The following errors may occur while loading a pre-trained model.
RuntimeError: Error(s) in loading state_dict for model:
Missing key(s) in state_dict: ~~~~
Unexpected key(s) in state_dict: ~~~~
Occurs when the key is not sufficient or the key name does not match.
Setting "strict" as "false" can easily resolve this error.
model.load_state_dict(checkpoint, strict=False)
For more detail check document.
However, when trying to load the model while changing some layer. (e.g. change num_classes)
The following errors may occur due to mismatched sizes.
RuntimeError: Error(s) in loading state_dict for model:
size mismatch for head.weight: copying a param with shape torch.Size([1000, 768]) from checkpoint, the shape in current model is torch.Size([6, 768]).
size mismatch for head.bias: copying a param with shape torch.Size([1000]) from checkpoint, the shape in current model is torch.Size([6]).
Here is my solution.
By using the key to skip the part where the problem occurred.
As mentioned above, if the key name is not accurate, you can ignore it by setting "strict" as "false".
Therefore, changing the key name of the problem part will solve it.
state_dict = torch.load(checkpoint, map_location=device)['model']
temp = OrderedDict()
for i, j in state_dict.items(): # search all key from model
name = i.replace("head.","") # change key that doesn't match
temp[name] = j
model.load_state_dict(temp, strict=False)
(The code is referenced from here)
You can now load the model successfully.
'Study > AI' 카테고리의 다른 글
데이터가 충분하다고 말하려면 얼마나 있어야 할까? (0) | 2022.01.24 |
---|---|
딥러닝이란 무엇일까? (0) | 2022.01.15 |
Deep Learning Library for video understanding (0) | 2021.11.30 |
Knowledge Distillation 구현 (0) | 2021.11.29 |
Lightweight Deep Learning (0) | 2021.11.24 |