diff --git a/llms/mlx_lm/__init__.py b/llms/mlx_lm/__init__.py index 538be9277..49a4a30be 100644 --- a/llms/mlx_lm/__init__.py +++ b/llms/mlx_lm/__init__.py @@ -6,4 +6,4 @@ os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1" -from .utils import convert, generate, load, stream_generate +from .utils import convert, generate, load, stream_generate, batch_generate diff --git a/llms/mlx_lm/cache_prompt.py b/llms/mlx_lm/cache_prompt.py index 9d7d1603d..4f88061e8 100644 --- a/llms/mlx_lm/cache_prompt.py +++ b/llms/mlx_lm/cache_prompt.py @@ -132,7 +132,7 @@ def main(): prompt = args.prompt cache = make_prompt_cache(model, args.max_kv_size) - y = mx.array(tokenizer.encode(prompt)) + y = mx.array(tokenizer.encode(prompt))[None] # Process the prompt start = time.time() diff --git a/llms/mlx_lm/models/base.py b/llms/mlx_lm/models/base.py index ad7a4a65a..568b85abb 100644 --- a/llms/mlx_lm/models/base.py +++ b/llms/mlx_lm/models/base.py @@ -28,6 +28,7 @@ def create_causal_mask( offset: int = 0, window_size: Optional[int] = None, lengths: Optional[mx.array] = None, + dtype: mx.Dtype = mx.float32, ): rinds = mx.arange(offset + N) linds = mx.arange(offset, offset + N) if offset else rinds @@ -39,7 +40,9 @@ def create_causal_mask( if lengths is not None: lengths = lengths[:, None, None, None] mask = mask | (rinds >= lengths) - return mask * -1e9 + # HACK: sometimes see NaN logprobs if no divide by 2 here + # return mask * (mx.finfo(dtype).min / 2) + return mask.astype(dtype) * (-65504.0 / 2) def create_attention_mask(h: mx.array, cache: Optional[Any] = None): diff --git a/llms/mlx_lm/sample_utils.py b/llms/mlx_lm/sample_utils.py index c48a32cf2..5aad20e04 100644 --- a/llms/mlx_lm/sample_utils.py +++ b/llms/mlx_lm/sample_utils.py @@ -72,7 +72,7 @@ def make_logits_processors( values = mx.array(list(logit_bias.values())) def logit_bias_processor(_, logits): - logits[:, indices] += values + logits[..., indices] += values return logits logits_processors.append(logit_bias_processor) @@ -132,7 +132,10 @@ def min_p_sampling( 0.99-0.8 range. min_tokens_to_keep (int, optional): Minimum number of tokens that cannot be filtered. Default: ``1``. - + temperature: Temperature parameter for softmax distribution reshaping. + Returns: + token(s) selected based on the min-p criterion. + Shape: same as logits, but with the last dimension having size 1. """ if not (0 <= min_p <= 1.0): raise ValueError( @@ -147,11 +150,11 @@ def min_p_sampling( logprobs = logprobs * (1 / temperature) # Indices sorted in decreasing order - sorted_indices = mx.argsort(-logprobs).squeeze(0) - sorted_logprobs = logprobs[..., sorted_indices] + sorted_indices = mx.argsort(-logprobs) + sorted_logprobs = mx.take_along_axis(logprobs, sorted_indices, axis=-1) # Top probability - top_logprobs = logprobs[..., sorted_indices[0]] + top_logprobs = mx.expand_dims(sorted_logprobs[..., 0], axis=-1) # Calculate the min_p threshold scaled_min_p = top_logprobs + math.log(min_p) @@ -163,13 +166,18 @@ def min_p_sampling( # Create pool of tokens with probability less than scaled min_p selected_logprobs = mx.where(tokens_to_remove, -float("inf"), sorted_logprobs) - # Return sampled token - sorted_token = mx.random.categorical(selected_logprobs) - return sorted_indices[sorted_token] + # Return sampled token(s) + sampled_indices = mx.random.categorical(selected_logprobs) + tokens = mx.take_along_axis( + sorted_indices, mx.expand_dims(sampled_indices, axis=-1), axis=-1 + ) + return tokens.squeeze(-1) @partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) -def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.array: +def top_p_sampling( + logits: mx.array, top_p: float, temperature: float, axis: int = -1 +) -> mx.array: """ Apply top-p (nucleus) sampling to logits. @@ -177,29 +185,36 @@ def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.arr logits: The logits from the model's output. top_p: The cumulative probability threshold for top-p filtering. temperature: Temperature parameter for softmax distribution reshaping. + axis: The axis along which to apply top-p sampling. Returns: - token selected based on the top-p criterion. + token(s) selected based on the top-p criterion. """ # referenced implementation from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L449-L460 - probs = mx.softmax(logits * (1 / temperature), axis=-1) + # Apply temperature and compute softmax + probs = mx.softmax(logits / temperature, axis=axis) - # sort probs in ascending order - sorted_indices = mx.argsort(probs, axis=-1) - sorted_probs = probs[..., sorted_indices.squeeze(0)] + # Sort probs in descending order + sorted_indices = mx.argsort(-probs, axis=axis) + sorted_probs = mx.take_along_axis(probs, sorted_indices, axis=axis) - cumulative_probs = mx.cumsum(sorted_probs, axis=-1) + # Compute cumulative probabilities + cumulative_probs = mx.cumsum(sorted_probs, axis=axis) - # select tokens with cumulative probs below threshold - top_probs = mx.where( - cumulative_probs > 1 - top_p, - sorted_probs, - 0, - ) + # Create a mask for probs above the threshold + mask = cumulative_probs <= top_p + + # Apply the mask to the sorted probabilities + masked_probs = sorted_probs * mask - sorted_token = mx.random.categorical(mx.log(top_probs)) - token = sorted_indices.squeeze(0)[sorted_token] + # Sample from the normalized probabilities + sampled_indices = mx.random.categorical(mx.log(masked_probs), axis=axis) + + # Gather the original token indices + tokens = mx.take_along_axis( + sorted_indices, mx.expand_dims(sampled_indices, axis=axis), axis=axis + ) - return token + return tokens.squeeze(axis) @partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) @@ -226,15 +241,15 @@ def make_repetition_penalty(penalty: float, context_size: int = 20): raise ValueError(f"penalty must be a non-negative float, got {penalty}") def repetition_penalty_processor(tokens, logits): - if len(tokens) > 0: - tokens = tokens[-context_size:] - selected_logits = logits[:, tokens] + if tokens.shape[-1] > 0: + tokens = tokens[..., -context_size:] + selected_logits = mx.take_along_axis(logits, tokens, axis=-1) selected_logits = mx.where( selected_logits < 0, selected_logits * penalty, selected_logits / penalty, ) - logits[:, tokens] = selected_logits + logits[mx.arange(tokens.shape[0])[:, None], tokens] = selected_logits return logits return repetition_penalty_processor diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 4d69115e0..714609f59 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -21,6 +21,7 @@ # Local imports from .models import cache +from .models.base import create_causal_mask from .sample_utils import make_logits_processors, make_sampler from .tokenizer_utils import TokenizerWrapper, load_tokenizer from .tuner.utils import dequantize as dequantize_model @@ -202,6 +203,7 @@ def generate_step( logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None, max_kv_size: Optional[int] = None, prompt_cache: Optional[Any] = None, + mask: Optional[mx.array] = None, prefill_step_size: int = 512, kv_bits: Optional[int] = None, kv_group_size: int = 64, @@ -215,22 +217,27 @@ def generate_step( min_tokens_to_keep: Optional[int] = None, ) -> Generator[Tuple[mx.array, mx.array], None, None]: """ - A generator producing token ids based on the given prompt from the model. + A generator producing token ids based on the given prompt(s) from the model. Args: - prompt (mx.array): The input prompt. + prompt (mx.array): The input prompt(s), shaped (batch_size, prompt_len). model (nn.Module): The model to use for generation. max_tokens (int): The maximum number of tokens. Use``-1`` for an infinite generator. Default: ``256``. - sampler (Callable[mx.array, mx.array], optional): A sampler for sampling a - token from a vector of log probabilities. Default: ``None``. + sampler (Callable[mx.array, mx.array], optional): A sampler for sampling + tokens (shaped (batch_size,)) from a vector of log probabilities, + shaped (batch_size, vocab_size). Default: ``None``. logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional): - A list of functions that take tokens and logits and return the processed - logits. Default: ``None``. + A list of functions that take tokens (shaped (batch_size, gen_len)) and + logits and return the processed logits, shaped (batch_size, vocab_size). + Default: ``None``. max_kv_size (int, optional): Maximum size of the key-value cache. Old entries (except the first 4 tokens) will be overwritten. prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if provided, the cache will be updated in place. + mask (mx.array, optional): An attention mask to apply to the prompt. + Should be of shape (batch_size, 1, prompt_len, prompt_len + cache_len). + See: `create_causal_mask`. prefill_step_size (int): Step size for processing the prompt. kv_bits (int, optional): Number of bits to use for KV cache quantization. None implies no cache quantization. Default: ``None``. @@ -241,12 +248,25 @@ def generate_step( prompt tokens processed so far and the total number of prompt tokens. Yields: - Tuple[mx.array, mx.array]: One token and a vector of log probabilities. + Tuple[mx.array, mx.array]: One token (shaped (batch_size,)) + and a vector of log probabilities (shaped (batch_size, vocab_size)). """ + if prompt.ndim == 1: + print( + "[Warning] Passing a (prompt_len,)-shaped ``prompt`` into ``generate_step`` " + "is deprecated. Pass in a (batch_size, prompt_len)-shaped ``prompt`` instead." + ) + prompt = prompt[None] y = prompt tokens = None + if y.shape[0] != 1 and max_kv_size is not None: + # TODO: If we have left-padded sequences, we need to evict all the + # pad tokens from the `RotatingKVCache` before evicting from left plus 4. + # The indexing of `mask` below will also break if we evict from cache. + raise ValueError("max_kv_size is not supported for batched generation.") + # Create the KV cache for generation if prompt_cache is None: prompt_cache = cache.make_prompt_cache( @@ -276,13 +296,19 @@ def generate_step( prompt_progress_callback = prompt_progress_callback or (lambda *_: None) def _step(y): + if y.ndim == 1: + y = mx.expand_dims(y, axis=-1) with mx.stream(generation_stream): - logits = model(y[None], cache=prompt_cache) + logits = model( + y, + cache=prompt_cache, + mask=mask if mask is not None and y.shape[-1] > 1 else None, + ) logits = logits[:, -1, :] if logits_processors: nonlocal tokens - tokens = mx.concat([tokens, y]) if tokens is not None else y + tokens = mx.concat([tokens, y], axis=-1) if tokens is not None else y for processor in logits_processors: logits = processor(tokens, logits) @@ -293,20 +319,30 @@ def _step(y): logprobs = logits - mx.logsumexp(logits, keepdims=True) y = sampler(logprobs) - return y, logprobs.squeeze(0) + return y, logprobs with mx.stream(generation_stream): total_prompt_tokens = y.size prompt_processed_tokens = 0 - while y.size > prefill_step_size: - model(y[:prefill_step_size][None], cache=prompt_cache) + while y.shape[-1] > prefill_step_size: + offset = prompt_cache[0].offset + model( + y[:, :prefill_step_size], + cache=prompt_cache, + mask=( + mask[:, :, :prefill_step_size, : offset + prefill_step_size] + if mask is not None + else None + ), + ) maybe_quantize_kv_cache( prompt_cache, quantized_kv_start, kv_group_size, kv_bits ) mx.eval([c.state for c in prompt_cache]) prompt_progress_callback(prompt_processed_tokens, total_prompt_tokens) - prompt_processed_tokens += prefill_step_size - y = y[prefill_step_size:] + prompt_processed_tokens += y.shape[0] * prefill_step_size + y = y[:, prefill_step_size:] + mask = mask[:, :, prefill_step_size:, :] if mask is not None else None mx.metal.clear_cache() y, logprobs = _step(y) @@ -322,7 +358,8 @@ def _step(y): prompt_progress_callback(total_prompt_tokens, total_prompt_tokens) if n == max_tokens: break - yield y.item(), logprobs + mx.eval(y) + yield y, logprobs if n % 256 == 0: mx.metal.clear_cache() y, logprobs = next_y, next_logprobs @@ -357,12 +394,16 @@ def stream_generate( prompt if isinstance(prompt, list) else tokenizer.encode(prompt) ) + if prompt.ndim == 1: + prompt = prompt[None] + detokenizer = tokenizer.detokenizer with wired_limit(model, [generation_stream]): detokenizer.reset() tic = time.perf_counter() for n, (token, logprobs) in enumerate(generate_step(prompt, model, **kwargs)): + token, logprobs = token.item(), logprobs.squeeze(0) if n == 0: prompt_time = time.perf_counter() - tic prompt_tps = prompt.size / prompt_time @@ -451,6 +492,97 @@ def generate( return text +def batch_generate( + model: nn.Module, + tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], + prompts: list[str], + verbose: bool = False, + **kwargs, +) -> list[str]: + """ + Generate complete responses from the model for a list of prompts. + + Args: + model (nn.Module): The language model. + tokenizer (PreTrainedTokenizer): The tokenizer. + prompts (List[str]): The string prompts. + verbose (bool): If ``True``, print tokens and timing information. + Default: ``False``. + kwargs: The remaining options get passed to :func:`generate_step`. + See :func:`generate_step` for more details. + """ + if "prompt_cache" in kwargs: + # TODO: Handle `prompt_cache` and `prompt` both left-padded, so that + # we have texttext. Should involve taking `prompt_cache_lens` + # to extend `mask` below, and handling position_ids (see TODO below) + raise ValueError("Batch generation does not support prompt_cache yet.") + tokenizer = copy.deepcopy(tokenizer) + if not isinstance(tokenizer, TokenizerWrapper): + tokenizer = TokenizerWrapper(tokenizer) + # TODO: left-shift position_ids for absolute/rotary positional encodings + # Example: https://github.com/huggingface/transformers/issues/26072#issuecomment-2101209470 + tokenizer._tokenizer.padding_side = "left" + if tokenizer.pad_token is None: + tokenizer._tokenizer.pad_token = tokenizer.eos_token + tokenizer._tokenizer.pad_token_id = tokenizer.eos_token_id + res = tokenizer._tokenizer(prompts, padding=True) + input_ids, token_mask = mx.array(res["input_ids"]), mx.array(res["attention_mask"]) + dtype = mx.float32 + for module in model.modules(): + if isinstance(module, nn.QuantizedEmbedding) or isinstance( + module, nn.Embedding + ): + dtype = module(mx.zeros(1, dtype=input_ids.dtype)).dtype + break + causal_mask = create_causal_mask(token_mask.shape[-1], dtype=dtype) + # HACK: sometimes see NaN logprobs if no divide by 2 here + # mask = mx.where(token_mask[:, None, None, :], causal_mask, mx.finfo(dtype).min / 2) + mask = mx.where(token_mask[:, None, None, :], causal_mask, -65504.0 / 2) + + output_toks = [] + prompt_time = None + ended = mx.zeros(len(prompts), dtype=mx.bool_) + tic = time.perf_counter() + # TODO: non-generator version of `generate_step` so that we can + # add or remove prompts from the batch as they start/finish + for tokens, _ in generate_step(input_ids, model, mask=mask, **kwargs): + if not prompt_time: + prompt_time = time.perf_counter() - tic + tic = time.perf_counter() + ended = ended | (tokens == tokenizer.eos_token_id) + if ended.all(): + break + output_toks.append(tokens) + if verbose: + print(".", end="", flush=True) + output_toks = mx.stack(output_toks, axis=-1) + token_count = output_toks.size + response = [ + response.split(tokenizer.eos_token)[0].split(tokenizer.pad_token)[0] + for response in tokenizer.batch_decode(output_toks.tolist()) + ] + if verbose: + gen_time = time.perf_counter() - tic + if token_count <= 0: + print("No tokens generated for this prompt") + else: + print() + for p, resp in zip(prompts, response): + print("=" * 10) + print("Prompt:", p) + print(resp) + print("=" * 10) + if prompt_time: + prompt_tps = input_ids.size / prompt_time + print(f"Prompt: {input_ids.size} tokens, {prompt_tps:.3f} tokens-per-sec") + gen_tps = token_count / gen_time + print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec") + peak_mem = mx.metal.get_peak_memory() / 2**30 + print(f"Peak memory: {peak_mem:.3f} GB") + + return response + + def load_config(model_path: Path) -> dict: try: with open(model_path / "config.json", "r") as f: diff --git a/llms/tests/test_generate.py b/llms/tests/test_generate.py index f23453943..bcdb0d9f7 100644 --- a/llms/tests/test_generate.py +++ b/llms/tests/test_generate.py @@ -2,12 +2,11 @@ import unittest -from mlx_lm.sample_utils import make_logits_processors -from mlx_lm.utils import generate, load +from mlx_lm.sample_utils import make_logits_processors, make_sampler +from mlx_lm.utils import generate, batch_generate, load class TestGenerate(unittest.TestCase): - @classmethod def setUpClass(cls): HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit" @@ -49,7 +48,26 @@ def logits_processor(toks, logits): verbose=False, logits_processors=[logits_processor], ) - self.assertEqual(len(all_toks), len(init_toks) + 5) + self.assertEqual(all_toks.shape[-1], len(init_toks) + 5) + + def test_batch_generate(self): + logit_bias = {0: 20.0, 1: -20.0} + texts = batch_generate( + self.model, + self.tokenizer, + [ + "hello", + "this is a longer prompt to test out the padding and masking. hello", + ], + max_tokens=5, + prefill_step_size=4, + sampler=make_sampler(temp=1.0, min_p=0.5), + logits_processors=make_logits_processors( + logit_bias, repetition_penalty=2.0 + ), + verbose=False, + ) + self.assertEqual(texts, ["!", "!"]) if __name__ == "__main__": diff --git a/llms/tests/test_prompt_cache.py b/llms/tests/test_prompt_cache.py index de5694d58..5fcc6834a 100644 --- a/llms/tests/test_prompt_cache.py +++ b/llms/tests/test_prompt_cache.py @@ -121,21 +121,24 @@ def test_save_load_mixed_cache(self): def test_cache_with_generate(self): model, tokenizer = load(HF_MODEL_PATH) prompt = tokenizer.encode("this is a prompt", return_tensors="mlx")[0] - results = list(generate_step(prompt, model, max_tokens=4)) + results = list(generate_step(prompt[None], model, max_tokens=4)) + results = [(t.item(), l.squeeze(0)) for t, l in results] toks, all_logits = zip(*results) prompt_cache = make_prompt_cache(model) i = 0 for tok, logits in generate_step( - prompt, model, prompt_cache=prompt_cache, max_tokens=2 + prompt[None], model, prompt_cache=prompt_cache, max_tokens=2 ): + tok, logits = tok.item(), logits.squeeze(0) self.assertEqual(tok, toks[i]) self.assertTrue(mx.allclose(logits, all_logits[i])) i += 1 for tok, logits in generate_step( - mx.array([toks[i]]), model, prompt_cache=prompt_cache, max_tokens=1 + mx.array([[toks[i]]]), model, prompt_cache=prompt_cache, max_tokens=1 ): + tok, logits = tok.item(), logits.squeeze(0) i += 1 self.assertEqual(tok, toks[i]) self.assertTrue(mx.allclose(logits, all_logits[i])) @@ -205,14 +208,16 @@ def test_trim_cache_with_generate(self): prompt_cache = make_prompt_cache(model) # Generate one token so we process the full prompt - last_tok, _ = next(generate_step(prompt, model, prompt_cache=prompt_cache)) - last_tok = mx.array([last_tok]) + last_tok, _ = next( + generate_step(prompt[None], model, prompt_cache=prompt_cache) + ) # Generate two more tokens results = zip( - range(2), generate_step(last_tok, model, prompt_cache=prompt_cache) + range(2), generate_step(last_tok[None], model, prompt_cache=prompt_cache) ) - toks, all_logits = zip(*(r[1] for r in results)) + results = [(t.item(), l.squeeze(0)) for _, (t, l) in results] + toks, all_logits = zip(*results) # To get back to the cache just after processing the prompt, # trim by 3 tokens @@ -220,9 +225,10 @@ def test_trim_cache_with_generate(self): # Generate the same thing again results = zip( - range(2), generate_step(last_tok, model, prompt_cache=prompt_cache) + range(2), generate_step(last_tok[None], model, prompt_cache=prompt_cache) ) - second_toks, second_all_logits = zip(*(r[1] for r in results)) + results = [(t.item(), l.squeeze(0)) for _, (t, l) in results] + second_toks, second_all_logits = zip(*results) self.assertEqual(toks, second_toks) self.assertTrue( all(mx.allclose(l, l2) for l, l2 in zip(all_logits, second_all_logits)) @@ -278,14 +284,16 @@ def test_save_load_quantized_cache(self): def test_cache_to_quantized(self): model, tokenizer = load(HF_MODEL_PATH) prompt = tokenizer.encode("this is a prompt", return_tensors="mlx")[0] - results = zip(range(4), generate_step(prompt, model)) - toks, all_logits = zip(*(r[1] for r in results)) + results = zip(range(4), generate_step(prompt[None], model)) + results = [(t.item(), l.squeeze(0)) for _, (t, l) in results] + toks, all_logits = zip(*results) prompt_cache = make_prompt_cache(model) i = 0 for _, (tok, logits) in zip( - range(2), generate_step(prompt, model, prompt_cache=prompt_cache) + range(2), generate_step(prompt[None], model, prompt_cache=prompt_cache) ): + tok, logits = tok.item(), logits.squeeze(0) self.assertEqual(tok, toks[i]) self.assertTrue(mx.allclose(logits, all_logits[i])) i += 1 @@ -294,8 +302,9 @@ def test_cache_to_quantized(self): for _, (tok, logits) in zip( range(1), - generate_step(mx.array([toks[i]]), model, prompt_cache=prompt_cache), + generate_step(mx.array([[toks[i]]]), model, prompt_cache=prompt_cache), ): + tok, logits = tok.item(), logits.squeeze(0) i += 1 self.assertEqual(tok, toks[i]) self.assertTrue(mx.allclose(logits, all_logits[i], rtol=2e-2)) diff --git a/llms/tests/test_sample_utils.py b/llms/tests/test_sample_utils.py index c45fa4439..bc38e8b59 100644 --- a/llms/tests/test_sample_utils.py +++ b/llms/tests/test_sample_utils.py @@ -10,8 +10,9 @@ def test_top_p_sampling(self): logits = mx.log(probs) temperature = 1.0 - token = top_p_sampling(logits, 0.3, temperature).item() - self.assertEqual(token, 0) + token = top_p_sampling(logits, 0.3, temperature) + self.assertEqual(token.shape, (1,)) + self.assertEqual(token.item(), 0) token = top_p_sampling(logits, 0.95, temperature).item() self.assertTrue(token in (0, 3)) @@ -28,26 +29,41 @@ def test_top_p_sampling(self): token = top_p_sampling(logits, 0.95, temperature).item() self.assertTrue(token in (1, 2, 3)) + # Batch mode works + probs = mx.array([[0.9, 0.0, 0.0, 0.1], [0.0, 0.5, 0.4, 0.1]]) + logits = mx.log(probs) + token = top_p_sampling(logits, 0.4, temperature) + self.assertEqual(token.shape, (2,)) + self.assertEqual(token.tolist(), [0, 1]) + def test_min_p_sampling(self): probs = mx.array([0.9, 0.0, 0.0, 0.1])[None] logits = mx.log(probs) - temperature = 1.0 token = min_p_sampling(logits, 0.8) - self.assertEqual(token, 0) + self.assertEqual(token.shape, (1,)) + self.assertEqual(token.item(), 0) probs = mx.array([0.9, 0.0, 0.0, 0.1])[None] logits = mx.log(probs) - temperature = 1.0 for _ in range(5): token = min_p_sampling(logits, 0.05) self.assertTrue(token in (0, 3)) + # Batch mode works + probs = mx.array([[0.6, 0.0, 0.0, 0.4], [0.7, 0.0, 0.0, 0.3]]) + logits = mx.log(probs) + for _ in range(5): + token = min_p_sampling(logits, 0.65) + self.assertEqual(token.shape, (2,)) + self.assertTrue(token.tolist() in ([0, 0], [3, 0])) + def test_top_k_sampling(self): probs = mx.array([0.9, 0.0, 0.0, 0.1])[None] logits = mx.log(probs) - token = top_k_sampling(logits, 1).item() - self.assertEqual(token, 0) + token = top_k_sampling(logits, 1) + self.assertEqual(token.shape, (1,)) + self.assertEqual(token.item(), 0) probs = mx.array([0.5, 0.0, 0.0, 0.5])[None] tokens = set() @@ -61,6 +77,7 @@ def test_top_k_sampling(self): logits = mx.log(probs) tokens = top_k_sampling(logits, 1) + self.assertEqual(tokens.shape, (2,)) self.assertEqual(tokens.tolist(), [0, 1])