Skip to content

GRPCGATEWAY-20 match against named URL template slots correctly #21

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions build.sbt
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import com.trueaccord.scalapb.compiler.Version.{grpcJavaVersion, scalapbVersion}

organization in ThisBuild := "beyondthelines"
version in ThisBuild := "0.0.8"
version in ThisBuild := "0.0.10-SNAPSHOT"
licenses in ThisBuild := ("MIT", url("http://opensource.org/licenses/MIT")) :: Nil
bintrayOrganization in ThisBuild := Some("beyondthelines")
bintrayPackageLabels in ThisBuild := Seq("scala", "protobuf", "grpc")
scalaVersion in ThisBuild := "2.12.4"

val googleapisVersion = "0.0.3"
val scalatestVersion = "3.0.4"

lazy val runtime = (project in file("runtime"))
.settings(
crossScalaVersions := Seq("2.12.4", "2.11.11"),
Expand All @@ -16,8 +19,9 @@ lazy val runtime = (project in file("runtime"))
"com.trueaccord.scalapb" %% "scalapb-runtime-grpc" % scalapbVersion,
"com.trueaccord.scalapb" %% "scalapb-json4s" % "0.3.3",
"io.grpc" % "grpc-netty" % grpcJavaVersion,
"org.scalatest" %% "scalatest" % scalatestVersion % Test,
"org.webjars" % "swagger-ui" % "3.5.0",
"com.google.api.grpc" % "googleapis-common-protos" % "0.0.3" % "protobuf"
"com.google.api.grpc" % "googleapis-common-protos" % googleapisVersion % "protobuf"
),
PB.protoSources in Compile += target.value / "protobuf_external",
includeFilter in PB.generate := new SimpleFilter(
Expand All @@ -35,9 +39,10 @@ lazy val generator = (project in file("generator"))
crossScalaVersions := Seq("2.12.4", "2.10.6"),
name := "GrpcGatewayGenerator",
libraryDependencies ++= Seq(
"com.trueaccord.scalapb" %% "compilerplugin" % scalapbVersion,
"com.trueaccord.scalapb" %% "scalapb-runtime-grpc" % scalapbVersion,
"com.google.api.grpc" % "googleapis-common-protos" % "0.0.3" % "protobuf"
"com.trueaccord.scalapb" %% "compilerplugin" % scalapbVersion,
"com.trueaccord.scalapb" %% "scalapb-runtime-grpc" % scalapbVersion,
"com.google.api.grpc" % "googleapis-common-protos" % googleapisVersion % "protobuf",
"org.scalatest" %% "scalatest" % scalatestVersion % Test
),
PB.protoSources in Compile += target.value / "protobuf_external",
includeFilter in PB.generate := new SimpleFilter(
Expand Down
187 changes: 116 additions & 71 deletions generator/src/main/scala/grpcgateway/generators/GatewayGenerator.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package grpcgateway.generators

import com.google.api.AnnotationsProto
import com.google.api.{AnnotationsProto, HttpRule}
import com.google.api.HttpRule.PatternCase
import com.google.protobuf.Descriptors.FieldDescriptor.JavaType
import com.google.protobuf.Descriptors._
Expand All @@ -10,10 +10,11 @@ import com.trueaccord.scalapb.compiler.FunctionalPrinter.PrinterEndo
import com.trueaccord.scalapb.compiler.{DescriptorPimps, FunctionalPrinter}

import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scalapbshade.v0_6_7.com.trueaccord.scalapb.Scalapb

object GatewayGenerator extends protocbridge.ProtocCodeGenerator with DescriptorPimps {

override val params = com.trueaccord.scalapb.compiler.GeneratorParams()

override def run(requestBytes: Array[Byte]): Array[Byte] = {
Expand All @@ -23,7 +24,6 @@ object GatewayGenerator extends protocbridge.ProtocCodeGenerator with Descriptor

val b = CodeGeneratorResponse.newBuilder
val request = CodeGeneratorRequest.parseFrom(requestBytes, registry)

val fileDescByName: Map[String, FileDescriptor] =
request.getProtoFileList.asScala.foldLeft[Map[String, FileDescriptor]](Map.empty) {
case (acc, fp) =>
Expand All @@ -50,16 +50,16 @@ object GatewayGenerator extends protocbridge.ProtocCodeGenerator with Descriptor
.add(
"import _root_.com.trueaccord.scalapb.GeneratedMessage",
"import _root_.com.trueaccord.scalapb.json.JsonFormat",
"import _root_.grpcgateway.handlers._",
"import _root_.io.grpc._",
"import _root_.io.netty.handler.codec.http.{HttpMethod, QueryStringDecoder}"
"import _root_.grpcgateway.handlers.GrpcGatewayHandler",
"import _root_.grpcgateway.handlers.jsonException2GatewayExceptionPF",
"import _root_.io.grpc.ManagedChannel",
"import _root_.io.netty.handler.codec.http.HttpMethod"
)
.newline
.add(
"import scala.collection.JavaConverters._",
"import scala.concurrent.{ExecutionContext, Future}",
"import com.trueaccord.scalapb.json.JsonFormatException",
"import scala.util._"
"import grpcgateway.util.{RestfulUrl, UrlTemplate}",
"import scala.util.Try"
)
.newline
.print(fileDesc.getServices.asScala) { case (p, s) => generateService(s)(p) }
Expand All @@ -73,13 +73,14 @@ object GatewayGenerator extends protocbridge.ProtocCodeGenerator with Descriptor
_.add(s"class ${service.getName}Handler(channel: ManagedChannel)(implicit ec: ExecutionContext)").indent
.add(
"extends GrpcGatewayHandler(channel)(ec) {",
"// a function that takes a RestfulUrl and produces a function that takes a request body and returns a response message",
"type RestfulHandler = RestfulUrl => (String) => Future[GeneratedMessage]",
"",
s"""override val name: String = "${service.getName}"""",
s"private val stub = ${service.getName}Grpc.stub(channel)"
)
.newline
.call(generateSupportsCall(service))
.newline
.call(generateUnaryCall(service))
.call(generateCallSeqsByVerb(getUnaryCallsWithHttpExtension(service)))
.outdent
.add("}")
.newline
Expand All @@ -91,61 +92,12 @@ object GatewayGenerator extends protocbridge.ProtocCodeGenerator with Descriptor
}
}

private def generateUnaryCall(service: ServiceDescriptor): PrinterEndo = { printer =>
val methods = getUnaryCallsWithHttpExtension(service)
printer
.add(s"override def unaryCall(method: HttpMethod, uri: String, body: String): Future[GeneratedMessage] = {")
.indent
.add(
"val queryString = new QueryStringDecoder(uri)",
"(method.name, queryString.path) match {"
)
.indent
.print(methods) { case (p, m) => generateMethodHandlerCase(m)(p) }
.add("case (methodName, path) => ")
.addIndented("""Future.failed(InvalidArgument(s"No route defined for $methodName($path)"))""")
.outdent
.add("}")
.outdent
.add("}")
}

private def generateSupportsCall(service: ServiceDescriptor): PrinterEndo = { printer =>
val methods = getUnaryCallsWithHttpExtension(service)
printer
.add(s"override def supportsCall(method: HttpMethod, uri: String): Boolean = {")
.indent
.add(
"val queryString = new QueryStringDecoder(uri)",
"(method.name, queryString.path) match {"
)
.indent
.print(methods) { case (p, m) => generateMethodCase(m)(p) }
.add("case _ => false")
.outdent
.add("}")
.outdent
.add("}")
}

private def generateMethodCase(method: MethodDescriptor): PrinterEndo = { printer =>
val http = method.getOptions.getExtension(AnnotationsProto.http)
http.getPatternCase match {
case PatternCase.GET => printer.add(s"""case ("GET", "${http.getGet}") => true""")
case PatternCase.POST => printer.add(s"""case ("POST", "${http.getPost}") => true""")
case PatternCase.PUT => printer.add(s"""case ("PUT", "${http.getPut}") => true""")
case PatternCase.DELETE => printer.add(s"""case ("DELETE", "${http.getDelete}") => true""")
case _ => printer
}
}

private def generateMethodHandlerCase(method: MethodDescriptor): PrinterEndo = { printer =>
val http = method.getOptions.getExtension(AnnotationsProto.http)
val methodName = method.getName.charAt(0).toLower + method.getName.substring(1)
http.getPatternCase match {
case PatternCase.GET =>
printer
.add(s"""case ("GET", "${http.getGet}") => """)
.indent
.add("val input = Try {")
.indent
Expand All @@ -156,7 +108,6 @@ object GatewayGenerator extends protocbridge.ProtocCodeGenerator with Descriptor
.outdent
case PatternCase.POST =>
printer
.add(s"""case ("POST", "${http.getPost}") => """)
.add("for {")
.addIndented(
s"""msg <- Future.fromTry(Try(JsonFormat.fromJsonString[${method.getInputType.getName}](body)).recoverWith(jsonException2GatewayExceptionPF))""",
Expand All @@ -165,7 +116,6 @@ object GatewayGenerator extends protocbridge.ProtocCodeGenerator with Descriptor
.add("} yield res")
case PatternCase.PUT =>
printer
.add(s"""case ("PUT", "${http.getPut}") => """)
.add("for {")
.addIndented(
s"""msg <- Future.fromTry(Try(JsonFormat.fromJsonString[${method.getInputType.getName}](body)).recoverWith(jsonException2GatewayExceptionPF))""",
Expand All @@ -174,7 +124,6 @@ object GatewayGenerator extends protocbridge.ProtocCodeGenerator with Descriptor
.add("} yield res")
case PatternCase.DELETE =>
printer
.add(s"""case ("DELETE", "${http.getDelete}") => """)
.indent
.add("val input = Try {")
.indent
Expand Down Expand Up @@ -203,37 +152,37 @@ object GatewayGenerator extends protocbridge.ProtocCodeGenerator with Descriptor
case JavaType.ENUM =>
p.add(s"val ${inputName(f, prefix)} = ")
.addIndented(
s"""${f.getName}.valueOf(queryString.parameters().get("$prefix${f.getJsonName}").asScala.head)"""
s"""${f.getName}.valueOf(url.parameter("$prefix${f.getJsonName}"))"""
)
case JavaType.BOOLEAN =>
p.add(s"val ${inputName(f, prefix)} = ")
.addIndented(
s"""queryString.parameters().get("$prefix${f.getJsonName}").asScala.head.toBoolean"""
s"""url.parameter("$prefix${f.getJsonName}").toBoolean"""
)
case JavaType.DOUBLE =>
p.add(s"val ${inputName(f, prefix)} = ")
.addIndented(
s"""queryString.parameters().get("$prefix${f.getJsonName}").asScala.head.toDouble"""
s"""url.parameter("$prefix${f.getJsonName}").toDouble"""
)
case JavaType.FLOAT =>
p.add(s"val ${inputName(f, prefix)} = ")
.addIndented(
s"""queryString.parameters().get("$prefix${f.getJsonName}").asScala.head.toFloat"""
s"""url.parameter("$prefix${f.getJsonName}").toFloat"""
)
case JavaType.INT =>
p.add(s"val ${inputName(f, prefix)} = ")
.addIndented(
s"""queryString.parameters().get("$prefix${f.getJsonName}").asScala.head.toInt"""
s"""url.parameter("$prefix${f.getJsonName}").toInt"""
)
case JavaType.LONG =>
p.add(s"val ${inputName(f, prefix)} = ")
.addIndented(
s"""queryString.parameters().get("$prefix${f.getJsonName}").asScala.head.toLong"""
s"""url.parameter("$prefix${f.getJsonName}").toLong"""
)
case JavaType.STRING =>
p.add(s"val ${inputName(f, prefix)} = ")
.addIndented(
s"""queryString.parameters().get("$prefix${f.getJsonName}").asScala.head"""
s"""url.parameter("$prefix${f.getJsonName}")"""
)
case jt => throw new Exception(s"Unknown java type: $jt")
}
Expand All @@ -246,4 +195,100 @@ object GatewayGenerator extends protocbridge.ProtocCodeGenerator with Descriptor
name.charAt(0).toLower + name.substring(1)
}

private def generateCallSeqsByVerb(descritors: mutable.Seq[MethodDescriptor]): PrinterEndo = { printer =>
val verbToMethods: mutable.Map[PatternCase, Seq[RestfulMethod]] = MethodDescriptors.methodsByVerb(descritors)
printer
.call(generateCallSeqsByVerb(verbToMethods))
.call(generateSupportsCall(verbToMethods.keySet))
}

private def generateCallSeqsByVerb(verbToMethods: mutable.Map[PatternCase, Seq[RestfulMethod]]): PrinterEndo = { printer =>
printer.
print(verbToMethods) { case (p, (pattern,methods)) => generateCallSeq(pattern, methods)(p) }
}

private def generateCallSeq(verb: PatternCase, methods: Seq[RestfulMethod]): PrinterEndo = { printer =>
printer
.add(s"private val ${verb.name().toLowerCase}Calls: Seq[(UrlTemplate, RestfulHandler)] = Seq(")
.indent
.print(methods) { case (p, method) => generateCall(method)(p) }
.outdent
.add(")") // Seq
.newline
}

private def generateCall(method: RestfulMethod): PrinterEndo = { printer =>
printer
.add("(") // pair
.add(
s"""UrlTemplate("${method.urlTemplate}"),""",
"(url: RestfulUrl) => (body: String) => {" // function
)
.indent
.call(generateMethodHandlerCase(method.method))
.outdent
.add("}") // function
.add("),") // pair
}

private def generateSupportsCall(verbs: collection.Set[PatternCase]): PrinterEndo = { printer =>
printer
.add(s"override def supportsCall(method: HttpMethod, uri: String): Option[UnaryCall] = {")
.indent
.add("method.name match {")
.indent
.print(verbs) { case (p, verb) => generateVerbCase(verb)(p) }
.add("case _ => None")
.outdent
.add("}") // match
.outdent
.add("}") // def
}

private def generateVerbCase(verb: PatternCase): PrinterEndo = { printer =>
printer
.add(s"""case "${verb.name().toUpperCase}" =>""")
.indent
.add(s"for ((restful, handler) <- ${verb.name().toLowerCase}Calls) {")
.indent
.add("val mayBe = restful.matchUri(uri).map((url: RestfulUrl) => handler(url))")
.add("if (mayBe.isDefined) {")
.indent
.add("return mayBe")
.outdent
.add("}") //if
.outdent
.add("}") // for
.newline
.add("None") // def
.newline
.outdent // case body
}

}

private case class RestfulMethod(urlTemplate: String, method: MethodDescriptor)

private object MethodDescriptors {
def methodsByVerb(descriptors: mutable.Seq[MethodDescriptor]) : mutable.Map[PatternCase, Seq[RestfulMethod]] = {
val map = mutable.Map[PatternCase, ArrayBuffer[RestfulMethod]]()

descriptors.foreach((md: MethodDescriptor) => {
val http = md.getOptions.getExtension(AnnotationsProto.http)
val seq = map.getOrElseUpdate(http.getPatternCase, ArrayBuffer())
seq += RestfulMethod(urlTemplate(http), md)
})

map.asInstanceOf[mutable.Map[PatternCase, Seq[RestfulMethod]]] // todo how to do it with "A <:" ?
}

private def urlTemplate(http: HttpRule): String = {
http.getPatternCase match {
case PatternCase.GET => http.getGet
case PatternCase.POST => http.getPost
case PatternCase.PUT => http.getPut
case PatternCase.DELETE => http.getDelete
case _ => throw new IllegalArgumentException(s"Unsupported pattern: ${http.getPatternCase}")
}
}
}
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package grpcgateway.generators

import java.nio.file.{Files, Paths}

import com.google.protobuf.compiler.PluginProtos.{CodeGeneratorRequest, CodeGeneratorResponse}
import org.scalatest.{Assertions, FlatSpec}
import protocbridge.frontend.PluginFrontend

class GatewayGeneratorTest extends FlatSpec with Assertions {
private val DIR = "generator/target/scala-2.12/test-classes/"

it should "generate" in {
val requestProtoStream = Files.newInputStream(Paths.get(DIR + "objectstore_proto.bin"))
val request = CodeGeneratorRequest.parseFrom(requestProtoStream)

val responseBytes: Array[Byte] = PluginFrontend.runWithBytes(GatewayGenerator, request.toByteArray)
val generatedResponse = CodeGeneratorResponse.parseFrom(responseBytes)

for (i <- 0 until generatedResponse.getFileCount) {
val file = generatedResponse.getFile(i)
Files.write(Paths.get(DIR + file.getName.substring(file.getName.lastIndexOf("/") + 1)), file.getContent.getBytes())
}
}
}
Loading