Fix UpdateTTL and add tests

This commit is contained in:
Peter 2021-07-07 22:17:56 +02:00
parent 368932a31d
commit 5a266a8897
Signed by: prskr
GPG key ID: C1DB5D2E8DB512F9
3 changed files with 161 additions and 0 deletions

View file

@ -33,6 +33,7 @@ type Entry struct {
Name string Name string
Address net.IP Address net.IP
timeout time.Time timeout time.Time
index int
} }
func (e Entry) WithTTL(ttl time.Duration) *Entry { func (e Entry) WithTTL(ttl time.Duration) *Entry {
@ -40,6 +41,10 @@ func (e Entry) WithTTL(ttl time.Duration) *Entry {
return &e return &e
} }
func (e Entry) TTL() time.Time {
return e.timeout
}
type Cache interface { type Cache interface {
PutRecord(host string, address net.IP) PutRecord(host string, address net.IP)
ForwardLookup(host string, resolver IPResolver) net.IP ForwardLookup(host string, resolver IPResolver) net.IP

View file

@ -44,6 +44,7 @@ func (f EvictionCallbackFunc) OnEvicted(evictedEntries []*Entry) {
type TTLQueue interface { type TTLQueue interface {
Push(name string, address net.IP, ttl time.Duration) *Entry Push(name string, address net.IP, ttl time.Duration) *Entry
Get(idx int) *Entry
UpdateTTL(e *Entry, newTTL time.Duration) UpdateTTL(e *Entry, newTTL time.Duration)
Evict() Evict()
PeekFront() *Entry PeekFront() *Entry
@ -99,7 +100,29 @@ func (t *ttlQueue) Get(idx int) *Entry {
} }
func (t *ttlQueue) UpdateTTL(e *Entry, newTTL time.Duration) { func (t *ttlQueue) UpdateTTL(e *Entry, newTTL time.Duration) {
t.modLock.Lock()
defer t.modLock.Unlock()
var (
length = len(t.virtual)
insertIdx int
)
e.timeout = time.Now().UTC().Add(newTTL) e.timeout = time.Now().UTC().Add(newTTL)
insertIdx = sort.Search(length, func(i int) bool {
return t.virtual[i].timeout.After(e.timeout)
})
if insertIdx == e.index {
return
}
for idx := e.index; idx+1 < insertIdx; idx++ {
t.virtual[idx] = t.virtual[idx+1]
t.virtual[idx].index = idx
}
t.virtual[insertIdx-1] = e
} }
func (t *ttlQueue) Evict() { func (t *ttlQueue) Evict() {
@ -120,6 +143,7 @@ func (t *ttlQueue) Push(name string, address net.IP, ttl time.Duration) *Entry {
Name: name, Name: name,
Address: address, Address: address,
timeout: time.Now().UTC().Add(ttl), timeout: time.Now().UTC().Add(ttl),
index: length,
} }
) )
@ -146,6 +170,7 @@ func (t *ttlQueue) Push(name string, address net.IP, ttl time.Duration) *Entry {
t.virtual = append(t.virtual, nil) t.virtual = append(t.virtual, nil)
copy(t.virtual[insertIdx+1:], t.virtual[insertIdx:]) copy(t.virtual[insertIdx+1:], t.virtual[insertIdx:])
t.virtual[insertIdx] = entry t.virtual[insertIdx] = entry
entry.index = insertIdx
} }
return entry return entry

View file

@ -176,3 +176,134 @@ func Test_ttlQueue_Push(t *testing.T) {
}) })
} }
} }
func Test_ttlQueue_UpdateTTL(t1 *testing.T) {
t1.Parallel()
type seedEntry struct {
host string
ip net.IP
ttl time.Duration
}
type fields struct {
initialCapacity int
seeds []seedEntry
}
type args struct {
idxToUpdate int
newTTL time.Duration
}
tests := []struct {
name string
fields fields
args args
}{
{
name: "Single element",
fields: fields{
initialCapacity: 10,
seeds: []seedEntry{
{
host: "mail.gogle.ru",
ip: net.ParseIP("1.2.3.4"),
ttl: 100 * time.Millisecond,
},
},
},
args: args{
idxToUpdate: 0,
newTTL: 200 * time.Millisecond,
},
},
{
name: "Last element",
fields: fields{
initialCapacity: 10,
seeds: []seedEntry{
{
host: "mail.gogle.ru",
ip: net.ParseIP("1.2.3.4"),
ttl: 100 * time.Millisecond,
},
{
host: "asdf.gogle.ru",
ip: net.ParseIP("1.2.3.5"),
ttl: 120 * time.Millisecond,
},
},
},
args: args{
idxToUpdate: 1,
newTTL: 200 * time.Millisecond,
},
},
{
name: "Switch elements",
fields: fields{
initialCapacity: 10,
seeds: []seedEntry{
{
host: "mail.gogle.ru",
ip: net.ParseIP("1.2.3.4"),
ttl: 50 * time.Millisecond,
},
{
host: "asdf.gogle.ru",
ip: net.ParseIP("1.2.3.5"),
ttl: 150 * time.Millisecond,
},
},
},
args: args{
idxToUpdate: 0,
newTTL: 500 * time.Millisecond,
},
},
{
name: "Switch elements",
fields: fields{
initialCapacity: 10,
seeds: []seedEntry{
{
host: "mail.gogle.ru",
ip: net.ParseIP("1.2.3.4"),
ttl: 50 * time.Millisecond,
},
{
host: "honey.gogle.ru",
ip: net.ParseIP("1.2.3.4"),
ttl: 100 * time.Millisecond,
},
{
host: "asdf.gogle.ru",
ip: net.ParseIP("1.2.3.5"),
ttl: 150 * time.Millisecond,
},
},
},
args: args{
idxToUpdate: 1,
newTTL: 500 * time.Millisecond,
},
},
}
for _, tt := range tests {
tt := tt
t1.Run(tt.name, func(t1 *testing.T) {
t1.Parallel()
var queue = dns.NewQueue(tt.fields.initialCapacity)
for i := range tt.fields.seeds {
seed := tt.fields.seeds[i]
_ = queue.Push(seed.host, seed.ip, seed.ttl)
}
var entry = queue.Get(tt.args.idxToUpdate)
queue.UpdateTTL(entry, tt.args.newTTL)
var current = queue.PeekFront().TTL()
for i := 0; i < queue.Len(); i++ {
if current.After(queue.Get(i).TTL()) {
t1.Errorf("TTLs in wrong order got = %v, current = %v", queue.Get(i).TTL(), current)
}
}
})
}
}