package protocol import ( "errors" "fmt" "reflect" "strings" rpcv1 "code.icb4dc0.de/buildr/wasi-module-sdk-go/protocol/generated/rpc/v1" ) var ( ErrUnmatchingType = errors.New("field type does not match wire value") ) func Unmarshal(input *rpcv1.ModuleSpec, into any) error { cfg := unmarshalConfig{ TagName: "hcl", } val := reflect.ValueOf(into) if val.Kind() != reflect.Ptr { return errors.New("into value must be a pointer") } return cfg.unmarshal(input.Values, val.Elem(), reflect.TypeOf(into).Elem()) } type unmarshalConfig struct { TagName string } func (cfg unmarshalConfig) unmarshal(input map[string]*rpcv1.ModuleSpec_Value, into reflect.Value, intoType reflect.Type) error { for i := 0; i < intoType.NumField(); i++ { tf := intoType.Field(i) if !tf.IsExported() { continue } fieldName := cfg.fieldName(tf) val, ok := input[fieldName] if !ok { continue } if mapped, err := cfg.mapSpecValueTo(val, tf.Type); err != nil { return fmt.Errorf("failed to map value for field %s: %w", tf.Name, err) } else if into.Type().Kind() == reflect.Pointer { into.Elem().Field(i).Set(mapped) } else { into.Field(i).Set(mapped) } delete(input, fieldName) } if len(input) > 0 { keys := make([]string, 0, len(input)) for k := range input { keys = append(keys, k) } return fmt.Errorf("unknown fields: %v", keys) } return nil } func (cfg unmarshalConfig) mapSpecValueTo(val *rpcv1.ModuleSpec_Value, targetType reflect.Type) (reflect.Value, error) { switch val.Type { case rpcv1.ModuleSpec_ValueTypeUnknown: return reflect.Value{}, ErrUnmatchingType case rpcv1.ModuleSpec_ValueTypeObject: if targetType.Kind() == reflect.Struct { structVal := reflect.New(targetType) if err := cfg.unmarshal(val.ComplexValue, structVal, targetType); err != nil { return reflect.Value{}, err } else { return structVal.Elem(), nil } } else if targetType.Kind() == reflect.Pointer && targetType.Elem().Kind() == reflect.Struct { structVal := reflect.New(targetType.Elem()) if err := cfg.unmarshal(val.ComplexValue, structVal, targetType.Elem()); err != nil { return reflect.Value{}, err } else { return structVal, nil } } else if targetType.Kind() != reflect.Map { return reflect.Value{}, fmt.Errorf("%w: expected struct, got %s", ErrUnmatchingType, targetType.String()) } fallthrough case rpcv1.ModuleSpec_ValueTypeMap: if targetType.Kind() != reflect.Map { return reflect.Value{}, fmt.Errorf("%w: expected map, got %v", ErrUnmatchingType, targetType) } mapVal := reflect.MakeMap(targetType) for k, v := range val.ComplexValue { if mappedVal, err := cfg.mapSpecValueTo(v, targetType.Elem()); err != nil { return reflect.Value{}, err } else { mapVal.SetMapIndex(reflect.ValueOf(k), mappedVal) } } return mapVal, nil case rpcv1.ModuleSpec_ValueTypeSingle: switch sv := val.SingleValue.(type) { case *rpcv1.ModuleSpec_Value_BoolValue: if targetType.Kind() != reflect.Bool { return reflect.Value{}, fmt.Errorf("%w: expected bool, got %v", ErrUnmatchingType, targetType) } return reflect.ValueOf(sv.BoolValue), nil case *rpcv1.ModuleSpec_Value_StringValue: if targetType.Kind() != reflect.String { return reflect.Value{}, fmt.Errorf("%w: expected string, got %v", ErrUnmatchingType, targetType) } return reflect.ValueOf(sv.StringValue), nil case *rpcv1.ModuleSpec_Value_IntValue: if targetType.Kind() != reflect.Int { return reflect.Value{}, fmt.Errorf("%w: expected int, got %v", ErrUnmatchingType, targetType) } return reflect.ValueOf(int(sv.IntValue)), nil case *rpcv1.ModuleSpec_Value_DoubleValue: if targetType.Kind() != reflect.Float64 { return reflect.Value{}, fmt.Errorf("%w: expected float64, got %v", ErrUnmatchingType, targetType) } return reflect.ValueOf(sv.DoubleValue), nil } case rpcv1.ModuleSpec_ValueTypeBoolSlice: if targetType.Kind() != reflect.Slice { return reflect.Value{}, fmt.Errorf("%w: expected slice, got %v", ErrUnmatchingType, targetType) } else if targetType.Elem().Kind() != reflect.Bool { return reflect.Value{}, fmt.Errorf("%w: expected bool, got %v", ErrUnmatchingType, targetType.Elem()) } return reflect.ValueOf(val.BoolValues), nil case rpcv1.ModuleSpec_ValueTypeStringSlice: if targetType.Kind() != reflect.Slice { return reflect.Value{}, fmt.Errorf("%w: expected slice, got %v", ErrUnmatchingType, targetType) } else if targetType.Elem().Kind() != reflect.String { return reflect.Value{}, fmt.Errorf("%w: expected string, got %v", ErrUnmatchingType, targetType.Elem()) } return reflect.ValueOf(val.StringValues), nil case rpcv1.ModuleSpec_ValueTypeIntSlice: if targetType.Kind() != reflect.Slice { return reflect.Value{}, fmt.Errorf("%w: expected slice, got %v", ErrUnmatchingType, targetType) } switch targetType.Elem().Kind() { case reflect.Int: return reflect.ValueOf(convertNumericSlice[int64, int](val.IntValues)), nil case reflect.Int8: return reflect.ValueOf(convertNumericSlice[int64, int8](val.IntValues)), nil case reflect.Int16: return reflect.ValueOf(convertNumericSlice[int64, int16](val.IntValues)), nil case reflect.Int32: return reflect.ValueOf(convertNumericSlice[int64, int32](val.IntValues)), nil case reflect.Int64: return reflect.ValueOf(val.IntValues), nil default: return reflect.Value{}, fmt.Errorf("%w: expected int, got %v", ErrUnmatchingType, targetType.Elem()) } case rpcv1.ModuleSpec_ValueTypeDoubleSlice: if targetType.Kind() != reflect.Slice { return reflect.Value{}, fmt.Errorf("%w: expected slice, got %v", ErrUnmatchingType, targetType) } switch targetType.Elem().Kind() { case reflect.Float32: return reflect.ValueOf(convertNumericSlice[float64, float32](val.DoubleValues)), nil case reflect.Float64: return reflect.ValueOf(val.DoubleValues), nil default: return reflect.Value{}, fmt.Errorf("%w: expected int, got %v", ErrUnmatchingType, targetType.Elem()) } } return reflect.Value{}, nil } func (cfg unmarshalConfig) fieldName(field reflect.StructField) string { tagVal, ok := field.Tag.Lookup(cfg.TagName) if !ok { return field.Name } split := strings.Split(tagVal, ",") switch { case len(split) == 0: fallthrough case split[0] == "": return strings.ToLower(field.Name) default: return strings.ToLower(split[0]) } } type numeric interface { Integer | Float } func convertNumericSlice[TIn numeric, TOut numeric](input []TIn) []TOut { output := make([]TOut, len(input)) for i, v := range input { output[i] = TOut(v) } return output }