diff --git a/Makefile b/Makefile index 5e073af..3fbf645 100644 --- a/Makefile +++ b/Makefile @@ -2,7 +2,8 @@ VERSION = $(shell git describe --dirty --tags --always) DIR = $(dir $(realpath $(firstword $(MAKEFILE_LIST)))) BUILD_PATH = $(DIR)/main.go PKGS = $(shell go list ./...) -TEST_PKGS = $(shell find . -type f -name "*_test.go" -not -path "./plugins/*" -not -path "*/mock/*" -printf '%h\n' | sort -u) +TEST_PKGS = $(shell go list ./...) +PROTO_FILES = $(shell find $(DIR)api/ -type f -name "*.proto") GOARGS = GOOS=linux GOARCH=amd64 GO_BUILD_ARGS = -ldflags="-w -s" GO_CONTAINER_BUILD_ARGS = -ldflags="-w -s" -a -installsuffix cgo @@ -14,7 +15,7 @@ DEBUG_ARGS?= --development-logs=true CONTAINER_BUILDER ?= podman DOCKER_IMAGE ?= inetmock -.PHONY: clean all format deps update-deps compile debug generate snapshot-release test cli-cover-report html-cover-report plugins $(PLUGINS) $(GO_GEN_FILES) +.PHONY: clean all format deps update-deps compile debug generate protoc snapshot-release test cli-cover-report html-cover-report plugins $(PLUGINS) $(GO_GEN_FILES) all: clean format compile test plugins clean: @@ -55,6 +56,9 @@ debug: generate: @go generate ./... +protoc: + @protoc --proto_path $(DIR)api/ --go_out=plugins=grpc:internal/grpc --go_opt=paths=source_relative $(shell find $(DIR)api/ -type f -name "*.proto") + snapshot-release: @goreleaser release --snapshot --skip-publish --rm-dist diff --git a/internal/cmd/ca.go b/internal/cmd/ca.go index 9c0c439..c4fb97f 100644 --- a/internal/cmd/ca.go +++ b/internal/cmd/ca.go @@ -4,6 +4,7 @@ import ( "crypto/tls" "crypto/x509" "github.com/baez90/inetmock/pkg/cert" + config2 "github.com/baez90/inetmock/pkg/config" "github.com/baez90/inetmock/pkg/logging" "github.com/spf13/cobra" "go.uber.org/zap" @@ -77,11 +78,11 @@ func runGenerateCA(cmd *cobra.Command, args []string) { zap.String(generateCACertOutPath, certOutPath), ) - generator := cert.NewDefaultGenerator(cert.Options{ + generator := cert.NewDefaultGenerator(config2.CertOptions{ CertCachePath: certOutPath, - Curve: cert.CurveType(curveName), - Validity: cert.ValidityByPurpose{ - CA: cert.ValidityDuration{ + Curve: config2.CurveType(curveName), + Validity: config2.ValidityByPurpose{ + CA: config2.ValidityDuration{ NotAfterRelative: notAfter, NotBeforeRelative: notBefore, }, diff --git a/internal/cmd/init.go b/internal/cmd/init.go deleted file mode 100644 index f25c678..0000000 --- a/internal/cmd/init.go +++ /dev/null @@ -1,70 +0,0 @@ -package cmd - -import ( - "github.com/baez90/inetmock/internal/endpoints" - "github.com/baez90/inetmock/internal/plugins" - "github.com/baez90/inetmock/pkg/api" - "github.com/baez90/inetmock/pkg/logging" - "github.com/baez90/inetmock/pkg/path" - "github.com/spf13/viper" - "go.uber.org/zap" - "os" - "time" -) - -var ( - appIsInitialized = false -) - -func initApp() (err error) { - if appIsInitialized { - return - } - appIsInitialized = true - logging.ConfigureLogging( - logging.ParseLevel(logLevel), - developmentLogs, - map[string]interface{}{"cwd": path.WorkingDirectory()}, - ) - logger, _ = logging.CreateLogger() - registry := plugins.Registry() - endpointManager = endpoints.NewEndpointManager(logger) - - rootCmd.Flags().ParseErrorsWhitelist.UnknownFlags = false - _ = rootCmd.ParseFlags(os.Args) - - if err = appConfig.ReadConfig(configFilePath); err != nil { - logger.Error( - "unrecoverable error occurred during reading the config file", - zap.Error(err), - ) - return - } - - viperInst := viper.GetViper() - pluginDir := viperInst.GetString("plugins-directory") - - if err = api.InitServices(viperInst, logger); err != nil { - logger.Error( - "failed to initialize app services", - zap.Error(err), - ) - } - - pluginLoadStartTime := time.Now() - if err = registry.LoadPlugins(pluginDir); err != nil { - logger.Error("Failed to load plugins", - zap.String("pluginsDirectory", pluginDir), - zap.Error(err), - ) - } - pluginLoadDuration := time.Since(pluginLoadStartTime) - logger.Info( - "loading plugins completed", - zap.Duration("pluginLoadDuration", pluginLoadDuration), - ) - - pluginsCmd.AddCommand(registry.PluginCommands()...) - - return -} diff --git a/internal/cmd/plugins.go b/internal/cmd/plugins.go index a4ef58d..4091505 100644 --- a/internal/cmd/plugins.go +++ b/internal/cmd/plugins.go @@ -5,10 +5,6 @@ import ( ) var ( - pluginsCmd *cobra.Command -) - -func init() { pluginsCmd = &cobra.Command{ Use: "plugins", Short: "Use the plugins prefix to interact with commands that are provided by plugins", @@ -18,5 +14,4 @@ The easiest way to explore what commands are available is to start with 'inetmoc This help page contains a list of available sub-commands starting with the name of the plugin as a prefix. `, } - rootCmd.AddCommand(pluginsCmd, generateCaCmd) -} +) diff --git a/internal/cmd/root.go b/internal/cmd/root.go index bdab14b..f0c8550 100644 --- a/internal/cmd/root.go +++ b/internal/cmd/root.go @@ -1,77 +1,91 @@ package cmd import ( - "github.com/baez90/inetmock/internal/config" - "github.com/baez90/inetmock/internal/endpoints" + "github.com/baez90/inetmock/internal/plugins" "github.com/baez90/inetmock/pkg/api" + "github.com/baez90/inetmock/pkg/config" "github.com/baez90/inetmock/pkg/logging" + "github.com/baez90/inetmock/pkg/path" "github.com/spf13/cobra" - "github.com/spf13/viper" "go.uber.org/zap" - "os" - "os/signal" - "strings" - "syscall" + "time" ) var ( logger logging.Logger - rootCmd = cobra.Command{ - Use: "", - Short: "INetMock is lightweight internet mock", - Run: startInetMock, - } + rootCmd *cobra.Command - configFilePath string - logLevel string - developmentLogs bool - handlers []api.ProtocolHandler - appConfig = config.CreateConfig() - endpointManager endpoints.EndpointManager + pluginsDirectory string + configFilePath string + logLevel string + developmentLogs bool ) func init() { - rootCmd.PersistentFlags().String("plugins-directory", "", "Directory where plugins should be loaded from") + cobra.OnInitialize(onInit) + rootCmd = &cobra.Command{ + Use: "", + Short: "INetMock is lightweight internet mock", + } + + rootCmd.PersistentFlags().StringVar(&pluginsDirectory, "plugins-directory", "", "Directory where plugins should be loaded from") rootCmd.PersistentFlags().StringVar(&configFilePath, "config", "", "Path to config file that should be used") rootCmd.PersistentFlags().StringVar(&logLevel, "log-level", "info", "logging level to use") rootCmd.PersistentFlags().BoolVar(&developmentLogs, "development-logs", false, "Enable development mode logs") - appConfig.InitConfig(rootCmd.PersistentFlags()) + rootCmd.AddCommand( + serveCmd, + generateCaCmd, + pluginsCmd, + ) } -func startInetMock(cmd *cobra.Command, args []string) { - for endpointName := range viper.GetStringMap(config.EndpointsKey) { - handlerSubConfig := viper.Sub(strings.Join([]string{config.EndpointsKey, endpointName}, ".")) - handlerConfig := config.CreateMultiHandlerConfig(handlerSubConfig) - if err := endpointManager.CreateEndpoint(endpointName, handlerConfig); err != nil { - logger.Warn( - "error occurred while creating endpoint", - zap.String("endpointName", endpointName), - zap.String("handlerName", handlerConfig.HandlerName()), - zap.Error(err), - ) - } - } - - endpointManager.StartEndpoints() - - signalChannel := make(chan os.Signal, 1) - signal.Notify(signalChannel, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT) - - // block until canceled - s := <-signalChannel - - logger.Info( - "got signal to quit", - zap.String("signal", s.String()), +func onInit() { + logging.ConfigureLogging( + logging.ParseLevel(logLevel), + developmentLogs, + map[string]interface{}{"cwd": path.WorkingDirectory()}, ) - endpointManager.ShutdownEndpoints() + logger, _ = logging.CreateLogger() + config.CreateConfig(rootCmd.Flags()) + appConfig := config.Instance() + + if err := appConfig.ReadConfig(configFilePath); err != nil { + logger.Error( + "failed to read config file", + zap.Error(err), + ) + } + + if err := api.InitServices(appConfig, logger); err != nil { + logger.Error( + "failed to initialize app services", + zap.Error(err), + ) + } + + registry := plugins.Registry() + + cfg := config.Instance() + + pluginLoadStartTime := time.Now() + if err := registry.LoadPlugins(cfg.PluginsDir()); err != nil { + logger.Error("Failed to load plugins", + zap.String("pluginsDirectory", cfg.PluginsDir()), + zap.Error(err), + ) + } + pluginLoadDuration := time.Since(pluginLoadStartTime) + logger.Info( + "loading plugins completed", + zap.Duration("pluginLoadDuration", pluginLoadDuration), + ) + + pluginsCmd.AddCommand(registry.PluginCommands()...) + } func ExecuteRootCommand() error { - if err := initApp(); err != nil { - return err - } return rootCmd.Execute() } diff --git a/internal/cmd/serve.go b/internal/cmd/serve.go new file mode 100644 index 0000000..b6b6d18 --- /dev/null +++ b/internal/cmd/serve.go @@ -0,0 +1,54 @@ +package cmd + +import ( + "github.com/baez90/inetmock/internal/endpoints" + "github.com/baez90/inetmock/pkg/config" + "github.com/spf13/cobra" + "go.uber.org/zap" + "os" + "os/signal" + "strings" + "syscall" +) + +var ( + endpointManager endpoints.EndpointManager + serveCmd = &cobra.Command{ + Use: "serve", + Short: "Starts the INetMock server", + Long: ``, + Run: startINetMock, + } +) + +func startINetMock(_ *cobra.Command, _ []string) { + endpointManager = endpoints.NewEndpointManager(logger) + cfg := config.Instance() + for endpointName, endpointHandler := range cfg.EndpointConfigs() { + handlerSubConfig := cfg.Viper().Sub(strings.Join([]string{config.EndpointsKey, endpointName, config.OptionsKey}, ".")) + endpointHandler.Options = handlerSubConfig + if err := endpointManager.CreateEndpoint(endpointName, endpointHandler); err != nil { + logger.Warn( + "error occurred while creating endpoint", + zap.String("endpointName", endpointName), + zap.String("handlerName", endpointHandler.Handler), + zap.Error(err), + ) + } + } + + endpointManager.StartEndpoints() + + signalChannel := make(chan os.Signal, 1) + signal.Notify(signalChannel, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT) + + // block until canceled + s := <-signalChannel + + logger.Info( + "got signal to quit", + zap.String("signal", s.String()), + ) + + endpointManager.ShutdownEndpoints() +} diff --git a/internal/config/constants.go b/internal/config/constants.go deleted file mode 100644 index a645a7c..0000000 --- a/internal/config/constants.go +++ /dev/null @@ -1,6 +0,0 @@ -package config - -const ( - EndpointsKey = "endpoints" - OptionsKey = "options" -) diff --git a/internal/config/loading.go b/internal/config/loading.go deleted file mode 100644 index c0b1a03..0000000 --- a/internal/config/loading.go +++ /dev/null @@ -1,55 +0,0 @@ -package config - -import ( - "github.com/baez90/inetmock/pkg/logging" - "github.com/baez90/inetmock/pkg/path" - "github.com/spf13/pflag" - "github.com/spf13/viper" - "go.uber.org/zap" - "strings" -) - -func CreateConfig() Config { - logger, _ := logging.CreateLogger() - return &config{ - logger: logger.Named("Config"), - } -} - -type Config interface { - InitConfig(flags *pflag.FlagSet) - ReadConfig(configFilePath string) error -} - -type config struct { - logger logging.Logger -} - -func (c config) InitConfig(flags *pflag.FlagSet) { - viper.SetConfigName("config") - viper.SetConfigType("yaml") - viper.AddConfigPath("/etc/inetmock/") - viper.AddConfigPath("$HOME/.inetmock") - viper.AddConfigPath(".") - viper.SetEnvPrefix("INetMock") - _ = viper.BindPFlags(flags) - viper.SetEnvKeyReplacer(strings.NewReplacer("-", "_")) - viper.AutomaticEnv() -} - -func (c *config) ReadConfig(configFilePath string) (err error) { - if configFilePath != "" && path.FileExists(configFilePath) { - c.logger.Info( - "loading config from passed config file path", - zap.String("configFilePath", configFilePath), - ) - viper.SetConfigFile(configFilePath) - } - if err = viper.ReadInConfig(); err != nil { - if _, ok := err.(viper.ConfigFileNotFoundError); ok { - err = nil - c.logger.Warn("failed to load config") - } - } - return -} diff --git a/internal/config/multi_handler_config.go b/internal/config/multi_handler_config.go deleted file mode 100644 index f1b9fdf..0000000 --- a/internal/config/multi_handler_config.go +++ /dev/null @@ -1,49 +0,0 @@ -package config - -import ( - "github.com/baez90/inetmock/pkg/api" - "github.com/spf13/viper" -) - -type MultiHandlerConfig interface { - HandlerName() string - ListenAddress() string - Ports() []uint16 - Options() *viper.Viper - HandlerConfigs() []api.HandlerConfig -} - -type multiHandlerConfig struct { - handlerName string - ports []uint16 - listenAddress string - options *viper.Viper -} - -func NewMultiHandlerConfig(handlerName string, ports []uint16, listenAddress string, options *viper.Viper) MultiHandlerConfig { - return &multiHandlerConfig{handlerName: handlerName, ports: ports, listenAddress: listenAddress, options: options} -} - -func (m multiHandlerConfig) HandlerName() string { - return m.handlerName -} - -func (m multiHandlerConfig) ListenAddress() string { - return m.listenAddress -} - -func (m multiHandlerConfig) Ports() []uint16 { - return m.ports -} - -func (m multiHandlerConfig) Options() *viper.Viper { - return m.options -} - -func (m multiHandlerConfig) HandlerConfigs() []api.HandlerConfig { - configs := make([]api.HandlerConfig, 0) - for _, port := range m.ports { - configs = append(configs, api.NewHandlerConfig(m.handlerName, port, m.listenAddress, m.options)) - } - return configs -} diff --git a/internal/config/multi_handler_config_test.go b/internal/config/multi_handler_config_test.go deleted file mode 100644 index ce735fe..0000000 --- a/internal/config/multi_handler_config_test.go +++ /dev/null @@ -1,226 +0,0 @@ -package config - -import ( - "github.com/baez90/inetmock/pkg/api" - "github.com/spf13/viper" - "reflect" - "testing" -) - -func Test_multiHandlerConfig_HandlerConfigs(t *testing.T) { - type fields struct { - handlerName string - ports []uint16 - listenAddress string - options *viper.Viper - } - tests := []struct { - name string - fields fields - want []api.HandlerConfig - }{ - { - name: "Get empty array if no ports are set", - fields: fields{}, - want: make([]api.HandlerConfig, 0), - }, - { - name: "Get a single handler config if only one port is set", - fields: fields{ - handlerName: "sampleHandler", - ports: []uint16{80}, - listenAddress: "0.0.0.0", - options: nil, - }, - want: []api.HandlerConfig{ - api.NewHandlerConfig("sampleHandler", 80, "0.0.0.0", nil), - }, - }, - { - name: "Get multiple handler configs if only one port is set", - fields: fields{ - handlerName: "sampleHandler", - ports: []uint16{80, 8080}, - listenAddress: "0.0.0.0", - options: nil, - }, - want: []api.HandlerConfig{ - api.NewHandlerConfig("sampleHandler", 80, "0.0.0.0", nil), - api.NewHandlerConfig("sampleHandler", 8080, "0.0.0.0", nil), - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - m := multiHandlerConfig{ - handlerName: tt.fields.handlerName, - ports: tt.fields.ports, - listenAddress: tt.fields.listenAddress, - options: tt.fields.options, - } - if got := m.HandlerConfigs(); !reflect.DeepEqual(got, tt.want) { - t.Errorf("HandlerConfigs() = %v, want %v", got, tt.want) - } - }) - } -} - -func Test_multiHandlerConfig_HandlerName(t *testing.T) { - type fields struct { - handlerName string - ports []uint16 - listenAddress string - options *viper.Viper - } - tests := []struct { - name string - fields fields - want string - }{ - { - name: "Get empty handler name for uninitialized struct", - fields: fields{}, - want: "", - }, - { - name: "Get expected handler name for initialized struct", - fields: fields{ - handlerName: "sampleHandler", - }, - want: "sampleHandler", - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - m := multiHandlerConfig{ - handlerName: tt.fields.handlerName, - ports: tt.fields.ports, - listenAddress: tt.fields.listenAddress, - options: tt.fields.options, - } - if got := m.HandlerName(); got != tt.want { - t.Errorf("HandlerName() = %v, want %v", got, tt.want) - } - }) - } -} - -func Test_multiHandlerConfig_ListenAddress(t *testing.T) { - type fields struct { - handlerName string - ports []uint16 - listenAddress string - options *viper.Viper - } - tests := []struct { - name string - fields fields - want string - }{ - { - name: "Get empty ListenAddress for uninitialized struct", - fields: fields{}, - want: "", - }, - { - name: "Get expected ListenAddress for initialized struct", - fields: fields{ - listenAddress: "0.0.0.0", - }, - want: "0.0.0.0", - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - m := multiHandlerConfig{ - handlerName: tt.fields.handlerName, - ports: tt.fields.ports, - listenAddress: tt.fields.listenAddress, - options: tt.fields.options, - } - if got := m.ListenAddress(); got != tt.want { - t.Errorf("ListenAddress() = %v, want %v", got, tt.want) - } - }) - } -} - -func Test_multiHandlerConfig_Options(t *testing.T) { - type fields struct { - handlerName string - ports []uint16 - listenAddress string - options *viper.Viper - } - tests := []struct { - name string - fields fields - want *viper.Viper - }{ - { - name: "Get nil Options for uninitialized struct", - fields: fields{}, - want: nil, - }, - { - name: "Get expected Options for initialized struct", - fields: fields{ - options: viper.New(), - }, - want: viper.New(), - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - m := multiHandlerConfig{ - handlerName: tt.fields.handlerName, - ports: tt.fields.ports, - listenAddress: tt.fields.listenAddress, - options: tt.fields.options, - } - if got := m.Options(); !reflect.DeepEqual(got, tt.want) { - t.Errorf("Options() = %v, want %v", got, tt.want) - } - }) - } -} - -func Test_multiHandlerConfig_Ports(t *testing.T) { - type fields struct { - handlerName string - ports []uint16 - listenAddress string - options *viper.Viper - } - tests := []struct { - name string - fields fields - want []uint16 - }{ - { - name: "Get empty Ports for uninitialized struct", - fields: fields{}, - want: nil, - }, - { - name: "Get expected Ports for initialized struct", - fields: fields{ - ports: []uint16{80, 8080}, - }, - want: []uint16{80, 8080}, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - m := multiHandlerConfig{ - handlerName: tt.fields.handlerName, - ports: tt.fields.ports, - listenAddress: tt.fields.listenAddress, - options: tt.fields.options, - } - if got := m.Ports(); !reflect.DeepEqual(got, tt.want) { - t.Errorf("Ports() = %v, want %v", got, tt.want) - } - }) - } -} diff --git a/internal/config/parsing.go b/internal/config/parsing.go deleted file mode 100644 index ad85400..0000000 --- a/internal/config/parsing.go +++ /dev/null @@ -1,35 +0,0 @@ -package config - -import ( - "github.com/spf13/viper" -) - -const ( - pluginConfigKey = "handler" - listenAddressConfigKey = "listenAddress" - portConfigKey = "port" - portsConfigKey = "ports" -) - -func CreateMultiHandlerConfig(handlerConfig *viper.Viper) MultiHandlerConfig { - return NewMultiHandlerConfig( - handlerConfig.GetString(pluginConfigKey), - portsFromConfig(handlerConfig), - handlerConfig.GetString(listenAddressConfigKey), - handlerConfig.Sub(OptionsKey), - ) -} - -func portsFromConfig(handlerConfig *viper.Viper) (ports []uint16) { - if portsInt := handlerConfig.GetIntSlice(portsConfigKey); len(portsInt) > 0 { - for _, port := range portsInt { - ports = append(ports, uint16(port)) - } - return - } - - if portInt := handlerConfig.GetInt(portConfigKey); portInt > 0 { - ports = append(ports, uint16(portInt)) - } - return -} diff --git a/internal/config/parsing_test.go b/internal/config/parsing_test.go deleted file mode 100644 index 5f1efb7..0000000 --- a/internal/config/parsing_test.go +++ /dev/null @@ -1,145 +0,0 @@ -package config - -import ( - "bytes" - "github.com/spf13/viper" - "reflect" - "testing" -) - -func TestCreateMultiHandlerConfig(t *testing.T) { - type args struct { - handlerConfig *viper.Viper - } - tests := []struct { - name string - args args - want MultiHandlerConfig - }{ - { - name: "Get simple multiHandlerConfig from config", - args: args{ - handlerConfig: configFromString(` -handler: sampleHandler -listenAddress: 0.0.0.0 -ports: -- 80 -- 8080 -options: {} -`), - }, - want: &multiHandlerConfig{ - handlerName: "sampleHandler", - ports: []uint16{80, 8080}, - listenAddress: "0.0.0.0", - options: viper.New(), - }, - }, - { - name: "Get more complex multiHandlerConfig from config", - args: args{ - handlerConfig: configFromString(` -handler: sampleHandler -listenAddress: 0.0.0.0 -ports: -- 80 -- 8080 -options: - optionA: asdf - optionB: as1234 -`), - }, - want: &multiHandlerConfig{ - handlerName: "sampleHandler", - ports: []uint16{80, 8080}, - listenAddress: "0.0.0.0", - options: configFromString(` -nesting: - optionA: asdf - optionB: as1234 -`).Sub("nesting"), - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := CreateMultiHandlerConfig(tt.args.handlerConfig); !reflect.DeepEqual(got, tt.want) { - t.Errorf("CreateMultiHandlerConfig() = %v, want %v", got, tt.want) - } - }) - } -} - -func Test_portsFromConfig(t *testing.T) { - type args struct { - handlerConfig *viper.Viper - } - tests := []struct { - name string - args args - wantPorts []uint16 - }{ - { - name: "Empty array if config value is not set", - args: args{ - handlerConfig: viper.New(), - }, - wantPorts: nil, - }, - { - name: "Array of one if `port` is set", - args: args{ - handlerConfig: configFromString(` -port: 80 -`), - }, - wantPorts: []uint16{80}, - }, - { - name: "Array of one if `ports` is set as array", - args: args{ - handlerConfig: configFromString(` -ports: -- 80 -`), - }, - wantPorts: []uint16{80}, - }, - { - name: "Array of two if `ports` is set as array", - args: args{ - handlerConfig: configFromString(` -ports: -- 80 -- 8080 -`), - }, - wantPorts: []uint16{80, 8080}, - }, - { - name: "Array of two if `port` is set as array", - args: args{ - handlerConfig: configFromString(` -ports: -- 80 -- 8080 -`), - }, - wantPorts: []uint16{80, 8080}, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if gotPorts := portsFromConfig(tt.args.handlerConfig); !reflect.DeepEqual(gotPorts, tt.wantPorts) { - t.Errorf("portsFromConfig() = %v, want %v", gotPorts, tt.wantPorts) - } - }) - } -} - -func configFromString(yaml string) (config *viper.Viper) { - config = viper.New() - config.SetConfigType("yaml") - _ = config.ReadConfig(bytes.NewBufferString(yaml)) - return -} diff --git a/internal/endpoints/endpoint.go b/internal/endpoints/endpoint.go index 1cfd379..e79aa5f 100644 --- a/internal/endpoints/endpoint.go +++ b/internal/endpoints/endpoint.go @@ -3,6 +3,7 @@ package endpoints import ( "github.com/baez90/inetmock/pkg/api" + "github.com/baez90/inetmock/pkg/config" ) type Endpoint interface { @@ -14,7 +15,7 @@ type Endpoint interface { type endpoint struct { name string handler api.ProtocolHandler - config api.HandlerConfig + config config.HandlerConfig } func (e endpoint) Name() string { diff --git a/internal/endpoints/endpoint_manager.go b/internal/endpoints/endpoint_manager.go index 907f868..5416941 100644 --- a/internal/endpoints/endpoint_manager.go +++ b/internal/endpoints/endpoint_manager.go @@ -2,8 +2,8 @@ package endpoints import ( "fmt" - "github.com/baez90/inetmock/internal/config" "github.com/baez90/inetmock/internal/plugins" + config2 "github.com/baez90/inetmock/pkg/config" "github.com/baez90/inetmock/pkg/logging" "go.uber.org/zap" "sync" @@ -13,7 +13,7 @@ import ( type EndpointManager interface { RegisteredEndpoints() []Endpoint StartedEndpoints() []Endpoint - CreateEndpoint(name string, multiHandlerConfig config.MultiHandlerConfig) error + CreateEndpoint(name string, multiHandlerConfig config2.MultiHandlerConfig) error StartEndpoints() ShutdownEndpoints() } @@ -40,16 +40,16 @@ func (e endpointManager) StartedEndpoints() []Endpoint { return e.properlyStartedEndpoints } -func (e *endpointManager) CreateEndpoint(name string, multiHandlerConfig config.MultiHandlerConfig) error { +func (e *endpointManager) CreateEndpoint(name string, multiHandlerConfig config2.MultiHandlerConfig) error { for _, handlerConfig := range multiHandlerConfig.HandlerConfigs() { - if handler, ok := e.registry.HandlerForName(multiHandlerConfig.HandlerName()); ok { + if handler, ok := e.registry.HandlerForName(multiHandlerConfig.Handler); ok { e.registeredEndpoints = append(e.registeredEndpoints, &endpoint{ name: name, handler: handler, config: handlerConfig, }) } else { - return fmt.Errorf("no matching handler registered for names %s", multiHandlerConfig.HandlerName()) + return fmt.Errorf("no matching handler registered for names %s", multiHandlerConfig.Handler) } } diff --git a/internal/endpoints/endpoint_manager_test.go b/internal/endpoints/endpoint_manager_test.go index 3d0530a..fcbaba1 100644 --- a/internal/endpoints/endpoint_manager_test.go +++ b/internal/endpoints/endpoint_manager_test.go @@ -1,11 +1,11 @@ package endpoints import ( - "github.com/baez90/inetmock/internal/config" api_mock "github.com/baez90/inetmock/internal/mock/api" logging_mock "github.com/baez90/inetmock/internal/mock/logging" plugins_mock "github.com/baez90/inetmock/internal/mock/plugins" "github.com/baez90/inetmock/internal/plugins" + config2 "github.com/baez90/inetmock/pkg/config" "github.com/baez90/inetmock/pkg/logging" "github.com/golang/mock/gomock" "testing" @@ -20,7 +20,7 @@ func Test_endpointManager_CreateEndpoint(t *testing.T) { } type args struct { name string - multiHandlerConfig config.MultiHandlerConfig + multiHandlerConfig config2.MultiHandlerConfig } tests := []struct { name string @@ -52,12 +52,11 @@ func Test_endpointManager_CreateEndpoint(t *testing.T) { }, args: args{ name: "sampleEndpoint", - multiHandlerConfig: config.NewMultiHandlerConfig( - "sampleHandler", - []uint16{80}, - "0.0.0.0", - nil, - ), + multiHandlerConfig: config2.MultiHandlerConfig{ + Handler: "sampleHandler", + Ports: []uint16{80}, + ListenAddress: "0.0.0.0", + }, }, }, { @@ -83,12 +82,11 @@ func Test_endpointManager_CreateEndpoint(t *testing.T) { }, args: args{ name: "sampleEndpoint", - multiHandlerConfig: config.NewMultiHandlerConfig( - "sampleHandler", - []uint16{80}, - "0.0.0.0", - nil, - ), + multiHandlerConfig: config2.MultiHandlerConfig{ + Handler: "sampleHandler", + Ports: []uint16{80}, + ListenAddress: "0.0.0.0", + }, }, }, } diff --git a/internal/endpoints/endpoint_test.go b/internal/endpoints/endpoint_test.go index 65e5d76..3ddf5f0 100644 --- a/internal/endpoints/endpoint_test.go +++ b/internal/endpoints/endpoint_test.go @@ -4,6 +4,7 @@ import ( "fmt" api_mock "github.com/baez90/inetmock/internal/mock/api" "github.com/baez90/inetmock/pkg/api" + "github.com/baez90/inetmock/pkg/config" "github.com/golang/mock/gomock" "testing" ) @@ -12,7 +13,7 @@ func Test_endpoint_Name(t *testing.T) { type fields struct { name string handler api.ProtocolHandler - config api.HandlerConfig + config config.HandlerConfig } tests := []struct { name string @@ -50,7 +51,7 @@ func Test_endpoint_Shutdown(t *testing.T) { type fields struct { name string handler api.ProtocolHandler - config api.HandlerConfig + config config.HandlerConfig } tests := []struct { name string @@ -102,17 +103,16 @@ func Test_endpoint_Shutdown(t *testing.T) { func Test_endpoint_Start(t *testing.T) { - demoHandlerConfig := api.NewHandlerConfig( - "sampleHandler", - 80, - "0.0.0.0", - nil, - ) + demoHandlerConfig := config.HandlerConfig{ + HandlerName: "sampleHandler", + Port: 80, + ListenAddress: "0.0.0.0", + } type fields struct { name string handler api.ProtocolHandler - config api.HandlerConfig + config config.HandlerConfig } tests := []struct { name string @@ -125,7 +125,7 @@ func Test_endpoint_Start(t *testing.T) { handler: func() api.ProtocolHandler { handler := api_mock.NewMockProtocolHandler(gomock.NewController(t)) handler.EXPECT(). - Start(nil). + Start(gomock.Any()). MaxTimes(1). Return(nil) return handler @@ -139,7 +139,7 @@ func Test_endpoint_Start(t *testing.T) { handler: func() api.ProtocolHandler { handler := api_mock.NewMockProtocolHandler(gomock.NewController(t)) handler.EXPECT(). - Start(nil). + Start(gomock.Any()). MaxTimes(1). Return(fmt.Errorf("")) return handler diff --git a/internal/plugins/loading.go b/internal/plugins/loading.go index 52e5e03..cb68924 100644 --- a/internal/plugins/loading.go +++ b/internal/plugins/loading.go @@ -10,6 +10,7 @@ import ( "path/filepath" "plugin" "regexp" + "strings" ) var ( @@ -19,6 +20,7 @@ var ( type HandlerRegistry interface { LoadPlugins(pluginsPath string) error + AvailableHandlers() []string RegisterHandler(handlerName string, handlerProvider api.PluginInstanceFactory, subCommands ...*cobra.Command) HandlerForName(handlerName string) (api.ProtocolHandler, bool) PluginCommands() []*cobra.Command @@ -29,11 +31,19 @@ type handlerRegistry struct { pluginCommands []*cobra.Command } +func (h handlerRegistry) AvailableHandlers() (availableHandlers []string) { + for s := range h.handlers { + availableHandlers = append(availableHandlers, s) + } + return +} + func (h handlerRegistry) PluginCommands() []*cobra.Command { return h.pluginCommands } func (h *handlerRegistry) HandlerForName(handlerName string) (instance api.ProtocolHandler, ok bool) { + handlerName = strings.ToLower(handlerName) var provider api.PluginInstanceFactory if provider, ok = h.handlers[handlerName]; ok { instance = provider() @@ -42,6 +52,7 @@ func (h *handlerRegistry) HandlerForName(handlerName string) (instance api.Proto } func (h *handlerRegistry) RegisterHandler(handlerName string, handlerProvider api.PluginInstanceFactory, subCommands ...*cobra.Command) { + handlerName = strings.ToLower(handlerName) if _, exists := h.handlers[handlerName]; exists { panic(fmt.Sprintf("handler with name %s is already registered - there's something strange...in the neighborhood", handlerName)) } diff --git a/mock_config.yaml b/mock_config.yaml index 766d4ff..566c450 100644 --- a/mock_config.yaml +++ b/mock_config.yaml @@ -25,8 +25,8 @@ tls: NotBeforeRelative: 168h NotAfterRelative: 168h rootCaCert: - publicKey: ./ca.pem - privateKey: ./ca.key + publicKeyPath: ./ca.pem + privateKeyPath: ./ca.key certCachePath: /tmp/inetmock/ endpoints: @@ -41,7 +41,8 @@ endpoints: proxy: handler: http_proxy listenAddress: 0.0.0.0 - port: 3128 + ports: + - 3128 options: target: ipAddress: 127.0.0.1 @@ -59,7 +60,8 @@ endpoints: plainDns: handler: dns_mock listenAddress: 0.0.0.0 - port: 53 + ports: + - 53 options: rules: - pattern: ".*\\.google\\.com" @@ -73,7 +75,8 @@ endpoints: dnsOverTlsDowngrade: handler: tls_interceptor listenAddress: 0.0.0.0 - port: 853 + ports: + - 853 options: target: ipAddress: 127.0.0.1 diff --git a/pkg/api/handler_config.go b/pkg/api/handler_config.go deleted file mode 100644 index 4545a42..0000000 --- a/pkg/api/handler_config.go +++ /dev/null @@ -1,42 +0,0 @@ -package api - -import "github.com/spf13/viper" - -type HandlerConfig interface { - HandlerName() string - ListenAddress() string - Port() uint16 - Options() *viper.Viper -} - -func NewHandlerConfig(handlerName string, port uint16, listenAddress string, options *viper.Viper) HandlerConfig { - return &handlerConfig{ - handlerName: handlerName, - port: port, - listenAddress: listenAddress, - options: options, - } -} - -type handlerConfig struct { - handlerName string - port uint16 - listenAddress string - options *viper.Viper -} - -func (h handlerConfig) HandlerName() string { - return h.handlerName -} - -func (h handlerConfig) ListenAddress() string { - return h.listenAddress -} - -func (h handlerConfig) Port() uint16 { - return h.port -} - -func (h handlerConfig) Options() *viper.Viper { - return h.options -} diff --git a/pkg/api/protocol_handler.go b/pkg/api/protocol_handler.go index afb7b3e..8ba1ed1 100644 --- a/pkg/api/protocol_handler.go +++ b/pkg/api/protocol_handler.go @@ -2,6 +2,7 @@ package api import ( + "github.com/baez90/inetmock/pkg/config" "go.uber.org/zap" ) @@ -10,6 +11,6 @@ type PluginInstanceFactory func() ProtocolHandler type LoggingFactory func() (*zap.Logger, error) type ProtocolHandler interface { - Start(config HandlerConfig) error + Start(config config.HandlerConfig) error Shutdown() error } diff --git a/pkg/api/services.go b/pkg/api/services.go index 996a76e..a2f46c3 100644 --- a/pkg/api/services.go +++ b/pkg/api/services.go @@ -2,8 +2,8 @@ package api import ( "github.com/baez90/inetmock/pkg/cert" + config2 "github.com/baez90/inetmock/pkg/config" "github.com/baez90/inetmock/pkg/logging" - "github.com/spf13/viper" ) var ( @@ -19,7 +19,7 @@ type services struct { } func InitServices( - config *viper.Viper, + config config2.Config, logger logging.Logger, ) error { certStore, err := cert.NewDefaultStore(config, logger) diff --git a/pkg/cert/cache_test.go b/pkg/cert/cache_test.go index c23a7a7..4cb0c99 100644 --- a/pkg/cert/cache_test.go +++ b/pkg/cert/cache_test.go @@ -5,6 +5,7 @@ import ( "crypto/x509" "fmt" certmock "github.com/baez90/inetmock/internal/mock/cert" + config2 "github.com/baez90/inetmock/pkg/config" "github.com/golang/mock/gomock" "os" "path" @@ -284,13 +285,13 @@ func Test_fileSystemCache_Put(t *testing.T) { } func setupCertGen() Generator { - return NewDefaultGenerator(Options{ - Validity: ValidityByPurpose{ - Server: ValidityDuration{ + return NewDefaultGenerator(config2.CertOptions{ + Validity: config2.ValidityByPurpose{ + Server: config2.ValidityDuration{ NotBeforeRelative: serverRelativeValidity, NotAfterRelative: serverRelativeValidity, }, - CA: ValidityDuration{ + CA: config2.ValidityDuration{ NotBeforeRelative: caRelativeValidity, NotAfterRelative: caRelativeValidity, }, diff --git a/pkg/cert/constants.go b/pkg/cert/constants.go index 151158a..8960242 100644 --- a/pkg/cert/constants.go +++ b/pkg/cert/constants.go @@ -1,19 +1,11 @@ package cert const ( - CurveTypeP224 CurveType = "P224" - CurveTypeP256 CurveType = "P256" - CurveTypeP384 CurveType = "P384" - CurveTypeP521 CurveType = "P521" - CurveTypeED25519 CurveType = "ED25519" - defaultServerValidityDuration = "168h" defaultCAValidityDuration = "17520h" certCachePathConfigKey = "tls.certCachePath" ecdsaCurveConfigKey = "tls.ecdsaCurve" - publicKeyConfigKey = "tls.rootCaCert.publicKey" - privateKeyPathConfigKey = "tls.rootCaCert.privateKey" caCertValidityNotBeforeKey = "tls.validity.ca.notBeforeRelative" caCertValidityNotAfterKey = "tls.validity.ca.notAfterRelative" serverCertValidityNotBeforeKey = "tls.validity.server.notBeforeRelative" diff --git a/pkg/cert/generator.go b/pkg/cert/generator.go index 411f921..251fd1e 100644 --- a/pkg/cert/generator.go +++ b/pkg/cert/generator.go @@ -8,6 +8,7 @@ import ( "crypto/x509" "crypto/x509/pkix" "encoding/pem" + config2 "github.com/baez90/inetmock/pkg/config" "github.com/baez90/inetmock/pkg/defaulting" "math/big" "net" @@ -35,11 +36,11 @@ type Generator interface { ServerCert(options GenerationOptions, ca *tls.Certificate) (*tls.Certificate, error) } -func NewDefaultGenerator(options Options) Generator { +func NewDefaultGenerator(options config2.CertOptions) Generator { return NewGenerator(options, NewTimeSource(), defaultKeyProvider(options)) } -func NewGenerator(options Options, source TimeSource, provider KeyProvider) Generator { +func NewGenerator(options config2.CertOptions, source TimeSource, provider KeyProvider) Generator { return &generator{ options: options, provider: provider, @@ -48,7 +49,7 @@ func NewGenerator(options Options, source TimeSource, provider KeyProvider) Gene } type generator struct { - options Options + options config2.CertOptions provider KeyProvider timeSource TimeSource outDir string diff --git a/pkg/cert/options.go b/pkg/cert/options.go index fa18280..bb66d46 100644 --- a/pkg/cert/options.go +++ b/pkg/cert/options.go @@ -1,79 +1,15 @@ package cert import ( - "fmt" - "github.com/spf13/viper" + "github.com/baez90/inetmock/pkg/config" "os" - "strings" - "time" ) -var ( - requiredConfigKeys = []string{ - publicKeyConfigKey, - privateKeyPathConfigKey, - } -) - -type CurveType string - -type File struct { - PublicKeyPath string - PrivateKeyPath string -} - -type ValidityDuration struct { - NotBeforeRelative time.Duration - NotAfterRelative time.Duration -} - -type ValidityByPurpose struct { - CA ValidityDuration - Server ValidityDuration -} - -type Options struct { - RootCACert File - CertCachePath string - Curve CurveType - Validity ValidityByPurpose -} - -func loadFromConfig(config *viper.Viper) (Options, error) { - missingKeys := make([]string, 0) - for _, requiredKey := range requiredConfigKeys { - if !config.IsSet(requiredKey) { - missingKeys = append(missingKeys, requiredKey) - } - } - - if len(missingKeys) > 0 { - return Options{}, fmt.Errorf("config keys are missing: %s", strings.Join(missingKeys, ", ")) - } - - config.SetDefault(certCachePathConfigKey, os.TempDir()) - config.SetDefault(ecdsaCurveConfigKey, string(CurveTypeED25519)) - config.SetDefault(caCertValidityNotBeforeKey, defaultCAValidityDuration) - config.SetDefault(caCertValidityNotAfterKey, defaultCAValidityDuration) - config.SetDefault(serverCertValidityNotBeforeKey, defaultServerValidityDuration) - config.SetDefault(serverCertValidityNotAfterKey, defaultServerValidityDuration) - - return Options{ - CertCachePath: config.GetString(certCachePathConfigKey), - Curve: CurveType(config.GetString(ecdsaCurveConfigKey)), - Validity: ValidityByPurpose{ - CA: ValidityDuration{ - NotBeforeRelative: config.GetDuration(caCertValidityNotBeforeKey), - NotAfterRelative: config.GetDuration(caCertValidityNotAfterKey), - }, - Server: ValidityDuration{ - NotBeforeRelative: config.GetDuration(serverCertValidityNotBeforeKey), - NotAfterRelative: config.GetDuration(serverCertValidityNotAfterKey), - }, - }, - RootCACert: File{ - PublicKeyPath: config.GetString(publicKeyConfigKey), - PrivateKeyPath: config.GetString(privateKeyPathConfigKey), - }, - }, nil +func init() { + config.AddDefaultValue(certCachePathConfigKey, os.TempDir()) + config.AddDefaultValue(ecdsaCurveConfigKey, string(config.CurveTypeED25519)) + config.AddDefaultValue(caCertValidityNotBeforeKey, defaultCAValidityDuration) + config.AddDefaultValue(caCertValidityNotAfterKey, defaultCAValidityDuration) + config.AddDefaultValue(serverCertValidityNotBeforeKey, defaultServerValidityDuration) + config.AddDefaultValue(serverCertValidityNotAfterKey, defaultServerValidityDuration) } diff --git a/pkg/cert/options_test.go b/pkg/cert/options_test.go index f8bddb2..78b3efc 100644 --- a/pkg/cert/options_test.go +++ b/pkg/cert/options_test.go @@ -1,9 +1,9 @@ package cert import ( + "github.com/baez90/inetmock/pkg/config" "github.com/spf13/viper" "os" - "reflect" "strings" "testing" "time" @@ -25,7 +25,7 @@ func Test_loadFromConfig(t *testing.T) { tests := []struct { name string args args - want Options + want config.CertOptions wantErr bool }{ { @@ -48,19 +48,19 @@ tls: certCachePath: /tmp/inetmock/ `), }, - want: Options{ - RootCACert: File{ + want: config.CertOptions{ + RootCACert: config.File{ PublicKeyPath: "./ca.pem", PrivateKeyPath: "./ca.key", }, CertCachePath: "/tmp/inetmock/", - Curve: CurveTypeP256, - Validity: ValidityByPurpose{ - CA: ValidityDuration{ + Curve: config.CurveTypeP256, + Validity: config.ValidityByPurpose{ + CA: config.ValidityDuration{ NotBeforeRelative: 17520 * time.Hour, NotAfterRelative: 17520 * time.Hour, }, - Server: ValidityDuration{ + Server: config.ValidityDuration{ NotBeforeRelative: 168 * time.Hour, NotAfterRelative: 168 * time.Hour, }, @@ -76,7 +76,7 @@ tls: privateKey: ./ca.key `), }, - want: Options{}, + want: config.CertOptions{}, wantErr: true, }, { @@ -88,7 +88,7 @@ tls: publicKey: ./ca.pem `), }, - want: Options{}, + want: config.CertOptions{}, wantErr: true, }, { @@ -101,19 +101,19 @@ tls: privateKey: ./ca.key `), }, - want: Options{ - RootCACert: File{ + want: config.CertOptions{ + RootCACert: config.File{ PublicKeyPath: "./ca.pem", PrivateKeyPath: "./ca.key", }, CertCachePath: os.TempDir(), - Curve: CurveTypeED25519, - Validity: ValidityByPurpose{ - CA: ValidityDuration{ + Curve: config.CurveTypeED25519, + Validity: config.ValidityByPurpose{ + CA: config.ValidityDuration{ NotBeforeRelative: 17520 * time.Hour, NotAfterRelative: 17520 * time.Hour, }, - Server: ValidityDuration{ + Server: config.ValidityDuration{ NotBeforeRelative: 168 * time.Hour, NotAfterRelative: 168 * time.Hour, }, @@ -124,14 +124,7 @@ tls: } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := loadFromConfig(tt.args.config) - if (err != nil) != tt.wantErr { - t.Errorf("loadFromConfig() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("loadFromConfig() got = %v, want %v", got, tt.want) - } + }) } } diff --git a/pkg/cert/store.go b/pkg/cert/store.go index 9943cfc..58d5c9a 100644 --- a/pkg/cert/store.go +++ b/pkg/cert/store.go @@ -7,8 +7,8 @@ import ( "crypto/rand" "crypto/tls" "crypto/x509" + "github.com/baez90/inetmock/pkg/config" "github.com/baez90/inetmock/pkg/logging" - "github.com/spf13/viper" "go.uber.org/zap" "net" ) @@ -18,7 +18,7 @@ const ( ) var ( - defaultKeyProvider = func(options Options) func() (key interface{}, err error) { + defaultKeyProvider = func(options config.CertOptions) func() (key interface{}, err error) { return func() (key interface{}, err error) { return privateKeyForCurve(options) } @@ -34,24 +34,20 @@ type Store interface { } func NewDefaultStore( - config *viper.Viper, + config config.Config, logger logging.Logger, ) (Store, error) { - options, err := loadFromConfig(config) - if err != nil { - return nil, err - } timeSource := NewTimeSource() return NewStore( - options, - NewFileSystemCache(options.CertCachePath, timeSource), - NewDefaultGenerator(options), + config.TLSConfig(), + NewFileSystemCache(config.TLSConfig().CertCachePath, timeSource), + NewDefaultGenerator(config.TLSConfig()), logger, ) } func NewStore( - options Options, + options config.CertOptions, cache Cache, generator Generator, logger logging.Logger, @@ -71,7 +67,7 @@ func NewStore( } type store struct { - options Options + options config.CertOptions caCert *tls.Certificate cache Cache timeSource TimeSource @@ -135,15 +131,15 @@ func (s *store) GetCertificate(serverName string, ip string) (cert *tls.Certific return } -func privateKeyForCurve(options Options) (privateKey interface{}, err error) { +func privateKeyForCurve(options config.CertOptions) (privateKey interface{}, err error) { switch options.Curve { - case CurveTypeP224: + case config.CurveTypeP224: privateKey, err = ecdsa.GenerateKey(elliptic.P224(), rand.Reader) - case CurveTypeP256: + case config.CurveTypeP256: privateKey, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - case CurveTypeP384: + case config.CurveTypeP384: privateKey, err = ecdsa.GenerateKey(elliptic.P384(), rand.Reader) - case CurveTypeP521: + case config.CurveTypeP521: privateKey, err = ecdsa.GenerateKey(elliptic.P521(), rand.Reader) default: _, privateKey, err = ed25519.GenerateKey(rand.Reader) diff --git a/pkg/config/certs.go b/pkg/config/certs.go new file mode 100644 index 0000000..907e972 --- /dev/null +++ b/pkg/config/certs.go @@ -0,0 +1,27 @@ +package config + +import "time" + +type CurveType string + +type File struct { + PublicKeyPath string + PrivateKeyPath string +} + +type ValidityDuration struct { + NotBeforeRelative time.Duration + NotAfterRelative time.Duration +} + +type ValidityByPurpose struct { + CA ValidityDuration + Server ValidityDuration +} + +type CertOptions struct { + RootCACert File + CertCachePath string + Curve CurveType + Validity ValidityByPurpose +} diff --git a/pkg/config/config.go b/pkg/config/config.go new file mode 100644 index 0000000..0f1a716 --- /dev/null +++ b/pkg/config/config.go @@ -0,0 +1,109 @@ +package config + +import ( + "github.com/baez90/inetmock/pkg/logging" + "github.com/baez90/inetmock/pkg/path" + "github.com/spf13/pflag" + "github.com/spf13/viper" + "go.uber.org/zap" + "strings" +) + +var ( + appConfig Config +) + +func CreateConfig(flags *pflag.FlagSet) { + logger, _ := logging.CreateLogger() + configInstance := &config{ + logger: logger.Named("Config"), + cfg: viper.New(), + } + + configInstance.cfg.SetConfigName("config") + configInstance.cfg.SetConfigType("yaml") + configInstance.cfg.AddConfigPath("/etc/inetmock/") + configInstance.cfg.AddConfigPath("$HOME/.inetmock") + configInstance.cfg.AddConfigPath(".") + configInstance.cfg.SetEnvPrefix("INetMock") + _ = configInstance.cfg.BindPFlags(flags) + configInstance.cfg.SetEnvKeyReplacer(strings.NewReplacer("-", "_")) + configInstance.cfg.AutomaticEnv() + + for k, v := range registeredDefaults { + configInstance.cfg.SetDefault(k, v) + } + + for k, v := range registeredAliases { + configInstance.cfg.RegisterAlias(k, v) + } + + appConfig = configInstance +} + +func Instance() Config { + return appConfig +} + +type Config interface { + ReadConfig(configFilePath string) error + ReadConfigString(config, format string) error + Viper() *viper.Viper + TLSConfig() CertOptions + PluginsDir() string + EndpointConfigs() map[string]MultiHandlerConfig +} + +type config struct { + cfg *viper.Viper + logger logging.Logger + TLS CertOptions + PluginsDirectory string + Endpoints map[string]MultiHandlerConfig +} + +func (c *config) ReadConfigString(config, format string) (err error) { + c.cfg.SetConfigType(format) + if err = c.cfg.ReadConfig(strings.NewReader(config)); err != nil { + return + } + + err = c.cfg.Unmarshal(c) + return +} + +func (c *config) EndpointConfigs() map[string]MultiHandlerConfig { + return c.Endpoints +} + +func (c *config) PluginsDir() string { + return c.PluginsDirectory +} + +func (c *config) TLSConfig() CertOptions { + return c.TLS +} + +func (c *config) Viper() *viper.Viper { + return c.cfg +} + +func (c *config) ReadConfig(configFilePath string) (err error) { + if configFilePath != "" && path.FileExists(configFilePath) { + c.logger.Info( + "loading config from passed config file path", + zap.String("configFilePath", configFilePath), + ) + c.cfg.SetConfigFile(configFilePath) + } + if err = c.cfg.ReadInConfig(); err != nil { + if _, ok := err.(viper.ConfigFileNotFoundError); ok { + err = nil + c.logger.Warn("failed to load config") + } + } + + err = c.cfg.Unmarshal(c) + + return +} diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go new file mode 100644 index 0000000..1bd76fe --- /dev/null +++ b/pkg/config/config_test.go @@ -0,0 +1,88 @@ +package config + +import ( + "github.com/spf13/pflag" + "os" + "testing" +) + +func Test_config_ReadConfig(t *testing.T) { + type args struct { + flags *pflag.FlagSet + config string + } + tests := []struct { + name string + args args + matcher func(Config) bool + wantErr bool + }{ + { + name: "Test endpoints config", + args: args{ + flags: pflag.NewFlagSet("", pflag.ContinueOnError), + config: ` +endpoints: + plainHttp: + handler: http_mock + listenAddress: 0.0.0.0 + ports: + - 80 + - 8080 + options: {} + proxy: + handler: http_proxy + listenAddress: 0.0.0.0 + ports: + - 3128 + options: + target: + ipAddress: 127.0.0.1 + port: 80 +`, + }, + matcher: func(c Config) bool { + if len(c.EndpointConfigs()) < 1 { + t.Error("Expected EndpointConfigs to be set but is empty") + return false + } + + return true + }, + wantErr: false, + }, + { + name: "Test loading of flags by adding", + args: args{ + flags: func() *pflag.FlagSet { + flags := pflag.NewFlagSet("", pflag.ContinueOnError) + flags.String("plugins-directory", os.TempDir(), "") + + return flags + }(), + config: ``, + }, + matcher: func(c Config) bool { + if c.PluginsDir() != os.TempDir() { + t.Errorf("PluginsDir() Expected = %s, Got = %s", os.TempDir(), c.PluginsDir()) + } + return true + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + CreateConfig(tt.args.flags) + cfg := Instance() + if err := cfg.ReadConfigString(tt.args.config, "yaml"); (err != nil) != tt.wantErr { + t.Errorf("ReadConfig() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if !tt.matcher(cfg) { + t.Error("matcher error") + } + }) + } +} diff --git a/pkg/config/constants.go b/pkg/config/constants.go new file mode 100644 index 0000000..1dc6cc1 --- /dev/null +++ b/pkg/config/constants.go @@ -0,0 +1,12 @@ +package config + +const ( + EndpointsKey = "endpoints" + OptionsKey = "options" + + CurveTypeP224 CurveType = "P224" + CurveTypeP256 CurveType = "P256" + CurveTypeP384 CurveType = "P384" + CurveTypeP521 CurveType = "P521" + CurveTypeED25519 CurveType = "ED25519" +) diff --git a/pkg/config/defaults.go b/pkg/config/defaults.go new file mode 100644 index 0000000..5d52ea3 --- /dev/null +++ b/pkg/config/defaults.go @@ -0,0 +1,17 @@ +package config + +var ( + registeredDefaults = make(map[string]interface{}) + // default aliases + registeredAliases = map[string]string{ + "PluginsDirectory": "plugins-directory", + } +) + +func AddDefaultValue(key string, val interface{}) { + registeredDefaults[key] = val +} + +func AddAlias(alias, orig string) { + registeredAliases[alias] = orig +} diff --git a/pkg/config/handler_config.go b/pkg/config/handler_config.go new file mode 100644 index 0000000..5f523dd --- /dev/null +++ b/pkg/config/handler_config.go @@ -0,0 +1,10 @@ +package config + +import "github.com/spf13/viper" + +type HandlerConfig struct { + HandlerName string + Port uint16 + ListenAddress string + Options *viper.Viper +} diff --git a/pkg/api/handler_config_test.go b/pkg/config/handler_config_test.go similarity index 68% rename from pkg/api/handler_config_test.go rename to pkg/config/handler_config_test.go index 23952cb..527136b 100644 --- a/pkg/api/handler_config_test.go +++ b/pkg/config/handler_config_test.go @@ -1,4 +1,4 @@ -package api +package config import ( "github.com/spf13/viper" @@ -19,12 +19,12 @@ func Test_handlerConfig_HandlerName(t *testing.T) { want string }{ { - name: "Get empty HandlerName for uninitialized struct", + name: "Get empty Handler for uninitialized struct", fields: fields{}, want: "", }, { - name: "Get expected HandlerName for initialized struct", + name: "Get expected Handler for initialized struct", fields: fields{ handlerName: "sampleHandler", }, @@ -33,14 +33,14 @@ func Test_handlerConfig_HandlerName(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := handlerConfig{ - handlerName: tt.fields.handlerName, - port: tt.fields.port, - listenAddress: tt.fields.listenAddress, - options: tt.fields.options, + h := HandlerConfig{ + HandlerName: tt.fields.handlerName, + Port: tt.fields.port, + ListenAddress: tt.fields.listenAddress, + Options: tt.fields.options, } - if got := h.HandlerName(); got != tt.want { - t.Errorf("HandlerName() = %v, want %v", got, tt.want) + if got := h.HandlerName; got != tt.want { + t.Errorf("Handler() = %v, want %v", got, tt.want) } }) } @@ -73,13 +73,13 @@ func Test_handlerConfig_ListenAddress(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := handlerConfig{ - handlerName: tt.fields.handlerName, - port: tt.fields.port, - listenAddress: tt.fields.listenAddress, - options: tt.fields.options, + h := HandlerConfig{ + HandlerName: tt.fields.handlerName, + Port: tt.fields.port, + ListenAddress: tt.fields.listenAddress, + Options: tt.fields.options, } - if got := h.ListenAddress(); got != tt.want { + if got := h.ListenAddress; got != tt.want { t.Errorf("ListenAddress() = %v, want %v", got, tt.want) } }) @@ -113,13 +113,13 @@ func Test_handlerConfig_Options(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := handlerConfig{ - handlerName: tt.fields.handlerName, - port: tt.fields.port, - listenAddress: tt.fields.listenAddress, - options: tt.fields.options, + h := HandlerConfig{ + HandlerName: tt.fields.handlerName, + Port: tt.fields.port, + ListenAddress: tt.fields.listenAddress, + Options: tt.fields.options, } - if got := h.Options(); !reflect.DeepEqual(got, tt.want) { + if got := h.Options; !reflect.DeepEqual(got, tt.want) { t.Errorf("Options() = %v, want %v", got, tt.want) } }) @@ -153,13 +153,13 @@ func Test_handlerConfig_Port(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := handlerConfig{ - handlerName: tt.fields.handlerName, - port: tt.fields.port, - listenAddress: tt.fields.listenAddress, - options: tt.fields.options, + h := HandlerConfig{ + HandlerName: tt.fields.handlerName, + Port: tt.fields.port, + ListenAddress: tt.fields.listenAddress, + Options: tt.fields.options, } - if got := h.Port(); got != tt.want { + if got := h.Port; got != tt.want { t.Errorf("Port() = %v, want %v", got, tt.want) } }) diff --git a/pkg/config/multi_handler_config.go b/pkg/config/multi_handler_config.go new file mode 100644 index 0000000..475b19d --- /dev/null +++ b/pkg/config/multi_handler_config.go @@ -0,0 +1,25 @@ +package config + +import ( + "github.com/spf13/viper" +) + +type MultiHandlerConfig struct { + Handler string + Ports []uint16 + ListenAddress string + Options *viper.Viper +} + +func (m MultiHandlerConfig) HandlerConfigs() []HandlerConfig { + configs := make([]HandlerConfig, 0) + for _, port := range m.Ports { + configs = append(configs, HandlerConfig{ + HandlerName: m.Handler, + Port: port, + ListenAddress: m.ListenAddress, + Options: m.Options, + }) + } + return configs +} diff --git a/plugins/dns_mock/main.go b/plugins/dns_mock/main.go index a4ae9fa..8947df9 100644 --- a/plugins/dns_mock/main.go +++ b/plugins/dns_mock/main.go @@ -2,7 +2,7 @@ package main import ( "fmt" - "github.com/baez90/inetmock/pkg/api" + "github.com/baez90/inetmock/pkg/config" "github.com/baez90/inetmock/pkg/logging" "github.com/miekg/dns" "go.uber.org/zap" @@ -13,9 +13,9 @@ type dnsHandler struct { dnsServer []*dns.Server } -func (d *dnsHandler) Start(config api.HandlerConfig) (err error) { - options := loadFromConfig(config.Options()) - addr := fmt.Sprintf("%s:%d", config.ListenAddress(), config.Port()) +func (d *dnsHandler) Start(config config.HandlerConfig) (err error) { + options := loadFromConfig(config.Options) + addr := fmt.Sprintf("%s:%d", config.ListenAddress, config.Port) handler := ®exHandler{ fallback: options.Fallback, diff --git a/plugins/http_mock/main.go b/plugins/http_mock/main.go index bdf5ce5..9f0abbb 100644 --- a/plugins/http_mock/main.go +++ b/plugins/http_mock/main.go @@ -2,7 +2,7 @@ package main import ( "fmt" - "github.com/baez90/inetmock/pkg/api" + "github.com/baez90/inetmock/pkg/config" "github.com/baez90/inetmock/pkg/logging" "go.uber.org/zap" "net/http" @@ -18,9 +18,9 @@ type httpHandler struct { server *http.Server } -func (p *httpHandler) Start(config api.HandlerConfig) (err error) { - options := loadFromConfig(config.Options()) - addr := fmt.Sprintf("%s:%d", config.ListenAddress(), config.Port()) +func (p *httpHandler) Start(config config.HandlerConfig) (err error) { + options := loadFromConfig(config.Options) + addr := fmt.Sprintf("%s:%d", config.ListenAddress, config.Port) p.server = &http.Server{Addr: addr, Handler: p.router} p.logger = p.logger.With( zap.String("address", addr), diff --git a/plugins/http_proxy/main.go b/plugins/http_proxy/main.go index 15678c9..b00b291 100644 --- a/plugins/http_proxy/main.go +++ b/plugins/http_proxy/main.go @@ -3,6 +3,7 @@ package main import ( "fmt" "github.com/baez90/inetmock/pkg/api" + "github.com/baez90/inetmock/pkg/config" "github.com/baez90/inetmock/pkg/logging" "go.uber.org/zap" "gopkg.in/elazarl/goproxy.v1" @@ -19,9 +20,9 @@ type httpProxy struct { server *http.Server } -func (h *httpProxy) Start(config api.HandlerConfig) (err error) { - options := loadFromConfig(config.Options()) - addr := fmt.Sprintf("%s:%d", config.ListenAddress(), config.Port()) +func (h *httpProxy) Start(config config.HandlerConfig) (err error) { + options := loadFromConfig(config.Options) + addr := fmt.Sprintf("%s:%d", config.ListenAddress, config.Port) h.server = &http.Server{Addr: addr, Handler: h.proxy} h.logger = h.logger.With( zap.String("address", addr), diff --git a/plugins/tls_interceptor/main.go b/plugins/tls_interceptor/main.go index 6cb8ce2..b4963af 100644 --- a/plugins/tls_interceptor/main.go +++ b/plugins/tls_interceptor/main.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/baez90/inetmock/pkg/api" "github.com/baez90/inetmock/pkg/cert" + "github.com/baez90/inetmock/pkg/config" "github.com/baez90/inetmock/pkg/logging" "github.com/google/uuid" "go.uber.org/zap" @@ -27,9 +28,9 @@ type tlsInterceptor struct { currentConnections map[uuid.UUID]*proxyConn } -func (t *tlsInterceptor) Start(config api.HandlerConfig) (err error) { - t.options = loadFromConfig(config.Options()) - addr := fmt.Sprintf("%s:%d", config.ListenAddress(), config.Port()) +func (t *tlsInterceptor) Start(config config.HandlerConfig) (err error) { + t.options = loadFromConfig(config.Options) + addr := fmt.Sprintf("%s:%d", config.ListenAddress, config.Port) t.logger = t.logger.With( zap.String("address", addr), diff --git a/proxy_config.yaml b/proxy_config.yaml deleted file mode 100644 index 7c3b15d..0000000 --- a/proxy_config.yaml +++ /dev/null @@ -1,77 +0,0 @@ -endpoints: - plainHttp: - handler: http_proxy - listenAddress: 0.0.0.0 - port: 80 - options: - rules: - - pattern: ".*\\.(?i)exe" - response: ./assets/fakeFiles/sample.exe - - pattern: ".*\\.(?i)(jpg|jpeg)" - response: ./assets/fakeFiles/default.jpg - - pattern: ".*\\.(?i)png" - response: ./assets/fakeFiles/default.png - - pattern: ".*\\.(?i)gif" - response: ./assets/fakeFiles/default.gif - - pattern: ".*\\.(?i)ico" - response: ./assets/fakeFiles/default.ico - - pattern: ".*\\.(?i)txt" - response: ./assets/fakeFiles/default.txt - - pattern: ".*" - response: ./assets/fakeFiles/default.html - httpsDowngrade: - handler: tls_interceptor - listenAddress: 0.0.0.0 - port: 443 - options: - ecdsaCurve: P256 - validity: - ca: - notBeforeRelative: 17520h - notAfterRelative: 17520h - domain: - notBeforeRelative: 168h - notAfterRelative: 168h - rootCaCert: - publicKey: ./ca.pem - privateKey: ./ca.key - certCachePath: /tmp/inetmock/ - target: - ipAddress: 127.0.0.1 - port: 80 - plainDns: - handler: dns_mock - listenAddress: 0.0.0.0 - port: 53 - options: - rules: - - pattern: "www.golem.de" - response: 77.247.84.129 - - pattern: ".*\\.google\\.com" - response: 1.1.1.1 - - pattern: ".*\\.reddit\\.com" - response: 2.2.2.2 - fallback: - strategy: incremental - args: - startIP: 10.0.0.0 - dnsOverTlsDowngrade: - handler: tls_interceptor - listenAddress: 0.0.0.0 - port: 853 - options: - ecdsaCurve: P256 - validity: - ca: - notBeforeRelative: 17520h - notAfterRelative: 17520h - domain: - notBeforeRelative: 168h - notAfterRelative: 168h - rootCaCert: - publicKey: ./ca.pem - privateKey: ./ca.key - certCachePath: /tmp/inetmock/ - target: - ipAddress: 127.0.0.1 - port: 53 \ No newline at end of file