Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,4 @@ coverage.txt
.idea
.vscode

enumer
./enumer
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

Enumer is a tool to generate Go code that adds useful methods to Go enums (constants with a specific type).
It started as a fork of [Rob Pike’s Stringer tool](https://godoc.org/golang.org/x/tools/cmd/stringer)
maintained by [Álvaro López Espinosa](https://github.com/alvaroloes/enumer).
maintained by [Álvaro López Espinosa](https://github.com/alvaroloes/enumer).
This was again forked here as (https://github.com/dmarkham/enumer) picking up where Álvaro left off.


Expand Down Expand Up @@ -41,6 +41,8 @@ Flags:
if true, alternative string values method will be generated. Default: false
-yaml
if true, yaml marshaling methods will be generated. Default: false
-customerror
if true, a custom error will be returned by the `<Type>String` function. Default: false
```


Expand Down
4 changes: 2 additions & 2 deletions endtoend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

// go command is not available on android

//go:build !android
// +build !android

package main
Expand All @@ -12,7 +13,6 @@ import (
"fmt"
"go/build"
"io"
"io/ioutil"
"os"
"os/exec"
"path/filepath"
Expand Down Expand Up @@ -40,7 +40,7 @@ func init() {
// binary panics if the String method for X is not correct, including for error cases.

func TestEndToEnd(t *testing.T) {
dir, err := ioutil.TempDir("", "stringer")
dir, err := os.MkdirTemp("", "stringer")
if err != nil {
t.Fatal(err)
}
Expand Down
24 changes: 20 additions & 4 deletions enumer.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
package main

import "fmt"
import (
"fmt"
)

// Arguments to format are:
//
// [1]: type name
const stringNameToValueMethod = `// %[1]sString retrieves an enum value from the enum constants string name.
// Throws an error if the param is not part of the enum.
Expand All @@ -14,11 +17,12 @@ func %[1]sString(s string) (%[1]s, error) {
if val, ok := _%[1]sNameToValueMap[strings.ToLower(s)]; ok {
return val, nil
}
return 0, fmt.Errorf("%%s does not belong to %[1]s values", s)
return 0, %[2]s
}
`

// Arguments to format are:
//
// [1]: type name
const stringValuesMethod = `// %[1]sValues returns all values of the enum
func %[1]sValues() []%[1]s {
Expand All @@ -27,6 +31,7 @@ func %[1]sValues() []%[1]s {
`

// Arguments to format are:
//
// [1]: type name
const stringsMethod = `// %[1]sStrings returns a slice of all String values of the enum
func %[1]sStrings() []string {
Expand All @@ -37,6 +42,7 @@ func %[1]sStrings() []string {
`

// Arguments to format are:
//
// [1]: type name
const stringBelongsMethodLoop = `// IsA%[1]s returns "true" if the value is listed in the enum definition. "false" otherwise
func (i %[1]s) IsA%[1]s() bool {
Expand All @@ -50,6 +56,7 @@ func (i %[1]s) IsA%[1]s() bool {
`

// Arguments to format are:
//
// [1]: type name
const stringBelongsMethodSet = `// IsA%[1]s returns "true" if the value is listed in the enum definition. "false" otherwise
func (i %[1]s) IsA%[1]s() bool {
Expand All @@ -59,6 +66,7 @@ func (i %[1]s) IsA%[1]s() bool {
`

// Arguments to format are:
//
// [1]: type name
const altStringValuesMethod = `func (%[1]s) Values() []string {
return %[1]sStrings()
Expand All @@ -70,7 +78,7 @@ func (g *Generator) buildAltStringValuesMethod(typeName string) {
g.Printf(altStringValuesMethod, typeName)
}

func (g *Generator) buildBasicExtras(runs [][]Value, typeName string, runsThreshold int) {
func (g *Generator) buildBasicExtras(runs [][]Value, typeName string, runsThreshold int, customError bool) {
// At this moment, either "g.declareIndexAndNameVars()" or "g.declareNameVars()" has been called

// Print the slice of values
Expand All @@ -89,7 +97,12 @@ func (g *Generator) buildBasicExtras(runs [][]Value, typeName string, runsThresh
g.printNamesSlice(runs, typeName, runsThreshold)

// Print the basic extra methods
g.Printf(stringNameToValueMethod, typeName)
stringNameToValueErr := fmt.Sprintf(`fmt.Errorf("%%s does not belong to %s values", s)`, typeName)
if customError {
stringNameToValueErr = `enumer.InvalidEnumValueError`
}
g.Printf(stringNameToValueMethod, typeName, stringNameToValueErr)

g.Printf(stringValuesMethod, typeName)
g.Printf(stringsMethod, typeName)
if len(runs) <= runsThreshold {
Expand Down Expand Up @@ -144,6 +157,7 @@ func (g *Generator) printNamesSlice(runs [][]Value, typeName string, runsThresho
}

// Arguments to format are:
//
// [1]: type name
const jsonMethods = `
// MarshalJSON implements the json.Marshaler interface for %[1]s
Expand All @@ -169,6 +183,7 @@ func (g *Generator) buildJSONMethods(runs [][]Value, typeName string, runsThresh
}

// Arguments to format are:
//
// [1]: type name
const textMethods = `
// MarshalText implements the encoding.TextMarshaler interface for %[1]s
Expand All @@ -189,6 +204,7 @@ func (g *Generator) buildTextMethods(runs [][]Value, typeName string, runsThresh
}

// Arguments to format are:
//
// [1]: type name
const yamlMethods = `
// MarshalYAML implements a YAML Marshaler for %[1]s
Expand Down
49 changes: 25 additions & 24 deletions golden_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
package main

import (
"io/ioutil"
"os"
"path/filepath"
"strings"
Expand Down Expand Up @@ -68,6 +67,10 @@ var goldenWithPrefix = []Golden{
{"dayWithPrefix", dayIn},
}

var goldenWithCustomError = []Golden{
{"dayWithCustomError", dayIn},
}

var goldenTrimAndAddPrefix = []Golden{
{"dayTrimAndPrefix", trimPrefixIn},
}
Expand Down Expand Up @@ -315,52 +318,55 @@ const (

func TestGolden(t *testing.T) {
for _, test := range golden {
runGoldenTest(t, test, false, false, false, false, false, false, true, "", "")
runGoldenTest(t, test, false, false, false, false, false, false, true, false, "", "")
}
for _, test := range goldenJSON {
runGoldenTest(t, test, true, false, false, false, false, false, false, "", "")
runGoldenTest(t, test, true, false, false, false, false, false, false, false, "", "")
}
for _, test := range goldenText {
runGoldenTest(t, test, false, false, false, true, false, false, false, "", "")
runGoldenTest(t, test, false, false, false, true, false, false, false, false, "", "")
}
for _, test := range goldenYAML {
runGoldenTest(t, test, false, true, false, false, false, false, false, "", "")
runGoldenTest(t, test, false, true, false, false, false, false, false, false, "", "")
}
for _, test := range goldenSQL {
runGoldenTest(t, test, false, false, true, false, false, false, false, "", "")
runGoldenTest(t, test, false, false, true, false, false, false, false, false, "", "")
}
for _, test := range goldenJSONAndSQL {
runGoldenTest(t, test, true, false, true, false, false, false, false, "", "")
runGoldenTest(t, test, true, false, true, false, false, false, false, false, "", "")
}
for _, test := range goldenGQLGen {
runGoldenTest(t, test, false, false, false, false, false, true, false, "", "")
runGoldenTest(t, test, false, false, false, false, false, true, false, false, "", "")
}
for _, test := range goldenTrimPrefix {
runGoldenTest(t, test, false, false, false, false, false, false, false, "Day", "")
runGoldenTest(t, test, false, false, false, false, false, false, false, false, "Day", "")
}
for _, test := range goldenTrimPrefixMultiple {
runGoldenTest(t, test, false, false, false, false, false, false, false, "Day,Night", "")
runGoldenTest(t, test, false, false, false, false, false, false, false, false, "Day,Night", "")
}
for _, test := range goldenWithPrefix {
runGoldenTest(t, test, false, false, false, false, false, false, false, "", "Day")
runGoldenTest(t, test, false, false, false, false, false, false, false, false, "", "Day")
}
for _, test := range goldenWithCustomError {
runGoldenTest(t, test, false, false, false, false, false, false, false, true, "", "Day")
}
for _, test := range goldenTrimAndAddPrefix {
runGoldenTest(t, test, false, false, false, false, false, false, false, "Day", "Night")
runGoldenTest(t, test, false, false, false, false, false, false, false, false, "Day", "Night")
}
for _, test := range goldenLinecomment {
runGoldenTest(t, test, false, false, false, false, true, false, false, "", "")
runGoldenTest(t, test, false, false, false, false, true, false, false, false, "", "")
}
}

func runGoldenTest(t *testing.T, test Golden,
generateJSON, generateYAML, generateSQL, generateText, linecomment, generateGQLGen, generateValuesMethod bool,
generateJSON, generateYAML, generateSQL, generateText, linecomment, generateGQLGen, generateValuesMethod bool, generateCustomError bool,
trimPrefix string, prefix string) {

var g Generator
file := test.name + ".go"
input := "package test\n" + test.input

dir, err := ioutil.TempDir("", "stringer")
dir, err := os.MkdirTemp("", "stringer")
if err != nil {
t.Error(err)
}
Expand All @@ -372,7 +378,7 @@ func runGoldenTest(t *testing.T, test Golden,
}()

absFile := filepath.Join(dir, file)
err = ioutil.WriteFile(absFile, []byte(input), 0644)
err = os.WriteFile(absFile, []byte(input), 0644)
if err != nil {
t.Error(err)
}
Expand All @@ -382,12 +388,12 @@ func runGoldenTest(t *testing.T, test Golden,
if len(tokens) != 3 {
t.Fatalf("%s: need type declaration on first line", test.name)
}
g.generate(tokens[1], generateJSON, generateYAML, generateSQL, generateText, generateGQLGen, "noop", trimPrefix, prefix, linecomment, generateValuesMethod)
g.generate(tokens[1], generateJSON, generateYAML, generateSQL, generateText, generateGQLGen, "noop", trimPrefix, prefix, linecomment, generateCustomError, generateValuesMethod)
got := string(g.format())
if got != loadGolden(test.name) {
// Use this to help build a golden text when changes are needed
//goldenFile := fmt.Sprintf("./testdata/%v.golden", test.name)
//err = ioutil.WriteFile(goldenFile, []byte(got), 0644)
//err = os.WriteFile(goldenFile, []byte(got), 0644)
//if err != nil {
// t.Error(err)
//}
Expand All @@ -396,12 +402,7 @@ func runGoldenTest(t *testing.T, test Golden,
}

func loadGolden(name string) string {
fh, err := os.Open("testdata/" + name + ".golden")
if err != nil {
return ""
}
defer fh.Close()
b, err := ioutil.ReadAll(fh)
b, err := os.ReadFile("testdata/" + name + ".golden")
if err != nil {
return ""
}
Expand Down
5 changes: 5 additions & 0 deletions pkg/enumer/enumer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package enumer

import "errors"

var InvalidEnumValueError = errors.New("invalid enum value")
20 changes: 11 additions & 9 deletions stringer.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ import (
"go/importer"
"go/token"
"go/types"
"io/ioutil"
"log"
"os"
"path/filepath"
Expand Down Expand Up @@ -56,6 +55,7 @@ var (
trimPrefix = flag.String("trimprefix", "", "transform each item name by removing a prefix or comma separated list of prefixes. Default: \"\"")
addPrefix = flag.String("addprefix", "", "transform each item name by adding a prefix. Default: \"\"")
linecomment = flag.Bool("linecomment", false, "use line comment text as printed text when present")
customError = flag.Bool("customerror", false, "if true, a custom error will be returned by the `<Type>String` function. Default: false")
)

var comments arrayFlags
Expand Down Expand Up @@ -131,11 +131,15 @@ func main() {
g.Printf("\t\"io\"\n")
g.Printf("\t\"strconv\"\n")
}
if *customError {
g.Printf("\n\t\"github.com/dmarkham/enumer/pkg/enumer\"\n")
}

g.Printf(")\n")

// Run generate for each type.
for _, typeName := range typs {
g.generate(typeName, *json, *yaml, *sql, *text, *gqlgen, *transformMethod, *trimPrefix, *addPrefix, *linecomment, *altValuesFunc)
g.generate(typeName, *json, *yaml, *sql, *text, *gqlgen, *transformMethod, *trimPrefix, *addPrefix, *linecomment, *customError, *altValuesFunc)
}

// Format the output.
Expand All @@ -149,21 +153,19 @@ func main() {
}

// Write to tmpfile first
tmpFile, err := ioutil.TempFile(dir, fmt.Sprintf("%s_enumer_", typs[0]))
tmpFile, err := os.CreateTemp(dir, fmt.Sprintf("%s_enumer_", typs[0]))
if err != nil {
log.Fatalf("creating temporary file for output: %s", err)
}
_, err = tmpFile.Write(src)
if err != nil {
if _, err = tmpFile.Write(src); err != nil {
tmpFile.Close()
os.Remove(tmpFile.Name())
log.Fatalf("writing output: %s", err)
}
tmpFile.Close()

// Rename tmpfile to output file
err = os.Rename(tmpFile.Name(), outputName)
if err != nil {
if err := os.Rename(tmpFile.Name(), outputName); err != nil {
log.Fatalf("moving tempfile to output file: %s", err)
}
}
Expand Down Expand Up @@ -415,7 +417,7 @@ func (g *Generator) prefixValueNames(values []Value, prefix string) {
// generate produces the String method for the named type.
func (g *Generator) generate(typeName string,
includeJSON, includeYAML, includeSQL, includeText, includeGQLGen bool,
transformMethod string, trimPrefix string, addPrefix string, lineComment bool, includeValuesMethod bool) {
transformMethod string, trimPrefix string, addPrefix string, lineComment bool, customError bool, includeValuesMethod bool) {
values := make([]Value, 0, 100)
for _, file := range g.pkg.files {
file.lineComment = lineComment
Expand Down Expand Up @@ -468,7 +470,7 @@ func (g *Generator) generate(typeName string,

g.buildNoOpOrderChangeDetect(runs, typeName)

g.buildBasicExtras(runs, typeName, runsThreshold)
g.buildBasicExtras(runs, typeName, runsThreshold, customError)
if includeJSON {
g.buildJSONMethods(runs, typeName, runsThreshold)
}
Expand Down
Loading