From 6fbb602f69e1f7417e75f00530cb7929255e83b1 Mon Sep 17 00:00:00 2001 From: Cotton Seed Date: Tue, 18 Dec 2018 21:55:01 -0500 Subject: [PATCH] move functionality to backend --- hail/python/hail/backend/backend.py | 39 ++++++++---- hail/python/hail/expr/matrix_type.py | 13 +++- hail/python/hail/expr/table_type.py | 10 +++- hail/python/hail/expr/types.py | 2 +- hail/python/hail/ir/base_ir.py | 31 +++++++--- hail/python/hail/ir/ir.py | 60 ++++++++++++------- hail/python/hail/ir/table_ir.py | 8 +-- hail/python/hail/matrixtable.py | 35 +++++------ hail/python/hail/table.py | 15 ++--- .../scala/is/hail/expr/AnnotationImpex.scala | 2 + .../scala/is/hail/expr/ir/Interpret.scala | 21 +------ .../scala/is/hail/expr/types/TableType.scala | 1 - 12 files changed, 140 insertions(+), 97 deletions(-) diff --git a/hail/python/hail/backend/backend.py b/hail/python/hail/backend/backend.py index f150dc371855..7240a7283bed 100644 --- a/hail/python/hail/backend/backend.py +++ b/hail/python/hail/backend/backend.py @@ -2,6 +2,8 @@ from hail.utils.java import * from hail.expr.types import dtype +from hail.expr.table_type import * +from hail.expr.matrix_type import * from hail.ir.renderer import Renderer @@ -10,21 +12,36 @@ class Backend(object): def interpret(self, ir): return + @abc.abstractmethod + def table_read_type(self, table_read_ir): + return -class SparkBackend(Backend): - def interpret(self, ir): - assert isinstance(ir, hail.ir.IR) + @abc.abstractmethod + def matrix_read_type(self, matrix_read_ir): + return - r = Renderer(stop_at_jir=True) - code = r(ir) - ir_map = {name: jir for name, jir in r.jirs.items()} - jir = ir.to_java_ir() +class SparkBackend(Backend): + def _to_java_ir(self, ir): + if not hasattr(ir, '_jir'): + r = Renderer(stop_at_jir=True) + code = r(ir) + # FIXME parse should be static + ir._jir = ir.parse(code, ir_map=r.jirs) + return ir._jir + + def interpret(self, ir): + return ir.typ._from_json( + Env.hail().expr.ir.Interpret.interpretJSON( + self._to_java_ir(ir))) - typ = dtype(jir.typ().toString()) - result = Env.hail().expr.ir.Interpret.interpretPyIR(code, {}, ir_map) + def table_read_type(self, tir): + jir = self._to_java_ir(tir) + return ttable._from_java(jir.typ()) - return typ._from_json(result) + def matrix_read_type(self, mir): + jir = self._to_java_ir(mir) + return tmatrix._from_java(jir.typ()) class ServiceBackend(Backend): @@ -34,8 +51,6 @@ def __init__(self, host, port=80, scheme='http'): self.port = port def interpret(self, ir): - assert isinstance(ir, hail.ir.IR) - r = Renderer(stop_at_jir=True) code = r(ir) assert len(r.jirs) == 0 diff --git a/hail/python/hail/expr/matrix_type.py b/hail/python/hail/expr/matrix_type.py index 800ae8224823..30965693d10b 100644 --- a/hail/python/hail/expr/matrix_type.py +++ b/hail/python/hail/expr/matrix_type.py @@ -1,8 +1,19 @@ from hail.typecheck import * from hail.utils.java import escape_parsable -from hail.expr.types import tstruct +from hail.expr.types import dtype, tstruct +from hail.utils.java import jiterable_to_list class tmatrix(object): + @staticmethod + def _from_java(jtt): + return tmatrix( + dtype(jtt.globalType().toString()), + dtype(jtt.colType().toString()), + jiterable_to_list(jtt.colKey()), + dtype(jtt.rowType().toString()), + jiterable_to_list(jtt.rowKey()), + dtype(jtt.entryType().toString())) + @typecheck_method(global_type=tstruct, col_type=tstruct, col_key=sequenceof(str), row_type=tstruct, row_key=sequenceof(str), diff --git a/hail/python/hail/expr/table_type.py b/hail/python/hail/expr/table_type.py index 20630b72fd6e..68d1159ded7d 100644 --- a/hail/python/hail/expr/table_type.py +++ b/hail/python/hail/expr/table_type.py @@ -1,8 +1,16 @@ from hail.typecheck import * from hail.utils.java import escape_parsable -from hail.expr.types import tstruct +from hail.expr.types import dtype, tstruct +from hail.utils.java import jiterable_to_list class ttable(object): + @staticmethod + def _from_java(jtt): + return ttable( + dtype(jtt.globalType().toString()), + dtype(jtt.rowType().toString()), + jiterable_to_list(jtt.key())) + @typecheck_method(global_type=tstruct, row_type=tstruct, row_key=sequenceof(str)) def __init__(self, global_type, row_type, row_key): self.global_type = global_type diff --git a/hail/python/hail/expr/types.py b/hail/python/hail/expr/types.py index 3c1511f2ff35..f00f515d5a4f 100644 --- a/hail/python/hail/expr/types.py +++ b/hail/python/hail/expr/types.py @@ -843,7 +843,7 @@ def _pretty(self, l, indent, increment): pre_indent = indent indent += increment l.append('struct {') - for i, (f, t) in enumerate(self._field_types): + for i, (f, t) in enumerate(self._field_types.items()): if i > 0: l.append(', ') l.append('\n') diff --git a/hail/python/hail/ir/base_ir.py b/hail/python/hail/ir/base_ir.py index 3d05c2edddb3..dc93c98f3eb7 100644 --- a/hail/python/hail/ir/base_ir.py +++ b/hail/python/hail/ir/base_ir.py @@ -1,5 +1,7 @@ import abc from .renderer import Renderer +from hail.expr.matrix_type import * +from hail.expr.table_type import * from hail.utils.java import Env @@ -11,18 +13,14 @@ def __str__(self): r = Renderer(stop_at_jir = False) return r(self) - def to_java_ir(self): - if not hasattr(self, '_jir'): - r = Renderer(stop_at_jir=True) - code = r(self) - ir_map = {name: jir for name, jir in r.jirs.items()} - self._jir = self.parse(code, ir_map=ir_map) - return self._jir - @abc.abstractmethod def parse(self, code, ref_map, ir_map): return + @abc.abstractproperty + def typ(self): + return + class IR(BaseIR): def __init__(self, *children): @@ -63,6 +61,11 @@ def map_ir(self, f): def bound_variables(self): return {v for child in self.children for v in child.bound_variables} + @property + def typ(self): + jir = Env.hc()._backend._to_java_ir(self) + return dtype(jir.typ().toString()) + def parse(self, code, ref_map={}, ir_map={}): return Env.hail().expr.ir.IRParser.parse_value_ir(code, ref_map, ir_map) @@ -71,6 +74,11 @@ class TableIR(BaseIR): def __init__(self): super().__init__() + @property + def typ(self): + jtir = Env.hc()._backend._to_java_ir(self) + return ttable._from_java(jtir.typ()) + def parse(self, code, ref_map={}, ir_map={}): return Env.hail().expr.ir.IRParser.parse_table_ir(code, ref_map, ir_map) @@ -79,5 +87,10 @@ class MatrixIR(BaseIR): def __init__(self): super().__init__() + @property + def typ(self): + jmir = Env.hc()._backend._to_java_ir(self) + return tmatrix._from_java(jmir.typ()) + def parse(self, code, ref_map={}, ir_map={}): - return Env.hail().expr.ir.IRParser.parse_matrix_ir(code, ref_map, ir_map) \ No newline at end of file + return Env.hail().expr.ir.IRParser.parse_matrix_ir(code, ref_map, ir_map) diff --git a/hail/python/hail/ir/ir.py b/hail/python/hail/ir/ir.py index b6058dad2b3d..20bcf79980a9 100644 --- a/hail/python/hail/ir/ir.py +++ b/hail/python/hail/ir/ir.py @@ -148,38 +148,46 @@ class Cast(IR): def __init__(self, v, typ): super().__init__(v) self.v = v - self.typ = typ + self._typ = typ + + @property + def typ(self): + return self._typ @typecheck_method(v=IR) def copy(self, v): new_instance = self.__class__ - return new_instance(v, self.typ) + return new_instance(v, self._typ) def render(self, r): - return '(Cast {} {})'.format(self.typ._parsable_string(), r(self.v)) + return '(Cast {} {})'.format(self._typ._parsable_string(), r(self.v)) def __eq__(self, other): return isinstance(other, Cast) and \ other.v == self.v and \ - other.typ == self.typ + other._typ == self._typ class NA(IR): @typecheck_method(typ=hail_type) def __init__(self, typ): super().__init__() - self.typ = typ + self._typ = typ + + @property + def typ(self): + return self._typ def copy(self): new_instance = self.__class__ - return new_instance(self.typ) + return new_instance(self._typ) def render(self, r): - return '(NA {})'.format(self.typ._parsable_string()) + return '(NA {})'.format(self._typ._parsable_string()) def __eq__(self, other): return isinstance(other, NA) and \ - other.typ == self.typ + other._typ == self._typ class IsNA(IR): @@ -354,25 +362,25 @@ def __eq__(self, other): class MakeArray(IR): - @typecheck_method(args=sequenceof(IR), typ=nullable(hail_type)) - def __init__(self, args, typ): + @typecheck_method(args=sequenceof(IR), element_type=nullable(hail_type)) + def __init__(self, args, element_type): super().__init__(*args) self.args = args - self.typ = typ + self._element_type = element_type def copy(self, *args): new_instance = self.__class__ - return new_instance(list(args), self.typ) + return new_instance(list(args), self._element_type) def render(self, r): return '(MakeArray {} {})'.format( - self.typ._parsable_string() if self.typ is not None else 'None', + self._element_type._parsable_string() if self._element_type is not None else 'None', ' '.join([r(x) for x in self.args])) def __eq__(self, other): return isinstance(other, MakeArray) and \ other.args == self.args and \ - other.typ == self.typ + other._element_type == self._element_type class ArrayRef(IR): @@ -1056,19 +1064,23 @@ class In(IR): def __init__(self, i, typ): super().__init__() self.i = i - self.typ = typ + self._typ = typ + + @property + def typ(self): + return self._typ def copy(self): new_instance = self.__class__ - return new_instance(self.i, self.typ) + return new_instance(self.i, self._typ) def render(self, r): - return '(In {} {})'.format(self.typ._parsable_string(), self.i) + return '(In {} {})'.format(self._typ._parsable_string(), self.i) def __eq__(self, other): return isinstance(other, In) and \ other.i == self.i and \ - other.typ == self.typ + other._typ == self._typ class Die(IR): @@ -1076,19 +1088,23 @@ class Die(IR): def __init__(self, message, typ): super().__init__() self.message = message - self.typ = typ + self._typ = typ + + @property + def typ(self): + return self._typ def copy(self): new_instance = self.__class__ - return new_instance(self.message, self.typ) + return new_instance(self.message, self._typ) def render(self, r): - return '(Die {} {})'.format(self.typ._parsable_string(), r(self.message)) + return '(Die {} {})'.format(self._typ._parsable_string(), r(self.message)) def __eq__(self, other): return isinstance(other, Die) and \ other.message == self.message and \ - other.typ == self.typ + other._typ == self._typ class Apply(IR): diff --git a/hail/python/hail/ir/table_ir.py b/hail/python/hail/ir/table_ir.py index 4ee5a10b4248..536e043f23ff 100644 --- a/hail/python/hail/ir/table_ir.py +++ b/hail/python/hail/ir/table_ir.py @@ -94,26 +94,26 @@ def __init__(self, path, drop_rows, typ): super().__init__() self.path = path self.drop_rows = drop_rows - self.typ = typ + self._typ = typ def render(self, r): return '(TableRead "{}" {} {})'.format( escape_str(self.path), self.drop_rows, - self.typ) + self._typ) class TableImport(TableIR): def __init__(self, paths, typ, reader_options): super().__init__() self.paths = paths - self.typ = typ + self._typ = typ self.reader_options = reader_options def render(self, r): return '(TableImport ({}) {} {})'.format( ' '.join([escape_str(path) for path in self.paths]), - self.typ._parsable_string(), + self._typ._parsable_string(), escape_str(json.dumps(self.reader_options))) diff --git a/hail/python/hail/matrixtable.py b/hail/python/hail/matrixtable.py index a5dd5fdc8f49..9175fd3db500 100644 --- a/hail/python/hail/matrixtable.py +++ b/hail/python/hail/matrixtable.py @@ -549,10 +549,8 @@ def __init__(self, mir): super(MatrixTable, self).__init__() self._mir = mir - self._jmir = mir.to_java_ir() - self._jmt = Env.hail().variant.MatrixTable(Env.hc()._jhc, self._jmir) - - jmtype = self._jmir.typ() + self._jmt = Env.hail().variant.MatrixTable( + Env.hc()._jhc, Env.hc()._backend._to_java_ir(self._mir)) self._globals = None self._col_values = None @@ -565,20 +563,17 @@ def __init__(self, mir): self._col_indices = Indices(self, {self._col_axis}) self._entry_indices = Indices(self, {self._row_axis, self._col_axis}) - self._global_type = hl.dtype(jmtype.globalType().toString()) - self._col_type = hl.dtype(jmtype.colType().toString()) - self._row_type = hl.dtype(jmtype.rowType().toString()) - self._entry_type = hl.dtype(jmtype.entryType().toString()) + self._type = self._mir.typ - assert isinstance(self._global_type, tstruct), self._global_type - assert isinstance(self._col_type, tstruct), self._col_type - assert isinstance(self._row_type, tstruct), self._row_type - assert isinstance(self._entry_type, tstruct), self._entry_type + self._global_type = self._type.global_type + self._col_type = self._type.col_type + self._row_type = self._type.row_type + self._entry_type = self._type.entry_type self._globals = construct_reference('global', self._global_type, indices=self._global_indices) self._rvrow = construct_reference('va', - hl.dtype(jmtype.rvRowType().toString()), + self._type.row_type, indices=self._row_indices) self._row = hail.struct(**{k: self._rvrow[k] for k in self._row_type.keys()}) self._col = construct_reference('sa', self._col_type, @@ -592,10 +587,10 @@ def __init__(self, mir): 'g': self._entry_indices} self._row_key = hail.struct( - **{k: self._row[k] for k in jiterable_to_list(jmtype.rowKey())}) + **{k: self._row[k] for k in self._type.row_key}) self._partition_key = self._row_key self._col_key = hail.struct( - **{k: self._col[k] for k in jiterable_to_list(jmtype.colKey())}) + **{k: self._col[k] for k in self._type.col_key}) self._num_samples = None @@ -2093,7 +2088,8 @@ def count_rows(self) -> int: Number of rows in the matrix. """ - return self._jmt.countRows() + return Env.hc()._backend.interpret( + TableCount(MatrixRowsTable(self._mir))) def _force_count_rows(self): return self._jmt.forceCountRows() @@ -2116,7 +2112,9 @@ def count_cols(self) -> int: :obj:`int` Number of columns in the matrix. """ - return self._jmt.countCols() + + return Env.hc()._backend.interpret( + TableCount(MatrixColsTable(self._mir))) def count(self) -> Tuple[int, int]: """Count the number of rows and columns in the matrix. @@ -2131,8 +2129,7 @@ def count(self) -> Tuple[int, int]: :obj:`int`, :obj:`int` Number of rows, number of cols. """ - r = self._jmt.count() - return r._1(), r._2() + return (self.count_rows(), self.count_cols()) @typecheck_method(output=str, overwrite=bool, diff --git a/hail/python/hail/table.py b/hail/python/hail/table.py index c1aeedf276b5..f6a6c9e0548a 100644 --- a/hail/python/hail/table.py +++ b/hail/python/hail/table.py @@ -324,21 +324,18 @@ def __init__(self, tir): super(Table, self).__init__() self._tir = tir - self._jtir = tir.to_java_ir() - self._jt = Env.hail().table.Table(Env.hc()._jhc, self._jtir) + self._jt = Env.hail().table.Table( + Env.hc()._jhc, Env.hc()._backend._to_java_ir(self._tir)) - jttype = self._jtir.typ() + self._type = self._tir.typ self._row_axis = 'row' self._global_indices = Indices(axes=set(), source=self) self._row_indices = Indices(axes={self._row_axis}, source=self) - self._global_type = hl.dtype(jttype.globalType().toString()) - self._row_type = hl.dtype(jttype.rowType().toString()) - - assert isinstance(self._global_type, tstruct) - assert isinstance(self._row_type, tstruct) + self._global_type = self._type.global_type + self._row_type = self._type.row_type self._globals = construct_reference('global', self._global_type, indices=self._global_indices) self._row = construct_reference('row', self._row_type, indices=self._row_indices) @@ -347,7 +344,7 @@ def __init__(self, tir): 'row': self._row_indices} self._key = hail.struct( - **{k: self._row[k] for k in jiterable_to_list(jttype.key())}) + **{k: self._row[k] for k in self._type.row_key}) for k, v in itertools.chain(self._globals.items(), self._row.items()): diff --git a/hail/src/main/scala/is/hail/expr/AnnotationImpex.scala b/hail/src/main/scala/is/hail/expr/AnnotationImpex.scala index 630c98ebbd16..e544bfcd1efa 100644 --- a/hail/src/main/scala/is/hail/expr/AnnotationImpex.scala +++ b/hail/src/main/scala/is/hail/expr/AnnotationImpex.scala @@ -101,6 +101,8 @@ object JSONAnnotationImpex { case _: TFloat32 => JDouble(a.asInstanceOf[Float]) case _: TFloat64 => JDouble(a.asInstanceOf[Double]) case _: TString => JString(a.asInstanceOf[String]) + case TVoid => + JNull case TArray(elementType, _) => val arr = a.asInstanceOf[Seq[Any]] JArray(arr.map(elem => exportAnnotation(elem, elementType)).toList) diff --git a/hail/src/main/scala/is/hail/expr/ir/Interpret.scala b/hail/src/main/scala/is/hail/expr/ir/Interpret.scala index 9ecc5fb30713..d624d356af06 100644 --- a/hail/src/main/scala/is/hail/expr/ir/Interpret.scala +++ b/hail/src/main/scala/is/hail/expr/ir/Interpret.scala @@ -1,7 +1,5 @@ package is.hail.expr.ir -import java.io.{PrintWriter, StringWriter} - import is.hail.{HailContext, Uploader, stats} import is.hail.annotations.aggregators.RegionValueAggregator import is.hail.annotations._ @@ -13,29 +11,16 @@ import is.hail.expr.types.virtual._ import is.hail.methods._ import is.hail.utils._ import org.apache.spark.sql.Row -import org.json4s.JsonAST.JNull import org.json4s.jackson.JsonMethods -import scala.collection.JavaConverters._ - object Interpret { type Agg = (IndexedSeq[Row], TStruct) - def interpretPyIR(s: String, refMap: java.util.HashMap[String, Type], irMap: java.util.HashMap[String, BaseIR]): String = { - interpretPyIR(s, refMap.asScala.toMap, irMap.asScala.toMap) - } - - def interpretPyIR(s: String, refMap: Map[String, Type] = Map.empty, irMap: Map[String, BaseIR] = Map.empty): String = { - val ir = IRParser.parse_value_ir(s, IRParserEnvironment(refMap, irMap)) + def interpretJSON(ir: IR): String = { val t = ir.typ val value = Interpret[Any](ir) - - val jsonValue = t match { - case TVoid => JNull - case _ => JSONAnnotationImpex.exportAnnotation(value, t) - } - - JsonMethods.compact(jsonValue) + JsonMethods.compact( + JSONAnnotationImpex.exportAnnotation(value, t)) } def apply(tir: TableIR): TableValue = diff --git a/hail/src/main/scala/is/hail/expr/types/TableType.scala b/hail/src/main/scala/is/hail/expr/types/TableType.scala index d09903319f61..e2f5c0e9edbf 100644 --- a/hail/src/main/scala/is/hail/expr/types/TableType.scala +++ b/hail/src/main/scala/is/hail/expr/types/TableType.scala @@ -1,6 +1,5 @@ package is.hail.expr.types -import is.hail.expr.Parser import is.hail.expr.ir._ import is.hail.expr.types.virtual.{TStruct, Type} import is.hail.rvd.RVDType