Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 28 additions & 7 deletions solvers/python_pgd.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,48 @@
from math import sqrt
import numpy as np
from scipy import sparse


from benchopt import BaseSolver
from benchopt import safe_import_context


with safe_import_context() as import_ctx:
import cupy as cp

class Solver(BaseSolver):
name = 'Python-PGD' # proximal gradient, optionally accelerated

requirements = ['cupy']

# any parameter defined here is accessible as a class attribute
parameters = {'use_acceleration': [False, True]}
parameters = {'use_acceleration': [False, True],
'use_gpu': [False, True]}

def skip(self, X, y, lmbd):
if sparse.issparse(X) and self.use_gpu:
return True, "sparse is not supported with GPU"
return False, None

def set_objective(self, X, y, lmbd):
if self.use_gpu:
X, y = cp.array(X), cp.array(y)
self.X, self.y, self.lmbd = X, y, lmbd

def run(self, n_iter):
L = self.compute_lipschitz_cste()

xp = cp if self.use_gpu else np

n_features = self.X.shape[1]
w = np.zeros(n_features)
w = xp.zeros(n_features)
if self.use_acceleration:
z = np.zeros(n_features)
z = xp.zeros(n_features)

t_new = 1
for _ in range(n_iter):
if self.use_acceleration:
t_old = t_new
t_new = (1 + np.sqrt(1 + 4 * t_old ** 2)) / 2
t_new = (1 + sqrt(1 + 4 * t_old ** 2)) / 2
w_old = w.copy()
z -= self.X.T @ (self.X @ z - self.y) / L
w = self.st(z, self.lmbd / L)
Expand All @@ -35,18 +51,23 @@ def run(self, n_iter):
w -= self.X.T @ (self.X @ w - self.y) / L
w = self.st(w, self.lmbd / L)

if self.use_gpu:
w = cp.asnumpy(w)

self.w = w

def st(self, w, mu):
w -= np.clip(w, -mu, mu)
xp = cp if self.use_gpu else np
w -= xp.clip(w, -mu, mu)
return w

def get_result(self):
return self.w

def compute_lipschitz_cste(self, max_iter=100):
if not sparse.issparse(self.X):
return np.linalg.norm(self.X, ord=2) ** 2
xp = cp if self.use_gpu else np
return xp.linalg.norm(self.X, ord=2) ** 2

n, m = self.X.shape
if n < m:
Expand Down