Pytorch Tensor 变换
PyTorch 维度变换权威指南:从 view
到 permute
在 PyTorch 中,数据以张量(Tensor)的形式流动。如果你想自如地构建、调试和优化神经网络,那么精通张量维度的变换操作是必不可少的基本功。无论是为了匹配模型的输入要求,还是为了实现复杂的算法逻辑,你都将频繁地与张量的形状(Shape)打交道。
本指南将深入探讨 PyTorch 中所有核心的维度操作 API,不仅解释它们的用法,更剖析其背后的原理——特别是 内存布局(Memory Layout) 这一关键概念,它决定了哪些操作高效,哪些操作会引发错误。
核心概念:视图(View) vs. 内存重排(Permutation)
要真正理解维度变换,首先必须区分两类操作:
-
视图类操作 (
view
,reshape
): 这类操作 不改变数据在内存中的物理存储顺序。它们只是改变了 PyTorch 解析这个数据的方式(即改变张量的size
和stride
信息),从而创建一个新的“视图”。这类操作速度极快,因为它们不涉及任何数据复制。- 前提条件: 只有当张量在内存中是 连续的(Contiguous) 时,才能进行视图操作。所谓连续,即指张量在内存中的存储顺序与按其维度顺序遍历时的访问顺序一致。
-
内存重排类操作 (
permute
,transpose
): 这类操作会 改变维度的顺序。这通常会导致张量在内存中变得 不连续,因为数据的物理存储顺序与新的维度遍历顺序不再匹配。
理解这一点至关重要,因为它解释了为何一个 permute
操作之后,紧接着的 view
操作常常会失败。
API 深度解析
1. torch.Tensor.view(*shape)
: 高速视图变换
原理:
view
是最基础、最高效的形状变换方法。它严格要求目标张量是内存连续的。如果满足条件,它会返回一个共享原始数据内存的新张量视图,不产生任何数据拷贝。
用法:
- 可以传入新的维度信息。
- 使用
-1
作为占位符,PyTorch 会自动计算该维度的大小。
详细示例:
import torch
# 1. 基础用法
x = torch.arange(12) # 内存连续的 1D 张量
print(f"Original: {x.shape}, Is Contiguous: {x.is_contiguous()}")
# Original: torch.Size([12]), Is Contiguous: True
y = x.view(3, 4) # 变换为 2D 视图
print(f"View (3, 4):\n{y}")
print(f"Is y sharing storage with x? {y.storage().data_ptr() == x.storage().data_ptr()}") # True
# 2. 失败场景:对非连续张量使用 view
z = y.transpose(0, 1) # transpose 会导致内存不连续
print(f"Transposed shape: {z.shape}, Is Contiguous: {z.is_contiguous()}")
# Transposed shape: torch.Size([4, 3]), Is Contiguous: False
try:
w = z.view(12)
except RuntimeError as e:
print(f"\nError trying to view a non-contiguous tensor:\n{e}")
# Error: view size is not compatible with input tensor's size and stride
2. torch.reshape(input, shape)
: 更稳健的形状变换
原理:
reshape
是一个更“智能”和用户友好的版本。它会首先尝试返回一个视图(等同于 view
)。如果因为张量不连续而失败,它会自动创建一个内存连续的副本,然后再进行形状变换。
用法:
与 view
类似,但可以作用于非连续张量。
详细示例:
继续使用上面非连续的张量 z
。
# z 是非连续的
print(f"z shape: {z.shape}, Is Contiguous: {z.is_contiguous()}") # False
# reshape 成功处理了非连续张量
w = torch.reshape(z, (12,))
print(f"\nReshaped (12,): {w}")
print(f"Is w contiguous? {w.is_contiguous()}") # True
print(f"Did w create a copy? {w.storage().data_ptr() != z.storage().data_ptr()}") # True, 因为 z 不连续
```**何时选择**:优先使用 `view` 以确保效率和明确性。当你不确定张量是否连续,或者代码的健壮性比极致性能更重要时,可以使用 `reshape`。
### 3. `torch.Tensor.permute(*dims)`: 任意维度重排
**原理**:
`permute` 用于根据给定的索引序列重新排列张量的所有维度。例如,`permute(0, 2, 1)` 会将原来的第0维保持不变,第1维和第2维互换。返回的张量几乎总是非连续的,但它与原张量共享数据存储。
**用法**:
传入的维度索引数量必须与张量的维度数相等。
**详细示例**:
这是计算机视觉中一个极其常见的场景:在 `channels-last` (如 TensorFlow, NumPy) 和 `channels-first` (PyTorch 标准) 格式之间转换。
```python
# 模拟一个 channels-last 的图像 batch (N, H, W, C)
# 4张 64x64 的 3 通道图像
image_batch_last = torch.randn(4, 64, 64, 3)
# 将其转换为 PyTorch 期望的 channels-first 格式 (N, C, H, W)
# 维度 0 -> 0 (N)
# 维度 1 -> 2 (H)
# 维度 2 -> 3 (W)
# 维度 3 -> 1 (C)
image_batch_first = image_batch_last.permute(0, 3, 1, 2)
print(f"Original shape (N, H, W, C): {image_batch_last.shape}")
print(f"Permuted shape (N, C, H, W): {image_batch_first.shape}")
print(f"Is permuted tensor contiguous? {image_batch_first.is_contiguous()}") # False
4. torch.transpose(input, dim0, dim1)
: 两维互换
原理:
transpose
是 permute
的一个特例,它只交换指定的两个维度,使用起来更简洁。
用法: 指定要交换的两个维度的索引。
详细示例:
在 NLP 中,经常需要在批处理维度(N)和序列长度维度(S)之间切换,以适应不同层的要求(如 nn.LSTM
默认是 (S, N, E)
)。
# (Batch, Seq_Len, Embedding_Dim)
x = torch.randn(32, 100, 512)
# 交换 batch 和 seq_len 维度
x_transposed = x.transpose(0, 1)
print(f"Original shape (N, S, E): {x.shape}")
print(f"Transposed shape (S, N, E): {x_transposed.shape}")
5. torch.unsqueeze(input, dim)
& torch.squeeze(input, dim=None)
: 维度扩展与压缩
原理:
unsqueeze
: 在指定位置dim
插入一个大小为 1 的新维度。squeeze
: 移除所有大小为 1 的维度。如果指定dim
,则仅当该维度大小为 1 时才移除。
这两个函数对于匹配输入形状、创建批处理维度或广播(Broadcasting)操作至关重要。
详细示例:
# 1. unsqueeze: 为单个样本添加 batch 和 channel 维度
single_image = torch.randn(28, 28) # MNIST 图像
print(f"Single image shape: {single_image.shape}")
# 添加 channel 维度 (C, H, W)
img_with_channel = single_image.unsqueeze(0)
print(f"After unsqueeze(0): {img_with_channel.shape}") # torch.Size([1, 28, 28])
# 添加 batch 维度 (N, C, H, W)
img_batch = img_with_channel.unsqueeze(0)
print(f"After unsqueeze(0) again: {img_batch.shape}") # torch.Size([1, 1, 28, 28])
# 2. squeeze: 移除多余的维度
output = torch.randn(1, 10, 1)
print(f"\nOriginal output shape: {output.shape}")
# 移除所有大小为1的维度
squeezed_all = output.squeeze()
print(f"Squeezed all: {squeezed_all.shape}") # torch.Size([10])
# 只移除最后一个维度
squeezed_dim2 = output.squeeze(dim=2)
print(f"Squeezed dim 2: {squeezed_dim2.shape}") # torch.Size([1, 10])
关键技巧:.contiguous()
内存连续化
现在我们来解决前面 view
遇到的问题。当你对一个张量进行了 permute
或 transpose
后,它很可能不再是内存连续的,此时若想使用 view
,你必须先调用 .contiguous()
。
原理:
.contiguous()
方法会返回一个与原张量内容相同,但在内存中是连续存储的新张量。如果原张量已经是连续的,此操作将直接返回原张量,几乎没有开销。如果不是,它会重新开辟内存,并按正确的顺序复制数据。
黄金法则: permute
/ transpose
之后接 view
/ reshape
,中间通常需要 .contiguous()
。
x = torch.randn(4, 3, 2)
# 1. 维度重排导致不连续
y = x.permute(2, 0, 1)
print(f"Permuted shape: {y.shape}, Is Contiguous: {y.is_contiguous()}") # False
# 2. 直接 view 会失败
try:
z = y.view(2, 12)
except RuntimeError as e:
print(f"\nError: {e}")
# 3. 正确的做法:先连续化,再 view
w = y.contiguous().view(2, 12)
print(f"\nSuccessfully viewed after .contiguous(): {w.shape}")
总结
API | 功能 | 是否改变内存布局 | 是否需要连续内存 |
---|---|---|---|
view() |
高效重塑形状 | 否(仅改 stride) | 是 |
reshape() |
稳健重塑形状 | 可能(如果需要拷贝) | 否 |
permute() |
任意重排维度 | 是 | 否 |
transpose() |
交换两个维度 | 是 | 否 |
squeeze() |
移除大小为1的维度 | 否(仅改 stride) | 否 |
unsqueeze() |
添加大小为1的维度 | 否(仅改 stride) | 否 |
.contiguous() |
创建内存连续的副本 | 是(如果需要) | 否 |
No comments to display
No comments to display