causalexplain package#
Subpackages#
- causalexplain.common package
- Submodules
BaseExperiment
Experiment
setup_plot()
add_grid()
subplots()
format_graph()
draw_graph_subplot()
cleanup_graph()
set_colormap()
dag2dot()
values_distribution()
correlation_matrix()
hierarchies()
dag()
dags()
shap_values()
shap_discrepancies()
deprecated_dags()
score_by_method()
scores_by_method()
score_by_subtype()
combined_metrics()
latex_table_by_datatype()
latex_table_by_method()
save_experiment()
load_experiment()
valid_output_name()
graph_from_dot_file()
graph_from_dictionary()
graph_from_adjacency()
graph_from_adjacency_file()
graph_to_adjacency()
graph_to_adjacency_file()
graph_to_dot_file()
select_device()
graph_intersection()
graph_union()
digraph_from_connected_features()
correct_edge_from_prior()
valid_candidates_from_prior()
break_cycles_using_prior()
potential_misoriented_edges()
break_cycles_if_present()
stringfy_object()
get_feature_names()
get_feature_types()
cast_categoricals_to_int()
find_crossing_point()
format_time()
combine_dags()
list_files()
read_json_file()
- Module contents
- Submodules
- causalexplain.estimators package
- causalexplain.explainability package
- Submodules
Hierarchies
connect_isolated_nodes()
connect_hierarchies()
plot_dendogram_correlations()
- Parameters:
PermutationImportance
RegQuality
ShapDiscrepancy
ShapDiscrepancy.target
ShapDiscrepancy.parent
ShapDiscrepancy.shap_heteroskedasticity
ShapDiscrepancy.parent_heteroskedasticity
ShapDiscrepancy.shap_p_value
ShapDiscrepancy.parent_p_value
ShapDiscrepancy.shap_model
ShapDiscrepancy.parent_model
ShapDiscrepancy.shap_discrepancy
ShapDiscrepancy.shap_correlation
ShapDiscrepancy.shap_gof
ShapDiscrepancy.ks_pvalue
ShapDiscrepancy.ks_result
ShapDiscrepancy.__init__()
ShapEstimator
ShapEstimator.device
ShapEstimator.shap_discrepancies
ShapEstimator.__init__()
ShapEstimator.explainer
ShapEstimator.models
ShapEstimator.correlation_th
ShapEstimator.mean_shap_percentile
ShapEstimator.iters
ShapEstimator.reciprocity
ShapEstimator.min_impact
ShapEstimator.exhaustive
ShapEstimator.parallel_jobs
ShapEstimator.on_gpu
ShapEstimator.verbose
ShapEstimator.prog_bar
ShapEstimator.silent
ShapEstimator.fit()
ShapEstimator.predict()
ShapEstimator.adjust()
ShapEstimator.set_predict_request()
ShapEstimator.compute_error_contribution()
custom_main()
sachs_main()
- Module contents
- Submodules
- causalexplain.generators package
- causalexplain.independence package
- Submodules
ConditionalIndependencies
SufficientSets
get_backdoor_paths()
get_paths()
find_colliders_in_path()
get_sufficient_sets_for_pair()
get_sufficient_sets()
get_conditional_independencies()
custom_main()
main()
dag_main()
get_edge_orientation()
estimate()
estimate_edge()
main()
select_features()
find_cluster_change_point()
main()
test()
GraphIndependence
HSIC_Values
HSIC
rbf_dot()
kernel_Delta_norm()
kernel_Delta()
kernel_Gaussian()
pairwise_mic()
fit_and_get_residuals()
run_feature_selection()
- Module contents
- Submodules
- causalexplain.metrics package
- causalexplain.models package
- Submodules
MLP
DFF
MDN
ColumnsDataset
RBF
MMDLoss
BaseModel
BaseModel.model
BaseModel.all_columns
BaseModel.callbacks
BaseModel.columns
BaseModel.logger
BaseModel.extra_trainer_args
BaseModel.scaler
BaseModel.train_loader
BaseModel.val_loader
BaseModel.n_rows
BaseModel.device
BaseModel.__init__()
BaseModel.init_logger()
BaseModel.init_callbacks()
BaseModel.init_data()
BaseModel.override_extras()
MLPModel
extract_weights()
see_weights_to_hidden()
see_weights_from_input()
plot_feature()
plot_features()
layer_weights()
summarize_weights()
identify_relationships()
infer_causal_relationships()
NNRegressor
custom_main()
GBTRegressor
custom_main()
- Module contents
- Submodules
Submodules#
This module contains the GraphDiscovery class which is responsible for creating, fitting, and evaluating causal discovery experiments.
- class GraphDiscovery(experiment_name=None, model_type='rex', csv_filename=None, true_dag_filename=None, verbose=False, seed=42)[source]#
Bases:
object
- Attributes:
- model
Methods
combine_and_evaluate_dags
([prior])Retrieve the DAG from the Experiment objects.
Create an Experiment object for each regressor.
export
(output_file)This method exports the DAG to a DOT file.
fit_experiments
([hpo_iterations, ...])Fit the Experiment objects.
load
(model_path)Load the model from a pickle file.
plot
([show_metrics, show_node_fill, title, ...])This method plots the DAG using networkx and matplotlib.
printout_results
(graph, metrics)This method prints the DAG to stdout in hierarchical order.
run
([hpo_iterations, bootstrap_iterations, ...])Run the experiment.
save
(full_filename_path)Save the model as an Experiment object.
- __init__(experiment_name=None, model_type='rex', csv_filename=None, true_dag_filename=None, verbose=False, seed=42)[source]#
Initializes a new instance of the GraphDiscovery class.
- Parameters:
experiment_name (str, optional) – The name of the experiment.
model_type (str, optional) – The type of model to use. Valid options are: ‘rex’, ‘pc’, ‘fci’, ‘ges’, ‘lingam’, ‘cam’, ‘notears’.
csv_filename (str, optional) – The filename of the CSV file containing the data.
true_dag_filename (str, optional) – The filename of the DOT file containing the true causal graph.
verbose (bool, optional) – Whether to print verbose output.
seed (int, optional) – The random seed for reproducibility.
- create_experiments()[source]#
Create an Experiment object for each regressor.
- Parameters:
- Returns:
A dictionary of Experiment objects
- Return type:
- fit_experiments(hpo_iterations=None, bootstrap_iterations=None, prior=None, **kwargs)[source]#
Fit the Experiment objects.
- Parameters:
trainer (dict) – A dictionary of Experiment objects
estimator (str) – The estimator to use (‘rex’ or other)
verbose (bool, optional) – Whether to print verbose output. Defaults to False.
hpo_iterations (int, optional) – Number of HPO trials for REX. Defaults to None.
bootstrap_iterations (int, optional) – Number of bootstrap trials for REX. Defaults to None.
- combine_and_evaluate_dags(prior=None)[source]#
Retrieve the DAG from the Experiment objects.
- Parameters:
prior (List[List[str]], optional) – The prior to use for ReX. Defaults to None.
- Returns:
The experiment object with the final DAG
- Return type:
- run(hpo_iterations=None, bootstrap_iterations=None, prior=None, **kwargs)[source]#
Run the experiment.
- save(full_filename_path)[source]#
Save the model as an Experiment object.
- Parameters:
full_filename_path (str) – A full path where to save the model, including the filename.
- load(model_path)[source]#
Load the model from a pickle file.
- Parameters:
model_path (str) – Path to the pickle file containing the model
- Returns:
The loaded Experiment object
- Return type:
- printout_results(graph, metrics)[source]#
This method prints the DAG to stdout in hierarchical order.
- plot(show_metrics=False, show_node_fill=True, title=None, ax=None, figsize=(5, 5), dpi=75, save_to_pdf=None, layout='dot', **kwargs)[source]#
This method plots the DAG using networkx and matplotlib.
- property model#
Module contents#
CausalExplain: A Python package for causal discovery and inference.
This package provides tools for discovering and analyzing causal relationships in data using various methods and algorithms.
- class GraphDiscovery(experiment_name=None, model_type='rex', csv_filename=None, true_dag_filename=None, verbose=False, seed=42)[source]#
Bases:
object
- Attributes:
- model
Methods
combine_and_evaluate_dags
([prior])Retrieve the DAG from the Experiment objects.
Create an Experiment object for each regressor.
export
(output_file)This method exports the DAG to a DOT file.
fit_experiments
([hpo_iterations, ...])Fit the Experiment objects.
load
(model_path)Load the model from a pickle file.
plot
([show_metrics, show_node_fill, title, ...])This method plots the DAG using networkx and matplotlib.
printout_results
(graph, metrics)This method prints the DAG to stdout in hierarchical order.
run
([hpo_iterations, bootstrap_iterations, ...])Run the experiment.
save
(full_filename_path)Save the model as an Experiment object.
- __init__(experiment_name=None, model_type='rex', csv_filename=None, true_dag_filename=None, verbose=False, seed=42)[source]#
Initializes a new instance of the GraphDiscovery class.
- Parameters:
experiment_name (str, optional) – The name of the experiment.
model_type (str, optional) – The type of model to use. Valid options are: ‘rex’, ‘pc’, ‘fci’, ‘ges’, ‘lingam’, ‘cam’, ‘notears’.
csv_filename (str, optional) – The filename of the CSV file containing the data.
true_dag_filename (str, optional) – The filename of the DOT file containing the true causal graph.
verbose (bool, optional) – Whether to print verbose output.
seed (int, optional) – The random seed for reproducibility.
- create_experiments()[source]#
Create an Experiment object for each regressor.
- Parameters:
- Returns:
A dictionary of Experiment objects
- Return type:
- fit_experiments(hpo_iterations=None, bootstrap_iterations=None, prior=None, **kwargs)[source]#
Fit the Experiment objects.
- Parameters:
trainer (dict) – A dictionary of Experiment objects
estimator (str) – The estimator to use (‘rex’ or other)
verbose (bool, optional) – Whether to print verbose output. Defaults to False.
hpo_iterations (int, optional) – Number of HPO trials for REX. Defaults to None.
bootstrap_iterations (int, optional) – Number of bootstrap trials for REX. Defaults to None.
- combine_and_evaluate_dags(prior=None)[source]#
Retrieve the DAG from the Experiment objects.
- Parameters:
prior (List[List[str]], optional) – The prior to use for ReX. Defaults to None.
- Returns:
The experiment object with the final DAG
- Return type:
- run(hpo_iterations=None, bootstrap_iterations=None, prior=None, **kwargs)[source]#
Run the experiment.
- save(full_filename_path)[source]#
Save the model as an Experiment object.
- Parameters:
full_filename_path (str) – A full path where to save the model, including the filename.
- load(model_path)[source]#
Load the model from a pickle file.
- Parameters:
model_path (str) – Path to the pickle file containing the model
- Returns:
The loaded Experiment object
- Return type:
- printout_results(graph, metrics)[source]#
This method prints the DAG to stdout in hierarchical order.
- plot(show_metrics=False, show_node_fill=True, title=None, ax=None, figsize=(5, 5), dpi=75, save_to_pdf=None, layout='dot', **kwargs)[source]#
This method plots the DAG using networkx and matplotlib.
- property model#