Skip to main content

Pytorch nn AND nn.Functional

PyTorch API 技术指南:原理与实践

本指南旨在为开发者提供一份关于 PyTorch 核心 API 的技术参考。内容涵盖了从模型架构定义到高级张量操作的常用接口,并对每个 API 的基本原理和标准用法进行了解析,辅以简明的代码示例。

第一部分:核心模型架构 (torch.nn)

torch.nn 命名空间提供了构建神经网络所需的所有基础模块。这些模块是面向对象的设计,封装了可学习的参数和相应的计算逻辑。

1.1 nn.Module: 所有神经网络模块的基类

原理: nn.Module 是 PyTorch 模型构建的核心。任何自定义的网络层或完整的模型都应继承此类。它提供了参数跟踪、子模块注册、设备转移 (.to(device))、模式切换 (.train()/.eval()) 等一系列基础功能。

示例:

import torch
import torch.nn as nn

class CustomModel(nn.Module):
    def __init__(self, in_features, out_features):
        super(CustomModel, self).__init__()
        self.layer = nn.Linear(in_features, out_features)

    def forward(self, x):
        return self.layer(x)

# 实例化模型
model = CustomModel(128, 10)
print(model)
# 输出:
# CustomModel(
#   (layer): Linear(in_features=128, out_features=10, bias=True)
# )

1.2 nn.Parameter: 可学习参数的封装

原理: nn.Parametertorch.Tensor 的一个包装类。当一个 nn.Parameter 对象被赋值为 nn.Module 的属性时,它会自动被注册为模型的可学习参数。这意味着在反向传播后,优化器会计算并更新其梯度。

示例:

class CustomLayer(nn.Module):
    def __init__(self, size):
        super(CustomLayer, self).__init__()
        # 定义一个可学习的缩放因子
        self.scale = nn.Parameter(torch.ones(size))

    def forward(self, x):
        return x * self.scale

# 访问参数
layer = CustomLayer(5)
for param in layer.parameters():
    print(f"Parameter Shape: {param.shape}, Requires Grad: {param.requires_grad}")

1.3 nn.Linear: 全连接层

原理: nn.Linear 对输入数据应用一个线性变换:$y = xA^T + b$。它内部维护了权重矩阵 A 和偏置向量 b 作为可学习参数。

示例:

# 输入: (Batch Size, in_features)
input_tensor = torch.randn(64, 128)
linear_layer = nn.Linear(in_features=128, out_features=10)
output = linear_layer(input_tensor)
print(f"Output shape: {output.shape}") # torch.Size([64, 10])

1.4 nn.Conv2d: 二维卷积层

原理: nn.Conv2d 在输入的二维信号(如图像)上应用二维卷积操作。它通过滑动一个可学习的卷积核(kernel)来提取局部特征。关键参数包括 in_channels(输入通道数)、out_channels(输出通道数,即卷积核数量)、kernel_size(卷积核尺寸)、stridepadding

示例:

# 输入: (Batch, Channels, Height, Width)
input_image = torch.randn(16, 3, 224, 224)
conv_layer = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=1)
feature_maps = conv_layer(input_image)
print(f"Feature map shape: {feature_maps.shape}") # torch.Size([16, 32, 224, 224])

1.5 nn.Sequential: 顺序容器

原理: nn.Sequential 是一个模块容器,它会按照模块被传入构造函数的顺序,依次将输入数据传递给每个模块。这是一种快速构建简单线性堆叠模型的便捷方式。

示例:

model = nn.Sequential(
    nn.Linear(784, 256),
    nn.ReLU(),
    nn.Linear(256, 10)
)
input_data = torch.randn(64, 784)
output = model(input_data)
print(f"Sequential model output shape: {output.shape}") # torch.Size([64, 10])

1.6 nn.Dropout: Dropout 正则化

原理: Dropout 是一种在训练期间使用的正则化技术。它会以指定的概率 p 将输入张量中的部分元素随机置为零,其余元素则按 1/(1-p) 的比例进行缩放,以保持整体期望值不变。这能有效防止神经元之间的共适应,降低过拟合风险。在评估模式 (.eval())下,nn.Dropout 不会执行任何操作。

示例:

dropout_layer = nn.Dropout(p=0.5)
input_tensor = torch.ones(1, 10)
dropout_layer.train() # 切换到训练模式
output_train = dropout_layer(input_tensor)
print(f"Train mode output (example): {output_train}") # 约一半元素为0,其余为2.0

