Integrate into handlers

This commit is contained in:
Peter 2021-01-04 17:52:21 +01:00
parent 63607dfd67
commit 0468c93671
Signed by: prskr
GPG key ID: C1DB5D2E8DB512F9
10 changed files with 190 additions and 52 deletions

View file

@ -11,6 +11,7 @@ import (
"github.com/spf13/cobra"
"gitlab.com/inetmock/inetmock/internal/endpoints"
"gitlab.com/inetmock/inetmock/pkg/api"
"gitlab.com/inetmock/inetmock/pkg/audit"
"gitlab.com/inetmock/inetmock/pkg/cert"
"gitlab.com/inetmock/inetmock/pkg/config"
"gitlab.com/inetmock/inetmock/pkg/health"
@ -47,6 +48,7 @@ type app struct {
registry api.HandlerRegistry
ctx context.Context
cancel context.CancelFunc
eventStream audit.EventStream
}
func (a *app) MustRun() {
@ -82,6 +84,10 @@ func (a app) EndpointManager() endpoints.EndpointManager {
return a.endpointManager
}
func (a app) Audit() audit.Emitter {
return a.eventStream
}
func (a app) HandlerRegistry() api.HandlerRegistry {
return a.registry
}
@ -137,15 +143,32 @@ func NewApp(registrations ...api.Registration) (inetmockApp App, err error) {
return
}
a.endpointManager = endpoints.NewEndpointManager(
a.registry,
a.Logger().Named("EndpointManager"),
a.checker,
a,
)
a.cfg = config.CreateConfig(cmd.Flags())
if err = a.cfg.ReadConfig(configFilePath); err != nil {
return
}
a.certStore, err = cert.NewDefaultStore(a.cfg, a.rootLogger)
a.endpointManager = endpoints.NewEndpointManager(a.registry, a.Logger().Named("EndpointManager"), a.checker, a)
if a.certStore, err = cert.NewDefaultStore(a.cfg, a.rootLogger); err != nil {
return
}
a.eventStream, err = audit.NewEventStream(
a.Logger().Named("EventStream"),
audit.WithSinkBufferSize(10),
)
if err != nil {
return
}
err = a.eventStream.RegisterSink(audit.NewLogSink(a.Logger().Named("LogSink")))
return
}

View file

@ -4,6 +4,7 @@ package api
import (
"context"
"gitlab.com/inetmock/inetmock/pkg/audit"
"gitlab.com/inetmock/inetmock/pkg/cert"
"gitlab.com/inetmock/inetmock/pkg/config"
"gitlab.com/inetmock/inetmock/pkg/logging"
@ -12,6 +13,7 @@ import (
type PluginContext interface {
Logger() logging.Logger
CertStore() cert.Store
Audit() audit.Emitter
}
type ProtocolHandler interface {

View file

@ -16,13 +16,18 @@ func init() {
type eventStream struct {
logger logging.Logger
buffer chan *Event
sinks map[string]chan Event
sinks map[string]*registeredSink
lock sync.Locker
idGenerator *snowflake.Node
sinkBufferSize int
sinkConsumptionTimeout time.Duration
}
type registeredSink struct {
downstream chan Event
lock sync.Locker
}
func NewEventStream(logger logging.Logger, options ...EventStreamOption) (es EventStream, err error) {
cfg := newEventStreamCfg()
@ -39,7 +44,7 @@ func NewEventStream(logger logging.Logger, options ...EventStreamOption) (es Eve
generatorIdx++
underlying := &eventStream{
logger: logger,
sinks: make(map[string]chan Event),
sinks: make(map[string]*registeredSink),
buffer: make(chan *Event, cfg.bufferSize),
sinkBufferSize: cfg.sinkBuffersize,
sinkConsumptionTimeout: cfg.sinkConsumptionTimeout,
@ -47,7 +52,10 @@ func NewEventStream(logger logging.Logger, options ...EventStreamOption) (es Eve
lock: &sync.Mutex{},
}
// start distribute workers
for i := 0; i < cfg.distributeParallelization; i++ {
go underlying.distribute()
}
es = underlying
@ -71,9 +79,12 @@ func (e *eventStream) RegisterSink(s Sink) error {
return ErrSinkAlreadyRegistered
}
downstream := make(chan Event, e.sinkBufferSize)
s.OnSubscribe(downstream)
e.sinks[name] = downstream
rs := &registeredSink{
downstream: make(chan Event, e.sinkBufferSize),
lock: new(sync.Mutex),
}
s.OnSubscribe(rs.downstream)
e.sinks[name] = rs
return nil
}
@ -86,22 +97,24 @@ func (e eventStream) Sinks() (sinks []string) {
func (e *eventStream) Close() error {
close(e.buffer)
for _, downstream := range e.sinks {
close(downstream)
for _, rs := range e.sinks {
close(rs.downstream)
}
return nil
}
func (e *eventStream) distribute() {
for ev := range e.buffer {
for name, s := range e.sinks {
for name, rs := range e.sinks {
rs.lock.Lock()
e.logger.Debug("notify sink", zap.String("sink", name))
select {
case s <- *ev:
case rs.downstream <- *ev:
e.logger.Debug("pushed event to sink channel")
case <-time.After(e.sinkConsumptionTimeout):
e.logger.Warn("sink consummation timed out")
}
rs.lock.Unlock()
}
}
}

View file

@ -169,7 +169,7 @@ func Test_eventStream_Emit(t *testing.T) {
name: "Expect to get a single event",
subscribe: true,
args: args{
opts: []audit.EventStreamOption{audit.WithBufferSize(10)},
opts: []audit.EventStreamOption{},
evs: testEvents[:1],
},
},
@ -177,7 +177,7 @@ func Test_eventStream_Emit(t *testing.T) {
name: "Expect to get multiple events",
subscribe: true,
args: args{
opts: []audit.EventStreamOption{audit.WithBufferSize(10)},
opts: []audit.EventStreamOption{},
evs: testEvents,
},
},

View file

@ -46,7 +46,6 @@ func (l logSink) OnSubscribe(evs <-chan Event) {
zap.Uint16("source_port", ev.SourcePort),
zap.String("destination_ip", ev.DestinationIP.String()),
zap.Uint16("destination_port", ev.DestinationPort),
zap.Any("details", ev.ProtocolDetails),
)
}
}(l.logger, evs)

View file

@ -1,6 +1,9 @@
package audit
import "time"
import (
"runtime"
"time"
)
const (
defaultEventStreamBufferSize = 100
@ -10,6 +13,7 @@ const (
var (
generatorIdx int64 = 1
defaultDistributeParallelization = runtime.NumCPU() / 2
WithBufferSize = func(bufferSize int) EventStreamOption {
return func(cfg *eventStreamCfg) {
cfg.bufferSize = bufferSize
@ -30,6 +34,14 @@ var (
cfg.sinkConsumptionTimeout = timeout
}
}
WithDistributeParallelization = func(parallelization int) EventStreamOption {
return func(cfg *eventStreamCfg) {
if parallelization <= 0 || parallelization > runtime.NumCPU() {
return
}
cfg.distributeParallelization = parallelization
}
}
)
type EventStreamOption func(cfg *eventStreamCfg)
@ -38,6 +50,7 @@ type eventStreamCfg struct {
bufferSize int
sinkBuffersize int
generatorIndex int64
distributeParallelization int
sinkConsumptionTimeout time.Duration
}
@ -47,6 +60,7 @@ func newEventStreamCfg() eventStreamCfg {
sinkBuffersize: defaultSinkBufferSize,
bufferSize: defaultEventStreamBufferSize,
sinkConsumptionTimeout: defaultSinkConsumptionTimeout,
distributeParallelization: defaultDistributeParallelization,
}
generatorIdx++

View file

@ -0,0 +1,35 @@
package http_mock
import (
"context"
"net"
)
type httpContextKey string
const (
remoteAddrKey httpContextKey = "RemoteAddr"
localAddrKey httpContextKey = "LocalAddr"
)
func StoreConnPropertiesInContext(ctx context.Context, c net.Conn) context.Context {
ctx = context.WithValue(ctx, remoteAddrKey, c.RemoteAddr())
ctx = context.WithValue(ctx, localAddrKey, c.LocalAddr())
return ctx
}
func LocalAddr(ctx context.Context) net.Addr {
val := ctx.Value(localAddrKey)
if val == nil {
return nil
}
return val.(net.Addr)
}
func RemoteAddr(ctx context.Context) net.Addr {
val := ctx.Value(remoteAddrKey)
if val == nil {
return nil
}
return val.(net.Addr)
}

View file

@ -40,9 +40,14 @@ func (p *httpHandler) Start(ctx api.PluginContext, config config.HandlerConfig)
router := &RegexpHandler{
logger: p.logger,
emitter: ctx.Audit(),
handlerName: config.HandlerName,
}
p.server = &http.Server{Addr: config.ListenAddr(), Handler: router}
p.server = &http.Server{
Addr: config.ListenAddr(),
Handler: router,
ConnContext: StoreConnPropertiesInContext,
}
for _, rule := range options.Rules {
router.setupRoute(rule)

View file

@ -14,6 +14,7 @@ import (
"github.com/golang/mock/gomock"
"github.com/spf13/viper"
api_mock "gitlab.com/inetmock/inetmock/internal/mock/api"
audit_mock "gitlab.com/inetmock/inetmock/internal/mock/audit"
"gitlab.com/inetmock/inetmock/pkg/api"
"gitlab.com/inetmock/inetmock/pkg/config"
"gitlab.com/inetmock/inetmock/pkg/logging"
@ -72,11 +73,18 @@ func setupHandler(b *testing.B, ctrl *gomock.Controller, listenPort uint16) (api
b.Error("handler not registered")
}
emitter := audit_mock.NewMockEmitter(ctrl)
emitter.EXPECT().Emit(gomock.Any()).AnyTimes()
mockApp := api_mock.NewMockPluginContext(ctrl)
mockApp.EXPECT().
Logger().
Return(testLogger)
mockApp.EXPECT().
Audit().
Return(emitter)
v := viper.New()
v.Set("rules", []map[string]string{
{

View file

@ -1,13 +1,16 @@
package http_mock
import (
"bytes"
"net"
"net/http"
"strconv"
"strings"
"github.com/prometheus/client_golang/prometheus"
"gitlab.com/inetmock/inetmock/pkg/audit"
"gitlab.com/inetmock/inetmock/pkg/logging"
"go.uber.org/zap"
"google.golang.org/protobuf/types/known/anypb"
)
type route struct {
@ -19,16 +22,13 @@ type RegexpHandler struct {
handlerName string
logger logging.Logger
routes []*route
emitter audit.Emitter
}
func (h *RegexpHandler) Handler(rule targetRule, handler http.Handler) {
h.routes = append(h.routes, &route{rule, handler})
}
func (h *RegexpHandler) HandleFunc(rule targetRule, handler func(http.ResponseWriter, *http.Request)) {
h.routes = append(h.routes, &route{rule, http.HandlerFunc(handler)})
}
func (h *RegexpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
timer := prometheus.NewTimer(requestDurationHistogram.WithLabelValues(h.handlerName))
defer timer.ObserveDuration()
@ -53,25 +53,64 @@ func (h *RegexpHandler) setupRoute(rule targetRule) {
zap.String("response", rule.Response()),
)
h.Handler(rule, createHandlerForTarget(h.logger, rule.response))
}
func createHandlerForTarget(logger logging.Logger, targetPath string) http.Handler {
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
headerWriter := &bytes.Buffer{}
request.Header.Write(headerWriter)
logger.Info(
"Handling request",
zap.String("source", request.RemoteAddr),
zap.String("host", request.Host),
zap.String("method", request.Method),
zap.String("protocol", request.Proto),
zap.String("path", request.RequestURI),
zap.String("response", targetPath),
zap.Reflect("headers", request.Header),
)
http.ServeFile(writer, request, targetPath)
h.Handler(rule, emittingFileHandler{
emitter: h.emitter,
targetPath: rule.response,
})
}
type emittingFileHandler struct {
emitter audit.Emitter
targetPath string
}
func (f emittingFileHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
f.emitter.Emit(eventFromRequest(request))
http.ServeFile(writer, request, f.targetPath)
}
func eventFromRequest(request *http.Request) audit.Event {
details := audit.HTTPDetails{
Method: request.Method,
Host: request.Host,
URI: request.RequestURI,
Proto: request.Proto,
Headers: request.Header,
}
var wire *anypb.Any
if w, err := details.MarshalToWireFormat(); err == nil {
wire = w
}
localIP, localPort := ipPortFromAddr(LocalAddr(request.Context()))
remoteIP, remotePort := ipPortFromAddr(RemoteAddr(request.Context()))
return audit.Event{
Transport: audit.TransportProtocol_TCP,
Application: audit.AppProtocol_HTTP,
SourceIP: remoteIP,
DestinationIP: localIP,
SourcePort: remotePort,
DestinationPort: localPort,
ProtocolDetails: wire,
}
}
func ipPortFromAddr(addr net.Addr) (ip net.IP, port uint16) {
if tcpAddr, isTCPAddr := addr.(*net.TCPAddr); isTCPAddr {
return tcpAddr.IP, uint16(tcpAddr.Port)
}
ipPortSplit := strings.Split(addr.String(), ":")
if len(ipPortSplit) != 2 {
return
}
ip = net.ParseIP(ipPortSplit[0])
if p, err := strconv.Atoi(ipPortSplit[1]); err == nil {
port = uint16(p)
}
return
}