diff --git a/api/proto/internal/rpc/audit.proto b/api/proto/internal/rpc/audit.proto index aa07869..2eb4a15 100644 --- a/api/proto/internal/rpc/audit.proto +++ b/api/proto/internal/rpc/audit.proto @@ -26,7 +26,7 @@ message RemoveFileSinkRequest { } message RemoveFileSinkResponse { - + bool SinkGotRemoved = 1; } service Audit { diff --git a/cmd/imctl/endpoints.go b/cmd/imctl/endpoints.go index d8c7d3e..997164d 100644 --- a/cmd/imctl/endpoints.go +++ b/cmd/imctl/endpoints.go @@ -51,7 +51,7 @@ func fromEndpoints(eps []*rpc.Endpoint) (out []*printableEndpoint) { func runGetEndpoints(_ *cobra.Command, _ []string) (err error) { endpointsClient := rpc.NewEndpointsClient(conn) - ctx, cancel := context.WithTimeout(context.Background(), grpcTimeout) + ctx, cancel := context.WithTimeout(appCtx, grpcTimeout) defer cancel() var endpointsResp *rpc.GetEndpointsResponse if endpointsResp, err = endpointsClient.GetEndpoints(ctx, &rpc.GetEndpointsRequest{}); err != nil { diff --git a/internal/rpc/audit_server.go b/internal/rpc/audit_server.go index db8d1fa..7cb59d2 100644 --- a/internal/rpc/audit_server.go +++ b/internal/rpc/audit_server.go @@ -5,20 +5,21 @@ import ( "io" "os" - "gitlab.com/inetmock/inetmock/internal/app" "gitlab.com/inetmock/inetmock/pkg/audit" "gitlab.com/inetmock/inetmock/pkg/audit/sink" + "gitlab.com/inetmock/inetmock/pkg/logging" "go.uber.org/zap" ) type auditServer struct { UnimplementedAuditServer - app app.App + logger logging.Logger + eventStream audit.EventStream } func (a *auditServer) WatchEvents(req *WatchEventsRequest, srv Audit_WatchEventsServer) (err error) { - a.app.Logger().Info("watcher attached", zap.String("name", req.WatcherName)) - err = a.app.EventStream().RegisterSink(sink.NewGRPCSink(srv.Context(), req.WatcherName, func(ev audit.Event) { + a.logger.Info("watcher attached", zap.String("name", req.WatcherName)) + err = a.eventStream.RegisterSink(sink.NewGRPCSink(srv.Context(), req.WatcherName, func(ev audit.Event) { if err = srv.Send(ev.ProtoMessage()); err != nil { return } @@ -29,7 +30,7 @@ func (a *auditServer) WatchEvents(req *WatchEventsRequest, srv Audit_WatchEvents } <-srv.Context().Done() - a.app.Logger().Info("Watcher detached", zap.String("name", req.WatcherName)) + a.logger.Info("Watcher detached", zap.String("name", req.WatcherName)) return } @@ -38,7 +39,7 @@ func (a *auditServer) RegisterFileSink(_ context.Context, req *RegisterFileSinkR if writer, err = os.OpenFile(req.TargetPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644); err != nil { return } - if err = a.app.EventStream().RegisterSink(sink.NewWriterSink(req.TargetPath, audit.NewEventWriter(writer))); err != nil { + if err = a.eventStream.RegisterSink(sink.NewWriterSink(req.TargetPath, audit.NewEventWriter(writer))); err != nil { return } resp = &RegisterFileSinkResponse{} @@ -46,6 +47,8 @@ func (a *auditServer) RegisterFileSink(_ context.Context, req *RegisterFileSinkR } func (a *auditServer) RemoveFileSink(_ context.Context, req *RemoveFileSinkRequest) (*RemoveFileSinkResponse, error) { - a.app.EventStream().RemoveSink(req.TargetPath) - return &RemoveFileSinkResponse{}, nil + gotRemoved := a.eventStream.RemoveSink(req.TargetPath) + return &RemoveFileSinkResponse{ + SinkGotRemoved: gotRemoved, + }, nil } diff --git a/internal/rpc/grpc_api.go b/internal/rpc/grpc_api.go index 6c71d8c..b5f0f65 100644 --- a/internal/rpc/grpc_api.go +++ b/internal/rpc/grpc_api.go @@ -53,7 +53,8 @@ func (i *inetmockAPI) StartServer() (err error) { }) RegisterAuditServer(i.server, &auditServer{ - app: i.app, + logger: i.app.Logger(), + eventStream: i.app.EventStream(), }) go i.startServerAsync(lis) diff --git a/pkg/audit/api.go b/pkg/audit/api.go index d269121..4129a2d 100644 --- a/pkg/audit/api.go +++ b/pkg/audit/api.go @@ -27,5 +27,5 @@ type EventStream interface { Emitter RegisterSink(s Sink) error Sinks() []string - RemoveSink(name string) + RemoveSink(name string) (exists bool) } diff --git a/pkg/audit/event_stream.go b/pkg/audit/event_stream.go index d3986b0..6dbd522 100644 --- a/pkg/audit/event_stream.go +++ b/pkg/audit/event_stream.go @@ -72,11 +72,12 @@ func (e *eventStream) Emit(ev Event) { } } -func (e *eventStream) RemoveSink(name string) { +func (e *eventStream) RemoveSink(name string) (exists bool) { e.lock.Lock() defer e.lock.Unlock() - sink, exists := e.sinks[name] + var sink *registeredSink + sink, exists = e.sinks[name] if !exists { return } @@ -84,6 +85,8 @@ func (e *eventStream) RemoveSink(name string) { defer sink.lock.Unlock() delete(e.sinks, name) close(sink.downstream) + + return } func (e *eventStream) RegisterSink(s Sink) error {