PyTorch torch.flip 函数
torch.flip 是 PyTorch 中用于沿指定维度翻转张量的函数。
函数定义
torch.flip(input, dims)
使用示例
实例
import torch
x = torch.arange(8).reshape(2, 4)
print("原始:")
print(x)
print("水平翻转:")
print(torch.flip(x, dims=[1]))
print("垂直翻转:")
print(torch.flip(x, dims=[0]))
x = torch.arange(8).reshape(2, 4)
print("原始:")
print(x)
print("水平翻转:")
print(torch.flip(x, dims=[1]))
print("垂直翻转:")
print(torch.flip(x, dims=[0]))
输出结果为:
原始:
tensor([[0, 1, 2, 3],
[4, 5, 6, 7]])
水平翻转:
tensor([[3, 2, 1, 0],
[7, 6, 5, 4]])
垂直翻转:
tensor([[4, 5, 6, 7],
[0, 1, 2, 3]])

Pytorch torch 参考手册