causalexplain.common package#
Submodules#
Configuration for metric types used across the causalexplain package.
A module to run experiments with the causalexplain package, and simplify the process of loading and saving experiments in notebooks.
Example
>> from causalexplain.common.notebook import Experiment >> experiment = Experiment(“linear”, csv_filename=”linear.csv”) >> rex = experiment.load()
2023, 2024 J. Renero
- class BaseExperiment(input_path, output_path, train_anyway=False, save_anyway=False, scale=False, train_size=0.9, random_state=42, verbose=False)[source]#
Bases:
object
Base class for experiments.
Args: input_path (str): The path to the input data. output_path (str): The path to save the experiment output. train_anyway (bool, optional): Whether to train the model even if the
experiment exists. Defaults to False.
- save_anyway (bool, optional): Whether to save the experiment even
if it exists. Defaults to False.
- train_size (float, optional): The proportion of data to use for training.
Defaults to 0.9.
- random_state (int, optional): The random state for reproducibility.
Defaults to 42.
- verbose (bool, optional): Whether to display verbose output.
Defaults to False.
Methods
create_estimator
(estimator_name, name, **kwargs)Dynamically creates an instance of a class based on the estimator name.
experiment_exists
(name)Checks whether the experiment exists in the output path
prepare_experiment_input
(experiment_filename)- __init__(input_path, output_path, train_anyway=False, save_anyway=False, scale=False, train_size=0.9, random_state=42, verbose=False)[source]#
- prepare_experiment_input(experiment_filename, csv_filename=None, dot_filename=None)[source]#
Loads the data and
splits it into train and test,
scales it
loads the reference graph from the dot file, which has to be named as the experiment file, with the .dot extension
- create_estimator(estimator_name, name, **kwargs)[source]#
Dynamically creates an instance of a class based on the estimator name.
Args: estimator_name (str): The name of the estimator (key in the ‘estimators’
dictionary).
name (str): The name of the estimator instance. *args: Variable length argument list to be passed to the class constructor. **kwargs: Arbitrary keyword arguments to be passed to the class constructor.
Returns: An instance of the specified class, or None if the class does not exist.
- class Experiment(experiment_name, csv_filename=None, dot_filename=None, model_type='nn', input_path='/Users/renero/phd/data/RC4/', output_path='/Users/renero/phd/output/RC4/', train_size=0.9, random_state=42, verbose=False)[source]#
Bases:
BaseExperiment
Represents an experiment for causal graph analysis.
- Attributes:
- estimator_name
Methods
create_estimator
(estimator_name, name, **kwargs)Dynamically creates an instance of a class based on the estimator name.
experiment_exists
(name)Checks whether the experiment exists in the output path
fit
([estimator_name])Fits the experiment data.
fit_predict
([estimator])Fits and predicts with the experiment data.
load
([exp_name])Loads the experiment data.
predict
([estimator])Predicts with the experiment data.
prepare_experiment_input
(experiment_filename)save
([exp_name, overwrite])Saves the experiment data.
- estimator_name = None#
- __init__(experiment_name, csv_filename=None, dot_filename=None, model_type='nn', input_path='/Users/renero/phd/data/RC4/', output_path='/Users/renero/phd/output/RC4/', train_size=0.9, random_state=42, verbose=False)[source]#
Initializes a new instance of the Experiment class.
- Parameters:
experiment_name (str) – The name of the experiment.
csv_filename (str, optional) – The filename of the CSV file containing the data. Defaults to None.
dot_filename (str, optional) – The filename of the DOT file containing the causal graph. Defaults to None.
model_type (str, optional) – The type of model to use. Defaults to ‘nn’. Other options are: ‘gbt’, ‘nn’, ‘cam’, ‘pc’, ‘fci’, ‘notears’, ‘ges’ and ‘lingam’.
input_path (str, optional) – The path to the input data. Defaults to “/Users/renero/phd/data/RC4/”.
output_path (str, optional) – The path to save the output. Defaults to “/Users/renero/phd/output/RC4/”.
train_size (float, optional) – The proportion of data to use for training. Defaults to 0.9.
random_state (int, optional) – The random seed for reproducibility. Defaults to 42.
verbose (bool, optional) – Whether to print verbose output. Defaults to False.
- fit(estimator_name='rex', **kwargs)[source]#
Fits the experiment data.
- Parameters:
**kwargs – Additional keyword arguments to pass to the Rex constructor.
- Returns:
The fitted experiment data.
- Return type:
- predict(estimator='rex', **kwargs)[source]#
Predicts with the experiment data.
- Parameters:
**kwargs – Additional keyword arguments to pass to the predict() method
- Returns:
The fitted experiment data.
- Return type:
- fit_predict(estimator='rex', **kwargs)[source]#
Fits and predicts with the experiment data.
- Parameters:
**kwargs – Additional keyword arguments to pass to the Rex constructor.
- Returns:
The fitted experiment data.
- Return type:
This file includes all the plot methods for the causal graph
Renero, 2022, 2023
- format_graph(G, Gt=None, ok_color='green', inv_color='lightgreen', wrong_color='black', missing_color=None)[source]#
- draw_graph_subplot(G, root_causes=None, layout=None, title=None, ax=None, **kwargs)[source]#
Draw a graph in a subplot.
- set_colormap(color_threshold=0.15, max_color=0.8, cmap_name='OrRd')[source]#
Set the colormap for the graph edges.
- dag2dot(G, undirected=False, name='my_dotgraph', odots=True)[source]#
Display a DOT of the graph in the notebook.
- Parameters:
G (nx.Graph or DiGraph) – the graph to be represented.
undirected (bool) – default False, indicates whether the plot is forced to contain no arrows.
plot (bool) – default is True, this flag can be used to simply generate the object but not plot, in case the object is needed to generate a PNG version of the DOT, for instance.
name (str) – the name to be embedded in the Dot object for this graph.
odots (bool) – represent edges with biconnections with circles (odots). if this is set to false, then the edge simply has no arrowheads.
- Returns:
pydot.Dot object
- Return type:
Dot
- values_distribution(values, threshold=None, **kwargs)[source]#
Plot the probability density and cumulative density of a given set of values.
Parameters:#
values (array-like): The values to be plotted. **kwargs: Additional keyword arguments for customizing the plot.
Returns:#
None
- correlation_matrix(corr_matrix, sorted_colnames=None, threshold=0.5, ax=None, **kwargs)[source]#
Plot the correlation matrix of the data.
- Parameters:
(pd.DataFrame) (- corrs) – Correlation matrix.
(List[str]) (- sorted_colnames) –
- List of sorted column names. If the dataframe contains the names of
columns already sorted, then no need to pass this argument.
(float) (- threshold) –
- Threshold for the correlation. Values below this threshold will
not be displayed
(matplotlib.axes.Axes) (- ax) –
- Axes to plot the correlation matrix, in case this is a plot to be
embedded in a subplot. Otherwise, a new figure will be created and this argument is not necessary.
**kwargs (-) –
Keyword arguments to be passed to the plot_dendogram function. - title (str)
Title of the plot.
- fontsize (int)
Font size for the labels.
- fontname (str)
Font name for the labels.
- xrot (int)
Rotation of the labels.
- Return type:
None
- hierarchies(hierarchies, threshold=0.5, **kwargs)[source]#
Plot the hierarchical clustering and correlation matrix of the data.
https://www.kaggle.com/code/sgalella/correlation-heatmaps-with-hierarchical-clustering/notebook
- Parameters:
(HierarchicalClustering) (- hierarchies) – Hierarchical clustering object.
(float) (- threshold) – Threshold for the correlation.
**kwargs (-) –
- Additional keyword arguments to be passed to the correlation_matrix
function.
- Return type:
None
- dag(graph, reference=None, root_causes=None, show_metrics=False, show_node_fill=True, title=None, ax=None, figsize=(5, 5), dpi=75, save_to_pdf=None, layout='dot', **kwargs)[source]#
Compare two graphs using dot.
- dags(dags, ref_graph, titles, figsize=(15, 12), dpi=300)[source]#
Plots multiple directed acyclic graphs (DAGs) in a grid layout.
Parameters: - dags (list): List of DAGs to plot. - ref_graph: Reference graph used for layout. - titles (list): List of titles for each DAG. - figsize (tuple): Figure size (default: (15, 12)). - dpi (int): Dots per inch (default: 300).
Raises: - ValueError: If there are too many DAGs to plot.
Returns: - None
- shap_discrepancies(shaps, target_name, threshold=100.0, regression_line=False, reduced=False, **kwargs)[source]#
Plot the discrepancies of the SHAP values.
- deprecated_dags(graph, reference=None, names=None, figsize=(10, 5), dpi=75, save_to_pdf=None, **kwargs)[source]#
Compare two graphs using dot.
- score_by_method(metrics, metric, methods, **kwargs)[source]#
Plots the score by method.
Parameters: - metrics: DataFrame containing the metrics data. - metric: The metric to be plotted. - methods: List of methods to be included in the plot. - **kwargs: Additional keyword arguments for customization, like
figsize: The size of the figure. Default is (4, 3).
dpi: The resolution of the figure in dots per inch. Default is 300.
title: The title of the plot. Default is None.
- pdf_filename: The filename to save the plot to. If None, the plot
will be displayed on screen, otherwise it will be saved to the
- method_column: The name of the column containing the method names.
Default is ‘method’.
Returns: None
- scores_by_method(metrics, methods=None, title=None, pdf_filename=None, **kwargs)[source]#
Plot the metrics for all the experiments matching the input pattern
- Parameters:
metrics (pandas DataFrame) – A DataFrame with the metrics for all the experiments
method_types (list) – The list of methods to plot. If None, all the methods will be plotted The methods included are: ‘rex_mlp’, ‘rex_gbt’, ‘rex_intersection’ and ‘rex_union’
title (str) – The title of the plot
pdf_filename (str) – The filename to save the plot to. If None, the plot will be displayed on screen, otherwise it will be saved to the specified filename.
parameters (Optional)
(tuple (- ylim)
optional) (The y-axis limits of the plot. Default is None.)
(int (- dpi) – Default is 300.
optional) – Default is 300.
(tuple
optional)
- score_by_subtype(metrics, score_name, methods=None, pdf_filename=None, **kwargs)[source]#
Plots the score by subtype.
Parameters: - metrics (pandas DataFrame): The metrics for all the experiments. This dataframe contains the following columns:
method (str): The name of the method used.
data_type (str): The type of data used.
f1 (float): The F1 score.
precision (float): The precision score.
recall (float): The recall score.
aupr (float): The area under the precision-recall curve.
Tp (int): The number of true positives.
Tn (int): The number of true negatives.
Fp (int): The number of false positives.
Fn (int): The number of false negatives.
shd (int): The structural Hamming distance.
sid (int): The structural intervention distance.
n_edges (int): The number of edges in the graph.
ref_n_edges (int): The number of edges in the reference graph.
diff_edges (int): The difference between the number of edges in the graph
and the reference graph. - name (str): The name of the experiment.
and stores one experiment per row. - score_name (str): The name of the score to plot. Valid names are ‘f1’, ‘precision’, ‘recall’, ‘aupr’, ‘shd’, ‘sid’, ‘n_edges’, ‘ref_n_edges’ and ‘diff_edges’. - methods (list, optional): The list of methods to plot. If None, all the methods will be plotted. The methods included are: ‘rex_intersection’, ‘rex_union’, ‘pc’, ‘fci’, ‘ges’, ‘lingam’ - pdf_filename (str, optional): The filename to save the plot to. If None, the plot will be displayed on screen, otherwise it will be saved to the specified filename.
Optional parameters: - figsize (tuple, optional): The size of the figure. Default is (2, 1). - dpi (int, optional): The resolution of the figure in dots per inch. Default is 300. - method_column (str, optional): The name of the column in the metrics dataframe
that contains the method name. Default is ‘method’.
Returns: None
- combined_metrics(metrics, metric_types=None, title=None, acyclic=False, medians=False, pdf_filename=None)[source]#
Plot the metrics for all the experiments matching the input pattern
Utility functions for causalexplain (C) J. Renero, 2022, 2023, 2024
- save_experiment(obj_name, folder, results, overwrite=False)[source]#
Creates a folder for the experiment and saves results. Results is a dictionary that will be saved as an opaque pickle. When the experiment will require to be loaded, the only parameter needed are the folder name.
- Parameters:
obj_name (str) – the name to be given to the pickle file to be saved. If a file already exists with that name, a file with same name and a extension will be generated.
folder (str) – a full path to the folder where the experiment is to be saved. If the folder does not exist it will be created.
results (obj) – the object to be saved as experiment. This is typically a dictionary with different items representing different parts of the experiment.
overwrite (bool) – If True, then the experiment is saved even if a file with the same name already exists. If False, then the experiment is saved only if a file with the same name does not exist.
- Returns:
(str) The name under which the experiment has been saved
- Return type:
- load_experiment(obj_name, folder)[source]#
Loads a pickle from a given folder name. It is not necessary to add the “pickle” extension to the experiment name.
- valid_output_name(filename, path, extension=None)[source]#
Builds a valid name. In case there’s another file which already exists adds a number (1, 2, …) until finds a valid filename which does not exist.
- Returns:
The filename if the name is valid and file does not exists, – None otherwise.
Params
——
filename (str) – The base filename to be set.
path (str) – The path where trying to set the filename
extension (str) – The extension of the file, without the dot ‘.’ If no extension is specified, any extension is searched to avoid returning a filepath of an existing file, no matter what extension it has.
- Return type:
- graph_from_dictionary(d)[source]#
Builds a graph from a dictionary like {‘u’: [‘v’, ‘w’], ‘x’: [‘y’]}. The elements of the list can be tuples including weight
- Parameters:
d (dict) – A dictionary of the form {‘u’: [‘v’, ‘w’], ‘x’: [‘y’]}, where an edge between ‘u’ goes towards ‘v’ and ‘w’, and also an edge from ‘x’ goes towards ‘y’. The format can also be like: {‘u’: [(‘v’, 0.2), (‘w’, 0.7)], ‘x’: [(‘y’, 0.5)]}, where the values in the tuple are interpreted as weights.
- Returns:
networkx.DiGraph with the nodes and edges specified.
- Return type:
Graph | DiGraph
- graph_from_adjacency(adjacency, node_labels=None, th=0.0, inverse=False, absolute_values=False)[source]#
Manually parse the adj matrix to shape a dot graph
- Parameters:
adjacency (ndarray) – a numpy adjacency matrix
node_labels – an array of same length as nr of columns in the adjacency matrix containing the labels to use with every node.
th – (float) weight threshold to be considered a valid edge.
inverse (bool) – Set to true if rows in adjacency reflects where edges are comming from, instead of where are they going to.
absolute_values (bool) – Take absolute value of weight label to check if its greater than the threshold.
- Returns:
The Graph (DiGraph)
- Return type:
DiGraph
- graph_from_adjacency_file(file, labels=None, th=0.0, sep=',', header=True)[source]#
Read Adjacency matrix from a file and return a Graph
- Parameters:
labels – (List[str]) the list of node names to be used. If None, the node names are extracted from the adj file. The names must be already sorted in the same order as the adjacency matrix.
th – (float) weight threshold to be considered a valid edge.
sep – (str) the separator used in the file
header (bool) – (bool) whether the file has a header. If True, then ‘infer’
header. (is used to read the)
- Returns:
DiGraph, DataFrame
- Return type:
- graph_to_adjacency(graph, labels, weight_label='weight')[source]#
A method to generate the adjacency matrix of the graph. Labels are sorted for better readability.
- Parameters:
graph (Graph | DiGraph) – (Union[Graph, DiGraph]) the graph to be converted.
node_names – (List[str]) the list of node names to be used. If None, the node names are extracted from the graph. The names must be already sorted in the same order as the adjacency matrix.
weight_label (str) – the label used to identify the weights.
- Returns:
- (numpy.ndarray) A 2d array containing the adjacency matrix of
the graph.
- Return type:
graph
- graph_to_adjacency_file(graph, output_file, labels)[source]#
A method to write the adjacency matrix of the graph to a file. If graph has weights, these are the values stored in the adjacency matrix.
- select_device(force=None)[source]#
Selects the device to be used for training. If force is not None, then the device is forced to be the one specified. If force is None, then the device is selected based on the availability of GPUs. If no GPUs are available, then the CPU is selected.
- Parameters:
force (str) – If not None, then the device is forced to be the one specified. If None, then the device is selected based on the availability of GPUs. If no GPUs are available, then the CPU is selected.
- Returns:
(str) The device to be used for training.
- Raises:
ValueError – If the forced device is not available or not a valid device.
- Return type:
- graph_intersection(g1, g2)[source]#
Returns the intersection of two graphs. The intersection is defined as the set of nodes and edges that are common to both graphs. The intersection is performed on the nodes and edges, not on the attributes of the nodes and edges.
- Parameters:
g1 (networkx.DiGraph) – The first graph.
g2 (networkx.DiGraph) – The second graph.
- Returns:
(networkx.DiGraph) The intersection of the two graphs.
- Return type:
DiGraph
- graph_union(g1, g2)[source]#
Returns the union of two graphs. The union is defined as the set of nodes and edges that are in both graphs. The union is performed on the nodes and edges, not on the attributes of the nodes and edges.
- digraph_from_connected_features(X, feature_names, models, connections, root_causes, prior=None, reciprocity=True, anm_iterations=10, max_anm_samples=400, verbose=False)[source]#
Builds a directed graph from a set of features and their connections. The connections are determined by the SHAP values of the features. The orientation of the edges is determined by the causal direction of the features. The orientation is determined by the method of edge orientation defined in the independence module.
- correct_edge_from_prior(dag, u, v, prior, verbose)[source]#
Correct an edge in the graph according to the prior knowledge. This function corrects the orientation of an edge in a directed acyclic graph (DAG) based on prior knowledge. The prior knowledge is a list of lists of node names, ordered according to a temporal structure. The assumption is that nodes from the first layer of temporal structure cannot receive any incoming edge.
The function checks the relative positions of the two nodes in the prior knowledge and: - If both nodes are in the top list, it removes the edge, based on the assumption. - If one node is before the other, it adds a new edge in the correct direction. - If the nodes are not in a clear order, it leaves the edge unchanged. - If the edge reflects a backward connection between nodes from different
temporal layers, it removes the edge.
- The orientation of the edge, -1 if the edge is reversed, +1 if the edge
is removed or kept, and 0 if the edge orientation is not clear.
- Parameters:
- Returns:
orientation – The orientation of the edge, -1 if the edge is reversed, +1 if the edge is removed or kept, and 0 if the edge orientation is not clear.
- Return type:
- valid_candidates_from_prior(feature_names, effect, prior)[source]#
This method returns the valid candidates for a given effect, based on the prior information. The prior information is a list of lists, where each list contains the features that are known to be in the same level of the hierarchy.
- break_cycles_using_prior(original_dag, prior, verbose=False)[source]#
This method remove potential edges in the DAG by removing edges that are incompatible with the temporal relationship defined in the prior knowledge. Any edge pointing backwards, according to the order established in the prior knowledge, is removed.
- potential_misoriented_edges(loop, discrepancies, verbose=False)[source]#
Find potential misoriented edges in a loop based on discrepancies. The loop is a list of nodes that form a loop in a DAG. The discrepancies is a dictionary containing the discrepancies between nodes. The function returns a list of potential misoriented edges, sorted by the difference in goodness-of-fit scores.
- Parameters:
- Returns:
- A list of potential misoriented edges,
sorted by the difference in goodness-of-fit scores.
- Return type:
- break_cycles_if_present(dag, discrepancies, prior=None, verbose=False)[source]#
Breaks cycles in a directed acyclic graph (DAG) by removing the edge with the lowest goodness of fit (R2). If there are multiple cycles, they are all traversed and fixed. If prior is set, then the cycles are broken using the prior knowledge.
Parameters: - dag (nx.DiGraph): the DAG to break cycles in. - knowledge (pd.DataFrame): a DataFrame containing the permutation importances
for each edge in the DAG.
- prior (List[List[str]]): a list of lists containing the prior knowledge
about the edges in the DAG. The lists define a hierarchy of edges that represent a temporal relation in cause and effect. If a node is in the first list, then it is a root cause. If a node is in the second list, then it is caused by the nodes in the first list or the second list, and so on.
Returns: - dag (nx.DiGraph): the DAG with cycles broken.
- stringfy_object(object_)[source]#
Convert an object into a string representation, including its attributes.
- get_feature_names(X)[source]#
Get the feature names from the input data. The feature names can be extracted from a pandas DataFrame, a numpy array or a list.
- get_feature_types(X)[source]#
Get the feature types from the input data. The feature types can be binary, multiclass or continuous. The classification is done based on the number of unique values in the array. If the array has only two unique values, then the variable is classified as binary. If the unique values are integers, then the variable is classified as multiclass. Otherwise, the variable is classified as continuous.
- cast_categoricals_to_int(X)[source]#
Cast all categorical features in the dataframe to integer values.
- Parameters:
X (pd.DataFrame) – The dataframe to cast.
- Returns:
The dataframe with all categorical features cast to integer values.
- Return type:
pd.DataFrame
- find_crossing_point(f1, f2, x_values)[source]#
This function finds the exact crossing point between two curves defined by f1 and f2 over x_values. It interpolates between points to provide a more accurate crossing point.
- format_time(seconds)[source]#
Convert the time in seconds to a more convenient time unit.
- Parameters:
seconds (float) – The time in seconds.
- Returns:
The time in the most appropriate unit and the unit string.
- Return type:
Examples
>>> format_time(1.0) (1.0, 's') >>> format_time(60.0) (1.0, 'm') >>> format_time(3600.0) (1.0, 'h') >>> format_time(86400.0) (1.0, 'd') >>> format_time(604800.0) (1.0, 'w') >>> format_time(2592000.0) (1.0, 'm') >>> format_time(31536000.0) (1.0, 'y') >>> format_time(315360000.0) (1.0, 'a')
- combine_dags(dag1, dag2, discrepancies, prior=None)[source]#
Combine two directed acyclic graphs (DAGs) into a single DAG.
- Parameters:
dag1 (nx.DiGraph) – The first DAG.
dag2 (nx.DiGraph) – The second DAG.
discrepancies (Dict) – A Dictionary containing the permutation importances for each edge in the DAGs.
prior (Optional[List[List[str]]], optional) – A list of lists containing the prior knowledge about the edges in the DAGs. The lists define a hierarchy of edges that represent a temporal relation in cause and effect. If a node is in the first list, then it is a root cause. If a node is in the second list, then it is caused by the nodes in the first list or the second list, and so on.
- Returns:
A tuple containing four graphs: the union of the two DAGs, the intersection of the two DAGs, the union of the two DAGs after removing cycles, and the intersection of the two DAGs after removing cycles.
- Return type:
Tuple[nx.DiGraph, nx.DiGraph, nx.DiGraph, nx.DiGraph]