Skip to content

Commit

Permalink
feat(truncate): expand axis logic
Browse files Browse the repository at this point in the history
  • Loading branch information
ljgray committed Dec 17, 2024
1 parent 22ba143 commit 5d0746c
Showing 1 changed file with 55 additions and 25 deletions.
80 changes: 55 additions & 25 deletions draco/core/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,6 +801,56 @@ def _get_params(self, container, dset):

return params

def _get_weights(self, container, dset, wdset):
"""Extract the weight dataset and broadcast agaonst the truncation dataset.
Parameters
----------
container
Container class.
dset : str
Dataset name
wdset : str
Weight dataset name
Returns
-------
weight : np.ndarray
Array of weights to use in truncation. If `dset` is complex,
this is scaled by a factor of 2.
Raises
------
KeyError
Raised if either `dset` or `wdset` does not exist.
ValueError
Raised if the weight dataset cannot be broadcast to
the shape of the dataset to be truncated.
"""
# Try to get weights from an attribute first
if hasattr(container, wdset):
weight = getattr(container, wdset)
else:
weight = container[wdset]

data = container[dset]

if isinstance(weight, memh5.MemDataset):
# Add missing broadcast axes to the weights dataset
waxes = weight.attrs.get("axis", [])
daxes = data.attrs.get("axis", [])
# Add length-one axes
slobj = tuple(slice(None) if ax in waxes else np.newaxis for ax in daxes)
weight = weight[:][slobj]

# Broadcast `weight` against the shape of the truncation array
weight = np.broadcast_to(weight, data[:].shape).copy().reshape(-1)

if np.iscomplexobj(data):
weight *= 2.0

return weight

def process(self, data):
"""Truncate the incoming data.
Expand All @@ -818,8 +868,6 @@ def process(self, data):
Raises
------
`caput.pipeline.PipelineRuntimeError`
If input data has mismatching dataset and weight array shapes.
`config.CaputConfigError`
If the input data container has no preset values and `fixed_precision` or
`variance_increase` are not set in the config.
Expand Down Expand Up @@ -855,42 +903,24 @@ def process(self, data):
val, specs["fixed_precision"]
).reshape(old_shape)
else:
# If possible, extract the weight dataset from
# an attribute
if hasattr(data, specs["weight_dataset"]):
invvar = getattr(data, specs["weight_dataset"])
else:
wdset = data[specs["weight_dataset"]]
# Add missing axes to the weights dataset if
# needed and if possible
waxes = wdset.attrs.get("axis", [])
daxes = data[dset].attrs.get("axis", [])
# Add length-one axes
slobj = tuple(
slice(None) if ax in waxes else np.newaxis for ax in daxes
)
invvar = wdset[:][slobj]

invvar = np.broadcast_to(invvar, data[dset][:].shape).copy().reshape(-1)
invvar *= (2.0 if np.iscomplexobj(data[dset]) else 1.0) / specs[
"variance_increase"
]
wdset = self._get_weights(data, dset, specs["weight_dataset"])
wdset /= specs["variance_increase"]

if np.iscomplexobj(data[dset]):
data[dset][:].real = truncate.bit_truncate_weights(
val.real,
invvar,
wdset,
specs["fixed_precision"],
).reshape(old_shape)
data[dset][:].imag = truncate.bit_truncate_weights(
val.imag,
invvar,
wdset,
specs["fixed_precision"],
).reshape(old_shape)
else:
data[dset][:] = truncate.bit_truncate_weights(
val,
invvar,
wdset,
specs["fixed_precision"],
).reshape(old_shape)

Expand Down

0 comments on commit 5d0746c

Please sign in to comment.