PyTorch torch.load_state_dict 函数
torch.load_state_dict 是 PyTorch 中用于加载状态字典的函数。从字典对象中加载模型优化器等的状态。
函数定义
torch.load_state_dict(state_dict, strict=True)
使用示例
实例
import torch
import torch.nn as nn
# 定义一个简单的模型
model = nn.Linear(10, 5)
# 保存模型状态字典
state_dict = model.state_dict()
torch.save(state_dict, 'model_state.pt')
# 加载状态字典
loaded_state = torch.load('model_state.pt')
model.load_state_dict(loaded_state)
print("State dict loaded successfully!")
import torch.nn as nn
# 定义一个简单的模型
model = nn.Linear(10, 5)
# 保存模型状态字典
state_dict = model.state_dict()
torch.save(state_dict, 'model_state.pt')
# 加载状态字典
loaded_state = torch.load('model_state.pt')
model.load_state_dict(loaded_state)
print("State dict loaded successfully!")

Pytorch torch 参考手册