现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.promote_types 函数


Pytorch torch 参考手册 Pytorch torch 参考手册

torch.promote_types 是 PyTorch 中用于类型提升的函数。它接受多个数据类型作为输入,返回能够容纳所有输入类型的公共类型(最合适的类型)。

函数定义

torch.promote_types(type1, type2)

参数说明

  • type1: 第一个数据类型
  • type2: 第二个数据类型

使用示例

实例

import torch

# 类型提升示例
dtype1 = torch.float32
dtype2 = torch.float64

# 获取公共类型
result_dtype = torch.promote_types(dtype1, dtype2)

print("float32 和 float64 的公共类型:", result_dtype)

# 整数类型和浮点数类型
dtype3 = torch.int32
dtype4 = torch.float32
result_dtype2 = torch.promote_types(dtype3, dtype4)

print("int32 和 float32 的公共类型:", result_dtype2)

输出结果为:

float32 和 float64 的公共类型: torch.float64
int32 和 float32 的公共类型: torch.float32

Pytorch torch 参考手册 Pytorch torch 参考手册