PyTorch torch.promote_types 函数
torch.promote_types 是 PyTorch 中用于类型提升的函数。它接受多个数据类型作为输入,返回能够容纳所有输入类型的公共类型(最合适的类型)。
函数定义
torch.promote_types(type1, type2)
参数说明
type1: 第一个数据类型type2: 第二个数据类型
使用示例
实例
import torch
# 类型提升示例
dtype1 = torch.float32
dtype2 = torch.float64
# 获取公共类型
result_dtype = torch.promote_types(dtype1, dtype2)
print("float32 和 float64 的公共类型:", result_dtype)
# 整数类型和浮点数类型
dtype3 = torch.int32
dtype4 = torch.float32
result_dtype2 = torch.promote_types(dtype3, dtype4)
print("int32 和 float32 的公共类型:", result_dtype2)
# 类型提升示例
dtype1 = torch.float32
dtype2 = torch.float64
# 获取公共类型
result_dtype = torch.promote_types(dtype1, dtype2)
print("float32 和 float64 的公共类型:", result_dtype)
# 整数类型和浮点数类型
dtype3 = torch.int32
dtype4 = torch.float32
result_dtype2 = torch.promote_types(dtype3, dtype4)
print("int32 和 float32 的公共类型:", result_dtype2)
输出结果为:
float32 和 float64 的公共类型: torch.float64 int32 和 float32 的公共类型: torch.float32

Pytorch torch 参考手册