Skip to content

Commit 701fa38

Browse files
samiam2013bincyber
andcommitted
Add -typederrors flag for typed enum conversion errors
Co-authored-by: @bincyber <bincyber@users.noreply.github.com>
1 parent 750eb57 commit 701fa38

File tree

9 files changed

+232
-37
lines changed

9 files changed

+232
-37
lines changed

.github/workflows/go.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,9 @@ jobs:
2626

2727
- name: Test
2828
run: go test -v ./...
29+
30+
- name: Run Golden Tests
31+
run: go test -v -run TestGolden
32+
33+
- name: Run End-to-End Tests
34+
run: go test -v -run TestEndToEnd

README.md

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ Flags:
3737
transform each item name by removing a prefix or comma separated list of prefixes. Default: ""
3838
-type string
3939
comma-separated list of type names; must be set
40+
-typederrors
41+
if true, errors from errs/ will be errors.Join()-ed for errors.Is(...) to simplify invalid value handling. Default: false
4042
-values
4143
if true, alternative string values method will be generated. Default: false
4244
-yaml
@@ -70,6 +72,9 @@ When Enumer is applied to a type, it will generate:
7072
the enum conform to the `gopkg.in/yaml.v2.Marshaler` and `gopkg.in/yaml.v2.Unmarshaler` interfaces.
7173
- When the flag `sql` is provided, the methods for implementing the `Scanner` and `Valuer` interfaces.
7274
Useful when storing the enum in a database.
75+
- When the flag `typederrors` is provided, the string conversion functions will return errors wrapped with
76+
`errors.Join()` containing a typed error from the `errs` package. This allows you to use `errors.Is()` to
77+
check for specific enum validation failures.
7378

7479

