Skip to content

Commit b8ca932

Browse files
committed
added MatrixUnionCols IR
1 parent b863e77 commit b8ca932

File tree

9 files changed

+87
-105
lines changed

9 files changed

+87
-105
lines changed

hail/python/hail/expr/types.py

+1
Original file line numberDiff line numberDiff line change
@@ -866,6 +866,7 @@ def _is_prefix_of(self, other):
866866
len(self._fields) <= len(other._fields) and
867867
all(x == y for x, y in zip(self._field_types.values(), other._field_types.values())))
868868

869+
869870
class ttuple(HailType):
870871
"""Hail type for tuples.
871872

hail/python/hail/ir/matrix_ir.py

+8
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,14 @@ def render(self, r):
5555
'(' + ' '.join(f'"{escape_str(f)}"' for f in self.new_key) + ')' if self.new_key is not None else 'None',
5656
r(self.child), r(self.new_col))
5757

58+
class MatrixUnionCols(MatrixIR):
59+
def __init__(self, left, right):
60+
self.left = left
61+
self.right = right
62+
63+
def render(self, r):
64+
return f'(MatrixUnionCols {r(self.left)} {r(self.right)})'
65+
5866
class MatrixMapEntries(MatrixIR):
5967
def __init__(self, child, new_entry):
6068
super().__init__()

hail/python/hail/matrixtable.py

+18-1
Original file line numberDiff line numberDiff line change
@@ -3104,7 +3104,24 @@ def union_cols(self, other: 'MatrixTable') -> 'MatrixTable':
31043104
:class:`.MatrixTable`
31053105
Dataset with columns from both datasets.
31063106
"""
3107-
return MatrixTable._from_java(self._jmt.unionCols(other._jmt))
3107+
if self._entry_type != other._entry_type:
3108+
raise ValueError('entry types differ\n'
3109+
f' left: {self._entry_type}\n'
3110+
f' right: {other.entry_type}')
3111+
if self._col_type != other._col_type:
3112+
raise ValueError('column types differ\n'
3113+
f' left: {self._col_type}\n'
3114+
f' right: {other.col_type}')
3115+
if list(self.col_key.values()) != list(other.col_key.values()):
3116+
raise ValueError('column key types differ\n'
3117+
f' left: {', '.join(self.col_key.values()}\n'
3118+
f' right: {', '.join(other.col_key.values()}')
3119+
if list(self.row_key.values()) != list(other.row_key.values()):
3120+
raise ValueError('row key types differ\n'
3121+
f' left: {', '.join(self.row_key.values()}\n'
3122+
f' right: {', '.join(other.row_key.values()}')
3123+
3124+
return MatrixTable(MatrixUnionCols(self._mir, other._mir))
31083125

31093126
@typecheck_method(n=int)
31103127
def head(self, n: int) -> 'MatrixTable':

hail/src/main/scala/is/hail/expr/ir/IRBuilder.scala

+9-1
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ object IRBuilder {
4141
def irIf(cond: IRProxy)(cnsq: IRProxy)(altr: IRProxy): IRProxy = (env: E) =>
4242
If(cond(env), cnsq(env), altr(env))
4343

44+
def makeArray(first: IRProxy, rest: IRProxy*): IRProxy = arrayToProxy(first +: rest)
45+
4446
def makeStruct(fields: (Symbol, IRProxy)*): IRProxy = (env: E) =>
4547
MakeStruct(fields.map { case (s, ir) => (s.name, ir(env)) })
4648

@@ -124,7 +126,13 @@ object IRBuilder {
124126
def map(f: LambdaProxy): IRProxy = (env: E) => {
125127
val array = ir(env)
126128
val eltType = array.typ.asInstanceOf[TArray].elementType
127-
ArrayMap(ir(env), f.s.name, f.body(env.bind(f.s.name -> eltType)))
129+
ArrayMap(array, f.s.name, f.body(env.bind(f.s.name -> eltType)))
130+
}
131+
132+
def flatMap(f: LambdaProxy): IRProxy = (env: E) => {
133+
val array = ir(env)
134+
val eltType = array.typ.asInstanceOf[TArray].elementType
135+
ArrayFlatMap(array, f.s.name, f.body(env.bind(f.s.name -> eltType)))
128136
}
129137

130138
def sort(ascending: IRProxy, onKey: Boolean = false): IRProxy = (env: E) => ArraySort(ir(env), ascending(env), onKey)

hail/src/main/scala/is/hail/expr/ir/LowerMatrixIR.scala

+23-1
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,28 @@ object LowerMatrixIR {
147147
}
148148
}))
149149

