关注

零基础解读MiniMind代码:MiniMind-O训练代码

MiniMind-O 逐行代码解读 · 第四章:训练代码

📁 文件:trainer/train_sft_omni.py(263行)+ trainer/trainer_utils.py(200行) 🔖 训练是把数据"喂"给模型、让模型学习的过程——本章解读训练的每一个步骤


📚 本章导读

训练代码分为两部分:

  1. train_sft_omni.py — 训练主流程,包含命令行参数、训练循环、loss 计算
  2. trainer_utils.py — 工具函数,包含模型初始化、学习率调度、checkpoint 保存

Part A: trainer_utils.py 工具函数

1️⃣ is_main_process / Logger(第16-21行)

def is_main_process():
    return not dist.is_initialized() or dist.get_rank() == 0

def Logger(content):
    if is_main_process():
        print(content)

分布式训练中:多张 GPU 各自运行一份代码,但日志只需要主 GPU(rank=0)打印,避免重复。


2️⃣ get_lr 学习率调度(第25-27行)

def get_lr(current_step, total_steps, lr):
    return lr * (0.1 + 0.45 * (1 + math.cos(math.pi * current_step / total_steps)))

余弦退火调度

lr
1.0 * base ┤████████╗
           │         ╚════╗
           │              ╚════╗
           │                   ╚════╗
0.1 * base ┤                        ╚════════
           └──────────────────────────────────── step
           0                              total_steps
  • 起始 lr = 1.0 × base_lr
  • 结束 lr = 0.1 × base_lr
  • 平滑过渡,避免训练后期震荡

公式推导

  • cos(0) = 1 → 0.1 + 0.45 * (1 + 1) = 0.1 + 0.9 = 1.0 → 起始倍率
  • cos(π) = -1 → 0.1 + 0.45 * (1 - 1) = 0.1 → 结束倍率

3️⃣ init_distributed_mode 分布式初始化(第30-37行)

def init_distributed_mode():
    if int(os.environ.get("RANK", -1)) == -1:
        return 0  # 没有 RANK 环境变量 → 单卡模式,local_rank=0
    
    dist.init_process_group(backend="nccl")  # 初始化 NCCL 通信后端
    local_rank = int(os.environ["LOCAL_RANK"])
    torch.cuda.set_device(local_rank)  # 绑定当前进程使用的 GPU
    return local_rank

torchrun 设置的环境变量

变量说明
RANK全局进程编号(0, 1, 2, ...)
LOCAL_RANK当前节点内的 GPU 编号
WORLD_SIZE总进程数

4️⃣ setup_seed 随机种子(第40-47行)

def setup_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True   # cuDNN 确定性模式
    torch.backends.cudnn.benchmark = False       # 关闭自动调优

为什么关闭 benchmark? benchmark=True 时 cuDNN 会尝试不同的卷积算法选最快的,但不同算法可能产生不同结果。关闭后保证每次运行结果一致(可复现性),但可能慢一点。


5️⃣ log_model_params 模型参数统计(第50-62行)

def log_model_params(model, ignore_patterns=['audio_encoder', 'vision_encoder']):
    def should_count(n): return not any(p in n for p in ignore_patterns)
    total = sum(p.numel() for n, p in model.named_parameters() if should_count(n)) / 1e6  # 总参数量(M)
    
    cfg = model.config
    n_routed = getattr(cfg, 'n_routed_experts', getattr(cfg, 'num_experts', 0))  # 路由专家数
    n_active = getattr(cfg, 'num_experts_per_tok', 0)  # 每次激活的专家数
    n_shared = getattr(cfg, 'n_shared_experts', 0)     # 共享专家数
    
    expert = sum(p.numel() for n, p in model.named_parameters()
                 if 'mlp.experts.0.' in n and should_count(n)) / 1e6  # 单个专家参数量
    shared_expert = sum(p.numel() for n, p in model.named_parameters()
                        if 'mlp.shared_experts.0.' in n and should_count(n)) / 1e6
    
    base = total - (expert * n_routed) - (shared_expert * n_shared)  # 非专家部分
    active = base + (expert * n_active) + (shared_expert * n_shared)  # 每次实际激活的参数量
    
    if active < total:
        Logger(f'Model Params: {total:.2f}M-A{active:.2f}M')  # 如 113.13M-A56.8M
    else:
        Logger(f'Model Params: {total:.2f}M')

