Cleanup
This commit is contained in:
parent
29662726f6
commit
6442f9f915
63 changed files with 1889 additions and 1334 deletions
|
@ -18,6 +18,9 @@ linters-settings:
|
|||
- ifElseChain
|
||||
- octalLiteral
|
||||
- wrapperFunc
|
||||
settings:
|
||||
hugeParam:
|
||||
sizeThreshold: 200
|
||||
gocyclo:
|
||||
min-complexity: 15
|
||||
goimports:
|
||||
|
@ -57,6 +60,7 @@ linters:
|
|||
- dupl
|
||||
- errcheck
|
||||
- exhaustive
|
||||
- exportloopref
|
||||
- funlen
|
||||
- goconst
|
||||
- gocritic
|
||||
|
@ -71,12 +75,15 @@ linters:
|
|||
- lll
|
||||
- misspell
|
||||
- nakedret
|
||||
- nestif
|
||||
- noctx
|
||||
- nolintlint
|
||||
- scopelint
|
||||
- paralleltest
|
||||
- prealloc
|
||||
- staticcheck
|
||||
- structcheck
|
||||
- stylecheck
|
||||
- thelper
|
||||
- typecheck
|
||||
- unconvert
|
||||
- unparam
|
||||
|
@ -92,11 +99,6 @@ issues:
|
|||
linters:
|
||||
- gomnd
|
||||
|
||||
# https://github.com/go-critic/go-critic/issues/926
|
||||
- linters:
|
||||
- gocritic
|
||||
text: "unnecessaryDefer:"
|
||||
|
||||
run:
|
||||
skip-dirs:
|
||||
- internal/mock
|
||||
|
@ -104,6 +106,4 @@ run:
|
|||
# golangci.com configuration
|
||||
# https://github.com/golangci/golangci/wiki/Configuration
|
||||
service:
|
||||
golangci-lint-version: 1.37.x # use the fixed version to not introduce new linters unexpectedly
|
||||
prepare:
|
||||
- echo "here I can run custom commands, but no preparation needed for this repo"
|
||||
golangci-lint-version: 1.39.x # use the fixed version to not introduce new linters unexpectedly
|
||||
|
|
13
.run/imctl.run.xml
Normal file
13
.run/imctl.run.xml
Normal file
|
@ -0,0 +1,13 @@
|
|||
<component name="ProjectRunConfigurationManager">
|
||||
<configuration default="false" name="imctl" type="GoApplicationRunConfiguration" factoryName="Go Application">
|
||||
<module name="inetmock" />
|
||||
<working_directory value="$PROJECT_DIR$" />
|
||||
<parameters value="audit watch" />
|
||||
<sudo value="true" />
|
||||
<kind value="PACKAGE" />
|
||||
<package value="gitlab.com/inetmock/inetmock/cmd/imctl" />
|
||||
<directory value="$PROJECT_DIR$" />
|
||||
<filePath value="$PROJECT_DIR$/cmd/imctl/main.go" />
|
||||
<method v="2" />
|
||||
</configuration>
|
||||
</component>
|
13
.run/inetmock.run.xml
Normal file
13
.run/inetmock.run.xml
Normal file
|
@ -0,0 +1,13 @@
|
|||
<component name="ProjectRunConfigurationManager">
|
||||
<configuration default="false" name="inetmock" type="GoApplicationRunConfiguration" factoryName="Go Application">
|
||||
<module name="inetmock" />
|
||||
<working_directory value="$PROJECT_DIR$" />
|
||||
<parameters value="serve" />
|
||||
<sudo value="true" />
|
||||
<kind value="PACKAGE" />
|
||||
<package value="gitlab.com/inetmock/inetmock/cmd/inetmock" />
|
||||
<directory value="$PROJECT_DIR$" />
|
||||
<filePath value="$PROJECT_DIR$/cmd/inetmock/main.go" />
|
||||
<method v="2" />
|
||||
</configuration>
|
||||
</component>
|
20
Taskfile.yml
20
Taskfile.yml
|
@ -16,15 +16,18 @@ env:
|
|||
|
||||
tasks:
|
||||
clean:
|
||||
desc: clean all generated files
|
||||
cmds:
|
||||
- find . -type f \( -name "*.pb.go" -or -name "*.mock.go" \) -exec rm -f {} \;
|
||||
- rm -rf ./main {{ .OUT_DIR }}
|
||||
|
||||
format:
|
||||
desc: format all source files
|
||||
cmds:
|
||||
- go fmt ./...
|
||||
|
||||
deps:
|
||||
desc: download dependencies
|
||||
sources:
|
||||
- go.mod
|
||||
- go.sum
|
||||
|
@ -32,6 +35,7 @@ tasks:
|
|||
- go mod download
|
||||
|
||||
golangci-lint:
|
||||
desc: run Go linter
|
||||
deps:
|
||||
- deps
|
||||
- generate
|
||||
|
@ -39,6 +43,7 @@ tasks:
|
|||
- golangci-lint run --timeout 5m
|
||||
|
||||
protobuf-lint:
|
||||
desc: run protobuf linter
|
||||
dir: api/
|
||||
sources:
|
||||
- "proto/**/*.proto"
|
||||
|
@ -46,11 +51,13 @@ tasks:
|
|||
- buf lint
|
||||
|
||||
lint:
|
||||
desc: run all linters
|
||||
deps:
|
||||
- golangci-lint
|
||||
- protobuf-lint
|
||||
|
||||
protoc:
|
||||
desc: generate Go files from protobuf files
|
||||
deps:
|
||||
- protobuf-lint
|
||||
sources:
|
||||
|
@ -61,6 +68,7 @@ tasks:
|
|||
- buf protoc --proto_path ./api/proto/ --go_out=./pkg/ --go_opt=paths=source_relative --go-grpc_out=./pkg/ --go-grpc_opt=paths=source_relative {{ .PROTO_FILES }}
|
||||
|
||||
go-generate:
|
||||
desc: run all go:generate directives
|
||||
deps:
|
||||
- deps
|
||||
sources:
|
||||
|
@ -71,11 +79,13 @@ tasks:
|
|||
- go generate -x ./...
|
||||
|
||||
generate:
|
||||
desc: run all code generation steps
|
||||
deps:
|
||||
- go-generate
|
||||
- protoc
|
||||
|
||||
test:
|
||||
desc: run all unit tests
|
||||
sources:
|
||||
- "**/*.go"
|
||||
deps:
|
||||
|
@ -91,6 +101,7 @@ tasks:
|
|||
- rm -f {{ .OUT_DIR }}/cov-raw.out
|
||||
|
||||
integration-test:
|
||||
desc: run all benchmarks/integration tests
|
||||
deps:
|
||||
- deps
|
||||
- generate
|
||||
|
@ -101,18 +112,21 @@ tasks:
|
|||
{{ end }}
|
||||
|
||||
cli-cover-report:
|
||||
desc: generate a coverage report on the CLI
|
||||
deps:
|
||||
- test
|
||||
cmds:
|
||||
- go tool cover -func={{ .OUT_DIR }}/cov.out
|
||||
|
||||
html-cover-report:
|
||||
desc: generate a coverage report as HTML page
|
||||
deps:
|
||||
- test
|
||||
cmds:
|
||||
- go tool cover -html={{ .OUT_DIR }}/cov.out -o {{ .OUT_DIR }}/coverage.html
|
||||
|
||||
build-inetmock:
|
||||
desc: build the INetMock server part
|
||||
deps:
|
||||
- test
|
||||
cmds:
|
||||
|
@ -120,10 +134,12 @@ tasks:
|
|||
- go build -ldflags='-w -s' -o {{ .OUT_DIR }}/inetmock {{ .INETMOCK_PKG }}
|
||||
|
||||
debug-inetmock:
|
||||
desc: run INetMock server with delve for remote debugging
|
||||
cmds:
|
||||
- dlv --listen=:2345 --headless=true --api-version=2 --accept-multiclient --output {{ .OUT_DIR }}/__debug_bin debug {{ .INETMOCK_PKG }} -- serve
|
||||
|
||||
build-imctl:
|
||||
desc: build the imctl INetMock client CLI
|
||||
deps:
|
||||
- test
|
||||
cmds:
|
||||
|
@ -131,11 +147,13 @@ tasks:
|
|||
- go build -ldflags='-w -s' -o {{ .OUT_DIR }}/imctl {{ .IMCTL_PKG }}
|
||||
|
||||
build-all:
|
||||
desc: build all binaries
|
||||
deps:
|
||||
- build-inetmock
|
||||
- build-imctl
|
||||
|
||||
snapshot-release:
|
||||
desc: create a snapshot/test release without publishing any artifacts
|
||||
deps:
|
||||
- deps
|
||||
- generate
|
||||
|
@ -143,6 +161,7 @@ tasks:
|
|||
- goreleaser release --snapshot --skip-publish --rm-dist
|
||||
|
||||
release:
|
||||
desc: create a release - includes artifact publishing
|
||||
deps:
|
||||
- deps
|
||||
- generate
|
||||
|
@ -150,5 +169,6 @@ tasks:
|
|||
- goreleaser release
|
||||
|
||||
docs:
|
||||
desc: generate docs
|
||||
cmds:
|
||||
- mdbook build -d ./../public ./docs
|
||||
|
|
|
@ -113,7 +113,7 @@ func watchAuditEvents(_ *cobra.Command, _ []string) (err error) {
|
|||
|
||||
func runListSinks(*cobra.Command, []string) (err error) {
|
||||
auditClient := rpcV1.NewAuditServiceClient(conn)
|
||||
ctx, cancel := context.WithTimeout(cliApp.Context(), grpcTimeout)
|
||||
ctx, cancel := context.WithTimeout(cliApp.Context(), cfg.GRPCTimeout)
|
||||
defer cancel()
|
||||
|
||||
var resp *rpcV1.ListSinksResponse
|
||||
|
@ -125,19 +125,19 @@ func runListSinks(*cobra.Command, []string) (err error) {
|
|||
Name string
|
||||
}
|
||||
|
||||
var sinks []printableSink
|
||||
for _, s := range resp.Sinks {
|
||||
sinks = append(sinks, printableSink{Name: s})
|
||||
var sinks = make([]printableSink, len(resp.Sinks))
|
||||
for i, s := range resp.Sinks {
|
||||
sinks[i] = printableSink{Name: s}
|
||||
}
|
||||
|
||||
writer := format.Writer(outputFormat, os.Stdout)
|
||||
writer := format.Writer(cfg.Format, os.Stdout)
|
||||
err = writer.Write(sinks)
|
||||
return
|
||||
}
|
||||
|
||||
func runAddFile(_ *cobra.Command, args []string) (err error) {
|
||||
auditClient := rpcV1.NewAuditServiceClient(conn)
|
||||
ctx, cancel := context.WithTimeout(cliApp.Context(), grpcTimeout)
|
||||
ctx, cancel := context.WithTimeout(cliApp.Context(), cfg.GRPCTimeout)
|
||||
defer cancel()
|
||||
|
||||
var resp *rpcV1.RegisterFileSinkResponse
|
||||
|
@ -154,7 +154,7 @@ func runAddFile(_ *cobra.Command, args []string) (err error) {
|
|||
|
||||
func runRemoveFile(_ *cobra.Command, args []string) (err error) {
|
||||
auditClient := rpcV1.NewAuditServiceClient(conn)
|
||||
ctx, cancel := context.WithTimeout(cliApp.Context(), grpcTimeout)
|
||||
ctx, cancel := context.WithTimeout(cliApp.Context(), cfg.GRPCTimeout)
|
||||
defer cancel()
|
||||
|
||||
_, err = auditClient.RemoveFileSink(ctx, &rpcV1.RemoveFileSinkRequest{TargetPath: args[0]})
|
||||
|
|
|
@ -44,21 +44,22 @@ func fromComponentsHealth(componentsHealth map[string]*rpcV1.ComponentHealth) in
|
|||
Message string
|
||||
}
|
||||
|
||||
var componentsInfo []printableHealthInfo
|
||||
|
||||
var componentsInfo = make([]printableHealthInfo, len(componentsHealth))
|
||||
var idx int
|
||||
for componentName, component := range componentsHealth {
|
||||
componentsInfo = append(componentsInfo, printableHealthInfo{
|
||||
componentsInfo[idx] = printableHealthInfo{
|
||||
Component: componentName,
|
||||
State: component.State.String(),
|
||||
Message: component.Message,
|
||||
})
|
||||
}
|
||||
idx++
|
||||
}
|
||||
return componentsInfo
|
||||
}
|
||||
|
||||
func getHealthResult() (healthResp *rpcV1.GetHealthResponse, err error) {
|
||||
var healthClient = rpcV1.NewHealthServiceClient(conn)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), grpcTimeout)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), cfg.GRPCTimeout)
|
||||
healthResp, err = healthClient.GetHealth(ctx, &rpcV1.GetHealthRequest{})
|
||||
cancel()
|
||||
return
|
||||
|
@ -75,7 +76,7 @@ func runGeneralHealth(_ *cobra.Command, _ []string) {
|
|||
|
||||
printable := fromComponentsHealth(healthResp.ComponentsHealth)
|
||||
|
||||
writer := format.Writer(outputFormat, os.Stdout)
|
||||
writer := format.Writer(cfg.Format, os.Stdout)
|
||||
if err = writer.Write(printable); err != nil {
|
||||
fmt.Printf("Error occurred during writing response values: %v\n", err)
|
||||
}
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/spf13/pflag"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"gitlab.com/inetmock/inetmock/internal/app"
|
||||
|
@ -15,34 +16,53 @@ const (
|
|||
)
|
||||
|
||||
var (
|
||||
inetMockSocketPath string
|
||||
outputFormat string
|
||||
grpcTimeout time.Duration
|
||||
cliApp app.App
|
||||
conn *grpc.ClientConn
|
||||
cfg config
|
||||
)
|
||||
|
||||
//nolint:lll
|
||||
type config struct {
|
||||
SocketPath string
|
||||
Format string
|
||||
GRPCTimeout time.Duration
|
||||
}
|
||||
|
||||
func main() {
|
||||
healthCmd.AddCommand(generalHealthCmd, containerHealthCmd)
|
||||
cliApp = app.NewApp(
|
||||
app.Spec{
|
||||
Name: "imctl",
|
||||
Short: "IMCTL is the CLI app to interact with an INetMock server",
|
||||
Config: &cfg,
|
||||
SubCommands: []*cobra.Command{healthCmd, auditCmd, pcapCmd},
|
||||
LateInitTasks: []func(cmd *cobra.Command, args []string) (err error){
|
||||
initGRPCConnection,
|
||||
},
|
||||
IgnoreMissingConfigFile: true,
|
||||
FlagBindings: map[string]func(flagSet *pflag.FlagSet) *pflag.Flag{
|
||||
"grpctimeout": func(flagSet *pflag.FlagSet) *pflag.Flag {
|
||||
return flagSet.Lookup("grpc-timeout")
|
||||
},
|
||||
"format": func(flagSet *pflag.FlagSet) *pflag.Flag {
|
||||
return flagSet.Lookup("format")
|
||||
},
|
||||
"socketpath": func(flagSet *pflag.FlagSet) *pflag.Flag {
|
||||
return flagSet.Lookup("socket-path")
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
cliApp = app.NewApp("imctl", "IMCTL is the CLI app to interact with an INetMock server").
|
||||
WithCommands(healthCmd, auditCmd, pcapCmd).
|
||||
WithInitTasks(func(_ *cobra.Command, _ []string) (err error) {
|
||||
return initGRPCConnection()
|
||||
}).
|
||||
WithLogger()
|
||||
|
||||
cliApp.RootCommand().PersistentFlags().StringVar(&inetMockSocketPath, "socket-path", "unix:///var/run/inetmock.sock", "Path to the INetMock socket file")
|
||||
cliApp.RootCommand().PersistentFlags().StringVarP(&outputFormat, "format", "f", "table", "Output format to use. Possible values: table, json, yaml")
|
||||
cliApp.RootCommand().PersistentFlags().DurationVar(&grpcTimeout, "grpc-timeout", defaultGRPCTimeout, "Timeout to connect to the gRPC API")
|
||||
cliApp.RootCommand().PersistentFlags().String("socket-path", "unix:///var/run/inetmock.sock", "Path to the INetMock socket file")
|
||||
cliApp.RootCommand().PersistentFlags().StringP("format", "f", "table", "Output format to use. Possible values: table, json, yaml")
|
||||
cliApp.RootCommand().PersistentFlags().Duration("grpc-timeout", defaultGRPCTimeout, "Timeout to connect to the gRPC API")
|
||||
|
||||
cliApp.MustRun()
|
||||
}
|
||||
|
||||
func initGRPCConnection() (err error) {
|
||||
dialCtx, cancel := context.WithTimeout(cliApp.Context(), grpcTimeout)
|
||||
conn, err = grpc.DialContext(dialCtx, inetMockSocketPath, grpc.WithInsecure())
|
||||
func initGRPCConnection(*cobra.Command, []string) (err error) {
|
||||
dialCtx, cancel := context.WithTimeout(cliApp.Context(), cfg.GRPCTimeout)
|
||||
conn, err = grpc.DialContext(dialCtx, cfg.SocketPath, grpc.WithInsecure())
|
||||
cancel()
|
||||
|
||||
return
|
||||
|
|
|
@ -56,7 +56,7 @@ var (
|
|||
|
||||
var err error
|
||||
pcapClient := rpcV1.NewPCAPServiceClient(conn)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), grpcTimeout)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), cfg.GRPCTimeout)
|
||||
defer cancel()
|
||||
var resp *rpcV1.ListAvailableDevicesResponse
|
||||
if resp, err = pcapClient.ListAvailableDevices(ctx, new(rpcV1.ListAvailableDevicesRequest)); err == nil {
|
||||
|
@ -91,7 +91,7 @@ var (
|
|||
ValidArgsFunction: func(*cobra.Command, []string, string) ([]string, cobra.ShellCompDirective) {
|
||||
var err error
|
||||
pcapClient := rpcV1.NewPCAPServiceClient(conn)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), grpcTimeout)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), cfg.GRPCTimeout)
|
||||
defer cancel()
|
||||
var resp *rpcV1.ListActiveRecordingsResponse
|
||||
if resp, err = pcapClient.ListActiveRecordings(ctx, new(rpcV1.ListActiveRecordingsRequest)); err == nil {
|
||||
|
@ -127,7 +127,7 @@ func init() {
|
|||
|
||||
func runListAvailableDevices(*cobra.Command, []string) (err error) {
|
||||
pcapClient := rpcV1.NewPCAPServiceClient(conn)
|
||||
ctx, cancel := context.WithTimeout(cliApp.Context(), grpcTimeout)
|
||||
ctx, cancel := context.WithTimeout(cliApp.Context(), cfg.GRPCTimeout)
|
||||
defer cancel()
|
||||
|
||||
var resp *rpcV1.ListAvailableDevicesResponse
|
||||
|
@ -149,7 +149,7 @@ func runListAvailableDevices(*cobra.Command, []string) (err error) {
|
|||
})
|
||||
}
|
||||
|
||||
writer := format.Writer(outputFormat, os.Stdout)
|
||||
writer := format.Writer(cfg.Format, os.Stdout)
|
||||
err = writer.Write(availableDevs)
|
||||
return
|
||||
}
|
||||
|
@ -163,7 +163,7 @@ func runListActiveRecordings(*cobra.Command, []string) error {
|
|||
|
||||
pcapClient := rpcV1.NewPCAPServiceClient(conn)
|
||||
|
||||
ctx, cancel := context.WithTimeout(cliApp.Context(), grpcTimeout)
|
||||
ctx, cancel := context.WithTimeout(cliApp.Context(), cfg.GRPCTimeout)
|
||||
defer cancel()
|
||||
|
||||
var err error
|
||||
|
@ -172,21 +172,21 @@ func runListActiveRecordings(*cobra.Command, []string) error {
|
|||
return err
|
||||
}
|
||||
|
||||
var out []printableSubscription
|
||||
for _, subscription := range resp.Subscriptions {
|
||||
var out = make([]printableSubscription, len(resp.Subscriptions))
|
||||
for idx, subscription := range resp.Subscriptions {
|
||||
splitIdx := strings.Index(subscription, ":")
|
||||
if splitIdx < 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
out = append(out, printableSubscription{
|
||||
out[idx] = printableSubscription{
|
||||
Name: subscription[splitIdx:],
|
||||
Device: subscription[:splitIdx],
|
||||
ConsumerKey: subscription,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
writer := format.Writer(outputFormat, os.Stdout)
|
||||
writer := format.Writer(cfg.Format, os.Stdout)
|
||||
|
||||
return writer.Write(out)
|
||||
}
|
||||
|
@ -198,7 +198,7 @@ func runAddRecording(_ *cobra.Command, args []string) (err error) {
|
|||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(cliApp.Context(), grpcTimeout)
|
||||
ctx, cancel := context.WithTimeout(cliApp.Context(), cfg.GRPCTimeout)
|
||||
defer cancel()
|
||||
|
||||
var resp *rpcV1.StartPCAPFileRecordingResponse
|
||||
|
@ -219,7 +219,7 @@ func runAddRecording(_ *cobra.Command, args []string) (err error) {
|
|||
func runRemoveCurrentlyRunningRecording(_ *cobra.Command, args []string) error {
|
||||
pcapClient := rpcV1.NewPCAPServiceClient(conn)
|
||||
|
||||
listRecsCtx, listRecsCancel := context.WithTimeout(cliApp.Context(), grpcTimeout)
|
||||
listRecsCtx, listRecsCancel := context.WithTimeout(cliApp.Context(), cfg.GRPCTimeout)
|
||||
defer listRecsCancel()
|
||||
|
||||
var err error
|
||||
|
@ -240,7 +240,7 @@ func runRemoveCurrentlyRunningRecording(_ *cobra.Command, args []string) error {
|
|||
return fmt.Errorf("the given subscription is not known: %s", args[0])
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(cliApp.Context(), grpcTimeout)
|
||||
ctx, cancel := context.WithTimeout(cliApp.Context(), cfg.GRPCTimeout)
|
||||
defer cancel()
|
||||
|
||||
var stopRecResp *rpcV1.StopPCAPFileRecordingResponse
|
||||
|
@ -259,7 +259,7 @@ func runRemoveCurrentlyRunningRecording(_ *cobra.Command, args []string) error {
|
|||
}
|
||||
|
||||
func isValidRecordDevice(device string, pcapClient rpcV1.PCAPServiceClient) (err error) {
|
||||
ctx, cancel := context.WithTimeout(cliApp.Context(), grpcTimeout)
|
||||
ctx, cancel := context.WithTimeout(cliApp.Context(), cfg.GRPCTimeout)
|
||||
defer cancel()
|
||||
var resp *rpcV1.ListAvailableDevicesResponse
|
||||
if resp, err = pcapClient.ListAvailableDevices(ctx, new(rpcV1.ListAvailableDevicesRequest)); err != nil {
|
||||
|
|
|
@ -1,33 +1,101 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"gitlab.com/inetmock/inetmock/internal/app"
|
||||
"gitlab.com/inetmock/inetmock/internal/endpoint"
|
||||
dns "gitlab.com/inetmock/inetmock/internal/endpoint/handler/dns/mock"
|
||||
http "gitlab.com/inetmock/inetmock/internal/endpoint/handler/http/mock"
|
||||
"gitlab.com/inetmock/inetmock/internal/endpoint/handler/http/proxy"
|
||||
"gitlab.com/inetmock/inetmock/internal/endpoint/handler/metrics"
|
||||
"gitlab.com/inetmock/inetmock/internal/endpoint/handler/tls/interceptor"
|
||||
"gitlab.com/inetmock/inetmock/pkg/cert"
|
||||
)
|
||||
|
||||
var (
|
||||
serverApp app.App
|
||||
)
|
||||
|
||||
func main() {
|
||||
serverApp = app.NewApp("inetmock", "INetMock is lightweight internet mock").
|
||||
WithHandlerRegistry(
|
||||
cfg appConfig
|
||||
registrations = []endpoint.Registration{
|
||||
http.AddHTTPMock,
|
||||
dns.AddDNSMock,
|
||||
interceptor.AddTLSInterceptor,
|
||||
proxy.AddHTTPProxy,
|
||||
metrics.AddMetricsExporter).
|
||||
WithCommands(serveCmd, generateCaCmd).
|
||||
WithConfig().
|
||||
WithLogger().
|
||||
WithHealthChecker().
|
||||
WithCertStore().
|
||||
WithEventStream().
|
||||
WithEndpointManager()
|
||||
metrics.AddMetricsExporter,
|
||||
}
|
||||
)
|
||||
|
||||
type Data struct {
|
||||
PCAP string
|
||||
Audit string
|
||||
}
|
||||
|
||||
func (d *Data) setup() (err error) {
|
||||
if d.PCAP, err = ensureDataDir(d.PCAP); err != nil {
|
||||
return
|
||||
}
|
||||
if d.Audit, err = ensureDataDir(d.Audit); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func ensureDataDir(dataDirPath string) (cleanedPath string, err error) {
|
||||
cleanedPath = dataDirPath
|
||||
if !filepath.IsAbs(cleanedPath) {
|
||||
if cleanedPath, err = filepath.Abs(cleanedPath); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
err = os.MkdirAll(cleanedPath, 0640)
|
||||
return
|
||||
}
|
||||
|
||||
type appConfig struct {
|
||||
TLS cert.Options
|
||||
Listeners map[string]endpoint.ListenerSpec
|
||||
API struct {
|
||||
Listen string
|
||||
}
|
||||
Data Data
|
||||
}
|
||||
|
||||
func (c *appConfig) APIURL() *url.URL {
|
||||
if u, err := url.Parse(c.API.Listen); err != nil {
|
||||
u, _ = url.Parse("tcp://:0")
|
||||
return u
|
||||
} else {
|
||||
return u
|
||||
}
|
||||
}
|
||||
|
||||
func main() {
|
||||
serverApp = app.NewApp(
|
||||
app.Spec{
|
||||
Name: "inetmock",
|
||||
Short: "INetMock is lightweight internet mock",
|
||||
Config: &cfg,
|
||||
SubCommands: []*cobra.Command{serveCmd, generateCaCmd},
|
||||
Defaults: map[string]interface{}{
|
||||
"api.listen": "tcp://:0",
|
||||
"data.pcap": "/var/lib/inetmock/data/pcap",
|
||||
"data.audit": "/var/lib/inetmock/data/audit",
|
||||
"tls.curve": cert.CurveTypeP256,
|
||||
"tls.minTLSVersion": cert.TLSVersionTLS10,
|
||||
"tls.includeInsecureCipherSuites": false,
|
||||
"tls.validity.server.notBeforeRelative": 168 * time.Hour,
|
||||
"tls.validity.server.notAfterRelative": 168 * time.Hour,
|
||||
"tls.certCachePath": "/tmp",
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
serverApp.MustRun()
|
||||
}
|
||||
|
|
|
@ -4,9 +4,19 @@ import (
|
|||
"github.com/spf13/cobra"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"gitlab.com/inetmock/inetmock/internal/endpoint"
|
||||
"gitlab.com/inetmock/inetmock/internal/pcap"
|
||||
"gitlab.com/inetmock/inetmock/internal/pcap/consumers/audit"
|
||||
audit2 "gitlab.com/inetmock/inetmock/internal/pcap/consumers/audit"
|
||||
"gitlab.com/inetmock/inetmock/internal/rpc"
|
||||
"gitlab.com/inetmock/inetmock/pkg/audit"
|
||||
"gitlab.com/inetmock/inetmock/pkg/audit/sink"
|
||||
"gitlab.com/inetmock/inetmock/pkg/cert"
|
||||
"gitlab.com/inetmock/inetmock/pkg/health"
|
||||
"gitlab.com/inetmock/inetmock/pkg/logging"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultEventBufferSize = 10
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -19,33 +29,67 @@ var (
|
|||
)
|
||||
|
||||
func startINetMock(_ *cobra.Command, _ []string) error {
|
||||
rpcAPI := rpc.NewINetMockAPI(
|
||||
serverApp.Config().APIURL(),
|
||||
serverApp.Logger(),
|
||||
serverApp.Checker(),
|
||||
serverApp.EventStream(),
|
||||
serverApp.Config().AuditDataDir(),
|
||||
serverApp.Config().PCAPDataDir(),
|
||||
)
|
||||
logger := serverApp.Logger()
|
||||
registry := endpoint.NewHandlerRegistry()
|
||||
|
||||
cfg := serverApp.Config()
|
||||
endpointOrchestrator := serverApp.EndpointManager()
|
||||
var err error
|
||||
appLogger := serverApp.Logger()
|
||||
|
||||
for name, spec := range cfg.ListenerSpecs() {
|
||||
if spec.Name == "" {
|
||||
spec.Name = name
|
||||
if err = cfg.Data.setup(); err != nil {
|
||||
appLogger.Error("Failed to setup data directories", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
if err := endpointOrchestrator.RegisterListener(spec); err != nil {
|
||||
logger.Error("Failed to register listener", zap.Error(err))
|
||||
|
||||
for _, registration := range registrations {
|
||||
if err = registration(registry); err != nil {
|
||||
appLogger.Error("Failed to run registration", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
errChan := serverApp.EndpointManager().StartEndpoints()
|
||||
var certStore cert.Store
|
||||
if certStore, err = cert.NewDefaultStore(cfg.TLS, appLogger.Named("CertStore")); err != nil {
|
||||
appLogger.Error("Failed to initialize cert store", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
var eventStream audit.EventStream
|
||||
if eventStream, err = setupEventStream(appLogger); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var endpointOrchestrator = endpoint.NewOrchestrator(
|
||||
serverApp.Context(),
|
||||
certStore,
|
||||
registry,
|
||||
eventStream,
|
||||
appLogger,
|
||||
)
|
||||
|
||||
checker := health.New()
|
||||
|
||||
rpcAPI := rpc.NewINetMockAPI(
|
||||
cfg.APIURL(),
|
||||
appLogger,
|
||||
checker,
|
||||
eventStream,
|
||||
cfg.Data.Audit,
|
||||
cfg.Data.PCAP,
|
||||
)
|
||||
|
||||
for name, spec := range cfg.Listeners {
|
||||
if spec.Name == "" {
|
||||
spec.Name = name
|
||||
}
|
||||
if err := endpointOrchestrator.RegisterListener(spec); err != nil {
|
||||
appLogger.Error("Failed to register listener", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
errChan := endpointOrchestrator.StartEndpoints()
|
||||
if err := rpcAPI.StartServer(); err != nil {
|
||||
serverApp.Shutdown()
|
||||
logger.Error(
|
||||
appLogger.Error(
|
||||
"failed to start gRPC API",
|
||||
zap.Error(err),
|
||||
)
|
||||
|
@ -60,22 +104,55 @@ loop:
|
|||
for {
|
||||
select {
|
||||
case err := <-errChan:
|
||||
logger.Error("got error from endpoint", zap.Error(err))
|
||||
appLogger.Error("got error from endpoint", zap.Error(err))
|
||||
case <-serverApp.Context().Done():
|
||||
break loop
|
||||
}
|
||||
}
|
||||
|
||||
logger.Info("App context canceled - shutting down")
|
||||
appLogger.Info("App context canceled - shutting down")
|
||||
|
||||
rpcAPI.StopServer()
|
||||
return nil
|
||||
}
|
||||
|
||||
func setupEventStream(appLogger logging.Logger) (audit.EventStream, error) {
|
||||
var evenStream audit.EventStream
|
||||
var err error
|
||||
evenStream, err = audit.NewEventStream(
|
||||
appLogger.Named("EventStream"),
|
||||
audit.WithSinkBufferSize(defaultEventBufferSize),
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
appLogger.Error("Failed to initialize event stream", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = evenStream.RegisterSink(serverApp.Context(), sink.NewLogSink(appLogger.Named("LogSink"))); err != nil {
|
||||
appLogger.Error("Failed to register log sink to event stream", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var metricSink audit.Sink
|
||||
if metricSink, err = sink.NewMetricSink(); err != nil {
|
||||
appLogger.Error("Failed to setup metrics sink", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = evenStream.RegisterSink(serverApp.Context(), metricSink); err != nil {
|
||||
appLogger.Error("Failed to register metric sink", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return evenStream, nil
|
||||
}
|
||||
|
||||
//nolint:deadcode
|
||||
func startAuditConsumer() error {
|
||||
func startAuditConsumer(eventStream audit.EventStream) error {
|
||||
recorder := pcap.NewRecorder()
|
||||
auditConsumer := audit.NewAuditConsumer("audit", serverApp.EventStream())
|
||||
|
||||
auditConsumer := audit2.NewAuditConsumer("audit", eventStream)
|
||||
|
||||
return recorder.StartRecording(serverApp.Context(), "lo", auditConsumer)
|
||||
}
|
||||
|
|
|
@ -59,7 +59,7 @@ api:
|
|||
|
||||
tls:
|
||||
curve: P256
|
||||
minTLSVersion: SSL3
|
||||
minTLSVersion: TLS10
|
||||
includeInsecureCipherSuites: false
|
||||
validity:
|
||||
ca:
|
||||
|
|
|
@ -63,7 +63,7 @@ api:
|
|||
|
||||
tls:
|
||||
curve: P256
|
||||
minTLSVersion: SSL3
|
||||
minTLSVersion: TLS10
|
||||
includeInsecureCipherSuites: false
|
||||
validity:
|
||||
ca:
|
||||
|
|
21
go.mod
21
go.mod
|
@ -6,25 +6,24 @@ require (
|
|||
github.com/bwmarrin/snowflake v0.3.0
|
||||
github.com/docker/go-connections v0.4.0
|
||||
github.com/golang/mock v1.5.0
|
||||
github.com/golang/protobuf v1.4.3
|
||||
github.com/google/gopacket v1.1.19
|
||||
github.com/google/uuid v1.2.0
|
||||
github.com/imdario/mergo v0.3.11
|
||||
github.com/jinzhu/copier v0.2.5
|
||||
github.com/miekg/dns v1.1.40
|
||||
github.com/imdario/mergo v0.3.12
|
||||
github.com/jinzhu/copier v0.3.0
|
||||
github.com/maxatome/go-testdeep v1.9.2
|
||||
github.com/miekg/dns v1.1.41
|
||||
github.com/mitchellh/mapstructure v1.4.1
|
||||
github.com/olekukonko/tablewriter v0.0.5
|
||||
github.com/prometheus/client_golang v1.9.0
|
||||
github.com/soheilhy/cmux v0.1.4
|
||||
github.com/prometheus/client_golang v1.10.0
|
||||
github.com/soheilhy/cmux v0.1.5
|
||||
github.com/spf13/cobra v1.1.3
|
||||
github.com/spf13/pflag v1.0.5
|
||||
github.com/spf13/viper v1.7.1
|
||||
github.com/testcontainers/testcontainers-go v0.9.0
|
||||
github.com/testcontainers/testcontainers-go v0.10.0
|
||||
go.uber.org/multierr v1.6.0
|
||||
go.uber.org/zap v1.16.0
|
||||
google.golang.org/grpc v1.36.0
|
||||
google.golang.org/protobuf v1.25.0
|
||||
google.golang.org/grpc v1.37.0
|
||||
google.golang.org/protobuf v1.26.0
|
||||
gopkg.in/elazarl/goproxy.v1 v1.0.0-20180725130230-947c36da3153
|
||||
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b
|
||||
)
|
||||
|
||||
replace github.com/docker/docker => github.com/docker/docker v17.12.0-ce-rc1.0.20200916142827-bd33bbf0497b+incompatible
|
||||
|
|
|
@ -4,18 +4,17 @@ package app
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/spf13/pflag"
|
||||
"github.com/spf13/viper"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"gitlab.com/inetmock/inetmock/internal/endpoint"
|
||||
"gitlab.com/inetmock/inetmock/pkg/audit"
|
||||
"gitlab.com/inetmock/inetmock/pkg/audit/sink"
|
||||
"gitlab.com/inetmock/inetmock/pkg/cert"
|
||||
"gitlab.com/inetmock/inetmock/pkg/health"
|
||||
"gitlab.com/inetmock/inetmock/pkg/logging"
|
||||
"gitlab.com/inetmock/inetmock/pkg/path"
|
||||
)
|
||||
|
@ -26,71 +25,30 @@ var (
|
|||
developmentLogs bool
|
||||
)
|
||||
|
||||
type contextKey string
|
||||
|
||||
const (
|
||||
loggerKey contextKey = "gitlab.com/inetmock/inetmock/app/context/logger"
|
||||
configKey contextKey = "gitlab.com/inetmock/inetmock/app/context/config"
|
||||
handlerRegistryKey contextKey = "gitlab.com/inetmock/inetmock/app/context/handlerRegistry"
|
||||
healthCheckerKey contextKey = "gitlab.com/inetmock/inetmock/app/context/healthChecker"
|
||||
endpointManagerKey contextKey = "gitlab.com/inetmock/inetmock/app/context/endpointManager"
|
||||
certStoreKey contextKey = "gitlab.com/inetmock/inetmock/app/context/certStore"
|
||||
eventStreamKey contextKey = "gitlab.com/inetmock/inetmock/app/context/eventStream"
|
||||
defaultEventBufferSize = 10
|
||||
)
|
||||
type Spec struct {
|
||||
Name string
|
||||
Short string
|
||||
Config interface{}
|
||||
IgnoreMissingConfigFile bool
|
||||
Defaults map[string]interface{}
|
||||
FlagBindings map[string]func(flagSet *pflag.FlagSet) *pflag.Flag
|
||||
SubCommands []*cobra.Command
|
||||
LateInitTasks []func(cmd *cobra.Command, args []string) (err error)
|
||||
}
|
||||
|
||||
type App interface {
|
||||
EventStream() audit.EventStream
|
||||
Config() Config
|
||||
Checker() health.Checker
|
||||
Logger() logging.Logger
|
||||
EndpointManager() endpoint.Orchestrator
|
||||
HandlerRegistry() endpoint.HandlerRegistry
|
||||
Context() context.Context
|
||||
RootCommand() *cobra.Command
|
||||
MustRun()
|
||||
Shutdown()
|
||||
|
||||
// WithCommands adds subcommands to the root command
|
||||
// requires nothing
|
||||
WithCommands(cmds ...*cobra.Command) App
|
||||
|
||||
// WithHandlerRegistry builds up the handler registry
|
||||
// requires nothing
|
||||
WithHandlerRegistry(registrations ...endpoint.Registration) App
|
||||
|
||||
// WithHealthChecker adds the health checker mechanism
|
||||
// requires nothing
|
||||
WithHealthChecker() App
|
||||
|
||||
// WithLogger configures the logging system
|
||||
// requires nothing
|
||||
WithLogger() App
|
||||
|
||||
// WithEndpointManager creates an endpoint manager instance and adds it to the context
|
||||
// requires WithHandlerRegistry, WithHealthChecker and WithLogger
|
||||
WithEndpointManager() App
|
||||
|
||||
// WithCertStore initializes the cert store
|
||||
// requires WithLogger and WithConfig
|
||||
WithCertStore() App
|
||||
|
||||
// WithEventStream adds the audit event stream
|
||||
// requires WithLogger
|
||||
WithEventStream() App
|
||||
|
||||
// WithConfig loads the config
|
||||
// requires nothing
|
||||
WithConfig() App
|
||||
|
||||
WithInitTasks(task ...func(cmd *cobra.Command, args []string) (err error)) App
|
||||
}
|
||||
|
||||
type app struct {
|
||||
rootCmd *cobra.Command
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
lateInitTasks []func(cmd *cobra.Command, args []string) (err error)
|
||||
logger logging.Logger
|
||||
}
|
||||
|
||||
func (a *app) MustRun() {
|
||||
|
@ -107,67 +65,7 @@ func (a *app) MustRun() {
|
|||
}
|
||||
|
||||
func (a *app) Logger() logging.Logger {
|
||||
val := a.ctx.Value(loggerKey)
|
||||
if val == nil {
|
||||
return nil
|
||||
}
|
||||
return val.(logging.Logger)
|
||||
}
|
||||
|
||||
func (a *app) Config() Config {
|
||||
val := a.ctx.Value(configKey)
|
||||
if val == nil {
|
||||
return nil
|
||||
}
|
||||
return val.(Config)
|
||||
}
|
||||
|
||||
func (a *app) CertStore() cert.Store {
|
||||
val := a.ctx.Value(certStoreKey)
|
||||
if val == nil {
|
||||
return nil
|
||||
}
|
||||
return val.(cert.Store)
|
||||
}
|
||||
|
||||
func (a *app) Checker() health.Checker {
|
||||
val := a.ctx.Value(healthCheckerKey)
|
||||
if val == nil {
|
||||
return nil
|
||||
}
|
||||
return val.(health.Checker)
|
||||
}
|
||||
|
||||
func (a *app) EndpointManager() endpoint.Orchestrator {
|
||||
val := a.ctx.Value(endpointManagerKey)
|
||||
if val == nil {
|
||||
return nil
|
||||
}
|
||||
return val.(endpoint.Orchestrator)
|
||||
}
|
||||
|
||||
func (a *app) Audit() audit.Emitter {
|
||||
val := a.ctx.Value(eventStreamKey)
|
||||
if val == nil {
|
||||
return nil
|
||||
}
|
||||
return val.(audit.Emitter)
|
||||
}
|
||||
|
||||
func (a *app) EventStream() audit.EventStream {
|
||||
val := a.ctx.Value(eventStreamKey)
|
||||
if val == nil {
|
||||
return nil
|
||||
}
|
||||
return val.(audit.EventStream)
|
||||
}
|
||||
|
||||
func (a *app) HandlerRegistry() endpoint.HandlerRegistry {
|
||||
val := a.ctx.Value(handlerRegistryKey)
|
||||
if val == nil {
|
||||
return nil
|
||||
}
|
||||
return val.(endpoint.HandlerRegistry)
|
||||
return a.logger
|
||||
}
|
||||
|
||||
func (a *app) Context() context.Context {
|
||||
|
@ -182,41 +80,27 @@ func (a *app) Shutdown() {
|
|||
a.cancel()
|
||||
}
|
||||
|
||||
// WithCommands adds subcommands to the root command
|
||||
// requires nothing
|
||||
func (a *app) WithCommands(cmds ...*cobra.Command) App {
|
||||
a.rootCmd.AddCommand(cmds...)
|
||||
return a
|
||||
func NewApp(spec Spec) App {
|
||||
if spec.Defaults == nil {
|
||||
spec.Defaults = make(map[string]interface{})
|
||||
}
|
||||
|
||||
// WithHandlerRegistry builds up the handler registry
|
||||
// requires nothing
|
||||
func (a *app) WithHandlerRegistry(registrations ...endpoint.Registration) App {
|
||||
registry := endpoint.NewHandlerRegistry()
|
||||
|
||||
for _, registration := range registrations {
|
||||
if err := registration(registry); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
ctx, cancel := initAppContext()
|
||||
a := &app{
|
||||
rootCmd: &cobra.Command{
|
||||
Use: spec.Name,
|
||||
Short: spec.Short,
|
||||
SilenceUsage: true,
|
||||
},
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
a.ctx = context.WithValue(a.ctx, handlerRegistryKey, registry)
|
||||
|
||||
return a
|
||||
}
|
||||
|
||||
// WithHealthChecker adds the health checker mechanism
|
||||
// requires nothing
|
||||
func (a *app) WithHealthChecker() App {
|
||||
checker := health.New()
|
||||
a.ctx = context.WithValue(a.ctx, healthCheckerKey, checker)
|
||||
return a
|
||||
}
|
||||
|
||||
// WithLogger configures the logging system
|
||||
// requires nothing
|
||||
func (a *app) WithLogger() App {
|
||||
a.lateInitTasks = append(a.lateInitTasks, func(cmd *cobra.Command, args []string) (err error) {
|
||||
lateInitTasks := []func(cmd *cobra.Command, args []string) (err error){
|
||||
func(*cobra.Command, []string) (err error) {
|
||||
return spec.readConfig(a.rootCmd)
|
||||
},
|
||||
func(cmd *cobra.Command, args []string) (err error) {
|
||||
logging.ConfigureLogging(
|
||||
logging.ParseLevel(logLevel),
|
||||
developmentLogs,
|
||||
|
@ -227,122 +111,24 @@ func (a *app) WithLogger() App {
|
|||
},
|
||||
)
|
||||
|
||||
var logger logging.Logger
|
||||
if logger, err = logging.CreateLogger(); err != nil {
|
||||
if a.logger, err = logging.CreateLogger(); err != nil {
|
||||
return
|
||||
}
|
||||
a.ctx = context.WithValue(a.ctx, loggerKey, logger)
|
||||
return
|
||||
})
|
||||
return a
|
||||
}
|
||||
|
||||
// WithEndpointManager creates an endpoint manager instance and adds it to the context
|
||||
// requires WithHandlerRegistry, WithHealthChecker and WithLogger
|
||||
func (a *app) WithEndpointManager() App {
|
||||
a.lateInitTasks = append(a.lateInitTasks, func(_ *cobra.Command, _ []string) (err error) {
|
||||
epMgr := endpoint.NewOrchestrator(
|
||||
a.Context(),
|
||||
a.CertStore(),
|
||||
a.HandlerRegistry(),
|
||||
a.Audit(),
|
||||
a.Logger().Named("Orchestrator"),
|
||||
)
|
||||
|
||||
a.ctx = context.WithValue(a.ctx, endpointManagerKey, epMgr)
|
||||
return
|
||||
})
|
||||
return a
|
||||
}
|
||||
|
||||
// WithCertStore initializes the cert store
|
||||
// requires WithLogger and WithConfig
|
||||
func (a *app) WithCertStore() App {
|
||||
a.lateInitTasks = append(a.lateInitTasks, func(cmd *cobra.Command, args []string) (err error) {
|
||||
var certStore cert.Store
|
||||
if certStore, err = cert.NewDefaultStore(
|
||||
a.Config().TLSConfig(),
|
||||
a.Logger().Named("CertStore"),
|
||||
); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
a.ctx = context.WithValue(a.ctx, certStoreKey, certStore)
|
||||
return
|
||||
})
|
||||
return a
|
||||
}
|
||||
|
||||
// WithEventStream adds the audit event stream
|
||||
// requires WithLogger
|
||||
func (a *app) WithEventStream() App {
|
||||
a.lateInitTasks = append(a.lateInitTasks, func(_ *cobra.Command, _ []string) (err error) {
|
||||
var eventStream audit.EventStream
|
||||
eventStream, err = audit.NewEventStream(
|
||||
a.Logger().Named("EventStream"),
|
||||
audit.WithSinkBufferSize(defaultEventBufferSize),
|
||||
)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err = eventStream.RegisterSink(a.ctx, sink.NewLogSink(a.Logger().Named("LogSink"))); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var metricSink audit.Sink
|
||||
if metricSink, err = sink.NewMetricSink(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err = eventStream.RegisterSink(a.ctx, metricSink); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
a.ctx = context.WithValue(a.ctx, eventStreamKey, eventStream)
|
||||
return
|
||||
})
|
||||
|
||||
return a
|
||||
}
|
||||
|
||||
// WithConfig loads the config
|
||||
// requires nothing
|
||||
func (a *app) WithConfig() App {
|
||||
a.lateInitTasks = append(a.lateInitTasks, func(cmd *cobra.Command, _ []string) (err error) {
|
||||
cfg := CreateConfig()
|
||||
if err = cfg.ReadConfig(configFilePath); err != nil {
|
||||
return
|
||||
}
|
||||
a.ctx = context.WithValue(a.ctx, configKey, cfg)
|
||||
return
|
||||
})
|
||||
|
||||
return a
|
||||
}
|
||||
|
||||
func (a *app) WithInitTasks(task ...func(cmd *cobra.Command, args []string) (err error)) App {
|
||||
a.lateInitTasks = append(a.lateInitTasks, task...)
|
||||
return a
|
||||
}
|
||||
|
||||
func NewApp(name, short string) App {
|
||||
ctx, cancel := initAppContext()
|
||||
a := &app{
|
||||
rootCmd: &cobra.Command{
|
||||
Use: name,
|
||||
Short: short,
|
||||
},
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
lateInitTasks = append(lateInitTasks, spec.LateInitTasks...)
|
||||
|
||||
a.rootCmd.AddCommand(completionCmd)
|
||||
a.rootCmd.AddCommand(spec.SubCommands...)
|
||||
|
||||
a.rootCmd.PersistentFlags().StringVar(&configFilePath, "config", "", "Path to config file that should be used")
|
||||
a.rootCmd.PersistentFlags().StringVar(&logLevel, "log-level", "info", "logging level to use")
|
||||
a.rootCmd.PersistentFlags().BoolVar(&developmentLogs, "development-logs", false, "Enable development mode logs")
|
||||
|
||||
a.rootCmd.PersistentPreRunE = func(cmd *cobra.Command, args []string) (err error) {
|
||||
for _, initTask := range a.lateInitTasks {
|
||||
for _, initTask := range lateInitTasks {
|
||||
if err = initTask(cmd, args); err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -366,3 +152,43 @@ func initAppContext() (context.Context, context.CancelFunc) {
|
|||
|
||||
return ctx, cancel
|
||||
}
|
||||
|
||||
func (s Spec) readConfig(rootCmd *cobra.Command) error {
|
||||
viperCfg := viper.NewWithOptions()
|
||||
viperCfg.SetConfigName("config")
|
||||
viperCfg.SetConfigType("yaml")
|
||||
viperCfg.AddConfigPath(fmt.Sprintf("/etc/%s/", strings.ToLower(s.Name)))
|
||||
viperCfg.AddConfigPath(fmt.Sprintf("$HOME/.%s/", strings.ToLower(s.Name)))
|
||||
viperCfg.AddConfigPath(".")
|
||||
viperCfg.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
|
||||
viperCfg.SetEnvPrefix("INETMOCK")
|
||||
viperCfg.AutomaticEnv()
|
||||
|
||||
if s.FlagBindings != nil {
|
||||
for key, selector := range s.FlagBindings {
|
||||
if err := viperCfg.BindPFlag(key, selector(rootCmd.Flags())); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for k, v := range s.Defaults {
|
||||
viperCfg.SetDefault(k, v)
|
||||
}
|
||||
|
||||
if configFilePath != "" && path.FileExists(configFilePath) {
|
||||
viperCfg.SetConfigFile(configFilePath)
|
||||
}
|
||||
|
||||
if err := viperCfg.ReadInConfig(); err != nil {
|
||||
if _, ok := err.(viper.ConfigFileNotFoundError); !(ok && s.IgnoreMissingConfigFile) {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if err := viperCfg.Unmarshal(s.Config); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -1,117 +0,0 @@
|
|||
package app
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/viper"
|
||||
|
||||
"gitlab.com/inetmock/inetmock/internal/endpoint"
|
||||
"gitlab.com/inetmock/inetmock/pkg/cert"
|
||||
"gitlab.com/inetmock/inetmock/pkg/path"
|
||||
)
|
||||
|
||||
func CreateConfig() Config {
|
||||
configInstance := &config{
|
||||
cfg: viper.New(),
|
||||
}
|
||||
|
||||
configInstance.cfg.SetConfigName("config")
|
||||
configInstance.cfg.SetConfigType("yaml")
|
||||
configInstance.cfg.AddConfigPath("/etc/inetmock/")
|
||||
configInstance.cfg.AddConfigPath("$HOME/.inetmock")
|
||||
configInstance.cfg.AddConfigPath(".")
|
||||
configInstance.cfg.SetEnvPrefix("INetMock")
|
||||
configInstance.cfg.SetEnvKeyReplacer(strings.NewReplacer("-", "_"))
|
||||
configInstance.cfg.AutomaticEnv()
|
||||
|
||||
for k, v := range registeredDefaults {
|
||||
configInstance.cfg.SetDefault(k, v)
|
||||
}
|
||||
|
||||
for k, v := range registeredAliases {
|
||||
configInstance.cfg.RegisterAlias(k, v)
|
||||
}
|
||||
|
||||
return configInstance
|
||||
}
|
||||
|
||||
type Config interface {
|
||||
ReadConfig(configFilePath string) error
|
||||
ReadConfigString(config, format string) error
|
||||
TLSConfig() cert.Options
|
||||
APIURL() *url.URL
|
||||
AuditDataDir() string
|
||||
PCAPDataDir() string
|
||||
ListenerSpecs() map[string]endpoint.ListenerSpec
|
||||
}
|
||||
|
||||
type apiConfig struct {
|
||||
Listen string
|
||||
}
|
||||
|
||||
type config struct {
|
||||
cfg *viper.Viper
|
||||
TLS cert.Options
|
||||
Listeners map[string]endpoint.ListenerSpec
|
||||
API apiConfig
|
||||
Data Data
|
||||
}
|
||||
|
||||
func (c *config) AuditDataDir() string {
|
||||
return c.Data.Audit
|
||||
}
|
||||
|
||||
func (c *config) PCAPDataDir() string {
|
||||
return c.Data.PCAP
|
||||
}
|
||||
|
||||
func (c *config) APIURL() *url.URL {
|
||||
if u, err := url.Parse(c.API.Listen); err != nil {
|
||||
u, _ = url.Parse("tcp://:0")
|
||||
return u
|
||||
} else {
|
||||
return u
|
||||
}
|
||||
}
|
||||
|
||||
func (c *config) ListenerSpecs() map[string]endpoint.ListenerSpec {
|
||||
return c.Listeners
|
||||
}
|
||||
|
||||
func (c *config) TLSConfig() cert.Options {
|
||||
return c.TLS
|
||||
}
|
||||
|
||||
func (c *config) ReadConfigString(config, format string) (err error) {
|
||||
c.cfg.SetConfigType(format)
|
||||
if err = c.cfg.ReadConfig(strings.NewReader(config)); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
err = c.cfg.Unmarshal(c)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *config) ReadConfig(configFilePath string) (err error) {
|
||||
if configFilePath != "" && path.FileExists(configFilePath) {
|
||||
c.cfg.SetConfigFile(configFilePath)
|
||||
}
|
||||
if err = c.cfg.ReadInConfig(); err != nil {
|
||||
if _, ok := err.(viper.ConfigFileNotFoundError); ok {
|
||||
err = nil
|
||||
} else {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if err = c.cfg.Unmarshal(c); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err = c.Data.setup(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
|
@ -1,74 +0,0 @@
|
|||
package app
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"gitlab.com/inetmock/inetmock/internal/endpoint"
|
||||
)
|
||||
|
||||
func Test_config_ReadConfig(t *testing.T) {
|
||||
type args struct {
|
||||
config string
|
||||
}
|
||||
type testCase struct {
|
||||
name string
|
||||
args args
|
||||
wantListeners map[string]endpoint.ListenerSpec
|
||||
wantErr bool
|
||||
}
|
||||
tests := []testCase{
|
||||
{
|
||||
name: "Test endpoints config",
|
||||
args: args{
|
||||
// language=yaml
|
||||
config: `
|
||||
listeners:
|
||||
tcp_80:
|
||||
name: ''
|
||||
protocol: tcp
|
||||
listenAddress: ''
|
||||
port: 80
|
||||
tcp_443:
|
||||
name: ''
|
||||
protocol: tcp
|
||||
listenAddress: ''
|
||||
port: 443
|
||||
`,
|
||||
},
|
||||
wantListeners: map[string]endpoint.ListenerSpec{
|
||||
"tcp_80": {
|
||||
Name: "",
|
||||
Protocol: "tcp",
|
||||
Address: "",
|
||||
Port: 80,
|
||||
Endpoints: nil,
|
||||
},
|
||||
"tcp_443": {
|
||||
Name: "",
|
||||
Protocol: "tcp",
|
||||
Address: "",
|
||||
Port: 443,
|
||||
Endpoints: nil,
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
scenario := func(tt testCase) func(t *testing.T) {
|
||||
return func(t *testing.T) {
|
||||
cfg := CreateConfig()
|
||||
if err := cfg.ReadConfigString(tt.args.config, "yaml"); (err != nil) != tt.wantErr {
|
||||
t.Errorf("ReadConfig() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(tt.wantListeners, cfg.ListenerSpecs()) {
|
||||
t.Errorf("want = %v, got = %v", tt.wantListeners, cfg.ListenerSpecs())
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, scenario(tt))
|
||||
}
|
||||
}
|
|
@ -1,6 +0,0 @@
|
|||
package app
|
||||
|
||||
const (
|
||||
EndpointsKey = "endpoints"
|
||||
OptionsKey = "options"
|
||||
)
|
|
@ -1,34 +0,0 @@
|
|||
package app
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
type Data struct {
|
||||
PCAP string
|
||||
Audit string
|
||||
}
|
||||
|
||||
func (d *Data) setup() (err error) {
|
||||
if d.PCAP, err = ensureDataDir(d.PCAP); err != nil {
|
||||
return
|
||||
}
|
||||
if d.Audit, err = ensureDataDir(d.Audit); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func ensureDataDir(dataDirPath string) (cleanedPath string, err error) {
|
||||
cleanedPath = dataDirPath
|
||||
if !filepath.IsAbs(cleanedPath) {
|
||||
if cleanedPath, err = filepath.Abs(cleanedPath); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
err = os.MkdirAll(cleanedPath, 0640)
|
||||
return
|
||||
}
|
|
@ -1,15 +0,0 @@
|
|||
package app
|
||||
|
||||
var (
|
||||
registeredDefaults = make(map[string]interface{})
|
||||
// default aliases
|
||||
registeredAliases = map[string]string{}
|
||||
)
|
||||
|
||||
func AddDefaultValue(key string, val interface{}) {
|
||||
registeredDefaults[key] = val
|
||||
}
|
||||
|
||||
func AddAlias(alias, orig string) {
|
||||
registeredAliases[alias] = orig
|
||||
}
|
|
@ -4,9 +4,12 @@ import (
|
|||
"net"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/maxatome/go-testdeep/td"
|
||||
)
|
||||
|
||||
func Test_randomIPFallback_GetIP(t *testing.T) {
|
||||
t.Parallel()
|
||||
ra := randomIPFallback{}
|
||||
for i := 0; i < 1000; i++ {
|
||||
if got := ra.GetIP(); reflect.DeepEqual(got, net.IP{}) {
|
||||
|
@ -16,6 +19,7 @@ func Test_randomIPFallback_GetIP(t *testing.T) {
|
|||
}
|
||||
|
||||
func Test_incrementalIPFallback_GetIP(t *testing.T) {
|
||||
t.Parallel()
|
||||
type fields struct {
|
||||
latestIP uint32
|
||||
}
|
||||
|
@ -59,24 +63,22 @@ func Test_incrementalIPFallback_GetIP(t *testing.T) {
|
|||
},
|
||||
},
|
||||
}
|
||||
scenario := func(tt testCase) func(t *testing.T) {
|
||||
return func(t *testing.T) {
|
||||
for _, tc := range tests {
|
||||
tt := tc
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
i := &incrementalIPFallback{
|
||||
latestIP: tt.fields.latestIP,
|
||||
}
|
||||
for k := 0; k < len(tt.want); k++ {
|
||||
if got := i.GetIP(); !reflect.DeepEqual(got, tt.want[k]) {
|
||||
t.Errorf("GetIP() = %v, want %v", got, tt.want[k])
|
||||
td.Cmp(t, i.GetIP(), tt.want[k])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, scenario(tt))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_ipToInt32(t *testing.T) {
|
||||
t.Parallel()
|
||||
type args struct {
|
||||
ip net.IP
|
||||
}
|
||||
|
@ -101,14 +103,11 @@ func Test_ipToInt32(t *testing.T) {
|
|||
want: 3232281098,
|
||||
},
|
||||
}
|
||||
scenario := func(tt testCase) func(t *testing.T) {
|
||||
return func(t *testing.T) {
|
||||
if got := ipToInt32(tt.args.ip); got != tt.want {
|
||||
t.Errorf("ipToInt32() = %v, want %v", got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, scenario(tt))
|
||||
for _, tc := range tests {
|
||||
tt := tc
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
td.Cmp(t, ipToInt32(tt.args.ip), tt.want)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -53,8 +53,9 @@ func Benchmark_httpHandler(b *testing.B) {
|
|||
scheme: "https",
|
||||
},
|
||||
}
|
||||
scenario := func(bm benchmark) func(bm *testing.B) {
|
||||
return func(b *testing.B) {
|
||||
for _, bc := range benchmarks {
|
||||
bm := bc
|
||||
b.Run(bm.name, func(b *testing.B) {
|
||||
var err error
|
||||
var endpoint string
|
||||
if endpoint, err = setupContainer(b, bm.scheme, bm.port); err != nil {
|
||||
|
@ -88,10 +89,7 @@ func Benchmark_httpHandler(b *testing.B) {
|
|||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
for _, bm := range benchmarks {
|
||||
b.Run(bm.name, scenario(bm))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -2,11 +2,11 @@ package mock
|
|||
|
||||
import (
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/maxatome/go-testdeep/td"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
|
||||
endpoint_mock "gitlab.com/inetmock/inetmock/internal/mock/endpoint"
|
||||
|
@ -14,6 +14,7 @@ import (
|
|||
|
||||
//nolint:funlen
|
||||
func Test_loadFromConfig(t *testing.T) {
|
||||
t.Parallel()
|
||||
type args struct {
|
||||
config map[string]interface{}
|
||||
}
|
||||
|
@ -159,10 +160,11 @@ func Test_loadFromConfig(t *testing.T) {
|
|||
wantErr: false,
|
||||
},
|
||||
}
|
||||
scenario := func(tt testCase) func(t *testing.T) {
|
||||
return func(t *testing.T) {
|
||||
for _, tc := range tests {
|
||||
tt := tc
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
lcMock := endpoint_mock.NewMockLifecycle(ctrl)
|
||||
|
||||
lcMock.EXPECT().UnmarshalOptions(gomock.Any()).Do(func(cfg interface{}) {
|
||||
|
@ -174,12 +176,7 @@ func Test_loadFromConfig(t *testing.T) {
|
|||
t.Errorf("loadFromConfig() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(gotOptions, tt.wantOptions) {
|
||||
t.Errorf("loadFromConfig() gotOptions = %v, want %v", gotOptions, tt.wantOptions)
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, scenario(tt))
|
||||
td.Cmp(t, gotOptions, tt.wantOptions)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -53,8 +53,8 @@ func Benchmark_httpProxy(b *testing.B) {
|
|||
scheme: "https",
|
||||
},
|
||||
}
|
||||
scenario := func(bm benchmark) func(bm *testing.B) {
|
||||
return func(b *testing.B) {
|
||||
for _, bm := range benchmarks {
|
||||
b.Run(bm.name, func(b *testing.B) {
|
||||
var err error
|
||||
var endpoint string
|
||||
if endpoint, err = setupContainer(b, bm.port); err != nil {
|
||||
|
@ -90,10 +90,7 @@ func Benchmark_httpProxy(b *testing.B) {
|
|||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
for _, bm := range benchmarks {
|
||||
b.Run(bm.name, scenario(bm))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -54,14 +54,15 @@ func (l *ListenerSpec) ConfigureMultiplexing(tlsConfig *tls.Config) ([]Endpoint,
|
|||
}
|
||||
}
|
||||
|
||||
var endpoints []Endpoint
|
||||
if len(l.Endpoints) <= 1 {
|
||||
for name, s := range l.Endpoints {
|
||||
endpoints = append(endpoints, Endpoint{
|
||||
endpoints := []Endpoint{
|
||||
{
|
||||
name: fmt.Sprintf("%s:%s", l.Name, name),
|
||||
uplink: *l.Uplink,
|
||||
Spec: s,
|
||||
})
|
||||
},
|
||||
}
|
||||
return endpoints, nil, nil
|
||||
}
|
||||
}
|
||||
|
@ -70,10 +71,12 @@ func (l *ListenerSpec) ConfigureMultiplexing(tlsConfig *tls.Config) ([]Endpoint,
|
|||
return nil, nil, ErrUDPMultiplexer
|
||||
}
|
||||
|
||||
var epNames []string
|
||||
var epNames = make([]string, len(l.Endpoints))
|
||||
var multiplexEndpoints = make(map[string]MultiplexHandler)
|
||||
var idx int
|
||||
for name, spec := range l.Endpoints {
|
||||
epNames = append(epNames, name)
|
||||
epNames[idx] = name
|
||||
idx++
|
||||
if ep, ok := spec.Handler.(MultiplexHandler); !ok {
|
||||
return nil, nil, fmt.Errorf("handler %s %w", spec.HandlerRef, ErrMultiplexingNotSupported)
|
||||
} else {
|
||||
|
@ -90,7 +93,8 @@ func (l *ListenerSpec) ConfigureMultiplexing(tlsConfig *tls.Config) ([]Endpoint,
|
|||
|
||||
var tlsRequired = false
|
||||
|
||||
for _, epName := range epNames {
|
||||
var endpoints = make([]Endpoint, len(epNames))
|
||||
for i, epName := range epNames {
|
||||
epSpec := l.Endpoints[epName]
|
||||
var epMux = plainMux
|
||||
if epSpec.TLS {
|
||||
|
@ -106,7 +110,7 @@ func (l *ListenerSpec) ConfigureMultiplexing(tlsConfig *tls.Config) ([]Endpoint,
|
|||
Spec: epSpec,
|
||||
}
|
||||
|
||||
endpoints = append(endpoints, epListener)
|
||||
endpoints[i] = epListener
|
||||
}
|
||||
|
||||
var muxes []cmux.CMux
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
|
||||
//nolint:funlen
|
||||
func Test_tblWriter_Write(t *testing.T) {
|
||||
t.Parallel()
|
||||
type s1 struct {
|
||||
Name string
|
||||
Age int
|
||||
|
@ -139,8 +140,10 @@ func Test_tblWriter_Write(t *testing.T) {
|
|||
`,
|
||||
},
|
||||
}
|
||||
scenario := func(tt testCase) func(t *testing.T) {
|
||||
return func(t *testing.T) {
|
||||
for _, tc := range tests {
|
||||
tt := tc
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
bldr := new(strings.Builder)
|
||||
|
||||
// hack to be able to format expected strings pretty
|
||||
|
@ -153,9 +156,6 @@ func Test_tblWriter_Write(t *testing.T) {
|
|||
if bldr.String() != tt.wantResult {
|
||||
t.Errorf("Write() got = %s, want %s", bldr.String(), tt.wantResult)
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, scenario(tt))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -21,6 +21,7 @@ const (
|
|||
)
|
||||
|
||||
func Test_recorder_CompleteWorkflow(t *testing.T) {
|
||||
t.Parallel()
|
||||
var recorder = pcap.NewRecorder()
|
||||
var buffer = bytes.NewBuffer(nil)
|
||||
var err error
|
||||
|
|
|
@ -30,6 +30,7 @@ func (f fakeWriterCloser) Close() error {
|
|||
}
|
||||
|
||||
func Test_recorder_Subscriptions(t *testing.T) {
|
||||
t.Parallel()
|
||||
type subscriptionRequest struct {
|
||||
Name string
|
||||
Device string
|
||||
|
@ -82,8 +83,10 @@ func Test_recorder_Subscriptions(t *testing.T) {
|
|||
},
|
||||
},
|
||||
}
|
||||
scenario := func(tt testCase) func(t *testing.T) {
|
||||
return func(t *testing.T) {
|
||||
for _, tc := range tests {
|
||||
tt := tc
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
r := pcap.NewRecorder()
|
||||
|
||||
t.Cleanup(func() {
|
||||
|
@ -101,14 +104,12 @@ func Test_recorder_Subscriptions(t *testing.T) {
|
|||
if gotSubscriptions := sortSubscriptions(r.Subscriptions()); !reflect.DeepEqual(gotSubscriptions, tt.wantSubscriptions) {
|
||||
t.Errorf("Subscriptions() = %v, want %v", gotSubscriptions, tt.wantSubscriptions)
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, scenario(tt))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_recorder_StartRecordingWithOptions(t *testing.T) {
|
||||
t.Parallel()
|
||||
type args struct {
|
||||
device string
|
||||
consumer pcap.Consumer
|
||||
|
@ -155,8 +156,10 @@ func Test_recorder_StartRecordingWithOptions(t *testing.T) {
|
|||
wantErr: true,
|
||||
},
|
||||
}
|
||||
scenario := func(tt testCase) func(t *testing.T) {
|
||||
return func(t *testing.T) {
|
||||
for _, tc := range tests {
|
||||
tt := tc
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var err error
|
||||
var recorder pcap.Recorder
|
||||
|
||||
|
@ -174,14 +177,12 @@ func Test_recorder_StartRecordingWithOptions(t *testing.T) {
|
|||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("StartRecordingWithOptions() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, scenario(tt))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_recorder_AvailableDevices(t *testing.T) {
|
||||
t.Parallel()
|
||||
type testCase struct {
|
||||
name string
|
||||
mtacher func(got []pcap.Device) error
|
||||
|
@ -209,8 +210,11 @@ func Test_recorder_AvailableDevices(t *testing.T) {
|
|||
wantErr: false,
|
||||
},
|
||||
}
|
||||
scenario := func(tt testCase) func(t *testing.T) {
|
||||
return func(t *testing.T) {
|
||||
|
||||
for _, tc := range tests {
|
||||
tt := tc
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
re := pcap.NewRecorder()
|
||||
t.Cleanup(func() {
|
||||
if err := re.Close(); err != nil {
|
||||
|
@ -225,14 +229,12 @@ func Test_recorder_AvailableDevices(t *testing.T) {
|
|||
if err := tt.mtacher(gotDevices); err != nil {
|
||||
t.Errorf("AvailableDevices() matcher error = %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, scenario(tt))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_recorder_StopRecording(t *testing.T) {
|
||||
t.Parallel()
|
||||
type args struct {
|
||||
consumerKey string
|
||||
}
|
||||
|
@ -249,6 +251,7 @@ func Test_recorder_StopRecording(t *testing.T) {
|
|||
consumerKey: "lo:test.pcap",
|
||||
},
|
||||
recorderSetup: func(t *testing.T) (recorder pcap.Recorder, err error) {
|
||||
t.Helper()
|
||||
return pcap.NewRecorder(), nil
|
||||
},
|
||||
wantErr: true,
|
||||
|
@ -259,6 +262,7 @@ func Test_recorder_StopRecording(t *testing.T) {
|
|||
consumerKey: "lo:test",
|
||||
},
|
||||
recorderSetup: func(t *testing.T) (recorder pcap.Recorder, err error) {
|
||||
t.Helper()
|
||||
recorder = pcap.NewRecorder()
|
||||
err = recorder.StartRecording(context.Background(), "lo", consumers.NewNoOpConsumerWithName("test"))
|
||||
|
||||
|
@ -272,6 +276,7 @@ func Test_recorder_StopRecording(t *testing.T) {
|
|||
consumerKey: "lo:test",
|
||||
},
|
||||
recorderSetup: func(t *testing.T) (recorder pcap.Recorder, err error) {
|
||||
t.Helper()
|
||||
recorder = pcap.NewRecorder()
|
||||
var writerConsumer pcap.Consumer
|
||||
gotClosed := false
|
||||
|
@ -299,8 +304,10 @@ func Test_recorder_StopRecording(t *testing.T) {
|
|||
wantErr: false,
|
||||
},
|
||||
}
|
||||
scenario := func(tt testCase) func(t *testing.T) {
|
||||
return func(t *testing.T) {
|
||||
for _, tc := range tests {
|
||||
tt := tc
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var err error
|
||||
var r pcap.Recorder
|
||||
if r, err = tt.recorderSetup(t); err != nil {
|
||||
|
@ -316,14 +323,12 @@ func Test_recorder_StopRecording(t *testing.T) {
|
|||
if err := r.StopRecording(tt.args.consumerKey); (err != nil) != tt.wantErr {
|
||||
t.Errorf("StopRecording() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, scenario(tt))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_recorder_StopRecordingFromContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
var err error
|
||||
var writerConsumer pcap.Consumer
|
||||
gotClosed := false
|
||||
|
|
|
@ -28,6 +28,14 @@ type auditServer struct {
|
|||
auditDataDirPath string
|
||||
}
|
||||
|
||||
func NewAuditServiceServer(logger logging.Logger, eventStream audit.EventStream, auditDataDirPath string) v1.AuditServiceServer {
|
||||
return &auditServer{
|
||||
logger: logger,
|
||||
eventStream: eventStream,
|
||||
auditDataDirPath: auditDataDirPath,
|
||||
}
|
||||
}
|
||||
|
||||
func (a *auditServer) ListSinks(context.Context, *v1.ListSinksRequest) (*v1.ListSinksResponse, error) {
|
||||
return &v1.ListSinksResponse{
|
||||
Sinks: a.eventStream.Sinks(),
|
||||
|
@ -35,7 +43,8 @@ func (a *auditServer) ListSinks(context.Context, *v1.ListSinksRequest) (*v1.List
|
|||
}
|
||||
|
||||
func (a *auditServer) WatchEvents(req *v1.WatchEventsRequest, srv v1.AuditService_WatchEventsServer) (err error) {
|
||||
a.logger.Info("watcher attached", zap.String("name", req.WatcherName))
|
||||
var logger = a.logger
|
||||
logger.Info("watcher attached", zap.String("name", req.WatcherName))
|
||||
err = a.eventStream.RegisterSink(srv.Context(), sink.NewGenericSink(req.WatcherName, func(ev audit.Event) {
|
||||
if err = srv.Send(&v1.WatchEventsResponse{Entity: ev.ProtoMessage()}); err != nil {
|
||||
return
|
||||
|
@ -47,7 +56,7 @@ func (a *auditServer) WatchEvents(req *v1.WatchEventsRequest, srv v1.AuditServic
|
|||
}
|
||||
|
||||
<-srv.Context().Done()
|
||||
a.logger.Info("Watcher detached", zap.String("name", req.WatcherName))
|
||||
logger.Info("Watcher detached", zap.String("name", req.WatcherName))
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -76,7 +85,9 @@ func (a *auditServer) RegisterFileSink(_ context.Context, req *v1.RegisterFileSi
|
|||
|
||||
func (a *auditServer) RemoveFileSink(_ context.Context, req *v1.RemoveFileSinkRequest) (*v1.RemoveFileSinkResponse, error) {
|
||||
if gotRemoved := a.eventStream.RemoveSink(req.TargetPath); gotRemoved {
|
||||
return &v1.RemoveFileSinkResponse{}, nil
|
||||
return &v1.RemoveFileSinkResponse{
|
||||
SinkGotRemoved: gotRemoved,
|
||||
}, nil
|
||||
}
|
||||
return nil, status.Error(codes.NotFound, "file sink with given target path not found")
|
||||
}
|
||||
|
|
114
internal/rpc/audit_server_test.go
Normal file
114
internal/rpc/audit_server_test.go
Normal file
|
@ -0,0 +1,114 @@
|
|||
package rpc_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/maxatome/go-testdeep/td"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
|
||||
auditm "gitlab.com/inetmock/inetmock/internal/mock/audit"
|
||||
"gitlab.com/inetmock/inetmock/internal/rpc"
|
||||
"gitlab.com/inetmock/inetmock/internal/rpc/test"
|
||||
tst "gitlab.com/inetmock/inetmock/internal/test"
|
||||
"gitlab.com/inetmock/inetmock/pkg/audit"
|
||||
"gitlab.com/inetmock/inetmock/pkg/logging"
|
||||
v1 "gitlab.com/inetmock/inetmock/pkg/rpc/v1"
|
||||
)
|
||||
|
||||
const (
|
||||
grpcTimeout = 100 * time.Millisecond
|
||||
)
|
||||
|
||||
func Test_auditServer_RemoveFileSink(t *testing.T) {
|
||||
t.Parallel()
|
||||
type fields struct {
|
||||
eventStreamSetup func(t *testing.T) audit.EventStream
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
req *v1.RemoveFileSinkRequest
|
||||
fields fields
|
||||
want td.StructFields
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "Remove existing file sink - success",
|
||||
req: &v1.RemoveFileSinkRequest{
|
||||
TargetPath: "test.pcap",
|
||||
},
|
||||
fields: fields{
|
||||
eventStreamSetup: func(t *testing.T) audit.EventStream {
|
||||
t.Helper()
|
||||
ctrl := gomock.NewController(t)
|
||||
|
||||
es := auditm.NewMockEventStream(ctrl)
|
||||
es.
|
||||
EXPECT().
|
||||
RemoveSink("test.pcap").
|
||||
Return(true)
|
||||
|
||||
return es
|
||||
},
|
||||
},
|
||||
want: td.StructFields{
|
||||
"SinkGotRemoved": true,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Remove non-existing file sink - success",
|
||||
req: &v1.RemoveFileSinkRequest{
|
||||
TargetPath: "test.pcap",
|
||||
},
|
||||
fields: fields{
|
||||
eventStreamSetup: func(t *testing.T) audit.EventStream {
|
||||
t.Helper()
|
||||
ctrl := gomock.NewController(t)
|
||||
|
||||
es := auditm.NewMockEventStream(ctrl)
|
||||
es.
|
||||
EXPECT().
|
||||
RemoveSink("test.pcap").
|
||||
Return(false)
|
||||
|
||||
return es
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tt := tc
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
testCtx := tst.Context(t)
|
||||
logger := logging.CreateTestLogger(t)
|
||||
|
||||
srv := test.NewTestGRPCServer(t, func(registrar grpc.ServiceRegistrar) {
|
||||
v1.RegisterAuditServiceServer(registrar, rpc.NewAuditServiceServer(logger, tt.fields.eventStreamSetup(t), t.TempDir()))
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithTimeout(testCtx, grpcTimeout)
|
||||
conn := srv.Dial(ctx, t)
|
||||
cancel()
|
||||
|
||||
client := v1.NewAuditServiceClient(conn)
|
||||
|
||||
ctx, cancel = context.WithTimeout(testCtx, grpcTimeout)
|
||||
t.Cleanup(cancel)
|
||||
got, err := client.RemoveFileSink(ctx, tt.req)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("RemoveFileSink() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
if !tt.wantErr {
|
||||
td.CmpStruct(t, got, new(v1.RemoveFileSinkResponse), tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -63,12 +63,7 @@ func (i *inetmockAPI) StartServer() (err error) {
|
|||
checker: i.checker,
|
||||
})
|
||||
|
||||
v1.RegisterAuditServiceServer(i.server, &auditServer{
|
||||
logger: i.logger,
|
||||
eventStream: i.eventStream,
|
||||
auditDataDirPath: i.auditDataDir,
|
||||
})
|
||||
|
||||
v1.RegisterAuditServiceServer(i.server, NewAuditServiceServer(i.logger, i.eventStream, i.auditDataDir))
|
||||
v1.RegisterPCAPServiceServer(i.server, NewPCAPServer(i.pcapDataDir, pcap.NewRecorder()))
|
||||
|
||||
reflection.Register(i.server)
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
"reflect"
|
||||
"sort"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"google.golang.org/grpc"
|
||||
|
@ -15,10 +16,12 @@ import (
|
|||
"gitlab.com/inetmock/inetmock/internal/pcap/consumers"
|
||||
"gitlab.com/inetmock/inetmock/internal/rpc"
|
||||
"gitlab.com/inetmock/inetmock/internal/rpc/test"
|
||||
tst "gitlab.com/inetmock/inetmock/internal/test"
|
||||
rpcV1 "gitlab.com/inetmock/inetmock/pkg/rpc/v1"
|
||||
)
|
||||
|
||||
func Test_pcapServer_ListActiveRecordings(t *testing.T) {
|
||||
t.Parallel()
|
||||
type testCase struct {
|
||||
name string
|
||||
recorderSetup func(t *testing.T) (recorder pcap.Recorder, err error)
|
||||
|
@ -29,6 +32,7 @@ func Test_pcapServer_ListActiveRecordings(t *testing.T) {
|
|||
{
|
||||
name: "No subscriptions",
|
||||
recorderSetup: func(t *testing.T) (recorder pcap.Recorder, err error) {
|
||||
t.Helper()
|
||||
recorder = pcap.NewRecorder()
|
||||
return
|
||||
},
|
||||
|
@ -38,6 +42,7 @@ func Test_pcapServer_ListActiveRecordings(t *testing.T) {
|
|||
{
|
||||
name: "Listening to lo interface",
|
||||
recorderSetup: func(t *testing.T) (recorder pcap.Recorder, err error) {
|
||||
t.Helper()
|
||||
recorder = pcap.NewRecorder()
|
||||
err = recorder.StartRecording(context.Background(), "lo", consumers.NewNoOpConsumerWithName("test"))
|
||||
return
|
||||
|
@ -46,8 +51,10 @@ func Test_pcapServer_ListActiveRecordings(t *testing.T) {
|
|||
wantErr: false,
|
||||
},
|
||||
}
|
||||
scenario := func(tt testCase) func(t *testing.T) {
|
||||
return func(t *testing.T) {
|
||||
for _, tc := range tests {
|
||||
tt := tc
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var err error
|
||||
var recorder pcap.Recorder
|
||||
if recorder, err = tt.recorderSetup(t); err != nil {
|
||||
|
@ -68,24 +75,24 @@ func Test_pcapServer_ListActiveRecordings(t *testing.T) {
|
|||
t.Errorf("ListActiveRecordings() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if (gotResp == nil) != tt.wantErr {
|
||||
if gotResp == nil {
|
||||
if !tt.wantErr {
|
||||
t.Errorf("response was nil")
|
||||
}
|
||||
return
|
||||
} else {
|
||||
}
|
||||
|
||||
sort.Strings(gotResp.Subscriptions)
|
||||
sort.Strings(tt.wantSubscriptions)
|
||||
if !reflect.DeepEqual(gotResp.Subscriptions, tt.wantSubscriptions) {
|
||||
t.Errorf("ListActiveRecordings() gotResp = %v, want %v", gotResp.Subscriptions, tt.wantSubscriptions)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, scenario(tt))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_pcapServer_ListAvailableDevices(t *testing.T) {
|
||||
t.Parallel()
|
||||
type testCase struct {
|
||||
name string
|
||||
matcher func(devs []*rpcV1.ListAvailableDevicesResponse_PCAPDevice) error
|
||||
|
@ -135,8 +142,10 @@ func Test_pcapServer_ListAvailableDevices(t *testing.T) {
|
|||
wantErr: false,
|
||||
},
|
||||
}
|
||||
scenario := func(tt testCase) func(t *testing.T) {
|
||||
return func(t *testing.T) {
|
||||
for _, tc := range tests {
|
||||
tt := tc
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var err error
|
||||
var recorder = pcap.NewRecorder()
|
||||
|
||||
|
@ -157,14 +166,12 @@ func Test_pcapServer_ListAvailableDevices(t *testing.T) {
|
|||
t.Errorf("ListAvailableDevices() matcher error = %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, scenario(tt))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_pcapServer_StartPCAPFileRecording(t *testing.T) {
|
||||
t.Parallel()
|
||||
type testCase struct {
|
||||
name string
|
||||
req *rpcV1.StartPCAPFileRecordingRequest
|
||||
|
@ -196,8 +203,10 @@ func Test_pcapServer_StartPCAPFileRecording(t *testing.T) {
|
|||
wantErr: true,
|
||||
},
|
||||
}
|
||||
scenario := func(tt testCase) func(t *testing.T) {
|
||||
return func(t *testing.T) {
|
||||
for _, tc := range tests {
|
||||
tt := tc
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var err error
|
||||
var recorder = pcap.NewRecorder()
|
||||
|
||||
|
@ -217,14 +226,12 @@ func Test_pcapServer_StartPCAPFileRecording(t *testing.T) {
|
|||
if currentSubs := recorder.Subscriptions(); !reflect.DeepEqual(currentSubs, tt.wantSubscriptions) {
|
||||
t.Errorf("StartPCAPFileRecording() got = %v, want %v", currentSubs, tt.wantSubscriptions)
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, scenario(tt))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_pcapServer_StopPCAPFileRecord(t *testing.T) {
|
||||
t.Parallel()
|
||||
type testCase struct {
|
||||
name string
|
||||
keyToRemove string
|
||||
|
@ -237,6 +244,7 @@ func Test_pcapServer_StopPCAPFileRecord(t *testing.T) {
|
|||
name: "Remove non existing recording",
|
||||
keyToRemove: "lo:asdf.pcap",
|
||||
recorderSetup: func(t *testing.T) (recorder pcap.Recorder, err error) {
|
||||
t.Helper()
|
||||
recorder = pcap.NewRecorder()
|
||||
return
|
||||
},
|
||||
|
@ -247,6 +255,7 @@ func Test_pcapServer_StopPCAPFileRecord(t *testing.T) {
|
|||
name: "Remove an existing recording",
|
||||
keyToRemove: "lo:test.pcap",
|
||||
recorderSetup: func(t *testing.T) (recorder pcap.Recorder, err error) {
|
||||
t.Helper()
|
||||
recorder = pcap.NewRecorder()
|
||||
err = recorder.StartRecording(context.Background(), "lo", consumers.NewNoOpConsumerWithName("test.pcap"))
|
||||
return
|
||||
|
@ -255,8 +264,10 @@ func Test_pcapServer_StopPCAPFileRecord(t *testing.T) {
|
|||
wantErr: false,
|
||||
},
|
||||
}
|
||||
scenario := func(tt testCase) func(t *testing.T) {
|
||||
return func(t *testing.T) {
|
||||
for _, tc := range tests {
|
||||
tt := tc
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var err error
|
||||
var recorder pcap.Recorder
|
||||
if recorder, err = tt.recorderSetup(t); err != nil {
|
||||
|
@ -281,44 +292,21 @@ func Test_pcapServer_StopPCAPFileRecord(t *testing.T) {
|
|||
if gotResp.Removed != tt.removedRecording {
|
||||
t.Errorf("StopPCAPFileRecord() removed = %v, want %v", gotResp.Removed, tt.removedRecording)
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, scenario(tt))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func setupTestPCAPServer(t *testing.T, recorder pcap.Recorder) rpcV1.PCAPServiceClient {
|
||||
t.Helper()
|
||||
var err error
|
||||
var srv *test.GRPCServer
|
||||
p := rpc.NewPCAPServer(t.TempDir(), recorder)
|
||||
srv, err = test.NewTestGRPCServer(func(registrar grpc.ServiceRegistrar) {
|
||||
|
||||
srv := test.NewTestGRPCServer(t, func(registrar grpc.ServiceRegistrar) {
|
||||
rpcV1.RegisterPCAPServiceServer(registrar, p)
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("NewTestGRPCServer() error = %v", err)
|
||||
}
|
||||
|
||||
if err = srv.StartServer(); err != nil {
|
||||
t.Fatalf("StartServer() error = %v", err)
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
srv.StopServer()
|
||||
})
|
||||
|
||||
var conn *grpc.ClientConn
|
||||
if conn, err = srv.Dial(context.Background(), grpc.WithInsecure()); err != nil {
|
||||
t.Fatalf("Dial() error = %v", err)
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
t.Errorf("conn.Close() error = %v", err)
|
||||
}
|
||||
})
|
||||
ctx, cancel := context.WithTimeout(tst.Context(t), 100*time.Millisecond)
|
||||
var conn = srv.Dial(ctx, t)
|
||||
cancel()
|
||||
|
||||
return rpcV1.NewPCAPServiceClient(conn)
|
||||
}
|
||||
|
|
|
@ -4,100 +4,66 @@ import (
|
|||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/test/bufconn"
|
||||
)
|
||||
|
||||
const (
|
||||
bufconnBufferSize = 1024 * 1024
|
||||
)
|
||||
|
||||
type APIRegistration func(registrar grpc.ServiceRegistrar)
|
||||
|
||||
type GRPCServer struct {
|
||||
mutex sync.Mutex
|
||||
serverRunning chan struct{}
|
||||
cancel context.CancelFunc
|
||||
server *grpc.Server
|
||||
addr *net.TCPAddr
|
||||
listener net.Listener
|
||||
listener *bufconn.Listener
|
||||
}
|
||||
|
||||
func NewTestGRPCServer(registrations ...APIRegistration) (srv *GRPCServer, err error) {
|
||||
srv = new(GRPCServer)
|
||||
if srv.listener, err = net.Listen("tcp", "127.0.0.1:0"); err != nil {
|
||||
return
|
||||
func NewTestGRPCServer(tb testing.TB, registrations ...APIRegistration) (srv *GRPCServer) {
|
||||
tb.Helper()
|
||||
srv = &GRPCServer{
|
||||
listener: bufconn.Listen(bufconnBufferSize),
|
||||
server: grpc.NewServer(),
|
||||
}
|
||||
|
||||
var isTCPAddr bool
|
||||
if srv.addr, isTCPAddr = srv.listener.Addr().(*net.TCPAddr); !isTCPAddr {
|
||||
err = errors.New("expected TPC addr but wasn't")
|
||||
}
|
||||
|
||||
srv.server = grpc.NewServer()
|
||||
|
||||
for _, reg := range registrations {
|
||||
reg(srv.server)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (t *GRPCServer) StartServer() (err error) {
|
||||
t.mutex.Lock()
|
||||
var ctx context.Context
|
||||
ctx, t.cancel = context.WithCancel(context.Background())
|
||||
t.mutex.Unlock()
|
||||
errs := t.startServerAsync(ctx)
|
||||
select {
|
||||
case err = <-errs:
|
||||
default:
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
tb.Cleanup(cancel)
|
||||
go srv.startServerLifecycle(ctx, tb)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (t *GRPCServer) Dial(ctx context.Context, opts ...grpc.DialOption) (*grpc.ClientConn, error) {
|
||||
opts = append(opts, grpc.WithInsecure())
|
||||
|
||||
return grpc.DialContext(ctx, t.addr.String(), opts...)
|
||||
func (t *GRPCServer) Dial(ctx context.Context, tb testing.TB, opts ...grpc.DialOption) *grpc.ClientConn {
|
||||
tb.Helper()
|
||||
opts = append(opts, grpc.WithInsecure(), grpc.WithContextDialer(func(context.Context, string) (net.Conn, error) {
|
||||
return t.listener.Dial()
|
||||
}))
|
||||
conn, err := grpc.DialContext(ctx, "", opts...)
|
||||
if err != nil {
|
||||
tb.Fatalf("failed to connect to gRPC test server - error = %v", err)
|
||||
}
|
||||
tb.Cleanup(func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
tb.Errorf("Failed to close bufconn connection error = %v", err)
|
||||
}
|
||||
})
|
||||
return conn
|
||||
}
|
||||
|
||||
func (t *GRPCServer) StopServer() {
|
||||
t.cancel()
|
||||
}
|
||||
|
||||
func (t *GRPCServer) IsRunning() bool {
|
||||
t.mutex.Lock()
|
||||
defer t.mutex.Unlock()
|
||||
|
||||
if t.serverRunning == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
select {
|
||||
case _, more := <-t.serverRunning:
|
||||
return more
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
func (t *GRPCServer) startServerAsync(ctx context.Context) (errs chan error) {
|
||||
errs = make(chan error)
|
||||
t.mutex.Lock()
|
||||
t.serverRunning = make(chan struct{})
|
||||
t.mutex.Unlock()
|
||||
defer close(t.serverRunning)
|
||||
|
||||
func (t *GRPCServer) startServerLifecycle(ctx context.Context, tb testing.TB) {
|
||||
tb.Helper()
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
if t.IsRunning() {
|
||||
t.server.Stop()
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
if err := t.server.Serve(t.listener); err != nil && !errors.Is(err, grpc.ErrServerStopped) {
|
||||
errs <- err
|
||||
tb.Errorf("error occurred during running the gRPC test server - error - %v", err)
|
||||
}
|
||||
}()
|
||||
return
|
||||
}
|
||||
|
|
18
internal/test/context.go
Normal file
18
internal/test/context.go
Normal file
|
@ -0,0 +1,18 @@
|
|||
package test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Context(t *testing.T) context.Context {
|
||||
t.Helper()
|
||||
if deadline, ok := t.Deadline(); ok {
|
||||
ctx, cancel := context.WithDeadline(context.Background(), deadline)
|
||||
t.Cleanup(cancel)
|
||||
return ctx
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
t.Cleanup(cancel)
|
||||
return ctx
|
||||
}
|
|
@ -16,6 +16,7 @@ import (
|
|||
const containerShutdownTimeout = 5 * time.Second
|
||||
|
||||
func SetupINetMockContainer(ctx context.Context, tb testing.TB, exposedPorts ...string) (testcontainers.Container, error) {
|
||||
tb.Helper()
|
||||
//nolint:dogsled
|
||||
_, fileName, _, _ := runtime.Caller(0)
|
||||
|
||||
|
|
|
@ -62,7 +62,6 @@ func NewEventStream(logger logging.Logger, options ...EventStreamOption) (EventS
|
|||
return underlying, err
|
||||
}
|
||||
|
||||
//nolint:gocritic
|
||||
func (e *eventStream) Emit(ev Event) {
|
||||
ev.ApplyDefaults(e.idGenerator.Generate().Int64())
|
||||
select {
|
||||
|
|
|
@ -53,17 +53,19 @@ var (
|
|||
}
|
||||
)
|
||||
|
||||
func wgMockSink(t testing.TB, wg *sync.WaitGroup) audit.Sink {
|
||||
func wgMockSink(tb testing.TB, wg *sync.WaitGroup) audit.Sink {
|
||||
tb.Helper()
|
||||
return sink.NewGenericSink(
|
||||
"WG mock sink",
|
||||
func(event audit.Event) {
|
||||
t.Logf("Got event = %v", event)
|
||||
tb.Logf("Got event = %v", event)
|
||||
wg.Done()
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func Test_eventStream_RegisterSink(t *testing.T) {
|
||||
t.Parallel()
|
||||
type args struct {
|
||||
s audit.Sink
|
||||
}
|
||||
|
@ -92,8 +94,10 @@ func Test_eventStream_RegisterSink(t *testing.T) {
|
|||
wantErr: true,
|
||||
},
|
||||
}
|
||||
scenario := func(tt testCase) func(t *testing.T) {
|
||||
return func(t *testing.T) {
|
||||
for _, tc := range tests {
|
||||
tt := tc
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var err error
|
||||
var e audit.EventStream
|
||||
if e, err = audit.NewEventStream(logging.CreateTestLogger(t)); err != nil {
|
||||
|
@ -121,15 +125,12 @@ func Test_eventStream_RegisterSink(t *testing.T) {
|
|||
if !found {
|
||||
t.Errorf("expected defaultSink name %s not found in registered sinks %v", tt.args.s.Name(), e.Sinks())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, scenario(tt))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_eventStream_Emit(t *testing.T) {
|
||||
t.Parallel()
|
||||
type args struct {
|
||||
evs []*audit.Event
|
||||
opts []audit.EventStreamOption
|
||||
|
@ -165,9 +166,10 @@ func Test_eventStream_Emit(t *testing.T) {
|
|||
subscribe: false,
|
||||
},
|
||||
}
|
||||
|
||||
scenario := func(tt testCase) func(t *testing.T) {
|
||||
return func(t *testing.T) {
|
||||
for _, tc := range tests {
|
||||
tt := tc
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var err error
|
||||
var e audit.EventStream
|
||||
if e, err = audit.NewEventStream(logging.CreateTestLogger(t), tt.args.opts...); err != nil {
|
||||
|
@ -208,10 +210,6 @@ func Test_eventStream_Emit(t *testing.T) {
|
|||
case <-time.After(5 * time.Second):
|
||||
t.Errorf("did not get all expected events in time")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, scenario(tt))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,9 +2,9 @@ package audit
|
|||
|
||||
import (
|
||||
"net/http"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/maxatome/go-testdeep/td"
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
|
||||
|
@ -13,6 +13,7 @@ import (
|
|||
)
|
||||
|
||||
func Test_guessDetailsFromApp(t *testing.T) {
|
||||
t.Parallel()
|
||||
mustAny := func(msg proto.Message) *anypb.Any {
|
||||
a, err := anypb.New(msg)
|
||||
if err != nil {
|
||||
|
@ -49,14 +50,12 @@ func Test_guessDetailsFromApp(t *testing.T) {
|
|||
},
|
||||
},
|
||||
}
|
||||
scenario := func(tt testCase) func(t *testing.T) {
|
||||
return func(t *testing.T) {
|
||||
if got := guessDetailsFromApp(tt.args.any); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("guessDetailsFromApp() = %v, want %v", got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, scenario(tt))
|
||||
for _, tc := range tests {
|
||||
tt := tc
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := guessDetailsFromApp(tt.args.any)
|
||||
td.Cmp(t, got, tt.want)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,9 +4,10 @@ import (
|
|||
"bytes"
|
||||
"encoding/hex"
|
||||
"io"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/maxatome/go-testdeep/td"
|
||||
|
||||
"gitlab.com/inetmock/inetmock/pkg/audit"
|
||||
)
|
||||
|
||||
|
@ -26,6 +27,7 @@ func mustDecodeHex(hexBytes string) io.Reader {
|
|||
}
|
||||
|
||||
func Test_eventReader_Read(t *testing.T) {
|
||||
t.Parallel()
|
||||
type fields struct {
|
||||
source io.Reader
|
||||
}
|
||||
|
@ -53,8 +55,10 @@ func Test_eventReader_Read(t *testing.T) {
|
|||
wantErr: false,
|
||||
},
|
||||
}
|
||||
scenario := func(tt testCase) func(t *testing.T) {
|
||||
return func(t *testing.T) {
|
||||
for _, tc := range tests {
|
||||
tt := tc
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
e := audit.NewEventReader(tt.fields.source)
|
||||
gotEv, err := e.Read()
|
||||
if (err != nil) != tt.wantErr {
|
||||
|
@ -62,12 +66,7 @@ func Test_eventReader_Read(t *testing.T) {
|
|||
return
|
||||
}
|
||||
|
||||
if err == nil && !reflect.DeepEqual(gotEv, *tt.wantEv) {
|
||||
t.Errorf("Read() gotEv = %v, want %v", gotEv, tt.wantEv)
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, scenario(tt))
|
||||
td.Cmp(t, gotEv, *tt.wantEv)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -13,6 +13,7 @@ import (
|
|||
)
|
||||
|
||||
func Test_genericSink_OnSubscribe(t *testing.T) {
|
||||
t.Parallel()
|
||||
type testCase struct {
|
||||
name string
|
||||
events []*audit.Event
|
||||
|
@ -27,8 +28,10 @@ func Test_genericSink_OnSubscribe(t *testing.T) {
|
|||
events: testEvents,
|
||||
},
|
||||
}
|
||||
scenario := func(tt testCase) func(t *testing.T) {
|
||||
return func(t *testing.T) {
|
||||
for _, tc := range tests {
|
||||
tt := tc
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
wg := new(sync.WaitGroup)
|
||||
wg.Add(len(tt.events))
|
||||
|
||||
|
@ -57,9 +60,6 @@ func Test_genericSink_OnSubscribe(t *testing.T) {
|
|||
t.Errorf("not all events recorded in time")
|
||||
case <-wait.ForWaitGroupDone(wg):
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, scenario(tt))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,6 +18,7 @@ import (
|
|||
|
||||
//nolint:dupl
|
||||
func Test_logSink_OnSubscribe(t *testing.T) {
|
||||
t.Parallel()
|
||||
type fields struct {
|
||||
loggerSetup func(t *testing.T, wg *sync.WaitGroup) logging.Logger
|
||||
}
|
||||
|
@ -31,8 +32,8 @@ func Test_logSink_OnSubscribe(t *testing.T) {
|
|||
name: "Get a single log line",
|
||||
fields: fields{
|
||||
loggerSetup: func(t *testing.T, wg *sync.WaitGroup) logging.Logger {
|
||||
t.Helper()
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
loggerMock := logging_mock.NewMockLogger(ctrl)
|
||||
|
||||
loggerMock.
|
||||
|
@ -57,8 +58,8 @@ func Test_logSink_OnSubscribe(t *testing.T) {
|
|||
name: "Get multiple events",
|
||||
fields: fields{
|
||||
loggerSetup: func(t *testing.T, wg *sync.WaitGroup) logging.Logger {
|
||||
t.Helper()
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
loggerMock := logging_mock.NewMockLogger(ctrl)
|
||||
|
||||
loggerMock.
|
||||
|
@ -80,8 +81,10 @@ func Test_logSink_OnSubscribe(t *testing.T) {
|
|||
events: testEvents,
|
||||
},
|
||||
}
|
||||
scenario := func(tt testCase) func(t *testing.T) {
|
||||
return func(t *testing.T) {
|
||||
for _, tc := range tests {
|
||||
tt := tc
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
wg := new(sync.WaitGroup)
|
||||
wg.Add(len(tt.events))
|
||||
|
||||
|
@ -107,10 +110,6 @@ func Test_logSink_OnSubscribe(t *testing.T) {
|
|||
t.Errorf("not all events recorded in time")
|
||||
case <-wait.ForWaitGroupDone(wg):
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, scenario(tt))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -56,6 +56,7 @@ var (
|
|||
)
|
||||
|
||||
func Test_writerCloserSink_OnSubscribe(t *testing.T) {
|
||||
t.Parallel()
|
||||
type testCase struct {
|
||||
name string
|
||||
events []*audit.Event
|
||||
|
@ -70,13 +71,14 @@ func Test_writerCloserSink_OnSubscribe(t *testing.T) {
|
|||
events: testEvents,
|
||||
},
|
||||
}
|
||||
scenario := func(tt testCase) func(t *testing.T) {
|
||||
return func(t *testing.T) {
|
||||
for _, tc := range tests {
|
||||
tt := tc
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
wg := new(sync.WaitGroup)
|
||||
wg.Add(len(tt.events))
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
|
||||
writerMock := audit_mock.NewMockWriter(ctrl)
|
||||
writerMock.
|
||||
|
@ -111,10 +113,6 @@ func Test_writerCloserSink_OnSubscribe(t *testing.T) {
|
|||
t.Errorf("not all events recorded in time")
|
||||
case <-wait.ForWaitGroupDone(wg):
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, scenario(tt))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,13 +1,12 @@
|
|||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// versions:
|
||||
// protoc-gen-go v1.25.0
|
||||
// protoc-gen-go v1.26.0
|
||||
// protoc v3.15.2
|
||||
// source: audit/v1/dns_details.proto
|
||||
|
||||
package v1
|
||||
|
||||
import (
|
||||
proto "github.com/golang/protobuf/proto"
|
||||
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
|
||||
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
|
||||
reflect "reflect"
|
||||
|
@ -21,10 +20,6 @@ const (
|
|||
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
|
||||
)
|
||||
|
||||
// This is a compile-time assertion that a sufficiently up-to-date version
|
||||
// of the legacy proto package is being used.
|
||||
const _ = proto.ProtoPackageIsVersion4
|
||||
|
||||
type DNSOpCode int32
|
||||
|
||||
const (
|
||||
|
|
|
@ -1,13 +1,12 @@
|
|||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// versions:
|
||||
// protoc-gen-go v1.25.0
|
||||
// protoc-gen-go v1.26.0
|
||||
// protoc v3.15.2
|
||||
// source: audit/v1/event_entity.proto
|
||||
|
||||
package v1
|
||||
|
||||
import (
|
||||
proto "github.com/golang/protobuf/proto"
|
||||
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
|
||||
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
|
||||
anypb "google.golang.org/protobuf/types/known/anypb"
|
||||
|
@ -23,10 +22,6 @@ const (
|
|||
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
|
||||
)
|
||||
|
||||
// This is a compile-time assertion that a sufficiently up-to-date version
|
||||
// of the legacy proto package is being used.
|
||||
const _ = proto.ProtoPackageIsVersion4
|
||||
|
||||
type TransportProtocol int32
|
||||
|
||||
const (
|
||||
|
|
|
@ -1,13 +1,12 @@
|
|||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// versions:
|
||||
// protoc-gen-go v1.25.0
|
||||
// protoc-gen-go v1.26.0
|
||||
// protoc v3.15.2
|
||||
// source: audit/v1/http_details.proto
|
||||
|
||||
package v1
|
||||
|
||||
import (
|
||||
proto "github.com/golang/protobuf/proto"
|
||||
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
|
||||
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
|
||||
reflect "reflect"
|
||||
|
@ -21,10 +20,6 @@ const (
|
|||
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
|
||||
)
|
||||
|
||||
// This is a compile-time assertion that a sufficiently up-to-date version
|
||||
// of the legacy proto package is being used.
|
||||
const _ = proto.ProtoPackageIsVersion4
|
||||
|
||||
type HTTPMethod int32
|
||||
|
||||
const (
|
||||
|
|
|
@ -20,6 +20,7 @@ type WriterCloserSyncer interface {
|
|||
}
|
||||
|
||||
func Test_eventWriter_Write(t *testing.T) {
|
||||
t.Parallel()
|
||||
type args struct {
|
||||
evs []*audit.Event
|
||||
}
|
||||
|
@ -44,10 +45,11 @@ func Test_eventWriter_Write(t *testing.T) {
|
|||
wantErr: false,
|
||||
},
|
||||
}
|
||||
scenario := func(tt testCase) func(t *testing.T) {
|
||||
return func(t *testing.T) {
|
||||
for _, tc := range tests {
|
||||
tt := tc
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
writerMock := audit_mock.NewMockWriterCloserSyncer(ctrl)
|
||||
calls := make([]*gomock.Call, 0)
|
||||
for i := 0; i < len(tt.args.evs); i++ {
|
||||
|
@ -79,10 +81,6 @@ func Test_eventWriter_Write(t *testing.T) {
|
|||
t.Errorf("Write() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, scenario(tt))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ package cert
|
|||
import "testing"
|
||||
|
||||
func Test_extractIPFromAddress(t *testing.T) {
|
||||
t.Parallel()
|
||||
type args struct {
|
||||
addr string
|
||||
}
|
||||
|
@ -30,8 +31,10 @@ func Test_extractIPFromAddress(t *testing.T) {
|
|||
},
|
||||
},
|
||||
}
|
||||
scenario := func(tt testCase) func(t *testing.T) {
|
||||
return func(t *testing.T) {
|
||||
for _, tc := range tests {
|
||||
tt := tc
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got, err := extractIPFromAddress(tt.args.addr)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("extractIPFromAddress() error = %v, wantErr %v", err, tt.wantErr)
|
||||
|
@ -40,9 +43,6 @@ func Test_extractIPFromAddress(t *testing.T) {
|
|||
if got != tt.want {
|
||||
t.Errorf("extractIPFromAddress() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, scenario(tt))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,11 +6,11 @@ import (
|
|||
"fmt"
|
||||
"os"
|
||||
"path"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/maxatome/go-testdeep/td"
|
||||
|
||||
certmock "gitlab.com/inetmock/inetmock/internal/mock/cert"
|
||||
)
|
||||
|
@ -27,10 +27,9 @@ var (
|
|||
)
|
||||
|
||||
func Test_certShouldBeRenewed(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
t.Parallel()
|
||||
type args struct {
|
||||
timeSource TimeSource
|
||||
timeSourceSetup func(ctrl *gomock.Controller) TimeSource
|
||||
cert *x509.Certificate
|
||||
}
|
||||
type testCase struct {
|
||||
|
@ -42,7 +41,7 @@ func Test_certShouldBeRenewed(t *testing.T) {
|
|||
{
|
||||
name: "Expect cert should not be renewed right after creation",
|
||||
args: args{
|
||||
timeSource: func() TimeSource {
|
||||
timeSourceSetup: func(ctrl *gomock.Controller) TimeSource {
|
||||
tsMock := certmock.NewMockTimeSource(ctrl)
|
||||
tsMock.
|
||||
EXPECT().
|
||||
|
@ -50,7 +49,7 @@ func Test_certShouldBeRenewed(t *testing.T) {
|
|||
Return(time.Now().UTC()).
|
||||
Times(1)
|
||||
return tsMock
|
||||
}(),
|
||||
},
|
||||
cert: &x509.Certificate{
|
||||
NotAfter: time.Now().UTC().Add(serverRelativeValidity),
|
||||
NotBefore: time.Now().UTC().Add(-serverRelativeValidity),
|
||||
|
@ -61,7 +60,7 @@ func Test_certShouldBeRenewed(t *testing.T) {
|
|||
{
|
||||
name: "Expect cert should be renewed if the remaining lifetime is less than a quarter of the total lifetime",
|
||||
args: args{
|
||||
timeSource: func() TimeSource {
|
||||
timeSourceSetup: func(ctrl *gomock.Controller) TimeSource {
|
||||
tsMock := certmock.NewMockTimeSource(ctrl)
|
||||
tsMock.
|
||||
EXPECT().
|
||||
|
@ -69,7 +68,7 @@ func Test_certShouldBeRenewed(t *testing.T) {
|
|||
Return(time.Now().UTC().Add(serverRelativeValidity/2 + 1*time.Hour)).
|
||||
Times(1)
|
||||
return tsMock
|
||||
}(),
|
||||
},
|
||||
cert: &x509.Certificate{
|
||||
NotAfter: time.Now().UTC().Add(serverRelativeValidity),
|
||||
NotBefore: time.Now().UTC().Add(-serverRelativeValidity),
|
||||
|
@ -78,50 +77,21 @@ func Test_certShouldBeRenewed(t *testing.T) {
|
|||
want: true,
|
||||
},
|
||||
}
|
||||
scenario := func(tt testCase) func(t *testing.T) {
|
||||
return func(t *testing.T) {
|
||||
if got := certShouldBeRenewed(tt.args.timeSource, tt.args.cert); got != tt.want {
|
||||
for _, tc := range tests {
|
||||
tt := tc
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctrl := gomock.NewController(t)
|
||||
if got := certShouldBeRenewed(tt.args.timeSourceSetup(ctrl), tt.args.cert); got != tt.want {
|
||||
t.Errorf("certShouldBeRenewed() = %v, want %v", got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, scenario(tt))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
//nolint:funlen
|
||||
func Test_fileSystemCache_Get(t *testing.T) {
|
||||
type fields struct {
|
||||
certCachePath string
|
||||
inMemCache map[string]*tls.Certificate
|
||||
timeSource TimeSource
|
||||
}
|
||||
type args struct {
|
||||
cn string
|
||||
}
|
||||
type testCase struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
wantOk bool
|
||||
}
|
||||
tests := []testCase{
|
||||
{
|
||||
name: "Get a miss when no cert is present",
|
||||
fields: fields{
|
||||
certCachePath: os.TempDir(),
|
||||
inMemCache: make(map[string]*tls.Certificate),
|
||||
timeSource: NewTimeSource(),
|
||||
},
|
||||
args: args{
|
||||
cnLocalhost,
|
||||
},
|
||||
wantOk: false,
|
||||
},
|
||||
{
|
||||
name: "Get a prepared certificate from the memory cache",
|
||||
fields: func() fields {
|
||||
t.Parallel()
|
||||
|
||||
certGen := setupCertGen()
|
||||
|
||||
caCrt, _ := certGen.CACert(GenerationOptions{
|
||||
|
@ -132,14 +102,40 @@ func Test_fileSystemCache_Get(t *testing.T) {
|
|||
CommonName: cnLocalhost,
|
||||
}, caCrt)
|
||||
|
||||
return fields{
|
||||
certCachePath: os.TempDir(),
|
||||
type fields struct {
|
||||
inMemCache map[string]*tls.Certificate
|
||||
timeSource TimeSource
|
||||
}
|
||||
type args struct {
|
||||
cn string
|
||||
}
|
||||
type testCase struct {
|
||||
name string
|
||||
fields fields
|
||||
polluteCertCache bool
|
||||
args args
|
||||
wantOk bool
|
||||
}
|
||||
tests := []testCase{
|
||||
{
|
||||
name: "Get a miss when no cert is present",
|
||||
fields: fields{
|
||||
inMemCache: make(map[string]*tls.Certificate),
|
||||
timeSource: NewTimeSource(),
|
||||
},
|
||||
args: args{
|
||||
cnLocalhost,
|
||||
},
|
||||
wantOk: false,
|
||||
},
|
||||
{
|
||||
name: "Get a prepared certificate from the memory cache",
|
||||
fields: fields{
|
||||
inMemCache: map[string]*tls.Certificate{
|
||||
cnLocalhost: srvCrt,
|
||||
},
|
||||
timeSource: NewTimeSource(),
|
||||
}
|
||||
}(),
|
||||
},
|
||||
args: args{
|
||||
cn: cnLocalhost,
|
||||
},
|
||||
|
@ -147,67 +143,50 @@ func Test_fileSystemCache_Get(t *testing.T) {
|
|||
},
|
||||
{
|
||||
name: "Get a prepared certificate from the file system",
|
||||
fields: func() fields {
|
||||
certGen := setupCertGen()
|
||||
|
||||
caCrt, _ := certGen.CACert(GenerationOptions{
|
||||
CommonName: "INetMock",
|
||||
})
|
||||
|
||||
srvCrt, _ := certGen.ServerCert(GenerationOptions{
|
||||
CommonName: serverCN,
|
||||
}, caCrt)
|
||||
|
||||
pem := NewPEM(srvCrt)
|
||||
if err := pem.Write(serverCN, os.TempDir()); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
for _, f := range []string{
|
||||
path.Join(os.TempDir(), fmt.Sprintf("%s.pem", serverCN)),
|
||||
path.Join(os.TempDir(), fmt.Sprintf("%s.key", serverCN)),
|
||||
} {
|
||||
_ = os.Remove(f)
|
||||
}
|
||||
})
|
||||
|
||||
return fields{
|
||||
certCachePath: os.TempDir(),
|
||||
fields: fields{
|
||||
inMemCache: make(map[string]*tls.Certificate),
|
||||
timeSource: NewTimeSource(),
|
||||
}
|
||||
}(),
|
||||
},
|
||||
args: args{
|
||||
cn: serverCN,
|
||||
},
|
||||
polluteCertCache: true,
|
||||
wantOk: true,
|
||||
},
|
||||
}
|
||||
scenario := func(tt testCase) func(t *testing.T) {
|
||||
return func(t *testing.T) {
|
||||
for _, tc := range tests {
|
||||
tt := tc
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
dir := t.TempDir()
|
||||
f := &fileSystemCache{
|
||||
certCachePath: tt.fields.certCachePath,
|
||||
certCachePath: dir,
|
||||
inMemCache: tt.fields.inMemCache,
|
||||
timeSource: tt.fields.timeSource,
|
||||
}
|
||||
|
||||
if tt.polluteCertCache {
|
||||
pem := NewPEM(srvCrt)
|
||||
if err := pem.Write(serverCN, dir); err != nil {
|
||||
t.Fatalf("polluteCertCache error = %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
gotCrt, gotOk := f.Get(tt.args.cn)
|
||||
|
||||
if gotOk && (gotCrt == nil || !reflect.DeepEqual(reflect.TypeOf(new(tls.Certificate)), reflect.TypeOf(gotCrt))) {
|
||||
if gotOk && (gotCrt == nil || !td.CmpIsa(t, gotCrt, new(tls.Certificate))) {
|
||||
t.Errorf("Wanted propert certificate but got %v", gotCrt)
|
||||
}
|
||||
|
||||
if gotOk != tt.wantOk {
|
||||
t.Errorf("Get() gotOk = %v, want %v", gotOk, tt.wantOk)
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, scenario(tt))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_fileSystemCache_Put(t *testing.T) {
|
||||
t.Parallel()
|
||||
type fields struct {
|
||||
certCachePath string
|
||||
inMemCache map[string]*tls.Certificate
|
||||
|
@ -281,8 +260,10 @@ func Test_fileSystemCache_Put(t *testing.T) {
|
|||
wantErr: false,
|
||||
},
|
||||
}
|
||||
scenario := func(tt testCase) func(t *testing.T) {
|
||||
return func(t *testing.T) {
|
||||
for _, tc := range tests {
|
||||
tt := tc
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
f := &fileSystemCache{
|
||||
certCachePath: tt.fields.certCachePath,
|
||||
inMemCache: tt.fields.inMemCache,
|
||||
|
@ -291,10 +272,7 @@ func Test_fileSystemCache_Put(t *testing.T) {
|
|||
if err := f.Put(tt.args.cert); (err != nil) != tt.wantErr {
|
||||
t.Errorf("Put() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, scenario(tt))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,11 +1,13 @@
|
|||
package cert
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/maxatome/go-testdeep/td"
|
||||
)
|
||||
|
||||
func Test_certOptionsDefaulter(t *testing.T) {
|
||||
t.Parallel()
|
||||
type testCase struct {
|
||||
name string
|
||||
arg GenerationOptions
|
||||
|
@ -60,17 +62,14 @@ func Test_certOptionsDefaulter(t *testing.T) {
|
|||
},
|
||||
},
|
||||
}
|
||||
scenario := func(tt testCase) func(t *testing.T) {
|
||||
return func(t *testing.T) {
|
||||
for _, tc := range tests {
|
||||
tt := tc
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if err := applyDefaultGenerationOptions(&tt.arg); err != nil {
|
||||
t.Errorf("applyDefaultGenerationOptions() error = %v", err)
|
||||
}
|
||||
if !reflect.DeepEqual(tt.expected, tt.arg) {
|
||||
t.Errorf("Apply defaulter expected=%v got=%v", tt.expected, tt.arg)
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, scenario(tt))
|
||||
td.Cmp(t, tt.arg, tt.expected)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -30,12 +30,10 @@ type Generator interface {
|
|||
ServerCert(options GenerationOptions, ca *tls.Certificate) (*tls.Certificate, error)
|
||||
}
|
||||
|
||||
//nolint:gocritic
|
||||
func NewDefaultGenerator(options Options) Generator {
|
||||
return NewGenerator(options, NewTimeSource(), defaultKeyProvider(options))
|
||||
}
|
||||
|
||||
//nolint:gocritic
|
||||
func NewGenerator(options Options, source TimeSource, provider KeyProvider) Generator {
|
||||
return &generator{
|
||||
options: options,
|
||||
|
|
|
@ -1,10 +0,0 @@
|
|||
package cert
|
||||
|
||||
/*func init() {
|
||||
config.AddDefaultValue(certCachePathConfigKey, os.TempDir())
|
||||
config.AddDefaultValue(ecdsaCurveConfigKey, string(CurveTypeED25519))
|
||||
config.AddDefaultValue(caCertValidityNotBeforeKey, defaultCAValidityDuration)
|
||||
config.AddDefaultValue(caCertValidityNotAfterKey, defaultCAValidityDuration)
|
||||
config.AddDefaultValue(serverCertValidityNotBeforeKey, defaultServerValidityDuration)
|
||||
config.AddDefaultValue(serverCertValidityNotAfterKey, defaultServerValidityDuration)
|
||||
}*/
|
|
@ -1,131 +0,0 @@
|
|||
package cert
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
func readViper(cfg string) *viper.Viper {
|
||||
vpr := viper.New()
|
||||
vpr.SetConfigType("yaml")
|
||||
if err := vpr.ReadConfig(strings.NewReader(cfg)); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return vpr
|
||||
}
|
||||
|
||||
//nolint:funlen
|
||||
func Test_loadFromConfig(t *testing.T) {
|
||||
type args struct {
|
||||
config *viper.Viper
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want Options
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "Parse valid TLS configuration",
|
||||
wantErr: false,
|
||||
args: args{
|
||||
config: readViper(`
|
||||
tls:
|
||||
ecdsaCurve: P256
|
||||
validity:
|
||||
ca:
|
||||
notBeforeRelative: 17520h
|
||||
notAfterRelative: 17520h
|
||||
server:
|
||||
NotBeforeRelative: 168h
|
||||
NotAfterRelative: 168h
|
||||
rootCaCert:
|
||||
publicKey: ./ca.pem
|
||||
privateKey: ./ca.key
|
||||
certCachePath: /tmp/inetmock/
|
||||
`),
|
||||
},
|
||||
want: Options{
|
||||
RootCACert: File{
|
||||
PublicKeyPath: "./ca.pem",
|
||||
PrivateKeyPath: "./ca.key",
|
||||
},
|
||||
CertCachePath: "/tmp/inetmock/",
|
||||
Curve: CurveTypeP256,
|
||||
Validity: ValidityByPurpose{
|
||||
CA: ValidityDuration{
|
||||
NotBeforeRelative: 17520 * time.Hour,
|
||||
NotAfterRelative: 17520 * time.Hour,
|
||||
},
|
||||
Server: ValidityDuration{
|
||||
NotBeforeRelative: 168 * time.Hour,
|
||||
NotAfterRelative: 168 * time.Hour,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Get an error if CA public key path is missing",
|
||||
args: args{
|
||||
readViper(`
|
||||
tls:
|
||||
rootCaCert:
|
||||
privateKey: ./ca.key
|
||||
`),
|
||||
},
|
||||
want: Options{},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "Get an error if CA private key path is missing",
|
||||
args: args{
|
||||
readViper(`
|
||||
tls:
|
||||
rootCaCert:
|
||||
publicKey: ./ca.pem
|
||||
`),
|
||||
},
|
||||
want: Options{},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "Get default options if all required fields are set",
|
||||
args: args{
|
||||
readViper(`
|
||||
tls:
|
||||
rootCaCert:
|
||||
publicKey: ./ca.pem
|
||||
privateKey: ./ca.key
|
||||
`),
|
||||
},
|
||||
want: Options{
|
||||
RootCACert: File{
|
||||
PublicKeyPath: "./ca.pem",
|
||||
PrivateKeyPath: "./ca.key",
|
||||
},
|
||||
CertCachePath: os.TempDir(),
|
||||
Curve: CurveTypeED25519,
|
||||
Validity: ValidityByPurpose{
|
||||
CA: ValidityDuration{
|
||||
NotBeforeRelative: 17520 * time.Hour,
|
||||
NotAfterRelative: 17520 * time.Hour,
|
||||
},
|
||||
Server: ValidityDuration{
|
||||
NotBeforeRelative: 168 * time.Hour,
|
||||
NotAfterRelative: 168 * time.Hour,
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
||||
})
|
||||
}
|
||||
}
|
|
@ -33,7 +33,6 @@ type Store interface {
|
|||
TLSConfig() *tls.Config
|
||||
}
|
||||
|
||||
//nolint:gocritic
|
||||
func NewDefaultStore(
|
||||
options Options,
|
||||
logger logging.Logger,
|
||||
|
@ -47,7 +46,6 @@ func NewDefaultStore(
|
|||
)
|
||||
}
|
||||
|
||||
//nolint:gocritic
|
||||
func NewStore(
|
||||
options Options,
|
||||
cache Cache,
|
||||
|
@ -149,7 +147,6 @@ func (s *store) GetCertificate(serverName, ip string) (cert *tls.Certificate, er
|
|||
return
|
||||
}
|
||||
|
||||
//nolint:gocritic
|
||||
func privateKeyForCurve(options Options) (privateKey interface{}, err error) {
|
||||
switch options.Curve {
|
||||
case CurveTypeP224:
|
||||
|
|
|
@ -1,12 +1,14 @@
|
|||
package health
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/maxatome/go-testdeep/td"
|
||||
)
|
||||
|
||||
//nolint:funlen
|
||||
func Test_checker_IsHealthy(t *testing.T) {
|
||||
t.Parallel()
|
||||
type fields struct {
|
||||
componentChecks map[string]Check
|
||||
}
|
||||
|
@ -145,17 +147,15 @@ func Test_checker_IsHealthy(t *testing.T) {
|
|||
},
|
||||
},
|
||||
}
|
||||
scenario := func(tt testCase) func(t *testing.T) {
|
||||
return func(t *testing.T) {
|
||||
for _, tc := range tests {
|
||||
tt := tc
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
c := &checker{
|
||||
componentChecks: tt.fields.componentChecks,
|
||||
}
|
||||
if gotR := c.IsHealthy(); !reflect.DeepEqual(gotR, tt.wantR) {
|
||||
t.Errorf("IsHealthy() = %v, want %v", gotR, tt.wantR)
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, scenario(tt))
|
||||
gotR := c.IsHealthy()
|
||||
td.Cmp(t, gotR, tt.wantR)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -50,6 +50,7 @@ func CreateLogger() (Logger, error) {
|
|||
}
|
||||
|
||||
func CreateTestLogger(tb testing.TB) Logger {
|
||||
tb.Helper()
|
||||
logger := &testLogger{
|
||||
testRunning: make(chan struct{}),
|
||||
tb: tb,
|
||||
|
|
|
@ -9,6 +9,7 @@ import (
|
|||
|
||||
//nolint:funlen
|
||||
func TestParseLevel(t *testing.T) {
|
||||
t.Parallel()
|
||||
type args struct {
|
||||
levelString string
|
||||
}
|
||||
|
@ -103,19 +104,19 @@ func TestParseLevel(t *testing.T) {
|
|||
want: zap.NewAtomicLevelAt(zap.InfoLevel),
|
||||
},
|
||||
}
|
||||
scenario := func(tt testCase) func(t *testing.T) {
|
||||
return func(t *testing.T) {
|
||||
for _, tc := range tests {
|
||||
tt := tc
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if got := ParseLevel(tt.args.levelString); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("ParseLevel() = %v, want %v", got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, scenario(tt))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigureLogging(t *testing.T) {
|
||||
t.Parallel()
|
||||
type args struct {
|
||||
level zap.AtomicLevel
|
||||
developmentLogging bool
|
||||
|
@ -151,8 +152,10 @@ func TestConfigureLogging(t *testing.T) {
|
|||
},
|
||||
},
|
||||
}
|
||||
scenario := func(tt testCase) func(t *testing.T) {
|
||||
return func(t *testing.T) {
|
||||
for _, tc := range tests {
|
||||
tt := tc
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ConfigureLogging(tt.args.level, tt.args.developmentLogging, tt.args.initialFields)
|
||||
if loggingConfig.Development != tt.args.developmentLogging {
|
||||
t.Errorf("loggingConfig.Development = %t, want %t", loggingConfig.Development, tt.args.developmentLogging)
|
||||
|
@ -168,9 +171,6 @@ func TestConfigureLogging(t *testing.T) {
|
|||
t.Errorf("loggingConfig.InitialFields = %v, want %v", loggingConfig.InitialFields, tt.args.initialFields)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, scenario(tt))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,58 +1,47 @@
|
|||
package path
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"math/rand"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/maxatome/go-testdeep/td"
|
||||
)
|
||||
|
||||
func TestFileExists(t *testing.T) {
|
||||
tmpFile, err := ioutil.TempFile("", "inetmock")
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("failed to create temp file: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
tmpFile.Close()
|
||||
os.Remove(tmpFile.Name())
|
||||
}()
|
||||
|
||||
type args struct {
|
||||
filename string
|
||||
}
|
||||
t.Parallel()
|
||||
type testCase struct {
|
||||
name string
|
||||
args args
|
||||
want bool
|
||||
}
|
||||
tests := []testCase{
|
||||
{
|
||||
name: "Ensure temp file exists",
|
||||
want: true,
|
||||
args: args{
|
||||
filename: tmpFile.Name(),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Ensure random file name does not exist",
|
||||
want: false,
|
||||
args: args{
|
||||
//nolint:gosec
|
||||
filename: path.Join(os.TempDir(), fmt.Sprintf("asdf-%d", rand.Uint32())),
|
||||
},
|
||||
},
|
||||
}
|
||||
scenario := func(tt testCase) func(t *testing.T) {
|
||||
return func(t *testing.T) {
|
||||
if got := FileExists(tt.args.filename); got != tt.want {
|
||||
for _, tc := range tests {
|
||||
tt := tc
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
tempDir := t.TempDir()
|
||||
var filePath = filepath.Join(tempDir, "nonexistent")
|
||||
if tt.want {
|
||||
tmpFile, err := os.CreateTemp(tempDir, "testFileExists.*.tmp")
|
||||
if td.CmpNoError(t, err) {
|
||||
t.Cleanup(func() {
|
||||
_ = tmpFile.Close()
|
||||
})
|
||||
filePath = filepath.Join(tempDir, filepath.Base(tmpFile.Name()))
|
||||
}
|
||||
}
|
||||
if got := FileExists(filePath); got != tt.want {
|
||||
t.Errorf("FileExists() = %v, want %v", got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, scenario(tt))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,13 +1,12 @@
|
|||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// versions:
|
||||
// protoc-gen-go v1.25.0
|
||||
// protoc-gen-go v1.26.0
|
||||
// protoc v3.15.2
|
||||
// source: rpc/v1/audit.proto
|
||||
|
||||
package v1
|
||||
|
||||
import (
|
||||
proto "github.com/golang/protobuf/proto"
|
||||
v1 "gitlab.com/inetmock/inetmock/pkg/audit/v1"
|
||||
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
|
||||
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
|
||||
|
@ -22,10 +21,6 @@ const (
|
|||
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
|
||||
)
|
||||
|
||||
// This is a compile-time assertion that a sufficiently up-to-date version
|
||||
// of the legacy proto package is being used.
|
||||
const _ = proto.ProtoPackageIsVersion4
|
||||
|
||||
type WatchEventsRequest struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
|
|
|
@ -11,6 +11,7 @@ import (
|
|||
|
||||
// This is a compile-time assertion to ensure that this generated file
|
||||
// is compatible with the grpc package it is being compiled against.
|
||||
// Requires gRPC-Go v1.32.0 or later.
|
||||
const _ = grpc.SupportPackageIsVersion7
|
||||
|
||||
// AuditServiceClient is the client API for AuditService service.
|
||||
|
@ -32,7 +33,7 @@ func NewAuditServiceClient(cc grpc.ClientConnInterface) AuditServiceClient {
|
|||
}
|
||||
|
||||
func (c *auditServiceClient) WatchEvents(ctx context.Context, in *WatchEventsRequest, opts ...grpc.CallOption) (AuditService_WatchEventsClient, error) {
|
||||
stream, err := c.cc.NewStream(ctx, &_AuditService_serviceDesc.Streams[0], "/inetmock.rpc.v1.AuditService/WatchEvents", opts...)
|
||||
stream, err := c.cc.NewStream(ctx, &AuditService_ServiceDesc.Streams[0], "/inetmock.rpc.v1.AuditService/WatchEvents", opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -127,7 +128,7 @@ type UnsafeAuditServiceServer interface {
|
|||
}
|
||||
|
||||
func RegisterAuditServiceServer(s grpc.ServiceRegistrar, srv AuditServiceServer) {
|
||||
s.RegisterService(&_AuditService_serviceDesc, srv)
|
||||
s.RegisterService(&AuditService_ServiceDesc, srv)
|
||||
}
|
||||
|
||||
func _AuditService_WatchEvents_Handler(srv interface{}, stream grpc.ServerStream) error {
|
||||
|
@ -205,7 +206,10 @@ func _AuditService_ListSinks_Handler(srv interface{}, ctx context.Context, dec f
|
|||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
var _AuditService_serviceDesc = grpc.ServiceDesc{
|
||||
// AuditService_ServiceDesc is the grpc.ServiceDesc for AuditService service.
|
||||
// It's only intended for direct use with grpc.RegisterService,
|
||||
// and not to be introspected or modified (even as a copy)
|
||||
var AuditService_ServiceDesc = grpc.ServiceDesc{
|
||||
ServiceName: "inetmock.rpc.v1.AuditService",
|
||||
HandlerType: (*AuditServiceServer)(nil),
|
||||
Methods: []grpc.MethodDesc{
|
||||
|
|
|
@ -1,13 +1,12 @@
|
|||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// versions:
|
||||
// protoc-gen-go v1.25.0
|
||||
// protoc-gen-go v1.26.0
|
||||
// protoc v3.15.2
|
||||
// source: rpc/v1/health.proto
|
||||
|
||||
package v1
|
||||
|
||||
import (
|
||||
proto "github.com/golang/protobuf/proto"
|
||||
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
|
||||
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
|
||||
reflect "reflect"
|
||||
|
@ -21,10 +20,6 @@ const (
|
|||
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
|
||||
)
|
||||
|
||||
// This is a compile-time assertion that a sufficiently up-to-date version
|
||||
// of the legacy proto package is being used.
|
||||
const _ = proto.ProtoPackageIsVersion4
|
||||
|
||||
type HealthState int32
|
||||
|
||||
const (
|
||||
|
|
|
@ -11,6 +11,7 @@ import (
|
|||
|
||||
// This is a compile-time assertion to ensure that this generated file
|
||||
// is compatible with the grpc package it is being compiled against.
|
||||
// Requires gRPC-Go v1.32.0 or later.
|
||||
const _ = grpc.SupportPackageIsVersion7
|
||||
|
||||
// HealthServiceClient is the client API for HealthService service.
|
||||
|
@ -62,7 +63,7 @@ type UnsafeHealthServiceServer interface {
|
|||
}
|
||||
|
||||
func RegisterHealthServiceServer(s grpc.ServiceRegistrar, srv HealthServiceServer) {
|
||||
s.RegisterService(&_HealthService_serviceDesc, srv)
|
||||
s.RegisterService(&HealthService_ServiceDesc, srv)
|
||||
}
|
||||
|
||||
func _HealthService_GetHealth_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
|
@ -83,7 +84,10 @@ func _HealthService_GetHealth_Handler(srv interface{}, ctx context.Context, dec
|
|||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
var _HealthService_serviceDesc = grpc.ServiceDesc{
|
||||
// HealthService_ServiceDesc is the grpc.ServiceDesc for HealthService service.
|
||||
// It's only intended for direct use with grpc.RegisterService,
|
||||
// and not to be introspected or modified (even as a copy)
|
||||
var HealthService_ServiceDesc = grpc.ServiceDesc{
|
||||
ServiceName: "inetmock.rpc.v1.HealthService",
|
||||
HandlerType: (*HealthServiceServer)(nil),
|
||||
Methods: []grpc.MethodDesc{
|
||||
|
|
|
@ -1,13 +1,12 @@
|
|||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// versions:
|
||||
// protoc-gen-go v1.25.0
|
||||
// protoc-gen-go v1.26.0
|
||||
// protoc v3.15.2
|
||||
// source: rpc/v1/pcap.proto
|
||||
|
||||
package v1
|
||||
|
||||
import (
|
||||
proto "github.com/golang/protobuf/proto"
|
||||
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
|
||||
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
|
||||
durationpb "google.golang.org/protobuf/types/known/durationpb"
|
||||
|
@ -22,10 +21,6 @@ const (
|
|||
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
|
||||
)
|
||||
|
||||
// This is a compile-time assertion that a sufficiently up-to-date version
|
||||
// of the legacy proto package is being used.
|
||||
const _ = proto.ProtoPackageIsVersion4
|
||||
|
||||
type ListAvailableDevicesRequest struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
|
|
|
@ -11,6 +11,7 @@ import (
|
|||
|
||||
// This is a compile-time assertion to ensure that this generated file
|
||||
// is compatible with the grpc package it is being compiled against.
|
||||
// Requires gRPC-Go v1.32.0 or later.
|
||||
const _ = grpc.SupportPackageIsVersion7
|
||||
|
||||
// PCAPServiceClient is the client API for PCAPService service.
|
||||
|
@ -104,7 +105,7 @@ type UnsafePCAPServiceServer interface {
|
|||
}
|
||||
|
||||
func RegisterPCAPServiceServer(s grpc.ServiceRegistrar, srv PCAPServiceServer) {
|
||||
s.RegisterService(&_PCAPService_serviceDesc, srv)
|
||||
s.RegisterService(&PCAPService_ServiceDesc, srv)
|
||||
}
|
||||
|
||||
func _PCAPService_ListAvailableDevices_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
|
@ -179,7 +180,10 @@ func _PCAPService_StopPCAPFileRecording_Handler(srv interface{}, ctx context.Con
|
|||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
var _PCAPService_serviceDesc = grpc.ServiceDesc{
|
||||
// PCAPService_ServiceDesc is the grpc.ServiceDesc for PCAPService service.
|
||||
// It's only intended for direct use with grpc.RegisterService,
|
||||
// and not to be introspected or modified (even as a copy)
|
||||
var PCAPService_ServiceDesc = grpc.ServiceDesc{
|
||||
ServiceName: "inetmock.rpc.v1.PCAPService",
|
||||
HandlerType: (*PCAPServiceServer)(nil),
|
||||
Methods: []grpc.MethodDesc{
|
||||
|
|
Loading…
Add table
Reference in a new issue