api/internal/endpoint/endpoint_manager.go

181 lines
4.7 KiB
Go

package endpoint
import (
"context"
"fmt"
"sync"
"time"
"github.com/google/uuid"
"gitlab.com/inetmock/inetmock/pkg/api"
"gitlab.com/inetmock/inetmock/pkg/config"
"gitlab.com/inetmock/inetmock/pkg/health"
"gitlab.com/inetmock/inetmock/pkg/logging"
"go.uber.org/zap"
)
const (
startupTimeoutDuration = 100 * time.Millisecond
)
type EndpointManager interface {
RegisteredEndpoints() []Endpoint
StartedEndpoints() []Endpoint
CreateEndpoint(name string, multiHandlerConfig config.EndpointConfig) error
StartEndpoints()
ShutdownEndpoints()
}
func NewEndpointManager(registry api.HandlerRegistry, logging logging.Logger, checker health.Checker, pluginContext api.PluginContext) EndpointManager {
return &endpointManager{
registry: registry,
logger: logging,
checker: checker,
pluginContext: pluginContext,
}
}
type endpointManager struct {
registry api.HandlerRegistry
logger logging.Logger
checker health.Checker
pluginContext api.PluginContext
registeredEndpoints []Endpoint
properlyStartedEndpoints []Endpoint
}
func (e endpointManager) RegisteredEndpoints() []Endpoint {
return e.registeredEndpoints
}
func (e endpointManager) StartedEndpoints() []Endpoint {
return e.properlyStartedEndpoints
}
func (e *endpointManager) CreateEndpoint(name string, endpointConfig config.EndpointConfig) error {
for _, handlerConfig := range endpointConfig.HandlerConfigs() {
if handler, ok := e.registry.HandlerForName(endpointConfig.Handler); ok {
e.registeredEndpoints = append(e.registeredEndpoints, &endpoint{
id: uuid.New(),
name: name,
handler: handler,
config: handlerConfig,
})
} else {
return fmt.Errorf("no matching handler registered for names %s", endpointConfig.Handler)
}
}
return nil
}
func (e *endpointManager) StartEndpoints() {
startTime := time.Now()
for _, endpoint := range e.registeredEndpoints {
endpointLogger := e.logger.With(
zap.String("endpoint", endpoint.Name()),
)
endpointLogger.Info("Starting endpoint")
if ok := startEndpoint(endpoint, e.pluginContext, endpointLogger); ok {
_ = e.checker.RegisterCheck(
endpointComponentName(endpoint),
health.StaticResultCheckWithMessage(health.HEALTHY, "Successfully started"),
)
e.properlyStartedEndpoints = append(e.properlyStartedEndpoints, endpoint)
endpointLogger.Info("successfully started endpoint")
} else {
_ = e.checker.RegisterCheck(
endpointComponentName(endpoint),
health.StaticResultCheckWithMessage(health.UNHEALTHY, "failed to start"),
)
endpointLogger.Error("error occurred during endpoint startup - will be skipped for now")
}
}
endpointStartupDuration := time.Since(startTime)
e.logger.Info(
"Startup of all endpoints completed",
zap.Duration("startupTime", endpointStartupDuration),
)
}
func (e *endpointManager) ShutdownEndpoints() {
waitGroup := new(sync.WaitGroup)
waitGroup.Add(len(e.properlyStartedEndpoints))
parentCtx, _ := context.WithTimeout(context.Background(), shutdownTimeout)
perHandlerTimeout := e.shutdownTimePerEndpoint()
for _, endpoint := range e.properlyStartedEndpoints {
ctx, _ := context.WithTimeout(parentCtx, perHandlerTimeout)
endpointLogger := e.logger.With(
zap.String("endpoint", endpoint.Name()),
)
endpointLogger.Info("Triggering shutdown of endpoint")
go shutdownEndpoint(ctx, endpoint, endpointLogger, waitGroup)
}
waitGroup.Wait()
}
func startEndpoint(ep Endpoint, ctx api.PluginContext, logger logging.Logger) (success bool) {
startSuccessful := make(chan bool)
go func() {
defer func() {
if r := recover(); r != nil {
logger.Fatal(
"recovered panic during startup of endpoint",
zap.Any("recovered", r),
)
startSuccessful <- false
}
}()
if err := ep.Start(ctx); err != nil {
logger.Error(
"failed to start endpoint",
zap.Error(err),
)
startSuccessful <- false
} else {
startSuccessful <- true
}
}()
select {
case success = <-startSuccessful:
case <-time.After(startupTimeoutDuration):
success = false
}
close(startSuccessful)
return
}
func shutdownEndpoint(ctx context.Context, ep Endpoint, logger logging.Logger, wg *sync.WaitGroup) {
defer func() {
if r := recover(); r != nil {
logger.Fatal(
"recovered panic during shutdown of endpoint",
zap.Any("recovered", r),
)
}
wg.Done()
}()
if err := ep.Shutdown(ctx); err != nil {
logger.Error(
"Failed to shutdown endpoint",
zap.Error(err),
)
}
}
func (e *endpointManager) shutdownTimePerEndpoint() time.Duration {
return time.Duration((float64(shutdownTimeout) * 0.9) / float64(len(e.properlyStartedEndpoints)))
}
func endpointComponentName(ep Endpoint) string {
return fmt.Sprintf("endpoint_%s", ep.Name())
}