diff --git a/.gitignore b/.gitignore index f52cf6c..f81058c 100644 --- a/.gitignore +++ b/.gitignore @@ -15,4 +15,5 @@ # vendor/ dist/ -.idea/ \ No newline at end of file +.idea/ +cue.mod/ \ No newline at end of file diff --git a/check/api.go b/check/api.go index edbc781..91baecc 100644 --- a/check/api.go +++ b/check/api.go @@ -4,12 +4,26 @@ import ( "context" "errors" + "github.com/baez90/nurse/config" "github.com/baez90/nurse/grammar" ) -var ErrNoSuchCheck = errors.New("no such check") +var ( + ErrNoSuchCheck = errors.New("no such check") + ErrConflictingCheck = errors.New("check with same name already registered") +) -type SystemChecker interface { - grammar.CheckUnmarshaler - Execute(ctx context.Context) error -} +type ( + Unmarshaler interface { + UnmarshalCheck(c grammar.Check, lookup config.ServerLookup) error + } + + SystemChecker interface { + Unmarshaler + Execute(ctx context.Context) error + } + + CallUnmarshaler interface { + UnmarshalCall(c grammar.Call) error + } +) diff --git a/check/modules.go b/check/modules.go new file mode 100644 index 0000000..ed34ee3 --- /dev/null +++ b/check/modules.go @@ -0,0 +1,91 @@ +package check + +import ( + "fmt" + "strings" + "sync" + + "github.com/baez90/nurse/config" + "github.com/baez90/nurse/grammar" +) + +type ( + ModuleOption interface { + Apply(m *Module) error + } + + ModuleOptionFunc func(m *Module) error + + Factory interface { + New() SystemChecker + } + + FactoryFunc func() SystemChecker +) + +func (f ModuleOptionFunc) Apply(m *Module) error { + return f(m) +} + +func (f FactoryFunc) New() SystemChecker { + return f() +} + +func WithCheck(name string, factory Factory) ModuleOption { + return ModuleOptionFunc(func(m *Module) error { + return m.Register(name, factory) + }) +} + +func NewModule(opts ...ModuleOption) (*Module, error) { + m := &Module{ + knownChecks: make(map[string]Factory), + } + + for i := range opts { + if err := opts[i].Apply(m); err != nil { + return nil, err + } + } + + return m, nil +} + +type Module struct { + lock sync.RWMutex + knownChecks map[string]Factory +} + +func (m *Module) Lookup(c grammar.Check, srvLookup config.ServerLookup) (SystemChecker, error) { + m.lock.RLock() + defer m.lock.RUnlock() + var ( + factory Factory + ok bool + ) + if factory, ok = m.knownChecks[strings.ToLower(c.Initiator.Name)]; !ok { + return nil, fmt.Errorf("%w: %s", ErrNoSuchCheck, c.Initiator.Name) + } + + chk := factory.New() + if err := chk.UnmarshalCheck(c, srvLookup); err != nil { + return nil, err + } + + return chk, nil +} + +func (m *Module) Register(name string, factory Factory) error { + m.lock.Lock() + defer m.lock.Unlock() + + name = strings.ToLower(name) + + if _, ok := m.knownChecks[name]; ok { + return fmt.Errorf("%w: %s", ErrConflictingCheck, name) + } + + m.knownChecks[name] = factory + + return nil +} diff --git a/config/app_config.go b/config/app_config.go new file mode 100644 index 0000000..33edee1 --- /dev/null +++ b/config/app_config.go @@ -0,0 +1,91 @@ +package config + +import ( + "io" + "os" + "time" +) + +type ( + Option interface { + Extend(n Nurse) (Nurse, error) + } + OptionFunc func(n Nurse) (Nurse, error) +) + +func (f OptionFunc) Extend(n Nurse) (Nurse, error) { + return f(n) +} + +func New(opts ...Option) (*Nurse, error) { + var ( + inst Nurse + err error + ) + + for i := range opts { + if inst, err = opts[i].Extend(inst); err != nil { + return nil, err + } + } + + return &inst, nil +} + +type Nurse struct { + Servers map[string]Server + Endpoints map[Route]EndpointSpec + CheckTimeout time.Duration +} + +// Merge merges the current Nurse instance with another one +// giving the current instance precedence means no set value is overwritten +func (n Nurse) Merge(other Nurse) Nurse { + if n.CheckTimeout == 0 { + n.CheckTimeout = other.CheckTimeout + } + + for name, srv := range other.Servers { + if _, ok := n.Servers[name]; !ok { + n.Servers[name] = srv + } + } + + return n +} + +func (n *Nurse) Unmarshal(reader io.ReadSeeker) error { + providers := []func(io.Reader) configDecoder{ + newYAMLDecoder, + newJSONDecoder, + } + + for i := range providers { + if err := n.tryUnmarshal(reader, providers[i]); err == nil { + return nil + } + } + + return nil +} + +func (n *Nurse) ReadFromFile(configFilePath string) error { + if file, err := os.Open(configFilePath); err != nil { + return err + } else { + defer func() { + _ = file.Close() + }() + + return n.Unmarshal(file) + } +} + +func (n *Nurse) tryUnmarshal(seeker io.ReadSeeker, prov func(reader io.Reader) configDecoder) error { + if _, err := seeker.Seek(0, 0); err != nil { + return err + } + + decoder := prov(seeker) + return decoder.DecodeConfig(n) +} diff --git a/config/decoder.go b/config/decoder.go new file mode 100644 index 0000000..3f29226 --- /dev/null +++ b/config/decoder.go @@ -0,0 +1,41 @@ +package config + +import ( + "encoding/json" + "io" + + "gopkg.in/yaml.v3" +) + +var ( + _ configDecoder = (*jsonDecoder)(nil) + _ configDecoder = (*yamlDecoder)(nil) +) + +type configDecoder interface { + DecodeConfig(into *Nurse) error +} + +func newJSONDecoder(r io.Reader) configDecoder { + return &jsonDecoder{decoder: json.NewDecoder(r)} +} + +type jsonDecoder struct { + decoder *json.Decoder +} + +func (j jsonDecoder) DecodeConfig(into *Nurse) error { + return j.decoder.Decode(into) +} + +func newYAMLDecoder(r io.Reader) configDecoder { + return &yamlDecoder{decoder: yaml.NewDecoder(r)} +} + +type yamlDecoder struct { + decoder *yaml.Decoder +} + +func (y yamlDecoder) DecodeConfig(into *Nurse) error { + return y.decoder.Decode(into) +} diff --git a/config/endpoints.go b/config/endpoints.go new file mode 100644 index 0000000..75549e5 --- /dev/null +++ b/config/endpoints.go @@ -0,0 +1,42 @@ +package config + +import ( + "encoding" + "fmt" + "path" + "strings" + "time" + + "github.com/baez90/nurse/grammar" +) + +var _ encoding.TextUnmarshaler = (*EndpointSpec)(nil) + +type Route string + +func (r Route) String() string { + val := string(r) + val = strings.Trim(val, "/") + + return path.Clean(fmt.Sprintf("/%s", val)) +} + +type EndpointSpec struct { + CheckTimeout time.Duration + Checks []grammar.Check +} + +func (e *EndpointSpec) UnmarshalText(text []byte) error { + parser, err := grammar.NewParser[grammar.Script]() + if err != nil { + return err + } + + script, err := parser.Parse(string(text)) + if err != nil { + return err + } + + e.Checks = script.Checks + return nil +} diff --git a/config/env.go b/config/env.go new file mode 100644 index 0000000..340a4fa --- /dev/null +++ b/config/env.go @@ -0,0 +1,61 @@ +package config + +import ( + "os" + "path" + . "strings" +) + +const ( + ServerKeyPrefix = "NURSE_SERVER" + EndpointKeyPrefix = "NURSE_ENDPOINT" +) + +func ServersFromEnv() (map[string]Server, error) { + servers := make(map[string]Server) + for _, kv := range os.Environ() { + key, value, valid := Cut(kv, "=") + if !valid { + continue + } + + if !HasPrefix(key, ServerKeyPrefix) { + continue + } + + serverName := ToLower(Trim(Replace(key, ServerKeyPrefix, "", -1), "_")) + srv := Server{} + if err := srv.UnmarshalURL(value); err != nil { + return nil, err + } + + servers[serverName] = srv + } + + return servers, nil +} + +func EndpointsFromEnv() (map[Route]EndpointSpec, error) { + endpoints := make(map[Route]EndpointSpec) + + for _, kv := range os.Environ() { + key, value, valid := Cut(kv, "=") + if !valid { + continue + } + + if !HasPrefix(key, EndpointKeyPrefix) { + continue + } + + endpointRoute := path.Join(Split(ToLower(Trim(Replace(key, EndpointKeyPrefix, "", -1), "_")), "_")...) + spec := EndpointSpec{} + if err := spec.UnmarshalText([]byte(value)); err != nil { + return nil, err + } + + endpoints[Route(endpointRoute)] = spec + } + + return endpoints, nil +} diff --git a/config/env_test.go b/config/env_test.go new file mode 100644 index 0000000..1deb3d4 --- /dev/null +++ b/config/env_test.go @@ -0,0 +1,162 @@ +package config_test + +import ( + "fmt" + "testing" + + "github.com/maxatome/go-testdeep/td" + + "github.com/baez90/nurse/config" +) + +//nolint:paralleltest // not possible with env setup +func TestServersFromEnv(t *testing.T) { + tests := []struct { + name string + env map[string]string + want any + wantErr bool + }{ + { + name: "Empty env", + want: td.NotNil(), + }, + { + name: "Single server", + env: map[string]string{ + fmt.Sprintf("%s_REDIS_1", config.ServerKeyPrefix): "redis://localhost:6379", + }, + want: td.Map(make(map[string]config.Server), td.MapEntries{ + "redis_1": config.Server{ + Type: config.ServerTypeRedis, + Hosts: []string{"localhost:6379"}, + Args: make(map[string]any), + }, + }), + wantErr: false, + }, + { + name: "Multiple servers", + env: map[string]string{ + fmt.Sprintf("%s_REDIS_1", config.ServerKeyPrefix): "redis://localhost:6379", + fmt.Sprintf("%s_REDIS_2", config.ServerKeyPrefix): "redis://redis:6379", + }, + want: td.Map(make(map[string]config.Server), td.MapEntries{ + "redis_1": config.Server{ + Type: config.ServerTypeRedis, + Hosts: []string{"localhost:6379"}, + Args: make(map[string]any), + }, + "redis_2": config.Server{ + Type: config.ServerTypeRedis, + Hosts: []string{"redis:6379"}, + Args: make(map[string]any), + }, + }), + wantErr: false, + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + if tt.env != nil { + for k, v := range tt.env { + t.Setenv(k, v) + } + } + + got, err := config.ServersFromEnv() + if (err != nil) != tt.wantErr { + t.Errorf("ServersFromEnv() error = %v, wantErr %v", err, tt.wantErr) + return + } + + td.Cmp(t, got, tt.want) + }) + } +} + +//nolint:paralleltest // not possible with env setup +func TestEndpointsFromEnv(t *testing.T) { + tests := []struct { + name string + env map[string]string + want any + wantErr bool + }{ + { + name: "Empty env", + want: td.NotNil(), + }, + { + name: "Single endpoint", + env: map[string]string{ + fmt.Sprintf("%s_readiness", config.EndpointKeyPrefix): `redis.PING("local_redis")`, + }, + want: td.Map(make(map[config.Route]config.EndpointSpec), td.MapEntries{ + config.Route("readiness"): td.Struct(config.EndpointSpec{}, td.StructFields{ + "Checks": td.Len(1), + }), + }), + wantErr: false, + }, + { + name: "Single endpoint - multiple checks", + env: map[string]string{ + fmt.Sprintf("%s_readiness", config.EndpointKeyPrefix): `redis.PING("local_redis");redis.GET("local_redis", "serving") => String("ok")`, + }, + want: td.Map(make(map[config.Route]config.EndpointSpec), td.MapEntries{ + config.Route("readiness"): td.Struct(config.EndpointSpec{}, td.StructFields{ + "Checks": td.Len(2), + }), + }), + wantErr: false, + }, + { + name: "Single endpoint - sub-route", + env: map[string]string{ + fmt.Sprintf("%s_READINESS_REDIS", config.EndpointKeyPrefix): `redis.PING("local_redis");redis.GET("local_redis", "serving") => String("ok")`, + }, + want: td.Map(make(map[config.Route]config.EndpointSpec), td.MapEntries{ + config.Route("readiness/redis"): td.Struct(config.EndpointSpec{}, td.StructFields{ + "Checks": td.Len(2), + }), + }), + wantErr: false, + }, + { + name: "Multiple endpoints", + env: map[string]string{ + fmt.Sprintf("%s_readiness", config.EndpointKeyPrefix): `redis.PING("local_redis")`, + fmt.Sprintf("%s_liveness", config.EndpointKeyPrefix): `redis.GET("local_redis", "serving") => String("ok")`, + }, + want: td.Map(make(map[config.Route]config.EndpointSpec), td.MapEntries{ + config.Route("readiness"): td.Struct(config.EndpointSpec{}, td.StructFields{ + "Checks": td.Len(1), + }), + config.Route("liveness"): td.Struct(config.EndpointSpec{}, td.StructFields{ + "Checks": td.Len(1), + }), + }), + wantErr: false, + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + if tt.env != nil { + for k, v := range tt.env { + t.Setenv(k, v) + } + } + + got, err := config.EndpointsFromEnv() + if (err != nil) != tt.wantErr { + t.Errorf("EndpointsFromEnv() error = %v, wantErr %v", err, tt.wantErr) + return + } + + td.Cmp(t, got, tt.want) + }) + } +} diff --git a/config/flags.go b/config/flags.go new file mode 100644 index 0000000..e693d60 --- /dev/null +++ b/config/flags.go @@ -0,0 +1,39 @@ +package config + +import ( + "flag" + "os" + "time" +) + +const defaultCheckTimeout = 500 * time.Millisecond + +func ConfigureFlags(cfg *Nurse) *flag.FlagSet { + set := flag.NewFlagSet("nurse", flag.ContinueOnError) + + set.DurationVar( + &cfg.CheckTimeout, + "check-timeout", + LookupEnvOr("NURSE_CHECK_TIMEOUT", defaultCheckTimeout, time.ParseDuration), + "Timeout when running checks", + ) + + return set +} + +func LookupEnvOr[T any](envKey string, fallback T, parse func(envVal string) (T, error)) T { + envVal := os.Getenv(envKey) + if envVal == "" { + return fallback + } + + if parsed, err := parse(envVal); err != nil { + return fallback + } else { + return parsed + } +} + +func Identity[T any](in T) (T, error) { + return in, nil +} diff --git a/config/lookup.go b/config/lookup.go index dc33f82..b0e6bf9 100644 --- a/config/lookup.go +++ b/config/lookup.go @@ -10,17 +10,26 @@ import ( var ( ErrServerNameAlreadyRegistered = errors.New("server name is already registered") ErrNoSuchServer = errors.New("no known server with given name") - DefaultLookup = &ServerLookup{ - servers: make(map[string]Server), - } + + _ ServerLookup = (*ServerRegister)(nil) ) -type ServerLookup struct { +type ServerLookup interface { + Lookup(name string) (*Server, error) +} + +func NewServerRegister() *ServerRegister { + return &ServerRegister{ + servers: make(map[string]Server), + } +} + +type ServerRegister struct { lock sync.RWMutex servers map[string]Server } -func (l *ServerLookup) Register(name string, srv Server) error { +func (l *ServerRegister) Register(name string, srv Server) error { l.lock.Lock() defer l.lock.Unlock() @@ -34,7 +43,7 @@ func (l *ServerLookup) Register(name string, srv Server) error { return nil } -func (l *ServerLookup) Lookup(name string) (*Server, error) { +func (l *ServerRegister) Lookup(name string) (*Server, error) { l.lock.RLock() defer l.lock.RUnlock() diff --git a/config/options.go b/config/options.go new file mode 100644 index 0000000..4e86af3 --- /dev/null +++ b/config/options.go @@ -0,0 +1,119 @@ +package config + +import ( + "os" + "path/filepath" + + "go.uber.org/zap" +) + +func WithServersFromEnv() Option { + return OptionFunc(func(n Nurse) (Nurse, error) { + envServers, err := ServersFromEnv() + if err != nil { + return Nurse{}, err + } + + if n.Servers == nil || len(n.Servers) == 0 { + n.Servers = envServers + return n, nil + } + + for name, srv := range envServers { + if _, ok := n.Servers[name]; !ok { + n.Servers[name] = srv + } + } + + return n, nil + }) +} + +func WithEndpointsFromEnv() Option { + return OptionFunc(func(n Nurse) (Nurse, error) { + envEndpoints, err := EndpointsFromEnv() + if err != nil { + return Nurse{}, err + } + + if n.Endpoints == nil || len(n.Endpoints) == 0 { + n.Endpoints = envEndpoints + return n, nil + } + + for route, spec := range envEndpoints { + if _, ok := n.Endpoints[route]; !ok { + n.Endpoints[route] = spec + } + } + + return n, nil + }) +} + +func WithValuesFrom(other Nurse) Option { + return OptionFunc(func(n Nurse) (Nurse, error) { + return n.Merge(other), nil + }) +} + +func WithConfigFile(configFilePath string) Option { + logger := zap.L() + return OptionFunc(func(n Nurse) (Nurse, error) { + var out Nurse + if configFilePath != "" { + logger.Debug("Attempt to load config file") + if err := out.ReadFromFile(configFilePath); err == nil { + return out, nil + } else { + logger.Warn( + "Failed to load config file", + zap.String("config_file_path", configFilePath), + zap.Error(err), + ) + } + } + + if workingDir, err := os.Getwd(); err == nil { + configFilePath = filepath.Join(workingDir, "nurse.yaml") + logger.Debug("Attempt to load config file from current working directory") + if err = out.ReadFromFile(configFilePath); err == nil { + return out, nil + } else { + logger.Warn( + "Failed to load config file", + zap.String("config_file_path", configFilePath), + zap.Error(err), + ) + } + } + + if home, err := os.UserHomeDir(); err == nil { + configFilePath = filepath.Join(home, ".nurse.yaml") + logger.Debug("Attempt to load config file from user home directory") + if err = out.ReadFromFile(configFilePath); err == nil { + return out, nil + } else { + logger.Warn( + "Failed to load config file", + zap.String("config_file_path", configFilePath), + zap.Error(err), + ) + } + } + + configFilePath = filepath.Join("", "etc", "nurse", "config.yaml") + logger.Debug("Attempt to load config file from global config directory") + if err := out.ReadFromFile(configFilePath); err == nil { + return out, nil + } else { + logger.Warn( + "Failed to load config file", + zap.String("config_file_path", configFilePath), + zap.Error(err), + ) + } + + return Nurse{}, nil + }) +} diff --git a/config/server.go b/config/server.go index 226e7ff..7051299 100644 --- a/config/server.go +++ b/config/server.go @@ -2,11 +2,13 @@ package config import ( "encoding/json" + "fmt" "net/url" "regexp" "strings" - "github.com/baez90/nurse/internal/values" + "golang.org/x/exp/maps" + "gopkg.in/yaml.v3" ) type ServerType uint @@ -14,48 +16,42 @@ type ServerType uint const ( ServerTypeUnspecified ServerType = iota ServerTypeRedis + + //nolint:lll // cannot break regex + // language=regexp + urlSchemeRegex = `^(?P\w+)://((?P[\w-.]+)(:(?P.*))?@)?(\{(?P(.+:\d{1,5})(,(.+:\d{1,5}))*)}|(?P.+:\d{1,5}))(?P/.*$)?` ) -var hostsRegexp = regexp.MustCompile(`^{(.+:\d{1,5})(,(.+:\d{1,5}))*}|(.+:\d{1,5})$`) - -func ParseFromURL(url *url.URL) (*Server, error) { - srv := &Server{ - Type: SchemeToServerType(url.Scheme), - Hosts: hostsRegexp.FindAllString(url.Host, -1), +func (t ServerType) Scheme() string { + switch t { + case ServerTypeRedis: + return "redis" + case ServerTypeUnspecified: + fallthrough + default: + return "unknown" } - - if user := url.User; user != nil { - srv.Credentials = &Credentials{ - Username: user.Username(), - } - if pw, ok := user.Password(); ok { - srv.Credentials.Password = values.StringP(pw) - } - } - - srv.Path = strings.Split(strings.Trim(url.EscapedPath(), "/"), "/") - - q := url.Query() - qm := map[string][]string(q) - srv.Args = make(map[string]any, len(qm)) - - for k := range qm { - var val any - if err := json.Unmarshal([]byte(q.Get(k)), &val); err != nil { - return nil, err - } else { - srv.Args[k] = val - } - } - - return srv, nil } +var hostsRegexp = regexp.MustCompile(urlSchemeRegex) + type Credentials struct { Username string Password *string } +var ( + _ json.Unmarshaler = (*Server)(nil) + _ yaml.Unmarshaler = (*Server)(nil) +) + +type marshaledServer struct { + Type ServerType + URL string + Hosts []string + Args map[string]any +} + type Server struct { Type ServerType Credentials *Credentials @@ -63,3 +59,135 @@ type Server struct { Path []string Args map[string]any } + +func (s *Server) UnmarshalYAML(value *yaml.Node) error { + s.Args = make(map[string]any) + + tmp := new(marshaledServer) + + if err := value.Decode(tmp); err != nil { + return err + } + + return s.mergedMarshaledServer(*tmp) +} + +func (s *Server) UnmarshalJSON(bytes []byte) error { + s.Args = make(map[string]any) + + tmp := new(marshaledServer) + + if err := json.Unmarshal(bytes, tmp); err != nil { + return err + } + + return s.mergedMarshaledServer(*tmp) +} + +func (s *Server) UnmarshalURL(rawUrl string) error { + rawPath, username, password, err := s.extractBaseProperties(rawUrl) + if err != nil { + return err + } + + if username != "" { + s.Credentials = &Credentials{ + Username: username, + } + if password != "" { + s.Credentials.Password = &password + } + } + + if rawPath != "" { + parsedUrl, err := url.Parse(fmt.Sprintf("%s://%s%s", s.Type.Scheme(), s.Hosts[0], rawPath)) + if err != nil { + return err + } + if err = s.unmarshalPath(parsedUrl); err != nil { + return err + } + } else { + s.Args = make(map[string]any) + } + + return nil +} + +func (s *Server) extractBaseProperties(rawUrl string) (rawPath, username, password string, err error) { + allMatches := hostsRegexp.FindAllStringSubmatch(rawUrl, -1) + if matchLen := len(allMatches); matchLen != 1 { + return "", "", "", fmt.Errorf("ambiguous server match: %d", matchLen) + } + + match := allMatches[0] + + for i, name := range hostsRegexp.SubexpNames() { + if i == 0 { + continue + } + + switch name { + case "scheme": + s.Type = SchemeToServerType(match[i]) + case "host": + singleHost := match[i] + if singleHost != "" { + s.Hosts = []string{singleHost} + } + case "hosts": + hosts := strings.Split(match[i], ",") + for _, host := range hosts { + s.Hosts = append(s.Hosts, strings.TrimSpace(host)) + } + case "path": + rawPath = match[i] + case "username": + username = match[i] + case "password": + password = match[i] + } + } + + return rawPath, username, password, nil +} + +func (s *Server) unmarshalPath(u *url.URL) error { + s.Path = strings.Split(strings.Trim(u.EscapedPath(), "/"), "/") + + q := u.Query() + qm := map[string][]string(q) + s.Args = make(map[string]any, len(qm)) + + for k := range qm { + var val any + if err := json.Unmarshal([]byte(q.Get(k)), &val); err != nil { + return err + } else { + s.Args[k] = val + } + } + + return nil +} + +func (s *Server) mergedMarshaledServer(srv marshaledServer) error { + if srv.URL != "" { + if err := s.UnmarshalURL(srv.URL); err != nil { + return err + } + + if srv.Args != nil { + maps.Copy(s.Args, srv.Args) + } + return nil + } + + s.Type = srv.Type + s.Hosts = srv.Hosts + if srv.Args != nil { + s.Args = srv.Args + } + + return nil +} diff --git a/config/server_test.go b/config/server_test.go new file mode 100644 index 0000000..695fa2f --- /dev/null +++ b/config/server_test.go @@ -0,0 +1,144 @@ +package config_test + +import ( + "testing" + + "github.com/maxatome/go-testdeep/td" + + "github.com/baez90/nurse/config" + "github.com/baez90/nurse/internal/values" +) + +func TestParseFromURL(t *testing.T) { + t.Parallel() + tests := []struct { + name string + url string + want any + wantErr bool + }{ + { + name: "Single Redis server", + url: "redis://localhost:6379/0", + want: &config.Server{ + Type: config.ServerTypeRedis, + Hosts: []string{"localhost:6379"}, + Path: []string{"0"}, + Args: make(map[string]any), + }, + }, + { + name: "Single Redis server with args", + url: "redis://localhost:6379/0?MaxRetries=3", + want: &config.Server{ + Type: config.ServerTypeRedis, + Hosts: []string{"localhost:6379"}, + Path: []string{"0"}, + Args: map[string]any{ + "MaxRetries": 3.0, + }, + }, + }, + { + name: "Single Redis server with user", + url: "redis://user@localhost:6379/0", + want: &config.Server{ + Type: config.ServerTypeRedis, + Hosts: []string{"localhost:6379"}, + Path: []string{"0"}, + Args: make(map[string]any), + Credentials: &config.Credentials{ + Username: "user", + }, + }, + }, + { + name: "Single Redis server with credentials", + url: "redis://user:p4$$w0rd@localhost:6379/0", + want: &config.Server{ + Type: config.ServerTypeRedis, + Hosts: []string{"localhost:6379"}, + Path: []string{"0"}, + Args: make(map[string]any), + Credentials: &config.Credentials{ + Username: "user", + Password: values.StringP("p4$$w0rd"), + }, + }, + }, + { + name: "Multiple Redis servers", + url: "redis://{localhost:6379,localhost:6380}/0", + want: &config.Server{ + Type: config.ServerTypeRedis, + Hosts: []string{ + "localhost:6379", + "localhost:6380", + }, + Path: []string{"0"}, + Args: make(map[string]any), + }, + }, + { + name: "Multiple Redis servers with args", + url: "redis://{localhost:6379,localhost:6380}/0?MaxRetries=3", + want: &config.Server{ + Type: config.ServerTypeRedis, + Hosts: []string{ + "localhost:6379", + "localhost:6380", + }, + Path: []string{"0"}, + Args: map[string]any{ + "MaxRetries": 3.0, + }, + }, + }, + { + name: "Multiple Redis servers with user", + url: "redis://user@{localhost:6379,localhost:6380}/0", + want: &config.Server{ + Type: config.ServerTypeRedis, + Hosts: []string{ + "localhost:6379", + "localhost:6380", + }, + Credentials: &config.Credentials{ + Username: "user", + }, + Path: []string{"0"}, + Args: make(map[string]any), + }, + }, + { + name: "Multiple Redis servers with credentials", + url: "redis://user:p4$$w0rd@{localhost:6379,localhost:6380}/0", + want: &config.Server{ + Type: config.ServerTypeRedis, + Hosts: []string{ + "localhost:6379", + "localhost:6380", + }, + Credentials: &config.Credentials{ + Username: "user", + Password: values.StringP("p4$$w0rd"), + }, + Path: []string{"0"}, + Args: make(map[string]any), + }, + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := new(config.Server) + + if err := got.UnmarshalURL(tt.url); (err != nil) != tt.wantErr { + t.Errorf("ParseFromURL() error = %v, wantErr %v", err, tt.wantErr) + return + } + td.Cmp(t, got, tt.want) + }) + } +} diff --git a/go.mod b/go.mod index 013943c..b6eab7b 100644 --- a/go.mod +++ b/go.mod @@ -5,9 +5,12 @@ go 1.18 require ( github.com/alecthomas/participle/v2 v2.0.0-alpha8 github.com/go-redis/redis/v8 v8.11.5 + github.com/google/uuid v1.3.0 github.com/maxatome/go-testdeep v1.11.0 - github.com/mitchellh/mapstructure v1.4.3 + github.com/mitchellh/mapstructure v1.5.0 github.com/testcontainers/testcontainers-go v0.13.0 + go.uber.org/zap v1.10.0 + gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b ) require ( @@ -27,7 +30,6 @@ require ( github.com/gogo/protobuf v1.3.2 // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/protobuf v1.5.2 // indirect - github.com/google/uuid v1.3.0 // indirect github.com/magiconair/properties v1.8.6 // indirect github.com/moby/sys/mount v0.3.2 // indirect github.com/moby/sys/mountinfo v0.6.1 // indirect @@ -39,10 +41,12 @@ require ( github.com/pkg/errors v0.9.1 // indirect github.com/sirupsen/logrus v1.8.1 // indirect go.opencensus.io v0.23.0 // indirect + go.uber.org/atomic v1.4.0 // indirect + go.uber.org/multierr v1.1.0 // indirect + golang.org/x/exp v0.0.0-20220428152302-39d4317da171 // indirect golang.org/x/net v0.0.0-20220420153159-1850ba15e1be // indirect golang.org/x/sys v0.0.0-20220412211240-33da011f77ad // indirect google.golang.org/genproto v0.0.0-20220414192740-2d67ff6cf2b4 // indirect google.golang.org/grpc v1.45.0 // indirect google.golang.org/protobuf v1.28.0 // indirect - gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect ) diff --git a/go.sum b/go.sum index 9366bb0..aecc10c 100644 --- a/go.sum +++ b/go.sum @@ -500,8 +500,8 @@ github.com/miekg/pkcs11 v1.0.3/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WT github.com/mistifyio/go-zfs v2.1.2-0.20190413222219-f784269be439+incompatible/go.mod h1:8AuVvqP/mXw1px98n46wfvcGfQ4ci2FwoAjKYxuo3Z4= github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= -github.com/mitchellh/mapstructure v1.4.3 h1:OVowDSCllw/YjdLkam3/sm7wEtOy59d8ndGgCcyj8cs= -github.com/mitchellh/mapstructure v1.4.3/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= +github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/mitchellh/osext v0.0.0-20151018003038-5e2d6d41470f/go.mod h1:OkQIRizQZAeMln+1tSwduZz7+Af5oFlKirV/MSYes2A= github.com/moby/locker v1.0.1/go.mod h1:S7SDdo5zpBK84bzzVlKr2V0hz+7x9hWbYC/kq7oQppc= github.com/moby/sys/mount v0.2.0/go.mod h1:aAivFE2LB3W4bACsUXChRHQ0qKWsetY4Y9V7sxOougM= @@ -717,9 +717,12 @@ go.opencensus.io v0.23.0 h1:gqCw0LfLxScz8irSi8exQc7fyQ0fKQU/qnC/X8+V/1M= go.opencensus.io v0.23.0/go.mod h1:XItmlyltB5F7CS4xOC1DcqMoFqwtC6OG2xF7mCv7P7E= go.opentelemetry.io/proto/otlp v0.7.0/go.mod h1:PqfVotwruBrMGOCsRd/89rSnXhoiJIqeYNgFYFoEGnI= go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= +go.uber.org/atomic v1.4.0 h1:cxzIVoETapQEqDhQu3QfnvXAV4AlzcvUCxkVUFw3+EU= go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= go.uber.org/goleak v1.1.12/go.mod h1:cwTWslyiVhfpKIDGSZEM2HlOvcqm+tG4zioyIeLoqMQ= +go.uber.org/multierr v1.1.0 h1:HoEmRHQPVSqub6w2z2d2EOVs2fjyFRGyofhKuyDq0QI= go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= +go.uber.org/zap v1.10.0 h1:ORx85nbTijNz8ljznvCMR1ZBIPKFn3jQrag10X2AsuM= go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= golang.org/x/crypto v0.0.0-20171113213409-9f005a07e0d3/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= @@ -744,6 +747,8 @@ golang.org/x/exp v0.0.0-20191227195350-da58074b4299/go.mod h1:2RIsYlXP63K8oxa1u0 golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4= golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM= golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU= +golang.org/x/exp v0.0.0-20220428152302-39d4317da171 h1:TfdoLivD44QwvssI9Sv1xwa5DcL5XQr4au4sZ2F2NV4= +golang.org/x/exp v0.0.0-20220428152302-39d4317da171/go.mod h1:lgLbSvA5ygNOMpwM/9anMpWVlVJ7Z+cHWq/eFuinpGE= golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= diff --git a/grammar/grammar.go b/grammar/grammar.go index a1c193b..89e4b16 100644 --- a/grammar/grammar.go +++ b/grammar/grammar.go @@ -1,9 +1,5 @@ package grammar -type CheckUnmarshaler interface { - UnmarshalCheck(c Check) error -} - type Call struct { Module string `parser:"(@Module'.')?"` Name string `parser:"@Ident"` @@ -20,5 +16,5 @@ type Check struct { } type Script struct { - Checks []Check `parser:"@@*"` + Checks []Check `parser:"(@@';'?)*"` } diff --git a/main.go b/main.go index 4a111e0..727ca1d 100644 --- a/main.go +++ b/main.go @@ -1,28 +1,66 @@ package main import ( + "log" "os" - "time" + + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + + "github.com/baez90/nurse/config" ) -var checkTimeout time.Duration +var ( + logLevel = zapcore.InfoLevel + cfgFile string + cfg config.Nurse +) func main() { -} + if err := prepareFlags(); err != nil { + log.Fatalf("Failed to parse flags: %v", err) + } + setupLogging() -func lookupEnvOr[T any](envKey string, fallback T, parse func(envVal string) (T, error)) T { - envVal := os.Getenv(envKey) - if envVal == "" { - return fallback + logger := zap.L() + + envCfg, err := config.New( + config.WithValuesFrom(cfg), + config.WithConfigFile(cfgFile), + config.WithServersFromEnv(), + config.WithEndpointsFromEnv(), + ) + + if err != nil { + logger.Error("Failed to load config from environment", zap.Error(err)) + os.Exit(1) } - if parsed, err := parse(envVal); err != nil { - return fallback - } else { - return parsed - } + logger.Debug("Loaded config", zap.Any("config", envCfg)) } -func identity[T any](in T) (T, error) { - return in, nil +func setupLogging() { + cfg := zap.NewProductionConfig() + cfg.Level = zap.NewAtomicLevelAt(logLevel) + logger, err := cfg.Build() + if err != nil { + log.Fatalf("Failed to setup logging: %v", err) + } + + zap.ReplaceGlobals(logger) +} + +func prepareFlags() error { + set := config.ConfigureFlags(&cfg) + + set.Var(&logLevel, "log-level", "Log level to use") + + set.StringVar( + &cfgFile, + "config", + config.LookupEnvOr[string]("NURSE_CONFIG", "", config.Identity[string]), + "Config file to load, if not set $HOME/.nurse.yaml, /etc/nurse/config.yaml and ./nurse.yaml are tried - optional", + ) + + return set.Parse(os.Args) } diff --git a/nurse.yaml b/nurse.yaml new file mode 100644 index 0000000..0cf1383 --- /dev/null +++ b/nurse.yaml @@ -0,0 +1,16 @@ +checkTimeout: 500ms +servers: + redis_local: + url: redis://localhost:6379/0 + args: + MaxRetries: 3 + +endpoints: + /ready: + checkTimeout: 500ms + checks: + - redis.PING("redis_local") + /liveness: + checkTimeout: 100ms + checks: + - redis.GET("redis_local", "running") => String("ok") diff --git a/redis/checks.go b/redis/checks.go index 0cf34f4..bdf4a85 100644 --- a/redis/checks.go +++ b/redis/checks.go @@ -1,35 +1,18 @@ package redis import ( - "fmt" - "strings" - "github.com/baez90/nurse/check" - "github.com/baez90/nurse/grammar" ) -var knownChecks = map[string]func() check.SystemChecker{ - "ping": func() check.SystemChecker { - return new(PingCheck) - }, - "get": func() check.SystemChecker { - return new(GetCheck) - }, -} - -func LookupCheck(c grammar.Check) (check.SystemChecker, error) { - var ( - provider func() check.SystemChecker - ok bool +func Module() *check.Module { + m, _ := check.NewModule( + check.WithCheck("ping", check.FactoryFunc(func() check.SystemChecker { + return new(PingCheck) + })), + check.WithCheck("get", check.FactoryFunc(func() check.SystemChecker { + return new(GetCheck) + })), ) - if provider, ok = knownChecks[strings.ToLower(c.Initiator.Name)]; !ok { - return nil, fmt.Errorf("%w: %s", check.ErrNoSuchCheck, c.Initiator.Name) - } - chk := provider() - if err := chk.UnmarshalCheck(c); err != nil { - return nil, err - } - - return chk, nil + return m } diff --git a/redis/checks_test.go b/redis/checks_test.go new file mode 100644 index 0000000..f46345c --- /dev/null +++ b/redis/checks_test.go @@ -0,0 +1,103 @@ +package redis_test + +import ( + "context" + "fmt" + "strings" + "testing" + + redisCli "github.com/go-redis/redis/v8" + "github.com/google/uuid" + "github.com/maxatome/go-testdeep/td" + + "github.com/baez90/nurse/config" + "github.com/baez90/nurse/grammar" + "github.com/baez90/nurse/redis" +) + +func TestChecks_Execute(t *testing.T) { + t.Parallel() + + redisModule := redis.Module() + + tests := []struct { + name string + check string + setup func(tb testing.TB, cli redisCli.UniversalClient) + wantErr bool + }{ + { + name: "Get value", + check: `redis.GET("%s", "some_key")`, + setup: func(tb testing.TB, cli redisCli.UniversalClient) { + tb.Helper() + td.CmpNoError(tb, cli.Set(context.Background(), "some_key", "some_value", 0).Err()) + }, + wantErr: false, + }, + { + name: "Get value - validate value", + check: `redis.GET("%s", "some_key") => String("some_value")`, + setup: func(tb testing.TB, cli redisCli.UniversalClient) { + tb.Helper() + td.CmpNoError(tb, cli.Set(context.Background(), "some_key", "some_value", 0).Err()) + }, + wantErr: false, + }, + { + name: "PING check", + check: `redis.PING("%s")`, + wantErr: false, + }, + { + name: "PING check - with custom message", + check: `redis.PING("%s", "Hello, Redis!")`, + wantErr: false, + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + srv := PrepareRedisContainer(t) + serverName := uuid.NewString() + + register := config.NewServerRegister() + + if err := register.Register(serverName, *srv); err != nil { + t.Fatalf("DefaultLookup.Register() err = %v", err) + } + + cli, err := redis.ClientForServer(srv) + if err != nil { + t.Fatalf("redis.ClientForServer() err = %v", err) + } + + if tt.setup != nil { + tt.setup(t, cli) + } + + if strings.Contains(tt.check, "%s") { + tt.check = fmt.Sprintf(tt.check, serverName) + } + + parser, err := grammar.NewParser[grammar.Check]() + td.CmpNoError(t, err, "grammar.NewParser()") + parsedCheck, err := parser.Parse(tt.check) + td.CmpNoError(t, err, "parser.Parse()") + + chk, err := redisModule.Lookup(*parsedCheck, register) + td.CmpNoError(t, err, "redis.LookupCheck()") + + td.CmpNoError(t, chk.UnmarshalCheck(*parsedCheck, register), "get.UnmarshalCheck()") + td.CmpNoError(t, chk.Execute(context.Background())) + + if tt.wantErr { + td.CmpError(t, chk.Execute(context.Background())) + } else { + td.CmpNoError(t, chk.Execute(context.Background())) + } + }) + } +} diff --git a/redis/client.go b/redis/client.go index 444ba54..76690ef 100644 --- a/redis/client.go +++ b/redis/client.go @@ -10,10 +10,10 @@ import ( "github.com/baez90/nurse/grammar" ) -func clientFromParam(p grammar.Param) (redis.UniversalClient, error) { +func clientFromParam(p grammar.Param, srvLookup config.ServerLookup) (redis.UniversalClient, error) { if srvName, err := p.AsString(); err != nil { return nil, err - } else if srv, err := config.DefaultLookup.Lookup(srvName); err != nil { + } else if srv, err := srvLookup.Lookup(srvName); err != nil { return nil, err } else if redisCli, err := ClientForServer(srv); err != nil { return nil, err @@ -37,5 +37,10 @@ func ClientForServer(srv *config.Server) (redis.UniversalClient, error) { return nil, err } + if srv.Credentials != nil { + opts.Username = srv.Credentials.Username + opts.Password = *srv.Credentials.Password + } + return redis.NewUniversalClient(opts), nil } diff --git a/redis/container_test.go b/redis/container_test.go index 879c480..c065fc8 100644 --- a/redis/container_test.go +++ b/redis/container_test.go @@ -3,19 +3,16 @@ package redis_test import ( "context" "fmt" - "net/url" "testing" "time" "github.com/testcontainers/testcontainers-go" - "github.com/testcontainers/testcontainers-go/wait" "github.com/baez90/nurse/config" ) func PrepareRedisContainer(tb testing.TB) *config.Server { tb.Helper() - const redisPort = "6379/tcp" ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute) @@ -24,9 +21,9 @@ func PrepareRedisContainer(tb testing.TB) *config.Server { container, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{ ContainerRequest: testcontainers.ContainerRequest{ Image: "docker.io/redis:alpine", - Name: tb.Name(), ExposedPorts: []string{redisPort}, - WaitingFor: wait.ForListeningPort(redisPort), + SkipReaper: true, + AutoRemove: true, }, Started: true, Logger: testcontainers.TestLogger(tb), @@ -46,13 +43,8 @@ func PrepareRedisContainer(tb testing.TB) *config.Server { tb.Fatalf("container.PortEndpoint() err = %v", err) } - u, err := url.Parse(fmt.Sprintf("%s/0?MaxRetries=3", ep)) - if err != nil { - tb.Fatalf("url.Parse() err = %v", err) - } - - srv, err := config.ParseFromURL(u) - if err != nil { + srv := new(config.Server) + if err := srv.UnmarshalURL(fmt.Sprintf("%s/0?MaxRetries=3", ep)); err != nil { tb.Fatalf("config.ParseFromURL() err = %v", err) } return srv diff --git a/redis/get.go b/redis/get.go index a299465..4741a1c 100644 --- a/redis/get.go +++ b/redis/get.go @@ -6,13 +6,11 @@ import ( "github.com/go-redis/redis/v8" "github.com/baez90/nurse/check" + "github.com/baez90/nurse/config" "github.com/baez90/nurse/grammar" ) -var ( - _ check.SystemChecker = (*GetCheck)(nil) - _ grammar.CheckUnmarshaler = (*GetCheck)(nil) -) +var _ check.SystemChecker = (*GetCheck)(nil) type GetCheck struct { redis.UniversalClient @@ -30,7 +28,7 @@ func (g GetCheck) Execute(ctx context.Context) error { return g.validators.Validate(cmd) } -func (g *GetCheck) UnmarshalCheck(c grammar.Check) error { +func (g *GetCheck) UnmarshalCheck(c grammar.Check, lookup config.ServerLookup) error { const serverAndKeyArgsNumber = 2 inst := c.Initiator if err := grammar.ValidateParameterCount(inst.Params, serverAndKeyArgsNumber); err != nil { @@ -38,7 +36,7 @@ func (g *GetCheck) UnmarshalCheck(c grammar.Check) error { } var err error - if g.UniversalClient, err = clientFromParam(inst.Params[0]); err != nil { + if g.UniversalClient, err = clientFromParam(inst.Params[0], lookup); err != nil { return err } @@ -46,5 +44,9 @@ func (g *GetCheck) UnmarshalCheck(c grammar.Check) error { return err } + if g.validators, err = ValidatorsForFilters(c.Validators); err != nil { + return err + } + return nil } diff --git a/redis/get_test.go b/redis/get_test.go deleted file mode 100644 index 06804bf..0000000 --- a/redis/get_test.go +++ /dev/null @@ -1,70 +0,0 @@ -package redis_test - -import ( - "context" - "fmt" - "testing" - - redisCli "github.com/go-redis/redis/v8" - "github.com/maxatome/go-testdeep/td" - - "github.com/baez90/nurse/config" - "github.com/baez90/nurse/grammar" - "github.com/baez90/nurse/redis" -) - -func TestGetCheck_Execute(t *testing.T) { - t.Parallel() - srv := PrepareRedisContainer(t) - if err := config.DefaultLookup.Register(t.Name(), *srv); err != nil { - t.Fatalf("DefaultLookup.Register() err = %v", err) - } - - cli, err := redis.ClientForServer(srv) - if err != nil { - t.Fatalf("redis.ClientForServer() err = %v", err) - } - - tests := []struct { - name string - check string - setup func(tb testing.TB, cli redisCli.UniversalClient) - wantErr bool - }{ - { - name: "Get value", - check: fmt.Sprintf(`redis.GET("%s", "some_key")`, t.Name()), - setup: func(tb testing.TB, cli redisCli.UniversalClient) { - tb.Helper() - td.CmpNoError(tb, cli.Set(context.Background(), "some_key", "some_value", 0).Err()) - }, - wantErr: false, - }, - } - for _, tt := range tests { - tt := tt - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - if tt.setup != nil { - tt.setup(t, cli) - } - - get := new(redis.GetCheck) - - parser, err := grammar.NewParser[grammar.Check]() - td.CmpNoError(t, err, "grammar.NewParser()") - check, err := parser.Parse(tt.check) - td.CmpNoError(t, err, "parser.Parse()") - - td.CmpNoError(t, get.UnmarshalCheck(*check), "get.UnmarshalCheck()") - td.CmpNoError(t, get.Execute(context.Background())) - - if tt.wantErr { - td.CmpError(t, get.Execute(context.Background())) - } else { - td.CmpNoError(t, get.Execute(context.Background())) - } - }) - } -} diff --git a/redis/ping.go b/redis/ping.go index 7c71049..3892c01 100644 --- a/redis/ping.go +++ b/redis/ping.go @@ -7,13 +7,11 @@ import ( "github.com/go-redis/redis/v8" "github.com/baez90/nurse/check" + "github.com/baez90/nurse/config" "github.com/baez90/nurse/grammar" ) -var ( - _ check.SystemChecker = (*PingCheck)(nil) - _ grammar.CheckUnmarshaler = (*PingCheck)(nil) -) +var _ check.SystemChecker = (*PingCheck)(nil) type PingCheck struct { redis.UniversalClient @@ -34,7 +32,7 @@ func (p PingCheck) Execute(ctx context.Context) error { return nil } -func (p *PingCheck) UnmarshalCheck(c grammar.Check) error { +func (p *PingCheck) UnmarshalCheck(c grammar.Check, lookup config.ServerLookup) error { const ( serverOnlyArgCount = 1 serverAndMessageArgCount = 2 @@ -54,7 +52,7 @@ func (p *PingCheck) UnmarshalCheck(c grammar.Check) error { } fallthrough case serverOnlyArgCount: - if cli, err := clientFromParam(init.Params[0]); err != nil { + if cli, err := clientFromParam(init.Params[0], lookup); err != nil { return err } else { p.UniversalClient = cli diff --git a/redis/ping_test.go b/redis/ping_test.go deleted file mode 100644 index b6ba76c..0000000 --- a/redis/ping_test.go +++ /dev/null @@ -1,61 +0,0 @@ -package redis_test - -import ( - "context" - "fmt" - "testing" - - "github.com/maxatome/go-testdeep/td" - - "github.com/baez90/nurse/config" - "github.com/baez90/nurse/grammar" - "github.com/baez90/nurse/redis" -) - -func TestPingCheck_Execute(t *testing.T) { - t.Parallel() - srv := PrepareRedisContainer(t) - if err := config.DefaultLookup.Register(t.Name(), *srv); err != nil { - t.Fatalf("DefaultLookup.Register() err = %v", err) - } - - tests := []struct { - name string - check string - wantErr bool - }{ - { - name: "PING check", - check: fmt.Sprintf(`redis.PING("%s")`, t.Name()), - wantErr: false, - }, - // redis.PING("my-redis") - // redis.PING("my-redis", "Hello, Redis!") - // redis.GET("my-redis", "some-key") -> String("Hello") - // redis.PING("my-redis"); redis.GET("my-redis", "some-key") -> String("Hello") - { - name: "PING check - with custom message", - check: fmt.Sprintf(`redis.PING("%s", "Hello, Redis!")`, t.Name()), - wantErr: false, - }, - } - for _, tt := range tests { - tt := tt - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - ping := new(redis.PingCheck) - - parser, err := grammar.NewParser[grammar.Check]() - td.CmpNoError(t, err, "grammar.NewParser()") - check, err := parser.Parse(tt.check) - td.CmpNoError(t, err, "parser.Parse()") - - if err := ping.UnmarshalCheck(*check); err != nil { - t.Fatalf("UnmarshalCheck() err = %v", err) - } - if err := ping.Execute(context.Background()); (err != nil) != tt.wantErr { - t.Errorf("Execute() error = %v, wantErr %v", err, tt.wantErr) - } - }) - } -} diff --git a/redis/validation.go b/redis/validation.go index 5f84957..ce1ee3b 100644 --- a/redis/validation.go +++ b/redis/validation.go @@ -3,23 +3,64 @@ package redis import ( "errors" "fmt" + "strings" "github.com/go-redis/redis/v8" + + "github.com/baez90/nurse/check" + "github.com/baez90/nurse/grammar" ) var ( + ErrNoSuchValidator = errors.New("no such validator") + _ CmdValidator = (ValidationChain)(nil) - _ CmdValidator = StringCmdValidator("") + _ CmdValidator = (*StringCmdValidator)(nil) + + knownValidators = map[string]func() unmarshallableCmdValidator{ + "string": func() unmarshallableCmdValidator { + return new(StringCmdValidator) + }, + } ) type ( CmdValidator interface { Validate(cmder redis.Cmder) error } - - ValidationChain []CmdValidator + unmarshallableCmdValidator interface { + CmdValidator + check.CallUnmarshaler + } ) +func ValidatorsForFilters(filters *grammar.Filters) (ValidationChain, error) { + if filters == nil || filters.Chain == nil { + return ValidationChain{}, nil + } + chain := make(ValidationChain, 0, len(filters.Chain)) + for i := range filters.Chain { + validationCall := filters.Chain[i] + if validatorProvider, ok := knownValidators[strings.ToLower(validationCall.Name)]; !ok { + return nil, fmt.Errorf("%w: %s", ErrNoSuchValidator, validationCall.Name) + } else { + validator := validatorProvider() + if err := validator.UnmarshalCall(validationCall); err != nil { + return nil, err + } + chain = append(chain, validator) + } + } + + return chain, nil +} + +type ValidationChain []CmdValidator + +func (v ValidationChain) UnmarshalCall(grammar.Call) error { + return errors.New("cannot unmarshal chain") +} + func (v ValidationChain) Validate(cmder redis.Cmder) error { for i := range v { if err := v[i].Validate(cmder); err != nil { @@ -32,6 +73,20 @@ func (v ValidationChain) Validate(cmder redis.Cmder) error { type StringCmdValidator string +func (s *StringCmdValidator) UnmarshalCall(c grammar.Call) error { + if err := grammar.ValidateParameterCount(c.Params, 1); err != nil { + return err + } + + want, err := c.Params[0].AsString() + if err != nil { + return err + } + + *s = StringCmdValidator(want) + return nil +} + func (s StringCmdValidator) Validate(cmder redis.Cmder) error { if err := cmder.Err(); err != nil { return err