现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.moveaxis 函数


Pytorch torch 参考手册 Pytorch torch 参考手册

torch.moveaxis 是 PyTorch 中用于移动张量轴的函数。它将指定轴移动到新的位置,返回一个新的张量视图。

这个函数与 torch.movedim 功能相同,只是参数名称不同(moveaxis 使用 "axis" 而 movedim 使用 "dim")。

函数定义

torch.moveaxis(input, source, destination)

参数:

  • input (Tensor): 输入张量。
  • source (int 或 tuple of int): 要移动的原始轴索引。可以是一个整数或轴索引的元组。
  • destination (int 或 tuple of int): 目标位置索引。可以是一个整数或轴索引的元组,长度必须与 source 相同。

返回值:

  • torch.Tensor: 返回轴移动后的张量视图。

使用示例

实例

import torch

# 创建一个形状为 (batch, seq_len, feature) 的张量
x = torch.randn(32, 10, 128)

# 将最后一个维度移动到第一个位置
y = torch.moveaxis(x, -1, 0)

print("原始形状:", x.shape)
print("移动后形状:", y.shape)

输出结果为:

原始形状: torch.Size([32, 10, 128])
移动后形状: torch.Size([128, 10, 32])

实例

import torch

# 创建一个四维张量 (time, channel, height, width)
x = torch.randn(10, 3, 32, 32)

# 将 time 维度移动到最后
y = torch.moveaxis(x, 0, -1)

print("原始形状:", x.shape)
print("移动后形状:", y.shape)

输出结果为:

原始形状: torch.Size([10, 3, 32, 32])
移动后形状: torch.Size([3, 32, 32, 10])

实例

import torch

# 多个轴一起移动
x = torch.randn(2, 3, 4, 5)

# 将轴0和1一起移动到轴2和3的位置
y = torch.moveaxis(x, source=(0, 1), destination=(2, 3))

print("原始形状:", x.shape)
print("移动后形状:", y.shape)

输出结果为:

原始形状: torch.Size([2, 3, 4, 5])
移动后形状: torch.Size([4, 5, 2, 3])

注意:torch.moveaxis 返回的是视图,而不是副本,因此操作是高效的。该函数常用于调整数据形状以适应不同的深度学习框架或API。


Pytorch torch 参考手册 Pytorch torch 参考手册