nurse/protocols/redis/validation.go

98 lines
2 KiB
Go
Raw Permalink Normal View History

2022-04-28 16:35:02 +00:00
package redis
import (
"errors"
"log/slog"
2022-04-28 16:35:02 +00:00
"github.com/redis/go-redis/v9"
2022-09-22 09:46:36 +00:00
"code.icb4dc0.de/prskr/nurse/grammar"
"code.icb4dc0.de/prskr/nurse/validation"
2022-04-28 16:35:02 +00:00
)
var (
_ CmdValidator = (*GenericCmdValidator)(nil)
2022-06-09 20:12:45 +00:00
registry = validation.NewRegistry[redis.Cmder]()
2022-04-28 16:35:02 +00:00
)
2022-06-09 20:12:45 +00:00
func init() {
registry.Register("equals", func() validation.FromCall[redis.Cmder] {
return new(GenericCmdValidator)
})
}
2022-06-09 20:12:45 +00:00
type CmdValidator interface {
Validate(cmder redis.Cmder) error
2022-04-28 16:35:02 +00:00
}
func GenericCommandValidatorFor[T validation.Value](want T) (*GenericCmdValidator, error) {
comparator, err := validation.JSONValueComparatorFor(want)
if err != nil {
return nil, err
}
return &GenericCmdValidator{
comparator: comparator,
}, nil
}
type GenericCmdValidator struct {
comparator validation.ValueComparator
}
2022-04-28 16:35:02 +00:00
func (g *GenericCmdValidator) UnmarshalCall(c grammar.Call) error {
if err := grammar.ValidateParameterCount(c.Params, 1); err != nil {
return err
}
var err error
switch c.Params[0].Type() {
case grammar.ParamTypeInt:
if g.comparator, err = validation.JSONValueComparatorFor(*c.Params[0].Int); err != nil {
return err
}
case grammar.ParamTypeFloat:
if g.comparator, err = validation.JSONValueComparatorFor(*c.Params[0].Float); err != nil {
return err
}
case grammar.ParamTypeString:
if g.comparator, err = validation.JSONValueComparatorFor(*c.Params[0].String); err != nil {
return err
}
2022-06-09 20:40:32 +00:00
case grammar.ParamTypeUnknown:
fallthrough
default:
2022-06-09 20:12:45 +00:00
return errors.New("param type is unknown")
}
return nil
}
2022-06-09 20:12:45 +00:00
func (g *GenericCmdValidator) Validate(cmder redis.Cmder) error {
slog.Default().Debug("Validate Redis result")
2022-04-28 16:35:02 +00:00
if err := cmder.Err(); err != nil {
return err
}
switch in := cmder.(type) {
case *redis.StringCmd:
if err := in.Err(); err != nil {
return err
}
res, err := in.Result()
if err != nil {
return err
}
2022-06-09 20:12:45 +00:00
return g.comparator.Equals(res)
case *redis.StatusCmd:
if err := in.Err(); err != nil {
return err
}
2022-04-28 16:35:02 +00:00
}
return nil
}