PyTorch torch.cross 函数
torch.cross 是 PyTorch 中用于计算两个 3 维向量(或批量的 3 维向量)叉积的函数。叉积产生一个垂直于两个输入向量的新向量。
函数定义
torch.cross(input, other, dim=-1)
使用示例
实例
import torch
# 两个 3 维向量的叉积
a = torch.tensor([1, 0, 0])
b = torch.tensor([0, 1, 0])
c = torch.cross(a, b)
print("a:", a)
print("b:", b)
print("a x b:", c)
# 输出: tensor([0, 0, 1])
# 批量计算叉积
a = torch.tensor([[1, 0, 0], [0, 1, 0]])
b = torch.tensor([[0, 1, 0], [1, 0, 0]])
result = torch.cross(a, b)
print("批量叉积:")
print(result)
# tensor([[0, 0, 1],
# [0, 0, -1]])
# 指定维度
a = torch.randn(3, 4, 3)
b = torch.randn(3, 4, 3)
result = torch.cross(a, b, dim=2)
print("指定维度叉积形状:", result.shape)
# 两个 3 维向量的叉积
a = torch.tensor([1, 0, 0])
b = torch.tensor([0, 1, 0])
c = torch.cross(a, b)
print("a:", a)
print("b:", b)
print("a x b:", c)
# 输出: tensor([0, 0, 1])
# 批量计算叉积
a = torch.tensor([[1, 0, 0], [0, 1, 0]])
b = torch.tensor([[0, 1, 0], [1, 0, 0]])
result = torch.cross(a, b)
print("批量叉积:")
print(result)
# tensor([[0, 0, 1],
# [0, 0, -1]])
# 指定维度
a = torch.randn(3, 4, 3)
b = torch.randn(3, 4, 3)
result = torch.cross(a, b, dim=2)
print("指定维度叉积形状:", result.shape)

Pytorch torch 参考手册