PyTorch torch.eye 函数
torch.eye 是 PyTorch 中用于创建单位矩阵(对角线上元素为 1,其余为 0 的矩阵)的函数。
这在深度学习中常用于初始化、创建掩码等场景。
函数定义
torch.eye(n, m, dtype, device, requires_grad)
参数:
n(int): 行数。m(int, 可选): 列数。如果未指定,则创建 n×n 的方阵。dtype(torch.dtype, 可选): 数据类型。device(torch.device, 可选): 设备。requires_grad(bool, 可选): 是否需要计算梯度。
返回值:
torch.Tensor: 返回单位矩阵。
使用示例
示例 1: 创建方阵
实例
import torch
# 创建 3x3 单位矩阵
I = torch.eye(3)
print(I)
# 创建 3x3 单位矩阵
I = torch.eye(3)
print(I)
输出结果为:
tensor([[1., 0., 0.],
[0., 1., 0.],
[0., 0., 1.]])
示例 2: 创建非方阵
实例
import torch
# 创建 3x4 的单位矩阵
I = torch.eye(3, 4)
print(I)
# 创建 3x4 的单位矩阵
I = torch.eye(3, 4)
print(I)
输出结果为:
tensor([[1., 0., 0., 0.],
[0., 1., 0., 0.],
[0., 0., 1., 0.]])

Pytorch torch 参考手册