Skip to main content

一个简单的预训练代码(了解流程即可)

好的,这里提供一个使用 PyTorch 构建的极简 Decoder-only Transformer(类似 GPT-like 结构)进行预训练(Causal Language Modeling - CLM)的代码示例。

这个例子会非常简化,以便于理解核心概念:

  1. 数据: 使用一个极小的、重复的字符级文本数据集。
  2. 分词器 (Tokenizer): 一个简单的字符级分词器。
  3. 模型: 一个简化的 Transformer Decoder (这里为了方便,我们会使用 nn.TransformerEncoder 并传入一个 causal mask,这在实践中是构建 decoder-only 模型的常见做法,因为它只包含自注意力和前馈网络)。
  4. 预训练任务: Causal Language Modeling (预测下一个 token)。
  5. 训练循环: 基本的 PyTorch 训练流程。
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import math
import random

# --- 1. 配置 ---
VOCAB_SIZE = 0 # 会根据数据动态确定
D_MODEL = 64      # 模型的维度
N_HEAD = 4        # 多头注意力的头数
NUM_DECODER_LAYERS = 2 # Decoder层数 (这里用Encoder层模拟)
DIM_FEEDFORWARD = 128 # 前馈网络的隐藏层维度
SEQ_LEN = 20      # 输入序列的最大长度
BATCH_SIZE = 8
EPOCHS = 50
LEARNING_RATE = 1e-3
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- 2. 准备数据和分词器 ---
# 极简的字符级数据
raw_data = [
    "hello world, this is a test.",
    "pytorch is fun to learn.",
    "language models are powerful.",
    "pretraining helps a lot.",
    "simple example for understanding."
] * 20 # 重复数据以增加样本量

# 简易字符级分词器
class SimpleCharTokenizer:
    def __init__(self):
        self.char_to_ix = {}
        self.ix_to_char = {}
        self.vocab_size = 0
        self.pad_token_id = 0
        self.bos_token_id = 1 # Start of sentence (Beginning of Sequence)
        self.eos_token_id = 2 # End of sentence (End of Sequence)
        self.unk_token_id = 3 # Unknown token

    def fit(self, texts):
        all_chars = set()
        for text in texts:
            all_chars.update(list(text))

        # 预留特殊 token
        self.char_to_ix = {"<pad>": self.pad_token_id, "<bos>": self.bos_token_id, "<eos>": self.eos_token_id, "<unk>": self.unk_token_id}
        self.ix_to_char = {v: k for k, v in self.char_to_ix.items()}
        
        current_idx = len(self.char_to_ix)
        for char in sorted(list(all_chars)):
            if char not in self.char_to_ix:
                self.char_to_ix[char] = current_idx
                self.ix_to_char[current_idx] = char
                current_idx += 1
        self.vocab_size = len(self.char_to_ix)
        global VOCAB_SIZE # 更新全局词汇表大小
        VOCAB_SIZE = self.vocab_size


    def encode(self, text, max_len):
        tokens = [self.bos_token_id] + [self.char_to_ix.get(char, self.unk_token_id) for char in text] + [self.eos_token_id]
        padding_needed = max_len - len(tokens)
        if padding_needed > 0:
            tokens += [self.pad_token_id] * padding_needed
        elif padding_needed < 0:
            tokens = tokens[:max_len-1] + [self.eos_token_id] # 确保EOS在末尾
        return tokens

    def decode(self, token_ids):
        chars = []
        for token_id in token_ids:
            if token_id == self.eos_token_id:
                break
            if token_id != self.pad_token_id and token_id != self.bos_token_id:
                chars.append(self.ix_to_char.get(token_id, "<unk>"))
        return "".join(chars)

tokenizer = SimpleCharTokenizer()
tokenizer.fit(raw_data)
print(f"Vocabulary Size: {VOCAB_SIZE}")
print(f"Char to Ix: {tokenizer.char_to_ix}")


