about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/networkx/algorithms/tests/test_d_separation.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/networkx/algorithms/tests/test_d_separation.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/networkx/algorithms/tests/test_d_separation.py')
-rw-r--r--.venv/lib/python3.12/site-packages/networkx/algorithms/tests/test_d_separation.py348
1 files changed, 348 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/networkx/algorithms/tests/test_d_separation.py b/.venv/lib/python3.12/site-packages/networkx/algorithms/tests/test_d_separation.py
new file mode 100644
index 00000000..6f629713
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/networkx/algorithms/tests/test_d_separation.py
@@ -0,0 +1,348 @@
+from itertools import combinations
+
+import pytest
+
+import networkx as nx
+
+
+def path_graph():
+    """Return a path graph of length three."""
+    G = nx.path_graph(3, create_using=nx.DiGraph)
+    G.graph["name"] = "path"
+    nx.freeze(G)
+    return G
+
+
+def fork_graph():
+    """Return a three node fork graph."""
+    G = nx.DiGraph(name="fork")
+    G.add_edges_from([(0, 1), (0, 2)])
+    nx.freeze(G)
+    return G
+
+
+def collider_graph():
+    """Return a collider/v-structure graph with three nodes."""
+    G = nx.DiGraph(name="collider")
+    G.add_edges_from([(0, 2), (1, 2)])
+    nx.freeze(G)
+    return G
+
+
+def naive_bayes_graph():
+    """Return a simply Naive Bayes PGM graph."""
+    G = nx.DiGraph(name="naive_bayes")
+    G.add_edges_from([(0, 1), (0, 2), (0, 3), (0, 4)])
+    nx.freeze(G)
+    return G
+
+
+def asia_graph():
+    """Return the 'Asia' PGM graph."""
+    G = nx.DiGraph(name="asia")
+    G.add_edges_from(
+        [
+            ("asia", "tuberculosis"),
+            ("smoking", "cancer"),
+            ("smoking", "bronchitis"),
+            ("tuberculosis", "either"),
+            ("cancer", "either"),
+            ("either", "xray"),
+            ("either", "dyspnea"),
+            ("bronchitis", "dyspnea"),
+        ]
+    )
+    nx.freeze(G)
+    return G
+
+
+@pytest.fixture(name="path_graph")
+def path_graph_fixture():
+    return path_graph()
+
+
+@pytest.fixture(name="fork_graph")
+def fork_graph_fixture():
+    return fork_graph()
+
+
+@pytest.fixture(name="collider_graph")
+def collider_graph_fixture():
+    return collider_graph()
+
+
+@pytest.fixture(name="naive_bayes_graph")
+def naive_bayes_graph_fixture():
+    return naive_bayes_graph()
+
+
+@pytest.fixture(name="asia_graph")
+def asia_graph_fixture():
+    return asia_graph()
+
+
+@pytest.fixture()
+def large_collider_graph():
+    edge_list = [("A", "B"), ("C", "B"), ("B", "D"), ("D", "E"), ("B", "F"), ("G", "E")]
+    G = nx.DiGraph(edge_list)
+    return G
+
+
+@pytest.fixture()
+def chain_and_fork_graph():
+    edge_list = [("A", "B"), ("B", "C"), ("B", "D"), ("D", "C")]
+    G = nx.DiGraph(edge_list)
+    return G
+
+
+@pytest.fixture()
+def no_separating_set_graph():
+    edge_list = [("A", "B")]
+    G = nx.DiGraph(edge_list)
+    return G
+
+
+@pytest.fixture()
+def large_no_separating_set_graph():
+    edge_list = [("A", "B"), ("C", "A"), ("C", "B")]
+    G = nx.DiGraph(edge_list)
+    return G
+
+
+@pytest.fixture()
+def collider_trek_graph():
+    edge_list = [("A", "B"), ("C", "B"), ("C", "D")]
+    G = nx.DiGraph(edge_list)
+    return G
+
+
+@pytest.mark.parametrize(
+    "graph",
+    [path_graph(), fork_graph(), collider_graph(), naive_bayes_graph(), asia_graph()],
+)
+def test_markov_condition(graph):
+    """Test that the Markov condition holds for each PGM graph."""
+    for node in graph.nodes:
+        parents = set(graph.predecessors(node))
+        non_descendants = graph.nodes - nx.descendants(graph, node) - {node} - parents
+        assert nx.is_d_separator(graph, {node}, non_descendants, parents)
+
+
+def test_path_graph_dsep(path_graph):
+    """Example-based test of d-separation for path_graph."""
+    assert nx.is_d_separator(path_graph, {0}, {2}, {1})
+    assert not nx.is_d_separator(path_graph, {0}, {2}, set())
+
+
+def test_fork_graph_dsep(fork_graph):
+    """Example-based test of d-separation for fork_graph."""
+    assert nx.is_d_separator(fork_graph, {1}, {2}, {0})
+    assert not nx.is_d_separator(fork_graph, {1}, {2}, set())
+
+
+def test_collider_graph_dsep(collider_graph):
+    """Example-based test of d-separation for collider_graph."""
+    assert nx.is_d_separator(collider_graph, {0}, {1}, set())
+    assert not nx.is_d_separator(collider_graph, {0}, {1}, {2})
+
+
+def test_naive_bayes_dsep(naive_bayes_graph):
+    """Example-based test of d-separation for naive_bayes_graph."""
+    for u, v in combinations(range(1, 5), 2):
+        assert nx.is_d_separator(naive_bayes_graph, {u}, {v}, {0})
+        assert not nx.is_d_separator(naive_bayes_graph, {u}, {v}, set())
+
+
+def test_asia_graph_dsep(asia_graph):
+    """Example-based test of d-separation for asia_graph."""
+    assert nx.is_d_separator(
+        asia_graph, {"asia", "smoking"}, {"dyspnea", "xray"}, {"bronchitis", "either"}
+    )
+    assert nx.is_d_separator(
+        asia_graph, {"tuberculosis", "cancer"}, {"bronchitis"}, {"smoking", "xray"}
+    )
+
+
+def test_undirected_graphs_are_not_supported():
+    """
+    Test that undirected graphs are not supported.
+
+    d-separation and its related algorithms do not apply in
+    the case of undirected graphs.
+    """
+    g = nx.path_graph(3, nx.Graph)
+    with pytest.raises(nx.NetworkXNotImplemented):
+        nx.is_d_separator(g, {0}, {1}, {2})
+    with pytest.raises(nx.NetworkXNotImplemented):
+        nx.is_minimal_d_separator(g, {0}, {1}, {2})
+    with pytest.raises(nx.NetworkXNotImplemented):
+        nx.find_minimal_d_separator(g, {0}, {1})
+
+
+def test_cyclic_graphs_raise_error():
+    """
+    Test that cycle graphs should cause erroring.
+
+    This is because PGMs assume a directed acyclic graph.
+    """
+    g = nx.cycle_graph(3, nx.DiGraph)
+    with pytest.raises(nx.NetworkXError):
+        nx.is_d_separator(g, {0}, {1}, {2})
+    with pytest.raises(nx.NetworkXError):
+        nx.find_minimal_d_separator(g, {0}, {1})
+    with pytest.raises(nx.NetworkXError):
+        nx.is_minimal_d_separator(g, {0}, {1}, {2})
+
+
+def test_invalid_nodes_raise_error(asia_graph):
+    """
+    Test that graphs that have invalid nodes passed in raise errors.
+    """
+    # Check both set and node arguments
+    with pytest.raises(nx.NodeNotFound):
+        nx.is_d_separator(asia_graph, {0}, {1}, {2})
+    with pytest.raises(nx.NodeNotFound):
+        nx.is_d_separator(asia_graph, 0, 1, 2)
+    with pytest.raises(nx.NodeNotFound):
+        nx.is_minimal_d_separator(asia_graph, {0}, {1}, {2})
+    with pytest.raises(nx.NodeNotFound):
+        nx.is_minimal_d_separator(asia_graph, 0, 1, 2)
+    with pytest.raises(nx.NodeNotFound):
+        nx.find_minimal_d_separator(asia_graph, {0}, {1})
+    with pytest.raises(nx.NodeNotFound):
+        nx.find_minimal_d_separator(asia_graph, 0, 1)
+
+
+def test_nondisjoint_node_sets_raise_error(collider_graph):
+    """
+    Test that error is raised when node sets aren't disjoint.
+    """
+    with pytest.raises(nx.NetworkXError):
+        nx.is_d_separator(collider_graph, 0, 1, 0)
+    with pytest.raises(nx.NetworkXError):
+        nx.is_d_separator(collider_graph, 0, 2, 0)
+    with pytest.raises(nx.NetworkXError):
+        nx.is_d_separator(collider_graph, 0, 0, 1)
+    with pytest.raises(nx.NetworkXError):
+        nx.is_d_separator(collider_graph, 1, 0, 0)
+    with pytest.raises(nx.NetworkXError):
+        nx.find_minimal_d_separator(collider_graph, 0, 0)
+    with pytest.raises(nx.NetworkXError):
+        nx.find_minimal_d_separator(collider_graph, 0, 1, included=0)
+    with pytest.raises(nx.NetworkXError):
+        nx.find_minimal_d_separator(collider_graph, 1, 0, included=0)
+    with pytest.raises(nx.NetworkXError):
+        nx.is_minimal_d_separator(collider_graph, 0, 0, set())
+    with pytest.raises(nx.NetworkXError):
+        nx.is_minimal_d_separator(collider_graph, 0, 1, set(), included=0)
+    with pytest.raises(nx.NetworkXError):
+        nx.is_minimal_d_separator(collider_graph, 1, 0, set(), included=0)
+
+
+def test_is_minimal_d_separator(
+    large_collider_graph,
+    chain_and_fork_graph,
+    no_separating_set_graph,
+    large_no_separating_set_graph,
+    collider_trek_graph,
+):
+    # Case 1:
+    # create a graph A -> B <- C
+    # B -> D -> E;
+    # B -> F;
+    # G -> E;
+    assert not nx.is_d_separator(large_collider_graph, {"B"}, {"E"}, set())
+
+    # minimal set of the corresponding graph
+    # for B and E should be (D,)
+    Zmin = nx.find_minimal_d_separator(large_collider_graph, "B", "E")
+    # check that the minimal d-separator is a d-separating set
+    assert nx.is_d_separator(large_collider_graph, "B", "E", Zmin)
+    # the minimal separating set should also pass the test for minimality
+    assert nx.is_minimal_d_separator(large_collider_graph, "B", "E", Zmin)
+    # function should also work with set arguments
+    assert nx.is_minimal_d_separator(large_collider_graph, {"A", "B"}, {"G", "E"}, Zmin)
+    assert Zmin == {"D"}
+
+    # Case 2:
+    # create a graph A -> B -> C
+    # B -> D -> C;
+    assert not nx.is_d_separator(chain_and_fork_graph, {"A"}, {"C"}, set())
+    Zmin = nx.find_minimal_d_separator(chain_and_fork_graph, "A", "C")
+
+    # the minimal separating set should pass the test for minimality
+    assert nx.is_minimal_d_separator(chain_and_fork_graph, "A", "C", Zmin)
+    assert Zmin == {"B"}
+    Znotmin = Zmin.union({"D"})
+    assert not nx.is_minimal_d_separator(chain_and_fork_graph, "A", "C", Znotmin)
+
+    # Case 3:
+    # create a graph A -> B
+
+    # there is no m-separating set between A and B at all, so
+    # no minimal m-separating set can exist
+    assert not nx.is_d_separator(no_separating_set_graph, {"A"}, {"B"}, set())
+    assert nx.find_minimal_d_separator(no_separating_set_graph, "A", "B") is None
+
+    # Case 4:
+    # create a graph A -> B with A <- C -> B
+
+    # there is no m-separating set between A and B at all, so
+    # no minimal m-separating set can exist
+    # however, the algorithm will initially propose C as a
+    # minimal (but invalid) separating set
+    assert not nx.is_d_separator(large_no_separating_set_graph, {"A"}, {"B"}, {"C"})
+    assert nx.find_minimal_d_separator(large_no_separating_set_graph, "A", "B") is None
+
+    # Test `included` and `excluded` args
+    # create graph A -> B <- C -> D
+    assert nx.find_minimal_d_separator(collider_trek_graph, "A", "D", included="B") == {
+        "B",
+        "C",
+    }
+    assert (
+        nx.find_minimal_d_separator(
+            collider_trek_graph, "A", "D", included="B", restricted="B"
+        )
+        is None
+    )
+
+
+def test_is_minimal_d_separator_checks_dsep():
+    """Test that is_minimal_d_separator checks for d-separation as well."""
+    g = nx.DiGraph()
+    g.add_edges_from(
+        [
+            ("A", "B"),
+            ("A", "E"),
+            ("B", "C"),
+            ("B", "D"),
+            ("D", "C"),
+            ("D", "F"),
+            ("E", "D"),
+            ("E", "F"),
+        ]
+    )
+
+    assert not nx.is_d_separator(g, {"C"}, {"F"}, {"D"})
+
+    # since {'D'} and {} are not d-separators, we return false
+    assert not nx.is_minimal_d_separator(g, "C", "F", {"D"})
+    assert not nx.is_minimal_d_separator(g, "C", "F", set())
+
+
+def test__reachable(large_collider_graph):
+    reachable = nx.algorithms.d_separation._reachable
+    g = large_collider_graph
+    x = {"F", "D"}
+    ancestors = {"A", "B", "C", "D", "F"}
+    assert reachable(g, x, ancestors, {"B"}) == {"B", "F", "D"}
+    assert reachable(g, x, ancestors, set()) == ancestors
+
+
+def test_deprecations():
+    G = nx.DiGraph([(0, 1), (1, 2)])
+    with pytest.deprecated_call():
+        nx.d_separated(G, 0, 2, {1})
+    with pytest.deprecated_call():
+        z = nx.minimal_d_separator(G, 0, 2)