PyTorch torch.enable_grad 函数
torch.enable_grad 是 PyTorch 中用于启用梯度计算的上下文管理器。它与 torch.no_grad 相反,用于在需要梯度的代码块中明确启用梯度计算。
这在推理代码中临时需要训练模型,或者在某个局部区域需要启用梯度计算时非常有用。
函数定义
torch.enable_grad()
参数:
- 无参数。这是一个上下文管理器。
返回值:
- 返回一个上下文管理器,在该上下文中启用梯度计算。
使用示例
示例 1: 基本用法
实例
import torch
# 默认情况下,在 no_grad 上下文中
with torch.no_grad():
x = torch.tensor([1.0, 2.0, 3.0])
print("在 no_grad 中:", x.requires_grad)
# 使用 enable_grad 临时启用梯度
with torch.enable_grad():
y = x * 2
print("在 enable_grad 中:", y.requires_grad)
# 退出后恢复 no_grad 状态
z = x * 2
print("退出后:", z.requires_grad)
# 默认情况下,在 no_grad 上下文中
with torch.no_grad():
x = torch.tensor([1.0, 2.0, 3.0])
print("在 no_grad 中:", x.requires_grad)
# 使用 enable_grad 临时启用梯度
with torch.enable_grad():
y = x * 2
print("在 enable_grad 中:", y.requires_grad)
# 退出后恢复 no_grad 状态
z = x * 2
print("退出后:", z.requires_grad)
输出结果为:
在 no_grad 中: False 在 enable_grad 中: True 退出后: False
示例 2: 训练和推理混合使用
实例
import torch
import torch.nn as nn
model = nn.Linear(10, 2)
# 推理模式
with torch.no_grad():
x = torch.randn(5, 10)
output1 = model(x)
print("推理输出:", output1.shape)
# 如果需要在推理时临时训练部分参数
with torch.enable_grad():
# 创建一个需要梯度的张量进行计算
temp_weight = torch.randn(10, 10, requires_grad=True)
temp_output = torch.mm(x, temp_weight)
print("临时启用梯度:", temp_output.requires_grad)
import torch.nn as nn
model = nn.Linear(10, 2)
# 推理模式
with torch.no_grad():
x = torch.randn(5, 10)
output1 = model(x)
print("推理输出:", output1.shape)
# 如果需要在推理时临时训练部分参数
with torch.enable_grad():
# 创建一个需要梯度的张量进行计算
temp_weight = torch.randn(10, 10, requires_grad=True)
temp_output = torch.mm(x, temp_weight)
print("临时启用梯度:", temp_output.requires_grad)
输出结果为:
推理输出: torch.Size([5, 2]) 临时启用梯度: True
示例 3: 在装饰器中使用
实例
import torch
@torch.enable_grad()
def train_step(x, y):
"""模拟训练步骤"""
# 这个函数内部会启用梯度计算
loss = (x - y).sum()
return loss
# 在 no_grad 上下文中调用
with torch.no_grad():
x = torch.tensor([1.0, 2.0, 3.0])
y = torch.tensor([0.0, 0.0, 0.0])
# 装饰器确保梯度计算启用
loss = train_step(x, y)
print("Loss:", loss)
print("需要梯度:", loss.requires_grad)
@torch.enable_grad()
def train_step(x, y):
"""模拟训练步骤"""
# 这个函数内部会启用梯度计算
loss = (x - y).sum()
return loss
# 在 no_grad 上下文中调用
with torch.no_grad():
x = torch.tensor([1.0, 2.0, 3.0])
y = torch.tensor([0.0, 0.0, 0.0])
# 装饰器确保梯度计算启用
loss = train_step(x, y)
print("Loss:", loss)
print("需要梯度:", loss.requires_grad)
输出结果为:
Loss: tensor(6.) 需要梯度: True
相关函数
torch.no_grad(): 禁用梯度计算。torch.set_grad_enabled(grad): 根据参数启用或禁用梯度计算。torch.is_grad_enabled(): 检查当前是否启用梯度计算。
注意事项
enable_grad主要用于在no_grad上下文中临时启用梯度计算。- 如果在全局启用了梯度的环境中使用
enable_grad,不会有任何效果。 - 建议使用
enable_grad装饰器来确保函数内部始终启用梯度。

Pytorch torch 参考手册