现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.load_state_dict 函数


Pytorch torch 参考手册 Pytorch torch 参考手册

torch.load_state_dict 是 PyTorch 中用于加载状态字典的函数。从字典对象中加载模型优化器等的状态。

函数定义

torch.load_state_dict(state_dict, strict=True)

使用示例

实例

import torch
import torch.nn as nn

# 定义一个简单的模型
model = nn.Linear(10, 5)

# 保存模型状态字典
state_dict = model.state_dict()
torch.save(state_dict, 'model_state.pt')

# 加载状态字典
loaded_state = torch.load('model_state.pt')
model.load_state_dict(loaded_state)
print("State dict loaded successfully!")

Pytorch torch 参考手册 Pytorch torch 参考手册