"""
Utility functions for causalexplain
(C) J. Renero, 2022, 2023, 2024
"""
# pylint: disable=E1101:no-member, W0201:attribute-defined-outside-init
# pylint: disable=C0103:invalid-name, W0511:fixme
# pylint: disable=C0116:missing-function-docstring
# pylint: disable=R0913:too-many-arguments
# pylint: disable=R0914:too-many-locals, R0915:too-many-statements
# pylint: disable=W0106:expression-not-assigned, R1702:too-many-branches
import glob
import json
import os
import pickle
import types
from os.path import join
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
import networkx as nx
import numpy as np
import pandas as pd
import pydot
import pydotplus
import torch
from ..independence.edge_orientation import get_edge_orientation
AnyGraph = Union[nx.Graph, nx.DiGraph]
[docs]
def save_experiment(
obj_name: str,
folder: str,
results: object,
overwrite: bool = False) -> str:
"""
Creates a folder for the experiment and saves results. Results is a
dictionary that will be saved as an opaque pickle. When the experiment will
require to be loaded, the only parameter needed are the folder name.
Args:
obj_name (str): the name to be given to the pickle file to be saved. If
a file already exists with that name, a file with same name and a
extension will be generated.
folder (str): a full path to the folder where the experiment is to
be saved. If the folder does not exist it will be created.
results (obj): the object to be saved as experiment. This is typically a
dictionary with different items representing different parts of the
experiment.
overwrite (bool): If True, then the experiment is saved even if a file
with the same name already exists. If False, then the experiment is
saved only if a file with the same name does not exist.
Return:
(str) The name under which the experiment has been saved
"""
if not os.path.exists(folder):
Path(folder).mkdir(parents=False, exist_ok=True)
if not overwrite:
output = valid_output_name(obj_name, folder, extension="pickle")
else:
output = join(folder, obj_name if obj_name.endswith('.pickle') \
else obj_name + '.pickle')
with open(output, 'wb') as handle:
pickle.dump(results, handle, protocol=pickle.HIGHEST_PROTOCOL)
return output
[docs]
def load_experiment(obj_name: str, folder: str):
"""
Loads a pickle from a given folder name. It is not necessary to add the "pickle"
extension to the experiment name.
Parameters:
-----------
obj_name (str): The name of the object saved in pickle format that is to be loaded.
folder (str): A full path where looking for the experiment object.
Returns:
--------
An obj loaded from a pickle file.
"""
if Path(obj_name).suffix == "" or Path(obj_name).suffix != "pickle":
ext = ".pickle"
else:
ext = ''
experiment = f"{str(Path(folder, obj_name))}{ext}"
with open(experiment, 'rb') as h:
obj = pickle.load(h, encoding="utf-8")
return obj
[docs]
def valid_output_name(filename: str, path: str, extension=None) -> str:
"""
Builds a valid name. In case there's another file which already exists
adds a number (1, 2, ...) until finds a valid filename which does not
exist.
Returns
-------
The filename if the name is valid and file does not exists,
None otherwise.
Params
------
filename: str
The base filename to be set.
path: str
The path where trying to set the filename
extension:str
The extension of the file, without the dot '.' If no extension is
specified, any extension is searched to avoid returning a filepath
of an existing file, no matter what extension it has.
"""
if extension:
if extension[0] == '.':
extension = extension[1:]
if filename.endswith(f".{extension}"):
base_filepath = join(path, filename)
else:
base_filepath = join(path, filename) + '.{}'.format(extension)
else:
base_filepath = join(path, filename)
output_filepath = base_filepath
idx = 1
while len(glob.glob(f"{output_filepath}*")) > 0:
if extension:
output_filepath = join(
path, filename) + '_{:d}.{}'.format(
idx, extension)
else:
output_filepath = join(path, filename + '_{}'.format(idx))
idx += 1
return output_filepath
[docs]
def graph_from_dot_file(dot_file: Union[str, Path]) -> Optional[nx.DiGraph]:
""" Returns a NetworkX DiGraph from a DOT FILE. """
# Check if "dot_file" exists
if not os.path.exists(dot_file):
return None
try:
dot_object = pydot.graph_from_dot_file(dot_file)
if dot_object is None or not dot_object or not isinstance(dot_object, (list, tuple)):
return None
dot_string = dot_object[0].to_string()
if not dot_string:
return None
dot_string = dot_string.replace('\"\\n\"', '')
dot_string = dot_string.replace("\n;\n", "\n")
dotplus = pydotplus.graph_from_dot_data(dot_string)
if dotplus is None:
return None
dotplus.set_strict(True) # type: ignore
final_graph = nx.nx_pydot.from_pydot(dotplus)
if '\\n' in final_graph.nodes:
final_graph.remove_node('\\n')
return final_graph
except Exception as e:
print(f"Error processing dot file: {str(e)}")
return None
[docs]
def graph_from_dictionary(d: Dict[str, List[Union[str, Tuple[str, float]]]]) -> AnyGraph:
"""
Builds a graph from a dictionary like {'u': ['v', 'w'], 'x': ['y']}.
The elements of the list can be tuples including weight
Args:
d (dict): A dictionary of the form {'u': ['v', 'w'], 'x': ['y']}, where
an edge between 'u' goes towards 'v' and 'w', and also an edge from
'x' goes towards 'y'. The format can also be like:
{'u': [('v', 0.2), ('w', 0.7)], 'x': [('y', 0.5)]}, where the values
in the tuple are interpreted as weights.
Returns:
networkx.DiGraph with the nodes and edges specified.
"""
g = nx.DiGraph()
for node, parents in d.items():
if len(parents) > 0:
if isinstance(parents[0], tuple):
for parent, weight in parents:
g.add_edge(parent, node, weight=weight)
else:
for parent in parents:
g.add_edge(parent, node)
return g
[docs]
def graph_from_adjacency(
adjacency: np.ndarray,
node_labels=None,
th=0.0,
inverse: bool = False,
absolute_values: bool = False
) -> nx.DiGraph:
"""
Manually parse the adj matrix to shape a dot graph
Args:
adjacency: a numpy adjacency matrix
node_labels: an array of same length as nr of columns in the adjacency
matrix containing the labels to use with every node.
th: (float) weight threshold to be considered a valid edge.
inverse (bool): Set to true if rows in adjacency reflects where edges
are comming from, instead of where are they going to.
absolute_values: Take absolute value of weight label to check if
its greater than the threshold.
Returns:
The Graph (DiGraph)
"""
G = nx.DiGraph()
G.add_nodes_from(range(adjacency.shape[1]))
# What to do with absolute values?
def not_abs(x):
return x
w_val = np.abs if absolute_values else not_abs
def weight_gt(w, thresh):
return w != 0.0 if thresh is None else w_val(w) > thresh
# Do I have a threshold to consider?
for i, row in enumerate(adjacency):
for j, value in enumerate(row):
if inverse:
if weight_gt(adjacency[j][i], th):
G.add_edge(i, j, weight=w_val(adjacency[j][i]))
else:
if weight_gt(value, th):
# , arrowhead="normal")
G.add_edge(i, j, weight=w_val(value))
# Map the current column numbers to the letters used in toy dataset
if node_labels is not None and len(node_labels) == adjacency.shape[1]:
mapping = dict(zip(sorted(G), node_labels))
G = nx.relabel_nodes(G, mapping)
return G
[docs]
def graph_from_adjacency_file(
file: Union[Path, str],
labels=None,
th=0.0,
sep=",",
header:bool=True
) -> Tuple[nx.DiGraph, pd.DataFrame]:
"""
Read Adjacency matrix from a file and return a Graph
Args:
file: (str) the full path of the file to read
labels: (List[str]) the list of node names to be used. If None, the
node names are extracted from the adj file. The names must be
already sorted in the same order as the adjacency matrix.
th: (float) weight threshold to be considered a valid edge.
sep: (str) the separator used in the file
header: (bool) whether the file has a header. If True, then 'infer'
is used to read the header.
Returns:
DiGraph, DataFrame
"""
header_infer = "infer" if header else None
df = pd.read_csv(file, dtype="str", delimiter=sep, header=header_infer)
df = df.astype("float64")
if labels is None:
labels = list(df.columns)
G = graph_from_adjacency(df.values, node_labels=labels, th=th)
return G, df
[docs]
def graph_to_adjacency(
graph: AnyGraph,
labels: List[str],
weight_label: str = "weight") -> np.ndarray:
"""
A method to generate the adjacency matrix of the graph. Labels are
sorted for better readability.
Args:
graph: (Union[Graph, DiGraph]) the graph to be converted.
node_names: (List[str]) the list of node names to be used. If None, the
node names are extracted from the graph. The names must be already
sorted in the same order as the adjacency matrix.
weight_label: the label used to identify the weights.
Return:
graph: (numpy.ndarray) A 2d array containing the adjacency matrix of
the graph.
"""
symbol_map = {"o": 1, ">": 2, "-": 3}
# Fix for the case where an empty node is parsed from the .dot file
if '\\n' in labels:
labels.remove('\\n')
mat = np.zeros((len(labels), (len(labels))))
for x in labels:
for y in labels:
if graph.has_edge(x, y):
if bool(graph.get_edge_data(x, y)):
if y in graph.get_edge_data(x, y).keys():
mat[labels.index(x)][labels.index(y)] = symbol_map[
graph.get_edge_data(x, y)[y]
]
else:
weight = graph.get_edge_data(x, y)[weight_label]
if weight is None:
weight = 1
mat[labels.index(x)][labels.index(y)] = weight
else:
mat[labels.index(x)][labels.index(y)] = 1
mat[np.isnan(mat)] = 0
return mat
[docs]
def graph_to_adjacency_file(graph: AnyGraph, output_file: Union[Path, str], labels):
"""
A method to write the adjacency matrix of the graph to a file. If graph has
weights, these are the values stored in the adjacency matrix.
Args:
graph: (Union[Graph, DiGraph] the graph to be saved
output_file: (str) The full path where graph is to be saved
"""
mat = graph_to_adjacency(graph, labels)
with open(output_file, "w", encoding="utf-8") as f:
f.write(",".join([f"{label}" for label in labels]))
f.write("\n")
for i, label in enumerate(labels):
f.write(f"{label}")
f.write(",")
f.write(",".join([str(point) for point in mat[i]]))
f.write("\n")
return
[docs]
def graph_to_dot_file(graph: AnyGraph, output_file: Union[Path, str]):
"""
A method to write the graph in dot format to a file.
Args:
graph: (Union[Graph, DiGraph] the graph to be saved
output_file: (str) The full path where graph is to be saved
"""
nx.drawing.nx_pydot.write_dot(graph, output_file)
[docs]
def select_device(force: Optional[str] = None) -> str:
"""
Selects the device to be used for training. If force is not None, then
the device is forced to be the one specified. If force is None, then
the device is selected based on the availability of GPUs. If no GPUs are
available, then the CPU is selected.
Args:
force (str): If not None, then the device is forced to be the one
specified. If None, then the device is selected based on the
availability of GPUs. If no GPUs are available, then the CPU is
selected.
Returns:
(str) The device to be used for training.
Raises:
ValueError: If the forced device is not available or not a valid device.
"""
if force is not None:
if force in ['cuda', 'mps', 'cpu']:
if force == 'cuda' and torch.cuda.is_available():
device = force
elif force == 'mps' and torch.backends.mps.is_available():
device = force
elif force == 'mps' and not torch.backends.mps.is_available():
device = 'cpu'
elif force == 'cpu':
device = force
else:
raise ValueError(f"Invalid device: {force}")
else:
raise ValueError(f"Invalid device: {force}")
else:
if torch.cuda.is_available() and torch.backends.cuda.is_built():
device = "cuda"
elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
device = "mps"
else:
device = "cpu"
return device
[docs]
def graph_intersection(g1: nx.DiGraph, g2: nx.DiGraph) -> nx.DiGraph:
"""
Returns the intersection of two graphs. The intersection is defined as the
set of nodes and edges that are common to both graphs. The intersection is
performed on the nodes and edges, not on the attributes of the nodes and
edges.
Args:
g1 (networkx.DiGraph): The first graph.
g2 (networkx.DiGraph): The second graph.
Returns:
(networkx.DiGraph) The intersection of the two graphs.
"""
# Get the nodes from g1 and g2, as the union of both
nodes = set(g1.nodes).intersection(set(g2.nodes))
# Check if any of the graphs is empty
if len(nodes) == 0:
return nx.DiGraph()
# For the nodes in the resulting DAG ("g"), get the data from the nodes in
# g1 and g2, as the average value of both
first_node = list(g1.nodes)[0]
data_fields = g1.nodes[first_node].keys()
nodes_data = {}
for key in data_fields:
# Check if the value for this key is numeric
if isinstance(g1.nodes[first_node][key], (int, float, np.number)):
for n in nodes:
if n in g1.nodes and n in g2.nodes:
nodes_data[n] = {
key: (g1.nodes[n][key] + g2.nodes[n][key]) / 2}
elif n in g1.nodes:
nodes_data[n] = {key: g1.nodes[n][key]}
elif n in g2.nodes:
nodes_data[n] = {key: g2.nodes[n][key]}
# Get the edges from g1 and g2, and take only those that match completely
edges = g1.edges & g2.edges
# Create a new graph with the nodes, nodes_data and edges
g = nx.DiGraph()
g.add_nodes_from(nodes)
g.add_edges_from(edges)
nx.set_node_attributes(g, nodes_data)
return g
[docs]
def graph_union(g1: nx.DiGraph, g2: nx.DiGraph) -> nx.DiGraph:
"""
Returns the union of two graphs. The union is defined as the set of nodes and
edges that are in both graphs. The union is performed on the nodes and edges,
not on the attributes of the nodes and edges.
"""
nodes = set(g1.nodes).union(set(g2.nodes))
g = nx.DiGraph()
g.add_nodes_from(nodes)
g.add_edges_from(g1.edges)
g.add_edges_from(g2.edges)
# For the nodes in the resulting DAG ("g"), get the data from the nodes in
# g1 and g2, as the average value of both
first_node = list(g1.nodes)[0]
data_fields = g1.nodes[first_node].keys()
nodes_data = {}
for key in data_fields:
# Check if the value for this key is numeric
if isinstance(g1.nodes[first_node][key], (int, float, np.number)):
for n in nodes:
if n in g1.nodes and n in g2.nodes:
nodes_data[n] = {
key: (g1.nodes[n][key] + g2.nodes[n][key]) / 2}
elif n in g1.nodes:
nodes_data[n] = {key: g1.nodes[n][key]}
elif n in g2.nodes:
nodes_data[n] = {key: g2.nodes[n][key]}
nx.set_node_attributes(g, nodes_data)
return g
[docs]
def digraph_from_connected_features(
X,
feature_names,
models,
connections,
root_causes,
prior: Optional[List[List[str]]] = None,
reciprocity=True,
anm_iterations=10,
max_anm_samples=400,
verbose=False):
"""
Builds a directed graph from a set of features and their connections. The
connections are determined by the SHAP values of the features. The
orientation of the edges is determined by the causal direction of the
features. The orientation is determined by the method of edge orientation
defined in the independence module.
Parameters:
-----------
X (numpy.ndarray): The input data.
feature_names (list): The list of feature names.
models (obj): The model used to explain the data.
connections (dict): The dictionary of connections between features. The
dictionary is of the form {'feature1': ['feature2', 'feature3'], ...}
root_causes (list): The list of root causes. The root causes are the
features that are not caused by any other feature.
reciprocity (bool): If True, then the edges are oriented only if the
direction is reciprocal. If False, then the edges are oriented
regardless of the reciprocity.
anm_iterations=10 (int): The number of iterations to be used by the independence
module to determine the orientation of the edges.
max_anm_samples=400 (int): The maximum number of samples to be used by the
independence module to determine the orientation of the edges. If the
number of samples is larger than the number of samples in the data, then
the number of samples is set to the number of samples in the data.
verbose (bool): If True, then the function prints information about the
orientation of the edges.
Returns:
--------
networkx.DiGraph: The directed graph with the nodes and edges determined
by the connections and the orientation determined by the independence
module.
"""
if verbose:
print("-----\ndigraph_from_connected_features()")
unoriented_graph = nx.Graph()
for target in feature_names:
for peer in connections[target]:
# Add edges ONLY between nodes where SHAP recognizes both directions
if reciprocity:
if target in connections[peer]:
unoriented_graph.add_edge(target, peer)
else:
unoriented_graph.add_edge(target, peer)
dag = nx.DiGraph()
# Add regression mean score to each node
for i, feature in enumerate(feature_names):
dag.add_node(feature)
dag.nodes[feature]['regr_score'] = models.scoring[i]
# Determine edge orientation for each edge
if X.shape[0] > max_anm_samples:
X = X.sample(max_anm_samples)
if verbose:
print(f"Reduced number of samples to {max_anm_samples}")
for u, v in unoriented_graph.edges():
# Set orientation to 0 ~ unknown
orientation = 0
# Check if we can set it from prior knowledge
if prior:
orientation = correct_edge_from_prior(dag, u, v, prior, verbose)
if orientation == 0:
print(f"Checking edge {u} -> {v}...") if verbose else None
orientation = get_edge_orientation(
X, u, v, iters=anm_iterations, method="gpr", verbose=verbose)
if orientation == +1:
print(f" Edge {u} -> {v} added from ANM") if verbose else None
dag.add_edge(u, v)
elif orientation == -1:
print(f" Edge {v} -> {u} added from ANM") if verbose else None
dag.add_edge(v, u)
else:
pass
# Apply a simple fix: if quality of regression for a feature is poor, then
# that feature can be considered a root or parent node. Therefore, we cannot
# have an edge pointing to that node.
if root_causes:
changes = []
for parent_node in root_causes:
print(f"Checking root cause {parent_node}...") if verbose else None
for cause, effect in dag.edges():
if effect == parent_node:
print(
f"Reverting edge {cause} -> {effect}") if verbose else None
changes.append((cause, effect))
for cause, effect in changes:
dag.remove_edge(cause, effect)
dag.add_edge(effect, cause)
return dag
[docs]
def correct_edge_from_prior(dag, u, v, prior, verbose):
"""
Correct an edge in the graph according to the prior knowledge.
This function corrects the orientation of an edge in a directed acyclic graph
(DAG) based on prior knowledge. The prior knowledge is a list of lists of node
names, ordered according to a temporal structure. The assumption is that nodes
from the first layer of temporal structure cannot receive any incoming edge.
The function checks the relative positions of the two nodes in the prior
knowledge and:
- If both nodes are in the top list, it removes the edge, based on the assumption.
- If one node is before the other, it adds a new edge in the correct direction.
- If the nodes are not in a clear order, it leaves the edge unchanged.
- If the edge reflects a backward connection between nodes from different
temporal layers, it removes the edge.
The orientation of the edge, -1 if the edge is reversed, +1 if the edge
is removed or kept, and 0 if the edge orientation is not clear.
Parameters
----------
dag : nx.DiGraph
The graph to be corrected.
u : str
The first node of the edge.
v : str
The second node of the edge.
prior : List[List[str]]
The prior knowledge, a list of lists of node names, ordered according to
a temporal structure.
verbose : bool
If True, print debug messages.
Returns
-------
orientation : int
The orientation of the edge, -1 if the edge is reversed, +1 if the edge
is removed or kept, and 0 if the edge orientation is not clear.
"""
# Check either 'u' or 'v' are NOT in the prior list. Remember that 'prior' is a
# list of lists of node names, ordered according to a temporal structure.
if not any([u in p for p in prior]) or not any([v in p for p in prior]):
return 0
print(f"Checking edge {u} -> {v}...") if verbose else None
idx_u = [i for i, l in enumerate(prior) if u in l][0]
idx_v = [i for i, l in enumerate(prior) if v in l][0]
both_in_top_list = idx_u == 0 and idx_v == 0
u_is_before_v = idx_u - idx_v < 0
v_is_before_u = idx_v - idx_u < 0
if both_in_top_list:
print(
f"Edge {u} -x- {v} removed: both top list") if verbose else None
dag.remove_edge(u, v) if dag.has_edge(u, v) else None # Experimental XXX Beware of this line!
return +1
elif u_is_before_v:
print(
f"Edge {u} -> {v} added: {u} before {v}") if verbose else None
dag.add_edge(u, v) if not dag.has_edge(v, u) else None
return +1
elif v_is_before_u:
print(
f"Edge {v} -> {u} added: {v} before {u}") if verbose else None
dag.remove_edge(u, v) if dag.has_edge(u, v) else None
return -1
else:
return 0
[docs]
def valid_candidates_from_prior(feature_names, effect, prior):
"""
This method returns the valid candidates for a given effect, based on the
prior information. The prior information is a list of lists, where each list
contains the features that are known to be in the same level of the hierarchy.
Parameters:
-----------
feature_names: List[str]
The list of feature names.
effect: str
The effect for which to find the valid candidates.
prior: List[List[str]]
The prior information about the hierarchy of the features.
Returns:
--------
List[str]
The valid candidates for the given effect.
"""
if prior is None or prior == []:
return [c for c in feature_names if c != effect]
# identify in what list is the effect, from the list of lists defined in prior
idx = [i for i, sublist in enumerate(prior) if effect in sublist]
if not idx:
raise ValueError(f"Effect '{effect}' not found in prior")
# candidates are elements in the list ' idx' and all the lists before it
candidates = [item for sublist in prior[:idx[0] + 1]
for item in sublist]
# return the candidates, excluding the effect itself
return [c for c in candidates if c != effect]
[docs]
def break_cycles_using_prior(
original_dag: nx.DiGraph,
prior: List[List[str]],
verbose: bool = False) -> nx.DiGraph:
"""
This method remove potential edges in the DAG by removing edges that are
incompatible with the temporal relationship defined in the prior knowledge.
Any edge pointing backwards, according to the order established in the prior
knowledge, is removed.
Parameters:
-----------
original_dag (nx.DiGraph): The original directed acyclic graph.
prior (List[List[str]]): A list of lists containing the prior knowledge
about the edges in the DAG. The lists define a hierarchy of edges that
represent a temporal relation in cause and effect. If a node is in the first
list, then it is a root cause. If a node is in the second list, then it is
caused by the nodes in the first list or the second list, and so on.
verbose (bool): If True, then the method prints information about the
edges that are removed.
Returns:
--------
nx.DiGraph: The directed acyclic graph with the edges removed according to
the prior knowledge.
"""
if prior is None or prior == []:
return original_dag
new_dag = original_dag.copy()
if verbose:
print("Prior knowledge:", prior)
cycles = list(nx.simple_cycles(new_dag))
while len(cycles) > 0:
cycle = cycles.pop(0)
if verbose:
print("Checking cycle", cycle)
for i, node in enumerate(cycle):
if node == cycle[-1]:
neighbour = cycle[0]
else:
neighbour = cycle[i+1]
# Check if the edge between the two nodes is in the prior knowledge
if [node, neighbour] not in prior:
if verbose:
print(
f"↳ Checking '{node} -> {neighbour}' in the prior knowledge")
idx_node = [i for i, l in enumerate(prior) if node in l][0]
idx_neighbour = [i for i, l in enumerate(
prior) if neighbour in l][0]
if idx_neighbour - idx_node < 0:
if verbose:
print(
f" ↳ Breaking cycle, removing edge {node} -> {neighbour}")
print("** Recomputing cycles **")
new_dag.remove_edge(node, neighbour)
cycles = list(nx.simple_cycles(new_dag))
break
return new_dag
[docs]
def potential_misoriented_edges(
loop: List[str],
discrepancies: Dict,
verbose: bool = False) -> List[Tuple[str, str, float]]:
"""
Find potential misoriented edges in a loop based on discrepancies. The loop
is a list of nodes that form a loop in a DAG. The discrepancies is a dictionary
containing the discrepancies between nodes. The function returns a list of
potential misoriented edges, sorted by the difference in goodness-of-fit
scores.
Args:
loop (List[str]): The list of nodes in the loop.
discrepancies (Dict): A dictionary containing discrepancies between
nodes. It is typically a positive float number between 0 and 1.
verbose (bool, optional): Whether to print verbose output.
Defaults to False.
Returns:
List[Tuple[str, str, float]]: A list of potential misoriented edges,
sorted by the difference in goodness-of-fit scores.
"""
potential_misoriented_edges = []
for i, node in enumerate(loop):
if i == len(loop) - 1:
neighbor = loop[0]
else:
neighbor = loop[i + 1]
if neighbor not in discrepancies or node not in discrepancies:
continue
if node not in discrepancies[neighbor] or neighbor not in discrepancies[node]:
continue
fwd_gof = 1. - discrepancies[neighbor][node].shap_gof
bwd_gof = 1. - discrepancies[node][neighbor].shap_gof
orientation = +1 if fwd_gof < bwd_gof else -1
if orientation == -1:
potential_misoriented_edges.append(
(node, neighbor, fwd_gof - bwd_gof))
if verbose:
direction = "-->" if fwd_gof < bwd_gof else "<--"
print(
f"Edge: {node}{direction}{neighbor} FWD:{fwd_gof:.3f} "
f"BWD:{bwd_gof:.3f}")
return sorted(potential_misoriented_edges, key=lambda x: x[2], reverse=True)
[docs]
def break_cycles_if_present(
dag: nx.DiGraph,
discrepancies: Dict,
prior: Optional[List[List[str]]] = None,
verbose: bool = False) -> nx.DiGraph:
"""
Breaks cycles in a directed acyclic graph (DAG) by removing the edge with
the lowest goodness of fit (R2). If there are multiple cycles, they are
all traversed and fixed. If prior is set, then the cycles are broken
using the prior knowledge.
Parameters:
- dag (nx.DiGraph): the DAG to break cycles in.
- knowledge (pd.DataFrame): a DataFrame containing the permutation importances
for each edge in the DAG.
- prior (List[List[str]]): a list of lists containing the prior knowledge
about the edges in the DAG. The lists define a hierarchy of edges that
represent a temporal relation in cause and effect. If a node is in the first
list, then it is a root cause. If a node is in the second list, then it is
caused by the nodes in the first list or the second list, and so on.
Returns:
- dag (nx.DiGraph): the DAG with cycles broken.
"""
if verbose:
print("-----\nbreak_cycles_if_present()")
# If prior is set, then break cycles using the prior knowledge
if prior:
new_dag = break_cycles_using_prior(dag, prior, verbose)
else:
new_dag = dag.copy()
# Find all cycles in the DAG
cycles = list(nx.simple_cycles(new_dag))
# This might be important, to remove first double edges
cycles.sort(key=len)
if len(cycles) == 0:
if verbose:
print("No cycles found")
return new_dag
# Traverse all cycles, fixing them
cycles_info = []
while len(cycles) > 0:
cycle = cycles.pop(0)
# For every pair of consecutive nodes in the cycle, store its SHAP
# discrepancy
cycle_edges = {}
for node in cycle:
if node == cycle[-1]:
neighbour = cycle[0]
else:
neighbour = cycle[cycle.index(node)+1]
cycle_edges[(node, neighbour)
] = discrepancies[neighbour][node].shap_gof
cycles_info.append((cycle, cycle_edges))
# Find the edge with the lowest SHAP discrepancy
min_gof = min(cycle_edges.values())
min_edge = [edge for edge, gof in cycle_edges.items() if gof ==
min_gof][0]
# Check if any of the edges in the loop is potentially misoriented.
# If so, change the orientation of the edge with the lowest SHAP
# discrepancy
potential_misoriented = potential_misoriented_edges(
cycle, discrepancies, verbose)
if len(potential_misoriented) > 0:
if verbose:
print("Potential misoriented edges in the cycle:")
for edge in potential_misoriented:
print(f" {edge[0]} --> {edge[1]} ({edge[2]:.3f})")
# Check if changing the orientation of the edge would break
# the cycle
test_dag = new_dag.copy()
# Check if the edge is yet in the test_dag
if min_edge in test_dag.edges:
test_dag.remove_edge(*min_edge)
test_dag.add_edge(*potential_misoriented[0][1::-1])
test_cycles = list(nx.simple_cycles(test_dag))
if len(test_cycles) == len(cycles):
if verbose:
print(
f"Breaking cycle {cycle} by changing orientation of "
f"edge {min_edge}")
new_dag.remove_edge(*min_edge)
new_dag.add_edge(*potential_misoriented[0][1::-1])
continue
# Remove the edge with the lowest SHAP discrepancy, checking that
# the edge still exists
if min_edge in new_dag.edges:
if verbose:
print(f"Breaking cycle {cycle} by removing edge {min_edge}")
new_dag.remove_edge(*min_edge)
# Recompute whether there're cycles in the DAG after this change
cycles = list(nx.simple_cycles(new_dag))
cycles.sort(key=len)
return new_dag
[docs]
def stringfy_object(object_: object) -> str:
"""
Convert an object into a string representation, including its attributes.
Args:
object_ (object): The object to be converted.
Returns:
str: A string representation of the object and its attributes.
"""
forbidden_attrs = [
'fit', 'predict', 'fit_predict', 'score', 'get_params', 'set_params']
ret = "Object attributes\n"
ret += f"{'-'*80}\n"
for attr in dir(object_):
if attr.startswith('_') or \
attr in forbidden_attrs or \
isinstance(getattr(object_, attr), types.MethodType):
continue
if isinstance(getattr(object_, attr), (pd.DataFrame, np.ndarray)):
ret += f"{attr:25} {type(getattr(object_, attr)).__name__} "
ret += f"{getattr(object_, attr).shape}\n"
continue
if isinstance(getattr(object_, attr), nx.DiGraph):
n_nodes = getattr(object_, attr).number_of_nodes()
n_edges = getattr(object_, attr).number_of_edges()
ret += f"{attr:25} {n_nodes} nodes, {n_edges} edges\n"
continue
if isinstance(getattr(object_, attr), dict):
keys_list = [
f"{k}:{type(getattr(object_, attr)[k])}"
for k in getattr(object_, attr).keys()
]
ret += f"{attr:25} dict {str(keys_list)[:60]}...\n"
continue
if isinstance(getattr(object_, attr), list):
list_ = '[' + ', '.join(str(e)
for e in getattr(object_, attr)) + ']'
ret += f"{attr:25} list {list_[:60]}...\n"
continue
if isinstance(getattr(object_, attr), (int, float, str, bool)):
ret += f"{attr:25} {getattr(object_, attr)}\n"
continue
ret += f"{attr:25} <{getattr(object_, attr).__class__.__name__}>\n"
return ret
[docs]
def get_feature_names(X: Union[pd.DataFrame, np.ndarray, list]) -> List[str]:
"""
Get the feature names from the input data. The feature names can be
extracted from a pandas DataFrame, a numpy array or a list.
Parameters:
-----------
X : Union[pd.DataFrame, np.ndarray, list]
The input data.
Returns:
--------
List[str]
The list of feature names.
"""
assert isinstance(X, (pd.DataFrame, np.ndarray, list)), \
"X must be a pandas DataFrame, a numpy array or a list"
if isinstance(X, pd.DataFrame):
feature_names = list(X.columns)
elif isinstance(X, np.ndarray):
feature_names = [f"X{i}" for i in range(X.shape[1])]
else:
feature_names = [f"X{i}" for i in range(len(X[0]))]
return feature_names
def _classify_variable(arr):
"""
This method inspect the contents of the array and returns the type of the
variable. If the variables are NOT numbers, then it returns multiclass. If
the variable is a number, then it returns binary if it has only two unique
values. If the number of unique values is greater than two, then it returns
continuous.
Args:
arr (Union[pd.Series, np.ndarray]): The array to be inspected.
Returns:
str: The type of the variable. Can be binary, multiclass or continuous.
"""
# Convert numpy array to pandas Series if needed
if isinstance(arr, np.ndarray):
unique_values = np.unique(arr)
else:
unique_values = arr.unique()
# Handle empty series
if len(unique_values) == 0:
return "continuous"
# Check if the variable is not a number
if unique_values.dtype.kind not in "biufc":
return "multiclass"
# Check if the variable has only two unique values
elif len(unique_values) == 2:
return "binary"
else:
return "continuous"
[docs]
def get_feature_types(X) -> Dict[str, str]:
"""
Get the feature types from the input data. The feature types can be
binary, multiclass or continuous. The classification is done based on the
number of unique values in the array. If the array has only two unique
values, then the variable is classified as binary. If the unique values
are integers, then the variable is classified as multiclass. Otherwise,
the variable is classified as continuous.
Parameters:
-----------
X : Union[pd.DataFrame, np.ndarray, list]
The input data.
Returns:
--------
Dict[str, str]
A dictionary with the feature names as keys and the feature types as
values.
Example:
--------
>>> X = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
>>> get_feature_types(X)
{'X0': 'multiclass', 'X1': 'multiclass', 'X2': 'multiclass'}
"""
feature_names = get_feature_names(X)
feature_types = {}
if isinstance(X, np.ndarray):
for i, feature in enumerate(X.T):
feature_types[feature_names[i]] = _classify_variable(feature)
elif isinstance(X, pd.DataFrame):
for feature in X.columns:
feature_types[feature] = _classify_variable(X[feature].values)
else:
for i, feature in enumerate(X[0]):
feature_types[feature_names[i]] = _classify_variable(feature)
return feature_types
[docs]
def cast_categoricals_to_int(X: pd.DataFrame) -> pd.DataFrame:
"""
Cast all categorical features in the dataframe to integer values.
Parameters
----------
X: pd.DataFrame
The dataframe to cast.
Returns
-------
pd.DataFrame
The dataframe with all categorical features cast to integer values.
"""
X = X.copy()
for feature in X.columns:
feature_dtype = _classify_variable(X[feature])
if feature_dtype == "multiclass" or feature_dtype == "binary":
# Convert all the values to integers. For example, 'a' -> 0, 'b' ->
# 1, 'c' -> 2, respecting the order of the unique values.abs
X[feature] = X[feature].astype("category")
X[feature] = X[feature].cat.codes
X[feature] = X[feature].astype("int16")
return X
[docs]
def find_crossing_point(
f1: List[float], # Values of the first curve (precision, for example)
f2: List[float], # Values of the second curve (recall, for example)
# X values (thresholds, for example) corresponding
x_values: List[float]
# to the curve evaluations
) -> Optional[float]:
"""
This function finds the exact crossing point between two curves defined by
f1 and f2 over x_values. It interpolates between points to provide a more
accurate crossing point.
Parameters:
-----------
f1: List[float]
Values of the first curve (precision, for example).
f2: List[float]
Values of the second curve (recall, for example).
x_values: List[float]
X values (thresholds, for example) corresponding to the curve evaluations.
Returns:
--------
Optional[float]
The x value where the first crossing occurs between the two curves.
If no crossing is found, returns None.
"""
def interpolate_crossing(
f1_prev: float, # Value of f1 at the previous x value
f2_prev: float, # Value of f2 at the previous x value
x_prev: float, # Previous x value
x_next: float, # Next x value
f1_next: float, # Value of f1 at the next x value
f2_next: float # Value of f2 at the next x value
) -> float:
# Linear interpolation formula to find the exact crossing point between
# two points
return x_prev + (x_next - x_prev) * abs(f1_prev - f2_prev) / \
(abs(f1_prev - f1_next) + abs(f2_prev - f2_next))
# Loop through the x values and find where the first crossing occurs
for i in range(1, len(x_values)):
# Check for a crossing between consecutive points
if (f1[i-1] < f2[i-1] and f1[i] > f2[i]) or \
(f1[i-1] > f2[i-1] and f1[i] < f2[i]):
# Interpolate to find the exact crossing point
return interpolate_crossing(
f1[i-1], f2[i-1], x_values[i-1], x_values[i], f1[i], f2[i])
return None # No crossing found
[docs]
def combine_dags(
dag1: nx.DiGraph,
dag2: nx.DiGraph,
discrepancies: Dict,
prior: Optional[List[List[str]]] = None
) -> Tuple[nx.DiGraph, nx.DiGraph, nx.DiGraph, nx.DiGraph]:
"""
Combine two directed acyclic graphs (DAGs) into a single DAG.
Parameters
----------
dag1 : nx.DiGraph
The first DAG.
dag2 : nx.DiGraph
The second DAG.
discrepancies : Dict
A Dictionary containing the permutation importances for each edge in the DAGs.
prior : Optional[List[List[str]]], optional
A list of lists containing the prior knowledge about the edges in the DAGs.
The lists define a hierarchy of edges that represent a temporal relation in
cause and effect. If a node is in the first list, then it is a root cause.
If a node is in the second list, then it is caused by the nodes in the
first list or the second list, and so on.
Returns
-------
Tuple[nx.DiGraph, nx.DiGraph, nx.DiGraph, nx.DiGraph]
A tuple containing four graphs: the union of the two DAGs, the
intersection of the two DAGs, the union of the two DAGs after removing
cycles, and the intersection of the two DAGs after removing cycles.
"""
# Compute the union of the two DAGs
union = graph_union(dag1, dag2)
# Compute the intersection of the two DAGs
inter = graph_intersection(dag1, dag2)
# Remove cycles from the union and intersection
union_cycles_removed = break_cycles_if_present(
union, discrepancies, prior)
inter_cycles_removed = break_cycles_if_present(
inter, discrepancies, prior)
# XXX: Experiment -> Use prior to correct edges in the union
if prior:
edges_list = list(union_cycles_removed.edges)
for u, v in edges_list:
correct_edge_from_prior(union_cycles_removed, u, v, prior, verbose=False)
return union, inter, union_cycles_removed, inter_cycles_removed
[docs]
def list_files(input_pattern: str, where: str) -> list:
"""
List all the files in the input path matching the input pattern
Parameters:
-----------
input_pattern : str
The pattern to match the files
where : str
The path to use to look for the files matching the pattern
"""
input_files = glob.glob(os.path.join(
where, input_pattern))
input_files = sorted([os.path.basename(os.path.splitext(f)[0])
for f in input_files])
assert len(input_files) > 0, \
f"No files found in {where} matching <{input_pattern}>"
return sorted(list(set(input_files)))
[docs]
def read_json_file(file_path: str):
"""
Read a JSON file and return its content as a dictionary."
"""
try:
with open(file_path, 'r', encoding='utf-8') as file:
data = json.load(file)
# Get 'prior' key, default to empty list
prior = data.get("prior", [])
return prior
except FileNotFoundError:
print(f"Error: File '{file_path}' not found.")
return []
except json.JSONDecodeError:
print("Error: Invalid JSON format.")
return []