wasi-module-sdk-go/integration/test_host.go

277 lines
6.1 KiB
Go
Raw Normal View History

2023-05-08 13:21:31 +00:00
package integration
import (
"context"
"crypto/md5"
"errors"
2023-05-13 15:46:09 +00:00
"fmt"
2023-05-08 13:21:31 +00:00
"os"
"time"
2023-05-18 16:21:44 +00:00
"code.icb4dc0.de/buildr/wasi-module-sdk-go/mem"
2023-05-08 13:21:31 +00:00
"github.com/google/uuid"
"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/api"
"github.com/tetratelabs/wazero/imports/wasi_snapshot_preview1"
"golang.org/x/exp/slog"
2023-05-18 16:21:44 +00:00
sdk "code.icb4dc0.de/buildr/wasi-module-sdk-go"
rpcv1 "code.icb4dc0.de/buildr/wasi-module-sdk-go/protocol/generated/rpc/v1"
2023-05-08 13:21:31 +00:00
)
func StateKey(cat sdk.Category, modName, key string) string {
h := md5.New()
_, _ = h.Write([]byte(cat.String()))
_, _ = h.Write([]byte(modName))
_, _ = h.Write([]byte(key))
return string(h.Sum(nil))
}
type TestSpec struct {
ModuleCategory sdk.Category
ModuleType string
ModuleName string
RawTaskSpec []byte
}
type HostOption func(h *Host)
func WithState(key string, value []byte) HostOption {
return func(h *Host) {
h.state[key] = value
}
}
func NewHost(logger *slog.Logger, opts ...HostOption) *Host {
h := &Host{
Logger: logger,
state: make(map[string][]byte),
}
for i := range opts {
opts[i](h)
}
return h
}
type Host struct {
memMgr *memoryManager
state map[string][]byte
Logger *slog.Logger
}
func (h *Host) Run(ctx context.Context, wasiPayload []byte, spec TestSpec) (err error) {
runtimeConfig := wazero.NewRuntimeConfig().
WithCloseOnContextDone(true)
r := wazero.NewRuntimeWithConfig(ctx, runtimeConfig)
defer func() {
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
err = errors.Join(err, r.Close(ctx))
}()
_, err = r.NewHostModuleBuilder("buildr").
NewFunctionBuilder().WithFunc(h.log).Export("log_msg").
NewFunctionBuilder().WithFunc(h.getState).Export("get_state").
NewFunctionBuilder().WithFunc(h.setState).Export("set_state").
Instantiate(ctx)
if err != nil {
return err
}
closer, err := wasi_snapshot_preview1.Instantiate(ctx, r)
if err != nil {
return err
}
defer func() {
err = errors.Join(err, closer.Close(context.Background()))
}()
config := wazero.NewFSConfig().
WithDirMount(".", "/work")
moduleConfig := wazero.NewModuleConfig().
WithStdout(os.Stdout).
WithStderr(os.Stderr).
WithFSConfig(config)
mod, err := r.InstantiateWithConfig(ctx, wasiPayload, moduleConfig)
if err != nil {
return err
}
2023-05-13 15:46:09 +00:00
h.memMgr = newMemoryManager(mod)
inv, err := h.getInventory(ctx, mod)
if err != nil {
return err
}
for _, ref := range inv {
fmt.Printf("%s/%s/n", ref.Category, ref.Type)
}
2023-05-08 13:21:31 +00:00
startTask := &rpcv1.StartTaskRequest{
2023-05-13 15:46:09 +00:00
Reference: &rpcv1.TaskReference{
Id: uuid.NewString(),
Module: &rpcv1.ModuleReference{
ModuleCategory: spec.ModuleCategory.String(),
ModuleType: spec.ModuleType,
},
Name: spec.ModuleName,
2023-05-08 13:21:31 +00:00
},
Buildr: &rpcv1.Buildr{
Repo: &rpcv1.Buildr_Repo{
Root: "/work",
},
},
RawTask: spec.RawTaskSpec,
}
data, err := startTask.MarshalVT()
if err != nil {
return err
}
run := mod.ExportedFunction("run")
defer func() {
err = errors.Join(err, h.memMgr.Close())
}()
err = h.memMgr.WithMem(ctx, uint64(len(data)), func(ptr uint64) error {
if !mod.Memory().Write(uint32(ptr), data) {
return errors.New("failed to write to memory")
}
_, err = run.Call(ctx, ptr, uint64(len(data)))
return err
})
return err
}
2023-05-13 15:46:09 +00:00
func (h *Host) getInventory(ctx context.Context, mod api.Module) (refs []sdk.Reference, err error) {
result, err := mod.ExportedFunction("inventory").Call(ctx)
if err != nil {
return nil, err
}
ptr, size := mem.UnifiedPtrToSizePtr(result[0])
if ptr == 0 {
return nil, errors.New("failed to get inventory - 0 pointer")
}
defer func() {
err = errors.Join(err, h.memMgr.Deallocate(ctx, uint64(ptr)))
}()
data, ok := mod.Memory().Read(ptr, size)
if !ok {
return nil, errors.New("failed to get inventory")
}
var inventory rpcv1.PluginInventory
if err = inventory.UnmarshalVT(data); err != nil {
return nil, err
}
refs = make([]sdk.Reference, 0, len(inventory.Modules))
for _, m := range inventory.Modules {
refs = append(refs, sdk.Reference{
Category: sdk.Category(m.GetModuleCategory()),
Type: m.GetModuleType(),
})
}
return refs, nil
}
2023-05-08 13:21:31 +00:00
func (h *Host) getState(ctx context.Context, m api.Module, offset, byteCount uint32) uint64 {
if h.state == nil {
h.state = make(map[string][]byte)
}
buf, ok := m.Memory().Read(offset, byteCount)
if !ok {
return 0
}
getStateReq := new(rpcv1.GetStateRequest)
if err := getStateReq.UnmarshalVT(buf); err != nil {
h.Logger.Error("failed to unmarshal getStateRequest", slog.String("err", err.Error()))
return 0
}
resp := new(rpcv1.GetStateResponse)
resp.Data, _ = h.state[string(getStateReq.Key)]
if ptr, err := h.memMgr.WriteMessage(ctx, m, resp); err != nil {
h.Logger.Error("Failed to write message", slog.String("err", err.Error()))
return 0
} else {
return ptr
}
}
func (h *Host) setState(ctx context.Context, m api.Module, offset, byteCount uint32) (result uint64) {
if h.state == nil {
h.state = make(map[string][]byte)
}
buf, ok := m.Memory().Read(offset, byteCount)
if !ok {
return 0
}
setState := new(rpcv1.SetState)
if err := setState.UnmarshalVT(buf); err != nil {
h.Logger.Error("failed to unmarshal SetState", slog.String("err", err.Error()))
return 0
}
var resp rpcv1.Result
if len(setState.Key) < 1 {
resp.Error = "key might not be empty"
} else {
h.state[string(setState.Key)] = setState.Data
resp.Success = true
}
if ptr, err := h.memMgr.WriteMessage(ctx, m, &resp); err != nil {
h.Logger.Error("Failed to write message", slog.String("err", err.Error()))
return 0
} else {
return ptr
}
}
func (h *Host) log(ctx context.Context, m api.Module, offset, byteCount uint32) {
buf, ok := m.Memory().Read(offset, byteCount)
if !ok {
return
}
taskLog := new(rpcv1.TaskLog)
if err := taskLog.UnmarshalVT(buf); err != nil {
h.Logger.Warn("failed to unmarshal task log", slog.String("err", err.Error()))
return
}
rec := slog.NewRecord(time.UnixMicro(taskLog.Time), slog.Level(taskLog.Level), taskLog.Message, 0)
for i := range taskLog.Attributes {
attr := taskLog.Attributes[i]
rec.AddAttrs(slog.String(attr.Key, attr.Value))
}
_ = h.Logger.Handler().Handle(ctx, rec)
}