Skip to content

Commit 0bba577

Browse files
samiam2013bincyber
andcommitted
Add -typederrors flag for typed enum conversion errors
Co-authored-by: @bincyber <bincyber@users.noreply.github.com>
1 parent 750eb57 commit 0bba577

File tree

5 files changed

+78
-14
lines changed

5 files changed

+78
-14
lines changed

README.md

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ Flags:
3737
transform each item name by removing a prefix or comma separated list of prefixes. Default: ""
3838
-type string
3939
comma-separated list of type names; must be set
40+
-typederrors
41+
if true, errors from errs/ will be errors.Join()-ed for errors.Is(...) to simplify invalid value handling. Default: false
4042
-values
4143
if true, alternative string values method will be generated. Default: false
4244
-yaml
@@ -70,6 +72,9 @@ When Enumer is applied to a type, it will generate:
7072
the enum conform to the `gopkg.in/yaml.v2.Marshaler` and `gopkg.in/yaml.v2.Unmarshaler` interfaces.
7173
- When the flag `sql` is provided, the methods for implementing the `Scanner` and `Valuer` interfaces.
7274
Useful when storing the enum in a database.
75+
- When the flag `typederrors` is provided, the string conversion functions will return errors wrapped with
76+
`errors.Join()` containing a typed error from the `errs` package. This allows you to use `errors.Is()` to
77+
check for specific enum validation failures.
7378

7479

