diff --git a/internal/app/app.go b/internal/app/app.go index 48c137b..871b737 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -286,7 +286,7 @@ func (a *app) WithEventStream() App { return } - if err = eventStream.RegisterSink(sink.NewLogSink(a.Logger().Named("LogSink"))); err != nil { + if err = eventStream.RegisterSink(a.ctx, sink.NewLogSink(a.Logger().Named("LogSink"))); err != nil { return } diff --git a/internal/rpc/audit_server.go b/internal/rpc/audit_server.go index 0fcfa8a..b9ac4da 100644 --- a/internal/rpc/audit_server.go +++ b/internal/rpc/audit_server.go @@ -25,7 +25,7 @@ func (a *auditServer) ListSinks(context.Context, *ListSinksRequest) (*ListSinksR func (a *auditServer) WatchEvents(req *WatchEventsRequest, srv Audit_WatchEventsServer) (err error) { a.logger.Info("watcher attached", zap.String("name", req.WatcherName)) - err = a.eventStream.RegisterSink(sink.NewGRPCSink(srv.Context(), req.WatcherName, func(ev audit.Event) { + err = a.eventStream.RegisterSink(srv.Context(), sink.NewGRPCSink(req.WatcherName, func(ev audit.Event) { if err = srv.Send(ev.ProtoMessage()); err != nil { return } @@ -59,7 +59,7 @@ func (a *auditServer) RegisterFileSink(_ context.Context, req *RegisterFileSinkR if writer, err = os.OpenFile(req.TargetPath, flags, permissions); err != nil { return } - if err = a.eventStream.RegisterSink(sink.NewWriterSink(req.TargetPath, audit.NewEventWriter(writer))); err != nil { + if err = a.eventStream.RegisterSink(context.Background(), sink.NewWriterSink(req.TargetPath, audit.NewEventWriter(writer))); err != nil { return } resp = &RegisterFileSinkResponse{} diff --git a/pkg/audit/api.go b/pkg/audit/api.go index 4129a2d..b3a789f 100644 --- a/pkg/audit/api.go +++ b/pkg/audit/api.go @@ -3,6 +3,7 @@ package audit import ( + "context" "errors" "io" ) @@ -15,17 +16,15 @@ type Emitter interface { Emit(ev Event) } -type CloseHandle func() - type Sink interface { Name() string - OnSubscribe(evs <-chan Event, close CloseHandle) + OnSubscribe(evs <-chan Event) } type EventStream interface { io.Closer Emitter - RegisterSink(s Sink) error + RegisterSink(ctx context.Context, s Sink) error Sinks() []string RemoveSink(name string) (exists bool) } diff --git a/pkg/audit/event_stream.go b/pkg/audit/event_stream.go index 6dbd522..262a29d 100644 --- a/pkg/audit/event_stream.go +++ b/pkg/audit/event_stream.go @@ -1,6 +1,7 @@ package audit import ( + "context" "sync" "time" @@ -89,7 +90,7 @@ func (e *eventStream) RemoveSink(name string) (exists bool) { return } -func (e *eventStream) RegisterSink(s Sink) error { +func (e *eventStream) RegisterSink(ctx context.Context, s Sink) error { name := s.Name() if _, exists := e.sinks[name]; exists { @@ -101,9 +102,12 @@ func (e *eventStream) RegisterSink(s Sink) error { lock: new(sync.Mutex), } - s.OnSubscribe(rs.downstream, func() { + s.OnSubscribe(rs.downstream) + + go func() { + <-ctx.Done() e.RemoveSink(name) - }) + }() e.sinks[name] = rs return nil diff --git a/pkg/audit/event_stream_test.go b/pkg/audit/event_stream_test.go index e6b0e3d..5be54a2 100644 --- a/pkg/audit/event_stream_test.go +++ b/pkg/audit/event_stream_test.go @@ -1,6 +1,7 @@ package audit_test import ( + "context" "crypto/tls" "net" "net/http" @@ -61,7 +62,7 @@ func (t *testSink) Name() string { return t.name } -func (t *testSink) OnSubscribe(evs <-chan audit.Event, _ audit.CloseHandle) { +func (t *testSink) OnSubscribe(evs <-chan audit.Event) { go func() { for ev := range evs { if t.consumer != nil { @@ -105,7 +106,7 @@ func Test_eventStream_RegisterSink(t *testing.T) { s: defaultSink, }, setup: func(e audit.EventStream) { - _ = e.RegisterSink(defaultSink) + _ = e.RegisterSink(context.Background(), defaultSink) }, wantErr: true, }, @@ -126,7 +127,7 @@ func Test_eventStream_RegisterSink(t *testing.T) { tt.setup(e) } - if err := e.RegisterSink(tt.args.s); (err != nil) != tt.wantErr { + if err := e.RegisterSink(context.Background(), tt.args.s); (err != nil) != tt.wantErr { t.Errorf("RegisterSink() error = %v, wantErr %v", err, tt.wantErr) } @@ -203,7 +204,7 @@ func Test_eventStream_Emit(t *testing.T) { if tt.subscribe { receivedWaitGroup.Add(len(tt.args.evs)) - if err := e.RegisterSink(wgMockSink(t, receivedWaitGroup)); err != nil { + if err := e.RegisterSink(context.Background(), wgMockSink(t, receivedWaitGroup)); err != nil { t.Errorf("RegisterSink() error = %v", err) } } diff --git a/pkg/audit/sink/grpc_sink.go b/pkg/audit/sink/grpc_sink.go index de52156..ee74446 100644 --- a/pkg/audit/sink/grpc_sink.go +++ b/pkg/audit/sink/grpc_sink.go @@ -6,10 +6,9 @@ import ( "gitlab.com/inetmock/inetmock/pkg/audit" ) -func NewGRPCSink(ctx context.Context, name string, consumer func(ev audit.Event)) audit.Sink { +func NewGRPCSink(name string, consumer func(ev audit.Event)) audit.Sink { return &grpcSink{ name: name, - ctx: ctx, consumer: consumer, } } @@ -24,16 +23,10 @@ func (g grpcSink) Name() string { return g.name } -func (g grpcSink) OnSubscribe(evs <-chan audit.Event, handle audit.CloseHandle) { - go func(ctx context.Context, consumer func(ev audit.Event), evs <-chan audit.Event, handle audit.CloseHandle) { - for { - select { - case ev := <-evs: - consumer(ev) - case <-ctx.Done(): - handle() - return - } +func (g grpcSink) OnSubscribe(evs <-chan audit.Event) { + go func(consumer func(ev audit.Event), evs <-chan audit.Event) { + for ev := range evs { + consumer(ev) } - }(g.ctx, g.consumer, evs, handle) + }(g.consumer, evs) } diff --git a/pkg/audit/sink/log_sink.go b/pkg/audit/sink/log_sink.go index 52e9a5f..efebc0b 100644 --- a/pkg/audit/sink/log_sink.go +++ b/pkg/audit/sink/log_sink.go @@ -24,7 +24,7 @@ func (logSink) Name() string { return logSinkName } -func (l logSink) OnSubscribe(evs <-chan audit.Event, _ audit.CloseHandle) { +func (l logSink) OnSubscribe(evs <-chan audit.Event) { go func(logger logging.Logger, evs <-chan audit.Event) { for ev := range evs { eventLogger := logger diff --git a/pkg/audit/sink/log_sink_test.go b/pkg/audit/sink/log_sink_test.go index c64f200..fe072d7 100644 --- a/pkg/audit/sink/log_sink_test.go +++ b/pkg/audit/sink/log_sink_test.go @@ -1,6 +1,7 @@ package sink_test import ( + "context" "sync" "testing" "time" @@ -90,7 +91,7 @@ func Test_logSink_OnSubscribe(t *testing.T) { t.Errorf("NewEventStream() error = %v", err) } - if err = evs.RegisterSink(logSink); err != nil { + if err = evs.RegisterSink(context.Background(), logSink); err != nil { t.Errorf("RegisterSink() error = %v", err) } diff --git a/pkg/audit/sink/writer_sink.go b/pkg/audit/sink/writer_sink.go index 5409dde..e0e2ba5 100644 --- a/pkg/audit/sink/writer_sink.go +++ b/pkg/audit/sink/writer_sink.go @@ -33,7 +33,7 @@ func (f writerCloserSink) Name() string { return f.name } -func (f writerCloserSink) OnSubscribe(evs <-chan audit.Event, _ audit.CloseHandle) { +func (f writerCloserSink) OnSubscribe(evs <-chan audit.Event) { go func(target audit.Writer, closeOnExit bool, evs <-chan audit.Event) { for ev := range evs { _ = target.Write(&ev) diff --git a/pkg/audit/sink/writer_sink_test.go b/pkg/audit/sink/writer_sink_test.go index cb955b9..885a506 100644 --- a/pkg/audit/sink/writer_sink_test.go +++ b/pkg/audit/sink/writer_sink_test.go @@ -1,6 +1,7 @@ package sink_test import ( + "context" "crypto/tls" "net" "net/http" @@ -91,7 +92,7 @@ func Test_writerCloserSink_OnSubscribe(t *testing.T) { t.Errorf("NewEventStream() error = %v", err) } - if err = evs.RegisterSink(writerCloserSink); err != nil { + if err = evs.RegisterSink(context.Background(), writerCloserSink); err != nil { t.Errorf("RegisterSink() error = %v", err) }