Implement config reading

- support config file
- support reading checks, server definition and endpoint definitions from env
This commit is contained in:
Peter 2022-05-08 11:00:22 +02:00
parent 803f52ba46
commit 4c2aa968d2
Signed by: prskr
GPG key ID: C1DB5D2E8DB512F9
27 changed files with 1263 additions and 255 deletions

1
.gitignore vendored
View file

@ -16,3 +16,4 @@
dist/
.idea/
cue.mod/

View file

@ -4,12 +4,26 @@ import (
"context"
"errors"
"github.com/baez90/nurse/config"
"github.com/baez90/nurse/grammar"
)
var ErrNoSuchCheck = errors.New("no such check")
var (
ErrNoSuchCheck = errors.New("no such check")
ErrConflictingCheck = errors.New("check with same name already registered")
)
type SystemChecker interface {
grammar.CheckUnmarshaler
Execute(ctx context.Context) error
}
type (
Unmarshaler interface {
UnmarshalCheck(c grammar.Check, lookup config.ServerLookup) error
}
SystemChecker interface {
Unmarshaler
Execute(ctx context.Context) error
}
CallUnmarshaler interface {
UnmarshalCall(c grammar.Call) error
}
)

91
check/modules.go Normal file
View file

@ -0,0 +1,91 @@
package check
import (
"fmt"
"strings"
"sync"
"github.com/baez90/nurse/config"
"github.com/baez90/nurse/grammar"
)
type (
ModuleOption interface {
Apply(m *Module) error
}
ModuleOptionFunc func(m *Module) error
Factory interface {
New() SystemChecker
}
FactoryFunc func() SystemChecker
)
func (f ModuleOptionFunc) Apply(m *Module) error {
return f(m)
}
func (f FactoryFunc) New() SystemChecker {
return f()
}
func WithCheck(name string, factory Factory) ModuleOption {
return ModuleOptionFunc(func(m *Module) error {
return m.Register(name, factory)
})
}
func NewModule(opts ...ModuleOption) (*Module, error) {
m := &Module{
knownChecks: make(map[string]Factory),
}
for i := range opts {
if err := opts[i].Apply(m); err != nil {
return nil, err
}
}
return m, nil
}
type Module struct {
lock sync.RWMutex
knownChecks map[string]Factory
}
func (m *Module) Lookup(c grammar.Check, srvLookup config.ServerLookup) (SystemChecker, error) {
m.lock.RLock()
defer m.lock.RUnlock()
var (
factory Factory
ok bool
)
if factory, ok = m.knownChecks[strings.ToLower(c.Initiator.Name)]; !ok {
return nil, fmt.Errorf("%w: %s", ErrNoSuchCheck, c.Initiator.Name)
}
chk := factory.New()
if err := chk.UnmarshalCheck(c, srvLookup); err != nil {
return nil, err
}
return chk, nil
}
func (m *Module) Register(name string, factory Factory) error {
m.lock.Lock()
defer m.lock.Unlock()
name = strings.ToLower(name)
if _, ok := m.knownChecks[name]; ok {
return fmt.Errorf("%w: %s", ErrConflictingCheck, name)
}
m.knownChecks[name] = factory
return nil
}

91
config/app_config.go Normal file
View file

@ -0,0 +1,91 @@
package config
import (
"io"
"os"
"time"
)
type (
Option interface {
Extend(n Nurse) (Nurse, error)
}
OptionFunc func(n Nurse) (Nurse, error)
)
func (f OptionFunc) Extend(n Nurse) (Nurse, error) {
return f(n)
}
func New(opts ...Option) (*Nurse, error) {
var (
inst Nurse
err error
)
for i := range opts {
if inst, err = opts[i].Extend(inst); err != nil {
return nil, err
}
}
return &inst, nil
}
type Nurse struct {
Servers map[string]Server
Endpoints map[Route]EndpointSpec
CheckTimeout time.Duration
}
// Merge merges the current Nurse instance with another one
// giving the current instance precedence means no set value is overwritten
func (n Nurse) Merge(other Nurse) Nurse {
if n.CheckTimeout == 0 {
n.CheckTimeout = other.CheckTimeout
}
for name, srv := range other.Servers {
if _, ok := n.Servers[name]; !ok {
n.Servers[name] = srv
}
}
return n
}
func (n *Nurse) Unmarshal(reader io.ReadSeeker) error {
providers := []func(io.Reader) configDecoder{
newYAMLDecoder,
newJSONDecoder,
}
for i := range providers {
if err := n.tryUnmarshal(reader, providers[i]); err == nil {
return nil
}
}
return nil
}
func (n *Nurse) ReadFromFile(configFilePath string) error {
if file, err := os.Open(configFilePath); err != nil {
return err
} else {
defer func() {
_ = file.Close()
}()
return n.Unmarshal(file)
}
}
func (n *Nurse) tryUnmarshal(seeker io.ReadSeeker, prov func(reader io.Reader) configDecoder) error {
if _, err := seeker.Seek(0, 0); err != nil {
return err
}
decoder := prov(seeker)
return decoder.DecodeConfig(n)
}

41
config/decoder.go Normal file
View file

@ -0,0 +1,41 @@
package config
import (
"encoding/json"
"io"
"gopkg.in/yaml.v3"
)
var (
_ configDecoder = (*jsonDecoder)(nil)
_ configDecoder = (*yamlDecoder)(nil)
)
type configDecoder interface {
DecodeConfig(into *Nurse) error
}
func newJSONDecoder(r io.Reader) configDecoder {
return &jsonDecoder{decoder: json.NewDecoder(r)}
}
type jsonDecoder struct {
decoder *json.Decoder
}
func (j jsonDecoder) DecodeConfig(into *Nurse) error {
return j.decoder.Decode(into)
}
func newYAMLDecoder(r io.Reader) configDecoder {
return &yamlDecoder{decoder: yaml.NewDecoder(r)}
}
type yamlDecoder struct {
decoder *yaml.Decoder
}
func (y yamlDecoder) DecodeConfig(into *Nurse) error {
return y.decoder.Decode(into)
}

42
config/endpoints.go Normal file
View file

@ -0,0 +1,42 @@
package config
import (
"encoding"
"fmt"
"path"
"strings"
"time"
"github.com/baez90/nurse/grammar"
)
var _ encoding.TextUnmarshaler = (*EndpointSpec)(nil)
type Route string
func (r Route) String() string {
val := string(r)
val = strings.Trim(val, "/")
return path.Clean(fmt.Sprintf("/%s", val))
}
type EndpointSpec struct {
CheckTimeout time.Duration
Checks []grammar.Check
}
func (e *EndpointSpec) UnmarshalText(text []byte) error {
parser, err := grammar.NewParser[grammar.Script]()
if err != nil {
return err
}
script, err := parser.Parse(string(text))
if err != nil {
return err
}
e.Checks = script.Checks
return nil
}

61
config/env.go Normal file
View file

@ -0,0 +1,61 @@
package config
import (
"os"
"path"
. "strings"
)
const (
ServerKeyPrefix = "NURSE_SERVER"
EndpointKeyPrefix = "NURSE_ENDPOINT"
)
func ServersFromEnv() (map[string]Server, error) {
servers := make(map[string]Server)
for _, kv := range os.Environ() {
key, value, valid := Cut(kv, "=")
if !valid {
continue
}
if !HasPrefix(key, ServerKeyPrefix) {
continue
}
serverName := ToLower(Trim(Replace(key, ServerKeyPrefix, "", -1), "_"))
srv := Server{}
if err := srv.UnmarshalURL(value); err != nil {
return nil, err
}
servers[serverName] = srv
}
return servers, nil
}
func EndpointsFromEnv() (map[Route]EndpointSpec, error) {
endpoints := make(map[Route]EndpointSpec)
for _, kv := range os.Environ() {
key, value, valid := Cut(kv, "=")
if !valid {
continue
}
if !HasPrefix(key, EndpointKeyPrefix) {
continue
}
endpointRoute := path.Join(Split(ToLower(Trim(Replace(key, EndpointKeyPrefix, "", -1), "_")), "_")...)
spec := EndpointSpec{}
if err := spec.UnmarshalText([]byte(value)); err != nil {
return nil, err
}
endpoints[Route(endpointRoute)] = spec
}
return endpoints, nil
}

162
config/env_test.go Normal file
View file

@ -0,0 +1,162 @@
package config_test
import (
"fmt"
"testing"
"github.com/maxatome/go-testdeep/td"
"github.com/baez90/nurse/config"
)
//nolint:paralleltest // not possible with env setup
func TestServersFromEnv(t *testing.T) {
tests := []struct {
name string
env map[string]string
want any
wantErr bool
}{
{
name: "Empty env",
want: td.NotNil(),
},
{
name: "Single server",
env: map[string]string{
fmt.Sprintf("%s_REDIS_1", config.ServerKeyPrefix): "redis://localhost:6379",
},
want: td.Map(make(map[string]config.Server), td.MapEntries{
"redis_1": config.Server{
Type: config.ServerTypeRedis,
Hosts: []string{"localhost:6379"},
Args: make(map[string]any),
},
}),
wantErr: false,
},
{
name: "Multiple servers",
env: map[string]string{
fmt.Sprintf("%s_REDIS_1", config.ServerKeyPrefix): "redis://localhost:6379",
fmt.Sprintf("%s_REDIS_2", config.ServerKeyPrefix): "redis://redis:6379",
},
want: td.Map(make(map[string]config.Server), td.MapEntries{
"redis_1": config.Server{
Type: config.ServerTypeRedis,
Hosts: []string{"localhost:6379"},
Args: make(map[string]any),
},
"redis_2": config.Server{
Type: config.ServerTypeRedis,
Hosts: []string{"redis:6379"},
Args: make(map[string]any),
},
}),
wantErr: false,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
if tt.env != nil {
for k, v := range tt.env {
t.Setenv(k, v)
}
}
got, err := config.ServersFromEnv()
if (err != nil) != tt.wantErr {
t.Errorf("ServersFromEnv() error = %v, wantErr %v", err, tt.wantErr)
return
}
td.Cmp(t, got, tt.want)
})
}
}
//nolint:paralleltest // not possible with env setup
func TestEndpointsFromEnv(t *testing.T) {
tests := []struct {
name string
env map[string]string
want any
wantErr bool
}{
{
name: "Empty env",
want: td.NotNil(),
},
{
name: "Single endpoint",
env: map[string]string{
fmt.Sprintf("%s_readiness", config.EndpointKeyPrefix): `redis.PING("local_redis")`,
},
want: td.Map(make(map[config.Route]config.EndpointSpec), td.MapEntries{
config.Route("readiness"): td.Struct(config.EndpointSpec{}, td.StructFields{
"Checks": td.Len(1),
}),
}),
wantErr: false,
},
{
name: "Single endpoint - multiple checks",
env: map[string]string{
fmt.Sprintf("%s_readiness", config.EndpointKeyPrefix): `redis.PING("local_redis");redis.GET("local_redis", "serving") => String("ok")`,
},
want: td.Map(make(map[config.Route]config.EndpointSpec), td.MapEntries{
config.Route("readiness"): td.Struct(config.EndpointSpec{}, td.StructFields{
"Checks": td.Len(2),
}),
}),
wantErr: false,
},
{
name: "Single endpoint - sub-route",
env: map[string]string{
fmt.Sprintf("%s_READINESS_REDIS", config.EndpointKeyPrefix): `redis.PING("local_redis");redis.GET("local_redis", "serving") => String("ok")`,
},
want: td.Map(make(map[config.Route]config.EndpointSpec), td.MapEntries{
config.Route("readiness/redis"): td.Struct(config.EndpointSpec{}, td.StructFields{
"Checks": td.Len(2),
}),
}),
wantErr: false,
},
{
name: "Multiple endpoints",
env: map[string]string{
fmt.Sprintf("%s_readiness", config.EndpointKeyPrefix): `redis.PING("local_redis")`,
fmt.Sprintf("%s_liveness", config.EndpointKeyPrefix): `redis.GET("local_redis", "serving") => String("ok")`,
},
want: td.Map(make(map[config.Route]config.EndpointSpec), td.MapEntries{
config.Route("readiness"): td.Struct(config.EndpointSpec{}, td.StructFields{
"Checks": td.Len(1),
}),
config.Route("liveness"): td.Struct(config.EndpointSpec{}, td.StructFields{
"Checks": td.Len(1),
}),
}),
wantErr: false,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
if tt.env != nil {
for k, v := range tt.env {
t.Setenv(k, v)
}
}
got, err := config.EndpointsFromEnv()
if (err != nil) != tt.wantErr {
t.Errorf("EndpointsFromEnv() error = %v, wantErr %v", err, tt.wantErr)
return
}
td.Cmp(t, got, tt.want)
})
}
}

