Improve config and startup handling
- use `Unmarshal` method of viper - move config loading, defaulting and stuff to config package - move serve to an extra command and move keep only global flags in root command - add startup logic to onInit method in root command - update mocks - clean config structs - adopt changes in plugins - update default config
This commit is contained in:
parent
1775d3dbc6
commit
91f0cf6963
40 changed files with 561 additions and 975 deletions
8
Makefile
8
Makefile
|
@ -2,7 +2,8 @@ VERSION = $(shell git describe --dirty --tags --always)
|
|||
DIR = $(dir $(realpath $(firstword $(MAKEFILE_LIST))))
|
||||
BUILD_PATH = $(DIR)/main.go
|
||||
PKGS = $(shell go list ./...)
|
||||
TEST_PKGS = $(shell find . -type f -name "*_test.go" -not -path "./plugins/*" -not -path "*/mock/*" -printf '%h\n' | sort -u)
|
||||
TEST_PKGS = $(shell go list ./...)
|
||||
PROTO_FILES = $(shell find $(DIR)api/ -type f -name "*.proto")
|
||||
GOARGS = GOOS=linux GOARCH=amd64
|
||||
GO_BUILD_ARGS = -ldflags="-w -s"
|
||||
GO_CONTAINER_BUILD_ARGS = -ldflags="-w -s" -a -installsuffix cgo
|
||||
|
@ -14,7 +15,7 @@ DEBUG_ARGS?= --development-logs=true
|
|||
CONTAINER_BUILDER ?= podman
|
||||
DOCKER_IMAGE ?= inetmock
|
||||
|
||||
.PHONY: clean all format deps update-deps compile debug generate snapshot-release test cli-cover-report html-cover-report plugins $(PLUGINS) $(GO_GEN_FILES)
|
||||
.PHONY: clean all format deps update-deps compile debug generate protoc snapshot-release test cli-cover-report html-cover-report plugins $(PLUGINS) $(GO_GEN_FILES)
|
||||
all: clean format compile test plugins
|
||||
|
||||
clean:
|
||||
|
@ -55,6 +56,9 @@ debug:
|
|||
generate:
|
||||
@go generate ./...
|
||||
|
||||
protoc:
|
||||
@protoc --proto_path $(DIR)api/ --go_out=plugins=grpc:internal/grpc --go_opt=paths=source_relative $(shell find $(DIR)api/ -type f -name "*.proto")
|
||||
|
||||
snapshot-release:
|
||||
@goreleaser release --snapshot --skip-publish --rm-dist
|
||||
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"github.com/baez90/inetmock/pkg/cert"
|
||||
config2 "github.com/baez90/inetmock/pkg/config"
|
||||
"github.com/baez90/inetmock/pkg/logging"
|
||||
"github.com/spf13/cobra"
|
||||
"go.uber.org/zap"
|
||||
|
@ -77,11 +78,11 @@ func runGenerateCA(cmd *cobra.Command, args []string) {
|
|||
zap.String(generateCACertOutPath, certOutPath),
|
||||
)
|
||||
|
||||
generator := cert.NewDefaultGenerator(cert.Options{
|
||||
generator := cert.NewDefaultGenerator(config2.CertOptions{
|
||||
CertCachePath: certOutPath,
|
||||
Curve: cert.CurveType(curveName),
|
||||
Validity: cert.ValidityByPurpose{
|
||||
CA: cert.ValidityDuration{
|
||||
Curve: config2.CurveType(curveName),
|
||||
Validity: config2.ValidityByPurpose{
|
||||
CA: config2.ValidityDuration{
|
||||
NotAfterRelative: notAfter,
|
||||
NotBeforeRelative: notBefore,
|
||||
},
|
||||
|
|
|
@ -1,70 +0,0 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"github.com/baez90/inetmock/internal/endpoints"
|
||||
"github.com/baez90/inetmock/internal/plugins"
|
||||
"github.com/baez90/inetmock/pkg/api"
|
||||
"github.com/baez90/inetmock/pkg/logging"
|
||||
"github.com/baez90/inetmock/pkg/path"
|
||||
"github.com/spf13/viper"
|
||||
"go.uber.org/zap"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
appIsInitialized = false
|
||||
)
|
||||
|
||||
func initApp() (err error) {
|
||||
if appIsInitialized {
|
||||
return
|
||||
}
|
||||
appIsInitialized = true
|
||||
logging.ConfigureLogging(
|
||||
logging.ParseLevel(logLevel),
|
||||
developmentLogs,
|
||||
map[string]interface{}{"cwd": path.WorkingDirectory()},
|
||||
)
|
||||
logger, _ = logging.CreateLogger()
|
||||
registry := plugins.Registry()
|
||||
endpointManager = endpoints.NewEndpointManager(logger)
|
||||
|
||||
rootCmd.Flags().ParseErrorsWhitelist.UnknownFlags = false
|
||||
_ = rootCmd.ParseFlags(os.Args)
|
||||
|
||||
if err = appConfig.ReadConfig(configFilePath); err != nil {
|
||||
logger.Error(
|
||||
"unrecoverable error occurred during reading the config file",
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
viperInst := viper.GetViper()
|
||||
pluginDir := viperInst.GetString("plugins-directory")
|
||||
|
||||
if err = api.InitServices(viperInst, logger); err != nil {
|
||||
logger.Error(
|
||||
"failed to initialize app services",
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
|
||||
pluginLoadStartTime := time.Now()
|
||||
if err = registry.LoadPlugins(pluginDir); err != nil {
|
||||
logger.Error("Failed to load plugins",
|
||||
zap.String("pluginsDirectory", pluginDir),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
pluginLoadDuration := time.Since(pluginLoadStartTime)
|
||||
logger.Info(
|
||||
"loading plugins completed",
|
||||
zap.Duration("pluginLoadDuration", pluginLoadDuration),
|
||||
)
|
||||
|
||||
pluginsCmd.AddCommand(registry.PluginCommands()...)
|
||||
|
||||
return
|
||||
}
|
|
@ -5,10 +5,6 @@ import (
|
|||
)
|
||||
|
||||
var (
|
||||
pluginsCmd *cobra.Command
|
||||
)
|
||||
|
||||
func init() {
|
||||
pluginsCmd = &cobra.Command{
|
||||
Use: "plugins",
|
||||
Short: "Use the plugins prefix to interact with commands that are provided by plugins",
|
||||
|
@ -18,5 +14,4 @@ The easiest way to explore what commands are available is to start with 'inetmoc
|
|||
This help page contains a list of available sub-commands starting with the name of the plugin as a prefix.
|
||||
`,
|
||||
}
|
||||
rootCmd.AddCommand(pluginsCmd, generateCaCmd)
|
||||
}
|
||||
)
|
||||
|
|
|
@ -1,77 +1,91 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"github.com/baez90/inetmock/internal/config"
|
||||
"github.com/baez90/inetmock/internal/endpoints"
|
||||
"github.com/baez90/inetmock/internal/plugins"
|
||||
"github.com/baez90/inetmock/pkg/api"
|
||||
"github.com/baez90/inetmock/pkg/config"
|
||||
"github.com/baez90/inetmock/pkg/logging"
|
||||
"github.com/baez90/inetmock/pkg/path"
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/spf13/viper"
|
||||
"go.uber.org/zap"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
logger logging.Logger
|
||||
rootCmd = cobra.Command{
|
||||
Use: "",
|
||||
Short: "INetMock is lightweight internet mock",
|
||||
Run: startInetMock,
|
||||
}
|
||||
rootCmd *cobra.Command
|
||||
|
||||
pluginsDirectory string
|
||||
configFilePath string
|
||||
logLevel string
|
||||
developmentLogs bool
|
||||
handlers []api.ProtocolHandler
|
||||
appConfig = config.CreateConfig()
|
||||
endpointManager endpoints.EndpointManager
|
||||
)
|
||||
|
||||
func init() {
|
||||
rootCmd.PersistentFlags().String("plugins-directory", "", "Directory where plugins should be loaded from")
|
||||
cobra.OnInitialize(onInit)
|
||||
rootCmd = &cobra.Command{
|
||||
Use: "",
|
||||
Short: "INetMock is lightweight internet mock",
|
||||
}
|
||||
|
||||
rootCmd.PersistentFlags().StringVar(&pluginsDirectory, "plugins-directory", "", "Directory where plugins should be loaded from")
|
||||
rootCmd.PersistentFlags().StringVar(&configFilePath, "config", "", "Path to config file that should be used")
|
||||
rootCmd.PersistentFlags().StringVar(&logLevel, "log-level", "info", "logging level to use")
|
||||
rootCmd.PersistentFlags().BoolVar(&developmentLogs, "development-logs", false, "Enable development mode logs")
|
||||
|
||||
appConfig.InitConfig(rootCmd.PersistentFlags())
|
||||
rootCmd.AddCommand(
|
||||
serveCmd,
|
||||
generateCaCmd,
|
||||
pluginsCmd,
|
||||
)
|
||||
}
|
||||
|
||||
func startInetMock(cmd *cobra.Command, args []string) {
|
||||
for endpointName := range viper.GetStringMap(config.EndpointsKey) {
|
||||
handlerSubConfig := viper.Sub(strings.Join([]string{config.EndpointsKey, endpointName}, "."))
|
||||
handlerConfig := config.CreateMultiHandlerConfig(handlerSubConfig)
|
||||
if err := endpointManager.CreateEndpoint(endpointName, handlerConfig); err != nil {
|
||||
logger.Warn(
|
||||
"error occurred while creating endpoint",
|
||||
zap.String("endpointName", endpointName),
|
||||
zap.String("handlerName", handlerConfig.HandlerName()),
|
||||
func onInit() {
|
||||
logging.ConfigureLogging(
|
||||
logging.ParseLevel(logLevel),
|
||||
developmentLogs,
|
||||
map[string]interface{}{"cwd": path.WorkingDirectory()},
|
||||
)
|
||||
|
||||
logger, _ = logging.CreateLogger()
|
||||
config.CreateConfig(rootCmd.Flags())
|
||||
appConfig := config.Instance()
|
||||
|
||||
if err := appConfig.ReadConfig(configFilePath); err != nil {
|
||||
logger.Error(
|
||||
"failed to read config file",
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
|
||||
if err := api.InitServices(appConfig, logger); err != nil {
|
||||
logger.Error(
|
||||
"failed to initialize app services",
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
|
||||
endpointManager.StartEndpoints()
|
||||
registry := plugins.Registry()
|
||||
|
||||
signalChannel := make(chan os.Signal, 1)
|
||||
signal.Notify(signalChannel, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT)
|
||||
|
||||
// block until canceled
|
||||
s := <-signalChannel
|
||||
cfg := config.Instance()
|
||||
|
||||
pluginLoadStartTime := time.Now()
|
||||
if err := registry.LoadPlugins(cfg.PluginsDir()); err != nil {
|
||||
logger.Error("Failed to load plugins",
|
||||
zap.String("pluginsDirectory", cfg.PluginsDir()),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
pluginLoadDuration := time.Since(pluginLoadStartTime)
|
||||
logger.Info(
|
||||
"got signal to quit",
|
||||
zap.String("signal", s.String()),
|
||||
"loading plugins completed",
|
||||
zap.Duration("pluginLoadDuration", pluginLoadDuration),
|
||||
)
|
||||
|
||||
endpointManager.ShutdownEndpoints()
|
||||
pluginsCmd.AddCommand(registry.PluginCommands()...)
|
||||
|
||||
}
|
||||
|
||||
func ExecuteRootCommand() error {
|
||||
if err := initApp(); err != nil {
|
||||
return err
|
||||
}
|
||||
return rootCmd.Execute()
|
||||
}
|
||||
|
|
54
internal/cmd/serve.go
Normal file
54
internal/cmd/serve.go
Normal file
|
@ -0,0 +1,54 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"github.com/baez90/inetmock/internal/endpoints"
|
||||
"github.com/baez90/inetmock/pkg/config"
|
||||
"github.com/spf13/cobra"
|
||||
"go.uber.org/zap"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
var (
|
||||
endpointManager endpoints.EndpointManager
|
||||
serveCmd = &cobra.Command{
|
||||
Use: "serve",
|
||||
Short: "Starts the INetMock server",
|
||||
Long: ``,
|
||||
Run: startINetMock,
|
||||
}
|
||||
)
|
||||
|
||||
func startINetMock(_ *cobra.Command, _ []string) {
|
||||
endpointManager = endpoints.NewEndpointManager(logger)
|
||||
cfg := config.Instance()
|
||||
for endpointName, endpointHandler := range cfg.EndpointConfigs() {
|
||||
handlerSubConfig := cfg.Viper().Sub(strings.Join([]string{config.EndpointsKey, endpointName, config.OptionsKey}, "."))
|
||||
endpointHandler.Options = handlerSubConfig
|
||||
if err := endpointManager.CreateEndpoint(endpointName, endpointHandler); err != nil {
|
||||
logger.Warn(
|
||||
"error occurred while creating endpoint",
|
||||
zap.String("endpointName", endpointName),
|
||||
zap.String("handlerName", endpointHandler.Handler),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
endpointManager.StartEndpoints()
|
||||
|
||||
signalChannel := make(chan os.Signal, 1)
|
||||
signal.Notify(signalChannel, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT)
|
||||
|
||||
// block until canceled
|
||||
s := <-signalChannel
|
||||
|
||||
logger.Info(
|
||||
"got signal to quit",
|
||||
zap.String("signal", s.String()),
|
||||
)
|
||||
|
||||
endpointManager.ShutdownEndpoints()
|
||||
}
|
|
@ -1,6 +0,0 @@
|
|||
package config
|
||||
|
||||
const (
|
||||
EndpointsKey = "endpoints"
|
||||
OptionsKey = "options"
|
||||
)
|
|
@ -1,55 +0,0 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"github.com/baez90/inetmock/pkg/logging"
|
||||
"github.com/baez90/inetmock/pkg/path"
|
||||
"github.com/spf13/pflag"
|
||||
"github.com/spf13/viper"
|
||||
"go.uber.org/zap"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func CreateConfig() Config {
|
||||
logger, _ := logging.CreateLogger()
|
||||
return &config{
|
||||
logger: logger.Named("Config"),
|
||||
}
|
||||
}
|
||||
|
||||
type Config interface {
|
||||
InitConfig(flags *pflag.FlagSet)
|
||||
ReadConfig(configFilePath string) error
|
||||
}
|
||||
|
||||
type config struct {
|
||||
logger logging.Logger
|
||||
}
|
||||
|
||||
func (c config) InitConfig(flags *pflag.FlagSet) {
|
||||
viper.SetConfigName("config")
|
||||
viper.SetConfigType("yaml")
|
||||
viper.AddConfigPath("/etc/inetmock/")
|
||||
viper.AddConfigPath("$HOME/.inetmock")
|
||||
viper.AddConfigPath(".")
|
||||
viper.SetEnvPrefix("INetMock")
|
||||
_ = viper.BindPFlags(flags)
|
||||
viper.SetEnvKeyReplacer(strings.NewReplacer("-", "_"))
|
||||
viper.AutomaticEnv()
|
||||
}
|
||||
|
||||
func (c *config) ReadConfig(configFilePath string) (err error) {
|
||||
if configFilePath != "" && path.FileExists(configFilePath) {
|
||||
c.logger.Info(
|
||||
"loading config from passed config file path",
|
||||
zap.String("configFilePath", configFilePath),
|
||||
)
|
||||
viper.SetConfigFile(configFilePath)
|
||||
}
|
||||
if err = viper.ReadInConfig(); err != nil {
|
||||
if _, ok := err.(viper.ConfigFileNotFoundError); ok {
|
||||
err = nil
|
||||
c.logger.Warn("failed to load config")
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
|
@ -1,49 +0,0 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"github.com/baez90/inetmock/pkg/api"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
type MultiHandlerConfig interface {
|
||||
HandlerName() string
|
||||
ListenAddress() string
|
||||
Ports() []uint16
|
||||
Options() *viper.Viper
|
||||
HandlerConfigs() []api.HandlerConfig
|
||||
}
|
||||
|
||||
type multiHandlerConfig struct {
|
||||
handlerName string
|
||||
ports []uint16
|
||||
listenAddress string
|
||||
options *viper.Viper
|
||||
}
|
||||
|
||||
func NewMultiHandlerConfig(handlerName string, ports []uint16, listenAddress string, options *viper.Viper) MultiHandlerConfig {
|
||||
return &multiHandlerConfig{handlerName: handlerName, ports: ports, listenAddress: listenAddress, options: options}
|
||||
}
|
||||
|
||||
func (m multiHandlerConfig) HandlerName() string {
|
||||
return m.handlerName
|
||||
}
|
||||
|
||||
func (m multiHandlerConfig) ListenAddress() string {
|
||||
return m.listenAddress
|
||||
}
|
||||
|
||||
func (m multiHandlerConfig) Ports() []uint16 {
|
||||
return m.ports
|
||||
}
|
||||
|
||||
func (m multiHandlerConfig) Options() *viper.Viper {
|
||||
return m.options
|
||||
}
|
||||
|
||||
func (m multiHandlerConfig) HandlerConfigs() []api.HandlerConfig {
|
||||
configs := make([]api.HandlerConfig, 0)
|
||||
for _, port := range m.ports {
|
||||
configs = append(configs, api.NewHandlerConfig(m.handlerName, port, m.listenAddress, m.options))
|
||||
}
|
||||
return configs
|
||||
}
|
|
@ -1,226 +0,0 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"github.com/baez90/inetmock/pkg/api"
|
||||
"github.com/spf13/viper"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_multiHandlerConfig_HandlerConfigs(t *testing.T) {
|
||||
type fields struct {
|
||||
handlerName string
|
||||
ports []uint16
|
||||
listenAddress string
|
||||
options *viper.Viper
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
want []api.HandlerConfig
|
||||
}{
|
||||
{
|
||||
name: "Get empty array if no ports are set",
|
||||
fields: fields{},
|
||||
want: make([]api.HandlerConfig, 0),
|
||||
},
|
||||
{
|
||||
name: "Get a single handler config if only one port is set",
|
||||
fields: fields{
|
||||
handlerName: "sampleHandler",
|
||||
ports: []uint16{80},
|
||||
listenAddress: "0.0.0.0",
|
||||
options: nil,
|
||||
},
|
||||
want: []api.HandlerConfig{
|
||||
api.NewHandlerConfig("sampleHandler", 80, "0.0.0.0", nil),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Get multiple handler configs if only one port is set",
|
||||
fields: fields{
|
||||
handlerName: "sampleHandler",
|
||||
ports: []uint16{80, 8080},
|
||||
listenAddress: "0.0.0.0",
|
||||
options: nil,
|
||||
},
|
||||
want: []api.HandlerConfig{
|
||||
api.NewHandlerConfig("sampleHandler", 80, "0.0.0.0", nil),
|
||||
api.NewHandlerConfig("sampleHandler", 8080, "0.0.0.0", nil),
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
m := multiHandlerConfig{
|
||||
handlerName: tt.fields.handlerName,
|
||||
ports: tt.fields.ports,
|
||||
listenAddress: tt.fields.listenAddress,
|
||||
options: tt.fields.options,
|
||||
}
|
||||
if got := m.HandlerConfigs(); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("HandlerConfigs() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_multiHandlerConfig_HandlerName(t *testing.T) {
|
||||
type fields struct {
|
||||
handlerName string
|
||||
ports []uint16
|
||||
listenAddress string
|
||||
options *viper.Viper
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "Get empty handler name for uninitialized struct",
|
||||
fields: fields{},
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "Get expected handler name for initialized struct",
|
||||
fields: fields{
|
||||
handlerName: "sampleHandler",
|
||||
},
|
||||
want: "sampleHandler",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
m := multiHandlerConfig{
|
||||
handlerName: tt.fields.handlerName,
|
||||
ports: tt.fields.ports,
|
||||
listenAddress: tt.fields.listenAddress,
|
||||
options: tt.fields.options,
|
||||
}
|
||||
if got := m.HandlerName(); got != tt.want {
|
||||
t.Errorf("HandlerName() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_multiHandlerConfig_ListenAddress(t *testing.T) {
|
||||
type fields struct {
|
||||
handlerName string
|
||||
ports []uint16
|
||||
listenAddress string
|
||||
options *viper.Viper
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "Get empty ListenAddress for uninitialized struct",
|
||||
fields: fields{},
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "Get expected ListenAddress for initialized struct",
|
||||
fields: fields{
|
||||
listenAddress: "0.0.0.0",
|
||||
},
|
||||
want: "0.0.0.0",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
m := multiHandlerConfig{
|
||||
handlerName: tt.fields.handlerName,
|
||||
ports: tt.fields.ports,
|
||||
listenAddress: tt.fields.listenAddress,
|
||||
options: tt.fields.options,
|
||||
}
|
||||
if got := m.ListenAddress(); got != tt.want {
|
||||
t.Errorf("ListenAddress() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_multiHandlerConfig_Options(t *testing.T) {
|
||||
type fields struct {
|
||||
handlerName string
|
||||
ports []uint16
|
||||
listenAddress string
|
||||
options *viper.Viper
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
want *viper.Viper
|
||||
}{
|
||||
{
|
||||
name: "Get nil Options for uninitialized struct",
|
||||
fields: fields{},
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "Get expected Options for initialized struct",
|
||||
fields: fields{
|
||||
options: viper.New(),
|
||||
},
|
||||
want: viper.New(),
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
m := multiHandlerConfig{
|
||||
handlerName: tt.fields.handlerName,
|
||||
ports: tt.fields.ports,
|
||||
listenAddress: tt.fields.listenAddress,
|
||||
options: tt.fields.options,
|
||||
}
|
||||
if got := m.Options(); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("Options() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_multiHandlerConfig_Ports(t *testing.T) {
|
||||
type fields struct {
|
||||
handlerName string
|
||||
ports []uint16
|
||||
listenAddress string
|
||||
options *viper.Viper
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
want []uint16
|
||||
}{
|
||||
{
|
||||
name: "Get empty Ports for uninitialized struct",
|
||||
fields: fields{},
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "Get expected Ports for initialized struct",
|
||||
fields: fields{
|
||||
ports: []uint16{80, 8080},
|
||||
},
|
||||
want: []uint16{80, 8080},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
m := multiHandlerConfig{
|
||||
handlerName: tt.fields.handlerName,
|
||||
ports: tt.fields.ports,
|
||||
listenAddress: tt.fields.listenAddress,
|
||||
options: tt.fields.options,
|
||||
}
|
||||
if got := m.Ports(); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("Ports() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,35 +0,0 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
const (
|
||||
pluginConfigKey = "handler"
|
||||
listenAddressConfigKey = "listenAddress"
|
||||
portConfigKey = "port"
|
||||
portsConfigKey = "ports"
|
||||
)
|
||||
|
||||
func CreateMultiHandlerConfig(handlerConfig *viper.Viper) MultiHandlerConfig {
|
||||
return NewMultiHandlerConfig(
|
||||
handlerConfig.GetString(pluginConfigKey),
|
||||
portsFromConfig(handlerConfig),
|
||||
handlerConfig.GetString(listenAddressConfigKey),
|
||||
handlerConfig.Sub(OptionsKey),
|
||||
)
|
||||
}
|
||||
|
||||
func portsFromConfig(handlerConfig *viper.Viper) (ports []uint16) {
|
||||
if portsInt := handlerConfig.GetIntSlice(portsConfigKey); len(portsInt) > 0 {
|
||||
for _, port := range portsInt {
|
||||
ports = append(ports, uint16(port))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if portInt := handlerConfig.GetInt(portConfigKey); portInt > 0 {
|
||||
ports = append(ports, uint16(portInt))
|
||||
}
|
||||
return
|
||||
}
|
|
@ -1,145 +0,0 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"github.com/spf13/viper"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCreateMultiHandlerConfig(t *testing.T) {
|
||||
type args struct {
|
||||
handlerConfig *viper.Viper
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want MultiHandlerConfig
|
||||
}{
|
||||
{
|
||||
name: "Get simple multiHandlerConfig from config",
|
||||
args: args{
|
||||
handlerConfig: configFromString(`
|
||||
handler: sampleHandler
|
||||
listenAddress: 0.0.0.0
|
||||
ports:
|
||||
- 80
|
||||
- 8080
|
||||
options: {}
|
||||
`),
|
||||
},
|
||||
want: &multiHandlerConfig{
|
||||
handlerName: "sampleHandler",
|
||||
ports: []uint16{80, 8080},
|
||||
listenAddress: "0.0.0.0",
|
||||
options: viper.New(),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Get more complex multiHandlerConfig from config",
|
||||
args: args{
|
||||
handlerConfig: configFromString(`
|
||||
handler: sampleHandler
|
||||
listenAddress: 0.0.0.0
|
||||
ports:
|
||||
- 80
|
||||
- 8080
|
||||
options:
|
||||
optionA: asdf
|
||||
optionB: as1234
|
||||
`),
|
||||
},
|
||||
want: &multiHandlerConfig{
|
||||
handlerName: "sampleHandler",
|
||||
ports: []uint16{80, 8080},
|
||||
listenAddress: "0.0.0.0",
|
||||
options: configFromString(`
|
||||
nesting:
|
||||
optionA: asdf
|
||||
optionB: as1234
|
||||
`).Sub("nesting"),
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := CreateMultiHandlerConfig(tt.args.handlerConfig); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("CreateMultiHandlerConfig() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_portsFromConfig(t *testing.T) {
|
||||
type args struct {
|
||||
handlerConfig *viper.Viper
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantPorts []uint16
|
||||
}{
|
||||
{
|
||||
name: "Empty array if config value is not set",
|
||||
args: args{
|
||||
handlerConfig: viper.New(),
|
||||
},
|
||||
wantPorts: nil,
|
||||
},
|
||||
{
|
||||
name: "Array of one if `port` is set",
|
||||
args: args{
|
||||
handlerConfig: configFromString(`
|
||||
port: 80
|
||||
`),
|
||||
},
|
||||
wantPorts: []uint16{80},
|
||||
},
|
||||
{
|
||||
name: "Array of one if `ports` is set as array",
|
||||
args: args{
|
||||
handlerConfig: configFromString(`
|
||||
ports:
|
||||
- 80
|
||||
`),
|
||||
},
|
||||
wantPorts: []uint16{80},
|
||||
},
|
||||
{
|
||||
name: "Array of two if `ports` is set as array",
|
||||
args: args{
|
||||
handlerConfig: configFromString(`
|
||||
ports:
|
||||
- 80
|
||||
- 8080
|
||||
`),
|
||||
},
|
||||
wantPorts: []uint16{80, 8080},
|
||||
},
|
||||
{
|
||||
name: "Array of two if `port` is set as array",
|
||||
args: args{
|
||||
handlerConfig: configFromString(`
|
||||
ports:
|
||||
- 80
|
||||
- 8080
|
||||
`),
|
||||
},
|
||||
wantPorts: []uint16{80, 8080},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if gotPorts := portsFromConfig(tt.args.handlerConfig); !reflect.DeepEqual(gotPorts, tt.wantPorts) {
|
||||
t.Errorf("portsFromConfig() = %v, want %v", gotPorts, tt.wantPorts)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func configFromString(yaml string) (config *viper.Viper) {
|
||||
config = viper.New()
|
||||
config.SetConfigType("yaml")
|
||||
_ = config.ReadConfig(bytes.NewBufferString(yaml))
|
||||
return
|
||||
}
|
|
@ -3,6 +3,7 @@ package endpoints
|
|||
|
||||
import (
|
||||
"github.com/baez90/inetmock/pkg/api"
|
||||
"github.com/baez90/inetmock/pkg/config"
|
||||
)
|
||||
|
||||
type Endpoint interface {
|
||||
|
@ -14,7 +15,7 @@ type Endpoint interface {
|
|||
type endpoint struct {
|
||||
name string
|
||||
handler api.ProtocolHandler
|
||||
config api.HandlerConfig
|
||||
config config.HandlerConfig
|
||||
}
|
||||
|
||||
func (e endpoint) Name() string {
|
||||
|
|
|
@ -2,8 +2,8 @@ package endpoints
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/baez90/inetmock/internal/config"
|
||||
"github.com/baez90/inetmock/internal/plugins"
|
||||
config2 "github.com/baez90/inetmock/pkg/config"
|
||||
"github.com/baez90/inetmock/pkg/logging"
|
||||
"go.uber.org/zap"
|
||||
"sync"
|
||||
|
@ -13,7 +13,7 @@ import (
|
|||
type EndpointManager interface {
|
||||
RegisteredEndpoints() []Endpoint
|
||||
StartedEndpoints() []Endpoint
|
||||
CreateEndpoint(name string, multiHandlerConfig config.MultiHandlerConfig) error
|
||||
CreateEndpoint(name string, multiHandlerConfig config2.MultiHandlerConfig) error
|
||||
StartEndpoints()
|
||||
ShutdownEndpoints()
|
||||
}
|
||||
|
@ -40,16 +40,16 @@ func (e endpointManager) StartedEndpoints() []Endpoint {
|
|||
return e.properlyStartedEndpoints
|
||||
}
|
||||
|
||||
func (e *endpointManager) CreateEndpoint(name string, multiHandlerConfig config.MultiHandlerConfig) error {
|
||||
func (e *endpointManager) CreateEndpoint(name string, multiHandlerConfig config2.MultiHandlerConfig) error {
|
||||
for _, handlerConfig := range multiHandlerConfig.HandlerConfigs() {
|
||||
if handler, ok := e.registry.HandlerForName(multiHandlerConfig.HandlerName()); ok {
|
||||
if handler, ok := e.registry.HandlerForName(multiHandlerConfig.Handler); ok {
|
||||
e.registeredEndpoints = append(e.registeredEndpoints, &endpoint{
|
||||
name: name,
|
||||
handler: handler,
|
||||
config: handlerConfig,
|
||||
})
|
||||
} else {
|
||||
return fmt.Errorf("no matching handler registered for names %s", multiHandlerConfig.HandlerName())
|
||||
return fmt.Errorf("no matching handler registered for names %s", multiHandlerConfig.Handler)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
package endpoints
|
||||
|
||||
import (
|
||||
"github.com/baez90/inetmock/internal/config"
|
||||
api_mock "github.com/baez90/inetmock/internal/mock/api"
|
||||
logging_mock "github.com/baez90/inetmock/internal/mock/logging"
|
||||
plugins_mock "github.com/baez90/inetmock/internal/mock/plugins"
|
||||
"github.com/baez90/inetmock/internal/plugins"
|
||||
config2 "github.com/baez90/inetmock/pkg/config"
|
||||
"github.com/baez90/inetmock/pkg/logging"
|
||||
"github.com/golang/mock/gomock"
|
||||
"testing"
|
||||
|
@ -20,7 +20,7 @@ func Test_endpointManager_CreateEndpoint(t *testing.T) {
|
|||
}
|
||||
type args struct {
|
||||
name string
|
||||
multiHandlerConfig config.MultiHandlerConfig
|
||||
multiHandlerConfig config2.MultiHandlerConfig
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
|
@ -52,12 +52,11 @@ func Test_endpointManager_CreateEndpoint(t *testing.T) {
|
|||
},
|
||||
args: args{
|
||||
name: "sampleEndpoint",
|
||||
multiHandlerConfig: config.NewMultiHandlerConfig(
|
||||
"sampleHandler",
|
||||
[]uint16{80},
|
||||
"0.0.0.0",
|
||||
nil,
|
||||
),
|
||||
multiHandlerConfig: config2.MultiHandlerConfig{
|
||||
Handler: "sampleHandler",
|
||||
Ports: []uint16{80},
|
||||
ListenAddress: "0.0.0.0",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
|
@ -83,12 +82,11 @@ func Test_endpointManager_CreateEndpoint(t *testing.T) {
|
|||
},
|
||||
args: args{
|
||||
name: "sampleEndpoint",
|
||||
multiHandlerConfig: config.NewMultiHandlerConfig(
|
||||
"sampleHandler",
|
||||
[]uint16{80},
|
||||
"0.0.0.0",
|
||||
nil,
|
||||
),
|
||||
multiHandlerConfig: config2.MultiHandlerConfig{
|
||||
Handler: "sampleHandler",
|
||||
Ports: []uint16{80},
|
||||
ListenAddress: "0.0.0.0",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"fmt"
|
||||
api_mock "github.com/baez90/inetmock/internal/mock/api"
|
||||
"github.com/baez90/inetmock/pkg/api"
|
||||
"github.com/baez90/inetmock/pkg/config"
|
||||
"github.com/golang/mock/gomock"
|
||||
"testing"
|
||||
)
|
||||
|
@ -12,7 +13,7 @@ func Test_endpoint_Name(t *testing.T) {
|
|||
type fields struct {
|
||||
name string
|
||||
handler api.ProtocolHandler
|
||||
config api.HandlerConfig
|
||||
config config.HandlerConfig
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
|
@ -50,7 +51,7 @@ func Test_endpoint_Shutdown(t *testing.T) {
|
|||
type fields struct {
|
||||
name string
|
||||
handler api.ProtocolHandler
|
||||
config api.HandlerConfig
|
||||
config config.HandlerConfig
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
|
@ -102,17 +103,16 @@ func Test_endpoint_Shutdown(t *testing.T) {
|
|||
|
||||
func Test_endpoint_Start(t *testing.T) {
|
||||
|
||||
demoHandlerConfig := api.NewHandlerConfig(
|
||||
"sampleHandler",
|
||||
80,
|
||||
"0.0.0.0",
|
||||
nil,
|
||||
)
|
||||
demoHandlerConfig := config.HandlerConfig{
|
||||
HandlerName: "sampleHandler",
|
||||
Port: 80,
|
||||
ListenAddress: "0.0.0.0",
|
||||
}
|
||||
|
||||
type fields struct {
|
||||
name string
|
||||
handler api.ProtocolHandler
|
||||
config api.HandlerConfig
|
||||
config config.HandlerConfig
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
|
@ -125,7 +125,7 @@ func Test_endpoint_Start(t *testing.T) {
|
|||
handler: func() api.ProtocolHandler {
|
||||
handler := api_mock.NewMockProtocolHandler(gomock.NewController(t))
|
||||
handler.EXPECT().
|
||||
Start(nil).
|
||||
Start(gomock.Any()).
|
||||
MaxTimes(1).
|
||||
Return(nil)
|
||||
return handler
|
||||
|
@ -139,7 +139,7 @@ func Test_endpoint_Start(t *testing.T) {
|
|||
handler: func() api.ProtocolHandler {
|
||||
handler := api_mock.NewMockProtocolHandler(gomock.NewController(t))
|
||||
handler.EXPECT().
|
||||
Start(nil).
|
||||
Start(gomock.Any()).
|
||||
MaxTimes(1).
|
||||
Return(fmt.Errorf(""))
|
||||
return handler
|
||||
|
|
|
@ -10,6 +10,7 @@ import (
|
|||
"path/filepath"
|
||||
"plugin"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -19,6 +20,7 @@ var (
|
|||
|
||||
type HandlerRegistry interface {
|
||||
LoadPlugins(pluginsPath string) error
|
||||
AvailableHandlers() []string
|
||||
RegisterHandler(handlerName string, handlerProvider api.PluginInstanceFactory, subCommands ...*cobra.Command)
|
||||
HandlerForName(handlerName string) (api.ProtocolHandler, bool)
|
||||
PluginCommands() []*cobra.Command
|
||||
|
@ -29,11 +31,19 @@ type handlerRegistry struct {
|
|||
pluginCommands []*cobra.Command
|
||||
}
|
||||
|
||||
func (h handlerRegistry) AvailableHandlers() (availableHandlers []string) {
|
||||
for s := range h.handlers {
|
||||
availableHandlers = append(availableHandlers, s)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (h handlerRegistry) PluginCommands() []*cobra.Command {
|
||||
return h.pluginCommands
|
||||
}
|
||||
|
||||
func (h *handlerRegistry) HandlerForName(handlerName string) (instance api.ProtocolHandler, ok bool) {
|
||||
handlerName = strings.ToLower(handlerName)
|
||||
var provider api.PluginInstanceFactory
|
||||
if provider, ok = h.handlers[handlerName]; ok {
|
||||
instance = provider()
|
||||
|
@ -42,6 +52,7 @@ func (h *handlerRegistry) HandlerForName(handlerName string) (instance api.Proto
|
|||
}
|
||||
|
||||
func (h *handlerRegistry) RegisterHandler(handlerName string, handlerProvider api.PluginInstanceFactory, subCommands ...*cobra.Command) {
|
||||
handlerName = strings.ToLower(handlerName)
|
||||
if _, exists := h.handlers[handlerName]; exists {
|
||||
panic(fmt.Sprintf("handler with name %s is already registered - there's something strange...in the neighborhood", handlerName))
|
||||
}
|
||||
|
|
|
@ -25,8 +25,8 @@ tls:
|
|||
NotBeforeRelative: 168h
|
||||
NotAfterRelative: 168h
|
||||
rootCaCert:
|
||||
publicKey: ./ca.pem
|
||||
privateKey: ./ca.key
|
||||
publicKeyPath: ./ca.pem
|
||||
privateKeyPath: ./ca.key
|
||||
certCachePath: /tmp/inetmock/
|
||||
|
||||
endpoints:
|
||||
|
@ -41,7 +41,8 @@ endpoints:
|
|||
proxy:
|
||||
handler: http_proxy
|
||||
listenAddress: 0.0.0.0
|
||||
port: 3128
|
||||
ports:
|
||||
- 3128
|
||||
options:
|
||||
target:
|
||||
ipAddress: 127.0.0.1
|
||||
|
@ -59,7 +60,8 @@ endpoints:
|
|||
plainDns:
|
||||
handler: dns_mock
|
||||
listenAddress: 0.0.0.0
|
||||
port: 53
|
||||
ports:
|
||||
- 53
|
||||
options:
|
||||
rules:
|
||||
- pattern: ".*\\.google\\.com"
|
||||
|
@ -73,7 +75,8 @@ endpoints:
|
|||
dnsOverTlsDowngrade:
|
||||
handler: tls_interceptor
|
||||
listenAddress: 0.0.0.0
|
||||
port: 853
|
||||
ports:
|
||||
- 853
|
||||
options:
|
||||
target:
|
||||
ipAddress: 127.0.0.1
|
||||
|
|
|
@ -1,42 +0,0 @@
|
|||
package api
|
||||
|
||||
import "github.com/spf13/viper"
|
||||
|
||||
type HandlerConfig interface {
|
||||
HandlerName() string
|
||||
ListenAddress() string
|
||||
Port() uint16
|
||||
Options() *viper.Viper
|
||||
}
|
||||
|
||||
func NewHandlerConfig(handlerName string, port uint16, listenAddress string, options *viper.Viper) HandlerConfig {
|
||||
return &handlerConfig{
|
||||
handlerName: handlerName,
|
||||
port: port,
|
||||
listenAddress: listenAddress,
|
||||
options: options,
|
||||
}
|
||||
}
|
||||
|
||||
type handlerConfig struct {
|
||||
handlerName string
|
||||
port uint16
|
||||
listenAddress string
|
||||
options *viper.Viper
|
||||
}
|
||||
|
||||
func (h handlerConfig) HandlerName() string {
|
||||
return h.handlerName
|
||||
}
|
||||
|
||||
func (h handlerConfig) ListenAddress() string {
|
||||
return h.listenAddress
|
||||
}
|
||||
|
||||
func (h handlerConfig) Port() uint16 {
|
||||
return h.port
|
||||
}
|
||||
|
||||
func (h handlerConfig) Options() *viper.Viper {
|
||||
return h.options
|
||||
}
|
|
@ -2,6 +2,7 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"github.com/baez90/inetmock/pkg/config"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
|
@ -10,6 +11,6 @@ type PluginInstanceFactory func() ProtocolHandler
|
|||
type LoggingFactory func() (*zap.Logger, error)
|
||||
|
||||
type ProtocolHandler interface {
|
||||
Start(config HandlerConfig) error
|
||||
Start(config config.HandlerConfig) error
|
||||
Shutdown() error
|
||||
}
|
||||
|
|
|
@ -2,8 +2,8 @@ package api
|
|||
|
||||
import (
|
||||
"github.com/baez90/inetmock/pkg/cert"
|
||||
config2 "github.com/baez90/inetmock/pkg/config"
|
||||
"github.com/baez90/inetmock/pkg/logging"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -19,7 +19,7 @@ type services struct {
|
|||
}
|
||||
|
||||
func InitServices(
|
||||
config *viper.Viper,
|
||||
config config2.Config,
|
||||
logger logging.Logger,
|
||||
) error {
|
||||
certStore, err := cert.NewDefaultStore(config, logger)
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"crypto/x509"
|
||||
"fmt"
|
||||
certmock "github.com/baez90/inetmock/internal/mock/cert"
|
||||
config2 "github.com/baez90/inetmock/pkg/config"
|
||||
"github.com/golang/mock/gomock"
|
||||
"os"
|
||||
"path"
|
||||
|
@ -284,13 +285,13 @@ func Test_fileSystemCache_Put(t *testing.T) {
|
|||
}
|
||||
|
||||
func setupCertGen() Generator {
|
||||
return NewDefaultGenerator(Options{
|
||||
Validity: ValidityByPurpose{
|
||||
Server: ValidityDuration{
|
||||
return NewDefaultGenerator(config2.CertOptions{
|
||||
Validity: config2.ValidityByPurpose{
|
||||
Server: config2.ValidityDuration{
|
||||
NotBeforeRelative: serverRelativeValidity,
|
||||
NotAfterRelative: serverRelativeValidity,
|
||||
},
|
||||
CA: ValidityDuration{
|
||||
CA: config2.ValidityDuration{
|
||||
NotBeforeRelative: caRelativeValidity,
|
||||
NotAfterRelative: caRelativeValidity,
|
||||
},
|
||||
|
|
|
@ -1,19 +1,11 @@
|
|||
package cert
|
||||
|
||||
const (
|
||||
CurveTypeP224 CurveType = "P224"
|
||||
CurveTypeP256 CurveType = "P256"
|
||||
CurveTypeP384 CurveType = "P384"
|
||||
CurveTypeP521 CurveType = "P521"
|
||||
CurveTypeED25519 CurveType = "ED25519"
|
||||
|
||||
defaultServerValidityDuration = "168h"
|
||||
defaultCAValidityDuration = "17520h"
|
||||
|
||||
certCachePathConfigKey = "tls.certCachePath"
|
||||
ecdsaCurveConfigKey = "tls.ecdsaCurve"
|
||||
publicKeyConfigKey = "tls.rootCaCert.publicKey"
|
||||
privateKeyPathConfigKey = "tls.rootCaCert.privateKey"
|
||||
caCertValidityNotBeforeKey = "tls.validity.ca.notBeforeRelative"
|
||||
caCertValidityNotAfterKey = "tls.validity.ca.notAfterRelative"
|
||||
serverCertValidityNotBeforeKey = "tls.validity.server.notBeforeRelative"
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
config2 "github.com/baez90/inetmock/pkg/config"
|
||||
"github.com/baez90/inetmock/pkg/defaulting"
|
||||
"math/big"
|
||||
"net"
|
||||
|
@ -35,11 +36,11 @@ type Generator interface {
|
|||
ServerCert(options GenerationOptions, ca *tls.Certificate) (*tls.Certificate, error)
|
||||
}
|
||||
|
||||
func NewDefaultGenerator(options Options) Generator {
|
||||
func NewDefaultGenerator(options config2.CertOptions) Generator {
|
||||
return NewGenerator(options, NewTimeSource(), defaultKeyProvider(options))
|
||||
}
|
||||
|
||||
func NewGenerator(options Options, source TimeSource, provider KeyProvider) Generator {
|
||||
func NewGenerator(options config2.CertOptions, source TimeSource, provider KeyProvider) Generator {
|
||||
return &generator{
|
||||
options: options,
|
||||
provider: provider,
|
||||
|
@ -48,7 +49,7 @@ func NewGenerator(options Options, source TimeSource, provider KeyProvider) Gene
|
|||
}
|
||||
|
||||
type generator struct {
|
||||
options Options
|
||||
options config2.CertOptions
|
||||
provider KeyProvider
|
||||
timeSource TimeSource
|
||||
outDir string
|
||||
|
|
|
@ -1,79 +1,15 @@
|
|||
package cert
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/spf13/viper"
|
||||
"github.com/baez90/inetmock/pkg/config"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
requiredConfigKeys = []string{
|
||||
publicKeyConfigKey,
|
||||
privateKeyPathConfigKey,
|
||||
}
|
||||
)
|
||||
|
||||
type CurveType string
|
||||
|
||||
type File struct {
|
||||
PublicKeyPath string
|
||||
PrivateKeyPath string
|
||||
}
|
||||
|
||||
type ValidityDuration struct {
|
||||
NotBeforeRelative time.Duration
|
||||
NotAfterRelative time.Duration
|
||||
}
|
||||
|
||||
type ValidityByPurpose struct {
|
||||
CA ValidityDuration
|
||||
Server ValidityDuration
|
||||
}
|
||||
|
||||
type Options struct {
|
||||
RootCACert File
|
||||
CertCachePath string
|
||||
Curve CurveType
|
||||
Validity ValidityByPurpose
|
||||
}
|
||||
|
||||
func loadFromConfig(config *viper.Viper) (Options, error) {
|
||||
missingKeys := make([]string, 0)
|
||||
for _, requiredKey := range requiredConfigKeys {
|
||||
if !config.IsSet(requiredKey) {
|
||||
missingKeys = append(missingKeys, requiredKey)
|
||||
}
|
||||
}
|
||||
|
||||
if len(missingKeys) > 0 {
|
||||
return Options{}, fmt.Errorf("config keys are missing: %s", strings.Join(missingKeys, ", "))
|
||||
}
|
||||
|
||||
config.SetDefault(certCachePathConfigKey, os.TempDir())
|
||||
config.SetDefault(ecdsaCurveConfigKey, string(CurveTypeED25519))
|
||||
config.SetDefault(caCertValidityNotBeforeKey, defaultCAValidityDuration)
|
||||
config.SetDefault(caCertValidityNotAfterKey, defaultCAValidityDuration)
|
||||
config.SetDefault(serverCertValidityNotBeforeKey, defaultServerValidityDuration)
|
||||
config.SetDefault(serverCertValidityNotAfterKey, defaultServerValidityDuration)
|
||||
|
||||
return Options{
|
||||
CertCachePath: config.GetString(certCachePathConfigKey),
|
||||
Curve: CurveType(config.GetString(ecdsaCurveConfigKey)),
|
||||
Validity: ValidityByPurpose{
|
||||
CA: ValidityDuration{
|
||||
NotBeforeRelative: config.GetDuration(caCertValidityNotBeforeKey),
|
||||
NotAfterRelative: config.GetDuration(caCertValidityNotAfterKey),
|
||||
},
|
||||
Server: ValidityDuration{
|
||||
NotBeforeRelative: config.GetDuration(serverCertValidityNotBeforeKey),
|
||||
NotAfterRelative: config.GetDuration(serverCertValidityNotAfterKey),
|
||||
},
|
||||
},
|
||||
RootCACert: File{
|
||||
PublicKeyPath: config.GetString(publicKeyConfigKey),
|
||||
PrivateKeyPath: config.GetString(privateKeyPathConfigKey),
|
||||
},
|
||||
}, nil
|
||||
func init() {
|
||||
config.AddDefaultValue(certCachePathConfigKey, os.TempDir())
|
||||
config.AddDefaultValue(ecdsaCurveConfigKey, string(config.CurveTypeED25519))
|
||||
config.AddDefaultValue(caCertValidityNotBeforeKey, defaultCAValidityDuration)
|
||||
config.AddDefaultValue(caCertValidityNotAfterKey, defaultCAValidityDuration)
|
||||
config.AddDefaultValue(serverCertValidityNotBeforeKey, defaultServerValidityDuration)
|
||||
config.AddDefaultValue(serverCertValidityNotAfterKey, defaultServerValidityDuration)
|
||||
}
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
package cert
|
||||
|
||||
import (
|
||||
"github.com/baez90/inetmock/pkg/config"
|
||||
"github.com/spf13/viper"
|
||||
"os"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
@ -25,7 +25,7 @@ func Test_loadFromConfig(t *testing.T) {
|
|||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want Options
|
||||
want config.CertOptions
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
|
@ -48,19 +48,19 @@ tls:
|
|||
certCachePath: /tmp/inetmock/
|
||||
`),
|
||||
},
|
||||
want: Options{
|
||||
RootCACert: File{
|
||||
want: config.CertOptions{
|
||||
RootCACert: config.File{
|
||||
PublicKeyPath: "./ca.pem",
|
||||
PrivateKeyPath: "./ca.key",
|
||||
},
|
||||
CertCachePath: "/tmp/inetmock/",
|
||||
Curve: CurveTypeP256,
|
||||
Validity: ValidityByPurpose{
|
||||
CA: ValidityDuration{
|
||||
Curve: config.CurveTypeP256,
|
||||
Validity: config.ValidityByPurpose{
|
||||
CA: config.ValidityDuration{
|
||||
NotBeforeRelative: 17520 * time.Hour,
|
||||
NotAfterRelative: 17520 * time.Hour,
|
||||
},
|
||||
Server: ValidityDuration{
|
||||
Server: config.ValidityDuration{
|
||||
NotBeforeRelative: 168 * time.Hour,
|
||||
NotAfterRelative: 168 * time.Hour,
|
||||
},
|
||||
|
@ -76,7 +76,7 @@ tls:
|
|||
privateKey: ./ca.key
|
||||
`),
|
||||
},
|
||||
want: Options{},
|
||||
want: config.CertOptions{},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
|
@ -88,7 +88,7 @@ tls:
|
|||
publicKey: ./ca.pem
|
||||
`),
|
||||
},
|
||||
want: Options{},
|
||||
want: config.CertOptions{},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
|
@ -101,19 +101,19 @@ tls:
|
|||
privateKey: ./ca.key
|
||||
`),
|
||||
},
|
||||
want: Options{
|
||||
RootCACert: File{
|
||||
want: config.CertOptions{
|
||||
RootCACert: config.File{
|
||||
PublicKeyPath: "./ca.pem",
|
||||
PrivateKeyPath: "./ca.key",
|
||||
},
|
||||
CertCachePath: os.TempDir(),
|
||||
Curve: CurveTypeED25519,
|
||||
Validity: ValidityByPurpose{
|
||||
CA: ValidityDuration{
|
||||
Curve: config.CurveTypeED25519,
|
||||
Validity: config.ValidityByPurpose{
|
||||
CA: config.ValidityDuration{
|
||||
NotBeforeRelative: 17520 * time.Hour,
|
||||
NotAfterRelative: 17520 * time.Hour,
|
||||
},
|
||||
Server: ValidityDuration{
|
||||
Server: config.ValidityDuration{
|
||||
NotBeforeRelative: 168 * time.Hour,
|
||||
NotAfterRelative: 168 * time.Hour,
|
||||
},
|
||||
|
@ -124,14 +124,7 @@ tls:
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := loadFromConfig(tt.args.config)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("loadFromConfig() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("loadFromConfig() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,8 +7,8 @@ import (
|
|||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"github.com/baez90/inetmock/pkg/config"
|
||||
"github.com/baez90/inetmock/pkg/logging"
|
||||
"github.com/spf13/viper"
|
||||
"go.uber.org/zap"
|
||||
"net"
|
||||
)
|
||||
|
@ -18,7 +18,7 @@ const (
|
|||
)
|
||||
|
||||
var (
|
||||
defaultKeyProvider = func(options Options) func() (key interface{}, err error) {
|
||||
defaultKeyProvider = func(options config.CertOptions) func() (key interface{}, err error) {
|
||||
return func() (key interface{}, err error) {
|
||||
return privateKeyForCurve(options)
|
||||
}
|
||||
|
@ -34,24 +34,20 @@ type Store interface {
|
|||
}
|
||||
|
||||
func NewDefaultStore(
|
||||
config *viper.Viper,
|
||||
config config.Config,
|
||||
logger logging.Logger,
|
||||
) (Store, error) {
|
||||
options, err := loadFromConfig(config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
timeSource := NewTimeSource()
|
||||
return NewStore(
|
||||
options,
|
||||
NewFileSystemCache(options.CertCachePath, timeSource),
|
||||
NewDefaultGenerator(options),
|
||||
config.TLSConfig(),
|
||||
NewFileSystemCache(config.TLSConfig().CertCachePath, timeSource),
|
||||
NewDefaultGenerator(config.TLSConfig()),
|
||||
logger,
|
||||
)
|
||||
}
|
||||
|
||||
func NewStore(
|
||||
options Options,
|
||||
options config.CertOptions,
|
||||
cache Cache,
|
||||
generator Generator,
|
||||
logger logging.Logger,
|
||||
|
@ -71,7 +67,7 @@ func NewStore(
|
|||
}
|
||||
|
||||
type store struct {
|
||||
options Options
|
||||
options config.CertOptions
|
||||
caCert *tls.Certificate
|
||||
cache Cache
|
||||
timeSource TimeSource
|
||||
|
@ -135,15 +131,15 @@ func (s *store) GetCertificate(serverName string, ip string) (cert *tls.Certific
|
|||
return
|
||||
}
|
||||
|
||||
func privateKeyForCurve(options Options) (privateKey interface{}, err error) {
|
||||
func privateKeyForCurve(options config.CertOptions) (privateKey interface{}, err error) {
|
||||
switch options.Curve {
|
||||
case CurveTypeP224:
|
||||
case config.CurveTypeP224:
|
||||
privateKey, err = ecdsa.GenerateKey(elliptic.P224(), rand.Reader)
|
||||
case CurveTypeP256:
|
||||
case config.CurveTypeP256:
|
||||
privateKey, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
case CurveTypeP384:
|
||||
case config.CurveTypeP384:
|
||||
privateKey, err = ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
|
||||
case CurveTypeP521:
|
||||
case config.CurveTypeP521:
|
||||
privateKey, err = ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
|
||||
default:
|
||||
_, privateKey, err = ed25519.GenerateKey(rand.Reader)
|
||||
|
|
27
pkg/config/certs.go
Normal file
27
pkg/config/certs.go
Normal file
|
@ -0,0 +1,27 @@
|
|||
package config
|
||||
|
||||
import "time"
|
||||
|
||||
type CurveType string
|
||||
|
||||
type File struct {
|
||||
PublicKeyPath string
|
||||
PrivateKeyPath string
|
||||
}
|
||||
|
||||
type ValidityDuration struct {
|
||||
NotBeforeRelative time.Duration
|
||||
NotAfterRelative time.Duration
|
||||
}
|
||||
|
||||
type ValidityByPurpose struct {
|
||||
CA ValidityDuration
|
||||
Server ValidityDuration
|
||||
}
|
||||
|
||||
type CertOptions struct {
|
||||
RootCACert File
|
||||
CertCachePath string
|
||||
Curve CurveType
|
||||
Validity ValidityByPurpose
|
||||
}
|
109
pkg/config/config.go
Normal file
109
pkg/config/config.go
Normal file
|
@ -0,0 +1,109 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"github.com/baez90/inetmock/pkg/logging"
|
||||
"github.com/baez90/inetmock/pkg/path"
|
||||
"github.com/spf13/pflag"
|
||||
"github.com/spf13/viper"
|
||||
"go.uber.org/zap"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
appConfig Config
|
||||
)
|
||||
|
||||
func CreateConfig(flags *pflag.FlagSet) {
|
||||
logger, _ := logging.CreateLogger()
|
||||
configInstance := &config{
|
||||
logger: logger.Named("Config"),
|
||||
cfg: viper.New(),
|
||||
}
|
||||
|
||||
configInstance.cfg.SetConfigName("config")
|
||||
configInstance.cfg.SetConfigType("yaml")
|
||||
configInstance.cfg.AddConfigPath("/etc/inetmock/")
|
||||
configInstance.cfg.AddConfigPath("$HOME/.inetmock")
|
||||
configInstance.cfg.AddConfigPath(".")
|
||||
configInstance.cfg.SetEnvPrefix("INetMock")
|
||||
_ = configInstance.cfg.BindPFlags(flags)
|
||||
configInstance.cfg.SetEnvKeyReplacer(strings.NewReplacer("-", "_"))
|
||||
configInstance.cfg.AutomaticEnv()
|
||||
|
||||
for k, v := range registeredDefaults {
|
||||
configInstance.cfg.SetDefault(k, v)
|
||||
}
|
||||
|
||||
for k, v := range registeredAliases {
|
||||
configInstance.cfg.RegisterAlias(k, v)
|
||||
}
|
||||
|
||||
appConfig = configInstance
|
||||
}
|
||||
|
||||
func Instance() Config {
|
||||
return appConfig
|
||||
}
|
||||
|
||||
type Config interface {
|
||||
ReadConfig(configFilePath string) error
|
||||
ReadConfigString(config, format string) error
|
||||
Viper() *viper.Viper
|
||||
TLSConfig() CertOptions
|
||||
PluginsDir() string
|
||||
EndpointConfigs() map[string]MultiHandlerConfig
|
||||
}
|
||||
|
||||
type config struct {
|
||||
cfg *viper.Viper
|
||||
logger logging.Logger
|
||||
TLS CertOptions
|
||||
PluginsDirectory string
|
||||
Endpoints map[string]MultiHandlerConfig
|
||||
}
|
||||
|
||||
func (c *config) ReadConfigString(config, format string) (err error) {
|
||||
c.cfg.SetConfigType(format)
|
||||
if err = c.cfg.ReadConfig(strings.NewReader(config)); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
err = c.cfg.Unmarshal(c)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *config) EndpointConfigs() map[string]MultiHandlerConfig {
|
||||
return c.Endpoints
|
||||
}
|
||||
|
||||
func (c *config) PluginsDir() string {
|
||||
return c.PluginsDirectory
|
||||
}
|
||||
|
||||
func (c *config) TLSConfig() CertOptions {
|
||||
return c.TLS
|
||||
}
|
||||
|
||||
func (c *config) Viper() *viper.Viper {
|
||||
return c.cfg
|
||||
}
|
||||
|
||||
func (c *config) ReadConfig(configFilePath string) (err error) {
|
||||
if configFilePath != "" && path.FileExists(configFilePath) {
|
||||
c.logger.Info(
|
||||
"loading config from passed config file path",
|
||||
zap.String("configFilePath", configFilePath),
|
||||
)
|
||||
c.cfg.SetConfigFile(configFilePath)
|
||||
}
|
||||
if err = c.cfg.ReadInConfig(); err != nil {
|
||||
if _, ok := err.(viper.ConfigFileNotFoundError); ok {
|
||||
err = nil
|
||||
c.logger.Warn("failed to load config")
|
||||
}
|
||||
}
|
||||
|
||||
err = c.cfg.Unmarshal(c)
|
||||
|
||||
return
|
||||
}
|
88
pkg/config/config_test.go
Normal file
88
pkg/config/config_test.go
Normal file
|
@ -0,0 +1,88 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"github.com/spf13/pflag"
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_config_ReadConfig(t *testing.T) {
|
||||
type args struct {
|
||||
flags *pflag.FlagSet
|
||||
config string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
matcher func(Config) bool
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "Test endpoints config",
|
||||
args: args{
|
||||
flags: pflag.NewFlagSet("", pflag.ContinueOnError),
|
||||
config: `
|
||||
endpoints:
|
||||
plainHttp:
|
||||
handler: http_mock
|
||||
listenAddress: 0.0.0.0
|
||||
ports:
|
||||
- 80
|
||||
- 8080
|
||||
options: {}
|
||||
proxy:
|
||||
handler: http_proxy
|
||||
listenAddress: 0.0.0.0
|
||||
ports:
|
||||
- 3128
|
||||
options:
|
||||
target:
|
||||
ipAddress: 127.0.0.1
|
||||
port: 80
|
||||
`,
|
||||
},
|
||||
matcher: func(c Config) bool {
|
||||
if len(c.EndpointConfigs()) < 1 {
|
||||
t.Error("Expected EndpointConfigs to be set but is empty")
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Test loading of flags by adding",
|
||||
args: args{
|
||||
flags: func() *pflag.FlagSet {
|
||||
flags := pflag.NewFlagSet("", pflag.ContinueOnError)
|
||||
flags.String("plugins-directory", os.TempDir(), "")
|
||||
|
||||
return flags
|
||||
}(),
|
||||
config: ``,
|
||||
},
|
||||
matcher: func(c Config) bool {
|
||||
if c.PluginsDir() != os.TempDir() {
|
||||
t.Errorf("PluginsDir() Expected = %s, Got = %s", os.TempDir(), c.PluginsDir())
|
||||
}
|
||||
return true
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
CreateConfig(tt.args.flags)
|
||||
cfg := Instance()
|
||||
if err := cfg.ReadConfigString(tt.args.config, "yaml"); (err != nil) != tt.wantErr {
|
||||
t.Errorf("ReadConfig() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
if !tt.matcher(cfg) {
|
||||
t.Error("matcher error")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
12
pkg/config/constants.go
Normal file
12
pkg/config/constants.go
Normal file
|
@ -0,0 +1,12 @@
|
|||
package config
|
||||
|
||||
const (
|
||||
EndpointsKey = "endpoints"
|
||||
OptionsKey = "options"
|
||||
|
||||
CurveTypeP224 CurveType = "P224"
|
||||
CurveTypeP256 CurveType = "P256"
|
||||
CurveTypeP384 CurveType = "P384"
|
||||
CurveTypeP521 CurveType = "P521"
|
||||
CurveTypeED25519 CurveType = "ED25519"
|
||||
)
|
17
pkg/config/defaults.go
Normal file
17
pkg/config/defaults.go
Normal file
|
@ -0,0 +1,17 @@
|
|||
package config
|
||||
|
||||
var (
|
||||
registeredDefaults = make(map[string]interface{})
|
||||
// default aliases
|
||||
registeredAliases = map[string]string{
|
||||
"PluginsDirectory": "plugins-directory",
|
||||
}
|
||||
)
|
||||
|
||||
func AddDefaultValue(key string, val interface{}) {
|
||||
registeredDefaults[key] = val
|
||||
}
|
||||
|
||||
func AddAlias(alias, orig string) {
|
||||
registeredAliases[alias] = orig
|
||||
}
|
10
pkg/config/handler_config.go
Normal file
10
pkg/config/handler_config.go
Normal file
|
@ -0,0 +1,10 @@
|
|||
package config
|
||||
|
||||
import "github.com/spf13/viper"
|
||||
|
||||
type HandlerConfig struct {
|
||||
HandlerName string
|
||||
Port uint16
|
||||
ListenAddress string
|
||||
Options *viper.Viper
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
package api
|
||||
package config
|
||||
|
||||
import (
|
||||
"github.com/spf13/viper"
|
||||
|
@ -19,12 +19,12 @@ func Test_handlerConfig_HandlerName(t *testing.T) {
|
|||
want string
|
||||
}{
|
||||
{
|
||||
name: "Get empty HandlerName for uninitialized struct",
|
||||
name: "Get empty Handler for uninitialized struct",
|
||||
fields: fields{},
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "Get expected HandlerName for initialized struct",
|
||||
name: "Get expected Handler for initialized struct",
|
||||
fields: fields{
|
||||
handlerName: "sampleHandler",
|
||||
},
|
||||
|
@ -33,14 +33,14 @@ func Test_handlerConfig_HandlerName(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := handlerConfig{
|
||||
handlerName: tt.fields.handlerName,
|
||||
port: tt.fields.port,
|
||||
listenAddress: tt.fields.listenAddress,
|
||||
options: tt.fields.options,
|
||||
h := HandlerConfig{
|
||||
HandlerName: tt.fields.handlerName,
|
||||
Port: tt.fields.port,
|
||||
ListenAddress: tt.fields.listenAddress,
|
||||
Options: tt.fields.options,
|
||||
}
|
||||
if got := h.HandlerName(); got != tt.want {
|
||||
t.Errorf("HandlerName() = %v, want %v", got, tt.want)
|
||||
if got := h.HandlerName; got != tt.want {
|
||||
t.Errorf("Handler() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -73,13 +73,13 @@ func Test_handlerConfig_ListenAddress(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := handlerConfig{
|
||||
handlerName: tt.fields.handlerName,
|
||||
port: tt.fields.port,
|
||||
listenAddress: tt.fields.listenAddress,
|
||||
options: tt.fields.options,
|
||||
h := HandlerConfig{
|
||||
HandlerName: tt.fields.handlerName,
|
||||
Port: tt.fields.port,
|
||||
ListenAddress: tt.fields.listenAddress,
|
||||
Options: tt.fields.options,
|
||||
}
|
||||
if got := h.ListenAddress(); got != tt.want {
|
||||
if got := h.ListenAddress; got != tt.want {
|
||||
t.Errorf("ListenAddress() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
|
@ -113,13 +113,13 @@ func Test_handlerConfig_Options(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := handlerConfig{
|
||||
handlerName: tt.fields.handlerName,
|
||||
port: tt.fields.port,
|
||||
listenAddress: tt.fields.listenAddress,
|
||||
options: tt.fields.options,
|
||||
h := HandlerConfig{
|
||||
HandlerName: tt.fields.handlerName,
|
||||
Port: tt.fields.port,
|
||||
ListenAddress: tt.fields.listenAddress,
|
||||
Options: tt.fields.options,
|
||||
}
|
||||
if got := h.Options(); !reflect.DeepEqual(got, tt.want) {
|
||||
if got := h.Options; !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("Options() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
|
@ -153,13 +153,13 @@ func Test_handlerConfig_Port(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := handlerConfig{
|
||||
handlerName: tt.fields.handlerName,
|
||||
port: tt.fields.port,
|
||||
listenAddress: tt.fields.listenAddress,
|
||||
options: tt.fields.options,
|
||||
h := HandlerConfig{
|
||||
HandlerName: tt.fields.handlerName,
|
||||
Port: tt.fields.port,
|
||||
ListenAddress: tt.fields.listenAddress,
|
||||
Options: tt.fields.options,
|
||||
}
|
||||
if got := h.Port(); got != tt.want {
|
||||
if got := h.Port; got != tt.want {
|
||||
t.Errorf("Port() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
25
pkg/config/multi_handler_config.go
Normal file
25
pkg/config/multi_handler_config.go
Normal file
|
@ -0,0 +1,25 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
type MultiHandlerConfig struct {
|
||||
Handler string
|
||||
Ports []uint16
|
||||
ListenAddress string
|
||||
Options *viper.Viper
|
||||
}
|
||||
|
||||
func (m MultiHandlerConfig) HandlerConfigs() []HandlerConfig {
|
||||
configs := make([]HandlerConfig, 0)
|
||||
for _, port := range m.Ports {
|
||||
configs = append(configs, HandlerConfig{
|
||||
HandlerName: m.Handler,
|
||||
Port: port,
|
||||
ListenAddress: m.ListenAddress,
|
||||
Options: m.Options,
|
||||
})
|
||||
}
|
||||
return configs
|
||||
}
|
|
@ -2,7 +2,7 @@ package main
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/baez90/inetmock/pkg/api"
|
||||
"github.com/baez90/inetmock/pkg/config"
|
||||
"github.com/baez90/inetmock/pkg/logging"
|
||||
"github.com/miekg/dns"
|
||||
"go.uber.org/zap"
|
||||
|
@ -13,9 +13,9 @@ type dnsHandler struct {
|
|||
dnsServer []*dns.Server
|
||||
}
|
||||
|
||||
func (d *dnsHandler) Start(config api.HandlerConfig) (err error) {
|
||||
options := loadFromConfig(config.Options())
|
||||
addr := fmt.Sprintf("%s:%d", config.ListenAddress(), config.Port())
|
||||
func (d *dnsHandler) Start(config config.HandlerConfig) (err error) {
|
||||
options := loadFromConfig(config.Options)
|
||||
addr := fmt.Sprintf("%s:%d", config.ListenAddress, config.Port)
|
||||
|
||||
handler := ®exHandler{
|
||||
fallback: options.Fallback,
|
||||
|
|
|
@ -2,7 +2,7 @@ package main
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/baez90/inetmock/pkg/api"
|
||||
"github.com/baez90/inetmock/pkg/config"
|
||||
"github.com/baez90/inetmock/pkg/logging"
|
||||
"go.uber.org/zap"
|
||||
"net/http"
|
||||
|
@ -18,9 +18,9 @@ type httpHandler struct {
|
|||
server *http.Server
|
||||
}
|
||||
|
||||
func (p *httpHandler) Start(config api.HandlerConfig) (err error) {
|
||||
options := loadFromConfig(config.Options())
|
||||
addr := fmt.Sprintf("%s:%d", config.ListenAddress(), config.Port())
|
||||
func (p *httpHandler) Start(config config.HandlerConfig) (err error) {
|
||||
options := loadFromConfig(config.Options)
|
||||
addr := fmt.Sprintf("%s:%d", config.ListenAddress, config.Port)
|
||||
p.server = &http.Server{Addr: addr, Handler: p.router}
|
||||
p.logger = p.logger.With(
|
||||
zap.String("address", addr),
|
||||
|
|
|
@ -3,6 +3,7 @@ package main
|
|||
import (
|
||||
"fmt"
|
||||
"github.com/baez90/inetmock/pkg/api"
|
||||
"github.com/baez90/inetmock/pkg/config"
|
||||
"github.com/baez90/inetmock/pkg/logging"
|
||||
"go.uber.org/zap"
|
||||
"gopkg.in/elazarl/goproxy.v1"
|
||||
|
@ -19,9 +20,9 @@ type httpProxy struct {
|
|||
server *http.Server
|
||||
}
|
||||
|
||||
func (h *httpProxy) Start(config api.HandlerConfig) (err error) {
|
||||
options := loadFromConfig(config.Options())
|
||||
addr := fmt.Sprintf("%s:%d", config.ListenAddress(), config.Port())
|
||||
func (h *httpProxy) Start(config config.HandlerConfig) (err error) {
|
||||
options := loadFromConfig(config.Options)
|
||||
addr := fmt.Sprintf("%s:%d", config.ListenAddress, config.Port)
|
||||
h.server = &http.Server{Addr: addr, Handler: h.proxy}
|
||||
h.logger = h.logger.With(
|
||||
zap.String("address", addr),
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"fmt"
|
||||
"github.com/baez90/inetmock/pkg/api"
|
||||
"github.com/baez90/inetmock/pkg/cert"
|
||||
"github.com/baez90/inetmock/pkg/config"
|
||||
"github.com/baez90/inetmock/pkg/logging"
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/zap"
|
||||
|
@ -27,9 +28,9 @@ type tlsInterceptor struct {
|
|||
currentConnections map[uuid.UUID]*proxyConn
|
||||
}
|
||||
|
||||
func (t *tlsInterceptor) Start(config api.HandlerConfig) (err error) {
|
||||
t.options = loadFromConfig(config.Options())
|
||||
addr := fmt.Sprintf("%s:%d", config.ListenAddress(), config.Port())
|
||||
func (t *tlsInterceptor) Start(config config.HandlerConfig) (err error) {
|
||||
t.options = loadFromConfig(config.Options)
|
||||
addr := fmt.Sprintf("%s:%d", config.ListenAddress, config.Port)
|
||||
|
||||
t.logger = t.logger.With(
|
||||
zap.String("address", addr),
|
||||
|
|
|
@ -1,77 +0,0 @@
|
|||
endpoints:
|
||||
plainHttp:
|
||||
handler: http_proxy
|
||||
listenAddress: 0.0.0.0
|
||||
port: 80
|
||||
options:
|
||||
rules:
|
||||
- pattern: ".*\\.(?i)exe"
|
||||
response: ./assets/fakeFiles/sample.exe
|
||||
- pattern: ".*\\.(?i)(jpg|jpeg)"
|
||||
response: ./assets/fakeFiles/default.jpg
|
||||
- pattern: ".*\\.(?i)png"
|
||||
response: ./assets/fakeFiles/default.png
|
||||
- pattern: ".*\\.(?i)gif"
|
||||
response: ./assets/fakeFiles/default.gif
|
||||
- pattern: ".*\\.(?i)ico"
|
||||
response: ./assets/fakeFiles/default.ico
|
||||
- pattern: ".*\\.(?i)txt"
|
||||
response: ./assets/fakeFiles/default.txt
|
||||
- pattern: ".*"
|
||||
response: ./assets/fakeFiles/default.html
|
||||
httpsDowngrade:
|
||||
handler: tls_interceptor
|
||||
listenAddress: 0.0.0.0
|
||||
port: 443
|
||||
options:
|
||||
ecdsaCurve: P256
|
||||
validity:
|
||||
ca:
|
||||
notBeforeRelative: 17520h
|
||||
notAfterRelative: 17520h
|
||||
domain:
|
||||
notBeforeRelative: 168h
|
||||
notAfterRelative: 168h
|
||||
rootCaCert:
|
||||
publicKey: ./ca.pem
|
||||
privateKey: ./ca.key
|
||||
certCachePath: /tmp/inetmock/
|
||||
target:
|
||||
ipAddress: 127.0.0.1
|
||||
port: 80
|
||||
plainDns:
|
||||
handler: dns_mock
|
||||
listenAddress: 0.0.0.0
|
||||
port: 53
|
||||
options:
|
||||
rules:
|
||||
- pattern: "www.golem.de"
|
||||
response: 77.247.84.129
|
||||
- pattern: ".*\\.google\\.com"
|
||||
response: 1.1.1.1
|
||||
- pattern: ".*\\.reddit\\.com"
|
||||
response: 2.2.2.2
|
||||
fallback:
|
||||
strategy: incremental
|
||||
args:
|
||||
startIP: 10.0.0.0
|
||||
dnsOverTlsDowngrade:
|
||||
handler: tls_interceptor
|
||||
listenAddress: 0.0.0.0
|
||||
port: 853
|
||||
options:
|
||||
ecdsaCurve: P256
|
||||
validity:
|
||||
ca:
|
||||
notBeforeRelative: 17520h
|
||||
notAfterRelative: 17520h
|
||||
domain:
|
||||
notBeforeRelative: 168h
|
||||
notAfterRelative: 168h
|
||||
rootCaCert:
|
||||
publicKey: ./ca.pem
|
||||
privateKey: ./ca.key
|
||||
certCachePath: /tmp/inetmock/
|
||||
target:
|
||||
ipAddress: 127.0.0.1
|
||||
port: 53
|
Loading…
Reference in a new issue