"-A" 的含义:MoE 模型虽然总参数量大,但每个 token 只激活一部分专家。113.13M-A56.8M 表示总参数113M,每次推理只用56.8M。


6️⃣ init_omni_model 模型初始化(第65-104行)⭐

def init_omni_model(omni_config, from_weight='full_sft', tokenizer_path='../model',
                     audio_encoder_path='../model/SenseVoiceSmall',
                     vision_model_path='../model/siglip2-base-p32-256-ve',
                     save_dir='../out', device='cuda', freeze_backbone='none', from_resume=0):
Step 1: 加载 tokenizer 和模型骨架
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
    model = MiniMindOmni(omni_config, audio_encoder_path=audio_encoder_path,
                         vision_model_path=vision_model_path)
Step 2: 加载预训练权重
    if from_weight != 'none':
        moe_suffix = '_moe' if omni_config.use_moe else ''
        weight_path = f'{save_dir}/{from_weight}_{omni_config.hidden_size}{moe_suffix}.pth'
        # 例如: out/llm_768.pth 或 out/sft_omni_768.pth
        
        if os.path.exists(weight_path):
            weights = torch.load(weight_path, map_location=device)
            
            # 检查 shape 不匹配的权重
            param_shapes = {k: v.shape for k, v in model.named_parameters()}
            incompatible = {k for k, v in weights.items()
                           if k in param_shapes and v.shape != param_shapes[k]}
            if incompatible:
                Logger(f'跳过shape不匹配的权重: {incompatible}')
                weights = {k: v for k, v in weights.items() if k not in incompatible}
            
            model.load_state_dict(weights, strict=False)
            Logger(f'已加载权重: {weight_path}')

strict=False:允许权重文件中有模型不存在的 key,也允许模型有权重文件中没有的 key。这对增量训练非常重要——新的 Talker 层在旧权重中不存在,不会报错。

Step 3: Talker 层初始化(关键!)
            if from_resume == 0 and omni_config.talker_hidden_size == omni_config.hidden_size:
                n_talker = omni_config.num_talker_hidden_layers  # 4
                n_thinker = len(model.thinker.layers)              # 8
                has_talker = any(k.startswith('talker.layers.') for k in weights)
                
                if not has_talker and n_talker > 0:
                    # 从 Thinker 的最后 4 层复制到 Talker 的 4 层
                    for i in range(n_talker):
                        src = n_thinker - n_talker + i  # src = 4, 5, 6, 7
                        model.talker.layers[i].load_state_dict(
                            model.thinker.layers[src].state_dict()
                        )
                    Logger(f'Talker层初始化: 复制thinker layers[{n_thinker-n_talker}:{n_thinker}] → talker layers[0:{n_talker}]')

为什么这样做?

  • Talker 和 Thinker 使用相同的 MiniMindBlock 结构
  • Talker 是新加的层,没有预训练权重
  • 从 Thinker 的后几层复制权重,比随机初始化好得多
  • Thinker 后几层已经学会了"理解语义",Talker 需要的就是理解语义来生成语音
Step 4: 冻结策略
    if freeze_backbone == 'all':
        # 冻结整个 Thinker 主干
        for param in model.model.parameters():
            param.requires_grad = False
    elif freeze_backbone == 'last1':
        # 冻结除最后一层之外的所有层
        for param in model.model.parameters():
            param.requires_grad = False
        if hasattr(model.model, 'layers') and len(model.model.layers) > 0:
            for param in model.model.layers[-1].parameters():
                param.requires_grad = True
    
    return model.to(device), tokenizer
