diff --git a/.goreleaser.yml b/.goreleaser.yml index dad29e5..2f81f84 100644 --- a/.goreleaser.yml +++ b/.goreleaser.yml @@ -3,7 +3,7 @@ before: hooks: # You may remove this if you don't use go modules. - - make plugins + - make handlers builds: - id: "default" ldflags: diff --git a/Makefile b/Makefile index 66a0630..46debe7 100644 --- a/Makefile +++ b/Makefile @@ -3,6 +3,7 @@ DIR = $(dir $(realpath $(firstword $(MAKEFILE_LIST)))) BUILD_PATH = $(DIR)/main.go PKGS = $(shell go list ./...) TEST_PKGS = $(shell find . -type f -name "*_test.go" -not -path "./plugins/*" -printf '%h\n' | sort -u) +GO_GEN_FILES = $(shell grep -rnwl --include="*.go" "go:generate" $(DIR)) GOARGS = GOOS=linux GOARCH=amd64 GO_BUILD_ARGS = -ldflags="-w -s" GO_CONTAINER_BUILD_ARGS = -ldflags="-w -s" -a -installsuffix cgo @@ -11,10 +12,8 @@ BINARY_NAME = inetmock PLUGINS = $(wildcard $(DIR)plugins/*/.) DEBUG_PORT = 2345 DEBUG_ARGS?= --development-logs=true -INETMOCK_PLUGINS_DIRECTORY = $(DIR) - -.PHONY: clean all format deps update-deps compile debug snapshot-release test cli-cover-report html-cover-report plugins $(PLUGINS) +.PHONY: clean all format deps update-deps compile debug generate snapshot-release test cli-cover-report html-cover-report plugins $(PLUGINS) $(.PHONY: clean all format deps update-deps compile debug generate snapshot-release test cli-cover-report html-cover-report plugins $(PLUGINS) $(GO_GEN_FILES) all: clean format compile test plugins clean: @@ -44,8 +43,8 @@ else @$(GOARGS) go build $(GO_BUILD_ARGS) -o $(DIR)$(BINARY_NAME) $(BUILD_PATH) endif +debug: export INETMOCK_PLUGINS_DIRECTORY = $(DIR) debug: - @export INETMOCK_PLUGINS_DIRECTORY dlv exec $(DIR)$(BINARY_NAME) \ --headless \ --listen=:2345 \ @@ -53,6 +52,11 @@ debug: --accept-multiclient \ -- $(DEBUG_ARGS) +generate: + @for go_gen_target in $(GO_GEN_FILES); do \ + go generate $$go_gen_target; \ + done + snapshot-release: @goreleaser release --snapshot --skip-publish --rm-dist @@ -63,7 +67,7 @@ test: cli-cover-report: @go tool cover -func=cov.out -html-cover-report: +html-cover-report: test @go tool cover -html=cov.out -o .coverage.html plugins: $(PLUGINS) diff --git a/go.mod b/go.mod index 1cd0e7c..c8d8d9d 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.13 require ( github.com/fsnotify/fsnotify v1.4.9 // indirect + github.com/golang/mock v1.4.3 github.com/mitchellh/mapstructure v1.2.2 // indirect github.com/pelletier/go-toml v1.7.0 // indirect github.com/spf13/afero v1.2.2 // indirect diff --git a/go.sum b/go.sum index 34c4685..36f19f8 100644 --- a/go.sum +++ b/go.sum @@ -33,7 +33,10 @@ github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7a github.com/gogo/protobuf v1.2.1/go.mod h1:hp+jE20tsWTFYpLwKvXlhS1hjn+gTNwPg2I6zVXpSg4= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/groupcache v0.0.0-20190129154638-5b532d6fd5ef/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= +github.com/golang/mock v1.1.1 h1:G5FRp8JnTd7RQH5kemVNlMeyXQAztQ3mOWV95KxsXH8= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= +github.com/golang/mock v1.4.3 h1:GV+pQPG/EUUbkh47niozDcADz6go/dUwhVzdUQHIVRw= +github.com/golang/mock v1.4.3/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= @@ -177,6 +180,7 @@ golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200409092240-59c9f1ba88fa h1:mQTN3ECqfsViCNBgq+A40vdwhkGykrrQlYe3mPj6BoU= golang.org/x/sys v0.0.0-20200409092240-59c9f1ba88fa/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= @@ -186,6 +190,7 @@ golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGm golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191029190741-b9c20aec41a5/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= @@ -215,3 +220,5 @@ gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.1-2019.2.3 h1:3JgtbtFHMiCmsznwGVTUWbgGov+pVqnlf1dEJTNAXeM= honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= +rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0= +rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA= diff --git a/internal/cmd/init.go b/internal/cmd/init.go index 5a5d4e9..eb8494a 100644 --- a/internal/cmd/init.go +++ b/internal/cmd/init.go @@ -1,12 +1,14 @@ package cmd import ( + "github.com/baez90/inetmock/internal/endpoints" "github.com/baez90/inetmock/internal/plugins" "github.com/baez90/inetmock/pkg/logging" "github.com/baez90/inetmock/pkg/path" "github.com/spf13/viper" "go.uber.org/zap" "os" + "time" ) var ( @@ -25,6 +27,7 @@ func initApp() (err error) { ) logger, _ = logging.CreateLogger() registry := plugins.Registry() + endpointManager = endpoints.NewEndpointManager(logger) if err = rootCmd.ParseFlags(os.Args); err != nil { return @@ -40,12 +43,18 @@ func initApp() (err error) { viperInst := viper.GetViper() pluginDir := viperInst.GetString("plugins-directory") + pluginLoadStartTime := time.Now() if err = registry.LoadPlugins(pluginDir); err != nil { logger.Error("Failed to load plugins", zap.String("pluginsDirectory", pluginDir), zap.Error(err), ) } + pluginLoadDuration := time.Since(pluginLoadStartTime) + logger.Info( + "loading plugins completed", + zap.Duration("pluginLoadDuration", pluginLoadDuration), + ) pluginsCmd.AddCommand(registry.PluginCommands()...) diff --git a/internal/cmd/root.go b/internal/cmd/root.go index 293db28..bdab14b 100644 --- a/internal/cmd/root.go +++ b/internal/cmd/root.go @@ -2,20 +2,20 @@ package cmd import ( "github.com/baez90/inetmock/internal/config" - "github.com/baez90/inetmock/internal/plugins" + "github.com/baez90/inetmock/internal/endpoints" "github.com/baez90/inetmock/pkg/api" + "github.com/baez90/inetmock/pkg/logging" "github.com/spf13/cobra" "github.com/spf13/viper" "go.uber.org/zap" "os" "os/signal" "strings" - "sync" "syscall" ) var ( - logger *zap.Logger + logger logging.Logger rootCmd = cobra.Command{ Use: "", Short: "INetMock is lightweight internet mock", @@ -27,6 +27,7 @@ var ( developmentLogs bool handlers []api.ProtocolHandler appConfig = config.CreateConfig() + endpointManager endpoints.EndpointManager ) func init() { @@ -39,28 +40,21 @@ func init() { } func startInetMock(cmd *cobra.Command, args []string) { - registry := plugins.Registry() - var wg sync.WaitGroup - - //todo introduce endpoint type and move startup and shutdown to this type - - for key, val := range viper.GetStringMap(config.EndpointsKey) { - handlerSubConfig := viper.Sub(strings.Join([]string{config.EndpointsKey, key, config.OptionsKey}, ".")) - pluginConfig := config.CreateHandlerConfig(val, handlerSubConfig) - logger.Info(key, zap.Any("value", pluginConfig)) - - if handler, ok := registry.HandlerForName(pluginConfig.HandlerName()); ok { - handlers = append(handlers, handler) - go startEndpoint(handler, pluginConfig, logger) - wg.Add(1) - } else { + for endpointName := range viper.GetStringMap(config.EndpointsKey) { + handlerSubConfig := viper.Sub(strings.Join([]string{config.EndpointsKey, endpointName}, ".")) + handlerConfig := config.CreateMultiHandlerConfig(handlerSubConfig) + if err := endpointManager.CreateEndpoint(endpointName, handlerConfig); err != nil { logger.Warn( - "no matching handler registered", - zap.String("handler", pluginConfig.HandlerName()), + "error occurred while creating endpoint", + zap.String("endpointName", endpointName), + zap.String("handlerName", handlerConfig.HandlerName()), + zap.Error(err), ) } } + endpointManager.StartEndpoints() + signalChannel := make(chan os.Signal, 1) signal.Notify(signalChannel, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT) @@ -72,35 +66,7 @@ func startInetMock(cmd *cobra.Command, args []string) { zap.String("signal", s.String()), ) - for _, handler := range handlers { - go shutdownEndpoint(handler, &wg, logger) - } - - wg.Wait() -} - -func startEndpoint(handler api.ProtocolHandler, config config.HandlerConfig, logger *zap.Logger) { - defer func() { - if r := recover(); r != nil { - logger.Fatal( - "recovered panic during startup of endpoint", - zap.Any("recovered", r), - ) - } - }() - handler.Start(config) -} - -func shutdownEndpoint(handler api.ProtocolHandler, wg *sync.WaitGroup, logger *zap.Logger) { - defer func() { - if r := recover(); r != nil { - logger.Fatal( - "recovered panic during shutdown of endpoint", - zap.Any("recovered", r), - ) - } - }() - handler.Shutdown(wg) + endpointManager.ShutdownEndpoints() } func ExecuteRootCommand() error { diff --git a/internal/config/handler_config.go b/internal/config/handler_config.go index 27bf5a3..cfeee30 100644 --- a/internal/config/handler_config.go +++ b/internal/config/handler_config.go @@ -4,19 +4,36 @@ import "github.com/spf13/viper" const ( pluginConfigKey = "handler" - listenAddressConfigKey = "listenaddress" + listenAddressConfigKey = "listenAddress" portConfigKey = "port" + portsConfigKey = "ports" ) +type HandlerConfig interface { + HandlerName() string + ListenAddress() string + Port() uint16 + Options() *viper.Viper +} + +func NewHandlerConfig(handlerName string, port uint16, listenAddress string, options *viper.Viper) HandlerConfig { + return &handlerConfig{ + handlerName: handlerName, + port: port, + listenAddress: listenAddress, + options: options, + } +} + type handlerConfig struct { - pluginName string + handlerName string port uint16 listenAddress string options *viper.Viper } func (h handlerConfig) HandlerName() string { - return h.pluginName + return h.handlerName } func (h handlerConfig) ListenAddress() string { @@ -30,20 +47,3 @@ func (h handlerConfig) Port() uint16 { func (h handlerConfig) Options() *viper.Viper { return h.options } - -type HandlerConfig interface { - HandlerName() string - ListenAddress() string - Port() uint16 - Options() *viper.Viper -} - -func CreateHandlerConfig(configMap interface{}, subConfig *viper.Viper) HandlerConfig { - underlyingMap := configMap.(map[string]interface{}) - return &handlerConfig{ - pluginName: underlyingMap[pluginConfigKey].(string), - listenAddress: underlyingMap[listenAddressConfigKey].(string), - port: uint16(underlyingMap[portConfigKey].(int)), - options: subConfig, - } -} diff --git a/internal/config/handler_config_test.go b/internal/config/handler_config_test.go new file mode 100644 index 0000000..838c5fd --- /dev/null +++ b/internal/config/handler_config_test.go @@ -0,0 +1,167 @@ +package config + +import ( + "github.com/spf13/viper" + "reflect" + "testing" +) + +func Test_handlerConfig_HandlerName(t *testing.T) { + type fields struct { + handlerName string + port uint16 + listenAddress string + options *viper.Viper + } + tests := []struct { + name string + fields fields + want string + }{ + { + name: "Get empty HandlerName for uninitialized struct", + fields: fields{}, + want: "", + }, + { + name: "Get expected HandlerName for initialized struct", + fields: fields{ + handlerName: "sampleHandler", + }, + want: "sampleHandler", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h := handlerConfig{ + handlerName: tt.fields.handlerName, + port: tt.fields.port, + listenAddress: tt.fields.listenAddress, + options: tt.fields.options, + } + if got := h.HandlerName(); got != tt.want { + t.Errorf("HandlerName() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_handlerConfig_ListenAddress(t *testing.T) { + type fields struct { + handlerName string + port uint16 + listenAddress string + options *viper.Viper + } + tests := []struct { + name string + fields fields + want string + }{ + { + name: "Get empty ListenAddress for uninitialized struct", + fields: fields{}, + want: "", + }, + { + name: "Get expected ListenAddress for initialized struct", + fields: fields{ + listenAddress: "0.0.0.0", + }, + want: "0.0.0.0", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h := handlerConfig{ + handlerName: tt.fields.handlerName, + port: tt.fields.port, + listenAddress: tt.fields.listenAddress, + options: tt.fields.options, + } + if got := h.ListenAddress(); got != tt.want { + t.Errorf("ListenAddress() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_handlerConfig_Options(t *testing.T) { + type fields struct { + handlerName string + port uint16 + listenAddress string + options *viper.Viper + } + tests := []struct { + name string + fields fields + want *viper.Viper + }{ + { + name: "Get nil Options for uninitialized struct", + fields: fields{}, + want: nil, + }, + { + name: "Get expected Options for initialized struct", + fields: fields{ + options: viper.New(), + }, + want: viper.New(), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h := handlerConfig{ + handlerName: tt.fields.handlerName, + port: tt.fields.port, + listenAddress: tt.fields.listenAddress, + options: tt.fields.options, + } + if got := h.Options(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("Options() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_handlerConfig_Port(t *testing.T) { + type fields struct { + handlerName string + port uint16 + listenAddress string + options *viper.Viper + } + tests := []struct { + name string + fields fields + want uint16 + }{ + { + name: "Get empty Port for uninitialized struct", + fields: fields{}, + want: 0, + }, + { + name: "Get expected Port for initialized struct", + fields: fields{ + port: 80, + }, + want: 80, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h := handlerConfig{ + handlerName: tt.fields.handlerName, + port: tt.fields.port, + listenAddress: tt.fields.listenAddress, + options: tt.fields.options, + } + if got := h.Port(); got != tt.want { + t.Errorf("Port() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/internal/config/loading.go b/internal/config/loading.go index e699b82..c0b1a03 100644 --- a/internal/config/loading.go +++ b/internal/config/loading.go @@ -22,7 +22,7 @@ type Config interface { } type config struct { - logger *zap.Logger + logger logging.Logger } func (c config) InitConfig(flags *pflag.FlagSet) { diff --git a/internal/config/multi_handler_config.go b/internal/config/multi_handler_config.go new file mode 100644 index 0000000..6ea35a6 --- /dev/null +++ b/internal/config/multi_handler_config.go @@ -0,0 +1,48 @@ +package config + +import ( + "github.com/spf13/viper" +) + +type MultiHandlerConfig interface { + HandlerName() string + ListenAddress() string + Ports() []uint16 + Options() *viper.Viper + HandlerConfigs() []HandlerConfig +} + +type multiHandlerConfig struct { + handlerName string + ports []uint16 + listenAddress string + options *viper.Viper +} + +func NewMultiHandlerConfig(handlerName string, ports []uint16, listenAddress string, options *viper.Viper) MultiHandlerConfig { + return &multiHandlerConfig{handlerName: handlerName, ports: ports, listenAddress: listenAddress, options: options} +} + +func (m multiHandlerConfig) HandlerName() string { + return m.handlerName +} + +func (m multiHandlerConfig) ListenAddress() string { + return m.listenAddress +} + +func (m multiHandlerConfig) Ports() []uint16 { + return m.ports +} + +func (m multiHandlerConfig) Options() *viper.Viper { + return m.options +} + +func (m multiHandlerConfig) HandlerConfigs() []HandlerConfig { + configs := make([]HandlerConfig, 0) + for _, port := range m.ports { + configs = append(configs, NewHandlerConfig(m.handlerName, port, m.listenAddress, m.options)) + } + return configs +} diff --git a/internal/config/multi_handler_config_test.go b/internal/config/multi_handler_config_test.go new file mode 100644 index 0000000..0d285e9 --- /dev/null +++ b/internal/config/multi_handler_config_test.go @@ -0,0 +1,240 @@ +package config + +import ( + "github.com/spf13/viper" + "reflect" + "testing" +) + +func Test_multiHandlerConfig_HandlerConfigs(t *testing.T) { + type fields struct { + handlerName string + ports []uint16 + listenAddress string + options *viper.Viper + } + tests := []struct { + name string + fields fields + want []HandlerConfig + }{ + { + name: "Get empty array if no ports are set", + fields: fields{}, + want: make([]HandlerConfig, 0), + }, + { + name: "Get a single handler config if only one port is set", + fields: fields{ + handlerName: "sampleHandler", + ports: []uint16{80}, + listenAddress: "0.0.0.0", + options: nil, + }, + want: []HandlerConfig{ + &handlerConfig{ + handlerName: "sampleHandler", + port: 80, + listenAddress: "0.0.0.0", + options: nil, + }, + }, + }, + { + name: "Get multiple handler configs if only one port is set", + fields: fields{ + handlerName: "sampleHandler", + ports: []uint16{80, 8080}, + listenAddress: "0.0.0.0", + options: nil, + }, + want: []HandlerConfig{ + &handlerConfig{ + handlerName: "sampleHandler", + port: 80, + listenAddress: "0.0.0.0", + options: nil, + }, + &handlerConfig{ + handlerName: "sampleHandler", + port: 8080, + listenAddress: "0.0.0.0", + options: nil, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + m := multiHandlerConfig{ + handlerName: tt.fields.handlerName, + ports: tt.fields.ports, + listenAddress: tt.fields.listenAddress, + options: tt.fields.options, + } + if got := m.HandlerConfigs(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("HandlerConfigs() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_multiHandlerConfig_HandlerName(t *testing.T) { + type fields struct { + handlerName string + ports []uint16 + listenAddress string + options *viper.Viper + } + tests := []struct { + name string + fields fields + want string + }{ + { + name: "Get empty handler name for uninitialized struct", + fields: fields{}, + want: "", + }, + { + name: "Get expected handler name for initialized struct", + fields: fields{ + handlerName: "sampleHandler", + }, + want: "sampleHandler", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + m := multiHandlerConfig{ + handlerName: tt.fields.handlerName, + ports: tt.fields.ports, + listenAddress: tt.fields.listenAddress, + options: tt.fields.options, + } + if got := m.HandlerName(); got != tt.want { + t.Errorf("HandlerName() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_multiHandlerConfig_ListenAddress(t *testing.T) { + type fields struct { + handlerName string + ports []uint16 + listenAddress string + options *viper.Viper + } + tests := []struct { + name string + fields fields + want string + }{ + { + name: "Get empty ListenAddress for uninitialized struct", + fields: fields{}, + want: "", + }, + { + name: "Get expected ListenAddress for initialized struct", + fields: fields{ + listenAddress: "0.0.0.0", + }, + want: "0.0.0.0", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + m := multiHandlerConfig{ + handlerName: tt.fields.handlerName, + ports: tt.fields.ports, + listenAddress: tt.fields.listenAddress, + options: tt.fields.options, + } + if got := m.ListenAddress(); got != tt.want { + t.Errorf("ListenAddress() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_multiHandlerConfig_Options(t *testing.T) { + type fields struct { + handlerName string + ports []uint16 + listenAddress string + options *viper.Viper + } + tests := []struct { + name string + fields fields + want *viper.Viper + }{ + { + name: "Get nil Options for uninitialized struct", + fields: fields{}, + want: nil, + }, + { + name: "Get expected Options for initialized struct", + fields: fields{ + options: viper.New(), + }, + want: viper.New(), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + m := multiHandlerConfig{ + handlerName: tt.fields.handlerName, + ports: tt.fields.ports, + listenAddress: tt.fields.listenAddress, + options: tt.fields.options, + } + if got := m.Options(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("Options() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_multiHandlerConfig_Ports(t *testing.T) { + type fields struct { + handlerName string + ports []uint16 + listenAddress string + options *viper.Viper + } + tests := []struct { + name string + fields fields + want []uint16 + }{ + { + name: "Get empty Ports for uninitialized struct", + fields: fields{}, + want: nil, + }, + { + name: "Get expected Ports for initialized struct", + fields: fields{ + ports: []uint16{80, 8080}, + }, + want: []uint16{80, 8080}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + m := multiHandlerConfig{ + handlerName: tt.fields.handlerName, + ports: tt.fields.ports, + listenAddress: tt.fields.listenAddress, + options: tt.fields.options, + } + if got := m.Ports(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("Ports() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/internal/config/parsing.go b/internal/config/parsing.go new file mode 100644 index 0000000..adc4a09 --- /dev/null +++ b/internal/config/parsing.go @@ -0,0 +1,26 @@ +package config + +import "github.com/spf13/viper" + +func CreateMultiHandlerConfig(handlerConfig *viper.Viper) MultiHandlerConfig { + return NewMultiHandlerConfig( + handlerConfig.GetString(pluginConfigKey), + portsFromConfig(handlerConfig), + handlerConfig.GetString(listenAddressConfigKey), + handlerConfig.Sub(OptionsKey), + ) +} + +func portsFromConfig(handlerConfig *viper.Viper) (ports []uint16) { + if portsInt := handlerConfig.GetIntSlice(portsConfigKey); len(portsInt) > 0 { + for _, port := range portsInt { + ports = append(ports, uint16(port)) + } + return + } + + if portInt := handlerConfig.GetInt(portConfigKey); portInt > 0 { + ports = append(ports, uint16(portInt)) + } + return +} diff --git a/internal/config/parsing_test.go b/internal/config/parsing_test.go new file mode 100644 index 0000000..5f1efb7 --- /dev/null +++ b/internal/config/parsing_test.go @@ -0,0 +1,145 @@ +package config + +import ( + "bytes" + "github.com/spf13/viper" + "reflect" + "testing" +) + +func TestCreateMultiHandlerConfig(t *testing.T) { + type args struct { + handlerConfig *viper.Viper + } + tests := []struct { + name string + args args + want MultiHandlerConfig + }{ + { + name: "Get simple multiHandlerConfig from config", + args: args{ + handlerConfig: configFromString(` +handler: sampleHandler +listenAddress: 0.0.0.0 +ports: +- 80 +- 8080 +options: {} +`), + }, + want: &multiHandlerConfig{ + handlerName: "sampleHandler", + ports: []uint16{80, 8080}, + listenAddress: "0.0.0.0", + options: viper.New(), + }, + }, + { + name: "Get more complex multiHandlerConfig from config", + args: args{ + handlerConfig: configFromString(` +handler: sampleHandler +listenAddress: 0.0.0.0 +ports: +- 80 +- 8080 +options: + optionA: asdf + optionB: as1234 +`), + }, + want: &multiHandlerConfig{ + handlerName: "sampleHandler", + ports: []uint16{80, 8080}, + listenAddress: "0.0.0.0", + options: configFromString(` +nesting: + optionA: asdf + optionB: as1234 +`).Sub("nesting"), + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := CreateMultiHandlerConfig(tt.args.handlerConfig); !reflect.DeepEqual(got, tt.want) { + t.Errorf("CreateMultiHandlerConfig() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_portsFromConfig(t *testing.T) { + type args struct { + handlerConfig *viper.Viper + } + tests := []struct { + name string + args args + wantPorts []uint16 + }{ + { + name: "Empty array if config value is not set", + args: args{ + handlerConfig: viper.New(), + }, + wantPorts: nil, + }, + { + name: "Array of one if `port` is set", + args: args{ + handlerConfig: configFromString(` +port: 80 +`), + }, + wantPorts: []uint16{80}, + }, + { + name: "Array of one if `ports` is set as array", + args: args{ + handlerConfig: configFromString(` +ports: +- 80 +`), + }, + wantPorts: []uint16{80}, + }, + { + name: "Array of two if `ports` is set as array", + args: args{ + handlerConfig: configFromString(` +ports: +- 80 +- 8080 +`), + }, + wantPorts: []uint16{80, 8080}, + }, + { + name: "Array of two if `port` is set as array", + args: args{ + handlerConfig: configFromString(` +ports: +- 80 +- 8080 +`), + }, + wantPorts: []uint16{80, 8080}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if gotPorts := portsFromConfig(tt.args.handlerConfig); !reflect.DeepEqual(gotPorts, tt.wantPorts) { + t.Errorf("portsFromConfig() = %v, want %v", gotPorts, tt.wantPorts) + } + }) + } +} + +func configFromString(yaml string) (config *viper.Viper) { + config = viper.New() + config.SetConfigType("yaml") + _ = config.ReadConfig(bytes.NewBufferString(yaml)) + return +} diff --git a/internal/endpoints/endpoint.go b/internal/endpoints/endpoint.go new file mode 100644 index 0000000..b209339 --- /dev/null +++ b/internal/endpoints/endpoint.go @@ -0,0 +1,31 @@ +//go:generate mockgen -source=endpoint.go -destination=./../../mock/endpoint_mock.go -package=mock +package endpoints + +import ( + "github.com/baez90/inetmock/internal/config" + "github.com/baez90/inetmock/pkg/api" +) + +type Endpoint interface { + Start() error + Shutdown() error + Name() string +} + +type endpoint struct { + name string + handler api.ProtocolHandler + config config.HandlerConfig +} + +func (e endpoint) Name() string { + return e.name +} + +func (e *endpoint) Start() (err error) { + return e.handler.Start(e.config) +} + +func (e *endpoint) Shutdown() (err error) { + return e.handler.Shutdown() +} diff --git a/internal/endpoints/endpoint_manager.go b/internal/endpoints/endpoint_manager.go new file mode 100644 index 0000000..209e3b7 --- /dev/null +++ b/internal/endpoints/endpoint_manager.go @@ -0,0 +1,130 @@ +package endpoints + +import ( + "fmt" + "github.com/baez90/inetmock/internal/config" + "github.com/baez90/inetmock/internal/plugins" + "github.com/baez90/inetmock/pkg/logging" + "go.uber.org/zap" + "sync" + "time" +) + +type EndpointManager interface { + RegisteredEndpoints() []Endpoint + StartedEndpoints() []Endpoint + CreateEndpoint(name string, multiHandlerConfig config.MultiHandlerConfig) error + StartEndpoints() + ShutdownEndpoints() +} + +func NewEndpointManager(logger logging.Logger) EndpointManager { + return &endpointManager{ + logger: logger, + registry: plugins.Registry(), + } +} + +type endpointManager struct { + logger logging.Logger + registeredEndpoints []Endpoint + properlyStartedEndpoints []Endpoint + registry plugins.HandlerRegistry +} + +func (e endpointManager) RegisteredEndpoints() []Endpoint { + return e.registeredEndpoints +} + +func (e endpointManager) StartedEndpoints() []Endpoint { + return e.properlyStartedEndpoints +} + +func (e *endpointManager) CreateEndpoint(name string, multiHandlerConfig config.MultiHandlerConfig) (err error) { + if handler, ok := e.registry.HandlerForName(multiHandlerConfig.HandlerName()); ok { + for _, handlerConfig := range multiHandlerConfig.HandlerConfigs() { + e.registeredEndpoints = append(e.registeredEndpoints, &endpoint{ + name: name, + handler: handler, + config: handlerConfig, + }) + } + } else { + err = fmt.Errorf("no matching handler registered for name %s", multiHandlerConfig.HandlerName()) + } + return +} + +func (e *endpointManager) StartEndpoints() { + startTime := time.Now() + for _, endpoint := range e.registeredEndpoints { + endpointLogger := e.logger.With( + zap.String("endpoint", endpoint.Name()), + ) + endpointLogger.Info("Starting endpoint") + if ok := startEndpoint(endpoint, endpointLogger); ok { + e.properlyStartedEndpoints = append(e.properlyStartedEndpoints, endpoint) + endpointLogger.Info("successfully started endpoint") + } else { + endpointLogger.Error("error occurred during endpoint startup - will be skipped for now") + } + } + endpointStartupDuration := time.Since(startTime) + e.logger.Info( + "Startup of all endpoints completed", + zap.Duration("startupTime", endpointStartupDuration), + ) +} + +func (e *endpointManager) ShutdownEndpoints() { + var waitGroup sync.WaitGroup + waitGroup.Add(len(e.properlyStartedEndpoints)) + + for _, endpoint := range e.properlyStartedEndpoints { + endpointLogger := e.logger.With( + zap.String("endpoint", endpoint.Name()), + ) + endpointLogger.Info("Triggering shutdown of endpoint") + go shutdownEndpoint(endpoint, endpointLogger, &waitGroup) + } + + waitGroup.Wait() +} + +func startEndpoint(ep Endpoint, logger logging.Logger) (success bool) { + defer func() { + if r := recover(); r != nil { + logger.Fatal( + "recovered panic during startup of endpoint", + zap.Any("recovered", r), + ) + } + }() + if err := ep.Start(); err != nil { + logger.Error( + "failed to start endpoint", + zap.Error(err), + ) + } else { + success = true + } + return +} + +func shutdownEndpoint(ep Endpoint, logger logging.Logger, wg *sync.WaitGroup) { + defer func() { + if r := recover(); r != nil { + logger.Fatal( + "recovered panic during shutdown of endpoint", + zap.Any("recovered", r), + ) + } + wg.Done() + }() + if err := ep.Shutdown(); err != nil { + logger.Error( + "Failed to shutdown endpoint", + zap.Error(err), + ) + } +} diff --git a/internal/endpoints/endpoint_manager_test.go b/internal/endpoints/endpoint_manager_test.go new file mode 100644 index 0000000..b629cef --- /dev/null +++ b/internal/endpoints/endpoint_manager_test.go @@ -0,0 +1,116 @@ +package endpoints + +import ( + "github.com/baez90/inetmock/internal/config" + "github.com/baez90/inetmock/internal/plugins" + "github.com/baez90/inetmock/mock" + "github.com/baez90/inetmock/pkg/logging" + "github.com/golang/mock/gomock" + "testing" +) + +func Test_endpointManager_CreateEndpoint(t *testing.T) { + type fields struct { + logger logging.Logger + registeredEndpoints []Endpoint + properlyStartedEndpoints []Endpoint + registry plugins.HandlerRegistry + } + type args struct { + name string + multiHandlerConfig config.MultiHandlerConfig + } + tests := []struct { + name string + fields fields + args args + wantErr bool + wantEndpoints int + }{ + { + name: "Test add endpoint", + wantErr: false, + wantEndpoints: 1, + fields: fields{ + logger: func() logging.Logger { + return mock.NewMockLogger(gomock.NewController(t)) + }(), + registeredEndpoints: nil, + properlyStartedEndpoints: nil, + registry: func() plugins.HandlerRegistry { + registry := mock.NewMockHandlerRegistry(gomock.NewController(t)) + registry. + EXPECT(). + HandlerForName("sampleHandler"). + MinTimes(1). + MaxTimes(1). + Return(mock.NewMockProtocolHandler(gomock.NewController(t)), true) + return registry + }(), + }, + args: args{ + name: "sampleEndpoint", + multiHandlerConfig: config.NewMultiHandlerConfig( + "sampleHandler", + []uint16{80}, + "0.0.0.0", + nil, + ), + }, + }, + { + name: "Test add unknown handler", + wantErr: true, + wantEndpoints: 0, + fields: fields{ + logger: func() logging.Logger { + return mock.NewMockLogger(gomock.NewController(t)) + }(), + registeredEndpoints: nil, + properlyStartedEndpoints: nil, + registry: func() plugins.HandlerRegistry { + registry := mock.NewMockHandlerRegistry(gomock.NewController(t)) + registry. + EXPECT(). + HandlerForName("sampleHandler"). + MinTimes(1). + MaxTimes(1). + Return(nil, false) + return registry + }(), + }, + args: args{ + name: "sampleEndpoint", + multiHandlerConfig: config.NewMultiHandlerConfig( + "sampleHandler", + []uint16{80}, + "0.0.0.0", + nil, + ), + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := &endpointManager{ + logger: tt.fields.logger, + registeredEndpoints: tt.fields.registeredEndpoints, + properlyStartedEndpoints: tt.fields.properlyStartedEndpoints, + registry: tt.fields.registry, + } + + if err := e.CreateEndpoint(tt.args.name, tt.args.multiHandlerConfig); (err != nil) != tt.wantErr { + t.Errorf("CreateEndpoint() error = %v, wantErr %v", err, tt.wantErr) + } + + if len(e.RegisteredEndpoints()) != tt.wantEndpoints { + t.Errorf("RegisteredEndpoints() = %d, want = 1", len(e.RegisteredEndpoints())) + return + } + + if len(e.RegisteredEndpoints()) > 0 && e.RegisteredEndpoints()[0].Name() != tt.args.name { + t.Errorf("Name() = %s, want = %s", e.RegisteredEndpoints()[0].Name(), tt.args.name) + } + }) + } +} diff --git a/internal/endpoints/endpoint_test.go b/internal/endpoints/endpoint_test.go new file mode 100644 index 0000000..d63a99f --- /dev/null +++ b/internal/endpoints/endpoint_test.go @@ -0,0 +1,179 @@ +package endpoints + +import ( + "fmt" + "github.com/baez90/inetmock/internal/config" + "github.com/baez90/inetmock/mock" + "github.com/baez90/inetmock/pkg/api" + "github.com/golang/mock/gomock" + "testing" +) + +func Test_endpoint_Name(t *testing.T) { + type fields struct { + name string + handler api.ProtocolHandler + config config.HandlerConfig + } + tests := []struct { + name string + fields fields + want string + }{ + { + name: "Empty Name if struct is uninitialized", + fields: fields{}, + want: "", + }, + { + name: "Expected Name if struct is initialized", + fields: fields{ + name: "sampleHandler", + }, + want: "sampleHandler", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := endpoint{ + name: tt.fields.name, + handler: tt.fields.handler, + config: tt.fields.config, + } + if got := e.Name(); got != tt.want { + t.Errorf("Name() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_endpoint_Shutdown(t *testing.T) { + type fields struct { + name string + handler api.ProtocolHandler + config config.HandlerConfig + } + tests := []struct { + name string + fields fields + wantErr bool + }{ + { + name: "Expect no error if mocked handler does not return one", + fields: fields{ + handler: func() api.ProtocolHandler { + handler := mock.NewMockProtocolHandler(gomock.NewController(t)) + handler.EXPECT(). + Shutdown(). + MaxTimes(1). + Return(nil) + return handler + }(), + }, + wantErr: false, + }, + { + name: "Expect error if mocked handler returns one", + fields: fields{ + handler: func() api.ProtocolHandler { + handler := mock.NewMockProtocolHandler(gomock.NewController(t)) + handler.EXPECT(). + Shutdown(). + MaxTimes(1). + Return(fmt.Errorf("")) + return handler + }(), + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := &endpoint{ + name: tt.fields.name, + handler: tt.fields.handler, + config: tt.fields.config, + } + if err := e.Shutdown(); (err != nil) != tt.wantErr { + t.Errorf("Shutdown() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func Test_endpoint_Start(t *testing.T) { + + demoHandlerConfig := config.NewHandlerConfig( + "sampleHandler", + 80, + "0.0.0.0", + nil, + ) + + type fields struct { + name string + handler api.ProtocolHandler + config config.HandlerConfig + } + tests := []struct { + name string + fields fields + wantErr bool + }{ + { + name: "Expect no error if mocked handler does not return one", + fields: fields{ + handler: func() api.ProtocolHandler { + handler := mock.NewMockProtocolHandler(gomock.NewController(t)) + handler.EXPECT(). + Start(nil). + MaxTimes(1). + Return(nil) + return handler + }(), + }, + wantErr: false, + }, + { + name: "Expect error if mocked handler returns one", + fields: fields{ + handler: func() api.ProtocolHandler { + handler := mock.NewMockProtocolHandler(gomock.NewController(t)) + handler.EXPECT(). + Start(nil). + MaxTimes(1). + Return(fmt.Errorf("")) + return handler + }(), + }, + wantErr: true, + }, + { + name: "Expect config to be passed to Start call", + fields: fields{ + config: demoHandlerConfig, + handler: func() api.ProtocolHandler { + handler := mock.NewMockProtocolHandler(gomock.NewController(t)) + handler.EXPECT(). + Start(demoHandlerConfig). + MaxTimes(1). + Return(fmt.Errorf("")) + return handler + }(), + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := &endpoint{ + name: tt.fields.name, + handler: tt.fields.handler, + config: tt.fields.config, + } + if err := e.Start(); (err != nil) != tt.wantErr { + t.Errorf("Start() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} diff --git a/internal/plugins/loading.go b/internal/plugins/loading.go index 1ef3ad7..5e67ab2 100644 --- a/internal/plugins/loading.go +++ b/internal/plugins/loading.go @@ -1,3 +1,4 @@ +//go:generate mockgen -source=loading.go -destination=./../../mock/plugins_mock.go -package=mock package plugins import ( @@ -42,7 +43,7 @@ func (h *handlerRegistry) HandlerForName(handlerName string) (instance api.Proto func (h *handlerRegistry) RegisterHandler(handlerName string, handlerProvider api.PluginInstanceFactory, subCommands ...*cobra.Command) { if _, exists := h.handlers[handlerName]; exists { - panic(fmt.Sprintf("plugin %s already registered - there's something strange...in the neighborhood")) + panic(fmt.Sprintf("handler with name %s is already registered - there's something strange...in the neighborhood", handlerName)) } h.handlers[handlerName] = handlerProvider diff --git a/internal/plugins/loading_test.go b/internal/plugins/loading_test.go new file mode 100644 index 0000000..7ad9ae7 --- /dev/null +++ b/internal/plugins/loading_test.go @@ -0,0 +1,109 @@ +package plugins + +import ( + "github.com/baez90/inetmock/pkg/api" + "github.com/spf13/cobra" + "reflect" + "testing" +) + +func Test_handlerRegistry_PluginCommands(t *testing.T) { + type fields struct { + handlers map[string]api.PluginInstanceFactory + pluginCommands []*cobra.Command + } + tests := []struct { + name string + fields fields + want []*cobra.Command + }{ + { + name: "Default is an nil array of commands", + fields: fields{}, + want: nil, + }, + { + name: "Returns a copy of the given array of commands", + fields: fields{ + pluginCommands: []*cobra.Command{ + { + Use: "my-super-command", + Short: "bla bla bla, description", + }, + }, + }, + want: []*cobra.Command{ + { + Use: "my-super-command", + Short: "bla bla bla, description", + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h := handlerRegistry{ + handlers: tt.fields.handlers, + pluginCommands: tt.fields.pluginCommands, + } + if got := h.PluginCommands(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("PluginCommands() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_handlerRegistry_HandlerForName(t *testing.T) { + type fields struct { + handlers map[string]api.PluginInstanceFactory + pluginCommands []*cobra.Command + } + type args struct { + handlerName string + } + tests := []struct { + name string + fields fields + args args + wantInstance api.ProtocolHandler + wantOk bool + }{ + { + name: "No instance if nothing is registered", + fields: fields{}, + args: args{}, + wantInstance: nil, + wantOk: false, + }, + { + name: "Nil instance from pseudo factory", + fields: fields{ + handlers: map[string]api.PluginInstanceFactory{ + "pseudo": func() api.ProtocolHandler { + return nil + }, + }, + }, + args: args{ + handlerName: "pseudo", + }, + wantInstance: nil, + wantOk: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h := &handlerRegistry{ + handlers: tt.fields.handlers, + pluginCommands: tt.fields.pluginCommands, + } + gotInstance, gotOk := h.HandlerForName(tt.args.handlerName) + if !reflect.DeepEqual(gotInstance, tt.wantInstance) { + t.Errorf("HandlerForName() gotInstance = %v, want %v", gotInstance, tt.wantInstance) + } + if gotOk != tt.wantOk { + t.Errorf("HandlerForName() gotOk = %v, want %v", gotOk, tt.wantOk) + } + }) + } +} diff --git a/mock_config.yaml b/mock_config.yaml index 450facc..ad98264 100644 --- a/mock_config.yaml +++ b/mock_config.yaml @@ -33,7 +33,9 @@ endpoints: plainHttp: handler: http_mock listenAddress: 0.0.0.0 - port: 80 + ports: + - 80 + - 8080 options: <<: *httpResponseRules proxy: @@ -46,7 +48,9 @@ endpoints: httpsDowngrade: handler: tls_interceptor listenAddress: 0.0.0.0 - port: 443 + ports: + - 443 + - 8443 options: <<: *tlsOptions target: diff --git a/pkg/api/protocol_handler.go b/pkg/api/protocol_handler.go index 5b3607a..27c99b3 100644 --- a/pkg/api/protocol_handler.go +++ b/pkg/api/protocol_handler.go @@ -1,9 +1,9 @@ +//go:generate mockgen -source=protocol_handler.go -destination=./../../mock/protocol_handler_mock.go -package=mock package api import ( "github.com/baez90/inetmock/internal/config" "go.uber.org/zap" - "sync" ) type PluginInstanceFactory func() ProtocolHandler @@ -11,6 +11,6 @@ type PluginInstanceFactory func() ProtocolHandler type LoggingFactory func() (*zap.Logger, error) type ProtocolHandler interface { - Start(config config.HandlerConfig) - Shutdown(wg *sync.WaitGroup) + Start(config config.HandlerConfig) error + Shutdown() error } diff --git a/pkg/logging/factory.go b/pkg/logging/factory.go index dc671ec..7ba4ea0 100644 --- a/pkg/logging/factory.go +++ b/pkg/logging/factory.go @@ -17,7 +17,9 @@ func ConfigureLogging( ) { loggingConfig.Level = level loggingConfig.Development = developmentLogging - loggingConfig.InitialFields = initialFields + if initialFields != nil { + loggingConfig.InitialFields = initialFields + } } func ParseLevel(levelString string) zap.AtomicLevel { @@ -37,6 +39,10 @@ func ParseLevel(levelString string) zap.AtomicLevel { } } -func CreateLogger() (*zap.Logger, error) { - return loggingConfig.Build() +func CreateLogger() (Logger, error) { + if zapLogger, err := loggingConfig.Build(); err != nil { + return nil, err + } else { + return NewLogger(zapLogger), nil + } } diff --git a/pkg/logging/factory_test.go b/pkg/logging/factory_test.go new file mode 100644 index 0000000..061462d --- /dev/null +++ b/pkg/logging/factory_test.go @@ -0,0 +1,166 @@ +package logging + +import ( + "go.uber.org/zap" + "reflect" + "testing" +) + +func TestParseLevel(t *testing.T) { + type args struct { + levelString string + } + tests := []struct { + name string + args args + want zap.AtomicLevel + }{ + { + name: "Test parse DEBUG level", + args: args{ + levelString: "DEBUG", + }, + want: zap.NewAtomicLevelAt(zap.DebugLevel), + }, + { + name: "Test parse DeBuG level", + args: args{ + levelString: "DeBuG", + }, + want: zap.NewAtomicLevelAt(zap.DebugLevel), + }, + { + name: "Test parse INFO level", + args: args{ + levelString: "INFO", + }, + want: zap.NewAtomicLevelAt(zap.InfoLevel), + }, + { + name: "Test parse InFo level", + args: args{ + levelString: "InFo", + }, + want: zap.NewAtomicLevelAt(zap.InfoLevel), + }, + { + name: "Test parse WARN level", + args: args{ + levelString: "WARN", + }, + want: zap.NewAtomicLevelAt(zap.WarnLevel), + }, + { + name: "Test parse WaRn level", + args: args{ + levelString: "WaRn", + }, + want: zap.NewAtomicLevelAt(zap.WarnLevel), + }, + { + name: "Test parse ERROR level", + args: args{ + levelString: "ERROR", + }, + want: zap.NewAtomicLevelAt(zap.ErrorLevel), + }, + { + name: "Test parse ErRoR level", + args: args{ + levelString: "ErRoR", + }, + want: zap.NewAtomicLevelAt(zap.ErrorLevel), + }, + { + name: "Test parse FATAL level", + args: args{ + levelString: "FATAL", + }, + want: zap.NewAtomicLevelAt(zap.FatalLevel), + }, + { + name: "Test parse FaTaL level", + args: args{ + levelString: "FaTaL", + }, + want: zap.NewAtomicLevelAt(zap.FatalLevel), + }, + { + name: "Fallback to INFO level if unknown level", + args: args{ + levelString: "asdf23423", + }, + want: zap.NewAtomicLevelAt(zap.InfoLevel), + }, + { + name: "Fallback to INFO level if no level", + args: args{ + levelString: "", + }, + want: zap.NewAtomicLevelAt(zap.InfoLevel), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := ParseLevel(tt.args.levelString); !reflect.DeepEqual(got, tt.want) { + t.Errorf("ParseLevel() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestConfigureLogging(t *testing.T) { + type args struct { + level zap.AtomicLevel + developmentLogging bool + initialFields map[string]interface{} + } + tests := []struct { + name string + args args + }{ + { + name: "Test configure defaults", + args: args{}, + }, + { + name: "Test configure with initialFields", + args: args{ + initialFields: map[string]interface{}{ + "asdf": "hello, World", + }, + }, + }, + { + name: "Test configure development logging enabled", + args: args{ + developmentLogging: true, + }, + }, + { + name: "Test configure log level", + args: args{ + level: zap.NewAtomicLevelAt(zap.FatalLevel), + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ConfigureLogging(tt.args.level, tt.args.developmentLogging, tt.args.initialFields) + if loggingConfig.Development != tt.args.developmentLogging { + t.Errorf("loggingConfig.Development = %t, want %t", loggingConfig.Development, tt.args.developmentLogging) + return + } + + if loggingConfig.Level != tt.args.level { + t.Errorf("loggingConfig.Level = %v, want %v", loggingConfig.Level, tt.args.level) + return + } + + if tt.args.initialFields != nil && !reflect.DeepEqual(loggingConfig.InitialFields, tt.args.initialFields) { + t.Errorf("loggingConfig.InitialFields = %v, want %v", loggingConfig.InitialFields, tt.args.initialFields) + return + } + }) + } +} diff --git a/pkg/logging/logger.go b/pkg/logging/logger.go new file mode 100644 index 0000000..2cce141 --- /dev/null +++ b/pkg/logging/logger.go @@ -0,0 +1,60 @@ +//go:generate mockgen -source=logger.go -destination=./../../mock/logger_mock.go -package=mock +package logging + +import "go.uber.org/zap" + +type Logger interface { + Named(s string) Logger + With(fields ...zap.Field) Logger + Debug(msg string, fields ...zap.Field) + Info(msg string, fields ...zap.Field) + Warn(msg string, fields ...zap.Field) + Error(msg string, fields ...zap.Field) + Panic(msg string, fields ...zap.Field) + Fatal(msg string, fields ...zap.Field) + Sync() error +} + +type logger struct { + underlyingLogger *zap.Logger +} + +func NewLogger(underlyingLogger *zap.Logger) *logger { + return &logger{underlyingLogger: underlyingLogger} +} + +func (l logger) Named(s string) Logger { + return NewLogger(l.underlyingLogger.Named(s)) +} + +func (l logger) With(fields ...zap.Field) Logger { + return NewLogger(l.underlyingLogger.With(fields...)) +} + +func (l logger) Debug(msg string, fields ...zap.Field) { + l.underlyingLogger.Debug(msg, fields...) +} + +func (l logger) Info(msg string, fields ...zap.Field) { + l.underlyingLogger.Info(msg, fields...) +} + +func (l logger) Warn(msg string, fields ...zap.Field) { + l.underlyingLogger.Warn(msg, fields...) +} + +func (l logger) Error(msg string, fields ...zap.Field) { + l.underlyingLogger.Error(msg, fields...) +} + +func (l logger) Panic(msg string, fields ...zap.Field) { + l.underlyingLogger.Panic(msg, fields...) +} + +func (l logger) Fatal(msg string, fields ...zap.Field) { + l.underlyingLogger.Fatal(msg, fields...) +} + +func (l logger) Sync() error { + return l.underlyingLogger.Sync() +} diff --git a/plugins/dns_mock/go.sum b/plugins/dns_mock/go.sum index 3a5b570..689d313 100644 --- a/plugins/dns_mock/go.sum +++ b/plugins/dns_mock/go.sum @@ -34,6 +34,7 @@ github.com/gogo/protobuf v1.2.1/go.mod h1:hp+jE20tsWTFYpLwKvXlhS1hjn+gTNwPg2I6zV github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/groupcache v0.0.0-20190129154638-5b532d6fd5ef/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= +github.com/golang/mock v1.4.3/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= @@ -187,6 +188,7 @@ golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200409092240-59c9f1ba88fa h1:mQTN3ECqfsViCNBgq+A40vdwhkGykrrQlYe3mPj6BoU= golang.org/x/sys v0.0.0-20200409092240-59c9f1ba88fa/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= @@ -196,6 +198,7 @@ golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGm golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191029190741-b9c20aec41a5/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= @@ -225,3 +228,5 @@ gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.1-2019.2.3 h1:3JgtbtFHMiCmsznwGVTUWbgGov+pVqnlf1dEJTNAXeM= honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= +rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0= +rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA= diff --git a/plugins/dns_mock/main.go b/plugins/dns_mock/main.go index 27f30b6..6b1ff1c 100644 --- a/plugins/dns_mock/main.go +++ b/plugins/dns_mock/main.go @@ -3,17 +3,17 @@ package main import ( "fmt" "github.com/baez90/inetmock/internal/config" + "github.com/baez90/inetmock/pkg/logging" "github.com/miekg/dns" "go.uber.org/zap" - "sync" ) type dnsHandler struct { - logger *zap.Logger + logger logging.Logger dnsServer []*dns.Server } -func (d *dnsHandler) Start(config config.HandlerConfig) { +func (d *dnsHandler) Start(config config.HandlerConfig) (err error) { options := loadFromConfig(config.Options()) addr := fmt.Sprintf("%s:%d", config.ListenAddress(), config.Port()) @@ -51,6 +51,7 @@ func (d *dnsHandler) Start(config config.HandlerConfig) { for _, dnsServer := range d.dnsServer { go d.startServer(dnsServer) } + return } func (d *dnsHandler) startServer(dnsServer *dns.Server) { @@ -62,7 +63,7 @@ func (d *dnsHandler) startServer(dnsServer *dns.Server) { } } -func (d *dnsHandler) Shutdown(wg *sync.WaitGroup) { +func (d *dnsHandler) Shutdown() error { d.logger.Info("shutting down DNS mock") for _, dnsServer := range d.dnsServer { if err := dnsServer.Shutdown(); err != nil { @@ -72,5 +73,5 @@ func (d *dnsHandler) Shutdown(wg *sync.WaitGroup) { ) } } - wg.Done() + return nil } diff --git a/plugins/dns_mock/regex_handler.go b/plugins/dns_mock/regex_handler.go index fb3cc0e..d245bcb 100644 --- a/plugins/dns_mock/regex_handler.go +++ b/plugins/dns_mock/regex_handler.go @@ -1,6 +1,7 @@ package main import ( + "github.com/baez90/inetmock/pkg/logging" "github.com/miekg/dns" "go.uber.org/zap" ) @@ -8,7 +9,7 @@ import ( type regexHandler struct { routes []resolverRule fallback ResolverFallback - logger *zap.Logger + logger logging.Logger } func (r2 *regexHandler) AddRule(rule resolverRule) { diff --git a/plugins/http_mock/go.sum b/plugins/http_mock/go.sum index 34c4685..0495479 100644 --- a/plugins/http_mock/go.sum +++ b/plugins/http_mock/go.sum @@ -34,6 +34,7 @@ github.com/gogo/protobuf v1.2.1/go.mod h1:hp+jE20tsWTFYpLwKvXlhS1hjn+gTNwPg2I6zV github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/groupcache v0.0.0-20190129154638-5b532d6fd5ef/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= +github.com/golang/mock v1.4.3/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= @@ -177,6 +178,7 @@ golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200409092240-59c9f1ba88fa h1:mQTN3ECqfsViCNBgq+A40vdwhkGykrrQlYe3mPj6BoU= golang.org/x/sys v0.0.0-20200409092240-59c9f1ba88fa/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= @@ -186,6 +188,7 @@ golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGm golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191029190741-b9c20aec41a5/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= @@ -215,3 +218,5 @@ gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.1-2019.2.3 h1:3JgtbtFHMiCmsznwGVTUWbgGov+pVqnlf1dEJTNAXeM= honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= +rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0= +rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA= diff --git a/plugins/http_mock/http_handler.go b/plugins/http_mock/http_handler.go index 987c4d7..57a4d67 100644 --- a/plugins/http_mock/http_handler.go +++ b/plugins/http_mock/http_handler.go @@ -2,6 +2,7 @@ package main import ( "bytes" + "github.com/baez90/inetmock/pkg/logging" "go.uber.org/zap" "net/http" ) @@ -16,7 +17,7 @@ func (p *httpHandler) setupRoute(rule targetRule) { p.router.Handler(rule.Pattern(), createHandlerForTarget(p.logger, rule.response)) } -func createHandlerForTarget(logger *zap.Logger, targetPath string) http.Handler { +func createHandlerForTarget(logger logging.Logger, targetPath string) http.Handler { return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { headerWriter := &bytes.Buffer{} request.Header.Write(headerWriter) diff --git a/plugins/http_mock/main.go b/plugins/http_mock/main.go index 6f586a2..27a1362 100644 --- a/plugins/http_mock/main.go +++ b/plugins/http_mock/main.go @@ -3,9 +3,9 @@ package main import ( "fmt" "github.com/baez90/inetmock/internal/config" + "github.com/baez90/inetmock/pkg/logging" "go.uber.org/zap" "net/http" - "sync" ) const ( @@ -13,12 +13,12 @@ const ( ) type httpHandler struct { - logger *zap.Logger + logger logging.Logger router *RegexpHandler server *http.Server } -func (p *httpHandler) Start(config config.HandlerConfig) { +func (p *httpHandler) Start(config config.HandlerConfig) (err error) { options := loadFromConfig(config.Options()) addr := fmt.Sprintf("%s:%d", config.ListenAddress(), config.Port()) p.server = &http.Server{Addr: addr, Handler: p.router} @@ -31,18 +31,22 @@ func (p *httpHandler) Start(config config.HandlerConfig) { } go p.startServer() + return } -func (p *httpHandler) Shutdown(wg *sync.WaitGroup) { +func (p *httpHandler) Shutdown() (err error) { p.logger.Info("Shutting down HTTP mock") - if err := p.server.Close(); err != nil { + if err = p.server.Close(); err != nil { p.logger.Error( "failed to shutdown HTTP server", zap.Error(err), ) + err = fmt.Errorf( + "failed to shutdown HTTP server: %w", + err, + ) } - - wg.Done() + return } func (p *httpHandler) startServer() { diff --git a/plugins/http_proxy/go.sum b/plugins/http_proxy/go.sum index 6397fa2..0703ef9 100644 --- a/plugins/http_proxy/go.sum +++ b/plugins/http_proxy/go.sum @@ -38,6 +38,7 @@ github.com/gogo/protobuf v1.2.1/go.mod h1:hp+jE20tsWTFYpLwKvXlhS1hjn+gTNwPg2I6zV github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/groupcache v0.0.0-20190129154638-5b532d6fd5ef/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= +github.com/golang/mock v1.4.3/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= @@ -182,6 +183,7 @@ golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200409092240-59c9f1ba88fa h1:mQTN3ECqfsViCNBgq+A40vdwhkGykrrQlYe3mPj6BoU= golang.org/x/sys v0.0.0-20200409092240-59c9f1ba88fa/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= @@ -191,6 +193,7 @@ golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGm golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191029190741-b9c20aec41a5/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= @@ -222,3 +225,5 @@ gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.1-2019.2.3 h1:3JgtbtFHMiCmsznwGVTUWbgGov+pVqnlf1dEJTNAXeM= honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= +rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0= +rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA= diff --git a/plugins/http_proxy/main.go b/plugins/http_proxy/main.go index 08b5496..0176053 100644 --- a/plugins/http_proxy/main.go +++ b/plugins/http_proxy/main.go @@ -3,10 +3,10 @@ package main import ( "fmt" "github.com/baez90/inetmock/internal/config" + "github.com/baez90/inetmock/pkg/logging" "go.uber.org/zap" "gopkg.in/elazarl/goproxy.v1" "net/http" - "sync" ) const ( @@ -14,12 +14,12 @@ const ( ) type httpProxy struct { - logger *zap.Logger + logger logging.Logger proxy *goproxy.ProxyHttpServer server *http.Server } -func (h *httpProxy) Start(config config.HandlerConfig) { +func (h *httpProxy) Start(config config.HandlerConfig) (err error) { options := loadFromConfig(config.Options()) addr := fmt.Sprintf("%s:%d", config.ListenAddress(), config.Port()) h.server = &http.Server{Addr: addr, Handler: h.proxy} @@ -33,6 +33,7 @@ func (h *httpProxy) Start(config config.HandlerConfig) { } h.proxy.OnRequest().Do(proxyHandler) go h.startProxy() + return } func (h *httpProxy) startProxy() { @@ -44,13 +45,18 @@ func (h *httpProxy) startProxy() { } } -func (h *httpProxy) Shutdown(wg *sync.WaitGroup) { - defer wg.Done() +func (h *httpProxy) Shutdown() (err error) { h.logger.Info("Shutting down HTTP proxy") - if err := h.server.Close(); err != nil { + if err = h.server.Close(); err != nil { h.logger.Error( "failed to shutdown proxy endpoint", zap.Error(err), ) + + err = fmt.Errorf( + "failed to shutdown proxy endpoint: %w", + err, + ) } + return } diff --git a/plugins/http_proxy/proxy_handler.go b/plugins/http_proxy/proxy_handler.go index b0a3798..55a6e8c 100644 --- a/plugins/http_proxy/proxy_handler.go +++ b/plugins/http_proxy/proxy_handler.go @@ -1,6 +1,7 @@ package main import ( + "github.com/baez90/inetmock/pkg/logging" "go.uber.org/zap" "gopkg.in/elazarl/goproxy.v1" "io" @@ -12,7 +13,7 @@ import ( type proxyHttpHandler struct { options httpProxyOptions - logger *zap.Logger + logger logging.Logger } /* diff --git a/plugins/tls_interceptor/cert_store.go b/plugins/tls_interceptor/cert_store.go index 807674e..229348d 100644 --- a/plugins/tls_interceptor/cert_store.go +++ b/plugins/tls_interceptor/cert_store.go @@ -6,6 +6,7 @@ import ( "crypto/x509" "crypto/x509/pkix" "fmt" + "github.com/baez90/inetmock/pkg/logging" "github.com/baez90/inetmock/pkg/path" "go.uber.org/zap" "math/big" @@ -22,7 +23,7 @@ type certStore struct { caPrivateKey interface{} certCache map[string]*tls.Certificate timeSourceInstance timeSource - logger *zap.Logger + logger logging.Logger } func (cs *certStore) timeSource() timeSource { diff --git a/plugins/tls_interceptor/cert_store_test.go b/plugins/tls_interceptor/cert_store_test.go index a4de478..31ebe6e 100644 --- a/plugins/tls_interceptor/cert_store_test.go +++ b/plugins/tls_interceptor/cert_store_test.go @@ -3,6 +3,7 @@ package main import ( "crypto/tls" "crypto/x509" + "github.com/baez90/inetmock/pkg/logging" "go.uber.org/zap" "io/ioutil" "os" @@ -85,7 +86,8 @@ func Test_generateDomainCert(t *testing.T) { }, } - logger, _ := zap.NewDevelopment() + zapLogger, _ := zap.NewDevelopment() + logger := logging.NewLogger(zapLogger) certStore := certStore{ options: options, diff --git a/plugins/tls_interceptor/generate_ca_cmd.go b/plugins/tls_interceptor/generate_ca_cmd.go index 7ac59ae..fe60313 100644 --- a/plugins/tls_interceptor/generate_ca_cmd.go +++ b/plugins/tls_interceptor/generate_ca_cmd.go @@ -1,6 +1,7 @@ package main import ( + "github.com/baez90/inetmock/pkg/logging" "github.com/spf13/cobra" "go.uber.org/zap" "time" @@ -14,7 +15,7 @@ const ( generateCANotAfterRelative = "not-after" ) -func generateCACmd(logger *zap.Logger) *cobra.Command { +func generateCACmd(logger logging.Logger) *cobra.Command { cmd := &cobra.Command{ Use: "generate-ca", Short: "Generate a new CA certificate and corresponding key", @@ -31,7 +32,7 @@ func generateCACmd(logger *zap.Logger) *cobra.Command { return cmd } -func getDurationFlag(cmd *cobra.Command, flagName string, logger *zap.Logger) (val time.Duration, err error) { +func getDurationFlag(cmd *cobra.Command, flagName string, logger logging.Logger) (val time.Duration, err error) { if val, err = cmd.Flags().GetDuration(flagName); err != nil { logger.Error( "failed to parse parse flag", @@ -42,7 +43,7 @@ func getDurationFlag(cmd *cobra.Command, flagName string, logger *zap.Logger) (v return } -func getStringFlag(cmd *cobra.Command, flagName string, logger *zap.Logger) (val string, err error) { +func getStringFlag(cmd *cobra.Command, flagName string, logger logging.Logger) (val string, err error) { if val, err = cmd.Flags().GetString(flagName); err != nil { logger.Error( "failed to parse parse flag", @@ -53,7 +54,7 @@ func getStringFlag(cmd *cobra.Command, flagName string, logger *zap.Logger) (val return } -func runGenerateCA(logger *zap.Logger) func(cmd *cobra.Command, args []string) { +func runGenerateCA(logger logging.Logger) func(cmd *cobra.Command, args []string) { return func(cmd *cobra.Command, args []string) { var certOutPath, keyOutPath, curveName string var notBefore, notAfter time.Duration diff --git a/plugins/tls_interceptor/go.sum b/plugins/tls_interceptor/go.sum index 34c4685..0495479 100644 --- a/plugins/tls_interceptor/go.sum +++ b/plugins/tls_interceptor/go.sum @@ -34,6 +34,7 @@ github.com/gogo/protobuf v1.2.1/go.mod h1:hp+jE20tsWTFYpLwKvXlhS1hjn+gTNwPg2I6zV github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/groupcache v0.0.0-20190129154638-5b532d6fd5ef/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= +github.com/golang/mock v1.4.3/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= @@ -177,6 +178,7 @@ golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200409092240-59c9f1ba88fa h1:mQTN3ECqfsViCNBgq+A40vdwhkGykrrQlYe3mPj6BoU= golang.org/x/sys v0.0.0-20200409092240-59c9f1ba88fa/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= @@ -186,6 +188,7 @@ golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGm golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191029190741-b9c20aec41a5/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= @@ -215,3 +218,5 @@ gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.1-2019.2.3 h1:3JgtbtFHMiCmsznwGVTUWbgGov+pVqnlf1dEJTNAXeM= honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= +rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0= +rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA= diff --git a/plugins/tls_interceptor/main.go b/plugins/tls_interceptor/main.go index 24af7d3..550d4f0 100644 --- a/plugins/tls_interceptor/main.go +++ b/plugins/tls_interceptor/main.go @@ -5,6 +5,7 @@ import ( "crypto/x509" "fmt" "github.com/baez90/inetmock/internal/config" + "github.com/baez90/inetmock/pkg/logging" "go.uber.org/zap" "net" "sync" @@ -16,7 +17,7 @@ const ( ) type tlsInterceptor struct { - logger *zap.Logger + logger logging.Logger listener net.Listener certStore *certStore options *tlsOptions @@ -25,8 +26,7 @@ type tlsInterceptor struct { currentConnections []*proxyConn } -func (t *tlsInterceptor) Start(config config.HandlerConfig) { - var err error +func (t *tlsInterceptor) Start(config config.HandlerConfig) (err error) { t.options = loadFromConfig(config.Options()) addr := fmt.Sprintf("%s:%d", config.ListenAddress(), config.Port()) @@ -46,6 +46,11 @@ func (t *tlsInterceptor) Start(config config.HandlerConfig) { "failed to initialize CA cert", zap.Error(err), ) + err = fmt.Errorf( + "failed to initialize CA cert: %w", + err, + ) + return } rootCaPool := x509.NewCertPool() @@ -61,13 +66,18 @@ func (t *tlsInterceptor) Start(config config.HandlerConfig) { "failed to create tls listener", zap.Error(err), ) + err = fmt.Errorf( + "failed to create tls listener: %w", + err, + ) return } go t.startListener() + return } -func (t *tlsInterceptor) Shutdown(wg *sync.WaitGroup) { +func (t *tlsInterceptor) Shutdown() (err error) { t.logger.Info("Shutting down TLS interceptor") t.shutdownRequested = true done := make(chan struct{}) @@ -78,17 +88,21 @@ func (t *tlsInterceptor) Shutdown(wg *sync.WaitGroup) { select { case <-done: - wg.Done() + return case <-time.After(5 * time.Second): for _, proxyConn := range t.currentConnections { - if err := proxyConn.Close(); err != nil { + if err = proxyConn.Close(); err != nil { t.logger.Error( "error while closing remaining proxy connections", zap.Error(err), ) + err = fmt.Errorf( + "error while closing remaining proxy connections: %w", + err, + ) } } - wg.Done() + return } }