LLaMA-Factory监督微调实战指南
Google Gemma 2B 微调实战(IT科技新闻标题生成)
这篇实战,我们将用 Google 的 Gemma-2b 模型来微调一个根据IT科技新闻正文生成标题的模型。同时会介绍如何利用高度集成的训练框架,快速完成整个流程。
开始前
为了尽可能简化流程,直接选用了 linux-cn 数据集[1]作为本次训练任务的训练数据。
模型方面选择了 Gemma-2b[2]。对于这个任务来说,2b 级别的参数模型已经绰绰有余了,当然如果想试试 7b 也完全没问题。
训练框架直接用了 LLaMA-Factory[3],它不光能搞定监督微调(SFT),还支持预训练(PT)、奖励模型(RM)以及 PPO/DPO 的训练,一步到位。
数据整理
linux-cn 数据集本身已经被清洗和格式化过,我们只需要把需要的字段提取出来,再按 LLaMA-Factory 监督微调的格式转换一下就行。
这个任务只用到数据集里的 title 和 content 两个字段。LLaMA-Factory 的监督微调格式是这样的 JSON 文件:
[
{
"instruction": "What are the three primary colors?",
"input": "",
"output": "The three primary colors are red, blue, and yellow. These colors are called primary because they cannot be created by mixing other colors and all other colors can be made by combining them in various proportions. In the additive color system, used for light, the primary colors are red, green, and blue (RGB).",
},
...
]
因为我们用的是预训练模型,所以还要指定一个 prompt template。这么做的好处是,如果以后想混合训练多个不同类型的任务,它们之间不会互相干扰。
完整的数据转换代码如下:
import json
result = []
prompt_template = """Generate a title for the article:
{content}
---
Title:
"""
with open('archve.jsonl', 'r') as f:
for line in f:
p = json.loads(line)
result.append({
"instruction": prompt_template.replace("{content}", p['content']),
"input": "",
"output": p['title']
})
with open('itnews_data.json', 'w') as f:
json.dump(result, f, ensure_ascii=False, indent=4)
做完这一步,模型训练就可以开工了。不过说真的,整个流程里最耗时、最让人头疼的往往就是数据收集和整理这一步。
模型微调
首先确认 LLaMA-Factory 框架已在本地就绪——下载项目并完成安装。具体安装过程可以参考项目 README,这里就不展开了。
接下来把数据集放到框架的 data 目录下,然后在 dataset_info.json 里注册自定义数据集。本文实例添加的信息如下:
"itnews": {
"file_name": "itnews_data.json",
},
不同任务在这个框架里可能有不同的数据集格式要求,可以参照项目 dataset_info.json 的 README[4]。
一切就绪后,执行以下命令就能开始微调了(本文在单张 A100 80G 上完成):
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage sft \
--do_train True \
--model_name_or_path google/gemma-2b \
--finetuning_type lora \
--template default \
--dataset itnews \
--use_unsloth \
--cutoff_len 8192 \
--learning_rate 5e-05 \
--num_train_epochs 10.0 \
--max_samples 10000 \
--per_device_train_batch_size 4 \
--per_device_eval_batch_size 4 \
--gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \
--max_grad_norm 1.0 \
--logging_steps 10 \
--sa ve_steps 100 \
--eval_steps 100 \
--evaluation_strategy steps \
--warmup_steps 0 \
--output_dir sa ves/Gemma-2B/lora/train_v1 \
--bf16 True \
--lora_rank 8 \
--lora_dropout 0.1 \
--lora_target q_proj,v_proj \
--val_size 0.1 \
--load_best_model_at_end True \
--plot_loss True \
--report_to "tensorboard"
几个关键参数简单说明一下:
--stage 任务类型,这里用 sft(监督微调),其他任务需要改成对应的类型。
--dataset 数据集名称,也就是刚才在 dataset_info.json 里指定的名字。
--use_unsloth 一个训练翻跟斗,官方称在 Gemma 7b 上能加速 2.4 倍、节省超一半显存。使用前需要按官方文档[5]安装。
--cutoff_len 文本令牌化后的最大长度,Gemma 2b 的最大长度是 8192,这里就直接设了 8192。注意,更长的上下文意味着更高的 GPU 显存开销。
--max_samples 数据集加载的最大条数。调试时特别好用,比如不确定 cutoff_len 和 batch_size 的时候,可以先加载一小部分数据测试显存占用。
--learning_rate 和 --num_train_epochs 学习率和训练轮数。这些数值往往来自经验,通常通过观察 loss 来调整。但在 LLM 训练中,最终的评判标准是模型是否满足业务需求——完美的 loss 不一定意味着好用的模型。
--per_device_train_batch_size、--per_device_eval_batch_size、--gradient_accumulation_steps 这三个参数需要根据显存大小、是否使用多 GPU 等情况灵活调整。
--output_dir 模型保存目录。
更多参数解释可以查看项目说明[6]以及 Transformers Trainer 的说明[7]。
模型使用
直接通过 Transformers 库加载训练好的 LoRA 权重就能用了:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
peft_model_id = "checkpoint-2000"
model = AutoModelForCausalLM.from_pretrained(peft_model_id, device_map="cuda")
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
input_text = """
Generate a title for the article:
{content}
---
Title:
""" # 固定格式
encoding = tokenizer(input_text, return_tensors="pt").to("cuda")
outputs = model.generate(**encoding, max_length=8192, temperature=0.2, do_sample=True)
generated_ids = outputs[:, encoding.input_ids.shape[1]:]
generated_texts = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
print(generated_texts[0])
拿一篇自己写的、大约 5000 tokens 的微服务文章[8]做测试(这篇不在训练数据里)。在相同 prompt 下,对比微调前后的输出:
gemma-2b-it(未微调)
> 微服务架构
概述
微服务架构的定义
微服务架构的定义
微服务架构的定义
微服务架构的定义
微服务架构的定义
微服务架构的定义
...
LoRA 微调后
> 微服务架构的优势
简单的测试就能看出,微调后的模型在返回格式上更稳定,也更符合我们的预期。
总结
如果不想自己训练但又想尝试,可以在 HuggingFace 上搜索 gemma-2b-technology-news-title-generation-lora[9],能找到从 100 步到 2200 步的所有 checkpoint。
本文展示了一种相对简单的方式,来训练一个符合自己需求的模型。但在真实的企业场景中,还会涉及数据集合理生成、集群训练、模型 A/B 测试、企业级部署等问题。这些内容就留到以后的文章里再聊了。
![我们在这里将直接使用 LLaMA-Factory[3] 训练框架来直接完成监督微调部分工作](/uploadfile/2026/0624/e398d19d2c481e428cccad277d3f76e0.webp)