causalexplain.estimators.rex package#
Submodules#
- class Knowledge(rex, ref_graph)[source]#
Bases:
objectThis class collects everything we know about each edge in the proposed graph in terms of the following properties:
origin: the origin node
target: the target node
is_edge: whether the edge is in the reference graph
is_root_cause: whether the origin is a root cause
is_leaf_node: whether the origin is a leaf node
correlation: the correlation between the individual SHAP values and the origin node
KS_pval: the p-value of the Kolmogorov-Smirnov test between the origin and the target
- shap_edge: whether the edge is in the graph constructed after evaluating mean
SHAP values.
shap_skedastic_pval: the p-value of the skedastic test for the SHAP values
parent_skedastic_pval: the p-value of the skedastic test for the parent values
mean_shap: the mean of the SHAP values between the origin and the target
slope_shap: the slope of the linear regression for target vs. SHAP values
slope_target: the slope of the linear regression for the target vs. origin values
potential_root: whether the origin is a potential root cause
regression_err: the regression error of the origin to the target
err_contrib: the error contribution of the origin to the target
- con_ind_pval: the p-value of the conditional independence test between the origin
and the target
Main class for the REX estimator. (C) J. Renero, 2022, 2023, 2024, 2025
- class Rex(name, model_type='nn', explainer='gradient', tune_model=False, correlation_th=None, corr_method='spearman', corr_alpha=0.6, corr_clusters=15, condlen=1, condsize=0, mean_pi_percentile=0.8, discrepancy_threshold=0.99, hpo_n_trials=20, bootstrap_trials=20, bootstrap_sampling_split='auto', bootstrap_tolerance='auto', bootstrap_parallel_jobs=0, parallel_jobs=0, verbose=False, prog_bar=True, silent=False, shap_fsize=(10, 10), dpi=75, pdf_filename=None, random_state=1234, **kwargs)[source]#
Bases:
BaseEstimator,ClassifierMixinRegression with Explainability (Rex) is a causal inference discovery that uses a regression model to predict the outcome of a treatment and uses explainability to identify the causal variables.
- Parameters:
demo_param (str, default='demo_param') – A parameter used for demonstation of how to pass and store paramters.
Examples
>>> from causalexplain.estimators.rex import Rex >>> import numpy as np
>>> dataset_name = 'rex_generated_linear_0' >>> ref_graph = utils.graph_from_dot_file(f"../data/{dataset_name}.dot") >>> data = pd.read_csv(f"{input_path}{dataset_name}.csv") >>> scaler = StandardScaler() >>> data = pd.DataFrame(scaler.fit_transform(data), columns=data.columns) >>> train = data.sample(frac=0.8, random_state=42) >>> test = data.drop(train.index)
>>> rex = Rex( name=dataset_name, tune_model=tune_model, model_type=model_type, explainer=explainer) >>> rex.fit_predict(train, test, ref_graph)
- shaps = None#
- hierarchies = None#
- pi = None#
- models = None#
- dag = None#
- indep = None#
- __init__(name, model_type='nn', explainer='gradient', tune_model=False, correlation_th=None, corr_method='spearman', corr_alpha=0.6, corr_clusters=15, condlen=1, condsize=0, mean_pi_percentile=0.8, discrepancy_threshold=0.99, hpo_n_trials=20, bootstrap_trials=20, bootstrap_sampling_split='auto', bootstrap_tolerance='auto', bootstrap_parallel_jobs=0, parallel_jobs=0, verbose=False, prog_bar=True, silent=False, shap_fsize=(10, 10), dpi=75, pdf_filename=None, random_state=1234, **kwargs)[source]#
- is_fitted_ = False#
- fit(X, y=None, pipeline=None)[source]#
Fit the model according to the given training data.
- Parameters:
X ({array-like, sparse matrix}, shape (n_samples, n_features)) – Training vector, where n_samples is the number of samples and n_features is the number of features.
y (array-like, shape (n_samples,)) – Target vector relative to X.
- Returns:
self – Returns self
- Return type:
- predict(X, ref_graph=None, prior=None, pipeline=None)[source]#
Predicts the causal graph from the given data.
- Parameters:
X (-) – The input samples.
ref_graph (-) – The reference graph, or ground truth.
prior (-) – The prior to use for building the DAG. This prior is a list of lists of node/feature names, ordered according to a temporal structure so that the first list contains the first set of nodes to be considered as root causes, the second list contains the set of nodes to be considered as potential effects of the first set, and the nodes in this second list, and so on. The number of lists in the prior is the depth of the conditioning sequence. This prior imposes the rule that the nodes in the first list are the only ones that can be root causes, and the nodes in the following lists cannot be the cause of the nodes in the previous lists. If the prior is not provided, the DAG is built without any prior information. Example: [[‘A’, ‘B’], [‘C’, ‘D’]]
- Returns:
- G_final – The final graph, after the correction stage.
- Return type:
nx.DiGraph
Examples
In the following example, where four features are used, the prior is defined as [[‘A’, ‘B’], [‘C’, ‘D’]], which means that the first set of features to be considered as root causes are ‘A’ and ‘B’, and the second set of features to be considered as potential effects of the first set are ‘C’ and ‘D’.
The resulting DAG cannot contain any edge from ‘C’ or ‘D’ to ‘A’ or ‘B’.
`python rex.predict(X_test, ref_graph, prior=[['A', 'B'], ['C', 'D']]) `
- fit_predict(train, test, ref_graph, prior=None)[source]#
Fit the model according to the given training data and predict the outcome of the treatment.
- Parameters:
train ({array-like, sparse matrix}, shape (n_samples, n_features)) – Training vector, where n_samples is the number of samples and n_features is the number of features.
test ({array-like, sparse matrix}, shape (n_samples, n_features)) – Test vector, where n_samples is the number of samples and n_features is the number of features.
ref_graph (nx.DiGraph) – The reference graph, or ground truth.
- Returns:
G_final – The final graph, after the correction stage.
- Return type:
nx.DiGraph
- bootstrap(X, ref_graph=None, num_iterations=20, sampling_split=0.2, prior=None, random_state=1234, tolerance=0.3, key_metric='f1', direction='maximize', parallel_jobs=0)[source]#
Finds the best tolerance value for the iterative predict method by iterating over different tolerance values and selecting the one that gives the best key_metric with respect to the reference graph.
- Parameters:
ref_graph (nx.DiGraph) – The reference graph to evaluate the F1 score against.
target (str, optional) – The target DAG to evaluate. Defaults to ‘shap’. Possible values: ‘shap’, ‘rho’, ‘adjusted’, ‘perm_imp’, ‘indep’, and ‘final’
key_metric (str, optional) – The key metric to evaluate. Defaults to ‘f1’. Possible values: ‘f1’, ‘precision’, ‘recall’, ‘shd’, sid’, ‘aupr’, ‘Tp’, ‘Tn’, ‘Fp’, ‘Fn’ ‘
direction (str, optional) – The direction of the key metric. Defaults to ‘maximize’. Possible values: ‘maximize’ or ‘minimize’
parallel_jobs (int, optional) – Number of processes to run the iterations in parallel. Defaults to 0.
- Returns:
nx.DiGraph
- Return type:
The best DAG found by the iterative predict method.
- score(ref_graph, predicted_graph='final')[source]#
Obtains the score of the predicted graph against the reference graph. The score contains different metrics, such as the precision, recall, F1-score, SHD or SID.
- compute_regression_quality()[source]#
Compute the regression quality for each feature in the dataset.
- Returns:
A set of features that are considered as root causes.
- Return type:
- main(dataset_name, input_path='/Users/renero/phd/data/RC4/', output_path='/Users/renero/phd/output/RC4/', load_model=False, fit_model=True, predict_model=True, scale_data=False, tune_model=False, model_type='nn', explainer='gradient', save=False)[source]#
Custom main function to run the pipeline with the given dataset. Specially useful for testing and debugging.
Module contents#
REX (Regression with Explainability) estimator module.
This module provides the REX estimator, which uses regression models and explainability techniques to discover causal relationships in data.
- class Rex(name, model_type='nn', explainer='gradient', tune_model=False, correlation_th=None, corr_method='spearman', corr_alpha=0.6, corr_clusters=15, condlen=1, condsize=0, mean_pi_percentile=0.8, discrepancy_threshold=0.99, hpo_n_trials=20, bootstrap_trials=20, bootstrap_sampling_split='auto', bootstrap_tolerance='auto', bootstrap_parallel_jobs=0, parallel_jobs=0, verbose=False, prog_bar=True, silent=False, shap_fsize=(10, 10), dpi=75, pdf_filename=None, random_state=1234, **kwargs)[source]#
Bases:
BaseEstimator,ClassifierMixinRegression with Explainability (Rex) is a causal inference discovery that uses a regression model to predict the outcome of a treatment and uses explainability to identify the causal variables.
- Parameters:
demo_param (str, default='demo_param') – A parameter used for demonstation of how to pass and store paramters.
Examples
>>> from causalexplain.estimators.rex import Rex >>> import numpy as np
>>> dataset_name = 'rex_generated_linear_0' >>> ref_graph = utils.graph_from_dot_file(f"../data/{dataset_name}.dot") >>> data = pd.read_csv(f"{input_path}{dataset_name}.csv") >>> scaler = StandardScaler() >>> data = pd.DataFrame(scaler.fit_transform(data), columns=data.columns) >>> train = data.sample(frac=0.8, random_state=42) >>> test = data.drop(train.index)
>>> rex = Rex( name=dataset_name, tune_model=tune_model, model_type=model_type, explainer=explainer) >>> rex.fit_predict(train, test, ref_graph)
- shaps = None#
- hierarchies = None#
- pi = None#
- models = None#
- dag = None#
- indep = None#
- __init__(name, model_type='nn', explainer='gradient', tune_model=False, correlation_th=None, corr_method='spearman', corr_alpha=0.6, corr_clusters=15, condlen=1, condsize=0, mean_pi_percentile=0.8, discrepancy_threshold=0.99, hpo_n_trials=20, bootstrap_trials=20, bootstrap_sampling_split='auto', bootstrap_tolerance='auto', bootstrap_parallel_jobs=0, parallel_jobs=0, verbose=False, prog_bar=True, silent=False, shap_fsize=(10, 10), dpi=75, pdf_filename=None, random_state=1234, **kwargs)[source]#
- is_fitted_ = False#
- fit(X, y=None, pipeline=None)[source]#
Fit the model according to the given training data.
- Parameters:
X ({array-like, sparse matrix}, shape (n_samples, n_features)) – Training vector, where n_samples is the number of samples and n_features is the number of features.
y (array-like, shape (n_samples,)) – Target vector relative to X.
- Returns:
self – Returns self
- Return type:
- predict(X, ref_graph=None, prior=None, pipeline=None)[source]#
Predicts the causal graph from the given data.
- Parameters:
X (-) – The input samples.
ref_graph (-) – The reference graph, or ground truth.
prior (-) – The prior to use for building the DAG. This prior is a list of lists of node/feature names, ordered according to a temporal structure so that the first list contains the first set of nodes to be considered as root causes, the second list contains the set of nodes to be considered as potential effects of the first set, and the nodes in this second list, and so on. The number of lists in the prior is the depth of the conditioning sequence. This prior imposes the rule that the nodes in the first list are the only ones that can be root causes, and the nodes in the following lists cannot be the cause of the nodes in the previous lists. If the prior is not provided, the DAG is built without any prior information. Example: [[‘A’, ‘B’], [‘C’, ‘D’]]
- Returns:
- G_final – The final graph, after the correction stage.
- Return type:
nx.DiGraph
Examples
In the following example, where four features are used, the prior is defined as [[‘A’, ‘B’], [‘C’, ‘D’]], which means that the first set of features to be considered as root causes are ‘A’ and ‘B’, and the second set of features to be considered as potential effects of the first set are ‘C’ and ‘D’.
The resulting DAG cannot contain any edge from ‘C’ or ‘D’ to ‘A’ or ‘B’.
`python rex.predict(X_test, ref_graph, prior=[['A', 'B'], ['C', 'D']]) `
- fit_predict(train, test, ref_graph, prior=None)[source]#
Fit the model according to the given training data and predict the outcome of the treatment.
- Parameters:
train ({array-like, sparse matrix}, shape (n_samples, n_features)) – Training vector, where n_samples is the number of samples and n_features is the number of features.
test ({array-like, sparse matrix}, shape (n_samples, n_features)) – Test vector, where n_samples is the number of samples and n_features is the number of features.
ref_graph (nx.DiGraph) – The reference graph, or ground truth.
- Returns:
G_final – The final graph, after the correction stage.
- Return type:
nx.DiGraph
- bootstrap(X, ref_graph=None, num_iterations=20, sampling_split=0.2, prior=None, random_state=1234, tolerance=0.3, key_metric='f1', direction='maximize', parallel_jobs=0)[source]#
Finds the best tolerance value for the iterative predict method by iterating over different tolerance values and selecting the one that gives the best key_metric with respect to the reference graph.
- Parameters:
ref_graph (nx.DiGraph) – The reference graph to evaluate the F1 score against.
target (str, optional) – The target DAG to evaluate. Defaults to ‘shap’. Possible values: ‘shap’, ‘rho’, ‘adjusted’, ‘perm_imp’, ‘indep’, and ‘final’
key_metric (str, optional) – The key metric to evaluate. Defaults to ‘f1’. Possible values: ‘f1’, ‘precision’, ‘recall’, ‘shd’, sid’, ‘aupr’, ‘Tp’, ‘Tn’, ‘Fp’, ‘Fn’ ‘
direction (str, optional) – The direction of the key metric. Defaults to ‘maximize’. Possible values: ‘maximize’ or ‘minimize’
parallel_jobs (int, optional) – Number of processes to run the iterations in parallel. Defaults to 0.
- Returns:
nx.DiGraph
- Return type:
The best DAG found by the iterative predict method.
- score(ref_graph, predicted_graph='final')[source]#
Obtains the score of the predicted graph against the reference graph. The score contains different metrics, such as the precision, recall, F1-score, SHD or SID.
- compute_regression_quality()[source]#
Compute the regression quality for each feature in the dataset.
- Returns:
A set of features that are considered as root causes.
- Return type:
- class Knowledge(rex, ref_graph)[source]#
Bases:
objectThis class collects everything we know about each edge in the proposed graph in terms of the following properties:
origin: the origin node
target: the target node
is_edge: whether the edge is in the reference graph
is_root_cause: whether the origin is a root cause
is_leaf_node: whether the origin is a leaf node
correlation: the correlation between the individual SHAP values and the origin node
KS_pval: the p-value of the Kolmogorov-Smirnov test between the origin and the target
- shap_edge: whether the edge is in the graph constructed after evaluating mean
SHAP values.
shap_skedastic_pval: the p-value of the skedastic test for the SHAP values
parent_skedastic_pval: the p-value of the skedastic test for the parent values
mean_shap: the mean of the SHAP values between the origin and the target
slope_shap: the slope of the linear regression for target vs. SHAP values
slope_target: the slope of the linear regression for the target vs. origin values
potential_root: whether the origin is a potential root cause
regression_err: the regression error of the origin to the target
err_contrib: the error contribution of the origin to the target
- con_ind_pval: the p-value of the conditional independence test between the origin
and the target