现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.ones 函数


Pytorch torch 参考手册 Pytorch torch 参考手册

torch.ones 是 PyTorch 中用于创建全一张量的函数。它会创建一个指定形状的张量,所有元素都初始化为 1。

这在深度学习中常用于初始化某些特定参数、创建掩码或作为数学运算的基准值。

函数定义

torch.ones(*size, dtype=None, device=None, requires_grad=False, pin_memory=False)

参数:

  • *size (int): 张量的形状,例如 3(3, 4)(2, 3, 4) 等。
  • dtype (torch.dtype, 可选): 指定张量的数据类型,默认为 torch.float32
  • device (torch.device, 可选): 指定张量存储的设备。
  • requires_grad (bool, 可选): 是否需要计算梯度。
  • pin_memory (bool, 可选): 是否使用锁页内存。

返回值:

  • torch.Tensor: 返回一个全一张量。

使用示例

示例 1: 创建一维全一张量

实例

import torch

# 创建包含 5 个元素的全一张量
x = torch.ones(5)

print(x)

输出结果为:

tensor([1., 1., 1., 1., 1.])

示例 2: 创建二维全一张量

实例

import torch

# 创建 3x4 的全一张量(矩阵)
x = torch.ones(3, 4)

print(x)
print(x.shape)

输出结果为:

tensor([[1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.]])
torch.Size([3, 4])

示例 3: 创建三维全一张量

实例

import torch

# 创建 2x3x4 的全一张量
x = torch.ones(2, 3, 4)

print(x.shape)

输出结果为:

torch.Size([2, 3, 4])

Pytorch torch 参考手册 Pytorch torch 参考手册