This commit is contained in:
Peter 2021-04-22 20:19:07 +00:00
parent 29662726f6
commit 6442f9f915
63 changed files with 1889 additions and 1334 deletions

View file

@ -18,6 +18,9 @@ linters-settings:
- ifElseChain - ifElseChain
- octalLiteral - octalLiteral
- wrapperFunc - wrapperFunc
settings:
hugeParam:
sizeThreshold: 200
gocyclo: gocyclo:
min-complexity: 15 min-complexity: 15
goimports: goimports:
@ -57,6 +60,7 @@ linters:
- dupl - dupl
- errcheck - errcheck
- exhaustive - exhaustive
- exportloopref
- funlen - funlen
- goconst - goconst
- gocritic - gocritic
@ -71,12 +75,15 @@ linters:
- lll - lll
- misspell - misspell
- nakedret - nakedret
- nestif
- noctx - noctx
- nolintlint - nolintlint
- scopelint - paralleltest
- prealloc
- staticcheck - staticcheck
- structcheck - structcheck
- stylecheck - stylecheck
- thelper
- typecheck - typecheck
- unconvert - unconvert
- unparam - unparam
@ -92,11 +99,6 @@ issues:
linters: linters:
- gomnd - gomnd
# https://github.com/go-critic/go-critic/issues/926
- linters:
- gocritic
text: "unnecessaryDefer:"
run: run:
skip-dirs: skip-dirs:
- internal/mock - internal/mock
@ -104,6 +106,4 @@ run:
# golangci.com configuration # golangci.com configuration
# https://github.com/golangci/golangci/wiki/Configuration # https://github.com/golangci/golangci/wiki/Configuration
service: service:
golangci-lint-version: 1.37.x # use the fixed version to not introduce new linters unexpectedly golangci-lint-version: 1.39.x # use the fixed version to not introduce new linters unexpectedly
prepare:
- echo "here I can run custom commands, but no preparation needed for this repo"

13
.run/imctl.run.xml Normal file
View file

@ -0,0 +1,13 @@
<component name="ProjectRunConfigurationManager">
<configuration default="false" name="imctl" type="GoApplicationRunConfiguration" factoryName="Go Application">
<module name="inetmock" />
<working_directory value="$PROJECT_DIR$" />
<parameters value="audit watch" />
<sudo value="true" />
<kind value="PACKAGE" />
<package value="gitlab.com/inetmock/inetmock/cmd/imctl" />
<directory value="$PROJECT_DIR$" />
<filePath value="$PROJECT_DIR$/cmd/imctl/main.go" />
<method v="2" />
</configuration>
</component>

13
.run/inetmock.run.xml Normal file
View file

@ -0,0 +1,13 @@
<component name="ProjectRunConfigurationManager">
<configuration default="false" name="inetmock" type="GoApplicationRunConfiguration" factoryName="Go Application">
<module name="inetmock" />
<working_directory value="$PROJECT_DIR$" />
<parameters value="serve" />
<sudo value="true" />
<kind value="PACKAGE" />
<package value="gitlab.com/inetmock/inetmock/cmd/inetmock" />
<directory value="$PROJECT_DIR$" />
<filePath value="$PROJECT_DIR$/cmd/inetmock/main.go" />
<method v="2" />
</configuration>
</component>

View file

@ -16,15 +16,18 @@ env:
tasks: tasks:
clean: clean:
desc: clean all generated files
cmds: cmds:
- find . -type f \( -name "*.pb.go" -or -name "*.mock.go" \) -exec rm -f {} \; - find . -type f \( -name "*.pb.go" -or -name "*.mock.go" \) -exec rm -f {} \;
- rm -rf ./main {{ .OUT_DIR }} - rm -rf ./main {{ .OUT_DIR }}
format: format:
desc: format all source files
cmds: cmds:
- go fmt ./... - go fmt ./...
deps: deps:
desc: download dependencies
sources: sources:
- go.mod - go.mod
- go.sum - go.sum
@ -32,6 +35,7 @@ tasks:
- go mod download - go mod download
golangci-lint: golangci-lint:
desc: run Go linter
deps: deps:
- deps - deps
- generate - generate
@ -39,6 +43,7 @@ tasks:
- golangci-lint run --timeout 5m - golangci-lint run --timeout 5m
protobuf-lint: protobuf-lint:
desc: run protobuf linter
dir: api/ dir: api/
sources: sources:
- "proto/**/*.proto" - "proto/**/*.proto"
@ -46,11 +51,13 @@ tasks:
- buf lint - buf lint
lint: lint:
desc: run all linters
deps: deps:
- golangci-lint - golangci-lint
- protobuf-lint - protobuf-lint
protoc: protoc:
desc: generate Go files from protobuf files
deps: deps:
- protobuf-lint - protobuf-lint
sources: sources:
@ -61,6 +68,7 @@ tasks:
- buf protoc --proto_path ./api/proto/ --go_out=./pkg/ --go_opt=paths=source_relative --go-grpc_out=./pkg/ --go-grpc_opt=paths=source_relative {{ .PROTO_FILES }} - buf protoc --proto_path ./api/proto/ --go_out=./pkg/ --go_opt=paths=source_relative --go-grpc_out=./pkg/ --go-grpc_opt=paths=source_relative {{ .PROTO_FILES }}
go-generate: go-generate:
desc: run all go:generate directives
deps: deps:
- deps - deps
sources: sources:
@ -71,11 +79,13 @@ tasks:
- go generate -x ./... - go generate -x ./...
generate: generate:
desc: run all code generation steps
deps: deps:
- go-generate - go-generate
- protoc - protoc
test: test:
desc: run all unit tests
sources: sources:
- "**/*.go" - "**/*.go"
deps: deps:
@ -91,6 +101,7 @@ tasks:
- rm -f {{ .OUT_DIR }}/cov-raw.out - rm -f {{ .OUT_DIR }}/cov-raw.out
integration-test: integration-test:
desc: run all benchmarks/integration tests
deps: deps:
- deps - deps
- generate - generate
@ -101,18 +112,21 @@ tasks:
{{ end }} {{ end }}
cli-cover-report: cli-cover-report:
desc: generate a coverage report on the CLI
deps: deps:
- test - test
cmds: cmds:
- go tool cover -func={{ .OUT_DIR }}/cov.out - go tool cover -func={{ .OUT_DIR }}/cov.out
html-cover-report: html-cover-report:
desc: generate a coverage report as HTML page
deps: deps:
- test - test
cmds: cmds:
- go tool cover -html={{ .OUT_DIR }}/cov.out -o {{ .OUT_DIR }}/coverage.html - go tool cover -html={{ .OUT_DIR }}/cov.out -o {{ .OUT_DIR }}/coverage.html
build-inetmock: build-inetmock:
desc: build the INetMock server part
deps: deps:
- test - test
cmds: cmds:
@ -120,10 +134,12 @@ tasks:
- go build -ldflags='-w -s' -o {{ .OUT_DIR }}/inetmock {{ .INETMOCK_PKG }} - go build -ldflags='-w -s' -o {{ .OUT_DIR }}/inetmock {{ .INETMOCK_PKG }}
debug-inetmock: debug-inetmock:
desc: run INetMock server with delve for remote debugging
cmds: cmds:
- dlv --listen=:2345 --headless=true --api-version=2 --accept-multiclient --output {{ .OUT_DIR }}/__debug_bin debug {{ .INETMOCK_PKG }} -- serve - dlv --listen=:2345 --headless=true --api-version=2 --accept-multiclient --output {{ .OUT_DIR }}/__debug_bin debug {{ .INETMOCK_PKG }} -- serve
build-imctl: build-imctl:
desc: build the imctl INetMock client CLI
deps: deps:
- test - test
cmds: cmds:
@ -131,11 +147,13 @@ tasks:
- go build -ldflags='-w -s' -o {{ .OUT_DIR }}/imctl {{ .IMCTL_PKG }} - go build -ldflags='-w -s' -o {{ .OUT_DIR }}/imctl {{ .IMCTL_PKG }}
build-all: build-all:
desc: build all binaries
deps: deps:
- build-inetmock - build-inetmock
- build-imctl - build-imctl
snapshot-release: snapshot-release:
desc: create a snapshot/test release without publishing any artifacts
deps: deps:
- deps - deps
- generate - generate
@ -143,6 +161,7 @@ tasks:
- goreleaser release --snapshot --skip-publish --rm-dist - goreleaser release --snapshot --skip-publish --rm-dist
release: release:
desc: create a release - includes artifact publishing
deps: deps:
- deps - deps
- generate - generate
@ -150,5 +169,6 @@ tasks:
- goreleaser release - goreleaser release
docs: docs:
desc: generate docs
cmds: cmds:
- mdbook build -d ./../public ./docs - mdbook build -d ./../public ./docs

View file

