Source code for causalexplain.estimators.cam.pruning
"""
- The `pruning` function is translated to Python.
- `dim(G)[1]` is replaced with `G.shape[0]` to get the number of rows.
- `matrix(0,p,p)` is replaced with `np.zeros((p, p))` to create a zero matrix.
- `which(G[,i]==1)` is replaced with `np.where(G[:, i] == 1)[0]` to find the indices
where the condition is true.
- `cbind(X[,parents],X[,i])` is replaced with `np.hstack((X[:, parents], X[:, [i]]))`
to concatenate arrays horizontally.
- The `cat` function is replaced with `print` for output.
- The `pruneMethod` function is passed as `prune_method` and called accordingly.
"""
import numpy as np
from .selGam import selGam
[docs]
def pruning(
X,
G,
verbose=False,
prune_method=None,
prune_method_pars={'cutOffPVal': 0.001, 'numBasisFcts': 10}):
"""_summary_
Args:
X (_type_): Input vectors
G (_type_): Adjacency matrix representing a DAG
output (bool, optional): Whether to print debug messages
prune_method (_type_, optional): _description_. Defaults to None.
prune_method_pars (dict, optional): _description_.
Defaults to {'cutOffPVal': 0.001, 'numBasisFcts': 10}.
Returns:
_type_: _description_
"""
if prune_method is None:
prune_method = selGam
p = G.shape[0]
finalG = np.zeros((p, p))
for i in range(p):
parents = np.where(G[:, i] == 1)[0]
lenpa = len(parents)
if verbose:
print(f"Pruning variable: {i}")
print(f". Considered parents: {parents}")
if lenpa > 0:
Xtmp = np.hstack((X[:, parents], X[:, [i]]))
selected_par = prune_method(
Xtmp, pars=prune_method_pars, verbose=verbose, k=lenpa+1)
final_parents = parents[selected_par]
finalG[final_parents, i] = 1
if verbose:
print(f". Final parents: {final_parents}")
print(f". Pruned parents of {i}: {[p for p in parents if p not in final_parents]}")
return finalG