"""
This module contains the GraphDiscovery class which is responsible for
creating, fitting, and evaluating causal discovery experiments.
"""
import os
import pickle
import pandas as pd
import matplotlib.pyplot as plt
from typing import List, Tuple
from causalexplain.common import (
DEFAULT_REGRESSORS,
utils,
)
from causalexplain.common import plot
from causalexplain.common.notebook import Experiment
from causalexplain.metrics.compare_graphs import evaluate_graph
[docs]
class GraphDiscovery:
[docs]
def __init__(
self,
experiment_name: str = None,
model_type: str = 'rex',
csv_filename: str = None,
true_dag_filename: str = None,
verbose: bool = False,
seed: int = 42
) -> None:
"""
Initializes a new instance of the GraphDiscovery class.
Args:
experiment_name (str, optional): The name of the experiment.
model_type (str, optional): The type of model to use. Valid options
are: 'rex', 'pc', 'fci', 'ges', 'lingam', 'cam', 'notears'.
csv_filename (str, optional): The filename of the CSV file containing
the data.
true_dag_filename (str, optional): The filename of the DOT file
containing the true causal graph.
verbose (bool, optional): Whether to print verbose output.
seed (int, optional): The random seed for reproducibility.
"""
# Normalize empty/whitespace strings to None
experiment_name = experiment_name.strip() if isinstance(
experiment_name, str) else experiment_name
experiment_name = None if experiment_name == "" else experiment_name
csv_filename = csv_filename.strip() if isinstance(
csv_filename, str) else csv_filename
csv_filename = None if csv_filename == "" else csv_filename
if (experiment_name is None and csv_filename is not None) or \
(experiment_name is not None and csv_filename is None):
raise ValueError(
f"Both 'experiment_name' and 'csv_filename' must be provided together, "
f"or none of them. Got experiment_name='{experiment_name}', "
f"csv_filename='{csv_filename}'")
elif experiment_name is None and csv_filename is None:
self.experiment_name = None
self.estimator = 'rex'
self.csv_filename = None
self.dot_filename = None
self.verbose = False
self.seed = 42
return
self.experiment_name = experiment_name
self.estimator = model_type
self.csv_filename = csv_filename
self.dot_filename = true_dag_filename
self.verbose = verbose
self.seed = seed
self.dataset_path = os.path.dirname(csv_filename)
self.output_path = os.getcwd()
self.trainer = {}
# Read the reference graph
if true_dag_filename is not None:
self.ref_graph = utils.graph_from_dot_file(true_dag_filename)
else:
self.ref_graph = None
# assert that the data file exists
if not os.path.exists(csv_filename):
raise FileNotFoundError(f"Data file {csv_filename} not found")
self.dataset_name = os.path.splitext(os.path.basename(csv_filename))[0]
# Read the column names of the data.
data = pd.read_csv(csv_filename)
self.data_columns = list(data.columns)
del data
if self.estimator == 'rex':
self.regressors = DEFAULT_REGRESSORS
else:
self.regressors = [self.estimator]
[docs]
def create_experiments(self) -> dict:
"""
Create an Experiment object for each regressor.
Args:
dataset_name (str): Name of the dataset
true_dag (str): Path to the true DAG DOT file
regressors (list): List of regressor types to create experiments for
dataset_path (str): Path to the input dataset
output_path (str): Path for output files
Returns:
dict: A dictionary of Experiment objects
"""
self.trainer = {}
for model_type in self.regressors:
trainer_name = f"{self.dataset_name}_{model_type}"
self.trainer[trainer_name] = Experiment(
experiment_name=self.dataset_name,
csv_filename=self.csv_filename,
dot_filename=self.dot_filename,
model_type=model_type,
input_path=self.dataset_path,
output_path=self.output_path,
verbose=False)
return self.trainer
[docs]
def fit_experiments(
self,
hpo_iterations: int = None,
bootstrap_iterations: int = None,
prior: List[List[str]] = None,
**kwargs
) -> None:
"""
Fit the Experiment objects.
Args:
trainer (dict): A dictionary of Experiment objects
estimator (str): The estimator to use ('rex' or other)
verbose (bool, optional): Whether to print verbose output.
Defaults to False.
hpo_iterations (int, optional): Number of HPO trials for REX.
Defaults to None.
bootstrap_iterations (int, optional): Number of bootstrap trials
for REX. Defaults to None.
"""
if self.estimator == 'rex':
xargs = {
'verbose': self.verbose,
'hpo_n_trials': hpo_iterations,
'bootstrap_trials': bootstrap_iterations,
# 'prior': prior
}
else:
xargs = {
'verbose': self.verbose
}
# Combine the arguments
xargs.update(kwargs)
for trainer_name, experiment in self.trainer.items():
if not trainer_name.endswith("_rex"):
experiment.fit_predict(estimator=self.estimator, **xargs)
[docs]
def combine_and_evaluate_dags(self, prior: List[List[str]] = None) -> Experiment:
"""
Retrieve the DAG from the Experiment objects.
Args:
prior (List[List[str]], optional): The prior to use for ReX.
Defaults to None.
Returns:
Experiment: The experiment object with the final DAG
"""
if self.estimator != 'rex':
trainer_key = f"{self.dataset_name}_{self.estimator}"
estimator_obj = getattr(self.trainer[trainer_key], self.estimator)
self.trainer[trainer_key].dag = estimator_obj.dag
if self.ref_graph is not None and self.data_columns is not None:
self.trainer[trainer_key].metrics = evaluate_graph(
self.ref_graph, estimator_obj.dag, self.data_columns)
else:
self.trainer[trainer_key].metrics = None
self.dag = self.trainer[trainer_key].dag
self.metrics = self.trainer[trainer_key].metrics
return self.trainer[trainer_key]
# For ReX, we need to combine the DAGs. Hardcoded for now to combine
# the first and second DAGs
estimator1 = getattr(self.trainer[list(self.trainer.keys())[0]], 'rex')
estimator2 = getattr(self.trainer[list(self.trainer.keys())[1]], 'rex')
_, _, dag, _ = utils.combine_dags(
estimator1.dag, estimator2.dag,
discrepancies=estimator1.shaps.shap_discrepancies,
prior=prior
)
# Create a new Experiment object for the combined DAG
new_trainer = f"{self.dataset_name}_rex"
if new_trainer in self.trainer:
del self.trainer[new_trainer]
self.trainer[new_trainer] = Experiment(
experiment_name=self.dataset_name,
model_type='rex',
input_path=self.dataset_path,
output_path=self.output_path,
verbose=False)
# Set the DAG and evaluate it
self.trainer[new_trainer].ref_graph = self.ref_graph
self.trainer[new_trainer].dag = dag
if self.ref_graph is not None and self.data_columns is not None:
self.trainer[new_trainer].metrics = evaluate_graph(
self.ref_graph, dag, self.data_columns)
else:
self.trainer[new_trainer].metrics = None
self.dag = self.trainer[new_trainer].dag
self.metrics = self.trainer[new_trainer].metrics
return self.trainer[new_trainer]
[docs]
def run(
self,
hpo_iterations: int = None,
bootstrap_iterations: int = None,
prior: List[List[str]] = None,
**kwargs):
"""
Run the experiment.
Args:
hpo_iterations (int, optional): Number of HPO trials for REX.
Defaults to None.
bootstrap_iterations (int, optional): Number of bootstrap trials
for REX. Defaults to None.
"""
self.create_experiments()
self.fit_experiments(
hpo_iterations, bootstrap_iterations, prior, **kwargs)
self.combine_and_evaluate_dags(prior=prior)
[docs]
def save(self, full_filename_path: str) -> None:
"""
Save the model as an Experiment object.
Args:
full_filename_path (str): A full path where to save the model,
including the filename.
"""
assert self.trainer, "No trainer to save"
assert full_filename_path, "No output path specified"
full_dir_path = os.path.dirname(full_filename_path)
# Check only if not local dir
if full_dir_path != "." and full_dir_path != "":
assert os.path.exists(full_dir_path), \
f"Output directory {full_dir_path} does not exist"
else:
full_dir_path = os.getcwd()
saved_as = utils.save_experiment(
os.path.basename(full_filename_path), full_dir_path,
self.trainer, overwrite=False)
print(f"Saved model as: {saved_as}", flush=True)
[docs]
def load(self, model_path: str) -> Experiment:
"""
Load the model from a pickle file.
Args:
model_path (str): Path to the pickle file containing the model
Returns:
Experiment: The loaded Experiment object
"""
with open(model_path, 'rb') as f:
self.trainer = pickle.load(f)
print(f"Loaded model from: {model_path}", flush=True)
# Set the dag and metrics
self.dag = self.trainer[list(self.trainer.keys())[-1]].dag
self.metrics = self.trainer[list(self.trainer.keys())[-1]].metrics
return self.trainer
[docs]
def printout_results(self, graph, metrics):
"""
This method prints the DAG to stdout in hierarchical order.
Parameters:
-----------
dag : nx.DiGraph
The DAG to be printed.
"""
if len(graph.edges()) == 0:
print("Empty graph")
return
print("Resulting Graph:\n---------------")
def dfs(node, visited, indent=""):
if node in visited:
return # Avoid revisiting nodes
visited.add(node)
# Print edges for this node
for neighbor in graph.successors(node):
print(f"{indent}{node} -> {neighbor}")
dfs(neighbor, visited, indent + " ")
visited = set()
# Start traversal from all nodes without predecessors (roots)
for node in graph.nodes:
if graph.in_degree(node) == 0:
dfs(node, visited)
# Handle disconnected components (not reachable from any "root")
for node in graph.nodes:
if node not in visited:
dfs(node, visited)
if metrics is not None:
print("\nGraph Metrics:\n-------------")
print(metrics)
[docs]
def export(self, output_file: str) -> str:
"""
This method exports the DAG to a DOT file.
Parameters:
-----------
dag : nx.DiGraph
The DAG to be exported.
output_file : str
The path to the output DOT file.
Returns:
--------
str
The path to the output DOT file.
"""
saved_as = utils.graph_to_dot_file(
self.trainer[list(self.trainer.keys())[-1]].dag, output_file)
return saved_as
[docs]
def plot(
self,
show_metrics: bool = False,
show_node_fill: bool = True,
title: str = None,
ax: plt.Axes = None,
figsize: Tuple[int, int] = (5, 5),
dpi: int = 75,
save_to_pdf: str = None,
layout: str = 'dot',
**kwargs
):
"""
This method plots the DAG using networkx and matplotlib.
Parameters:
-----------
show_metrics : bool, optional
Whether to show the metrics on the plot. Defaults to False.
show_node_fill : bool, optional
Whether to fill the nodes with color. Defaults to True.
title : str, optional
The title of the plot. Defaults to None.
ax : plt.Axes, optional
The matplotlib axes to plot on. Defaults to None.
figsize : Tuple[int, int], optional
The size of the plot. Defaults to (5, 5).
dpi : int, optional
The DPI of the plot. Defaults to 75.
save_to_pdf : str, optional
The path to save the plot as a PDF. Defaults to None.
layout : str, optional
The layout to use for the plot. Defaults to 'dot'. Other option
is 'circular'.
"""
model = self.trainer[list(self.trainer.keys())[-1]]
if model.ref_graph is not None:
ref_graph = model.ref_graph
else:
ref_graph = None
plot.dag(
graph=model.dag, reference=ref_graph, show_metrics=show_metrics,
show_node_fill=show_node_fill, title=title, ax=ax,
figsize=figsize, dpi=dpi, save_to_pdf=save_to_pdf, layout=layout,
**kwargs)
@property
def model(self):
return self.trainer[list(self.trainer.keys())[-1]]