PyTorch torch.inference_mode 函数
torch.inference_mode 是 PyTorch 中用于推理模式的上下文管理器。它比 torch.no_grad 更严格,不仅禁用梯度计算,还禁用 autograd 引擎的所有跟踪功能。
这在模型推理时比 no_grad 更加高效,可以进一步减少内存使用和提升推理速度。
函数定义
torch.inference_mode(mode=True)
参数:
mode(bool, 可选): 如果为True(默认值),启用推理模式;如果为False,退出推理模式。默认为True。
返回值:
- 返回一个上下文管理器,在该上下文中禁用梯度和 autograd。
使用示例
示例 1: 基本用法
实例
import torch
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
# 在 inference_mode 上下文中
with torch.inference_mode():
y = x * 2
print("在 inference_mode 中:", y.requires_grad)
# 在 no_grad 上下文中
with torch.no_grad():
z = x * 2
print("在 no_grad 中:", z.requires_grad)
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
# 在 inference_mode 上下文中
with torch.inference_mode():
y = x * 2
print("在 inference_mode 中:", y.requires_grad)
# 在 no_grad 上下文中
with torch.no_grad():
z = x * 2
print("在 no_grad 中:", z.requires_grad)
输出结果为:
在 inference_mode 中: False 在 no_grad 中: False
示例 2: 对比 no_grad 和 inference_mode
实例
import torch
# 创建张量
x = torch.randn(100, 100)
# 在 inference_mode 中
with torch.inference_mode():
# 进行了大量计算
for _ in range(10):
x = torch.mm(x, x)
# 即使计算完成,在上下文中的张量也不能用于反向传播
result = x.sum()
# 检查是否可以转换为需要梯度的张量
print("在 inference_mode 中:", result.is_leaf)
# 在 no_grad 中做同样的计算
x2 = torch.randn(100, 100)
with torch.no_grad():
for _ in range(10):
x2 = torch.mm(x2, x2)
result2 = x2.sum()
print("在 no_grad 中:", result2.is_leaf)
# 创建张量
x = torch.randn(100, 100)
# 在 inference_mode 中
with torch.inference_mode():
# 进行了大量计算
for _ in range(10):
x = torch.mm(x, x)
# 即使计算完成,在上下文中的张量也不能用于反向传播
result = x.sum()
# 检查是否可以转换为需要梯度的张量
print("在 inference_mode 中:", result.is_leaf)
# 在 no_grad 中做同样的计算
x2 = torch.randn(100, 100)
with torch.no_grad():
for _ in range(10):
x2 = torch.mm(x2, x2)
result2 = x2.sum()
print("在 no_grad 中:", result2.is_leaf)
输出结果为:
在 inference_mode 中: False 在 no_grad 中: True
示例 3: 模型推理
实例
import torch
import torch.nn as nn
# 定义一个简单的模型
model = nn.Sequential(
nn.Linear(10, 20),
nn.ReLU(),
nn.Linear(20, 5)
)
model.eval()
# 创建输入数据
x = torch.randn(100, 10)
# 使用 inference_mode 进行推理
with torch.inference_mode():
output = model(x)
print("输出形状:", output.shape)
print("输出 requires_grad:", output.requires_grad)
# 也可以使用装饰器
@torch.inference_mode()
def predict(x):
return model(x)
result = predict(x)
print("装饰器方式 - 输出形状:", result.shape)
import torch.nn as nn
# 定义一个简单的模型
model = nn.Sequential(
nn.Linear(10, 20),
nn.ReLU(),
nn.Linear(20, 5)
)
model.eval()
# 创建输入数据
x = torch.randn(100, 10)
# 使用 inference_mode 进行推理
with torch.inference_mode():
output = model(x)
print("输出形状:", output.shape)
print("输出 requires_grad:", output.requires_grad)
# 也可以使用装饰器
@torch.inference_mode()
def predict(x):
return model(x)
result = predict(x)
print("装饰器方式 - 输出形状:", result.shape)
输出结果为:
输出形状: torch.Size([100, 5]) 输出 requires_grad: False 装饰器方式 - 输出形状: torch.Size([100, 5])
示例 4: 内存优化对比
实例
import torch
import torch.nn as nn
model = nn.Sequential(
nn.Linear(1000, 1000),
nn.ReLU(),
nn.Linear(1000, 1000),
nn.ReLU(),
nn.Linear(1000, 10)
)
# 测试不同模式的内存占用
x = torch.randn(50, 1000)
print("不使用任何上下文管理器:")
_ = model(x)
print("n使用 no_grad:")
with torch.no_grad():
_ = model(x)
print("n使用 inference_mode:")
with torch.inference_mode():
_ = model(x)
import torch.nn as nn
model = nn.Sequential(
nn.Linear(1000, 1000),
nn.ReLU(),
nn.Linear(1000, 1000),
nn.ReLU(),
nn.Linear(1000, 10)
)
# 测试不同模式的内存占用
x = torch.randn(50, 1000)
print("不使用任何上下文管理器:")
_ = model(x)
print("n使用 no_grad:")
with torch.no_grad():
_ = model(x)
print("n使用 inference_mode:")
with torch.inference_mode():
_ = model(x)
使用 inference_mode 可以进一步优化内存,因为它完全禁用了 autograd 引擎。
相关函数
torch.no_grad(): 禁用梯度计算,但仍然保留部分 autograd 功能。torch.enable_grad(): 启用梯度计算。torch.is_inference_mode_enabled(): 检查是否启用推理模式。
注意事项
inference_mode比no_grad更严格,禁用更多的功能。- 在
inference_mode中创建的张量被标记为非叶子节点,无法用于反向传播。 - 推荐在模型推理和评估时使用
inference_mode以获得最佳性能。 inference_mode不能与no_grad嵌套使用。

Pytorch torch 参考手册