From 63607dfd6755639493e4ffe837015fcaefce4dc6 Mon Sep 17 00:00:00 2001 From: Peter Kurfer Date: Sat, 2 Jan 2021 17:24:06 +0100 Subject: [PATCH] Implement log and writer sinks Add reader --- internal/endpoints/endpoint_manager.go | 4 +- pkg/audit/event.go | 102 ++++++++++++++++++++- pkg/audit/event_stream.go | 14 +-- pkg/audit/event_stream_test.go | 120 ++++++++++++++----------- pkg/audit/http_details.go | 9 +- pkg/audit/log_sink.go | 53 +++++++++++ pkg/audit/log_sink_test.go | 110 +++++++++++++++++++++++ pkg/audit/reader.go | 87 ++++++++++++++++++ pkg/audit/reader_test.go | 97 ++++++++++++++++++++ pkg/audit/tls_details.go | 52 +++++++---- pkg/audit/writer.go | 85 ++++++++++++++++++ pkg/audit/writer_sink.go | 47 ++++++++++ pkg/audit/writer_sink_test.go | 72 +++++++++++++++ pkg/audit/writer_test.go | 117 ++++++++++++++++++++++++ plugins/tls_interceptor/register.go | 2 +- 15 files changed, 883 insertions(+), 88 deletions(-) create mode 100644 pkg/audit/log_sink.go create mode 100644 pkg/audit/log_sink_test.go create mode 100644 pkg/audit/reader.go create mode 100644 pkg/audit/reader_test.go create mode 100644 pkg/audit/writer.go create mode 100644 pkg/audit/writer_sink.go create mode 100644 pkg/audit/writer_sink_test.go create mode 100644 pkg/audit/writer_test.go diff --git a/internal/endpoints/endpoint_manager.go b/internal/endpoints/endpoint_manager.go index afc13e2..0470f8d 100644 --- a/internal/endpoints/endpoint_manager.go +++ b/internal/endpoints/endpoint_manager.go @@ -99,7 +99,7 @@ func (e *endpointManager) StartEndpoints() { } func (e *endpointManager) ShutdownEndpoints() { - var waitGroup sync.WaitGroup + waitGroup := new(sync.WaitGroup) waitGroup.Add(len(e.properlyStartedEndpoints)) parentCtx, _ := context.WithTimeout(context.Background(), shutdownTimeout) @@ -112,7 +112,7 @@ func (e *endpointManager) ShutdownEndpoints() { zap.String("endpoint", endpoint.Name()), ) endpointLogger.Info("Triggering shutdown of endpoint") - go shutdownEndpoint(ctx, endpoint, endpointLogger, &waitGroup) + go shutdownEndpoint(ctx, endpoint, endpointLogger, waitGroup) } waitGroup.Wait() diff --git a/pkg/audit/event.go b/pkg/audit/event.go index e625734..84fb7d1 100644 --- a/pkg/audit/event.go +++ b/pkg/audit/event.go @@ -1,14 +1,18 @@ package audit import ( + "encoding/binary" + "math/big" "net" "time" "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/anypb" + "google.golang.org/protobuf/types/known/timestamppb" ) -type EventDetails interface { - ProtoMessage() proto.Message +type Details interface { + MarshalToWireFormat() (*anypb.Any, error) } type Event struct { @@ -20,6 +24,98 @@ type Event struct { DestinationIP net.IP SourcePort uint16 DestinationPort uint16 - ProtocolDetails EventDetails + ProtocolDetails *anypb.Any TLS *TLSDetails } + +func (e *Event) ProtoMessage() proto.Message { + var sourceIP isEventEntity_SourceIP + if ipv4 := e.SourceIP.To4(); ipv4 != nil { + if len(ipv4) == 16 { + sourceIP = &EventEntity_SourceIPv4{SourceIPv4: binary.BigEndian.Uint32(ipv4[12:16])} + } else { + sourceIP = &EventEntity_SourceIPv4{SourceIPv4: binary.BigEndian.Uint32(ipv4)} + } + } else { + ipv6 := big.NewInt(0) + ipv6.SetBytes(e.SourceIP) + sourceIP = &EventEntity_SourceIPv6{SourceIPv6: ipv6.Uint64()} + } + + var destinationIP isEventEntity_DestinationIP + if ipv4 := e.DestinationIP.To4(); ipv4 != nil { + if len(ipv4) == 16 { + destinationIP = &EventEntity_DestinationIPv4{DestinationIPv4: binary.BigEndian.Uint32(ipv4[12:16])} + } else { + destinationIP = &EventEntity_DestinationIPv4{DestinationIPv4: binary.BigEndian.Uint32(ipv4)} + } + } else { + ipv6 := big.NewInt(0) + ipv6.SetBytes(e.SourceIP) + destinationIP = &EventEntity_DestinationIPv6{DestinationIPv6: ipv6.Uint64()} + } + + var tlsDetails *TLSDetailsEntity = nil + if e.TLS != nil { + tlsDetails = e.TLS.ProtoMessage() + } + + return &EventEntity{ + Id: e.ID, + Timestamp: timestamppb.New(e.Timestamp), + Transport: e.Transport, + Application: e.Application, + SourceIP: sourceIP, + DestinationIP: destinationIP, + SourcePort: uint32(e.SourcePort), + DestinationPort: uint32(e.DestinationPort), + Tls: tlsDetails, + ProtocolDetails: e.ProtocolDetails, + } +} + +func (e *Event) ApplyDefaults(id int64) { + e.ID = id + emptyTime := time.Time{} + if e.Timestamp == emptyTime { + e.Timestamp = time.Now().UTC() + } +} + +func NewEventFromProto(msg *EventEntity) (ev Event) { + var sourceIP net.IP + switch ip := msg.GetSourceIP().(type) { + case *EventEntity_SourceIPv4: + buf := make([]byte, 4) + binary.BigEndian.PutUint32(buf, ip.SourceIPv4) + sourceIP = buf + sourceIP = sourceIP.To4() + case *EventEntity_SourceIPv6: + sourceIP = big.NewInt(int64(ip.SourceIPv6)).Bytes() + } + + var destinationIP net.IP + switch ip := msg.GetDestinationIP().(type) { + case *EventEntity_DestinationIPv4: + buf := make([]byte, 4) + binary.BigEndian.PutUint32(buf, ip.DestinationIPv4) + destinationIP = buf + destinationIP = destinationIP.To4() + case *EventEntity_DestinationIPv6: + destinationIP = big.NewInt(int64(ip.DestinationIPv6)).Bytes() + } + + ev = Event{ + ID: msg.GetId(), + Timestamp: msg.GetTimestamp().AsTime(), + Transport: msg.GetTransport(), + Application: msg.GetApplication(), + SourceIP: sourceIP, + DestinationIP: destinationIP, + SourcePort: uint16(msg.GetSourcePort()), + DestinationPort: uint16(msg.GetDestinationPort()), + ProtocolDetails: msg.GetProtocolDetails(), + TLS: NewTLSDetailsFromProto(msg.GetTls()), + } + return +} diff --git a/pkg/audit/event_stream.go b/pkg/audit/event_stream.go index 0019805..bc22331 100644 --- a/pkg/audit/event_stream.go +++ b/pkg/audit/event_stream.go @@ -5,7 +5,6 @@ import ( "time" "github.com/bwmarrin/snowflake" - "github.com/imdario/mergo" "gitlab.com/inetmock/inetmock/pkg/logging" "go.uber.org/zap" ) @@ -56,10 +55,9 @@ func NewEventStream(logger logging.Logger, options ...EventStreamOption) (es Eve } func (e *eventStream) Emit(ev Event) { - defaultEvent := e.newDefaultEvent() - _ = mergo.Merge(defaultEvent, &ev) + ev.ApplyDefaults(e.idGenerator.Generate().Int64()) select { - case e.buffer <- defaultEvent: + case e.buffer <- &ev: e.logger.Debug("pushed event to distribute loop") default: e.logger.Warn("buffer is full") @@ -107,11 +105,3 @@ func (e *eventStream) distribute() { } } } - -func (e *eventStream) newDefaultEvent() *Event { - id := e.idGenerator.Generate() - return &Event{ - ID: id.Int64(), - Timestamp: time.Now().UTC(), - } -} diff --git a/pkg/audit/event_stream_test.go b/pkg/audit/event_stream_test.go index adf6c08..17732e3 100644 --- a/pkg/audit/event_stream_test.go +++ b/pkg/audit/event_stream_test.go @@ -1,19 +1,54 @@ package audit_test import ( + "crypto/tls" "net" + "net/http" "sync" "testing" "time" "gitlab.com/inetmock/inetmock/pkg/audit" "gitlab.com/inetmock/inetmock/pkg/logging" + "google.golang.org/protobuf/types/known/anypb" ) var ( defaultSink = &testSink{ name: "test defaultSink", } + testEvents = []audit.Event{ + { + Transport: audit.TransportProtocol_TCP, + Application: audit.AppProtocol_HTTP, + SourceIP: net.ParseIP("127.0.0.1"), + DestinationIP: net.ParseIP("127.0.0.1"), + SourcePort: 32344, + DestinationPort: 80, + TLS: &audit.TLSDetails{ + Version: tls.VersionTLS13, + CipherSuite: tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, + ServerName: "localhost", + }, + ProtocolDetails: mustMarshalToWireformat(audit.HTTPDetails{ + Method: "GET", + Host: "localhost", + URI: "http://localhost/asdf", + Proto: "HTTP 1.1", + Headers: http.Header{ + "Accept": []string{"application/json"}, + }, + }), + }, + { + Transport: audit.TransportProtocol_TCP, + Application: audit.AppProtocol_DNS, + SourceIP: net.ParseIP("::1"), + DestinationIP: net.ParseIP("::1"), + SourcePort: 32344, + DestinationPort: 80, + }, + } ) type testSink struct { @@ -45,17 +80,25 @@ func wgMockSink(t testing.TB, wg *sync.WaitGroup) audit.Sink { } } +func mustMarshalToWireformat(d audit.Details) *anypb.Any { + any, err := d.MarshalToWireFormat() + if err != nil { + panic(err) + } + return any +} + func Test_eventStream_RegisterSink(t *testing.T) { type args struct { s audit.Sink } - - tests := []struct { + type testCase struct { name string args args setup func(e audit.EventStream) wantErr bool - }{ + } + tests := []testCase{ { name: "Register test defaultSink", args: args{ @@ -74,8 +117,8 @@ func Test_eventStream_RegisterSink(t *testing.T) { wantErr: true, }, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { + scenario := func(tt testCase) func(t *testing.T) { + return func(t *testing.T) { var err error var e audit.EventStream if e, err = audit.NewEventStream(logging.CreateTestLogger(t)); err != nil { @@ -103,7 +146,11 @@ func Test_eventStream_RegisterSink(t *testing.T) { if !found { t.Errorf("expected defaultSink name %s not found in registered sinks %v", tt.args.s.Name(), e.Sinks()) } - }) + } + } + + for _, tt := range tests { + t.Run(tt.name, scenario(tt)) } } @@ -112,26 +159,18 @@ func Test_eventStream_Emit(t *testing.T) { evs []audit.Event opts []audit.EventStreamOption } - tests := []struct { + type testCase struct { name string args args subscribe bool - }{ + } + tests := []testCase{ { name: "Expect to get a single event", subscribe: true, args: args{ opts: []audit.EventStreamOption{audit.WithBufferSize(10)}, - evs: []audit.Event{ - { - Transport: audit.TransportProtocol_TCP, - Application: audit.AppProtocol_HTTP, - SourceIP: net.ParseIP("127.0.0.1"), - DestinationIP: net.ParseIP("127.0.0.1"), - SourcePort: 32344, - DestinationPort: 80, - }, - }, + evs: testEvents[:1], }, }, { @@ -139,46 +178,21 @@ func Test_eventStream_Emit(t *testing.T) { subscribe: true, args: args{ opts: []audit.EventStreamOption{audit.WithBufferSize(10)}, - evs: []audit.Event{ - { - Transport: audit.TransportProtocol_TCP, - Application: audit.AppProtocol_HTTP, - SourceIP: net.ParseIP("127.0.0.1"), - DestinationIP: net.ParseIP("127.0.0.1"), - SourcePort: 32344, - DestinationPort: 80, - }, - { - Transport: audit.TransportProtocol_TCP, - Application: audit.AppProtocol_DNS, - SourceIP: net.ParseIP("::1"), - DestinationIP: net.ParseIP("::1"), - SourcePort: 32344, - DestinationPort: 80, - }, - }, + evs: testEvents, }, }, { name: "Emit without subscribe sink", args: args{ opts: []audit.EventStreamOption{audit.WithBufferSize(0)}, - evs: []audit.Event{ - { - Transport: audit.TransportProtocol_TCP, - Application: audit.AppProtocol_HTTP, - SourceIP: net.ParseIP("127.0.0.1"), - DestinationIP: net.ParseIP("127.0.0.1"), - SourcePort: 32344, - DestinationPort: 80, - }, - }, + evs: testEvents[:1], }, subscribe: false, }, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { + + scenario := func(tt testCase) func(t *testing.T) { + return func(t *testing.T) { var err error var e audit.EventStream if e, err = audit.NewEventStream(logging.CreateTestLogger(t), tt.args.opts...); err != nil { @@ -189,8 +203,8 @@ func Test_eventStream_Emit(t *testing.T) { _ = e.Close() }) - emittedWaitGroup := &sync.WaitGroup{} - receivedWaitGroup := &sync.WaitGroup{} + emittedWaitGroup := new(sync.WaitGroup) + receivedWaitGroup := new(sync.WaitGroup) emittedWaitGroup.Add(len(tt.args.evs)) @@ -219,7 +233,11 @@ func Test_eventStream_Emit(t *testing.T) { case <-time.After(5 * time.Second): t.Errorf("did not get all expected events in time") } - }) + } + } + + for _, tt := range tests { + t.Run(tt.name, scenario(tt)) } } diff --git a/pkg/audit/http_details.go b/pkg/audit/http_details.go index dc5b096..b755a45 100644 --- a/pkg/audit/http_details.go +++ b/pkg/audit/http_details.go @@ -4,7 +4,7 @@ import ( "net/http" "strings" - "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/anypb" ) type HTTPDetails struct { @@ -15,7 +15,7 @@ type HTTPDetails struct { Headers http.Header } -func (d HTTPDetails) ProtoMessage() proto.Message { +func (d HTTPDetails) MarshalToWireFormat() (any *anypb.Any, err error) { var method = HTTPMethod_GET if methodValue, known := HTTPMethod_value[strings.ToUpper(d.Method)]; known { method = HTTPMethod(methodValue) @@ -29,11 +29,14 @@ func (d HTTPDetails) ProtoMessage() proto.Message { } } - return &HTTPDetailsEntity{ + protoDetails := &HTTPDetailsEntity{ Method: method, Host: d.Host, Uri: d.URI, Proto: d.Proto, Headers: headers, } + + any, err = anypb.New(protoDetails) + return } diff --git a/pkg/audit/log_sink.go b/pkg/audit/log_sink.go new file mode 100644 index 0000000..a6e07f1 --- /dev/null +++ b/pkg/audit/log_sink.go @@ -0,0 +1,53 @@ +package audit + +import ( + "crypto/tls" + + "gitlab.com/inetmock/inetmock/pkg/logging" + "go.uber.org/zap" +) + +const ( + logSinkName = "logging" +) + +func NewLogSink(logger logging.Logger) Sink { + return &logSink{ + logger: logger, + } +} + +type logSink struct { + logger logging.Logger +} + +func (logSink) Name() string { + return logSinkName +} + +func (l logSink) OnSubscribe(evs <-chan Event) { + go func(logger logging.Logger, evs <-chan Event) { + for ev := range evs { + eventLogger := logger + + if ev.TLS != nil { + eventLogger = eventLogger.With( + zap.String("tls_server_name", ev.TLS.ServerName), + zap.String("tls_cipher_suite", tls.CipherSuiteName(ev.TLS.CipherSuite)), + ) + } + + eventLogger.Info( + "handled request", + zap.Time("timestamp", ev.Timestamp), + zap.String("application", ev.Application.String()), + zap.String("transport", ev.Transport.String()), + zap.String("source_ip", ev.SourceIP.String()), + zap.Uint16("source_port", ev.SourcePort), + zap.String("destination_ip", ev.DestinationIP.String()), + zap.Uint16("destination_port", ev.DestinationPort), + zap.Any("details", ev.ProtocolDetails), + ) + } + }(l.logger, evs) +} diff --git a/pkg/audit/log_sink_test.go b/pkg/audit/log_sink_test.go new file mode 100644 index 0000000..085dc25 --- /dev/null +++ b/pkg/audit/log_sink_test.go @@ -0,0 +1,110 @@ +package audit_test + +import ( + "sync" + "testing" + "time" + + "github.com/golang/mock/gomock" + logging_mock "gitlab.com/inetmock/inetmock/internal/mock/logging" + "gitlab.com/inetmock/inetmock/pkg/audit" + "gitlab.com/inetmock/inetmock/pkg/logging" + "go.uber.org/zap" +) + +//nolint:dupl +func Test_logSink_OnSubscribe(t *testing.T) { + type fields struct { + loggerSetup func(t *testing.T, wg *sync.WaitGroup) logging.Logger + } + type testCase struct { + name string + fields fields + events []audit.Event + } + tests := []testCase{ + { + name: "Get a single log line", + fields: fields{ + loggerSetup: func(t *testing.T, wg *sync.WaitGroup) logging.Logger { + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + loggerMock := logging_mock.NewMockLogger(ctrl) + + loggerMock. + EXPECT(). + With(gomock.Any()). + Return(loggerMock) + + loggerMock. + EXPECT(). + Info("handled request", gomock.Any()). + Do(func(_ string, _ ...zap.Field) { + wg.Done() + }). + Times(1) + + return loggerMock + }, + }, + events: testEvents[:1], + }, + { + name: "Get multiple events", + fields: fields{ + loggerSetup: func(t *testing.T, wg *sync.WaitGroup) logging.Logger { + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + loggerMock := logging_mock.NewMockLogger(ctrl) + + loggerMock. + EXPECT(). + With(gomock.Any()). + Return(loggerMock) + + loggerMock. + EXPECT(). + Info("handled request", gomock.Any()). + Do(func(_ string, _ ...zap.Field) { + wg.Done() + }). + Times(2) + + return loggerMock + }, + }, + events: testEvents, + }, + } + scenario := func(tt testCase) func(t *testing.T) { + return func(t *testing.T) { + wg := new(sync.WaitGroup) + wg.Add(len(tt.events)) + + logSink := audit.NewLogSink(tt.fields.loggerSetup(t, wg)) + var evs audit.EventStream + var err error + if evs, err = audit.NewEventStream(logging.CreateTestLogger(t)); err != nil { + t.Errorf("NewEventStream() error = %v", err) + } + + if err = evs.RegisterSink(logSink); err != nil { + t.Errorf("RegisterSink() error = %v", err) + } + + for _, ev := range tt.events { + evs.Emit(ev) + } + + select { + case <-time.After(100 * time.Millisecond): + t.Errorf("not all events recorded in time") + case <-waitGroupDone(wg): + } + } + } + + for _, tt := range tests { + t.Run(tt.name, scenario(tt)) + } +} diff --git a/pkg/audit/reader.go b/pkg/audit/reader.go new file mode 100644 index 0000000..d679565 --- /dev/null +++ b/pkg/audit/reader.go @@ -0,0 +1,87 @@ +package audit + +import ( + "encoding/binary" + "io" + "sync" + + "google.golang.org/protobuf/proto" +) + +const ( + lengthBufferSize = 4 + defaultPayloadBufferSize = 4096 +) + +type Reader interface { + Read() (Event, error) +} + +type EventReaderOption func(reader *eventReader) + +var ( + WithReaderByteOrder = func(order binary.ByteOrder) EventReaderOption { + return func(reader *eventReader) { + reader.byteOrder = order + } + } +) + +func NewEventReader(source io.Reader, opts ...EventReaderOption) Reader { + reader := &eventReader{ + source: source, + byteOrder: defaultOrder, + lengthPool: &sync.Pool{ + New: func() interface{} { + return make([]byte, lengthBufferSize) + }, + }, + payloadPool: &sync.Pool{ + New: func() interface{} { + return make([]byte, defaultPayloadBufferSize) + }, + }, + } + + for _, opt := range opts { + opt(reader) + } + + return reader +} + +type eventReader struct { + lengthPool *sync.Pool + payloadPool *sync.Pool + byteOrder binary.ByteOrder + source io.Reader +} + +func (e eventReader) Read() (ev Event, err error) { + lengthBuf := e.lengthPool.Get().([]byte) + if _, err = e.source.Read(lengthBuf); err != nil { + return + } + + length := e.byteOrder.Uint32(lengthBuf) + var msgBuf []byte + if length <= defaultPayloadBufferSize { + msgBuf = e.payloadPool.Get().([]byte)[:length] + } else { + msgBuf = make([]byte, length) + } + + if _, err = e.source.Read(msgBuf); err != nil { + return + } + + protoEv := new(EventEntity) + + if err = proto.Unmarshal(msgBuf, protoEv); err != nil { + return + } + + ev = NewEventFromProto(protoEv) + + return +} diff --git a/pkg/audit/reader_test.go b/pkg/audit/reader_test.go new file mode 100644 index 0000000..3335e79 --- /dev/null +++ b/pkg/audit/reader_test.go @@ -0,0 +1,97 @@ +package audit_test + +import ( + "bytes" + "encoding/binary" + "encoding/hex" + "io" + "testing" + + "gitlab.com/inetmock/inetmock/pkg/audit" + "google.golang.org/protobuf/proto" +) + +var ( + //nolint:lll + httpPayloadBytesLittleEndian = `dd000000120b088092b8c398feffffff011801200248d8fc0150505a3308041224544c535f45434448455f45434453415f574954485f4145535f3235365f4342435f5348411a096c6f63616c686f73746282010a34747970652e676f6f676c65617069732e636f6d2f696e65746d6f636b2e61756469742e4854545044657461696c73456e74697479124a12096c6f63616c686f73741a15687474703a2f2f6c6f63616c686f73742f6173646622084854545020312e312a1c0a0641636365707412120a106170706c69636174696f6e2f6a736f6e28818080f80738818080f807` + //nolint:lll + httpPayloadBytesBigEndian = `000000dd120b088092b8c398feffffff011801200248d8fc0150505a3308041224544c535f45434448455f45434453415f574954485f4145535f3235365f4342435f5348411a096c6f63616c686f73746282010a34747970652e676f6f676c65617069732e636f6d2f696e65746d6f636b2e61756469742e4854545044657461696c73456e74697479124a12096c6f63616c686f73741a15687474703a2f2f6c6f63616c686f73742f6173646622084854545020312e312a1c0a0641636365707412120a106170706c69636174696f6e2f6a736f6e28818080f80738818080f807` + dnsPayloadBytesLittleEndian = `1b000000120b088092b8c398feffffff011801200148d8fc01505030014001` + dnsPayloadBytesBigEndian = `0000001b120b088092b8c398feffffff011801200148d8fc01505030014001` +) + +func mustDecodeHex(hexBytes string) io.Reader { + b, err := hex.DecodeString(hexBytes) + if err != nil { + panic(err) + } + return bytes.NewReader(b) +} + +func Test_eventReader_Read(t *testing.T) { + type fields struct { + source io.Reader + order binary.ByteOrder + } + type testCase struct { + name string + fields fields + wantEv audit.Event + wantErr bool + } + tests := []testCase{ + { + name: "Read HTTP payload - little endian", + fields: fields{ + source: mustDecodeHex(httpPayloadBytesLittleEndian), + order: binary.LittleEndian, + }, + wantEv: testEvents[0], + wantErr: false, + }, + { + name: "Read HTTP payload - big endian", + fields: fields{ + source: mustDecodeHex(httpPayloadBytesBigEndian), + order: binary.BigEndian, + }, + wantEv: testEvents[0], + wantErr: false, + }, + { + name: "Read DNS payload - little endian", + fields: fields{ + source: mustDecodeHex(dnsPayloadBytesLittleEndian), + order: binary.LittleEndian, + }, + wantEv: testEvents[1], + wantErr: false, + }, + { + name: "Read DNS payload - big endian", + fields: fields{ + source: mustDecodeHex(dnsPayloadBytesBigEndian), + order: binary.BigEndian, + }, + wantEv: testEvents[1], + wantErr: false, + }, + } + scenario := func(tt testCase) func(t *testing.T) { + return func(t *testing.T) { + e := audit.NewEventReader(tt.fields.source, audit.WithReaderByteOrder(tt.fields.order)) + gotEv, err := e.Read() + if (err != nil) != tt.wantErr { + t.Errorf("Read() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if err == nil && !proto.Equal(gotEv.ProtoMessage(), tt.wantEv.ProtoMessage()) { + t.Errorf("Read() gotEv = %v, want %v", gotEv, tt.wantEv) + } + } + } + for _, tt := range tests { + t.Run(tt.name, scenario(tt)) + } +} diff --git a/pkg/audit/tls_details.go b/pkg/audit/tls_details.go index fc9ae75..a984007 100644 --- a/pkg/audit/tls_details.go +++ b/pkg/audit/tls_details.go @@ -2,8 +2,31 @@ package audit import ( "crypto/tls" +) - "google.golang.org/protobuf/proto" +var ( + tlsToEntity = map[uint16]TLSVersion{ + tls.VersionSSL30: TLSVersion_SSLv30, + tls.VersionTLS10: TLSVersion_TLS10, + tls.VersionTLS11: TLSVersion_TLS11, + tls.VersionTLS12: TLSVersion_TLS12, + tls.VersionTLS13: TLSVersion_TLS13, + } + entityToTls = map[TLSVersion]uint16{ + TLSVersion_SSLv30: tls.VersionSSL30, + TLSVersion_TLS10: tls.VersionTLS10, + TLSVersion_TLS11: tls.VersionTLS11, + TLSVersion_TLS12: tls.VersionTLS12, + TLSVersion_TLS13: tls.VersionTLS13, + } + cipherSuiteIDLookup = func(name string) uint16 { + for _, cs := range tls.CipherSuites() { + if cs.Name == name { + return cs.ID + } + } + return 0 + } ) type TLSDetails struct { @@ -12,24 +35,21 @@ type TLSDetails struct { ServerName string } -func (d TLSDetails) ProtoMessage() proto.Message { - var version TLSVersion - - switch d.Version { - case tls.VersionTLS10: - version = TLSVersion_TLS10 - case tls.VersionTLS11: - version = TLSVersion_TLS11 - case tls.VersionTLS12: - version = TLSVersion_TLS12 - case tls.VersionTLS13: - version = TLSVersion_TLS13 - default: - version = TLSVersion_SSLv30 +func NewTLSDetailsFromProto(entity *TLSDetailsEntity) *TLSDetails { + if entity == nil { + return nil } + return &TLSDetails{ + Version: entityToTls[entity.GetVersion()], + CipherSuite: cipherSuiteIDLookup(entity.GetCipherSuite()), + ServerName: entity.GetServerName(), + } +} + +func (d TLSDetails) ProtoMessage() *TLSDetailsEntity { return &TLSDetailsEntity{ - Version: version, + Version: tlsToEntity[d.Version], CipherSuite: tls.CipherSuiteName(d.CipherSuite), ServerName: d.ServerName, } diff --git a/pkg/audit/writer.go b/pkg/audit/writer.go new file mode 100644 index 0000000..5a318b4 --- /dev/null +++ b/pkg/audit/writer.go @@ -0,0 +1,85 @@ +//go:generate mockgen -source=$GOFILE -destination=./../../internal/mock/audit/writer.mock.go -package=audit_mock + +package audit + +import ( + "encoding/binary" + "errors" + "io" + "sync" + + "google.golang.org/protobuf/proto" +) + +var ( + WithWriterByteOrder = func(order binary.ByteOrder) func(writer *eventWriter) { + return func(writer *eventWriter) { + writer.byteOrder = order + } + } + defaultOrder = binary.BigEndian + ErrValueMostNotBeNil = errors.New("event value must not be nil") +) + +type Writer interface { + io.Closer + Write(ev *Event) error +} + +type EventWriterOption func(writer *eventWriter) + +func NewEventWriter(target io.Writer, opts ...EventWriterOption) Writer { + writer := &eventWriter{ + target: target, + byteOrder: defaultOrder, + lengthPool: &sync.Pool{ + New: func() interface{} { + return make([]byte, lengthBufferSize) + }, + }, + } + + for _, opt := range opts { + opt(writer) + } + + return writer +} + +type eventWriter struct { + lengthPool *sync.Pool + target io.Writer + byteOrder binary.ByteOrder +} + +func (e eventWriter) Write(ev *Event) (err error) { + if ev == nil { + return ErrValueMostNotBeNil + } + var bytes []byte + + if bytes, err = proto.Marshal(ev.ProtoMessage()); err != nil { + return + } + buf := e.lengthPool.Get().([]byte) + e.byteOrder.PutUint32(buf, uint32(len(bytes))) + + if _, err = e.target.Write(buf); err != nil { + return + } + if _, err = e.target.Write(bytes); err != nil { + return + } + if syncerInst, ok := e.target.(syncer); ok { + err = syncerInst.Sync() + } + + return +} + +func (e eventWriter) Close() error { + if closer, isCloser := e.target.(io.Closer); isCloser { + return closer.Close() + } + return nil +} diff --git a/pkg/audit/writer_sink.go b/pkg/audit/writer_sink.go new file mode 100644 index 0000000..803d29a --- /dev/null +++ b/pkg/audit/writer_sink.go @@ -0,0 +1,47 @@ +package audit + +type WriterSinkOption func(sink *writerCloserSink) + +var ( + WithCloseOnExit WriterSinkOption = func(sink *writerCloserSink) { + sink.closeOnExit = true + } +) + +func NewWriterSink(name string, target Writer, opts ...WriterSinkOption) Sink { + sink := &writerCloserSink{ + name: name, + target: target, + } + + for _, opt := range opts { + opt(sink) + } + + return sink +} + +type writerCloserSink struct { + name string + target Writer + closeOnExit bool +} + +type syncer interface { + Sync() error +} + +func (f writerCloserSink) Name() string { + return f.name +} + +func (f writerCloserSink) OnSubscribe(evs <-chan Event) { + go func(target Writer, closeOnExit bool, evs <-chan Event) { + for ev := range evs { + _ = target.Write(&ev) + } + if closeOnExit { + _ = target.Close() + } + }(f.target, f.closeOnExit, evs) +} diff --git a/pkg/audit/writer_sink_test.go b/pkg/audit/writer_sink_test.go new file mode 100644 index 0000000..395cc5a --- /dev/null +++ b/pkg/audit/writer_sink_test.go @@ -0,0 +1,72 @@ +package audit_test + +import ( + "sync" + "testing" + "time" + + "github.com/golang/mock/gomock" + audit_mock "gitlab.com/inetmock/inetmock/internal/mock/audit" + "gitlab.com/inetmock/inetmock/pkg/audit" + "gitlab.com/inetmock/inetmock/pkg/logging" +) + +func Test_writerCloserSink_OnSubscribe(t *testing.T) { + type testCase struct { + name string + events []audit.Event + } + tests := []testCase{ + { + name: "Get a single event", + events: testEvents[:1], + }, + { + name: "Get multiple events", + events: testEvents, + }, + } + scenario := func(tt testCase) func(t *testing.T) { + return func(t *testing.T) { + wg := new(sync.WaitGroup) + wg.Add(len(tt.events)) + + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + writerMock := audit_mock.NewMockWriter(ctrl) + writerMock. + EXPECT(). + Write(gomock.Any()). + Do(func(_ *audit.Event) { + wg.Done() + }). + Times(len(tt.events)) + + writerCloserSink := audit.NewWriterSink("WriterMock", writerMock, audit.WithCloseOnExit) + var evs audit.EventStream + var err error + + if evs, err = audit.NewEventStream(logging.CreateTestLogger(t)); err != nil { + t.Errorf("NewEventStream() error = %v", err) + } + + if err = evs.RegisterSink(writerCloserSink); err != nil { + t.Errorf("RegisterSink() error = %v", err) + } + + for _, ev := range tt.events { + evs.Emit(ev) + } + + select { + case <-time.After(100 * time.Millisecond): + t.Errorf("not all events recorded in time") + case <-waitGroupDone(wg): + } + } + } + + for _, tt := range tests { + t.Run(tt.name, scenario(tt)) + } +} diff --git a/pkg/audit/writer_test.go b/pkg/audit/writer_test.go new file mode 100644 index 0000000..ce8264b --- /dev/null +++ b/pkg/audit/writer_test.go @@ -0,0 +1,117 @@ +//go:generate mockgen -source=$GOFILE -destination=./../../internal/mock/audit/writercloser.mock.go -package=audit_mock + +package audit_test + +import ( + "encoding/binary" + "encoding/hex" + "io" + "testing" + + "github.com/golang/mock/gomock" + audit_mock "gitlab.com/inetmock/inetmock/internal/mock/audit" + "gitlab.com/inetmock/inetmock/pkg/audit" +) + +type WriterCloserSyncer interface { + io.WriteCloser + Sync() error +} + +func Test_eventWriter_Write(t *testing.T) { + type fields struct { + order binary.ByteOrder + } + type args struct { + evs []audit.Event + } + type testCase struct { + name string + fields fields + args args + wantErr bool + } + tests := []testCase{ + { + name: "Write a single event - little endian", + fields: fields{ + order: binary.LittleEndian, + }, + args: args{ + evs: testEvents[:1], + }, + wantErr: false, + }, + { + name: "Write a single event - big endian", + fields: fields{ + order: binary.BigEndian, + }, + args: args{ + evs: testEvents[:1], + }, + wantErr: false, + }, + { + name: "Write multiple events - little endian", + fields: fields{ + order: binary.LittleEndian, + }, + args: args{ + evs: testEvents, + }, + wantErr: false, + }, + { + name: "Write multiple events - big endian", + fields: fields{ + order: binary.BigEndian, + }, + args: args{ + evs: testEvents, + }, + wantErr: false, + }, + } + scenario := func(tt testCase) func(t *testing.T) { + return func(t *testing.T) { + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + writerMock := audit_mock.NewMockWriterCloserSyncer(ctrl) + calls := make([]*gomock.Call, 0) + for i := 0; i < len(tt.args.evs); i++ { + calls = append(calls, + writerMock. + EXPECT(). + Write(gomock.Any()). + Do(func(data []byte) { + t.Logf("got payload = %s", hex.EncodeToString(data)) + t.Logf("got length %d", tt.fields.order.Uint32(data)) + }), + writerMock. + EXPECT(). + Write(gomock.Any()). + Do(func(data []byte) { + t.Logf("got payload = %s", hex.EncodeToString(data)) + }), + writerMock. + EXPECT(). + Sync(), + ) + } + gomock.InOrder(calls...) + + e := audit.NewEventWriter(writerMock, audit.WithWriterByteOrder(tt.fields.order)) + + for _, ev := range tt.args.evs { + if err := e.Write(&ev); (err != nil) != tt.wantErr { + t.Errorf("Write() error = %v, wantErr %v", err, tt.wantErr) + } + } + } + } + + for _, tt := range tests { + t.Run(tt.name, scenario(tt)) + } +} diff --git a/plugins/tls_interceptor/register.go b/plugins/tls_interceptor/register.go index d354da5..e8ec9cb 100644 --- a/plugins/tls_interceptor/register.go +++ b/plugins/tls_interceptor/register.go @@ -40,7 +40,7 @@ func AddTLSInterceptor(registry api.HandlerRegistry) (err error) { registry.RegisterHandler(name, func() api.ProtocolHandler { return &tlsInterceptor{ logger: logger, - currentConnectionsCount: &sync.WaitGroup{}, + currentConnectionsCount: new(sync.WaitGroup), currentConnections: make(map[uuid.UUID]*proxyConn), connectionsMutex: &sync.Mutex{}, }