From 0468c9367174ce23abddc6eeb4f6c4de9c1734b4 Mon Sep 17 00:00:00 2001 From: Peter Kurfer Date: Mon, 4 Jan 2021 17:52:21 +0100 Subject: [PATCH] Integrate into handlers --- internal/app/app.go | 27 +++++++++- pkg/api/protocol_handler.go | 2 + pkg/audit/event_stream.go | 33 ++++++++---- pkg/audit/event_stream_test.go | 4 +- pkg/audit/log_sink.go | 1 - pkg/audit/options.go | 36 +++++++++---- plugins/http_mock/conn_context.go | 35 ++++++++++++ plugins/http_mock/handler.go | 7 ++- plugins/http_mock/handler_test.go | 8 +++ plugins/http_mock/regex_router.go | 89 ++++++++++++++++++++++--------- 10 files changed, 190 insertions(+), 52 deletions(-) create mode 100644 plugins/http_mock/conn_context.go diff --git a/internal/app/app.go b/internal/app/app.go index ce8d1ff..6c4adfe 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -11,6 +11,7 @@ import ( "github.com/spf13/cobra" "gitlab.com/inetmock/inetmock/internal/endpoints" "gitlab.com/inetmock/inetmock/pkg/api" + "gitlab.com/inetmock/inetmock/pkg/audit" "gitlab.com/inetmock/inetmock/pkg/cert" "gitlab.com/inetmock/inetmock/pkg/config" "gitlab.com/inetmock/inetmock/pkg/health" @@ -47,6 +48,7 @@ type app struct { registry api.HandlerRegistry ctx context.Context cancel context.CancelFunc + eventStream audit.EventStream } func (a *app) MustRun() { @@ -82,6 +84,10 @@ func (a app) EndpointManager() endpoints.EndpointManager { return a.endpointManager } +func (a app) Audit() audit.Emitter { + return a.eventStream +} + func (a app) HandlerRegistry() api.HandlerRegistry { return a.registry } @@ -137,15 +143,32 @@ func NewApp(registrations ...api.Registration) (inetmockApp App, err error) { return } + a.endpointManager = endpoints.NewEndpointManager( + a.registry, + a.Logger().Named("EndpointManager"), + a.checker, + a, + ) + a.cfg = config.CreateConfig(cmd.Flags()) if err = a.cfg.ReadConfig(configFilePath); err != nil { return } - a.certStore, err = cert.NewDefaultStore(a.cfg, a.rootLogger) - a.endpointManager = endpoints.NewEndpointManager(a.registry, a.Logger().Named("EndpointManager"), a.checker, a) + if a.certStore, err = cert.NewDefaultStore(a.cfg, a.rootLogger); err != nil { + return + } + a.eventStream, err = audit.NewEventStream( + a.Logger().Named("EventStream"), + audit.WithSinkBufferSize(10), + ) + if err != nil { + return + } + + err = a.eventStream.RegisterSink(audit.NewLogSink(a.Logger().Named("LogSink"))) return } diff --git a/pkg/api/protocol_handler.go b/pkg/api/protocol_handler.go index 540ae77..f927570 100644 --- a/pkg/api/protocol_handler.go +++ b/pkg/api/protocol_handler.go @@ -4,6 +4,7 @@ package api import ( "context" + "gitlab.com/inetmock/inetmock/pkg/audit" "gitlab.com/inetmock/inetmock/pkg/cert" "gitlab.com/inetmock/inetmock/pkg/config" "gitlab.com/inetmock/inetmock/pkg/logging" @@ -12,6 +13,7 @@ import ( type PluginContext interface { Logger() logging.Logger CertStore() cert.Store + Audit() audit.Emitter } type ProtocolHandler interface { diff --git a/pkg/audit/event_stream.go b/pkg/audit/event_stream.go index bc22331..51f6c7c 100644 --- a/pkg/audit/event_stream.go +++ b/pkg/audit/event_stream.go @@ -16,13 +16,18 @@ func init() { type eventStream struct { logger logging.Logger buffer chan *Event - sinks map[string]chan Event + sinks map[string]*registeredSink lock sync.Locker idGenerator *snowflake.Node sinkBufferSize int sinkConsumptionTimeout time.Duration } +type registeredSink struct { + downstream chan Event + lock sync.Locker +} + func NewEventStream(logger logging.Logger, options ...EventStreamOption) (es EventStream, err error) { cfg := newEventStreamCfg() @@ -39,7 +44,7 @@ func NewEventStream(logger logging.Logger, options ...EventStreamOption) (es Eve generatorIdx++ underlying := &eventStream{ logger: logger, - sinks: make(map[string]chan Event), + sinks: make(map[string]*registeredSink), buffer: make(chan *Event, cfg.bufferSize), sinkBufferSize: cfg.sinkBuffersize, sinkConsumptionTimeout: cfg.sinkConsumptionTimeout, @@ -47,7 +52,10 @@ func NewEventStream(logger logging.Logger, options ...EventStreamOption) (es Eve lock: &sync.Mutex{}, } - go underlying.distribute() + // start distribute workers + for i := 0; i < cfg.distributeParallelization; i++ { + go underlying.distribute() + } es = underlying @@ -71,9 +79,12 @@ func (e *eventStream) RegisterSink(s Sink) error { return ErrSinkAlreadyRegistered } - downstream := make(chan Event, e.sinkBufferSize) - s.OnSubscribe(downstream) - e.sinks[name] = downstream + rs := ®isteredSink{ + downstream: make(chan Event, e.sinkBufferSize), + lock: new(sync.Mutex), + } + s.OnSubscribe(rs.downstream) + e.sinks[name] = rs return nil } @@ -86,22 +97,24 @@ func (e eventStream) Sinks() (sinks []string) { func (e *eventStream) Close() error { close(e.buffer) - for _, downstream := range e.sinks { - close(downstream) + for _, rs := range e.sinks { + close(rs.downstream) } return nil } func (e *eventStream) distribute() { for ev := range e.buffer { - for name, s := range e.sinks { + for name, rs := range e.sinks { + rs.lock.Lock() e.logger.Debug("notify sink", zap.String("sink", name)) select { - case s <- *ev: + case rs.downstream <- *ev: e.logger.Debug("pushed event to sink channel") case <-time.After(e.sinkConsumptionTimeout): e.logger.Warn("sink consummation timed out") } + rs.lock.Unlock() } } } diff --git a/pkg/audit/event_stream_test.go b/pkg/audit/event_stream_test.go index 17732e3..4e0cd70 100644 --- a/pkg/audit/event_stream_test.go +++ b/pkg/audit/event_stream_test.go @@ -169,7 +169,7 @@ func Test_eventStream_Emit(t *testing.T) { name: "Expect to get a single event", subscribe: true, args: args{ - opts: []audit.EventStreamOption{audit.WithBufferSize(10)}, + opts: []audit.EventStreamOption{}, evs: testEvents[:1], }, }, @@ -177,7 +177,7 @@ func Test_eventStream_Emit(t *testing.T) { name: "Expect to get multiple events", subscribe: true, args: args{ - opts: []audit.EventStreamOption{audit.WithBufferSize(10)}, + opts: []audit.EventStreamOption{}, evs: testEvents, }, }, diff --git a/pkg/audit/log_sink.go b/pkg/audit/log_sink.go index a6e07f1..c6ce90d 100644 --- a/pkg/audit/log_sink.go +++ b/pkg/audit/log_sink.go @@ -46,7 +46,6 @@ func (l logSink) OnSubscribe(evs <-chan Event) { zap.Uint16("source_port", ev.SourcePort), zap.String("destination_ip", ev.DestinationIP.String()), zap.Uint16("destination_port", ev.DestinationPort), - zap.Any("details", ev.ProtocolDetails), ) } }(l.logger, evs) diff --git a/pkg/audit/options.go b/pkg/audit/options.go index 24ee746..76b737b 100644 --- a/pkg/audit/options.go +++ b/pkg/audit/options.go @@ -1,6 +1,9 @@ package audit -import "time" +import ( + "runtime" + "time" +) const ( defaultEventStreamBufferSize = 100 @@ -9,8 +12,9 @@ const ( ) var ( - generatorIdx int64 = 1 - WithBufferSize = func(bufferSize int) EventStreamOption { + generatorIdx int64 = 1 + defaultDistributeParallelization = runtime.NumCPU() / 2 + WithBufferSize = func(bufferSize int) EventStreamOption { return func(cfg *eventStreamCfg) { cfg.bufferSize = bufferSize } @@ -30,23 +34,33 @@ var ( cfg.sinkConsumptionTimeout = timeout } } + WithDistributeParallelization = func(parallelization int) EventStreamOption { + return func(cfg *eventStreamCfg) { + if parallelization <= 0 || parallelization > runtime.NumCPU() { + return + } + cfg.distributeParallelization = parallelization + } + } ) type EventStreamOption func(cfg *eventStreamCfg) type eventStreamCfg struct { - bufferSize int - sinkBuffersize int - generatorIndex int64 - sinkConsumptionTimeout time.Duration + bufferSize int + sinkBuffersize int + generatorIndex int64 + distributeParallelization int + sinkConsumptionTimeout time.Duration } func newEventStreamCfg() eventStreamCfg { cfg := eventStreamCfg{ - generatorIndex: generatorIdx, - sinkBuffersize: defaultSinkBufferSize, - bufferSize: defaultEventStreamBufferSize, - sinkConsumptionTimeout: defaultSinkConsumptionTimeout, + generatorIndex: generatorIdx, + sinkBuffersize: defaultSinkBufferSize, + bufferSize: defaultEventStreamBufferSize, + sinkConsumptionTimeout: defaultSinkConsumptionTimeout, + distributeParallelization: defaultDistributeParallelization, } generatorIdx++ diff --git a/plugins/http_mock/conn_context.go b/plugins/http_mock/conn_context.go new file mode 100644 index 0000000..b22af86 --- /dev/null +++ b/plugins/http_mock/conn_context.go @@ -0,0 +1,35 @@ +package http_mock + +import ( + "context" + "net" +) + +type httpContextKey string + +const ( + remoteAddrKey httpContextKey = "RemoteAddr" + localAddrKey httpContextKey = "LocalAddr" +) + +func StoreConnPropertiesInContext(ctx context.Context, c net.Conn) context.Context { + ctx = context.WithValue(ctx, remoteAddrKey, c.RemoteAddr()) + ctx = context.WithValue(ctx, localAddrKey, c.LocalAddr()) + return ctx +} + +func LocalAddr(ctx context.Context) net.Addr { + val := ctx.Value(localAddrKey) + if val == nil { + return nil + } + return val.(net.Addr) +} + +func RemoteAddr(ctx context.Context) net.Addr { + val := ctx.Value(remoteAddrKey) + if val == nil { + return nil + } + return val.(net.Addr) +} diff --git a/plugins/http_mock/handler.go b/plugins/http_mock/handler.go index c87bd2a..f1acb23 100644 --- a/plugins/http_mock/handler.go +++ b/plugins/http_mock/handler.go @@ -40,9 +40,14 @@ func (p *httpHandler) Start(ctx api.PluginContext, config config.HandlerConfig) router := &RegexpHandler{ logger: p.logger, + emitter: ctx.Audit(), handlerName: config.HandlerName, } - p.server = &http.Server{Addr: config.ListenAddr(), Handler: router} + p.server = &http.Server{ + Addr: config.ListenAddr(), + Handler: router, + ConnContext: StoreConnPropertiesInContext, + } for _, rule := range options.Rules { router.setupRoute(rule) diff --git a/plugins/http_mock/handler_test.go b/plugins/http_mock/handler_test.go index f60d031..804943a 100644 --- a/plugins/http_mock/handler_test.go +++ b/plugins/http_mock/handler_test.go @@ -14,6 +14,7 @@ import ( "github.com/golang/mock/gomock" "github.com/spf13/viper" api_mock "gitlab.com/inetmock/inetmock/internal/mock/api" + audit_mock "gitlab.com/inetmock/inetmock/internal/mock/audit" "gitlab.com/inetmock/inetmock/pkg/api" "gitlab.com/inetmock/inetmock/pkg/config" "gitlab.com/inetmock/inetmock/pkg/logging" @@ -72,11 +73,18 @@ func setupHandler(b *testing.B, ctrl *gomock.Controller, listenPort uint16) (api b.Error("handler not registered") } + emitter := audit_mock.NewMockEmitter(ctrl) + emitter.EXPECT().Emit(gomock.Any()).AnyTimes() + mockApp := api_mock.NewMockPluginContext(ctrl) mockApp.EXPECT(). Logger(). Return(testLogger) + mockApp.EXPECT(). + Audit(). + Return(emitter) + v := viper.New() v.Set("rules", []map[string]string{ { diff --git a/plugins/http_mock/regex_router.go b/plugins/http_mock/regex_router.go index 3de7ea4..9b643d1 100644 --- a/plugins/http_mock/regex_router.go +++ b/plugins/http_mock/regex_router.go @@ -1,13 +1,16 @@ package http_mock import ( - "bytes" + "net" "net/http" "strconv" + "strings" "github.com/prometheus/client_golang/prometheus" + "gitlab.com/inetmock/inetmock/pkg/audit" "gitlab.com/inetmock/inetmock/pkg/logging" "go.uber.org/zap" + "google.golang.org/protobuf/types/known/anypb" ) type route struct { @@ -19,16 +22,13 @@ type RegexpHandler struct { handlerName string logger logging.Logger routes []*route + emitter audit.Emitter } func (h *RegexpHandler) Handler(rule targetRule, handler http.Handler) { h.routes = append(h.routes, &route{rule, handler}) } -func (h *RegexpHandler) HandleFunc(rule targetRule, handler func(http.ResponseWriter, *http.Request)) { - h.routes = append(h.routes, &route{rule, http.HandlerFunc(handler)}) -} - func (h *RegexpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { timer := prometheus.NewTimer(requestDurationHistogram.WithLabelValues(h.handlerName)) defer timer.ObserveDuration() @@ -53,25 +53,64 @@ func (h *RegexpHandler) setupRoute(rule targetRule) { zap.String("response", rule.Response()), ) - h.Handler(rule, createHandlerForTarget(h.logger, rule.response)) -} - -func createHandlerForTarget(logger logging.Logger, targetPath string) http.Handler { - return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { - headerWriter := &bytes.Buffer{} - request.Header.Write(headerWriter) - - logger.Info( - "Handling request", - zap.String("source", request.RemoteAddr), - zap.String("host", request.Host), - zap.String("method", request.Method), - zap.String("protocol", request.Proto), - zap.String("path", request.RequestURI), - zap.String("response", targetPath), - zap.Reflect("headers", request.Header), - ) - - http.ServeFile(writer, request, targetPath) + h.Handler(rule, emittingFileHandler{ + emitter: h.emitter, + targetPath: rule.response, }) } + +type emittingFileHandler struct { + emitter audit.Emitter + targetPath string +} + +func (f emittingFileHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { + f.emitter.Emit(eventFromRequest(request)) + http.ServeFile(writer, request, f.targetPath) +} + +func eventFromRequest(request *http.Request) audit.Event { + details := audit.HTTPDetails{ + Method: request.Method, + Host: request.Host, + URI: request.RequestURI, + Proto: request.Proto, + Headers: request.Header, + } + var wire *anypb.Any + if w, err := details.MarshalToWireFormat(); err == nil { + wire = w + } + + localIP, localPort := ipPortFromAddr(LocalAddr(request.Context())) + remoteIP, remotePort := ipPortFromAddr(RemoteAddr(request.Context())) + + return audit.Event{ + Transport: audit.TransportProtocol_TCP, + Application: audit.AppProtocol_HTTP, + SourceIP: remoteIP, + DestinationIP: localIP, + SourcePort: remotePort, + DestinationPort: localPort, + ProtocolDetails: wire, + } +} + +func ipPortFromAddr(addr net.Addr) (ip net.IP, port uint16) { + if tcpAddr, isTCPAddr := addr.(*net.TCPAddr); isTCPAddr { + return tcpAddr.IP, uint16(tcpAddr.Port) + } + + ipPortSplit := strings.Split(addr.String(), ":") + if len(ipPortSplit) != 2 { + return + } + + ip = net.ParseIP(ipPortSplit[0]) + + if p, err := strconv.Atoi(ipPortSplit[1]); err == nil { + port = uint16(p) + } + + return +}