309 lines
8.1 KiB
Go
309 lines
8.1 KiB
Go
package reverse_dns
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"net"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"golang.org/x/sync/semaphore"
|
|
)
|
|
|
|
const defaultMaxWorkers = 10
|
|
|
|
var (
|
|
ErrTimeout = errors.New("request timed out")
|
|
)
|
|
|
|
// AnyResolver is for the net.Resolver
|
|
type AnyResolver interface {
|
|
LookupAddr(ctx context.Context, addr string) (names []string, err error)
|
|
}
|
|
|
|
// ReverseDNSCache is safe to use across multiple goroutines.
|
|
// if multiple goroutines request the same IP at the same time, one of the
|
|
// requests will trigger the lookup and the rest will wait for its response.
|
|
type ReverseDNSCache struct {
|
|
Resolver AnyResolver
|
|
stats RDNSCacheStats
|
|
|
|
// settings
|
|
ttl time.Duration
|
|
lookupTimeout time.Duration
|
|
maxWorkers int
|
|
|
|
// internal
|
|
rwLock sync.RWMutex
|
|
sem *semaphore.Weighted
|
|
cancelCleanupWorker context.CancelFunc
|
|
|
|
cache map[string]*dnslookup
|
|
|
|
// keep an ordered list of what needs to be worked on and what is due to expire.
|
|
// We can use this list for both with a job position marker, and by popping items
|
|
// off the list as they expire. This avoids iterating over the whole map to find
|
|
// things to do.
|
|
// As a bonus, we only have to read the first item to know if anything in the
|
|
// map has expired.
|
|
// must lock to get access to this.
|
|
expireList []*dnslookup
|
|
expireListLock sync.Mutex
|
|
}
|
|
|
|
type RDNSCacheStats struct {
|
|
CacheHit uint64
|
|
CacheMiss uint64
|
|
CacheExpire uint64
|
|
RequestsAbandoned uint64
|
|
RequestsFilled uint64
|
|
}
|
|
|
|
func NewReverseDNSCache(ttl, lookupTimeout time.Duration, workerPoolSize int) *ReverseDNSCache {
|
|
if workerPoolSize <= 0 {
|
|
workerPoolSize = defaultMaxWorkers
|
|
}
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
d := &ReverseDNSCache{
|
|
ttl: ttl,
|
|
lookupTimeout: lookupTimeout,
|
|
cache: make(map[string]*dnslookup),
|
|
maxWorkers: workerPoolSize,
|
|
sem: semaphore.NewWeighted(int64(workerPoolSize)),
|
|
cancelCleanupWorker: cancel,
|
|
Resolver: net.DefaultResolver,
|
|
}
|
|
d.startCleanupWorker(ctx)
|
|
return d
|
|
}
|
|
|
|
// dnslookup represents a lookup request/response. It may or may not be answered yet.
|
|
// interested parties register themselves with existing requests or create new ones
|
|
// to get their dns query answered. Answers will be pushed out to callbacks.
|
|
type dnslookup struct {
|
|
ip string // keep a copy for the expireList.
|
|
domains []string
|
|
expiresAt time.Time
|
|
completed bool
|
|
callbacks []callbackChannelType
|
|
}
|
|
|
|
type lookupResult struct {
|
|
domains []string
|
|
err error
|
|
}
|
|
|
|
type callbackChannelType chan lookupResult
|
|
|
|
// Lookup takes a string representing a parseable ipv4 or ipv6 IP, and blocks
|
|
// until it has resolved to 0-n results, or until its lookup timeout has elapsed.
|
|
// if the lookup timeout elapses, it returns an empty slice.
|
|
func (d *ReverseDNSCache) Lookup(ip string) ([]string, error) {
|
|
if len(ip) == 0 {
|
|
return nil, nil
|
|
}
|
|
|
|
// check if the value is cached
|
|
d.rwLock.RLock()
|
|
result, found := d.lockedGetFromCache(ip)
|
|
if found && result.completed && !result.expiresAt.Before(time.Now()) {
|
|
defer d.rwLock.RUnlock()
|
|
atomic.AddUint64(&d.stats.CacheHit, 1)
|
|
// cache is valid
|
|
return result.domains, nil
|
|
}
|
|
d.rwLock.RUnlock()
|
|
|
|
// if it's not cached, kick off a lookup job and subscribe to the result.
|
|
lookupChan := d.subscribeTo(ip)
|
|
timer := time.NewTimer(d.lookupTimeout)
|
|
defer timer.Stop()
|
|
|
|
// timer is still necessary even if doLookup respects timeout due to worker
|
|
// pool starvation.
|
|
select {
|
|
case result := <-lookupChan:
|
|
return result.domains, result.err
|
|
case <-timer.C:
|
|
return nil, ErrTimeout
|
|
}
|
|
}
|
|
|
|
func (d *ReverseDNSCache) subscribeTo(ip string) callbackChannelType {
|
|
callback := make(callbackChannelType, 1)
|
|
|
|
d.rwLock.Lock()
|
|
defer d.rwLock.Unlock()
|
|
|
|
// confirm it's still not in the cache. This needs to be done under an active lock.
|
|
result, found := d.lockedGetFromCache(ip)
|
|
if found {
|
|
atomic.AddUint64(&d.stats.CacheHit, 1)
|
|
// has the request been answered since we last checked?
|
|
if result.completed {
|
|
// we can return the answer with the channel.
|
|
callback <- lookupResult{domains: result.domains}
|
|
return callback
|
|
}
|
|
// there's a request but it hasn't been answered yet;
|
|
// add yourself to the subscribers and return that.
|
|
result.callbacks = append(result.callbacks, callback)
|
|
d.lockedSaveToCache(result)
|
|
return callback
|
|
}
|
|
|
|
atomic.AddUint64(&d.stats.CacheMiss, 1)
|
|
|
|
// otherwise we need to register the request
|
|
l := &dnslookup{
|
|
ip: ip,
|
|
expiresAt: time.Now().Add(d.ttl),
|
|
callbacks: []callbackChannelType{callback},
|
|
}
|
|
|
|
d.lockedSaveToCache(l)
|
|
go d.doLookup(l.ip)
|
|
return callback
|
|
}
|
|
|
|
// lockedGetFromCache fetches from the correct internal ip cache.
|
|
// you MUST first do a read or write lock before calling it, and keep locks around
|
|
// the dnslookup that is returned until you clone it.
|
|
func (d *ReverseDNSCache) lockedGetFromCache(ip string) (lookup *dnslookup, found bool) {
|
|
lookup, found = d.cache[ip]
|
|
if found && !lookup.expiresAt.After(time.Now()) {
|
|
return nil, false
|
|
}
|
|
return lookup, found
|
|
}
|
|
|
|
// lockedSaveToCache stores a lookup in the correct internal ip cache.
|
|
// you MUST first do a write lock before calling it.
|
|
func (d *ReverseDNSCache) lockedSaveToCache(lookup *dnslookup) {
|
|
if !lookup.expiresAt.After(time.Now()) {
|
|
return // don't cache.
|
|
}
|
|
d.cache[lookup.ip] = lookup
|
|
}
|
|
|
|
func (d *ReverseDNSCache) startCleanupWorker(ctx context.Context) {
|
|
go func() {
|
|
cleanupTick := time.NewTicker(10 * time.Second)
|
|
for {
|
|
select {
|
|
case <-cleanupTick.C:
|
|
d.cleanup()
|
|
case <-ctx.Done():
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
|
|
func (d *ReverseDNSCache) doLookup(ip string) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), d.lookupTimeout)
|
|
defer cancel()
|
|
if err := d.sem.Acquire(ctx, 1); err != nil {
|
|
// lookup timeout
|
|
d.abandonLookup(ip, ErrTimeout)
|
|
return
|
|
}
|
|
defer d.sem.Release(1)
|
|
|
|
names, err := d.Resolver.LookupAddr(ctx, ip)
|
|
if err != nil {
|
|
d.abandonLookup(ip, err)
|
|
return
|
|
}
|
|
|
|
d.rwLock.Lock()
|
|
lookup, found := d.lockedGetFromCache(ip)
|
|
if !found {
|
|
d.rwLock.Unlock()
|
|
return
|
|
}
|
|
|
|
lookup.domains = names
|
|
lookup.completed = true
|
|
lookup.expiresAt = time.Now().Add(d.ttl) // extend the ttl now that we have a reply.
|
|
callbacks := lookup.callbacks
|
|
lookup.callbacks = nil
|
|
|
|
d.lockedSaveToCache(lookup)
|
|
d.rwLock.Unlock()
|
|
|
|
d.expireListLock.Lock()
|
|
// add it to the expireList.
|
|
d.expireList = append(d.expireList, lookup)
|
|
d.expireListLock.Unlock()
|
|
|
|
atomic.AddUint64(&d.stats.RequestsFilled, uint64(len(callbacks)))
|
|
for _, cb := range callbacks {
|
|
cb <- lookupResult{domains: names}
|
|
close(cb)
|
|
}
|
|
}
|
|
|
|
func (d *ReverseDNSCache) abandonLookup(ip string, err error) {
|
|
d.rwLock.Lock()
|
|
lookup, found := d.lockedGetFromCache(ip)
|
|
if !found {
|
|
d.rwLock.Unlock()
|
|
return
|
|
}
|
|
|
|
callbacks := lookup.callbacks
|
|
delete(d.cache, lookup.ip)
|
|
d.rwLock.Unlock()
|
|
// resolve the remaining callbacks to free the resources.
|
|
atomic.AddUint64(&d.stats.RequestsAbandoned, uint64(len(callbacks)))
|
|
for _, cb := range callbacks {
|
|
cb <- lookupResult{err: err}
|
|
close(cb)
|
|
}
|
|
}
|
|
|
|
func (d *ReverseDNSCache) cleanup() {
|
|
now := time.Now()
|
|
d.expireListLock.Lock()
|
|
if len(d.expireList) == 0 {
|
|
d.expireListLock.Unlock()
|
|
return
|
|
}
|
|
ipsToDelete := make([]string, 0, len(d.expireList))
|
|
for i := 0; i < len(d.expireList); i++ {
|
|
if !d.expireList[i].expiresAt.Before(now) {
|
|
break // done. Nothing after this point is expired.
|
|
}
|
|
ipsToDelete = append(ipsToDelete, d.expireList[i].ip)
|
|
}
|
|
if len(ipsToDelete) == 0 {
|
|
d.expireListLock.Unlock()
|
|
return
|
|
}
|
|
d.expireList = d.expireList[len(ipsToDelete):]
|
|
d.expireListLock.Unlock()
|
|
|
|
atomic.AddUint64(&d.stats.CacheExpire, uint64(len(ipsToDelete)))
|
|
|
|
d.rwLock.Lock()
|
|
defer d.rwLock.Unlock()
|
|
for _, ip := range ipsToDelete {
|
|
delete(d.cache, ip)
|
|
}
|
|
}
|
|
|
|
func (d *ReverseDNSCache) Stats() RDNSCacheStats {
|
|
stats := RDNSCacheStats{}
|
|
stats.CacheHit = atomic.LoadUint64(&d.stats.CacheHit)
|
|
stats.CacheMiss = atomic.LoadUint64(&d.stats.CacheMiss)
|
|
stats.CacheExpire = atomic.LoadUint64(&d.stats.CacheExpire)
|
|
stats.RequestsAbandoned = atomic.LoadUint64(&d.stats.RequestsAbandoned)
|
|
stats.RequestsFilled = atomic.LoadUint64(&d.stats.RequestsFilled)
|
|
return stats
|
|
}
|
|
|
|
func (d *ReverseDNSCache) Stop() {
|
|
d.cancelCleanupWorker()
|
|
}
|