Skip to content

Commit

Permalink
Merge pull request #1019 from AI-Hypercomputer:fix/aqteinsum
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 695485206
  • Loading branch information
maxtext authors committed Nov 11, 2024
2 parents 1ff8505 + eb39a37 commit 97c3274
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 7 deletions.
6 changes: 3 additions & 3 deletions MaxText/configs/quantization/mp_scale.json
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
{
".*/query": {"bits": 4, "scale": 0.8},
".*/key": {"bits": 4, "scale": 0.9},
".*/query": {"bits": 4, "clipping_scale": 0.8},
".*/key": {"bits": 4, "clipping_scale": 0.9},
".*/value": {"bits": 4},
".*/out": {"bits": 4},
".*/wi_0": {"bits": 4},
".*/wo": {"bits": 4}
}
}
24 changes: 20 additions & 4 deletions MaxText/layers/quantizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,11 @@ def _tiling_fn(lhs, rhs, dimension_numbers, tile_size):


def _rhs_axis_metadata_wrapper(
x: jnp.ndarray, tile_map, no_sharding_axis: Sequence[int], mesh_axes: Tuple[str, ...], is_tiled: bool
x: jnp.ndarray,
tile_map,
no_sharding_axis: Sequence[int],
mesh_axes: Tuple[str, ...],
is_tiled: bool,
):
mesh_axes = list(mesh_axes)
if is_tiled:
Expand Down Expand Up @@ -124,6 +128,7 @@ def dot_general_cls(self, mesh_axes: Tuple[str, ...] = ()):
else:
quant_dg, is_tiled, tiling_fn = self.quant_dg, False, None
rhs_axis_metadata_wrapper = self._get_rhs_axis_metadata_wrapper(mesh_axes, is_tiled)

aqt_dg_cls = functools.partial(
aqt_flax.AqtDotGeneral,
quant_dg,
Expand All @@ -138,15 +143,21 @@ def dot_general_cls(self, mesh_axes: Tuple[str, ...] = ()):

def einsum(self, mesh_axes: Tuple[str, ...] = ()):
"""Returns einsum configured with aqt params."""
rhs_axis_metadata_wrapper = self._get_rhs_axis_metadata_wrapper(mesh_axes)
if isinstance(self.quant_dg, dict):
quant_dg, is_tiled, tiling_fn = self._get_mixed_precision_cfg()
else:
quant_dg, is_tiled, tiling_fn = self.quant_dg, False, None

rhs_axis_metadata_wrapper = self._get_rhs_axis_metadata_wrapper(mesh_axes, is_tiled)
aqt_einsum = functools.partial(
aqt_flax.AqtEinsum(
cfg=self.quant_dg,
cfg=quant_dg,
rhs_quant_mode=self.quant_mode,
lhs_freeze_mode=aqt_flax.FreezerMode.NONE,
rhs_freeze_mode=aqt_flax.FreezerMode.CALIBRATION_AND_VALUE,
rhs_axis_metadata_wrapper=rhs_axis_metadata_wrapper,
use_legacy_freezer=False,
tiling_fn=tiling_fn,
)
)
return aqt_einsum
Expand Down Expand Up @@ -337,7 +348,12 @@ def quantize(self, kv: Array, axis_names: AxisNames):
return value, scale
raise ValueError(f"Invalid KV quant dtype:{self.dtype}.")

def einsum_fn_with_rhs_qtensor(self, kv: Array | aqt_tensor.QTensor, rhs_dequant_mode=None, rhs_calibration_mode=None):
def einsum_fn_with_rhs_qtensor(
self,
kv: Array | aqt_tensor.QTensor,
rhs_dequant_mode=None,
rhs_calibration_mode=None,
):
# Assumes kv is already quantized.
einsum = jnp.einsum
if isinstance(kv, aqt_tensor.QTensor):
Expand Down

0 comments on commit 97c3274

Please sign in to comment.