Skip to content

Commit

Permalink
Merge branch 'main' into tensor_decompression
Browse files Browse the repository at this point in the history
  • Loading branch information
Sara Adkins committed Mar 20, 2024
2 parents 38e3c6d + dead8b5 commit 830a9f5
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@ def __init__(
mask_type: str = "unstructured",
active_weight_decay: float = 0.0001,
):

self._mask_type = mask_type

super(TopKASTPruningModifier, self).__init__(
Expand Down Expand Up @@ -262,11 +261,9 @@ def check_mask_update(
epoch, steps_per_epoch
)

self._module_masks.update_param_masks(
target=recomputation_sparsity or self._applied_sparsity
)
self._module_masks.update_param_masks(target=self._applied_sparsity)
self._grad_module_masks.update_param_masks(
target=recomputation_sparsity or self._grad_applied_sparsity
target=self._grad_applied_sparsity
)
self._sparsity_applied = True

Expand Down Expand Up @@ -383,7 +380,7 @@ def optimizer_pre_step(
w_i = w_i - w_i * self._active_weight_decay * current learning rate
For weights inactive in the forward pass but active in
the backward pass:
w_i = w_i - w_i * 1/forward_sparsity * self._active_weight_decay * \
w_i = w_i - w_i * 1/(1-forward_sparsity) * self._active_weight_decay * \
current learning rate
The reason that we multiply by the learning rate is that in the original
Expand Down Expand Up @@ -416,7 +413,7 @@ def optimizer_pre_step(
param -= (
self._active_weight_decay
* lr
* (1 / self.forward_sparsity)
* (1 / (1 - self.forward_sparsity))
* (1 - self._module_masks.param_masks[i])
* self._grad_module_masks.param_masks[i]
* param
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@


def create_optim_sgd(
model: Module, lr: float = 0.00025, momentum: float = 0.9, weight_decay: float = 0
model: Module, lr: float = 0.25, momentum: float = 0.0, weight_decay: float = 0.0
) -> SGD:
return SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)

Expand Down Expand Up @@ -81,7 +81,7 @@ def create_optim_adam(model: Module, lr: float = 0.00025) -> Adam:
@pytest.mark.parametrize("model_lambda", [LinearNet], scope="function")
@pytest.mark.parametrize(
"optim_lambda",
[create_optim_sgd, create_optim_adam],
[create_optim_sgd],
scope="function",
)
class TestTopKASTPruningModifier(ScheduledModifierTest):
Expand Down Expand Up @@ -124,6 +124,125 @@ def _test_compression_sparsity_applied():
assert not modifier.update_ready(epoch, test_steps_per_epoch)
_test_compression_sparsity_applied()

# This test evaluates whether the gradients computed when Top-Kast
# is applied (so there is a forward mask) match those when the remaining
# parameters are explicitly sset to 0.
def test_topkast_forward_masking(
self,
modifier_lambda,
model_lambda,
optim_lambda,
test_steps_per_epoch, # noqa: F811
):
modifier = modifier_lambda()
model = model_lambda()
optimizer = optim_lambda(model)
self.initialize_helper(modifier, model)

batch_shape = 10
input_shape = model_lambda.layer_descs()[0].input_size
epoch = int(modifier.start_epoch)

while epoch < modifier.end_epoch:
if modifier.update_ready(epoch, test_steps_per_epoch):
modifier.scheduled_update(model, optimizer, epoch, test_steps_per_epoch)

# cache the model's weights before masking, so we can restore at
# the end of the test.
model_state_dict = copy.deepcopy(model.state_dict())

random_input = torch.randn(batch_shape, *input_shape)

# Compute gradients using full weights but Top-Kast modifier.
optimizer.zero_grad()
model(random_input).mean().backward()
grads_from_full_model = {}
for i, param in enumerate(modifier._module_masks._params):
grads_from_full_model[i] = modifier._module_masks._params[i].grad

# Now compute grads when the masked weights are actually just 0.
optimizer.zero_grad()
with torch.no_grad():
for i, param in enumerate(modifier._module_masks._params):
param.data.mul_(modifier._module_masks.param_masks[i])
model(random_input).mean().backward()
for i, param in enumerate(modifier._module_masks._params):
assert torch.allclose(
grads_from_full_model[i], modifier._module_masks._params[i].grad
)

# Restore the unmasked weights to continue the test.
model.load_state_dict(model_state_dict)
optimizer.step()
epoch += 1

# Test whether the gradients are masked and applied correctly.
def test_topkast_gradient_masking(
self,
modifier_lambda,
model_lambda,
optim_lambda,
test_steps_per_epoch, # noqa: F811
):
modifier = modifier_lambda()
model = model_lambda()
optimizer = optim_lambda(model)
self.initialize_helper(modifier, model)

batch_shape = 10
input_shape = model_lambda.layer_descs()[0].input_size
epoch = int(modifier.start_epoch)

while epoch < modifier.end_epoch:
if modifier.update_ready(epoch, test_steps_per_epoch):
modifier.scheduled_update(model, optimizer, epoch, test_steps_per_epoch)

# cache the model's weights before optimizer step.
layer_weights_pre = copy.deepcopy(modifier._module_masks)

optimizer.zero_grad()
model(torch.randn(batch_shape, *input_shape)).mean().backward()
optimizer.step()

for i, param in enumerate(modifier._module_masks._params):
# Params masked by the backward mask shouldn't change.
unchanged_mask = (1 - modifier._grad_module_masks.param_masks[i]).bool()
forward_mask = (modifier._module_masks.param_masks[i]).bool()
backward_mask = (
(1 - modifier._module_masks.param_masks[i])
* modifier._grad_module_masks.param_masks[i]
).bool()
# check that the three masks fully covert the space
assert torch.all(unchanged_mask + forward_mask + backward_mask)
assert torch.equal((~unchanged_mask), forward_mask + backward_mask)
assert torch.equal((~forward_mask), backward_mask + unchanged_mask)
assert torch.equal((~backward_mask), forward_mask + unchanged_mask)

# Confirm that the gradients were only applied to those weights that
# are in the backward mask.
# We are using SGD with no momentum as the optimizer, so we can check
# the calculation explicitly.
assert torch.equal(
modifier._module_masks._params[i][unchanged_mask],
layer_weights_pre._params[i][unchanged_mask],
)
assert torch.allclose(
modifier._module_masks._params[i][forward_mask],
(
layer_weights_pre._params[i]
- 0.25 * modifier._module_masks._params[i].grad
)[forward_mask],
)
assert torch.allclose(
modifier._module_masks._params[i][backward_mask],
(
layer_weights_pre._params[i]
- 0.25 * modifier._module_masks._params[i].grad
)[backward_mask],
)

epoch += 1

@pytest.mark.flaky(reruns=3, min_passes=2)
def test_weight_decay(
self,
Expand All @@ -144,7 +263,7 @@ def test_weight_decay(
while epoch < modifier.end_epoch:
if modifier.update_ready(epoch, test_steps_per_epoch):
modifier.scheduled_update(model, optimizer, epoch, test_steps_per_epoch)
# Cache the model's weights before optimizer step.
# cache the model's weights before optimizer step.

layer_weights_pre = copy.deepcopy(modifier._module_masks)
optimizer.zero_grad()
Expand All @@ -158,7 +277,7 @@ def test_weight_decay(
(1 - modifier._module_masks.param_masks[i])
* modifier._grad_module_masks.param_masks[i]
).bool()
# Check that the three masks fully covert the space
# check that the three masks fully covert the space
assert torch.all(unchanged_mask + forward_mask + backward_mask)
assert torch.equal((~unchanged_mask), forward_mask + backward_mask)
assert torch.equal((~forward_mask), backward_mask + unchanged_mask)
Expand All @@ -170,15 +289,15 @@ def test_weight_decay(
)
assert torch.allclose(
modifier._module_masks._params[i][forward_mask],
layer_weights_pre._params[i][forward_mask] * (1 - 0.0002 * 0.00025),
atol=1e-5,
layer_weights_pre._params[i][forward_mask] * (1 - 0.0002 * 0.25),
atol=1e-7,
equal_nan=True,
)
assert torch.allclose(
modifier._module_masks._params[i][backward_mask],
layer_weights_pre._params[i][backward_mask]
* (1 - 0.0002 * 0.00025 * 1 / modifier._forward_sparsity),
atol=1e-5,
* (1 - 0.0002 * 0.25 * 1 / (1 - modifier._forward_sparsity)),
atol=1e-7,
equal_nan=True,
)

Expand Down

0 comments on commit 830a9f5

Please sign in to comment.