Implement log and writer sinks
Add reader
This commit is contained in:
parent
1e8139e845
commit
63607dfd67
15 changed files with 883 additions and 88 deletions
|
@ -99,7 +99,7 @@ func (e *endpointManager) StartEndpoints() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *endpointManager) ShutdownEndpoints() {
|
func (e *endpointManager) ShutdownEndpoints() {
|
||||||
var waitGroup sync.WaitGroup
|
waitGroup := new(sync.WaitGroup)
|
||||||
waitGroup.Add(len(e.properlyStartedEndpoints))
|
waitGroup.Add(len(e.properlyStartedEndpoints))
|
||||||
|
|
||||||
parentCtx, _ := context.WithTimeout(context.Background(), shutdownTimeout)
|
parentCtx, _ := context.WithTimeout(context.Background(), shutdownTimeout)
|
||||||
|
@ -112,7 +112,7 @@ func (e *endpointManager) ShutdownEndpoints() {
|
||||||
zap.String("endpoint", endpoint.Name()),
|
zap.String("endpoint", endpoint.Name()),
|
||||||
)
|
)
|
||||||
endpointLogger.Info("Triggering shutdown of endpoint")
|
endpointLogger.Info("Triggering shutdown of endpoint")
|
||||||
go shutdownEndpoint(ctx, endpoint, endpointLogger, &waitGroup)
|
go shutdownEndpoint(ctx, endpoint, endpointLogger, waitGroup)
|
||||||
}
|
}
|
||||||
|
|
||||||
waitGroup.Wait()
|
waitGroup.Wait()
|
||||||
|
|
|
@ -1,14 +1,18 @@
|
||||||
package audit
|
package audit
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"math/big"
|
||||||
"net"
|
"net"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"google.golang.org/protobuf/proto"
|
"google.golang.org/protobuf/proto"
|
||||||
|
"google.golang.org/protobuf/types/known/anypb"
|
||||||
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
)
|
)
|
||||||
|
|
||||||
type EventDetails interface {
|
type Details interface {
|
||||||
ProtoMessage() proto.Message
|
MarshalToWireFormat() (*anypb.Any, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Event struct {
|
type Event struct {
|
||||||
|
@ -20,6 +24,98 @@ type Event struct {
|
||||||
DestinationIP net.IP
|
DestinationIP net.IP
|
||||||
SourcePort uint16
|
SourcePort uint16
|
||||||
DestinationPort uint16
|
DestinationPort uint16
|
||||||
ProtocolDetails EventDetails
|
ProtocolDetails *anypb.Any
|
||||||
TLS *TLSDetails
|
TLS *TLSDetails
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (e *Event) ProtoMessage() proto.Message {
|
||||||
|
var sourceIP isEventEntity_SourceIP
|
||||||
|
if ipv4 := e.SourceIP.To4(); ipv4 != nil {
|
||||||
|
if len(ipv4) == 16 {
|
||||||
|
sourceIP = &EventEntity_SourceIPv4{SourceIPv4: binary.BigEndian.Uint32(ipv4[12:16])}
|
||||||
|
} else {
|
||||||
|
sourceIP = &EventEntity_SourceIPv4{SourceIPv4: binary.BigEndian.Uint32(ipv4)}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
ipv6 := big.NewInt(0)
|
||||||
|
ipv6.SetBytes(e.SourceIP)
|
||||||
|
sourceIP = &EventEntity_SourceIPv6{SourceIPv6: ipv6.Uint64()}
|
||||||
|
}
|
||||||
|
|
||||||
|
var destinationIP isEventEntity_DestinationIP
|
||||||
|
if ipv4 := e.DestinationIP.To4(); ipv4 != nil {
|
||||||
|
if len(ipv4) == 16 {
|
||||||
|
destinationIP = &EventEntity_DestinationIPv4{DestinationIPv4: binary.BigEndian.Uint32(ipv4[12:16])}
|
||||||
|
} else {
|
||||||
|
destinationIP = &EventEntity_DestinationIPv4{DestinationIPv4: binary.BigEndian.Uint32(ipv4)}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
ipv6 := big.NewInt(0)
|
||||||
|
ipv6.SetBytes(e.SourceIP)
|
||||||
|
destinationIP = &EventEntity_DestinationIPv6{DestinationIPv6: ipv6.Uint64()}
|
||||||
|
}
|
||||||
|
|
||||||
|
var tlsDetails *TLSDetailsEntity = nil
|
||||||
|
if e.TLS != nil {
|
||||||
|
tlsDetails = e.TLS.ProtoMessage()
|
||||||
|
}
|
||||||
|
|
||||||
|
return &EventEntity{
|
||||||
|
Id: e.ID,
|
||||||
|
Timestamp: timestamppb.New(e.Timestamp),
|
||||||
|
Transport: e.Transport,
|
||||||
|
Application: e.Application,
|
||||||
|
SourceIP: sourceIP,
|
||||||
|
DestinationIP: destinationIP,
|
||||||
|
SourcePort: uint32(e.SourcePort),
|
||||||
|
DestinationPort: uint32(e.DestinationPort),
|
||||||
|
Tls: tlsDetails,
|
||||||
|
ProtocolDetails: e.ProtocolDetails,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Event) ApplyDefaults(id int64) {
|
||||||
|
e.ID = id
|
||||||
|
emptyTime := time.Time{}
|
||||||
|
if e.Timestamp == emptyTime {
|
||||||
|
e.Timestamp = time.Now().UTC()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewEventFromProto(msg *EventEntity) (ev Event) {
|
||||||
|
var sourceIP net.IP
|
||||||
|
switch ip := msg.GetSourceIP().(type) {
|
||||||
|
case *EventEntity_SourceIPv4:
|
||||||
|
buf := make([]byte, 4)
|
||||||
|
binary.BigEndian.PutUint32(buf, ip.SourceIPv4)
|
||||||
|
sourceIP = buf
|
||||||
|
sourceIP = sourceIP.To4()
|
||||||
|
case *EventEntity_SourceIPv6:
|
||||||
|
sourceIP = big.NewInt(int64(ip.SourceIPv6)).Bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
var destinationIP net.IP
|
||||||
|
switch ip := msg.GetDestinationIP().(type) {
|
||||||
|
case *EventEntity_DestinationIPv4:
|
||||||
|
buf := make([]byte, 4)
|
||||||
|
binary.BigEndian.PutUint32(buf, ip.DestinationIPv4)
|
||||||
|
destinationIP = buf
|
||||||
|
destinationIP = destinationIP.To4()
|
||||||
|
case *EventEntity_DestinationIPv6:
|
||||||
|
destinationIP = big.NewInt(int64(ip.DestinationIPv6)).Bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
ev = Event{
|
||||||
|
ID: msg.GetId(),
|
||||||
|
Timestamp: msg.GetTimestamp().AsTime(),
|
||||||
|
Transport: msg.GetTransport(),
|
||||||
|
Application: msg.GetApplication(),
|
||||||
|
SourceIP: sourceIP,
|
||||||
|
DestinationIP: destinationIP,
|
||||||
|
SourcePort: uint16(msg.GetSourcePort()),
|
||||||
|
DestinationPort: uint16(msg.GetDestinationPort()),
|
||||||
|
ProtocolDetails: msg.GetProtocolDetails(),
|
||||||
|
TLS: NewTLSDetailsFromProto(msg.GetTls()),
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
|
@ -5,7 +5,6 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/bwmarrin/snowflake"
|
"github.com/bwmarrin/snowflake"
|
||||||
"github.com/imdario/mergo"
|
|
||||||
"gitlab.com/inetmock/inetmock/pkg/logging"
|
"gitlab.com/inetmock/inetmock/pkg/logging"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
)
|
)
|
||||||
|
@ -56,10 +55,9 @@ func NewEventStream(logger logging.Logger, options ...EventStreamOption) (es Eve
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *eventStream) Emit(ev Event) {
|
func (e *eventStream) Emit(ev Event) {
|
||||||
defaultEvent := e.newDefaultEvent()
|
ev.ApplyDefaults(e.idGenerator.Generate().Int64())
|
||||||
_ = mergo.Merge(defaultEvent, &ev)
|
|
||||||
select {
|
select {
|
||||||
case e.buffer <- defaultEvent:
|
case e.buffer <- &ev:
|
||||||
e.logger.Debug("pushed event to distribute loop")
|
e.logger.Debug("pushed event to distribute loop")
|
||||||
default:
|
default:
|
||||||
e.logger.Warn("buffer is full")
|
e.logger.Warn("buffer is full")
|
||||||
|
@ -107,11 +105,3 @@ func (e *eventStream) distribute() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *eventStream) newDefaultEvent() *Event {
|
|
||||||
id := e.idGenerator.Generate()
|
|
||||||
return &Event{
|
|
||||||
ID: id.Int64(),
|
|
||||||
Timestamp: time.Now().UTC(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
@ -1,19 +1,54 @@
|
||||||
package audit_test
|
package audit_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/tls"
|
||||||
"net"
|
"net"
|
||||||
|
"net/http"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"gitlab.com/inetmock/inetmock/pkg/audit"
|
"gitlab.com/inetmock/inetmock/pkg/audit"
|
||||||
"gitlab.com/inetmock/inetmock/pkg/logging"
|
"gitlab.com/inetmock/inetmock/pkg/logging"
|
||||||
|
"google.golang.org/protobuf/types/known/anypb"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
defaultSink = &testSink{
|
defaultSink = &testSink{
|
||||||
name: "test defaultSink",
|
name: "test defaultSink",
|
||||||
}
|
}
|
||||||
|
testEvents = []audit.Event{
|
||||||
|
{
|
||||||
|
Transport: audit.TransportProtocol_TCP,
|
||||||
|
Application: audit.AppProtocol_HTTP,
|
||||||
|
SourceIP: net.ParseIP("127.0.0.1"),
|
||||||
|
DestinationIP: net.ParseIP("127.0.0.1"),
|
||||||
|
SourcePort: 32344,
|
||||||
|
DestinationPort: 80,
|
||||||
|
TLS: &audit.TLSDetails{
|
||||||
|
Version: tls.VersionTLS13,
|
||||||
|
CipherSuite: tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
|
||||||
|
ServerName: "localhost",
|
||||||
|
},
|
||||||
|
ProtocolDetails: mustMarshalToWireformat(audit.HTTPDetails{
|
||||||
|
Method: "GET",
|
||||||
|
Host: "localhost",
|
||||||
|
URI: "http://localhost/asdf",
|
||||||
|
Proto: "HTTP 1.1",
|
||||||
|
Headers: http.Header{
|
||||||
|
"Accept": []string{"application/json"},
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Transport: audit.TransportProtocol_TCP,
|
||||||
|
Application: audit.AppProtocol_DNS,
|
||||||
|
SourceIP: net.ParseIP("::1"),
|
||||||
|
DestinationIP: net.ParseIP("::1"),
|
||||||
|
SourcePort: 32344,
|
||||||
|
DestinationPort: 80,
|
||||||
|
},
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
type testSink struct {
|
type testSink struct {
|
||||||
|
@ -45,17 +80,25 @@ func wgMockSink(t testing.TB, wg *sync.WaitGroup) audit.Sink {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func mustMarshalToWireformat(d audit.Details) *anypb.Any {
|
||||||
|
any, err := d.MarshalToWireFormat()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return any
|
||||||
|
}
|
||||||
|
|
||||||
func Test_eventStream_RegisterSink(t *testing.T) {
|
func Test_eventStream_RegisterSink(t *testing.T) {
|
||||||
type args struct {
|
type args struct {
|
||||||
s audit.Sink
|
s audit.Sink
|
||||||
}
|
}
|
||||||
|
type testCase struct {
|
||||||
tests := []struct {
|
|
||||||
name string
|
name string
|
||||||
args args
|
args args
|
||||||
setup func(e audit.EventStream)
|
setup func(e audit.EventStream)
|
||||||
wantErr bool
|
wantErr bool
|
||||||
}{
|
}
|
||||||
|
tests := []testCase{
|
||||||
{
|
{
|
||||||
name: "Register test defaultSink",
|
name: "Register test defaultSink",
|
||||||
args: args{
|
args: args{
|
||||||
|
@ -74,8 +117,8 @@ func Test_eventStream_RegisterSink(t *testing.T) {
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
scenario := func(tt testCase) func(t *testing.T) {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
return func(t *testing.T) {
|
||||||
var err error
|
var err error
|
||||||
var e audit.EventStream
|
var e audit.EventStream
|
||||||
if e, err = audit.NewEventStream(logging.CreateTestLogger(t)); err != nil {
|
if e, err = audit.NewEventStream(logging.CreateTestLogger(t)); err != nil {
|
||||||
|
@ -103,7 +146,11 @@ func Test_eventStream_RegisterSink(t *testing.T) {
|
||||||
if !found {
|
if !found {
|
||||||
t.Errorf("expected defaultSink name %s not found in registered sinks %v", tt.args.s.Name(), e.Sinks())
|
t.Errorf("expected defaultSink name %s not found in registered sinks %v", tt.args.s.Name(), e.Sinks())
|
||||||
}
|
}
|
||||||
})
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, scenario(tt))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -112,26 +159,18 @@ func Test_eventStream_Emit(t *testing.T) {
|
||||||
evs []audit.Event
|
evs []audit.Event
|
||||||
opts []audit.EventStreamOption
|
opts []audit.EventStreamOption
|
||||||
}
|
}
|
||||||
tests := []struct {
|
type testCase struct {
|
||||||
name string
|
name string
|
||||||
args args
|
args args
|
||||||
subscribe bool
|
subscribe bool
|
||||||
}{
|
}
|
||||||
|
tests := []testCase{
|
||||||
{
|
{
|
||||||
name: "Expect to get a single event",
|
name: "Expect to get a single event",
|
||||||
subscribe: true,
|
subscribe: true,
|
||||||
args: args{
|
args: args{
|
||||||
opts: []audit.EventStreamOption{audit.WithBufferSize(10)},
|
opts: []audit.EventStreamOption{audit.WithBufferSize(10)},
|
||||||
evs: []audit.Event{
|
evs: testEvents[:1],
|
||||||
{
|
|
||||||
Transport: audit.TransportProtocol_TCP,
|
|
||||||
Application: audit.AppProtocol_HTTP,
|
|
||||||
SourceIP: net.ParseIP("127.0.0.1"),
|
|
||||||
DestinationIP: net.ParseIP("127.0.0.1"),
|
|
||||||
SourcePort: 32344,
|
|
||||||
DestinationPort: 80,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -139,46 +178,21 @@ func Test_eventStream_Emit(t *testing.T) {
|
||||||
subscribe: true,
|
subscribe: true,
|
||||||
args: args{
|
args: args{
|
||||||
opts: []audit.EventStreamOption{audit.WithBufferSize(10)},
|
opts: []audit.EventStreamOption{audit.WithBufferSize(10)},
|
||||||
evs: []audit.Event{
|
evs: testEvents,
|
||||||
{
|
|
||||||
Transport: audit.TransportProtocol_TCP,
|
|
||||||
Application: audit.AppProtocol_HTTP,
|
|
||||||
SourceIP: net.ParseIP("127.0.0.1"),
|
|
||||||
DestinationIP: net.ParseIP("127.0.0.1"),
|
|
||||||
SourcePort: 32344,
|
|
||||||
DestinationPort: 80,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Transport: audit.TransportProtocol_TCP,
|
|
||||||
Application: audit.AppProtocol_DNS,
|
|
||||||
SourceIP: net.ParseIP("::1"),
|
|
||||||
DestinationIP: net.ParseIP("::1"),
|
|
||||||
SourcePort: 32344,
|
|
||||||
DestinationPort: 80,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Emit without subscribe sink",
|
name: "Emit without subscribe sink",
|
||||||
args: args{
|
args: args{
|
||||||
opts: []audit.EventStreamOption{audit.WithBufferSize(0)},
|
opts: []audit.EventStreamOption{audit.WithBufferSize(0)},
|
||||||
evs: []audit.Event{
|
evs: testEvents[:1],
|
||||||
{
|
|
||||||
Transport: audit.TransportProtocol_TCP,
|
|
||||||
Application: audit.AppProtocol_HTTP,
|
|
||||||
SourceIP: net.ParseIP("127.0.0.1"),
|
|
||||||
DestinationIP: net.ParseIP("127.0.0.1"),
|
|
||||||
SourcePort: 32344,
|
|
||||||
DestinationPort: 80,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
subscribe: false,
|
subscribe: false,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
scenario := func(tt testCase) func(t *testing.T) {
|
||||||
|
return func(t *testing.T) {
|
||||||
var err error
|
var err error
|
||||||
var e audit.EventStream
|
var e audit.EventStream
|
||||||
if e, err = audit.NewEventStream(logging.CreateTestLogger(t), tt.args.opts...); err != nil {
|
if e, err = audit.NewEventStream(logging.CreateTestLogger(t), tt.args.opts...); err != nil {
|
||||||
|
@ -189,8 +203,8 @@ func Test_eventStream_Emit(t *testing.T) {
|
||||||
_ = e.Close()
|
_ = e.Close()
|
||||||
})
|
})
|
||||||
|
|
||||||
emittedWaitGroup := &sync.WaitGroup{}
|
emittedWaitGroup := new(sync.WaitGroup)
|
||||||
receivedWaitGroup := &sync.WaitGroup{}
|
receivedWaitGroup := new(sync.WaitGroup)
|
||||||
|
|
||||||
emittedWaitGroup.Add(len(tt.args.evs))
|
emittedWaitGroup.Add(len(tt.args.evs))
|
||||||
|
|
||||||
|
@ -219,7 +233,11 @@ func Test_eventStream_Emit(t *testing.T) {
|
||||||
case <-time.After(5 * time.Second):
|
case <-time.After(5 * time.Second):
|
||||||
t.Errorf("did not get all expected events in time")
|
t.Errorf("did not get all expected events in time")
|
||||||
}
|
}
|
||||||
})
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, scenario(tt))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -4,7 +4,7 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"google.golang.org/protobuf/proto"
|
"google.golang.org/protobuf/types/known/anypb"
|
||||||
)
|
)
|
||||||
|
|
||||||
type HTTPDetails struct {
|
type HTTPDetails struct {
|
||||||
|
@ -15,7 +15,7 @@ type HTTPDetails struct {
|
||||||
Headers http.Header
|
Headers http.Header
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d HTTPDetails) ProtoMessage() proto.Message {
|
func (d HTTPDetails) MarshalToWireFormat() (any *anypb.Any, err error) {
|
||||||
var method = HTTPMethod_GET
|
var method = HTTPMethod_GET
|
||||||
if methodValue, known := HTTPMethod_value[strings.ToUpper(d.Method)]; known {
|
if methodValue, known := HTTPMethod_value[strings.ToUpper(d.Method)]; known {
|
||||||
method = HTTPMethod(methodValue)
|
method = HTTPMethod(methodValue)
|
||||||
|
@ -29,11 +29,14 @@ func (d HTTPDetails) ProtoMessage() proto.Message {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return &HTTPDetailsEntity{
|
protoDetails := &HTTPDetailsEntity{
|
||||||
Method: method,
|
Method: method,
|
||||||
Host: d.Host,
|
Host: d.Host,
|
||||||
Uri: d.URI,
|
Uri: d.URI,
|
||||||
Proto: d.Proto,
|
Proto: d.Proto,
|
||||||
Headers: headers,
|
Headers: headers,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
any, err = anypb.New(protoDetails)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
53
pkg/audit/log_sink.go
Normal file
53
pkg/audit/log_sink.go
Normal file
|
@ -0,0 +1,53 @@
|
||||||
|
package audit
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
|
||||||
|
"gitlab.com/inetmock/inetmock/pkg/logging"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
logSinkName = "logging"
|
||||||
|
)
|
||||||
|
|
||||||
|
func NewLogSink(logger logging.Logger) Sink {
|
||||||
|
return &logSink{
|
||||||
|
logger: logger,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type logSink struct {
|
||||||
|
logger logging.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
func (logSink) Name() string {
|
||||||
|
return logSinkName
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l logSink) OnSubscribe(evs <-chan Event) {
|
||||||
|
go func(logger logging.Logger, evs <-chan Event) {
|
||||||
|
for ev := range evs {
|
||||||
|
eventLogger := logger
|
||||||
|
|
||||||
|
if ev.TLS != nil {
|
||||||
|
eventLogger = eventLogger.With(
|
||||||
|
zap.String("tls_server_name", ev.TLS.ServerName),
|
||||||
|
zap.String("tls_cipher_suite", tls.CipherSuiteName(ev.TLS.CipherSuite)),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
eventLogger.Info(
|
||||||
|
"handled request",
|
||||||
|
zap.Time("timestamp", ev.Timestamp),
|
||||||
|
zap.String("application", ev.Application.String()),
|
||||||
|
zap.String("transport", ev.Transport.String()),
|
||||||
|
zap.String("source_ip", ev.SourceIP.String()),
|
||||||
|
zap.Uint16("source_port", ev.SourcePort),
|
||||||
|
zap.String("destination_ip", ev.DestinationIP.String()),
|
||||||
|
zap.Uint16("destination_port", ev.DestinationPort),
|
||||||
|
zap.Any("details", ev.ProtocolDetails),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}(l.logger, evs)
|
||||||
|
}
|
110
pkg/audit/log_sink_test.go
Normal file
110
pkg/audit/log_sink_test.go
Normal file
|
@ -0,0 +1,110 @@
|
||||||
|
package audit_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang/mock/gomock"
|
||||||
|
logging_mock "gitlab.com/inetmock/inetmock/internal/mock/logging"
|
||||||
|
"gitlab.com/inetmock/inetmock/pkg/audit"
|
||||||
|
"gitlab.com/inetmock/inetmock/pkg/logging"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
//nolint:dupl
|
||||||
|
func Test_logSink_OnSubscribe(t *testing.T) {
|
||||||
|
type fields struct {
|
||||||
|
loggerSetup func(t *testing.T, wg *sync.WaitGroup) logging.Logger
|
||||||
|
}
|
||||||
|
type testCase struct {
|
||||||
|
name string
|
||||||
|
fields fields
|
||||||
|
events []audit.Event
|
||||||
|
}
|
||||||
|
tests := []testCase{
|
||||||
|
{
|
||||||
|
name: "Get a single log line",
|
||||||
|
fields: fields{
|
||||||
|
loggerSetup: func(t *testing.T, wg *sync.WaitGroup) logging.Logger {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
t.Cleanup(ctrl.Finish)
|
||||||
|
loggerMock := logging_mock.NewMockLogger(ctrl)
|
||||||
|
|
||||||
|
loggerMock.
|
||||||
|
EXPECT().
|
||||||
|
With(gomock.Any()).
|
||||||
|
Return(loggerMock)
|
||||||
|
|
||||||
|
loggerMock.
|
||||||
|
EXPECT().
|
||||||
|
Info("handled request", gomock.Any()).
|
||||||
|
Do(func(_ string, _ ...zap.Field) {
|
||||||
|
wg.Done()
|
||||||
|
}).
|
||||||
|
Times(1)
|
||||||
|
|
||||||
|
return loggerMock
|
||||||
|
},
|
||||||
|
},
|
||||||
|
events: testEvents[:1],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Get multiple events",
|
||||||
|
fields: fields{
|
||||||
|
loggerSetup: func(t *testing.T, wg *sync.WaitGroup) logging.Logger {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
t.Cleanup(ctrl.Finish)
|
||||||
|
loggerMock := logging_mock.NewMockLogger(ctrl)
|
||||||
|
|
||||||
|
loggerMock.
|
||||||
|
EXPECT().
|
||||||
|
With(gomock.Any()).
|
||||||
|
Return(loggerMock)
|
||||||
|
|
||||||
|
loggerMock.
|
||||||
|
EXPECT().
|
||||||
|
Info("handled request", gomock.Any()).
|
||||||
|
Do(func(_ string, _ ...zap.Field) {
|
||||||
|
wg.Done()
|
||||||
|
}).
|
||||||
|
Times(2)
|
||||||
|
|
||||||
|
return loggerMock
|
||||||
|
},
|
||||||
|
},
|
||||||
|
events: testEvents,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
scenario := func(tt testCase) func(t *testing.T) {
|
||||||
|
return func(t *testing.T) {
|
||||||
|
wg := new(sync.WaitGroup)
|
||||||
|
wg.Add(len(tt.events))
|
||||||
|
|
||||||
|
logSink := audit.NewLogSink(tt.fields.loggerSetup(t, wg))
|
||||||
|
var evs audit.EventStream
|
||||||
|
var err error
|
||||||
|
if evs, err = audit.NewEventStream(logging.CreateTestLogger(t)); err != nil {
|
||||||
|
t.Errorf("NewEventStream() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = evs.RegisterSink(logSink); err != nil {
|
||||||
|
t.Errorf("RegisterSink() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, ev := range tt.events {
|
||||||
|
evs.Emit(ev)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
t.Errorf("not all events recorded in time")
|
||||||
|
case <-waitGroupDone(wg):
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, scenario(tt))
|
||||||
|
}
|
||||||
|
}
|
87
pkg/audit/reader.go
Normal file
87
pkg/audit/reader.go
Normal file
|
@ -0,0 +1,87 @@
|
||||||
|
package audit
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"io"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"google.golang.org/protobuf/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
lengthBufferSize = 4
|
||||||
|
defaultPayloadBufferSize = 4096
|
||||||
|
)
|
||||||
|
|
||||||
|
type Reader interface {
|
||||||
|
Read() (Event, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type EventReaderOption func(reader *eventReader)
|
||||||
|
|
||||||
|
var (
|
||||||
|
WithReaderByteOrder = func(order binary.ByteOrder) EventReaderOption {
|
||||||
|
return func(reader *eventReader) {
|
||||||
|
reader.byteOrder = order
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
func NewEventReader(source io.Reader, opts ...EventReaderOption) Reader {
|
||||||
|
reader := &eventReader{
|
||||||
|
source: source,
|
||||||
|
byteOrder: defaultOrder,
|
||||||
|
lengthPool: &sync.Pool{
|
||||||
|
New: func() interface{} {
|
||||||
|
return make([]byte, lengthBufferSize)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
payloadPool: &sync.Pool{
|
||||||
|
New: func() interface{} {
|
||||||
|
return make([]byte, defaultPayloadBufferSize)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(reader)
|
||||||
|
}
|
||||||
|
|
||||||
|
return reader
|
||||||
|
}
|
||||||
|
|
||||||
|
type eventReader struct {
|
||||||
|
lengthPool *sync.Pool
|
||||||
|
payloadPool *sync.Pool
|
||||||
|
byteOrder binary.ByteOrder
|
||||||
|
source io.Reader
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e eventReader) Read() (ev Event, err error) {
|
||||||
|
lengthBuf := e.lengthPool.Get().([]byte)
|
||||||
|
if _, err = e.source.Read(lengthBuf); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
length := e.byteOrder.Uint32(lengthBuf)
|
||||||
|
var msgBuf []byte
|
||||||
|
if length <= defaultPayloadBufferSize {
|
||||||
|
msgBuf = e.payloadPool.Get().([]byte)[:length]
|
||||||
|
} else {
|
||||||
|
msgBuf = make([]byte, length)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err = e.source.Read(msgBuf); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
protoEv := new(EventEntity)
|
||||||
|
|
||||||
|
if err = proto.Unmarshal(msgBuf, protoEv); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ev = NewEventFromProto(protoEv)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
97
pkg/audit/reader_test.go
Normal file
97
pkg/audit/reader_test.go
Normal file
|
@ -0,0 +1,97 @@
|
||||||
|
package audit_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/hex"
|
||||||
|
"io"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"gitlab.com/inetmock/inetmock/pkg/audit"
|
||||||
|
"google.golang.org/protobuf/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
//nolint:lll
|
||||||
|
httpPayloadBytesLittleEndian = `dd000000120b088092b8c398feffffff011801200248d8fc0150505a3308041224544c535f45434448455f45434453415f574954485f4145535f3235365f4342435f5348411a096c6f63616c686f73746282010a34747970652e676f6f676c65617069732e636f6d2f696e65746d6f636b2e61756469742e4854545044657461696c73456e74697479124a12096c6f63616c686f73741a15687474703a2f2f6c6f63616c686f73742f6173646622084854545020312e312a1c0a0641636365707412120a106170706c69636174696f6e2f6a736f6e28818080f80738818080f807`
|
||||||
|
//nolint:lll
|
||||||
|
httpPayloadBytesBigEndian = `000000dd120b088092b8c398feffffff011801200248d8fc0150505a3308041224544c535f45434448455f45434453415f574954485f4145535f3235365f4342435f5348411a096c6f63616c686f73746282010a34747970652e676f6f676c65617069732e636f6d2f696e65746d6f636b2e61756469742e4854545044657461696c73456e74697479124a12096c6f63616c686f73741a15687474703a2f2f6c6f63616c686f73742f6173646622084854545020312e312a1c0a0641636365707412120a106170706c69636174696f6e2f6a736f6e28818080f80738818080f807`
|
||||||
|
dnsPayloadBytesLittleEndian = `1b000000120b088092b8c398feffffff011801200148d8fc01505030014001`
|
||||||
|
dnsPayloadBytesBigEndian = `0000001b120b088092b8c398feffffff011801200148d8fc01505030014001`
|
||||||
|
)
|
||||||
|
|
||||||
|
func mustDecodeHex(hexBytes string) io.Reader {
|
||||||
|
b, err := hex.DecodeString(hexBytes)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return bytes.NewReader(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_eventReader_Read(t *testing.T) {
|
||||||
|
type fields struct {
|
||||||
|
source io.Reader
|
||||||
|
order binary.ByteOrder
|
||||||
|
}
|
||||||
|
type testCase struct {
|
||||||
|
name string
|
||||||
|
fields fields
|
||||||
|
wantEv audit.Event
|
||||||
|
wantErr bool
|
||||||
|
}
|
||||||
|
tests := []testCase{
|
||||||
|
{
|
||||||
|
name: "Read HTTP payload - little endian",
|
||||||
|
fields: fields{
|
||||||
|
source: mustDecodeHex(httpPayloadBytesLittleEndian),
|
||||||
|
order: binary.LittleEndian,
|
||||||
|
},
|
||||||
|
wantEv: testEvents[0],
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Read HTTP payload - big endian",
|
||||||
|
fields: fields{
|
||||||
|
source: mustDecodeHex(httpPayloadBytesBigEndian),
|
||||||
|
order: binary.BigEndian,
|
||||||
|
},
|
||||||
|
wantEv: testEvents[0],
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Read DNS payload - little endian",
|
||||||
|
fields: fields{
|
||||||
|
source: mustDecodeHex(dnsPayloadBytesLittleEndian),
|
||||||
|
order: binary.LittleEndian,
|
||||||
|
},
|
||||||
|
wantEv: testEvents[1],
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Read DNS payload - big endian",
|
||||||
|
fields: fields{
|
||||||
|
source: mustDecodeHex(dnsPayloadBytesBigEndian),
|
||||||
|
order: binary.BigEndian,
|
||||||
|
},
|
||||||
|
wantEv: testEvents[1],
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
scenario := func(tt testCase) func(t *testing.T) {
|
||||||
|
return func(t *testing.T) {
|
||||||
|
e := audit.NewEventReader(tt.fields.source, audit.WithReaderByteOrder(tt.fields.order))
|
||||||
|
gotEv, err := e.Read()
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("Read() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err == nil && !proto.Equal(gotEv.ProtoMessage(), tt.wantEv.ProtoMessage()) {
|
||||||
|
t.Errorf("Read() gotEv = %v, want %v", gotEv, tt.wantEv)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, scenario(tt))
|
||||||
|
}
|
||||||
|
}
|
|
@ -2,8 +2,31 @@ package audit
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
)
|
||||||
|
|
||||||
"google.golang.org/protobuf/proto"
|
var (
|
||||||
|
tlsToEntity = map[uint16]TLSVersion{
|
||||||
|
tls.VersionSSL30: TLSVersion_SSLv30,
|
||||||
|
tls.VersionTLS10: TLSVersion_TLS10,
|
||||||
|
tls.VersionTLS11: TLSVersion_TLS11,
|
||||||
|
tls.VersionTLS12: TLSVersion_TLS12,
|
||||||
|
tls.VersionTLS13: TLSVersion_TLS13,
|
||||||
|
}
|
||||||
|
entityToTls = map[TLSVersion]uint16{
|
||||||
|
TLSVersion_SSLv30: tls.VersionSSL30,
|
||||||
|
TLSVersion_TLS10: tls.VersionTLS10,
|
||||||
|
TLSVersion_TLS11: tls.VersionTLS11,
|
||||||
|
TLSVersion_TLS12: tls.VersionTLS12,
|
||||||
|
TLSVersion_TLS13: tls.VersionTLS13,
|
||||||
|
}
|
||||||
|
cipherSuiteIDLookup = func(name string) uint16 {
|
||||||
|
for _, cs := range tls.CipherSuites() {
|
||||||
|
if cs.Name == name {
|
||||||
|
return cs.ID
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
type TLSDetails struct {
|
type TLSDetails struct {
|
||||||
|
@ -12,24 +35,21 @@ type TLSDetails struct {
|
||||||
ServerName string
|
ServerName string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d TLSDetails) ProtoMessage() proto.Message {
|
func NewTLSDetailsFromProto(entity *TLSDetailsEntity) *TLSDetails {
|
||||||
var version TLSVersion
|
if entity == nil {
|
||||||
|
return nil
|
||||||
switch d.Version {
|
|
||||||
case tls.VersionTLS10:
|
|
||||||
version = TLSVersion_TLS10
|
|
||||||
case tls.VersionTLS11:
|
|
||||||
version = TLSVersion_TLS11
|
|
||||||
case tls.VersionTLS12:
|
|
||||||
version = TLSVersion_TLS12
|
|
||||||
case tls.VersionTLS13:
|
|
||||||
version = TLSVersion_TLS13
|
|
||||||
default:
|
|
||||||
version = TLSVersion_SSLv30
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return &TLSDetails{
|
||||||
|
Version: entityToTls[entity.GetVersion()],
|
||||||
|
CipherSuite: cipherSuiteIDLookup(entity.GetCipherSuite()),
|
||||||
|
ServerName: entity.GetServerName(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d TLSDetails) ProtoMessage() *TLSDetailsEntity {
|
||||||
return &TLSDetailsEntity{
|
return &TLSDetailsEntity{
|
||||||
Version: version,
|
Version: tlsToEntity[d.Version],
|
||||||
CipherSuite: tls.CipherSuiteName(d.CipherSuite),
|
CipherSuite: tls.CipherSuiteName(d.CipherSuite),
|
||||||
ServerName: d.ServerName,
|
ServerName: d.ServerName,
|
||||||
}
|
}
|
||||||
|
|
85
pkg/audit/writer.go
Normal file
85
pkg/audit/writer.go
Normal file
|
@ -0,0 +1,85 @@
|
||||||
|
//go:generate mockgen -source=$GOFILE -destination=./../../internal/mock/audit/writer.mock.go -package=audit_mock
|
||||||
|
|
||||||
|
package audit
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"google.golang.org/protobuf/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
WithWriterByteOrder = func(order binary.ByteOrder) func(writer *eventWriter) {
|
||||||
|
return func(writer *eventWriter) {
|
||||||
|
writer.byteOrder = order
|
||||||
|
}
|
||||||
|
}
|
||||||
|
defaultOrder = binary.BigEndian
|
||||||
|
ErrValueMostNotBeNil = errors.New("event value must not be nil")
|
||||||
|
)
|
||||||
|
|
||||||
|
type Writer interface {
|
||||||
|
io.Closer
|
||||||
|
Write(ev *Event) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type EventWriterOption func(writer *eventWriter)
|
||||||
|
|
||||||
|
func NewEventWriter(target io.Writer, opts ...EventWriterOption) Writer {
|
||||||
|
writer := &eventWriter{
|
||||||
|
target: target,
|
||||||
|
byteOrder: defaultOrder,
|
||||||
|
lengthPool: &sync.Pool{
|
||||||
|
New: func() interface{} {
|
||||||
|
return make([]byte, lengthBufferSize)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(writer)
|
||||||
|
}
|
||||||
|
|
||||||
|
return writer
|
||||||
|
}
|
||||||
|
|
||||||
|
type eventWriter struct {
|
||||||
|
lengthPool *sync.Pool
|
||||||
|
target io.Writer
|
||||||
|
byteOrder binary.ByteOrder
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e eventWriter) Write(ev *Event) (err error) {
|
||||||
|
if ev == nil {
|
||||||
|
return ErrValueMostNotBeNil
|
||||||
|
}
|
||||||
|
var bytes []byte
|
||||||
|
|
||||||
|
if bytes, err = proto.Marshal(ev.ProtoMessage()); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
buf := e.lengthPool.Get().([]byte)
|
||||||
|
e.byteOrder.PutUint32(buf, uint32(len(bytes)))
|
||||||
|
|
||||||
|
if _, err = e.target.Write(buf); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if _, err = e.target.Write(bytes); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if syncerInst, ok := e.target.(syncer); ok {
|
||||||
|
err = syncerInst.Sync()
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e eventWriter) Close() error {
|
||||||
|
if closer, isCloser := e.target.(io.Closer); isCloser {
|
||||||
|
return closer.Close()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
47
pkg/audit/writer_sink.go
Normal file
47
pkg/audit/writer_sink.go
Normal file
|
@ -0,0 +1,47 @@
|
||||||
|
package audit
|
||||||
|
|
||||||
|
type WriterSinkOption func(sink *writerCloserSink)
|
||||||
|
|
||||||
|
var (
|
||||||
|
WithCloseOnExit WriterSinkOption = func(sink *writerCloserSink) {
|
||||||
|
sink.closeOnExit = true
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
func NewWriterSink(name string, target Writer, opts ...WriterSinkOption) Sink {
|
||||||
|
sink := &writerCloserSink{
|
||||||
|
name: name,
|
||||||
|
target: target,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(sink)
|
||||||
|
}
|
||||||
|
|
||||||
|
return sink
|
||||||
|
}
|
||||||
|
|
||||||
|
type writerCloserSink struct {
|
||||||
|
name string
|
||||||
|
target Writer
|
||||||
|
closeOnExit bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type syncer interface {
|
||||||
|
Sync() error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f writerCloserSink) Name() string {
|
||||||
|
return f.name
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f writerCloserSink) OnSubscribe(evs <-chan Event) {
|
||||||
|
go func(target Writer, closeOnExit bool, evs <-chan Event) {
|
||||||
|
for ev := range evs {
|
||||||
|
_ = target.Write(&ev)
|
||||||
|
}
|
||||||
|
if closeOnExit {
|
||||||
|
_ = target.Close()
|
||||||
|
}
|
||||||
|
}(f.target, f.closeOnExit, evs)
|
||||||
|
}
|
72
pkg/audit/writer_sink_test.go
Normal file
72
pkg/audit/writer_sink_test.go
Normal file
|
@ -0,0 +1,72 @@
|
||||||
|
package audit_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang/mock/gomock"
|
||||||
|
audit_mock "gitlab.com/inetmock/inetmock/internal/mock/audit"
|
||||||
|
"gitlab.com/inetmock/inetmock/pkg/audit"
|
||||||
|
"gitlab.com/inetmock/inetmock/pkg/logging"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_writerCloserSink_OnSubscribe(t *testing.T) {
|
||||||
|
type testCase struct {
|
||||||
|
name string
|
||||||
|
events []audit.Event
|
||||||
|
}
|
||||||
|
tests := []testCase{
|
||||||
|
{
|
||||||
|
name: "Get a single event",
|
||||||
|
events: testEvents[:1],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Get multiple events",
|
||||||
|
events: testEvents,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
scenario := func(tt testCase) func(t *testing.T) {
|
||||||
|
return func(t *testing.T) {
|
||||||
|
wg := new(sync.WaitGroup)
|
||||||
|
wg.Add(len(tt.events))
|
||||||
|
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
t.Cleanup(ctrl.Finish)
|
||||||
|
writerMock := audit_mock.NewMockWriter(ctrl)
|
||||||
|
writerMock.
|
||||||
|
EXPECT().
|
||||||
|
Write(gomock.Any()).
|
||||||
|
Do(func(_ *audit.Event) {
|
||||||
|
wg.Done()
|
||||||
|
}).
|
||||||
|
Times(len(tt.events))
|
||||||
|
|
||||||
|
writerCloserSink := audit.NewWriterSink("WriterMock", writerMock, audit.WithCloseOnExit)
|
||||||
|
var evs audit.EventStream
|
||||||
|
var err error
|
||||||
|
|
||||||
|
if evs, err = audit.NewEventStream(logging.CreateTestLogger(t)); err != nil {
|
||||||
|
t.Errorf("NewEventStream() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = evs.RegisterSink(writerCloserSink); err != nil {
|
||||||
|
t.Errorf("RegisterSink() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, ev := range tt.events {
|
||||||
|
evs.Emit(ev)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
t.Errorf("not all events recorded in time")
|
||||||
|
case <-waitGroupDone(wg):
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, scenario(tt))
|
||||||
|
}
|
||||||
|
}
|
117
pkg/audit/writer_test.go
Normal file
117
pkg/audit/writer_test.go
Normal file
|
@ -0,0 +1,117 @@
|
||||||
|
//go:generate mockgen -source=$GOFILE -destination=./../../internal/mock/audit/writercloser.mock.go -package=audit_mock
|
||||||
|
|
||||||
|
package audit_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/hex"
|
||||||
|
"io"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/golang/mock/gomock"
|
||||||
|
audit_mock "gitlab.com/inetmock/inetmock/internal/mock/audit"
|
||||||
|
"gitlab.com/inetmock/inetmock/pkg/audit"
|
||||||
|
)
|
||||||
|
|
||||||
|
type WriterCloserSyncer interface {
|
||||||
|
io.WriteCloser
|
||||||
|
Sync() error
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_eventWriter_Write(t *testing.T) {
|
||||||
|
type fields struct {
|
||||||
|
order binary.ByteOrder
|
||||||
|
}
|
||||||
|
type args struct {
|
||||||
|
evs []audit.Event
|
||||||
|
}
|
||||||
|
type testCase struct {
|
||||||
|
name string
|
||||||
|
fields fields
|
||||||
|
args args
|
||||||
|
wantErr bool
|
||||||
|
}
|
||||||
|
tests := []testCase{
|
||||||
|
{
|
||||||
|
name: "Write a single event - little endian",
|
||||||
|
fields: fields{
|
||||||
|
order: binary.LittleEndian,
|
||||||
|
},
|
||||||
|
args: args{
|
||||||
|
evs: testEvents[:1],
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Write a single event - big endian",
|
||||||
|
fields: fields{
|
||||||
|
order: binary.BigEndian,
|
||||||
|
},
|
||||||
|
args: args{
|
||||||
|
evs: testEvents[:1],
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Write multiple events - little endian",
|
||||||
|
fields: fields{
|
||||||
|
order: binary.LittleEndian,
|
||||||
|
},
|
||||||
|
args: args{
|
||||||
|
evs: testEvents,
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Write multiple events - big endian",
|
||||||
|
fields: fields{
|
||||||
|
order: binary.BigEndian,
|
||||||
|
},
|
||||||
|
args: args{
|
||||||
|
evs: testEvents,
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
scenario := func(tt testCase) func(t *testing.T) {
|
||||||
|
return func(t *testing.T) {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
t.Cleanup(ctrl.Finish)
|
||||||
|
writerMock := audit_mock.NewMockWriterCloserSyncer(ctrl)
|
||||||
|
calls := make([]*gomock.Call, 0)
|
||||||
|
for i := 0; i < len(tt.args.evs); i++ {
|
||||||
|
calls = append(calls,
|
||||||
|
writerMock.
|
||||||
|
EXPECT().
|
||||||
|
Write(gomock.Any()).
|
||||||
|
Do(func(data []byte) {
|
||||||
|
t.Logf("got payload = %s", hex.EncodeToString(data))
|
||||||
|
t.Logf("got length %d", tt.fields.order.Uint32(data))
|
||||||
|
}),
|
||||||
|
writerMock.
|
||||||
|
EXPECT().
|
||||||
|
Write(gomock.Any()).
|
||||||
|
Do(func(data []byte) {
|
||||||
|
t.Logf("got payload = %s", hex.EncodeToString(data))
|
||||||
|
}),
|
||||||
|
writerMock.
|
||||||
|
EXPECT().
|
||||||
|
Sync(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
gomock.InOrder(calls...)
|
||||||
|
|
||||||
|
e := audit.NewEventWriter(writerMock, audit.WithWriterByteOrder(tt.fields.order))
|
||||||
|
|
||||||
|
for _, ev := range tt.args.evs {
|
||||||
|
if err := e.Write(&ev); (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("Write() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, scenario(tt))
|
||||||
|
}
|
||||||
|
}
|
|
@ -40,7 +40,7 @@ func AddTLSInterceptor(registry api.HandlerRegistry) (err error) {
|
||||||
registry.RegisterHandler(name, func() api.ProtocolHandler {
|
registry.RegisterHandler(name, func() api.ProtocolHandler {
|
||||||
return &tlsInterceptor{
|
return &tlsInterceptor{
|
||||||
logger: logger,
|
logger: logger,
|
||||||
currentConnectionsCount: &sync.WaitGroup{},
|
currentConnectionsCount: new(sync.WaitGroup),
|
||||||
currentConnections: make(map[uuid.UUID]*proxyConn),
|
currentConnections: make(map[uuid.UUID]*proxyConn),
|
||||||
connectionsMutex: &sync.Mutex{},
|
connectionsMutex: &sync.Mutex{},
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue