Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

Enumer is a tool to generate Go code that adds useful methods to Go enums (constants with a specific type).
It started as a fork of [Rob Pike’s Stringer tool](https://godoc.org/golang.org/x/tools/cmd/stringer)
maintained by [Álvaro López Espinosa](https://github.com/alvaroloes/enumer).
maintained by [Álvaro López Espinosa](https://github.com/alvaroloes/enumer).
This was again forked here as (https://github.com/dmarkham/enumer) picking up where Álvaro left off.


Expand Down Expand Up @@ -40,6 +40,8 @@ Flags:
comma-separated list of type names; must be set
-values
if true, alternative string values method will be generated. Default: false
-validate
if true, a `Validate() error` method will be generated. Default: false
-yaml
if true, yaml marshaling methods will be generated. Default: false
```
Expand Down
27 changes: 26 additions & 1 deletion enumer.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package main
import "fmt"

// Arguments to format are:
//
// [1]: type name
const stringNameToValueMethod = `// %[1]sString retrieves an enum value from the enum constants string name.
// Throws an error if the param is not part of the enum.
Expand All @@ -19,6 +20,7 @@ func %[1]sString(s string) (%[1]s, error) {
`

// Arguments to format are:
//
// [1]: type name
const stringValuesMethod = `// %[1]sValues returns all values of the enum
func %[1]sValues() []%[1]s {
Expand All @@ -27,6 +29,7 @@ func %[1]sValues() []%[1]s {
`

// Arguments to format are:
//
// [1]: type name
const stringsMethod = `// %[1]sStrings returns a slice of all String values of the enum
func %[1]sStrings() []string {
Expand All @@ -37,6 +40,7 @@ func %[1]sStrings() []string {
`

// Arguments to format are:
//
// [1]: type name
const stringBelongsMethodLoop = `// IsA%[1]s returns "true" if the value is listed in the enum definition. "false" otherwise
func (i %[1]s) IsA%[1]s() bool {
Expand All @@ -50,15 +54,29 @@ func (i %[1]s) IsA%[1]s() bool {
`

// Arguments to format are:
//
// [1]: type name
const stringBelongsMethodSet = `// IsA%[1]s returns "true" if the value is listed in the enum definition. "false" otherwise
func (i %[1]s) IsA%[1]s() bool {
_, ok := _%[1]sMap[i]
_, ok := _%[1]sMap[i]
return ok
}
`

// Arguments to format are:
//
// [1]: type name
const validateMethod = `// Validate returns an error if the value is not listed in the enum definition.
func (i %[1]s) Validate() error {
if !i.IsA%[1]s() {
return fmt.Errorf("%%d is not a valid value for %[1]s values", i)
}
return nil
}
`

// Arguments to format are:
//
// [1]: type name
const altStringValuesMethod = `func (%[1]s) Values() []string {
return %[1]sStrings()
Expand Down Expand Up @@ -144,6 +162,7 @@ func (g *Generator) printNamesSlice(runs [][]Value, typeName string, runsThresho
}

// Arguments to format are:
//
// [1]: type name
const jsonMethods = `
// MarshalJSON implements the json.Marshaler interface for %[1]s
Expand All @@ -169,6 +188,7 @@ func (g *Generator) buildJSONMethods(runs [][]Value, typeName string, runsThresh
}

// Arguments to format are:
//
// [1]: type name
const textMethods = `
// MarshalText implements the encoding.TextMarshaler interface for %[1]s
Expand All @@ -189,6 +209,7 @@ func (g *Generator) buildTextMethods(runs [][]Value, typeName string, runsThresh
}

// Arguments to format are:
//
// [1]: type name
const yamlMethods = `
// MarshalYAML implements a YAML Marshaler for %[1]s
Expand All @@ -212,3 +233,7 @@ func (i *%[1]s) UnmarshalYAML(unmarshal func(interface{}) error) error {
func (g *Generator) buildYAMLMethods(runs [][]Value, typeName string, runsThreshold int) {
g.Printf(yamlMethods, typeName)
}

func (g *Generator) buildValidateMethod(typeName string) {
g.Printf(validateMethod, typeName)
}
107 changes: 52 additions & 55 deletions golden_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
package main

import (
"fmt"
"io/ioutil"
"os"
"path/filepath"
Expand Down Expand Up @@ -76,6 +77,10 @@ var goldenLinecomment = []Golden{
{"dayWithLinecomment", linecommentIn},
}

var goldenValidateMethod = []Golden{
{name: "validate", input: dayIn},
}

// Each example starts with "type XXX [u]int", with a single space separating them.

// Simple test: enumeration of type int starting at 0.
Expand Down Expand Up @@ -315,95 +320,87 @@ const (

func TestGolden(t *testing.T) {
for _, test := range golden {
runGoldenTest(t, test, false, false, false, false, false, false, true, "", "")
runGoldenTest(t, test, false, false, false, false, false, false, true, false, "", "")
}
for _, test := range goldenJSON {
runGoldenTest(t, test, true, false, false, false, false, false, false, "", "")
runGoldenTest(t, test, true, false, false, false, false, false, false, false, "", "")
}
for _, test := range goldenText {
runGoldenTest(t, test, false, false, false, true, false, false, false, "", "")
runGoldenTest(t, test, false, false, false, true, false, false, false, false, "", "")
}
for _, test := range goldenYAML {
runGoldenTest(t, test, false, true, false, false, false, false, false, "", "")
runGoldenTest(t, test, false, true, false, false, false, false, false, false, "", "")
}
for _, test := range goldenSQL {
runGoldenTest(t, test, false, false, true, false, false, false, false, "", "")
runGoldenTest(t, test, false, false, true, false, false, false, false, false, "", "")
}
for _, test := range goldenJSONAndSQL {
runGoldenTest(t, test, true, false, true, false, false, false, false, "", "")
runGoldenTest(t, test, true, false, true, false, false, false, false, false, "", "")
}
for _, test := range goldenGQLGen {
runGoldenTest(t, test, false, false, false, false, false, true, false, "", "")
runGoldenTest(t, test, false, false, false, false, false, true, false, false, "", "")
}
for _, test := range goldenTrimPrefix {
runGoldenTest(t, test, false, false, false, false, false, false, false, "Day", "")
runGoldenTest(t, test, false, false, false, false, false, false, false, false, "Day", "")
}
for _, test := range goldenTrimPrefixMultiple {
runGoldenTest(t, test, false, false, false, false, false, false, false, "Day,Night", "")
runGoldenTest(t, test, false, false, false, false, false, false, false, false, "Day,Night", "")
}
for _, test := range goldenWithPrefix {
runGoldenTest(t, test, false, false, false, false, false, false, false, "", "Day")
runGoldenTest(t, test, false, false, false, false, false, false, false, false, "", "Day")
}
for _, test := range goldenTrimAndAddPrefix {
runGoldenTest(t, test, false, false, false, false, false, false, false, "Day", "Night")
runGoldenTest(t, test, false, false, false, false, false, false, false, false, "Day", "Night")
}
for _, test := range goldenLinecomment {
runGoldenTest(t, test, false, false, false, false, true, false, false, "", "")
runGoldenTest(t, test, false, false, false, false, true, false, false, false, "", "")
}
for _, test := range goldenValidateMethod {
runGoldenTest(t, test, false, false, false, false, false, false, false, true, "", "")
}
}

func runGoldenTest(t *testing.T, test Golden,
generateJSON, generateYAML, generateSQL, generateText, linecomment, generateGQLGen, generateValuesMethod bool,
generateJSON, generateYAML, generateSQL, generateText, linecomment, generateGQLGen, generateValuesMethod, generateIsValidMethod bool,
trimPrefix string, prefix string) {
t.Run(test.name, func(t *testing.T) {
var g Generator
file := test.name + ".go"
input := "package test\n" + test.input

var g Generator
file := test.name + ".go"
input := "package test\n" + test.input
dir := t.TempDir()

dir, err := ioutil.TempDir("", "stringer")
if err != nil {
t.Error(err)
}
defer func() {
err = os.RemoveAll(dir)
absFile := filepath.Join(dir, file)
err := ioutil.WriteFile(absFile, []byte(input), 0644)
if err != nil {
t.Error(err)
t.Fatal("writing test input to file:", err)
}
}()

absFile := filepath.Join(dir, file)
err = ioutil.WriteFile(absFile, []byte(input), 0644)
if err != nil {
t.Error(err)
}
g.parsePackage([]string{absFile}, nil)
// Extract the name and type of the constant from the first line.
tokens := strings.SplitN(test.input, " ", 3)
if len(tokens) != 3 {
t.Fatalf("%s: need type declaration on first line", test.name)
}
g.generate(tokens[1], generateJSON, generateYAML, generateSQL, generateText, generateGQLGen, "noop", trimPrefix, prefix, linecomment, generateValuesMethod)
got := string(g.format())
if got != loadGolden(test.name) {
// Use this to help build a golden text when changes are needed
//goldenFile := fmt.Sprintf("./testdata/%v.golden", test.name)
//err = ioutil.WriteFile(goldenFile, []byte(got), 0644)
//if err != nil {
// t.Error(err)
//}
t.Errorf("%s: got\n====\n%s====\nexpected\n====%s", test.name, got, loadGolden(test.name))
}
g.parsePackage([]string{absFile}, nil)
// Extract the name and type of the constant from the first line.
tokens := strings.SplitN(test.input, " ", 3)
if len(tokens) != 3 {
t.Fatalf("%s: need type declaration on first line", test.name)
}
g.generate(tokens[1], generateJSON, generateYAML, generateSQL, generateText, generateGQLGen, "noop", trimPrefix, prefix, linecomment, generateValuesMethod, generateIsValidMethod)
got := string(g.format())
golden := loadGolden(t, test.name)
if got != golden {
// Use this to help build a golden text when changes are needed
// goldenFile := fmt.Sprintf("./testdata/%v.golden", test.name)
// err = ioutil.WriteFile(goldenFile, []byte(got), 0644)
// if err != nil {
// t.Error(err)
// }
t.Errorf("%s: got\n====\n%s====\nexpected\n====%s", test.name, got, golden)
}
})
}

func loadGolden(name string) string {
fh, err := os.Open("testdata/" + name + ".golden")
if err != nil {
return ""
}
defer fh.Close()
b, err := ioutil.ReadAll(fh)
func loadGolden(t *testing.T, name string) string {
t.Helper()
b, err := os.ReadFile(fmt.Sprintf("./testdata/%v.golden", name))
if err != nil {
return ""
t.Fatalf("error while loading golden file %v: %v", name, err)
}
return string(b)

Expand Down
24 changes: 15 additions & 9 deletions stringer.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ var (
text = flag.Bool("text", false, "if true, text marshaling methods will be generated. Default: false")
gqlgen = flag.Bool("gqlgen", false, "if true, GraphQL marshaling methods for gqlgen will be generated. Default: false")
altValuesFunc = flag.Bool("values", false, "if true, alternative string values method will be generated. Default: false")
validateFunc = flag.Bool("validate", false, "if true, a `Validate() error` method will be generated. Default: false")
output = flag.String("output", "", "output file name; default srcdir/<type>_string.go")
transformMethod = flag.String("transform", "noop", "enum item name transformation method. Default: noop")
trimPrefix = flag.String("trimprefix", "", "transform each item name by removing a prefix. Default: \"\"")
Expand Down Expand Up @@ -135,7 +136,7 @@ func main() {

// Run generate for each type.
for _, typeName := range typs {
g.generate(typeName, *json, *yaml, *sql, *text, *gqlgen, *transformMethod, *trimPrefix, *addPrefix, *linecomment, *altValuesFunc)
g.generate(typeName, *json, *yaml, *sql, *text, *gqlgen, *transformMethod, *trimPrefix, *addPrefix, *linecomment, *altValuesFunc, *validateFunc)
}

// Format the output.
Expand Down Expand Up @@ -415,7 +416,7 @@ func (g *Generator) prefixValueNames(values []Value, prefix string) {
// generate produces the String method for the named type.
func (g *Generator) generate(typeName string,
includeJSON, includeYAML, includeSQL, includeText, includeGQLGen bool,
transformMethod string, trimPrefix string, addPrefix string, lineComment bool, includeValuesMethod bool) {
transformMethod string, trimPrefix string, addPrefix string, lineComment bool, includeValuesMethod bool, includeValidateFunc bool) {
values := make([]Value, 0, 100)
for _, file := range g.pkg.files {
file.lineComment = lineComment
Expand Down Expand Up @@ -465,6 +466,9 @@ func (g *Generator) generate(typeName string,
if includeValuesMethod {
g.buildAltStringValuesMethod(typeName)
}
if includeValidateFunc {
g.buildValidateMethod(typeName)
}

g.buildNoOpOrderChangeDetect(runs, typeName)

Expand Down Expand Up @@ -774,9 +778,10 @@ func (g *Generator) buildOneRun(runs [][]Value, typeName string) {
}

// Arguments to format are:
// [1]: type name
// [2]: size of index element (8 for uint8 etc.)
// [3]: less than zero check (for signed types)
//
// [1]: type name
// [2]: size of index element (8 for uint8 etc.)
// [3]: less than zero check (for signed types)
const stringOneRun = `func (i %[1]s) String() string {
if %[3]si >= %[1]s(len(_%[1]sIndex)-1) {
return fmt.Sprintf("%[1]s(%%d)", i)
Expand All @@ -786,10 +791,11 @@ const stringOneRun = `func (i %[1]s) String() string {
`

// Arguments to format are:
// [1]: type name
// [2]: lowest defined value for type, as a string
// [3]: size of index element (8 for uint8 etc.)
// [4]: less than zero check (for signed types)
//
// [1]: type name
// [2]: lowest defined value for type, as a string
// [3]: size of index element (8 for uint8 etc.)
// [4]: less than zero check (for signed types)
const stringOneRunWithOffset = `func (i %[1]s) String() string {
i -= %[2]s
if %[4]si >= %[1]s(len(_%[1]sIndex)-1) {
Expand Down
Loading