现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.tensor 函数


Pytorch torch 参考手册 Pytorch torch 参考手册

torch.tensor 是 PyTorch 中用于创建张量的核心函数。它通过复制数据创建一个新的张量,并且不会保留原始数据的自动梯度历史。

这是创建 PyTorch 张量最常用的方式之一,适用于大多数场景。

函数定义

torch.tensor(data, dtype=None, device=None, requires_grad=False, pin_memory=False)

参数:

  • data (任意类型): 张量的数据,可以是 Python 列表、元组、NumPy 数组等。
  • dtype (torch.dtype, 可选): 指定张量的数据类型,如 torch.float32torch.int64 等。
  • device (torch.device, 可选): 指定张量存储的设备,如 torch.device('cpu')torch.device('cuda')
  • requires_grad (bool, 可选): 是否需要计算梯度,默认为 False
  • pin_memory (bool, 可选): 是否使用锁页内存,默认为 False

返回值:

  • torch.Tensor: 返回一个 PyTorch 张量。

使用示例

以下是一些使用 torch.tensor 函数的示例。

示例 1: 从列表创建张量

实例

import torch

# 从 Python 列表创建一维张量
data = [1, 2, 3, 4, 5]
x = torch.tensor(data)

print(x)
print(x.dtype)

输出结果为:

tensor([1, 2, 3, 4, 5])
torch.int64

在这个示例中,我们从 Python 列表创建了一个一维张量。PyTorch 会自动推断数据类型为 int64

示例 2: 指定数据类型

实例

import torch

# 创建 float32 类型的张量
x = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32)

print(x)
print(x.dtype)

输出结果为:

tensor([1., 2., 3.])
torch.float32

在这个示例中,我们显式指定了张量的数据类型为 float32

示例 3: 创建需要梯度的张量

实例

import torch

# 创建一个需要计算梯度的张量
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)

print(x.requires_grad)

输出结果为:

True

在这个示例中,我们创建了一个需要计算梯度的张量,这在训练神经网络时非常有用。

示例 4: 从二维列表创建张量

实例

import torch

# 从二维列表创建二维张量(矩阵)
data = [[1, 2, 3], [4, 5, 6]]
x = torch.tensor(data)

print(x)
print(x.shape)

输出结果为:

tensor([[1, 2, 3],
        [4, 5, 6]])
torch.Size([2, 3])

在这个示例中,我们从二维列表创建了一个 2x3 的二维张量。

示例 5: 在 CUDA 设备上创建张量

实例

import torch

# 检查 CUDA 是否可用
if torch.cuda.is_available():
    # 在 CUDA 设备上创建张量
    x = torch.tensor([1, 2, 3], device='cuda')
    print(x.device)
else:
    print("CUDA 不可用")

输出结果为:

cuda:0

在这个示例中,我们检查 CUDA 是否可用,然后在 GPU 上创建张量。


torch.tensor 与 torch.as_tensor 的区别

torch.tensortorch.as_tensor 都用于创建张量,但它们有重要的区别:

  • torch.tensor: 总是复制数据,创建的张量与原始数据不共享内存。
  • torch.as_tensor: 尽可能共享数据,不复制内存,只在必要时复制。

如果需要保留自动梯度历史或避免不必要的数据复制,可以使用 torch.as_tensor


注意事项

  • torch.tensor 会复制数据,因此对返回的张量的修改不会影响原始数据。
  • 如果需要创建与另一个张量相同形状和设备的新张量,可以使用 torch.zeros_like()torch.ones_like()
  • 创建张量时,如果不指定 dtype,PyTorch 会自动推断数据类型。

Pytorch torch 参考手册 Pytorch torch 参考手册