39
config/flags.go Normal file
View file

@ -0,0 +1,39 @@
package config
import (
"flag"
"os"
"time"
)
const defaultCheckTimeout = 500 * time.Millisecond
func ConfigureFlags(cfg *Nurse) *flag.FlagSet {
set := flag.NewFlagSet("nurse", flag.ContinueOnError)
set.DurationVar(
&cfg.CheckTimeout,
"check-timeout",
LookupEnvOr("NURSE_CHECK_TIMEOUT", defaultCheckTimeout, time.ParseDuration),
"Timeout when running checks",
)
return set
}
func LookupEnvOr[T any](envKey string, fallback T, parse func(envVal string) (T, error)) T {
envVal := os.Getenv(envKey)
if envVal == "" {
return fallback
}
if parsed, err := parse(envVal); err != nil {
return fallback
} else {
return parsed
}
}
func Identity[T any](in T) (T, error) {
return in, nil
}

View file

@ -10,17 +10,26 @@ import (
var (
ErrServerNameAlreadyRegistered = errors.New("server name is already registered")
ErrNoSuchServer = errors.New("no known server with given name")
DefaultLookup = &ServerLookup{
servers: make(map[string]Server),
}
_ ServerLookup = (*ServerRegister)(nil)
)
type ServerLookup struct {
type ServerLookup interface {
Lookup(name string) (*Server, error)
}
func NewServerRegister() *ServerRegister {
return &ServerRegister{
servers: make(map[string]Server),
}
}
type ServerRegister struct {
lock sync.RWMutex
servers map[string]Server
}
func (l *ServerLookup) Register(name string, srv Server) error {
func (l *ServerRegister) Register(name string, srv Server) error {
l.lock.Lock()
defer l.lock.Unlock()
@ -34,7 +43,7 @@ func (l *ServerLookup) Register(name string, srv Server) error {
return nil
}
func (l *ServerLookup) Lookup(name string) (*Server, error) {
func (l *ServerRegister) Lookup(name string) (*Server, error) {
l.lock.RLock()
defer l.lock.RUnlock()

119
config/options.go Normal file
View file

@ -0,0 +1,119 @@
package config
import (
"os"
"path/filepath"
"go.uber.org/zap"
)
func WithServersFromEnv() Option {
return OptionFunc(func(n Nurse) (Nurse, error) {
envServers, err := ServersFromEnv()
if err != nil {
return Nurse{}, err
}
if n.Servers == nil || len(n.Servers) == 0 {
n.Servers = envServers
return n, nil
}
for name, srv := range envServers {
if _, ok := n.Servers[name]; !ok {
n.Servers[name] = srv
}
}
return n, nil
})
}
func WithEndpointsFromEnv() Option {
return OptionFunc(func(n Nurse) (Nurse, error) {
envEndpoints, err := EndpointsFromEnv()
if err != nil {
return Nurse{}, err
}
if n.Endpoints == nil || len(n.Endpoints) == 0 {
n.Endpoints = envEndpoints
return n, nil
}
for route, spec := range envEndpoints {
if _, ok := n.Endpoints[route]; !ok {
n.Endpoints[route] = spec
}
}
return n, nil
})
}
func WithValuesFrom(other Nurse) Option {
return OptionFunc(func(n Nurse) (Nurse, error) {
return n.Merge(other), nil
})
}
func WithConfigFile(configFilePath string) Option {
logger := zap.L()
return OptionFunc(func(n Nurse) (Nurse, error) {
var out Nurse
if configFilePath != "" {
logger.Debug("Attempt to load config file")
if err := out.ReadFromFile(configFilePath); err == nil {
return out, nil
} else {
logger.Warn(
"Failed to load config file",
zap.String("config_file_path", configFilePath),
zap.Error(err),
)
}
}
if workingDir, err := os.Getwd(); err == nil {
configFilePath = filepath.Join(workingDir, "nurse.yaml")
logger.Debug("Attempt to load config file from current working directory")
if err = out.ReadFromFile(configFilePath); err == nil {
return out, nil
} else {
logger.Warn(
"Failed to load config file",
zap.String("config_file_path", configFilePath),
zap.Error(err),
)
}
}
if home, err := os.UserHomeDir(); err == nil {
configFilePath = filepath.Join(home, ".nurse.yaml")
logger.Debug("Attempt to load config file from user home directory")
if err = out.ReadFromFile(configFilePath); err == nil {
return out, nil
} else {
logger.Warn(
"Failed to load config file",
zap.String("config_file_path", configFilePath),
zap.Error(err),
)
}
}
configFilePath = filepath.Join("", "etc", "nurse", "config.yaml")
logger.Debug("Attempt to load config file from global config directory")
if err := out.ReadFromFile(configFilePath); err == nil {
return out, nil
} else {
logger.Warn(
"Failed to load config file",
zap.String("config_file_path", configFilePath),
zap.Error(err),
)
}
return Nurse{}, nil
})
}

View file

@ -2,11 +2,13 @@ package config
import (
"encoding/json"
"fmt"
"net/url"
"regexp"
"strings"
"github.com/baez90/nurse/internal/values"
"golang.org/x/exp/maps"
"gopkg.in/yaml.v3"
)
type ServerType uint
@ -14,48 +16,42 @@ type ServerType uint
const (
ServerTypeUnspecified ServerType = iota
ServerTypeRedis
//nolint:lll // cannot break regex
// language=regexp
urlSchemeRegex = `^(?P<scheme>\w+)://((?P<username>[\w-.]+)(:(?P<password>.*))?@)?(\{(?P<hosts>(.+:\d{1,5})(,(.+:\d{1,5}))*)}|(?P<host>.+:\d{1,5}))(?P<path>/.*$)?`
)
var hostsRegexp = regexp.MustCompile(`^{(.+:\d{1,5})(,(.+:\d{1,5}))*}|(.+:\d{1,5})$`)
func ParseFromURL(url *url.URL) (*Server, error) {
srv := &Server{
Type: SchemeToServerType(url.Scheme),
Hosts: hostsRegexp.FindAllString(url.Host, -1),
func (t ServerType) Scheme() string {
switch t {
case ServerTypeRedis:
return "redis"
case ServerTypeUnspecified:
fallthrough
default:
return "unknown"
}
if user := url.User; user != nil {
srv.Credentials = &Credentials{
Username: user.Username(),
}
if pw, ok := user.Password(); ok {
srv.Credentials.Password = values.StringP(pw)
}
}
srv.Path = strings.Split(strings.Trim(url.EscapedPath(), "/"), "/")
q := url.Query()
qm := map[string][]string(q)
srv.Args = make(map[string]any, len(qm))
for k := range qm {
var val any
if err := json.Unmarshal([]byte(q.Get(k)), &val); err != nil {
return nil, err
} else {
srv.Args[k] = val
}
}
return srv, nil
}
var hostsRegexp = regexp.MustCompile(urlSchemeRegex)
type Credentials struct {
Username string
Password *string
}
var (
_ json.Unmarshaler = (*Server)(nil)
_ yaml.Unmarshaler = (*Server)(nil)
)
type marshaledServer struct {
Type ServerType
URL string
Hosts []string
Args map[string]any
}
type Server struct {
Type ServerType
Credentials *Credentials
@ -63,3 +59,135 @@ type Server struct {
Path []string
Args map[string]any
}
func (s *Server) UnmarshalYAML(value *yaml.Node) error {
s.Args = make(map[string]any)
tmp := new(marshaledServer)
if err := value.Decode(tmp); err != nil {
return err
}
return s.mergedMarshaledServer(*tmp)
}
func (s *Server) UnmarshalJSON(bytes []byte) error {
s.Args = make(map[string]any)
tmp := new(marshaledServer)
if err := json.Unmarshal(bytes, tmp); err != nil {
return err
}
return s.mergedMarshaledServer(*tmp)
}
func (s *Server) UnmarshalURL(rawUrl string) error {
rawPath, username, password, err := s.extractBaseProperties(rawUrl)
if err != nil {
return err
}
if username != "" {
s.Credentials = &Credentials{
Username: username,
}
if password != "" {
s.Credentials.Password = &password
}
}
if rawPath != "" {
parsedUrl, err := url.Parse(fmt.Sprintf("%s://%s%s", s.Type.Scheme(), s.Hosts[0], rawPath))
if err != nil {
return err
}
if err = s.unmarshalPath(parsedUrl); err != nil {
return err
}
} else {
s.Args = make(map[string]any)
}
return nil
}
func (s *Server) extractBaseProperties(rawUrl string) (rawPath, username, password string, err error) {
allMatches := hostsRegexp.FindAllStringSubmatch(rawUrl, -1)
if matchLen := len(allMatches); matchLen != 1 {
return "", "", "", fmt.Errorf("ambiguous server match: %d", matchLen)
}
match := allMatches[0]
for i, name := range hostsRegexp.SubexpNames() {
if i == 0 {
continue
}
switch name {
case "scheme":
s.Type = SchemeToServerType(match[i])
case "host":
singleHost := match[i]
if singleHost != "" {
s.Hosts = []string{singleHost}
}
case "hosts":
hosts := strings.Split(match[i], ",")
for _, host := range hosts {
s.Hosts = append(s.Hosts, strings.TrimSpace(host))
}
case "path":
rawPath = match[i]
case "username":
username = match[i]
case "password":
password = match[i]
}
}
return rawPath, username, password, nil
}
func (s *Server) unmarshalPath(u *url.URL) error {
s.Path = strings.Split(strings.Trim(u.EscapedPath(), "/"), "/")
q := u.Query()
qm := map[string][]string(q)
s.Args = make(map[string]any, len(qm))
for k := range qm {
var val any
if err := json.Unmarshal([]byte(q.Get(k)), &val); err != nil {
return err
} else {
s.Args[k] = val
}
}
return nil
}
func (s *Server) mergedMarshaledServer(srv marshaledServer) error {
if srv.URL != "" {
if err := s.UnmarshalURL(srv.URL); err != nil {
return err
}
if srv.Args != nil {
maps.Copy(s.Args, srv.Args)
}
return nil
}
s.Type = srv.Type
s.Hosts = srv.Hosts
if srv.Args != nil {
s.Args = srv.Args
}
return nil
}

144
config/server_test.go Normal file
View file

@ -0,0 +1,144 @@
package config_test
import (
"testing"
"github.com/maxatome/go-testdeep/td"
"github.com/baez90/nurse/config"
"github.com/baez90/nurse/internal/values"
)
func TestParseFromURL(t *testing.T) {
t.Parallel()
tests := []struct {
name string
url string
want any
wantErr bool
}{
{
name: "Single Redis server",
url: "redis://localhost:6379/0",
want: &config.Server{
Type: config.ServerTypeRedis,
Hosts: []string{"localhost:6379"},
Path: []string{"0"},
Args: make(map[string]any),
},
},
{
name: "Single Redis server with args",
url: "redis://localhost:6379/0?MaxRetries=3",
want: &config.Server{
Type: config.ServerTypeRedis,
Hosts: []string{"localhost:6379"},
Path: []string{"0"},
Args: map[string]any{
"MaxRetries": 3.0,
},
},
},
{
name: "Single Redis server with user",
url: "redis://user@localhost:6379/0",
want: &config.Server{
Type: config.ServerTypeRedis,
Hosts: []string{"localhost:6379"},
Path: []string{"0"},
Args: make(map[string]any),
Credentials: &config.Credentials{
Username: "user",
},
},
},
{
name: "Single Redis server with credentials",
url: "redis://user:p4$$w0rd@localhost:6379/0",
want: &config.Server{
Type: config.ServerTypeRedis,
Hosts: []string{"localhost:6379"},
Path: []string{"0"},
Args: make(map[string]any),
Credentials: &config.Credentials{
Username: "user",
Password: values.StringP("p4$$w0rd"),
},
},
},
{
name: "Multiple Redis servers",
url: "redis://{localhost:6379,localhost:6380}/0",
want: &config.Server{
Type: config.ServerTypeRedis,
Hosts: []string{
"localhost:6379",
"localhost:6380",
},
Path: []string{"0"},
Args: make(map[string]any),
},
},
{
name: "Multiple Redis servers with args",
url: "redis://{localhost:6379,localhost:6380}/0?MaxRetries=3",
want: &config.Server{
Type: config.ServerTypeRedis,
Hosts: []string{
"localhost:6379",
"localhost:6380",
},
Path: []string{"0"},
Args: map[string]any{
"MaxRetries": 3.0,
},
},
},
{
name: "Multiple Redis servers with user",
url: "redis://user@{localhost:6379,localhost:6380}/0",
want: &config.Server{
Type: config.ServerTypeRedis,
Hosts: []string{
"localhost:6379",
"localhost:6380",
},
Credentials: &config.Credentials{
Username: "user",
},
Path: []string{"0"},
Args: make(map[string]any),
},
},
{
name: "Multiple Redis servers with credentials",
url: "redis://user:p4$$w0rd@{localhost:6379,localhost:6380}/0",
want: &config.Server{
Type: config.ServerTypeRedis,
Hosts: []string{
"localhost:6379",
"localhost:6380",
},
Credentials: &config.Credentials{
Username: "user",
Password: values.StringP("p4$$w0rd"),
},
Path: []string{"0"},
Args: make(map[string]any),
},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := new(config.Server)
if err := got.UnmarshalURL(tt.url); (err != nil) != tt.wantErr {
t.Errorf("ParseFromURL() error = %v, wantErr %v", err, tt.wantErr)
return
}
td.Cmp(t, got, tt.want)
})
}
}

10
go.mod
View file

@ -5,9 +5,12 @@ go 1.18
require (
github.com/alecthomas/participle/v2 v2.0.0-alpha8
github.com/go-redis/redis/v8 v8.11.5
github.com/google/uuid v1.3.0
github.com/maxatome/go-testdeep v1.11.0
github.com/mitchellh/mapstructure v1.4.3
github.com/mitchellh/mapstructure v1.5.0
github.com/testcontainers/testcontainers-go v0.13.0
go.uber.org/zap v1.10.0
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b
)
require (
@ -27,7 +30,6 @@ require (
github.com/gogo/protobuf v1.3.2 // indirect
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
github.com/golang/protobuf v1.5.2 // indirect
github.com/google/uuid v1.3.0 // indirect
github.com/magiconair/properties v1.8.6 // indirect
github.com/moby/sys/mount v0.3.2 // indirect
github.com/moby/sys/mountinfo v0.6.1 // indirect
@ -39,10 +41,12 @@ require (
github.com/pkg/errors v0.9.1 // indirect
github.com/sirupsen/logrus v1.8.1 // indirect
go.opencensus.io v0.23.0 // indirect
go.uber.org/atomic v1.4.0 // indirect
go.uber.org/multierr v1.1.0 // indirect
golang.org/x/exp v0.0.0-20220428152302-39d4317da171 // indirect
golang.org/x/net v0.0.0-20220420153159-1850ba15e1be // indirect
golang.org/x/sys v0.0.0-20220412211240-33da011f77ad // indirect
google.golang.org/genproto v0.0.0-20220414192740-2d67ff6cf2b4 // indirect
google.golang.org/grpc v1.45.0 // indirect
google.golang.org/protobuf v1.28.0 // indirect
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect
)

9
go.sum
View file

@ -500,8 +500,8 @@ github.com/miekg/pkcs11 v1.0.3/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WT
github.com/mistifyio/go-zfs v2.1.2-0.20190413222219-f784269be439+incompatible/go.mod h1:8AuVvqP/mXw1px98n46wfvcGfQ4ci2FwoAjKYxuo3Z4=
github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0=
github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y=
github.com/mitchellh/mapstructure v1.4.3 h1:OVowDSCllw/YjdLkam3/sm7wEtOy59d8ndGgCcyj8cs=
github.com/mitchellh/mapstructure v1.4.3/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
github.com/mitchellh/osext v0.0.0-20151018003038-5e2d6d41470f/go.mod h1:OkQIRizQZAeMln+1tSwduZz7+Af5oFlKirV/MSYes2A=
github.com/moby/locker v1.0.1/go.mod h1:S7SDdo5zpBK84bzzVlKr2V0hz+7x9hWbYC/kq7oQppc=
github.com/moby/sys/mount v0.2.0/go.mod h1:aAivFE2LB3W4bACsUXChRHQ0qKWsetY4Y9V7sxOougM=
@ -717,9 +717,12 @@ go.opencensus.io v0.23.0 h1:gqCw0LfLxScz8irSi8exQc7fyQ0fKQU/qnC/X8+V/1M=
go.opencensus.io v0.23.0/go.mod h1:XItmlyltB5F7CS4xOC1DcqMoFqwtC6OG2xF7mCv7P7E=
go.opentelemetry.io/proto/otlp v0.7.0/go.mod h1:PqfVotwruBrMGOCsRd/89rSnXhoiJIqeYNgFYFoEGnI=
go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE=
go.uber.org/atomic v1.4.0 h1:cxzIVoETapQEqDhQu3QfnvXAV4AlzcvUCxkVUFw3+EU=
go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE=
go.uber.org/goleak v1.1.12/go.mod h1:cwTWslyiVhfpKIDGSZEM2HlOvcqm+tG4zioyIeLoqMQ=
go.uber.org/multierr v1.1.0 h1:HoEmRHQPVSqub6w2z2d2EOVs2fjyFRGyofhKuyDq0QI=
go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0=
go.uber.org/zap v1.10.0 h1:ORx85nbTijNz8ljznvCMR1ZBIPKFn3jQrag10X2AsuM=
go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q=
golang.org/x/crypto v0.0.0-20171113213409-9f005a07e0d3/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
@ -744,6 +747,8 @@ golang.org/x/exp v0.0.0-20191227195350-da58074b4299/go.mod h1:2RIsYlXP63K8oxa1u0
golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4=
golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM=
golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU=
golang.org/x/exp v0.0.0-20220428152302-39d4317da171 h1:TfdoLivD44QwvssI9Sv1xwa5DcL5XQr4au4sZ2F2NV4=
golang.org/x/exp v0.0.0-20220428152302-39d4317da171/go.mod h1:lgLbSvA5ygNOMpwM/9anMpWVlVJ7Z+cHWq/eFuinpGE=
golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=

View file

@ -1,9 +1,5 @@
package grammar
type CheckUnmarshaler interface {
UnmarshalCheck(c Check) error
}
type Call struct {
Module string `parser:"(@Module'.')?"`
Name string `parser:"@Ident"`
@ -20,5 +16,5 @@ type Check struct {
}
type Script struct {
Checks []Check `parser:"@@*"`
Checks []Check `parser:"(@@';'?)*"`
}

66
main.go
View file

@ -1,28 +1,66 @@
package main
import (
"log"
"os"
"time"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"github.com/baez90/nurse/config"
)
var checkTimeout time.Duration
var (
logLevel = zapcore.InfoLevel
cfgFile string
cfg config.Nurse
)
func main() {
}
if err := prepareFlags(); err != nil {
log.Fatalf("Failed to parse flags: %v", err)
}
setupLogging()
func lookupEnvOr[T any](envKey string, fallback T, parse func(envVal string) (T, error)) T {
envVal := os.Getenv(envKey)
if envVal == "" {
return fallback
logger := zap.L()
envCfg, err := config.New(
config.WithValuesFrom(cfg),
config.WithConfigFile(cfgFile),
config.WithServersFromEnv(),
config.WithEndpointsFromEnv(),
)
if err != nil {
logger.Error("Failed to load config from environment", zap.Error(err))
os.Exit(1)
}
if parsed, err := parse(envVal); err != nil {
return fallback
} else {
return parsed
}
logger.Debug("Loaded config", zap.Any("config", envCfg))
}
func identity[T any](in T) (T, error) {
return in, nil
func setupLogging() {
cfg := zap.NewProductionConfig()
cfg.Level = zap.NewAtomicLevelAt(logLevel)
logger, err := cfg.Build()
if err != nil {
log.Fatalf("Failed to setup logging: %v", err)
}
zap.ReplaceGlobals(logger)
}
func prepareFlags() error {
set := config.ConfigureFlags(&cfg)
set.Var(&logLevel, "log-level", "Log level to use")
set.StringVar(
&cfgFile,
"config",
config.LookupEnvOr[string]("NURSE_CONFIG", "", config.Identity[string]),
"Config file to load, if not set $HOME/.nurse.yaml, /etc/nurse/config.yaml and ./nurse.yaml are tried - optional",
)
return set.Parse(os.Args)
}

16
nurse.yaml Normal file
View file

@ -0,0 +1,16 @@
checkTimeout: 500ms
servers:
redis_local:
url: redis://localhost:6379/0
args:
MaxRetries: 3
endpoints:
/ready:
checkTimeout: 500ms
checks:
- redis.PING("redis_local")
/liveness:
checkTimeout: 100ms
checks:
- redis.GET("redis_local", "running") => String("ok")

View file

@ -1,35 +1,18 @@
package redis
import (
"fmt"
"strings"
"github.com/baez90/nurse/check"
"github.com/baez90/nurse/grammar"
)
var knownChecks = map[string]func() check.SystemChecker{
"ping": func() check.SystemChecker {
return new(PingCheck)
},
"get": func() check.SystemChecker {
return new(GetCheck)
},
}
func LookupCheck(c grammar.Check) (check.SystemChecker, error) {
var (
provider func() check.SystemChecker
ok bool
func Module() *check.Module {
m, _ := check.NewModule(
check.WithCheck("ping", check.FactoryFunc(func() check.SystemChecker {
return new(PingCheck)
})),
check.WithCheck("get", check.FactoryFunc(func() check.SystemChecker {
return new(GetCheck)
})),
)
if provider, ok = knownChecks[strings.ToLower(c.Initiator.Name)]; !ok {
return nil, fmt.Errorf("%w: %s", check.ErrNoSuchCheck, c.Initiator.Name)
}
chk := provider()
if err := chk.UnmarshalCheck(c); err != nil {
return nil, err
}
return chk, nil
return m
}

103
redis/checks_test.go Normal file
View file

@ -0,0 +1,103 @@
package redis_test
import (
"context"
"fmt"
"strings"
"testing"
redisCli "github.com/go-redis/redis/v8"
"github.com/google/uuid"
"github.com/maxatome/go-testdeep/td"
"github.com/baez90/nurse/config"
"github.com/baez90/nurse/grammar"
"github.com/baez90/nurse/redis"
)
func TestChecks_Execute(t *testing.T) {
t.Parallel()
redisModule := redis.Module()
tests := []struct {
name string
check string
setup func(tb testing.TB, cli redisCli.UniversalClient)
wantErr bool
}{
{
name: "Get value",
check: `redis.GET("%s", "some_key")`,
setup: func(tb testing.TB, cli redisCli.UniversalClient) {
tb.Helper()
td.CmpNoError(tb, cli.Set(context.Background(), "some_key", "some_value", 0).Err())
},
wantErr: false,
},
{
name: "Get value - validate value",
check: `redis.GET("%s", "some_key") => String("some_value")`,
setup: func(tb testing.TB, cli redisCli.UniversalClient) {
tb.Helper()
td.CmpNoError(tb, cli.Set(context.Background(), "some_key", "some_value", 0).Err())
},
wantErr: false,
},
{
name: "PING check",
check: `redis.PING("%s")`,
wantErr: false,
},
{
name: "PING check - with custom message",
check: `redis.PING("%s", "Hello, Redis!")`,
wantErr: false,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
srv := PrepareRedisContainer(t)
serverName := uuid.NewString()
register := config.NewServerRegister()
if err := register.Register(serverName, *srv); err != nil {
t.Fatalf("DefaultLookup.Register() err = %v", err)
}
cli, err := redis.ClientForServer(srv)
if err != nil {
t.Fatalf("redis.ClientForServer() err = %v", err)
}
if tt.setup != nil {
tt.setup(t, cli)
}
if strings.Contains(tt.check, "%s") {
tt.check = fmt.Sprintf(tt.check, serverName)
}
parser, err := grammar.NewParser[grammar.Check]()
td.CmpNoError(t, err, "grammar.NewParser()")
parsedCheck, err := parser.Parse(tt.check)
td.CmpNoError(t, err, "parser.Parse()")
chk, err := redisModule.Lookup(*parsedCheck, register)
td.CmpNoError(t, err, "redis.LookupCheck()")
td.CmpNoError(t, chk.UnmarshalCheck(*parsedCheck, register), "get.UnmarshalCheck()")
td.CmpNoError(t, chk.Execute(context.Background()))
if tt.wantErr {
td.CmpError(t, chk.Execute(context.Background()))
} else {
td.CmpNoError(t, chk.Execute(context.Background()))
}
})
}
}

View file

@ -10,10 +10,10 @@ import (
"github.com/baez90/nurse/grammar"
)
func clientFromParam(p grammar.Param) (redis.UniversalClient, error) {
func clientFromParam(p grammar.Param, srvLookup config.ServerLookup) (redis.UniversalClient, error) {
if srvName, err := p.AsString(); err != nil {
return nil, err
} else if srv, err := config.DefaultLookup.Lookup(srvName); err != nil {
} else if srv, err := srvLookup.Lookup(srvName); err != nil {
return nil, err
} else if redisCli, err := ClientForServer(srv); err != nil {
return nil, err
@ -37,5 +37,10 @@ func ClientForServer(srv *config.Server) (redis.UniversalClient, error) {
return nil, err
}
if srv.Credentials != nil {
opts.Username = srv.Credentials.Username
opts.Password = *srv.Credentials.Password
}
return redis.NewUniversalClient(opts), nil
}

View file

@ -3,19 +3,16 @@ package redis_test
import (
"context"
"fmt"
"net/url"
"testing"
"time"
"github.com/testcontainers/testcontainers-go"
"github.com/testcontainers/testcontainers-go/wait"
"github.com/baez90/nurse/config"
)
func PrepareRedisContainer(tb testing.TB) *config.Server {
tb.Helper()
const redisPort = "6379/tcp"
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute)
@ -24,9 +21,9 @@ func PrepareRedisContainer(tb testing.TB) *config.Server {
container, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{
ContainerRequest: testcontainers.ContainerRequest{
Image: "docker.io/redis:alpine",
Name: tb.Name(),
ExposedPorts: []string{redisPort},
WaitingFor: wait.ForListeningPort(redisPort),
SkipReaper: true,
AutoRemove: true,
},
Started: true,
Logger: testcontainers.TestLogger(tb),
@ -46,13 +43,8 @@ func PrepareRedisContainer(tb testing.TB) *config.Server {
tb.Fatalf("container.PortEndpoint() err = %v", err)
}
u, err := url.Parse(fmt.Sprintf("%s/0?MaxRetries=3", ep))
if err != nil {
tb.Fatalf("url.Parse() err = %v", err)
}
srv, err := config.ParseFromURL(u)
if err != nil {
srv := new(config.Server)
if err := srv.UnmarshalURL(fmt.Sprintf("%s/0?MaxRetries=3", ep)); err != nil {
tb.Fatalf("config.ParseFromURL() err = %v", err)
}
return srv

View file

@ -6,13 +6,11 @@ import (
"github.com/go-redis/redis/v8"
"github.com/baez90/nurse/check"
"github.com/baez90/nurse/config"
"github.com/baez90/nurse/grammar"
)
var (
_ check.SystemChecker = (*GetCheck)(nil)
_ grammar.CheckUnmarshaler = (*GetCheck)(nil)
)
var _ check.SystemChecker = (*GetCheck)(nil)
type GetCheck struct {
redis.UniversalClient
@ -30,7 +28,7 @@ func (g GetCheck) Execute(ctx context.Context) error {
return g.validators.Validate(cmd)
}
func (g *GetCheck) UnmarshalCheck(c grammar.Check) error {
func (g *GetCheck) UnmarshalCheck(c grammar.Check, lookup config.ServerLookup) error {
const serverAndKeyArgsNumber = 2
inst := c.Initiator
if err := grammar.ValidateParameterCount(inst.Params, serverAndKeyArgsNumber); err != nil {
@ -38,7 +36,7 @@ func (g *GetCheck) UnmarshalCheck(c grammar.Check) error {
}
var err error
if g.UniversalClient, err = clientFromParam(inst.Params[0]); err != nil {
if g.UniversalClient, err = clientFromParam(inst.Params[0], lookup); err != nil {
return err
}
@ -46,5 +44,9 @@ func (g *GetCheck) UnmarshalCheck(c grammar.Check) error {
return err
}
if g.validators, err = ValidatorsForFilters(c.Validators); err != nil {
return err
}
return nil
}

View file

@ -1,70 +0,0 @@
package redis_test
import (
"context"
"fmt"
"testing"
redisCli "github.com/go-redis/redis/v8"
"github.com/maxatome/go-testdeep/td"
"github.com/baez90/nurse/config"
"github.com/baez90/nurse/grammar"
"github.com/baez90/nurse/redis"
)
func TestGetCheck_Execute(t *testing.T) {
t.Parallel()
srv := PrepareRedisContainer(t)
if err := config.DefaultLookup.Register(t.Name(), *srv); err != nil {
t.Fatalf("DefaultLookup.Register() err = %v", err)
}
cli, err := redis.ClientForServer(srv)
if err != nil {
t.Fatalf("redis.ClientForServer() err = %v", err)
}
tests := []struct {
name string
check string
setup func(tb testing.TB, cli redisCli.UniversalClient)
wantErr bool
}{
{
name: "Get value",
check: fmt.Sprintf(`redis.GET("%s", "some_key")`, t.Name()),
setup: func(tb testing.TB, cli redisCli.UniversalClient) {
tb.Helper()
td.CmpNoError(tb, cli.Set(context.Background(), "some_key", "some_value", 0).Err())
},
wantErr: false,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
if tt.setup != nil {
tt.setup(t, cli)
}
get := new(redis.GetCheck)
parser, err := grammar.NewParser[grammar.Check]()
td.CmpNoError(t, err, "grammar.NewParser()")
check, err := parser.Parse(tt.check)
td.CmpNoError(t, err, "parser.Parse()")
td.CmpNoError(t, get.UnmarshalCheck(*check), "get.UnmarshalCheck()")
td.CmpNoError(t, get.Execute(context.Background()))
if tt.wantErr {
td.CmpError(t, get.Execute(context.Background()))
} else {
td.CmpNoError(t, get.Execute(context.Background()))
}
})
}
}

View file

@ -7,13 +7,11 @@ import (
"github.com/go-redis/redis/v8"
"github.com/baez90/nurse/check"
"github.com/baez90/nurse/config"
"github.com/baez90/nurse/grammar"
)
var (
_ check.SystemChecker = (*PingCheck)(nil)
_ grammar.CheckUnmarshaler = (*PingCheck)(nil)
)
var _ check.SystemChecker = (*PingCheck)(nil)
type PingCheck struct {
redis.UniversalClient
@ -34,7 +32,7 @@ func (p PingCheck) Execute(ctx context.Context) error {
return nil
}
func (p *PingCheck) UnmarshalCheck(c grammar.Check) error {
func (p *PingCheck) UnmarshalCheck(c grammar.Check, lookup config.ServerLookup) error {
const (
serverOnlyArgCount = 1
serverAndMessageArgCount = 2
@ -54,7 +52,7 @@ func (p *PingCheck) UnmarshalCheck(c grammar.Check) error {
}
fallthrough
case serverOnlyArgCount:
if cli, err := clientFromParam(init.Params[0]); err != nil {
if cli, err := clientFromParam(init.Params[0], lookup); err != nil {
return err
} else {
p.UniversalClient = cli

View file

@ -1,61 +0,0 @@
package redis_test
import (
"context"
"fmt"
"testing"
"github.com/maxatome/go-testdeep/td"
"github.com/baez90/nurse/config"
"github.com/baez90/nurse/grammar"
"github.com/baez90/nurse/redis"
)
func TestPingCheck_Execute(t *testing.T) {
t.Parallel()
srv := PrepareRedisContainer(t)
if err := config.DefaultLookup.Register(t.Name(), *srv); err != nil {
t.Fatalf("DefaultLookup.Register() err = %v", err)
}
tests := []struct {
name string
check string
wantErr bool
}{
{
name: "PING check",
check: fmt.Sprintf(`redis.PING("%s")`, t.Name()),
wantErr: false,
},
// redis.PING("my-redis")
// redis.PING("my-redis", "Hello, Redis!")
// redis.GET("my-redis", "some-key") -> String("Hello")
// redis.PING("my-redis"); redis.GET("my-redis", "some-key") -> String("Hello")
{
name: "PING check - with custom message",
check: fmt.Sprintf(`redis.PING("%s", "Hello, Redis!")`, t.Name()),
wantErr: false,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
ping := new(redis.PingCheck)
parser, err := grammar.NewParser[grammar.Check]()
td.CmpNoError(t, err, "grammar.NewParser()")
check, err := parser.Parse(tt.check)
td.CmpNoError(t, err, "parser.Parse()")
if err := ping.UnmarshalCheck(*check); err != nil {
t.Fatalf("UnmarshalCheck() err = %v", err)
}
if err := ping.Execute(context.Background()); (err != nil) != tt.wantErr {
t.Errorf("Execute() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}

View file

@ -3,23 +3,64 @@ package redis
import (
"errors"
"fmt"
"strings"
"github.com/go-redis/redis/v8"
"github.com/baez90/nurse/check"
"github.com/baez90/nurse/grammar"
)
var (
ErrNoSuchValidator = errors.New("no such validator")
_ CmdValidator = (ValidationChain)(nil)
_ CmdValidator = StringCmdValidator("")
_ CmdValidator = (*StringCmdValidator)(nil)
knownValidators = map[string]func() unmarshallableCmdValidator{
"string": func() unmarshallableCmdValidator {
return new(StringCmdValidator)
},
}
)
type (
CmdValidator interface {
Validate(cmder redis.Cmder) error
}
ValidationChain []CmdValidator
unmarshallableCmdValidator interface {
CmdValidator
check.CallUnmarshaler
}
)
func ValidatorsForFilters(filters *grammar.Filters) (ValidationChain, error) {
if filters == nil || filters.Chain == nil {
return ValidationChain{}, nil
}
chain := make(ValidationChain, 0, len(filters.Chain))
for i := range filters.Chain {
validationCall := filters.Chain[i]
if validatorProvider, ok := knownValidators[strings.ToLower(validationCall.Name)]; !ok {
return nil, fmt.Errorf("%w: %s", ErrNoSuchValidator, validationCall.Name)
} else {
validator := validatorProvider()
if err := validator.UnmarshalCall(validationCall); err != nil {
return nil, err
}
chain = append(chain, validator)
}
}
return chain, nil
}
type ValidationChain []CmdValidator
func (v ValidationChain) UnmarshalCall(grammar.Call) error {
return errors.New("cannot unmarshal chain")
}
func (v ValidationChain) Validate(cmder redis.Cmder) error {
for i := range v {
if err := v[i].Validate(cmder); err != nil {
@ -32,6 +73,20 @@ func (v ValidationChain) Validate(cmder redis.Cmder) error {
type StringCmdValidator string
func (s *StringCmdValidator) UnmarshalCall(c grammar.Call) error {
if err := grammar.ValidateParameterCount(c.Params, 1); err != nil {
return err
}
want, err := c.Params[0].AsString()
if err != nil {
return err
}
*s = StringCmdValidator(want)
return nil
}
func (s StringCmdValidator) Validate(cmder redis.Cmder) error {
if err := cmder.Err(); err != nil {
return err