PyTorch torch.vmap 函数
torch.vmap 是 PyTorch 中用于向量映射的函数。它接受一个函数作为输入,并返回一个新的函数,这个新函数可以自动对批量的张量进行操作,类似于 JAX 中的 vmap。
函数定义
torch.vmap(func, in_dims, out_dims, randomness)
参数说明
func: 要向量化的函数in_dims: 输入张量的批处理维度(可选)out_dims: 输出张量的批处理维度(可选)randomness: 随机行为,可选 "error", "different", "same"
使用示例
实例
import torch
# 定义一个简单的函数
def simple_func(x):
return x * 2 + 1
# 使用 vmap 向量化函数
vectorized_func = torch.vmap(simple_func)
# 批量输入(批处理维度为0)
batch_input = torch.randn(4, 3)
# 应用向量化函数
output = vectorized_func(batch_input)
print("输入形状:", batch_input.shape)
print("输出形状:", output.shape)
print("输出:")
print(output)
# 定义一个简单的函数
def simple_func(x):
return x * 2 + 1
# 使用 vmap 向量化函数
vectorized_func = torch.vmap(simple_func)
# 批量输入(批处理维度为0)
batch_input = torch.randn(4, 3)
# 应用向量化函数
output = vectorized_func(batch_input)
print("输入形状:", batch_input.shape)
print("输出形状:", output.shape)
print("输出:")
print(output)
输出结果为:
输入形状: torch.Size([4, 3])
输出形状: torch.Size([4, 3])
输出:
tensor([[ 0.2345, 1.5678, -0.3456],
[ 2.1234, -1.2345, 0.5678],
[-0.8765, 1.2345, 2.3456],
[ 1.5678, 0.1234, -1.2345]])

Pytorch torch 参考手册