-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathexample_hessian.py
68 lines (50 loc) · 1.73 KB
/
example_hessian.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import torch
from torch import nn
from autograd_lib import autograd_lib
from collections import defaultdict
from attrdict import AttrDefault
def simple_model(d, num_layers):
"""Creates simple linear neural network initialized to identity"""
layers = []
for i in range(num_layers):
layer = nn.Linear(d, d, bias=False)
layer.weight.data.copy_(torch.eye(d))
layers.append(layer)
return torch.nn.Sequential(*layers)
def least_squares(data, targets=None):
"""Least squares loss (like MSELoss, but an extra 1/2 factor."""
if targets is None:
targets = torch.zeros_like(data)
err = data - targets.view(-1, data.shape[1])
return torch.sum(err * err) / 2 / len(data)
d=1
n=1
model = simple_model(1, 5)
data = torch.ones((n, d))
targets = torch.ones((n, d))
loss_fn = least_squares
autograd_lib.register(model)
hess = defaultdict(float)
hess_diag = defaultdict(float)
hess_kfac = defaultdict(lambda: AttrDefault(float))
activations = {}
def save_activations(layer, A, _):
activations[layer] = A
# KFAC left factor
hess_kfac[layer].AA += torch.einsum("ni,nj->ij", A, A)
with autograd_lib.module_hook(save_activations):
output = model(data)
loss = loss_fn(output, targets)
def compute_hess(layer, _, B):
A = activations[layer]
BA = torch.einsum("nl,ni->nli", B, A)
# full Hessian
hess[layer] += torch.einsum('nli,nkj->likj', BA, BA)
# Hessian diagonal
hess_diag[layer] += torch.einsum("ni,nj->ij", B * B, A * A)
# KFAC right factor
hess_kfac[layer].BB += torch.einsum("ni,nj->ij", B, B)
with autograd_lib.module_hook(compute_hess):
autograd_lib.backward_hessian(output, loss='LeastSquares')
for layer in model.modules():
print(hess_diag[layer])