PyTorch torch.save 与 torch.load 函数
torch.save 和 torch.load 是 PyTorch 中用于序列化(保存)和反序列化(加载)张量、模型和其他 Python 对象的函数。
这些函数在保存训练好的模型、保存检查点以便恢复训练等场景中必不可少。
函数定义
torch.save(obj, f, pickle_module, pickle_protocol) torch.load(f, map_location, pickle_module, weights_only)
torch.save 参数:
obj: 要保存的对象,可以是张量、模型、字典、列表等。f: 文件路径(字符串或文件对象)。pickle_module(可选): 用于序列化的模块。pickle_protocol(可选): 序列化协议版本。
torch.load 参数:
f: 文件路径(字符串或文件对象)。map_location(可选): 指定如何将存储映射到不同的设备。pickle_module(可选): 用于反序列化的模块。weights_only(bool, 可选): 是否只加载权重而不加载 Python 对象。
使用示例
示例 1: 保存和加载张量
实例
import torch
# 创建一些张量
x = torch.tensor([1, 2, 3, 4, 5])
y = torch.randn(3, 4)
# 保存张量到文件
torch.save({'x': x, 'y': y}, 'tensors.pth')
# 从文件加载张量
loaded = torch.load('tensors.pth')
print("加载的数据:", loaded)
print("x:", loaded['x'])
print("y:", loaded['y'])
# 创建一些张量
x = torch.tensor([1, 2, 3, 4, 5])
y = torch.randn(3, 4)
# 保存张量到文件
torch.save({'x': x, 'y': y}, 'tensors.pth')
# 从文件加载张量
loaded = torch.load('tensors.pth')
print("加载的数据:", loaded)
print("x:", loaded['x'])
print("y:", loaded['y'])
输出结果为:
加载的数据: {'x': tensor([1, 2, 3, 4, 5]), 'y': tensor([[-0.3042, -0.9077, -1.0826, 0.9333],
[ 0.0551, 0.6728, 0.5942, -0.1522],
[-0.3744, 0.9239, -0.2104, -0.5239]])}
x: tensor([1, 2, 3, 4, 5])
y: tensor([[-0.3042, -0.9077, -1.0826, 0.9333],
[ 0.0551, 0.6728, 0.5942, -0.1522],
[-0.3744, 0.9239, -0.2104, -0.5239]])
示例 2: 保存和加载模型
实例
import torch
import torch.nn as nn
# 定义一个简单的神经网络
class SimpleNet(nn.Module):
def __init__(self):
super(SimpleNet, self).__init__()
self.fc1 = nn.Linear(10, 5)
self.fc2 = nn.Linear(5, 2)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
# 创建模型实例
model = SimpleNet()
# 保存模型(保存整个模型)
torch.save(model, 'model.pth')
# 加载模型
loaded_model = torch.load('model.pth')
print("模型已保存并加载")
print(loaded_model)
import torch.nn as nn
# 定义一个简单的神经网络
class SimpleNet(nn.Module):
def __init__(self):
super(SimpleNet, self).__init__()
self.fc1 = nn.Linear(10, 5)
self.fc2 = nn.Linear(5, 2)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
# 创建模型实例
model = SimpleNet()
# 保存模型(保存整个模型)
torch.save(model, 'model.pth')
# 加载模型
loaded_model = torch.load('model.pth')
print("模型已保存并加载")
print(loaded_model)
示例 3: 只保存模型参数(推荐方式)
实例
import torch
import torch.nn as nn
class SimpleNet(nn.Module):
def __init__(self):
super(SimpleNet, self).__init__()
self.fc1 = nn.Linear(10, 5)
self.fc2 = nn.Linear(5, 2)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
model = SimpleNet()
# 只保存模型参数(推荐方式)
torch.save(model.state_dict(), 'model_weights.pth')
# 创建新模型并加载参数
new_model = SimpleNet()
new_model.load_state_dict(torch.load('model_weights.pth'))
print("模型参数已保存并加载")
print(new_model.state_dict().keys())
import torch.nn as nn
class SimpleNet(nn.Module):
def __init__(self):
super(SimpleNet, self).__init__()
self.fc1 = nn.Linear(10, 5)
self.fc2 = nn.Linear(5, 2)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
model = SimpleNet()
# 只保存模型参数(推荐方式)
torch.save(model.state_dict(), 'model_weights.pth')
# 创建新模型并加载参数
new_model = SimpleNet()
new_model.load_state_dict(torch.load('model_weights.pth'))
print("模型参数已保存并加载")
print(new_model.state_dict().keys())
输出结果为:
模型参数已保存并加载 odict_keys(['fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias'])推荐只保存模型参数(
state_dict),而不是保存整个模型。这样可以在不同架构的模型之间复用参数。示例 4: 在 CPU 和 GPU 之间迁移
实例
import torch
# 假设在 GPU 上保存了模型
if torch.cuda.is_available():
x = torch.randn(2, 3, device='cuda')
torch.save(x, 'tensor_gpu.pth')
# 在 CPU 上加载
x_cpu = torch.load('tensor_gpu.pth', map_location='cpu')
print("加载到 CPU:", x_cpu.device)
使用
map_location参数可以将张量加载到不同的设备上。
注意事项
- 保存文件通常使用
.pth或.pt扩展名。 - 推荐只保存模型参数(
state_dict),而不是保存整个模型。 - 加载时需要注意安全性和版本兼容性。

Pytorch torch 参考手册