代码拉取完成,页面将自动刷新
同步操作将从 Charent/ChatLM-mini-Chinese 强制同步,此操作会覆盖自 Fork 仓库以来所做的任何修改,且无法恢复!!!
确定后同步将在后台操作,完成时将刷新页面,请耐心等待。
# coding=utf-8
from typing import Dict, Optional
import time
import os
import pandas as pd
import torch
from datasets import Dataset, load_dataset
from transformers import PreTrainedTokenizerFast, TrainingArguments
from trl import DPOTrainer
from tokenizers import Tokenizer
from peft import LoraConfig, TaskType, PeftModel
from config import DpoConfig, T5ModelConfig
from model.chat_model import TextToTextModel
from utils.functions import get_T5_config
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
def get_dataset(split: str, file: str, cache_dir: str = '.cache') -> Dataset:
"""Load the Anthropic Helpful-Harmless dataset from Hugging Face and convert it to the necessary format.
The dataset is converted to a dictionary with the following structure:
{
'prompt': List[str],
'chosen': List[str],
'rejected': List[str],
}
"""
dataset = load_dataset('json', data_files=file, split=split, cache_dir=cache_dir)
def split_prompt_and_responses(sample: dict) -> Dict[str, str]:
return {
# add an eos token for signal that end of sentence, using in generate.
"prompt": f"{sample['prompt']}[EOS]",
"chosen": f"{sample['chosen']}[EOS]",
"rejected": f"{sample['rejected']}[EOS]",
}
return dataset.map(split_prompt_and_responses).shuffle(2333)
def train_dpo(config: DpoConfig, peft_config: LoraConfig=None) -> None:
# step 1. 加载tokenizer
tokenizer = PreTrainedTokenizerFast.from_pretrained(config.tokenizer_dir)
# step 2. 加载预训练模型
model_train, model_ref = None, None
if os.path.isdir(config.sft_model_file):
# 传入文件夹则 from_pretrained
model_train = TextToTextModel.from_pretrained(config.sft_model_file)
model_ref = TextToTextModel.from_pretrained(config.sft_model_file)
else:
# load_state_dict
t5_config = get_T5_config(T5ModelConfig(), vocab_size=len(tokenizer), decoder_start_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id)
model_train = TextToTextModel(t5_config)
model_train.load_state_dict(torch.load(config.sft_model_file, map_location='cpu')) # set cpu for no exception
model_ref = TextToTextModel(t5_config)
model_ref.load_state_dict(torch.load(config.sft_model_file, map_location='cpu'))
# 4. 加载训练数据集
train_dataset = get_dataset("train", file=config.dpo_train_file)
# 5. 加载评估数据集
# eval_dataset = get_dataset("train", file=config.dpo_eval_file)
eval_dataset = None
# 6. 初始化训练参数
training_args = TrainingArguments(
per_device_train_batch_size=config.per_device_train_batch_size,
num_train_epochs=config.num_train_epochs,
auto_find_batch_size=True,
remove_unused_columns=False,
gradient_accumulation_steps=config.gradient_accumulation_steps,
learning_rate=config.learning_rate,
logging_first_step=True,
logging_steps=config.logging_steps,
save_steps=config.save_steps,
output_dir=config.output_dir,
optim="adafactor",
report_to="tensorboard",
log_level='info',
warmup_steps=config.warmup_steps,
bf16=False,
fp16=config.fp16,
seed=config.seed,
logging_dir=config.log_dir,
)
# 7. 初始化 DPO trainer
dpo_trainer = DPOTrainer(
model_train,
model_ref,
peft_config=peft_config,
args=training_args,
beta=config.beta,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
max_length=config.max_seq_len,
max_target_length=config.max_seq_len,
max_prompt_length=config.max_seq_len,
generate_during_eval=True,
is_encoder_decoder=True,
)
# 8. 训练
dpo_trainer.train(
# resume_from_checkpoint=True
)
# 9. save log
loss_log = pd.DataFrame(dpo_trainer.state.log_history)
loss_log.to_csv(f"./logs/dpo_train_log_{time.strftime('%Y%m%d-%H%M')}.csv")
# 10. 保存模型/lora
suffixe = '/lora/' if peft_config is not None else '/dpo'
model_save_dir = '/'.join(config.sft_model_file.split('/')[0: -1]) + suffixe
dpo_trainer.save_model(model_save_dir)
print('save model or lora adapter to: {}'.format(model_save_dir))
def merge_lora_weight_into_model(config: DpoConfig, peft_config: LoraConfig) -> None:
# step 1. 加载tokenizer
tokenizer = PreTrainedTokenizerFast.from_pretrained(config.tokenizer_dir)
# step 2. 加载预训练模型
sft_model = None
if os.path.isdir(config.sft_model_file):
# 传入文件夹则 from_pretrained
sft_model = TextToTextModel.from_pretrained(config.sft_model_file)
else:
# load_state_dict
t5_config = get_T5_config(T5ModelConfig(), vocab_size=len(tokenizer), decoder_start_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id)
sft_model = TextToTextModel(t5_config)
sft_model.load_state_dict(torch.load(config.sft_model_file, map_location='cpu')) # set cpu for no exception
# 注意这个路径要和上面的model_save_dir一致
# train_dpo函数代码
# 9. 保存模型/lora
# suffixe = '/lora/' if peft_config is not None else '/dpo'
# model_save_dir = '/'.join(config.sft_model_file.split('/')[0: -1]) + suffixe
adapter_save_dir = '/'.join(config.sft_model_file.split('/')[0: -1]) + '/lora'
peft_model = PeftModel.from_pretrained(
model=sft_model,
model_id=adapter_save_dir,
config=peft_config,
adapter_name='adapter',
)
# peft_model = PeftModel(
# model=sft_model,
# peft_config=peft_config,
# adapter_name='adapter',
# )
# 3. load adapter
print('load adapter from dir: {}'.format(adapter_save_dir))
peft_model.load_adapter(model_id=adapter_save_dir, adapter_name='adapter',)
# 4. merge
peft_model = peft_model.merge_and_unload()
# 5. save
save_merge_file = config.sft_model_file + '.dpo_lora_merged'
sft_model.save_pretrained(save_merge_file)
print('save merge model file to: {}'.format(save_merge_file))
if __name__ == "__main__":
peft_config = LoraConfig(
task_type=TaskType.SEQ_2_SEQ_LM, # text 2 text lora model
inference_mode=False,
r=16,
lora_alpha=16,
lora_dropout=0.1,
bias="all",
)
dpo_config = DpoConfig()
# 1. train
train_dpo(dpo_config, peft_config=None)
# 2. merge lora adapter into model
# merge_lora_weight_into_model(dpo_config, peft_config)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。