Improve config and startup handling
- use `Unmarshal` method of viper - move config loading, defaulting and stuff to config package - move serve to an extra command and move keep only global flags in root command - add startup logic to onInit method in root command - update mocks - clean config structs - adopt changes in plugins - update default config
This commit is contained in:
parent
1775d3dbc6
commit
91f0cf6963
40 changed files with 561 additions and 975 deletions
8
Makefile
8
Makefile
|
@ -2,7 +2,8 @@ VERSION = $(shell git describe --dirty --tags --always)
|
||||||
DIR = $(dir $(realpath $(firstword $(MAKEFILE_LIST))))
|
DIR = $(dir $(realpath $(firstword $(MAKEFILE_LIST))))
|
||||||
BUILD_PATH = $(DIR)/main.go
|
BUILD_PATH = $(DIR)/main.go
|
||||||
PKGS = $(shell go list ./...)
|
PKGS = $(shell go list ./...)
|
||||||
TEST_PKGS = $(shell find . -type f -name "*_test.go" -not -path "./plugins/*" -not -path "*/mock/*" -printf '%h\n' | sort -u)
|
TEST_PKGS = $(shell go list ./...)
|
||||||
|
PROTO_FILES = $(shell find $(DIR)api/ -type f -name "*.proto")
|
||||||
GOARGS = GOOS=linux GOARCH=amd64
|
GOARGS = GOOS=linux GOARCH=amd64
|
||||||
GO_BUILD_ARGS = -ldflags="-w -s"
|
GO_BUILD_ARGS = -ldflags="-w -s"
|
||||||
GO_CONTAINER_BUILD_ARGS = -ldflags="-w -s" -a -installsuffix cgo
|
GO_CONTAINER_BUILD_ARGS = -ldflags="-w -s" -a -installsuffix cgo
|
||||||
|
@ -14,7 +15,7 @@ DEBUG_ARGS?= --development-logs=true
|
||||||
CONTAINER_BUILDER ?= podman
|
CONTAINER_BUILDER ?= podman
|
||||||
DOCKER_IMAGE ?= inetmock
|
DOCKER_IMAGE ?= inetmock
|
||||||
|
|
||||||
.PHONY: clean all format deps update-deps compile debug generate snapshot-release test cli-cover-report html-cover-report plugins $(PLUGINS) $(GO_GEN_FILES)
|
.PHONY: clean all format deps update-deps compile debug generate protoc snapshot-release test cli-cover-report html-cover-report plugins $(PLUGINS) $(GO_GEN_FILES)
|
||||||
all: clean format compile test plugins
|
all: clean format compile test plugins
|
||||||
|
|
||||||
clean:
|
clean:
|
||||||
|
@ -55,6 +56,9 @@ debug:
|
||||||
generate:
|
generate:
|
||||||
@go generate ./...
|
@go generate ./...
|
||||||
|
|
||||||
|
protoc:
|
||||||
|
@protoc --proto_path $(DIR)api/ --go_out=plugins=grpc:internal/grpc --go_opt=paths=source_relative $(shell find $(DIR)api/ -type f -name "*.proto")
|
||||||
|
|
||||||
snapshot-release:
|
snapshot-release:
|
||||||
@goreleaser release --snapshot --skip-publish --rm-dist
|
@goreleaser release --snapshot --skip-publish --rm-dist
|
||||||
|
|
||||||
|
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"github.com/baez90/inetmock/pkg/cert"
|
"github.com/baez90/inetmock/pkg/cert"
|
||||||
|
config2 "github.com/baez90/inetmock/pkg/config"
|
||||||
"github.com/baez90/inetmock/pkg/logging"
|
"github.com/baez90/inetmock/pkg/logging"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
|
@ -77,11 +78,11 @@ func runGenerateCA(cmd *cobra.Command, args []string) {
|
||||||
zap.String(generateCACertOutPath, certOutPath),
|
zap.String(generateCACertOutPath, certOutPath),
|
||||||
)
|
)
|
||||||
|
|
||||||
generator := cert.NewDefaultGenerator(cert.Options{
|
generator := cert.NewDefaultGenerator(config2.CertOptions{
|
||||||
CertCachePath: certOutPath,
|
CertCachePath: certOutPath,
|
||||||
Curve: cert.CurveType(curveName),
|
Curve: config2.CurveType(curveName),
|
||||||
Validity: cert.ValidityByPurpose{
|
Validity: config2.ValidityByPurpose{
|
||||||
CA: cert.ValidityDuration{
|
CA: config2.ValidityDuration{
|
||||||
NotAfterRelative: notAfter,
|
NotAfterRelative: notAfter,
|
||||||
NotBeforeRelative: notBefore,
|
NotBeforeRelative: notBefore,
|
||||||
},
|
},
|
||||||
|
|
|
@ -1,70 +0,0 @@
|
||||||
package cmd
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/baez90/inetmock/internal/endpoints"
|
|
||||||
"github.com/baez90/inetmock/internal/plugins"
|
|
||||||
"github.com/baez90/inetmock/pkg/api"
|
|
||||||
"github.com/baez90/inetmock/pkg/logging"
|
|
||||||
"github.com/baez90/inetmock/pkg/path"
|
|
||||||
"github.com/spf13/viper"
|
|
||||||
"go.uber.org/zap"
|
|
||||||
"os"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
appIsInitialized = false
|
|
||||||
)
|
|
||||||
|
|
||||||
func initApp() (err error) {
|
|
||||||
if appIsInitialized {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
appIsInitialized = true
|
|
||||||
logging.ConfigureLogging(
|
|
||||||
logging.ParseLevel(logLevel),
|
|
||||||
developmentLogs,
|
|
||||||
map[string]interface{}{"cwd": path.WorkingDirectory()},
|
|
||||||
)
|
|
||||||
logger, _ = logging.CreateLogger()
|
|
||||||
registry := plugins.Registry()
|
|
||||||
endpointManager = endpoints.NewEndpointManager(logger)
|
|
||||||
|
|
||||||
rootCmd.Flags().ParseErrorsWhitelist.UnknownFlags = false
|
|
||||||
_ = rootCmd.ParseFlags(os.Args)
|
|
||||||
|
|
||||||
if err = appConfig.ReadConfig(configFilePath); err != nil {
|
|
||||||
logger.Error(
|
|
||||||
"unrecoverable error occurred during reading the config file",
|
|
||||||
zap.Error(err),
|
|
||||||
)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
viperInst := viper.GetViper()
|
|
||||||
pluginDir := viperInst.GetString("plugins-directory")
|
|
||||||
|
|
||||||
if err = api.InitServices(viperInst, logger); err != nil {
|
|
||||||
logger.Error(
|
|
||||||
"failed to initialize app services",
|
|
||||||
zap.Error(err),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
pluginLoadStartTime := time.Now()
|
|
||||||
if err = registry.LoadPlugins(pluginDir); err != nil {
|
|
||||||
logger.Error("Failed to load plugins",
|
|
||||||
zap.String("pluginsDirectory", pluginDir),
|
|
||||||
zap.Error(err),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
pluginLoadDuration := time.Since(pluginLoadStartTime)
|
|
||||||
logger.Info(
|
|
||||||
"loading plugins completed",
|
|
||||||
zap.Duration("pluginLoadDuration", pluginLoadDuration),
|
|
||||||
)
|
|
||||||
|
|
||||||
pluginsCmd.AddCommand(registry.PluginCommands()...)
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
|
@ -5,10 +5,6 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
pluginsCmd *cobra.Command
|
|
||||||
)
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
pluginsCmd = &cobra.Command{
|
pluginsCmd = &cobra.Command{
|
||||||
Use: "plugins",
|
Use: "plugins",
|
||||||
Short: "Use the plugins prefix to interact with commands that are provided by plugins",
|
Short: "Use the plugins prefix to interact with commands that are provided by plugins",
|
||||||
|
@ -18,5 +14,4 @@ The easiest way to explore what commands are available is to start with 'inetmoc
|
||||||
This help page contains a list of available sub-commands starting with the name of the plugin as a prefix.
|
This help page contains a list of available sub-commands starting with the name of the plugin as a prefix.
|
||||||
`,
|
`,
|
||||||
}
|
}
|
||||||
rootCmd.AddCommand(pluginsCmd, generateCaCmd)
|
)
|
||||||
}
|
|
||||||
|
|
|
@ -1,77 +1,91 @@
|
||||||
package cmd
|
package cmd
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/baez90/inetmock/internal/config"
|
"github.com/baez90/inetmock/internal/plugins"
|
||||||
"github.com/baez90/inetmock/internal/endpoints"
|
|
||||||
"github.com/baez90/inetmock/pkg/api"
|
"github.com/baez90/inetmock/pkg/api"
|
||||||
|
"github.com/baez90/inetmock/pkg/config"
|
||||||
"github.com/baez90/inetmock/pkg/logging"
|
"github.com/baez90/inetmock/pkg/logging"
|
||||||
|
"github.com/baez90/inetmock/pkg/path"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
"github.com/spf13/viper"
|
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
"os"
|
"time"
|
||||||
"os/signal"
|
|
||||||
"strings"
|
|
||||||
"syscall"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
logger logging.Logger
|
logger logging.Logger
|
||||||
rootCmd = cobra.Command{
|
rootCmd *cobra.Command
|
||||||
Use: "",
|
|
||||||
Short: "INetMock is lightweight internet mock",
|
|
||||||
Run: startInetMock,
|
|
||||||
}
|
|
||||||
|
|
||||||
configFilePath string
|
pluginsDirectory string
|
||||||
logLevel string
|
configFilePath string
|
||||||
developmentLogs bool
|
logLevel string
|
||||||
handlers []api.ProtocolHandler
|
developmentLogs bool
|
||||||
appConfig = config.CreateConfig()
|
|
||||||
endpointManager endpoints.EndpointManager
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
rootCmd.PersistentFlags().String("plugins-directory", "", "Directory where plugins should be loaded from")
|
cobra.OnInitialize(onInit)
|
||||||
|
rootCmd = &cobra.Command{
|
||||||
|
Use: "",
|
||||||
|
Short: "INetMock is lightweight internet mock",
|
||||||
|
}
|
||||||
|
|
||||||
|
rootCmd.PersistentFlags().StringVar(&pluginsDirectory, "plugins-directory", "", "Directory where plugins should be loaded from")
|
||||||
rootCmd.PersistentFlags().StringVar(&configFilePath, "config", "", "Path to config file that should be used")
|
rootCmd.PersistentFlags().StringVar(&configFilePath, "config", "", "Path to config file that should be used")
|
||||||
rootCmd.PersistentFlags().StringVar(&logLevel, "log-level", "info", "logging level to use")
|
rootCmd.PersistentFlags().StringVar(&logLevel, "log-level", "info", "logging level to use")
|
||||||
rootCmd.PersistentFlags().BoolVar(&developmentLogs, "development-logs", false, "Enable development mode logs")
|
rootCmd.PersistentFlags().BoolVar(&developmentLogs, "development-logs", false, "Enable development mode logs")
|
||||||
|
|
||||||
appConfig.InitConfig(rootCmd.PersistentFlags())
|
rootCmd.AddCommand(
|
||||||
|
serveCmd,
|
||||||
|
generateCaCmd,
|
||||||
|
pluginsCmd,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func startInetMock(cmd *cobra.Command, args []string) {
|
func onInit() {
|
||||||
for endpointName := range viper.GetStringMap(config.EndpointsKey) {
|
logging.ConfigureLogging(
|
||||||
handlerSubConfig := viper.Sub(strings.Join([]string{config.EndpointsKey, endpointName}, "."))
|
logging.ParseLevel(logLevel),
|
||||||
handlerConfig := config.CreateMultiHandlerConfig(handlerSubConfig)
|
developmentLogs,
|
||||||
if err := endpointManager.CreateEndpoint(endpointName, handlerConfig); err != nil {
|
map[string]interface{}{"cwd": path.WorkingDirectory()},
|
||||||
logger.Warn(
|
|
||||||
"error occurred while creating endpoint",
|
|
||||||
zap.String("endpointName", endpointName),
|
|
||||||
zap.String("handlerName", handlerConfig.HandlerName()),
|
|
||||||
zap.Error(err),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
endpointManager.StartEndpoints()
|
|
||||||
|
|
||||||
signalChannel := make(chan os.Signal, 1)
|
|
||||||
signal.Notify(signalChannel, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT)
|
|
||||||
|
|
||||||
// block until canceled
|
|
||||||
s := <-signalChannel
|
|
||||||
|
|
||||||
logger.Info(
|
|
||||||
"got signal to quit",
|
|
||||||
zap.String("signal", s.String()),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
endpointManager.ShutdownEndpoints()
|
logger, _ = logging.CreateLogger()
|
||||||
|
config.CreateConfig(rootCmd.Flags())
|
||||||
|
appConfig := config.Instance()
|
||||||
|
|
||||||
|
if err := appConfig.ReadConfig(configFilePath); err != nil {
|
||||||
|
logger.Error(
|
||||||
|
"failed to read config file",
|
||||||
|
zap.Error(err),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := api.InitServices(appConfig, logger); err != nil {
|
||||||
|
logger.Error(
|
||||||
|
"failed to initialize app services",
|
||||||
|
zap.Error(err),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
registry := plugins.Registry()
|
||||||
|
|
||||||
|
cfg := config.Instance()
|
||||||
|
|
||||||
|
pluginLoadStartTime := time.Now()
|
||||||
|
if err := registry.LoadPlugins(cfg.PluginsDir()); err != nil {
|
||||||
|
logger.Error("Failed to load plugins",
|
||||||
|
zap.String("pluginsDirectory", cfg.PluginsDir()),
|
||||||
|
zap.Error(err),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
pluginLoadDuration := time.Since(pluginLoadStartTime)
|
||||||
|
logger.Info(
|
||||||
|
"loading plugins completed",
|
||||||
|
zap.Duration("pluginLoadDuration", pluginLoadDuration),
|
||||||
|
)
|
||||||
|
|
||||||
|
pluginsCmd.AddCommand(registry.PluginCommands()...)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func ExecuteRootCommand() error {
|
func ExecuteRootCommand() error {
|
||||||
if err := initApp(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return rootCmd.Execute()
|
return rootCmd.Execute()
|
||||||
}
|
}
|
||||||
|
|
54
internal/cmd/serve.go
Normal file
54
internal/cmd/serve.go
Normal file
|
@ -0,0 +1,54 @@
|
||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/baez90/inetmock/internal/endpoints"
|
||||||
|
"github.com/baez90/inetmock/pkg/config"
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"strings"
|
||||||
|
"syscall"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
endpointManager endpoints.EndpointManager
|
||||||
|
serveCmd = &cobra.Command{
|
||||||
|
Use: "serve",
|
||||||
|
Short: "Starts the INetMock server",
|
||||||
|
Long: ``,
|
||||||
|
Run: startINetMock,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
func startINetMock(_ *cobra.Command, _ []string) {
|
||||||
|
endpointManager = endpoints.NewEndpointManager(logger)
|
||||||
|
cfg := config.Instance()
|
||||||
|
for endpointName, endpointHandler := range cfg.EndpointConfigs() {
|
||||||
|
handlerSubConfig := cfg.Viper().Sub(strings.Join([]string{config.EndpointsKey, endpointName, config.OptionsKey}, "."))
|
||||||
|
endpointHandler.Options = handlerSubConfig
|
||||||
|
if err := endpointManager.CreateEndpoint(endpointName, endpointHandler); err != nil {
|
||||||
|
logger.Warn(
|
||||||
|
"error occurred while creating endpoint",
|
||||||
|
zap.String("endpointName", endpointName),
|
||||||
|
zap.String("handlerName", endpointHandler.Handler),
|
||||||
|
zap.Error(err),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
endpointManager.StartEndpoints()
|
||||||
|
|
||||||
|
signalChannel := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(signalChannel, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT)
|
||||||
|
|
||||||
|
// block until canceled
|
||||||
|
s := <-signalChannel
|
||||||
|
|
||||||
|
logger.Info(
|
||||||
|
"got signal to quit",
|
||||||
|
zap.String("signal", s.String()),
|
||||||
|
)
|
||||||
|
|
||||||
|
endpointManager.ShutdownEndpoints()
|
||||||
|
}
|
|
@ -1,6 +0,0 @@
|
||||||
package config
|
|
||||||
|
|
||||||
const (
|
|
||||||
EndpointsKey = "endpoints"
|
|
||||||
OptionsKey = "options"
|
|
||||||
)
|
|
|
@ -1,55 +0,0 @@
|
||||||
package config
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/baez90/inetmock/pkg/logging"
|
|
||||||
"github.com/baez90/inetmock/pkg/path"
|
|
||||||
"github.com/spf13/pflag"
|
|
||||||
"github.com/spf13/viper"
|
|
||||||
"go.uber.org/zap"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
func CreateConfig() Config {
|
|
||||||
logger, _ := logging.CreateLogger()
|
|
||||||
return &config{
|
|
||||||
logger: logger.Named("Config"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type Config interface {
|
|
||||||
InitConfig(flags *pflag.FlagSet)
|
|
||||||
ReadConfig(configFilePath string) error
|
|
||||||
}
|
|
||||||
|
|
||||||
type config struct {
|
|
||||||
logger logging.Logger
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c config) InitConfig(flags *pflag.FlagSet) {
|
|
||||||
viper.SetConfigName("config")
|
|
||||||
viper.SetConfigType("yaml")
|
|
||||||
viper.AddConfigPath("/etc/inetmock/")
|
|
||||||
viper.AddConfigPath("$HOME/.inetmock")
|
|
||||||
viper.AddConfigPath(".")
|
|
||||||
viper.SetEnvPrefix("INetMock")
|
|
||||||
_ = viper.BindPFlags(flags)
|
|
||||||
viper.SetEnvKeyReplacer(strings.NewReplacer("-", "_"))
|
|
||||||
viper.AutomaticEnv()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *config) ReadConfig(configFilePath string) (err error) {
|
|
||||||
if configFilePath != "" && path.FileExists(configFilePath) {
|
|
||||||
c.logger.Info(
|
|
||||||
"loading config from passed config file path",
|
|
||||||
zap.String("configFilePath", configFilePath),
|
|
||||||
)
|
|
||||||
viper.SetConfigFile(configFilePath)
|
|
||||||
}
|
|
||||||
if err = viper.ReadInConfig(); err != nil {
|
|
||||||
if _, ok := err.(viper.ConfigFileNotFoundError); ok {
|
|
||||||
err = nil
|
|
||||||
c.logger.Warn("failed to load config")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
|
@ -1,49 +0,0 @@
|
||||||
package config
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/baez90/inetmock/pkg/api"
|
|
||||||
"github.com/spf13/viper"
|
|
||||||
)
|
|
||||||
|
|
||||||
type MultiHandlerConfig interface {
|
|
||||||
HandlerName() string
|
|
||||||
ListenAddress() string
|
|
||||||
Ports() []uint16
|
|
||||||
Options() *viper.Viper
|
|
||||||
HandlerConfigs() []api.HandlerConfig
|
|
||||||
}
|
|
||||||
|
|
||||||
type multiHandlerConfig struct {
|
|
||||||
handlerName string
|
|
||||||
ports []uint16
|
|
||||||
listenAddress string
|
|
||||||
options *viper.Viper
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewMultiHandlerConfig(handlerName string, ports []uint16, listenAddress string, options *viper.Viper) MultiHandlerConfig {
|
|
||||||
return &multiHandlerConfig{handlerName: handlerName, ports: ports, listenAddress: listenAddress, options: options}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m multiHandlerConfig) HandlerName() string {
|
|
||||||
return m.handlerName
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m multiHandlerConfig) ListenAddress() string {
|
|
||||||
return m.listenAddress
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m multiHandlerConfig) Ports() []uint16 {
|
|
||||||
return m.ports
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m multiHandlerConfig) Options() *viper.Viper {
|
|
||||||
return m.options
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m multiHandlerConfig) HandlerConfigs() []api.HandlerConfig {
|
|
||||||
configs := make([]api.HandlerConfig, 0)
|
|
||||||
for _, port := range m.ports {
|
|
||||||
configs = append(configs, api.NewHandlerConfig(m.handlerName, port, m.listenAddress, m.options))
|
|
||||||
}
|
|
||||||
return configs
|
|
||||||
}
|
|
|
@ -1,226 +0,0 @@
|
||||||
package config
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/baez90/inetmock/pkg/api"
|
|
||||||
"github.com/spf13/viper"
|
|
||||||
"reflect"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func Test_multiHandlerConfig_HandlerConfigs(t *testing.T) {
|
|
||||||
type fields struct {
|
|
||||||
handlerName string
|
|
||||||
ports []uint16
|
|
||||||
listenAddress string
|
|
||||||
options *viper.Viper
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
fields fields
|
|
||||||
want []api.HandlerConfig
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "Get empty array if no ports are set",
|
|
||||||
fields: fields{},
|
|
||||||
want: make([]api.HandlerConfig, 0),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Get a single handler config if only one port is set",
|
|
||||||
fields: fields{
|
|
||||||
handlerName: "sampleHandler",
|
|
||||||
ports: []uint16{80},
|
|
||||||
listenAddress: "0.0.0.0",
|
|
||||||
options: nil,
|
|
||||||
},
|
|
||||||
want: []api.HandlerConfig{
|
|
||||||
api.NewHandlerConfig("sampleHandler", 80, "0.0.0.0", nil),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Get multiple handler configs if only one port is set",
|
|
||||||
fields: fields{
|
|
||||||
handlerName: "sampleHandler",
|
|
||||||
ports: []uint16{80, 8080},
|
|
||||||
listenAddress: "0.0.0.0",
|
|
||||||
options: nil,
|
|
||||||
},
|
|
||||||
want: []api.HandlerConfig{
|
|
||||||
api.NewHandlerConfig("sampleHandler", 80, "0.0.0.0", nil),
|
|
||||||
api.NewHandlerConfig("sampleHandler", 8080, "0.0.0.0", nil),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
m := multiHandlerConfig{
|
|
||||||
handlerName: tt.fields.handlerName,
|
|
||||||
ports: tt.fields.ports,
|
|
||||||
listenAddress: tt.fields.listenAddress,
|
|
||||||
options: tt.fields.options,
|
|
||||||
}
|
|
||||||
if got := m.HandlerConfigs(); !reflect.DeepEqual(got, tt.want) {
|
|
||||||
t.Errorf("HandlerConfigs() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_multiHandlerConfig_HandlerName(t *testing.T) {
|
|
||||||
type fields struct {
|
|
||||||
handlerName string
|
|
||||||
ports []uint16
|
|
||||||
listenAddress string
|
|
||||||
options *viper.Viper
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
fields fields
|
|
||||||
want string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "Get empty handler name for uninitialized struct",
|
|
||||||
fields: fields{},
|
|
||||||
want: "",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Get expected handler name for initialized struct",
|
|
||||||
fields: fields{
|
|
||||||
handlerName: "sampleHandler",
|
|
||||||
},
|
|
||||||
want: "sampleHandler",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
m := multiHandlerConfig{
|
|
||||||
handlerName: tt.fields.handlerName,
|
|
||||||
ports: tt.fields.ports,
|
|
||||||
listenAddress: tt.fields.listenAddress,
|
|
||||||
options: tt.fields.options,
|
|
||||||
}
|
|
||||||
if got := m.HandlerName(); got != tt.want {
|
|
||||||
t.Errorf("HandlerName() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_multiHandlerConfig_ListenAddress(t *testing.T) {
|
|
||||||
type fields struct {
|
|
||||||
handlerName string
|
|
||||||
ports []uint16
|
|
||||||
listenAddress string
|
|
||||||
options *viper.Viper
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
fields fields
|
|
||||||
want string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "Get empty ListenAddress for uninitialized struct",
|
|
||||||
fields: fields{},
|
|
||||||
want: "",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Get expected ListenAddress for initialized struct",
|
|
||||||
fields: fields{
|
|
||||||
listenAddress: "0.0.0.0",
|
|
||||||
},
|
|
||||||
want: "0.0.0.0",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
m := multiHandlerConfig{
|
|
||||||
handlerName: tt.fields.handlerName,
|
|
||||||
ports: tt.fields.ports,
|
|
||||||
listenAddress: tt.fields.listenAddress,
|
|
||||||
options: tt.fields.options,
|
|
||||||
}
|
|
||||||
if got := m.ListenAddress(); got != tt.want {
|
|
||||||
t.Errorf("ListenAddress() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_multiHandlerConfig_Options(t *testing.T) {
|
|
||||||
type fields struct {
|
|
||||||
handlerName string
|
|
||||||
ports []uint16
|
|
||||||
listenAddress string
|
|
||||||
options *viper.Viper
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
fields fields
|
|
||||||
want *viper.Viper
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "Get nil Options for uninitialized struct",
|
|
||||||
fields: fields{},
|
|
||||||
want: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Get expected Options for initialized struct",
|
|
||||||
fields: fields{
|
|
||||||
options: viper.New(),
|
|
||||||
},
|
|
||||||
want: viper.New(),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
m := multiHandlerConfig{
|
|
||||||
handlerName: tt.fields.handlerName,
|
|
||||||
ports: tt.fields.ports,
|
|
||||||
listenAddress: tt.fields.listenAddress,
|
|
||||||
options: tt.fields.options,
|
|
||||||
}
|
|
||||||
if got := m.Options(); !reflect.DeepEqual(got, tt.want) {
|
|
||||||
t.Errorf("Options() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_multiHandlerConfig_Ports(t *testing.T) {
|
|
||||||
type fields struct {
|
|
||||||
handlerName string
|
|
||||||
ports []uint16
|
|
||||||
listenAddress string
|
|
||||||
options *viper.Viper
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
fields fields
|
|
||||||
want []uint16
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "Get empty Ports for uninitialized struct",
|
|
||||||
fields: fields{},
|
|
||||||
want: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Get expected Ports for initialized struct",
|
|
||||||
fields: fields{
|
|
||||||
ports: []uint16{80, 8080},
|
|
||||||
},
|
|
||||||
want: []uint16{80, 8080},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
m := multiHandlerConfig{
|
|
||||||
handlerName: tt.fields.handlerName,
|
|
||||||
ports: tt.fields.ports,
|
|
||||||
listenAddress: tt.fields.listenAddress,
|
|
||||||
options: tt.fields.options,
|
|
||||||
}
|
|
||||||
if got := m.Ports(); !reflect.DeepEqual(got, tt.want) {
|
|
||||||
t.Errorf("Ports() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,35 +0,0 @@
|
||||||
package config
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/spf13/viper"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
pluginConfigKey = "handler"
|
|
||||||
listenAddressConfigKey = "listenAddress"
|
|
||||||
portConfigKey = "port"
|
|
||||||
portsConfigKey = "ports"
|
|
||||||
)
|
|
||||||
|
|
||||||
func CreateMultiHandlerConfig(handlerConfig *viper.Viper) MultiHandlerConfig {
|
|
||||||
return NewMultiHandlerConfig(
|
|
||||||
handlerConfig.GetString(pluginConfigKey),
|
|
||||||
portsFromConfig(handlerConfig),
|
|
||||||
handlerConfig.GetString(listenAddressConfigKey),
|
|
||||||
handlerConfig.Sub(OptionsKey),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
func portsFromConfig(handlerConfig *viper.Viper) (ports []uint16) {
|
|
||||||
if portsInt := handlerConfig.GetIntSlice(portsConfigKey); len(portsInt) > 0 {
|
|
||||||
for _, port := range portsInt {
|
|
||||||
ports = append(ports, uint16(port))
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if portInt := handlerConfig.GetInt(portConfigKey); portInt > 0 {
|
|
||||||
ports = append(ports, uint16(portInt))
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
|
@ -1,145 +0,0 @@
|
||||||
package config
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"github.com/spf13/viper"
|
|
||||||
"reflect"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestCreateMultiHandlerConfig(t *testing.T) {
|
|
||||||
type args struct {
|
|
||||||
handlerConfig *viper.Viper
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
args args
|
|
||||||
want MultiHandlerConfig
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "Get simple multiHandlerConfig from config",
|
|
||||||
args: args{
|
|
||||||
handlerConfig: configFromString(`
|
|
||||||
handler: sampleHandler
|
|
||||||
listenAddress: 0.0.0.0
|
|
||||||
ports:
|
|
||||||
- 80
|
|
||||||
- 8080
|
|
||||||
options: {}
|
|
||||||
`),
|
|
||||||
},
|
|
||||||
want: &multiHandlerConfig{
|
|
||||||
handlerName: "sampleHandler",
|
|
||||||
ports: []uint16{80, 8080},
|
|
||||||
listenAddress: "0.0.0.0",
|
|
||||||
options: viper.New(),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Get more complex multiHandlerConfig from config",
|
|
||||||
args: args{
|
|
||||||
handlerConfig: configFromString(`
|
|
||||||
handler: sampleHandler
|
|
||||||
listenAddress: 0.0.0.0
|
|
||||||
ports:
|
|
||||||
- 80
|
|
||||||
- 8080
|
|
||||||
options:
|
|
||||||
optionA: asdf
|
|
||||||
optionB: as1234
|
|
||||||
`),
|
|
||||||
},
|
|
||||||
want: &multiHandlerConfig{
|
|
||||||
handlerName: "sampleHandler",
|
|
||||||
ports: []uint16{80, 8080},
|
|
||||||
listenAddress: "0.0.0.0",
|
|
||||||
options: configFromString(`
|
|
||||||
nesting:
|
|
||||||
optionA: asdf
|
|
||||||
optionB: as1234
|
|
||||||
`).Sub("nesting"),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
if got := CreateMultiHandlerConfig(tt.args.handlerConfig); !reflect.DeepEqual(got, tt.want) {
|
|
||||||
t.Errorf("CreateMultiHandlerConfig() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_portsFromConfig(t *testing.T) {
|
|
||||||
type args struct {
|
|
||||||
handlerConfig *viper.Viper
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
args args
|
|
||||||
wantPorts []uint16
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "Empty array if config value is not set",
|
|
||||||
args: args{
|
|
||||||
handlerConfig: viper.New(),
|
|
||||||
},
|
|
||||||
wantPorts: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Array of one if `port` is set",
|
|
||||||
args: args{
|
|
||||||
handlerConfig: configFromString(`
|
|
||||||
port: 80
|
|
||||||
`),
|
|
||||||
},
|
|
||||||
wantPorts: []uint16{80},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Array of one if `ports` is set as array",
|
|
||||||
args: args{
|
|
||||||
handlerConfig: configFromString(`
|
|
||||||
ports:
|
|
||||||
- 80
|
|
||||||
`),
|
|
||||||
},
|
|
||||||
wantPorts: []uint16{80},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Array of two if `ports` is set as array",
|
|
||||||
args: args{
|
|
||||||
handlerConfig: configFromString(`
|
|
||||||
ports:
|
|
||||||
- 80
|
|
||||||
- 8080
|
|
||||||
`),
|
|
||||||
},
|
|
||||||
wantPorts: []uint16{80, 8080},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Array of two if `port` is set as array",
|
|
||||||
args: args{
|
|
||||||
handlerConfig: configFromString(`
|
|
||||||
ports:
|
|
||||||
- 80
|
|
||||||
- 8080
|
|
||||||
`),
|
|
||||||
},
|
|
||||||
wantPorts: []uint16{80, 8080},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
if gotPorts := portsFromConfig(tt.args.handlerConfig); !reflect.DeepEqual(gotPorts, tt.wantPorts) {
|
|
||||||
t.Errorf("portsFromConfig() = %v, want %v", gotPorts, tt.wantPorts)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func configFromString(yaml string) (config *viper.Viper) {
|
|
||||||
config = viper.New()
|
|
||||||
config.SetConfigType("yaml")
|
|
||||||
_ = config.ReadConfig(bytes.NewBufferString(yaml))
|
|
||||||
return
|
|
||||||
}
|
|
|
@ -3,6 +3,7 @@ package endpoints
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/baez90/inetmock/pkg/api"
|
"github.com/baez90/inetmock/pkg/api"
|
||||||
|
"github.com/baez90/inetmock/pkg/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Endpoint interface {
|
type Endpoint interface {
|
||||||
|
@ -14,7 +15,7 @@ type Endpoint interface {
|
||||||
type endpoint struct {
|
type endpoint struct {
|
||||||
name string
|
name string
|
||||||
handler api.ProtocolHandler
|
handler api.ProtocolHandler
|
||||||
config api.HandlerConfig
|
config config.HandlerConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e endpoint) Name() string {
|
func (e endpoint) Name() string {
|
||||||
|
|
|
@ -2,8 +2,8 @@ package endpoints
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/baez90/inetmock/internal/config"
|
|
||||||
"github.com/baez90/inetmock/internal/plugins"
|
"github.com/baez90/inetmock/internal/plugins"
|
||||||
|
config2 "github.com/baez90/inetmock/pkg/config"
|
||||||
"github.com/baez90/inetmock/pkg/logging"
|
"github.com/baez90/inetmock/pkg/logging"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
"sync"
|
"sync"
|
||||||
|
@ -13,7 +13,7 @@ import (
|
||||||
type EndpointManager interface {
|
type EndpointManager interface {
|
||||||
RegisteredEndpoints() []Endpoint
|
RegisteredEndpoints() []Endpoint
|
||||||
StartedEndpoints() []Endpoint
|
StartedEndpoints() []Endpoint
|
||||||
CreateEndpoint(name string, multiHandlerConfig config.MultiHandlerConfig) error
|
CreateEndpoint(name string, multiHandlerConfig config2.MultiHandlerConfig) error
|
||||||
StartEndpoints()
|
StartEndpoints()
|
||||||
ShutdownEndpoints()
|
ShutdownEndpoints()
|
||||||
}
|
}
|
||||||
|
@ -40,16 +40,16 @@ func (e endpointManager) StartedEndpoints() []Endpoint {
|
||||||
return e.properlyStartedEndpoints
|
return e.properlyStartedEndpoints
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *endpointManager) CreateEndpoint(name string, multiHandlerConfig config.MultiHandlerConfig) error {
|
func (e *endpointManager) CreateEndpoint(name string, multiHandlerConfig config2.MultiHandlerConfig) error {
|
||||||
for _, handlerConfig := range multiHandlerConfig.HandlerConfigs() {
|
for _, handlerConfig := range multiHandlerConfig.HandlerConfigs() {
|
||||||
if handler, ok := e.registry.HandlerForName(multiHandlerConfig.HandlerName()); ok {
|
if handler, ok := e.registry.HandlerForName(multiHandlerConfig.Handler); ok {
|
||||||
e.registeredEndpoints = append(e.registeredEndpoints, &endpoint{
|
e.registeredEndpoints = append(e.registeredEndpoints, &endpoint{
|
||||||
name: name,
|
name: name,
|
||||||
handler: handler,
|
handler: handler,
|
||||||
config: handlerConfig,
|
config: handlerConfig,
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
return fmt.Errorf("no matching handler registered for names %s", multiHandlerConfig.HandlerName())
|
return fmt.Errorf("no matching handler registered for names %s", multiHandlerConfig.Handler)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,11 +1,11 @@
|
||||||
package endpoints
|
package endpoints
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/baez90/inetmock/internal/config"
|
|
||||||
api_mock "github.com/baez90/inetmock/internal/mock/api"
|
api_mock "github.com/baez90/inetmock/internal/mock/api"
|
||||||
logging_mock "github.com/baez90/inetmock/internal/mock/logging"
|
logging_mock "github.com/baez90/inetmock/internal/mock/logging"
|
||||||
plugins_mock "github.com/baez90/inetmock/internal/mock/plugins"
|
plugins_mock "github.com/baez90/inetmock/internal/mock/plugins"
|
||||||
"github.com/baez90/inetmock/internal/plugins"
|
"github.com/baez90/inetmock/internal/plugins"
|
||||||
|
config2 "github.com/baez90/inetmock/pkg/config"
|
||||||
"github.com/baez90/inetmock/pkg/logging"
|
"github.com/baez90/inetmock/pkg/logging"
|
||||||
"github.com/golang/mock/gomock"
|
"github.com/golang/mock/gomock"
|
||||||
"testing"
|
"testing"
|
||||||
|
@ -20,7 +20,7 @@ func Test_endpointManager_CreateEndpoint(t *testing.T) {
|
||||||
}
|
}
|
||||||
type args struct {
|
type args struct {
|
||||||
name string
|
name string
|
||||||
multiHandlerConfig config.MultiHandlerConfig
|
multiHandlerConfig config2.MultiHandlerConfig
|
||||||
}
|
}
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
@ -52,12 +52,11 @@ func Test_endpointManager_CreateEndpoint(t *testing.T) {
|
||||||
},
|
},
|
||||||
args: args{
|
args: args{
|
||||||
name: "sampleEndpoint",
|
name: "sampleEndpoint",
|
||||||
multiHandlerConfig: config.NewMultiHandlerConfig(
|
multiHandlerConfig: config2.MultiHandlerConfig{
|
||||||
"sampleHandler",
|
Handler: "sampleHandler",
|
||||||
[]uint16{80},
|
Ports: []uint16{80},
|
||||||
"0.0.0.0",
|
ListenAddress: "0.0.0.0",
|
||||||
nil,
|
},
|
||||||
),
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -83,12 +82,11 @@ func Test_endpointManager_CreateEndpoint(t *testing.T) {
|
||||||
},
|
},
|
||||||
args: args{
|
args: args{
|
||||||
name: "sampleEndpoint",
|
name: "sampleEndpoint",
|
||||||
multiHandlerConfig: config.NewMultiHandlerConfig(
|
multiHandlerConfig: config2.MultiHandlerConfig{
|
||||||
"sampleHandler",
|
Handler: "sampleHandler",
|
||||||
[]uint16{80},
|
Ports: []uint16{80},
|
||||||
"0.0.0.0",
|
ListenAddress: "0.0.0.0",
|
||||||
nil,
|
},
|
||||||
),
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
api_mock "github.com/baez90/inetmock/internal/mock/api"
|
api_mock "github.com/baez90/inetmock/internal/mock/api"
|
||||||
"github.com/baez90/inetmock/pkg/api"
|
"github.com/baez90/inetmock/pkg/api"
|
||||||
|
"github.com/baez90/inetmock/pkg/config"
|
||||||
"github.com/golang/mock/gomock"
|
"github.com/golang/mock/gomock"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
@ -12,7 +13,7 @@ func Test_endpoint_Name(t *testing.T) {
|
||||||
type fields struct {
|
type fields struct {
|
||||||
name string
|
name string
|
||||||
handler api.ProtocolHandler
|
handler api.ProtocolHandler
|
||||||
config api.HandlerConfig
|
config config.HandlerConfig
|
||||||
}
|
}
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
@ -50,7 +51,7 @@ func Test_endpoint_Shutdown(t *testing.T) {
|
||||||
type fields struct {
|
type fields struct {
|
||||||
name string
|
name string
|
||||||
handler api.ProtocolHandler
|
handler api.ProtocolHandler
|
||||||
config api.HandlerConfig
|
config config.HandlerConfig
|
||||||
}
|
}
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
@ -102,17 +103,16 @@ func Test_endpoint_Shutdown(t *testing.T) {
|
||||||
|
|
||||||
func Test_endpoint_Start(t *testing.T) {
|
func Test_endpoint_Start(t *testing.T) {
|
||||||
|
|
||||||
demoHandlerConfig := api.NewHandlerConfig(
|
demoHandlerConfig := config.HandlerConfig{
|
||||||
"sampleHandler",
|
HandlerName: "sampleHandler",
|
||||||
80,
|
Port: 80,
|
||||||
"0.0.0.0",
|
ListenAddress: "0.0.0.0",
|
||||||
nil,
|
}
|
||||||
)
|
|
||||||
|
|
||||||
type fields struct {
|
type fields struct {
|
||||||
name string
|
name string
|
||||||
handler api.ProtocolHandler
|
handler api.ProtocolHandler
|
||||||
config api.HandlerConfig
|
config config.HandlerConfig
|
||||||
}
|
}
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
@ -125,7 +125,7 @@ func Test_endpoint_Start(t *testing.T) {
|
||||||
handler: func() api.ProtocolHandler {
|
handler: func() api.ProtocolHandler {
|
||||||
handler := api_mock.NewMockProtocolHandler(gomock.NewController(t))
|
handler := api_mock.NewMockProtocolHandler(gomock.NewController(t))
|
||||||
handler.EXPECT().
|
handler.EXPECT().
|
||||||
Start(nil).
|
Start(gomock.Any()).
|
||||||
MaxTimes(1).
|
MaxTimes(1).
|
||||||
Return(nil)
|
Return(nil)
|
||||||
return handler
|
return handler
|
||||||
|
@ -139,7 +139,7 @@ func Test_endpoint_Start(t *testing.T) {
|
||||||
handler: func() api.ProtocolHandler {
|
handler: func() api.ProtocolHandler {
|
||||||
handler := api_mock.NewMockProtocolHandler(gomock.NewController(t))
|
handler := api_mock.NewMockProtocolHandler(gomock.NewController(t))
|
||||||
handler.EXPECT().
|
handler.EXPECT().
|
||||||
Start(nil).
|
Start(gomock.Any()).
|
||||||
MaxTimes(1).
|
MaxTimes(1).
|
||||||
Return(fmt.Errorf(""))
|
Return(fmt.Errorf(""))
|
||||||
return handler
|
return handler
|
||||||
|
|
|
@ -10,6 +10,7 @@ import (
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"plugin"
|
"plugin"
|
||||||
"regexp"
|
"regexp"
|
||||||
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -19,6 +20,7 @@ var (
|
||||||
|
|
||||||
type HandlerRegistry interface {
|
type HandlerRegistry interface {
|
||||||
LoadPlugins(pluginsPath string) error
|
LoadPlugins(pluginsPath string) error
|
||||||
|
AvailableHandlers() []string
|
||||||
RegisterHandler(handlerName string, handlerProvider api.PluginInstanceFactory, subCommands ...*cobra.Command)
|
RegisterHandler(handlerName string, handlerProvider api.PluginInstanceFactory, subCommands ...*cobra.Command)
|
||||||
HandlerForName(handlerName string) (api.ProtocolHandler, bool)
|
HandlerForName(handlerName string) (api.ProtocolHandler, bool)
|
||||||
PluginCommands() []*cobra.Command
|
PluginCommands() []*cobra.Command
|
||||||
|
@ -29,11 +31,19 @@ type handlerRegistry struct {
|
||||||
pluginCommands []*cobra.Command
|
pluginCommands []*cobra.Command
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h handlerRegistry) AvailableHandlers() (availableHandlers []string) {
|
||||||
|
for s := range h.handlers {
|
||||||
|
availableHandlers = append(availableHandlers, s)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
func (h handlerRegistry) PluginCommands() []*cobra.Command {
|
func (h handlerRegistry) PluginCommands() []*cobra.Command {
|
||||||
return h.pluginCommands
|
return h.pluginCommands
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *handlerRegistry) HandlerForName(handlerName string) (instance api.ProtocolHandler, ok bool) {
|
func (h *handlerRegistry) HandlerForName(handlerName string) (instance api.ProtocolHandler, ok bool) {
|
||||||
|
handlerName = strings.ToLower(handlerName)
|
||||||
var provider api.PluginInstanceFactory
|
var provider api.PluginInstanceFactory
|
||||||
if provider, ok = h.handlers[handlerName]; ok {
|
if provider, ok = h.handlers[handlerName]; ok {
|
||||||
instance = provider()
|
instance = provider()
|
||||||
|
@ -42,6 +52,7 @@ func (h *handlerRegistry) HandlerForName(handlerName string) (instance api.Proto
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *handlerRegistry) RegisterHandler(handlerName string, handlerProvider api.PluginInstanceFactory, subCommands ...*cobra.Command) {
|
func (h *handlerRegistry) RegisterHandler(handlerName string, handlerProvider api.PluginInstanceFactory, subCommands ...*cobra.Command) {
|
||||||
|
handlerName = strings.ToLower(handlerName)
|
||||||
if _, exists := h.handlers[handlerName]; exists {
|
if _, exists := h.handlers[handlerName]; exists {
|
||||||
panic(fmt.Sprintf("handler with name %s is already registered - there's something strange...in the neighborhood", handlerName))
|
panic(fmt.Sprintf("handler with name %s is already registered - there's something strange...in the neighborhood", handlerName))
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,8 +25,8 @@ tls:
|
||||||
NotBeforeRelative: 168h
|
NotBeforeRelative: 168h
|
||||||
NotAfterRelative: 168h
|
NotAfterRelative: 168h
|
||||||
rootCaCert:
|
rootCaCert:
|
||||||
publicKey: ./ca.pem
|
publicKeyPath: ./ca.pem
|
||||||
privateKey: ./ca.key
|
privateKeyPath: ./ca.key
|
||||||
certCachePath: /tmp/inetmock/
|
certCachePath: /tmp/inetmock/
|
||||||
|
|
||||||
endpoints:
|
endpoints:
|
||||||
|
@ -41,7 +41,8 @@ endpoints:
|
||||||
proxy:
|
proxy:
|
||||||
handler: http_proxy
|
handler: http_proxy
|
||||||
listenAddress: 0.0.0.0
|
listenAddress: 0.0.0.0
|
||||||
port: 3128
|
ports:
|
||||||
|
- 3128
|
||||||
options:
|
options:
|
||||||
target:
|
target:
|
||||||
ipAddress: 127.0.0.1
|
ipAddress: 127.0.0.1
|
||||||
|
@ -59,7 +60,8 @@ endpoints:
|
||||||
plainDns:
|
plainDns:
|
||||||
handler: dns_mock
|
handler: dns_mock
|
||||||
listenAddress: 0.0.0.0
|
listenAddress: 0.0.0.0
|
||||||
port: 53
|
ports:
|
||||||
|
- 53
|
||||||
options:
|
options:
|
||||||
rules:
|
rules:
|
||||||
- pattern: ".*\\.google\\.com"
|
- pattern: ".*\\.google\\.com"
|
||||||
|
@ -73,7 +75,8 @@ endpoints:
|
||||||
dnsOverTlsDowngrade:
|
dnsOverTlsDowngrade:
|
||||||
handler: tls_interceptor
|
handler: tls_interceptor
|
||||||
listenAddress: 0.0.0.0
|
listenAddress: 0.0.0.0
|
||||||
port: 853
|
ports:
|
||||||
|
- 853
|
||||||
options:
|
options:
|
||||||
target:
|
target:
|
||||||
ipAddress: 127.0.0.1
|
ipAddress: 127.0.0.1
|
||||||
|
|
|
@ -1,42 +0,0 @@
|
||||||
package api
|
|
||||||
|
|
||||||
import "github.com/spf13/viper"
|
|
||||||
|
|
||||||
type HandlerConfig interface {
|
|
||||||
HandlerName() string
|
|
||||||
ListenAddress() string
|
|
||||||
Port() uint16
|
|
||||||
Options() *viper.Viper
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewHandlerConfig(handlerName string, port uint16, listenAddress string, options *viper.Viper) HandlerConfig {
|
|
||||||
return &handlerConfig{
|
|
||||||
handlerName: handlerName,
|
|
||||||
port: port,
|
|
||||||
listenAddress: listenAddress,
|
|
||||||
options: options,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type handlerConfig struct {
|
|
||||||
handlerName string
|
|
||||||
port uint16
|
|
||||||
listenAddress string
|
|
||||||
options *viper.Viper
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h handlerConfig) HandlerName() string {
|
|
||||||
return h.handlerName
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h handlerConfig) ListenAddress() string {
|
|
||||||
return h.listenAddress
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h handlerConfig) Port() uint16 {
|
|
||||||
return h.port
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h handlerConfig) Options() *viper.Viper {
|
|
||||||
return h.options
|
|
||||||
}
|
|
|
@ -2,6 +2,7 @@
|
||||||
package api
|
package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"github.com/baez90/inetmock/pkg/config"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -10,6 +11,6 @@ type PluginInstanceFactory func() ProtocolHandler
|
||||||
type LoggingFactory func() (*zap.Logger, error)
|
type LoggingFactory func() (*zap.Logger, error)
|
||||||
|
|
||||||
type ProtocolHandler interface {
|
type ProtocolHandler interface {
|
||||||
Start(config HandlerConfig) error
|
Start(config config.HandlerConfig) error
|
||||||
Shutdown() error
|
Shutdown() error
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,8 +2,8 @@ package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/baez90/inetmock/pkg/cert"
|
"github.com/baez90/inetmock/pkg/cert"
|
||||||
|
config2 "github.com/baez90/inetmock/pkg/config"
|
||||||
"github.com/baez90/inetmock/pkg/logging"
|
"github.com/baez90/inetmock/pkg/logging"
|
||||||
"github.com/spf13/viper"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -19,7 +19,7 @@ type services struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func InitServices(
|
func InitServices(
|
||||||
config *viper.Viper,
|
config config2.Config,
|
||||||
logger logging.Logger,
|
logger logging.Logger,
|
||||||
) error {
|
) error {
|
||||||
certStore, err := cert.NewDefaultStore(config, logger)
|
certStore, err := cert.NewDefaultStore(config, logger)
|
||||||
|
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"fmt"
|
"fmt"
|
||||||
certmock "github.com/baez90/inetmock/internal/mock/cert"
|
certmock "github.com/baez90/inetmock/internal/mock/cert"
|
||||||
|
config2 "github.com/baez90/inetmock/pkg/config"
|
||||||
"github.com/golang/mock/gomock"
|
"github.com/golang/mock/gomock"
|
||||||
"os"
|
"os"
|
||||||
"path"
|
"path"
|
||||||
|
@ -284,13 +285,13 @@ func Test_fileSystemCache_Put(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func setupCertGen() Generator {
|
func setupCertGen() Generator {
|
||||||
return NewDefaultGenerator(Options{
|
return NewDefaultGenerator(config2.CertOptions{
|
||||||
Validity: ValidityByPurpose{
|
Validity: config2.ValidityByPurpose{
|
||||||
Server: ValidityDuration{
|
Server: config2.ValidityDuration{
|
||||||
NotBeforeRelative: serverRelativeValidity,
|
NotBeforeRelative: serverRelativeValidity,
|
||||||
NotAfterRelative: serverRelativeValidity,
|
NotAfterRelative: serverRelativeValidity,
|
||||||
},
|
},
|
||||||
CA: ValidityDuration{
|
CA: config2.ValidityDuration{
|
||||||
NotBeforeRelative: caRelativeValidity,
|
NotBeforeRelative: caRelativeValidity,
|
||||||
NotAfterRelative: caRelativeValidity,
|
NotAfterRelative: caRelativeValidity,
|
||||||
},
|
},
|
||||||
|
|
|
@ -1,19 +1,11 @@
|
||||||
package cert
|
package cert
|
||||||
|
|
||||||
const (
|
const (
|
||||||
CurveTypeP224 CurveType = "P224"
|
|
||||||
CurveTypeP256 CurveType = "P256"
|
|
||||||
CurveTypeP384 CurveType = "P384"
|
|
||||||
CurveTypeP521 CurveType = "P521"
|
|
||||||
CurveTypeED25519 CurveType = "ED25519"
|
|
||||||
|
|
||||||
defaultServerValidityDuration = "168h"
|
defaultServerValidityDuration = "168h"
|
||||||
defaultCAValidityDuration = "17520h"
|
defaultCAValidityDuration = "17520h"
|
||||||
|
|
||||||
certCachePathConfigKey = "tls.certCachePath"
|
certCachePathConfigKey = "tls.certCachePath"
|
||||||
ecdsaCurveConfigKey = "tls.ecdsaCurve"
|
ecdsaCurveConfigKey = "tls.ecdsaCurve"
|
||||||
publicKeyConfigKey = "tls.rootCaCert.publicKey"
|
|
||||||
privateKeyPathConfigKey = "tls.rootCaCert.privateKey"
|
|
||||||
caCertValidityNotBeforeKey = "tls.validity.ca.notBeforeRelative"
|
caCertValidityNotBeforeKey = "tls.validity.ca.notBeforeRelative"
|
||||||
caCertValidityNotAfterKey = "tls.validity.ca.notAfterRelative"
|
caCertValidityNotAfterKey = "tls.validity.ca.notAfterRelative"
|
||||||
serverCertValidityNotBeforeKey = "tls.validity.server.notBeforeRelative"
|
serverCertValidityNotBeforeKey = "tls.validity.server.notBeforeRelative"
|
||||||
|
|
|
@ -8,6 +8,7 @@ import (
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"crypto/x509/pkix"
|
"crypto/x509/pkix"
|
||||||
"encoding/pem"
|
"encoding/pem"
|
||||||
|
config2 "github.com/baez90/inetmock/pkg/config"
|
||||||
"github.com/baez90/inetmock/pkg/defaulting"
|
"github.com/baez90/inetmock/pkg/defaulting"
|
||||||
"math/big"
|
"math/big"
|
||||||
"net"
|
"net"
|
||||||
|
@ -35,11 +36,11 @@ type Generator interface {
|
||||||
ServerCert(options GenerationOptions, ca *tls.Certificate) (*tls.Certificate, error)
|
ServerCert(options GenerationOptions, ca *tls.Certificate) (*tls.Certificate, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewDefaultGenerator(options Options) Generator {
|
func NewDefaultGenerator(options config2.CertOptions) Generator {
|
||||||
return NewGenerator(options, NewTimeSource(), defaultKeyProvider(options))
|
return NewGenerator(options, NewTimeSource(), defaultKeyProvider(options))
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewGenerator(options Options, source TimeSource, provider KeyProvider) Generator {
|
func NewGenerator(options config2.CertOptions, source TimeSource, provider KeyProvider) Generator {
|
||||||
return &generator{
|
return &generator{
|
||||||
options: options,
|
options: options,
|
||||||
provider: provider,
|
provider: provider,
|
||||||
|
@ -48,7 +49,7 @@ func NewGenerator(options Options, source TimeSource, provider KeyProvider) Gene
|
||||||
}
|
}
|
||||||
|
|
||||||
type generator struct {
|
type generator struct {
|
||||||
options Options
|
options config2.CertOptions
|
||||||
provider KeyProvider
|
provider KeyProvider
|
||||||
timeSource TimeSource
|
timeSource TimeSource
|
||||||
outDir string
|
outDir string
|
||||||
|
|
|
@ -1,79 +1,15 @@
|
||||||
package cert
|
package cert
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"github.com/baez90/inetmock/pkg/config"
|
||||||
"github.com/spf13/viper"
|
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
func init() {
|
||||||
requiredConfigKeys = []string{
|
config.AddDefaultValue(certCachePathConfigKey, os.TempDir())
|
||||||
publicKeyConfigKey,
|
config.AddDefaultValue(ecdsaCurveConfigKey, string(config.CurveTypeED25519))
|
||||||
privateKeyPathConfigKey,
|
config.AddDefaultValue(caCertValidityNotBeforeKey, defaultCAValidityDuration)
|
||||||
}
|
config.AddDefaultValue(caCertValidityNotAfterKey, defaultCAValidityDuration)
|
||||||
)
|
config.AddDefaultValue(serverCertValidityNotBeforeKey, defaultServerValidityDuration)
|
||||||
|
config.AddDefaultValue(serverCertValidityNotAfterKey, defaultServerValidityDuration)
|
||||||
type CurveType string
|
|
||||||
|
|
||||||
type File struct {
|
|
||||||
PublicKeyPath string
|
|
||||||
PrivateKeyPath string
|
|
||||||
}
|
|
||||||
|
|
||||||
type ValidityDuration struct {
|
|
||||||
NotBeforeRelative time.Duration
|
|
||||||
NotAfterRelative time.Duration
|
|
||||||
}
|
|
||||||
|
|
||||||
type ValidityByPurpose struct {
|
|
||||||
CA ValidityDuration
|
|
||||||
Server ValidityDuration
|
|
||||||
}
|
|
||||||
|
|
||||||
type Options struct {
|
|
||||||
RootCACert File
|
|
||||||
CertCachePath string
|
|
||||||
Curve CurveType
|
|
||||||
Validity ValidityByPurpose
|
|
||||||
}
|
|
||||||
|
|
||||||
func loadFromConfig(config *viper.Viper) (Options, error) {
|
|
||||||
missingKeys := make([]string, 0)
|
|
||||||
for _, requiredKey := range requiredConfigKeys {
|
|
||||||
if !config.IsSet(requiredKey) {
|
|
||||||
missingKeys = append(missingKeys, requiredKey)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(missingKeys) > 0 {
|
|
||||||
return Options{}, fmt.Errorf("config keys are missing: %s", strings.Join(missingKeys, ", "))
|
|
||||||
}
|
|
||||||
|
|
||||||
config.SetDefault(certCachePathConfigKey, os.TempDir())
|
|
||||||
config.SetDefault(ecdsaCurveConfigKey, string(CurveTypeED25519))
|
|
||||||
config.SetDefault(caCertValidityNotBeforeKey, defaultCAValidityDuration)
|
|
||||||
config.SetDefault(caCertValidityNotAfterKey, defaultCAValidityDuration)
|
|
||||||
config.SetDefault(serverCertValidityNotBeforeKey, defaultServerValidityDuration)
|
|
||||||
config.SetDefault(serverCertValidityNotAfterKey, defaultServerValidityDuration)
|
|
||||||
|
|
||||||
return Options{
|
|
||||||
CertCachePath: config.GetString(certCachePathConfigKey),
|
|
||||||
Curve: CurveType(config.GetString(ecdsaCurveConfigKey)),
|
|
||||||
Validity: ValidityByPurpose{
|
|
||||||
CA: ValidityDuration{
|
|
||||||
NotBeforeRelative: config.GetDuration(caCertValidityNotBeforeKey),
|
|
||||||
NotAfterRelative: config.GetDuration(caCertValidityNotAfterKey),
|
|
||||||
},
|
|
||||||
Server: ValidityDuration{
|
|
||||||
NotBeforeRelative: config.GetDuration(serverCertValidityNotBeforeKey),
|
|
||||||
NotAfterRelative: config.GetDuration(serverCertValidityNotAfterKey),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
RootCACert: File{
|
|
||||||
PublicKeyPath: config.GetString(publicKeyConfigKey),
|
|
||||||
PrivateKeyPath: config.GetString(privateKeyPathConfigKey),
|
|
||||||
},
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,9 +1,9 @@
|
||||||
package cert
|
package cert
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"github.com/baez90/inetmock/pkg/config"
|
||||||
"github.com/spf13/viper"
|
"github.com/spf13/viper"
|
||||||
"os"
|
"os"
|
||||||
"reflect"
|
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
@ -25,7 +25,7 @@ func Test_loadFromConfig(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
args args
|
args args
|
||||||
want Options
|
want config.CertOptions
|
||||||
wantErr bool
|
wantErr bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
|
@ -48,19 +48,19 @@ tls:
|
||||||
certCachePath: /tmp/inetmock/
|
certCachePath: /tmp/inetmock/
|
||||||
`),
|
`),
|
||||||
},
|
},
|
||||||
want: Options{
|
want: config.CertOptions{
|
||||||
RootCACert: File{
|
RootCACert: config.File{
|
||||||
PublicKeyPath: "./ca.pem",
|
PublicKeyPath: "./ca.pem",
|
||||||
PrivateKeyPath: "./ca.key",
|
PrivateKeyPath: "./ca.key",
|
||||||
},
|
},
|
||||||
CertCachePath: "/tmp/inetmock/",
|
CertCachePath: "/tmp/inetmock/",
|
||||||
Curve: CurveTypeP256,
|
Curve: config.CurveTypeP256,
|
||||||
Validity: ValidityByPurpose{
|
Validity: config.ValidityByPurpose{
|
||||||
CA: ValidityDuration{
|
CA: config.ValidityDuration{
|
||||||
NotBeforeRelative: 17520 * time.Hour,
|
NotBeforeRelative: 17520 * time.Hour,
|
||||||
NotAfterRelative: 17520 * time.Hour,
|
NotAfterRelative: 17520 * time.Hour,
|
||||||
},
|
},
|
||||||
Server: ValidityDuration{
|
Server: config.ValidityDuration{
|
||||||
NotBeforeRelative: 168 * time.Hour,
|
NotBeforeRelative: 168 * time.Hour,
|
||||||
NotAfterRelative: 168 * time.Hour,
|
NotAfterRelative: 168 * time.Hour,
|
||||||
},
|
},
|
||||||
|
@ -76,7 +76,7 @@ tls:
|
||||||
privateKey: ./ca.key
|
privateKey: ./ca.key
|
||||||
`),
|
`),
|
||||||
},
|
},
|
||||||
want: Options{},
|
want: config.CertOptions{},
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -88,7 +88,7 @@ tls:
|
||||||
publicKey: ./ca.pem
|
publicKey: ./ca.pem
|
||||||
`),
|
`),
|
||||||
},
|
},
|
||||||
want: Options{},
|
want: config.CertOptions{},
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -101,19 +101,19 @@ tls:
|
||||||
privateKey: ./ca.key
|
privateKey: ./ca.key
|
||||||
`),
|
`),
|
||||||
},
|
},
|
||||||
want: Options{
|
want: config.CertOptions{
|
||||||
RootCACert: File{
|
RootCACert: config.File{
|
||||||
PublicKeyPath: "./ca.pem",
|
PublicKeyPath: "./ca.pem",
|
||||||
PrivateKeyPath: "./ca.key",
|
PrivateKeyPath: "./ca.key",
|
||||||
},
|
},
|
||||||
CertCachePath: os.TempDir(),
|
CertCachePath: os.TempDir(),
|
||||||
Curve: CurveTypeED25519,
|
Curve: config.CurveTypeED25519,
|
||||||
Validity: ValidityByPurpose{
|
Validity: config.ValidityByPurpose{
|
||||||
CA: ValidityDuration{
|
CA: config.ValidityDuration{
|
||||||
NotBeforeRelative: 17520 * time.Hour,
|
NotBeforeRelative: 17520 * time.Hour,
|
||||||
NotAfterRelative: 17520 * time.Hour,
|
NotAfterRelative: 17520 * time.Hour,
|
||||||
},
|
},
|
||||||
Server: ValidityDuration{
|
Server: config.ValidityDuration{
|
||||||
NotBeforeRelative: 168 * time.Hour,
|
NotBeforeRelative: 168 * time.Hour,
|
||||||
NotAfterRelative: 168 * time.Hour,
|
NotAfterRelative: 168 * time.Hour,
|
||||||
},
|
},
|
||||||
|
@ -124,14 +124,7 @@ tls:
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
got, err := loadFromConfig(tt.args.config)
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("loadFromConfig() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(got, tt.want) {
|
|
||||||
t.Errorf("loadFromConfig() got = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,8 +7,8 @@ import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
|
"github.com/baez90/inetmock/pkg/config"
|
||||||
"github.com/baez90/inetmock/pkg/logging"
|
"github.com/baez90/inetmock/pkg/logging"
|
||||||
"github.com/spf13/viper"
|
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
"net"
|
"net"
|
||||||
)
|
)
|
||||||
|
@ -18,7 +18,7 @@ const (
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
defaultKeyProvider = func(options Options) func() (key interface{}, err error) {
|
defaultKeyProvider = func(options config.CertOptions) func() (key interface{}, err error) {
|
||||||
return func() (key interface{}, err error) {
|
return func() (key interface{}, err error) {
|
||||||
return privateKeyForCurve(options)
|
return privateKeyForCurve(options)
|
||||||
}
|
}
|
||||||
|
@ -34,24 +34,20 @@ type Store interface {
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewDefaultStore(
|
func NewDefaultStore(
|
||||||
config *viper.Viper,
|
config config.Config,
|
||||||
logger logging.Logger,
|
logger logging.Logger,
|
||||||
) (Store, error) {
|
) (Store, error) {
|
||||||
options, err := loadFromConfig(config)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
timeSource := NewTimeSource()
|
timeSource := NewTimeSource()
|
||||||
return NewStore(
|
return NewStore(
|
||||||
options,
|
config.TLSConfig(),
|
||||||
NewFileSystemCache(options.CertCachePath, timeSource),
|
NewFileSystemCache(config.TLSConfig().CertCachePath, timeSource),
|
||||||
NewDefaultGenerator(options),
|
NewDefaultGenerator(config.TLSConfig()),
|
||||||
logger,
|
logger,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewStore(
|
func NewStore(
|
||||||
options Options,
|
options config.CertOptions,
|
||||||
cache Cache,
|
cache Cache,
|
||||||
generator Generator,
|
generator Generator,
|
||||||
logger logging.Logger,
|
logger logging.Logger,
|
||||||
|
@ -71,7 +67,7 @@ func NewStore(
|
||||||
}
|
}
|
||||||
|
|
||||||
type store struct {
|
type store struct {
|
||||||
options Options
|
options config.CertOptions
|
||||||
caCert *tls.Certificate
|
caCert *tls.Certificate
|
||||||
cache Cache
|
cache Cache
|
||||||
timeSource TimeSource
|
timeSource TimeSource
|
||||||
|
@ -135,15 +131,15 @@ func (s *store) GetCertificate(serverName string, ip string) (cert *tls.Certific
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func privateKeyForCurve(options Options) (privateKey interface{}, err error) {
|
func privateKeyForCurve(options config.CertOptions) (privateKey interface{}, err error) {
|
||||||
switch options.Curve {
|
switch options.Curve {
|
||||||
case CurveTypeP224:
|
case config.CurveTypeP224:
|
||||||
privateKey, err = ecdsa.GenerateKey(elliptic.P224(), rand.Reader)
|
privateKey, err = ecdsa.GenerateKey(elliptic.P224(), rand.Reader)
|
||||||
case CurveTypeP256:
|
case config.CurveTypeP256:
|
||||||
privateKey, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
privateKey, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
case CurveTypeP384:
|
case config.CurveTypeP384:
|
||||||
privateKey, err = ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
|
privateKey, err = ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
|
||||||
case CurveTypeP521:
|
case config.CurveTypeP521:
|
||||||
privateKey, err = ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
|
privateKey, err = ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
|
||||||
default:
|
default:
|
||||||
_, privateKey, err = ed25519.GenerateKey(rand.Reader)
|
_, privateKey, err = ed25519.GenerateKey(rand.Reader)
|
||||||
|
|
27
pkg/config/certs.go
Normal file
27
pkg/config/certs.go
Normal file
|
@ -0,0 +1,27 @@
|
||||||
|
package config
|
||||||
|
|
||||||
|
import "time"
|
||||||
|
|
||||||
|
type CurveType string
|
||||||
|
|
||||||
|
type File struct {
|
||||||
|
PublicKeyPath string
|
||||||
|
PrivateKeyPath string
|
||||||
|
}
|
||||||
|
|
||||||
|
type ValidityDuration struct {
|
||||||
|
NotBeforeRelative time.Duration
|
||||||
|
NotAfterRelative time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
type ValidityByPurpose struct {
|
||||||
|
CA ValidityDuration
|
||||||
|
Server ValidityDuration
|
||||||
|
}
|
||||||
|
|
||||||
|
type CertOptions struct {
|
||||||
|
RootCACert File
|
||||||
|
CertCachePath string
|
||||||
|
Curve CurveType
|
||||||
|
Validity ValidityByPurpose
|
||||||
|
}
|
109
pkg/config/config.go
Normal file
109
pkg/config/config.go
Normal file
|
@ -0,0 +1,109 @@
|
||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/baez90/inetmock/pkg/logging"
|
||||||
|
"github.com/baez90/inetmock/pkg/path"
|
||||||
|
"github.com/spf13/pflag"
|
||||||
|
"github.com/spf13/viper"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
appConfig Config
|
||||||
|
)
|
||||||
|
|
||||||
|
func CreateConfig(flags *pflag.FlagSet) {
|
||||||
|
logger, _ := logging.CreateLogger()
|
||||||
|
configInstance := &config{
|
||||||
|
logger: logger.Named("Config"),
|
||||||
|
cfg: viper.New(),
|
||||||
|
}
|
||||||
|
|
||||||
|
configInstance.cfg.SetConfigName("config")
|
||||||
|
configInstance.cfg.SetConfigType("yaml")
|
||||||
|
configInstance.cfg.AddConfigPath("/etc/inetmock/")
|
||||||
|
configInstance.cfg.AddConfigPath("$HOME/.inetmock")
|
||||||
|
configInstance.cfg.AddConfigPath(".")
|
||||||
|
configInstance.cfg.SetEnvPrefix("INetMock")
|
||||||
|
_ = configInstance.cfg.BindPFlags(flags)
|
||||||
|
configInstance.cfg.SetEnvKeyReplacer(strings.NewReplacer("-", "_"))
|
||||||
|
configInstance.cfg.AutomaticEnv()
|
||||||
|
|
||||||
|
for k, v := range registeredDefaults {
|
||||||
|
configInstance.cfg.SetDefault(k, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
for k, v := range registeredAliases {
|
||||||
|
configInstance.cfg.RegisterAlias(k, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
appConfig = configInstance
|
||||||
|
}
|
||||||
|
|
||||||
|
func Instance() Config {
|
||||||
|
return appConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
type Config interface {
|
||||||
|
ReadConfig(configFilePath string) error
|
||||||
|
ReadConfigString(config, format string) error
|
||||||
|
Viper() *viper.Viper
|
||||||
|
TLSConfig() CertOptions
|
||||||
|
PluginsDir() string
|
||||||
|
EndpointConfigs() map[string]MultiHandlerConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
type config struct {
|
||||||
|
cfg *viper.Viper
|
||||||
|
logger logging.Logger
|
||||||
|
TLS CertOptions
|
||||||
|
PluginsDirectory string
|
||||||
|
Endpoints map[string]MultiHandlerConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *config) ReadConfigString(config, format string) (err error) {
|
||||||
|
c.cfg.SetConfigType(format)
|
||||||
|
if err = c.cfg.ReadConfig(strings.NewReader(config)); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = c.cfg.Unmarshal(c)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *config) EndpointConfigs() map[string]MultiHandlerConfig {
|
||||||
|
return c.Endpoints
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *config) PluginsDir() string {
|
||||||
|
return c.PluginsDirectory
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *config) TLSConfig() CertOptions {
|
||||||
|
return c.TLS
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *config) Viper() *viper.Viper {
|
||||||
|
return c.cfg
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *config) ReadConfig(configFilePath string) (err error) {
|
||||||
|
if configFilePath != "" && path.FileExists(configFilePath) {
|
||||||
|
c.logger.Info(
|
||||||
|
"loading config from passed config file path",
|
||||||
|
zap.String("configFilePath", configFilePath),
|
||||||
|
)
|
||||||
|
c.cfg.SetConfigFile(configFilePath)
|
||||||
|
}
|
||||||
|
if err = c.cfg.ReadInConfig(); err != nil {
|
||||||
|
if _, ok := err.(viper.ConfigFileNotFoundError); ok {
|
||||||
|
err = nil
|
||||||
|
c.logger.Warn("failed to load config")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
err = c.cfg.Unmarshal(c)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
88
pkg/config/config_test.go
Normal file
88
pkg/config/config_test.go
Normal file
|
@ -0,0 +1,88 @@
|
||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/spf13/pflag"
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_config_ReadConfig(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
flags *pflag.FlagSet
|
||||||
|
config string
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
matcher func(Config) bool
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Test endpoints config",
|
||||||
|
args: args{
|
||||||
|
flags: pflag.NewFlagSet("", pflag.ContinueOnError),
|
||||||
|
config: `
|
||||||
|
endpoints:
|
||||||
|
plainHttp:
|
||||||
|
handler: http_mock
|
||||||
|
listenAddress: 0.0.0.0
|
||||||
|
ports:
|
||||||
|
- 80
|
||||||
|
- 8080
|
||||||
|
options: {}
|
||||||
|
proxy:
|
||||||
|
handler: http_proxy
|
||||||
|
listenAddress: 0.0.0.0
|
||||||
|
ports:
|
||||||
|
- 3128
|
||||||
|
options:
|
||||||
|
target:
|
||||||
|
ipAddress: 127.0.0.1
|
||||||
|
port: 80
|
||||||
|
`,
|
||||||
|
},
|
||||||
|
matcher: func(c Config) bool {
|
||||||
|
if len(c.EndpointConfigs()) < 1 {
|
||||||
|
t.Error("Expected EndpointConfigs to be set but is empty")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Test loading of flags by adding",
|
||||||
|
args: args{
|
||||||
|
flags: func() *pflag.FlagSet {
|
||||||
|
flags := pflag.NewFlagSet("", pflag.ContinueOnError)
|
||||||
|
flags.String("plugins-directory", os.TempDir(), "")
|
||||||
|
|
||||||
|
return flags
|
||||||
|
}(),
|
||||||
|
config: ``,
|
||||||
|
},
|
||||||
|
matcher: func(c Config) bool {
|
||||||
|
if c.PluginsDir() != os.TempDir() {
|
||||||
|
t.Errorf("PluginsDir() Expected = %s, Got = %s", os.TempDir(), c.PluginsDir())
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
CreateConfig(tt.args.flags)
|
||||||
|
cfg := Instance()
|
||||||
|
if err := cfg.ReadConfigString(tt.args.config, "yaml"); (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("ReadConfig() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !tt.matcher(cfg) {
|
||||||
|
t.Error("matcher error")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
12
pkg/config/constants.go
Normal file
12
pkg/config/constants.go
Normal file
|
@ -0,0 +1,12 @@
|
||||||
|
package config
|
||||||
|
|
||||||
|
const (
|
||||||
|
EndpointsKey = "endpoints"
|
||||||
|
OptionsKey = "options"
|
||||||
|
|
||||||
|
CurveTypeP224 CurveType = "P224"
|
||||||
|
CurveTypeP256 CurveType = "P256"
|
||||||
|
CurveTypeP384 CurveType = "P384"
|
||||||
|
CurveTypeP521 CurveType = "P521"
|
||||||
|
CurveTypeED25519 CurveType = "ED25519"
|
||||||
|
)
|
17
pkg/config/defaults.go
Normal file
17
pkg/config/defaults.go
Normal file
|
@ -0,0 +1,17 @@
|
||||||
|
package config
|
||||||
|
|
||||||
|
var (
|
||||||
|
registeredDefaults = make(map[string]interface{})
|
||||||
|
// default aliases
|
||||||
|
registeredAliases = map[string]string{
|
||||||
|
"PluginsDirectory": "plugins-directory",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
func AddDefaultValue(key string, val interface{}) {
|
||||||
|
registeredDefaults[key] = val
|
||||||
|
}
|
||||||
|
|
||||||
|
func AddAlias(alias, orig string) {
|
||||||
|
registeredAliases[alias] = orig
|
||||||
|
}
|
10
pkg/config/handler_config.go
Normal file
10
pkg/config/handler_config.go
Normal file
|
@ -0,0 +1,10 @@
|
||||||
|
package config
|
||||||
|
|
||||||
|
import "github.com/spf13/viper"
|
||||||
|
|
||||||
|
type HandlerConfig struct {
|
||||||
|
HandlerName string
|
||||||
|
Port uint16
|
||||||
|
ListenAddress string
|
||||||
|
Options *viper.Viper
|
||||||
|
}
|
|
@ -1,4 +1,4 @@
|
||||||
package api
|
package config
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/spf13/viper"
|
"github.com/spf13/viper"
|
||||||
|
@ -19,12 +19,12 @@ func Test_handlerConfig_HandlerName(t *testing.T) {
|
||||||
want string
|
want string
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "Get empty HandlerName for uninitialized struct",
|
name: "Get empty Handler for uninitialized struct",
|
||||||
fields: fields{},
|
fields: fields{},
|
||||||
want: "",
|
want: "",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Get expected HandlerName for initialized struct",
|
name: "Get expected Handler for initialized struct",
|
||||||
fields: fields{
|
fields: fields{
|
||||||
handlerName: "sampleHandler",
|
handlerName: "sampleHandler",
|
||||||
},
|
},
|
||||||
|
@ -33,14 +33,14 @@ func Test_handlerConfig_HandlerName(t *testing.T) {
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
h := handlerConfig{
|
h := HandlerConfig{
|
||||||
handlerName: tt.fields.handlerName,
|
HandlerName: tt.fields.handlerName,
|
||||||
port: tt.fields.port,
|
Port: tt.fields.port,
|
||||||
listenAddress: tt.fields.listenAddress,
|
ListenAddress: tt.fields.listenAddress,
|
||||||
options: tt.fields.options,
|
Options: tt.fields.options,
|
||||||
}
|
}
|
||||||
if got := h.HandlerName(); got != tt.want {
|
if got := h.HandlerName; got != tt.want {
|
||||||
t.Errorf("HandlerName() = %v, want %v", got, tt.want)
|
t.Errorf("Handler() = %v, want %v", got, tt.want)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -73,13 +73,13 @@ func Test_handlerConfig_ListenAddress(t *testing.T) {
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
h := handlerConfig{
|
h := HandlerConfig{
|
||||||
handlerName: tt.fields.handlerName,
|
HandlerName: tt.fields.handlerName,
|
||||||
port: tt.fields.port,
|
Port: tt.fields.port,
|
||||||
listenAddress: tt.fields.listenAddress,
|
ListenAddress: tt.fields.listenAddress,
|
||||||
options: tt.fields.options,
|
Options: tt.fields.options,
|
||||||
}
|
}
|
||||||
if got := h.ListenAddress(); got != tt.want {
|
if got := h.ListenAddress; got != tt.want {
|
||||||
t.Errorf("ListenAddress() = %v, want %v", got, tt.want)
|
t.Errorf("ListenAddress() = %v, want %v", got, tt.want)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -113,13 +113,13 @@ func Test_handlerConfig_Options(t *testing.T) {
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
h := handlerConfig{
|
h := HandlerConfig{
|
||||||
handlerName: tt.fields.handlerName,
|
HandlerName: tt.fields.handlerName,
|
||||||
port: tt.fields.port,
|
Port: tt.fields.port,
|
||||||
listenAddress: tt.fields.listenAddress,
|
ListenAddress: tt.fields.listenAddress,
|
||||||
options: tt.fields.options,
|
Options: tt.fields.options,
|
||||||
}
|
}
|
||||||
if got := h.Options(); !reflect.DeepEqual(got, tt.want) {
|
if got := h.Options; !reflect.DeepEqual(got, tt.want) {
|
||||||
t.Errorf("Options() = %v, want %v", got, tt.want)
|
t.Errorf("Options() = %v, want %v", got, tt.want)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -153,13 +153,13 @@ func Test_handlerConfig_Port(t *testing.T) {
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
h := handlerConfig{
|
h := HandlerConfig{
|
||||||
handlerName: tt.fields.handlerName,
|
HandlerName: tt.fields.handlerName,
|
||||||
port: tt.fields.port,
|
Port: tt.fields.port,
|
||||||
listenAddress: tt.fields.listenAddress,
|
ListenAddress: tt.fields.listenAddress,
|
||||||
options: tt.fields.options,
|
Options: tt.fields.options,
|
||||||
}
|
}
|
||||||
if got := h.Port(); got != tt.want {
|
if got := h.Port; got != tt.want {
|
||||||
t.Errorf("Port() = %v, want %v", got, tt.want)
|
t.Errorf("Port() = %v, want %v", got, tt.want)
|
||||||
}
|
}
|
||||||
})
|
})
|
25
pkg/config/multi_handler_config.go
Normal file
25
pkg/config/multi_handler_config.go
Normal file
|
@ -0,0 +1,25 @@
|
||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/spf13/viper"
|
||||||
|
)
|
||||||
|
|
||||||
|
type MultiHandlerConfig struct {
|
||||||
|
Handler string
|
||||||
|
Ports []uint16
|
||||||
|
ListenAddress string
|
||||||
|
Options *viper.Viper
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m MultiHandlerConfig) HandlerConfigs() []HandlerConfig {
|
||||||
|
configs := make([]HandlerConfig, 0)
|
||||||
|
for _, port := range m.Ports {
|
||||||
|
configs = append(configs, HandlerConfig{
|
||||||
|
HandlerName: m.Handler,
|
||||||
|
Port: port,
|
||||||
|
ListenAddress: m.ListenAddress,
|
||||||
|
Options: m.Options,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return configs
|
||||||
|
}
|
|
@ -2,7 +2,7 @@ package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/baez90/inetmock/pkg/api"
|
"github.com/baez90/inetmock/pkg/config"
|
||||||
"github.com/baez90/inetmock/pkg/logging"
|
"github.com/baez90/inetmock/pkg/logging"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
|
@ -13,9 +13,9 @@ type dnsHandler struct {
|
||||||
dnsServer []*dns.Server
|
dnsServer []*dns.Server
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *dnsHandler) Start(config api.HandlerConfig) (err error) {
|
func (d *dnsHandler) Start(config config.HandlerConfig) (err error) {
|
||||||
options := loadFromConfig(config.Options())
|
options := loadFromConfig(config.Options)
|
||||||
addr := fmt.Sprintf("%s:%d", config.ListenAddress(), config.Port())
|
addr := fmt.Sprintf("%s:%d", config.ListenAddress, config.Port)
|
||||||
|
|
||||||
handler := ®exHandler{
|
handler := ®exHandler{
|
||||||
fallback: options.Fallback,
|
fallback: options.Fallback,
|
||||||
|
|
|
@ -2,7 +2,7 @@ package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/baez90/inetmock/pkg/api"
|
"github.com/baez90/inetmock/pkg/config"
|
||||||
"github.com/baez90/inetmock/pkg/logging"
|
"github.com/baez90/inetmock/pkg/logging"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
@ -18,9 +18,9 @@ type httpHandler struct {
|
||||||
server *http.Server
|
server *http.Server
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *httpHandler) Start(config api.HandlerConfig) (err error) {
|
func (p *httpHandler) Start(config config.HandlerConfig) (err error) {
|
||||||
options := loadFromConfig(config.Options())
|
options := loadFromConfig(config.Options)
|
||||||
addr := fmt.Sprintf("%s:%d", config.ListenAddress(), config.Port())
|
addr := fmt.Sprintf("%s:%d", config.ListenAddress, config.Port)
|
||||||
p.server = &http.Server{Addr: addr, Handler: p.router}
|
p.server = &http.Server{Addr: addr, Handler: p.router}
|
||||||
p.logger = p.logger.With(
|
p.logger = p.logger.With(
|
||||||
zap.String("address", addr),
|
zap.String("address", addr),
|
||||||
|
|
|
@ -3,6 +3,7 @@ package main
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/baez90/inetmock/pkg/api"
|
"github.com/baez90/inetmock/pkg/api"
|
||||||
|
"github.com/baez90/inetmock/pkg/config"
|
||||||
"github.com/baez90/inetmock/pkg/logging"
|
"github.com/baez90/inetmock/pkg/logging"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
"gopkg.in/elazarl/goproxy.v1"
|
"gopkg.in/elazarl/goproxy.v1"
|
||||||
|
@ -19,9 +20,9 @@ type httpProxy struct {
|
||||||
server *http.Server
|
server *http.Server
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *httpProxy) Start(config api.HandlerConfig) (err error) {
|
func (h *httpProxy) Start(config config.HandlerConfig) (err error) {
|
||||||
options := loadFromConfig(config.Options())
|
options := loadFromConfig(config.Options)
|
||||||
addr := fmt.Sprintf("%s:%d", config.ListenAddress(), config.Port())
|
addr := fmt.Sprintf("%s:%d", config.ListenAddress, config.Port)
|
||||||
h.server = &http.Server{Addr: addr, Handler: h.proxy}
|
h.server = &http.Server{Addr: addr, Handler: h.proxy}
|
||||||
h.logger = h.logger.With(
|
h.logger = h.logger.With(
|
||||||
zap.String("address", addr),
|
zap.String("address", addr),
|
||||||
|
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/baez90/inetmock/pkg/api"
|
"github.com/baez90/inetmock/pkg/api"
|
||||||
"github.com/baez90/inetmock/pkg/cert"
|
"github.com/baez90/inetmock/pkg/cert"
|
||||||
|
"github.com/baez90/inetmock/pkg/config"
|
||||||
"github.com/baez90/inetmock/pkg/logging"
|
"github.com/baez90/inetmock/pkg/logging"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
|
@ -27,9 +28,9 @@ type tlsInterceptor struct {
|
||||||
currentConnections map[uuid.UUID]*proxyConn
|
currentConnections map[uuid.UUID]*proxyConn
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tlsInterceptor) Start(config api.HandlerConfig) (err error) {
|
func (t *tlsInterceptor) Start(config config.HandlerConfig) (err error) {
|
||||||
t.options = loadFromConfig(config.Options())
|
t.options = loadFromConfig(config.Options)
|
||||||
addr := fmt.Sprintf("%s:%d", config.ListenAddress(), config.Port())
|
addr := fmt.Sprintf("%s:%d", config.ListenAddress, config.Port)
|
||||||
|
|
||||||
t.logger = t.logger.With(
|
t.logger = t.logger.With(
|
||||||
zap.String("address", addr),
|
zap.String("address", addr),
|
||||||
|
|
|
@ -1,77 +0,0 @@
|
||||||
endpoints:
|
|
||||||
plainHttp:
|
|
||||||
handler: http_proxy
|
|
||||||
listenAddress: 0.0.0.0
|
|
||||||
port: 80
|
|
||||||
options:
|
|
||||||
rules:
|
|
||||||
- pattern: ".*\\.(?i)exe"
|
|
||||||
response: ./assets/fakeFiles/sample.exe
|
|
||||||
- pattern: ".*\\.(?i)(jpg|jpeg)"
|
|
||||||
response: ./assets/fakeFiles/default.jpg
|
|
||||||
- pattern: ".*\\.(?i)png"
|
|
||||||
response: ./assets/fakeFiles/default.png
|
|
||||||
- pattern: ".*\\.(?i)gif"
|
|
||||||
response: ./assets/fakeFiles/default.gif
|
|
||||||
- pattern: ".*\\.(?i)ico"
|
|
||||||
response: ./assets/fakeFiles/default.ico
|
|
||||||
- pattern: ".*\\.(?i)txt"
|
|
||||||
response: ./assets/fakeFiles/default.txt
|
|
||||||
- pattern: ".*"
|
|
||||||
response: ./assets/fakeFiles/default.html
|
|
||||||
httpsDowngrade:
|
|
||||||
handler: tls_interceptor
|
|
||||||
listenAddress: 0.0.0.0
|
|
||||||
port: 443
|
|
||||||
options:
|
|
||||||
ecdsaCurve: P256
|
|
||||||
validity:
|
|
||||||
ca:
|
|
||||||
notBeforeRelative: 17520h
|
|
||||||
notAfterRelative: 17520h
|
|
||||||
domain:
|
|
||||||
notBeforeRelative: 168h
|
|
||||||
notAfterRelative: 168h
|
|
||||||
rootCaCert:
|
|
||||||
publicKey: ./ca.pem
|
|
||||||
privateKey: ./ca.key
|
|
||||||
certCachePath: /tmp/inetmock/
|
|
||||||
target:
|
|
||||||
ipAddress: 127.0.0.1
|
|
||||||
port: 80
|
|
||||||
plainDns:
|
|
||||||
handler: dns_mock
|
|
||||||
listenAddress: 0.0.0.0
|
|
||||||
port: 53
|
|
||||||
options:
|
|
||||||
rules:
|
|
||||||
- pattern: "www.golem.de"
|
|
||||||
response: 77.247.84.129
|
|
||||||
- pattern: ".*\\.google\\.com"
|
|
||||||
response: 1.1.1.1
|
|
||||||
- pattern: ".*\\.reddit\\.com"
|
|
||||||
response: 2.2.2.2
|
|
||||||
fallback:
|
|
||||||
strategy: incremental
|
|
||||||
args:
|
|
||||||
startIP: 10.0.0.0
|
|
||||||
dnsOverTlsDowngrade:
|
|
||||||
handler: tls_interceptor
|
|
||||||
listenAddress: 0.0.0.0
|
|
||||||
port: 853
|
|
||||||
options:
|
|
||||||
ecdsaCurve: P256
|
|
||||||
validity:
|
|
||||||
ca:
|
|
||||||
notBeforeRelative: 17520h
|
|
||||||
notAfterRelative: 17520h
|
|
||||||
domain:
|
|
||||||
notBeforeRelative: 168h
|
|
||||||
notAfterRelative: 168h
|
|
||||||
rootCaCert:
|
|
||||||
publicKey: ./ca.pem
|
|
||||||
privateKey: ./ca.key
|
|
||||||
certCachePath: /tmp/inetmock/
|
|
||||||
target:
|
|
||||||
ipAddress: 127.0.0.1
|
|
||||||
port: 53
|
|
Loading…
Reference in a new issue