Source code for causalexplain.estimators.pc.pdag

import networkx as nx

from .dag import DAG
from warnings import warn


[docs] class PDAG(nx.DiGraph): """ Class for representing PDAGs (also known as CPDAG). PDAGs are the equivance classes of DAGs and contain both directed and undirected edges. **Note: In this class, undirected edges are represented using two edges in both direction i.e. an undirected edge between X - Y is represented using X -> Y and X <- Y. """
[docs] def __init__(self, directed_ebunch=[], undirected_ebunch=[], latents=[]): """ Initializes a PDAG class. Parameters ---------- directed_ebunch: list, array-like of 2-tuples List of directed edges in the PDAG. undirected_ebunch: list, array-like of 2-tuples List of undirected edges in the PDAG. latents: list, array-like List of nodes which are latent variables. Returns ------- An instance of the PDAG object. Examples -------- """ super(PDAG, self).__init__( directed_ebunch + undirected_ebunch + [(Y, X) for (X, Y) in undirected_ebunch] ) self.latents = set(latents) self.directed_edges = set(directed_ebunch) self.undirected_edges = set(undirected_ebunch)
# TODO: Fix the cycle issue # import pdb; pdb.set_trace() # try: # # Filter out undirected edges as they also form a cycle in # # themself when represented using directed edges. # cycles = filter(lambda t: len(t) > 2, nx.simple_cycles(self)) # if cycles: # out_str = "Cycles are not allowed in a PDAG. " # out_str += "The following path forms a loop: " # out_str += "".join(["({u},{v}) ".format(u=u, v=v) for (u, v) in cycles]) # raise ValueError(out_str) # except nx.NetworkXNoCycle: # pass
[docs] def copy(self): """ Returns a copy of the object instance. Returns ------- PDAG instance: Returns a copy of self. """ return PDAG( directed_ebunch=list(self.directed_edges.copy()), undirected_ebunch=list(self.undirected_edges.copy()), latents=self.latents, )
[docs] def to_dag(self, required_edges=[]): """ Returns one possible DAG which is represented using the PDAG. Parameters ---------- required_edges: list, array-like of 2-tuples The list of edges that should be included in the DAG. Returns ------- Returns an instance of DAG. Examples -------- """ # Add required edges if it doesn't form a new v-structure or an opposite edge # is already present in the network. dag = DAG() # Add all the nodes and the directed edges dag.add_nodes_from(self.nodes()) dag.add_edges_from(self.directed_edges) dag.latents = self.latents pdag = self.copy() while pdag.number_of_nodes() > 0: # find node with (1) no directed outgoing edges and # (2) the set of undirected neighbors is either empty or # undirected neighbors + parents of X are a clique found = False for X in pdag.nodes(): directed_outgoing_edges = set(pdag.successors(X)) - set( pdag.predecessors(X) ) undirected_neighbors = set(pdag.successors(X)) & set( pdag.predecessors(X) ) neighbors_are_clique = all( ( pdag.has_edge(Y, Z) for Z in pdag.predecessors(X) for Y in undirected_neighbors if not Y == Z ) ) if not directed_outgoing_edges and ( not undirected_neighbors or neighbors_are_clique ): found = True # add all edges of X as outgoing edges to dag for Y in pdag.predecessors(X): dag.add_edge(Y, X) pdag.remove_node(X) break if not found: warn( "PDAG has no faithful extension (= no oriented DAG with the " + "same v-structures as PDAG). Remaining undirected PDAG edges " + "oriented arbitrarily." ) for X, Y in pdag.edges(): if not dag.has_edge(Y, X): try: dag.add_edge(X, Y) except ValueError: pass break return dag