Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(mlx_lm)!: batch_generate #948

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion llms/mlx_lm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@

os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"

from .utils import convert, generate, load, stream_generate
from .utils import convert, generate, load, stream_generate, batch_generate
2 changes: 1 addition & 1 deletion llms/mlx_lm/cache_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def main():
prompt = args.prompt

cache = make_prompt_cache(model, args.max_kv_size)
y = mx.array(tokenizer.encode(prompt))
y = mx.array(tokenizer.encode(prompt))[None]

# Process the prompt
start = time.time()
Expand Down
5 changes: 4 additions & 1 deletion llms/mlx_lm/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def create_causal_mask(
offset: int = 0,
window_size: Optional[int] = None,
lengths: Optional[mx.array] = None,
dtype: mx.Dtype = mx.float32,
):
rinds = mx.arange(offset + N)
linds = mx.arange(offset, offset + N) if offset else rinds
Expand All @@ -39,7 +40,9 @@ def create_causal_mask(
if lengths is not None:
lengths = lengths[:, None, None, None]
mask = mask | (rinds >= lengths)
return mask * -1e9
# HACK: sometimes see NaN logprobs if no divide by 2 here
# return mask * (mx.finfo(dtype).min / 2)
return mask.astype(dtype) * (-65504.0 / 2)


def create_attention_mask(h: mx.array, cache: Optional[Any] = None):
Expand Down
71 changes: 43 additions & 28 deletions llms/mlx_lm/sample_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def make_logits_processors(
values = mx.array(list(logit_bias.values()))

def logit_bias_processor(_, logits):
logits[:, indices] += values
logits[..., indices] += values
return logits

logits_processors.append(logit_bias_processor)
Expand Down Expand Up @@ -132,7 +132,10 @@ def min_p_sampling(
0.99-0.8 range.
min_tokens_to_keep (int, optional): Minimum number of tokens that cannot
be filtered. Default: ``1``.

temperature: Temperature parameter for softmax distribution reshaping.
Returns:
token(s) selected based on the min-p criterion.
Shape: same as logits, but with the last dimension having size 1.
"""
if not (0 <= min_p <= 1.0):
raise ValueError(
Expand All @@ -147,11 +150,11 @@ def min_p_sampling(
logprobs = logprobs * (1 / temperature)

# Indices sorted in decreasing order
sorted_indices = mx.argsort(-logprobs).squeeze(0)
sorted_logprobs = logprobs[..., sorted_indices]
sorted_indices = mx.argsort(-logprobs)
sorted_logprobs = mx.take_along_axis(logprobs, sorted_indices, axis=-1)

# Top probability
top_logprobs = logprobs[..., sorted_indices[0]]
top_logprobs = mx.expand_dims(sorted_logprobs[..., 0], axis=-1)

# Calculate the min_p threshold
scaled_min_p = top_logprobs + math.log(min_p)
Expand All @@ -163,43 +166,55 @@ def min_p_sampling(
# Create pool of tokens with probability less than scaled min_p
selected_logprobs = mx.where(tokens_to_remove, -float("inf"), sorted_logprobs)

# Return sampled token
sorted_token = mx.random.categorical(selected_logprobs)
return sorted_indices[sorted_token]
# Return sampled token(s)
sampled_indices = mx.random.categorical(selected_logprobs)
tokens = mx.take_along_axis(
sorted_indices, mx.expand_dims(sampled_indices, axis=-1), axis=-1
)
return tokens.squeeze(-1)


@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.array:
def top_p_sampling(
logits: mx.array, top_p: float, temperature: float, axis: int = -1
) -> mx.array:
"""
Apply top-p (nucleus) sampling to logits.

Args:
logits: The logits from the model's output.
top_p: The cumulative probability threshold for top-p filtering.
temperature: Temperature parameter for softmax distribution reshaping.
axis: The axis along which to apply top-p sampling.
Returns:
token selected based on the top-p criterion.
token(s) selected based on the top-p criterion.
"""
# referenced implementation from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L449-L460
probs = mx.softmax(logits * (1 / temperature), axis=-1)
# Apply temperature and compute softmax
probs = mx.softmax(logits / temperature, axis=axis)

# sort probs in ascending order
sorted_indices = mx.argsort(probs, axis=-1)
sorted_probs = probs[..., sorted_indices.squeeze(0)]
# Sort probs in descending order
sorted_indices = mx.argsort(-probs, axis=axis)
sorted_probs = mx.take_along_axis(probs, sorted_indices, axis=axis)

cumulative_probs = mx.cumsum(sorted_probs, axis=-1)
# Compute cumulative probabilities
cumulative_probs = mx.cumsum(sorted_probs, axis=axis)

# select tokens with cumulative probs below threshold
top_probs = mx.where(
cumulative_probs > 1 - top_p,
sorted_probs,
0,
)
# Create a mask for probs above the threshold
mask = cumulative_probs <= top_p

# Apply the mask to the sorted probabilities
masked_probs = sorted_probs * mask

sorted_token = mx.random.categorical(mx.log(top_probs))
token = sorted_indices.squeeze(0)[sorted_token]
# Sample from the normalized probabilities
sampled_indices = mx.random.categorical(mx.log(masked_probs), axis=axis)

# Gather the original token indices
tokens = mx.take_along_axis(
sorted_indices, mx.expand_dims(sampled_indices, axis=axis), axis=axis
)

return token
return tokens.squeeze(axis)


@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
Expand All @@ -226,15 +241,15 @@ def make_repetition_penalty(penalty: float, context_size: int = 20):
raise ValueError(f"penalty must be a non-negative float, got {penalty}")

def repetition_penalty_processor(tokens, logits):
if len(tokens) > 0:
tokens = tokens[-context_size:]
selected_logits = logits[:, tokens]
if tokens.shape[-1] > 0:
tokens = tokens[..., -context_size:]
selected_logits = mx.take_along_axis(logits, tokens, axis=-1)
selected_logits = mx.where(
selected_logits < 0,
selected_logits * penalty,
selected_logits / penalty,
)
logits[:, tokens] = selected_logits
logits[mx.arange(tokens.shape[0])[:, None], tokens] = selected_logits
return logits

return repetition_penalty_processor
Loading