feat: add retry-behavior based on attempts
This commit is contained in:
parent
2134fae875
commit
af4498700d
17 changed files with 172 additions and 47 deletions
|
@ -1,7 +1,6 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
|
@ -12,18 +11,14 @@ var _ http.Handler = (*CheckHandler)(nil)
|
|||
|
||||
type CheckHandler struct {
|
||||
Timeout time.Duration
|
||||
Attempts uint
|
||||
Check check.SystemChecker
|
||||
}
|
||||
|
||||
func (c CheckHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
|
||||
var (
|
||||
ctx = request.Context()
|
||||
cancel context.CancelFunc
|
||||
)
|
||||
if c.Timeout != 0 {
|
||||
ctx, cancel = context.WithTimeout(ctx, c.Timeout)
|
||||
ctx, cancel := check.AttemptsContext(request.Context(), c.Attempts, c.Timeout)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
if err := c.Check.Execute(ctx); err != nil {
|
||||
writer.WriteHeader(http.StatusServiceUnavailable)
|
||||
_, _ = writer.Write([]byte(err.Error()))
|
||||
|
|
|
@ -23,6 +23,7 @@ func PrepareMux(instance *config.Nurse, modLookup check.ModuleLookup, srvLookup
|
|||
|
||||
mux.Handle(route.String(), CheckHandler{
|
||||
Timeout: spec.Timeout(instance.CheckTimeout),
|
||||
Attempts: spec.Attempts(instance.CheckAttempts),
|
||||
Check: chk,
|
||||
})
|
||||
}
|
||||
|
|
|
@ -19,9 +19,15 @@ type (
|
|||
UnmarshalCheck(c grammar.Check, lookup config.ServerLookup) error
|
||||
}
|
||||
|
||||
Context interface {
|
||||
context.Context
|
||||
AttemptContext() (context.Context, context.CancelFunc)
|
||||
WithParent(ctx context.Context) Context
|
||||
}
|
||||
|
||||
SystemChecker interface {
|
||||
Unmarshaler
|
||||
Execute(ctx context.Context) error
|
||||
Execute(ctx Context) error
|
||||
}
|
||||
|
||||
CallUnmarshaler interface {
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
package check
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"code.1533b4dc0.de/prskr/nurse/config"
|
||||
|
@ -17,13 +15,15 @@ func (Collection) UnmarshalCheck(grammar.Check, config.ServerLookup) error {
|
|||
panic("unmarshalling is not supported for a collection")
|
||||
}
|
||||
|
||||
func (c Collection) Execute(ctx context.Context) error {
|
||||
func (c Collection) Execute(ctx Context) error {
|
||||
grp, grpCtx := errgroup.WithContext(ctx)
|
||||
|
||||
chkCtx := ctx.WithParent(grpCtx)
|
||||
|
||||
for i := range c {
|
||||
chk := c[i]
|
||||
grp.Go(func() error {
|
||||
return chk.Execute(grpCtx)
|
||||
return chk.Execute(chkCtx)
|
||||
})
|
||||
}
|
||||
|
||||
|
|
34
check/context.go
Normal file
34
check/context.go
Normal file
|
@ -0,0 +1,34 @@
|
|||
package check
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
var _ Context = (*checkContext)(nil)
|
||||
|
||||
func AttemptsContext(parent context.Context, numberOfAttempts uint, attemptTimeout time.Duration) (*checkContext, context.CancelFunc) {
|
||||
finalTimeout := time.Duration(numberOfAttempts) * attemptTimeout
|
||||
base, cancel := context.WithTimeout(parent, finalTimeout)
|
||||
|
||||
return &checkContext{
|
||||
Context: base,
|
||||
attemptTimeout: attemptTimeout,
|
||||
}, cancel
|
||||
}
|
||||
|
||||
type checkContext struct {
|
||||
attemptTimeout time.Duration
|
||||
context.Context
|
||||
}
|
||||
|
||||
func (c *checkContext) WithParent(ctx context.Context) Context {
|
||||
return &checkContext{
|
||||
Context: ctx,
|
||||
attemptTimeout: c.attemptTimeout,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *checkContext) AttemptContext() (context.Context, context.CancelFunc) {
|
||||
return context.WithTimeout(c, c.attemptTimeout)
|
||||
}
|
|
@ -36,6 +36,7 @@ type Nurse struct {
|
|||
Servers map[string]Server
|
||||
Endpoints map[Route]EndpointSpec
|
||||
CheckTimeout time.Duration
|
||||
CheckAttempts uint
|
||||
}
|
||||
|
||||
func (n Nurse) ServerLookup() (*ServerRegister, error) {
|
||||
|
@ -57,6 +58,10 @@ func (n Nurse) Merge(other Nurse) Nurse {
|
|||
n.CheckTimeout = other.CheckTimeout
|
||||
}
|
||||
|
||||
if n.CheckAttempts == 0 {
|
||||
n.CheckAttempts = other.CheckAttempts
|
||||
}
|
||||
|
||||
for name, srv := range other.Servers {
|
||||
if _, ok := n.Servers[name]; !ok {
|
||||
n.Servers[name] = srv
|
||||
|
|
|
@ -20,6 +20,7 @@ func (r Route) String() string {
|
|||
|
||||
type EndpointSpec struct {
|
||||
CheckTimeout time.Duration
|
||||
CheckAttempts uint
|
||||
Checks []grammar.Check
|
||||
}
|
||||
|
||||
|
@ -31,6 +32,14 @@ func (s EndpointSpec) Timeout(fallback time.Duration) time.Duration {
|
|||
return fallback
|
||||
}
|
||||
|
||||
func (s EndpointSpec) Attempts(fallback uint) uint {
|
||||
if s.CheckAttempts != 0 {
|
||||
return s.CheckAttempts
|
||||
}
|
||||
|
||||
return fallback
|
||||
}
|
||||
|
||||
func (s *EndpointSpec) Parse(text string) error {
|
||||
parser, err := grammar.NewParser[grammar.Script]()
|
||||
if err != nil {
|
||||
|
|
|
@ -3,10 +3,14 @@ package config
|
|||
import (
|
||||
"flag"
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
const defaultCheckTimeout = 500 * time.Millisecond
|
||||
const (
|
||||
defaultCheckTimeout = 500 * time.Millisecond
|
||||
defaultAttemptCount = 20
|
||||
)
|
||||
|
||||
func ConfigureFlags(cfg *Nurse) *flag.FlagSet {
|
||||
set := flag.NewFlagSet("nurse", flag.ContinueOnError)
|
||||
|
@ -18,6 +22,13 @@ func ConfigureFlags(cfg *Nurse) *flag.FlagSet {
|
|||
"Timeout when running checks",
|
||||
)
|
||||
|
||||
set.UintVar(
|
||||
&cfg.CheckAttempts,
|
||||
"check-attempts",
|
||||
LookupEnvOr("NURSE_CHECK_ATTEMPTS", defaultAttemptCount, parseUint),
|
||||
"Number of attempts for a check",
|
||||
)
|
||||
|
||||
return set
|
||||
}
|
||||
|
||||
|
@ -39,3 +50,11 @@ func LookupEnvOr[T any](envKey string, fallback T, parse func(envVal string) (T,
|
|||
func Identity[T any](in T) (T, error) {
|
||||
return in, nil
|
||||
}
|
||||
|
||||
func parseUint(val string) (uint, error) {
|
||||
parsed, err := strconv.ParseUint(val, 10, 32)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return uint(parsed), nil
|
||||
}
|
||||
|
|
|
@ -8,9 +8,11 @@ import (
|
|||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/maxatome/go-testdeep/td"
|
||||
|
||||
"code.1533b4dc0.de/prskr/nurse/check"
|
||||
"code.1533b4dc0.de/prskr/nurse/grammar"
|
||||
httpcheck "code.1533b4dc0.de/prskr/nurse/protocols/http"
|
||||
)
|
||||
|
@ -134,10 +136,13 @@ func TestChecks_Execute(t *testing.T) {
|
|||
clientInjectable.SetClient(testServer.Client())
|
||||
}
|
||||
|
||||
ctx, cancel := check.AttemptsContext(context.Background(), 100, 100*time.Millisecond)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
if tt.wantErr {
|
||||
td.CmpError(t, chk.Execute(context.Background()))
|
||||
td.CmpError(t, chk.Execute(ctx))
|
||||
} else {
|
||||
td.CmpNoError(t, chk.Execute(context.Background()))
|
||||
td.CmpNoError(t, chk.Execute(ctx))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
@ -2,7 +2,6 @@ package http
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
|
@ -37,7 +36,8 @@ func (g *GenericCheck) SetClient(client *http.Client) {
|
|||
g.Client = client
|
||||
}
|
||||
|
||||
func (g *GenericCheck) Execute(ctx context.Context) error {
|
||||
func (g *GenericCheck) Execute(ctx check.Context) error {
|
||||
//TODO adopt
|
||||
var body io.Reader
|
||||
if len(g.Body) > 0 {
|
||||
body = bytes.NewReader(g.Body)
|
||||
|
|
|
@ -5,11 +5,13 @@ import (
|
|||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
redisCli "github.com/go-redis/redis/v8"
|
||||
"github.com/google/uuid"
|
||||
"github.com/maxatome/go-testdeep/td"
|
||||
|
||||
"code.1533b4dc0.de/prskr/nurse/check"
|
||||
"code.1533b4dc0.de/prskr/nurse/config"
|
||||
"code.1533b4dc0.de/prskr/nurse/grammar"
|
||||
"code.1533b4dc0.de/prskr/nurse/protocols/redis"
|
||||
|
@ -90,10 +92,13 @@ func TestChecks_Execute(t *testing.T) {
|
|||
chk, err := redisModule.Lookup(*parsedCheck, register)
|
||||
td.CmpNoError(t, err, "redis.LookupCheck()")
|
||||
|
||||
ctx, cancel := check.AttemptsContext(context.Background(), 100, 500*time.Millisecond)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
if tt.wantErr {
|
||||
td.CmpError(t, chk.Execute(context.Background()))
|
||||
td.CmpError(t, chk.Execute(ctx))
|
||||
} else {
|
||||
td.CmpNoError(t, chk.Execute(context.Background()))
|
||||
td.CmpNoError(t, chk.Execute(ctx))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/testcontainers/testcontainers-go"
|
||||
"github.com/testcontainers/testcontainers-go/wait"
|
||||
|
||||
"code.1533b4dc0.de/prskr/nurse/config"
|
||||
)
|
||||
|
@ -24,6 +25,7 @@ func PrepareRedisContainer(tb testing.TB) *config.Server {
|
|||
ExposedPorts: []string{redisPort},
|
||||
SkipReaper: true,
|
||||
AutoRemove: true,
|
||||
WaitingFor: wait.ForListeningPort(redisPort),
|
||||
},
|
||||
Started: true,
|
||||
Logger: testcontainers.TestLogger(tb),
|
||||
|
|
|
@ -19,7 +19,23 @@ type GetCheck struct {
|
|||
Key string
|
||||
}
|
||||
|
||||
func (g *GetCheck) Execute(ctx context.Context) error {
|
||||
func (g *GetCheck) Execute(ctx check.Context) error {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
attemptCtx, cancel := ctx.AttemptContext()
|
||||
err := g.executeAttempt(attemptCtx)
|
||||
cancel()
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (g *GetCheck) executeAttempt(ctx context.Context) error {
|
||||
cmd := g.Get(ctx, g.Key)
|
||||
|
||||
if err := cmd.Err(); err != nil {
|
||||
|
|
|
@ -20,10 +20,27 @@ type PingCheck struct {
|
|||
Message string
|
||||
}
|
||||
|
||||
func (p PingCheck) Execute(ctx context.Context) error {
|
||||
func (p PingCheck) Execute(ctx check.Context) error {
|
||||
if p.Message == "" {
|
||||
return p.Ping(ctx).Err()
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
attemptCtx, cancel := ctx.AttemptContext()
|
||||
err := p.executeAttempt(attemptCtx)
|
||||
cancel()
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p PingCheck) executeAttempt(ctx context.Context) error {
|
||||
if resp, err := p.Do(ctx, "PING", p.Message).Text(); err != nil {
|
||||
return err
|
||||
} else if resp != p.Message {
|
||||
|
|
|
@ -6,9 +6,11 @@ import (
|
|||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/maxatome/go-testdeep/td"
|
||||
|
||||
"code.1533b4dc0.de/prskr/nurse/check"
|
||||
"code.1533b4dc0.de/prskr/nurse/config"
|
||||
"code.1533b4dc0.de/prskr/nurse/grammar"
|
||||
sqlchk "code.1533b4dc0.de/prskr/nurse/protocols/sql"
|
||||
|
@ -92,10 +94,13 @@ func TestChecks_Execute(t *testing.T) {
|
|||
chk, err := sqlModule.Lookup(*parsedCheck, register)
|
||||
td.CmpNoError(t, err, "redis.LookupCheck()")
|
||||
|
||||
ctx, cancel := check.AttemptsContext(context.Background(), 100, 500*time.Millisecond)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
if tt.wantErr {
|
||||
td.CmpError(t, chk.Execute(context.Background()))
|
||||
td.CmpError(t, chk.Execute(ctx))
|
||||
} else {
|
||||
td.CmpNoError(t, chk.Execute(context.Background()))
|
||||
td.CmpNoError(t, chk.Execute(ctx))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
@ -38,12 +38,7 @@ func PreparePostgresContainer(tb testing.TB) (name string, cfg *config.Server) {
|
|||
"POSTGRES_PASSWORD": dbPassword,
|
||||
"POSTGRES_DB": dbName,
|
||||
},
|
||||
WaitingFor: wait.ForAll(
|
||||
wait.ForListeningPort(postgresPort),
|
||||
/*wait.ForSQL(postgresPort, "pgx", func(port nat.Port) string {
|
||||
return fmt.Sprintf("postgres://%s:%s@localhost:%d/%s", dbUser, dbPassword, port.Int(), dbName)
|
||||
}),*/
|
||||
),
|
||||
WaitingFor: wait.ForListeningPort(postgresPort),
|
||||
},
|
||||
Started: true,
|
||||
Logger: testcontainers.TestLogger(tb),
|
||||
|
@ -93,12 +88,7 @@ func PrepareMariaDBContainer(tb testing.TB) (name string, cfg *config.Server) {
|
|||
"MARIADB_RANDOM_ROOT_PASSWORD": "1",
|
||||
"MARIADB_DATABASE": dbName,
|
||||
},
|
||||
WaitingFor: wait.ForAll(
|
||||
wait.ForListeningPort(mysqlPort),
|
||||
/* wait.ForSQL(mysqlPort, "mysql", func(port nat.Port) string {
|
||||
return fmt.Sprintf("%s:%s@tcp(localhost:%d)/%s", dbUser, dbPassword, port.Int(), dbName)
|
||||
}),*/
|
||||
),
|
||||
WaitingFor: wait.ForListeningPort(mysqlPort),
|
||||
},
|
||||
Started: true,
|
||||
Logger: testcontainers.TestLogger(tb),
|
||||
|
|
|
@ -43,7 +43,23 @@ func (s *SelectCheck) UnmarshalCheck(c grammar.Check, lookup config.ServerLookup
|
|||
return nil
|
||||
}
|
||||
|
||||
func (s *SelectCheck) Execute(ctx context.Context) (err error) {
|
||||
func (s *SelectCheck) Execute(ctx check.Context) error {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
attemptCtx, cancel := ctx.AttemptContext()
|
||||
err := s.executeAttempt(attemptCtx)
|
||||
cancel()
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SelectCheck) executeAttempt(ctx context.Context) (err error) {
|
||||
var rows *sql.Rows
|
||||
rows, err = s.QueryContext(ctx, s.Query)
|
||||
if err != nil {
|
||||
|
|
Loading…
Reference in a new issue