208 lines
6.5 KiB
Go
208 lines
6.5 KiB
Go
package protocol
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"reflect"
|
|
"strings"
|
|
|
|
rpcv1 "code.icb4dc0.de/buildr/wasi-module-sdk-go/protocol/generated/rpc/v1"
|
|
)
|
|
|
|
var (
|
|
ErrUnmatchingType = errors.New("field type does not match wire value")
|
|
)
|
|
|
|
func Unmarshal(input *rpcv1.ModuleSpec, into any) error {
|
|
cfg := unmarshalConfig{
|
|
TagName: "hcl",
|
|
}
|
|
|
|
val := reflect.ValueOf(into)
|
|
if val.Kind() != reflect.Ptr {
|
|
return errors.New("into value must be a pointer")
|
|
}
|
|
return cfg.unmarshal(input.Values, val.Elem(), reflect.TypeOf(into).Elem())
|
|
}
|
|
|
|
type unmarshalConfig struct {
|
|
TagName string
|
|
}
|
|
|
|
func (cfg unmarshalConfig) unmarshal(input map[string]*rpcv1.ModuleSpec_Value, into reflect.Value, intoType reflect.Type) error {
|
|
for i := 0; i < intoType.NumField(); i++ {
|
|
tf := intoType.Field(i)
|
|
if !tf.IsExported() {
|
|
continue
|
|
}
|
|
|
|
fieldName := cfg.fieldName(tf)
|
|
|
|
val, ok := input[fieldName]
|
|
if !ok {
|
|
continue
|
|
}
|
|
|
|
if mapped, err := cfg.mapSpecValueTo(val, tf.Type); err != nil {
|
|
return fmt.Errorf("failed to map value for field %s: %w", tf.Name, err)
|
|
} else if into.Type().Kind() == reflect.Pointer {
|
|
into.Elem().Field(i).Set(mapped)
|
|
} else {
|
|
into.Field(i).Set(mapped)
|
|
}
|
|
|
|
delete(input, fieldName)
|
|
}
|
|
|
|
if len(input) > 0 {
|
|
keys := make([]string, 0, len(input))
|
|
for k := range input {
|
|
keys = append(keys, k)
|
|
}
|
|
return fmt.Errorf("unknown fields: %v", keys)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (cfg unmarshalConfig) mapSpecValueTo(val *rpcv1.ModuleSpec_Value, targetType reflect.Type) (reflect.Value, error) {
|
|
switch val.Type {
|
|
case rpcv1.ModuleSpec_ValueTypeUnknown:
|
|
return reflect.Value{}, ErrUnmatchingType
|
|
case rpcv1.ModuleSpec_ValueTypeObject:
|
|
if targetType.Kind() == reflect.Struct {
|
|
structVal := reflect.New(targetType)
|
|
if err := cfg.unmarshal(val.ComplexValue, structVal, targetType); err != nil {
|
|
return reflect.Value{}, err
|
|
} else {
|
|
return structVal.Elem(), nil
|
|
}
|
|
} else if targetType.Kind() == reflect.Pointer && targetType.Elem().Kind() == reflect.Struct {
|
|
structVal := reflect.New(targetType.Elem())
|
|
if err := cfg.unmarshal(val.ComplexValue, structVal, targetType.Elem()); err != nil {
|
|
return reflect.Value{}, err
|
|
} else {
|
|
return structVal, nil
|
|
}
|
|
} else if targetType.Kind() != reflect.Map {
|
|
return reflect.Value{}, fmt.Errorf("%w: expected struct, got %s", ErrUnmatchingType, targetType.String())
|
|
}
|
|
fallthrough
|
|
case rpcv1.ModuleSpec_ValueTypeMap:
|
|
if targetType.Kind() != reflect.Map {
|
|
return reflect.Value{}, fmt.Errorf("%w: expected map, got %v", ErrUnmatchingType, targetType)
|
|
}
|
|
|
|
mapVal := reflect.MakeMap(targetType)
|
|
|
|
for k, v := range val.ComplexValue {
|
|
if mappedVal, err := cfg.mapSpecValueTo(v, targetType.Elem()); err != nil {
|
|
return reflect.Value{}, err
|
|
} else {
|
|
mapVal.SetMapIndex(reflect.ValueOf(k), mappedVal)
|
|
}
|
|
}
|
|
|
|
return mapVal, nil
|
|
|
|
case rpcv1.ModuleSpec_ValueTypeSingle:
|
|
switch sv := val.SingleValue.(type) {
|
|
case *rpcv1.ModuleSpec_Value_BoolValue:
|
|
if targetType.Kind() != reflect.Bool {
|
|
return reflect.Value{}, fmt.Errorf("%w: expected bool, got %v", ErrUnmatchingType, targetType)
|
|
}
|
|
return reflect.ValueOf(sv.BoolValue), nil
|
|
case *rpcv1.ModuleSpec_Value_StringValue:
|
|
if targetType.Kind() != reflect.String {
|
|
return reflect.Value{}, fmt.Errorf("%w: expected string, got %v", ErrUnmatchingType, targetType)
|
|
}
|
|
return reflect.ValueOf(sv.StringValue), nil
|
|
case *rpcv1.ModuleSpec_Value_IntValue:
|
|
if targetType.Kind() != reflect.Int {
|
|
return reflect.Value{}, fmt.Errorf("%w: expected int, got %v", ErrUnmatchingType, targetType)
|
|
}
|
|
return reflect.ValueOf(int(sv.IntValue)), nil
|
|
case *rpcv1.ModuleSpec_Value_DoubleValue:
|
|
if targetType.Kind() != reflect.Float64 {
|
|
return reflect.Value{}, fmt.Errorf("%w: expected float64, got %v", ErrUnmatchingType, targetType)
|
|
}
|
|
return reflect.ValueOf(sv.DoubleValue), nil
|
|
}
|
|
case rpcv1.ModuleSpec_ValueTypeBoolSlice:
|
|
if targetType.Kind() != reflect.Slice {
|
|
return reflect.Value{}, fmt.Errorf("%w: expected slice, got %v", ErrUnmatchingType, targetType)
|
|
} else if targetType.Elem().Kind() != reflect.Bool {
|
|
return reflect.Value{}, fmt.Errorf("%w: expected bool, got %v", ErrUnmatchingType, targetType.Elem())
|
|
}
|
|
return reflect.ValueOf(val.BoolValues), nil
|
|
case rpcv1.ModuleSpec_ValueTypeStringSlice:
|
|
if targetType.Kind() != reflect.Slice {
|
|
return reflect.Value{}, fmt.Errorf("%w: expected slice, got %v", ErrUnmatchingType, targetType)
|
|
} else if targetType.Elem().Kind() != reflect.String {
|
|
return reflect.Value{}, fmt.Errorf("%w: expected string, got %v", ErrUnmatchingType, targetType.Elem())
|
|
}
|
|
return reflect.ValueOf(val.StringValues), nil
|
|
case rpcv1.ModuleSpec_ValueTypeIntSlice:
|
|
if targetType.Kind() != reflect.Slice {
|
|
return reflect.Value{}, fmt.Errorf("%w: expected slice, got %v", ErrUnmatchingType, targetType)
|
|
}
|
|
switch targetType.Elem().Kind() {
|
|
case reflect.Int:
|
|
return reflect.ValueOf(convertNumericSlice[int64, int](val.IntValues)), nil
|
|
case reflect.Int8:
|
|
return reflect.ValueOf(convertNumericSlice[int64, int8](val.IntValues)), nil
|
|
case reflect.Int16:
|
|
return reflect.ValueOf(convertNumericSlice[int64, int16](val.IntValues)), nil
|
|
case reflect.Int32:
|
|
return reflect.ValueOf(convertNumericSlice[int64, int32](val.IntValues)), nil
|
|
case reflect.Int64:
|
|
return reflect.ValueOf(val.IntValues), nil
|
|
default:
|
|
return reflect.Value{}, fmt.Errorf("%w: expected int, got %v", ErrUnmatchingType, targetType.Elem())
|
|
}
|
|
case rpcv1.ModuleSpec_ValueTypeDoubleSlice:
|
|
if targetType.Kind() != reflect.Slice {
|
|
return reflect.Value{}, fmt.Errorf("%w: expected slice, got %v", ErrUnmatchingType, targetType)
|
|
}
|
|
switch targetType.Elem().Kind() {
|
|
case reflect.Float32:
|
|
return reflect.ValueOf(convertNumericSlice[float64, float32](val.DoubleValues)), nil
|
|
case reflect.Float64:
|
|
return reflect.ValueOf(val.DoubleValues), nil
|
|
default:
|
|
return reflect.Value{}, fmt.Errorf("%w: expected int, got %v", ErrUnmatchingType, targetType.Elem())
|
|
}
|
|
}
|
|
|
|
return reflect.Value{}, nil
|
|
}
|
|
|
|
func (cfg unmarshalConfig) fieldName(field reflect.StructField) string {
|
|
tagVal, ok := field.Tag.Lookup(cfg.TagName)
|
|
if !ok {
|
|
return field.Name
|
|
}
|
|
|
|
split := strings.Split(tagVal, ",")
|
|
|
|
switch {
|
|
case len(split) == 0:
|
|
fallthrough
|
|
case split[0] == "":
|
|
return strings.ToLower(field.Name)
|
|
default:
|
|
return strings.ToLower(split[0])
|
|
}
|
|
}
|
|
|
|
type numeric interface {
|
|
Integer | Float
|
|
}
|
|
|
|
func convertNumericSlice[TIn numeric, TOut numeric](input []TIn) []TOut {
|
|
output := make([]TOut, len(input))
|
|
for i, v := range input {
|
|
output[i] = TOut(v)
|
|
}
|
|
return output
|
|
}
|