feat(linalg): add weighted least squares solver#1104
feat(linalg): add weighted least squares solver#1104aamrindersingh wants to merge 8 commits intofortran-lang:masterfrom
Conversation
There was a problem hiding this comment.
Pull request overview
Adds an experimental weighted least-squares API to stdlib_linalg, enabling WLS solves with per-observation (diagonal) positive weights via row-scaling and reuse of the existing SVD-based lstsq backend.
Changes:
- Introduces
weighted_lstsq(w, A, b [, cond, overwrite_a, rank, err])for real/complex systems with real positive weights. - Adds unit tests and a new example program demonstrating outlier downweighting.
- Documents the new interface in the linalg spec.
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
| test/linalg/test_linalg_lstsq.fypp | Adds WLS tests for uniform weights, solution-change under nonuniform weights, and negative-weight error handling. |
| src/linalg/stdlib_linalg_least_squares.fypp | Implements weighted least-squares by scaling rows of A/b with sqrt(w) and calling existing lstsq. |
| src/linalg/stdlib_linalg.fypp | Exposes weighted_lstsq as a public generic interface with documentation comments. |
| example/linalg/example_weighted_lstsq.f90 | Adds a runnable example demonstrating downweighting an outlier. |
| example/linalg/CMakeLists.txt | Registers the new example in the example build. |
| doc/specs/stdlib_linalg.md | Adds a new spec section for weighted_lstsq. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 6 out of 6 changed files in this pull request and generated 2 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 6 out of 6 changed files in this pull request and generated 1 comment.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 6 out of 6 changed files in this pull request and generated 2 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 6 out of 6 changed files in this pull request and generated 3 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| call stdlib_linalg_${ri}$_solve_lstsq_one(amat, b_scaled, x, cond=cond, overwrite_a=.true., rank=rank, err=err) | ||
|
|
There was a problem hiding this comment.
Errors raised by the underlying solve_lstsq_one call will be reported as coming from lstsq (because solve_lstsq_one uses the submodule-level this='lstsq'). This means weighted_lstsq failures (e.g., SVD nonconvergence) will surface as [lstsq] ..., which is misleading for users/debugging.
Consider calling solve_lstsq_one with a local linalg_state_type variable, then rewriting where_at to weighted_lstsq (or otherwise wrapping the state) before propagating via linalg_error_handling, so both returned err and error stop messages correctly identify weighted_lstsq.
| call stdlib_linalg_${ri}$_solve_lstsq_one(amat, b_scaled, x, cond=cond, overwrite_a=.true., rank=rank, err=err) | |
| block | |
| type(linalg_state_type) :: local_err | |
| if (present(err)) then | |
| call stdlib_linalg_${ri}$_solve_lstsq_one(amat, b_scaled, x, cond=cond, overwrite_a=.true., rank=rank, err=local_err) | |
| local_err%where_at = 'weighted_lstsq' | |
| err = local_err | |
| else | |
| call stdlib_linalg_${ri}$_solve_lstsq_one(amat, b_scaled, x, cond=cond, overwrite_a=.true., rank=rank, err=local_err) | |
| end if | |
| end block |
There was a problem hiding this comment.
@jvdp1 Need some inputs on this :-
right now weighted_lstsq validation errors say "weighted_lstsq" but internal errors say "lstsq". Should I change it so all errors just say "lstsq" for consistency or keep it as is?
There was a problem hiding this comment.
I ain't too familiar with how the linalg_error_type system works. @perazz : any input on this?
| !> Result array x[n] | ||
| ${rt}$, allocatable :: x(:) | ||
|
|
There was a problem hiding this comment.
New function result x is declared without the target attribute. In this codebase, linalg function results are typically declared allocatable, target (e.g., stdlib_linalg_${ri}$_lstsq_* in this same file uses ${rt}$, allocatable, target :: x...). Aligning the declaration avoids inconsistency and can prevent pointer-remapping surprises in compilers.
Suggestion: declare x as ${rt}$, allocatable, target :: x(:) here (and keep the interface in stdlib_linalg.fypp consistent).
| !> Result array x[n] | ||
| ${rt}$, allocatable :: x(:) | ||
| end function stdlib_linalg_${ri}$_weighted_lstsq |
There was a problem hiding this comment.
The weighted_lstsq interface declares the result as ${rt}$, allocatable :: x(:), but other linalg interfaces in this module consistently use allocatable, target for result arrays/matrices (e.g., lstsq result at src/linalg/stdlib_linalg.fypp around the existing allocatable, target :: x...). For consistency with the module’s established pattern (and to match the implementation if updated), consider declaring the result here as ${rt}$, allocatable, target :: x(:).
There was a problem hiding this comment.
target is to be used only if needed in the case one needs to set a pointer on top. If that's not the case, then please don't use it.
loiseaujc
left a comment
There was a problem hiding this comment.
I peaked a quick glance at your implementation. See some of my comments.
I'll try to take some time by the end of the week to go deeper into it. I've also enabled the CI/CD to run so that we can keep track of any test fail.
| ! Scale A column-wise (cache-friendly: column-major order) | ||
| do j = 1, n | ||
| amat(:, j) = sqrt_w(:) * a(:, j) | ||
| end do |
There was a problem hiding this comment.
This is a multiplication by a diagonal matrix. You can use the lascl2 function from lapack for that purpose.
There was a problem hiding this comment.
@loiseaujc ,
Looked into using lascl2 for diagnol matrix multiplication but it does not appear to be wrapped in stdlib LAPACK interface
Shall I proceed with opening a separate issue/PR for adding lascl2?
There was a problem hiding this comment.
@perazz : I've never really asked, but which version of lapackdid you use for the modernization? It seems like some of blas-like function in lapack v3.12.1 are no present in our port.
There was a problem hiding this comment.
A plain nested do loop is also fine here.
@loiseaujc the reference lapack was 3.10
There was a problem hiding this comment.
dlascl2 was introduced in 3.11 apparently, so that's the reason why. In any case, the code is as simple as
do concurrent( i=1:m, j=1:n)
x(i, j) = d(i) * x(i, j)
end doYou can probably keep it as is for the time being because it's so simple. Eventually we'll probably have to update the lapack module to 3.12 and we'll change these few lines at this time. May be you can simple add a comment mentioning dlascl2 so that we easily grep it and retrieve it whenever lapack will have been updated.
There was a problem hiding this comment.
Another option is to add a templated lascl2 implementation from LAPACK 3.11+ into our LAPACK module: it is only one routine, it should be manageable
There was a problem hiding this comment.
Sure enough. The code is extremely simple:
* =====================================================================
SUBROUTINE dlascl2( M, N, D, X, LDX )
*
* -- LAPACK computational routine --
* -- LAPACK is a software package provided by Univ. of Tennessee, --
* -- Univ. of California Berkeley, Univ. of Colorado Denver and NAG Ltd..--
*
* .. Scalar Arguments ..
INTEGER M, N, LDX
* ..
* .. Array Arguments ..
DOUBLE PRECISION D( * ), X( LDX, * )
* ..
*
* =====================================================================
*
* .. Local Scalars ..
INTEGER I, J
* ..
* .. Executable Statements ..
*
DO j = 1, n
DO i = 1, m
x( i, j ) = x( i, j ) * d( i )
END DO
END DO
RETURN
ENDI'm commuting right now and, out of curiosity, have started to look at the changelogs for lapack 3.10.1 and lapack 3.11.0. Nothing too complicated as far as I can see. May a bit tedious to port all of these changes, but fairly easy nonetheless.
There was a problem hiding this comment.
The code is extremely simple:
so many such cases; perhaps our implementation should have do concurrent internally. However, if libraries are going to pick it up in the future as part of the "external" API, maybe the future-proof choice is to use the same API?
There was a problem hiding this comment.
Using do concurrent internally makes sense. When you say the same api, you mean the same signature (i.e. dlascl2( M, N, D, X, LDX )) ?
There was a problem hiding this comment.
yep, same signature as the official API
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #1104 +/- ##
==========================================
- Coverage 68.51% 68.50% -0.01%
==========================================
Files 396 397 +1
Lines 12746 12755 +9
Branches 1376 1376
==========================================
+ Hits 8733 8738 +5
- Misses 4013 4017 +4 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
perazz
left a comment
There was a problem hiding this comment.
Here is a summary of the outstanding edits/suggestions from the discussion above:
1a) add a procedure update_location to
stdlib/src/core/stdlib_error.fypp
Line 61 in 9591c0d
1b) in
linalg_error_handling, add a third optional parameter where_at and if (present(where_at)) update the location of the error message;1c) in
call stdlib_linalg_${ri}$_solve_lstsq_one(... err = err0) -> replace err with err0 and then call the error handler (err0, err, where_at='weighted_lstsq)2) create a templated
*lascl2 procedure, should be in the LAPACK's "blas-like" module
Resolves #1047 (1 of 2)
This PR adds the
weighted_lstsqinterface to stdlib_linalg for weighted least squares problems. It handles diagonal weight matrices where each observation has a different importance, common in heteroscedastic regression and outlier downweighting.The key design decisions are:
sqrt(w).overwrite_apattern fromsolve_luwherecopy_adefaults to.true.to preserve A unless the user explicitly opts into destruction for performance.Testing includes:
lstsqpatterns