From af4498700dd5ebe43c86fcdd0048c8758c9d5f24 Mon Sep 17 00:00:00 2001 From: Peter Kurfer Date: Sun, 7 Aug 2022 12:22:21 +0200 Subject: [PATCH] feat: add retry-behavior based on attempts --- api/check_handler.go | 17 ++++++---------- api/mux.go | 5 +++-- check/api.go | 8 +++++++- check/collection.go | 8 ++++---- check/context.go | 34 +++++++++++++++++++++++++++++++ config/app_config.go | 11 +++++++--- config/endpoints.go | 13 ++++++++++-- config/flags.go | 21 ++++++++++++++++++- protocols/http/checks_test.go | 9 ++++++-- protocols/http/get.go | 4 ++-- protocols/redis/checks_test.go | 9 ++++++-- protocols/redis/container_test.go | 2 ++ protocols/redis/get.go | 18 +++++++++++++++- protocols/redis/ping.go | 19 ++++++++++++++++- protocols/sql/checks_test.go | 9 ++++++-- protocols/sql/container_test.go | 14 ++----------- protocols/sql/select.go | 18 +++++++++++++++- 17 files changed, 172 insertions(+), 47 deletions(-) create mode 100644 check/context.go diff --git a/api/check_handler.go b/api/check_handler.go index ff2e319..2b169d6 100644 --- a/api/check_handler.go +++ b/api/check_handler.go @@ -1,7 +1,6 @@ package api import ( - "context" "net/http" "time" @@ -11,19 +10,15 @@ import ( var _ http.Handler = (*CheckHandler)(nil) type CheckHandler struct { - Timeout time.Duration - Check check.SystemChecker + Timeout time.Duration + Attempts uint + Check check.SystemChecker } func (c CheckHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { - var ( - ctx = request.Context() - cancel context.CancelFunc - ) - if c.Timeout != 0 { - ctx, cancel = context.WithTimeout(ctx, c.Timeout) - defer cancel() - } + ctx, cancel := check.AttemptsContext(request.Context(), c.Attempts, c.Timeout) + defer cancel() + if err := c.Check.Execute(ctx); err != nil { writer.WriteHeader(http.StatusServiceUnavailable) _, _ = writer.Write([]byte(err.Error())) diff --git a/api/mux.go b/api/mux.go index 0101d52..beb4685 100644 --- a/api/mux.go +++ b/api/mux.go @@ -22,8 +22,9 @@ func PrepareMux(instance *config.Nurse, modLookup check.ModuleLookup, srvLookup } mux.Handle(route.String(), CheckHandler{ - Timeout: spec.Timeout(instance.CheckTimeout), - Check: chk, + Timeout: spec.Timeout(instance.CheckTimeout), + Attempts: spec.Attempts(instance.CheckAttempts), + Check: chk, }) } diff --git a/check/api.go b/check/api.go index c2ec900..32da261 100644 --- a/check/api.go +++ b/check/api.go @@ -19,9 +19,15 @@ type ( UnmarshalCheck(c grammar.Check, lookup config.ServerLookup) error } + Context interface { + context.Context + AttemptContext() (context.Context, context.CancelFunc) + WithParent(ctx context.Context) Context + } + SystemChecker interface { Unmarshaler - Execute(ctx context.Context) error + Execute(ctx Context) error } CallUnmarshaler interface { diff --git a/check/collection.go b/check/collection.go index ebf79bb..e472cef 100644 --- a/check/collection.go +++ b/check/collection.go @@ -1,8 +1,6 @@ package check import ( - "context" - "golang.org/x/sync/errgroup" "code.1533b4dc0.de/prskr/nurse/config" @@ -17,13 +15,15 @@ func (Collection) UnmarshalCheck(grammar.Check, config.ServerLookup) error { panic("unmarshalling is not supported for a collection") } -func (c Collection) Execute(ctx context.Context) error { +func (c Collection) Execute(ctx Context) error { grp, grpCtx := errgroup.WithContext(ctx) + chkCtx := ctx.WithParent(grpCtx) + for i := range c { chk := c[i] grp.Go(func() error { - return chk.Execute(grpCtx) + return chk.Execute(chkCtx) }) } diff --git a/check/context.go b/check/context.go new file mode 100644 index 0000000..c366a19 --- /dev/null +++ b/check/context.go @@ -0,0 +1,34 @@ +package check + +import ( + "context" + "time" +) + +var _ Context = (*checkContext)(nil) + +func AttemptsContext(parent context.Context, numberOfAttempts uint, attemptTimeout time.Duration) (*checkContext, context.CancelFunc) { + finalTimeout := time.Duration(numberOfAttempts) * attemptTimeout + base, cancel := context.WithTimeout(parent, finalTimeout) + + return &checkContext{ + Context: base, + attemptTimeout: attemptTimeout, + }, cancel +} + +type checkContext struct { + attemptTimeout time.Duration + context.Context +} + +func (c *checkContext) WithParent(ctx context.Context) Context { + return &checkContext{ + Context: ctx, + attemptTimeout: c.attemptTimeout, + } +} + +func (c *checkContext) AttemptContext() (context.Context, context.CancelFunc) { + return context.WithTimeout(c, c.attemptTimeout) +} diff --git a/config/app_config.go b/config/app_config.go index 3ccd014..f2ceb9c 100644 --- a/config/app_config.go +++ b/config/app_config.go @@ -33,9 +33,10 @@ func New(opts ...Option) (*Nurse, error) { } type Nurse struct { - Servers map[string]Server - Endpoints map[Route]EndpointSpec - CheckTimeout time.Duration + Servers map[string]Server + Endpoints map[Route]EndpointSpec + CheckTimeout time.Duration + CheckAttempts uint } func (n Nurse) ServerLookup() (*ServerRegister, error) { @@ -57,6 +58,10 @@ func (n Nurse) Merge(other Nurse) Nurse { n.CheckTimeout = other.CheckTimeout } + if n.CheckAttempts == 0 { + n.CheckAttempts = other.CheckAttempts + } + for name, srv := range other.Servers { if _, ok := n.Servers[name]; !ok { n.Servers[name] = srv diff --git a/config/endpoints.go b/config/endpoints.go index 15204ac..77211f3 100644 --- a/config/endpoints.go +++ b/config/endpoints.go @@ -19,8 +19,9 @@ func (r Route) String() string { } type EndpointSpec struct { - CheckTimeout time.Duration - Checks []grammar.Check + CheckTimeout time.Duration + CheckAttempts uint + Checks []grammar.Check } func (s EndpointSpec) Timeout(fallback time.Duration) time.Duration { @@ -31,6 +32,14 @@ func (s EndpointSpec) Timeout(fallback time.Duration) time.Duration { return fallback } +func (s EndpointSpec) Attempts(fallback uint) uint { + if s.CheckAttempts != 0 { + return s.CheckAttempts + } + + return fallback +} + func (s *EndpointSpec) Parse(text string) error { parser, err := grammar.NewParser[grammar.Script]() if err != nil { diff --git a/config/flags.go b/config/flags.go index e563a11..0c62061 100644 --- a/config/flags.go +++ b/config/flags.go @@ -3,10 +3,14 @@ package config import ( "flag" "os" + "strconv" "time" ) -const defaultCheckTimeout = 500 * time.Millisecond +const ( + defaultCheckTimeout = 500 * time.Millisecond + defaultAttemptCount = 20 +) func ConfigureFlags(cfg *Nurse) *flag.FlagSet { set := flag.NewFlagSet("nurse", flag.ContinueOnError) @@ -18,6 +22,13 @@ func ConfigureFlags(cfg *Nurse) *flag.FlagSet { "Timeout when running checks", ) + set.UintVar( + &cfg.CheckAttempts, + "check-attempts", + LookupEnvOr("NURSE_CHECK_ATTEMPTS", defaultAttemptCount, parseUint), + "Number of attempts for a check", + ) + return set } @@ -39,3 +50,11 @@ func LookupEnvOr[T any](envKey string, fallback T, parse func(envVal string) (T, func Identity[T any](in T) (T, error) { return in, nil } + +func parseUint(val string) (uint, error) { + parsed, err := strconv.ParseUint(val, 10, 32) + if err != nil { + return 0, err + } + return uint(parsed), nil +} diff --git a/protocols/http/checks_test.go b/protocols/http/checks_test.go index cb94bf3..b916274 100644 --- a/protocols/http/checks_test.go +++ b/protocols/http/checks_test.go @@ -8,9 +8,11 @@ import ( "net/http/httptest" "strings" "testing" + "time" "github.com/maxatome/go-testdeep/td" + "code.1533b4dc0.de/prskr/nurse/check" "code.1533b4dc0.de/prskr/nurse/grammar" httpcheck "code.1533b4dc0.de/prskr/nurse/protocols/http" ) @@ -134,10 +136,13 @@ func TestChecks_Execute(t *testing.T) { clientInjectable.SetClient(testServer.Client()) } + ctx, cancel := check.AttemptsContext(context.Background(), 100, 100*time.Millisecond) + t.Cleanup(cancel) + if tt.wantErr { - td.CmpError(t, chk.Execute(context.Background())) + td.CmpError(t, chk.Execute(ctx)) } else { - td.CmpNoError(t, chk.Execute(context.Background())) + td.CmpNoError(t, chk.Execute(ctx)) } }) } diff --git a/protocols/http/get.go b/protocols/http/get.go index 608b617..b1132f7 100644 --- a/protocols/http/get.go +++ b/protocols/http/get.go @@ -2,7 +2,6 @@ package http import ( "bytes" - "context" "io" "net/http" @@ -37,7 +36,8 @@ func (g *GenericCheck) SetClient(client *http.Client) { g.Client = client } -func (g *GenericCheck) Execute(ctx context.Context) error { +func (g *GenericCheck) Execute(ctx check.Context) error { + //TODO adopt var body io.Reader if len(g.Body) > 0 { body = bytes.NewReader(g.Body) diff --git a/protocols/redis/checks_test.go b/protocols/redis/checks_test.go index ad33c1b..8b808b5 100644 --- a/protocols/redis/checks_test.go +++ b/protocols/redis/checks_test.go @@ -5,11 +5,13 @@ import ( "fmt" "strings" "testing" + "time" redisCli "github.com/go-redis/redis/v8" "github.com/google/uuid" "github.com/maxatome/go-testdeep/td" + "code.1533b4dc0.de/prskr/nurse/check" "code.1533b4dc0.de/prskr/nurse/config" "code.1533b4dc0.de/prskr/nurse/grammar" "code.1533b4dc0.de/prskr/nurse/protocols/redis" @@ -90,10 +92,13 @@ func TestChecks_Execute(t *testing.T) { chk, err := redisModule.Lookup(*parsedCheck, register) td.CmpNoError(t, err, "redis.LookupCheck()") + ctx, cancel := check.AttemptsContext(context.Background(), 100, 500*time.Millisecond) + t.Cleanup(cancel) + if tt.wantErr { - td.CmpError(t, chk.Execute(context.Background())) + td.CmpError(t, chk.Execute(ctx)) } else { - td.CmpNoError(t, chk.Execute(context.Background())) + td.CmpNoError(t, chk.Execute(ctx)) } }) } diff --git a/protocols/redis/container_test.go b/protocols/redis/container_test.go index f992418..0ecfbd9 100644 --- a/protocols/redis/container_test.go +++ b/protocols/redis/container_test.go @@ -7,6 +7,7 @@ import ( "time" "github.com/testcontainers/testcontainers-go" + "github.com/testcontainers/testcontainers-go/wait" "code.1533b4dc0.de/prskr/nurse/config" ) @@ -24,6 +25,7 @@ func PrepareRedisContainer(tb testing.TB) *config.Server { ExposedPorts: []string{redisPort}, SkipReaper: true, AutoRemove: true, + WaitingFor: wait.ForListeningPort(redisPort), }, Started: true, Logger: testcontainers.TestLogger(tb), diff --git a/protocols/redis/get.go b/protocols/redis/get.go index 9384f99..d278446 100644 --- a/protocols/redis/get.go +++ b/protocols/redis/get.go @@ -19,7 +19,23 @@ type GetCheck struct { Key string } -func (g *GetCheck) Execute(ctx context.Context) error { +func (g *GetCheck) Execute(ctx check.Context) error { + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + attemptCtx, cancel := ctx.AttemptContext() + err := g.executeAttempt(attemptCtx) + cancel() + if err == nil { + return nil + } + } + } +} + +func (g *GetCheck) executeAttempt(ctx context.Context) error { cmd := g.Get(ctx, g.Key) if err := cmd.Err(); err != nil { diff --git a/protocols/redis/ping.go b/protocols/redis/ping.go index 3013550..a36999a 100644 --- a/protocols/redis/ping.go +++ b/protocols/redis/ping.go @@ -20,10 +20,27 @@ type PingCheck struct { Message string } -func (p PingCheck) Execute(ctx context.Context) error { +func (p PingCheck) Execute(ctx check.Context) error { if p.Message == "" { return p.Ping(ctx).Err() } + + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + attemptCtx, cancel := ctx.AttemptContext() + err := p.executeAttempt(attemptCtx) + cancel() + if err == nil { + return nil + } + } + } +} + +func (p PingCheck) executeAttempt(ctx context.Context) error { if resp, err := p.Do(ctx, "PING", p.Message).Text(); err != nil { return err } else if resp != p.Message { diff --git a/protocols/sql/checks_test.go b/protocols/sql/checks_test.go index 81fb65d..b590b35 100644 --- a/protocols/sql/checks_test.go +++ b/protocols/sql/checks_test.go @@ -6,9 +6,11 @@ import ( "fmt" "strings" "testing" + "time" "github.com/maxatome/go-testdeep/td" + "code.1533b4dc0.de/prskr/nurse/check" "code.1533b4dc0.de/prskr/nurse/config" "code.1533b4dc0.de/prskr/nurse/grammar" sqlchk "code.1533b4dc0.de/prskr/nurse/protocols/sql" @@ -92,10 +94,13 @@ func TestChecks_Execute(t *testing.T) { chk, err := sqlModule.Lookup(*parsedCheck, register) td.CmpNoError(t, err, "redis.LookupCheck()") + ctx, cancel := check.AttemptsContext(context.Background(), 100, 500*time.Millisecond) + t.Cleanup(cancel) + if tt.wantErr { - td.CmpError(t, chk.Execute(context.Background())) + td.CmpError(t, chk.Execute(ctx)) } else { - td.CmpNoError(t, chk.Execute(context.Background())) + td.CmpNoError(t, chk.Execute(ctx)) } }) } diff --git a/protocols/sql/container_test.go b/protocols/sql/container_test.go index 3f6a56c..1b21395 100644 --- a/protocols/sql/container_test.go +++ b/protocols/sql/container_test.go @@ -38,12 +38,7 @@ func PreparePostgresContainer(tb testing.TB) (name string, cfg *config.Server) { "POSTGRES_PASSWORD": dbPassword, "POSTGRES_DB": dbName, }, - WaitingFor: wait.ForAll( - wait.ForListeningPort(postgresPort), - /*wait.ForSQL(postgresPort, "pgx", func(port nat.Port) string { - return fmt.Sprintf("postgres://%s:%s@localhost:%d/%s", dbUser, dbPassword, port.Int(), dbName) - }),*/ - ), + WaitingFor: wait.ForListeningPort(postgresPort), }, Started: true, Logger: testcontainers.TestLogger(tb), @@ -93,12 +88,7 @@ func PrepareMariaDBContainer(tb testing.TB) (name string, cfg *config.Server) { "MARIADB_RANDOM_ROOT_PASSWORD": "1", "MARIADB_DATABASE": dbName, }, - WaitingFor: wait.ForAll( - wait.ForListeningPort(mysqlPort), - /* wait.ForSQL(mysqlPort, "mysql", func(port nat.Port) string { - return fmt.Sprintf("%s:%s@tcp(localhost:%d)/%s", dbUser, dbPassword, port.Int(), dbName) - }),*/ - ), + WaitingFor: wait.ForListeningPort(mysqlPort), }, Started: true, Logger: testcontainers.TestLogger(tb), diff --git a/protocols/sql/select.go b/protocols/sql/select.go index b4ef88c..d600b4c 100644 --- a/protocols/sql/select.go +++ b/protocols/sql/select.go @@ -43,7 +43,23 @@ func (s *SelectCheck) UnmarshalCheck(c grammar.Check, lookup config.ServerLookup return nil } -func (s *SelectCheck) Execute(ctx context.Context) (err error) { +func (s *SelectCheck) Execute(ctx check.Context) error { + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + attemptCtx, cancel := ctx.AttemptContext() + err := s.executeAttempt(attemptCtx) + cancel() + if err == nil { + return nil + } + } + } +} + +func (s *SelectCheck) executeAttempt(ctx context.Context) (err error) { var rows *sql.Rows rows, err = s.QueryContext(ctx, s.Query) if err != nil {