Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions interp/interp.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ import (
"path"
"path/filepath"
"reflect"
"runtime"
"runtime/debug"
"strconv"
"strings"
"sync"
Expand Down Expand Up @@ -220,6 +222,7 @@ type Interpreter struct {
hooks *hooks // symbol hooks

debugger *Debugger
stopOnce sync.Once
}

const (
Expand Down Expand Up @@ -641,6 +644,46 @@ func (interp *Interpreter) EvalWithContext(ctx context.Context, src string) (ref
return v, err
}

func (interp *Interpreter) ExecFunc(ctx context.Context, fn interface{}, args ...interface{}) ([]reflect.Value, error) {
var result []reflect.Value
var err error

interp.mutex.Lock()
interp.done = make(chan struct{})
interp.cancelChan = !interp.opt.fastChan
interp.mutex.Unlock()

done := make(chan struct{})
vFn := reflect.Indirect(reflect.ValueOf(fn))
if vFn.Kind() != reflect.Func {
return nil, fmt.Errorf("fn is not function")
}
vArgs := make([]reflect.Value, len(args))
for i := range args {
vArgs[i] = reflect.ValueOf(args[i])
}
go func() {
defer close(done)
defer func() {
r := recover()
if r != nil {
var pc [64]uintptr // 64 frames should be enough.
n := runtime.Callers(1, pc[:])
err = Panic{Value: r, Callers: pc[:n], Stack: debug.Stack()}
}
}()
result = vFn.Call(vArgs)
}()

select {
case <-ctx.Done():
interp.stop()
return nil, ctx.Err()
case <-done:
}
return result, err
}

// stop sends a semaphore to all running frames and closes the chan
// operation short circuit channel. stop may only be called once per
// invocation of EvalWithContext.
Expand All @@ -649,6 +692,14 @@ func (interp *Interpreter) stop() {
close(interp.done)
}

// Stop sends a semaphore to all running frames and closes the chan
// operation short circuit channel.
func (interp *Interpreter) Stop() {
interp.stopOnce.Do(func() {
interp.stop()
})
}

func (interp *Interpreter) runid() uint64 { return atomic.LoadUint64(&interp.id) }

// getWrapper returns the wrapper type of the corresponding interface, or nil if not found.
Expand Down
35 changes: 35 additions & 0 deletions interp/interp_eval_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1749,3 +1749,38 @@ func TestRestrictedEnv(t *testing.T) {
t.Fatal("expected \"\", got " + s)
}
}

func TestExecFuncWithArgs(t *testing.T) {
ctx := context.Background()
ctx, cancel := context.WithTimeout(ctx, 100*time.Millisecond)
defer cancel()
done := make(chan struct{})
go func() {
defer close(done)
i := interp.New(interp.Options{})
i.Use(stdlib.Symbols)
eval(t, i, `
import "time"
func Bar(msg string) string {
for i:=0;;i++ {
println(i)
time.Sleep(1*time.Second)
}
return msg
}
`)

v := eval(t, i, "Bar")
bar := v.Interface().(func(string) string)

_, err := i.ExecFunc(ctx, bar, "hello")
if err != context.DeadlineExceeded {
t.Errorf("unexpected error")
}
}()
select {
case <-time.After(time.Second):
t.Errorf("timeout failed to terminate execution")
case <-done:
}
}