PyTorch torch.repeat_interleave 函数
torch.repeat_interleave 是 PyTorch 中用于沿指定维度重复元素的函数。
函数定义
torch.repeat_interleave(input, repeats, dim=None, *, output_size=None)
使用示例
实例
import torch
# 不指定 dim,沿所有元素重复
x = torch.tensor([1, 2, 3])
result = torch.repeat_interleave(x, 2)
print("每个元素重复 2 次:")
print(result)
# 按 dim=0 重复
y = torch.tensor([[1, 2], [3, 4]])
print("n原始张量:")
print(y)
result = torch.repeat_interleave(y, 2, dim=0)
print("沿 dim=0 重复 2 次:")
print(result)
# 每个元素不同重复次数
result = torch.repeat_interleave(y, torch.tensor([1, 2]), dim=0)
print("n沿 dim=0 不同重复次数 [1, 2]:")
print(result)
# 按 dim=1 重复
result = torch.repeat_interleave(y, 3, dim=1)
print("n沿 dim=1 重复 3 次:")
print(result)
# 返回 output_size
result = torch.repeat_interleave(x, 2, output_size=9)
print("n指定 output_size=9:")
print(result)
# 不指定 dim,沿所有元素重复
x = torch.tensor([1, 2, 3])
result = torch.repeat_interleave(x, 2)
print("每个元素重复 2 次:")
print(result)
# 按 dim=0 重复
y = torch.tensor([[1, 2], [3, 4]])
print("n原始张量:")
print(y)
result = torch.repeat_interleave(y, 2, dim=0)
print("沿 dim=0 重复 2 次:")
print(result)
# 每个元素不同重复次数
result = torch.repeat_interleave(y, torch.tensor([1, 2]), dim=0)
print("n沿 dim=0 不同重复次数 [1, 2]:")
print(result)
# 按 dim=1 重复
result = torch.repeat_interleave(y, 3, dim=1)
print("n沿 dim=1 重复 3 次:")
print(result)
# 返回 output_size
result = torch.repeat_interleave(x, 2, output_size=9)
print("n指定 output_size=9:")
print(result)
输出结果为:
每个元素重复 2 次:
tensor([1, 1, 2, 2, 3, 3])
原始张量:
tensor([[1, 2],
[3, 4]])
沿 dim=0 重复 2 次:
tensor([[1, 2],
[1, 2],
[3, 4],
[3, 4]])
沿 dim=0 不同重复次数 [1, 2]:
tensor([[1, 2],
[3, 4],
[3, 4]])
沿 dim=1 重复 3 次:
tensor([[1, 1, 1, 2, 2, 2],
[3, 3, 3, 4, 4, 4]])
指定 output_size=9:
tensor([1, 1, 2, 2, 3, 3, 1, 2, 3])

Pytorch torch 参考手册