Peter Kurfer
49e58ac2e4
- move to Gitlab - make code better testable - create app abstraction for server - cleanup
301 lines
6.8 KiB
Go
301 lines
6.8 KiB
Go
package cert
|
|
|
|
import (
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"fmt"
|
|
"os"
|
|
"path"
|
|
"reflect"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/golang/mock/gomock"
|
|
certmock "gitlab.com/inetmock/inetmock/internal/mock/cert"
|
|
"gitlab.com/inetmock/inetmock/pkg/config"
|
|
)
|
|
|
|
const (
|
|
cnLocalhost = "localhost"
|
|
caCN = "UnitTests"
|
|
serverRelativeValidity = 24 * time.Hour
|
|
caRelativeValidity = 168 * time.Hour
|
|
)
|
|
|
|
var (
|
|
serverCN = fmt.Sprintf("%s-%d", cnLocalhost, time.Now().Unix())
|
|
)
|
|
|
|
func Test_certShouldBeRenewed(t *testing.T) {
|
|
ctrl := gomock.NewController(t)
|
|
defer ctrl.Finish()
|
|
type args struct {
|
|
timeSource TimeSource
|
|
cert *x509.Certificate
|
|
}
|
|
tests := []struct {
|
|
name string
|
|
args args
|
|
want bool
|
|
}{
|
|
{
|
|
name: "Expect cert should not be renewed right after creation",
|
|
args: args{
|
|
timeSource: func() TimeSource {
|
|
tsMock := certmock.NewMockTimeSource(ctrl)
|
|
tsMock.
|
|
EXPECT().
|
|
UTCNow().
|
|
Return(time.Now().UTC()).
|
|
Times(1)
|
|
return tsMock
|
|
}(),
|
|
cert: &x509.Certificate{
|
|
NotAfter: time.Now().UTC().Add(serverRelativeValidity),
|
|
NotBefore: time.Now().UTC().Add(-serverRelativeValidity),
|
|
},
|
|
},
|
|
want: false,
|
|
},
|
|
{
|
|
name: "Expect cert should be renewed if the remaining lifetime is less than a quarter of the total lifetime",
|
|
args: args{
|
|
timeSource: func() TimeSource {
|
|
tsMock := certmock.NewMockTimeSource(ctrl)
|
|
tsMock.
|
|
EXPECT().
|
|
UTCNow().
|
|
Return(time.Now().UTC().Add(serverRelativeValidity/2 + 1*time.Hour)).
|
|
Times(1)
|
|
return tsMock
|
|
}(),
|
|
cert: &x509.Certificate{
|
|
NotAfter: time.Now().UTC().Add(serverRelativeValidity),
|
|
NotBefore: time.Now().UTC().Add(-serverRelativeValidity),
|
|
},
|
|
},
|
|
want: true,
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
if got := certShouldBeRenewed(tt.args.timeSource, tt.args.cert); got != tt.want {
|
|
t.Errorf("certShouldBeRenewed() = %v, want %v", got, tt.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_fileSystemCache_Get(t *testing.T) {
|
|
type fields struct {
|
|
certCachePath string
|
|
inMemCache map[string]*tls.Certificate
|
|
timeSource TimeSource
|
|
}
|
|
type args struct {
|
|
cn string
|
|
}
|
|
tests := []struct {
|
|
name string
|
|
fields fields
|
|
args args
|
|
wantOk bool
|
|
}{
|
|
{
|
|
name: "Get a miss when no cert is present",
|
|
fields: fields{
|
|
certCachePath: os.TempDir(),
|
|
inMemCache: make(map[string]*tls.Certificate),
|
|
timeSource: NewTimeSource(),
|
|
},
|
|
args: args{
|
|
cnLocalhost,
|
|
},
|
|
wantOk: false,
|
|
},
|
|
{
|
|
name: "Get a prepared certificate from the memory cache",
|
|
fields: func() fields {
|
|
certGen := setupCertGen()
|
|
|
|
caCrt, _ := certGen.CACert(GenerationOptions{
|
|
CommonName: caCN,
|
|
})
|
|
|
|
srvCrt, _ := certGen.ServerCert(GenerationOptions{
|
|
CommonName: cnLocalhost,
|
|
}, caCrt)
|
|
|
|
return fields{
|
|
certCachePath: os.TempDir(),
|
|
inMemCache: map[string]*tls.Certificate{
|
|
cnLocalhost: srvCrt,
|
|
},
|
|
timeSource: NewTimeSource(),
|
|
}
|
|
}(),
|
|
args: args{
|
|
cn: cnLocalhost,
|
|
},
|
|
wantOk: true,
|
|
},
|
|
{
|
|
name: "Get a prepared certificate from the file system",
|
|
fields: func() fields {
|
|
certGen := setupCertGen()
|
|
|
|
caCrt, _ := certGen.CACert(GenerationOptions{
|
|
CommonName: "INetMock",
|
|
})
|
|
|
|
srvCrt, _ := certGen.ServerCert(GenerationOptions{
|
|
CommonName: serverCN,
|
|
}, caCrt)
|
|
|
|
pem := NewPEM(srvCrt)
|
|
if err := pem.Write(serverCN, os.TempDir()); err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
t.Cleanup(func() {
|
|
for _, f := range []string{
|
|
path.Join(os.TempDir(), fmt.Sprintf("%s.pem", serverCN)),
|
|
path.Join(os.TempDir(), fmt.Sprintf("%s.key", serverCN)),
|
|
} {
|
|
_ = os.Remove(f)
|
|
}
|
|
})
|
|
|
|
return fields{
|
|
certCachePath: os.TempDir(),
|
|
inMemCache: make(map[string]*tls.Certificate),
|
|
timeSource: NewTimeSource(),
|
|
}
|
|
}(),
|
|
args: args{
|
|
cn: serverCN,
|
|
},
|
|
wantOk: true,
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
f := &fileSystemCache{
|
|
certCachePath: tt.fields.certCachePath,
|
|
inMemCache: tt.fields.inMemCache,
|
|
timeSource: tt.fields.timeSource,
|
|
}
|
|
gotCrt, gotOk := f.Get(tt.args.cn)
|
|
|
|
if gotOk && (gotCrt == nil || !reflect.DeepEqual(reflect.TypeOf(new(tls.Certificate)), reflect.TypeOf(gotCrt))) {
|
|
t.Errorf("Wanted propert certificate but got %v", gotCrt)
|
|
}
|
|
|
|
if gotOk != tt.wantOk {
|
|
t.Errorf("Get() gotOk = %v, want %v", gotOk, tt.wantOk)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_fileSystemCache_Put(t *testing.T) {
|
|
type fields struct {
|
|
certCachePath string
|
|
inMemCache map[string]*tls.Certificate
|
|
timeSource TimeSource
|
|
}
|
|
type args struct {
|
|
cert *tls.Certificate
|
|
}
|
|
tests := []struct {
|
|
name string
|
|
fields fields
|
|
args args
|
|
wantErr bool
|
|
}{
|
|
{
|
|
name: "Want error if nil cert is passed to put",
|
|
fields: fields{
|
|
certCachePath: os.TempDir(),
|
|
inMemCache: make(map[string]*tls.Certificate),
|
|
timeSource: NewTimeSource(),
|
|
},
|
|
args: args{
|
|
cert: nil,
|
|
},
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "Want error if empty cert is passed to put",
|
|
fields: fields{
|
|
certCachePath: os.TempDir(),
|
|
inMemCache: make(map[string]*tls.Certificate),
|
|
timeSource: NewTimeSource(),
|
|
},
|
|
args: args{
|
|
cert: &tls.Certificate{},
|
|
},
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "No error if valid cert is passed",
|
|
fields: fields{
|
|
certCachePath: os.TempDir(),
|
|
inMemCache: make(map[string]*tls.Certificate),
|
|
timeSource: NewTimeSource(),
|
|
},
|
|
args: args{
|
|
cert: func() *tls.Certificate {
|
|
gen := setupCertGen()
|
|
ca, _ := gen.CACert(GenerationOptions{
|
|
CommonName: caCN,
|
|
})
|
|
|
|
srvCN := fmt.Sprintf("%s-%d", cnLocalhost, time.Now().Unix())
|
|
|
|
t.Cleanup(func() {
|
|
for _, f := range []string{
|
|
path.Join(os.TempDir(), fmt.Sprintf("%s.pem", srvCN)),
|
|
path.Join(os.TempDir(), fmt.Sprintf("%s.key", srvCN)),
|
|
} {
|
|
_ = os.Remove(f)
|
|
}
|
|
})
|
|
|
|
srvCrt, _ := gen.ServerCert(GenerationOptions{
|
|
CommonName: srvCN,
|
|
}, ca)
|
|
return srvCrt
|
|
}(),
|
|
},
|
|
wantErr: false,
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
f := &fileSystemCache{
|
|
certCachePath: tt.fields.certCachePath,
|
|
inMemCache: tt.fields.inMemCache,
|
|
timeSource: tt.fields.timeSource,
|
|
}
|
|
if err := f.Put(tt.args.cert); (err != nil) != tt.wantErr {
|
|
t.Errorf("Put() error = %v, wantErr %v", err, tt.wantErr)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func setupCertGen() Generator {
|
|
return NewDefaultGenerator(config.CertOptions{
|
|
Validity: config.ValidityByPurpose{
|
|
Server: config.ValidityDuration{
|
|
NotBeforeRelative: serverRelativeValidity,
|
|
NotAfterRelative: serverRelativeValidity,
|
|
},
|
|
CA: config.ValidityDuration{
|
|
NotBeforeRelative: caRelativeValidity,
|
|
NotAfterRelative: caRelativeValidity,
|
|
},
|
|
},
|
|
})
|
|
}
|