freeze_backbone效果使用场景
none全部可训练Step 1/3 全量微调
all只训练 Talker + 投影层极小显存训练
last1只训练最后1层 + Talker + 投影层折中方案

7️⃣ omni_checkpoint Checkpoint 管理(第107-162行)

def omni_checkpoint(omni_config, weight='pretrain_omni', model=None, optimizer=None,
                    epoch=0, step=0, wandb=None, save_dir='../checkpoints', **kwargs):
保存模式(model 不为 None)
    if model is not None:
        # 1. 取出原始模型(去掉 DDP 和 compile 包装)
        raw_model = model.module if isinstance(model, DistributedDataParallel) else model
        raw_model = getattr(raw_model, '_orig_mod', raw_model)  # torch.compile 包装
        
        # 2. 清理不需要保存的权重
        clean_state_dict = {
            k: v for k, v in raw_model.state_dict().items()
            if not k.startswith('audio_encoder.') and not k.startswith('vision_encoder.')
        }
        # 不保存冻结的编码器!它们可以从原始路径重新加载
        # 这样 checkpoint 从 1.1GB 降到 226MB
        
        state_dict = {k: v.half().cpu() for k, v in clean_state_dict.items()}  # FP16 + 移到CPU
        
        # 3. 原子写入(防崩溃损坏)
        ckp_tmp = ckp_path + '.tmp'
        torch.save(state_dict, ckp_tmp)
        os.replace(ckp_tmp, ckp_path)  # 原子操作:先写临时文件,再重命名

os.replace 是原子的:如果写入过程中断电,临时文件可能损坏,但原始文件完好。重命名是瞬间的原子操作。

        # 4. 保存 resume checkpoint(含 optimizer 状态)
        resume_data = {
            'model': state_dict,
            'optimizer': optimizer.state_dict(),
            'epoch': epoch,
            'step': step,
            'world_size': dist.get_world_size() if dist.is_initialized() else 1,
            'wandb_id': wandb_id,
        }
        # 也用原子写入
加载模式(model 为 None)
    else:
        if os.path.exists(resume_path):
            ckp_data = torch.load(resume_path, map_location='cpu')
            saved_ws = ckp_data.get('world_size', 1)
            current_ws = dist.get_world_size() if dist.is_initialized() else 1
            if saved_ws != current_ws:
                # GPU数量变化时,调整 step 数
                ckp_data['step'] = ckp_data['step'] * saved_ws // current_ws
                Logger(f'GPU数量变化({saved_ws}→{current_ws}),step已自动转换为{ckp_data["step"]}')
            return ckp_data
        return None

GPU 数量变化的处理:如果之前用4卡训练保存了 step=1000,现在改用1卡继续,step 需要调整。因为4卡时每个 step 处理 4×batch_size 的数据,1卡只处理 1×batch_size,所以 step 数要乘4才能处理等量的数据。


8️⃣ SkipBatchSampler(第176-199行)

class SkipBatchSampler(Sampler):
    def __init__(self, sampler, batch_size, skip_batches=0):
        self.sampler = sampler
        self.batch_size = batch_size
        self.skip_batches = skip_batches  # 续训时跳过的 batch 数
    
    def __iter__(self):
        batch = []
        skipped = 0
        for idx in self.sampler:
            batch.append(idx)
            if len(batch) == self.batch_size:
                if skipped < self.skip_batches:
                    skipped += 1  # 跳过已训练的 batch
                    batch = []
                    continue
                yield batch
                batch = []
        if len(batch) > 0 and skipped >= self.skip_batches:
            yield batch  # 最后不完整的 batch

用途:续训时,从上次中断的 step 继续,不重复训练已经处理过的数据。


Part B: train_sft_omni.py 训练主脚本

1️⃣ omni_collate_fn 自定义批处理函数(第24-48行)

