Implement log and writer sinks
Add reader
This commit is contained in:
parent
1e8139e845
commit
63607dfd67
15 changed files with 883 additions and 88 deletions
|
@ -99,7 +99,7 @@ func (e *endpointManager) StartEndpoints() {
|
|||
}
|
||||
|
||||
func (e *endpointManager) ShutdownEndpoints() {
|
||||
var waitGroup sync.WaitGroup
|
||||
waitGroup := new(sync.WaitGroup)
|
||||
waitGroup.Add(len(e.properlyStartedEndpoints))
|
||||
|
||||
parentCtx, _ := context.WithTimeout(context.Background(), shutdownTimeout)
|
||||
|
@ -112,7 +112,7 @@ func (e *endpointManager) ShutdownEndpoints() {
|
|||
zap.String("endpoint", endpoint.Name()),
|
||||
)
|
||||
endpointLogger.Info("Triggering shutdown of endpoint")
|
||||
go shutdownEndpoint(ctx, endpoint, endpointLogger, &waitGroup)
|
||||
go shutdownEndpoint(ctx, endpoint, endpointLogger, waitGroup)
|
||||
}
|
||||
|
||||
waitGroup.Wait()
|
||||
|
|
|
@ -1,14 +1,18 @@
|
|||
package audit
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"math/big"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
)
|
||||
|
||||
type EventDetails interface {
|
||||
ProtoMessage() proto.Message
|
||||
type Details interface {
|
||||
MarshalToWireFormat() (*anypb.Any, error)
|
||||
}
|
||||
|
||||
type Event struct {
|
||||
|
@ -20,6 +24,98 @@ type Event struct {
|
|||
DestinationIP net.IP
|
||||
SourcePort uint16
|
||||
DestinationPort uint16
|
||||
ProtocolDetails EventDetails
|
||||
ProtocolDetails *anypb.Any
|
||||
TLS *TLSDetails
|
||||
}
|
||||
|
||||
func (e *Event) ProtoMessage() proto.Message {
|
||||
var sourceIP isEventEntity_SourceIP
|
||||
if ipv4 := e.SourceIP.To4(); ipv4 != nil {
|
||||
if len(ipv4) == 16 {
|
||||
sourceIP = &EventEntity_SourceIPv4{SourceIPv4: binary.BigEndian.Uint32(ipv4[12:16])}
|
||||
} else {
|
||||
sourceIP = &EventEntity_SourceIPv4{SourceIPv4: binary.BigEndian.Uint32(ipv4)}
|
||||
}
|
||||
} else {
|
||||
ipv6 := big.NewInt(0)
|
||||
ipv6.SetBytes(e.SourceIP)
|
||||
sourceIP = &EventEntity_SourceIPv6{SourceIPv6: ipv6.Uint64()}
|
||||
}
|
||||
|
||||
var destinationIP isEventEntity_DestinationIP
|
||||
if ipv4 := e.DestinationIP.To4(); ipv4 != nil {
|
||||
if len(ipv4) == 16 {
|
||||
destinationIP = &EventEntity_DestinationIPv4{DestinationIPv4: binary.BigEndian.Uint32(ipv4[12:16])}
|
||||
} else {
|
||||
destinationIP = &EventEntity_DestinationIPv4{DestinationIPv4: binary.BigEndian.Uint32(ipv4)}
|
||||
}
|
||||
} else {
|
||||
ipv6 := big.NewInt(0)
|
||||
ipv6.SetBytes(e.SourceIP)
|
||||
destinationIP = &EventEntity_DestinationIPv6{DestinationIPv6: ipv6.Uint64()}
|
||||
}
|
||||
|
||||
var tlsDetails *TLSDetailsEntity = nil
|
||||
if e.TLS != nil {
|
||||
tlsDetails = e.TLS.ProtoMessage()
|
||||
}
|
||||
|
||||
return &EventEntity{
|
||||
Id: e.ID,
|
||||
Timestamp: timestamppb.New(e.Timestamp),
|
||||
Transport: e.Transport,
|
||||
Application: e.Application,
|
||||
SourceIP: sourceIP,
|
||||
DestinationIP: destinationIP,
|
||||
SourcePort: uint32(e.SourcePort),
|
||||
DestinationPort: uint32(e.DestinationPort),
|
||||
Tls: tlsDetails,
|
||||
ProtocolDetails: e.ProtocolDetails,
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Event) ApplyDefaults(id int64) {
|
||||
e.ID = id
|
||||
emptyTime := time.Time{}
|
||||
if e.Timestamp == emptyTime {
|
||||
e.Timestamp = time.Now().UTC()
|
||||
}
|
||||
}
|
||||
|
||||
func NewEventFromProto(msg *EventEntity) (ev Event) {
|
||||
var sourceIP net.IP
|
||||
switch ip := msg.GetSourceIP().(type) {
|
||||
case *EventEntity_SourceIPv4:
|
||||
buf := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(buf, ip.SourceIPv4)
|
||||
sourceIP = buf
|
||||
sourceIP = sourceIP.To4()
|
||||
case *EventEntity_SourceIPv6:
|
||||
sourceIP = big.NewInt(int64(ip.SourceIPv6)).Bytes()
|
||||
}
|
||||
|
||||
var destinationIP net.IP
|
||||
switch ip := msg.GetDestinationIP().(type) {
|
||||
case *EventEntity_DestinationIPv4:
|
||||
buf := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(buf, ip.DestinationIPv4)
|
||||
destinationIP = buf
|
||||
destinationIP = destinationIP.To4()
|
||||
case *EventEntity_DestinationIPv6:
|
||||
destinationIP = big.NewInt(int64(ip.DestinationIPv6)).Bytes()
|
||||
}
|
||||
|
||||
ev = Event{
|
||||
ID: msg.GetId(),
|
||||
Timestamp: msg.GetTimestamp().AsTime(),
|
||||
Transport: msg.GetTransport(),
|
||||
Application: msg.GetApplication(),
|
||||
SourceIP: sourceIP,
|
||||
DestinationIP: destinationIP,
|
||||
SourcePort: uint16(msg.GetSourcePort()),
|
||||
DestinationPort: uint16(msg.GetDestinationPort()),
|
||||
ProtocolDetails: msg.GetProtocolDetails(),
|
||||
TLS: NewTLSDetailsFromProto(msg.GetTls()),
|
||||
}
|
||||
return
|
||||
}
|
||||
|
|
|
@ -5,7 +5,6 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/bwmarrin/snowflake"
|
||||
"github.com/imdario/mergo"
|
||||
"gitlab.com/inetmock/inetmock/pkg/logging"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
@ -56,10 +55,9 @@ func NewEventStream(logger logging.Logger, options ...EventStreamOption) (es Eve
|
|||
}
|
||||
|
||||
func (e *eventStream) Emit(ev Event) {
|
||||
defaultEvent := e.newDefaultEvent()
|
||||
_ = mergo.Merge(defaultEvent, &ev)
|
||||
ev.ApplyDefaults(e.idGenerator.Generate().Int64())
|
||||
select {
|
||||
case e.buffer <- defaultEvent:
|
||||
case e.buffer <- &ev:
|
||||
e.logger.Debug("pushed event to distribute loop")
|
||||
default:
|
||||
e.logger.Warn("buffer is full")
|
||||
|
@ -107,11 +105,3 @@ func (e *eventStream) distribute() {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (e *eventStream) newDefaultEvent() *Event {
|
||||
id := e.idGenerator.Generate()
|
||||
return &Event{
|
||||
ID: id.Int64(),
|
||||
Timestamp: time.Now().UTC(),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,19 +1,54 @@
|
|||
package audit_test
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"gitlab.com/inetmock/inetmock/pkg/audit"
|
||||
"gitlab.com/inetmock/inetmock/pkg/logging"
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
)
|
||||
|
||||
var (
|
||||
defaultSink = &testSink{
|
||||
name: "test defaultSink",
|
||||
}
|
||||
testEvents = []audit.Event{
|
||||
{
|
||||
Transport: audit.TransportProtocol_TCP,
|
||||
Application: audit.AppProtocol_HTTP,
|
||||
SourceIP: net.ParseIP("127.0.0.1"),
|
||||
DestinationIP: net.ParseIP("127.0.0.1"),
|
||||
SourcePort: 32344,
|
||||
DestinationPort: 80,
|
||||
TLS: &audit.TLSDetails{
|
||||
Version: tls.VersionTLS13,
|
||||
CipherSuite: tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
|
||||
ServerName: "localhost",
|
||||
},
|
||||
ProtocolDetails: mustMarshalToWireformat(audit.HTTPDetails{
|
||||
Method: "GET",
|
||||
Host: "localhost",
|
||||
URI: "http://localhost/asdf",
|
||||
Proto: "HTTP 1.1",
|
||||
Headers: http.Header{
|
||||
"Accept": []string{"application/json"},
|
||||
},
|
||||
}),
|
||||
},
|
||||
{
|
||||
Transport: audit.TransportProtocol_TCP,
|
||||
Application: audit.AppProtocol_DNS,
|
||||
SourceIP: net.ParseIP("::1"),
|
||||
DestinationIP: net.ParseIP("::1"),
|
||||
SourcePort: 32344,
|
||||
DestinationPort: 80,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
type testSink struct {
|
||||
|
@ -45,17 +80,25 @@ func wgMockSink(t testing.TB, wg *sync.WaitGroup) audit.Sink {
|
|||
}
|
||||
}
|
||||
|
||||
func mustMarshalToWireformat(d audit.Details) *anypb.Any {
|
||||
any, err := d.MarshalToWireFormat()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return any
|
||||
}
|
||||
|
||||
func Test_eventStream_RegisterSink(t *testing.T) {
|
||||
type args struct {
|
||||
s audit.Sink
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
type testCase struct {
|
||||
name string
|
||||
args args
|
||||
setup func(e audit.EventStream)
|
||||
wantErr bool
|
||||
}{
|
||||
}
|
||||
tests := []testCase{
|
||||
{
|
||||
name: "Register test defaultSink",
|
||||
args: args{
|
||||
|
@ -74,8 +117,8 @@ func Test_eventStream_RegisterSink(t *testing.T) {
|
|||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
scenario := func(tt testCase) func(t *testing.T) {
|
||||
return func(t *testing.T) {
|
||||
var err error
|
||||
var e audit.EventStream
|
||||
if e, err = audit.NewEventStream(logging.CreateTestLogger(t)); err != nil {
|
||||
|
@ -103,7 +146,11 @@ func Test_eventStream_RegisterSink(t *testing.T) {
|
|||
if !found {
|
||||
t.Errorf("expected defaultSink name %s not found in registered sinks %v", tt.args.s.Name(), e.Sinks())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, scenario(tt))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -112,26 +159,18 @@ func Test_eventStream_Emit(t *testing.T) {
|
|||
evs []audit.Event
|
||||
opts []audit.EventStreamOption
|
||||
}
|
||||
tests := []struct {
|
||||
type testCase struct {
|
||||
name string
|
||||
args args
|
||||
subscribe bool
|
||||
}{
|
||||
}
|
||||
tests := []testCase{
|
||||
{
|
||||
name: "Expect to get a single event",
|
||||
subscribe: true,
|
||||
args: args{
|
||||
opts: []audit.EventStreamOption{audit.WithBufferSize(10)},
|
||||
evs: []audit.Event{
|
||||
{
|
||||
Transport: audit.TransportProtocol_TCP,
|
||||
Application: audit.AppProtocol_HTTP,
|
||||
SourceIP: net.ParseIP("127.0.0.1"),
|
||||
DestinationIP: net.ParseIP("127.0.0.1"),
|
||||
SourcePort: 32344,
|
||||
DestinationPort: 80,
|
||||
},
|
||||
},
|
||||
evs: testEvents[:1],
|
||||
},
|
||||
},
|
||||
{
|
||||
|
@ -139,46 +178,21 @@ func Test_eventStream_Emit(t *testing.T) {
|
|||
subscribe: true,
|
||||
args: args{
|
||||
opts: []audit.EventStreamOption{audit.WithBufferSize(10)},
|
||||
evs: []audit.Event{
|
||||
{
|
||||
Transport: audit.TransportProtocol_TCP,
|
||||
Application: audit.AppProtocol_HTTP,
|
||||
SourceIP: net.ParseIP("127.0.0.1"),
|
||||
DestinationIP: net.ParseIP("127.0.0.1"),
|
||||
SourcePort: 32344,
|
||||
DestinationPort: 80,
|
||||
},
|
||||
{
|
||||
Transport: audit.TransportProtocol_TCP,
|
||||
Application: audit.AppProtocol_DNS,
|
||||
SourceIP: net.ParseIP("::1"),
|
||||
DestinationIP: net.ParseIP("::1"),
|
||||
SourcePort: 32344,
|
||||
DestinationPort: 80,
|
||||
},
|
||||
},
|
||||
evs: testEvents,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Emit without subscribe sink",
|
||||
args: args{
|
||||
opts: []audit.EventStreamOption{audit.WithBufferSize(0)},
|
||||
evs: []audit.Event{
|
||||
{
|
||||
Transport: audit.TransportProtocol_TCP,
|
||||
Application: audit.AppProtocol_HTTP,
|
||||
SourceIP: net.ParseIP("127.0.0.1"),
|
||||
DestinationIP: net.ParseIP("127.0.0.1"),
|
||||
SourcePort: 32344,
|
||||
DestinationPort: 80,
|
||||
},
|
||||
},
|
||||
evs: testEvents[:1],
|
||||
},
|
||||
subscribe: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
||||
scenario := func(tt testCase) func(t *testing.T) {
|
||||
return func(t *testing.T) {
|
||||
var err error
|
||||
var e audit.EventStream
|
||||
if e, err = audit.NewEventStream(logging.CreateTestLogger(t), tt.args.opts...); err != nil {
|
||||
|
@ -189,8 +203,8 @@ func Test_eventStream_Emit(t *testing.T) {
|
|||
_ = e.Close()
|
||||
})
|
||||
|
||||
emittedWaitGroup := &sync.WaitGroup{}
|
||||
receivedWaitGroup := &sync.WaitGroup{}
|
||||
emittedWaitGroup := new(sync.WaitGroup)
|
||||
receivedWaitGroup := new(sync.WaitGroup)
|
||||
|
||||
emittedWaitGroup.Add(len(tt.args.evs))
|
||||
|
||||
|
@ -219,7 +233,11 @@ func Test_eventStream_Emit(t *testing.T) {
|
|||
case <-time.After(5 * time.Second):
|
||||
t.Errorf("did not get all expected events in time")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, scenario(tt))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ import (
|
|||
"net/http"
|
||||
"strings"
|
||||
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
)
|
||||
|
||||
type HTTPDetails struct {
|
||||
|
@ -15,7 +15,7 @@ type HTTPDetails struct {
|
|||
Headers http.Header
|
||||
}
|
||||
|
||||
func (d HTTPDetails) ProtoMessage() proto.Message {
|
||||
func (d HTTPDetails) MarshalToWireFormat() (any *anypb.Any, err error) {
|
||||
var method = HTTPMethod_GET
|
||||
if methodValue, known := HTTPMethod_value[strings.ToUpper(d.Method)]; known {
|
||||
method = HTTPMethod(methodValue)
|
||||
|
@ -29,11 +29,14 @@ func (d HTTPDetails) ProtoMessage() proto.Message {
|
|||
}
|
||||
}
|
||||
|
||||
return &HTTPDetailsEntity{
|
||||
protoDetails := &HTTPDetailsEntity{
|
||||
Method: method,
|
||||
Host: d.Host,
|
||||
Uri: d.URI,
|
||||
Proto: d.Proto,
|
||||
Headers: headers,
|
||||
}
|
||||
|
||||
any, err = anypb.New(protoDetails)
|
||||
return
|
||||
}
|
||||
|
|
53
pkg/audit/log_sink.go
Normal file
53
pkg/audit/log_sink.go
Normal file
|
@ -0,0 +1,53 @@
|
|||
package audit
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
|
||||
"gitlab.com/inetmock/inetmock/pkg/logging"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const (
|
||||
logSinkName = "logging"
|
||||
)
|
||||
|
||||
func NewLogSink(logger logging.Logger) Sink {
|
||||
return &logSink{
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
type logSink struct {
|
||||
logger logging.Logger
|
||||
}
|
||||
|
||||
func (logSink) Name() string {
|
||||
return logSinkName
|
||||
}
|
||||
|
||||
func (l logSink) OnSubscribe(evs <-chan Event) {
|
||||
go func(logger logging.Logger, evs <-chan Event) {
|
||||
for ev := range evs {
|
||||
eventLogger := logger
|
||||
|
||||
if ev.TLS != nil {
|
||||
eventLogger = eventLogger.With(
|
||||
zap.String("tls_server_name", ev.TLS.ServerName),
|
||||
zap.String("tls_cipher_suite", tls.CipherSuiteName(ev.TLS.CipherSuite)),
|
||||
)
|
||||
}
|
||||
|
||||
eventLogger.Info(
|
||||
"handled request",
|
||||
zap.Time("timestamp", ev.Timestamp),
|
||||
zap.String("application", ev.Application.String()),
|
||||
zap.String("transport", ev.Transport.String()),
|
||||
zap.String("source_ip", ev.SourceIP.String()),
|
||||
zap.Uint16("source_port", ev.SourcePort),
|
||||
zap.String("destination_ip", ev.DestinationIP.String()),
|
||||
zap.Uint16("destination_port", ev.DestinationPort),
|
||||
zap.Any("details", ev.ProtocolDetails),
|
||||
)
|
||||
}
|
||||
}(l.logger, evs)
|
||||
}
|
110
pkg/audit/log_sink_test.go
Normal file
110
pkg/audit/log_sink_test.go
Normal file
|
@ -0,0 +1,110 @@
|
|||
package audit_test
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
logging_mock "gitlab.com/inetmock/inetmock/internal/mock/logging"
|
||||
"gitlab.com/inetmock/inetmock/pkg/audit"
|
||||
"gitlab.com/inetmock/inetmock/pkg/logging"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
//nolint:dupl
|
||||
func Test_logSink_OnSubscribe(t *testing.T) {
|
||||
type fields struct {
|
||||
loggerSetup func(t *testing.T, wg *sync.WaitGroup) logging.Logger
|
||||
}
|
||||
type testCase struct {
|
||||
name string
|
||||
fields fields
|
||||
events []audit.Event
|
||||
}
|
||||
tests := []testCase{
|
||||
{
|
||||
name: "Get a single log line",
|
||||
fields: fields{
|
||||
loggerSetup: func(t *testing.T, wg *sync.WaitGroup) logging.Logger {
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
loggerMock := logging_mock.NewMockLogger(ctrl)
|
||||
|
||||
loggerMock.
|
||||
EXPECT().
|
||||
With(gomock.Any()).
|
||||
Return(loggerMock)
|
||||
|
||||
loggerMock.
|
||||
EXPECT().
|
||||
Info("handled request", gomock.Any()).
|
||||
Do(func(_ string, _ ...zap.Field) {
|
||||
wg.Done()
|
||||
}).
|
||||
Times(1)
|
||||
|
||||
return loggerMock
|
||||
},
|
||||
},
|
||||
events: testEvents[:1],
|
||||
},
|
||||
{
|
||||
name: "Get multiple events",
|
||||
fields: fields{
|
||||
loggerSetup: func(t *testing.T, wg *sync.WaitGroup) logging.Logger {
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
loggerMock := logging_mock.NewMockLogger(ctrl)
|
||||
|
||||
loggerMock.
|
||||
EXPECT().
|
||||
With(gomock.Any()).
|
||||
Return(loggerMock)
|
||||
|
||||
loggerMock.
|
||||
EXPECT().
|
||||
Info("handled request", gomock.Any()).
|
||||
Do(func(_ string, _ ...zap.Field) {
|
||||
wg.Done()
|
||||
}).
|
||||
Times(2)
|
||||
|
||||
return loggerMock
|
||||
},
|
||||
},
|
||||
events: testEvents,
|
||||
},
|
||||
}
|
||||
scenario := func(tt testCase) func(t *testing.T) {
|
||||
return func(t *testing.T) {
|
||||
wg := new(sync.WaitGroup)
|
||||
wg.Add(len(tt.events))
|
||||
|
||||
logSink := audit.NewLogSink(tt.fields.loggerSetup(t, wg))
|
||||
var evs audit.EventStream
|
||||
var err error
|
||||
if evs, err = audit.NewEventStream(logging.CreateTestLogger(t)); err != nil {
|
||||
t.Errorf("NewEventStream() error = %v", err)
|
||||
}
|
||||
|
||||
if err = evs.RegisterSink(logSink); err != nil {
|
||||
t.Errorf("RegisterSink() error = %v", err)
|
||||
}
|
||||
|
||||
for _, ev := range tt.events {
|
||||
evs.Emit(ev)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Errorf("not all events recorded in time")
|
||||
case <-waitGroupDone(wg):
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, scenario(tt))
|
||||
}
|
||||
}
|
87
pkg/audit/reader.go
Normal file
87
pkg/audit/reader.go
Normal file
|
@ -0,0 +1,87 @@
|
|||
package audit
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"sync"
|
||||
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
const (
|
||||
lengthBufferSize = 4
|
||||
defaultPayloadBufferSize = 4096
|
||||
)
|
||||
|
||||
type Reader interface {
|
||||
Read() (Event, error)
|
||||
}
|
||||
|
||||
type EventReaderOption func(reader *eventReader)
|
||||
|
||||
var (
|
||||
WithReaderByteOrder = func(order binary.ByteOrder) EventReaderOption {
|
||||
return func(reader *eventReader) {
|
||||
reader.byteOrder = order
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
func NewEventReader(source io.Reader, opts ...EventReaderOption) Reader {
|
||||
reader := &eventReader{
|
||||
source: source,
|
||||
byteOrder: defaultOrder,
|
||||
lengthPool: &sync.Pool{
|
||||
New: func() interface{} {
|
||||
return make([]byte, lengthBufferSize)
|
||||
},
|
||||
},
|
||||
payloadPool: &sync.Pool{
|
||||
New: func() interface{} {
|
||||
return make([]byte, defaultPayloadBufferSize)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(reader)
|
||||
}
|
||||
|
||||
return reader
|
||||
}
|
||||
|
||||
type eventReader struct {
|
||||
lengthPool *sync.Pool
|
||||
payloadPool *sync.Pool
|
||||
byteOrder binary.ByteOrder
|
||||
source io.Reader
|
||||
}
|
||||
|
||||
func (e eventReader) Read() (ev Event, err error) {
|
||||
lengthBuf := e.lengthPool.Get().([]byte)
|
||||
if _, err = e.source.Read(lengthBuf); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
length := e.byteOrder.Uint32(lengthBuf)
|
||||
var msgBuf []byte
|
||||
if length <= defaultPayloadBufferSize {
|
||||
msgBuf = e.payloadPool.Get().([]byte)[:length]
|
||||
} else {
|
||||
msgBuf = make([]byte, length)
|
||||
}
|
||||
|
||||
if _, err = e.source.Read(msgBuf); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
protoEv := new(EventEntity)
|
||||
|
||||
if err = proto.Unmarshal(msgBuf, protoEv); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ev = NewEventFromProto(protoEv)
|
||||
|
||||
return
|
||||
}
|
97
pkg/audit/reader_test.go
Normal file
97
pkg/audit/reader_test.go
Normal file
|
@ -0,0 +1,97 @@
|
|||
package audit_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"gitlab.com/inetmock/inetmock/pkg/audit"
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
var (
|
||||
//nolint:lll
|
||||
httpPayloadBytesLittleEndian = `dd000000120b088092b8c398feffffff011801200248d8fc0150505a3308041224544c535f45434448455f45434453415f574954485f4145535f3235365f4342435f5348411a096c6f63616c686f73746282010a34747970652e676f6f676c65617069732e636f6d2f696e65746d6f636b2e61756469742e4854545044657461696c73456e74697479124a12096c6f63616c686f73741a15687474703a2f2f6c6f63616c686f73742f6173646622084854545020312e312a1c0a0641636365707412120a106170706c69636174696f6e2f6a736f6e28818080f80738818080f807`
|
||||
//nolint:lll
|
||||
httpPayloadBytesBigEndian = `000000dd120b088092b8c398feffffff011801200248d8fc0150505a3308041224544c535f45434448455f45434453415f574954485f4145535f3235365f4342435f5348411a096c6f63616c686f73746282010a34747970652e676f6f676c65617069732e636f6d2f696e65746d6f636b2e61756469742e4854545044657461696c73456e74697479124a12096c6f63616c686f73741a15687474703a2f2f6c6f63616c686f73742f6173646622084854545020312e312a1c0a0641636365707412120a106170706c69636174696f6e2f6a736f6e28818080f80738818080f807`
|
||||
dnsPayloadBytesLittleEndian = `1b000000120b088092b8c398feffffff011801200148d8fc01505030014001`
|
||||
dnsPayloadBytesBigEndian = `0000001b120b088092b8c398feffffff011801200148d8fc01505030014001`
|
||||
)
|
||||
|
||||
func mustDecodeHex(hexBytes string) io.Reader {
|
||||
b, err := hex.DecodeString(hexBytes)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return bytes.NewReader(b)
|
||||
}
|
||||
|
||||
func Test_eventReader_Read(t *testing.T) {
|
||||
type fields struct {
|
||||
source io.Reader
|
||||
order binary.ByteOrder
|
||||
}
|
||||
type testCase struct {
|
||||
name string
|
||||
fields fields
|
||||
wantEv audit.Event
|
||||
wantErr bool
|
||||
}
|
||||
tests := []testCase{
|
||||
{
|
||||
name: "Read HTTP payload - little endian",
|
||||
fields: fields{
|
||||
source: mustDecodeHex(httpPayloadBytesLittleEndian),
|
||||
order: binary.LittleEndian,
|
||||
},
|
||||
wantEv: testEvents[0],
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Read HTTP payload - big endian",
|
||||
fields: fields{
|
||||
source: mustDecodeHex(httpPayloadBytesBigEndian),
|
||||
order: binary.BigEndian,
|
||||
},
|
||||
wantEv: testEvents[0],
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Read DNS payload - little endian",
|
||||
fields: fields{
|
||||
source: mustDecodeHex(dnsPayloadBytesLittleEndian),
|
||||
order: binary.LittleEndian,
|
||||
},
|
||||
wantEv: testEvents[1],
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Read DNS payload - big endian",
|
||||
fields: fields{
|
||||
source: mustDecodeHex(dnsPayloadBytesBigEndian),
|
||||
order: binary.BigEndian,
|
||||
},
|
||||
wantEv: testEvents[1],
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
scenario := func(tt testCase) func(t *testing.T) {
|
||||
return func(t *testing.T) {
|
||||
e := audit.NewEventReader(tt.fields.source, audit.WithReaderByteOrder(tt.fields.order))
|
||||
gotEv, err := e.Read()
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Read() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
if err == nil && !proto.Equal(gotEv.ProtoMessage(), tt.wantEv.ProtoMessage()) {
|
||||
t.Errorf("Read() gotEv = %v, want %v", gotEv, tt.wantEv)
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, scenario(tt))
|
||||
}
|
||||
}
|
|
@ -2,8 +2,31 @@ package audit
|
|||
|
||||
import (
|
||||
"crypto/tls"
|
||||
)
|
||||
|
||||
"google.golang.org/protobuf/proto"
|
||||
var (
|
||||
tlsToEntity = map[uint16]TLSVersion{
|
||||
tls.VersionSSL30: TLSVersion_SSLv30,
|
||||
tls.VersionTLS10: TLSVersion_TLS10,
|
||||
tls.VersionTLS11: TLSVersion_TLS11,
|
||||
tls.VersionTLS12: TLSVersion_TLS12,
|
||||
tls.VersionTLS13: TLSVersion_TLS13,
|
||||
}
|
||||
entityToTls = map[TLSVersion]uint16{
|
||||
TLSVersion_SSLv30: tls.VersionSSL30,
|
||||
TLSVersion_TLS10: tls.VersionTLS10,
|
||||
TLSVersion_TLS11: tls.VersionTLS11,
|
||||
TLSVersion_TLS12: tls.VersionTLS12,
|
||||
TLSVersion_TLS13: tls.VersionTLS13,
|
||||
}
|
||||
cipherSuiteIDLookup = func(name string) uint16 {
|
||||
for _, cs := range tls.CipherSuites() {
|
||||
if cs.Name == name {
|
||||
return cs.ID
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
)
|
||||
|
||||
type TLSDetails struct {
|
||||
|
@ -12,24 +35,21 @@ type TLSDetails struct {
|
|||
ServerName string
|
||||
}
|
||||
|
||||
func (d TLSDetails) ProtoMessage() proto.Message {
|
||||
var version TLSVersion
|
||||
|
||||
switch d.Version {
|
||||
case tls.VersionTLS10:
|
||||
version = TLSVersion_TLS10
|
||||
case tls.VersionTLS11:
|
||||
version = TLSVersion_TLS11
|
||||
case tls.VersionTLS12:
|
||||
version = TLSVersion_TLS12
|
||||
case tls.VersionTLS13:
|
||||
version = TLSVersion_TLS13
|
||||
default:
|
||||
version = TLSVersion_SSLv30
|
||||
func NewTLSDetailsFromProto(entity *TLSDetailsEntity) *TLSDetails {
|
||||
if entity == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &TLSDetails{
|
||||
Version: entityToTls[entity.GetVersion()],
|
||||
CipherSuite: cipherSuiteIDLookup(entity.GetCipherSuite()),
|
||||
ServerName: entity.GetServerName(),
|
||||
}
|
||||
}
|
||||
|
||||
func (d TLSDetails) ProtoMessage() *TLSDetailsEntity {
|
||||
return &TLSDetailsEntity{
|
||||
Version: version,
|
||||
Version: tlsToEntity[d.Version],
|
||||
CipherSuite: tls.CipherSuiteName(d.CipherSuite),
|
||||
ServerName: d.ServerName,
|
||||
}
|
||||
|
|
85
pkg/audit/writer.go
Normal file
85
pkg/audit/writer.go
Normal file
|
@ -0,0 +1,85 @@
|
|||
//go:generate mockgen -source=$GOFILE -destination=./../../internal/mock/audit/writer.mock.go -package=audit_mock
|
||||
|
||||
package audit
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"io"
|
||||
"sync"
|
||||
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
var (
|
||||
WithWriterByteOrder = func(order binary.ByteOrder) func(writer *eventWriter) {
|
||||
return func(writer *eventWriter) {
|
||||
writer.byteOrder = order
|
||||
}
|
||||
}
|
||||
defaultOrder = binary.BigEndian
|
||||
ErrValueMostNotBeNil = errors.New("event value must not be nil")
|
||||
)
|
||||
|
||||
type Writer interface {
|
||||
io.Closer
|
||||
Write(ev *Event) error
|
||||
}
|
||||
|
||||
type EventWriterOption func(writer *eventWriter)
|
||||
|
||||
func NewEventWriter(target io.Writer, opts ...EventWriterOption) Writer {
|
||||
writer := &eventWriter{
|
||||
target: target,
|
||||
byteOrder: defaultOrder,
|
||||
lengthPool: &sync.Pool{
|
||||
New: func() interface{} {
|
||||
return make([]byte, lengthBufferSize)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(writer)
|
||||
}
|
||||
|
||||
return writer
|
||||
}
|
||||
|
||||
type eventWriter struct {
|
||||
lengthPool *sync.Pool
|
||||
target io.Writer
|
||||
byteOrder binary.ByteOrder
|
||||
}
|
||||
|
||||
func (e eventWriter) Write(ev *Event) (err error) {
|
||||
if ev == nil {
|
||||
return ErrValueMostNotBeNil
|
||||
}
|
||||
var bytes []byte
|
||||
|
||||
if bytes, err = proto.Marshal(ev.ProtoMessage()); err != nil {
|
||||
return
|
||||
}
|
||||
buf := e.lengthPool.Get().([]byte)
|
||||
e.byteOrder.PutUint32(buf, uint32(len(bytes)))
|
||||
|
||||
if _, err = e.target.Write(buf); err != nil {
|
||||
return
|
||||
}
|
||||
if _, err = e.target.Write(bytes); err != nil {
|
||||
return
|
||||
}
|
||||
if syncerInst, ok := e.target.(syncer); ok {
|
||||
err = syncerInst.Sync()
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (e eventWriter) Close() error {
|
||||
if closer, isCloser := e.target.(io.Closer); isCloser {
|
||||
return closer.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
47
pkg/audit/writer_sink.go
Normal file
47
pkg/audit/writer_sink.go
Normal file
|
@ -0,0 +1,47 @@
|
|||
package audit
|
||||
|
||||
type WriterSinkOption func(sink *writerCloserSink)
|
||||
|
||||
var (
|
||||
WithCloseOnExit WriterSinkOption = func(sink *writerCloserSink) {
|
||||
sink.closeOnExit = true
|
||||
}
|
||||
)
|
||||
|
||||
func NewWriterSink(name string, target Writer, opts ...WriterSinkOption) Sink {
|
||||
sink := &writerCloserSink{
|
||||
name: name,
|
||||
target: target,
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(sink)
|
||||
}
|
||||
|
||||
return sink
|
||||
}
|
||||
|
||||
type writerCloserSink struct {
|
||||
name string
|
||||
target Writer
|
||||
closeOnExit bool
|
||||
}
|
||||
|
||||
type syncer interface {
|
||||
Sync() error
|
||||
}
|
||||
|
||||
func (f writerCloserSink) Name() string {
|
||||
return f.name
|
||||
}
|
||||
|
||||
func (f writerCloserSink) OnSubscribe(evs <-chan Event) {
|
||||
go func(target Writer, closeOnExit bool, evs <-chan Event) {
|
||||
for ev := range evs {
|
||||
_ = target.Write(&ev)
|
||||
}
|
||||
if closeOnExit {
|
||||
_ = target.Close()
|
||||
}
|
||||
}(f.target, f.closeOnExit, evs)
|
||||
}
|
72
pkg/audit/writer_sink_test.go
Normal file
72
pkg/audit/writer_sink_test.go
Normal file
|
@ -0,0 +1,72 @@
|
|||
package audit_test
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
audit_mock "gitlab.com/inetmock/inetmock/internal/mock/audit"
|
||||
"gitlab.com/inetmock/inetmock/pkg/audit"
|
||||
"gitlab.com/inetmock/inetmock/pkg/logging"
|
||||
)
|
||||
|
||||
func Test_writerCloserSink_OnSubscribe(t *testing.T) {
|
||||
type testCase struct {
|
||||
name string
|
||||
events []audit.Event
|
||||
}
|
||||
tests := []testCase{
|
||||
{
|
||||
name: "Get a single event",
|
||||
events: testEvents[:1],
|
||||
},
|
||||
{
|
||||
name: "Get multiple events",
|
||||
events: testEvents,
|
||||
},
|
||||
}
|
||||
scenario := func(tt testCase) func(t *testing.T) {
|
||||
return func(t *testing.T) {
|
||||
wg := new(sync.WaitGroup)
|
||||
wg.Add(len(tt.events))
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
writerMock := audit_mock.NewMockWriter(ctrl)
|
||||
writerMock.
|
||||
EXPECT().
|
||||
Write(gomock.Any()).
|
||||
Do(func(_ *audit.Event) {
|
||||
wg.Done()
|
||||
}).
|
||||
Times(len(tt.events))
|
||||
|
||||
writerCloserSink := audit.NewWriterSink("WriterMock", writerMock, audit.WithCloseOnExit)
|
||||
var evs audit.EventStream
|
||||
var err error
|
||||
|
||||
if evs, err = audit.NewEventStream(logging.CreateTestLogger(t)); err != nil {
|
||||
t.Errorf("NewEventStream() error = %v", err)
|
||||
}
|
||||
|
||||
if err = evs.RegisterSink(writerCloserSink); err != nil {
|
||||
t.Errorf("RegisterSink() error = %v", err)
|
||||
}
|
||||
|
||||
for _, ev := range tt.events {
|
||||
evs.Emit(ev)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Errorf("not all events recorded in time")
|
||||
case <-waitGroupDone(wg):
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, scenario(tt))
|
||||
}
|
||||
}
|
117
pkg/audit/writer_test.go
Normal file
117
pkg/audit/writer_test.go
Normal file
|
@ -0,0 +1,117 @@
|
|||
//go:generate mockgen -source=$GOFILE -destination=./../../internal/mock/audit/writercloser.mock.go -package=audit_mock
|
||||
|
||||
package audit_test
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
audit_mock "gitlab.com/inetmock/inetmock/internal/mock/audit"
|
||||
"gitlab.com/inetmock/inetmock/pkg/audit"
|
||||
)
|
||||
|
||||
type WriterCloserSyncer interface {
|
||||
io.WriteCloser
|
||||
Sync() error
|
||||
}
|
||||
|
||||
func Test_eventWriter_Write(t *testing.T) {
|
||||
type fields struct {
|
||||
order binary.ByteOrder
|
||||
}
|
||||
type args struct {
|
||||
evs []audit.Event
|
||||
}
|
||||
type testCase struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
wantErr bool
|
||||
}
|
||||
tests := []testCase{
|
||||
{
|
||||
name: "Write a single event - little endian",
|
||||
fields: fields{
|
||||
order: binary.LittleEndian,
|
||||
},
|
||||
args: args{
|
||||
evs: testEvents[:1],
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Write a single event - big endian",
|
||||
fields: fields{
|
||||
order: binary.BigEndian,
|
||||
},
|
||||
args: args{
|
||||
evs: testEvents[:1],
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Write multiple events - little endian",
|
||||
fields: fields{
|
||||
order: binary.LittleEndian,
|
||||
},
|
||||
args: args{
|
||||
evs: testEvents,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Write multiple events - big endian",
|
||||
fields: fields{
|
||||
order: binary.BigEndian,
|
||||
},
|
||||
args: args{
|
||||
evs: testEvents,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
scenario := func(tt testCase) func(t *testing.T) {
|
||||
return func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
writerMock := audit_mock.NewMockWriterCloserSyncer(ctrl)
|
||||
calls := make([]*gomock.Call, 0)
|
||||
for i := 0; i < len(tt.args.evs); i++ {
|
||||
calls = append(calls,
|
||||
writerMock.
|
||||
EXPECT().
|
||||
Write(gomock.Any()).
|
||||
Do(func(data []byte) {
|
||||
t.Logf("got payload = %s", hex.EncodeToString(data))
|
||||
t.Logf("got length %d", tt.fields.order.Uint32(data))
|
||||
}),
|
||||
writerMock.
|
||||
EXPECT().
|
||||
Write(gomock.Any()).
|
||||
Do(func(data []byte) {
|
||||
t.Logf("got payload = %s", hex.EncodeToString(data))
|
||||
}),
|
||||
writerMock.
|
||||
EXPECT().
|
||||
Sync(),
|
||||
)
|
||||
}
|
||||
gomock.InOrder(calls...)
|
||||
|
||||
e := audit.NewEventWriter(writerMock, audit.WithWriterByteOrder(tt.fields.order))
|
||||
|
||||
for _, ev := range tt.args.evs {
|
||||
if err := e.Write(&ev); (err != nil) != tt.wantErr {
|
||||
t.Errorf("Write() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, scenario(tt))
|
||||
}
|
||||
}
|
|
@ -40,7 +40,7 @@ func AddTLSInterceptor(registry api.HandlerRegistry) (err error) {
|
|||
registry.RegisterHandler(name, func() api.ProtocolHandler {
|
||||
return &tlsInterceptor{
|
||||
logger: logger,
|
||||
currentConnectionsCount: &sync.WaitGroup{},
|
||||
currentConnectionsCount: new(sync.WaitGroup),
|
||||
currentConnections: make(map[uuid.UUID]*proxyConn),
|
||||
connectionsMutex: &sync.Mutex{},
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue