问题的产生
在使用torch.save(model.state_dict,PATH)保存模型之后,在测试模型需要重新加载模型参数的时候会出现一下问题:
错误信息
1.错误信息提示报错原因
在训练模型的时候使用了多块gpu进行了分布式训练,因此,会在保存模型的时候使用了nn.DataParallel
,因此,在加载模型的key值中会多一个module.这七个字符。
问题的解决
1.去掉key值的前七个字符。
# original saved file with DataParallel
state_dict = torch.load(model_path)
# create new OrderedDict that does not contain `module.`
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
# load params
model.load_state_dict(new_state_dict)
2.更改模型加载方式
将模型直接加载到cpu上,然后在测试模型之前可以使用torch.cuda.is_available
加载到gpu上
-Save
torch.save(model.state_dict(), PATH)
-Load
device = torch.device('cpu')
model = YourNet()
model.load_state_dict(torch.load(PATH, map_location=device))