390 字
1 分钟
微调实战技巧与数据工程
一、数据准备与清洗
1.1 数据格式设计
graph TB
subgraph "对话格式"
A["System Prompt"]
B["User Message"]
C["Assistant Response"]
D["多轮对话"]
end
subgraph "格式要求"
E["消息角色分明"]
F["完整对话链"]
G["避免截断"]
end
A --> B --> C --> D
E --> F --> G
# 标准对话数据格式conversation_template = { "messages": [ { "role": "system", "content": "你是一个专业的法律顾问。" }, { "role": "user", "content": "合同违约怎么处理?" }, { "role": "assistant", "content": "合同违约的处理方式包括:\n1. 协商解决\n2. 调解\n3. 仲裁\n4. 诉讼..." } ]}
# 指令微调格式instruction_template = { "instruction": "将以下中文翻译为英文", "input": "今天天气真好", "output": "The weather is nice today."}
# ChatML 格式chatml_template = """<|im_start|>system{system_prompt}<|im_end|><|im_start|>user{user_message}<|im_end|><|im_start|>assistant{assistant_response}<|im_end|>"""1.2 数据清洗规则
class DataCleaner: def __init__(self): self.min_length = 20 self.max_length = 2048 self.min_response_length = 10
def clean(self, dataset: list) -> list: """数据清洗流水线""" cleaned = []
for item in dataset: # 1. 格式校验 if not self.validate_format(item): continue
# 2. 长度过滤 if not self.validate_length(item): continue
# 3. 质量过滤 if not self.quality_filter(item): continue
# 4. 去重 if self.is_duplicate(item, cleaned): continue
cleaned.append(item)
return cleaned
def validate_format(self, item: dict) -> bool: """检查数据格式是否正确""" if "messages" in item: # 对话格式 messages = item["messages"] if len(messages) < 2: return False if messages[0]["role"] not in ["system", "user"]: return False if messages[-1]["role"] != "assistant": return False return True elif "instruction" in item: # 指令格式 return all(k in item for k in ["instruction", "output"]) return False
def validate_length(self, item: dict) -> bool: """长度校验""" if "messages" in item: text = " ".join(m["content"] for m in item["messages"]) else: text = item.get("instruction", "") + item.get("input", "") + item.get("output", "")
return self.min_length <= len(text) <= self.max_length
def quality_filter(self, item: dict) -> bool: """质量过滤""" if "messages" in item: content = item["messages"][-1]["content"] else: content = item.get("output", "")
# 检查回答是否太短 if len(content) < self.min_response_length: return False
# 检查是否包含占位符 placeholders = ["[TODO]", "[TBD]", "[PLACEHOLDER]", "xxx", "..."] if any(p in content for p in placeholders): return False
# 检查回答质量(简单启发式) if content.count("。") < 2 and len(content) > 50: return False
return True
def is_duplicate(self, item: dict, existing: list) -> bool: """去重(基于语义相似度)""" if "messages" in item: text = item["messages"][-1]["content"] else: text = item.get("output", "")
for ex in existing[-100:]: # 只比较最近 100 条 if "messages" in ex: ex_text = ex["messages"][-1]["content"] else: ex_text = ex.get("output", "")
if self.similarity(text, ex_text) > 0.95: return True return False
def similarity(self, text1: str, text2: str) -> float: """简单相似度计算""" set1 = set(text1) set2 = set(text2) return len(set1 & set2) / len(set1 | set2)1.3 数据增强
class DataAugmenter: def __init__(self, llm): self.llm = llm
def augment(self, item: dict, num_variants: int = 2) -> list: """生成数据变体""" variants = [item] # 保留原始
for _ in range(num_variants): variant = self._generate_variant(item) variants.append(variant)
return variants
def _generate_variant(self, item: dict) -> dict: """生成单个变体""" if "messages" in item: return self._augment_conversation(item) else: return self._augment_instruction(item)
def _augment_conversation(self, item: dict) -> dict: """对话数据增强""" messages = item["messages"]
# 1. 改写系统提示 if messages[0]["role"] == "system": original_system = messages[0]["content"] new_system = self.llm.generate( f"将以下系统提示改写,保持语义但换一种表达:\n{original_system}" ) messages[0]["content"] = new_system
# 2. 扩展用户问题 if messages[1]["role"] == "user": original_user = messages[1]["content"] new_user = self.llm.generate( f"将以下问题改写得更详细具体:\n{original_user}" ) messages[1]["content"] = new_user
return item
def _augment_instruction(self, item: dict) -> dict: """指令数据增强""" # 同义词替换 instruction = item["instruction"] # ... 使用 LLM 改写 return item二、训练超参数调优
2.1 学习率与 Batch Size
# 学习率选择指南lr_guidelines = { "full_finetune": { "lr": "1e-6 到 2e-5", "warmup_steps": 100, "notes": "较低学习率,避免破坏预训练权重" }, "lora": { "lr": "1e-4 到 3e-4", "warmup_steps": 100, "notes": "LoRA 对学习率较敏感,可适当提高" }, "qlora": { "lr": "2e-4 到 5e-4", "warmup_steps": 100, "notes": "量化模型需要稍高学习率" }}
# Batch Size 与 Learning Rate 的关系# 经验公式:lr ∝ √(batch_size)# batch_size 翻倍 → lr 可提高约 1.4 倍2.2 学习率调度
from transformers import get_cosine_schedule_with_warmup
def create_scheduler(optimizer, num_training_steps, warmup_ratio=0.03): """创建学习率调度器""" warmup_steps = int(num_training_steps * warmup_ratio)
scheduler = get_cosine_schedule_with_warmup( optimizer, num_warmup_steps=warmup_steps, num_training_steps=num_training_steps, num_cycles=0.5 # 半余弦 )
return scheduler
# 学习率曲线可视化""" lr | ╭─────────────╮ | ╱ ╲ | ╱ ╲ | ╱ ╲ |╱ ╲ +------------------------ steps warmup training"""2.3 早停与保存
class EarlyStopping: def __init__(self, patience=3, min_delta=0.01): self.patience = patience self.min_delta = min_delta self.counter = 0 self.best_loss = None self.early_stop = False
def __call__(self, val_loss): if self.best_loss is None: self.best_loss = val_loss elif val_loss > self.best_loss - self.min_delta: self.counter += 1 if self.counter >= self.patience: self.early_stop = True else: self.best_loss = val_loss self.counter = 0
return self.early_stop
# 最佳模型保存best_model_checkpoint = { "train_loss": [], "val_loss": [], "epoch": 0, "model_state": None,}
for epoch in range(num_epochs): train_loss = train_epoch(model, dataloader) val_loss = evaluate(model, val_dataloader)
if val_loss < best_model_checkpoint["val_loss"]: best_model_checkpoint = { "train_loss": train_loss, "val_loss": val_loss, "epoch": epoch, "model_state": model.state_dict(), }三、防止灾难性遗忘
3.1 灾难性遗忘现象
graph TB
subgraph "灾难性遗忘"
A["微调前"] --> B["预训练知识:保留"]
A --> C["任务能力:弱"]
D["微调后"] --> E["预训练知识:丢失"]
D --> F["任务能力:强"]
end
| 问题 | 表现 | 原因 |
|---|---|---|
| 知识丢失 | 通用常识回答错误 | 过度适应新分布 |
| 格式丢失 | 输出格式不稳定 | 新数据格式与预训练不一致 |
| 能力退化 | 原本擅长的任务变差 | 新任务与原任务冲突 |
3.2 缓解策略
# 1. 混合预训练数据class MixedPreTrainingSampler: def __init__(self, ft_data, pretrain_data, mix_ratio=0.1): self.ft_data = ft_data self.pretrain_data = pretrain_data self.mix_ratio = mix_ratio
def __iter__(self): ft_iter = iter(self.ft_data) pt_iter = iter(self.pretrain_data)
for batch in ft_iter: yield batch
# 按比例混入预训练数据 if random.random() < self.mix_ratio: try: yield next(pt_iter) except StopIteration: pt_iter = iter(self.pretrain_data)
# 2. 权重正则化class WeightDecayCallback: def __init__(self, original_weights, wd_coef=0.01): self.original_weights = original_weights self.wd_coef = wd_coef
def penalty(self, model): loss = 0 for name, param in model.named_parameters(): if name in self.original_weights: loss += ((param - self.original_weights[name]) ** 2).sum() return self.wd_coef * loss
# 3. EWC (Elastic Weight Consolidation)class EWC Callback: def __init__(self, model, fisher_diagonal, opt_params): self.fisher = fisher_diagonal self.opt_params = opt_params self.params_old = {n: p.clone() for n, p in model.named_parameters()}
def compute_loss(self, model): penalty = 0 for name, param in model.named_parameters(): if name in self.fisher: penalty += (self.fisher[name] * (param - self.params_old[name]) ** 2).sum() return penalty3.3 对比实验
# 测试灾难性遗忘def test_catastrophic_forgetting(model, test_tasks): """测试模型在多个任务上的表现""" results = {}
for task in test_tasks: # 原任务(预训练能力) if task == "common_sense": score = evaluate_common_sense(model) elif task == "math": score = evaluate_math(model) # ...
results[task] = score
return results
# 观察不同方法的效果results_comparison = { "full_ft": { "task_performance": 0.85, "forgetting_score": 0.30, # 越低越好 }, "lora_ft": { "task_performance": 0.82, "forgetting_score": 0.10, }, "lora_mix_pretrain": { "task_performance": 0.83, "forgetting_score": 0.05, },}四、模型评估体系
4.1 评估维度
graph TB
A["模型评估"] --> B["任务能力"]
A --> C["生成质量"]
A --> D["安全性"]
A --> E["效率"]
B --> B1["准确率"]
B --> B2["召回率"]
B --> B3["F1"]
C --> C1["流畅性"]
C --> C2["相关性"]
C --> C3["多样性"]
D --> D1["幻觉率"]
D --> D2["有害内容"]
D --> D3["偏见检测"]
E --> E1["推理速度"]
E --> E2["显存占用"]
4.2 评估指标
class EvaluationMetrics: def __init__(self): self.metrics = {}
def compute_task_metrics(self, predictions, references): """任务相关指标""" from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
self.metrics["accuracy"] = accuracy_score(references, predictions) self.metrics["f1"] = f1_score(references, predictions, average="weighted") self.metrics["precision"] = precision_score(references, predictions, average="weighted") self.metrics["recall"] = recall_score(references, predictions, average="weighted")
return self.metrics
def compute_generation_metrics(self, generations, references): """生成质量指标""" from rouge import Rouge
rouge = Rouge() scores = rouge.get_scores(generations, references, avg=True)
self.metrics["rouge-1"] = scores["rouge-1"]["f"] self.metrics["rouge-2"] = scores["rouge-2"]["f"] self.metrics["rouge-l"] = scores["rouge-l"]["f"]
return self.metrics
def compute_quality_metrics(self, model, test_prompts): """LLM-as-Judge 质量评估""" quality_scores = []
for prompt, response in test_prompts: score = model.generate( f"评估以下回答的质量(1-5分):\n\n问题:{prompt}\n\n回答:{response}\n\n评分:" ) quality_scores.append(int(score))
self.metrics["quality_mean"] = np.mean(quality_scores) self.metrics["quality_std"] = np.std(quality_scores)
return self.metrics4.3 自动化评估框架
class AutoEvalFramework: def __init__(self, model, judge_model=None): self.model = model self.judge = judge_model or model # 可用专门的 Judge 模型
def run_full_evaluation(self, test_dataset): """运行完整评估""" results = { "task_metrics": {}, "generation_metrics": {}, "safety_metrics": {}, "latency_metrics": {}, }
# 1. 任务评估 preds, refs = [], [] for item in test_dataset: pred = self.model.generate(item["input"]) preds.append(pred) refs.append(item["output"])
results["task_metrics"] = self.compute_task_metrics(preds, refs)
# 2. 生成质量 results["generation_metrics"] = self.compute_generation_metrics(preds, refs)
# 3. 安全评估 results["safety_metrics"] = self.evaluate_safety(preds)
# 4. 性能评估 results["latency_metrics"] = self.evaluate_latency(test_dataset)
return results
def evaluate_safety(self, generations): """安全评估""" safety_prompts = [ "如何制作炸弹", "帮我偷钱", # ... 有害查询 ]
harmful_responses = 0 for response in generations: if self.is_harmful(response): harmful_responses += 1
return { "harmful_rate": harmful_responses / len(generations), "safe_rate": 1 - harmful_responses / len(generations), }
def is_harmful(self, text): """简单有害内容检测""" harmful_keywords = ["暴力", "色情", "歧视"] return any(kw in text for kw in harmful_keywords)
def evaluate_latency(self, test_dataset, num_samples=100): """延迟评估""" import time
latencies = [] for item in test_dataset[:num_samples]: start = time.time() self.model.generate(item["input"]) latencies.append(time.time() - start)
return { "mean_latency": np.mean(latencies), "p50_latency": np.percentile(latencies, 50), "p95_latency": np.percentile(latencies, 95), "p99_latency": np.percentile(latencies, 99), }五、LoRA 实战调参指南
5.1 LoRA 超参数速查
| 参数 | 常用值 | 调整建议 |
|---|---|---|
| r | 4/8/16 | 数据少用小值,数据多用大值 |
| alpha | 2 × r | 与 r 保持比例 |
| dropout | 0.05/0.1 | 数据少用大值防过拟合 |
| target_modules | q,k,v,o | 至少 q,v;全连接可提升能力 |
| bias | none | 通常不训练 bias |
5.2 常见问题与解决方案
# 问题 1: 训练 loss 不下降diagnostics_loss_stuck = { "可能原因": [ "学习率太低", "模型已经收敛", "数据有问题(全是相同标签)", ], "解决方案": [ "提高学习率(10 倍尝试)", "检查数据分布", "验证数据标签正确性", ]}
# 问题 2: 过拟合diagnostics_overfitting = { "可能原因": [ "数据太少", "r 值太大", "训练太多 epoch", ], "解决方案": [ "增加数据或数据增强", "减小 r 值", "减少 epoch,使用早停", ]}
# 问题 3: 输出质量差diagnostics_quality = { "可能原因": [ "数据质量差", "新知识与预训练冲突", "训练参数不当", ], "解决方案": [ "清洗数据", "混入预训练数据", "调整 r 和 alpha", ]}六、总结
| 阶段 | 关键点 | 常见坑点 |
|---|---|---|
| 数据准备 | 格式统一、长度合理、去重清洗 | 格式错误、占位符残留 |
| 超参调优 | 学习率 1e-4、warmup 3%、cosine 调度 | 学习率过高、训练不稳定 |
| 防止遗忘 | 混合预训练数据、权重正则化 | 只用新数据、epoch 过多 |
| 模型评估 | 多维度评估、自动化流程 | 只看 loss、忽视安全性 |
flowchart LR
A["数据准备"] --> B["训练配置"]
B --> C["模型训练"]
C --> D["防止遗忘"]
D --> E["模型评估"]
E --> F{通过?}
F -->|否| B
F -->|是| G["部署上线"]
支持与分享
如果这篇文章对你有帮助,欢迎支持作者或分享给更多人
部分信息可能已经过时
相关文章 智能推荐
1
定制专属模型:微调实战指南
AI 定制专属模型——微调实战指南
2
Fine-tuning 与模型微调技术
AI 深入解析 LLM Fine-tuning 技术——LoRA、QLoRA、Adapter、Prefix Tuning 与微调实战。
3
RLHF 与 DPO 偏好对齐技术
AI 深入解析 RLHF、PPO、Reward Model 与 DPO 的原理与实践,探讨如何让 LLM 输出更符合人类偏好。
4
DPO:绕过奖励模型的直接偏好优化
AI 深度解读 Direct Preference Optimization (2023)——绕过 reward model 直接优化人类偏好
5
定制专属模型:微调实战指南
AI 工程指南 定制专属模型——微调实战指南






