现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.set_grad_enabled 函数


Pytorch torch 参考手册 Pytorch torch 参考手册

torch.set_grad_enabled 是 PyTorch 中用于全局设置是否启用梯度计算的函数。它可以动态地启用或禁用梯度计算,与上下文管理器不同,它是一个可以改变全局状态的函数。

这在需要根据条件动态控制梯度计算时非常有用。

函数定义

torch.set_grad_enabled(mode)

参数:

  • mode (bool): 如果为 True,启用梯度计算;如果为 False,禁用梯度计算。

返回值:

  • 返回一个上下文管理器,可以用于 with 语句。

使用示例

示例 1: 基本用法

实例

import torch

# 创建一个需要梯度的张量
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)

# 禁用梯度计算
torch.set_grad_enabled(False)
y1 = x * 2
print("禁用梯度后:", y1.requires_grad)

# 启用梯度计算
torch.set_grad_enabled(True)
y2 = x * 2
print("启用梯度后:", y2.requires_grad)

输出结果为:

禁用梯度后: False
启用梯度后: True

示例 2: 作为上下文管理器使用

实例

import torch

x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)

# 使用上下文管理器
with torch.set_grad_enabled(False):
    y1 = x * 2
    print("在上下文管理器内:", y1.requires_grad)

y2 = x * 2
print("在上下文管理器外:", y2.requires_grad)

输出结果为:

在上下文管理器内: False
在上下文管理器外: True

示例 3: 动态控制训练和推理

实例

import torch
import torch.nn as nn

model = nn.Linear(10, 2)

def forward_pass(x, training=True):
    """根据 training 参数控制梯度"""
    with torch.set_grad_enabled(training):
        output = model(x)
        print(f"training={training}, requires_grad={output.requires_grad}")
    return output

# 训练模式
x = torch.randn(5, 10)
forward_pass(x, training=True)

# 推理模式
forward_pass(x, training=False)

输出结果为:

training=True, requires_grad=True
training=False, requires_grad=False

示例 4: 保存和恢复梯度状态

实例

import torch

# 初始状态
print("初始状态:", torch.is_grad_enabled())

# 创建返回原始状态的上下文管理器
old = torch.is_grad_enabled()

# 临时禁用梯度
with torch.set_grad_enabled(False):
    print("在内部:", torch.is_grad_enabled())

# 自动恢复(但这里需要手动恢复)
torch.set_grad_enabled(old)
print("恢复后:", torch.is_grad_enabled())

输出结果为:

初始状态: True
在内部: False
恢复后: True

相关函数

  • torch.no_grad(): 禁用梯度计算的上下文管理器。
  • torch.enable_grad(): 启用梯度计算的上下文管理器。
  • torch.is_grad_enabled(): 检查当前是否启用梯度计算。

注意事项

  • set_grad_enabled 可以作为函数直接调用改变全局状态,也可以作为上下文管理器使用。
  • 作为函数调用时,需要手动恢复原始状态,否则会影响后续代码。
  • 建议使用上下文管理器的方式,以确保状态正确恢复。
  • 注意全局状态的影响,在复杂代码中要小心使用。

Pytorch torch 参考手册 Pytorch torch 参考手册