Move TLS/cert handling to main app

- apply changes in proxy plugin and TLS interceptor
- add HTTPS proxy support
- move ca-generation command to main app
- minor refactoring to improve API stability
- move mocks to extra packages to avoid cycling imports
- fix bug in multi-port configuration
- change HTTP proxy to redirect to HTTP mock instead of maintaining custom rules
This commit is contained in:
Peter 2020-04-26 00:22:45 +02:00
parent 43d3c62e01
commit 7c2a41ad25
65 changed files with 1722 additions and 1335 deletions

View file

@ -3,16 +3,22 @@
############### ###############
.git/ .git/
plugins/ .idea/
.vscode/
deploy/
dist/
doc/
######### #########
# Files # # Files #
######### #########
*.so
*.out *.out
main main
inetmock inetmock
README.md README.md
LICENSE
.dockerignore .dockerignore
.gitignore .gitignore
Dockerfile Dockerfile

View file

@ -3,7 +3,7 @@
before: before:
hooks: hooks:
# You may remove this if you don't use go modules. # You may remove this if you don't use go modules.
- make handlers - make plugins
builds: builds:
- id: "default" - id: "default"
ldflags: ldflags:

View file

@ -38,8 +38,8 @@ WORKDIR /app
COPY --from=build /etc/passwd /etc/group /etc/ COPY --from=build /etc/passwd /etc/group /etc/
COPY --from=build --chown=$USER /work/inetmock ./ COPY --from=build --chown=$USER /work/inetmock ./
COPY --from=build --chown=$USER /work/plugins/ ./plugins/ COPY --from=build --chown=$USER /work/*.so ./plugins/
USER $USER:$USER USER $USER:$USER
ENTRYPOINT ["/app/inetmock"] ENTRYPOINT ["/app/inetmock"]

View file

@ -1,9 +1,8 @@
VERSst/pluginsION = $(shell git describe --dirty --tags --always) VERSION = $(shell git describe --dirty --tags --always)
DIR = $(dir $(realpath $(firstword $(MAKEFILE_LIST)))) DIR = $(dir $(realpath $(firstword $(MAKEFILE_LIST))))
BUILD_PATH = $(DIR)/main.go BUILD_PATH = $(DIR)/main.go
PKGS = $(shell go list ./...) PKGS = $(shell go list ./...)
TEST_PKGS = $(shell find . -type f -name "*_test.go" -not -path "./plugins/*" -printf '%h\n' | sort -u) TEST_PKGS = $(shell find . -type f -name "*_test.go" -not -path "./plugins/*" -not -path "*/mock/*" -printf '%h\n' | sort -u)
GO_GEN_FILES = $(shell grep -rnwl --include="*.go" "go:generate" $(DIR))
GOARGS = GOOS=linux GOARCH=amd64 GOARGS = GOOS=linux GOARCH=amd64
GO_BUILD_ARGS = -ldflags="-w -s" GO_BUILD_ARGS = -ldflags="-w -s"
GO_CONTAINER_BUILD_ARGS = -ldflags="-w -s" -a -installsuffix cgo GO_CONTAINER_BUILD_ARGS = -ldflags="-w -s" -a -installsuffix cgo
@ -12,8 +11,10 @@ BINARY_NAME = inetmock
PLUGINS = $(wildcard $(DIR)plugins/*/.) PLUGINS = $(wildcard $(DIR)plugins/*/.)
DEBUG_PORT = 2345 DEBUG_PORT = 2345
DEBUG_ARGS?= --development-logs=true DEBUG_ARGS?= --development-logs=true
CONTAINER_BUILDER ?= podman
DOCKER_IMAGE ?= inetmock
.PHONY: clean all format deps update-deps compile debug generate snapshot-release test cli-cover-report html-cover-report plugins $(PLUGINS) $(.PHONY: clean all format deps update-deps compile debug generate snapshot-release test cli-cover-report html-cover-report plugins $(PLUGINS) $(GO_GEN_FILES) .PHONY: clean all format deps update-deps compile debug generate snapshot-release test cli-cover-report html-cover-report plugins $(PLUGINS) $(GO_GEN_FILES)
all: clean format compile test plugins all: clean format compile test plugins
clean: clean:
@ -45,26 +46,26 @@ endif
debug: export INETMOCK_PLUGINS_DIRECTORY = $(DIR) debug: export INETMOCK_PLUGINS_DIRECTORY = $(DIR)
debug: debug:
dlv exec $(DIR)$(BINARY_NAME) \ dlv debug $(DIR) \
--headless \ --headless \
--listen=:2345 \ --listen=:2345 \
--api-version=2 \ --api-version=2 \
--accept-multiclient \
-- $(DEBUG_ARGS) -- $(DEBUG_ARGS)
generate: generate:
@for go_gen_target in $(GO_GEN_FILES); do \ @go generate ./...
go generate $$go_gen_target; \
done
snapshot-release: snapshot-release:
@goreleaser release --snapshot --skip-publish --rm-dist @goreleaser release --snapshot --skip-publish --rm-dist
container:
@$(CONTAINER_BUILDER) build -t $(DOCKER_IMAGE):latest -f $(DIR)Dockerfile $(DIR)
test: test:
@go test -coverprofile=./cov-raw.out -v $(TEST_PKGS) @go test -coverprofile=./cov-raw.out -v $(TEST_PKGS)
@cat ./cov-raw.out | grep -v "generated" > ./cov.out @cat ./cov-raw.out | grep -v "generated" > ./cov.out
cli-cover-report: cli-cover-report: test
@go tool cover -func=cov.out @go tool cover -func=cov.out
html-cover-report: test html-cover-report: test
@ -72,4 +73,4 @@ html-cover-report: test
plugins: $(PLUGINS) plugins: $(PLUGINS)
$(PLUGINS): $(PLUGINS):
$(MAKE) -C $@ $(MAKE) -C $@

4
go.mod
View file

@ -13,8 +13,8 @@ require (
github.com/spf13/jwalterweatherman v1.1.0 // indirect github.com/spf13/jwalterweatherman v1.1.0 // indirect
github.com/spf13/pflag v1.0.5 github.com/spf13/pflag v1.0.5
github.com/spf13/viper v1.6.3 github.com/spf13/viper v1.6.3
go.uber.org/zap v1.14.1 go.uber.org/zap v1.15.0
golang.org/x/sys v0.0.0-20200409092240-59c9f1ba88fa // indirect golang.org/x/sys v0.0.0-20200420163511-1957bb5e6d1f // indirect
golang.org/x/text v0.3.2 // indirect golang.org/x/text v0.3.2 // indirect
golang.org/x/tools v0.0.0-20191216052735-49a3e744a425 // indirect golang.org/x/tools v0.0.0-20191216052735-49a3e744a425 // indirect
gopkg.in/ini.v1 v1.55.0 // indirect gopkg.in/ini.v1 v1.55.0 // indirect

8
go.sum
View file

@ -147,8 +147,8 @@ go.uber.org/multierr v1.5.0/go.mod h1:FeouvMocqHpRaaGuG9EjoKcStLC43Zu/fmqdUMPcKY
go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee h1:0mgffUl7nfd+FpvXMVz4IDEaUSmT1ysygQC7qYo7sG4= go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee h1:0mgffUl7nfd+FpvXMVz4IDEaUSmT1ysygQC7qYo7sG4=
go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee/go.mod h1:vJERXedbb3MVM5f9Ejo0C68/HhF8uaILCdgjnY+goOA= go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee/go.mod h1:vJERXedbb3MVM5f9Ejo0C68/HhF8uaILCdgjnY+goOA=
go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q=
go.uber.org/zap v1.14.1 h1:nYDKopTbvAPq/NrUVZwT15y2lpROBiLLyoRTbXOYWOo= go.uber.org/zap v1.15.0 h1:ZZCA22JRF2gQE5FoNmhmrf7jeJJ2uhqDUNRYKm8dvmM=
go.uber.org/zap v1.14.1/go.mod h1:Mb2vm2krFEG5DV0W9qcHBYFtp/Wku1cvYaqPsS/WYfc= go.uber.org/zap v1.15.0/go.mod h1:Mb2vm2krFEG5DV0W9qcHBYFtp/Wku1cvYaqPsS/WYfc=
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
@ -178,8 +178,8 @@ golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5h
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200409092240-59c9f1ba88fa h1:mQTN3ECqfsViCNBgq+A40vdwhkGykrrQlYe3mPj6BoU= golang.org/x/sys v0.0.0-20200420163511-1957bb5e6d1f h1:gWF768j/LaZugp8dyS4UwsslYCYz9XgFxvlgsn0n9H8=
golang.org/x/sys v0.0.0-20200409092240-59c9f1ba88fa/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200420163511-1957bb5e6d1f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs=

144
internal/cmd/ca.go Normal file
View file

@ -0,0 +1,144 @@
package cmd
import (
"crypto/tls"
"crypto/x509"
"github.com/baez90/inetmock/pkg/cert"
"github.com/baez90/inetmock/pkg/logging"
"github.com/spf13/cobra"
"go.uber.org/zap"
"time"
)
const (
generateCACommonName = "cn"
generateCaOrganizationName = "o"
generateCaOrganizationalUnitName = "ou"
generateCaCountryName = "c"
generateCaLocalityName = "l"
generateCaStateName = "st"
generateCaStreetAddressName = "street-address"
generateCaPostalCodeName = "postal-code"
generateCACertOutPath = "out-dir"
generateCACurveName = "curve"
generateCANotBeforeRelative = "not-before"
generateCANotAfterRelative = "not-after"
)
var (
generateCaCmd *cobra.Command
caCertOptions cert.GenerationOptions
)
func init() {
generateCaCmd = &cobra.Command{
Use: "generate-ca",
Short: "Generate a new CA certificate and corresponding key",
Long: ``,
Run: runGenerateCA,
}
generateCaCmd.Flags().StringVar(&caCertOptions.CommonName, generateCACommonName, "INetMock", "Certificate Common Name that will also be used as file name during generation.")
generateCaCmd.Flags().StringSliceVar(&caCertOptions.Organization, generateCaOrganizationName, nil, "Organization information to append to certificate")
generateCaCmd.Flags().StringSliceVar(&caCertOptions.OrganizationalUnit, generateCaOrganizationalUnitName, nil, "Organizational unit information to append to certificate")
generateCaCmd.Flags().StringSliceVar(&caCertOptions.Country, generateCaCountryName, nil, "Country information to append to certificate")
generateCaCmd.Flags().StringSliceVar(&caCertOptions.Province, generateCaStateName, nil, "State information to append to certificate")
generateCaCmd.Flags().StringSliceVar(&caCertOptions.Locality, generateCaLocalityName, nil, "Locality information to append to certificate")
generateCaCmd.Flags().StringSliceVar(&caCertOptions.StreetAddress, generateCaStreetAddressName, nil, "Street address information to append to certificate")
generateCaCmd.Flags().StringSliceVar(&caCertOptions.PostalCode, generateCaPostalCodeName, nil, "Postal code information to append to certificate")
generateCaCmd.Flags().String(generateCACertOutPath, "", "Path where CA files should be stored")
generateCaCmd.Flags().String(generateCACurveName, "", "Name of the curve to use, if empty ED25519 is used, other valid values are [P224, P256,P384,P521]")
generateCaCmd.Flags().Duration(generateCANotBeforeRelative, 17520*time.Hour, "Relative time value since when in the past the CA certificate should be valid. The value has a time unit, the greatest time unit is h for hour.")
generateCaCmd.Flags().Duration(generateCANotAfterRelative, 17520*time.Hour, "Relative time value until when in the future the CA certificate should be valid. The value has a time unit, the greatest time unit is h for hour.")
}
func runGenerateCA(cmd *cobra.Command, args []string) {
var certOutPath, curveName string
var notBefore, notAfter time.Duration
var err error
if certOutPath, err = getStringFlag(generateCaCmd, generateCACertOutPath, logger); err != nil {
return
}
if curveName, err = getStringFlag(generateCaCmd, generateCACurveName, logger); err != nil {
return
}
if notBefore, err = getDurationFlag(generateCaCmd, generateCANotBeforeRelative, logger); err != nil {
return
}
if notAfter, err = getDurationFlag(generateCaCmd, generateCANotAfterRelative, logger); err != nil {
return
}
logger, _ := logging.CreateLogger()
logger = logger.With(
zap.String(generateCACurveName, curveName),
zap.String(generateCACertOutPath, certOutPath),
)
generator := cert.NewDefaultGenerator(cert.Options{
CertCachePath: certOutPath,
Curve: cert.CurveType(curveName),
Validity: cert.ValidityByPurpose{
CA: cert.ValidityDuration{
NotAfterRelative: notAfter,
NotBeforeRelative: notBefore,
},
},
})
var caCrt *tls.Certificate
if caCrt, err = generator.CACert(caCertOptions); err != nil {
logger.Error(
"failed to generate CA certificate",
zap.Error(err),
)
return
}
if len(caCrt.Certificate) < 1 {
logger.Error("no public key given for generated CA certificate")
return
}
var pubKey *x509.Certificate
if pubKey, err = x509.ParseCertificate(caCrt.Certificate[0]); err != nil {
logger.Error(
"failed to parse public key from generated CA",
zap.Error(err),
)
return
}
pemCrt := cert.NewPEM(caCrt)
if err = pemCrt.Write(pubKey.Subject.CommonName, certOutPath); err != nil {
logger.Error(
"failed to write Ca files",
zap.Error(err),
)
}
logger.Info("completed certificate generation")
}
func getDurationFlag(cmd *cobra.Command, flagName string, logger logging.Logger) (val time.Duration, err error) {
if val, err = cmd.Flags().GetDuration(flagName); err != nil {
logger.Error(
"failed to parse parse flag",
zap.String("flag", flagName),
zap.Error(err),
)
}
return
}
func getStringFlag(cmd *cobra.Command, flagName string, logger logging.Logger) (val string, err error) {
if val, err = cmd.Flags().GetString(flagName); err != nil {
logger.Error(
"failed to parse parse flag",
zap.String("flag", flagName),
zap.Error(err),
)
}
return
}

View file

@ -3,6 +3,7 @@ package cmd
import ( import (
"github.com/baez90/inetmock/internal/endpoints" "github.com/baez90/inetmock/internal/endpoints"
"github.com/baez90/inetmock/internal/plugins" "github.com/baez90/inetmock/internal/plugins"
"github.com/baez90/inetmock/pkg/api"
"github.com/baez90/inetmock/pkg/logging" "github.com/baez90/inetmock/pkg/logging"
"github.com/baez90/inetmock/pkg/path" "github.com/baez90/inetmock/pkg/path"
"github.com/spf13/viper" "github.com/spf13/viper"
@ -29,9 +30,8 @@ func initApp() (err error) {
registry := plugins.Registry() registry := plugins.Registry()
endpointManager = endpoints.NewEndpointManager(logger) endpointManager = endpoints.NewEndpointManager(logger)
if err = rootCmd.ParseFlags(os.Args); err != nil { rootCmd.Flags().ParseErrorsWhitelist.UnknownFlags = false
return _ = rootCmd.ParseFlags(os.Args)
}
if err = appConfig.ReadConfig(configFilePath); err != nil { if err = appConfig.ReadConfig(configFilePath); err != nil {
logger.Error( logger.Error(
@ -43,6 +43,14 @@ func initApp() (err error) {
viperInst := viper.GetViper() viperInst := viper.GetViper()
pluginDir := viperInst.GetString("plugins-directory") pluginDir := viperInst.GetString("plugins-directory")
if err = api.InitServices(viperInst, logger); err != nil {
logger.Error(
"failed to initialize app services",
zap.Error(err),
)
}
pluginLoadStartTime := time.Now() pluginLoadStartTime := time.Now()
if err = registry.LoadPlugins(pluginDir); err != nil { if err = registry.LoadPlugins(pluginDir); err != nil {
logger.Error("Failed to load plugins", logger.Error("Failed to load plugins",

View file

@ -18,5 +18,5 @@ The easiest way to explore what commands are available is to start with 'inetmoc
This help page contains a list of available sub-commands starting with the name of the plugin as a prefix. This help page contains a list of available sub-commands starting with the name of the plugin as a prefix.
`, `,
} }
rootCmd.AddCommand(pluginsCmd) rootCmd.AddCommand(pluginsCmd, generateCaCmd)
} }

View file

@ -1,6 +1,7 @@
package config package config
import ( import (
"github.com/baez90/inetmock/pkg/api"
"github.com/spf13/viper" "github.com/spf13/viper"
) )
@ -9,7 +10,7 @@ type MultiHandlerConfig interface {
ListenAddress() string ListenAddress() string
Ports() []uint16 Ports() []uint16
Options() *viper.Viper Options() *viper.Viper
HandlerConfigs() []HandlerConfig HandlerConfigs() []api.HandlerConfig
} }
type multiHandlerConfig struct { type multiHandlerConfig struct {
@ -39,10 +40,10 @@ func (m multiHandlerConfig) Options() *viper.Viper {
return m.options return m.options
} }
func (m multiHandlerConfig) HandlerConfigs() []HandlerConfig { func (m multiHandlerConfig) HandlerConfigs() []api.HandlerConfig {
configs := make([]HandlerConfig, 0) configs := make([]api.HandlerConfig, 0)
for _, port := range m.ports { for _, port := range m.ports {
configs = append(configs, NewHandlerConfig(m.handlerName, port, m.listenAddress, m.options)) configs = append(configs, api.NewHandlerConfig(m.handlerName, port, m.listenAddress, m.options))
} }
return configs return configs
} }

View file

@ -1,6 +1,7 @@
package config package config
import ( import (
"github.com/baez90/inetmock/pkg/api"
"github.com/spf13/viper" "github.com/spf13/viper"
"reflect" "reflect"
"testing" "testing"
@ -16,12 +17,12 @@ func Test_multiHandlerConfig_HandlerConfigs(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
fields fields fields fields
want []HandlerConfig want []api.HandlerConfig
}{ }{
{ {
name: "Get empty array if no ports are set", name: "Get empty array if no ports are set",
fields: fields{}, fields: fields{},
want: make([]HandlerConfig, 0), want: make([]api.HandlerConfig, 0),
}, },
{ {
name: "Get a single handler config if only one port is set", name: "Get a single handler config if only one port is set",
@ -31,13 +32,8 @@ func Test_multiHandlerConfig_HandlerConfigs(t *testing.T) {
listenAddress: "0.0.0.0", listenAddress: "0.0.0.0",
options: nil, options: nil,
}, },
want: []HandlerConfig{ want: []api.HandlerConfig{
&handlerConfig{ api.NewHandlerConfig("sampleHandler", 80, "0.0.0.0", nil),
handlerName: "sampleHandler",
port: 80,
listenAddress: "0.0.0.0",
options: nil,
},
}, },
}, },
{ {
@ -48,19 +44,9 @@ func Test_multiHandlerConfig_HandlerConfigs(t *testing.T) {
listenAddress: "0.0.0.0", listenAddress: "0.0.0.0",
options: nil, options: nil,
}, },
want: []HandlerConfig{ want: []api.HandlerConfig{
&handlerConfig{ api.NewHandlerConfig("sampleHandler", 80, "0.0.0.0", nil),
handlerName: "sampleHandler", api.NewHandlerConfig("sampleHandler", 8080, "0.0.0.0", nil),
port: 80,
listenAddress: "0.0.0.0",
options: nil,
},
&handlerConfig{
handlerName: "sampleHandler",
port: 8080,
listenAddress: "0.0.0.0",
options: nil,
},
}, },
}, },
} }

View file

@ -1,6 +1,15 @@
package config package config
import "github.com/spf13/viper" import (
"github.com/spf13/viper"
)
const (
pluginConfigKey = "handler"
listenAddressConfigKey = "listenAddress"
portConfigKey = "port"
portsConfigKey = "ports"
)
func CreateMultiHandlerConfig(handlerConfig *viper.Viper) MultiHandlerConfig { func CreateMultiHandlerConfig(handlerConfig *viper.Viper) MultiHandlerConfig {
return NewMultiHandlerConfig( return NewMultiHandlerConfig(

View file

@ -1,8 +1,7 @@
//go:generate mockgen -source=endpoint.go -destination=./../../internal/mock/endpoint_mock.go -package=mock //go:generate mockgen -source=endpoint.go -destination=./../../internal/mock/endpoints/endpoint_mock.go -package=endpoints_mock
package endpoints package endpoints
import ( import (
"github.com/baez90/inetmock/internal/config"
"github.com/baez90/inetmock/pkg/api" "github.com/baez90/inetmock/pkg/api"
) )
@ -15,7 +14,7 @@ type Endpoint interface {
type endpoint struct { type endpoint struct {
name string name string
handler api.ProtocolHandler handler api.ProtocolHandler
config config.HandlerConfig config api.HandlerConfig
} }
func (e endpoint) Name() string { func (e endpoint) Name() string {

View file

@ -40,19 +40,20 @@ func (e endpointManager) StartedEndpoints() []Endpoint {
return e.properlyStartedEndpoints return e.properlyStartedEndpoints
} }
func (e *endpointManager) CreateEndpoint(name string, multiHandlerConfig config.MultiHandlerConfig) (err error) { func (e *endpointManager) CreateEndpoint(name string, multiHandlerConfig config.MultiHandlerConfig) error {
if handler, ok := e.registry.HandlerForName(multiHandlerConfig.HandlerName()); ok { for _, handlerConfig := range multiHandlerConfig.HandlerConfigs() {
for _, handlerConfig := range multiHandlerConfig.HandlerConfigs() { if handler, ok := e.registry.HandlerForName(multiHandlerConfig.HandlerName()); ok {
e.registeredEndpoints = append(e.registeredEndpoints, &endpoint{ e.registeredEndpoints = append(e.registeredEndpoints, &endpoint{
name: name, name: name,
handler: handler, handler: handler,
config: handlerConfig, config: handlerConfig,
}) })
} else {
return fmt.Errorf("no matching handler registered for names %s", multiHandlerConfig.HandlerName())
} }
} else {
err = fmt.Errorf("no matching handler registered for name %s", multiHandlerConfig.HandlerName())
} }
return
return nil
} }
func (e *endpointManager) StartEndpoints() { func (e *endpointManager) StartEndpoints() {

View file

@ -2,7 +2,9 @@ package endpoints
import ( import (
"github.com/baez90/inetmock/internal/config" "github.com/baez90/inetmock/internal/config"
"github.com/baez90/inetmock/internal/mock" api_mock "github.com/baez90/inetmock/internal/mock/api"
logging_mock "github.com/baez90/inetmock/internal/mock/logging"
plugins_mock "github.com/baez90/inetmock/internal/mock/plugins"
"github.com/baez90/inetmock/internal/plugins" "github.com/baez90/inetmock/internal/plugins"
"github.com/baez90/inetmock/pkg/logging" "github.com/baez90/inetmock/pkg/logging"
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
@ -33,18 +35,18 @@ func Test_endpointManager_CreateEndpoint(t *testing.T) {
wantEndpoints: 1, wantEndpoints: 1,
fields: fields{ fields: fields{
logger: func() logging.Logger { logger: func() logging.Logger {
return mock.NewMockLogger(gomock.NewController(t)) return logging_mock.NewMockLogger(gomock.NewController(t))
}(), }(),
registeredEndpoints: nil, registeredEndpoints: nil,
properlyStartedEndpoints: nil, properlyStartedEndpoints: nil,
registry: func() plugins.HandlerRegistry { registry: func() plugins.HandlerRegistry {
registry := mock.NewMockHandlerRegistry(gomock.NewController(t)) registry := plugins_mock.NewMockHandlerRegistry(gomock.NewController(t))
registry. registry.
EXPECT(). EXPECT().
HandlerForName("sampleHandler"). HandlerForName("sampleHandler").
MinTimes(1). MinTimes(1).
MaxTimes(1). MaxTimes(1).
Return(mock.NewMockProtocolHandler(gomock.NewController(t)), true) Return(api_mock.NewMockProtocolHandler(gomock.NewController(t)), true)
return registry return registry
}(), }(),
}, },
@ -64,12 +66,12 @@ func Test_endpointManager_CreateEndpoint(t *testing.T) {
wantEndpoints: 0, wantEndpoints: 0,
fields: fields{ fields: fields{
logger: func() logging.Logger { logger: func() logging.Logger {
return mock.NewMockLogger(gomock.NewController(t)) return logging_mock.NewMockLogger(gomock.NewController(t))
}(), }(),
registeredEndpoints: nil, registeredEndpoints: nil,
properlyStartedEndpoints: nil, properlyStartedEndpoints: nil,
registry: func() plugins.HandlerRegistry { registry: func() plugins.HandlerRegistry {
registry := mock.NewMockHandlerRegistry(gomock.NewController(t)) registry := plugins_mock.NewMockHandlerRegistry(gomock.NewController(t))
registry. registry.
EXPECT(). EXPECT().
HandlerForName("sampleHandler"). HandlerForName("sampleHandler").

View file

@ -2,8 +2,7 @@ package endpoints
import ( import (
"fmt" "fmt"
"github.com/baez90/inetmock/internal/config" api_mock "github.com/baez90/inetmock/internal/mock/api"
"github.com/baez90/inetmock/internal/mock"
"github.com/baez90/inetmock/pkg/api" "github.com/baez90/inetmock/pkg/api"
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"testing" "testing"
@ -13,7 +12,7 @@ func Test_endpoint_Name(t *testing.T) {
type fields struct { type fields struct {
name string name string
handler api.ProtocolHandler handler api.ProtocolHandler
config config.HandlerConfig config api.HandlerConfig
} }
tests := []struct { tests := []struct {
name string name string
@ -51,7 +50,7 @@ func Test_endpoint_Shutdown(t *testing.T) {
type fields struct { type fields struct {
name string name string
handler api.ProtocolHandler handler api.ProtocolHandler
config config.HandlerConfig config api.HandlerConfig
} }
tests := []struct { tests := []struct {
name string name string
@ -62,7 +61,7 @@ func Test_endpoint_Shutdown(t *testing.T) {
name: "Expect no error if mocked handler does not return one", name: "Expect no error if mocked handler does not return one",
fields: fields{ fields: fields{
handler: func() api.ProtocolHandler { handler: func() api.ProtocolHandler {
handler := mock.NewMockProtocolHandler(gomock.NewController(t)) handler := api_mock.NewMockProtocolHandler(gomock.NewController(t))
handler.EXPECT(). handler.EXPECT().
Shutdown(). Shutdown().
MaxTimes(1). MaxTimes(1).
@ -76,7 +75,7 @@ func Test_endpoint_Shutdown(t *testing.T) {
name: "Expect error if mocked handler returns one", name: "Expect error if mocked handler returns one",
fields: fields{ fields: fields{
handler: func() api.ProtocolHandler { handler: func() api.ProtocolHandler {
handler := mock.NewMockProtocolHandler(gomock.NewController(t)) handler := api_mock.NewMockProtocolHandler(gomock.NewController(t))
handler.EXPECT(). handler.EXPECT().
Shutdown(). Shutdown().
MaxTimes(1). MaxTimes(1).
@ -103,7 +102,7 @@ func Test_endpoint_Shutdown(t *testing.T) {
func Test_endpoint_Start(t *testing.T) { func Test_endpoint_Start(t *testing.T) {
demoHandlerConfig := config.NewHandlerConfig( demoHandlerConfig := api.NewHandlerConfig(
"sampleHandler", "sampleHandler",
80, 80,
"0.0.0.0", "0.0.0.0",
@ -113,7 +112,7 @@ func Test_endpoint_Start(t *testing.T) {
type fields struct { type fields struct {
name string name string
handler api.ProtocolHandler handler api.ProtocolHandler
config config.HandlerConfig config api.HandlerConfig
} }
tests := []struct { tests := []struct {
name string name string
@ -124,7 +123,7 @@ func Test_endpoint_Start(t *testing.T) {
name: "Expect no error if mocked handler does not return one", name: "Expect no error if mocked handler does not return one",
fields: fields{ fields: fields{
handler: func() api.ProtocolHandler { handler: func() api.ProtocolHandler {
handler := mock.NewMockProtocolHandler(gomock.NewController(t)) handler := api_mock.NewMockProtocolHandler(gomock.NewController(t))
handler.EXPECT(). handler.EXPECT().
Start(nil). Start(nil).
MaxTimes(1). MaxTimes(1).
@ -138,7 +137,7 @@ func Test_endpoint_Start(t *testing.T) {
name: "Expect error if mocked handler returns one", name: "Expect error if mocked handler returns one",
fields: fields{ fields: fields{
handler: func() api.ProtocolHandler { handler: func() api.ProtocolHandler {
handler := mock.NewMockProtocolHandler(gomock.NewController(t)) handler := api_mock.NewMockProtocolHandler(gomock.NewController(t))
handler.EXPECT(). handler.EXPECT().
Start(nil). Start(nil).
MaxTimes(1). MaxTimes(1).
@ -153,7 +152,7 @@ func Test_endpoint_Start(t *testing.T) {
fields: fields{ fields: fields{
config: demoHandlerConfig, config: demoHandlerConfig,
handler: func() api.ProtocolHandler { handler: func() api.ProtocolHandler {
handler := mock.NewMockProtocolHandler(gomock.NewController(t)) handler := api_mock.NewMockProtocolHandler(gomock.NewController(t))
handler.EXPECT(). handler.EXPECT().
Start(demoHandlerConfig). Start(demoHandlerConfig).
MaxTimes(1). MaxTimes(1).

View file

@ -1,4 +1,4 @@
//go:generate mockgen -source=loading.go -destination=./../../internal/mock/handler_registry_mock.go -package=mock //go:generate mockgen -source=loading.go -destination=./../../internal/mock/plugins/handler_registry_mock.go -package=plugins_mock
package plugins package plugins
import ( import (

View file

@ -15,15 +15,15 @@ x-response-rules: &httpResponseRules
- pattern: ".*" - pattern: ".*"
response: ./assets/fakeFiles/default.html response: ./assets/fakeFiles/default.html
x-tls-options: &tlsOptions tls:
ecdsaCurve: P256 ecdsaCurve: P256
validity: validity:
ca: ca:
notBeforeRelative: 17520h notBeforeRelative: 17520h
notAfterRelative: 17520h notAfterRelative: 17520h
domain: server:
notBeforeRelative: 168h NotBeforeRelative: 168h
notAfterRelative: 168h NotAfterRelative: 168h
rootCaCert: rootCaCert:
publicKey: ./ca.pem publicKey: ./ca.pem
privateKey: ./ca.key privateKey: ./ca.key
@ -43,8 +43,9 @@ endpoints:
listenAddress: 0.0.0.0 listenAddress: 0.0.0.0
port: 3128 port: 3128
options: options:
fallback: notfound target:
<<: *httpResponseRules ipAddress: 127.0.0.1
port: 80
httpsDowngrade: httpsDowngrade:
handler: tls_interceptor handler: tls_interceptor
listenAddress: 0.0.0.0 listenAddress: 0.0.0.0
@ -52,7 +53,6 @@ endpoints:
- 443 - 443
- 8443 - 8443
options: options:
<<: *tlsOptions
target: target:
ipAddress: 127.0.0.1 ipAddress: 127.0.0.1
port: 80 port: 80
@ -75,7 +75,6 @@ endpoints:
listenAddress: 0.0.0.0 listenAddress: 0.0.0.0
port: 853 port: 853
options: options:
<<: *tlsOptions
target: target:
ipAddress: 127.0.0.1 ipAddress: 127.0.0.1
port: 53 port: 53

View file

@ -1,14 +1,7 @@
package config package api
import "github.com/spf13/viper" import "github.com/spf13/viper"
const (
pluginConfigKey = "handler"
listenAddressConfigKey = "listenAddress"
portConfigKey = "port"
portsConfigKey = "ports"
)
type HandlerConfig interface { type HandlerConfig interface {
HandlerName() string HandlerName() string
ListenAddress() string ListenAddress() string

View file

@ -1,4 +1,4 @@
package config package api
import ( import (
"github.com/spf13/viper" "github.com/spf13/viper"

View file

@ -1,8 +1,7 @@
//go:generate mockgen -source=protocol_handler.go -destination=./../../internal/mock/protocol_handler_mock.go -package=mock //go:generate mockgen -source=protocol_handler.go -destination=./../../internal/mock/api/protocol_handler_mock.go -package=api_mock
package api package api
import ( import (
"github.com/baez90/inetmock/internal/config"
"go.uber.org/zap" "go.uber.org/zap"
) )
@ -11,6 +10,6 @@ type PluginInstanceFactory func() ProtocolHandler
type LoggingFactory func() (*zap.Logger, error) type LoggingFactory func() (*zap.Logger, error)
type ProtocolHandler interface { type ProtocolHandler interface {
Start(config config.HandlerConfig) error Start(config HandlerConfig) error
Shutdown() error Shutdown() error
} }

41
pkg/api/services.go Normal file
View file

@ -0,0 +1,41 @@
package api
import (
"github.com/baez90/inetmock/pkg/cert"
"github.com/baez90/inetmock/pkg/logging"
"github.com/spf13/viper"
)
var (
svcs Services
)
type Services interface {
CertStore() cert.Store
}
type services struct {
certStore cert.Store
}
func InitServices(
config *viper.Viper,
logger logging.Logger,
) error {
certStore, err := cert.NewDefaultStore(config, logger)
if err != nil {
return err
}
svcs = &services{
certStore: certStore,
}
return nil
}
func ServicesInstance() Services {
return svcs
}
func (s *services) CertStore() cert.Store {
return s.certStore
}

22
pkg/cert/addr_utils.go Normal file
View file

@ -0,0 +1,22 @@
package cert
import (
"fmt"
"strings"
)
func extractIPFromAddress(addr string) (ip string, err error) {
if idx := strings.LastIndex(addr, ":"); idx < 0 {
err = fmt.Errorf("addr %s does not match expected scheme <ip>:<port>", addr)
} else {
/* get IP part of address */
ip = addr[0:idx]
/* trim [ ] for IPv6 addresses */
if ip[0] == '[' {
ip = ip[1 : len(ip)-1]
}
}
return
}

View file

@ -1,4 +1,4 @@
package main package cert
import "testing" import "testing"

78
pkg/cert/cache.go Normal file
View file

@ -0,0 +1,78 @@
//go:generate mockgen -source=cache.go -destination=./../../internal/mock/cert/cert_cache_mock.go -package=cert_mock
package cert
import (
"crypto/tls"
"crypto/x509"
"errors"
)
type Cache interface {
Put(cert *tls.Certificate) error
Get(cn string) (*tls.Certificate, bool)
}
func NewFileSystemCache(certCachePath string, source TimeSource) Cache {
return &fileSystemCache{
certCachePath: certCachePath,
inMemCache: make(map[string]*tls.Certificate),
timeSource: source,
}
}
type fileSystemCache struct {
certCachePath string
inMemCache map[string]*tls.Certificate
timeSource TimeSource
}
func (f *fileSystemCache) Put(cert *tls.Certificate) (err error) {
if cert == nil {
err = errors.New("cert may not be nil")
return
}
var cn string
if len(cert.Certificate) > 0 {
if pubKey, err := x509.ParseCertificate(cert.Certificate[0]); err != nil {
return err
} else {
cn = pubKey.Subject.CommonName
}
f.inMemCache[cn] = cert
pemCrt := NewPEM(cert)
err = pemCrt.Write(cn, f.certCachePath)
} else {
err = errors.New("no public key present for certificate")
}
return
}
func (f *fileSystemCache) Get(cn string) (*tls.Certificate, bool) {
if crt, ok := f.inMemCache[cn]; ok {
return crt, true
}
pemCrt := NewPEM(nil)
if err := pemCrt.Read(cn, f.certCachePath); err != nil || pemCrt.Cert() == nil {
return nil, false
}
x509Cert, err := x509.ParseCertificate(pemCrt.Cert().Certificate[0])
if err == nil && !certShouldBeRenewed(f.timeSource, x509Cert) {
return pemCrt.Cert(), true
}
return nil, false
}
func certShouldBeRenewed(timeSource TimeSource, cert *x509.Certificate) bool {
lifetime := cert.NotAfter.Sub(cert.NotBefore)
// if the cert is closer to the end of the lifetime than lifetime/2 it should be renewed
if cert.NotAfter.Sub(timeSource.UTCNow()) < lifetime/4 {
return true
}
return false
}

