Skip to content

Commit

Permalink
Encode numpy arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
benjeffery committed Sep 5, 2024
1 parent a080e8f commit fe04950
Showing 1 changed file with 17 additions and 3 deletions.
20 changes: 17 additions & 3 deletions tsinfer/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -945,13 +945,27 @@ def common_params(self) -> dict:
}

def save(self, path):
def numpy_encoder(obj):
if isinstance(obj, np.ndarray):
return {

Check warning on line 950 in tsinfer/inference.py

View check run for this annotation

Codecov / codecov/patch

tsinfer/inference.py#L950

Added line #L950 was not covered by tests
"__numpy__": True,
"data": obj.tolist(),
"dtype": str(obj.dtype),
}
return obj

Check warning on line 955 in tsinfer/inference.py

View check run for this annotation

Codecov / codecov/patch

tsinfer/inference.py#L955

Added line #L955 was not covered by tests

with open(path, "w") as f:
json.dump(dataclasses.asdict(self), f, indent=2)
json.dump(dataclasses.asdict(self), f, indent=2, default=numpy_encoder)

@classmethod
def load(cls, path):
def numpy_decoder(dct):
if "__numpy__" in dct:
return np.array(dct["data"], dtype=dct["dtype"])

Check warning on line 964 in tsinfer/inference.py

View check run for this annotation

Codecov / codecov/patch

tsinfer/inference.py#L964

Added line #L964 was not covered by tests
return dct

with open(path) as f:
wd_dict = json.load(f)
wd_dict = json.load(f, object_hook=numpy_decoder)
return cls(**wd_dict)


Expand Down Expand Up @@ -1048,7 +1062,7 @@ def match_samples_batch_init(
sample_times = sample_times.tolist()
wd.sample_indexes = sample_indexes
wd.sample_times = sample_times
num_samples_per_partition = min_work_per_job // variant_data.num_sites
num_samples_per_partition = int(min_work_per_job // variant_data.num_sites)
if num_samples_per_partition == 0:
num_samples_per_partition = 1
wd.num_samples_per_partition = num_samples_per_partition
Expand Down

0 comments on commit fe04950

Please sign in to comment.