Skip to content

Commit

Permalink
Merge branch 'main' into benchmarking_script
Browse files Browse the repository at this point in the history
  • Loading branch information
Jack-Khuu authored Nov 18, 2024
2 parents 80d2c6d + 5da240a commit b7fc9c7
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 7 deletions.
2 changes: 1 addition & 1 deletion torchchat/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,7 @@ def arg_init(args):
# Localized import to minimize expensive imports
from torchchat.utils.build_utils import get_device_str

if args.device is None:
if args.device is None or args.device == "fast":
args.device = get_device_str(
args.quantize.get("executor", {}).get("accelerator", default_device)
)
Expand Down
8 changes: 7 additions & 1 deletion torchchat/cli/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ def _download_direct(
def download_and_convert(
model: str, models_dir: Path, hf_token: Optional[str] = None
) -> None:
if model is None:
raise ValueError("'download' command needs a model name or alias.")
model_config = resolve_model_config(model)
model_dir = models_dir / model_config.name

Expand Down Expand Up @@ -234,4 +236,8 @@ def where_main(args) -> None:

# Subcommand to download model artifacts.
def download_main(args) -> None:
download_and_convert(args.model, args.model_directory, args.hf_token)
try:
download_and_convert(args.model, args.model_directory, args.hf_token)
except ValueError as e:
print(e, file=sys.stderr)
sys.exit(1)
25 changes: 20 additions & 5 deletions torchchat/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1189,12 +1189,27 @@ def callback(x, *, done_generating=False):
f"Mean Accepted: {sum([idx * i for idx, i in enumerate(counts_aggregated)])/sum(counts_aggregated)}"
)

print(
f"\n Average tokens/sec (total): {torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item():.2f} \
\nAverage tokens/sec (first token): {torch.mean(torch.tensor(aggregate_metrics['first_token_per_sec'])).item():.2f} \
\nAverage tokens/sec (next tokens): {torch.mean(torch.tensor(aggregate_metrics['next_tokens_per_sec'])).item():.2f} \n\
avg_tokens_sec = torch.mean(
torch.tensor(aggregate_metrics["tokens_per_sec"])
).item()
avg_first_token_sec = torch.mean(
torch.tensor(aggregate_metrics["first_token_per_sec"])
).item()
avg_next_tokens_sec = torch.mean(
torch.tensor(aggregate_metrics["next_tokens_per_sec"])
).item()

if not (
torch.isnan(torch.tensor(avg_tokens_sec))
or torch.isnan(torch.tensor(avg_first_token_sec))
or torch.isnan(torch.tensor(avg_next_tokens_sec))
):
print(
f"\n Average tokens/sec (total): {avg_tokens_sec:.2f} \
\nAverage tokens/sec (first token): {avg_first_token_sec:.2f} \
\nAverage tokens/sec (next tokens): {avg_next_tokens_sec:.2f} \n\
"
)
)
if torch.cuda.is_available():
print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")

Expand Down

0 comments on commit b7fc9c7

Please sign in to comment.