Skip to content

Commit 14bc3b5

Browse files
authored
interp: add support of Go generics in interpreter
Status: * [x] parsing code with generics * [x] instantiate generics from concrete types * [x] automatic type inference * [x] support of generic recursive types * [x] support of generic methods * [x] support of generic receivers in methods * [x] support of multiple type parameters * [x] support of generic constraints * [x] tests (see _test/gen*.go) Fixes #1363.
1 parent 255b1cf commit 14bc3b5

19 files changed

+986
-123
lines changed

_test/gen1.go

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
package main
2+
3+
import "fmt"
4+
5+
// SumInts adds together the values of m.
6+
func SumInts(m map[string]int64) int64 {
7+
var s int64
8+
for _, v := range m {
9+
s += v
10+
}
11+
return s
12+
}
13+
14+
// SumFloats adds together the values of m.
15+
func SumFloats(m map[string]float64) float64 {
16+
var s float64
17+
for _, v := range m {
18+
s += v
19+
}
20+
return s
21+
}
22+
23+
func main() {
24+
// Initialize a map for the integer values
25+
ints := map[string]int64{
26+
"first": 34,
27+
"second": 12,
28+
}
29+
30+
// Initialize a map for the float values
31+
floats := map[string]float64{
32+
"first": 35.98,
33+
"second": 26.99,
34+
}
35+
36+
fmt.Printf("Non-Generic Sums: %v and %v\n",
37+
SumInts(ints),
38+
SumFloats(floats))
39+
}

_test/gen2.go

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
package main
2+
3+
import "fmt"
4+
5+
// SumIntsOrFloats sums the values of map m. It supports both int64 and float64
6+
// as types for map values.
7+
func SumIntsOrFloats[K comparable, V int64 | float64](m map[K]V) V {
8+
var s V
9+
for _, v := range m {
10+
s += v
11+
}
12+
return s
13+
}
14+
15+
func main() {
16+
// Initialize a map for the integer values
17+
ints := map[string]int64{
18+
"first": 34,
19+
"second": 12,
20+
}
21+
22+
// Initialize a map for the float values
23+
floats := map[string]float64{
24+
"first": 35.98,
25+
"second": 26.99,
26+
}
27+
28+
fmt.Printf("Generic Sums: %v and %v\n",
29+
SumIntsOrFloats[string, int64](ints),
30+
SumIntsOrFloats[string, float64](floats))
31+
}
32+
33+
// Output:
34+
// Generic Sums: 46 and 62.97

_test/gen3.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
package main
2+
3+
type Number interface {
4+
int | int64 | ~float64
5+
}
6+
7+
func Sum[T Number](numbers []T) T {
8+
var total T
9+
for _, x := range numbers {
10+
total += x
11+
}
12+
return total
13+
}
14+
15+
func main() {
16+
xs := []int{3, 5, 10}
17+
total := Sum(xs)
18+
println(total)
19+
}
20+
21+
// Output:
22+
// 18

_test/gen4.go

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
package main
2+
3+
import "fmt"
4+
5+
type List[T any] struct {
6+
head, tail *element[T]
7+
}
8+
9+
// A recursive generic type.
10+
type element[T any] struct {
11+
next *element[T]
12+
val T
13+
}
14+
15+
func (lst *List[T]) Push(v T) {
16+
if lst.tail == nil {
17+
lst.head = &element[T]{val: v}
18+
lst.tail = lst.head
19+
} else {
20+
lst.tail.next = &element[T]{val: v}
21+
lst.tail = lst.tail.next
22+
}
23+
}
24+
25+
func (lst *List[T]) GetAll() []T {
26+
var elems []T
27+
for e := lst.head; e != nil; e = e.next {
28+
elems = append(elems, e.val)
29+
}
30+
return elems
31+
}
32+
33+
func main() {
34+
lst := List[int]{}
35+
lst.Push(10)
36+
lst.Push(13)
37+
lst.Push(23)
38+
fmt.Println("list:", lst.GetAll())
39+
}
40+
41+
// Output:
42+
// list: [10 13 23]

