Implement log and writer sinks

Add reader
This commit is contained in:
Peter 2021-01-02 17:24:06 +01:00
parent 1e8139e845
commit 63607dfd67
Signed by: prskr
GPG key ID: C1DB5D2E8DB512F9
15 changed files with 883 additions and 88 deletions

View file

@ -99,7 +99,7 @@ func (e *endpointManager) StartEndpoints() {
}
func (e *endpointManager) ShutdownEndpoints() {
var waitGroup sync.WaitGroup
waitGroup := new(sync.WaitGroup)
waitGroup.Add(len(e.properlyStartedEndpoints))
parentCtx, _ := context.WithTimeout(context.Background(), shutdownTimeout)
@ -112,7 +112,7 @@ func (e *endpointManager) ShutdownEndpoints() {
zap.String("endpoint", endpoint.Name()),
)
endpointLogger.Info("Triggering shutdown of endpoint")
go shutdownEndpoint(ctx, endpoint, endpointLogger, &waitGroup)
go shutdownEndpoint(ctx, endpoint, endpointLogger, waitGroup)
}
waitGroup.Wait()

View file

@ -1,14 +1,18 @@
package audit
import (
"encoding/binary"
"math/big"
"net"
"time"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/anypb"
"google.golang.org/protobuf/types/known/timestamppb"
)
type EventDetails interface {
ProtoMessage() proto.Message
type Details interface {
MarshalToWireFormat() (*anypb.Any, error)
}
type Event struct {
@ -20,6 +24,98 @@ type Event struct {
DestinationIP net.IP
SourcePort uint16
DestinationPort uint16
ProtocolDetails EventDetails
ProtocolDetails *anypb.Any
TLS *TLSDetails
}
func (e *Event) ProtoMessage() proto.Message {
var sourceIP isEventEntity_SourceIP
if ipv4 := e.SourceIP.To4(); ipv4 != nil {
if len(ipv4) == 16 {
sourceIP = &EventEntity_SourceIPv4{SourceIPv4: binary.BigEndian.Uint32(ipv4[12:16])}
} else {
sourceIP = &EventEntity_SourceIPv4{SourceIPv4: binary.BigEndian.Uint32(ipv4)}
}
} else {
ipv6 := big.NewInt(0)
ipv6.SetBytes(e.SourceIP)
sourceIP = &EventEntity_SourceIPv6{SourceIPv6: ipv6.Uint64()}
}
var destinationIP isEventEntity_DestinationIP
if ipv4 := e.DestinationIP.To4(); ipv4 != nil {
if len(ipv4) == 16 {
destinationIP = &EventEntity_DestinationIPv4{DestinationIPv4: binary.BigEndian.Uint32(ipv4[12:16])}
} else {
destinationIP = &EventEntity_DestinationIPv4{DestinationIPv4: binary.BigEndian.Uint32(ipv4)}
}
} else {
ipv6 := big.NewInt(0)
ipv6.SetBytes(e.SourceIP)
destinationIP = &EventEntity_DestinationIPv6{DestinationIPv6: ipv6.Uint64()}
}
var tlsDetails *TLSDetailsEntity = nil
if e.TLS != nil {
tlsDetails = e.TLS.ProtoMessage()
}
return &EventEntity{
Id: e.ID,
Timestamp: timestamppb.New(e.Timestamp),
Transport: e.Transport,
Application: e.Application,
SourceIP: sourceIP,
DestinationIP: destinationIP,
SourcePort: uint32(e.SourcePort),
DestinationPort: uint32(e.DestinationPort),
Tls: tlsDetails,
ProtocolDetails: e.ProtocolDetails,
}
}
func (e *Event) ApplyDefaults(id int64) {
e.ID = id
emptyTime := time.Time{}
if e.Timestamp == emptyTime {
e.Timestamp = time.Now().UTC()
}
}
func NewEventFromProto(msg *EventEntity) (ev Event) {
var sourceIP net.IP
switch ip := msg.GetSourceIP().(type) {
case *EventEntity_SourceIPv4:
buf := make([]byte, 4)
binary.BigEndian.PutUint32(buf, ip.SourceIPv4)
sourceIP = buf
sourceIP = sourceIP.To4()
case *EventEntity_SourceIPv6:
sourceIP = big.NewInt(int64(ip.SourceIPv6)).Bytes()
}
var destinationIP net.IP
switch ip := msg.GetDestinationIP().(type) {
case *EventEntity_DestinationIPv4:
buf := make([]byte, 4)
binary.BigEndian.PutUint32(buf, ip.DestinationIPv4)
destinationIP = buf
destinationIP = destinationIP.To4()
case *EventEntity_DestinationIPv6:
destinationIP = big.NewInt(int64(ip.DestinationIPv6)).Bytes()
}
ev = Event{
ID: msg.GetId(),
Timestamp: msg.GetTimestamp().AsTime(),
Transport: msg.GetTransport(),
Application: msg.GetApplication(),
SourceIP: sourceIP,
DestinationIP: destinationIP,
SourcePort: uint16(msg.GetSourcePort()),
DestinationPort: uint16(msg.GetDestinationPort()),
ProtocolDetails: msg.GetProtocolDetails(),
TLS: NewTLSDetailsFromProto(msg.GetTls()),
}
return
}

View file

@ -5,7 +5,6 @@ import (
"time"
"github.com/bwmarrin/snowflake"
"github.com/imdario/mergo"
"gitlab.com/inetmock/inetmock/pkg/logging"
"go.uber.org/zap"
)
@ -56,10 +55,9 @@ func NewEventStream(logger logging.Logger, options ...EventStreamOption) (es Eve
}
func (e *eventStream) Emit(ev Event) {
defaultEvent := e.newDefaultEvent()
_ = mergo.Merge(defaultEvent, &ev)
ev.ApplyDefaults(e.idGenerator.Generate().Int64())
select {
case e.buffer <- defaultEvent:
case e.buffer <- &ev:
e.logger.Debug("pushed event to distribute loop")
default:
e.logger.Warn("buffer is full")
@ -107,11 +105,3 @@ func (e *eventStream) distribute() {
}
}
}
func (e *eventStream) newDefaultEvent() *Event {
id := e.idGenerator.Generate()
return &Event{
ID: id.Int64(),
Timestamp: time.Now().UTC(),
}
}

View file

@ -1,19 +1,54 @@
package audit_test
import (
"crypto/tls"
"net"
"net/http"
"sync"
"testing"
"time"
"gitlab.com/inetmock/inetmock/pkg/audit"
"gitlab.com/inetmock/inetmock/pkg/logging"
"google.golang.org/protobuf/types/known/anypb"
)
var (
defaultSink = &testSink{
name: "test defaultSink",
}
testEvents = []audit.Event{
{
Transport: audit.TransportProtocol_TCP,
Application: audit.AppProtocol_HTTP,
SourceIP: net.ParseIP("127.0.0.1"),
DestinationIP: net.ParseIP("127.0.0.1"),
SourcePort: 32344,
DestinationPort: 80,
TLS: &audit.TLSDetails{
Version: tls.VersionTLS13,
CipherSuite: tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
ServerName: "localhost",
},
ProtocolDetails: mustMarshalToWireformat(audit.HTTPDetails{
Method: "GET",
Host: "localhost",
URI: "http://localhost/asdf",
Proto: "HTTP 1.1",
Headers: http.Header{
"Accept": []string{"application/json"},
},
}),
},
{
Transport: audit.TransportProtocol_TCP,
Application: audit.AppProtocol_DNS,
SourceIP: net.ParseIP("::1"),
DestinationIP: net.ParseIP("::1"),
SourcePort: 32344,
DestinationPort: 80,
},
}
)
type testSink struct {
@ -45,17 +80,25 @@ func wgMockSink(t testing.TB, wg *sync.WaitGroup) audit.Sink {
}
}
func mustMarshalToWireformat(d audit.Details) *anypb.Any {
any, err := d.MarshalToWireFormat()
if err != nil {
panic(err)
}
return any
}
func Test_eventStream_RegisterSink(t *testing.T) {
type args struct {
s audit.Sink
}
tests := []struct {
type testCase struct {
name string
args args
setup func(e audit.EventStream)
wantErr bool
}{
}
tests := []testCase{
{
name: "Register test defaultSink",
args: args{
@ -74,8 +117,8 @@ func Test_eventStream_RegisterSink(t *testing.T) {
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
scenario := func(tt testCase) func(t *testing.T) {
return func(t *testing.T) {
var err error
var e audit.EventStream
if e, err = audit.NewEventStream(logging.CreateTestLogger(t)); err != nil {
@ -103,7 +146,11 @@ func Test_eventStream_RegisterSink(t *testing.T) {
if !found {
t.Errorf("expected defaultSink name %s not found in registered sinks %v", tt.args.s.Name(), e.Sinks())
}
})
}
}
for _, tt := range tests {
t.Run(tt.name, scenario(tt))
}
}
@ -112,26 +159,18 @@ func Test_eventStream_Emit(t *testing.T) {
evs []audit.Event
opts []audit.EventStreamOption
}
tests := []struct {
type testCase struct {
name string
args args
subscribe bool
}{
}
tests := []testCase{
{
name: "Expect to get a single event",
subscribe: true,
args: args{
opts: []audit.EventStreamOption{audit.WithBufferSize(10)},
evs: []audit.Event{
{
Transport: audit.TransportProtocol_TCP,
Application: audit.AppProtocol_HTTP,
SourceIP: net.ParseIP("127.0.0.1"),
DestinationIP: net.ParseIP("127.0.0.1"),
SourcePort: 32344,
DestinationPort: 80,
},
},
evs: testEvents[:1],
},
},
{
@ -139,46 +178,21 @@ func Test_eventStream_Emit(t *testing.T) {
subscribe: true,
args: args{
opts: []audit.EventStreamOption{audit.WithBufferSize(10)},
evs: []audit.Event{
{
Transport: audit.TransportProtocol_TCP,
Application: audit.AppProtocol_HTTP,
SourceIP: net.ParseIP("127.0.0.1"),
DestinationIP: net.ParseIP("127.0.0.1"),
SourcePort: 32344,
DestinationPort: 80,
},
{
Transport: audit.TransportProtocol_TCP,
Application: audit.AppProtocol_DNS,
SourceIP: net.ParseIP("::1"),
DestinationIP: net.ParseIP("::1"),
SourcePort: 32344,
DestinationPort: 80,
},
},
evs: testEvents,
},
},
{
name: "Emit without subscribe sink",
args: args{
opts: []audit.EventStreamOption{audit.WithBufferSize(0)},
evs: []audit.Event{
{
Transport: audit.TransportProtocol_TCP,
Application: audit.AppProtocol_HTTP,
SourceIP: net.ParseIP("127.0.0.1"),
DestinationIP: net.ParseIP("127.0.0.1"),
SourcePort: 32344,
DestinationPort: 80,
},
},
evs: testEvents[:1],
},
subscribe: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
scenario := func(tt testCase) func(t *testing.T) {
return func(t *testing.T) {
var err error
var e audit.EventStream
if e, err = audit.NewEventStream(logging.CreateTestLogger(t), tt.args.opts...); err != nil {
@ -189,8 +203,8 @@ func Test_eventStream_Emit(t *testing.T) {
_ = e.Close()
})
emittedWaitGroup := &sync.WaitGroup{}
receivedWaitGroup := &sync.WaitGroup{}
emittedWaitGroup := new(sync.WaitGroup)
receivedWaitGroup := new(sync.WaitGroup)
emittedWaitGroup.Add(len(tt.args.evs))
@ -219,7 +233,11 @@ func Test_eventStream_Emit(t *testing.T) {
case <-time.After(5 * time.Second):
t.Errorf("did not get all expected events in time")
}
})
}
}
for _, tt := range tests {
t.Run(tt.name, scenario(tt))
}
}

View file

@ -4,7 +4,7 @@ import (
"net/http"
"strings"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/anypb"
)
type HTTPDetails struct {
@ -15,7 +15,7 @@ type HTTPDetails struct {
Headers http.Header
}
func (d HTTPDetails) ProtoMessage() proto.Message {
func (d HTTPDetails) MarshalToWireFormat() (any *anypb.Any, err error) {
var method = HTTPMethod_GET
if methodValue, known := HTTPMethod_value[strings.ToUpper(d.Method)]; known {
method = HTTPMethod(methodValue)
@ -29,11 +29,14 @@ func (d HTTPDetails) ProtoMessage() proto.Message {
}
}
return &HTTPDetailsEntity{
protoDetails := &HTTPDetailsEntity{
Method: method,
Host: d.Host,
Uri: d.URI,
Proto: d.Proto,
Headers: headers,
}
any, err = anypb.New(protoDetails)
return
}

53
pkg/audit/log_sink.go Normal file
View file

@ -0,0 +1,53 @@
package audit
import (
"crypto/tls"
"gitlab.com/inetmock/inetmock/pkg/logging"
"go.uber.org/zap"
)
const (
logSinkName = "logging"
)
func NewLogSink(logger logging.Logger) Sink {
return &logSink{
logger: logger,
}
}
type logSink struct {
logger logging.Logger
}
func (logSink) Name() string {
return logSinkName
}
func (l logSink) OnSubscribe(evs <-chan Event) {
go func(logger logging.Logger, evs <-chan Event) {
for ev := range evs {
eventLogger := logger
if ev.TLS != nil {
eventLogger = eventLogger.With(
zap.String("tls_server_name", ev.TLS.ServerName),
zap.String("tls_cipher_suite", tls.CipherSuiteName(ev.TLS.CipherSuite)),
)
}
eventLogger.Info(
"handled request",
zap.Time("timestamp", ev.Timestamp),
zap.String("application", ev.Application.String()),
zap.String("transport", ev.Transport.String()),
zap.String("source_ip", ev.SourceIP.String()),
zap.Uint16("source_port", ev.SourcePort),
zap.String("destination_ip", ev.DestinationIP.String()),
zap.Uint16("destination_port", ev.DestinationPort),
zap.Any("details", ev.ProtocolDetails),
)
}
}(l.logger, evs)
}

110
pkg/audit/log_sink_test.go Normal file
View file

@ -0,0 +1,110 @@
package audit_test
import (
"sync"
"testing"
"time"
"github.com/golang/mock/gomock"
logging_mock "gitlab.com/inetmock/inetmock/internal/mock/logging"
"gitlab.com/inetmock/inetmock/pkg/audit"
"gitlab.com/inetmock/inetmock/pkg/logging"
"go.uber.org/zap"
)
//nolint:dupl
func Test_logSink_OnSubscribe(t *testing.T) {
type fields struct {
loggerSetup func(t *testing.T, wg *sync.WaitGroup) logging.Logger
}
type testCase struct {
name string
fields fields
events []audit.Event
}
tests := []testCase{
{
name: "Get a single log line",
fields: fields{
loggerSetup: func(t *testing.T, wg *sync.WaitGroup) logging.Logger {
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
loggerMock := logging_mock.NewMockLogger(ctrl)
loggerMock.
EXPECT().
With(gomock.Any()).
Return(loggerMock)
loggerMock.
EXPECT().
Info("handled request", gomock.Any()).
Do(func(_ string, _ ...zap.Field) {
wg.Done()
}).
Times(1)
return loggerMock
},
},
events: testEvents[:1],
},
{
name: "Get multiple events",
fields: fields{
loggerSetup: func(t *testing.T, wg *sync.WaitGroup) logging.Logger {
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
loggerMock := logging_mock.NewMockLogger(ctrl)
loggerMock.
EXPECT().
With(gomock.Any()).
Return(loggerMock)
loggerMock.
EXPECT().
Info("handled request", gomock.Any()).
Do(func(_ string, _ ...zap.Field) {
wg.Done()
}).
Times(2)
return loggerMock
},
},
events: testEvents,
},
}
scenario := func(tt testCase) func(t *testing.T) {
return func(t *testing.T) {
wg := new(sync.WaitGroup)
wg.Add(len(tt.events))
logSink := audit.NewLogSink(tt.fields.loggerSetup(t, wg))
var evs audit.EventStream
var err error
if evs, err = audit.NewEventStream(logging.CreateTestLogger(t)); err != nil {
t.Errorf("NewEventStream() error = %v", err)
}
if err = evs.RegisterSink(logSink); err != nil {
t.Errorf("RegisterSink() error = %v", err)
}
for _, ev := range tt.events {
evs.Emit(ev)
}
select {
case <-time.After(100 * time.Millisecond):
t.Errorf("not all events recorded in time")
case <-waitGroupDone(wg):
}
}
}
for _, tt := range tests {
t.Run(tt.name, scenario(tt))
}
}

87
pkg/audit/reader.go Normal file
View file

@ -0,0 +1,87 @@
package audit
import (
"encoding/binary"
"io"
"sync"
"google.golang.org/protobuf/proto"
)
const (
lengthBufferSize = 4
defaultPayloadBufferSize = 4096
)
type Reader interface {
Read() (Event, error)
}
type EventReaderOption func(reader *eventReader)
var (
WithReaderByteOrder = func(order binary.ByteOrder) EventReaderOption {
return func(reader *eventReader) {
reader.byteOrder = order
}
}
)
func NewEventReader(source io.Reader, opts ...EventReaderOption) Reader {
reader := &eventReader{
source: source,
byteOrder: defaultOrder,
lengthPool: &sync.Pool{
New: func() interface{} {
return make([]byte, lengthBufferSize)
},
},
payloadPool: &sync.Pool{
New: func() interface{} {
return make([]byte, defaultPayloadBufferSize)
},
},
}
for _, opt := range opts {
opt(reader)
}
return reader
}
type eventReader struct {
lengthPool *sync.Pool
payloadPool *sync.Pool
byteOrder binary.ByteOrder
source io.Reader
}
func (e eventReader) Read() (ev Event, err error) {
lengthBuf := e.lengthPool.Get().([]byte)
if _, err = e.source.Read(lengthBuf); err != nil {
return
}
length := e.byteOrder.Uint32(lengthBuf)
var msgBuf []byte
if length <= defaultPayloadBufferSize {
msgBuf = e.payloadPool.Get().([]byte)[:length]
} else {
msgBuf = make([]byte, length)
}
if _, err = e.source.Read(msgBuf); err != nil {
return
}
protoEv := new(EventEntity)
if err = proto.Unmarshal(msgBuf, protoEv); err != nil {
return
}
ev = NewEventFromProto(protoEv)
return
}

97
pkg/audit/reader_test.go Normal file
View file

@ -0,0 +1,97 @@
package audit_test
import (
"bytes"
"encoding/binary"
"encoding/hex"
"io"
"testing"
"gitlab.com/inetmock/inetmock/pkg/audit"
"google.golang.org/protobuf/proto"
)
var (
//nolint:lll
httpPayloadBytesLittleEndian = `dd000000120b088092b8c398feffffff011801200248d8fc0150505a3308041224544c535f45434448455f45434453415f574954485f4145535f3235365f4342435f5348411a096c6f63616c686f73746282010a34747970652e676f6f676c65617069732e636f6d2f696e65746d6f636b2e61756469742e4854545044657461696c73456e74697479124a12096c6f63616c686f73741a15687474703a2f2f6c6f63616c686f73742f6173646622084854545020312e312a1c0a0641636365707412120a106170706c69636174696f6e2f6a736f6e28818080f80738818080f807`
//nolint:lll
httpPayloadBytesBigEndian = `000000dd120b088092b8c398feffffff011801200248d8fc0150505a3308041224544c535f45434448455f45434453415f574954485f4145535f3235365f4342435f5348411a096c6f63616c686f73746282010a34747970652e676f6f676c65617069732e636f6d2f696e65746d6f636b2e61756469742e4854545044657461696c73456e74697479124a12096c6f63616c686f73741a15687474703a2f2f6c6f63616c686f73742f6173646622084854545020312e312a1c0a0641636365707412120a106170706c69636174696f6e2f6a736f6e28818080f80738818080f807`
dnsPayloadBytesLittleEndian = `1b000000120b088092b8c398feffffff011801200148d8fc01505030014001`
dnsPayloadBytesBigEndian = `0000001b120b088092b8c398feffffff011801200148d8fc01505030014001`
)
func mustDecodeHex(hexBytes string) io.Reader {
b, err := hex.DecodeString(hexBytes)
if err != nil {
panic(err)
}
return bytes.NewReader(b)
}
func Test_eventReader_Read(t *testing.T) {
type fields struct {
source io.Reader
order binary.ByteOrder
}
type testCase struct {
name string
fields fields
wantEv audit.Event
wantErr bool
}
tests := []testCase{
{
name: "Read HTTP payload - little endian",
fields: fields{
source: mustDecodeHex(httpPayloadBytesLittleEndian),
order: binary.LittleEndian,
},
wantEv: testEvents[0],
wantErr: false,
},
{
name: "Read HTTP payload - big endian",
fields: fields{
source: mustDecodeHex(httpPayloadBytesBigEndian),
order: binary.BigEndian,
},
wantEv: testEvents[0],
wantErr: false,
},
{
name: "Read DNS payload - little endian",
fields: fields{
source: mustDecodeHex(dnsPayloadBytesLittleEndian),
order: binary.LittleEndian,
},
wantEv: testEvents[1],
wantErr: false,
},
{
name: "Read DNS payload - big endian",
fields: fields{
source: mustDecodeHex(dnsPayloadBytesBigEndian),
order: binary.BigEndian,
},
wantEv: testEvents[1],
wantErr: false,
},
}
scenario := func(tt testCase) func(t *testing.T) {
return func(t *testing.T) {
e := audit.NewEventReader(tt.fields.source, audit.WithReaderByteOrder(tt.fields.order))
gotEv, err := e.Read()
if (err != nil) != tt.wantErr {
t.Errorf("Read() error = %v, wantErr %v", err, tt.wantErr)
return
}
if err == nil && !proto.Equal(gotEv.ProtoMessage(), tt.wantEv.ProtoMessage()) {
t.Errorf("Read() gotEv = %v, want %v", gotEv, tt.wantEv)
}
}
}
for _, tt := range tests {
t.Run(tt.name, scenario(tt))
}
}

View file

@ -2,8 +2,31 @@ package audit
import (
"crypto/tls"
)
"google.golang.org/protobuf/proto"
var (
tlsToEntity = map[uint16]TLSVersion{
tls.VersionSSL30: TLSVersion_SSLv30,
tls.VersionTLS10: TLSVersion_TLS10,
tls.VersionTLS11: TLSVersion_TLS11,
tls.VersionTLS12: TLSVersion_TLS12,
tls.VersionTLS13: TLSVersion_TLS13,
}
entityToTls = map[TLSVersion]uint16{
TLSVersion_SSLv30: tls.VersionSSL30,
TLSVersion_TLS10: tls.VersionTLS10,
TLSVersion_TLS11: tls.VersionTLS11,
TLSVersion_TLS12: tls.VersionTLS12,
TLSVersion_TLS13: tls.VersionTLS13,
}
cipherSuiteIDLookup = func(name string) uint16 {
for _, cs := range tls.CipherSuites() {
if cs.Name == name {
return cs.ID
}
}
return 0
}
)
type TLSDetails struct {
@ -12,24 +35,21 @@ type TLSDetails struct {
ServerName string
}
func (d TLSDetails) ProtoMessage() proto.Message {
var version TLSVersion
switch d.Version {
case tls.VersionTLS10:
version = TLSVersion_TLS10
case tls.VersionTLS11:
version = TLSVersion_TLS11
case tls.VersionTLS12:
version = TLSVersion_TLS12
case tls.VersionTLS13:
version = TLSVersion_TLS13
default:
version = TLSVersion_SSLv30
func NewTLSDetailsFromProto(entity *TLSDetailsEntity) *TLSDetails {
if entity == nil {
return nil
}
return &TLSDetails{
Version: entityToTls[entity.GetVersion()],
CipherSuite: cipherSuiteIDLookup(entity.GetCipherSuite()),
ServerName: entity.GetServerName(),
}
}
func (d TLSDetails) ProtoMessage() *TLSDetailsEntity {
return &TLSDetailsEntity{
Version: version,
Version: tlsToEntity[d.Version],
CipherSuite: tls.CipherSuiteName(d.CipherSuite),
ServerName: d.ServerName,
}

85
pkg/audit/writer.go Normal file
View file

@ -0,0 +1,85 @@
//go:generate mockgen -source=$GOFILE -destination=./../../internal/mock/audit/writer.mock.go -package=audit_mock
package audit
import (
"encoding/binary"
"errors"
"io"
"sync"
"google.golang.org/protobuf/proto"
)
var (
WithWriterByteOrder = func(order binary.ByteOrder) func(writer *eventWriter) {
return func(writer *eventWriter) {
writer.byteOrder = order
}
}
defaultOrder = binary.BigEndian
ErrValueMostNotBeNil = errors.New("event value must not be nil")
)
type Writer interface {
io.Closer
Write(ev *Event) error
}
type EventWriterOption func(writer *eventWriter)
func NewEventWriter(target io.Writer, opts ...EventWriterOption) Writer {
writer := &eventWriter{
target: target,
byteOrder: defaultOrder,
lengthPool: &sync.Pool{
New: func() interface{} {
return make([]byte, lengthBufferSize)
},
},
}
for _, opt := range opts {
opt(writer)
}
return writer
}
type eventWriter struct {
lengthPool *sync.Pool
target io.Writer
byteOrder binary.ByteOrder
}
func (e eventWriter) Write(ev *Event) (err error) {
if ev == nil {
return ErrValueMostNotBeNil
}
var bytes []byte
if bytes, err = proto.Marshal(ev.ProtoMessage()); err != nil {
return
}
buf := e.lengthPool.Get().([]byte)
e.byteOrder.PutUint32(buf, uint32(len(bytes)))
if _, err = e.target.Write(buf); err != nil {
return
}
if _, err = e.target.Write(bytes); err != nil {
return
}
if syncerInst, ok := e.target.(syncer); ok {
err = syncerInst.Sync()
}
return
}
func (e eventWriter) Close() error {
if closer, isCloser := e.target.(io.Closer); isCloser {
return closer.Close()
}
return nil
}

47
pkg/audit/writer_sink.go Normal file
View file

@ -0,0 +1,47 @@
package audit
type WriterSinkOption func(sink *writerCloserSink)
var (
WithCloseOnExit WriterSinkOption = func(sink *writerCloserSink) {
sink.closeOnExit = true
}
)
func NewWriterSink(name string, target Writer, opts ...WriterSinkOption) Sink {
sink := &writerCloserSink{
name: name,
target: target,
}
for _, opt := range opts {
opt(sink)
}
return sink
}
type writerCloserSink struct {
name string
target Writer
closeOnExit bool
}
type syncer interface {
Sync() error
}
func (f writerCloserSink) Name() string {
return f.name
}
func (f writerCloserSink) OnSubscribe(evs <-chan Event) {
go func(target Writer, closeOnExit bool, evs <-chan Event) {
for ev := range evs {
_ = target.Write(&ev)
}
if closeOnExit {
_ = target.Close()
}
}(f.target, f.closeOnExit, evs)
}

View file

@ -0,0 +1,72 @@
package audit_test
import (
"sync"
"testing"
"time"
"github.com/golang/mock/gomock"
audit_mock "gitlab.com/inetmock/inetmock/internal/mock/audit"
"gitlab.com/inetmock/inetmock/pkg/audit"
"gitlab.com/inetmock/inetmock/pkg/logging"
)
func Test_writerCloserSink_OnSubscribe(t *testing.T) {
type testCase struct {
name string
events []audit.Event
}
tests := []testCase{
{
name: "Get a single event",
events: testEvents[:1],
},
{
name: "Get multiple events",
events: testEvents,
},
}
scenario := func(tt testCase) func(t *testing.T) {
return func(t *testing.T) {
wg := new(sync.WaitGroup)
wg.Add(len(tt.events))
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
writerMock := audit_mock.NewMockWriter(ctrl)
writerMock.
EXPECT().
Write(gomock.Any()).
Do(func(_ *audit.Event) {
wg.Done()
}).
Times(len(tt.events))
writerCloserSink := audit.NewWriterSink("WriterMock", writerMock, audit.WithCloseOnExit)
var evs audit.EventStream
var err error
if evs, err = audit.NewEventStream(logging.CreateTestLogger(t)); err != nil {
t.Errorf("NewEventStream() error = %v", err)
}
if err = evs.RegisterSink(writerCloserSink); err != nil {
t.Errorf("RegisterSink() error = %v", err)
}
for _, ev := range tt.events {
evs.Emit(ev)
}
select {
case <-time.After(100 * time.Millisecond):
t.Errorf("not all events recorded in time")
case <-waitGroupDone(wg):
}
}
}
for _, tt := range tests {
t.Run(tt.name, scenario(tt))
}
}

117
pkg/audit/writer_test.go Normal file
View file

@ -0,0 +1,117 @@
//go:generate mockgen -source=$GOFILE -destination=./../../internal/mock/audit/writercloser.mock.go -package=audit_mock
package audit_test
import (
"encoding/binary"
"encoding/hex"
"io"
"testing"
"github.com/golang/mock/gomock"
audit_mock "gitlab.com/inetmock/inetmock/internal/mock/audit"
"gitlab.com/inetmock/inetmock/pkg/audit"
)
type WriterCloserSyncer interface {
io.WriteCloser
Sync() error
}
func Test_eventWriter_Write(t *testing.T) {
type fields struct {
order binary.ByteOrder
}
type args struct {
evs []audit.Event
}
type testCase struct {
name string
fields fields
args args
wantErr bool
}
tests := []testCase{
{
name: "Write a single event - little endian",
fields: fields{
order: binary.LittleEndian,
},
args: args{
evs: testEvents[:1],
},
wantErr: false,
},
{
name: "Write a single event - big endian",
fields: fields{
order: binary.BigEndian,
},
args: args{
evs: testEvents[:1],
},
wantErr: false,
},
{
name: "Write multiple events - little endian",
fields: fields{
order: binary.LittleEndian,
},
args: args{
evs: testEvents,
},
wantErr: false,
},
{
name: "Write multiple events - big endian",
fields: fields{
order: binary.BigEndian,
},
args: args{
evs: testEvents,
},
wantErr: false,
},
}
scenario := func(tt testCase) func(t *testing.T) {
return func(t *testing.T) {
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
writerMock := audit_mock.NewMockWriterCloserSyncer(ctrl)
calls := make([]*gomock.Call, 0)
for i := 0; i < len(tt.args.evs); i++ {
calls = append(calls,
writerMock.
EXPECT().
Write(gomock.Any()).
Do(func(data []byte) {
t.Logf("got payload = %s", hex.EncodeToString(data))
t.Logf("got length %d", tt.fields.order.Uint32(data))
}),
writerMock.
EXPECT().
Write(gomock.Any()).
Do(func(data []byte) {
t.Logf("got payload = %s", hex.EncodeToString(data))
}),
writerMock.
EXPECT().
Sync(),
)
}
gomock.InOrder(calls...)
e := audit.NewEventWriter(writerMock, audit.WithWriterByteOrder(tt.fields.order))
for _, ev := range tt.args.evs {
if err := e.Write(&ev); (err != nil) != tt.wantErr {
t.Errorf("Write() error = %v, wantErr %v", err, tt.wantErr)
}
}
}
}
for _, tt := range tests {
t.Run(tt.name, scenario(tt))
}
}

View file

@ -40,7 +40,7 @@ func AddTLSInterceptor(registry api.HandlerRegistry) (err error) {
registry.RegisterHandler(name, func() api.ProtocolHandler {
return &tlsInterceptor{
logger: logger,
currentConnectionsCount: &sync.WaitGroup{},
currentConnectionsCount: new(sync.WaitGroup),
currentConnections: make(map[uuid.UUID]*proxyConn),
connectionsMutex: &sync.Mutex{},
}