diff --git a/examples/benchmark_buffer_size/benchmark.py b/examples/benchmark_buffer_size/benchmark.py index ae46fda..af6afb1 100755 --- a/examples/benchmark_buffer_size/benchmark.py +++ b/examples/benchmark_buffer_size/benchmark.py @@ -16,7 +16,6 @@ from tensorizer.stream_io import ( CURLStreamFile, RedisStreamFile, - default_s3_read_endpoint, open_stream, ) @@ -117,20 +116,34 @@ parser.add_argument( "--convert-json", default="", help="Convert JSON to human readable" ) +parser.add_argument( + "--s3-endpoint", + type=str, + help="The S3 storage URL to load the models from (default: accel-object.ord1.coreweave.com)", + default="accel-object.ord1.coreweave.com" +) +parser.add_argument( + "--bucket", + type=str, + help="The bucket where the models are located (default: tensorized)", + default="tensorized" +) args = parser.parse_args() model_name: str = args.model -http_uri = ( - "http://tensorized.accel-object.ord1.coreweave.com" - f"/{model_name}/model.tensors" -) +http_uri = f"http://{args.bucket}.{args.s3_url}/{model_name}/model.tensors" + https_uri = http_uri.replace("http://", "https://") -s3_uri = f"s3://tensorized/{model_name}/model.tensors" +s3_uri = f"s3://{args.bucket}/{model_name}/model.tensors" sanitized_model_file = model_name.replace("/", "_") file_uri = f"{args.file_prefix}{sanitized_model_file}.tensors" local_uri = f"http://localhost:3000/{sanitized_model_file}.tensors" +s3_endpoint = f"http://{args.s3_url}" +if args.test_http: + s3_endpoint = s3_endpoint.replace("http://", "https://") + # Get nodename from environment, or default to os.uname().nodename nodename = os.getenv("K8S_NODE_NAME") or os.uname().nodename @@ -367,6 +380,7 @@ def io_test( def deserialize_test( source=http_uri, + s3_endpoint=s3_endpoint, plaid_mode=False, verify_hash=False, lazy_load=False, @@ -378,7 +392,7 @@ def deserialize_test( plaid_mode_buffers = None stream = open_stream( source, - s3_endpoint=default_s3_read_endpoint, + s3_endpoint=s3_endpoint, buffer_size=buffer_size, force_http=force_http, )