Subscribe sinks with a context to automatically remove them when context is canceled
This commit is contained in:
parent
6d2737b501
commit
cc72595398
10 changed files with 30 additions and 31 deletions
|
@ -286,7 +286,7 @@ func (a *app) WithEventStream() App {
|
|||
return
|
||||
}
|
||||
|
||||
if err = eventStream.RegisterSink(sink.NewLogSink(a.Logger().Named("LogSink"))); err != nil {
|
||||
if err = eventStream.RegisterSink(a.ctx, sink.NewLogSink(a.Logger().Named("LogSink"))); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -25,7 +25,7 @@ func (a *auditServer) ListSinks(context.Context, *ListSinksRequest) (*ListSinksR
|
|||
|
||||
func (a *auditServer) WatchEvents(req *WatchEventsRequest, srv Audit_WatchEventsServer) (err error) {
|
||||
a.logger.Info("watcher attached", zap.String("name", req.WatcherName))
|
||||
err = a.eventStream.RegisterSink(sink.NewGRPCSink(srv.Context(), req.WatcherName, func(ev audit.Event) {
|
||||
err = a.eventStream.RegisterSink(srv.Context(), sink.NewGRPCSink(req.WatcherName, func(ev audit.Event) {
|
||||
if err = srv.Send(ev.ProtoMessage()); err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -59,7 +59,7 @@ func (a *auditServer) RegisterFileSink(_ context.Context, req *RegisterFileSinkR
|
|||
if writer, err = os.OpenFile(req.TargetPath, flags, permissions); err != nil {
|
||||
return
|
||||
}
|
||||
if err = a.eventStream.RegisterSink(sink.NewWriterSink(req.TargetPath, audit.NewEventWriter(writer))); err != nil {
|
||||
if err = a.eventStream.RegisterSink(context.Background(), sink.NewWriterSink(req.TargetPath, audit.NewEventWriter(writer))); err != nil {
|
||||
return
|
||||
}
|
||||
resp = &RegisterFileSinkResponse{}
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
package audit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
)
|
||||
|
@ -15,17 +16,15 @@ type Emitter interface {
|
|||
Emit(ev Event)
|
||||
}
|
||||
|
||||
type CloseHandle func()
|
||||
|
||||
type Sink interface {
|
||||
Name() string
|
||||
OnSubscribe(evs <-chan Event, close CloseHandle)
|
||||
OnSubscribe(evs <-chan Event)
|
||||
}
|
||||
|
||||
type EventStream interface {
|
||||
io.Closer
|
||||
Emitter
|
||||
RegisterSink(s Sink) error
|
||||
RegisterSink(ctx context.Context, s Sink) error
|
||||
Sinks() []string
|
||||
RemoveSink(name string) (exists bool)
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package audit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
|
@ -89,7 +90,7 @@ func (e *eventStream) RemoveSink(name string) (exists bool) {
|
|||
return
|
||||
}
|
||||
|
||||
func (e *eventStream) RegisterSink(s Sink) error {
|
||||
func (e *eventStream) RegisterSink(ctx context.Context, s Sink) error {
|
||||
name := s.Name()
|
||||
|
||||
if _, exists := e.sinks[name]; exists {
|
||||
|
@ -101,9 +102,12 @@ func (e *eventStream) RegisterSink(s Sink) error {
|
|||
lock: new(sync.Mutex),
|
||||
}
|
||||
|
||||
s.OnSubscribe(rs.downstream, func() {
|
||||
s.OnSubscribe(rs.downstream)
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
e.RemoveSink(name)
|
||||
})
|
||||
}()
|
||||
|
||||
e.sinks[name] = rs
|
||||
return nil
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package audit_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"net/http"
|
||||
|
@ -61,7 +62,7 @@ func (t *testSink) Name() string {
|
|||
return t.name
|
||||
}
|
||||
|
||||
func (t *testSink) OnSubscribe(evs <-chan audit.Event, _ audit.CloseHandle) {
|
||||
func (t *testSink) OnSubscribe(evs <-chan audit.Event) {
|
||||
go func() {
|
||||
for ev := range evs {
|
||||
if t.consumer != nil {
|
||||
|
@ -105,7 +106,7 @@ func Test_eventStream_RegisterSink(t *testing.T) {
|
|||
s: defaultSink,
|
||||
},
|
||||
setup: func(e audit.EventStream) {
|
||||
_ = e.RegisterSink(defaultSink)
|
||||
_ = e.RegisterSink(context.Background(), defaultSink)
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
|
@ -126,7 +127,7 @@ func Test_eventStream_RegisterSink(t *testing.T) {
|
|||
tt.setup(e)
|
||||
}
|
||||
|
||||
if err := e.RegisterSink(tt.args.s); (err != nil) != tt.wantErr {
|
||||
if err := e.RegisterSink(context.Background(), tt.args.s); (err != nil) != tt.wantErr {
|
||||
t.Errorf("RegisterSink() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
|
||||
|
@ -203,7 +204,7 @@ func Test_eventStream_Emit(t *testing.T) {
|
|||
|
||||
if tt.subscribe {
|
||||
receivedWaitGroup.Add(len(tt.args.evs))
|
||||
if err := e.RegisterSink(wgMockSink(t, receivedWaitGroup)); err != nil {
|
||||
if err := e.RegisterSink(context.Background(), wgMockSink(t, receivedWaitGroup)); err != nil {
|
||||
t.Errorf("RegisterSink() error = %v", err)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,10 +6,9 @@ import (
|
|||
"gitlab.com/inetmock/inetmock/pkg/audit"
|
||||
)
|
||||
|
||||
func NewGRPCSink(ctx context.Context, name string, consumer func(ev audit.Event)) audit.Sink {
|
||||
func NewGRPCSink(name string, consumer func(ev audit.Event)) audit.Sink {
|
||||
return &grpcSink{
|
||||
name: name,
|
||||
ctx: ctx,
|
||||
consumer: consumer,
|
||||
}
|
||||
}
|
||||
|
@ -24,16 +23,10 @@ func (g grpcSink) Name() string {
|
|||
return g.name
|
||||
}
|
||||
|
||||
func (g grpcSink) OnSubscribe(evs <-chan audit.Event, handle audit.CloseHandle) {
|
||||
go func(ctx context.Context, consumer func(ev audit.Event), evs <-chan audit.Event, handle audit.CloseHandle) {
|
||||
for {
|
||||
select {
|
||||
case ev := <-evs:
|
||||
func (g grpcSink) OnSubscribe(evs <-chan audit.Event) {
|
||||
go func(consumer func(ev audit.Event), evs <-chan audit.Event) {
|
||||
for ev := range evs {
|
||||
consumer(ev)
|
||||
case <-ctx.Done():
|
||||
handle()
|
||||
return
|
||||
}
|
||||
}
|
||||
}(g.ctx, g.consumer, evs, handle)
|
||||
}(g.consumer, evs)
|
||||
}
|
||||
|
|
|
@ -24,7 +24,7 @@ func (logSink) Name() string {
|
|||
return logSinkName
|
||||
}
|
||||
|
||||
func (l logSink) OnSubscribe(evs <-chan audit.Event, _ audit.CloseHandle) {
|
||||
func (l logSink) OnSubscribe(evs <-chan audit.Event) {
|
||||
go func(logger logging.Logger, evs <-chan audit.Event) {
|
||||
for ev := range evs {
|
||||
eventLogger := logger
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package sink_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
@ -90,7 +91,7 @@ func Test_logSink_OnSubscribe(t *testing.T) {
|
|||
t.Errorf("NewEventStream() error = %v", err)
|
||||
}
|
||||
|
||||
if err = evs.RegisterSink(logSink); err != nil {
|
||||
if err = evs.RegisterSink(context.Background(), logSink); err != nil {
|
||||
t.Errorf("RegisterSink() error = %v", err)
|
||||
}
|
||||
|
||||
|
|
|
@ -33,7 +33,7 @@ func (f writerCloserSink) Name() string {
|
|||
return f.name
|
||||
}
|
||||
|
||||
func (f writerCloserSink) OnSubscribe(evs <-chan audit.Event, _ audit.CloseHandle) {
|
||||
func (f writerCloserSink) OnSubscribe(evs <-chan audit.Event) {
|
||||
go func(target audit.Writer, closeOnExit bool, evs <-chan audit.Event) {
|
||||
for ev := range evs {
|
||||
_ = target.Write(&ev)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package sink_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"net/http"
|
||||
|
@ -91,7 +92,7 @@ func Test_writerCloserSink_OnSubscribe(t *testing.T) {
|
|||
t.Errorf("NewEventStream() error = %v", err)
|
||||
}
|
||||
|
||||
if err = evs.RegisterSink(writerCloserSink); err != nil {
|
||||
if err = evs.RegisterSink(context.Background(), writerCloserSink); err != nil {
|
||||
t.Errorf("RegisterSink() error = %v", err)
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue