PyTorch torch.is_inference_mode_enabled 函数
torch.is_inference_mode_enabled 是 PyTorch 中用于检查当前是否启用推理模式的函数。它返回当前是否处于 inference_mode 上下文的布尔值。
这在编写需要根据推理模式状态执行不同逻辑的代码时非常有用。
函数定义
torch.is_inference_mode_enabled()
参数:
- 无参数。
返回值:
- 返回一个布尔值:如果当前启用了推理模式,返回
True;否则返回False。
使用示例
示例 1: 基本用法
实例
import torch
# 默认状态下,推理模式是禁用的
print("默认状态:", torch.is_inference_mode_enabled())
# 在 inference_mode 上下文中
with torch.inference_mode():
print("在 inference_mode 中:", torch.is_inference_mode_enabled())
# 退出后恢复
print("退出后:", torch.is_inference_mode_enabled())
# 默认状态下,推理模式是禁用的
print("默认状态:", torch.is_inference_mode_enabled())
# 在 inference_mode 上下文中
with torch.inference_mode():
print("在 inference_mode 中:", torch.is_inference_mode_enabled())
# 退出后恢复
print("退出后:", torch.is_inference_mode_enabled())
输出结果为:
默认状态: False 在 inference_mode 中: True 退出后: False
示例 2: 对比 is_grad_enabled 和 is_inference_mode_enabled
实例
import torch
# 在 no_grad 中
with torch.no_grad():
print("在 no_grad 中:")
print(" is_grad_enabled:", torch.is_grad_enabled())
print(" is_inference_mode_enabled:", torch.is_inference_mode_enabled())
# 在 inference_mode 中
with torch.inference_mode():
print("在 inference_mode 中:")
print(" is_grad_enabled:", torch.is_grad_enabled())
print(" is_inference_mode_enabled:", torch.is_inference_mode_enabled())
# 默认状态
print("默认状态:")
print(" is_grad_enabled:", torch.is_grad_enabled())
print(" is_inference_mode_enabled:", torch.is_inference_mode_enabled())
# 在 no_grad 中
with torch.no_grad():
print("在 no_grad 中:")
print(" is_grad_enabled:", torch.is_grad_enabled())
print(" is_inference_mode_enabled:", torch.is_inference_mode_enabled())
# 在 inference_mode 中
with torch.inference_mode():
print("在 inference_mode 中:")
print(" is_grad_enabled:", torch.is_grad_enabled())
print(" is_inference_mode_enabled:", torch.is_inference_mode_enabled())
# 默认状态
print("默认状态:")
print(" is_grad_enabled:", torch.is_grad_enabled())
print(" is_inference_mode_enabled:", torch.is_inference_mode_enabled())
输出结果为:
在 no_grad 中: is_grad_enabled: False is_inference_mode_enabled: False 在 inference_mode 中: is_grad_enabled: False is_inference_mode_enabled: True 默认状态: is_grad_enabled: True is_inference_mode_enabled: False
示例 3: 在条件判断中使用
实例
import torch
import torch.nn as nn
def forward_pass(x, model):
"""根据推理模式执行不同的优化"""
if torch.is_inference_mode_enabled():
print("使用推理模式优化")
# 推理模式下的优化计算
return model(x)
elif torch.is_grad_enabled():
print("训练模式")
return model(x)
else:
print("eval 模式")
return model(x)
model = nn.Linear(10, 5)
x = torch.randn(5, 10)
# 测试不同状态
print("=== 训练模式 ===")
with torch.enable_grad():
result = forward_pass(x, model)
print("n=== 推理模式 ===")
with torch.inference_mode():
result = forward_pass(x, model)
print("n=== eval 模式 ===")
with torch.no_grad():
result = forward_pass(x, model)
import torch.nn as nn
def forward_pass(x, model):
"""根据推理模式执行不同的优化"""
if torch.is_inference_mode_enabled():
print("使用推理模式优化")
# 推理模式下的优化计算
return model(x)
elif torch.is_grad_enabled():
print("训练模式")
return model(x)
else:
print("eval 模式")
return model(x)
model = nn.Linear(10, 5)
x = torch.randn(5, 10)
# 测试不同状态
print("=== 训练模式 ===")
with torch.enable_grad():
result = forward_pass(x, model)
print("n=== 推理模式 ===")
with torch.inference_mode():
result = forward_pass(x, model)
print("n=== eval 模式 ===")
with torch.no_grad():
result = forward_pass(x, model)
输出结果为:
=== 训练模式 === 训练模式 === 推理模式 === 使用推理模式优化 === eval 模式 === eval 模式
示例 4: 在自定义模块中检测
实例
import torch
import torch.nn as nn
class OptimizedLayer(nn.Module):
def __init__(self):
super().__init__()
self.weight = nn.Parameter(torch.randn(20, 10))
def forward(self, x):
# 检测推理模式并应用优化
if torch.is_inference_mode_enabled():
# 推理模式:使用更高效的计算方式
return torch.mm(x, self.weight)
elif torch.is_grad_enabled():
# 训练模式:保留梯度计算
return torch.mm(x, self.weight)
else:
# eval 模式:不需要梯度但可以使用优化
with torch.no_grad():
return torch.mm(x, self.weight)
layer = OptimizedLayer()
x = torch.randn(5, 10)
print("=== 训练模式 ===")
with torch.enable_grad():
_ = layer(x)
print("n=== 推理模式 ===")
with torch.inference_mode():
_ = layer(x)
print("n=== eval 模式 ===")
with torch.no_grad():
_ = layer(x)
import torch.nn as nn
class OptimizedLayer(nn.Module):
def __init__(self):
super().__init__()
self.weight = nn.Parameter(torch.randn(20, 10))
def forward(self, x):
# 检测推理模式并应用优化
if torch.is_inference_mode_enabled():
# 推理模式:使用更高效的计算方式
return torch.mm(x, self.weight)
elif torch.is_grad_enabled():
# 训练模式:保留梯度计算
return torch.mm(x, self.weight)
else:
# eval 模式:不需要梯度但可以使用优化
with torch.no_grad():
return torch.mm(x, self.weight)
layer = OptimizedLayer()
x = torch.randn(5, 10)
print("=== 训练模式 ===")
with torch.enable_grad():
_ = layer(x)
print("n=== 推理模式 ===")
with torch.inference_mode():
_ = layer(x)
print("n=== eval 模式 ===")
with torch.no_grad():
_ = layer(x)
输出结果为:
=== 训练模式 === 训练模式 === 推理模式 === 推理模式:使用更高效的计算方式 === eval 模式 === eval 模式
相关函数
torch.inference_mode(): 启用推理模式的上下文管理器。torch.no_grad(): 禁用梯度计算的上下文管理器。torch.is_grad_enabled(): 检查是否启用梯度计算。
注意事项
is_inference_mode_enabled是一个只读函数,不会改变任何状态。- 它专门用于检测
inference_mode,而不是no_grad。 - 在
inference_mode中时,is_grad_enabled也返回False,但反之不成立。 - 编写通用代码时,可以使用此函数来区分不同的运行模式。

Pytorch torch 参考手册