7580
For example, if we have an enum type called `Pill`,
@@ -200,7 +205,7 @@ For a module-aware repo with `enumer` in the `go.mod` file, generation can be ca
200205
//go:generate go run github.com/dmarkham/enumer -type=YOURTYPE
201206
```
202207

203-
There are four boolean flags: `json`, `text`, `yaml` and `sql`. You can use any combination of them (i.e. `enumer -type=Pill -json -text`),
208+
There are five boolean flags: `json`, `text`, `yaml`, `sql`, and `typederrors`. You can use any combination of them (i.e. `enumer -type=Pill -json -text -typederrors`),
204209

205210
For enum string representation transformation the `transform` and `trimprefix` flags
206211
were added (i.e. `enumer -type=MyType -json -transform=snake`).
@@ -215,6 +220,28 @@ If a prefix is provided via the `addprefix` flag, it will be added to the start
215220

216221
The boolean flag `values` will additionally create an alternative string values method `Values() []string` to fullfill the `EnumValues` interface of [ent](https://entgo.io/docs/schema-fields/#enum-fields).
217222

223+
## Typed Error Handling
224+
225+
When using the `typederrors` flag, you can handle enum validation errors specifically using `errors.Is()`:
226+
227+
```go
228+
import (
229+
"errors"
230+
"github.com/dmarkham/enumer/errs"
231+
)
232+
233+
// This will return a typed error that can be checked
234+
pill, err := PillString("InvalidValue")
235+
if err != nil {
236+
if errors.Is(err, errs.ErrValueInvalid) {
237+
// Handle invalid enum value specifically
238+
fmt.Println("Invalid pill value provided")
239+
}
240+
// The error also contains a descriptive message
241+
fmt.Printf("Error: %v\n", err)
242+
}
243+
```
244+
218245
## Inspiring projects
219246

220247
- [Álvaro López Espinosa](https://github.com/alvaroloes/enumer)

enumer.go

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@ package main
33
import "fmt"
44

55
// Arguments to format are:
6+
//
67
// [1]: type name
8+
// [2]: complete error expression
79
const stringNameToValueMethod = `// %[1]sString retrieves an enum value from the enum constants string name.
810
// Throws an error if the param is not part of the enum.
911
func %[1]sString(s string) (%[1]s, error) {
@@ -14,11 +16,12 @@ func %[1]sString(s string) (%[1]s, error) {
1416
if val, ok := _%[1]sNameToValueMap[strings.ToLower(s)]; ok {
1517
return val, nil
1618
}
17-
return 0, fmt.Errorf("%%s does not belong to %[1]s values", s)
19+
return 0, %[2]s
1820
}
1921
`
2022

2123
// Arguments to format are:
24+
//
2225
// [1]: type name
2326
const stringValuesMethod = `// %[1]sValues returns all values of the enum
2427
func %[1]sValues() []%[1]s {
@@ -27,6 +30,7 @@ func %[1]sValues() []%[1]s {
2730
`
2831

2932
// Arguments to format are:
33+
//
3034
// [1]: type name
3135
const stringsMethod = `// %[1]sStrings returns a slice of all String values of the enum
3236
func %[1]sStrings() []string {
@@ -37,6 +41,7 @@ func %[1]sStrings() []string {
3741
`
3842

3943
// Arguments to format are:
44+
//
4045
// [1]: type name
4146
const stringBelongsMethodLoop = `// IsA%[1]s returns "true" if the value is listed in the enum definition. "false" otherwise
4247
func (i %[1]s) IsA%[1]s() bool {
@@ -50,6 +55,7 @@ func (i %[1]s) IsA%[1]s() bool {
5055
`
5156

5257
// Arguments to format are:
58+
//
5359
// [1]: type name
5460
const stringBelongsMethodSet = `// IsA%[1]s returns "true" if the value is listed in the enum definition. "false" otherwise
5561
func (i %[1]s) IsA%[1]s() bool {
@@ -59,6 +65,7 @@ func (i %[1]s) IsA%[1]s() bool {
5965
`
6066

6167
// Arguments to format are:
68+
//
6269
// [1]: type name
6370
const altStringValuesMethod = `func (%[1]s) Values() []string {
6471
return %[1]sStrings()
@@ -70,7 +77,9 @@ func (g *Generator) buildAltStringValuesMethod(typeName string) {
7077
g.Printf(altStringValuesMethod, typeName)
7178
}
7279

73-
func (g *Generator) buildBasicExtras(runs [][]Value, typeName string, runsThreshold int) {
80+
var ErrInvalidValue = fmt.Errorf("invalid enumer value")
81+
82+
func (g *Generator) buildBasicExtras(runs [][]Value, typeName string, runsThreshold int, useTypedErrors bool) {
7483
// At this moment, either "g.declareIndexAndNameVars()" or "g.declareNameVars()" has been called
7584

7685
// Print the slice of values
@@ -89,7 +98,13 @@ func (g *Generator) buildBasicExtras(runs [][]Value, typeName string, runsThresh
8998
g.printNamesSlice(runs, typeName, runsThreshold)
9099

91100
// Print the basic extra methods
92-
g.Printf(stringNameToValueMethod, typeName)
101+
var errorCode string
102+
if useTypedErrors {
103+
errorCode = fmt.Sprintf(`errors.Join(errs.ErrValueInvalid, fmt.Errorf("%%s does not belong to %s values", s))`, typeName)
104+
} else {
105+
errorCode = fmt.Sprintf(`fmt.Errorf("%%s does not belong to %s values", s)`, typeName)
106+
}
107+
g.Printf(stringNameToValueMethod, typeName, errorCode)
93108
g.Printf(stringValuesMethod, typeName)
94109
g.Printf(stringsMethod, typeName)
95110
if len(runs) <= runsThreshold {
@@ -144,6 +159,7 @@ func (g *Generator) printNamesSlice(runs [][]Value, typeName string, runsThresho
144159
}
145160

146161
// Arguments to format are:
162+
//
147163
// [1]: type name
148164
const jsonMethods = `
149165
// MarshalJSON implements the json.Marshaler interface for %[1]s
@@ -164,11 +180,14 @@ func (i *%[1]s) UnmarshalJSON(data []byte) error {
164180
}
165181
`
166182

167-
func (g *Generator) buildJSONMethods(runs [][]Value, typeName string, runsThreshold int) {
183+
func (g *Generator) buildJSONMethods(runs [][]Value, typeName string, runsThreshold int, useTypedErrors bool) {
184+
// For now, just use the standard template
185+
// We rely on the %[1]sString method to provide typed errors when enabled
168186
g.Printf(jsonMethods, typeName)
169187
}
170188

171189
// Arguments to format are:
190+
//
172191
// [1]: type name
173192
const textMethods = `
174193
// MarshalText implements the encoding.TextMarshaler interface for %[1]s
@@ -184,11 +203,14 @@ func (i *%[1]s) UnmarshalText(text []byte) error {
184203
}
185204
`
186205

187-
func (g *Generator) buildTextMethods(runs [][]Value, typeName string, runsThreshold int) {
206+
func (g *Generator) buildTextMethods(runs [][]Value, typeName string, runsThreshold int, useTypedErrors bool) {
207+
// For now, just use the standard template
208+
// We rely on the %[1]sString method to provide typed errors when enabled
188209
g.Printf(textMethods, typeName)
189210
}
190211

191212
// Arguments to format are:
213+
//
192214
// [1]: type name
193215
const yamlMethods = `
194216
// MarshalYAML implements a YAML Marshaler for %[1]s
@@ -209,6 +231,8 @@ func (i *%[1]s) UnmarshalYAML(unmarshal func(interface{}) error) error {
209231
}
210232
`
211233

212-
func (g *Generator) buildYAMLMethods(runs [][]Value, typeName string, runsThreshold int) {
234+
func (g *Generator) buildYAMLMethods(runs [][]Value, typeName string, runsThreshold int, useTypedErrors bool) {
235+
// For now, just use the standard template
236+
// We rely on the %[1]sString method to provide typed errors when enabled
213237
g.Printf(yamlMethods, typeName)
214238
}

errs/errs.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
package errs
2+
3+
import "errors"
4+
5+
// This package defines custom error types for use in the generated code.
6+
7+
// ErrValueInvalid is returned when a value does not belong to the set of valid values for a type.
8+
var ErrValueInvalid = errors.New("the input value is not valid for the type")

golden_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,7 @@ func runGoldenTest(t *testing.T, test Golden,
382382
if len(tokens) != 3 {
383383
t.Fatalf("%s: need type declaration on first line", test.name)
384384
}
385-
g.generate(tokens[1], generateJSON, generateYAML, generateSQL, generateText, generateGQLGen, "noop", trimPrefix, prefix, linecomment, generateValuesMethod)
385+
g.generate(tokens[1], generateJSON, generateYAML, generateSQL, generateText, generateGQLGen, "noop", trimPrefix, prefix, linecomment, generateValuesMethod, false)
386386
got := string(g.format())
387387
if got != loadGolden(test.name) {
388388
// Use this to help build a golden text when changes are needed

stringer.go

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ var (
5656
trimPrefix = flag.String("trimprefix", "", "transform each item name by removing a prefix or comma separated list of prefixes. Default: \"\"")
5757
addPrefix = flag.String("addprefix", "", "transform each item name by adding a prefix. Default: \"\"")
5858
linecomment = flag.Bool("linecomment", false, "use line comment text as printed text when present")
59+
typedErrors = flag.Bool("typederrors", false, "if true, use typed errors for enum string conversion methods. Default: false")
5960
)
6061

6162
var comments arrayFlags
@@ -119,6 +120,10 @@ func main() {
119120
g.Printf("package %s", g.pkg.name)
120121
g.Printf("\n")
121122
g.Printf("import (\n")
123+
if *typedErrors {
124+
g.Printf("\t\"errors\"\n")
125+
g.Printf("\t\"github.com/dmarkham/enumer/errs\"\n")
126+
}
122127
g.Printf("\t\"fmt\"\n")
123128
g.Printf("\t\"strings\"\n")
124129
if *sql {
@@ -135,7 +140,7 @@ func main() {
135140

136141
// Run generate for each type.
137142
for _, typeName := range typs {
138-
g.generate(typeName, *json, *yaml, *sql, *text, *gqlgen, *transformMethod, *trimPrefix, *addPrefix, *linecomment, *altValuesFunc)
143+
g.generate(typeName, *json, *yaml, *sql, *text, *gqlgen, *transformMethod, *trimPrefix, *addPrefix, *linecomment, *altValuesFunc, *typedErrors)
139144
}
140145

141146
// Format the output.
@@ -415,7 +420,7 @@ func (g *Generator) prefixValueNames(values []Value, prefix string) {
415420
// generate produces the String method for the named type.
416421
func (g *Generator) generate(typeName string,
417422
includeJSON, includeYAML, includeSQL, includeText, includeGQLGen bool,
418-
transformMethod string, trimPrefix string, addPrefix string, lineComment bool, includeValuesMethod bool) {
423+
transformMethod string, trimPrefix string, addPrefix string, lineComment bool, includeValuesMethod bool, useTypedErrors bool) {
419424
values := make([]Value, 0, 100)
420425
for _, file := range g.pkg.files {
421426
file.lineComment = lineComment
@@ -468,15 +473,15 @@ func (g *Generator) generate(typeName string,
468473

469474
g.buildNoOpOrderChangeDetect(runs, typeName)
470475

471-
g.buildBasicExtras(runs, typeName, runsThreshold)
476+
g.buildBasicExtras(runs, typeName, runsThreshold, useTypedErrors)
472477
if includeJSON {
473-
g.buildJSONMethods(runs, typeName, runsThreshold)
478+
g.buildJSONMethods(runs, typeName, runsThreshold, useTypedErrors)
474479
}
475480
if includeText {
476-
g.buildTextMethods(runs, typeName, runsThreshold)
481+
g.buildTextMethods(runs, typeName, runsThreshold, useTypedErrors)
477482
}
478483
if includeYAML {
479-
g.buildYAMLMethods(runs, typeName, runsThreshold)
484+
g.buildYAMLMethods(runs, typeName, runsThreshold, useTypedErrors)
480485
}
481486
if includeSQL {
482487
g.addValueAndScanMethod(typeName)

0 commit comments

Comments
 (0)