150+
case MatrixUnionCols(left, right) =>
151+
val rightEntries = genUID()
152+
val rightCols = genUID()
153+
TableJoin(
154+
lower(left),
155+
lower(right)
156+
.mapRows('row
157+
.insertFields(Symbol(rightEntries) -> 'row(entriesField))
158+
.selectFields(right.typ.rowKey :+ rightEntries: _*))
159+
.mapGlobals('global
160+
.insertFields(Symbol(rightCols) -> 'global(colsField))
161+
.selectFields(rightCols)),
162+
"inner")
163+
.mapRows('row
164+
.insertFields(entriesField ->
165+
makeArray('row(entriesField), 'row(Symbol(rightEntries))).flatMap('a ~> 'a))
166+
.dropFields(Symbol(rightEntries)))
167+
.mapGlobals('global
168+
.insertFields(colsField ->
169+
makeArray('global(colsField), 'global(Symbol(rightCols))).flatMap('a ~> 'a))
170+
.dropFields(Symbol(rightCols)))
171+
150172
case MatrixMapEntries(child, newEntries) =>
151173
lower(child).mapRows('row.insertFields(entriesField ->
152174
irRange(0, 'global(colsField).len).map { 'i ~>
@@ -176,7 +198,7 @@ object LowerMatrixIR {
176198
.mapRows('row.insertFields(entriesField ->
177199
'global('newColIdx).map { 'kv ~>
178200
makeStruct(child.typ.entryType.fieldNames.map { s =>
179-
(Symbol(s), 'kv('value).map { 'i ~> 'row(entriesField)('i)(Symbol(s))})}: _*)
201+
(Symbol(s), 'kv('value).map { 'i ~> 'row(entriesField)('i)(Symbol(s)) }) }: _*)
180202
}))
181203
.mapGlobals('global
182204
.insertFields(colsField ->

hail/src/main/scala/is/hail/expr/ir/MatrixIR.scala

+14-1
Original file line numberDiff line numberDiff line change
@@ -1078,8 +1078,21 @@ case class MatrixAggregateColsByKey(child: MatrixIR, entryExpr: IR, colExpr: IR)
10781078
}
10791079
}
10801080

1081-
case class MatrixMapEntries(child: MatrixIR, newEntries: IR) extends MatrixIR {
1081+
case class MatrixUnionCols(left: MatrixIR, right: MatrixIR) extends MatrixIR {
1082+
def children: IndexedSeq[BaseIR] = Array(left, right)
1083+
1084+
def copy(newChildren: IndexedSeq[BaseIR]): MatrixUnionCols = {
1085+
assert(newChildren.length == 2)
1086+
MatrixUnionCols(newChildren(0).asInstanceOf[MatrixIR], newChildren(1).asInstanceOf[MatrixIR])
1087+
}
1088+
1089+
val typ: MatrixType = left.typ
10821090

1091+
override def columnCount: Option[Int] =
1092+
left.columnCount.flatMap(leftCount => right.columnCount.map(rightCount => leftCount + rightCount))
1093+
}
1094+
1095+
case class MatrixMapEntries(child: MatrixIR, newEntries: IR) extends MatrixIR {
10831096
def children: IndexedSeq[BaseIR] = Array(child, newEntries)
10841097

10851098
def copy(newChildren: IndexedSeq[BaseIR]): MatrixMapEntries = {

hail/src/main/scala/is/hail/expr/ir/Parser.scala

+4
Original file line numberDiff line numberDiff line change
@@ -942,6 +942,10 @@ object IRParser {
942942
val child = matrix_ir(env)(it)
943943
val newEntry = ir_value_expr(env.withRefMap(child.typ.refMap))(it)
944944
MatrixMapEntries(child, newEntry)
945+
case "MatrixUnionCols" =>
946+
val left = matrix_ir(env)(it)
947+
val right = matrix_ir(env)(it)
948+
MatrixUnionCols(left, right)
945949
case "MatrixMapGlobals" =>
946950
val child = matrix_ir(env)(it)
947951
val newGlobals = ir_value_expr(env.withRefMap(child.typ.refMap))(it)

hail/src/main/scala/is/hail/expr/ir/PruneDeadFields.scala

+10
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,12 @@ object PruneDeadFields {
417417
case MatrixFilterEntries(child, pred) =>
418418
val irDep = memoizeAndGetDep(pred, pred.typ, child.typ, memo)
419419
memoizeMatrixIR(child, unify(child.typ, requestedType, irDep), memo)
420+
case MatrixUnionCols(left, right) =>
421+
memoizeMatrixIR(left, requestedType, memo)
422+
memoizeMatrixIR(right,
423+
requestedType.copy(globalType = TStruct.empty(),
424+
rvRowType = requestedType.rvRowType.filterSet((requestedType.rowKey :+ MatrixType.entriesIdentifier).toSet)._1),
425+
memo)
420426
case MatrixMapEntries(child, newEntries) =>
421427
val irDep = memoizeAndGetDep(newEntries, requestedType.entryType, child.typ, memo)
422428
val depMod = requestedType.copy(rvRowType = TStruct(requestedType.rvRowType.required, requestedType.rvRowType.fields.map { f =>
@@ -917,6 +923,10 @@ object PruneDeadFields {
917923
case MatrixFilterEntries(child, pred) =>
918924
val child2 = rebuild(child, memo)
919925
MatrixFilterEntries(child2, rebuild(pred, child2.typ, memo))
926+
case MatrixUnionCols(left, right) =>
927+
val left2 = rebuild(left, memo)
928+
val right2 = rebuild(right, memo)
929+
MatrixUnionCols(left2, right2)
920930
case MatrixMapEntries(child, newEntries) =>
921931
val child2 = rebuild(child, memo)
922932
MatrixMapEntries(child2, rebuild(newEntries, child2.typ, memo))

hail/src/main/scala/is/hail/variant/MatrixTable.scala

-101
Original file line numberDiff line numberDiff line change
@@ -494,107 +494,6 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) {
494494
copyAST(MatrixLiteral(newValue))
495495
}
496496

497-
/**
498-
*
499-
* @param right right-hand dataset with which to join
500-
*/
501-
def unionCols(right: MatrixTable): MatrixTable = {
502-
if (entryType != right.entryType) {
503-
fatal(
504-
s"""union_cols: cannot combine datasets with different entry schema
505-
| left entry schema: @1
506-
| right entry schema: @2""".stripMargin,
507-
entryType.toString,
508-
right.entryType.toString)
509-
}
510-
511-
if (!colKeyTypes.sameElements(right.colKeyTypes)) {
512-
fatal(
513-
s"""union_cols: cannot combine datasets with different column key schema
514-
| left column schema: [${ colKeyTypes.map(_.toString).mkString(", ") }]
515-
| right column schema: [${ right.colKeyTypes.map(_.toString).mkString(", ") }]""".stripMargin)
516-
}
517-
518-
if (colType != right.colType) {
519-
fatal(
520-
s"""union_cols: cannot combine datasets with different column schema
521-
| left column schema: @1
522-
| right column schema: @2""".stripMargin,
523-
colType.toString,
524-
right.colType.toString)
525-
}
526-
527-
if (!rowKeyTypes.sameElements(right.rowKeyTypes)) {
528-
fatal(
529-
s"""union_cols: cannot combine datasets with different row key schema
530-
| left row key schema: @1
531-
| right row key schema: @2""".stripMargin,
532-
rowKeyTypes.map(_.toString).mkString(", "),
533-
right.rowKeyTypes.map(_.toString).mkString(", "))
534-
}
535-
536-
537-
val newMatrixType = matrixType.copyParts() // move entries to the end
538-
val newRVRowType = newMatrixType.rvRowType
539-
val leftRVRowType = rvRowType.physicalType
540-
val rightRVRowType = right.rvRowType.physicalType
541-
val localLeftSamples = numCols
542-
val localRightSamples = right.numCols
543-
val leftEntriesIndex = entriesIndex
544-
val rightEntriesIndex = right.entriesIndex
545-
val localEntriesType = matrixType.entryArrayType.physicalType
546-
assert(right.matrixType.entryArrayType == matrixType.entryArrayType)
547-
548-
val joiner = { (ctx: RVDContext, it: Iterator[JoinedRegionValue]) =>
549-
val rvb = ctx.rvb
550-
val rv2 = RegionValue()
551-
552-
it.map { jrv =>
553-
val lrv = jrv.rvLeft
554-
val rrv = jrv.rvRight
555-
556-
rvb.start(newRVRowType.physicalType)
557-
rvb.startStruct()
558-
var i = 0
559-
while (i < leftRVRowType.size) {
560-
if (i != leftEntriesIndex)
561-
rvb.addField(leftRVRowType, lrv, i)
562-
i += 1
563-
}
564-
rvb.startArray(localLeftSamples + localRightSamples)
565-
566-
val leftEntriesOffset = leftRVRowType.loadField(lrv.region, lrv.offset, leftEntriesIndex)
567-
val leftEntriesLength = localEntriesType.loadLength(lrv.region, leftEntriesOffset)
568-
assert(leftEntriesLength == localLeftSamples)
569-
570-
val rightEntriesOffset = rightRVRowType.loadField(rrv.region, rrv.offset, rightEntriesIndex)
571-
val rightEntriesLength = localEntriesType.loadLength(rrv.region, rightEntriesOffset)
572-
assert(rightEntriesLength == localRightSamples)
573-
574-
i = 0
575-
while (i < localLeftSamples) {
576-
rvb.addElement(localEntriesType, lrv.region, leftEntriesOffset, i)
577-
i += 1
578-
}
579-
580-
i = 0
581-
while (i < localRightSamples) {
582-
rvb.addElement(localEntriesType, rrv.region, rightEntriesOffset, i)
583-
i += 1
584-
}
585-
586-
rvb.endArray()
587-
rvb.endStruct()
588-
rv2.set(ctx.region, rvb.end())
589-
rv2
590-
}
591-
}
592-
593-
copyMT(matrixType = newMatrixType,
594-
colValues = colValues.copy(value = colValues.value ++ right.colValues.value),
595-
rvd = rvd.orderedJoinDistinct(right.rvd, "inner", joiner, newMatrixType.canonicalRVDType))
596-
}
597-
598497
def makeTable(separator: String = "."): Table = {
599498
matrixType.requireColKeyString()
600499
requireUniqueSamples("make_table")

0 commit comments

Comments
 (0)