Skip to content

Commit

Permalink
move functionality to backend
Browse files Browse the repository at this point in the history
  • Loading branch information
cseed committed Dec 19, 2018
1 parent fbdb0c5 commit 6fbb602
Show file tree
Hide file tree
Showing 12 changed files with 140 additions and 97 deletions.
39 changes: 27 additions & 12 deletions hail/python/hail/backend/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from hail.utils.java import *
from hail.expr.types import dtype
from hail.expr.table_type import *
from hail.expr.matrix_type import *
from hail.ir.renderer import Renderer


Expand All @@ -10,21 +12,36 @@ class Backend(object):
def interpret(self, ir):
return

@abc.abstractmethod
def table_read_type(self, table_read_ir):
return

class SparkBackend(Backend):
def interpret(self, ir):
assert isinstance(ir, hail.ir.IR)
@abc.abstractmethod
def matrix_read_type(self, matrix_read_ir):
return

r = Renderer(stop_at_jir=True)
code = r(ir)
ir_map = {name: jir for name, jir in r.jirs.items()}

jir = ir.to_java_ir()
class SparkBackend(Backend):
def _to_java_ir(self, ir):
if not hasattr(ir, '_jir'):
r = Renderer(stop_at_jir=True)
code = r(ir)
# FIXME parse should be static
ir._jir = ir.parse(code, ir_map=r.jirs)
return ir._jir

def interpret(self, ir):
return ir.typ._from_json(
Env.hail().expr.ir.Interpret.interpretJSON(
self._to_java_ir(ir)))

typ = dtype(jir.typ().toString())
result = Env.hail().expr.ir.Interpret.interpretPyIR(code, {}, ir_map)
def table_read_type(self, tir):
jir = self._to_java_ir(tir)
return ttable._from_java(jir.typ())

return typ._from_json(result)
def matrix_read_type(self, mir):
jir = self._to_java_ir(mir)
return tmatrix._from_java(jir.typ())


class ServiceBackend(Backend):
Expand All @@ -34,8 +51,6 @@ def __init__(self, host, port=80, scheme='http'):
self.port = port

def interpret(self, ir):
assert isinstance(ir, hail.ir.IR)

r = Renderer(stop_at_jir=True)
code = r(ir)
assert len(r.jirs) == 0
Expand Down
13 changes: 12 additions & 1 deletion hail/python/hail/expr/matrix_type.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,19 @@
from hail.typecheck import *
from hail.utils.java import escape_parsable
from hail.expr.types import tstruct
from hail.expr.types import dtype, tstruct
from hail.utils.java import jiterable_to_list

class tmatrix(object):
@staticmethod
def _from_java(jtt):
return tmatrix(
dtype(jtt.globalType().toString()),
dtype(jtt.colType().toString()),
jiterable_to_list(jtt.colKey()),
dtype(jtt.rowType().toString()),
jiterable_to_list(jtt.rowKey()),
dtype(jtt.entryType().toString()))

