现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.from_numpy 函数


Pytorch torch 参考手册 Pytorch torch 参考手册

torch.from_numpy 是 PyTorch 中用于从 NumPy 数组创建张量的函数。创建的张量与原始 NumPy 数组共享内存。

这在需要同时使用 PyTorch 和 NumPy 进行数据处理的场景中非常有用。

函数定义

torch.from_numpy(ndarray)

参数:

  • ndarray (numpy.ndarray): 输入的 NumPy 数组。

返回值:

  • torch.Tensor: 返回一个与 NumPy 数组共享内存的张量。

使用示例

示例 1: 从 NumPy 数组创建张量

实例

import torch
import numpy as np

# 创建 NumPy 数组
numpy_array = np.array([[1, 2, 3], [4, 5, 6]])

# 转换为 PyTorch 张量
tensor = torch.from_numpy(numpy_array)

print("NumPy 数组:")
print(numpy_array)
print("PyTorch 张量:")
print(tensor)

输出结果为:

NumPy 数组:
[[1 2 3]
 [4 5 6]]
PyTorch 张量:
tensor([[1, 2, 3],
        [4, 5, 6]])

示例 2: 内存共享

实例

import torch
import numpy as np

# 创建 NumPy 数组
numpy_array = np.array([1, 2, 3])

# 转换为 PyTorch 张量
tensor = torch.from_numpy(numpy_array)

# 修改张量
tensor[0] = 100

# NumPy 数组也会改变
print("NumPy 数组:", numpy_array)
print("PyTorch 张量:", tensor)

输出结果为:

NumPy 数组: [100   2   3]
PyTorch 张量: tensor([100,   2,   3])

由于共享内存,修改一个会影响另一个。

示例 3: 数据类型转换

实例

import torch
import numpy as np

# NumPy 默认创建 int64
numpy_array = np.array([1, 2, 3])

tensor = torch.from_numpy(numpy_array)
print("dtype:", tensor.dtype)  # torch.int64

Pytorch torch 参考手册 Pytorch torch 参考手册