api/pkg/cert/store.go

150 lines
3.4 KiB
Go
Raw Normal View History

package cert
import (
"crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"github.com/baez90/inetmock/pkg/config"
"github.com/baez90/inetmock/pkg/logging"
"go.uber.org/zap"
"net"
)
const (
ipv4Loopback = "127.0.0.1"
)
var (
defaultKeyProvider = func(options config.CertOptions) func() (key interface{}, err error) {
return func() (key interface{}, err error) {
return privateKeyForCurve(options)
}
}
)
type KeyProvider func() (key interface{}, err error)
type Store interface {
CACert() *tls.Certificate
GetCertificate(serverName string, ip string) (*tls.Certificate, error)
TLSConfig() *tls.Config
}
func NewDefaultStore(
config config.Config,
logger logging.Logger,
) (Store, error) {
timeSource := NewTimeSource()
return NewStore(
config.TLSConfig(),
NewFileSystemCache(config.TLSConfig().CertCachePath, timeSource),
NewDefaultGenerator(config.TLSConfig()),
logger,
)
}
func NewStore(
options config.CertOptions,
cache Cache,
generator Generator,
logger logging.Logger,
) (Store, error) {
store := &store{
options: options,
cache: cache,
generator: generator,
logger: logger,
}
if err := store.loadCACert(); err != nil {
return nil, err
}
return store, nil
}
type store struct {
options config.CertOptions
caCert *tls.Certificate
cache Cache
timeSource TimeSource
generator Generator
logger logging.Logger
}
func (s *store) TLSConfig() *tls.Config {
rootCaPool := x509.NewCertPool()
rootCaPubKey, _ := x509.ParseCertificate(s.caCert.Certificate[0])
rootCaPool.AddCert(rootCaPubKey)
return &tls.Config{
GetCertificate: func(info *tls.ClientHelloInfo) (cert *tls.Certificate, err error) {
var localIp string
if localIp, err = extractIPFromAddress(info.Conn.LocalAddr().String()); err != nil {
localIp = ipv4Loopback
}
if cert, err = s.GetCertificate(info.ServerName, localIp); err != nil {
s.logger.Error(
"error while resolving certificate",
zap.String("serverName", info.ServerName),
zap.String("localAddr", localIp),
zap.Error(err),
)
}
return
},
RootCAs: rootCaPool,
}
}
func (s *store) loadCACert() (err error) {
pemCrt := NewPEM(nil)
if err = pemCrt.ReadFrom(s.options.RootCACert.PublicKeyPath, s.options.RootCACert.PrivateKeyPath); err != nil {
return
}
s.caCert = pemCrt.Cert()
return
}
func (s *store) CACert() *tls.Certificate {
return s.caCert
}
func (s *store) GetCertificate(serverName string, ip string) (cert *tls.Certificate, err error) {
if crt, ok := s.cache.Get(serverName); ok {
return crt, nil
}
if cert, err = s.generator.ServerCert(GenerationOptions{
CommonName: serverName,
DNSNames: []string{serverName},
IPAddresses: []net.IP{net.ParseIP(ip)},
}, s.caCert); err == nil {
s.cache.Put(cert)
}
return
}
func privateKeyForCurve(options config.CertOptions) (privateKey interface{}, err error) {
switch options.Curve {
case config.CurveTypeP224:
privateKey, err = ecdsa.GenerateKey(elliptic.P224(), rand.Reader)
case config.CurveTypeP256:
privateKey, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
case config.CurveTypeP384:
privateKey, err = ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
case config.CurveTypeP521:
privateKey, err = ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
default:
_, privateKey, err = ed25519.GenerateKey(rand.Reader)
}
return
}