Pure Go implementation of LASSO and Elastic Net regression - sklearn-compatible API
Efficient implementation of LASSO (Least Absolute Shrinkage and Selection Operator) and Elastic Net regression using sequential coordinate descent optimization. Supports pure LASSO (L1), pure Ridge (L2), and balanced Elastic Net regularization. Optimized for cache locality and numerical stability.
- ⚡ Sequential coordinate descent - Optimized for cache locality and performance
- 📉 L1 regularization - Automatic feature selection and model simplification
- 🔀 Elastic Net support - Combine L1 (LASSO) and L2 (Ridge) regularization
- 🎯 Early stopping - Terminates training when convergence is detected
- 📊 Metrics tracking - Records MSE, R², and weight deltas during training
- 🔧 Feature standardization - Automatic data preprocessing
- 📈 Comprehensive evaluation - Supports R², MSE, and MAE metrics
- 📝 Training history - Access detailed logs of each iteration
- ⚙️ Configurable parameters - Tune lambda, alpha, tolerance, and more
- 🔄 Cross-validation - K-fold CV for automatic lambda selection
- 💾 Model persistence - Save/Load models to JSON
- Go 1.25+
go get github.com/causalgo/lassopackage main
import (
"fmt"
"github.com/causalgo/lasso"
"gonum.org/v1/gonum/mat"
)
func main() {
// Training data
X := mat.NewDense(4, 2, []float64{
1, 2,
3, 4,
5, 6,
7, 8,
})
y := []float64{3, 7, 11, 15}
// Configure training
cfg := lasso.NewDefaultConfig()
cfg.Lambda = 0.1 // Regularization strength
cfg.Verbose = true // Enable training logs
// Train model
model, err := lasso.Fit(X, y, cfg)
if err != nil {
panic(err)
}
// Make predictions
newX := mat.NewDense(2, 2, []float64{
2, 3,
4, 5,
})
predictions := model.Predict(newX)
fmt.Println("Predictions:", predictions) // [5.0001, 9.0000]
// Evaluate model
score := model.Score(X, y)
fmt.Printf("R² score: %.4f\n", score) // 1.0000
}cfg := &lasso.Config{
Lambda: 0.05, // Regularization parameter
Alpha: 1.0, // Elastic Net mixing: 1.0=LASSO, 0.0=Ridge, 0.5=balanced
MaxIter: 2000, // Maximum iterations
Tol: 1e-5, // Convergence tolerance
Standardize: true, // Standardize features
Verbose: true, // Show training logs
LogStep: 50, // Log every 50 iterations
EarlyStop: true, // Enable early stopping
StopAfter: 15, // Stop after 15 iterations without improvement
MinDelta: 1e-5, // Minimum improvement for early stopping
}// Pure LASSO (L1 only) - produces sparse solutions
cfg := lasso.NewDefaultConfig()
cfg.Lambda = 0.1
cfg.Alpha = 1.0 // Default: pure LASSO
// Elastic Net (L1 + L2 mix) - balanced regularization
cfg.Alpha = 0.5 // 50% L1, 50% L2
// Pure Ridge (L2 only) - no sparsity, all features active
cfg.Alpha = 0.0 // Pure Ridge regularization
model, err := lasso.Fit(X, y, cfg)The Elastic Net objective function:
minimize: (1/2n) * ||y - Xw||² + λ * (α * ||w||₁ + (1-α) * ||w||²/2)
Where:
α = 1.0: Pure LASSO - encourages sparsity (feature selection)α = 0.0: Pure Ridge - shrinks coefficients without sparsity0 < α < 1: Elastic Net - balances sparsity and stability
model, err := lasso.Fit(X, y, cfg)
if err != nil {
panic(err)
}
// Analyze training progress
for _, log := range model.History {
if log.Iteration%100 == 0 {
fmt.Printf("Iter %d: MSE=%.4f R²=%.4f\n",
log.Iteration, log.MSE, log.R2)
}
}// Save model to JSON
err := model.Save("model.json")
if err != nil {
panic(err)
}
// Load model from JSON
loadedModel, err := lasso.Load("model.json")
if err != nil {
panic(err)
}// Automatic lambda selection via k-fold cross-validation
result, err := lasso.CrossValidate(X, y, &lasso.CVConfig{
Lambdas: []float64{0.001, 0.01, 0.1, 1.0}, // Or nil for auto-generation
NFolds: 5, // 5-fold CV
Scoring: "mse", // "mse", "r2", or "mae"
Seed: 42, // For reproducibility
Config: lasso.NewDefaultConfig(), // Base training config
})
if err != nil {
panic(err)
}
fmt.Printf("Best lambda: %.4f\n", result.BestLambda)
fmt.Printf("Best score: %.4f\n", result.BestScore)
// Use the best model (trained on full data with best lambda)
predictions := result.Model.Predict(newX)
// Access detailed CV results
for lambda, scores := range result.CVScores {
fmt.Printf("Lambda %.4f: mean=%.4f, scores=%v\n",
lambda, result.MeanScores[lambda], scores)
}Run benchmarks locally:
go test -bench=. -run=^Benchmark -benchmem ./...Key optimizations:
- Sequential coordinate descent - better cache locality than parallel
- RawMatrix() access - direct slice operations, no bounds checking overhead
- Minimal allocations - reuses buffers via
predictInto()
- API Reference - Full documentation on pkg.go.dev
- CONTRIBUTING.md - Development guide and Git workflow
- CHANGELOG.md - Release history
- ROADMAP.md - Development roadmap
- SECURITY.md - Security policy
Contributions are welcome! Please read our Contributing Guide for details on:
- Git-Flow branching model
- Commit message conventions
- Code quality standards
- Pull request requirements
Quick start:
git checkout develop
git checkout -b feature/amazing-feature
# Make changes...
go fmt ./... && golangci-lint run && go test -race ./...
git commit -m "feat: add amazing feature"
git push origin feature/amazing-featureThis project is licensed under the MIT License - see the LICENSE file for details.
causalgo - Machine learning tools for causal analysis in Go