PyTorch torch.moveaxis 函数
torch.moveaxis 是 PyTorch 中用于移动张量轴的函数。它将指定轴移动到新的位置,返回一个新的张量视图。
这个函数与 torch.movedim 功能相同,只是参数名称不同(moveaxis 使用 "axis" 而 movedim 使用 "dim")。
函数定义
torch.moveaxis(input, source, destination)
参数:
input(Tensor): 输入张量。source(int 或 tuple of int): 要移动的原始轴索引。可以是一个整数或轴索引的元组。destination(int 或 tuple of int): 目标位置索引。可以是一个整数或轴索引的元组,长度必须与 source 相同。
返回值:
torch.Tensor: 返回轴移动后的张量视图。
使用示例
实例
import torch
# 创建一个形状为 (batch, seq_len, feature) 的张量
x = torch.randn(32, 10, 128)
# 将最后一个维度移动到第一个位置
y = torch.moveaxis(x, -1, 0)
print("原始形状:", x.shape)
print("移动后形状:", y.shape)
# 创建一个形状为 (batch, seq_len, feature) 的张量
x = torch.randn(32, 10, 128)
# 将最后一个维度移动到第一个位置
y = torch.moveaxis(x, -1, 0)
print("原始形状:", x.shape)
print("移动后形状:", y.shape)
输出结果为:
原始形状: torch.Size([32, 10, 128]) 移动后形状: torch.Size([128, 10, 32])
实例
import torch
# 创建一个四维张量 (time, channel, height, width)
x = torch.randn(10, 3, 32, 32)
# 将 time 维度移动到最后
y = torch.moveaxis(x, 0, -1)
print("原始形状:", x.shape)
print("移动后形状:", y.shape)
# 创建一个四维张量 (time, channel, height, width)
x = torch.randn(10, 3, 32, 32)
# 将 time 维度移动到最后
y = torch.moveaxis(x, 0, -1)
print("原始形状:", x.shape)
print("移动后形状:", y.shape)
输出结果为:
原始形状: torch.Size([10, 3, 32, 32]) 移动后形状: torch.Size([3, 32, 32, 10])
实例
import torch
# 多个轴一起移动
x = torch.randn(2, 3, 4, 5)
# 将轴0和1一起移动到轴2和3的位置
y = torch.moveaxis(x, source=(0, 1), destination=(2, 3))
print("原始形状:", x.shape)
print("移动后形状:", y.shape)
# 多个轴一起移动
x = torch.randn(2, 3, 4, 5)
# 将轴0和1一起移动到轴2和3的位置
y = torch.moveaxis(x, source=(0, 1), destination=(2, 3))
print("原始形状:", x.shape)
print("移动后形状:", y.shape)
输出结果为:
原始形状: torch.Size([2, 3, 4, 5]) 移动后形状: torch.Size([4, 5, 2, 3])
注意:torch.moveaxis 返回的是视图,而不是副本,因此操作是高效的。该函数常用于调整数据形状以适应不同的深度学习框架或API。

Pytorch torch 参考手册