common/wasirpc/host2module.go

122 lines
2.7 KiB
Go

//go:build !wasi
package wasirpc
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net/http"
"sync"
"connectrpc.com/connect"
"github.com/tetratelabs/wazero/api"
)
var (
ErrMissingMandatoryFunction = errors.New("missing mandatory function")
ErrUnmappedFunction = errors.New("unmapped function")
_ connect.HTTPClient = (*hostToModuleClient)(nil)
)
const (
allocateFuncName string = "allocate"
defaultIntegerSize = 32
)
type hostToModuleClient struct {
lock sync.Mutex
mod api.Module
alloc api.Function
functions map[string]api.Function
}
func NewModuleClient(mod api.Module) (connect.HTTPClient, error) {
client := &hostToModuleClient{
mod: mod,
}
definitions := mod.ExportedFunctionDefinitions()
client.functions = make(map[string]api.Function, len(definitions))
for funcName, _ := range definitions {
client.functions[funcName] = mod.ExportedFunction(funcName)
}
var ok bool
if client.alloc, ok = client.functions[allocateFuncName]; !ok {
return nil, fmt.Errorf("%w: %s", ErrMissingMandatoryFunction, allocateFuncName)
}
return client, nil
}
func (h *hostToModuleClient) Do(request *http.Request) (resp *http.Response, err error) {
h.lock.Lock()
defer h.lock.Unlock()
exportedFunc, ok := h.functions[request.URL.Path]
if !ok {
return nil, fmt.Errorf("%w: %s", ErrUnmappedFunction, request.URL.Path)
}
reqData, err := io.ReadAll(request.Body)
if err != nil {
return nil, err
}
reqDataPtr, err := h.allocate(request.Context(), uint64(len(reqData)))
if err != nil {
return nil, err
}
if !h.mod.Memory().Write(uint32(reqDataPtr), reqData) {
return nil, errors.New("failed to write to memory")
}
results, err := exportedFunc.Call(request.Context(), reqDataPtr, uint64(len(reqData)))
if err != nil {
return nil, err
}
resp = &http.Response{
Status: "200 OK",
StatusCode: http.StatusOK,
Request: request,
}
if len(results) == 0 {
return resp, nil
}
if len(results) > 1 {
return nil, fmt.Errorf("unexpected number of results: %d", len(results))
}
respDataPtr, size := uint32(results[0]>>defaultIntegerSize), uint32(results[0])
if respDataPtr == 0 {
resp.Body = io.NopCloser(bytes.NewReader(make([]byte, 0)))
return resp, nil
}
respData, ok := h.mod.Memory().Read(respDataPtr, size)
if !ok {
return nil, errors.New("failed to read result data")
}
resp.Body = io.NopCloser(bytes.NewReader(respData))
resp.ContentLength = int64(len(respData))
return resp, nil
}
func (h *hostToModuleClient) allocate(ctx context.Context, size uint64) (ptr uint64, err error) {
results, err := h.alloc.Call(ctx, size)
if err != nil {
return 0, err
}
return results[0], nil
}