Skip to content

Commit

Permalink
fix: handling pytorch nested model structures
Browse files Browse the repository at this point in the history
  • Loading branch information
evilsocket committed Oct 29, 2024
1 parent 2a180c2 commit 6afa276
Showing 1 changed file with 84 additions and 67 deletions.
151 changes: 84 additions & 67 deletions src/core/pytorch/inspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,76 +3,93 @@
import os
import argparse

parser = argparse.ArgumentParser(description="Inspect PyTorch model files")
parser.add_argument("file", help="Path to PyTorch model file")
parser.add_argument(
"--detailed", action="store_true", help="Show detailed tensor information"
)
parser.add_argument("--filter", help="Filter tensors by name pattern")

args = parser.parse_args()
def main():
parser = argparse.ArgumentParser(description="Inspect PyTorch model files")
parser.add_argument("file", help="Path to PyTorch model file")
parser.add_argument(
"--detailed", action="store_true", help="Show detailed tensor information"
)
parser.add_argument("--filter", help="Filter tensors by name pattern")

file_path = os.path.abspath(args.file)
file_size = os.path.getsize(file_path)
args = parser.parse_args()

try:
model = torch.load(
file_path, weights_only=True, mmap=True, map_location=torch.device("cpu")
)
except RuntimeError:
# RuntimeError: mmap can only be used with files saved with `torch.save(/model.bin, _use_new_zipfile_serialization=True), please torch.save your checkpoint with this option in order to use mmap.
model = torch.load(file_path, weights_only=True, map_location=torch.device("cpu"))

all_metadata = getattr(model, "_metadata") if hasattr(model, "_metadata") else {}
model_metadata = all_metadata[""] if "" in all_metadata else {}

inspection = {
"file_path": file_path,
"file_type": "PyTorch",
"file_size": file_size,
"version": str(model_metadata["version"] if "version" in model_metadata else ""),
"num_tensors": len(model.items()),
"data_size": 0,
"unique_shapes": [],
"unique_dtypes": [],
"metadata": {k: str(v) for (k, v) in model_metadata.items()},
"tensors": [] if args.detailed else None,
}

for tensor_name, tensor in model.items():
inspection["data_size"] += tensor.shape.numel() * tensor.element_size()

shape = list(tensor.shape)
if shape != []:
if shape not in inspection["unique_shapes"]:
inspection["unique_shapes"].append(shape)

dtype = str(tensor.dtype).replace("torch.", "")
if dtype not in inspection["unique_dtypes"]:
inspection["unique_dtypes"].append(dtype)

if args.detailed:
if args.filter and args.filter not in tensor_name:
continue

layer_name = tensor_name.split(".")[0]
inspection["tensors"].append(
{
"id": tensor_name,
"shape": shape,
"dtype": dtype,
"size": tensor.shape.numel() * tensor.element_size(),
"metadata": {k: str(v) for (k, v) in all_metadata[layer_name].items()}
if layer_name in all_metadata
else {},
}
file_path = os.path.abspath(args.file)
file_size = os.path.getsize(file_path)

try:
model = torch.load(
file_path, weights_only=True, mmap=True, map_location=torch.device("cpu")
)
except RuntimeError:
# handle: RuntimeError: mmap can only be used with files saved with `torch.save(/model.bin, _use_new_zipfile_serialization=True),
# please torch.save your checkpoint with this option in order to use mmap.
model = torch.load(
file_path, weights_only=True, map_location=torch.device("cpu")
)

# data can be compressed or shared among multiple vectors(?) in which case this would be negative
inspection["header_size"] = (
inspection["file_size"] - inspection["data_size"]
if inspection["data_size"] < inspection["file_size"]
else 0
)
all_metadata = getattr(model, "_metadata") if hasattr(model, "_metadata") else {}
model_metadata = all_metadata[""] if "" in all_metadata else {}

inspection = {
"file_path": file_path,
"file_type": "PyTorch",
"file_size": file_size,
"version": str(
model_metadata["version"] if "version" in model_metadata else ""
),
"num_tensors": len(model.items()),
"data_size": 0,
"unique_shapes": [],
"unique_dtypes": [],
"metadata": {k: str(v) for (k, v) in model_metadata.items()},
"tensors": [] if args.detailed else None,
}

# handle nested dictionary case
if "model" in model:
model = model["model"]

for tensor_name, tensor in model.items():
inspection["data_size"] += tensor.shape.numel() * tensor.element_size()

shape = list(tensor.shape)
if shape != []:
if shape not in inspection["unique_shapes"]:
inspection["unique_shapes"].append(shape)

dtype = str(tensor.dtype).replace("torch.", "")
if dtype not in inspection["unique_dtypes"]:
inspection["unique_dtypes"].append(dtype)

if args.detailed:
if args.filter and args.filter not in tensor_name:
continue

layer_name = tensor_name.split(".")[0]
inspection["tensors"].append(
{
"id": tensor_name,
"shape": shape,
"dtype": dtype,
"size": tensor.shape.numel() * tensor.element_size(),
"metadata": {
k: str(v) for (k, v) in all_metadata[layer_name].items()
}
if layer_name in all_metadata
else {},
}
)

# data can be compressed or shared among multiple vectors(?) in which case this would be negative
inspection["header_size"] = (
inspection["file_size"] - inspection["data_size"]
if inspection["data_size"] < inspection["file_size"]
else 0
)

print(json.dumps(inspection))


print(json.dumps(inspection))
if __name__ == "__main__":
main()

0 comments on commit 6afa276

Please sign in to comment.