causalexplain.common package#

Submodules#

Configuration for metric types used across the causalexplain package.

A module to run experiments with the causalexplain package, and simplify the process of loading and saving experiments in notebooks.

Example

>> from causalexplain.common.notebook import Experiment >> experiment = Experiment(“linear”, csv_filename=”linear.csv”) >> rex = experiment.load()

  1. 2023, 2024 J. Renero

class BaseExperiment(input_path, output_path, train_anyway=False, save_anyway=False, scale=False, train_size=0.9, random_state=42, verbose=False)[source]#

Bases: object

Base class for notebook experiments.

Parameters:
  • input_path (str) – Path to the input data.

  • output_path (str) – Path to save experiment outputs.

  • train_anyway (bool, optional) – Whether to train even if cached outputs exist.

  • save_anyway (bool, optional) – Whether to overwrite cached outputs.

  • train_size (float, optional) – Proportion of samples used for training.

  • random_state (int, optional) – Random seed for reproducibility.

  • verbose (bool, optional) – Whether to display verbose output.

model_type: str | None = None#
__init__(input_path, output_path, train_anyway=False, save_anyway=False, scale=False, train_size=0.9, random_state=42, verbose=False)[source]#
property data: DataFrame | None#
property train_data: DataFrame#
property test_data: DataFrame#
prepare_experiment_input(experiment_filename, csv_filename=None, dot_filename=None, data=None, data_is_processed=False, train_idx=None, test_idx=None)[source]#
  • Loads the data and

  • splits it into train and test,

  • scales it

  • loads the reference graph from the dot file, which has to be named as the experiment file, with the .dot extension

experiment_exists(name)[source]#

Checks whether the experiment exists in the output path

create_estimator(estimator_name, name, **kwargs)[source]#

Create an estimator instance from the registry.

class Experiment(experiment_name, csv_filename=None, dot_filename=None, data=None, data_is_processed=False, train_idx=None, test_idx=None, model_type='nn', input_path='/Users/renero/phd/data/RC4/', output_path='/Users/renero/phd/output/RC4/', train_size=0.9, random_state=42, verbose=False)[source]#

Bases: BaseExperiment

Notebook-friendly wrapper for training and evaluating estimators.

estimator_name = None#
rex: Rex | None = None#
pc: PC | None = None#
lingam: DirectLiNGAM | None = None#
ges: GES | None = None#
fci: FCI | None = None#
cam: CAM | None = None#
notears: NOTEARS | None = None#
estimator: Any | None = None#
__init__(experiment_name, csv_filename=None, dot_filename=None, data=None, data_is_processed=False, train_idx=None, test_idx=None, model_type='nn', input_path='/Users/renero/phd/data/RC4/', output_path='/Users/renero/phd/output/RC4/', train_size=0.9, random_state=42, verbose=False)[source]#

Initialize an experiment session.

fit(estimator_name='rex', **kwargs)[source]#

Fit the selected estimator using the experiment data.

predict(estimator='rex', **kwargs)[source]#

Run prediction for the active estimator.

fit_predict(estimator='rex', **kwargs)[source]#

Fit and predict with the selected estimator.

load(exp_name=None)[source]#

Load a previously saved experiment.

save(exp_name=None, overwrite=False)[source]#

Save the experiment data to disk.

This file includes all the plot methods for the causal graph

    1. Renero, 2022, 2023

class ShapSummary(*args, **kwargs)[source]#

Bases: Protocol

is_fitted_: bool#
feature_names: Sequence[str]#
__init__(*args, **kwargs)#
class ShapDiscrepancy(*args, **kwargs)[source]#

Bases: Protocol

is_fitted_: bool#
feature_names: Sequence[str]#
shap_discrepancies: dict#
X_test: DataFrame#
shap_scaled_values: dict#
__init__(*args, **kwargs)#
setup_plot(**kwargs)[source]#

Customize figure settings.

Parameters:
  • tex (bool, optional) – use LaTeX. Defaults to True.

  • font (str, optional) – font type. Defaults to “serif”.

  • dpi (int, optional) – dots per inch. Defaults to 180.

add_grid(ax, lines=True, locations=None)[source]#

Add a grid to the current plot.

Parameters:
  • ax (Axis) – axis object in which to draw the grid.

  • lines (bool, optional) – add lines to the grid. Defaults to True.

  • locations (tuple, optional) – (xminor, xmajor, yminor, ymajor). Defaults to None.

