Move TLS/cert handling to main app

- apply changes in proxy plugin and TLS interceptor
- add HTTPS proxy support
- move ca-generation command to main app
- minor refactoring to improve API stability
- move mocks to extra packages to avoid cycling imports
- fix bug in multi-port configuration
- change HTTP proxy to redirect to HTTP mock instead of maintaining custom rules
This commit is contained in:
Peter 2020-04-26 00:22:45 +02:00
parent 43d3c62e01
commit 7c2a41ad25
65 changed files with 1722 additions and 1335 deletions

View file

@ -3,16 +3,22 @@
###############
.git/
plugins/
.idea/
.vscode/
deploy/
dist/
doc/
#########
# Files #
#########
*.so
*.out
main
inetmock
README.md
LICENSE
.dockerignore
.gitignore
Dockerfile

View file

@ -3,7 +3,7 @@
before:
hooks:
# You may remove this if you don't use go modules.
- make handlers
- make plugins
builds:
- id: "default"
ldflags:

View file

@ -38,7 +38,7 @@ WORKDIR /app
COPY --from=build /etc/passwd /etc/group /etc/
COPY --from=build --chown=$USER /work/inetmock ./
COPY --from=build --chown=$USER /work/plugins/ ./plugins/
COPY --from=build --chown=$USER /work/*.so ./plugins/
USER $USER:$USER

View file

@ -1,9 +1,8 @@
VERSst/pluginsION = $(shell git describe --dirty --tags --always)
VERSION = $(shell git describe --dirty --tags --always)
DIR = $(dir $(realpath $(firstword $(MAKEFILE_LIST))))
BUILD_PATH = $(DIR)/main.go
PKGS = $(shell go list ./...)
TEST_PKGS = $(shell find . -type f -name "*_test.go" -not -path "./plugins/*" -printf '%h\n' | sort -u)
GO_GEN_FILES = $(shell grep -rnwl --include="*.go" "go:generate" $(DIR))
TEST_PKGS = $(shell find . -type f -name "*_test.go" -not -path "./plugins/*" -not -path "*/mock/*" -printf '%h\n' | sort -u)
GOARGS = GOOS=linux GOARCH=amd64
GO_BUILD_ARGS = -ldflags="-w -s"
GO_CONTAINER_BUILD_ARGS = -ldflags="-w -s" -a -installsuffix cgo
@ -12,8 +11,10 @@ BINARY_NAME = inetmock
PLUGINS = $(wildcard $(DIR)plugins/*/.)
DEBUG_PORT = 2345
DEBUG_ARGS?= --development-logs=true
CONTAINER_BUILDER ?= podman
DOCKER_IMAGE ?= inetmock
.PHONY: clean all format deps update-deps compile debug generate snapshot-release test cli-cover-report html-cover-report plugins $(PLUGINS) $(.PHONY: clean all format deps update-deps compile debug generate snapshot-release test cli-cover-report html-cover-report plugins $(PLUGINS) $(GO_GEN_FILES)
.PHONY: clean all format deps update-deps compile debug generate snapshot-release test cli-cover-report html-cover-report plugins $(PLUGINS) $(GO_GEN_FILES)
all: clean format compile test plugins
clean:
@ -45,26 +46,26 @@ endif
debug: export INETMOCK_PLUGINS_DIRECTORY = $(DIR)
debug:
dlv exec $(DIR)$(BINARY_NAME) \
dlv debug $(DIR) \
--headless \
--listen=:2345 \
--api-version=2 \
--accept-multiclient \
-- $(DEBUG_ARGS)
generate:
@for go_gen_target in $(GO_GEN_FILES); do \
go generate $$go_gen_target; \
done
@go generate ./...
snapshot-release:
@goreleaser release --snapshot --skip-publish --rm-dist
container:
@$(CONTAINER_BUILDER) build -t $(DOCKER_IMAGE):latest -f $(DIR)Dockerfile $(DIR)
test:
@go test -coverprofile=./cov-raw.out -v $(TEST_PKGS)
@cat ./cov-raw.out | grep -v "generated" > ./cov.out
cli-cover-report:
cli-cover-report: test
@go tool cover -func=cov.out
html-cover-report: test

4
go.mod
View file

@ -13,8 +13,8 @@ require (
github.com/spf13/jwalterweatherman v1.1.0 // indirect
github.com/spf13/pflag v1.0.5
github.com/spf13/viper v1.6.3
go.uber.org/zap v1.14.1
golang.org/x/sys v0.0.0-20200409092240-59c9f1ba88fa // indirect
go.uber.org/zap v1.15.0
golang.org/x/sys v0.0.0-20200420163511-1957bb5e6d1f // indirect
golang.org/x/text v0.3.2 // indirect
golang.org/x/tools v0.0.0-20191216052735-49a3e744a425 // indirect
gopkg.in/ini.v1 v1.55.0 // indirect

8
go.sum
View file

@ -147,8 +147,8 @@ go.uber.org/multierr v1.5.0/go.mod h1:FeouvMocqHpRaaGuG9EjoKcStLC43Zu/fmqdUMPcKY
go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee h1:0mgffUl7nfd+FpvXMVz4IDEaUSmT1ysygQC7qYo7sG4=
go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee/go.mod h1:vJERXedbb3MVM5f9Ejo0C68/HhF8uaILCdgjnY+goOA=
go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q=
go.uber.org/zap v1.14.1 h1:nYDKopTbvAPq/NrUVZwT15y2lpROBiLLyoRTbXOYWOo=
go.uber.org/zap v1.14.1/go.mod h1:Mb2vm2krFEG5DV0W9qcHBYFtp/Wku1cvYaqPsS/WYfc=
go.uber.org/zap v1.15.0 h1:ZZCA22JRF2gQE5FoNmhmrf7jeJJ2uhqDUNRYKm8dvmM=
go.uber.org/zap v1.15.0/go.mod h1:Mb2vm2krFEG5DV0W9qcHBYFtp/Wku1cvYaqPsS/WYfc=
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
@ -178,8 +178,8 @@ golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5h
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200409092240-59c9f1ba88fa h1:mQTN3ECqfsViCNBgq+A40vdwhkGykrrQlYe3mPj6BoU=
golang.org/x/sys v0.0.0-20200409092240-59c9f1ba88fa/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200420163511-1957bb5e6d1f h1:gWF768j/LaZugp8dyS4UwsslYCYz9XgFxvlgsn0n9H8=
golang.org/x/sys v0.0.0-20200420163511-1957bb5e6d1f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs=

144
internal/cmd/ca.go Normal file
View file

@ -0,0 +1,144 @@
package cmd
import (
"crypto/tls"
"crypto/x509"
"github.com/baez90/inetmock/pkg/cert"
"github.com/baez90/inetmock/pkg/logging"
"github.com/spf13/cobra"
"go.uber.org/zap"
"time"
)
const (
generateCACommonName = "cn"
generateCaOrganizationName = "o"
generateCaOrganizationalUnitName = "ou"
generateCaCountryName = "c"
generateCaLocalityName = "l"
generateCaStateName = "st"
generateCaStreetAddressName = "street-address"
generateCaPostalCodeName = "postal-code"
generateCACertOutPath = "out-dir"
generateCACurveName = "curve"
generateCANotBeforeRelative = "not-before"
generateCANotAfterRelative = "not-after"
)
var (
generateCaCmd *cobra.Command
caCertOptions cert.GenerationOptions
)
func init() {
generateCaCmd = &cobra.Command{
Use: "generate-ca",
Short: "Generate a new CA certificate and corresponding key",
Long: ``,
Run: runGenerateCA,
}
generateCaCmd.Flags().StringVar(&caCertOptions.CommonName, generateCACommonName, "INetMock", "Certificate Common Name that will also be used as file name during generation.")
generateCaCmd.Flags().StringSliceVar(&caCertOptions.Organization, generateCaOrganizationName, nil, "Organization information to append to certificate")
generateCaCmd.Flags().StringSliceVar(&caCertOptions.OrganizationalUnit, generateCaOrganizationalUnitName, nil, "Organizational unit information to append to certificate")
generateCaCmd.Flags().StringSliceVar(&caCertOptions.Country, generateCaCountryName, nil, "Country information to append to certificate")
generateCaCmd.Flags().StringSliceVar(&caCertOptions.Province, generateCaStateName, nil, "State information to append to certificate")
generateCaCmd.Flags().StringSliceVar(&caCertOptions.Locality, generateCaLocalityName, nil, "Locality information to append to certificate")
generateCaCmd.Flags().StringSliceVar(&caCertOptions.StreetAddress, generateCaStreetAddressName, nil, "Street address information to append to certificate")
generateCaCmd.Flags().StringSliceVar(&caCertOptions.PostalCode, generateCaPostalCodeName, nil, "Postal code information to append to certificate")
generateCaCmd.Flags().String(generateCACertOutPath, "", "Path where CA files should be stored")
generateCaCmd.Flags().String(generateCACurveName, "", "Name of the curve to use, if empty ED25519 is used, other valid values are [P224, P256,P384,P521]")
generateCaCmd.Flags().Duration(generateCANotBeforeRelative, 17520*time.Hour, "Relative time value since when in the past the CA certificate should be valid. The value has a time unit, the greatest time unit is h for hour.")
generateCaCmd.Flags().Duration(generateCANotAfterRelative, 17520*time.Hour, "Relative time value until when in the future the CA certificate should be valid. The value has a time unit, the greatest time unit is h for hour.")
}
func runGenerateCA(cmd *cobra.Command, args []string) {
var certOutPath, curveName string
var notBefore, notAfter time.Duration
var err error
if certOutPath, err = getStringFlag(generateCaCmd, generateCACertOutPath, logger); err != nil {
return
}
if curveName, err = getStringFlag(generateCaCmd, generateCACurveName, logger); err != nil {
return
}
if notBefore, err = getDurationFlag(generateCaCmd, generateCANotBeforeRelative, logger); err != nil {
return
}
if notAfter, err = getDurationFlag(generateCaCmd, generateCANotAfterRelative, logger); err != nil {
return
}
logger, _ := logging.CreateLogger()
logger = logger.With(
zap.String(generateCACurveName, curveName),
zap.String(generateCACertOutPath, certOutPath),
)
generator := cert.NewDefaultGenerator(cert.Options{
CertCachePath: certOutPath,
Curve: cert.CurveType(curveName),
Validity: cert.ValidityByPurpose{
CA: cert.ValidityDuration{
NotAfterRelative: notAfter,
NotBeforeRelative: notBefore,
},
},
})
var caCrt *tls.Certificate
if caCrt, err = generator.CACert(caCertOptions); err != nil {
logger.Error(
"failed to generate CA certificate",
zap.Error(err),
)
return
}
if len(caCrt.Certificate) < 1 {
logger.Error("no public key given for generated CA certificate")
return
}
var pubKey *x509.Certificate
if pubKey, err = x509.ParseCertificate(caCrt.Certificate[0]); err != nil {
logger.Error(
"failed to parse public key from generated CA",
zap.Error(err),
)
return
}
pemCrt := cert.NewPEM(caCrt)
if err = pemCrt.Write(pubKey.Subject.CommonName, certOutPath); err != nil {
logger.Error(
"failed to write Ca files",
zap.Error(err),
)
}
logger.Info("completed certificate generation")
}
func getDurationFlag(cmd *cobra.Command, flagName string, logger logging.Logger) (val time.Duration, err error) {
if val, err = cmd.Flags().GetDuration(flagName); err != nil {
logger.Error(
"failed to parse parse flag",
zap.String("flag", flagName),
zap.Error(err),
)
}
return
}
func getStringFlag(cmd *cobra.Command, flagName string, logger logging.Logger) (val string, err error) {
if val, err = cmd.Flags().GetString(flagName); err != nil {
logger.Error(
"failed to parse parse flag",
zap.String("flag", flagName),
zap.Error(err),
)
}
return
}

View file

@ -3,6 +3,7 @@ package cmd
import (
"github.com/baez90/inetmock/internal/endpoints"
"github.com/baez90/inetmock/internal/plugins"
"github.com/baez90/inetmock/pkg/api"
"github.com/baez90/inetmock/pkg/logging"
"github.com/baez90/inetmock/pkg/path"
"github.com/spf13/viper"
@ -29,9 +30,8 @@ func initApp() (err error) {
registry := plugins.Registry()
endpointManager = endpoints.NewEndpointManager(logger)
if err = rootCmd.ParseFlags(os.Args); err != nil {
return
}
rootCmd.Flags().ParseErrorsWhitelist.UnknownFlags = false
_ = rootCmd.ParseFlags(os.Args)
if err = appConfig.ReadConfig(configFilePath); err != nil {
logger.Error(
@ -43,6 +43,14 @@ func initApp() (err error) {
viperInst := viper.GetViper()
pluginDir := viperInst.GetString("plugins-directory")
if err = api.InitServices(viperInst, logger); err != nil {
logger.Error(
"failed to initialize app services",
zap.Error(err),
)
}
pluginLoadStartTime := time.Now()
if err = registry.LoadPlugins(pluginDir); err != nil {
logger.Error("Failed to load plugins",

View file

@ -18,5 +18,5 @@ The easiest way to explore what commands are available is to start with 'inetmoc
This help page contains a list of available sub-commands starting with the name of the plugin as a prefix.
`,
}
rootCmd.AddCommand(pluginsCmd)
rootCmd.AddCommand(pluginsCmd, generateCaCmd)
}

View file

@ -1,6 +1,7 @@
package config
import (
"github.com/baez90/inetmock/pkg/api"
"github.com/spf13/viper"
)
@ -9,7 +10,7 @@ type MultiHandlerConfig interface {
ListenAddress() string
Ports() []uint16
Options() *viper.Viper
HandlerConfigs() []HandlerConfig
HandlerConfigs() []api.HandlerConfig
}
type multiHandlerConfig struct {
@ -39,10 +40,10 @@ func (m multiHandlerConfig) Options() *viper.Viper {
return m.options
}
func (m multiHandlerConfig) HandlerConfigs() []HandlerConfig {
configs := make([]HandlerConfig, 0)
func (m multiHandlerConfig) HandlerConfigs() []api.HandlerConfig {
configs := make([]api.HandlerConfig, 0)
for _, port := range m.ports {
configs = append(configs, NewHandlerConfig(m.handlerName, port, m.listenAddress, m.options))
configs = append(configs, api.NewHandlerConfig(m.handlerName, port, m.listenAddress, m.options))
}
return configs
}

View file

@ -1,6 +1,7 @@
package config
import (
"github.com/baez90/inetmock/pkg/api"
"github.com/spf13/viper"
"reflect"
"testing"
@ -16,12 +17,12 @@ func Test_multiHandlerConfig_HandlerConfigs(t *testing.T) {
tests := []struct {
name string
fields fields
want []HandlerConfig
want []api.HandlerConfig
}{
{
name: "Get empty array if no ports are set",
fields: fields{},
want: make([]HandlerConfig, 0),
want: make([]api.HandlerConfig, 0),
},
{
name: "Get a single handler config if only one port is set",
@ -31,13 +32,8 @@ func Test_multiHandlerConfig_HandlerConfigs(t *testing.T) {
listenAddress: "0.0.0.0",
options: nil,
},
want: []HandlerConfig{
&handlerConfig{
handlerName: "sampleHandler",
port: 80,
listenAddress: "0.0.0.0",
options: nil,
},
want: []api.HandlerConfig{
api.NewHandlerConfig("sampleHandler", 80, "0.0.0.0", nil),
},
},
{
@ -48,19 +44,9 @@ func Test_multiHandlerConfig_HandlerConfigs(t *testing.T) {
listenAddress: "0.0.0.0",
options: nil,
},
want: []HandlerConfig{
&handlerConfig{
handlerName: "sampleHandler",
port: 80,
listenAddress: "0.0.0.0",
options: nil,
},
&handlerConfig{
handlerName: "sampleHandler",
port: 8080,
listenAddress: "0.0.0.0",
options: nil,
},
want: []api.HandlerConfig{
api.NewHandlerConfig("sampleHandler", 80, "0.0.0.0", nil),
api.NewHandlerConfig("sampleHandler", 8080, "0.0.0.0", nil),
},
},
}

View file

@ -1,6 +1,15 @@
package config
import "github.com/spf13/viper"
import (
"github.com/spf13/viper"
)
const (
pluginConfigKey = "handler"
listenAddressConfigKey = "listenAddress"
portConfigKey = "port"
portsConfigKey = "ports"
)
func CreateMultiHandlerConfig(handlerConfig *viper.Viper) MultiHandlerConfig {
return NewMultiHandlerConfig(

View file

@ -1,8 +1,7 @@
//go:generate mockgen -source=endpoint.go -destination=./../../internal/mock/endpoint_mock.go -package=mock
//go:generate mockgen -source=endpoint.go -destination=./../../internal/mock/endpoints/endpoint_mock.go -package=endpoints_mock
package endpoints
import (
"github.com/baez90/inetmock/internal/config"
"github.com/baez90/inetmock/pkg/api"
)
@ -15,7 +14,7 @@ type Endpoint interface {
type endpoint struct {
name string
handler api.ProtocolHandler
config config.HandlerConfig
config api.HandlerConfig
}
func (e endpoint) Name() string {

View file

@ -40,19 +40,20 @@ func (e endpointManager) StartedEndpoints() []Endpoint {
return e.properlyStartedEndpoints
}
func (e *endpointManager) CreateEndpoint(name string, multiHandlerConfig config.MultiHandlerConfig) (err error) {
if handler, ok := e.registry.HandlerForName(multiHandlerConfig.HandlerName()); ok {
func (e *endpointManager) CreateEndpoint(name string, multiHandlerConfig config.MultiHandlerConfig) error {
for _, handlerConfig := range multiHandlerConfig.HandlerConfigs() {
if handler, ok := e.registry.HandlerForName(multiHandlerConfig.HandlerName()); ok {
e.registeredEndpoints = append(e.registeredEndpoints, &endpoint{
name: name,
handler: handler,
config: handlerConfig,
})
}
} else {
err = fmt.Errorf("no matching handler registered for name %s", multiHandlerConfig.HandlerName())
return fmt.Errorf("no matching handler registered for names %s", multiHandlerConfig.HandlerName())
}
return
}
return nil
}
func (e *endpointManager) StartEndpoints() {

View file

@ -2,7 +2,9 @@ package endpoints
import (
"github.com/baez90/inetmock/internal/config"
"github.com/baez90/inetmock/internal/mock"
api_mock "github.com/baez90/inetmock/internal/mock/api"
logging_mock "github.com/baez90/inetmock/internal/mock/logging"
plugins_mock "github.com/baez90/inetmock/internal/mock/plugins"
"github.com/baez90/inetmock/internal/plugins"
"github.com/baez90/inetmock/pkg/logging"
"github.com/golang/mock/gomock"
@ -33,18 +35,18 @@ func Test_endpointManager_CreateEndpoint(t *testing.T) {
wantEndpoints: 1,
fields: fields{
logger: func() logging.Logger {
return mock.NewMockLogger(gomock.NewController(t))
return logging_mock.NewMockLogger(gomock.NewController(t))
}(),
registeredEndpoints: nil,
properlyStartedEndpoints: nil,
registry: func() plugins.HandlerRegistry {
registry := mock.NewMockHandlerRegistry(gomock.NewController(t))
registry := plugins_mock.NewMockHandlerRegistry(gomock.NewController(t))
registry.
EXPECT().
HandlerForName("sampleHandler").
MinTimes(1).
MaxTimes(1).
Return(mock.NewMockProtocolHandler(gomock.NewController(t)), true)
Return(api_mock.NewMockProtocolHandler(gomock.NewController(t)), true)
return registry
}(),
},
@ -64,12 +66,12 @@ func Test_endpointManager_CreateEndpoint(t *testing.T) {
wantEndpoints: 0,
fields: fields{
logger: func() logging.Logger {
return mock.NewMockLogger(gomock.NewController(t))
return logging_mock.NewMockLogger(gomock.NewController(t))
}(),
registeredEndpoints: nil,
properlyStartedEndpoints: nil,
registry: func() plugins.HandlerRegistry {
registry := mock.NewMockHandlerRegistry(gomock.NewController(t))
registry := plugins_mock.NewMockHandlerRegistry(gomock.NewController(t))
registry.
EXPECT().
HandlerForName("sampleHandler").

View file

@ -2,8 +2,7 @@ package endpoints
import (
"fmt"
"github.com/baez90/inetmock/internal/config"
"github.com/baez90/inetmock/internal/mock"
api_mock "github.com/baez90/inetmock/internal/mock/api"
"github.com/baez90/inetmock/pkg/api"
"github.com/golang/mock/gomock"
"testing"
@ -13,7 +12,7 @@ func Test_endpoint_Name(t *testing.T) {
type fields struct {
name string
handler api.ProtocolHandler
config config.HandlerConfig
config api.HandlerConfig
}
tests := []struct {
name string
@ -51,7 +50,7 @@ func Test_endpoint_Shutdown(t *testing.T) {
type fields struct {
name string
handler api.ProtocolHandler
config config.HandlerConfig
config api.HandlerConfig
}
tests := []struct {
name string
@ -62,7 +61,7 @@ func Test_endpoint_Shutdown(t *testing.T) {
name: "Expect no error if mocked handler does not return one",
fields: fields{
handler: func() api.ProtocolHandler {
handler := mock.NewMockProtocolHandler(gomock.NewController(t))
handler := api_mock.NewMockProtocolHandler(gomock.NewController(t))
handler.EXPECT().
Shutdown().
MaxTimes(1).
@ -76,7 +75,7 @@ func Test_endpoint_Shutdown(t *testing.T) {
name: "Expect error if mocked handler returns one",
fields: fields{
handler: func() api.ProtocolHandler {
handler := mock.NewMockProtocolHandler(gomock.NewController(t))
handler := api_mock.NewMockProtocolHandler(gomock.NewController(t))
handler.EXPECT().
Shutdown().
MaxTimes(1).
@ -103,7 +102,7 @@ func Test_endpoint_Shutdown(t *testing.T) {
func Test_endpoint_Start(t *testing.T) {
demoHandlerConfig := config.NewHandlerConfig(
demoHandlerConfig := api.NewHandlerConfig(
"sampleHandler",
80,
"0.0.0.0",
@ -113,7 +112,7 @@ func Test_endpoint_Start(t *testing.T) {
type fields struct {
name string
handler api.ProtocolHandler
config config.HandlerConfig
config api.HandlerConfig
}
tests := []struct {
name string
@ -124,7 +123,7 @@ func Test_endpoint_Start(t *testing.T) {
name: "Expect no error if mocked handler does not return one",
fields: fields{
handler: func() api.ProtocolHandler {
handler := mock.NewMockProtocolHandler(gomock.NewController(t))
handler := api_mock.NewMockProtocolHandler(gomock.NewController(t))
handler.EXPECT().
Start(nil).
MaxTimes(1).
@ -138,7 +137,7 @@ func Test_endpoint_Start(t *testing.T) {
name: "Expect error if mocked handler returns one",
fields: fields{
handler: func() api.ProtocolHandler {
handler := mock.NewMockProtocolHandler(gomock.NewController(t))
handler := api_mock.NewMockProtocolHandler(gomock.NewController(t))
handler.EXPECT().
Start(nil).
MaxTimes(1).
@ -153,7 +152,7 @@ func Test_endpoint_Start(t *testing.T) {
fields: fields{
config: demoHandlerConfig,
handler: func() api.ProtocolHandler {
handler := mock.NewMockProtocolHandler(gomock.NewController(t))
handler := api_mock.NewMockProtocolHandler(gomock.NewController(t))
handler.EXPECT().
Start(demoHandlerConfig).
MaxTimes(1).

View file

@ -1,4 +1,4 @@
//go:generate mockgen -source=loading.go -destination=./../../internal/mock/handler_registry_mock.go -package=mock
//go:generate mockgen -source=loading.go -destination=./../../internal/mock/plugins/handler_registry_mock.go -package=plugins_mock
package plugins
import (

View file

@ -15,15 +15,15 @@ x-response-rules: &httpResponseRules
- pattern: ".*"
response: ./assets/fakeFiles/default.html
x-tls-options: &tlsOptions
tls:
ecdsaCurve: P256
validity:
ca:
notBeforeRelative: 17520h
notAfterRelative: 17520h
domain:
notBeforeRelative: 168h
notAfterRelative: 168h
server:
NotBeforeRelative: 168h
NotAfterRelative: 168h
rootCaCert:
publicKey: ./ca.pem
privateKey: ./ca.key
@ -43,8 +43,9 @@ endpoints:
listenAddress: 0.0.0.0
port: 3128
options:
fallback: notfound
<<: *httpResponseRules
target:
ipAddress: 127.0.0.1
port: 80
httpsDowngrade:
handler: tls_interceptor
listenAddress: 0.0.0.0
@ -52,7 +53,6 @@ endpoints:
- 443
- 8443
options:
<<: *tlsOptions
target:
ipAddress: 127.0.0.1
port: 80
@ -75,7 +75,6 @@ endpoints:
listenAddress: 0.0.0.0
port: 853
options:
<<: *tlsOptions
target:
ipAddress: 127.0.0.1
port: 53

View file

@ -1,14 +1,7 @@
package config
package api
import "github.com/spf13/viper"
const (
pluginConfigKey = "handler"
listenAddressConfigKey = "listenAddress"
portConfigKey = "port"
portsConfigKey = "ports"
)
type HandlerConfig interface {
HandlerName() string
ListenAddress() string

View file

@ -1,4 +1,4 @@
package config
package api
import (
"github.com/spf13/viper"

View file

@ -1,8 +1,7 @@
//go:generate mockgen -source=protocol_handler.go -destination=./../../internal/mock/protocol_handler_mock.go -package=mock
//go:generate mockgen -source=protocol_handler.go -destination=./../../internal/mock/api/protocol_handler_mock.go -package=api_mock
package api
import (
"github.com/baez90/inetmock/internal/config"
"go.uber.org/zap"
)
@ -11,6 +10,6 @@ type PluginInstanceFactory func() ProtocolHandler
type LoggingFactory func() (*zap.Logger, error)
type ProtocolHandler interface {
Start(config config.HandlerConfig) error
Start(config HandlerConfig) error
Shutdown() error
}

41
pkg/api/services.go Normal file
View file

@ -0,0 +1,41 @@
package api
import (
"github.com/baez90/inetmock/pkg/cert"
"github.com/baez90/inetmock/pkg/logging"
"github.com/spf13/viper"
)
var (
svcs Services
)
type Services interface {
CertStore() cert.Store
}
type services struct {
certStore cert.Store
}
func InitServices(
config *viper.Viper,
logger logging.Logger,
) error {
certStore, err := cert.NewDefaultStore(config, logger)
if err != nil {
return err
}
svcs = &services{
certStore: certStore,
}
return nil
}
func ServicesInstance() Services {
return svcs
}
func (s *services) CertStore() cert.Store {
return s.certStore
}

22
pkg/cert/addr_utils.go Normal file
View file

@ -0,0 +1,22 @@
package cert
import (
"fmt"
"strings"
)
func extractIPFromAddress(addr string) (ip string, err error) {
if idx := strings.LastIndex(addr, ":"); idx < 0 {
err = fmt.Errorf("addr %s does not match expected scheme <ip>:<port>", addr)
} else {
/* get IP part of address */
ip = addr[0:idx]
/* trim [ ] for IPv6 addresses */
if ip[0] == '[' {
ip = ip[1 : len(ip)-1]
}
}
return
}

View file

@ -1,4 +1,4 @@
package main
package cert
import "testing"

78
pkg/cert/cache.go Normal file
View file

@ -0,0 +1,78 @@
//go:generate mockgen -source=cache.go -destination=./../../internal/mock/cert/cert_cache_mock.go -package=cert_mock
package cert
import (
"crypto/tls"
"crypto/x509"
"errors"
)
type Cache interface {
Put(cert *tls.Certificate) error
Get(cn string) (*tls.Certificate, bool)
}
func NewFileSystemCache(certCachePath string, source TimeSource) Cache {
return &fileSystemCache{
certCachePath: certCachePath,
inMemCache: make(map[string]*tls.Certificate),
timeSource: source,
}
}
type fileSystemCache struct {
certCachePath string
inMemCache map[string]*tls.Certificate
timeSource TimeSource
}
func (f *fileSystemCache) Put(cert *tls.Certificate) (err error) {
if cert == nil {
err = errors.New("cert may not be nil")
return
}
var cn string
if len(cert.Certificate) > 0 {
if pubKey, err := x509.ParseCertificate(cert.Certificate[0]); err != nil {
return err
} else {
cn = pubKey.Subject.CommonName
}
f.inMemCache[cn] = cert
pemCrt := NewPEM(cert)
err = pemCrt.Write(cn, f.certCachePath)
} else {
err = errors.New("no public key present for certificate")
}
return
}
func (f *fileSystemCache) Get(cn string) (*tls.Certificate, bool) {
if crt, ok := f.inMemCache[cn]; ok {
return crt, true
}
pemCrt := NewPEM(nil)
if err := pemCrt.Read(cn, f.certCachePath); err != nil || pemCrt.Cert() == nil {
return nil, false
}
x509Cert, err := x509.ParseCertificate(pemCrt.Cert().Certificate[0])
if err == nil && !certShouldBeRenewed(f.timeSource, x509Cert) {
return pemCrt.Cert(), true
}
return nil, false
}
func certShouldBeRenewed(timeSource TimeSource, cert *x509.Certificate) bool {
lifetime := cert.NotAfter.Sub(cert.NotBefore)
// if the cert is closer to the end of the lifetime than lifetime/2 it should be renewed
if cert.NotAfter.Sub(timeSource.UTCNow()) < lifetime/4 {
return true
}
return false
}

299
pkg/cert/cache_test.go Normal file
View file

@ -0,0 +1,299 @@
package cert
import (
"crypto/tls"
"crypto/x509"
"fmt"
certmock "github.com/baez90/inetmock/internal/mock/cert"
"github.com/golang/mock/gomock"
"os"
"path"
"reflect"
"testing"
"time"
)
const (
cnLocalhost = "localhost"
caCN = "UnitTests"
serverRelativeValidity = 24 * time.Hour
caRelativeValidity = 168 * time.Hour
)
var (
serverCN = fmt.Sprintf("%s-%d", cnLocalhost, time.Now().Unix())
)
func Test_certShouldBeRenewed(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
type args struct {
timeSource TimeSource
cert *x509.Certificate
}
tests := []struct {
name string
args args
want bool
}{
{
name: "Expect cert should not be renewed right after creation",
args: args{
timeSource: func() TimeSource {
tsMock := certmock.NewMockTimeSource(ctrl)
tsMock.
EXPECT().
UTCNow().
Return(time.Now().UTC()).
Times(1)
return tsMock
}(),
cert: &x509.Certificate{
NotAfter: time.Now().UTC().Add(serverRelativeValidity),
NotBefore: time.Now().UTC().Add(-serverRelativeValidity),
},
},
want: false,
},
{
name: "Expect cert should be renewed if the remaining lifetime is less than a quarter of the total lifetime",
args: args{
timeSource: func() TimeSource {
tsMock := certmock.NewMockTimeSource(ctrl)
tsMock.
EXPECT().
UTCNow().
Return(time.Now().UTC().Add(serverRelativeValidity/2 + 1*time.Hour)).
Times(1)
return tsMock
}(),
cert: &x509.Certificate{
NotAfter: time.Now().UTC().Add(serverRelativeValidity),
NotBefore: time.Now().UTC().Add(-serverRelativeValidity),
},
},
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := certShouldBeRenewed(tt.args.timeSource, tt.args.cert); got != tt.want {
t.Errorf("certShouldBeRenewed() = %v, want %v", got, tt.want)
}
})
}
}
func Test_fileSystemCache_Get(t *testing.T) {
type fields struct {
certCachePath string
inMemCache map[string]*tls.Certificate
timeSource TimeSource
}
type args struct {
cn string
}
tests := []struct {
name string
fields fields
args args
wantOk bool
}{
{
name: "Get a miss when no cert is present",
fields: fields{
certCachePath: os.TempDir(),
inMemCache: make(map[string]*tls.Certificate),
timeSource: NewTimeSource(),
},
args: args{
cnLocalhost,
},
wantOk: false,
},
{
name: "Get a prepared certificate from the memory cache",
fields: func() fields {
certGen := setupCertGen()
caCrt, _ := certGen.CACert(GenerationOptions{
CommonName: caCN,
})
srvCrt, _ := certGen.ServerCert(GenerationOptions{
CommonName: cnLocalhost,
}, caCrt)
return fields{
certCachePath: os.TempDir(),
inMemCache: map[string]*tls.Certificate{
cnLocalhost: srvCrt,
},
timeSource: NewTimeSource(),
}
}(),
args: args{
cn: cnLocalhost,
},
wantOk: true,
},
{
name: "Get a prepared certificate from the file system",
fields: func() fields {
certGen := setupCertGen()
caCrt, _ := certGen.CACert(GenerationOptions{
CommonName: "INetMock",
})
srvCrt, _ := certGen.ServerCert(GenerationOptions{
CommonName: serverCN,
}, caCrt)
pem := NewPEM(srvCrt)
if err := pem.Write(serverCN, os.TempDir()); err != nil {
panic(err)
}
t.Cleanup(func() {
for _, f := range []string{
path.Join(os.TempDir(), fmt.Sprintf("%s.pem", serverCN)),
path.Join(os.TempDir(), fmt.Sprintf("%s.key", serverCN)),
} {
_ = os.Remove(f)
}
})
return fields{
certCachePath: os.TempDir(),
inMemCache: make(map[string]*tls.Certificate),
timeSource: NewTimeSource(),
}
}(),
args: args{
cn: serverCN,
},
wantOk: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
f := &fileSystemCache{
certCachePath: tt.fields.certCachePath,
inMemCache: tt.fields.inMemCache,
timeSource: tt.fields.timeSource,
}
gotCrt, gotOk := f.Get(tt.args.cn)
if gotOk && (gotCrt == nil || !reflect.DeepEqual(reflect.TypeOf(new(tls.Certificate)), reflect.TypeOf(gotCrt))) {
t.Errorf("Wanted propert certificate but got %v", gotCrt)
}
if gotOk != tt.wantOk {
t.Errorf("Get() gotOk = %v, want %v", gotOk, tt.wantOk)
}
})
}
}
func Test_fileSystemCache_Put(t *testing.T) {
type fields struct {
certCachePath string
inMemCache map[string]*tls.Certificate
timeSource TimeSource
}
type args struct {
cert *tls.Certificate
}
tests := []struct {
name string
fields fields
args args
wantErr bool
}{
{
name: "Want error if nil cert is passed to put",
fields: fields{
certCachePath: os.TempDir(),
inMemCache: make(map[string]*tls.Certificate),
timeSource: NewTimeSource(),
},
args: args{
cert: nil,
},
wantErr: true,
},
{
name: "Want error if empty cert is passed to put",
fields: fields{
certCachePath: os.TempDir(),
inMemCache: make(map[string]*tls.Certificate),
timeSource: NewTimeSource(),
},
args: args{
cert: &tls.Certificate{},
},
wantErr: true,
},
{
name: "No error if valid cert is passed",
fields: fields{
certCachePath: os.TempDir(),
inMemCache: make(map[string]*tls.Certificate),
timeSource: NewTimeSource(),
},
args: args{
cert: func() *tls.Certificate {
gen := setupCertGen()
ca, _ := gen.CACert(GenerationOptions{
CommonName: caCN,
})
srvCN := fmt.Sprintf("%s-%d", cnLocalhost, time.Now().Unix())
t.Cleanup(func() {
for _, f := range []string{
path.Join(os.TempDir(), fmt.Sprintf("%s.pem", srvCN)),
path.Join(os.TempDir(), fmt.Sprintf("%s.key", srvCN)),
} {
_ = os.Remove(f)
}
})
srvCrt, _ := gen.ServerCert(GenerationOptions{
CommonName: srvCN,
}, ca)
return srvCrt
}(),
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
f := &fileSystemCache{
certCachePath: tt.fields.certCachePath,
inMemCache: tt.fields.inMemCache,
timeSource: tt.fields.timeSource,
}
if err := f.Put(tt.args.cert); (err != nil) != tt.wantErr {
t.Errorf("Put() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func setupCertGen() Generator {
return NewDefaultGenerator(Options{
Validity: ValidityByPurpose{
Server: ValidityDuration{
NotBeforeRelative: serverRelativeValidity,
NotAfterRelative: serverRelativeValidity,
},
CA: ValidityDuration{
NotBeforeRelative: caRelativeValidity,
NotAfterRelative: caRelativeValidity,
},
},
})
}

21
pkg/cert/constants.go Normal file
View file

@ -0,0 +1,21 @@
package cert
const (
CurveTypeP224 CurveType = "P224"
CurveTypeP256 CurveType = "P256"
CurveTypeP384 CurveType = "P384"
CurveTypeP521 CurveType = "P521"
CurveTypeED25519 CurveType = "ED25519"
defaultServerValidityDuration = "168h"
defaultCAValidityDuration = "17520h"
certCachePathConfigKey = "tls.certCachePath"
ecdsaCurveConfigKey = "tls.ecdsaCurve"
publicKeyConfigKey = "tls.rootCaCert.publicKey"
privateKeyPathConfigKey = "tls.rootCaCert.privateKey"
caCertValidityNotBeforeKey = "tls.validity.ca.notBeforeRelative"
caCertValidityNotAfterKey = "tls.validity.ca.notAfterRelative"
serverCertValidityNotBeforeKey = "tls.validity.server.notBeforeRelative"
serverCertValidityNotAfterKey = "tls.validity.server.notAfterRelative"
)

43
pkg/cert/defaults.go Normal file
View file

@ -0,0 +1,43 @@
package cert
import (
"github.com/baez90/inetmock/pkg/defaulting"
"reflect"
)
var (
certOptionsDefaulter defaulting.Defaulter = func(instance interface{}) {
switch o := instance.(type) {
case *GenerationOptions:
if len(o.Country) < 1 {
o.Country = []string{"US"}
}
if len(o.Locality) < 1 {
o.Locality = []string{"San Francisco"}
}
if len(o.Organization) < 1 {
o.Organization = []string{"INetMock"}
}
if len(o.StreetAddress) < 1 {
o.StreetAddress = []string{"Golden Gate Bridge"}
}
if len(o.PostalCode) < 1 {
o.PostalCode = []string{"94016"}
}
if len(o.Province) < 1 {
o.Province = []string{""}
}
}
}
)
func init() {
certOptionsType := reflect.TypeOf(GenerationOptions{})
defaulters.Register(certOptionsType, certOptionsDefaulter)
}

39
pkg/cert/defaults_test.go Normal file
View file

@ -0,0 +1,39 @@
package cert
import (
"reflect"
"testing"
)
func Test_certOptionsDefaulter(t *testing.T) {
tests := []struct {
name string
arg GenerationOptions
expected GenerationOptions
}{
{
name: "",
arg: GenerationOptions{
CommonName: "CA",
},
expected: GenerationOptions{
CommonName: "CA",
Country: []string{"US"},
Locality: []string{"San Francisco"},
Organization: []string{"INetMock"},
StreetAddress: []string{"Golden Gate Bridge"},
PostalCode: []string{"94016"},
Province: []string{""},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
certOptionsDefaulter(&tt.arg)
if !reflect.DeepEqual(tt.expected, tt.arg) {
t.Errorf("Apply defaulter expected=%v got=%v", tt.expected, tt.arg)
}
})
}
}

184
pkg/cert/generator.go Normal file
View file

@ -0,0 +1,184 @@
package cert
import (
"crypto/ecdsa"
"crypto/ed25519"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"github.com/baez90/inetmock/pkg/defaulting"
"math/big"
"net"
)
var (
defaulters = defaulting.New()
)
type GenerationOptions struct {
CommonName string
Organization []string
OrganizationalUnit []string
IPAddresses []net.IP
DNSNames []string
Country []string
Province []string
Locality []string
StreetAddress []string
PostalCode []string
}
type Generator interface {
CACert(options GenerationOptions) (*tls.Certificate, error)
ServerCert(options GenerationOptions, ca *tls.Certificate) (*tls.Certificate, error)
}
func NewDefaultGenerator(options Options) Generator {
return NewGenerator(options, NewTimeSource(), defaultKeyProvider(options))
}
func NewGenerator(options Options, source TimeSource, provider KeyProvider) Generator {
return &generator{
options: options,
provider: provider,
timeSource: source,
}
}
type generator struct {
options Options
provider KeyProvider
timeSource TimeSource
outDir string
}
func (g *generator) privateKey() (key interface{}, err error) {
if g.provider != nil {
return g.provider()
} else {
return defaultKeyProvider(g.options)()
}
}
func (g *generator) ServerCert(options GenerationOptions, ca *tls.Certificate) (cert *tls.Certificate, err error) {
defaulters.Apply(&options)
var serialNumber *big.Int
if serialNumber, err = generateSerialNumber(); err != nil {
return
}
var privateKey interface{}
if privateKey, err = g.privateKey(); err != nil {
return
}
template := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
CommonName: options.CommonName,
Organization: options.Organization,
OrganizationalUnit: options.OrganizationalUnit,
Country: options.Country,
Province: options.Province,
Locality: options.Locality,
StreetAddress: options.StreetAddress,
PostalCode: options.PostalCode,
},
IPAddresses: options.IPAddresses,
DNSNames: options.DNSNames,
NotBefore: g.timeSource.UTCNow().Add(-g.options.Validity.Server.NotBeforeRelative),
NotAfter: g.timeSource.UTCNow().Add(g.options.Validity.Server.NotAfterRelative),
KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
}
var caCrt *x509.Certificate
caCrt, err = x509.ParseCertificate(ca.Certificate[0])
var derBytes []byte
if derBytes, err = x509.CreateCertificate(rand.Reader, &template, caCrt, publicKey(privateKey), ca.PrivateKey); err != nil {
return
}
var privateKeyBytes []byte
if privateKeyBytes, err = x509.MarshalPKCS8PrivateKey(privateKey); err != nil {
return
}
if cert, err = parseCert(derBytes, privateKeyBytes); err != nil {
return
}
return
}
func (g generator) CACert(options GenerationOptions) (crt *tls.Certificate, err error) {
defaulters.Apply(&options)
var privateKey interface{}
var serialNumber *big.Int
if serialNumber, err = generateSerialNumber(); err != nil {
return
}
if privateKey, err = g.privateKey(); err != nil {
return
}
template := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
CommonName: options.CommonName,
Organization: options.Organization,
Country: options.Country,
Province: options.Province,
Locality: options.Locality,
StreetAddress: options.StreetAddress,
PostalCode: options.PostalCode,
},
IsCA: true,
NotBefore: g.timeSource.UTCNow().Add(-g.options.Validity.Server.NotBeforeRelative),
NotAfter: g.timeSource.UTCNow().Add(g.options.Validity.Server.NotAfterRelative),
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
}
var derBytes []byte
if derBytes, err = x509.CreateCertificate(rand.Reader, &template, &template, publicKey(privateKey), privateKey); err != nil {
return
}
var privateKeyBytes []byte
if privateKeyBytes, err = x509.MarshalPKCS8PrivateKey(privateKey); err != nil {
return
}
crt, err = parseCert(derBytes, privateKeyBytes)
return
}
func generateSerialNumber() (*big.Int, error) {
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
return rand.Int(rand.Reader, serialNumberLimit)
}
func publicKey(privateKey interface{}) interface{} {
switch k := privateKey.(type) {
case *ecdsa.PrivateKey:
return &k.PublicKey
case ed25519.PrivateKey:
return k.Public().(ed25519.PublicKey)
default:
return nil
}
}
func parseCert(derBytes []byte, privateKeyBytes []byte) (*tls.Certificate, error) {
pemEncodedPublicKey := pem.EncodeToMemory(&pem.Block{Type: certificateBlockType, Bytes: derBytes})
pemEncodedPrivateKey := pem.EncodeToMemory(&pem.Block{Type: privateKeyBlockType, Bytes: privateKeyBytes})
cert, err := tls.X509KeyPair(pemEncodedPublicKey, pemEncodedPrivateKey)
return &cert, err
}

79
pkg/cert/options.go Normal file
View file

@ -0,0 +1,79 @@
package cert
import (
"fmt"
"github.com/spf13/viper"
"os"
"strings"
"time"
)
var (
requiredConfigKeys = []string{
publicKeyConfigKey,
privateKeyPathConfigKey,
}
)
type CurveType string
type File struct {
PublicKeyPath string
PrivateKeyPath string
}
type ValidityDuration struct {
NotBeforeRelative time.Duration
NotAfterRelative time.Duration
}
type ValidityByPurpose struct {
CA ValidityDuration
Server ValidityDuration
}
type Options struct {
RootCACert File
CertCachePath string
Curve CurveType
Validity ValidityByPurpose
}
func loadFromConfig(config *viper.Viper) (Options, error) {
missingKeys := make([]string, 0)
for _, requiredKey := range requiredConfigKeys {
if !config.IsSet(requiredKey) {
missingKeys = append(missingKeys, requiredKey)
}
}
if len(missingKeys) > 0 {
return Options{}, fmt.Errorf("config keys are missing: %s", strings.Join(missingKeys, ", "))
}
config.SetDefault(certCachePathConfigKey, os.TempDir())
config.SetDefault(ecdsaCurveConfigKey, string(CurveTypeED25519))
config.SetDefault(caCertValidityNotBeforeKey, defaultCAValidityDuration)
config.SetDefault(caCertValidityNotAfterKey, defaultCAValidityDuration)
config.SetDefault(serverCertValidityNotBeforeKey, defaultServerValidityDuration)
config.SetDefault(serverCertValidityNotAfterKey, defaultServerValidityDuration)
return Options{
CertCachePath: config.GetString(certCachePathConfigKey),
Curve: CurveType(config.GetString(ecdsaCurveConfigKey)),
Validity: ValidityByPurpose{
CA: ValidityDuration{
NotBeforeRelative: config.GetDuration(caCertValidityNotBeforeKey),
NotAfterRelative: config.GetDuration(caCertValidityNotAfterKey),
},
Server: ValidityDuration{
NotBeforeRelative: config.GetDuration(serverCertValidityNotBeforeKey),
NotAfterRelative: config.GetDuration(serverCertValidityNotAfterKey),
},
},
RootCACert: File{
PublicKeyPath: config.GetString(publicKeyConfigKey),
PrivateKeyPath: config.GetString(privateKeyPathConfigKey),
},
}, nil
}

137
pkg/cert/options_test.go Normal file
View file

@ -0,0 +1,137 @@
package cert
import (
"github.com/spf13/viper"
"os"
"reflect"
"strings"
"testing"
"time"
)
func readViper(cfg string) *viper.Viper {
vpr := viper.New()
vpr.SetConfigType("yaml")
if err := vpr.ReadConfig(strings.NewReader(cfg)); err != nil {
panic(err)
}
return vpr
}
func Test_loadFromConfig(t *testing.T) {
type args struct {
config *viper.Viper
}
tests := []struct {
name string
args args
want Options
wantErr bool
}{
{
name: "Parse valid TLS configuration",
wantErr: false,
args: args{
config: readViper(`
tls:
ecdsaCurve: P256
validity:
ca:
notBeforeRelative: 17520h
notAfterRelative: 17520h
server:
NotBeforeRelative: 168h
NotAfterRelative: 168h
rootCaCert:
publicKey: ./ca.pem
privateKey: ./ca.key
certCachePath: /tmp/inetmock/
`),
},
want: Options{
RootCACert: File{
PublicKeyPath: "./ca.pem",
PrivateKeyPath: "./ca.key",
},
CertCachePath: "/tmp/inetmock/",
Curve: CurveTypeP256,
Validity: ValidityByPurpose{
CA: ValidityDuration{
NotBeforeRelative: 17520 * time.Hour,
NotAfterRelative: 17520 * time.Hour,
},
Server: ValidityDuration{
NotBeforeRelative: 168 * time.Hour,
NotAfterRelative: 168 * time.Hour,
},
},
},
},
{
name: "Get an error if CA public key path is missing",
args: args{
readViper(`
tls:
rootCaCert:
privateKey: ./ca.key
`),
},
want: Options{},
wantErr: true,
},
{
name: "Get an error if CA private key path is missing",
args: args{
readViper(`
tls:
rootCaCert:
publicKey: ./ca.pem
`),
},
want: Options{},
wantErr: true,
},
{
name: "Get default options if all required fields are set",
args: args{
readViper(`
tls:
rootCaCert:
publicKey: ./ca.pem
privateKey: ./ca.key
`),
},
want: Options{
RootCACert: File{
PublicKeyPath: "./ca.pem",
PrivateKeyPath: "./ca.key",
},
CertCachePath: os.TempDir(),
Curve: CurveTypeED25519,
Validity: ValidityByPurpose{
CA: ValidityDuration{
NotBeforeRelative: 17520 * time.Hour,
NotAfterRelative: 17520 * time.Hour,
},
Server: ValidityDuration{
NotBeforeRelative: 168 * time.Hour,
NotAfterRelative: 168 * time.Hour,
},
},
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := loadFromConfig(tt.args.config)
if (err != nil) != tt.wantErr {
t.Errorf("loadFromConfig() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("loadFromConfig() got = %v, want %v", got, tt.want)
}
})
}
}

79
pkg/cert/pem.go Normal file
View file

@ -0,0 +1,79 @@
package cert
import (
"crypto/tls"
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
"github.com/baez90/inetmock/pkg/path"
"os"
"path/filepath"
)
const (
certificateBlockType = "CERTIFICATE"
privateKeyBlockType = "PRIVATE KEY"
)
type PEM interface {
Cert() *tls.Certificate
Write(cn string, outDir string) error
Read(cn string, inDir string) error
ReadFrom(pubKeyPath, privateKeyPath string) error
}
func NewPEM(crt *tls.Certificate) PEM {
return &pemCrt{
crt: crt,
}
}
type pemCrt struct {
crt *tls.Certificate
}
func (p *pemCrt) Cert() *tls.Certificate {
return p.crt
}
func (p pemCrt) Write(cn string, outDir string) (err error) {
var certOut *os.File
if certOut, err = os.Create(filepath.Join(outDir, fmt.Sprintf("%s.pem", cn))); err != nil {
return
}
defer certOut.Close()
if err = pem.Encode(certOut, &pem.Block{Type: certificateBlockType, Bytes: p.crt.Certificate[0]}); err != nil {
return
}
var keyOut *os.File
if keyOut, err = os.OpenFile(filepath.Join(outDir, fmt.Sprintf("%s.key", cn)), os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600); err != nil {
return
}
var privKeyBytes []byte
privKeyBytes, err = x509.MarshalPKCS8PrivateKey(p.crt.PrivateKey)
err = pem.Encode(keyOut, &pem.Block{Type: privateKeyBlockType, Bytes: privKeyBytes})
return
}
func (p *pemCrt) Read(cn string, inDir string) error {
certPath := filepath.Join(inDir, fmt.Sprintf("%s.pem", cn))
keyPath := filepath.Join(inDir, fmt.Sprintf("%s.key", cn))
return p.ReadFrom(certPath, keyPath)
}
func (p *pemCrt) ReadFrom(pubKeyPath, privateKeyPath string) (err error) {
var tlsCrt tls.Certificate
if path.FileExists(pubKeyPath) && path.FileExists(privateKeyPath) {
if tlsCrt, err = tls.LoadX509KeyPair(pubKeyPath, privateKeyPath); err == nil {
p.crt = &tlsCrt
}
} else {
err = errors.New("either public or private key file do not exist")
}
return
}

153
pkg/cert/store.go Normal file
View file

@ -0,0 +1,153 @@
package cert
import (
"crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"github.com/baez90/inetmock/pkg/logging"
"github.com/spf13/viper"
"go.uber.org/zap"
"net"
)
const (
ipv4Loopback = "127.0.0.1"
)
var (
defaultKeyProvider = func(options Options) func() (key interface{}, err error) {
return func() (key interface{}, err error) {
return privateKeyForCurve(options)
}
}
)
type KeyProvider func() (key interface{}, err error)
type Store interface {
CACert() *tls.Certificate
GetCertificate(serverName string, ip string) (*tls.Certificate, error)
TLSConfig() *tls.Config
}
func NewDefaultStore(
config *viper.Viper,
logger logging.Logger,
) (Store, error) {
options, err := loadFromConfig(config)
if err != nil {
return nil, err
}
timeSource := NewTimeSource()
return NewStore(
options,
NewFileSystemCache(options.CertCachePath, timeSource),
NewDefaultGenerator(options),
logger,
)
}
func NewStore(
options Options,
cache Cache,
generator Generator,
logger logging.Logger,
) (Store, error) {
store := &store{
options: options,
cache: cache,
generator: generator,
logger: logger,
}
if err := store.loadCACert(); err != nil {
return nil, err
}
return store, nil
}
type store struct {
options Options
caCert *tls.Certificate
cache Cache
timeSource TimeSource
generator Generator
logger logging.Logger
}
func (s *store) TLSConfig() *tls.Config {
rootCaPool := x509.NewCertPool()
rootCaPubKey, _ := x509.ParseCertificate(s.caCert.Certificate[0])
rootCaPool.AddCert(rootCaPubKey)
return &tls.Config{
GetCertificate: func(info *tls.ClientHelloInfo) (cert *tls.Certificate, err error) {
var localIp string
if localIp, err = extractIPFromAddress(info.Conn.LocalAddr().String()); err != nil {
localIp = ipv4Loopback
}
if cert, err = s.GetCertificate(info.ServerName, localIp); err != nil {
s.logger.Error(
"error while resolving certificate",
zap.String("serverName", info.ServerName),
zap.String("localAddr", localIp),
zap.Error(err),
)
}
return
},
RootCAs: rootCaPool,
}
}
func (s *store) loadCACert() (err error) {
pemCrt := NewPEM(nil)
if err = pemCrt.ReadFrom(s.options.RootCACert.PublicKeyPath, s.options.RootCACert.PrivateKeyPath); err != nil {
return
}
s.caCert = pemCrt.Cert()
return
}
func (s *store) CACert() *tls.Certificate {
return s.caCert
}
func (s *store) GetCertificate(serverName string, ip string) (cert *tls.Certificate, err error) {
if crt, ok := s.cache.Get(serverName); ok {
return crt, nil
}
if cert, err = s.generator.ServerCert(GenerationOptions{
CommonName: serverName,
DNSNames: []string{serverName},
IPAddresses: []net.IP{net.ParseIP(ip)},
}, s.caCert); err == nil {
s.cache.Put(cert)
}
return
}
func privateKeyForCurve(options Options) (privateKey interface{}, err error) {
switch options.Curve {
case CurveTypeP224:
privateKey, err = ecdsa.GenerateKey(elliptic.P224(), rand.Reader)
case CurveTypeP256:
privateKey, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
case CurveTypeP384:
privateKey, err = ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
case CurveTypeP521:
privateKey, err = ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
default:
_, privateKey, err = ed25519.GenerateKey(rand.Reader)
}
return
}

20
pkg/cert/time_source.go Normal file
View file

@ -0,0 +1,20 @@
//go:generate mockgen -source=time_source.go -destination=./../../internal/mock/cert/time_source_mock.go -package=cert_mock
package cert
import "time"
type TimeSource interface {
UTCNow() time.Time
}
func NewTimeSource() TimeSource {
return &defaultTimeSource{}
}
type defaultTimeSource struct {
}
func (d defaultTimeSource) UTCNow() time.Time {
return time.Now().UTC()
}

View file

@ -0,0 +1,43 @@
package defaulting
import (
"reflect"
)
type Defaulter func(instance interface{})
type Registry interface {
Register(t reflect.Type, defaulter ...Defaulter)
Apply(instance interface{})
}
func New() Registry {
return &registry{
defaulters: make(map[reflect.Type][]Defaulter),
}
}
type registry struct {
defaulters map[reflect.Type][]Defaulter
}
func (r *registry) Register(t reflect.Type, defaulter ...Defaulter) {
var given []Defaulter
if r, ok := r.defaulters[t]; ok {
given = r
}
for _, d := range defaulter {
given = append(given, d)
}
r.defaulters[t] = given
}
func (r *registry) Apply(instance interface{}) {
if defs, ok := r.defaulters[reflect.TypeOf(instance)]; ok {
for _, def := range defs {
def(instance)
}
}
}

View file

@ -0,0 +1,111 @@
package defaulting
import (
"reflect"
"testing"
)
func Test_registry_Apply(t *testing.T) {
type sample struct {
i int
}
type fields struct {
defaulters map[reflect.Type][]Defaulter
}
type args struct {
instance interface{}
}
type expect struct {
result interface{}
}
tests := []struct {
name string
fields fields
args args
expect expect
}{
{
name: "Expect setting a sample value",
fields: fields{
defaulters: map[reflect.Type][]Defaulter{
reflect.TypeOf(&sample{}): {func(instance interface{}) {
switch i := instance.(type) {
case *sample:
i.i = 42
}
}},
},
},
args: args{
instance: &sample{},
},
expect: expect{
result: &sample{
i: 42,
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := &registry{
defaulters: tt.fields.defaulters,
}
r.Apply(tt.args.instance)
if !reflect.DeepEqual(tt.expect.result, tt.args.instance) {
t.Errorf("Apply() expected = %v got %v", tt.args.instance, tt.expect.result)
}
})
}
}
func Test_registry_Register(t *testing.T) {
type sample struct {
}
type fields struct {
defaulters map[reflect.Type][]Defaulter
}
type args struct {
t reflect.Type
defaulter []Defaulter
}
type expect struct {
length int
}
tests := []struct {
name string
fields fields
args args
expect expect
}{
{
name: "",
fields: fields{
defaulters: make(map[reflect.Type][]Defaulter),
},
args: args{
t: reflect.TypeOf(sample{}),
defaulter: []Defaulter{func(instance interface{}) {
}},
},
expect: expect{
length: 1,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := &registry{
defaulters: tt.fields.defaulters,
}
r.Register(tt.args.t, tt.args.defaulter...)
if length := len(r.defaulters); length != tt.expect.length {
t.Errorf("len(r.defaulters) expect %d got %d", tt.expect.length, length)
}
})
}
}

View file

@ -1,4 +1,4 @@
//go:generate mockgen -source=logger.go -destination=./../../internal/mock/logger_mock.go -package=mock
//go:generate mockgen -source=logger.go -destination=./../../internal/mock/logging/logger_mock.go -package=logging_mock
package logging
import "go.uber.org/zap"

View file

@ -6,7 +6,7 @@ require (
github.com/baez90/inetmock v0.0.1
github.com/miekg/dns v1.1.29
github.com/spf13/viper v1.6.3
go.uber.org/zap v1.14.1
go.uber.org/zap v1.15.0
golang.org/x/crypto v0.0.0-20200406173513-056763e48d71 // indirect
golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e // indirect
)

View file

@ -149,6 +149,8 @@ go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee/go.mod h1:vJERXedbb3MVM5f9E
go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q=
go.uber.org/zap v1.14.1 h1:nYDKopTbvAPq/NrUVZwT15y2lpROBiLLyoRTbXOYWOo=
go.uber.org/zap v1.14.1/go.mod h1:Mb2vm2krFEG5DV0W9qcHBYFtp/Wku1cvYaqPsS/WYfc=
go.uber.org/zap v1.15.0 h1:ZZCA22JRF2gQE5FoNmhmrf7jeJJ2uhqDUNRYKm8dvmM=
go.uber.org/zap v1.15.0/go.mod h1:Mb2vm2krFEG5DV0W9qcHBYFtp/Wku1cvYaqPsS/WYfc=
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
@ -188,6 +190,8 @@ golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200409092240-59c9f1ba88fa h1:mQTN3ECqfsViCNBgq+A40vdwhkGykrrQlYe3mPj6BoU=
golang.org/x/sys v0.0.0-20200409092240-59c9f1ba88fa/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200420163511-1957bb5e6d1f h1:gWF768j/LaZugp8dyS4UwsslYCYz9XgFxvlgsn0n9H8=
golang.org/x/sys v0.0.0-20200420163511-1957bb5e6d1f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs=

View file

@ -2,7 +2,7 @@ package main
import (
"fmt"
"github.com/baez90/inetmock/internal/config"
"github.com/baez90/inetmock/pkg/api"
"github.com/baez90/inetmock/pkg/logging"
"github.com/miekg/dns"
"go.uber.org/zap"
@ -13,7 +13,7 @@ type dnsHandler struct {
dnsServer []*dns.Server
}
func (d *dnsHandler) Start(config config.HandlerConfig) (err error) {
func (d *dnsHandler) Start(config api.HandlerConfig) (err error) {
options := loadFromConfig(config.Options())
addr := fmt.Sprintf("%s:%d", config.ListenAddress(), config.Port())

View file

@ -5,7 +5,7 @@ go 1.14
require (
github.com/baez90/inetmock v0.0.1
github.com/spf13/viper v1.6.3
go.uber.org/zap v1.14.1
go.uber.org/zap v1.15.0
)
replace github.com/baez90/inetmock v0.0.1 => ../../

View file

@ -147,6 +147,8 @@ go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee/go.mod h1:vJERXedbb3MVM5f9E
go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q=
go.uber.org/zap v1.14.1 h1:nYDKopTbvAPq/NrUVZwT15y2lpROBiLLyoRTbXOYWOo=
go.uber.org/zap v1.14.1/go.mod h1:Mb2vm2krFEG5DV0W9qcHBYFtp/Wku1cvYaqPsS/WYfc=
go.uber.org/zap v1.15.0 h1:ZZCA22JRF2gQE5FoNmhmrf7jeJJ2uhqDUNRYKm8dvmM=
go.uber.org/zap v1.15.0/go.mod h1:Mb2vm2krFEG5DV0W9qcHBYFtp/Wku1cvYaqPsS/WYfc=
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
@ -178,6 +180,8 @@ golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200409092240-59c9f1ba88fa h1:mQTN3ECqfsViCNBgq+A40vdwhkGykrrQlYe3mPj6BoU=
golang.org/x/sys v0.0.0-20200409092240-59c9f1ba88fa/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200420163511-1957bb5e6d1f h1:gWF768j/LaZugp8dyS4UwsslYCYz9XgFxvlgsn0n9H8=
golang.org/x/sys v0.0.0-20200420163511-1957bb5e6d1f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs=

View file

@ -2,7 +2,7 @@ package main
import (
"fmt"
"github.com/baez90/inetmock/internal/config"
"github.com/baez90/inetmock/pkg/api"
"github.com/baez90/inetmock/pkg/logging"
"go.uber.org/zap"
"net/http"
@ -18,7 +18,7 @@ type httpHandler struct {
server *http.Server
}
func (p *httpHandler) Start(config config.HandlerConfig) (err error) {
func (p *httpHandler) Start(config api.HandlerConfig) (err error) {
options := loadFromConfig(config.Options())
addr := fmt.Sprintf("%s:%d", config.ListenAddress(), config.Port())
p.server = &http.Server{Addr: addr, Handler: p.router}

View file

@ -1,53 +0,0 @@
package main
import (
"gopkg.in/elazarl/goproxy.v1"
"net/http"
)
const (
passthroughStrategyName = "passthrough"
notFoundStrategyName = "notfound"
)
var (
fallbackStrategies map[string]ProxyFallbackStrategy
)
func init() {
fallbackStrategies = map[string]ProxyFallbackStrategy{
passthroughStrategyName: &passThroughFallbackStrategy{},
notFoundStrategyName: &notFoundFallbackStrategy{},
}
}
func StrategyForName(name string) ProxyFallbackStrategy {
if strategy, ok := fallbackStrategies[name]; ok {
return strategy
}
return fallbackStrategies[notFoundStrategyName]
}
type ProxyFallbackStrategy interface {
Apply(request *http.Request) (*http.Response, error)
}
type passThroughFallbackStrategy struct {
}
func (p passThroughFallbackStrategy) Apply(request *http.Request) (*http.Response, error) {
return nil, nil
}
type notFoundFallbackStrategy struct {
}
func (n notFoundFallbackStrategy) Apply(request *http.Request) (response *http.Response, err error) {
response = goproxy.NewResponse(
request,
goproxy.ContentTypeText,
http.StatusNotFound,
"The requested resource was not found",
)
return
}

View file

@ -1,46 +0,0 @@
package main
import (
"reflect"
"testing"
)
func TestStrategyForName(t *testing.T) {
type args struct {
name string
}
tests := []struct {
name string
args args
want reflect.Type
}{
{
name: "Test get notfound strategy",
want: reflect.TypeOf(&notFoundFallbackStrategy{}),
args: args{
name: "notfound",
},
},
{
name: "Test get pass through strategy",
want: reflect.TypeOf(&passThroughFallbackStrategy{}),
args: args{
name: "passthrough",
},
},
{
name: "Test get fallback strategy notfound because key is not known",
want: reflect.TypeOf(&notFoundFallbackStrategy{}),
args: args{
name: "asdf12234",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := StrategyForName(tt.args.name); reflect.TypeOf(got) != tt.want {
t.Errorf("StrategyForName() = %v, want %v", got, tt.want)
}
})
}
}

View file

@ -4,9 +4,8 @@ go 1.14
require (
github.com/baez90/inetmock v0.0.1
github.com/elazarl/goproxy v0.0.0-20200315184450-1f3cb6622dad
github.com/spf13/viper v1.6.3
go.uber.org/zap v1.14.1
go.uber.org/zap v1.15.0
gopkg.in/elazarl/goproxy.v1 v1.0.0-20180725130230-947c36da3153
)

View file

@ -152,6 +152,8 @@ go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee/go.mod h1:vJERXedbb3MVM5f9E
go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q=
go.uber.org/zap v1.14.1 h1:nYDKopTbvAPq/NrUVZwT15y2lpROBiLLyoRTbXOYWOo=
go.uber.org/zap v1.14.1/go.mod h1:Mb2vm2krFEG5DV0W9qcHBYFtp/Wku1cvYaqPsS/WYfc=
go.uber.org/zap v1.15.0 h1:ZZCA22JRF2gQE5FoNmhmrf7jeJJ2uhqDUNRYKm8dvmM=
go.uber.org/zap v1.15.0/go.mod h1:Mb2vm2krFEG5DV0W9qcHBYFtp/Wku1cvYaqPsS/WYfc=
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
@ -183,6 +185,8 @@ golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200409092240-59c9f1ba88fa h1:mQTN3ECqfsViCNBgq+A40vdwhkGykrrQlYe3mPj6BoU=
golang.org/x/sys v0.0.0-20200409092240-59c9f1ba88fa/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200420163511-1957bb5e6d1f h1:gWF768j/LaZugp8dyS4UwsslYCYz9XgFxvlgsn0n9H8=
golang.org/x/sys v0.0.0-20200420163511-1957bb5e6d1f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs=

View file

@ -2,7 +2,7 @@ package main
import (
"fmt"
"github.com/baez90/inetmock/internal/config"
"github.com/baez90/inetmock/pkg/api"
"github.com/baez90/inetmock/pkg/logging"
"go.uber.org/zap"
"gopkg.in/elazarl/goproxy.v1"
@ -19,7 +19,7 @@ type httpProxy struct {
server *http.Server
}
func (h *httpProxy) Start(config config.HandlerConfig) (err error) {
func (h *httpProxy) Start(config api.HandlerConfig) (err error) {
options := loadFromConfig(config.Options())
addr := fmt.Sprintf("%s:%d", config.ListenAddress(), config.Port())
h.server = &http.Server{Addr: addr, Handler: h.proxy}
@ -27,11 +27,20 @@ func (h *httpProxy) Start(config config.HandlerConfig) (err error) {
zap.String("address", addr),
)
tlsConfig := api.ServicesInstance().CertStore().TLSConfig()
proxyHandler := &proxyHttpHandler{
options: options,
logger: h.logger,
}
proxyHttpsHandler := &proxyHttpsHandler{
tlsConfig: tlsConfig,
logger: h.logger,
}
h.proxy.OnRequest().Do(proxyHandler)
h.proxy.OnRequest().HandleConnect(proxyHttpsHandler)
go h.startProxy()
return
}

View file

@ -1,52 +1,42 @@
package main
import (
"fmt"
"github.com/spf13/viper"
"regexp"
)
const (
rulesConfigKey = "rules"
patternConfigKey = "pattern"
responseConfigKey = "response"
fallbackStrategyConfigKey = "fallback"
targetSchemeConfigKey = "target.scheme"
targetIpAddressConfigKey = "target.ipAddress"
targetPortConfigKey = "target.port"
)
type targetRule struct {
pattern *regexp.Regexp
response string
type redirectionTarget struct {
scheme string
ipAddress string
port uint16
}
func (tr targetRule) Pattern() *regexp.Regexp {
return tr.pattern
}
func (tr targetRule) Response() string {
return tr.response
func (rt redirectionTarget) host() string {
return fmt.Sprintf("%s:%d", rt.ipAddress, rt.port)
}
type httpProxyOptions struct {
Rules []targetRule
FallbackStrategy ProxyFallbackStrategy
redirectionTarget redirectionTarget
}
func loadFromConfig(config *viper.Viper) (options httpProxyOptions) {
options.FallbackStrategy = StrategyForName(config.GetString(fallbackStrategyConfigKey))
anonRules := config.Get(rulesConfigKey).([]interface{})
config.SetDefault(targetSchemeConfigKey, "http")
config.SetDefault(targetIpAddressConfigKey, "127.0.0.1")
config.SetDefault(targetPortConfigKey, "80")
for _, i := range anonRules {
innerData := i.(map[interface{}]interface{})
if rulePattern, err := regexp.Compile(innerData[patternConfigKey].(string)); err == nil {
options.Rules = append(options.Rules, targetRule{
pattern: rulePattern,
response: innerData[responseConfigKey].(string),
})
} else {
panic(err)
options = httpProxyOptions{
redirectionTarget{
scheme: config.GetString(targetSchemeConfigKey),
ipAddress: config.GetString(targetIpAddressConfigKey),
port: uint16(config.GetInt(targetPortConfigKey)),
},
}
}
return
}

View file

@ -1,126 +1 @@
package main
import (
"bytes"
"github.com/spf13/viper"
"reflect"
"regexp"
"testing"
)
func Test_loadFromConfig(t *testing.T) {
type args struct {
config string
}
tests := []struct {
name string
args args
wantOptions httpProxyOptions
}{
{
name: "Parse proper configuration with notfound strategy",
args: args{
config: `
fallback: notfound,
rules:
- pattern: ".*"
response: ./assets/fakeFiles/default.html
`,
},
wantOptions: httpProxyOptions{
FallbackStrategy: StrategyForName(notFoundStrategyName),
Rules: []targetRule{
{
response: "./assets/fakeFiles/default.html",
pattern: regexp.MustCompile(".*"),
},
},
},
},
{
name: "Parse proper configuration with pass through strategy",
args: args{
config: `
fallback: passthrough
rules:
- pattern: ".*"
response: ./assets/fakeFiles/default.html
`,
},
wantOptions: httpProxyOptions{
FallbackStrategy: StrategyForName(passthroughStrategyName),
Rules: []targetRule{
{
response: "./assets/fakeFiles/default.html",
pattern: regexp.MustCompile(".*"),
},
},
},
},
{
name: "Parse proper configuration and preserve order of rules",
args: args{
config: `
fallback: notfound
rules:
- pattern: ".*\\.(?i)txt"
response: ./assets/fakeFiles/default.txt
- pattern: ".*"
response: ./assets/fakeFiles/default.html
`,
},
wantOptions: httpProxyOptions{
FallbackStrategy: StrategyForName(notFoundStrategyName),
Rules: []targetRule{
{
response: "./assets/fakeFiles/default.txt",
pattern: regexp.MustCompile(".*\\.(?i)txt"),
},
{
response: "./assets/fakeFiles/default.html",
pattern: regexp.MustCompile(".*"),
},
},
},
},
{
name: "Parse configuration with non existing fallback strategy key - falling back to 'notfound'",
args: args{
config: `
fallback: doesNotExist
rules: []
`,
},
wantOptions: httpProxyOptions{
FallbackStrategy: StrategyForName(notFoundStrategyName),
Rules: nil,
},
},
{
name: "Parse configuration without any fallback key",
args: args{
config: `
f4llb4ck: doesNotExist
rules: []
`,
},
wantOptions: httpProxyOptions{
FallbackStrategy: StrategyForName(notFoundStrategyName),
Rules: nil,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
config := viper.New()
config.SetConfigType("yaml")
if err := config.ReadConfig(bytes.NewBufferString(tt.args.config)); err != nil {
t.Errorf("failed to read config %v", err)
return
}
if gotOptions := loadFromConfig(config); !reflect.DeepEqual(gotOptions, tt.wantOptions) {
t.Errorf("loadFromConfig() = %v, want %v", gotOptions, tt.wantOptions)
}
})
}
}

View file

@ -1,14 +1,13 @@
package main
import (
"context"
"crypto/tls"
"github.com/baez90/inetmock/pkg/logging"
"go.uber.org/zap"
"gopkg.in/elazarl/goproxy.v1"
"io"
"mime"
"net/http"
"os"
"path/filepath"
"net/url"
)
type proxyHttpHandler struct {
@ -16,25 +15,27 @@ type proxyHttpHandler struct {
logger logging.Logger
}
/*
TODO implement HTTPS proxy like in TLS interceptor
func (p *proxyHttpHandler) HandleConnect(req string, ctx *goproxy.ProxyCtx) (*goproxy.ConnectAction, string) {
type proxyHttpsHandler struct {
tlsConfig *tls.Config
logger logging.Logger
}
func (p *proxyHttpsHandler) HandleConnect(req string, _ *goproxy.ProxyCtx) (*goproxy.ConnectAction, string) {
p.logger.Info(
"Intercepting HTTPS proxy request",
zap.String("request", req),
)
return &goproxy.ConnectAction{
Action: goproxy.OkConnect,
Action: goproxy.ConnectMitm,
TLSConfig: func(host string, ctx *goproxy.ProxyCtx) (*tls.Config, error) {
return p.tlsConfig, nil
},
}, ""
}*/
func (p *proxyHttpHandler) Handle(req *http.Request, _ *goproxy.ProxyCtx) (retReq *http.Request, resp *http.Response) {
}
func (p *proxyHttpHandler) Handle(req *http.Request, ctx *goproxy.ProxyCtx) (retReq *http.Request, resp *http.Response) {
retReq = req
resp = &http.Response{
Request: req,
TransferEncoding: req.TransferEncoding,
Header: make(http.Header),
StatusCode: http.StatusOK,
}
p.logger.Info(
"Handling request",
zap.String("source", req.RemoteAddr),
@ -45,56 +46,48 @@ func (p *proxyHttpHandler) Handle(req *http.Request, _ *goproxy.ProxyCtx) (retRe
zap.Reflect("headers", req.Header),
)
for _, rule := range p.options.Rules {
if rule.pattern.MatchString(req.URL.Path) {
if file, err := os.Open(rule.response); err != nil {
var err error
if resp, err = ctx.RoundTrip(p.redirectHTTPRequest(req)); err != nil {
p.logger.Error(
"failed to open response target file",
zap.String("resonse", rule.response),
"error while doing roundtrip",
zap.Error(err),
)
continue
} else {
resp.Body = file
if stat, err := file.Stat(); err == nil {
resp.ContentLength = stat.Size()
}
if contentType, err := GetContentType(rule, file); err == nil {
resp.Header["Content-Type"] = []string{contentType}
}
p.logger.Info("returning fake response from rules")
return req, resp
}
}
}
if resp, err := p.options.FallbackStrategy.Apply(req); err != nil {
p.logger.Error(
"failed to apply fallback strategy",
zap.Error(err),
)
} else {
p.logger.Info("returning fake response from fallback strategy")
return req, resp
}
p.logger.Info("falling back to proxying request through")
return req, nil
}
func GetContentType(rule targetRule, file *os.File) (contentType string, err error) {
if contentType = mime.TypeByExtension(filepath.Ext(rule.response)); contentType != "" {
return
}
var buf [512]byte
n, _ := io.ReadFull(file, buf[:])
contentType = http.DetectContentType(buf[:n])
_, err = file.Seek(0, io.SeekStart)
return
}
func (p proxyHttpHandler) redirectHTTPRequest(originalRequest *http.Request) (redirectReq *http.Request) {
redirectReq = &http.Request{
Method: originalRequest.Method,
URL: &url.URL{
Host: p.options.redirectionTarget.host(),
Scheme: p.options.redirectionTarget.scheme,
Path: originalRequest.URL.Path,
ForceQuery: originalRequest.URL.ForceQuery,
Fragment: originalRequest.URL.Fragment,
Opaque: originalRequest.URL.Opaque,
RawPath: originalRequest.URL.RawPath,
RawQuery: originalRequest.URL.RawQuery,
User: originalRequest.URL.User,
},
Proto: originalRequest.Proto,
ProtoMajor: originalRequest.ProtoMajor,
ProtoMinor: originalRequest.ProtoMinor,
Header: originalRequest.Header,
Body: originalRequest.Body,
GetBody: originalRequest.GetBody,
ContentLength: originalRequest.ContentLength,
TransferEncoding: originalRequest.TransferEncoding,
Close: false,
Host: originalRequest.Host,
Form: originalRequest.Form,
PostForm: originalRequest.PostForm,
MultipartForm: originalRequest.MultipartForm,
Trailer: originalRequest.Trailer,
}
redirectReq = redirectReq.WithContext(context.Background())
return
}

View file

@ -1,20 +0,0 @@
package main
import (
"fmt"
"regexp"
"strings"
)
var (
ipExtractionRegex = regexp.MustCompile(`(.+):\d{1,5}$`)
)
func extractIPFromAddress(addr string) (string, error) {
matches := ipExtractionRegex.FindAllStringSubmatch(addr, -1)
if len(matches) > 0 && len(matches[0]) >= 1 {
return strings.Trim(matches[0][1], "[]"), nil
} else {
return "", fmt.Errorf("failed to extract IP address from addr %s", addr)
}
}

View file

@ -1,204 +0,0 @@
package main
import (
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"fmt"
"github.com/baez90/inetmock/pkg/logging"
"github.com/baez90/inetmock/pkg/path"
"go.uber.org/zap"
"math/big"
"net"
"path/filepath"
)
type keyProvider func() (key interface{}, err error)
type certStore struct {
options *tlsOptions
keyProvider keyProvider
caCert *x509.Certificate
caPrivateKey interface{}
certCache map[string]*tls.Certificate
timeSourceInstance timeSource
logger logging.Logger
}
func (cs *certStore) timeSource() timeSource {
if cs.timeSourceInstance == nil {
cs.timeSourceInstance = createTimeSource()
}
return cs.timeSourceInstance
}
func (cs *certStore) initCaCert() (err error) {
if path.FileExists(cs.options.rootCaCert.publicKeyPath) && path.FileExists(cs.options.rootCaCert.privateKeyPath) {
var tlsCert tls.Certificate
if tlsCert, err = tls.LoadX509KeyPair(cs.options.rootCaCert.publicKeyPath, cs.options.rootCaCert.privateKeyPath); err != nil {
return
} else if len(tlsCert.Certificate) > 0 {
cs.caPrivateKey = tlsCert.PrivateKey
cs.caCert, err = x509.ParseCertificate(tlsCert.Certificate[0])
}
} else {
cs.caCert, cs.caPrivateKey, err = cs.generateCaCert()
}
return
}
func (cs *certStore) initProvider() {
if cs.keyProvider == nil {
cs.keyProvider = func() (key interface{}, err error) {
return privateKeyForCurve(cs.options)
}
}
}
func (cs *certStore) generateCaCert() (pubKey *x509.Certificate, privateKey interface{}, err error) {
cs.initProvider()
timeSource := cs.timeSource()
var serialNumber *big.Int
if serialNumber, err = generateSerialNumber(); err != nil {
return
}
if privateKey, err = cs.keyProvider(); err != nil {
return
}
template := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
Organization: []string{"INetMock"},
Country: []string{"US"},
Province: []string{""},
Locality: []string{"San Francisco"},
StreetAddress: []string{"Golden Gate Bridge"},
PostalCode: []string{"94016"},
},
IsCA: true,
NotBefore: timeSource.UTCNow().Add(-cs.options.validity.ca.notBeforeRelative),
NotAfter: timeSource.UTCNow().Add(cs.options.validity.ca.notAfterRelative),
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
}
var derBytes []byte
if derBytes, err = x509.CreateCertificate(rand.Reader, &template, &template, publicKey(privateKey), privateKey); err != nil {
return
}
pubKey, err = x509.ParseCertificate(derBytes)
if err = writePublicKey(cs.options.rootCaCert.publicKeyPath, derBytes); err != nil {
return
}
var privateKeyBytes []byte
if privateKeyBytes, err = x509.MarshalPKCS8PrivateKey(privateKey); err != nil {
return
}
if err = writePrivateKey(cs.options.rootCaCert.privateKeyPath, privateKeyBytes); err != nil {
return
}
return
}
func (cs *certStore) getCertificate(serverName string, ip string) (cert *tls.Certificate, err error) {
if cert, ok := cs.certCache[serverName]; ok {
return cert, nil
}
certPath := filepath.Join(cs.options.certCachePath, fmt.Sprintf("%s.pem", serverName))
keyPath := filepath.Join(cs.options.certCachePath, fmt.Sprintf("%s.key", serverName))
if path.FileExists(certPath) && path.FileExists(keyPath) {
if tlsCert, loadErr := tls.LoadX509KeyPair(certPath, keyPath); loadErr == nil {
cs.certCache[serverName] = &tlsCert
x509Cert, err := x509.ParseCertificate(tlsCert.Certificate[0])
if err == nil && !certShouldBeRenewed(cs.timeSource(), x509Cert) {
return &tlsCert, nil
}
}
}
if cert, err = cs.generateDomainCert(serverName, ip); err == nil {
cs.certCache[serverName] = cert
}
return
}
func (cs *certStore) generateDomainCert(
serverName string,
localIp string,
) (cert *tls.Certificate, err error) {
cs.initProvider()
if cs.caCert == nil {
if err = cs.initCaCert(); err != nil {
return
}
}
var serialNumber *big.Int
if serialNumber, err = generateSerialNumber(); err != nil {
return
}
var privateKey interface{}
if privateKey, err = cs.keyProvider(); err != nil {
return
}
notBefore := cs.timeSource().UTCNow().Add(-cs.options.validity.domain.notBeforeRelative)
notAfter := cs.timeSource().UTCNow().Add(cs.options.validity.domain.notAfterRelative)
cs.logger.Info(
"generate domain certificate",
zap.Time("notBefore", notBefore),
zap.Time("notAfter", notAfter),
)
template := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
Organization: []string{"INetMock"},
},
IPAddresses: []net.IP{net.ParseIP(localIp)},
DNSNames: []string{serverName},
NotBefore: notBefore,
NotAfter: notAfter,
KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
}
var derBytes []byte
if derBytes, err = x509.CreateCertificate(rand.Reader, &template, cs.caCert, publicKey(privateKey), cs.caPrivateKey); err != nil {
return
}
if err = writePublicKey(filepath.Join(cs.options.certCachePath, fmt.Sprintf("%s.pem", serverName)), derBytes); err != nil {
return
}
var privateKeyBytes []byte
if privateKeyBytes, err = x509.MarshalPKCS8PrivateKey(privateKey); err != nil {
return
}
if cert, err = parseCert(derBytes, privateKeyBytes); err != nil {
return
}
if err = writePrivateKey(filepath.Join(cs.options.certCachePath, fmt.Sprintf("%s.key", serverName)), privateKeyBytes); err != nil {
return
}
return
}

View file

@ -1,213 +0,0 @@
package main
import (
"crypto/tls"
"crypto/x509"
"github.com/baez90/inetmock/pkg/logging"
"go.uber.org/zap"
"io/ioutil"
"os"
"path/filepath"
"reflect"
"testing"
"time"
)
func Test_generateCaCert(t *testing.T) {
tmpDir, err := ioutil.TempDir(os.TempDir(), "*-inetmock")
if err != nil {
t.Errorf("failed to create temp dir %v", err)
return
}
options := &tlsOptions{
ecdsaCurve: "P256",
rootCaCert: cert{
publicKeyPath: filepath.Join(tmpDir, "localhost.pem"),
privateKeyPath: filepath.Join(tmpDir, "localhost.key"),
},
validity: validity{
ca: certValidity{
notBeforeRelative: time.Hour * 24 * 30,
notAfterRelative: time.Hour * 24 * 30,
},
},
}
certStore := certStore{
options: options,
}
defer func() {
_ = os.Remove(tmpDir)
}()
_, _, err = certStore.generateCaCert()
if err != nil {
t.Errorf("failed to generate CA cert %v", err)
}
if _, err = os.Stat(options.rootCaCert.publicKeyPath); err != nil {
t.Errorf("cert file was not created")
}
if _, err = os.Stat(options.rootCaCert.privateKeyPath); err != nil {
t.Errorf("cert file was not created")
}
}
func Test_generateDomainCert(t *testing.T) {
tmpDir, err := ioutil.TempDir(os.TempDir(), "*-inetmock")
if err != nil {
t.Errorf("failed to create temp dir %v", err)
return
}
defer func() {
_ = os.Remove(tmpDir)
}()
caTlsCert, _ := loadPEMCert(testCaCrt, testCaKey)
caCert, _ := x509.ParseCertificate(caTlsCert.Certificate[0])
options := &tlsOptions{
ecdsaCurve: "P256",
certCachePath: tmpDir,
validity: validity{
domain: certValidity{
notAfterRelative: time.Hour * 24 * 30,
notBeforeRelative: time.Hour * 24 * 30,
},
ca: certValidity{
notAfterRelative: time.Hour * 24 * 30,
notBeforeRelative: time.Hour * 24 * 30,
},
},
}
zapLogger, _ := zap.NewDevelopment()
logger := logging.NewLogger(zapLogger)
certStore := certStore{
options: options,
caCert: caCert,
logger: logger,
caPrivateKey: caTlsCert.PrivateKey,
}
type args struct {
domain string
ip string
}
tests := []struct {
name string
args args
wantErr bool
}{
{
name: "Test create google.com cert",
args: args{
domain: "google.com",
ip: "127.0.0.1",
},
wantErr: false,
},
{
name: "Test create golem.de cert",
args: args{
domain: "golem.de",
ip: "127.0.0.1",
},
wantErr: false,
},
{
name: "Test create golem.de cert with any IP address",
args: args{
domain: "golem.de",
ip: "10.10.0.10",
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if domainCert, err := certStore.generateDomainCert(
tt.args.domain,
tt.args.ip,
); (err != nil) != tt.wantErr || reflect.DeepEqual(domainCert, tls.Certificate{}) {
t.Errorf("generateDomainCert() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func Test_certStore_initCaCert(t *testing.T) {
tmpDir, err := ioutil.TempDir(os.TempDir(), "*-inetmock")
if err != nil {
t.Errorf("failed to create temp dir %v", err)
return
}
defer func() {
_ = os.Remove(tmpDir)
}()
caCertPath := filepath.Join(tmpDir, "cacert.pem")
caKeyPath := filepath.Join(tmpDir, "cacert.key")
if err := ioutil.WriteFile(caCertPath, testCaCrt, 0600); err != nil {
t.Errorf("failed to write cacert.pem %v", err)
return
}
if err := ioutil.WriteFile(caKeyPath, testCaKey, 0600); err != nil {
t.Errorf("failed to write cacert.key %v", err)
return
}
type fields struct {
options *tlsOptions
caCert *x509.Certificate
}
tests := []struct {
name string
fields fields
wantErr bool
}{
{
name: "Init CA cert from file",
wantErr: false,
fields: fields{
options: &tlsOptions{
rootCaCert: cert{
publicKeyPath: caCertPath,
privateKeyPath: caKeyPath,
},
},
},
},
{
name: "Init CA with new cert",
wantErr: false,
fields: fields{
options: &tlsOptions{
rootCaCert: cert{
publicKeyPath: filepath.Join(tmpDir, "nonexistent.pem"),
privateKeyPath: filepath.Join(tmpDir, "nonexistent.key"),
},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cs := &certStore{
options: tt.fields.options,
caCert: tt.fields.caCert,
}
if err := cs.initCaCert(); (err != nil) != tt.wantErr || cs.caCert == nil {
t.Errorf("initCaCert() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}

View file

@ -1,106 +0,0 @@
package main
import (
"crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"math/big"
"os"
)
type curveType string
const (
certificateBlockType = "CERTIFICATE"
privateKeyBlockType = "PRIVATE KEY"
curveTypeP224 curveType = "P224"
curveTypeP256 curveType = "P256"
curveTypeP384 curveType = "P384"
curveTypeP521 curveType = "P521"
curveTypeED25519 curveType = "ED25519"
)
func loadPEMCert(certPEMBytes []byte, keyPEMBytes []byte) (*tls.Certificate, error) {
cert, err := tls.X509KeyPair(certPEMBytes, keyPEMBytes)
return &cert, err
}
func parseCert(derBytes []byte, privateKeyBytes []byte) (*tls.Certificate, error) {
pemEncodedPublicKey := pem.EncodeToMemory(&pem.Block{Type: certificateBlockType, Bytes: derBytes})
pemEncodedPrivateKey := pem.EncodeToMemory(&pem.Block{Type: privateKeyBlockType, Bytes: privateKeyBytes})
cert, err := tls.X509KeyPair(pemEncodedPublicKey, pemEncodedPrivateKey)
return &cert, err
}
func writePublicKey(crtOutPath string, derBytes []byte) (err error) {
var certOut *os.File
if certOut, err = os.Create(crtOutPath); err != nil {
return
}
if err = pem.Encode(certOut, &pem.Block{Type: certificateBlockType, Bytes: derBytes}); err != nil {
return
}
err = certOut.Close()
return
}
func writePrivateKey(keyOutPath string, privateKeyBytes []byte) (err error) {
var keyOut *os.File
if keyOut, err = os.OpenFile(keyOutPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600); err != nil {
return
}
if err = pem.Encode(keyOut, &pem.Block{Type: privateKeyBlockType, Bytes: privateKeyBytes}); err != nil {
return
}
err = keyOut.Close()
return
}
func certShouldBeRenewed(timeSource timeSource, cert *x509.Certificate) bool {
lifetime := cert.NotAfter.Sub(cert.NotBefore)
// if the cert is closer to the end of the lifetime than lifetime/2 it should be renewed
if cert.NotAfter.Sub(timeSource.UTCNow()) < lifetime/4 {
return true
}
return false
}
func generateSerialNumber() (*big.Int, error) {
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
return rand.Int(rand.Reader, serialNumberLimit)
}
func privateKeyForCurve(options *tlsOptions) (privateKey interface{}, err error) {
switch options.ecdsaCurve {
case "P224":
privateKey, err = ecdsa.GenerateKey(elliptic.P224(), rand.Reader)
case "P256":
privateKey, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
case "P384":
privateKey, err = ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
case "P521":
privateKey, err = ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
default:
_, privateKey, err = ed25519.GenerateKey(rand.Reader)
}
return
}
func publicKey(privateKey interface{}) interface{} {
switch k := privateKey.(type) {
case *ecdsa.PrivateKey:
return &k.PublicKey
case ed25519.PrivateKey:
return k.Public().(ed25519.PublicKey)
default:
return nil
}
}

View file

@ -1,74 +0,0 @@
package main
import (
"crypto/x509"
"testing"
"time"
)
type testTimeSource struct {
nowValue time.Time
}
func (t testTimeSource) UTCNow() time.Time {
return t.nowValue
}
func Test_certShouldBeRenewed(t *testing.T) {
type args struct {
timeSource timeSource
cert *x509.Certificate
}
tests := []struct {
name string
args args
want bool
}{
{
name: "Detect cert is expired",
want: true,
args: args{
cert: &x509.Certificate{
NotAfter: time.Now().UTC().Add(1 * time.Hour),
NotBefore: time.Now().UTC().Add(-1 * time.Hour),
},
timeSource: testTimeSource{
nowValue: time.Now().UTC().Add(2 * time.Hour),
},
},
},
{
name: "Detect cert should be renewed",
want: true,
args: args{
cert: &x509.Certificate{
NotAfter: time.Now().UTC().Add(1 * time.Hour),
NotBefore: time.Now().UTC().Add(-1 * time.Hour),
},
timeSource: testTimeSource{
nowValue: time.Now().UTC().Add(45 * time.Minute),
},
},
},
{
name: "Detect cert shouldn't be renewed",
want: false,
args: args{
cert: &x509.Certificate{
NotAfter: time.Now().UTC().Add(1 * time.Hour),
NotBefore: time.Now().UTC().Add(-1 * time.Hour),
},
timeSource: testTimeSource{
nowValue: time.Now().UTC().Add(25 * time.Minute),
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := certShouldBeRenewed(tt.args.timeSource, tt.args.cert); got != tt.want {
t.Errorf("certShouldBeRenewed() = %v, want %v", got, tt.want)
}
})
}
}

View file

@ -1,113 +0,0 @@
package main
import (
"github.com/baez90/inetmock/pkg/logging"
"github.com/spf13/cobra"
"go.uber.org/zap"
"time"
)
const (
generateCACertOutPath = "cert-out"
generateCAKeyOutPath = "key-out"
generateCACurveName = "curve"
generateCANotBeforeRelative = "not-before"
generateCANotAfterRelative = "not-after"
)
func generateCACmd(logger logging.Logger) *cobra.Command {
cmd := &cobra.Command{
Use: "generate-ca",
Short: "Generate a new CA certificate and corresponding key",
Long: ``,
Run: runGenerateCA(logger),
}
cmd.Flags().String(generateCACertOutPath, "", "Path where CA cert file should be stored")
cmd.Flags().String(generateCAKeyOutPath, "", "Path where CA key file should be stored")
cmd.Flags().String(generateCACurveName, "", "Name of the curve to use, if empty ED25519 is used, other valid values are [P224, P256,P384,P521]")
cmd.Flags().Duration(generateCANotBeforeRelative, 17520*time.Hour, "Relative time value since when in the past the CA certificate should be valid. The value has a time unit, the greatest time unit is h for hour.")
cmd.Flags().Duration(generateCANotAfterRelative, 17520*time.Hour, "Relative time value until when in the future the CA certificate should be valid. The value has a time unit, the greatest time unit is h for hour.")
return cmd
}
func getDurationFlag(cmd *cobra.Command, flagName string, logger logging.Logger) (val time.Duration, err error) {
if val, err = cmd.Flags().GetDuration(flagName); err != nil {
logger.Error(
"failed to parse parse flag",
zap.String("flag", flagName),
zap.Error(err),
)
}
return
}
func getStringFlag(cmd *cobra.Command, flagName string, logger logging.Logger) (val string, err error) {
if val, err = cmd.Flags().GetString(flagName); err != nil {
logger.Error(
"failed to parse parse flag",
zap.String("flag", flagName),
zap.Error(err),
)
}
return
}
func runGenerateCA(logger logging.Logger) func(cmd *cobra.Command, args []string) {
return func(cmd *cobra.Command, args []string) {
var certOutPath, keyOutPath, curveName string
var notBefore, notAfter time.Duration
var err error
if certOutPath, err = getStringFlag(cmd, generateCACertOutPath, logger); err != nil {
return
}
if keyOutPath, err = getStringFlag(cmd, generateCAKeyOutPath, logger); err != nil {
return
}
if curveName, err = getStringFlag(cmd, generateCACurveName, logger); err != nil {
return
}
if notBefore, err = getDurationFlag(cmd, generateCANotBeforeRelative, logger); err != nil {
return
}
if notAfter, err = getDurationFlag(cmd, generateCANotAfterRelative, logger); err != nil {
return
}
logger := logger.With(
zap.String(generateCACurveName, curveName),
zap.String(generateCACertOutPath, certOutPath),
zap.String(generateCAKeyOutPath, keyOutPath),
)
certStore := certStore{
options: &tlsOptions{
ecdsaCurve: curveType(curveName),
validity: validity{
ca: certValidity{
notAfterRelative: notAfter,
notBeforeRelative: notBefore,
},
},
rootCaCert: cert{
publicKeyPath: certOutPath,
privateKeyPath: keyOutPath,
},
},
}
if _, _, err := certStore.generateCaCert(); err != nil {
logger.Error(
"failed to generate CA cert",
zap.Error(err),
)
} else {
logger.Info("Successfully generated CA cert")
}
}
}

View file

@ -4,9 +4,8 @@ go 1.14
require (
github.com/baez90/inetmock v0.0.1
github.com/spf13/cobra v1.0.0
github.com/spf13/viper v1.6.3
go.uber.org/zap v1.14.1
go.uber.org/zap v1.15.0
)
replace github.com/baez90/inetmock v0.0.1 => ../../

View file

@ -147,6 +147,8 @@ go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee/go.mod h1:vJERXedbb3MVM5f9E
go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q=
go.uber.org/zap v1.14.1 h1:nYDKopTbvAPq/NrUVZwT15y2lpROBiLLyoRTbXOYWOo=
go.uber.org/zap v1.14.1/go.mod h1:Mb2vm2krFEG5DV0W9qcHBYFtp/Wku1cvYaqPsS/WYfc=
go.uber.org/zap v1.15.0 h1:ZZCA22JRF2gQE5FoNmhmrf7jeJJ2uhqDUNRYKm8dvmM=
go.uber.org/zap v1.15.0/go.mod h1:Mb2vm2krFEG5DV0W9qcHBYFtp/Wku1cvYaqPsS/WYfc=
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
@ -178,6 +180,8 @@ golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200409092240-59c9f1ba88fa h1:mQTN3ECqfsViCNBgq+A40vdwhkGykrrQlYe3mPj6BoU=
golang.org/x/sys v0.0.0-20200409092240-59c9f1ba88fa/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200420163511-1957bb5e6d1f h1:gWF768j/LaZugp8dyS4UwsslYCYz9XgFxvlgsn0n9H8=
golang.org/x/sys v0.0.0-20200420163511-1957bb5e6d1f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs=

View file

@ -19,5 +19,5 @@ func init() {
logger: logger,
currentConnectionsCount: &sync.WaitGroup{},
}
}, generateCACmd(logger))
})
}

View file

@ -2,9 +2,9 @@ package main
import (
"crypto/tls"
"crypto/x509"
"fmt"
"github.com/baez90/inetmock/internal/config"
"github.com/baez90/inetmock/pkg/api"
"github.com/baez90/inetmock/pkg/cert"
"github.com/baez90/inetmock/pkg/logging"
"go.uber.org/zap"
"net"
@ -17,16 +17,16 @@ const (
)
type tlsInterceptor struct {
options tlsOptions
logger logging.Logger
listener net.Listener
certStore *certStore
options *tlsOptions
certStore cert.Store
shutdownRequested bool
currentConnectionsCount *sync.WaitGroup
currentConnections []*proxyConn
}
func (t *tlsInterceptor) Start(config config.HandlerConfig) (err error) {
func (t *tlsInterceptor) Start(config api.HandlerConfig) (err error) {
t.options = loadFromConfig(config.Options())
addr := fmt.Sprintf("%s:%d", config.ListenAddress(), config.Port())
@ -35,33 +35,7 @@ func (t *tlsInterceptor) Start(config config.HandlerConfig) (err error) {
zap.String("target", t.options.redirectionTarget.address()),
)
t.certStore = &certStore{
options: t.options,
certCache: make(map[string]*tls.Certificate),
logger: t.logger,
}
if err = t.certStore.initCaCert(); err != nil {
t.logger.Error(
"failed to initialize CA cert",
zap.Error(err),
)
err = fmt.Errorf(
"failed to initialize CA cert: %w",
err,
)
return
}
rootCaPool := x509.NewCertPool()
rootCaPool.AddCert(t.certStore.caCert)
tlsConfig := &tls.Config{
GetCertificate: t.getCert,
RootCAs: rootCaPool,
}
if t.listener, err = tls.Listen("tcp", addr, tlsConfig); err != nil {
if t.listener, err = tls.Listen("tcp", addr, api.ServicesInstance().CertStore().TLSConfig()); err != nil {
t.logger.Fatal(
"failed to create tls listener",
zap.Error(err),
@ -122,23 +96,6 @@ func (t *tlsInterceptor) startListener() {
}
}
func (t *tlsInterceptor) getCert(info *tls.ClientHelloInfo) (cert *tls.Certificate, err error) {
var localIp string
if localIp, err = extractIPFromAddress(info.Conn.LocalAddr().String()); err != nil {
localIp = "127.0.0.1"
}
if cert, err = t.certStore.getCertificate(info.ServerName, localIp); err != nil {
t.logger.Error(
"error while resolving certificate",
zap.String("serverName", info.ServerName),
zap.String("localAddr", localIp),
zap.Error(err),
)
}
return
}
func (t *tlsInterceptor) proxyConn(conn net.Conn) {
defer conn.Close()

View file

@ -3,37 +3,13 @@ package main
import (
"fmt"
"github.com/spf13/viper"
"time"
)
const (
certCachePathConfigKey = "certCachePath"
ecdsaCurveConfigKey = "ecdsaCurve"
targetIpAddressConfigKey = "target.ipAddress"
targetPortConfigKey = "target.port"
publicKeyConfigKey = "rootCaCert.publicKey"
privateKeyPathConfigKey = "rootCaCert.privateKey"
caCertValidityNotBeforeKey = "validity.ca.notBeforeRelative"
caCertValidityNotAfterKey = "validity.ca.notAfterRelative"
domainCertValidityNotBeforeKey = "validity.domain.notBeforeRelative"
domainCertValidityNotAfterKey = "validity.domain.notAfterRelative"
)
type cert struct {
publicKeyPath string
privateKeyPath string
}
type certValidity struct {
notBeforeRelative time.Duration
notAfterRelative time.Duration
}
type validity struct {
ca certValidity
domain certValidity
}
type redirectionTarget struct {
ipAddress string
port uint16
@ -44,35 +20,14 @@ func (rt redirectionTarget) address() string {
}
type tlsOptions struct {
rootCaCert cert
certCachePath string
redirectionTarget redirectionTarget
ecdsaCurve curveType
validity validity
}
func loadFromConfig(config *viper.Viper) *tlsOptions {
return &tlsOptions{
certCachePath: config.GetString(certCachePathConfigKey),
ecdsaCurve: curveType(config.GetString(ecdsaCurveConfigKey)),
func loadFromConfig(config *viper.Viper) tlsOptions {
return tlsOptions{
redirectionTarget: redirectionTarget{
ipAddress: config.GetString(targetIpAddressConfigKey),
port: uint16(config.GetInt(targetPortConfigKey)),
},
validity: validity{
ca: certValidity{
notBeforeRelative: config.GetDuration(caCertValidityNotBeforeKey),
notAfterRelative: config.GetDuration(caCertValidityNotAfterKey),
},
domain: certValidity{
notBeforeRelative: config.GetDuration(domainCertValidityNotBeforeKey),
notAfterRelative: config.GetDuration(domainCertValidityNotAfterKey),
},
},
rootCaCert: cert{
publicKeyPath: config.GetString(publicKeyConfigKey),
privateKeyPath: config.GetString(privateKeyPathConfigKey),
},
}
}

View file

@ -1,55 +0,0 @@
package main
import (
"fmt"
"io/ioutil"
"os"
"path/filepath"
"time"
)
var (
testCaCrt []byte
testCaKey []byte
)
func init() {
tmpDir, err := ioutil.TempDir(os.TempDir(), "*-inetmock")
if err != nil {
panic(fmt.Sprintf("failed to create temp dir %v", err))
}
options := &tlsOptions{
ecdsaCurve: "P256",
rootCaCert: cert{
publicKeyPath: filepath.Join(tmpDir, "localhost.pem"),
privateKeyPath: filepath.Join(tmpDir, "localhost.key"),
},
validity: validity{
ca: certValidity{
notBeforeRelative: time.Hour * 24 * 30,
notAfterRelative: time.Hour * 24 * 30,
},
},
}
certStore := certStore{
options: options,
keyProvider: func() (key interface{}, err error) {
return privateKeyForCurve(options)
},
}
defer func() {
_ = os.Remove(tmpDir)
}()
_, _, err = certStore.generateCaCert()
testCaCrt, _ = ioutil.ReadFile(options.rootCaCert.publicKeyPath)
testCaKey, _ = ioutil.ReadFile(options.rootCaCert.privateKeyPath)
if err != nil {
panic(fmt.Sprintf("failed to generate CA cert %v", err))
}
}

View file

@ -1,18 +0,0 @@
package main
import "time"
type timeSource interface {
UTCNow() time.Time
}
func createTimeSource() timeSource {
return &defaultTimeSource{}
}
type defaultTimeSource struct {
}
func (d defaultTimeSource) UTCNow() time.Time {
return time.Now().UTC()
}