1
0
Fork 0
telegraf/plugins/inputs/dns_query/dns_query.go

255 lines
5.3 KiB
Go
Raw Permalink Normal View History

//go:generate ../../../tools/readme_config_includer/generator
package dns_query
import (
_ "embed"
"errors"
"fmt"
"net"
"slices"
"strconv"
"sync"
"time"
"github.com/miekg/dns"
"github.com/influxdata/telegraf"
"github.com/influxdata/telegraf/config"
"github.com/influxdata/telegraf/plugins/inputs"
)
//go:embed sample.conf
var sampleConfig string
var ignoredErrors = []string{
"NXDOMAIN",
}
type resultType uint64
const (
successResult resultType = iota
timeoutResult
errorResult
)
type DNSQuery struct {
Domains []string `toml:"domains"`
Network string `toml:"network"`
Servers []string `toml:"servers"`
RecordType string `toml:"record_type"`
Port int `toml:"port"`
Timeout config.Duration `toml:"timeout"`
IncludeFields []string `toml:"include_fields"`
fieldEnabled map[string]bool
}
func (*DNSQuery) SampleConfig() string {
return sampleConfig
}
func (d *DNSQuery) Init() error {
// Convert the included fields into a lookup-table
d.fieldEnabled = make(map[string]bool, len(d.IncludeFields))
for _, f := range d.IncludeFields {
switch f {
case "first_ip", "all_ips":
default:
return fmt.Errorf("invalid field %q included", f)
}
d.fieldEnabled[f] = true
}
// Set defaults
if d.Network == "" {
d.Network = "udp"
}
if d.RecordType == "" {
d.RecordType = "NS"
}
if len(d.Domains) == 0 {
d.Domains = []string{"."}
d.RecordType = "NS"
}
if d.Port < 1 {
d.Port = 53
}
return nil
}
func (d *DNSQuery) Gather(acc telegraf.Accumulator) error {
var wg sync.WaitGroup
for _, domain := range d.Domains {
for _, server := range d.Servers {
wg.Add(1)
go func(domain, server string) {
defer wg.Done()
fields, tags, err := d.query(domain, server)
if err != nil && !slices.Contains(ignoredErrors, tags["rcode"]) {
var opErr *net.OpError
if !errors.As(err, &opErr) || !opErr.Timeout() {
acc.AddError(err)
}
}
acc.AddFields("dns_query", fields, tags)
}(domain, server)
}
}
wg.Wait()
return nil
}
func (d *DNSQuery) query(domain, server string) (map[string]interface{}, map[string]string, error) {
tags := map[string]string{
"server": server,
"domain": domain,
"record_type": d.RecordType,
"result": "error",
}
fields := map[string]interface{}{
"query_time_ms": float64(0),
"result_code": uint64(errorResult),
}
c := dns.Client{
ReadTimeout: time.Duration(d.Timeout),
Net: d.Network,
}
recordType, err := d.parseRecordType()
if err != nil {
return fields, tags, err
}
var msg dns.Msg
msg.SetQuestion(dns.Fqdn(domain), recordType)
msg.RecursionDesired = true
addr := net.JoinHostPort(server, strconv.Itoa(d.Port))
r, rtt, err := c.Exchange(&msg, addr)
if err != nil {
var opErr *net.OpError
if errors.As(err, &opErr) && opErr.Timeout() {
tags["result"] = "timeout"
fields["result_code"] = uint64(timeoutResult)
return fields, tags, err
}
return fields, tags, err
}
// Fill valid fields
tags["rcode"] = dns.RcodeToString[r.Rcode]
fields["rcode_value"] = r.Rcode
fields["query_time_ms"] = float64(rtt.Nanoseconds()) / 1e6
// Handle the failure case
if r.Rcode != dns.RcodeSuccess {
return fields, tags, fmt.Errorf("invalid answer (%s) from %s after %s query for %s", dns.RcodeToString[r.Rcode], server, d.RecordType, domain)
}
// Success
tags["result"] = "success"
fields["result_code"] = uint64(successResult)
// Fill out custom fields for specific record types
for _, record := range r.Answer {
switch x := record.(type) {
case *dns.A:
fields["name"] = x.Hdr.Name
case *dns.AAAA:
fields["name"] = x.Hdr.Name
case *dns.CNAME:
fields["name"] = x.Hdr.Name
case *dns.MX:
fields["name"] = x.Hdr.Name
fields["preference"] = x.Preference
case *dns.SOA:
fields["expire"] = x.Expire
fields["minttl"] = x.Minttl
fields["name"] = x.Hdr.Name
fields["refresh"] = x.Refresh
fields["retry"] = x.Retry
fields["serial"] = x.Serial
}
}
if d.fieldEnabled["first_ip"] {
for _, record := range r.Answer {
if ip, found := extractIP(record); found {
fields["ip"] = ip
break
}
}
}
if d.fieldEnabled["all_ips"] {
for i, record := range r.Answer {
if ip, found := extractIP(record); found {
fields["ip_"+strconv.Itoa(i)] = ip
}
}
}
return fields, tags, nil
}
func (d *DNSQuery) parseRecordType() (uint16, error) {
var recordType uint16
var err error
switch d.RecordType {
case "A":
recordType = dns.TypeA
case "AAAA":
recordType = dns.TypeAAAA
case "ANY":
recordType = dns.TypeANY
case "CNAME":
recordType = dns.TypeCNAME
case "MX":
recordType = dns.TypeMX
case "NS":
recordType = dns.TypeNS
case "PTR":
recordType = dns.TypePTR
case "SOA":
recordType = dns.TypeSOA
case "SPF":
recordType = dns.TypeSPF
case "SRV":
recordType = dns.TypeSRV
case "TXT":
recordType = dns.TypeTXT
default:
err = fmt.Errorf("record type %s not recognized", d.RecordType)
}
return recordType, err
}
func extractIP(record dns.RR) (string, bool) {
if r, ok := record.(*dns.A); ok {
return r.A.String(), true
}
if r, ok := record.(*dns.AAAA); ok {
return r.AAAA.String(), true
}
return "", false
}
func init() {
inputs.Add("dns_query", func() telegraf.Input {
return &DNSQuery{
Timeout: config.Duration(2 * time.Second),
}
})
}