Skip to content

Commit

Permalink
Merge pull request #983 from qazwsxedckll/dev
Browse files Browse the repository at this point in the history
feat: new protoc-gen-go-grain using protogen
  • Loading branch information
rogeralsing authored Dec 16, 2023
2 parents 96b6848 + 170af6b commit 7b6b1f5
Show file tree
Hide file tree
Showing 10 changed files with 634 additions and 0 deletions.
3 changes: 3 additions & 0 deletions protobuf/protoc-gen-go-grain/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
.PHONY: gentestdata
gentestdata:
protoc --go_out=. --go_opt=paths=source_relative --plugin=protoc-gen-go-grain=protoc-gen-go-grain.sh --go-grain_out=. --go-grain_opt=paths=source_relative testdata/hello/hello.proto
99 changes: 99 additions & 0 deletions protobuf/protoc-gen-go-grain/generate.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
package main

import (
"fmt"

"google.golang.org/protobuf/compiler/protogen"
"google.golang.org/protobuf/types/descriptorpb"
)

const deprecationComment = "// Deprecated: Do not use."

const (
timePackage = protogen.GoImportPath("time")
errorsPackage = protogen.GoImportPath("errors")
fmtPackage = protogen.GoImportPath("fmt")
slogPackage = protogen.GoImportPath("log/slog")
protoPackage = protogen.GoImportPath("google.golang.org/protobuf/proto")
actorPackage = protogen.GoImportPath("github.com/asynkron/protoactor-go/actor")
clusterPackage = protogen.GoImportPath("github.com/asynkron/protoactor-go/cluster")
)

func generateFile(gen *protogen.Plugin, file *protogen.File) {
filename := file.GeneratedFilenamePrefix + "_grain.pb.go"
g := gen.NewGeneratedFile(filename, file.GoImportPath)

generateHeader(gen, g, file)
generateContent(gen, g, file)
}

func generateHeader(gen *protogen.Plugin, g *protogen.GeneratedFile, file *protogen.File) {
g.P("// Code generated by protoc-gen-grain. DO NOT EDIT.")
g.P("// versions:")
g.P("// protoc-gen-grain ", version)
protocVersion := "(unknown)"
if v := gen.Request.GetCompilerVersion(); v != nil {
protocVersion = fmt.Sprintf("v%v.%v.%v", v.GetMajor(), v.GetMinor(), v.GetPatch())
if s := v.GetSuffix(); s != "" {
protocVersion += "-" + s
}
}
g.P("// protoc ", protocVersion)
if file.Proto.GetOptions().GetDeprecated() {
g.P("// ", file.Desc.Path(), " is a deprecated file.")
} else {
g.P("// source: ", file.Desc.Path())
}
g.P()
}

func generateContent(gen *protogen.Plugin, g *protogen.GeneratedFile, file *protogen.File) {
g.P("package ", file.GoPackageName)
g.P()

if len(file.Services) == 0 {
return
}

g.QualifiedGoIdent(actorPackage.Ident(""))
g.QualifiedGoIdent(clusterPackage.Ident(""))
g.QualifiedGoIdent(protoPackage.Ident(""))
g.QualifiedGoIdent(errorsPackage.Ident(""))
g.QualifiedGoIdent(fmtPackage.Ident(""))
g.QualifiedGoIdent(timePackage.Ident(""))
g.QualifiedGoIdent(slogPackage.Ident(""))

for _, service := range file.Services {
generateService(service, file, g)
}
}

func generateService(service *protogen.Service, file *protogen.File, g *protogen.GeneratedFile) {
if service.Desc.Options().(*descriptorpb.ServiceOptions).GetDeprecated() {
g.P("//")
g.P(deprecationComment)
}

sd := &serviceDesc{
Name: service.GoName,
}

for i, method := range service.Methods {
if method.Desc.IsStreamingClient() || method.Desc.IsStreamingServer() {
continue
}

md := &methodDesc{
Name: method.GoName,
Input: g.QualifiedGoIdent(method.Input.GoIdent),
Output: g.QualifiedGoIdent(method.Output.GoIdent),
Index: i,
}

sd.Methods = append(sd.Methods, md)
}

if len(sd.Methods) != 0 {
g.P(sd.execute())
}
}
18 changes: 18 additions & 0 deletions protobuf/protoc-gen-go-grain/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package main

import (
"google.golang.org/protobuf/compiler/protogen"
)

func main() {
protogen.Options{}.Run(func(gen *protogen.Plugin) error {
for _, f := range gen.Files {
if !f.Generate {
continue
}
generateFile(gen, f)
}

return nil
})
}
3 changes: 3 additions & 0 deletions protobuf/protoc-gen-go-grain/protoc-gen-go-grain.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#!/usr/bin/env bash

exec go run .
36 changes: 36 additions & 0 deletions protobuf/protoc-gen-go-grain/template.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package main

import (
"bytes"
_ "embed"
"strings"
"text/template"
)

//go:embed templates/grain.tmpl
var grainTemplate string

type serviceDesc struct {
Name string // Greeter
Methods []*methodDesc
}

type methodDesc struct {
Name string
Input string
Output string
Index int
}

func (s *serviceDesc) execute() string {
buf := new(bytes.Buffer)
tmpl, err := template.New("grain").Parse(strings.TrimSpace(grainTemplate))
if err != nil {
panic(err)
}
if err := tmpl.Execute(buf, s); err != nil {
panic(err)
}

return strings.Trim(buf.String(), "\r\n")
}
143 changes: 143 additions & 0 deletions protobuf/protoc-gen-go-grain/templates/grain.tmpl
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
{{ $service := . }}
var x{{ $service.Name }}Factory func() {{ $service.Name }}