7580
For example, if we have an enum type called `Pill`,
@@ -200,7 +205,7 @@ For a module-aware repo with `enumer` in the `go.mod` file, generation can be ca
200205
//go:generate go run github.com/dmarkham/enumer -type=YOURTYPE
201206
```
202207

203-
There are four boolean flags: `json`, `text`, `yaml` and `sql`. You can use any combination of them (i.e. `enumer -type=Pill -json -text`),
208+
There are five boolean flags: `json`, `text`, `yaml`, `sql`, and `typederrors`. You can use any combination of them (i.e. `enumer -type=Pill -json -text -typederrors`),
204209

205210
For enum string representation transformation the `transform` and `trimprefix` flags
206211
were added (i.e. `enumer -type=MyType -json -transform=snake`).
@@ -215,6 +220,28 @@ If a prefix is provided via the `addprefix` flag, it will be added to the start
215220

216221
The boolean flag `values` will additionally create an alternative string values method `Values() []string` to fullfill the `EnumValues` interface of [ent](https://entgo.io/docs/schema-fields/#enum-fields).
217222

223+
## Typed Error Handling
224+
225+
When using the `typederrors` flag, you can handle enum validation errors specifically using `errors.Is()`:
226+
227+
```go
228+
import (
229+
"errors"
230+
"github.com/dmarkham/enumer/errs"
231+
)
232+
233+
// This will return a typed error that can be checked
234+
pill, err := PillString("InvalidValue")
235+
if err != nil {
236+
if errors.Is(err, errs.ErrValueInvalid) {
237+
// Handle invalid enum value specifically
238+
fmt.Println("Invalid pill value provided")
239+
}
240+
// The error also contains a descriptive message
241+
fmt.Printf("Error: %v\n", err)
242+
}
243+
```
244+
218245
## Inspiring projects
219246

220247
- [Álvaro López Espinosa](https://github.com/alvaroloes/enumer)

endtoend_test.go

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
// go command is not available on android
66

7+
//go:build !android
78
// +build !android
89

910
package main
@@ -75,6 +76,7 @@ func TestEndToEnd(t *testing.T) {
7576
// Names are known to be ASCII and long enough.
7677
var typeName string
7778
var transformNameMethod string
79+
var useTypedErrors bool
7880

7981
switch name {
8082
case "transform_snake.go":
@@ -110,18 +112,22 @@ func TestEndToEnd(t *testing.T) {
110112
case "transform_whitespace.go":
111113
typeName = "WhitespaceSeparatedValue"
112114
transformNameMethod = "whitespace"
115+
case "typedErrors.go":
116+
typeName = "TypedErrorsValue"
117+
transformNameMethod = "noop"
118+
useTypedErrors = true
113119
default:
114120
typeName = fmt.Sprintf("%c%s", name[0]+'A'-'a', name[1:len(name)-len(".go")])
115121
transformNameMethod = "noop"
116122
}
117123

118-
stringerCompileAndRun(t, dir, stringer, typeName, name, transformNameMethod)
124+
stringerCompileAndRun(t, dir, stringer, typeName, name, transformNameMethod, useTypedErrors)
119125
}
120126
}
121127

122128
// stringerCompileAndRun runs stringer for the named file and compiles and
123129
// runs the target binary in directory dir. That binary will panic if the String method is incorrect.
124-
func stringerCompileAndRun(t *testing.T, dir, stringer, typeName, fileName, transformNameMethod string) {
130+
func stringerCompileAndRun(t *testing.T, dir, stringer, typeName, fileName, transformNameMethod string, useTypedErrors bool) {
125131
t.Logf("run: %s %s\n", fileName, typeName)
126132
source := filepath.Join(dir, fileName)
127133
err := copy(source, filepath.Join("testdata", fileName))
@@ -130,7 +136,12 @@ func stringerCompileAndRun(t *testing.T, dir, stringer, typeName, fileName, tran
130136
}
131137
stringSource := filepath.Join(dir, typeName+"_string.go")
132138
// Run stringer in temporary directory.
133-
err = run(stringer, "-type", typeName, "-output", stringSource, "-transform", transformNameMethod, source)
139+
args := []string{"-type", typeName, "-output", stringSource, "-transform", transformNameMethod}
140+
if useTypedErrors {
141+
args = append(args, "-typederrors", "-values")
142+
}
143+
args = append(args, source)
144+
err = run(stringer, args...)
134145
if err != nil {
135146
t.Fatal(err)
136147
}

enumer.go

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@ package main
33
import "fmt"
44

55
// Arguments to format are:
6+
//
67
// [1]: type name
8+
// [2]: complete error expression
79
const stringNameToValueMethod = `// %[1]sString retrieves an enum value from the enum constants string name.
810
// Throws an error if the param is not part of the enum.
911
func %[1]sString(s string) (%[1]s, error) {
@@ -14,11 +16,12 @@ func %[1]sString(s string) (%[1]s, error) {
1416
if val, ok := _%[1]sNameToValueMap[strings.ToLower(s)]; ok {
1517
return val, nil
1618
}
17-
return 0, fmt.Errorf("%%s does not belong to %[1]s values", s)
19+
return 0, %[2]s
1820
}
1921
`
2022

2123
// Arguments to format are:
24+
//
2225
// [1]: type name
2326
const stringValuesMethod = `// %[1]sValues returns all values of the enum
2427
func %[1]sValues() []%[1]s {
@@ -27,6 +30,7 @@ func %[1]sValues() []%[1]s {
2730
`
2831

2932
// Arguments to format are:
33+
//
3034
// [1]: type name
3135
const stringsMethod = `// %[1]sStrings returns a slice of all String values of the enum
3236
func %[1]sStrings() []string {
@@ -37,6 +41,7 @@ func %[1]sStrings() []string {
3741
`
3842

3943
// Arguments to format are:
44+
//
4045
// [1]: type name
4146
const stringBelongsMethodLoop = `// IsA%[1]s returns "true" if the value is listed in the enum definition. "false" otherwise
4247
func (i %[1]s) IsA%[1]s() bool {
@@ -50,6 +55,7 @@ func (i %[1]s) IsA%[1]s() bool {
5055
`
5156

5257
// Arguments to format are:
58+
//
5359
// [1]: type name
5460
const stringBelongsMethodSet = `// IsA%[1]s returns "true" if the value is listed in the enum definition. "false" otherwise
5561
func (i %[1]s) IsA%[1]s() bool {
@@ -59,6 +65,7 @@ func (i %[1]s) IsA%[1]s() bool {
5965
`
6066

6167
// Arguments to format are:
68+
//
6269
// [1]: type name
6370
const altStringValuesMethod = `func (%[1]s) Values() []string {
6471
return %[1]sStrings()
@@ -70,7 +77,9 @@ func (g *Generator) buildAltStringValuesMethod(typeName string) {
7077
g.Printf(altStringValuesMethod, typeName)
7178
}
7279

73-
func (g *Generator) buildBasicExtras(runs [][]Value, typeName string, runsThreshold int) {
80+
var ErrInvalidValue = fmt.Errorf("invalid enumer value")
81+
82+
func (g *Generator) buildBasicExtras(runs [][]Value, typeName string, runsThreshold int, useTypedErrors bool) {
7483
// At this moment, either "g.declareIndexAndNameVars()" or "g.declareNameVars()" has been called
7584

7685
// Print the slice of values
@@ -89,7 +98,13 @@ func (g *Generator) buildBasicExtras(runs [][]Value, typeName string, runsThresh
8998
g.printNamesSlice(runs, typeName, runsThreshold)
9099

91100
// Print the basic extra methods
92-
g.Printf(stringNameToValueMethod, typeName)
101+
var errorCode string
102+
if useTypedErrors {
103+
errorCode = fmt.Sprintf(`errors.Join(enumerrs.ErrValueInvalid, fmt.Errorf("%%s does not belong to %s values", s))`, typeName)
104+
} else {
105+
errorCode = fmt.Sprintf(`fmt.Errorf("%%s does not belong to %s values", s)`, typeName)
106+
}
107+
g.Printf(stringNameToValueMethod, typeName, errorCode)
93108
g.Printf(stringValuesMethod, typeName)
94109
g.Printf(stringsMethod, typeName)
95110
if len(runs) <= runsThreshold {
@@ -144,6 +159,7 @@ func (g *Generator) printNamesSlice(runs [][]Value, typeName string, runsThresho
144159
}
145160

146161
// Arguments to format are:
162+
//
147163
// [1]: type name
148164
const jsonMethods = `
149165
// MarshalJSON implements the json.Marshaler interface for %[1]s
@@ -164,11 +180,14 @@ func (i *%[1]s) UnmarshalJSON(data []byte) error {
164180
}
165181
`
166182

167-
func (g *Generator) buildJSONMethods(runs [][]Value, typeName string, runsThreshold int) {
183+
func (g *Generator) buildJSONMethods(runs [][]Value, typeName string, runsThreshold int, useTypedErrors bool) {
184+
// For now, just use the standard template
185+
// We rely on the %[1]sString method to provide typed errors when enabled
168186
g.Printf(jsonMethods, typeName)
169187
}
170188

171189
// Arguments to format are:
190+
//
172191
// [1]: type name
173192
const textMethods = `
174193
// MarshalText implements the encoding.TextMarshaler interface for %[1]s
@@ -184,11 +203,14 @@ func (i *%[1]s) UnmarshalText(text []byte) error {
184203
}
185204
`
186205

187-
func (g *Generator) buildTextMethods(runs [][]Value, typeName string, runsThreshold int) {
206+
func (g *Generator) buildTextMethods(runs [][]Value, typeName string, runsThreshold int, useTypedErrors bool) {
207+
// For now, just use the standard template
208+
// We rely on the %[1]sString method to provide typed errors when enabled
188209
g.Printf(textMethods, typeName)
189210
}
190211

191212
// Arguments to format are:
213+
//
192214
// [1]: type name
193215
const yamlMethods = `
194216
// MarshalYAML implements a YAML Marshaler for %[1]s
@@ -209,6 +231,8 @@ func (i *%[1]s) UnmarshalYAML(unmarshal func(interface{}) error) error {
209231
}
210232
`
211233

212-
func (g *Generator) buildYAMLMethods(runs [][]Value, typeName string, runsThreshold int) {
234+
func (g *Generator) buildYAMLMethods(runs [][]Value, typeName string, runsThreshold int, useTypedErrors bool) {
235+
// For now, just use the standard template
236+
// We rely on the %[1]sString method to provide typed errors when enabled
213237
g.Printf(yamlMethods, typeName)
214238
}

enumerrs/errors.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
package enumerrs
2+
3+
import "errors"
4+
5+
// This package defines custom error types for use in the generated code.
6+
7+
// ErrValueInvalid is returned when a value does not belong to the set of valid values for a type.
8+
var ErrValueInvalid = errors.New("the input value is not valid for the type")

golden_test.go

Lines changed: 36 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,10 @@ var goldenLinecomment = []Golden{
7676
{"dayWithLinecomment", linecommentIn},
7777
}
7878

79+
var goldenTypedErrors = []Golden{
80+
{"typedErrors", typedErrorsIn},
81+
}
82+
7983
// Each example starts with "type XXX [u]int", with a single space separating them.
8084

8185
// Simple test: enumeration of type int starting at 0.
@@ -313,54 +317,65 @@ const (
313317
)
314318
`
315319

320+
const typedErrorsIn = `type TypedErrorsValue int
321+
const (
322+
TypedErrorsValueOne TypedErrorsValue = iota
323+
TypedErrorsValueTwo
324+
TypedErrorsValueThree
325+
)
326+
`
327+
316328
func TestGolden(t *testing.T) {
317329
for _, test := range golden {
318-
runGoldenTest(t, test, false, false, false, false, false, false, true, "", "")
330+
runGoldenTest(t, test, false, false, false, false, false, false, true, "", "", false)
319331
}
320332
for _, test := range goldenJSON {
321-
runGoldenTest(t, test, true, false, false, false, false, false, false, "", "")
333+
runGoldenTest(t, test, true, false, false, false, false, false, false, "", "", false)
322334
}
323335
for _, test := range goldenText {
324-
runGoldenTest(t, test, false, false, false, true, false, false, false, "", "")
336+
runGoldenTest(t, test, false, false, false, true, false, false, false, "", "", false)
325337
}
326338
for _, test := range goldenYAML {
327-
runGoldenTest(t, test, false, true, false, false, false, false, false, "", "")
339+
runGoldenTest(t, test, false, true, false, false, false, false, false, "", "", false)
328340
}
329341
for _, test := range goldenSQL {
330-
runGoldenTest(t, test, false, false, true, false, false, false, false, "", "")
342+
runGoldenTest(t, test, false, false, true, false, false, false, false, "", "", false)
331343
}
332344
for _, test := range goldenJSONAndSQL {
333-
runGoldenTest(t, test, true, false, true, false, false, false, false, "", "")
345+
runGoldenTest(t, test, true, false, true, false, false, false, false, "", "", false)
334346
}
335347
for _, test := range goldenGQLGen {
336-
runGoldenTest(t, test, false, false, false, false, false, true, false, "", "")
348+
runGoldenTest(t, test, false, false, false, false, false, true, false, "", "", false)
337349
}
338350
for _, test := range goldenTrimPrefix {
339-
runGoldenTest(t, test, false, false, false, false, false, false, false, "Day", "")
351+
runGoldenTest(t, test, false, false, false, false, false, false, false, "Day", "", false)
340352
}
341353
for _, test := range goldenTrimPrefixMultiple {
342-
runGoldenTest(t, test, false, false, false, false, false, false, false, "Day,Night", "")
354+
runGoldenTest(t, test, false, false, false, false, false, false, false, "Day,Night", "", false)
343355
}
344356
for _, test := range goldenWithPrefix {
345-
runGoldenTest(t, test, false, false, false, false, false, false, false, "", "Day")
357+
runGoldenTest(t, test, false, false, false, false, false, false, false, "", "Day", false)
346358
}
347359
for _, test := range goldenTrimAndAddPrefix {
348-
runGoldenTest(t, test, false, false, false, false, false, false, false, "Day", "Night")
360+
runGoldenTest(t, test, false, false, false, false, false, false, false, "Day", "Night", false)
349361
}
350362
for _, test := range goldenLinecomment {
351-
runGoldenTest(t, test, false, false, false, false, true, false, false, "", "")
363+
runGoldenTest(t, test, false, false, false, false, true, false, false, "", "", false)
364+
}
365+
for _, test := range goldenTypedErrors {
366+
runGoldenTest(t, test, false, false, false, false, false, false, false, "", "", true)
352367
}
353368
}
354369

355370
func runGoldenTest(t *testing.T, test Golden,
356371
generateJSON, generateYAML, generateSQL, generateText, linecomment, generateGQLGen, generateValuesMethod bool,
357-
trimPrefix string, prefix string) {
372+
trimPrefix string, prefix string, useTypedErrors bool) {
358373

359374
var g Generator
360375
file := test.name + ".go"
361376
input := "package test\n" + test.input
362377

363-
dir, err := ioutil.TempDir("", "stringer")
378+
dir, err := os.MkdirTemp("", "stringer")
364379
if err != nil {
365380
t.Error(err)
366381
}
@@ -372,7 +387,7 @@ func runGoldenTest(t *testing.T, test Golden,
372387
}()
373388

374389
absFile := filepath.Join(dir, file)
375-
err = ioutil.WriteFile(absFile, []byte(input), 0644)
390+
err = os.WriteFile(absFile, []byte(input), 0644)
376391
if err != nil {
377392
t.Error(err)
378393
}
@@ -382,15 +397,15 @@ func runGoldenTest(t *testing.T, test Golden,
382397
if len(tokens) != 3 {
383398
t.Fatalf("%s: need type declaration on first line", test.name)
384399
}
385-
g.generate(tokens[1], generateJSON, generateYAML, generateSQL, generateText, generateGQLGen, "noop", trimPrefix, prefix, linecomment, generateValuesMethod)
400+
g.generate(tokens[1], generateJSON, generateYAML, generateSQL, generateText, generateGQLGen, "noop", trimPrefix, prefix, linecomment, generateValuesMethod, useTypedErrors)
386401
got := string(g.format())
387402
if got != loadGolden(test.name) {
388403
// Use this to help build a golden text when changes are needed
389-
//goldenFile := fmt.Sprintf("./testdata/%v.golden", test.name)
390-
//err = ioutil.WriteFile(goldenFile, []byte(got), 0644)
391-
//if err != nil {
392-
// t.Error(err)
393-
//}
404+
// goldenFile := fmt.Sprintf("./testdata/%v.golden", test.name)
405+
// err = os.WriteFile(goldenFile, []byte(got), 0644)
406+
// if err != nil {
407+
// t.Error(err)
408+
// }
394409
t.Errorf("%s: got\n====\n%s====\nexpected\n====%s", test.name, got, loadGolden(test.name))
395410
}
396411
}

0 commit comments

Comments
 (0)