model.load_state_dict(state_dict)报错问题
load_state_dict strict=False 使用说明
load_state_dict strict=False 时不会成功加载不对应位置的参数,这种方式无法解决模型不对应,只是在模型有改动时,只需要要加载骨干网络的参数时,用这种方法,因为骨干网络的那部分key大概率对应(还是打印检测一下保险),所以可以用strict=False,使得骨干网络那部分的参数加载上来。
print(model_path)
# model.load_state_dict(torch.load(model_path, map_location=device), strict=False)
state_dict = torch.load(model_path, map_location=device)
# print(state_dict.keys())
# print(model)
model.load_state_dict(state_dict, strict=False)
# 'module.backbone.layer1.0.weight
bl10w = state_dict["module.backbone.layer1.0.weight"]
print(bl10w.shape)
model_state_dict = model.state_dict()
mbl10w = model_state_dict["backbone.layer1.0.weight"]
print(mbl10w.shape)
print(bl10w[0][0][0][0])
print(mbl10w[0][0][0][0])
D:/python/python_data/SiamRPNPP_Change/save/SiamRPNPP_Transformer_Alexnet_epoch_39.pt
torch.Size([96, 3, 11, 11])
torch.Size([96, 3, 11, 11])
tensor(0.0473, device='cuda:0')
tensor(0.0057)
正确处理模型key不对应的方法:
参考链接:https://blog.csdn.net/xuru_0927/article/details/119274321
# load model
print(model_path)
state_dict = torch.load(model_path, map_location=device)
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
model.load_state_dict(state_dict)
bl10w = state_dict["backbone.layer1.0.weight"]
print(bl10w.shape)
model_state_dict = model.state_dict()
mbl10w = model_state_dict["backbone.layer1.0.weight"]
print(mbl10w.shape)
print(bl10w[0][0][0][0])
print(mbl10w[0][0][0][0])
print(bl10w[0][0][0][1])
print(mbl10w[0][0][0][1])
结果:
D:/python/python_data/SiamRPNPP_Change/save/SiamRPNPP_Transformer_Alexnet_epoch_39.pt
torch.Size([96, 3, 11, 11])
torch.Size([96, 3, 11, 11])
tensor(0.0473, device='cuda:0')
tensor(0.0473)
tensor(-0.0176, device='cuda:0')
tensor(-0.0176)