# 数据集类
class CLMDataset(Dataset):
    def __init__(self, texts, tokenizer, seq_len):
        self.tokenizer = tokenizer
        self.seq_len = seq_len
        self.samples = []
        for text in texts:
            encoded = self.tokenizer.encode(text, self.seq_len + 1) # +1 for label shifting
            # 对于 CLM, input_ids 是 labels 的前缀,labels 是 input_ids 的后缀
            # e.g., text: "abc", encoded: [bos, a, b, c, eos]
            # input_ids: [bos, a, b, c] (len=4)
            # labels:    [a, b, c, eos] (len=4)
            # 我们需要确保 input_ids 和 labels 长度相同用于训练
            # 因此,若原始编码后长度为L, input_ids 为前 L-1 个, labels 为后 L-1 个
            if len(encoded) < 2: continue # 至少需要两个token来创建输入和标签

            input_ids = torch.tensor(encoded[:-1]) # 所有token除了最后一个
            labels = torch.tensor(encoded[1:])    # 所有token除了第一个
            
            # 确保长度一致,这在 encode 中已经部分处理,这里再确认
            if len(input_ids) == self.seq_len and len(labels) == self.seq_len:
                 self.samples.append({"input_ids": input_ids, "labels": labels})

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]

train_dataset = CLMDataset(raw_data, tokenizer, SEQ_LEN)
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)


# --- 3. 定义模型 ---
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1) # [max_len, 1, d_model] -> [max_len, 1, d_model]
        self.register_buffer('pe', pe) # 不作为模型参数

    def forward(self, x):
        # x shape: [seq_len, batch_size, d_model]
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

class SimpleDecoderOnlyModel(nn.Module):
    def __init__(self, vocab_size, d_model, n_head, num_encoder_layers, dim_feedforward, max_seq_len):
        super().__init__()
        self.d_model = d_model
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model, max_len=max_seq_len)
        
        # 使用 TransformerEncoderLayer 和 TransformerEncoder 来构建 Decoder-only 结构
        # 需要提供一个 causal mask (三角注意力掩码)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, 
            nhead=n_head, 
            dim_feedforward=dim_feedforward,
            batch_first=True, # 重要:输入输出为 (batch, seq, feature)
            dropout=0.1
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)
        
        self.lm_head = nn.Linear(d_model, vocab_size) # 输出到词汇表大小

    def _generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask.to(DEVICE)

    def forward(self, src_input_ids):
        # src_input_ids shape: [batch_size, seq_len]
        src_emb = self.embedding(src_input_ids) * math.sqrt(self.d_model) # [batch_size, seq_len, d_model]
        
        # PositionalEncoding 期望 [seq_len, batch_size, d_model] 如果 batch_first=False
        # 但如果 nn.TransformerEncoderLayer batch_first=True, 它内部会处理
        # 为简单起见,我们调整 PositionalEncoding 或这里的输入
        # 这里我们的 TransformerEncoderLayer 是 batch_first=True, 所以输入应该是 [batch, seq, feature]
        # PositionalEncoding 如果也期望 batch_first,需要调整或重新实现
        # 为了简单,这里我们手动调整 PositionalEncoding 的输出再加回去
        
        # 手动实现 PositionalEncoding 的 batch_first 兼容
        # 创建位置ID
        batch_size, seq_len = src_input_ids.size()
        position_ids = torch.arange(0, seq_len, dtype=torch.long, device=DEVICE).unsqueeze(0).repeat(batch_size, 1)
        # 假设我们有一个可学习的位置嵌入或固定的正弦位置编码应用在 batch_first 数据上
        # 这里用简化的 PositionalEncoding (它期望 seq_len first)
        # src_emb = src_emb.transpose(0,1) # B,S,D -> S,B,D
        # src_emb = self.pos_encoder(src_emb)
        # src_emb = src_emb.transpose(0,1) # S,B,D -> B,S,D
        # 或者,让PositionalEncoding自己处理batch_first,或者直接在这里创建并添加:
        pe_temp = torch.zeros(seq_len, self.d_model).to(DEVICE)
        position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, self.d_model, 2).float() * (-math.log(10000.0) / self.d_model))
        pe_temp[:, 0::2] = torch.sin(position * div_term.to(DEVICE))
        pe_temp[:, 1::2] = torch.cos(position * div_term.to(DEVICE))
        src_emb = src_emb + pe_temp.unsqueeze(0) # [B, S, D] + [1, S, D] -> [B, S, D] via broadcasting
        
        # 创建 causal mask
        src_mask = self._generate_square_subsequent_mask(src_input_ids.size(1))
        
        # 创建 padding mask (如果需要)
        # True 表示该位置是 padding,应该被忽略
        src_key_padding_mask = (src_input_ids == tokenizer.pad_token_id) # [batch_size, seq_len]

        output = self.transformer_encoder(src_emb, mask=src_mask, src_key_padding_mask=src_key_padding_mask)
        # output shape: [batch_size, seq_len, d_model]
        
        logits = self.lm_head(output) # [batch_size, seq_len, vocab_size]
        return logits

# --- 4. 训练 ---
model = SimpleDecoderOnlyModel(VOCAB_SIZE, D_MODEL, N_HEAD, NUM_DECODER_LAYERS, DIM_FEEDFORWARD, SEQ_LEN).to(DEVICE)
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id) # 忽略padding token的损失
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

print(f"Training on {DEVICE}")
model.train()
for epoch in range(EPOCHS):
    total_loss = 0
    for batch_idx, batch in enumerate(train_dataloader):
        input_ids = batch["input_ids"].to(DEVICE) # [batch_size, seq_len]
        labels = batch["labels"].to(DEVICE)       # [batch_size, seq_len]

        optimizer.zero_grad()
        
        # 前向传播
        logits = model(input_ids) # [batch_size, seq_len, vocab_size]
        
        # 计算损失
        # CrossEntropyLoss 期望 logits 为 [N, C] 或 [B, C, d1, d2,...], targets 为 [N] 或 [B, d1, d2,...]
        # 这里 logits: [batch_size * seq_len, vocab_size]
        # labels: [batch_size * seq_len]
        loss = criterion(logits.reshape(-1, VOCAB_SIZE), labels.reshape(-1))
        
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()

        if batch_idx % 10 == 0:
            print(f"Epoch {epoch+1}/{EPOCHS}, Batch {batch_idx}/{len(train_dataloader)}, Loss: {loss.item():.4f}")

    avg_loss = total_loss / len(train_dataloader)
    print(f"Epoch {epoch+1} finished. Average Loss: {avg_loss:.4f}")

    # --- 5. 简单生成测试 (greedy search) ---
    if (epoch + 1) % 10 == 0 :
        model.eval()
        with torch.no_grad():
            prompt = "hello "
            print(f"\nGenerating from prompt: '{prompt}'")
            generated_ids = tokenizer.encode(prompt, SEQ_LEN)[:-1] # 去掉EOS和padding,只保留BOS和prompt
            
            # 找到BOS的位置并截取其后的内容作为实际的输入序列
            try:
                bos_idx = generated_ids.index(tokenizer.bos_token_id)
                current_ids_list = generated_ids[bos_idx:] 
            except ValueError: # 如果没有BOS(例如encode逻辑变化),则直接使用
                 current_ids_list = generated_ids[:next((i for i, x in enumerate(generated_ids) if x == tokenizer.pad_token_id), len(generated_ids))]


            for _ in range(SEQ_LEN * 2): # 最多生成 SEQ_LEN*2 个 token
                input_tensor = torch.tensor([current_ids_list[-SEQ_LEN:]]).to(DEVICE) # 取最近的 SEQ_LEN 个 token
                
                output_logits = model(input_tensor) # [1, current_len, vocab_size]
                
                # 只关心最后一个时间步的预测
                next_token_logits = output_logits[:, -1, :] # [1, vocab_size]
                predicted_token_id = torch.argmax(next_token_logits, dim=-1).item()
                
                current_ids_list.append(predicted_token_id)
                
                if predicted_token_id == tokenizer.eos_token_id:
                    break
            
            generated_text = tokenizer.decode(current_ids_list)
            print(f"Generated: '{generated_text}'")
        model.train()
print("\nTraining complete.")

代码解释:

  1. 配置 (Configuration): 设置模型和训练的超参数。
  2. 数据和分词器 (Data and Tokenizer):
    • SimpleCharTokenizer: 一个非常基础的字符级分词器,它会构建字符到索引的映射,并处理 <pad>, <bos> (Beginning Of Sequence), <eos> (End Of Sequence) 特殊标记。
    • CLMDataset: PyTorch Dataset 类。关键在于 __getitem__:对于因果语言模型,输入是目标序列的前缀,标签是目标序列的后缀。例如,如果序列是 [<bos>, t1, t2, t3, <eos>],那么模型的输入是 [<bos>, t1, t2, t3],对应的标签是 [t1, t2, t3, <eos>]
  3. 模型 (Model):
    • PositionalEncoding: 标准的Transformer正弦位置编码(在这个简化版中,我在SimpleDecoderOnlyModelforward里直接实现了batch-first兼容的添加方式,更规范的做法是让PositionalEncoding类本身支持batch_first)。
    • SimpleDecoderOnlyModel:
      • nn.Embedding: 将输入的 token ID 转换为向量。
      • nn.TransformerEncoderLayernn.TransformerEncoder: PyTorch 内置的模块。我们使用它来构建 Decoder-only 结构。关键在于 forward 方法中传入 src_mask,这个 src_mask 必须是一个上三角矩阵(causal mask 或 subsequent mask),以防止模型在预测当前 token 时看到未来的 token。
      • _generate_square_subsequent_mask: 生成这种 causal mask 的辅助函数。
      • src_key_padding_mask: 告诉模型哪些是 padding token,在注意力计算中应该忽略它们。
      • lm_head: 一个线性层,将 Transformer 的输出映射回词汇表大小,得到每个 token 的 logits。
  4. 训练 (Training):
    • 使用 nn.CrossEntropyLoss 作为损失函数。ignore_index=tokenizer.pad_token_id 确保 padding token 不参与损失计算。
    • 标准的 PyTorch 训练循环:获取批次数据,前向传播,计算损失,反向传播,更新参数。
  5. 简单生成测试 (Greedy Search):
    • 在每个 epoch 结束后,进行一个简单的文本生成测试。
    • 从一个 prompt 开始,模型自回归地一个接一个地预测下一个 token,直到遇到 <eos> token 或达到最大长度。
    • 这里使用的是贪心策略(每次都选择概率最高的 token)。

重要注意事项和改进点:

  • 数据量和质量: 这个例子用的数据非常小且简单,所以模型学到的东西会很有限。真实预训练需要海量高质量文本数据。
  • 分词器: 字符级分词器很简单,但对于真实场景,通常使用更高级的分词器,如 BPE (Byte Pair Encoding), WordPiece, 或 SentencePiece。
  • 模型规模: 这里的模型参数非常少。真实的大语言模型参数量巨大。
  • 位置编码: PositionalEncoding 类在这里没有直接用于 batch_first=True 的情况,我在 forward 中重新实现了逻辑。更优雅的方式是修改 PositionalEncoding 类使其原生支持 batch_first,或者使用可学习的位置嵌入 (nn.Embedding)。
  • 超参数调整: 学习率、batch size、模型维度等都需要仔细调整。
  • 正则化: Dropout 已包含,还可以考虑权重衰减 (Weight Decay)。
  • 优化器: AdamW 通常比 Adam 效果更好。
  • 学习率调度器 (Learning Rate Scheduler): 例如,warmup 后线性衰减。
  • 梯度裁剪 (Gradient Clipping): 防止梯度爆炸。
  • 更复杂的生成策略: Beam search, top-k sampling, top-p (nucleus) sampling 通常比 greedy search 生成的文本质量更高。
  • 评估指标: Perplexity 是评估语言模型常用的指标。

这个代码只是一个起点,帮助你理解 Decoder-only 模型进行 Causal Language Modeling 预训练的基本流程。要构建一个强大的预训练模型,还需要考虑上述诸多因素并进行大量工程实践。