def omni_collate_fn(batch):
    input_ids, labels, audio_labels, audio_inputs, audio_lens, pixel_values, spk_emb = zip(*batch)

为什么需要自定义 collate? 因为不同样本的 audio_inputs 长度不同,pixel_values 可能是字典或 tensor,PyTorch 默认的 collate 无法处理。

    # 文本标签和音频标签:直接 stack
    input_ids = torch.stack(input_ids)    # (B, 9, T)
    labels = torch.stack(labels)          # (B, T)
    audio_labels = torch.stack(audio_labels)  # (B, 8, T)
    audio_lens = torch.tensor(audio_lens, dtype=torch.long)  # (B,)
    
    # 音频输入:不同长度,需要 pad 到最长
    valid_audios = [a for a in audio_inputs if a is not None]
    if valid_audios:
        max_t = max(a.size(1) for a in valid_audios)  # 找最长的时间帧数
        padded = [a if a.size(1) == max_t 
                  else torch.nn.functional.pad(a, (0, 0, 0, max_t - a.size(1)))
                  for a in valid_audios]  # 短的填充零
        audio_inputs = torch.cat(padded, dim=0)  # (valid_B, max_T, 560)
    else:
        audio_inputs = None
    
    # 图像输入:字典格式需要按键合并
    valid_images = [p for p in pixel_values if p is not None]
    if valid_images:
        if hasattr(valid_images[0], 'keys'):  # 字典格式
            keys = set.intersection(*[set(d.keys()) for d in valid_images])
            pixel_values = {k: torch.cat([d[k] for d in valid_images], dim=0) for k in keys}
        else:  # tensor 格式
            pixel_values = torch.cat(valid_images, dim=0)
    else:
        pixel_values = None
    
    spk_emb = torch.stack(spk_emb)  # (B, 192)
    return input_ids, labels, audio_labels, audio_inputs, audio_lens, pixel_values, spk_emb

2️⃣ train_epoch 训练一个 epoch(第51-136行)⭐ 核心训练循环

def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
    start_time = time.time()
    last_step = start_step
训练循环主体
    for step, (input_ids, labels, audio_labels, audio_inputs, audio_lens, pixel_values, spk_emb) in enumerate(loader, start=start_step + 1):

Step A: 数据移到 GPU

        input_ids = input_ids.to(args.device)
        labels = labels.to(args.device)
        audio_labels = audio_labels.to(args.device)
        audio_lens = audio_lens.to(args.device)
        if audio_inputs is not None:
            audio_inputs = audio_inputs.to(args.device)
        if pixel_values is not None:
            if hasattr(pixel_values, 'keys'):  # 字典
                pixel_values = {k: v.to(args.device) for k, v in pixel_values.items()}
            else:
                pixel_values = pixel_values.to(args.device)
        spk_emb = spk_emb.to(args.device)

Step B: 更新学习率

        lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

Step C: 前向传播(在混合精度上下文中)

        with autocast_ctx:
            res = model(input_ids, audio_inputs=audio_inputs, audio_lens=audio_lens,
                       pixel_values=pixel_values, spk_emb=spk_emb)

Step D: 计算 Loss

            loss_fct = nn.CrossEntropyLoss(reduction='none')  # 不自动求均值
            
            # === 文本损失 ===
            text_loss_raw = loss_fct(res.logits.view(-1, res.logits.size(-1)), labels.view(-1))
            text_mask = (labels.view(-1) != -100).float()  # 只计算有效位置
            text_loss = (text_loss_raw * text_mask).sum() / (text_mask.sum() + 1e-9)
            # === 音频损失(8层独立计算)===
            audio_loss = res.audio_logits[0].sum() * 0  # 初始化为 0(保持计算图)
            for i, al in enumerate(res.audio_logits):
                al_flat = al.view(-1, al.size(-1))                           # (B*T, 2112)
                target_flat = audio_labels[:, i, :].reshape(-1)              # (B*T,)
                layer_loss = loss_fct(al_flat, target_flat)                  # 逐位置损失
                valid_mask = (target_flat != -100).float()                    # 有效位置
                stop_mask = (target_flat == 2050).float()                     # stop token 位置
                weighted_loss = layer_loss * valid_mask * (1 + stop_mask * 9) # stop 权重 10x!
                msum = valid_mask.sum()
                if msum > 0:
                    audio_loss = audio_loss + weighted_loss.sum() / (msum + 1e-9)
            audio_loss = audio_loss / 8  # 8 层取平均

