现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.save 与 torch.load 函数


Pytorch torch 参考手册 Pytorch torch 参考手册

torch.savetorch.load 是 PyTorch 中用于序列化(保存)和反序列化(加载)张量、模型和其他 Python 对象的函数。

这些函数在保存训练好的模型、保存检查点以便恢复训练等场景中必不可少。

函数定义

torch.save(obj, f, pickle_module, pickle_protocol)
torch.load(f, map_location, pickle_module, weights_only)

torch.save 参数:

  • obj: 要保存的对象,可以是张量、模型、字典、列表等。
  • f: 文件路径(字符串或文件对象)。
  • pickle_module (可选): 用于序列化的模块。
  • pickle_protocol (可选): 序列化协议版本。

torch.load 参数:

  • f: 文件路径(字符串或文件对象)。
  • map_location (可选): 指定如何将存储映射到不同的设备。
  • pickle_module (可选): 用于反序列化的模块。
  • weights_only (bool, 可选): 是否只加载权重而不加载 Python 对象。

使用示例

示例 1: 保存和加载张量

实例

import torch

# 创建一些张量
x = torch.tensor([1, 2, 3, 4, 5])
y = torch.randn(3, 4)

# 保存张量到文件
torch.save({'x': x, 'y': y}, 'tensors.pth')

# 从文件加载张量
loaded = torch.load('tensors.pth')

print("加载的数据:", loaded)
print("x:", loaded['x'])
print("y:", loaded['y'])

输出结果为:

加载的数据: {'x': tensor([1, 2, 3, 4, 5]), 'y': tensor([[-0.3042, -0.9077, -1.0826,  0.9333],
        [ 0.0551,  0.6728,  0.5942, -0.1522],
        [-0.3744,  0.9239, -0.2104, -0.5239]])}
x: tensor([1, 2, 3, 4, 5])
y: tensor([[-0.3042, -0.9077, -1.0826,  0.9333],
        [ 0.0551,  0.6728,  0.5942, -0.1522],
        [-0.3744,  0.9239, -0.2104, -0.5239]])

示例 2: 保存和加载模型

实例

import torch
import torch.nn as nn

# 定义一个简单的神经网络
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(10, 5)
        self.fc2 = nn.Linear(5, 2)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 创建模型实例
model = SimpleNet()

# 保存模型(保存整个模型)
torch.save(model, 'model.pth')

# 加载模型
loaded_model = torch.load('model.pth')

print("模型已保存并加载")
print(loaded_model)

示例 3: 只保存模型参数(推荐方式)

实例

import torch
import torch.nn as nn

class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(10, 5)
        self.fc2 = nn.Linear(5, 2)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

model = SimpleNet()

# 只保存模型参数(推荐方式)
torch.save(model.state_dict(), 'model_weights.pth')

# 创建新模型并加载参数
new_model = SimpleNet()
new_model.load_state_dict(torch.load('model_weights.pth'))

print("模型参数已保存并加载")
print(new_model.state_dict().keys())

输出结果为:

模型参数已保存并加载
odict_keys(['fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias'])

推荐只保存模型参数(state_dict),而不是保存整个模型。这样可以在不同架构的模型之间复用参数。

示例 4: 在 CPU 和 GPU 之间迁移

实例

import torch

# 假设在 GPU 上保存了模型
if torch.cuda.is_available():
    x = torch.randn(2, 3, device='cuda')
    torch.save(x, 'tensor_gpu.pth')

    # 在 CPU 上加载
    x_cpu = torch.load('tensor_gpu.pth', map_location='cpu')
    print("加载到 CPU:", x_cpu.device)

使用 map_location 参数可以将张量加载到不同的设备上。


注意事项

  • 保存文件通常使用 .pth.pt 扩展名。
  • 推荐只保存模型参数(state_dict),而不是保存整个模型。
  • 加载时需要注意安全性和版本兼容性。

Pytorch torch 参考手册 Pytorch torch 参考手册