PyTorch torch.set_grad_enabled 函数
torch.set_grad_enabled 是 PyTorch 中用于全局设置是否启用梯度计算的函数。它可以动态地启用或禁用梯度计算,与上下文管理器不同,它是一个可以改变全局状态的函数。
这在需要根据条件动态控制梯度计算时非常有用。
函数定义
torch.set_grad_enabled(mode)
参数:
mode(bool): 如果为True,启用梯度计算;如果为False,禁用梯度计算。
返回值:
- 返回一个上下文管理器,可以用于
with语句。
使用示例
示例 1: 基本用法
实例
import torch
# 创建一个需要梯度的张量
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
# 禁用梯度计算
torch.set_grad_enabled(False)
y1 = x * 2
print("禁用梯度后:", y1.requires_grad)
# 启用梯度计算
torch.set_grad_enabled(True)
y2 = x * 2
print("启用梯度后:", y2.requires_grad)
# 创建一个需要梯度的张量
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
# 禁用梯度计算
torch.set_grad_enabled(False)
y1 = x * 2
print("禁用梯度后:", y1.requires_grad)
# 启用梯度计算
torch.set_grad_enabled(True)
y2 = x * 2
print("启用梯度后:", y2.requires_grad)
输出结果为:
禁用梯度后: False 启用梯度后: True
示例 2: 作为上下文管理器使用
实例
import torch
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
# 使用上下文管理器
with torch.set_grad_enabled(False):
y1 = x * 2
print("在上下文管理器内:", y1.requires_grad)
y2 = x * 2
print("在上下文管理器外:", y2.requires_grad)
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
# 使用上下文管理器
with torch.set_grad_enabled(False):
y1 = x * 2
print("在上下文管理器内:", y1.requires_grad)
y2 = x * 2
print("在上下文管理器外:", y2.requires_grad)
输出结果为:
在上下文管理器内: False 在上下文管理器外: True
示例 3: 动态控制训练和推理
实例
import torch
import torch.nn as nn
model = nn.Linear(10, 2)
def forward_pass(x, training=True):
"""根据 training 参数控制梯度"""
with torch.set_grad_enabled(training):
output = model(x)
print(f"training={training}, requires_grad={output.requires_grad}")
return output
# 训练模式
x = torch.randn(5, 10)
forward_pass(x, training=True)
# 推理模式
forward_pass(x, training=False)
import torch.nn as nn
model = nn.Linear(10, 2)
def forward_pass(x, training=True):
"""根据 training 参数控制梯度"""
with torch.set_grad_enabled(training):
output = model(x)
print(f"training={training}, requires_grad={output.requires_grad}")
return output
# 训练模式
x = torch.randn(5, 10)
forward_pass(x, training=True)
# 推理模式
forward_pass(x, training=False)
输出结果为:
training=True, requires_grad=True training=False, requires_grad=False
示例 4: 保存和恢复梯度状态
实例
import torch
# 初始状态
print("初始状态:", torch.is_grad_enabled())
# 创建返回原始状态的上下文管理器
old = torch.is_grad_enabled()
# 临时禁用梯度
with torch.set_grad_enabled(False):
print("在内部:", torch.is_grad_enabled())
# 自动恢复(但这里需要手动恢复)
torch.set_grad_enabled(old)
print("恢复后:", torch.is_grad_enabled())
# 初始状态
print("初始状态:", torch.is_grad_enabled())
# 创建返回原始状态的上下文管理器
old = torch.is_grad_enabled()
# 临时禁用梯度
with torch.set_grad_enabled(False):
print("在内部:", torch.is_grad_enabled())
# 自动恢复(但这里需要手动恢复)
torch.set_grad_enabled(old)
print("恢复后:", torch.is_grad_enabled())
输出结果为:
初始状态: True 在内部: False 恢复后: True
相关函数
torch.no_grad(): 禁用梯度计算的上下文管理器。torch.enable_grad(): 启用梯度计算的上下文管理器。torch.is_grad_enabled(): 检查当前是否启用梯度计算。
注意事项
set_grad_enabled可以作为函数直接调用改变全局状态,也可以作为上下文管理器使用。- 作为函数调用时,需要手动恢复原始状态,否则会影响后续代码。
- 建议使用上下文管理器的方式,以确保状态正确恢复。
- 注意全局状态的影响,在复杂代码中要小心使用。

Pytorch torch 参考手册