Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Distributed inference example #890

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 43 additions & 7 deletions llms/mlx_lm/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,25 @@
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"


def share_message(world, prompt):
if world.size() == 1:
return prompt

if world.rank() == 0:
size = mx.array([len(prompt)])
else:
size = mx.array([0])
size = mx.distributed.all_sum(size, stream=mx.cpu).item()
if size == 0:
return []

if world.rank() == 0:
prompt = mx.array(prompt)
else:
prompt = mx.array([0] * len(prompt))
return mx.distributed.all_sum(size, stream=mx.cpu).tolist()


def setup_arg_parser():
"""Set up and return the argument parser."""
parser = argparse.ArgumentParser(description="Chat with an LLM")
Expand Down Expand Up @@ -54,6 +73,7 @@ def setup_arg_parser():


def main():
world = mx.distributed.init()
parser = setup_arg_parser()
args = parser.parse_args()

Expand All @@ -63,16 +83,30 @@ def main():
args.model,
adapter_path=args.adapter_path,
tokenizer_config={"trust_remote_code": True},
sequential_load=mx.distributed.init().size() > 1,
)

print(f"[INFO] Starting chat session with {args.model}. To exit, enter 'q'.")
print(f"Node {world.rank()} of {world.size()}", flush=True)
print(
f"[INFO] Starting chat session with {args.model}. To exit, enter 'q'.",
flush=True,
)
world.barrier()
prompt_cache = make_prompt_cache(model, args.max_kv_size)
while True:
query = input(">> ")
if query == "q":
if world.rank() == 0:
query = input(">> ")
if query == "q":
prompt = []
else:
messages = [{"role": "user", "content": query}]
prompt = tokenizer.apply_chat_template(
messages, add_generation_prompt=True
)

prompt = share_message(world, prompt)
if len(prompt) == 0:
break
messages = [{"role": "user", "content": query}]
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
for response in stream_generate(
model,
tokenizer,
Expand All @@ -81,8 +115,10 @@ def main():
sampler=make_sampler(args.temp, args.top_p),
prompt_cache=prompt_cache,
):
print(response.text, flush=True, end="")
print()
if world.rank() == 0:
print(response, flush=True, end="")
if world.rank() == 0:
print()


if __name__ == "__main__":
Expand Down
11 changes: 9 additions & 2 deletions llms/mlx_lm/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ def main():
model_path,
adapter_path=args.adapter_path,
tokenizer_config=tokenizer_config,
sequential_load=mx.distributed.init().size() > 1,
)
for eos_token in args.extra_eos_token:
tokenizer.add_eos_token(eos_token)
Expand Down Expand Up @@ -234,13 +235,17 @@ def main():
else:
draft_model = None
sampler = make_sampler(args.temp, args.top_p, args.min_p, args.min_tokens_to_keep)

world = mx.distributed.init()
print(f"Node {world.rank()} of {world.size()}", flush=True)
world.barrier()
response = generate(
model,
tokenizer,
prompt,
max_tokens=args.max_tokens,
verbose=args.verbose,
sampler=sampler,
verbose=args.verbose and world.rank() == 0,
max_kv_size=args.max_kv_size,
prompt_cache=prompt_cache if using_cache else None,
kv_bits=args.kv_bits,
Expand All @@ -249,8 +254,10 @@ def main():
draft_model=draft_model,
num_draft_tokens=args.num_draft_tokens,
)
if not args.verbose:

if not args.verbose and mx.distributed.init().rank() == 0:
print(response)
mx.synchronize()


if __name__ == "__main__":
Expand Down
30 changes: 30 additions & 0 deletions llms/mlx_lm/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,36 @@ def sanitize(self, weights):
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
}

def shard(self, group: Optional[mx.distributed.Group] = None):
group = group or mx.distributed.init()

def all_to_sharded(l):
if isinstance(l, nn.QuantizedLinear):
return nn.QuantizedAllToShardedLinear.from_quantized_linear(l, group)
else:
return nn.AllToShardedLinear.from_linear(l, group)