@typecheck_method(global_type=tstruct,
col_type=tstruct, col_key=sequenceof(str),
row_type=tstruct, row_key=sequenceof(str),
Expand Down
10 changes: 9 additions & 1 deletion hail/python/hail/expr/table_type.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
from hail.typecheck import *
from hail.utils.java import escape_parsable
from hail.expr.types import tstruct
from hail.expr.types import dtype, tstruct
from hail.utils.java import jiterable_to_list

class ttable(object):
@staticmethod
def _from_java(jtt):
return ttable(
dtype(jtt.globalType().toString()),
dtype(jtt.rowType().toString()),
jiterable_to_list(jtt.key()))

@typecheck_method(global_type=tstruct, row_type=tstruct, row_key=sequenceof(str))
def __init__(self, global_type, row_type, row_key):
self.global_type = global_type
Expand Down
2 changes: 1 addition & 1 deletion hail/python/hail/expr/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -843,7 +843,7 @@ def _pretty(self, l, indent, increment):
pre_indent = indent
indent += increment
l.append('struct {')
for i, (f, t) in enumerate(self._field_types):
for i, (f, t) in enumerate(self._field_types.items()):
if i > 0:
l.append(', ')
l.append('\n')
Expand Down
31 changes: 22 additions & 9 deletions hail/python/hail/ir/base_ir.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import abc
from .renderer import Renderer
from hail.expr.matrix_type import *
from hail.expr.table_type import *
from hail.utils.java import Env


Expand All @@ -11,18 +13,14 @@ def __str__(self):
r = Renderer(stop_at_jir = False)
return r(self)

def to_java_ir(self):
if not hasattr(self, '_jir'):
r = Renderer(stop_at_jir=True)
code = r(self)
ir_map = {name: jir for name, jir in r.jirs.items()}
self._jir = self.parse(code, ir_map=ir_map)
return self._jir

@abc.abstractmethod
def parse(self, code, ref_map, ir_map):
return

@abc.abstractproperty
def typ(self):
return


class IR(BaseIR):
def __init__(self, *children):
Expand Down Expand Up @@ -63,6 +61,11 @@ def map_ir(self, f):
def bound_variables(self):
return {v for child in self.children for v in child.bound_variables}

@property
def typ(self):
jir = Env.hc()._backend._to_java_ir(self)
return dtype(jir.typ().toString())

def parse(self, code, ref_map={}, ir_map={}):
return Env.hail().expr.ir.IRParser.parse_value_ir(code, ref_map, ir_map)

Expand All @@ -71,6 +74,11 @@ class TableIR(BaseIR):
def __init__(self):
super().__init__()

@property
def typ(self):
jtir = Env.hc()._backend._to_java_ir(self)
return ttable._from_java(jtir.typ())

def parse(self, code, ref_map={}, ir_map={}):
return Env.hail().expr.ir.IRParser.parse_table_ir(code, ref_map, ir_map)

Expand All @@ -79,5 +87,10 @@ class MatrixIR(BaseIR):
def __init__(self):
super().__init__()

@property
def typ(self):
jmir = Env.hc()._backend._to_java_ir(self)
return tmatrix._from_java(jmir.typ())

def parse(self, code, ref_map={}, ir_map={}):
return Env.hail().expr.ir.IRParser.parse_matrix_ir(code, ref_map, ir_map)
return Env.hail().expr.ir.IRParser.parse_matrix_ir(code, ref_map, ir_map)
60 changes: 38 additions & 22 deletions hail/python/hail/ir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,38 +148,46 @@ class Cast(IR):
def __init__(self, v, typ):
super().__init__(v)
self.v = v
self.typ = typ
self._typ = typ

@property
def typ(self):
return self._typ

@typecheck_method(v=IR)
def copy(self, v):
new_instance = self.__class__
return new_instance(v, self.typ)
return new_instance(v, self._typ)

def render(self, r):
return '(Cast {} {})'.format(self.typ._parsable_string(), r(self.v))
return '(Cast {} {})'.format(self._typ._parsable_string(), r(self.v))

def __eq__(self, other):
return isinstance(other, Cast) and \
other.v == self.v and \
other.typ == self.typ
other._typ == self._typ


class NA(IR):
@typecheck_method(typ=hail_type)
def __init__(self, typ):
super().__init__()
self.typ = typ
self._typ = typ

@property
def typ(self):
return self._typ

def copy(self):
new_instance = self.__class__
return new_instance(self.typ)
return new_instance(self._typ)

def render(self, r):
return '(NA {})'.format(self.typ._parsable_string())
return '(NA {})'.format(self._typ._parsable_string())

def __eq__(self, other):
return isinstance(other, NA) and \
other.typ == self.typ
other._typ == self._typ


class IsNA(IR):
Expand Down Expand Up @@ -354,25 +362,25 @@ def __eq__(self, other):


class MakeArray(IR):
@typecheck_method(args=sequenceof(IR), typ=nullable(hail_type))
def __init__(self, args, typ):
@typecheck_method(args=sequenceof(IR), element_type=nullable(hail_type))
def __init__(self, args, element_type):
super().__init__(*args)
self.args = args
self.typ = typ
self._element_type = element_type

def copy(self, *args):
new_instance = self.__class__
return new_instance(list(args), self.typ)
return new_instance(list(args), self._element_type)

def render(self, r):
return '(MakeArray {} {})'.format(
self.typ._parsable_string() if self.typ is not None else 'None',
self._element_type._parsable_string() if self._element_type is not None else 'None',
' '.join([r(x) for x in self.args]))

def __eq__(self, other):
return isinstance(other, MakeArray) and \
other.args == self.args and \
other.typ == self.typ
other._element_type == self._element_type


class ArrayRef(IR):
Expand Down Expand Up @@ -1056,39 +1064,47 @@ class In(IR):
def __init__(self, i, typ):
super().__init__()
self.i = i
self.typ = typ
self._typ = typ

@property
def typ(self):
return self._typ

def copy(self):
new_instance = self.__class__
return new_instance(self.i, self.typ)
return new_instance(self.i, self._typ)

def render(self, r):
return '(In {} {})'.format(self.typ._parsable_string(), self.i)
return '(In {} {})'.format(self._typ._parsable_string(), self.i)

def __eq__(self, other):
return isinstance(other, In) and \
other.i == self.i and \
other.typ == self.typ
other._typ == self._typ


class Die(IR):
@typecheck_method(message=IR, typ=hail_type)
def __init__(self, message, typ):
super().__init__()
self.message = message
self.typ = typ
self._typ = typ

@property
def typ(self):
return self._typ

def copy(self):
new_instance = self.__class__
return new_instance(self.message, self.typ)
return new_instance(self.message, self._typ)

def render(self, r):
return '(Die {} {})'.format(self.typ._parsable_string(), r(self.message))
return '(Die {} {})'.format(self._typ._parsable_string(), r(self.message))

def __eq__(self, other):
return isinstance(other, Die) and \
other.message == self.message and \
other.typ == self.typ
other._typ == self._typ


class Apply(IR):
Expand Down
8 changes: 4 additions & 4 deletions hail/python/hail/ir/table_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,26 +94,26 @@ def __init__(self, path, drop_rows, typ):
super().__init__()
self.path = path
self.drop_rows = drop_rows
self.typ = typ
self._typ = typ

def render(self, r):
return '(TableRead "{}" {} {})'.format(
escape_str(self.path),
self.drop_rows,
self.typ)
self._typ)


class TableImport(TableIR):
def __init__(self, paths, typ, reader_options):
super().__init__()
self.paths = paths
self.typ = typ
self._typ = typ
self.reader_options = reader_options

def render(self, r):
return '(TableImport ({}) {} {})'.format(
' '.join([escape_str(path) for path in self.paths]),
self.typ._parsable_string(),
self._typ._parsable_string(),
escape_str(json.dumps(self.reader_options)))


Expand Down
Loading

0 comments on commit 6fbb602

Please sign in to comment.