Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added MatrixUnionCols IR #5037

Merged
merged 4 commits into from
Dec 24, 2018
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
added MatrixUnionCols IR
  • Loading branch information
cseed committed Dec 22, 2018
commit f28726807255377262d9f69593698e92bfbab1d7
1 change: 1 addition & 0 deletions hail/python/hail/expr/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,6 +866,7 @@ def _is_prefix_of(self, other):
len(self._fields) <= len(other._fields) and
all(x == y for x, y in zip(self._field_types.values(), other._field_types.values())))


class ttuple(HailType):
"""Hail type for tuples.

Expand Down
8 changes: 8 additions & 0 deletions hail/python/hail/ir/matrix_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,14 @@ def render(self, r):
'(' + ' '.join(f'"{escape_str(f)}"' for f in self.new_key) + ')' if self.new_key is not None else 'None',
r(self.child), r(self.new_col))

class MatrixUnionCols(MatrixIR):
def __init__(self, left, right):
self.left = left
self.right = right

def render(self, r):
return f'(MatrixUnionCols {r(self.left)} {r(self.right)})'

class MatrixMapEntries(MatrixIR):
def __init__(self, child, new_entry):
super().__init__()
Expand Down
19 changes: 18 additions & 1 deletion hail/python/hail/matrixtable.py
Original file line number Diff line number Diff line change
Expand Up @@ -3104,7 +3104,24 @@ def union_cols(self, other: 'MatrixTable') -> 'MatrixTable':
:class:`.MatrixTable`
Dataset with columns from both datasets.
"""
return MatrixTable._from_java(self._jmt.unionCols(other._jmt))
if self._entry_type != other._entry_type:
raise ValueError('entry types differ\n'
f' left: {self._entry_type}\n'
f' right: {other.entry_type}')
if self._col_type != other._col_type:
raise ValueError('column types differ\n'
f' left: {self._col_type}\n'
f' right: {other.col_type}')
if list(self.col_key.values()) != list(other.col_key.values()):
raise ValueError('column key types differ\n'
f' left: {', '.join(self.col_key.values()}\n'
f' right: {', '.join(other.col_key.values()}')
if list(self.row_key.values()) != list(other.row_key.values()):
raise ValueError('row key types differ\n'
f' left: {', '.join(self.row_key.values()}\n'
f' right: {', '.join(other.row_key.values()}')

return MatrixTable(MatrixUnionCols(self._mir, other._mir))

@typecheck_method(n=int)
def head(self, n: int) -> 'MatrixTable':
Expand Down
10 changes: 9 additions & 1 deletion hail/src/main/scala/is/hail/expr/ir/IRBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ object IRBuilder {
def irIf(cond: IRProxy)(cnsq: IRProxy)(altr: IRProxy): IRProxy = (env: E) =>
If(cond(env), cnsq(env), altr(env))

def makeArray(first: IRProxy, rest: IRProxy*): IRProxy = arrayToProxy(first +: rest)

def makeStruct(fields: (Symbol, IRProxy)*): IRProxy = (env: E) =>
MakeStruct(fields.map { case (s, ir) => (s.name, ir(env)) })

Expand Down Expand Up @@ -124,7 +126,13 @@ object IRBuilder {
def map(f: LambdaProxy): IRProxy = (env: E) => {
val array = ir(env)
val eltType = array.typ.asInstanceOf[TArray].elementType
ArrayMap(ir(env), f.s.name, f.body(env.bind(f.s.name -> eltType)))
ArrayMap(array, f.s.name, f.body(env.bind(f.s.name -> eltType)))
}

def flatMap(f: LambdaProxy): IRProxy = (env: E) => {
val array = ir(env)
val eltType = array.typ.asInstanceOf[TArray].elementType
ArrayFlatMap(array, f.s.name, f.body(env.bind(f.s.name -> eltType)))
}

def sort(ascending: IRProxy, onKey: Boolean = false): IRProxy = (env: E) => ArraySort(ir(env), ascending(env), onKey)
Expand Down
24 changes: 23 additions & 1 deletion hail/src/main/scala/is/hail/expr/ir/LowerMatrixIR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,28 @@ object LowerMatrixIR {
}
}))

