Skip to content

Commit 03ca462

Browse files
committed

File tree

1 file changed

+79
-53
lines changed

1 file changed

+79
-53
lines changed

lib/service-gen.go

Lines changed: 79 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,16 @@ import (
1616
)
1717

1818
type MethodInfo struct {
19-
OriginalName string
20-
Name string
21-
InputType string
22-
IsInputPointer bool
23-
OutputType string
24-
IsOutputPointer bool
25-
IsWorkflow bool
26-
IsService bool
19+
OriginalName string
20+
Name string
21+
InputType string
22+
IsInputPointer bool
23+
IsInputPrimitive bool
24+
OutputType string
25+
IsOutputPointer bool
26+
IsOutputPrimitive bool
27+
IsWorkflow bool
28+
IsService bool
2729
}
2830

2931
type ServiceInfo struct {
@@ -72,16 +74,18 @@ func (t *{{.ServiceStructName}}) GetInputType(method string) (any, error) {
7274
}
7375
7476
func (t *{{.ServiceStructName}}) GetOutputType(method string) (any, error) {
75-
method = strings.ToLower(method)
76-
switch method {
77-
{{range .Methods}}case "{{.Name}}":
78-
{
79-
return &{{.OutputType}}{}, nil
80-
}
81-
{{end}}default:
82-
{
83-
return nil, errors.New("method not found")
84-
}
77+
switch strings.ToLower(method) {
78+
{{range .Methods}}
79+
case "{{.Name}}":
80+
{{if .IsOutputPrimitive}}
81+
var v {{.OutputType}}
82+
return &v, nil
83+
{{else}}
84+
return &{{.OutputType}}{}, nil
85+
{{end}}
86+
{{end}}
87+
default:
88+
return nil, fmt.Errorf("method %q not found", method)
8589
}
8690
}
8791
@@ -289,6 +293,50 @@ func validateFunctionParams(fn *ast.FuncDecl) (string, error) {
289293
return "", fmt.Errorf("function %s: first parameter must be polycode.ServiceContext or polycode.WorkflowContext", fn.Name.Name)
290294
}
291295

296+
func extractType(expr ast.Expr) (typeStr string, isPointer bool, isPrimitive bool) {
297+
switch t := expr.(type) {
298+
299+
case *ast.StarExpr:
300+
innerType, _, primitive := extractType(t.X)
301+
return innerType, true, primitive
302+
303+
case *ast.SelectorExpr:
304+
// Handles pkg.Type
305+
if pkgIdent, ok := t.X.(*ast.Ident); ok {
306+
typeName := fmt.Sprintf("%s.%s", pkgIdent.Name, t.Sel.Name)
307+
return typeName, false, false
308+
}
309+
310+
return t.Sel.Name, false, false
311+
312+
case *ast.Ident:
313+
// Handles builtin and local types
314+
return t.Name, false, primitiveTypes[t.Name]
315+
316+
case *ast.ArrayType:
317+
elemType, _, _ := extractType(t.Elt)
318+
return "[]" + elemType, false, false
319+
320+
case *ast.MapType:
321+
keyType, _, _ := extractType(t.Key)
322+
valType, _, _ := extractType(t.Value)
323+
return fmt.Sprintf("map[%s]%s", keyType, valType), false, false
324+
325+
case *ast.InterfaceType:
326+
return "interface{}", false, false
327+
328+
default:
329+
return fmt.Sprintf("%T", t), false, false
330+
}
331+
}
332+
333+
var primitiveTypes = map[string]bool{
334+
"string": true, "bool": true, "int": true, "int8": true, "int16": true,
335+
"int32": true, "int64": true, "uint": true, "uint8": true, "uint16": true,
336+
"uint32": true, "uint64": true, "float32": true, "float64": true,
337+
"byte": true, "rune": true, "any": true, "interface{}": true,
338+
}
339+
292340
// Updated parseDir function to mark methods as workflow or service
293341
func parseDir(serviceFolder string) ([]MethodInfo, []string, error) {
294342
fset := token.NewFileSet()
@@ -330,44 +378,22 @@ func parseDir(serviceFolder string) ([]MethodInfo, []string, error) {
330378

331379
// Extract the function name and input/output parameters
332380
methodName := strings.ToLower(fn.Name.Name) // Normalize to lowercase
333-
334-
inputType := ""
335-
isInputPointer := false
336-
337-
// Handle pointer types and normal types
338-
if starExpr, ok := fn.Type.Params.List[1].Type.(*ast.StarExpr); ok {
339-
isInputPointer = true
340-
if selectorExpr, ok := starExpr.X.(*ast.SelectorExpr); ok {
341-
inputType = fmt.Sprintf("%s.%s", selectorExpr.X.(*ast.Ident).Name, selectorExpr.Sel.Name)
342-
}
343-
} else if selectorExpr, ok := fn.Type.Params.List[1].Type.(*ast.SelectorExpr); ok {
344-
inputType = fmt.Sprintf("%s.%s", selectorExpr.X.(*ast.Ident).Name, selectorExpr.Sel.Name)
345-
}
346-
347-
outputType := ""
348-
isOutputPointer := false
349-
350-
// Handle pointer types and normal types
351-
if starExpr, ok := fn.Type.Results.List[0].Type.(*ast.StarExpr); ok {
352-
isOutputPointer = true
353-
if selectorExpr, ok := starExpr.X.(*ast.SelectorExpr); ok {
354-
outputType = fmt.Sprintf("%s.%s", selectorExpr.X.(*ast.Ident).Name, selectorExpr.Sel.Name)
355-
}
356-
} else if selectorExpr, ok := fn.Type.Results.List[0].Type.(*ast.SelectorExpr); ok {
357-
outputType = fmt.Sprintf("%s.%s", selectorExpr.X.(*ast.Ident).Name, selectorExpr.Sel.Name)
358-
}
381+
inputType, isInputPointer, isInputPrimitive := extractType(fn.Type.Params.List[1].Type)
382+
outputType, isOutputPointer, isOutputPrimitive := extractType(fn.Type.Results.List[0].Type)
359383

360384
// Append the method and its corresponding input type to methods
361-
if inputType != "" {
385+
if inputType != "" && outputType != "" {
362386
methods = append(methods, MethodInfo{
363-
OriginalName: OriginalName,
364-
Name: methodName,
365-
InputType: inputType,
366-
IsInputPointer: isInputPointer,
367-
OutputType: outputType,
368-
IsOutputPointer: isOutputPointer,
369-
IsWorkflow: contextType == "Workflow",
370-
IsService: contextType == "Service",
387+
OriginalName: OriginalName,
388+
Name: methodName,
389+
InputType: inputType,
390+
IsInputPointer: isInputPointer,
391+
IsInputPrimitive: isInputPrimitive,
392+
OutputType: outputType,
393+
IsOutputPointer: isOutputPointer,
394+
IsOutputPrimitive: isOutputPrimitive,
395+
IsWorkflow: contextType == "Workflow",
396+
IsService: contextType == "Service",
371397
})
372398
}
373399
}

0 commit comments

Comments
 (0)