现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.dtype 函数


Pytorch torch 参考手册 Pytorch torch 参考手册

torch.dtype 是 PyTorch 中用于表示张量数据类型的对象,不是函数。它用于获取或指定张量的数据类型。

数据类型说明

torch.float32 - 32位浮点数
torch.float64 - 64位浮点数(双精度)
torch.float16 - 16位浮点数(半精度)
torch.bfloat16 - 16位浮点数(Brain Float)
torch.complex32 - 32位复数
torch.complex64 - 64位复数
torch.complex128 - 128位复数
torch.int8 - 8位整数
torch.int16 - 16位整数
torch.int32 - 32位整数
torch.int64 - 64位整数
torch.bool - 布尔类型
torch.uint8 - 无符号8位整数

使用示例

实例

import torch

# 创建不同数据类型的张量
x1 = torch.tensor([1.0, 2.0, 3.0])
x2 = torch.tensor([1, 2, 3], dtype=torch.int32)
x3 = torch.tensor([True, False, True], dtype=torch.bool)

# 获取张量的数据类型
print("x1 dtype:", x1.dtype)
print("x2 dtype:", x2.dtype)
print("x3 dtype:", x3.dtype)

Pytorch torch 参考手册 Pytorch torch 参考手册