aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/networkx/algorithms/tests/test_d_separation.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/networkx/algorithms/tests/test_d_separation.py')
-rw-r--r--.venv/lib/python3.12/site-packages/networkx/algorithms/tests/test_d_separation.py348
1 files changed, 348 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/networkx/algorithms/tests/test_d_separation.py b/.venv/lib/python3.12/site-packages/networkx/algorithms/tests/test_d_separation.py
new file mode 100644
index 00000000..6f629713
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/networkx/algorithms/tests/test_d_separation.py
@@ -0,0 +1,348 @@
+from itertools import combinations
+
+import pytest
+
+import networkx as nx
+
+
+def path_graph():
+ """Return a path graph of length three."""
+ G = nx.path_graph(3, create_using=nx.DiGraph)
+ G.graph["name"] = "path"
+ nx.freeze(G)
+ return G
+
+
+def fork_graph():
+ """Return a three node fork graph."""
+ G = nx.DiGraph(name="fork")
+ G.add_edges_from([(0, 1), (0, 2)])
+ nx.freeze(G)
+ return G
+
+
+def collider_graph():
+ """Return a collider/v-structure graph with three nodes."""
+ G = nx.DiGraph(name="collider")
+ G.add_edges_from([(0, 2), (1, 2)])
+ nx.freeze(G)
+ return G
+
+
+def naive_bayes_graph():
+ """Return a simply Naive Bayes PGM graph."""
+ G = nx.DiGraph(name="naive_bayes")
+ G.add_edges_from([(0, 1), (0, 2), (0, 3), (0, 4)])
+ nx.freeze(G)
+ return G
+
+
+def asia_graph():
+ """Return the 'Asia' PGM graph."""
+ G = nx.DiGraph(name="asia")
+ G.add_edges_from(
+ [
+ ("asia", "tuberculosis"),
+ ("smoking", "cancer"),
+ ("smoking", "bronchitis"),
+ ("tuberculosis", "either"),
+ ("cancer", "either"),
+ ("either", "xray"),
+ ("either", "dyspnea"),
+ ("bronchitis", "dyspnea"),
+ ]
+ )
+ nx.freeze(G)
+ return G
+
+
+@pytest.fixture(name="path_graph")
+def path_graph_fixture():
+ return path_graph()
+
+
+@pytest.fixture(name="fork_graph")
+def fork_graph_fixture():
+ return fork_graph()
+
+
+@pytest.fixture(name="collider_graph")
+def collider_graph_fixture():
+ return collider_graph()
+
+
+@pytest.fixture(name="naive_bayes_graph")
+def naive_bayes_graph_fixture():
+ return naive_bayes_graph()
+
+
+@pytest.fixture(name="asia_graph")
+def asia_graph_fixture():
+ return asia_graph()
+
+
+@pytest.fixture()
+def large_collider_graph():
+ edge_list = [("A", "B"), ("C", "B"), ("B", "D"), ("D", "E"), ("B", "F"), ("G", "E")]
+ G = nx.DiGraph(edge_list)
+ return G
+
+
+@pytest.fixture()
+def chain_and_fork_graph():
+ edge_list = [("A", "B"), ("B", "C"), ("B", "D"), ("D", "C")]
+ G = nx.DiGraph(edge_list)
+ return G
+
+
+@pytest.fixture()
+def no_separating_set_graph():
+ edge_list = [("A", "B")]
+ G = nx.DiGraph(edge_list)
+ return G
+
+
+@pytest.fixture()
+def large_no_separating_set_graph():
+ edge_list = [("A", "B"), ("C", "A"), ("C", "B")]
+ G = nx.DiGraph(edge_list)
+ return G
+
+
+@pytest.fixture()
+def collider_trek_graph():
+ edge_list = [("A", "B"), ("C", "B"), ("C", "D")]
+ G = nx.DiGraph(edge_list)
+ return G
+
+
+@pytest.mark.parametrize(
+ "graph",
+ [path_graph(), fork_graph(), collider_graph(), naive_bayes_graph(), asia_graph()],
+)
+def test_markov_condition(graph):
+ """Test that the Markov condition holds for each PGM graph."""
+ for node in graph.nodes:
+ parents = set(graph.predecessors(node))
+ non_descendants = graph.nodes - nx.descendants(graph, node) - {node} - parents
+ assert nx.is_d_separator(graph, {node}, non_descendants, parents)
+
+
+def test_path_graph_dsep(path_graph):
+ """Example-based test of d-separation for path_graph."""
+ assert nx.is_d_separator(path_graph, {0}, {2}, {1})
+ assert not nx.is_d_separator(path_graph, {0}, {2}, set())
+
+
+def test_fork_graph_dsep(fork_graph):
+ """Example-based test of d-separation for fork_graph."""
+ assert nx.is_d_separator(fork_graph, {1}, {2}, {0})
+ assert not nx.is_d_separator(fork_graph, {1}, {2}, set())
+
+
+def test_collider_graph_dsep(collider_graph):
+ """Example-based test of d-separation for collider_graph."""
+ assert nx.is_d_separator(collider_graph, {0}, {1}, set())
+ assert not nx.is_d_separator(collider_graph, {0}, {1}, {2})
+
+
+def test_naive_bayes_dsep(naive_bayes_graph):
+ """Example-based test of d-separation for naive_bayes_graph."""
+ for u, v in combinations(range(1, 5), 2):
+ assert nx.is_d_separator(naive_bayes_graph, {u}, {v}, {0})
+ assert not nx.is_d_separator(naive_bayes_graph, {u}, {v}, set())
+
+
+def test_asia_graph_dsep(asia_graph):
+ """Example-based test of d-separation for asia_graph."""
+ assert nx.is_d_separator(
+ asia_graph, {"asia", "smoking"}, {"dyspnea", "xray"}, {"bronchitis", "either"}
+ )
+ assert nx.is_d_separator(
+ asia_graph, {"tuberculosis", "cancer"}, {"bronchitis"}, {"smoking", "xray"}
+ )
+
+
+def test_undirected_graphs_are_not_supported():
+ """
+ Test that undirected graphs are not supported.
+
+ d-separation and its related algorithms do not apply in
+ the case of undirected graphs.
+ """
+ g = nx.path_graph(3, nx.Graph)
+ with pytest.raises(nx.NetworkXNotImplemented):
+ nx.is_d_separator(g, {0}, {1}, {2})
+ with pytest.raises(nx.NetworkXNotImplemented):
+ nx.is_minimal_d_separator(g, {0}, {1}, {2})
+ with pytest.raises(nx.NetworkXNotImplemented):
+ nx.find_minimal_d_separator(g, {0}, {1})
+
+
+def test_cyclic_graphs_raise_error():
+ """
+ Test that cycle graphs should cause erroring.
+
+ This is because PGMs assume a directed acyclic graph.
+ """
+ g = nx.cycle_graph(3, nx.DiGraph)
+ with pytest.raises(nx.NetworkXError):
+ nx.is_d_separator(g, {0}, {1}, {2})
+ with pytest.raises(nx.NetworkXError):
+ nx.find_minimal_d_separator(g, {0}, {1})
+ with pytest.raises(nx.NetworkXError):
+ nx.is_minimal_d_separator(g, {0}, {1}, {2})
+
+
+def test_invalid_nodes_raise_error(asia_graph):
+ """
+ Test that graphs that have invalid nodes passed in raise errors.
+ """
+ # Check both set and node arguments
+ with pytest.raises(nx.NodeNotFound):
+ nx.is_d_separator(asia_graph, {0}, {1}, {2})
+ with pytest.raises(nx.NodeNotFound):
+ nx.is_d_separator(asia_graph, 0, 1, 2)
+ with pytest.raises(nx.NodeNotFound):
+ nx.is_minimal_d_separator(asia_graph, {0}, {1}, {2})
+ with pytest.raises(nx.NodeNotFound):
+ nx.is_minimal_d_separator(asia_graph, 0, 1, 2)
+ with pytest.raises(nx.NodeNotFound):
+ nx.find_minimal_d_separator(asia_graph, {0}, {1})
+ with pytest.raises(nx.NodeNotFound):
+ nx.find_minimal_d_separator(asia_graph, 0, 1)
+
+
+def test_nondisjoint_node_sets_raise_error(collider_graph):
+ """
+ Test that error is raised when node sets aren't disjoint.
+ """
+ with pytest.raises(nx.NetworkXError):
+ nx.is_d_separator(collider_graph, 0, 1, 0)
+ with pytest.raises(nx.NetworkXError):
+ nx.is_d_separator(collider_graph, 0, 2, 0)
+ with pytest.raises(nx.NetworkXError):
+ nx.is_d_separator(collider_graph, 0, 0, 1)
+ with pytest.raises(nx.NetworkXError):
+ nx.is_d_separator(collider_graph, 1, 0, 0)
+ with pytest.raises(nx.NetworkXError):
+ nx.find_minimal_d_separator(collider_graph, 0, 0)
+ with pytest.raises(nx.NetworkXError):
+ nx.find_minimal_d_separator(collider_graph, 0, 1, included=0)
+ with pytest.raises(nx.NetworkXError):
+ nx.find_minimal_d_separator(collider_graph, 1, 0, included=0)
+ with pytest.raises(nx.NetworkXError):
+ nx.is_minimal_d_separator(collider_graph, 0, 0, set())
+ with pytest.raises(nx.NetworkXError):
+ nx.is_minimal_d_separator(collider_graph, 0, 1, set(), included=0)
+ with pytest.raises(nx.NetworkXError):
+ nx.is_minimal_d_separator(collider_graph, 1, 0, set(), included=0)
+
+
+def test_is_minimal_d_separator(
+ large_collider_graph,
+ chain_and_fork_graph,
+ no_separating_set_graph,
+ large_no_separating_set_graph,
+ collider_trek_graph,
+):
+ # Case 1:
+ # create a graph A -> B <- C
+ # B -> D -> E;
+ # B -> F;
+ # G -> E;
+ assert not nx.is_d_separator(large_collider_graph, {"B"}, {"E"}, set())
+
+ # minimal set of the corresponding graph
+ # for B and E should be (D,)
+ Zmin = nx.find_minimal_d_separator(large_collider_graph, "B", "E")
+ # check that the minimal d-separator is a d-separating set
+ assert nx.is_d_separator(large_collider_graph, "B", "E", Zmin)
+ # the minimal separating set should also pass the test for minimality
+ assert nx.is_minimal_d_separator(large_collider_graph, "B", "E", Zmin)
+ # function should also work with set arguments
+ assert nx.is_minimal_d_separator(large_collider_graph, {"A", "B"}, {"G", "E"}, Zmin)
+ assert Zmin == {"D"}
+
+ # Case 2:
+ # create a graph A -> B -> C
+ # B -> D -> C;
+ assert not nx.is_d_separator(chain_and_fork_graph, {"A"}, {"C"}, set())
+ Zmin = nx.find_minimal_d_separator(chain_and_fork_graph, "A", "C")
+
+ # the minimal separating set should pass the test for minimality
+ assert nx.is_minimal_d_separator(chain_and_fork_graph, "A", "C", Zmin)
+ assert Zmin == {"B"}
+ Znotmin = Zmin.union({"D"})
+ assert not nx.is_minimal_d_separator(chain_and_fork_graph, "A", "C", Znotmin)
+
+ # Case 3:
+ # create a graph A -> B
+
+ # there is no m-separating set between A and B at all, so
+ # no minimal m-separating set can exist
+ assert not nx.is_d_separator(no_separating_set_graph, {"A"}, {"B"}, set())
+ assert nx.find_minimal_d_separator(no_separating_set_graph, "A", "B") is None
+
+ # Case 4:
+ # create a graph A -> B with A <- C -> B
+
+ # there is no m-separating set between A and B at all, so
+ # no minimal m-separating set can exist
+ # however, the algorithm will initially propose C as a
+ # minimal (but invalid) separating set
+ assert not nx.is_d_separator(large_no_separating_set_graph, {"A"}, {"B"}, {"C"})
+ assert nx.find_minimal_d_separator(large_no_separating_set_graph, "A", "B") is None
+
+ # Test `included` and `excluded` args
+ # create graph A -> B <- C -> D
+ assert nx.find_minimal_d_separator(collider_trek_graph, "A", "D", included="B") == {
+ "B",
+ "C",
+ }
+ assert (
+ nx.find_minimal_d_separator(
+ collider_trek_graph, "A", "D", included="B", restricted="B"
+ )
+ is None
+ )
+
+
+def test_is_minimal_d_separator_checks_dsep():
+ """Test that is_minimal_d_separator checks for d-separation as well."""
+ g = nx.DiGraph()
+ g.add_edges_from(
+ [
+ ("A", "B"),
+ ("A", "E"),
+ ("B", "C"),
+ ("B", "D"),
+ ("D", "C"),
+ ("D", "F"),
+ ("E", "D"),
+ ("E", "F"),
+ ]
+ )
+
+ assert not nx.is_d_separator(g, {"C"}, {"F"}, {"D"})
+
+ # since {'D'} and {} are not d-separators, we return false
+ assert not nx.is_minimal_d_separator(g, "C", "F", {"D"})
+ assert not nx.is_minimal_d_separator(g, "C", "F", set())
+
+
+def test__reachable(large_collider_graph):
+ reachable = nx.algorithms.d_separation._reachable
+ g = large_collider_graph
+ x = {"F", "D"}
+ ancestors = {"A", "B", "C", "D", "F"}
+ assert reachable(g, x, ancestors, {"B"}) == {"B", "F", "D"}
+ assert reachable(g, x, ancestors, set()) == ancestors
+
+
+def test_deprecations():
+ G = nx.DiGraph([(0, 1), (1, 2)])
+ with pytest.deprecated_call():
+ nx.d_separated(G, 0, 2, {1})
+ with pytest.deprecated_call():
+ z = nx.minimal_d_separator(G, 0, 2)