Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix for multiple region with different ranges
Browse files Browse the repository at this point in the history
Giuseppe5 committed Sep 12, 2023
1 parent c5d42d5 commit 1a106ac
Showing 1 changed file with 32 additions and 16 deletions.
48 changes: 32 additions & 16 deletions src/brevitas/graph/equalize.py
Original file line number Diff line number Diff line change
@@ -32,15 +32,6 @@

EPSILON = 1e-9

Region = namedtuple('Region', ['srcs', 'sinks'])

# Start and End identify the starting and ending channels of the weight matrix that need to be
# equalized.
# Offset refers to the relative position of these channels with respect to
# the other matrices' channels that are equalized simultaneously.
# Source matrix are always fully equalized, while sinks can be partially equalized.
EqualizationIndexes = namedtuple('EqualizationIndexes', ['start', 'end', 'offset'])

_supported_layers = (
nn.ConvTranspose1d,
nn.ConvTranspose2d,
@@ -118,6 +109,22 @@ class WalkRegionState:
update_offset: bool = False


# Start and End identify the starting and ending channels of the weight matrix that need to be
# equalized.
# Offset refers to the relative position of these channels with respect to
# the other matrices' channels that are equalized simultaneously.
# Source matrix are always fully equalized, while sinks can be partially equalized.
@dataclass
class EqualizationIndexes:
start: int = 0
end: int = 0
offset: int = 0


def __str__(self):
return str(self.start) + '_' + str(self.end) + '_' + str(self.offset)


_UNSUPPORTED_OP = object()


@@ -352,8 +359,6 @@ def _no_equalize():
print("No Eq")
return torch.tensor(1., dtype=dtype, device=device)

src_axes = {}
sink_axes = {}
act_sink_axes = {}
act_sources_axes = {}
single_module = list(sinks.values())[0][0]
@@ -511,6 +516,13 @@ def _no_equalize():
# one (i.e., no equalization)
partial_scaling[indexes.start:indexes.end] = scaling_factors[indexes.offset:indexes.offset +
channel_range]
if isinstance(module, _batch_norm):
# We re-compute the bias as function of running_mean and running_var to adjust the
# additive factor for equalization.
additive_factor = module.running_mean.data * module.weight.data / torch.sqrt(
module.running_var.data + module.eps)
_update_weights(
module, module.bias.clone() + additive_factor * (partial_scaling - 1), attr='bias')
_update_weights(
module,
module.weight.clone() * torch.reshape(partial_scaling, src_broadcast_size),
@@ -531,6 +543,7 @@ def _organize_region(region, name_to_module, type):
region = getattr(region, type)
for i, (k, v) in enumerate(region.items()):
name = type + str(i)
k = k.split('$')[0]
region_dict[name] = (name_to_module[k], v)
return region_dict

@@ -550,9 +563,9 @@ def _equalize(
name_set = set()
for region in regions:
for name in region.srcs.keys():
name_set.add(name)
name_set.add(name.split("$")[0])
for name in region.sinks.keys():
name_set.add(name)
name_set.add(name.split("$")[0])

for name, module in model.named_modules():
if name in name_set:
@@ -713,7 +726,8 @@ def find_srcs(graph_model: GraphModule, starting_node: Node,
weight = get_weight_source([module])
eq_indexes = EqualizationIndexes(0, weight[0].shape[0], state.offset)
# state.srcs.add((node.target, eq_indexes))
state.srcs[node.target] = eq_indexes
full_source_name = node.target + '$' + str(eq_indexes)
state.srcs[full_source_name] = eq_indexes
# After we found a source, we need to check if it branches into multiple sinks
find_sinks(graph_model, node, state)
state.offset = state.offset if not state.update_offset else state.offset + weight[
@@ -763,11 +777,13 @@ def find_sinks(graph_model: GraphModule, starting_node: Node,
weight = get_weight_sink([module])
eq_indexes = EqualizationIndexes(0, weight[0].shape[0], state.offset)
# It is not possible to equalize through LayerNorm as sink
if isinstance(module, (nn.LayerNorm,) + _batch_norm):
if isinstance(module, (nn.LayerNorm,)):
# state.sinks.add((_UNSUPPORTED_OP, _UNSUPPORTED_OP))
state.sinks[_UNSUPPORTED_OP] = _UNSUPPORTED_OP
else:
state.sinks[node.target] = eq_indexes
full_sink_name = node.target + '$' + str(eq_indexes)
state.sinks[full_sink_name] = eq_indexes
# state.sinks[node.target] = eq_indexes
elif _is_scale_invariant_module(
graph_model, node) or _is_scale_invariant_function(node) or _is_reshaping_op(node):
find_sinks(graph_model, node, state)

0 comments on commit 1a106ac

Please sign in to comment.