PyTorch torch.sum 函数
torch.sum 是 PyTorch 中用于计算张量元素之和的函数。它可以计算所有元素的和,也可以沿指定维度计算。
这是深度学习中常用的归约操作,用于损失计算、统计等场景。
函数定义
torch.sum(input, dim, keepdim, dtype, out)
参数:
input(Tensor): 输入张量。dim(int 或 tuple of int, 可选): 要计算的维度。如果为None,则计算所有元素的和。keepdim(bool, 可选): 是否保持维度。默认为False。dtype(torch.dtype, 可选): 输出张量的数据类型。out(Tensor, 可选): 输出张量。
返回值:
torch.Tensor: 返回计算后的张量。
使用示例
示例 1: 计算所有元素之和
实例
import torch
# 创建张量
x = torch.tensor([[1, 2, 3], [4, 5, 6]])
# 计算所有元素之和
total = torch.sum(x)
print("张量:")
print(x)
print("元素之和:", total)
# 创建张量
x = torch.tensor([[1, 2, 3], [4, 5, 6]])
# 计算所有元素之和
total = torch.sum(x)
print("张量:")
print(x)
print("元素之和:", total)
输出结果为:
张量:
tensor([[1, 2, 3],
[4, 5, 6]])
元素之和: tensor(21)
示例 2: 沿指定维度求和
实例
import torch
x = torch.tensor([[1, 2, 3], [4, 5, 6]])
# 沿 dim=0(列)求和
sum_dim0 = torch.sum(x, dim=0)
print("沿 dim=0 求和:", sum_dim0)
# 沿 dim=1(行)求和
sum_dim1 = torch.sum(x, dim=1)
print("沿 dim=1 求和:", sum_dim1)
x = torch.tensor([[1, 2, 3], [4, 5, 6]])
# 沿 dim=0(列)求和
sum_dim0 = torch.sum(x, dim=0)
print("沿 dim=0 求和:", sum_dim0)
# 沿 dim=1(行)求和
sum_dim1 = torch.sum(x, dim=1)
print("沿 dim=1 求和:", sum_dim1)
输出结果为:
沿 dim=0 求和: tensor([5, 7, 9]) 沿 dim=1 求和: tensor([ 6, 15])
示例 3: 使用 keepdim 保持维度
实例
import torch
x = torch.tensor([[1, 2, 3], [4, 5, 6]])
# 不保持维度
sum1 = torch.sum(x, dim=0)
print("不保持维度:", sum1.shape)
# 保持维度
sum2 = torch.sum(x, dim=0, keepdim=True)
print("保持维度:", sum2.shape)
print(sum2)
x = torch.tensor([[1, 2, 3], [4, 5, 6]])
# 不保持维度
sum1 = torch.sum(x, dim=0)
print("不保持维度:", sum1.shape)
# 保持维度
sum2 = torch.sum(x, dim=0, keepdim=True)
print("保持维度:", sum2.shape)
print(sum2)
输出结果为:
不保持维度: torch.Size([3]) 保持维度: torch.Size([1, 3]) tensor([[5, 7, 9]])
示例 4: 在神经网络中计算损失
实例
import torch
# 模拟预测值和真实值
predictions = torch.tensor([0.1, 0.9, 0.8, 0.3])
targets = torch.tensor([0.0, 1.0, 1.0, 0.0])
# 计算均方误差损失
loss = torch.sum((predictions - targets) ** 2) / len(predictions)
print("MSE 损失:", loss.item())
# 模拟预测值和真实值
predictions = torch.tensor([0.1, 0.9, 0.8, 0.3])
targets = torch.tensor([0.0, 1.0, 1.0, 0.0])
# 计算均方误差损失
loss = torch.sum((predictions - targets) ** 2) / len(predictions)
print("MSE 损失:", loss.item())
输出结果为:
MSE 损失: 0.07499999690771103

Pytorch torch 参考手册