Integrate into handlers

This commit is contained in:
Peter 2021-01-04 17:52:21 +01:00
parent 63607dfd67
commit 0468c93671
Signed by: prskr
GPG key ID: C1DB5D2E8DB512F9
10 changed files with 190 additions and 52 deletions

View file

@ -11,6 +11,7 @@ import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
"gitlab.com/inetmock/inetmock/internal/endpoints" "gitlab.com/inetmock/inetmock/internal/endpoints"
"gitlab.com/inetmock/inetmock/pkg/api" "gitlab.com/inetmock/inetmock/pkg/api"
"gitlab.com/inetmock/inetmock/pkg/audit"
"gitlab.com/inetmock/inetmock/pkg/cert" "gitlab.com/inetmock/inetmock/pkg/cert"
"gitlab.com/inetmock/inetmock/pkg/config" "gitlab.com/inetmock/inetmock/pkg/config"
"gitlab.com/inetmock/inetmock/pkg/health" "gitlab.com/inetmock/inetmock/pkg/health"
@ -47,6 +48,7 @@ type app struct {
registry api.HandlerRegistry registry api.HandlerRegistry
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
eventStream audit.EventStream
} }
func (a *app) MustRun() { func (a *app) MustRun() {
@ -82,6 +84,10 @@ func (a app) EndpointManager() endpoints.EndpointManager {
return a.endpointManager return a.endpointManager
} }
func (a app) Audit() audit.Emitter {
return a.eventStream
}
func (a app) HandlerRegistry() api.HandlerRegistry { func (a app) HandlerRegistry() api.HandlerRegistry {
return a.registry return a.registry
} }
@ -137,15 +143,32 @@ func NewApp(registrations ...api.Registration) (inetmockApp App, err error) {
return return
} }
a.endpointManager = endpoints.NewEndpointManager(
a.registry,
a.Logger().Named("EndpointManager"),
a.checker,
a,
)
a.cfg = config.CreateConfig(cmd.Flags()) a.cfg = config.CreateConfig(cmd.Flags())
if err = a.cfg.ReadConfig(configFilePath); err != nil { if err = a.cfg.ReadConfig(configFilePath); err != nil {
return return
} }
a.certStore, err = cert.NewDefaultStore(a.cfg, a.rootLogger) if a.certStore, err = cert.NewDefaultStore(a.cfg, a.rootLogger); err != nil {
a.endpointManager = endpoints.NewEndpointManager(a.registry, a.Logger().Named("EndpointManager"), a.checker, a) return
}
a.eventStream, err = audit.NewEventStream(
a.Logger().Named("EventStream"),
audit.WithSinkBufferSize(10),
)
if err != nil {
return
}
err = a.eventStream.RegisterSink(audit.NewLogSink(a.Logger().Named("LogSink")))
return return
} }

View file