stop token 10 倍权重

普通 token 权重 = 1 + 0 * 9 = 1
stop token 权重 = 1 + 1 * 9 = 10
  • stop token 是音频生成的终止信号
  • 如果模型错过 stop → 无限生成噪音
  • 所以需要重点训练模型识别 stop
            # === 总损失 ===
            loss = (text_loss + audio_loss + res.aux_loss) / args.accumulation_steps
  • 除以 accumulation_steps 是为了梯度累积
  • 等效 batch_size = batch_size × accumulation_steps

Step E: 反向传播 + 梯度累积

        scaler.scale(loss).backward()  # 放大梯度(float16 用)
        if step % args.accumulation_steps == 0:
            scaler.unscale_(optimizer)  # 缩放回来
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)  # 梯度裁剪
            scaler.step(optimizer)      # 更新参数
            scaler.update()            # 更新 scaler
            optimizer.zero_grad(set_to_none=True)  # 清零梯度(set_to_none=True 更快)

set_to_none=True:不把梯度设为零张量,而是设为 None。这样 PyTorch 可以跳过一些不必要的计算。

Step F: 日志记录

        if step % args.log_interval == 0 or step == iters:
            spend_time = time.time() - start_time
            current_loss = loss.item() * args.accumulation_steps  # 恢复真实 loss
            text_loss_val = text_loss.item()
            audio_loss_val = audio_loss.item()
            current_lr = optimizer.param_groups[-1]['lr']
            eta_min = spend_time / max(step - start_step, 1) * (iters - step) // 60  # 预估剩余时间
            Logger(f'Epoch:[{epoch+1}/{args.epochs}]({step}/{iters}), '
                   f'loss: {current_loss:.4f}, text: {text_loss_val:.4f}, '
                   f'audio: {audio_loss_val:.4f}, lr: {current_lr:.08f}, epoch_time: {eta_min:.1f}min')

Step G: 保存 Checkpoint

        if (step % args.save_interval == 0 or step == iters) and is_main_process():
            model.eval()
            # ... 保存权重 ...
            model.train()

Step H: 释放显存

        del input_ids, labels, audio_labels, audio_inputs, audio_lens, pixel_values, spk_emb, res, loss
  • 手动删除不再需要的张量
  • 帮助 Python GC 和 PyTorch 缓存分配器回收显存

3️⃣ 命令行参数(第139-167行)

完整参数表:

参数默认值说明
--save_dir../out模型保存目录
--save_weightsft_omni保存权重前缀
--epochs15训练轮数
--batch_size32batch 大小
--learning_rate5e-4学习率
--devicecuda:0训练设备
--dtypebfloat16混合精度类型
--num_workers4DataLoader 线程数
--accumulation_steps1梯度累积步数
--grad_clip1.0梯度裁剪阈值
--log_interval100日志间隔
--save_interval1000保存间隔
--hidden_size768隐藏维度
--num_hidden_layers8层数
--max_seq_len512最大序列长度
--use_moe0是否 MoE
--data_path../dataset/sft_t2a_mini.parquet数据路径
--from_weightllm初始权重名
--from_resume0是否续训
--freeze_backbonenone冻结策略
--modeall训练模式
--use_compile0是否 torch.compile

--mode 的三个选项

模式冻结只训练用途
all全部Step 1/3 全量微调
audio_proj全部audio_projStep 2 音频对齐
vision_proj全部vision_proj视觉对齐(未用)

