Move HTTPS handling to http_mock handler

This commit is contained in:
Peter 2021-01-26 18:42:07 +01:00
parent 63a6516d99
commit 6d2737b501
Signed by: prskr
GPG key ID: C1DB5D2E8DB512F9
14 changed files with 124 additions and 97 deletions

View file

@ -1,5 +0,0 @@
-----BEGIN PRIVATE KEY-----
MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgTTz25fFLS2WO4hXD
162B059HEe+MAQtV4iGXf7HfKCihRANCAAT3D181Tzrz6i9Mx75pmyAsg+itojO9
sHXZSswmfsh46IVK46m0hXNHgPvD2WYW5m1PHvRl3B0vDo/2Y6sOU/Q9
-----END PRIVATE KEY-----

View file

@ -1,12 +0,0 @@
-----BEGIN CERTIFICATE-----
MIIB3DCCAYKgAwIBAgIQHQIFIEcNZjsDP+wDtGPMXzAKBggqhkjOPQQDAjBOMRAw
DgYDVQQGEwdnZXJtYW55MREwDwYDVQQHEwhEb3J0bXVuZDERMA8GA1UEChMISU5l
dE1vY2sxFDASBgNVBAMTC0lOZXRNb2NrIENBMB4XDTIwMDYxNTEwNTEzNloXDTIw
MDYxNTEwNTEzNlowTjEQMA4GA1UEBhMHZ2VybWFueTERMA8GA1UEBxMIRG9ydG11
bmQxETAPBgNVBAoTCElOZXRNb2NrMRQwEgYDVQQDEwtJTmV0TW9jayBDQTBZMBMG
ByqGSM49AgEGCCqGSM49AwEHA0IABPcPXzVPOvPqL0zHvmmbICyD6K2iM72wddlK
zCZ+yHjohUrjqbSFc0eA+8PZZhbmbU8e9GXcHS8Oj/Zjqw5T9D2jQjBAMA4GA1Ud
DwEB/wQEAwIChDAdBgNVHSUEFjAUBggrBgEFBQcDAgYIKwYBBQUHAwEwDwYDVR0T
AQH/BAUwAwEB/zAKBggqhkjOPQQDAgNIADBFAiBecJsOL7ej0kCkWOnoQJpW3JuY
KQIxQBT+XXPKEJj14AIhANG4twTloC3amz8Y7Zn3DVtvjXlTgg8YwjBFG+JioQOe
-----END CERTIFICATE-----

View file

