Skip to content

Commit

Permalink
convert more Table operations to IR (hail-is#5015)
Browse files Browse the repository at this point in the history
* convert more Table operations to IR

* addressed comments
  • Loading branch information
cseed authored and danking committed Dec 20, 2018
1 parent ca29754 commit 5ef7b5c
Show file tree
Hide file tree
Showing 10 changed files with 66 additions and 57 deletions.
6 changes: 3 additions & 3 deletions hail/python/hail/backend/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

class Backend(object):
@abc.abstractmethod
def interpret(self, ir):
def execute(self, ir):
return

@abc.abstractmethod
Expand All @@ -30,7 +30,7 @@ def _to_java_ir(self, ir):
ir._jir = ir.parse(code, ir_map=r.jirs)
return ir._jir

def interpret(self, ir):
def execute(self, ir):
return ir.typ._from_json(
Env.hail().expr.ir.Interpret.interpretJSON(
self._to_java_ir(ir)))
Expand All @@ -50,7 +50,7 @@ def __init__(self, host, port=80, scheme='http'):
self.host = host
self.port = port

def interpret(self, ir):
def execute(self, ir):
r = Renderer(stop_at_jir=True)
code = r(ir)
assert len(r.jirs) == 0
Expand Down
2 changes: 1 addition & 1 deletion hail/python/hail/expr/expressions/expression_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def eval_typed(expression):
analyze('eval_typed', expression, Indices(expression._indices.source))

if expression._indices.source is None:
return (Env.hc()._backend.interpret(expression._ir), expression.dtype)
return (Env.backend().execute(expression._ir), expression.dtype)
else:
return expression.collect()[0], expression.dtype

Expand Down
14 changes: 7 additions & 7 deletions hail/python/hail/ir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1291,9 +1291,9 @@ def __eq__(self, other):
class TableExport(IR):
@typecheck_method(child=TableIR,
path=str,
types_file=str,
types_file=nullable(str),
header=bool,
export_type=hail_type)
export_type=int)
def __init__(self, child, path, types_file, header, export_type):
super().__init__(child)
self.child = child
Expand All @@ -1308,12 +1308,12 @@ def copy(self, child):
return new_instance(child, self.path, self.types_file, self.header, self.export_type)

def render(self, r):
return '(TableExport "{}" "{}" "{}" {} {})'.format(
return '(TableExport {} "{}" "{}" {} {})'.format(
r(self.child),
escape_str(self.path),
escape_str(self.types_file),
escape_str(self.header),
self.export_type,
r(self.child))
escape_str(self.types_file) if self.types_file else 'None',
self.header,
self.export_type)

def __eq__(self, other):
return isinstance(other, TableExport) and \
Expand Down
9 changes: 9 additions & 0 deletions hail/python/hail/ir/table_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,15 @@ def render(self, r):
return '(TableJoin {} {} {} {})'.format(
escape_id(self.join_type), self.join_key, r(self.left), r(self.right))

class TableLeftJoinRightDistinct(TableIR):
def __init__(self, left, right, root):
self.left = left
self.right = right
self.root = root

def render(self, r):
return '(TableLeftJoinRightDistinct {} {} {})'.format(
escape_id(self.root), r(self.left), r(self.right))

class TableUnion(TableIR):
def __init__(self, children):
Expand Down
12 changes: 6 additions & 6 deletions hail/python/hail/matrixtable.py
Original file line number Diff line number Diff line change
Expand Up @@ -1688,7 +1688,7 @@ def aggregate_rows(self, expr) -> Any:
base, _ = self._process_joins(expr)
analyze('MatrixTable.aggregate_rows', expr, self._global_indices, {self._row_axis})
subst_query = subst(expr._ir, {}, {'va': Ref('row')})
return Env.hc()._backend.interpret(TableAggregate(MatrixRowsTable(base._mir), subst_query))
return Env.backend().execute(TableAggregate(MatrixRowsTable(base._mir), subst_query))