@ -113,7 +113,7 @@ func watchAuditEvents(_ *cobra.Command, _ []string) (err error) {
func runListSinks(*cobra.Command, []string) (err error) { func runListSinks(*cobra.Command, []string) (err error) {
auditClient := rpcV1.NewAuditServiceClient(conn) auditClient := rpcV1.NewAuditServiceClient(conn)
ctx, cancel := context.WithTimeout(cliApp.Context(), grpcTimeout) ctx, cancel := context.WithTimeout(cliApp.Context(), cfg.GRPCTimeout)
defer cancel() defer cancel()
var resp *rpcV1.ListSinksResponse var resp *rpcV1.ListSinksResponse
@ -125,19 +125,19 @@ func runListSinks(*cobra.Command, []string) (err error) {
Name string Name string
} }
var sinks []printableSink var sinks = make([]printableSink, len(resp.Sinks))
for _, s := range resp.Sinks { for i, s := range resp.Sinks {
sinks = append(sinks, printableSink{Name: s}) sinks[i] = printableSink{Name: s}
} }
writer := format.Writer(outputFormat, os.Stdout) writer := format.Writer(cfg.Format, os.Stdout)
err = writer.Write(sinks) err = writer.Write(sinks)
return return
} }
func runAddFile(_ *cobra.Command, args []string) (err error) { func runAddFile(_ *cobra.Command, args []string) (err error) {
auditClient := rpcV1.NewAuditServiceClient(conn) auditClient := rpcV1.NewAuditServiceClient(conn)
ctx, cancel := context.WithTimeout(cliApp.Context(), grpcTimeout) ctx, cancel := context.WithTimeout(cliApp.Context(), cfg.GRPCTimeout)
defer cancel() defer cancel()
var resp *rpcV1.RegisterFileSinkResponse var resp *rpcV1.RegisterFileSinkResponse
@ -154,7 +154,7 @@ func runAddFile(_ *cobra.Command, args []string) (err error) {
func runRemoveFile(_ *cobra.Command, args []string) (err error) { func runRemoveFile(_ *cobra.Command, args []string) (err error) {
auditClient := rpcV1.NewAuditServiceClient(conn) auditClient := rpcV1.NewAuditServiceClient(conn)
ctx, cancel := context.WithTimeout(cliApp.Context(), grpcTimeout) ctx, cancel := context.WithTimeout(cliApp.Context(), cfg.GRPCTimeout)
defer cancel() defer cancel()
_, err = auditClient.RemoveFileSink(ctx, &rpcV1.RemoveFileSinkRequest{TargetPath: args[0]}) _, err = auditClient.RemoveFileSink(ctx, &rpcV1.RemoveFileSinkRequest{TargetPath: args[0]})

View file

@ -44,21 +44,22 @@ func fromComponentsHealth(componentsHealth map[string]*rpcV1.ComponentHealth) in
Message string Message string
} }
var componentsInfo []printableHealthInfo var componentsInfo = make([]printableHealthInfo, len(componentsHealth))
var idx int
for componentName, component := range componentsHealth { for componentName, component := range componentsHealth {
componentsInfo = append(componentsInfo, printableHealthInfo{ componentsInfo[idx] = printableHealthInfo{
Component: componentName, Component: componentName,
State: component.State.String(), State: component.State.String(),
Message: component.Message, Message: component.Message,
}) }
idx++
} }
return componentsInfo return componentsInfo
} }
func getHealthResult() (healthResp *rpcV1.GetHealthResponse, err error) { func getHealthResult() (healthResp *rpcV1.GetHealthResponse, err error) {
var healthClient = rpcV1.NewHealthServiceClient(conn) var healthClient = rpcV1.NewHealthServiceClient(conn)
ctx, cancel := context.WithTimeout(context.Background(), grpcTimeout) ctx, cancel := context.WithTimeout(context.Background(), cfg.GRPCTimeout)
healthResp, err = healthClient.GetHealth(ctx, &rpcV1.GetHealthRequest{}) healthResp, err = healthClient.GetHealth(ctx, &rpcV1.GetHealthRequest{})
cancel() cancel()
return return
@ -75,7 +76,7 @@ func runGeneralHealth(_ *cobra.Command, _ []string) {
printable := fromComponentsHealth(healthResp.ComponentsHealth) printable := fromComponentsHealth(healthResp.ComponentsHealth)
writer := format.Writer(outputFormat, os.Stdout) writer := format.Writer(cfg.Format, os.Stdout)
if err = writer.Write(printable); err != nil { if err = writer.Write(printable); err != nil {
fmt.Printf("Error occurred during writing response values: %v\n", err) fmt.Printf("Error occurred during writing response values: %v\n", err)
} }

View file

@ -5,6 +5,7 @@ import (
"time" "time"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/spf13/pflag"
"google.golang.org/grpc" "google.golang.org/grpc"
"gitlab.com/inetmock/inetmock/internal/app" "gitlab.com/inetmock/inetmock/internal/app"
@ -15,34 +16,53 @@ const (
) )
var ( var (
inetMockSocketPath string cliApp app.App
outputFormat string conn *grpc.ClientConn
grpcTimeout time.Duration cfg config
cliApp app.App
conn *grpc.ClientConn
) )
//nolint:lll type config struct {
SocketPath string
Format string
GRPCTimeout time.Duration
}
func main() { func main() {
healthCmd.AddCommand(generalHealthCmd, containerHealthCmd) healthCmd.AddCommand(generalHealthCmd, containerHealthCmd)
cliApp = app.NewApp(
app.Spec{
Name: "imctl",
Short: "IMCTL is the CLI app to interact with an INetMock server",
Config: &cfg,
SubCommands: []*cobra.Command{healthCmd, auditCmd, pcapCmd},
LateInitTasks: []func(cmd *cobra.Command, args []string) (err error){
initGRPCConnection,
},
IgnoreMissingConfigFile: true,
FlagBindings: map[string]func(flagSet *pflag.FlagSet) *pflag.Flag{
"grpctimeout": func(flagSet *pflag.FlagSet) *pflag.Flag {
return flagSet.Lookup("grpc-timeout")
},
"format": func(flagSet *pflag.FlagSet) *pflag.Flag {
return flagSet.Lookup("format")
},
"socketpath": func(flagSet *pflag.FlagSet) *pflag.Flag {
return flagSet.Lookup("socket-path")
},
},
},
)
cliApp = app.NewApp("imctl", "IMCTL is the CLI app to interact with an INetMock server"). cliApp.RootCommand().PersistentFlags().String("socket-path", "unix:///var/run/inetmock.sock", "Path to the INetMock socket file")
WithCommands(healthCmd, auditCmd, pcapCmd). cliApp.RootCommand().PersistentFlags().StringP("format", "f", "table", "Output format to use. Possible values: table, json, yaml")
WithInitTasks(func(_ *cobra.Command, _ []string) (err error) { cliApp.RootCommand().PersistentFlags().Duration("grpc-timeout", defaultGRPCTimeout, "Timeout to connect to the gRPC API")
return initGRPCConnection()
}).
WithLogger()
cliApp.RootCommand().PersistentFlags().StringVar(&inetMockSocketPath, "socket-path", "unix:///var/run/inetmock.sock", "Path to the INetMock socket file")
cliApp.RootCommand().PersistentFlags().StringVarP(&outputFormat, "format", "f", "table", "Output format to use. Possible values: table, json, yaml")
cliApp.RootCommand().PersistentFlags().DurationVar(&grpcTimeout, "grpc-timeout", defaultGRPCTimeout, "Timeout to connect to the gRPC API")
cliApp.MustRun() cliApp.MustRun()
} }
func initGRPCConnection() (err error) { func initGRPCConnection(*cobra.Command, []string) (err error) {
dialCtx, cancel := context.WithTimeout(cliApp.Context(), grpcTimeout) dialCtx, cancel := context.WithTimeout(cliApp.Context(), cfg.GRPCTimeout)
conn, err = grpc.DialContext(dialCtx, inetMockSocketPath, grpc.WithInsecure()) conn, err = grpc.DialContext(dialCtx, cfg.SocketPath, grpc.WithInsecure())
cancel() cancel()
return return

View file

@ -56,7 +56,7 @@ var (
var err error var err error
pcapClient := rpcV1.NewPCAPServiceClient(conn) pcapClient := rpcV1.NewPCAPServiceClient(conn)
ctx, cancel := context.WithTimeout(context.Background(), grpcTimeout) ctx, cancel := context.WithTimeout(context.Background(), cfg.GRPCTimeout)
defer cancel() defer cancel()
var resp *rpcV1.ListAvailableDevicesResponse var resp *rpcV1.ListAvailableDevicesResponse
if resp, err = pcapClient.ListAvailableDevices(ctx, new(rpcV1.ListAvailableDevicesRequest)); err == nil { if resp, err = pcapClient.ListAvailableDevices(ctx, new(rpcV1.ListAvailableDevicesRequest)); err == nil {
@ -91,7 +91,7 @@ var (
ValidArgsFunction: func(*cobra.Command, []string, string) ([]string, cobra.ShellCompDirective) { ValidArgsFunction: func(*cobra.Command, []string, string) ([]string, cobra.ShellCompDirective) {
var err error var err error
pcapClient := rpcV1.NewPCAPServiceClient(conn) pcapClient := rpcV1.NewPCAPServiceClient(conn)
ctx, cancel := context.WithTimeout(context.Background(), grpcTimeout) ctx, cancel := context.WithTimeout(context.Background(), cfg.GRPCTimeout)
defer cancel() defer cancel()
var resp *rpcV1.ListActiveRecordingsResponse var resp *rpcV1.ListActiveRecordingsResponse
if resp, err = pcapClient.ListActiveRecordings(ctx, new(rpcV1.ListActiveRecordingsRequest)); err == nil { if resp, err = pcapClient.ListActiveRecordings(ctx, new(rpcV1.ListActiveRecordingsRequest)); err == nil {
@ -127,7 +127,7 @@ func init() {
func runListAvailableDevices(*cobra.Command, []string) (err error) { func runListAvailableDevices(*cobra.Command, []string) (err error) {
pcapClient := rpcV1.NewPCAPServiceClient(conn) pcapClient := rpcV1.NewPCAPServiceClient(conn)
ctx, cancel := context.WithTimeout(cliApp.Context(), grpcTimeout) ctx, cancel := context.WithTimeout(cliApp.Context(), cfg.GRPCTimeout)
defer cancel() defer cancel()
var resp *rpcV1.ListAvailableDevicesResponse var resp *rpcV1.ListAvailableDevicesResponse
@ -149,7 +149,7 @@ func runListAvailableDevices(*cobra.Command, []string) (err error) {
}) })
} }
writer := format.Writer(outputFormat, os.Stdout) writer := format.Writer(cfg.Format, os.Stdout)
err = writer.Write(availableDevs) err = writer.Write(availableDevs)
return return
} }
@ -163,7 +163,7 @@ func runListActiveRecordings(*cobra.Command, []string) error {
pcapClient := rpcV1.NewPCAPServiceClient(conn) pcapClient := rpcV1.NewPCAPServiceClient(conn)
ctx, cancel := context.WithTimeout(cliApp.Context(), grpcTimeout) ctx, cancel := context.WithTimeout(cliApp.Context(), cfg.GRPCTimeout)
defer cancel() defer cancel()
var err error var err error
@ -172,21 +172,21 @@ func runListActiveRecordings(*cobra.Command, []string) error {
return err return err
} }
var out []printableSubscription var out = make([]printableSubscription, len(resp.Subscriptions))
for _, subscription := range resp.Subscriptions { for idx, subscription := range resp.Subscriptions {
splitIdx := strings.Index(subscription, ":") splitIdx := strings.Index(subscription, ":")
if splitIdx < 0 { if splitIdx < 0 {
continue continue
} }
out = append(out, printableSubscription{ out[idx] = printableSubscription{
Name: subscription[splitIdx:], Name: subscription[splitIdx:],
Device: subscription[:splitIdx], Device: subscription[:splitIdx],
ConsumerKey: subscription, ConsumerKey: subscription,
}) }
} }
writer := format.Writer(outputFormat, os.Stdout) writer := format.Writer(cfg.Format, os.Stdout)
return writer.Write(out) return writer.Write(out)
} }
@ -198,7 +198,7 @@ func runAddRecording(_ *cobra.Command, args []string) (err error) {
return return
} }
ctx, cancel := context.WithTimeout(cliApp.Context(), grpcTimeout) ctx, cancel := context.WithTimeout(cliApp.Context(), cfg.GRPCTimeout)
defer cancel() defer cancel()
var resp *rpcV1.StartPCAPFileRecordingResponse var resp *rpcV1.StartPCAPFileRecordingResponse
@ -219,7 +219,7 @@ func runAddRecording(_ *cobra.Command, args []string) (err error) {
func runRemoveCurrentlyRunningRecording(_ *cobra.Command, args []string) error { func runRemoveCurrentlyRunningRecording(_ *cobra.Command, args []string) error {
pcapClient := rpcV1.NewPCAPServiceClient(conn) pcapClient := rpcV1.NewPCAPServiceClient(conn)
listRecsCtx, listRecsCancel := context.WithTimeout(cliApp.Context(), grpcTimeout) listRecsCtx, listRecsCancel := context.WithTimeout(cliApp.Context(), cfg.GRPCTimeout)
defer listRecsCancel() defer listRecsCancel()
var err error var err error
@ -240,7 +240,7 @@ func runRemoveCurrentlyRunningRecording(_ *cobra.Command, args []string) error {
return fmt.Errorf("the given subscription is not known: %s", args[0]) return fmt.Errorf("the given subscription is not known: %s", args[0])
} }
ctx, cancel := context.WithTimeout(cliApp.Context(), grpcTimeout) ctx, cancel := context.WithTimeout(cliApp.Context(), cfg.GRPCTimeout)
defer cancel() defer cancel()
var stopRecResp *rpcV1.StopPCAPFileRecordingResponse var stopRecResp *rpcV1.StopPCAPFileRecordingResponse
@ -259,7 +259,7 @@ func runRemoveCurrentlyRunningRecording(_ *cobra.Command, args []string) error {
} }
func isValidRecordDevice(device string, pcapClient rpcV1.PCAPServiceClient) (err error) { func isValidRecordDevice(device string, pcapClient rpcV1.PCAPServiceClient) (err error) {
ctx, cancel := context.WithTimeout(cliApp.Context(), grpcTimeout) ctx, cancel := context.WithTimeout(cliApp.Context(), cfg.GRPCTimeout)
defer cancel() defer cancel()
var resp *rpcV1.ListAvailableDevicesResponse var resp *rpcV1.ListAvailableDevicesResponse
if resp, err = pcapClient.ListAvailableDevices(ctx, new(rpcV1.ListAvailableDevicesRequest)); err != nil { if resp, err = pcapClient.ListAvailableDevices(ctx, new(rpcV1.ListAvailableDevicesRequest)); err != nil {

View file

@ -1,33 +1,101 @@
package main package main
import ( import (
"net/url"
"os"
"path/filepath"
"time"
"github.com/spf13/cobra"
"gitlab.com/inetmock/inetmock/internal/app" "gitlab.com/inetmock/inetmock/internal/app"
"gitlab.com/inetmock/inetmock/internal/endpoint"
dns "gitlab.com/inetmock/inetmock/internal/endpoint/handler/dns/mock" dns "gitlab.com/inetmock/inetmock/internal/endpoint/handler/dns/mock"
http "gitlab.com/inetmock/inetmock/internal/endpoint/handler/http/mock" http "gitlab.com/inetmock/inetmock/internal/endpoint/handler/http/mock"
"gitlab.com/inetmock/inetmock/internal/endpoint/handler/http/proxy" "gitlab.com/inetmock/inetmock/internal/endpoint/handler/http/proxy"
"gitlab.com/inetmock/inetmock/internal/endpoint/handler/metrics" "gitlab.com/inetmock/inetmock/internal/endpoint/handler/metrics"
"gitlab.com/inetmock/inetmock/internal/endpoint/handler/tls/interceptor" "gitlab.com/inetmock/inetmock/internal/endpoint/handler/tls/interceptor"
"gitlab.com/inetmock/inetmock/pkg/cert"
) )
var ( var (
serverApp app.App serverApp app.App
cfg appConfig
registrations = []endpoint.Registration{
http.AddHTTPMock,
dns.AddDNSMock,
interceptor.AddTLSInterceptor,
proxy.AddHTTPProxy,
metrics.AddMetricsExporter,
}
) )
type Data struct {
PCAP string
Audit string
}
func (d *Data) setup() (err error) {
if d.PCAP, err = ensureDataDir(d.PCAP); err != nil {
return
}
if d.Audit, err = ensureDataDir(d.Audit); err != nil {
return
}
return
}
func ensureDataDir(dataDirPath string) (cleanedPath string, err error) {
cleanedPath = dataDirPath
if !filepath.IsAbs(cleanedPath) {
if cleanedPath, err = filepath.Abs(cleanedPath); err != nil {
return
}
}
err = os.MkdirAll(cleanedPath, 0640)
return
}
type appConfig struct {
TLS cert.Options
Listeners map[string]endpoint.ListenerSpec
API struct {
Listen string
}
Data Data
}
func (c *appConfig) APIURL() *url.URL {
if u, err := url.Parse(c.API.Listen); err != nil {
u, _ = url.Parse("tcp://:0")
return u
} else {
return u
}
}
func main() { func main() {
serverApp = app.NewApp("inetmock", "INetMock is lightweight internet mock"). serverApp = app.NewApp(
WithHandlerRegistry( app.Spec{
http.AddHTTPMock, Name: "inetmock",
dns.AddDNSMock, Short: "INetMock is lightweight internet mock",
interceptor.AddTLSInterceptor, Config: &cfg,
proxy.AddHTTPProxy, SubCommands: []*cobra.Command{serveCmd, generateCaCmd},
metrics.AddMetricsExporter). Defaults: map[string]interface{}{
WithCommands(serveCmd, generateCaCmd). "api.listen": "tcp://:0",
WithConfig(). "data.pcap": "/var/lib/inetmock/data/pcap",
WithLogger(). "data.audit": "/var/lib/inetmock/data/audit",
WithHealthChecker(). "tls.curve": cert.CurveTypeP256,
WithCertStore(). "tls.minTLSVersion": cert.TLSVersionTLS10,
WithEventStream(). "tls.includeInsecureCipherSuites": false,
WithEndpointManager() "tls.validity.server.notBeforeRelative": 168 * time.Hour,
"tls.validity.server.notAfterRelative": 168 * time.Hour,
"tls.certCachePath": "/tmp",
},
},
)
serverApp.MustRun() serverApp.MustRun()
} }

View file

@ -4,9 +4,19 @@ import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
"go.uber.org/zap" "go.uber.org/zap"
"gitlab.com/inetmock/inetmock/internal/endpoint"
"gitlab.com/inetmock/inetmock/internal/pcap" "gitlab.com/inetmock/inetmock/internal/pcap"
"gitlab.com/inetmock/inetmock/internal/pcap/consumers/audit" audit2 "gitlab.com/inetmock/inetmock/internal/pcap/consumers/audit"
"gitlab.com/inetmock/inetmock/internal/rpc" "gitlab.com/inetmock/inetmock/internal/rpc"
"gitlab.com/inetmock/inetmock/pkg/audit"
"gitlab.com/inetmock/inetmock/pkg/audit/sink"
"gitlab.com/inetmock/inetmock/pkg/cert"
"gitlab.com/inetmock/inetmock/pkg/health"
"gitlab.com/inetmock/inetmock/pkg/logging"
)
const (
defaultEventBufferSize = 10
) )
var ( var (
@ -19,33 +29,67 @@ var (
) )
func startINetMock(_ *cobra.Command, _ []string) error { func startINetMock(_ *cobra.Command, _ []string) error {
rpcAPI := rpc.NewINetMockAPI( registry := endpoint.NewHandlerRegistry()
serverApp.Config().APIURL(),
serverApp.Logger(),
serverApp.Checker(),
serverApp.EventStream(),
serverApp.Config().AuditDataDir(),
serverApp.Config().PCAPDataDir(),
)
logger := serverApp.Logger()
cfg := serverApp.Config() var err error
endpointOrchestrator := serverApp.EndpointManager() appLogger := serverApp.Logger()
for name, spec := range cfg.ListenerSpecs() { if err = cfg.Data.setup(); err != nil {
if spec.Name == "" { appLogger.Error("Failed to setup data directories", zap.Error(err))
spec.Name = name return err
} }
if err := endpointOrchestrator.RegisterListener(spec); err != nil {
logger.Error("Failed to register listener", zap.Error(err)) for _, registration := range registrations {
if err = registration(registry); err != nil {
appLogger.Error("Failed to run registration", zap.Error(err))
return err return err
} }
} }
errChan := serverApp.EndpointManager().StartEndpoints() var certStore cert.Store
if certStore, err = cert.NewDefaultStore(cfg.TLS, appLogger.Named("CertStore")); err != nil {
appLogger.Error("Failed to initialize cert store", zap.Error(err))
return err
}
var eventStream audit.EventStream
if eventStream, err = setupEventStream(appLogger); err != nil {
return err
}
var endpointOrchestrator = endpoint.NewOrchestrator(
serverApp.Context(),
certStore,
registry,
eventStream,
appLogger,
)
checker := health.New()
rpcAPI := rpc.NewINetMockAPI(
cfg.APIURL(),
appLogger,
checker,
eventStream,
cfg.Data.Audit,
cfg.Data.PCAP,
)
for name, spec := range cfg.Listeners {
if spec.Name == "" {
spec.Name = name
}
if err := endpointOrchestrator.RegisterListener(spec); err != nil {
appLogger.Error("Failed to register listener", zap.Error(err))
return err
}
}
errChan := endpointOrchestrator.StartEndpoints()
if err := rpcAPI.StartServer(); err != nil { if err := rpcAPI.StartServer(); err != nil {
serverApp.Shutdown() serverApp.Shutdown()
logger.Error( appLogger.Error(
"failed to start gRPC API", "failed to start gRPC API",
zap.Error(err), zap.Error(err),
) )
@ -60,22 +104,55 @@ loop:
for { for {
select { select {
case err := <-errChan: case err := <-errChan:
logger.Error("got error from endpoint", zap.Error(err)) appLogger.Error("got error from endpoint", zap.Error(err))
case <-serverApp.Context().Done(): case <-serverApp.Context().Done():
break loop break loop
} }
} }
logger.Info("App context canceled - shutting down") appLogger.Info("App context canceled - shutting down")
rpcAPI.StopServer() rpcAPI.StopServer()
return nil return nil
} }
func setupEventStream(appLogger logging.Logger) (audit.EventStream, error) {
var evenStream audit.EventStream
var err error
evenStream, err = audit.NewEventStream(
appLogger.Named("EventStream"),
audit.WithSinkBufferSize(defaultEventBufferSize),
)
if err != nil {
appLogger.Error("Failed to initialize event stream", zap.Error(err))
return nil, err
}
if err = evenStream.RegisterSink(serverApp.Context(), sink.NewLogSink(appLogger.Named("LogSink"))); err != nil {
appLogger.Error("Failed to register log sink to event stream", zap.Error(err))
return nil, err
}
var metricSink audit.Sink
if metricSink, err = sink.NewMetricSink(); err != nil {
appLogger.Error("Failed to setup metrics sink", zap.Error(err))
return nil, err
}
if err = evenStream.RegisterSink(serverApp.Context(), metricSink); err != nil {
appLogger.Error("Failed to register metric sink", zap.Error(err))
return nil, err
}
return evenStream, nil
}
//nolint:deadcode //nolint:deadcode
func startAuditConsumer() error { func startAuditConsumer(eventStream audit.EventStream) error {
recorder := pcap.NewRecorder() recorder := pcap.NewRecorder()
auditConsumer := audit.NewAuditConsumer("audit", serverApp.EventStream())
auditConsumer := audit2.NewAuditConsumer("audit", eventStream)
return recorder.StartRecording(serverApp.Context(), "lo", auditConsumer) return recorder.StartRecording(serverApp.Context(), "lo", auditConsumer)
} }

View file

@ -59,7 +59,7 @@ api:
tls: tls:
curve: P256 curve: P256
minTLSVersion: SSL3 minTLSVersion: TLS10
includeInsecureCipherSuites: false includeInsecureCipherSuites: false
validity: validity:
ca: ca:

View file

@ -63,7 +63,7 @@ api:
tls: tls:
curve: P256 curve: P256
minTLSVersion: SSL3 minTLSVersion: TLS10
includeInsecureCipherSuites: false includeInsecureCipherSuites: false
validity: validity:
ca: ca:

21
go.mod
View file

@ -6,25 +6,24 @@ require (
github.com/bwmarrin/snowflake v0.3.0 github.com/bwmarrin/snowflake v0.3.0
github.com/docker/go-connections v0.4.0 github.com/docker/go-connections v0.4.0
github.com/golang/mock v1.5.0 github.com/golang/mock v1.5.0
github.com/golang/protobuf v1.4.3
github.com/google/gopacket v1.1.19 github.com/google/gopacket v1.1.19
github.com/google/uuid v1.2.0 github.com/google/uuid v1.2.0
github.com/imdario/mergo v0.3.11 github.com/imdario/mergo v0.3.12
github.com/jinzhu/copier v0.2.5 github.com/jinzhu/copier v0.3.0
github.com/miekg/dns v1.1.40 github.com/maxatome/go-testdeep v1.9.2
github.com/miekg/dns v1.1.41
github.com/mitchellh/mapstructure v1.4.1 github.com/mitchellh/mapstructure v1.4.1
github.com/olekukonko/tablewriter v0.0.5 github.com/olekukonko/tablewriter v0.0.5
github.com/prometheus/client_golang v1.9.0 github.com/prometheus/client_golang v1.10.0
github.com/soheilhy/cmux v0.1.4 github.com/soheilhy/cmux v0.1.5
github.com/spf13/cobra v1.1.3 github.com/spf13/cobra v1.1.3
github.com/spf13/pflag v1.0.5
github.com/spf13/viper v1.7.1 github.com/spf13/viper v1.7.1
github.com/testcontainers/testcontainers-go v0.9.0 github.com/testcontainers/testcontainers-go v0.10.0
go.uber.org/multierr v1.6.0 go.uber.org/multierr v1.6.0
go.uber.org/zap v1.16.0 go.uber.org/zap v1.16.0
google.golang.org/grpc v1.36.0 google.golang.org/grpc v1.37.0
google.golang.org/protobuf v1.25.0 google.golang.org/protobuf v1.26.0
gopkg.in/elazarl/goproxy.v1 v1.0.0-20180725130230-947c36da3153 gopkg.in/elazarl/goproxy.v1 v1.0.0-20180725130230-947c36da3153
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b
) )
replace github.com/docker/docker => github.com/docker/docker v17.12.0-ce-rc1.0.20200916142827-bd33bbf0497b+incompatible

1031
go.sum

File diff suppressed because it is too large Load diff

View file

@ -4,18 +4,17 @@ package app
import ( import (
"context" "context"
"fmt"
"os" "os"
"os/signal" "os/signal"
"strings"
"syscall" "syscall"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/spf13/pflag"
"github.com/spf13/viper"
"go.uber.org/zap" "go.uber.org/zap"
"gitlab.com/inetmock/inetmock/internal/endpoint"
"gitlab.com/inetmock/inetmock/pkg/audit"
"gitlab.com/inetmock/inetmock/pkg/audit/sink"
"gitlab.com/inetmock/inetmock/pkg/cert"
"gitlab.com/inetmock/inetmock/pkg/health"
"gitlab.com/inetmock/inetmock/pkg/logging" "gitlab.com/inetmock/inetmock/pkg/logging"
"gitlab.com/inetmock/inetmock/pkg/path" "gitlab.com/inetmock/inetmock/pkg/path"
) )
@ -26,71 +25,30 @@ var (
developmentLogs bool developmentLogs bool
) )
type contextKey string type Spec struct {
Name string
const ( Short string
loggerKey contextKey = "gitlab.com/inetmock/inetmock/app/context/logger" Config interface{}
configKey contextKey = "gitlab.com/inetmock/inetmock/app/context/config" IgnoreMissingConfigFile bool
handlerRegistryKey contextKey = "gitlab.com/inetmock/inetmock/app/context/handlerRegistry" Defaults map[string]interface{}
healthCheckerKey contextKey = "gitlab.com/inetmock/inetmock/app/context/healthChecker" FlagBindings map[string]func(flagSet *pflag.FlagSet) *pflag.Flag
endpointManagerKey contextKey = "gitlab.com/inetmock/inetmock/app/context/endpointManager" SubCommands []*cobra.Command
certStoreKey contextKey = "gitlab.com/inetmock/inetmock/app/context/certStore" LateInitTasks []func(cmd *cobra.Command, args []string) (err error)
eventStreamKey contextKey = "gitlab.com/inetmock/inetmock/app/context/eventStream" }
defaultEventBufferSize = 10
)
type App interface { type App interface {
EventStream() audit.EventStream
Config() Config
Checker() health.Checker
Logger() logging.Logger Logger() logging.Logger
EndpointManager() endpoint.Orchestrator
HandlerRegistry() endpoint.HandlerRegistry
Context() context.Context Context() context.Context
RootCommand() *cobra.Command RootCommand() *cobra.Command
MustRun() MustRun()
Shutdown() Shutdown()
// WithCommands adds subcommands to the root command
// requires nothing
WithCommands(cmds ...*cobra.Command) App
// WithHandlerRegistry builds up the handler registry
// requires nothing
WithHandlerRegistry(registrations ...endpoint.Registration) App
// WithHealthChecker adds the health checker mechanism
// requires nothing
WithHealthChecker() App
// WithLogger configures the logging system
// requires nothing
WithLogger() App
// WithEndpointManager creates an endpoint manager instance and adds it to the context
// requires WithHandlerRegistry, WithHealthChecker and WithLogger
WithEndpointManager() App
// WithCertStore initializes the cert store
// requires WithLogger and WithConfig
WithCertStore() App
// WithEventStream adds the audit event stream
// requires WithLogger
WithEventStream() App
// WithConfig loads the config
// requires nothing
WithConfig() App
WithInitTasks(task ...func(cmd *cobra.Command, args []string) (err error)) App
} }
type app struct { type app struct {
rootCmd *cobra.Command rootCmd *cobra.Command
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
lateInitTasks []func(cmd *cobra.Command, args []string) (err error) logger logging.Logger
} }
func (a *app) MustRun() { func (a *app) MustRun() {
@ -107,67 +65,7 @@ func (a *app) MustRun() {
} }
func (a *app) Logger() logging.Logger { func (a *app) Logger() logging.Logger {
val := a.ctx.Value(loggerKey) return a.logger
if val == nil {
return nil
}
return val.(logging.Logger)
}
func (a *app) Config() Config {
val := a.ctx.Value(configKey)
if val == nil {
return nil
}
return val.(Config)
}
func (a *app) CertStore() cert.Store {
val := a.ctx.Value(certStoreKey)
if val == nil {
return nil
}
return val.(cert.Store)
}
func (a *app) Checker() health.Checker {
val := a.ctx.Value(healthCheckerKey)
if val == nil {
return nil
}
return val.(health.Checker)
}
func (a *app) EndpointManager() endpoint.Orchestrator {
val := a.ctx.Value(endpointManagerKey)
if val == nil {
return nil
}
return val.(endpoint.Orchestrator)
}
func (a *app) Audit() audit.Emitter {
val := a.ctx.Value(eventStreamKey)
if val == nil {
return nil
}
return val.(audit.Emitter)
}
func (a *app) EventStream() audit.EventStream {
val := a.ctx.Value(eventStreamKey)
if val == nil {
return nil
}
return val.(audit.EventStream)
}
func (a *app) HandlerRegistry() endpoint.HandlerRegistry {
val := a.ctx.Value(handlerRegistryKey)
if val == nil {
return nil
}
return val.(endpoint.HandlerRegistry)
} }
func (a *app) Context() context.Context { func (a *app) Context() context.Context {
@ -182,167 +80,55 @@ func (a *app) Shutdown() {
a.cancel() a.cancel()
} }
// WithCommands adds subcommands to the root command func NewApp(spec Spec) App {
// requires nothing if spec.Defaults == nil {
func (a *app) WithCommands(cmds ...*cobra.Command) App { spec.Defaults = make(map[string]interface{})
a.rootCmd.AddCommand(cmds...)
return a
}
// WithHandlerRegistry builds up the handler registry
// requires nothing
func (a *app) WithHandlerRegistry(registrations ...endpoint.Registration) App {
registry := endpoint.NewHandlerRegistry()
for _, registration := range registrations {
if err := registration(registry); err != nil {
panic(err)
}
} }
a.ctx = context.WithValue(a.ctx, handlerRegistryKey, registry)
return a
}
// WithHealthChecker adds the health checker mechanism
// requires nothing
func (a *app) WithHealthChecker() App {
checker := health.New()
a.ctx = context.WithValue(a.ctx, healthCheckerKey, checker)
return a
}
// WithLogger configures the logging system
// requires nothing
func (a *app) WithLogger() App {
a.lateInitTasks = append(a.lateInitTasks, func(cmd *cobra.Command, args []string) (err error) {
logging.ConfigureLogging(
logging.ParseLevel(logLevel),
developmentLogs,
map[string]interface{}{
"cwd": path.WorkingDirectory(),
"cmd": cmd.Name(),
"args": args,
},
)
var logger logging.Logger
if logger, err = logging.CreateLogger(); err != nil {
return
}
a.ctx = context.WithValue(a.ctx, loggerKey, logger)
return
})
return a
}
// WithEndpointManager creates an endpoint manager instance and adds it to the context
// requires WithHandlerRegistry, WithHealthChecker and WithLogger
func (a *app) WithEndpointManager() App {
a.lateInitTasks = append(a.lateInitTasks, func(_ *cobra.Command, _ []string) (err error) {
epMgr := endpoint.NewOrchestrator(
a.Context(),
a.CertStore(),
a.HandlerRegistry(),
a.Audit(),
a.Logger().Named("Orchestrator"),
)
a.ctx = context.WithValue(a.ctx, endpointManagerKey, epMgr)
return
})
return a
}
// WithCertStore initializes the cert store
// requires WithLogger and WithConfig
func (a *app) WithCertStore() App {
a.lateInitTasks = append(a.lateInitTasks, func(cmd *cobra.Command, args []string) (err error) {
var certStore cert.Store
if certStore, err = cert.NewDefaultStore(
a.Config().TLSConfig(),
a.Logger().Named("CertStore"),
); err != nil {
return
}
a.ctx = context.WithValue(a.ctx, certStoreKey, certStore)
return
})
return a
}
// WithEventStream adds the audit event stream
// requires WithLogger
func (a *app) WithEventStream() App {
a.lateInitTasks = append(a.lateInitTasks, func(_ *cobra.Command, _ []string) (err error) {
var eventStream audit.EventStream
eventStream, err = audit.NewEventStream(
a.Logger().Named("EventStream"),
audit.WithSinkBufferSize(defaultEventBufferSize),
)
if err != nil {
return
}
if err = eventStream.RegisterSink(a.ctx, sink.NewLogSink(a.Logger().Named("LogSink"))); err != nil {
return
}
var metricSink audit.Sink
if metricSink, err = sink.NewMetricSink(); err != nil {
return
}
if err = eventStream.RegisterSink(a.ctx, metricSink); err != nil {
return
}
a.ctx = context.WithValue(a.ctx, eventStreamKey, eventStream)
return
})
return a
}
// WithConfig loads the config
// requires nothing
func (a *app) WithConfig() App {
a.lateInitTasks = append(a.lateInitTasks, func(cmd *cobra.Command, _ []string) (err error) {
cfg := CreateConfig()
if err = cfg.ReadConfig(configFilePath); err != nil {
return
}
a.ctx = context.WithValue(a.ctx, configKey, cfg)
return
})
return a
}
func (a *app) WithInitTasks(task ...func(cmd *cobra.Command, args []string) (err error)) App {
a.lateInitTasks = append(a.lateInitTasks, task...)
return a
}
func NewApp(name, short string) App {
ctx, cancel := initAppContext() ctx, cancel := initAppContext()
a := &app{ a := &app{
rootCmd: &cobra.Command{ rootCmd: &cobra.Command{
Use: name, Use: spec.Name,
Short: short, Short: spec.Short,
SilenceUsage: true,
}, },
ctx: ctx, ctx: ctx,
cancel: cancel, cancel: cancel,
} }
lateInitTasks := []func(cmd *cobra.Command, args []string) (err error){
func(*cobra.Command, []string) (err error) {
return spec.readConfig(a.rootCmd)
},
func(cmd *cobra.Command, args []string) (err error) {
logging.ConfigureLogging(
logging.ParseLevel(logLevel),
developmentLogs,
map[string]interface{}{
"cwd": path.WorkingDirectory(),
"cmd": cmd.Name(),
"args": args,
},
)
if a.logger, err = logging.CreateLogger(); err != nil {
return
}
return
},
}
lateInitTasks = append(lateInitTasks, spec.LateInitTasks...)
a.rootCmd.AddCommand(completionCmd) a.rootCmd.AddCommand(completionCmd)
a.rootCmd.AddCommand(spec.SubCommands...)
a.rootCmd.PersistentFlags().StringVar(&configFilePath, "config", "", "Path to config file that should be used") a.rootCmd.PersistentFlags().StringVar(&configFilePath, "config", "", "Path to config file that should be used")
a.rootCmd.PersistentFlags().StringVar(&logLevel, "log-level", "info", "logging level to use") a.rootCmd.PersistentFlags().StringVar(&logLevel, "log-level", "info", "logging level to use")
a.rootCmd.PersistentFlags().BoolVar(&developmentLogs, "development-logs", false, "Enable development mode logs") a.rootCmd.PersistentFlags().BoolVar(&developmentLogs, "development-logs", false, "Enable development mode logs")
a.rootCmd.PersistentPreRunE = func(cmd *cobra.Command, args []string) (err error) { a.rootCmd.PersistentPreRunE = func(cmd *cobra.Command, args []string) (err error) {
for _, initTask := range a.lateInitTasks { for _, initTask := range lateInitTasks {
if err = initTask(cmd, args); err != nil { if err = initTask(cmd, args); err != nil {
return return
} }
@ -366,3 +152,43 @@ func initAppContext() (context.Context, context.CancelFunc) {
return ctx, cancel return ctx, cancel
} }
func (s Spec) readConfig(rootCmd *cobra.Command) error {
viperCfg := viper.NewWithOptions()
viperCfg.SetConfigName("config")
viperCfg.SetConfigType("yaml")
viperCfg.AddConfigPath(fmt.Sprintf("/etc/%s/", strings.ToLower(s.Name)))
viperCfg.AddConfigPath(fmt.Sprintf("$HOME/.%s/", strings.ToLower(s.Name)))
viperCfg.AddConfigPath(".")
viperCfg.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
viperCfg.SetEnvPrefix("INETMOCK")
viperCfg.AutomaticEnv()
if s.FlagBindings != nil {
for key, selector := range s.FlagBindings {
if err := viperCfg.BindPFlag(key, selector(rootCmd.Flags())); err != nil {
return err
}
}
}
for k, v := range s.Defaults {
viperCfg.SetDefault(k, v)
}
if configFilePath != "" && path.FileExists(configFilePath) {
viperCfg.SetConfigFile(configFilePath)
}
if err := viperCfg.ReadInConfig(); err != nil {
if _, ok := err.(viper.ConfigFileNotFoundError); !(ok && s.IgnoreMissingConfigFile) {
return err
}
}
if err := viperCfg.Unmarshal(s.Config); err != nil {
return err
}
return nil
}

View file

@ -1,117 +0,0 @@
package app
import (
"net/url"
"strings"
"github.com/spf13/viper"
"gitlab.com/inetmock/inetmock/internal/endpoint"
"gitlab.com/inetmock/inetmock/pkg/cert"
"gitlab.com/inetmock/inetmock/pkg/path"
)
func CreateConfig() Config {
configInstance := &config{
cfg: viper.New(),
}
configInstance.cfg.SetConfigName("config")
configInstance.cfg.SetConfigType("yaml")
configInstance.cfg.AddConfigPath("/etc/inetmock/")
configInstance.cfg.AddConfigPath("$HOME/.inetmock")
configInstance.cfg.AddConfigPath(".")
configInstance.cfg.SetEnvPrefix("INetMock")
configInstance.cfg.SetEnvKeyReplacer(strings.NewReplacer("-", "_"))
configInstance.cfg.AutomaticEnv()
for k, v := range registeredDefaults {
configInstance.cfg.SetDefault(k, v)
}
for k, v := range registeredAliases {
configInstance.cfg.RegisterAlias(k, v)
}
return configInstance
}
type Config interface {
ReadConfig(configFilePath string) error
ReadConfigString(config, format string) error
TLSConfig() cert.Options
APIURL() *url.URL
AuditDataDir() string
PCAPDataDir() string
ListenerSpecs() map[string]endpoint.ListenerSpec
}
type apiConfig struct {
Listen string
}
type config struct {
cfg *viper.Viper
TLS cert.Options
Listeners map[string]endpoint.ListenerSpec
API apiConfig
Data Data
}
func (c *config) AuditDataDir() string {
return c.Data.Audit
}
func (c *config) PCAPDataDir() string {
return c.Data.PCAP
}
func (c *config) APIURL() *url.URL {
if u, err := url.Parse(c.API.Listen); err != nil {
u, _ = url.Parse("tcp://:0")
return u
} else {
return u
}
}
func (c *config) ListenerSpecs() map[string]endpoint.ListenerSpec {
return c.Listeners
}
func (c *config) TLSConfig() cert.Options {
return c.TLS
}
func (c *config) ReadConfigString(config, format string) (err error) {
c.cfg.SetConfigType(format)
if err = c.cfg.ReadConfig(strings.NewReader(config)); err != nil {
return
}
err = c.cfg.Unmarshal(c)
return
}
func (c *config) ReadConfig(configFilePath string) (err error) {
if configFilePath != "" && path.FileExists(configFilePath) {
c.cfg.SetConfigFile(configFilePath)
}
if err = c.cfg.ReadInConfig(); err != nil {
if _, ok := err.(viper.ConfigFileNotFoundError); ok {
err = nil
} else {
return
}
}
if err = c.cfg.Unmarshal(c); err != nil {
return
}
if err = c.Data.setup(); err != nil {
return
}
return
}

View file

@ -1,74 +0,0 @@
package app
import (
"reflect"
"testing"
"gitlab.com/inetmock/inetmock/internal/endpoint"
)
func Test_config_ReadConfig(t *testing.T) {
type args struct {
config string
}
type testCase struct {
name string
args args
wantListeners map[string]endpoint.ListenerSpec
wantErr bool
}
tests := []testCase{
{
name: "Test endpoints config",
args: args{
// language=yaml
config: `
listeners:
tcp_80:
name: ''
protocol: tcp
listenAddress: ''
port: 80
tcp_443:
name: ''
protocol: tcp
listenAddress: ''
port: 443
`,
},
wantListeners: map[string]endpoint.ListenerSpec{
"tcp_80": {
Name: "",
Protocol: "tcp",
Address: "",
Port: 80,
Endpoints: nil,
},
"tcp_443": {
Name: "",
Protocol: "tcp",
Address: "",
Port: 443,
Endpoints: nil,
},
},
wantErr: false,
},
}
scenario := func(tt testCase) func(t *testing.T) {
return func(t *testing.T) {
cfg := CreateConfig()
if err := cfg.ReadConfigString(tt.args.config, "yaml"); (err != nil) != tt.wantErr {
t.Errorf("ReadConfig() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(tt.wantListeners, cfg.ListenerSpecs()) {
t.Errorf("want = %v, got = %v", tt.wantListeners, cfg.ListenerSpecs())
}
}
}
for _, tt := range tests {
t.Run(tt.name, scenario(tt))
}
}

View file

@ -1,6 +0,0 @@
package app
const (
EndpointsKey = "endpoints"
OptionsKey = "options"
)

View file

@ -1,34 +0,0 @@
package app
import (
"os"
"path/filepath"
)
type Data struct {
PCAP string
Audit string
}
func (d *Data) setup() (err error) {
if d.PCAP, err = ensureDataDir(d.PCAP); err != nil {
return
}
if d.Audit, err = ensureDataDir(d.Audit); err != nil {
return
}
return
}
func ensureDataDir(dataDirPath string) (cleanedPath string, err error) {
cleanedPath = dataDirPath
if !filepath.IsAbs(cleanedPath) {
if cleanedPath, err = filepath.Abs(cleanedPath); err != nil {
return
}
}
err = os.MkdirAll(cleanedPath, 0640)
return
}

View file

@ -1,15 +0,0 @@
package app
var (
registeredDefaults = make(map[string]interface{})
// default aliases
registeredAliases = map[string]string{}
)
func AddDefaultValue(key string, val interface{}) {
registeredDefaults[key] = val
}
func AddAlias(alias, orig string) {
registeredAliases[alias] = orig
}

View file

@ -4,9 +4,12 @@ import (
"net" "net"
"reflect" "reflect"
"testing" "testing"
"github.com/maxatome/go-testdeep/td"
) )
func Test_randomIPFallback_GetIP(t *testing.T) { func Test_randomIPFallback_GetIP(t *testing.T) {
t.Parallel()
ra := randomIPFallback{} ra := randomIPFallback{}
for i := 0; i < 1000; i++ { for i := 0; i < 1000; i++ {
if got := ra.GetIP(); reflect.DeepEqual(got, net.IP{}) { if got := ra.GetIP(); reflect.DeepEqual(got, net.IP{}) {
@ -16,6 +19,7 @@ func Test_randomIPFallback_GetIP(t *testing.T) {
} }
func Test_incrementalIPFallback_GetIP(t *testing.T) { func Test_incrementalIPFallback_GetIP(t *testing.T) {
t.Parallel()
type fields struct { type fields struct {
latestIP uint32 latestIP uint32
} }
@ -59,24 +63,22 @@ func Test_incrementalIPFallback_GetIP(t *testing.T) {
}, },
}, },
} }
scenario := func(tt testCase) func(t *testing.T) { for _, tc := range tests {
return func(t *testing.T) { tt := tc
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
i := &incrementalIPFallback{ i := &incrementalIPFallback{
latestIP: tt.fields.latestIP, latestIP: tt.fields.latestIP,
} }
for k := 0; k < len(tt.want); k++ { for k := 0; k < len(tt.want); k++ {
if got := i.GetIP(); !reflect.DeepEqual(got, tt.want[k]) { td.Cmp(t, i.GetIP(), tt.want[k])
t.Errorf("GetIP() = %v, want %v", got, tt.want[k])
}
} }
} })
}
for _, tt := range tests {
t.Run(tt.name, scenario(tt))
} }
} }
func Test_ipToInt32(t *testing.T) { func Test_ipToInt32(t *testing.T) {
t.Parallel()
type args struct { type args struct {
ip net.IP ip net.IP
} }
@ -101,14 +103,11 @@ func Test_ipToInt32(t *testing.T) {
want: 3232281098, want: 3232281098,
}, },
} }
scenario := func(tt testCase) func(t *testing.T) { for _, tc := range tests {
return func(t *testing.T) { tt := tc
if got := ipToInt32(tt.args.ip); got != tt.want { t.Run(tt.name, func(t *testing.T) {
t.Errorf("ipToInt32() = %v, want %v", got, tt.want) t.Parallel()
} td.Cmp(t, ipToInt32(tt.args.ip), tt.want)
} })
}
for _, tt := range tests {
t.Run(tt.name, scenario(tt))
} }
} }

View file

@ -53,8 +53,9 @@ func Benchmark_httpHandler(b *testing.B) {
scheme: "https", scheme: "https",
}, },
} }
scenario := func(bm benchmark) func(bm *testing.B) { for _, bc := range benchmarks {
return func(b *testing.B) { bm := bc
b.Run(bm.name, func(b *testing.B) {
var err error var err error
var endpoint string var endpoint string
if endpoint, err = setupContainer(b, bm.scheme, bm.port); err != nil { if endpoint, err = setupContainer(b, bm.scheme, bm.port); err != nil {
@ -88,10 +89,7 @@ func Benchmark_httpHandler(b *testing.B) {
} }
} }
}) })
} })
}
for _, bm := range benchmarks {
b.Run(bm.name, scenario(bm))
} }
} }