@ -8,7 +8,6 @@ import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
"gitlab.com/inetmock/inetmock/pkg/cert" "gitlab.com/inetmock/inetmock/pkg/cert"
"gitlab.com/inetmock/inetmock/pkg/config" "gitlab.com/inetmock/inetmock/pkg/config"
"gitlab.com/inetmock/inetmock/pkg/logging"
"go.uber.org/zap" "go.uber.org/zap"
) )
@ -34,6 +33,7 @@ var (
certOutPath, curveName string certOutPath, curveName string
) )
//nolint:lll
func init() { func init() {
generateCaCmd = &cobra.Command{ generateCaCmd = &cobra.Command{
Use: "generate-ca", Use: "generate-ca",
@ -50,7 +50,7 @@ func init() {
generateCaCmd.Flags().StringSliceVar(&caCertOptions.Locality, generateCaLocalityName, nil, "Locality information to append to certificate") generateCaCmd.Flags().StringSliceVar(&caCertOptions.Locality, generateCaLocalityName, nil, "Locality information to append to certificate")
generateCaCmd.Flags().StringSliceVar(&caCertOptions.StreetAddress, generateCaStreetAddressName, nil, "Street address information to append to certificate") generateCaCmd.Flags().StringSliceVar(&caCertOptions.StreetAddress, generateCaStreetAddressName, nil, "Street address information to append to certificate")
generateCaCmd.Flags().StringSliceVar(&caCertOptions.PostalCode, generateCaPostalCodeName, nil, "Postal code information to append to certificate") generateCaCmd.Flags().StringSliceVar(&caCertOptions.PostalCode, generateCaPostalCodeName, nil, "Postal code information to append to certificate")
generateCaCmd.Flags().StringVar(&certOutPath, generateCACertOutPath, "", "Path where CA files should be stored") generateCaCmd.Flags().StringVar(&certOutPath, generateCACertOutPath, "", "Path where CA files should be stored")
generateCaCmd.Flags().StringVar(&curveName, generateCACurveName, "", "Name of the curve to use, if empty ED25519 is used, other valid values are [P224, P256,P384,P521]") generateCaCmd.Flags().StringVar(&curveName, generateCACurveName, "", "Name of the curve to use, if empty ED25519 is used, other valid values are [P224, P256,P384,P521]")
generateCaCmd.Flags().DurationVar(&notBefore, generateCANotBeforeRelative, 17520*time.Hour, "Relative time value since when in the past the CA certificate should be valid. The value has a time unit, the greatest time unit is h for hour.") generateCaCmd.Flags().DurationVar(&notBefore, generateCANotBeforeRelative, 17520*time.Hour, "Relative time value since when in the past the CA certificate should be valid. The value has a time unit, the greatest time unit is h for hour.")
generateCaCmd.Flags().DurationVar(&notAfter, generateCANotAfterRelative, 17520*time.Hour, "Relative time value until when in the future the CA certificate should be valid. The value has a time unit, the greatest time unit is h for hour.") generateCaCmd.Flags().DurationVar(&notAfter, generateCANotAfterRelative, 17520*time.Hour, "Relative time value until when in the future the CA certificate should be valid. The value has a time unit, the greatest time unit is h for hour.")
@ -108,25 +108,3 @@ func runGenerateCA(_ *cobra.Command, _ []string) {
} }
logger.Info("completed certificate generation") logger.Info("completed certificate generation")
} }
func getDurationFlag(cmd *cobra.Command, flagName string, logger logging.Logger) (val time.Duration, err error) {
if val, err = cmd.Flags().GetDuration(flagName); err != nil {
logger.Error(
"failed to parse parse flag",
zap.String("flag", flagName),
zap.Error(err),
)
}
return
}
func getStringFlag(cmd *cobra.Command, flagName string, logger logging.Logger) (val string, err error) {
if val, err = cmd.Flags().GetString(flagName); err != nil {
logger.Error(
"failed to parse parse flag",
zap.String("flag", flagName),
zap.Error(err),
)
}
return
}

View file

@ -69,6 +69,15 @@ endpoints:
- 8080 - 8080
options: options:
<<: *httpResponseRules <<: *httpResponseRules
https:
handler: http_mock
listenAddress: 0.0.0.0
ports:
- 443
- 8443
options:
tls: true
<<: *httpResponseRules
proxy: proxy:
handler: http_proxy handler: http_proxy
listenAddress: 0.0.0.0 listenAddress: 0.0.0.0
@ -78,16 +87,6 @@ endpoints:
target: target:
ipAddress: 127.0.0.1 ipAddress: 127.0.0.1
port: 80 port: 80
httpsDowngrade:
handler: tls_interceptor
listenAddress: 0.0.0.0
ports:
- 443
- 8443
options:
target:
ipAddress: 127.0.0.1
port: 80
plainDns: plainDns:
handler: dns_mock handler: dns_mock
listenAddress: 0.0.0.0 listenAddress: 0.0.0.0

View file

@ -108,35 +108,67 @@ func (a *app) MustRun() {
} }
func (a *app) Logger() logging.Logger { func (a *app) Logger() logging.Logger {
return a.ctx.Value(loggerKey).(logging.Logger) val := a.ctx.Value(loggerKey)
if val == nil {
return nil
}
return val.(logging.Logger)
} }
func (a *app) Config() config.Config { func (a *app) Config() config.Config {
return a.ctx.Value(configKey).(config.Config) val := a.ctx.Value(configKey)
if val == nil {
return nil
}
return val.(config.Config)
} }
func (a *app) CertStore() cert.Store { func (a *app) CertStore() cert.Store {
return a.ctx.Value(certStoreKey).(cert.Store) val := a.ctx.Value(certStoreKey)
if val == nil {
return nil
}
return val.(cert.Store)
} }
func (a *app) Checker() health.Checker { func (a *app) Checker() health.Checker {
return a.ctx.Value(healthCheckerKey).(health.Checker) val := a.ctx.Value(healthCheckerKey)
if val == nil {
return nil
}
return val.(health.Checker)
} }
func (a *app) EndpointManager() endpoint.EndpointManager { func (a *app) EndpointManager() endpoint.EndpointManager {
return a.ctx.Value(endpointManagerKey).(endpoint.EndpointManager) val := a.ctx.Value(endpointManagerKey)
if val == nil {
return nil
}
return val.(endpoint.EndpointManager)
} }
func (a *app) Audit() audit.Emitter { func (a *app) Audit() audit.Emitter {
return a.ctx.Value(eventStreamKey).(audit.Emitter) val := a.ctx.Value(eventStreamKey)
if val == nil {
return nil
}
return val.(audit.Emitter)
} }
func (a *app) EventStream() audit.EventStream { func (a *app) EventStream() audit.EventStream {
return a.ctx.Value(eventStreamKey).(audit.EventStream) val := a.ctx.Value(eventStreamKey)
if val == nil {
return nil
}
return val.(audit.EventStream)
} }
func (a *app) HandlerRegistry() api.HandlerRegistry { func (a *app) HandlerRegistry() api.HandlerRegistry {
return a.ctx.Value(handlerRegistryKey).(api.HandlerRegistry) val := a.ctx.Value(handlerRegistryKey)
if val == nil {
return nil
}
return val.(api.HandlerRegistry)
} }
func (a *app) Context() context.Context { func (a *app) Context() context.Context {

View file

@ -149,8 +149,6 @@ func startEndpoint(ep Endpoint, ctx api.PluginContext, logger logging.Logger) (s
success = false success = false
} }
close(startSuccessful)
return return
} }