299
pkg/cert/cache_test.go Normal file
View file

@ -0,0 +1,299 @@
package cert
import (
"crypto/tls"
"crypto/x509"
"fmt"
certmock "github.com/baez90/inetmock/internal/mock/cert"
"github.com/golang/mock/gomock"
"os"
"path"
"reflect"
"testing"
"time"
)
const (
cnLocalhost = "localhost"
caCN = "UnitTests"
serverRelativeValidity = 24 * time.Hour
caRelativeValidity = 168 * time.Hour
)
var (
serverCN = fmt.Sprintf("%s-%d", cnLocalhost, time.Now().Unix())
)
func Test_certShouldBeRenewed(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
type args struct {
timeSource TimeSource
cert *x509.Certificate
}
tests := []struct {
name string
args args
want bool
}{
{
name: "Expect cert should not be renewed right after creation",
args: args{
timeSource: func() TimeSource {
tsMock := certmock.NewMockTimeSource(ctrl)
tsMock.
EXPECT().
UTCNow().
Return(time.Now().UTC()).
Times(1)
return tsMock
}(),
cert: &x509.Certificate{
NotAfter: time.Now().UTC().Add(serverRelativeValidity),
NotBefore: time.Now().UTC().Add(-serverRelativeValidity),
},
},
want: false,
},
{
name: "Expect cert should be renewed if the remaining lifetime is less than a quarter of the total lifetime",
args: args{
timeSource: func() TimeSource {
tsMock := certmock.NewMockTimeSource(ctrl)
tsMock.
EXPECT().
UTCNow().
Return(time.Now().UTC().Add(serverRelativeValidity/2 + 1*time.Hour)).
Times(1)
return tsMock
}(),
cert: &x509.Certificate{
NotAfter: time.Now().UTC().Add(serverRelativeValidity),
NotBefore: time.Now().UTC().Add(-serverRelativeValidity),
},
},
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := certShouldBeRenewed(tt.args.timeSource, tt.args.cert); got != tt.want {
t.Errorf("certShouldBeRenewed() = %v, want %v", got, tt.want)
}
})
}
}
func Test_fileSystemCache_Get(t *testing.T) {
type fields struct {
certCachePath string
inMemCache map[string]*tls.Certificate
timeSource TimeSource
}
type args struct {
cn string
}
tests := []struct {
name string
fields fields
args args
wantOk bool
}{
{
name: "Get a miss when no cert is present",
fields: fields{
certCachePath: os.TempDir(),
inMemCache: make(map[string]*tls.Certificate),
timeSource: NewTimeSource(),
},
args: args{
cnLocalhost,
},
wantOk: false,
},
{
name: "Get a prepared certificate from the memory cache",
fields: func() fields {
certGen := setupCertGen()
caCrt, _ := certGen.CACert(GenerationOptions{
CommonName: caCN,
})
srvCrt, _ := certGen.ServerCert(GenerationOptions{
CommonName: cnLocalhost,
}, caCrt)
return fields{
certCachePath: os.TempDir(),
inMemCache: map[string]*tls.Certificate{
cnLocalhost: srvCrt,
},
timeSource: NewTimeSource(),
}
}(),
args: args{
cn: cnLocalhost,
},
wantOk: true,
},
{
name: "Get a prepared certificate from the file system",
fields: func() fields {
certGen := setupCertGen()
caCrt, _ := certGen.CACert(GenerationOptions{
CommonName: "INetMock",
})
srvCrt, _ := certGen.ServerCert(GenerationOptions{
CommonName: serverCN,
}, caCrt)
pem := NewPEM(srvCrt)
if err := pem.Write(serverCN, os.TempDir()); err != nil {
panic(err)
}
t.Cleanup(func() {
for _, f := range []string{
path.Join(os.TempDir(), fmt.Sprintf("%s.pem", serverCN)),
path.Join(os.TempDir(), fmt.Sprintf("%s.key", serverCN)),
} {
_ = os.Remove(f)
}
})
return fields{
certCachePath: os.TempDir(),
inMemCache: make(map[string]*tls.Certificate),
timeSource: NewTimeSource(),
}
}(),
args: args{
cn: serverCN,
},
wantOk: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
f := &fileSystemCache{
certCachePath: tt.fields.certCachePath,
inMemCache: tt.fields.inMemCache,
timeSource: tt.fields.timeSource,
}
gotCrt, gotOk := f.Get(tt.args.cn)
if gotOk && (gotCrt == nil || !reflect.DeepEqual(reflect.TypeOf(new(tls.Certificate)), reflect.TypeOf(gotCrt))) {
t.Errorf("Wanted propert certificate but got %v", gotCrt)
}
if gotOk != tt.wantOk {
t.Errorf("Get() gotOk = %v, want %v", gotOk, tt.wantOk)
}
})
}
}
func Test_fileSystemCache_Put(t *testing.T) {
type fields struct {
certCachePath string
inMemCache map[string]*tls.Certificate
timeSource TimeSource
}
type args struct {
cert *tls.Certificate
}
tests := []struct {
name string
fields fields
args args
wantErr bool
}{
{
name: "Want error if nil cert is passed to put",
fields: fields{
certCachePath: os.TempDir(),
inMemCache: make(map[string]*tls.Certificate),
timeSource: NewTimeSource(),
},
args: args{
cert: nil,
},
wantErr: true,
},
{
name: "Want error if empty cert is passed to put",
fields: fields{
certCachePath: os.TempDir(),
inMemCache: make(map[string]*tls.Certificate),
timeSource: NewTimeSource(),
},
args: args{
cert: &tls.Certificate{},
},
wantErr: true,
},
{
name: "No error if valid cert is passed",
fields: fields{
certCachePath: os.TempDir(),
inMemCache: make(map[string]*tls.Certificate),
timeSource: NewTimeSource(),
},
args: args{
cert: func() *tls.Certificate {
gen := setupCertGen()
ca, _ := gen.CACert(GenerationOptions{
CommonName: caCN,
})
srvCN := fmt.Sprintf("%s-%d", cnLocalhost, time.Now().Unix())
t.Cleanup(func() {
for _, f := range []string{
path.Join(os.TempDir(), fmt.Sprintf("%s.pem", srvCN)),
path.Join(os.TempDir(), fmt.Sprintf("%s.key", srvCN)),
} {
_ = os.Remove(f)
}
})
srvCrt, _ := gen.ServerCert(GenerationOptions{
CommonName: srvCN,
}, ca)
return srvCrt
}(),
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
f := &fileSystemCache{
certCachePath: tt.fields.certCachePath,
inMemCache: tt.fields.inMemCache,
timeSource: tt.fields.timeSource,
}
if err := f.Put(tt.args.cert); (err != nil) != tt.wantErr {
t.Errorf("Put() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func setupCertGen() Generator {
return NewDefaultGenerator(Options{
Validity: ValidityByPurpose{
Server: ValidityDuration{
NotBeforeRelative: serverRelativeValidity,
NotAfterRelative: serverRelativeValidity,
},
CA: ValidityDuration{
NotBeforeRelative: caRelativeValidity,
NotAfterRelative: caRelativeValidity,
},
},
})
}

21
pkg/cert/constants.go Normal file
View file

@ -0,0 +1,21 @@
package cert
const (
CurveTypeP224 CurveType = "P224"
CurveTypeP256 CurveType = "P256"
CurveTypeP384 CurveType = "P384"
CurveTypeP521 CurveType = "P521"
CurveTypeED25519 CurveType = "ED25519"
defaultServerValidityDuration = "168h"
defaultCAValidityDuration = "17520h"
certCachePathConfigKey = "tls.certCachePath"
ecdsaCurveConfigKey = "tls.ecdsaCurve"
publicKeyConfigKey = "tls.rootCaCert.publicKey"
privateKeyPathConfigKey = "tls.rootCaCert.privateKey"
caCertValidityNotBeforeKey = "tls.validity.ca.notBeforeRelative"
caCertValidityNotAfterKey = "tls.validity.ca.notAfterRelative"
serverCertValidityNotBeforeKey = "tls.validity.server.notBeforeRelative"
serverCertValidityNotAfterKey = "tls.validity.server.notAfterRelative"
)

43
pkg/cert/defaults.go Normal file
View file

@ -0,0 +1,43 @@
package cert
import (
"github.com/baez90/inetmock/pkg/defaulting"
"reflect"
)
var (
certOptionsDefaulter defaulting.Defaulter = func(instance interface{}) {
switch o := instance.(type) {
case *GenerationOptions:
if len(o.Country) < 1 {
o.Country = []string{"US"}
}
if len(o.Locality) < 1 {
o.Locality = []string{"San Francisco"}
}
if len(o.Organization) < 1 {
o.Organization = []string{"INetMock"}
}
if len(o.StreetAddress) < 1 {
o.StreetAddress = []string{"Golden Gate Bridge"}
}
if len(o.PostalCode) < 1 {
o.PostalCode = []string{"94016"}
}
if len(o.Province) < 1 {
o.Province = []string{""}
}
}
}
)
func init() {
certOptionsType := reflect.TypeOf(GenerationOptions{})
defaulters.Register(certOptionsType, certOptionsDefaulter)
}

39
pkg/cert/defaults_test.go Normal file
View file

@ -0,0 +1,39 @@
package cert
import (
"reflect"
"testing"
)
func Test_certOptionsDefaulter(t *testing.T) {
tests := []struct {
name string
arg GenerationOptions
expected GenerationOptions
}{
{
name: "",
arg: GenerationOptions{
CommonName: "CA",
},
expected: GenerationOptions{
CommonName: "CA",
Country: []string{"US"},
Locality: []string{"San Francisco"},
Organization: []string{"INetMock"},
StreetAddress: []string{"Golden Gate Bridge"},
PostalCode: []string{"94016"},
Province: []string{""},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
certOptionsDefaulter(&tt.arg)
if !reflect.DeepEqual(tt.expected, tt.arg) {
t.Errorf("Apply defaulter expected=%v got=%v", tt.expected, tt.arg)
}
})
}
}

184
pkg/cert/generator.go Normal file
View file

@ -0,0 +1,184 @@
package cert
import (
"crypto/ecdsa"
"crypto/ed25519"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"github.com/baez90/inetmock/pkg/defaulting"
"math/big"
"net"
)
var (
defaulters = defaulting.New()
)
type GenerationOptions struct {
CommonName string
Organization []string
OrganizationalUnit []string
IPAddresses []net.IP
DNSNames []string
Country []string
Province []string
Locality []string
StreetAddress []string
PostalCode []string
}
type Generator interface {
CACert(options GenerationOptions) (*tls.Certificate, error)
ServerCert(options GenerationOptions, ca *tls.Certificate) (*tls.Certificate, error)
}
func NewDefaultGenerator(options Options) Generator {
return NewGenerator(options, NewTimeSource(), defaultKeyProvider(options))
}
func NewGenerator(options Options, source TimeSource, provider KeyProvider) Generator {
return &generator{
options: options,
provider: provider,
timeSource: source,
}
}
type generator struct {
options Options
provider KeyProvider
timeSource TimeSource
outDir string
}
func (g *generator) privateKey() (key interface{}, err error) {
if g.provider != nil {
return g.provider()
} else {
return defaultKeyProvider(g.options)()
}
}
func (g *generator) ServerCert(options GenerationOptions, ca *tls.Certificate) (cert *tls.Certificate, err error) {
defaulters.Apply(&options)
var serialNumber *big.Int
if serialNumber, err = generateSerialNumber(); err != nil {
return
}
var privateKey interface{}
if privateKey, err = g.privateKey(); err != nil {
return
}
template := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
CommonName: options.CommonName,
Organization: options.Organization,
OrganizationalUnit: options.OrganizationalUnit,
Country: options.Country,
Province: options.Province,
Locality: options.Locality,
StreetAddress: options.StreetAddress,
PostalCode: options.PostalCode,
},
IPAddresses: options.IPAddresses,
DNSNames: options.DNSNames,
NotBefore: g.timeSource.UTCNow().Add(-g.options.Validity.Server.NotBeforeRelative),
NotAfter: g.timeSource.UTCNow().Add(g.options.Validity.Server.NotAfterRelative),
KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
}
var caCrt *x509.Certificate
caCrt, err = x509.ParseCertificate(ca.Certificate[0])
var derBytes []byte
if derBytes, err = x509.CreateCertificate(rand.Reader, &template, caCrt, publicKey(privateKey), ca.PrivateKey); err != nil {
return
}
var privateKeyBytes []byte
if privateKeyBytes, err = x509.MarshalPKCS8PrivateKey(privateKey); err != nil {
return
}
if cert, err = parseCert(derBytes, privateKeyBytes); err != nil {
return
}
return
}
func (g generator) CACert(options GenerationOptions) (crt *tls.Certificate, err error) {
defaulters.Apply(&options)
var privateKey interface{}
var serialNumber *big.Int
if serialNumber, err = generateSerialNumber(); err != nil {
return
}
if privateKey, err = g.privateKey(); err != nil {
return
}
template := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
CommonName: options.CommonName,
Organization: options.Organization,
Country: options.Country,
Province: options.Province,
Locality: options.Locality,
StreetAddress: options.StreetAddress,
PostalCode: options.PostalCode,
},
IsCA: true,
NotBefore: g.timeSource.UTCNow().Add(-g.options.Validity.Server.NotBeforeRelative),
NotAfter: g.timeSource.UTCNow().Add(g.options.Validity.Server.NotAfterRelative),
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
}
var derBytes []byte
if derBytes, err = x509.CreateCertificate(rand.Reader, &template, &template, publicKey(privateKey), privateKey); err != nil {
return
}
var privateKeyBytes []byte
if privateKeyBytes, err = x509.MarshalPKCS8PrivateKey(privateKey); err != nil {
return
}
crt, err = parseCert(derBytes, privateKeyBytes)
return
}
func generateSerialNumber() (*big.Int, error) {
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
return rand.Int(rand.Reader, serialNumberLimit)
}
func publicKey(privateKey interface{}) interface{} {
switch k := privateKey.(type) {
case *ecdsa.PrivateKey:
return &k.PublicKey
case ed25519.PrivateKey:
return k.Public().(ed25519.PublicKey)
default:
return nil
}
}
func parseCert(derBytes []byte, privateKeyBytes []byte) (*tls.Certificate, error) {
pemEncodedPublicKey := pem.EncodeToMemory(&pem.Block{Type: certificateBlockType, Bytes: derBytes})
pemEncodedPrivateKey := pem.EncodeToMemory(&pem.Block{Type: privateKeyBlockType, Bytes: privateKeyBytes})
cert, err := tls.X509KeyPair(pemEncodedPublicKey, pemEncodedPrivateKey)
return &cert, err
}

79
pkg/cert/options.go Normal file
View file

@ -0,0 +1,79 @@
package cert
import (
"fmt"
"github.com/spf13/viper"
"os"
"strings"
"time"
)
var (
requiredConfigKeys = []string{
publicKeyConfigKey,
privateKeyPathConfigKey,
}
)
type CurveType string
type File struct {
PublicKeyPath string
PrivateKeyPath string
}
type ValidityDuration struct {
NotBeforeRelative time.Duration
NotAfterRelative time.Duration
}
type ValidityByPurpose struct {
CA ValidityDuration
Server ValidityDuration
}
type Options struct {
RootCACert File
CertCachePath string
Curve CurveType
Validity ValidityByPurpose
}
func loadFromConfig(config *viper.Viper) (Options, error) {
missingKeys := make([]string, 0)
for _, requiredKey := range requiredConfigKeys {
if !config.IsSet(requiredKey) {
missingKeys = append(missingKeys, requiredKey)
}
}
if len(missingKeys) > 0 {
return Options{}, fmt.Errorf("config keys are missing: %s", strings.Join(missingKeys, ", "))
}
config.SetDefault(certCachePathConfigKey, os.TempDir())
config.SetDefault(ecdsaCurveConfigKey, string(CurveTypeED25519))
config.SetDefault(caCertValidityNotBeforeKey, defaultCAValidityDuration)
config.SetDefault(caCertValidityNotAfterKey, defaultCAValidityDuration)
config.SetDefault(serverCertValidityNotBeforeKey, defaultServerValidityDuration)
config.SetDefault(serverCertValidityNotAfterKey, defaultServerValidityDuration)
return Options{
CertCachePath: config.GetString(certCachePathConfigKey),
Curve: CurveType(config.GetString(ecdsaCurveConfigKey)),
Validity: ValidityByPurpose{
CA: ValidityDuration{
NotBeforeRelative: config.GetDuration(caCertValidityNotBeforeKey),
NotAfterRelative: config.GetDuration(caCertValidityNotAfterKey),
},
Server: ValidityDuration{
NotBeforeRelative: config.GetDuration(serverCertValidityNotBeforeKey),
NotAfterRelative: config.GetDuration(serverCertValidityNotAfterKey),
},
},
RootCACert: File{
PublicKeyPath: config.GetString(publicKeyConfigKey),
PrivateKeyPath: config.GetString(privateKeyPathConfigKey),
},
}, nil
}

137
pkg/cert/options_test.go Normal file
View file

@ -0,0 +1,137 @@
package cert
import (
"github.com/spf13/viper"
"os"
"reflect"
"strings"
"testing"
"time"
)
func readViper(cfg string) *viper.Viper {
vpr := viper.New()
vpr.SetConfigType("yaml")
if err := vpr.ReadConfig(strings.NewReader(cfg)); err != nil {
panic(err)
}
return vpr
}
func Test_loadFromConfig(t *testing.T) {
type args struct {
config *viper.Viper
}
tests := []struct {
name string
args args
want Options
wantErr bool
}{
{
name: "Parse valid TLS configuration",
wantErr: false,
args: args{
config: readViper(`
tls:
ecdsaCurve: P256
validity:
ca:
notBeforeRelative: 17520h
notAfterRelative: 17520h
server:
NotBeforeRelative: 168h
NotAfterRelative: 168h
rootCaCert:
publicKey: ./ca.pem
privateKey: ./ca.key
certCachePath: /tmp/inetmock/
`),
},
want: Options{
RootCACert: File{
PublicKeyPath: "./ca.pem",
PrivateKeyPath: "./ca.key",
},
CertCachePath: "/tmp/inetmock/",
Curve: CurveTypeP256,
Validity: ValidityByPurpose{
CA: ValidityDuration{
NotBeforeRelative: 17520 * time.Hour,
NotAfterRelative: 17520 * time.Hour,
},
Server: ValidityDuration{
NotBeforeRelative: 168 * time.Hour,
NotAfterRelative: 168 * time.Hour,
},
},
},
},
{
name: "Get an error if CA public key path is missing",
args: args{
readViper(`
tls:
rootCaCert:
privateKey: ./ca.key
`),
},
want: Options{},
wantErr: true,
},
{
name: "Get an error if CA private key path is missing",
args: args{
readViper(`
tls:
rootCaCert:
publicKey: ./ca.pem
`),
},
want: Options{},
wantErr: true,
},
{
name: "Get default options if all required fields are set",
args: args{
readViper(`
tls:
rootCaCert:
publicKey: ./ca.pem
privateKey: ./ca.key
`),
},
want: Options{
RootCACert: File{
PublicKeyPath: "./ca.pem",
PrivateKeyPath: "./ca.key",
},
CertCachePath: os.TempDir(),
Curve: CurveTypeED25519,
Validity: ValidityByPurpose{
CA: ValidityDuration{
NotBeforeRelative: 17520 * time.Hour,
NotAfterRelative: 17520 * time.Hour,
},
Server: ValidityDuration{
NotBeforeRelative: 168 * time.Hour,
NotAfterRelative: 168 * time.Hour,
},
},
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := loadFromConfig(tt.args.config)
if (err != nil) != tt.wantErr {
t.Errorf("loadFromConfig() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("loadFromConfig() got = %v, want %v", got, tt.want)
}
})
}
}

79
pkg/cert/pem.go Normal file
View file

@ -0,0 +1,79 @@
package cert
import (
"crypto/tls"
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
"github.com/baez90/inetmock/pkg/path"
"os"
"path/filepath"
)
const (
certificateBlockType = "CERTIFICATE"
privateKeyBlockType = "PRIVATE KEY"
)
type PEM interface {
Cert() *tls.Certificate
Write(cn string, outDir string) error
Read(cn string, inDir string) error
ReadFrom(pubKeyPath, privateKeyPath string) error
}
func NewPEM(crt *tls.Certificate) PEM {
return &pemCrt{
crt: crt,
}
}
type pemCrt struct {
crt *tls.Certificate
}
func (p *pemCrt) Cert() *tls.Certificate {
return p.crt
}
func (p pemCrt) Write(cn string, outDir string) (err error) {
var certOut *os.File
if certOut, err = os.Create(filepath.Join(outDir, fmt.Sprintf("%s.pem", cn))); err != nil {
return
}
defer certOut.Close()
if err = pem.Encode(certOut, &pem.Block{Type: certificateBlockType, Bytes: p.crt.Certificate[0]}); err != nil {
return
}
var keyOut *os.File
if keyOut, err = os.OpenFile(filepath.Join(outDir, fmt.Sprintf("%s.key", cn)), os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600); err != nil {
return
}
var privKeyBytes []byte
privKeyBytes, err = x509.MarshalPKCS8PrivateKey(p.crt.PrivateKey)
err = pem.Encode(keyOut, &pem.Block{Type: privateKeyBlockType, Bytes: privKeyBytes})
return
}
func (p *pemCrt) Read(cn string, inDir string) error {
certPath := filepath.Join(inDir, fmt.Sprintf("%s.pem", cn))
keyPath := filepath.Join(inDir, fmt.Sprintf("%s.key", cn))
return p.ReadFrom(certPath, keyPath)
}
func (p *pemCrt) ReadFrom(pubKeyPath, privateKeyPath string) (err error) {
var tlsCrt tls.Certificate
if path.FileExists(pubKeyPath) && path.FileExists(privateKeyPath) {
if tlsCrt, err = tls.LoadX509KeyPair(pubKeyPath, privateKeyPath); err == nil {
p.crt = &tlsCrt
}
} else {
err = errors.New("either public or private key file do not exist")
}
return
}

153
pkg/cert/store.go Normal file
View file

@ -0,0 +1,153 @@
package cert
import (
"crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"github.com/baez90/inetmock/pkg/logging"
"github.com/spf13/viper"
"go.uber.org/zap"
"net"
)
const (
ipv4Loopback = "127.0.0.1"
)
var (
defaultKeyProvider = func(options Options) func() (key interface{}, err error) {
return func() (key interface{}, err error) {
return privateKeyForCurve(options)
}
}
)
type KeyProvider func() (key interface{}, err error)
type Store interface {
CACert() *tls.Certificate
GetCertificate(serverName string, ip string) (*tls.Certificate, error)
TLSConfig() *tls.Config
}
func NewDefaultStore(
config *viper.Viper,
logger logging.Logger,
) (Store, error) {
options, err := loadFromConfig(config)
if err != nil {
return nil, err
}
timeSource := NewTimeSource()
return NewStore(
options,
NewFileSystemCache(options.CertCachePath, timeSource),
NewDefaultGenerator(options),
logger,
)
}
func NewStore(
options Options,
cache Cache,
generator Generator,
logger logging.Logger,
) (Store, error) {
store := &store{
options: options,
cache: cache,
generator: generator,
logger: logger,
}
if err := store.loadCACert(); err != nil {
return nil, err
}
return store, nil
}
type store struct {
options Options
caCert *tls.Certificate
cache Cache
timeSource TimeSource
generator Generator
logger logging.Logger
}
func (s *store) TLSConfig() *tls.Config {
rootCaPool := x509.NewCertPool()
rootCaPubKey, _ := x509.ParseCertificate(s.caCert.Certificate[0])
rootCaPool.AddCert(rootCaPubKey)
return &tls.Config{
GetCertificate: func(info *tls.ClientHelloInfo) (cert *tls.Certificate, err error) {
var localIp string
if localIp, err = extractIPFromAddress(info.Conn.LocalAddr().String()); err != nil {
localIp = ipv4Loopback
}
if cert, err = s.GetCertificate(info.ServerName, localIp); err != nil {
s.logger.Error(
"error while resolving certificate",
zap.String("serverName", info.ServerName),
zap.String("localAddr", localIp),
zap.Error(err),
)
}
return
},
RootCAs: rootCaPool,
}
}
func (s *store) loadCACert() (err error) {
pemCrt := NewPEM(nil)
if err = pemCrt.ReadFrom(s.options.RootCACert.PublicKeyPath, s.options.RootCACert.PrivateKeyPath); err != nil {
return
}
s.caCert = pemCrt.Cert()
return
}
func (s *store) CACert() *tls.Certificate {
return s.caCert
}
func (s *store) GetCertificate(serverName string, ip string) (cert *tls.Certificate, err error) {
if crt, ok := s.cache.Get(serverName); ok {
return crt, nil
}
if cert, err = s.generator.ServerCert(GenerationOptions{
CommonName: serverName,
DNSNames: []string{serverName},
IPAddresses: []net.IP{net.ParseIP(ip)},
}, s.caCert); err == nil {
s.cache.Put(cert)
}
return
}
func privateKeyForCurve(options Options) (privateKey interface{}, err error) {
switch options.Curve {
case CurveTypeP224:
privateKey, err = ecdsa.GenerateKey(elliptic.P224(), rand.Reader)
case CurveTypeP256:
privateKey, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
case CurveTypeP384:
privateKey, err = ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
case CurveTypeP521:
privateKey, err = ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
default:
_, privateKey, err = ed25519.GenerateKey(rand.Reader)
}
return
}

20
pkg/cert/time_source.go Normal file
View file

@ -0,0 +1,20 @@
//go:generate mockgen -source=time_source.go -destination=./../../internal/mock/cert/time_source_mock.go -package=cert_mock
package cert
import "time"
type TimeSource interface {
UTCNow() time.Time
}
func NewTimeSource() TimeSource {
return &defaultTimeSource{}
}
type defaultTimeSource struct {
}
func (d defaultTimeSource) UTCNow() time.Time {
return time.Now().UTC()
}

View file

@ -0,0 +1,43 @@
package defaulting
import (
"reflect"
)
type Defaulter func(instance interface{})
type Registry interface {
Register(t reflect.Type, defaulter ...Defaulter)
Apply(instance interface{})
}
func New() Registry {
return &registry{
defaulters: make(map[reflect.Type][]Defaulter),
}
}
type registry struct {
defaulters map[reflect.Type][]Defaulter
}
func (r *registry) Register(t reflect.Type, defaulter ...Defaulter) {
var given []Defaulter
if r, ok := r.defaulters[t]; ok {
given = r
}
for _, d := range defaulter {
given = append(given, d)
}
r.defaulters[t] = given
}
func (r *registry) Apply(instance interface{}) {
if defs, ok := r.defaulters[reflect.TypeOf(instance)]; ok {
for _, def := range defs {
def(instance)
}
}
}

View file

@ -0,0 +1,111 @@
package defaulting
import (
"reflect"
"testing"
)
func Test_registry_Apply(t *testing.T) {
type sample struct {
i int
}
type fields struct {
defaulters map[reflect.Type][]Defaulter
}
type args struct {
instance interface{}
}
type expect struct {
result interface{}
}
tests := []struct {
name string
fields fields
args args
expect expect
}{
{
name: "Expect setting a sample value",
fields: fields{
defaulters: map[reflect.Type][]Defaulter{
reflect.TypeOf(&sample{}): {func(instance interface{}) {
switch i := instance.(type) {
case *sample:
i.i = 42
}
}},
},
},
args: args{
instance: &sample{},
},
expect: expect{
result: &sample{
i: 42,
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := &registry{
defaulters: tt.fields.defaulters,
}
r.Apply(tt.args.instance)
if !reflect.DeepEqual(tt.expect.result, tt.args.instance) {
t.Errorf("Apply() expected = %v got %v", tt.args.instance, tt.expect.result)
}
})
}
}
func Test_registry_Register(t *testing.T) {
type sample struct {
}
type fields struct {
defaulters map[reflect.Type][]Defaulter
}
type args struct {
t reflect.Type
defaulter []Defaulter
}
type expect struct {
length int
}
tests := []struct {
name string
fields fields
args args
expect expect
}{
{
name: "",
fields: fields{
defaulters: make(map[reflect.Type][]Defaulter),
},
args: args{
t: reflect.TypeOf(sample{}),
defaulter: []Defaulter{func(instance interface{}) {
}},
},
expect: expect{
length: 1,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := &registry{
defaulters: tt.fields.defaulters,
}
r.Register(tt.args.t, tt.args.defaulter...)
if length := len(r.defaulters); length != tt.expect.length {
t.Errorf("len(r.defaulters) expect %d got %d", tt.expect.length, length)
}
})
}
}

View file

@ -1,4 +1,4 @@
//go:generate mockgen -source=logger.go -destination=./../../internal/mock/logger_mock.go -package=mock //go:generate mockgen -source=logger.go -destination=./../../internal/mock/logging/logger_mock.go -package=logging_mock
package logging package logging
import "go.uber.org/zap" import "go.uber.org/zap"

View file

@ -6,7 +6,7 @@ require (
github.com/baez90/inetmock v0.0.1 github.com/baez90/inetmock v0.0.1
github.com/miekg/dns v1.1.29 github.com/miekg/dns v1.1.29
github.com/spf13/viper v1.6.3 github.com/spf13/viper v1.6.3
go.uber.org/zap v1.14.1 go.uber.org/zap v1.15.0
golang.org/x/crypto v0.0.0-20200406173513-056763e48d71 // indirect golang.org/x/crypto v0.0.0-20200406173513-056763e48d71 // indirect
golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e // indirect golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e // indirect
) )

View file

@ -149,6 +149,8 @@ go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee/go.mod h1:vJERXedbb3MVM5f9E
go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q=
go.uber.org/zap v1.14.1 h1:nYDKopTbvAPq/NrUVZwT15y2lpROBiLLyoRTbXOYWOo= go.uber.org/zap v1.14.1 h1:nYDKopTbvAPq/NrUVZwT15y2lpROBiLLyoRTbXOYWOo=
go.uber.org/zap v1.14.1/go.mod h1:Mb2vm2krFEG5DV0W9qcHBYFtp/Wku1cvYaqPsS/WYfc= go.uber.org/zap v1.14.1/go.mod h1:Mb2vm2krFEG5DV0W9qcHBYFtp/Wku1cvYaqPsS/WYfc=
go.uber.org/zap v1.15.0 h1:ZZCA22JRF2gQE5FoNmhmrf7jeJJ2uhqDUNRYKm8dvmM=
go.uber.org/zap v1.15.0/go.mod h1:Mb2vm2krFEG5DV0W9qcHBYFtp/Wku1cvYaqPsS/WYfc=
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
@ -188,6 +190,8 @@ golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200409092240-59c9f1ba88fa h1:mQTN3ECqfsViCNBgq+A40vdwhkGykrrQlYe3mPj6BoU= golang.org/x/sys v0.0.0-20200409092240-59c9f1ba88fa h1:mQTN3ECqfsViCNBgq+A40vdwhkGykrrQlYe3mPj6BoU=
golang.org/x/sys v0.0.0-20200409092240-59c9f1ba88fa/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200409092240-59c9f1ba88fa/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200420163511-1957bb5e6d1f h1:gWF768j/LaZugp8dyS4UwsslYCYz9XgFxvlgsn0n9H8=
golang.org/x/sys v0.0.0-20200420163511-1957bb5e6d1f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs=

View file

@ -2,7 +2,7 @@ package main
import ( import (
"fmt" "fmt"
"github.com/baez90/inetmock/internal/config" "github.com/baez90/inetmock/pkg/api"
"github.com/baez90/inetmock/pkg/logging" "github.com/baez90/inetmock/pkg/logging"
"github.com/miekg/dns" "github.com/miekg/dns"
"go.uber.org/zap" "go.uber.org/zap"
@ -13,7 +13,7 @@ type dnsHandler struct {
dnsServer []*dns.Server dnsServer []*dns.Server
} }
func (d *dnsHandler) Start(config config.HandlerConfig) (err error) { func (d *dnsHandler) Start(config api.HandlerConfig) (err error) {
options := loadFromConfig(config.Options()) options := loadFromConfig(config.Options())
addr := fmt.Sprintf("%s:%d", config.ListenAddress(), config.Port()) addr := fmt.Sprintf("%s:%d", config.ListenAddress(), config.Port())

View file

@ -5,7 +5,7 @@ go 1.14
require ( require (
github.com/baez90/inetmock v0.0.1 github.com/baez90/inetmock v0.0.1
github.com/spf13/viper v1.6.3 github.com/spf13/viper v1.6.3
go.uber.org/zap v1.14.1 go.uber.org/zap v1.15.0
) )
replace github.com/baez90/inetmock v0.0.1 => ../../ replace github.com/baez90/inetmock v0.0.1 => ../../

View file

@ -147,6 +147,8 @@ go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee/go.mod h1:vJERXedbb3MVM5f9E
go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q=
go.uber.org/zap v1.14.1 h1:nYDKopTbvAPq/NrUVZwT15y2lpROBiLLyoRTbXOYWOo= go.uber.org/zap v1.14.1 h1:nYDKopTbvAPq/NrUVZwT15y2lpROBiLLyoRTbXOYWOo=
go.uber.org/zap v1.14.1/go.mod h1:Mb2vm2krFEG5DV0W9qcHBYFtp/Wku1cvYaqPsS/WYfc= go.uber.org/zap v1.14.1/go.mod h1:Mb2vm2krFEG5DV0W9qcHBYFtp/Wku1cvYaqPsS/WYfc=
go.uber.org/zap v1.15.0 h1:ZZCA22JRF2gQE5FoNmhmrf7jeJJ2uhqDUNRYKm8dvmM=
go.uber.org/zap v1.15.0/go.mod h1:Mb2vm2krFEG5DV0W9qcHBYFtp/Wku1cvYaqPsS/WYfc=
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
@ -178,6 +180,8 @@ golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200409092240-59c9f1ba88fa h1:mQTN3ECqfsViCNBgq+A40vdwhkGykrrQlYe3mPj6BoU= golang.org/x/sys v0.0.0-20200409092240-59c9f1ba88fa h1:mQTN3ECqfsViCNBgq+A40vdwhkGykrrQlYe3mPj6BoU=
golang.org/x/sys v0.0.0-20200409092240-59c9f1ba88fa/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200409092240-59c9f1ba88fa/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200420163511-1957bb5e6d1f h1:gWF768j/LaZugp8dyS4UwsslYCYz9XgFxvlgsn0n9H8=
golang.org/x/sys v0.0.0-20200420163511-1957bb5e6d1f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs=

View file

@ -2,7 +2,7 @@ package main
import ( import (
"fmt" "fmt"
"github.com/baez90/inetmock/internal/config" "github.com/baez90/inetmock/pkg/api"
"github.com/baez90/inetmock/pkg/logging" "github.com/baez90/inetmock/pkg/logging"
"go.uber.org/zap" "go.uber.org/zap"
"net/http" "net/http"
@ -18,7 +18,7 @@ type httpHandler struct {
server *http.Server server *http.Server
} }
func (p *httpHandler) Start(config config.HandlerConfig) (err error) { func (p *httpHandler) Start(config api.HandlerConfig) (err error) {
options := loadFromConfig(config.Options()) options := loadFromConfig(config.Options())
addr := fmt.Sprintf("%s:%d", config.ListenAddress(), config.Port()) addr := fmt.Sprintf("%s:%d", config.ListenAddress(), config.Port())
p.server = &http.Server{Addr: addr, Handler: p.router} p.server = &http.Server{Addr: addr, Handler: p.router}

View file

@ -1,53 +0,0 @@
package main
import (
"gopkg.in/elazarl/goproxy.v1"
"net/http"
)
const (
passthroughStrategyName = "passthrough"
notFoundStrategyName = "notfound"
)
var (
fallbackStrategies map[string]ProxyFallbackStrategy
)
func init() {
fallbackStrategies = map[string]ProxyFallbackStrategy{
passthroughStrategyName: &passThroughFallbackStrategy{},
notFoundStrategyName: &notFoundFallbackStrategy{},
}
}
func StrategyForName(name string) ProxyFallbackStrategy {
if strategy, ok := fallbackStrategies[name]; ok {
return strategy
}
return fallbackStrategies[notFoundStrategyName]
}
type ProxyFallbackStrategy interface {
Apply(request *http.Request) (*http.Response, error)
}
type passThroughFallbackStrategy struct {
}
func (p passThroughFallbackStrategy) Apply(request *http.Request) (*http.Response, error) {
return nil, nil
}
type notFoundFallbackStrategy struct {
}
func (n notFoundFallbackStrategy) Apply(request *http.Request) (response *http.Response, err error) {
response = goproxy.NewResponse(
request,
goproxy.ContentTypeText,
http.StatusNotFound,
"The requested resource was not found",
)
return
}

View file

@ -1,46 +0,0 @@
package main
import (
"reflect"
"testing"
)
func TestStrategyForName(t *testing.T) {
type args struct {
name string
}
tests := []struct {
name string
args args
want reflect.Type
}{
{
name: "Test get notfound strategy",
want: reflect.TypeOf(&notFoundFallbackStrategy{}),
args: args{
name: "notfound",
},
},
{
name: "Test get pass through strategy",
want: reflect.TypeOf(&passThroughFallbackStrategy{}),
args: args{
name: "passthrough",
},
},
{
name: "Test get fallback strategy notfound because key is not known",
want: reflect.TypeOf(&notFoundFallbackStrategy{}),
args: args{
name: "asdf12234",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := StrategyForName(tt.args.name); reflect.TypeOf(got) != tt.want {
t.Errorf("StrategyForName() = %v, want %v", got, tt.want)
}
})
}
}

View file

@ -4,9 +4,8 @@ go 1.14
require ( require (
github.com/baez90/inetmock v0.0.1 github.com/baez90/inetmock v0.0.1
github.com/elazarl/goproxy v0.0.0-20200315184450-1f3cb6622dad
github.com/spf13/viper v1.6.3 github.com/spf13/viper v1.6.3
go.uber.org/zap v1.14.1 go.uber.org/zap v1.15.0
gopkg.in/elazarl/goproxy.v1 v1.0.0-20180725130230-947c36da3153 gopkg.in/elazarl/goproxy.v1 v1.0.0-20180725130230-947c36da3153
) )

View file

@ -152,6 +152,8 @@ go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee/go.mod h1:vJERXedbb3MVM5f9E
go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q=
go.uber.org/zap v1.14.1 h1:nYDKopTbvAPq/NrUVZwT15y2lpROBiLLyoRTbXOYWOo= go.uber.org/zap v1.14.1 h1:nYDKopTbvAPq/NrUVZwT15y2lpROBiLLyoRTbXOYWOo=
go.uber.org/zap v1.14.1/go.mod h1:Mb2vm2krFEG5DV0W9qcHBYFtp/Wku1cvYaqPsS/WYfc= go.uber.org/zap v1.14.1/go.mod h1:Mb2vm2krFEG5DV0W9qcHBYFtp/Wku1cvYaqPsS/WYfc=
go.uber.org/zap v1.15.0 h1:ZZCA22JRF2gQE5FoNmhmrf7jeJJ2uhqDUNRYKm8dvmM=
go.uber.org/zap v1.15.0/go.mod h1:Mb2vm2krFEG5DV0W9qcHBYFtp/Wku1cvYaqPsS/WYfc=
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
@ -183,6 +185,8 @@ golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200409092240-59c9f1ba88fa h1:mQTN3ECqfsViCNBgq+A40vdwhkGykrrQlYe3mPj6BoU= golang.org/x/sys v0.0.0-20200409092240-59c9f1ba88fa h1:mQTN3ECqfsViCNBgq+A40vdwhkGykrrQlYe3mPj6BoU=
golang.org/x/sys v0.0.0-20200409092240-59c9f1ba88fa/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200409092240-59c9f1ba88fa/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200420163511-1957bb5e6d1f h1:gWF768j/LaZugp8dyS4UwsslYCYz9XgFxvlgsn0n9H8=
golang.org/x/sys v0.0.0-20200420163511-1957bb5e6d1f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs=

View file

@ -2,7 +2,7 @@ package main
import ( import (
"fmt" "fmt"
"github.com/baez90/inetmock/internal/config" "github.com/baez90/inetmock/pkg/api"
"github.com/baez90/inetmock/pkg/logging" "github.com/baez90/inetmock/pkg/logging"
"go.uber.org/zap" "go.uber.org/zap"
"gopkg.in/elazarl/goproxy.v1" "gopkg.in/elazarl/goproxy.v1"
@ -19,7 +19,7 @@ type httpProxy struct {
server *http.Server server *http.Server
} }
func (h *httpProxy) Start(config config.HandlerConfig) (err error) { func (h *httpProxy) Start(config api.HandlerConfig) (err error) {
options := loadFromConfig(config.Options()) options := loadFromConfig(config.Options())
addr := fmt.Sprintf("%s:%d", config.ListenAddress(), config.Port()) addr := fmt.Sprintf("%s:%d", config.ListenAddress(), config.Port())
h.server = &http.Server{Addr: addr, Handler: h.proxy} h.server = &http.Server{Addr: addr, Handler: h.proxy}
@ -27,11 +27,20 @@ func (h *httpProxy) Start(config config.HandlerConfig) (err error) {
zap.String("address", addr), zap.String("address", addr),
) )
tlsConfig := api.ServicesInstance().CertStore().TLSConfig()
proxyHandler := &proxyHttpHandler{ proxyHandler := &proxyHttpHandler{
options: options, options: options,
logger: h.logger, logger: h.logger,
} }
proxyHttpsHandler := &proxyHttpsHandler{
tlsConfig: tlsConfig,
logger: h.logger,
}
h.proxy.OnRequest().Do(proxyHandler) h.proxy.OnRequest().Do(proxyHandler)
h.proxy.OnRequest().HandleConnect(proxyHttpsHandler)
go h.startProxy() go h.startProxy()
return return
} }

View file

@ -1,52 +1,42 @@
package main package main
import ( import (
"fmt"
"github.com/spf13/viper" "github.com/spf13/viper"
"regexp"
) )
const ( const (
rulesConfigKey = "rules" targetSchemeConfigKey = "target.scheme"
patternConfigKey = "pattern" targetIpAddressConfigKey = "target.ipAddress"
responseConfigKey = "response" targetPortConfigKey = "target.port"
fallbackStrategyConfigKey = "fallback"
) )
type targetRule struct { type redirectionTarget struct {
pattern *regexp.Regexp scheme string
response string ipAddress string
port uint16
} }
func (tr targetRule) Pattern() *regexp.Regexp { func (rt redirectionTarget) host() string {
return tr.pattern return fmt.Sprintf("%s:%d", rt.ipAddress, rt.port)
}
func (tr targetRule) Response() string {
return tr.response
} }
type httpProxyOptions struct { type httpProxyOptions struct {
Rules []targetRule redirectionTarget redirectionTarget
FallbackStrategy ProxyFallbackStrategy
} }
func loadFromConfig(config *viper.Viper) (options httpProxyOptions) { func loadFromConfig(config *viper.Viper) (options httpProxyOptions) {
options.FallbackStrategy = StrategyForName(config.GetString(fallbackStrategyConfigKey))
anonRules := config.Get(rulesConfigKey).([]interface{}) config.SetDefault(targetSchemeConfigKey, "http")
config.SetDefault(targetIpAddressConfigKey, "127.0.0.1")
config.SetDefault(targetPortConfigKey, "80")
for _, i := range anonRules { options = httpProxyOptions{
innerData := i.(map[interface{}]interface{}) redirectionTarget{
scheme: config.GetString(targetSchemeConfigKey),
if rulePattern, err := regexp.Compile(innerData[patternConfigKey].(string)); err == nil { ipAddress: config.GetString(targetIpAddressConfigKey),
options.Rules = append(options.Rules, targetRule{ port: uint16(config.GetInt(targetPortConfigKey)),
pattern: rulePattern, },
response: innerData[responseConfigKey].(string),
})
} else {
panic(err)
}
} }
return return
} }

View file

@ -1,126 +1 @@
package main package main
import (
"bytes"
"github.com/spf13/viper"
"reflect"
"regexp"
"testing"
)
func Test_loadFromConfig(t *testing.T) {
type args struct {
config string
}
tests := []struct {
name string
args args
wantOptions httpProxyOptions
}{
{
name: "Parse proper configuration with notfound strategy",
args: args{
config: `
fallback: notfound,
rules:
- pattern: ".*"
response: ./assets/fakeFiles/default.html
`,
},
wantOptions: httpProxyOptions{
FallbackStrategy: StrategyForName(notFoundStrategyName),
Rules: []targetRule{
{
response: "./assets/fakeFiles/default.html",
pattern: regexp.MustCompile(".*"),
},
},
},
},
{
name: "Parse proper configuration with pass through strategy",
args: args{
config: `
fallback: passthrough
rules:
- pattern: ".*"
response: ./assets/fakeFiles/default.html
`,
},
wantOptions: httpProxyOptions{
FallbackStrategy: StrategyForName(passthroughStrategyName),
Rules: []targetRule{
{
response: "./assets/fakeFiles/default.html",
pattern: regexp.MustCompile(".*"),
},
},
},
},
{
name: "Parse proper configuration and preserve order of rules",
args: args{
config: `
fallback: notfound
rules:
- pattern: ".*\\.(?i)txt"
response: ./assets/fakeFiles/default.txt
- pattern: ".*"
response: ./assets/fakeFiles/default.html
`,
},
wantOptions: httpProxyOptions{
FallbackStrategy: StrategyForName(notFoundStrategyName),
Rules: []targetRule{
{
response: "./assets/fakeFiles/default.txt",
pattern: regexp.MustCompile(".*\\.(?i)txt"),
},
{
response: "./assets/fakeFiles/default.html",
pattern: regexp.MustCompile(".*"),
},
},
},
},
{
name: "Parse configuration with non existing fallback strategy key - falling back to 'notfound'",
args: args{
config: `
fallback: doesNotExist
rules: []
`,
},
wantOptions: httpProxyOptions{
FallbackStrategy: StrategyForName(notFoundStrategyName),
Rules: nil,
},
},
{
name: "Parse configuration without any fallback key",
args: args{
config: `
f4llb4ck: doesNotExist
rules: []
`,
},
wantOptions: httpProxyOptions{
FallbackStrategy: StrategyForName(notFoundStrategyName),
Rules: nil,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
config := viper.New()
config.SetConfigType("yaml")
if err := config.ReadConfig(bytes.NewBufferString(tt.args.config)); err != nil {
t.Errorf("failed to read config %v", err)
return
}
if gotOptions := loadFromConfig(config); !reflect.DeepEqual(gotOptions, tt.wantOptions) {
t.Errorf("loadFromConfig() = %v, want %v", gotOptions, tt.wantOptions)
}
})
}
}

View file

@ -1,14 +1,13 @@
package main package main
import ( import (
"context"
"crypto/tls"
"github.com/baez90/inetmock/pkg/logging" "github.com/baez90/inetmock/pkg/logging"
"go.uber.org/zap" "go.uber.org/zap"
"gopkg.in/elazarl/goproxy.v1" "gopkg.in/elazarl/goproxy.v1"
"io"
"mime"
"net/http" "net/http"
"os" "net/url"
"path/filepath"
) )
type proxyHttpHandler struct { type proxyHttpHandler struct {
@ -16,25 +15,27 @@ type proxyHttpHandler struct {
logger logging.Logger logger logging.Logger
} }
/* type proxyHttpsHandler struct {
TODO implement HTTPS proxy like in TLS interceptor tlsConfig *tls.Config
func (p *proxyHttpHandler) HandleConnect(req string, ctx *goproxy.ProxyCtx) (*goproxy.ConnectAction, string) { logger logging.Logger
}
func (p *proxyHttpsHandler) HandleConnect(req string, _ *goproxy.ProxyCtx) (*goproxy.ConnectAction, string) {
p.logger.Info(
"Intercepting HTTPS proxy request",
zap.String("request", req),
)
return &goproxy.ConnectAction{ return &goproxy.ConnectAction{
Action: goproxy.OkConnect, Action: goproxy.ConnectMitm,
TLSConfig: func(host string, ctx *goproxy.ProxyCtx) (*tls.Config, error) {
return p.tlsConfig, nil
},
}, "" }, ""
}*/ }
func (p *proxyHttpHandler) Handle(req *http.Request, _ *goproxy.ProxyCtx) (retReq *http.Request, resp *http.Response) {
func (p *proxyHttpHandler) Handle(req *http.Request, ctx *goproxy.ProxyCtx) (retReq *http.Request, resp *http.Response) {
retReq = req retReq = req
resp = &http.Response{
Request: req,
TransferEncoding: req.TransferEncoding,
Header: make(http.Header),
StatusCode: http.StatusOK,
}
p.logger.Info( p.logger.Info(
"Handling request", "Handling request",
zap.String("source", req.RemoteAddr), zap.String("source", req.RemoteAddr),
@ -45,56 +46,48 @@ func (p *proxyHttpHandler) Handle(req *http.Request, _ *goproxy.ProxyCtx) (retRe
zap.Reflect("headers", req.Header), zap.Reflect("headers", req.Header),
) )
for _, rule := range p.options.Rules { var err error
if rule.pattern.MatchString(req.URL.Path) { if resp, err = ctx.RoundTrip(p.redirectHTTPRequest(req)); err != nil {
if file, err := os.Open(rule.response); err != nil {
p.logger.Error(
"failed to open response target file",
zap.String("resonse", rule.response),
zap.Error(err),
)
continue
} else {
resp.Body = file
if stat, err := file.Stat(); err == nil {
resp.ContentLength = stat.Size()
}
if contentType, err := GetContentType(rule, file); err == nil {
resp.Header["Content-Type"] = []string{contentType}
}
p.logger.Info("returning fake response from rules")
return req, resp
}
}
}
if resp, err := p.options.FallbackStrategy.Apply(req); err != nil {
p.logger.Error( p.logger.Error(
"failed to apply fallback strategy", "error while doing roundtrip",
zap.Error(err), zap.Error(err),
) )
} else { return req, nil
p.logger.Info("returning fake response from fallback strategy")
return req, resp
} }
p.logger.Info("falling back to proxying request through") return
return req, nil }
}
func (p proxyHttpHandler) redirectHTTPRequest(originalRequest *http.Request) (redirectReq *http.Request) {
func GetContentType(rule targetRule, file *os.File) (contentType string, err error) { redirectReq = &http.Request{
if contentType = mime.TypeByExtension(filepath.Ext(rule.response)); contentType != "" { Method: originalRequest.Method,
return URL: &url.URL{
} Host: p.options.redirectionTarget.host(),
Scheme: p.options.redirectionTarget.scheme,
var buf [512]byte Path: originalRequest.URL.Path,
ForceQuery: originalRequest.URL.ForceQuery,
n, _ := io.ReadFull(file, buf[:]) Fragment: originalRequest.URL.Fragment,
Opaque: originalRequest.URL.Opaque,
contentType = http.DetectContentType(buf[:n]) RawPath: originalRequest.URL.RawPath,
_, err = file.Seek(0, io.SeekStart) RawQuery: originalRequest.URL.RawQuery,
User: originalRequest.URL.User,
},
Proto: originalRequest.Proto,
ProtoMajor: originalRequest.ProtoMajor,
ProtoMinor: originalRequest.ProtoMinor,
Header: originalRequest.Header,
Body: originalRequest.Body,
GetBody: originalRequest.GetBody,
ContentLength: originalRequest.ContentLength,
TransferEncoding: originalRequest.TransferEncoding,
Close: false,
Host: originalRequest.Host,
Form: originalRequest.Form,
PostForm: originalRequest.PostForm,
MultipartForm: originalRequest.MultipartForm,
Trailer: originalRequest.Trailer,
}
redirectReq = redirectReq.WithContext(context.Background())
return return
} }

View file

@ -1,20 +0,0 @@
package main
import (
"fmt"
"regexp"
"strings"
)
var (
ipExtractionRegex = regexp.MustCompile(`(.+):\d{1,5}$`)
)
func extractIPFromAddress(addr string) (string, error) {
matches := ipExtractionRegex.FindAllStringSubmatch(addr, -1)
if len(matches) > 0 && len(matches[0]) >= 1 {
return strings.Trim(matches[0][1], "[]"), nil
} else {
return "", fmt.Errorf("failed to extract IP address from addr %s", addr)
}
}

View file

@ -1,204 +0,0 @@
package main
import (
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"fmt"
"github.com/baez90/inetmock/pkg/logging"
"github.com/baez90/inetmock/pkg/path"
"go.uber.org/zap"
"math/big"
"net"
"path/filepath"
)
type keyProvider func() (key interface{}, err error)
type certStore struct {
options *tlsOptions
keyProvider keyProvider
caCert *x509.Certificate
caPrivateKey interface{}
certCache map[string]*tls.Certificate
timeSourceInstance timeSource
logger logging.Logger
}
func (cs *certStore) timeSource() timeSource {
if cs.timeSourceInstance == nil {
cs.timeSourceInstance = createTimeSource()
}
return cs.timeSourceInstance
}
func (cs *certStore) initCaCert() (err error) {
if path.FileExists(cs.options.rootCaCert.publicKeyPath) && path.FileExists(cs.options.rootCaCert.privateKeyPath) {
var tlsCert tls.Certificate
if tlsCert, err = tls.LoadX509KeyPair(cs.options.rootCaCert.publicKeyPath, cs.options.rootCaCert.privateKeyPath); err != nil {
return
} else if len(tlsCert.Certificate) > 0 {
cs.caPrivateKey = tlsCert.PrivateKey
cs.caCert, err = x509.ParseCertificate(tlsCert.Certificate[0])
}
} else {
cs.caCert, cs.caPrivateKey, err = cs.generateCaCert()
}
return
}
func (cs *certStore) initProvider() {
if cs.keyProvider == nil {
cs.keyProvider = func() (key interface{}, err error) {
return privateKeyForCurve(cs.options)
}
}
}
func (cs *certStore) generateCaCert() (pubKey *x509.Certificate, privateKey interface{}, err error) {
cs.initProvider()
timeSource := cs.timeSource()
var serialNumber *big.Int
if serialNumber, err = generateSerialNumber(); err != nil {
return
}
if privateKey, err = cs.keyProvider(); err != nil {
return
}
template := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
Organization: []string{"INetMock"},
Country: []string{"US"},
Province: []string{""},
Locality: []string{"San Francisco"},
StreetAddress: []string{"Golden Gate Bridge"},
PostalCode: []string{"94016"},
},
IsCA: true,
NotBefore: timeSource.UTCNow().Add(-cs.options.validity.ca.notBeforeRelative),
NotAfter: timeSource.UTCNow().Add(cs.options.validity.ca.notAfterRelative),
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
}
var derBytes []byte
if derBytes, err = x509.CreateCertificate(rand.Reader, &template, &template, publicKey(privateKey), privateKey); err != nil {
return
}
pubKey, err = x509.ParseCertificate(derBytes)
if err = writePublicKey(cs.options.rootCaCert.publicKeyPath, derBytes); err != nil {
return
}
var privateKeyBytes []byte
if privateKeyBytes, err = x509.MarshalPKCS8PrivateKey(privateKey); err != nil {
return
}
if err = writePrivateKey(cs.options.rootCaCert.privateKeyPath, privateKeyBytes); err != nil {
return
}
return
}
func (cs *certStore) getCertificate(serverName string, ip string) (cert *tls.Certificate, err error) {
if cert, ok := cs.certCache[serverName]; ok {
return cert, nil
}
certPath := filepath.Join(cs.options.certCachePath, fmt.Sprintf("%s.pem", serverName))
keyPath := filepath.Join(cs.options.certCachePath, fmt.Sprintf("%s.key", serverName))
if path.FileExists(certPath) && path.FileExists(keyPath) {
if tlsCert, loadErr := tls.LoadX509KeyPair(certPath, keyPath); loadErr == nil {
cs.certCache[serverName] = &tlsCert
x509Cert, err := x509.ParseCertificate(tlsCert.Certificate[0])
if err == nil && !certShouldBeRenewed(cs.timeSource(), x509Cert) {
return &tlsCert, nil
}
}
}
if cert, err = cs.generateDomainCert(serverName, ip); err == nil {
cs.certCache[serverName] = cert
}
return
}
func (cs *certStore) generateDomainCert(
serverName string,
localIp string,
) (cert *tls.Certificate, err error) {
cs.initProvider()
if cs.caCert == nil {
if err = cs.initCaCert(); err != nil {
return
}
}
var serialNumber *big.Int
if serialNumber, err = generateSerialNumber(); err != nil {
return
}
var privateKey interface{}
if privateKey, err = cs.keyProvider(); err != nil {
return
}
notBefore := cs.timeSource().UTCNow().Add(-cs.options.validity.domain.notBeforeRelative)
notAfter := cs.timeSource().UTCNow().Add(cs.options.validity.domain.notAfterRelative)
cs.logger.Info(
"generate domain certificate",
zap.Time("notBefore", notBefore),
zap.Time("notAfter", notAfter),
)
template := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
Organization: []string{"INetMock"},
},
IPAddresses: []net.IP{net.ParseIP(localIp)},
DNSNames: []string{serverName},
NotBefore: notBefore,
NotAfter: notAfter,
KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
}
var derBytes []byte
if derBytes, err = x509.CreateCertificate(rand.Reader, &template, cs.caCert, publicKey(privateKey), cs.caPrivateKey); err != nil {
return
}
if err = writePublicKey(filepath.Join(cs.options.certCachePath, fmt.Sprintf("%s.pem", serverName)), derBytes); err != nil {
return
}
var privateKeyBytes []byte
if privateKeyBytes, err = x509.MarshalPKCS8PrivateKey(privateKey); err != nil {
return
}
if cert, err = parseCert(derBytes, privateKeyBytes); err != nil {
return
}
if err = writePrivateKey(filepath.Join(cs.options.certCachePath, fmt.Sprintf("%s.key", serverName)), privateKeyBytes); err != nil {
return
}
return
}

View file

@ -1,213 +0,0 @@
package main
import (
"crypto/tls"
"crypto/x509"
"github.com/baez90/inetmock/pkg/logging"
"go.uber.org/zap"
"io/ioutil"
"os"
"path/filepath"
"reflect"
"testing"
"time"
)
func Test_generateCaCert(t *testing.T) {
tmpDir, err := ioutil.TempDir(os.TempDir(), "*-inetmock")
if err != nil {
t.Errorf("failed to create temp dir %v", err)
return
}
options := &tlsOptions{
ecdsaCurve: "P256",
rootCaCert: cert{
publicKeyPath: filepath.Join(tmpDir, "localhost.pem"),
privateKeyPath: filepath.Join(tmpDir, "localhost.key"),
},
validity: validity{
ca: certValidity{
notBeforeRelative: time.Hour * 24 * 30,
notAfterRelative: time.Hour * 24 * 30,
},
},
}
certStore := certStore{
options: options,
}
defer func() {
_ = os.Remove(tmpDir)
}()
_, _, err = certStore.generateCaCert()
if err != nil {
t.Errorf("failed to generate CA cert %v", err)
}
if _, err = os.Stat(options.rootCaCert.publicKeyPath); err != nil {
t.Errorf("cert file was not created")
}
if _, err = os.Stat(options.rootCaCert.privateKeyPath); err != nil {
t.Errorf("cert file was not created")
}
}
func Test_generateDomainCert(t *testing.T) {
tmpDir, err := ioutil.TempDir(os.TempDir(), "*-inetmock")
if err != nil {
t.Errorf("failed to create temp dir %v", err)
return
}
defer func() {
_ = os.Remove(tmpDir)
}()
caTlsCert, _ := loadPEMCert(testCaCrt, testCaKey)
caCert, _ := x509.ParseCertificate(caTlsCert.Certificate[0])
options := &tlsOptions{
ecdsaCurve: "P256",
certCachePath: tmpDir,
validity: validity{
domain: certValidity{
notAfterRelative: time.Hour * 24 * 30,
notBeforeRelative: time.Hour * 24 * 30,
},
ca: certValidity{
notAfterRelative: time.Hour * 24 * 30,
notBeforeRelative: time.Hour * 24 * 30,
},
},
}
zapLogger, _ := zap.NewDevelopment()
logger := logging.NewLogger(zapLogger)
certStore := certStore{
options: options,
caCert: caCert,
logger: logger,
caPrivateKey: caTlsCert.PrivateKey,
}
type args struct {
domain string
ip string
}
tests := []struct {
name string
args args
wantErr bool
}{
{
name: "Test create google.com cert",
args: args{
domain: "google.com",
ip: "127.0.0.1",
},
wantErr: false,
},
{
name: "Test create golem.de cert",
args: args{
domain: "golem.de",
ip: "127.0.0.1",
},
wantErr: false,
},
{
name: "Test create golem.de cert with any IP address",
args: args{
domain: "golem.de",
ip: "10.10.0.10",
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if domainCert, err := certStore.generateDomainCert(
tt.args.domain,
tt.args.ip,
); (err != nil) != tt.wantErr || reflect.DeepEqual(domainCert, tls.Certificate{}) {
t.Errorf("generateDomainCert() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func Test_certStore_initCaCert(t *testing.T) {
tmpDir, err := ioutil.TempDir(os.TempDir(), "*-inetmock")
if err != nil {
t.Errorf("failed to create temp dir %v", err)
return
}
defer func() {
_ = os.Remove(tmpDir)
}()
caCertPath := filepath.Join(tmpDir, "cacert.pem")
caKeyPath := filepath.Join(tmpDir, "cacert.key")
if err := ioutil.WriteFile(caCertPath, testCaCrt, 0600); err != nil {
t.Errorf("failed to write cacert.pem %v", err)
return
}
if err := ioutil.WriteFile(caKeyPath, testCaKey, 0600); err != nil {
t.Errorf("failed to write cacert.key %v", err)
return
}
type fields struct {
options *tlsOptions
caCert *x509.Certificate
}
tests := []struct {
name string
fields fields
wantErr bool
}{
{
name: "Init CA cert from file",
wantErr: false,
fields: fields{
options: &tlsOptions{
rootCaCert: cert{
publicKeyPath: caCertPath,
privateKeyPath: caKeyPath,
},
},
},
},
{
name: "Init CA with new cert",
wantErr: false,
fields: fields{
options: &tlsOptions{
rootCaCert: cert{
publicKeyPath: filepath.Join(tmpDir, "nonexistent.pem"),
privateKeyPath: filepath.Join(tmpDir, "nonexistent.key"),
},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cs := &certStore{
options: tt.fields.options,
caCert: tt.fields.caCert,
}
if err := cs.initCaCert(); (err != nil) != tt.wantErr || cs.caCert == nil {
t.Errorf("initCaCert() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}

View file

@ -1,106 +0,0 @@
package main
import (
"crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"math/big"
"os"
)
type curveType string
const (
certificateBlockType = "CERTIFICATE"
privateKeyBlockType = "PRIVATE KEY"
curveTypeP224 curveType = "P224"
curveTypeP256 curveType = "P256"
curveTypeP384 curveType = "P384"
curveTypeP521 curveType = "P521"
curveTypeED25519 curveType = "ED25519"
)
func loadPEMCert(certPEMBytes []byte, keyPEMBytes []byte) (*tls.Certificate, error) {
cert, err := tls.X509KeyPair(certPEMBytes, keyPEMBytes)
return &cert, err
}
func parseCert(derBytes []byte, privateKeyBytes []byte) (*tls.Certificate, error) {
pemEncodedPublicKey := pem.EncodeToMemory(&pem.Block{Type: certificateBlockType, Bytes: derBytes})
pemEncodedPrivateKey := pem.EncodeToMemory(&pem.Block{Type: privateKeyBlockType, Bytes: privateKeyBytes})
cert, err := tls.X509KeyPair(pemEncodedPublicKey, pemEncodedPrivateKey)
return &cert, err
}
func writePublicKey(crtOutPath string, derBytes []byte) (err error) {
var certOut *os.File
if certOut, err = os.Create(crtOutPath); err != nil {
return
}
if err = pem.Encode(certOut, &pem.Block{Type: certificateBlockType, Bytes: derBytes}); err != nil {
return
}
err = certOut.Close()
return
}
func writePrivateKey(keyOutPath string, privateKeyBytes []byte) (err error) {
var keyOut *os.File
if keyOut, err = os.OpenFile(keyOutPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600); err != nil {
return
}
if err = pem.Encode(keyOut, &pem.Block{Type: privateKeyBlockType, Bytes: privateKeyBytes}); err != nil {
return
}
err = keyOut.Close()
return
}
func certShouldBeRenewed(timeSource timeSource, cert *x509.Certificate) bool {
lifetime := cert.NotAfter.Sub(cert.NotBefore)
// if the cert is closer to the end of the lifetime than lifetime/2 it should be renewed
if cert.NotAfter.Sub(timeSource.UTCNow()) < lifetime/4 {
return true
}
return false
}
func generateSerialNumber() (*big.Int, error) {
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
return rand.Int(rand.Reader, serialNumberLimit)
}
func privateKeyForCurve(options *tlsOptions) (privateKey interface{}, err error) {
switch options.ecdsaCurve {
case "P224":
privateKey, err = ecdsa.GenerateKey(elliptic.P224(), rand.Reader)
case "P256":
privateKey, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
case "P384":
privateKey, err = ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
case "P521":
privateKey, err = ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
default:
_, privateKey, err = ed25519.GenerateKey(rand.Reader)
}
return
}
func publicKey(privateKey interface{}) interface{} {
switch k := privateKey.(type) {
case *ecdsa.PrivateKey:
return &k.PublicKey
case ed25519.PrivateKey:
return k.Public().(ed25519.PublicKey)
default:
return nil
}
}

View file

@ -1,74 +0,0 @@
package main
import (
"crypto/x509"
"testing"
"time"
)
type testTimeSource struct {
nowValue time.Time
}
func (t testTimeSource) UTCNow() time.Time {
return t.nowValue
}
func Test_certShouldBeRenewed(t *testing.T) {
type args struct {
timeSource timeSource
cert *x509.Certificate
}
tests := []struct {
name string
args args
want bool
}{
{
name: "Detect cert is expired",
want: true,
args: args{
cert: &x509.Certificate{
NotAfter: time.Now().UTC().Add(1 * time.Hour),
NotBefore: time.Now().UTC().Add(-1 * time.Hour),
},
timeSource: testTimeSource{
nowValue: time.Now().UTC().Add(2 * time.Hour),
},
},
},
{
name: "Detect cert should be renewed",
want: true,
args: args{
cert: &x509.Certificate{
NotAfter: time.Now().UTC().Add(1 * time.Hour),
NotBefore: time.Now().UTC().Add(-1 * time.Hour),
},
timeSource: testTimeSource{
nowValue: time.Now().UTC().Add(45 * time.Minute),
},
},
},
{
name: "Detect cert shouldn't be renewed",
want: false,
args: args{
cert: &x509.Certificate{
NotAfter: time.Now().UTC().Add(1 * time.Hour),
NotBefore: time.Now().UTC().Add(-1 * time.Hour),
},
timeSource: testTimeSource{
nowValue: time.Now().UTC().Add(25 * time.Minute),
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := certShouldBeRenewed(tt.args.timeSource, tt.args.cert); got != tt.want {
t.Errorf("certShouldBeRenewed() = %v, want %v", got, tt.want)
}
})
}
}

View file

@ -1,113 +0,0 @@
package main
import (
"github.com/baez90/inetmock/pkg/logging"
"github.com/spf13/cobra"
"go.uber.org/zap"
"time"
)
const (
generateCACertOutPath = "cert-out"
generateCAKeyOutPath = "key-out"
generateCACurveName = "curve"
generateCANotBeforeRelative = "not-before"
generateCANotAfterRelative = "not-after"
)
func generateCACmd(logger logging.Logger) *cobra.Command {
cmd := &cobra.Command{
Use: "generate-ca",
Short: "Generate a new CA certificate and corresponding key",
Long: ``,
Run: runGenerateCA(logger),
}
cmd.Flags().String(generateCACertOutPath, "", "Path where CA cert file should be stored")
cmd.Flags().String(generateCAKeyOutPath, "", "Path where CA key file should be stored")
cmd.Flags().String(generateCACurveName, "", "Name of the curve to use, if empty ED25519 is used, other valid values are [P224, P256,P384,P521]")
cmd.Flags().Duration(generateCANotBeforeRelative, 17520*time.Hour, "Relative time value since when in the past the CA certificate should be valid. The value has a time unit, the greatest time unit is h for hour.")
cmd.Flags().Duration(generateCANotAfterRelative, 17520*time.Hour, "Relative time value until when in the future the CA certificate should be valid. The value has a time unit, the greatest time unit is h for hour.")
return cmd
}
func getDurationFlag(cmd *cobra.Command, flagName string, logger logging.Logger) (val time.Duration, err error) {
if val, err = cmd.Flags().GetDuration(flagName); err != nil {
logger.Error(
"failed to parse parse flag",
zap.String("flag", flagName),
zap.Error(err),
)
}
return
}
func getStringFlag(cmd *cobra.Command, flagName string, logger logging.Logger) (val string, err error) {
if val, err = cmd.Flags().GetString(flagName); err != nil {
logger.Error(
"failed to parse parse flag",
zap.String("flag", flagName),
zap.Error(err),
)
}
return
}
func runGenerateCA(logger logging.Logger) func(cmd *cobra.Command, args []string) {
return func(cmd *cobra.Command, args []string) {
var certOutPath, keyOutPath, curveName string
var notBefore, notAfter time.Duration
var err error
if certOutPath, err = getStringFlag(cmd, generateCACertOutPath, logger); err != nil {
return
}
if keyOutPath, err = getStringFlag(cmd, generateCAKeyOutPath, logger); err != nil {
return
}
if curveName, err = getStringFlag(cmd, generateCACurveName, logger); err != nil {
return
}
if notBefore, err = getDurationFlag(cmd, generateCANotBeforeRelative, logger); err != nil {
return
}
if notAfter, err = getDurationFlag(cmd, generateCANotAfterRelative, logger); err != nil {
return
}
logger := logger.With(
zap.String(generateCACurveName, curveName),
zap.String(generateCACertOutPath, certOutPath),
zap.String(generateCAKeyOutPath, keyOutPath),
)
certStore := certStore{
options: &tlsOptions{
ecdsaCurve: curveType(curveName),
validity: validity{
ca: certValidity{
notAfterRelative: notAfter,
notBeforeRelative: notBefore,
},
},
rootCaCert: cert{
publicKeyPath: certOutPath,
privateKeyPath: keyOutPath,
},
},
}
if _, _, err := certStore.generateCaCert(); err != nil {
logger.Error(
"failed to generate CA cert",
zap.Error(err),
)
} else {
logger.Info("Successfully generated CA cert")
}
}
}

View file

@ -4,9 +4,8 @@ go 1.14
require ( require (
github.com/baez90/inetmock v0.0.1 github.com/baez90/inetmock v0.0.1
github.com/spf13/cobra v1.0.0
github.com/spf13/viper v1.6.3 github.com/spf13/viper v1.6.3
go.uber.org/zap v1.14.1 go.uber.org/zap v1.15.0
) )
replace github.com/baez90/inetmock v0.0.1 => ../../ replace github.com/baez90/inetmock v0.0.1 => ../../

View file

@ -147,6 +147,8 @@ go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee/go.mod h1:vJERXedbb3MVM5f9E
go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q=
go.uber.org/zap v1.14.1 h1:nYDKopTbvAPq/NrUVZwT15y2lpROBiLLyoRTbXOYWOo= go.uber.org/zap v1.14.1 h1:nYDKopTbvAPq/NrUVZwT15y2lpROBiLLyoRTbXOYWOo=
go.uber.org/zap v1.14.1/go.mod h1:Mb2vm2krFEG5DV0W9qcHBYFtp/Wku1cvYaqPsS/WYfc= go.uber.org/zap v1.14.1/go.mod h1:Mb2vm2krFEG5DV0W9qcHBYFtp/Wku1cvYaqPsS/WYfc=
go.uber.org/zap v1.15.0 h1:ZZCA22JRF2gQE5FoNmhmrf7jeJJ2uhqDUNRYKm8dvmM=
go.uber.org/zap v1.15.0/go.mod h1:Mb2vm2krFEG5DV0W9qcHBYFtp/Wku1cvYaqPsS/WYfc=
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
@ -178,6 +180,8 @@ golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200409092240-59c9f1ba88fa h1:mQTN3ECqfsViCNBgq+A40vdwhkGykrrQlYe3mPj6BoU= golang.org/x/sys v0.0.0-20200409092240-59c9f1ba88fa h1:mQTN3ECqfsViCNBgq+A40vdwhkGykrrQlYe3mPj6BoU=
golang.org/x/sys v0.0.0-20200409092240-59c9f1ba88fa/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200409092240-59c9f1ba88fa/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200420163511-1957bb5e6d1f h1:gWF768j/LaZugp8dyS4UwsslYCYz9XgFxvlgsn0n9H8=
golang.org/x/sys v0.0.0-20200420163511-1957bb5e6d1f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs=

View file

@ -19,5 +19,5 @@ func init() {
logger: logger, logger: logger,
currentConnectionsCount: &sync.WaitGroup{}, currentConnectionsCount: &sync.WaitGroup{},
} }
}, generateCACmd(logger)) })
} }

View file

@ -2,9 +2,9 @@ package main
import ( import (
"crypto/tls" "crypto/tls"
"crypto/x509"
"fmt" "fmt"
"github.com/baez90/inetmock/internal/config" "github.com/baez90/inetmock/pkg/api"
"github.com/baez90/inetmock/pkg/cert"
"github.com/baez90/inetmock/pkg/logging" "github.com/baez90/inetmock/pkg/logging"
"go.uber.org/zap" "go.uber.org/zap"
"net" "net"
@ -17,16 +17,16 @@ const (
) )
type tlsInterceptor struct { type tlsInterceptor struct {
options tlsOptions
logger logging.Logger logger logging.Logger
listener net.Listener listener net.Listener
certStore *certStore certStore cert.Store
options *tlsOptions
shutdownRequested bool shutdownRequested bool
currentConnectionsCount *sync.WaitGroup currentConnectionsCount *sync.WaitGroup
currentConnections []*proxyConn currentConnections []*proxyConn
} }
func (t *tlsInterceptor) Start(config config.HandlerConfig) (err error) { func (t *tlsInterceptor) Start(config api.HandlerConfig) (err error) {
t.options = loadFromConfig(config.Options()) t.options = loadFromConfig(config.Options())
addr := fmt.Sprintf("%s:%d", config.ListenAddress(), config.Port()) addr := fmt.Sprintf("%s:%d", config.ListenAddress(), config.Port())
@ -35,33 +35,7 @@ func (t *tlsInterceptor) Start(config config.HandlerConfig) (err error) {
zap.String("target", t.options.redirectionTarget.address()), zap.String("target", t.options.redirectionTarget.address()),
) )
t.certStore = &certStore{ if t.listener, err = tls.Listen("tcp", addr, api.ServicesInstance().CertStore().TLSConfig()); err != nil {
options: t.options,
certCache: make(map[string]*tls.Certificate),
logger: t.logger,
}
if err = t.certStore.initCaCert(); err != nil {
t.logger.Error(
"failed to initialize CA cert",
zap.Error(err),
)
err = fmt.Errorf(
"failed to initialize CA cert: %w",
err,
)
return
}
rootCaPool := x509.NewCertPool()
rootCaPool.AddCert(t.certStore.caCert)
tlsConfig := &tls.Config{
GetCertificate: t.getCert,
RootCAs: rootCaPool,
}
if t.listener, err = tls.Listen("tcp", addr, tlsConfig); err != nil {
t.logger.Fatal( t.logger.Fatal(
"failed to create tls listener", "failed to create tls listener",
zap.Error(err), zap.Error(err),
@ -122,23 +96,6 @@ func (t *tlsInterceptor) startListener() {
} }
} }
func (t *tlsInterceptor) getCert(info *tls.ClientHelloInfo) (cert *tls.Certificate, err error) {
var localIp string
if localIp, err = extractIPFromAddress(info.Conn.LocalAddr().String()); err != nil {
localIp = "127.0.0.1"
}
if cert, err = t.certStore.getCertificate(info.ServerName, localIp); err != nil {
t.logger.Error(
"error while resolving certificate",
zap.String("serverName", info.ServerName),
zap.String("localAddr", localIp),
zap.Error(err),
)
}
return
}
func (t *tlsInterceptor) proxyConn(conn net.Conn) { func (t *tlsInterceptor) proxyConn(conn net.Conn) {
defer conn.Close() defer conn.Close()

View file

@ -3,37 +3,13 @@ package main
import ( import (
"fmt" "fmt"
"github.com/spf13/viper" "github.com/spf13/viper"
"time"
) )
const ( const (
certCachePathConfigKey = "certCachePath" targetIpAddressConfigKey = "target.ipAddress"
ecdsaCurveConfigKey = "ecdsaCurve" targetPortConfigKey = "target.port"
targetIpAddressConfigKey = "target.ipAddress"
targetPortConfigKey = "target.port"
publicKeyConfigKey = "rootCaCert.publicKey"
privateKeyPathConfigKey = "rootCaCert.privateKey"
caCertValidityNotBeforeKey = "validity.ca.notBeforeRelative"
caCertValidityNotAfterKey = "validity.ca.notAfterRelative"
domainCertValidityNotBeforeKey = "validity.domain.notBeforeRelative"
domainCertValidityNotAfterKey = "validity.domain.notAfterRelative"
) )
type cert struct {
publicKeyPath string
privateKeyPath string
}
type certValidity struct {
notBeforeRelative time.Duration
notAfterRelative time.Duration
}
type validity struct {
ca certValidity
domain certValidity
}
type redirectionTarget struct { type redirectionTarget struct {
ipAddress string ipAddress string
port uint16 port uint16
@ -44,35 +20,14 @@ func (rt redirectionTarget) address() string {
} }
type tlsOptions struct { type tlsOptions struct {
rootCaCert cert
certCachePath string
redirectionTarget redirectionTarget redirectionTarget redirectionTarget
ecdsaCurve curveType
validity validity
} }
func loadFromConfig(config *viper.Viper) *tlsOptions { func loadFromConfig(config *viper.Viper) tlsOptions {
return tlsOptions{
return &tlsOptions{
certCachePath: config.GetString(certCachePathConfigKey),
ecdsaCurve: curveType(config.GetString(ecdsaCurveConfigKey)),
redirectionTarget: redirectionTarget{ redirectionTarget: redirectionTarget{
ipAddress: config.GetString(targetIpAddressConfigKey), ipAddress: config.GetString(targetIpAddressConfigKey),
port: uint16(config.GetInt(targetPortConfigKey)), port: uint16(config.GetInt(targetPortConfigKey)),
}, },
validity: validity{
ca: certValidity{
notBeforeRelative: config.GetDuration(caCertValidityNotBeforeKey),
notAfterRelative: config.GetDuration(caCertValidityNotAfterKey),
},
domain: certValidity{
notBeforeRelative: config.GetDuration(domainCertValidityNotBeforeKey),
notAfterRelative: config.GetDuration(domainCertValidityNotAfterKey),
},
},
rootCaCert: cert{
publicKeyPath: config.GetString(publicKeyConfigKey),
privateKeyPath: config.GetString(privateKeyPathConfigKey),
},
} }
} }

View file

@ -1,55 +0,0 @@
package main
import (
"fmt"
"io/ioutil"
"os"
"path/filepath"
"time"
)
var (
testCaCrt []byte
testCaKey []byte
)
func init() {
tmpDir, err := ioutil.TempDir(os.TempDir(), "*-inetmock")
if err != nil {
panic(fmt.Sprintf("failed to create temp dir %v", err))
}
options := &tlsOptions{
ecdsaCurve: "P256",
rootCaCert: cert{
publicKeyPath: filepath.Join(tmpDir, "localhost.pem"),
privateKeyPath: filepath.Join(tmpDir, "localhost.key"),
},
validity: validity{
ca: certValidity{
notBeforeRelative: time.Hour * 24 * 30,
notAfterRelative: time.Hour * 24 * 30,
},
},
}
certStore := certStore{
options: options,
keyProvider: func() (key interface{}, err error) {
return privateKeyForCurve(options)
},
}
defer func() {
_ = os.Remove(tmpDir)
}()
_, _, err = certStore.generateCaCert()
testCaCrt, _ = ioutil.ReadFile(options.rootCaCert.publicKeyPath)
testCaKey, _ = ioutil.ReadFile(options.rootCaCert.privateKeyPath)
if err != nil {
panic(fmt.Sprintf("failed to generate CA cert %v", err))
}
}

View file

@ -1,18 +0,0 @@
package main
import "time"
type timeSource interface {
UTCNow() time.Time
}
func createTimeSource() timeSource {
return &defaultTimeSource{}
}
type defaultTimeSource struct {
}
func (d defaultTimeSource) UTCNow() time.Time {
return time.Now().UTC()
}