PyTorch torch.dtype 函数
torch.dtype 是 PyTorch 中用于表示张量数据类型的对象,不是函数。它用于获取或指定张量的数据类型。
数据类型说明
torch.float32 - 32位浮点数 torch.float64 - 64位浮点数(双精度) torch.float16 - 16位浮点数(半精度) torch.bfloat16 - 16位浮点数(Brain Float) torch.complex32 - 32位复数 torch.complex64 - 64位复数 torch.complex128 - 128位复数 torch.int8 - 8位整数 torch.int16 - 16位整数 torch.int32 - 32位整数 torch.int64 - 64位整数 torch.bool - 布尔类型 torch.uint8 - 无符号8位整数
使用示例
实例
import torch
# 创建不同数据类型的张量
x1 = torch.tensor([1.0, 2.0, 3.0])
x2 = torch.tensor([1, 2, 3], dtype=torch.int32)
x3 = torch.tensor([True, False, True], dtype=torch.bool)
# 获取张量的数据类型
print("x1 dtype:", x1.dtype)
print("x2 dtype:", x2.dtype)
print("x3 dtype:", x3.dtype)
# 创建不同数据类型的张量
x1 = torch.tensor([1.0, 2.0, 3.0])
x2 = torch.tensor([1, 2, 3], dtype=torch.int32)
x3 = torch.tensor([True, False, True], dtype=torch.bool)
# 获取张量的数据类型
print("x1 dtype:", x1.dtype)
print("x2 dtype:", x2.dtype)
print("x3 dtype:", x3.dtype)

Pytorch torch 参考手册