Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
fanis-khafizov committed Jan 30, 2025
1 parent 7166b00 commit 8cb9d2b
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 7 deletions.
4 changes: 2 additions & 2 deletions code/ResNet/compressors.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def update(self, X_train, y_train, criterion, lr, eta, num_steps):
self.w[name] = mirror_descent(
model=self.model,
param_name=name,
impact=self.w[name],
impact=None,
lr=lr,
eta=eta,
lambda_value=0.1,
Expand Down Expand Up @@ -227,7 +227,7 @@ def __init__(self, model, k, weighted=True):
"""
self.model = model
self.k = k
self.w = {name: (imp := torch.ones_like(param)) / 2
self.w = {name: (imp := torch.ones_like(param))
for name, param in model.named_parameters()
}
self.weighted = weighted
Expand Down
8 changes: 7 additions & 1 deletion code/ResNet/descent.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,18 @@ def mirror_descent(model, X_train, y_train, param_name, impact: torch.Tensor, lr
Returns:
torch.Tensor: The optimized impact tensor.
"""
impact = impact.clone().detach().requires_grad_(True)

original_param = dict(model.named_parameters())[param_name]

outputs = model(X_train)
loss = criterion(outputs, y_train)
param_grad = torch.autograd.grad(loss, original_param, create_graph=True)[0]

if impact is None:
impact = param_grad.abs().clone().detach().requires_grad_(True)
else:
impact = impact.clone().detach().requires_grad_(True)

new_params = {name: param.clone() for name, param in model.named_parameters()}

for _ in range(num_steps):
Expand Down
17 changes: 13 additions & 4 deletions code/ResNet/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,10 @@
trainloader, testloader, classes = load_data()
device = get_device()

num_epochs = 2

config = {
'param_usage': 0.001,
'num_restarts': 1,
'num_epochs': 30,
}

compress_configs = [
Expand All @@ -31,13 +30,13 @@
{
'compression_type': 'ImpK_b',
'lr': 0.01,
'eta': 2.,
'eta': 5.,
'num_steps': 20,
},
{
'compression_type': 'ImpK_c',
'lr': 0.01,
'eta': 100000.,
'eta': 1000000.,
'num_steps': 20,
}
]
Expand All @@ -48,6 +47,7 @@

param_usage = config['param_usage']
num_restarts = config['num_restarts']
num_epochs = config['num_epochs']

for compress_config in compress_configs:
compression_type = compress_config['compression_type']
Expand Down Expand Up @@ -91,6 +91,15 @@
test_log[compression_type].append(test_loss)
test_acc[compression_type].append(test_accuracy)

print("Train Loss")
print(train_log)
print("Train Accuracy")
print(train_acc)
print("Test Loss")
print(test_log)
print("Test Accuracy")
print(test_acc)


fig_train, axs_train = plt.subplots(1, 2, figsize=(16, 7))
fig_test, axs_test = plt.subplots(1, 2, figsize=(16, 7))
Expand Down

0 comments on commit 8cb9d2b

Please sign in to comment.