View file

@ -2,11 +2,11 @@ package mock
import ( import (
"path/filepath" "path/filepath"
"reflect"
"regexp" "regexp"
"testing" "testing"
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/maxatome/go-testdeep/td"
"github.com/mitchellh/mapstructure" "github.com/mitchellh/mapstructure"
endpoint_mock "gitlab.com/inetmock/inetmock/internal/mock/endpoint" endpoint_mock "gitlab.com/inetmock/inetmock/internal/mock/endpoint"
@ -14,6 +14,7 @@ import (
//nolint:funlen //nolint:funlen
func Test_loadFromConfig(t *testing.T) { func Test_loadFromConfig(t *testing.T) {
t.Parallel()
type args struct { type args struct {
config map[string]interface{} config map[string]interface{}
} }
@ -159,10 +160,11 @@ func Test_loadFromConfig(t *testing.T) {
wantErr: false, wantErr: false,
}, },
} }
scenario := func(tt testCase) func(t *testing.T) { for _, tc := range tests {
return func(t *testing.T) { tt := tc
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
lcMock := endpoint_mock.NewMockLifecycle(ctrl) lcMock := endpoint_mock.NewMockLifecycle(ctrl)
lcMock.EXPECT().UnmarshalOptions(gomock.Any()).Do(func(cfg interface{}) { lcMock.EXPECT().UnmarshalOptions(gomock.Any()).Do(func(cfg interface{}) {
@ -174,12 +176,7 @@ func Test_loadFromConfig(t *testing.T) {
t.Errorf("loadFromConfig() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("loadFromConfig() error = %v, wantErr %v", err, tt.wantErr)
return return
} }
if !reflect.DeepEqual(gotOptions, tt.wantOptions) { td.Cmp(t, gotOptions, tt.wantOptions)
t.Errorf("loadFromConfig() gotOptions = %v, want %v", gotOptions, tt.wantOptions) })
}
}
}
for _, tt := range tests {
t.Run(tt.name, scenario(tt))
} }
} }

View file

@ -53,8 +53,8 @@ func Benchmark_httpProxy(b *testing.B) {
scheme: "https", scheme: "https",
}, },
} }
scenario := func(bm benchmark) func(bm *testing.B) { for _, bm := range benchmarks {
return func(b *testing.B) { b.Run(bm.name, func(b *testing.B) {
var err error var err error
var endpoint string var endpoint string
if endpoint, err = setupContainer(b, bm.port); err != nil { if endpoint, err = setupContainer(b, bm.port); err != nil {
@ -90,10 +90,7 @@ func Benchmark_httpProxy(b *testing.B) {
} }
} }
}) })
} })
}
for _, bm := range benchmarks {
b.Run(bm.name, scenario(bm))
} }
} }

