现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.linspace 函数


Pytorch torch 参考手册 Pytorch torch 参考手册

torch.linspace 是 PyTorch 中用于创建等间隔序列张量的函数。它会创建一个一维张量,包含从起始值到结束值的等间隔数列。

torch.arange 不同,torch.linspace 指定的是元素数量而不是步长,这在需要精确控制元素个数时非常有用。

函数定义

torch.linspace(start, end, steps, dtype=None, device=None, requires_grad=False)

参数:

  • start (float): 序列的起始值。
  • end (float): 序列的结束值(包含)。
  • steps (int): 序列中元素的个数,必须为正整数。
  • dtype (torch.dtype, 可选): 指定张量的数据类型。
  • device (torch.device, 可选): 指定张量存储的设备。
  • requires_grad (bool, 可选): 是否需要计算梯度。

返回值:

  • torch.Tensor: 返回一个一维张量。

使用示例

示例 1: 创建 5 个等间隔点

实例

import torch

# 创建从 0 到 10,共 5 个点的等间隔序列
x = torch.linspace(0, 10, 5)

print(x)

输出结果为:

tensor([ 0.0000,  2.5000,  5.0000,  7.5000, 10.0000])

示例 2: 创建 10 个点

实例

import torch

# 创建从 -1 到 1,共 10 个点的等间隔序列
x = torch.linspace(-1, 1, 10)

print(x)

输出结果为:

tensor([-1.0000, -0.7778, -0.5556, -0.3333, -0.1111,  0.1111,  0.3333,  0.5556,
         0.7778,  1.0000])

示例 3: 用于神经网络学习率调度

实例

import torch

# 模拟学习率从 0.1 逐渐降低到 0.001
learning_rates = torch.linspace(0.1, 0.001, 100)

print("初始学习率:", learning_rates[0].item())
print("最终学习率:", learning_rates[-1].item())

输出结果为:

初始学习率: 0.10000000149011612
最终学习率: 0.0010000000474974513

在这个示例中,我们创建了一个包含 100 个学习率值的序列,常用于学习率调度。


torch.arange 与 torch.linspace 的区别

  • torch.arange(start, end, step): 根据步长创建序列,元素个数由范围和步长决定。
  • torch.linspace(start, end, steps): 根据元素个数创建序列,间隔由范围和元素个数决定。

Pytorch torch 参考手册 Pytorch torch 参考手册