76 lines
1.6 KiB
Go
76 lines
1.6 KiB
Go
|
package integration
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"errors"
|
||
|
"github.com/tetratelabs/wazero/api"
|
||
|
)
|
||
|
|
||
|
type Message interface {
|
||
|
MarshalVT() (dAtA []byte, err error)
|
||
|
UnmarshalVT(dAtA []byte) error
|
||
|
}
|
||
|
|
||
|
func newMemoryManager(mod api.Module) *memoryManager {
|
||
|
return &memoryManager{
|
||
|
allocate: mod.ExportedFunction("malloc"),
|
||
|
deallocate: mod.ExportedFunction("free"),
|
||
|
}
|
||
|
}
|
||
|
|
||
|
type memoryManager struct {
|
||
|
allocate api.Function
|
||
|
deallocate api.Function
|
||
|
danglingAllocations []uint64
|
||
|
}
|
||
|
|
||
|
func (m *memoryManager) WriteMessage(ctx context.Context, mod api.Module, msg Message) (uint64, error) {
|
||
|
data, err := msg.MarshalVT()
|
||
|
if err != nil {
|
||
|
return 0, err
|
||
|
}
|
||
|
ptr, err := m.Allocate(ctx, uint64(len(data)))
|
||
|
if err != nil {
|
||
|
return 0, err
|
||
|
}
|
||
|
|
||
|
if !mod.Memory().Write(uint32(ptr), data) {
|
||
|
return 0, errors.New("failed to write message to memory")
|
||
|
}
|
||
|
|
||
|
return (ptr << uint64(32)) | uint64(len(data)), nil
|
||
|
}
|
||
|
|
||
|
func (m *memoryManager) Allocate(ctx context.Context, size uint64) (ptr uint64, err error) {
|
||
|
results, err := m.allocate.Call(ctx, size)
|
||
|
if err != nil {
|
||
|
return 0, err
|
||
|
}
|
||
|
|
||
|
m.danglingAllocations = append(m.danglingAllocations, results[0])
|
||
|
|
||
|
return results[0], nil
|
||
|
}
|
||
|
|
||
|
func (m *memoryManager) WithMem(ctx context.Context, size uint64, delegate func(ptr uint64) error) error {
|
||
|
results, err := m.allocate.Call(ctx, size)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
defer m.deallocate.Call(ctx, results[0])
|
||
|
|
||
|
return delegate(results[0])
|
||
|
}
|
||
|
|
||
|
func (m *memoryManager) Close() (err error) {
|
||
|
ctx := context.Background()
|
||
|
|
||
|
for i := range m.danglingAllocations {
|
||
|
_, e := m.deallocate.Call(ctx, m.danglingAllocations[i])
|
||
|
err = errors.Join(err, e)
|
||
|
}
|
||
|
|
||
|
return err
|
||
|
}
|