现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.inference_mode 函数


Pytorch torch 参考手册 Pytorch torch 参考手册

torch.inference_mode 是 PyTorch 中用于推理模式的上下文管理器。它比 torch.no_grad 更严格,不仅禁用梯度计算,还禁用 autograd 引擎的所有跟踪功能。

这在模型推理时比 no_grad 更加高效,可以进一步减少内存使用和提升推理速度。

函数定义

torch.inference_mode(mode=True)

参数:

  • mode (bool, 可选): 如果为 True(默认值),启用推理模式;如果为 False,退出推理模式。默认为 True

返回值:

  • 返回一个上下文管理器,在该上下文中禁用梯度和 autograd。

使用示例

示例 1: 基本用法

实例

import torch

x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)

# 在 inference_mode 上下文中
with torch.inference_mode():
    y = x * 2
    print("在 inference_mode 中:", y.requires_grad)

# 在 no_grad 上下文中
with torch.no_grad():
    z = x * 2
    print("在 no_grad 中:", z.requires_grad)

输出结果为:

在 inference_mode 中: False
在 no_grad 中: False

示例 2: 对比 no_grad 和 inference_mode

实例

import torch

# 创建张量
x = torch.randn(100, 100)

# 在 inference_mode 中
with torch.inference_mode():
    # 进行了大量计算
    for _ in range(10):
        x = torch.mm(x, x)

    # 即使计算完成,在上下文中的张量也不能用于反向传播
    result = x.sum()

    # 检查是否可以转换为需要梯度的张量
    print("在 inference_mode 中:", result.is_leaf)

# 在 no_grad 中做同样的计算
x2 = torch.randn(100, 100)
with torch.no_grad():
    for _ in range(10):
        x2 = torch.mm(x2, x2)
    result2 = x2.sum()
    print("在 no_grad 中:", result2.is_leaf)

输出结果为:

在 inference_mode 中: False
在 no_grad 中: True

示例 3: 模型推理

实例

import torch
import torch.nn as nn

# 定义一个简单的模型
model = nn.Sequential(
    nn.Linear(10, 20),
    nn.ReLU(),
    nn.Linear(20, 5)
)

model.eval()

# 创建输入数据
x = torch.randn(100, 10)

# 使用 inference_mode 进行推理
with torch.inference_mode():
    output = model(x)
    print("输出形状:", output.shape)
    print("输出 requires_grad:", output.requires_grad)

# 也可以使用装饰器
@torch.inference_mode()
def predict(x):
    return model(x)

result = predict(x)
print("装饰器方式 - 输出形状:", result.shape)

输出结果为:

输出形状: torch.Size([100, 5])
输出 requires_grad: False
装饰器方式 - 输出形状: torch.Size([100, 5])

示例 4: 内存优化对比

实例

import torch
import torch.nn as nn

model = nn.Sequential(
    nn.Linear(1000, 1000),
    nn.ReLU(),
    nn.Linear(1000, 1000),
    nn.ReLU(),
    nn.Linear(1000, 10)
)

# 测试不同模式的内存占用
x = torch.randn(50, 1000)

print("不使用任何上下文管理器:")
_ = model(x)

print("n使用 no_grad:")
with torch.no_grad():
    _ = model(x)

print("n使用 inference_mode:")
with torch.inference_mode():
    _ = model(x)

使用 inference_mode 可以进一步优化内存,因为它完全禁用了 autograd 引擎。


相关函数

  • torch.no_grad(): 禁用梯度计算,但仍然保留部分 autograd 功能。
  • torch.enable_grad(): 启用梯度计算。
  • torch.is_inference_mode_enabled(): 检查是否启用推理模式。

注意事项

  • inference_modeno_grad 更严格,禁用更多的功能。
  • inference_mode 中创建的张量被标记为非叶子节点,无法用于反向传播。
  • 推荐在模型推理和评估时使用 inference_mode 以获得最佳性能。
  • inference_mode 不能与 no_grad 嵌套使用。

Pytorch torch 参考手册 Pytorch torch 参考手册