Skip to content

Commit

Permalink
changes to C++ emit code for (coming) ordering (hail-is#4852)
Browse files Browse the repository at this point in the history
* cxx builder changes to prepare for orderings

* fix rebase

* fix bug

* fixed bug

* addressed comments
  • Loading branch information
cseed authored and danking committed Dec 1, 2018
1 parent 97f5aa5 commit 99f691d
Show file tree
Hide file tree
Showing 12 changed files with 257 additions and 257 deletions.
54 changes: 27 additions & 27 deletions hail/src/main/scala/is/hail/cxx/Compile.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package is.hail.cxx

import java.io.PrintWriter

import is.hail.expr.ir
import is.hail.expr.types.physical._
import is.hail.nativecode.{NativeLongFuncL2, NativeModule, NativeStatus}
Expand All @@ -17,7 +19,16 @@ object Compile {
assert(ir.TypeToIRIntermediateClassTag(returnType.virtualType) == classTag[Long])
assert(returnType.isInstanceOf[PBaseStruct])

val fb = FunctionBuilder("f",
val tub = new TranslationUnitBuilder

tub.include("hail/hail.h")
tub.include("hail/Utils.h")
tub.include("hail/Region.h")

tub.include("<limits.h>")
tub.include("<math.h>")

val fb = tub.buildFunction("f",
Array("NativeStatus *" -> "st", "Region *" -> "region", "const char *" -> "v"),
"char *")

Expand All @@ -30,34 +41,23 @@ object Compile {
| abort();
|return ${ v.v };
|""".stripMargin
val f = fb.result()
val f = fb.end()

tub += new Definition {
def name: String = "entrypoint"

def define: String =
s"""
|long entrypoint(NativeStatus *st, long region, long v) {
| return (long)${ f.name }(st, (Region *)region, (char *)v);
|}
""".stripMargin
}

val tu = tub.end()
val mod = tu.build("-ggdb -O1")

val sb = new StringBuilder
sb.append(
s"""
|#include "hail/hail.h"
|#include "hail/Utils.h"
|#include "hail/Region.h"
|
|#include <limits.h>
|#include <math.h>
|
|NAMESPACE_HAIL_MODULE_BEGIN
|
|${ f.define }
|
|long entrypoint(NativeStatus *st, long region, long v) {
| return (long)${ f.name }(st, (Region *)region, (char *)v);
|}
|
|NAMESPACE_HAIL_MODULE_END
|""".stripMargin)

val modCode = sb.toString()

val options = "-ggdb -O1"
val st = new NativeStatus()
val mod = new NativeModule(options, modCode)
mod.findOrBuild(st)
assert(st.ok, st.toString())
val nativef = mod.findLongFuncL2(st, "entrypoint")
Expand Down
63 changes: 15 additions & 48 deletions hail/src/main/scala/is/hail/cxx/Definition.scala
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
package is.hail.cxx

import is.hail.cxx
import is.hail.utils.ArrayBuilder

trait Definition {
def name: String
def typ: String
def define: Code
}

Expand Down Expand Up @@ -53,28 +53,18 @@ class Function(returnType: Type, val name: String, args: Array[Variable], body:
def define: Code = s"$returnType $name(${args.map(a => s"${a.typ} ${a.name}").mkString(", ")}) {\n$body\n}"
}

object FunctionBuilder {

def apply(prefix: String, args: Array[(Type, String)], returnType: Type): FunctionBuilder =
new FunctionBuilder(
prefix,
args.map { case (typ, p) => new Variable(p, typ, null) },
returnType)

def apply(prefix: String, argTypes: Array[Type], returnType: Type): FunctionBuilder =
apply(prefix, argTypes.map(_ -> genSym("arg")), returnType)
}

class FunctionBuilder(prefix: String, args: Array[Variable], returnType: Type) {
class FunctionBuilder(val parent: ScopeBuilder, prefix: String, args: Array[Variable], returnType: Type)
extends DefinitionBuilder[Function] {

val statements: ArrayBuilder[Code] = new ArrayBuilder[Code]()

def +=(statement: Code) =
def +=(statement: Code) {
statements += statement
}

def getArg(i: Int): Variable = args(i)

def result(): Function = new Function(returnType, prefix, args, statements.result().mkString("\n"))
def build(): Function = new Function(returnType, prefix, args, statements.result().mkString("\n"))

def defaultReturn: Code = {
if (returnType == "long")
Expand All @@ -90,48 +80,25 @@ class FunctionBuilder(prefix: String, args: Array[Variable], returnType: Type) {

}

class Class(val name: String, superClass: String, privateDefs: Array[Definition], publicDefs: Array[Definition]) extends Definition {
class Class(val name: String, superClass: String, definitions: Array[Code]) extends Definition {
def typ: Type = name

override def toString: Type = name

def addSuperclass(newSuper: String): Class = new Class(name, newSuper, privateDefs, publicDefs)
def addSuperclass(newSuper: String): Class = new Class(name, newSuper, definitions)

def define: Code =
s"""class $name${ if (superClass == null) "" else s" : public $superClass" } {
| private:
| ${ privateDefs.map(_.define).mkString("\n") }
| public:
| ${ publicDefs.map(_.define).mkString("\n") }
| ${ definitions.mkString("\n") }
|};
""".stripMargin
}

class ClassBuilder(val name: String, superClass: String = null) {
private[this] val privateDefs = new ArrayBuilder[Definition]()
private[this] val publicDefs = new ArrayBuilder[Definition]()

def +=(d: Definition) { publicDefs += d }
class ClassBuilder(val parent: ScopeBuilder, val name: String, superClass: String = null)
extends ScopeBuilder with DefinitionBuilder[Class] {
def build(): Class = new Class(name, superClass, definitions.result())

def addPrivate(d: Definition) { privateDefs += d }

def result(): Class = new Class(name, superClass, privateDefs.result(), publicDefs.result())

def addConstructor(definition: Code) {
val className = name
publicDefs += new Definition {
def name: String = className
def typ: Type = name
def define: Code = definition
}
}

def addDestructor(definition: Code) {
val className = name
publicDefs += new Definition {
def name: String = className
def typ: Type = name
def define: Code = definition
}
}
}
def buildMethod(prefix: String, args: Array[(cxx.Type, String)], returnType: Type): FunctionBuilder =
new FunctionBuilder(this, prefix, args.map { case (typ, p) => new Variable(p, typ, null) }, returnType)
}
33 changes: 16 additions & 17 deletions hail/src/main/scala/is/hail/cxx/PackDecoder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -175,15 +175,15 @@ object PackDecoder {
tub.include("<cstdio>")
tub.include("<memory>")

val decoderBuilder = new ClassBuilder(genSym("Decoder"), "NativeObj")
val decoderBuilder = tub.buildClass(genSym("Decoder"), "NativeObj")

val bufType = bufSpec.nativeInputBufferType
val buf = Variable("buf", s"std::shared_ptr<$bufType>")
decoderBuilder.addPrivate(buf)
decoderBuilder += buf

decoderBuilder.addConstructor(s"${ decoderBuilder.name }(std::shared_ptr<InputStream> is) : $buf(std::make_shared<$bufType>(is)) { }")
decoderBuilder += s"${ decoderBuilder.name }(std::shared_ptr<InputStream> is) : $buf(std::make_shared<$bufType>(is)) { }"

val rowFB = FunctionBuilder("decode_row", Array("NativeStatus*" -> "st", "Region *" -> "region"), "char *")
val rowFB = decoderBuilder.buildMethod("decode_row", Array("NativeStatus*" -> "st", "Region *" -> "region"), "char *")
val region = rowFB.getArg(1)
val initialSize = rt match {
case _: PArray | _: PBinary => 8
Expand All @@ -196,41 +196,40 @@ object PackDecoder {
case _: PArray | _: PBinary => s"return load_address($row);"
case _ => s"return $row;"
})
decoderBuilder += rowFB.result()
rowFB.end()

val byteFB = FunctionBuilder("decode_byte", Array("NativeStatus*" -> "st"), "char")
val byteFB = decoderBuilder.buildMethod("decode_byte", Array("NativeStatus*" -> "st"), "char")
byteFB += s"return $buf->read_byte();"
decoderBuilder += byteFB.result()
byteFB.end()

decoderBuilder.result()
decoderBuilder.end()
}

def buildModule(t: PType, rt: PType, bufSpec: BufferSpec): NativeDecoderModule = {
assert(t.isInstanceOf[PBaseStruct] || t.isInstanceOf[PArray])
val tub = new TranslationUnitBuilder()

val decoder = apply(t, rt, bufSpec, tub)
tub += decoder


tub.include("hail/Decoder.h")
tub.include("hail/ObjectArray.h")
tub.include("<memory>")

val inBufFB = FunctionBuilder("make_input_buffer", Array("NativeStatus*" -> "st", "long" -> "objects"), "NativeObjPtr")
val inBufFB = tub.buildFunction("make_input_buffer", Array("NativeStatus*" -> "st", "long" -> "objects"), "NativeObjPtr")
inBufFB += "UpcallEnv up;"
inBufFB += s"auto jinput_stream = reinterpret_cast<ObjectArray*>(${ inBufFB.getArg(1) })->at(0);"
inBufFB += s"return std::make_shared<$decoder>(std::make_shared<InputStream>(up, jinput_stream));"
tub += inBufFB.result()
inBufFB.end()

val rowFB = FunctionBuilder("decode_row", Array("NativeStatus*" -> "st", "long" -> "buf", "long" -> "region"), "long")
val rowFB = tub.buildFunction("decode_row", Array("NativeStatus*" -> "st", "long" -> "buf", "long" -> "region"), "long")
rowFB += s"return (long) reinterpret_cast<$decoder *>(${ rowFB.getArg(1) })->decode_row(${ rowFB.getArg(0) }, reinterpret_cast<Region *>(${ rowFB.getArg(2) }));"
tub += rowFB.result()
rowFB.end()

val byteFB = FunctionBuilder("decode_byte", Array("NativeStatus*" -> "st", "long" -> "buf"), "long")
val byteFB = tub.buildFunction("decode_byte", Array("NativeStatus*" -> "st", "long" -> "buf"), "long")
byteFB += s"return (long) reinterpret_cast<$decoder *>(${ byteFB.getArg(1) })->decode_byte(${ byteFB.getArg(0) });"
tub += byteFB.result()
byteFB.end()

val mod = tub.result().build("-O2")
val mod = tub.end().build("-O2")

NativeDecoderModule(mod.getKey, mod.getBinary)
}
Expand Down
49 changes: 24 additions & 25 deletions hail/src/main/scala/is/hail/cxx/PackEncoder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -94,75 +94,74 @@ object PackEncoder {
tub.include("<cstdio>")
tub.include("<memory>")

val encBuilder = new ClassBuilder("Encoder", "NativeObj")
val encBuilder = tub.buildClass("Encoder", "NativeObj")

val bufType = bufSpec.nativeOutputBufferType
val buf = Variable("buf", s"std::shared_ptr<$bufType>")
encBuilder.addPrivate(buf)
encBuilder += buf

encBuilder.addConstructor(s"${ encBuilder.name }(std::shared_ptr<OutputStream> os) : $buf(std::make_shared<$bufType>(os)) { }")
encBuilder += s"${ encBuilder.name }(std::shared_ptr<OutputStream> os) : $buf(std::make_shared<$bufType>(os)) { }"

val rowFB = FunctionBuilder("encode_row", Array("NativeStatus*" -> "st", "char const*" -> "row"), "void")
val rowFB = encBuilder.buildMethod("encode_row", Array("NativeStatus*" -> "st", "const char *" -> "row"), "void")
rowFB += encode(t.fundamentalType, buf.ref, rowFB.getArg(1).ref)
rowFB += "return;"
encBuilder += rowFB.result()
rowFB.end()

val byteFB = FunctionBuilder("encode_byte", Array("NativeStatus*" -> "st", "char" -> "b"), "void")
val byteFB = encBuilder.buildMethod("encode_byte", Array("NativeStatus*" -> "st", "char" -> "b"), "void")
byteFB += s"$buf->write_byte(${ byteFB.getArg(1) });"
byteFB += "return;"
encBuilder += byteFB.result()
byteFB.end()

val flushFB = FunctionBuilder("flush", Array("NativeStatus*" -> "st"), "void")
val flushFB = encBuilder.buildMethod("flush", Array("NativeStatus*" -> "st"), "void")
flushFB += s"$buf->flush();"
flushFB += "return;"
encBuilder += flushFB.result()
flushFB.end()

val closeFB = FunctionBuilder("close", Array("NativeStatus*" -> "st"), "void")
val closeFB = encBuilder.buildMethod("close", Array("NativeStatus*" -> "st"), "void")
closeFB +=
s"""
|$buf->close();
|return;""".stripMargin
encBuilder += closeFB.result()
closeFB.end()

encBuilder.result()
encBuilder.end()
}

def buildModule(t: PType, bufSpec: BufferSpec): NativeEncoderModule = {
assert(t.isInstanceOf[PBaseStruct] || t.isInstanceOf[PArray])
val tub = new TranslationUnitBuilder()

val encClass = apply(t, bufSpec, tub)
tub += encClass

val outBufFB = FunctionBuilder("makeOutputBuffer", Array("NativeStatus*" -> "st", "long" -> "objects"), "NativeObjPtr")
val outBufFB = tub.buildFunction("makeOutputBuffer", Array("NativeStatus*" -> "st", "long" -> "objects"), "NativeObjPtr")
outBufFB += "UpcallEnv up;"
outBufFB += s"auto joutput_stream = reinterpret_cast<ObjectArray*>(${ outBufFB.getArg(1) })->at(0);"
val bufType = bufSpec.nativeOutputBufferType
outBufFB += s"return std::make_shared<$encClass>(std::make_shared<OutputStream>(up, joutput_stream));"
tub += outBufFB.result()
outBufFB.end()

val rowFB = FunctionBuilder("encode_row", Array("NativeStatus*" -> "st", "long" -> "buf", "long" -> "row"), "long")
val rowFB = tub.buildFunction("encode_row", Array("NativeStatus*" -> "st", "long" -> "buf", "long" -> "row"), "long")
rowFB += s"reinterpret_cast<$encClass *>(${ rowFB.getArg(1) })->encode_row(${ rowFB.getArg(0) }, reinterpret_cast<char *>(${ rowFB.getArg(2) }));"
rowFB += "return 0;"
tub += rowFB.result()
rowFB.end()

val byteFB = FunctionBuilder("encode_byte", Array("NativeStatus*" -> "st", "long" -> "buf", "long" -> "b"), "long")
val byteFB = tub.buildFunction("encode_byte", Array("NativeStatus*" -> "st", "long" -> "buf", "long" -> "b"), "long")
byteFB += s"reinterpret_cast<$encClass *>(${ byteFB.getArg(1) })->encode_byte(${ byteFB.getArg(0) }, ${ byteFB.getArg(2) } & 0xff);"
byteFB += "return 0;"
tub += byteFB.result()
byteFB.end()

val flushFB = FunctionBuilder("encoder_flush", Array("NativeStatus*" -> "st", "long" -> "buf"), "long")
val flushFB = tub.buildFunction("encoder_flush", Array("NativeStatus*" -> "st", "long" -> "buf"), "long")
flushFB += s"reinterpret_cast<$encClass *>(${ flushFB.getArg(1) })->flush(${ flushFB.getArg(0) });"
flushFB += "return 0;"
tub += flushFB.result()
flushFB.end()

val closeFB = FunctionBuilder("encoder_close", Array("NativeStatus*" -> "st", "long" -> "buf"), "long")
val closeFB = tub.buildFunction("encoder_close", Array("NativeStatus*" -> "st", "long" -> "buf"), "long")
closeFB += s"reinterpret_cast<$encClass *>(${ closeFB.getArg(1) })->close(${ closeFB.getArg(0) });"
closeFB += "return 0;"
tub += closeFB.result()
closeFB.end()

val mod = tub.result().build("-O1")
val mod = tub.end().build("-O1")

NativeEncoderModule(mod.getKey, mod.getBinary)
}
}
}
4 changes: 1 addition & 3 deletions hail/src/main/scala/is/hail/cxx/RVDEmitTriplet.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ object RVDEmitTriplet {
tub: TranslationUnitBuilder
): RVDEmitTriplet = {
val decoder = codecSpec.buildNativeDecoderClass(t, requestedType.rowType, tub)
tub += decoder
tub.include("hail/Decoder.h")
tub.include("hail/ObjectArray.h")
tub.include("<memory>")
Expand Down Expand Up @@ -64,7 +63,6 @@ object RVDEmitTriplet {

def write[T](t: RVDEmitTriplet, tub: TranslationUnitBuilder, path: String, stageLocally: Boolean, codecSpec: CodecSpec): Array[Long] = {
val encClass = codecSpec.buildNativeEncoderClass(t.typ.rowType, tub)
tub += encClass
tub.include("hail/Encoder.h")

val os = Variable("os", "long")
Expand All @@ -91,7 +89,7 @@ object RVDEmitTriplet {
|return $nRows;
""".stripMargin)

val mod = tub.result().build("-O2 -llz4")
val mod = tub.end().build("-O2 -llz4")
val modKey = mod.getKey
val modBinary = mod.getBinary

Expand Down
Loading

0 comments on commit 99f691d

Please sign in to comment.