dropout_layer.eval() # 切换到评估模式
output_eval = dropout_layer(input_tensor)
print(f"Eval mode output: {output_eval}") # 所有元素仍为1.0

第二部分:函数式接口 (torch.nn.functional)

torch.nn.functional(通常简写为 F)提供了一系列无状态的函数,它们是 nn 模块的底层实现。这些函数不包含可学习参数,调用时需要显式传入所有输入,包括权重。

2.1 F.relu: ReLU 激活函数

原理: 修正线性单元(Rectified Linear Unit)是一种分段线性函数,其数学表达式为 $f(x) = \max(0, x)$。它能有效缓解梯度消失问题,且计算开销小。

示例:

import torch.nn.functional as F

input_tensor = torch.tensor([-1.0, 0.0, 2.0])
output = F.relu(input_tensor)
print(f"ReLU output: {output}") # tensor([0., 0., 2.])

2.2 F.softmax: Softmax 激活函数

原理: Softmax 函数将一个实数向量转换为一个概率分布。它对向量中的每个元素应用指数函数,然后将结果归一化,使得所有元素的和为1。dim 参数指定了进行归一化的维度。

示例:

logits = torch.tensor([[1.0, 2.0, 3.0], [1.0, 1.0, 1.0]])
probabilities = F.softmax(logits, dim=1)
print(f"Softmax probabilities:\n{probabilities}")
# 输出:
# tensor([[0.0900, 0.2447, 0.6652],
#         [0.3333, 0.3333, 0.3333]])

2.3 F.cross_entropy: 交叉熵损失

原理: 交叉熵损失函数通常用于多分类任务。它在内部高效地集成了 LogSoftmaxNLLLoss(负对数似然损失)。它衡量的是模型预测的概率分布与真实的类别标签(通常是 one-hot 形式,但函数接口接收类别索引)之间的差异。

示例:

# 模型输出 (logits) 和真实标签
logits = torch.randn(4, 5) # 4个样本, 5个类别
labels = torch.tensor([1, 0, 4, 2]) # 每个样本的真实类别索引

loss = F.cross_entropy(logits, labels)
print(f"Cross Entropy Loss: {loss.item()}")

2.4 F.mse_loss: 均方误差损失

原理: 均方误差(Mean Squared Error)用于回归任务。它计算的是模型预测值与真实值之间差的平方的平均值。

示例:

predictions = torch.tensor([1.5, 2.8, 4.1])
targets = torch.tensor([1.0, 3.0, 4.0])
loss = F.mse_loss(predictions, targets)
print(f"Mean Squared Error Loss: {loss.item()}")

第三部分:张量初始化与操作

3.1 torch.nn.init.*: 权重初始化

原理: torch.nn.init 模块提供了多种权重初始化策略,以帮助模型更好地收敛。kaiming_uniform_ 是其中一种,它根据 He 初始化策略从均匀分布中采样,适用于配合 ReLU 激活函数。函数名末尾的下划线 _ 表示这是一个原地(in-place)操作。

示例:

import torch.nn.init as init
import math

weight = torch.empty(3, 5)
init.kaiming_uniform_(weight, a=math.sqrt(5))
print("Kaiming Initialized Weight (sample):\n", weight)

3.2 torch.Tensor.view: 张量视图重塑

原理: .view() 方法可以在不改变底层数据存储的情况下,高效地改变张量的形状。新旧形状的元素总数必须一致。返回的是一个共享数据内存的新张量视图。

示例:

x = torch.arange(12) # tensor([ 0,  1, ..., 11])
y = x.view(3, 4)
print(f"Reshaped tensor:\n{y}")

3.3 torch.topk: 查找 Top-K 元素

原理: torch.topk 函数用于在指定维度上查找输入张量中最大(或最小)的 k 个值及其索引。

示例:

x = torch.tensor([[1, 9, 2], [8, 3, 5]])
values, indices = torch.topk(x, k=2, dim=1)
print(f"Top-2 Values: {values}")   # tensor([[9, 2], [8, 5]])
print(f"Top-2 Indices: {indices}") # tensor([[1, 2], [0, 2]])

3.4 torch.Tensor.scatter_add_: 高级索引聚合

原理: scatter_add_ 是一个原地操作,它将一个源张量 src 中的值,根据 index 张量指定的位置,累加到目标张量 self 中。这对于实现复杂的稀疏更新或自定义聚合操作非常有用。

