现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.narrow 函数


Pytorch torch 参考手册 Pytorch torch 参考手册

torch.narrow 是 PyTorch 中用于返回张量切片的函数。它返回指定维度上从起始位置到指定长度的切片视图。

函数定义

torch.narrow(input, dim, start, length)

参数:

  • input (Tensor): 输入张量。
  • dim (int): 要进行切片的维度。
  • start (int): 起始索引。
  • length (int): 切片的长度。

返回值:

  • torch.Tensor: 返回指定切片后的张量视图。

使用示例

实例

import torch

# 创建一个张量
x = torch.tensor([[1, 2, 3, 4],
                  [5, 6, 7, 8],
                  [9, 10, 11, 12]])

# 在第一维(行)上,从索引0开始取2行
y = torch.narrow(x, dim=0, start=0, length=2)

print("原始张量:")
print(x)
print("n切片结果:")
print(y)

输出结果为:

原始张量:
tensor([[ 1,  2,  3,  4],
        [ 5,  6,  7,  8],
        [ 9, 10, 11, 12]])

切片结果:
tensor([[1, 2, 3, 4],
        [5, 6, 7, 8]])

实例

import torch

# 创建一个张量
x = torch.tensor([[1, 2, 3, 4],
                  [5, 6, 7, 8],
                  [9, 10, 11, 12]])

# 在第二维(列)上,从索引1开始取2列
y = torch.narrow(x, dim=1, start=1, length=2)

print("原始张量:")
print(x)
print("n切片结果:")
print(y)

输出结果为:

原始张量:
tensor([[ 1,  2,  3,  4],
        [ 5,  6,  7,  8],
        [ 9, 10, 11, 12]])

切片结果:
tensor([[ 2,  3],
        [ 6,  7],
        [10, 11]])

实例

import torch

# 创建一个3D张量
x = torch.randn(5, 6, 7)

# 在第一维上,从索引2开始取3个元素
y = torch.narrow(x, dim=0, start=2, length=3)

print("原始形状:", x.shape)
print("切片后形状:", y.shape)

输出结果为:

原始形状: torch.Size([5, 6, 7])
切片后形状: torch.Size([3, 6, 7])

注意:torch.narrow 返回的是原始张量的视图,不是副本。如果需要副本,可以使用 torch.narrow_copy


Pytorch torch 参考手册 Pytorch torch 参考手册