Adding upstream version 1.34.4.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
e393c3af3f
commit
4978089aab
4963 changed files with 677545 additions and 0 deletions
267
plugins/common/socket/datagram.go
Normal file
267
plugins/common/socket/datagram.go
Normal file
|
@ -0,0 +1,267 @@
|
|||
package socket
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/alitto/pond"
|
||||
|
||||
"github.com/influxdata/telegraf"
|
||||
"github.com/influxdata/telegraf/config"
|
||||
"github.com/influxdata/telegraf/internal"
|
||||
)
|
||||
|
||||
type packetListener struct {
|
||||
Encoding string
|
||||
MaxDecompressionSize int64
|
||||
SocketMode string
|
||||
ReadBufferSize int
|
||||
Log telegraf.Logger
|
||||
|
||||
conn net.PacketConn
|
||||
decoders sync.Pool
|
||||
path string
|
||||
wg sync.WaitGroup
|
||||
parsePool *pond.WorkerPool
|
||||
}
|
||||
|
||||
func newPacketListener(encoding string, maxDecompressionSize config.Size, maxWorkers int) *packetListener {
|
||||
return &packetListener{
|
||||
Encoding: encoding,
|
||||
MaxDecompressionSize: int64(maxDecompressionSize),
|
||||
parsePool: pond.New(maxWorkers, 0, pond.MinWorkers(maxWorkers/2+1)),
|
||||
}
|
||||
}
|
||||
|
||||
func (l *packetListener) listenData(onData CallbackData, onError CallbackError) {
|
||||
l.wg.Add(1)
|
||||
|
||||
go func() {
|
||||
defer l.wg.Done()
|
||||
|
||||
buf := make([]byte, l.ReadBufferSize)
|
||||
for {
|
||||
n, src, err := l.conn.ReadFrom(buf)
|
||||
receiveTime := time.Now()
|
||||
if err != nil {
|
||||
if !strings.HasSuffix(err.Error(), ": use of closed network connection") {
|
||||
if onError != nil {
|
||||
onError(err)
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
d := make([]byte, n)
|
||||
copy(d, buf[:n])
|
||||
l.parsePool.Submit(func() {
|
||||
decoder := l.decoders.Get().(internal.ContentDecoder)
|
||||
defer l.decoders.Put(decoder)
|
||||
body, err := decoder.Decode(d)
|
||||
if err != nil && onError != nil {
|
||||
onError(fmt.Errorf("unable to decode incoming packet: %w", err))
|
||||
}
|
||||
|
||||
if l.path != "" {
|
||||
src = &net.UnixAddr{Name: l.path, Net: "unixgram"}
|
||||
}
|
||||
|
||||
onData(src, body, receiveTime)
|
||||
})
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (l *packetListener) listenConnection(onConnection CallbackConnection, onError CallbackError) {
|
||||
l.wg.Add(1)
|
||||
go func() {
|
||||
defer l.wg.Done()
|
||||
defer l.conn.Close()
|
||||
|
||||
buf := make([]byte, l.ReadBufferSize)
|
||||
for {
|
||||
// Wait for packets and read them
|
||||
n, src, err := l.conn.ReadFrom(buf)
|
||||
if err != nil {
|
||||
if !strings.HasSuffix(err.Error(), ": use of closed network connection") {
|
||||
if onError != nil {
|
||||
onError(err)
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
d := make([]byte, n)
|
||||
copy(d, buf[:n])
|
||||
l.parsePool.Submit(func() {
|
||||
// Decode the contents depending on the given encoding
|
||||
decoder := l.decoders.Get().(internal.ContentDecoder)
|
||||
// Not possible to immediately return the decoder to the Pool after calling Decode, because some
|
||||
// decoders return a reference to their internal buffers. This would cause data races.
|
||||
defer l.decoders.Put(decoder)
|
||||
body, err := decoder.Decode(d[:n])
|
||||
if err != nil && onError != nil {
|
||||
onError(fmt.Errorf("unable to decode incoming packet: %w", err))
|
||||
}
|
||||
|
||||
// Workaround to provide remote endpoints for Unix-type sockets
|
||||
if l.path != "" {
|
||||
src = &net.UnixAddr{Name: l.path, Net: "unixgram"}
|
||||
}
|
||||
|
||||
// Create a pipe and notify the caller via Callback that new data is
|
||||
// available. Afterwards write the data. Please note: Write() will
|
||||
// block until all data is consumed!
|
||||
reader, writer := io.Pipe()
|
||||
go onConnection(src, reader)
|
||||
if _, err := writer.Write(body); err != nil && onError != nil {
|
||||
onError(err)
|
||||
}
|
||||
writer.Close()
|
||||
})
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (l *packetListener) setupUnixgram(u *url.URL, socketMode string, bufferSize int) error {
|
||||
l.path = filepath.FromSlash(u.Path)
|
||||
if runtime.GOOS == "windows" && strings.Contains(l.path, ":") {
|
||||
l.path = strings.TrimPrefix(l.path, `\`)
|
||||
}
|
||||
if err := os.Remove(l.path); err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
return fmt.Errorf("removing socket failed: %w", err)
|
||||
}
|
||||
|
||||
conn, err := net.ListenPacket(u.Scheme, l.path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("listening (unixgram) failed: %w", err)
|
||||
}
|
||||
l.conn = conn
|
||||
|
||||
// Set permissions on socket
|
||||
if socketMode != "" {
|
||||
// Convert from octal in string to int
|
||||
i, err := strconv.ParseUint(socketMode, 8, 32)
|
||||
if err != nil {
|
||||
return fmt.Errorf("converting socket mode failed: %w", err)
|
||||
}
|
||||
|
||||
perm := os.FileMode(uint32(i))
|
||||
if err := os.Chmod(u.Path, perm); err != nil {
|
||||
return fmt.Errorf("changing socket permissions failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if bufferSize > 0 {
|
||||
l.ReadBufferSize = bufferSize
|
||||
} else {
|
||||
l.ReadBufferSize = 64 * 1024 // 64kb - IP packet size
|
||||
}
|
||||
|
||||
return l.setupDecoder()
|
||||
}
|
||||
|
||||
func (l *packetListener) setupUDP(u *url.URL, ifname string, bufferSize int) error {
|
||||
var conn *net.UDPConn
|
||||
|
||||
addr, err := net.ResolveUDPAddr(u.Scheme, u.Host)
|
||||
if err != nil {
|
||||
return fmt.Errorf("resolving UDP address failed: %w", err)
|
||||
}
|
||||
if addr.IP.IsMulticast() {
|
||||
var iface *net.Interface
|
||||
if ifname != "" {
|
||||
var err error
|
||||
iface, err = net.InterfaceByName(ifname)
|
||||
if err != nil {
|
||||
return fmt.Errorf("resolving address of %q failed: %w", ifname, err)
|
||||
}
|
||||
}
|
||||
conn, err = net.ListenMulticastUDP(u.Scheme, iface, addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("listening (udp multicast) failed: %w", err)
|
||||
}
|
||||
} else {
|
||||
conn, err = net.ListenUDP(u.Scheme, addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("listening (udp) failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if bufferSize > 0 {
|
||||
if err := conn.SetReadBuffer(bufferSize); err != nil {
|
||||
l.Log.Warnf("Setting read buffer on %s socket failed: %v", u.Scheme, err)
|
||||
}
|
||||
}
|
||||
|
||||
l.ReadBufferSize = 64 * 1024 // 64kb - IP packet size
|
||||
l.conn = conn
|
||||
return l.setupDecoder()
|
||||
}
|
||||
|
||||
func (l *packetListener) setupIP(u *url.URL) error {
|
||||
conn, err := net.ListenPacket(u.Scheme, u.Host)
|
||||
if err != nil {
|
||||
return fmt.Errorf("listening (ip) failed: %w", err)
|
||||
}
|
||||
|
||||
l.ReadBufferSize = 64 * 1024 // 64kb - IP packet size
|
||||
l.conn = conn
|
||||
return l.setupDecoder()
|
||||
}
|
||||
|
||||
func (l *packetListener) setupDecoder() error {
|
||||
// Create a decoder for the given encoding
|
||||
var options []internal.DecodingOption
|
||||
if l.MaxDecompressionSize > 0 {
|
||||
options = append(options, internal.WithMaxDecompressionSize(l.MaxDecompressionSize))
|
||||
}
|
||||
|
||||
l.decoders = sync.Pool{New: func() any {
|
||||
decoder, err := internal.NewContentDecoder(l.Encoding, options...)
|
||||
if err != nil {
|
||||
l.Log.Errorf("creating decoder failed: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
return decoder
|
||||
}}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *packetListener) address() net.Addr {
|
||||
return l.conn.LocalAddr()
|
||||
}
|
||||
|
||||
func (l *packetListener) close() error {
|
||||
if err := l.conn.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
l.wg.Wait()
|
||||
|
||||
if l.path != "" {
|
||||
fn := filepath.FromSlash(l.path)
|
||||
if runtime.GOOS == "windows" && strings.Contains(fn, ":") {
|
||||
fn = strings.TrimPrefix(fn, `\`)
|
||||
}
|
||||
if err := os.Remove(fn); err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
// Ignore file-not-exists errors when removing the socket
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
l.parsePool.StopAndWait()
|
||||
|
||||
return nil
|
||||
}
|
37
plugins/common/socket/socket.conf
Normal file
37
plugins/common/socket/socket.conf
Normal file
|
@ -0,0 +1,37 @@
|
|||
## Permission for unix sockets (only available on unix sockets)
|
||||
## This setting may not be respected by some platforms. To safely restrict
|
||||
## permissions it is recommended to place the socket into a previously
|
||||
## created directory with the desired permissions.
|
||||
## ex: socket_mode = "777"
|
||||
# socket_mode = ""
|
||||
|
||||
## Maximum number of concurrent connections (only available on stream sockets like TCP)
|
||||
## Zero means unlimited.
|
||||
# max_connections = 0
|
||||
|
||||
## Read timeout (only available on stream sockets like TCP)
|
||||
## Zero means unlimited.
|
||||
# read_timeout = "0s"
|
||||
|
||||
## Optional TLS configuration (only available on stream sockets like TCP)
|
||||
# tls_cert = "/etc/telegraf/cert.pem"
|
||||
# tls_key = "/etc/telegraf/key.pem"
|
||||
## Enables client authentication if set.
|
||||
# tls_allowed_cacerts = ["/etc/telegraf/clientca.pem"]
|
||||
|
||||
## Maximum socket buffer size (in bytes when no unit specified)
|
||||
## For stream sockets, once the buffer fills up, the sender will start
|
||||
## backing up. For datagram sockets, once the buffer fills up, metrics will
|
||||
## start dropping. Defaults to the OS default.
|
||||
# read_buffer_size = "64KiB"
|
||||
|
||||
## Period between keep alive probes (only applies to TCP sockets)
|
||||
## Zero disables keep alive probes. Defaults to the OS configuration.
|
||||
# keep_alive_period = "5m"
|
||||
|
||||
## Content encoding for message payloads
|
||||
## Can be set to "gzip" for compressed payloads or "identity" for no encoding.
|
||||
# content_encoding = "identity"
|
||||
|
||||
## Maximum size of decoded packet (in bytes when no unit specified)
|
||||
# max_decompression_size = "500MB"
|
181
plugins/common/socket/socket.go
Normal file
181
plugins/common/socket/socket.go
Normal file
|
@ -0,0 +1,181 @@
|
|||
package socket
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/influxdata/telegraf"
|
||||
"github.com/influxdata/telegraf/config"
|
||||
common_tls "github.com/influxdata/telegraf/plugins/common/tls"
|
||||
)
|
||||
|
||||
type CallbackData func(net.Addr, []byte, time.Time)
|
||||
type CallbackConnection func(net.Addr, io.ReadCloser)
|
||||
type CallbackError func(error)
|
||||
|
||||
type listener interface {
|
||||
address() net.Addr
|
||||
listenData(CallbackData, CallbackError)
|
||||
listenConnection(CallbackConnection, CallbackError)
|
||||
close() error
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
MaxConnections uint64 `toml:"max_connections"`
|
||||
ReadBufferSize config.Size `toml:"read_buffer_size"`
|
||||
ReadTimeout config.Duration `toml:"read_timeout"`
|
||||
KeepAlivePeriod *config.Duration `toml:"keep_alive_period"`
|
||||
SocketMode string `toml:"socket_mode"`
|
||||
ContentEncoding string `toml:"content_encoding"`
|
||||
MaxDecompressionSize config.Size `toml:"max_decompression_size"`
|
||||
MaxParallelParsers int `toml:"max_parallel_parsers"`
|
||||
common_tls.ServerConfig
|
||||
}
|
||||
|
||||
type Socket struct {
|
||||
Config
|
||||
|
||||
url *url.URL
|
||||
interfaceName string
|
||||
tlsCfg *tls.Config
|
||||
log telegraf.Logger
|
||||
|
||||
splitter bufio.SplitFunc
|
||||
listener listener
|
||||
}
|
||||
|
||||
func (cfg *Config) NewSocket(address string, splitcfg *SplitConfig, logger telegraf.Logger) (*Socket, error) {
|
||||
s := &Socket{
|
||||
Config: *cfg,
|
||||
log: logger,
|
||||
}
|
||||
|
||||
// Setup the splitter if given
|
||||
if splitcfg != nil {
|
||||
splitter, err := splitcfg.NewSplitter()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.splitter = splitter
|
||||
}
|
||||
|
||||
// Resolve the interface to an address if any given
|
||||
ifregex := regexp.MustCompile(`%([\w\.]+)`)
|
||||
if matches := ifregex.FindStringSubmatch(address); len(matches) == 2 {
|
||||
s.interfaceName = matches[1]
|
||||
address = strings.Replace(address, "%"+s.interfaceName, "", 1)
|
||||
}
|
||||
|
||||
// Preparing TLS configuration
|
||||
tlsCfg, err := s.ServerConfig.TLSConfig()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting TLS config failed: %w", err)
|
||||
}
|
||||
s.tlsCfg = tlsCfg
|
||||
|
||||
// Parse and check the address
|
||||
u, err := url.Parse(address)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing address failed: %w", err)
|
||||
}
|
||||
s.url = u
|
||||
|
||||
switch s.url.Scheme {
|
||||
case "tcp", "tcp4", "tcp6", "unix", "unixpacket",
|
||||
"udp", "udp4", "udp6", "ip", "ip4", "ip6", "unixgram", "vsock":
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown protocol %q in %q", u.Scheme, address)
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *Socket) Setup() error {
|
||||
s.MaxParallelParsers = max(s.MaxParallelParsers, 1)
|
||||
switch s.url.Scheme {
|
||||
case "tcp", "tcp4", "tcp6":
|
||||
l := newStreamListener(
|
||||
s.Config,
|
||||
s.splitter,
|
||||
s.log,
|
||||
)
|
||||
|
||||
if err := l.setupTCP(s.url, s.tlsCfg); err != nil {
|
||||
return err
|
||||
}
|
||||
s.listener = l
|
||||
case "unix", "unixpacket":
|
||||
l := newStreamListener(
|
||||
s.Config,
|
||||
s.splitter,
|
||||
s.log,
|
||||
)
|
||||
|
||||
if err := l.setupUnix(s.url, s.tlsCfg, s.SocketMode); err != nil {
|
||||
return err
|
||||
}
|
||||
s.listener = l
|
||||
case "udp", "udp4", "udp6":
|
||||
l := newPacketListener(s.ContentEncoding, s.MaxDecompressionSize, s.MaxParallelParsers)
|
||||
if err := l.setupUDP(s.url, s.interfaceName, int(s.ReadBufferSize)); err != nil {
|
||||
return err
|
||||
}
|
||||
s.listener = l
|
||||
case "ip", "ip4", "ip6":
|
||||
l := newPacketListener(s.ContentEncoding, s.MaxDecompressionSize, s.MaxParallelParsers)
|
||||
if err := l.setupIP(s.url); err != nil {
|
||||
return err
|
||||
}
|
||||
s.listener = l
|
||||
case "unixgram":
|
||||
l := newPacketListener(s.ContentEncoding, s.MaxDecompressionSize, s.MaxParallelParsers)
|
||||
if err := l.setupUnixgram(s.url, s.SocketMode, int(s.ReadBufferSize)); err != nil {
|
||||
return err
|
||||
}
|
||||
s.listener = l
|
||||
case "vsock":
|
||||
l := newStreamListener(
|
||||
s.Config,
|
||||
s.splitter,
|
||||
s.log,
|
||||
)
|
||||
|
||||
if err := l.setupVsock(s.url); err != nil {
|
||||
return err
|
||||
}
|
||||
s.listener = l
|
||||
default:
|
||||
return fmt.Errorf("unknown protocol %q", s.url.Scheme)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Socket) Listen(onData CallbackData, onError CallbackError) {
|
||||
s.listener.listenData(onData, onError)
|
||||
}
|
||||
|
||||
func (s *Socket) ListenConnection(onConnection CallbackConnection, onError CallbackError) {
|
||||
s.listener.listenConnection(onConnection, onError)
|
||||
}
|
||||
|
||||
func (s *Socket) Close() {
|
||||
if s.listener != nil {
|
||||
// Ignore the returned error as we cannot do anything about it anyway
|
||||
if err := s.listener.close(); err != nil {
|
||||
s.log.Warnf("Closing socket failed: %v", err)
|
||||
}
|
||||
s.listener = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Socket) Address() net.Addr {
|
||||
return s.listener.address()
|
||||
}
|
845
plugins/common/socket/socket_test.go
Normal file
845
plugins/common/socket/socket_test.go
Normal file
|
@ -0,0 +1,845 @@
|
|||
package socket
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/influxdata/telegraf"
|
||||
"github.com/influxdata/telegraf/config"
|
||||
"github.com/influxdata/telegraf/internal"
|
||||
"github.com/influxdata/telegraf/metric"
|
||||
_ "github.com/influxdata/telegraf/plugins/parsers/all"
|
||||
"github.com/influxdata/telegraf/plugins/parsers/influx"
|
||||
"github.com/influxdata/telegraf/testutil"
|
||||
)
|
||||
|
||||
var pki = testutil.NewPKI("../../../testutil/pki")
|
||||
|
||||
func TestListenData(t *testing.T) {
|
||||
messages := [][]byte{
|
||||
[]byte("test,foo=bar v=1i 123456789\ntest,foo=baz v=2i 123456790\n"),
|
||||
[]byte("test,foo=zab v=3i 123456791\n"),
|
||||
}
|
||||
expectedTemplates := []telegraf.Metric{
|
||||
metric.New(
|
||||
"test",
|
||||
map[string]string{"foo": "bar"},
|
||||
map[string]interface{}{"v": int64(1)},
|
||||
time.Unix(0, 123456789),
|
||||
),
|
||||
metric.New(
|
||||
"test",
|
||||
map[string]string{"foo": "baz"},
|
||||
map[string]interface{}{"v": int64(2)},
|
||||
time.Unix(0, 123456790),
|
||||
),
|
||||
metric.New(
|
||||
"test",
|
||||
map[string]string{"foo": "zab"},
|
||||
map[string]interface{}{"v": int64(3)},
|
||||
time.Unix(0, 123456791),
|
||||
),
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
schema string
|
||||
buffersize config.Size
|
||||
encoding string
|
||||
}{
|
||||
{
|
||||
name: "TCP",
|
||||
schema: "tcp",
|
||||
buffersize: config.Size(1024),
|
||||
},
|
||||
{
|
||||
name: "TCP with TLS",
|
||||
schema: "tcp+tls",
|
||||
},
|
||||
{
|
||||
name: "TCP with gzip encoding",
|
||||
schema: "tcp",
|
||||
buffersize: config.Size(1024),
|
||||
encoding: "gzip",
|
||||
},
|
||||
{
|
||||
name: "UDP",
|
||||
schema: "udp",
|
||||
buffersize: config.Size(1024),
|
||||
},
|
||||
{
|
||||
name: "UDP with gzip encoding",
|
||||
schema: "udp",
|
||||
buffersize: config.Size(1024),
|
||||
encoding: "gzip",
|
||||
},
|
||||
{
|
||||
name: "unix socket",
|
||||
schema: "unix",
|
||||
buffersize: config.Size(1024),
|
||||
},
|
||||
{
|
||||
name: "unix socket with TLS",
|
||||
schema: "unix+tls",
|
||||
},
|
||||
{
|
||||
name: "unix socket with gzip encoding",
|
||||
schema: "unix",
|
||||
encoding: "gzip",
|
||||
},
|
||||
{
|
||||
name: "unixgram socket",
|
||||
schema: "unixgram",
|
||||
buffersize: config.Size(1024),
|
||||
},
|
||||
}
|
||||
|
||||
serverTLS := pki.TLSServerConfig()
|
||||
clientTLS := pki.TLSClientConfig()
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
proto := strings.TrimSuffix(tt.schema, "+tls")
|
||||
|
||||
// Prepare the address and socket if needed
|
||||
var sockPath string
|
||||
var serviceAddress string
|
||||
var tlsCfg *tls.Config
|
||||
switch proto {
|
||||
case "tcp", "udp":
|
||||
serviceAddress = proto + "://" + "127.0.0.1:0"
|
||||
case "unix", "unixgram":
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("Skipping on Windows, as unixgram sockets are not supported")
|
||||
}
|
||||
|
||||
// Create a socket
|
||||
sockPath = testutil.TempSocket(t)
|
||||
f, err := os.Create(sockPath)
|
||||
require.NoError(t, err)
|
||||
defer f.Close()
|
||||
serviceAddress = proto + "://" + sockPath
|
||||
}
|
||||
|
||||
// Setup the configuration according to test specification
|
||||
cfg := &Config{
|
||||
ContentEncoding: tt.encoding,
|
||||
ReadBufferSize: tt.buffersize,
|
||||
}
|
||||
if strings.HasSuffix(tt.schema, "tls") {
|
||||
cfg.ServerConfig = *serverTLS
|
||||
var err error
|
||||
tlsCfg, err = clientTLS.TLSConfig()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Create the socket
|
||||
sock, err := cfg.NewSocket(serviceAddress, &SplitConfig{}, &testutil.Logger{})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create callbacks
|
||||
parser := &influx.Parser{}
|
||||
require.NoError(t, parser.Init())
|
||||
|
||||
var acc testutil.Accumulator
|
||||
onData := func(remote net.Addr, data []byte, _ time.Time) {
|
||||
m, err := parser.Parse(data)
|
||||
require.NoError(t, err)
|
||||
addr, _, err := net.SplitHostPort(remote.String())
|
||||
if err != nil {
|
||||
addr = remote.String()
|
||||
}
|
||||
for i := range m {
|
||||
m[i].AddTag("source", addr)
|
||||
}
|
||||
acc.AddMetrics(m)
|
||||
}
|
||||
onError := func(err error) {
|
||||
acc.AddError(err)
|
||||
}
|
||||
|
||||
// Start the listener
|
||||
require.NoError(t, sock.Setup())
|
||||
sock.Listen(onData, onError)
|
||||
defer sock.Close()
|
||||
|
||||
addr := sock.Address()
|
||||
|
||||
// Create a noop client
|
||||
// Server is async, so verify no errors at the end.
|
||||
client, err := createClient(serviceAddress, addr, tlsCfg)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, client.Close())
|
||||
|
||||
// Setup the client for submitting data
|
||||
client, err = createClient(serviceAddress, addr, tlsCfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Conditionally add the source address to the expectation
|
||||
expected := make([]telegraf.Metric, 0, len(expectedTemplates))
|
||||
for _, tmpl := range expectedTemplates {
|
||||
m := tmpl.Copy()
|
||||
switch proto {
|
||||
case "tcp", "udp":
|
||||
laddr := client.LocalAddr().String()
|
||||
addr, _, err := net.SplitHostPort(laddr)
|
||||
if err != nil {
|
||||
addr = laddr
|
||||
}
|
||||
m.AddTag("source", addr)
|
||||
case "unix", "unixgram":
|
||||
m.AddTag("source", sockPath)
|
||||
}
|
||||
expected = append(expected, m)
|
||||
}
|
||||
|
||||
// Send the data with the correct encoding
|
||||
encoder, err := internal.NewContentEncoder(tt.encoding)
|
||||
require.NoError(t, err)
|
||||
|
||||
for i, msg := range messages {
|
||||
m, err := encoder.Encode(msg)
|
||||
require.NoErrorf(t, err, "encoding failed for msg %d", i)
|
||||
_, err = client.Write(m)
|
||||
require.NoErrorf(t, err, "sending msg %d failed", i)
|
||||
}
|
||||
|
||||
// Test the resulting metrics and compare against expected results
|
||||
require.Eventuallyf(t, func() bool {
|
||||
acc.Lock()
|
||||
defer acc.Unlock()
|
||||
return acc.NMetrics() >= uint64(len(expected))
|
||||
}, time.Second, 100*time.Millisecond, "did not receive metrics (%d)", acc.NMetrics())
|
||||
|
||||
actual := acc.GetTelegrafMetrics()
|
||||
testutil.RequireMetricsEqual(t, expected, actual, testutil.SortMetrics())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestListenConnection(t *testing.T) {
|
||||
messages := [][]byte{
|
||||
[]byte("test,foo=bar v=1i 123456789\ntest,foo=baz v=2i 123456790\n"),
|
||||
[]byte("test,foo=zab v=3i 123456791\n"),
|
||||
}
|
||||
expectedTemplates := []telegraf.Metric{
|
||||
metric.New(
|
||||
"test",
|
||||
map[string]string{"foo": "bar"},
|
||||
map[string]interface{}{"v": int64(1)},
|
||||
time.Unix(0, 123456789),
|
||||
),
|
||||
metric.New(
|
||||
"test",
|
||||
map[string]string{"foo": "baz"},
|
||||
map[string]interface{}{"v": int64(2)},
|
||||
time.Unix(0, 123456790),
|
||||
),
|
||||
metric.New(
|
||||
"test",
|
||||
map[string]string{"foo": "zab"},
|
||||
map[string]interface{}{"v": int64(3)},
|
||||
time.Unix(0, 123456791),
|
||||
),
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
schema string
|
||||
buffersize config.Size
|
||||
encoding string
|
||||
}{
|
||||
{
|
||||
name: "TCP",
|
||||
schema: "tcp",
|
||||
buffersize: config.Size(1024),
|
||||
},
|
||||
{
|
||||
name: "TCP with TLS",
|
||||
schema: "tcp+tls",
|
||||
},
|
||||
{
|
||||
name: "TCP with gzip encoding",
|
||||
schema: "tcp",
|
||||
buffersize: config.Size(1024),
|
||||
encoding: "gzip",
|
||||
},
|
||||
{
|
||||
name: "UDP",
|
||||
schema: "udp",
|
||||
buffersize: config.Size(1024),
|
||||
},
|
||||
{
|
||||
name: "UDP with gzip encoding",
|
||||
schema: "udp",
|
||||
buffersize: config.Size(1024),
|
||||
encoding: "gzip",
|
||||
},
|
||||
{
|
||||
name: "unix socket",
|
||||
schema: "unix",
|
||||
buffersize: config.Size(1024),
|
||||
},
|
||||
{
|
||||
name: "unix socket with TLS",
|
||||
schema: "unix+tls",
|
||||
},
|
||||
{
|
||||
name: "unix socket with gzip encoding",
|
||||
schema: "unix",
|
||||
encoding: "gzip",
|
||||
},
|
||||
{
|
||||
name: "unixgram socket",
|
||||
schema: "unixgram",
|
||||
buffersize: config.Size(1024),
|
||||
},
|
||||
}
|
||||
|
||||
serverTLS := pki.TLSServerConfig()
|
||||
clientTLS := pki.TLSClientConfig()
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
proto := strings.TrimSuffix(tt.schema, "+tls")
|
||||
|
||||
// Prepare the address and socket if needed
|
||||
var sockPath string
|
||||
var serviceAddress string
|
||||
var tlsCfg *tls.Config
|
||||
switch proto {
|
||||
case "tcp", "udp":
|
||||
serviceAddress = proto + "://" + "127.0.0.1:0"
|
||||
case "unix", "unixgram":
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("Skipping on Windows, as unixgram sockets are not supported")
|
||||
}
|
||||
|
||||
// Create a socket
|
||||
sockPath = testutil.TempSocket(t)
|
||||
f, err := os.Create(sockPath)
|
||||
require.NoError(t, err)
|
||||
defer f.Close()
|
||||
serviceAddress = proto + "://" + sockPath
|
||||
}
|
||||
|
||||
// Setup the configuration according to test specification
|
||||
cfg := &Config{
|
||||
ContentEncoding: tt.encoding,
|
||||
ReadBufferSize: tt.buffersize,
|
||||
}
|
||||
if strings.HasSuffix(tt.schema, "tls") {
|
||||
cfg.ServerConfig = *serverTLS
|
||||
var err error
|
||||
tlsCfg, err = clientTLS.TLSConfig()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Create the socket
|
||||
sock, err := cfg.NewSocket(serviceAddress, &SplitConfig{}, &testutil.Logger{})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create callbacks
|
||||
parser := &influx.Parser{}
|
||||
require.NoError(t, parser.Init())
|
||||
|
||||
var acc testutil.Accumulator
|
||||
onConnection := func(remote net.Addr, reader io.ReadCloser) {
|
||||
data, err := io.ReadAll(reader)
|
||||
require.NoError(t, err)
|
||||
m, err := parser.Parse(data)
|
||||
require.NoError(t, err)
|
||||
addr, _, err := net.SplitHostPort(remote.String())
|
||||
if err != nil {
|
||||
addr = remote.String()
|
||||
}
|
||||
for i := range m {
|
||||
m[i].AddTag("source", addr)
|
||||
}
|
||||
acc.AddMetrics(m)
|
||||
}
|
||||
onError := func(err error) {
|
||||
acc.AddError(err)
|
||||
}
|
||||
|
||||
// Start the listener
|
||||
require.NoError(t, sock.Setup())
|
||||
sock.ListenConnection(onConnection, onError)
|
||||
defer sock.Close()
|
||||
|
||||
addr := sock.Address()
|
||||
|
||||
// Create a noop client
|
||||
// Server is async, so verify no errors at the end.
|
||||
client, err := createClient(serviceAddress, addr, tlsCfg)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, client.Close())
|
||||
|
||||
// Setup the client for submitting data
|
||||
client, err = createClient(serviceAddress, addr, tlsCfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Conditionally add the source address to the expectation
|
||||
expected := make([]telegraf.Metric, 0, len(expectedTemplates))
|
||||
for _, tmpl := range expectedTemplates {
|
||||
m := tmpl.Copy()
|
||||
switch proto {
|
||||
case "tcp", "udp":
|
||||
laddr := client.LocalAddr().String()
|
||||
addr, _, err := net.SplitHostPort(laddr)
|
||||
if err != nil {
|
||||
addr = laddr
|
||||
}
|
||||
m.AddTag("source", addr)
|
||||
case "unix", "unixgram":
|
||||
m.AddTag("source", sockPath)
|
||||
}
|
||||
expected = append(expected, m)
|
||||
}
|
||||
|
||||
// Send the data with the correct encoding
|
||||
encoder, err := internal.NewContentEncoder(tt.encoding)
|
||||
require.NoError(t, err)
|
||||
|
||||
for i, msg := range messages {
|
||||
m, err := encoder.Encode(msg)
|
||||
require.NoErrorf(t, err, "encoding failed for msg %d", i)
|
||||
_, err = client.Write(m)
|
||||
require.NoErrorf(t, err, "sending msg %d failed", i)
|
||||
}
|
||||
client.Close()
|
||||
|
||||
// Test the resulting metrics and compare against expected results
|
||||
require.Eventuallyf(t, func() bool {
|
||||
acc.Lock()
|
||||
defer acc.Unlock()
|
||||
return acc.NMetrics() >= uint64(len(expected))
|
||||
}, time.Second, 100*time.Millisecond, "did not receive metrics (%d)", acc.NMetrics())
|
||||
|
||||
actual := acc.GetTelegrafMetrics()
|
||||
testutil.RequireMetricsEqual(t, expected, actual, testutil.SortMetrics())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClosingConnections(t *testing.T) {
|
||||
// Setup the configuration
|
||||
cfg := &Config{
|
||||
ReadBufferSize: 1024,
|
||||
}
|
||||
|
||||
// Create the socket
|
||||
serviceAddress := "tcp://127.0.0.1:0"
|
||||
logger := &testutil.CaptureLogger{}
|
||||
sock, err := cfg.NewSocket(serviceAddress, &SplitConfig{}, logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create callbacks
|
||||
parser := &influx.Parser{}
|
||||
require.NoError(t, parser.Init())
|
||||
|
||||
var acc testutil.Accumulator
|
||||
onData := func(_ net.Addr, data []byte, _ time.Time) {
|
||||
m, err := parser.Parse(data)
|
||||
require.NoError(t, err)
|
||||
acc.AddMetrics(m)
|
||||
}
|
||||
onError := func(err error) {
|
||||
acc.AddError(err)
|
||||
}
|
||||
|
||||
// Start the listener
|
||||
require.NoError(t, sock.Setup())
|
||||
sock.Listen(onData, onError)
|
||||
defer sock.Close()
|
||||
|
||||
addr := sock.Address()
|
||||
|
||||
// Create a noop client
|
||||
client, err := createClient(serviceAddress, addr, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.Write([]byte("test value=42i\n"))
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
acc.Lock()
|
||||
defer acc.Unlock()
|
||||
return acc.NMetrics() >= 1
|
||||
}, time.Second, 100*time.Millisecond, "did not receive metric")
|
||||
|
||||
// This has to be a stream-listener...
|
||||
listener, ok := sock.listener.(*streamListener)
|
||||
require.True(t, ok)
|
||||
listener.Lock()
|
||||
conns := listener.connections
|
||||
listener.Unlock()
|
||||
require.NotZero(t, conns)
|
||||
|
||||
sock.Close()
|
||||
|
||||
// Verify that plugin.Stop() closed the client's connection
|
||||
require.NoError(t, client.SetReadDeadline(time.Now().Add(time.Second)))
|
||||
buf := []byte{1}
|
||||
_, err = client.Read(buf)
|
||||
require.Equal(t, err, io.EOF)
|
||||
|
||||
require.Empty(t, logger.Errors())
|
||||
require.Empty(t, logger.Warnings())
|
||||
}
|
||||
func TestMaxConnections(t *testing.T) {
|
||||
if runtime.GOOS == "darwin" {
|
||||
t.Skip("Skipping on darwin due to missing socket options")
|
||||
}
|
||||
|
||||
// Setup the configuration
|
||||
period := config.Duration(10 * time.Millisecond)
|
||||
cfg := &Config{
|
||||
MaxConnections: 5,
|
||||
KeepAlivePeriod: &period,
|
||||
}
|
||||
|
||||
// Create the socket
|
||||
serviceAddress := "tcp://127.0.0.1:0"
|
||||
sock, err := cfg.NewSocket(serviceAddress, nil, &testutil.Logger{})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create callback
|
||||
var errs []error
|
||||
var mu sync.Mutex
|
||||
onData := func(_ net.Addr, _ []byte, _ time.Time) {}
|
||||
onError := func(err error) {
|
||||
mu.Lock()
|
||||
errs = append(errs, err)
|
||||
mu.Unlock()
|
||||
}
|
||||
|
||||
// Start the listener
|
||||
require.NoError(t, sock.Setup())
|
||||
sock.Listen(onData, onError)
|
||||
defer sock.Close()
|
||||
|
||||
addr := sock.Address()
|
||||
|
||||
// Create maximum number of connections and write some data. All of this
|
||||
// should succeed...
|
||||
clients := make([]*net.TCPConn, 0, cfg.MaxConnections)
|
||||
for i := 0; i < int(cfg.MaxConnections); i++ {
|
||||
c, err := net.DialTCP("tcp", nil, addr.(*net.TCPAddr))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, c.SetWriteBuffer(0))
|
||||
require.NoError(t, c.SetNoDelay(true))
|
||||
clients = append(clients, c)
|
||||
|
||||
_, err = c.Write([]byte("test value=42i\n"))
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func() {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
require.Empty(t, errs)
|
||||
}()
|
||||
|
||||
// Create another client. This should fail because we already reached the
|
||||
// connection limit and the connection should be closed...
|
||||
client, err := net.DialTCP("tcp", nil, addr.(*net.TCPAddr))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, client.SetWriteBuffer(0))
|
||||
require.NoError(t, client.SetNoDelay(true))
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
return len(errs) > 0
|
||||
}, 3*time.Second, 100*time.Millisecond)
|
||||
func() {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
require.Len(t, errs, 1)
|
||||
require.ErrorContains(t, errs[0], "too many connections")
|
||||
errs = make([]error, 0)
|
||||
}()
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
_, err := client.Write([]byte("fail\n"))
|
||||
return err != nil
|
||||
}, 3*time.Second, 100*time.Millisecond)
|
||||
_, err = client.Write([]byte("test\n"))
|
||||
require.Error(t, err)
|
||||
|
||||
// Check other connections are still good
|
||||
for _, c := range clients {
|
||||
_, err := c.Write([]byte("test\n"))
|
||||
require.NoError(t, err)
|
||||
}
|
||||
func() {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
require.Empty(t, errs)
|
||||
}()
|
||||
|
||||
// Close the first client and check if we can connect now
|
||||
require.NoError(t, clients[0].Close())
|
||||
client, err = net.DialTCP("tcp", nil, addr.(*net.TCPAddr))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, client.SetWriteBuffer(0))
|
||||
require.NoError(t, client.SetNoDelay(true))
|
||||
_, err = client.Write([]byte("success\n"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Close all connections
|
||||
require.NoError(t, client.Close())
|
||||
for _, c := range clients[1:] {
|
||||
require.NoError(t, c.Close())
|
||||
}
|
||||
|
||||
// Close the clients and check the connection counter
|
||||
listener, ok := sock.listener.(*streamListener)
|
||||
require.True(t, ok)
|
||||
require.Eventually(t, func() bool {
|
||||
listener.Lock()
|
||||
conns := listener.connections
|
||||
listener.Unlock()
|
||||
return conns == 0
|
||||
}, 3*time.Second, 100*time.Millisecond)
|
||||
|
||||
// Close the socket and check again...
|
||||
sock.Close()
|
||||
listener.Lock()
|
||||
conns := listener.connections
|
||||
listener.Unlock()
|
||||
require.Zero(t, conns)
|
||||
}
|
||||
|
||||
func TestNoSplitter(t *testing.T) {
|
||||
messages := [][]byte{
|
||||
[]byte("test,foo=bar v"),
|
||||
[]byte("=1i 123456789\ntest,foo=baz v=2i 123456790\ntest,foo=zab v=3i 123456791\n"),
|
||||
}
|
||||
expectedTemplates := []telegraf.Metric{
|
||||
metric.New(
|
||||
"test",
|
||||
map[string]string{"foo": "bar"},
|
||||
map[string]interface{}{"v": int64(1)},
|
||||
time.Unix(0, 123456789),
|
||||
),
|
||||
metric.New(
|
||||
"test",
|
||||
map[string]string{"foo": "baz"},
|
||||
map[string]interface{}{"v": int64(2)},
|
||||
time.Unix(0, 123456790),
|
||||
),
|
||||
metric.New(
|
||||
"test",
|
||||
map[string]string{"foo": "zab"},
|
||||
map[string]interface{}{"v": int64(3)},
|
||||
time.Unix(0, 123456791),
|
||||
),
|
||||
}
|
||||
|
||||
// Prepare the address and socket if needed
|
||||
serviceAddress := "tcp://127.0.0.1:0"
|
||||
|
||||
// Setup the configuration according to test specification
|
||||
cfg := &Config{}
|
||||
|
||||
// Create the socket
|
||||
sock, err := cfg.NewSocket(serviceAddress, nil, &testutil.Logger{})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create callbacks
|
||||
parser := &influx.Parser{}
|
||||
require.NoError(t, parser.Init())
|
||||
|
||||
var acc testutil.Accumulator
|
||||
onConnection := func(remote net.Addr, reader io.ReadCloser) {
|
||||
data, err := io.ReadAll(reader)
|
||||
require.NoError(t, err)
|
||||
m, err := parser.Parse(data)
|
||||
require.NoError(t, err)
|
||||
addr, _, err := net.SplitHostPort(remote.String())
|
||||
if err != nil {
|
||||
addr = remote.String()
|
||||
}
|
||||
for i := range m {
|
||||
m[i].AddTag("source", addr)
|
||||
}
|
||||
acc.AddMetrics(m)
|
||||
}
|
||||
onError := func(err error) {
|
||||
acc.AddError(err)
|
||||
}
|
||||
|
||||
// Start the listener
|
||||
require.NoError(t, sock.Setup())
|
||||
sock.ListenConnection(onConnection, onError)
|
||||
defer sock.Close()
|
||||
|
||||
addr := sock.Address()
|
||||
|
||||
// Create a noop client
|
||||
// Server is async, so verify no errors at the end.
|
||||
client, err := createClient(serviceAddress, addr, nil)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, client.Close())
|
||||
|
||||
// Setup the client for submitting data
|
||||
client, err = createClient(serviceAddress, addr, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Conditionally add the source address to the expectation
|
||||
expected := make([]telegraf.Metric, 0, len(expectedTemplates))
|
||||
for _, tmpl := range expectedTemplates {
|
||||
m := tmpl.Copy()
|
||||
laddr := client.LocalAddr().String()
|
||||
addr, _, err := net.SplitHostPort(laddr)
|
||||
if err != nil {
|
||||
addr = laddr
|
||||
}
|
||||
m.AddTag("source", addr)
|
||||
expected = append(expected, m)
|
||||
}
|
||||
|
||||
// Send the data
|
||||
for i, msg := range messages {
|
||||
_, err = client.Write(msg)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
require.NoErrorf(t, err, "sending msg %d failed", i)
|
||||
}
|
||||
client.Close()
|
||||
|
||||
// Test the resulting metrics and compare against expected results
|
||||
require.Eventuallyf(t, func() bool {
|
||||
acc.Lock()
|
||||
defer acc.Unlock()
|
||||
return acc.NMetrics() >= uint64(len(expected))
|
||||
}, time.Second, 100*time.Millisecond, "did not receive metrics (%d)", acc.NMetrics())
|
||||
|
||||
actual := acc.GetTelegrafMetrics()
|
||||
testutil.RequireMetricsEqual(t, expected, actual, testutil.SortMetrics())
|
||||
}
|
||||
|
||||
func TestTLSMemLeak(t *testing.T) {
|
||||
// For issue https://github.com/influxdata/telegraf/issues/15509
|
||||
|
||||
// Prepare the address and socket if needed
|
||||
serviceAddress := "tcp://127.0.0.1:0"
|
||||
|
||||
// Setup a TLS socket to trigger the issue
|
||||
cfg := &Config{
|
||||
ServerConfig: *pki.TLSServerConfig(),
|
||||
}
|
||||
|
||||
// Create the socket
|
||||
sock, err := cfg.NewSocket(serviceAddress, nil, &testutil.Logger{})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create callbacks
|
||||
onConnection := func(_ net.Addr, reader io.ReadCloser) {
|
||||
//nolint:errcheck // We are not interested in the data so ignore all errors
|
||||
io.Copy(io.Discard, reader)
|
||||
}
|
||||
|
||||
// Start the listener
|
||||
require.NoError(t, sock.Setup())
|
||||
sock.ListenConnection(onConnection, nil)
|
||||
defer sock.Close()
|
||||
|
||||
addr := sock.Address()
|
||||
|
||||
// Setup the client side TLS
|
||||
tlsCfg, err := pki.TLSClientConfig().TLSConfig()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Define a single client write sequence
|
||||
data := []byte("test value=42i")
|
||||
write := func() error {
|
||||
conn, err := tls.Dial("tcp", addr.String(), tlsCfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
_, err = conn.Write(data)
|
||||
return err
|
||||
}
|
||||
|
||||
// Define a test with the given number of connections
|
||||
maxConcurrency := runtime.GOMAXPROCS(0)
|
||||
testCycle := func(connections int) (uint64, error) {
|
||||
var mu sync.Mutex
|
||||
var errs []error
|
||||
var wg sync.WaitGroup
|
||||
for count := 1; count < connections; count++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := write(); err != nil {
|
||||
mu.Lock()
|
||||
errs = append(errs, err)
|
||||
mu.Unlock()
|
||||
}
|
||||
}()
|
||||
if count%maxConcurrency == 0 {
|
||||
wg.Wait()
|
||||
mu.Lock()
|
||||
if len(errs) > 0 {
|
||||
mu.Unlock()
|
||||
return 0, errors.Join(errs...)
|
||||
}
|
||||
mu.Unlock()
|
||||
}
|
||||
}
|
||||
//nolint:revive // We need to actively run the garbage collector to get reliable measurements
|
||||
runtime.GC()
|
||||
|
||||
var stats runtime.MemStats
|
||||
runtime.ReadMemStats(&stats)
|
||||
return stats.HeapObjects, nil
|
||||
}
|
||||
|
||||
// Measure the memory usage after a short warmup and after some time.
|
||||
// The final number of heap objects should not exceed the number of
|
||||
// runs by a save margin
|
||||
|
||||
// Warmup, do a low number of runs to initialize all data structures
|
||||
// taking them out of the equation.
|
||||
initial, err := testCycle(100)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Do some more runs and make sure the memory growth is bound
|
||||
final, err := testCycle(2000)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Less(t, final, 3*initial)
|
||||
}
|
||||
|
||||
func createClient(endpoint string, addr net.Addr, tlsCfg *tls.Config) (net.Conn, error) {
|
||||
// Determine the protocol in a crude fashion
|
||||
parts := strings.SplitN(endpoint, "://", 2)
|
||||
if len(parts) != 2 {
|
||||
return nil, fmt.Errorf("invalid endpoint %q", endpoint)
|
||||
}
|
||||
protocol := parts[0]
|
||||
|
||||
if tlsCfg == nil {
|
||||
return net.Dial(protocol, addr.String())
|
||||
}
|
||||
|
||||
if protocol == "unix" {
|
||||
tlsCfg.InsecureSkipVerify = true
|
||||
}
|
||||
return tls.Dial(protocol, addr.String(), tlsCfg)
|
||||
}
|
37
plugins/common/socket/splitter.conf
Normal file
37
plugins/common/socket/splitter.conf
Normal file
|
@ -0,0 +1,37 @@
|
|||
## Message splitting strategy and corresponding settings for stream sockets
|
||||
## (tcp, tcp4, tcp6, unix or unixpacket). The setting is ignored for packet
|
||||
## listeners such as udp.
|
||||
## Available strategies are:
|
||||
## newline -- split at newlines (default)
|
||||
## null -- split at null bytes
|
||||
## delimiter -- split at delimiter byte-sequence in hex-format
|
||||
## given in `splitting_delimiter`
|
||||
## fixed length -- split after number of bytes given in `splitting_length`
|
||||
## variable length -- split depending on length information received in the
|
||||
## data. The length field information is specified in
|
||||
## `splitting_length_field`.
|
||||
# splitting_strategy = "newline"
|
||||
|
||||
## Delimiter used to split received data to messages consumed by the parser.
|
||||
## The delimiter is a hex byte-sequence marking the end of a message
|
||||
## e.g. "0x0D0A", "x0d0a" or "0d0a" marks a Windows line-break (CR LF).
|
||||
## The value is case-insensitive and can be specified with "0x" or "x" prefix
|
||||
## or without.
|
||||
## Note: This setting is only used for splitting_strategy = "delimiter".
|
||||
# splitting_delimiter = ""
|
||||
|
||||
## Fixed length of a message in bytes.
|
||||
## Note: This setting is only used for splitting_strategy = "fixed length".
|
||||
# splitting_length = 0
|
||||
|
||||
## Specification of the length field contained in the data to split messages
|
||||
## with variable length. The specification contains the following fields:
|
||||
## offset -- start of length field in bytes from begin of data
|
||||
## bytes -- length of length field in bytes
|
||||
## endianness -- endianness of the value, either "be" for big endian or
|
||||
## "le" for little endian
|
||||
## header_length -- total length of header to be skipped when passing
|
||||
## data on to the parser. If zero (default), the header
|
||||
## is passed on to the parser together with the message.
|
||||
## Note: This setting is only used for splitting_strategy = "variable length".
|
||||
# splitting_length_field = {offset = 0, bytes = 0, endianness = "be", header_length = 0}
|
167
plugins/common/socket/splitters.go
Normal file
167
plugins/common/socket/splitters.go
Normal file
|
@ -0,0 +1,167 @@
|
|||
package socket
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type lengthFieldSpec struct {
|
||||
Offset int64 `toml:"offset"`
|
||||
Bytes int64 `toml:"bytes"`
|
||||
Endianness string `toml:"endianness"`
|
||||
HeaderLength int64 `toml:"header_length"`
|
||||
converter func([]byte) int
|
||||
}
|
||||
|
||||
type SplitConfig struct {
|
||||
SplittingStrategy string `toml:"splitting_strategy"`
|
||||
SplittingDelimiter string `toml:"splitting_delimiter"`
|
||||
SplittingLength int `toml:"splitting_length"`
|
||||
SplittingLengthField lengthFieldSpec `toml:"splitting_length_field"`
|
||||
}
|
||||
|
||||
func (cfg *SplitConfig) NewSplitter() (bufio.SplitFunc, error) {
|
||||
switch cfg.SplittingStrategy {
|
||||
case "", "newline":
|
||||
return bufio.ScanLines, nil
|
||||
case "null":
|
||||
return scanNull, nil
|
||||
case "delimiter":
|
||||
re := regexp.MustCompile(`(\s*0?x)`)
|
||||
d := re.ReplaceAllString(strings.ToLower(cfg.SplittingDelimiter), "")
|
||||
delimiter, err := hex.DecodeString(d)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decoding delimiter failed: %w", err)
|
||||
}
|
||||
return createScanDelimiter(delimiter), nil
|
||||
case "fixed length":
|
||||
return createScanFixedLength(cfg.SplittingLength), nil
|
||||
case "variable length":
|
||||
// Create the converter function
|
||||
var order binary.ByteOrder
|
||||
switch strings.ToLower(cfg.SplittingLengthField.Endianness) {
|
||||
case "", "be":
|
||||
order = binary.BigEndian
|
||||
case "le":
|
||||
order = binary.LittleEndian
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid 'endianness' %q", cfg.SplittingLengthField.Endianness)
|
||||
}
|
||||
|
||||
switch cfg.SplittingLengthField.Bytes {
|
||||
case 1:
|
||||
cfg.SplittingLengthField.converter = func(b []byte) int {
|
||||
return int(b[0])
|
||||
}
|
||||
case 2:
|
||||
cfg.SplittingLengthField.converter = func(b []byte) int {
|
||||
return int(order.Uint16(b))
|
||||
}
|
||||
case 4:
|
||||
cfg.SplittingLengthField.converter = func(b []byte) int {
|
||||
return int(order.Uint32(b))
|
||||
}
|
||||
case 8:
|
||||
cfg.SplittingLengthField.converter = func(b []byte) int {
|
||||
return int(order.Uint64(b))
|
||||
}
|
||||
default:
|
||||
cfg.SplittingLengthField.converter = func(b []byte) int {
|
||||
buf := make([]byte, 8)
|
||||
start := 0
|
||||
if order == binary.BigEndian {
|
||||
start = 8 - len(b)
|
||||
}
|
||||
for i := 0; i < len(b); i++ {
|
||||
buf[start+i] = b[i]
|
||||
}
|
||||
return int(order.Uint64(buf))
|
||||
}
|
||||
}
|
||||
|
||||
// Check if we have enough bytes in the header
|
||||
return createScanVariableLength(cfg.SplittingLengthField), nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unknown 'splitting_strategy' %q", cfg.SplittingStrategy)
|
||||
}
|
||||
|
||||
func scanNull(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
if i := bytes.IndexByte(data, 0); i >= 0 {
|
||||
return i + 1, data[:i], nil
|
||||
}
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
// Request more data.
|
||||
return 0, nil, nil
|
||||
}
|
||||
|
||||
func createScanDelimiter(delimiter []byte) bufio.SplitFunc {
|
||||
return func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
if i := bytes.Index(data, delimiter); i >= 0 {
|
||||
return i + len(delimiter), data[:i], nil
|
||||
}
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
// Request more data.
|
||||
return 0, nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
func createScanFixedLength(length int) bufio.SplitFunc {
|
||||
return func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
if len(data) >= length {
|
||||
return length, data[:length], nil
|
||||
}
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
// Request more data.
|
||||
return 0, nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
func createScanVariableLength(spec lengthFieldSpec) bufio.SplitFunc {
|
||||
minlen := int(spec.Offset)
|
||||
minlen += int(spec.Bytes)
|
||||
headerLen := int(spec.HeaderLength)
|
||||
|
||||
return func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
dataLen := len(data)
|
||||
if dataLen >= minlen {
|
||||
// Extract the length field and convert it to a number
|
||||
lf := data[spec.Offset : spec.Offset+spec.Bytes]
|
||||
length := spec.converter(lf)
|
||||
start := headerLen
|
||||
end := length + headerLen
|
||||
// If we have enough data return it without the header
|
||||
if end <= dataLen {
|
||||
return end, data[start:end], nil
|
||||
}
|
||||
}
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
// Request more data.
|
||||
return 0, nil, nil
|
||||
}
|
||||
}
|
476
plugins/common/socket/stream.go
Normal file
476
plugins/common/socket/stream.go
Normal file
|
@ -0,0 +1,476 @@
|
|||
package socket
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"net"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/alitto/pond"
|
||||
"github.com/mdlayher/vsock"
|
||||
|
||||
"github.com/influxdata/telegraf"
|
||||
"github.com/influxdata/telegraf/config"
|
||||
"github.com/influxdata/telegraf/internal"
|
||||
)
|
||||
|
||||
type hasSetReadBuffer interface {
|
||||
SetReadBuffer(bytes int) error
|
||||
}
|
||||
|
||||
type streamListener struct {
|
||||
Encoding string
|
||||
ReadBufferSize int
|
||||
MaxConnections uint64
|
||||
ReadTimeout config.Duration
|
||||
KeepAlivePeriod *config.Duration
|
||||
Splitter bufio.SplitFunc
|
||||
Log telegraf.Logger
|
||||
|
||||
listener net.Listener
|
||||
connections uint64
|
||||
path string
|
||||
cancel context.CancelFunc
|
||||
parsePool *pond.WorkerPool
|
||||
|
||||
wg sync.WaitGroup
|
||||
sync.Mutex
|
||||
}
|
||||
|
||||
func newStreamListener(conf Config, splitter bufio.SplitFunc, log telegraf.Logger) *streamListener {
|
||||
return &streamListener{
|
||||
ReadBufferSize: int(conf.ReadBufferSize),
|
||||
ReadTimeout: conf.ReadTimeout,
|
||||
KeepAlivePeriod: conf.KeepAlivePeriod,
|
||||
MaxConnections: conf.MaxConnections,
|
||||
Encoding: conf.ContentEncoding,
|
||||
Splitter: splitter,
|
||||
Log: log,
|
||||
|
||||
parsePool: pond.New(
|
||||
conf.MaxParallelParsers,
|
||||
0,
|
||||
pond.MinWorkers(conf.MaxParallelParsers/2+1)),
|
||||
}
|
||||
}
|
||||
|
||||
func (l *streamListener) setupTCP(u *url.URL, tlsCfg *tls.Config) error {
|
||||
var err error
|
||||
if tlsCfg == nil {
|
||||
l.listener, err = net.Listen(u.Scheme, u.Host)
|
||||
} else {
|
||||
l.listener, err = tls.Listen(u.Scheme, u.Host, tlsCfg)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (l *streamListener) setupUnix(u *url.URL, tlsCfg *tls.Config, socketMode string) error {
|
||||
l.path = filepath.FromSlash(u.Path)
|
||||
if runtime.GOOS == "windows" && strings.Contains(l.path, ":") {
|
||||
l.path = strings.TrimPrefix(l.path, `\`)
|
||||
}
|
||||
if err := os.Remove(l.path); err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
return fmt.Errorf("removing socket failed: %w", err)
|
||||
}
|
||||
|
||||
var err error
|
||||
if tlsCfg == nil {
|
||||
l.listener, err = net.Listen(u.Scheme, l.path)
|
||||
} else {
|
||||
l.listener, err = tls.Listen(u.Scheme, l.path, tlsCfg)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Set permissions on socket
|
||||
if socketMode != "" {
|
||||
// Convert from octal in string to int
|
||||
i, err := strconv.ParseUint(socketMode, 8, 32)
|
||||
if err != nil {
|
||||
return fmt.Errorf("converting socket mode failed: %w", err)
|
||||
}
|
||||
|
||||
perm := os.FileMode(uint32(i))
|
||||
if err := os.Chmod(u.Path, perm); err != nil {
|
||||
return fmt.Errorf("changing socket permissions failed: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *streamListener) setupVsock(u *url.URL) error {
|
||||
var err error
|
||||
|
||||
addrTuple := strings.SplitN(u.String(), ":", 2)
|
||||
|
||||
// Check address string for containing two tokens
|
||||
if len(addrTuple) < 2 {
|
||||
return errors.New("port and/or CID number missing")
|
||||
}
|
||||
// Parse CID and port number from address string both being 32-bit
|
||||
// source: https://man7.org/linux/man-pages/man7/vsock.7.html
|
||||
cid, err := strconv.ParseUint(addrTuple[0], 10, 32)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse CID %s: %w", addrTuple[0], err)
|
||||
}
|
||||
if (cid >= uint64(math.Pow(2, 32))-1) && (cid <= 0) {
|
||||
return fmt.Errorf("value of CID %d is out of range", cid)
|
||||
}
|
||||
port, err := strconv.ParseUint(addrTuple[1], 10, 32)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse port number %s: %w", addrTuple[1], err)
|
||||
}
|
||||
if (port >= uint64(math.Pow(2, 32))-1) && (port <= 0) {
|
||||
return fmt.Errorf("port number %d is out of range", port)
|
||||
}
|
||||
|
||||
l.listener, err = vsock.Listen(uint32(port), nil)
|
||||
return err
|
||||
}
|
||||
|
||||
func (l *streamListener) setupConnection(conn net.Conn) error {
|
||||
addr := conn.RemoteAddr().String()
|
||||
l.Lock()
|
||||
if l.MaxConnections > 0 && l.connections >= l.MaxConnections {
|
||||
l.Unlock()
|
||||
// Ignore the returned error as we cannot do anything about it anyway
|
||||
_ = conn.Close()
|
||||
return fmt.Errorf("unable to accept connection from %q: too many connections", addr)
|
||||
}
|
||||
l.connections++
|
||||
l.Unlock()
|
||||
|
||||
if l.ReadBufferSize > 0 {
|
||||
if rb, ok := conn.(hasSetReadBuffer); ok {
|
||||
if err := rb.SetReadBuffer(l.ReadBufferSize); err != nil {
|
||||
l.Log.Warnf("Setting read buffer on socket failed: %v", err)
|
||||
}
|
||||
} else {
|
||||
l.Log.Warn("Cannot set read buffer on socket of this type")
|
||||
}
|
||||
}
|
||||
|
||||
// Set keep alive handlings
|
||||
if l.KeepAlivePeriod != nil {
|
||||
if c, ok := conn.(*tls.Conn); ok {
|
||||
conn = c.NetConn()
|
||||
}
|
||||
tcpConn, ok := conn.(*net.TCPConn)
|
||||
if !ok {
|
||||
l.Log.Warnf("connection not a TCP connection (%T)", conn)
|
||||
}
|
||||
if *l.KeepAlivePeriod == 0 {
|
||||
if err := tcpConn.SetKeepAlive(false); err != nil {
|
||||
l.Log.Warnf("Cannot set keep-alive: %v", err)
|
||||
}
|
||||
} else {
|
||||
if err := tcpConn.SetKeepAlive(true); err != nil {
|
||||
l.Log.Warnf("Cannot set keep-alive: %v", err)
|
||||
}
|
||||
err := tcpConn.SetKeepAlivePeriod(time.Duration(*l.KeepAlivePeriod))
|
||||
if err != nil {
|
||||
l.Log.Warnf("Cannot set keep-alive period: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *streamListener) closeConnection(conn net.Conn) {
|
||||
// Fallback to enforce blocked reads on connections to end immediately
|
||||
//nolint:errcheck // Ignore errors as this is a fallback only
|
||||
conn.SetReadDeadline(time.Now())
|
||||
|
||||
addr := conn.RemoteAddr().String()
|
||||
if err := conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) && !errors.Is(err, syscall.EPIPE) {
|
||||
l.Log.Warnf("Cannot close connection to %q: %v", addr, err)
|
||||
} else {
|
||||
l.Lock()
|
||||
l.connections--
|
||||
l.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (l *streamListener) address() net.Addr {
|
||||
return l.listener.Addr()
|
||||
}
|
||||
|
||||
func (l *streamListener) close() error {
|
||||
if l.listener != nil {
|
||||
// Continue even if we cannot close the listener in order to at least
|
||||
// close all active connections
|
||||
if err := l.listener.Close(); err != nil {
|
||||
l.Log.Errorf("Cannot close listener: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if l.cancel != nil {
|
||||
l.cancel()
|
||||
l.cancel = nil
|
||||
}
|
||||
l.wg.Wait()
|
||||
|
||||
if l.path != "" {
|
||||
fn := filepath.FromSlash(l.path)
|
||||
if runtime.GOOS == "windows" && strings.Contains(fn, ":") {
|
||||
fn = strings.TrimPrefix(fn, `\`)
|
||||
}
|
||||
// Ignore file-not-exists errors when removing the socket
|
||||
if err := os.Remove(fn); err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
l.parsePool.StopAndWait()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *streamListener) listenData(onData CallbackData, onError CallbackError) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
l.cancel = cancel
|
||||
|
||||
l.wg.Add(1)
|
||||
go func() {
|
||||
defer l.wg.Done()
|
||||
|
||||
for {
|
||||
conn, err := l.listener.Accept()
|
||||
if err != nil {
|
||||
if !errors.Is(err, net.ErrClosed) && onError != nil {
|
||||
onError(err)
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
if err := l.setupConnection(conn); err != nil && onError != nil {
|
||||
onError(err)
|
||||
continue
|
||||
}
|
||||
|
||||
l.wg.Add(1)
|
||||
go l.handleReaderConn(ctx, conn, onData, onError)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (l *streamListener) handleReaderConn(ctx context.Context, conn net.Conn, onData CallbackData, onError CallbackError) {
|
||||
defer l.wg.Done()
|
||||
|
||||
localCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
defer l.closeConnection(conn)
|
||||
stopFunc := context.AfterFunc(localCtx, func() { l.closeConnection(conn) })
|
||||
defer stopFunc()
|
||||
|
||||
reader := l.read
|
||||
if l.Splitter == nil {
|
||||
reader = l.readAll
|
||||
}
|
||||
if err := reader(conn, onData); err != nil {
|
||||
if !errors.Is(err, io.EOF) && !errors.Is(err, syscall.ECONNRESET) {
|
||||
if onError != nil {
|
||||
onError(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *streamListener) listenConnection(onConnection CallbackConnection, onError CallbackError) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
l.cancel = cancel
|
||||
|
||||
l.wg.Add(1)
|
||||
go func() {
|
||||
defer l.wg.Done()
|
||||
|
||||
for {
|
||||
conn, err := l.listener.Accept()
|
||||
if err != nil {
|
||||
if !errors.Is(err, net.ErrClosed) && onError != nil {
|
||||
onError(err)
|
||||
}
|
||||
break
|
||||
}
|
||||
if err := l.setupConnection(conn); err != nil && onError != nil {
|
||||
onError(err)
|
||||
continue
|
||||
}
|
||||
|
||||
l.wg.Add(1)
|
||||
go func(c net.Conn) {
|
||||
if err := l.handleConnection(ctx, c, onConnection); err != nil {
|
||||
if !errors.Is(err, io.EOF) && !errors.Is(err, syscall.ECONNRESET) {
|
||||
if onError != nil {
|
||||
onError(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}(conn)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (l *streamListener) read(conn net.Conn, onData CallbackData) error {
|
||||
decoder, err := internal.NewStreamContentDecoder(l.Encoding, conn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating decoder failed: %w", err)
|
||||
}
|
||||
|
||||
timeout := time.Duration(l.ReadTimeout)
|
||||
|
||||
scanner := bufio.NewScanner(decoder)
|
||||
if l.ReadBufferSize > bufio.MaxScanTokenSize {
|
||||
scanner.Buffer(make([]byte, l.ReadBufferSize), l.ReadBufferSize)
|
||||
}
|
||||
scanner.Split(l.Splitter)
|
||||
for {
|
||||
// Set the read deadline, if any, then start reading. The read
|
||||
// will accept the deadline and return if no or insufficient data
|
||||
// arrived in time. We need to set the deadline in every cycle as
|
||||
// it is an ABSOLUTE time and not a timeout.
|
||||
if timeout > 0 {
|
||||
deadline := time.Now().Add(timeout)
|
||||
if err := conn.SetReadDeadline(deadline); err != nil {
|
||||
return fmt.Errorf("setting read deadline failed: %w", err)
|
||||
}
|
||||
}
|
||||
if !scanner.Scan() {
|
||||
// Exit if no data arrived e.g. due to timeout or closed connection
|
||||
break
|
||||
}
|
||||
|
||||
receiveTime := time.Now()
|
||||
src := conn.RemoteAddr()
|
||||
if l.path != "" {
|
||||
src = &net.UnixAddr{Name: l.path, Net: "unix"}
|
||||
}
|
||||
|
||||
data := scanner.Bytes()
|
||||
d := make([]byte, len(data))
|
||||
copy(d, data)
|
||||
l.parsePool.Submit(func() {
|
||||
onData(src, d, receiveTime)
|
||||
})
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
if errors.Is(err, os.ErrDeadlineExceeded) {
|
||||
// Ignore the timeout and silently close the connection
|
||||
l.Log.Debug(err)
|
||||
return nil
|
||||
}
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
// Ignore the connection closing of the remote side
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *streamListener) readAll(conn net.Conn, onData CallbackData) error {
|
||||
src := conn.RemoteAddr()
|
||||
if l.path != "" {
|
||||
src = &net.UnixAddr{Name: l.path, Net: "unix"}
|
||||
}
|
||||
|
||||
decoder, err := internal.NewStreamContentDecoder(l.Encoding, conn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating decoder failed: %w", err)
|
||||
}
|
||||
|
||||
timeout := time.Duration(l.ReadTimeout)
|
||||
// Set the read deadline, if any, then start reading. The read
|
||||
// will accept the deadline and return if no or insufficient data
|
||||
// arrived in time. We need to set the deadline in every cycle as
|
||||
// it is an ABSOLUTE time and not a timeout.
|
||||
if timeout > 0 {
|
||||
deadline := time.Now().Add(timeout)
|
||||
if err := conn.SetReadDeadline(deadline); err != nil {
|
||||
return fmt.Errorf("setting read deadline failed: %w", err)
|
||||
}
|
||||
}
|
||||
buf, err := io.ReadAll(decoder)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read on %s failed: %w", src, err)
|
||||
}
|
||||
|
||||
receiveTime := time.Now()
|
||||
l.parsePool.Submit(func() {
|
||||
onData(src, buf, receiveTime)
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *streamListener) handleConnection(ctx context.Context, conn net.Conn, onConnection CallbackConnection) error {
|
||||
defer l.wg.Done()
|
||||
|
||||
localCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
defer l.closeConnection(conn)
|
||||
stopFunc := context.AfterFunc(localCtx, func() { l.closeConnection(conn) })
|
||||
defer stopFunc()
|
||||
|
||||
// Prepare the data decoder for the connection
|
||||
decoder, err := internal.NewStreamContentDecoder(l.Encoding, conn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating decoder failed: %w", err)
|
||||
}
|
||||
|
||||
// Get the remote address
|
||||
src := conn.RemoteAddr()
|
||||
if l.path != "" {
|
||||
src = &net.UnixAddr{Name: l.path, Net: "unix"}
|
||||
}
|
||||
|
||||
// Create a pipe and feed it to the callback
|
||||
reader, writer := io.Pipe()
|
||||
defer writer.Close()
|
||||
go onConnection(src, reader)
|
||||
|
||||
timeout := time.Duration(l.ReadTimeout)
|
||||
buf := make([]byte, 4096) // 4kb
|
||||
for {
|
||||
// Set the read deadline, if any, then start reading. The read
|
||||
// will accept the deadline and return if no or insufficient data
|
||||
// arrived in time. We need to set the deadline in every cycle as
|
||||
// it is an ABSOLUTE time and not a timeout.
|
||||
if timeout > 0 {
|
||||
deadline := time.Now().Add(timeout)
|
||||
if err := conn.SetReadDeadline(deadline); err != nil {
|
||||
return fmt.Errorf("setting read deadline failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Copy the data
|
||||
n, err := decoder.Read(buf)
|
||||
if err != nil {
|
||||
if !strings.HasSuffix(err.Error(), ": use of closed network connection") {
|
||||
if !errors.Is(err, os.ErrDeadlineExceeded) && errors.Is(err, net.ErrClosed) {
|
||||
writer.CloseWithError(err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if _, err := writer.Write(buf[:n]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue