Improve config and startup handling

- use `Unmarshal` method of viper
- move config loading, defaulting and stuff to config package
- move serve to an extra command and move keep only global flags in root command
- add startup logic to onInit method in root command
- update mocks
- clean config structs
- adopt changes in plugins
- update default config
This commit is contained in:
Peter 2020-04-28 00:26:15 +02:00 committed by baez90
parent 1775d3dbc6
commit 91f0cf6963
40 changed files with 561 additions and 975 deletions

View file

@ -2,7 +2,8 @@ VERSION = $(shell git describe --dirty --tags --always)
DIR = $(dir $(realpath $(firstword $(MAKEFILE_LIST))))
BUILD_PATH = $(DIR)/main.go
PKGS = $(shell go list ./...)
TEST_PKGS = $(shell find . -type f -name "*_test.go" -not -path "./plugins/*" -not -path "*/mock/*" -printf '%h\n' | sort -u)
TEST_PKGS = $(shell go list ./...)
PROTO_FILES = $(shell find $(DIR)api/ -type f -name "*.proto")
GOARGS = GOOS=linux GOARCH=amd64
GO_BUILD_ARGS = -ldflags="-w -s"
GO_CONTAINER_BUILD_ARGS = -ldflags="-w -s" -a -installsuffix cgo
@ -14,7 +15,7 @@ DEBUG_ARGS?= --development-logs=true
CONTAINER_BUILDER ?= podman
DOCKER_IMAGE ?= inetmock
.PHONY: clean all format deps update-deps compile debug generate snapshot-release test cli-cover-report html-cover-report plugins $(PLUGINS) $(GO_GEN_FILES)
.PHONY: clean all format deps update-deps compile debug generate protoc snapshot-release test cli-cover-report html-cover-report plugins $(PLUGINS) $(GO_GEN_FILES)
all: clean format compile test plugins
clean:
@ -55,6 +56,9 @@ debug:
generate:
@go generate ./...
protoc:
@protoc --proto_path $(DIR)api/ --go_out=plugins=grpc:internal/grpc --go_opt=paths=source_relative $(shell find $(DIR)api/ -type f -name "*.proto")
snapshot-release:
@goreleaser release --snapshot --skip-publish --rm-dist

View file

