PyTorch torch.narrow_copy 函数
torch.narrow_copy 是 PyTorch 中用于返回张量切片副本的函数。它返回指定维度上从起始位置到指定长度的切片的副本。
该函数与 torch.narrow 的功能类似,但 torch.narrow 返回视图,而 torch.narrow_copy 返回副本。
函数定义
torch.narrow_copy(input, dim, start, length)
参数:
input(Tensor): 输入张量。dim(int): 要进行切片的维度。start(int): 起始索引。length(int): 切片的长度。
返回值:
torch.Tensor: 返回指定切片后的张量副本。
使用示例
实例
import torch
# 创建一个张量
x = torch.tensor([[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12]])
# 在第一维(行)上,从索引0开始取2行
y = torch.narrow_copy(x, dim=0, start=0, length=2)
print("原始张量:")
print(x)
print("n切片副本:")
print(y)
# 修改副本不会影响原始张量
y[0, 0] = 100
print("n修改副本后,原始张量:", x[0, 0])
print("修改副本后,切片副本:", y[0, 0])
# 创建一个张量
x = torch.tensor([[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12]])
# 在第一维(行)上,从索引0开始取2行
y = torch.narrow_copy(x, dim=0, start=0, length=2)
print("原始张量:")
print(x)
print("n切片副本:")
print(y)
# 修改副本不会影响原始张量
y[0, 0] = 100
print("n修改副本后,原始张量:", x[0, 0])
print("修改副本后,切片副本:", y[0, 0])
输出结果为:
原始张量:
tensor([[ 1, 2, 3, 4],
[ 5, 6, 7, 8],
[ 9, 10, 11, 12]])
切片副本:
tensor([[1, 2, 3, 4],
[5, 6, 7, 8]])
修改副本后,原始张量: tensor(1)
修改副本后,切片副本: tensor(100)
实例
import torch
# 创建一个3D张量
x = torch.randn(5, 6, 7)
# 在第二维上,从索引2开始取4个元素
y = torch.narrow_copy(x, dim=1, start=2, length=4)
print("原始形状:", x.shape)
print("切片副本形状:", y.shape)
# 创建一个3D张量
x = torch.randn(5, 6, 7)
# 在第二维上,从索引2开始取4个元素
y = torch.narrow_copy(x, dim=1, start=2, length=4)
print("原始形状:", x.shape)
print("切片副本形状:", y.shape)
输出结果为:
原始形状: torch.Size([5, 6, 7]) 切片副本形状: torch.Size([5, 4, 7])
注意:torch.narrow_copy 返回的是副本,而不是视图。这意味着修改返回的张量不会影响原始张量,但会占用额外的内存。如果不需要独立的副本,可以使用 torch.narrow 以节省内存。

Pytorch torch 参考手册