//go:build !wasi package wasirpc import ( "bytes" "context" "errors" "fmt" "io" "net/http" "sync" "connectrpc.com/connect" "github.com/tetratelabs/wazero/api" ) var ( ErrMissingMandatoryFunction = errors.New("missing mandatory function") ErrUnmappedFunction = errors.New("unmapped function") _ connect.HTTPClient = (*hostToModuleClient)(nil) ) const ( allocateFuncName string = "allocate" defaultIntegerSize = 32 ) type hostToModuleClient struct { lock sync.Mutex mod api.Module alloc api.Function functions map[string]api.Function } func NewModuleClient(mod api.Module) (connect.HTTPClient, error) { client := &hostToModuleClient{ mod: mod, } definitions := mod.ExportedFunctionDefinitions() client.functions = make(map[string]api.Function, len(definitions)) for funcName, _ := range definitions { client.functions[funcName] = mod.ExportedFunction(funcName) } var ok bool if client.alloc, ok = client.functions[allocateFuncName]; !ok { return nil, fmt.Errorf("%w: %s", ErrMissingMandatoryFunction, allocateFuncName) } return client, nil } func (h *hostToModuleClient) Do(request *http.Request) (resp *http.Response, err error) { h.lock.Lock() defer h.lock.Unlock() exportedFunc, ok := h.functions[request.URL.Path] if !ok { return nil, fmt.Errorf("%w: %s", ErrUnmappedFunction, request.URL.Path) } reqData, err := io.ReadAll(request.Body) if err != nil { return nil, err } reqDataPtr, err := h.allocate(request.Context(), uint64(len(reqData))) if err != nil { return nil, err } if !h.mod.Memory().Write(uint32(reqDataPtr), reqData) { return nil, errors.New("failed to write to memory") } results, err := exportedFunc.Call(request.Context(), reqDataPtr, uint64(len(reqData))) if err != nil { return nil, err } resp = &http.Response{ Status: "200 OK", StatusCode: http.StatusOK, Request: request, Header: map[string][]string{ "Content-Type": {"application/proto"}, }, } if len(results) == 0 { return resp, nil } if len(results) > 1 { return nil, fmt.Errorf("unexpected number of results: %d", len(results)) } respDataPtr, size := uint32(results[0]>>defaultIntegerSize), uint32(results[0]) if respDataPtr == 0 { resp.Body = io.NopCloser(bytes.NewReader(make([]byte, 0))) return resp, nil } respData, ok := h.mod.Memory().Read(respDataPtr, size) if !ok { return nil, errors.New("failed to read result data") } resp.Body = io.NopCloser(bytes.NewReader(respData)) resp.ContentLength = int64(len(respData)) return resp, nil } func (h *hostToModuleClient) allocate(ctx context.Context, size uint64) (ptr uint64, err error) { results, err := h.alloc.Call(ctx, size) if err != nil { return 0, err } return results[0], nil }