Integrate into handlers
This commit is contained in:
parent
63607dfd67
commit
0468c93671
10 changed files with 190 additions and 52 deletions
|
@ -11,6 +11,7 @@ import (
|
|||
"github.com/spf13/cobra"
|
||||
"gitlab.com/inetmock/inetmock/internal/endpoints"
|
||||
"gitlab.com/inetmock/inetmock/pkg/api"
|
||||
"gitlab.com/inetmock/inetmock/pkg/audit"
|
||||
"gitlab.com/inetmock/inetmock/pkg/cert"
|
||||
"gitlab.com/inetmock/inetmock/pkg/config"
|
||||
"gitlab.com/inetmock/inetmock/pkg/health"
|
||||
|
@ -47,6 +48,7 @@ type app struct {
|
|||
registry api.HandlerRegistry
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
eventStream audit.EventStream
|
||||
}
|
||||
|
||||
func (a *app) MustRun() {
|
||||
|
@ -82,6 +84,10 @@ func (a app) EndpointManager() endpoints.EndpointManager {
|
|||
return a.endpointManager
|
||||
}
|
||||
|
||||
func (a app) Audit() audit.Emitter {
|
||||
return a.eventStream
|
||||
}
|
||||
|
||||
func (a app) HandlerRegistry() api.HandlerRegistry {
|
||||
return a.registry
|
||||
}
|
||||
|
@ -137,15 +143,32 @@ func NewApp(registrations ...api.Registration) (inetmockApp App, err error) {
|
|||
return
|
||||
}
|
||||
|
||||
a.endpointManager = endpoints.NewEndpointManager(
|
||||
a.registry,
|
||||
a.Logger().Named("EndpointManager"),
|
||||
a.checker,
|
||||
a,
|
||||
)
|
||||
|
||||
a.cfg = config.CreateConfig(cmd.Flags())
|
||||
|
||||
if err = a.cfg.ReadConfig(configFilePath); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
a.certStore, err = cert.NewDefaultStore(a.cfg, a.rootLogger)
|
||||
a.endpointManager = endpoints.NewEndpointManager(a.registry, a.Logger().Named("EndpointManager"), a.checker, a)
|
||||
if a.certStore, err = cert.NewDefaultStore(a.cfg, a.rootLogger); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
a.eventStream, err = audit.NewEventStream(
|
||||
a.Logger().Named("EventStream"),
|
||||
audit.WithSinkBufferSize(10),
|
||||
)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
err = a.eventStream.RegisterSink(audit.NewLogSink(a.Logger().Named("LogSink")))
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -4,6 +4,7 @@ package api
|
|||
import (
|
||||
"context"
|
||||
|
||||
"gitlab.com/inetmock/inetmock/pkg/audit"
|
||||
"gitlab.com/inetmock/inetmock/pkg/cert"
|
||||
"gitlab.com/inetmock/inetmock/pkg/config"
|
||||
"gitlab.com/inetmock/inetmock/pkg/logging"
|
||||
|
@ -12,6 +13,7 @@ import (
|
|||
type PluginContext interface {
|
||||
Logger() logging.Logger
|
||||
CertStore() cert.Store
|
||||
Audit() audit.Emitter
|
||||
}
|
||||
|
||||
type ProtocolHandler interface {
|
||||
|
|
|
@ -16,13 +16,18 @@ func init() {
|
|||
type eventStream struct {
|
||||
logger logging.Logger
|
||||
buffer chan *Event
|
||||
sinks map[string]chan Event
|
||||
sinks map[string]*registeredSink
|
||||
lock sync.Locker
|
||||
idGenerator *snowflake.Node
|
||||
sinkBufferSize int
|
||||
sinkConsumptionTimeout time.Duration
|
||||
}
|
||||
|
||||
type registeredSink struct {
|
||||
downstream chan Event
|
||||
lock sync.Locker
|
||||
}
|
||||
|
||||
func NewEventStream(logger logging.Logger, options ...EventStreamOption) (es EventStream, err error) {
|
||||
cfg := newEventStreamCfg()
|
||||
|
||||
|
@ -39,7 +44,7 @@ func NewEventStream(logger logging.Logger, options ...EventStreamOption) (es Eve
|
|||
generatorIdx++
|
||||
underlying := &eventStream{
|
||||
logger: logger,
|
||||
sinks: make(map[string]chan Event),
|
||||
sinks: make(map[string]*registeredSink),
|
||||
buffer: make(chan *Event, cfg.bufferSize),
|
||||
sinkBufferSize: cfg.sinkBuffersize,
|
||||
sinkConsumptionTimeout: cfg.sinkConsumptionTimeout,
|
||||
|
@ -47,7 +52,10 @@ func NewEventStream(logger logging.Logger, options ...EventStreamOption) (es Eve
|
|||
lock: &sync.Mutex{},
|
||||
}
|
||||
|
||||
go underlying.distribute()
|
||||
// start distribute workers
|
||||
for i := 0; i < cfg.distributeParallelization; i++ {
|
||||
go underlying.distribute()
|
||||
}
|
||||
|
||||
es = underlying
|
||||
|
||||
|
@ -71,9 +79,12 @@ func (e *eventStream) RegisterSink(s Sink) error {
|
|||
return ErrSinkAlreadyRegistered
|
||||
}
|
||||
|
||||
downstream := make(chan Event, e.sinkBufferSize)
|
||||
s.OnSubscribe(downstream)
|
||||
e.sinks[name] = downstream
|
||||
rs := ®isteredSink{
|
||||
downstream: make(chan Event, e.sinkBufferSize),
|
||||
lock: new(sync.Mutex),
|
||||
}
|
||||
s.OnSubscribe(rs.downstream)
|
||||
e.sinks[name] = rs
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -86,22 +97,24 @@ func (e eventStream) Sinks() (sinks []string) {
|
|||
|
||||
func (e *eventStream) Close() error {
|
||||
close(e.buffer)
|
||||
for _, downstream := range e.sinks {
|
||||
close(downstream)
|
||||
for _, rs := range e.sinks {
|
||||
close(rs.downstream)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *eventStream) distribute() {
|
||||
for ev := range e.buffer {
|
||||
for name, s := range e.sinks {
|
||||
for name, rs := range e.sinks {
|
||||
rs.lock.Lock()
|
||||
e.logger.Debug("notify sink", zap.String("sink", name))
|
||||
select {
|
||||
case s <- *ev:
|
||||
case rs.downstream <- *ev:
|
||||
e.logger.Debug("pushed event to sink channel")
|
||||
case <-time.After(e.sinkConsumptionTimeout):
|
||||
e.logger.Warn("sink consummation timed out")
|
||||
}
|
||||
rs.lock.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -169,7 +169,7 @@ func Test_eventStream_Emit(t *testing.T) {
|
|||
name: "Expect to get a single event",
|
||||
subscribe: true,
|
||||
args: args{
|
||||
opts: []audit.EventStreamOption{audit.WithBufferSize(10)},
|
||||
opts: []audit.EventStreamOption{},
|
||||
evs: testEvents[:1],
|
||||
},
|
||||
},
|
||||
|
@ -177,7 +177,7 @@ func Test_eventStream_Emit(t *testing.T) {
|
|||
name: "Expect to get multiple events",
|
||||
subscribe: true,
|
||||
args: args{
|
||||
opts: []audit.EventStreamOption{audit.WithBufferSize(10)},
|
||||
opts: []audit.EventStreamOption{},
|
||||
evs: testEvents,
|
||||
},
|
||||
},
|
||||
|
|
|
@ -46,7 +46,6 @@ func (l logSink) OnSubscribe(evs <-chan Event) {
|
|||
zap.Uint16("source_port", ev.SourcePort),
|
||||
zap.String("destination_ip", ev.DestinationIP.String()),
|
||||
zap.Uint16("destination_port", ev.DestinationPort),
|
||||
zap.Any("details", ev.ProtocolDetails),
|
||||
)
|
||||
}
|
||||
}(l.logger, evs)
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
package audit
|
||||
|
||||
import "time"
|
||||
import (
|
||||
"runtime"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultEventStreamBufferSize = 100
|
||||
|
@ -9,8 +12,9 @@ const (
|
|||
)
|
||||
|
||||
var (
|
||||
generatorIdx int64 = 1
|
||||
WithBufferSize = func(bufferSize int) EventStreamOption {
|
||||
generatorIdx int64 = 1
|
||||
defaultDistributeParallelization = runtime.NumCPU() / 2
|
||||
WithBufferSize = func(bufferSize int) EventStreamOption {
|
||||
return func(cfg *eventStreamCfg) {
|
||||
cfg.bufferSize = bufferSize
|
||||
}
|
||||
|
@ -30,23 +34,33 @@ var (
|
|||
cfg.sinkConsumptionTimeout = timeout
|
||||
}
|
||||
}
|
||||
WithDistributeParallelization = func(parallelization int) EventStreamOption {
|
||||
return func(cfg *eventStreamCfg) {
|
||||
if parallelization <= 0 || parallelization > runtime.NumCPU() {
|
||||
return
|
||||
}
|
||||
cfg.distributeParallelization = parallelization
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
type EventStreamOption func(cfg *eventStreamCfg)
|
||||
|
||||
type eventStreamCfg struct {
|
||||
bufferSize int
|
||||
sinkBuffersize int
|
||||
generatorIndex int64
|
||||
sinkConsumptionTimeout time.Duration
|
||||
bufferSize int
|
||||
sinkBuffersize int
|
||||
generatorIndex int64
|
||||
distributeParallelization int
|
||||
sinkConsumptionTimeout time.Duration
|
||||
}
|
||||
|
||||
func newEventStreamCfg() eventStreamCfg {
|
||||
cfg := eventStreamCfg{
|
||||
generatorIndex: generatorIdx,
|
||||
sinkBuffersize: defaultSinkBufferSize,
|
||||
bufferSize: defaultEventStreamBufferSize,
|
||||
sinkConsumptionTimeout: defaultSinkConsumptionTimeout,
|
||||
generatorIndex: generatorIdx,
|
||||
sinkBuffersize: defaultSinkBufferSize,
|
||||
bufferSize: defaultEventStreamBufferSize,
|
||||
sinkConsumptionTimeout: defaultSinkConsumptionTimeout,
|
||||
distributeParallelization: defaultDistributeParallelization,
|
||||
}
|
||||
generatorIdx++
|
||||
|
||||
|
|
35
plugins/http_mock/conn_context.go
Normal file
35
plugins/http_mock/conn_context.go
Normal file
|
@ -0,0 +1,35 @@
|
|||
package http_mock
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
)
|
||||
|
||||
type httpContextKey string
|
||||
|
||||
const (
|
||||
remoteAddrKey httpContextKey = "RemoteAddr"
|
||||
localAddrKey httpContextKey = "LocalAddr"
|
||||
)
|
||||
|
||||
func StoreConnPropertiesInContext(ctx context.Context, c net.Conn) context.Context {
|
||||
ctx = context.WithValue(ctx, remoteAddrKey, c.RemoteAddr())
|
||||
ctx = context.WithValue(ctx, localAddrKey, c.LocalAddr())
|
||||
return ctx
|
||||
}
|
||||
|
||||
func LocalAddr(ctx context.Context) net.Addr {
|
||||
val := ctx.Value(localAddrKey)
|
||||
if val == nil {
|
||||
return nil
|
||||
}
|
||||
return val.(net.Addr)
|
||||
}
|
||||
|
||||
func RemoteAddr(ctx context.Context) net.Addr {
|
||||
val := ctx.Value(remoteAddrKey)
|
||||
if val == nil {
|
||||
return nil
|
||||
}
|
||||
return val.(net.Addr)
|
||||
}
|
|
@ -40,9 +40,14 @@ func (p *httpHandler) Start(ctx api.PluginContext, config config.HandlerConfig)
|
|||
|
||||
router := &RegexpHandler{
|
||||
logger: p.logger,
|
||||
emitter: ctx.Audit(),
|
||||
handlerName: config.HandlerName,
|
||||
}
|
||||
p.server = &http.Server{Addr: config.ListenAddr(), Handler: router}
|
||||
p.server = &http.Server{
|
||||
Addr: config.ListenAddr(),
|
||||
Handler: router,
|
||||
ConnContext: StoreConnPropertiesInContext,
|
||||
}
|
||||
|
||||
for _, rule := range options.Rules {
|
||||
router.setupRoute(rule)
|
||||
|
|
|
@ -14,6 +14,7 @@ import (
|
|||
"github.com/golang/mock/gomock"
|
||||
"github.com/spf13/viper"
|
||||
api_mock "gitlab.com/inetmock/inetmock/internal/mock/api"
|
||||
audit_mock "gitlab.com/inetmock/inetmock/internal/mock/audit"
|
||||
"gitlab.com/inetmock/inetmock/pkg/api"
|
||||
"gitlab.com/inetmock/inetmock/pkg/config"
|
||||
"gitlab.com/inetmock/inetmock/pkg/logging"
|
||||
|
@ -72,11 +73,18 @@ func setupHandler(b *testing.B, ctrl *gomock.Controller, listenPort uint16) (api
|
|||
b.Error("handler not registered")
|
||||
}
|
||||
|
||||
emitter := audit_mock.NewMockEmitter(ctrl)
|
||||
emitter.EXPECT().Emit(gomock.Any()).AnyTimes()
|
||||
|
||||
mockApp := api_mock.NewMockPluginContext(ctrl)
|
||||
mockApp.EXPECT().
|
||||
Logger().
|
||||
Return(testLogger)
|
||||
|
||||
mockApp.EXPECT().
|
||||
Audit().
|
||||
Return(emitter)
|
||||
|
||||
v := viper.New()
|
||||
v.Set("rules", []map[string]string{
|
||||
{
|
||||
|
|
|
@ -1,13 +1,16 @@
|
|||
package http_mock
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"gitlab.com/inetmock/inetmock/pkg/audit"
|
||||
"gitlab.com/inetmock/inetmock/pkg/logging"
|
||||
"go.uber.org/zap"
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
)
|
||||
|
||||
type route struct {
|
||||
|
@ -19,16 +22,13 @@ type RegexpHandler struct {
|
|||
handlerName string
|
||||
logger logging.Logger
|
||||
routes []*route
|
||||
emitter audit.Emitter
|
||||
}
|
||||
|
||||
func (h *RegexpHandler) Handler(rule targetRule, handler http.Handler) {
|
||||
h.routes = append(h.routes, &route{rule, handler})
|
||||
}
|
||||
|
||||
func (h *RegexpHandler) HandleFunc(rule targetRule, handler func(http.ResponseWriter, *http.Request)) {
|
||||
h.routes = append(h.routes, &route{rule, http.HandlerFunc(handler)})
|
||||
}
|
||||
|
||||
func (h *RegexpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
timer := prometheus.NewTimer(requestDurationHistogram.WithLabelValues(h.handlerName))
|
||||
defer timer.ObserveDuration()
|
||||
|
@ -53,25 +53,64 @@ func (h *RegexpHandler) setupRoute(rule targetRule) {
|
|||
zap.String("response", rule.Response()),
|
||||
)
|
||||
|
||||
h.Handler(rule, createHandlerForTarget(h.logger, rule.response))
|
||||
}
|
||||
|
||||
func createHandlerForTarget(logger logging.Logger, targetPath string) http.Handler {
|
||||
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
|
||||
headerWriter := &bytes.Buffer{}
|
||||
request.Header.Write(headerWriter)
|
||||
|
||||
logger.Info(
|
||||
"Handling request",
|
||||
zap.String("source", request.RemoteAddr),
|
||||
zap.String("host", request.Host),
|
||||
zap.String("method", request.Method),
|
||||
zap.String("protocol", request.Proto),
|
||||
zap.String("path", request.RequestURI),
|
||||
zap.String("response", targetPath),
|
||||
zap.Reflect("headers", request.Header),
|
||||
)
|
||||
|
||||
http.ServeFile(writer, request, targetPath)
|
||||
h.Handler(rule, emittingFileHandler{
|
||||
emitter: h.emitter,
|
||||
targetPath: rule.response,
|
||||
})
|
||||
}
|
||||
|
||||
type emittingFileHandler struct {
|
||||
emitter audit.Emitter
|
||||
targetPath string
|
||||
}
|
||||
|
||||
func (f emittingFileHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
|
||||
f.emitter.Emit(eventFromRequest(request))
|
||||
http.ServeFile(writer, request, f.targetPath)
|
||||
}
|
||||
|
||||
func eventFromRequest(request *http.Request) audit.Event {
|
||||
details := audit.HTTPDetails{
|
||||
Method: request.Method,
|
||||
Host: request.Host,
|
||||
URI: request.RequestURI,
|
||||
Proto: request.Proto,
|
||||
Headers: request.Header,
|
||||
}
|
||||
var wire *anypb.Any
|
||||
if w, err := details.MarshalToWireFormat(); err == nil {
|
||||
wire = w
|
||||
}
|
||||
|
||||
localIP, localPort := ipPortFromAddr(LocalAddr(request.Context()))
|
||||
remoteIP, remotePort := ipPortFromAddr(RemoteAddr(request.Context()))
|
||||
|
||||
return audit.Event{
|
||||
Transport: audit.TransportProtocol_TCP,
|
||||
Application: audit.AppProtocol_HTTP,
|
||||
SourceIP: remoteIP,
|
||||
DestinationIP: localIP,
|
||||
SourcePort: remotePort,
|
||||
DestinationPort: localPort,
|
||||
ProtocolDetails: wire,
|
||||
}
|
||||
}
|
||||
|
||||
func ipPortFromAddr(addr net.Addr) (ip net.IP, port uint16) {
|
||||
if tcpAddr, isTCPAddr := addr.(*net.TCPAddr); isTCPAddr {
|
||||
return tcpAddr.IP, uint16(tcpAddr.Port)
|
||||
}
|
||||
|
||||
ipPortSplit := strings.Split(addr.String(), ":")
|
||||
if len(ipPortSplit) != 2 {
|
||||
return
|
||||
}
|
||||
|
||||
ip = net.ParseIP(ipPortSplit[0])
|
||||
|
||||
if p, err := strconv.Atoi(ipPortSplit[1]); err == nil {
|
||||
port = uint16(p)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue