api/pkg/cert/cache_test.go

302 lines
6.8 KiB
Go
Raw Normal View History

package cert
import (
"crypto/tls"
"crypto/x509"
"fmt"
"os"
"path"
"reflect"
"testing"
"time"
"github.com/golang/mock/gomock"
certmock "gitlab.com/inetmock/inetmock/internal/mock/cert"
"gitlab.com/inetmock/inetmock/pkg/config"
)
const (
cnLocalhost = "localhost"
caCN = "UnitTests"
serverRelativeValidity = 24 * time.Hour
caRelativeValidity = 168 * time.Hour
)
var (
serverCN = fmt.Sprintf("%s-%d", cnLocalhost, time.Now().Unix())
)
func Test_certShouldBeRenewed(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
type args struct {
timeSource TimeSource
cert *x509.Certificate
}
tests := []struct {
name string
args args
want bool
}{
{
name: "Expect cert should not be renewed right after creation",
args: args{
timeSource: func() TimeSource {
tsMock := certmock.NewMockTimeSource(ctrl)
tsMock.
EXPECT().
UTCNow().
Return(time.Now().UTC()).
Times(1)
return tsMock
}(),
cert: &x509.Certificate{
NotAfter: time.Now().UTC().Add(serverRelativeValidity),
NotBefore: time.Now().UTC().Add(-serverRelativeValidity),
},
},
want: false,
},
{
name: "Expect cert should be renewed if the remaining lifetime is less than a quarter of the total lifetime",
args: args{
timeSource: func() TimeSource {
tsMock := certmock.NewMockTimeSource(ctrl)
tsMock.
EXPECT().
UTCNow().
Return(time.Now().UTC().Add(serverRelativeValidity/2 + 1*time.Hour)).
Times(1)
return tsMock
}(),
cert: &x509.Certificate{
NotAfter: time.Now().UTC().Add(serverRelativeValidity),
NotBefore: time.Now().UTC().Add(-serverRelativeValidity),
},
},
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := certShouldBeRenewed(tt.args.timeSource, tt.args.cert); got != tt.want {
t.Errorf("certShouldBeRenewed() = %v, want %v", got, tt.want)
}
})
}
}
func Test_fileSystemCache_Get(t *testing.T) {
type fields struct {
certCachePath string
inMemCache map[string]*tls.Certificate
timeSource TimeSource
}
type args struct {
cn string
}
tests := []struct {
name string
fields fields
args args
wantOk bool
}{
{
name: "Get a miss when no cert is present",
fields: fields{
certCachePath: os.TempDir(),
inMemCache: make(map[string]*tls.Certificate),
timeSource: NewTimeSource(),
},
args: args{
cnLocalhost,
},
wantOk: false,
},
{
name: "Get a prepared certificate from the memory cache",
fields: func() fields {
certGen := setupCertGen()
caCrt, _ := certGen.CACert(GenerationOptions{
CommonName: caCN,
})
srvCrt, _ := certGen.ServerCert(GenerationOptions{
CommonName: cnLocalhost,
}, caCrt)
return fields{
certCachePath: os.TempDir(),
inMemCache: map[string]*tls.Certificate{
cnLocalhost: srvCrt,
},
timeSource: NewTimeSource(),
}
}(),
args: args{
cn: cnLocalhost,
},
wantOk: true,
},
{
name: "Get a prepared certificate from the file system",
fields: func() fields {
certGen := setupCertGen()
caCrt, _ := certGen.CACert(GenerationOptions{
CommonName: "INetMock",
})
srvCrt, _ := certGen.ServerCert(GenerationOptions{
CommonName: serverCN,
}, caCrt)
pem := NewPEM(srvCrt)
if err := pem.Write(serverCN, os.TempDir()); err != nil {
panic(err)
}
t.Cleanup(func() {
for _, f := range []string{
path.Join(os.TempDir(), fmt.Sprintf("%s.pem", serverCN)),
path.Join(os.TempDir(), fmt.Sprintf("%s.key", serverCN)),
} {
_ = os.Remove(f)
}
})
return fields{
certCachePath: os.TempDir(),
inMemCache: make(map[string]*tls.Certificate),
timeSource: NewTimeSource(),
}
}(),
args: args{
cn: serverCN,
},
wantOk: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
f := &fileSystemCache{
certCachePath: tt.fields.certCachePath,
inMemCache: tt.fields.inMemCache,
timeSource: tt.fields.timeSource,
}
gotCrt, gotOk := f.Get(tt.args.cn)
if gotOk && (gotCrt == nil || !reflect.DeepEqual(reflect.TypeOf(new(tls.Certificate)), reflect.TypeOf(gotCrt))) {
t.Errorf("Wanted propert certificate but got %v", gotCrt)
}
if gotOk != tt.wantOk {
t.Errorf("Get() gotOk = %v, want %v", gotOk, tt.wantOk)
}
})
}
}
func Test_fileSystemCache_Put(t *testing.T) {
type fields struct {
certCachePath string
inMemCache map[string]*tls.Certificate
timeSource TimeSource
}
type args struct {
cert *tls.Certificate
}
tests := []struct {
name string
fields fields
args args
wantErr bool
}{
{
name: "Want error if nil cert is passed to put",
fields: fields{
certCachePath: os.TempDir(),
inMemCache: make(map[string]*tls.Certificate),
timeSource: NewTimeSource(),
},
args: args{
cert: nil,
},
wantErr: true,
},
{
name: "Want error if empty cert is passed to put",
fields: fields{
certCachePath: os.TempDir(),
inMemCache: make(map[string]*tls.Certificate),
timeSource: NewTimeSource(),
},
args: args{
cert: &tls.Certificate{},
},
wantErr: true,
},
{
name: "No error if valid cert is passed",
fields: fields{
certCachePath: os.TempDir(),
inMemCache: make(map[string]*tls.Certificate),
timeSource: NewTimeSource(),
},
args: args{
cert: func() *tls.Certificate {
gen := setupCertGen()
ca, _ := gen.CACert(GenerationOptions{
CommonName: caCN,
})
srvCN := fmt.Sprintf("%s-%d", cnLocalhost, time.Now().Unix())
t.Cleanup(func() {
for _, f := range []string{
path.Join(os.TempDir(), fmt.Sprintf("%s.pem", srvCN)),
path.Join(os.TempDir(), fmt.Sprintf("%s.key", srvCN)),
} {
_ = os.Remove(f)
}
})
srvCrt, _ := gen.ServerCert(GenerationOptions{
CommonName: srvCN,
}, ca)
return srvCrt
}(),
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
f := &fileSystemCache{
certCachePath: tt.fields.certCachePath,
inMemCache: tt.fields.inMemCache,
timeSource: tt.fields.timeSource,
}
if err := f.Put(tt.args.cert); (err != nil) != tt.wantErr {
t.Errorf("Put() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func setupCertGen() Generator {
return NewDefaultGenerator(config.CertOptions{
Validity: config.ValidityByPurpose{
Server: config.ValidityDuration{
NotBeforeRelative: serverRelativeValidity,
NotAfterRelative: serverRelativeValidity,
},
CA: config.ValidityDuration{
NotBeforeRelative: caRelativeValidity,
NotAfterRelative: caRelativeValidity,
},
},
})
}