示例:

# 目标: 计算每个类别的得分总和
num_classes = 3
scores = torch.tensor([0.9, 0.1, 0.8, 0.5, 0.7])
class_indices = torch.tensor([0, 1, 0, 2, 1])

# 初始化类别总分张量
total_scores = torch.zeros(num_classes)

# 聚合分数
total_scores.scatter_add_(dim=0, index=class_indices, src=scores)
print(f"Total scores per class: {total_scores}") # tensor([1.7000, 0.8000, 0.5000])

第四部分:高级与专用网络层 (torch.nn)

本部分将介绍更专门化的 nn.Module,它们是构建现代深度学习架构(如 CNNs、RNNs 和 Transformers)的关键组件。

4.1 nn.BatchNorm2d: 二维批量归一化

原理: 批量归一化(Batch Normalization)是一种用于稳定和加速深度神经网络训练的技术。nn.BatchNorm2d 专门用于四维输入(典型的如图像数据 [N, C, H, W])。在训练时,它对一个 mini-batch 内的数据,沿着通道(Channel)维度进行归一化,使其均值为0,方差为1。此外,它还包含两个可学习的仿射变换参数($\gamma$ 和 $\beta$),用于恢复网络的表达能力。在评估模式下,它会使用在训练过程中累积的全局均值和方差进行归一化。

示例:

# 4个样本, 3个通道, 32x32的图像
input_tensor = torch.randn(4, 3, 32, 32)
# 归一化的特征数量必须与输入通道数匹配
batch_norm_layer = nn.BatchNorm2d(num_features=3)

# 在训练模式下
batch_norm_layer.train()
output_train = batch_norm_layer(input_tensor)
# output_train 的每个通道在 batch 维度上均值接近0,方差接近1

# 在评估模式下
batch_norm_layer.eval()
output_eval = batch_norm_layer(input_tensor)
print(f"Output shape remains the same: {output_eval.shape}")

4.2 nn.LayerNorm: 层归一化

原理: 层归一化(Layer Normalization)是另一种归一化技术。与批量归一化不同,它在单个样本内部对特征进行归一化,因此其计算完全独立于 batch 中的其他样本。这使得它非常适用于序列数据(如 NLP 任务中的 RNN 和 Transformer),其中序列长度可变,使用批量归一化会很棘手。它对指定 normalized_shape 的维度进行归一化。

示例:

# 2个样本, 序列长度10, 特征维度20
input_seq = torch.randn(2, 10, 20)
# 对最后一个维度(特征维度)进行归一化
layer_norm = nn.LayerNorm(normalized_shape=20)
output_norm = layer_norm(input_seq)

# output_norm 中每个样本的每个时间步的特征向量均值为0,方差为1
print(f"LayerNorm output shape: {output_norm.shape}") # torch.Size([2, 10, 20])

4.3 nn.Embedding: 嵌入层

原理: nn.Embedding 本质上是一个大型的查找表(lookup table)。它存储了固定大小词汇表或类别库的密集向量(embeddings)。当输入一个由类别索引组成的列表时,它会返回这些索引对应的嵌入向量。这在自然语言处理中用于将单词ID映射为词向量,或在推荐系统中将用户/物品ID映射为特征向量。其内部维护一个大小为 (num_embeddings, embedding_dim) 的可学习权重矩阵。

示例:

# 假设词汇表大小为100, 每个词用一个32维向量表示
embedding_layer = nn.Embedding(num_embeddings=100, embedding_dim=32)

# 输入是一个由单词索引组成的张量 (2个句子, 每个句子4个词)
input_indices = torch.LongTensor([[1, 2, 4, 5], [4, 3, 2, 9]])
embedded_vectors = embedding_layer(input_indices)

print(f"Embedded vectors shape: {embedded_vectors.shape}") # torch.Size([2, 4, 32])

4.4 nn.LSTM: 长短期记忆网络层

原理: LSTM (Long Short-Term Memory) 是一种特殊的循环神经网络(RNN),旨在解决传统 RNN 中的长期依赖和梯度消失问题。它通过引入一个单元状态(cell state)和三个门(遗忘门、输入门、输出门)来精细地控制信息的流动。这使得模型能够选择性地记忆、更新和输出信息,从而在长序列上表现更佳。

示例:

# 输入特征维度10, 隐藏状态维度20, 堆叠2层
lstm_layer = nn.LSTM(input_size=10, hidden_size=20, num_layers=2)

# 输入序列: 长度为5, batch大小为3, 特征维度10
input_seq = torch.randn(5, 3, 10)
# 初始隐藏状态和单元状态
h0 = torch.randn(2, 3, 20) # (num_layers, batch, hidden_size)
c0 = torch.randn(2, 3, 20)

# 前向传播
output, (hn, cn) = lstm_layer(input_seq, (h0, c0))

print(f"Output sequence shape: {output.shape}") # (seq_len, batch, hidden_size) -> (5, 3, 20)
print(f"Final hidden state shape: {hn.shape}") # (num_layers, batch, hidden_size) -> (2, 3, 20)

第五部分:高级张量操作与数据处理

本部分关注于对张量进行结构性变换和条件化处理的函数,这些是数据预处理和模型内部数据流控制的基础。

5.1 torch.cat: 张量拼接

原理: torch.cat 用于将一个张量序列沿着一个已存在的维度进行拼接。所有待拼接的张量在非拼接维度上必须具有相同的尺寸。

示例:

x = torch.randn(2, 3)
y = torch.randn(4, 3)
z = torch.randn(2, 2)

# 沿维度0 (行) 拼接, 要求列数相同
cat_dim0 = torch.cat([x, y], dim=0)
print(f"Concat on dim 0 shape: {cat_dim0.shape}") # torch.Size([6, 3])

# 沿维度1 (列) 拼接, 要求行数相同
cat_dim1 = torch.cat([x, z], dim=1)
print(f"Concat on dim 1 shape: {cat_dim1.shape}") # torch.Size([2, 5])

5.2 torch.stack: 张量堆叠

原理: torch.stack 用于将一个张量序列沿着一个新的维度进行堆叠。所有待堆叠的张量必须具有完全相同的尺寸。这相当于在拼接前为每个张量增加一个新维度。

示例:

a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])

# 沿新的维度0进行堆叠
stacked_tensor = torch.stack([a, b], dim=0)
print(f"Stacked tensor:\n{stacked_tensor}")
print(f"Stacked tensor shape: {stacked_tensor.shape}") # torch.Size([2, 3])

5.3 torch.squeeze & torch.unsqueeze: 维度压缩与扩展

原理:

  • torch.unsqueeze(input, dim): 在指定维度 dim 上为输入张量增加一个大小为1的新维度。
  • torch.squeeze(input, dim=None): 移除输入张量中所有大小为1的维度。如果指定了 dim,则只移除该维度(如果其大小为1)。

这两个操作对于调整张量维度以满足特定层的输入要求(如卷积层或批处理)至关重要。

示例:

x = torch.randn(3, 4) # Shape: [3, 4]
# 在维度0增加一个批处理维度
y = torch.unsqueeze(x, 0)
print(f"Unsqueeze shape: {y.shape}") # Shape: [1, 3, 4]

# 移除所有大小为1的维度
z = torch.squeeze(y)
print(f"Squeeze shape: {z.shape}") # Shape: [3, 4]

5.4 torch.where: 条件化选择

原理: torch.where(condition, x, y) 是一个元素级的条件判断函数。它返回一个新的张量,其元素根据 condition 张量(一个布尔张量)来选择:如果 condition 中对应位置的元素为 True,则从 x 中取值;否则从 y 中取值。

示例:

x = torch.tensor([-2.0, 0.0, 3.0, -1.0])
# 将所有负数替换为0 (类似于 ReLU 的一种实现)
output = torch.where(x > 0, x, torch.tensor(0.))
print(f"torch.where output: {output}") # tensor([0., 0., 3., 0.])

5.5 torch.permute: 维度重排

原理: torch.permute 用于根据指定的维度顺序,对张量的维度进行重排。它返回一个共享底层数据存储的新视图。这在需要将数据格式从 (N, H, W, C) 转换为 (N, C, H, W)(PyTorch 标准)时非常有用。

示例:

# 一个典型的 "channels-last" 格式的图像 batch
x_channels_last = torch.randn(16, 32, 32, 3) # (N, H, W, C)

# 重排为 "channels-first"
x_channels_first = x_channels_last.permute(0, 3, 1, 2) # (N, C, H, W)
print(f"Permuted shape: {x_channels_first.shape}") # torch.Size([16, 3, 32, 32])