@typecheck_method(expr=expr_any)
def aggregate_cols(self, expr) -> Any:
Expand Down Expand Up @@ -1733,7 +1733,7 @@ def aggregate_cols(self, expr) -> Any:
base, _ = self._process_joins(expr)
analyze('MatrixTable.aggregate_cols', expr, self._global_indices, {self._col_axis})
subst_query = subst(expr._ir, {}, {'sa': Ref('row')})
return Env.hc()._backend.interpret(TableAggregate(MatrixColsTable(base._mir), subst_query))
return Env.backend().execute(TableAggregate(MatrixColsTable(base._mir), subst_query))

@typecheck_method(expr=expr_any)
def aggregate_entries(self, expr) -> Any:
Expand Down Expand Up @@ -1773,7 +1773,7 @@ def aggregate_entries(self, expr) -> Any:

base, _ = self._process_joins(expr)
analyze('MatrixTable.aggregate_entries', expr, self._global_indices, {self._row_axis, self._col_axis})
return Env.hc()._backend.interpret(MatrixAggregate(base._mir, expr._ir))
return Env.backend().execute(MatrixAggregate(base._mir, expr._ir))

@typecheck_method(field_expr=oneof(str, Expression))
def explode_rows(self, field_expr) -> 'MatrixTable':
Expand Down Expand Up @@ -2088,7 +2088,7 @@ def count_rows(self) -> int:
Number of rows in the matrix.
"""

return Env.hc()._backend.interpret(
return Env.backend().execute(
TableCount(MatrixRowsTable(self._mir)))

def _force_count_rows(self):
Expand All @@ -2113,7 +2113,7 @@ def count_cols(self) -> int:
Number of columns in the matrix.
"""