subplots(plot_func, *plot_args, **kwargs)[source]#

Plots a set of subplots.

format_graph(G, Gt=None, ok_color='green', inv_color='lightgreen', wrong_color='black', missing_color=None)[source]#
draw_graph_subplot(G, root_causes=None, layout=None, title=None, ax=None, **kwargs)[source]#

Draw a graph in a subplot.

Parameters:
  • G (nx.DiGraph) – The graph to be drawn.

  • layout (dict) – The layout of the graph.

  • title (str) – The title of the graph.

  • ax (Axes) – The axis in which to draw the graph.

  • **formatting_kwargs (dict) – The formatting arguments for the graph.

Return type:

None

cleanup_graph(G)[source]#
set_colormap(color_threshold=0.15, max_color=0.8, cmap_name='OrRd')[source]#

Set the colormap for the graph edges.

Parameters:
  • color_threshold (float) – The threshold for the color of the values in the plot, below which the color will be white.

  • max_color (float) – The maximum color for the edges, above which the color will be red.

Returns:

The colormap to be used in the plot.

Return type:

LinearColormap

dag2dot(G, undirected=False, name='my_dotgraph', odots=True)[source]#

Display a DOT of the graph in the notebook.

Parameters:
  • G (nx.Graph or DiGraph) – the graph to be represented.

  • undirected (bool) – default False, indicates whether the plot is forced to contain no arrows.

  • plot (bool) – default is True, this flag can be used to simply generate the object but not plot, in case the object is needed to generate a PNG version of the DOT, for instance.

  • name (str) – the name to be embedded in the Dot object for this graph.

  • odots (bool) – represent edges with biconnections with circles (odots). if this is set to false, then the edge simply has no arrowheads.

Returns:

pydot.Dot object

Return type:

Dot

values_distribution(values, threshold=None, **kwargs)[source]#

Plot density and cumulative density for a set of values.

Parameters:
  • values (array-like) – Values to plot.

  • threshold (float, optional) – Optional threshold marker for the ECDF plot.

  • **kwargs – Additional keyword arguments for customizing the plot.

correlation_matrix(corr_matrix, sorted_colnames=None, threshold=0.5, ax=None, **kwargs)[source]#

Plot a thresholded correlation matrix.

Parameters:
  • corr_matrix (pandas.DataFrame) – Correlation matrix.

  • sorted_colnames (list[str], optional) – Column ordering for the plot.

  • threshold (float, optional) – Minimum absolute correlation to display.

  • ax (matplotlib.axes.Axes, optional) – Axes to draw into.

  • **kwargs – Plot customization options (title, fontsize, fontname, xrot).

hierarchies(hierarchies, threshold=0.5, **kwargs)[source]#

Plot hierarchical clustering and its correlation matrix.

dag(graph, reference=None, root_causes=None, show_metrics=False, show_node_fill=True, title=None, ax=None, figsize=(5, 5), dpi=75, save_to_pdf=None, layout='dot', **kwargs)[source]#

Compare two graphs using dot.

dags(dags, ref_graph, titles, figsize=(15, 12), dpi=300)[source]#

Plots multiple directed acyclic graphs (DAGs) in a grid layout.

Parameters: - dags (list): List of DAGs to plot. - ref_graph: Reference graph used for layout. - titles (list): List of titles for each DAG. - figsize (tuple): Figure size (default: (15, 12)). - dpi (int): Dots per inch (default: 300).

Raises: - ValueError: If there are too many DAGs to plot.

Returns: - None

shap_values(shaps, **kwargs)[source]#
shap_discrepancies(shaps, target_name, threshold=100.0, regression_line=False, reduced=False, **kwargs)[source]#

Plot the discrepancies of the SHAP values.

deprecated_dags(graph, reference=None, names=None, figsize=(10, 5), dpi=75, save_to_pdf=None, **kwargs)[source]#

Compare two graphs using dot.

score_by_method(metrics, metric, methods, **kwargs)[source]#

Plot a metric distribution grouped by method.

scores_by_method(metrics, methods=None, title=None, pdf_filename=None, **kwargs)[source]#

Plot the metrics for all the experiments matching the input pattern

