Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 25 additions & 18 deletions mpax/solver_log.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,23 +187,32 @@ def display_problem_details(qp: QuadraticProgrammingProblem) -> None:
The quadratic programming problem object containing the matrix and vector details.
"""
if logging.root.level == logging.INFO:
if isinstance(qp.constraint_matrix, (BCOO, BCSR)):
constraint_matrix_nnz_count = len(qp.constraint_matrix.data)
constraint_matrix_data = qp.constraint_matrix.data
else:
constraint_matrix_nnz_count = jnp.count_nonzero(qp.constraint_matrix)
constraint_matrix_data = qp.constraint_matrix
if isinstance(qp.objective_matrix, (BCOO, BCSR)):
objective_matrix_data = qp.objective_matrix.data
else:
objective_matrix_data = qp.objective_matrix
jax_debug_log(
"There are {:d} variables, {:d} constraints (including {:d} equalities) and {:d} nonzero coefficients.",
qp.constraint_matrix.shape[1],
qp.constraint_matrix.shape[0],
qp.num_variables,
qp.num_constraints,
qp.num_equalities,
len(qp.constraint_matrix.data),
constraint_matrix_nnz_count,
logger=logger,
level=logging.INFO,
)

nz_constraints = qp.constraint_matrix.data
jax_debug_log(
"Absolute value of nonzero constraint matrix elements:\n"
" largest={:.6f}, smallest={:.6f}, avg={:.6f}",
jnp.max(jnp.abs(nz_constraints)),
jnp.min(jnp.abs(nz_constraints)),
jnp.mean(jnp.abs(nz_constraints)),
jnp.max(jnp.abs(constraint_matrix_data), initial=0),
jnp.min(jnp.abs(constraint_matrix_data), initial=0),
jnp.mean(jnp.abs(constraint_matrix_data)),
logger=logger,
level=logging.INFO,
)
Expand All @@ -221,17 +230,15 @@ def display_problem_details(qp: QuadraticProgrammingProblem) -> None:
level=logging.INFO,
)

if len(qp.objective_matrix.data) > 0:
nz_objectives = qp.objective_matrix.data
jax_debug_log(
"Absolute value of objective matrix elements:"
" largest={:.6f}, smallest={:.6f}, avg={:.6f}",
jnp.max(jnp.abs(nz_objectives)),
jnp.min(jnp.abs(nz_objectives)),
jnp.mean(jnp.abs(nz_objectives)),
logger=logger,
level=logging.INFO,
)
jax_debug_log(
"Absolute value of objective matrix elements:"
" largest={:.6f}, smallest={:.6f}, avg={:.6f}",
jnp.max(jnp.abs(objective_matrix_data), initial=0),
jnp.min(jnp.abs(objective_matrix_data), initial=0),
jnp.mean(jnp.abs(objective_matrix_data)),
logger=logger,
level=logging.INFO,
)

jax_debug_log(
"Absolute value of objective vector elements:\n"
Expand Down
11 changes: 11 additions & 0 deletions tests/rapdhg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,17 @@ def test_rapdhg_lp_with_jit():
assert pytest.approx(result.primal_objective, rel=1e-2) == expected_obj


def test_rapdhg_lp_with_jit_dense_matrix():
"""Test the raPDHG solver on a sample LP problem."""
for model_filename, expected_obj in lp_model_objs.items():
gurobi_model = gp.read(pytest_cache_dir + "/" + model_filename)
qp = create_qp_from_gurobi(gurobi_model, use_sparse_matrix=False)
solver = raPDHG(eps_abs=1e-6, eps_rel=1e-6, verbose=True)
result = solver.optimize(qp)

assert pytest.approx(result.primal_objective, rel=1e-2) == expected_obj


def test_rapdhg_qp_with_jit():
"""Test the raPDHG solver on a sample LP problem."""
for model_filename, expected_obj in qp_model_objs.items():
Expand Down