现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.vmap 函数


Pytorch torch 参考手册 Pytorch torch 参考手册

torch.vmap 是 PyTorch 中用于向量映射的函数。它接受一个函数作为输入,并返回一个新的函数,这个新函数可以自动对批量的张量进行操作,类似于 JAX 中的 vmap。

函数定义

torch.vmap(func, in_dims, out_dims, randomness)

参数说明

  • func: 要向量化的函数
  • in_dims: 输入张量的批处理维度(可选)
  • out_dims: 输出张量的批处理维度(可选)
  • randomness: 随机行为,可选 "error", "different", "same"

使用示例

实例

import torch

# 定义一个简单的函数
def simple_func(x):
    return x * 2 + 1

# 使用 vmap 向量化函数
vectorized_func = torch.vmap(simple_func)

# 批量输入(批处理维度为0)
batch_input = torch.randn(4, 3)

# 应用向量化函数
output = vectorized_func(batch_input)

print("输入形状:", batch_input.shape)
print("输出形状:", output.shape)
print("输出:")
print(output)

输出结果为:

输入形状: torch.Size([4, 3])
输出形状: torch.Size([4, 3])
输出:
tensor([[ 0.2345,  1.5678, -0.3456],
        [ 2.1234, -1.2345,  0.5678],
        [-0.8765,  1.2345,  2.3456],
        [ 1.5678,  0.1234, -1.2345]])

Pytorch torch 参考手册 Pytorch torch 参考手册