PyTorch torch.set_float32_matmul_precision 函数
torch.set_float32_matmul_precision 是 PyTorch 中用于设置 float32 矩阵乘法精度的函数。可以选择使用 lower precision 来提高性能,或使用 higher precision 来提高准确性。
函数定义
torch.set_float32_matmul_precision(precision)
参数说明
precision: 精度级别,可选值:"highest": 最高精度(默认)"high": 高精度"medium": 中等精度(使用 TensorFloat-32)
使用示例
实例
import torch
# 设置为中等精度(使用 TensorFloat-32 加速)
torch.set_float32_matmul_precision("medium")
# 创建矩阵进行测试
a = torch.randn(100, 100)
b = torch.randn(100, 100)
# 矩阵乘法
c = torch.matmul(a, b)
print("使用中等精度进行矩阵乘法")
print("结果形状:", c.shape)
# 恢复最高精度
torch.set_float32_matmul_precision("highest")
# 设置为中等精度(使用 TensorFloat-32 加速)
torch.set_float32_matmul_precision("medium")
# 创建矩阵进行测试
a = torch.randn(100, 100)
b = torch.randn(100, 100)
# 矩阵乘法
c = torch.matmul(a, b)
print("使用中等精度进行矩阵乘法")
print("结果形状:", c.shape)
# 恢复最高精度
torch.set_float32_matmul_precision("highest")
输出结果为:
使用中等精度进行矩阵乘法 结果形状: torch.Size([100, 100])

Pytorch torch 参考手册