Move HTTPS handling to http_mock handler

This commit is contained in:
Peter 2021-01-26 18:42:07 +01:00
parent 63a6516d99
commit 6d2737b501
Signed by: prskr
GPG key ID: C1DB5D2E8DB512F9
14 changed files with 124 additions and 97 deletions

View file

@ -1,5 +0,0 @@
-----BEGIN PRIVATE KEY-----
MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgTTz25fFLS2WO4hXD
162B059HEe+MAQtV4iGXf7HfKCihRANCAAT3D181Tzrz6i9Mx75pmyAsg+itojO9
sHXZSswmfsh46IVK46m0hXNHgPvD2WYW5m1PHvRl3B0vDo/2Y6sOU/Q9
-----END PRIVATE KEY-----

View file

@ -1,12 +0,0 @@
-----BEGIN CERTIFICATE-----
MIIB3DCCAYKgAwIBAgIQHQIFIEcNZjsDP+wDtGPMXzAKBggqhkjOPQQDAjBOMRAw
DgYDVQQGEwdnZXJtYW55MREwDwYDVQQHEwhEb3J0bXVuZDERMA8GA1UEChMISU5l
dE1vY2sxFDASBgNVBAMTC0lOZXRNb2NrIENBMB4XDTIwMDYxNTEwNTEzNloXDTIw
MDYxNTEwNTEzNlowTjEQMA4GA1UEBhMHZ2VybWFueTERMA8GA1UEBxMIRG9ydG11
bmQxETAPBgNVBAoTCElOZXRNb2NrMRQwEgYDVQQDEwtJTmV0TW9jayBDQTBZMBMG
ByqGSM49AgEGCCqGSM49AwEHA0IABPcPXzVPOvPqL0zHvmmbICyD6K2iM72wddlK
zCZ+yHjohUrjqbSFc0eA+8PZZhbmbU8e9GXcHS8Oj/Zjqw5T9D2jQjBAMA4GA1Ud
DwEB/wQEAwIChDAdBgNVHSUEFjAUBggrBgEFBQcDAgYIKwYBBQUHAwEwDwYDVR0T
AQH/BAUwAwEB/zAKBggqhkjOPQQDAgNIADBFAiBecJsOL7ej0kCkWOnoQJpW3JuY
KQIxQBT+XXPKEJj14AIhANG4twTloC3amz8Y7Zn3DVtvjXlTgg8YwjBFG+JioQOe
-----END CERTIFICATE-----

View file

