diff --git a/go.mod b/go.mod index 8b9098a..aaae089 100644 --- a/go.mod +++ b/go.mod @@ -19,6 +19,8 @@ require ( github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect github.com/Microsoft/go-winio v0.5.2 // indirect github.com/Microsoft/hcsshim v0.9.3 // indirect + github.com/PaesslerAG/gval v1.0.0 // indirect + github.com/PaesslerAG/jsonpath v0.1.1 // indirect github.com/cenkalti/backoff/v4 v4.1.3 // indirect github.com/cespare/xxhash/v2 v2.1.2 // indirect github.com/containerd/cgroups v1.0.4 // indirect diff --git a/go.sum b/go.sum index a88e43b..00991b9 100644 --- a/go.sum +++ b/go.sum @@ -66,6 +66,11 @@ github.com/Microsoft/hcsshim/test v0.0.0-20201218223536-d3e5debf77da/go.mod h1:5 github.com/Microsoft/hcsshim/test v0.0.0-20210227013316-43a75bb4edd3/go.mod h1:mw7qgWloBUl75W/gVH3cQszUg1+gUITj7D6NY7ywVnY= github.com/NYTimes/gziphandler v0.0.0-20170623195520-56545f4a5d46/go.mod h1:3wb06e3pkSAbeQ52E9H9iFoQsEEwGN64994WTCIhntQ= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= +github.com/PaesslerAG/gval v1.0.0 h1:GEKnRwkWDdf9dOmKcNrar9EA1bz1z9DqPIO1+iLzhd8= +github.com/PaesslerAG/gval v1.0.0/go.mod h1:y/nm5yEyTeX6av0OfKJNp9rBNj2XrGhAf5+v24IBN1I= +github.com/PaesslerAG/jsonpath v0.1.0/go.mod h1:4BzmtoM/PI8fPO4aQGIusjGxGir2BzcV0grWtFzq1Y8= +github.com/PaesslerAG/jsonpath v0.1.1 h1:c1/AToHQMVsduPAa4Vh6xp2U0evy4t8SWp8imEsylIk= +github.com/PaesslerAG/jsonpath v0.1.1/go.mod h1:lVboNxFGal/VwW6d9JzIy56bUsYAP6tH/x80vjnCseY= github.com/PuerkitoBio/purell v1.0.0/go.mod h1:c11w/QuzBsJSee3cPx9rAFu61PvFxuPbtSwDGJws/X0= github.com/PuerkitoBio/purell v1.1.1/go.mod h1:c11w/QuzBsJSee3cPx9rAFu61PvFxuPbtSwDGJws/X0= github.com/PuerkitoBio/urlesc v0.0.0-20160726150825-5bd2802263f2/go.mod h1:uGdkoq3SwY9Y+13GIhn11/XLaGBb4BfwItxLd5jeuXE= diff --git a/grammar/param.go b/grammar/param.go index 6d7a377..2d88a3c 100644 --- a/grammar/param.go +++ b/grammar/param.go @@ -11,12 +11,37 @@ func ValidateParameterCount(params []Param, expected int) error { return nil } +type ParamType uint8 + +const ( + ParamTypeUnknown ParamType = iota + ParamTypeString + ParamTypeInt + ParamTypeFloat +) + type Param struct { String *string `parser:"@String | @RawString"` Int *int `parser:"| @Int"` Float *float64 `parser:"| @Float"` } +func (p Param) Type() ParamType { + if p.String != nil { + return ParamTypeString + } + + if p.Int != nil { + return ParamTypeInt + } + + if p.Float != nil { + return ParamTypeFloat + } + + return ParamTypeUnknown +} + func (p Param) AsString() (string, error) { if p.String == nil { return "", fmt.Errorf("string is nil %w", ErrTypeMismatch) diff --git a/protocols/redis/checks_test.go b/protocols/redis/checks_test.go index 4e7045f..ff703fa 100644 --- a/protocols/redis/checks_test.go +++ b/protocols/redis/checks_test.go @@ -37,7 +37,7 @@ func TestChecks_Execute(t *testing.T) { }, { name: "Get value - validate value", - check: `redis.GET("%s", "some_key") => String("some_value")`, + check: `redis.GET("%s", "some_key") => Equals("some_value")`, setup: func(tb testing.TB, cli redisCli.UniversalClient) { tb.Helper() td.CmpNoError(tb, cli.Set(context.Background(), "some_key", "some_value", 0).Err()) diff --git a/protocols/redis/ping.go b/protocols/redis/ping.go index 3892c01..4cccd46 100644 --- a/protocols/redis/ping.go +++ b/protocols/redis/ping.go @@ -38,7 +38,8 @@ func (p *PingCheck) UnmarshalCheck(c grammar.Check, lookup config.ServerLookup) serverAndMessageArgCount = 2 ) - p.validators = append(p.validators, StringCmdValidator("PONG")) + val, _ := GenericCommandValidatorFor("PONG") + p.validators = append(p.validators, val) init := c.Initiator switch len(init.Params) { @@ -48,7 +49,8 @@ func (p *PingCheck) UnmarshalCheck(c grammar.Check, lookup config.ServerLookup) if msg, err := init.Params[1].AsString(); err != nil { return err } else { - p.validators = ValidationChain{StringCmdValidator(msg)} + val, _ := GenericCommandValidatorFor(msg) + p.validators = ValidationChain{val} } fallthrough case serverOnlyArgCount: diff --git a/protocols/redis/validation.go b/protocols/redis/validation.go index ce1ee3b..dc06ca9 100644 --- a/protocols/redis/validation.go +++ b/protocols/redis/validation.go @@ -9,17 +9,18 @@ import ( "github.com/baez90/nurse/check" "github.com/baez90/nurse/grammar" + "github.com/baez90/nurse/validation" ) var ( ErrNoSuchValidator = errors.New("no such validator") _ CmdValidator = (ValidationChain)(nil) - _ CmdValidator = (*StringCmdValidator)(nil) + _ CmdValidator = (*GenericCmdValidator)(nil) knownValidators = map[string]func() unmarshallableCmdValidator{ - "string": func() unmarshallableCmdValidator { - return new(StringCmdValidator) + "equals": func() unmarshallableCmdValidator { + return new(GenericCmdValidator) }, } ) @@ -71,33 +72,61 @@ func (v ValidationChain) Validate(cmder redis.Cmder) error { return nil } -type StringCmdValidator string +func GenericCommandValidatorFor[T validation.Value](want T) (*GenericCmdValidator, error) { + comparator, err := validation.JSONValueComparatorFor(want) + if err != nil { + return nil, err + } + return &GenericCmdValidator{ + comparator: comparator, + }, nil +} -func (s *StringCmdValidator) UnmarshalCall(c grammar.Call) error { +type GenericCmdValidator struct { + comparator validation.ValueComparator +} + +func (g *GenericCmdValidator) UnmarshalCall(c grammar.Call) error { if err := grammar.ValidateParameterCount(c.Params, 1); err != nil { return err } - want, err := c.Params[0].AsString() - if err != nil { - return err + var err error + + switch c.Params[0].Type() { + case grammar.ParamTypeInt: + if g.comparator, err = validation.JSONValueComparatorFor(*c.Params[0].Int); err != nil { + return err + } + case grammar.ParamTypeFloat: + if g.comparator, err = validation.JSONValueComparatorFor(*c.Params[0].Float); err != nil { + return err + } + case grammar.ParamTypeString: + if g.comparator, err = validation.JSONValueComparatorFor(*c.Params[0].String); err != nil { + return err + } + default: + return fmt.Errorf("param type is unkown") } - *s = StringCmdValidator(want) return nil } -func (s StringCmdValidator) Validate(cmder redis.Cmder) error { +func (g GenericCmdValidator) Validate(cmder redis.Cmder) error { if err := cmder.Err(); err != nil { return err } - if stringCmd, ok := cmder.(*redis.StringCmd); !ok { - return errors.New("not a string result") - } else if got, err := stringCmd.Result(); err != nil { - return err - } else if want := string(s); got != want { - return fmt.Errorf("want %s but got %s", want, got) + switch in := cmder.(type) { + case *redis.StringCmd: + res, err := in.Result() + if err != nil { + return err + } + if !g.comparator.Equals(res) { + return fmt.Errorf("got %s - but didn't match expected value", res) + } } return nil diff --git a/validation/comparator.go b/validation/comparator.go new file mode 100644 index 0000000..d426d74 --- /dev/null +++ b/validation/comparator.go @@ -0,0 +1,39 @@ +package validation + +import "math" + +const equalityThreshold = 0.00000001 + +var ( + _ ValueComparator = (*GenericComparator[int])(nil) + _ ValueComparator = FloatComparator(0) +) + +type ValueComparator interface { + Equals(got any) bool +} + +type GenericComparator[T int | string] struct { + Want T + Parser func(got any) (T, error) +} + +func (g GenericComparator[T]) Equals(got any) bool { + parsed, err := g.Parser(got) + if err != nil { + return false + } + + return parsed == g.Want +} + +type FloatComparator float64 + +func (f FloatComparator) Equals(got any) bool { + val, err := ParseJSONFloat(got) + if err != nil { + return false + } + + return math.Abs(float64(f)-val) < equalityThreshold +} diff --git a/validation/conversion.go b/validation/conversion.go new file mode 100644 index 0000000..db4b3e7 --- /dev/null +++ b/validation/conversion.go @@ -0,0 +1,43 @@ +package validation + +func ToFloat64(val any) float64 { + switch i := val.(type) { + case float64: + return i + case float32: + return float64(i) + default: + return 0 + } +} + +func ToInt(val any) int { + switch i := val.(type) { + case int: + return i + case int8: + return int(i) + case int16: + return int(i) + case int32: + return int(i) + case int64: + return int(i) + case uint: + return int(i) + case uint8: + return int(i) + case uint16: + return int(i) + case uint32: + return int(i) + case uint64: + return int(i) + case float32: + return int(i) + case float64: + return int(i) + default: + return 0 + } +} diff --git a/validation/json_parsing.go b/validation/json_parsing.go new file mode 100644 index 0000000..6be2609 --- /dev/null +++ b/validation/json_parsing.go @@ -0,0 +1,76 @@ +package validation + +import ( + "encoding/json" + "fmt" +) + +func ParseJSONInt(got any) (int, error) { + switch in := got.(type) { + case float32, float64: + return ToInt(in), nil + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: + return ToInt(got), nil + case []byte: + var val json.Number + if err := json.Unmarshal(in, &val); err != nil { + return 0, err + } + if i, err := val.Int64(); err != nil { + return 0, err + } else { + return int(i), nil + } + case string: + var val json.Number + if err := json.Unmarshal([]byte(in), &val); err != nil { + return 0, err + } + if i, err := val.Int64(); err != nil { + return 0, err + } else { + return int(i), nil + } + default: + return 0, fmt.Errorf("cannot convert value %v to int", got) + } +} + +func ParseJSONFloat(got any) (float64, error) { + switch in := got.(type) { + case float32: + return float64(in), nil + case float64: + return in, nil + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: + i := got.(int) + return float64(i), nil + case []byte: + var val json.Number + if err := json.Unmarshal(in, &val); err != nil { + return 0, err + } + return val.Float64() + case string: + var val json.Number + if err := json.Unmarshal([]byte(in), &val); err != nil { + return 0, err + } + return val.Float64() + default: + return 0, fmt.Errorf("cannot convert value %v to float", got) + } +} + +func ParseJSONString(got any) (string, error) { + switch in := got.(type) { + case float32, float64, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: + return fmt.Sprintf("%v", in), nil + case []byte: + return string(in), nil + case string: + return in, nil + default: + return "", fmt.Errorf("cannot convert value %v to float", got) + } +} diff --git a/validation/jsonpath.go b/validation/jsonpath.go new file mode 100644 index 0000000..08a88f5 --- /dev/null +++ b/validation/jsonpath.go @@ -0,0 +1,65 @@ +package validation + +import ( + "encoding/json" + "fmt" + + "github.com/PaesslerAG/jsonpath" +) + +var _ ValueComparator = (*JSONPathValidator)(nil) + +func JSONPathValidatorFor[T Value](path string, want T) (*JSONPathValidator, error) { + comparator, err := JSONValueComparatorFor(want) + if err != nil { + return nil, err + } + + return &JSONPathValidator{ + Path: path, + Comparator: comparator, + }, nil +} + +type JSONPathValidator struct { + Path string + Comparator ValueComparator +} + +func (j JSONPathValidator) Equals(got any) bool { + parsed, err := parse(got) + if err != nil { + return false + } + val, err := jsonpath.Get(j.Path, parsed) + if err != nil { + return false + } + + return j.Comparator.Equals(val) +} + +func parse(in any) (any, error) { + keyValue := make(map[string]any) + arr := make([]any, 0) + switch data := in.(type) { + case []byte: + if err := json.Unmarshal(data, &keyValue); err == nil { + return keyValue, nil + } + + err := json.Unmarshal(data, &arr) + + return arr, err + case string: + raw := []byte(data) + if err := json.Unmarshal(raw, &keyValue); err == nil { + return keyValue, nil + } + err := json.Unmarshal(raw, &arr) + + return arr, err + } + + return nil, fmt.Errorf("cannot convert %v to JSON structure", in) +} diff --git a/validation/jsonpath_test.go b/validation/jsonpath_test.go new file mode 100644 index 0000000..19b2962 --- /dev/null +++ b/validation/jsonpath_test.go @@ -0,0 +1,85 @@ +package validation_test + +import ( + "testing" + + "github.com/baez90/nurse/validation" +) + +type jsonPathValidator_EqualsTestCase[V validation.Value] struct { + testName string + expected V + jsonPath string + json string + want bool +} + +func (tt jsonPathValidator_EqualsTestCase[V]) name() string { + return tt.testName +} + +func (tt jsonPathValidator_EqualsTestCase[V]) run(t *testing.T) { + t.Parallel() + t.Helper() + validator, err := validation.JSONPathValidatorFor(tt.jsonPath, tt.expected) + if err != nil { + t.Fatalf("JSONPathValidatorFor() err = %v", err) + } + + if validator.Equals(tt.json) != tt.want { + t.Errorf("Failed to equal value in %s to %v", tt.json, tt.expected) + } +} + +func TestJSONPathValidator_Equals(t *testing.T) { + t.Parallel() + tests := []testCase{ + jsonPathValidator_EqualsTestCase[string]{ + testName: "Simple object navigation", + expected: "hello", + jsonPath: "$.greeting", + json: `{"greeting": "hello"}`, + want: true, + }, + jsonPathValidator_EqualsTestCase[string]{ + testName: "Simple object navigation - number as string", + expected: "42", + jsonPath: "$.number", + json: `{"number": 42}`, + want: true, + }, + jsonPathValidator_EqualsTestCase[string]{ + testName: "Simple array navigation", + expected: "world", + jsonPath: "$[1]", + json: `["hello", "world"]`, + want: true, + }, + jsonPathValidator_EqualsTestCase[int]{ + testName: "Simple array navigation - string to int", + expected: 37, + jsonPath: "$[1]", + json: `["13", "37"]`, + want: true, + }, + jsonPathValidator_EqualsTestCase[int]{ + testName: "Simple array navigation - string to int - wrong value", + expected: 42, + jsonPath: "$[1]", + json: `["13", "37"]`, + want: false, + }, + jsonPathValidator_EqualsTestCase[string]{ + testName: "Simple array navigation - int to string", + expected: "37", + jsonPath: "$[1]", + json: `[13, 37]`, + want: true, + }, + } + //nolint:paralleltest + for _, tt := range tests { + tt := tt + t.Run(tt.name(), tt.run) + } +} diff --git a/validation/jsonval.go b/validation/jsonval.go new file mode 100644 index 0000000..e3595e9 --- /dev/null +++ b/validation/jsonval.go @@ -0,0 +1,52 @@ +package validation + +import ( + "fmt" +) + +var _ ValueComparator = (*JSONValueComparator)(nil) + +type Value interface { + float32 | float64 | int | int8 | int16 | int32 | int64 | uint | uint8 | uint16 | uint32 | uint64 | string | []byte +} + +func JSONValueComparatorFor[T Value](want T) (*JSONValueComparator, error) { + ti := any(want) + switch in := ti.(type) { + case float32, float64: + return &JSONValueComparator{ + Comparator: FloatComparator(ToFloat64(in)), + }, nil + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: + return &JSONValueComparator{ + Comparator: GenericComparator[int]{ + Want: ToInt(in), + Parser: ParseJSONInt, + }, + }, nil + case string: + return &JSONValueComparator{ + Comparator: GenericComparator[string]{ + Want: in, + Parser: ParseJSONString, + }, + }, nil + case []byte: + return &JSONValueComparator{ + Comparator: GenericComparator[string]{ + Want: string(in), + Parser: ParseJSONString, + }, + }, nil + default: + return nil, fmt.Errorf("no matching type detected for %v", want) + } +} + +type JSONValueComparator struct { + Comparator ValueComparator +} + +func (j JSONValueComparator) Equals(got any) bool { + return j.Comparator.Equals(got) +} diff --git a/validation/jsonval_test.go b/validation/jsonval_test.go new file mode 100644 index 0000000..5872cfb --- /dev/null +++ b/validation/jsonval_test.go @@ -0,0 +1,180 @@ +package validation_test + +import ( + "testing" + + "github.com/baez90/nurse/validation" +) + +type testCase interface { + run(t *testing.T) + name() string +} + +type jsonValueComparator_EqualsTestCase[V validation.Value] struct { + testName string + expected V + got any + want bool +} + +func (tt jsonValueComparator_EqualsTestCase[V]) run(t *testing.T) { + t.Parallel() + t.Helper() + comparator, err := validation.JSONValueComparatorFor(tt.expected) + if err != nil { + t.Fatalf("validation.JSONValueComparatorFor() err = %v", err) + } + + if got := comparator.Equals(tt.got); got != tt.want { + t.Errorf("Equals() = %v, want %v", got, tt.want) + } +} + +func (tt jsonValueComparator_EqualsTestCase[V]) name() string { + return tt.testName +} + +func TestJSONValueComparator_Equals(t *testing.T) { + t.Parallel() + tests := []testCase{ + jsonValueComparator_EqualsTestCase[int]{ + testName: "Test int equality", + expected: 42, + got: 42, + want: true, + }, + jsonValueComparator_EqualsTestCase[int]{ + testName: "Test int equality - wrong value", + expected: 42, + got: 43, + want: false, + }, + jsonValueComparator_EqualsTestCase[int]{ + testName: "Test int equality - string value", + expected: 42, + got: "42", + want: true, + }, + jsonValueComparator_EqualsTestCase[int]{ + testName: "Test int equality - []byte value", + expected: 42, + got: []byte("42"), + want: true, + }, + jsonValueComparator_EqualsTestCase[int]{ + testName: "Test int equality - float value", + expected: 42, + got: 42.0, + want: true, + }, + jsonValueComparator_EqualsTestCase[int8]{ + testName: "Test int8 equality", + expected: 42, + got: 42, + want: true, + }, + jsonValueComparator_EqualsTestCase[int8]{ + testName: "Test int8 equality - wrong value", + expected: 42, + got: 43, + want: false, + }, + jsonValueComparator_EqualsTestCase[int8]{ + testName: "Test int8 equality - int16 value", + expected: 42, + got: int16(42), + want: true, + }, + jsonValueComparator_EqualsTestCase[int8]{ + testName: "Test int8 equality - uint16 value", + expected: 42, + got: uint16(42), + want: true, + }, + jsonValueComparator_EqualsTestCase[float32]{ + testName: "Test float32 equality - float value", + expected: 42.0, + got: 42.0, + want: true, + }, + jsonValueComparator_EqualsTestCase[float32]{ + testName: "Test float32 equality - float value", + expected: 42.0, + got: float64(42.0), + want: true, + }, + jsonValueComparator_EqualsTestCase[float64]{ + testName: "Test float64 equality - float value", + expected: 42.0, + got: 42.0, + want: true, + }, + jsonValueComparator_EqualsTestCase[float64]{ + testName: "Test float64 equality - int value", + expected: 42.0, + got: 42, + want: true, + }, + jsonValueComparator_EqualsTestCase[float64]{ + testName: "Test float64 equality - []byte value", + expected: 42.0, + got: []byte("42"), + want: true, + }, + jsonValueComparator_EqualsTestCase[float64]{ + testName: "Test float64 equality - float32 value", + expected: 42.0, + got: float32(42.0), + want: true, + }, + jsonValueComparator_EqualsTestCase[float64]{ + testName: "Test float64 equality - string value", + expected: 42.0, + got: "42.0", + want: true, + }, + jsonValueComparator_EqualsTestCase[float64]{ + testName: "Test float64 equality - string value without dot", + expected: 42.0, + got: "42", + want: true, + }, + jsonValueComparator_EqualsTestCase[string]{ + testName: "Test string equality", + expected: "hello", + got: "hello", + want: true, + }, + jsonValueComparator_EqualsTestCase[string]{ + testName: "Test string equality - []byte value", + expected: "hello", + got: []byte("hello"), + want: true, + }, + jsonValueComparator_EqualsTestCase[string]{ + testName: "Test string equality - int value", + expected: "1337", + got: 1337, + want: true, + }, + jsonValueComparator_EqualsTestCase[string]{ + testName: "Test string equality - float value", + expected: "13.37", + got: 13.37, + want: true, + }, + jsonValueComparator_EqualsTestCase[string]{ + testName: "Test string equality - wrong case", + expected: "hello", + got: "HELLO", + want: false, + }, + } + + //nolint:paralleltest + for _, tt := range tests { + tt := tt + t.Run(tt.name(), tt.run) + } +}