View file

@ -54,14 +54,15 @@ func (l *ListenerSpec) ConfigureMultiplexing(tlsConfig *tls.Config) ([]Endpoint,
} }
} }
var endpoints []Endpoint
if len(l.Endpoints) <= 1 { if len(l.Endpoints) <= 1 {
for name, s := range l.Endpoints { for name, s := range l.Endpoints {
endpoints = append(endpoints, Endpoint{ endpoints := []Endpoint{
name: fmt.Sprintf("%s:%s", l.Name, name), {
uplink: *l.Uplink, name: fmt.Sprintf("%s:%s", l.Name, name),
Spec: s, uplink: *l.Uplink,
}) Spec: s,
},
}
return endpoints, nil, nil return endpoints, nil, nil
} }
} }
@ -70,10 +71,12 @@ func (l *ListenerSpec) ConfigureMultiplexing(tlsConfig *tls.Config) ([]Endpoint,
return nil, nil, ErrUDPMultiplexer return nil, nil, ErrUDPMultiplexer
} }
var epNames []string var epNames = make([]string, len(l.Endpoints))
var multiplexEndpoints = make(map[string]MultiplexHandler) var multiplexEndpoints = make(map[string]MultiplexHandler)
var idx int
for name, spec := range l.Endpoints { for name, spec := range l.Endpoints {
epNames = append(epNames, name) epNames[idx] = name
idx++
if ep, ok := spec.Handler.(MultiplexHandler); !ok { if ep, ok := spec.Handler.(MultiplexHandler); !ok {
return nil, nil, fmt.Errorf("handler %s %w", spec.HandlerRef, ErrMultiplexingNotSupported) return nil, nil, fmt.Errorf("handler %s %w", spec.HandlerRef, ErrMultiplexingNotSupported)
} else { } else {
@ -90,7 +93,8 @@ func (l *ListenerSpec) ConfigureMultiplexing(tlsConfig *tls.Config) ([]Endpoint,
var tlsRequired = false var tlsRequired = false
for _, epName := range epNames { var endpoints = make([]Endpoint, len(epNames))
for i, epName := range epNames {
epSpec := l.Endpoints[epName] epSpec := l.Endpoints[epName]
var epMux = plainMux var epMux = plainMux
if epSpec.TLS { if epSpec.TLS {
@ -106,7 +110,7 @@ func (l *ListenerSpec) ConfigureMultiplexing(tlsConfig *tls.Config) ([]Endpoint,
Spec: epSpec, Spec: epSpec,
} }
endpoints = append(endpoints, epListener) endpoints[i] = epListener
} }
var muxes []cmux.CMux var muxes []cmux.CMux

View file

@ -7,6 +7,7 @@ import (
//nolint:funlen //nolint:funlen
func Test_tblWriter_Write(t *testing.T) { func Test_tblWriter_Write(t *testing.T) {
t.Parallel()
type s1 struct { type s1 struct {
Name string Name string
Age int Age int
@ -139,8 +140,10 @@ func Test_tblWriter_Write(t *testing.T) {
`, `,
}, },
} }
scenario := func(tt testCase) func(t *testing.T) { for _, tc := range tests {
return func(t *testing.T) { tt := tc
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
bldr := new(strings.Builder) bldr := new(strings.Builder)
// hack to be able to format expected strings pretty // hack to be able to format expected strings pretty
@ -153,9 +156,6 @@ func Test_tblWriter_Write(t *testing.T) {
if bldr.String() != tt.wantResult { if bldr.String() != tt.wantResult {
t.Errorf("Write() got = %s, want %s", bldr.String(), tt.wantResult) t.Errorf("Write() got = %s, want %s", bldr.String(), tt.wantResult)
} }
} })
}
for _, tt := range tests {
t.Run(tt.name, scenario(tt))
} }
} }

View file

@ -21,6 +21,7 @@ const (
) )
func Test_recorder_CompleteWorkflow(t *testing.T) { func Test_recorder_CompleteWorkflow(t *testing.T) {
t.Parallel()
var recorder = pcap.NewRecorder() var recorder = pcap.NewRecorder()
var buffer = bytes.NewBuffer(nil) var buffer = bytes.NewBuffer(nil)
var err error var err error

View file

@ -30,6 +30,7 @@ func (f fakeWriterCloser) Close() error {
} }
func Test_recorder_Subscriptions(t *testing.T) { func Test_recorder_Subscriptions(t *testing.T) {
t.Parallel()
type subscriptionRequest struct { type subscriptionRequest struct {
Name string Name string
Device string Device string
@ -82,8 +83,10 @@ func Test_recorder_Subscriptions(t *testing.T) {
}, },
}, },
} }
scenario := func(tt testCase) func(t *testing.T) { for _, tc := range tests {
return func(t *testing.T) { tt := tc
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
r := pcap.NewRecorder() r := pcap.NewRecorder()
t.Cleanup(func() { t.Cleanup(func() {
@ -101,14 +104,12 @@ func Test_recorder_Subscriptions(t *testing.T) {
if gotSubscriptions := sortSubscriptions(r.Subscriptions()); !reflect.DeepEqual(gotSubscriptions, tt.wantSubscriptions) { if gotSubscriptions := sortSubscriptions(r.Subscriptions()); !reflect.DeepEqual(gotSubscriptions, tt.wantSubscriptions) {
t.Errorf("Subscriptions() = %v, want %v", gotSubscriptions, tt.wantSubscriptions) t.Errorf("Subscriptions() = %v, want %v", gotSubscriptions, tt.wantSubscriptions)
} }
} })
}
for _, tt := range tests {
t.Run(tt.name, scenario(tt))
} }
} }
func Test_recorder_StartRecordingWithOptions(t *testing.T) { func Test_recorder_StartRecordingWithOptions(t *testing.T) {
t.Parallel()
type args struct { type args struct {
device string device string
consumer pcap.Consumer consumer pcap.Consumer
@ -155,8 +156,10 @@ func Test_recorder_StartRecordingWithOptions(t *testing.T) {
wantErr: true, wantErr: true,
}, },
} }
scenario := func(tt testCase) func(t *testing.T) { for _, tc := range tests {
return func(t *testing.T) { tt := tc
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
var err error var err error
var recorder pcap.Recorder var recorder pcap.Recorder
@ -174,14 +177,12 @@ func Test_recorder_StartRecordingWithOptions(t *testing.T) {
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("StartRecordingWithOptions() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("StartRecordingWithOptions() error = %v, wantErr %v", err, tt.wantErr)
} }
} })
}
for _, tt := range tests {
t.Run(tt.name, scenario(tt))
} }
} }
func Test_recorder_AvailableDevices(t *testing.T) { func Test_recorder_AvailableDevices(t *testing.T) {
t.Parallel()
type testCase struct { type testCase struct {
name string name string
mtacher func(got []pcap.Device) error mtacher func(got []pcap.Device) error
@ -209,8 +210,11 @@ func Test_recorder_AvailableDevices(t *testing.T) {
wantErr: false, wantErr: false,
}, },
} }
scenario := func(tt testCase) func(t *testing.T) {
return func(t *testing.T) { for _, tc := range tests {
tt := tc
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
re := pcap.NewRecorder() re := pcap.NewRecorder()
t.Cleanup(func() { t.Cleanup(func() {
if err := re.Close(); err != nil { if err := re.Close(); err != nil {
@ -225,14 +229,12 @@ func Test_recorder_AvailableDevices(t *testing.T) {
if err := tt.mtacher(gotDevices); err != nil { if err := tt.mtacher(gotDevices); err != nil {
t.Errorf("AvailableDevices() matcher error = %v", err) t.Errorf("AvailableDevices() matcher error = %v", err)
} }
} })
}
for _, tt := range tests {
t.Run(tt.name, scenario(tt))
} }
} }
func Test_recorder_StopRecording(t *testing.T) { func Test_recorder_StopRecording(t *testing.T) {
t.Parallel()
type args struct { type args struct {
consumerKey string consumerKey string
} }
@ -249,6 +251,7 @@ func Test_recorder_StopRecording(t *testing.T) {
consumerKey: "lo:test.pcap", consumerKey: "lo:test.pcap",
}, },
recorderSetup: func(t *testing.T) (recorder pcap.Recorder, err error) { recorderSetup: func(t *testing.T) (recorder pcap.Recorder, err error) {
t.Helper()
return pcap.NewRecorder(), nil return pcap.NewRecorder(), nil
}, },
wantErr: true, wantErr: true,
@ -259,6 +262,7 @@ func Test_recorder_StopRecording(t *testing.T) {
consumerKey: "lo:test", consumerKey: "lo:test",
}, },
recorderSetup: func(t *testing.T) (recorder pcap.Recorder, err error) { recorderSetup: func(t *testing.T) (recorder pcap.Recorder, err error) {
t.Helper()
recorder = pcap.NewRecorder() recorder = pcap.NewRecorder()
err = recorder.StartRecording(context.Background(), "lo", consumers.NewNoOpConsumerWithName("test")) err = recorder.StartRecording(context.Background(), "lo", consumers.NewNoOpConsumerWithName("test"))
@ -272,6 +276,7 @@ func Test_recorder_StopRecording(t *testing.T) {
consumerKey: "lo:test", consumerKey: "lo:test",
}, },
recorderSetup: func(t *testing.T) (recorder pcap.Recorder, err error) { recorderSetup: func(t *testing.T) (recorder pcap.Recorder, err error) {
t.Helper()
recorder = pcap.NewRecorder() recorder = pcap.NewRecorder()
var writerConsumer pcap.Consumer var writerConsumer pcap.Consumer
gotClosed := false gotClosed := false
@ -299,8 +304,10 @@ func Test_recorder_StopRecording(t *testing.T) {
wantErr: false, wantErr: false,
}, },
} }
scenario := func(tt testCase) func(t *testing.T) { for _, tc := range tests {
return func(t *testing.T) { tt := tc
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
var err error var err error
var r pcap.Recorder var r pcap.Recorder
if r, err = tt.recorderSetup(t); err != nil { if r, err = tt.recorderSetup(t); err != nil {
@ -316,14 +323,12 @@ func Test_recorder_StopRecording(t *testing.T) {
if err := r.StopRecording(tt.args.consumerKey); (err != nil) != tt.wantErr { if err := r.StopRecording(tt.args.consumerKey); (err != nil) != tt.wantErr {
t.Errorf("StopRecording() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("StopRecording() error = %v, wantErr %v", err, tt.wantErr)
} }
} })
}
for _, tt := range tests {
t.Run(tt.name, scenario(tt))
} }
} }
func Test_recorder_StopRecordingFromContext(t *testing.T) { func Test_recorder_StopRecordingFromContext(t *testing.T) {
t.Parallel()
var err error var err error
var writerConsumer pcap.Consumer var writerConsumer pcap.Consumer
gotClosed := false gotClosed := false

View file

@ -28,6 +28,14 @@ type auditServer struct {
auditDataDirPath string auditDataDirPath string
} }
func NewAuditServiceServer(logger logging.Logger, eventStream audit.EventStream, auditDataDirPath string) v1.AuditServiceServer {
return &auditServer{
logger: logger,
eventStream: eventStream,
auditDataDirPath: auditDataDirPath,
}
}
func (a *auditServer) ListSinks(context.Context, *v1.ListSinksRequest) (*v1.ListSinksResponse, error) { func (a *auditServer) ListSinks(context.Context, *v1.ListSinksRequest) (*v1.ListSinksResponse, error) {
return &v1.ListSinksResponse{ return &v1.ListSinksResponse{
Sinks: a.eventStream.Sinks(), Sinks: a.eventStream.Sinks(),
@ -35,7 +43,8 @@ func (a *auditServer) ListSinks(context.Context, *v1.ListSinksRequest) (*v1.List
} }
func (a *auditServer) WatchEvents(req *v1.WatchEventsRequest, srv v1.AuditService_WatchEventsServer) (err error) { func (a *auditServer) WatchEvents(req *v1.WatchEventsRequest, srv v1.AuditService_WatchEventsServer) (err error) {
a.logger.Info("watcher attached", zap.String("name", req.WatcherName)) var logger = a.logger
logger.Info("watcher attached", zap.String("name", req.WatcherName))
err = a.eventStream.RegisterSink(srv.Context(), sink.NewGenericSink(req.WatcherName, func(ev audit.Event) { err = a.eventStream.RegisterSink(srv.Context(), sink.NewGenericSink(req.WatcherName, func(ev audit.Event) {
if err = srv.Send(&v1.WatchEventsResponse{Entity: ev.ProtoMessage()}); err != nil { if err = srv.Send(&v1.WatchEventsResponse{Entity: ev.ProtoMessage()}); err != nil {
return return
@ -47,7 +56,7 @@ func (a *auditServer) WatchEvents(req *v1.WatchEventsRequest, srv v1.AuditServic
} }
<-srv.Context().Done() <-srv.Context().Done()
a.logger.Info("Watcher detached", zap.String("name", req.WatcherName)) logger.Info("Watcher detached", zap.String("name", req.WatcherName))
return return
} }
@ -76,7 +85,9 @@ func (a *auditServer) RegisterFileSink(_ context.Context, req *v1.RegisterFileSi
func (a *auditServer) RemoveFileSink(_ context.Context, req *v1.RemoveFileSinkRequest) (*v1.RemoveFileSinkResponse, error) { func (a *auditServer) RemoveFileSink(_ context.Context, req *v1.RemoveFileSinkRequest) (*v1.RemoveFileSinkResponse, error) {
if gotRemoved := a.eventStream.RemoveSink(req.TargetPath); gotRemoved { if gotRemoved := a.eventStream.RemoveSink(req.TargetPath); gotRemoved {
return &v1.RemoveFileSinkResponse{}, nil return &v1.RemoveFileSinkResponse{
SinkGotRemoved: gotRemoved,
}, nil
} }
return nil, status.Error(codes.NotFound, "file sink with given target path not found") return nil, status.Error(codes.NotFound, "file sink with given target path not found")
} }

View file

@ -0,0 +1,114 @@
package rpc_test
import (
"context"
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/maxatome/go-testdeep/td"
"google.golang.org/grpc"
auditm "gitlab.com/inetmock/inetmock/internal/mock/audit"
"gitlab.com/inetmock/inetmock/internal/rpc"
"gitlab.com/inetmock/inetmock/internal/rpc/test"
tst "gitlab.com/inetmock/inetmock/internal/test"
"gitlab.com/inetmock/inetmock/pkg/audit"
"gitlab.com/inetmock/inetmock/pkg/logging"
v1 "gitlab.com/inetmock/inetmock/pkg/rpc/v1"
)
const (
grpcTimeout = 100 * time.Millisecond
)
func Test_auditServer_RemoveFileSink(t *testing.T) {
t.Parallel()
type fields struct {
eventStreamSetup func(t *testing.T) audit.EventStream
}
tests := []struct {
name string
req *v1.RemoveFileSinkRequest
fields fields
want td.StructFields
wantErr bool
}{
{
name: "Remove existing file sink - success",
req: &v1.RemoveFileSinkRequest{
TargetPath: "test.pcap",
},
fields: fields{
eventStreamSetup: func(t *testing.T) audit.EventStream {
t.Helper()
ctrl := gomock.NewController(t)
es := auditm.NewMockEventStream(ctrl)
es.
EXPECT().
RemoveSink("test.pcap").
Return(true)
return es
},
},
want: td.StructFields{
"SinkGotRemoved": true,
},
wantErr: false,
},
{
name: "Remove non-existing file sink - success",
req: &v1.RemoveFileSinkRequest{
TargetPath: "test.pcap",
},
fields: fields{
eventStreamSetup: func(t *testing.T) audit.EventStream {
t.Helper()
ctrl := gomock.NewController(t)
es := auditm.NewMockEventStream(ctrl)
es.
EXPECT().
RemoveSink("test.pcap").
Return(false)
return es
},
},
wantErr: true,
},
}
for _, tc := range tests {
tt := tc
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
testCtx := tst.Context(t)
logger := logging.CreateTestLogger(t)
srv := test.NewTestGRPCServer(t, func(registrar grpc.ServiceRegistrar) {
v1.RegisterAuditServiceServer(registrar, rpc.NewAuditServiceServer(logger, tt.fields.eventStreamSetup(t), t.TempDir()))
})
ctx, cancel := context.WithTimeout(testCtx, grpcTimeout)
conn := srv.Dial(ctx, t)
cancel()
client := v1.NewAuditServiceClient(conn)
ctx, cancel = context.WithTimeout(testCtx, grpcTimeout)
t.Cleanup(cancel)
got, err := client.RemoveFileSink(ctx, tt.req)
if (err != nil) != tt.wantErr {
t.Errorf("RemoveFileSink() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !tt.wantErr {
td.CmpStruct(t, got, new(v1.RemoveFileSinkResponse), tt.want)
}
})
}
}

View file

@ -63,12 +63,7 @@ func (i *inetmockAPI) StartServer() (err error) {
checker: i.checker, checker: i.checker,
}) })
v1.RegisterAuditServiceServer(i.server, &auditServer{ v1.RegisterAuditServiceServer(i.server, NewAuditServiceServer(i.logger, i.eventStream, i.auditDataDir))
logger: i.logger,
eventStream: i.eventStream,
auditDataDirPath: i.auditDataDir,
})
v1.RegisterPCAPServiceServer(i.server, NewPCAPServer(i.pcapDataDir, pcap.NewRecorder())) v1.RegisterPCAPServiceServer(i.server, NewPCAPServer(i.pcapDataDir, pcap.NewRecorder()))
reflection.Register(i.server) reflection.Register(i.server)

View file

@ -7,6 +7,7 @@ import (
"reflect" "reflect"
"sort" "sort"
"testing" "testing"
"time"
"github.com/google/uuid" "github.com/google/uuid"
"google.golang.org/grpc" "google.golang.org/grpc"
@ -15,10 +16,12 @@ import (
"gitlab.com/inetmock/inetmock/internal/pcap/consumers" "gitlab.com/inetmock/inetmock/internal/pcap/consumers"
"gitlab.com/inetmock/inetmock/internal/rpc" "gitlab.com/inetmock/inetmock/internal/rpc"
"gitlab.com/inetmock/inetmock/internal/rpc/test" "gitlab.com/inetmock/inetmock/internal/rpc/test"
tst "gitlab.com/inetmock/inetmock/internal/test"
rpcV1 "gitlab.com/inetmock/inetmock/pkg/rpc/v1" rpcV1 "gitlab.com/inetmock/inetmock/pkg/rpc/v1"
) )
func Test_pcapServer_ListActiveRecordings(t *testing.T) { func Test_pcapServer_ListActiveRecordings(t *testing.T) {
t.Parallel()
type testCase struct { type testCase struct {
name string name string
recorderSetup func(t *testing.T) (recorder pcap.Recorder, err error) recorderSetup func(t *testing.T) (recorder pcap.Recorder, err error)
@ -29,6 +32,7 @@ func Test_pcapServer_ListActiveRecordings(t *testing.T) {
{ {
name: "No subscriptions", name: "No subscriptions",
recorderSetup: func(t *testing.T) (recorder pcap.Recorder, err error) { recorderSetup: func(t *testing.T) (recorder pcap.Recorder, err error) {
t.Helper()
recorder = pcap.NewRecorder() recorder = pcap.NewRecorder()
return return
}, },
@ -38,6 +42,7 @@ func Test_pcapServer_ListActiveRecordings(t *testing.T) {
{ {
name: "Listening to lo interface", name: "Listening to lo interface",
recorderSetup: func(t *testing.T) (recorder pcap.Recorder, err error) { recorderSetup: func(t *testing.T) (recorder pcap.Recorder, err error) {
t.Helper()
recorder = pcap.NewRecorder() recorder = pcap.NewRecorder()
err = recorder.StartRecording(context.Background(), "lo", consumers.NewNoOpConsumerWithName("test")) err = recorder.StartRecording(context.Background(), "lo", consumers.NewNoOpConsumerWithName("test"))
return return
@ -46,8 +51,10 @@ func Test_pcapServer_ListActiveRecordings(t *testing.T) {
wantErr: false, wantErr: false,
}, },
} }
scenario := func(tt testCase) func(t *testing.T) { for _, tc := range tests {
return func(t *testing.T) { tt := tc
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
var err error var err error
var recorder pcap.Recorder var recorder pcap.Recorder
if recorder, err = tt.recorderSetup(t); err != nil { if recorder, err = tt.recorderSetup(t); err != nil {
@ -68,24 +75,24 @@ func Test_pcapServer_ListActiveRecordings(t *testing.T) {
t.Errorf("ListActiveRecordings() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("ListActiveRecordings() error = %v, wantErr %v", err, tt.wantErr)
return return
} }
if (gotResp == nil) != tt.wantErr { if gotResp == nil {
t.Errorf("response was nil") if !tt.wantErr {
return t.Errorf("response was nil")
} else {
sort.Strings(gotResp.Subscriptions)
sort.Strings(tt.wantSubscriptions)
if !reflect.DeepEqual(gotResp.Subscriptions, tt.wantSubscriptions) {
t.Errorf("ListActiveRecordings() gotResp = %v, want %v", gotResp.Subscriptions, tt.wantSubscriptions)
} }
return
} }
}
} sort.Strings(gotResp.Subscriptions)
for _, tt := range tests { sort.Strings(tt.wantSubscriptions)
t.Run(tt.name, scenario(tt)) if !reflect.DeepEqual(gotResp.Subscriptions, tt.wantSubscriptions) {
t.Errorf("ListActiveRecordings() gotResp = %v, want %v", gotResp.Subscriptions, tt.wantSubscriptions)
}
})
} }
} }
func Test_pcapServer_ListAvailableDevices(t *testing.T) { func Test_pcapServer_ListAvailableDevices(t *testing.T) {
t.Parallel()
type testCase struct { type testCase struct {
name string name string
matcher func(devs []*rpcV1.ListAvailableDevicesResponse_PCAPDevice) error matcher func(devs []*rpcV1.ListAvailableDevicesResponse_PCAPDevice) error
@ -135,8 +142,10 @@ func Test_pcapServer_ListAvailableDevices(t *testing.T) {
wantErr: false, wantErr: false,
}, },
} }
scenario := func(tt testCase) func(t *testing.T) { for _, tc := range tests {
return func(t *testing.T) { tt := tc
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
var err error var err error
var recorder = pcap.NewRecorder() var recorder = pcap.NewRecorder()
@ -157,14 +166,12 @@ func Test_pcapServer_ListAvailableDevices(t *testing.T) {
t.Errorf("ListAvailableDevices() matcher error = %v", err) t.Errorf("ListAvailableDevices() matcher error = %v", err)
} }
} }
} })
}
for _, tt := range tests {
t.Run(tt.name, scenario(tt))
} }
} }
func Test_pcapServer_StartPCAPFileRecording(t *testing.T) { func Test_pcapServer_StartPCAPFileRecording(t *testing.T) {
t.Parallel()
type testCase struct { type testCase struct {
name string name string
req *rpcV1.StartPCAPFileRecordingRequest req *rpcV1.StartPCAPFileRecordingRequest
@ -196,8 +203,10 @@ func Test_pcapServer_StartPCAPFileRecording(t *testing.T) {
wantErr: true, wantErr: true,
}, },
} }
scenario := func(tt testCase) func(t *testing.T) { for _, tc := range tests {
return func(t *testing.T) { tt := tc
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
var err error var err error
var recorder = pcap.NewRecorder() var recorder = pcap.NewRecorder()
@ -217,14 +226,12 @@ func Test_pcapServer_StartPCAPFileRecording(t *testing.T) {
if currentSubs := recorder.Subscriptions(); !reflect.DeepEqual(currentSubs, tt.wantSubscriptions) { if currentSubs := recorder.Subscriptions(); !reflect.DeepEqual(currentSubs, tt.wantSubscriptions) {
t.Errorf("StartPCAPFileRecording() got = %v, want %v", currentSubs, tt.wantSubscriptions) t.Errorf("StartPCAPFileRecording() got = %v, want %v", currentSubs, tt.wantSubscriptions)
} }
} })
}
for _, tt := range tests {
t.Run(tt.name, scenario(tt))
} }
} }
func Test_pcapServer_StopPCAPFileRecord(t *testing.T) { func Test_pcapServer_StopPCAPFileRecord(t *testing.T) {
t.Parallel()
type testCase struct { type testCase struct {
name string name string
keyToRemove string keyToRemove string
@ -237,6 +244,7 @@ func Test_pcapServer_StopPCAPFileRecord(t *testing.T) {
name: "Remove non existing recording", name: "Remove non existing recording",
keyToRemove: "lo:asdf.pcap", keyToRemove: "lo:asdf.pcap",
recorderSetup: func(t *testing.T) (recorder pcap.Recorder, err error) { recorderSetup: func(t *testing.T) (recorder pcap.Recorder, err error) {
t.Helper()
recorder = pcap.NewRecorder() recorder = pcap.NewRecorder()
return return
}, },
@ -247,6 +255,7 @@ func Test_pcapServer_StopPCAPFileRecord(t *testing.T) {
name: "Remove an existing recording", name: "Remove an existing recording",
keyToRemove: "lo:test.pcap", keyToRemove: "lo:test.pcap",
recorderSetup: func(t *testing.T) (recorder pcap.Recorder, err error) { recorderSetup: func(t *testing.T) (recorder pcap.Recorder, err error) {
t.Helper()
recorder = pcap.NewRecorder() recorder = pcap.NewRecorder()
err = recorder.StartRecording(context.Background(), "lo", consumers.NewNoOpConsumerWithName("test.pcap")) err = recorder.StartRecording(context.Background(), "lo", consumers.NewNoOpConsumerWithName("test.pcap"))
return return
@ -255,8 +264,10 @@ func Test_pcapServer_StopPCAPFileRecord(t *testing.T) {
wantErr: false, wantErr: false,
}, },
} }
scenario := func(tt testCase) func(t *testing.T) { for _, tc := range tests {
return func(t *testing.T) { tt := tc
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
var err error var err error
var recorder pcap.Recorder var recorder pcap.Recorder
if recorder, err = tt.recorderSetup(t); err != nil { if recorder, err = tt.recorderSetup(t); err != nil {
@ -281,44 +292,21 @@ func Test_pcapServer_StopPCAPFileRecord(t *testing.T) {
if gotResp.Removed != tt.removedRecording { if gotResp.Removed != tt.removedRecording {
t.Errorf("StopPCAPFileRecord() removed = %v, want %v", gotResp.Removed, tt.removedRecording) t.Errorf("StopPCAPFileRecord() removed = %v, want %v", gotResp.Removed, tt.removedRecording)
} }
} })
}
for _, tt := range tests {
t.Run(tt.name, scenario(tt))
} }
} }
func setupTestPCAPServer(t *testing.T, recorder pcap.Recorder) rpcV1.PCAPServiceClient { func setupTestPCAPServer(t *testing.T, recorder pcap.Recorder) rpcV1.PCAPServiceClient {
t.Helper() t.Helper()
var err error
var srv *test.GRPCServer
p := rpc.NewPCAPServer(t.TempDir(), recorder) p := rpc.NewPCAPServer(t.TempDir(), recorder)
srv, err = test.NewTestGRPCServer(func(registrar grpc.ServiceRegistrar) {
srv := test.NewTestGRPCServer(t, func(registrar grpc.ServiceRegistrar) {
rpcV1.RegisterPCAPServiceServer(registrar, p) rpcV1.RegisterPCAPServiceServer(registrar, p)
}) })
if err != nil { ctx, cancel := context.WithTimeout(tst.Context(t), 100*time.Millisecond)
t.Fatalf("NewTestGRPCServer() error = %v", err) var conn = srv.Dial(ctx, t)
} cancel()
if err = srv.StartServer(); err != nil {
t.Fatalf("StartServer() error = %v", err)
}
t.Cleanup(func() {
srv.StopServer()
})
var conn *grpc.ClientConn
if conn, err = srv.Dial(context.Background(), grpc.WithInsecure()); err != nil {
t.Fatalf("Dial() error = %v", err)
}
t.Cleanup(func() {
if err := conn.Close(); err != nil {
t.Errorf("conn.Close() error = %v", err)
}
})
return rpcV1.NewPCAPServiceClient(conn) return rpcV1.NewPCAPServiceClient(conn)
} }

View file

@ -4,100 +4,66 @@ import (
"context" "context"
"errors" "errors"
"net" "net"
"sync" "testing"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/test/bufconn"
)
const (
bufconnBufferSize = 1024 * 1024
) )
type APIRegistration func(registrar grpc.ServiceRegistrar) type APIRegistration func(registrar grpc.ServiceRegistrar)
type GRPCServer struct { type GRPCServer struct {
mutex sync.Mutex server *grpc.Server
serverRunning chan struct{} listener *bufconn.Listener
cancel context.CancelFunc
server *grpc.Server
addr *net.TCPAddr
listener net.Listener
} }
func NewTestGRPCServer(registrations ...APIRegistration) (srv *GRPCServer, err error) { func NewTestGRPCServer(tb testing.TB, registrations ...APIRegistration) (srv *GRPCServer) {
srv = new(GRPCServer) tb.Helper()
if srv.listener, err = net.Listen("tcp", "127.0.0.1:0"); err != nil { srv = &GRPCServer{
return listener: bufconn.Listen(bufconnBufferSize),
server: grpc.NewServer(),
} }
var isTCPAddr bool
if srv.addr, isTCPAddr = srv.listener.Addr().(*net.TCPAddr); !isTCPAddr {
err = errors.New("expected TPC addr but wasn't")
}
srv.server = grpc.NewServer()
for _, reg := range registrations { for _, reg := range registrations {
reg(srv.server) reg(srv.server)
} }
return ctx, cancel := context.WithCancel(context.Background())
} tb.Cleanup(cancel)
go srv.startServerLifecycle(ctx, tb)
func (t *GRPCServer) StartServer() (err error) {
t.mutex.Lock()
var ctx context.Context
ctx, t.cancel = context.WithCancel(context.Background())
t.mutex.Unlock()
errs := t.startServerAsync(ctx)
select {
case err = <-errs:
default:
}
return return
} }
func (t *GRPCServer) Dial(ctx context.Context, opts ...grpc.DialOption) (*grpc.ClientConn, error) { func (t *GRPCServer) Dial(ctx context.Context, tb testing.TB, opts ...grpc.DialOption) *grpc.ClientConn {
opts = append(opts, grpc.WithInsecure()) tb.Helper()
opts = append(opts, grpc.WithInsecure(), grpc.WithContextDialer(func(context.Context, string) (net.Conn, error) {
return grpc.DialContext(ctx, t.addr.String(), opts...) return t.listener.Dial()
} }))
conn, err := grpc.DialContext(ctx, "", opts...)
func (t *GRPCServer) StopServer() { if err != nil {
t.cancel() tb.Fatalf("failed to connect to gRPC test server - error = %v", err)
}
func (t *GRPCServer) IsRunning() bool {
t.mutex.Lock()
defer t.mutex.Unlock()
if t.serverRunning == nil {
return false
}
select {
case _, more := <-t.serverRunning:
return more
default:
return true
} }
tb.Cleanup(func() {
if err := conn.Close(); err != nil {
tb.Errorf("Failed to close bufconn connection error = %v", err)
}
})
return conn
} }
func (t *GRPCServer) startServerAsync(ctx context.Context) (errs chan error) { func (t *GRPCServer) startServerLifecycle(ctx context.Context, tb testing.TB) {
errs = make(chan error) tb.Helper()
t.mutex.Lock()
t.serverRunning = make(chan struct{})
t.mutex.Unlock()
defer close(t.serverRunning)
go func() { go func() {
<-ctx.Done() <-ctx.Done()
if t.IsRunning() { t.server.Stop()
t.server.Stop()
}
}() }()
go func() { if err := t.server.Serve(t.listener); err != nil && !errors.Is(err, grpc.ErrServerStopped) {
if err := t.server.Serve(t.listener); err != nil && !errors.Is(err, grpc.ErrServerStopped) { tb.Errorf("error occurred during running the gRPC test server - error - %v", err)
errs <- err }
}
}()
return
} }

18
internal/test/context.go Normal file
View file

@ -0,0 +1,18 @@
package test
import (
"context"
"testing"
)
func Context(t *testing.T) context.Context {
t.Helper()
if deadline, ok := t.Deadline(); ok {
ctx, cancel := context.WithDeadline(context.Background(), deadline)
t.Cleanup(cancel)
return ctx
}
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
return ctx
}

View file

@ -16,6 +16,7 @@ import (
const containerShutdownTimeout = 5 * time.Second const containerShutdownTimeout = 5 * time.Second
func SetupINetMockContainer(ctx context.Context, tb testing.TB, exposedPorts ...string) (testcontainers.Container, error) { func SetupINetMockContainer(ctx context.Context, tb testing.TB, exposedPorts ...string) (testcontainers.Container, error) {
tb.Helper()
//nolint:dogsled //nolint:dogsled
_, fileName, _, _ := runtime.Caller(0) _, fileName, _, _ := runtime.Caller(0)

View file

@ -62,7 +62,6 @@ func NewEventStream(logger logging.Logger, options ...EventStreamOption) (EventS
return underlying, err return underlying, err
} }
//nolint:gocritic
func (e *eventStream) Emit(ev Event) { func (e *eventStream) Emit(ev Event) {
ev.ApplyDefaults(e.idGenerator.Generate().Int64()) ev.ApplyDefaults(e.idGenerator.Generate().Int64())
select { select {

View file

@ -53,17 +53,19 @@ var (
} }
) )
func wgMockSink(t testing.TB, wg *sync.WaitGroup) audit.Sink { func wgMockSink(tb testing.TB, wg *sync.WaitGroup) audit.Sink {
tb.Helper()
return sink.NewGenericSink( return sink.NewGenericSink(
"WG mock sink", "WG mock sink",
func(event audit.Event) { func(event audit.Event) {
t.Logf("Got event = %v", event) tb.Logf("Got event = %v", event)
wg.Done() wg.Done()
}, },
) )
} }
func Test_eventStream_RegisterSink(t *testing.T) { func Test_eventStream_RegisterSink(t *testing.T) {
t.Parallel()
type args struct { type args struct {
s audit.Sink s audit.Sink
} }
@ -92,8 +94,10 @@ func Test_eventStream_RegisterSink(t *testing.T) {
wantErr: true, wantErr: true,
}, },
} }
scenario := func(tt testCase) func(t *testing.T) { for _, tc := range tests {
return func(t *testing.T) { tt := tc
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
var err error var err error
var e audit.EventStream var e audit.EventStream
if e, err = audit.NewEventStream(logging.CreateTestLogger(t)); err != nil { if e, err = audit.NewEventStream(logging.CreateTestLogger(t)); err != nil {
@ -121,15 +125,12 @@ func Test_eventStream_RegisterSink(t *testing.T) {
if !found { if !found {
t.Errorf("expected defaultSink name %s not found in registered sinks %v", tt.args.s.Name(), e.Sinks()) t.Errorf("expected defaultSink name %s not found in registered sinks %v", tt.args.s.Name(), e.Sinks())
} }
} })
}
for _, tt := range tests {
t.Run(tt.name, scenario(tt))
} }
} }
func Test_eventStream_Emit(t *testing.T) { func Test_eventStream_Emit(t *testing.T) {
t.Parallel()
type args struct { type args struct {
evs []*audit.Event evs []*audit.Event
opts []audit.EventStreamOption opts []audit.EventStreamOption
@ -165,9 +166,10 @@ func Test_eventStream_Emit(t *testing.T) {
subscribe: false, subscribe: false,
}, },
} }
for _, tc := range tests {
scenario := func(tt testCase) func(t *testing.T) { tt := tc
return func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
t.Parallel()
var err error var err error
var e audit.EventStream var e audit.EventStream
if e, err = audit.NewEventStream(logging.CreateTestLogger(t), tt.args.opts...); err != nil { if e, err = audit.NewEventStream(logging.CreateTestLogger(t), tt.args.opts...); err != nil {
@ -208,10 +210,6 @@ func Test_eventStream_Emit(t *testing.T) {
case <-time.After(5 * time.Second): case <-time.After(5 * time.Second):
t.Errorf("did not get all expected events in time") t.Errorf("did not get all expected events in time")
} }
} })
}
for _, tt := range tests {
t.Run(tt.name, scenario(tt))
} }
} }

View file

@ -2,9 +2,9 @@ package audit
import ( import (
"net/http" "net/http"
"reflect"
"testing" "testing"
"github.com/maxatome/go-testdeep/td"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/anypb" "google.golang.org/protobuf/types/known/anypb"
@ -13,6 +13,7 @@ import (
) )
func Test_guessDetailsFromApp(t *testing.T) { func Test_guessDetailsFromApp(t *testing.T) {
t.Parallel()
mustAny := func(msg proto.Message) *anypb.Any { mustAny := func(msg proto.Message) *anypb.Any {
a, err := anypb.New(msg) a, err := anypb.New(msg)
if err != nil { if err != nil {
@ -49,14 +50,12 @@ func Test_guessDetailsFromApp(t *testing.T) {
}, },
}, },
} }
scenario := func(tt testCase) func(t *testing.T) { for _, tc := range tests {
return func(t *testing.T) { tt := tc
if got := guessDetailsFromApp(tt.args.any); !reflect.DeepEqual(got, tt.want) { t.Run(tt.name, func(t *testing.T) {
t.Errorf("guessDetailsFromApp() = %v, want %v", got, tt.want) t.Parallel()
} got := guessDetailsFromApp(tt.args.any)
} td.Cmp(t, got, tt.want)
} })
for _, tt := range tests {
t.Run(tt.name, scenario(tt))
} }
} }

View file

@ -4,9 +4,10 @@ import (
"bytes" "bytes"
"encoding/hex" "encoding/hex"
"io" "io"
"reflect"
"testing" "testing"
"github.com/maxatome/go-testdeep/td"
"gitlab.com/inetmock/inetmock/pkg/audit" "gitlab.com/inetmock/inetmock/pkg/audit"
) )
@ -26,6 +27,7 @@ func mustDecodeHex(hexBytes string) io.Reader {
} }
func Test_eventReader_Read(t *testing.T) { func Test_eventReader_Read(t *testing.T) {
t.Parallel()
type fields struct { type fields struct {
source io.Reader source io.Reader
} }
@ -53,8 +55,10 @@ func Test_eventReader_Read(t *testing.T) {
wantErr: false, wantErr: false,
}, },
} }
scenario := func(tt testCase) func(t *testing.T) { for _, tc := range tests {
return func(t *testing.T) { tt := tc
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
e := audit.NewEventReader(tt.fields.source) e := audit.NewEventReader(tt.fields.source)
gotEv, err := e.Read() gotEv, err := e.Read()
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
@ -62,12 +66,7 @@ func Test_eventReader_Read(t *testing.T) {
return return
} }
if err == nil && !reflect.DeepEqual(gotEv, *tt.wantEv) { td.Cmp(t, gotEv, *tt.wantEv)
t.Errorf("Read() gotEv = %v, want %v", gotEv, tt.wantEv) })
}
}
}
for _, tt := range tests {
t.Run(tt.name, scenario(tt))
} }
} }

View file

@ -13,6 +13,7 @@ import (
) )
func Test_genericSink_OnSubscribe(t *testing.T) { func Test_genericSink_OnSubscribe(t *testing.T) {
t.Parallel()
type testCase struct { type testCase struct {
name string name string
events []*audit.Event events []*audit.Event
@ -27,8 +28,10 @@ func Test_genericSink_OnSubscribe(t *testing.T) {
events: testEvents, events: testEvents,
}, },
} }
scenario := func(tt testCase) func(t *testing.T) { for _, tc := range tests {
return func(t *testing.T) { tt := tc
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
wg := new(sync.WaitGroup) wg := new(sync.WaitGroup)
wg.Add(len(tt.events)) wg.Add(len(tt.events))
@ -57,9 +60,6 @@ func Test_genericSink_OnSubscribe(t *testing.T) {
t.Errorf("not all events recorded in time") t.Errorf("not all events recorded in time")
case <-wait.ForWaitGroupDone(wg): case <-wait.ForWaitGroupDone(wg):
} }
} })
}
for _, tt := range tests {
t.Run(tt.name, scenario(tt))
} }
} }

View file

@ -18,6 +18,7 @@ import (
//nolint:dupl //nolint:dupl
func Test_logSink_OnSubscribe(t *testing.T) { func Test_logSink_OnSubscribe(t *testing.T) {
t.Parallel()
type fields struct { type fields struct {
loggerSetup func(t *testing.T, wg *sync.WaitGroup) logging.Logger loggerSetup func(t *testing.T, wg *sync.WaitGroup) logging.Logger
} }
@ -31,8 +32,8 @@ func Test_logSink_OnSubscribe(t *testing.T) {
name: "Get a single log line", name: "Get a single log line",
fields: fields{ fields: fields{
loggerSetup: func(t *testing.T, wg *sync.WaitGroup) logging.Logger { loggerSetup: func(t *testing.T, wg *sync.WaitGroup) logging.Logger {
t.Helper()
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
loggerMock := logging_mock.NewMockLogger(ctrl) loggerMock := logging_mock.NewMockLogger(ctrl)
loggerMock. loggerMock.
@ -57,8 +58,8 @@ func Test_logSink_OnSubscribe(t *testing.T) {
name: "Get multiple events", name: "Get multiple events",
fields: fields{ fields: fields{
loggerSetup: func(t *testing.T, wg *sync.WaitGroup) logging.Logger { loggerSetup: func(t *testing.T, wg *sync.WaitGroup) logging.Logger {
t.Helper()
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
loggerMock := logging_mock.NewMockLogger(ctrl) loggerMock := logging_mock.NewMockLogger(ctrl)
loggerMock. loggerMock.
@ -80,8 +81,10 @@ func Test_logSink_OnSubscribe(t *testing.T) {
events: testEvents, events: testEvents,
}, },
} }
scenario := func(tt testCase) func(t *testing.T) { for _, tc := range tests {
return func(t *testing.T) { tt := tc
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
wg := new(sync.WaitGroup) wg := new(sync.WaitGroup)
wg.Add(len(tt.events)) wg.Add(len(tt.events))
@ -107,10 +110,6 @@ func Test_logSink_OnSubscribe(t *testing.T) {
t.Errorf("not all events recorded in time") t.Errorf("not all events recorded in time")
case <-wait.ForWaitGroupDone(wg): case <-wait.ForWaitGroupDone(wg):
} }
} })
}
for _, tt := range tests {
t.Run(tt.name, scenario(tt))
} }
} }

View file

@ -56,6 +56,7 @@ var (
) )
func Test_writerCloserSink_OnSubscribe(t *testing.T) { func Test_writerCloserSink_OnSubscribe(t *testing.T) {
t.Parallel()
type testCase struct { type testCase struct {
name string name string
events []*audit.Event events []*audit.Event
@ -70,13 +71,14 @@ func Test_writerCloserSink_OnSubscribe(t *testing.T) {
events: testEvents, events: testEvents,
}, },
} }
scenario := func(tt testCase) func(t *testing.T) { for _, tc := range tests {
return func(t *testing.T) { tt := tc
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
wg := new(sync.WaitGroup) wg := new(sync.WaitGroup)
wg.Add(len(tt.events)) wg.Add(len(tt.events))
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
writerMock := audit_mock.NewMockWriter(ctrl) writerMock := audit_mock.NewMockWriter(ctrl)
writerMock. writerMock.
@ -111,10 +113,6 @@ func Test_writerCloserSink_OnSubscribe(t *testing.T) {
t.Errorf("not all events recorded in time") t.Errorf("not all events recorded in time")
case <-wait.ForWaitGroupDone(wg): case <-wait.ForWaitGroupDone(wg):
} }
} })
}
for _, tt := range tests {
t.Run(tt.name, scenario(tt))
} }
} }

View file

@ -1,13 +1,12 @@
// Code generated by protoc-gen-go. DO NOT EDIT. // Code generated by protoc-gen-go. DO NOT EDIT.
// versions: // versions:
// protoc-gen-go v1.25.0 // protoc-gen-go v1.26.0
// protoc v3.15.2 // protoc v3.15.2
// source: audit/v1/dns_details.proto // source: audit/v1/dns_details.proto
package v1 package v1
import ( import (
proto "github.com/golang/protobuf/proto"
protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl" protoimpl "google.golang.org/protobuf/runtime/protoimpl"
reflect "reflect" reflect "reflect"
@ -21,10 +20,6 @@ const (
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
) )
// This is a compile-time assertion that a sufficiently up-to-date version
// of the legacy proto package is being used.
const _ = proto.ProtoPackageIsVersion4
type DNSOpCode int32 type DNSOpCode int32
const ( const (

View file

@ -1,13 +1,12 @@
// Code generated by protoc-gen-go. DO NOT EDIT. // Code generated by protoc-gen-go. DO NOT EDIT.
// versions: // versions:
// protoc-gen-go v1.25.0 // protoc-gen-go v1.26.0
// protoc v3.15.2 // protoc v3.15.2
// source: audit/v1/event_entity.proto // source: audit/v1/event_entity.proto
package v1 package v1
import ( import (
proto "github.com/golang/protobuf/proto"
protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl" protoimpl "google.golang.org/protobuf/runtime/protoimpl"
anypb "google.golang.org/protobuf/types/known/anypb" anypb "google.golang.org/protobuf/types/known/anypb"
@ -23,10 +22,6 @@ const (
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
) )
// This is a compile-time assertion that a sufficiently up-to-date version
// of the legacy proto package is being used.
const _ = proto.ProtoPackageIsVersion4
type TransportProtocol int32 type TransportProtocol int32
const ( const (

View file

@ -1,13 +1,12 @@
// Code generated by protoc-gen-go. DO NOT EDIT. // Code generated by protoc-gen-go. DO NOT EDIT.
// versions: // versions:
// protoc-gen-go v1.25.0 // protoc-gen-go v1.26.0
// protoc v3.15.2 // protoc v3.15.2
// source: audit/v1/http_details.proto // source: audit/v1/http_details.proto
package v1 package v1
import ( import (
proto "github.com/golang/protobuf/proto"
protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl" protoimpl "google.golang.org/protobuf/runtime/protoimpl"
reflect "reflect" reflect "reflect"
@ -21,10 +20,6 @@ const (
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
) )
// This is a compile-time assertion that a sufficiently up-to-date version
// of the legacy proto package is being used.
const _ = proto.ProtoPackageIsVersion4
type HTTPMethod int32 type HTTPMethod int32
const ( const (

View file

@ -20,6 +20,7 @@ type WriterCloserSyncer interface {
} }
func Test_eventWriter_Write(t *testing.T) { func Test_eventWriter_Write(t *testing.T) {
t.Parallel()
type args struct { type args struct {
evs []*audit.Event evs []*audit.Event
} }
@ -44,10 +45,11 @@ func Test_eventWriter_Write(t *testing.T) {
wantErr: false, wantErr: false,
}, },
} }
scenario := func(tt testCase) func(t *testing.T) { for _, tc := range tests {
return func(t *testing.T) { tt := tc
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
writerMock := audit_mock.NewMockWriterCloserSyncer(ctrl) writerMock := audit_mock.NewMockWriterCloserSyncer(ctrl)
calls := make([]*gomock.Call, 0) calls := make([]*gomock.Call, 0)
for i := 0; i < len(tt.args.evs); i++ { for i := 0; i < len(tt.args.evs); i++ {
@ -79,10 +81,6 @@ func Test_eventWriter_Write(t *testing.T) {
t.Errorf("Write() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("Write() error = %v, wantErr %v", err, tt.wantErr)
} }
} }
} })
}
for _, tt := range tests {
t.Run(tt.name, scenario(tt))
} }
} }

View file

@ -3,6 +3,7 @@ package cert
import "testing" import "testing"
func Test_extractIPFromAddress(t *testing.T) { func Test_extractIPFromAddress(t *testing.T) {
t.Parallel()
type args struct { type args struct {
addr string addr string
} }
@ -30,8 +31,10 @@ func Test_extractIPFromAddress(t *testing.T) {
}, },
}, },
} }
scenario := func(tt testCase) func(t *testing.T) { for _, tc := range tests {
return func(t *testing.T) { tt := tc
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got, err := extractIPFromAddress(tt.args.addr) got, err := extractIPFromAddress(tt.args.addr)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("extractIPFromAddress() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("extractIPFromAddress() error = %v, wantErr %v", err, tt.wantErr)
@ -40,9 +43,6 @@ func Test_extractIPFromAddress(t *testing.T) {
if got != tt.want { if got != tt.want {
t.Errorf("extractIPFromAddress() got = %v, want %v", got, tt.want) t.Errorf("extractIPFromAddress() got = %v, want %v", got, tt.want)
} }
} })
}
for _, tt := range tests {
t.Run(tt.name, scenario(tt))
} }
} }

