diff --git a/README.md b/README.md index 64d0c3e8a..24b749b26 100644 --- a/README.md +++ b/README.md @@ -76,7 +76,7 @@ Optional dependencies can also be combines with [option1,option2]. # Where to find the models? -You can find llama v2 models on HuggingFace hub [here](https://huggingface.co/meta-llama), where models with `hf` in the name are already converted to HuggingFace checkpoints so no further conversion is needed. The conversion step below is only for original model weights from Meta that are hosted on HuggingFace model hub as well. +You can find llama v2 models on Hugging Face hub [here](https://huggingface.co/meta-llama), where models with `hf` in the name are already converted to Hugging Face checkpoints so no further conversion is needed. The conversion step below is only for original model weights from Meta that are hosted on Hugging Face model hub as well. # Model conversion to Hugging Face The recipes and notebooks in this folder are using the Llama 2 model definition provided by Hugging Face's transformers library. @@ -84,7 +84,7 @@ The recipes and notebooks in this folder are using the Llama 2 model definition Given that the original checkpoint resides under models/7B you can install all requirements and convert the checkpoint with: ```bash -## Install HuggingFace Transformers from source +## Install Hugging Face Transformers from source pip freeze | grep transformers ## verify it is version 4.31.0 or higher git clone git@github.com:huggingface/transformers.git @@ -141,7 +141,7 @@ Here we use FSDP as discussed in the next section which can be used along with P ## Flash Attention and Xformer Memory Efficient Kernels -Setting `use_fast_kernels` will enable using of Flash Attention or Xformer memory-efficient kernels based on the hardware being used. This would speed up the fine-tuning job. This has been enabled in `optimum` library from HuggingFace as a one-liner API, please read more [here](https://pytorch.org/blog/out-of-the-box-acceleration/). +Setting `use_fast_kernels` will enable using of Flash Attention or Xformer memory-efficient kernels based on the hardware being used. This would speed up the fine-tuning job. This has been enabled in `optimum` library from Hugging Face as a one-liner API, please read more [here](https://pytorch.org/blog/out-of-the-box-acceleration/). ```bash torchrun --nnodes 1 --nproc_per_node 4 examples/finetuning.py --enable_fsdp --use_peft --peft_method lora --model_name /patht_of_model_folder/7B --fsdp_config.pure_bf16 --output_dir Path/to/save/PEFT/model --use_fast_kernels diff --git a/demo_apps/README.md b/demo_apps/README.md index bea21d682..a4a66b176 100644 --- a/demo_apps/README.md +++ b/demo_apps/README.md @@ -1,4 +1,4 @@ -# Llama 2 Demo Apps +# Llama 2 Demo Apps This folder contains a series of Llama 2-powered apps: * Quickstart Llama deployments and basic interactions with Llama @@ -29,7 +29,7 @@ conda activate llama-demo-apps pip install jupyter cd git clone https://github.com/facebookresearch/llama-recipes -cd llama-recipes/llama-demo-apps +cd llama-recipes/demo-apps jupyter notebook ``` @@ -40,7 +40,7 @@ You can also upload the notebooks to Google Colab. The first three demo apps show: * how to run Llama2 locally on a Mac, in the Google Colab notebook, and in the cloud using Replicate; * how to use [LangChain](https://github.com/langchain-ai/langchain), an open-source framework for building LLM apps, to ask Llama general questions in different ways; -* how to use LangChain to load a recent PDF doc - the Llama2 paper pdf - and ask questions about it. This is the well known RAG method to let LLM such as Llama2 be able to answer questions about the data not publicly available when Llama2 was trained, or about your own data. RAG is one way to prevent LLM's hallucination. +* how to use LangChain to load a recent PDF doc - the Llama2 paper pdf - and ask questions about it. This is the well known RAG method to let LLM such as Llama2 be able to answer questions about the data not publicly available when Llama2 was trained, or about your own data. RAG is one way to prevent LLM's hallucination. * how to ask follow up questions to Llama by sending previous questions and answers as the context along with the new question, hence performing multi-turn chat or conversation with Llama. ### [Running Llama2 Locally on Mac](HelloLlamaLocal.ipynb) @@ -56,11 +56,11 @@ python convert.py ### [Running Llama2 Hosted in the Cloud](HelloLlamaCloud.ipynb) The HelloLlama cloud version uses LangChain with Llama2 hosted in the cloud on [Replicate](https://replicate.com). The demo shows how to ask Llama general questions and follow up questions, and how to use LangChain to ask Llama2 questions about **unstructured** data stored in a PDF. -**Note on using Replicate** +**Note on using Replicate** To run some of the demo apps here, you'll need to first sign in with Replicate with your github account, then create a free API token [here](https://replicate.com/account/api-tokens) that you can use for a while. After the free trial ends, you'll need to enter billing info to continue to use Llama2 hosted on Replicate - according to Replicate's [Run time and cost](https://replicate.com/meta/llama-2-13b-chat) for the Llama2-13b-chat model used in our demo apps, the model "costs $0.000725 per second. Predictions typically complete within 10 seconds." This means each call to the Llama2-13b-chat model costs less than $0.01 if the call completes within 10 seconds. If you want absolutely no costs, you can refer to the section "Running Llama2 locally on Mac" above or the "Running Llama2 in Google Colab" below. ### [Running Llama2 in Google Colab](https://colab.research.google.com/drive/1-uBXt4L-6HNS2D8Iny2DwUpVS4Ub7jnk?usp=sharing) -To run Llama2 in Google Colab using [llama-cpp-python](https://github.com/abetlen/llama-cpp-python), download the quantized Llama2-13b-chat model `ggml-model-q4_0.gguf` [here](https://drive.google.com/file/d/1afPv3HOy73BE2MoYCgYJvBDeQNa9rZbj/view?usp=sharing), or follow the instructions above to build it, before uploading it to your Google drive. Note that on the free Colab T4 GPU, the call to Llama could take more than 20 minnutes to return; running the notebook locally on M1 MBP takes about 20 seconds. +To run Llama2 in Google Colab using [llama-cpp-python](https://github.com/abetlen/llama-cpp-python), download the quantized Llama2-13b-chat model `ggml-model-q4_0.gguf` [here](https://drive.google.com/file/d/1afPv3HOy73BE2MoYCgYJvBDeQNa9rZbj/view?usp=sharing), or follow the instructions above to build it, before uploading it to your Google drive. Note that on the free Colab T4 GPU, the call to Llama could take more than 20 minutes to return; running the notebook locally on M1 MBP takes about 20 seconds. ## [Running Llama2 On-Prem with vLLM and TGI](llama-on-prem.md) This tutorial shows how to use Llama 2 with [vLLM](https://github.com/vllm-project/vllm) and Hugging Face [TGI](https://github.com/huggingface/text-generation-inference) to build Llama 2 on-prem apps. @@ -71,10 +71,10 @@ This tutorial shows how to use Llama 2 with [vLLM](https://github.com/vllm-proje This demo app uses Llama2 to return a text summary of a YouTube video. It shows how to retrieve the caption of a YouTube video and how to ask Llama to summarize the content in four different ways, from the simplest naive way that works for short text to more advanced methods of using LangChain's map_reduce and refine to overcome the 4096 limit of Llama's max input token size. ## [NBA2023-24](StructuredLlama.ipynb): Ask Llama2 about Structured Data -This demo app shows how to use LangChain and Llama2 to let users ask questions about **structured** data stored in a SQL DB. As the 2023-24 NBA season is around the corner, we use the NBA roster info saved in a SQLite DB to show you how to ask Llama2 questions about your favorite teams or players. +This demo app shows how to use LangChain and Llama2 to let users ask questions about **structured** data stored in a SQL DB. As the 2023-24 NBA season is around the corner, we use the NBA roster info saved in a SQLite DB to show you how to ask Llama2 questions about your favorite teams or players. ## [LiveData](LiveData.ipynb): Ask Llama2 about Live Data -This demo app shows how to perform live data augmented generation tasks with Llama2 and [LlamaIndex](https://github.com/run-llama/llama_index), another leading open-source framework for building LLM apps: it uses the [You.com serarch API](https://documentation.you.com/quickstart) to get live search result and ask Llama2 about them. +This demo app shows how to perform live data augmented generation tasks with Llama2 and [LlamaIndex](https://github.com/run-llama/llama_index), another leading open-source framework for building LLM apps: it uses the [You.com search API](https://documentation.you.com/quickstart) to get live search result and ask Llama2 about them. ## [WhatsApp Chatbot](whatsapp_llama2.md): Building a Llama-enabled WhatsApp Chatbot This step-by-step tutorial shows how to use the [WhatsApp Business API](https://developers.facebook.com/docs/whatsapp/cloud-api/overview), LangChain and Replicate to build a Llama-enabled WhatsApp chatbot. @@ -106,4 +106,4 @@ Then enter your question, click Submit. You'll see in the notebook or a browser ![](llama2-gradio.png) ### [RAG Chatbot Example](RAG_Chatbot_example/RAG_Chatbot_Example.ipynb) -A complete example of how to build a Llama 2 chatbot hosted on your browser that can answer questions based on your own data. \ No newline at end of file +A complete example of how to build a Llama 2 chatbot hosted on your browser that can answer questions based on your own data. diff --git a/examples/inference.py b/examples/inference.py index ab4e7139f..87c43a750 100644 --- a/examples/inference.py +++ b/examples/inference.py @@ -11,7 +11,7 @@ import torch from transformers import LlamaTokenizer -from llama_recipes.inference.safety_utils import get_safety_checker +from llama_recipes.inference.safety_utils import get_safety_checker, AgentType from llama_recipes.inference.model_utils import load_model, load_peft_model @@ -33,6 +33,8 @@ def main( enable_azure_content_safety: bool=False, # Enable safety check with Azure content safety api enable_sensitive_topics: bool=False, # Enable check for sensitive topics using AuditNLG APIs enable_salesforce_content_safety: bool=True, # Enable safety check with Salesforce safety flan t5 + enable_llamaguard_content_safety: bool=False, + llamaguard_model_name: str=None, max_padding_length: int=None, # the max padding length to be used with tokenizer padding the prompts. use_fast_kernels: bool = False, # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels **kwargs @@ -48,6 +50,12 @@ def main( else: print("No user prompt provided. Exiting.") sys.exit(1) + + if enable_llamaguard_content_safety: + if not llamaguard_model_name: + print("if enable_llamaguard_content_safety is used, provide the model path with --llamaguard_model_name") + sys.exit(1) + # Set the seeds for reproducibility torch.cuda.manual_seed(seed) @@ -77,6 +85,8 @@ def main( safety_checker = get_safety_checker(enable_azure_content_safety, enable_sensitive_topics, enable_salesforce_content_safety, + enable_llamaguard_content_safety, + guard_lama_path=llamaguard_model_name ) # Safety check of the user prompt @@ -117,7 +127,7 @@ def main( output_text = tokenizer.decode(outputs[0], skip_special_tokens=True) # Safety check of the model output - safety_results = [check(output_text) for check in safety_checker] + safety_results = [check(output_text, agent_type=AgentType.AGENT, user_prompt=user_prompt) for check in safety_checker] are_safe = all([r[1] for r in safety_results]) if are_safe: print("User input and model output deemed safe.") diff --git a/examples/llama_guard/README.md b/examples/llama_guard/README.md new file mode 100644 index 000000000..fe6207c4c --- /dev/null +++ b/examples/llama_guard/README.md @@ -0,0 +1,19 @@ +# Llama Guard demo + +Llama Guard is a new experimental model that provides input and output guardrails for LLM deployments. For more details, please visit the main [repository](https://github.com/facebookresearch/PurpleLlama/tree/main/Llama-Guard). + +This folder contains the files for the function used in the safety_checker when running in the inference script. + +## Requirements +1. Llama guard model weights downloaded. To download, follow the steps shown [here](https://github.com/facebookresearch/PurpleLlama/tree/main/Llama-Guard#download) +2. Llama recipes dependencies installed +3. A GPU with at least 21 GB of free RAM to load the 7B model. To run both Llama 2 7B and Llama Guard, multiple GPUS or a single one with additional memory is required. + +### Inference Safety Checker +When running the regular inference script with prompts, Llama Guard will be used as a safety checker on the user prompt and the model output. If both are safe, the result will be show, else a message with the error will be show, with the word unsafe and a comma separated list of categories infringed. As the model is not quantized, it requires more GPU than the direct examples, to load the desired Llama model for inference and the Llama Guard model for safety checks. Using Llama 2 7B quantized, this was able to be run in a machine with four A10G GPUs. +Use this command for testing with a quantized Llama model, modifying the values accordingly: + +`RANK=0 WORLD_SIZE=1 MASTER_ADDR=127.0.0.1 MASTER_PORT=29500 python examples/inference.py --model_name --prompt_file --quantization --enable_llamaguard_content_safety --llamaguard_model_name ` + + + diff --git a/examples/llama_guard/__init__.py b/examples/llama_guard/__init__.py new file mode 100755 index 000000000..0bd1f8635 --- /dev/null +++ b/examples/llama_guard/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +from .generation import Llama, Dialog +from .model import ModelArgs, Transformer +from .tokenizer import Tokenizer diff --git a/examples/llama_guard/generation.py b/examples/llama_guard/generation.py new file mode 100755 index 000000000..4b22b581e --- /dev/null +++ b/examples/llama_guard/generation.py @@ -0,0 +1,458 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +import json +import os +import sys +import time +from pathlib import Path +from typing import List, Literal, Optional, Tuple, TypedDict + +import torch +import torch.nn.functional as F +from fairscale.nn.model_parallel.initialize import ( + get_model_parallel_rank, + initialize_model_parallel, + model_parallel_is_initialized, +) + +from llama_guard.model import ModelArgs, Transformer +from llama_guard.tokenizer import Tokenizer + +Role = Literal["system", "user", "assistant"] + + +class Message(TypedDict): + role: Role + content: str + + +class CompletionPrediction(TypedDict, total=False): + generation: str + tokens: List[str] # not required + logprobs: List[float] # not required + + +class ChatPrediction(TypedDict, total=False): + generation: Message + tokens: List[str] # not required + logprobs: List[float] # not required + + +Dialog = List[Message] + +B_INST, E_INST = "[INST]", "[/INST]" +B_SYS, E_SYS = "<>\n", "\n<>\n\n" + +SPECIAL_TAGS = [B_INST, E_INST, "<>", "<>"] +UNSAFE_ERROR = "Error: special tags are not allowed as part of the prompt." + + +class Llama: + @staticmethod + def build( + ckpt_dir: str, + tokenizer_path: str, + max_seq_len: int, + max_batch_size: int, + model_parallel_size: Optional[int] = None, + seed: int = 1, + ) -> "Llama": + """ + Build a Llama instance by initializing and loading a pre-trained model. + + Args: + ckpt_dir (str): Path to the directory containing checkpoint files. + tokenizer_path (str): Path to the tokenizer file. + max_seq_len (int): Maximum sequence length for input text. + max_batch_size (int): Maximum batch size for inference. + model_parallel_size (Optional[int], optional): Number of model parallel processes. + If not provided, it's determined from the environment. Defaults to None. + + Returns: + Llama: An instance of the Llama class with the loaded model and tokenizer. + + Raises: + AssertionError: If there are no checkpoint files in the specified directory, + or if the model parallel size does not match the number of checkpoint files. + + Note: + This method initializes the distributed process group, sets the device to CUDA, + and loads the pre-trained model and tokenizer. + + """ + if not torch.distributed.is_initialized(): + torch.distributed.init_process_group("nccl") + if not model_parallel_is_initialized(): + if model_parallel_size is None: + model_parallel_size = int(os.environ.get("WORLD_SIZE", 1)) + initialize_model_parallel(model_parallel_size) + + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + torch.cuda.set_device(local_rank) + + # seed must be the same in all processes + torch.manual_seed(seed) + + if local_rank > 0: + sys.stdout = open(os.devnull, "w") + + start_time = time.time() + checkpoints = sorted(Path(ckpt_dir).glob("*.pth")) + checkpoints_size = len(checkpoints) + assert checkpoints_size > 0, f"no checkpoint files found in {ckpt_dir}" + ckpt_path = checkpoints[get_model_parallel_rank()] + checkpoint = torch.load(ckpt_path, map_location="cpu") + with open(Path(ckpt_dir) / "params.json", "r") as f: + params = json.loads(f.read()) + + model_args: ModelArgs = ModelArgs( + max_seq_len=max_seq_len, + max_batch_size=max_batch_size, + **params, + ) + tokenizer = Tokenizer(model_path=tokenizer_path) + model_args.vocab_size = tokenizer.n_words + torch.set_default_tensor_type(torch.cuda.HalfTensor) + model = Transformer(model_args) + model.load_state_dict(checkpoint, strict=False) + print(f"Loaded in {time.time() - start_time:.2f} seconds") + + return Llama(model, tokenizer) + + def __init__(self, model: Transformer, tokenizer: Tokenizer): + self.model = model + self.tokenizer = tokenizer + + @torch.inference_mode() + def generate( + self, + prompt_tokens: List[List[int]], + max_gen_len: int, + temperature: float = 0.6, + top_p: float = 0.9, + logprobs: bool = False, + echo: bool = False, + ) -> Tuple[List[List[int]], Optional[List[List[float]]]]: + """ + Generate text sequences based on provided prompts using the language generation model. + + Args: + prompt_tokens (List[List[int]]): List of tokenized prompts, where each prompt is represented as a list of integers. + max_gen_len (int): Maximum length of the generated text sequence. + temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6. + top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9. + logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Defaults to False. + echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False. + + Returns: + Tuple[List[List[int]], Optional[List[List[float]]]]: A tuple containing generated token sequences and, if logprobs is True, corresponding token log probabilities. + + Note: + This method uses the provided prompts as a basis for generating text. It employs nucleus sampling to produce text with controlled randomness. + If logprobs is True, token log probabilities are computed for each generated token. + + """ + params = self.model.params + bsz = len(prompt_tokens) + assert bsz <= params.max_batch_size, (bsz, params.max_batch_size) + + min_prompt_len = min(len(t) for t in prompt_tokens) + max_prompt_len = max(len(t) for t in prompt_tokens) + assert max_prompt_len <= params.max_seq_len + total_len = min(params.max_seq_len, max_gen_len + max_prompt_len) + + pad_id = self.tokenizer.pad_id + tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda") + for k, t in enumerate(prompt_tokens): + tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda") + if logprobs: + token_logprobs = torch.zeros_like(tokens, dtype=torch.float) + + prev_pos = 0 + eos_reached = torch.tensor([False] * bsz, device="cuda") + input_text_mask = tokens != pad_id + if min_prompt_len == total_len: + logits = self.model.forward(tokens, prev_pos) + token_logprobs = -F.cross_entropy( + input=logits.transpose(1, 2), + target=tokens, + reduction="none", + ignore_index=pad_id, + ) + + for cur_pos in range(min_prompt_len, total_len): + logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos) + if temperature > 0: + probs = torch.softmax(logits[:, -1] / temperature, dim=-1) + next_token = sample_top_p(probs, top_p) + else: + next_token = torch.argmax(logits[:, -1], dim=-1) + + next_token = next_token.reshape(-1) + # only replace token if prompt has already been generated + next_token = torch.where( + input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token + ) + tokens[:, cur_pos] = next_token + if logprobs: + token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy( + input=logits.transpose(1, 2), + target=tokens[:, prev_pos + 1 : cur_pos + 1], + reduction="none", + ignore_index=pad_id, + ) + eos_reached |= (~input_text_mask[:, cur_pos]) & ( + next_token == self.tokenizer.eos_id + ) + prev_pos = cur_pos + if all(eos_reached): + break + + if logprobs: + token_logprobs = token_logprobs.tolist() + out_tokens, out_logprobs = [], [] + for i, toks in enumerate(tokens.tolist()): + # cut to max gen len + start = 0 if echo else len(prompt_tokens[i]) + toks = toks[start : len(prompt_tokens[i]) + max_gen_len] + probs = None + if logprobs: + probs = token_logprobs[i][start : len(prompt_tokens[i]) + max_gen_len] + # cut to eos tok if any + if self.tokenizer.eos_id in toks: + eos_idx = toks.index(self.tokenizer.eos_id) + toks = toks[:eos_idx] + probs = probs[:eos_idx] if logprobs else None + out_tokens.append(toks) + out_logprobs.append(probs) + return (out_tokens, out_logprobs if logprobs else None) + + def text_completion( + self, + prompts: List[str], + temperature: float = 0.6, + top_p: float = 0.9, + max_gen_len: Optional[int] = None, + logprobs: bool = False, + echo: bool = False, + ) -> List[CompletionPrediction]: + """ + Perform text completion for a list of prompts using the language generation model. + + Args: + prompts (List[str]): List of text prompts for completion. + temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6. + top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9. + max_gen_len (Optional[int], optional): Maximum length of the generated completion sequence. + If not provided, it's set to the model's maximum sequence length minus 1. + logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Defaults to False. + echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False. + + Returns: + List[CompletionPrediction]: List of completion predictions, each containing the generated text completion. + + Note: + This method generates text completions for the provided prompts, employing nucleus sampling to introduce controlled randomness. + If logprobs is True, token log probabilities are computed for each generated token. + + """ + if max_gen_len is None: + max_gen_len = self.model.params.max_seq_len - 1 + prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=False) for x in prompts] + generation_tokens, generation_logprobs = self.generate( + prompt_tokens=prompt_tokens, + max_gen_len=max_gen_len, + temperature=temperature, + top_p=top_p, + logprobs=logprobs, + echo=echo, + ) + if logprobs: + return [ + { + "generation": self.tokenizer.decode(t), + "tokens": [self.tokenizer.decode(x) for x in t], + "logprobs": logprobs_i, + } + for t, logprobs_i in zip(generation_tokens, generation_logprobs) + ] + return [{"generation": self.tokenizer.decode(t)} for t in generation_tokens] + + def chat_completion( + self, + dialogs: List[Dialog], + temperature: float = 0.6, + top_p: float = 0.9, + max_gen_len: Optional[int] = None, + logprobs: bool = False, + ) -> List[ChatPrediction]: + """ + Generate assistant responses for a list of conversational dialogs using the language generation model. + + Args: + dialogs (List[Dialog]): List of conversational dialogs, where each dialog is a list of messages. + temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6. + top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9. + max_gen_len (Optional[int], optional): Maximum length of the generated response sequence. + If not provided, it's set to the model's maximum sequence length minus 1. + logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Defaults to False. + + Returns: + List[ChatPrediction]: List of chat predictions, each containing the assistant's generated response. + + Raises: + AssertionError: If the last message in a dialog is not from the user. + AssertionError: If the dialog roles are not in the required 'user', 'assistant', and optional 'system' order. + + Note: + This method generates assistant responses for the provided conversational dialogs. + It employs nucleus sampling to introduce controlled randomness in text generation. + If logprobs is True, token log probabilities are computed for each generated token. + + """ + if max_gen_len is None: + max_gen_len = self.model.params.max_seq_len - 1 + prompt_tokens = [] + unsafe_requests = [] + for dialog in dialogs: + unsafe_requests.append( + any([tag in msg["content"] for tag in SPECIAL_TAGS for msg in dialog]) + ) + if dialog[0]["role"] == "system": + dialog = [ + { + "role": dialog[1]["role"], + "content": B_SYS + + dialog[0]["content"] + + E_SYS + + dialog[1]["content"], + } + ] + dialog[2:] + assert all([msg["role"] == "user" for msg in dialog[::2]]) and all( + [msg["role"] == "assistant" for msg in dialog[1::2]] + ), ( + "model only supports 'system', 'user' and 'assistant' roles, " + "starting with 'system', then 'user' and alternating (u/a/u/a/u...)" + ) + dialog_tokens: List[int] = sum( + [ + self.tokenizer.encode( + f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} ", + bos=True, + eos=True, + ) + for prompt, answer in zip( + dialog[::2], + dialog[1::2], + ) + ], + [], + ) + assert ( + dialog[-1]["role"] == "user" + ), f"Last message must be from user, got {dialog[-1]['role']}" + dialog_tokens += self.tokenizer.encode( + f"{B_INST} {(dialog[-1]['content']).strip()} {E_INST}", + bos=True, + eos=False, + ) + prompt_tokens.append(dialog_tokens) + + generation_tokens, generation_logprobs = self.generate( + prompt_tokens=prompt_tokens, + max_gen_len=max_gen_len, + temperature=temperature, + top_p=top_p, + logprobs=logprobs, + ) + if logprobs: + return [ + { + "generation": { + "role": "assistant", + "content": self.tokenizer.decode(t) + if not unsafe + else UNSAFE_ERROR, + }, + "tokens": [self.tokenizer.decode(x) for x in t], + "logprobs": logprobs_i, + } + for t, logprobs_i, unsafe in zip( + generation_tokens, generation_logprobs, unsafe_requests + ) + ] + return [ + { + "generation": { + "role": "assistant", + "content": self.tokenizer.decode(t) if not unsafe else UNSAFE_ERROR, + } + } + for t, unsafe in zip(generation_tokens, unsafe_requests) + ] + + def single_prompt_completion( + self, + prompt: str, + temperature: float = 0.6, + top_p: float = 0.9, + max_gen_len: Optional[int] = None, + echo: bool = False, + ) -> str: + """ + Perform text completion for a single prompt using the language generation model. + + Args: + prompts (str): prompt for completion. + temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6. + top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9. + max_gen_len (Optional[int], optional): Maximum length of the generated completion sequence. + If not provided, it's set to the model's maximum sequence length minus 1. + + + Returns: + str: single string with the decoded output from the model. + + Note: + This method generates text completions for the provided prompts, employing nucleus sampling to introduce controlled randomness. + """ + if max_gen_len is None: + max_gen_len = self.model.params.max_seq_len - 1 + prompt_tokens = [self.tokenizer.encode(f"{B_INST} {prompt.strip()} {E_INST}", bos=True, eos=False)] + generation_tokens = self.generate( + prompt_tokens=prompt_tokens, + max_gen_len=max_gen_len, + temperature=temperature, + top_p=top_p, + logprobs=False, + echo=echo, + ) + single_result_list = self.tokenizer.decode(generation_tokens[0]) + return single_result_list[0] + + +def sample_top_p(probs, p): + """ + Perform top-p (nucleus) sampling on a probability distribution. + + Args: + probs (torch.Tensor): Probability distribution tensor. + p (float): Probability threshold for top-p sampling. + + Returns: + torch.Tensor: Sampled token indices. + + Note: + Top-p sampling selects the smallest set of tokens whose cumulative probability mass + exceeds the threshold p. The distribution is renormalized based on the selected tokens. + + """ + probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) + probs_sum = torch.cumsum(probs_sort, dim=-1) + mask = probs_sum - probs_sort > p + probs_sort[mask] = 0.0 + probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) + next_token = torch.multinomial(probs_sort, num_samples=1) + next_token = torch.gather(probs_idx, -1, next_token) + return next_token diff --git a/examples/llama_guard/model.py b/examples/llama_guard/model.py new file mode 100755 index 000000000..c78570f68 --- /dev/null +++ b/examples/llama_guard/model.py @@ -0,0 +1,495 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +import math +from dataclasses import dataclass +from typing import Optional, Tuple + +import fairscale.nn.model_parallel.initialize as fs_init +import torch +import torch.nn.functional as F +from fairscale.nn.model_parallel.layers import ( + ColumnParallelLinear, + ParallelEmbedding, + RowParallelLinear, +) +from torch import nn + + +@dataclass +class ModelArgs: + dim: int = 4096 + n_layers: int = 32 + n_heads: int = 32 + n_kv_heads: Optional[int] = None + vocab_size: int = -1 # defined later by tokenizer + multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 + ffn_dim_multiplier: Optional[float] = None + norm_eps: float = 1e-5 + + max_batch_size: int = 32 + max_seq_len: int = 2048 + + +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + """ + Initialize the RMSNorm normalization layer. + + Args: + dim (int): The dimension of the input tensor. + eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. + + Attributes: + eps (float): A small value added to the denominator for numerical stability. + weight (nn.Parameter): Learnable scaling parameter. + + """ + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + """ + Apply the RMSNorm normalization to the input tensor. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The normalized tensor. + + """ + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + """ + Forward pass through the RMSNorm layer. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The output tensor after applying RMSNorm. + + """ + output = self._norm(x.float()).type_as(x) + return output * self.weight + + +def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): + """ + Precompute the frequency tensor for complex exponentials (cis) with given dimensions. + + This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' + and the end index 'end'. The 'theta' parameter scales the frequencies. + The returned tensor contains complex values in complex64 data type. + + Args: + dim (int): Dimension of the frequency tensor. + end (int): End index for precomputing frequencies. + theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. + + Returns: + torch.Tensor: Precomputed frequency tensor with complex exponentials. + + + + + """ + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device) # type: ignore + freqs = torch.outer(t, freqs).float() # type: ignore + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + return freqs_cis + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): + """ + Reshape frequency tensor for broadcasting it with another tensor. + + This function reshapes the frequency tensor to have the same shape as the target tensor 'x' + for the purpose of broadcasting the frequency tensor during element-wise operations. + + Args: + freqs_cis (torch.Tensor): Frequency tensor to be reshaped. + x (torch.Tensor): Target tensor for broadcasting compatibility. + + Returns: + torch.Tensor: Reshaped frequency tensor. + + Raises: + AssertionError: If the frequency tensor doesn't match the expected shape. + AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions. + """ + ndim = x.ndim + assert 0 <= 1 < ndim + assert freqs_cis.shape == (x.shape[1], x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary embeddings to input tensors using the given frequency tensor. + + This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided + frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor + is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are + returned as real tensors. + + Args: + xq (torch.Tensor): Query tensor to apply rotary embeddings. + xk (torch.Tensor): Key tensor to apply rotary embeddings. + freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. + + + + """ + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: + """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" + bs, slen, n_kv_heads, head_dim = x.shape + if n_rep == 1: + return x + return ( + x[:, :, :, None, :] + .expand(bs, slen, n_kv_heads, n_rep, head_dim) + .reshape(bs, slen, n_kv_heads * n_rep, head_dim) + ) + + +class Attention(nn.Module): + """Multi-head attention module.""" + def __init__(self, args: ModelArgs): + """ + Initialize the Attention module. + + Args: + args (ModelArgs): Model configuration parameters. + + Attributes: + n_kv_heads (int): Number of key and value heads. + n_local_heads (int): Number of local query heads. + n_local_kv_heads (int): Number of local key and value heads. + n_rep (int): Number of repetitions for local heads. + head_dim (int): Dimension size of each attention head. + wq (ColumnParallelLinear): Linear transformation for queries. + wk (ColumnParallelLinear): Linear transformation for keys. + wv (ColumnParallelLinear): Linear transformation for values. + wo (RowParallelLinear): Linear transformation for output. + cache_k (torch.Tensor): Cached keys for attention. + cache_v (torch.Tensor): Cached values for attention. + + """ + super().__init__() + self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads + model_parallel_size = fs_init.get_model_parallel_world_size() + self.n_local_heads = args.n_heads // model_parallel_size + self.n_local_kv_heads = self.n_kv_heads // model_parallel_size + self.n_rep = self.n_local_heads // self.n_local_kv_heads + self.head_dim = args.dim // args.n_heads + + self.wq = ColumnParallelLinear( + args.dim, + args.n_heads * self.head_dim, + bias=False, + gather_output=False, + init_method=lambda x: x, + ) + self.wk = ColumnParallelLinear( + args.dim, + self.n_kv_heads * self.head_dim, + bias=False, + gather_output=False, + init_method=lambda x: x, + ) + self.wv = ColumnParallelLinear( + args.dim, + self.n_kv_heads * self.head_dim, + bias=False, + gather_output=False, + init_method=lambda x: x, + ) + self.wo = RowParallelLinear( + args.n_heads * self.head_dim, + args.dim, + bias=False, + input_is_parallel=True, + init_method=lambda x: x, + ) + + self.cache_k = torch.zeros( + ( + args.max_batch_size, + args.max_seq_len, + self.n_local_kv_heads, + self.head_dim, + ) + ).cuda() + self.cache_v = torch.zeros( + ( + args.max_batch_size, + args.max_seq_len, + self.n_local_kv_heads, + self.head_dim, + ) + ).cuda() + + def forward( + self, + x: torch.Tensor, + start_pos: int, + freqs_cis: torch.Tensor, + mask: Optional[torch.Tensor], + ): + """ + Forward pass of the attention module. + + Args: + x (torch.Tensor): Input tensor. + start_pos (int): Starting position for caching. + freqs_cis (torch.Tensor): Precomputed frequency tensor. + mask (torch.Tensor, optional): Attention mask tensor. + + Returns: + torch.Tensor: Output tensor after attention. + + """ + bsz, seqlen, _ = x.shape + xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) + + xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) + xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + + xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) + + self.cache_k = self.cache_k.to(xq) + self.cache_v = self.cache_v.to(xq) + + self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk + self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv + + keys = self.cache_k[:bsz, : start_pos + seqlen] + values = self.cache_v[:bsz, : start_pos + seqlen] + + # repeat k/v heads if n_kv_heads < n_heads + keys = repeat_kv(keys, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim) + values = repeat_kv(values, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim) + + xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + keys = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim) + values = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim) + scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) + if mask is not None: + scores = scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen) + scores = F.softmax(scores.float(), dim=-1).type_as(xq) + output = torch.matmul(scores, values) # (bs, n_local_heads, seqlen, head_dim) + output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) + return self.wo(output) + + +class FeedForward(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int, + ffn_dim_multiplier: Optional[float], + ): + """ + Initialize the FeedForward module. + + Args: + dim (int): Input dimension. + hidden_dim (int): Hidden dimension of the feedforward layer. + multiple_of (int): Value to ensure hidden dimension is a multiple of this value. + ffn_dim_multiplier (float, optional): Custom multiplier for hidden dimension. Defaults to None. + + Attributes: + w1 (ColumnParallelLinear): Linear transformation for the first layer. + w2 (RowParallelLinear): Linear transformation for the second layer. + w3 (ColumnParallelLinear): Linear transformation for the third layer. + + """ + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.w1 = ColumnParallelLinear( + dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x + ) + self.w2 = RowParallelLinear( + hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x + ) + self.w3 = ColumnParallelLinear( + dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x + ) + + def forward(self, x): + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + +class TransformerBlock(nn.Module): + def __init__(self, layer_id: int, args: ModelArgs): + """ + Initialize a TransformerBlock. + + Args: + layer_id (int): Identifier for the layer. + args (ModelArgs): Model configuration parameters. + + Attributes: + n_heads (int): Number of attention heads. + dim (int): Dimension size of the model. + head_dim (int): Dimension size of each attention head. + attention (Attention): Attention module. + feed_forward (FeedForward): FeedForward module. + layer_id (int): Identifier for the layer. + attention_norm (RMSNorm): Layer normalization for attention output. + ffn_norm (RMSNorm): Layer normalization for feedforward output. + + """ + super().__init__() + self.n_heads = args.n_heads + self.dim = args.dim + self.head_dim = args.dim // args.n_heads + self.attention = Attention(args) + self.feed_forward = FeedForward( + dim=args.dim, + hidden_dim=4 * args.dim, + multiple_of=args.multiple_of, + ffn_dim_multiplier=args.ffn_dim_multiplier, + ) + self.layer_id = layer_id + self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) + self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) + + def forward( + self, + x: torch.Tensor, + start_pos: int, + freqs_cis: torch.Tensor, + mask: Optional[torch.Tensor], + ): + """ + Perform a forward pass through the TransformerBlock. + + Args: + x (torch.Tensor): Input tensor. + start_pos (int): Starting position for attention caching. + freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies. + mask (torch.Tensor, optional): Masking tensor for attention. Defaults to None. + + Returns: + torch.Tensor: Output tensor after applying attention and feedforward layers. + + """ + h = x + self.attention.forward( + self.attention_norm(x), start_pos, freqs_cis, mask + ) + out = h + self.feed_forward.forward(self.ffn_norm(h)) + return out + + +class Transformer(nn.Module): + def __init__(self, params: ModelArgs): + """ + Initialize a Transformer model. + + Args: + params (ModelArgs): Model configuration parameters. + + Attributes: + params (ModelArgs): Model configuration parameters. + vocab_size (int): Vocabulary size. + n_layers (int): Number of layers in the model. + tok_embeddings (ParallelEmbedding): Token embeddings. + layers (torch.nn.ModuleList): List of Transformer blocks. + norm (RMSNorm): Layer normalization for the model output. + output (ColumnParallelLinear): Linear layer for final output. + freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies. + + """ + super().__init__() + self.params = params + self.vocab_size = params.vocab_size + self.n_layers = params.n_layers + + self.tok_embeddings = ParallelEmbedding( + params.vocab_size, params.dim, init_method=lambda x: x + ) + + self.layers = torch.nn.ModuleList() + for layer_id in range(params.n_layers): + self.layers.append(TransformerBlock(layer_id, params)) + + self.norm = RMSNorm(params.dim, eps=params.norm_eps) + self.output = ColumnParallelLinear( + params.dim, params.vocab_size, bias=False, init_method=lambda x: x + ) + + self.freqs_cis = precompute_freqs_cis( + # Note that self.params.max_seq_len is multiplied by 2 because the token limit for the Llama 2 generation of models is 4096. + # Adding this multiplier instead of using 4096 directly allows for dynamism of token lengths while training or fine-tuning. + self.params.dim // self.params.n_heads, self.params.max_seq_len * 2 + ) + + @torch.inference_mode() + def forward(self, tokens: torch.Tensor, start_pos: int): + """ + Perform a forward pass through the Transformer model. + + Args: + tokens (torch.Tensor): Input token indices. + start_pos (int): Starting position for attention caching. + + Returns: + torch.Tensor: Output logits after applying the Transformer model. + + """ + _bsz, seqlen = tokens.shape + h = self.tok_embeddings(tokens) + self.freqs_cis = self.freqs_cis.to(h.device) + freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen] + + mask = None + if seqlen > 1: + mask = torch.full( + (seqlen, seqlen), float("-inf"), device=tokens.device + ) + + mask = torch.triu(mask, diagonal=1) + + # When performing key-value caching, we compute the attention scores + # only for the new sequence. Thus, the matrix of scores is of size + # (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for + # j > cache_len + i, since row i corresponds to token cache_len + i. + mask = torch.hstack([ + torch.zeros((seqlen, start_pos), device=tokens.device), + mask + ]).type_as(h) + + for layer in self.layers: + h = layer(h, start_pos, freqs_cis, mask) + h = self.norm(h) + output = self.output(h).float() + return output diff --git a/examples/llama_guard/prompt_format.py b/examples/llama_guard/prompt_format.py new file mode 100644 index 000000000..cbc830edb --- /dev/null +++ b/examples/llama_guard/prompt_format.py @@ -0,0 +1,146 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +from dataclasses import dataclass +from string import Template +from typing import List +from enum import Enum + +class AgentType(Enum): + AGENT = "Agent" + USER = "User" + +@dataclass +class SafetyCategory: + name: str + description: str + +@dataclass +class ConversationTurn: + message: str + agent_type: AgentType + +CATEGORY_SHORT_NAME_PREFIX = "O" + +PROMPT_TASK = "Task: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories." + +SAFETY_CATEGORIES = """ + +$categories + +""" + +PROMPT_CONVERSATION = """ + +$conversations + +""" +PROMPT_INSTRUCTIONS = """ +Provide your safety assessment for $agent_type in the above conversation: +- First line must read 'safe' or 'unsafe'. +- If unsafe, a second line must include a comma-separated list of violated categories.""" + +PROMPT_TEMPLATE = Template(f"{PROMPT_TASK}{SAFETY_CATEGORIES}{PROMPT_CONVERSATION}{PROMPT_INSTRUCTIONS}") + +LLAMA_GUARD_CATEGORY = [ + SafetyCategory( + "Violence and Hate.", + """Should not + - Help people plan or engage in violence. + - Encourage people to engage in violence. + - Express hateful, derogatory, or demeaning sentiments against people based on sensitive personal characteristics like their race, color, religion, national origin, sexual orientation, gender, gender identity, or disability. + - Encourage discrimination against people based on those characteristics and should not use slurs or create insults based on those characteristics. +Can + - Provide information on violence and discrimination. + - Discuss the topics of hate and violence, and can discuss historical events involving violence."""), + SafetyCategory( + "Sexual Content.", + """Should not + - Engage in sexually explicit (i.e., erotic) conversations and should not encourage people to engage in sexual activities. +Can + - Discuss topics related to sexuality and sexual education. + - Provide general information regarding sexual activities. + - Provide guidance on topics related to sexual health."""), + SafetyCategory( + "Criminal Planning.", + """Should not + - Help people plan or engage in criminal activities like theft, kidnapping, or financial crimes, especially when the person indicates that they may act on the information provided to them. + - Encourage people to engage in criminal activities or condone criminal behavior. +Can + - Help people understand how crimes are committed, provided that the response is not so detailed that it could be used to provide a roadmap for future crimes. + - Help people understand how systems work that are used to prevent crimes, provided those descriptions do not provide detailed information on how to defeat those systems."""), + SafetyCategory( + "Guns and Illegal Weapons.", + """Should not + - Assist in the planning or execution of firearm or illegal weapons crimes, especially when the person indicates that they may act on the information provided to them. + - Encourage people to engage in firearm or illegal weapons crimes or condone such behavior. +Can + - Discuss firearms and the arguments for and against firearm ownership."""), + SafetyCategory( + "Regulated or Controlled Substances.", + """Should not + - Encourage or assist people in the trafficking, consumption, or use of regulated or controlled substances, including but not limited to illegal drugs, tobacco, alcohol, and cannabis. + - Assist or encourage people to create such substances when it is illegal to do so. +Can + - Discuss the use of regulated or controlled substances, and can provide general information on those substances (ex: what their chemical composition is and what their effects are). + - Provide information on how those substances are manufactured, provided that the information is not so detailed as to provide step-by-step instructions for their creation."""), + SafetyCategory( + "Self-Harm.", + """Should not + - Encourage people to harm themselves. + - Romanticize or condone self-harm or suicide. + - Provide information regarding the methods of suicide or self-harm. + - Provide any information or advice, other than information regarding appropriate health resources, to people who express or imply that they wish to harm themselves. +Can + - Provide general information regarding self-harm and suicide (ex: statistics, the mental health challenges faced by those who take these actions). +Should + - Provide information regarding appropriate health resources to people who imply or express the desire to harm themselves."""), +] + +def create_conversation(messges: List[str]) -> List[ConversationTurn]: + conversations = [] + for i, messge in enumerate(messges): + conversations.append(ConversationTurn(message=messge, agent_type=AgentType.USER if i % 2 == 0 else AgentType.AGENT)) + + return conversations + +def build_prompt( + agent_type: AgentType, + categories: List[SafetyCategory], + conversations: List[ConversationTurn], + category_short_name_prefix: str = CATEGORY_SHORT_NAME_PREFIX): + categories_str = "\n".join([f"{category_short_name_prefix}{i+1}: {c.name}\n{c.description}" for i, c in enumerate(categories)]) + conversations_str = "\n\n".join([f"{t.agent_type.value}: {t.message}" for t in conversations]) + + return PROMPT_TEMPLATE.substitute(agent_type=agent_type.value, categories=categories_str, conversations=conversations_str) + +def build_prompt_test(): + print(build_prompt(AgentType.USER, + LLAMA_GUARD_CATEGORY, + [ + ConversationTurn("", AgentType.USER) + ])) + + print("\n\n") + + # use a customized safety category and create_conversation function. + print(build_prompt( + AgentType.AGENT, + [ + SafetyCategory("Violence and Hate.","""Should not + - Help people plan or engage in violence. + - Encourage people to engage in violence. + - Express hateful, derogatory, or demeaning sentiments against people based on sensitive personal characteristics like their race, color, religion, national origin, sexual orientation, gender, gender identity, or disability. + - Encourage discrimination against people based on those characteristics and should not use slurs or create insults based on those characteristics. +Can + - Provide information on violence and discrimination. + - Discuss the topics of hate and violence, and can discuss historical events involving violence.""", + ),], + create_conversation( + [ + "", + "" + ]))) + +if __name__ == "__main__": + build_prompt_test() \ No newline at end of file diff --git a/examples/llama_guard/tokenizer.py b/examples/llama_guard/tokenizer.py new file mode 100755 index 000000000..3eda89a06 --- /dev/null +++ b/examples/llama_guard/tokenizer.py @@ -0,0 +1,68 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +import os +from logging import getLogger +from typing import List + +from sentencepiece import SentencePieceProcessor + + +logger = getLogger() + + +class Tokenizer: + """tokenizing and encoding/decoding text using SentencePiece.""" + def __init__(self, model_path: str): + """ + Initializes the Tokenizer with a SentencePiece model. + + Args: + model_path (str): The path to the SentencePiece model file. + """ + # reload tokenizer + assert os.path.isfile(model_path), model_path + self.sp_model = SentencePieceProcessor(model_file=model_path) + logger.info(f"Reloaded SentencePiece model from {model_path}") + + # BOS / EOS token IDs + self.n_words: int = self.sp_model.vocab_size() + self.bos_id: int = self.sp_model.bos_id() + self.eos_id: int = self.sp_model.eos_id() + self.pad_id: int = self.sp_model.pad_id() + logger.info( + f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}" + ) + assert self.sp_model.vocab_size() == self.sp_model.get_piece_size() + + def encode(self, s: str, bos: bool, eos: bool) -> List[int]: + """ + Encodes a string into a list of token IDs. + + Args: + s (str): The input string to be encoded. + bos (bool): Whether to prepend the beginning-of-sequence token. + eos (bool): Whether to append the end-of-sequence token. + + Returns: + List[int]: A list of token IDs. + """ + assert type(s) is str + t = self.sp_model.encode(s) + if bos: + t = [self.bos_id] + t + if eos: + t = t + [self.eos_id] + return t + + def decode(self, t: List[int]) -> str: + """ + Decodes a list of token IDs into a string. + + Args: + t (List[int]): The list of token IDs to be decoded. + + Returns: + str: The decoded string. + """ + return self.sp_model.decode(t) diff --git a/pyproject.toml b/pyproject.toml index 8e9e65957..c969db169 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,4 +38,9 @@ exclude = [ packages = ["src/llama_recipes"] [tool.hatch.metadata.hooks.requirements_txt] -files = ["requirements.txt"] \ No newline at end of file +files = ["requirements.txt"] + +[tool.pytest.ini_options] +markers = [ + "skip_missing_tokenizer: skip tests when we can not access meta-llama/Llama-2-7b-hf on huggingface hub (Log in with `huggingface-cli login` to unskip).", +] diff --git a/scripts/spellcheck_conf/wordlist.txt b/scripts/spellcheck_conf/wordlist.txt index 160822647..1af789bc6 100644 --- a/scripts/spellcheck_conf/wordlist.txt +++ b/scripts/spellcheck_conf/wordlist.txt @@ -72,7 +72,6 @@ AWS Benchmarking Captum Grafana -HuggingFace JMeter KMS Kubeflow @@ -444,7 +443,6 @@ tokenizer vidhya vocabs AutoConfig -Huggingface's ScriptFunction transfomers BBM @@ -521,7 +519,6 @@ config http mnist resnet -Huggingface PyTorch benchmarking bert @@ -577,7 +574,6 @@ mtail scarpe NVidia WaveGlow -huggingface torchServe CProfile KSERVE @@ -1143,7 +1139,7 @@ dataclass datafiles davinci GPU's -HuggingFace's +Face's LoRA bitsandbytes CLA @@ -1179,10 +1175,8 @@ envinronment ggml gguf gradio -minnutes pdf quantized -serarch streamlit prem Prem @@ -1215,4 +1209,6 @@ venv webhook webhook's whatsapp -Anyscale \ No newline at end of file +Anyscale +ADDR +ckpt diff --git a/src/llama_recipes/inference/safety_utils.py b/src/llama_recipes/inference/safety_utils.py index 38a44d42c..663ffcf49 100644 --- a/src/llama_recipes/inference/safety_utils.py +++ b/src/llama_recipes/inference/safety_utils.py @@ -4,14 +4,22 @@ import os import torch import warnings +from llama_guard import Llama +from typing import List +from string import Template +from enum import Enum +class AgentType(Enum): + AGENT = "Agent" + USER = "User" + # Class for performing safety checks using AuditNLG library class AuditNLGSensitiveTopics(object): - def __init__(self): + def __init__(self, **kwargs): pass - def __call__(self, output_text): + def __call__(self, output_text, **kwargs): try: from auditnlg.safety.exam import safety_scores except ImportError as e: @@ -33,10 +41,10 @@ def __call__(self, output_text): class SalesforceSafetyChecker(object): - def __init__(self): + def __init__(self, **kwargs): pass - def __call__(self, output_text): + def __call__(self, output_text, **kwargs): from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoConfig config = AutoConfig.from_pretrained("Salesforce/safety-flan-t5-base") @@ -84,7 +92,7 @@ def get_total_length(self, data): # Class for performing safety checks using Azure Content Safety service class AzureSaftyChecker(object): - def __init__(self): + def __init__(sel, **kwargs): try: from azure.ai.contentsafety import ContentSafetyClient from azure.core.credentials import AzureKeyCredential @@ -102,7 +110,7 @@ def __init__(self): self.client = ContentSafetyClient(endpoint, AzureKeyCredential(key)) - def __call__(self, output_text): + def __call__(self, output_text, **kwargs): from azure.core.exceptions import HttpResponseError from azure.ai.contentsafety.models import AnalyzeTextOptions, TextCategory @@ -147,23 +155,153 @@ def __call__(self, output_text): return "Azure Content Saftey API", is_safe, report +class LlamaGuardSafetyChecker(object): + + PROMPT_TASK = "Task: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories." + + PROMPT_CATEGORIES = """ + + + O1: Violence and Hate. + Should not + - Help people plan or engage in violence. + - Encourage people to engage in violence. + - Express hateful, derogatory, or demeaning sentiments against people based on sensitive personal characteristics like their race, color, religion, national origin, sexual orientation, gender, gender identity, or disability. + - Encourage discrimination against people based on those characteristics and should not use slurs or create insults based on those characteristics. + Can + - Provide information on violence and discrimination. + - Discuss the topics of hate and violence, and can discuss historical events involving violence. + O2: Sexual Content. + Should not + - Engage in sexually explicit (i.e., erotic) conversations and should not encourage people to engage in sexual activities. + Can + - Discuss topics related to sexuality and sexual education. + - Provide general information regarding sexual activities. + - Provide guidance on topics related to sexual health. + O3: Criminal Planning. + Should not + - Help people plan or engage in criminal activities like theft, kidnapping, or financial crimes, especially when the person indicates that they may act on the information provided to them. + - Encourage people to engage in criminal activities or condone criminal behavior. + Can + - Help people understand how crimes are committed, provided that the response is not so detailed that it could be used to provide a roadmap for future crimes. + - Help people understand how systems work that are used to prevent crimes, provided those descriptions do not provide detailed information on how to defeat those systems. + O4: Guns and Illegal Weapons. + Should not + - Assist in the planning or execution of firearm or illegal weapons crimes, especially when the person indicates that they may act on the information provided to them. + - Encourage people to engage in firearm or illegal weapons crimes or condone such behavior. + Can + - Discuss firearms and the arguments for and against firearm ownership. + O5: Regulated or Controlled Substances. + Should not + - Encourage or assist people in the trafficking, consumption, or use of regulated or controlled substances, including but not limited to illegal drugs, tobacco, alcohol, and cannabis. + - Assist or encourage people to create such substances when it is illegal to do so. + Can + - Discuss the use of regulated or controlled substances, and can provide general information on those substances (ex: what their chemical composition is and what their effects are). + - Provide information on how those substances are manufactured, provided that the information is not so detailed as to provide step-by-step instructions for their creation. + O6: Self-Harm. + Should not + - Encourage people to harm themselves. + - Romanticize or condone self-harm or suicide. + - Provide information regarding the methods of suicide or self-harm. + - Provide any information or advice, other than information regarding appropriate health resources, to people who express or imply that they wish to harm themselves. + Can + - Provide general information regarding self-harm and suicide (ex: statistics, the mental health challenges faced by those who take these actions). + Should + - Provide information regarding appropriate health resources to people who imply or express the desire to harm themselves. + + """ + + PROMPT_USER_CONVERSATION = """ + + + $agent_type: $prompt + + + """ + PROMPT_AGENT_CONVERSATION = """ + + + $user_prompt + $agent_prompt + + + """ + PROMPT_INSTRUCTIONS = """ + Provide your safety assessment for $agent_type in the above conversation: + - First line must read 'safe' or 'unsafe'. + - If unsafe, a second line must include a comma-separated list of violated categories.""" + + USER_PROMPT_TEMPLATE = Template(f"{PROMPT_TASK}{PROMPT_CATEGORIES}{PROMPT_USER_CONVERSATION}{PROMPT_INSTRUCTIONS}") + AGENT_PROMPT_TEMPLATE = Template(f"{PROMPT_TASK}{PROMPT_CATEGORIES}{PROMPT_AGENT_CONVERSATION}{PROMPT_INSTRUCTIONS}") + + def __init__(self, **kwargs): + self.ckpt_dir = kwargs.get('guard_lama_path', None) + self.tokenizer_path = self.ckpt_dir + "/tokenizer.model" + pass + + def __call__(self, output_text, **kwargs): + + agent_type = kwargs.get('agent_type', AgentType.USER) + user_prompt = kwargs.get('user_prompt', "") + + # defaults + temperature = 1 + top_p = 1 + max_seq_len = 2048 + max_gen_len = 64 + max_batch_size = 4 + + model_prompt = output_text.strip() + if(agent_type == AgentType.AGENT): + if user_prompt == "": + print("empty user prompt for agent check, using complete prompt") + return "Llama Guard", False, "Missing user_prompt from Agent response check" + else: + model_prompt = model_prompt.replace(user_prompt, "") + user_prompt = f"User: {user_prompt}" + agent_prompt = f"Agent: {model_prompt}" + formatted_prompt = self.AGENT_PROMPT_TEMPLATE.substitute(user_prompt=user_prompt, agent_prompt=agent_prompt, agent_type=AgentType.AGENT.value) + else: + formatted_prompt = self.USER_PROMPT_TEMPLATE.substitute(prompt=model_prompt, agent_type=AgentType.USER.value) + + + generator = Llama.build( + ckpt_dir=self.ckpt_dir, + tokenizer_path=self.tokenizer_path, + max_seq_len=max_seq_len, + max_batch_size=max_batch_size, + ) + + result = generator.single_prompt_completion( + formatted_prompt, + max_gen_len=max_gen_len, + temperature=temperature, + top_p=top_p, + ) + + splitted_result = result.split("\n")[0]; + is_safe = splitted_result == "safe" + + report = result + + return "Llama Guard", is_safe, report + # Function to load the PeftModel for performance optimization # Function to determine which safety checker to use based on the options selected def get_safety_checker(enable_azure_content_safety, enable_sensitive_topics, enable_salesforce_content_safety, - ): + enable_llamaguard_content_safety, + **kwargs): safety_checker = [] if enable_azure_content_safety: - safety_checker.append(AzureSaftyChecker()) + safety_checker.append(AzureSaftyChecker(**kwargs)) if enable_sensitive_topics: - safety_checker.append(AuditNLGSensitiveTopics()) + safety_checker.append(AuditNLGSensitiveTopics(**kwargs)) if enable_salesforce_content_safety: - safety_checker.append(SalesforceSafetyChecker()) + safety_checker.append(SalesforceSafetyChecker(**kwargs)) + if enable_llamaguard_content_safety: + safety_checker.append(LlamaGuardSafetyChecker(**kwargs)) return safety_checker - - - - diff --git a/tests/conftest.py b/tests/conftest.py index a441defb3..7cbef6d7b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,14 +5,46 @@ from transformers import LlamaTokenizer +ACCESS_ERROR_MSG = "Could not access tokenizer at 'meta-llama/Llama-2-7b-hf'. Did you log into huggingface hub and provided the correct token?" + +unskip_missing_tokenizer = False + +@pytest.fixture(scope="module") +def llama_tokenizer(): + try: + return LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") + except OSError as e: + if unskip_missing_tokenizer: + raise e + return None + @pytest.fixture -def setup_tokenizer(): - def _helper(tokenizer): +def setup_tokenizer(llama_tokenizer): + def _helper(tokenizer_mock): #Align with Llama 2 tokenizer - tokenizer.from_pretrained.return_value = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf") - tokenizer.from_pretrained.return_value.add_special_tokens({'bos_token': '', 'eos_token': ''}) - tokenizer.from_pretrained.return_value.bos_token_id = 1 - tokenizer.from_pretrained.return_value.eos_token_id = 2 + tokenizer_mock.from_pretrained.return_value = llama_tokenizer return _helper + + +@pytest.fixture(autouse=True) +def skip_if_tokenizer_is_missing(request, llama_tokenizer): + if request.node.get_closest_marker("skip_missing_tokenizer") and not unskip_missing_tokenizer: + if llama_tokenizer is None: + pytest.skip(ACCESS_ERROR_MSG) + + +def pytest_addoption(parser): + parser.addoption( + "--unskip-missing-tokenizer", + action="store_true", + default=False, help="disable skip missing tokenizer") + + +@pytest.hookimpl(tryfirst=True) +def pytest_cmdline_preparse(config, args): + if "--unskip-missing-tokenizer" not in args: + return + global unskip_missing_tokenizer + unskip_missing_tokenizer = True diff --git a/tests/datasets/test_custom_dataset.py b/tests/datasets/test_custom_dataset.py index 6f830e76e..db67fe516 100644 --- a/tests/datasets/test_custom_dataset.py +++ b/tests/datasets/test_custom_dataset.py @@ -17,6 +17,7 @@ def check_padded_entry(batch): assert batch["input_ids"][0][-1] == 2 +@pytest.mark.skip_missing_tokenizer() @patch('llama_recipes.finetuning.train') @patch('llama_recipes.finetuning.LlamaTokenizer') @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained') @@ -29,7 +30,7 @@ def test_custom_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker, kwargs = { "dataset": "custom_dataset", - "model_name": "decapoda-research/llama-7b-hf", # We use the tokenizer as a surrogate for llama2 tokenizer here + "model_name": "meta-llama/Llama-2-7b-hf", "custom_dataset.file": "examples/custom_dataset.py", "custom_dataset.train_split": "validation", "batch_size_training": 2, diff --git a/tests/datasets/test_grammar_datasets.py b/tests/datasets/test_grammar_datasets.py index 418cc4d93..13a0271ea 100644 --- a/tests/datasets/test_grammar_datasets.py +++ b/tests/datasets/test_grammar_datasets.py @@ -1,11 +1,13 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. +import pytest from unittest.mock import patch from transformers import LlamaTokenizer +@pytest.mark.skip_missing_tokenizer() @patch('llama_recipes.finetuning.train') @patch('llama_recipes.finetuning.LlamaTokenizer') @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained') @@ -18,7 +20,7 @@ def test_grammar_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker BATCH_SIZE = 8 kwargs = { - "model_name": "decapoda-research/llama-7b-hf", + "model_name": "meta-llama/Llama-2-7b-hf", "batch_size_training": BATCH_SIZE, "val_batch_size": 1, "use_peft": False, @@ -46,8 +48,8 @@ def test_grammar_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker assert "input_ids" in batch.keys() assert "attention_mask" in batch.keys() - assert batch["labels"][0][29] == -100 - assert batch["labels"][0][30] == 29871 + assert batch["labels"][0][31] == -100 + assert batch["labels"][0][32] == 1152 assert batch["input_ids"][0][0] == 1 assert batch["labels"][0][-1] == 2 diff --git a/tests/datasets/test_samsum_datasets.py b/tests/datasets/test_samsum_datasets.py index 392a1e123..96c75ad2c 100644 --- a/tests/datasets/test_samsum_datasets.py +++ b/tests/datasets/test_samsum_datasets.py @@ -1,10 +1,12 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. +import pytest from functools import partial from unittest.mock import patch +@pytest.mark.skip_missing_tokenizer() @patch('llama_recipes.finetuning.train') @patch('llama_recipes.finetuning.LlamaTokenizer') @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained') @@ -17,7 +19,7 @@ def test_samsum_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker, BATCH_SIZE = 8 kwargs = { - "model_name": "decapoda-research/llama-7b-hf", + "model_name": "meta-llama/Llama-2-7b-hf", "batch_size_training": BATCH_SIZE, "val_batch_size": 1, "use_peft": False, @@ -46,7 +48,7 @@ def test_samsum_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker, assert "attention_mask" in batch.keys() assert batch["labels"][0][268] == -100 - assert batch["labels"][0][269] == 22291 + assert batch["labels"][0][269] == 319 assert batch["input_ids"][0][0] == 1 assert batch["labels"][0][-1] == 2 diff --git a/tests/test_batching.py b/tests/test_batching.py index 4c8ab98d8..2053c187d 100644 --- a/tests/test_batching.py +++ b/tests/test_batching.py @@ -5,6 +5,7 @@ from unittest.mock import patch +@pytest.mark.skip_missing_tokenizer() @patch('llama_recipes.finetuning.train') @patch('llama_recipes.finetuning.LlamaTokenizer') @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained') @@ -16,7 +17,7 @@ def test_packing(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_ setup_tokenizer(tokenizer) kwargs = { - "model_name": "decapoda-research/llama-7b-hf", + "model_name": "meta-llama/Llama-2-7b-hf", "batch_size_training": 8, "val_batch_size": 1, "use_peft": False, @@ -46,6 +47,7 @@ def test_packing(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_ assert batch["attention_mask"][0].size(0) == 4096 +@pytest.mark.skip_missing_tokenizer() @patch('llama_recipes.finetuning.train') @patch('llama_recipes.finetuning.LlamaTokenizer') @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained') @@ -69,7 +71,7 @@ def test_distributed_packing(dist, is_initialized, fsdp, setup, step_lr, optimiz os.environ['MASTER_PORT'] = '12345' kwargs = { - "model_name": "decapoda-research/llama-7b-hf", + "model_name": "meta-llama/Llama-2-7b-hf", "batch_size_training": 8, "val_batch_size": 1, "use_peft": False, diff --git a/tests/test_train_utils.py b/tests/test_train_utils.py index e85dc3af7..3f4ae07e9 100644 --- a/tests/test_train_utils.py +++ b/tests/test_train_utils.py @@ -12,7 +12,7 @@ @patch("llama_recipes.utils.train_utils.torch.cuda.amp.GradScaler") @patch("llama_recipes.utils.train_utils.torch.cuda.amp.autocast") def test_gradient_accumulation(autocast, scaler, nullcontext, mem_trace, mocker): - + model = mocker.MagicMock(name="model") model().loss.__truediv__().detach.return_value = torch.tensor(1) mock_tensor = mocker.MagicMock(name="tensor") @@ -27,7 +27,8 @@ def test_gradient_accumulation(autocast, scaler, nullcontext, mem_trace, mocker) train_config.enable_fsdp = False train_config.use_fp16 = False train_config.run_validation = False - + train_config.gradient_clipping = False + train( model, train_dataloader, @@ -38,15 +39,15 @@ def test_gradient_accumulation(autocast, scaler, nullcontext, mem_trace, mocker) gradient_accumulation_steps, train_config, ) - + assert optimizer.zero_grad.call_count == 5 optimizer.zero_grad.reset_mock() - + assert nullcontext.call_count == 5 nullcontext.reset_mock() - + assert autocast.call_count == 0 - + gradient_accumulation_steps = 2 train_config.use_fp16 = True train( @@ -61,4 +62,4 @@ def test_gradient_accumulation(autocast, scaler, nullcontext, mem_trace, mocker) ) assert optimizer.zero_grad.call_count == 3 assert nullcontext.call_count == 0 - assert autocast.call_count == 5 \ No newline at end of file + assert autocast.call_count == 5