from itertools import combinations import pytest import networkx as nx def path_graph(): """Return a path graph of length three.""" G = nx.path_graph(3, create_using=nx.DiGraph) G.graph["name"] = "path" nx.freeze(G) return G def fork_graph(): """Return a three node fork graph.""" G = nx.DiGraph(name="fork") G.add_edges_from([(0, 1), (0, 2)]) nx.freeze(G) return G def collider_graph(): """Return a collider/v-structure graph with three nodes.""" G = nx.DiGraph(name="collider") G.add_edges_from([(0, 2), (1, 2)]) nx.freeze(G) return G def naive_bayes_graph(): """Return a simply Naive Bayes PGM graph.""" G = nx.DiGraph(name="naive_bayes") G.add_edges_from([(0, 1), (0, 2), (0, 3), (0, 4)]) nx.freeze(G) return G def asia_graph(): """Return the 'Asia' PGM graph.""" G = nx.DiGraph(name="asia") G.add_edges_from( [ ("asia", "tuberculosis"), ("smoking", "cancer"), ("smoking", "bronchitis"), ("tuberculosis", "either"), ("cancer", "either"), ("either", "xray"), ("either", "dyspnea"), ("bronchitis", "dyspnea"), ] ) nx.freeze(G) return G @pytest.fixture(name="path_graph") def path_graph_fixture(): return path_graph() @pytest.fixture(name="fork_graph") def fork_graph_fixture(): return fork_graph() @pytest.fixture(name="collider_graph") def collider_graph_fixture(): return collider_graph() @pytest.fixture(name="naive_bayes_graph") def naive_bayes_graph_fixture(): return naive_bayes_graph() @pytest.fixture(name="asia_graph") def asia_graph_fixture(): return asia_graph() @pytest.fixture() def large_collider_graph(): edge_list = [("A", "B"), ("C", "B"), ("B", "D"), ("D", "E"), ("B", "F"), ("G", "E")] G = nx.DiGraph(edge_list) return G @pytest.fixture() def chain_and_fork_graph(): edge_list = [("A", "B"), ("B", "C"), ("B", "D"), ("D", "C")] G = nx.DiGraph(edge_list) return G @pytest.fixture() def no_separating_set_graph(): edge_list = [("A", "B")] G = nx.DiGraph(edge_list) return G @pytest.fixture() def large_no_separating_set_graph(): edge_list = [("A", "B"), ("C", "A"), ("C", "B")] G = nx.DiGraph(edge_list) return G @pytest.fixture() def collider_trek_graph(): edge_list = [("A", "B"), ("C", "B"), ("C", "D")] G = nx.DiGraph(edge_list) return G @pytest.mark.parametrize( "graph", [path_graph(), fork_graph(), collider_graph(), naive_bayes_graph(), asia_graph()], ) def test_markov_condition(graph): """Test that the Markov condition holds for each PGM graph.""" for node in graph.nodes: parents = set(graph.predecessors(node)) non_descendants = graph.nodes - nx.descendants(graph, node) - {node} - parents assert nx.is_d_separator(graph, {node}, non_descendants, parents) def test_path_graph_dsep(path_graph): """Example-based test of d-separation for path_graph.""" assert nx.is_d_separator(path_graph, {0}, {2}, {1}) assert not nx.is_d_separator(path_graph, {0}, {2}, set()) def test_fork_graph_dsep(fork_graph): """Example-based test of d-separation for fork_graph.""" assert nx.is_d_separator(fork_graph, {1}, {2}, {0}) assert not nx.is_d_separator(fork_graph, {1}, {2}, set()) def test_collider_graph_dsep(collider_graph): """Example-based test of d-separation for collider_graph.""" assert nx.is_d_separator(collider_graph, {0}, {1}, set()) assert not nx.is_d_separator(collider_graph, {0}, {1}, {2}) def test_naive_bayes_dsep(naive_bayes_graph): """Example-based test of d-separation for naive_bayes_graph.""" for u, v in combinations(range(1, 5), 2): assert nx.is_d_separator(naive_bayes_graph, {u}, {v}, {0}) assert not nx.is_d_separator(naive_bayes_graph, {u}, {v}, set()) def test_asia_graph_dsep(asia_graph): """Example-based test of d-separation for asia_graph.""" assert nx.is_d_separator( asia_graph, {"asia", "smoking"}, {"dyspnea", "xray"}, {"bronchitis", "either"} ) assert nx.is_d_separator( asia_graph, {"tuberculosis", "cancer"}, {"bronchitis"}, {"smoking", "xray"} ) def test_undirected_graphs_are_not_supported(): """ Test that undirected graphs are not supported. d-separation and its related algorithms do not apply in the case of undirected graphs. """ g = nx.path_graph(3, nx.Graph) with pytest.raises(nx.NetworkXNotImplemented): nx.is_d_separator(g, {0}, {1}, {2}) with pytest.raises(nx.NetworkXNotImplemented): nx.is_minimal_d_separator(g, {0}, {1}, {2}) with pytest.raises(nx.NetworkXNotImplemented): nx.find_minimal_d_separator(g, {0}, {1}) def test_cyclic_graphs_raise_error(): """ Test that cycle graphs should cause erroring. This is because PGMs assume a directed acyclic graph. """ g = nx.cycle_graph(3, nx.DiGraph) with pytest.raises(nx.NetworkXError): nx.is_d_separator(g, {0}, {1}, {2}) with pytest.raises(nx.NetworkXError): nx.find_minimal_d_separator(g, {0}, {1}) with pytest.raises(nx.NetworkXError): nx.is_minimal_d_separator(g, {0}, {1}, {2}) def test_invalid_nodes_raise_error(asia_graph): """ Test that graphs that have invalid nodes passed in raise errors. """ # Check both set and node arguments with pytest.raises(nx.NodeNotFound): nx.is_d_separator(asia_graph, {0}, {1}, {2}) with pytest.raises(nx.NodeNotFound): nx.is_d_separator(asia_graph, 0, 1, 2) with pytest.raises(nx.NodeNotFound): nx.is_minimal_d_separator(asia_graph, {0}, {1}, {2}) with pytest.raises(nx.NodeNotFound): nx.is_minimal_d_separator(asia_graph, 0, 1, 2) with pytest.raises(nx.NodeNotFound): nx.find_minimal_d_separator(asia_graph, {0}, {1}) with pytest.raises(nx.NodeNotFound): nx.find_minimal_d_separator(asia_graph, 0, 1) def test_nondisjoint_node_sets_raise_error(collider_graph): """ Test that error is raised when node sets aren't disjoint. """ with pytest.raises(nx.NetworkXError): nx.is_d_separator(collider_graph, 0, 1, 0) with pytest.raises(nx.NetworkXError): nx.is_d_separator(collider_graph, 0, 2, 0) with pytest.raises(nx.NetworkXError): nx.is_d_separator(collider_graph, 0, 0, 1) with pytest.raises(nx.NetworkXError): nx.is_d_separator(collider_graph, 1, 0, 0) with pytest.raises(nx.NetworkXError): nx.find_minimal_d_separator(collider_graph, 0, 0) with pytest.raises(nx.NetworkXError): nx.find_minimal_d_separator(collider_graph, 0, 1, included=0) with pytest.raises(nx.NetworkXError): nx.find_minimal_d_separator(collider_graph, 1, 0, included=0) with pytest.raises(nx.NetworkXError): nx.is_minimal_d_separator(collider_graph, 0, 0, set()) with pytest.raises(nx.NetworkXError): nx.is_minimal_d_separator(collider_graph, 0, 1, set(), included=0) with pytest.raises(nx.NetworkXError): nx.is_minimal_d_separator(collider_graph, 1, 0, set(), included=0) def test_is_minimal_d_separator( large_collider_graph, chain_and_fork_graph, no_separating_set_graph, large_no_separating_set_graph, collider_trek_graph, ): # Case 1: # create a graph A -> B <- C # B -> D -> E; # B -> F; # G -> E; assert not nx.is_d_separator(large_collider_graph, {"B"}, {"E"}, set()) # minimal set of the corresponding graph # for B and E should be (D,) Zmin = nx.find_minimal_d_separator(large_collider_graph, "B", "E") # check that the minimal d-separator is a d-separating set assert nx.is_d_separator(large_collider_graph, "B", "E", Zmin) # the minimal separating set should also pass the test for minimality assert nx.is_minimal_d_separator(large_collider_graph, "B", "E", Zmin) # function should also work with set arguments assert nx.is_minimal_d_separator(large_collider_graph, {"A", "B"}, {"G", "E"}, Zmin) assert Zmin == {"D"} # Case 2: # create a graph A -> B -> C # B -> D -> C; assert not nx.is_d_separator(chain_and_fork_graph, {"A"}, {"C"}, set()) Zmin = nx.find_minimal_d_separator(chain_and_fork_graph, "A", "C") # the minimal separating set should pass the test for minimality assert nx.is_minimal_d_separator(chain_and_fork_graph, "A", "C", Zmin) assert Zmin == {"B"} Znotmin = Zmin.union({"D"}) assert not nx.is_minimal_d_separator(chain_and_fork_graph, "A", "C", Znotmin) # Case 3: # create a graph A -> B # there is no m-separating set between A and B at all, so # no minimal m-separating set can exist assert not nx.is_d_separator(no_separating_set_graph, {"A"}, {"B"}, set()) assert nx.find_minimal_d_separator(no_separating_set_graph, "A", "B") is None # Case 4: # create a graph A -> B with A <- C -> B # there is no m-separating set between A and B at all, so # no minimal m-separating set can exist # however, the algorithm will initially propose C as a # minimal (but invalid) separating set assert not nx.is_d_separator(large_no_separating_set_graph, {"A"}, {"B"}, {"C"}) assert nx.find_minimal_d_separator(large_no_separating_set_graph, "A", "B") is None # Test `included` and `excluded` args # create graph A -> B <- C -> D assert nx.find_minimal_d_separator(collider_trek_graph, "A", "D", included="B") == { "B", "C", } assert ( nx.find_minimal_d_separator( collider_trek_graph, "A", "D", included="B", restricted="B" ) is None ) def test_is_minimal_d_separator_checks_dsep(): """Test that is_minimal_d_separator checks for d-separation as well.""" g = nx.DiGraph() g.add_edges_from( [ ("A", "B"), ("A", "E"), ("B", "C"), ("B", "D"), ("D", "C"), ("D", "F"), ("E", "D"), ("E", "F"), ] ) assert not nx.is_d_separator(g, {"C"}, {"F"}, {"D"}) # since {'D'} and {} are not d-separators, we return false assert not nx.is_minimal_d_separator(g, "C", "F", {"D"}) assert not nx.is_minimal_d_separator(g, "C", "F", set()) def test__reachable(large_collider_graph): reachable = nx.algorithms.d_separation._reachable g = large_collider_graph x = {"F", "D"} ancestors = {"A", "B", "C", "D", "F"} assert reachable(g, x, ancestors, {"B"}) == {"B", "F", "D"} assert reachable(g, x, ancestors, set()) == ancestors def test_deprecations(): G = nx.DiGraph([(0, 1), (1, 2)]) with pytest.deprecated_call(): nx.d_separated(G, 0, 2, {1}) with pytest.deprecated_call(): z = nx.minimal_d_separator(G, 0, 2)