@ -4,6 +4,7 @@ package api
import ( import (
"context" "context"
"gitlab.com/inetmock/inetmock/pkg/audit"
"gitlab.com/inetmock/inetmock/pkg/cert" "gitlab.com/inetmock/inetmock/pkg/cert"
"gitlab.com/inetmock/inetmock/pkg/config" "gitlab.com/inetmock/inetmock/pkg/config"
"gitlab.com/inetmock/inetmock/pkg/logging" "gitlab.com/inetmock/inetmock/pkg/logging"
@ -12,6 +13,7 @@ import (
type PluginContext interface { type PluginContext interface {
Logger() logging.Logger Logger() logging.Logger
CertStore() cert.Store CertStore() cert.Store
Audit() audit.Emitter
} }
type ProtocolHandler interface { type ProtocolHandler interface {

View file

@ -16,13 +16,18 @@ func init() {
type eventStream struct { type eventStream struct {
logger logging.Logger logger logging.Logger
buffer chan *Event buffer chan *Event
sinks map[string]chan Event sinks map[string]*registeredSink
lock sync.Locker lock sync.Locker
idGenerator *snowflake.Node idGenerator *snowflake.Node
sinkBufferSize int sinkBufferSize int
sinkConsumptionTimeout time.Duration sinkConsumptionTimeout time.Duration
} }
type registeredSink struct {
downstream chan Event
lock sync.Locker
}
func NewEventStream(logger logging.Logger, options ...EventStreamOption) (es EventStream, err error) { func NewEventStream(logger logging.Logger, options ...EventStreamOption) (es EventStream, err error) {
cfg := newEventStreamCfg() cfg := newEventStreamCfg()
@ -39,7 +44,7 @@ func NewEventStream(logger logging.Logger, options ...EventStreamOption) (es Eve
generatorIdx++ generatorIdx++
underlying := &eventStream{ underlying := &eventStream{
logger: logger, logger: logger,
sinks: make(map[string]chan Event), sinks: make(map[string]*registeredSink),
buffer: make(chan *Event, cfg.bufferSize), buffer: make(chan *Event, cfg.bufferSize),
sinkBufferSize: cfg.sinkBuffersize, sinkBufferSize: cfg.sinkBuffersize,
sinkConsumptionTimeout: cfg.sinkConsumptionTimeout, sinkConsumptionTimeout: cfg.sinkConsumptionTimeout,
@ -47,7 +52,10 @@ func NewEventStream(logger logging.Logger, options ...EventStreamOption) (es Eve
lock: &sync.Mutex{}, lock: &sync.Mutex{},
} }
go underlying.distribute() // start distribute workers
for i := 0; i < cfg.distributeParallelization; i++ {
go underlying.distribute()
}
es = underlying es = underlying
@ -71,9 +79,12 @@ func (e *eventStream) RegisterSink(s Sink) error {
return ErrSinkAlreadyRegistered return ErrSinkAlreadyRegistered
} }
downstream := make(chan Event, e.sinkBufferSize) rs := &registeredSink{
s.OnSubscribe(downstream) downstream: make(chan Event, e.sinkBufferSize),
e.sinks[name] = downstream lock: new(sync.Mutex),
}
s.OnSubscribe(rs.downstream)
e.sinks[name] = rs
return nil return nil
} }
@ -86,22 +97,24 @@ func (e eventStream) Sinks() (sinks []string) {
func (e *eventStream) Close() error { func (e *eventStream) Close() error {
close(e.buffer) close(e.buffer)
for _, downstream := range e.sinks { for _, rs := range e.sinks {
close(downstream) close(rs.downstream)
} }
return nil return nil
} }
func (e *eventStream) distribute() { func (e *eventStream) distribute() {
for ev := range e.buffer { for ev := range e.buffer {
for name, s := range e.sinks { for name, rs := range e.sinks {
rs.lock.Lock()
e.logger.Debug("notify sink", zap.String("sink", name)) e.logger.Debug("notify sink", zap.String("sink", name))
select { select {
case s <- *ev: case rs.downstream <- *ev:
e.logger.Debug("pushed event to sink channel") e.logger.Debug("pushed event to sink channel")
case <-time.After(e.sinkConsumptionTimeout): case <-time.After(e.sinkConsumptionTimeout):
e.logger.Warn("sink consummation timed out") e.logger.Warn("sink consummation timed out")
} }
rs.lock.Unlock()
} }
} }
} }

View file

@ -169,7 +169,7 @@ func Test_eventStream_Emit(t *testing.T) {
name: "Expect to get a single event", name: "Expect to get a single event",
subscribe: true, subscribe: true,
args: args{ args: args{
opts: []audit.EventStreamOption{audit.WithBufferSize(10)}, opts: []audit.EventStreamOption{},
evs: testEvents[:1], evs: testEvents[:1],
}, },
}, },
@ -177,7 +177,7 @@ func Test_eventStream_Emit(t *testing.T) {
name: "Expect to get multiple events", name: "Expect to get multiple events",
subscribe: true, subscribe: true,
args: args{ args: args{
opts: []audit.EventStreamOption{audit.WithBufferSize(10)}, opts: []audit.EventStreamOption{},
evs: testEvents, evs: testEvents,
}, },
}, },

View file

@ -46,7 +46,6 @@ func (l logSink) OnSubscribe(evs <-chan Event) {
zap.Uint16("source_port", ev.SourcePort), zap.Uint16("source_port", ev.SourcePort),
zap.String("destination_ip", ev.DestinationIP.String()), zap.String("destination_ip", ev.DestinationIP.String()),
zap.Uint16("destination_port", ev.DestinationPort), zap.Uint16("destination_port", ev.DestinationPort),
zap.Any("details", ev.ProtocolDetails),
) )
} }
}(l.logger, evs) }(l.logger, evs)

View file

@ -1,6 +1,9 @@
package audit package audit
import "time" import (
"runtime"
"time"
)
const ( const (
defaultEventStreamBufferSize = 100 defaultEventStreamBufferSize = 100
@ -9,8 +12,9 @@ const (
) )
var ( var (
generatorIdx int64 = 1 generatorIdx int64 = 1
WithBufferSize = func(bufferSize int) EventStreamOption { defaultDistributeParallelization = runtime.NumCPU() / 2
WithBufferSize = func(bufferSize int) EventStreamOption {
return func(cfg *eventStreamCfg) { return func(cfg *eventStreamCfg) {
cfg.bufferSize = bufferSize cfg.bufferSize = bufferSize
} }
@ -30,23 +34,33 @@ var (
cfg.sinkConsumptionTimeout = timeout cfg.sinkConsumptionTimeout = timeout
} }
} }
WithDistributeParallelization = func(parallelization int) EventStreamOption {
return func(cfg *eventStreamCfg) {
if parallelization <= 0 || parallelization > runtime.NumCPU() {
return
}
cfg.distributeParallelization = parallelization
}
}
) )
type EventStreamOption func(cfg *eventStreamCfg) type EventStreamOption func(cfg *eventStreamCfg)
type eventStreamCfg struct { type eventStreamCfg struct {
bufferSize int bufferSize int
sinkBuffersize int sinkBuffersize int
generatorIndex int64 generatorIndex int64
sinkConsumptionTimeout time.Duration distributeParallelization int
sinkConsumptionTimeout time.Duration
} }
func newEventStreamCfg() eventStreamCfg { func newEventStreamCfg() eventStreamCfg {
cfg := eventStreamCfg{ cfg := eventStreamCfg{
generatorIndex: generatorIdx, generatorIndex: generatorIdx,
sinkBuffersize: defaultSinkBufferSize, sinkBuffersize: defaultSinkBufferSize,
bufferSize: defaultEventStreamBufferSize, bufferSize: defaultEventStreamBufferSize,
sinkConsumptionTimeout: defaultSinkConsumptionTimeout, sinkConsumptionTimeout: defaultSinkConsumptionTimeout,
distributeParallelization: defaultDistributeParallelization,
} }
generatorIdx++ generatorIdx++

View file

@ -0,0 +1,35 @@
package http_mock
import (
"context"
"net"
)
type httpContextKey string
const (
remoteAddrKey httpContextKey = "RemoteAddr"
localAddrKey httpContextKey = "LocalAddr"
)
func StoreConnPropertiesInContext(ctx context.Context, c net.Conn) context.Context {
ctx = context.WithValue(ctx, remoteAddrKey, c.RemoteAddr())
ctx = context.WithValue(ctx, localAddrKey, c.LocalAddr())
return ctx
}
func LocalAddr(ctx context.Context) net.Addr {
val := ctx.Value(localAddrKey)
if val == nil {
return nil
}
return val.(net.Addr)
}
func RemoteAddr(ctx context.Context) net.Addr {
val := ctx.Value(remoteAddrKey)
if val == nil {
return nil
}
return val.(net.Addr)
}

View file

@ -40,9 +40,14 @@ func (p *httpHandler) Start(ctx api.PluginContext, config config.HandlerConfig)
router := &RegexpHandler{ router := &RegexpHandler{
logger: p.logger, logger: p.logger,
emitter: ctx.Audit(),
handlerName: config.HandlerName, handlerName: config.HandlerName,
} }
p.server = &http.Server{Addr: config.ListenAddr(), Handler: router} p.server = &http.Server{
Addr: config.ListenAddr(),
Handler: router,
ConnContext: StoreConnPropertiesInContext,
}
for _, rule := range options.Rules { for _, rule := range options.Rules {
router.setupRoute(rule) router.setupRoute(rule)

View file

@ -14,6 +14,7 @@ import (
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/spf13/viper" "github.com/spf13/viper"
api_mock "gitlab.com/inetmock/inetmock/internal/mock/api" api_mock "gitlab.com/inetmock/inetmock/internal/mock/api"
audit_mock "gitlab.com/inetmock/inetmock/internal/mock/audit"
"gitlab.com/inetmock/inetmock/pkg/api" "gitlab.com/inetmock/inetmock/pkg/api"
"gitlab.com/inetmock/inetmock/pkg/config" "gitlab.com/inetmock/inetmock/pkg/config"
"gitlab.com/inetmock/inetmock/pkg/logging" "gitlab.com/inetmock/inetmock/pkg/logging"
@ -72,11 +73,18 @@ func setupHandler(b *testing.B, ctrl *gomock.Controller, listenPort uint16) (api
b.Error("handler not registered") b.Error("handler not registered")
} }
emitter := audit_mock.NewMockEmitter(ctrl)
emitter.EXPECT().Emit(gomock.Any()).AnyTimes()
mockApp := api_mock.NewMockPluginContext(ctrl) mockApp := api_mock.NewMockPluginContext(ctrl)
mockApp.EXPECT(). mockApp.EXPECT().
Logger(). Logger().
Return(testLogger) Return(testLogger)
mockApp.EXPECT().
Audit().
Return(emitter)
v := viper.New() v := viper.New()
v.Set("rules", []map[string]string{ v.Set("rules", []map[string]string{
{ {

View file

@ -1,13 +1,16 @@
package http_mock package http_mock
import ( import (
"bytes" "net"
"net/http" "net/http"
"strconv" "strconv"
"strings"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
"gitlab.com/inetmock/inetmock/pkg/audit"
"gitlab.com/inetmock/inetmock/pkg/logging" "gitlab.com/inetmock/inetmock/pkg/logging"
"go.uber.org/zap" "go.uber.org/zap"
"google.golang.org/protobuf/types/known/anypb"
) )
type route struct { type route struct {
@ -19,16 +22,13 @@ type RegexpHandler struct {
handlerName string handlerName string
logger logging.Logger logger logging.Logger
routes []*route routes []*route
emitter audit.Emitter
} }
func (h *RegexpHandler) Handler(rule targetRule, handler http.Handler) { func (h *RegexpHandler) Handler(rule targetRule, handler http.Handler) {
h.routes = append(h.routes, &route{rule, handler}) h.routes = append(h.routes, &route{rule, handler})
} }
func (h *RegexpHandler) HandleFunc(rule targetRule, handler func(http.ResponseWriter, *http.Request)) {
h.routes = append(h.routes, &route{rule, http.HandlerFunc(handler)})
}
func (h *RegexpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (h *RegexpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
timer := prometheus.NewTimer(requestDurationHistogram.WithLabelValues(h.handlerName)) timer := prometheus.NewTimer(requestDurationHistogram.WithLabelValues(h.handlerName))
defer timer.ObserveDuration() defer timer.ObserveDuration()
@ -53,25 +53,64 @@ func (h *RegexpHandler) setupRoute(rule targetRule) {
zap.String("response", rule.Response()), zap.String("response", rule.Response()),
) )
h.Handler(rule, createHandlerForTarget(h.logger, rule.response)) h.Handler(rule, emittingFileHandler{
} emitter: h.emitter,
targetPath: rule.response,
func createHandlerForTarget(logger logging.Logger, targetPath string) http.Handler {
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
headerWriter := &bytes.Buffer{}
request.Header.Write(headerWriter)
logger.Info(
"Handling request",
zap.String("source", request.RemoteAddr),
zap.String("host", request.Host),
zap.String("method", request.Method),
zap.String("protocol", request.Proto),
zap.String("path", request.RequestURI),
zap.String("response", targetPath),
zap.Reflect("headers", request.Header),
)
http.ServeFile(writer, request, targetPath)
}) })
} }
type emittingFileHandler struct {
emitter audit.Emitter
targetPath string
}
func (f emittingFileHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
f.emitter.Emit(eventFromRequest(request))
http.ServeFile(writer, request, f.targetPath)
}
func eventFromRequest(request *http.Request) audit.Event {
details := audit.HTTPDetails{
Method: request.Method,
Host: request.Host,
URI: request.RequestURI,
Proto: request.Proto,
Headers: request.Header,
}
var wire *anypb.Any
if w, err := details.MarshalToWireFormat(); err == nil {
wire = w
}
localIP, localPort := ipPortFromAddr(LocalAddr(request.Context()))
remoteIP, remotePort := ipPortFromAddr(RemoteAddr(request.Context()))
return audit.Event{
Transport: audit.TransportProtocol_TCP,
Application: audit.AppProtocol_HTTP,
SourceIP: remoteIP,
DestinationIP: localIP,
SourcePort: remotePort,
DestinationPort: localPort,
ProtocolDetails: wire,
}
}
func ipPortFromAddr(addr net.Addr) (ip net.IP, port uint16) {
if tcpAddr, isTCPAddr := addr.(*net.TCPAddr); isTCPAddr {
return tcpAddr.IP, uint16(tcpAddr.Port)
}
ipPortSplit := strings.Split(addr.String(), ":")
if len(ipPortSplit) != 2 {
return
}
ip = net.ParseIP(ipPortSplit[0])
if p, err := strconv.Atoi(ipPortSplit[1]); err == nil {
port = uint16(p)
}
return
}