现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.eye 函数


Pytorch torch 参考手册 Pytorch torch 参考手册

torch.eye 是 PyTorch 中用于创建单位矩阵(对角线上元素为 1,其余为 0 的矩阵)的函数。

这在深度学习中常用于初始化、创建掩码等场景。

函数定义

torch.eye(n, m, dtype, device, requires_grad)

参数:

  • n (int): 行数。
  • m (int, 可选): 列数。如果未指定,则创建 n×n 的方阵。
  • dtype (torch.dtype, 可选): 数据类型。
  • device (torch.device, 可选): 设备。
  • requires_grad (bool, 可选): 是否需要计算梯度。

返回值:

  • torch.Tensor: 返回单位矩阵。

使用示例

示例 1: 创建方阵

实例

import torch

# 创建 3x3 单位矩阵
I = torch.eye(3)

print(I)

输出结果为:

tensor([[1., 0., 0.],
        [0., 1., 0.],
        [0., 0., 1.]])

示例 2: 创建非方阵

实例

import torch

# 创建 3x4 的单位矩阵
I = torch.eye(3, 4)

print(I)

输出结果为:

tensor([[1., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 1., 0.]])

Pytorch torch 参考手册 Pytorch torch 参考手册