Move TLS/cert handling to main app
- apply changes in proxy plugin and TLS interceptor - add HTTPS proxy support - move ca-generation command to main app - minor refactoring to improve API stability - move mocks to extra packages to avoid cycling imports - fix bug in multi-port configuration - change HTTP proxy to redirect to HTTP mock instead of maintaining custom rules
This commit is contained in:
parent
43d3c62e01
commit
7c2a41ad25
65 changed files with 1722 additions and 1335 deletions
|
@ -3,16 +3,22 @@
|
|||
###############
|
||||
|
||||
.git/
|
||||
plugins/
|
||||
.idea/
|
||||
.vscode/
|
||||
deploy/
|
||||
dist/
|
||||
doc/
|
||||
|
||||
#########
|
||||
# Files #
|
||||
#########
|
||||
|
||||
*.so
|
||||
*.out
|
||||
main
|
||||
inetmock
|
||||
README.md
|
||||
LICENSE
|
||||
.dockerignore
|
||||
.gitignore
|
||||
Dockerfile
|
||||
Dockerfile
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
before:
|
||||
hooks:
|
||||
# You may remove this if you don't use go modules.
|
||||
- make handlers
|
||||
- make plugins
|
||||
builds:
|
||||
- id: "default"
|
||||
ldflags:
|
||||
|
|
|
@ -38,8 +38,8 @@ WORKDIR /app
|
|||
|
||||
COPY --from=build /etc/passwd /etc/group /etc/
|
||||
COPY --from=build --chown=$USER /work/inetmock ./
|
||||
COPY --from=build --chown=$USER /work/plugins/ ./plugins/
|
||||
COPY --from=build --chown=$USER /work/*.so ./plugins/
|
||||
|
||||
USER $USER:$USER
|
||||
|
||||
ENTRYPOINT ["/app/inetmock"]
|
||||
ENTRYPOINT ["/app/inetmock"]
|
||||
|
|
23
Makefile
23
Makefile
|
@ -1,9 +1,8 @@
|
|||
VERSst/pluginsION = $(shell git describe --dirty --tags --always)
|
||||
VERSION = $(shell git describe --dirty --tags --always)
|
||||
DIR = $(dir $(realpath $(firstword $(MAKEFILE_LIST))))
|
||||
BUILD_PATH = $(DIR)/main.go
|
||||
PKGS = $(shell go list ./...)
|
||||
TEST_PKGS = $(shell find . -type f -name "*_test.go" -not -path "./plugins/*" -printf '%h\n' | sort -u)
|
||||
GO_GEN_FILES = $(shell grep -rnwl --include="*.go" "go:generate" $(DIR))
|
||||
TEST_PKGS = $(shell find . -type f -name "*_test.go" -not -path "./plugins/*" -not -path "*/mock/*" -printf '%h\n' | sort -u)
|
||||
GOARGS = GOOS=linux GOARCH=amd64
|
||||
GO_BUILD_ARGS = -ldflags="-w -s"
|
||||
GO_CONTAINER_BUILD_ARGS = -ldflags="-w -s" -a -installsuffix cgo
|
||||
|
@ -12,8 +11,10 @@ BINARY_NAME = inetmock
|
|||
PLUGINS = $(wildcard $(DIR)plugins/*/.)
|
||||
DEBUG_PORT = 2345
|
||||
DEBUG_ARGS?= --development-logs=true
|
||||
CONTAINER_BUILDER ?= podman
|
||||
DOCKER_IMAGE ?= inetmock
|
||||
|
||||
.PHONY: clean all format deps update-deps compile debug generate snapshot-release test cli-cover-report html-cover-report plugins $(PLUGINS) $(.PHONY: clean all format deps update-deps compile debug generate snapshot-release test cli-cover-report html-cover-report plugins $(PLUGINS) $(GO_GEN_FILES)
|
||||
.PHONY: clean all format deps update-deps compile debug generate snapshot-release test cli-cover-report html-cover-report plugins $(PLUGINS) $(GO_GEN_FILES)
|
||||
all: clean format compile test plugins
|
||||
|
||||
clean:
|
||||
|
@ -45,26 +46,26 @@ endif
|
|||
|
||||
debug: export INETMOCK_PLUGINS_DIRECTORY = $(DIR)
|
||||
debug:
|
||||
dlv exec $(DIR)$(BINARY_NAME) \
|
||||
dlv debug $(DIR) \
|
||||
--headless \
|
||||
--listen=:2345 \
|
||||
--api-version=2 \
|
||||
--accept-multiclient \
|
||||
-- $(DEBUG_ARGS)
|
||||
|
||||
generate:
|
||||
@for go_gen_target in $(GO_GEN_FILES); do \
|
||||
go generate $$go_gen_target; \
|
||||
done
|
||||
@go generate ./...
|
||||
|
||||
snapshot-release:
|
||||
@goreleaser release --snapshot --skip-publish --rm-dist
|
||||
|
||||
container:
|
||||
@$(CONTAINER_BUILDER) build -t $(DOCKER_IMAGE):latest -f $(DIR)Dockerfile $(DIR)
|
||||
|
||||
test:
|
||||
@go test -coverprofile=./cov-raw.out -v $(TEST_PKGS)
|
||||
@cat ./cov-raw.out | grep -v "generated" > ./cov.out
|
||||
|
||||
cli-cover-report:
|
||||
cli-cover-report: test
|
||||
@go tool cover -func=cov.out
|
||||
|
||||
html-cover-report: test
|
||||
|
@ -72,4 +73,4 @@ html-cover-report: test
|
|||
|
||||
plugins: $(PLUGINS)
|
||||
$(PLUGINS):
|
||||
$(MAKE) -C $@
|
||||
$(MAKE) -C $@
|
||||
|
|
4
go.mod
4
go.mod
|
@ -13,8 +13,8 @@ require (
|
|||
github.com/spf13/jwalterweatherman v1.1.0 // indirect
|
||||
github.com/spf13/pflag v1.0.5
|
||||
github.com/spf13/viper v1.6.3
|
||||
go.uber.org/zap v1.14.1
|
||||
golang.org/x/sys v0.0.0-20200409092240-59c9f1ba88fa // indirect
|
||||
go.uber.org/zap v1.15.0
|
||||
golang.org/x/sys v0.0.0-20200420163511-1957bb5e6d1f // indirect
|
||||
golang.org/x/text v0.3.2 // indirect
|
||||
golang.org/x/tools v0.0.0-20191216052735-49a3e744a425 // indirect
|
||||
gopkg.in/ini.v1 v1.55.0 // indirect
|
||||
|
|
8
go.sum
8
go.sum
|
@ -147,8 +147,8 @@ go.uber.org/multierr v1.5.0/go.mod h1:FeouvMocqHpRaaGuG9EjoKcStLC43Zu/fmqdUMPcKY
|
|||
go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee h1:0mgffUl7nfd+FpvXMVz4IDEaUSmT1ysygQC7qYo7sG4=
|
||||
go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee/go.mod h1:vJERXedbb3MVM5f9Ejo0C68/HhF8uaILCdgjnY+goOA=
|
||||
go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q=
|
||||
go.uber.org/zap v1.14.1 h1:nYDKopTbvAPq/NrUVZwT15y2lpROBiLLyoRTbXOYWOo=
|
||||
go.uber.org/zap v1.14.1/go.mod h1:Mb2vm2krFEG5DV0W9qcHBYFtp/Wku1cvYaqPsS/WYfc=
|
||||
go.uber.org/zap v1.15.0 h1:ZZCA22JRF2gQE5FoNmhmrf7jeJJ2uhqDUNRYKm8dvmM=
|
||||
go.uber.org/zap v1.15.0/go.mod h1:Mb2vm2krFEG5DV0W9qcHBYFtp/Wku1cvYaqPsS/WYfc=
|
||||
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
|
@ -178,8 +178,8 @@ golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5h
|
|||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200409092240-59c9f1ba88fa h1:mQTN3ECqfsViCNBgq+A40vdwhkGykrrQlYe3mPj6BoU=
|
||||
golang.org/x/sys v0.0.0-20200409092240-59c9f1ba88fa/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200420163511-1957bb5e6d1f h1:gWF768j/LaZugp8dyS4UwsslYCYz9XgFxvlgsn0n9H8=
|
||||
golang.org/x/sys v0.0.0-20200420163511-1957bb5e6d1f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs=
|
||||
|
|
144
internal/cmd/ca.go
Normal file
144
internal/cmd/ca.go
Normal file
|
@ -0,0 +1,144 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"github.com/baez90/inetmock/pkg/cert"
|
||||
"github.com/baez90/inetmock/pkg/logging"
|
||||
"github.com/spf13/cobra"
|
||||
"go.uber.org/zap"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
generateCACommonName = "cn"
|
||||
generateCaOrganizationName = "o"
|
||||
generateCaOrganizationalUnitName = "ou"
|
||||
generateCaCountryName = "c"
|
||||
generateCaLocalityName = "l"
|
||||
generateCaStateName = "st"
|
||||
generateCaStreetAddressName = "street-address"
|
||||
generateCaPostalCodeName = "postal-code"
|
||||
generateCACertOutPath = "out-dir"
|
||||
generateCACurveName = "curve"
|
||||
generateCANotBeforeRelative = "not-before"
|
||||
generateCANotAfterRelative = "not-after"
|
||||
)
|
||||
|
||||
var (
|
||||
generateCaCmd *cobra.Command
|
||||
caCertOptions cert.GenerationOptions
|
||||
)
|
||||
|
||||
func init() {
|
||||
generateCaCmd = &cobra.Command{
|
||||
Use: "generate-ca",
|
||||
Short: "Generate a new CA certificate and corresponding key",
|
||||
Long: ``,
|
||||
Run: runGenerateCA,
|
||||
}
|
||||
|
||||
generateCaCmd.Flags().StringVar(&caCertOptions.CommonName, generateCACommonName, "INetMock", "Certificate Common Name that will also be used as file name during generation.")
|
||||
generateCaCmd.Flags().StringSliceVar(&caCertOptions.Organization, generateCaOrganizationName, nil, "Organization information to append to certificate")
|
||||
generateCaCmd.Flags().StringSliceVar(&caCertOptions.OrganizationalUnit, generateCaOrganizationalUnitName, nil, "Organizational unit information to append to certificate")
|
||||
generateCaCmd.Flags().StringSliceVar(&caCertOptions.Country, generateCaCountryName, nil, "Country information to append to certificate")
|
||||
generateCaCmd.Flags().StringSliceVar(&caCertOptions.Province, generateCaStateName, nil, "State information to append to certificate")
|
||||
generateCaCmd.Flags().StringSliceVar(&caCertOptions.Locality, generateCaLocalityName, nil, "Locality information to append to certificate")
|
||||
generateCaCmd.Flags().StringSliceVar(&caCertOptions.StreetAddress, generateCaStreetAddressName, nil, "Street address information to append to certificate")
|
||||
generateCaCmd.Flags().StringSliceVar(&caCertOptions.PostalCode, generateCaPostalCodeName, nil, "Postal code information to append to certificate")
|
||||
generateCaCmd.Flags().String(generateCACertOutPath, "", "Path where CA files should be stored")
|
||||
generateCaCmd.Flags().String(generateCACurveName, "", "Name of the curve to use, if empty ED25519 is used, other valid values are [P224, P256,P384,P521]")
|
||||
generateCaCmd.Flags().Duration(generateCANotBeforeRelative, 17520*time.Hour, "Relative time value since when in the past the CA certificate should be valid. The value has a time unit, the greatest time unit is h for hour.")
|
||||
generateCaCmd.Flags().Duration(generateCANotAfterRelative, 17520*time.Hour, "Relative time value until when in the future the CA certificate should be valid. The value has a time unit, the greatest time unit is h for hour.")
|
||||
}
|
||||
|
||||
func runGenerateCA(cmd *cobra.Command, args []string) {
|
||||
var certOutPath, curveName string
|
||||
var notBefore, notAfter time.Duration
|
||||
var err error
|
||||
|
||||
if certOutPath, err = getStringFlag(generateCaCmd, generateCACertOutPath, logger); err != nil {
|
||||
return
|
||||
}
|
||||
if curveName, err = getStringFlag(generateCaCmd, generateCACurveName, logger); err != nil {
|
||||
return
|
||||
}
|
||||
if notBefore, err = getDurationFlag(generateCaCmd, generateCANotBeforeRelative, logger); err != nil {
|
||||
return
|
||||
}
|
||||
if notAfter, err = getDurationFlag(generateCaCmd, generateCANotAfterRelative, logger); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
logger, _ := logging.CreateLogger()
|
||||
|
||||
logger = logger.With(
|
||||
zap.String(generateCACurveName, curveName),
|
||||
zap.String(generateCACertOutPath, certOutPath),
|
||||
)
|
||||
|
||||
generator := cert.NewDefaultGenerator(cert.Options{
|
||||
CertCachePath: certOutPath,
|
||||
Curve: cert.CurveType(curveName),
|
||||
Validity: cert.ValidityByPurpose{
|
||||
CA: cert.ValidityDuration{
|
||||
NotAfterRelative: notAfter,
|
||||
NotBeforeRelative: notBefore,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
var caCrt *tls.Certificate
|
||||
if caCrt, err = generator.CACert(caCertOptions); err != nil {
|
||||
logger.Error(
|
||||
"failed to generate CA certificate",
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
if len(caCrt.Certificate) < 1 {
|
||||
logger.Error("no public key given for generated CA certificate")
|
||||
return
|
||||
}
|
||||
|
||||
var pubKey *x509.Certificate
|
||||
if pubKey, err = x509.ParseCertificate(caCrt.Certificate[0]); err != nil {
|
||||
logger.Error(
|
||||
"failed to parse public key from generated CA",
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
pemCrt := cert.NewPEM(caCrt)
|
||||
if err = pemCrt.Write(pubKey.Subject.CommonName, certOutPath); err != nil {
|
||||
logger.Error(
|
||||
"failed to write Ca files",
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
logger.Info("completed certificate generation")
|
||||
}
|
||||
|
||||
func getDurationFlag(cmd *cobra.Command, flagName string, logger logging.Logger) (val time.Duration, err error) {
|
||||
if val, err = cmd.Flags().GetDuration(flagName); err != nil {
|
||||
logger.Error(
|
||||
"failed to parse parse flag",
|
||||
zap.String("flag", flagName),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func getStringFlag(cmd *cobra.Command, flagName string, logger logging.Logger) (val string, err error) {
|
||||
if val, err = cmd.Flags().GetString(flagName); err != nil {
|
||||
logger.Error(
|
||||
"failed to parse parse flag",
|
||||
zap.String("flag", flagName),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
return
|
||||
}
|
|
@ -3,6 +3,7 @@ package cmd
|
|||
import (
|
||||
"github.com/baez90/inetmock/internal/endpoints"
|
||||
"github.com/baez90/inetmock/internal/plugins"
|
||||
"github.com/baez90/inetmock/pkg/api"
|
||||
"github.com/baez90/inetmock/pkg/logging"
|
||||
"github.com/baez90/inetmock/pkg/path"
|
||||
"github.com/spf13/viper"
|
||||
|
@ -29,9 +30,8 @@ func initApp() (err error) {
|
|||
registry := plugins.Registry()
|
||||
endpointManager = endpoints.NewEndpointManager(logger)
|
||||
|
||||
if err = rootCmd.ParseFlags(os.Args); err != nil {
|
||||
return
|
||||
}
|
||||
rootCmd.Flags().ParseErrorsWhitelist.UnknownFlags = false
|
||||
_ = rootCmd.ParseFlags(os.Args)
|
||||
|
||||
if err = appConfig.ReadConfig(configFilePath); err != nil {
|
||||
logger.Error(
|
||||
|
@ -43,6 +43,14 @@ func initApp() (err error) {
|
|||
|
||||
viperInst := viper.GetViper()
|
||||
pluginDir := viperInst.GetString("plugins-directory")
|
||||
|
||||
if err = api.InitServices(viperInst, logger); err != nil {
|
||||
logger.Error(
|
||||
"failed to initialize app services",
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
|
||||
pluginLoadStartTime := time.Now()
|
||||
if err = registry.LoadPlugins(pluginDir); err != nil {
|
||||
logger.Error("Failed to load plugins",
|
||||
|
|
|
@ -18,5 +18,5 @@ The easiest way to explore what commands are available is to start with 'inetmoc
|
|||
This help page contains a list of available sub-commands starting with the name of the plugin as a prefix.
|
||||
`,
|
||||
}
|
||||
rootCmd.AddCommand(pluginsCmd)
|
||||
rootCmd.AddCommand(pluginsCmd, generateCaCmd)
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"github.com/baez90/inetmock/pkg/api"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
|
@ -9,7 +10,7 @@ type MultiHandlerConfig interface {
|
|||
ListenAddress() string
|
||||
Ports() []uint16
|
||||
Options() *viper.Viper
|
||||
HandlerConfigs() []HandlerConfig
|
||||
HandlerConfigs() []api.HandlerConfig
|
||||
}
|
||||
|
||||
type multiHandlerConfig struct {
|
||||
|
@ -39,10 +40,10 @@ func (m multiHandlerConfig) Options() *viper.Viper {
|
|||
return m.options
|
||||
}
|
||||
|
||||
func (m multiHandlerConfig) HandlerConfigs() []HandlerConfig {
|
||||
configs := make([]HandlerConfig, 0)
|
||||
func (m multiHandlerConfig) HandlerConfigs() []api.HandlerConfig {
|
||||
configs := make([]api.HandlerConfig, 0)
|
||||
for _, port := range m.ports {
|
||||
configs = append(configs, NewHandlerConfig(m.handlerName, port, m.listenAddress, m.options))
|
||||
configs = append(configs, api.NewHandlerConfig(m.handlerName, port, m.listenAddress, m.options))
|
||||
}
|
||||
return configs
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"github.com/baez90/inetmock/pkg/api"
|
||||
"github.com/spf13/viper"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
@ -16,12 +17,12 @@ func Test_multiHandlerConfig_HandlerConfigs(t *testing.T) {
|
|||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
want []HandlerConfig
|
||||
want []api.HandlerConfig
|
||||
}{
|
||||
{
|
||||
name: "Get empty array if no ports are set",
|
||||
fields: fields{},
|
||||
want: make([]HandlerConfig, 0),
|
||||
want: make([]api.HandlerConfig, 0),
|
||||
},
|
||||
{
|
||||
name: "Get a single handler config if only one port is set",
|
||||
|
@ -31,13 +32,8 @@ func Test_multiHandlerConfig_HandlerConfigs(t *testing.T) {
|
|||
listenAddress: "0.0.0.0",
|
||||
options: nil,
|
||||
},
|
||||
want: []HandlerConfig{
|
||||
&handlerConfig{
|
||||
handlerName: "sampleHandler",
|
||||
port: 80,
|
||||
listenAddress: "0.0.0.0",
|
||||
options: nil,
|
||||
},
|
||||
want: []api.HandlerConfig{
|
||||
api.NewHandlerConfig("sampleHandler", 80, "0.0.0.0", nil),
|
||||
},
|
||||
},
|
||||
{
|
||||
|
@ -48,19 +44,9 @@ func Test_multiHandlerConfig_HandlerConfigs(t *testing.T) {
|
|||
listenAddress: "0.0.0.0",
|
||||
options: nil,
|
||||
},
|
||||
want: []HandlerConfig{
|
||||
&handlerConfig{
|
||||
handlerName: "sampleHandler",
|
||||
port: 80,
|
||||
listenAddress: "0.0.0.0",
|
||||
options: nil,
|
||||
},
|
||||
&handlerConfig{
|
||||
handlerName: "sampleHandler",
|
||||
port: 8080,
|
||||
listenAddress: "0.0.0.0",
|
||||
options: nil,
|
||||
},
|
||||
want: []api.HandlerConfig{
|
||||
api.NewHandlerConfig("sampleHandler", 80, "0.0.0.0", nil),
|
||||
api.NewHandlerConfig("sampleHandler", 8080, "0.0.0.0", nil),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
|
@ -1,6 +1,15 @@
|
|||
package config
|
||||
|
||||
import "github.com/spf13/viper"
|
||||
import (
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
const (
|
||||
pluginConfigKey = "handler"
|
||||
listenAddressConfigKey = "listenAddress"
|
||||
portConfigKey = "port"
|
||||
portsConfigKey = "ports"
|
||||
)
|
||||
|
||||
func CreateMultiHandlerConfig(handlerConfig *viper.Viper) MultiHandlerConfig {
|
||||
return NewMultiHandlerConfig(
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
//go:generate mockgen -source=endpoint.go -destination=./../../internal/mock/endpoint_mock.go -package=mock
|
||||
//go:generate mockgen -source=endpoint.go -destination=./../../internal/mock/endpoints/endpoint_mock.go -package=endpoints_mock
|
||||
package endpoints
|
||||
|
||||
import (
|
||||
"github.com/baez90/inetmock/internal/config"
|
||||
"github.com/baez90/inetmock/pkg/api"
|
||||
)
|
||||
|
||||
|
@ -15,7 +14,7 @@ type Endpoint interface {
|
|||
type endpoint struct {
|
||||
name string
|
||||
handler api.ProtocolHandler
|
||||
config config.HandlerConfig
|
||||
config api.HandlerConfig
|
||||
}
|
||||
|
||||
func (e endpoint) Name() string {
|
||||
|
|
|
@ -40,19 +40,20 @@ func (e endpointManager) StartedEndpoints() []Endpoint {
|
|||
return e.properlyStartedEndpoints
|
||||
}
|
||||
|
||||
func (e *endpointManager) CreateEndpoint(name string, multiHandlerConfig config.MultiHandlerConfig) (err error) {
|
||||
if handler, ok := e.registry.HandlerForName(multiHandlerConfig.HandlerName()); ok {
|
||||
for _, handlerConfig := range multiHandlerConfig.HandlerConfigs() {
|
||||
func (e *endpointManager) CreateEndpoint(name string, multiHandlerConfig config.MultiHandlerConfig) error {
|
||||
for _, handlerConfig := range multiHandlerConfig.HandlerConfigs() {
|
||||
if handler, ok := e.registry.HandlerForName(multiHandlerConfig.HandlerName()); ok {
|
||||
e.registeredEndpoints = append(e.registeredEndpoints, &endpoint{
|
||||
name: name,
|
||||
handler: handler,
|
||||
config: handlerConfig,
|
||||
})
|
||||
} else {
|
||||
return fmt.Errorf("no matching handler registered for names %s", multiHandlerConfig.HandlerName())
|
||||
}
|
||||
} else {
|
||||
err = fmt.Errorf("no matching handler registered for name %s", multiHandlerConfig.HandlerName())
|
||||
}
|
||||
return
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *endpointManager) StartEndpoints() {
|
||||
|
|
|
@ -2,7 +2,9 @@ package endpoints
|
|||
|
||||
import (
|
||||
"github.com/baez90/inetmock/internal/config"
|
||||
"github.com/baez90/inetmock/internal/mock"
|
||||
api_mock "github.com/baez90/inetmock/internal/mock/api"
|
||||
logging_mock "github.com/baez90/inetmock/internal/mock/logging"
|
||||
plugins_mock "github.com/baez90/inetmock/internal/mock/plugins"
|
||||
"github.com/baez90/inetmock/internal/plugins"
|
||||
"github.com/baez90/inetmock/pkg/logging"
|
||||
"github.com/golang/mock/gomock"
|
||||
|
@ -33,18 +35,18 @@ func Test_endpointManager_CreateEndpoint(t *testing.T) {
|
|||
wantEndpoints: 1,
|
||||
fields: fields{
|
||||
logger: func() logging.Logger {
|
||||
return mock.NewMockLogger(gomock.NewController(t))
|
||||
return logging_mock.NewMockLogger(gomock.NewController(t))
|
||||
}(),
|
||||
registeredEndpoints: nil,
|
||||
properlyStartedEndpoints: nil,
|
||||
registry: func() plugins.HandlerRegistry {
|
||||
registry := mock.NewMockHandlerRegistry(gomock.NewController(t))
|
||||
registry := plugins_mock.NewMockHandlerRegistry(gomock.NewController(t))
|
||||
registry.
|
||||
EXPECT().
|
||||
HandlerForName("sampleHandler").
|
||||
MinTimes(1).
|
||||
MaxTimes(1).
|
||||
Return(mock.NewMockProtocolHandler(gomock.NewController(t)), true)
|
||||
Return(api_mock.NewMockProtocolHandler(gomock.NewController(t)), true)
|
||||
return registry
|
||||
}(),
|
||||
},
|
||||
|
@ -64,12 +66,12 @@ func Test_endpointManager_CreateEndpoint(t *testing.T) {
|
|||
wantEndpoints: 0,
|
||||
fields: fields{
|
||||
logger: func() logging.Logger {
|
||||
return mock.NewMockLogger(gomock.NewController(t))
|
||||
return logging_mock.NewMockLogger(gomock.NewController(t))
|
||||
}(),
|
||||
registeredEndpoints: nil,
|
||||
properlyStartedEndpoints: nil,
|
||||
registry: func() plugins.HandlerRegistry {
|
||||
registry := mock.NewMockHandlerRegistry(gomock.NewController(t))
|
||||
registry := plugins_mock.NewMockHandlerRegistry(gomock.NewController(t))
|
||||
registry.
|
||||
EXPECT().
|
||||
HandlerForName("sampleHandler").
|
||||
|
|
|
@ -2,8 +2,7 @@ package endpoints
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/baez90/inetmock/internal/config"
|
||||
"github.com/baez90/inetmock/internal/mock"
|
||||
api_mock "github.com/baez90/inetmock/internal/mock/api"
|
||||
"github.com/baez90/inetmock/pkg/api"
|
||||
"github.com/golang/mock/gomock"
|
||||
"testing"
|
||||
|
@ -13,7 +12,7 @@ func Test_endpoint_Name(t *testing.T) {
|
|||
type fields struct {
|
||||
name string
|
||||
handler api.ProtocolHandler
|
||||
config config.HandlerConfig
|
||||
config api.HandlerConfig
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
|
@ -51,7 +50,7 @@ func Test_endpoint_Shutdown(t *testing.T) {
|
|||
type fields struct {
|
||||
name string
|
||||
handler api.ProtocolHandler
|
||||
config config.HandlerConfig
|
||||
config api.HandlerConfig
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
|
@ -62,7 +61,7 @@ func Test_endpoint_Shutdown(t *testing.T) {
|
|||
name: "Expect no error if mocked handler does not return one",
|
||||
fields: fields{
|
||||
handler: func() api.ProtocolHandler {
|
||||
handler := mock.NewMockProtocolHandler(gomock.NewController(t))
|
||||
handler := api_mock.NewMockProtocolHandler(gomock.NewController(t))
|
||||
handler.EXPECT().
|
||||
Shutdown().
|
||||
MaxTimes(1).
|
||||
|
@ -76,7 +75,7 @@ func Test_endpoint_Shutdown(t *testing.T) {
|
|||
name: "Expect error if mocked handler returns one",
|
||||
fields: fields{
|
||||
handler: func() api.ProtocolHandler {
|
||||
handler := mock.NewMockProtocolHandler(gomock.NewController(t))
|
||||
handler := api_mock.NewMockProtocolHandler(gomock.NewController(t))
|
||||
handler.EXPECT().
|
||||
Shutdown().
|
||||
MaxTimes(1).
|
||||
|
@ -103,7 +102,7 @@ func Test_endpoint_Shutdown(t *testing.T) {
|
|||
|
||||
func Test_endpoint_Start(t *testing.T) {
|
||||
|
||||
demoHandlerConfig := config.NewHandlerConfig(
|
||||
demoHandlerConfig := api.NewHandlerConfig(
|
||||
"sampleHandler",
|
||||
80,
|
||||
"0.0.0.0",
|
||||
|
@ -113,7 +112,7 @@ func Test_endpoint_Start(t *testing.T) {
|
|||
type fields struct {
|
||||
name string
|
||||
handler api.ProtocolHandler
|
||||
config config.HandlerConfig
|
||||
config api.HandlerConfig
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
|
@ -124,7 +123,7 @@ func Test_endpoint_Start(t *testing.T) {
|
|||
name: "Expect no error if mocked handler does not return one",
|
||||
fields: fields{
|
||||
handler: func() api.ProtocolHandler {
|
||||
handler := mock.NewMockProtocolHandler(gomock.NewController(t))
|
||||
handler := api_mock.NewMockProtocolHandler(gomock.NewController(t))
|
||||
handler.EXPECT().
|
||||
Start(nil).
|
||||
MaxTimes(1).
|
||||
|
@ -138,7 +137,7 @@ func Test_endpoint_Start(t *testing.T) {
|
|||
name: "Expect error if mocked handler returns one",
|
||||
fields: fields{
|
||||
handler: func() api.ProtocolHandler {
|
||||
handler := mock.NewMockProtocolHandler(gomock.NewController(t))
|
||||
handler := api_mock.NewMockProtocolHandler(gomock.NewController(t))
|
||||
handler.EXPECT().
|
||||
Start(nil).
|
||||
MaxTimes(1).
|
||||
|
@ -153,7 +152,7 @@ func Test_endpoint_Start(t *testing.T) {
|
|||
fields: fields{
|
||||
config: demoHandlerConfig,
|
||||
handler: func() api.ProtocolHandler {
|
||||
handler := mock.NewMockProtocolHandler(gomock.NewController(t))
|
||||
handler := api_mock.NewMockProtocolHandler(gomock.NewController(t))
|
||||
handler.EXPECT().
|
||||
Start(demoHandlerConfig).
|
||||
MaxTimes(1).
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
//go:generate mockgen -source=loading.go -destination=./../../internal/mock/handler_registry_mock.go -package=mock
|
||||
//go:generate mockgen -source=loading.go -destination=./../../internal/mock/plugins/handler_registry_mock.go -package=plugins_mock
|
||||
package plugins
|
||||
|
||||
import (
|
||||
|
|
|
@ -15,15 +15,15 @@ x-response-rules: &httpResponseRules
|
|||
- pattern: ".*"
|
||||
response: ./assets/fakeFiles/default.html
|
||||
|
||||
x-tls-options: &tlsOptions
|
||||
tls:
|
||||
ecdsaCurve: P256
|
||||
validity:
|
||||
ca:
|
||||
notBeforeRelative: 17520h
|
||||
notAfterRelative: 17520h
|
||||
domain:
|
||||
notBeforeRelative: 168h
|
||||
notAfterRelative: 168h
|
||||
server:
|
||||
NotBeforeRelative: 168h
|
||||
NotAfterRelative: 168h
|
||||
rootCaCert:
|
||||
publicKey: ./ca.pem
|
||||
privateKey: ./ca.key
|
||||
|
@ -43,8 +43,9 @@ endpoints:
|
|||
listenAddress: 0.0.0.0
|
||||
port: 3128
|
||||
options:
|
||||
fallback: notfound
|
||||
<<: *httpResponseRules
|
||||
target:
|
||||
ipAddress: 127.0.0.1
|
||||
port: 80
|
||||
httpsDowngrade:
|
||||
handler: tls_interceptor
|
||||
listenAddress: 0.0.0.0
|
||||
|
@ -52,7 +53,6 @@ endpoints:
|
|||
- 443
|
||||
- 8443
|
||||
options:
|
||||
<<: *tlsOptions
|
||||
target:
|
||||
ipAddress: 127.0.0.1
|
||||
port: 80
|
||||
|
@ -75,7 +75,6 @@ endpoints:
|
|||
listenAddress: 0.0.0.0
|
||||
port: 853
|
||||
options:
|
||||
<<: *tlsOptions
|
||||
target:
|
||||
ipAddress: 127.0.0.1
|
||||
port: 53
|
|
@ -1,14 +1,7 @@
|
|||
package config
|
||||
package api
|
||||
|
||||
import "github.com/spf13/viper"
|
||||
|
||||
const (
|
||||
pluginConfigKey = "handler"
|
||||
listenAddressConfigKey = "listenAddress"
|
||||
portConfigKey = "port"
|
||||
portsConfigKey = "ports"
|
||||
)
|
||||
|
||||
type HandlerConfig interface {
|
||||
HandlerName() string
|
||||
ListenAddress() string
|
|
@ -1,4 +1,4 @@
|
|||
package config
|
||||
package api
|
||||
|
||||
import (
|
||||
"github.com/spf13/viper"
|
|
@ -1,8 +1,7 @@
|
|||
//go:generate mockgen -source=protocol_handler.go -destination=./../../internal/mock/protocol_handler_mock.go -package=mock
|
||||
//go:generate mockgen -source=protocol_handler.go -destination=./../../internal/mock/api/protocol_handler_mock.go -package=api_mock
|
||||
package api
|
||||
|
||||
import (
|
||||
"github.com/baez90/inetmock/internal/config"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
|
@ -11,6 +10,6 @@ type PluginInstanceFactory func() ProtocolHandler
|
|||
type LoggingFactory func() (*zap.Logger, error)
|
||||
|
||||
type ProtocolHandler interface {
|
||||
Start(config config.HandlerConfig) error
|
||||
Start(config HandlerConfig) error
|
||||
Shutdown() error
|
||||
}
|
||||
|
|
41
pkg/api/services.go
Normal file
41
pkg/api/services.go
Normal file
|
@ -0,0 +1,41 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"github.com/baez90/inetmock/pkg/cert"
|
||||
"github.com/baez90/inetmock/pkg/logging"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
var (
|
||||
svcs Services
|
||||
)
|
||||
|
||||
type Services interface {
|
||||
CertStore() cert.Store
|
||||
}
|
||||
|
||||
type services struct {
|
||||
certStore cert.Store
|
||||
}
|
||||
|
||||
func InitServices(
|
||||
config *viper.Viper,
|
||||
logger logging.Logger,
|
||||
) error {
|
||||
certStore, err := cert.NewDefaultStore(config, logger)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
svcs = &services{
|
||||
certStore: certStore,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ServicesInstance() Services {
|
||||
return svcs
|
||||
}
|
||||
|
||||
func (s *services) CertStore() cert.Store {
|
||||
return s.certStore
|
||||
}
|
22
pkg/cert/addr_utils.go
Normal file
22
pkg/cert/addr_utils.go
Normal file
|
@ -0,0 +1,22 @@
|
|||
package cert
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func extractIPFromAddress(addr string) (ip string, err error) {
|
||||
if idx := strings.LastIndex(addr, ":"); idx < 0 {
|
||||
err = fmt.Errorf("addr %s does not match expected scheme <ip>:<port>", addr)
|
||||
|
||||
} else {
|
||||
/* get IP part of address */
|
||||
ip = addr[0:idx]
|
||||
|
||||
/* trim [ ] for IPv6 addresses */
|
||||
if ip[0] == '[' {
|
||||
ip = ip[1 : len(ip)-1]
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
package main
|
||||
package cert
|
||||
|
||||
import "testing"
|
||||
|
78
pkg/cert/cache.go
Normal file
78
pkg/cert/cache.go
Normal file
|
@ -0,0 +1,78 @@
|
|||
//go:generate mockgen -source=cache.go -destination=./../../internal/mock/cert/cert_cache_mock.go -package=cert_mock
|
||||
|
||||
package cert
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
)
|
||||
|
||||
type Cache interface {
|
||||
Put(cert *tls.Certificate) error
|
||||
Get(cn string) (*tls.Certificate, bool)
|
||||
}
|
||||
|
||||
func NewFileSystemCache(certCachePath string, source TimeSource) Cache {
|
||||
return &fileSystemCache{
|
||||
certCachePath: certCachePath,
|
||||
inMemCache: make(map[string]*tls.Certificate),
|
||||
timeSource: source,
|
||||
}
|
||||
}
|
||||
|
||||
type fileSystemCache struct {
|
||||
certCachePath string
|
||||
inMemCache map[string]*tls.Certificate
|
||||
timeSource TimeSource
|
||||
}
|
||||
|
||||
func (f *fileSystemCache) Put(cert *tls.Certificate) (err error) {
|
||||
if cert == nil {
|
||||
err = errors.New("cert may not be nil")
|
||||
return
|
||||
}
|
||||
var cn string
|
||||
if len(cert.Certificate) > 0 {
|
||||
if pubKey, err := x509.ParseCertificate(cert.Certificate[0]); err != nil {
|
||||
return err
|
||||
} else {
|
||||
cn = pubKey.Subject.CommonName
|
||||
}
|
||||
|
||||
f.inMemCache[cn] = cert
|
||||
pemCrt := NewPEM(cert)
|
||||
err = pemCrt.Write(cn, f.certCachePath)
|
||||
} else {
|
||||
err = errors.New("no public key present for certificate")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (f *fileSystemCache) Get(cn string) (*tls.Certificate, bool) {
|
||||
|
||||
if crt, ok := f.inMemCache[cn]; ok {
|
||||
return crt, true
|
||||
}
|
||||
|
||||
pemCrt := NewPEM(nil)
|
||||
if err := pemCrt.Read(cn, f.certCachePath); err != nil || pemCrt.Cert() == nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
x509Cert, err := x509.ParseCertificate(pemCrt.Cert().Certificate[0])
|
||||
if err == nil && !certShouldBeRenewed(f.timeSource, x509Cert) {
|
||||
return pemCrt.Cert(), true
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func certShouldBeRenewed(timeSource TimeSource, cert *x509.Certificate) bool {
|
||||
lifetime := cert.NotAfter.Sub(cert.NotBefore)
|
||||
// if the cert is closer to the end of the lifetime than lifetime/2 it should be renewed
|
||||
if cert.NotAfter.Sub(timeSource.UTCNow()) < lifetime/4 {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
299
pkg/cert/cache_test.go
Normal file
299
pkg/cert/cache_test.go
Normal file
|
@ -0,0 +1,299 @@
|
|||
package cert
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
certmock "github.com/baez90/inetmock/internal/mock/cert"
|
||||
"github.com/golang/mock/gomock"
|
||||
"os"
|
||||
"path"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
cnLocalhost = "localhost"
|
||||
caCN = "UnitTests"
|
||||
serverRelativeValidity = 24 * time.Hour
|
||||
caRelativeValidity = 168 * time.Hour
|
||||
)
|
||||
|
||||
var (
|
||||
serverCN = fmt.Sprintf("%s-%d", cnLocalhost, time.Now().Unix())
|
||||
)
|
||||
|
||||
func Test_certShouldBeRenewed(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
type args struct {
|
||||
timeSource TimeSource
|
||||
cert *x509.Certificate
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "Expect cert should not be renewed right after creation",
|
||||
args: args{
|
||||
timeSource: func() TimeSource {
|
||||
tsMock := certmock.NewMockTimeSource(ctrl)
|
||||
tsMock.
|
||||
EXPECT().
|
||||
UTCNow().
|
||||
Return(time.Now().UTC()).
|
||||
Times(1)
|
||||
return tsMock
|
||||
}(),
|
||||
cert: &x509.Certificate{
|
||||
NotAfter: time.Now().UTC().Add(serverRelativeValidity),
|
||||
NotBefore: time.Now().UTC().Add(-serverRelativeValidity),
|
||||
},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "Expect cert should be renewed if the remaining lifetime is less than a quarter of the total lifetime",
|
||||
args: args{
|
||||
timeSource: func() TimeSource {
|
||||
tsMock := certmock.NewMockTimeSource(ctrl)
|
||||
tsMock.
|
||||
EXPECT().
|
||||
UTCNow().
|
||||
Return(time.Now().UTC().Add(serverRelativeValidity/2 + 1*time.Hour)).
|
||||
Times(1)
|
||||
return tsMock
|
||||
}(),
|
||||
cert: &x509.Certificate{
|
||||
NotAfter: time.Now().UTC().Add(serverRelativeValidity),
|
||||
NotBefore: time.Now().UTC().Add(-serverRelativeValidity),
|
||||
},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := certShouldBeRenewed(tt.args.timeSource, tt.args.cert); got != tt.want {
|
||||
t.Errorf("certShouldBeRenewed() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_fileSystemCache_Get(t *testing.T) {
|
||||
type fields struct {
|
||||
certCachePath string
|
||||
inMemCache map[string]*tls.Certificate
|
||||
timeSource TimeSource
|
||||
}
|
||||
type args struct {
|
||||
cn string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
wantOk bool
|
||||
}{
|
||||
{
|
||||
name: "Get a miss when no cert is present",
|
||||
fields: fields{
|
||||
certCachePath: os.TempDir(),
|
||||
inMemCache: make(map[string]*tls.Certificate),
|
||||
timeSource: NewTimeSource(),
|
||||
},
|
||||
args: args{
|
||||
cnLocalhost,
|
||||
},
|
||||
wantOk: false,
|
||||
},
|
||||
{
|
||||
name: "Get a prepared certificate from the memory cache",
|
||||
fields: func() fields {
|
||||
certGen := setupCertGen()
|
||||
|
||||
caCrt, _ := certGen.CACert(GenerationOptions{
|
||||
CommonName: caCN,
|
||||
})
|
||||
|
||||
srvCrt, _ := certGen.ServerCert(GenerationOptions{
|
||||
CommonName: cnLocalhost,
|
||||
}, caCrt)
|
||||
|
||||
return fields{
|
||||
certCachePath: os.TempDir(),
|
||||
inMemCache: map[string]*tls.Certificate{
|
||||
cnLocalhost: srvCrt,
|
||||
},
|
||||
timeSource: NewTimeSource(),
|
||||
}
|
||||
}(),
|
||||
args: args{
|
||||
cn: cnLocalhost,
|
||||
},
|
||||
wantOk: true,
|
||||
},
|
||||
{
|
||||
name: "Get a prepared certificate from the file system",
|
||||
fields: func() fields {
|
||||
certGen := setupCertGen()
|
||||
|
||||
caCrt, _ := certGen.CACert(GenerationOptions{
|
||||
CommonName: "INetMock",
|
||||
})
|
||||
|
||||
srvCrt, _ := certGen.ServerCert(GenerationOptions{
|
||||
CommonName: serverCN,
|
||||
}, caCrt)
|
||||
|
||||
pem := NewPEM(srvCrt)
|
||||
if err := pem.Write(serverCN, os.TempDir()); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
for _, f := range []string{
|
||||
path.Join(os.TempDir(), fmt.Sprintf("%s.pem", serverCN)),
|
||||
path.Join(os.TempDir(), fmt.Sprintf("%s.key", serverCN)),
|
||||
} {
|
||||
_ = os.Remove(f)
|
||||
}
|
||||
})
|
||||
|
||||
return fields{
|
||||
certCachePath: os.TempDir(),
|
||||
inMemCache: make(map[string]*tls.Certificate),
|
||||
timeSource: NewTimeSource(),
|
||||
}
|
||||
}(),
|
||||
args: args{
|
||||
cn: serverCN,
|
||||
},
|
||||
wantOk: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
f := &fileSystemCache{
|
||||
certCachePath: tt.fields.certCachePath,
|
||||
inMemCache: tt.fields.inMemCache,
|
||||
timeSource: tt.fields.timeSource,
|
||||
}
|
||||
gotCrt, gotOk := f.Get(tt.args.cn)
|
||||
|
||||
if gotOk && (gotCrt == nil || !reflect.DeepEqual(reflect.TypeOf(new(tls.Certificate)), reflect.TypeOf(gotCrt))) {
|
||||
t.Errorf("Wanted propert certificate but got %v", gotCrt)
|
||||
}
|
||||
|
||||
if gotOk != tt.wantOk {
|
||||
t.Errorf("Get() gotOk = %v, want %v", gotOk, tt.wantOk)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_fileSystemCache_Put(t *testing.T) {
|
||||
type fields struct {
|
||||
certCachePath string
|
||||
inMemCache map[string]*tls.Certificate
|
||||
timeSource TimeSource
|
||||
}
|
||||
type args struct {
|
||||
cert *tls.Certificate
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "Want error if nil cert is passed to put",
|
||||
fields: fields{
|
||||
certCachePath: os.TempDir(),
|
||||
inMemCache: make(map[string]*tls.Certificate),
|
||||
timeSource: NewTimeSource(),
|
||||
},
|
||||
args: args{
|
||||
cert: nil,
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "Want error if empty cert is passed to put",
|
||||
fields: fields{
|
||||
certCachePath: os.TempDir(),
|
||||
inMemCache: make(map[string]*tls.Certificate),
|
||||
timeSource: NewTimeSource(),
|
||||
},
|
||||
args: args{
|
||||
cert: &tls.Certificate{},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "No error if valid cert is passed",
|
||||
fields: fields{
|
||||
certCachePath: os.TempDir(),
|
||||
inMemCache: make(map[string]*tls.Certificate),
|
||||
timeSource: NewTimeSource(),
|
||||
},
|
||||
args: args{
|
||||
cert: func() *tls.Certificate {
|
||||
gen := setupCertGen()
|
||||
ca, _ := gen.CACert(GenerationOptions{
|
||||
CommonName: caCN,
|
||||
})
|
||||
|
||||
srvCN := fmt.Sprintf("%s-%d", cnLocalhost, time.Now().Unix())
|
||||
|
||||
t.Cleanup(func() {
|
||||
for _, f := range []string{
|
||||
path.Join(os.TempDir(), fmt.Sprintf("%s.pem", srvCN)),
|
||||
path.Join(os.TempDir(), fmt.Sprintf("%s.key", srvCN)),
|
||||
} {
|
||||
_ = os.Remove(f)
|
||||
}
|
||||
})
|
||||
|
||||
srvCrt, _ := gen.ServerCert(GenerationOptions{
|
||||
CommonName: srvCN,
|
||||
}, ca)
|
||||
return srvCrt
|
||||
}(),
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
f := &fileSystemCache{
|
||||
certCachePath: tt.fields.certCachePath,
|
||||
inMemCache: tt.fields.inMemCache,
|
||||
timeSource: tt.fields.timeSource,
|
||||
}
|
||||
if err := f.Put(tt.args.cert); (err != nil) != tt.wantErr {
|
||||
t.Errorf("Put() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func setupCertGen() Generator {
|
||||
return NewDefaultGenerator(Options{
|
||||
Validity: ValidityByPurpose{
|
||||
Server: ValidityDuration{
|
||||
NotBeforeRelative: serverRelativeValidity,
|
||||
NotAfterRelative: serverRelativeValidity,
|
||||
},
|
||||
CA: ValidityDuration{
|
||||
NotBeforeRelative: caRelativeValidity,
|
||||
NotAfterRelative: caRelativeValidity,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
21
pkg/cert/constants.go
Normal file
21
pkg/cert/constants.go
Normal file
|
@ -0,0 +1,21 @@
|
|||
package cert
|
||||
|
||||
const (
|
||||
CurveTypeP224 CurveType = "P224"
|
||||
CurveTypeP256 CurveType = "P256"
|
||||
CurveTypeP384 CurveType = "P384"
|
||||
CurveTypeP521 CurveType = "P521"
|
||||
CurveTypeED25519 CurveType = "ED25519"
|
||||
|
||||
defaultServerValidityDuration = "168h"
|
||||
defaultCAValidityDuration = "17520h"
|
||||
|
||||
certCachePathConfigKey = "tls.certCachePath"
|
||||
ecdsaCurveConfigKey = "tls.ecdsaCurve"
|
||||
publicKeyConfigKey = "tls.rootCaCert.publicKey"
|
||||
privateKeyPathConfigKey = "tls.rootCaCert.privateKey"
|
||||
caCertValidityNotBeforeKey = "tls.validity.ca.notBeforeRelative"
|
||||
caCertValidityNotAfterKey = "tls.validity.ca.notAfterRelative"
|
||||
serverCertValidityNotBeforeKey = "tls.validity.server.notBeforeRelative"
|
||||
serverCertValidityNotAfterKey = "tls.validity.server.notAfterRelative"
|
||||
)
|
43
pkg/cert/defaults.go
Normal file
43
pkg/cert/defaults.go
Normal file
|
@ -0,0 +1,43 @@
|
|||
package cert
|
||||
|
||||
import (
|
||||
"github.com/baez90/inetmock/pkg/defaulting"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
var (
|
||||
certOptionsDefaulter defaulting.Defaulter = func(instance interface{}) {
|
||||
switch o := instance.(type) {
|
||||
case *GenerationOptions:
|
||||
|
||||
if len(o.Country) < 1 {
|
||||
o.Country = []string{"US"}
|
||||
}
|
||||
|
||||
if len(o.Locality) < 1 {
|
||||
o.Locality = []string{"San Francisco"}
|
||||
}
|
||||
|
||||
if len(o.Organization) < 1 {
|
||||
o.Organization = []string{"INetMock"}
|
||||
}
|
||||
|
||||
if len(o.StreetAddress) < 1 {
|
||||
o.StreetAddress = []string{"Golden Gate Bridge"}
|
||||
}
|
||||
|
||||
if len(o.PostalCode) < 1 {
|
||||
o.PostalCode = []string{"94016"}
|
||||
}
|
||||
|
||||
if len(o.Province) < 1 {
|
||||
o.Province = []string{""}
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
func init() {
|
||||
certOptionsType := reflect.TypeOf(GenerationOptions{})
|
||||
defaulters.Register(certOptionsType, certOptionsDefaulter)
|
||||
}
|
39
pkg/cert/defaults_test.go
Normal file
39
pkg/cert/defaults_test.go
Normal file
|
@ -0,0 +1,39 @@
|
|||
package cert
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_certOptionsDefaulter(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
arg GenerationOptions
|
||||
expected GenerationOptions
|
||||
}{
|
||||
{
|
||||
name: "",
|
||||
arg: GenerationOptions{
|
||||
CommonName: "CA",
|
||||
},
|
||||
expected: GenerationOptions{
|
||||
CommonName: "CA",
|
||||
Country: []string{"US"},
|
||||
Locality: []string{"San Francisco"},
|
||||
Organization: []string{"INetMock"},
|
||||
StreetAddress: []string{"Golden Gate Bridge"},
|
||||
PostalCode: []string{"94016"},
|
||||
Province: []string{""},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
certOptionsDefaulter(&tt.arg)
|
||||
if !reflect.DeepEqual(tt.expected, tt.arg) {
|
||||
t.Errorf("Apply defaulter expected=%v got=%v", tt.expected, tt.arg)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
184
pkg/cert/generator.go
Normal file
184
pkg/cert/generator.go
Normal file
|
@ -0,0 +1,184 @@
|
|||
package cert
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"github.com/baez90/inetmock/pkg/defaulting"
|
||||
"math/big"
|
||||
"net"
|
||||
)
|
||||
|
||||
var (
|
||||
defaulters = defaulting.New()
|
||||
)
|
||||
|
||||
type GenerationOptions struct {
|
||||
CommonName string
|
||||
Organization []string
|
||||
OrganizationalUnit []string
|
||||
IPAddresses []net.IP
|
||||
DNSNames []string
|
||||
Country []string
|
||||
Province []string
|
||||
Locality []string
|
||||
StreetAddress []string
|
||||
PostalCode []string
|
||||
}
|
||||
|
||||
type Generator interface {
|
||||
CACert(options GenerationOptions) (*tls.Certificate, error)
|
||||
ServerCert(options GenerationOptions, ca *tls.Certificate) (*tls.Certificate, error)
|
||||
}
|
||||
|
||||
func NewDefaultGenerator(options Options) Generator {
|
||||
return NewGenerator(options, NewTimeSource(), defaultKeyProvider(options))
|
||||
}
|
||||
|
||||
func NewGenerator(options Options, source TimeSource, provider KeyProvider) Generator {
|
||||
return &generator{
|
||||
options: options,
|
||||
provider: provider,
|
||||
timeSource: source,
|
||||
}
|
||||
}
|
||||
|
||||
type generator struct {
|
||||
options Options
|
||||
provider KeyProvider
|
||||
timeSource TimeSource
|
||||
outDir string
|
||||
}
|
||||
|
||||
func (g *generator) privateKey() (key interface{}, err error) {
|
||||
if g.provider != nil {
|
||||
return g.provider()
|
||||
} else {
|
||||
return defaultKeyProvider(g.options)()
|
||||
}
|
||||
}
|
||||
|
||||
func (g *generator) ServerCert(options GenerationOptions, ca *tls.Certificate) (cert *tls.Certificate, err error) {
|
||||
defaulters.Apply(&options)
|
||||
var serialNumber *big.Int
|
||||
if serialNumber, err = generateSerialNumber(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var privateKey interface{}
|
||||
if privateKey, err = g.privateKey(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
template := x509.Certificate{
|
||||
SerialNumber: serialNumber,
|
||||
Subject: pkix.Name{
|
||||
CommonName: options.CommonName,
|
||||
Organization: options.Organization,
|
||||
OrganizationalUnit: options.OrganizationalUnit,
|
||||
Country: options.Country,
|
||||
Province: options.Province,
|
||||
Locality: options.Locality,
|
||||
StreetAddress: options.StreetAddress,
|
||||
PostalCode: options.PostalCode,
|
||||
},
|
||||
IPAddresses: options.IPAddresses,
|
||||
DNSNames: options.DNSNames,
|
||||
NotBefore: g.timeSource.UTCNow().Add(-g.options.Validity.Server.NotBeforeRelative),
|
||||
NotAfter: g.timeSource.UTCNow().Add(g.options.Validity.Server.NotAfterRelative),
|
||||
KeyUsage: x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
|
||||
}
|
||||
var caCrt *x509.Certificate
|
||||
caCrt, err = x509.ParseCertificate(ca.Certificate[0])
|
||||
|
||||
var derBytes []byte
|
||||
if derBytes, err = x509.CreateCertificate(rand.Reader, &template, caCrt, publicKey(privateKey), ca.PrivateKey); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var privateKeyBytes []byte
|
||||
if privateKeyBytes, err = x509.MarshalPKCS8PrivateKey(privateKey); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if cert, err = parseCert(derBytes, privateKeyBytes); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (g generator) CACert(options GenerationOptions) (crt *tls.Certificate, err error) {
|
||||
defaulters.Apply(&options)
|
||||
|
||||
var privateKey interface{}
|
||||
var serialNumber *big.Int
|
||||
if serialNumber, err = generateSerialNumber(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if privateKey, err = g.privateKey(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
template := x509.Certificate{
|
||||
SerialNumber: serialNumber,
|
||||
Subject: pkix.Name{
|
||||
CommonName: options.CommonName,
|
||||
Organization: options.Organization,
|
||||
Country: options.Country,
|
||||
Province: options.Province,
|
||||
Locality: options.Locality,
|
||||
StreetAddress: options.StreetAddress,
|
||||
PostalCode: options.PostalCode,
|
||||
},
|
||||
IsCA: true,
|
||||
NotBefore: g.timeSource.UTCNow().Add(-g.options.Validity.Server.NotBeforeRelative),
|
||||
NotAfter: g.timeSource.UTCNow().Add(g.options.Validity.Server.NotAfterRelative),
|
||||
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
|
||||
BasicConstraintsValid: true,
|
||||
}
|
||||
|
||||
var derBytes []byte
|
||||
if derBytes, err = x509.CreateCertificate(rand.Reader, &template, &template, publicKey(privateKey), privateKey); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var privateKeyBytes []byte
|
||||
if privateKeyBytes, err = x509.MarshalPKCS8PrivateKey(privateKey); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
crt, err = parseCert(derBytes, privateKeyBytes)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func generateSerialNumber() (*big.Int, error) {
|
||||
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
|
||||
return rand.Int(rand.Reader, serialNumberLimit)
|
||||
}
|
||||
|
||||
func publicKey(privateKey interface{}) interface{} {
|
||||
switch k := privateKey.(type) {
|
||||
case *ecdsa.PrivateKey:
|
||||
return &k.PublicKey
|
||||
case ed25519.PrivateKey:
|
||||
return k.Public().(ed25519.PublicKey)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func parseCert(derBytes []byte, privateKeyBytes []byte) (*tls.Certificate, error) {
|
||||
pemEncodedPublicKey := pem.EncodeToMemory(&pem.Block{Type: certificateBlockType, Bytes: derBytes})
|
||||
pemEncodedPrivateKey := pem.EncodeToMemory(&pem.Block{Type: privateKeyBlockType, Bytes: privateKeyBytes})
|
||||
cert, err := tls.X509KeyPair(pemEncodedPublicKey, pemEncodedPrivateKey)
|
||||
return &cert, err
|
||||
}
|
79
pkg/cert/options.go
Normal file
79
pkg/cert/options.go
Normal file
|
@ -0,0 +1,79 @@
|
|||
package cert
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/spf13/viper"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
requiredConfigKeys = []string{
|
||||
publicKeyConfigKey,
|
||||
privateKeyPathConfigKey,
|
||||
}
|
||||
)
|
||||
|
||||
type CurveType string
|
||||
|
||||
type File struct {
|
||||
PublicKeyPath string
|
||||
PrivateKeyPath string
|
||||
}
|
||||
|
||||
type ValidityDuration struct {
|
||||
NotBeforeRelative time.Duration
|
||||
NotAfterRelative time.Duration
|
||||
}
|
||||
|
||||
type ValidityByPurpose struct {
|
||||
CA ValidityDuration
|
||||
Server ValidityDuration
|
||||
}
|
||||
|
||||
type Options struct {
|
||||
RootCACert File
|
||||
CertCachePath string
|
||||
Curve CurveType
|
||||
Validity ValidityByPurpose
|
||||
}
|
||||
|
||||
func loadFromConfig(config *viper.Viper) (Options, error) {
|
||||
missingKeys := make([]string, 0)
|
||||
for _, requiredKey := range requiredConfigKeys {
|
||||
if !config.IsSet(requiredKey) {
|
||||
missingKeys = append(missingKeys, requiredKey)
|
||||
}
|
||||
}
|
||||
|
||||
if len(missingKeys) > 0 {
|
||||
return Options{}, fmt.Errorf("config keys are missing: %s", strings.Join(missingKeys, ", "))
|
||||
}
|
||||
|
||||
config.SetDefault(certCachePathConfigKey, os.TempDir())
|
||||
config.SetDefault(ecdsaCurveConfigKey, string(CurveTypeED25519))
|
||||
config.SetDefault(caCertValidityNotBeforeKey, defaultCAValidityDuration)
|
||||
config.SetDefault(caCertValidityNotAfterKey, defaultCAValidityDuration)
|
||||
config.SetDefault(serverCertValidityNotBeforeKey, defaultServerValidityDuration)
|
||||
config.SetDefault(serverCertValidityNotAfterKey, defaultServerValidityDuration)
|
||||
|
||||
return Options{
|
||||
CertCachePath: config.GetString(certCachePathConfigKey),
|
||||
Curve: CurveType(config.GetString(ecdsaCurveConfigKey)),
|
||||
Validity: ValidityByPurpose{
|
||||
CA: ValidityDuration{
|
||||
NotBeforeRelative: config.GetDuration(caCertValidityNotBeforeKey),
|
||||
NotAfterRelative: config.GetDuration(caCertValidityNotAfterKey),
|
||||
},
|
||||
Server: ValidityDuration{
|
||||
NotBeforeRelative: config.GetDuration(serverCertValidityNotBeforeKey),
|
||||
NotAfterRelative: config.GetDuration(serverCertValidityNotAfterKey),
|
||||
},
|
||||
},
|
||||
RootCACert: File{
|
||||
PublicKeyPath: config.GetString(publicKeyConfigKey),
|
||||
PrivateKeyPath: config.GetString(privateKeyPathConfigKey),
|
||||
},
|
||||
}, nil
|
||||
}
|
137
pkg/cert/options_test.go
Normal file
137
pkg/cert/options_test.go
Normal file
|
@ -0,0 +1,137 @@
|
|||
package cert
|
||||
|
||||
import (
|
||||
"github.com/spf13/viper"
|
||||
"os"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func readViper(cfg string) *viper.Viper {
|
||||
vpr := viper.New()
|
||||
vpr.SetConfigType("yaml")
|
||||
if err := vpr.ReadConfig(strings.NewReader(cfg)); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return vpr
|
||||
}
|
||||
|
||||
func Test_loadFromConfig(t *testing.T) {
|
||||
type args struct {
|
||||
config *viper.Viper
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want Options
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "Parse valid TLS configuration",
|
||||
wantErr: false,
|
||||
args: args{
|
||||
config: readViper(`
|
||||
tls:
|
||||
ecdsaCurve: P256
|
||||
validity:
|
||||
ca:
|
||||
notBeforeRelative: 17520h
|
||||
notAfterRelative: 17520h
|
||||
server:
|
||||
NotBeforeRelative: 168h
|
||||
NotAfterRelative: 168h
|
||||
rootCaCert:
|
||||
publicKey: ./ca.pem
|
||||
privateKey: ./ca.key
|
||||
certCachePath: /tmp/inetmock/
|
||||
`),
|
||||
},
|
||||
want: Options{
|
||||
RootCACert: File{
|
||||
PublicKeyPath: "./ca.pem",
|
||||
PrivateKeyPath: "./ca.key",
|
||||
},
|
||||
CertCachePath: "/tmp/inetmock/",
|
||||
Curve: CurveTypeP256,
|
||||
Validity: ValidityByPurpose{
|
||||
CA: ValidityDuration{
|
||||
NotBeforeRelative: 17520 * time.Hour,
|
||||
NotAfterRelative: 17520 * time.Hour,
|
||||
},
|
||||
Server: ValidityDuration{
|
||||
NotBeforeRelative: 168 * time.Hour,
|
||||
NotAfterRelative: 168 * time.Hour,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Get an error if CA public key path is missing",
|
||||
args: args{
|
||||
readViper(`
|
||||
tls:
|
||||
rootCaCert:
|
||||
privateKey: ./ca.key
|
||||
`),
|
||||
},
|
||||
want: Options{},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "Get an error if CA private key path is missing",
|
||||
args: args{
|
||||
readViper(`
|
||||
tls:
|
||||
rootCaCert:
|
||||
publicKey: ./ca.pem
|
||||
`),
|
||||
},
|
||||
want: Options{},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "Get default options if all required fields are set",
|
||||
args: args{
|
||||
readViper(`
|
||||
tls:
|
||||
rootCaCert:
|
||||
publicKey: ./ca.pem
|
||||
privateKey: ./ca.key
|
||||
`),
|
||||
},
|
||||
want: Options{
|
||||
RootCACert: File{
|
||||
PublicKeyPath: "./ca.pem",
|
||||
PrivateKeyPath: "./ca.key",
|
||||
},
|
||||
CertCachePath: os.TempDir(),
|
||||
Curve: CurveTypeED25519,
|
||||
Validity: ValidityByPurpose{
|
||||
CA: ValidityDuration{
|
||||
NotBeforeRelative: 17520 * time.Hour,
|
||||
NotAfterRelative: 17520 * time.Hour,
|
||||
},
|
||||
Server: ValidityDuration{
|
||||
NotBeforeRelative: 168 * time.Hour,
|
||||
NotAfterRelative: 168 * time.Hour,
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := loadFromConfig(tt.args.config)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("loadFromConfig() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("loadFromConfig() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
79
pkg/cert/pem.go
Normal file
79
pkg/cert/pem.go
Normal file
|
@ -0,0 +1,79 @@
|
|||
package cert
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/baez90/inetmock/pkg/path"
|
||||
"os"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
const (
|
||||
certificateBlockType = "CERTIFICATE"
|
||||
privateKeyBlockType = "PRIVATE KEY"
|
||||
)
|
||||
|
||||
type PEM interface {
|
||||
Cert() *tls.Certificate
|
||||
Write(cn string, outDir string) error
|
||||
Read(cn string, inDir string) error
|
||||
ReadFrom(pubKeyPath, privateKeyPath string) error
|
||||
}
|
||||
|
||||
func NewPEM(crt *tls.Certificate) PEM {
|
||||
return &pemCrt{
|
||||
crt: crt,
|
||||
}
|
||||
}
|
||||
|
||||
type pemCrt struct {
|
||||
crt *tls.Certificate
|
||||
}
|
||||
|
||||
func (p *pemCrt) Cert() *tls.Certificate {
|
||||
return p.crt
|
||||
}
|
||||
|
||||
func (p pemCrt) Write(cn string, outDir string) (err error) {
|
||||
|
||||
var certOut *os.File
|
||||
if certOut, err = os.Create(filepath.Join(outDir, fmt.Sprintf("%s.pem", cn))); err != nil {
|
||||
return
|
||||
}
|
||||
defer certOut.Close()
|
||||
if err = pem.Encode(certOut, &pem.Block{Type: certificateBlockType, Bytes: p.crt.Certificate[0]}); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var keyOut *os.File
|
||||
if keyOut, err = os.OpenFile(filepath.Join(outDir, fmt.Sprintf("%s.key", cn)), os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600); err != nil {
|
||||
return
|
||||
}
|
||||
var privKeyBytes []byte
|
||||
privKeyBytes, err = x509.MarshalPKCS8PrivateKey(p.crt.PrivateKey)
|
||||
err = pem.Encode(keyOut, &pem.Block{Type: privateKeyBlockType, Bytes: privKeyBytes})
|
||||
return
|
||||
}
|
||||
|
||||
func (p *pemCrt) Read(cn string, inDir string) error {
|
||||
certPath := filepath.Join(inDir, fmt.Sprintf("%s.pem", cn))
|
||||
keyPath := filepath.Join(inDir, fmt.Sprintf("%s.key", cn))
|
||||
|
||||
return p.ReadFrom(certPath, keyPath)
|
||||
}
|
||||
|
||||
func (p *pemCrt) ReadFrom(pubKeyPath, privateKeyPath string) (err error) {
|
||||
var tlsCrt tls.Certificate
|
||||
if path.FileExists(pubKeyPath) && path.FileExists(privateKeyPath) {
|
||||
if tlsCrt, err = tls.LoadX509KeyPair(pubKeyPath, privateKeyPath); err == nil {
|
||||
p.crt = &tlsCrt
|
||||
}
|
||||
} else {
|
||||
err = errors.New("either public or private key file do not exist")
|
||||
}
|
||||
|
||||
return
|
||||
}
|
153
pkg/cert/store.go
Normal file
153
pkg/cert/store.go
Normal file
|
@ -0,0 +1,153 @@
|
|||
package cert
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/ed25519"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"github.com/baez90/inetmock/pkg/logging"
|
||||
"github.com/spf13/viper"
|
||||
"go.uber.org/zap"
|
||||
"net"
|
||||
)
|
||||
|
||||
const (
|
||||
ipv4Loopback = "127.0.0.1"
|
||||
)
|
||||
|
||||
var (
|
||||
defaultKeyProvider = func(options Options) func() (key interface{}, err error) {
|
||||
return func() (key interface{}, err error) {
|
||||
return privateKeyForCurve(options)
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
type KeyProvider func() (key interface{}, err error)
|
||||
|
||||
type Store interface {
|
||||
CACert() *tls.Certificate
|
||||
GetCertificate(serverName string, ip string) (*tls.Certificate, error)
|
||||
TLSConfig() *tls.Config
|
||||
}
|
||||
|
||||
func NewDefaultStore(
|
||||
config *viper.Viper,
|
||||
logger logging.Logger,
|
||||
) (Store, error) {
|
||||
options, err := loadFromConfig(config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
timeSource := NewTimeSource()
|
||||
return NewStore(
|
||||
options,
|
||||
NewFileSystemCache(options.CertCachePath, timeSource),
|
||||
NewDefaultGenerator(options),
|
||||
logger,
|
||||
)
|
||||
}
|
||||
|
||||
func NewStore(
|
||||
options Options,
|
||||
cache Cache,
|
||||
generator Generator,
|
||||
logger logging.Logger,
|
||||
) (Store, error) {
|
||||
store := &store{
|
||||
options: options,
|
||||
cache: cache,
|
||||
generator: generator,
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
if err := store.loadCACert(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return store, nil
|
||||
}
|
||||
|
||||
type store struct {
|
||||
options Options
|
||||
caCert *tls.Certificate
|
||||
cache Cache
|
||||
timeSource TimeSource
|
||||
generator Generator
|
||||
logger logging.Logger
|
||||
}
|
||||
|
||||
func (s *store) TLSConfig() *tls.Config {
|
||||
rootCaPool := x509.NewCertPool()
|
||||
rootCaPubKey, _ := x509.ParseCertificate(s.caCert.Certificate[0])
|
||||
rootCaPool.AddCert(rootCaPubKey)
|
||||
|
||||
return &tls.Config{
|
||||
GetCertificate: func(info *tls.ClientHelloInfo) (cert *tls.Certificate, err error) {
|
||||
var localIp string
|
||||
if localIp, err = extractIPFromAddress(info.Conn.LocalAddr().String()); err != nil {
|
||||
localIp = ipv4Loopback
|
||||
}
|
||||
|
||||
if cert, err = s.GetCertificate(info.ServerName, localIp); err != nil {
|
||||
s.logger.Error(
|
||||
"error while resolving certificate",
|
||||
zap.String("serverName", info.ServerName),
|
||||
zap.String("localAddr", localIp),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
|
||||
return
|
||||
},
|
||||
RootCAs: rootCaPool,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *store) loadCACert() (err error) {
|
||||
pemCrt := NewPEM(nil)
|
||||
if err = pemCrt.ReadFrom(s.options.RootCACert.PublicKeyPath, s.options.RootCACert.PrivateKeyPath); err != nil {
|
||||
return
|
||||
}
|
||||
s.caCert = pemCrt.Cert()
|
||||
return
|
||||
}
|
||||
|
||||
func (s *store) CACert() *tls.Certificate {
|
||||
return s.caCert
|
||||
}
|
||||
|
||||
func (s *store) GetCertificate(serverName string, ip string) (cert *tls.Certificate, err error) {
|
||||
if crt, ok := s.cache.Get(serverName); ok {
|
||||
return crt, nil
|
||||
}
|
||||
|
||||
if cert, err = s.generator.ServerCert(GenerationOptions{
|
||||
CommonName: serverName,
|
||||
DNSNames: []string{serverName},
|
||||
IPAddresses: []net.IP{net.ParseIP(ip)},
|
||||
}, s.caCert); err == nil {
|
||||
s.cache.Put(cert)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func privateKeyForCurve(options Options) (privateKey interface{}, err error) {
|
||||
switch options.Curve {
|
||||
case CurveTypeP224:
|
||||
privateKey, err = ecdsa.GenerateKey(elliptic.P224(), rand.Reader)
|
||||
case CurveTypeP256:
|
||||
privateKey, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
case CurveTypeP384:
|
||||
privateKey, err = ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
|
||||
case CurveTypeP521:
|
||||
privateKey, err = ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
|
||||
default:
|
||||
_, privateKey, err = ed25519.GenerateKey(rand.Reader)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
20
pkg/cert/time_source.go
Normal file
20
pkg/cert/time_source.go
Normal file
|
@ -0,0 +1,20 @@
|
|||
//go:generate mockgen -source=time_source.go -destination=./../../internal/mock/cert/time_source_mock.go -package=cert_mock
|
||||
|
||||
package cert
|
||||
|
||||
import "time"
|
||||
|
||||
type TimeSource interface {
|
||||
UTCNow() time.Time
|
||||
}
|
||||
|
||||
func NewTimeSource() TimeSource {
|
||||
return &defaultTimeSource{}
|
||||
}
|
||||
|
||||
type defaultTimeSource struct {
|
||||
}
|
||||
|
||||
func (d defaultTimeSource) UTCNow() time.Time {
|
||||
return time.Now().UTC()
|
||||
}
|
43
pkg/defaulting/defaulter.go
Normal file
43
pkg/defaulting/defaulter.go
Normal file
|
@ -0,0 +1,43 @@
|
|||
package defaulting
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
)
|
||||
|
||||
type Defaulter func(instance interface{})
|
||||
|
||||
type Registry interface {
|
||||
Register(t reflect.Type, defaulter ...Defaulter)
|
||||
Apply(instance interface{})
|
||||
}
|
||||
|
||||
func New() Registry {
|
||||
return ®istry{
|
||||
defaulters: make(map[reflect.Type][]Defaulter),
|
||||
}
|
||||
}
|
||||
|
||||
type registry struct {
|
||||
defaulters map[reflect.Type][]Defaulter
|
||||
}
|
||||
|
||||
func (r *registry) Register(t reflect.Type, defaulter ...Defaulter) {
|
||||
var given []Defaulter
|
||||
if r, ok := r.defaulters[t]; ok {
|
||||
given = r
|
||||
}
|
||||
|
||||
for _, d := range defaulter {
|
||||
given = append(given, d)
|
||||
}
|
||||
|
||||
r.defaulters[t] = given
|
||||
}
|
||||
|
||||
func (r *registry) Apply(instance interface{}) {
|
||||
if defs, ok := r.defaulters[reflect.TypeOf(instance)]; ok {
|
||||
for _, def := range defs {
|
||||
def(instance)
|
||||
}
|
||||
}
|
||||
}
|
111
pkg/defaulting/defaulter_test.go
Normal file
111
pkg/defaulting/defaulter_test.go
Normal file
|
@ -0,0 +1,111 @@
|
|||
package defaulting
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_registry_Apply(t *testing.T) {
|
||||
type sample struct {
|
||||
i int
|
||||
}
|
||||
type fields struct {
|
||||
defaulters map[reflect.Type][]Defaulter
|
||||
}
|
||||
type args struct {
|
||||
instance interface{}
|
||||
}
|
||||
type expect struct {
|
||||
result interface{}
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
expect expect
|
||||
}{
|
||||
{
|
||||
name: "Expect setting a sample value",
|
||||
fields: fields{
|
||||
defaulters: map[reflect.Type][]Defaulter{
|
||||
reflect.TypeOf(&sample{}): {func(instance interface{}) {
|
||||
switch i := instance.(type) {
|
||||
case *sample:
|
||||
i.i = 42
|
||||
}
|
||||
}},
|
||||
},
|
||||
},
|
||||
args: args{
|
||||
instance: &sample{},
|
||||
},
|
||||
expect: expect{
|
||||
result: &sample{
|
||||
i: 42,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r := ®istry{
|
||||
defaulters: tt.fields.defaulters,
|
||||
}
|
||||
|
||||
r.Apply(tt.args.instance)
|
||||
|
||||
if !reflect.DeepEqual(tt.expect.result, tt.args.instance) {
|
||||
t.Errorf("Apply() expected = %v got %v", tt.args.instance, tt.expect.result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_registry_Register(t *testing.T) {
|
||||
type sample struct {
|
||||
}
|
||||
type fields struct {
|
||||
defaulters map[reflect.Type][]Defaulter
|
||||
}
|
||||
type args struct {
|
||||
t reflect.Type
|
||||
defaulter []Defaulter
|
||||
}
|
||||
type expect struct {
|
||||
length int
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
expect expect
|
||||
}{
|
||||
{
|
||||
name: "",
|
||||
fields: fields{
|
||||
defaulters: make(map[reflect.Type][]Defaulter),
|
||||
},
|
||||
args: args{
|
||||
t: reflect.TypeOf(sample{}),
|
||||
defaulter: []Defaulter{func(instance interface{}) {
|
||||
|
||||
}},
|
||||
},
|
||||
expect: expect{
|
||||
length: 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r := ®istry{
|
||||
defaulters: tt.fields.defaulters,
|
||||
}
|
||||
r.Register(tt.args.t, tt.args.defaulter...)
|
||||
|
||||
if length := len(r.defaulters); length != tt.expect.length {
|
||||
t.Errorf("len(r.defaulters) expect %d got %d", tt.expect.length, length)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
//go:generate mockgen -source=logger.go -destination=./../../internal/mock/logger_mock.go -package=mock
|
||||
//go:generate mockgen -source=logger.go -destination=./../../internal/mock/logging/logger_mock.go -package=logging_mock
|
||||
package logging
|
||||
|
||||
import "go.uber.org/zap"
|
||||
|
|
|
@ -6,7 +6,7 @@ require (
|
|||
github.com/baez90/inetmock v0.0.1
|
||||
github.com/miekg/dns v1.1.29
|
||||
github.com/spf13/viper v1.6.3
|
||||
go.uber.org/zap v1.14.1
|
||||
go.uber.org/zap v1.15.0
|
||||
golang.org/x/crypto v0.0.0-20200406173513-056763e48d71 // indirect
|
||||
golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e // indirect
|
||||
)
|
||||
|
|
|
@ -149,6 +149,8 @@ go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee/go.mod h1:vJERXedbb3MVM5f9E
|
|||
go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q=
|
||||
go.uber.org/zap v1.14.1 h1:nYDKopTbvAPq/NrUVZwT15y2lpROBiLLyoRTbXOYWOo=
|
||||
go.uber.org/zap v1.14.1/go.mod h1:Mb2vm2krFEG5DV0W9qcHBYFtp/Wku1cvYaqPsS/WYfc=
|
||||
go.uber.org/zap v1.15.0 h1:ZZCA22JRF2gQE5FoNmhmrf7jeJJ2uhqDUNRYKm8dvmM=
|
||||
go.uber.org/zap v1.15.0/go.mod h1:Mb2vm2krFEG5DV0W9qcHBYFtp/Wku1cvYaqPsS/WYfc=
|
||||
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
|
@ -188,6 +190,8 @@ golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7w
|
|||
golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200409092240-59c9f1ba88fa h1:mQTN3ECqfsViCNBgq+A40vdwhkGykrrQlYe3mPj6BoU=
|
||||
golang.org/x/sys v0.0.0-20200409092240-59c9f1ba88fa/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200420163511-1957bb5e6d1f h1:gWF768j/LaZugp8dyS4UwsslYCYz9XgFxvlgsn0n9H8=
|
||||
golang.org/x/sys v0.0.0-20200420163511-1957bb5e6d1f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs=
|
||||
|
|
|
@ -2,7 +2,7 @@ package main
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/baez90/inetmock/internal/config"
|
||||
"github.com/baez90/inetmock/pkg/api"
|
||||
"github.com/baez90/inetmock/pkg/logging"
|
||||
"github.com/miekg/dns"
|
||||
"go.uber.org/zap"
|
||||
|
@ -13,7 +13,7 @@ type dnsHandler struct {
|
|||
dnsServer []*dns.Server
|
||||
}
|
||||
|
||||
func (d *dnsHandler) Start(config config.HandlerConfig) (err error) {
|
||||
func (d *dnsHandler) Start(config api.HandlerConfig) (err error) {
|
||||
options := loadFromConfig(config.Options())
|
||||
addr := fmt.Sprintf("%s:%d", config.ListenAddress(), config.Port())
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ go 1.14
|
|||
require (
|
||||
github.com/baez90/inetmock v0.0.1
|
||||
github.com/spf13/viper v1.6.3
|
||||
go.uber.org/zap v1.14.1
|
||||
go.uber.org/zap v1.15.0
|
||||
)
|
||||
|
||||
replace github.com/baez90/inetmock v0.0.1 => ../../
|
||||
|
|
|
@ -147,6 +147,8 @@ go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee/go.mod h1:vJERXedbb3MVM5f9E
|
|||
go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q=
|
||||
go.uber.org/zap v1.14.1 h1:nYDKopTbvAPq/NrUVZwT15y2lpROBiLLyoRTbXOYWOo=
|
||||
go.uber.org/zap v1.14.1/go.mod h1:Mb2vm2krFEG5DV0W9qcHBYFtp/Wku1cvYaqPsS/WYfc=
|
||||
go.uber.org/zap v1.15.0 h1:ZZCA22JRF2gQE5FoNmhmrf7jeJJ2uhqDUNRYKm8dvmM=
|
||||
go.uber.org/zap v1.15.0/go.mod h1:Mb2vm2krFEG5DV0W9qcHBYFtp/Wku1cvYaqPsS/WYfc=
|
||||
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
|
@ -178,6 +180,8 @@ golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7w
|
|||
golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200409092240-59c9f1ba88fa h1:mQTN3ECqfsViCNBgq+A40vdwhkGykrrQlYe3mPj6BoU=
|
||||
golang.org/x/sys v0.0.0-20200409092240-59c9f1ba88fa/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200420163511-1957bb5e6d1f h1:gWF768j/LaZugp8dyS4UwsslYCYz9XgFxvlgsn0n9H8=
|
||||
golang.org/x/sys v0.0.0-20200420163511-1957bb5e6d1f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs=
|
||||
|
|
|
@ -2,7 +2,7 @@ package main
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/baez90/inetmock/internal/config"
|
||||
"github.com/baez90/inetmock/pkg/api"
|
||||
"github.com/baez90/inetmock/pkg/logging"
|
||||
"go.uber.org/zap"
|
||||
"net/http"
|
||||
|
@ -18,7 +18,7 @@ type httpHandler struct {
|
|||
server *http.Server
|
||||
}
|
||||
|
||||
func (p *httpHandler) Start(config config.HandlerConfig) (err error) {
|
||||
func (p *httpHandler) Start(config api.HandlerConfig) (err error) {
|
||||
options := loadFromConfig(config.Options())
|
||||
addr := fmt.Sprintf("%s:%d", config.ListenAddress(), config.Port())
|
||||
p.server = &http.Server{Addr: addr, Handler: p.router}
|
||||
|
|
|
@ -1,53 +0,0 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"gopkg.in/elazarl/goproxy.v1"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
const (
|
||||
passthroughStrategyName = "passthrough"
|
||||
notFoundStrategyName = "notfound"
|
||||
)
|
||||
|
||||
var (
|
||||
fallbackStrategies map[string]ProxyFallbackStrategy
|
||||
)
|
||||
|
||||
func init() {
|
||||
fallbackStrategies = map[string]ProxyFallbackStrategy{
|
||||
passthroughStrategyName: &passThroughFallbackStrategy{},
|
||||
notFoundStrategyName: ¬FoundFallbackStrategy{},
|
||||
}
|
||||
}
|
||||
|
||||
func StrategyForName(name string) ProxyFallbackStrategy {
|
||||
if strategy, ok := fallbackStrategies[name]; ok {
|
||||
return strategy
|
||||
}
|
||||
return fallbackStrategies[notFoundStrategyName]
|
||||
}
|
||||
|
||||
type ProxyFallbackStrategy interface {
|
||||
Apply(request *http.Request) (*http.Response, error)
|
||||
}
|
||||
|
||||
type passThroughFallbackStrategy struct {
|
||||
}
|
||||
|
||||
func (p passThroughFallbackStrategy) Apply(request *http.Request) (*http.Response, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
type notFoundFallbackStrategy struct {
|
||||
}
|
||||
|
||||
func (n notFoundFallbackStrategy) Apply(request *http.Request) (response *http.Response, err error) {
|
||||
response = goproxy.NewResponse(
|
||||
request,
|
||||
goproxy.ContentTypeText,
|
||||
http.StatusNotFound,
|
||||
"The requested resource was not found",
|
||||
)
|
||||
return
|
||||
}
|
|
@ -1,46 +0,0 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestStrategyForName(t *testing.T) {
|
||||
type args struct {
|
||||
name string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want reflect.Type
|
||||
}{
|
||||
{
|
||||
name: "Test get notfound strategy",
|
||||
want: reflect.TypeOf(¬FoundFallbackStrategy{}),
|
||||
args: args{
|
||||
name: "notfound",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Test get pass through strategy",
|
||||
want: reflect.TypeOf(&passThroughFallbackStrategy{}),
|
||||
args: args{
|
||||
name: "passthrough",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Test get fallback strategy notfound because key is not known",
|
||||
want: reflect.TypeOf(¬FoundFallbackStrategy{}),
|
||||
args: args{
|
||||
name: "asdf12234",
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := StrategyForName(tt.args.name); reflect.TypeOf(got) != tt.want {
|
||||
t.Errorf("StrategyForName() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -4,9 +4,8 @@ go 1.14
|
|||
|
||||
require (
|
||||
github.com/baez90/inetmock v0.0.1
|
||||
github.com/elazarl/goproxy v0.0.0-20200315184450-1f3cb6622dad
|
||||
github.com/spf13/viper v1.6.3
|
||||
go.uber.org/zap v1.14.1
|
||||
go.uber.org/zap v1.15.0
|
||||
gopkg.in/elazarl/goproxy.v1 v1.0.0-20180725130230-947c36da3153
|
||||
)
|
||||
|
||||
|
|
|
@ -152,6 +152,8 @@ go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee/go.mod h1:vJERXedbb3MVM5f9E
|
|||
go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q=
|
||||
go.uber.org/zap v1.14.1 h1:nYDKopTbvAPq/NrUVZwT15y2lpROBiLLyoRTbXOYWOo=
|
||||
go.uber.org/zap v1.14.1/go.mod h1:Mb2vm2krFEG5DV0W9qcHBYFtp/Wku1cvYaqPsS/WYfc=
|
||||
go.uber.org/zap v1.15.0 h1:ZZCA22JRF2gQE5FoNmhmrf7jeJJ2uhqDUNRYKm8dvmM=
|
||||
go.uber.org/zap v1.15.0/go.mod h1:Mb2vm2krFEG5DV0W9qcHBYFtp/Wku1cvYaqPsS/WYfc=
|
||||
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
|
@ -183,6 +185,8 @@ golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7w
|
|||
golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200409092240-59c9f1ba88fa h1:mQTN3ECqfsViCNBgq+A40vdwhkGykrrQlYe3mPj6BoU=
|
||||
golang.org/x/sys v0.0.0-20200409092240-59c9f1ba88fa/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200420163511-1957bb5e6d1f h1:gWF768j/LaZugp8dyS4UwsslYCYz9XgFxvlgsn0n9H8=
|
||||
golang.org/x/sys v0.0.0-20200420163511-1957bb5e6d1f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs=
|
||||
|
|
|
@ -2,7 +2,7 @@ package main
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/baez90/inetmock/internal/config"
|
||||
"github.com/baez90/inetmock/pkg/api"
|
||||
"github.com/baez90/inetmock/pkg/logging"
|
||||
"go.uber.org/zap"
|
||||
"gopkg.in/elazarl/goproxy.v1"
|
||||
|
@ -19,7 +19,7 @@ type httpProxy struct {
|
|||
server *http.Server
|
||||
}
|
||||
|
||||
func (h *httpProxy) Start(config config.HandlerConfig) (err error) {
|
||||
func (h *httpProxy) Start(config api.HandlerConfig) (err error) {
|
||||
options := loadFromConfig(config.Options())
|
||||
addr := fmt.Sprintf("%s:%d", config.ListenAddress(), config.Port())
|
||||
h.server = &http.Server{Addr: addr, Handler: h.proxy}
|
||||
|
@ -27,11 +27,20 @@ func (h *httpProxy) Start(config config.HandlerConfig) (err error) {
|
|||
zap.String("address", addr),
|
||||
)
|
||||
|
||||
tlsConfig := api.ServicesInstance().CertStore().TLSConfig()
|
||||
|
||||
proxyHandler := &proxyHttpHandler{
|
||||
options: options,
|
||||
logger: h.logger,
|
||||
}
|
||||
|
||||
proxyHttpsHandler := &proxyHttpsHandler{
|
||||
tlsConfig: tlsConfig,
|
||||
logger: h.logger,
|
||||
}
|
||||
|
||||
h.proxy.OnRequest().Do(proxyHandler)
|
||||
h.proxy.OnRequest().HandleConnect(proxyHttpsHandler)
|
||||
go h.startProxy()
|
||||
return
|
||||
}
|
||||
|
|
|
@ -1,52 +1,42 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/spf13/viper"
|
||||
"regexp"
|
||||
)
|
||||
|
||||
const (
|
||||
rulesConfigKey = "rules"
|
||||
patternConfigKey = "pattern"
|
||||
responseConfigKey = "response"
|
||||
fallbackStrategyConfigKey = "fallback"
|
||||
targetSchemeConfigKey = "target.scheme"
|
||||
targetIpAddressConfigKey = "target.ipAddress"
|
||||
targetPortConfigKey = "target.port"
|
||||
)
|
||||
|
||||
type targetRule struct {
|
||||
pattern *regexp.Regexp
|
||||
response string
|
||||
type redirectionTarget struct {
|
||||
scheme string
|
||||
ipAddress string
|
||||
port uint16
|
||||
}
|
||||
|
||||
func (tr targetRule) Pattern() *regexp.Regexp {
|
||||
return tr.pattern
|
||||
}
|
||||
|
||||
func (tr targetRule) Response() string {
|
||||
return tr.response
|
||||
func (rt redirectionTarget) host() string {
|
||||
return fmt.Sprintf("%s:%d", rt.ipAddress, rt.port)
|
||||
}
|
||||
|
||||
type httpProxyOptions struct {
|
||||
Rules []targetRule
|
||||
FallbackStrategy ProxyFallbackStrategy
|
||||
redirectionTarget redirectionTarget
|
||||
}
|
||||
|
||||
func loadFromConfig(config *viper.Viper) (options httpProxyOptions) {
|
||||
options.FallbackStrategy = StrategyForName(config.GetString(fallbackStrategyConfigKey))
|
||||
|
||||
anonRules := config.Get(rulesConfigKey).([]interface{})
|
||||
config.SetDefault(targetSchemeConfigKey, "http")
|
||||
config.SetDefault(targetIpAddressConfigKey, "127.0.0.1")
|
||||
config.SetDefault(targetPortConfigKey, "80")
|
||||
|
||||
for _, i := range anonRules {
|
||||
innerData := i.(map[interface{}]interface{})
|
||||
|
||||
if rulePattern, err := regexp.Compile(innerData[patternConfigKey].(string)); err == nil {
|
||||
options.Rules = append(options.Rules, targetRule{
|
||||
pattern: rulePattern,
|
||||
response: innerData[responseConfigKey].(string),
|
||||
})
|
||||
} else {
|
||||
panic(err)
|
||||
}
|
||||
options = httpProxyOptions{
|
||||
redirectionTarget{
|
||||
scheme: config.GetString(targetSchemeConfigKey),
|
||||
ipAddress: config.GetString(targetIpAddressConfigKey),
|
||||
port: uint16(config.GetInt(targetPortConfigKey)),
|
||||
},
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
|
|
@ -1,126 +1 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"github.com/spf13/viper"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_loadFromConfig(t *testing.T) {
|
||||
type args struct {
|
||||
config string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantOptions httpProxyOptions
|
||||
}{
|
||||
{
|
||||
name: "Parse proper configuration with notfound strategy",
|
||||
args: args{
|
||||
config: `
|
||||
fallback: notfound,
|
||||
rules:
|
||||
- pattern: ".*"
|
||||
response: ./assets/fakeFiles/default.html
|
||||
`,
|
||||
},
|
||||
wantOptions: httpProxyOptions{
|
||||
FallbackStrategy: StrategyForName(notFoundStrategyName),
|
||||
Rules: []targetRule{
|
||||
{
|
||||
response: "./assets/fakeFiles/default.html",
|
||||
pattern: regexp.MustCompile(".*"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Parse proper configuration with pass through strategy",
|
||||
args: args{
|
||||
config: `
|
||||
fallback: passthrough
|
||||
rules:
|
||||
- pattern: ".*"
|
||||
response: ./assets/fakeFiles/default.html
|
||||
`,
|
||||
},
|
||||
wantOptions: httpProxyOptions{
|
||||
FallbackStrategy: StrategyForName(passthroughStrategyName),
|
||||
Rules: []targetRule{
|
||||
{
|
||||
response: "./assets/fakeFiles/default.html",
|
||||
pattern: regexp.MustCompile(".*"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Parse proper configuration and preserve order of rules",
|
||||
args: args{
|
||||
config: `
|
||||
fallback: notfound
|
||||
rules:
|
||||
- pattern: ".*\\.(?i)txt"
|
||||
response: ./assets/fakeFiles/default.txt
|
||||
- pattern: ".*"
|
||||
response: ./assets/fakeFiles/default.html
|
||||
`,
|
||||
},
|
||||
wantOptions: httpProxyOptions{
|
||||
FallbackStrategy: StrategyForName(notFoundStrategyName),
|
||||
Rules: []targetRule{
|
||||
{
|
||||
response: "./assets/fakeFiles/default.txt",
|
||||
pattern: regexp.MustCompile(".*\\.(?i)txt"),
|
||||
},
|
||||
{
|
||||
response: "./assets/fakeFiles/default.html",
|
||||
pattern: regexp.MustCompile(".*"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Parse configuration with non existing fallback strategy key - falling back to 'notfound'",
|
||||
args: args{
|
||||
config: `
|
||||
fallback: doesNotExist
|
||||
rules: []
|
||||
`,
|
||||
},
|
||||
wantOptions: httpProxyOptions{
|
||||
FallbackStrategy: StrategyForName(notFoundStrategyName),
|
||||
Rules: nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Parse configuration without any fallback key",
|
||||
args: args{
|
||||
config: `
|
||||
f4llb4ck: doesNotExist
|
||||
rules: []
|
||||
`,
|
||||
},
|
||||
wantOptions: httpProxyOptions{
|
||||
FallbackStrategy: StrategyForName(notFoundStrategyName),
|
||||
Rules: nil,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
config := viper.New()
|
||||
config.SetConfigType("yaml")
|
||||
if err := config.ReadConfig(bytes.NewBufferString(tt.args.config)); err != nil {
|
||||
t.Errorf("failed to read config %v", err)
|
||||
return
|
||||
}
|
||||
if gotOptions := loadFromConfig(config); !reflect.DeepEqual(gotOptions, tt.wantOptions) {
|
||||
t.Errorf("loadFromConfig() = %v, want %v", gotOptions, tt.wantOptions)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,14 +1,13 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"github.com/baez90/inetmock/pkg/logging"
|
||||
"go.uber.org/zap"
|
||||
"gopkg.in/elazarl/goproxy.v1"
|
||||
"io"
|
||||
"mime"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
type proxyHttpHandler struct {
|
||||
|
@ -16,25 +15,27 @@ type proxyHttpHandler struct {
|
|||
logger logging.Logger
|
||||
}
|
||||
|
||||
/*
|
||||
TODO implement HTTPS proxy like in TLS interceptor
|
||||
func (p *proxyHttpHandler) HandleConnect(req string, ctx *goproxy.ProxyCtx) (*goproxy.ConnectAction, string) {
|
||||
type proxyHttpsHandler struct {
|
||||
tlsConfig *tls.Config
|
||||
logger logging.Logger
|
||||
}
|
||||
|
||||
func (p *proxyHttpsHandler) HandleConnect(req string, _ *goproxy.ProxyCtx) (*goproxy.ConnectAction, string) {
|
||||
p.logger.Info(
|
||||
"Intercepting HTTPS proxy request",
|
||||
zap.String("request", req),
|
||||
)
|
||||
|
||||
return &goproxy.ConnectAction{
|
||||
Action: goproxy.OkConnect,
|
||||
|
||||
Action: goproxy.ConnectMitm,
|
||||
TLSConfig: func(host string, ctx *goproxy.ProxyCtx) (*tls.Config, error) {
|
||||
return p.tlsConfig, nil
|
||||
},
|
||||
}, ""
|
||||
}*/
|
||||
|
||||
func (p *proxyHttpHandler) Handle(req *http.Request, _ *goproxy.ProxyCtx) (retReq *http.Request, resp *http.Response) {
|
||||
}
|
||||
|
||||
func (p *proxyHttpHandler) Handle(req *http.Request, ctx *goproxy.ProxyCtx) (retReq *http.Request, resp *http.Response) {
|
||||
retReq = req
|
||||
resp = &http.Response{
|
||||
Request: req,
|
||||
TransferEncoding: req.TransferEncoding,
|
||||
Header: make(http.Header),
|
||||
StatusCode: http.StatusOK,
|
||||
}
|
||||
|
||||
p.logger.Info(
|
||||
"Handling request",
|
||||
zap.String("source", req.RemoteAddr),
|
||||
|
@ -45,56 +46,48 @@ func (p *proxyHttpHandler) Handle(req *http.Request, _ *goproxy.ProxyCtx) (retRe
|
|||
zap.Reflect("headers", req.Header),
|
||||
)
|
||||
|
||||
for _, rule := range p.options.Rules {
|
||||
if rule.pattern.MatchString(req.URL.Path) {
|
||||
if file, err := os.Open(rule.response); err != nil {
|
||||
p.logger.Error(
|
||||
"failed to open response target file",
|
||||
zap.String("resonse", rule.response),
|
||||
zap.Error(err),
|
||||
)
|
||||
continue
|
||||
} else {
|
||||
resp.Body = file
|
||||
|
||||
if stat, err := file.Stat(); err == nil {
|
||||
resp.ContentLength = stat.Size()
|
||||
}
|
||||
|
||||
if contentType, err := GetContentType(rule, file); err == nil {
|
||||
resp.Header["Content-Type"] = []string{contentType}
|
||||
}
|
||||
|
||||
p.logger.Info("returning fake response from rules")
|
||||
return req, resp
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if resp, err := p.options.FallbackStrategy.Apply(req); err != nil {
|
||||
var err error
|
||||
if resp, err = ctx.RoundTrip(p.redirectHTTPRequest(req)); err != nil {
|
||||
p.logger.Error(
|
||||
"failed to apply fallback strategy",
|
||||
"error while doing roundtrip",
|
||||
zap.Error(err),
|
||||
)
|
||||
} else {
|
||||
p.logger.Info("returning fake response from fallback strategy")
|
||||
return req, resp
|
||||
return req, nil
|
||||
}
|
||||
|
||||
p.logger.Info("falling back to proxying request through")
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func GetContentType(rule targetRule, file *os.File) (contentType string, err error) {
|
||||
if contentType = mime.TypeByExtension(filepath.Ext(rule.response)); contentType != "" {
|
||||
return
|
||||
}
|
||||
|
||||
var buf [512]byte
|
||||
|
||||
n, _ := io.ReadFull(file, buf[:])
|
||||
|
||||
contentType = http.DetectContentType(buf[:n])
|
||||
_, err = file.Seek(0, io.SeekStart)
|
||||
return
|
||||
}
|
||||
|
||||
func (p proxyHttpHandler) redirectHTTPRequest(originalRequest *http.Request) (redirectReq *http.Request) {
|
||||
redirectReq = &http.Request{
|
||||
Method: originalRequest.Method,
|
||||
URL: &url.URL{
|
||||
Host: p.options.redirectionTarget.host(),
|
||||
Scheme: p.options.redirectionTarget.scheme,
|
||||
Path: originalRequest.URL.Path,
|
||||
ForceQuery: originalRequest.URL.ForceQuery,
|
||||
Fragment: originalRequest.URL.Fragment,
|
||||
Opaque: originalRequest.URL.Opaque,
|
||||
RawPath: originalRequest.URL.RawPath,
|
||||
RawQuery: originalRequest.URL.RawQuery,
|
||||
User: originalRequest.URL.User,
|
||||
},
|
||||
Proto: originalRequest.Proto,
|
||||
ProtoMajor: originalRequest.ProtoMajor,
|
||||
ProtoMinor: originalRequest.ProtoMinor,
|
||||
Header: originalRequest.Header,
|
||||
Body: originalRequest.Body,
|
||||
GetBody: originalRequest.GetBody,
|
||||
ContentLength: originalRequest.ContentLength,
|
||||
TransferEncoding: originalRequest.TransferEncoding,
|
||||
Close: false,
|
||||
Host: originalRequest.Host,
|
||||
Form: originalRequest.Form,
|
||||
PostForm: originalRequest.PostForm,
|
||||
MultipartForm: originalRequest.MultipartForm,
|
||||
Trailer: originalRequest.Trailer,
|
||||
}
|
||||
redirectReq = redirectReq.WithContext(context.Background())
|
||||
|
||||
return
|
||||
}
|
||||
|
|
|
@ -1,20 +0,0 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
ipExtractionRegex = regexp.MustCompile(`(.+):\d{1,5}$`)
|
||||
)
|
||||
|
||||
func extractIPFromAddress(addr string) (string, error) {
|
||||
matches := ipExtractionRegex.FindAllStringSubmatch(addr, -1)
|
||||
if len(matches) > 0 && len(matches[0]) >= 1 {
|
||||
return strings.Trim(matches[0][1], "[]"), nil
|
||||
} else {
|
||||
return "", fmt.Errorf("failed to extract IP address from addr %s", addr)
|
||||
}
|
||||
}
|
|
@ -1,204 +0,0 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"fmt"
|
||||
"github.com/baez90/inetmock/pkg/logging"
|
||||
"github.com/baez90/inetmock/pkg/path"
|
||||
"go.uber.org/zap"
|
||||
"math/big"
|
||||
"net"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
type keyProvider func() (key interface{}, err error)
|
||||
|
||||
type certStore struct {
|
||||
options *tlsOptions
|
||||
keyProvider keyProvider
|
||||
caCert *x509.Certificate
|
||||
caPrivateKey interface{}
|
||||
certCache map[string]*tls.Certificate
|
||||
timeSourceInstance timeSource
|
||||
logger logging.Logger
|
||||
}
|
||||
|
||||
func (cs *certStore) timeSource() timeSource {
|
||||
if cs.timeSourceInstance == nil {
|
||||
cs.timeSourceInstance = createTimeSource()
|
||||
}
|
||||
return cs.timeSourceInstance
|
||||
}
|
||||
|
||||
func (cs *certStore) initCaCert() (err error) {
|
||||
if path.FileExists(cs.options.rootCaCert.publicKeyPath) && path.FileExists(cs.options.rootCaCert.privateKeyPath) {
|
||||
var tlsCert tls.Certificate
|
||||
if tlsCert, err = tls.LoadX509KeyPair(cs.options.rootCaCert.publicKeyPath, cs.options.rootCaCert.privateKeyPath); err != nil {
|
||||
return
|
||||
} else if len(tlsCert.Certificate) > 0 {
|
||||
cs.caPrivateKey = tlsCert.PrivateKey
|
||||
cs.caCert, err = x509.ParseCertificate(tlsCert.Certificate[0])
|
||||
}
|
||||
} else {
|
||||
cs.caCert, cs.caPrivateKey, err = cs.generateCaCert()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (cs *certStore) initProvider() {
|
||||
if cs.keyProvider == nil {
|
||||
cs.keyProvider = func() (key interface{}, err error) {
|
||||
return privateKeyForCurve(cs.options)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (cs *certStore) generateCaCert() (pubKey *x509.Certificate, privateKey interface{}, err error) {
|
||||
cs.initProvider()
|
||||
timeSource := cs.timeSource()
|
||||
var serialNumber *big.Int
|
||||
if serialNumber, err = generateSerialNumber(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if privateKey, err = cs.keyProvider(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
template := x509.Certificate{
|
||||
SerialNumber: serialNumber,
|
||||
Subject: pkix.Name{
|
||||
Organization: []string{"INetMock"},
|
||||
Country: []string{"US"},
|
||||
Province: []string{""},
|
||||
Locality: []string{"San Francisco"},
|
||||
StreetAddress: []string{"Golden Gate Bridge"},
|
||||
PostalCode: []string{"94016"},
|
||||
},
|
||||
IsCA: true,
|
||||
NotBefore: timeSource.UTCNow().Add(-cs.options.validity.ca.notBeforeRelative),
|
||||
NotAfter: timeSource.UTCNow().Add(cs.options.validity.ca.notAfterRelative),
|
||||
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
|
||||
BasicConstraintsValid: true,
|
||||
}
|
||||
|
||||
var derBytes []byte
|
||||
if derBytes, err = x509.CreateCertificate(rand.Reader, &template, &template, publicKey(privateKey), privateKey); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
pubKey, err = x509.ParseCertificate(derBytes)
|
||||
|
||||
if err = writePublicKey(cs.options.rootCaCert.publicKeyPath, derBytes); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var privateKeyBytes []byte
|
||||
if privateKeyBytes, err = x509.MarshalPKCS8PrivateKey(privateKey); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err = writePrivateKey(cs.options.rootCaCert.privateKeyPath, privateKeyBytes); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (cs *certStore) getCertificate(serverName string, ip string) (cert *tls.Certificate, err error) {
|
||||
if cert, ok := cs.certCache[serverName]; ok {
|
||||
return cert, nil
|
||||
}
|
||||
|
||||
certPath := filepath.Join(cs.options.certCachePath, fmt.Sprintf("%s.pem", serverName))
|
||||
keyPath := filepath.Join(cs.options.certCachePath, fmt.Sprintf("%s.key", serverName))
|
||||
|
||||
if path.FileExists(certPath) && path.FileExists(keyPath) {
|
||||
if tlsCert, loadErr := tls.LoadX509KeyPair(certPath, keyPath); loadErr == nil {
|
||||
cs.certCache[serverName] = &tlsCert
|
||||
x509Cert, err := x509.ParseCertificate(tlsCert.Certificate[0])
|
||||
if err == nil && !certShouldBeRenewed(cs.timeSource(), x509Cert) {
|
||||
return &tlsCert, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if cert, err = cs.generateDomainCert(serverName, ip); err == nil {
|
||||
cs.certCache[serverName] = cert
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (cs *certStore) generateDomainCert(
|
||||
serverName string,
|
||||
localIp string,
|
||||
) (cert *tls.Certificate, err error) {
|
||||
cs.initProvider()
|
||||
|
||||
if cs.caCert == nil {
|
||||
if err = cs.initCaCert(); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
var serialNumber *big.Int
|
||||
if serialNumber, err = generateSerialNumber(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var privateKey interface{}
|
||||
if privateKey, err = cs.keyProvider(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
notBefore := cs.timeSource().UTCNow().Add(-cs.options.validity.domain.notBeforeRelative)
|
||||
notAfter := cs.timeSource().UTCNow().Add(cs.options.validity.domain.notAfterRelative)
|
||||
|
||||
cs.logger.Info(
|
||||
"generate domain certificate",
|
||||
zap.Time("notBefore", notBefore),
|
||||
zap.Time("notAfter", notAfter),
|
||||
)
|
||||
|
||||
template := x509.Certificate{
|
||||
SerialNumber: serialNumber,
|
||||
Subject: pkix.Name{
|
||||
Organization: []string{"INetMock"},
|
||||
},
|
||||
IPAddresses: []net.IP{net.ParseIP(localIp)},
|
||||
DNSNames: []string{serverName},
|
||||
NotBefore: notBefore,
|
||||
NotAfter: notAfter,
|
||||
KeyUsage: x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
|
||||
}
|
||||
|
||||
var derBytes []byte
|
||||
if derBytes, err = x509.CreateCertificate(rand.Reader, &template, cs.caCert, publicKey(privateKey), cs.caPrivateKey); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err = writePublicKey(filepath.Join(cs.options.certCachePath, fmt.Sprintf("%s.pem", serverName)), derBytes); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var privateKeyBytes []byte
|
||||
if privateKeyBytes, err = x509.MarshalPKCS8PrivateKey(privateKey); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if cert, err = parseCert(derBytes, privateKeyBytes); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err = writePrivateKey(filepath.Join(cs.options.certCachePath, fmt.Sprintf("%s.key", serverName)), privateKeyBytes); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
|
@ -1,213 +0,0 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"github.com/baez90/inetmock/pkg/logging"
|
||||
"go.uber.org/zap"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func Test_generateCaCert(t *testing.T) {
|
||||
tmpDir, err := ioutil.TempDir(os.TempDir(), "*-inetmock")
|
||||
if err != nil {
|
||||
t.Errorf("failed to create temp dir %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
options := &tlsOptions{
|
||||
ecdsaCurve: "P256",
|
||||
rootCaCert: cert{
|
||||
publicKeyPath: filepath.Join(tmpDir, "localhost.pem"),
|
||||
privateKeyPath: filepath.Join(tmpDir, "localhost.key"),
|
||||
},
|
||||
validity: validity{
|
||||
ca: certValidity{
|
||||
notBeforeRelative: time.Hour * 24 * 30,
|
||||
notAfterRelative: time.Hour * 24 * 30,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
certStore := certStore{
|
||||
options: options,
|
||||
}
|
||||
|
||||
defer func() {
|
||||
_ = os.Remove(tmpDir)
|
||||
}()
|
||||
|
||||
_, _, err = certStore.generateCaCert()
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("failed to generate CA cert %v", err)
|
||||
}
|
||||
|
||||
if _, err = os.Stat(options.rootCaCert.publicKeyPath); err != nil {
|
||||
t.Errorf("cert file was not created")
|
||||
}
|
||||
|
||||
if _, err = os.Stat(options.rootCaCert.privateKeyPath); err != nil {
|
||||
t.Errorf("cert file was not created")
|
||||
}
|
||||
}
|
||||
|
||||
func Test_generateDomainCert(t *testing.T) {
|
||||
|
||||
tmpDir, err := ioutil.TempDir(os.TempDir(), "*-inetmock")
|
||||
if err != nil {
|
||||
t.Errorf("failed to create temp dir %v", err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = os.Remove(tmpDir)
|
||||
}()
|
||||
|
||||
caTlsCert, _ := loadPEMCert(testCaCrt, testCaKey)
|
||||
caCert, _ := x509.ParseCertificate(caTlsCert.Certificate[0])
|
||||
|
||||
options := &tlsOptions{
|
||||
ecdsaCurve: "P256",
|
||||
certCachePath: tmpDir,
|
||||
validity: validity{
|
||||
domain: certValidity{
|
||||
notAfterRelative: time.Hour * 24 * 30,
|
||||
notBeforeRelative: time.Hour * 24 * 30,
|
||||
},
|
||||
ca: certValidity{
|
||||
notAfterRelative: time.Hour * 24 * 30,
|
||||
notBeforeRelative: time.Hour * 24 * 30,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
zapLogger, _ := zap.NewDevelopment()
|
||||
logger := logging.NewLogger(zapLogger)
|
||||
|
||||
certStore := certStore{
|
||||
options: options,
|
||||
caCert: caCert,
|
||||
logger: logger,
|
||||
caPrivateKey: caTlsCert.PrivateKey,
|
||||
}
|
||||
|
||||
type args struct {
|
||||
domain string
|
||||
ip string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "Test create google.com cert",
|
||||
args: args{
|
||||
domain: "google.com",
|
||||
ip: "127.0.0.1",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Test create golem.de cert",
|
||||
args: args{
|
||||
domain: "golem.de",
|
||||
ip: "127.0.0.1",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Test create golem.de cert with any IP address",
|
||||
args: args{
|
||||
domain: "golem.de",
|
||||
ip: "10.10.0.10",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if domainCert, err := certStore.generateDomainCert(
|
||||
tt.args.domain,
|
||||
tt.args.ip,
|
||||
); (err != nil) != tt.wantErr || reflect.DeepEqual(domainCert, tls.Certificate{}) {
|
||||
t.Errorf("generateDomainCert() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_certStore_initCaCert(t *testing.T) {
|
||||
|
||||
tmpDir, err := ioutil.TempDir(os.TempDir(), "*-inetmock")
|
||||
if err != nil {
|
||||
t.Errorf("failed to create temp dir %v", err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = os.Remove(tmpDir)
|
||||
}()
|
||||
|
||||
caCertPath := filepath.Join(tmpDir, "cacert.pem")
|
||||
caKeyPath := filepath.Join(tmpDir, "cacert.key")
|
||||
|
||||
if err := ioutil.WriteFile(caCertPath, testCaCrt, 0600); err != nil {
|
||||
t.Errorf("failed to write cacert.pem %v", err)
|
||||
return
|
||||
}
|
||||
if err := ioutil.WriteFile(caKeyPath, testCaKey, 0600); err != nil {
|
||||
t.Errorf("failed to write cacert.key %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
type fields struct {
|
||||
options *tlsOptions
|
||||
caCert *x509.Certificate
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "Init CA cert from file",
|
||||
wantErr: false,
|
||||
fields: fields{
|
||||
options: &tlsOptions{
|
||||
rootCaCert: cert{
|
||||
publicKeyPath: caCertPath,
|
||||
privateKeyPath: caKeyPath,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Init CA with new cert",
|
||||
wantErr: false,
|
||||
fields: fields{
|
||||
options: &tlsOptions{
|
||||
rootCaCert: cert{
|
||||
publicKeyPath: filepath.Join(tmpDir, "nonexistent.pem"),
|
||||
privateKeyPath: filepath.Join(tmpDir, "nonexistent.key"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cs := &certStore{
|
||||
options: tt.fields.options,
|
||||
caCert: tt.fields.caCert,
|
||||
}
|
||||
if err := cs.initCaCert(); (err != nil) != tt.wantErr || cs.caCert == nil {
|
||||
t.Errorf("initCaCert() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,106 +0,0 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/ed25519"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"math/big"
|
||||
"os"
|
||||
)
|
||||
|
||||
type curveType string
|
||||
|
||||
const (
|
||||
certificateBlockType = "CERTIFICATE"
|
||||
privateKeyBlockType = "PRIVATE KEY"
|
||||
|
||||
curveTypeP224 curveType = "P224"
|
||||
curveTypeP256 curveType = "P256"
|
||||
curveTypeP384 curveType = "P384"
|
||||
curveTypeP521 curveType = "P521"
|
||||
curveTypeED25519 curveType = "ED25519"
|
||||
)
|
||||
|
||||
func loadPEMCert(certPEMBytes []byte, keyPEMBytes []byte) (*tls.Certificate, error) {
|
||||
cert, err := tls.X509KeyPair(certPEMBytes, keyPEMBytes)
|
||||
return &cert, err
|
||||
}
|
||||
|
||||
func parseCert(derBytes []byte, privateKeyBytes []byte) (*tls.Certificate, error) {
|
||||
pemEncodedPublicKey := pem.EncodeToMemory(&pem.Block{Type: certificateBlockType, Bytes: derBytes})
|
||||
pemEncodedPrivateKey := pem.EncodeToMemory(&pem.Block{Type: privateKeyBlockType, Bytes: privateKeyBytes})
|
||||
cert, err := tls.X509KeyPair(pemEncodedPublicKey, pemEncodedPrivateKey)
|
||||
return &cert, err
|
||||
}
|
||||
|
||||
func writePublicKey(crtOutPath string, derBytes []byte) (err error) {
|
||||
var certOut *os.File
|
||||
if certOut, err = os.Create(crtOutPath); err != nil {
|
||||
return
|
||||
}
|
||||
if err = pem.Encode(certOut, &pem.Block{Type: certificateBlockType, Bytes: derBytes}); err != nil {
|
||||
return
|
||||
}
|
||||
err = certOut.Close()
|
||||
return
|
||||
}
|
||||
|
||||
func writePrivateKey(keyOutPath string, privateKeyBytes []byte) (err error) {
|
||||
var keyOut *os.File
|
||||
if keyOut, err = os.OpenFile(keyOutPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err = pem.Encode(keyOut, &pem.Block{Type: privateKeyBlockType, Bytes: privateKeyBytes}); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
err = keyOut.Close()
|
||||
return
|
||||
}
|
||||
|
||||
func certShouldBeRenewed(timeSource timeSource, cert *x509.Certificate) bool {
|
||||
lifetime := cert.NotAfter.Sub(cert.NotBefore)
|
||||
// if the cert is closer to the end of the lifetime than lifetime/2 it should be renewed
|
||||
if cert.NotAfter.Sub(timeSource.UTCNow()) < lifetime/4 {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func generateSerialNumber() (*big.Int, error) {
|
||||
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
|
||||
return rand.Int(rand.Reader, serialNumberLimit)
|
||||
}
|
||||
|
||||
func privateKeyForCurve(options *tlsOptions) (privateKey interface{}, err error) {
|
||||
switch options.ecdsaCurve {
|
||||
case "P224":
|
||||
privateKey, err = ecdsa.GenerateKey(elliptic.P224(), rand.Reader)
|
||||
case "P256":
|
||||
privateKey, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
case "P384":
|
||||
privateKey, err = ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
|
||||
case "P521":
|
||||
privateKey, err = ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
|
||||
default:
|
||||
_, privateKey, err = ed25519.GenerateKey(rand.Reader)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func publicKey(privateKey interface{}) interface{} {
|
||||
switch k := privateKey.(type) {
|
||||
case *ecdsa.PrivateKey:
|
||||
return &k.PublicKey
|
||||
case ed25519.PrivateKey:
|
||||
return k.Public().(ed25519.PublicKey)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
|
@ -1,74 +0,0 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type testTimeSource struct {
|
||||
nowValue time.Time
|
||||
}
|
||||
|
||||
func (t testTimeSource) UTCNow() time.Time {
|
||||
return t.nowValue
|
||||
}
|
||||
|
||||
func Test_certShouldBeRenewed(t *testing.T) {
|
||||
type args struct {
|
||||
timeSource timeSource
|
||||
cert *x509.Certificate
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "Detect cert is expired",
|
||||
want: true,
|
||||
args: args{
|
||||
cert: &x509.Certificate{
|
||||
NotAfter: time.Now().UTC().Add(1 * time.Hour),
|
||||
NotBefore: time.Now().UTC().Add(-1 * time.Hour),
|
||||
},
|
||||
timeSource: testTimeSource{
|
||||
nowValue: time.Now().UTC().Add(2 * time.Hour),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Detect cert should be renewed",
|
||||
want: true,
|
||||
args: args{
|
||||
cert: &x509.Certificate{
|
||||
NotAfter: time.Now().UTC().Add(1 * time.Hour),
|
||||
NotBefore: time.Now().UTC().Add(-1 * time.Hour),
|
||||
},
|
||||
timeSource: testTimeSource{
|
||||
nowValue: time.Now().UTC().Add(45 * time.Minute),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Detect cert shouldn't be renewed",
|
||||
want: false,
|
||||
args: args{
|
||||
cert: &x509.Certificate{
|
||||
NotAfter: time.Now().UTC().Add(1 * time.Hour),
|
||||
NotBefore: time.Now().UTC().Add(-1 * time.Hour),
|
||||
},
|
||||
timeSource: testTimeSource{
|
||||
nowValue: time.Now().UTC().Add(25 * time.Minute),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := certShouldBeRenewed(tt.args.timeSource, tt.args.cert); got != tt.want {
|
||||
t.Errorf("certShouldBeRenewed() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,113 +0,0 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"github.com/baez90/inetmock/pkg/logging"
|
||||
"github.com/spf13/cobra"
|
||||
"go.uber.org/zap"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
generateCACertOutPath = "cert-out"
|
||||
generateCAKeyOutPath = "key-out"
|
||||
generateCACurveName = "curve"
|
||||
generateCANotBeforeRelative = "not-before"
|
||||
generateCANotAfterRelative = "not-after"
|
||||
)
|
||||
|
||||
func generateCACmd(logger logging.Logger) *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "generate-ca",
|
||||
Short: "Generate a new CA certificate and corresponding key",
|
||||
Long: ``,
|
||||
Run: runGenerateCA(logger),
|
||||
}
|
||||
|
||||
cmd.Flags().String(generateCACertOutPath, "", "Path where CA cert file should be stored")
|
||||
cmd.Flags().String(generateCAKeyOutPath, "", "Path where CA key file should be stored")
|
||||
cmd.Flags().String(generateCACurveName, "", "Name of the curve to use, if empty ED25519 is used, other valid values are [P224, P256,P384,P521]")
|
||||
cmd.Flags().Duration(generateCANotBeforeRelative, 17520*time.Hour, "Relative time value since when in the past the CA certificate should be valid. The value has a time unit, the greatest time unit is h for hour.")
|
||||
cmd.Flags().Duration(generateCANotAfterRelative, 17520*time.Hour, "Relative time value until when in the future the CA certificate should be valid. The value has a time unit, the greatest time unit is h for hour.")
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
func getDurationFlag(cmd *cobra.Command, flagName string, logger logging.Logger) (val time.Duration, err error) {
|
||||
if val, err = cmd.Flags().GetDuration(flagName); err != nil {
|
||||
logger.Error(
|
||||
"failed to parse parse flag",
|
||||
zap.String("flag", flagName),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func getStringFlag(cmd *cobra.Command, flagName string, logger logging.Logger) (val string, err error) {
|
||||
if val, err = cmd.Flags().GetString(flagName); err != nil {
|
||||
logger.Error(
|
||||
"failed to parse parse flag",
|
||||
zap.String("flag", flagName),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func runGenerateCA(logger logging.Logger) func(cmd *cobra.Command, args []string) {
|
||||
return func(cmd *cobra.Command, args []string) {
|
||||
var certOutPath, keyOutPath, curveName string
|
||||
var notBefore, notAfter time.Duration
|
||||
var err error
|
||||
|
||||
if certOutPath, err = getStringFlag(cmd, generateCACertOutPath, logger); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if keyOutPath, err = getStringFlag(cmd, generateCAKeyOutPath, logger); err != nil {
|
||||
return
|
||||
}
|
||||
if curveName, err = getStringFlag(cmd, generateCACurveName, logger); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if notBefore, err = getDurationFlag(cmd, generateCANotBeforeRelative, logger); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if notAfter, err = getDurationFlag(cmd, generateCANotAfterRelative, logger); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
logger := logger.With(
|
||||
zap.String(generateCACurveName, curveName),
|
||||
zap.String(generateCACertOutPath, certOutPath),
|
||||
zap.String(generateCAKeyOutPath, keyOutPath),
|
||||
)
|
||||
|
||||
certStore := certStore{
|
||||
options: &tlsOptions{
|
||||
ecdsaCurve: curveType(curveName),
|
||||
validity: validity{
|
||||
ca: certValidity{
|
||||
notAfterRelative: notAfter,
|
||||
notBeforeRelative: notBefore,
|
||||
},
|
||||
},
|
||||
rootCaCert: cert{
|
||||
publicKeyPath: certOutPath,
|
||||
privateKeyPath: keyOutPath,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if _, _, err := certStore.generateCaCert(); err != nil {
|
||||
logger.Error(
|
||||
"failed to generate CA cert",
|
||||
zap.Error(err),
|
||||
)
|
||||
} else {
|
||||
logger.Info("Successfully generated CA cert")
|
||||
}
|
||||
}
|
||||
}
|
|
@ -4,9 +4,8 @@ go 1.14
|
|||
|
||||
require (
|
||||
github.com/baez90/inetmock v0.0.1
|
||||
github.com/spf13/cobra v1.0.0
|
||||
github.com/spf13/viper v1.6.3
|
||||
go.uber.org/zap v1.14.1
|
||||
go.uber.org/zap v1.15.0
|
||||
)
|
||||
|
||||
replace github.com/baez90/inetmock v0.0.1 => ../../
|
||||
|
|
|
@ -147,6 +147,8 @@ go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee/go.mod h1:vJERXedbb3MVM5f9E
|
|||
go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q=
|
||||
go.uber.org/zap v1.14.1 h1:nYDKopTbvAPq/NrUVZwT15y2lpROBiLLyoRTbXOYWOo=
|
||||
go.uber.org/zap v1.14.1/go.mod h1:Mb2vm2krFEG5DV0W9qcHBYFtp/Wku1cvYaqPsS/WYfc=
|
||||
go.uber.org/zap v1.15.0 h1:ZZCA22JRF2gQE5FoNmhmrf7jeJJ2uhqDUNRYKm8dvmM=
|
||||
go.uber.org/zap v1.15.0/go.mod h1:Mb2vm2krFEG5DV0W9qcHBYFtp/Wku1cvYaqPsS/WYfc=
|
||||
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
|
@ -178,6 +180,8 @@ golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7w
|
|||
golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200409092240-59c9f1ba88fa h1:mQTN3ECqfsViCNBgq+A40vdwhkGykrrQlYe3mPj6BoU=
|
||||
golang.org/x/sys v0.0.0-20200409092240-59c9f1ba88fa/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200420163511-1957bb5e6d1f h1:gWF768j/LaZugp8dyS4UwsslYCYz9XgFxvlgsn0n9H8=
|
||||
golang.org/x/sys v0.0.0-20200420163511-1957bb5e6d1f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs=
|
||||
|
|
|
@ -19,5 +19,5 @@ func init() {
|
|||
logger: logger,
|
||||
currentConnectionsCount: &sync.WaitGroup{},
|
||||
}
|
||||
}, generateCACmd(logger))
|
||||
})
|
||||
}
|
||||
|
|
|
@ -2,9 +2,9 @@ package main
|
|||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"github.com/baez90/inetmock/internal/config"
|
||||
"github.com/baez90/inetmock/pkg/api"
|
||||
"github.com/baez90/inetmock/pkg/cert"
|
||||
"github.com/baez90/inetmock/pkg/logging"
|
||||
"go.uber.org/zap"
|
||||
"net"
|
||||
|
@ -17,16 +17,16 @@ const (
|
|||
)
|
||||
|
||||
type tlsInterceptor struct {
|
||||
options tlsOptions
|
||||
logger logging.Logger
|
||||
listener net.Listener
|
||||
certStore *certStore
|
||||
options *tlsOptions
|
||||
certStore cert.Store
|
||||
shutdownRequested bool
|
||||
currentConnectionsCount *sync.WaitGroup
|
||||
currentConnections []*proxyConn
|
||||
}
|
||||
|
||||
func (t *tlsInterceptor) Start(config config.HandlerConfig) (err error) {
|
||||
func (t *tlsInterceptor) Start(config api.HandlerConfig) (err error) {
|
||||
t.options = loadFromConfig(config.Options())
|
||||
addr := fmt.Sprintf("%s:%d", config.ListenAddress(), config.Port())
|
||||
|
||||
|
@ -35,33 +35,7 @@ func (t *tlsInterceptor) Start(config config.HandlerConfig) (err error) {
|
|||
zap.String("target", t.options.redirectionTarget.address()),
|
||||
)
|
||||
|
||||
t.certStore = &certStore{
|
||||
options: t.options,
|
||||
certCache: make(map[string]*tls.Certificate),
|
||||
logger: t.logger,
|
||||
}
|
||||
|
||||
if err = t.certStore.initCaCert(); err != nil {
|
||||
t.logger.Error(
|
||||
"failed to initialize CA cert",
|
||||
zap.Error(err),
|
||||
)
|
||||
err = fmt.Errorf(
|
||||
"failed to initialize CA cert: %w",
|
||||
err,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
rootCaPool := x509.NewCertPool()
|
||||
rootCaPool.AddCert(t.certStore.caCert)
|
||||
|
||||
tlsConfig := &tls.Config{
|
||||
GetCertificate: t.getCert,
|
||||
RootCAs: rootCaPool,
|
||||
}
|
||||
|
||||
if t.listener, err = tls.Listen("tcp", addr, tlsConfig); err != nil {
|
||||
if t.listener, err = tls.Listen("tcp", addr, api.ServicesInstance().CertStore().TLSConfig()); err != nil {
|
||||
t.logger.Fatal(
|
||||
"failed to create tls listener",
|
||||
zap.Error(err),
|
||||
|
@ -122,23 +96,6 @@ func (t *tlsInterceptor) startListener() {
|
|||
}
|
||||
}
|
||||
|
||||
func (t *tlsInterceptor) getCert(info *tls.ClientHelloInfo) (cert *tls.Certificate, err error) {
|
||||
var localIp string
|
||||
if localIp, err = extractIPFromAddress(info.Conn.LocalAddr().String()); err != nil {
|
||||
localIp = "127.0.0.1"
|
||||
}
|
||||
if cert, err = t.certStore.getCertificate(info.ServerName, localIp); err != nil {
|
||||
t.logger.Error(
|
||||
"error while resolving certificate",
|
||||
zap.String("serverName", info.ServerName),
|
||||
zap.String("localAddr", localIp),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (t *tlsInterceptor) proxyConn(conn net.Conn) {
|
||||
defer conn.Close()
|
||||
|
||||
|
|
|
@ -3,37 +3,13 @@ package main
|
|||
import (
|
||||
"fmt"
|
||||
"github.com/spf13/viper"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
certCachePathConfigKey = "certCachePath"
|
||||
ecdsaCurveConfigKey = "ecdsaCurve"
|
||||
targetIpAddressConfigKey = "target.ipAddress"
|
||||
targetPortConfigKey = "target.port"
|
||||
publicKeyConfigKey = "rootCaCert.publicKey"
|
||||
privateKeyPathConfigKey = "rootCaCert.privateKey"
|
||||
caCertValidityNotBeforeKey = "validity.ca.notBeforeRelative"
|
||||
caCertValidityNotAfterKey = "validity.ca.notAfterRelative"
|
||||
domainCertValidityNotBeforeKey = "validity.domain.notBeforeRelative"
|
||||
domainCertValidityNotAfterKey = "validity.domain.notAfterRelative"
|
||||
targetIpAddressConfigKey = "target.ipAddress"
|
||||
targetPortConfigKey = "target.port"
|
||||
)
|
||||
|
||||
type cert struct {
|
||||
publicKeyPath string
|
||||
privateKeyPath string
|
||||
}
|
||||
|
||||
type certValidity struct {
|
||||
notBeforeRelative time.Duration
|
||||
notAfterRelative time.Duration
|
||||
}
|
||||
|
||||
type validity struct {
|
||||
ca certValidity
|
||||
domain certValidity
|
||||
}
|
||||
|
||||
type redirectionTarget struct {
|
||||
ipAddress string
|
||||
port uint16
|
||||
|
@ -44,35 +20,14 @@ func (rt redirectionTarget) address() string {
|
|||
}
|
||||
|
||||
type tlsOptions struct {
|
||||
rootCaCert cert
|
||||
certCachePath string
|
||||
redirectionTarget redirectionTarget
|
||||
ecdsaCurve curveType
|
||||
validity validity
|
||||
}
|
||||
|
||||
func loadFromConfig(config *viper.Viper) *tlsOptions {
|
||||
|
||||
return &tlsOptions{
|
||||
certCachePath: config.GetString(certCachePathConfigKey),
|
||||
ecdsaCurve: curveType(config.GetString(ecdsaCurveConfigKey)),
|
||||
func loadFromConfig(config *viper.Viper) tlsOptions {
|
||||
return tlsOptions{
|
||||
redirectionTarget: redirectionTarget{
|
||||
ipAddress: config.GetString(targetIpAddressConfigKey),
|
||||
port: uint16(config.GetInt(targetPortConfigKey)),
|
||||
},
|
||||
validity: validity{
|
||||
ca: certValidity{
|
||||
notBeforeRelative: config.GetDuration(caCertValidityNotBeforeKey),
|
||||
notAfterRelative: config.GetDuration(caCertValidityNotAfterKey),
|
||||
},
|
||||
domain: certValidity{
|
||||
notBeforeRelative: config.GetDuration(domainCertValidityNotBeforeKey),
|
||||
notAfterRelative: config.GetDuration(domainCertValidityNotAfterKey),
|
||||
},
|
||||
},
|
||||
rootCaCert: cert{
|
||||
publicKeyPath: config.GetString(publicKeyConfigKey),
|
||||
privateKeyPath: config.GetString(privateKeyPathConfigKey),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,55 +0,0 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
testCaCrt []byte
|
||||
testCaKey []byte
|
||||
)
|
||||
|
||||
func init() {
|
||||
tmpDir, err := ioutil.TempDir(os.TempDir(), "*-inetmock")
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("failed to create temp dir %v", err))
|
||||
}
|
||||
|
||||
options := &tlsOptions{
|
||||
ecdsaCurve: "P256",
|
||||
rootCaCert: cert{
|
||||
publicKeyPath: filepath.Join(tmpDir, "localhost.pem"),
|
||||
privateKeyPath: filepath.Join(tmpDir, "localhost.key"),
|
||||
},
|
||||
validity: validity{
|
||||
ca: certValidity{
|
||||
notBeforeRelative: time.Hour * 24 * 30,
|
||||
notAfterRelative: time.Hour * 24 * 30,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
certStore := certStore{
|
||||
options: options,
|
||||
keyProvider: func() (key interface{}, err error) {
|
||||
return privateKeyForCurve(options)
|
||||
},
|
||||
}
|
||||
|
||||
defer func() {
|
||||
_ = os.Remove(tmpDir)
|
||||
}()
|
||||
|
||||
_, _, err = certStore.generateCaCert()
|
||||
|
||||
testCaCrt, _ = ioutil.ReadFile(options.rootCaCert.publicKeyPath)
|
||||
testCaKey, _ = ioutil.ReadFile(options.rootCaCert.privateKeyPath)
|
||||
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("failed to generate CA cert %v", err))
|
||||
}
|
||||
}
|
|
@ -1,18 +0,0 @@
|
|||
package main
|
||||
|
||||
import "time"
|
||||
|
||||
type timeSource interface {
|
||||
UTCNow() time.Time
|
||||
}
|
||||
|
||||
func createTimeSource() timeSource {
|
||||
return &defaultTimeSource{}
|
||||
}
|
||||
|
||||
type defaultTimeSource struct {
|
||||
}
|
||||
|
||||
func (d defaultTimeSource) UTCNow() time.Time {
|
||||
return time.Now().UTC()
|
||||
}
|
Loading…
Reference in a new issue