现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.narrow_copy 函数


Pytorch torch 参考手册 Pytorch torch 参考手册

torch.narrow_copy 是 PyTorch 中用于返回张量切片副本的函数。它返回指定维度上从起始位置到指定长度的切片的副本。

该函数与 torch.narrow 的功能类似,但 torch.narrow 返回视图,而 torch.narrow_copy 返回副本。

函数定义

torch.narrow_copy(input, dim, start, length)

参数:

  • input (Tensor): 输入张量。
  • dim (int): 要进行切片的维度。
  • start (int): 起始索引。
  • length (int): 切片的长度。

返回值:

  • torch.Tensor: 返回指定切片后的张量副本。

使用示例

实例

import torch

# 创建一个张量
x = torch.tensor([[1, 2, 3, 4],
                  [5, 6, 7, 8],
                  [9, 10, 11, 12]])

# 在第一维(行)上,从索引0开始取2行
y = torch.narrow_copy(x, dim=0, start=0, length=2)

print("原始张量:")
print(x)
print("n切片副本:")
print(y)

# 修改副本不会影响原始张量
y[0, 0] = 100
print("n修改副本后,原始张量:", x[0, 0])
print("修改副本后,切片副本:", y[0, 0])

输出结果为:

原始张量:
tensor([[ 1,  2,  3,  4],
        [ 5,  6,  7,  8],
        [ 9, 10, 11, 12]])

切片副本:
tensor([[1, 2, 3, 4],
        [5, 6, 7, 8]])

修改副本后,原始张量: tensor(1)
修改副本后,切片副本: tensor(100)

实例

import torch

# 创建一个3D张量
x = torch.randn(5, 6, 7)

# 在第二维上,从索引2开始取4个元素
y = torch.narrow_copy(x, dim=1, start=2, length=4)

print("原始形状:", x.shape)
print("切片副本形状:", y.shape)

输出结果为:

原始形状: torch.Size([5, 6, 7])
切片副本形状: torch.Size([5, 4, 7])

注意:torch.narrow_copy 返回的是副本,而不是视图。这意味着修改返回的张量不会影响原始张量,但会占用额外的内存。如果不需要独立的副本,可以使用 torch.narrow 以节省内存。


Pytorch torch 参考手册 Pytorch torch 参考手册