Skip to content

Commit

Permalink
Merge pull request #965 from benjeffery/fix-json
Browse files Browse the repository at this point in the history
Encode numpy arrays in sample batch json
  • Loading branch information
benjeffery authored Sep 6, 2024
2 parents f9de549 + fe04950 commit ed64181
Showing 1 changed file with 17 additions and 3 deletions.
20 changes: 17 additions & 3 deletions tsinfer/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -945,13 +945,27 @@ def common_params(self) -> dict:
}

def save(self, path):
def numpy_encoder(obj):
if isinstance(obj, np.ndarray):
return {
"__numpy__": True,
"data": obj.tolist(),
"dtype": str(obj.dtype),
}
return obj

with open(path, "w") as f:
json.dump(dataclasses.asdict(self), f, indent=2)
json.dump(dataclasses.asdict(self), f, indent=2, default=numpy_encoder)

@classmethod
def load(cls, path):
def numpy_decoder(dct):
if "__numpy__" in dct:
return np.array(dct["data"], dtype=dct["dtype"])
return dct

with open(path) as f:
wd_dict = json.load(f)
wd_dict = json.load(f, object_hook=numpy_decoder)
return cls(**wd_dict)


Expand Down Expand Up @@ -1048,7 +1062,7 @@ def match_samples_batch_init(
sample_times = sample_times.tolist()
wd.sample_indexes = sample_indexes
wd.sample_times = sample_times
num_samples_per_partition = min_work_per_job // variant_data.num_sites
num_samples_per_partition = int(min_work_per_job // variant_data.num_sites)
if num_samples_per_partition == 0:
num_samples_per_partition = 1
wd.num_samples_per_partition = num_samples_per_partition
Expand Down

0 comments on commit ed64181

Please sign in to comment.