Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions qiita_db/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,11 @@

from __future__ import division
from future.utils import viewitems
from itertools import chain
from datetime import datetime

import networkx as nx

import qiita_db as qdb


Expand Down Expand Up @@ -646,6 +649,65 @@ def parents(self):
return [Artifact(p_id)
for p_id in qdb.sql_connection.TRN.execute_fetchflatten()]

def _create_lineage_graph_from_edge_list(self, edge_list):
"""Generates an artifact graph from the given `edge_list`

Parameters
----------
edge_list : list of (int, int)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the if statement below, edge_list here should be optional and set to None by default in the function definition.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope, the if statement is specifically for the case in which the list is empty, but it is not an optional parameter.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, can you add a note to that effect? That was not clear from the code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

List of (parent_artifact_id, artifact_id)

Returns
-------
networkx.DiGraph
The graph representing the artifact lineage stored in `edge_list`
"""
lineage = nx.DiGraph()
# In case the edge list is empty, only 'self' is present in the graph
if edge_list:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is for the case when the list is empty, this should probably be set to check that. Right now anything can be passed as the edge_list and it will just return the single node graph.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you elaborate? this is the pythonic way of testing if a list is empty or not.

# By creating all the artifacts here we are saving DB calls
nodes = {a_id: Artifact(a_id)
for a_id in set(chain.from_iterable(edge_list))}

for parent, child in edge_list:
lineage.add_edge(nodes[parent], nodes[child])
else:
lineage.add_node(self)

return lineage

@property
def ancestors(self):
"""Returns the ancestors of the artifact

Returns
-------
networkx.DiGraph
The ancestors of the artifact
"""
with qdb.sql_connection.TRN:
sql = """SELECT parent_id, artifact_id
FROM qiita.artifact_ancestry(%s)"""
qdb.sql_connection.TRN.add(sql, [self.id])
edges = qdb.sql_connection.TRN.execute_fetchindex()
return self._create_lineage_graph_from_edge_list(edges)

@property
def descendants(self):
"""Returns the descendants of the artifact

Returns
-------
networkx.DiGraph
The descendants of the artifact
"""
with qdb.sql_connection.TRN:
sql = """SELECT parent_id, artifact_id
FROM qiita.artifact_descendants(%s)"""
qdb.sql_connection.TRN.add(sql, [self.id])
edges = qdb.sql_connection.TRN.execute_fetchindex()
return self._create_lineage_graph_from_edge_list(edges)

