-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathapgpy.py
111 lines (89 loc) · 2.68 KB
/
apgpy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
from __future__ import print_function
import numpy as np
from apgwrapper import NumpyWrapper
from functools import partial
def npwrap(x):
if isinstance(x, np.ndarray):
return NumpyWrapper(x)
return x
def npwrapfunc(f, *args):
return npwrap(f(*args))
def solve(grad_f, prox_h, x_init,
max_iters=2500,
eps=1e-6,
alpha=1.01,
beta=0.5,
use_restart=True,
gen_plots=False,
quiet=False,
use_gra=False,
step_size=False,
fixed_step_size=False,
debug=False):
df = partial(npwrapfunc, grad_f)
ph = partial(npwrapfunc, prox_h)
x_init = npwrap(x_init)
x = x_init.copy()
y = x.copy()
g = df(y.data)
theta = 1.
if not step_size:
# barzilai-borwein step-size initialization:
t = 1. / g.norm()
x_hat = x - t * g
g_hat = df(x_hat.data)
t = abs((x - x_hat).dot(g - g_hat) / (g - g_hat).norm() ** 2)
else:
t = step_size
if gen_plots:
errs = np.zeros(max_iters)
k = 0
err1 = np.nan
iter_str = 'iter num %i, norm(Gk)/(1+norm(xk)): %1.2e, step-size: %1.2e'
for k in range(max_iters):
if not quiet and k % 100 == 0:
print(iter_str % (k, err1, t))
x_old = x.copy()
y_old = y.copy()
x = y - t * g
if prox_h:
x = ph(x.data, t)
err1 = (y - x).norm() / (1 + x.norm()) / t
if gen_plots:
errs[k] = err1
if err1 < eps:
break
if not use_gra:
theta = 2. / (1 + np.sqrt(1 + 4 / (theta ** 2)))
else:
theta = 1.
if not use_gra and use_restart and (y - x).dot(x - x_old) > 0:
if debug:
print('restart, dg = %1.2e' % (y - x).dot(x - x_old))
x = x_old.copy()
y = x.copy()
theta = 1.
else:
y = x + (1 - theta) * (x - x_old)
g_old = g.copy()
g = df(y.data)
# tfocs-style backtracking:
if not fixed_step_size:
t_old = t
t_hat = 0.5 * ((y - y_old).norm() ** 2) / abs((y - y_old).dot(g_old - g))
t = min(alpha * t, max(beta * t, t_hat))
if debug:
if t_old > t:
print('back-track, t = %1.2e, t_old = %1.2e, t_hat = %1.2e' % (t, t_old, t_hat))
if not quiet:
print(iter_str % (k, err1, t))
print('terminated')
if gen_plots:
import matplotlib.pyplot as plt
errs = errs[1:k]
plt.figure()
plt.semilogy(errs[1:k])
plt.xlabel('iters')
plt.title('||Gk||/(1+||xk||)')
plt.draw()
return x.data