OiO.lk Community platform!

Oio.lk is an excellent forum for developers, providing a wide range of resources, discussions, and support for those in the developer community. Join oio.lk today to connect with like-minded professionals, share insights, and stay updated on the latest trends and technologies in the development field.
  You need to log in or register to access the solved answers to this problem.
  • You have reached the maximum number of guest views allowed
  • Please register below to remove this limitation

Inconsistent Inference Results After Loading PyTorch Model Checkpoint

  • Thread starter Thread starter Dana Dăscălescu
  • Start date Start date
D

Dana Dăscălescu

Guest
I am experiencing an issue with my PyTorch model where I get different inference results after saving and loading the model state dictionary. Here is the code I am using to save the model state:

Code:
def save_net_state(self, base_path: str = '', epoch: int = None, latest: bool = False, best: bool = False):
    os.makedirs(base_path, exist_ok=True)

    if latest:
        filename = 'latest_checkpoint.pkl'
        to_save = {
            "epoch": epoch,
            "model_weights": self.model.state_dict(),
            "optimizer": self.optimizer.state_dict()
        }
    elif best:
        filename = 'best_model.pkl'
        to_save = {
            "epoch": epoch,
            "model_weights": self.model.state_dict()
        }

    path_to_save = os.path.join(base_path, filename)
    torch.save(to_save, path_to_save)

To load the model back for evaluation:

Code:
model_path = './best_model.pkl'
checkpoint = torch.load(model_path, map_location='cpu')
model.load_state_dict(checkpoint["model_weights"], strict=True)
model.eval()
model.cuda()

There are no extra keys or missing keys in the state dictionary. However, I noticed that when I perform inference, I get different results than I did before saving the model. Specifically, at the time of saving, my model achieved around 81% accuracy, but after loading, it consistently achieves around 78.5% accuracy. No changes were made to the model architecture, neither to the dataset or to the preprocessing steps between saving and loading.

I've already tried to compare state_dicts:

Code:
def compare_state_dicts(model, checkpoint):
    model_keys = set(model.state_dict().keys())
    checkpoint_keys = set(checkpoint["model_weights"].keys())

    missing_keys = model_keys - checkpoint_keys
    extra_keys = checkpoint_keys - model_keys

    print(f"Missing keys: {missing_keys}")
    print(f"Extra keys: {extra_keys}")

compare_state_dicts(model, checkpoint)

This results in empty sets.
<p>I am experiencing an issue with my PyTorch model where I get different inference results after saving and loading the model state dictionary. Here is the code I am using to save the model state:</p>
<pre><code>def save_net_state(self, base_path: str = '', epoch: int = None, latest: bool = False, best: bool = False):
os.makedirs(base_path, exist_ok=True)

if latest:
filename = 'latest_checkpoint.pkl'
to_save = {
"epoch": epoch,
"model_weights": self.model.state_dict(),
"optimizer": self.optimizer.state_dict()
}
elif best:
filename = 'best_model.pkl'
to_save = {
"epoch": epoch,
"model_weights": self.model.state_dict()
}

path_to_save = os.path.join(base_path, filename)
torch.save(to_save, path_to_save)
</code></pre>
<p>To load the model back for evaluation:</p>
<pre><code>model_path = './best_model.pkl'
checkpoint = torch.load(model_path, map_location='cpu')
model.load_state_dict(checkpoint["model_weights"], strict=True)
model.eval()
model.cuda()
</code></pre>
<p>There are no extra keys or missing keys in the state dictionary. However, I noticed that when I perform inference, I get different results than I did before saving the model. Specifically, at the time of saving, my model achieved around 81% accuracy, but after loading, it consistently achieves around 78.5% accuracy. No changes were made to the model architecture, neither to the dataset or to the preprocessing steps between saving and loading.</p>
<p>I've already tried to compare state_dicts:</p>
<pre><code>def compare_state_dicts(model, checkpoint):
model_keys = set(model.state_dict().keys())
checkpoint_keys = set(checkpoint["model_weights"].keys())

missing_keys = model_keys - checkpoint_keys
extra_keys = checkpoint_keys - model_keys

print(f"Missing keys: {missing_keys}")
print(f"Extra keys: {extra_keys}")

compare_state_dicts(model, checkpoint)
</code></pre>
<p>This results in empty sets.</p>
 

Latest posts

I
Replies
0
Views
1
Isaac P. Liu
I
U
Replies
0
Views
1
user3658366
U
G
Replies
0
Views
1
Giampaolo Levorato
G
M
Replies
0
Views
1
Marcelo Rodrigo Nascimento
M
Top