Skip to content

Commit

Permalink
Merge pull request #173 from coreweave/dhix/parameterize-url
Browse files Browse the repository at this point in the history
feat: Add the ability to sub in URL and bucket for benchmarks
  • Loading branch information
wbrown authored Aug 7, 2024
2 parents 1552f64 + 8e8a91e commit 44a3e60
Showing 1 changed file with 21 additions and 7 deletions.
28 changes: 21 additions & 7 deletions examples/benchmark_buffer_size/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from tensorizer.stream_io import (
CURLStreamFile,
RedisStreamFile,
default_s3_read_endpoint,
open_stream,
)

Expand Down Expand Up @@ -117,20 +116,34 @@
parser.add_argument(
"--convert-json", default="", help="Convert JSON to human readable"
)
parser.add_argument(
"--s3-endpoint",
type=str,
help="The S3 storage URL to load the models from (default: accel-object.ord1.coreweave.com)",
default="accel-object.ord1.coreweave.com"
)
parser.add_argument(
"--bucket",
type=str,
help="The bucket where the models are located (default: tensorized)",
default="tensorized"
)
args = parser.parse_args()

model_name: str = args.model

http_uri = (
"http://tensorized.accel-object.ord1.coreweave.com"
f"/{model_name}/model.tensors"
)
http_uri = f"http://{args.bucket}.{args.s3_url}/{model_name}/model.tensors"

https_uri = http_uri.replace("http://", "https://")
s3_uri = f"s3://tensorized/{model_name}/model.tensors"
s3_uri = f"s3://{args.bucket}/{model_name}/model.tensors"
sanitized_model_file = model_name.replace("/", "_")
file_uri = f"{args.file_prefix}{sanitized_model_file}.tensors"
local_uri = f"http://localhost:3000/{sanitized_model_file}.tensors"

s3_endpoint = f"http://{args.s3_url}"
if args.test_http:
s3_endpoint = s3_endpoint.replace("http://", "https://")

# Get nodename from environment, or default to os.uname().nodename
nodename = os.getenv("K8S_NODE_NAME") or os.uname().nodename

Expand Down Expand Up @@ -367,6 +380,7 @@ def io_test(

def deserialize_test(
source=http_uri,
s3_endpoint=s3_endpoint,
plaid_mode=False,
verify_hash=False,
lazy_load=False,
Expand All @@ -378,7 +392,7 @@ def deserialize_test(
plaid_mode_buffers = None
stream = open_stream(
source,
s3_endpoint=default_s3_read_endpoint,
s3_endpoint=s3_endpoint,
buffer_size=buffer_size,
force_http=force_http,
)
Expand Down

0 comments on commit 44a3e60

Please sign in to comment.