@property
def children(self):
"""Returns the list of children of the artifact
Expand Down
39 changes: 39 additions & 0 deletions qiita_db/support_files/patches/33.sql
Original file line number Diff line number Diff line change
Expand Up @@ -811,3 +811,42 @@ BEGIN
END IF;
END
$$ LANGUAGE plpgsql;


-- Create a function to return the ancestors of an Artifact
CREATE FUNCTION qiita.artifact_ancestry(a_id bigint) RETURNS SETOF qiita.parent_artifact AS $$
BEGIN
IF EXISTS(SELECT * FROM qiita.parent_artifact WHERE artifact_id = a_id) THEN
RETURN QUERY WITH RECURSIVE root AS (
SELECT artifact_id, parent_id
FROM qiita.parent_artifact
WHERE artifact_id = a_id
UNION
SELECT p.artifact_id, p.parent_id
FROM qiita.parent_artifact p
JOIN root r ON (r.parent_id = p.artifact_id)
)
SELECT DISTINCT artifact_id, parent_id
FROM root;
END IF;
END
$$ LANGUAGE plpgsql;

-- Create a function to return the descendants of an artifact
CREATE FUNCTION qiita.artifact_descendants(a_id bigint) RETURNS SETOF qiita.parent_artifact AS $$
BEGIN
IF EXISTS(SELECT * FROM qiita.parent_artifact WHERE parent_id = a_id) THEN
RETURN QUERY WITH RECURSIVE root AS (
SELECT artifact_id, parent_id
FROM qiita.parent_artifact
WHERE parent_id = a_id
UNION
SELECT p.artifact_id, p.parent_id
FROM qiita.parent_artifact p
JOIN root r ON (r.artifact_id = p.parent_id)
)
SELECT DISTINCT artifact_id, parent_id
FROM root;
END IF;
END
$$ LANGUAGE plpgsql;
103 changes: 103 additions & 0 deletions qiita_db/test/test_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from functools import partial

import pandas as pd
import networkx as nx
from biom import example_table as et
from biom.util import biom_open

Expand Down Expand Up @@ -496,6 +497,108 @@ def test_parents(self):
exp_parents = [qdb.artifact.Artifact(2)]
self.assertEqual(qdb.artifact.Artifact(4).parents, exp_parents)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be good to add a test in for _create_lineage_graph_from_edge_list explicitly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding

def test_create_lineage_graph_from_edge_list_empty(self):
tester = qdb.artifact.Artifact(1)
obs = tester._create_lineage_graph_from_edge_list([])
self.assertTrue(isinstance(obs, nx.DiGraph))
self.assertEqual(obs.nodes(), [tester])
self.assertEqual(obs.edges(), [])

def test_create_lineage_graph_from_edge_list(self):
tester = qdb.artifact.Artifact(1)
obs = tester._create_lineage_graph_from_edge_list(
[(1, 2), (2, 4), (1, 3), (3, 4)])
self.assertTrue(isinstance(obs, nx.DiGraph))
exp = [qdb.artifact.Artifact(1), qdb.artifact.Artifact(2),
qdb.artifact.Artifact(3), qdb.artifact.Artifact(4)]
self.assertItemsEqual(obs.nodes(), exp)
exp = [(qdb.artifact.Artifact(1), qdb.artifact.Artifact(2)),
(qdb.artifact.Artifact(2), qdb.artifact.Artifact(4)),
(qdb.artifact.Artifact(1), qdb.artifact.Artifact(3)),
(qdb.artifact.Artifact(3), qdb.artifact.Artifact(4))]
self.assertItemsEqual(obs.edges(), exp)

def test_ancestors(self):
obs = qdb.artifact.Artifact(1).ancestors
self.assertTrue(isinstance(obs, nx.DiGraph))
obs_nodes = obs.nodes()
self.assertEqual(obs_nodes, [qdb.artifact.Artifact(1)])
obs_edges = obs.edges()
self.assertEqual(obs_edges, [])

obs = qdb.artifact.Artifact(2).ancestors
self.assertTrue(isinstance(obs, nx.DiGraph))
obs_nodes = obs.nodes()
exp_nodes = [qdb.artifact.Artifact(1), qdb.artifact.Artifact(2)]
self.assertItemsEqual(obs_nodes, exp_nodes)
obs_edges = obs.edges()
exp_edges = [(qdb.artifact.Artifact(1), qdb.artifact.Artifact(2))]
self.assertItemsEqual(obs_edges, exp_edges)

obs = qdb.artifact.Artifact(3).ancestors
self.assertTrue(isinstance(obs, nx.DiGraph))
obs_nodes = obs.nodes()
exp_nodes = [qdb.artifact.Artifact(1), qdb.artifact.Artifact(3)]
self.assertItemsEqual(obs_nodes, exp_nodes)
obs_edges = obs.edges()
exp_edges = [(qdb.artifact.Artifact(1), qdb.artifact.Artifact(3))]
self.assertItemsEqual(obs_edges, exp_edges)

obs = qdb.artifact.Artifact(4).ancestors
self.assertTrue(isinstance(obs, nx.DiGraph))
obs_nodes = obs.nodes()
exp_nodes = [qdb.artifact.Artifact(1), qdb.artifact.Artifact(2),
qdb.artifact.Artifact(4)]
self.assertItemsEqual(obs_nodes, exp_nodes)
obs_edges = obs.edges()
exp_edges = [(qdb.artifact.Artifact(1), qdb.artifact.Artifact(2)),
(qdb.artifact.Artifact(2), qdb.artifact.Artifact(4))]
self.assertItemsEqual(obs_edges, exp_edges)

def test_descendants(self):
obs = qdb.artifact.Artifact(1).descendants
self.assertTrue(isinstance(obs, nx.DiGraph))
obs_nodes = obs.nodes()
exp_nodes = [qdb.artifact.Artifact(1), qdb.artifact.Artifact(2),
qdb.artifact.Artifact(3), qdb.artifact.Artifact(4)]
self.assertItemsEqual(obs_nodes, exp_nodes)
obs_edges = obs.edges()
exp_edges = [(qdb.artifact.Artifact(1), qdb.artifact.Artifact(2)),
(qdb.artifact.Artifact(1), qdb.artifact.Artifact(3)),
(qdb.artifact.Artifact(2), qdb.artifact.Artifact(4))]
self.assertItemsEqual(obs_edges, exp_edges)

obs = qdb.artifact.Artifact(2).descendants
self.assertTrue(isinstance(obs, nx.DiGraph))
obs_nodes = obs.nodes()
exp_nodes = [qdb.artifact.Artifact(2), qdb.artifact.Artifact(4)]
self.assertItemsEqual(obs_nodes, exp_nodes)
obs_edges = obs.edges()
exp_edges = [(qdb.artifact.Artifact(2), qdb.artifact.Artifact(4))]
self.assertItemsEqual(obs_edges, exp_edges)

obs = qdb.artifact.Artifact(3).descendants
self.assertTrue(isinstance(obs, nx.DiGraph))
obs_nodes = obs.nodes()
self.assertItemsEqual(obs_nodes, [qdb.artifact.Artifact(3)])
obs_edges = obs.edges()
self.assertItemsEqual(obs_edges, [])

obs = qdb.artifact.Artifact(4).descendants
self.assertTrue(isinstance(obs, nx.DiGraph))
obs_nodes = obs.nodes()
self.assertItemsEqual(obs_nodes, [qdb.artifact.Artifact(4)])
obs_edges = obs.edges()
self.assertItemsEqual(obs_edges, [])

def test_children(self):
exp = [qdb.artifact.Artifact(2), qdb.artifact.Artifact(3)]
self.assertEqual(qdb.artifact.Artifact(1).children, exp)
exp = [qdb.artifact.Artifact(4)]
self.assertEqual(qdb.artifact.Artifact(2).children, exp)
self.assertEqual(qdb.artifact.Artifact(3).children, [])
self.assertEqual(qdb.artifact.Artifact(4).children, [])

def test_prep_templates(self):
self.assertEqual(
qdb.artifact.Artifact(1).prep_templates,
Expand Down
55 changes: 55 additions & 0 deletions qiita_db/test/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,61 @@ def test_find_artifact_roots_is_child_multiple_parents_multiple_root(self):
exp = [[1], [new_root.id]]
self.assertEqual(obs, exp)

def test_artifact_ancestry_root(self):
"""Correctly returns the ancestry of a root artifact"""
sql = "SELECT * FROM qiita.artifact_ancestry(%s)"
obs = self.conn_handler.execute_fetchall(sql, [1])
exp = []
self.assertEqual(obs, exp)

def test_artifact_ancestry_leaf(self):
"""Correctly returns the ancestry of a leaf artifact"""
sql = "SELECT * FROM qiita.artifact_ancestry(%s)"
obs = self.conn_handler.execute_fetchall(sql, [4])
exp = [[4, 2], [2, 1]]
self.assertItemsEqual(obs, exp)

def test_artifact_ancestry_leaf_multiple_parents(self):
"""Correctly returns the ancestry of a leaf artifact w multiple parents
"""
sql = """INSERT INTO qiita.parent_artifact (artifact_id, parent_id)
VALUES (%s, %s)"""
self.conn_handler.execute(sql, [4, 3])
sql = "SELECT * FROM qiita.artifact_ancestry(%s)"
obs = self.conn_handler.execute_fetchall(sql, [4])
exp = [[4, 3], [3, 1], [4, 2], [2, 1]]
self.assertItemsEqual(obs, exp)

def test_artifact_ancestry_middle(self):
"""Correctly returns the ancestry of an artifact in the middle of the
DAG"""
sql = "SELECT * FROM qiita.artifact_ancestry(%s)"
obs = self.conn_handler.execute_fetchall(sql, [2])
exp = [[2, 1]]
self.assertEqual(obs, exp)

def test_artifact_descendants_leaf(self):
"""Correctly returns the descendants of a leaf artifact"""
sql = "SELECT * FROM qiita.artifact_descendants(%s)"
obs = self.conn_handler.execute_fetchall(sql, [4])
exp = []
self.assertEqual(obs, exp)

def test_artifact_descendants_root(self):
"""Correctly returns the descendants of a root artifact"""
sql = "SELECT * FROM qiita.artifact_descendants(%s)"
obs = self.conn_handler.execute_fetchall(sql, [1])
exp = [[2, 1], [3, 1], [4, 2]]
self.assertItemsEqual(obs, exp)

def test_artifact_descendants_middle(self):
"""Correctly returns the descendants of an artifact in the middle of
the DAG"""
sql = "SELECT * FROM qiita.artifact_descendants(%s)"
obs = self.conn_handler.execute_fetchall(sql, [2])
exp = [[4, 2]]
self.assertEqual(obs, exp)


if __name__ == '__main__':
main()