Source code for causalexplain.independence.cond_indep

#
# Module to compute conditional independencencies in a DAG, and
# sufficient tests.
#
# Author: Jesús Renero
#
import itertools
from typing import List, Optional, Union
import networkx as nx


[docs] class ConditionalIndependencies: """ A class to store conditional independencies in a graph. Attributes ---------- _cache : dict A dictionary representing the conditional independencies. Methods ------- add(x, y, z) Adds a new conditional independence to the cache. __str__() Returns a string representation of the conditional independencies. __repr__() Returns a string representation of the conditional independencies. """
[docs] def __init__(self): self._cache = {}
[docs] def __str__(self): s = "Conditional Independencies:\n" for (x, y, z) in self._cache: if z == None: s += f" {x}{y}\n" else: s += f" {x}{y} | {z}\n" return s
[docs] def __repr__(self) -> str: """Returns a string representation of the ConditionalIndependencies object.""" items = [] for x, y, z in self._cache.items(): if z is None: items.append(f"({x!r}, {y!r}):{[]!r}") else: items.append(f"({x!r}, {y!r}):{list(z)!r}") return "{" + ", ".join(items) + "}"
[docs] def add(self, var1: str, var2: str, conditioning_set: Optional[List[str]] = None) -> None: """ Adds a new conditional independence to the cache. Parameters ---------- var1 : str A node in the graph. var2 : str A node in the graph. conditioning_set : list of str or None A set of nodes in the graph. """ conditioning_set = tuple(conditioning_set) if conditioning_set is not None else None if not self._cache: self._cache[(var1, var2, conditioning_set)] = True return if (var1, var2, conditioning_set) in self._cache: return self._cache[(var1, var2, conditioning_set)] = True
[docs] class SufficientSets: """ A class to represent the sufficient sets of a conditional independence test. Attributes ---------- _cache : list A list of tuples representing the sufficient sets. Methods ------- add(suff_set) Adds a new sufficient set to the cache. Parameters ---------- suff_set : list A list of tuples representing the new sufficient set to be added. __str__() Returns a string representation of the sufficient sets. """
[docs] def __init__(self): self._cache = []
[docs] def add(self, suff_set): """ Adds a new sufficient set to the cache. Parameters ---------- suff_set : list A list of tuples representing the new sufficient set to be added. """ for element in suff_set: # and ((y, x) not in self._cache): if (element not in self._cache): self._cache.append(element)
[docs] def __str__(self): """ Returns a string representation of the sufficient sets. Returns ------- str A string representation of the sufficient sets. """ s = "Sufficient sets:\n" if self._cache: for sufficient_set in self._cache: s += f" {sufficient_set}\n" else: s = "No sufficient sets found" return s
def __repr__(self): s = "[" last = False if self._cache: for sufficient_set in self._cache: s += f"{sufficient_set}" if sufficient_set == self._cache[-1]: last = True if not last: s += ", " s += "]" return s
[docs] def get_backdoor_paths(dag: nx.DiGraph, x: str, y: str): """ Returns all backdoor paths between two nodes in a graph. A backdoor path is a path that starts with an edge towards 'x' and ends with an edge towards 'y'. Parameters: ----------- dag: nx.DiGraph A directed graph x: str A node in the graph y: str A node in the graph Returns: -------- paths: list A list of paths between x and y """ # Check if x or y are not in the graph if x not in dag.nodes() or y not in dag.nodes(): return [] # If x and y are the same node, return empty list if x == y: return [] undirected_graph = dag.to_undirected() # list all paths between 'x' and 'y' paths = (p for p in nx.all_simple_paths( undirected_graph, source=x, target=y) if len(p) > 1 and dag.has_edge(p[1], x)) return list(paths)
[docs] def get_paths(graph: nx.DiGraph, x: str, y: str): """ Returns all simple paths between two nodes in a directed graph. Parameters ---------- - graph (nx.DiGraph): A directed graph. - x (str): The starting node. - y (str): The ending node. Returns ------- - list: A list of all simple paths between x and y. """ # Check if x or y are not in the graph if x not in graph.nodes() or y not in graph.nodes(): return [] # If x and y are the same node, return empty list if x == y: return [] return list(nx.all_simple_paths(graph, source=x, target=y))
[docs] def find_colliders_in_path(dag: nx.DiGraph, path: List[str]): """ Returns all colliders in a path. Parameters: ----------- G: nx.DiGraph A directed graph path: list A path formed by nodes in the graph Returns: -------- colliders: set A set of colliders in the path """ colliders = [] for i in range(1, len(path)-1): if dag.has_edge(path[i-1], path[i]) and dag.has_edge(path[i+1], path[i]): colliders.append(path[i]) return set(colliders)
[docs] def get_sufficient_sets_for_pair(dag, x, y, verbose=False): """ Compute the sufficient sets for a pair of nodes in a graph. A sufficient set is a set of nodes that blocks all backdoor paths between x and y. Parameters: ----------- G: nx.DiGraph A directed graph x: str A node in the graph y: str A node in the graph verbose: bool If True, print additional information Returns: -------- sufficient_sets: list A list of sufficient sets for the pair of nodes (x, y) """ backdoor_paths = get_backdoor_paths(dag, x, y) if verbose: if backdoor_paths: print(f" Found {len(backdoor_paths)} backdoor paths") else: print(" No backdoor paths found") sufficient_sets = [] for path in backdoor_paths: print(f" Checking backdoor path: {path}") if verbose else None # get all nodes in the path except the first and last sufficient_set = path[1:-1] # check that no node in sufficient_set is descendant of x descendants = nx.descendants(dag, x) if any([d in descendants for d in sufficient_set]): if verbose: print( f"Path {path} discarded because it contains a descendant of x") continue sufficient_sets.append(sufficient_set) # Now I must check that the nodes in the sufficient set block every backdoor path # between x and y final_suff_set = [] for sufficient_set in sufficient_sets: if verbose: print( f" ", f"Checking that {sufficient_set} blocks all backdoor paths " f"between {x} and {y}") # Check if any of the nodes in the sufficient set is in a collider in the path colliders = find_colliders_in_path(dag, [x] + sufficient_set + [y]) # If any of the nodes in the sufficient set is a collider, then continue if len(colliders) > 0: if verbose: print( f" {sufficient_set} contains a collider: {colliders}") continue all_conditions = True for path in backdoor_paths: if verbose: print(f" Checking path {path}") # Check that this path can be blocked by any node in the sufficient set # The path is blocked if any of the nodes in the sufficient set # is in the path colliders = find_colliders_in_path(dag, path) if verbose: if colliders: print(f" ! Colliders in path: {colliders}") else: print(" - No colliders in path") # Find what nodes from the sufficient set are in the path nodes_in_path = set(sufficient_set).intersection(set(path)) if verbose: if nodes_in_path: print( f" + Nodes from sufficient set in path: {nodes_in_path}") else: print(f" - No nodes from sufficient set in path") if len(nodes_in_path) > 0: # Check that at least one of the nodes in the path is NOT a collider if nodes_in_path.intersection(colliders) == set(): if verbose: print( f" Path {path} blocked by nodes in " f"{sufficient_set} ") elif len(nodes_in_path) == len(colliders): if verbose: print(f" ALL nodes in {sufficient_set} are colliders " f"in {path} \n" f" => {nodes_in_path.intersection(colliders)} == " f"{colliders}") all_conditions = False else: if verbose: print(f" Some nodes in {sufficient_set} are NOT colliders " f"in {path} ") else: if verbose: print(f" No nodes in {sufficient_set} are in {path}, " f"so they do not block this path ") all_conditions = False if all_conditions: if verbose: print( f" {sufficient_set} blocks all backdoor paths " f"between x and y") final_suff_set.append(sufficient_set) return final_suff_set
[docs] def get_sufficient_sets(dag, verbose=False): """ Get the sufficient sets (admissible sets) for all pairs of nodes in a graph. Parameters: ----------- G: nx.DiGraph A directed graph verbose: bool If True, print additional information Returns: -------- suff_sets: Suff_Sets A list of sufficient sets for all pairs of nodes in the graph """ suff_sets = SufficientSets() for x, y in itertools.combinations(dag.nodes(), 2): if verbose: print(f"Checking pair ({x}, {y})...", end="", sep="") sufficient_set = get_sufficient_sets_for_pair(dag, x, y, verbose) if sufficient_set: print(f" Adding sufficient set: {sufficient_set}") if verbose else None suff_sets.add(sufficient_set) return suff_sets
# # XXX: This is not used, and should be removed or replaced by the NX # implementation of d-separation or DoWhy implementation. #
[docs] def get_conditional_independencies(dag, verbose=False): """ Computes the set of conditional independencies implied by the graph G. Parameters: ----------- dag : networkx.DiGraph The directed acyclic graph representing the causal relationships between the variables. verbose : bool, optional If True, prints additional information about the computation. Returns: -------- cond_indeps : Cond_Indep The object containing the set of conditional independencies implied by the graph G. """ cond_indeps = ConditionalIndependencies() # Enumerate all pairs of nodes in G that are not d_separated for x, y in itertools.combinations(dag.nodes(), 2): # Check if x and y are connected by an edge if dag.has_edge(x, y) or dag.has_edge(y, x): continue if not nx.d_separated(dag, {x}, {y}, set()): if verbose: print(f"Pair ({x}, {y})") paths = get_paths(dag, x, y) # Check if any of the paths contains a collider for path in paths: if verbose: print(" Path:", path) colliders = find_colliders_in_path(dag, path) if len(colliders) == 0: blockers = set(path[1:-1]) if len(blockers) == 1: cond_indeps.add(x, y, blockers.pop()) else: # Check if the set of blockers is exactly the entire graph # without "x" and "y" if blockers != set(dag.nodes()) - {x, y}: cond_indeps.add(x, y, tuple(blockers)) if verbose: print(f" (no colliders on path {path})\n" f" {x}{y} | {blockers}") else: if verbose: print(f" The set of blockers is the entire graph. ") else: if verbose: print(f" Colliders in path: {colliders}") for blocker in path[1:-1]: if verbose: print(f" Blocking on {blocker}") if blocker not in colliders: cond_indeps.add(x, y, blocker) if verbose: print(f" {x}{y} | {blocker}") else: if verbose: print( f" Blocking on {blocker} is Collider: " f"{colliders}") else: cond_indeps.add(x, y) if verbose: print(f"Pair ({x}, {y})\n" f" {x}{y} | ∅") return cond_indeps
[docs] def custom_main(): G = nx.DiGraph() G.add_edges_from( [ ('x1', 'x2'), ('x2', 'x3'), ('x1', 'x4'), ('x2', 'x4') ] ) ss = get_sufficient_sets(G, verbose=True) print(ss) cond_independencies = get_conditional_independencies(G, verbose=True) print(cond_independencies)
[docs] def main(): G = nx.DiGraph() G.add_edges_from([('z1', 'x'), ('z1', 'z3'), ('z3', 'x'), ('z3', 'y'), ('x', 'y'), ('z2', 'z3'), ('z2', 'y')]) ss = get_sufficient_sets(G, verbose=True) print(ss) cond_independencies = get_conditional_independencies(G, verbose=False) print(cond_independencies)
[docs] def dag_main(): from causalexplain.estimators.pc.dag import DAG G = DAG([('x1', 'x2'), ('x2', 'x3'), ('x1', 'x4'), ('x2', 'x4')]) ci = G.get_independencies() print(ci)
if __name__ == "__main__": custom_main() # main() print("\n-----\n") dag_main()