View file

@ -1,6 +1,7 @@
package http package http
import ( import (
"crypto/tls"
"net/http" "net/http"
"gitlab.com/inetmock/inetmock/pkg/audit" "gitlab.com/inetmock/inetmock/pkg/audit"
@ -24,8 +25,8 @@ func EventFromRequest(request *http.Request, app audit.AppProtocol) audit.Event
if request.TLS != nil { if request.TLS != nil {
ev.TLS = &audit.TLSDetails{ ev.TLS = &audit.TLSDetails{
Version: request.TLS.Version, Version: audit.TLSVersionToEntity(request.TLS.Version).String(),
CipherSuite: request.TLS.CipherSuite, CipherSuite: tls.CipherSuiteName(request.TLS.CipherSuite),
ServerName: request.TLS.ServerName, ServerName: request.TLS.ServerName,
} }
} }

View file

@ -2,6 +2,7 @@ package mock
import ( import (
"context" "context"
"crypto/tls"
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
@ -50,11 +51,16 @@ func (p *httpHandler) Start(ctx api.PluginContext, config config.HandlerConfig)
ConnContext: imHttp.StoreConnPropertiesInContext, ConnContext: imHttp.StoreConnPropertiesInContext,
} }
if options.TLS {
p.server.TLSConfig = ctx.CertStore().TLSConfig()
p.server.TLSNextProto = make(map[string]func(*http.Server, *tls.Conn, http.Handler))
}
for _, rule := range options.Rules { for _, rule := range options.Rules {
router.setupRoute(rule) router.setupRoute(rule)
} }
go p.startServer() go p.startServer(options.TLS)
return return
} }
@ -73,8 +79,17 @@ func (p *httpHandler) Shutdown(ctx context.Context) (err error) {
return return
} }
func (p *httpHandler) startServer() { func (p *httpHandler) startServer(tls bool) {
if err := p.server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { var listen func() error
if tls {
listen = func() error {
return p.server.ListenAndServeTLS("", "")
}
} else {
listen = p.server.ListenAndServe
}
if err := listen(); err != nil && !errors.Is(err, http.ErrServerClosed) {
p.logger.Error( p.logger.Error(
"failed to start http listener", "failed to start http listener",
zap.Error(err), zap.Error(err),

View file

@ -50,6 +50,7 @@ func (tr targetRule) Response() string {
} }
type httpOptions struct { type httpOptions struct {
TLS bool
Rules []targetRule Rules []targetRule
} }
@ -62,6 +63,7 @@ func loadFromConfig(config *viper.Viper) (options httpOptions, err error) {
} }
tmpRules := struct { tmpRules := struct {
TLS bool
Rules []tmpCfg Rules []tmpCfg
}{} }{}
@ -69,6 +71,8 @@ func loadFromConfig(config *viper.Viper) (options httpOptions, err error) {
return return
} }
options.TLS = tmpRules.TLS
for _, i := range tmpRules.Rules { for _, i := range tmpRules.Rules {
var rulePattern *regexp.Regexp var rulePattern *regexp.Regexp
var matchTargetValue RequestMatchTarget var matchTargetValue RequestMatchTarget

View file

@ -95,6 +95,34 @@ rules:
}, },
wantErr: false, wantErr: false,
}, },
{
name: "Parse config with header matcher and TLS true",
args: args{
config: `
tls: true
rules:
- pattern: "^application/octet-stream$"
target: Content-Type
matcher: Header
response: ./assets/fakeFiles/sample.exe
`,
},
wantOptions: httpOptions{
TLS: true,
Rules: []targetRule{
{
pattern: regexp.MustCompile("^application/octet-stream$"),
response: func() string {
p, _ := filepath.Abs("./assets/fakeFiles/sample.exe")
return p
}(),
requestMatchTarget: RequestMatchTargetHeader,
targetKey: "Content-Type",
},
},
},
wantErr: false,
},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {

View file

@ -27,8 +27,8 @@ var (
SourcePort: 32344, SourcePort: 32344,
DestinationPort: 80, DestinationPort: 80,
TLS: &audit.TLSDetails{ TLS: &audit.TLSDetails{
Version: tls.VersionTLS13, Version: audit.TLSVersionToEntity(tls.VersionTLS13).String(),
CipherSuite: tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, CipherSuite: tls.CipherSuiteName(tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA),
ServerName: "localhost", ServerName: "localhost",
}, },
ProtocolDetails: details.HTTP{ ProtocolDetails: details.HTTP{

View file

@ -1,8 +1,6 @@
package sink package sink
import ( import (
"crypto/tls"
"gitlab.com/inetmock/inetmock/pkg/audit" "gitlab.com/inetmock/inetmock/pkg/audit"
"gitlab.com/inetmock/inetmock/pkg/logging" "gitlab.com/inetmock/inetmock/pkg/logging"
"go.uber.org/zap" "go.uber.org/zap"
@ -34,7 +32,8 @@ func (l logSink) OnSubscribe(evs <-chan audit.Event, _ audit.CloseHandle) {
if ev.TLS != nil { if ev.TLS != nil {
eventLogger = eventLogger.With( eventLogger = eventLogger.With(
zap.String("tls_server_name", ev.TLS.ServerName), zap.String("tls_server_name", ev.TLS.ServerName),
zap.String("tls_cipher_suite", tls.CipherSuiteName(ev.TLS.CipherSuite)), zap.String("tls_cipher_suite", ev.TLS.CipherSuite),
zap.String("tls_version", ev.TLS.Version),
) )
} }

View file

@ -27,8 +27,8 @@ var (
SourcePort: 32344, SourcePort: 32344,
DestinationPort: 80, DestinationPort: 80,
TLS: &audit.TLSDetails{ TLS: &audit.TLSDetails{
Version: tls.VersionTLS13, Version: audit.TLSVersionToEntity(tls.VersionTLS13).String(),
CipherSuite: tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, CipherSuite: tls.CipherSuiteName(tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA),
ServerName: "localhost", ServerName: "localhost",
}, },
ProtocolDetails: details.HTTP{ ProtocolDetails: details.HTTP{

View file

@ -1,8 +1,6 @@
package audit package audit
import ( import "crypto/tls"
"crypto/tls"
)
var ( var (
tlsToEntity = map[uint16]TLSVersion{ tlsToEntity = map[uint16]TLSVersion{
@ -12,45 +10,37 @@ var (
tls.VersionTLS12: TLSVersion_TLS12, tls.VersionTLS12: TLSVersion_TLS12,
tls.VersionTLS13: TLSVersion_TLS13, tls.VersionTLS13: TLSVersion_TLS13,
} }
entityToTls = map[TLSVersion]uint16{
TLSVersion_SSLv30: tls.VersionSSL30,
TLSVersion_TLS10: tls.VersionTLS10,
TLSVersion_TLS11: tls.VersionTLS11,
TLSVersion_TLS12: tls.VersionTLS12,
TLSVersion_TLS13: tls.VersionTLS13,
}
cipherSuiteIDLookup = func(name string) uint16 {
for _, cs := range tls.CipherSuites() {
if cs.Name == name {
return cs.ID
}
}
return 0
}
) )
type TLSDetails struct { type TLSDetails struct {
Version uint16 Version string
CipherSuite uint16 CipherSuite string
ServerName string ServerName string
} }
func TLSVersionToEntity(version uint16) TLSVersion {
if v, known := tlsToEntity[version]; known {
return v
}
return TLSVersion_SSLv30
}
func NewTLSDetailsFromProto(entity *TLSDetailsEntity) *TLSDetails { func NewTLSDetailsFromProto(entity *TLSDetailsEntity) *TLSDetails {
if entity == nil { if entity == nil {
return nil return nil
} }
return &TLSDetails{ return &TLSDetails{
Version: entityToTls[entity.GetVersion()], Version: entity.GetVersion().String(),
CipherSuite: cipherSuiteIDLookup(entity.GetCipherSuite()), CipherSuite: entity.GetCipherSuite(),
ServerName: entity.GetServerName(), ServerName: entity.GetServerName(),
} }
} }
func (d TLSDetails) ProtoMessage() *TLSDetailsEntity { func (d TLSDetails) ProtoMessage() *TLSDetailsEntity {
return &TLSDetailsEntity{ return &TLSDetailsEntity{
Version: tlsToEntity[d.Version], Version: TLSVersion(TLSVersion_value[d.Version]),
CipherSuite: tls.CipherSuiteName(d.CipherSuite), CipherSuite: d.CipherSuite,
ServerName: d.ServerName, ServerName: d.ServerName,
} }
} }