现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.manual_seed 函数


Pytorch torch 参考手册 Pytorch torch 参考手册

torch.manual_seed 是 PyTorch 中用于设置随机数生成器种子的函数。设置种子可以确保结果的可重复性。

这在需要保证实验结果可复现时非常重要,例如调试、论文复现等场景。

函数定义

torch.manual_seed(seed)

参数:

  • seed (int): 随机种子。

返回值:


使用示例

示例 1: 设置种子保证可复现性

实例

import torch

# 设置随机种子
torch.manual_seed(42)

# 每次生成的随机数相同
x = torch.randn(3)
print("第一次:", x)

# 重新设置相同种子
torch.manual_seed(42)
y = torch.randn(3)
print("第二次:", y)

print("结果相同:", torch.equal(x, y))

输出结果为:

第一次: tensor([ 0.3367,  0.1288,  0.2345])
第二次: tensor([ 0.3367,  0.1288,  0.2345])
结果相同: True

示例 2: 完整可复现训练

实例

import torch
import random
import numpy as np

def set_seed(seed=42):
    # 设置 PyTorch 种子
    torch.manual_seed(seed)
    # 设置 NumPy 种子
    np.random.seed(seed)
    # 设置 Python random 种子
    random.seed(seed)
    # 确保 CUDA 确定性好(如果使用)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# 设置种子
set_seed(42)

# 生成随机数据
x = torch.randn(3, 4)
print(x)

输出结果为:

tensor([[ 0.3367,  0.1288,  0.2345,  0.2303],
        [-1.1229, -0.1863,  0.1735, -0.5524],
        [ 0.6351, -0.2582,  0.4602, -0.5270]])

为了完全保证可复现性,需要同时设置 PyTorch、NumPy 和 Python random 的种子。


Pytorch torch 参考手册 Pytorch torch 参考手册