_test/gen5.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
package main
2+
3+
import "fmt"
4+
5+
type Set[Elem comparable] struct {
6+
m map[Elem]struct{}
7+
}
8+
9+
func Make[Elem comparable]() Set[Elem] {
10+
return Set[Elem]{m: make(map[Elem]struct{})}
11+
}
12+
13+
func (s Set[Elem]) Add(v Elem) {
14+
s.m[v] = struct{}{}
15+
}
16+
17+
func main() {
18+
s := Make[int]()
19+
s.Add(1)
20+
fmt.Println(s)
21+
}
22+
23+
// Output:
24+
// {map[1:{}]}

_test/gen6.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
package main
2+
3+
func MapKeys[K comparable, V any](m map[K]V) []K {
4+
r := make([]K, 0, len(m))
5+
for k := range m {
6+
r = append(r, k)
7+
}
8+
return r
9+
}
10+
11+
func main() {
12+
var m = map[int]string{1: "2", 2: "4", 4: "8"}
13+
14+
// Test type inference
15+
println(len(MapKeys(m)))
16+
}
17+
18+
// Output:
19+
// 3

_test/gen7.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
package main
2+
3+
func MapKeys[K comparable, V any](m map[K]V) []K {
4+
r := make([]K, 0, len(m))
5+
for k := range m {
6+
r = append(r, k)
7+
}
8+
return r
9+
}
10+
11+
func main() {
12+
var m = map[int]string{1: "2", 2: "4", 4: "8"}
13+
14+
// Test type inference
15+
println(len(MapKeys))
16+
}
17+
18+
// Error:
19+
// invalid argument for len

_test/gen8.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
package main
2+
3+
type Float interface {
4+
~float32 | ~float64
5+
}
6+
7+
func add[T Float](a, b T) float64 { return float64(a) + float64(b) }
8+
9+
func main() {
10+
var x, y int = 1, 2
11+
println(add(x, y))
12+
}
13+
14+
// Error:
15+
// int does not implement main.Float

_test/gen9.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
package main
2+
3+
type Float interface {
4+
~float32 | ~float64
5+
}
6+
7+
func add[T Float](a, b T) float64 { return float64(a) + float64(b) }
8+
9+
func main() {
10+
println(add(1, 2))
11+
}
12+
13+
// Error:
14+
// untyped int does not implement main.Float

interp/ast.go

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ const (
7272
importSpec
7373
incDecStmt
7474
indexExpr
75+
indexListExpr
7576
interfaceType
7677
keyValueExpr
7778
labeledStmt
@@ -155,6 +156,7 @@ var kinds = [...]string{
155156
importSpec: "importSpec",
156157
incDecStmt: "incDecStmt",
157158
indexExpr: "indexExpr",
159+
indexListExpr: "indexListExpr",
158160
interfaceType: "interfaceType",
159161
keyValueExpr: "keyValueExpr",
160162
labeledStmt: "labeledStmt",
@@ -694,7 +696,7 @@ func (interp *Interpreter) ast(f ast.Node) (string, *node, error) {
694696
n := addChild(&root, anc, pos, funcDecl, aNop)
695697
n.val = n
696698
if a.Recv == nil {
697-
// function is not a method, create an empty receiver list
699+
// Function is not a method, create an empty receiver list.
698700
addChild(&root, astNode{n, nod}, pos, fieldList, aNop)
699701
}
700702
st.push(n, nod)
@@ -706,7 +708,13 @@ func (interp *Interpreter) ast(f ast.Node) (string, *node, error) {
706708
st.push(n, nod)
707709

708710
case *ast.FuncType:
709-
st.push(addChild(&root, anc, pos, funcType, aNop), nod)
711+
n := addChild(&root, anc, pos, funcType, aNop)
712+
n.val = n
713+
if a.TypeParams == nil {
714+
// Function has no type parameters, create an empty fied list.
715+
addChild(&root, astNode{n, nod}, pos, fieldList, aNop)
716+
}
717+
st.push(n, nod)
710718

711719
case *ast.GenDecl:
712720
var kind nkind
@@ -776,6 +784,9 @@ func (interp *Interpreter) ast(f ast.Node) (string, *node, error) {
776784
case *ast.IndexExpr:
777785
st.push(addChild(&root, anc, pos, indexExpr, aGetIndex), nod)
778786

787+
case *ast.IndexListExpr:
788+
st.push(addChild(&root, anc, pos, indexListExpr, aNop), nod)
789+
779790
case *ast.InterfaceType:
780791
st.push(addChild(&root, anc, pos, interfaceType, aNop), nod)
781792

0 commit comments

Comments
 (0)