PyTorch torch.no_grad 函数
torch.no_grad 是 PyTorch 中用于禁用梯度计算的上下文管理器。在 no_grad 块内创建的张量不会计算梯度,这可以显著减少内存消耗并提高推理速度。
这在模型推理(inference)和评估阶段是必不可少的,可以大幅提升性能和节省内存。
函数定义
torch.no_grad()
参数:
- 无参数。这是一个上下文管理器。
返回值:
- 返回一个上下文管理器,在该上下文中禁用梯度计算。
使用示例
示例 1: 基本用法
实例
import torch
# 创建一个需要梯度的张量
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
# 在 no_grad 上下文中计算
with torch.no_grad():
y = x * 2
print("在 no_grad 中:", y.requires_grad)
# 在 no_grad 外部
z = x * 2
print("在 no_grad 外:", z.requires_grad)
# 创建一个需要梯度的张量
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
# 在 no_grad 上下文中计算
with torch.no_grad():
y = x * 2
print("在 no_grad 中:", y.requires_grad)
# 在 no_grad 外部
z = x * 2
print("在 no_grad 外:", z.requires_grad)
输出结果为:
在 no_grad 中: False 在 no_grad 外: True
示例 2: 模型推理
实例
import torch
import torch.nn as nn
# 定义一个简单的模型
model = nn.Linear(10, 2)
# 创建输入数据
x = torch.randn(1, 10)
# 推理时使用 no_grad
with torch.no_grad():
output = model(x)
print("输出:", output)
# 不需要计算梯度
# 或者使用装饰器
# @torch.no_grad()
# def predict(x):
# return model(x)
import torch.nn as nn
# 定义一个简单的模型
model = nn.Linear(10, 2)
# 创建输入数据
x = torch.randn(1, 10)
# 推理时使用 no_grad
with torch.no_grad():
output = model(x)
print("输出:", output)
# 不需要计算梯度
# 或者使用装饰器
# @torch.no_grad()
# def predict(x):
# return model(x)
输出结果为:
输出: tensor([[0.0920, 0.3557]])
示例 3: 评估模式
实例
import torch
import torch.nn as nn
model = nn.Sequential(
nn.Linear(10, 5),
nn.ReLU(),
nn.Linear(5, 2)
)
# 切换到评估模式
model.eval()
# 准备测试数据
test_input = torch.randn(5, 10)
# 评估时禁用梯度
with torch.no_grad():
predictions = model(test_input)
print("预测结果形状:", predictions.shape)
import torch.nn as nn
model = nn.Sequential(
nn.Linear(10, 5),
nn.ReLU(),
nn.Linear(5, 2)
)
# 切换到评估模式
model.eval()
# 准备测试数据
test_input = torch.randn(5, 10)
# 评估时禁用梯度
with torch.no_grad():
predictions = model(test_input)
print("预测结果形状:", predictions.shape)
输出结果为:
预测结果形状: torch.Size([5, 2])
示例 4: 对比内存使用
实例
import torch
import torch.nn as nn
model = nn.Linear(1000, 1000)
# 创建大量输入
inputs = [torch.randn(100, 1000) for _ in range(100)]
# 不使用 no_grad(会记录梯度历史)
print("不使用 no_grad:")
for inp in inputs[:5]:
_ = model(inp)
# 使用 no_grad(不记录梯度历史)
print("使用 no_grad:")
with torch.no_grad():
for inp in inputs[:5]:
_ = model(inp)
import torch.nn as nn
model = nn.Linear(1000, 1000)
# 创建大量输入
inputs = [torch.randn(100, 1000) for _ in range(100)]
# 不使用 no_grad(会记录梯度历史)
print("不使用 no_grad:")
for inp in inputs[:5]:
_ = model(inp)
# 使用 no_grad(不记录梯度历史)
print("使用 no_grad:")
with torch.no_grad():
for inp in inputs[:5]:
_ = model(inp)
使用 no_grad 可以显著减少内存消耗,因为不需要为中间变量存储梯度信息。
相关函数
torch.enable_grad(): 启用梯度计算(与no_grad相反)。torch.set_grad_enabled(grad): 根据参数启用或禁用梯度计算。torch.inference_mode(): 更严格的推理模式,同时禁用梯度计算和 autograd。
注意事项
- 在模型推理和评估时,始终使用
torch.no_grad()以节省内存和提升速度。 - 如果在
no_grad块内创建的变量需要在块外使用,需要手动复制出来。 - 与
model.eval()配合使用效果最佳。

Pytorch torch 参考手册