View file

@ -6,11 +6,11 @@ import (
"fmt" "fmt"
"os" "os"
"path" "path"
"reflect"
"testing" "testing"
"time" "time"
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/maxatome/go-testdeep/td"
certmock "gitlab.com/inetmock/inetmock/internal/mock/cert" certmock "gitlab.com/inetmock/inetmock/internal/mock/cert"
) )
@ -27,11 +27,10 @@ var (
) )
func Test_certShouldBeRenewed(t *testing.T) { func Test_certShouldBeRenewed(t *testing.T) {
ctrl := gomock.NewController(t) t.Parallel()
defer ctrl.Finish()
type args struct { type args struct {
timeSource TimeSource timeSourceSetup func(ctrl *gomock.Controller) TimeSource
cert *x509.Certificate cert *x509.Certificate
} }
type testCase struct { type testCase struct {
name string name string
@ -42,7 +41,7 @@ func Test_certShouldBeRenewed(t *testing.T) {
{ {
name: "Expect cert should not be renewed right after creation", name: "Expect cert should not be renewed right after creation",
args: args{ args: args{
timeSource: func() TimeSource { timeSourceSetup: func(ctrl *gomock.Controller) TimeSource {
tsMock := certmock.NewMockTimeSource(ctrl) tsMock := certmock.NewMockTimeSource(ctrl)
tsMock. tsMock.
EXPECT(). EXPECT().
@ -50,7 +49,7 @@ func Test_certShouldBeRenewed(t *testing.T) {
Return(time.Now().UTC()). Return(time.Now().UTC()).
Times(1) Times(1)
return tsMock return tsMock
}(), },
cert: &x509.Certificate{ cert: &x509.Certificate{
NotAfter: time.Now().UTC().Add(serverRelativeValidity), NotAfter: time.Now().UTC().Add(serverRelativeValidity),
NotBefore: time.Now().UTC().Add(-serverRelativeValidity), NotBefore: time.Now().UTC().Add(-serverRelativeValidity),
@ -61,7 +60,7 @@ func Test_certShouldBeRenewed(t *testing.T) {
{ {
name: "Expect cert should be renewed if the remaining lifetime is less than a quarter of the total lifetime", name: "Expect cert should be renewed if the remaining lifetime is less than a quarter of the total lifetime",
args: args{ args: args{
timeSource: func() TimeSource { timeSourceSetup: func(ctrl *gomock.Controller) TimeSource {
tsMock := certmock.NewMockTimeSource(ctrl) tsMock := certmock.NewMockTimeSource(ctrl)
tsMock. tsMock.
EXPECT(). EXPECT().
@ -69,7 +68,7 @@ func Test_certShouldBeRenewed(t *testing.T) {
Return(time.Now().UTC().Add(serverRelativeValidity/2 + 1*time.Hour)). Return(time.Now().UTC().Add(serverRelativeValidity/2 + 1*time.Hour)).
Times(1) Times(1)
return tsMock return tsMock
}(), },
cert: &x509.Certificate{ cert: &x509.Certificate{
NotAfter: time.Now().UTC().Add(serverRelativeValidity), NotAfter: time.Now().UTC().Add(serverRelativeValidity),
NotBefore: time.Now().UTC().Add(-serverRelativeValidity), NotBefore: time.Now().UTC().Add(-serverRelativeValidity),
@ -78,41 +77,51 @@ func Test_certShouldBeRenewed(t *testing.T) {
want: true, want: true,
}, },
} }
scenario := func(tt testCase) func(t *testing.T) { for _, tc := range tests {
return func(t *testing.T) { tt := tc
if got := certShouldBeRenewed(tt.args.timeSource, tt.args.cert); got != tt.want { t.Run(tt.name, func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
if got := certShouldBeRenewed(tt.args.timeSourceSetup(ctrl), tt.args.cert); got != tt.want {
t.Errorf("certShouldBeRenewed() = %v, want %v", got, tt.want) t.Errorf("certShouldBeRenewed() = %v, want %v", got, tt.want)
} }
} })
}
for _, tt := range tests {
t.Run(tt.name, scenario(tt))
} }
} }
//nolint:funlen
func Test_fileSystemCache_Get(t *testing.T) { func Test_fileSystemCache_Get(t *testing.T) {
t.Parallel()
certGen := setupCertGen()
caCrt, _ := certGen.CACert(GenerationOptions{
CommonName: caCN,
})
srvCrt, _ := certGen.ServerCert(GenerationOptions{
CommonName: cnLocalhost,
}, caCrt)
type fields struct { type fields struct {
certCachePath string inMemCache map[string]*tls.Certificate
inMemCache map[string]*tls.Certificate timeSource TimeSource
timeSource TimeSource
} }
type args struct { type args struct {
cn string cn string
} }
type testCase struct { type testCase struct {
name string name string
fields fields fields fields
args args polluteCertCache bool
wantOk bool args args
wantOk bool
} }
tests := []testCase{ tests := []testCase{
{ {
name: "Get a miss when no cert is present", name: "Get a miss when no cert is present",
fields: fields{ fields: fields{
certCachePath: os.TempDir(), inMemCache: make(map[string]*tls.Certificate),
inMemCache: make(map[string]*tls.Certificate), timeSource: NewTimeSource(),
timeSource: NewTimeSource(),
}, },
args: args{ args: args{
cnLocalhost, cnLocalhost,
@ -121,25 +130,12 @@ func Test_fileSystemCache_Get(t *testing.T) {
}, },
{ {
name: "Get a prepared certificate from the memory cache", name: "Get a prepared certificate from the memory cache",
fields: func() fields { fields: fields{
certGen := setupCertGen() inMemCache: map[string]*tls.Certificate{
cnLocalhost: srvCrt,
caCrt, _ := certGen.CACert(GenerationOptions{ },
CommonName: caCN, timeSource: NewTimeSource(),
}) },
srvCrt, _ := certGen.ServerCert(GenerationOptions{
CommonName: cnLocalhost,
}, caCrt)
return fields{
certCachePath: os.TempDir(),
inMemCache: map[string]*tls.Certificate{
cnLocalhost: srvCrt,
},
timeSource: NewTimeSource(),
}
}(),
args: args{ args: args{
cn: cnLocalhost, cn: cnLocalhost,
}, },
@ -147,67 +143,50 @@ func Test_fileSystemCache_Get(t *testing.T) {
}, },
{ {
name: "Get a prepared certificate from the file system", name: "Get a prepared certificate from the file system",
fields: func() fields { fields: fields{
certGen := setupCertGen() inMemCache: make(map[string]*tls.Certificate),
timeSource: NewTimeSource(),
caCrt, _ := certGen.CACert(GenerationOptions{ },
CommonName: "INetMock",
})
srvCrt, _ := certGen.ServerCert(GenerationOptions{
CommonName: serverCN,
}, caCrt)
pem := NewPEM(srvCrt)
if err := pem.Write(serverCN, os.TempDir()); err != nil {
panic(err)
}
t.Cleanup(func() {
for _, f := range []string{
path.Join(os.TempDir(), fmt.Sprintf("%s.pem", serverCN)),
path.Join(os.TempDir(), fmt.Sprintf("%s.key", serverCN)),
} {
_ = os.Remove(f)
}
})
return fields{
certCachePath: os.TempDir(),
inMemCache: make(map[string]*tls.Certificate),
timeSource: NewTimeSource(),
}
}(),
args: args{ args: args{
cn: serverCN, cn: serverCN,
}, },
wantOk: true, polluteCertCache: true,
wantOk: true,
}, },
} }
scenario := func(tt testCase) func(t *testing.T) { for _, tc := range tests {
return func(t *testing.T) { tt := tc
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
dir := t.TempDir()
f := &fileSystemCache{ f := &fileSystemCache{
certCachePath: tt.fields.certCachePath, certCachePath: dir,
inMemCache: tt.fields.inMemCache, inMemCache: tt.fields.inMemCache,
timeSource: tt.fields.timeSource, timeSource: tt.fields.timeSource,
} }
if tt.polluteCertCache {
pem := NewPEM(srvCrt)
if err := pem.Write(serverCN, dir); err != nil {
t.Fatalf("polluteCertCache error = %v", err)
}
}
gotCrt, gotOk := f.Get(tt.args.cn) gotCrt, gotOk := f.Get(tt.args.cn)
if gotOk && (gotCrt == nil || !reflect.DeepEqual(reflect.TypeOf(new(tls.Certificate)), reflect.TypeOf(gotCrt))) { if gotOk && (gotCrt == nil || !td.CmpIsa(t, gotCrt, new(tls.Certificate))) {
t.Errorf("Wanted propert certificate but got %v", gotCrt) t.Errorf("Wanted propert certificate but got %v", gotCrt)
} }
if gotOk != tt.wantOk { if gotOk != tt.wantOk {
t.Errorf("Get() gotOk = %v, want %v", gotOk, tt.wantOk) t.Errorf("Get() gotOk = %v, want %v", gotOk, tt.wantOk)
} }
} })
}
for _, tt := range tests {
t.Run(tt.name, scenario(tt))
} }
} }
func Test_fileSystemCache_Put(t *testing.T) { func Test_fileSystemCache_Put(t *testing.T) {
t.Parallel()
type fields struct { type fields struct {
certCachePath string certCachePath string
inMemCache map[string]*tls.Certificate inMemCache map[string]*tls.Certificate
@ -281,8 +260,10 @@ func Test_fileSystemCache_Put(t *testing.T) {
wantErr: false, wantErr: false,
}, },
} }
scenario := func(tt testCase) func(t *testing.T) { for _, tc := range tests {
return func(t *testing.T) { tt := tc
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
f := &fileSystemCache{ f := &fileSystemCache{
certCachePath: tt.fields.certCachePath, certCachePath: tt.fields.certCachePath,
inMemCache: tt.fields.inMemCache, inMemCache: tt.fields.inMemCache,
@ -291,10 +272,7 @@ func Test_fileSystemCache_Put(t *testing.T) {
if err := f.Put(tt.args.cert); (err != nil) != tt.wantErr { if err := f.Put(tt.args.cert); (err != nil) != tt.wantErr {
t.Errorf("Put() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("Put() error = %v, wantErr %v", err, tt.wantErr)
} }
} })
}
for _, tt := range tests {
t.Run(tt.name, scenario(tt))
} }
} }

View file

@ -1,11 +1,13 @@
package cert package cert
import ( import (
"reflect"
"testing" "testing"
"github.com/maxatome/go-testdeep/td"
) )
func Test_certOptionsDefaulter(t *testing.T) { func Test_certOptionsDefaulter(t *testing.T) {
t.Parallel()
type testCase struct { type testCase struct {
name string name string
arg GenerationOptions arg GenerationOptions
@ -60,17 +62,14 @@ func Test_certOptionsDefaulter(t *testing.T) {
}, },
}, },
} }
scenario := func(tt testCase) func(t *testing.T) { for _, tc := range tests {
return func(t *testing.T) { tt := tc
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
if err := applyDefaultGenerationOptions(&tt.arg); err != nil { if err := applyDefaultGenerationOptions(&tt.arg); err != nil {
t.Errorf("applyDefaultGenerationOptions() error = %v", err) t.Errorf("applyDefaultGenerationOptions() error = %v", err)
} }
if !reflect.DeepEqual(tt.expected, tt.arg) { td.Cmp(t, tt.arg, tt.expected)
t.Errorf("Apply defaulter expected=%v got=%v", tt.expected, tt.arg) })
}
}
}
for _, tt := range tests {
t.Run(tt.name, scenario(tt))
} }
} }

View file

@ -30,12 +30,10 @@ type Generator interface {
ServerCert(options GenerationOptions, ca *tls.Certificate) (*tls.Certificate, error) ServerCert(options GenerationOptions, ca *tls.Certificate) (*tls.Certificate, error)
} }
//nolint:gocritic
func NewDefaultGenerator(options Options) Generator { func NewDefaultGenerator(options Options) Generator {
return NewGenerator(options, NewTimeSource(), defaultKeyProvider(options)) return NewGenerator(options, NewTimeSource(), defaultKeyProvider(options))
} }
//nolint:gocritic
func NewGenerator(options Options, source TimeSource, provider KeyProvider) Generator { func NewGenerator(options Options, source TimeSource, provider KeyProvider) Generator {
return &generator{ return &generator{
options: options, options: options,

View file

@ -1,10 +0,0 @@
package cert
/*func init() {
config.AddDefaultValue(certCachePathConfigKey, os.TempDir())
config.AddDefaultValue(ecdsaCurveConfigKey, string(CurveTypeED25519))
config.AddDefaultValue(caCertValidityNotBeforeKey, defaultCAValidityDuration)
config.AddDefaultValue(caCertValidityNotAfterKey, defaultCAValidityDuration)
config.AddDefaultValue(serverCertValidityNotBeforeKey, defaultServerValidityDuration)
config.AddDefaultValue(serverCertValidityNotAfterKey, defaultServerValidityDuration)
}*/

View file

@ -1,131 +0,0 @@
package cert
import (
"os"
"strings"
"testing"
"time"
"github.com/spf13/viper"
)
func readViper(cfg string) *viper.Viper {
vpr := viper.New()
vpr.SetConfigType("yaml")
if err := vpr.ReadConfig(strings.NewReader(cfg)); err != nil {
panic(err)
}
return vpr
}
//nolint:funlen
func Test_loadFromConfig(t *testing.T) {
type args struct {
config *viper.Viper
}
tests := []struct {
name string
args args
want Options
wantErr bool
}{
{
name: "Parse valid TLS configuration",
wantErr: false,
args: args{
config: readViper(`
tls:
ecdsaCurve: P256
validity:
ca:
notBeforeRelative: 17520h
notAfterRelative: 17520h
server:
NotBeforeRelative: 168h
NotAfterRelative: 168h
rootCaCert:
publicKey: ./ca.pem
privateKey: ./ca.key
certCachePath: /tmp/inetmock/
`),
},
want: Options{
RootCACert: File{
PublicKeyPath: "./ca.pem",
PrivateKeyPath: "./ca.key",
},
CertCachePath: "/tmp/inetmock/",
Curve: CurveTypeP256,
Validity: ValidityByPurpose{
CA: ValidityDuration{
NotBeforeRelative: 17520 * time.Hour,
NotAfterRelative: 17520 * time.Hour,
},
Server: ValidityDuration{
NotBeforeRelative: 168 * time.Hour,
NotAfterRelative: 168 * time.Hour,
},
},
},
},
{
name: "Get an error if CA public key path is missing",
args: args{
readViper(`
tls:
rootCaCert:
privateKey: ./ca.key
`),
},
want: Options{},
wantErr: true,
},
{
name: "Get an error if CA private key path is missing",
args: args{
readViper(`
tls:
rootCaCert:
publicKey: ./ca.pem
`),
},
want: Options{},
wantErr: true,
},
{
name: "Get default options if all required fields are set",
args: args{
readViper(`
tls:
rootCaCert:
publicKey: ./ca.pem
privateKey: ./ca.key
`),
},
want: Options{
RootCACert: File{
PublicKeyPath: "./ca.pem",
PrivateKeyPath: "./ca.key",
},
CertCachePath: os.TempDir(),
Curve: CurveTypeED25519,
Validity: ValidityByPurpose{
CA: ValidityDuration{
NotBeforeRelative: 17520 * time.Hour,
NotAfterRelative: 17520 * time.Hour,
},
Server: ValidityDuration{
NotBeforeRelative: 168 * time.Hour,
NotAfterRelative: 168 * time.Hour,
},
},
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
})
}
}

View file

@ -33,7 +33,6 @@ type Store interface {
TLSConfig() *tls.Config TLSConfig() *tls.Config
} }
//nolint:gocritic
func NewDefaultStore( func NewDefaultStore(
options Options, options Options,
logger logging.Logger, logger logging.Logger,
@ -47,7 +46,6 @@ func NewDefaultStore(
) )
} }
//nolint:gocritic
func NewStore( func NewStore(
options Options, options Options,
cache Cache, cache Cache,
@ -149,7 +147,6 @@ func (s *store) GetCertificate(serverName, ip string) (cert *tls.Certificate, er
return return
} }
//nolint:gocritic
func privateKeyForCurve(options Options) (privateKey interface{}, err error) { func privateKeyForCurve(options Options) (privateKey interface{}, err error) {
switch options.Curve { switch options.Curve {
case CurveTypeP224: case CurveTypeP224:

View file

@ -1,12 +1,14 @@
package health package health
import ( import (
"reflect"
"testing" "testing"
"github.com/maxatome/go-testdeep/td"
) )
//nolint:funlen //nolint:funlen
func Test_checker_IsHealthy(t *testing.T) { func Test_checker_IsHealthy(t *testing.T) {
t.Parallel()
type fields struct { type fields struct {
componentChecks map[string]Check componentChecks map[string]Check
} }
@ -145,17 +147,15 @@ func Test_checker_IsHealthy(t *testing.T) {
}, },
}, },
} }
scenario := func(tt testCase) func(t *testing.T) { for _, tc := range tests {
return func(t *testing.T) { tt := tc
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
c := &checker{ c := &checker{
componentChecks: tt.fields.componentChecks, componentChecks: tt.fields.componentChecks,
} }
if gotR := c.IsHealthy(); !reflect.DeepEqual(gotR, tt.wantR) { gotR := c.IsHealthy()
t.Errorf("IsHealthy() = %v, want %v", gotR, tt.wantR) td.Cmp(t, gotR, tt.wantR)
} })
}
}
for _, tt := range tests {
t.Run(tt.name, scenario(tt))
} }
} }

View file

@ -50,6 +50,7 @@ func CreateLogger() (Logger, error) {
} }
func CreateTestLogger(tb testing.TB) Logger { func CreateTestLogger(tb testing.TB) Logger {
tb.Helper()
logger := &testLogger{ logger := &testLogger{
testRunning: make(chan struct{}), testRunning: make(chan struct{}),
tb: tb, tb: tb,

View file

@ -9,6 +9,7 @@ import (
//nolint:funlen //nolint:funlen
func TestParseLevel(t *testing.T) { func TestParseLevel(t *testing.T) {
t.Parallel()
type args struct { type args struct {
levelString string levelString string
} }
@ -103,19 +104,19 @@ func TestParseLevel(t *testing.T) {
want: zap.NewAtomicLevelAt(zap.InfoLevel), want: zap.NewAtomicLevelAt(zap.InfoLevel),
}, },
} }
scenario := func(tt testCase) func(t *testing.T) { for _, tc := range tests {
return func(t *testing.T) { tt := tc
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
if got := ParseLevel(tt.args.levelString); !reflect.DeepEqual(got, tt.want) { if got := ParseLevel(tt.args.levelString); !reflect.DeepEqual(got, tt.want) {
t.Errorf("ParseLevel() = %v, want %v", got, tt.want) t.Errorf("ParseLevel() = %v, want %v", got, tt.want)
} }
} })
}
for _, tt := range tests {
t.Run(tt.name, scenario(tt))
} }
} }
func TestConfigureLogging(t *testing.T) { func TestConfigureLogging(t *testing.T) {
t.Parallel()
type args struct { type args struct {
level zap.AtomicLevel level zap.AtomicLevel
developmentLogging bool developmentLogging bool
@ -151,8 +152,10 @@ func TestConfigureLogging(t *testing.T) {
}, },
}, },
} }
scenario := func(tt testCase) func(t *testing.T) { for _, tc := range tests {
return func(t *testing.T) { tt := tc
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
ConfigureLogging(tt.args.level, tt.args.developmentLogging, tt.args.initialFields) ConfigureLogging(tt.args.level, tt.args.developmentLogging, tt.args.initialFields)
if loggingConfig.Development != tt.args.developmentLogging { if loggingConfig.Development != tt.args.developmentLogging {
t.Errorf("loggingConfig.Development = %t, want %t", loggingConfig.Development, tt.args.developmentLogging) t.Errorf("loggingConfig.Development = %t, want %t", loggingConfig.Development, tt.args.developmentLogging)
@ -168,9 +171,6 @@ func TestConfigureLogging(t *testing.T) {
t.Errorf("loggingConfig.InitialFields = %v, want %v", loggingConfig.InitialFields, tt.args.initialFields) t.Errorf("loggingConfig.InitialFields = %v, want %v", loggingConfig.InitialFields, tt.args.initialFields)
return return
} }
} })
}
for _, tt := range tests {
t.Run(tt.name, scenario(tt))
} }
} }

View file

@ -1,58 +1,47 @@
package path package path
import ( import (
"fmt"
"io/ioutil"
"math/rand"
"os" "os"
"path" "path/filepath"
"testing" "testing"
"github.com/maxatome/go-testdeep/td"
) )
func TestFileExists(t *testing.T) { func TestFileExists(t *testing.T) {
tmpFile, err := ioutil.TempFile("", "inetmock") t.Parallel()
if err != nil {
t.Errorf("failed to create temp file: %v", err)
}
defer func() {
tmpFile.Close()
os.Remove(tmpFile.Name())
}()
type args struct {
filename string
}
type testCase struct { type testCase struct {
name string name string
args args
want bool want bool
} }
tests := []testCase{ tests := []testCase{
{ {
name: "Ensure temp file exists", name: "Ensure temp file exists",
want: true, want: true,
args: args{
filename: tmpFile.Name(),
},
}, },
{ {
name: "Ensure random file name does not exist", name: "Ensure random file name does not exist",
want: false, want: false,
args: args{
//nolint:gosec
filename: path.Join(os.TempDir(), fmt.Sprintf("asdf-%d", rand.Uint32())),
},
}, },
} }
scenario := func(tt testCase) func(t *testing.T) { for _, tc := range tests {
return func(t *testing.T) { tt := tc
if got := FileExists(tt.args.filename); got != tt.want { t.Run(tt.name, func(t *testing.T) {
t.Parallel()
tempDir := t.TempDir()
var filePath = filepath.Join(tempDir, "nonexistent")
if tt.want {
tmpFile, err := os.CreateTemp(tempDir, "testFileExists.*.tmp")
if td.CmpNoError(t, err) {
t.Cleanup(func() {
_ = tmpFile.Close()
})
filePath = filepath.Join(tempDir, filepath.Base(tmpFile.Name()))
}
}
if got := FileExists(filePath); got != tt.want {
t.Errorf("FileExists() = %v, want %v", got, tt.want) t.Errorf("FileExists() = %v, want %v", got, tt.want)
} }
} })
}
for _, tt := range tests {
t.Run(tt.name, scenario(tt))
} }
} }

View file

@ -1,13 +1,12 @@
// Code generated by protoc-gen-go. DO NOT EDIT. // Code generated by protoc-gen-go. DO NOT EDIT.
// versions: // versions:
// protoc-gen-go v1.25.0 // protoc-gen-go v1.26.0
// protoc v3.15.2 // protoc v3.15.2
// source: rpc/v1/audit.proto // source: rpc/v1/audit.proto
package v1 package v1
import ( import (
proto "github.com/golang/protobuf/proto"
v1 "gitlab.com/inetmock/inetmock/pkg/audit/v1" v1 "gitlab.com/inetmock/inetmock/pkg/audit/v1"
protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl" protoimpl "google.golang.org/protobuf/runtime/protoimpl"
@ -22,10 +21,6 @@ const (
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
) )
// This is a compile-time assertion that a sufficiently up-to-date version
// of the legacy proto package is being used.
const _ = proto.ProtoPackageIsVersion4
type WatchEventsRequest struct { type WatchEventsRequest struct {
state protoimpl.MessageState state protoimpl.MessageState
sizeCache protoimpl.SizeCache sizeCache protoimpl.SizeCache

View file

@ -11,6 +11,7 @@ import (
// This is a compile-time assertion to ensure that this generated file // This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against. // is compatible with the grpc package it is being compiled against.
// Requires gRPC-Go v1.32.0 or later.
const _ = grpc.SupportPackageIsVersion7 const _ = grpc.SupportPackageIsVersion7
// AuditServiceClient is the client API for AuditService service. // AuditServiceClient is the client API for AuditService service.
@ -32,7 +33,7 @@ func NewAuditServiceClient(cc grpc.ClientConnInterface) AuditServiceClient {
} }
func (c *auditServiceClient) WatchEvents(ctx context.Context, in *WatchEventsRequest, opts ...grpc.CallOption) (AuditService_WatchEventsClient, error) { func (c *auditServiceClient) WatchEvents(ctx context.Context, in *WatchEventsRequest, opts ...grpc.CallOption) (AuditService_WatchEventsClient, error) {
stream, err := c.cc.NewStream(ctx, &_AuditService_serviceDesc.Streams[0], "/inetmock.rpc.v1.AuditService/WatchEvents", opts...) stream, err := c.cc.NewStream(ctx, &AuditService_ServiceDesc.Streams[0], "/inetmock.rpc.v1.AuditService/WatchEvents", opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -127,7 +128,7 @@ type UnsafeAuditServiceServer interface {
} }
func RegisterAuditServiceServer(s grpc.ServiceRegistrar, srv AuditServiceServer) { func RegisterAuditServiceServer(s grpc.ServiceRegistrar, srv AuditServiceServer) {
s.RegisterService(&_AuditService_serviceDesc, srv) s.RegisterService(&AuditService_ServiceDesc, srv)
} }
func _AuditService_WatchEvents_Handler(srv interface{}, stream grpc.ServerStream) error { func _AuditService_WatchEvents_Handler(srv interface{}, stream grpc.ServerStream) error {
@ -205,7 +206,10 @@ func _AuditService_ListSinks_Handler(srv interface{}, ctx context.Context, dec f
return interceptor(ctx, in, info, handler) return interceptor(ctx, in, info, handler)
} }
var _AuditService_serviceDesc = grpc.ServiceDesc{ // AuditService_ServiceDesc is the grpc.ServiceDesc for AuditService service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
var AuditService_ServiceDesc = grpc.ServiceDesc{
ServiceName: "inetmock.rpc.v1.AuditService", ServiceName: "inetmock.rpc.v1.AuditService",
HandlerType: (*AuditServiceServer)(nil), HandlerType: (*AuditServiceServer)(nil),
Methods: []grpc.MethodDesc{ Methods: []grpc.MethodDesc{

View file

@ -1,13 +1,12 @@
// Code generated by protoc-gen-go. DO NOT EDIT. // Code generated by protoc-gen-go. DO NOT EDIT.
// versions: // versions:
// protoc-gen-go v1.25.0 // protoc-gen-go v1.26.0
// protoc v3.15.2 // protoc v3.15.2
// source: rpc/v1/health.proto // source: rpc/v1/health.proto
package v1 package v1
import ( import (
proto "github.com/golang/protobuf/proto"
protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl" protoimpl "google.golang.org/protobuf/runtime/protoimpl"
reflect "reflect" reflect "reflect"
@ -21,10 +20,6 @@ const (
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
) )
// This is a compile-time assertion that a sufficiently up-to-date version
// of the legacy proto package is being used.
const _ = proto.ProtoPackageIsVersion4
type HealthState int32 type HealthState int32
const ( const (

View file

@ -11,6 +11,7 @@ import (
// This is a compile-time assertion to ensure that this generated file // This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against. // is compatible with the grpc package it is being compiled against.
// Requires gRPC-Go v1.32.0 or later.
const _ = grpc.SupportPackageIsVersion7 const _ = grpc.SupportPackageIsVersion7
// HealthServiceClient is the client API for HealthService service. // HealthServiceClient is the client API for HealthService service.
@ -62,7 +63,7 @@ type UnsafeHealthServiceServer interface {
} }
func RegisterHealthServiceServer(s grpc.ServiceRegistrar, srv HealthServiceServer) { func RegisterHealthServiceServer(s grpc.ServiceRegistrar, srv HealthServiceServer) {
s.RegisterService(&_HealthService_serviceDesc, srv) s.RegisterService(&HealthService_ServiceDesc, srv)
} }
func _HealthService_GetHealth_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { func _HealthService_GetHealth_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
@ -83,7 +84,10 @@ func _HealthService_GetHealth_Handler(srv interface{}, ctx context.Context, dec
return interceptor(ctx, in, info, handler) return interceptor(ctx, in, info, handler)
} }
var _HealthService_serviceDesc = grpc.ServiceDesc{ // HealthService_ServiceDesc is the grpc.ServiceDesc for HealthService service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
var HealthService_ServiceDesc = grpc.ServiceDesc{
ServiceName: "inetmock.rpc.v1.HealthService", ServiceName: "inetmock.rpc.v1.HealthService",
HandlerType: (*HealthServiceServer)(nil), HandlerType: (*HealthServiceServer)(nil),
Methods: []grpc.MethodDesc{ Methods: []grpc.MethodDesc{

View file

@ -1,13 +1,12 @@
// Code generated by protoc-gen-go. DO NOT EDIT. // Code generated by protoc-gen-go. DO NOT EDIT.
// versions: // versions:
// protoc-gen-go v1.25.0 // protoc-gen-go v1.26.0
// protoc v3.15.2 // protoc v3.15.2
// source: rpc/v1/pcap.proto // source: rpc/v1/pcap.proto
package v1 package v1
import ( import (
proto "github.com/golang/protobuf/proto"
protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl" protoimpl "google.golang.org/protobuf/runtime/protoimpl"
durationpb "google.golang.org/protobuf/types/known/durationpb" durationpb "google.golang.org/protobuf/types/known/durationpb"
@ -22,10 +21,6 @@ const (
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
) )
// This is a compile-time assertion that a sufficiently up-to-date version
// of the legacy proto package is being used.
const _ = proto.ProtoPackageIsVersion4
type ListAvailableDevicesRequest struct { type ListAvailableDevicesRequest struct {
state protoimpl.MessageState state protoimpl.MessageState
sizeCache protoimpl.SizeCache sizeCache protoimpl.SizeCache

View file

@ -11,6 +11,7 @@ import (
// This is a compile-time assertion to ensure that this generated file // This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against. // is compatible with the grpc package it is being compiled against.
// Requires gRPC-Go v1.32.0 or later.
const _ = grpc.SupportPackageIsVersion7 const _ = grpc.SupportPackageIsVersion7
// PCAPServiceClient is the client API for PCAPService service. // PCAPServiceClient is the client API for PCAPService service.
@ -104,7 +105,7 @@ type UnsafePCAPServiceServer interface {
} }
func RegisterPCAPServiceServer(s grpc.ServiceRegistrar, srv PCAPServiceServer) { func RegisterPCAPServiceServer(s grpc.ServiceRegistrar, srv PCAPServiceServer) {
s.RegisterService(&_PCAPService_serviceDesc, srv) s.RegisterService(&PCAPService_ServiceDesc, srv)
} }
func _PCAPService_ListAvailableDevices_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { func _PCAPService_ListAvailableDevices_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
@ -179,7 +180,10 @@ func _PCAPService_StopPCAPFileRecording_Handler(srv interface{}, ctx context.Con
return interceptor(ctx, in, info, handler) return interceptor(ctx, in, info, handler)
} }
var _PCAPService_serviceDesc = grpc.ServiceDesc{ // PCAPService_ServiceDesc is the grpc.ServiceDesc for PCAPService service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
var PCAPService_ServiceDesc = grpc.ServiceDesc{
ServiceName: "inetmock.rpc.v1.PCAPService", ServiceName: "inetmock.rpc.v1.PCAPService",
HandlerType: (*PCAPServiceServer)(nil), HandlerType: (*PCAPServiceServer)(nil),
Methods: []grpc.MethodDesc{ Methods: []grpc.MethodDesc{