Parameters:
  • metrics (pandas DataFrame) – A DataFrame with the metrics for all the experiments

  • method_types (list) – The list of methods to plot. If None, all the methods will be plotted The methods included are: ‘rex_mlp’, ‘rex_gbt’, ‘rex_intersection’ and ‘rex_union’

  • title (str) – The title of the plot

  • pdf_filename (str) – The filename to save the plot to. If None, the plot will be displayed on screen, otherwise it will be saved to the specified filename.

  • parameters (Optional)

  • (tuple (- ylim)

  • optional) (The y-axis limits of the plot. Default is None.)

  • (int (- dpi) – Default is 300.

  • optional) – Default is 300.

  • (tuple

  • optional)

score_by_subtype(metrics, score_name, methods=None, pdf_filename=None, **kwargs)[source]#

Plot score distributions across data subtypes.

combined_metrics(metrics, metric_types=None, title=None, acyclic=False, medians=False, pdf_filename=None)[source]#

Plot the metrics for all the experiments matching the input pattern

Parameters:
  • metrics (dict) – A dictionary with the metrics for all the experiments

  • title (str) – The title of the plot

  • acyclic (bool) – Whether to plot the metrics for the no_cycles graphs

  • medians (bool) – Whether to plot the median lines

Return type:

None

latex_table_by_datatype(df, method, metrics=None)[source]#
latex_table_by_method(df, methods=None, metric_names=None)[source]#

Utility functions for causalexplain (C) J. Renero, 2022, 2023, 2024, 2025

save_experiment(obj_name, folder, results, overwrite=False)[source]#

Creates a folder for the experiment and saves results. Results is a dictionary that will be saved as an opaque pickle. When the experiment will require to be loaded, the only parameter needed are the folder name.

Parameters:
  • obj_name (str) – the name to be given to the pickle file to be saved. If a file already exists with that name, a file with same name and a extension will be generated.

  • folder (str) – a full path to the folder where the experiment is to be saved. If the folder does not exist it will be created.

  • results (obj) – the object to be saved as experiment. This is typically a dictionary with different items representing different parts of the experiment.

  • overwrite (bool) – If True, then the experiment is saved even if a file with the same name already exists. If False, then the experiment is saved only if a file with the same name does not exist.

Returns:

(str) The name under which the experiment has been saved

Return type:

str

load_experiment(obj_name, folder)[source]#

Loads a pickle from a given folder name. It is not necessary to add the “pickle” extension to the experiment name.

valid_output_name(filename, path, extension=None)[source]#

Builds a valid name. In case there’s another file which already exists adds a number (1, 2, …) until finds a valid filename which does not exist.

Returns:

  • The filename if the name is valid and file does not exists, – None otherwise.

  • Params

  • ——

  • filename (str) – The base filename to be set.

  • path (str) – The path where trying to set the filename

  • extension (str) – The extension of the file, without the dot ‘.’ If no extension is specified, any extension is searched to avoid returning a filepath of an existing file, no matter what extension it has.

Return type:

str

graph_from_dot_file(dot_file)[source]#

Reads a dot file and returns a networkx DiGraph object.

Parameters:

dot_file (str | Path) – (str or Path) The full path to the dot file to be read.

Returns:

networkx.DiGraph or None if the file could not be read or parsed.

Return type:

DiGraph | None

graph_from_dictionary(d)[source]#

Builds a graph from a dictionary like {‘u’: [‘v’, ‘w’], ‘x’: [‘y’]}. The elements of the list can be tuples including weight

Parameters:

d (dict) – A dictionary of the form {‘u’: [‘v’, ‘w’], ‘x’: [‘y’]}, where an edge between ‘u’ goes towards ‘v’ and ‘w’, and also an edge from ‘x’ goes towards ‘y’. The format can also be like: {‘u’: [(‘v’, 0.2), (‘w’, 0.7)], ‘x’: [(‘y’, 0.5)]}, where the values in the tuple are interpreted as weights.

Returns:

networkx.DiGraph with the nodes and edges specified.

Return type:

Graph | DiGraph

graph_from_adjacency(adjacency, node_labels=None, th=0.0, inverse=False, absolute_values=False)[source]#

Manually parse the adj matrix to shape a dot graph

