causalexplain.common package#
Submodules#
Configuration for metric types used across the causalexplain package.
A module to run experiments with the causalexplain package, and simplify the process of loading and saving experiments in notebooks.
Example
>> from causalexplain.common.notebook import Experiment >> experiment = Experiment(“linear”, csv_filename=”linear.csv”) >> rex = experiment.load()
2023, 2024 J. Renero
- class BaseExperiment(input_path, output_path, train_anyway=False, save_anyway=False, scale=False, train_size=0.9, random_state=42, verbose=False)[source]#
Bases:
objectBase class for notebook experiments.
- Parameters:
input_path (str) – Path to the input data.
output_path (str) – Path to save experiment outputs.
train_anyway (bool, optional) – Whether to train even if cached outputs exist.
save_anyway (bool, optional) – Whether to overwrite cached outputs.
train_size (float, optional) – Proportion of samples used for training.
random_state (int, optional) – Random seed for reproducibility.
verbose (bool, optional) – Whether to display verbose output.
- __init__(input_path, output_path, train_anyway=False, save_anyway=False, scale=False, train_size=0.9, random_state=42, verbose=False)[source]#
- prepare_experiment_input(experiment_filename, csv_filename=None, dot_filename=None, data=None, data_is_processed=False, train_idx=None, test_idx=None)[source]#
Loads the data and
splits it into train and test,
scales it
loads the reference graph from the dot file, which has to be named as the experiment file, with the .dot extension
- class Experiment(experiment_name, csv_filename=None, dot_filename=None, data=None, data_is_processed=False, train_idx=None, test_idx=None, model_type='nn', input_path='/Users/renero/phd/data/RC4/', output_path='/Users/renero/phd/output/RC4/', train_size=0.9, random_state=42, verbose=False)[source]#
Bases:
BaseExperimentNotebook-friendly wrapper for training and evaluating estimators.
- estimator_name = None#
- lingam: DirectLiNGAM | None = None#
- __init__(experiment_name, csv_filename=None, dot_filename=None, data=None, data_is_processed=False, train_idx=None, test_idx=None, model_type='nn', input_path='/Users/renero/phd/data/RC4/', output_path='/Users/renero/phd/output/RC4/', train_size=0.9, random_state=42, verbose=False)[source]#
Initialize an experiment session.
This file includes all the plot methods for the causal graph
Renero, 2022, 2023
- format_graph(G, Gt=None, ok_color='green', inv_color='lightgreen', wrong_color='black', missing_color=None)[source]#
- draw_graph_subplot(G, root_causes=None, layout=None, title=None, ax=None, **kwargs)[source]#
Draw a graph in a subplot.
- set_colormap(color_threshold=0.15, max_color=0.8, cmap_name='OrRd')[source]#
Set the colormap for the graph edges.
- dag2dot(G, undirected=False, name='my_dotgraph', odots=True)[source]#
Display a DOT of the graph in the notebook.
- Parameters:
G (nx.Graph or DiGraph) – the graph to be represented.
undirected (bool) – default False, indicates whether the plot is forced to contain no arrows.
plot (bool) – default is True, this flag can be used to simply generate the object but not plot, in case the object is needed to generate a PNG version of the DOT, for instance.
name (str) – the name to be embedded in the Dot object for this graph.
odots (bool) – represent edges with biconnections with circles (odots). if this is set to false, then the edge simply has no arrowheads.
- Returns:
pydot.Dot object
- Return type:
Dot
- values_distribution(values, threshold=None, **kwargs)[source]#
Plot density and cumulative density for a set of values.
- Parameters:
values (array-like) – Values to plot.
threshold (float, optional) – Optional threshold marker for the ECDF plot.
**kwargs – Additional keyword arguments for customizing the plot.
- correlation_matrix(corr_matrix, sorted_colnames=None, threshold=0.5, ax=None, **kwargs)[source]#
Plot a thresholded correlation matrix.
- Parameters:
corr_matrix (pandas.DataFrame) – Correlation matrix.
sorted_colnames (list[str], optional) – Column ordering for the plot.
threshold (float, optional) – Minimum absolute correlation to display.
ax (matplotlib.axes.Axes, optional) – Axes to draw into.
**kwargs – Plot customization options (title, fontsize, fontname, xrot).
- hierarchies(hierarchies, threshold=0.5, **kwargs)[source]#
Plot hierarchical clustering and its correlation matrix.
- dag(graph, reference=None, root_causes=None, show_metrics=False, show_node_fill=True, title=None, ax=None, figsize=(5, 5), dpi=75, save_to_pdf=None, layout='dot', **kwargs)[source]#
Compare two graphs using dot.
- dags(dags, ref_graph, titles, figsize=(15, 12), dpi=300)[source]#
Plots multiple directed acyclic graphs (DAGs) in a grid layout.
Parameters: - dags (list): List of DAGs to plot. - ref_graph: Reference graph used for layout. - titles (list): List of titles for each DAG. - figsize (tuple): Figure size (default: (15, 12)). - dpi (int): Dots per inch (default: 300).
Raises: - ValueError: If there are too many DAGs to plot.
Returns: - None
- shap_discrepancies(shaps, target_name, threshold=100.0, regression_line=False, reduced=False, **kwargs)[source]#
Plot the discrepancies of the SHAP values.
- deprecated_dags(graph, reference=None, names=None, figsize=(10, 5), dpi=75, save_to_pdf=None, **kwargs)[source]#
Compare two graphs using dot.
- score_by_method(metrics, metric, methods, **kwargs)[source]#
Plot a metric distribution grouped by method.
- scores_by_method(metrics, methods=None, title=None, pdf_filename=None, **kwargs)[source]#
Plot the metrics for all the experiments matching the input pattern
- Parameters:
metrics (pandas DataFrame) – A DataFrame with the metrics for all the experiments
method_types (list) – The list of methods to plot. If None, all the methods will be plotted The methods included are: ‘rex_mlp’, ‘rex_gbt’, ‘rex_intersection’ and ‘rex_union’
title (str) – The title of the plot
pdf_filename (str) – The filename to save the plot to. If None, the plot will be displayed on screen, otherwise it will be saved to the specified filename.
parameters (Optional)
(tuple (- ylim)
optional) (The y-axis limits of the plot. Default is None.)
(int (- dpi) – Default is 300.
optional) – Default is 300.
(tuple
optional)
- score_by_subtype(metrics, score_name, methods=None, pdf_filename=None, **kwargs)[source]#
Plot score distributions across data subtypes.
- combined_metrics(metrics, metric_types=None, title=None, acyclic=False, medians=False, pdf_filename=None)[source]#
Plot the metrics for all the experiments matching the input pattern
Utility functions for causalexplain (C) J. Renero, 2022, 2023, 2024, 2025
- save_experiment(obj_name, folder, results, overwrite=False)[source]#
Creates a folder for the experiment and saves results. Results is a dictionary that will be saved as an opaque pickle. When the experiment will require to be loaded, the only parameter needed are the folder name.
- Parameters:
obj_name (str) – the name to be given to the pickle file to be saved. If a file already exists with that name, a file with same name and a extension will be generated.
folder (str) – a full path to the folder where the experiment is to be saved. If the folder does not exist it will be created.
results (obj) – the object to be saved as experiment. This is typically a dictionary with different items representing different parts of the experiment.
overwrite (bool) – If True, then the experiment is saved even if a file with the same name already exists. If False, then the experiment is saved only if a file with the same name does not exist.
- Returns:
(str) The name under which the experiment has been saved
- Return type:
- load_experiment(obj_name, folder)[source]#
Loads a pickle from a given folder name. It is not necessary to add the “pickle” extension to the experiment name.
- valid_output_name(filename, path, extension=None)[source]#
Builds a valid name. In case there’s another file which already exists adds a number (1, 2, …) until finds a valid filename which does not exist.
- Returns:
The filename if the name is valid and file does not exists, – None otherwise.
Params
——
filename (str) – The base filename to be set.
path (str) – The path where trying to set the filename
extension (str) – The extension of the file, without the dot ‘.’ If no extension is specified, any extension is searched to avoid returning a filepath of an existing file, no matter what extension it has.
- Return type:
- graph_from_dictionary(d)[source]#
Builds a graph from a dictionary like {‘u’: [‘v’, ‘w’], ‘x’: [‘y’]}. The elements of the list can be tuples including weight
- Parameters:
d (dict) – A dictionary of the form {‘u’: [‘v’, ‘w’], ‘x’: [‘y’]}, where an edge between ‘u’ goes towards ‘v’ and ‘w’, and also an edge from ‘x’ goes towards ‘y’. The format can also be like: {‘u’: [(‘v’, 0.2), (‘w’, 0.7)], ‘x’: [(‘y’, 0.5)]}, where the values in the tuple are interpreted as weights.
- Returns:
networkx.DiGraph with the nodes and edges specified.
- Return type:
Graph | DiGraph
- graph_from_adjacency(adjacency, node_labels=None, th=0.0, inverse=False, absolute_values=False)[source]#
Manually parse the adj matrix to shape a dot graph
- Parameters:
adjacency (ndarray) – a numpy adjacency matrix
node_labels – an array of same length as nr of columns in the adjacency matrix containing the labels to use with every node.
th – (float) weight threshold to be considered a valid edge.
inverse (bool) – Set to true if rows in adjacency reflects where edges are comming from, instead of where are they going to.
absolute_values (bool) – Take absolute value of weight label to check if its greater than the threshold.
- Returns:
The Graph (DiGraph)
- Return type:
DiGraph
- graph_from_adjacency_file(file, labels=None, th=0.0, sep=',', header=True)[source]#
Read Adjacency matrix from a file and return a Graph
- Parameters:
labels – (List[str]) the list of node names to be used. If None, the node names are extracted from the adj file. The names must be already sorted in the same order as the adjacency matrix.
th – (float) weight threshold to be considered a valid edge.
sep – (str) the separator used in the file
header (bool) – (bool) whether the file has a header. If True, then ‘infer’
header. (is used to read the)
- Returns:
DiGraph, DataFrame
- Return type:
- graph_to_adjacency(graph, labels, weight_label='weight')[source]#
A method to generate the adjacency matrix of the graph. Labels are sorted for better readability.
- Parameters:
graph (Graph | DiGraph) – (Union[Graph, DiGraph]) the graph to be converted.
node_names – (List[str]) the list of node names to be used. If None, the node names are extracted from the graph. The names must be already sorted in the same order as the adjacency matrix.
weight_label (str) – the label used to identify the weights.
- Returns:
- (numpy.ndarray) A 2d array containing the adjacency matrix of
the graph.
- Return type:
graph
- graph_to_adjacency_file(graph, output_file, labels)[source]#
A method to write the adjacency matrix of the graph to a file. If graph has weights, these are the values stored in the adjacency matrix.
- select_device(force=None)[source]#
Selects the device to be used for training. If force is not None, then the device is forced to be the one specified. If force is None, then the device is selected based on the availability of GPUs. If no GPUs are available, then the CPU is selected.
- Parameters:
force (str) – If not None, then the device is forced to be the one specified. If None, then the device is selected based on the availability of GPUs. If no GPUs are available, then the CPU is selected.
- Returns:
(str) The device to be used for training.
- Raises:
ValueError – If the forced device is not available or not a valid device.
- Return type:
- resolve_device(requested=None)[source]#
Resolve an explicit device request, defaulting to CPU.
- Parameters:
requested (Optional[str]) – Requested device (“cpu”, “cuda”, “mps”).
- Returns:
The resolved device string.
- Return type:
- Raises:
ValueError – If the requested device is invalid or unavailable.
- graph_intersection(g1, g2)[source]#
Returns the intersection of two graphs. The intersection is defined as the set of nodes and edges that are common to both graphs. The intersection is performed on the nodes and edges, not on the attributes of the nodes and edges.
- Parameters:
g1 (networkx.DiGraph) – The first graph.
g2 (networkx.DiGraph) – The second graph.
- Returns:
(networkx.DiGraph) The intersection of the two graphs.
- Return type:
DiGraph
- graph_union(g1, g2)[source]#
Returns the union of two graphs. The union is defined as the set of nodes and edges that are in both graphs. The union is performed on the nodes and edges, not on the attributes of the nodes and edges.
- digraph_from_connected_features(X, feature_names, models, connections, root_causes, prior=None, reciprocity=True, anm_iterations=10, max_anm_samples=400, verbose=False)[source]#
Builds a directed graph from a set of features and their connections. The connections are determined by the SHAP values of the features. The orientation of the edges is determined by the causal direction of the features. The orientation is determined by the method of edge orientation defined in the independence module.
- correct_edge_from_prior(dag, u, v, prior, verbose)[source]#
Orient an edge using prior temporal knowledge.
- valid_candidates_from_prior(feature_names, effect, prior)[source]#
This method returns the valid candidates for a given effect, based on the prior information. The prior information is a list of lists, where each list contains the features that are known to be in the same level of the hierarchy.
- break_cycles_using_prior(original_dag, prior, verbose=False)[source]#
This method remove potential edges in the DAG by removing edges that are incompatible with the temporal relationship defined in the prior knowledge. Any edge pointing backwards, according to the order established in the prior knowledge, is removed.
- potential_misoriented_edges(loop, discrepancies, verbose=False)[source]#
Find potential misoriented edges in a loop based on discrepancies. The loop is a list of nodes that form a loop in a DAG. The discrepancies is a dictionary containing the discrepancies between nodes. The function returns a list of potential misoriented edges, sorted by the difference in goodness-of-fit scores.
- Parameters:
- Returns:
- A list of potential misoriented edges,
sorted by the difference in goodness-of-fit scores.
- Return type:
- break_cycles_if_present(dag, discrepancies, prior=None, verbose=False)[source]#
Break cycles in a DAG using discrepancies and optional priors.
- stringfy_object(object_)[source]#
Convert an object into a string representation, including its attributes.
- get_feature_names(X)[source]#
Get the feature names from the input data. The feature names can be extracted from a pandas DataFrame, a numpy array or a list.
- get_feature_types(X)[source]#
Get the feature types from the input data. The feature types can be binary, multiclass or continuous. The classification is done based on the number of unique values in the array. If the array has only two unique values, then the variable is classified as binary. If the unique values are integers, then the variable is classified as multiclass. Otherwise, the variable is classified as continuous.
- cast_categoricals_to_int(X)[source]#
Cast all categorical features in the dataframe to integer values.
- Parameters:
X (pd.DataFrame) – The dataframe to cast.
- Returns:
The dataframe with all categorical features cast to integer values.
- Return type:
pd.DataFrame
- find_crossing_point(f1, f2, x_values)[source]#
This function finds the exact crossing point between two curves defined by f1 and f2 over x_values. It interpolates between points to provide a more accurate crossing point.
- format_time(seconds)[source]#
Convert the time in seconds to a more convenient time unit.
- Parameters:
seconds (float) – The time in seconds.
- Returns:
The time in the most appropriate unit and the unit string.
- Return type:
Examples
>>> format_time(1.0) (1.0, 's') >>> format_time(60.0) (1.0, 'm') >>> format_time(3600.0) (1.0, 'h') >>> format_time(86400.0) (1.0, 'd') >>> format_time(604800.0) (1.0, 'w') >>> format_time(2592000.0) (1.0, 'm') >>> format_time(31536000.0) (1.0, 'y') >>> format_time(315360000.0) (1.0, 'a')
- combine_dags(dag1, dag2, discrepancies, prior=None)[source]#
Combine two directed acyclic graphs (DAGs) into a single DAG.
- Parameters:
dag1 (nx.DiGraph) – The first DAG.
dag2 (nx.DiGraph) – The second DAG.
discrepancies (Dict) – A Dictionary containing the permutation importances for each edge in the DAGs.
prior (Optional[List[List[str]]], optional) – A list of lists containing the prior knowledge about the edges in the DAGs. The lists define a hierarchy of edges that represent a temporal relation in cause and effect. If a node is in the first list, then it is a root cause. If a node is in the second list, then it is caused by the nodes in the first list or the second list, and so on.
- Returns:
A tuple containing four graphs: the union of the two DAGs, the intersection of the two DAGs, the union of the two DAGs after removing cycles, and the intersection of the two DAGs after removing cycles.
- Return type:
Tuple[nx.DiGraph, nx.DiGraph, nx.DiGraph, nx.DiGraph]