Subscribe sinks with a context to automatically remove them when context is canceled

This commit is contained in:
Peter 2021-01-26 18:50:53 +01:00
parent 6d2737b501
commit cc72595398
Signed by: prskr
GPG key ID: C1DB5D2E8DB512F9
10 changed files with 30 additions and 31 deletions

View file

@ -286,7 +286,7 @@ func (a *app) WithEventStream() App {
return
}
if err = eventStream.RegisterSink(sink.NewLogSink(a.Logger().Named("LogSink"))); err != nil {
if err = eventStream.RegisterSink(a.ctx, sink.NewLogSink(a.Logger().Named("LogSink"))); err != nil {
return
}

View file

@ -25,7 +25,7 @@ func (a *auditServer) ListSinks(context.Context, *ListSinksRequest) (*ListSinksR
func (a *auditServer) WatchEvents(req *WatchEventsRequest, srv Audit_WatchEventsServer) (err error) {
a.logger.Info("watcher attached", zap.String("name", req.WatcherName))
err = a.eventStream.RegisterSink(sink.NewGRPCSink(srv.Context(), req.WatcherName, func(ev audit.Event) {
err = a.eventStream.RegisterSink(srv.Context(), sink.NewGRPCSink(req.WatcherName, func(ev audit.Event) {
if err = srv.Send(ev.ProtoMessage()); err != nil {
return
}
@ -59,7 +59,7 @@ func (a *auditServer) RegisterFileSink(_ context.Context, req *RegisterFileSinkR
if writer, err = os.OpenFile(req.TargetPath, flags, permissions); err != nil {
return
}
if err = a.eventStream.RegisterSink(sink.NewWriterSink(req.TargetPath, audit.NewEventWriter(writer))); err != nil {
if err = a.eventStream.RegisterSink(context.Background(), sink.NewWriterSink(req.TargetPath, audit.NewEventWriter(writer))); err != nil {
return
}
resp = &RegisterFileSinkResponse{}

View file

@ -3,6 +3,7 @@
package audit
import (
"context"
"errors"
"io"
)
@ -15,17 +16,15 @@ type Emitter interface {
Emit(ev Event)
}
type CloseHandle func()
type Sink interface {
Name() string
OnSubscribe(evs <-chan Event, close CloseHandle)
OnSubscribe(evs <-chan Event)
}
type EventStream interface {
io.Closer
Emitter
RegisterSink(s Sink) error
RegisterSink(ctx context.Context, s Sink) error
Sinks() []string
RemoveSink(name string) (exists bool)
}

View file

@ -1,6 +1,7 @@
package audit
import (
"context"
"sync"
"time"
@ -89,7 +90,7 @@ func (e *eventStream) RemoveSink(name string) (exists bool) {
return
}
func (e *eventStream) RegisterSink(s Sink) error {
func (e *eventStream) RegisterSink(ctx context.Context, s Sink) error {
name := s.Name()
if _, exists := e.sinks[name]; exists {
@ -101,9 +102,12 @@ func (e *eventStream) RegisterSink(s Sink) error {
lock: new(sync.Mutex),
}
s.OnSubscribe(rs.downstream, func() {
s.OnSubscribe(rs.downstream)
go func() {
<-ctx.Done()
e.RemoveSink(name)
})
}()
e.sinks[name] = rs
return nil

View file

@ -1,6 +1,7 @@
package audit_test
import (
"context"
"crypto/tls"
"net"
"net/http"
@ -61,7 +62,7 @@ func (t *testSink) Name() string {
return t.name
}
func (t *testSink) OnSubscribe(evs <-chan audit.Event, _ audit.CloseHandle) {
func (t *testSink) OnSubscribe(evs <-chan audit.Event) {
go func() {
for ev := range evs {
if t.consumer != nil {
@ -105,7 +106,7 @@ func Test_eventStream_RegisterSink(t *testing.T) {
s: defaultSink,
},
setup: func(e audit.EventStream) {
_ = e.RegisterSink(defaultSink)
_ = e.RegisterSink(context.Background(), defaultSink)
},
wantErr: true,
},
@ -126,7 +127,7 @@ func Test_eventStream_RegisterSink(t *testing.T) {
tt.setup(e)
}
if err := e.RegisterSink(tt.args.s); (err != nil) != tt.wantErr {
if err := e.RegisterSink(context.Background(), tt.args.s); (err != nil) != tt.wantErr {
t.Errorf("RegisterSink() error = %v, wantErr %v", err, tt.wantErr)
}
@ -203,7 +204,7 @@ func Test_eventStream_Emit(t *testing.T) {
if tt.subscribe {
receivedWaitGroup.Add(len(tt.args.evs))
if err := e.RegisterSink(wgMockSink(t, receivedWaitGroup)); err != nil {
if err := e.RegisterSink(context.Background(), wgMockSink(t, receivedWaitGroup)); err != nil {
t.Errorf("RegisterSink() error = %v", err)
}
}

View file

@ -6,10 +6,9 @@ import (
"gitlab.com/inetmock/inetmock/pkg/audit"
)
func NewGRPCSink(ctx context.Context, name string, consumer func(ev audit.Event)) audit.Sink {
func NewGRPCSink(name string, consumer func(ev audit.Event)) audit.Sink {
return &grpcSink{
name: name,
ctx: ctx,
consumer: consumer,
}
}
@ -24,16 +23,10 @@ func (g grpcSink) Name() string {
return g.name
}
func (g grpcSink) OnSubscribe(evs <-chan audit.Event, handle audit.CloseHandle) {
go func(ctx context.Context, consumer func(ev audit.Event), evs <-chan audit.Event, handle audit.CloseHandle) {
for {
select {
case ev := <-evs:
consumer(ev)
case <-ctx.Done():
handle()
return
}
func (g grpcSink) OnSubscribe(evs <-chan audit.Event) {
go func(consumer func(ev audit.Event), evs <-chan audit.Event) {
for ev := range evs {
consumer(ev)
}
}(g.ctx, g.consumer, evs, handle)
}(g.consumer, evs)
}

View file

@ -24,7 +24,7 @@ func (logSink) Name() string {
return logSinkName
}
func (l logSink) OnSubscribe(evs <-chan audit.Event, _ audit.CloseHandle) {
func (l logSink) OnSubscribe(evs <-chan audit.Event) {
go func(logger logging.Logger, evs <-chan audit.Event) {
for ev := range evs {
eventLogger := logger

View file

@ -1,6 +1,7 @@
package sink_test
import (
"context"
"sync"
"testing"
"time"
@ -90,7 +91,7 @@ func Test_logSink_OnSubscribe(t *testing.T) {
t.Errorf("NewEventStream() error = %v", err)
}
if err = evs.RegisterSink(logSink); err != nil {
if err = evs.RegisterSink(context.Background(), logSink); err != nil {
t.Errorf("RegisterSink() error = %v", err)
}

View file

@ -33,7 +33,7 @@ func (f writerCloserSink) Name() string {
return f.name
}
func (f writerCloserSink) OnSubscribe(evs <-chan audit.Event, _ audit.CloseHandle) {
func (f writerCloserSink) OnSubscribe(evs <-chan audit.Event) {
go func(target audit.Writer, closeOnExit bool, evs <-chan audit.Event) {
for ev := range evs {
_ = target.Write(&ev)

View file

@ -1,6 +1,7 @@
package sink_test
import (
"context"
"crypto/tls"
"net"
"net/http"
@ -91,7 +92,7 @@ func Test_writerCloserSink_OnSubscribe(t *testing.T) {
t.Errorf("NewEventStream() error = %v", err)
}
if err = evs.RegisterSink(writerCloserSink); err != nil {
if err = evs.RegisterSink(context.Background(), writerCloserSink); err != nil {
t.Errorf("RegisterSink() error = %v", err)
}