Parameters:
  • adjacency (ndarray) – a numpy adjacency matrix

  • node_labels – an array of same length as nr of columns in the adjacency matrix containing the labels to use with every node.

  • th – (float) weight threshold to be considered a valid edge.

  • inverse (bool) – Set to true if rows in adjacency reflects where edges are comming from, instead of where are they going to.

  • absolute_values (bool) – Take absolute value of weight label to check if its greater than the threshold.

Returns:

The Graph (DiGraph)

Return type:

DiGraph

graph_from_adjacency_file(file, labels=None, th=0.0, sep=',', header=True)[source]#

Read Adjacency matrix from a file and return a Graph

Parameters:
  • file (Path | str) – (str) the full path of the file to read

  • labels – (List[str]) the list of node names to be used. If None, the node names are extracted from the adj file. The names must be already sorted in the same order as the adjacency matrix.

  • th – (float) weight threshold to be considered a valid edge.

  • sep – (str) the separator used in the file

  • header (bool) – (bool) whether the file has a header. If True, then ‘infer’

  • header. (is used to read the)

Returns:

DiGraph, DataFrame

Return type:

Tuple[DiGraph, DataFrame]

graph_to_adjacency(graph, labels, weight_label='weight')[source]#

A method to generate the adjacency matrix of the graph. Labels are sorted for better readability.

Parameters:
  • graph (Graph | DiGraph) – (Union[Graph, DiGraph]) the graph to be converted.

  • node_names – (List[str]) the list of node names to be used. If None, the node names are extracted from the graph. The names must be already sorted in the same order as the adjacency matrix.

  • weight_label (str) – the label used to identify the weights.

Returns:

(numpy.ndarray) A 2d array containing the adjacency matrix of

the graph.

Return type:

graph

graph_to_adjacency_file(graph, output_file, labels)[source]#

A method to write the adjacency matrix of the graph to a file. If graph has weights, these are the values stored in the adjacency matrix.