4️⃣ 主流程(第169-262行)

    # ========== 1. 初始化 ==========
    local_rank = init_distributed_mode()
    setup_seed(42 + (...))
    
    # ========== 2. 配置模型 ==========
    omni_config = OmniConfig(hidden_size=768, num_hidden_layers=8, use_moe=False)
    ckp_data = omni_checkpoint(...) if args.from_resume else None  # 续训时加载
    
    # ========== 3. 混合精度 ==========
    dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
    autocast_ctx = nullcontext() if "cpu" in args.device else torch.cuda.amp.autocast(dtype=dtype)
    
    # ========== 4. Wandb ==========
    if args.use_wandb and is_main_process():
        import swanlab as wandb  # 用 swanlab 替代 wandb
        wandb.init(...)
    
    # ========== 5. 模型 ==========
    model, tokenizer = init_omni_model(omni_config, from_weight=args.from_weight, ...)
    
    if args.use_compile == 1:
        model = torch.compile(model)  # 编译加速
    
    # 手动移动冻结的编码器到 GPU
    if model.audio_encoder is not None: model.audio_encoder.to(args.device)
    if model.vision_encoder is not None: model.vision_encoder.to(args.device)
    
    # mode 特殊处理
    if args.mode == 'audio_proj':
        for p in model.parameters(): p.requires_grad = False
        for p in model.audio_proj.parameters(): p.requires_grad = True
    
    # ========== 6. 数据集 ==========
    train_ds = OmniDataset(args.data_path, tokenizer, ...)
    
    # ========== 7. 优化器 ==========
    scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
    optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
    
    # ========== 8. 续训恢复 ==========
    if ckp_data:
        model.load_state_dict(ckp_data['model'], strict=False)
        optimizer.load_state_dict(ckp_data['optimizer'])
        scaler.load_state_dict(ckp_data['scaler'])
        start_epoch, start_step = ckp_data['epoch'], ckp_data.get('step', 0)
    
    # ========== 9. DDP 包装 ==========
    if dist.is_initialized():
        model = DistributedDataParallel(model, device_ids=[local_rank])
    
    # ========== 10. 训练循环 ==========
    for epoch in range(start_epoch, args.epochs):
        setup_seed(42 + epoch)
        batch_sampler = SkipBatchSampler(...)
        loader = DataLoader(train_ds, batch_sampler=batch_sampler, collate_fn=omni_collate_fn,
                           num_workers=args.num_workers, pin_memory=True)
        train_epoch(epoch, loader, len(loader), start_step, wandb)

📋 本章关键概念速查

三阶段训练流程

Step 1: sft_t2a (文本→音频对齐)
  数据: sft_t2a_mini.parquet
  模式: all (全量训练)
  权重: llm_768.pth → sft_zero_768.pth
  目标: 学会"说话"(文本→音频码的映射)

Step 2: audio_proj (音频理解对齐)
  数据: sft_a2a_mini.parquet
  模式: audio_proj (只训练投影层)
  权重: sft_zero_768.pth → sft_zero_768.pth (覆盖)
  目标: 学会"听"(音频特征→LLM空间的对齐)

Step 3: 全参数微调
  数据: sft_a2a_mini.parquet
  模式: all (全量训练)
  学习率: 2e-5 (比Step1小25倍)
  目标: 联合优化"听"和"说"

损失函数组成

total_loss = text_loss + audio_loss + aux_loss
             ↓           ↓            ↓
          文本预测    8层音频预测    MoE路由均衡
          (CrossEntropy)  (stop token 10x权重)  (负载均衡)

✅ 第四章完成 | 下一篇:推理 + WebUI 代码

转载自 CSDN-专业IT技术社区

原文链接:https://blog.csdn.net/liukanghao/article/details/162153106

评论

赞0

评论列表

微信小程序
QQ小程序

关于作者

点赞数:0
关注数:0
粉丝:0
文章:0
关注标签:0
加入于:--