现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.movedim 函数


Pytorch torch 参考手册 Pytorch torch 参考手册

torch.movedim 是 PyTorch 中用于移动张量维度的函数。它将指定维度移动到新的位置,返回一个新的张量视图。

这个函数在调整张量形状以便进行特定操作时非常有用,例如在处理图像数据或准备神经网络输入时。

函数定义

torch.movedim(input, source, destination)

参数:

  • input (Tensor): 输入张量。
  • source (int 或 tuple of int): 要移动的原始维度索引。可以是一个整数或维度索引的元组。
  • destination (int 或 tuple of int): 目标位置索引。可以是一个整数或维度索引的元组,长度必须与 source 相同。

返回值:

  • torch.Tensor: 返回维度移动后的张量视图。

使用示例

实例

import torch

# 创建一个形状为 (batch, channel, height, width) 的张量
x = torch.randn(32, 3, 224, 224)

# 将通道维度移动到最后一个维度
# 从位置1移动到位置3
y = torch.movedim(x, 1, 3)

print("原始形状:", x.shape)
print("移动后形状:", y.shape)

输出结果为:

原始形状: torch.Size([32, 3, 224, 224])
移动后形状: torch.Size([32, 224, 224, 3])

实例

import torch

# 创建一个形状为 (D0, D1, D2, D3) 的张量
x = torch.randn(2, 3, 4, 5)

# 将多个维度一起移动
# 将维度0和1移动到维度2和3的位置
y = torch.movedim(x, source=(0, 1), destination=(2, 3))

print("原始形状:", x.shape)
print("移动后形状:", y.shape)

输出结果为:

原始形状: torch.Size([2, 3, 4, 5])
移动后形状: torch.Size([4, 5, 2, 3])

实例

import torch

# 将图像张量从 (N, C, H, W) 转换为 (N, H, W, C)
# 这在将数据发送到某些需要通道在最后位置的API时很有用
images = torch.randn(16, 3, 64, 64)
images_permuted = torch.movedim(images, 1, 3)

print("原始形状 (N, C, H, W):", images.shape)
print("转换后形状 (N, H, W, C):", images_permuted.shape)

注意:torch.movedim 返回的是视图,而不是副本,因此操作是高效的。类似的函数还有 torch.moveaxis,它与 torch.movedim 功能相同,但参数名称不同。


Pytorch torch 参考手册 Pytorch torch 参考手册