diff --git a/hail/python/hail/backend/backend.py b/hail/python/hail/backend/backend.py index 7240a7283be..d9873a69b04 100644 --- a/hail/python/hail/backend/backend.py +++ b/hail/python/hail/backend/backend.py @@ -9,7 +9,7 @@ class Backend(object): @abc.abstractmethod - def interpret(self, ir): + def execute(self, ir): return @abc.abstractmethod @@ -30,7 +30,7 @@ def _to_java_ir(self, ir): ir._jir = ir.parse(code, ir_map=r.jirs) return ir._jir - def interpret(self, ir): + def execute(self, ir): return ir.typ._from_json( Env.hail().expr.ir.Interpret.interpretJSON( self._to_java_ir(ir))) @@ -50,7 +50,7 @@ def __init__(self, host, port=80, scheme='http'): self.host = host self.port = port - def interpret(self, ir): + def execute(self, ir): r = Renderer(stop_at_jir=True) code = r(ir) assert len(r.jirs) == 0 diff --git a/hail/python/hail/expr/expressions/expression_utils.py b/hail/python/hail/expr/expressions/expression_utils.py index 677cfc2a7ae..34e0a2c2ec5 100644 --- a/hail/python/hail/expr/expressions/expression_utils.py +++ b/hail/python/hail/expr/expressions/expression_utils.py @@ -190,7 +190,7 @@ def eval_typed(expression): analyze('eval_typed', expression, Indices(expression._indices.source)) if expression._indices.source is None: - return (Env.hc()._backend.interpret(expression._ir), expression.dtype) + return (Env.backend().execute(expression._ir), expression.dtype) else: return expression.collect()[0], expression.dtype diff --git a/hail/python/hail/ir/ir.py b/hail/python/hail/ir/ir.py index 20bcf79980a..9ce3b5a912a 100644 --- a/hail/python/hail/ir/ir.py +++ b/hail/python/hail/ir/ir.py @@ -1291,9 +1291,9 @@ def __eq__(self, other): class TableExport(IR): @typecheck_method(child=TableIR, path=str, - types_file=str, + types_file=nullable(str), header=bool, - export_type=hail_type) + export_type=int) def __init__(self, child, path, types_file, header, export_type): super().__init__(child) self.child = child @@ -1308,12 +1308,12 @@ def copy(self, child): return new_instance(child, self.path, self.types_file, self.header, self.export_type) def render(self, r): - return '(TableExport "{}" "{}" "{}" {} {})'.format( + return '(TableExport {} "{}" "{}" {} {})'.format( + r(self.child), escape_str(self.path), - escape_str(self.types_file), - escape_str(self.header), - self.export_type, - r(self.child)) + escape_str(self.types_file) if self.types_file else 'None', + self.header, + self.export_type) def __eq__(self, other): return isinstance(other, TableExport) and \ diff --git a/hail/python/hail/ir/table_ir.py b/hail/python/hail/ir/table_ir.py index a176020e1e8..a86f14cfa15 100644 --- a/hail/python/hail/ir/table_ir.py +++ b/hail/python/hail/ir/table_ir.py @@ -25,6 +25,15 @@ def render(self, r): return '(TableJoin {} {} {} {})'.format( escape_id(self.join_type), self.join_key, r(self.left), r(self.right)) +class TableLeftJoinRightDistinct(TableIR): + def __init__(self, left, right, root): + self.left = left + self.right = right + self.root = root + + def render(self, r): + return '(TableLeftJoinRightDistinct {} {} {})'.format( + escape_id(self.root), r(self.left), r(self.right)) class TableUnion(TableIR): def __init__(self, children): diff --git a/hail/python/hail/matrixtable.py b/hail/python/hail/matrixtable.py index 9175fd3db50..7916c0756e7 100644 --- a/hail/python/hail/matrixtable.py +++ b/hail/python/hail/matrixtable.py @@ -1688,7 +1688,7 @@ def aggregate_rows(self, expr) -> Any: base, _ = self._process_joins(expr) analyze('MatrixTable.aggregate_rows', expr, self._global_indices, {self._row_axis}) subst_query = subst(expr._ir, {}, {'va': Ref('row')}) - return Env.hc()._backend.interpret(TableAggregate(MatrixRowsTable(base._mir), subst_query)) + return Env.backend().execute(TableAggregate(MatrixRowsTable(base._mir), subst_query)) @typecheck_method(expr=expr_any) def aggregate_cols(self, expr) -> Any: @@ -1733,7 +1733,7 @@ def aggregate_cols(self, expr) -> Any: base, _ = self._process_joins(expr) analyze('MatrixTable.aggregate_cols', expr, self._global_indices, {self._col_axis}) subst_query = subst(expr._ir, {}, {'sa': Ref('row')}) - return Env.hc()._backend.interpret(TableAggregate(MatrixColsTable(base._mir), subst_query)) + return Env.backend().execute(TableAggregate(MatrixColsTable(base._mir), subst_query)) @typecheck_method(expr=expr_any) def aggregate_entries(self, expr) -> Any: @@ -1773,7 +1773,7 @@ def aggregate_entries(self, expr) -> Any: base, _ = self._process_joins(expr) analyze('MatrixTable.aggregate_entries', expr, self._global_indices, {self._row_axis, self._col_axis}) - return Env.hc()._backend.interpret(MatrixAggregate(base._mir, expr._ir)) + return Env.backend().execute(MatrixAggregate(base._mir, expr._ir)) @typecheck_method(field_expr=oneof(str, Expression)) def explode_rows(self, field_expr) -> 'MatrixTable': @@ -2088,7 +2088,7 @@ def count_rows(self) -> int: Number of rows in the matrix. """ - return Env.hc()._backend.interpret( + return Env.backend().execute( TableCount(MatrixRowsTable(self._mir))) def _force_count_rows(self): @@ -2113,7 +2113,7 @@ def count_cols(self) -> int: Number of columns in the matrix. """ - return Env.hc()._backend.interpret( + return Env.backend().execute( TableCount(MatrixColsTable(self._mir))) def count(self) -> Tuple[int, int]: @@ -2160,7 +2160,7 @@ def write(self, output: str, overwrite: bool = False, stage_locally: bool = Fals """ writer = MatrixNativeWriter(output, overwrite, stage_locally, _codec_spec) - Env.hc()._backend.interpret(MatrixWrite(self._mir, writer)) + Env.backend().execute(MatrixWrite(self._mir, writer)) def globals_table(self) -> Table: """Returns a table with a single row with the globals of the matrix table. diff --git a/hail/python/hail/methods/impex.py b/hail/python/hail/methods/impex.py index 87fb6f77968..50af4850546 100644 --- a/hail/python/hail/methods/impex.py +++ b/hail/python/hail/methods/impex.py @@ -160,7 +160,7 @@ def export_gen(dataset, output, precision=4, gp=None, id1=None, id2=None, entry_exprs=entry_exprs) writer = MatrixGENWriter(output, precision) - Env.hc()._backend.interpret(MatrixWrite(dataset._mir, writer)) + Env.backend().execute(MatrixWrite(dataset._mir, writer)) @typecheck(dataset=MatrixTable, @@ -301,7 +301,7 @@ def export_plink(dataset, output, call=None, fam_id=None, ind_id=None, pat_id=No raise TypeError("\n".join(errors)) writer = MatrixPLINKWriter(output) - Env.hc()._backend.interpret(MatrixWrite(dataset._mir, writer)) + Env.backend().execute(MatrixWrite(dataset._mir, writer)) @typecheck(dataset=MatrixTable, @@ -418,7 +418,7 @@ def export_vcf(dataset, output, append_to_header=None, parallel=None, metadata=N append_to_header, Env.hail().utils.ExportType.getExportType(parallel), metadata) - Env.hc()._backend.interpret(MatrixWrite(dataset._mir, writer)) + Env.backend().execute(MatrixWrite(dataset._mir, writer)) @typecheck(path=str, diff --git a/hail/python/hail/table.py b/hail/python/hail/table.py index e916c44a6ef..400dec28946 100644 --- a/hail/python/hail/table.py +++ b/hail/python/hail/table.py @@ -411,7 +411,7 @@ def count(self): ------- :obj:`int` """ - return Env.hc()._backend.interpret(TableCount(self._tir)) + return Env.backend().execute(TableCount(self._tir)) def _force_count(self): return self._jt.forceCount() @@ -996,7 +996,8 @@ def export(self, output, types_file=None, header=True, parallel=None): the export will be slower. """ - self._jt.export(output, types_file, header, Env.hail().utils.ExportType.getExportType(parallel)) + Env.backend().execute( + TableExport(self._tir, output, types_file, header, Env.hail().utils.ExportType.getExportType(parallel))) def group_by(self, *exprs, **named_exprs) -> 'GroupedTable': """Group by a new key for use with :meth:`.GroupedTable.aggregate`. @@ -1135,7 +1136,7 @@ def aggregate(self, expr): base, _ = self._process_joins(expr) analyze('Table.aggregate', expr, self._global_indices, {self._row_axis}) - return Env.hc()._backend.interpret(TableAggregate(base._tir, expr._ir)) + return Env.backend().execute(TableAggregate(base._tir, expr._ir)) @typecheck_method(output=str, overwrite=bool, @@ -1165,7 +1166,7 @@ def write(self, output: str, overwrite = False, stage_locally: bool = False, If ``True``, overwrite an existing file at the destination. """ - Env.hc()._backend.interpret(TableWrite(self._tir, output, overwrite, stage_locally, _codec_spec)) + Env.backend().execute(TableWrite(self._tir, output, overwrite, stage_locally, _codec_spec)) @typecheck_method(n=int, width=int, truncate=nullable(int), types=bool, handler=anyfunc) def show(self, n=10, width=90, truncate=None, types=True, handler=print): @@ -1341,10 +1342,11 @@ def types_compatible(left, right): def joiner(left): if not is_key: original_key = list(left.key) - left = Table._from_java(left.key_by()._jt.mapRows(str(Apply('annotate', - left._row._ir, - hl.struct(**dict(zip(uids, exprs)))._ir))) - ).key_by(*uids) + left = Table(TableMapRows(left.key_by()._tir, + Apply('annotate', + left._row._ir, + hl.struct(**dict(zip(uids, exprs)))._ir)) + ).key_by(*uids) rekey_f = lambda t: t.key_by(*original_key) else: rekey_f = identity @@ -1352,7 +1354,7 @@ def joiner(left): if is_interval: left = Table._from_java(left._jt.intervalJoin(self._jt, uid)) else: - left = Table._from_java(left._jt.leftJoinRightDistinct(self._jt, uid)) + left = Table(TableLeftJoinRightDistinct(left._tir, self._tir, uid)) return rekey_f(left) all_uids.append(uid) @@ -1387,7 +1389,8 @@ def joiner(left): key = None else: key = [str(k._ir) for k in exprs] - joiner = lambda left: MatrixTable._from_java(left._jmt.annotateRowsTableIR(right._jt, uid, key)) + joiner = lambda left: MatrixTable(MatrixAnnotateRowsTable( + left._mir, right._tir, uid, key)) ast = Join(GetField(TopLevelReference('va'), uid), [uid], exprs, @@ -1399,7 +1402,7 @@ def joiner(left): exprs[i] is src.col_key[i] for i in range(len(exprs))]): # key is already correct def joiner(left): - return MatrixTable._from_java(left._jmt.annotateColsTable(right._jt, uid)) + return MatrixTable(MatrixAnnotateColsTable(left._mir, right._tir, uid)) else: index_uid = Env.get_uid() uids = [Env.get_uid() for _ in exprs] @@ -1417,10 +1420,12 @@ def joiner(left: MatrixTable): .join(self, 'inner') .key_by(index_uid) .drop(*uids)) - result = MatrixTable._from_java(left.add_col_index(index_uid) - .key_cols_by(index_uid) - ._jmt - .annotateColsTable(joined._jt, uid)).key_cols_by(*prev_key) + result = MatrixTable(MatrixAnnotateColsTable( + (left.add_col_index(index_uid) + .key_cols_by(index_uid) + ._mir), + joined._tir, + uid)).key_cols_by(*prev_key) return result ir = Join(GetField(TopLevelReference('sa'), uid), all_uids, @@ -2318,11 +2323,8 @@ def to_matrix_table(self, row_key, col_key, row_fields=[], col_fields=[], n_part if len(col_key) == 0: raise ValueError(f"'to_matrix_table': require at least one col key field") - return hl.MatrixTable._from_java(self._jt.jToMatrixTable(row_key, - col_key, - row_fields, - col_fields, - n_partitions)) + return hl.MatrixTable(TableToMatrixTable( + self._tir, row_key, col_key, row_fields, col_fields, n_partitions)) @property def globals(self) -> 'StructExpression': @@ -2617,8 +2619,8 @@ def _filter_partitions(self, parts, keep=True): entries_field_name=str, col_key=sequenceof(str)) def _unlocalize_entries(self, entries_field_name, cols_field_name, col_key): - return hl.MatrixTable._from_java( - self._jt.unlocalizeEntries(entries_field_name, cols_field_name, col_key)) + return hl.MatrixTable(CastTableToMatrix( + self._tir, entries_field_name, cols_field_name, col_key)) @staticmethod @typecheck(tables=sequenceof(table_type), data_field_name=str, global_field_name=str) @@ -2632,9 +2634,7 @@ def _multi_way_zip_join(tables, data_field_name, global_field_name): raise TypeError('All input tables to multi_way_zip_join must have the same row type') if any(head.globals.dtype != t.globals.dtype for t in tables): raise TypeError('All input tables to multi_way_zip_join must have the same global type') - jt = Env.hail().table.Table.multiWayZipJoin([t._jt for t in tables], - data_field_name, - global_field_name) - return Table._from_java(jt) + return Table(TableMultiWayZipJoin( + [t._tir for t in tables], data_field_name, global_field_name)) table_type.set(Table) diff --git a/hail/python/hail/utils/java.py b/hail/python/hail/utils/java.py index 2dc78e1492a..d82d269bb6a 100644 --- a/hail/python/hail/utils/java.py +++ b/hail/python/hail/utils/java.py @@ -61,6 +61,10 @@ def hc(): assert Env._hc is not None return Env._hc + @staticmethod + def backend(): + return Env.hc()._backend + @staticmethod def sql_context(): return Env.hc()._sql_context diff --git a/hail/src/main/scala/is/hail/expr/ir/Parser.scala b/hail/src/main/scala/is/hail/expr/ir/Parser.scala index 751a10ffc1c..a0a817a42ee 100644 --- a/hail/src/main/scala/is/hail/expr/ir/Parser.scala +++ b/hail/src/main/scala/is/hail/expr/ir/Parser.scala @@ -730,6 +730,13 @@ object IRParser { val child = table_ir(env)(it) val query = ir_value_expr(env.update(child.typ.refMap))(it) TableAggregate(child, query) + case "TableExport" => + val child = table_ir(env)(it) + val path = string_literal(it) + val typesFile = opt(it, string_literal).orNull + val header = boolean_literal(it) + val exportType = int32_literal(it) + TableExport(child, path, typesFile, header, exportType) case "TableWrite" => val path = string_literal(it) val overwrite = boolean_literal(it) diff --git a/hail/src/main/scala/is/hail/table/Table.scala b/hail/src/main/scala/is/hail/table/Table.scala index b74bfb70e0e..cdf3bdc9f26 100644 --- a/hail/src/main/scala/is/hail/table/Table.scala +++ b/hail/src/main/scala/is/hail/table/Table.scala @@ -439,17 +439,6 @@ class Table(val hc: HailContext, val tir: TableIR) { signature = keySignature ++ TStruct(name -> TArray(valueSignature))) } - def jToMatrixTable(rowKeys: java.util.ArrayList[String], - colKeys: java.util.ArrayList[String], - rowFields: java.util.ArrayList[String], - colFields: java.util.ArrayList[String], - nPartitions: java.lang.Integer): MatrixTable = { - - toMatrixTable(rowKeys.asScala.toArray, colKeys.asScala.toArray, - rowFields.asScala.toArray, colFields.asScala.toArray, - Option(nPartitions).map(_.asInstanceOf[Int]) - ) - } def toMatrixTable( rowKeys: Array[String],