feat: add retry-behavior based on attempts

This commit is contained in:
Peter 2022-08-07 12:22:21 +02:00
parent 2134fae875
commit af4498700d
Signed by: prskr
GPG key ID: C1DB5D2E8DB512F9
17 changed files with 172 additions and 47 deletions

View file

@ -1,7 +1,6 @@
package api package api
import ( import (
"context"
"net/http" "net/http"
"time" "time"
@ -12,18 +11,14 @@ var _ http.Handler = (*CheckHandler)(nil)
type CheckHandler struct { type CheckHandler struct {
Timeout time.Duration Timeout time.Duration
Attempts uint
Check check.SystemChecker Check check.SystemChecker
} }
func (c CheckHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { func (c CheckHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
var ( ctx, cancel := check.AttemptsContext(request.Context(), c.Attempts, c.Timeout)
ctx = request.Context()
cancel context.CancelFunc
)
if c.Timeout != 0 {
ctx, cancel = context.WithTimeout(ctx, c.Timeout)
defer cancel() defer cancel()
}
if err := c.Check.Execute(ctx); err != nil { if err := c.Check.Execute(ctx); err != nil {
writer.WriteHeader(http.StatusServiceUnavailable) writer.WriteHeader(http.StatusServiceUnavailable)
_, _ = writer.Write([]byte(err.Error())) _, _ = writer.Write([]byte(err.Error()))

View file

@ -23,6 +23,7 @@ func PrepareMux(instance *config.Nurse, modLookup check.ModuleLookup, srvLookup
mux.Handle(route.String(), CheckHandler{ mux.Handle(route.String(), CheckHandler{
Timeout: spec.Timeout(instance.CheckTimeout), Timeout: spec.Timeout(instance.CheckTimeout),
Attempts: spec.Attempts(instance.CheckAttempts),
Check: chk, Check: chk,
}) })
} }

View file

@ -19,9 +19,15 @@ type (
UnmarshalCheck(c grammar.Check, lookup config.ServerLookup) error UnmarshalCheck(c grammar.Check, lookup config.ServerLookup) error
} }
Context interface {
context.Context
AttemptContext() (context.Context, context.CancelFunc)
WithParent(ctx context.Context) Context
}
SystemChecker interface { SystemChecker interface {
Unmarshaler Unmarshaler
Execute(ctx context.Context) error Execute(ctx Context) error
} }
CallUnmarshaler interface { CallUnmarshaler interface {

View file

@ -1,8 +1,6 @@
package check package check
import ( import (
"context"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"code.1533b4dc0.de/prskr/nurse/config" "code.1533b4dc0.de/prskr/nurse/config"
@ -17,13 +15,15 @@ func (Collection) UnmarshalCheck(grammar.Check, config.ServerLookup) error {
panic("unmarshalling is not supported for a collection") panic("unmarshalling is not supported for a collection")
} }
func (c Collection) Execute(ctx context.Context) error { func (c Collection) Execute(ctx Context) error {
grp, grpCtx := errgroup.WithContext(ctx) grp, grpCtx := errgroup.WithContext(ctx)
chkCtx := ctx.WithParent(grpCtx)
for i := range c { for i := range c {
chk := c[i] chk := c[i]
grp.Go(func() error { grp.Go(func() error {
return chk.Execute(grpCtx) return chk.Execute(chkCtx)
}) })
} }

34
check/context.go Normal file
View file

@ -0,0 +1,34 @@
package check
import (
"context"
"time"
)
var _ Context = (*checkContext)(nil)
func AttemptsContext(parent context.Context, numberOfAttempts uint, attemptTimeout time.Duration) (*checkContext, context.CancelFunc) {
finalTimeout := time.Duration(numberOfAttempts) * attemptTimeout
base, cancel := context.WithTimeout(parent, finalTimeout)
return &checkContext{
Context: base,
attemptTimeout: attemptTimeout,
}, cancel
}
type checkContext struct {
attemptTimeout time.Duration
context.Context
}
func (c *checkContext) WithParent(ctx context.Context) Context {
return &checkContext{
Context: ctx,
attemptTimeout: c.attemptTimeout,
}
}
func (c *checkContext) AttemptContext() (context.Context, context.CancelFunc) {
return context.WithTimeout(c, c.attemptTimeout)
}

View file

@ -36,6 +36,7 @@ type Nurse struct {
Servers map[string]Server Servers map[string]Server
Endpoints map[Route]EndpointSpec Endpoints map[Route]EndpointSpec
CheckTimeout time.Duration CheckTimeout time.Duration
CheckAttempts uint
} }
func (n Nurse) ServerLookup() (*ServerRegister, error) { func (n Nurse) ServerLookup() (*ServerRegister, error) {
@ -57,6 +58,10 @@ func (n Nurse) Merge(other Nurse) Nurse {
n.CheckTimeout = other.CheckTimeout n.CheckTimeout = other.CheckTimeout
} }
if n.CheckAttempts == 0 {
n.CheckAttempts = other.CheckAttempts
}
for name, srv := range other.Servers { for name, srv := range other.Servers {
if _, ok := n.Servers[name]; !ok { if _, ok := n.Servers[name]; !ok {
n.Servers[name] = srv n.Servers[name] = srv

View file

@ -20,6 +20,7 @@ func (r Route) String() string {
type EndpointSpec struct { type EndpointSpec struct {
CheckTimeout time.Duration CheckTimeout time.Duration
CheckAttempts uint
Checks []grammar.Check Checks []grammar.Check
} }
@ -31,6 +32,14 @@ func (s EndpointSpec) Timeout(fallback time.Duration) time.Duration {
return fallback return fallback
} }
func (s EndpointSpec) Attempts(fallback uint) uint {
if s.CheckAttempts != 0 {
return s.CheckAttempts
}
return fallback
}
func (s *EndpointSpec) Parse(text string) error { func (s *EndpointSpec) Parse(text string) error {
parser, err := grammar.NewParser[grammar.Script]() parser, err := grammar.NewParser[grammar.Script]()
if err != nil { if err != nil {

View file

@ -3,10 +3,14 @@ package config
import ( import (
"flag" "flag"
"os" "os"
"strconv"
"time" "time"
) )
const defaultCheckTimeout = 500 * time.Millisecond const (
defaultCheckTimeout = 500 * time.Millisecond
defaultAttemptCount = 20
)
func ConfigureFlags(cfg *Nurse) *flag.FlagSet { func ConfigureFlags(cfg *Nurse) *flag.FlagSet {
set := flag.NewFlagSet("nurse", flag.ContinueOnError) set := flag.NewFlagSet("nurse", flag.ContinueOnError)
@ -18,6 +22,13 @@ func ConfigureFlags(cfg *Nurse) *flag.FlagSet {
"Timeout when running checks", "Timeout when running checks",
) )
set.UintVar(
&cfg.CheckAttempts,
"check-attempts",
LookupEnvOr("NURSE_CHECK_ATTEMPTS", defaultAttemptCount, parseUint),
"Number of attempts for a check",
)
return set return set
} }
@ -39,3 +50,11 @@ func LookupEnvOr[T any](envKey string, fallback T, parse func(envVal string) (T,
func Identity[T any](in T) (T, error) { func Identity[T any](in T) (T, error) {
return in, nil return in, nil
} }
func parseUint(val string) (uint, error) {
parsed, err := strconv.ParseUint(val, 10, 32)
if err != nil {
return 0, err
}
return uint(parsed), nil
}

View file

@ -8,9 +8,11 @@ import (
"net/http/httptest" "net/http/httptest"
"strings" "strings"
"testing" "testing"
"time"
"github.com/maxatome/go-testdeep/td" "github.com/maxatome/go-testdeep/td"
"code.1533b4dc0.de/prskr/nurse/check"
"code.1533b4dc0.de/prskr/nurse/grammar" "code.1533b4dc0.de/prskr/nurse/grammar"
httpcheck "code.1533b4dc0.de/prskr/nurse/protocols/http" httpcheck "code.1533b4dc0.de/prskr/nurse/protocols/http"
) )
@ -134,10 +136,13 @@ func TestChecks_Execute(t *testing.T) {
clientInjectable.SetClient(testServer.Client()) clientInjectable.SetClient(testServer.Client())
} }
ctx, cancel := check.AttemptsContext(context.Background(), 100, 100*time.Millisecond)
t.Cleanup(cancel)
if tt.wantErr { if tt.wantErr {
td.CmpError(t, chk.Execute(context.Background())) td.CmpError(t, chk.Execute(ctx))
} else { } else {
td.CmpNoError(t, chk.Execute(context.Background())) td.CmpNoError(t, chk.Execute(ctx))
} }
}) })
} }

View file

@ -2,7 +2,6 @@ package http
import ( import (
"bytes" "bytes"
"context"
"io" "io"
"net/http" "net/http"
@ -37,7 +36,8 @@ func (g *GenericCheck) SetClient(client *http.Client) {
g.Client = client g.Client = client
} }
func (g *GenericCheck) Execute(ctx context.Context) error { func (g *GenericCheck) Execute(ctx check.Context) error {
//TODO adopt
var body io.Reader var body io.Reader
if len(g.Body) > 0 { if len(g.Body) > 0 {
body = bytes.NewReader(g.Body) body = bytes.NewReader(g.Body)

View file

@ -5,11 +5,13 @@ import (
"fmt" "fmt"
"strings" "strings"
"testing" "testing"
"time"
redisCli "github.com/go-redis/redis/v8" redisCli "github.com/go-redis/redis/v8"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/maxatome/go-testdeep/td" "github.com/maxatome/go-testdeep/td"
"code.1533b4dc0.de/prskr/nurse/check"
"code.1533b4dc0.de/prskr/nurse/config" "code.1533b4dc0.de/prskr/nurse/config"
"code.1533b4dc0.de/prskr/nurse/grammar" "code.1533b4dc0.de/prskr/nurse/grammar"
"code.1533b4dc0.de/prskr/nurse/protocols/redis" "code.1533b4dc0.de/prskr/nurse/protocols/redis"
@ -90,10 +92,13 @@ func TestChecks_Execute(t *testing.T) {
chk, err := redisModule.Lookup(*parsedCheck, register) chk, err := redisModule.Lookup(*parsedCheck, register)
td.CmpNoError(t, err, "redis.LookupCheck()") td.CmpNoError(t, err, "redis.LookupCheck()")
ctx, cancel := check.AttemptsContext(context.Background(), 100, 500*time.Millisecond)
t.Cleanup(cancel)
if tt.wantErr { if tt.wantErr {
td.CmpError(t, chk.Execute(context.Background())) td.CmpError(t, chk.Execute(ctx))
} else { } else {
td.CmpNoError(t, chk.Execute(context.Background())) td.CmpNoError(t, chk.Execute(ctx))
} }
}) })
} }

View file

@ -7,6 +7,7 @@ import (
"time" "time"
"github.com/testcontainers/testcontainers-go" "github.com/testcontainers/testcontainers-go"
"github.com/testcontainers/testcontainers-go/wait"
"code.1533b4dc0.de/prskr/nurse/config" "code.1533b4dc0.de/prskr/nurse/config"
) )
@ -24,6 +25,7 @@ func PrepareRedisContainer(tb testing.TB) *config.Server {
ExposedPorts: []string{redisPort}, ExposedPorts: []string{redisPort},
SkipReaper: true, SkipReaper: true,
AutoRemove: true, AutoRemove: true,
WaitingFor: wait.ForListeningPort(redisPort),
}, },
Started: true, Started: true,
Logger: testcontainers.TestLogger(tb), Logger: testcontainers.TestLogger(tb),

View file

@ -19,7 +19,23 @@ type GetCheck struct {
Key string Key string
} }
func (g *GetCheck) Execute(ctx context.Context) error { func (g *GetCheck) Execute(ctx check.Context) error {
for {
select {
case <-ctx.Done():
return ctx.Err()
default:
attemptCtx, cancel := ctx.AttemptContext()
err := g.executeAttempt(attemptCtx)
cancel()
if err == nil {
return nil
}
}
}
}
func (g *GetCheck) executeAttempt(ctx context.Context) error {
cmd := g.Get(ctx, g.Key) cmd := g.Get(ctx, g.Key)
if err := cmd.Err(); err != nil { if err := cmd.Err(); err != nil {

View file

@ -20,10 +20,27 @@ type PingCheck struct {
Message string Message string
} }
func (p PingCheck) Execute(ctx context.Context) error { func (p PingCheck) Execute(ctx check.Context) error {
if p.Message == "" { if p.Message == "" {
return p.Ping(ctx).Err() return p.Ping(ctx).Err()
} }
for {
select {
case <-ctx.Done():
return ctx.Err()
default:
attemptCtx, cancel := ctx.AttemptContext()
err := p.executeAttempt(attemptCtx)
cancel()
if err == nil {
return nil
}
}
}
}
func (p PingCheck) executeAttempt(ctx context.Context) error {
if resp, err := p.Do(ctx, "PING", p.Message).Text(); err != nil { if resp, err := p.Do(ctx, "PING", p.Message).Text(); err != nil {
return err return err
} else if resp != p.Message { } else if resp != p.Message {

View file

@ -6,9 +6,11 @@ import (
"fmt" "fmt"
"strings" "strings"
"testing" "testing"
"time"
"github.com/maxatome/go-testdeep/td" "github.com/maxatome/go-testdeep/td"
"code.1533b4dc0.de/prskr/nurse/check"
"code.1533b4dc0.de/prskr/nurse/config" "code.1533b4dc0.de/prskr/nurse/config"
"code.1533b4dc0.de/prskr/nurse/grammar" "code.1533b4dc0.de/prskr/nurse/grammar"
sqlchk "code.1533b4dc0.de/prskr/nurse/protocols/sql" sqlchk "code.1533b4dc0.de/prskr/nurse/protocols/sql"
@ -92,10 +94,13 @@ func TestChecks_Execute(t *testing.T) {
chk, err := sqlModule.Lookup(*parsedCheck, register) chk, err := sqlModule.Lookup(*parsedCheck, register)
td.CmpNoError(t, err, "redis.LookupCheck()") td.CmpNoError(t, err, "redis.LookupCheck()")
ctx, cancel := check.AttemptsContext(context.Background(), 100, 500*time.Millisecond)
t.Cleanup(cancel)
if tt.wantErr { if tt.wantErr {
td.CmpError(t, chk.Execute(context.Background())) td.CmpError(t, chk.Execute(ctx))
} else { } else {
td.CmpNoError(t, chk.Execute(context.Background())) td.CmpNoError(t, chk.Execute(ctx))
} }
}) })
} }

View file

@ -38,12 +38,7 @@ func PreparePostgresContainer(tb testing.TB) (name string, cfg *config.Server) {
"POSTGRES_PASSWORD": dbPassword, "POSTGRES_PASSWORD": dbPassword,
"POSTGRES_DB": dbName, "POSTGRES_DB": dbName,
}, },
WaitingFor: wait.ForAll( WaitingFor: wait.ForListeningPort(postgresPort),
wait.ForListeningPort(postgresPort),
/*wait.ForSQL(postgresPort, "pgx", func(port nat.Port) string {
return fmt.Sprintf("postgres://%s:%s@localhost:%d/%s", dbUser, dbPassword, port.Int(), dbName)
}),*/
),
}, },
Started: true, Started: true,
Logger: testcontainers.TestLogger(tb), Logger: testcontainers.TestLogger(tb),
@ -93,12 +88,7 @@ func PrepareMariaDBContainer(tb testing.TB) (name string, cfg *config.Server) {
"MARIADB_RANDOM_ROOT_PASSWORD": "1", "MARIADB_RANDOM_ROOT_PASSWORD": "1",
"MARIADB_DATABASE": dbName, "MARIADB_DATABASE": dbName,
}, },
WaitingFor: wait.ForAll( WaitingFor: wait.ForListeningPort(mysqlPort),
wait.ForListeningPort(mysqlPort),
/* wait.ForSQL(mysqlPort, "mysql", func(port nat.Port) string {
return fmt.Sprintf("%s:%s@tcp(localhost:%d)/%s", dbUser, dbPassword, port.Int(), dbName)
}),*/
),
}, },
Started: true, Started: true,
Logger: testcontainers.TestLogger(tb), Logger: testcontainers.TestLogger(tb),

View file

@ -43,7 +43,23 @@ func (s *SelectCheck) UnmarshalCheck(c grammar.Check, lookup config.ServerLookup
return nil return nil
} }
func (s *SelectCheck) Execute(ctx context.Context) (err error) { func (s *SelectCheck) Execute(ctx check.Context) error {
for {
select {
case <-ctx.Done():
return ctx.Err()
default:
attemptCtx, cancel := ctx.AttemptContext()
err := s.executeAttempt(attemptCtx)
cancel()
if err == nil {
return nil
}
}
}
}
func (s *SelectCheck) executeAttempt(ctx context.Context) (err error) {
var rows *sql.Rows var rows *sql.Rows
rows, err = s.QueryContext(ctx, s.Query) rows, err = s.QueryContext(ctx, s.Query)
if err != nil { if err != nil {