122 lines
2.7 KiB
Go
122 lines
2.7 KiB
Go
//go:build !wasi
|
|
|
|
package wasirpc
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"sync"
|
|
|
|
"connectrpc.com/connect"
|
|
"github.com/tetratelabs/wazero/api"
|
|
)
|
|
|
|
var (
|
|
ErrMissingMandatoryFunction = errors.New("missing mandatory function")
|
|
ErrUnmappedFunction = errors.New("unmapped function")
|
|
_ connect.HTTPClient = (*hostToModuleClient)(nil)
|
|
)
|
|
|
|
const (
|
|
allocateFuncName string = "allocate"
|
|
defaultIntegerSize = 32
|
|
)
|
|
|
|
type hostToModuleClient struct {
|
|
lock sync.Mutex
|
|
mod api.Module
|
|
alloc api.Function
|
|
functions map[string]api.Function
|
|
}
|
|
|
|
func NewModuleClient(mod api.Module) (connect.HTTPClient, error) {
|
|
client := &hostToModuleClient{
|
|
mod: mod,
|
|
}
|
|
|
|
definitions := mod.ExportedFunctionDefinitions()
|
|
client.functions = make(map[string]api.Function, len(definitions))
|
|
|
|
for funcName, _ := range definitions {
|
|
client.functions[funcName] = mod.ExportedFunction(funcName)
|
|
}
|
|
|
|
var ok bool
|
|
if client.alloc, ok = client.functions[allocateFuncName]; !ok {
|
|
return nil, fmt.Errorf("%w: %s", ErrMissingMandatoryFunction, allocateFuncName)
|
|
}
|
|
|
|
return client, nil
|
|
}
|
|
|
|
func (h *hostToModuleClient) Do(request *http.Request) (resp *http.Response, err error) {
|
|
h.lock.Lock()
|
|
defer h.lock.Unlock()
|
|
|
|
exportedFunc, ok := h.functions[request.URL.Path]
|
|
if !ok {
|
|
return nil, fmt.Errorf("%w: %s", ErrUnmappedFunction, request.URL.Path)
|
|
}
|
|
|
|
reqData, err := io.ReadAll(request.Body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
reqDataPtr, err := h.allocate(request.Context(), uint64(len(reqData)))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if !h.mod.Memory().Write(uint32(reqDataPtr), reqData) {
|
|
return nil, errors.New("failed to write to memory")
|
|
}
|
|
|
|
results, err := exportedFunc.Call(request.Context(), reqDataPtr, uint64(len(reqData)))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
resp = &http.Response{
|
|
Status: "200 OK",
|
|
StatusCode: http.StatusOK,
|
|
Request: request,
|
|
}
|
|
|
|
if len(results) == 0 {
|
|
return resp, nil
|
|
}
|
|
|
|
if len(results) > 1 {
|
|
return nil, fmt.Errorf("unexpected number of results: %d", len(results))
|
|
}
|
|
|
|
respDataPtr, size := uint32(results[0]>>defaultIntegerSize), uint32(results[0])
|
|
if respDataPtr == 0 {
|
|
resp.Body = io.NopCloser(bytes.NewReader(make([]byte, 0)))
|
|
return resp, nil
|
|
}
|
|
|
|
respData, ok := h.mod.Memory().Read(respDataPtr, size)
|
|
if !ok {
|
|
return nil, errors.New("failed to read result data")
|
|
}
|
|
|
|
resp.Body = io.NopCloser(bytes.NewReader(respData))
|
|
resp.ContentLength = int64(len(respData))
|
|
|
|
return resp, nil
|
|
}
|
|
|
|
func (h *hostToModuleClient) allocate(ctx context.Context, size uint64) (ptr uint64, err error) {
|
|
results, err := h.alloc.Call(ctx, size)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
return results[0], nil
|
|
}
|