@ -8,7 +8,6 @@ import (
"github.com/spf13/cobra"
"gitlab.com/inetmock/inetmock/pkg/cert"
"gitlab.com/inetmock/inetmock/pkg/config"
"gitlab.com/inetmock/inetmock/pkg/logging"
"go.uber.org/zap"
)
@ -34,6 +33,7 @@ var (
certOutPath, curveName string
)
//nolint:lll
func init() {
generateCaCmd = &cobra.Command{
Use: "generate-ca",
@ -50,7 +50,7 @@ func init() {
generateCaCmd.Flags().StringSliceVar(&caCertOptions.Locality, generateCaLocalityName, nil, "Locality information to append to certificate")
generateCaCmd.Flags().StringSliceVar(&caCertOptions.StreetAddress, generateCaStreetAddressName, nil, "Street address information to append to certificate")
generateCaCmd.Flags().StringSliceVar(&caCertOptions.PostalCode, generateCaPostalCodeName, nil, "Postal code information to append to certificate")
generateCaCmd.Flags().StringVar(&certOutPath, generateCACertOutPath, "", "Path where CA files should be stored")
generateCaCmd.Flags().StringVar(&certOutPath, generateCACertOutPath, "", "Path where CA files should be stored")
generateCaCmd.Flags().StringVar(&curveName, generateCACurveName, "", "Name of the curve to use, if empty ED25519 is used, other valid values are [P224, P256,P384,P521]")
generateCaCmd.Flags().DurationVar(&notBefore, generateCANotBeforeRelative, 17520*time.Hour, "Relative time value since when in the past the CA certificate should be valid. The value has a time unit, the greatest time unit is h for hour.")
generateCaCmd.Flags().DurationVar(&notAfter, generateCANotAfterRelative, 17520*time.Hour, "Relative time value until when in the future the CA certificate should be valid. The value has a time unit, the greatest time unit is h for hour.")
@ -108,25 +108,3 @@ func runGenerateCA(_ *cobra.Command, _ []string) {
}
logger.Info("completed certificate generation")
}
func getDurationFlag(cmd *cobra.Command, flagName string, logger logging.Logger) (val time.Duration, err error) {
if val, err = cmd.Flags().GetDuration(flagName); err != nil {
logger.Error(
"failed to parse parse flag",
zap.String("flag", flagName),
zap.Error(err),
)
}
return
}
func getStringFlag(cmd *cobra.Command, flagName string, logger logging.Logger) (val string, err error) {
if val, err = cmd.Flags().GetString(flagName); err != nil {
logger.Error(
"failed to parse parse flag",
zap.String("flag", flagName),
zap.Error(err),
)
}
return
}

View file

@ -69,6 +69,15 @@ endpoints:
- 8080
options:
<<: *httpResponseRules
https:
handler: http_mock
listenAddress: 0.0.0.0
ports:
- 443
- 8443
options:
tls: true
<<: *httpResponseRules
proxy:
handler: http_proxy
listenAddress: 0.0.0.0
@ -78,16 +87,6 @@ endpoints:
target:
ipAddress: 127.0.0.1
port: 80
httpsDowngrade:
handler: tls_interceptor
listenAddress: 0.0.0.0
ports:
- 443
- 8443
options:
target:
ipAddress: 127.0.0.1
port: 80
plainDns:
handler: dns_mock
listenAddress: 0.0.0.0

View file

@ -108,35 +108,67 @@ func (a *app) MustRun() {
}
func (a *app) Logger() logging.Logger {
return a.ctx.Value(loggerKey).(logging.Logger)
val := a.ctx.Value(loggerKey)
if val == nil {
return nil
}
return val.(logging.Logger)
}
func (a *app) Config() config.Config {
return a.ctx.Value(configKey).(config.Config)
val := a.ctx.Value(configKey)
if val == nil {
return nil
}
return val.(config.Config)
}
func (a *app) CertStore() cert.Store {
return a.ctx.Value(certStoreKey).(cert.Store)
val := a.ctx.Value(certStoreKey)
if val == nil {
return nil
}
return val.(cert.Store)
}
func (a *app) Checker() health.Checker {
return a.ctx.Value(healthCheckerKey).(health.Checker)
val := a.ctx.Value(healthCheckerKey)
if val == nil {
return nil
}
return val.(health.Checker)
}
func (a *app) EndpointManager() endpoint.EndpointManager {
return a.ctx.Value(endpointManagerKey).(endpoint.EndpointManager)
val := a.ctx.Value(endpointManagerKey)
if val == nil {
return nil
}
return val.(endpoint.EndpointManager)
}
func (a *app) Audit() audit.Emitter {
return a.ctx.Value(eventStreamKey).(audit.Emitter)
val := a.ctx.Value(eventStreamKey)
if val == nil {
return nil
}
return val.(audit.Emitter)
}
func (a *app) EventStream() audit.EventStream {
return a.ctx.Value(eventStreamKey).(audit.EventStream)
val := a.ctx.Value(eventStreamKey)
if val == nil {
return nil
}
return val.(audit.EventStream)
}
func (a *app) HandlerRegistry() api.HandlerRegistry {
return a.ctx.Value(handlerRegistryKey).(api.HandlerRegistry)
val := a.ctx.Value(handlerRegistryKey)
if val == nil {
return nil
}
return val.(api.HandlerRegistry)
}
func (a *app) Context() context.Context {

View file

@ -149,8 +149,6 @@ func startEndpoint(ep Endpoint, ctx api.PluginContext, logger logging.Logger) (s
success = false
}
close(startSuccessful)
return
}

View file

@ -1,6 +1,7 @@
package http
import (
"crypto/tls"
"net/http"
"gitlab.com/inetmock/inetmock/pkg/audit"
@ -24,8 +25,8 @@ func EventFromRequest(request *http.Request, app audit.AppProtocol) audit.Event
if request.TLS != nil {
ev.TLS = &audit.TLSDetails{
Version: request.TLS.Version,
CipherSuite: request.TLS.CipherSuite,
Version: audit.TLSVersionToEntity(request.TLS.Version).String(),
CipherSuite: tls.CipherSuiteName(request.TLS.CipherSuite),
ServerName: request.TLS.ServerName,
}
}

View file

@ -2,6 +2,7 @@ package mock
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net/http"
@ -50,11 +51,16 @@ func (p *httpHandler) Start(ctx api.PluginContext, config config.HandlerConfig)
ConnContext: imHttp.StoreConnPropertiesInContext,
}
if options.TLS {
p.server.TLSConfig = ctx.CertStore().TLSConfig()
p.server.TLSNextProto = make(map[string]func(*http.Server, *tls.Conn, http.Handler))
}
for _, rule := range options.Rules {
router.setupRoute(rule)
}
go p.startServer()
go p.startServer(options.TLS)
return
}
@ -73,8 +79,17 @@ func (p *httpHandler) Shutdown(ctx context.Context) (err error) {
return
}
func (p *httpHandler) startServer() {
if err := p.server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
func (p *httpHandler) startServer(tls bool) {
var listen func() error
if tls {
listen = func() error {
return p.server.ListenAndServeTLS("", "")
}
} else {
listen = p.server.ListenAndServe
}
if err := listen(); err != nil && !errors.Is(err, http.ErrServerClosed) {
p.logger.Error(
"failed to start http listener",
zap.Error(err),

View file

@ -50,6 +50,7 @@ func (tr targetRule) Response() string {
}
type httpOptions struct {
TLS bool
Rules []targetRule
}
@ -62,6 +63,7 @@ func loadFromConfig(config *viper.Viper) (options httpOptions, err error) {
}
tmpRules := struct {
TLS bool
Rules []tmpCfg
}{}
@ -69,6 +71,8 @@ func loadFromConfig(config *viper.Viper) (options httpOptions, err error) {
return
}
options.TLS = tmpRules.TLS
for _, i := range tmpRules.Rules {
var rulePattern *regexp.Regexp
var matchTargetValue RequestMatchTarget

View file

@ -95,6 +95,34 @@ rules:
},
wantErr: false,
},
{
name: "Parse config with header matcher and TLS true",
args: args{
config: `
tls: true
rules:
- pattern: "^application/octet-stream$"
target: Content-Type
matcher: Header
response: ./assets/fakeFiles/sample.exe
`,
},
wantOptions: httpOptions{
TLS: true,
Rules: []targetRule{
{
pattern: regexp.MustCompile("^application/octet-stream$"),
response: func() string {
p, _ := filepath.Abs("./assets/fakeFiles/sample.exe")
return p
}(),
requestMatchTarget: RequestMatchTargetHeader,
targetKey: "Content-Type",
},
},
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {

View file

@ -27,8 +27,8 @@ var (
SourcePort: 32344,
DestinationPort: 80,
TLS: &audit.TLSDetails{
Version: tls.VersionTLS13,
CipherSuite: tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
Version: audit.TLSVersionToEntity(tls.VersionTLS13).String(),
CipherSuite: tls.CipherSuiteName(tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA),
ServerName: "localhost",
},
ProtocolDetails: details.HTTP{

View file

@ -1,8 +1,6 @@
package sink
import (
"crypto/tls"
"gitlab.com/inetmock/inetmock/pkg/audit"
"gitlab.com/inetmock/inetmock/pkg/logging"
"go.uber.org/zap"
@ -34,7 +32,8 @@ func (l logSink) OnSubscribe(evs <-chan audit.Event, _ audit.CloseHandle) {
if ev.TLS != nil {
eventLogger = eventLogger.With(
zap.String("tls_server_name", ev.TLS.ServerName),
zap.String("tls_cipher_suite", tls.CipherSuiteName(ev.TLS.CipherSuite)),
zap.String("tls_cipher_suite", ev.TLS.CipherSuite),
zap.String("tls_version", ev.TLS.Version),
)
}

View file

@ -27,8 +27,8 @@ var (
SourcePort: 32344,
DestinationPort: 80,
TLS: &audit.TLSDetails{
Version: tls.VersionTLS13,
CipherSuite: tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
Version: audit.TLSVersionToEntity(tls.VersionTLS13).String(),
CipherSuite: tls.CipherSuiteName(tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA),
ServerName: "localhost",
},
ProtocolDetails: details.HTTP{

View file

@ -1,8 +1,6 @@
package audit
import (
"crypto/tls"
)
import "crypto/tls"
var (
tlsToEntity = map[uint16]TLSVersion{
@ -12,45 +10,37 @@ var (
tls.VersionTLS12: TLSVersion_TLS12,
tls.VersionTLS13: TLSVersion_TLS13,
}
entityToTls = map[TLSVersion]uint16{
TLSVersion_SSLv30: tls.VersionSSL30,
TLSVersion_TLS10: tls.VersionTLS10,
TLSVersion_TLS11: tls.VersionTLS11,
TLSVersion_TLS12: tls.VersionTLS12,
TLSVersion_TLS13: tls.VersionTLS13,
}
cipherSuiteIDLookup = func(name string) uint16 {
for _, cs := range tls.CipherSuites() {
if cs.Name == name {
return cs.ID
}
}
return 0
}
)
type TLSDetails struct {
Version uint16
CipherSuite uint16
Version string
CipherSuite string
ServerName string
}
func TLSVersionToEntity(version uint16) TLSVersion {
if v, known := tlsToEntity[version]; known {
return v
}
return TLSVersion_SSLv30
}
func NewTLSDetailsFromProto(entity *TLSDetailsEntity) *TLSDetails {
if entity == nil {
return nil
}
return &TLSDetails{
Version: entityToTls[entity.GetVersion()],
CipherSuite: cipherSuiteIDLookup(entity.GetCipherSuite()),
Version: entity.GetVersion().String(),
CipherSuite: entity.GetCipherSuite(),
ServerName: entity.GetServerName(),
}
}
func (d TLSDetails) ProtoMessage() *TLSDetailsEntity {
return &TLSDetailsEntity{
Version: tlsToEntity[d.Version],
CipherSuite: tls.CipherSuiteName(d.CipherSuite),
Version: TLSVersion(TLSVersion_value[d.Version]),
CipherSuite: d.CipherSuite,
ServerName: d.ServerName,
}
}