Peter Kurfer
d70ba748f5
- merge packages to get a more concise layout because plugins are no more and therefore there's not a lot to be exported - fix test logger - rework config parsing to be easier and more transparent - remove unnecessary APIs because dynamic endpoint handling is rather a won't implement
161 lines
3.5 KiB
Go
161 lines
3.5 KiB
Go
package interceptor
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"net"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/prometheus/client_golang/prometheus"
|
|
"gitlab.com/inetmock/inetmock/internal/endpoint"
|
|
"gitlab.com/inetmock/inetmock/pkg/logging"
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
const (
|
|
name = "tls_interceptor"
|
|
)
|
|
|
|
type tlsInterceptor struct {
|
|
name string
|
|
options tlsOptions
|
|
logger logging.Logger
|
|
listener net.Listener
|
|
shutdownRequested bool
|
|
currentConnectionsCount *sync.WaitGroup
|
|
currentConnections map[uuid.UUID]*proxyConn
|
|
connectionsMutex *sync.Mutex
|
|
}
|
|
|
|
func (t *tlsInterceptor) Start(ctx endpoint.Lifecycle) (err error) {
|
|
t.name = ctx.Name()
|
|
|
|
if err = ctx.UnmarshalOptions(&t.options); err != nil {
|
|
return
|
|
}
|
|
|
|
t.logger = t.logger.With(
|
|
zap.String("handler_name", ctx.Name()),
|
|
zap.String("address", ctx.Uplink().Addr().String()),
|
|
zap.String("Target", t.options.Target.address()),
|
|
)
|
|
|
|
t.listener = tls.NewListener(ctx.Uplink().Listener, ctx.CertStore().TLSConfig())
|
|
|
|
go t.startListener()
|
|
go t.shutdownOnContextDone(ctx.Context())
|
|
return
|
|
}
|
|
|
|
func (t *tlsInterceptor) shutdownOnContextDone(ctx context.Context) {
|
|
<-ctx.Done()
|
|
t.logger.Info("Shutting down TLS interceptor")
|
|
t.shutdownRequested = true
|
|
done := make(chan struct{})
|
|
go func() {
|
|
t.currentConnectionsCount.Wait()
|
|
close(done)
|
|
}()
|
|
|
|
select {
|
|
case <-done:
|
|
return
|
|
case <-time.After(100 * time.Millisecond):
|
|
for _, proxyConn := range t.currentConnections {
|
|
if err := proxyConn.Close(); err != nil {
|
|
t.logger.Error(
|
|
"error while closing remaining proxy connections",
|
|
zap.Error(err),
|
|
)
|
|
}
|
|
}
|
|
return
|
|
}
|
|
}
|
|
|
|
func (t *tlsInterceptor) startListener() {
|
|
for !t.shutdownRequested {
|
|
conn, err := t.listener.Accept()
|
|
if err != nil {
|
|
t.logger.Error(
|
|
"error during accept",
|
|
zap.Error(err),
|
|
)
|
|
continue
|
|
}
|
|
|
|
handledRequestCounter.WithLabelValues(t.name).Inc()
|
|
openConnectionsGauge.WithLabelValues(t.name).Inc()
|
|
t.currentConnectionsCount.Add(1)
|
|
go t.proxyConn(conn)
|
|
}
|
|
}
|
|
|
|
func (t *tlsInterceptor) proxyConn(conn net.Conn) {
|
|
timer := prometheus.NewTimer(requestDurationHistogram.WithLabelValues(t.name))
|
|
defer func() {
|
|
_ = conn.Close()
|
|
t.currentConnectionsCount.Done()
|
|
openConnectionsGauge.WithLabelValues(t.name).Dec()
|
|
timer.ObserveDuration()
|
|
}()
|
|
|
|
rAddr, err := net.ResolveTCPAddr("tcp", t.options.Target.address())
|
|
if err != nil {
|
|
t.logger.Error(
|
|
"failed to resolve proxy Target",
|
|
zap.Error(err),
|
|
)
|
|
}
|
|
|
|
targetConn, err := net.DialTCP("tcp", nil, rAddr)
|
|
if err != nil {
|
|
t.logger.Error(
|
|
"failed to connect to proxy Target",
|
|
zap.Error(err),
|
|
)
|
|
return
|
|
}
|
|
defer targetConn.Close()
|
|
|
|
proxyCon := &proxyConn{
|
|
source: conn,
|
|
target: targetConn,
|
|
}
|
|
|
|
conUID := uuid.New()
|
|
t.storeConnection(conUID, proxyCon)
|
|
Pipe(conn, targetConn)
|
|
t.cleanConnection(conUID)
|
|
|
|
switch tlsConn := conn.(type) {
|
|
case *tls.Conn:
|
|
if tlsConn.Handshake() != nil {
|
|
t.logger.Error(
|
|
"error occurred during TLS handshake",
|
|
zap.Error(tlsConn.Handshake()),
|
|
)
|
|
}
|
|
}
|
|
|
|
t.logger.Info(
|
|
"connection closed",
|
|
zap.String("remoteAddr", conn.RemoteAddr().String()),
|
|
)
|
|
}
|
|
|
|
func (t *tlsInterceptor) storeConnection(connUUID uuid.UUID, conn *proxyConn) {
|
|
t.connectionsMutex.Lock()
|
|
defer t.connectionsMutex.Unlock()
|
|
t.currentConnections[connUUID] = conn
|
|
}
|
|
|
|
func (t *tlsInterceptor) cleanConnection(connUUID uuid.UUID) {
|
|
t.connectionsMutex.Lock()
|
|
defer t.connectionsMutex.Unlock()
|
|
if _, ok := t.currentConnections[connUUID]; ok {
|
|
delete(t.currentConnections, connUUID)
|
|
}
|
|
}
|