现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.nn.Bilinear 函数

PyTorch torch.nn 参考手册 PyTorch torch.nn 参考手册


torch.nn.Bilinear 是 PyTorch 中的双线性层。

它对两个输入进行双线性变换,常用于特征融合。

函数定义

torch.nn.Bilinear(in1_features, out_features, out_features, bias=True)

公式

y = x1 * W * x2 + b

使用示例

示例 1: 基本用法

实例

import torch
import torch.nn as nn

# 双线性层:两个 100 维输入 -> 50 维输出
bilinear = nn.Bilinear(100, 100, 50)

# 两个输入
x1 = torch.randn(4, 100)
x2 = torch.randn(4, 100)

output = bilinear(x1, x2)

print("输入1:", x1.shape)
print("输入2:", x2.shape)
print("输出:", output.shape)

示例 2: 特征融合

实例

import torch
import torch.nn as nn

# 双线性特征融合(类似 Skip-gram)
class FusionNet(nn.Module):
    def __init__(self, dim=128):
        super(FusionNet, self).__init__()
        self.bilinear = nn.Bilinear(dim, dim, dim)

    def forward(self, feat1, feat2):
        return self.bilinear(feat1, feat2)

model = FusionNet()
f1 = torch.randn(4, 128)
f2 = torch.randn(4, 128)

output = model(f1, f2)
print("融合输出:", output.shape)

示例 3: 参数数量

实例

import torch
import torch.nn as nn

# 双线性层参数
bilinear = nn.Bilinear(100, 100, 50)

print("参数数量:", sum(p.numel() for p in bilinear.parameters()))
print("权重形状:", bilinear.weight.shape)  # (50, 100, 100)

使用场景

  • 特征融合: 多模态
  • 注意力机制
  • 交互建模

注意:双线性层参数量大,谨慎使用。


PyTorch torch.nn 参考手册 PyTorch torch.nn 参考手册