现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.cross 函数


Pytorch torch 参考手册 Pytorch torch 参考手册

torch.cross 是 PyTorch 中用于计算两个 3 维向量(或批量的 3 维向量)叉积的函数。叉积产生一个垂直于两个输入向量的新向量。

函数定义

torch.cross(input, other, dim=-1)

使用示例

实例

import torch

# 两个 3 维向量的叉积
a = torch.tensor([1, 0, 0])
b = torch.tensor([0, 1, 0])
c = torch.cross(a, b)
print("a:", a)
print("b:", b)
print("a x b:", c)
# 输出: tensor([0, 0, 1])

# 批量计算叉积
a = torch.tensor([[1, 0, 0], [0, 1, 0]])
b = torch.tensor([[0, 1, 0], [1, 0, 0]])
result = torch.cross(a, b)
print("批量叉积:")
print(result)
# tensor([[0, 0, 1],
#         [0, 0, -1]])

# 指定维度
a = torch.randn(3, 4, 3)
b = torch.randn(3, 4, 3)
result = torch.cross(a, b, dim=2)
print("指定维度叉积形状:", result.shape)

Pytorch torch 参考手册 Pytorch torch 参考手册