common/protocol/unmarshal.go

206 lines
6.6 KiB
Go

package protocol
import (
"errors"
"fmt"
"reflect"
"strings"
commonv1 "code.icb4dc0.de/buildr/api/generated/common/v1"
)
var ErrUnmatchingType = errors.New("field type does not match wire value")
type unmarshalConfigOption interface {
applyToUnmarshalConfig(cfg *unmarshalConfig)
}
func Unmarshal(input *commonv1.ModuleSpec, into any, opts ...unmarshalConfigOption) error {
if u, ok := into.(SpecUnmarshaler); ok {
return u.UnmarshalModuleSpec(input)
}
cfg := unmarshalConfig{
protocolConfig: defaultConfig(),
}
for _, opt := range opts {
opt.applyToUnmarshalConfig(&cfg)
}
val := reflect.ValueOf(into)
switch val.Kind() {
case reflect.Ptr:
default:
return errors.New("into value must be a pointer")
}
return cfg.unmarshal(input.Values, val.Elem(), reflect.TypeOf(into).Elem())
}
type unmarshalConfig struct {
protocolConfig
}
func (cfg unmarshalConfig) unmarshal(input map[string]*commonv1.ModuleSpec_Value, into reflect.Value, intoType reflect.Type) error {
for i := 0; i < intoType.NumField(); i++ {
tf := intoType.Field(i)
if !tf.IsExported() {
continue
}
val, ok := input[strings.ToLower(tf.Name)]
if !ok {
continue
}
if into.Type().Kind() == reflect.Pointer {
if u, ok := into.Elem().Field(i).Interface().(SpecValueUnmarshaler); ok {
if err := u.UnmarshalSpecValue(val); err != nil {
return err
}
continue
}
} else {
if u, ok := into.Field(i).Interface().(SpecValueUnmarshaler); ok {
if err := u.UnmarshalSpecValue(val); err != nil {
return err
}
continue
}
}
if mapped, err := cfg.mapSpecValueTo(val, tf.Type); err != nil {
return err
} else if into.Type().Kind() == reflect.Pointer {
into.Elem().Field(i).Set(mapped)
} else {
into.Field(i).Set(mapped)
}
}
return nil
}
func (cfg unmarshalConfig) mapSpecValueTo(val *commonv1.ModuleSpec_Value, targetType reflect.Type) (reflect.Value, error) {
switch val.Type {
case commonv1.ModuleSpec_ValueTypeUnknown:
return reflect.Value{}, fmt.Errorf("%w: expected %s", ErrUnmatchingType, targetType.String())
case commonv1.ModuleSpec_ValueTypeObject:
if targetType.Kind() == reflect.Struct {
structVal := reflect.New(targetType)
if err := cfg.unmarshal(val.ComplexValue, structVal, targetType); err != nil {
return reflect.Value{}, err
} else {
return structVal.Elem(), nil
}
} else if targetType.Kind() == reflect.Pointer && targetType.Elem().Kind() == reflect.Struct {
structVal := reflect.New(targetType.Elem())
if err := cfg.unmarshal(val.ComplexValue, structVal, targetType.Elem()); err != nil {
return reflect.Value{}, err
} else {
return structVal, nil
}
} else {
return reflect.Value{}, fmt.Errorf("%w: expected struct, got %s", ErrUnmatchingType, val.Type)
}
case commonv1.ModuleSpec_ValueTypeMap:
if targetType.Kind() != reflect.Map {
return reflect.Value{}, fmt.Errorf("%w: expected map, got %v", ErrUnmatchingType, targetType)
}
mapVal := reflect.MakeMap(targetType)
for k, v := range val.ComplexValue {
if mappedVal, err := cfg.mapSpecValueTo(v, targetType.Elem()); err != nil {
return reflect.Value{}, err
} else {
mapVal.SetMapIndex(reflect.ValueOf(k), mappedVal)
}
}
return mapVal, nil
case commonv1.ModuleSpec_ValueTypeSingle:
switch sv := val.SingleValue.(type) {
case *commonv1.ModuleSpec_Value_BoolValue:
if targetType.Kind() != reflect.Bool {
return reflect.Value{}, fmt.Errorf("%w: expected bool, got %v", ErrUnmatchingType, targetType)
}
return reflect.ValueOf(sv.BoolValue), nil
case *commonv1.ModuleSpec_Value_StringValue:
if targetType.Kind() != reflect.String {
return reflect.Value{}, fmt.Errorf("%w: expected string, got %v", ErrUnmatchingType, targetType)
}
return reflect.ValueOf(sv.StringValue), nil
case *commonv1.ModuleSpec_Value_IntValue:
if targetType.Kind() != reflect.Int {
return reflect.Value{}, fmt.Errorf("%w: expected int, got %v", ErrUnmatchingType, targetType)
}
return reflect.ValueOf(int(sv.IntValue)), nil
case *commonv1.ModuleSpec_Value_DoubleValue:
if targetType.Kind() != reflect.Float64 {
return reflect.Value{}, fmt.Errorf("%w: expected float64, got %v", ErrUnmatchingType, targetType)
}
return reflect.ValueOf(sv.DoubleValue), nil
}
case commonv1.ModuleSpec_ValueTypeBoolSlice:
if targetType.Kind() != reflect.Slice {
return reflect.Value{}, fmt.Errorf("%w: expected slice, got %v", ErrUnmatchingType, targetType)
} else if targetType.Elem().Kind() != reflect.Bool {
return reflect.Value{}, fmt.Errorf("%w: expected bool, got %v", ErrUnmatchingType, targetType.Elem())
}
return reflect.ValueOf(val.BoolValues), nil
case commonv1.ModuleSpec_ValueTypeStringSlice:
if targetType.Kind() != reflect.Slice {
return reflect.Value{}, fmt.Errorf("%w: expected slice, got %v", ErrUnmatchingType, targetType)
} else if targetType.Elem().Kind() != reflect.String {
return reflect.Value{}, fmt.Errorf("%w: expected string, got %v", ErrUnmatchingType, targetType.Elem())
}
return reflect.ValueOf(val.StringValues), nil
case commonv1.ModuleSpec_ValueTypeIntSlice:
if targetType.Kind() != reflect.Slice {
return reflect.Value{}, fmt.Errorf("%w: expected slice, got %v", ErrUnmatchingType, targetType)
}
switch targetType.Elem().Kind() {
case reflect.Int:
return reflect.ValueOf(convertNumericSlice[int64, int](val.IntValues)), nil
case reflect.Int8:
return reflect.ValueOf(convertNumericSlice[int64, int8](val.IntValues)), nil
case reflect.Int16:
return reflect.ValueOf(convertNumericSlice[int64, int16](val.IntValues)), nil
case reflect.Int32:
return reflect.ValueOf(convertNumericSlice[int64, int32](val.IntValues)), nil
case reflect.Int64:
return reflect.ValueOf(val.IntValues), nil
default:
return reflect.Value{}, fmt.Errorf("%w: expected int, got %v", ErrUnmatchingType, targetType.Elem())
}
case commonv1.ModuleSpec_ValueTypeDoubleSlice:
if targetType.Kind() != reflect.Slice {
return reflect.Value{}, fmt.Errorf("%w: expected slice, got %v", ErrUnmatchingType, targetType)
}
switch targetType.Elem().Kind() {
case reflect.Float32:
return reflect.ValueOf(convertNumericSlice[float64, float32](val.DoubleValues)), nil
case reflect.Float64:
return reflect.ValueOf(val.DoubleValues), nil
default:
return reflect.Value{}, fmt.Errorf("%w: expected int, got %v", ErrUnmatchingType, targetType.Elem())
}
}
return reflect.Value{}, nil
}
type numeric interface {
Integer | Float
}
func convertNumericSlice[TIn numeric, TOut numeric](input []TIn) []TOut {
output := make([]TOut, len(input))
for i, v := range input {
output[i] = TOut(v)
}
return output
}