@ -4,6 +4,7 @@ import (
"crypto/tls"
"crypto/x509"
"github.com/baez90/inetmock/pkg/cert"
config2 "github.com/baez90/inetmock/pkg/config"
"github.com/baez90/inetmock/pkg/logging"
"github.com/spf13/cobra"
"go.uber.org/zap"
@ -77,11 +78,11 @@ func runGenerateCA(cmd *cobra.Command, args []string) {
zap.String(generateCACertOutPath, certOutPath),
)
generator := cert.NewDefaultGenerator(cert.Options{
generator := cert.NewDefaultGenerator(config2.CertOptions{
CertCachePath: certOutPath,
Curve: cert.CurveType(curveName),
Validity: cert.ValidityByPurpose{
CA: cert.ValidityDuration{
Curve: config2.CurveType(curveName),
Validity: config2.ValidityByPurpose{
CA: config2.ValidityDuration{
NotAfterRelative: notAfter,
NotBeforeRelative: notBefore,
},

View file

@ -1,70 +0,0 @@
package cmd
import (
"github.com/baez90/inetmock/internal/endpoints"
"github.com/baez90/inetmock/internal/plugins"
"github.com/baez90/inetmock/pkg/api"
"github.com/baez90/inetmock/pkg/logging"
"github.com/baez90/inetmock/pkg/path"
"github.com/spf13/viper"
"go.uber.org/zap"
"os"
"time"
)
var (
appIsInitialized = false
)
func initApp() (err error) {
if appIsInitialized {
return
}
appIsInitialized = true
logging.ConfigureLogging(
logging.ParseLevel(logLevel),
developmentLogs,
map[string]interface{}{"cwd": path.WorkingDirectory()},
)
logger, _ = logging.CreateLogger()
registry := plugins.Registry()
endpointManager = endpoints.NewEndpointManager(logger)
rootCmd.Flags().ParseErrorsWhitelist.UnknownFlags = false
_ = rootCmd.ParseFlags(os.Args)
if err = appConfig.ReadConfig(configFilePath); err != nil {
logger.Error(
"unrecoverable error occurred during reading the config file",
zap.Error(err),
)
return
}
viperInst := viper.GetViper()
pluginDir := viperInst.GetString("plugins-directory")
if err = api.InitServices(viperInst, logger); err != nil {
logger.Error(
"failed to initialize app services",
zap.Error(err),
)
}
pluginLoadStartTime := time.Now()
if err = registry.LoadPlugins(pluginDir); err != nil {
logger.Error("Failed to load plugins",
zap.String("pluginsDirectory", pluginDir),
zap.Error(err),
)
}
pluginLoadDuration := time.Since(pluginLoadStartTime)
logger.Info(
"loading plugins completed",
zap.Duration("pluginLoadDuration", pluginLoadDuration),
)
pluginsCmd.AddCommand(registry.PluginCommands()...)
return
}

View file

@ -5,10 +5,6 @@ import (
)
var (
pluginsCmd *cobra.Command
)
func init() {
pluginsCmd = &cobra.Command{
Use: "plugins",
Short: "Use the plugins prefix to interact with commands that are provided by plugins",
@ -18,5 +14,4 @@ The easiest way to explore what commands are available is to start with 'inetmoc
This help page contains a list of available sub-commands starting with the name of the plugin as a prefix.
`,
}
rootCmd.AddCommand(pluginsCmd, generateCaCmd)
}
)

View file

@ -1,77 +1,91 @@
package cmd
import (
"github.com/baez90/inetmock/internal/config"
"github.com/baez90/inetmock/internal/endpoints"
"github.com/baez90/inetmock/internal/plugins"
"github.com/baez90/inetmock/pkg/api"
"github.com/baez90/inetmock/pkg/config"
"github.com/baez90/inetmock/pkg/logging"
"github.com/baez90/inetmock/pkg/path"
"github.com/spf13/cobra"
"github.com/spf13/viper"
"go.uber.org/zap"
"os"
"os/signal"
"strings"
"syscall"
"time"
)
var (
logger logging.Logger
rootCmd = cobra.Command{
Use: "",
Short: "INetMock is lightweight internet mock",
Run: startInetMock,
}
rootCmd *cobra.Command
pluginsDirectory string
configFilePath string
logLevel string
developmentLogs bool
handlers []api.ProtocolHandler
appConfig = config.CreateConfig()
endpointManager endpoints.EndpointManager
)
func init() {
rootCmd.PersistentFlags().String("plugins-directory", "", "Directory where plugins should be loaded from")
cobra.OnInitialize(onInit)
rootCmd = &cobra.Command{
Use: "",
Short: "INetMock is lightweight internet mock",
}
rootCmd.PersistentFlags().StringVar(&pluginsDirectory, "plugins-directory", "", "Directory where plugins should be loaded from")
rootCmd.PersistentFlags().StringVar(&configFilePath, "config", "", "Path to config file that should be used")
rootCmd.PersistentFlags().StringVar(&logLevel, "log-level", "info", "logging level to use")
rootCmd.PersistentFlags().BoolVar(&developmentLogs, "development-logs", false, "Enable development mode logs")
appConfig.InitConfig(rootCmd.PersistentFlags())
rootCmd.AddCommand(
serveCmd,
generateCaCmd,
pluginsCmd,
)
}
func startInetMock(cmd *cobra.Command, args []string) {
for endpointName := range viper.GetStringMap(config.EndpointsKey) {
handlerSubConfig := viper.Sub(strings.Join([]string{config.EndpointsKey, endpointName}, "."))
handlerConfig := config.CreateMultiHandlerConfig(handlerSubConfig)
if err := endpointManager.CreateEndpoint(endpointName, handlerConfig); err != nil {
logger.Warn(
"error occurred while creating endpoint",
zap.String("endpointName", endpointName),
zap.String("handlerName", handlerConfig.HandlerName()),
func onInit() {
logging.ConfigureLogging(
logging.ParseLevel(logLevel),
developmentLogs,
map[string]interface{}{"cwd": path.WorkingDirectory()},
)
logger, _ = logging.CreateLogger()
config.CreateConfig(rootCmd.Flags())
appConfig := config.Instance()
if err := appConfig.ReadConfig(configFilePath); err != nil {
logger.Error(
"failed to read config file",
zap.Error(err),
)
}
if err := api.InitServices(appConfig, logger); err != nil {
logger.Error(
"failed to initialize app services",
zap.Error(err),
)
}
endpointManager.StartEndpoints()
registry := plugins.Registry()
signalChannel := make(chan os.Signal, 1)
signal.Notify(signalChannel, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT)
// block until canceled
s := <-signalChannel
cfg := config.Instance()
pluginLoadStartTime := time.Now()
if err := registry.LoadPlugins(cfg.PluginsDir()); err != nil {
logger.Error("Failed to load plugins",
zap.String("pluginsDirectory", cfg.PluginsDir()),
zap.Error(err),
)
}
pluginLoadDuration := time.Since(pluginLoadStartTime)
logger.Info(
"got signal to quit",
zap.String("signal", s.String()),
"loading plugins completed",
zap.Duration("pluginLoadDuration", pluginLoadDuration),
)
endpointManager.ShutdownEndpoints()
pluginsCmd.AddCommand(registry.PluginCommands()...)
}
func ExecuteRootCommand() error {
if err := initApp(); err != nil {
return err
}
return rootCmd.Execute()
}

54
internal/cmd/serve.go Normal file
View file

@ -0,0 +1,54 @@
package cmd
import (
"github.com/baez90/inetmock/internal/endpoints"
"github.com/baez90/inetmock/pkg/config"
"github.com/spf13/cobra"
"go.uber.org/zap"
"os"
"os/signal"
"strings"
"syscall"
)
var (
endpointManager endpoints.EndpointManager
serveCmd = &cobra.Command{
Use: "serve",
Short: "Starts the INetMock server",
Long: ``,
Run: startINetMock,
}
)
func startINetMock(_ *cobra.Command, _ []string) {
endpointManager = endpoints.NewEndpointManager(logger)
cfg := config.Instance()
for endpointName, endpointHandler := range cfg.EndpointConfigs() {
handlerSubConfig := cfg.Viper().Sub(strings.Join([]string{config.EndpointsKey, endpointName, config.OptionsKey}, "."))
endpointHandler.Options = handlerSubConfig
if err := endpointManager.CreateEndpoint(endpointName, endpointHandler); err != nil {
logger.Warn(
"error occurred while creating endpoint",
zap.String("endpointName", endpointName),
zap.String("handlerName", endpointHandler.Handler),
zap.Error(err),
)
}
}
endpointManager.StartEndpoints()
signalChannel := make(chan os.Signal, 1)
signal.Notify(signalChannel, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT)
// block until canceled
s := <-signalChannel
logger.Info(
"got signal to quit",
zap.String("signal", s.String()),
)
endpointManager.ShutdownEndpoints()
}

View file

@ -1,6 +0,0 @@
package config
const (
EndpointsKey = "endpoints"
OptionsKey = "options"
)

View file

@ -1,55 +0,0 @@
package config
import (
"github.com/baez90/inetmock/pkg/logging"
"github.com/baez90/inetmock/pkg/path"
"github.com/spf13/pflag"
"github.com/spf13/viper"
"go.uber.org/zap"
"strings"
)
func CreateConfig() Config {
logger, _ := logging.CreateLogger()
return &config{
logger: logger.Named("Config"),
}
}
type Config interface {
InitConfig(flags *pflag.FlagSet)
ReadConfig(configFilePath string) error
}
type config struct {
logger logging.Logger
}
func (c config) InitConfig(flags *pflag.FlagSet) {
viper.SetConfigName("config")
viper.SetConfigType("yaml")
viper.AddConfigPath("/etc/inetmock/")
viper.AddConfigPath("$HOME/.inetmock")
viper.AddConfigPath(".")
viper.SetEnvPrefix("INetMock")
_ = viper.BindPFlags(flags)
viper.SetEnvKeyReplacer(strings.NewReplacer("-", "_"))
viper.AutomaticEnv()
}
func (c *config) ReadConfig(configFilePath string) (err error) {
if configFilePath != "" && path.FileExists(configFilePath) {
c.logger.Info(
"loading config from passed config file path",
zap.String("configFilePath", configFilePath),
)
viper.SetConfigFile(configFilePath)
}
if err = viper.ReadInConfig(); err != nil {
if _, ok := err.(viper.ConfigFileNotFoundError); ok {
err = nil
c.logger.Warn("failed to load config")
}
}
return
}

View file

@ -1,49 +0,0 @@
package config
import (
"github.com/baez90/inetmock/pkg/api"
"github.com/spf13/viper"
)
type MultiHandlerConfig interface {
HandlerName() string
ListenAddress() string
Ports() []uint16
Options() *viper.Viper
HandlerConfigs() []api.HandlerConfig
}
type multiHandlerConfig struct {
handlerName string
ports []uint16
listenAddress string
options *viper.Viper
}
func NewMultiHandlerConfig(handlerName string, ports []uint16, listenAddress string, options *viper.Viper) MultiHandlerConfig {
return &multiHandlerConfig{handlerName: handlerName, ports: ports, listenAddress: listenAddress, options: options}
}
func (m multiHandlerConfig) HandlerName() string {
return m.handlerName
}
func (m multiHandlerConfig) ListenAddress() string {
return m.listenAddress
}
func (m multiHandlerConfig) Ports() []uint16 {
return m.ports
}
func (m multiHandlerConfig) Options() *viper.Viper {
return m.options
}
func (m multiHandlerConfig) HandlerConfigs() []api.HandlerConfig {
configs := make([]api.HandlerConfig, 0)
for _, port := range m.ports {
configs = append(configs, api.NewHandlerConfig(m.handlerName, port, m.listenAddress, m.options))
}
return configs
}

View file

@ -1,226 +0,0 @@
package config
import (
"github.com/baez90/inetmock/pkg/api"
"github.com/spf13/viper"
"reflect"
"testing"
)
func Test_multiHandlerConfig_HandlerConfigs(t *testing.T) {
type fields struct {
handlerName string
ports []uint16
listenAddress string
options *viper.Viper
}
tests := []struct {
name string
fields fields
want []api.HandlerConfig
}{
{
name: "Get empty array if no ports are set",
fields: fields{},
want: make([]api.HandlerConfig, 0),
},
{
name: "Get a single handler config if only one port is set",
fields: fields{
handlerName: "sampleHandler",
ports: []uint16{80},
listenAddress: "0.0.0.0",
options: nil,
},
want: []api.HandlerConfig{
api.NewHandlerConfig("sampleHandler", 80, "0.0.0.0", nil),
},
},
{
name: "Get multiple handler configs if only one port is set",
fields: fields{
handlerName: "sampleHandler",
ports: []uint16{80, 8080},
listenAddress: "0.0.0.0",
options: nil,
},
want: []api.HandlerConfig{
api.NewHandlerConfig("sampleHandler", 80, "0.0.0.0", nil),
api.NewHandlerConfig("sampleHandler", 8080, "0.0.0.0", nil),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
m := multiHandlerConfig{
handlerName: tt.fields.handlerName,
ports: tt.fields.ports,
listenAddress: tt.fields.listenAddress,
options: tt.fields.options,
}
if got := m.HandlerConfigs(); !reflect.DeepEqual(got, tt.want) {
t.Errorf("HandlerConfigs() = %v, want %v", got, tt.want)
}
})
}
}
func Test_multiHandlerConfig_HandlerName(t *testing.T) {
type fields struct {
handlerName string
ports []uint16
listenAddress string
options *viper.Viper
}
tests := []struct {
name string
fields fields
want string
}{
{
name: "Get empty handler name for uninitialized struct",
fields: fields{},
want: "",
},
{
name: "Get expected handler name for initialized struct",
fields: fields{
handlerName: "sampleHandler",
},
want: "sampleHandler",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
m := multiHandlerConfig{
handlerName: tt.fields.handlerName,
ports: tt.fields.ports,
listenAddress: tt.fields.listenAddress,
options: tt.fields.options,
}
if got := m.HandlerName(); got != tt.want {
t.Errorf("HandlerName() = %v, want %v", got, tt.want)
}
})
}
}
func Test_multiHandlerConfig_ListenAddress(t *testing.T) {
type fields struct {
handlerName string
ports []uint16
listenAddress string
options *viper.Viper
}
tests := []struct {
name string
fields fields
want string
}{
{
name: "Get empty ListenAddress for uninitialized struct",
fields: fields{},
want: "",
},
{
name: "Get expected ListenAddress for initialized struct",
fields: fields{
listenAddress: "0.0.0.0",
},
want: "0.0.0.0",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
m := multiHandlerConfig{
handlerName: tt.fields.handlerName,
ports: tt.fields.ports,
listenAddress: tt.fields.listenAddress,
options: tt.fields.options,
}
if got := m.ListenAddress(); got != tt.want {
t.Errorf("ListenAddress() = %v, want %v", got, tt.want)
}
})
}
}
func Test_multiHandlerConfig_Options(t *testing.T) {
type fields struct {
handlerName string
ports []uint16
listenAddress string
options *viper.Viper
}
tests := []struct {
name string
fields fields
want *viper.Viper
}{
{
name: "Get nil Options for uninitialized struct",
fields: fields{},
want: nil,
},
{
name: "Get expected Options for initialized struct",
fields: fields{
options: viper.New(),
},
want: viper.New(),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
m := multiHandlerConfig{
handlerName: tt.fields.handlerName,
ports: tt.fields.ports,
listenAddress: tt.fields.listenAddress,
options: tt.fields.options,
}
if got := m.Options(); !reflect.DeepEqual(got, tt.want) {
t.Errorf("Options() = %v, want %v", got, tt.want)
}
})
}
}
func Test_multiHandlerConfig_Ports(t *testing.T) {
type fields struct {
handlerName string
ports []uint16
listenAddress string
options *viper.Viper
}
tests := []struct {
name string
fields fields
want []uint16
}{
{
name: "Get empty Ports for uninitialized struct",
fields: fields{},
want: nil,
},
{
name: "Get expected Ports for initialized struct",
fields: fields{
ports: []uint16{80, 8080},
},
want: []uint16{80, 8080},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
m := multiHandlerConfig{
handlerName: tt.fields.handlerName,
ports: tt.fields.ports,
listenAddress: tt.fields.listenAddress,
options: tt.fields.options,
}
if got := m.Ports(); !reflect.DeepEqual(got, tt.want) {
t.Errorf("Ports() = %v, want %v", got, tt.want)
}
})
}
}

View file

@ -1,35 +0,0 @@
package config
import (
"github.com/spf13/viper"
)
const (
pluginConfigKey = "handler"
listenAddressConfigKey = "listenAddress"
portConfigKey = "port"
portsConfigKey = "ports"
)
func CreateMultiHandlerConfig(handlerConfig *viper.Viper) MultiHandlerConfig {
return NewMultiHandlerConfig(
handlerConfig.GetString(pluginConfigKey),
portsFromConfig(handlerConfig),
handlerConfig.GetString(listenAddressConfigKey),
handlerConfig.Sub(OptionsKey),
)
}
func portsFromConfig(handlerConfig *viper.Viper) (ports []uint16) {
if portsInt := handlerConfig.GetIntSlice(portsConfigKey); len(portsInt) > 0 {
for _, port := range portsInt {
ports = append(ports, uint16(port))
}
return
}
if portInt := handlerConfig.GetInt(portConfigKey); portInt > 0 {
ports = append(ports, uint16(portInt))
}
return
}

View file

@ -1,145 +0,0 @@
package config
import (
"bytes"
"github.com/spf13/viper"
"reflect"
"testing"
)
func TestCreateMultiHandlerConfig(t *testing.T) {
type args struct {
handlerConfig *viper.Viper
}
tests := []struct {
name string
args args
want MultiHandlerConfig
}{
{
name: "Get simple multiHandlerConfig from config",
args: args{
handlerConfig: configFromString(`
handler: sampleHandler
listenAddress: 0.0.0.0
ports:
- 80
- 8080
options: {}
`),
},
want: &multiHandlerConfig{
handlerName: "sampleHandler",
ports: []uint16{80, 8080},
listenAddress: "0.0.0.0",
options: viper.New(),
},
},
{
name: "Get more complex multiHandlerConfig from config",
args: args{
handlerConfig: configFromString(`
handler: sampleHandler
listenAddress: 0.0.0.0
ports:
- 80
- 8080
options:
optionA: asdf
optionB: as1234
`),
},
want: &multiHandlerConfig{
handlerName: "sampleHandler",
ports: []uint16{80, 8080},
listenAddress: "0.0.0.0",
options: configFromString(`
nesting:
optionA: asdf
optionB: as1234
`).Sub("nesting"),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := CreateMultiHandlerConfig(tt.args.handlerConfig); !reflect.DeepEqual(got, tt.want) {
t.Errorf("CreateMultiHandlerConfig() = %v, want %v", got, tt.want)
}
})
}
}
func Test_portsFromConfig(t *testing.T) {
type args struct {
handlerConfig *viper.Viper
}
tests := []struct {
name string
args args
wantPorts []uint16
}{
{
name: "Empty array if config value is not set",
args: args{
handlerConfig: viper.New(),
},
wantPorts: nil,
},
{
name: "Array of one if `port` is set",
args: args{
handlerConfig: configFromString(`
port: 80
`),
},
wantPorts: []uint16{80},
},
{
name: "Array of one if `ports` is set as array",
args: args{
handlerConfig: configFromString(`
ports:
- 80
`),
},
wantPorts: []uint16{80},
},
{
name: "Array of two if `ports` is set as array",
args: args{
handlerConfig: configFromString(`
ports:
- 80
- 8080
`),
},
wantPorts: []uint16{80, 8080},
},
{
name: "Array of two if `port` is set as array",
args: args{
handlerConfig: configFromString(`
ports:
- 80
- 8080
`),
},
wantPorts: []uint16{80, 8080},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if gotPorts := portsFromConfig(tt.args.handlerConfig); !reflect.DeepEqual(gotPorts, tt.wantPorts) {
t.Errorf("portsFromConfig() = %v, want %v", gotPorts, tt.wantPorts)
}
})
}
}
func configFromString(yaml string) (config *viper.Viper) {
config = viper.New()
config.SetConfigType("yaml")
_ = config.ReadConfig(bytes.NewBufferString(yaml))
return
}

View file

@ -3,6 +3,7 @@ package endpoints
import (
"github.com/baez90/inetmock/pkg/api"
"github.com/baez90/inetmock/pkg/config"
)
type Endpoint interface {
@ -14,7 +15,7 @@ type Endpoint interface {
type endpoint struct {
name string
handler api.ProtocolHandler
config api.HandlerConfig
config config.HandlerConfig
}
func (e endpoint) Name() string {

View file

@ -2,8 +2,8 @@ package endpoints
import (
"fmt"
"github.com/baez90/inetmock/internal/config"
"github.com/baez90/inetmock/internal/plugins"
config2 "github.com/baez90/inetmock/pkg/config"
"github.com/baez90/inetmock/pkg/logging"
"go.uber.org/zap"
"sync"
@ -13,7 +13,7 @@ import (
type EndpointManager interface {
RegisteredEndpoints() []Endpoint
StartedEndpoints() []Endpoint
CreateEndpoint(name string, multiHandlerConfig config.MultiHandlerConfig) error
CreateEndpoint(name string, multiHandlerConfig config2.MultiHandlerConfig) error
StartEndpoints()
ShutdownEndpoints()
}
@ -40,16 +40,16 @@ func (e endpointManager) StartedEndpoints() []Endpoint {
return e.properlyStartedEndpoints
}
func (e *endpointManager) CreateEndpoint(name string, multiHandlerConfig config.MultiHandlerConfig) error {
func (e *endpointManager) CreateEndpoint(name string, multiHandlerConfig config2.MultiHandlerConfig) error {
for _, handlerConfig := range multiHandlerConfig.HandlerConfigs() {
if handler, ok := e.registry.HandlerForName(multiHandlerConfig.HandlerName()); ok {
if handler, ok := e.registry.HandlerForName(multiHandlerConfig.Handler); ok {
e.registeredEndpoints = append(e.registeredEndpoints, &endpoint{
name: name,
handler: handler,
config: handlerConfig,
})
} else {
return fmt.Errorf("no matching handler registered for names %s", multiHandlerConfig.HandlerName())
return fmt.Errorf("no matching handler registered for names %s", multiHandlerConfig.Handler)
}
}

View file

@ -1,11 +1,11 @@
package endpoints
import (
"github.com/baez90/inetmock/internal/config"
api_mock "github.com/baez90/inetmock/internal/mock/api"
logging_mock "github.com/baez90/inetmock/internal/mock/logging"
plugins_mock "github.com/baez90/inetmock/internal/mock/plugins"
"github.com/baez90/inetmock/internal/plugins"
config2 "github.com/baez90/inetmock/pkg/config"
"github.com/baez90/inetmock/pkg/logging"
"github.com/golang/mock/gomock"
"testing"
@ -20,7 +20,7 @@ func Test_endpointManager_CreateEndpoint(t *testing.T) {
}
type args struct {
name string
multiHandlerConfig config.MultiHandlerConfig
multiHandlerConfig config2.MultiHandlerConfig
}
tests := []struct {
name string
@ -52,12 +52,11 @@ func Test_endpointManager_CreateEndpoint(t *testing.T) {
},
args: args{
name: "sampleEndpoint",
multiHandlerConfig: config.NewMultiHandlerConfig(
"sampleHandler",
[]uint16{80},
"0.0.0.0",
nil,
),
multiHandlerConfig: config2.MultiHandlerConfig{
Handler: "sampleHandler",
Ports: []uint16{80},
ListenAddress: "0.0.0.0",
},
},
},
{
@ -83,12 +82,11 @@ func Test_endpointManager_CreateEndpoint(t *testing.T) {
},
args: args{
name: "sampleEndpoint",
multiHandlerConfig: config.NewMultiHandlerConfig(
"sampleHandler",
[]uint16{80},
"0.0.0.0",
nil,
),
multiHandlerConfig: config2.MultiHandlerConfig{
Handler: "sampleHandler",
Ports: []uint16{80},
ListenAddress: "0.0.0.0",
},
},
},
}

View file

@ -4,6 +4,7 @@ import (
"fmt"
api_mock "github.com/baez90/inetmock/internal/mock/api"
"github.com/baez90/inetmock/pkg/api"
"github.com/baez90/inetmock/pkg/config"
"github.com/golang/mock/gomock"
"testing"
)
@ -12,7 +13,7 @@ func Test_endpoint_Name(t *testing.T) {
type fields struct {
name string
handler api.ProtocolHandler
config api.HandlerConfig
config config.HandlerConfig
}
tests := []struct {
name string
@ -50,7 +51,7 @@ func Test_endpoint_Shutdown(t *testing.T) {
type fields struct {
name string
handler api.ProtocolHandler
config api.HandlerConfig
config config.HandlerConfig
}
tests := []struct {
name string
@ -102,17 +103,16 @@ func Test_endpoint_Shutdown(t *testing.T) {
func Test_endpoint_Start(t *testing.T) {
demoHandlerConfig := api.NewHandlerConfig(
"sampleHandler",
80,
"0.0.0.0",
nil,
)
demoHandlerConfig := config.HandlerConfig{
HandlerName: "sampleHandler",
Port: 80,
ListenAddress: "0.0.0.0",
}
type fields struct {
name string
handler api.ProtocolHandler
config api.HandlerConfig
config config.HandlerConfig
}
tests := []struct {
name string
@ -125,7 +125,7 @@ func Test_endpoint_Start(t *testing.T) {
handler: func() api.ProtocolHandler {
handler := api_mock.NewMockProtocolHandler(gomock.NewController(t))
handler.EXPECT().
Start(nil).
Start(gomock.Any()).
MaxTimes(1).
Return(nil)
return handler
@ -139,7 +139,7 @@ func Test_endpoint_Start(t *testing.T) {
handler: func() api.ProtocolHandler {
handler := api_mock.NewMockProtocolHandler(gomock.NewController(t))
handler.EXPECT().
Start(nil).
Start(gomock.Any()).
MaxTimes(1).
Return(fmt.Errorf(""))
return handler

View file

@ -10,6 +10,7 @@ import (
"path/filepath"
"plugin"
"regexp"
"strings"
)
var (
@ -19,6 +20,7 @@ var (
type HandlerRegistry interface {
LoadPlugins(pluginsPath string) error
AvailableHandlers() []string
RegisterHandler(handlerName string, handlerProvider api.PluginInstanceFactory, subCommands ...*cobra.Command)
HandlerForName(handlerName string) (api.ProtocolHandler, bool)
PluginCommands() []*cobra.Command
@ -29,11 +31,19 @@ type handlerRegistry struct {
pluginCommands []*cobra.Command
}
func (h handlerRegistry) AvailableHandlers() (availableHandlers []string) {
for s := range h.handlers {
availableHandlers = append(availableHandlers, s)
}
return
}
func (h handlerRegistry) PluginCommands() []*cobra.Command {
return h.pluginCommands
}
func (h *handlerRegistry) HandlerForName(handlerName string) (instance api.ProtocolHandler, ok bool) {
handlerName = strings.ToLower(handlerName)
var provider api.PluginInstanceFactory
if provider, ok = h.handlers[handlerName]; ok {
instance = provider()
@ -42,6 +52,7 @@ func (h *handlerRegistry) HandlerForName(handlerName string) (instance api.Proto
}
func (h *handlerRegistry) RegisterHandler(handlerName string, handlerProvider api.PluginInstanceFactory, subCommands ...*cobra.Command) {
handlerName = strings.ToLower(handlerName)
if _, exists := h.handlers[handlerName]; exists {
panic(fmt.Sprintf("handler with name %s is already registered - there's something strange...in the neighborhood", handlerName))
}

View file

@ -25,8 +25,8 @@ tls:
NotBeforeRelative: 168h
NotAfterRelative: 168h
rootCaCert:
publicKey: ./ca.pem
privateKey: ./ca.key
publicKeyPath: ./ca.pem
privateKeyPath: ./ca.key
certCachePath: /tmp/inetmock/
endpoints:
@ -41,7 +41,8 @@ endpoints:
proxy:
handler: http_proxy
listenAddress: 0.0.0.0
port: 3128
ports:
- 3128
options:
target:
ipAddress: 127.0.0.1
@ -59,7 +60,8 @@ endpoints:
plainDns:
handler: dns_mock
listenAddress: 0.0.0.0
port: 53
ports:
- 53
options:
rules:
- pattern: ".*\\.google\\.com"
@ -73,7 +75,8 @@ endpoints:
dnsOverTlsDowngrade:
handler: tls_interceptor
listenAddress: 0.0.0.0
port: 853
ports:
- 853
options:
target:
ipAddress: 127.0.0.1

View file

@ -1,42 +0,0 @@
package api
import "github.com/spf13/viper"
type HandlerConfig interface {
HandlerName() string
ListenAddress() string
Port() uint16
Options() *viper.Viper
}
func NewHandlerConfig(handlerName string, port uint16, listenAddress string, options *viper.Viper) HandlerConfig {
return &handlerConfig{
handlerName: handlerName,
port: port,
listenAddress: listenAddress,
options: options,
}
}
type handlerConfig struct {
handlerName string
port uint16
listenAddress string
options *viper.Viper
}
func (h handlerConfig) HandlerName() string {
return h.handlerName
}
func (h handlerConfig) ListenAddress() string {
return h.listenAddress
}
func (h handlerConfig) Port() uint16 {
return h.port
}
func (h handlerConfig) Options() *viper.Viper {
return h.options
}

View file

@ -2,6 +2,7 @@
package api
import (
"github.com/baez90/inetmock/pkg/config"
"go.uber.org/zap"
)
@ -10,6 +11,6 @@ type PluginInstanceFactory func() ProtocolHandler
type LoggingFactory func() (*zap.Logger, error)
type ProtocolHandler interface {
Start(config HandlerConfig) error
Start(config config.HandlerConfig) error
Shutdown() error
}

View file

@ -2,8 +2,8 @@ package api
import (
"github.com/baez90/inetmock/pkg/cert"
config2 "github.com/baez90/inetmock/pkg/config"
"github.com/baez90/inetmock/pkg/logging"
"github.com/spf13/viper"
)
var (
@ -19,7 +19,7 @@ type services struct {
}
func InitServices(
config *viper.Viper,
config config2.Config,
logger logging.Logger,
) error {
certStore, err := cert.NewDefaultStore(config, logger)

View file

@ -5,6 +5,7 @@ import (
"crypto/x509"
"fmt"
certmock "github.com/baez90/inetmock/internal/mock/cert"
config2 "github.com/baez90/inetmock/pkg/config"
"github.com/golang/mock/gomock"
"os"
"path"
@ -284,13 +285,13 @@ func Test_fileSystemCache_Put(t *testing.T) {
}
func setupCertGen() Generator {
return NewDefaultGenerator(Options{
Validity: ValidityByPurpose{
Server: ValidityDuration{
return NewDefaultGenerator(config2.CertOptions{
Validity: config2.ValidityByPurpose{
Server: config2.ValidityDuration{
NotBeforeRelative: serverRelativeValidity,
NotAfterRelative: serverRelativeValidity,
},
CA: ValidityDuration{
CA: config2.ValidityDuration{
NotBeforeRelative: caRelativeValidity,
NotAfterRelative: caRelativeValidity,
},

View file

@ -1,19 +1,11 @@
package cert
const (
CurveTypeP224 CurveType = "P224"
CurveTypeP256 CurveType = "P256"
CurveTypeP384 CurveType = "P384"
CurveTypeP521 CurveType = "P521"
CurveTypeED25519 CurveType = "ED25519"
defaultServerValidityDuration = "168h"
defaultCAValidityDuration = "17520h"
certCachePathConfigKey = "tls.certCachePath"
ecdsaCurveConfigKey = "tls.ecdsaCurve"
publicKeyConfigKey = "tls.rootCaCert.publicKey"
privateKeyPathConfigKey = "tls.rootCaCert.privateKey"
caCertValidityNotBeforeKey = "tls.validity.ca.notBeforeRelative"
caCertValidityNotAfterKey = "tls.validity.ca.notAfterRelative"
serverCertValidityNotBeforeKey = "tls.validity.server.notBeforeRelative"

View file

@ -8,6 +8,7 @@ import (
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
config2 "github.com/baez90/inetmock/pkg/config"
"github.com/baez90/inetmock/pkg/defaulting"
"math/big"
"net"
@ -35,11 +36,11 @@ type Generator interface {
ServerCert(options GenerationOptions, ca *tls.Certificate) (*tls.Certificate, error)
}
func NewDefaultGenerator(options Options) Generator {
func NewDefaultGenerator(options config2.CertOptions) Generator {
return NewGenerator(options, NewTimeSource(), defaultKeyProvider(options))
}
func NewGenerator(options Options, source TimeSource, provider KeyProvider) Generator {
func NewGenerator(options config2.CertOptions, source TimeSource, provider KeyProvider) Generator {
return &generator{
options: options,
provider: provider,
@ -48,7 +49,7 @@ func NewGenerator(options Options, source TimeSource, provider KeyProvider) Gene
}
type generator struct {
options Options
options config2.CertOptions
provider KeyProvider
timeSource TimeSource
outDir string

View file

@ -1,79 +1,15 @@
package cert
import (
"fmt"
"github.com/spf13/viper"
"github.com/baez90/inetmock/pkg/config"
"os"
"strings"
"time"
)
var (
requiredConfigKeys = []string{
publicKeyConfigKey,
privateKeyPathConfigKey,
}
)
type CurveType string
type File struct {
PublicKeyPath string
PrivateKeyPath string
}
type ValidityDuration struct {
NotBeforeRelative time.Duration
NotAfterRelative time.Duration
}
type ValidityByPurpose struct {
CA ValidityDuration
Server ValidityDuration
}
type Options struct {
RootCACert File
CertCachePath string
Curve CurveType
Validity ValidityByPurpose
}
func loadFromConfig(config *viper.Viper) (Options, error) {
missingKeys := make([]string, 0)
for _, requiredKey := range requiredConfigKeys {
if !config.IsSet(requiredKey) {
missingKeys = append(missingKeys, requiredKey)
}
}
if len(missingKeys) > 0 {
return Options{}, fmt.Errorf("config keys are missing: %s", strings.Join(missingKeys, ", "))
}
config.SetDefault(certCachePathConfigKey, os.TempDir())
config.SetDefault(ecdsaCurveConfigKey, string(CurveTypeED25519))
config.SetDefault(caCertValidityNotBeforeKey, defaultCAValidityDuration)
config.SetDefault(caCertValidityNotAfterKey, defaultCAValidityDuration)
config.SetDefault(serverCertValidityNotBeforeKey, defaultServerValidityDuration)
config.SetDefault(serverCertValidityNotAfterKey, defaultServerValidityDuration)
return Options{
CertCachePath: config.GetString(certCachePathConfigKey),
Curve: CurveType(config.GetString(ecdsaCurveConfigKey)),
Validity: ValidityByPurpose{
CA: ValidityDuration{
NotBeforeRelative: config.GetDuration(caCertValidityNotBeforeKey),
NotAfterRelative: config.GetDuration(caCertValidityNotAfterKey),
},
Server: ValidityDuration{
NotBeforeRelative: config.GetDuration(serverCertValidityNotBeforeKey),
NotAfterRelative: config.GetDuration(serverCertValidityNotAfterKey),
},
},
RootCACert: File{
PublicKeyPath: config.GetString(publicKeyConfigKey),
PrivateKeyPath: config.GetString(privateKeyPathConfigKey),
},
}, nil
func init() {
config.AddDefaultValue(certCachePathConfigKey, os.TempDir())
config.AddDefaultValue(ecdsaCurveConfigKey, string(config.CurveTypeED25519))
config.AddDefaultValue(caCertValidityNotBeforeKey, defaultCAValidityDuration)
config.AddDefaultValue(caCertValidityNotAfterKey, defaultCAValidityDuration)
config.AddDefaultValue(serverCertValidityNotBeforeKey, defaultServerValidityDuration)
config.AddDefaultValue(serverCertValidityNotAfterKey, defaultServerValidityDuration)
}

View file

@ -1,9 +1,9 @@
package cert
import (
"github.com/baez90/inetmock/pkg/config"
"github.com/spf13/viper"
"os"
"reflect"
"strings"
"testing"
"time"
@ -25,7 +25,7 @@ func Test_loadFromConfig(t *testing.T) {
tests := []struct {
name string
args args
want Options
want config.CertOptions
wantErr bool
}{
{
@ -48,19 +48,19 @@ tls:
certCachePath: /tmp/inetmock/
`),
},
want: Options{
RootCACert: File{
want: config.CertOptions{
RootCACert: config.File{
PublicKeyPath: "./ca.pem",
PrivateKeyPath: "./ca.key",
},
CertCachePath: "/tmp/inetmock/",
Curve: CurveTypeP256,
Validity: ValidityByPurpose{
CA: ValidityDuration{
Curve: config.CurveTypeP256,
Validity: config.ValidityByPurpose{
CA: config.ValidityDuration{
NotBeforeRelative: 17520 * time.Hour,
NotAfterRelative: 17520 * time.Hour,
},
Server: ValidityDuration{
Server: config.ValidityDuration{
NotBeforeRelative: 168 * time.Hour,
NotAfterRelative: 168 * time.Hour,
},
@ -76,7 +76,7 @@ tls:
privateKey: ./ca.key
`),
},
want: Options{},
want: config.CertOptions{},
wantErr: true,
},
{
@ -88,7 +88,7 @@ tls:
publicKey: ./ca.pem
`),
},
want: Options{},
want: config.CertOptions{},
wantErr: true,
},
{
@ -101,19 +101,19 @@ tls:
privateKey: ./ca.key
`),
},
want: Options{
RootCACert: File{
want: config.CertOptions{
RootCACert: config.File{
PublicKeyPath: "./ca.pem",
PrivateKeyPath: "./ca.key",
},
CertCachePath: os.TempDir(),
Curve: CurveTypeED25519,
Validity: ValidityByPurpose{
CA: ValidityDuration{
Curve: config.CurveTypeED25519,
Validity: config.ValidityByPurpose{
CA: config.ValidityDuration{
NotBeforeRelative: 17520 * time.Hour,
NotAfterRelative: 17520 * time.Hour,
},
Server: ValidityDuration{
Server: config.ValidityDuration{
NotBeforeRelative: 168 * time.Hour,
NotAfterRelative: 168 * time.Hour,
},
@ -124,14 +124,7 @@ tls:
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := loadFromConfig(tt.args.config)
if (err != nil) != tt.wantErr {
t.Errorf("loadFromConfig() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("loadFromConfig() got = %v, want %v", got, tt.want)
}
})
}
}

View file

@ -7,8 +7,8 @@ import (
"crypto/rand"
"crypto/tls"
"crypto/x509"
"github.com/baez90/inetmock/pkg/config"
"github.com/baez90/inetmock/pkg/logging"
"github.com/spf13/viper"
"go.uber.org/zap"
"net"
)
@ -18,7 +18,7 @@ const (
)
var (
defaultKeyProvider = func(options Options) func() (key interface{}, err error) {
defaultKeyProvider = func(options config.CertOptions) func() (key interface{}, err error) {
return func() (key interface{}, err error) {
return privateKeyForCurve(options)
}
@ -34,24 +34,20 @@ type Store interface {
}
func NewDefaultStore(
config *viper.Viper,
config config.Config,
logger logging.Logger,
) (Store, error) {
options, err := loadFromConfig(config)
if err != nil {
return nil, err
}
timeSource := NewTimeSource()
return NewStore(
options,
NewFileSystemCache(options.CertCachePath, timeSource),
NewDefaultGenerator(options),
config.TLSConfig(),
NewFileSystemCache(config.TLSConfig().CertCachePath, timeSource),
NewDefaultGenerator(config.TLSConfig()),
logger,
)
}
func NewStore(
options Options,
options config.CertOptions,
cache Cache,
generator Generator,
logger logging.Logger,
@ -71,7 +67,7 @@ func NewStore(
}
type store struct {
options Options
options config.CertOptions
caCert *tls.Certificate
cache Cache
timeSource TimeSource
@ -135,15 +131,15 @@ func (s *store) GetCertificate(serverName string, ip string) (cert *tls.Certific
return
}
func privateKeyForCurve(options Options) (privateKey interface{}, err error) {
func privateKeyForCurve(options config.CertOptions) (privateKey interface{}, err error) {
switch options.Curve {
case CurveTypeP224:
case config.CurveTypeP224:
privateKey, err = ecdsa.GenerateKey(elliptic.P224(), rand.Reader)
case CurveTypeP256:
case config.CurveTypeP256:
privateKey, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
case CurveTypeP384:
case config.CurveTypeP384:
privateKey, err = ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
case CurveTypeP521:
case config.CurveTypeP521:
privateKey, err = ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
default:
_, privateKey, err = ed25519.GenerateKey(rand.Reader)

27
pkg/config/certs.go Normal file
View file

@ -0,0 +1,27 @@
package config
import "time"
type CurveType string
type File struct {
PublicKeyPath string
PrivateKeyPath string
}
type ValidityDuration struct {
NotBeforeRelative time.Duration
NotAfterRelative time.Duration
}
type ValidityByPurpose struct {
CA ValidityDuration
Server ValidityDuration
}
type CertOptions struct {
RootCACert File
CertCachePath string
Curve CurveType
Validity ValidityByPurpose
}

109
pkg/config/config.go Normal file
View file

@ -0,0 +1,109 @@
package config
import (
"github.com/baez90/inetmock/pkg/logging"
"github.com/baez90/inetmock/pkg/path"
"github.com/spf13/pflag"
"github.com/spf13/viper"
"go.uber.org/zap"
"strings"
)
var (
appConfig Config
)
func CreateConfig(flags *pflag.FlagSet) {
logger, _ := logging.CreateLogger()
configInstance := &config{
logger: logger.Named("Config"),
cfg: viper.New(),
}
configInstance.cfg.SetConfigName("config")
configInstance.cfg.SetConfigType("yaml")
configInstance.cfg.AddConfigPath("/etc/inetmock/")
configInstance.cfg.AddConfigPath("$HOME/.inetmock")
configInstance.cfg.AddConfigPath(".")
configInstance.cfg.SetEnvPrefix("INetMock")
_ = configInstance.cfg.BindPFlags(flags)
configInstance.cfg.SetEnvKeyReplacer(strings.NewReplacer("-", "_"))
configInstance.cfg.AutomaticEnv()
for k, v := range registeredDefaults {
configInstance.cfg.SetDefault(k, v)
}
for k, v := range registeredAliases {
configInstance.cfg.RegisterAlias(k, v)
}
appConfig = configInstance
}
func Instance() Config {
return appConfig
}
type Config interface {
ReadConfig(configFilePath string) error
ReadConfigString(config, format string) error
Viper() *viper.Viper
TLSConfig() CertOptions
PluginsDir() string
EndpointConfigs() map[string]MultiHandlerConfig
}
type config struct {
cfg *viper.Viper
logger logging.Logger
TLS CertOptions
PluginsDirectory string
Endpoints map[string]MultiHandlerConfig
}
func (c *config) ReadConfigString(config, format string) (err error) {
c.cfg.SetConfigType(format)
if err = c.cfg.ReadConfig(strings.NewReader(config)); err != nil {
return
}
err = c.cfg.Unmarshal(c)
return
}
func (c *config) EndpointConfigs() map[string]MultiHandlerConfig {
return c.Endpoints
}
func (c *config) PluginsDir() string {
return c.PluginsDirectory
}
func (c *config) TLSConfig() CertOptions {
return c.TLS
}
func (c *config) Viper() *viper.Viper {
return c.cfg
}
func (c *config) ReadConfig(configFilePath string) (err error) {
if configFilePath != "" && path.FileExists(configFilePath) {
c.logger.Info(
"loading config from passed config file path",
zap.String("configFilePath", configFilePath),
)
c.cfg.SetConfigFile(configFilePath)
}
if err = c.cfg.ReadInConfig(); err != nil {
if _, ok := err.(viper.ConfigFileNotFoundError); ok {
err = nil
c.logger.Warn("failed to load config")
}
}
err = c.cfg.Unmarshal(c)
return
}

88
pkg/config/config_test.go Normal file
View file

@ -0,0 +1,88 @@
package config
import (
"github.com/spf13/pflag"
"os"
"testing"
)
func Test_config_ReadConfig(t *testing.T) {
type args struct {
flags *pflag.FlagSet
config string
}
tests := []struct {
name string
args args
matcher func(Config) bool
wantErr bool
}{
{
name: "Test endpoints config",
args: args{
flags: pflag.NewFlagSet("", pflag.ContinueOnError),
config: `
endpoints:
plainHttp:
handler: http_mock
listenAddress: 0.0.0.0
ports:
- 80
- 8080
options: {}
proxy:
handler: http_proxy
listenAddress: 0.0.0.0
ports:
- 3128
options:
target:
ipAddress: 127.0.0.1
port: 80
`,
},
matcher: func(c Config) bool {
if len(c.EndpointConfigs()) < 1 {
t.Error("Expected EndpointConfigs to be set but is empty")
return false
}
return true
},
wantErr: false,
},
{
name: "Test loading of flags by adding",
args: args{
flags: func() *pflag.FlagSet {
flags := pflag.NewFlagSet("", pflag.ContinueOnError)
flags.String("plugins-directory", os.TempDir(), "")
return flags
}(),
config: ``,
},
matcher: func(c Config) bool {
if c.PluginsDir() != os.TempDir() {
t.Errorf("PluginsDir() Expected = %s, Got = %s", os.TempDir(), c.PluginsDir())
}
return true
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
CreateConfig(tt.args.flags)
cfg := Instance()
if err := cfg.ReadConfigString(tt.args.config, "yaml"); (err != nil) != tt.wantErr {
t.Errorf("ReadConfig() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !tt.matcher(cfg) {
t.Error("matcher error")
}
})
}
}

12
pkg/config/constants.go Normal file
View file

@ -0,0 +1,12 @@
package config
const (
EndpointsKey = "endpoints"
OptionsKey = "options"
CurveTypeP224 CurveType = "P224"
CurveTypeP256 CurveType = "P256"
CurveTypeP384 CurveType = "P384"
CurveTypeP521 CurveType = "P521"
CurveTypeED25519 CurveType = "ED25519"
)

17
pkg/config/defaults.go Normal file
View file

@ -0,0 +1,17 @@
package config
var (
registeredDefaults = make(map[string]interface{})
// default aliases
registeredAliases = map[string]string{
"PluginsDirectory": "plugins-directory",
}
)
func AddDefaultValue(key string, val interface{}) {
registeredDefaults[key] = val
}
func AddAlias(alias, orig string) {
registeredAliases[alias] = orig
}

View file

@ -0,0 +1,10 @@
package config
import "github.com/spf13/viper"
type HandlerConfig struct {
HandlerName string
Port uint16
ListenAddress string
Options *viper.Viper
}

View file

@ -1,4 +1,4 @@
package api
package config
import (
"github.com/spf13/viper"
@ -19,12 +19,12 @@ func Test_handlerConfig_HandlerName(t *testing.T) {
want string
}{
{
name: "Get empty HandlerName for uninitialized struct",
name: "Get empty Handler for uninitialized struct",
fields: fields{},
want: "",
},
{
name: "Get expected HandlerName for initialized struct",
name: "Get expected Handler for initialized struct",
fields: fields{
handlerName: "sampleHandler",
},
@ -33,14 +33,14 @@ func Test_handlerConfig_HandlerName(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := handlerConfig{
handlerName: tt.fields.handlerName,
port: tt.fields.port,
listenAddress: tt.fields.listenAddress,
options: tt.fields.options,
h := HandlerConfig{
HandlerName: tt.fields.handlerName,
Port: tt.fields.port,
ListenAddress: tt.fields.listenAddress,
Options: tt.fields.options,
}
if got := h.HandlerName(); got != tt.want {
t.Errorf("HandlerName() = %v, want %v", got, tt.want)
if got := h.HandlerName; got != tt.want {
t.Errorf("Handler() = %v, want %v", got, tt.want)
}
})
}
@ -73,13 +73,13 @@ func Test_handlerConfig_ListenAddress(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := handlerConfig{
handlerName: tt.fields.handlerName,
port: tt.fields.port,
listenAddress: tt.fields.listenAddress,
options: tt.fields.options,
h := HandlerConfig{
HandlerName: tt.fields.handlerName,
Port: tt.fields.port,
ListenAddress: tt.fields.listenAddress,
Options: tt.fields.options,
}
if got := h.ListenAddress(); got != tt.want {
if got := h.ListenAddress; got != tt.want {
t.Errorf("ListenAddress() = %v, want %v", got, tt.want)
}
})
@ -113,13 +113,13 @@ func Test_handlerConfig_Options(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := handlerConfig{
handlerName: tt.fields.handlerName,
port: tt.fields.port,
listenAddress: tt.fields.listenAddress,
options: tt.fields.options,
h := HandlerConfig{
HandlerName: tt.fields.handlerName,
Port: tt.fields.port,
ListenAddress: tt.fields.listenAddress,
Options: tt.fields.options,
}
if got := h.Options(); !reflect.DeepEqual(got, tt.want) {
if got := h.Options; !reflect.DeepEqual(got, tt.want) {
t.Errorf("Options() = %v, want %v", got, tt.want)
}
})
@ -153,13 +153,13 @@ func Test_handlerConfig_Port(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := handlerConfig{
handlerName: tt.fields.handlerName,
port: tt.fields.port,
listenAddress: tt.fields.listenAddress,
options: tt.fields.options,
h := HandlerConfig{
HandlerName: tt.fields.handlerName,
Port: tt.fields.port,
ListenAddress: tt.fields.listenAddress,
Options: tt.fields.options,
}
if got := h.Port(); got != tt.want {
if got := h.Port; got != tt.want {
t.Errorf("Port() = %v, want %v", got, tt.want)
}
})

View file

@ -0,0 +1,25 @@
package config
import (
"github.com/spf13/viper"
)
type MultiHandlerConfig struct {
Handler string
Ports []uint16
ListenAddress string
Options *viper.Viper
}
func (m MultiHandlerConfig) HandlerConfigs() []HandlerConfig {
configs := make([]HandlerConfig, 0)
for _, port := range m.Ports {
configs = append(configs, HandlerConfig{
HandlerName: m.Handler,
Port: port,
ListenAddress: m.ListenAddress,
Options: m.Options,
})
}
return configs
}

View file

@ -2,7 +2,7 @@ package main
import (
"fmt"
"github.com/baez90/inetmock/pkg/api"
"github.com/baez90/inetmock/pkg/config"
"github.com/baez90/inetmock/pkg/logging"
"github.com/miekg/dns"
"go.uber.org/zap"
@ -13,9 +13,9 @@ type dnsHandler struct {
dnsServer []*dns.Server
}
func (d *dnsHandler) Start(config api.HandlerConfig) (err error) {
options := loadFromConfig(config.Options())
addr := fmt.Sprintf("%s:%d", config.ListenAddress(), config.Port())
func (d *dnsHandler) Start(config config.HandlerConfig) (err error) {
options := loadFromConfig(config.Options)
addr := fmt.Sprintf("%s:%d", config.ListenAddress, config.Port)
handler := &regexHandler{
fallback: options.Fallback,

View file

@ -2,7 +2,7 @@ package main
import (
"fmt"
"github.com/baez90/inetmock/pkg/api"
"github.com/baez90/inetmock/pkg/config"
"github.com/baez90/inetmock/pkg/logging"
"go.uber.org/zap"
"net/http"
@ -18,9 +18,9 @@ type httpHandler struct {
server *http.Server
}
func (p *httpHandler) Start(config api.HandlerConfig) (err error) {
options := loadFromConfig(config.Options())
addr := fmt.Sprintf("%s:%d", config.ListenAddress(), config.Port())
func (p *httpHandler) Start(config config.HandlerConfig) (err error) {
options := loadFromConfig(config.Options)
addr := fmt.Sprintf("%s:%d", config.ListenAddress, config.Port)
p.server = &http.Server{Addr: addr, Handler: p.router}
p.logger = p.logger.With(
zap.String("address", addr),

View file

@ -3,6 +3,7 @@ package main
import (
"fmt"
"github.com/baez90/inetmock/pkg/api"
"github.com/baez90/inetmock/pkg/config"
"github.com/baez90/inetmock/pkg/logging"
"go.uber.org/zap"
"gopkg.in/elazarl/goproxy.v1"
@ -19,9 +20,9 @@ type httpProxy struct {
server *http.Server
}
func (h *httpProxy) Start(config api.HandlerConfig) (err error) {
options := loadFromConfig(config.Options())
addr := fmt.Sprintf("%s:%d", config.ListenAddress(), config.Port())
func (h *httpProxy) Start(config config.HandlerConfig) (err error) {
options := loadFromConfig(config.Options)
addr := fmt.Sprintf("%s:%d", config.ListenAddress, config.Port)
h.server = &http.Server{Addr: addr, Handler: h.proxy}
h.logger = h.logger.With(
zap.String("address", addr),

View file

@ -5,6 +5,7 @@ import (
"fmt"
"github.com/baez90/inetmock/pkg/api"
"github.com/baez90/inetmock/pkg/cert"
"github.com/baez90/inetmock/pkg/config"
"github.com/baez90/inetmock/pkg/logging"
"github.com/google/uuid"
"go.uber.org/zap"
@ -27,9 +28,9 @@ type tlsInterceptor struct {
currentConnections map[uuid.UUID]*proxyConn
}
func (t *tlsInterceptor) Start(config api.HandlerConfig) (err error) {
t.options = loadFromConfig(config.Options())
addr := fmt.Sprintf("%s:%d", config.ListenAddress(), config.Port())
func (t *tlsInterceptor) Start(config config.HandlerConfig) (err error) {
t.options = loadFromConfig(config.Options)
addr := fmt.Sprintf("%s:%d", config.ListenAddress, config.Port)
t.logger = t.logger.With(
zap.String("address", addr),

View file

@ -1,77 +0,0 @@
endpoints:
plainHttp:
handler: http_proxy
listenAddress: 0.0.0.0
port: 80
options:
rules:
- pattern: ".*\\.(?i)exe"
response: ./assets/fakeFiles/sample.exe
- pattern: ".*\\.(?i)(jpg|jpeg)"
response: ./assets/fakeFiles/default.jpg
- pattern: ".*\\.(?i)png"
response: ./assets/fakeFiles/default.png
- pattern: ".*\\.(?i)gif"
response: ./assets/fakeFiles/default.gif
- pattern: ".*\\.(?i)ico"
response: ./assets/fakeFiles/default.ico
- pattern: ".*\\.(?i)txt"
response: ./assets/fakeFiles/default.txt
- pattern: ".*"
response: ./assets/fakeFiles/default.html
httpsDowngrade:
handler: tls_interceptor
listenAddress: 0.0.0.0
port: 443
options:
ecdsaCurve: P256
validity:
ca:
notBeforeRelative: 17520h
notAfterRelative: 17520h
domain:
notBeforeRelative: 168h
notAfterRelative: 168h
rootCaCert:
publicKey: ./ca.pem
privateKey: ./ca.key
certCachePath: /tmp/inetmock/
target:
ipAddress: 127.0.0.1
port: 80
plainDns:
handler: dns_mock
listenAddress: 0.0.0.0
port: 53
options:
rules:
- pattern: "www.golem.de"
response: 77.247.84.129
- pattern: ".*\\.google\\.com"
response: 1.1.1.1
- pattern: ".*\\.reddit\\.com"
response: 2.2.2.2
fallback:
strategy: incremental
args:
startIP: 10.0.0.0
dnsOverTlsDowngrade:
handler: tls_interceptor
listenAddress: 0.0.0.0
port: 853
options:
ecdsaCurve: P256
validity:
ca:
notBeforeRelative: 17520h
notAfterRelative: 17520h
domain:
notBeforeRelative: 168h
notAfterRelative: 168h
rootCaCert:
publicKey: ./ca.pem
privateKey: ./ca.key
certCachePath: /tmp/inetmock/
target:
ipAddress: 127.0.0.1
port: 53