Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge LPGD into diffcp #67

Open
wants to merge 22 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ These inputs must conform to the [SCS convention](https://github.com/bodono/scs-

The values in `cone_dict` denote the sizes of each cone; the values of `diffcp.SOC`, `diffcp.PSD`, and `diffcp.EXP` should be lists. The order of the rows of `A` must match the ordering of the cones given above. For more details, consult the [SCS documentation](https://github.com/cvxgrp/scs/blob/master/README.md).

To enable [Lagrangian Proximal Gradient Descent (LPGD)](https://arxiv.org/abs/2407.05920) differentiation of the conic program based on efficient finite-differences, provide the `mode=LPGD` option along with the argument `derivative_kwargs=dict(tau=0.1, rho=0.1)` to specify the perturbation and regularization strength. Alternatively, the derivative kwargs can also be passed directly to the returned `derivative` and `adjoint_derivative` function.
To enable [Lagrangian Proximal Gradient Descent (LPGD)](https://arxiv.org/abs/2407.05920) differentiation of the conic program based on efficient finite-differences, provide the `mode=lpgd` option along with the argument `derivative_kwargs=dict(tau=0.1, rho=0.1)` to specify the perturbation and regularization strength. Alternatively, the derivative kwargs can also be passed directly to the returned `derivative` and `adjoint_derivative` function.

#### Return value
The function `solve_and_derivative` returns a tuple
Expand Down
163 changes: 102 additions & 61 deletions diffcp/cone_program.py

Large diffs are not rendered by default.

18 changes: 12 additions & 6 deletions diffcp/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def embed_problem(A, b, c, P, cone_dict):
return A_emb, b_emb, c_emb, P_emb, cone_dict_emb


def compute_perturbed_solution(dA, db, dc, tau, rho, A, b, c, P, cone_dict, x, y, s, solver_kwargs, solve_method, solve_internal):
def compute_perturbed_solution(dA, db, dc, dP, tau, rho, A, b, c, P, cone_dict, x, y, s, solver_kwargs, solve_method, solve_internal):
"""
Computes the perturbed solution x_right, y_right, s_right at (A, b, c) with
perturbations dA, db, dc and optional regularization rho.
Expand All @@ -92,6 +92,8 @@ def compute_perturbed_solution(dA, db, dc, tau, rho, A, b, c, P, cone_dict, x, y
pattern as the matrix `A` from the cone program
db: NumPy array representing perturbation in `b`
dc: NumPy array representing perturbation in `c`
dP: (optional) SciPy sparse matrix in CSC format; must have same sparsity
pattern as the matrix `P` from the cone program
tau: Perturbation strength parameter
rho: Regularization strength parameter
Returns:
Expand All @@ -105,16 +107,20 @@ def compute_perturbed_solution(dA, db, dc, tau, rho, A, b, c, P, cone_dict, x, y
A_pert = A + tau * dA
b_pert = b + tau * db
c_pert = c + tau * dc
if dP is not None:
P_pert = P + tau * dP
else:
P_pert = P

# Regularize: Effectively adds a rho/2 |x-x^*|^2 term to the objective
P_reg = regularize_P(P, rho=rho, size=n)
P_pert_reg = regularize_P(P_pert, rho=rho, size=n)
c_pert_reg = c_pert - rho * x

# Set warm start
warm_start = (x, y, s)
warm_start = (x, y, s) if solve_method not in ["ECOS", "Clarabel"] else None

# Solve the perturbed problem
result_pert = solve_internal(A=A_pert, b=b_pert, c=c_pert_reg, P=P_reg, cone_dict=cone_dict,
result_pert = solve_internal(A=A_pert, b=b_pert, c=c_pert_reg, P=P_pert_reg, cone_dict=cone_dict,
solve_method=solve_method, warm_start=warm_start, **solver_kwargs)
# Extract the solutions
x_pert, y_pert, s_pert = result_pert["x"], result_pert["y"], result_pert["s"]
Expand Down Expand Up @@ -169,7 +175,7 @@ def compute_adjoint_perturbed_solution(dx, dy, ds, tau, rho, A, b, c, P, cone_di
c_pert_reg = c_pert - rho * x

# Set warm start
warm_start = (x, y, s) if solve_method != "ECOS" else None
warm_start = (x, y, s) if solve_method not in ["ECOS", "Clarabel"] else None

# Solve the perturbed problem
# Note: In special case solve_method=='SCS' and rho==0, this could be sped up strongly by using solver.update
Expand All @@ -190,7 +196,7 @@ def compute_adjoint_perturbed_solution(dx, dy, ds, tau, rho, A, b, c, P, cone_di
c_emb_pert_reg = c_emb_pert - rho * np.hstack([x, s])

# Set warm start
if solve_method == "ECOS":
if solve_method in ["ECOS", "Clarabel"]:
warm_start = None
else:
warm_start = (np.hstack([x, s]), np.hstack([y, y]), np.hstack([s, s]))
PTNobel marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
47 changes: 47 additions & 0 deletions examples/batch_clarabel_lpgd_quadratic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import diffcp
import time
import numpy as np
import scipy.sparse as sparse

m = 100
n = 50

batch_size = 16
n_jobs = 2

As, bs, cs, Ks, Ps = [], [], [], [], []
for _ in range(batch_size):
A, b, c, K = diffcp.utils.least_squares_eq_scs_data(m, n)
P = sparse.csc_matrix((c.size, c.size))
P = sparse.triu(P).tocsc()
As += [A]
bs += [b]
cs += [c]
Ks += [K]
Ps += [P]

def time_function(f, N=1):
result = []
for i in range(N):
tic = time.time()
f()
toc = time.time()
result += [toc - tic]
return np.mean(result), np.std(result)

for n_jobs in range(1, 8):
def f_forward():
return diffcp.solve_and_derivative_batch(As, bs, cs, Ks,
n_jobs_forward=n_jobs, n_jobs_backward=n_jobs, solve_method="Clarabel", verbose=False, mode="lpgd", derivative_kwargs=dict(tau=1e-3, rho=0.1),
Ps=Ps)
xs, ys, ss, D_batch, DT_batch = diffcp.solve_and_derivative_batch(As, bs, cs, Ks,
n_jobs_forward=1, n_jobs_backward=n_jobs, solve_method="Clarabel", verbose=False,
mode="lpgd", derivative_kwargs=dict(tau=1e-3, rho=0.1), Ps=Ps)

def f_backward():
DT_batch(xs, ys, ss, return_dP=True)

mean_forward, std_forward = time_function(f_forward)
mean_backward, std_backward = time_function(f_backward)
print("%03d | %4.4f +/- %2.2f | %4.4f +/- %2.2f" %
(n_jobs, mean_forward, std_forward, mean_backward, std_backward))
40 changes: 40 additions & 0 deletions examples/clarabel_example_lpgd_quadratic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import diffcp

import numpy as np
import utils
from scipy import sparse

np.set_printoptions(precision=5, suppress=True)


# We generate a random cone program with a cone
# defined as a product of a 3-d zero cone, 3-d positive orthant cone,
# and a 5-d second order cone.
K = {
'z': 3,
'l': 3,
'q': [5]
}

m = 3 + 3 + 5
n = 5

np.random.seed(0)

A, b, c = utils.random_cone_prog(m, n, K)
P = sparse.csc_matrix((c.size, c.size))
P = sparse.triu(P).tocsc()

# We solve the cone program and get the derivative and its adjoint
x, y, s, derivative, adjoint_derivative = diffcp.solve_and_derivative(
A, b, c, K, P=P, solve_method="Clarabel", verbose=False, mode="lpgd", derivative_kwargs=dict(tau=0.1, rho=0.0))

print("x =", x)
print("y =", y)
print("s =", s)

# Adjoint derivative
dA, db, dc, dP = adjoint_derivative(dx=c, dy=np.zeros(m), ds=np.zeros(m), return_dP=True)

# Derivative (dummy inputs)
dx, dy, ds = derivative(dA=A, db=b, dc=c, dP=P)
2 changes: 1 addition & 1 deletion examples/dual_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


# We generate a random cone program with a cone
# defined as a product of a 3-d fixed cone, 3-d positive orthant cone,
# defined as a product of a 3-d zero cone, 3-d positive orthant cone,
# and a 5-d second order cone.
K = {
'z': 3,
Expand Down
2 changes: 1 addition & 1 deletion examples/dual_example_lpgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


# We generate a random cone program with a cone
# defined as a product of a 3-d fixed cone, 3-d positive orthant cone,
# defined as a product of a 3-d zero cone, 3-d positive orthant cone,
# and a 5-d second order cone.
K = {
'z': 3,
Expand Down
2 changes: 1 addition & 1 deletion examples/ecos_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


# We generate a random cone program with a cone
# defined as a product of a 3-d fixed cone, 3-d positive orthant cone,
# defined as a product of a 3-d zero cone, 3-d positive orthant cone,
# and a 5-d second order cone.
K = {
'z': 3,
Expand Down
2 changes: 1 addition & 1 deletion examples/ecos_example_lpgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


# We generate a random cone program with a cone
# defined as a product of a 3-d fixed cone, 3-d positive orthant cone,
# defined as a product of a 3-d zero cone, 3-d positive orthant cone,
# and a 5-d second order cone.
K = {
'z': 3,
Expand Down