def sharded_to_all(l):
if isinstance(l, nn.QuantizedLinear):
return nn.QuantizedShardedToAllLinear.from_quantized_linear(l, group)
else:
return nn.ShardedToAllLinear.from_linear(l, group)

N = group.size()
for layer in self.model.layers:
# Shard the self attention
layer.self_attn.q_proj = all_to_sharded(layer.self_attn.q_proj)
layer.self_attn.k_proj = all_to_sharded(layer.self_attn.k_proj)
layer.self_attn.v_proj = all_to_sharded(layer.self_attn.v_proj)
layer.self_attn.o_proj = sharded_to_all(layer.self_attn.o_proj)
layer.self_attn.n_heads //= N
layer.self_attn.n_kv_heads //= N

# Shard the MLP
layer.mlp.gate_proj = all_to_sharded(layer.mlp.gate_proj)
layer.mlp.down_proj = sharded_to_all(layer.mlp.down_proj)
layer.mlp.up_proj = all_to_sharded(layer.mlp.up_proj)

@property
def layers(self):
return self.model.layers
21 changes: 17 additions & 4 deletions llms/mlx_lm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,12 +306,12 @@ def _step(y):

y, logprobs = _step(y)

mx.async_eval(y, logprobs)
mx.eval(y, logprobs)
n = 0
while True:
if n != max_tokens:
next_y, next_logprobs = _step(y)
mx.async_eval(next_y, next_logprobs)
mx.eval(next_y, next_logprobs)
if n == 0:
mx.eval(y)
prompt_progress_callback(total_prompt_tokens, total_prompt_tokens)
Expand Down Expand Up @@ -628,6 +628,7 @@ def load_model(
model_path: Path,
lazy: bool = False,
strict: bool = True,
sequential_load: bool = False,
model_config: dict = {},
get_model_classes: Callable[[dict], Tuple[Type[nn.Module], Type]] = _get_classes,
) -> nn.Module:
Expand Down Expand Up @@ -699,7 +700,16 @@ def class_predicate(p, m):

model.load_weights(list(weights.items()), strict=strict)

if mx.distributed.init().size() > 1:
if not hasattr(model, "shard"):
raise RuntimeError("Model doesn't support distributed inference.")
model.shard()

if not lazy:
weights.clear()
if sequential_load:
for layer in model.layers:
mx.eval(layer.parameters())
mx.eval(model.parameters())

model.eval()
Expand All @@ -712,6 +722,7 @@ def load(
model_config={},
adapter_path: Optional[str] = None,
lazy: bool = False,
sequential_load: bool = False,
) -> Tuple[nn.Module, TokenizerWrapper]:
"""
Load the model and tokenizer from a given path or a huggingface repository.
Expand All @@ -727,6 +738,8 @@ def load(
lazy (bool): If ``False`` eval the model parameters to make sure they are
loaded in memory before returning, otherwise they will be loaded
when needed. Default: ``False``
sequential_load (bool): If True then load each layer sequentially to
ensure that we are not wasting memory.
Returns:
Tuple[nn.Module, TokenizerWrapper]: A tuple containing the loaded model and tokenizer.

Expand All @@ -736,7 +749,7 @@ def load(
"""
model_path = get_model_path(path_or_hf_repo)

model, config = load_model(model_path, lazy)
model, config = load_model(model_path, lazy=lazy, sequential_load=sequential_load)
if adapter_path is not None:
model = load_adapters(model, adapter_path)
model.eval()
Expand All @@ -750,7 +763,7 @@ def load(
def fetch_from_hub(
model_path: Path, lazy: bool = False
) -> Tuple[nn.Module, dict, PreTrainedTokenizer]:
model, config = load_model(model_path, lazy)
model, config = load_model(model_path, lazy=lazy)
tokenizer = load_tokenizer(
model_path, eos_token_ids=config.get("eos_token_id", None)
)
Expand Down