Source code for causalexplain.estimators.cam.train_lasso

"""
This Python version accomplishes the same task as the R function:

1. It uses `LassoCV` from scikit-learn to perform cross-validation and find the 
optimal regularization parameter (lambda in R, alpha in Python).
2. It then trains a final Lasso model using the optimal alpha.
3. The function returns a dictionary containing the fitted values, residuals, and 
the trained model.

Note that:
- The `cv.glmnet` in R is replaced by `LassoCV` in Python.
- The `glmnet` in R is replaced by `Lasso` in Python.
- In scikit-learn, the regularization parameter is called `alpha` instead of `lambda`.
- The `pars` parameter is kept for consistency, but it's not used in this 
implementation. You can extend the function to use additional parameters if needed.
- The cross-validation is set to 5-fold (you can adjust this if needed).
- A random state is set for reproducibility.

This Python version should provide equivalent functionality to the original R function.
"""
from typing import Dict, Any
from sklearn.linear_model import LassoCV, Lasso
import numpy as np


[docs] def train_lasso(X, y, pars=None) -> Dict[str, Any]: """_summary_ Args: X (_type_): _description_ y (_type_): _description_ pars (_type_, optional): _description_. Defaults to None. Returns: _type_: _description_ """ # Perform cross-validation to find optimal lambda cv_model = LassoCV(cv=5, random_state=42) cv_model.fit(X, y) # Train the final model using the optimal lambda (alpha in sklearn) mod = Lasso(alpha=cv_model.alpha_) mod.fit(X, y) # Prepare results result = { 'Yfit': mod.predict(X), 'residuals': y - mod.predict(X), 'model': mod } return result