case MatrixUnionCols(left, right) =>
val rightEntries = genUID()
val rightCols = genUID()
TableJoin(
lower(left),
lower(right)
.mapRows('row
.insertFields(Symbol(rightEntries) -> 'row(entriesField))
.selectFields(right.typ.rowKey :+ rightEntries: _*))
.mapGlobals('global
.insertFields(Symbol(rightCols) -> 'global(colsField))
.selectFields(rightCols)),
"inner")
.mapRows('row
.insertFields(entriesField ->
makeArray('row(entriesField), 'row(Symbol(rightEntries))).flatMap('a ~> 'a))
.dropFields(Symbol(rightEntries)))
.mapGlobals('global
.insertFields(colsField ->
makeArray('global(colsField), 'global(Symbol(rightCols))).flatMap('a ~> 'a))
.dropFields(Symbol(rightCols)))

case MatrixMapEntries(child, newEntries) =>
lower(child).mapRows('row.insertFields(entriesField ->
irRange(0, 'global(colsField).len).map { 'i ~>
Expand Down Expand Up @@ -176,7 +198,7 @@ object LowerMatrixIR {
.mapRows('row.insertFields(entriesField ->
'global('newColIdx).map { 'kv ~>
makeStruct(child.typ.entryType.fieldNames.map { s =>
(Symbol(s), 'kv('value).map { 'i ~> 'row(entriesField)('i)(Symbol(s))})}: _*)
(Symbol(s), 'kv('value).map { 'i ~> 'row(entriesField)('i)(Symbol(s)) }) }: _*)
}))
.mapGlobals('global
.insertFields(colsField ->
Expand Down
15 changes: 14 additions & 1 deletion hail/src/main/scala/is/hail/expr/ir/MatrixIR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1078,8 +1078,21 @@ case class MatrixAggregateColsByKey(child: MatrixIR, entryExpr: IR, colExpr: IR)
}
}

case class MatrixMapEntries(child: MatrixIR, newEntries: IR) extends MatrixIR {
case class MatrixUnionCols(left: MatrixIR, right: MatrixIR) extends MatrixIR {
def children: IndexedSeq[BaseIR] = Array(left, right)

def copy(newChildren: IndexedSeq[BaseIR]): MatrixUnionCols = {
assert(newChildren.length == 2)
MatrixUnionCols(newChildren(0).asInstanceOf[MatrixIR], newChildren(1).asInstanceOf[MatrixIR])
}

val typ: MatrixType = left.typ

override def columnCount: Option[Int] =
left.columnCount.flatMap(leftCount => right.columnCount.map(rightCount => leftCount + rightCount))
}

case class MatrixMapEntries(child: MatrixIR, newEntries: IR) extends MatrixIR {
def children: IndexedSeq[BaseIR] = Array(child, newEntries)

def copy(newChildren: IndexedSeq[BaseIR]): MatrixMapEntries = {
Expand Down
4 changes: 4 additions & 0 deletions hail/src/main/scala/is/hail/expr/ir/Parser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -942,6 +942,10 @@ object IRParser {
val child = matrix_ir(env)(it)
val newEntry = ir_value_expr(env.withRefMap(child.typ.refMap))(it)
MatrixMapEntries(child, newEntry)
case "MatrixUnionCols" =>
val left = matrix_ir(env)(it)
val right = matrix_ir(env)(it)
MatrixUnionCols(left, right)
case "MatrixMapGlobals" =>
val child = matrix_ir(env)(it)
val newGlobals = ir_value_expr(env.withRefMap(child.typ.refMap))(it)
Expand Down
10 changes: 10 additions & 0 deletions hail/src/main/scala/is/hail/expr/ir/PruneDeadFields.scala
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,12 @@ object PruneDeadFields {
case MatrixFilterEntries(child, pred) =>
val irDep = memoizeAndGetDep(pred, pred.typ, child.typ, memo)
memoizeMatrixIR(child, unify(child.typ, requestedType, irDep), memo)
case MatrixUnionCols(left, right) =>
memoizeMatrixIR(left, requestedType, memo)
memoizeMatrixIR(right,
requestedType.copy(globalType = TStruct.empty(),
rvRowType = requestedType.rvRowType.filterSet((requestedType.rowKey :+ MatrixType.entriesIdentifier).toSet)._1),
memo)
case MatrixMapEntries(child, newEntries) =>
val irDep = memoizeAndGetDep(newEntries, requestedType.entryType, child.typ, memo)
val depMod = requestedType.copy(rvRowType = TStruct(requestedType.rvRowType.required, requestedType.rvRowType.fields.map { f =>
Expand Down Expand Up @@ -917,6 +923,10 @@ object PruneDeadFields {
case MatrixFilterEntries(child, pred) =>
val child2 = rebuild(child, memo)
MatrixFilterEntries(child2, rebuild(pred, child2.typ, memo))
case MatrixUnionCols(left, right) =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is the base case, you can delete.

val left2 = rebuild(left, memo)
val right2 = rebuild(right, memo)
MatrixUnionCols(left2, right2)
case MatrixMapEntries(child, newEntries) =>
val child2 = rebuild(child, memo)
MatrixMapEntries(child2, rebuild(newEntries, child2.typ, memo))
Expand Down
101 changes: 0 additions & 101 deletions hail/src/main/scala/is/hail/variant/MatrixTable.scala
Original file line number Diff line number Diff line change
Expand Up @@ -494,107 +494,6 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) {
copyAST(MatrixLiteral(newValue))
}

/**
*
* @param right right-hand dataset with which to join
*/
def unionCols(right: MatrixTable): MatrixTable = {
if (entryType != right.entryType) {
fatal(
s"""union_cols: cannot combine datasets with different entry schema
| left entry schema: @1
| right entry schema: @2""".stripMargin,
entryType.toString,
right.entryType.toString)
}

if (!colKeyTypes.sameElements(right.colKeyTypes)) {
fatal(
s"""union_cols: cannot combine datasets with different column key schema
| left column schema: [${ colKeyTypes.map(_.toString).mkString(", ") }]
| right column schema: [${ right.colKeyTypes.map(_.toString).mkString(", ") }]""".stripMargin)
}

if (colType != right.colType) {
fatal(
s"""union_cols: cannot combine datasets with different column schema
| left column schema: @1
| right column schema: @2""".stripMargin,
colType.toString,
right.colType.toString)
}

if (!rowKeyTypes.sameElements(right.rowKeyTypes)) {
fatal(
s"""union_cols: cannot combine datasets with different row key schema
| left row key schema: @1
| right row key schema: @2""".stripMargin,
rowKeyTypes.map(_.toString).mkString(", "),
right.rowKeyTypes.map(_.toString).mkString(", "))
}


val newMatrixType = matrixType.copyParts() // move entries to the end
val newRVRowType = newMatrixType.rvRowType
val leftRVRowType = rvRowType.physicalType
val rightRVRowType = right.rvRowType.physicalType
val localLeftSamples = numCols
val localRightSamples = right.numCols
val leftEntriesIndex = entriesIndex
val rightEntriesIndex = right.entriesIndex
val localEntriesType = matrixType.entryArrayType.physicalType
assert(right.matrixType.entryArrayType == matrixType.entryArrayType)

val joiner = { (ctx: RVDContext, it: Iterator[JoinedRegionValue]) =>
val rvb = ctx.rvb
val rv2 = RegionValue()

it.map { jrv =>
val lrv = jrv.rvLeft
val rrv = jrv.rvRight

rvb.start(newRVRowType.physicalType)
rvb.startStruct()
var i = 0
while (i < leftRVRowType.size) {
if (i != leftEntriesIndex)
rvb.addField(leftRVRowType, lrv, i)
i += 1
}
rvb.startArray(localLeftSamples + localRightSamples)

val leftEntriesOffset = leftRVRowType.loadField(lrv.region, lrv.offset, leftEntriesIndex)
val leftEntriesLength = localEntriesType.loadLength(lrv.region, leftEntriesOffset)
assert(leftEntriesLength == localLeftSamples)

val rightEntriesOffset = rightRVRowType.loadField(rrv.region, rrv.offset, rightEntriesIndex)
val rightEntriesLength = localEntriesType.loadLength(rrv.region, rightEntriesOffset)
assert(rightEntriesLength == localRightSamples)

i = 0
while (i < localLeftSamples) {
rvb.addElement(localEntriesType, lrv.region, leftEntriesOffset, i)
i += 1
}

i = 0
while (i < localRightSamples) {
rvb.addElement(localEntriesType, rrv.region, rightEntriesOffset, i)
i += 1
}

rvb.endArray()
rvb.endStruct()
rv2.set(ctx.region, rvb.end())
rv2
}
}

copyMT(matrixType = newMatrixType,
colValues = colValues.copy(value = colValues.value ++ right.colValues.value),
rvd = rvd.orderedJoinDistinct(right.rvd, "inner", joiner, newMatrixType.canonicalRVDType))
}

def aggregateRowsJSON(expr: String): String = {
val (a, t) = aggregateRows(expr)
val jv = JSONAnnotationImpex.exportAnnotation(a, t)
Expand Down