PyTorch torch.nn.Tanh 函数
torch.nn.Tanh 是 PyTorch 中的双曲正切激活函数。
它将输入值映射到 -1 到 1 之间,输出零中心化,常用于循环神经网络。
函数定义
torch.nn.Tanh()
数学原理
Tanh(x) = (e^x - e^(-x)) / (e^x + e^(-x))
使用示例
示例 1: 基本用法
实例
import torch
import torch.nn as nn
tanh = nn.Tanh()
x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0])
output = tanh(x)
print("输入:", x.tolist())
print("输出:", output.tolist())
print("特点: 输出范围 [-1, 1],零中心化")
import torch.nn as nn
tanh = nn.Tanh()
x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0])
output = tanh(x)
print("输入:", x.tolist())
print("输出:", output.tolist())
print("特点: 输出范围 [-1, 1],零中心化")
示例 2: 在 LSTM 中使用
实例
import torch
import torch.nn as nn
class SimpleLSTMCell(nn.Module):
def __init__(self, input_size, hidden_size):
super(SimpleLSTMCell, self).__init__()
self.hidden_size = hidden_size
# 门控机制使用 Tanh 和 Sigmoid
self.tanh = nn.Tanh()
self.sigmoid = nn.Sigmoid()
def forward(self, x, hidden):
h, c = hidden
# 简化的门控计算
gates = self.sigmoid(x @ torch.randn(x.shape[1], self.hidden_size * 4))
# Tanh 用于候选记忆
candidate = self.tanh(x @ torch.randn(x.shape[1], self.hidden_size))
return candidate, torch.zeros_like(h)
# 测试
cell = SimpleLSTMCell(10, 20)
x = torch.randn(1, 10)
h = torch.randn(1, 20)
c = torch.randn(1, 20)
new_h, new_c = cell(x, (h, c))
print("输入形状:", x.shape)
print("隐藏状态形状:", new_h.shape)
import torch.nn as nn
class SimpleLSTMCell(nn.Module):
def __init__(self, input_size, hidden_size):
super(SimpleLSTMCell, self).__init__()
self.hidden_size = hidden_size
# 门控机制使用 Tanh 和 Sigmoid
self.tanh = nn.Tanh()
self.sigmoid = nn.Sigmoid()
def forward(self, x, hidden):
h, c = hidden
# 简化的门控计算
gates = self.sigmoid(x @ torch.randn(x.shape[1], self.hidden_size * 4))
# Tanh 用于候选记忆
candidate = self.tanh(x @ torch.randn(x.shape[1], self.hidden_size))
return candidate, torch.zeros_like(h)
# 测试
cell = SimpleLSTMCell(10, 20)
x = torch.randn(1, 10)
h = torch.randn(1, 20)
c = torch.randn(1, 20)
new_h, new_c = cell(x, (h, c))
print("输入形状:", x.shape)
print("隐藏状态形状:", new_h.shape)
示例 3: 对比 Sigmoid
实例
import torch
import torch.nn as nn
import numpy as np
x = np.linspace(-3, 3, 11)
x_tensor = torch.tensor(x, dtype=torch.float32)
sigmoid = nn.Sigmoid()
tanh = nn.Tanh()
print("x Sigmoid Tanh")
print("-" * 35)
for i in range(0, 11, 2):
xi = x_tensor[i:i+2]
print(f"{xi[0].item():5.1f} {sigmoid(xi)[0].item():9.4f} {tanh(xi)[0].item():9.4f}")
import torch.nn as nn
import numpy as np
x = np.linspace(-3, 3, 11)
x_tensor = torch.tensor(x, dtype=torch.float32)
sigmoid = nn.Sigmoid()
tanh = nn.Tanh()
print("x Sigmoid Tanh")
print("-" * 35)
for i in range(0, 11, 2):
xi = x_tensor[i:i+2]
print(f"{xi[0].item():5.1f} {sigmoid(xi)[0].item():9.4f} {tanh(xi)[0].item():9.4f}")
常见问题
Q1: Tanh 和 Sigmoid 的区别?
Tanh 输出范围 [-1,1](零中心化),Sigmoid 输出 [0,1]。
Q2: 为什么 LSTM 用 Tanh?
Tanh 的零中心化特性使梯度流动更稳定。
使用场景
- 循环神经网络: LSTM、GRU 默认激活
- 生成模型: GAN 的生成器
- 门控机制: 控制信息范围

PyTorch torch.nn 参考手册