Subscribe sinks with a context to automatically remove them when context is canceled
This commit is contained in:
parent
6d2737b501
commit
cc72595398
10 changed files with 30 additions and 31 deletions
|
@ -286,7 +286,7 @@ func (a *app) WithEventStream() App {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = eventStream.RegisterSink(sink.NewLogSink(a.Logger().Named("LogSink"))); err != nil {
|
if err = eventStream.RegisterSink(a.ctx, sink.NewLogSink(a.Logger().Named("LogSink"))); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -25,7 +25,7 @@ func (a *auditServer) ListSinks(context.Context, *ListSinksRequest) (*ListSinksR
|
||||||
|
|
||||||
func (a *auditServer) WatchEvents(req *WatchEventsRequest, srv Audit_WatchEventsServer) (err error) {
|
func (a *auditServer) WatchEvents(req *WatchEventsRequest, srv Audit_WatchEventsServer) (err error) {
|
||||||
a.logger.Info("watcher attached", zap.String("name", req.WatcherName))
|
a.logger.Info("watcher attached", zap.String("name", req.WatcherName))
|
||||||
err = a.eventStream.RegisterSink(sink.NewGRPCSink(srv.Context(), req.WatcherName, func(ev audit.Event) {
|
err = a.eventStream.RegisterSink(srv.Context(), sink.NewGRPCSink(req.WatcherName, func(ev audit.Event) {
|
||||||
if err = srv.Send(ev.ProtoMessage()); err != nil {
|
if err = srv.Send(ev.ProtoMessage()); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -59,7 +59,7 @@ func (a *auditServer) RegisterFileSink(_ context.Context, req *RegisterFileSinkR
|
||||||
if writer, err = os.OpenFile(req.TargetPath, flags, permissions); err != nil {
|
if writer, err = os.OpenFile(req.TargetPath, flags, permissions); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err = a.eventStream.RegisterSink(sink.NewWriterSink(req.TargetPath, audit.NewEventWriter(writer))); err != nil {
|
if err = a.eventStream.RegisterSink(context.Background(), sink.NewWriterSink(req.TargetPath, audit.NewEventWriter(writer))); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
resp = &RegisterFileSinkResponse{}
|
resp = &RegisterFileSinkResponse{}
|
||||||
|
|
|
@ -3,6 +3,7 @@
|
||||||
package audit
|
package audit
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
)
|
)
|
||||||
|
@ -15,17 +16,15 @@ type Emitter interface {
|
||||||
Emit(ev Event)
|
Emit(ev Event)
|
||||||
}
|
}
|
||||||
|
|
||||||
type CloseHandle func()
|
|
||||||
|
|
||||||
type Sink interface {
|
type Sink interface {
|
||||||
Name() string
|
Name() string
|
||||||
OnSubscribe(evs <-chan Event, close CloseHandle)
|
OnSubscribe(evs <-chan Event)
|
||||||
}
|
}
|
||||||
|
|
||||||
type EventStream interface {
|
type EventStream interface {
|
||||||
io.Closer
|
io.Closer
|
||||||
Emitter
|
Emitter
|
||||||
RegisterSink(s Sink) error
|
RegisterSink(ctx context.Context, s Sink) error
|
||||||
Sinks() []string
|
Sinks() []string
|
||||||
RemoveSink(name string) (exists bool)
|
RemoveSink(name string) (exists bool)
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package audit
|
package audit
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -89,7 +90,7 @@ func (e *eventStream) RemoveSink(name string) (exists bool) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *eventStream) RegisterSink(s Sink) error {
|
func (e *eventStream) RegisterSink(ctx context.Context, s Sink) error {
|
||||||
name := s.Name()
|
name := s.Name()
|
||||||
|
|
||||||
if _, exists := e.sinks[name]; exists {
|
if _, exists := e.sinks[name]; exists {
|
||||||
|
@ -101,9 +102,12 @@ func (e *eventStream) RegisterSink(s Sink) error {
|
||||||
lock: new(sync.Mutex),
|
lock: new(sync.Mutex),
|
||||||
}
|
}
|
||||||
|
|
||||||
s.OnSubscribe(rs.downstream, func() {
|
s.OnSubscribe(rs.downstream)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
<-ctx.Done()
|
||||||
e.RemoveSink(name)
|
e.RemoveSink(name)
|
||||||
})
|
}()
|
||||||
|
|
||||||
e.sinks[name] = rs
|
e.sinks[name] = rs
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package audit_test
|
package audit_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
@ -61,7 +62,7 @@ func (t *testSink) Name() string {
|
||||||
return t.name
|
return t.name
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *testSink) OnSubscribe(evs <-chan audit.Event, _ audit.CloseHandle) {
|
func (t *testSink) OnSubscribe(evs <-chan audit.Event) {
|
||||||
go func() {
|
go func() {
|
||||||
for ev := range evs {
|
for ev := range evs {
|
||||||
if t.consumer != nil {
|
if t.consumer != nil {
|
||||||
|
@ -105,7 +106,7 @@ func Test_eventStream_RegisterSink(t *testing.T) {
|
||||||
s: defaultSink,
|
s: defaultSink,
|
||||||
},
|
},
|
||||||
setup: func(e audit.EventStream) {
|
setup: func(e audit.EventStream) {
|
||||||
_ = e.RegisterSink(defaultSink)
|
_ = e.RegisterSink(context.Background(), defaultSink)
|
||||||
},
|
},
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
},
|
},
|
||||||
|
@ -126,7 +127,7 @@ func Test_eventStream_RegisterSink(t *testing.T) {
|
||||||
tt.setup(e)
|
tt.setup(e)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := e.RegisterSink(tt.args.s); (err != nil) != tt.wantErr {
|
if err := e.RegisterSink(context.Background(), tt.args.s); (err != nil) != tt.wantErr {
|
||||||
t.Errorf("RegisterSink() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("RegisterSink() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -203,7 +204,7 @@ func Test_eventStream_Emit(t *testing.T) {
|
||||||
|
|
||||||
if tt.subscribe {
|
if tt.subscribe {
|
||||||
receivedWaitGroup.Add(len(tt.args.evs))
|
receivedWaitGroup.Add(len(tt.args.evs))
|
||||||
if err := e.RegisterSink(wgMockSink(t, receivedWaitGroup)); err != nil {
|
if err := e.RegisterSink(context.Background(), wgMockSink(t, receivedWaitGroup)); err != nil {
|
||||||
t.Errorf("RegisterSink() error = %v", err)
|
t.Errorf("RegisterSink() error = %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,10 +6,9 @@ import (
|
||||||
"gitlab.com/inetmock/inetmock/pkg/audit"
|
"gitlab.com/inetmock/inetmock/pkg/audit"
|
||||||
)
|
)
|
||||||
|
|
||||||
func NewGRPCSink(ctx context.Context, name string, consumer func(ev audit.Event)) audit.Sink {
|
func NewGRPCSink(name string, consumer func(ev audit.Event)) audit.Sink {
|
||||||
return &grpcSink{
|
return &grpcSink{
|
||||||
name: name,
|
name: name,
|
||||||
ctx: ctx,
|
|
||||||
consumer: consumer,
|
consumer: consumer,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -24,16 +23,10 @@ func (g grpcSink) Name() string {
|
||||||
return g.name
|
return g.name
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g grpcSink) OnSubscribe(evs <-chan audit.Event, handle audit.CloseHandle) {
|
func (g grpcSink) OnSubscribe(evs <-chan audit.Event) {
|
||||||
go func(ctx context.Context, consumer func(ev audit.Event), evs <-chan audit.Event, handle audit.CloseHandle) {
|
go func(consumer func(ev audit.Event), evs <-chan audit.Event) {
|
||||||
for {
|
for ev := range evs {
|
||||||
select {
|
consumer(ev)
|
||||||
case ev := <-evs:
|
|
||||||
consumer(ev)
|
|
||||||
case <-ctx.Done():
|
|
||||||
handle()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}(g.ctx, g.consumer, evs, handle)
|
}(g.consumer, evs)
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,7 +24,7 @@ func (logSink) Name() string {
|
||||||
return logSinkName
|
return logSinkName
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l logSink) OnSubscribe(evs <-chan audit.Event, _ audit.CloseHandle) {
|
func (l logSink) OnSubscribe(evs <-chan audit.Event) {
|
||||||
go func(logger logging.Logger, evs <-chan audit.Event) {
|
go func(logger logging.Logger, evs <-chan audit.Event) {
|
||||||
for ev := range evs {
|
for ev := range evs {
|
||||||
eventLogger := logger
|
eventLogger := logger
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package sink_test
|
package sink_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
@ -90,7 +91,7 @@ func Test_logSink_OnSubscribe(t *testing.T) {
|
||||||
t.Errorf("NewEventStream() error = %v", err)
|
t.Errorf("NewEventStream() error = %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = evs.RegisterSink(logSink); err != nil {
|
if err = evs.RegisterSink(context.Background(), logSink); err != nil {
|
||||||
t.Errorf("RegisterSink() error = %v", err)
|
t.Errorf("RegisterSink() error = %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -33,7 +33,7 @@ func (f writerCloserSink) Name() string {
|
||||||
return f.name
|
return f.name
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f writerCloserSink) OnSubscribe(evs <-chan audit.Event, _ audit.CloseHandle) {
|
func (f writerCloserSink) OnSubscribe(evs <-chan audit.Event) {
|
||||||
go func(target audit.Writer, closeOnExit bool, evs <-chan audit.Event) {
|
go func(target audit.Writer, closeOnExit bool, evs <-chan audit.Event) {
|
||||||
for ev := range evs {
|
for ev := range evs {
|
||||||
_ = target.Write(&ev)
|
_ = target.Write(&ev)
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package sink_test
|
package sink_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
@ -91,7 +92,7 @@ func Test_writerCloserSink_OnSubscribe(t *testing.T) {
|
||||||
t.Errorf("NewEventStream() error = %v", err)
|
t.Errorf("NewEventStream() error = %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = evs.RegisterSink(writerCloserSink); err != nil {
|
if err = evs.RegisterSink(context.Background(), writerCloserSink); err != nil {
|
||||||
t.Errorf("RegisterSink() error = %v", err)
|
t.Errorf("RegisterSink() error = %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue