Skip to main content

Pytorch 数学运算

PyTorch 数学运算核心指南:从 sum@

在深度学习的引擎盖下,一切皆为数学。损失的计算、梯度的传播、权重的更新、性能的评估,每一个环节都离不开大量的数学运算。PyTorch 提供了一套极其丰富、高度优化的数学函数库,它们是构建和训练任何神经网络的基石。

本指南将系统地梳理 PyTorch 中最重要的数学运算,并将其分为四大类:规约(Reduction)运算逐元素(Element-wise)运算线性代数运算比较运算,帮助你彻底掌握这些核心工具。

第一部分:规约运算 - 从张量中提炼信息

规约(Reduction)操作会将一个张量沿着指定的维度进行计算,输出一个维度更少(或被“压缩”)的张量,从而实现信息的提炼与汇总。

1. torch.sum(input, dim=None, keepdim=False)

  • 原理: 计算张量中所有元素或沿着指定维度 dim 的元素之和。

  • 关键参数:

    • dim: 指定要规约的维度。可以是单个维度或一个元组。
    • keepdim=True: 这是个至关重要的参数。它会在输出中保留被规约的维度,但其大小变为1。这对于后续的广播(Broadcasting)操作至关重要。
  • 应用示例:

    x = torch.tensor([[1, 2, 3], [4, 5, 6]])
    
    # 1. 全局求和
    total_sum = torch.sum(x)
    print(f"Total sum: {total_sum}") # tensor(21)
    
    # 2. 沿列求和 (规约行,dim=0)
    col_sum = torch.sum(x, dim=0)
    print(f"Column sum: {col_sum}") # tensor([5, 7, 9])
    
    # 3. 沿行求和 (规约列,dim=1),并保持维度
    row_sum_kept = torch.sum(x, dim=1, keepdim=True)
    print(f"Row sum (keepdim=True):\n{row_sum_kept}")
    print(f"Shape with keepdim=True: {row_sum_kept.shape}") # torch.Size([2, 1])
    # 保持维度使得广播成为可能,例如,按行进行归一化
    # normalized_x = x / row_sum_kept
    

2. torch.mean(input, dim=None, keepdim=False)

  • 原理: 计算算术平均值。参数 dimkeepdim 的作用与 sum 完全相同。
  • 应用示例: 计算批处理损失。在训练循环中,每个样本都会产生一个损失值,你需要计算整个批次的平均损失来进行反向传播。
    batch_losses = torch.tensor([1.2, 0.9, 1.5, 1.0]) # 一个batch中4个样本的损失
    mean_loss = torch.mean(batch_losses)
    print(f"Mean loss for the batch: {mean_loss}") # tensor(1.1500)
    

3. torch.max(input, dim=None)torch.min(input, dim=None)

  • 原理: 查找最大/最小值。当指定 dim 时,它们会返回一个包含两个张量的命名元组 (values, indices),分别是被找到的值和它们在该维度上的索引。
  • 应用示例: 从 Logits 中获取预测类别。这是 torch.max 最经典的应用场景。
    logits = torch.randn(4, 5) # 4个样本, 5个类别
    print(f"Logits:\n{logits}")
    
    # 沿类别维度 (dim=1) 寻找最大值
    max_values, predicted_classes = torch.max(logits, dim=1)
    
    print(f"\nMax logit values: {max_values}")
    print(f"Predicted class indices: {predicted_classes}") # 这就是模型的预测结果
    

4. torch.argmax(input, dim=None)torch.argmin(input, dim=None)

  • 原理: torch.maxtorch.min 的快捷方式,当 你只关心索引而不关心具体值 时使用。它直接返回最大/最小值的索引。
  • 应用示例: 与上面相同,但代码更直接。
    predicted_classes = torch.argmax(logits, dim=1)
    print(f"Predicted class indices (using argmax): {predicted_classes}")
    

第二部分:逐元素运算 - 并行处理的基础

这类操作会对输入的每个元素独立地应用一个函数。输出张量的形状通常与输入张量相同。

1. 基础算术: add, sub, mul, div

  • 原理: 执行加减乘除。可以直接使用操作符 +, -, *, /。支持广播机制。
  • 应用示例:
    x = torch.tensor([[1, 2], [3, 4]])
    y = torch.tensor([[5, 6], [7, 8]])
    
    # 逐元素相加
    z = x + y # 等价于 torch.add(x, y)
    print(f"x + y:\n{z}")
    
    # 广播一个标量
    w = x * 2
    print(f"x * 2:\n{w}")
    

2. 数学函数: sqrt, exp, log, pow

  • 原理: 应用标准的数学函数,如平方根、指数、对数、幂函数等。
  • 应用示例: 实现 RMSNorm (Root Mean Square Normalization)
    # 简化版的 RMSNorm
    variance = torch.mean(x.pow(2), dim=-1, keepdim=True)
    # 加上一个很小的epsilon防止除以零
    normalized_x = x * torch.rsqrt(variance + 1e-5) # rsqrt是1/sqrt,效率更高
    

3. torch.clamp(input, min=None, max=None)

  • 原理: 将张量中的所有元素限制(裁剪)在一个指定的闭区间 [min, max] 内。
  • 应用示例: 梯度裁剪实现 ReLU6 激活函数
    # 梯度裁剪: 防止梯度爆炸
    gradient = torch.randn(4) * 10 # 一个可能很大的梯度
    clipped_grad = torch.clamp(gradient, min=-1.0, max=1.0)
    print(f"Original Gradient: {gradient}")
    print(f"Clipped Gradient: {clipped_grad}")
    
    # ReLU6: f(x) = min(max(0, x), 6)
    relu6_output = torch.clamp(torch.randn(4), min=0, max=6)
    

第三部分:线性代数 - 深度学习的语言

神经网络中的大部分计算都是线性代数运算,尤其是矩阵乘法。

1. torch.matmul(input, other)@ 操作符

  • 原理: 这是一个功能极其强大的矩阵乘法函数,其行为根据输入张量的维度而变化:
    • 1D x 1D: 向量内积(Dot Product),返回一个标量。
    • 2D x 2D: 标准的矩阵-矩阵乘法。
    • N-D x M-D (N,M > 1): 支持批处理的矩阵乘法(Batched Matrix Multiplication)。这是深度学习中最重要的形式。
  • 应用示例: 手动实现一个线性层注意力机制中的矩阵乘法
    # 1. 线性层
    inputs = torch.randn(128, 64) # (Batch, In_Features)
    weights = torch.randn(32, 64) # (Out_Features, In_Features)
    output = torch.matmul(inputs, weights.T) # .T 是 .transpose(-1, -2) 的简写
    print(f"Linear layer output shape: {output.shape}") # torch.Size([128, 32])
    
    # 2. 批处理矩阵乘法 (注意力分数计算)
    # (Batch, Heads, SeqLen, HeadDim)
    queries = torch.randn(32, 8, 100, 64)
    keys = torch.randn(32, 8, 100, 64)
    # 使用 @ 操作符更简洁
    attention_scores = queries @ keys.transpose(-2, -1)
    print(f"Attention scores shape: {attention_scores.shape}") # torch.Size([32, 8, 100, 100])
    

2. 其他矩阵函数: mm, bmm, dot

  • torch.mm(mat1, mat2): 严格的 2D x 2D 矩阵乘法,不支持广播。
  • torch.bmm(batch1, batch2): 严格的批处理 矩阵乘法 (B, n, m) @ (B, m, p),不支持批次维度的广播。
  • torch.dot(vec1, vec2): 严格的 1D x 1D 向量内积。

选择建议: 优先使用 @torch.matmul,因为它们最灵活、最通用。仅在确定输入维度且追求极致代码明确性时,才使用专用函数。

第四部分:比较运算 - 构建逻辑与掩码

比较运算对张量进行逐元素比较,返回一个包含 TrueFalse 的布尔张量,它通常被用作 掩码(Mask)

运算符: >, <, >=, <=, ==, !=

  • 原理: torch.gt (大于), torch.lt (小于), torch.eq (等于) 等函数的便捷操作符。
  • 应用示例: 计算分类准确率
    labels = torch.tensor([1, 0, 4, 2])
    predictions = torch.tensor([1, 3, 4, 2]) # 模型的预测类别
    
    # 1. 生成布尔掩码
    correct_mask = (predictions == labels)
    print(f"Correctness mask: {correct_mask}") # tensor([ True, False,  True,  True])
    
    # 2. 使用掩码计算准确率
    # correct_mask.sum() 会将 True 当作 1,False 当作 0
    accuracy = correct_mask.sum() / len(labels)
    print(f"Accuracy: {accuracy.item():.2f}") # 0.75
    

关于原地(In-place)操作的说明

PyTorch 中许多函数都有一个带下划线 _ 的版本,如 add_()。这些是 原地操作,它们会直接修改调用者自身的数据,而不是返回一个新张量。

x = torch.ones(1)
y = torch.ones(1)

# 非原地操作
z = x.add(y)
print(f"x after add: {x}") # x 依然是 tensor([1.])

# 原地操作
x.add_(y)
print(f"x after add_: {x}") # x 变成了 tensor([2.])
```**使用建议**:
*   **优点**: 节省内存,因为避免了创建新张量。
*   **缺点**: 在计算图中,如果一个张量在反向传播时还需要使用其原始值,对其进行原地操作会破坏计算图,导致 autograd 错误。
*   **黄金法则**: 谨慎使用。在不影响反向传播的叶子节点或确定不再需要原始值时使用是安全的。

掌握这些数学运算是释放 PyTorch 全部潜力的关键。通过将它们像乐高积木一样组合起来,你就可以构建出任何复杂的计算逻辑。