// {{ $service.Name }}Factory produces a {{ $service.Name }}
func {{ $service.Name }}Factory(factory func() {{ $service.Name }}) {
x{{ $service.Name }}Factory = factory
}

// Get{{ $service.Name }}GrainClient instantiates a new {{ $service.Name }}GrainClient with given Identity
func Get{{ $service.Name }}GrainClient(c *cluster.Cluster, id string) *{{ $service.Name }}GrainClient {
if c == nil {
panic(fmt.Errorf("nil cluster instance"))
}
if id == "" {
panic(fmt.Errorf("empty id"))
}
return &{{ $service.Name }}GrainClient{Identity: id, cluster: c}
}

// Get{{ $service.Name }}Kind instantiates a new cluster.Kind for {{ $service.Name }}
func Get{{ $service.Name }}Kind(opts ...actor.PropsOption) *cluster.Kind {
props := actor.PropsFromProducer(func() actor.Actor {
return &{{ $service.Name }}Actor{
Timeout: 60 * time.Second,
}
}, opts...)
kind := cluster.NewKind("{{ $service.Name }}", props)
return kind
}

// Get{{ $service.Name }}Kind instantiates a new cluster.Kind for {{ $service.Name }}
func New{{ $service.Name }}Kind(factory func() {{ $service.Name }}, timeout time.Duration, opts ...actor.PropsOption) *cluster.Kind {
x{{ $service.Name }}Factory = factory
props := actor.PropsFromProducer(func() actor.Actor {
return &{{ $service.Name }}Actor{
Timeout: timeout,
}
}, opts...)
kind := cluster.NewKind("{{ $service.Name }}", props)
return kind
}

// {{ $service.Name }} interfaces the services available to the {{ $service.Name }}
type {{ $service.Name }} interface {
Init(ctx cluster.GrainContext)
Terminate(ctx cluster.GrainContext)
ReceiveDefault(ctx cluster.GrainContext)
{{- range $method := .Methods }}
{{ $method.Name }}(*{{ $method.Input }}, cluster.GrainContext) (*{{ $method.Output }}, error)
{{- end }}
}

// {{ $service.Name }}GrainClient holds the base data for the {{ $service.Name }}Grain
type {{ $service.Name }}GrainClient struct {
Identity string
cluster *cluster.Cluster
}
{{ range $method := .Methods}}
// {{ $method.Name }} requests the execution on to the cluster with CallOptions
func (g *{{ $service.Name }}GrainClient) {{ $method.Name }}(r *{{ $method.Input }}, opts ...cluster.GrainCallOption) (*{{ $method.Output }}, error) {
bytes, err := proto.Marshal(r)
if err != nil {
return nil, err
}
reqMsg := &cluster.GrainRequest{MethodIndex: {{ $method.Index }}, MessageData: bytes}
resp, err := g.cluster.Request(g.Identity, "{{ $service.Name }}", reqMsg, opts...)
if err != nil {
return nil, err
}
switch msg := resp.(type) {
case *cluster.GrainResponse:
result := &{{ $method.Output }}{}
err = proto.Unmarshal(msg.MessageData, result)
if err != nil {
return nil, err
}
return result, nil
case *cluster.GrainErrorResponse:
return nil, errors.New(msg.Err)
default:
return nil, errors.New("unknown response")
}
}
{{ end }}
// {{ $service.Name }}Actor represents the actor structure
type {{ $service.Name }}Actor struct {
ctx cluster.GrainContext
inner {{ $service.Name }}
Timeout time.Duration
}

// Receive ensures the lifecycle of the actor for the received message
func (a *{{ $service.Name }}Actor) Receive(ctx actor.Context) {
switch msg := ctx.Message().(type) {
case *actor.Started: //pass
case *cluster.ClusterInit:
a.ctx = cluster.NewGrainContext(ctx, msg.Identity, msg.Cluster)
a.inner = x{{ $service.Name }}Factory()
a.inner.Init(a.ctx)

if a.Timeout > 0 {
ctx.SetReceiveTimeout(a.Timeout)
}
case *actor.ReceiveTimeout:
ctx.Poison(ctx.Self())
case *actor.Stopped:
a.inner.Terminate(a.ctx)
case actor.AutoReceiveMessage: // pass
case actor.SystemMessage: // pass

case *cluster.GrainRequest:
switch msg.MethodIndex {
{{ range $method := .Methods -}}
case {{ $method.Index }}:
req := &{{ $method.Input }}{}
err := proto.Unmarshal(msg.MessageData, req)
if err != nil {
ctx.Logger().Error("[Grain] {{ $method.Name }}({{ $method.Input }}) proto.Unmarshal failed.", slog.Any("error", err))
resp := &cluster.GrainErrorResponse{Err: err.Error()}
ctx.Respond(resp)
return
}
r0, err := a.inner.{{ $method.Name }}(req, a.ctx)
if err != nil {
resp := &cluster.GrainErrorResponse{Err: err.Error()}
ctx.Respond(resp)
return
}
bytes, err := proto.Marshal(r0)
if err != nil {
ctx.Logger().Error("[Grain] {{ $method.Name }}({{ $method.Input }}) proto.Marshal failed", slog.Any("error", err))
resp := &cluster.GrainErrorResponse{Err: err.Error()}
ctx.Respond(resp)
return
}
resp := &cluster.GrainResponse{MessageData: bytes}
ctx.Respond(resp)
{{ end -}}
}
default:
a.inner.ReceiveDefault(a.ctx)
}
}
Loading

0 comments on commit 7b6b1f5

Please sign in to comment.