return Env.hc()._backend.interpret(
return Env.backend().execute(
TableCount(MatrixColsTable(self._mir)))

def count(self) -> Tuple[int, int]:
Expand Down Expand Up @@ -2160,7 +2160,7 @@ def write(self, output: str, overwrite: bool = False, stage_locally: bool = Fals
"""

writer = MatrixNativeWriter(output, overwrite, stage_locally, _codec_spec)
Env.hc()._backend.interpret(MatrixWrite(self._mir, writer))
Env.backend().execute(MatrixWrite(self._mir, writer))

def globals_table(self) -> Table:
"""Returns a table with a single row with the globals of the matrix table.
Expand Down
6 changes: 3 additions & 3 deletions hail/python/hail/methods/impex.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def export_gen(dataset, output, precision=4, gp=None, id1=None, id2=None,
entry_exprs=entry_exprs)

writer = MatrixGENWriter(output, precision)
Env.hc()._backend.interpret(MatrixWrite(dataset._mir, writer))
Env.backend().execute(MatrixWrite(dataset._mir, writer))


@typecheck(dataset=MatrixTable,
Expand Down Expand Up @@ -301,7 +301,7 @@ def export_plink(dataset, output, call=None, fam_id=None, ind_id=None, pat_id=No
raise TypeError("\n".join(errors))

writer = MatrixPLINKWriter(output)
Env.hc()._backend.interpret(MatrixWrite(dataset._mir, writer))
Env.backend().execute(MatrixWrite(dataset._mir, writer))


@typecheck(dataset=MatrixTable,
Expand Down Expand Up @@ -418,7 +418,7 @@ def export_vcf(dataset, output, append_to_header=None, parallel=None, metadata=N
append_to_header,
Env.hail().utils.ExportType.getExportType(parallel),
metadata)
Env.hc()._backend.interpret(MatrixWrite(dataset._mir, writer))
Env.backend().execute(MatrixWrite(dataset._mir, writer))


@typecheck(path=str,
Expand Down
52 changes: 26 additions & 26 deletions hail/python/hail/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ def count(self):
-------
:obj:`int`
"""
return Env.hc()._backend.interpret(TableCount(self._tir))
return Env.backend().execute(TableCount(self._tir))

def _force_count(self):
return self._jt.forceCount()
Expand Down Expand Up @@ -996,7 +996,8 @@ def export(self, output, types_file=None, header=True, parallel=None):
the export will be slower.
"""

self._jt.export(output, types_file, header, Env.hail().utils.ExportType.getExportType(parallel))
Env.backend().execute(
TableExport(self._tir, output, types_file, header, Env.hail().utils.ExportType.getExportType(parallel)))

def group_by(self, *exprs, **named_exprs) -> 'GroupedTable':
"""Group by a new key for use with :meth:`.GroupedTable.aggregate`.
Expand Down Expand Up @@ -1135,7 +1136,7 @@ def aggregate(self, expr):
base, _ = self._process_joins(expr)
analyze('Table.aggregate', expr, self._global_indices, {self._row_axis})

return Env.hc()._backend.interpret(TableAggregate(base._tir, expr._ir))
return Env.backend().execute(TableAggregate(base._tir, expr._ir))

@typecheck_method(output=str,
overwrite=bool,
Expand Down Expand Up @@ -1165,7 +1166,7 @@ def write(self, output: str, overwrite = False, stage_locally: bool = False,
If ``True``, overwrite an existing file at the destination.
"""

Env.hc()._backend.interpret(TableWrite(self._tir, output, overwrite, stage_locally, _codec_spec))
Env.backend().execute(TableWrite(self._tir, output, overwrite, stage_locally, _codec_spec))

@typecheck_method(n=int, width=int, truncate=nullable(int), types=bool, handler=anyfunc)
def show(self, n=10, width=90, truncate=None, types=True, handler=print):
Expand Down Expand Up @@ -1341,18 +1342,19 @@ def types_compatible(left, right):
def joiner(left):
if not is_key:
original_key = list(left.key)
left = Table._from_java(left.key_by()._jt.mapRows(str(Apply('annotate',
left._row._ir,
hl.struct(**dict(zip(uids, exprs)))._ir)))
).key_by(*uids)
left = Table(TableMapRows(left.key_by()._tir,
Apply('annotate',
left._row._ir,
hl.struct(**dict(zip(uids, exprs)))._ir))
).key_by(*uids)
rekey_f = lambda t: t.key_by(*original_key)
else:
rekey_f = identity

if is_interval:
left = Table._from_java(left._jt.intervalJoin(self._jt, uid))
else:
left = Table._from_java(left._jt.leftJoinRightDistinct(self._jt, uid))
left = Table(TableLeftJoinRightDistinct(left._tir, self._tir, uid))
return rekey_f(left)

all_uids.append(uid)
Expand Down Expand Up @@ -1387,7 +1389,8 @@ def joiner(left):
key = None
else:
key = [str(k._ir) for k in exprs]
joiner = lambda left: MatrixTable._from_java(left._jmt.annotateRowsTableIR(right._jt, uid, key))
joiner = lambda left: MatrixTable(MatrixAnnotateRowsTable(
left._mir, right._tir, uid, key))
ast = Join(GetField(TopLevelReference('va'), uid),
[uid],
exprs,
Expand All @@ -1399,7 +1402,7 @@ def joiner(left):
exprs[i] is src.col_key[i] for i in range(len(exprs))]):
# key is already correct
def joiner(left):
return MatrixTable._from_java(left._jmt.annotateColsTable(right._jt, uid))
return MatrixTable(MatrixAnnotateColsTable(left._mir, right._tir, uid))
else:
index_uid = Env.get_uid()
uids = [Env.get_uid() for _ in exprs]
Expand All @@ -1417,10 +1420,12 @@ def joiner(left: MatrixTable):
.join(self, 'inner')
.key_by(index_uid)
.drop(*uids))
result = MatrixTable._from_java(left.add_col_index(index_uid)
.key_cols_by(index_uid)
._jmt
.annotateColsTable(joined._jt, uid)).key_cols_by(*prev_key)
result = MatrixTable(MatrixAnnotateColsTable(
(left.add_col_index(index_uid)
.key_cols_by(index_uid)
._mir),
joined._tir,
uid)).key_cols_by(*prev_key)
return result
ir = Join(GetField(TopLevelReference('sa'), uid),
all_uids,
Expand Down Expand Up @@ -2318,11 +2323,8 @@ def to_matrix_table(self, row_key, col_key, row_fields=[], col_fields=[], n_part
if len(col_key) == 0:
raise ValueError(f"'to_matrix_table': require at least one col key field")

return hl.MatrixTable._from_java(self._jt.jToMatrixTable(row_key,
col_key,
row_fields,
col_fields,
n_partitions))
return hl.MatrixTable(TableToMatrixTable(
self._tir, row_key, col_key, row_fields, col_fields, n_partitions))

@property
def globals(self) -> 'StructExpression':
Expand Down Expand Up @@ -2617,8 +2619,8 @@ def _filter_partitions(self, parts, keep=True):
entries_field_name=str,
col_key=sequenceof(str))
def _unlocalize_entries(self, entries_field_name, cols_field_name, col_key):
return hl.MatrixTable._from_java(
self._jt.unlocalizeEntries(entries_field_name, cols_field_name, col_key))
return hl.MatrixTable(CastTableToMatrix(
self._tir, entries_field_name, cols_field_name, col_key))

@staticmethod
@typecheck(tables=sequenceof(table_type), data_field_name=str, global_field_name=str)
Expand All @@ -2632,9 +2634,7 @@ def _multi_way_zip_join(tables, data_field_name, global_field_name):
raise TypeError('All input tables to multi_way_zip_join must have the same row type')
if any(head.globals.dtype != t.globals.dtype for t in tables):
raise TypeError('All input tables to multi_way_zip_join must have the same global type')
jt = Env.hail().table.Table.multiWayZipJoin([t._jt for t in tables],
data_field_name,
global_field_name)
return Table._from_java(jt)
return Table(TableMultiWayZipJoin(
[t._tir for t in tables], data_field_name, global_field_name))

table_type.set(Table)
4 changes: 4 additions & 0 deletions hail/python/hail/utils/java.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ def hc():
assert Env._hc is not None
return Env._hc

@staticmethod
def backend():
return Env.hc()._backend

@staticmethod
def sql_context():
return Env.hc()._sql_context
Expand Down
7 changes: 7 additions & 0 deletions hail/src/main/scala/is/hail/expr/ir/Parser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -730,6 +730,13 @@ object IRParser {
val child = table_ir(env)(it)
val query = ir_value_expr(env.update(child.typ.refMap))(it)
TableAggregate(child, query)
case "TableExport" =>
val child = table_ir(env)(it)
val path = string_literal(it)
val typesFile = opt(it, string_literal).orNull
val header = boolean_literal(it)
val exportType = int32_literal(it)
TableExport(child, path, typesFile, header, exportType)
case "TableWrite" =>
val path = string_literal(it)
val overwrite = boolean_literal(it)
Expand Down
11 changes: 0 additions & 11 deletions hail/src/main/scala/is/hail/table/Table.scala
Original file line number Diff line number Diff line change
Expand Up @@ -439,17 +439,6 @@ class Table(val hc: HailContext, val tir: TableIR) {
signature = keySignature ++ TStruct(name -> TArray(valueSignature)))
}

def jToMatrixTable(rowKeys: java.util.ArrayList[String],
colKeys: java.util.ArrayList[String],
rowFields: java.util.ArrayList[String],
colFields: java.util.ArrayList[String],
nPartitions: java.lang.Integer): MatrixTable = {

toMatrixTable(rowKeys.asScala.toArray, colKeys.asScala.toArray,
rowFields.asScala.toArray, colFields.asScala.toArray,
Option(nPartitions).map(_.asInstanceOf[Int])
)
}

def toMatrixTable(
rowKeys: Array[String],
Expand Down

0 comments on commit 5ef7b5c

Please sign in to comment.