package redis import ( "errors" "fmt" "strings" "github.com/go-redis/redis/v8" "github.com/baez90/nurse/check" "github.com/baez90/nurse/grammar" "github.com/baez90/nurse/validation" ) var ( ErrNoSuchValidator = errors.New("no such validator") _ CmdValidator = (ValidationChain)(nil) _ CmdValidator = (*GenericCmdValidator)(nil) knownValidators = map[string]func() unmarshallableCmdValidator{ "equals": func() unmarshallableCmdValidator { return new(GenericCmdValidator) }, } ) type ( CmdValidator interface { Validate(cmder redis.Cmder) error } unmarshallableCmdValidator interface { CmdValidator check.CallUnmarshaler } ) func ValidatorsForFilters(filters *grammar.Filters) (ValidationChain, error) { if filters == nil || filters.Chain == nil { return ValidationChain{}, nil } chain := make(ValidationChain, 0, len(filters.Chain)) for i := range filters.Chain { validationCall := filters.Chain[i] if validatorProvider, ok := knownValidators[strings.ToLower(validationCall.Name)]; !ok { return nil, fmt.Errorf("%w: %s", ErrNoSuchValidator, validationCall.Name) } else { validator := validatorProvider() if err := validator.UnmarshalCall(validationCall); err != nil { return nil, err } chain = append(chain, validator) } } return chain, nil } type ValidationChain []CmdValidator func (v ValidationChain) UnmarshalCall(grammar.Call) error { return errors.New("cannot unmarshal chain") } func (v ValidationChain) Validate(cmder redis.Cmder) error { for i := range v { if err := v[i].Validate(cmder); err != nil { return err } } return nil } func GenericCommandValidatorFor[T validation.Value](want T) (*GenericCmdValidator, error) { comparator, err := validation.JSONValueComparatorFor(want) if err != nil { return nil, err } return &GenericCmdValidator{ comparator: comparator, }, nil } type GenericCmdValidator struct { comparator validation.ValueComparator } func (g *GenericCmdValidator) UnmarshalCall(c grammar.Call) error { if err := grammar.ValidateParameterCount(c.Params, 1); err != nil { return err } var err error switch c.Params[0].Type() { case grammar.ParamTypeInt: if g.comparator, err = validation.JSONValueComparatorFor(*c.Params[0].Int); err != nil { return err } case grammar.ParamTypeFloat: if g.comparator, err = validation.JSONValueComparatorFor(*c.Params[0].Float); err != nil { return err } case grammar.ParamTypeString: if g.comparator, err = validation.JSONValueComparatorFor(*c.Params[0].String); err != nil { return err } default: return fmt.Errorf("param type is unkown") } return nil } func (g GenericCmdValidator) Validate(cmder redis.Cmder) error { if err := cmder.Err(); err != nil { return err } switch in := cmder.(type) { case *redis.StringCmd: res, err := in.Result() if err != nil { return err } if !g.comparator.Equals(res) { return fmt.Errorf("got %s - but didn't match expected value", res) } } return nil }