简介
记录训练通义千问Qwen2.5 VL-7B-Instruct模型Lora的一些主要步骤和命令。任务简单来说就是质检,微观结构的判定。
环境准备
租用AutoDL服务器,远程ssh。
硬件配置
GPU: H20-NVLink(96GB) * 1
CPU: 20 vCPU Intel® Xeon® Platinum 8457C
RAM: 200GB
费用 ¥7.98/时
软件依赖
PyTorch 2.5.1
Python 3.12(ubuntu22.04)
CUDA 12.4
ssh配置
配置端口映射和ssh私钥登陆, 将swift/tensorboard/vllm/ollama等选择性映射到本机
1 2 3 4 5 6 7 8 9 Host autodl HostName xxxx User root Port xxxxx IdentityFile ~/.ssh/xxxxx LocalForward 0.0.0.0:7860 localhost:7860 LocalForward 0.0.0.0:6006 localhost:6006 LocalForward 0.0.0.0:8000 localhost:8000 LocalForward 0.0.0.0:11434 localhost:11434
软件安装
modelscope
首先安装tmux和modelscope,然后马上进去tmux后台下载基座模型,这个需要一段比较长时间
1 2 3 4 5 apt update && apt install -y tmux # 进入tmux pip install modelscope modelscope download --model Qwen/Qwen2.5-VL-7B-Instruct
ms-swift
准备训练用工具swift, 训练过程依赖flash attention2, 也预先安装好. flash-attn的安装可能会卡,注意翻。
1 2 3 4 pip install ms-swift pip install flash-attn --no-build-isolation # 准备好,启动WebUI swift web-ui --lang zh
数据准备
数据预处理
数据只需要一个jsonl文件准备好对话内容即可,内容格式如下:
1 { "messages" : [ { "role" : "user" , "content" : "<image>请描述这个图片" } , { "role" : "assistant" , "content" : "这个图片描述的是……" } ] , "images" : "image-path.jpg" }
图片可能需要resize & crop,用ffmpeg处理一下:
1 find input -name "*.jpg" -exec bash -c 'ffmpeg -i {} -vf "scale=iw*0.5:ih*0.5,crop=500:500:0:0" -y output/$(basename {})' \;
数据标注
这一步省了。如果要做,可以用label studio。
训练配置
LoRA参数设置
基本上是swift的默认值
rank: 8
alpha: 32
dropout: 0.05
target modules: all-linear
训练超参数
训练时batch size为1用了89G VRAM,没尝试batch size=2是否可跑起来
batch size: 1
learning rate: 1e-4
epochs: 1000
warmup steps: 0
gradient accumulation: 16
save_steps: 100
命令行
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 swift sft \ --torch_dtype bfloat16 \ --model Qwen/Qwen2.5-VL-7B-Instruct \ --model_type qwen2_5_vl \ --template qwen2_5_vl \ --system You are a helpful assistant. \ --dataset /root/dataset/train.jsonl \ --max_length 1024 \ --init_weights True \ --learning_rate 1e-4 \ --num_train_epochs 1000 \ --attn_impl flash_attn \ --gradient_accumulation_steps 16 \ --eval_steps 500 \ --save_steps 100 \ --output_dir /root/output \ --report_to tensorboard \ --add_version False \ --output_dir /root/output/v0-20250411-021343 \ --logging_dir /root/output/v0-20250411-021343/runs \ --ignore_args_error True
评估
vllm
本来想用vllm提供接口,然后用Page Assist来调用,但是后面有其他问题放弃了,这里仅记录启动vllm的命令。
1 2 3 4 vllm serve \ ~/.cache/modelscope/hub/models/Qwen/Qwen2.5-VL-7B-Instruct \ --enable-lora \ --lora-modules lora=/path/to/checkpoint
尝试了transformers貌似无法用peft加载qwen-vl的lora,在用lora包裹model那一步会报错,这里仅给出不含lora的调用代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessorfrom qwen_vl_utils import process_vision_infomodel_id_or_path = '~/.cache/modelscope/hub/models/Qwen/Qwen2.5-VL-7B-Instruct' model = Qwen2_5_VLForConditionalGeneration.from_pretrained( model_id_or_path, torch_dtype="auto" , device_map="auto" ) processor = AutoProcessor.from_pretrained(model_id_or_path) text = processor.apply_chat_template( messages, tokenize=False , add_generation_prompt=True ) image_inputs, video_inputs = process_vision_info(messages) inputs = processor( text=[text], images=image_inputs, videos=video_inputs, padding=True , return_tensors="pt" , ) inputs = inputs.to(model.device) generated_ids = model.generate(**inputs, max_new_tokens=128 ) generated_ids_trimmed = [ out_ids[len (in_ids) :] for in_ids, out_ids in zip (inputs.input_ids, generated_ids) ] output_text = processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True , clean_up_tokenization_spaces=False )
ms-swift
代码参考:https://github.com/modelscope/ms-swift/blob/main/examples/infer/demo.py
训练总结
训练样本太少,正负各5个,泛化能力不足
速度为115s/it,接近6.5小时跑了200个epochs
大概在20+epochs后,train acc达到顶峰,175+ epochs后,训练集loss比较平缓
其他总结
Qwen开源模型
用了modelscope原版下载的2.5-VL-7B-instruct模型,发现效果非常的离奇。有一句prompt正好测试集全通过,但是稍微改一个字,甚至一个标点符号,都会导致输出结果非常不同(temperature=0),而且对于指令遵循能力也是比较差。
阿里云官方API
由于任务性质,能提供的信息只有形状、位置、结构、材质、堆叠,判断有或没有,并不是一些常见的物品。从7B到Max都测试过,结果就是只有Max可以输出准确的描述,因此也只有它能给到正确的结果,并且按照系统提示词格式正确地结构化输出JSON。其他参数的模型,某些输入能输出JSON,而某些输入却会输出markdown+JSON,不太稳定。