From 1da635d4d6283ae27a43b76085e9d97c48c63769 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 6 Dec 2023 15:30:15 +0000 Subject: [PATCH 01/16] Feat (graph/equalize): extended graph equalization --- src/brevitas/graph/equalize.py | 419 ++++++++++++++---- src/brevitas/graph/quantize.py | 1 + src/brevitas/graph/target/flexml.py | 6 +- tests/brevitas/graph/equalization_fixtures.py | 5 +- tests/brevitas/graph/test_equalization.py | 8 +- 5 files changed, 336 insertions(+), 103 deletions(-) diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index 9701c1fdf..b132e31b5 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -1,8 +1,14 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause +<<<<<<< HEAD from abc import ABC from abc import abstractmethod +======= +from collections import namedtuple +from collections import OrderedDict +from copy import deepcopy +>>>>>>> Feat (graph/equalize): extended graph equalization from dataclasses import dataclass from dataclasses import field from functools import partial @@ -62,7 +68,13 @@ nn.ReLU, nn.LeakyReLU) -_scale_invariant_op = (torch.mul, operator.mul, operator.imul, operator.__mul__, operator.__imul__) +_scale_invariant_op = ( + torch.mul, + operator.mul, + operator.imul, + operator.__mul__, + operator.__imul__, + torch.nn.functional.interpolate) _select_op = (operator.getitem, operator.__getitem__) @@ -90,18 +102,36 @@ class WeightBiasTuple: # Required for being hashable @dataclass(eq=True, frozen=True) class Region: - srcs: Tuple = field(default_factory=tuple) - sinks: Tuple = field(default_factory=tuple) + srcs: Dict = field(default_factory=dict) + sinks: Dict = field(default_factory=dict) acts: Tuple = field(default_factory=tuple) @dataclass class WalkRegionState: - srcs: Set = field(default_factory=set) - sinks: Set = field(default_factory=set) + srcs: Dict = field(default_factory=dict) + sinks: Dict = field(default_factory=dict) acts: Set = field(default_factory=set) history: set = field(default_factory=set) add_mul_node: bool = False + offset: int = 0 + update_offset: bool = False + + +# Start and End identify the starting and ending channels of the weight matrix that need to be +# equalized. +# Offset refers to the relative position of these channels with respect to +# the other matrices' channels that are equalized simultaneously. +# Source matrix are always fully equalized, while sinks can be partially equalized. +@dataclass +class EqualizationIndexes: + start: int = 0 + end: int = 0 + offset: int = 0 + + +def __str__(self): + return str(self.start) + '_' + str(self.end) + '_' + str(self.offset) _UNSUPPORTED_OP = object() @@ -183,15 +213,6 @@ def _channel_maxabs(inp: torch.Tensor, dim: int = 1) -> torch.Tensor: return out -def _get_size(axes: Dict[nn.Module, int]) -> int: - m0, axis0 = list(axes.items())[0] - size = m0.weight.size(axis0) - for m, axis in axes.items(): - if m.weight.size(axis) != size: - return None - return size - - def _get_input_axis(module: nn.Module) -> Optional[int]: """ Given a sink module, determine the axis associated to the input channels. @@ -282,7 +303,7 @@ def _combine_weights_bias( return weight.data bias = bias.data - weight = weight.data.reshape(weight.shape[0], -1) + weight = weight.reshape(weight.shape[0], -1) bias = bias.reshape(-1, 1) weight = torch.where( @@ -324,8 +345,8 @@ def transpose(module: torch.nn.Module, axis: int): def _cross_layer_equalization( - srcs: List[nn.Module], - sinks: List[nn.Module], + srcs: Dict[nn.Module, List[int]], + sinks: Dict[nn.Module, List[int]], merge_bias: bool, scale_computation_type: str, bias_shrinkage: Optional[Union[float, str]] = None, @@ -339,41 +360,60 @@ def _cross_layer_equalization( """ # Determine device and type of tensors - device = next(sinks[0].parameters()).device - dtype = next(sinks[0].parameters()).dtype + device = next(sinks['sinks0'][0].parameters()).device + dtype = next(sinks['sinks0'][0].parameters()).dtype # If equalization criteria are not met, we return a scalar one to indicate that no equalization # has been performed def _no_equalize(): + print("No Eq") return torch.tensor(1., dtype=dtype, device=device) - src_axes = {} - sink_axes = {} act_sink_axes = {} act_sources_axes = {} + single_module = list(sinks.values())[0][0] + device = next(single_module.parameters()).device + dtype = next(single_module.parameters()).dtype + + max_shape_srcs = 0 + for name, (k, v) in srcs.items(): + max_shape_srcs = max(max_shape_srcs, v.end + v.offset) + max_shape_sinks = 0 + for name, (k, v) in sinks.items(): + max_shape_sinks = max(max_shape_sinks, v.offset + (v.end - v.start)) + # Exit if source and sink have different sizes + + if max_shape_srcs != max_shape_sinks and len(srcs) > 0: + return _no_equalize() - for i, module in enumerate(srcs): + # for i, module in enumerate(srcs): + src_axes = {} + for i, (name, (module, indexes)) in enumerate(srcs.items()): # If module is not supported, do not perform graph equalization + axis = _get_output_axis(module) + act_sources_axes[name] = _get_act_axis(module) if not isinstance(module, _supported_layers): return _no_equalize() if isinstance(module, nn.MultiheadAttention): - srcs[i] = module.out_proj - src_axes[srcs[i]] = _get_output_axis(module) - act_sources_axes[srcs[i]] = _get_act_axis(module) + module = module.out_proj + srcs[name] = (module, indexes) + src_axes[name] = (module, axis) - for i, module in enumerate(sinks): + sink_axes = {} + for i, (name, (module, indexes)) in enumerate(sinks.items()): + axis = _get_input_axis(module) + act_sink_axes[name] = _get_act_axis(module) # If module is not supported, do not perform graph equalization if not isinstance(module, _supported_layers): return _no_equalize() # For MultiheadAttention, we support only self-attetion if isinstance(module, nn.MultiheadAttention) and module.in_proj_weight is not None: # For sinks, we only need to modify the weight but not the bias - sinks[i] = WeightBiasTuple(weight=module.in_proj_weight) + module = WeightBiasTuple(module.in_proj_weight) + sinks[name] = (module, indexes) elif isinstance(module, nn.MultiheadAttention) and module.in_proj_weight is None: return _no_equalize() - sink_axes[sinks[i]] = _get_input_axis(module) - act_sink_axes[sinks[i]] = _get_act_axis(module) - + sink_axes[name] = (module, axis) # If act_val is enabled, use source or sink weights to determine the activation channel # For example, if the source is BatchNorm, we need to use the information coming from the sinks if list_of_act_val is not None: @@ -396,40 +436,51 @@ def _no_equalize(): if None in axes_to_check: return _no_equalize() - # Check if the sink_size is None, - # which means that the some of the sinks do not have the same size as the others. - sink_size = _get_size(sink_axes) - if None in [sink_size]: - return _no_equalize() - scale_fn = _select_scale_computation_fn(scale_computation_type) - sink_weights = [transpose(m, axis) for m, axis in sink_axes.items()] - sinks_range = scale_fn(torch.cat([w.reshape(w.size(0), -1) for w in sink_weights], 1)) + sink_weights = {name: transpose(m, axis) for name, (m, axis) in sink_axes.items()} + srcs_range = -1 * torch.ones(max_shape_srcs, device=device, dtype=dtype) + sinks_range = -1 * torch.ones(max_shape_sinks, device=device, dtype=dtype) + for k, v in sink_weights.items(): + # Sinks can be partially equalized, thus we need to select + # only the channels we are interested in + indexes = sinks[k][1] + # Compute the range of the channels we need to equalize + weight_range = scale_fn(v.reshape(v.size(0), -1))[indexes.start:indexes.end] + # Compute the numbers of channels we are equalizing + channel_range = indexes.end - indexes.start + # Use the offset and the range to update the correct range in the sinks + sinks_range[indexes.offset:indexes.offset + channel_range] = torch.max( + sinks_range[indexes.offset:indexes.offset + channel_range], weight_range) + sinks_range = torch.clamp(sinks_range, EPSILON) # Determine the srcs_range based on where we are performing activation equalization or # weight equalization if list_of_act_val is not None: list_of_act_val_shapes = [act_val.shape for act_val in list_of_act_val] + if len(list_of_act_val_shapes) > 0: + shape_0 = list_of_act_val_shapes[0] + if any(shape_0 != shape for shape in list_of_act_val_shapes): + return _no_equalize() list_of_act_val = [ transpose(WeightBiasTuple(act_val), act_axis) for act_val in list_of_act_val] srcs_range = scale_fn( torch.cat([act_val.reshape(act_val.size(0), -1) for act_val in list_of_act_val], 1)) else: - # If we do weight equalization, perform additional check on source size - src_size = _get_size(src_axes) - # Exit if source and sink have different different sizes, or if sources contains None - if src_size != sink_size or None in [src_size]: - warnings.warn( - "Detected source and sink with non compatible shapes, equalization is skipped") - return _no_equalize() - if merge_bias: - src_weights = [ - _combine_weights_bias(transpose(m, axis), bias_shrinkage, m.bias) for m, - axis in src_axes.items()] + src_weights = { + name: _combine_weights_bias(transpose(m, axis), bias_shrinkage, m.bias) + for name, (m, axis) in src_axes.items()} else: - src_weights = [transpose(m, axis) for m, axis in src_axes.items()] - srcs_range = scale_fn(torch.cat([w.reshape(w.size(0), -1) for w in src_weights], 1)) + src_weights = {name: transpose(m, axis) for name, (m, axis) in src_axes.items()} + for k, v in src_weights.items(): + # Srcs are always fully equalized, thus we simply need to apply the offset to position them + # correctly with respect to the other srcs matrices. + indexes = srcs[k][1] + channel_start = indexes.offset + indexes.start + channel_end = indexes.offset + indexes.end + weight_range = scale_fn(v.reshape(v.size(0), -1)) + srcs_range[channel_start:channel_end] = torch.max( + srcs_range[channel_start:channel_end], weight_range) # If there is a mismatch between srcs and sinks values, exit if srcs_range.shape != sinks_range.shape: @@ -455,32 +506,44 @@ def _no_equalize(): for act_val_shape, insert_mul_node_fn in zip(list_of_act_val_shapes, list_of_insert_mul_node_fn): insert_mul_node_fn(inverse_scaling_factors, act_val_shape, act_axis) if len(src_axes) > 0: - for module, axis in src_axes.items(): + for name, (module, axis) in src_axes.items(): + indexes = srcs[name][1] + channel_start = indexes.offset + indexes.start + channel_end = indexes.offset + indexes.end if hasattr(module, 'bias') and module.bias is not None: _update_weights( module, - module.bias.clone() * inverse_scaling_factors.view_as(module.bias), + module.bias.clone() * + inverse_scaling_factors[channel_start:channel_end].view_as(module.bias), attr='bias') src_broadcast_size = [1] * module.weight.ndim src_broadcast_size[axis] = module.weight.size(axis) + _update_weights( - module, ( - module.weight.clone() * - torch.reshape(inverse_scaling_factors, src_broadcast_size)), + module, + module.weight.clone() * torch.reshape( + inverse_scaling_factors[channel_start:channel_end], src_broadcast_size), attr='weight') for module, axis in sink_axes.items(): sink_broadcast_size = [1] * module.weight.ndim sink_broadcast_size[axis] = module.weight.size(axis) + indexes = sinks[name][1] + channel_range = indexes.end - indexes.start + partial_scaling = torch.ones(module.weight.size(axis), device=device, dtype=dtype) + # We replace the scaling factors of the channels we need to equalize, leaving the other to + # one (i.e., no equalization) + partial_scaling[indexes.start:indexes.end] = scaling_factors[indexes.offset:indexes.offset + + channel_range] if isinstance(module, _batch_norm): # We re-compute the bias as function of running_mean and running_var to adjust the # additive factor for equalization. additive_factor = module.running_mean.data * module.weight.data / torch.sqrt( module.running_var.data + module.eps) _update_weights( - module, module.bias.clone() + additive_factor * (scaling_factors - 1), attr='bias') + module, module.bias.clone() + additive_factor * (partial_scaling - 1), attr='bias') _update_weights( module, - module.weight.clone() * torch.reshape(scaling_factors, sink_broadcast_size), + module.weight.clone() * torch.reshape(partial_scaling, sink_broadcast_size), attr='weight') return scaling_factors @@ -493,6 +556,16 @@ def _update_weights(original_module, new_value, attr='weight'): setattr(original_module, attr, nn.Parameter(new_value)) +def _organize_region(region, name_to_module, type): + region_dict = {} + region = getattr(region, type) + for i, (k, v) in enumerate(region.items()): + name = type + str(i) + k = k.split('$')[0] + region_dict[name] = (name_to_module[k], v) + return region_dict + + def _equalize( model: GraphModule, regions: Set[Tuple[str]], @@ -507,22 +580,25 @@ def _equalize( name_to_module: Dict[str, nn.Module] = {} name_set = set() for region in regions: - for name in region.srcs: - name_set.add(name) - for name in region.sinks: - name_set.add(name) + for name in region.srcs.keys(): + name_set.add(name.split("$")[0]) + for name in region.sinks.keys(): + name_set.add(name.split("$")[0]) for name, module in model.named_modules(): if name in name_set: name_to_module[name] = module for i in range(iterations): scale_factor_max = None - for region in regions: + for ii, region in enumerate(regions): + srcs_dict = _organize_region(region, name_to_module, 'srcs') + sinks_dict = _organize_region(region, name_to_module, 'sinks') scale_factors_region = _cross_layer_equalization( - [name_to_module[n] for n in region.srcs], [name_to_module[n] for n in region.sinks], + srcs_dict, + sinks_dict, merge_bias=merge_bias, - scale_computation_type=scale_computation_type, - bias_shrinkage=bias_shrinkage) + bias_shrinkage=bias_shrinkage, + scale_computation_type=scale_computation_type) scale_factor_region_max = torch.max(torch.abs(1 - scale_factors_region)) if scale_factor_max is not None: scale_factor_max = torch.max(scale_factor_max, scale_factor_region_max) @@ -557,46 +633,153 @@ def _is_scale_varying_activation(graph_model, node): def _is_scale_invariant_function(node: Node) -> bool: - return node.op == 'call_function' and node.target in _scale_invariant_op + _select_op + out = node.op == 'call_function' and node.target in _scale_invariant_op + _select_op + if node.target == torch.nn.functional.interpolate: + out &= node.kwargs.get('mode', None) == 'nearest' + return out def _is_reshaping_op(node: Node) -> bool: return node.target in _reshaping_op +def get_weight_source(module_list): + transpose = lambda module, axis: module.weight if axis == 0 else module.weight.transpose(0, 1) + for i, module in enumerate(module_list): + if isinstance(module, nn.MultiheadAttention): + if hasattr(module, 'out_proj'): + module_list[i] = module.out_proj + else: + raise RuntimeError("Configuration for Multiheadattention not supported") + srcs_axes = {module: _get_output_axis(module) for module in module_list} + weight = [transpose(m, axis) for m, axis in srcs_axes.items()] + return weight + + +def get_weight_sink(module_list): + transpose = lambda module, axis: module.weight if axis == 0 else module.weight.transpose(0, 1) + for i, module in enumerate(module_list): + if isinstance(module, nn.MultiheadAttention): + if hasattr(module, 'in_proj_weight'): + module_list[i] = WeightBiasTuple(module.in_proj_weight) + else: + raise RuntimeError("Configuration for Multiheadattention not supported") + sinks_axes = {module: _get_input_axis(module) for module in module_list} + weight = [transpose(m, axis) for m, axis in sinks_axes.items()] + return weight + + +def find_srcs_channel_dim(model, inp_node): + if _is_supported_module(model, inp_node): + # If we meet a supported module, determine the channel shape + module = get_module(model, inp_node.target) + # Since we are walking up, we consider the module as srcs + weight = get_weight_source([module]) + channel = weight[0].shape[0] + return channel + elif _is_add(inp_node): + # If it's add, we need the channel shape of one of the branches, since they are all the same + return find_srcs_channel_dim(model, inp_node.all_input_nodes[0]) + elif _is_cat(inp_node): + total_channels = 0 + # If it's cat, we need to sum the channel shape of all the branches + for n in inp_node.all_input_nodes: + total_channels += find_srcs_channel_dim(model, n) + return total_channels + elif _is_scale_invariant_module(model, inp_node) or _is_scale_invariant_function(inp_node): + return find_srcs_channel_dim(model, inp_node.all_input_nodes[0]) + else: + return _UNSUPPORTED_OP + + +def cat_handler(graph_model: GraphModule, starting_node: Node, state: WalkRegionState): + + state.srcs.clear() + state.sinks.clear() + state.history.clear() + state.srcs[starting_node.target] = _UNSUPPORTED_OP + state.update_offset = True + state.offset = 0 + find_srcs(graph_model, starting_node, state) + state.update_offset = False + state.offset = 0 + find_sinks(graph_model, starting_node, state) + + +def _is_cat(node): + return node.target in (torch.cat,) + + +def _is_cat_in_srcs(srcs): + out = False + for src in srcs: + out = out or src in (torch.cat,) + return out + + +def _is_add(node): + return ( + node.op == 'call_method' and node.target in _residual_methods or + node.op == 'call_function' and node.target in _residual_fns) + + def find_srcs(graph_model: GraphModule, starting_node: Node, state: WalkRegionState) -> Dict[str, Set]: node_list = starting_node.all_input_nodes + update_offset_state = state.update_offset for node in node_list: # we keep a history of how the graph has been walked already, invariant to the direction, # to avoid getting stuck in a loop path = (node, starting_node) + module = None if path not in state.history: state.history.add(path) else: continue if _is_supported_module(graph_model, node): - state.srcs.add(node.target) + module = get_module(graph_model, node.target) + weight = get_weight_source([module]) + eq_indexes = EqualizationIndexes(0, weight[0].shape[0], state.offset) + # state.srcs.add((node.target, eq_indexes)) + full_source_name = node.target + '$' + str(eq_indexes) + state.srcs[full_source_name] = eq_indexes # After we found a source, we need to check if it branches into multiple sinks find_sinks(graph_model, node, state) + state.offset = state.offset if not state.update_offset else state.offset + weight[ + 0].shape[0] elif _is_scale_invariant_module( graph_model, node) or _is_scale_invariant_function(node) or _is_reshaping_op(node): - find_srcs(graph_model, node, state) find_sinks(graph_model, node, state) + find_srcs(graph_model, node, state) elif (node.op == 'call_method' and node.target in _residual_methods or node.op == 'call_function' and node.target in _residual_fns): - find_srcs(graph_model, node, state) + state.update_offset = False find_sinks(graph_model, node, state) + find_srcs(graph_model, node, state) + state.update_offset = update_offset_state + elif _is_cat(node): + # We have never encoutered cat + if not _is_cat_in_srcs(state.srcs): + # We restart the region search starting from the cat + cat_handler(graph_model, node, state) + # We have encoutered cat already + else: + state.update_offset = False + find_sinks(graph_model, node, state) + state.update_offset = True + find_srcs(graph_model, node, state) + state.update_offset = update_offset_state elif node.target in _ignore_ops: continue else: # If we meet an unrecognized op, we add None to invalidate the region - state.srcs.add(_UNSUPPORTED_OP) + state.sinks[_UNSUPPORTED_OP] = _UNSUPPORTED_OP def find_sinks(graph_model: GraphModule, starting_node: Node, state: WalkRegionState) -> Dict[str, Set]: node_list = starting_node.users + update_offset_state = state.update_offset for node in node_list: # we keep a history of how the graph has been walked already, invariant to the direction, # to avoid getting stuck in a loop @@ -608,50 +791,88 @@ def find_sinks(graph_model: GraphModule, starting_node: Node, continue if _is_supported_module(graph_model, node): module = get_module(graph_model, node.target) + weight = get_weight_sink([module]) + eq_indexes = EqualizationIndexes(0, weight[0].shape[0], state.offset) # It is not possible to equalize through LayerNorm as sink - if isinstance(module, (nn.LayerNorm,) + _batch_norm): - state.sinks.add(_UNSUPPORTED_OP) + if isinstance(module, (nn.LayerNorm,)): + # state.sinks.add((_UNSUPPORTED_OP, _UNSUPPORTED_OP)) + state.sinks[_UNSUPPORTED_OP] = _UNSUPPORTED_OP else: - state.sinks.add(node.target) + full_sink_name = node.target + '$' + str(eq_indexes) + state.sinks[full_sink_name] = eq_indexes + # state.sinks[node.target] = eq_indexes elif _is_scale_invariant_module( graph_model, node) or _is_scale_invariant_function(node) or _is_reshaping_op(node): find_sinks(graph_model, node, state) elif (node.op == 'call_method' and node.target in _residual_methods or node.op == 'call_function' and node.target in _residual_fns): + state.update_offset = False find_sinks(graph_model, node, state) find_srcs(graph_model, node, state) + state.update_offset = update_offset_state + elif _is_cat(node): + # We have never encoutered cat + if not _is_cat_in_srcs(state.srcs): + # We restart the region search starting from the cat + cat_handler(graph_model, node, state) + # We have encoutered cat already + else: + # In this case we define all our sinks, and isolate only the channels we want + # to equalize (start, end). + # Furthermore, we need to consider the offset given by the sources of the second cat + index = node.all_input_nodes.index(starting_node) + channels = [] + for n in node.all_input_nodes: + channels.append(find_srcs_channel_dim(graph_model, n)) + start = sum(channels[:index]) + end = start + channels[index] + new_state = WalkRegionState(offset=state.offset) + find_sinks(graph_model, node, new_state) + + for k in new_state.sinks.keys(): + state.sinks[k] = EqualizationIndexes(start, end, new_state.offset) + state.srcs.update(new_state.srcs) elif node.target in _ignore_ops: continue else: # If we meet an unrecognized op, we add None to invalidate the region - state.sinks.add(_UNSUPPORTED_OP) + state.sinks[_UNSUPPORTED_OP] = _UNSUPPORTED_OP def _extract_regions( graph_model: GraphModule, add_mul_node: bool = False, return_acts: bool = False) -> List[Region]: - regions = [] + regions = list() + regions_name = set() for node in graph_model.graph.nodes: if _is_supported_module(graph_model, node) or (add_mul_node and _is_scale_varying_activation(graph_model, node)): - state = WalkRegionState(srcs={node.target}, add_mul_node=add_mul_node) + state = WalkRegionState(add_mul_node=add_mul_node) if _is_scale_varying_activation(graph_model, node): state.acts.add(node.target) + else: + module = get_module(graph_model, node.target) + weight = get_weight_source([module]) + eq_indexes = EqualizationIndexes(0, weight[0].shape[0], 0) + full_source_name = node.target + '$' + str(eq_indexes) + state.srcs[full_source_name] = eq_indexes find_sinks(graph_model, node, state) - if state.sinks and _UNSUPPORTED_OP not in state.sinks and _UNSUPPORTED_OP not in state.srcs: - # each region should appear only once, so to make it hashable - # we convert srcs and sinks to ordered lists first, and then to tuples - srcs = tuple(sorted(state.srcs)) - sinks = tuple(sorted(state.sinks)) - acts = tuple(sorted(state.acts)) + if state.sinks and _UNSUPPORTED_OP not in state.sinks.keys( + ) and _UNSUPPORTED_OP not in state.srcs.keys(): + # Drop cat from the srcs + state.srcs = {k: v for k, v in state.srcs.items() if k is not torch.cat} + sorted_srcs = dict(sorted(state.srcs.items())) + sorted_sinks = dict(sorted(state.sinks.items())) + sorted_acts = tuple(sorted(state.acts)) if return_acts: - region_to_add = Region(srcs=srcs, sinks=sinks, acts=acts) + region = Region(srcs=sorted_srcs, sinks=sorted_sinks, acts=sorted_acts) else: - region_to_add = Region(srcs=srcs, sinks=sinks) - if region_to_add not in regions: - regions.append(region_to_add) + region = Region(srcs=sorted_srcs, sinks=sorted_sinks) + + if region not in regions: + regions.append(region) return regions @@ -773,7 +994,9 @@ def find_module(self, model, regions: List): """ if isinstance(model, _supported_layers) and not isinstance(model, _batch_norm + (nn.LayerNorm,)): - regions.append(model) + weight = get_weight_sink([model]) + eq_indexes = EqualizationIndexes(0, weight[0].shape[0], 0) + regions.append((model, eq_indexes)) else: for module in model.children(): self.find_module(module, regions) @@ -785,25 +1008,25 @@ def setup(self): batch_dim = 0 if region.batch_first else 1 hook_fn = partial( - self.forward_stats_hook, name=region, batch_dim=batch_dim, use_inp=True) - new_instance = KwargsForwardHook(region, hook_fn) - ModuleInstanceToModuleInstance(region, new_instance).apply(self.model) + self.forward_stats_hook, name=region[0], batch_dim=batch_dim, use_inp=True) + new_instance = KwargsForwardHook(region[0], hook_fn) + ModuleInstanceToModuleInstance(region[0], new_instance).apply(self.model) self.hooks.append(new_instance) def apply(self, alpha): scale_factors = [] self.remove_hooks() for region in self.regions: - if self.float_act_map[region] == None: + if self.float_act_map[region[0]] == None: continue sinks = region insert_mul_fn = partial( - self.insert_mul_node, region=region, batch_dim=self.batch_dim_act_map[region]) + self.insert_mul_node, region=region[0], batch_dim=self.batch_dim_act_map[region[0]]) scale_factors.append( - _cross_layer_equalization([], [sinks], + _cross_layer_equalization({}, {'sinks0': sinks}, False, scale_computation_type=self.scale_computation_type, - list_of_act_val=[self.float_act_map[region]], + list_of_act_val=[self.float_act_map[region[0]]], list_of_insert_mul_node_fn=[insert_mul_fn], alpha=alpha)) return scale_factors @@ -896,10 +1119,12 @@ def apply(self, alpha): self.insert_mul_node, act_node=act_node, batch_dim=self.batch_dim_act_map[act_name])) + srcs_dict = _organize_region(region, name_to_module, 'srcs') + sinks_dict = _organize_region(region, name_to_module, 'sinks') scale_factors.append( _cross_layer_equalization( - srcs, - sinks, + srcs_dict, + sinks_dict, False, scale_computation_type=self.scale_computation_type, list_of_act_val=list_of_act_val, diff --git a/src/brevitas/graph/quantize.py b/src/brevitas/graph/quantize.py index 63143c4e5..458be7029 100644 --- a/src/brevitas/graph/quantize.py +++ b/src/brevitas/graph/quantize.py @@ -262,6 +262,7 @@ def preprocess_for_quantize( equalize_iters=0, equalize_merge_bias=True, merge_bn=True, + include_cat=True, equalize_bias_shrinkage: str = 'vaiq', equalize_scale_computation: str = 'maxabs'): diff --git a/src/brevitas/graph/target/flexml.py b/src/brevitas/graph/target/flexml.py index c10f688e3..f6817c196 100644 --- a/src/brevitas/graph/target/flexml.py +++ b/src/brevitas/graph/target/flexml.py @@ -124,8 +124,9 @@ def preprocess_for_flexml_quantize( equalize_iters=0, equalize_merge_bias=True, merge_bn=True, - equalize_bias_shrinkage: str = 'vaiq', - equalize_scale_computation: str = 'maxabs', + equalize_bias_shrinkage='vaiq', + equalize_scale_computation='maxabs', + include_cat=True, **model_kwargs): training_state = model.training model.eval() @@ -141,6 +142,7 @@ def preprocess_for_flexml_quantize( equalize_iters, equalize_merge_bias, merge_bn, + include_cat, equalize_bias_shrinkage, equalize_scale_computation) model.train(training_state) diff --git a/tests/brevitas/graph/equalization_fixtures.py b/tests/brevitas/graph/equalization_fixtures.py index 8e13c49cd..466cf82a3 100644 --- a/tests/brevitas/graph/equalization_fixtures.py +++ b/tests/brevitas/graph/equalization_fixtures.py @@ -42,8 +42,11 @@ def equalize_test(model, regions, merge_bias, bias_shrinkage, scale_computation_ name_to_module[name] = module for i in range(3): for region in regions: + srcs_dict = _organize_region(region, name_to_module, 'srcs') + sinks_dict = _organize_region(region, name_to_module, 'sinks') scale_factors_region = _cross_layer_equalization( - [name_to_module[n] for n in region.srcs], [name_to_module[n] for n in region.sinks], + srcs_dict, + sinks_dict, merge_bias=merge_bias, bias_shrinkage=bias_shrinkage, scale_computation_type=scale_computation_type) diff --git a/tests/brevitas/graph/test_equalization.py b/tests/brevitas/graph/test_equalization.py index caca0fd29..cb07d26ee 100644 --- a/tests/brevitas/graph/test_equalization.py +++ b/tests/brevitas/graph/test_equalization.py @@ -36,7 +36,7 @@ def test_resnet18_equalization(): # Check that equalization is not introducing FP variations assert torch.allclose(expected_out, out, atol=ATOL) - regions = sorted(regions, key=lambda region: region.srcs[0]) + regions = sorted(regions, key=lambda region: sorted([r for r in region.srcs.keys()])) resnet_18_regions = sorted(RESNET_18_REGIONS, key=lambda region: region[0][0]) equalized_layers = set() for r in resnet_18_regions: @@ -45,8 +45,10 @@ def test_resnet18_equalization(): # Check that we found all the expected regions for region, expected_region in zip(regions, resnet_18_regions): - sources_check = set(region.srcs) == set(expected_region[0]) - sinks_check = set(region.sinks) == set(expected_region[1]) + srcs = list(region.srcs) + sources_check = set(srcs) == set(expected_region[0]) + sinks = list(region.sinks) + sinks_check = set(sinks) == set(expected_region[1]) assert sources_check assert sinks_check From 9dbc1456e2e9184ba2a1b0d8d802de8b36fda9db Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 6 Dec 2023 15:31:38 +0000 Subject: [PATCH 02/16] Add missing import --- tests/brevitas/graph/equalization_fixtures.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/brevitas/graph/equalization_fixtures.py b/tests/brevitas/graph/equalization_fixtures.py index 466cf82a3..7ff94fdee 100644 --- a/tests/brevitas/graph/equalization_fixtures.py +++ b/tests/brevitas/graph/equalization_fixtures.py @@ -11,6 +11,7 @@ from brevitas import torch_version from brevitas.graph.equalize import _cross_layer_equalization +from brevitas.graph.equalize import _organize_region SEED = 123456 ATOL = 1e-3 From 7ac452536fe8094d5bf4024549c58c4978da6b5e Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 6 Dec 2023 21:19:45 +0000 Subject: [PATCH 03/16] Fix test --- src/brevitas/graph/equalize.py | 23 ++++++------------- tests/brevitas/graph/equalization_fixtures.py | 6 +++-- tests/brevitas/graph/test_equalization.py | 4 ++-- 3 files changed, 13 insertions(+), 20 deletions(-) diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index b132e31b5..5af1f0e0b 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -1,14 +1,8 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause -<<<<<<< HEAD from abc import ABC from abc import abstractmethod -======= -from collections import namedtuple -from collections import OrderedDict -from copy import deepcopy ->>>>>>> Feat (graph/equalize): extended graph equalization from dataclasses import dataclass from dataclasses import field from functools import partial @@ -183,10 +177,13 @@ def dict_name_to_module(model, regions): name_set = set() for region in regions: for name in region.srcs: + name = name.split("$")[0] name_set.add(name) for name in region.sinks: + name = name.split("$")[0] name_set.add(name) for name in region.acts: + name = name.split("$")[0] name_set.add(name) for name, module in model.named_modules(): if name in name_set: @@ -794,7 +791,7 @@ def find_sinks(graph_model: GraphModule, starting_node: Node, weight = get_weight_sink([module]) eq_indexes = EqualizationIndexes(0, weight[0].shape[0], state.offset) # It is not possible to equalize through LayerNorm as sink - if isinstance(module, (nn.LayerNorm,)): + if isinstance(module, (nn.LayerNorm,) + _batch_norm): # state.sinks.add((_UNSUPPORTED_OP, _UNSUPPORTED_OP)) state.sinks[_UNSUPPORTED_OP] = _UNSUPPORTED_OP else: @@ -1096,16 +1093,10 @@ def apply(self, alpha): name_to_module = dict_name_to_module(self.model, self.regions) for region in self.regions: region_to_search = region.sinks if len(region.acts) == 0 else region.acts - if any([self.float_act_map[name] is None for name in region_to_search]): + if any([self.float_act_map[name.split("$")[0]] is None for name in region_to_search]): continue - act_module = [name_to_module[act_name] for act_name in region.acts] - list_of_act_val = [self.float_act_map[name] for name in region_to_search] - sinks = [name_to_module[sink] for sink in region.sinks] - # Filter out scale_varying activations from the srcs - srcs = [ - name_to_module[src] - for src in region.srcs - if not isinstance(name_to_module[src], _scale_varying_activations)] + act_module = [name_to_module[act_name.split("$")[0]] for act_name in region.acts] + list_of_act_val = [self.float_act_map[name.split("$")[0]] for name in region_to_search] list_of_insert_mul_node_fn = None if self.add_mul_node and any([ diff --git a/tests/brevitas/graph/equalization_fixtures.py b/tests/brevitas/graph/equalization_fixtures.py index 7ff94fdee..d78c3c1c9 100644 --- a/tests/brevitas/graph/equalization_fixtures.py +++ b/tests/brevitas/graph/equalization_fixtures.py @@ -21,8 +21,8 @@ 'shufflenet_v2_x0_5': [0.318, 0.649], 'mobilenet_v2': [0.161, 0.320], 'resnet18': [0.487, 0.952], - 'googlenet': [0.1826, 0.413], - 'inception_v3': [0.264, 0.6], + 'googlenet': [0.495, 0.982], + 'inception_v3': [0.582, 0.989], 'alexnet': [0.875, 0.875],} IN_SIZE_CONV = (1, 3, 224, 224) @@ -34,8 +34,10 @@ def equalize_test(model, regions, merge_bias, bias_shrinkage, scale_computation_ name_set = set() for region in regions: for name in region.srcs: + name = name.split('$')[0] name_set.add(name) for name in region.sinks: + name = name.split('$')[0] name_set.add(name) scale_factors_regions = [] for name, module in model.named_modules(): diff --git a/tests/brevitas/graph/test_equalization.py b/tests/brevitas/graph/test_equalization.py index cb07d26ee..11fd579cc 100644 --- a/tests/brevitas/graph/test_equalization.py +++ b/tests/brevitas/graph/test_equalization.py @@ -86,8 +86,8 @@ def test_equalization_torchvision_models(model_coverage: tuple, merge_bias: bool srcs = set() sinks = set() for r in regions: - srcs.update(list(r.srcs)) - sinks.update(list(r.sinks)) + srcs.update([x.split("$")[0] for x in list(r.srcs)]) + sinks.update([x.split("$")[0] for x in list(r.sinks)]) count_region_srcs = 0 count_region_sinks = 0 From 60e15d00a1767cef62ef7ccf7c35a661d104342d Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 7 Dec 2023 15:55:00 +0000 Subject: [PATCH 04/16] Fix test v2 --- tests/brevitas/graph/equalization_fixtures.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/brevitas/graph/equalization_fixtures.py b/tests/brevitas/graph/equalization_fixtures.py index d78c3c1c9..015c339f8 100644 --- a/tests/brevitas/graph/equalization_fixtures.py +++ b/tests/brevitas/graph/equalization_fixtures.py @@ -22,7 +22,7 @@ 'mobilenet_v2': [0.161, 0.320], 'resnet18': [0.487, 0.952], 'googlenet': [0.495, 0.982], - 'inception_v3': [0.582, 0.989], + 'inception_v3': [0.497, 0.989], 'alexnet': [0.875, 0.875],} IN_SIZE_CONV = (1, 3, 224, 224) From bbeca9f0dc12f55600d47ebb723d79c47f88db11 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 7 Dec 2023 16:30:22 +0000 Subject: [PATCH 05/16] Review --- src/brevitas/graph/equalize.py | 38 +++++++++---------- tests/brevitas/graph/equalization_fixtures.py | 6 +-- tests/brevitas/graph/test_equalization.py | 4 +- 3 files changed, 23 insertions(+), 25 deletions(-) diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index 5af1f0e0b..c05366669 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -100,6 +100,14 @@ class Region: sinks: Dict = field(default_factory=dict) acts: Tuple = field(default_factory=tuple) + @property + def srcs_names(self): + return [name.split("$")[0] for name in self.srcs.keys()] + + @property + def sinks_names(self): + return [name.split("$")[0] for name in self.sinks.keys()] + @dataclass class WalkRegionState: @@ -176,14 +184,11 @@ def dict_name_to_module(model, regions): name_set = set() for region in regions: - for name in region.srcs: - name = name.split("$")[0] + for name in region.srcs_names: name_set.add(name) - for name in region.sinks: - name = name.split("$")[0] + for name in region.sinks_names: name_set.add(name) for name in region.acts: - name = name.split("$")[0] name_set.add(name) for name, module in model.named_modules(): if name in name_set: @@ -363,7 +368,6 @@ def _cross_layer_equalization( # If equalization criteria are not met, we return a scalar one to indicate that no equalization # has been performed def _no_equalize(): - print("No Eq") return torch.tensor(1., dtype=dtype, device=device) act_sink_axes = {} @@ -378,12 +382,11 @@ def _no_equalize(): max_shape_sinks = 0 for name, (k, v) in sinks.items(): max_shape_sinks = max(max_shape_sinks, v.offset + (v.end - v.start)) - # Exit if source and sink have different sizes + # Exit if source and sink have different sizes if max_shape_srcs != max_shape_sinks and len(srcs) > 0: return _no_equalize() - # for i, module in enumerate(srcs): src_axes = {} for i, (name, (module, indexes)) in enumerate(srcs.items()): # If module is not supported, do not perform graph equalization @@ -577,10 +580,10 @@ def _equalize( name_to_module: Dict[str, nn.Module] = {} name_set = set() for region in regions: - for name in region.srcs.keys(): - name_set.add(name.split("$")[0]) - for name in region.sinks.keys(): - name_set.add(name.split("$")[0]) + for name in region.srcs_names: + name_set.add(name) + for name in region.sinks_names: + name_set.add(name) for name, module in model.named_modules(): if name in name_set: @@ -737,7 +740,6 @@ def find_srcs(graph_model: GraphModule, starting_node: Node, module = get_module(graph_model, node.target) weight = get_weight_source([module]) eq_indexes = EqualizationIndexes(0, weight[0].shape[0], state.offset) - # state.srcs.add((node.target, eq_indexes)) full_source_name = node.target + '$' + str(eq_indexes) state.srcs[full_source_name] = eq_indexes # After we found a source, we need to check if it branches into multiple sinks @@ -792,12 +794,10 @@ def find_sinks(graph_model: GraphModule, starting_node: Node, eq_indexes = EqualizationIndexes(0, weight[0].shape[0], state.offset) # It is not possible to equalize through LayerNorm as sink if isinstance(module, (nn.LayerNorm,) + _batch_norm): - # state.sinks.add((_UNSUPPORTED_OP, _UNSUPPORTED_OP)) state.sinks[_UNSUPPORTED_OP] = _UNSUPPORTED_OP else: full_sink_name = node.target + '$' + str(eq_indexes) state.sinks[full_sink_name] = eq_indexes - # state.sinks[node.target] = eq_indexes elif _is_scale_invariant_module( graph_model, node) or _is_scale_invariant_function(node) or _is_reshaping_op(node): find_sinks(graph_model, node, state) @@ -1092,11 +1092,11 @@ def apply(self, alpha): self.remove_hooks() name_to_module = dict_name_to_module(self.model, self.regions) for region in self.regions: - region_to_search = region.sinks if len(region.acts) == 0 else region.acts - if any([self.float_act_map[name.split("$")[0]] is None for name in region_to_search]): + region_names = region.sinks_names if len(region.acts) == 0 else region.acts + if any([self.float_act_map[name] is None for name in region_names]): continue - act_module = [name_to_module[act_name.split("$")[0]] for act_name in region.acts] - list_of_act_val = [self.float_act_map[name.split("$")[0]] for name in region_to_search] + act_module = [name_to_module[act_name] for act_name in region.acts] + list_of_act_val = [self.float_act_map[name] for name in region_names] list_of_insert_mul_node_fn = None if self.add_mul_node and any([ diff --git a/tests/brevitas/graph/equalization_fixtures.py b/tests/brevitas/graph/equalization_fixtures.py index 015c339f8..8feeb5c6b 100644 --- a/tests/brevitas/graph/equalization_fixtures.py +++ b/tests/brevitas/graph/equalization_fixtures.py @@ -33,11 +33,9 @@ def equalize_test(model, regions, merge_bias, bias_shrinkage, scale_computation_ name_to_module = {} name_set = set() for region in regions: - for name in region.srcs: - name = name.split('$')[0] + for name in region.srcs_names: name_set.add(name) - for name in region.sinks: - name = name.split('$')[0] + for name in region.sinks_names: name_set.add(name) scale_factors_regions = [] for name, module in model.named_modules(): diff --git a/tests/brevitas/graph/test_equalization.py b/tests/brevitas/graph/test_equalization.py index 11fd579cc..f47c24434 100644 --- a/tests/brevitas/graph/test_equalization.py +++ b/tests/brevitas/graph/test_equalization.py @@ -86,8 +86,8 @@ def test_equalization_torchvision_models(model_coverage: tuple, merge_bias: bool srcs = set() sinks = set() for r in regions: - srcs.update([x.split("$")[0] for x in list(r.srcs)]) - sinks.update([x.split("$")[0] for x in list(r.sinks)]) + srcs.update([x for x in list(r.srcs_names)]) + sinks.update([x for x in list(r.sinks_names)]) count_region_srcs = 0 count_region_sinks = 0 From 2b6e81d95e4f67786847e2bd315884de0fcd385c Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 11 Dec 2023 17:14:18 +0000 Subject: [PATCH 06/16] Fix test --- tests/brevitas/graph/test_equalization.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/brevitas/graph/test_equalization.py b/tests/brevitas/graph/test_equalization.py index f47c24434..e67fb4f63 100644 --- a/tests/brevitas/graph/test_equalization.py +++ b/tests/brevitas/graph/test_equalization.py @@ -36,7 +36,7 @@ def test_resnet18_equalization(): # Check that equalization is not introducing FP variations assert torch.allclose(expected_out, out, atol=ATOL) - regions = sorted(regions, key=lambda region: sorted([r for r in region.srcs.keys()])) + regions = sorted(regions, key=lambda region: sorted([r for r in region.srcs_names])) resnet_18_regions = sorted(RESNET_18_REGIONS, key=lambda region: region[0][0]) equalized_layers = set() for r in resnet_18_regions: @@ -45,9 +45,9 @@ def test_resnet18_equalization(): # Check that we found all the expected regions for region, expected_region in zip(regions, resnet_18_regions): - srcs = list(region.srcs) + srcs = region.srcs_names sources_check = set(srcs) == set(expected_region[0]) - sinks = list(region.sinks) + sinks = region.sinks_names sinks_check = set(sinks) == set(expected_region[1]) assert sources_check assert sinks_check From 24ad80ca8033464bca308f5c6e7a787267159b04 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 18 Dec 2023 10:42:55 +0000 Subject: [PATCH 07/16] Review --- src/brevitas/graph/equalize.py | 145 ++++++++++-------- tests/brevitas/graph/equalization_fixtures.py | 14 +- 2 files changed, 84 insertions(+), 75 deletions(-) diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index c05366669..6021ca6f9 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -86,6 +86,18 @@ _ignore_ops = (getattr, 'size') +# Start and End identify the starting and ending channels of the weight matrix that need to be +# equalized. +# Offset refers to the relative position of these channels with respect to +# the other matrices' channels that are equalized simultaneously. +# Source matrix are always fully equalized, while sinks can be partially equalized. +@dataclass +class EqualizationIndexes: + start: int = 0 + end: int = 0 + offset: int = 0 + + # Required for being hashable @dataclass(eq=True, frozen=True) class WeightBiasTuple: @@ -99,6 +111,7 @@ class Region: srcs: Dict = field(default_factory=dict) sinks: Dict = field(default_factory=dict) acts: Tuple = field(default_factory=tuple) + name_to_module: Dict = field(default_factory=dict) @property def srcs_names(self): @@ -108,6 +121,10 @@ def srcs_names(self): def sinks_names(self): return [name.split("$")[0] for name in self.sinks.keys()] + def get_module_from_name(self, name: str) -> nn.Module: + name = name.split("$")[0] + return self.name_to_module[name] + @dataclass class WalkRegionState: @@ -115,21 +132,47 @@ class WalkRegionState: sinks: Dict = field(default_factory=dict) acts: Set = field(default_factory=set) history: set = field(default_factory=set) + name_to_module: Dict = field(default_factory=dict) + add_mul_node: bool = False offset: int = 0 update_offset: bool = False + @property + def srcs_names(self): + return [name.split("$")[0] for name in self.srcs.keys()] -# Start and End identify the starting and ending channels of the weight matrix that need to be -# equalized. -# Offset refers to the relative position of these channels with respect to -# the other matrices' channels that are equalized simultaneously. -# Source matrix are always fully equalized, while sinks can be partially equalized. -@dataclass -class EqualizationIndexes: - start: int = 0 - end: int = 0 - offset: int = 0 + @property + def sinks_names(self): + return [name.split("$")[0] for name in self.sinks.keys()] + + def add( + self, + type: str, + name: str, + module: nn.Module, + indexes: Optional[EqualizationIndexes] = None): + if type == 'srcs' or type == 'sinks': + assert indexes is not None + full_source_name = name + '$' + str(indexes) + self.srcs[full_source_name] = indexes + getattr(self, type)[full_source_name] = indexes + else: + getattr(self, type).add(name) + self.name_to_module[name] = module + + def add_srcs(self, src_name: str, src: nn.Module, indexes: EqualizationIndexes): + self.add('srcs', src_name, src, indexes) + + def add_sinks(self, sink_name: str, sink: nn.Module, indexes: EqualizationIndexes): + self.add('srcs', sink_name, sink, indexes) + + def add_acts(self, act_name: str, act: nn.Module): + self.add('acts', act_name, act) + + def get_module_from_name(self, name: str) -> nn.Module: + name = name.split("$")[0] + return self.name_to_module[name] def __str__(self): @@ -179,23 +222,6 @@ def __exit__(self, type, value, traceback): return True # To propagate exceptions -def dict_name_to_module(model, regions): - name_to_module: Dict[str, torch.nn.Module] = {} - - name_set = set() - for region in regions: - for name in region.srcs_names: - name_set.add(name) - for name in region.sinks_names: - name_set.add(name) - for name in region.acts: - name_set.add(name) - for name, module in model.named_modules(): - if name in name_set: - name_to_module[name] = module - return name_to_module - - def _channel_range(inp: torch.Tensor, dim: int = 1) -> torch.Tensor: mins, _ = inp.min(dim=dim) maxs, _ = inp.max(dim=dim) @@ -556,13 +582,12 @@ def _update_weights(original_module, new_value, attr='weight'): setattr(original_module, attr, nn.Parameter(new_value)) -def _organize_region(region, name_to_module, type): +def _organize_region(region, type): region_dict = {} region = getattr(region, type) for i, (k, v) in enumerate(region.items()): name = type + str(i) - k = k.split('$')[0] - region_dict[name] = (name_to_module[k], v) + region_dict[name] = (region.get_module_from_name(k), v) return region_dict @@ -577,22 +602,11 @@ def _equalize( """ Generalized version of section 4.1 of https://arxiv.org/pdf/1906.04721.pdf """ - name_to_module: Dict[str, nn.Module] = {} - name_set = set() - for region in regions: - for name in region.srcs_names: - name_set.add(name) - for name in region.sinks_names: - name_set.add(name) - - for name, module in model.named_modules(): - if name in name_set: - name_to_module[name] = module for i in range(iterations): scale_factor_max = None for ii, region in enumerate(regions): - srcs_dict = _organize_region(region, name_to_module, 'srcs') - sinks_dict = _organize_region(region, name_to_module, 'sinks') + srcs_dict = _organize_region(region, 'srcs') + sinks_dict = _organize_region(region, 'sinks') scale_factors_region = _cross_layer_equalization( srcs_dict, sinks_dict, @@ -740,9 +754,9 @@ def find_srcs(graph_model: GraphModule, starting_node: Node, module = get_module(graph_model, node.target) weight = get_weight_source([module]) eq_indexes = EqualizationIndexes(0, weight[0].shape[0], state.offset) - full_source_name = node.target + '$' + str(eq_indexes) - state.srcs[full_source_name] = eq_indexes + # After we found a source, we need to check if it branches into multiple sinks + state.add_srcs(node.target, module, eq_indexes) find_sinks(graph_model, node, state) state.offset = state.offset if not state.update_offset else state.offset + weight[ 0].shape[0] @@ -796,8 +810,8 @@ def find_sinks(graph_model: GraphModule, starting_node: Node, if isinstance(module, (nn.LayerNorm,) + _batch_norm): state.sinks[_UNSUPPORTED_OP] = _UNSUPPORTED_OP else: - full_sink_name = node.target + '$' + str(eq_indexes) - state.sinks[full_sink_name] = eq_indexes + state.add_sinks(node.target, module, eq_indexes) + elif _is_scale_invariant_module( graph_model, node) or _is_scale_invariant_function(node) or _is_reshaping_op(node): find_sinks(graph_model, node, state) @@ -826,8 +840,11 @@ def find_sinks(graph_model: GraphModule, starting_node: Node, new_state = WalkRegionState(offset=state.offset) find_sinks(graph_model, node, new_state) - for k in new_state.sinks.keys(): - state.sinks[k] = EqualizationIndexes(start, end, new_state.offset) + for k in new_state.sinks_names: + state.add_sinks( + k, + new_state.get_module_from_name(k), + EqualizationIndexes(start, end, new_state.offset)) state.srcs.update(new_state.srcs) elif node.target in _ignore_ops: continue @@ -841,20 +858,19 @@ def _extract_regions( add_mul_node: bool = False, return_acts: bool = False) -> List[Region]: regions = list() - regions_name = set() for node in graph_model.graph.nodes: if _is_supported_module(graph_model, node) or (add_mul_node and _is_scale_varying_activation(graph_model, node)): state = WalkRegionState(add_mul_node=add_mul_node) if _is_scale_varying_activation(graph_model, node): - state.acts.add(node.target) + module = get_module(graph_model, node.target) + state.add_acts(node.target, module) else: module = get_module(graph_model, node.target) weight = get_weight_source([module]) eq_indexes = EqualizationIndexes(0, weight[0].shape[0], 0) - full_source_name = node.target + '$' + str(eq_indexes) - state.srcs[full_source_name] = eq_indexes + state.add_srcs(node.target.module, eq_indexes) find_sinks(graph_model, node, state) if state.sinks and _UNSUPPORTED_OP not in state.sinks.keys( ) and _UNSUPPORTED_OP not in state.srcs.keys(): @@ -864,9 +880,14 @@ def _extract_regions( sorted_sinks = dict(sorted(state.sinks.items())) sorted_acts = tuple(sorted(state.acts)) if return_acts: - region = Region(srcs=sorted_srcs, sinks=sorted_sinks, acts=sorted_acts) + region = Region( + srcs=sorted_srcs, + sinks=sorted_sinks, + acts=sorted_acts, + name_to_module=state.name_to_module) else: - region = Region(srcs=sorted_srcs, sinks=sorted_sinks) + region = Region( + srcs=sorted_srcs, sinks=sorted_sinks, name_to_module=state.name_to_module) if region not in regions: regions.append(region) @@ -1055,7 +1076,6 @@ def __init__( self.scale_fn = _channel_range def setup(self): - name_to_module = dict_name_to_module(self.model, self.regions) # Select only regions with activation to equalize through. # If a region has multiple scale varying activation, must also be dropped # because we can't propagate scaling factors @@ -1063,7 +1083,7 @@ def setup(self): for region in self.regions: # This condition is for redudancy, since # a region with two scale-varying activations cannot be detected in the first place - if len(region.acts) > 1 and any([isinstance(name_to_module[act_name], + if len(region.acts) > 1 and any([isinstance(region.get_module_from_name(act_name), _scale_varying_activations) for act_name in region.acts]): regions_to_drop.append(region) @@ -1073,11 +1093,11 @@ def setup(self): batch_dim = 0 region_to_search = region.sinks if len(region.acts) == 0 else region.acts for name in region.srcs + region.sinks: - module = name_to_module[name] + module = region.get_module_from_name(name) if hasattr(module, 'batch_first'): batch_dim = 0 if module.batch_first else 1 for name in region_to_search: - module = name_to_module[name] + module = region.get_module_from_name(name) use_inp = True if region_to_search == region.sinks else False hook_fn = partial( self.forward_stats_hook, name=name, batch_dim=batch_dim, use_inp=use_inp) @@ -1090,12 +1110,11 @@ def setup(self): def apply(self, alpha): scale_factors = [] self.remove_hooks() - name_to_module = dict_name_to_module(self.model, self.regions) for region in self.regions: region_names = region.sinks_names if len(region.acts) == 0 else region.acts if any([self.float_act_map[name] is None for name in region_names]): continue - act_module = [name_to_module[act_name] for act_name in region.acts] + act_module = [region.get_module_from_name(act_name) for act_name in region.acts] list_of_act_val = [self.float_act_map[name] for name in region_names] list_of_insert_mul_node_fn = None @@ -1110,8 +1129,8 @@ def apply(self, alpha): self.insert_mul_node, act_node=act_node, batch_dim=self.batch_dim_act_map[act_name])) - srcs_dict = _organize_region(region, name_to_module, 'srcs') - sinks_dict = _organize_region(region, name_to_module, 'sinks') + srcs_dict = _organize_region(region, 'srcs') + sinks_dict = _organize_region(region, 'sinks') scale_factors.append( _cross_layer_equalization( srcs_dict, diff --git a/tests/brevitas/graph/equalization_fixtures.py b/tests/brevitas/graph/equalization_fixtures.py index 8feeb5c6b..03812af48 100644 --- a/tests/brevitas/graph/equalization_fixtures.py +++ b/tests/brevitas/graph/equalization_fixtures.py @@ -30,21 +30,11 @@ def equalize_test(model, regions, merge_bias, bias_shrinkage, scale_computation_type): - name_to_module = {} - name_set = set() - for region in regions: - for name in region.srcs_names: - name_set.add(name) - for name in region.sinks_names: - name_set.add(name) scale_factors_regions = [] - for name, module in model.named_modules(): - if name in name_set: - name_to_module[name] = module for i in range(3): for region in regions: - srcs_dict = _organize_region(region, name_to_module, 'srcs') - sinks_dict = _organize_region(region, name_to_module, 'sinks') + srcs_dict = _organize_region(region, 'srcs') + sinks_dict = _organize_region(region, 'sinks') scale_factors_region = _cross_layer_equalization( srcs_dict, sinks_dict, From 9b812c292545157ea92dea6e3add0dad6753a226 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 18 Dec 2023 10:47:07 +0000 Subject: [PATCH 08/16] fix --- src/brevitas/graph/equalize.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index 6021ca6f9..42fa77e48 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -870,7 +870,7 @@ def _extract_regions( module = get_module(graph_model, node.target) weight = get_weight_source([module]) eq_indexes = EqualizationIndexes(0, weight[0].shape[0], 0) - state.add_srcs(node.target.module, eq_indexes) + state.add_srcs(node.target, module, eq_indexes) find_sinks(graph_model, node, state) if state.sinks and _UNSUPPORTED_OP not in state.sinks.keys( ) and _UNSUPPORTED_OP not in state.srcs.keys(): From 5ab02731156a6b16f61f0f17de5537f0bb49c00a Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 18 Dec 2023 15:28:48 +0000 Subject: [PATCH 09/16] Review --- src/brevitas/graph/equalize.py | 123 ++++++++---------- tests/brevitas/graph/equalization_fixtures.py | 5 +- 2 files changed, 58 insertions(+), 70 deletions(-) diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index 42fa77e48..b50bab6d8 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -155,17 +155,16 @@ def add( if type == 'srcs' or type == 'sinks': assert indexes is not None full_source_name = name + '$' + str(indexes) - self.srcs[full_source_name] = indexes getattr(self, type)[full_source_name] = indexes - else: - getattr(self, type).add(name) + elif type == 'acts': + self.acts.add(name) self.name_to_module[name] = module def add_srcs(self, src_name: str, src: nn.Module, indexes: EqualizationIndexes): self.add('srcs', src_name, src, indexes) def add_sinks(self, sink_name: str, sink: nn.Module, indexes: EqualizationIndexes): - self.add('srcs', sink_name, sink, indexes) + self.add('sinks', sink_name, sink, indexes) def add_acts(self, act_name: str, act: nn.Module): self.add('acts', act_name, act) @@ -373,8 +372,7 @@ def transpose(module: torch.nn.Module, axis: int): def _cross_layer_equalization( - srcs: Dict[nn.Module, List[int]], - sinks: Dict[nn.Module, List[int]], + region: Region, merge_bias: bool, scale_computation_type: str, bias_shrinkage: Optional[Union[float, str]] = None, @@ -387,10 +385,6 @@ def _cross_layer_equalization( ranges of the second tensors' input channel """ - # Determine device and type of tensors - device = next(sinks['sinks0'][0].parameters()).device - dtype = next(sinks['sinks0'][0].parameters()).dtype - # If equalization criteria are not met, we return a scalar one to indicate that no equalization # has been performed def _no_equalize(): @@ -398,23 +392,24 @@ def _no_equalize(): act_sink_axes = {} act_sources_axes = {} - single_module = list(sinks.values())[0][0] + single_module = region.get_module_from_name(next(iter(region.sinks_names))) device = next(single_module.parameters()).device dtype = next(single_module.parameters()).dtype max_shape_srcs = 0 - for name, (k, v) in srcs.items(): - max_shape_srcs = max(max_shape_srcs, v.end + v.offset) + for name, indexes in region.srcs.items(): + max_shape_srcs = max(max_shape_srcs, indexes.end + indexes.offset) max_shape_sinks = 0 - for name, (k, v) in sinks.items(): - max_shape_sinks = max(max_shape_sinks, v.offset + (v.end - v.start)) + for name, indexes in region.sinks.items(): + max_shape_sinks = max(max_shape_sinks, indexes.offset + (indexes.end - indexes.start)) # Exit if source and sink have different sizes - if max_shape_srcs != max_shape_sinks and len(srcs) > 0: + if max_shape_srcs != max_shape_sinks and len(region.srcs) > 0: return _no_equalize() src_axes = {} - for i, (name, (module, indexes)) in enumerate(srcs.items()): + for name, indexes in region.srcs.items(): + module = region.get_module_from_name(name) # If module is not supported, do not perform graph equalization axis = _get_output_axis(module) act_sources_axes[name] = _get_act_axis(module) @@ -422,11 +417,11 @@ def _no_equalize(): return _no_equalize() if isinstance(module, nn.MultiheadAttention): module = module.out_proj - srcs[name] = (module, indexes) src_axes[name] = (module, axis) sink_axes = {} - for i, (name, (module, indexes)) in enumerate(sinks.items()): + for name, indexes in region.sinks.items(): + module = region.get_module_from_name(name) axis = _get_input_axis(module) act_sink_axes[name] = _get_act_axis(module) # If module is not supported, do not perform graph equalization @@ -436,7 +431,6 @@ def _no_equalize(): if isinstance(module, nn.MultiheadAttention) and module.in_proj_weight is not None: # For sinks, we only need to modify the weight but not the bias module = WeightBiasTuple(module.in_proj_weight) - sinks[name] = (module, indexes) elif isinstance(module, nn.MultiheadAttention) and module.in_proj_weight is None: return _no_equalize() sink_axes[name] = (module, axis) @@ -469,7 +463,7 @@ def _no_equalize(): for k, v in sink_weights.items(): # Sinks can be partially equalized, thus we need to select # only the channels we are interested in - indexes = sinks[k][1] + indexes = region.sinks[k] # Compute the range of the channels we need to equalize weight_range = scale_fn(v.reshape(v.size(0), -1))[indexes.start:indexes.end] # Compute the numbers of channels we are equalizing @@ -501,7 +495,7 @@ def _no_equalize(): for k, v in src_weights.items(): # Srcs are always fully equalized, thus we simply need to apply the offset to position them # correctly with respect to the other srcs matrices. - indexes = srcs[k][1] + indexes = region.srcs[k] channel_start = indexes.offset + indexes.start channel_end = indexes.offset + indexes.end weight_range = scale_fn(v.reshape(v.size(0), -1)) @@ -533,7 +527,7 @@ def _no_equalize(): insert_mul_node_fn(inverse_scaling_factors, act_val_shape, act_axis) if len(src_axes) > 0: for name, (module, axis) in src_axes.items(): - indexes = srcs[name][1] + indexes = region.srcs[name] channel_start = indexes.offset + indexes.start channel_end = indexes.offset + indexes.end if hasattr(module, 'bias') and module.bias is not None: @@ -550,10 +544,10 @@ def _no_equalize(): module.weight.clone() * torch.reshape( inverse_scaling_factors[channel_start:channel_end], src_broadcast_size), attr='weight') - for module, axis in sink_axes.items(): + for name, (module, axis) in sink_axes.items(): sink_broadcast_size = [1] * module.weight.ndim sink_broadcast_size[axis] = module.weight.size(axis) - indexes = sinks[name][1] + indexes = region.sinks[name] channel_range = indexes.end - indexes.start partial_scaling = torch.ones(module.weight.size(axis), device=device, dtype=dtype) # We replace the scaling factors of the channels we need to equalize, leaving the other to @@ -582,15 +576,6 @@ def _update_weights(original_module, new_value, attr='weight'): setattr(original_module, attr, nn.Parameter(new_value)) -def _organize_region(region, type): - region_dict = {} - region = getattr(region, type) - for i, (k, v) in enumerate(region.items()): - name = type + str(i) - region_dict[name] = (region.get_module_from_name(k), v) - return region_dict - - def _equalize( model: GraphModule, regions: Set[Tuple[str]], @@ -605,11 +590,8 @@ def _equalize( for i in range(iterations): scale_factor_max = None for ii, region in enumerate(regions): - srcs_dict = _organize_region(region, 'srcs') - sinks_dict = _organize_region(region, 'sinks') scale_factors_region = _cross_layer_equalization( - srcs_dict, - sinks_dict, + region, merge_bias=merge_bias, bias_shrinkage=bias_shrinkage, scale_computation_type=scale_computation_type) @@ -996,7 +978,7 @@ def __init__(self, model, scale_computation_type: str = 'maxabs'): self.hooks = [] self.add_mul_node = True - regions = [] + regions: List[Region] = [] self.find_module(model, regions) self.regions = regions @@ -1014,39 +996,42 @@ def find_module(self, model, regions: List): _supported_layers) and not isinstance(model, _batch_norm + (nn.LayerNorm,)): weight = get_weight_sink([model]) eq_indexes = EqualizationIndexes(0, weight[0].shape[0], 0) - regions.append((model, eq_indexes)) + region = Region(sinks={'sinks0': eq_indexes}, name_to_module={'sinks0': model}) + regions.append(region) else: for module in model.children(): self.find_module(module, regions) def setup(self): for region in self.regions: + module = region.get_module_from_name('sinks0') batch_dim = 0 if hasattr(region, 'batch_first'): batch_dim = 0 if region.batch_first else 1 hook_fn = partial( - self.forward_stats_hook, name=region[0], batch_dim=batch_dim, use_inp=True) - new_instance = KwargsForwardHook(region[0], hook_fn) - ModuleInstanceToModuleInstance(region[0], new_instance).apply(self.model) + self.forward_stats_hook, name=module, batch_dim=batch_dim, use_inp=True) + new_instance = KwargsForwardHook(module, hook_fn) + ModuleInstanceToModuleInstance(module, new_instance).apply(self.model) self.hooks.append(new_instance) def apply(self, alpha): scale_factors = [] self.remove_hooks() for region in self.regions: - if self.float_act_map[region[0]] == None: + module = region.get_module_from_name('sinks0') + if self.float_act_map[module] == None: continue - sinks = region insert_mul_fn = partial( - self.insert_mul_node, region=region[0], batch_dim=self.batch_dim_act_map[region[0]]) + self.insert_mul_node, region=module, batch_dim=self.batch_dim_act_map[module]) scale_factors.append( - _cross_layer_equalization({}, {'sinks0': sinks}, - False, - scale_computation_type=self.scale_computation_type, - list_of_act_val=[self.float_act_map[region[0]]], - list_of_insert_mul_node_fn=[insert_mul_fn], - alpha=alpha)) + _cross_layer_equalization( + region, + False, + scale_computation_type=self.scale_computation_type, + list_of_act_val=[self.float_act_map[module]], + list_of_insert_mul_node_fn=[insert_mul_fn], + alpha=alpha)) return scale_factors def insert_mul_node(self, scale, shape, axis, region, batch_dim=0): @@ -1067,6 +1052,7 @@ def __init__( self.float_act_map = {} self.batch_dim_act_map = {} self.hooks = [] + self.hooked_modules = set() self.add_mul_node = add_mul_node self.regions = _extract_regions(model, add_mul_node=add_mul_node, return_acts=True) @@ -1091,19 +1077,26 @@ def setup(self): # We assume that the entire region has a unique batch_dim batch_dim = 0 - region_to_search = region.sinks if len(region.acts) == 0 else region.acts - for name in region.srcs + region.sinks: + for name in region.srcs: + module = region.get_module_from_name(name) + if hasattr(module, 'batch_first') and not module.batch_first: + batch_dim = 1 + for name in region.sinks: module = region.get_module_from_name(name) - if hasattr(module, 'batch_first'): - batch_dim = 0 if module.batch_first else 1 + if hasattr(module, 'batch_first') and not module.batch_first: + batch_dim = 1 + + region_to_search = region.sinks_names if len(region.acts) == 0 else region.acts for name in region_to_search: module = region.get_module_from_name(name) - use_inp = True if region_to_search == region.sinks else False - hook_fn = partial( - self.forward_stats_hook, name=name, batch_dim=batch_dim, use_inp=use_inp) - new_instance = KwargsForwardHook(module, hook_fn) - ModuleInstanceToModuleInstance(module, new_instance).apply(self.model) - self.hooks.append(new_instance) + if module not in self.hooked_modules: + self.hooked_modules.add(module) + use_inp = True if region_to_search == region.sinks_names else False + hook_fn = partial( + self.forward_stats_hook, name=name, batch_dim=batch_dim, use_inp=use_inp) + new_instance = KwargsForwardHook(module, hook_fn) + ModuleInstanceToModuleInstance(module, new_instance).apply(self.model) + self.hooks.append(new_instance) self.regions = [x for x in self.regions if x not in regions_to_drop] @@ -1129,12 +1122,10 @@ def apply(self, alpha): self.insert_mul_node, act_node=act_node, batch_dim=self.batch_dim_act_map[act_name])) - srcs_dict = _organize_region(region, 'srcs') - sinks_dict = _organize_region(region, 'sinks') + scale_factors.append( _cross_layer_equalization( - srcs_dict, - sinks_dict, + region, False, scale_computation_type=self.scale_computation_type, list_of_act_val=list_of_act_val, diff --git a/tests/brevitas/graph/equalization_fixtures.py b/tests/brevitas/graph/equalization_fixtures.py index 03812af48..fde3a60d9 100644 --- a/tests/brevitas/graph/equalization_fixtures.py +++ b/tests/brevitas/graph/equalization_fixtures.py @@ -33,11 +33,8 @@ def equalize_test(model, regions, merge_bias, bias_shrinkage, scale_computation_ scale_factors_regions = [] for i in range(3): for region in regions: - srcs_dict = _organize_region(region, 'srcs') - sinks_dict = _organize_region(region, 'sinks') scale_factors_region = _cross_layer_equalization( - srcs_dict, - sinks_dict, + region, merge_bias=merge_bias, bias_shrinkage=bias_shrinkage, scale_computation_type=scale_computation_type) From 0c594ea557e04f0d68156d177d952dfac6d033b2 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 18 Dec 2023 15:32:48 +0000 Subject: [PATCH 10/16] Removed unused flag --- src/brevitas/graph/quantize.py | 1 - src/brevitas/graph/target/flexml.py | 2 -- 2 files changed, 3 deletions(-) diff --git a/src/brevitas/graph/quantize.py b/src/brevitas/graph/quantize.py index 458be7029..63143c4e5 100644 --- a/src/brevitas/graph/quantize.py +++ b/src/brevitas/graph/quantize.py @@ -262,7 +262,6 @@ def preprocess_for_quantize( equalize_iters=0, equalize_merge_bias=True, merge_bn=True, - include_cat=True, equalize_bias_shrinkage: str = 'vaiq', equalize_scale_computation: str = 'maxabs'): diff --git a/src/brevitas/graph/target/flexml.py b/src/brevitas/graph/target/flexml.py index f6817c196..9aedd337c 100644 --- a/src/brevitas/graph/target/flexml.py +++ b/src/brevitas/graph/target/flexml.py @@ -126,7 +126,6 @@ def preprocess_for_flexml_quantize( merge_bn=True, equalize_bias_shrinkage='vaiq', equalize_scale_computation='maxabs', - include_cat=True, **model_kwargs): training_state = model.training model.eval() @@ -142,7 +141,6 @@ def preprocess_for_flexml_quantize( equalize_iters, equalize_merge_bias, merge_bn, - include_cat, equalize_bias_shrinkage, equalize_scale_computation) model.train(training_state) From a70516d35226c5b7fd0c212b27d44970a13f9810 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 18 Dec 2023 15:55:22 +0000 Subject: [PATCH 11/16] Correct exception handling for cat --- src/brevitas/graph/equalize.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index b50bab6d8..f4d1b412f 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -674,8 +674,14 @@ def find_srcs_channel_dim(model, inp_node): channel = weight[0].shape[0] return channel elif _is_add(inp_node): - # If it's add, we need the channel shape of one of the branches, since they are all the same - return find_srcs_channel_dim(model, inp_node.all_input_nodes[0]) + all_channels = [] + for n in inp_node.all_input_nodes: + all_channels.append(find_srcs_channel_dim(model, n)) + # All branches to add should have the same amount of channels + if all([channel == all_channels[0] for channel in all_channels]): + return all_channels[0] + else: + return _UNSUPPORTED_OP elif _is_cat(inp_node): total_channels = 0 # If it's cat, we need to sum the channel shape of all the branches @@ -816,7 +822,11 @@ def find_sinks(graph_model: GraphModule, starting_node: Node, index = node.all_input_nodes.index(starting_node) channels = [] for n in node.all_input_nodes: - channels.append(find_srcs_channel_dim(graph_model, n)) + channel_dim = find_srcs_channel_dim(graph_model, n) + if channel_dim is _UNSUPPORTED_OP: + state.sinks[_UNSUPPORTED_OP] = _UNSUPPORTED_OP + continue + channels.append(channel_dim) start = sum(channels[:index]) end = start + channels[index] new_state = WalkRegionState(offset=state.offset) From e2c277d2214ef8f98218ca4664a5129d1baa03c3 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 18 Dec 2023 16:37:38 +0000 Subject: [PATCH 12/16] Fix unsupported --- src/brevitas/graph/equalize.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index f4d1b412f..a5898b811 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -144,7 +144,7 @@ def srcs_names(self): @property def sinks_names(self): - return [name.split("$")[0] for name in self.sinks.keys()] + return [name.split("$")[0] for name in self.sinks.keys() if name is not _UNSUPPORTED_OP] def add( self, From dabb8c5f74ea670c91683f5f87587e6f064e461b Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 20 Dec 2023 15:59:40 +0000 Subject: [PATCH 13/16] Fix tests --- tests/brevitas/graph/equalization_fixtures.py | 3 +-- tests/brevitas/graph/test_equalization.py | 14 +++----------- 2 files changed, 4 insertions(+), 13 deletions(-) diff --git a/tests/brevitas/graph/equalization_fixtures.py b/tests/brevitas/graph/equalization_fixtures.py index fde3a60d9..263543a82 100644 --- a/tests/brevitas/graph/equalization_fixtures.py +++ b/tests/brevitas/graph/equalization_fixtures.py @@ -11,7 +11,6 @@ from brevitas import torch_version from brevitas.graph.equalize import _cross_layer_equalization -from brevitas.graph.equalize import _organize_region SEED = 123456 ATOL = 1e-3 @@ -29,7 +28,7 @@ IN_SIZE_LINEAR = (1, 224, 3) -def equalize_test(model, regions, merge_bias, bias_shrinkage, scale_computation_type): +def equalize_test(regions, merge_bias, bias_shrinkage, scale_computation_type): scale_factors_regions = [] for i in range(3): for region in regions: diff --git a/tests/brevitas/graph/test_equalization.py b/tests/brevitas/graph/test_equalization.py index e67fb4f63..4f713211c 100644 --- a/tests/brevitas/graph/test_equalization.py +++ b/tests/brevitas/graph/test_equalization.py @@ -30,7 +30,7 @@ def test_resnet18_equalization(): model_orig = copy.deepcopy(model) regions = _extract_regions(model) _ = equalize_test( - model, regions, merge_bias=True, bias_shrinkage='vaiq', scale_computation_type='maxabs') + regions, merge_bias=True, bias_shrinkage='vaiq', scale_computation_type='maxabs') out = model(inp) # Check that equalization is not introducing FP variations @@ -75,11 +75,7 @@ def test_equalization_torchvision_models(model_coverage: tuple, merge_bias: bool regions = _extract_regions(model) scale_factor_regions = equalize_test( - model, - regions, - merge_bias=merge_bias, - bias_shrinkage='vaiq', - scale_computation_type='maxabs') + regions, merge_bias=merge_bias, bias_shrinkage='vaiq', scale_computation_type='maxabs') shape_scale_regions = [scale.shape for scale in scale_factor_regions] out = model(inp) @@ -132,11 +128,7 @@ def test_models(toy_model, merge_bias, request): model = symbolic_trace(model) regions = _extract_regions(model) scale_factor_regions = equalize_test( - model, - regions, - merge_bias=merge_bias, - bias_shrinkage='vaiq', - scale_computation_type='maxabs') + regions, merge_bias=merge_bias, bias_shrinkage='vaiq', scale_computation_type='maxabs') shape_scale_regions = [scale.shape for scale in scale_factor_regions] with torch.no_grad(): From 0f67b6b899675a08b9fa9676df2bd76ca227dc04 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 20 Dec 2023 20:59:50 +0000 Subject: [PATCH 14/16] Review --- src/brevitas/graph/equalize.py | 33 ++++++++++----------------------- 1 file changed, 10 insertions(+), 23 deletions(-) diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index a5898b811..0c0ae75c3 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -134,7 +134,7 @@ class WalkRegionState: history: set = field(default_factory=set) name_to_module: Dict = field(default_factory=dict) - add_mul_node: bool = False + cat_encoutered: bool = False offset: int = 0 update_offset: bool = False @@ -471,7 +471,6 @@ def _no_equalize(): # Use the offset and the range to update the correct range in the sinks sinks_range[indexes.offset:indexes.offset + channel_range] = torch.max( sinks_range[indexes.offset:indexes.offset + channel_range], weight_range) - sinks_range = torch.clamp(sinks_range, EPSILON) # Determine the srcs_range based on where we are performing activation equalization or # weight equalization @@ -589,7 +588,7 @@ def _equalize( """ for i in range(iterations): scale_factor_max = None - for ii, region in enumerate(regions): + for region in regions: scale_factors_region = _cross_layer_equalization( region, merge_bias=merge_bias, @@ -699,7 +698,8 @@ def cat_handler(graph_model: GraphModule, starting_node: Node, state: WalkRegion state.srcs.clear() state.sinks.clear() state.history.clear() - state.srcs[starting_node.target] = _UNSUPPORTED_OP + # Keep track that concatenation has been encoutered once + state.cat_encoutered = True state.update_offset = True state.offset = 0 find_srcs(graph_model, starting_node, state) @@ -712,13 +712,6 @@ def _is_cat(node): return node.target in (torch.cat,) -def _is_cat_in_srcs(srcs): - out = False - for src in srcs: - out = out or src in (torch.cat,) - return out - - def _is_add(node): return ( node.op == 'call_method' and node.target in _residual_methods or @@ -733,7 +726,6 @@ def find_srcs(graph_model: GraphModule, starting_node: Node, # we keep a history of how the graph has been walked already, invariant to the direction, # to avoid getting stuck in a loop path = (node, starting_node) - module = None if path not in state.history: state.history.add(path) else: @@ -759,11 +751,10 @@ def find_srcs(graph_model: GraphModule, starting_node: Node, find_srcs(graph_model, node, state) state.update_offset = update_offset_state elif _is_cat(node): - # We have never encoutered cat - if not _is_cat_in_srcs(state.srcs): + # The first time we encoutered a cat differes from all subsequent ones + if not state.cat_encoutered: # We restart the region search starting from the cat cat_handler(graph_model, node, state) - # We have encoutered cat already else: state.update_offset = False find_sinks(graph_model, node, state) @@ -810,11 +801,10 @@ def find_sinks(graph_model: GraphModule, starting_node: Node, find_srcs(graph_model, node, state) state.update_offset = update_offset_state elif _is_cat(node): - # We have never encoutered cat - if not _is_cat_in_srcs(state.srcs): + # The first time we encoutered a cat differes from all subsequent ones + if not state.cat_encoutered: # We restart the region search starting from the cat cat_handler(graph_model, node, state) - # We have encoutered cat already else: # In this case we define all our sinks, and isolate only the channels we want # to equalize (start, end). @@ -854,7 +844,7 @@ def _extract_regions( if _is_supported_module(graph_model, node) or (add_mul_node and _is_scale_varying_activation(graph_model, node)): - state = WalkRegionState(add_mul_node=add_mul_node) + state = WalkRegionState() if _is_scale_varying_activation(graph_model, node): module = get_module(graph_model, node.target) state.add_acts(node.target, module) @@ -864,10 +854,7 @@ def _extract_regions( eq_indexes = EqualizationIndexes(0, weight[0].shape[0], 0) state.add_srcs(node.target, module, eq_indexes) find_sinks(graph_model, node, state) - if state.sinks and _UNSUPPORTED_OP not in state.sinks.keys( - ) and _UNSUPPORTED_OP not in state.srcs.keys(): - # Drop cat from the srcs - state.srcs = {k: v for k, v in state.srcs.items() if k is not torch.cat} + if len(state.sinks) > 0 and _UNSUPPORTED_OP not in state.sinks.keys(): sorted_srcs = dict(sorted(state.srcs.items())) sorted_sinks = dict(sorted(state.sinks.items())) sorted_acts = tuple(sorted(state.acts)) From 704772e2e15ab8f8fad45c6c05d4c43ec2063dd8 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 21 Dec 2023 10:23:06 +0000 Subject: [PATCH 15/16] Last review --- src/brevitas/graph/equalize.py | 59 ++++++++++++++++------------------ 1 file changed, 27 insertions(+), 32 deletions(-) diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index 0c0ae75c3..29d306536 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -638,29 +638,24 @@ def _is_reshaping_op(node: Node) -> bool: return node.target in _reshaping_op -def get_weight_source(module_list): - transpose = lambda module, axis: module.weight if axis == 0 else module.weight.transpose(0, 1) - for i, module in enumerate(module_list): - if isinstance(module, nn.MultiheadAttention): - if hasattr(module, 'out_proj'): - module_list[i] = module.out_proj - else: - raise RuntimeError("Configuration for Multiheadattention not supported") - srcs_axes = {module: _get_output_axis(module) for module in module_list} - weight = [transpose(m, axis) for m, axis in srcs_axes.items()] +def get_weight_source(module): + transpose = lambda weight, axis: weight if axis == 0 else weight.transpose(0, 1) + if isinstance(module, nn.MultiheadAttention) and not hasattr(module, 'out_proj'): + raise RuntimeError("Configuration for Multiheadattention not supported") + weight = module.out_proj.weight if isinstance(module, nn.MultiheadAttention) else module.weight + axis = _get_output_axis(module) + weight = transpose(weight, axis) return weight -def get_weight_sink(module_list): - transpose = lambda module, axis: module.weight if axis == 0 else module.weight.transpose(0, 1) - for i, module in enumerate(module_list): - if isinstance(module, nn.MultiheadAttention): - if hasattr(module, 'in_proj_weight'): - module_list[i] = WeightBiasTuple(module.in_proj_weight) - else: - raise RuntimeError("Configuration for Multiheadattention not supported") - sinks_axes = {module: _get_input_axis(module) for module in module_list} - weight = [transpose(m, axis) for m, axis in sinks_axes.items()] +def get_weight_sink(module): + transpose = lambda weight, axis: weight if axis == 0 else weight.transpose(0, 1) + if isinstance(module, nn.MultiheadAttention) and not hasattr(module, 'in_proj_weight'): + raise RuntimeError("Configuration for Multiheadattention not supported") + weight = WeightBiasTuple(module.in_proj_weight).weight if isinstance( + module, nn.MultiheadAttention) else module.weight + axis = _get_input_axis(module) + weight = transpose(weight, axis) return weight @@ -669,8 +664,8 @@ def find_srcs_channel_dim(model, inp_node): # If we meet a supported module, determine the channel shape module = get_module(model, inp_node.target) # Since we are walking up, we consider the module as srcs - weight = get_weight_source([module]) - channel = weight[0].shape[0] + weight = get_weight_source(module) + channel = weight.shape[0] return channel elif _is_add(inp_node): all_channels = [] @@ -732,14 +727,14 @@ def find_srcs(graph_model: GraphModule, starting_node: Node, continue if _is_supported_module(graph_model, node): module = get_module(graph_model, node.target) - weight = get_weight_source([module]) - eq_indexes = EqualizationIndexes(0, weight[0].shape[0], state.offset) + weight = get_weight_source(module) + eq_indexes = EqualizationIndexes(0, weight.shape[0], state.offset) # After we found a source, we need to check if it branches into multiple sinks state.add_srcs(node.target, module, eq_indexes) find_sinks(graph_model, node, state) - state.offset = state.offset if not state.update_offset else state.offset + weight[ - 0].shape[0] + state.offset = state.offset if not state.update_offset else state.offset + weight.shape[ + 0] elif _is_scale_invariant_module( graph_model, node) or _is_scale_invariant_function(node) or _is_reshaping_op(node): find_sinks(graph_model, node, state) @@ -783,8 +778,8 @@ def find_sinks(graph_model: GraphModule, starting_node: Node, continue if _is_supported_module(graph_model, node): module = get_module(graph_model, node.target) - weight = get_weight_sink([module]) - eq_indexes = EqualizationIndexes(0, weight[0].shape[0], state.offset) + weight = get_weight_sink(module) + eq_indexes = EqualizationIndexes(0, weight.shape[0], state.offset) # It is not possible to equalize through LayerNorm as sink if isinstance(module, (nn.LayerNorm,) + _batch_norm): state.sinks[_UNSUPPORTED_OP] = _UNSUPPORTED_OP @@ -850,8 +845,8 @@ def _extract_regions( state.add_acts(node.target, module) else: module = get_module(graph_model, node.target) - weight = get_weight_source([module]) - eq_indexes = EqualizationIndexes(0, weight[0].shape[0], 0) + weight = get_weight_source(module) + eq_indexes = EqualizationIndexes(0, weight.shape[0], 0) state.add_srcs(node.target, module, eq_indexes) find_sinks(graph_model, node, state) if len(state.sinks) > 0 and _UNSUPPORTED_OP not in state.sinks.keys(): @@ -991,8 +986,8 @@ def find_module(self, model, regions: List): """ if isinstance(model, _supported_layers) and not isinstance(model, _batch_norm + (nn.LayerNorm,)): - weight = get_weight_sink([model]) - eq_indexes = EqualizationIndexes(0, weight[0].shape[0], 0) + weight = get_weight_sink(model) + eq_indexes = EqualizationIndexes(0, weight.shape[0], 0) region = Region(sinks={'sinks0': eq_indexes}, name_to_module={'sinks0': model}) regions.append(region) else: From a42908480228a78ba5975e54f183fbbe6315130e Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 21 Dec 2023 13:28:59 +0000 Subject: [PATCH 16/16] remove batchnorm support --- src/brevitas/graph/equalize.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index 29d306536..fa455cbac 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -425,7 +425,7 @@ def _no_equalize(): axis = _get_input_axis(module) act_sink_axes[name] = _get_act_axis(module) # If module is not supported, do not perform graph equalization - if not isinstance(module, _supported_layers): + if not isinstance(module, _supported_layers) or module in _batch_norm: return _no_equalize() # For MultiheadAttention, we support only self-attetion if isinstance(module, nn.MultiheadAttention) and module.in_proj_weight is not None: @@ -553,13 +553,6 @@ def _no_equalize(): # one (i.e., no equalization) partial_scaling[indexes.start:indexes.end] = scaling_factors[indexes.offset:indexes.offset + channel_range] - if isinstance(module, _batch_norm): - # We re-compute the bias as function of running_mean and running_var to adjust the - # additive factor for equalization. - additive_factor = module.running_mean.data * module.weight.data / torch.sqrt( - module.running_var.data + module.eps) - _update_weights( - module, module.bias.clone() + additive_factor * (partial_scaling - 1), attr='bias') _update_weights( module, module.weight.clone() * torch.reshape(partial_scaling, sink_broadcast_size),