现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.set_float32_matmul_precision 函数


Pytorch torch 参考手册 Pytorch torch 参考手册

torch.set_float32_matmul_precision 是 PyTorch 中用于设置 float32 矩阵乘法精度的函数。可以选择使用 lower precision 来提高性能,或使用 higher precision 来提高准确性。

函数定义

torch.set_float32_matmul_precision(precision)

参数说明

  • precision: 精度级别,可选值:
    • "highest": 最高精度(默认)
    • "high": 高精度
    • "medium": 中等精度(使用 TensorFloat-32)

使用示例

实例

import torch

# 设置为中等精度(使用 TensorFloat-32 加速)
torch.set_float32_matmul_precision("medium")

# 创建矩阵进行测试
a = torch.randn(100, 100)
b = torch.randn(100, 100)

# 矩阵乘法
c = torch.matmul(a, b)

print("使用中等精度进行矩阵乘法")
print("结果形状:", c.shape)

# 恢复最高精度
torch.set_float32_matmul_precision("highest")

输出结果为:

使用中等精度进行矩阵乘法
结果形状: torch.Size([100, 100])

Pytorch torch 参考手册 Pytorch torch 参考手册