PyTorch torch.narrow 函数
torch.narrow 是 PyTorch 中用于返回张量切片的函数。它返回指定维度上从起始位置到指定长度的切片视图。
函数定义
torch.narrow(input, dim, start, length)
参数:
input(Tensor): 输入张量。dim(int): 要进行切片的维度。start(int): 起始索引。length(int): 切片的长度。
返回值:
torch.Tensor: 返回指定切片后的张量视图。
使用示例
实例
import torch
# 创建一个张量
x = torch.tensor([[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12]])
# 在第一维(行)上,从索引0开始取2行
y = torch.narrow(x, dim=0, start=0, length=2)
print("原始张量:")
print(x)
print("n切片结果:")
print(y)
# 创建一个张量
x = torch.tensor([[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12]])
# 在第一维(行)上,从索引0开始取2行
y = torch.narrow(x, dim=0, start=0, length=2)
print("原始张量:")
print(x)
print("n切片结果:")
print(y)
输出结果为:
原始张量:
tensor([[ 1, 2, 3, 4],
[ 5, 6, 7, 8],
[ 9, 10, 11, 12]])
切片结果:
tensor([[1, 2, 3, 4],
[5, 6, 7, 8]])
实例
import torch
# 创建一个张量
x = torch.tensor([[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12]])
# 在第二维(列)上,从索引1开始取2列
y = torch.narrow(x, dim=1, start=1, length=2)
print("原始张量:")
print(x)
print("n切片结果:")
print(y)
# 创建一个张量
x = torch.tensor([[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12]])
# 在第二维(列)上,从索引1开始取2列
y = torch.narrow(x, dim=1, start=1, length=2)
print("原始张量:")
print(x)
print("n切片结果:")
print(y)
输出结果为:
原始张量:
tensor([[ 1, 2, 3, 4],
[ 5, 6, 7, 8],
[ 9, 10, 11, 12]])
切片结果:
tensor([[ 2, 3],
[ 6, 7],
[10, 11]])
实例
import torch
# 创建一个3D张量
x = torch.randn(5, 6, 7)
# 在第一维上,从索引2开始取3个元素
y = torch.narrow(x, dim=0, start=2, length=3)
print("原始形状:", x.shape)
print("切片后形状:", y.shape)
# 创建一个3D张量
x = torch.randn(5, 6, 7)
# 在第一维上,从索引2开始取3个元素
y = torch.narrow(x, dim=0, start=2, length=3)
print("原始形状:", x.shape)
print("切片后形状:", y.shape)
输出结果为:
原始形状: torch.Size([5, 6, 7]) 切片后形状: torch.Size([3, 6, 7])
注意:torch.narrow 返回的是原始张量的视图,不是副本。如果需要副本,可以使用 torch.narrow_copy。

Pytorch torch 参考手册