现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.is_inference_mode_enabled 函数


Pytorch torch 参考手册 Pytorch torch 参考手册

torch.is_inference_mode_enabled 是 PyTorch 中用于检查当前是否启用推理模式的函数。它返回当前是否处于 inference_mode 上下文的布尔值。

这在编写需要根据推理模式状态执行不同逻辑的代码时非常有用。

函数定义

torch.is_inference_mode_enabled()

参数:

  • 无参数。

返回值:

  • 返回一个布尔值:如果当前启用了推理模式,返回 True;否则返回 False

使用示例

示例 1: 基本用法

实例

import torch

# 默认状态下,推理模式是禁用的
print("默认状态:", torch.is_inference_mode_enabled())

# 在 inference_mode 上下文中
with torch.inference_mode():
    print("在 inference_mode 中:", torch.is_inference_mode_enabled())

# 退出后恢复
print("退出后:", torch.is_inference_mode_enabled())

输出结果为:

默认状态: False
在 inference_mode 中: True
退出后: False

示例 2: 对比 is_grad_enabled 和 is_inference_mode_enabled

实例

import torch

# 在 no_grad 中
with torch.no_grad():
    print("在 no_grad 中:")
    print("  is_grad_enabled:", torch.is_grad_enabled())
    print("  is_inference_mode_enabled:", torch.is_inference_mode_enabled())

# 在 inference_mode 中
with torch.inference_mode():
    print("在 inference_mode 中:")
    print("  is_grad_enabled:", torch.is_grad_enabled())
    print("  is_inference_mode_enabled:", torch.is_inference_mode_enabled())

# 默认状态
print("默认状态:")
print("  is_grad_enabled:", torch.is_grad_enabled())
print("  is_inference_mode_enabled:", torch.is_inference_mode_enabled())

输出结果为:

在 no_grad 中:
  is_grad_enabled: False
  is_inference_mode_enabled: False
在 inference_mode 中:
  is_grad_enabled: False
  is_inference_mode_enabled: True
默认状态:
  is_grad_enabled: True
  is_inference_mode_enabled: False

示例 3: 在条件判断中使用

实例

import torch
import torch.nn as nn

def forward_pass(x, model):
    """根据推理模式执行不同的优化"""
    if torch.is_inference_mode_enabled():
        print("使用推理模式优化")
        # 推理模式下的优化计算
        return model(x)
    elif torch.is_grad_enabled():
        print("训练模式")
        return model(x)
    else:
        print("eval 模式")
        return model(x)

model = nn.Linear(10, 5)
x = torch.randn(5, 10)

# 测试不同状态
print("=== 训练模式 ===")
with torch.enable_grad():
    result = forward_pass(x, model)

print("n=== 推理模式 ===")
with torch.inference_mode():
    result = forward_pass(x, model)

print("n=== eval 模式 ===")
with torch.no_grad():
    result = forward_pass(x, model)

输出结果为:

=== 训练模式 ===
训练模式

=== 推理模式 ===
使用推理模式优化

=== eval 模式 ===
eval 模式

示例 4: 在自定义模块中检测

实例

import torch
import torch.nn as nn

class OptimizedLayer(nn.Module):
    def __init__(self):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(20, 10))

    def forward(self, x):
        # 检测推理模式并应用优化
        if torch.is_inference_mode_enabled():
            # 推理模式:使用更高效的计算方式
            return torch.mm(x, self.weight)
        elif torch.is_grad_enabled():
            # 训练模式:保留梯度计算
            return torch.mm(x, self.weight)
        else:
            # eval 模式:不需要梯度但可以使用优化
            with torch.no_grad():
                return torch.mm(x, self.weight)

layer = OptimizedLayer()
x = torch.randn(5, 10)

print("=== 训练模式 ===")
with torch.enable_grad():
    _ = layer(x)

print("n=== 推理模式 ===")
with torch.inference_mode():
    _ = layer(x)

print("n=== eval 模式 ===")
with torch.no_grad():
    _ = layer(x)

输出结果为:

=== 训练模式 ===
训练模式

=== 推理模式 ===
推理模式:使用更高效的计算方式

=== eval 模式 ===
eval 模式

相关函数

  • torch.inference_mode(): 启用推理模式的上下文管理器。
  • torch.no_grad(): 禁用梯度计算的上下文管理器。
  • torch.is_grad_enabled(): 检查是否启用梯度计算。

注意事项

  • is_inference_mode_enabled 是一个只读函数,不会改变任何状态。
  • 它专门用于检测 inference_mode,而不是 no_grad
  • inference_mode 中时,is_grad_enabled 也返回 False,但反之不成立。
  • 编写通用代码时,可以使用此函数来区分不同的运行模式。

Pytorch torch 参考手册 Pytorch torch 参考手册