Subscribe sinks with a context to automatically remove them when context is canceled

This commit is contained in:
Peter 2021-01-26 18:50:53 +01:00
parent 6d2737b501
commit cc72595398
Signed by: prskr
GPG key ID: C1DB5D2E8DB512F9
10 changed files with 30 additions and 31 deletions

View file

@ -286,7 +286,7 @@ func (a *app) WithEventStream() App {
return return
} }
if err = eventStream.RegisterSink(sink.NewLogSink(a.Logger().Named("LogSink"))); err != nil { if err = eventStream.RegisterSink(a.ctx, sink.NewLogSink(a.Logger().Named("LogSink"))); err != nil {
return return
} }

View file

@ -25,7 +25,7 @@ func (a *auditServer) ListSinks(context.Context, *ListSinksRequest) (*ListSinksR
func (a *auditServer) WatchEvents(req *WatchEventsRequest, srv Audit_WatchEventsServer) (err error) { func (a *auditServer) WatchEvents(req *WatchEventsRequest, srv Audit_WatchEventsServer) (err error) {
a.logger.Info("watcher attached", zap.String("name", req.WatcherName)) a.logger.Info("watcher attached", zap.String("name", req.WatcherName))
err = a.eventStream.RegisterSink(sink.NewGRPCSink(srv.Context(), req.WatcherName, func(ev audit.Event) { err = a.eventStream.RegisterSink(srv.Context(), sink.NewGRPCSink(req.WatcherName, func(ev audit.Event) {
if err = srv.Send(ev.ProtoMessage()); err != nil { if err = srv.Send(ev.ProtoMessage()); err != nil {
return return
} }
@ -59,7 +59,7 @@ func (a *auditServer) RegisterFileSink(_ context.Context, req *RegisterFileSinkR
if writer, err = os.OpenFile(req.TargetPath, flags, permissions); err != nil { if writer, err = os.OpenFile(req.TargetPath, flags, permissions); err != nil {
return return
} }
if err = a.eventStream.RegisterSink(sink.NewWriterSink(req.TargetPath, audit.NewEventWriter(writer))); err != nil { if err = a.eventStream.RegisterSink(context.Background(), sink.NewWriterSink(req.TargetPath, audit.NewEventWriter(writer))); err != nil {
return return
} }
resp = &RegisterFileSinkResponse{} resp = &RegisterFileSinkResponse{}

View file

@ -3,6 +3,7 @@
package audit package audit
import ( import (
"context"
"errors" "errors"
"io" "io"
) )
@ -15,17 +16,15 @@ type Emitter interface {
Emit(ev Event) Emit(ev Event)
} }
type CloseHandle func()
type Sink interface { type Sink interface {
Name() string Name() string
OnSubscribe(evs <-chan Event, close CloseHandle) OnSubscribe(evs <-chan Event)
} }
type EventStream interface { type EventStream interface {
io.Closer io.Closer
Emitter Emitter
RegisterSink(s Sink) error RegisterSink(ctx context.Context, s Sink) error
Sinks() []string Sinks() []string
RemoveSink(name string) (exists bool) RemoveSink(name string) (exists bool)
} }

View file

@ -1,6 +1,7 @@
package audit package audit
import ( import (
"context"
"sync" "sync"
"time" "time"
@ -89,7 +90,7 @@ func (e *eventStream) RemoveSink(name string) (exists bool) {
return return
} }
func (e *eventStream) RegisterSink(s Sink) error { func (e *eventStream) RegisterSink(ctx context.Context, s Sink) error {
name := s.Name() name := s.Name()
if _, exists := e.sinks[name]; exists { if _, exists := e.sinks[name]; exists {
@ -101,9 +102,12 @@ func (e *eventStream) RegisterSink(s Sink) error {
lock: new(sync.Mutex), lock: new(sync.Mutex),
} }
s.OnSubscribe(rs.downstream, func() { s.OnSubscribe(rs.downstream)
go func() {
<-ctx.Done()
e.RemoveSink(name) e.RemoveSink(name)
}) }()
e.sinks[name] = rs e.sinks[name] = rs
return nil return nil

View file

@ -1,6 +1,7 @@
package audit_test package audit_test
import ( import (
"context"
"crypto/tls" "crypto/tls"
"net" "net"
"net/http" "net/http"
@ -61,7 +62,7 @@ func (t *testSink) Name() string {
return t.name return t.name
} }
func (t *testSink) OnSubscribe(evs <-chan audit.Event, _ audit.CloseHandle) { func (t *testSink) OnSubscribe(evs <-chan audit.Event) {
go func() { go func() {
for ev := range evs { for ev := range evs {
if t.consumer != nil { if t.consumer != nil {
@ -105,7 +106,7 @@ func Test_eventStream_RegisterSink(t *testing.T) {
s: defaultSink, s: defaultSink,
}, },
setup: func(e audit.EventStream) { setup: func(e audit.EventStream) {
_ = e.RegisterSink(defaultSink) _ = e.RegisterSink(context.Background(), defaultSink)
}, },
wantErr: true, wantErr: true,
}, },
@ -126,7 +127,7 @@ func Test_eventStream_RegisterSink(t *testing.T) {
tt.setup(e) tt.setup(e)
} }
if err := e.RegisterSink(tt.args.s); (err != nil) != tt.wantErr { if err := e.RegisterSink(context.Background(), tt.args.s); (err != nil) != tt.wantErr {
t.Errorf("RegisterSink() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("RegisterSink() error = %v, wantErr %v", err, tt.wantErr)
} }
@ -203,7 +204,7 @@ func Test_eventStream_Emit(t *testing.T) {
if tt.subscribe { if tt.subscribe {
receivedWaitGroup.Add(len(tt.args.evs)) receivedWaitGroup.Add(len(tt.args.evs))
if err := e.RegisterSink(wgMockSink(t, receivedWaitGroup)); err != nil { if err := e.RegisterSink(context.Background(), wgMockSink(t, receivedWaitGroup)); err != nil {
t.Errorf("RegisterSink() error = %v", err) t.Errorf("RegisterSink() error = %v", err)
} }
} }

View file

@ -6,10 +6,9 @@ import (
"gitlab.com/inetmock/inetmock/pkg/audit" "gitlab.com/inetmock/inetmock/pkg/audit"
) )
func NewGRPCSink(ctx context.Context, name string, consumer func(ev audit.Event)) audit.Sink { func NewGRPCSink(name string, consumer func(ev audit.Event)) audit.Sink {
return &grpcSink{ return &grpcSink{
name: name, name: name,
ctx: ctx,
consumer: consumer, consumer: consumer,
} }
} }
@ -24,16 +23,10 @@ func (g grpcSink) Name() string {
return g.name return g.name
} }
func (g grpcSink) OnSubscribe(evs <-chan audit.Event, handle audit.CloseHandle) { func (g grpcSink) OnSubscribe(evs <-chan audit.Event) {
go func(ctx context.Context, consumer func(ev audit.Event), evs <-chan audit.Event, handle audit.CloseHandle) { go func(consumer func(ev audit.Event), evs <-chan audit.Event) {
for { for ev := range evs {
select { consumer(ev)
case ev := <-evs:
consumer(ev)
case <-ctx.Done():
handle()
return
}
} }
}(g.ctx, g.consumer, evs, handle) }(g.consumer, evs)
} }

View file

@ -24,7 +24,7 @@ func (logSink) Name() string {
return logSinkName return logSinkName
} }
func (l logSink) OnSubscribe(evs <-chan audit.Event, _ audit.CloseHandle) { func (l logSink) OnSubscribe(evs <-chan audit.Event) {
go func(logger logging.Logger, evs <-chan audit.Event) { go func(logger logging.Logger, evs <-chan audit.Event) {
for ev := range evs { for ev := range evs {
eventLogger := logger eventLogger := logger

View file

@ -1,6 +1,7 @@
package sink_test package sink_test
import ( import (
"context"
"sync" "sync"
"testing" "testing"
"time" "time"
@ -90,7 +91,7 @@ func Test_logSink_OnSubscribe(t *testing.T) {
t.Errorf("NewEventStream() error = %v", err) t.Errorf("NewEventStream() error = %v", err)
} }
if err = evs.RegisterSink(logSink); err != nil { if err = evs.RegisterSink(context.Background(), logSink); err != nil {
t.Errorf("RegisterSink() error = %v", err) t.Errorf("RegisterSink() error = %v", err)
} }

View file

@ -33,7 +33,7 @@ func (f writerCloserSink) Name() string {
return f.name return f.name
} }
func (f writerCloserSink) OnSubscribe(evs <-chan audit.Event, _ audit.CloseHandle) { func (f writerCloserSink) OnSubscribe(evs <-chan audit.Event) {
go func(target audit.Writer, closeOnExit bool, evs <-chan audit.Event) { go func(target audit.Writer, closeOnExit bool, evs <-chan audit.Event) {
for ev := range evs { for ev := range evs {
_ = target.Write(&ev) _ = target.Write(&ev)

View file

@ -1,6 +1,7 @@
package sink_test package sink_test
import ( import (
"context"
"crypto/tls" "crypto/tls"
"net" "net"
"net/http" "net/http"
@ -91,7 +92,7 @@ func Test_writerCloserSink_OnSubscribe(t *testing.T) {
t.Errorf("NewEventStream() error = %v", err) t.Errorf("NewEventStream() error = %v", err)
} }
if err = evs.RegisterSink(writerCloserSink); err != nil { if err = evs.RegisterSink(context.Background(), writerCloserSink); err != nil {
t.Errorf("RegisterSink() error = %v", err) t.Errorf("RegisterSink() error = %v", err)
} }