From 5ab02731156a6b16f61f0f17de5537f0bb49c00a Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 18 Dec 2023 15:28:48 +0000 Subject: [PATCH] Review --- src/brevitas/graph/equalize.py | 123 ++++++++---------- tests/brevitas/graph/equalization_fixtures.py | 5 +- 2 files changed, 58 insertions(+), 70 deletions(-) diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index 42fa77e48..b50bab6d8 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -155,17 +155,16 @@ def add( if type == 'srcs' or type == 'sinks': assert indexes is not None full_source_name = name + '$' + str(indexes) - self.srcs[full_source_name] = indexes getattr(self, type)[full_source_name] = indexes - else: - getattr(self, type).add(name) + elif type == 'acts': + self.acts.add(name) self.name_to_module[name] = module def add_srcs(self, src_name: str, src: nn.Module, indexes: EqualizationIndexes): self.add('srcs', src_name, src, indexes) def add_sinks(self, sink_name: str, sink: nn.Module, indexes: EqualizationIndexes): - self.add('srcs', sink_name, sink, indexes) + self.add('sinks', sink_name, sink, indexes) def add_acts(self, act_name: str, act: nn.Module): self.add('acts', act_name, act) @@ -373,8 +372,7 @@ def transpose(module: torch.nn.Module, axis: int): def _cross_layer_equalization( - srcs: Dict[nn.Module, List[int]], - sinks: Dict[nn.Module, List[int]], + region: Region, merge_bias: bool, scale_computation_type: str, bias_shrinkage: Optional[Union[float, str]] = None, @@ -387,10 +385,6 @@ def _cross_layer_equalization( ranges of the second tensors' input channel """ - # Determine device and type of tensors - device = next(sinks['sinks0'][0].parameters()).device - dtype = next(sinks['sinks0'][0].parameters()).dtype - # If equalization criteria are not met, we return a scalar one to indicate that no equalization # has been performed def _no_equalize(): @@ -398,23 +392,24 @@ def _no_equalize(): act_sink_axes = {} act_sources_axes = {} - single_module = list(sinks.values())[0][0] + single_module = region.get_module_from_name(next(iter(region.sinks_names))) device = next(single_module.parameters()).device dtype = next(single_module.parameters()).dtype max_shape_srcs = 0 - for name, (k, v) in srcs.items(): - max_shape_srcs = max(max_shape_srcs, v.end + v.offset) + for name, indexes in region.srcs.items(): + max_shape_srcs = max(max_shape_srcs, indexes.end + indexes.offset) max_shape_sinks = 0 - for name, (k, v) in sinks.items(): - max_shape_sinks = max(max_shape_sinks, v.offset + (v.end - v.start)) + for name, indexes in region.sinks.items(): + max_shape_sinks = max(max_shape_sinks, indexes.offset + (indexes.end - indexes.start)) # Exit if source and sink have different sizes - if max_shape_srcs != max_shape_sinks and len(srcs) > 0: + if max_shape_srcs != max_shape_sinks and len(region.srcs) > 0: return _no_equalize() src_axes = {} - for i, (name, (module, indexes)) in enumerate(srcs.items()): + for name, indexes in region.srcs.items(): + module = region.get_module_from_name(name) # If module is not supported, do not perform graph equalization axis = _get_output_axis(module) act_sources_axes[name] = _get_act_axis(module) @@ -422,11 +417,11 @@ def _no_equalize(): return _no_equalize() if isinstance(module, nn.MultiheadAttention): module = module.out_proj - srcs[name] = (module, indexes) src_axes[name] = (module, axis) sink_axes = {} - for i, (name, (module, indexes)) in enumerate(sinks.items()): + for name, indexes in region.sinks.items(): + module = region.get_module_from_name(name) axis = _get_input_axis(module) act_sink_axes[name] = _get_act_axis(module) # If module is not supported, do not perform graph equalization @@ -436,7 +431,6 @@ def _no_equalize(): if isinstance(module, nn.MultiheadAttention) and module.in_proj_weight is not None: # For sinks, we only need to modify the weight but not the bias module = WeightBiasTuple(module.in_proj_weight) - sinks[name] = (module, indexes) elif isinstance(module, nn.MultiheadAttention) and module.in_proj_weight is None: return _no_equalize() sink_axes[name] = (module, axis) @@ -469,7 +463,7 @@ def _no_equalize(): for k, v in sink_weights.items(): # Sinks can be partially equalized, thus we need to select # only the channels we are interested in - indexes = sinks[k][1] + indexes = region.sinks[k] # Compute the range of the channels we need to equalize weight_range = scale_fn(v.reshape(v.size(0), -1))[indexes.start:indexes.end] # Compute the numbers of channels we are equalizing @@ -501,7 +495,7 @@ def _no_equalize(): for k, v in src_weights.items(): # Srcs are always fully equalized, thus we simply need to apply the offset to position them # correctly with respect to the other srcs matrices. - indexes = srcs[k][1] + indexes = region.srcs[k] channel_start = indexes.offset + indexes.start channel_end = indexes.offset + indexes.end weight_range = scale_fn(v.reshape(v.size(0), -1)) @@ -533,7 +527,7 @@ def _no_equalize(): insert_mul_node_fn(inverse_scaling_factors, act_val_shape, act_axis) if len(src_axes) > 0: for name, (module, axis) in src_axes.items(): - indexes = srcs[name][1] + indexes = region.srcs[name] channel_start = indexes.offset + indexes.start channel_end = indexes.offset + indexes.end if hasattr(module, 'bias') and module.bias is not None: @@ -550,10 +544,10 @@ def _no_equalize(): module.weight.clone() * torch.reshape( inverse_scaling_factors[channel_start:channel_end], src_broadcast_size), attr='weight') - for module, axis in sink_axes.items(): + for name, (module, axis) in sink_axes.items(): sink_broadcast_size = [1] * module.weight.ndim sink_broadcast_size[axis] = module.weight.size(axis) - indexes = sinks[name][1] + indexes = region.sinks[name] channel_range = indexes.end - indexes.start partial_scaling = torch.ones(module.weight.size(axis), device=device, dtype=dtype) # We replace the scaling factors of the channels we need to equalize, leaving the other to @@ -582,15 +576,6 @@ def _update_weights(original_module, new_value, attr='weight'): setattr(original_module, attr, nn.Parameter(new_value)) -def _organize_region(region, type): - region_dict = {} - region = getattr(region, type) - for i, (k, v) in enumerate(region.items()): - name = type + str(i) - region_dict[name] = (region.get_module_from_name(k), v) - return region_dict - - def _equalize( model: GraphModule, regions: Set[Tuple[str]], @@ -605,11 +590,8 @@ def _equalize( for i in range(iterations): scale_factor_max = None for ii, region in enumerate(regions): - srcs_dict = _organize_region(region, 'srcs') - sinks_dict = _organize_region(region, 'sinks') scale_factors_region = _cross_layer_equalization( - srcs_dict, - sinks_dict, + region, merge_bias=merge_bias, bias_shrinkage=bias_shrinkage, scale_computation_type=scale_computation_type) @@ -996,7 +978,7 @@ def __init__(self, model, scale_computation_type: str = 'maxabs'): self.hooks = [] self.add_mul_node = True - regions = [] + regions: List[Region] = [] self.find_module(model, regions) self.regions = regions @@ -1014,39 +996,42 @@ def find_module(self, model, regions: List): _supported_layers) and not isinstance(model, _batch_norm + (nn.LayerNorm,)): weight = get_weight_sink([model]) eq_indexes = EqualizationIndexes(0, weight[0].shape[0], 0) - regions.append((model, eq_indexes)) + region = Region(sinks={'sinks0': eq_indexes}, name_to_module={'sinks0': model}) + regions.append(region) else: for module in model.children(): self.find_module(module, regions) def setup(self): for region in self.regions: + module = region.get_module_from_name('sinks0') batch_dim = 0 if hasattr(region, 'batch_first'): batch_dim = 0 if region.batch_first else 1 hook_fn = partial( - self.forward_stats_hook, name=region[0], batch_dim=batch_dim, use_inp=True) - new_instance = KwargsForwardHook(region[0], hook_fn) - ModuleInstanceToModuleInstance(region[0], new_instance).apply(self.model) + self.forward_stats_hook, name=module, batch_dim=batch_dim, use_inp=True) + new_instance = KwargsForwardHook(module, hook_fn) + ModuleInstanceToModuleInstance(module, new_instance).apply(self.model) self.hooks.append(new_instance) def apply(self, alpha): scale_factors = [] self.remove_hooks() for region in self.regions: - if self.float_act_map[region[0]] == None: + module = region.get_module_from_name('sinks0') + if self.float_act_map[module] == None: continue - sinks = region insert_mul_fn = partial( - self.insert_mul_node, region=region[0], batch_dim=self.batch_dim_act_map[region[0]]) + self.insert_mul_node, region=module, batch_dim=self.batch_dim_act_map[module]) scale_factors.append( - _cross_layer_equalization({}, {'sinks0': sinks}, - False, - scale_computation_type=self.scale_computation_type, - list_of_act_val=[self.float_act_map[region[0]]], - list_of_insert_mul_node_fn=[insert_mul_fn], - alpha=alpha)) + _cross_layer_equalization( + region, + False, + scale_computation_type=self.scale_computation_type, + list_of_act_val=[self.float_act_map[module]], + list_of_insert_mul_node_fn=[insert_mul_fn], + alpha=alpha)) return scale_factors def insert_mul_node(self, scale, shape, axis, region, batch_dim=0): @@ -1067,6 +1052,7 @@ def __init__( self.float_act_map = {} self.batch_dim_act_map = {} self.hooks = [] + self.hooked_modules = set() self.add_mul_node = add_mul_node self.regions = _extract_regions(model, add_mul_node=add_mul_node, return_acts=True) @@ -1091,19 +1077,26 @@ def setup(self): # We assume that the entire region has a unique batch_dim batch_dim = 0 - region_to_search = region.sinks if len(region.acts) == 0 else region.acts - for name in region.srcs + region.sinks: + for name in region.srcs: + module = region.get_module_from_name(name) + if hasattr(module, 'batch_first') and not module.batch_first: + batch_dim = 1 + for name in region.sinks: module = region.get_module_from_name(name) - if hasattr(module, 'batch_first'): - batch_dim = 0 if module.batch_first else 1 + if hasattr(module, 'batch_first') and not module.batch_first: + batch_dim = 1 + + region_to_search = region.sinks_names if len(region.acts) == 0 else region.acts for name in region_to_search: module = region.get_module_from_name(name) - use_inp = True if region_to_search == region.sinks else False - hook_fn = partial( - self.forward_stats_hook, name=name, batch_dim=batch_dim, use_inp=use_inp) - new_instance = KwargsForwardHook(module, hook_fn) - ModuleInstanceToModuleInstance(module, new_instance).apply(self.model) - self.hooks.append(new_instance) + if module not in self.hooked_modules: + self.hooked_modules.add(module) + use_inp = True if region_to_search == region.sinks_names else False + hook_fn = partial( + self.forward_stats_hook, name=name, batch_dim=batch_dim, use_inp=use_inp) + new_instance = KwargsForwardHook(module, hook_fn) + ModuleInstanceToModuleInstance(module, new_instance).apply(self.model) + self.hooks.append(new_instance) self.regions = [x for x in self.regions if x not in regions_to_drop] @@ -1129,12 +1122,10 @@ def apply(self, alpha): self.insert_mul_node, act_node=act_node, batch_dim=self.batch_dim_act_map[act_name])) - srcs_dict = _organize_region(region, 'srcs') - sinks_dict = _organize_region(region, 'sinks') + scale_factors.append( _cross_layer_equalization( - srcs_dict, - sinks_dict, + region, False, scale_computation_type=self.scale_computation_type, list_of_act_val=list_of_act_val, diff --git a/tests/brevitas/graph/equalization_fixtures.py b/tests/brevitas/graph/equalization_fixtures.py index 03812af48..fde3a60d9 100644 --- a/tests/brevitas/graph/equalization_fixtures.py +++ b/tests/brevitas/graph/equalization_fixtures.py @@ -33,11 +33,8 @@ def equalize_test(model, regions, merge_bias, bias_shrinkage, scale_computation_ scale_factors_regions = [] for i in range(3): for region in regions: - srcs_dict = _organize_region(region, 'srcs') - sinks_dict = _organize_region(region, 'sinks') scale_factors_region = _cross_layer_equalization( - srcs_dict, - sinks_dict, + region, merge_bias=merge_bias, bias_shrinkage=bias_shrinkage, scale_computation_type=scale_computation_type)