Skip to content

Commit

Permalink
Review
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Dec 18, 2023
1 parent 9b812c2 commit 5ab0273
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 70 deletions.
123 changes: 57 additions & 66 deletions src/brevitas/graph/equalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,17 +155,16 @@ def add(
if type == 'srcs' or type == 'sinks':
assert indexes is not None
full_source_name = name + '$' + str(indexes)
self.srcs[full_source_name] = indexes
getattr(self, type)[full_source_name] = indexes
else:
getattr(self, type).add(name)
elif type == 'acts':
self.acts.add(name)
self.name_to_module[name] = module

def add_srcs(self, src_name: str, src: nn.Module, indexes: EqualizationIndexes):
self.add('srcs', src_name, src, indexes)

def add_sinks(self, sink_name: str, sink: nn.Module, indexes: EqualizationIndexes):
self.add('srcs', sink_name, sink, indexes)
self.add('sinks', sink_name, sink, indexes)

def add_acts(self, act_name: str, act: nn.Module):
self.add('acts', act_name, act)
Expand Down Expand Up @@ -373,8 +372,7 @@ def transpose(module: torch.nn.Module, axis: int):


def _cross_layer_equalization(
srcs: Dict[nn.Module, List[int]],
sinks: Dict[nn.Module, List[int]],
region: Region,
merge_bias: bool,
scale_computation_type: str,
bias_shrinkage: Optional[Union[float, str]] = None,
Expand All @@ -387,46 +385,43 @@ def _cross_layer_equalization(
ranges of the second tensors' input channel
"""

# Determine device and type of tensors
device = next(sinks['sinks0'][0].parameters()).device
dtype = next(sinks['sinks0'][0].parameters()).dtype

# If equalization criteria are not met, we return a scalar one to indicate that no equalization
# has been performed
def _no_equalize():
return torch.tensor(1., dtype=dtype, device=device)

act_sink_axes = {}
act_sources_axes = {}
single_module = list(sinks.values())[0][0]
single_module = region.get_module_from_name(next(iter(region.sinks_names)))
device = next(single_module.parameters()).device
dtype = next(single_module.parameters()).dtype

max_shape_srcs = 0
for name, (k, v) in srcs.items():
max_shape_srcs = max(max_shape_srcs, v.end + v.offset)
for name, indexes in region.srcs.items():
max_shape_srcs = max(max_shape_srcs, indexes.end + indexes.offset)
max_shape_sinks = 0
for name, (k, v) in sinks.items():
max_shape_sinks = max(max_shape_sinks, v.offset + (v.end - v.start))
for name, indexes in region.sinks.items():
max_shape_sinks = max(max_shape_sinks, indexes.offset + (indexes.end - indexes.start))

# Exit if source and sink have different sizes
if max_shape_srcs != max_shape_sinks and len(srcs) > 0:
if max_shape_srcs != max_shape_sinks and len(region.srcs) > 0:
return _no_equalize()

src_axes = {}
for i, (name, (module, indexes)) in enumerate(srcs.items()):
for name, indexes in region.srcs.items():
module = region.get_module_from_name(name)
# If module is not supported, do not perform graph equalization
axis = _get_output_axis(module)
act_sources_axes[name] = _get_act_axis(module)
if not isinstance(module, _supported_layers):
return _no_equalize()
if isinstance(module, nn.MultiheadAttention):
module = module.out_proj
srcs[name] = (module, indexes)
src_axes[name] = (module, axis)

sink_axes = {}
for i, (name, (module, indexes)) in enumerate(sinks.items()):
for name, indexes in region.sinks.items():
module = region.get_module_from_name(name)
axis = _get_input_axis(module)
act_sink_axes[name] = _get_act_axis(module)
# If module is not supported, do not perform graph equalization
Expand All @@ -436,7 +431,6 @@ def _no_equalize():
if isinstance(module, nn.MultiheadAttention) and module.in_proj_weight is not None:
# For sinks, we only need to modify the weight but not the bias
module = WeightBiasTuple(module.in_proj_weight)
sinks[name] = (module, indexes)
elif isinstance(module, nn.MultiheadAttention) and module.in_proj_weight is None:
return _no_equalize()
sink_axes[name] = (module, axis)
Expand Down Expand Up @@ -469,7 +463,7 @@ def _no_equalize():
for k, v in sink_weights.items():
# Sinks can be partially equalized, thus we need to select
# only the channels we are interested in
indexes = sinks[k][1]
indexes = region.sinks[k]
# Compute the range of the channels we need to equalize
weight_range = scale_fn(v.reshape(v.size(0), -1))[indexes.start:indexes.end]
# Compute the numbers of channels we are equalizing
Expand Down Expand Up @@ -501,7 +495,7 @@ def _no_equalize():
for k, v in src_weights.items():
# Srcs are always fully equalized, thus we simply need to apply the offset to position them
# correctly with respect to the other srcs matrices.
indexes = srcs[k][1]
indexes = region.srcs[k]
channel_start = indexes.offset + indexes.start
channel_end = indexes.offset + indexes.end
weight_range = scale_fn(v.reshape(v.size(0), -1))
Expand Down Expand Up @@ -533,7 +527,7 @@ def _no_equalize():
insert_mul_node_fn(inverse_scaling_factors, act_val_shape, act_axis)
if len(src_axes) > 0:
for name, (module, axis) in src_axes.items():
indexes = srcs[name][1]
indexes = region.srcs[name]
channel_start = indexes.offset + indexes.start
channel_end = indexes.offset + indexes.end
if hasattr(module, 'bias') and module.bias is not None:
Expand All @@ -550,10 +544,10 @@ def _no_equalize():
module.weight.clone() * torch.reshape(
inverse_scaling_factors[channel_start:channel_end], src_broadcast_size),
attr='weight')
for module, axis in sink_axes.items():
for name, (module, axis) in sink_axes.items():
sink_broadcast_size = [1] * module.weight.ndim
sink_broadcast_size[axis] = module.weight.size(axis)
indexes = sinks[name][1]
indexes = region.sinks[name]
channel_range = indexes.end - indexes.start
partial_scaling = torch.ones(module.weight.size(axis), device=device, dtype=dtype)
# We replace the scaling factors of the channels we need to equalize, leaving the other to
Expand Down Expand Up @@ -582,15 +576,6 @@ def _update_weights(original_module, new_value, attr='weight'):
setattr(original_module, attr, nn.Parameter(new_value))


def _organize_region(region, type):
region_dict = {}
region = getattr(region, type)
for i, (k, v) in enumerate(region.items()):
name = type + str(i)
region_dict[name] = (region.get_module_from_name(k), v)
return region_dict


def _equalize(
model: GraphModule,
regions: Set[Tuple[str]],
Expand All @@ -605,11 +590,8 @@ def _equalize(
for i in range(iterations):
scale_factor_max = None
for ii, region in enumerate(regions):
srcs_dict = _organize_region(region, 'srcs')
sinks_dict = _organize_region(region, 'sinks')
scale_factors_region = _cross_layer_equalization(
srcs_dict,
sinks_dict,
region,
merge_bias=merge_bias,
bias_shrinkage=bias_shrinkage,
scale_computation_type=scale_computation_type)
Expand Down Expand Up @@ -996,7 +978,7 @@ def __init__(self, model, scale_computation_type: str = 'maxabs'):
self.hooks = []
self.add_mul_node = True

regions = []
regions: List[Region] = []
self.find_module(model, regions)
self.regions = regions

Expand All @@ -1014,39 +996,42 @@ def find_module(self, model, regions: List):
_supported_layers) and not isinstance(model, _batch_norm + (nn.LayerNorm,)):
weight = get_weight_sink([model])
eq_indexes = EqualizationIndexes(0, weight[0].shape[0], 0)
regions.append((model, eq_indexes))
region = Region(sinks={'sinks0': eq_indexes}, name_to_module={'sinks0': model})
regions.append(region)
else:
for module in model.children():
self.find_module(module, regions)

def setup(self):
for region in self.regions:
module = region.get_module_from_name('sinks0')
batch_dim = 0
if hasattr(region, 'batch_first'):
batch_dim = 0 if region.batch_first else 1

hook_fn = partial(
self.forward_stats_hook, name=region[0], batch_dim=batch_dim, use_inp=True)
new_instance = KwargsForwardHook(region[0], hook_fn)
ModuleInstanceToModuleInstance(region[0], new_instance).apply(self.model)
self.forward_stats_hook, name=module, batch_dim=batch_dim, use_inp=True)
new_instance = KwargsForwardHook(module, hook_fn)
ModuleInstanceToModuleInstance(module, new_instance).apply(self.model)
self.hooks.append(new_instance)

def apply(self, alpha):
scale_factors = []
self.remove_hooks()
for region in self.regions:
if self.float_act_map[region[0]] == None:
module = region.get_module_from_name('sinks0')
if self.float_act_map[module] == None:
continue
sinks = region
insert_mul_fn = partial(
self.insert_mul_node, region=region[0], batch_dim=self.batch_dim_act_map[region[0]])
self.insert_mul_node, region=module, batch_dim=self.batch_dim_act_map[module])
scale_factors.append(
_cross_layer_equalization({}, {'sinks0': sinks},
False,
scale_computation_type=self.scale_computation_type,
list_of_act_val=[self.float_act_map[region[0]]],
list_of_insert_mul_node_fn=[insert_mul_fn],
alpha=alpha))
_cross_layer_equalization(
region,
False,
scale_computation_type=self.scale_computation_type,
list_of_act_val=[self.float_act_map[module]],
list_of_insert_mul_node_fn=[insert_mul_fn],
alpha=alpha))
return scale_factors

def insert_mul_node(self, scale, shape, axis, region, batch_dim=0):
Expand All @@ -1067,6 +1052,7 @@ def __init__(
self.float_act_map = {}
self.batch_dim_act_map = {}
self.hooks = []
self.hooked_modules = set()
self.add_mul_node = add_mul_node
self.regions = _extract_regions(model, add_mul_node=add_mul_node, return_acts=True)

Expand All @@ -1091,19 +1077,26 @@ def setup(self):

# We assume that the entire region has a unique batch_dim
batch_dim = 0
region_to_search = region.sinks if len(region.acts) == 0 else region.acts
for name in region.srcs + region.sinks:
for name in region.srcs:
module = region.get_module_from_name(name)
if hasattr(module, 'batch_first') and not module.batch_first:
batch_dim = 1
for name in region.sinks:
module = region.get_module_from_name(name)
if hasattr(module, 'batch_first'):
batch_dim = 0 if module.batch_first else 1
if hasattr(module, 'batch_first') and not module.batch_first:
batch_dim = 1

region_to_search = region.sinks_names if len(region.acts) == 0 else region.acts
for name in region_to_search:
module = region.get_module_from_name(name)
use_inp = True if region_to_search == region.sinks else False
hook_fn = partial(
self.forward_stats_hook, name=name, batch_dim=batch_dim, use_inp=use_inp)
new_instance = KwargsForwardHook(module, hook_fn)
ModuleInstanceToModuleInstance(module, new_instance).apply(self.model)
self.hooks.append(new_instance)
if module not in self.hooked_modules:
self.hooked_modules.add(module)
use_inp = True if region_to_search == region.sinks_names else False
hook_fn = partial(
self.forward_stats_hook, name=name, batch_dim=batch_dim, use_inp=use_inp)
new_instance = KwargsForwardHook(module, hook_fn)
ModuleInstanceToModuleInstance(module, new_instance).apply(self.model)
self.hooks.append(new_instance)

self.regions = [x for x in self.regions if x not in regions_to_drop]

Expand All @@ -1129,12 +1122,10 @@ def apply(self, alpha):
self.insert_mul_node,
act_node=act_node,
batch_dim=self.batch_dim_act_map[act_name]))
srcs_dict = _organize_region(region, 'srcs')
sinks_dict = _organize_region(region, 'sinks')

scale_factors.append(
_cross_layer_equalization(
srcs_dict,
sinks_dict,
region,
False,
scale_computation_type=self.scale_computation_type,
list_of_act_val=list_of_act_val,
Expand Down
5 changes: 1 addition & 4 deletions tests/brevitas/graph/equalization_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,8 @@ def equalize_test(model, regions, merge_bias, bias_shrinkage, scale_computation_
scale_factors_regions = []
for i in range(3):
for region in regions:
srcs_dict = _organize_region(region, 'srcs')
sinks_dict = _organize_region(region, 'sinks')
scale_factors_region = _cross_layer_equalization(
srcs_dict,
sinks_dict,
region,
merge_bias=merge_bias,
bias_shrinkage=bias_shrinkage,
scale_computation_type=scale_computation_type)
Expand Down

0 comments on commit 5ab0273

Please sign in to comment.