Skip to content

Commit

Permalink
- moved entity sql to embedded fs
Browse files Browse the repository at this point in the history
  • Loading branch information
adranwit committed Nov 12, 2024
1 parent 13f286f commit cde5ed3
Show file tree
Hide file tree
Showing 7 changed files with 111 additions and 17 deletions.
8 changes: 7 additions & 1 deletion cmd/command/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -394,12 +394,18 @@ func (s *Service) translateGenerationOptions(gen *options.Generate, info *plugin
}

func (s *Service) generateEntity(ctx context.Context, pkg string, gen *options.Generate, info *plugin.Info, template *codegen.Template) error {
code, err := template.GenerateEntity(ctx, pkg, info)
embedContent := make(map[string]string)
embedURI := strings.ToLower(template.Spec.Namespace)

code, err := template.GenerateEntity(ctx, pkg, info, embedContent)
if err != nil {
return err
}
entityName := ensureGoFileCaseFormat(template)
s.Files.Append(asset.NewFile(gen.EntityLocation(template.FilePrefix(), template.FileMethodFragment(), entityName), code))
for k, v := range embedContent {
s.Files.Append(asset.NewFile(gen.EmbedLocation(embedURI+"/"+k, template.FileMethodFragment()), v))
}
return nil
}

Expand Down
97 changes: 81 additions & 16 deletions internal/codegen/entity.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@ import (
_ "embed"
"fmt"
"github.com/viant/datly/internal/plugin"
"github.com/viant/datly/internal/setter"
"github.com/viant/datly/view/extension"
"github.com/viant/datly/view/tags"
"github.com/viant/tagly/format/text"

"github.com/viant/datly/view/state"
"github.com/viant/xreflect"
"go/format"
Expand All @@ -22,7 +26,7 @@ const (
var entityTemplate string

// GenerateEntity generate golang entity
func (t *Template) GenerateEntity(ctx context.Context, pkg string, info *plugin.Info) (string, error) {
func (t *Template) GenerateEntity(ctx context.Context, pkg string, info *plugin.Info, embedContent map[string]string) (string, error) {
pkg = info.Package(pkg)
if t.MethodFragment != "" && t.MethodFragment != "get" {
pkg = strings.ToLower(t.MethodFragment)
Expand Down Expand Up @@ -62,19 +66,44 @@ func (t *Template) GenerateEntity(ctx context.Context, pkg string, info *plugin.
recv := strings.ToLower(t.TypeDef.Name[:1])

afterSnippet := strings.Builder{}

entityType := getBodyType(rType)
t.generateSetters(entityType, &afterSnippet, recv)
if !t.IsHandler {
afterSnippet = strings.Builder{}
}

embedURI := strings.ToLower(t.Spec.Namespace)

generatedStruct := xreflect.GenerateStruct(t.TypeDef.Name, rType,
xreflect.WithPackage(pkg),
xreflect.WithImports(imps.Packages),
xreflect.WithSnippetBefore(initSnippet),
xreflect.WithOnStructField(t.adjustStructField(embedURI, embedContent, true)),

xreflect.WithSnippetAfter(afterSnippet.String()),
)

formatted, err := format.Source([]byte(generatedStruct))
if err != nil {
return "", err
}
return string(formatted), nil
}

func (t *Template) generateSetters(rType reflect.Type, afterSnippet *strings.Builder, recv string) {
for i := 0; i < rType.NumField(); i++ {
field := rType.Field(i)
isPtr := field.Type.Kind() == reflect.Ptr
rawType := field.Type
if isPtr {
rawType = field.Type.Elem()
}

if rawType.Name() == "" || strings.Contains(string(field.Tag), `json:"-\"`) {
continue
}

afterSnippet.WriteString(fmt.Sprintf("\nfunc (%v *%v) Set%v(value %v) {", recv, t.TypeDef.Name, field.Name, rawType.String()))
afterSnippet.WriteString(fmt.Sprintf("\nfunc (%v *%v) Set%v(value %v) {", recv, t.Spec.Type.Name, field.Name, rawType.String()))
if isPtr {
afterSnippet.WriteString(fmt.Sprintf("\n\t%v.%v = &value", recv, field.Name))
} else {
Expand All @@ -83,22 +112,41 @@ func (t *Template) GenerateEntity(ctx context.Context, pkg string, info *plugin.
afterSnippet.WriteString(fmt.Sprintf("\n\t%v.Has.%v = true", recv, field.Name))
afterSnippet.WriteString("\n}\n\n")
}
if !t.IsHandler {
afterSnippet = strings.Builder{}
}
}

generatedStruct := xreflect.GenerateStruct(t.TypeDef.Name, rType,
xreflect.WithPackage(pkg),
xreflect.WithImports(imps.Packages),
xreflect.WithSnippetBefore(initSnippet),
xreflect.WithSnippetAfter(afterSnippet.String()),
)
func getBodyType(rType reflect.Type) reflect.Type {
if rType.Kind() == reflect.Ptr {
rType = rType.Elem()
}
if rType.NumField() == 1 {
field := rType.Field(0)
fType := field.Type
if fType.Kind() == reflect.Slice {
fType = fType.Elem()
}
if fType.Kind() == reflect.Ptr {
fType = fType.Elem()
}
if fType.Kind() == reflect.Struct {
return fType
}
}

formatted, err := format.Source([]byte(generatedStruct))
if err != nil {
return "", err
//get body type
for i := 0; i < rType.NumField(); i++ {
field := rType.Field(i)
parameterTag, _ := tags.Parse(field.Tag, nil)
if parameterTag.Parameter != nil && parameterTag.Parameter.Kind == string(state.KindRequestBody) {
rType = field.Type
if rType.Kind() == reflect.Slice {
rType = rType.Elem()
}
if rType.Kind() == reflect.Ptr {
rType = rType.Elem()
}
}
}
return string(formatted), nil
return rType
}

func (t *Template) generateRegisterType() string {
Expand Down Expand Up @@ -127,3 +175,20 @@ func (t *Template) generateMapTypeBody() string {
initCode := strings.Join(initElements, "\n")
return initCode
}

func (c *Template) adjustStructField(embedURI string, embeds map[string]string, generateContract bool) func(aField *reflect.StructField, tag *string, typeName *string, doc *string) {
return func(aField *reflect.StructField, tag, typeName, doc *string) {
fieldTag := *tag
fieldTag, value := xreflect.RemoveTag(fieldTag, "sql")
if value != "" {
name := *typeName
setter.SetStringIfEmpty(&name, aField.Name)
key := text.CaseFormatUpperCamel.Format(name, text.CaseFormatLowerUnderscore)
key = strings.ReplaceAll(key, ".", "")
key += ".sql"
embeds[key] = value
fieldTag += fmt.Sprintf(` sql:"uri=%v/`+key+`" `, embedURI)
}
*tag = fieldTag
}
}
7 changes: 7 additions & 0 deletions internal/inference/spec.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,13 @@ type (
Selector []string
)

func (s *Spec) NormalizeSQL() {
if s.SQL == "" || s.Table == "" {
return
}
s.SQL = strings.ReplaceAll(s.SQL, "("+s.Table+")", s.Table)
}

func (s *Spec) EnsureRelationType() {
if len(s.Relations) == 0 {
return
Expand Down
5 changes: 5 additions & 0 deletions internal/inference/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@ func (t *Table) detect(ctx context.Context, db *sql.DB, SQL string) error {
}
t.Namespace = strings.ToLower(query.From.Alias)
from := sqlparser.Stringify(query.From.X)

trimFrom := strings.TrimSpace(from)
if strings.HasPrefix(trimFrom, "(") && strings.HasSuffix(trimFrom, ")") {
from = trimFrom[1 : len(trimFrom)-1]
}
if !HasWhitespace(from) {
t.Name = from
}
Expand Down
1 change: 1 addition & 0 deletions internal/inference/tag.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ func (t *Tags) Init(tag string) {
func (t *Tags) buildSqlxTag(source *Spec, field *Field) {
column := field.Column
tagValue := TagValue{}

tagValue.Append(column.Name)
if column.IsAutoincrement {
tagValue.Append("autoincrement")
Expand Down
9 changes: 9 additions & 0 deletions internal/translator/viewlet.go
Original file line number Diff line number Diff line change
Expand Up @@ -263,3 +263,12 @@ func (v *Viewlet) mergeTableJSONHint(hint string) error {
data, _ := parser.MergeStructs(&output, &v.OutputSettings)
return json.Unmarshal(data, &v.OutputSettings)
}

func (v *Viewlet) NormalizeSQL() {
if v.Spec != nil {
v.Spec.NormalizeSQL()
}
if v.Table != nil && v.Table.Name != "" {
v.SQL = strings.ReplaceAll(v.SQL, "("+v.Table.Name+")", v.Table.Name)
}
}
1 change: 1 addition & 0 deletions internal/translator/viewlets.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ func (n *Viewlets) Init(ctx context.Context, aQuery *query.Select, resource *Res
}

if err := n.Each(func(viewlet *Viewlet) error {
viewlet.NormalizeSQL()
if err := setType(ctx, viewlet, resource.Rule.Doc.Columns); err != nil {
return fmt.Errorf("failed to init viewlet: %v, %w", viewlet.Name, err)
}
Expand Down

0 comments on commit cde5ed3

Please sign in to comment.