Skip to content

Commit

Permalink
Distill DeepSeek-R1 (#438)
Browse files Browse the repository at this point in the history
Co-authored-by: wangzhihong <[email protected]>
  • Loading branch information
JingofXin and wzh1994 authored Feb 21, 2025
1 parent ab7f279 commit c511e1c
Show file tree
Hide file tree
Showing 3 changed files with 138 additions and 4 deletions.
132 changes: 132 additions & 0 deletions examples/distill_deepseek_r1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import os
import re
import json

import lazyllm
from lazyllm import finetune, deploy, launchers, warp

from modelscope.msdatasets import MsDataset


def load_data(data_path):
with open(data_path, 'r') as file:
dataset = json.load(file)
return dataset

def save_res(data, file_path):
with open(file_path, 'w') as file:
json.dump(data, file, ensure_ascii=False, indent=4)

def build_data_path(file_name):
data_root = os.path.join(os.getcwd(), 'dataset')
if not os.path.exists(data_root):
os.makedirs(data_root)
save_path = os.path.join(data_root, file_name)
return save_path

def get_dataset():
train_path = build_data_path('train_set.json')
eval_path = build_data_path('eval_set.json')
ds = MsDataset.load('modelscope/gsm8k', subset_name='main')
ds = ds.rename_column('question', 'instruction').rename_column('answer', 'output')
with open(train_path, 'w') as file:
json.dump(ds['train'].to_list(), file, ensure_ascii=False, indent=4)
with open(eval_path, 'w') as file:
json.dump(ds['test'].to_list(), file, ensure_ascii=False, indent=4)
return train_path, eval_path

def distill_dataset(data_path, model=None, demo=False):
inputs = load_data(data_path)[:1] if demo else load_data(data_path)
res_list = []
try_n = 0
while inputs:
print(">>>" * 12, f"{try_n+1} times left: ", len(inputs))
with warp(_concurrent=1) as wp:
wp.func = model
querys = [item['instruction'] for item in inputs]
results = wp(querys)
valid_data, inputs = filter(inputs, results)
res_list.extend(valid_data)
try_n += 1
if try_n == 15:
break
res_list = res_list * 120 if demo else res_list
distilled_train_set_path = build_data_path('distilled_train_data.json')
save_res(res_list, distilled_train_set_path)
save_res(inputs, build_data_path('left_data.json'))
return distilled_train_set_path

def filter(inputs, results):
valid = []
retry = []
for i, item in enumerate(inputs):
true_v = item['output'].split('\n#### ')[-1].strip()
if f'\\boxed{{{true_v}}}' in results[i] and '</think>' in results[i]:
valid.append({'instruction': item['instruction'], 'output': results[i], 'input': ''})
else:
retry.append(item)
return valid, retry

def extract_boxed_content(text):
pattern = r'boxed{((?:[^{}]*|{.*?})*)}'
contents = re.findall(pattern, text)
return contents

def caculate_score(eval_set, infer_set):
assert len(eval_set) == len(infer_set)
score = 0
for index, eval_item in enumerate(eval_set):
output = infer_set[index]
if 'boxed{' in output:
res = extract_boxed_content(output)
res = list(set(res))
res = res[0] if len(res) == 1 else res
if type(res) is list:
continue
true_v = eval_item['output'].split('\n#### ')[-1].strip()
if true_v == res.strip():
score += 1
return f'{score}/{len(eval_set)}, {round(score/len(eval_set),4)*100}%'

def main(techer_name, student_name, demo=False, sft_data_path=None):
# Launcher Teacher
teacher_model = lazyllm.OnlineChatModule(techer_name)

# Load and Distill Dataset
train_set_path, eval_set_path = get_dataset()
eval_set = load_data(eval_set_path)
if not sft_data_path:
sft_data_path = distill_dataset(train_set_path, teacher_model, demo)

# Train and Infer
infer_data = [item['instruction'] for item in eval_set]
student_model = lazyllm.TrainableModule(student_name)\
.mode('finetune')\
.trainset(sft_data_path)\
.finetune_method((finetune.llamafactory, {
'learning_rate': 1e-4,
'cutoff_len': 5120,
'max_samples': 20000,
'val_size': 0.01,
'per_device_train_batch_size': 2,
'num_train_epochs': 2.0,
'launcher': launchers.sco(nnode=1, nproc=8, ngpus=8)
}))\
.prompt(dict(system='You are a helpful assistant.', drop_builtin_system=True))\
.deploy_method(deploy.Vllm)
student_model._prompt._soa = '<|im_start|>assistant\n\n<think>'
student_model.evalset(infer_data)
student_model.update()

# Score
score = caculate_score(eval_set, student_model.eval_result)
print("All Done. Score is: ", score)

if __name__ == '__main__':
teacher_model_name = 'DeepSeek-R1'
student_model_name = 'internlm2-chat-7b'
# Demo
main(teacher_model_name, student_model_name, demo=True)
# Valid
# sft_data_path = 'path/to/gsm8k_deepseeko1_7148.json'
# main(teacher_model_name, student_model_name, sft_data_path=sft_data_path)
4 changes: 2 additions & 2 deletions lazyllm/components/deploy/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ class Vllm(LazyLLMDeployBase):
'stream': False,
'stop': ['<|im_end|>', '<|im_start|>', '</s>', '<|assistant|>', '<|user|>', '<|system|>', '<eos>'],
'skip_special_tokens': False,
'temperature': 0.01,
'temperature': 0.6,
'top_p': 0.8,
'max_tokens': 1024
'max_tokens': 4096
}
auto_map = {'tp': 'tensor-parallel-size'}

Expand Down
6 changes: 4 additions & 2 deletions lazyllm/module/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def _get_eval_tasks(self):
def set_result(x): self.eval_result = x

def parallel_infer():
with ThreadPoolExecutor(max_workers=100) as executor:
with ThreadPoolExecutor(max_workers=200) as executor:
results = list(executor.map(lambda item: self(**item)
if isinstance(item, dict) else self(item), self._evalset))
return results
Expand Down Expand Up @@ -831,10 +831,12 @@ def status(self, task_name: Optional[str] = None):
def prompt(self, prompt: str = '', history: Optional[List[List[str]]] = None):
if self.base_model != '' and prompt == '' and ModelManager.get_model_type(self.base_model) != 'llm':
prompt = None
clear_system = isinstance(prompt, dict) and prompt.get('drop_builtin_system')
prompt = super(__class__, self).prompt(prompt, history)._prompt
self._tools = getattr(prompt, "_tools", None)
keys = ModelManager.get_model_prompt_keys(self.base_model)
keys = ModelManager.get_model_prompt_keys(self.base_model).copy()
if keys:
if clear_system: keys['system'] = ''
prompt._set_model_configs(**keys)
for key in ["tool_start_token", "tool_args_token", "tool_end_token"]:
if key in keys: setattr(self, f"_{key}", keys[key])
Expand Down

0 comments on commit c511e1c

Please sign in to comment.