Parameters:
  • graph (Graph | DiGraph) – (Union[Graph, DiGraph] the graph to be saved

  • output_file (Path | str) – (str) The full path where graph is to be saved

graph_to_dot_file(graph, output_file)[source]#

A method to write the graph in dot format to a file.

Parameters:
  • graph (Graph | DiGraph) – (Union[Graph, DiGraph] the graph to be saved

  • output_file (Path | str) – (str) The full path where graph is to be saved

Returns:

(str) The path to the output DOT file.

Return type:

str

select_device(force=None)[source]#

Selects the device to be used for training. If force is not None, then the device is forced to be the one specified. If force is None, then the device is selected based on the availability of GPUs. If no GPUs are available, then the CPU is selected.

Parameters:

force (str) – If not None, then the device is forced to be the one specified. If None, then the device is selected based on the availability of GPUs. If no GPUs are available, then the CPU is selected.

Returns:

(str) The device to be used for training.

Raises:

ValueError – If the forced device is not available or not a valid device.

Return type:

str

resolve_device(requested=None)[source]#

Resolve an explicit device request, defaulting to CPU.

Parameters:

requested (Optional[str]) – Requested device (“cpu”, “cuda”, “mps”).

Returns:

The resolved device string.

Return type:

str

Raises:

ValueError – If the requested device is invalid or unavailable.

graph_intersection(g1, g2)[source]#

Returns the intersection of two graphs. The intersection is defined as the set of nodes and edges that are common to both graphs. The intersection is performed on the nodes and edges, not on the attributes of the nodes and edges.

Parameters:
  • g1 (networkx.DiGraph) – The first graph.

  • g2 (networkx.DiGraph) – The second graph.

Returns:

(networkx.DiGraph) The intersection of the two graphs.

Return type:

DiGraph

graph_union(g1, g2)[source]#

Returns the union of two graphs. The union is defined as the set of nodes and edges that are in both graphs. The union is performed on the nodes and edges, not on the attributes of the nodes and edges.

digraph_from_connected_features(X, feature_names, models, connections, root_causes, prior=None, reciprocity=True, anm_iterations=10, max_anm_samples=400, verbose=False)[source]#

Builds a directed graph from a set of features and their connections. The connections are determined by the SHAP values of the features. The orientation of the edges is determined by the causal direction of the features. The orientation is determined by the method of edge orientation defined in the independence module.

correct_edge_from_prior(dag, u, v, prior, verbose)[source]#

Orient an edge using prior temporal knowledge.

valid_candidates_from_prior(feature_names, effect, prior)[source]#

This method returns the valid candidates for a given effect, based on the prior information. The prior information is a list of lists, where each list contains the features that are known to be in the same level of the hierarchy.

break_cycles_using_prior(original_dag, prior, verbose=False)[source]#

This method remove potential edges in the DAG by removing edges that are incompatible with the temporal relationship defined in the prior knowledge. Any edge pointing backwards, according to the order established in the prior knowledge, is removed.

potential_misoriented_edges(loop, discrepancies, verbose=False)[source]#

Find potential misoriented edges in a loop based on discrepancies. The loop is a list of nodes that form a loop in a DAG. The discrepancies is a dictionary containing the discrepancies between nodes. The function returns a list of potential misoriented edges, sorted by the difference in goodness-of-fit scores.

Parameters:
  • loop (List[str]) – The list of nodes in the loop.

  • discrepancies (Dict) – A dictionary containing discrepancies between nodes. It is typically a positive float number between 0 and 1.

  • verbose (bool, optional) – Whether to print verbose output. Defaults to False.

Returns:

A list of potential misoriented edges,

sorted by the difference in goodness-of-fit scores.

Return type:

List[Tuple[str, str, float]]

break_cycles_if_present(dag, discrepancies, prior=None, verbose=False)[source]#

Break cycles in a DAG using discrepancies and optional priors.

stringfy_object(object_)[source]#

Convert an object into a string representation, including its attributes.

Parameters:

object (object) – The object to be converted.

Returns:

A string representation of the object and its attributes.

Return type:

str

get_feature_names(X)[source]#

Get the feature names from the input data. The feature names can be extracted from a pandas DataFrame, a numpy array or a list.

get_feature_types(X)[source]#

Get the feature types from the input data. The feature types can be binary, multiclass or continuous. The classification is done based on the number of unique values in the array. If the array has only two unique values, then the variable is classified as binary. If the unique values are integers, then the variable is classified as multiclass. Otherwise, the variable is classified as continuous.

cast_categoricals_to_int(X)[source]#

Cast all categorical features in the dataframe to integer values.

Parameters:

X (pd.DataFrame) – The dataframe to cast.

Returns:

The dataframe with all categorical features cast to integer values.

Return type:

pd.DataFrame

find_crossing_point(f1, f2, x_values)[source]#

This function finds the exact crossing point between two curves defined by f1 and f2 over x_values. It interpolates between points to provide a more accurate crossing point.

format_time(seconds)[source]#

Convert the time in seconds to a more convenient time unit.

Parameters:

seconds (float) – The time in seconds.

Returns:

The time in the most appropriate unit and the unit string.

Return type:

(float, str)

Examples

>>> format_time(1.0)
(1.0, 's')
>>> format_time(60.0)
(1.0, 'm')
>>> format_time(3600.0)
(1.0, 'h')
>>> format_time(86400.0)
(1.0, 'd')
>>> format_time(604800.0)
(1.0, 'w')
>>> format_time(2592000.0)
(1.0, 'm')
>>> format_time(31536000.0)
(1.0, 'y')
>>> format_time(315360000.0)
(1.0, 'a')
combine_dags(dag1, dag2, discrepancies, prior=None)[source]#

Combine two directed acyclic graphs (DAGs) into a single DAG.

Parameters:
  • dag1 (nx.DiGraph) – The first DAG.

  • dag2 (nx.DiGraph) – The second DAG.

  • discrepancies (Dict) – A Dictionary containing the permutation importances for each edge in the DAGs.

  • prior (Optional[List[List[str]]], optional) – A list of lists containing the prior knowledge about the edges in the DAGs. The lists define a hierarchy of edges that represent a temporal relation in cause and effect. If a node is in the first list, then it is a root cause. If a node is in the second list, then it is caused by the nodes in the first list or the second list, and so on.

Returns:

A tuple containing four graphs: the union of the two DAGs, the intersection of the two DAGs, the union of the two DAGs after removing cycles, and the intersection of the two DAGs after removing cycles.

Return type:

Tuple[nx.DiGraph, nx.DiGraph, nx.DiGraph, nx.DiGraph]

list_files(input_pattern, where)[source]#

List all the files in the input path matching the input pattern

read_json_file(file_path)[source]#

Read a JSON file and return the prior list, if present.

pretty_print(obj, prefix='')[source]#

Return a pretty string representation of an object using pprint.

Module contents#