PyTorch torch.nn.Bilinear 函数
torch.nn.Bilinear 是 PyTorch 中的双线性层。
它对两个输入进行双线性变换,常用于特征融合。
函数定义
torch.nn.Bilinear(in1_features, out_features, out_features, bias=True)
公式
y = x1 * W * x2 + b
使用示例
示例 1: 基本用法
实例
import torch
import torch.nn as nn
# 双线性层:两个 100 维输入 -> 50 维输出
bilinear = nn.Bilinear(100, 100, 50)
# 两个输入
x1 = torch.randn(4, 100)
x2 = torch.randn(4, 100)
output = bilinear(x1, x2)
print("输入1:", x1.shape)
print("输入2:", x2.shape)
print("输出:", output.shape)
import torch.nn as nn
# 双线性层:两个 100 维输入 -> 50 维输出
bilinear = nn.Bilinear(100, 100, 50)
# 两个输入
x1 = torch.randn(4, 100)
x2 = torch.randn(4, 100)
output = bilinear(x1, x2)
print("输入1:", x1.shape)
print("输入2:", x2.shape)
print("输出:", output.shape)
示例 2: 特征融合
实例
import torch
import torch.nn as nn
# 双线性特征融合(类似 Skip-gram)
class FusionNet(nn.Module):
def __init__(self, dim=128):
super(FusionNet, self).__init__()
self.bilinear = nn.Bilinear(dim, dim, dim)
def forward(self, feat1, feat2):
return self.bilinear(feat1, feat2)
model = FusionNet()
f1 = torch.randn(4, 128)
f2 = torch.randn(4, 128)
output = model(f1, f2)
print("融合输出:", output.shape)
import torch.nn as nn
# 双线性特征融合(类似 Skip-gram)
class FusionNet(nn.Module):
def __init__(self, dim=128):
super(FusionNet, self).__init__()
self.bilinear = nn.Bilinear(dim, dim, dim)
def forward(self, feat1, feat2):
return self.bilinear(feat1, feat2)
model = FusionNet()
f1 = torch.randn(4, 128)
f2 = torch.randn(4, 128)
output = model(f1, f2)
print("融合输出:", output.shape)
示例 3: 参数数量
实例
import torch
import torch.nn as nn
# 双线性层参数
bilinear = nn.Bilinear(100, 100, 50)
print("参数数量:", sum(p.numel() for p in bilinear.parameters()))
print("权重形状:", bilinear.weight.shape) # (50, 100, 100)
import torch.nn as nn
# 双线性层参数
bilinear = nn.Bilinear(100, 100, 50)
print("参数数量:", sum(p.numel() for p in bilinear.parameters()))
print("权重形状:", bilinear.weight.shape) # (50, 100, 100)
使用场景
- 特征融合: 多模态
- 注意力机制
- 交互建模
注意:双线性层参数量大,谨慎使用。

PyTorch torch.nn 参考手册