diff --git a/modules/chat.py b/modules/chat.py index 4e0bde1c90..85b6e614ab 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -704,22 +704,22 @@ def load_character(character, name1, name2): return name1, name2, picture, greeting, context -def load_instruction_template(template): +def load_instruction_template(template, current_instruction_template=None, current_chat_template=None): if template == 'None': - return '' + return '', current_chat_template for filepath in [Path(f'instruction-templates/{template}.yaml'), Path('instruction-templates/Alpaca.yaml')]: if filepath.exists(): break else: - return '' + return '', current_chat_template file_contents = open(filepath, 'r', encoding='utf-8').read() data = yaml.safe_load(file_contents) if 'instruction_template' in data: - return data['instruction_template'] + return data['instruction_template'], data['chat_template'] if 'chat_template' in data else current_chat_template else: - return jinja_template_from_old_format(data) + return jinja_template_from_old_format(data), current_chat_template @functools.cache @@ -821,9 +821,10 @@ def generate_character_yaml(name, greeting, context): return yaml.dump(data, sort_keys=False, width=float("inf")) -def generate_instruction_template_yaml(instruction_template): +def generate_instruction_template_yaml(instruction_template, chat_template): data = { - 'instruction_template': instruction_template + 'instruction_template': instruction_template, + 'chat_template': chat_template } return my_yaml_output(data) diff --git a/modules/ui_chat.py b/modules/ui_chat.py index 3193bd6748..49c7423202 100644 --- a/modules/ui_chat.py +++ b/modules/ui_chat.py @@ -316,13 +316,12 @@ def create_event_handlers(): shared.gradio['delete_character'].click(lambda: gr.update(visible=True), None, gradio('character_deleter')) shared.gradio['load_template'].click( - chat.load_instruction_template, gradio('instruction_template'), gradio('instruction_template_str')).then( - lambda: "Select template to load...", None, gradio('instruction_template')) + chat.load_instruction_template, gradio(['instruction_template', 'instruction_template_str', 'chat_template_str']), gradio(['instruction_template_str', 'chat_template_str'])) shared.gradio['save_template'].click( - lambda: 'My Template.yaml', None, gradio('save_filename')).then( + lambda x: x + '.yaml', gradio('instruction_template'), gradio('save_filename')).then( lambda: 'instruction-templates/', None, gradio('save_root')).then( - chat.generate_instruction_template_yaml, gradio('instruction_template_str'), gradio('save_contents')).then( + chat.generate_instruction_template_yaml, gradio(['instruction_template_str', 'chat_template_str']), gradio('save_contents')).then( lambda: gr.update(visible=True), None, gradio('file_saver')) shared.gradio['delete_template'].click(