1
0
Fork 0

Adding upstream version 1.34.4.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-05-24 07:26:29 +02:00
parent e393c3af3f
commit 4978089aab
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
4963 changed files with 677545 additions and 0 deletions

189
plugins/common/adx/adx.go Normal file
View file

@ -0,0 +1,189 @@
package adx
import (
"bytes"
"context"
"errors"
"fmt"
"strings"
"time"
"github.com/Azure/azure-kusto-go/kusto"
kustoerrors "github.com/Azure/azure-kusto-go/kusto/data/errors"
"github.com/Azure/azure-kusto-go/kusto/ingest"
"github.com/Azure/azure-kusto-go/kusto/kql"
"github.com/influxdata/telegraf"
"github.com/influxdata/telegraf/config"
"github.com/influxdata/telegraf/internal"
)
const (
TablePerMetric = "tablepermetric"
SingleTable = "singletable"
// These control the amount of memory we use when ingesting blobs
bufferSize = 1 << 20 // 1 MiB
maxBuffers = 5
ManagedIngestion = "managed"
QueuedIngestion = "queued"
)
type Config struct {
Endpoint string `toml:"endpoint_url"`
Database string `toml:"database"`
Timeout config.Duration `toml:"timeout"`
MetricsGrouping string `toml:"metrics_grouping_type"`
TableName string `toml:"table_name"`
CreateTables bool `toml:"create_tables"`
IngestionType string `toml:"ingestion_type"`
}
type Client struct {
cfg *Config
client *kusto.Client
ingestors map[string]ingest.Ingestor
logger telegraf.Logger
}
func (cfg *Config) NewClient(app string, log telegraf.Logger) (*Client, error) {
if cfg.Endpoint == "" {
return nil, errors.New("endpoint configuration cannot be empty")
}
if cfg.Database == "" {
return nil, errors.New("database configuration cannot be empty")
}
cfg.MetricsGrouping = strings.ToLower(cfg.MetricsGrouping)
if cfg.MetricsGrouping == SingleTable && cfg.TableName == "" {
return nil, errors.New("table name cannot be empty for SingleTable metrics grouping type")
}
if cfg.MetricsGrouping == "" {
cfg.MetricsGrouping = TablePerMetric
}
if cfg.MetricsGrouping != SingleTable && cfg.MetricsGrouping != TablePerMetric {
return nil, errors.New("metrics grouping type is not valid")
}
if cfg.Timeout == 0 {
cfg.Timeout = config.Duration(20 * time.Second)
}
switch cfg.IngestionType {
case "":
cfg.IngestionType = QueuedIngestion
case ManagedIngestion, QueuedIngestion:
// Do nothing as those are valid
default:
return nil, fmt.Errorf("unknown ingestion type %q", cfg.IngestionType)
}
conn := kusto.NewConnectionStringBuilder(cfg.Endpoint).WithDefaultAzureCredential()
conn.SetConnectorDetails("Telegraf", internal.ProductToken(), app, "", false, "")
client, err := kusto.New(conn)
if err != nil {
return nil, err
}
return &Client{
cfg: cfg,
ingestors: make(map[string]ingest.Ingestor),
logger: log,
client: client,
}, nil
}
// Clean up and close the ingestor
func (adx *Client) Close() error {
var errs []error
for _, v := range adx.ingestors {
if err := v.Close(); err != nil {
// accumulate errors while closing ingestors
errs = append(errs, err)
}
}
if err := adx.client.Close(); err != nil {
errs = append(errs, err)
}
adx.client = nil
adx.ingestors = nil
if len(errs) == 0 {
return nil
}
// Combine errors into a single object and return the combined error
return kustoerrors.GetCombinedError(errs...)
}
func (adx *Client) PushMetrics(format ingest.FileOption, tableName string, metrics []byte) error {
ctx := context.Background()
ctx, cancel := context.WithTimeout(ctx, time.Duration(adx.cfg.Timeout))
defer cancel()
metricIngestor, err := adx.getMetricIngestor(ctx, tableName)
if err != nil {
return err
}
reader := bytes.NewReader(metrics)
mapping := ingest.IngestionMappingRef(tableName+"_mapping", ingest.JSON)
if metricIngestor != nil {
if _, err := metricIngestor.FromReader(ctx, reader, format, mapping); err != nil {
return fmt.Errorf("sending ingestion request to Azure Data Explorer for table %q failed: %w", tableName, err)
}
}
return nil
}
func (adx *Client) getMetricIngestor(ctx context.Context, tableName string) (ingest.Ingestor, error) {
if ingestor := adx.ingestors[tableName]; ingestor != nil {
return ingestor, nil
}
if adx.cfg.CreateTables {
if _, err := adx.client.Mgmt(ctx, adx.cfg.Database, createTableCommand(tableName)); err != nil {
return nil, fmt.Errorf("creating table for %q failed: %w", tableName, err)
}
if _, err := adx.client.Mgmt(ctx, adx.cfg.Database, createTableMappingCommand(tableName)); err != nil {
return nil, err
}
}
// Create a new ingestor client for the table
var ingestor ingest.Ingestor
var err error
switch strings.ToLower(adx.cfg.IngestionType) {
case ManagedIngestion:
ingestor, err = ingest.NewManaged(adx.client, adx.cfg.Database, tableName)
case QueuedIngestion:
ingestor, err = ingest.New(adx.client, adx.cfg.Database, tableName, ingest.WithStaticBuffer(bufferSize, maxBuffers))
default:
return nil, fmt.Errorf(`ingestion_type has to be one of %q or %q`, ManagedIngestion, QueuedIngestion)
}
if err != nil {
return nil, fmt.Errorf("creating ingestor for %q failed: %w", tableName, err)
}
adx.ingestors[tableName] = ingestor
return ingestor, nil
}
func createTableCommand(table string) kusto.Statement {
builder := kql.New(`.create-merge table ['`).AddTable(table).AddLiteral(`'] `)
builder.AddLiteral(`(['fields']:dynamic, ['name']:string, ['tags']:dynamic, ['timestamp']:datetime);`)
return builder
}
func createTableMappingCommand(table string) kusto.Statement {
builder := kql.New(`.create-or-alter table ['`).AddTable(table).AddLiteral(`'] `)
builder.AddLiteral(`ingestion json mapping '`).AddTable(table + "_mapping").AddLiteral(`' `)
builder.AddLiteral(`'[{"column":"fields", `)
builder.AddLiteral(`"Properties":{"Path":"$[\'fields\']"}},{"column":"name", `)
builder.AddLiteral(`"Properties":{"Path":"$[\'name\']"}},{"column":"tags", `)
builder.AddLiteral(`"Properties":{"Path":"$[\'tags\']"}},{"column":"timestamp", `)
builder.AddLiteral(`"Properties":{"Path":"$[\'timestamp\']"}}]'`)
return builder
}

View file

@ -0,0 +1,219 @@
package adx
import (
"bufio"
"context"
"encoding/json"
"io"
"testing"
"time"
"github.com/Azure/azure-kusto-go/kusto"
"github.com/Azure/azure-kusto-go/kusto/ingest"
"github.com/stretchr/testify/require"
"github.com/influxdata/telegraf"
"github.com/influxdata/telegraf/config"
serializers_json "github.com/influxdata/telegraf/plugins/serializers/json"
"github.com/influxdata/telegraf/testutil"
)
func TestInitBlankEndpointData(t *testing.T) {
plugin := Config{
Endpoint: "",
Database: "mydb",
}
_, err := plugin.NewClient("TestKusto.Telegraf", nil)
require.Error(t, err)
require.Equal(t, "endpoint configuration cannot be empty", err.Error())
}
func TestQueryConstruction(t *testing.T) {
const tableName = "mytable"
const expectedCreate = `.create-merge table ['mytable'] (['fields']:dynamic, ['name']:string, ['tags']:dynamic, ['timestamp']:datetime);`
const expectedMapping = `` +
`.create-or-alter table ['mytable'] ingestion json mapping 'mytable_mapping' '[{"column":"fields", ` +
`"Properties":{"Path":"$[\'fields\']"}},{"column":"name", "Properties":{"Path":"$[\'name\']"}},{"column":"tags", ` +
`"Properties":{"Path":"$[\'tags\']"}},{"column":"timestamp", "Properties":{"Path":"$[\'timestamp\']"}}]'`
require.Equal(t, expectedCreate, createTableCommand(tableName).String())
require.Equal(t, expectedMapping, createTableMappingCommand(tableName).String())
}
func TestGetMetricIngestor(t *testing.T) {
plugin := Client{
logger: testutil.Logger{},
client: kusto.NewMockClient(),
cfg: &Config{
Database: "mydb",
IngestionType: QueuedIngestion,
},
ingestors: map[string]ingest.Ingestor{"test1": &fakeIngestor{}},
}
ingestor, err := plugin.getMetricIngestor(t.Context(), "test1")
require.NoError(t, err)
require.NotNil(t, ingestor)
}
func TestGetMetricIngestorNoIngester(t *testing.T) {
plugin := Client{
logger: testutil.Logger{},
client: kusto.NewMockClient(),
cfg: &Config{
IngestionType: QueuedIngestion,
},
ingestors: map[string]ingest.Ingestor{"test1": &fakeIngestor{}},
}
ingestor, err := plugin.getMetricIngestor(t.Context(), "test1")
require.NoError(t, err)
require.NotNil(t, ingestor)
}
func TestPushMetrics(t *testing.T) {
plugin := Client{
logger: testutil.Logger{},
client: kusto.NewMockClient(),
cfg: &Config{
Database: "mydb",
Endpoint: "https://ingest-test.westus.kusto.windows.net",
IngestionType: QueuedIngestion,
},
ingestors: map[string]ingest.Ingestor{"test1": &fakeIngestor{}},
}
metrics := []byte(`{"fields": {"value": 1}, "name": "test1", "tags": {"tag1": "value1"}, "timestamp": "2021-01-01T00:00:00Z"}`)
require.NoError(t, plugin.PushMetrics(ingest.FileFormat(ingest.JSON), "test1", metrics))
}
func TestPushMetricsOutputs(t *testing.T) {
testCases := []struct {
name string
inputMetric []telegraf.Metric
metricsGrouping string
createTables bool
ingestionType string
}{
{
name: "Valid metric",
inputMetric: testutil.MockMetrics(),
createTables: true,
metricsGrouping: TablePerMetric,
},
{
name: "Don't create tables'",
inputMetric: testutil.MockMetrics(),
createTables: false,
metricsGrouping: TablePerMetric,
},
{
name: "SingleTable metric grouping type",
inputMetric: testutil.MockMetrics(),
createTables: true,
metricsGrouping: SingleTable,
},
{
name: "Valid metric managed ingestion",
inputMetric: testutil.MockMetrics(),
createTables: true,
metricsGrouping: TablePerMetric,
ingestionType: ManagedIngestion,
},
}
var expectedMetric = map[string]interface{}{
"metricName": "test1",
"fields": map[string]interface{}{
"value": 1.0,
},
"tags": map[string]interface{}{
"tag1": "value1",
},
"timestamp": float64(time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC).UnixNano() / int64(time.Second)),
}
for _, tC := range testCases {
t.Run(tC.name, func(t *testing.T) {
ingestionType := "queued"
if tC.ingestionType != "" {
ingestionType = tC.ingestionType
}
serializer := &serializers_json.Serializer{
TimestampUnits: config.Duration(time.Nanosecond),
TimestampFormat: time.RFC3339Nano,
}
cfg := &Config{
Endpoint: "https://someendpoint.kusto.net",
Database: "databasename",
MetricsGrouping: tC.metricsGrouping,
TableName: "test1",
CreateTables: tC.createTables,
IngestionType: ingestionType,
Timeout: config.Duration(20 * time.Second),
}
client, err := cfg.NewClient("telegraf", &testutil.Logger{})
require.NoError(t, err)
// Inject the ingestor
ingestor := &fakeIngestor{}
client.ingestors["test1"] = ingestor
tableMetricGroups := make(map[string][]byte)
mockmetrics := testutil.MockMetrics()
for _, m := range mockmetrics {
metricInBytes, err := serializer.Serialize(m)
require.NoError(t, err)
tableMetricGroups[m.Name()] = append(tableMetricGroups[m.Name()], metricInBytes...)
}
format := ingest.FileFormat(ingest.JSON)
for tableName, tableMetrics := range tableMetricGroups {
require.NoError(t, client.PushMetrics(format, tableName, tableMetrics))
createdFakeIngestor := ingestor
require.EqualValues(t, expectedMetric["metricName"], createdFakeIngestor.actualOutputMetric["name"])
require.EqualValues(t, expectedMetric["fields"], createdFakeIngestor.actualOutputMetric["fields"])
require.EqualValues(t, expectedMetric["tags"], createdFakeIngestor.actualOutputMetric["tags"])
timestampStr := createdFakeIngestor.actualOutputMetric["timestamp"].(string)
parsedTime, err := time.Parse(time.RFC3339Nano, timestampStr)
parsedTimeFloat := float64(parsedTime.UnixNano()) / 1e9
require.NoError(t, err)
require.InDelta(t, expectedMetric["timestamp"].(float64), parsedTimeFloat, testutil.DefaultDelta)
}
})
}
}
func TestAlreadyClosed(t *testing.T) {
plugin := Client{
logger: testutil.Logger{},
cfg: &Config{
IngestionType: QueuedIngestion,
},
client: kusto.NewMockClient(),
}
require.NoError(t, plugin.Close())
}
type fakeIngestor struct {
actualOutputMetric map[string]interface{}
}
func (f *fakeIngestor) FromReader(_ context.Context, reader io.Reader, _ ...ingest.FileOption) (*ingest.Result, error) {
scanner := bufio.NewScanner(reader)
scanner.Scan()
firstLine := scanner.Text()
err := json.Unmarshal([]byte(firstLine), &f.actualOutputMetric)
if err != nil {
return nil, err
}
return &ingest.Result{}, nil
}
func (*fakeIngestor) FromFile(_ context.Context, _ string, _ ...ingest.FileOption) (*ingest.Result, error) {
return &ingest.Result{}, nil
}
func (*fakeIngestor) Close() error {
return nil
}

View file

@ -0,0 +1,23 @@
package auth
import (
"crypto/subtle"
"net/http"
)
type BasicAuth struct {
Username string `toml:"username"`
Password string `toml:"password"`
}
func (b *BasicAuth) Verify(r *http.Request) bool {
if b.Username == "" && b.Password == "" {
return true
}
username, password, ok := r.BasicAuth()
usernameComparison := subtle.ConstantTimeCompare([]byte(username), []byte(b.Username)) == 1
passwordComparison := subtle.ConstantTimeCompare([]byte(password), []byte(b.Password)) == 1
return ok && usernameComparison && passwordComparison
}

View file

@ -0,0 +1,34 @@
package auth
import (
"net/http/httptest"
"testing"
"github.com/stretchr/testify/require"
)
func TestBasicAuth_VerifyWithCredentials(t *testing.T) {
auth := BasicAuth{"username", "password"}
r := httptest.NewRequest("GET", "/github", nil)
r.SetBasicAuth(auth.Username, auth.Password)
require.True(t, auth.Verify(r))
}
func TestBasicAuth_VerifyWithoutCredentials(t *testing.T) {
auth := BasicAuth{}
r := httptest.NewRequest("GET", "/github", nil)
require.True(t, auth.Verify(r))
}
func TestBasicAuth_VerifyWithInvalidCredentials(t *testing.T) {
auth := BasicAuth{"username", "password"}
r := httptest.NewRequest("GET", "/github", nil)
r.SetBasicAuth("wrong-username", "wrong-password")
require.False(t, auth.Verify(r))
}

View file

@ -0,0 +1,84 @@
package aws
import (
"context"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/credentials/stscreds"
"github.com/aws/aws-sdk-go-v2/service/sts"
)
// The endpoint_url supplied here is used for specific AWS service (Cloudwatch / Timestream / etc.)
type CredentialConfig struct {
Region string `toml:"region"`
AccessKey string `toml:"access_key"`
SecretKey string `toml:"secret_key"`
RoleARN string `toml:"role_arn"`
Profile string `toml:"profile"`
Filename string `toml:"shared_credential_file"`
Token string `toml:"token"`
EndpointURL string `toml:"endpoint_url"`
RoleSessionName string `toml:"role_session_name"`
WebIdentityTokenFile string `toml:"web_identity_token_file"`
}
func (c *CredentialConfig) Credentials() (aws.Config, error) {
if c.RoleARN != "" {
return c.configWithAssumeCredentials()
}
return c.configWithRootCredentials()
}
func (c *CredentialConfig) configWithRootCredentials() (aws.Config, error) {
options := []func(*config.LoadOptions) error{
config.WithRegion(c.Region),
}
if c.Profile != "" {
options = append(options, config.WithSharedConfigProfile(c.Profile))
}
if c.Filename != "" {
options = append(options, config.WithSharedCredentialsFiles([]string{c.Filename}))
}
if c.AccessKey != "" || c.SecretKey != "" {
provider := credentials.NewStaticCredentialsProvider(c.AccessKey, c.SecretKey, c.Token)
options = append(options, config.WithCredentialsProvider(provider))
}
return config.LoadDefaultConfig(context.Background(), options...)
}
func (c *CredentialConfig) configWithAssumeCredentials() (aws.Config, error) {
// To generate credentials using assumeRole, we need to create AWS STS client with the default AWS endpoint,
defaultConfig, err := c.configWithRootCredentials()
if err != nil {
return aws.Config{}, err
}
var provider aws.CredentialsProvider
stsService := sts.NewFromConfig(defaultConfig)
if c.WebIdentityTokenFile != "" {
provider = stscreds.NewWebIdentityRoleProvider(
stsService,
c.RoleARN,
stscreds.IdentityTokenFile(c.WebIdentityTokenFile),
func(opts *stscreds.WebIdentityRoleOptions) {
if c.RoleSessionName != "" {
opts.RoleSessionName = c.RoleSessionName
}
},
)
} else {
provider = stscreds.NewAssumeRoleProvider(stsService, c.RoleARN, func(opts *stscreds.AssumeRoleOptions) {
if c.RoleSessionName != "" {
opts.RoleSessionName = c.RoleSessionName
}
})
}
defaultConfig.Credentials = aws.NewCredentialsCache(provider)
return defaultConfig, nil
}

View file

@ -0,0 +1,136 @@
package cookie
import (
"context"
"fmt"
"io"
"net/http"
"net/http/cookiejar"
"strings"
"sync"
"time"
clockutil "github.com/benbjohnson/clock"
"github.com/influxdata/telegraf"
"github.com/influxdata/telegraf/config"
)
type CookieAuthConfig struct {
URL string `toml:"cookie_auth_url"`
Method string `toml:"cookie_auth_method"`
Headers map[string]*config.Secret `toml:"cookie_auth_headers"`
// HTTP Basic Auth Credentials
Username string `toml:"cookie_auth_username"`
Password string `toml:"cookie_auth_password"`
Body string `toml:"cookie_auth_body"`
Renewal config.Duration `toml:"cookie_auth_renewal"`
client *http.Client
wg sync.WaitGroup
}
func (c *CookieAuthConfig) Start(client *http.Client, log telegraf.Logger, clock clockutil.Clock) (err error) {
if err := c.initializeClient(client); err != nil {
return err
}
// continual auth renewal if set
if c.Renewal > 0 {
ticker := clock.Ticker(time.Duration(c.Renewal))
// this context is used in the tests only, it is to cancel the goroutine
go c.authRenewal(context.Background(), ticker, log)
}
return nil
}
func (c *CookieAuthConfig) initializeClient(client *http.Client) (err error) {
c.client = client
if c.Method == "" {
c.Method = http.MethodPost
}
return c.auth()
}
func (c *CookieAuthConfig) authRenewal(ctx context.Context, ticker *clockutil.Ticker, log telegraf.Logger) {
for {
select {
case <-ctx.Done():
c.wg.Done()
return
case <-ticker.C:
if err := c.auth(); err != nil && log != nil {
log.Errorf("renewal failed for %q: %v", c.URL, err)
}
}
}
}
func (c *CookieAuthConfig) auth() error {
var err error
// everytime we auth we clear out the cookie jar to ensure that the cookie
// is not used as a part of re-authing. The only way to empty or reset is
// to create a new cookie jar.
c.client.Jar, err = cookiejar.New(nil)
if err != nil {
return err
}
var body io.Reader
if c.Body != "" {
body = strings.NewReader(c.Body)
}
req, err := http.NewRequest(c.Method, c.URL, body)
if err != nil {
return err
}
if c.Username != "" {
req.SetBasicAuth(c.Username, c.Password)
}
for k, v := range c.Headers {
secret, err := v.Get()
if err != nil {
return err
}
headerVal := secret.String()
if strings.EqualFold(k, "host") {
req.Host = headerVal
} else {
req.Header.Add(k, headerVal)
}
secret.Destroy()
}
resp, err := c.client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return err
}
if resp.StatusCode < 200 || resp.StatusCode > 299 {
return fmt.Errorf("cookie auth renewal received status code: %v (%v) [%v]",
resp.StatusCode,
http.StatusText(resp.StatusCode),
string(respBody),
)
}
return nil
}

View file

@ -0,0 +1,281 @@
package cookie
import (
"context"
"errors"
"io"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"time"
clockutil "github.com/benbjohnson/clock"
"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/require"
"github.com/influxdata/telegraf/config"
"github.com/influxdata/telegraf/testutil"
)
const (
reqUser = "testUser"
reqPasswd = "testPassword"
reqBody = "a body"
reqHeaderKey = "hello"
reqHeaderVal = "world"
authEndpointNoCreds = "/auth"
authEndpointWithBasicAuth = "/authWithCreds"
authEndpointWithBasicAuthOnlyUsername = "/authWithCredsUser"
authEndpointWithBody = "/authWithBody"
authEndpointWithHeader = "/authWithHeader"
)
var fakeCookie = &http.Cookie{
Name: "test-cookie",
Value: "this is an auth cookie",
}
var reqHeaderValSecret = config.NewSecret([]byte(reqHeaderVal))
type fakeServer struct {
*httptest.Server
*int32
}
func newFakeServer(t *testing.T) fakeServer {
var c int32
return fakeServer{
Server: httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
authed := func() {
atomic.AddInt32(&c, 1) // increment auth counter
http.SetCookie(w, fakeCookie) // set fake cookie
}
switch r.URL.Path {
case authEndpointNoCreds:
authed()
case authEndpointWithHeader:
if !cmp.Equal(r.Header.Get(reqHeaderKey), reqHeaderVal) {
w.WriteHeader(http.StatusUnauthorized)
return
}
authed()
case authEndpointWithBody:
body, err := io.ReadAll(r.Body)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
t.Error(err)
return
}
if !cmp.Equal([]byte(reqBody), body) {
w.WriteHeader(http.StatusUnauthorized)
return
}
authed()
case authEndpointWithBasicAuth:
u, p, ok := r.BasicAuth()
if !ok || u != reqUser || p != reqPasswd {
w.WriteHeader(http.StatusUnauthorized)
return
}
authed()
case authEndpointWithBasicAuthOnlyUsername:
u, p, ok := r.BasicAuth()
if !ok || u != reqUser || p != "" {
w.WriteHeader(http.StatusUnauthorized)
return
}
authed()
default:
// ensure cookie exists on request
if _, err := r.Cookie(fakeCookie.Name); err != nil {
w.WriteHeader(http.StatusForbidden)
return
}
if _, err := w.Write([]byte("good test response")); err != nil {
w.WriteHeader(http.StatusInternalServerError)
t.Error(err)
return
}
}
})),
int32: &c,
}
}
func (s fakeServer) checkResp(t *testing.T, expCode int) {
t.Helper()
resp, err := s.Client().Get(s.URL + "/endpoint")
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, expCode, resp.StatusCode)
if expCode == http.StatusOK {
require.Len(t, resp.Request.Cookies(), 1)
require.Equal(t, "test-cookie", resp.Request.Cookies()[0].Name)
}
}
func (s fakeServer) checkAuthCount(t *testing.T, atLeast int32) {
t.Helper()
require.GreaterOrEqual(t, atomic.LoadInt32(s.int32), atLeast)
}
func TestAuthConfig_Start(t *testing.T) {
const (
renewal = 50 * time.Millisecond
renewalCheck = 5 * renewal
)
type fields struct {
Method string
Username string
Password string
Body string
Headers map[string]*config.Secret
}
type args struct {
renewal time.Duration
endpoint string
}
tests := []struct {
name string
fields fields
args args
wantErr error
firstAuthCount int32
lastAuthCount int32
firstHTTPResponse int
lastHTTPResponse int
}{
{
name: "success no creds, no body, default method",
args: args{
renewal: renewal,
endpoint: authEndpointNoCreds,
},
firstAuthCount: 1,
lastAuthCount: 3,
firstHTTPResponse: http.StatusOK,
lastHTTPResponse: http.StatusOK,
},
{
name: "success no creds, no body, default method, header set",
args: args{
renewal: renewal,
endpoint: authEndpointWithHeader,
},
fields: fields{
Headers: map[string]*config.Secret{reqHeaderKey: &reqHeaderValSecret},
},
firstAuthCount: 1,
lastAuthCount: 3,
firstHTTPResponse: http.StatusOK,
lastHTTPResponse: http.StatusOK,
},
{
name: "success with creds, no body",
fields: fields{
Method: http.MethodPost,
Username: reqUser,
Password: reqPasswd,
},
args: args{
renewal: renewal,
endpoint: authEndpointWithBasicAuth,
},
firstAuthCount: 1,
lastAuthCount: 3,
firstHTTPResponse: http.StatusOK,
lastHTTPResponse: http.StatusOK,
},
{
name: "failure with bad creds",
fields: fields{
Method: http.MethodPost,
Username: reqUser,
Password: "a bad password",
},
args: args{
renewal: renewal,
endpoint: authEndpointWithBasicAuth,
},
wantErr: errors.New("cookie auth renewal received status code: 401 (Unauthorized) []"),
firstAuthCount: 0,
lastAuthCount: 0,
firstHTTPResponse: http.StatusForbidden,
lastHTTPResponse: http.StatusForbidden,
},
{
name: "success with no creds, with good body",
fields: fields{
Method: http.MethodPost,
Body: reqBody,
},
args: args{
renewal: renewal,
endpoint: authEndpointWithBody,
},
firstAuthCount: 1,
lastAuthCount: 3,
firstHTTPResponse: http.StatusOK,
lastHTTPResponse: http.StatusOK,
},
{
name: "failure with bad body",
fields: fields{
Method: http.MethodPost,
Body: "a bad body",
},
args: args{
renewal: renewal,
endpoint: authEndpointWithBody,
},
wantErr: errors.New("cookie auth renewal received status code: 401 (Unauthorized) []"),
firstAuthCount: 0,
lastAuthCount: 0,
firstHTTPResponse: http.StatusForbidden,
lastHTTPResponse: http.StatusForbidden,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
srv := newFakeServer(t)
c := &CookieAuthConfig{
URL: srv.URL + tt.args.endpoint,
Method: tt.fields.Method,
Username: tt.fields.Username,
Password: tt.fields.Password,
Body: tt.fields.Body,
Headers: tt.fields.Headers,
Renewal: config.Duration(tt.args.renewal),
}
if err := c.initializeClient(srv.Client()); tt.wantErr != nil {
require.EqualError(t, err, tt.wantErr.Error())
} else {
require.NoError(t, err)
}
mock := clockutil.NewMock()
ticker := mock.Ticker(time.Duration(c.Renewal))
defer ticker.Stop()
c.wg.Add(1)
ctx, cancel := context.WithCancel(t.Context())
go c.authRenewal(ctx, ticker, testutil.Logger{Name: "cookie_auth"})
srv.checkAuthCount(t, tt.firstAuthCount)
srv.checkResp(t, tt.firstHTTPResponse)
mock.Add(renewalCheck)
// Ensure that the auth renewal goroutine has completed
require.Eventually(t, func() bool { return atomic.LoadInt32(srv.int32) >= tt.lastAuthCount }, time.Second, 10*time.Millisecond)
cancel()
c.wg.Wait()
srv.checkAuthCount(t, tt.lastAuthCount)
srv.checkResp(t, tt.lastHTTPResponse)
srv.Close()
})
}
}

View file

@ -0,0 +1,79 @@
// Package docker contains few helper functions copied from
// https://github.com/docker/cli/blob/master/cli/command/container/stats_helpers.go
package docker
import (
"github.com/docker/docker/api/types/container"
)
// CalculateCPUPercentUnix calculate CPU usage (for Unix, in percentages)
func CalculateCPUPercentUnix(previousCPU, previousSystem uint64, v *container.StatsResponse) float64 {
var (
cpuPercent = 0.0
// calculate the change for the cpu usage of the container in between readings
cpuDelta = float64(v.CPUStats.CPUUsage.TotalUsage) - float64(previousCPU)
// calculate the change for the entire system between readings
systemDelta = float64(v.CPUStats.SystemUsage) - float64(previousSystem)
onlineCPUs = float64(v.CPUStats.OnlineCPUs)
)
if onlineCPUs == 0.0 {
onlineCPUs = float64(len(v.CPUStats.CPUUsage.PercpuUsage))
}
if systemDelta > 0.0 && cpuDelta > 0.0 {
cpuPercent = (cpuDelta / systemDelta) * onlineCPUs * 100.0
}
return cpuPercent
}
// CalculateCPUPercentWindows calculate CPU usage (for Windows, in percentages)
func CalculateCPUPercentWindows(v *container.StatsResponse) float64 {
// Max number of 100ns intervals between the previous time read and now
possIntervals := uint64(v.Read.Sub(v.PreRead).Nanoseconds()) // Start with number of ns intervals
possIntervals /= 100 // Convert to number of 100ns intervals
possIntervals *= uint64(v.NumProcs) // Multiple by the number of processors
// Intervals used
intervalsUsed := v.CPUStats.CPUUsage.TotalUsage - v.PreCPUStats.CPUUsage.TotalUsage
// Percentage avoiding divide-by-zero
if possIntervals > 0 {
return float64(intervalsUsed) / float64(possIntervals) * 100.0
}
return 0.00
}
// CalculateMemUsageUnixNoCache calculate memory usage of the container.
// Cache is intentionally excluded to avoid misinterpretation of the output.
//
// On Docker 19.03 and older, the result is `mem.Usage - mem.Stats["cache"]`.
// On new docker with cgroup v1 host, the result is `mem.Usage - mem.Stats["total_inactive_file"]`.
// On new docker with cgroup v2 host, the result is `mem.Usage - mem.Stats["inactive_file"]`.
//
// This definition is designed to be consistent with past values and the latest docker CLI
// * https://github.com/docker/cli/blob/6e2838e18645e06f3e4b6c5143898ccc44063e3b/cli/command/container/stats_helpers.go#L239
func CalculateMemUsageUnixNoCache(mem container.MemoryStats) float64 {
// Docker 19.03 and older
if v, isOldDocker := mem.Stats["cache"]; isOldDocker && v < mem.Usage {
return float64(mem.Usage - v)
}
// cgroup v1
if v, isCgroup1 := mem.Stats["total_inactive_file"]; isCgroup1 && v < mem.Usage {
return float64(mem.Usage - v)
}
// cgroup v2
if v := mem.Stats["inactive_file"]; v < mem.Usage {
return float64(mem.Usage - v)
}
return float64(mem.Usage)
}
// CalculateMemPercentUnixNoCache calculate memory usage of the container, in percentages.
func CalculateMemPercentUnixNoCache(limit, usedNoCache float64) float64 {
// MemoryStats.Limit will never be 0 unless the container is not running and we haven't
// got any data from cgroup
if limit != 0 {
return usedNoCache / limit * 100.0
}
return 0
}

View file

@ -0,0 +1,34 @@
package encoding
import (
"errors"
"golang.org/x/text/encoding"
"golang.org/x/text/encoding/unicode"
)
// NewDecoder returns an x/text Decoder for the specified text encoding. The
// Decoder converts a character encoding into utf-8 bytes. If a BOM is found
// it will be converted into an utf-8 BOM, you can use
// github.com/dimchansky/utfbom to strip the BOM.
//
// The "none" or "" encoding will pass through bytes unchecked. Use the utf-8
// encoding if you want invalid bytes replaced using the unicode
// replacement character.
//
// Detection of utf-16 endianness using the BOM is not currently provided due
// to the tail input plugins requirement to be able to start at the middle or
// end of the file.
func NewDecoder(enc string) (*Decoder, error) {
switch enc {
case "utf-8":
return createDecoder(unicode.UTF8.NewDecoder()), nil
case "utf-16le":
return createDecoder(unicode.UTF16(unicode.LittleEndian, unicode.IgnoreBOM).NewDecoder()), nil
case "utf-16be":
return createDecoder(unicode.UTF16(unicode.BigEndian, unicode.IgnoreBOM).NewDecoder()), nil
case "none", "":
return createDecoder(encoding.Nop.NewDecoder()), nil
}
return nil, errors.New("unknown character encoding")
}

View file

@ -0,0 +1,164 @@
package encoding
import (
"errors"
"io"
"golang.org/x/text/transform"
)
// Other than resetting r.err and r.transformComplete in Read() this
// was copied from x/text
func createDecoder(t transform.Transformer) *Decoder {
return &Decoder{Transformer: t}
}
// A Decoder converts bytes to UTF-8. It implements transform.Transformer.
//
// Transforming source bytes that are not of that encoding will not result in an
// error per se. Each byte that cannot be transcoded will be represented in the
// output by the UTF-8 encoding of '\uFFFD', the replacement rune.
type Decoder struct {
transform.Transformer
// This forces external creators of Decoders to use names in struct
// initializers, allowing for future extensibility without having to break
// code.
_ struct{}
}
// Bytes converts the given encoded bytes to UTF-8. It returns the converted
// bytes or nil, err if any error occurred.
func (d *Decoder) Bytes(b []byte) ([]byte, error) {
b, _, err := transform.Bytes(d, b)
if err != nil {
return nil, err
}
return b, nil
}
// String converts the given encoded string to UTF-8. It returns the converted
// string or "", err if any error occurred.
func (d *Decoder) String(s string) (string, error) {
s, _, err := transform.String(d, s)
if err != nil {
return "", err
}
return s, nil
}
// Reader wraps another Reader to decode its bytes.
//
// The Decoder may not be used for any other operation as long as the returned
// Reader is in use.
func (d *Decoder) Reader(r io.Reader) io.Reader {
return NewReader(r, d)
}
// Reader wraps another io.Reader by transforming the bytes read.
type Reader struct {
r io.Reader
t transform.Transformer
err error
// dst[dst0:dst1] contains bytes that have been transformed by t but
// not yet copied out via Read.
dst []byte
dst0, dst1 int
// src[src0:src1] contains bytes that have been read from r but not
// yet transformed through t.
src []byte
src0, src1 int
// transformComplete is whether the transformation is complete,
// regardless of whether it was successful.
transformComplete bool
}
var (
// errInconsistentByteCount means that Transform returned success (nil
// error) but also returned nSrc inconsistent with the src argument.
errInconsistentByteCount = errors.New("transform: inconsistent byte count returned")
)
const defaultBufSize = 4096
// NewReader returns a new Reader that wraps r by transforming the bytes read
// via t. It calls Reset on t.
func NewReader(r io.Reader, t transform.Transformer) *Reader {
t.Reset()
return &Reader{
r: r,
t: t,
dst: make([]byte, defaultBufSize),
src: make([]byte, defaultBufSize),
}
}
// Read implements the io.Reader interface.
func (r *Reader) Read(p []byte) (int, error) {
// Clear previous errors so a Read can be performed even if the last call
// returned EOF.
r.err = nil
r.transformComplete = false
n := 0
for {
// Copy out any transformed bytes and return the final error if we are done.
if r.dst0 != r.dst1 {
n = copy(p, r.dst[r.dst0:r.dst1])
r.dst0 += n
if r.dst0 == r.dst1 && r.transformComplete {
return n, r.err
}
return n, nil
} else if r.transformComplete {
return 0, r.err
}
// Try to transform some source bytes, or to flush the transformer if we
// are out of source bytes. We do this even if r.r.Read returned an error.
// As the io.Reader documentation says, "process the n > 0 bytes returned
// before considering the error".
if r.src0 != r.src1 || r.err != nil {
var err error
r.dst0 = 0
r.dst1, n, err = r.t.Transform(r.dst, r.src[r.src0:r.src1], errors.Is(r.err, io.EOF))
r.src0 += n
switch {
case err == nil:
if r.src0 != r.src1 {
r.err = errInconsistentByteCount
}
// The Transform call was successful; we are complete if we
// cannot read more bytes into src.
r.transformComplete = r.err != nil
continue
case errors.Is(err, transform.ErrShortDst) && (r.dst1 != 0 || n != 0):
// Make room in dst by copying out, and try again.
continue
case errors.Is(err, transform.ErrShortSrc) && r.src1-r.src0 != len(r.src) && r.err == nil:
// Read more bytes into src via the code below, and try again.
default:
r.transformComplete = true
// The reader error (r.err) takes precedence over the
// transformer error (err) unless r.err is nil or io.EOF.
if r.err == nil || errors.Is(r.err, io.EOF) {
r.err = err
}
continue
}
}
// Move any untransformed source bytes to the start of the buffer
// and read more bytes.
if r.src0 != 0 {
r.src0, r.src1 = 0, copy(r.src, r.src[r.src0:r.src1])
}
n, r.err = r.r.Read(r.src[r.src1:])
r.src1 += n
}
}

View file

@ -0,0 +1,78 @@
package encoding
import (
"bytes"
"io"
"testing"
"github.com/stretchr/testify/require"
)
func TestDecoder(t *testing.T) {
tests := []struct {
name string
encoding string
input []byte
expected []byte
expectedErr bool
}{
{
name: "no decoder utf-8",
encoding: "",
input: []byte("howdy"),
expected: []byte("howdy"),
},
{
name: "utf-8 decoder",
encoding: "utf-8",
input: []byte("howdy"),
expected: []byte("howdy"),
},
{
name: "utf-8 decoder invalid bytes replaced with replacement char",
encoding: "utf-8",
input: []byte("\xff\xfe"),
expected: []byte("\uFFFD\uFFFD"),
},
{
name: "utf-16le decoder no BOM",
encoding: "utf-16le",
input: []byte("h\x00o\x00w\x00d\x00y\x00"),
expected: []byte("howdy"),
},
{
name: "utf-16le decoder with BOM",
encoding: "utf-16le",
input: []byte("\xff\xfeh\x00o\x00w\x00d\x00y\x00"),
expected: []byte("\xef\xbb\xbfhowdy"),
},
{
name: "utf-16be decoder no BOM",
encoding: "utf-16be",
input: []byte("\x00h\x00o\x00w\x00d\x00y"),
expected: []byte("howdy"),
},
{
name: "utf-16be decoder with BOM",
encoding: "utf-16be",
input: []byte("\xfe\xff\x00h\x00o\x00w\x00d\x00y"),
expected: []byte("\xef\xbb\xbfhowdy"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
decoder, err := NewDecoder(tt.encoding)
require.NoError(t, err)
buf := bytes.NewBuffer(tt.input)
r := decoder.Reader(buf)
actual, err := io.ReadAll(r)
if tt.expectedErr {
require.Error(t, err)
return
}
require.NoError(t, err)
require.Equal(t, tt.expected, actual)
})
}
}

View file

@ -0,0 +1,78 @@
package httpconfig
import (
"context"
"fmt"
"net/http"
"time"
"github.com/benbjohnson/clock"
"github.com/peterbourgon/unixtransport"
"github.com/influxdata/telegraf"
"github.com/influxdata/telegraf/config"
"github.com/influxdata/telegraf/plugins/common/cookie"
"github.com/influxdata/telegraf/plugins/common/oauth"
"github.com/influxdata/telegraf/plugins/common/proxy"
"github.com/influxdata/telegraf/plugins/common/tls"
)
// Common HTTP client struct.
type HTTPClientConfig struct {
Timeout config.Duration `toml:"timeout"`
IdleConnTimeout config.Duration `toml:"idle_conn_timeout"`
MaxIdleConns int `toml:"max_idle_conn"`
MaxIdleConnsPerHost int `toml:"max_idle_conn_per_host"`
ResponseHeaderTimeout config.Duration `toml:"response_timeout"`
proxy.HTTPProxy
tls.ClientConfig
oauth.OAuth2Config
cookie.CookieAuthConfig
}
func (h *HTTPClientConfig) CreateClient(ctx context.Context, log telegraf.Logger) (*http.Client, error) {
tlsCfg, err := h.ClientConfig.TLSConfig()
if err != nil {
return nil, fmt.Errorf("failed to set TLS config: %w", err)
}
prox, err := h.HTTPProxy.Proxy()
if err != nil {
return nil, fmt.Errorf("failed to set proxy: %w", err)
}
transport := &http.Transport{
TLSClientConfig: tlsCfg,
Proxy: prox,
IdleConnTimeout: time.Duration(h.IdleConnTimeout),
MaxIdleConns: h.MaxIdleConns,
MaxIdleConnsPerHost: h.MaxIdleConnsPerHost,
ResponseHeaderTimeout: time.Duration(h.ResponseHeaderTimeout),
}
// Register "http+unix" and "https+unix" protocol handler.
unixtransport.Register(transport)
client := &http.Client{
Transport: transport,
}
// While CreateOauth2Client returns a http.Client keeping the Transport configuration,
// it does not keep other http.Client parameters (e.g. Timeout).
client = h.OAuth2Config.CreateOauth2Client(ctx, client)
if h.CookieAuthConfig.URL != "" {
if err := h.CookieAuthConfig.Start(client, log, clock.New()); err != nil {
return nil, err
}
}
timeout := h.Timeout
if timeout == 0 {
timeout = config.Duration(time.Second * 5)
}
client.Timeout = time.Duration(timeout)
return client, nil
}

View file

@ -0,0 +1,275 @@
package jolokia2
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"path"
"time"
"github.com/influxdata/telegraf/plugins/common/tls"
)
type Client struct {
URL string
client *http.Client
config *ClientConfig
}
type ClientConfig struct {
ResponseTimeout time.Duration
Username string
Password string
Origin string
ProxyConfig *ProxyConfig
tls.ClientConfig
}
type ProxyConfig struct {
DefaultTargetUsername string
DefaultTargetPassword string
Targets []ProxyTargetConfig
}
type ProxyTargetConfig struct {
Username string
Password string
URL string
}
type ReadRequest struct {
Mbean string
Attributes []string
Path string
}
type ReadResponse struct {
Status int
Value interface{}
RequestMbean string
RequestAttributes []string
RequestPath string
RequestTarget string
}
// Jolokia JSON request object. Example: {
// "type": "read",
// "mbean: "java.lang:type="Runtime",
// "attribute": "Uptime",
// "target": {
// "url: "service:jmx:rmi:///jndi/rmi://target:9010/jmxrmi"
// }
// }
type jolokiaRequest struct {
Type string `json:"type"`
Mbean string `json:"mbean"`
Attribute interface{} `json:"attribute,omitempty"`
Path string `json:"path,omitempty"`
Target *jolokiaTarget `json:"target,omitempty"`
}
type jolokiaTarget struct {
URL string `json:"url"`
User string `json:"user,omitempty"`
Password string `json:"password,omitempty"`
}
// Jolokia JSON response object. Example: {
// "request": {
// "type": "read"
// "mbean": "java.lang:type=Runtime",
// "attribute": "Uptime",
// "target": {
// "url": "service:jmx:rmi:///jndi/rmi://target:9010/jmxrmi"
// }
// },
// "value": 1214083,
// "timestamp": 1488059309,
// "status": 200
// }
type jolokiaResponse struct {
Request jolokiaRequest `json:"request"`
Value interface{} `json:"value"`
Status int `json:"status"`
}
func NewClient(address string, config *ClientConfig) (*Client, error) {
tlsConfig, err := config.ClientConfig.TLSConfig()
if err != nil {
return nil, err
}
transport := &http.Transport{
ResponseHeaderTimeout: config.ResponseTimeout,
TLSClientConfig: tlsConfig,
}
client := &http.Client{
Transport: transport,
Timeout: config.ResponseTimeout,
}
return &Client{
URL: address,
config: config,
client: client,
}, nil
}
func (c *Client) read(requests []ReadRequest) ([]ReadResponse, error) {
jRequests := makeJolokiaRequests(requests, c.config.ProxyConfig)
requestBody, err := json.Marshal(jRequests)
if err != nil {
return nil, err
}
requestURL, err := formatReadURL(c.URL, c.config.Username, c.config.Password)
if err != nil {
return nil, err
}
req, err := http.NewRequest("POST", requestURL, bytes.NewBuffer(requestBody))
if err != nil {
// err is not contained in returned error - it may contain sensitive data (password) which should not be logged
return nil, fmt.Errorf("unable to create new request for: %q", c.URL)
}
req.Header.Add("Content-type", "application/json")
if c.config.Origin != "" {
req.Header.Add("Origin", c.config.Origin)
}
resp, err := c.client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("response from url %q has status code %d (%s), expected %d (%s)",
c.URL, resp.StatusCode, http.StatusText(resp.StatusCode), http.StatusOK, http.StatusText(http.StatusOK))
}
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
var jResponses []jolokiaResponse
if err = json.Unmarshal(responseBody, &jResponses); err != nil {
return nil, fmt.Errorf("decoding JSON response: %w: %s", err, responseBody)
}
return makeReadResponses(jResponses), nil
}
func makeJolokiaRequests(rrequests []ReadRequest, proxyConfig *ProxyConfig) []jolokiaRequest {
jrequests := make([]jolokiaRequest, 0)
if proxyConfig == nil {
for _, rr := range rrequests {
jrequests = append(jrequests, makeJolokiaRequest(rr, nil))
}
} else {
for _, t := range proxyConfig.Targets {
if t.Username == "" {
t.Username = proxyConfig.DefaultTargetUsername
}
if t.Password == "" {
t.Password = proxyConfig.DefaultTargetPassword
}
for _, rr := range rrequests {
jtarget := &jolokiaTarget{
URL: t.URL,
User: t.Username,
Password: t.Password,
}
jrequests = append(jrequests, makeJolokiaRequest(rr, jtarget))
}
}
}
return jrequests
}
func makeJolokiaRequest(rrequest ReadRequest, jtarget *jolokiaTarget) jolokiaRequest {
jrequest := jolokiaRequest{
Type: "read",
Mbean: rrequest.Mbean,
Path: rrequest.Path,
Target: jtarget,
}
if len(rrequest.Attributes) == 1 {
jrequest.Attribute = rrequest.Attributes[0]
}
if len(rrequest.Attributes) > 1 {
jrequest.Attribute = rrequest.Attributes
}
return jrequest
}
func makeReadResponses(jresponses []jolokiaResponse) []ReadResponse {
rresponses := make([]ReadResponse, 0)
for _, jr := range jresponses {
rrequest := ReadRequest{
Mbean: jr.Request.Mbean,
Path: jr.Request.Path,
Attributes: make([]string, 0),
}
attrValue := jr.Request.Attribute
if attrValue != nil {
attribute, ok := attrValue.(string)
if ok {
rrequest.Attributes = []string{attribute}
} else {
attributes, _ := attrValue.([]interface{})
rrequest.Attributes = make([]string, 0, len(attributes))
for _, attr := range attributes {
rrequest.Attributes = append(rrequest.Attributes, attr.(string))
}
}
}
rresponse := ReadResponse{
Value: jr.Value,
Status: jr.Status,
RequestMbean: rrequest.Mbean,
RequestAttributes: rrequest.Attributes,
RequestPath: rrequest.Path,
}
if jtarget := jr.Request.Target; jtarget != nil {
rresponse.RequestTarget = jtarget.URL
}
rresponses = append(rresponses, rresponse)
}
return rresponses
}
func formatReadURL(configURL, username, password string) (string, error) {
parsedURL, err := url.Parse(configURL)
if err != nil {
return "", err
}
readURL := url.URL{
Host: parsedURL.Host,
Scheme: parsedURL.Scheme,
}
if username != "" || password != "" {
readURL.User = url.UserPassword(username, password)
}
readURL.Path = path.Join(parsedURL.Path, "read")
readURL.Query().Add("ignoreErrors", "true")
return readURL.String(), nil
}

View file

@ -0,0 +1,266 @@
package jolokia2
import (
"fmt"
"sort"
"strings"
"github.com/influxdata/telegraf"
)
const defaultFieldName = "value"
type Gatherer struct {
metrics []Metric
requests []ReadRequest
}
func NewGatherer(metrics []Metric) *Gatherer {
return &Gatherer{
metrics: metrics,
requests: makeReadRequests(metrics),
}
}
// Gather adds points to an accumulator from responses returned
// by a Jolokia agent.
func (g *Gatherer) Gather(client *Client, acc telegraf.Accumulator) error {
var tags map[string]string
if client.config.ProxyConfig != nil {
tags = map[string]string{"jolokia_proxy_url": client.URL}
} else {
tags = map[string]string{"jolokia_agent_url": client.URL}
}
requests := makeReadRequests(g.metrics)
responses, err := client.read(requests)
if err != nil {
return err
}
g.gatherResponses(responses, tags, acc)
return nil
}
// gatherResponses adds points to an accumulator from the ReadResponse objects
// returned by a Jolokia agent.
func (g *Gatherer) gatherResponses(responses []ReadResponse, tags map[string]string, acc telegraf.Accumulator) {
series := make(map[string][]point)
for _, metric := range g.metrics {
points, ok := series[metric.Name]
if !ok {
points = make([]point, 0)
}
responsePoints, responseErrors := generatePoints(metric, responses)
points = append(points, responsePoints...)
for _, err := range responseErrors {
acc.AddError(err)
}
series[metric.Name] = points
}
for measurement, points := range series {
for _, point := range compactPoints(points) {
acc.AddFields(measurement,
point.Fields, mergeTags(point.Tags, tags))
}
}
}
// generatePoints creates points for the supplied metric from the ReadResponse objects returned by the Jolokia client.
func generatePoints(metric Metric, responses []ReadResponse) ([]point, []error) {
points := make([]point, 0)
errors := make([]error, 0)
for _, response := range responses {
switch response.Status {
case 200:
// Correct response status - do nothing.
case 404:
continue
default:
errors = append(errors, fmt.Errorf("unexpected status in response from target %s (%q): %d",
response.RequestTarget, response.RequestMbean, response.Status))
continue
}
if !metricMatchesResponse(metric, response) {
continue
}
pb := NewPointBuilder(metric, response.RequestAttributes, response.RequestPath)
ps, err := pb.Build(metric.Mbean, response.Value)
if err != nil {
errors = append(errors, err)
continue
}
for _, point := range ps {
if response.RequestTarget != "" {
point.Tags["jolokia_agent_url"] = response.RequestTarget
}
points = append(points, point)
}
}
return points, errors
}
// mergeTags combines two tag sets into a single tag set.
func mergeTags(metricTags, outerTags map[string]string) map[string]string {
tags := make(map[string]string)
for k, v := range outerTags {
tags[k] = strings.Trim(v, `'"`)
}
for k, v := range metricTags {
tags[k] = strings.Trim(v, `'"`)
}
return tags
}
// metricMatchesResponse returns true when the name, attributes, and path
// of a Metric match the corresponding elements in a ReadResponse object
// returned by a Jolokia agent.
func metricMatchesResponse(metric Metric, response ReadResponse) bool {
if !metric.MatchObjectName(response.RequestMbean) {
return false
}
if len(metric.Paths) == 0 {
return len(response.RequestAttributes) == 0
}
for _, attribute := range response.RequestAttributes {
if metric.MatchAttributeAndPath(attribute, response.RequestPath) {
return true
}
}
return false
}
// compactPoints attempts to remove points by compacting points
// with matching tag sets. When a match is found, the fields from
// one point are moved to another, and the empty point is removed.
func compactPoints(points []point) []point {
compactedPoints := make([]point, 0)
for _, sourcePoint := range points {
keepPoint := true
for _, compactPoint := range compactedPoints {
if !tagSetsMatch(sourcePoint.Tags, compactPoint.Tags) {
continue
}
keepPoint = false
for key, val := range sourcePoint.Fields {
compactPoint.Fields[key] = val
}
}
if keepPoint {
compactedPoints = append(compactedPoints, sourcePoint)
}
}
return compactedPoints
}
// tagSetsMatch returns true if two maps are equivalent.
func tagSetsMatch(a, b map[string]string) bool {
if len(a) != len(b) {
return false
}
for ak, av := range a {
bv, ok := b[ak]
if !ok {
return false
}
if av != bv {
return false
}
}
return true
}
// makeReadRequests creates ReadRequest objects from metrics definitions.
func makeReadRequests(metrics []Metric) []ReadRequest {
var requests []ReadRequest
for _, metric := range metrics {
if len(metric.Paths) == 0 {
requests = append(requests, ReadRequest{
Mbean: metric.Mbean,
Attributes: make([]string, 0),
})
} else {
attributes := make(map[string][]string)
for _, path := range metric.Paths {
segments := strings.Split(path, "/")
attribute := segments[0]
if _, ok := attributes[attribute]; !ok {
attributes[attribute] = make([]string, 0)
}
if len(segments) > 1 {
paths := attributes[attribute]
attributes[attribute] = append(paths, strings.Join(segments[1:], "/"))
}
}
rootAttributes := findRequestAttributesWithoutPaths(attributes)
if len(rootAttributes) > 0 {
requests = append(requests, ReadRequest{
Mbean: metric.Mbean,
Attributes: rootAttributes,
})
}
for _, deepAttribute := range findRequestAttributesWithPaths(attributes) {
for _, path := range attributes[deepAttribute] {
requests = append(requests, ReadRequest{
Mbean: metric.Mbean,
Attributes: []string{deepAttribute},
Path: path,
})
}
}
}
}
return requests
}
func findRequestAttributesWithoutPaths(attributes map[string][]string) []string {
results := make([]string, 0)
for attr, paths := range attributes {
if len(paths) == 0 {
results = append(results, attr)
}
}
sort.Strings(results)
return results
}
func findRequestAttributesWithPaths(attributes map[string][]string) []string {
results := make([]string, 0)
for attr, paths := range attributes {
if len(paths) != 0 {
results = append(results, attr)
}
}
sort.Strings(results)
return results
}

View file

@ -0,0 +1,104 @@
package jolokia2
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestJolokia2_makeReadRequests(t *testing.T) {
cases := []struct {
metric Metric
expected []ReadRequest
}{
{
metric: Metric{
Name: "object",
Mbean: "test:foo=bar",
},
expected: []ReadRequest{
{
Mbean: "test:foo=bar",
Attributes: make([]string, 0),
},
},
}, {
metric: Metric{
Name: "object_with_an_attribute",
Mbean: "test:foo=bar",
Paths: []string{"biz"},
},
expected: []ReadRequest{
{
Mbean: "test:foo=bar",
Attributes: []string{"biz"},
},
},
}, {
metric: Metric{
Name: "object_with_attributes",
Mbean: "test:foo=bar",
Paths: []string{"baz", "biz"},
},
expected: []ReadRequest{
{
Mbean: "test:foo=bar",
Attributes: []string{"baz", "biz"},
},
},
}, {
metric: Metric{
Name: "object_with_an_attribute_and_path",
Mbean: "test:foo=bar",
Paths: []string{"biz/baz"},
},
expected: []ReadRequest{
{
Mbean: "test:foo=bar",
Attributes: []string{"biz"},
Path: "baz",
},
},
}, {
metric: Metric{
Name: "object_with_an_attribute_and_a_deep_path",
Mbean: "test:foo=bar",
Paths: []string{"biz/baz/fiz/faz"},
},
expected: []ReadRequest{
{
Mbean: "test:foo=bar",
Attributes: []string{"biz"},
Path: "baz/fiz/faz",
},
},
}, {
metric: Metric{
Name: "object_with_attributes_and_paths",
Mbean: "test:foo=bar",
Paths: []string{"baz/biz", "faz/fiz"},
},
expected: []ReadRequest{
{
Mbean: "test:foo=bar",
Attributes: []string{"baz"},
Path: "biz",
},
{
Mbean: "test:foo=bar",
Attributes: []string{"faz"},
Path: "fiz",
},
},
},
}
for _, c := range cases {
payload := makeReadRequests([]Metric{c.metric})
require.Len(t, payload, len(c.expected), "Failing case: "+c.metric.Name)
for _, actual := range payload {
require.Contains(t, c.expected, actual, "Failing case: "+c.metric.Name)
}
}
}

View file

@ -0,0 +1,128 @@
package jolokia2
import "strings"
// A MetricConfig represents a TOML form of
// a Metric with some optional fields.
type MetricConfig struct {
Name string
Mbean string
Paths []string
FieldName *string
FieldPrefix *string
FieldSeparator *string
TagPrefix *string
TagKeys []string
}
// A Metric represents a specification for a
// Jolokia read request, and the transformations
// to apply to points generated from the responses.
type Metric struct {
Name string
Mbean string
Paths []string
FieldName string
FieldPrefix string
FieldSeparator string
TagPrefix string
TagKeys []string
mbeanDomain string
mbeanProperties []string
}
func NewMetric(config MetricConfig, defaultFieldPrefix, defaultFieldSeparator, defaultTagPrefix string) Metric {
metric := Metric{
Name: config.Name,
Mbean: config.Mbean,
Paths: config.Paths,
TagKeys: config.TagKeys,
}
if config.FieldName != nil {
metric.FieldName = *config.FieldName
}
if config.FieldPrefix == nil {
metric.FieldPrefix = defaultFieldPrefix
} else {
metric.FieldPrefix = *config.FieldPrefix
}
if config.FieldSeparator == nil {
metric.FieldSeparator = defaultFieldSeparator
} else {
metric.FieldSeparator = *config.FieldSeparator
}
if config.TagPrefix == nil {
metric.TagPrefix = defaultTagPrefix
} else {
metric.TagPrefix = *config.TagPrefix
}
mbeanDomain, mbeanProperties := parseMbeanObjectName(config.Mbean)
metric.mbeanDomain = mbeanDomain
metric.mbeanProperties = mbeanProperties
return metric
}
func (m Metric) MatchObjectName(name string) bool {
if name == m.Mbean {
return true
}
mbeanDomain, mbeanProperties := parseMbeanObjectName(name)
if mbeanDomain != m.mbeanDomain {
return false
}
if len(mbeanProperties) != len(m.mbeanProperties) {
return false
}
NEXT_PROPERTY:
for _, mbeanProperty := range m.mbeanProperties {
for i := range mbeanProperties {
if mbeanProperties[i] == mbeanProperty {
continue NEXT_PROPERTY
}
}
return false
}
return true
}
func (m Metric) MatchAttributeAndPath(attribute, innerPath string) bool {
path := attribute
if innerPath != "" {
path = path + "/" + innerPath
}
for i := range m.Paths {
if path == m.Paths[i] {
return true
}
}
return false
}
func parseMbeanObjectName(name string) (string, []string) {
index := strings.Index(name, ":")
if index == -1 {
return name, nil
}
domain := name[:index]
if index+1 > len(name) {
return domain, nil
}
return domain, strings.Split(name[index+1:], ",")
}

View file

@ -0,0 +1,274 @@
package jolokia2
import (
"fmt"
"strings"
)
type point struct {
Tags map[string]string
Fields map[string]interface{}
}
type pointBuilder struct {
metric Metric
objectAttributes []string
objectPath string
substitutions []string
}
func NewPointBuilder(metric Metric, attributes []string, path string) *pointBuilder {
return &pointBuilder{
metric: metric,
objectAttributes: attributes,
objectPath: path,
substitutions: makeSubstitutionList(metric.Mbean),
}
}
// Build generates a point for a given mbean name/pattern and value object.
func (pb *pointBuilder) Build(mbean string, value interface{}) ([]point, error) {
hasPattern := strings.Contains(mbean, "*")
if !hasPattern || value == nil {
value = map[string]interface{}{mbean: value}
}
valueMap, ok := value.(map[string]interface{})
if !ok {
return nil, fmt.Errorf("the response of %s's value should be a map", mbean)
}
points := make([]point, 0)
for mbean, value := range valueMap {
points = append(points, point{
Tags: pb.extractTags(mbean),
Fields: pb.extractFields(mbean, value),
})
}
return compactPoints(points), nil
}
// extractTags generates the map of tags for a given mbean name/pattern.
func (pb *pointBuilder) extractTags(mbean string) map[string]string {
propertyMap := makePropertyMap(mbean)
tagMap := make(map[string]string)
for key, value := range propertyMap {
if pb.includeTag(key) {
tagName := pb.formatTagName(key)
tagMap[tagName] = value
}
}
return tagMap
}
func (pb *pointBuilder) includeTag(tagName string) bool {
for _, t := range pb.metric.TagKeys {
if tagName == t {
return true
}
}
return false
}
func (pb *pointBuilder) formatTagName(tagName string) string {
if tagName == "" {
return ""
}
if tagPrefix := pb.metric.TagPrefix; tagPrefix != "" {
return tagPrefix + tagName
}
return tagName
}
// extractFields generates the map of fields for a given mbean name
// and value object.
func (pb *pointBuilder) extractFields(mbean string, value interface{}) map[string]interface{} {
fieldMap := make(map[string]interface{})
valueMap, ok := value.(map[string]interface{})
if ok {
// complex value
if len(pb.objectAttributes) == 0 {
// if there were no attributes requested,
// then the keys are attributes
pb.FillFields("", valueMap, fieldMap)
} else if len(pb.objectAttributes) == 1 {
// if there was a single attribute requested,
// then the keys are the attribute's properties
fieldName := pb.formatFieldName(pb.objectAttributes[0], pb.objectPath)
pb.FillFields(fieldName, valueMap, fieldMap)
} else {
// if there were multiple attributes requested,
// then the keys are the attribute names
for _, attribute := range pb.objectAttributes {
fieldName := pb.formatFieldName(attribute, pb.objectPath)
pb.FillFields(fieldName, valueMap[attribute], fieldMap)
}
}
} else {
// scalar value
var fieldName string
if len(pb.objectAttributes) == 0 {
fieldName = pb.formatFieldName(defaultFieldName, pb.objectPath)
} else {
fieldName = pb.formatFieldName(pb.objectAttributes[0], pb.objectPath)
}
pb.FillFields(fieldName, value, fieldMap)
}
if len(pb.substitutions) > 1 {
pb.applySubstitutions(mbean, fieldMap)
}
return fieldMap
}
// formatFieldName generates a field name from the supplied attribute and
// path. The return value has the configured FieldPrefix and FieldSuffix
// instructions applied.
func (pb *pointBuilder) formatFieldName(attribute, path string) string {
fieldName := attribute
fieldPrefix := pb.metric.FieldPrefix
fieldSeparator := pb.metric.FieldSeparator
if fieldPrefix != "" {
fieldName = fieldPrefix + fieldName
}
if path != "" {
fieldName = fieldName + fieldSeparator + strings.ReplaceAll(path, "/", fieldSeparator)
}
return fieldName
}
// FillFields recurses into the supplied value object, generating a named field
// for every value it discovers.
func (pb *pointBuilder) FillFields(name string, value interface{}, fieldMap map[string]interface{}) {
if valueMap, ok := value.(map[string]interface{}); ok {
// keep going until we get to something that is not a map
for key, innerValue := range valueMap {
if _, ok := innerValue.([]interface{}); ok {
continue
}
var innerName string
if name == "" {
innerName = pb.metric.FieldPrefix + key
} else {
innerName = name + pb.metric.FieldSeparator + key
}
pb.FillFields(innerName, innerValue, fieldMap)
}
return
}
if _, ok := value.([]interface{}); ok {
return
}
if pb.metric.FieldName != "" {
name = pb.metric.FieldName
if prefix := pb.metric.FieldPrefix; prefix != "" {
name = prefix + name
}
}
if name == "" {
name = defaultFieldName
}
fieldMap[name] = value
}
// applySubstitutions updates all the keys in the supplied map
// of fields to account for $1-style substitution instructions.
func (pb *pointBuilder) applySubstitutions(mbean string, fieldMap map[string]interface{}) {
properties := makePropertyMap(mbean)
for i, subKey := range pb.substitutions[1:] {
symbol := fmt.Sprintf("$%d", i+1)
substitution := properties[subKey]
for fieldName, fieldValue := range fieldMap {
newFieldName := strings.ReplaceAll(fieldName, symbol, substitution)
if fieldName != newFieldName {
fieldMap[newFieldName] = fieldValue
delete(fieldMap, fieldName)
}
}
}
}
// makePropertyMap returns a the mbean property-key list as
// a dictionary. foo:x=y becomes map[string]string { "x": "y" }
func makePropertyMap(mbean string) map[string]string {
props := make(map[string]string)
object := strings.SplitN(mbean, ":", 2)
domain := object[0]
if domain != "" && len(object) == 2 {
list := object[1]
for _, keyProperty := range strings.Split(list, ",") {
pair := strings.SplitN(keyProperty, "=", 2)
if len(pair) != 2 {
continue
}
if key := pair[0]; key != "" {
props[key] = pair[1]
}
}
}
return props
}
// makeSubstitutionList returns an array of values to
// use as substitutions when renaming fields
// with the $1..$N syntax. The first item in the list
// is always the mbean domain.
func makeSubstitutionList(mbean string) []string {
subs := make([]string, 0)
object := strings.SplitN(mbean, ":", 2)
domain := object[0]
if domain != "" && len(object) == 2 {
subs = append(subs, domain)
list := object[1]
for _, keyProperty := range strings.Split(list, ",") {
pair := strings.SplitN(keyProperty, "=", 2)
if len(pair) != 2 {
continue
}
key := pair[0]
if key == "" {
continue
}
property := pair[1]
if !strings.Contains(property, "*") {
continue
}
subs = append(subs, key)
}
}
return subs
}

View file

@ -0,0 +1,158 @@
package kafka
import (
"errors"
"math"
"strings"
"time"
"github.com/IBM/sarama"
"github.com/influxdata/telegraf"
"github.com/influxdata/telegraf/config"
"github.com/influxdata/telegraf/plugins/common/tls"
)
// ReadConfig for kafka clients meaning to read from Kafka.
type ReadConfig struct {
Config
}
// SetConfig on the sarama.Config object from the ReadConfig struct.
func (k *ReadConfig) SetConfig(cfg *sarama.Config, log telegraf.Logger) error {
cfg.Consumer.Return.Errors = true
return k.Config.SetConfig(cfg, log)
}
// WriteConfig for kafka clients meaning to write to kafka
type WriteConfig struct {
Config
RequiredAcks int `toml:"required_acks"`
MaxRetry int `toml:"max_retry"`
MaxMessageBytes int `toml:"max_message_bytes"`
IdempotentWrites bool `toml:"idempotent_writes"`
}
// SetConfig on the sarama.Config object from the WriteConfig struct.
func (k *WriteConfig) SetConfig(cfg *sarama.Config, log telegraf.Logger) error {
cfg.Producer.Return.Successes = true
cfg.Producer.Idempotent = k.IdempotentWrites
cfg.Producer.Retry.Max = k.MaxRetry
if k.MaxMessageBytes > 0 {
cfg.Producer.MaxMessageBytes = k.MaxMessageBytes
}
cfg.Producer.RequiredAcks = sarama.RequiredAcks(k.RequiredAcks)
if cfg.Producer.Idempotent {
cfg.Net.MaxOpenRequests = 1
}
return k.Config.SetConfig(cfg, log)
}
// Config common to all Kafka clients.
type Config struct {
SASLAuth
tls.ClientConfig
Version string `toml:"version"`
ClientID string `toml:"client_id"`
CompressionCodec int `toml:"compression_codec"`
EnableTLS *bool `toml:"enable_tls"`
KeepAlivePeriod *config.Duration `toml:"keep_alive_period"`
MetadataRetryMax int `toml:"metadata_retry_max"`
MetadataRetryType string `toml:"metadata_retry_type"`
MetadataRetryBackoff config.Duration `toml:"metadata_retry_backoff"`
MetadataRetryMaxDuration config.Duration `toml:"metadata_retry_max_duration"`
// Disable full metadata fetching
MetadataFull *bool `toml:"metadata_full"`
}
type BackoffFunc func(retries, maxRetries int) time.Duration
func makeBackoffFunc(backoff, maxDuration time.Duration) BackoffFunc {
return func(retries, _ int) time.Duration {
d := time.Duration(math.Pow(2, float64(retries))) * backoff
if maxDuration != 0 && d > maxDuration {
return maxDuration
}
return d
}
}
// SetConfig on the sarama.Config object from the Config struct.
func (k *Config) SetConfig(cfg *sarama.Config, log telegraf.Logger) error {
if k.Version != "" {
version, err := sarama.ParseKafkaVersion(k.Version)
if err != nil {
return err
}
cfg.Version = version
}
if k.ClientID != "" {
cfg.ClientID = k.ClientID
} else {
cfg.ClientID = "Telegraf"
}
cfg.Producer.Compression = sarama.CompressionCodec(k.CompressionCodec)
if k.EnableTLS != nil && *k.EnableTLS {
cfg.Net.TLS.Enable = true
}
tlsConfig, err := k.ClientConfig.TLSConfig()
if err != nil {
return err
}
if tlsConfig != nil {
cfg.Net.TLS.Config = tlsConfig
// To maintain backwards compatibility, if the enable_tls option is not
// set TLS is enabled if a non-default TLS config is used.
if k.EnableTLS == nil {
cfg.Net.TLS.Enable = true
}
}
if k.KeepAlivePeriod != nil {
// Defaults to OS setting (15s currently)
cfg.Net.KeepAlive = time.Duration(*k.KeepAlivePeriod)
}
if k.MetadataFull != nil {
// Defaults to true in Sarama
cfg.Metadata.Full = *k.MetadataFull
}
if k.MetadataRetryMax != 0 {
cfg.Metadata.Retry.Max = k.MetadataRetryMax
}
if k.MetadataRetryBackoff != 0 {
// If cfg.Metadata.Retry.BackoffFunc is set, sarama ignores
// cfg.Metadata.Retry.Backoff
cfg.Metadata.Retry.Backoff = time.Duration(k.MetadataRetryBackoff)
}
switch strings.ToLower(k.MetadataRetryType) {
default:
return errors.New("invalid metadata retry type")
case "exponential":
if k.MetadataRetryBackoff == 0 {
k.MetadataRetryBackoff = config.Duration(250 * time.Millisecond)
log.Warnf("metadata_retry_backoff is 0, using %s", time.Duration(k.MetadataRetryBackoff))
}
cfg.Metadata.Retry.BackoffFunc = makeBackoffFunc(
time.Duration(k.MetadataRetryBackoff),
time.Duration(k.MetadataRetryMaxDuration),
)
case "constant", "":
}
return k.SetSASLConfig(cfg)
}

View file

@ -0,0 +1,22 @@
package kafka
import (
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestBackoffFunc(t *testing.T) {
b := 250 * time.Millisecond
limit := 1100 * time.Millisecond
f := makeBackoffFunc(b, limit)
require.Equal(t, b, f(0, 0))
require.Equal(t, b*2, f(1, 0))
require.Equal(t, b*4, f(2, 0))
require.Equal(t, limit, f(3, 0)) // would be 2000 but that's greater than max
f = makeBackoffFunc(b, 0) // max = 0 means no max
require.Equal(t, b*8, f(3, 0)) // with no max, it's 2000
}

View file

@ -0,0 +1,41 @@
package kafka
import (
"sync"
"github.com/IBM/sarama"
"github.com/influxdata/telegraf"
"github.com/influxdata/telegraf/logger"
)
var (
log = logger.New("sarama", "", "")
once sync.Once
)
type debugLogger struct{}
func (*debugLogger) Print(v ...interface{}) {
log.Trace(v...)
}
func (*debugLogger) Printf(format string, v ...interface{}) {
log.Tracef(format, v...)
}
func (l *debugLogger) Println(v ...interface{}) {
l.Print(v...)
}
// SetLogger configures a debug logger for kafka (sarama)
func SetLogger(level telegraf.LogLevel) {
// Set-up the sarama logger only once
once.Do(func() {
sarama.Logger = &debugLogger{}
})
// Increase the log-level if needed.
if !log.Level().Includes(level) {
log.SetLevel(level)
}
}

View file

@ -0,0 +1,127 @@
package kafka
import (
"errors"
"fmt"
"github.com/IBM/sarama"
"github.com/influxdata/telegraf/config"
)
type SASLAuth struct {
SASLUsername config.Secret `toml:"sasl_username"`
SASLPassword config.Secret `toml:"sasl_password"`
SASLExtensions map[string]string `toml:"sasl_extensions"`
SASLMechanism string `toml:"sasl_mechanism"`
SASLVersion *int `toml:"sasl_version"`
// GSSAPI config
SASLGSSAPIServiceName string `toml:"sasl_gssapi_service_name"`
SASLGSSAPIAuthType string `toml:"sasl_gssapi_auth_type"`
SASLGSSAPIDisablePAFXFAST bool `toml:"sasl_gssapi_disable_pafxfast"`
SASLGSSAPIKerberosConfigPath string `toml:"sasl_gssapi_kerberos_config_path"`
SASLGSSAPIKeyTabPath string `toml:"sasl_gssapi_key_tab_path"`
SASLGSSAPIRealm string `toml:"sasl_gssapi_realm"`
// OAUTHBEARER config
SASLAccessToken config.Secret `toml:"sasl_access_token"`
}
// SetSASLConfig configures SASL for kafka (sarama)
func (k *SASLAuth) SetSASLConfig(cfg *sarama.Config) error {
username, err := k.SASLUsername.Get()
if err != nil {
return fmt.Errorf("getting username failed: %w", err)
}
cfg.Net.SASL.User = username.String()
defer username.Destroy()
password, err := k.SASLPassword.Get()
if err != nil {
return fmt.Errorf("getting password failed: %w", err)
}
cfg.Net.SASL.Password = password.String()
defer password.Destroy()
if k.SASLMechanism != "" {
cfg.Net.SASL.Mechanism = sarama.SASLMechanism(k.SASLMechanism)
switch cfg.Net.SASL.Mechanism {
case sarama.SASLTypeSCRAMSHA256:
cfg.Net.SASL.SCRAMClientGeneratorFunc = func() sarama.SCRAMClient {
return &XDGSCRAMClient{HashGeneratorFcn: SHA256}
}
case sarama.SASLTypeSCRAMSHA512:
cfg.Net.SASL.SCRAMClientGeneratorFunc = func() sarama.SCRAMClient {
return &XDGSCRAMClient{HashGeneratorFcn: SHA512}
}
case sarama.SASLTypeOAuth:
cfg.Net.SASL.TokenProvider = k // use self as token provider.
case sarama.SASLTypeGSSAPI:
cfg.Net.SASL.GSSAPI.ServiceName = k.SASLGSSAPIServiceName
cfg.Net.SASL.GSSAPI.AuthType = gssapiAuthType(k.SASLGSSAPIAuthType)
cfg.Net.SASL.GSSAPI.Username = username.String()
cfg.Net.SASL.GSSAPI.Password = password.String()
cfg.Net.SASL.GSSAPI.DisablePAFXFAST = k.SASLGSSAPIDisablePAFXFAST
cfg.Net.SASL.GSSAPI.KerberosConfigPath = k.SASLGSSAPIKerberosConfigPath
cfg.Net.SASL.GSSAPI.KeyTabPath = k.SASLGSSAPIKeyTabPath
cfg.Net.SASL.GSSAPI.Realm = k.SASLGSSAPIRealm
case sarama.SASLTypePlaintext:
// nothing.
default:
}
}
if !k.SASLUsername.Empty() || k.SASLMechanism != "" {
cfg.Net.SASL.Enable = true
version, err := SASLVersion(cfg.Version, k.SASLVersion)
if err != nil {
return err
}
cfg.Net.SASL.Version = version
}
return nil
}
// Token does nothing smart, it just grabs a hard-coded token from config.
func (k *SASLAuth) Token() (*sarama.AccessToken, error) {
token, err := k.SASLAccessToken.Get()
if err != nil {
return nil, fmt.Errorf("getting token failed: %w", err)
}
defer token.Destroy()
return &sarama.AccessToken{
Token: token.String(),
Extensions: k.SASLExtensions,
}, nil
}
func SASLVersion(kafkaVersion sarama.KafkaVersion, saslVersion *int) (int16, error) {
if saslVersion == nil {
if kafkaVersion.IsAtLeast(sarama.V1_0_0_0) {
return sarama.SASLHandshakeV1, nil
}
return sarama.SASLHandshakeV0, nil
}
switch *saslVersion {
case 0:
return sarama.SASLHandshakeV0, nil
case 1:
return sarama.SASLHandshakeV1, nil
default:
return 0, errors.New("invalid SASL version")
}
}
func gssapiAuthType(authType string) int {
switch authType {
case "KRB5_USER_AUTH":
return sarama.KRB5_USER_AUTH
case "KRB5_KEYTAB_AUTH":
return sarama.KRB5_KEYTAB_AUTH
default:
return 0
}
}

View file

@ -0,0 +1,35 @@
package kafka
import (
"crypto/sha256"
"crypto/sha512"
"hash"
"github.com/xdg/scram"
)
var SHA256 scram.HashGeneratorFcn = func() hash.Hash { return sha256.New() }
var SHA512 scram.HashGeneratorFcn = func() hash.Hash { return sha512.New() }
type XDGSCRAMClient struct {
*scram.Client
*scram.ClientConversation
scram.HashGeneratorFcn
}
func (x *XDGSCRAMClient) Begin(userName, password, authzID string) (err error) {
x.Client, err = x.HashGeneratorFcn.NewClient(userName, password, authzID)
if err != nil {
return err
}
x.ClientConversation = x.Client.NewConversation()
return nil
}
func (x *XDGSCRAMClient) Step(challenge string) (response string, err error) {
return x.ClientConversation.Step(challenge)
}
func (x *XDGSCRAMClient) Done() bool {
return x.ClientConversation.Done()
}

View file

@ -0,0 +1,35 @@
package logrus
import (
"io"
"log" //nolint:depguard // Allow exceptional but valid use of log here.
"strings"
"sync"
"github.com/sirupsen/logrus"
)
var once sync.Once
type LogHook struct {
}
// InstallHook installs a logging hook into the logrus standard logger, diverting all logs
// through the Telegraf logger at debug level. This is useful for libraries
// that directly log to the logrus system without providing an override method.
func InstallHook() {
once.Do(func() {
logrus.SetOutput(io.Discard)
logrus.AddHook(&LogHook{})
})
}
func (*LogHook) Fire(entry *logrus.Entry) error {
msg := strings.ReplaceAll(entry.Message, "\n", " ")
log.Print("D! [logrus] ", msg)
return nil
}
func (*LogHook) Levels() []logrus.Level {
return logrus.AllLevels
}

View file

@ -0,0 +1,95 @@
package mqtt
import (
"errors"
"fmt"
"net/url"
"strings"
paho "github.com/eclipse/paho.mqtt.golang"
"github.com/influxdata/telegraf/config"
"github.com/influxdata/telegraf/plugins/common/tls"
)
// mqtt v5-specific publish properties.
// See https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901109
type PublishProperties struct {
ContentType string `toml:"content_type"`
ResponseTopic string `toml:"response_topic"`
MessageExpiry config.Duration `toml:"message_expiry"`
TopicAlias *uint16 `toml:"topic_alias"`
UserProperties map[string]string `toml:"user_properties"`
}
type MqttConfig struct {
Servers []string `toml:"servers"`
Protocol string `toml:"protocol"`
Username config.Secret `toml:"username"`
Password config.Secret `toml:"password"`
Timeout config.Duration `toml:"timeout"`
ConnectionTimeout config.Duration `toml:"connection_timeout"`
QoS int `toml:"qos"`
ClientID string `toml:"client_id"`
Retain bool `toml:"retain"`
KeepAlive int64 `toml:"keep_alive"`
PersistentSession bool `toml:"persistent_session"`
PublishPropertiesV5 *PublishProperties `toml:"v5"`
ClientTrace bool `toml:"client_trace"`
tls.ClientConfig
AutoReconnect bool `toml:"-"`
OnConnectionLost func(error) `toml:"-"`
}
// Client is a protocol neutral MQTT client for connecting,
// disconnecting, and publishing data to a topic.
// The protocol specific clients must implement this interface
type Client interface {
Connect() (bool, error)
Publish(topic string, data []byte) error
SubscribeMultiple(filters map[string]byte, callback paho.MessageHandler) error
AddRoute(topic string, callback paho.MessageHandler)
Close() error
}
func NewClient(cfg *MqttConfig) (Client, error) {
if len(cfg.Servers) == 0 {
return nil, errors.New("no servers specified")
}
if cfg.PersistentSession && cfg.ClientID == "" {
return nil, errors.New("persistent_session requires client_id")
}
if cfg.QoS > 2 || cfg.QoS < 0 {
return nil, fmt.Errorf("invalid QoS value %d; must be 0, 1 or 2", cfg.QoS)
}
switch cfg.Protocol {
case "", "3.1.1":
return NewMQTTv311Client(cfg)
case "5":
return NewMQTTv5Client(cfg)
}
return nil, fmt.Errorf("unsupported protocol %q: must be \"3.1.1\" or \"5\"", cfg.Protocol)
}
func parseServers(servers []string) ([]*url.URL, error) {
urls := make([]*url.URL, 0, len(servers))
for _, svr := range servers {
// Preserve support for host:port style servers; deprecated in Telegraf 1.4.4
if !strings.Contains(svr, "://") {
urls = append(urls, &url.URL{Scheme: "tcp", Host: svr})
continue
}
u, err := url.Parse(svr)
if err != nil {
return nil, err
}
urls = append(urls, u)
}
return urls, nil
}

View file

@ -0,0 +1,17 @@
package mqtt
import (
"github.com/influxdata/telegraf"
)
type mqttLogger struct {
telegraf.Logger
}
func (l mqttLogger) Printf(fmt string, args ...interface{}) {
l.Logger.Debugf(fmt, args...)
}
func (l mqttLogger) Println(args ...interface{}) {
l.Logger.Debug(args...)
}

View file

@ -0,0 +1,26 @@
package mqtt
import (
"testing"
"github.com/stretchr/testify/require"
)
// Test that default client has random ID
func TestRandomClientID(t *testing.T) {
var err error
cfg := &MqttConfig{
Servers: []string{"tcp://localhost:1883"},
}
client1, err := NewMQTTv311Client(cfg)
require.NoError(t, err)
client2, err := NewMQTTv311Client(cfg)
require.NoError(t, err)
options1 := client1.client.OptionsReader()
options2 := client2.client.OptionsReader()
require.NotEqual(t, options1.ClientID(), options2.ClientID())
}

View file

@ -0,0 +1,139 @@
package mqtt
import (
"fmt"
"time"
mqttv3 "github.com/eclipse/paho.mqtt.golang" // Library that supports v3.1.1
"github.com/influxdata/telegraf/internal"
"github.com/influxdata/telegraf/logger"
)
type mqttv311Client struct {
client mqttv3.Client
timeout time.Duration
qos int
retain bool
}
func NewMQTTv311Client(cfg *MqttConfig) (*mqttv311Client, error) {
opts := mqttv3.NewClientOptions()
opts.KeepAlive = cfg.KeepAlive
opts.WriteTimeout = time.Duration(cfg.Timeout)
if time.Duration(cfg.ConnectionTimeout) >= 1*time.Second {
opts.ConnectTimeout = time.Duration(cfg.ConnectionTimeout)
}
opts.SetCleanSession(!cfg.PersistentSession)
if cfg.OnConnectionLost != nil {
onConnectionLost := func(_ mqttv3.Client, err error) {
cfg.OnConnectionLost(err)
}
opts.SetConnectionLostHandler(onConnectionLost)
}
opts.SetAutoReconnect(cfg.AutoReconnect)
if cfg.ClientID != "" {
opts.SetClientID(cfg.ClientID)
} else {
id, err := internal.RandomString(5)
if err != nil {
return nil, fmt.Errorf("generating random client ID failed: %w", err)
}
opts.SetClientID("Telegraf-Output-" + id)
}
tlsCfg, err := cfg.ClientConfig.TLSConfig()
if err != nil {
return nil, err
}
opts.SetTLSConfig(tlsCfg)
if !cfg.Username.Empty() {
user, err := cfg.Username.Get()
if err != nil {
return nil, fmt.Errorf("getting username failed: %w", err)
}
opts.SetUsername(user.String())
user.Destroy()
}
if !cfg.Password.Empty() {
password, err := cfg.Password.Get()
if err != nil {
return nil, fmt.Errorf("getting password failed: %w", err)
}
opts.SetPassword(password.String())
password.Destroy()
}
servers, err := parseServers(cfg.Servers)
if err != nil {
return nil, err
}
for _, server := range servers {
if tlsCfg != nil {
server.Scheme = "tls"
}
broker := server.String()
opts.AddBroker(broker)
}
if cfg.ClientTrace {
log := &mqttLogger{logger.New("paho", "", "")}
mqttv3.ERROR = log
mqttv3.CRITICAL = log
mqttv3.WARN = log
mqttv3.DEBUG = log
}
return &mqttv311Client{
client: mqttv3.NewClient(opts),
timeout: time.Duration(cfg.Timeout),
qos: cfg.QoS,
retain: cfg.Retain,
}, nil
}
func (m *mqttv311Client) Connect() (bool, error) {
token := m.client.Connect()
if token.Wait() && token.Error() != nil {
return false, token.Error()
}
// Persistent sessions should skip subscription if a session is present, as
// the subscriptions are stored by the server.
type sessionPresent interface {
SessionPresent() bool
}
if t, ok := token.(sessionPresent); ok {
return t.SessionPresent(), nil
}
return false, nil
}
func (m *mqttv311Client) Publish(topic string, body []byte) error {
token := m.client.Publish(topic, byte(m.qos), m.retain, body)
if !token.WaitTimeout(m.timeout) {
return internal.ErrTimeout
}
return token.Error()
}
func (m *mqttv311Client) SubscribeMultiple(filters map[string]byte, callback mqttv3.MessageHandler) error {
token := m.client.SubscribeMultiple(filters, callback)
token.Wait()
return token.Error()
}
func (m *mqttv311Client) AddRoute(topic string, callback mqttv3.MessageHandler) {
m.client.AddRoute(topic, callback)
}
func (m *mqttv311Client) Close() error {
if m.client.IsConnected() {
m.client.Disconnect(100)
}
return nil
}

View file

@ -0,0 +1,165 @@
package mqtt
import (
"context"
"fmt"
"net/url"
"time"
mqttv5auto "github.com/eclipse/paho.golang/autopaho"
mqttv5 "github.com/eclipse/paho.golang/paho"
paho "github.com/eclipse/paho.mqtt.golang"
"github.com/influxdata/telegraf/config"
"github.com/influxdata/telegraf/internal"
"github.com/influxdata/telegraf/logger"
)
type mqttv5Client struct {
client *mqttv5auto.ConnectionManager
options mqttv5auto.ClientConfig
username config.Secret
password config.Secret
timeout time.Duration
qos int
retain bool
clientTrace bool
properties *mqttv5.PublishProperties
}
func NewMQTTv5Client(cfg *MqttConfig) (*mqttv5Client, error) {
opts := mqttv5auto.ClientConfig{
KeepAlive: uint16(cfg.KeepAlive),
OnConnectError: cfg.OnConnectionLost,
}
opts.ConnectPacketBuilder = func(c *mqttv5.Connect, _ *url.URL) (*mqttv5.Connect, error) {
c.CleanStart = cfg.PersistentSession
return c, nil
}
if time.Duration(cfg.ConnectionTimeout) >= 1*time.Second {
opts.ConnectTimeout = time.Duration(cfg.ConnectionTimeout)
}
if cfg.ClientID != "" {
opts.ClientID = cfg.ClientID
} else {
id, err := internal.RandomString(5)
if err != nil {
return nil, fmt.Errorf("generating random client ID failed: %w", err)
}
opts.ClientID = "Telegraf-Output-" + id
}
tlsCfg, err := cfg.ClientConfig.TLSConfig()
if err != nil {
return nil, err
}
if tlsCfg != nil {
opts.TlsCfg = tlsCfg
}
brokers := make([]*url.URL, 0)
servers, err := parseServers(cfg.Servers)
if err != nil {
return nil, err
}
for _, server := range servers {
if tlsCfg != nil {
server.Scheme = "tls"
}
brokers = append(brokers, server)
}
opts.BrokerUrls = brokers
// Build the v5 specific publish properties if they are present in the config.
// These should not change during the lifecycle of the client.
var properties *mqttv5.PublishProperties
if cfg.PublishPropertiesV5 != nil {
properties = &mqttv5.PublishProperties{
ContentType: cfg.PublishPropertiesV5.ContentType,
ResponseTopic: cfg.PublishPropertiesV5.ResponseTopic,
TopicAlias: cfg.PublishPropertiesV5.TopicAlias,
}
messageExpiry := time.Duration(cfg.PublishPropertiesV5.MessageExpiry)
if expirySeconds := uint32(messageExpiry.Seconds()); expirySeconds > 0 {
properties.MessageExpiry = &expirySeconds
}
properties.User = make([]mqttv5.UserProperty, 0, len(cfg.PublishPropertiesV5.UserProperties))
for k, v := range cfg.PublishPropertiesV5.UserProperties {
properties.User.Add(k, v)
}
}
return &mqttv5Client{
options: opts,
timeout: time.Duration(cfg.Timeout),
username: cfg.Username,
password: cfg.Password,
qos: cfg.QoS,
retain: cfg.Retain,
properties: properties,
clientTrace: cfg.ClientTrace,
}, nil
}
func (m *mqttv5Client) Connect() (bool, error) {
user, err := m.username.Get()
if err != nil {
return false, fmt.Errorf("getting username failed: %w", err)
}
defer user.Destroy()
pass, err := m.password.Get()
if err != nil {
return false, fmt.Errorf("getting password failed: %w", err)
}
defer pass.Destroy()
m.options.ConnectUsername = user.String()
m.options.ConnectPassword = []byte(pass.String())
if m.clientTrace {
log := mqttLogger{logger.New("paho", "", "")}
m.options.Debug = log
m.options.Errors = log
}
client, err := mqttv5auto.NewConnection(context.Background(), m.options)
if err != nil {
return false, err
}
m.client = client
return false, client.AwaitConnection(context.Background())
}
func (m *mqttv5Client) Publish(topic string, body []byte) error {
ctx, cancel := context.WithTimeout(context.Background(), m.timeout)
defer cancel()
_, err := m.client.Publish(ctx, &mqttv5.Publish{
Topic: topic,
QoS: byte(m.qos),
Retain: m.retain,
Payload: body,
Properties: m.properties,
})
return err
}
func (*mqttv5Client) SubscribeMultiple(filters map[string]byte, callback paho.MessageHandler) error {
_, _ = filters, callback
panic("not implemented")
}
func (*mqttv5Client) AddRoute(topic string, callback paho.MessageHandler) {
_, _ = topic, callback
panic("not implemented")
}
func (m *mqttv5Client) Close() error {
return m.client.Disconnect(context.Background())
}

View file

@ -0,0 +1,42 @@
package oauth
import (
"context"
"net/http"
"net/url"
"golang.org/x/oauth2"
"golang.org/x/oauth2/clientcredentials"
)
type OAuth2Config struct {
// OAuth2 Credentials
ClientID string `toml:"client_id"`
ClientSecret string `toml:"client_secret"`
TokenURL string `toml:"token_url"`
Audience string `toml:"audience"`
Scopes []string `toml:"scopes"`
}
func (o *OAuth2Config) CreateOauth2Client(ctx context.Context, client *http.Client) *http.Client {
if o.ClientID == "" || o.ClientSecret == "" || o.TokenURL == "" {
return client
}
oauthConfig := clientcredentials.Config{
ClientID: o.ClientID,
ClientSecret: o.ClientSecret,
TokenURL: o.TokenURL,
Scopes: o.Scopes,
EndpointParams: make(url.Values),
}
if o.Audience != "" {
oauthConfig.EndpointParams.Add("audience", o.Audience)
}
ctx = context.WithValue(ctx, oauth2.HTTPClient, client)
client = oauthConfig.Client(ctx)
return client
}

View file

@ -0,0 +1,241 @@
package opcua
import (
"context"
"errors"
"fmt"
"log" //nolint:depguard // just for debug
"net/url"
"strconv"
"time"
"github.com/gopcua/opcua"
"github.com/gopcua/opcua/debug"
"github.com/gopcua/opcua/ua"
"github.com/influxdata/telegraf"
"github.com/influxdata/telegraf/config"
"github.com/influxdata/telegraf/internal/choice"
)
type OpcUAWorkarounds struct {
AdditionalValidStatusCodes []string `toml:"additional_valid_status_codes"`
}
type ConnectionState opcua.ConnState
const (
Closed ConnectionState = ConnectionState(opcua.Closed)
Connected ConnectionState = ConnectionState(opcua.Connected)
Connecting ConnectionState = ConnectionState(opcua.Connecting)
Disconnected ConnectionState = ConnectionState(opcua.Disconnected)
Reconnecting ConnectionState = ConnectionState(opcua.Reconnecting)
)
func (c ConnectionState) String() string {
return opcua.ConnState(c).String()
}
type OpcUAClientConfig struct {
Endpoint string `toml:"endpoint"`
SecurityPolicy string `toml:"security_policy"`
SecurityMode string `toml:"security_mode"`
Certificate string `toml:"certificate"`
PrivateKey string `toml:"private_key"`
Username config.Secret `toml:"username"`
Password config.Secret `toml:"password"`
AuthMethod string `toml:"auth_method"`
ConnectTimeout config.Duration `toml:"connect_timeout"`
RequestTimeout config.Duration `toml:"request_timeout"`
ClientTrace bool `toml:"client_trace"`
OptionalFields []string `toml:"optional_fields"`
Workarounds OpcUAWorkarounds `toml:"workarounds"`
SessionTimeout config.Duration `toml:"session_timeout"`
}
func (o *OpcUAClientConfig) Validate() error {
if err := o.validateOptionalFields(); err != nil {
return fmt.Errorf("invalid 'optional_fields': %w", err)
}
return o.validateEndpoint()
}
func (o *OpcUAClientConfig) validateOptionalFields() error {
validFields := []string{"DataType"}
return choice.CheckSlice(o.OptionalFields, validFields)
}
func (o *OpcUAClientConfig) validateEndpoint() error {
if o.Endpoint == "" {
return errors.New("endpoint url is empty")
}
_, err := url.Parse(o.Endpoint)
if err != nil {
return errors.New("endpoint url is invalid")
}
switch o.SecurityPolicy {
case "None", "Basic128Rsa15", "Basic256", "Basic256Sha256", "auto":
default:
return fmt.Errorf("invalid security type %q in %q", o.SecurityPolicy, o.Endpoint)
}
switch o.SecurityMode {
case "None", "Sign", "SignAndEncrypt", "auto":
default:
return fmt.Errorf("invalid security type %q in %q", o.SecurityMode, o.Endpoint)
}
return nil
}
func (o *OpcUAClientConfig) CreateClient(telegrafLogger telegraf.Logger) (*OpcUAClient, error) {
err := o.Validate()
if err != nil {
return nil, err
}
if o.ClientTrace {
debug.Enable = true
debug.Logger = log.New(&DebugLogger{Log: telegrafLogger}, "", 0)
}
c := &OpcUAClient{
Config: o,
Log: telegrafLogger,
}
c.Log.Debug("Initialising OpcUAClient")
err = c.setupWorkarounds()
return c, err
}
type OpcUAClient struct {
Config *OpcUAClientConfig
Log telegraf.Logger
Client *opcua.Client
opts []opcua.Option
codes []ua.StatusCode
}
// / setupOptions read the endpoints from the specified server and setup all authentication
func (o *OpcUAClient) SetupOptions() error {
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(o.Config.ConnectTimeout))
defer cancel()
// Get a list of the endpoints for our target server
endpoints, err := opcua.GetEndpoints(ctx, o.Config.Endpoint)
if err != nil {
return err
}
if o.Config.Certificate == "" && o.Config.PrivateKey == "" {
if o.Config.SecurityPolicy != "None" || o.Config.SecurityMode != "None" {
o.Log.Debug("Generating self-signed certificate")
cert, privateKey, err := generateCert("urn:telegraf:gopcua:client", 2048,
o.Config.Certificate, o.Config.PrivateKey, 365*24*time.Hour)
if err != nil {
return err
}
o.Config.Certificate = cert
o.Config.PrivateKey = privateKey
}
}
o.Log.Debug("Configuring OPC UA connection options")
o.opts, err = o.generateClientOpts(endpoints)
return err
}
func (o *OpcUAClient) setupWorkarounds() error {
o.codes = []ua.StatusCode{ua.StatusOK}
for _, c := range o.Config.Workarounds.AdditionalValidStatusCodes {
val, err := strconv.ParseUint(c, 0, 32) // setting 32 bits to allow for safe conversion
if err != nil {
return err
}
o.codes = append(o.codes, ua.StatusCode(val))
}
return nil
}
func (o *OpcUAClient) StatusCodeOK(code ua.StatusCode) bool {
for _, val := range o.codes {
if val == code {
return true
}
}
return false
}
// Connect to an OPC UA device
func (o *OpcUAClient) Connect(ctx context.Context) error {
o.Log.Debug("Connecting OPC UA Client to server")
u, err := url.Parse(o.Config.Endpoint)
if err != nil {
return err
}
switch u.Scheme {
case "opc.tcp":
if err := o.SetupOptions(); err != nil {
return err
}
if o.Client != nil {
o.Log.Warnf("Closing connection to %q as already connected", u)
if err := o.Client.Close(ctx); err != nil {
// Only log the error but to not bail-out here as this prevents
// reconnections for multiple parties (see e.g. #9523).
o.Log.Errorf("Closing connection failed: %v", err)
}
}
o.Client, err = opcua.NewClient(o.Config.Endpoint, o.opts...)
if err != nil {
return fmt.Errorf("error in new client: %w", err)
}
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(o.Config.ConnectTimeout))
defer cancel()
if err := o.Client.Connect(ctx); err != nil {
return fmt.Errorf("error in Client Connection: %w", err)
}
o.Log.Debug("Connected to OPC UA Server")
default:
return fmt.Errorf("unsupported scheme %q in endpoint. Expected opc.tcp", u.Scheme)
}
return nil
}
func (o *OpcUAClient) Disconnect(ctx context.Context) error {
o.Log.Debug("Disconnecting from OPC UA Server")
u, err := url.Parse(o.Config.Endpoint)
if err != nil {
return err
}
switch u.Scheme {
case "opc.tcp":
// We can't do anything about failing to close a connection
err := o.Client.Close(ctx)
o.Client = nil
return err
default:
return errors.New("invalid controller")
}
}
func (o *OpcUAClient) State() ConnectionState {
if o.Client == nil {
return Disconnected
}
return ConnectionState(o.Client.State())
}

View file

@ -0,0 +1,33 @@
package opcua
import (
"testing"
"github.com/gopcua/opcua/ua"
"github.com/stretchr/testify/require"
)
func TestSetupWorkarounds(t *testing.T) {
o := OpcUAClient{
Config: &OpcUAClientConfig{
Workarounds: OpcUAWorkarounds{
AdditionalValidStatusCodes: []string{"0xC0", "0x00AA0000", "0x80000000"},
},
},
}
err := o.setupWorkarounds()
require.NoError(t, err)
require.Len(t, o.codes, 4)
require.Equal(t, o.codes[0], ua.StatusCode(0))
require.Equal(t, o.codes[1], ua.StatusCode(192))
require.Equal(t, o.codes[2], ua.StatusCode(11141120))
require.Equal(t, o.codes[3], ua.StatusCode(2147483648))
}
func TestCheckStatusCode(t *testing.T) {
var o OpcUAClient
o.codes = []ua.StatusCode{ua.StatusCode(0), ua.StatusCode(192), ua.StatusCode(11141120)}
require.True(t, o.StatusCodeOK(ua.StatusCode(192)))
}

View file

@ -0,0 +1,500 @@
package input
import (
"context"
"errors"
"fmt"
"sort"
"strconv"
"strings"
"time"
"github.com/gopcua/opcua/ua"
"github.com/influxdata/telegraf"
"github.com/influxdata/telegraf/config"
"github.com/influxdata/telegraf/internal/choice"
"github.com/influxdata/telegraf/metric"
"github.com/influxdata/telegraf/plugins/common/opcua"
)
type Trigger string
const (
Status Trigger = "Status"
StatusValue Trigger = "StatusValue"
StatusValueTimestamp Trigger = "StatusValueTimestamp"
)
type DeadbandType string
const (
Absolute DeadbandType = "Absolute"
Percent DeadbandType = "Percent"
)
type DataChangeFilter struct {
Trigger Trigger `toml:"trigger"`
DeadbandType DeadbandType `toml:"deadband_type"`
DeadbandValue *float64 `toml:"deadband_value"`
}
type MonitoringParameters struct {
SamplingInterval config.Duration `toml:"sampling_interval"`
QueueSize *uint32 `toml:"queue_size"`
DiscardOldest *bool `toml:"discard_oldest"`
DataChangeFilter *DataChangeFilter `toml:"data_change_filter"`
}
// NodeSettings describes how to map from a OPC UA node to a Metric
type NodeSettings struct {
FieldName string `toml:"name"`
Namespace string `toml:"namespace"`
IdentifierType string `toml:"identifier_type"`
Identifier string `toml:"identifier"`
DataType string `toml:"data_type" deprecated:"1.17.0;1.35.0;option is ignored"`
Description string `toml:"description" deprecated:"1.17.0;1.35.0;option is ignored"`
TagsSlice [][]string `toml:"tags" deprecated:"1.25.0;1.35.0;use 'default_tags' instead"`
DefaultTags map[string]string `toml:"default_tags"`
MonitoringParams MonitoringParameters `toml:"monitoring_params"`
}
// NodeID returns the OPC UA node id
func (tag *NodeSettings) NodeID() string {
return "ns=" + tag.Namespace + ";" + tag.IdentifierType + "=" + tag.Identifier
}
// NodeGroupSettings describes a mapping of group of nodes to Metrics
type NodeGroupSettings struct {
MetricName string `toml:"name"` // Overrides plugin's setting
Namespace string `toml:"namespace"` // Can be overridden by node setting
IdentifierType string `toml:"identifier_type"` // Can be overridden by node setting
Nodes []NodeSettings `toml:"nodes"`
TagsSlice [][]string `toml:"tags" deprecated:"1.26.0;1.35.0;use default_tags"`
DefaultTags map[string]string `toml:"default_tags"`
SamplingInterval config.Duration `toml:"sampling_interval"` // Can be overridden by monitoring parameters
}
type TimestampSource string
const (
TimestampSourceServer TimestampSource = "server"
TimestampSourceSource TimestampSource = "source"
TimestampSourceTelegraf TimestampSource = "gather"
)
// InputClientConfig a configuration for the input client
type InputClientConfig struct {
opcua.OpcUAClientConfig
MetricName string `toml:"name"`
Timestamp TimestampSource `toml:"timestamp"`
TimestampFormat string `toml:"timestamp_format"`
RootNodes []NodeSettings `toml:"nodes"`
Groups []NodeGroupSettings `toml:"group"`
}
func (o *InputClientConfig) Validate() error {
if o.MetricName == "" {
return errors.New("metric name is empty")
}
err := choice.Check(string(o.Timestamp), []string{"", "gather", "server", "source"})
if err != nil {
return err
}
if o.TimestampFormat == "" {
o.TimestampFormat = time.RFC3339Nano
}
if len(o.Groups) == 0 && len(o.RootNodes) == 0 {
return errors.New("no groups or root nodes provided to gather from")
}
for _, group := range o.Groups {
if len(group.Nodes) == 0 {
return errors.New("group has no nodes to collect from")
}
}
return nil
}
func (o *InputClientConfig) CreateInputClient(log telegraf.Logger) (*OpcUAInputClient, error) {
if err := o.Validate(); err != nil {
return nil, err
}
log.Debug("Initialising OpcUAInputClient")
opcClient, err := o.OpcUAClientConfig.CreateClient(log)
if err != nil {
return nil, err
}
c := &OpcUAInputClient{
OpcUAClient: opcClient,
Log: log,
Config: *o,
}
log.Debug("Initialising node to metric mapping")
if err := c.InitNodeMetricMapping(); err != nil {
return nil, err
}
c.initLastReceivedValues()
return c, nil
}
// NodeMetricMapping mapping from a single node to a metric
type NodeMetricMapping struct {
Tag NodeSettings
idStr string
metricName string
MetricTags map[string]string
}
// NewNodeMetricMapping builds a new NodeMetricMapping from the given argument
func NewNodeMetricMapping(metricName string, node NodeSettings, groupTags map[string]string) (*NodeMetricMapping, error) {
mergedTags := make(map[string]string)
for n, t := range groupTags {
mergedTags[n] = t
}
nodeTags := make(map[string]string)
if len(node.DefaultTags) > 0 {
nodeTags = node.DefaultTags
} else if len(node.TagsSlice) > 0 {
// fixme: once the TagsSlice has been removed (after deprecation), remove this if else logic
var err error
nodeTags, err = tagsSliceToMap(node.TagsSlice)
if err != nil {
return nil, err
}
}
for n, t := range nodeTags {
mergedTags[n] = t
}
return &NodeMetricMapping{
Tag: node,
idStr: node.NodeID(),
metricName: metricName,
MetricTags: mergedTags,
}, nil
}
// NodeValue The received value for a node
type NodeValue struct {
TagName string
Value interface{}
Quality ua.StatusCode
ServerTime time.Time
SourceTime time.Time
DataType ua.TypeID
IsArray bool
}
// OpcUAInputClient can receive data from an OPC UA server and map it to Metrics. This type does not contain
// logic for actually retrieving data from the server, but is used by other types like ReadClient and
// OpcUAInputSubscribeClient to store data needed to convert node ids to the corresponding metrics.
type OpcUAInputClient struct {
*opcua.OpcUAClient
Config InputClientConfig
Log telegraf.Logger
NodeMetricMapping []NodeMetricMapping
NodeIDs []*ua.NodeID
LastReceivedData []NodeValue
}
// Stop the connection to the client
func (o *OpcUAInputClient) Stop(ctx context.Context) <-chan struct{} {
ch := make(chan struct{})
defer close(ch)
err := o.Disconnect(ctx)
if err != nil {
o.Log.Warn("Disconnecting from server failed with error ", err)
}
return ch
}
// metricParts is only used to ensure no duplicate metrics are created
type metricParts struct {
metricName string
fieldName string
tags string // sorted by tag name and in format tag1=value1, tag2=value2
}
func newMP(n *NodeMetricMapping) metricParts {
keys := make([]string, 0, len(n.MetricTags))
for key := range n.MetricTags {
keys = append(keys, key)
}
sort.Strings(keys)
var sb strings.Builder
for i, key := range keys {
if i != 0 {
sb.WriteString(", ")
}
sb.WriteString(key)
sb.WriteString("=")
sb.WriteString(n.MetricTags[key])
}
x := metricParts{
metricName: n.metricName,
fieldName: n.Tag.FieldName,
tags: sb.String(),
}
return x
}
// fixme: once the TagsSlice has been removed (after deprecation), remove this
// tagsSliceToMap takes an array of pairs of strings and creates a map from it
func tagsSliceToMap(tags [][]string) (map[string]string, error) {
m := make(map[string]string)
for i, tag := range tags {
if len(tag) != 2 {
return nil, fmt.Errorf("tag %d needs 2 values, has %d: %v", i+1, len(tag), tag)
}
if tag[0] == "" {
return nil, fmt.Errorf("tag %d has empty name", i+1)
}
if tag[1] == "" {
return nil, fmt.Errorf("tag %d has empty value", i+1)
}
if _, ok := m[tag[0]]; ok {
return nil, fmt.Errorf("tag %d has duplicate key: %v", i+1, tag[0])
}
m[tag[0]] = tag[1]
}
return m, nil
}
func validateNodeToAdd(existing map[metricParts]struct{}, nmm *NodeMetricMapping) error {
if nmm.Tag.FieldName == "" {
return fmt.Errorf("empty name in %q", nmm.Tag.FieldName)
}
if len(nmm.Tag.Namespace) == 0 {
return errors.New("empty node namespace not allowed")
}
if len(nmm.Tag.Identifier) == 0 {
return errors.New("empty node identifier not allowed")
}
mp := newMP(nmm)
if _, exists := existing[mp]; exists {
return fmt.Errorf("name %q is duplicated (metric name %q, tags %q)",
mp.fieldName, mp.metricName, mp.tags)
}
switch nmm.Tag.IdentifierType {
case "i":
if _, err := strconv.Atoi(nmm.Tag.Identifier); err != nil {
return fmt.Errorf("identifier type %q does not match the type of identifier %q", nmm.Tag.IdentifierType, nmm.Tag.Identifier)
}
case "s", "g", "b":
// Valid identifier type - do nothing.
default:
return fmt.Errorf("invalid identifier type %q in %q", nmm.Tag.IdentifierType, nmm.Tag.FieldName)
}
existing[mp] = struct{}{}
return nil
}
// InitNodeMetricMapping builds nodes from the configuration
func (o *OpcUAInputClient) InitNodeMetricMapping() error {
existing := make(map[metricParts]struct{}, len(o.Config.RootNodes))
for _, node := range o.Config.RootNodes {
nmm, err := NewNodeMetricMapping(o.Config.MetricName, node, make(map[string]string))
if err != nil {
return err
}
if err := validateNodeToAdd(existing, nmm); err != nil {
return err
}
o.NodeMetricMapping = append(o.NodeMetricMapping, *nmm)
}
for _, group := range o.Config.Groups {
if group.MetricName == "" {
group.MetricName = o.Config.MetricName
}
if len(group.DefaultTags) > 0 && len(group.TagsSlice) > 0 {
o.Log.Warn("Tags found in both `tags` and `default_tags`, only using tags defined in `default_tags`")
}
groupTags := make(map[string]string)
if len(group.DefaultTags) > 0 {
groupTags = group.DefaultTags
} else if len(group.TagsSlice) > 0 {
// fixme: once the TagsSlice has been removed (after deprecation), remove this if else logic
var err error
groupTags, err = tagsSliceToMap(group.TagsSlice)
if err != nil {
return err
}
}
for _, node := range group.Nodes {
if node.Namespace == "" {
node.Namespace = group.Namespace
}
if node.IdentifierType == "" {
node.IdentifierType = group.IdentifierType
}
if node.MonitoringParams.SamplingInterval == 0 {
node.MonitoringParams.SamplingInterval = group.SamplingInterval
}
nmm, err := NewNodeMetricMapping(group.MetricName, node, groupTags)
if err != nil {
return err
}
if err := validateNodeToAdd(existing, nmm); err != nil {
return err
}
o.NodeMetricMapping = append(o.NodeMetricMapping, *nmm)
}
}
return nil
}
func (o *OpcUAInputClient) InitNodeIDs() error {
o.NodeIDs = make([]*ua.NodeID, 0, len(o.NodeMetricMapping))
for _, node := range o.NodeMetricMapping {
nid, err := ua.ParseNodeID(node.Tag.NodeID())
if err != nil {
return err
}
o.NodeIDs = append(o.NodeIDs, nid)
}
return nil
}
func (o *OpcUAInputClient) initLastReceivedValues() {
o.LastReceivedData = make([]NodeValue, len(o.NodeMetricMapping))
for nodeIdx, nmm := range o.NodeMetricMapping {
o.LastReceivedData[nodeIdx].TagName = nmm.Tag.FieldName
}
}
func (o *OpcUAInputClient) UpdateNodeValue(nodeIdx int, d *ua.DataValue) {
o.LastReceivedData[nodeIdx].Quality = d.Status
if !o.StatusCodeOK(d.Status) {
// Verify NodeIDs array has been built before trying to get item; otherwise show '?' for node id
if len(o.NodeIDs) > nodeIdx {
o.Log.Errorf("status not OK for node %v (%v): %v", o.NodeMetricMapping[nodeIdx].Tag.FieldName, o.NodeIDs[nodeIdx].String(), d.Status)
} else {
o.Log.Errorf("status not OK for node %v (%v): %v", o.NodeMetricMapping[nodeIdx].Tag.FieldName, '?', d.Status)
}
return
}
if d.Value != nil {
o.LastReceivedData[nodeIdx].DataType = d.Value.Type()
o.LastReceivedData[nodeIdx].IsArray = d.Value.Has(ua.VariantArrayValues)
o.LastReceivedData[nodeIdx].Value = d.Value.Value()
if o.LastReceivedData[nodeIdx].DataType == ua.TypeIDDateTime {
if t, ok := d.Value.Value().(time.Time); ok {
o.LastReceivedData[nodeIdx].Value = t.Format(o.Config.TimestampFormat)
}
}
}
o.LastReceivedData[nodeIdx].ServerTime = d.ServerTimestamp
o.LastReceivedData[nodeIdx].SourceTime = d.SourceTimestamp
}
func (o *OpcUAInputClient) MetricForNode(nodeIdx int) telegraf.Metric {
nmm := &o.NodeMetricMapping[nodeIdx]
tags := map[string]string{
"id": nmm.idStr,
}
for k, v := range nmm.MetricTags {
tags[k] = v
}
fields := make(map[string]interface{})
if o.LastReceivedData[nodeIdx].Value != nil {
// Simple scalar types can be stored directly under the field name while
// arrays (see 5.2.5) and structures (see 5.2.6) must be unpacked.
// Note: Structures and arrays of structures are currently not supported.
if o.LastReceivedData[nodeIdx].IsArray {
switch typedValue := o.LastReceivedData[nodeIdx].Value.(type) {
case []uint8:
fields = unpack(nmm.Tag.FieldName, typedValue)
case []uint16:
fields = unpack(nmm.Tag.FieldName, typedValue)
case []uint32:
fields = unpack(nmm.Tag.FieldName, typedValue)
case []uint64:
fields = unpack(nmm.Tag.FieldName, typedValue)
case []int8:
fields = unpack(nmm.Tag.FieldName, typedValue)
case []int16:
fields = unpack(nmm.Tag.FieldName, typedValue)
case []int32:
fields = unpack(nmm.Tag.FieldName, typedValue)
case []int64:
fields = unpack(nmm.Tag.FieldName, typedValue)
case []float32:
fields = unpack(nmm.Tag.FieldName, typedValue)
case []float64:
fields = unpack(nmm.Tag.FieldName, typedValue)
case []string:
fields = unpack(nmm.Tag.FieldName, typedValue)
case []bool:
fields = unpack(nmm.Tag.FieldName, typedValue)
default:
o.Log.Errorf("could not unpack variant array of type: %T", typedValue)
}
} else {
fields = map[string]interface{}{
nmm.Tag.FieldName: o.LastReceivedData[nodeIdx].Value,
}
}
}
fields["Quality"] = strings.TrimSpace(o.LastReceivedData[nodeIdx].Quality.Error())
if choice.Contains("DataType", o.Config.OptionalFields) {
fields["DataType"] = strings.Replace(o.LastReceivedData[nodeIdx].DataType.String(), "TypeID", "", 1)
}
if !o.StatusCodeOK(o.LastReceivedData[nodeIdx].Quality) {
mp := newMP(nmm)
o.Log.Debugf("status not OK for node %q(metric name %q, tags %q)",
mp.fieldName, mp.metricName, mp.tags)
}
var t time.Time
switch o.Config.Timestamp {
case TimestampSourceServer:
t = o.LastReceivedData[nodeIdx].ServerTime
case TimestampSourceSource:
t = o.LastReceivedData[nodeIdx].SourceTime
default:
t = time.Now()
}
return metric.New(nmm.metricName, tags, fields, t)
}
func unpack[Slice ~[]E, E any](prefix string, value Slice) map[string]interface{} {
fields := make(map[string]interface{}, len(value))
for i, v := range value {
key := fmt.Sprintf("%s[%d]", prefix, i)
fields[key] = v
}
return fields
}

View file

@ -0,0 +1,890 @@
package input
import (
"errors"
"testing"
"time"
"github.com/gopcua/opcua/ua"
"github.com/stretchr/testify/require"
"github.com/influxdata/telegraf"
"github.com/influxdata/telegraf/config"
"github.com/influxdata/telegraf/metric"
"github.com/influxdata/telegraf/plugins/common/opcua"
"github.com/influxdata/telegraf/testutil"
)
func TestTagsSliceToMap(t *testing.T) {
m, err := tagsSliceToMap([][]string{{"foo", "bar"}, {"baz", "bat"}})
require.NoError(t, err)
require.Len(t, m, 2)
require.Equal(t, "bar", m["foo"])
require.Equal(t, "bat", m["baz"])
}
func TestTagsSliceToMap_twoStrings(t *testing.T) {
var err error
_, err = tagsSliceToMap([][]string{{"foo", "bar", "baz"}})
require.Error(t, err)
_, err = tagsSliceToMap([][]string{{"foo"}})
require.Error(t, err)
}
func TestTagsSliceToMap_dupeKey(t *testing.T) {
_, err := tagsSliceToMap([][]string{{"foo", "bar"}, {"foo", "bat"}})
require.Error(t, err)
}
func TestTagsSliceToMap_empty(t *testing.T) {
_, err := tagsSliceToMap([][]string{{"foo", ""}})
require.Equal(t, errors.New("tag 1 has empty value"), err)
_, err = tagsSliceToMap([][]string{{"", "bar"}})
require.Equal(t, errors.New("tag 1 has empty name"), err)
}
func TestValidateOPCTags(t *testing.T) {
tests := []struct {
name string
config InputClientConfig
err error
}{
{
"duplicates",
InputClientConfig{
MetricName: "mn",
RootNodes: []NodeSettings{
{
FieldName: "fn",
Namespace: "2",
IdentifierType: "s",
Identifier: "i1",
TagsSlice: [][]string{{"t1", "v1"}, {"t2", "v2"}},
},
},
Groups: []NodeGroupSettings{
{
Nodes: []NodeSettings{
{
FieldName: "fn",
Namespace: "2",
IdentifierType: "s",
Identifier: "i1",
},
},
TagsSlice: [][]string{{"t1", "v1"}, {"t2", "v2"}},
},
},
},
errors.New(`name "fn" is duplicated (metric name "mn", tags "t1=v1, t2=v2")`),
},
{
"empty tag value not allowed",
InputClientConfig{
MetricName: "mn",
RootNodes: []NodeSettings{
{
FieldName: "fn",
IdentifierType: "s",
TagsSlice: [][]string{{"t1", ""}},
},
},
},
errors.New("tag 1 has empty value"),
},
{
"empty tag name not allowed",
InputClientConfig{
MetricName: "mn",
RootNodes: []NodeSettings{
{
FieldName: "fn",
IdentifierType: "s",
TagsSlice: [][]string{{"", "1"}},
},
},
},
errors.New("tag 1 has empty name"),
},
{
"different metric tag names",
InputClientConfig{
MetricName: "mn",
RootNodes: []NodeSettings{
{
FieldName: "fn",
Namespace: "2",
IdentifierType: "s",
Identifier: "i1",
TagsSlice: [][]string{{"t1", "v1"}, {"t2", "v2"}},
},
{
FieldName: "fn",
Namespace: "2",
IdentifierType: "s",
Identifier: "i1",
TagsSlice: [][]string{{"t1", "v1"}, {"t3", "v2"}},
},
},
},
nil,
},
{
"different metric tag values",
InputClientConfig{
MetricName: "mn",
RootNodes: []NodeSettings{
{
FieldName: "fn",
Namespace: "2",
IdentifierType: "s",
Identifier: "i1",
TagsSlice: [][]string{{"t1", "foo"}, {"t2", "v2"}},
},
{
FieldName: "fn",
Namespace: "2",
IdentifierType: "s",
Identifier: "i1",
TagsSlice: [][]string{{"t1", "bar"}, {"t2", "v2"}},
},
},
},
nil,
},
{
"different metric names",
InputClientConfig{
MetricName: "mn",
Groups: []NodeGroupSettings{
{
MetricName: "mn",
Namespace: "2",
Nodes: []NodeSettings{
{
FieldName: "fn",
IdentifierType: "s",
Identifier: "i1",
TagsSlice: [][]string{{"t1", "v1"}, {"t2", "v2"}},
},
},
},
{
MetricName: "mn2",
Namespace: "2",
Nodes: []NodeSettings{
{
FieldName: "fn",
IdentifierType: "s",
Identifier: "i1",
TagsSlice: [][]string{{"t1", "v1"}, {"t2", "v2"}},
},
},
},
},
},
nil,
},
{
"different field names",
InputClientConfig{
MetricName: "mn",
RootNodes: []NodeSettings{
{
FieldName: "fn",
Namespace: "2",
IdentifierType: "s",
Identifier: "i1",
TagsSlice: [][]string{{"t1", "v1"}, {"t2", "v2"}},
},
{
FieldName: "fn2",
Namespace: "2",
IdentifierType: "s",
Identifier: "i1",
TagsSlice: [][]string{{"t1", "v1"}, {"t2", "v2"}},
},
},
},
nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
o := OpcUAInputClient{
Config: tt.config,
Log: testutil.Logger{},
}
require.Equal(t, tt.err, o.InitNodeMetricMapping())
})
}
}
func TestNewNodeMetricMappingTags(t *testing.T) {
tests := []struct {
name string
settings NodeSettings
groupTags map[string]string
expectedTags map[string]string
err error
}{
{
name: "empty tags",
settings: NodeSettings{
FieldName: "f",
Namespace: "2",
IdentifierType: "s",
Identifier: "h",
},
groupTags: map[string]string{},
expectedTags: map[string]string{},
err: nil,
},
{
name: "node tags only",
settings: NodeSettings{
FieldName: "f",
Namespace: "2",
IdentifierType: "s",
Identifier: "h",
TagsSlice: [][]string{{"t1", "v1"}},
},
groupTags: map[string]string{},
expectedTags: map[string]string{"t1": "v1"},
err: nil,
},
{
name: "group tags only",
settings: NodeSettings{
FieldName: "f",
Namespace: "2",
IdentifierType: "s",
Identifier: "h",
},
groupTags: map[string]string{"t1": "v1"},
expectedTags: map[string]string{"t1": "v1"},
err: nil,
},
{
name: "node tag overrides group tags",
settings: NodeSettings{
FieldName: "f",
Namespace: "2",
IdentifierType: "s",
Identifier: "h",
TagsSlice: [][]string{{"t1", "v2"}},
},
groupTags: map[string]string{"t1": "v1"},
expectedTags: map[string]string{"t1": "v2"},
err: nil,
},
{
name: "node tag merged with group tags",
settings: NodeSettings{
FieldName: "f",
Namespace: "2",
IdentifierType: "s",
Identifier: "h",
TagsSlice: [][]string{{"t2", "v2"}},
},
groupTags: map[string]string{"t1": "v1"},
expectedTags: map[string]string{"t1": "v1", "t2": "v2"},
err: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
nmm, err := NewNodeMetricMapping("testmetric", tt.settings, tt.groupTags)
require.Equal(t, tt.err, err)
require.Equal(t, tt.expectedTags, nmm.MetricTags)
})
}
}
func TestNewNodeMetricMappingIdStrInstantiated(t *testing.T) {
nmm, err := NewNodeMetricMapping("testmetric", NodeSettings{
FieldName: "f",
Namespace: "2",
IdentifierType: "s",
Identifier: "h",
}, map[string]string{})
require.NoError(t, err)
require.Equal(t, "ns=2;s=h", nmm.idStr)
}
func TestValidateNodeToAdd(t *testing.T) {
tests := []struct {
name string
existing map[metricParts]struct{}
nmm *NodeMetricMapping
err error
}{
{
name: "valid",
existing: map[metricParts]struct{}{},
nmm: func() *NodeMetricMapping {
nmm, err := NewNodeMetricMapping("testmetric", NodeSettings{
FieldName: "f",
Namespace: "2",
IdentifierType: "s",
Identifier: "hf",
}, map[string]string{})
require.NoError(t, err)
return nmm
}(),
err: nil,
},
{
name: "empty field name not allowed",
existing: map[metricParts]struct{}{},
nmm: func() *NodeMetricMapping {
nmm, err := NewNodeMetricMapping("testmetric", NodeSettings{
FieldName: "",
Namespace: "2",
IdentifierType: "s",
Identifier: "hf",
}, map[string]string{})
require.NoError(t, err)
return nmm
}(),
err: errors.New(`empty name in ""`),
},
{
name: "empty namespace not allowed",
existing: map[metricParts]struct{}{},
nmm: func() *NodeMetricMapping {
nmm, err := NewNodeMetricMapping("testmetric", NodeSettings{
FieldName: "f",
Namespace: "",
IdentifierType: "s",
Identifier: "hf",
}, map[string]string{})
require.NoError(t, err)
return nmm
}(),
err: errors.New("empty node namespace not allowed"),
},
{
name: "empty identifier type not allowed",
existing: map[metricParts]struct{}{},
nmm: func() *NodeMetricMapping {
nmm, err := NewNodeMetricMapping("testmetric", NodeSettings{
FieldName: "f",
Namespace: "2",
IdentifierType: "",
Identifier: "hf",
}, map[string]string{})
require.NoError(t, err)
return nmm
}(),
err: errors.New(`invalid identifier type "" in "f"`),
},
{
name: "invalid identifier type not allowed",
existing: map[metricParts]struct{}{},
nmm: func() *NodeMetricMapping {
nmm, err := NewNodeMetricMapping("testmetric", NodeSettings{
FieldName: "f",
Namespace: "2",
IdentifierType: "j",
Identifier: "hf",
}, map[string]string{})
require.NoError(t, err)
return nmm
}(),
err: errors.New(`invalid identifier type "j" in "f"`),
},
{
name: "duplicate metric not allowed",
existing: map[metricParts]struct{}{
{metricName: "testmetric", fieldName: "f", tags: "t1=v1, t2=v2"}: {},
},
nmm: func() *NodeMetricMapping {
nmm, err := NewNodeMetricMapping("testmetric", NodeSettings{
FieldName: "f",
Namespace: "2",
IdentifierType: "s",
Identifier: "hf",
TagsSlice: [][]string{{"t1", "v1"}, {"t2", "v2"}},
}, map[string]string{})
require.NoError(t, err)
return nmm
}(),
err: errors.New(`name "f" is duplicated (metric name "testmetric", tags "t1=v1, t2=v2")`),
},
{
name: "identifier type mismatch",
existing: map[metricParts]struct{}{},
nmm: func() *NodeMetricMapping {
nmm, err := NewNodeMetricMapping("testmetric", NodeSettings{
FieldName: "f",
Namespace: "2",
IdentifierType: "i",
Identifier: "hf",
}, map[string]string{})
require.NoError(t, err)
return nmm
}(),
err: errors.New(`identifier type "i" does not match the type of identifier "hf"`),
},
}
for idT, idV := range map[string]string{
"s": "hf",
"i": "1",
"g": "849683f0-ce92-4fa2-836f-a02cde61d75d",
"b": "aGVsbG8gSSBhbSBhIHRlc3QgaWRlbnRpZmllcg=="} {
tests = append(tests, struct {
name string
existing map[metricParts]struct{}
nmm *NodeMetricMapping
err error
}{
name: "identifier type " + idT + " allowed",
existing: map[metricParts]struct{}{},
nmm: func() *NodeMetricMapping {
nmm, err := NewNodeMetricMapping("testmetric", NodeSettings{
FieldName: "f",
Namespace: "2",
IdentifierType: idT,
Identifier: idV,
}, map[string]string{})
require.NoError(t, err)
return nmm
}(),
err: nil,
})
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validateNodeToAdd(tt.existing, tt.nmm)
require.Equal(t, tt.err, err)
})
}
}
func TestInitNodeMetricMapping(t *testing.T) {
tests := []struct {
testname string
config InputClientConfig
expected []NodeMetricMapping
err error
}{
{
testname: "only root node",
config: InputClientConfig{
MetricName: "testmetric",
Timestamp: TimestampSourceTelegraf,
RootNodes: []NodeSettings{
{
FieldName: "f",
Namespace: "2",
IdentifierType: "s",
Identifier: "id1",
TagsSlice: [][]string{{"t1", "v1"}},
},
},
},
expected: []NodeMetricMapping{
{
Tag: NodeSettings{
FieldName: "f",
Namespace: "2",
IdentifierType: "s",
Identifier: "id1",
TagsSlice: [][]string{{"t1", "v1"}},
},
idStr: "ns=2;s=id1",
metricName: "testmetric",
MetricTags: map[string]string{"t1": "v1"},
},
},
err: nil,
},
{
testname: "root node and group node",
config: InputClientConfig{
MetricName: "testmetric",
Timestamp: TimestampSourceTelegraf,
RootNodes: []NodeSettings{
{
FieldName: "f",
Namespace: "2",
IdentifierType: "s",
Identifier: "id1",
TagsSlice: [][]string{{"t1", "v1"}},
},
},
Groups: []NodeGroupSettings{
{
MetricName: "groupmetric",
Namespace: "3",
IdentifierType: "s",
Nodes: []NodeSettings{
{
FieldName: "f",
Identifier: "id2",
TagsSlice: [][]string{{"t2", "v2"}},
},
},
},
},
},
expected: []NodeMetricMapping{
{
Tag: NodeSettings{
FieldName: "f",
Namespace: "2",
IdentifierType: "s",
Identifier: "id1",
TagsSlice: [][]string{{"t1", "v1"}},
},
idStr: "ns=2;s=id1",
metricName: "testmetric",
MetricTags: map[string]string{"t1": "v1"},
},
{
Tag: NodeSettings{
FieldName: "f",
Namespace: "3",
IdentifierType: "s",
Identifier: "id2",
TagsSlice: [][]string{{"t2", "v2"}},
},
idStr: "ns=3;s=id2",
metricName: "groupmetric",
MetricTags: map[string]string{"t2": "v2"},
},
},
err: nil,
},
{
testname: "only group node",
config: InputClientConfig{
MetricName: "testmetric",
Timestamp: TimestampSourceTelegraf,
Groups: []NodeGroupSettings{
{
MetricName: "groupmetric",
Namespace: "3",
IdentifierType: "s",
Nodes: []NodeSettings{
{
FieldName: "f",
Identifier: "id2",
TagsSlice: [][]string{{"t2", "v2"}},
},
},
},
},
},
expected: []NodeMetricMapping{
{
Tag: NodeSettings{
FieldName: "f",
Namespace: "3",
IdentifierType: "s",
Identifier: "id2",
TagsSlice: [][]string{{"t2", "v2"}},
},
idStr: "ns=3;s=id2",
metricName: "groupmetric",
MetricTags: map[string]string{"t2": "v2"},
},
},
err: nil,
},
{
testname: "tags and default only default tags used",
config: InputClientConfig{
MetricName: "testmetric",
Timestamp: TimestampSourceTelegraf,
Groups: []NodeGroupSettings{
{
MetricName: "groupmetric",
Namespace: "3",
IdentifierType: "s",
Nodes: []NodeSettings{
{
FieldName: "f",
Identifier: "id2",
TagsSlice: [][]string{{"t2", "v2"}},
DefaultTags: map[string]string{"t3": "v3"},
},
},
},
},
},
expected: []NodeMetricMapping{
{
Tag: NodeSettings{
FieldName: "f",
Namespace: "3",
IdentifierType: "s",
Identifier: "id2",
TagsSlice: [][]string{{"t2", "v2"}},
DefaultTags: map[string]string{"t3": "v3"},
},
idStr: "ns=3;s=id2",
metricName: "groupmetric",
MetricTags: map[string]string{"t3": "v3"},
},
},
err: nil,
},
{
testname: "only root node default overrides slice",
config: InputClientConfig{
MetricName: "testmetric",
Timestamp: TimestampSourceTelegraf,
RootNodes: []NodeSettings{
{
FieldName: "f",
Namespace: "2",
IdentifierType: "s",
Identifier: "id1",
TagsSlice: [][]string{{"t1", "v1"}},
DefaultTags: map[string]string{"t3": "v3"},
},
},
},
expected: []NodeMetricMapping{
{
Tag: NodeSettings{
FieldName: "f",
Namespace: "2",
IdentifierType: "s",
Identifier: "id1",
TagsSlice: [][]string{{"t1", "v1"}},
DefaultTags: map[string]string{"t3": "v3"},
},
idStr: "ns=2;s=id1",
metricName: "testmetric",
MetricTags: map[string]string{"t3": "v3"},
},
},
err: nil,
},
}
for _, tt := range tests {
t.Run(tt.testname, func(t *testing.T) {
o := OpcUAInputClient{Config: tt.config}
err := o.InitNodeMetricMapping()
require.NoError(t, err)
require.Equal(t, tt.expected, o.NodeMetricMapping)
})
}
}
func TestUpdateNodeValue(t *testing.T) {
type testStep struct {
nodeIdx int
value interface{}
status ua.StatusCode
expected interface{}
}
tests := []struct {
testname string
steps []testStep
}{
{
"value should update when code ok",
[]testStep{
{
0,
"Harmony",
ua.StatusOK,
"Harmony",
},
},
},
{
"value should not update when code bad",
[]testStep{
{
0,
"Harmony",
ua.StatusOK,
"Harmony",
},
{
0,
"Odium",
ua.StatusBad,
"Harmony",
},
{
0,
"Ati",
ua.StatusOK,
"Ati",
},
},
},
}
conf := &opcua.OpcUAClientConfig{
Endpoint: "opc.tcp://localhost:4930",
SecurityPolicy: "None",
SecurityMode: "None",
AuthMethod: "",
ConnectTimeout: config.Duration(2 * time.Second),
RequestTimeout: config.Duration(2 * time.Second),
Workarounds: opcua.OpcUAWorkarounds{},
}
c, err := conf.CreateClient(testutil.Logger{})
require.NoError(t, err)
o := OpcUAInputClient{
OpcUAClient: c,
Log: testutil.Logger{},
NodeMetricMapping: []NodeMetricMapping{
{
Tag: NodeSettings{
FieldName: "f",
},
},
{
Tag: NodeSettings{
FieldName: "f2",
},
},
},
LastReceivedData: make([]NodeValue, 2),
}
for _, tt := range tests {
t.Run(tt.testname, func(t *testing.T) {
o.LastReceivedData = make([]NodeValue, 2)
for i, step := range tt.steps {
v, err := ua.NewVariant(step.value)
require.NoError(t, err)
o.UpdateNodeValue(0, &ua.DataValue{
Value: v,
Status: step.status,
SourceTimestamp: time.Date(2022, 03, 17, 8, 33, 00, 00, &time.Location{}).Add(time.Duration(i) * time.Second),
SourcePicoseconds: 0,
ServerTimestamp: time.Date(2022, 03, 17, 8, 33, 00, 500, &time.Location{}).Add(time.Duration(i) * time.Second),
ServerPicoseconds: 0,
})
require.Equal(t, step.expected, o.LastReceivedData[0].Value)
}
})
}
}
func TestMetricForNode(t *testing.T) {
conf := &opcua.OpcUAClientConfig{
Endpoint: "opc.tcp://localhost:4930",
SecurityPolicy: "None",
SecurityMode: "None",
AuthMethod: "",
ConnectTimeout: config.Duration(2 * time.Second),
RequestTimeout: config.Duration(2 * time.Second),
Workarounds: opcua.OpcUAWorkarounds{},
}
c, err := conf.CreateClient(testutil.Logger{})
require.NoError(t, err)
o := OpcUAInputClient{
Config: InputClientConfig{
Timestamp: TimestampSourceSource,
},
OpcUAClient: c,
Log: testutil.Logger{},
LastReceivedData: make([]NodeValue, 2),
}
tests := []struct {
testname string
nmm []NodeMetricMapping
v interface{}
isArray bool
dataType ua.TypeID
time time.Time
status ua.StatusCode
expected telegraf.Metric
}{
{
testname: "metric build correctly",
nmm: []NodeMetricMapping{
{
Tag: NodeSettings{
FieldName: "fn",
},
idStr: "ns=3;s=hi",
metricName: "testingmetric",
MetricTags: map[string]string{"t1": "v1"},
},
},
v: 16,
isArray: false,
dataType: ua.TypeIDInt32,
time: time.Date(2022, 03, 17, 8, 55, 00, 00, &time.Location{}),
status: ua.StatusOK,
expected: metric.New("testingmetric",
map[string]string{"t1": "v1", "id": "ns=3;s=hi"},
map[string]interface{}{"Quality": "The operation succeeded. StatusGood (0x0)", "fn": 16},
time.Date(2022, 03, 17, 8, 55, 00, 00, &time.Location{})),
},
{
testname: "array-like metric build correctly",
nmm: []NodeMetricMapping{
{
Tag: NodeSettings{
FieldName: "fn",
},
idStr: "ns=3;s=hi",
metricName: "testingmetric",
MetricTags: map[string]string{"t1": "v1"},
},
},
v: []int32{16, 17},
isArray: true,
dataType: ua.TypeIDInt32,
time: time.Date(2022, 03, 17, 8, 55, 00, 00, &time.Location{}),
status: ua.StatusOK,
expected: metric.New("testingmetric",
map[string]string{"t1": "v1", "id": "ns=3;s=hi"},
map[string]interface{}{"Quality": "The operation succeeded. StatusGood (0x0)", "fn[0]": 16, "fn[1]": 17},
time.Date(2022, 03, 17, 8, 55, 00, 00, &time.Location{})),
},
{
testname: "nil does not panic",
nmm: []NodeMetricMapping{
{
Tag: NodeSettings{
FieldName: "fn",
},
idStr: "ns=3;s=hi",
metricName: "testingmetric",
MetricTags: map[string]string{"t1": "v1"},
},
},
v: nil,
isArray: false,
dataType: ua.TypeIDNull,
time: time.Date(2022, 03, 17, 8, 55, 00, 00, &time.Location{}),
status: ua.StatusOK,
expected: metric.New("testingmetric",
map[string]string{"t1": "v1", "id": "ns=3;s=hi"},
map[string]interface{}{"Quality": "The operation succeeded. StatusGood (0x0)"},
time.Date(2022, 03, 17, 8, 55, 00, 00, &time.Location{})),
},
}
for _, tt := range tests {
t.Run(tt.testname, func(t *testing.T) {
o.NodeMetricMapping = tt.nmm
o.LastReceivedData[0].SourceTime = tt.time
o.LastReceivedData[0].Quality = tt.status
o.LastReceivedData[0].Value = tt.v
o.LastReceivedData[0].DataType = tt.dataType
o.LastReceivedData[0].IsArray = tt.isArray
actual := o.MetricForNode(0)
require.Equal(t, tt.expected.Tags(), actual.Tags())
require.Equal(t, tt.expected.Fields(), actual.Fields())
require.Equal(t, tt.expected.Time(), actual.Time())
})
}
}

View file

@ -0,0 +1,15 @@
package opcua
import (
"github.com/influxdata/telegraf"
)
// DebugLogger logs messages from opcua at the debug level.
type DebugLogger struct {
Log telegraf.Logger
}
func (l *DebugLogger) Write(p []byte) (n int, err error) {
l.Log.Debug(string(p))
return len(p), nil
}

View file

@ -0,0 +1,359 @@
package opcua
import (
"crypto/ecdsa"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"errors"
"fmt"
"math/big"
"net"
"net/url"
"os"
"strings"
"time"
"github.com/gopcua/opcua"
"github.com/gopcua/opcua/debug"
"github.com/gopcua/opcua/ua"
"github.com/influxdata/telegraf/config"
)
// SELF SIGNED CERT FUNCTIONS
func newTempDir() (string, error) {
dir, err := os.MkdirTemp("", "ssc")
return dir, err
}
func generateCert(host string, rsaBits int, certFile, keyFile string, dur time.Duration) (cert, key string, err error) {
dir, err := newTempDir()
if err != nil {
return "", "", fmt.Errorf("failed to create certificate: %w", err)
}
if len(host) == 0 {
return "", "", errors.New("missing required host parameter")
}
if rsaBits == 0 {
rsaBits = 2048
}
if len(certFile) == 0 {
certFile = dir + "/cert.pem"
}
if len(keyFile) == 0 {
keyFile = dir + "/key.pem"
}
priv, err := rsa.GenerateKey(rand.Reader, rsaBits)
if err != nil {
return "", "", fmt.Errorf("failed to generate private key: %w", err)
}
notBefore := time.Now()
notAfter := notBefore.Add(dur)
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
if err != nil {
return "", "", fmt.Errorf("failed to generate serial number: %w", err)
}
template := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
Organization: []string{"Telegraf OPC UA Client"},
},
NotBefore: notBefore,
NotAfter: notAfter,
KeyUsage: x509.KeyUsageContentCommitment | x509.KeyUsageKeyEncipherment |
x509.KeyUsageDigitalSignature | x509.KeyUsageDataEncipherment | x509.KeyUsageCertSign,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
BasicConstraintsValid: true,
}
hosts := strings.Split(host, ",")
for _, h := range hosts {
if ip := net.ParseIP(h); ip != nil {
template.IPAddresses = append(template.IPAddresses, ip)
} else {
template.DNSNames = append(template.DNSNames, h)
}
if uri, err := url.Parse(h); err == nil {
template.URIs = append(template.URIs, uri)
}
}
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, publicKey(priv), priv)
if err != nil {
return "", "", fmt.Errorf("failed to create certificate: %w", err)
}
certOut, err := os.Create(certFile)
if err != nil {
return "", "", fmt.Errorf("failed to open %s for writing: %w", certFile, err)
}
if err := pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil {
return "", "", fmt.Errorf("failed to write data to %s: %w", certFile, err)
}
if err := certOut.Close(); err != nil {
return "", "", fmt.Errorf("error closing %s: %w", certFile, err)
}
keyOut, err := os.OpenFile(keyFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
if err != nil {
return "", "", fmt.Errorf("failed to open %s for writing: %w", keyFile, err)
}
keyBlock, err := pemBlockForKey(priv)
if err != nil {
return "", "", fmt.Errorf("error generating block: %w", err)
}
if err := pem.Encode(keyOut, keyBlock); err != nil {
return "", "", fmt.Errorf("failed to write data to %s: %w", keyFile, err)
}
if err := keyOut.Close(); err != nil {
return "", "", fmt.Errorf("error closing %s: %w", keyFile, err)
}
return certFile, keyFile, nil
}
func publicKey(priv interface{}) interface{} {
switch k := priv.(type) {
case *rsa.PrivateKey:
return &k.PublicKey
case *ecdsa.PrivateKey:
return &k.PublicKey
default:
return nil
}
}
func pemBlockForKey(priv interface{}) (*pem.Block, error) {
switch k := priv.(type) {
case *rsa.PrivateKey:
return &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(k)}, nil
case *ecdsa.PrivateKey:
b, err := x509.MarshalECPrivateKey(k)
if err != nil {
return nil, fmt.Errorf("unable to marshal ECDSA private key: %w", err)
}
return &pem.Block{Type: "EC PRIVATE KEY", Bytes: b}, nil
default:
return nil, nil
}
}
func (o *OpcUAClient) generateClientOpts(endpoints []*ua.EndpointDescription) ([]opcua.Option, error) {
appuri := "urn:telegraf:gopcua:client"
appname := "Telegraf"
// ApplicationURI is automatically read from the cert so is not required if a cert if provided
opts := []opcua.Option{
opcua.ApplicationURI(appuri),
opcua.ApplicationName(appname),
opcua.RequestTimeout(time.Duration(o.Config.RequestTimeout)),
}
if o.Config.SessionTimeout != 0 {
opts = append(opts, opcua.SessionTimeout(time.Duration(o.Config.SessionTimeout)))
}
certFile := o.Config.Certificate
keyFile := o.Config.PrivateKey
policy := o.Config.SecurityPolicy
mode := o.Config.SecurityMode
var err error
if certFile == "" && keyFile == "" {
if policy != "None" || mode != "None" {
certFile, keyFile, err = generateCert(appuri, 2048, certFile, keyFile, 365*24*time.Hour)
if err != nil {
return nil, err
}
}
}
var cert []byte
if certFile != "" && keyFile != "" {
debug.Printf("Loading cert/key from %s/%s", certFile, keyFile)
c, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
o.Log.Warnf("Failed to load certificate: %s", err)
} else {
pk, ok := c.PrivateKey.(*rsa.PrivateKey)
if !ok {
return nil, errors.New("invalid private key")
}
cert = c.Certificate[0]
opts = append(opts, opcua.PrivateKey(pk), opcua.Certificate(cert))
}
}
var secPolicy string
switch {
case policy == "auto":
// set it later
case strings.HasPrefix(policy, ua.SecurityPolicyURIPrefix):
secPolicy = policy
policy = ""
case policy == "None" || policy == "Basic128Rsa15" || policy == "Basic256" || policy == "Basic256Sha256" ||
policy == "Aes128_Sha256_RsaOaep" || policy == "Aes256_Sha256_RsaPss":
secPolicy = ua.SecurityPolicyURIPrefix + policy
policy = ""
default:
return nil, fmt.Errorf("invalid security policy: %s", policy)
}
o.Log.Debugf("security policy from configuration %s", secPolicy)
// Select the most appropriate authentication mode from server capabilities and user input
authMode, authOption, err := o.generateAuth(o.Config.AuthMethod, cert, o.Config.Username, o.Config.Password)
if err != nil {
return nil, err
}
opts = append(opts, authOption)
var secMode ua.MessageSecurityMode
switch strings.ToLower(mode) {
case "auto":
case "none":
secMode = ua.MessageSecurityModeNone
mode = ""
case "sign":
secMode = ua.MessageSecurityModeSign
mode = ""
case "signandencrypt":
secMode = ua.MessageSecurityModeSignAndEncrypt
mode = ""
default:
return nil, fmt.Errorf("invalid security mode: %s", mode)
}
// Allow input of only one of sec-mode,sec-policy when choosing 'None'
if secMode == ua.MessageSecurityModeNone || secPolicy == ua.SecurityPolicyURINone {
secMode = ua.MessageSecurityModeNone
secPolicy = ua.SecurityPolicyURINone
}
// Find the best endpoint based on our input and server recommendation (highest SecurityMode+SecurityLevel)
var serverEndpoint *ua.EndpointDescription
switch {
case mode == "auto" && policy == "auto": // No user selection, choose best
for _, e := range endpoints {
if serverEndpoint == nil || (e.SecurityMode >= serverEndpoint.SecurityMode && e.SecurityLevel >= serverEndpoint.SecurityLevel) {
serverEndpoint = e
}
}
case mode != "auto" && policy == "auto": // User only cares about mode, select highest securitylevel with that mode
for _, e := range endpoints {
if e.SecurityMode == secMode && (serverEndpoint == nil || e.SecurityLevel >= serverEndpoint.SecurityLevel) {
serverEndpoint = e
}
}
case mode == "auto" && policy != "auto": // User only cares about policy, select highest securitylevel with that policy
for _, e := range endpoints {
if e.SecurityPolicyURI == secPolicy && (serverEndpoint == nil || e.SecurityLevel >= serverEndpoint.SecurityLevel) {
serverEndpoint = e
}
}
default: // User cares about both
o.Log.Debugf("User cares about both the policy (%s) and security mode (%s)", secPolicy, secMode)
o.Log.Debugf("Server has %d endpoints", len(endpoints))
for _, e := range endpoints {
o.Log.Debugf("Evaluating endpoint %s, policy %s, mode %s, level %d", e.EndpointURL, e.SecurityPolicyURI, e.SecurityMode, e.SecurityLevel)
if e.SecurityPolicyURI == secPolicy && e.SecurityMode == secMode && (serverEndpoint == nil || e.SecurityLevel >= serverEndpoint.SecurityLevel) {
serverEndpoint = e
o.Log.Debugf(
"Security policy and mode found. Using server endpoint %s for security. Policy %s",
serverEndpoint.EndpointURL,
serverEndpoint.SecurityPolicyURI,
)
}
}
}
if serverEndpoint == nil { // Didn't find an endpoint with matching policy and mode.
return nil, errors.New("unable to find suitable server endpoint with selected sec-policy and sec-mode")
}
secPolicy = serverEndpoint.SecurityPolicyURI
secMode = serverEndpoint.SecurityMode
// Check that the selected endpoint is a valid combo
err = validateEndpointConfig(endpoints, secPolicy, secMode, authMode)
if err != nil {
return nil, fmt.Errorf("error validating input: %w", err)
}
opts = append(opts, opcua.SecurityFromEndpoint(serverEndpoint, authMode))
return opts, nil
}
func (o *OpcUAClient) generateAuth(a string, cert []byte, user, passwd config.Secret) (ua.UserTokenType, opcua.Option, error) {
var authMode ua.UserTokenType
var authOption opcua.Option
switch strings.ToLower(a) {
case "anonymous":
authMode = ua.UserTokenTypeAnonymous
authOption = opcua.AuthAnonymous()
case "username":
authMode = ua.UserTokenTypeUserName
var username, password []byte
if !user.Empty() {
usecret, err := user.Get()
if err != nil {
return 0, nil, fmt.Errorf("error reading the username input: %w", err)
}
defer usecret.Destroy()
username = usecret.Bytes()
}
if !passwd.Empty() {
psecret, err := passwd.Get()
if err != nil {
return 0, nil, fmt.Errorf("error reading the password input: %w", err)
}
defer psecret.Destroy()
password = psecret.Bytes()
}
authOption = opcua.AuthUsername(string(username), string(password))
case "certificate":
authMode = ua.UserTokenTypeCertificate
authOption = opcua.AuthCertificate(cert)
case "issuedtoken":
// todo: this is unsupported, fail here or fail in the opcua package?
authMode = ua.UserTokenTypeIssuedToken
authOption = opcua.AuthIssuedToken([]byte(nil))
default:
o.Log.Warnf("unknown auth-mode, defaulting to Anonymous")
authMode = ua.UserTokenTypeAnonymous
authOption = opcua.AuthAnonymous()
}
return authMode, authOption, nil
}
func validateEndpointConfig(endpoints []*ua.EndpointDescription, secPolicy string, secMode ua.MessageSecurityMode, authMode ua.UserTokenType) error {
for _, e := range endpoints {
if e.SecurityMode == secMode && e.SecurityPolicyURI == secPolicy {
for _, t := range e.UserIdentityTokens {
if t.TokenType == authMode {
return nil
}
}
}
}
return fmt.Errorf("server does not support an endpoint with security: %q, %q", secPolicy, secMode)
}

View file

@ -0,0 +1,84 @@
package parallel
import (
"sync"
"github.com/influxdata/telegraf"
)
type Ordered struct {
wg sync.WaitGroup
fn func(telegraf.Metric) []telegraf.Metric
// queue of jobs coming in. Workers pick jobs off this queue for processing
workerQueue chan job
// queue of ordered metrics going out
queue chan futureMetric
}
func NewOrdered(acc telegraf.Accumulator, fn func(telegraf.Metric) []telegraf.Metric, orderedQueueSize, workerCount int) *Ordered {
p := &Ordered{
fn: fn,
workerQueue: make(chan job, workerCount),
queue: make(chan futureMetric, orderedQueueSize),
}
p.startWorkers(workerCount)
p.wg.Add(1)
go func() {
p.readQueue(acc)
p.wg.Done()
}()
return p
}
func (p *Ordered) Enqueue(metric telegraf.Metric) {
future := make(futureMetric)
p.queue <- future
// write the future to the worker pool. Order doesn't matter now because the
// outgoing p.queue will enforce order regardless of the order the jobs are
// completed in
p.workerQueue <- job{
future: future,
metric: metric,
}
}
func (p *Ordered) readQueue(acc telegraf.Accumulator) {
// wait for the response from each worker in order
for mCh := range p.queue {
// allow each worker to write out multiple metrics
for metrics := range mCh {
for _, m := range metrics {
acc.AddMetric(m)
}
}
}
}
func (p *Ordered) startWorkers(count int) {
p.wg.Add(count)
for i := 0; i < count; i++ {
go func() {
for job := range p.workerQueue {
job.future <- p.fn(job.metric)
close(job.future)
}
p.wg.Done()
}()
}
}
func (p *Ordered) Stop() {
close(p.queue)
close(p.workerQueue)
p.wg.Wait()
}
type futureMetric chan []telegraf.Metric
type job struct {
future futureMetric
metric telegraf.Metric
}

View file

@ -0,0 +1,8 @@
package parallel
import "github.com/influxdata/telegraf"
type Parallel interface {
Enqueue(telegraf.Metric)
Stop()
}

View file

@ -0,0 +1,117 @@
package parallel_test
import (
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/influxdata/telegraf"
"github.com/influxdata/telegraf/metric"
"github.com/influxdata/telegraf/plugins/common/parallel"
"github.com/influxdata/telegraf/testutil"
)
func TestOrderedJobsStayOrdered(t *testing.T) {
acc := &testutil.Accumulator{}
p := parallel.NewOrdered(acc, jobFunc, 10000, 10)
now := time.Now()
for i := 0; i < 20000; i++ {
m := metric.New("test",
map[string]string{},
map[string]interface{}{
"val": i,
},
now,
)
now = now.Add(1)
p.Enqueue(m)
}
p.Stop()
i := 0
require.Len(t, acc.Metrics, 20000)
for _, m := range acc.GetTelegrafMetrics() {
v, ok := m.GetField("val")
require.True(t, ok)
require.EqualValues(t, i, v)
i++
}
}
func TestUnorderedJobsDontDropAnyJobs(t *testing.T) {
acc := &testutil.Accumulator{}
p := parallel.NewUnordered(acc, jobFunc, 10)
now := time.Now()
expectedTotal := 0
for i := 0; i < 20000; i++ {
expectedTotal += i
m := metric.New("test",
map[string]string{},
map[string]interface{}{
"val": i,
},
now,
)
now = now.Add(1)
p.Enqueue(m)
}
p.Stop()
actualTotal := int64(0)
require.Len(t, acc.Metrics, 20000)
for _, m := range acc.GetTelegrafMetrics() {
v, ok := m.GetField("val")
require.True(t, ok)
actualTotal += v.(int64)
}
require.EqualValues(t, expectedTotal, actualTotal)
}
func BenchmarkOrdered(b *testing.B) {
acc := &testutil.Accumulator{}
p := parallel.NewOrdered(acc, jobFunc, 10000, 10)
m := metric.New("test",
map[string]string{},
map[string]interface{}{
"val": 1,
},
time.Now(),
)
b.ResetTimer()
for i := 0; i < b.N; i++ {
p.Enqueue(m)
}
p.Stop()
}
func BenchmarkUnordered(b *testing.B) {
acc := &testutil.Accumulator{}
p := parallel.NewUnordered(acc, jobFunc, 10)
m := metric.New("test",
map[string]string{},
map[string]interface{}{
"val": 1,
},
time.Now(),
)
b.ResetTimer()
for i := 0; i < b.N; i++ {
p.Enqueue(m)
}
p.Stop()
}
func jobFunc(m telegraf.Metric) []telegraf.Metric {
return []telegraf.Metric{m}
}

View file

@ -0,0 +1,60 @@
package parallel
import (
"sync"
"github.com/influxdata/telegraf"
)
type Unordered struct {
wg sync.WaitGroup
acc telegraf.Accumulator
fn func(telegraf.Metric) []telegraf.Metric
inQueue chan telegraf.Metric
}
func NewUnordered(
acc telegraf.Accumulator,
fn func(telegraf.Metric) []telegraf.Metric,
workerCount int,
) *Unordered {
p := &Unordered{
acc: acc,
inQueue: make(chan telegraf.Metric, workerCount),
fn: fn,
}
// start workers
p.wg.Add(1)
go func() {
p.startWorkers(workerCount)
p.wg.Done()
}()
return p
}
func (p *Unordered) startWorkers(count int) {
wg := sync.WaitGroup{}
wg.Add(count)
for i := 0; i < count; i++ {
go func() {
for metric := range p.inQueue {
for _, m := range p.fn(metric) {
p.acc.AddMetric(m)
}
}
wg.Done()
}()
}
wg.Wait()
}
func (p *Unordered) Stop() {
close(p.inQueue)
p.wg.Wait()
}
func (p *Unordered) Enqueue(m telegraf.Metric) {
p.inQueue <- m
}

View file

@ -0,0 +1,184 @@
package postgresql
import (
"fmt"
"net"
"net/url"
"regexp"
"sort"
"strings"
"time"
"github.com/jackc/pgx/v4"
"github.com/jackc/pgx/v4/stdlib"
"github.com/influxdata/telegraf/config"
)
var socketRegexp = regexp.MustCompile(`/\.s\.PGSQL\.\d+$`)
var sanitizer = regexp.MustCompile(`(\s|^)((?:password|sslcert|sslkey|sslmode|sslrootcert)\s?=\s?(?:(?:'(?:[^'\\]|\\.)*')|(?:\S+)))`)
type Config struct {
Address config.Secret `toml:"address"`
OutputAddress string `toml:"outputaddress"`
MaxIdle int `toml:"max_idle"`
MaxOpen int `toml:"max_open"`
MaxLifetime config.Duration `toml:"max_lifetime"`
IsPgBouncer bool `toml:"-"`
}
func (c *Config) CreateService() (*Service, error) {
addrSecret, err := c.Address.Get()
if err != nil {
return nil, fmt.Errorf("getting address failed: %w", err)
}
addr := addrSecret.String()
defer addrSecret.Destroy()
if c.Address.Empty() || addr == "localhost" {
addr = "host=localhost sslmode=disable"
if err := c.Address.Set([]byte(addr)); err != nil {
return nil, err
}
}
connConfig, err := pgx.ParseConfig(addr)
if err != nil {
return nil, err
}
// Remove the socket name from the path
connConfig.Host = socketRegexp.ReplaceAllLiteralString(connConfig.Host, "")
// Specific support to make it work with PgBouncer too
// See https://github.com/influxdata/telegraf/issues/3253#issuecomment-357505343
if c.IsPgBouncer {
// Remove DriveConfig and revert it by the ParseConfig method
// See https://github.com/influxdata/telegraf/issues/9134
connConfig.PreferSimpleProtocol = true
}
// Provide the connection string without sensitive information for use as
// tag or other output properties
sanitizedAddr, err := c.sanitizedAddress()
if err != nil {
return nil, err
}
return &Service{
SanitizedAddress: sanitizedAddr,
ConnectionDatabase: connectionDatabase(sanitizedAddr),
maxIdle: c.MaxIdle,
maxOpen: c.MaxOpen,
maxLifetime: time.Duration(c.MaxLifetime),
dsn: stdlib.RegisterConnConfig(connConfig),
}, nil
}
// connectionDatabase determines the database to which the connection was made
func connectionDatabase(sanitizedAddr string) string {
connConfig, err := pgx.ParseConfig(sanitizedAddr)
if err != nil || connConfig.Database == "" {
return "postgres"
}
return connConfig.Database
}
// sanitizedAddress strips sensitive information from the connection string.
// If the user set the output address use that before parsing anything else.
func (c *Config) sanitizedAddress() (string, error) {
if c.OutputAddress != "" {
return c.OutputAddress, nil
}
// Get the address
addrSecret, err := c.Address.Get()
if err != nil {
return "", fmt.Errorf("getting address for sanitization failed: %w", err)
}
defer addrSecret.Destroy()
// Make sure we convert URI-formatted strings into key-values
addr := addrSecret.TemporaryString()
if strings.HasPrefix(addr, "postgres://") || strings.HasPrefix(addr, "postgresql://") {
if addr, err = toKeyValue(addr); err != nil {
return "", err
}
}
// Sanitize the string using a regular expression
sanitized := sanitizer.ReplaceAllString(addr, "")
return strings.TrimSpace(sanitized), nil
}
// Based on parseURLSettings() at https://github.com/jackc/pgx/blob/master/pgconn/config.go
func toKeyValue(uri string) (string, error) {
u, err := url.Parse(uri)
if err != nil {
return "", fmt.Errorf("parsing URI failed: %w", err)
}
// Check the protocol
if u.Scheme != "postgres" && u.Scheme != "postgresql" {
return "", fmt.Errorf("invalid connection protocol: %s", u.Scheme)
}
quoteIfNecessary := func(v string) string {
if !strings.ContainsAny(v, ` ='\`) {
return v
}
r := strings.ReplaceAll(v, `\`, `\\`)
r = strings.ReplaceAll(r, `'`, `\'`)
return "'" + r + "'"
}
// Extract the parameters
parts := make([]string, 0, len(u.Query())+5)
if u.User != nil {
parts = append(parts, "user="+quoteIfNecessary(u.User.Username()))
if password, found := u.User.Password(); found {
parts = append(parts, "password="+quoteIfNecessary(password))
}
}
// Handle multiple host:port's in url.Host by splitting them into host,host,host and port,port,port.
hostParts := strings.Split(u.Host, ",")
hosts := make([]string, 0, len(hostParts))
ports := make([]string, 0, len(hostParts))
var anyPortSet bool
for _, host := range hostParts {
if host == "" {
continue
}
h, p, err := net.SplitHostPort(host)
if err != nil {
if !strings.Contains(err.Error(), "missing port") {
return "", fmt.Errorf("failed to process host %q: %w", host, err)
}
h = host
}
anyPortSet = anyPortSet || err == nil
hosts = append(hosts, h)
ports = append(ports, p)
}
if len(hosts) > 0 {
parts = append(parts, "host="+strings.Join(hosts, ","))
}
if anyPortSet {
parts = append(parts, "port="+strings.Join(ports, ","))
}
database := strings.TrimLeft(u.Path, "/")
if database != "" {
parts = append(parts, "dbname="+quoteIfNecessary(database))
}
for k, v := range u.Query() {
parts = append(parts, k+"="+quoteIfNecessary(strings.Join(v, ",")))
}
// Required to produce a repeatable output e.g. for tags or testing
sort.Strings(parts)
return strings.Join(parts, " "), nil
}

View file

@ -0,0 +1,240 @@
package postgresql
import (
"strings"
"testing"
"github.com/stretchr/testify/require"
"github.com/influxdata/telegraf/config"
)
func TestURIParsing(t *testing.T) {
tests := []struct {
name string
uri string
expected string
}{
{
name: "short",
uri: `postgres://localhost`,
expected: "host=localhost",
},
{
name: "with port",
uri: `postgres://localhost:5432`,
expected: "host=localhost port=5432",
},
{
name: "with database",
uri: `postgres://localhost/mydb`,
expected: "dbname=mydb host=localhost",
},
{
name: "with additional parameters",
uri: `postgres://localhost/mydb?application_name=pgxtest&search_path=myschema&connect_timeout=5`,
expected: "application_name=pgxtest connect_timeout=5 dbname=mydb host=localhost search_path=myschema",
},
{
name: "with database setting in params",
uri: `postgres://localhost:5432/?database=mydb`,
expected: "database=mydb host=localhost port=5432",
},
{
name: "with authentication",
uri: `postgres://jack:secret@localhost:5432/mydb?sslmode=prefer`,
expected: "dbname=mydb host=localhost password=secret port=5432 sslmode=prefer user=jack",
},
{
name: "with spaces",
uri: `postgres://jack%20hunter:secret@localhost/mydb?application_name=pgx%20test`,
expected: "application_name='pgx test' dbname=mydb host=localhost password=secret user='jack hunter'",
},
{
name: "with equal signs",
uri: `postgres://jack%20hunter:secret@localhost/mydb?application_name=pgx%3Dtest`,
expected: "application_name='pgx=test' dbname=mydb host=localhost password=secret user='jack hunter'",
},
{
name: "multiple hosts",
uri: `postgres://jack:secret@foo:1,bar:2,baz:3/mydb?sslmode=disable`,
expected: "dbname=mydb host=foo,bar,baz password=secret port=1,2,3 sslmode=disable user=jack",
},
{
name: "multiple hosts without ports",
uri: `postgres://jack:secret@foo,bar,baz/mydb?sslmode=disable`,
expected: "dbname=mydb host=foo,bar,baz password=secret sslmode=disable user=jack",
},
}
for _, tt := range tests {
// Key value without spaces around equal sign
t.Run(tt.name, func(t *testing.T) {
actual, err := toKeyValue(tt.uri)
require.NoError(t, err)
require.Equalf(t, tt.expected, actual, "initial: %s", tt.uri)
})
}
}
func TestSanitizeAddressKeyValue(t *testing.T) {
keys := []string{"password", "sslcert", "sslkey", "sslmode", "sslrootcert"}
tests := []struct {
name string
value string
}{
{
name: "simple text",
value: `foo`,
},
{
name: "empty values",
value: `''`,
},
{
name: "space in value",
value: `'foo bar'`,
},
{
name: "equal sign in value",
value: `'foo=bar'`,
},
{
name: "escaped quote",
value: `'foo\'s bar'`,
},
{
name: "escaped quote no space",
value: `\'foobar\'s\'`,
},
{
name: "escaped backslash",
value: `'foo bar\\'`,
},
{
name: "escaped quote and backslash",
value: `'foo\\\'s bar'`,
},
{
name: "two escaped backslashes",
value: `'foo bar\\\\'`,
},
{
name: "multiple inline spaces",
value: "'foo \t bar'",
},
{
name: "leading space",
value: `' foo bar'`,
},
{
name: "trailing space",
value: `'foo bar '`,
},
{
name: "multiple equal signs",
value: `'foo===bar'`,
},
{
name: "leading equal sign",
value: `'=foo bar'`,
},
{
name: "trailing equal sign",
value: `'foo bar='`,
},
{
name: "mix of equal signs and spaces",
value: "'foo = a\t===\tbar'",
},
}
for _, tt := range tests {
// Key value without spaces around equal sign
t.Run(tt.name, func(t *testing.T) {
// Generate the DSN from the given keys and value
parts := make([]string, 0, len(keys))
for _, k := range keys {
parts = append(parts, k+"="+tt.value)
}
dsn := strings.Join(parts, " canary=ok ")
cfg := &Config{
Address: config.NewSecret([]byte(dsn)),
}
expected := strings.Join(make([]string, len(keys)), "canary=ok ")
expected = strings.TrimSpace(expected)
actual, err := cfg.sanitizedAddress()
require.NoError(t, err)
require.Equalf(t, expected, actual, "initial: %s", dsn)
})
// Key value with spaces around equal sign
t.Run("spaced "+tt.name, func(t *testing.T) {
// Generate the DSN from the given keys and value
parts := make([]string, 0, len(keys))
for _, k := range keys {
parts = append(parts, k+" = "+tt.value)
}
dsn := strings.Join(parts, " canary=ok ")
cfg := &Config{
Address: config.NewSecret([]byte(dsn)),
}
expected := strings.Join(make([]string, len(keys)), "canary=ok ")
expected = strings.TrimSpace(expected)
actual, err := cfg.sanitizedAddress()
require.NoError(t, err)
require.Equalf(t, expected, actual, "initial: %s", dsn)
})
}
}
func TestSanitizeAddressURI(t *testing.T) {
keys := []string{"password", "sslcert", "sslkey", "sslmode", "sslrootcert"}
tests := []struct {
name string
value string
}{
{
name: "simple text",
value: `foo`,
},
{
name: "empty values",
value: ``,
},
{
name: "space in value",
value: `foo bar`,
},
{
name: "equal sign in value",
value: `foo=bar`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Generate the DSN from the given keys and value
value := strings.ReplaceAll(tt.value, "=", "%3D")
value = strings.ReplaceAll(value, " ", "%20")
parts := make([]string, 0, len(keys))
for _, k := range keys {
parts = append(parts, k+"="+value)
}
dsn := "postgresql://user:passwd@localhost:5432/db?" + strings.Join(parts, "&")
cfg := &Config{
Address: config.NewSecret([]byte(dsn)),
}
expected := "dbname=db host=localhost port=5432 user=user"
actual, err := cfg.sanitizedAddress()
require.NoError(t, err)
require.Equalf(t, expected, actual, "initial: %s", dsn)
})
}
}

View file

@ -0,0 +1,42 @@
package postgresql
import (
"database/sql"
"time"
// Blank import required to register driver
_ "github.com/jackc/pgx/v4/stdlib"
)
// Service common functionality shared between the postgresql and postgresql_extensible
// packages.
type Service struct {
DB *sql.DB
SanitizedAddress string
ConnectionDatabase string
dsn string
maxIdle int
maxOpen int
maxLifetime time.Duration
}
func (p *Service) Start() error {
db, err := sql.Open("pgx", p.dsn)
if err != nil {
return err
}
p.DB = db
p.DB.SetMaxOpenConns(p.maxOpen)
p.DB.SetMaxIdleConns(p.maxIdle)
p.DB.SetConnMaxLifetime(p.maxLifetime)
return nil
}
func (p *Service) Stop() {
if p.DB != nil {
p.DB.Close()
}
}

View file

@ -0,0 +1,140 @@
package proxy
import (
"bufio"
"context"
"fmt"
"net"
"net/http"
"net/url"
"golang.org/x/net/proxy"
)
// httpConnectProxy proxies (only?) TCP over a HTTP tunnel using the CONNECT method
type httpConnectProxy struct {
forward proxy.Dialer
url *url.URL
}
func (c *httpConnectProxy) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
// Prevent using UDP
if network == "udp" {
return nil, fmt.Errorf("cannot proxy %q traffic over HTTP CONNECT", network)
}
var proxyConn net.Conn
var err error
if dialer, ok := c.forward.(proxy.ContextDialer); ok {
proxyConn, err = dialer.DialContext(ctx, "tcp", c.url.Host)
} else {
shim := contextDialerShim{c.forward}
proxyConn, err = shim.DialContext(ctx, "tcp", c.url.Host)
}
if err != nil {
return nil, err
}
// Add and strip http:// to extract authority portion of the URL
// since CONNECT doesn't use a full URL. The request header would
// look something like: "CONNECT www.influxdata.com:443 HTTP/1.1"
requestURL, err := url.Parse("http://" + addr)
if err != nil {
if err := proxyConn.Close(); err != nil {
return nil, err
}
return nil, err
}
requestURL.Scheme = ""
// Build HTTP CONNECT request
req, err := http.NewRequest(http.MethodConnect, requestURL.String(), nil)
if err != nil {
if err := proxyConn.Close(); err != nil {
return nil, err
}
return nil, err
}
req.Close = false
if password, hasAuth := c.url.User.Password(); hasAuth {
req.SetBasicAuth(c.url.User.Username(), password)
}
err = req.Write(proxyConn)
if err != nil {
if err := proxyConn.Close(); err != nil {
return nil, err
}
return nil, err
}
resp, err := http.ReadResponse(bufio.NewReader(proxyConn), req)
if err != nil {
if err := proxyConn.Close(); err != nil {
return nil, err
}
return nil, err
}
if err := resp.Body.Close(); err != nil {
return nil, err
}
if resp.StatusCode != 200 {
if err := proxyConn.Close(); err != nil {
return nil, err
}
return nil, fmt.Errorf("failed to connect to proxy: %q", resp.Status)
}
return proxyConn, nil
}
func (c *httpConnectProxy) Dial(network, addr string) (net.Conn, error) {
return c.DialContext(context.Background(), network, addr)
}
func newHTTPConnectProxy(proxyURL *url.URL, forward proxy.Dialer) (proxy.Dialer, error) {
return &httpConnectProxy{forward, proxyURL}, nil
}
func init() {
// Register new proxy types
proxy.RegisterDialerType("http", newHTTPConnectProxy)
proxy.RegisterDialerType("https", newHTTPConnectProxy)
}
// contextDialerShim allows cancellation of the dial from a context even if the underlying
// dialer does not implement `proxy.ContextDialer`. Arguably, this shouldn't actually get run,
// unless a new proxy type is added that doesn't implement `proxy.ContextDialer`, as all the
// standard library dialers implement `proxy.ContextDialer`.
type contextDialerShim struct {
dialer proxy.Dialer
}
func (cd *contextDialerShim) Dial(network, addr string) (net.Conn, error) {
return cd.dialer.Dial(network, addr)
}
func (cd *contextDialerShim) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
var (
conn net.Conn
done = make(chan struct{}, 1)
err error
)
go func() {
conn, err = cd.dialer.Dial(network, addr)
close(done)
if conn != nil && ctx.Err() != nil {
_ = conn.Close()
}
}()
select {
case <-ctx.Done():
err = ctx.Err()
case <-done:
}
return conn, err
}

View file

@ -0,0 +1,37 @@
package proxy
import (
"context"
"net"
"time"
"golang.org/x/net/proxy"
)
type ProxiedDialer struct {
dialer proxy.Dialer
}
func (pd *ProxiedDialer) Dial(network, addr string) (net.Conn, error) {
return pd.dialer.Dial(network, addr)
}
func (pd *ProxiedDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
if contextDialer, ok := pd.dialer.(proxy.ContextDialer); ok {
return contextDialer.DialContext(ctx, network, addr)
}
contextDialer := contextDialerShim{pd.dialer}
return contextDialer.DialContext(ctx, network, addr)
}
func (pd *ProxiedDialer) DialTimeout(network, addr string, timeout time.Duration) (net.Conn, error) {
ctx := context.Background()
if timeout.Seconds() != 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, timeout)
defer cancel()
}
return pd.DialContext(ctx, network, addr)
}

View file

@ -0,0 +1,57 @@
package proxy
import (
"fmt"
"net/http"
"net/url"
"golang.org/x/net/proxy"
)
type HTTPProxy struct {
UseSystemProxy bool `toml:"use_system_proxy"`
HTTPProxyURL string `toml:"http_proxy_url"`
}
type proxyFunc func(req *http.Request) (*url.URL, error)
func (p *HTTPProxy) Proxy() (proxyFunc, error) {
if p.UseSystemProxy {
return http.ProxyFromEnvironment, nil
} else if len(p.HTTPProxyURL) > 0 {
address, err := url.Parse(p.HTTPProxyURL)
if err != nil {
return nil, fmt.Errorf("error parsing proxy url %q: %w", p.HTTPProxyURL, err)
}
return http.ProxyURL(address), nil
}
return nil, nil
}
type TCPProxy struct {
UseProxy bool `toml:"use_proxy"`
ProxyURL string `toml:"proxy_url"`
}
func (p *TCPProxy) Proxy() (*ProxiedDialer, error) {
var dialer proxy.Dialer
if p.UseProxy {
if len(p.ProxyURL) > 0 {
parsed, err := url.Parse(p.ProxyURL)
if err != nil {
return nil, fmt.Errorf("error parsing proxy url %q: %w", p.ProxyURL, err)
}
if dialer, err = proxy.FromURL(parsed, proxy.Direct); err != nil {
return nil, err
}
} else {
dialer = proxy.FromEnvironment()
}
} else {
dialer = proxy.Direct
}
return &ProxiedDialer{dialer}, nil
}

View file

@ -0,0 +1,22 @@
package proxy
import (
"golang.org/x/net/proxy"
)
type Socks5ProxyConfig struct {
Socks5ProxyEnabled bool `toml:"socks5_enabled"`
Socks5ProxyAddress string `toml:"socks5_address"`
Socks5ProxyUsername string `toml:"socks5_username"`
Socks5ProxyPassword string `toml:"socks5_password"`
}
func (c *Socks5ProxyConfig) GetDialer() (proxy.Dialer, error) {
var auth *proxy.Auth
if c.Socks5ProxyPassword != "" || c.Socks5ProxyUsername != "" {
auth = new(proxy.Auth)
auth.User = c.Socks5ProxyUsername
auth.Password = c.Socks5ProxyPassword
}
return proxy.SOCKS5("tcp", c.Socks5ProxyAddress, auth, proxy.Direct)
}

View file

@ -0,0 +1,74 @@
package proxy
import (
"net"
"testing"
"time"
"github.com/armon/go-socks5"
"github.com/stretchr/testify/require"
)
func TestSocks5ProxyConfigIntegration(t *testing.T) {
if testing.Short() {
t.Skip("Skipping integration test in short mode")
}
const (
proxyAddress = "127.0.0.1:12345"
proxyUsername = "user"
proxyPassword = "password"
)
l, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
server, err := socks5.New(&socks5.Config{
AuthMethods: []socks5.Authenticator{socks5.UserPassAuthenticator{
Credentials: socks5.StaticCredentials{
proxyUsername: proxyPassword,
},
}},
})
require.NoError(t, err)
go func() {
if err := server.ListenAndServe("tcp", proxyAddress); err != nil {
t.Error(err)
}
}()
conf := Socks5ProxyConfig{
Socks5ProxyEnabled: true,
Socks5ProxyAddress: proxyAddress,
Socks5ProxyUsername: proxyUsername,
Socks5ProxyPassword: proxyPassword,
}
dialer, err := conf.GetDialer()
require.NoError(t, err)
var proxyConn net.Conn
for i := 0; i < 10; i++ {
proxyConn, err = dialer.Dial("tcp", l.Addr().String())
if err == nil {
break
}
time.Sleep(10 * time.Millisecond)
}
require.NotNil(t, proxyConn)
defer func() { require.NoError(t, proxyConn.Close()) }()
serverConn, err := l.Accept()
require.NoError(t, err)
defer func() { require.NoError(t, serverConn.Close()) }()
writePayload := []byte("test")
_, err = proxyConn.Write(writePayload)
require.NoError(t, err)
receivePayload := make([]byte, 4)
_, err = serverConn.Read(receivePayload)
require.NoError(t, err)
require.Equal(t, writePayload, receivePayload)
}

View file

@ -0,0 +1,155 @@
package psutil
import (
"os"
"github.com/shirou/gopsutil/v4/cpu"
"github.com/shirou/gopsutil/v4/disk"
"github.com/shirou/gopsutil/v4/mem"
"github.com/shirou/gopsutil/v4/net"
"github.com/stretchr/testify/mock"
)
// MockPS is a mock implementation of the PS interface for testing purposes.
type MockPS struct {
mock.Mock
PSDiskDeps
}
// MockPSDisk is a mock implementation of the PSDiskDeps interface for testing purposes.
type MockPSDisk struct {
*SystemPS
*mock.Mock
}
// MockDiskUsage is a mock implementation for disk usage operations.
type MockDiskUsage struct {
*mock.Mock
}
// CPUTimes returns the CPU times statistics.
func (m *MockPS) CPUTimes(_, _ bool) ([]cpu.TimesStat, error) {
ret := m.Called()
r0 := ret.Get(0).([]cpu.TimesStat)
r1 := ret.Error(1)
return r0, r1
}
// DiskUsage returns the disk usage statistics.
func (m *MockPS) DiskUsage(mountPointFilter, mountOptsExclude, fstypeExclude []string) ([]*disk.UsageStat, []*disk.PartitionStat, error) {
ret := m.Called(mountPointFilter, mountOptsExclude, fstypeExclude)
r0 := ret.Get(0).([]*disk.UsageStat)
r1 := ret.Get(1).([]*disk.PartitionStat)
r2 := ret.Error(2)
return r0, r1, r2
}
// NetIO returns network I/O statistics for every network interface installed on the system.
func (m *MockPS) NetIO() ([]net.IOCountersStat, error) {
ret := m.Called()
r0 := ret.Get(0).([]net.IOCountersStat)
r1 := ret.Error(1)
return r0, r1
}
// NetProto returns network statistics for the entire system.
func (m *MockPS) NetProto() ([]net.ProtoCountersStat, error) {
ret := m.Called()
r0 := ret.Get(0).([]net.ProtoCountersStat)
r1 := ret.Error(1)
return r0, r1
}
// DiskIO returns the disk I/O statistics.
func (m *MockPS) DiskIO(_ []string) (map[string]disk.IOCountersStat, error) {
ret := m.Called()
r0 := ret.Get(0).(map[string]disk.IOCountersStat)
r1 := ret.Error(1)
return r0, r1
}
// VMStat returns the virtual memory statistics.
func (m *MockPS) VMStat() (*mem.VirtualMemoryStat, error) {
ret := m.Called()
r0 := ret.Get(0).(*mem.VirtualMemoryStat)
r1 := ret.Error(1)
return r0, r1
}
// SwapStat returns the swap memory statistics.
func (m *MockPS) SwapStat() (*mem.SwapMemoryStat, error) {
ret := m.Called()
r0 := ret.Get(0).(*mem.SwapMemoryStat)
r1 := ret.Error(1)
return r0, r1
}
// NetConnections returns a list of network connections opened.
func (m *MockPS) NetConnections() ([]net.ConnectionStat, error) {
ret := m.Called()
r0 := ret.Get(0).([]net.ConnectionStat)
r1 := ret.Error(1)
return r0, r1
}
// NetConntrack returns more detailed info about the conntrack table.
func (m *MockPS) NetConntrack(perCPU bool) ([]net.ConntrackStat, error) {
ret := m.Called(perCPU)
r0 := ret.Get(0).([]net.ConntrackStat)
r1 := ret.Error(1)
return r0, r1
}
// Partitions returns the disk partition statistics.
func (m *MockDiskUsage) Partitions(all bool) ([]disk.PartitionStat, error) {
ret := m.Called(all)
r0 := ret.Get(0).([]disk.PartitionStat)
r1 := ret.Error(1)
return r0, r1
}
// OSGetenv returns the value of the environment variable named by the key.
func (m *MockDiskUsage) OSGetenv(key string) string {
ret := m.Called(key)
return ret.Get(0).(string)
}
// OSStat returns the FileInfo structure describing the named file.
func (m *MockDiskUsage) OSStat(name string) (os.FileInfo, error) {
ret := m.Called(name)
r0 := ret.Get(0).(os.FileInfo)
r1 := ret.Error(1)
return r0, r1
}
// PSDiskUsage returns a file system usage for the specified path.
func (m *MockDiskUsage) PSDiskUsage(path string) (*disk.UsageStat, error) {
ret := m.Called(path)
r0 := ret.Get(0).(*disk.UsageStat)
r1 := ret.Error(1)
return r0, r1
}

256
plugins/common/psutil/ps.go Normal file
View file

@ -0,0 +1,256 @@
package psutil
import (
"errors"
"os"
"path/filepath"
"strings"
"github.com/shirou/gopsutil/v4/cpu"
"github.com/shirou/gopsutil/v4/disk"
"github.com/shirou/gopsutil/v4/mem"
"github.com/shirou/gopsutil/v4/net"
"github.com/influxdata/telegraf"
"github.com/influxdata/telegraf/internal"
)
// PS is an interface that defines methods for gathering system statistics.
type PS interface {
// CPUTimes returns the CPU times statistics.
CPUTimes(perCPU, totalCPU bool) ([]cpu.TimesStat, error)
// DiskUsage returns the disk usage statistics.
DiskUsage(mountPointFilter []string, mountOptsExclude []string, fstypeExclude []string) ([]*disk.UsageStat, []*disk.PartitionStat, error)
// NetIO returns network I/O statistics for every network interface installed on the system.
NetIO() ([]net.IOCountersStat, error)
// NetProto returns network statistics for the entire system.
NetProto() ([]net.ProtoCountersStat, error)
// DiskIO returns the disk I/O statistics.
DiskIO(names []string) (map[string]disk.IOCountersStat, error)
// VMStat returns the virtual memory statistics.
VMStat() (*mem.VirtualMemoryStat, error)
// SwapStat returns the swap memory statistics.
SwapStat() (*mem.SwapMemoryStat, error)
// NetConnections returns a list of network connections opened.
NetConnections() ([]net.ConnectionStat, error)
// NetConntrack returns more detailed info about the conntrack table.
NetConntrack(perCPU bool) ([]net.ConntrackStat, error)
}
// PSDiskDeps is an interface that defines methods for gathering disk statistics.
type PSDiskDeps interface {
// Partitions returns the disk partition statistics.
Partitions(all bool) ([]disk.PartitionStat, error)
// OSGetenv returns the value of the environment variable named by the key.
OSGetenv(key string) string
// OSStat returns the FileInfo structure describing the named file.
OSStat(name string) (os.FileInfo, error)
// PSDiskUsage returns a file system usage for the specified path.
PSDiskUsage(path string) (*disk.UsageStat, error)
}
// SystemPS is a struct that implements the PS interface.
type SystemPS struct {
PSDiskDeps
Log telegraf.Logger `toml:"-"`
}
// SystemPSDisk is a struct that implements the PSDiskDeps interface.
type SystemPSDisk struct{}
// NewSystemPS creates a new instance of SystemPS.
func NewSystemPS() *SystemPS {
return &SystemPS{PSDiskDeps: &SystemPSDisk{}}
}
// CPUTimes returns the CPU times statistics.
func (*SystemPS) CPUTimes(perCPU, totalCPU bool) ([]cpu.TimesStat, error) {
var cpuTimes []cpu.TimesStat
if perCPU {
perCPUTimes, err := cpu.Times(true)
if err != nil {
return nil, err
}
cpuTimes = append(cpuTimes, perCPUTimes...)
}
if totalCPU {
totalCPUTimes, err := cpu.Times(false)
if err != nil {
return nil, err
}
cpuTimes = append(cpuTimes, totalCPUTimes...)
}
return cpuTimes, nil
}
// DiskUsage returns the disk usage statistics.
func (s *SystemPS) DiskUsage(mountPointFilter, mountOptsExclude, fstypeExclude []string) ([]*disk.UsageStat, []*disk.PartitionStat, error) {
parts, err := s.Partitions(true)
if err != nil {
return nil, nil, err
}
mountPointFilterSet := newSet()
for _, filter := range mountPointFilter {
mountPointFilterSet.add(filter)
}
mountOptFilterSet := newSet()
for _, filter := range mountOptsExclude {
mountOptFilterSet.add(filter)
}
fstypeExcludeSet := newSet()
for _, filter := range fstypeExclude {
fstypeExcludeSet.add(filter)
}
paths := newSet()
for _, part := range parts {
paths.add(part.Mountpoint)
}
// Autofs mounts indicate a potential mount, the partition will also be
// listed with the actual filesystem when mounted. Ignore the autofs
// partition to avoid triggering a mount.
fstypeExcludeSet.add("autofs")
var usage []*disk.UsageStat
var partitions []*disk.PartitionStat
hostMountPrefix := s.OSGetenv("HOST_MOUNT_PREFIX")
partitionRange:
for i := range parts {
p := parts[i]
for _, o := range p.Opts {
if !mountOptFilterSet.empty() && mountOptFilterSet.has(o) {
continue partitionRange
}
}
// If there is a filter set and if the mount point is not a
// member of the filter set, don't gather info on it.
if !mountPointFilterSet.empty() && !mountPointFilterSet.has(p.Mountpoint) {
continue
}
// If the mount point is a member of the exclude set,
// don't gather info on it.
if fstypeExcludeSet.has(p.Fstype) {
continue
}
// If there's a host mount prefix use it as newer gopsutil version check for
// the init's mountpoints usually pointing to the host-mountpoint but in the
// container. This won't work for checking the disk-usage as the disks are
// mounted at HOST_MOUNT_PREFIX...
mountpoint := p.Mountpoint
if hostMountPrefix != "" && !strings.HasPrefix(p.Mountpoint, hostMountPrefix) {
mountpoint = filepath.Join(hostMountPrefix, p.Mountpoint)
// Exclude conflicting paths
if paths.has(mountpoint) {
if s.Log != nil {
s.Log.Debugf("[SystemPS] => dropped by mount prefix (%q): %q", mountpoint, hostMountPrefix)
}
continue
}
}
du, err := s.PSDiskUsage(mountpoint)
if err != nil {
if s.Log != nil {
s.Log.Debugf("[SystemPS] => unable to get disk usage (%q): %v", mountpoint, err)
}
continue
}
du.Path = filepath.Join(string(os.PathSeparator), strings.TrimPrefix(p.Mountpoint, hostMountPrefix))
du.Fstype = p.Fstype
usage = append(usage, du)
partitions = append(partitions, &p)
}
return usage, partitions, nil
}
// NetProto returns network statistics for the entire system.
func (*SystemPS) NetProto() ([]net.ProtoCountersStat, error) {
return net.ProtoCounters(nil)
}
// NetIO returns network I/O statistics for every network interface installed on the system.
func (*SystemPS) NetIO() ([]net.IOCountersStat, error) {
return net.IOCounters(true)
}
// NetConnections returns a list of network connections opened.
func (*SystemPS) NetConnections() ([]net.ConnectionStat, error) {
return net.Connections("all")
}
// NetConntrack returns more detailed info about the conntrack table.
func (*SystemPS) NetConntrack(perCPU bool) ([]net.ConntrackStat, error) {
return net.ConntrackStats(perCPU)
}
// DiskIO returns the disk I/O statistics.
func (*SystemPS) DiskIO(names []string) (map[string]disk.IOCountersStat, error) {
m, err := disk.IOCounters(names...)
if errors.Is(err, internal.ErrNotImplemented) {
return nil, nil
}
return m, err
}
// VMStat returns the virtual memory statistics.
func (*SystemPS) VMStat() (*mem.VirtualMemoryStat, error) {
return mem.VirtualMemory()
}
// SwapStat returns the swap memory statistics.
func (*SystemPS) SwapStat() (*mem.SwapMemoryStat, error) {
return mem.SwapMemory()
}
// Partitions returns the disk partition statistics.
func (*SystemPSDisk) Partitions(all bool) ([]disk.PartitionStat, error) {
return disk.Partitions(all)
}
// OSGetenv returns the value of the environment variable named by the key.
func (*SystemPSDisk) OSGetenv(key string) string {
return os.Getenv(key)
}
// OSStat returns the FileInfo structure describing the named file.
func (*SystemPSDisk) OSStat(name string) (os.FileInfo, error) {
return os.Stat(name)
}
// PSDiskUsage returns a file system usage for the specified path.
func (*SystemPSDisk) PSDiskUsage(path string) (*disk.UsageStat, error) {
return disk.Usage(path)
}
type set struct {
m map[string]struct{}
}
func (s *set) empty() bool {
return len(s.m) == 0
}
func (s *set) add(key string) {
s.m[key] = struct{}{}
}
func (s *set) has(key string) bool {
var ok bool
_, ok = s.m[key]
return ok
}
func newSet() *set {
s := &set{
m: make(map[string]struct{}),
}
return s
}

View file

@ -0,0 +1,23 @@
package ratelimiter
import (
"errors"
"time"
"github.com/influxdata/telegraf/config"
)
type RateLimitConfig struct {
Limit config.Size `toml:"rate_limit"`
Period config.Duration `toml:"rate_limit_period"`
}
func (cfg *RateLimitConfig) CreateRateLimiter() (*RateLimiter, error) {
if cfg.Limit > 0 && cfg.Period <= 0 {
return nil, errors.New("invalid period for rate-limit")
}
return &RateLimiter{
limit: int64(cfg.Limit),
period: time.Duration(cfg.Period),
}, nil
}

View file

@ -0,0 +1,66 @@
package ratelimiter
import (
"errors"
"math"
"time"
)
var (
ErrLimitExceeded = errors.New("not enough tokens")
)
type RateLimiter struct {
limit int64
period time.Duration
periodStart time.Time
remaining int64
}
func (r *RateLimiter) Remaining(t time.Time) int64 {
if r.limit == 0 {
return math.MaxInt64
}
// Check for corner case
if !r.periodStart.Before(t) {
return 0
}
// We are in a new period, so the complete size is available
deltat := t.Sub(r.periodStart)
if deltat >= r.period {
return r.limit
}
return r.remaining
}
func (r *RateLimiter) Accept(t time.Time, used int64) {
if r.limit == 0 || r.periodStart.After(t) {
return
}
// Remember the first query and reset if we are in a new period
if r.periodStart.IsZero() {
r.periodStart = t
r.remaining = r.limit
} else if deltat := t.Sub(r.periodStart); deltat >= r.period {
r.periodStart = r.periodStart.Add(deltat.Truncate(r.period))
r.remaining = r.limit
}
// Update the state
r.remaining = max(r.remaining-used, 0)
}
func (r *RateLimiter) Undo(t time.Time, used int64) {
// Do nothing if we are not in the current period or unlimited because we
// already reset the limit on a new window.
if r.limit == 0 || r.periodStart.IsZero() || r.periodStart.After(t) || t.Sub(r.periodStart) >= r.period {
return
}
// Undo the state update
r.remaining = min(r.remaining+used, r.limit)
}

View file

@ -0,0 +1,189 @@
package ratelimiter
import (
"math"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/influxdata/telegraf/config"
)
func TestInvalidPeriod(t *testing.T) {
cfg := &RateLimitConfig{Limit: config.Size(1024)}
_, err := cfg.CreateRateLimiter()
require.ErrorContains(t, err, "invalid period for rate-limit")
}
func TestUnlimited(t *testing.T) {
cfg := &RateLimitConfig{}
limiter, err := cfg.CreateRateLimiter()
require.NoError(t, err)
start := time.Now()
end := start.Add(30 * time.Minute)
for ts := start; ts.Before(end); ts = ts.Add(1 * time.Minute) {
require.EqualValues(t, int64(math.MaxInt64), limiter.Remaining(ts))
}
}
func TestUnlimitedWithPeriod(t *testing.T) {
cfg := &RateLimitConfig{
Period: config.Duration(5 * time.Minute),
}
limiter, err := cfg.CreateRateLimiter()
require.NoError(t, err)
start := time.Now()
end := start.Add(30 * time.Minute)
for ts := start; ts.Before(end); ts = ts.Add(1 * time.Minute) {
require.EqualValues(t, int64(math.MaxInt64), limiter.Remaining(ts))
}
}
func TestLimited(t *testing.T) {
tests := []struct {
name string
cfg *RateLimitConfig
step time.Duration
request []int64
expected []int64
}{
{
name: "constant usage",
cfg: &RateLimitConfig{
Limit: config.Size(1024),
Period: config.Duration(5 * time.Minute),
},
step: time.Minute,
request: []int64{300},
expected: []int64{1024, 724, 424, 124, 0, 1024, 724, 424, 124, 0},
},
{
name: "variable usage",
cfg: &RateLimitConfig{
Limit: config.Size(1024),
Period: config.Duration(5 * time.Minute),
},
step: time.Minute,
request: []int64{256, 128, 512, 64, 64, 1024, 0, 0, 0, 0, 128, 4096, 4096, 4096, 4096, 4096},
expected: []int64{1024, 768, 640, 128, 64, 1024, 0, 0, 0, 0, 1024, 896, 0, 0, 0, 1024},
},
}
// Run the test with an offset of period multiples
for _, tt := range tests {
t.Run(tt.name+" at period", func(t *testing.T) {
// Setup the limiter
limiter, err := tt.cfg.CreateRateLimiter()
require.NoError(t, err)
// Compute the actual values
start := time.Now().Truncate(tt.step)
for i, expected := range tt.expected {
ts := start.Add(time.Duration(i) * tt.step)
remaining := limiter.Remaining(ts)
use := min(remaining, tt.request[i%len(tt.request)])
require.Equalf(t, expected, remaining, "mismatch at index %d", i)
limiter.Accept(ts, use)
}
})
}
// Run the test at a time of period multiples
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Setup the limiter
limiter, err := tt.cfg.CreateRateLimiter()
require.NoError(t, err)
// Compute the actual values
start := time.Now().Truncate(tt.step).Add(1 * time.Second)
for i, expected := range tt.expected {
ts := start.Add(time.Duration(i) * tt.step)
remaining := limiter.Remaining(ts)
use := min(remaining, tt.request[i%len(tt.request)])
require.Equalf(t, expected, remaining, "mismatch at index %d", i)
limiter.Accept(ts, use)
}
})
}
}
func TestUndo(t *testing.T) {
tests := []struct {
name string
cfg *RateLimitConfig
step time.Duration
request []int64
expected []int64
}{
{
name: "constant usage",
cfg: &RateLimitConfig{
Limit: config.Size(1024),
Period: config.Duration(5 * time.Minute),
},
step: time.Minute,
request: []int64{300},
expected: []int64{1024, 724, 424, 124, 124, 1024, 724, 424, 124, 124},
},
{
name: "variable usage",
cfg: &RateLimitConfig{
Limit: config.Size(1024),
Period: config.Duration(5 * time.Minute),
},
step: time.Minute,
request: []int64{256, 128, 512, 64, 64, 1024, 0, 0, 0, 0, 128, 4096, 4096, 4096, 4096, 4096},
expected: []int64{1024, 768, 640, 128, 64, 1024, 0, 0, 0, 0, 1024, 896, 896, 896, 896, 1024},
},
}
// Run the test with an offset of period multiples
for _, tt := range tests {
t.Run(tt.name+" at period", func(t *testing.T) {
// Setup the limiter
limiter, err := tt.cfg.CreateRateLimiter()
require.NoError(t, err)
// Compute the actual values
start := time.Now().Truncate(tt.step)
for i, expected := range tt.expected {
ts := start.Add(time.Duration(i) * tt.step)
remaining := limiter.Remaining(ts)
use := min(remaining, tt.request[i%len(tt.request)])
require.Equalf(t, expected, remaining, "mismatch at index %d", i)
limiter.Accept(ts, use)
// Undo too large operations
if tt.request[i%len(tt.request)] > remaining {
limiter.Undo(ts, use)
}
}
})
}
// Run the test at a time of period multiples
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Setup the limiter
limiter, err := tt.cfg.CreateRateLimiter()
require.NoError(t, err)
// Compute the actual values
start := time.Now().Truncate(tt.step).Add(1 * time.Second)
for i, expected := range tt.expected {
ts := start.Add(time.Duration(i) * tt.step)
remaining := limiter.Remaining(ts)
use := min(remaining, tt.request[i%len(tt.request)])
require.Equalf(t, expected, remaining, "mismatch at index %d", i)
limiter.Accept(ts, use)
// Undo too large operations
if tt.request[i%len(tt.request)] > remaining {
limiter.Undo(ts, use)
}
}
})
}
}

View file

@ -0,0 +1,100 @@
package ratelimiter
import (
"bytes"
"math"
"github.com/influxdata/telegraf"
"github.com/influxdata/telegraf/internal"
)
// Serializer interface abstracting the different implementations of a
// limited-size serializer
type Serializer interface {
Serialize(metric telegraf.Metric, limit int64) ([]byte, error)
SerializeBatch(metrics []telegraf.Metric, limit int64) ([]byte, error)
}
// Individual serializers do serialize each metric individually using the
// serializer's Serialize() function and add the resulting output to the buffer
// until the limit is reached. This only works for serializers NOT requiring
// the serialization of a batch as-a-whole.
type IndividualSerializer struct {
serializer telegraf.Serializer
buffer *bytes.Buffer
}
func NewIndividualSerializer(s telegraf.Serializer) *IndividualSerializer {
return &IndividualSerializer{
serializer: s,
buffer: &bytes.Buffer{},
}
}
func (s *IndividualSerializer) Serialize(metric telegraf.Metric, limit int64) ([]byte, error) {
// Do the serialization
buf, err := s.serializer.Serialize(metric)
if err != nil {
return nil, err
}
// The serialized metric fits into the limit, so output it
if buflen := int64(len(buf)); buflen <= limit {
return buf, nil
}
// The serialized metric exceeds the limit
return nil, internal.ErrSizeLimitReached
}
func (s *IndividualSerializer) SerializeBatch(metrics []telegraf.Metric, limit int64) ([]byte, error) {
// Grow the buffer so it can hold at least the required size. This will
// save us from reallocate often
s.buffer.Reset()
if limit > 0 && limit > int64(s.buffer.Cap()) && limit < int64(math.MaxInt) {
s.buffer.Grow(int(limit))
}
// Prepare a potential write error and be optimistic
werr := &internal.PartialWriteError{
MetricsAccept: make([]int, 0, len(metrics)),
}
// Iterate through the metrics, serialize them and add them to the output
// buffer if they are within the size limit.
var used int64
for i, m := range metrics {
buf, err := s.serializer.Serialize(m)
if err != nil {
// Failing serialization is a fatal error so mark the metric as such
werr.Err = internal.ErrSerialization
werr.MetricsReject = append(werr.MetricsReject, i)
werr.MetricsRejectErrors = append(werr.MetricsRejectErrors, err)
continue
}
// The serialized metric fits into the limit, so add it to the output
if usedAdded := used + int64(len(buf)); usedAdded <= limit {
if _, err := s.buffer.Write(buf); err != nil {
return nil, err
}
werr.MetricsAccept = append(werr.MetricsAccept, i)
used = usedAdded
continue
}
// Return only the size-limit-reached error if all metrics failed.
if used == 0 {
return nil, internal.ErrSizeLimitReached
}
// Adding the serialized metric would exceed the limit so exit with an
// WriteError and fill in the required information
werr.Err = internal.ErrSizeLimitReached
break
}
if werr.Err != nil {
return s.buffer.Bytes(), werr
}
return s.buffer.Bytes(), nil
}

View file

@ -0,0 +1,352 @@
package ratelimiter
import (
"math"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/influxdata/telegraf"
"github.com/influxdata/telegraf/internal"
"github.com/influxdata/telegraf/metric"
"github.com/influxdata/telegraf/plugins/serializers/influx"
)
func TestIndividualSerializer(t *testing.T) {
input := []telegraf.Metric{
metric.New(
"serializer_test",
map[string]string{
"source": "localhost",
"location": "factory_north",
"machine": "A",
"status": "ok",
},
map[string]interface{}{
"operating_hours": 123,
"temperature": 25.0,
"pressure": 1023.4,
},
time.Unix(1722443551, 0),
),
metric.New(
"serializer_test",
map[string]string{
"source": "localhost",
"location": "factory_north",
"machine": "B",
"status": "failed",
},
map[string]interface{}{
"operating_hours": 8430,
"temperature": 65.2,
"pressure": 985.9,
},
time.Unix(1722443554, 0),
),
metric.New(
"serializer_test",
map[string]string{
"source": "localhost",
"location": "factory_north",
"machine": "C",
"status": "warning",
},
map[string]interface{}{
"operating_hours": 6765,
"temperature": 42.5,
"pressure": 986.1,
},
time.Unix(1722443555, 0),
),
metric.New(
"device",
map[string]string{
"source": "localhost",
"location": "factory_north",
},
map[string]interface{}{
"status": "ok",
},
time.Unix(1722443556, 0),
),
metric.New(
"serializer_test",
map[string]string{
"source": "gateway_af43e",
"location": "factory_south",
"machine": "A",
"status": "ok",
},
map[string]interface{}{
"operating_hours": 5544,
"temperature": 18.6,
"pressure": 1069.4,
},
time.Unix(1722443552, 0),
),
metric.New(
"serializer_test",
map[string]string{
"source": "gateway_af43e",
"location": "factory_south",
"machine": "B",
"status": "ok",
},
map[string]interface{}{
"operating_hours": 65,
"temperature": 29.7,
"pressure": 1101.2,
},
time.Unix(1722443553, 0),
),
metric.New(
"device",
map[string]string{
"source": "gateway_af43e",
"location": "factory_south",
},
map[string]interface{}{
"status": "ok",
},
time.Unix(1722443559, 0),
),
metric.New(
"serializer_test",
map[string]string{
"source": "gateway_af43e",
"location": "factory_south",
"machine": "C",
"status": "off",
},
map[string]interface{}{
"operating_hours": 0,
"temperature": 0.0,
"pressure": 0.0,
},
time.Unix(1722443562, 0),
),
}
//nolint:lll // Resulting metrics should not be wrapped for readability
expected := []string{
"serializer_test,location=factory_north,machine=A,source=localhost,status=ok operating_hours=123i,pressure=1023.4,temperature=25 1722443551000000000\n" +
"serializer_test,location=factory_north,machine=B,source=localhost,status=failed operating_hours=8430i,pressure=985.9,temperature=65.2 1722443554000000000\n",
"serializer_test,location=factory_north,machine=C,source=localhost,status=warning operating_hours=6765i,pressure=986.1,temperature=42.5 1722443555000000000\n" +
"device,location=factory_north,source=localhost status=\"ok\" 1722443556000000000\n" +
"serializer_test,location=factory_south,machine=A,source=gateway_af43e,status=ok operating_hours=5544i,pressure=1069.4,temperature=18.6 1722443552000000000\n",
"serializer_test,location=factory_south,machine=B,source=gateway_af43e,status=ok operating_hours=65i,pressure=1101.2,temperature=29.7 1722443553000000000\n" +
"device,location=factory_south,source=gateway_af43e status=\"ok\" 1722443559000000000\n" +
"serializer_test,location=factory_south,machine=C,source=gateway_af43e,status=off operating_hours=0i,pressure=0,temperature=0 1722443562000000000\n",
}
// Setup the limited serializer
s := &influx.Serializer{SortFields: true}
require.NoError(t, s.Init())
serializer := NewIndividualSerializer(s)
var werr *internal.PartialWriteError
// Do the first serialization runs with all metrics
buf, err := serializer.SerializeBatch(input, 400)
require.ErrorAs(t, err, &werr)
require.ErrorIs(t, werr.Err, internal.ErrSizeLimitReached)
require.EqualValues(t, []int{0, 1}, werr.MetricsAccept)
require.Empty(t, werr.MetricsReject)
require.Equal(t, expected[0], string(buf))
// Run again with the successful metrics removed
buf, err = serializer.SerializeBatch(input[2:], 400)
require.ErrorAs(t, err, &werr)
require.ErrorIs(t, werr.Err, internal.ErrSizeLimitReached)
require.EqualValues(t, []int{0, 1, 2}, werr.MetricsAccept)
require.Empty(t, werr.MetricsReject)
require.Equal(t, expected[1], string(buf))
// Final run with the successful metrics removed
buf, err = serializer.SerializeBatch(input[5:], 400)
require.NoError(t, err)
require.Equal(t, expected[2], string(buf))
}
func TestIndividualSerializerFirstTooBig(t *testing.T) {
input := []telegraf.Metric{
metric.New(
"serializer_test",
map[string]string{
"source": "localhost",
"location": "factory_north",
"machine": "A",
"status": "ok",
},
map[string]interface{}{
"operating_hours": 123,
"temperature": 25.0,
"pressure": 1023.4,
},
time.Unix(1722443551, 0),
),
metric.New(
"serializer_test",
map[string]string{
"source": "localhost",
"location": "factory_north",
"machine": "B",
"status": "failed",
},
map[string]interface{}{
"operating_hours": 8430,
"temperature": 65.2,
"pressure": 985.9,
},
time.Unix(1722443554, 0),
),
}
// Setup the limited serializer
s := &influx.Serializer{SortFields: true}
require.NoError(t, s.Init())
serializer := NewIndividualSerializer(s)
// The first metric will already exceed the size so all metrics fail and
// we expect a shortcut error.
buf, err := serializer.SerializeBatch(input, 100)
require.ErrorIs(t, err, internal.ErrSizeLimitReached)
require.Empty(t, buf)
}
func TestIndividualSerializerUnlimited(t *testing.T) {
input := []telegraf.Metric{
metric.New(
"serializer_test",
map[string]string{
"source": "localhost",
"location": "factory_north",
"machine": "A",
"status": "ok",
},
map[string]interface{}{
"operating_hours": 123,
"temperature": 25.0,
"pressure": 1023.4,
},
time.Unix(1722443551, 0),
),
metric.New(
"serializer_test",
map[string]string{
"source": "localhost",
"location": "factory_north",
"machine": "B",
"status": "failed",
},
map[string]interface{}{
"operating_hours": 8430,
"temperature": 65.2,
"pressure": 985.9,
},
time.Unix(1722443554, 0),
),
metric.New(
"serializer_test",
map[string]string{
"source": "localhost",
"location": "factory_north",
"machine": "C",
"status": "warning",
},
map[string]interface{}{
"operating_hours": 6765,
"temperature": 42.5,
"pressure": 986.1,
},
time.Unix(1722443555, 0),
),
metric.New(
"device",
map[string]string{
"source": "localhost",
"location": "factory_north",
},
map[string]interface{}{
"status": "ok",
},
time.Unix(1722443556, 0),
),
metric.New(
"serializer_test",
map[string]string{
"source": "gateway_af43e",
"location": "factory_south",
"machine": "A",
"status": "ok",
},
map[string]interface{}{
"operating_hours": 5544,
"temperature": 18.6,
"pressure": 1069.4,
},
time.Unix(1722443552, 0),
),
metric.New(
"serializer_test",
map[string]string{
"source": "gateway_af43e",
"location": "factory_south",
"machine": "B",
"status": "ok",
},
map[string]interface{}{
"operating_hours": 65,
"temperature": 29.7,
"pressure": 1101.2,
},
time.Unix(1722443553, 0),
),
metric.New(
"device",
map[string]string{
"source": "gateway_af43e",
"location": "factory_south",
},
map[string]interface{}{
"status": "ok",
},
time.Unix(1722443559, 0),
),
metric.New(
"serializer_test",
map[string]string{
"source": "gateway_af43e",
"location": "factory_south",
"machine": "C",
"status": "off",
},
map[string]interface{}{
"operating_hours": 0,
"temperature": 0.0,
"pressure": 0.0,
},
time.Unix(1722443562, 0),
),
}
//nolint:lll // Resulting metrics should not be wrapped for readability
expected := "serializer_test,location=factory_north,machine=A,source=localhost,status=ok operating_hours=123i,pressure=1023.4,temperature=25 1722443551000000000\n" +
"serializer_test,location=factory_north,machine=B,source=localhost,status=failed operating_hours=8430i,pressure=985.9,temperature=65.2 1722443554000000000\n" +
"serializer_test,location=factory_north,machine=C,source=localhost,status=warning operating_hours=6765i,pressure=986.1,temperature=42.5 1722443555000000000\n" +
"device,location=factory_north,source=localhost status=\"ok\" 1722443556000000000\n" +
"serializer_test,location=factory_south,machine=A,source=gateway_af43e,status=ok operating_hours=5544i,pressure=1069.4,temperature=18.6 1722443552000000000\n" +
"serializer_test,location=factory_south,machine=B,source=gateway_af43e,status=ok operating_hours=65i,pressure=1101.2,temperature=29.7 1722443553000000000\n" +
"device,location=factory_south,source=gateway_af43e status=\"ok\" 1722443559000000000\n" +
"serializer_test,location=factory_south,machine=C,source=gateway_af43e,status=off operating_hours=0i,pressure=0,temperature=0 1722443562000000000\n"
// Setup the limited serializer
s := &influx.Serializer{SortFields: true}
require.NoError(t, s.Init())
serializer := NewIndividualSerializer(s)
// Do the first serialization runs with all metrics
buf, err := serializer.SerializeBatch(input, math.MaxInt64)
require.NoError(t, err)
require.Equal(t, expected, string(buf))
}

View file

@ -0,0 +1,64 @@
# Telegraf Execd Go Shim
The goal of this _shim_ is to make it trivial to extract an internal input,
processor, or output plugin from the main Telegraf repo out to a stand-alone
repo. This allows anyone to build and run it as a separate app using one of the
execd plugins:
- [inputs.execd](/plugins/inputs/execd)
- [processors.execd](/plugins/processors/execd)
- [outputs.execd](/plugins/outputs/execd)
## Steps to externalize a plugin
1. Move the project to an external repo, it's recommended to preserve the path
structure, (but not strictly necessary). eg if your plugin was at
`plugins/inputs/cpu`, it's recommended that it also be under `plugins/inputs/cpu`
in the new repo. For a further example of what this might look like, take a
look at [ssoroka/rand](https://github.com/ssoroka/rand) or
[danielnelson/telegraf-plugins](https://github.com/danielnelson/telegraf-plugins)
1. Copy [main.go](./example/cmd/main.go) into your project under the `cmd` folder.
This will be the entrypoint to the plugin when run as a stand-alone program, and
it will call the shim code for you to make that happen. It's recommended to
have only one plugin per repo, as the shim is not designed to run multiple
plugins at the same time (it would vastly complicate things).
1. Edit the main.go file to import your plugin. Within Telegraf this would have
been done in an all.go file, but here we don't split the two apart, and the change
just goes in the top of main.go. If you skip this step, your plugin will do nothing.
eg: `_ "github.com/me/my-plugin-telegraf/plugins/inputs/cpu"`
1. Optionally add a [plugin.conf](./example/cmd/plugin.conf) for configuration
specific to your plugin. Note that this config file **must be separate from the
rest of the config for Telegraf, and must not be in a shared directory where
Telegraf is expecting to load all configs**. If Telegraf reads this config file
it will not know which plugin it relates to. Telegraf instead uses an execd config
block to look for this plugin.
## Steps to build and run your plugin
1. Build the cmd/main.go. For my rand project this looks like `go build -o rand cmd/main.go`
1. If you're building an input, you can test out the binary just by running it.
eg `./rand -config plugin.conf`
Depending on your polling settings and whether you implemented a service plugin or
an input gathering plugin, you may see data right away, or you may have to hit enter
first, or wait for your poll duration to elapse, but the metrics will be written to
STDOUT. Ctrl-C to end your test.
If you're testing a processor or output manually, you can still do this but you
will need to feed valid metrics in on STDIN to verify that it is doing what you
want. This can be a very valuable debugging technique before hooking it up to
Telegraf.
1. Configure Telegraf to call your new plugin binary. For an input, this would
look something like:
```toml
[[inputs.execd]]
command = ["/path/to/rand", "-config", "/path/to/plugin.conf"]
signal = "none"
```
Refer to the execd plugin readmes for more information.
## Congratulations
You've done it! Consider publishing your plugin to github and open a Pull Request
back to the Telegraf repo letting us know about the availability of your
[external plugin](https://github.com/influxdata/telegraf/blob/master/EXTERNAL_PLUGINS.md).

View file

@ -0,0 +1,170 @@
package shim
import (
"errors"
"fmt"
"log" //nolint:depguard // Allow exceptional but valid use of log here.
"os"
"github.com/BurntSushi/toml"
"github.com/influxdata/telegraf"
"github.com/influxdata/telegraf/plugins/inputs"
"github.com/influxdata/telegraf/plugins/outputs"
"github.com/influxdata/telegraf/plugins/processors"
)
type config struct {
Inputs map[string][]toml.Primitive
Processors map[string][]toml.Primitive
Outputs map[string][]toml.Primitive
}
type loadedConfig struct {
Input telegraf.Input
Processor telegraf.StreamingProcessor
Output telegraf.Output
}
// LoadConfig Adds plugins to the shim
func (s *Shim) LoadConfig(filePath *string) error {
conf, err := LoadConfig(filePath)
if err != nil {
return err
}
if conf.Input != nil {
if err = s.AddInput(conf.Input); err != nil {
return fmt.Errorf("failed to add Input: %w", err)
}
} else if conf.Processor != nil {
if err = s.AddStreamingProcessor(conf.Processor); err != nil {
return fmt.Errorf("failed to add Processor: %w", err)
}
} else if conf.Output != nil {
if err = s.AddOutput(conf.Output); err != nil {
return fmt.Errorf("failed to add Output: %w", err)
}
}
return nil
}
// LoadConfig loads the config and returns inputs that later need to be loaded.
func LoadConfig(filePath *string) (loaded loadedConfig, err error) {
var data string
conf := config{}
if filePath != nil && *filePath != "" {
b, err := os.ReadFile(*filePath)
if err != nil {
return loadedConfig{}, err
}
data = expandEnvVars(b)
} else {
conf = DefaultImportedPlugins()
}
md, err := toml.Decode(data, &conf)
if err != nil {
return loadedConfig{}, err
}
return createPluginsWithTomlConfig(md, conf)
}
func expandEnvVars(contents []byte) string {
return os.Expand(string(contents), getEnv)
}
func getEnv(key string) string {
v := os.Getenv(key)
return envVarEscaper.Replace(v)
}
func createPluginsWithTomlConfig(md toml.MetaData, conf config) (loadedConfig, error) {
loadedConf := loadedConfig{}
for name, primitives := range conf.Inputs {
creator, ok := inputs.Inputs[name]
if !ok {
return loadedConf, errors.New("unknown input " + name)
}
plugin := creator()
if len(primitives) > 0 {
primitive := primitives[0]
if err := md.PrimitiveDecode(primitive, plugin); err != nil {
return loadedConf, err
}
}
loadedConf.Input = plugin
break
}
for name, primitives := range conf.Processors {
creator, ok := processors.Processors[name]
if !ok {
return loadedConf, errors.New("unknown processor " + name)
}
plugin := creator()
if len(primitives) > 0 {
primitive := primitives[0]
var p telegraf.PluginDescriber = plugin
if processor, ok := plugin.(processors.HasUnwrap); ok {
p = processor.Unwrap()
}
if err := md.PrimitiveDecode(primitive, p); err != nil {
return loadedConf, err
}
}
loadedConf.Processor = plugin
break
}
for name, primitives := range conf.Outputs {
creator, ok := outputs.Outputs[name]
if !ok {
return loadedConf, errors.New("unknown output " + name)
}
plugin := creator()
if len(primitives) > 0 {
primitive := primitives[0]
if err := md.PrimitiveDecode(primitive, plugin); err != nil {
return loadedConf, err
}
}
loadedConf.Output = plugin
break
}
return loadedConf, nil
}
// DefaultImportedPlugins defaults to whatever plugins happen to be loaded and
// have registered themselves with the registry. This makes loading plugins
// without having to define a config dead easy.
func DefaultImportedPlugins() config {
conf := config{
Inputs: make(map[string][]toml.Primitive, len(inputs.Inputs)),
Processors: make(map[string][]toml.Primitive, len(processors.Processors)),
Outputs: make(map[string][]toml.Primitive, len(outputs.Outputs)),
}
for name := range inputs.Inputs {
log.Println("No config found. Loading default config for plugin", name)
conf.Inputs[name] = make([]toml.Primitive, 0)
return conf
}
for name := range processors.Processors {
log.Println("No config found. Loading default config for plugin", name)
conf.Processors[name] = make([]toml.Primitive, 0)
return conf
}
for name := range outputs.Outputs {
log.Println("No config found. Loading default config for plugin", name)
conf.Outputs[name] = make([]toml.Primitive, 0)
return conf
}
return conf
}

View file

@ -0,0 +1,87 @@
package shim
import (
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/influxdata/telegraf"
cfg "github.com/influxdata/telegraf/config"
"github.com/influxdata/telegraf/plugins/inputs"
"github.com/influxdata/telegraf/plugins/processors"
)
func TestLoadConfig(t *testing.T) {
t.Setenv("SECRET_TOKEN", "xxxxxxxxxx")
t.Setenv("SECRET_VALUE", `test"\test`)
inputs.Add("test", func() telegraf.Input {
return &serviceInput{}
})
c := "./testdata/plugin.conf"
conf, err := LoadConfig(&c)
require.NoError(t, err)
inp := conf.Input.(*serviceInput)
require.Equal(t, "awesome name", inp.ServiceName)
require.Equal(t, "xxxxxxxxxx", inp.SecretToken)
require.Equal(t, `test"\test`, inp.SecretValue)
}
func TestLoadingSpecialTypes(t *testing.T) {
inputs.Add("test", func() telegraf.Input {
return &testDurationInput{}
})
c := "./testdata/special.conf"
conf, err := LoadConfig(&c)
require.NoError(t, err)
inp := conf.Input.(*testDurationInput)
require.EqualValues(t, 3*time.Second, inp.Duration)
require.EqualValues(t, 3*1000*1000, inp.Size)
require.EqualValues(t, 52, inp.Hex)
}
func TestLoadingProcessorWithConfig(t *testing.T) {
proc := &testConfigProcessor{}
processors.Add("test_config_load", func() telegraf.Processor {
return proc
})
c := "./testdata/processor.conf"
_, err := LoadConfig(&c)
require.NoError(t, err)
require.EqualValues(t, "yep", proc.Loaded)
}
type testDurationInput struct {
Duration cfg.Duration `toml:"duration"`
Size cfg.Size `toml:"size"`
Hex int64 `toml:"hex"`
}
func (*testDurationInput) SampleConfig() string {
return ""
}
func (*testDurationInput) Gather(telegraf.Accumulator) error {
return nil
}
type testConfigProcessor struct {
Loaded string `toml:"loaded"`
}
func (*testConfigProcessor) SampleConfig() string {
return ""
}
func (*testConfigProcessor) Apply(metrics ...telegraf.Metric) []telegraf.Metric {
return metrics
}

View file

@ -0,0 +1,64 @@
package main
import (
"flag"
"fmt"
"os"
"time"
// TODO: import your plugins
_ "github.com/influxdata/tail" // Example external package for showing where you can import your plugins
"github.com/influxdata/telegraf/plugins/common/shim"
)
var pollInterval = flag.Duration("poll_interval", 1*time.Second, "how often to send metrics")
var pollIntervalDisabled = flag.Bool(
"poll_interval_disabled",
false,
"set to true to disable polling. You want to use this when you are sending metrics on your own schedule",
)
var configFile = flag.String("config", "", "path to the config file for this plugin")
var err error
// This is designed to be simple; Just change the import above, and you're good.
//
// However, if you want to do all your config in code, you can like so:
//
// // initialize your plugin with any settings you want
//
// myInput := &mypluginname.MyPlugin{
// DefaultSettingHere: 3,
// }
//
// shim := shim.New()
//
// shim.AddInput(myInput)
//
// // now the shim.Run() call as below. Note the shim is only intended to run a single plugin.
func main() {
// parse command line options
flag.Parse()
if *pollIntervalDisabled {
*pollInterval = shim.PollIntervalDisabled
}
// create the shim. This is what will run your plugins.
shimLayer := shim.New()
// If no config is specified, all imported plugins are loaded.
// otherwise, follow what the config asks for.
// Check for settings from a config toml file,
// (or just use whatever plugins were imported above)
if err = shimLayer.LoadConfig(configFile); err != nil {
fmt.Fprintf(os.Stderr, "Err loading input: %s\n", err)
os.Exit(1)
}
// run a single plugin until stdin closes, or we receive a termination signal
if err = shimLayer.Run(*pollInterval); err != nil {
fmt.Fprintf(os.Stderr, "Err: %s\n", err)
os.Exit(1)
}
}

View file

@ -0,0 +1,2 @@
[[inputs.my_plugin_name]]
value_name = "value"

View file

@ -0,0 +1,145 @@
package shim
import (
"context"
"errors"
"fmt"
"io"
"os"
"os/signal"
"strings"
"syscall"
"time"
"github.com/influxdata/telegraf"
"github.com/influxdata/telegraf/logger"
"github.com/influxdata/telegraf/plugins/serializers/influx"
)
type empty struct{}
var (
forever = 100 * 365 * 24 * time.Hour
envVarEscaper = strings.NewReplacer(
`"`, `\"`,
`\`, `\\`,
)
)
const (
// PollIntervalDisabled is used to indicate that you want to disable polling,
// as opposed to duration 0 meaning poll constantly.
PollIntervalDisabled = time.Duration(0)
)
// Shim allows you to wrap your inputs and run them as if they were part of Telegraf,
// except built externally.
type Shim struct {
Input telegraf.Input
Processor telegraf.StreamingProcessor
Output telegraf.Output
log telegraf.Logger
// streams
stdin io.Reader
stdout io.Writer
stderr io.Writer
// outgoing metric channel
metricCh chan telegraf.Metric
// input only
gatherPromptCh chan empty
}
// New creates a new shim interface
func New() *Shim {
return &Shim{
metricCh: make(chan telegraf.Metric, 1),
stdin: os.Stdin,
stdout: os.Stdout,
stderr: os.Stderr,
log: logger.New("", "", ""),
}
}
func (*Shim) watchForShutdown(cancel context.CancelFunc) {
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-quit // user-triggered quit
// cancel, but keep looping until the metric channel closes.
cancel()
}()
}
// Run the input plugins..
func (s *Shim) Run(pollInterval time.Duration) error {
if s.Input != nil {
err := s.RunInput(pollInterval)
if err != nil {
return fmt.Errorf("running input failed: %w", err)
}
} else if s.Processor != nil {
err := s.RunProcessor()
if err != nil {
return fmt.Errorf("running processor failed: %w", err)
}
} else if s.Output != nil {
err := s.RunOutput()
if err != nil {
return fmt.Errorf("running output failed: %w", err)
}
} else {
return errors.New("nothing to run")
}
return nil
}
func hasQuit(ctx context.Context) bool {
return ctx.Err() != nil
}
func (s *Shim) writeProcessedMetrics() error {
serializer := &influx.Serializer{}
if err := serializer.Init(); err != nil {
return fmt.Errorf("creating serializer failed: %w", err)
}
for { //nolint:staticcheck // for-select used on purpose
select {
case m, open := <-s.metricCh:
if !open {
return nil
}
b, err := serializer.Serialize(m)
if err != nil {
m.Reject()
return fmt.Errorf("failed to serialize metric: %w", err)
}
// Write this to stdout
_, err = fmt.Fprint(s.stdout, string(b))
if err != nil {
m.Drop()
return fmt.Errorf("failed to write metric: %w", err)
}
m.Accept()
}
}
}
// LogName satisfies the MetricMaker interface
func (*Shim) LogName() string {
return ""
}
// MakeMetric satisfies the MetricMaker interface
func (*Shim) MakeMetric(m telegraf.Metric) telegraf.Metric {
return m // don't need to do anything to it.
}
// Log satisfies the MetricMaker interface
func (s *Shim) Log() telegraf.Logger {
return s.log
}

View file

@ -0,0 +1,78 @@
package shim
import (
"bufio"
"errors"
"io"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/influxdata/telegraf"
"github.com/influxdata/telegraf/logger"
)
func TestShimSetsUpLogger(t *testing.T) {
stderrReader, stderrWriter := io.Pipe()
stdinReader, stdinWriter := io.Pipe()
runErroringInputPlugin(t, 40*time.Second, stdinReader, nil, stderrWriter)
_, err := stdinWriter.Write([]byte("\n"))
require.NoError(t, err)
r := bufio.NewReader(stderrReader)
out, err := r.ReadString('\n')
require.NoError(t, err)
require.Contains(t, out, "Error in plugin: intentional")
err = stdinWriter.Close()
require.NoError(t, err)
}
func runErroringInputPlugin(t *testing.T, interval time.Duration, stdin io.Reader, stdout, stderr io.Writer) (processed, exited chan bool) {
processed = make(chan bool, 1)
exited = make(chan bool, 1)
inp := &erroringInput{}
shim := New()
if stdin != nil {
shim.stdin = stdin
}
if stdout != nil {
shim.stdout = stdout
}
if stderr != nil {
shim.stderr = stderr
logger.RedirectLogging(stderr)
}
require.NoError(t, shim.AddInput(inp))
go func(e chan bool) {
if err := shim.Run(interval); err != nil {
t.Error(err)
}
e <- true
}(exited)
return processed, exited
}
type erroringInput struct {
}
func (*erroringInput) SampleConfig() string {
return ""
}
func (*erroringInput) Gather(acc telegraf.Accumulator) error {
acc.AddError(errors.New("intentional"))
return nil
}
func (*erroringInput) Start(telegraf.Accumulator) error {
return nil
}
func (*erroringInput) Stop() {
}

View file

@ -0,0 +1,116 @@
package shim
import (
"bufio"
"context"
"fmt"
"sync"
"time"
"github.com/influxdata/telegraf"
"github.com/influxdata/telegraf/agent"
"github.com/influxdata/telegraf/models"
)
// AddInput adds the input to the shim. Later calls to Run() will run this input.
func (s *Shim) AddInput(input telegraf.Input) error {
models.SetLoggerOnPlugin(input, s.Log())
if p, ok := input.(telegraf.Initializer); ok {
err := p.Init()
if err != nil {
return fmt.Errorf("failed to init input: %w", err)
}
}
s.Input = input
return nil
}
func (s *Shim) RunInput(pollInterval time.Duration) error {
// context is used only to close the stdin reader. everything else cascades
// from that point and closes cleanly when it's done.
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s.watchForShutdown(cancel)
acc := agent.NewAccumulator(s, s.metricCh)
acc.SetPrecision(time.Nanosecond)
if serviceInput, ok := s.Input.(telegraf.ServiceInput); ok {
if err := serviceInput.Start(acc); err != nil {
return fmt.Errorf("failed to start input: %w", err)
}
}
s.gatherPromptCh = make(chan empty, 1)
go func() {
s.startGathering(ctx, s.Input, acc, pollInterval)
if serviceInput, ok := s.Input.(telegraf.ServiceInput); ok {
serviceInput.Stop()
}
// closing the metric channel gracefully stops writing to stdout
close(s.metricCh)
}()
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
err := s.writeProcessedMetrics()
if err != nil {
s.log.Warn(err.Error())
}
wg.Done()
}()
go func() {
scanner := bufio.NewScanner(s.stdin)
for scanner.Scan() {
// push a non-blocking message to trigger metric collection.
s.pushCollectMetricsRequest()
}
cancel() // cancel gracefully stops gathering
}()
wg.Wait() // wait for writing to stdout to finish
return nil
}
func (s *Shim) startGathering(ctx context.Context, input telegraf.Input, acc telegraf.Accumulator, pollInterval time.Duration) {
if pollInterval == PollIntervalDisabled {
pollInterval = forever
}
t := time.NewTicker(pollInterval)
defer t.Stop()
for {
// give priority to stopping.
if hasQuit(ctx) {
return
}
// see what's up
select {
case <-ctx.Done():
return
case <-s.gatherPromptCh:
if err := input.Gather(acc); err != nil {
fmt.Fprintf(s.stderr, "failed to gather metrics: %s\n", err)
}
case <-t.C:
if err := input.Gather(acc); err != nil {
fmt.Fprintf(s.stderr, "failed to gather metrics: %s\n", err)
}
}
}
}
// pushCollectMetricsRequest pushes a non-blocking (nil) message to the
// gatherPromptCh channel to trigger metric collection.
// The channel is defined with a buffer of 1, so while it's full, subsequent
// requests are discarded.
func (s *Shim) pushCollectMetricsRequest() {
// push a message out to each channel to collect metrics. don't block.
select {
case s.gatherPromptCh <- empty{}:
default:
}
}

View file

@ -0,0 +1,140 @@
package shim
import (
"bufio"
"io"
"strings"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/influxdata/telegraf"
)
func TestInputShimTimer(t *testing.T) {
stdoutReader, stdoutWriter := io.Pipe()
stdin, _ := io.Pipe() // hold the stdin pipe open
metricProcessed, _ := runInputPlugin(t, 10*time.Millisecond, stdin, stdoutWriter, nil)
<-metricProcessed
r := bufio.NewReader(stdoutReader)
out, err := r.ReadString('\n')
require.NoError(t, err)
require.Contains(t, out, "\n")
metricLine := strings.Split(out, "\n")[0]
require.Equal(t, "measurement,tag=tag field=1i 1234000005678", metricLine)
}
func TestInputShimStdinSignalingWorks(t *testing.T) {
stdinReader, stdinWriter := io.Pipe()
stdoutReader, stdoutWriter := io.Pipe()
metricProcessed, exited := runInputPlugin(t, 40*time.Second, stdinReader, stdoutWriter, nil)
_, err := stdinWriter.Write([]byte("\n"))
require.NoError(t, err)
<-metricProcessed
r := bufio.NewReader(stdoutReader)
out, err := r.ReadString('\n')
require.NoError(t, err)
require.Equal(t, "measurement,tag=tag field=1i 1234000005678\n", out)
err = stdinWriter.Close()
require.NoError(t, err)
go func() {
if _, err = io.ReadAll(r); err != nil {
t.Error(err)
}
}()
// check that it exits cleanly
<-exited
}
func runInputPlugin(t *testing.T, interval time.Duration, stdin io.Reader, stdout, stderr io.Writer) (processed, exited chan bool) {
processed = make(chan bool, 1)
exited = make(chan bool, 1)
inp := &testInput{
metricProcessed: processed,
}
shim := New()
if stdin != nil {
shim.stdin = stdin
}
if stdout != nil {
shim.stdout = stdout
}
if stderr != nil {
shim.stderr = stderr
}
err := shim.AddInput(inp)
require.NoError(t, err)
go func(e chan bool) {
if err := shim.Run(interval); err != nil {
t.Error(err)
}
e <- true
}(exited)
return processed, exited
}
type testInput struct {
metricProcessed chan bool
}
func (*testInput) SampleConfig() string {
return ""
}
func (i *testInput) Gather(acc telegraf.Accumulator) error {
acc.AddFields("measurement",
map[string]interface{}{
"field": 1,
},
map[string]string{
"tag": "tag",
}, time.Unix(1234, 5678))
i.metricProcessed <- true
return nil
}
func (*testInput) Start(telegraf.Accumulator) error {
return nil
}
func (*testInput) Stop() {
}
type serviceInput struct {
ServiceName string `toml:"service_name"`
SecretToken string `toml:"secret_token"`
SecretValue string `toml:"secret_value"`
}
func (*serviceInput) SampleConfig() string {
return ""
}
func (*serviceInput) Gather(acc telegraf.Accumulator) error {
acc.AddFields("measurement",
map[string]interface{}{
"field": 1,
},
map[string]string{
"tag": "tag",
}, time.Unix(1234, 5678))
return nil
}
func (*serviceInput) Start(telegraf.Accumulator) error {
return nil
}
func (*serviceInput) Stop() {
}

View file

@ -0,0 +1,54 @@
package shim
import (
"bufio"
"fmt"
"github.com/influxdata/telegraf"
"github.com/influxdata/telegraf/models"
"github.com/influxdata/telegraf/plugins/parsers/influx"
)
// AddOutput adds the input to the shim. Later calls to Run() will run this.
func (s *Shim) AddOutput(output telegraf.Output) error {
models.SetLoggerOnPlugin(output, s.Log())
if p, ok := output.(telegraf.Initializer); ok {
err := p.Init()
if err != nil {
return fmt.Errorf("failed to init input: %w", err)
}
}
s.Output = output
return nil
}
func (s *Shim) RunOutput() error {
parser := influx.Parser{}
err := parser.Init()
if err != nil {
return fmt.Errorf("failed to create new parser: %w", err)
}
err = s.Output.Connect()
if err != nil {
return fmt.Errorf("failed to start processor: %w", err)
}
defer s.Output.Close()
var m telegraf.Metric
scanner := bufio.NewScanner(s.stdin)
for scanner.Scan() {
m, err = parser.ParseLine(scanner.Text())
if err != nil {
fmt.Fprintf(s.stderr, "Failed to parse metric: %s\n", err)
continue
}
if err = s.Output.Write([]telegraf.Metric{m}); err != nil {
fmt.Fprintf(s.stderr, "Failed to write metric: %s\n", err)
}
}
return nil
}

View file

@ -0,0 +1,81 @@
package shim
import (
"io"
"sync"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/influxdata/telegraf"
"github.com/influxdata/telegraf/metric"
"github.com/influxdata/telegraf/plugins/serializers/influx"
"github.com/influxdata/telegraf/testutil"
)
func TestOutputShim(t *testing.T) {
o := &testOutput{}
stdinReader, stdinWriter := io.Pipe()
s := New()
s.stdin = stdinReader
err := s.AddOutput(o)
require.NoError(t, err)
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
if err := s.RunOutput(); err != nil {
t.Error(err)
}
wg.Done()
}()
serializer := &influx.Serializer{}
require.NoError(t, serializer.Init())
m := metric.New("thing",
map[string]string{
"a": "b",
},
map[string]interface{}{
"v": 1,
},
time.Now(),
)
b, err := serializer.Serialize(m)
require.NoError(t, err)
_, err = stdinWriter.Write(b)
require.NoError(t, err)
err = stdinWriter.Close()
require.NoError(t, err)
wg.Wait()
require.Len(t, o.MetricsWritten, 1)
mOut := o.MetricsWritten[0]
testutil.RequireMetricEqual(t, m, mOut)
}
type testOutput struct {
MetricsWritten []telegraf.Metric
}
func (*testOutput) Connect() error {
return nil
}
func (*testOutput) Close() error {
return nil
}
func (o *testOutput) Write(metrics []telegraf.Metric) error {
o.MetricsWritten = append(o.MetricsWritten, metrics...)
return nil
}
func (*testOutput) SampleConfig() string {
return ""
}

View file

@ -0,0 +1,81 @@
package shim
import (
"errors"
"fmt"
"sync"
"time"
"github.com/influxdata/telegraf"
"github.com/influxdata/telegraf/agent"
"github.com/influxdata/telegraf/models"
"github.com/influxdata/telegraf/plugins/parsers/influx"
"github.com/influxdata/telegraf/plugins/processors"
)
// AddProcessor adds the processor to the shim. Later calls to Run() will run this.
func (s *Shim) AddProcessor(processor telegraf.Processor) error {
models.SetLoggerOnPlugin(processor, s.Log())
p := processors.NewStreamingProcessorFromProcessor(processor)
return s.AddStreamingProcessor(p)
}
// AddStreamingProcessor adds the processor to the shim. Later calls to Run() will run this.
func (s *Shim) AddStreamingProcessor(processor telegraf.StreamingProcessor) error {
models.SetLoggerOnPlugin(processor, s.Log())
if p, ok := processor.(telegraf.Initializer); ok {
err := p.Init()
if err != nil {
return fmt.Errorf("failed to init input: %w", err)
}
}
s.Processor = processor
return nil
}
func (s *Shim) RunProcessor() error {
acc := agent.NewAccumulator(s, s.metricCh)
acc.SetPrecision(time.Nanosecond)
err := s.Processor.Start(acc)
if err != nil {
return fmt.Errorf("failed to start processor: %w", err)
}
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
err := s.writeProcessedMetrics()
if err != nil {
s.log.Warn(err.Error())
}
wg.Done()
}()
parser := influx.NewStreamParser(s.stdin)
for {
m, err := parser.Next()
if err != nil {
if errors.Is(err, influx.EOF) {
break // stream ended
}
var parseErr *influx.ParseError
if errors.As(err, &parseErr) {
fmt.Fprintf(s.stderr, "Failed to parse metric: %s\b", parseErr)
continue
}
fmt.Fprintf(s.stderr, "Failure during reading stdin: %s\b", err)
continue
}
if err = s.Processor.Add(m, acc); err != nil {
fmt.Fprintf(s.stderr, "Failure during processing metric by processor: %v\b", err)
}
}
close(s.metricCh)
s.Processor.Stop()
wg.Wait()
return nil
}

View file

@ -0,0 +1,113 @@
package shim
import (
"bufio"
"io"
"math/rand"
"sync"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/influxdata/telegraf"
"github.com/influxdata/telegraf/metric"
"github.com/influxdata/telegraf/plugins/parsers/influx"
serializers_influx "github.com/influxdata/telegraf/plugins/serializers/influx"
)
func TestProcessorShim(t *testing.T) {
testSendAndReceive(t, "f1", "fv1")
}
func TestProcessorShimWithLargerThanDefaultScannerBufferSize(t *testing.T) {
letters := []rune("ABCDEFGHIJKLMNOPQRSTUVWXYZ")
b := make([]rune, 0, bufio.MaxScanTokenSize*2)
for i := 0; i < bufio.MaxScanTokenSize*2; i++ {
b = append(b, letters[rand.Intn(len(letters))])
}
testSendAndReceive(t, "f1", string(b))
}
func testSendAndReceive(t *testing.T, fieldKey, fieldValue string) {
p := &testProcessor{"hi", "mom"}
stdinReader, stdinWriter := io.Pipe()
stdoutReader, stdoutWriter := io.Pipe()
s := New()
// inject test into shim
s.stdin = stdinReader
s.stdout = stdoutWriter
err := s.AddProcessor(p)
require.NoError(t, err)
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
if err := s.RunProcessor(); err != nil {
t.Error(err)
}
wg.Done()
}()
serializer := &serializers_influx.Serializer{}
require.NoError(t, serializer.Init())
parser := influx.Parser{}
require.NoError(t, parser.Init())
m := metric.New("thing",
map[string]string{
"a": "b",
},
map[string]interface{}{
"v": 1,
fieldKey: fieldValue,
},
time.Now(),
)
b, err := serializer.Serialize(m)
require.NoError(t, err)
_, err = stdinWriter.Write(b)
require.NoError(t, err)
err = stdinWriter.Close()
require.NoError(t, err)
r := bufio.NewReader(stdoutReader)
out, err := r.ReadString('\n')
require.NoError(t, err)
mOut, err := parser.ParseLine(out)
require.NoError(t, err)
val, ok := mOut.GetTag(p.tagName)
require.True(t, ok)
require.Equal(t, p.tagValue, val)
val2, ok := mOut.Fields()[fieldKey]
require.True(t, ok)
require.Equal(t, fieldValue, val2)
go func() {
if _, err = io.ReadAll(r); err != nil {
t.Error(err)
}
}()
wg.Wait()
}
type testProcessor struct {
tagName string
tagValue string
}
func (p *testProcessor) Apply(in ...telegraf.Metric) []telegraf.Metric {
for _, m := range in {
m.AddTag(p.tagName, p.tagValue)
}
return in
}
func (*testProcessor) SampleConfig() string {
return ""
}

View file

@ -0,0 +1,4 @@
[[inputs.test]]
service_name = "awesome name"
secret_token = "${SECRET_TOKEN}"
secret_value = "$SECRET_VALUE"

View file

@ -0,0 +1,2 @@
[[processors.test_config_load]]
loaded = "yep"

View file

@ -0,0 +1,5 @@
# testing custom field types
[[inputs.test]]
duration = "3s"
size = "3MB"
hex = 0x34

View file

@ -0,0 +1,267 @@
package socket
import (
"errors"
"fmt"
"io"
"net"
"net/url"
"os"
"path/filepath"
"runtime"
"strconv"
"strings"
"sync"
"time"
"github.com/alitto/pond"
"github.com/influxdata/telegraf"
"github.com/influxdata/telegraf/config"
"github.com/influxdata/telegraf/internal"
)
type packetListener struct {
Encoding string
MaxDecompressionSize int64
SocketMode string
ReadBufferSize int
Log telegraf.Logger
conn net.PacketConn
decoders sync.Pool
path string
wg sync.WaitGroup
parsePool *pond.WorkerPool
}
func newPacketListener(encoding string, maxDecompressionSize config.Size, maxWorkers int) *packetListener {
return &packetListener{
Encoding: encoding,
MaxDecompressionSize: int64(maxDecompressionSize),
parsePool: pond.New(maxWorkers, 0, pond.MinWorkers(maxWorkers/2+1)),
}
}
func (l *packetListener) listenData(onData CallbackData, onError CallbackError) {
l.wg.Add(1)
go func() {
defer l.wg.Done()
buf := make([]byte, l.ReadBufferSize)
for {
n, src, err := l.conn.ReadFrom(buf)
receiveTime := time.Now()
if err != nil {
if !strings.HasSuffix(err.Error(), ": use of closed network connection") {
if onError != nil {
onError(err)
}
}
break
}
d := make([]byte, n)
copy(d, buf[:n])
l.parsePool.Submit(func() {
decoder := l.decoders.Get().(internal.ContentDecoder)
defer l.decoders.Put(decoder)
body, err := decoder.Decode(d)
if err != nil && onError != nil {
onError(fmt.Errorf("unable to decode incoming packet: %w", err))
}
if l.path != "" {
src = &net.UnixAddr{Name: l.path, Net: "unixgram"}
}
onData(src, body, receiveTime)
})
}
}()
}
func (l *packetListener) listenConnection(onConnection CallbackConnection, onError CallbackError) {
l.wg.Add(1)
go func() {
defer l.wg.Done()
defer l.conn.Close()
buf := make([]byte, l.ReadBufferSize)
for {
// Wait for packets and read them
n, src, err := l.conn.ReadFrom(buf)
if err != nil {
if !strings.HasSuffix(err.Error(), ": use of closed network connection") {
if onError != nil {
onError(err)
}
}
break
}
d := make([]byte, n)
copy(d, buf[:n])
l.parsePool.Submit(func() {
// Decode the contents depending on the given encoding
decoder := l.decoders.Get().(internal.ContentDecoder)
// Not possible to immediately return the decoder to the Pool after calling Decode, because some
// decoders return a reference to their internal buffers. This would cause data races.
defer l.decoders.Put(decoder)
body, err := decoder.Decode(d[:n])
if err != nil && onError != nil {
onError(fmt.Errorf("unable to decode incoming packet: %w", err))
}
// Workaround to provide remote endpoints for Unix-type sockets
if l.path != "" {
src = &net.UnixAddr{Name: l.path, Net: "unixgram"}
}
// Create a pipe and notify the caller via Callback that new data is
// available. Afterwards write the data. Please note: Write() will
// block until all data is consumed!
reader, writer := io.Pipe()
go onConnection(src, reader)
if _, err := writer.Write(body); err != nil && onError != nil {
onError(err)
}
writer.Close()
})
}
}()
}
func (l *packetListener) setupUnixgram(u *url.URL, socketMode string, bufferSize int) error {
l.path = filepath.FromSlash(u.Path)
if runtime.GOOS == "windows" && strings.Contains(l.path, ":") {
l.path = strings.TrimPrefix(l.path, `\`)
}
if err := os.Remove(l.path); err != nil && !errors.Is(err, os.ErrNotExist) {
return fmt.Errorf("removing socket failed: %w", err)
}
conn, err := net.ListenPacket(u.Scheme, l.path)
if err != nil {
return fmt.Errorf("listening (unixgram) failed: %w", err)
}
l.conn = conn
// Set permissions on socket
if socketMode != "" {
// Convert from octal in string to int
i, err := strconv.ParseUint(socketMode, 8, 32)
if err != nil {
return fmt.Errorf("converting socket mode failed: %w", err)
}
perm := os.FileMode(uint32(i))
if err := os.Chmod(u.Path, perm); err != nil {
return fmt.Errorf("changing socket permissions failed: %w", err)
}
}
if bufferSize > 0 {
l.ReadBufferSize = bufferSize
} else {
l.ReadBufferSize = 64 * 1024 // 64kb - IP packet size
}
return l.setupDecoder()
}
func (l *packetListener) setupUDP(u *url.URL, ifname string, bufferSize int) error {
var conn *net.UDPConn
addr, err := net.ResolveUDPAddr(u.Scheme, u.Host)
if err != nil {
return fmt.Errorf("resolving UDP address failed: %w", err)
}
if addr.IP.IsMulticast() {
var iface *net.Interface
if ifname != "" {
var err error
iface, err = net.InterfaceByName(ifname)
if err != nil {
return fmt.Errorf("resolving address of %q failed: %w", ifname, err)
}
}
conn, err = net.ListenMulticastUDP(u.Scheme, iface, addr)
if err != nil {
return fmt.Errorf("listening (udp multicast) failed: %w", err)
}
} else {
conn, err = net.ListenUDP(u.Scheme, addr)
if err != nil {
return fmt.Errorf("listening (udp) failed: %w", err)
}
}
if bufferSize > 0 {
if err := conn.SetReadBuffer(bufferSize); err != nil {
l.Log.Warnf("Setting read buffer on %s socket failed: %v", u.Scheme, err)
}
}
l.ReadBufferSize = 64 * 1024 // 64kb - IP packet size
l.conn = conn
return l.setupDecoder()
}
func (l *packetListener) setupIP(u *url.URL) error {
conn, err := net.ListenPacket(u.Scheme, u.Host)
if err != nil {
return fmt.Errorf("listening (ip) failed: %w", err)
}
l.ReadBufferSize = 64 * 1024 // 64kb - IP packet size
l.conn = conn
return l.setupDecoder()
}
func (l *packetListener) setupDecoder() error {
// Create a decoder for the given encoding
var options []internal.DecodingOption
if l.MaxDecompressionSize > 0 {
options = append(options, internal.WithMaxDecompressionSize(l.MaxDecompressionSize))
}
l.decoders = sync.Pool{New: func() any {
decoder, err := internal.NewContentDecoder(l.Encoding, options...)
if err != nil {
l.Log.Errorf("creating decoder failed: %v", err)
return nil
}
return decoder
}}
return nil
}
func (l *packetListener) address() net.Addr {
return l.conn.LocalAddr()
}
func (l *packetListener) close() error {
if err := l.conn.Close(); err != nil {
return err
}
l.wg.Wait()
if l.path != "" {
fn := filepath.FromSlash(l.path)
if runtime.GOOS == "windows" && strings.Contains(fn, ":") {
fn = strings.TrimPrefix(fn, `\`)
}
if err := os.Remove(fn); err != nil && !errors.Is(err, os.ErrNotExist) {
// Ignore file-not-exists errors when removing the socket
return err
}
}
l.parsePool.StopAndWait()
return nil
}

View file

@ -0,0 +1,37 @@
## Permission for unix sockets (only available on unix sockets)
## This setting may not be respected by some platforms. To safely restrict
## permissions it is recommended to place the socket into a previously
## created directory with the desired permissions.
## ex: socket_mode = "777"
# socket_mode = ""
## Maximum number of concurrent connections (only available on stream sockets like TCP)
## Zero means unlimited.
# max_connections = 0
## Read timeout (only available on stream sockets like TCP)
## Zero means unlimited.
# read_timeout = "0s"
## Optional TLS configuration (only available on stream sockets like TCP)
# tls_cert = "/etc/telegraf/cert.pem"
# tls_key = "/etc/telegraf/key.pem"
## Enables client authentication if set.
# tls_allowed_cacerts = ["/etc/telegraf/clientca.pem"]
## Maximum socket buffer size (in bytes when no unit specified)
## For stream sockets, once the buffer fills up, the sender will start
## backing up. For datagram sockets, once the buffer fills up, metrics will
## start dropping. Defaults to the OS default.
# read_buffer_size = "64KiB"
## Period between keep alive probes (only applies to TCP sockets)
## Zero disables keep alive probes. Defaults to the OS configuration.
# keep_alive_period = "5m"
## Content encoding for message payloads
## Can be set to "gzip" for compressed payloads or "identity" for no encoding.
# content_encoding = "identity"
## Maximum size of decoded packet (in bytes when no unit specified)
# max_decompression_size = "500MB"

View file

@ -0,0 +1,181 @@
package socket
import (
"bufio"
"crypto/tls"
"fmt"
"io"
"net"
"net/url"
"regexp"
"strings"
"time"
"github.com/influxdata/telegraf"
"github.com/influxdata/telegraf/config"
common_tls "github.com/influxdata/telegraf/plugins/common/tls"
)
type CallbackData func(net.Addr, []byte, time.Time)
type CallbackConnection func(net.Addr, io.ReadCloser)
type CallbackError func(error)
type listener interface {
address() net.Addr
listenData(CallbackData, CallbackError)
listenConnection(CallbackConnection, CallbackError)
close() error
}
type Config struct {
MaxConnections uint64 `toml:"max_connections"`
ReadBufferSize config.Size `toml:"read_buffer_size"`
ReadTimeout config.Duration `toml:"read_timeout"`
KeepAlivePeriod *config.Duration `toml:"keep_alive_period"`
SocketMode string `toml:"socket_mode"`
ContentEncoding string `toml:"content_encoding"`
MaxDecompressionSize config.Size `toml:"max_decompression_size"`
MaxParallelParsers int `toml:"max_parallel_parsers"`
common_tls.ServerConfig
}
type Socket struct {
Config
url *url.URL
interfaceName string
tlsCfg *tls.Config
log telegraf.Logger
splitter bufio.SplitFunc
listener listener
}
func (cfg *Config) NewSocket(address string, splitcfg *SplitConfig, logger telegraf.Logger) (*Socket, error) {
s := &Socket{
Config: *cfg,
log: logger,
}
// Setup the splitter if given
if splitcfg != nil {
splitter, err := splitcfg.NewSplitter()
if err != nil {
return nil, err
}
s.splitter = splitter
}
// Resolve the interface to an address if any given
ifregex := regexp.MustCompile(`%([\w\.]+)`)
if matches := ifregex.FindStringSubmatch(address); len(matches) == 2 {
s.interfaceName = matches[1]
address = strings.Replace(address, "%"+s.interfaceName, "", 1)
}
// Preparing TLS configuration
tlsCfg, err := s.ServerConfig.TLSConfig()
if err != nil {
return nil, fmt.Errorf("getting TLS config failed: %w", err)
}
s.tlsCfg = tlsCfg
// Parse and check the address
u, err := url.Parse(address)
if err != nil {
return nil, fmt.Errorf("parsing address failed: %w", err)
}
s.url = u
switch s.url.Scheme {
case "tcp", "tcp4", "tcp6", "unix", "unixpacket",
"udp", "udp4", "udp6", "ip", "ip4", "ip6", "unixgram", "vsock":
default:
return nil, fmt.Errorf("unknown protocol %q in %q", u.Scheme, address)
}
return s, nil
}
func (s *Socket) Setup() error {
s.MaxParallelParsers = max(s.MaxParallelParsers, 1)
switch s.url.Scheme {
case "tcp", "tcp4", "tcp6":
l := newStreamListener(
s.Config,
s.splitter,
s.log,
)
if err := l.setupTCP(s.url, s.tlsCfg); err != nil {
return err
}
s.listener = l
case "unix", "unixpacket":
l := newStreamListener(
s.Config,
s.splitter,
s.log,
)
if err := l.setupUnix(s.url, s.tlsCfg, s.SocketMode); err != nil {
return err
}
s.listener = l
case "udp", "udp4", "udp6":
l := newPacketListener(s.ContentEncoding, s.MaxDecompressionSize, s.MaxParallelParsers)
if err := l.setupUDP(s.url, s.interfaceName, int(s.ReadBufferSize)); err != nil {
return err
}
s.listener = l
case "ip", "ip4", "ip6":
l := newPacketListener(s.ContentEncoding, s.MaxDecompressionSize, s.MaxParallelParsers)
if err := l.setupIP(s.url); err != nil {
return err
}
s.listener = l
case "unixgram":
l := newPacketListener(s.ContentEncoding, s.MaxDecompressionSize, s.MaxParallelParsers)
if err := l.setupUnixgram(s.url, s.SocketMode, int(s.ReadBufferSize)); err != nil {
return err
}
s.listener = l
case "vsock":
l := newStreamListener(
s.Config,
s.splitter,
s.log,
)
if err := l.setupVsock(s.url); err != nil {
return err
}
s.listener = l
default:
return fmt.Errorf("unknown protocol %q", s.url.Scheme)
}
return nil
}
func (s *Socket) Listen(onData CallbackData, onError CallbackError) {
s.listener.listenData(onData, onError)
}
func (s *Socket) ListenConnection(onConnection CallbackConnection, onError CallbackError) {
s.listener.listenConnection(onConnection, onError)
}
func (s *Socket) Close() {
if s.listener != nil {
// Ignore the returned error as we cannot do anything about it anyway
if err := s.listener.close(); err != nil {
s.log.Warnf("Closing socket failed: %v", err)
}
s.listener = nil
}
}
func (s *Socket) Address() net.Addr {
return s.listener.address()
}

View file

@ -0,0 +1,845 @@
package socket
import (
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"os"
"runtime"
"strings"
"sync"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/influxdata/telegraf"
"github.com/influxdata/telegraf/config"
"github.com/influxdata/telegraf/internal"
"github.com/influxdata/telegraf/metric"
_ "github.com/influxdata/telegraf/plugins/parsers/all"
"github.com/influxdata/telegraf/plugins/parsers/influx"
"github.com/influxdata/telegraf/testutil"
)
var pki = testutil.NewPKI("../../../testutil/pki")
func TestListenData(t *testing.T) {
messages := [][]byte{
[]byte("test,foo=bar v=1i 123456789\ntest,foo=baz v=2i 123456790\n"),
[]byte("test,foo=zab v=3i 123456791\n"),
}
expectedTemplates := []telegraf.Metric{
metric.New(
"test",
map[string]string{"foo": "bar"},
map[string]interface{}{"v": int64(1)},
time.Unix(0, 123456789),
),
metric.New(
"test",
map[string]string{"foo": "baz"},
map[string]interface{}{"v": int64(2)},
time.Unix(0, 123456790),
),
metric.New(
"test",
map[string]string{"foo": "zab"},
map[string]interface{}{"v": int64(3)},
time.Unix(0, 123456791),
),
}
tests := []struct {
name string
schema string
buffersize config.Size
encoding string
}{
{
name: "TCP",
schema: "tcp",
buffersize: config.Size(1024),
},
{
name: "TCP with TLS",
schema: "tcp+tls",
},
{
name: "TCP with gzip encoding",
schema: "tcp",
buffersize: config.Size(1024),
encoding: "gzip",
},
{
name: "UDP",
schema: "udp",
buffersize: config.Size(1024),
},
{
name: "UDP with gzip encoding",
schema: "udp",
buffersize: config.Size(1024),
encoding: "gzip",
},
{
name: "unix socket",
schema: "unix",
buffersize: config.Size(1024),
},
{
name: "unix socket with TLS",
schema: "unix+tls",
},
{
name: "unix socket with gzip encoding",
schema: "unix",
encoding: "gzip",
},
{
name: "unixgram socket",
schema: "unixgram",
buffersize: config.Size(1024),
},
}
serverTLS := pki.TLSServerConfig()
clientTLS := pki.TLSClientConfig()
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
proto := strings.TrimSuffix(tt.schema, "+tls")
// Prepare the address and socket if needed
var sockPath string
var serviceAddress string
var tlsCfg *tls.Config
switch proto {
case "tcp", "udp":
serviceAddress = proto + "://" + "127.0.0.1:0"
case "unix", "unixgram":
if runtime.GOOS == "windows" {
t.Skip("Skipping on Windows, as unixgram sockets are not supported")
}
// Create a socket
sockPath = testutil.TempSocket(t)
f, err := os.Create(sockPath)
require.NoError(t, err)
defer f.Close()
serviceAddress = proto + "://" + sockPath
}
// Setup the configuration according to test specification
cfg := &Config{
ContentEncoding: tt.encoding,
ReadBufferSize: tt.buffersize,
}
if strings.HasSuffix(tt.schema, "tls") {
cfg.ServerConfig = *serverTLS
var err error
tlsCfg, err = clientTLS.TLSConfig()
require.NoError(t, err)
}
// Create the socket
sock, err := cfg.NewSocket(serviceAddress, &SplitConfig{}, &testutil.Logger{})
require.NoError(t, err)
// Create callbacks
parser := &influx.Parser{}
require.NoError(t, parser.Init())
var acc testutil.Accumulator
onData := func(remote net.Addr, data []byte, _ time.Time) {
m, err := parser.Parse(data)
require.NoError(t, err)
addr, _, err := net.SplitHostPort(remote.String())
if err != nil {
addr = remote.String()
}
for i := range m {
m[i].AddTag("source", addr)
}
acc.AddMetrics(m)
}
onError := func(err error) {
acc.AddError(err)
}
// Start the listener
require.NoError(t, sock.Setup())
sock.Listen(onData, onError)
defer sock.Close()
addr := sock.Address()
// Create a noop client
// Server is async, so verify no errors at the end.
client, err := createClient(serviceAddress, addr, tlsCfg)
require.NoError(t, err)
require.NoError(t, client.Close())
// Setup the client for submitting data
client, err = createClient(serviceAddress, addr, tlsCfg)
require.NoError(t, err)
// Conditionally add the source address to the expectation
expected := make([]telegraf.Metric, 0, len(expectedTemplates))
for _, tmpl := range expectedTemplates {
m := tmpl.Copy()
switch proto {
case "tcp", "udp":
laddr := client.LocalAddr().String()
addr, _, err := net.SplitHostPort(laddr)
if err != nil {
addr = laddr
}
m.AddTag("source", addr)
case "unix", "unixgram":
m.AddTag("source", sockPath)
}
expected = append(expected, m)
}
// Send the data with the correct encoding
encoder, err := internal.NewContentEncoder(tt.encoding)
require.NoError(t, err)
for i, msg := range messages {
m, err := encoder.Encode(msg)
require.NoErrorf(t, err, "encoding failed for msg %d", i)
_, err = client.Write(m)
require.NoErrorf(t, err, "sending msg %d failed", i)
}
// Test the resulting metrics and compare against expected results
require.Eventuallyf(t, func() bool {
acc.Lock()
defer acc.Unlock()
return acc.NMetrics() >= uint64(len(expected))
}, time.Second, 100*time.Millisecond, "did not receive metrics (%d)", acc.NMetrics())
actual := acc.GetTelegrafMetrics()
testutil.RequireMetricsEqual(t, expected, actual, testutil.SortMetrics())
})
}
}
func TestListenConnection(t *testing.T) {
messages := [][]byte{
[]byte("test,foo=bar v=1i 123456789\ntest,foo=baz v=2i 123456790\n"),
[]byte("test,foo=zab v=3i 123456791\n"),
}
expectedTemplates := []telegraf.Metric{
metric.New(
"test",
map[string]string{"foo": "bar"},
map[string]interface{}{"v": int64(1)},
time.Unix(0, 123456789),
),
metric.New(
"test",
map[string]string{"foo": "baz"},
map[string]interface{}{"v": int64(2)},
time.Unix(0, 123456790),
),
metric.New(
"test",
map[string]string{"foo": "zab"},
map[string]interface{}{"v": int64(3)},
time.Unix(0, 123456791),
),
}
tests := []struct {
name string
schema string
buffersize config.Size
encoding string
}{
{
name: "TCP",
schema: "tcp",
buffersize: config.Size(1024),
},
{
name: "TCP with TLS",
schema: "tcp+tls",
},
{
name: "TCP with gzip encoding",
schema: "tcp",
buffersize: config.Size(1024),
encoding: "gzip",
},
{
name: "UDP",
schema: "udp",
buffersize: config.Size(1024),
},
{
name: "UDP with gzip encoding",
schema: "udp",
buffersize: config.Size(1024),
encoding: "gzip",
},
{
name: "unix socket",
schema: "unix",
buffersize: config.Size(1024),
},
{
name: "unix socket with TLS",
schema: "unix+tls",
},
{
name: "unix socket with gzip encoding",
schema: "unix",
encoding: "gzip",
},
{
name: "unixgram socket",
schema: "unixgram",
buffersize: config.Size(1024),
},
}
serverTLS := pki.TLSServerConfig()
clientTLS := pki.TLSClientConfig()
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
proto := strings.TrimSuffix(tt.schema, "+tls")
// Prepare the address and socket if needed
var sockPath string
var serviceAddress string
var tlsCfg *tls.Config
switch proto {
case "tcp", "udp":
serviceAddress = proto + "://" + "127.0.0.1:0"
case "unix", "unixgram":
if runtime.GOOS == "windows" {
t.Skip("Skipping on Windows, as unixgram sockets are not supported")
}
// Create a socket
sockPath = testutil.TempSocket(t)
f, err := os.Create(sockPath)
require.NoError(t, err)
defer f.Close()
serviceAddress = proto + "://" + sockPath
}
// Setup the configuration according to test specification
cfg := &Config{
ContentEncoding: tt.encoding,
ReadBufferSize: tt.buffersize,
}
if strings.HasSuffix(tt.schema, "tls") {
cfg.ServerConfig = *serverTLS
var err error
tlsCfg, err = clientTLS.TLSConfig()
require.NoError(t, err)
}
// Create the socket
sock, err := cfg.NewSocket(serviceAddress, &SplitConfig{}, &testutil.Logger{})
require.NoError(t, err)
// Create callbacks
parser := &influx.Parser{}
require.NoError(t, parser.Init())
var acc testutil.Accumulator
onConnection := func(remote net.Addr, reader io.ReadCloser) {
data, err := io.ReadAll(reader)
require.NoError(t, err)
m, err := parser.Parse(data)
require.NoError(t, err)
addr, _, err := net.SplitHostPort(remote.String())
if err != nil {
addr = remote.String()
}
for i := range m {
m[i].AddTag("source", addr)
}
acc.AddMetrics(m)
}
onError := func(err error) {
acc.AddError(err)
}
// Start the listener
require.NoError(t, sock.Setup())
sock.ListenConnection(onConnection, onError)
defer sock.Close()
addr := sock.Address()
// Create a noop client
// Server is async, so verify no errors at the end.
client, err := createClient(serviceAddress, addr, tlsCfg)
require.NoError(t, err)
require.NoError(t, client.Close())
// Setup the client for submitting data
client, err = createClient(serviceAddress, addr, tlsCfg)
require.NoError(t, err)
// Conditionally add the source address to the expectation
expected := make([]telegraf.Metric, 0, len(expectedTemplates))
for _, tmpl := range expectedTemplates {
m := tmpl.Copy()
switch proto {
case "tcp", "udp":
laddr := client.LocalAddr().String()
addr, _, err := net.SplitHostPort(laddr)
if err != nil {
addr = laddr
}
m.AddTag("source", addr)
case "unix", "unixgram":
m.AddTag("source", sockPath)
}
expected = append(expected, m)
}
// Send the data with the correct encoding
encoder, err := internal.NewContentEncoder(tt.encoding)
require.NoError(t, err)
for i, msg := range messages {
m, err := encoder.Encode(msg)
require.NoErrorf(t, err, "encoding failed for msg %d", i)
_, err = client.Write(m)
require.NoErrorf(t, err, "sending msg %d failed", i)
}
client.Close()
// Test the resulting metrics and compare against expected results
require.Eventuallyf(t, func() bool {
acc.Lock()
defer acc.Unlock()
return acc.NMetrics() >= uint64(len(expected))
}, time.Second, 100*time.Millisecond, "did not receive metrics (%d)", acc.NMetrics())
actual := acc.GetTelegrafMetrics()
testutil.RequireMetricsEqual(t, expected, actual, testutil.SortMetrics())
})
}
}
func TestClosingConnections(t *testing.T) {
// Setup the configuration
cfg := &Config{
ReadBufferSize: 1024,
}
// Create the socket
serviceAddress := "tcp://127.0.0.1:0"
logger := &testutil.CaptureLogger{}
sock, err := cfg.NewSocket(serviceAddress, &SplitConfig{}, logger)
require.NoError(t, err)
// Create callbacks
parser := &influx.Parser{}
require.NoError(t, parser.Init())
var acc testutil.Accumulator
onData := func(_ net.Addr, data []byte, _ time.Time) {
m, err := parser.Parse(data)
require.NoError(t, err)
acc.AddMetrics(m)
}
onError := func(err error) {
acc.AddError(err)
}
// Start the listener
require.NoError(t, sock.Setup())
sock.Listen(onData, onError)
defer sock.Close()
addr := sock.Address()
// Create a noop client
client, err := createClient(serviceAddress, addr, nil)
require.NoError(t, err)
_, err = client.Write([]byte("test value=42i\n"))
require.NoError(t, err)
require.Eventually(t, func() bool {
acc.Lock()
defer acc.Unlock()
return acc.NMetrics() >= 1
}, time.Second, 100*time.Millisecond, "did not receive metric")
// This has to be a stream-listener...
listener, ok := sock.listener.(*streamListener)
require.True(t, ok)
listener.Lock()
conns := listener.connections
listener.Unlock()
require.NotZero(t, conns)
sock.Close()
// Verify that plugin.Stop() closed the client's connection
require.NoError(t, client.SetReadDeadline(time.Now().Add(time.Second)))
buf := []byte{1}
_, err = client.Read(buf)
require.Equal(t, err, io.EOF)
require.Empty(t, logger.Errors())
require.Empty(t, logger.Warnings())
}
func TestMaxConnections(t *testing.T) {
if runtime.GOOS == "darwin" {
t.Skip("Skipping on darwin due to missing socket options")
}
// Setup the configuration
period := config.Duration(10 * time.Millisecond)
cfg := &Config{
MaxConnections: 5,
KeepAlivePeriod: &period,
}
// Create the socket
serviceAddress := "tcp://127.0.0.1:0"
sock, err := cfg.NewSocket(serviceAddress, nil, &testutil.Logger{})
require.NoError(t, err)
// Create callback
var errs []error
var mu sync.Mutex
onData := func(_ net.Addr, _ []byte, _ time.Time) {}
onError := func(err error) {
mu.Lock()
errs = append(errs, err)
mu.Unlock()
}
// Start the listener
require.NoError(t, sock.Setup())
sock.Listen(onData, onError)
defer sock.Close()
addr := sock.Address()
// Create maximum number of connections and write some data. All of this
// should succeed...
clients := make([]*net.TCPConn, 0, cfg.MaxConnections)
for i := 0; i < int(cfg.MaxConnections); i++ {
c, err := net.DialTCP("tcp", nil, addr.(*net.TCPAddr))
require.NoError(t, err)
require.NoError(t, c.SetWriteBuffer(0))
require.NoError(t, c.SetNoDelay(true))
clients = append(clients, c)
_, err = c.Write([]byte("test value=42i\n"))
require.NoError(t, err)
}
func() {
mu.Lock()
defer mu.Unlock()
require.Empty(t, errs)
}()
// Create another client. This should fail because we already reached the
// connection limit and the connection should be closed...
client, err := net.DialTCP("tcp", nil, addr.(*net.TCPAddr))
require.NoError(t, err)
require.NoError(t, client.SetWriteBuffer(0))
require.NoError(t, client.SetNoDelay(true))
require.Eventually(t, func() bool {
mu.Lock()
defer mu.Unlock()
return len(errs) > 0
}, 3*time.Second, 100*time.Millisecond)
func() {
mu.Lock()
defer mu.Unlock()
require.Len(t, errs, 1)
require.ErrorContains(t, errs[0], "too many connections")
errs = make([]error, 0)
}()
require.Eventually(t, func() bool {
_, err := client.Write([]byte("fail\n"))
return err != nil
}, 3*time.Second, 100*time.Millisecond)
_, err = client.Write([]byte("test\n"))
require.Error(t, err)
// Check other connections are still good
for _, c := range clients {
_, err := c.Write([]byte("test\n"))
require.NoError(t, err)
}
func() {
mu.Lock()
defer mu.Unlock()
require.Empty(t, errs)
}()
// Close the first client and check if we can connect now
require.NoError(t, clients[0].Close())
client, err = net.DialTCP("tcp", nil, addr.(*net.TCPAddr))
require.NoError(t, err)
require.NoError(t, client.SetWriteBuffer(0))
require.NoError(t, client.SetNoDelay(true))
_, err = client.Write([]byte("success\n"))
require.NoError(t, err)
// Close all connections
require.NoError(t, client.Close())
for _, c := range clients[1:] {
require.NoError(t, c.Close())
}
// Close the clients and check the connection counter
listener, ok := sock.listener.(*streamListener)
require.True(t, ok)
require.Eventually(t, func() bool {
listener.Lock()
conns := listener.connections
listener.Unlock()
return conns == 0
}, 3*time.Second, 100*time.Millisecond)
// Close the socket and check again...
sock.Close()
listener.Lock()
conns := listener.connections
listener.Unlock()
require.Zero(t, conns)
}
func TestNoSplitter(t *testing.T) {
messages := [][]byte{
[]byte("test,foo=bar v"),
[]byte("=1i 123456789\ntest,foo=baz v=2i 123456790\ntest,foo=zab v=3i 123456791\n"),
}
expectedTemplates := []telegraf.Metric{
metric.New(
"test",
map[string]string{"foo": "bar"},
map[string]interface{}{"v": int64(1)},
time.Unix(0, 123456789),
),
metric.New(
"test",
map[string]string{"foo": "baz"},
map[string]interface{}{"v": int64(2)},
time.Unix(0, 123456790),
),
metric.New(
"test",
map[string]string{"foo": "zab"},
map[string]interface{}{"v": int64(3)},
time.Unix(0, 123456791),
),
}
// Prepare the address and socket if needed
serviceAddress := "tcp://127.0.0.1:0"
// Setup the configuration according to test specification
cfg := &Config{}
// Create the socket
sock, err := cfg.NewSocket(serviceAddress, nil, &testutil.Logger{})
require.NoError(t, err)
// Create callbacks
parser := &influx.Parser{}
require.NoError(t, parser.Init())
var acc testutil.Accumulator
onConnection := func(remote net.Addr, reader io.ReadCloser) {
data, err := io.ReadAll(reader)
require.NoError(t, err)
m, err := parser.Parse(data)
require.NoError(t, err)
addr, _, err := net.SplitHostPort(remote.String())
if err != nil {
addr = remote.String()
}
for i := range m {
m[i].AddTag("source", addr)
}
acc.AddMetrics(m)
}
onError := func(err error) {
acc.AddError(err)
}
// Start the listener
require.NoError(t, sock.Setup())
sock.ListenConnection(onConnection, onError)
defer sock.Close()
addr := sock.Address()
// Create a noop client
// Server is async, so verify no errors at the end.
client, err := createClient(serviceAddress, addr, nil)
require.NoError(t, err)
require.NoError(t, client.Close())
// Setup the client for submitting data
client, err = createClient(serviceAddress, addr, nil)
require.NoError(t, err)
// Conditionally add the source address to the expectation
expected := make([]telegraf.Metric, 0, len(expectedTemplates))
for _, tmpl := range expectedTemplates {
m := tmpl.Copy()
laddr := client.LocalAddr().String()
addr, _, err := net.SplitHostPort(laddr)
if err != nil {
addr = laddr
}
m.AddTag("source", addr)
expected = append(expected, m)
}
// Send the data
for i, msg := range messages {
_, err = client.Write(msg)
time.Sleep(100 * time.Millisecond)
require.NoErrorf(t, err, "sending msg %d failed", i)
}
client.Close()
// Test the resulting metrics and compare against expected results
require.Eventuallyf(t, func() bool {
acc.Lock()
defer acc.Unlock()
return acc.NMetrics() >= uint64(len(expected))
}, time.Second, 100*time.Millisecond, "did not receive metrics (%d)", acc.NMetrics())
actual := acc.GetTelegrafMetrics()
testutil.RequireMetricsEqual(t, expected, actual, testutil.SortMetrics())
}
func TestTLSMemLeak(t *testing.T) {
// For issue https://github.com/influxdata/telegraf/issues/15509
// Prepare the address and socket if needed
serviceAddress := "tcp://127.0.0.1:0"
// Setup a TLS socket to trigger the issue
cfg := &Config{
ServerConfig: *pki.TLSServerConfig(),
}
// Create the socket
sock, err := cfg.NewSocket(serviceAddress, nil, &testutil.Logger{})
require.NoError(t, err)
// Create callbacks
onConnection := func(_ net.Addr, reader io.ReadCloser) {
//nolint:errcheck // We are not interested in the data so ignore all errors
io.Copy(io.Discard, reader)
}
// Start the listener
require.NoError(t, sock.Setup())
sock.ListenConnection(onConnection, nil)
defer sock.Close()
addr := sock.Address()
// Setup the client side TLS
tlsCfg, err := pki.TLSClientConfig().TLSConfig()
require.NoError(t, err)
// Define a single client write sequence
data := []byte("test value=42i")
write := func() error {
conn, err := tls.Dial("tcp", addr.String(), tlsCfg)
if err != nil {
return err
}
defer conn.Close()
_, err = conn.Write(data)
return err
}
// Define a test with the given number of connections
maxConcurrency := runtime.GOMAXPROCS(0)
testCycle := func(connections int) (uint64, error) {
var mu sync.Mutex
var errs []error
var wg sync.WaitGroup
for count := 1; count < connections; count++ {
wg.Add(1)
go func() {
defer wg.Done()
if err := write(); err != nil {
mu.Lock()
errs = append(errs, err)
mu.Unlock()
}
}()
if count%maxConcurrency == 0 {
wg.Wait()
mu.Lock()
if len(errs) > 0 {
mu.Unlock()
return 0, errors.Join(errs...)
}
mu.Unlock()
}
}
//nolint:revive // We need to actively run the garbage collector to get reliable measurements
runtime.GC()
var stats runtime.MemStats
runtime.ReadMemStats(&stats)
return stats.HeapObjects, nil
}
// Measure the memory usage after a short warmup and after some time.
// The final number of heap objects should not exceed the number of
// runs by a save margin
// Warmup, do a low number of runs to initialize all data structures
// taking them out of the equation.
initial, err := testCycle(100)
require.NoError(t, err)
// Do some more runs and make sure the memory growth is bound
final, err := testCycle(2000)
require.NoError(t, err)
require.Less(t, final, 3*initial)
}
func createClient(endpoint string, addr net.Addr, tlsCfg *tls.Config) (net.Conn, error) {
// Determine the protocol in a crude fashion
parts := strings.SplitN(endpoint, "://", 2)
if len(parts) != 2 {
return nil, fmt.Errorf("invalid endpoint %q", endpoint)
}
protocol := parts[0]
if tlsCfg == nil {
return net.Dial(protocol, addr.String())
}
if protocol == "unix" {
tlsCfg.InsecureSkipVerify = true
}
return tls.Dial(protocol, addr.String(), tlsCfg)
}

View file

@ -0,0 +1,37 @@
## Message splitting strategy and corresponding settings for stream sockets
## (tcp, tcp4, tcp6, unix or unixpacket). The setting is ignored for packet
## listeners such as udp.
## Available strategies are:
## newline -- split at newlines (default)
## null -- split at null bytes
## delimiter -- split at delimiter byte-sequence in hex-format
## given in `splitting_delimiter`
## fixed length -- split after number of bytes given in `splitting_length`
## variable length -- split depending on length information received in the
## data. The length field information is specified in
## `splitting_length_field`.
# splitting_strategy = "newline"
## Delimiter used to split received data to messages consumed by the parser.
## The delimiter is a hex byte-sequence marking the end of a message
## e.g. "0x0D0A", "x0d0a" or "0d0a" marks a Windows line-break (CR LF).
## The value is case-insensitive and can be specified with "0x" or "x" prefix
## or without.
## Note: This setting is only used for splitting_strategy = "delimiter".
# splitting_delimiter = ""
## Fixed length of a message in bytes.
## Note: This setting is only used for splitting_strategy = "fixed length".
# splitting_length = 0
## Specification of the length field contained in the data to split messages
## with variable length. The specification contains the following fields:
## offset -- start of length field in bytes from begin of data
## bytes -- length of length field in bytes
## endianness -- endianness of the value, either "be" for big endian or
## "le" for little endian
## header_length -- total length of header to be skipped when passing
## data on to the parser. If zero (default), the header
## is passed on to the parser together with the message.
## Note: This setting is only used for splitting_strategy = "variable length".
# splitting_length_field = {offset = 0, bytes = 0, endianness = "be", header_length = 0}

View file

@ -0,0 +1,167 @@
package socket
import (
"bufio"
"bytes"
"encoding/binary"
"encoding/hex"
"fmt"
"regexp"
"strings"
)
type lengthFieldSpec struct {
Offset int64 `toml:"offset"`
Bytes int64 `toml:"bytes"`
Endianness string `toml:"endianness"`
HeaderLength int64 `toml:"header_length"`
converter func([]byte) int
}
type SplitConfig struct {
SplittingStrategy string `toml:"splitting_strategy"`
SplittingDelimiter string `toml:"splitting_delimiter"`
SplittingLength int `toml:"splitting_length"`
SplittingLengthField lengthFieldSpec `toml:"splitting_length_field"`
}
func (cfg *SplitConfig) NewSplitter() (bufio.SplitFunc, error) {
switch cfg.SplittingStrategy {
case "", "newline":
return bufio.ScanLines, nil
case "null":
return scanNull, nil
case "delimiter":
re := regexp.MustCompile(`(\s*0?x)`)
d := re.ReplaceAllString(strings.ToLower(cfg.SplittingDelimiter), "")
delimiter, err := hex.DecodeString(d)
if err != nil {
return nil, fmt.Errorf("decoding delimiter failed: %w", err)
}
return createScanDelimiter(delimiter), nil
case "fixed length":
return createScanFixedLength(cfg.SplittingLength), nil
case "variable length":
// Create the converter function
var order binary.ByteOrder
switch strings.ToLower(cfg.SplittingLengthField.Endianness) {
case "", "be":
order = binary.BigEndian
case "le":
order = binary.LittleEndian
default:
return nil, fmt.Errorf("invalid 'endianness' %q", cfg.SplittingLengthField.Endianness)
}
switch cfg.SplittingLengthField.Bytes {
case 1:
cfg.SplittingLengthField.converter = func(b []byte) int {
return int(b[0])
}
case 2:
cfg.SplittingLengthField.converter = func(b []byte) int {
return int(order.Uint16(b))
}
case 4:
cfg.SplittingLengthField.converter = func(b []byte) int {
return int(order.Uint32(b))
}
case 8:
cfg.SplittingLengthField.converter = func(b []byte) int {
return int(order.Uint64(b))
}
default:
cfg.SplittingLengthField.converter = func(b []byte) int {
buf := make([]byte, 8)
start := 0
if order == binary.BigEndian {
start = 8 - len(b)
}
for i := 0; i < len(b); i++ {
buf[start+i] = b[i]
}
return int(order.Uint64(buf))
}
}
// Check if we have enough bytes in the header
return createScanVariableLength(cfg.SplittingLengthField), nil
}
return nil, fmt.Errorf("unknown 'splitting_strategy' %q", cfg.SplittingStrategy)
}
func scanNull(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := bytes.IndexByte(data, 0); i >= 0 {
return i + 1, data[:i], nil
}
if atEOF {
return len(data), data, nil
}
// Request more data.
return 0, nil, nil
}
func createScanDelimiter(delimiter []byte) bufio.SplitFunc {
return func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := bytes.Index(data, delimiter); i >= 0 {
return i + len(delimiter), data[:i], nil
}
if atEOF {
return len(data), data, nil
}
// Request more data.
return 0, nil, nil
}
}
func createScanFixedLength(length int) bufio.SplitFunc {
return func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if len(data) >= length {
return length, data[:length], nil
}
if atEOF {
return len(data), data, nil
}
// Request more data.
return 0, nil, nil
}
}
func createScanVariableLength(spec lengthFieldSpec) bufio.SplitFunc {
minlen := int(spec.Offset)
minlen += int(spec.Bytes)
headerLen := int(spec.HeaderLength)
return func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
dataLen := len(data)
if dataLen >= minlen {
// Extract the length field and convert it to a number
lf := data[spec.Offset : spec.Offset+spec.Bytes]
length := spec.converter(lf)
start := headerLen
end := length + headerLen
// If we have enough data return it without the header
if end <= dataLen {
return end, data[start:end], nil
}
}
if atEOF {
return len(data), data, nil
}
// Request more data.
return 0, nil, nil
}
}

View file

@ -0,0 +1,476 @@
package socket
import (
"bufio"
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"math"
"net"
"net/url"
"os"
"path/filepath"
"runtime"
"strconv"
"strings"
"sync"
"syscall"
"time"
"github.com/alitto/pond"
"github.com/mdlayher/vsock"
"github.com/influxdata/telegraf"
"github.com/influxdata/telegraf/config"
"github.com/influxdata/telegraf/internal"
)
type hasSetReadBuffer interface {
SetReadBuffer(bytes int) error
}
type streamListener struct {
Encoding string
ReadBufferSize int
MaxConnections uint64
ReadTimeout config.Duration
KeepAlivePeriod *config.Duration
Splitter bufio.SplitFunc
Log telegraf.Logger
listener net.Listener
connections uint64
path string
cancel context.CancelFunc
parsePool *pond.WorkerPool
wg sync.WaitGroup
sync.Mutex
}
func newStreamListener(conf Config, splitter bufio.SplitFunc, log telegraf.Logger) *streamListener {
return &streamListener{
ReadBufferSize: int(conf.ReadBufferSize),
ReadTimeout: conf.ReadTimeout,
KeepAlivePeriod: conf.KeepAlivePeriod,
MaxConnections: conf.MaxConnections,
Encoding: conf.ContentEncoding,
Splitter: splitter,
Log: log,
parsePool: pond.New(
conf.MaxParallelParsers,
0,
pond.MinWorkers(conf.MaxParallelParsers/2+1)),
}
}
func (l *streamListener) setupTCP(u *url.URL, tlsCfg *tls.Config) error {
var err error
if tlsCfg == nil {
l.listener, err = net.Listen(u.Scheme, u.Host)
} else {
l.listener, err = tls.Listen(u.Scheme, u.Host, tlsCfg)
}
return err
}
func (l *streamListener) setupUnix(u *url.URL, tlsCfg *tls.Config, socketMode string) error {
l.path = filepath.FromSlash(u.Path)
if runtime.GOOS == "windows" && strings.Contains(l.path, ":") {
l.path = strings.TrimPrefix(l.path, `\`)
}
if err := os.Remove(l.path); err != nil && !errors.Is(err, os.ErrNotExist) {
return fmt.Errorf("removing socket failed: %w", err)
}
var err error
if tlsCfg == nil {
l.listener, err = net.Listen(u.Scheme, l.path)
} else {
l.listener, err = tls.Listen(u.Scheme, l.path, tlsCfg)
}
if err != nil {
return err
}
// Set permissions on socket
if socketMode != "" {
// Convert from octal in string to int
i, err := strconv.ParseUint(socketMode, 8, 32)
if err != nil {
return fmt.Errorf("converting socket mode failed: %w", err)
}
perm := os.FileMode(uint32(i))
if err := os.Chmod(u.Path, perm); err != nil {
return fmt.Errorf("changing socket permissions failed: %w", err)
}
}
return nil
}
func (l *streamListener) setupVsock(u *url.URL) error {
var err error
addrTuple := strings.SplitN(u.String(), ":", 2)
// Check address string for containing two tokens
if len(addrTuple) < 2 {
return errors.New("port and/or CID number missing")
}
// Parse CID and port number from address string both being 32-bit
// source: https://man7.org/linux/man-pages/man7/vsock.7.html
cid, err := strconv.ParseUint(addrTuple[0], 10, 32)
if err != nil {
return fmt.Errorf("failed to parse CID %s: %w", addrTuple[0], err)
}
if (cid >= uint64(math.Pow(2, 32))-1) && (cid <= 0) {
return fmt.Errorf("value of CID %d is out of range", cid)
}
port, err := strconv.ParseUint(addrTuple[1], 10, 32)
if err != nil {
return fmt.Errorf("failed to parse port number %s: %w", addrTuple[1], err)
}
if (port >= uint64(math.Pow(2, 32))-1) && (port <= 0) {
return fmt.Errorf("port number %d is out of range", port)
}
l.listener, err = vsock.Listen(uint32(port), nil)
return err
}
func (l *streamListener) setupConnection(conn net.Conn) error {
addr := conn.RemoteAddr().String()
l.Lock()
if l.MaxConnections > 0 && l.connections >= l.MaxConnections {
l.Unlock()
// Ignore the returned error as we cannot do anything about it anyway
_ = conn.Close()
return fmt.Errorf("unable to accept connection from %q: too many connections", addr)
}
l.connections++
l.Unlock()
if l.ReadBufferSize > 0 {
if rb, ok := conn.(hasSetReadBuffer); ok {
if err := rb.SetReadBuffer(l.ReadBufferSize); err != nil {
l.Log.Warnf("Setting read buffer on socket failed: %v", err)
}
} else {
l.Log.Warn("Cannot set read buffer on socket of this type")
}
}
// Set keep alive handlings
if l.KeepAlivePeriod != nil {
if c, ok := conn.(*tls.Conn); ok {
conn = c.NetConn()
}
tcpConn, ok := conn.(*net.TCPConn)
if !ok {
l.Log.Warnf("connection not a TCP connection (%T)", conn)
}
if *l.KeepAlivePeriod == 0 {
if err := tcpConn.SetKeepAlive(false); err != nil {
l.Log.Warnf("Cannot set keep-alive: %v", err)
}
} else {
if err := tcpConn.SetKeepAlive(true); err != nil {
l.Log.Warnf("Cannot set keep-alive: %v", err)
}
err := tcpConn.SetKeepAlivePeriod(time.Duration(*l.KeepAlivePeriod))
if err != nil {
l.Log.Warnf("Cannot set keep-alive period: %v", err)
}
}
}
return nil
}
func (l *streamListener) closeConnection(conn net.Conn) {
// Fallback to enforce blocked reads on connections to end immediately
//nolint:errcheck // Ignore errors as this is a fallback only
conn.SetReadDeadline(time.Now())
addr := conn.RemoteAddr().String()
if err := conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) && !errors.Is(err, syscall.EPIPE) {
l.Log.Warnf("Cannot close connection to %q: %v", addr, err)
} else {
l.Lock()
l.connections--
l.Unlock()
}
}
func (l *streamListener) address() net.Addr {
return l.listener.Addr()
}
func (l *streamListener) close() error {
if l.listener != nil {
// Continue even if we cannot close the listener in order to at least
// close all active connections
if err := l.listener.Close(); err != nil {
l.Log.Errorf("Cannot close listener: %v", err)
}
}
if l.cancel != nil {
l.cancel()
l.cancel = nil
}
l.wg.Wait()
if l.path != "" {
fn := filepath.FromSlash(l.path)
if runtime.GOOS == "windows" && strings.Contains(fn, ":") {
fn = strings.TrimPrefix(fn, `\`)
}
// Ignore file-not-exists errors when removing the socket
if err := os.Remove(fn); err != nil && !errors.Is(err, os.ErrNotExist) {
return err
}
}
l.parsePool.StopAndWait()
return nil
}
func (l *streamListener) listenData(onData CallbackData, onError CallbackError) {
ctx, cancel := context.WithCancel(context.Background())
l.cancel = cancel
l.wg.Add(1)
go func() {
defer l.wg.Done()
for {
conn, err := l.listener.Accept()
if err != nil {
if !errors.Is(err, net.ErrClosed) && onError != nil {
onError(err)
}
break
}
if err := l.setupConnection(conn); err != nil && onError != nil {
onError(err)
continue
}
l.wg.Add(1)
go l.handleReaderConn(ctx, conn, onData, onError)
}
}()
}
func (l *streamListener) handleReaderConn(ctx context.Context, conn net.Conn, onData CallbackData, onError CallbackError) {
defer l.wg.Done()
localCtx, cancel := context.WithCancel(ctx)
defer cancel()
defer l.closeConnection(conn)
stopFunc := context.AfterFunc(localCtx, func() { l.closeConnection(conn) })
defer stopFunc()
reader := l.read
if l.Splitter == nil {
reader = l.readAll
}
if err := reader(conn, onData); err != nil {
if !errors.Is(err, io.EOF) && !errors.Is(err, syscall.ECONNRESET) {
if onError != nil {
onError(err)
}
}
}
}
func (l *streamListener) listenConnection(onConnection CallbackConnection, onError CallbackError) {
ctx, cancel := context.WithCancel(context.Background())
l.cancel = cancel
l.wg.Add(1)
go func() {
defer l.wg.Done()
for {
conn, err := l.listener.Accept()
if err != nil {
if !errors.Is(err, net.ErrClosed) && onError != nil {
onError(err)
}
break
}
if err := l.setupConnection(conn); err != nil && onError != nil {
onError(err)
continue
}
l.wg.Add(1)
go func(c net.Conn) {
if err := l.handleConnection(ctx, c, onConnection); err != nil {
if !errors.Is(err, io.EOF) && !errors.Is(err, syscall.ECONNRESET) {
if onError != nil {
onError(err)
}
}
}
}(conn)
}
}()
}
func (l *streamListener) read(conn net.Conn, onData CallbackData) error {
decoder, err := internal.NewStreamContentDecoder(l.Encoding, conn)
if err != nil {
return fmt.Errorf("creating decoder failed: %w", err)
}
timeout := time.Duration(l.ReadTimeout)
scanner := bufio.NewScanner(decoder)
if l.ReadBufferSize > bufio.MaxScanTokenSize {
scanner.Buffer(make([]byte, l.ReadBufferSize), l.ReadBufferSize)
}
scanner.Split(l.Splitter)
for {
// Set the read deadline, if any, then start reading. The read
// will accept the deadline and return if no or insufficient data
// arrived in time. We need to set the deadline in every cycle as
// it is an ABSOLUTE time and not a timeout.
if timeout > 0 {
deadline := time.Now().Add(timeout)
if err := conn.SetReadDeadline(deadline); err != nil {
return fmt.Errorf("setting read deadline failed: %w", err)
}
}
if !scanner.Scan() {
// Exit if no data arrived e.g. due to timeout or closed connection
break
}
receiveTime := time.Now()
src := conn.RemoteAddr()
if l.path != "" {
src = &net.UnixAddr{Name: l.path, Net: "unix"}
}
data := scanner.Bytes()
d := make([]byte, len(data))
copy(d, data)
l.parsePool.Submit(func() {
onData(src, d, receiveTime)
})
}
if err := scanner.Err(); err != nil {
if errors.Is(err, os.ErrDeadlineExceeded) {
// Ignore the timeout and silently close the connection
l.Log.Debug(err)
return nil
}
if errors.Is(err, net.ErrClosed) {
// Ignore the connection closing of the remote side
return nil
}
return err
}
return nil
}
func (l *streamListener) readAll(conn net.Conn, onData CallbackData) error {
src := conn.RemoteAddr()
if l.path != "" {
src = &net.UnixAddr{Name: l.path, Net: "unix"}
}
decoder, err := internal.NewStreamContentDecoder(l.Encoding, conn)
if err != nil {
return fmt.Errorf("creating decoder failed: %w", err)
}
timeout := time.Duration(l.ReadTimeout)
// Set the read deadline, if any, then start reading. The read
// will accept the deadline and return if no or insufficient data
// arrived in time. We need to set the deadline in every cycle as
// it is an ABSOLUTE time and not a timeout.
if timeout > 0 {
deadline := time.Now().Add(timeout)
if err := conn.SetReadDeadline(deadline); err != nil {
return fmt.Errorf("setting read deadline failed: %w", err)
}
}
buf, err := io.ReadAll(decoder)
if err != nil {
return fmt.Errorf("read on %s failed: %w", src, err)
}
receiveTime := time.Now()
l.parsePool.Submit(func() {
onData(src, buf, receiveTime)
})
return nil
}
func (l *streamListener) handleConnection(ctx context.Context, conn net.Conn, onConnection CallbackConnection) error {
defer l.wg.Done()
localCtx, cancel := context.WithCancel(ctx)
defer cancel()
defer l.closeConnection(conn)
stopFunc := context.AfterFunc(localCtx, func() { l.closeConnection(conn) })
defer stopFunc()
// Prepare the data decoder for the connection
decoder, err := internal.NewStreamContentDecoder(l.Encoding, conn)
if err != nil {
return fmt.Errorf("creating decoder failed: %w", err)
}
// Get the remote address
src := conn.RemoteAddr()
if l.path != "" {
src = &net.UnixAddr{Name: l.path, Net: "unix"}
}
// Create a pipe and feed it to the callback
reader, writer := io.Pipe()
defer writer.Close()
go onConnection(src, reader)
timeout := time.Duration(l.ReadTimeout)
buf := make([]byte, 4096) // 4kb
for {
// Set the read deadline, if any, then start reading. The read
// will accept the deadline and return if no or insufficient data
// arrived in time. We need to set the deadline in every cycle as
// it is an ABSOLUTE time and not a timeout.
if timeout > 0 {
deadline := time.Now().Add(timeout)
if err := conn.SetReadDeadline(deadline); err != nil {
return fmt.Errorf("setting read deadline failed: %w", err)
}
}
// Copy the data
n, err := decoder.Read(buf)
if err != nil {
if !strings.HasSuffix(err.Error(), ": use of closed network connection") {
if !errors.Is(err, os.ErrDeadlineExceeded) && errors.Is(err, net.ErrClosed) {
writer.CloseWithError(err)
}
}
return nil
}
if _, err := writer.Write(buf[:n]); err != nil {
return err
}
}
}

View file

@ -0,0 +1,305 @@
package starlark
import (
"fmt"
"sort"
"time"
"go.starlark.net/starlark"
"github.com/influxdata/telegraf"
"github.com/influxdata/telegraf/metric"
)
func newMetric(_ *starlark.Thread, _ *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) {
var (
name starlark.String
tags, fields starlark.Value
)
if err := starlark.UnpackArgs("Metric", args, kwargs, "name", &name, "tags?", &tags, "fields?", &fields); err != nil {
return nil, err
}
allFields, err := toFields(fields)
if err != nil {
return nil, err
}
allTags, err := toTags(tags)
if err != nil {
return nil, err
}
m := metric.New(string(name), allTags, allFields, time.Now())
return &Metric{metric: m}, nil
}
func toString(value starlark.Value, errorMsg string) (string, error) {
if value, ok := value.(starlark.String); ok {
return string(value), nil
}
return "", fmt.Errorf(errorMsg, value)
}
func items(value starlark.Value, errorMsg string) ([]starlark.Tuple, error) {
if iter, ok := value.(starlark.IterableMapping); ok {
return iter.Items(), nil
}
return nil, fmt.Errorf(errorMsg, value)
}
func deepcopy(_ *starlark.Thread, _ *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) {
var sm *Metric
var track bool
if err := starlark.UnpackArgs("deepcopy", args, kwargs, "source", &sm, "track?", &track); err != nil {
return nil, err
}
// In case we copy a tracking metric but do not want to track the result,
// we have to strip the tracking information. This can be done by unwrapping
// the metric.
if tm, ok := sm.metric.(telegraf.TrackingMetric); ok && !track {
return &Metric{metric: tm.Unwrap().Copy()}, nil
}
// Copy the whole metric including potential tracking information
return &Metric{metric: sm.metric.Copy()}, nil
}
// catch(f) evaluates f() and returns its evaluation error message
// if it failed or None if it succeeded.
func catch(thread *starlark.Thread, _ *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) {
var fn starlark.Callable
if err := starlark.UnpackArgs("catch", args, kwargs, "fn", &fn); err != nil {
return nil, err
}
if _, err := starlark.Call(thread, fn, nil, nil); err != nil {
//nolint:nilerr // nil returned on purpose, error put inside starlark.Value
return starlark.String(err.Error()), nil
}
return starlark.None, nil
}
type builtinMethod func(b *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error)
func builtinAttr(recv starlark.Value, name string, methods map[string]builtinMethod) (starlark.Value, error) {
method := methods[name]
if method == nil {
return starlark.None, fmt.Errorf("no such method %q", name)
}
// Allocate a closure over 'method'.
impl := func(_ *starlark.Thread, b *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) {
return method(b, args, kwargs)
}
return starlark.NewBuiltin(name, impl).BindReceiver(recv), nil
}
func builtinAttrNames(methods map[string]builtinMethod) []string {
names := make([]string, 0, len(methods))
for name := range methods {
names = append(names, name)
}
sort.Strings(names)
return names
}
// --- dictionary methods ---
// https://github.com/google/starlark-go/blob/master/doc/spec.md#dict·clear
func dictClear(b *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) {
if err := starlark.UnpackPositionalArgs(b.Name(), args, kwargs, 0); err != nil {
return starlark.None, fmt.Errorf("%s: %w", b.Name(), err)
}
type HasClear interface {
Clear() error
}
return starlark.None, b.Receiver().(HasClear).Clear()
}
// https://github.com/google/starlark-go/blob/master/doc/spec.md#dict·pop
func dictPop(b *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) {
var k, d starlark.Value
if err := starlark.UnpackPositionalArgs(b.Name(), args, kwargs, 1, &k, &d); err != nil {
return starlark.None, fmt.Errorf("%s: %w", b.Name(), err)
}
type HasDelete interface {
Delete(k starlark.Value) (starlark.Value, bool, error)
}
if v, found, err := b.Receiver().(HasDelete).Delete(k); err != nil {
return starlark.None, fmt.Errorf("%s: %w", b.Name(), err) // dict is frozen or key is unhashable
} else if found {
return v, nil
} else if d != nil {
return d, nil
}
return starlark.None, fmt.Errorf("%s: missing key", b.Name())
}
// https://github.com/google/starlark-go/blob/master/doc/spec.md#dict·popitem
func dictPopitem(b *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) {
if err := starlark.UnpackPositionalArgs(b.Name(), args, kwargs, 0); err != nil {
return starlark.None, fmt.Errorf("%s: %w", b.Name(), err)
}
type HasPopItem interface {
PopItem() (starlark.Value, error)
}
return b.Receiver().(HasPopItem).PopItem()
}
// https://github.com/google/starlark-go/blob/master/doc/spec.md#dict·get
func dictGet(b *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) {
var key, dflt starlark.Value
if err := starlark.UnpackPositionalArgs(b.Name(), args, kwargs, 1, &key, &dflt); err != nil {
return starlark.None, fmt.Errorf("%s: %w", b.Name(), err)
}
if v, ok, err := b.Receiver().(starlark.Mapping).Get(key); err != nil {
return starlark.None, fmt.Errorf("%s: %w", b.Name(), err)
} else if ok {
return v, nil
} else if dflt != nil {
return dflt, nil
}
return starlark.None, nil
}
// https://github.com/google/starlark-go/blob/master/doc/spec.md#dict·setdefault
func dictSetdefault(b *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) {
var key, dflt starlark.Value = nil, starlark.None
if err := starlark.UnpackPositionalArgs(b.Name(), args, kwargs, 1, &key, &dflt); err != nil {
return starlark.None, fmt.Errorf("%s: %w", b.Name(), err)
}
recv := b.Receiver().(starlark.HasSetKey)
v, found, err := recv.Get(key)
if err != nil {
return starlark.None, fmt.Errorf("%s: %w", b.Name(), err)
}
if !found {
v = dflt
if err := recv.SetKey(key, dflt); err != nil {
return starlark.None, fmt.Errorf("%s: %w", b.Name(), err)
}
}
return v, nil
}
// https://github.com/google/starlark-go/blob/master/doc/spec.md#dict·update
func dictUpdate(b *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) {
// Unpack the arguments
if len(args) > 1 {
return nil, fmt.Errorf("update: got %d arguments, want at most 1", len(args))
}
// Get the target
dict := b.Receiver().(starlark.HasSetKey)
if len(args) == 1 {
switch updates := args[0].(type) {
case starlark.IterableMapping:
// Iterate over dict's key/value pairs, not just keys.
for _, item := range updates.Items() {
if err := dict.SetKey(item[0], item[1]); err != nil {
return nil, err // dict is frozen
}
}
default:
// all other sequences
iter := starlark.Iterate(updates)
if iter == nil {
return nil, fmt.Errorf("got %s, want iterable", updates.Type())
}
defer iter.Done()
var pair starlark.Value
for i := 0; iter.Next(&pair); i++ {
iterErr := func() error {
iter2 := starlark.Iterate(pair)
if iter2 == nil {
return fmt.Errorf("dictionary update sequence element #%d is not iterable (%s)", i, pair.Type())
}
defer iter2.Done()
length := starlark.Len(pair)
if length < 0 {
return fmt.Errorf("dictionary update sequence element #%d has unknown length (%s)", i, pair.Type())
} else if length != 2 {
return fmt.Errorf("dictionary update sequence element #%d has length %d, want 2", i, length)
}
var k, v starlark.Value
iter2.Next(&k)
iter2.Next(&v)
return dict.SetKey(k, v)
}()
if iterErr != nil {
return nil, iterErr
}
}
}
}
// Then add the kwargs.
before := starlark.Len(dict)
for _, pair := range kwargs {
if err := dict.SetKey(pair[0], pair[1]); err != nil {
return nil, err // dict is frozen
}
}
// In the common case, each kwarg will add another dict entry.
// If that's not so, check whether it is because there was a duplicate kwarg.
if starlark.Len(dict) < before+len(kwargs) {
keys := make(map[starlark.String]bool, len(kwargs))
for _, kv := range kwargs {
k := kv[0].(starlark.String)
if keys[k] {
return nil, fmt.Errorf("duplicate keyword arg: %v", k)
}
keys[k] = true
}
}
return starlark.None, nil
}
// https://github.com/google/starlark-go/blob/master/doc/spec.md#dict·items
func dictItems(b *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) {
if err := starlark.UnpackPositionalArgs(b.Name(), args, kwargs, 0); err != nil {
return starlark.None, fmt.Errorf("%s: %w", b.Name(), err)
}
items := b.Receiver().(starlark.IterableMapping).Items()
res := make([]starlark.Value, 0, len(items))
for _, item := range items {
res = append(res, item) // convert [2]starlark.Value to starlark.Value
}
return starlark.NewList(res), nil
}
// https://github.com/google/starlark-go/blob/master/doc/spec.md#dict·keys
func dictKeys(b *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) {
if err := starlark.UnpackPositionalArgs(b.Name(), args, kwargs, 0); err != nil {
return starlark.None, fmt.Errorf("%s: %w", b.Name(), err)
}
items := b.Receiver().(starlark.IterableMapping).Items()
res := make([]starlark.Value, 0, len(items))
for _, item := range items {
res = append(res, item[0])
}
return starlark.NewList(res), nil
}
// https://github.com/google/starlark-go/blob/master/doc/spec.md#dict·update
func dictValues(b *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) {
if err := starlark.UnpackPositionalArgs(b.Name(), args, kwargs, 0); err != nil {
return starlark.None, fmt.Errorf("%s: %w", b.Name(), err)
}
items := b.Receiver().(starlark.IterableMapping).Items()
res := make([]starlark.Value, 0, len(items))
for _, item := range items {
res = append(res, item[1])
}
return starlark.NewList(res), nil
}

View file

@ -0,0 +1,308 @@
package starlark
import (
"errors"
"fmt"
"reflect"
"strings"
"go.starlark.net/starlark"
"github.com/influxdata/telegraf"
)
// FieldDict is a starlark.Value for the metric fields. It is heavily based on the
// starlark.Dict.
type FieldDict struct {
*Metric
}
func (d FieldDict) String() string {
buf := new(strings.Builder)
buf.WriteString("{")
sep := ""
for _, item := range d.Items() {
k, v := item[0], item[1]
buf.WriteString(sep)
buf.WriteString(k.String())
buf.WriteString(": ")
buf.WriteString(v.String())
sep = ", "
}
buf.WriteString("}")
return buf.String()
}
func (FieldDict) Type() string {
return "Fields"
}
func (d FieldDict) Freeze() {
// Disable linter check as the frozen variable is modified despite
// passing a value instead of a pointer, because `FieldDict` holds
// a pointer to the underlying metric containing the `frozen` field.
//revive:disable:modifies-value-receiver
d.frozen = true
}
func (d FieldDict) Truth() starlark.Bool {
return len(d.metric.FieldList()) != 0
}
func (FieldDict) Hash() (uint32, error) {
return 0, errors.New("not hashable")
}
// AttrNames implements the starlark.HasAttrs interface.
func (FieldDict) AttrNames() []string {
return builtinAttrNames(FieldDictMethods)
}
// Attr implements the starlark.HasAttrs interface.
func (d FieldDict) Attr(name string) (starlark.Value, error) {
return builtinAttr(d, name, FieldDictMethods)
}
var FieldDictMethods = map[string]builtinMethod{
"clear": dictClear,
"get": dictGet,
"items": dictItems,
"keys": dictKeys,
"pop": dictPop,
"popitem": dictPopitem,
"setdefault": dictSetdefault,
"update": dictUpdate,
"values": dictValues,
}
// Get implements the starlark.Mapping interface.
func (d FieldDict) Get(key starlark.Value) (v starlark.Value, found bool, err error) {
if k, ok := key.(starlark.String); ok {
gv, found := d.metric.GetField(k.GoString())
if !found {
return starlark.None, false, nil
}
v, err := asStarlarkValue(gv)
if err != nil {
return starlark.None, false, err
}
return v, true, nil
}
return starlark.None, false, errors.New("key must be of type 'str'")
}
// SetKey implements the starlark.HasSetKey interface to support map update
// using x[k]=v syntax, like a dictionary.
func (d FieldDict) SetKey(k, v starlark.Value) error {
if d.fieldIterCount > 0 {
return errors.New("cannot insert during iteration")
}
key, ok := k.(starlark.String)
if !ok {
return errors.New("field key must be of type 'str'")
}
gv, err := asGoValue(v)
if err != nil {
return err
}
d.metric.AddField(key.GoString(), gv)
return nil
}
// Items implements the starlark.IterableMapping interface.
func (d FieldDict) Items() []starlark.Tuple {
items := make([]starlark.Tuple, 0, len(d.metric.FieldList()))
for _, field := range d.metric.FieldList() {
key := starlark.String(field.Key)
sv, err := asStarlarkValue(field.Value)
if err != nil {
continue
}
pair := starlark.Tuple{key, sv}
items = append(items, pair)
}
return items
}
func (d FieldDict) Clear() error {
if d.fieldIterCount > 0 {
return errors.New("cannot delete during iteration")
}
keys := make([]string, 0, len(d.metric.FieldList()))
for _, field := range d.metric.FieldList() {
keys = append(keys, field.Key)
}
for _, key := range keys {
d.metric.RemoveField(key)
}
return nil
}
func (d FieldDict) PopItem() (starlark.Value, error) {
if d.fieldIterCount > 0 {
return nil, errors.New("cannot delete during iteration")
}
if len(d.metric.FieldList()) == 0 {
return nil, errors.New("popitem(): field dictionary is empty")
}
field := d.metric.FieldList()[0]
k := field.Key
v := field.Value
d.metric.RemoveField(k)
sk := starlark.String(k)
sv, err := asStarlarkValue(v)
if err != nil {
return nil, errors.New("could not convert to starlark value")
}
return starlark.Tuple{sk, sv}, nil
}
func (d FieldDict) Delete(k starlark.Value) (v starlark.Value, found bool, err error) {
if d.fieldIterCount > 0 {
return nil, false, errors.New("cannot delete during iteration")
}
if key, ok := k.(starlark.String); ok {
value, ok := d.metric.GetField(key.GoString())
if ok {
d.metric.RemoveField(key.GoString())
sv, err := asStarlarkValue(value)
return sv, ok, err
}
return starlark.None, false, nil
}
return starlark.None, false, errors.New("key must be of type 'str'")
}
// Iterate implements the starlark.Iterator interface.
func (d FieldDict) Iterate() starlark.Iterator {
d.fieldIterCount++
return &FieldIterator{Metric: d.Metric, fields: d.metric.FieldList()}
}
type FieldIterator struct {
*Metric
fields []*telegraf.Field
}
// Next implements the starlark.Iterator interface.
func (i *FieldIterator) Next(p *starlark.Value) bool {
if len(i.fields) == 0 {
return false
}
field := i.fields[0]
i.fields = i.fields[1:]
*p = starlark.String(field.Key)
return true
}
// Done implements the starlark.Iterator interface.
func (i *FieldIterator) Done() {
i.fieldIterCount--
}
// AsStarlarkValue converts a field value to a starlark.Value.
func asStarlarkValue(value interface{}) (starlark.Value, error) {
v := reflect.ValueOf(value)
switch v.Kind() {
case reflect.Slice:
length := v.Len()
array := make([]starlark.Value, 0, length)
for i := 0; i < length; i++ {
sVal, err := asStarlarkValue(v.Index(i).Interface())
if err != nil {
return starlark.None, err
}
array = append(array, sVal)
}
return starlark.NewList(array), nil
case reflect.Map:
dict := starlark.NewDict(v.Len())
iter := v.MapRange()
for iter.Next() {
sKey, err := asStarlarkValue(iter.Key().Interface())
if err != nil {
return starlark.None, err
}
sValue, err := asStarlarkValue(iter.Value().Interface())
if err != nil {
return starlark.None, err
}
if err := dict.SetKey(sKey, sValue); err != nil {
return starlark.None, err
}
}
return dict, nil
case reflect.Float32, reflect.Float64:
return starlark.Float(v.Float()), nil
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return starlark.MakeInt64(v.Int()), nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return starlark.MakeUint64(v.Uint()), nil
case reflect.String:
return starlark.String(v.String()), nil
case reflect.Bool:
return starlark.Bool(v.Bool()), nil
}
return nil, fmt.Errorf("invalid type %T", value)
}
// AsGoValue converts a starlark.Value to a field value.
func asGoValue(value interface{}) (interface{}, error) {
switch v := value.(type) {
case starlark.Float:
return float64(v), nil
case starlark.Int:
n, ok := v.Int64()
if !ok {
return nil, fmt.Errorf("cannot represent integer %v as int64", v)
}
return n, nil
case starlark.String:
return string(v), nil
case starlark.Bool:
return bool(v), nil
}
return nil, fmt.Errorf("invalid starlark type %T", value)
}
// ToFields converts a starlark.Value to a map of values.
func toFields(value starlark.Value) (map[string]interface{}, error) {
if value == nil {
return nil, nil
}
items, err := items(value, "The type %T is unsupported as type of collection of fields")
if err != nil {
return nil, err
}
result := make(map[string]interface{}, len(items))
for _, item := range items {
key, err := toString(item[0], "The type %T is unsupported as type of key for fields")
if err != nil {
return nil, err
}
value, err := asGoValue(item[1])
if err != nil {
return nil, err
}
result[key] = value
}
return result, nil
}

View file

@ -0,0 +1,48 @@
package starlark
import (
"errors"
"fmt"
"go.starlark.net/starlark"
"go.starlark.net/starlarkstruct"
"github.com/influxdata/telegraf"
)
// Builds a module that defines all the supported logging functions which will log using the provided logger
func LogModule(logger telegraf.Logger) *starlarkstruct.Module {
var logFunc = func(_ *starlark.Thread, b *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) {
return log(b, args, kwargs, logger)
}
return &starlarkstruct.Module{
Name: "log",
Members: starlark.StringDict{
"debug": starlark.NewBuiltin("log.debug", logFunc),
"info": starlark.NewBuiltin("log.info", logFunc),
"warn": starlark.NewBuiltin("log.warn", logFunc),
"error": starlark.NewBuiltin("log.error", logFunc),
},
}
}
// Logs the provided message according to the level chosen
func log(b *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple, logger telegraf.Logger) (starlark.Value, error) {
var msg starlark.String
if err := starlark.UnpackPositionalArgs(b.Name(), args, kwargs, 1, &msg); err != nil {
return starlark.None, fmt.Errorf("%s: %w", b.Name(), err)
}
switch b.Name() {
case "log.debug":
logger.Debug(string(msg))
case "log.info":
logger.Info(string(msg))
case "log.warn":
logger.Warn(string(msg))
case "log.error":
logger.Error(string(msg))
default:
return nil, errors.New("method " + b.Name() + " is unknown")
}
return starlark.None, nil
}

View file

@ -0,0 +1,156 @@
package starlark
import (
"errors"
"fmt"
"strings"
"time"
"go.starlark.net/starlark"
"github.com/influxdata/telegraf"
)
type Metric struct {
ID telegraf.TrackingID
metric telegraf.Metric
tagIterCount int
fieldIterCount int
frozen bool
}
// Wrap updates the starlark.Metric to wrap a new telegraf.Metric.
func (m *Metric) Wrap(metric telegraf.Metric) {
if tm, ok := metric.(telegraf.TrackingMetric); ok {
m.ID = tm.TrackingID()
}
m.metric = metric
m.tagIterCount = 0
m.fieldIterCount = 0
m.frozen = false
}
// Unwrap removes the telegraf.Metric from the startlark.Metric.
func (m *Metric) Unwrap() telegraf.Metric {
return m.metric
}
// String returns the starlark representation of the Metric.
//
// The String function is called by both the repr() and str() functions, and so
// it behaves more like the repr function would in Python.
func (m *Metric) String() string {
buf := new(strings.Builder)
buf.WriteString("Metric(")
buf.WriteString(m.Name().String())
buf.WriteString(", tags=")
buf.WriteString(m.Tags().String())
buf.WriteString(", fields=")
buf.WriteString(m.Fields().String())
buf.WriteString(", time=")
buf.WriteString(m.Time().String())
buf.WriteString(")")
if m.ID != 0 {
fmt.Fprintf(buf, "[tracking ID=%v]", m.ID)
}
return buf.String()
}
func (*Metric) Type() string {
return "Metric"
}
func (m *Metric) Freeze() {
m.frozen = true
}
func (*Metric) Truth() starlark.Bool {
return true
}
func (*Metric) Hash() (uint32, error) {
return 0, errors.New("not hashable")
}
// AttrNames implements the starlark.HasAttrs interface.
func (*Metric) AttrNames() []string {
return []string{"name", "tags", "fields", "time"}
}
// Attr implements the starlark.HasAttrs interface.
func (m *Metric) Attr(name string) (starlark.Value, error) {
switch name {
case "name":
return m.Name(), nil
case "tags":
return m.Tags(), nil
case "fields":
return m.Fields(), nil
case "time":
return m.Time(), nil
default:
// Returning nil, nil indicates "no such field or method"
return nil, nil
}
}
// SetField implements the starlark.HasSetField interface.
func (m *Metric) SetField(name string, value starlark.Value) error {
if m.frozen {
return errors.New("cannot modify frozen metric")
}
switch name {
case "name":
return m.SetName(value)
case "time":
return m.SetTime(value)
case "tags":
return errors.New("cannot set tags")
case "fields":
return errors.New("cannot set fields")
default:
return starlark.NoSuchAttrError(
fmt.Sprintf("cannot assign to field %q", name))
}
}
func (m *Metric) Name() starlark.String {
return starlark.String(m.metric.Name())
}
func (m *Metric) SetName(value starlark.Value) error {
if str, ok := value.(starlark.String); ok {
m.metric.SetName(str.GoString())
return nil
}
return errors.New("type error")
}
func (m *Metric) Tags() TagDict {
return TagDict{m}
}
func (m *Metric) Fields() FieldDict {
return FieldDict{m}
}
func (m *Metric) Time() starlark.Int {
return starlark.MakeInt64(m.metric.Time().UnixNano())
}
func (m *Metric) SetTime(value starlark.Value) error {
switch v := value.(type) {
case starlark.Int:
ns, ok := v.Int64()
if !ok {
return errors.New("type error: unrepresentable time")
}
tm := time.Unix(0, ns)
m.metric.SetTime(tm)
return nil
default:
return errors.New("type error")
}
}

View file

@ -0,0 +1,282 @@
package starlark
import (
"bytes"
"encoding/gob"
"errors"
"fmt"
"strings"
"go.starlark.net/lib/json"
"go.starlark.net/lib/math"
"go.starlark.net/lib/time"
"go.starlark.net/starlark"
"go.starlark.net/syntax"
"github.com/influxdata/telegraf"
)
type Common struct {
Source string `toml:"source"`
Script string `toml:"script"`
Constants map[string]interface{} `toml:"constants"`
Log telegraf.Logger `toml:"-"`
StarlarkLoadFunc func(module string, logger telegraf.Logger) (starlark.StringDict, error)
thread *starlark.Thread
builtins starlark.StringDict
globals starlark.StringDict
functions map[string]*starlark.Function
parameters map[string]starlark.Tuple
state *starlark.Dict
}
func (s *Common) GetState() interface{} {
// Return the actual byte-type instead of nil allowing the persister
// to guess instantiate variable of the appropriate type
if s.state == nil {
return make([]byte, 0)
}
// Convert the starlark dict into a golang dictionary for serialization
state := make(map[string]interface{}, s.state.Len())
items := s.state.Items()
for _, item := range items {
if len(item) != 2 {
// We do expect key-value pairs in the state so there should be
// two items.
s.Log.Errorf("state item %+v does not contain a key-value pair", item)
continue
}
k, ok := item.Index(0).(starlark.String)
if !ok {
s.Log.Errorf("state item %+v has invalid key type %T", item, item.Index(0))
continue
}
v, err := asGoValue(item.Index(1))
if err != nil {
s.Log.Errorf("state item %+v value cannot be converted: %v", item, err)
continue
}
state[k.GoString()] = v
}
// Do a binary GOB encoding to preserve types
var buf bytes.Buffer
if err := gob.NewEncoder(&buf).Encode(state); err != nil {
s.Log.Errorf("encoding state failed: %v", err)
return make([]byte, 0)
}
return buf.Bytes()
}
func (s *Common) SetState(state interface{}) error {
data, ok := state.([]byte)
if !ok {
return fmt.Errorf("unexpected type %T for state", state)
}
if len(data) == 0 {
return nil
}
// Decode the binary GOB encoding
var dict map[string]interface{}
if err := gob.NewDecoder(bytes.NewBuffer(data)).Decode(&dict); err != nil {
return fmt.Errorf("decoding state failed: %w", err)
}
// Convert the golang dict back to starlark types
s.state = starlark.NewDict(len(dict))
for k, v := range dict {
sv, err := asStarlarkValue(v)
if err != nil {
return fmt.Errorf("value %v of state item %q cannot be set: %w", v, k, err)
}
if err := s.state.SetKey(starlark.String(k), sv); err != nil {
return fmt.Errorf("state item %q cannot be set: %w", k, err)
}
}
s.builtins["state"] = s.state
return s.InitProgram()
}
func (s *Common) Init() error {
if s.Source == "" && s.Script == "" {
return errors.New("one of source or script must be set")
}
if s.Source != "" && s.Script != "" {
return errors.New("both source or script cannot be set")
}
s.builtins = starlark.StringDict{}
s.builtins["Metric"] = starlark.NewBuiltin("Metric", newMetric)
s.builtins["deepcopy"] = starlark.NewBuiltin("deepcopy", deepcopy)
s.builtins["catch"] = starlark.NewBuiltin("catch", catch)
if err := s.addConstants(&s.builtins); err != nil {
return err
}
// Initialize the program
if err := s.InitProgram(); err != nil {
// Try again with a declared state. This might be necessary for
// state persistence.
s.state = starlark.NewDict(0)
s.builtins["state"] = s.state
if serr := s.InitProgram(); serr != nil {
return err
}
}
s.functions = make(map[string]*starlark.Function)
s.parameters = make(map[string]starlark.Tuple)
return nil
}
func (s *Common) InitProgram() error {
// Load the program. In case of an error we can try to insert the state
// which can be used implicitly e.g. when persisting states
program, err := s.sourceProgram(s.builtins)
if err != nil {
return err
}
// Execute source
s.thread = &starlark.Thread{
Print: func(_ *starlark.Thread, msg string) { s.Log.Debug(msg) },
Load: func(_ *starlark.Thread, module string) (starlark.StringDict, error) {
return s.StarlarkLoadFunc(module, s.Log)
},
}
globals, err := program.Init(s.thread, s.builtins)
if err != nil {
return err
}
// In case the program declares a global "state" we should insert it to
// avoid warnings about inserting into a frozen variable
if _, found := globals["state"]; found {
globals["state"] = starlark.NewDict(0)
}
// Freeze the global state. This prevents modifications to the processor
// state and prevents scripts from containing errors storing tracking
// metrics. Tasks that require global state will not be possible due to
// this, so maybe we should relax this in the future.
globals.Freeze()
s.globals = globals
return nil
}
func (s *Common) GetParameters(name string) (starlark.Tuple, bool) {
parameters, found := s.parameters[name]
return parameters, found
}
func (s *Common) AddFunction(name string, params ...starlark.Value) error {
globalFn, found := s.globals[name]
if !found {
return fmt.Errorf("%s is not defined", name)
}
fn, found := globalFn.(*starlark.Function)
if !found {
return fmt.Errorf("%s is not a function", name)
}
if fn.NumParams() != len(params) {
return fmt.Errorf("%s function must take %d parameter(s)", name, len(params))
}
p := make(starlark.Tuple, len(params))
copy(p, params)
s.functions[name] = fn
s.parameters[name] = params
return nil
}
// Add all the constants defined in the plugin as constants of the script
func (s *Common) addConstants(builtins *starlark.StringDict) error {
for key, val := range s.Constants {
if key == "state" {
return errors.New("'state' constant uses reserved name")
}
sVal, err := asStarlarkValue(val)
if err != nil {
return fmt.Errorf("converting type %T failed: %w", val, err)
}
(*builtins)[key] = sVal
}
return nil
}
func (s *Common) sourceProgram(builtins starlark.StringDict) (*starlark.Program, error) {
var src interface{}
if s.Source != "" {
src = s.Source
}
// AllowFloat - obsolete, no effect
// AllowNestedDef - always on https://github.com/google/starlark-go/pull/328
// AllowLambda - always on https://github.com/google/starlark-go/pull/328
options := syntax.FileOptions{
Recursion: true,
GlobalReassign: true,
Set: true,
}
_, program, err := starlark.SourceProgramOptions(&options, s.Script, src, builtins.Has)
return program, err
}
// Call calls the function corresponding to the given name.
func (s *Common) Call(name string) (starlark.Value, error) {
fn, ok := s.functions[name]
if !ok {
return nil, fmt.Errorf("function %q does not exist", name)
}
args, ok := s.parameters[name]
if !ok {
return nil, fmt.Errorf("params for function %q do not exist", name)
}
return starlark.Call(s.thread, fn, args, nil)
}
func (s *Common) LogError(err error) {
var evalErr *starlark.EvalError
if errors.As(err, &evalErr) {
for _, line := range strings.Split(evalErr.Backtrace(), "\n") {
s.Log.Error(line)
}
} else {
s.Log.Error(err)
}
}
func LoadFunc(module string, logger telegraf.Logger) (starlark.StringDict, error) {
switch module {
case "json.star":
return starlark.StringDict{
"json": json.Module,
}, nil
case "logging.star":
return starlark.StringDict{
"log": LogModule(logger),
}, nil
case "math.star":
return starlark.StringDict{
"math": math.Module,
}, nil
case "time.star":
return starlark.StringDict{
"time": time.Module,
}, nil
default:
return nil, errors.New("module " + module + " is not available")
}
}

View file

@ -0,0 +1,226 @@
package starlark
import (
"errors"
"strings"
"go.starlark.net/starlark"
"github.com/influxdata/telegraf"
)
// TagDict is a starlark.Value for the metric tags. It is heavily based on the
// starlark.Dict.
type TagDict struct {
*Metric
}
func (d TagDict) String() string {
buf := new(strings.Builder)
buf.WriteString("{")
sep := ""
for _, item := range d.Items() {
k, v := item[0], item[1]
buf.WriteString(sep)
buf.WriteString(k.String())
buf.WriteString(": ")
buf.WriteString(v.String())
sep = ", "
}
buf.WriteString("}")
return buf.String()
}
func (TagDict) Type() string {
return "Tags"
}
func (d TagDict) Freeze() {
// Disable linter check as the frozen variable is modified despite
// passing a value instead of a pointer, because `TagDict` holds
// a pointer to the underlying metric containing the `frozen` field.
//revive:disable:modifies-value-receiver
d.frozen = true
}
func (d TagDict) Truth() starlark.Bool {
return len(d.metric.TagList()) != 0
}
func (TagDict) Hash() (uint32, error) {
return 0, errors.New("not hashable")
}
// AttrNames implements the starlark.HasAttrs interface.
func (TagDict) AttrNames() []string {
return builtinAttrNames(TagDictMethods)
}
// Attr implements the starlark.HasAttrs interface.
func (d TagDict) Attr(name string) (starlark.Value, error) {
return builtinAttr(d, name, TagDictMethods)
}
var TagDictMethods = map[string]builtinMethod{
"clear": dictClear,
"get": dictGet,
"items": dictItems,
"keys": dictKeys,
"pop": dictPop,
"popitem": dictPopitem,
"setdefault": dictSetdefault,
"update": dictUpdate,
"values": dictValues,
}
// Get implements the starlark.Mapping interface.
func (d TagDict) Get(key starlark.Value) (v starlark.Value, found bool, err error) {
if k, ok := key.(starlark.String); ok {
gv, found := d.metric.GetTag(k.GoString())
if !found {
return starlark.None, false, nil
}
return starlark.String(gv), true, err
}
return starlark.None, false, errors.New("key must be of type 'str'")
}
// SetKey implements the starlark.HasSetKey interface to support map update
// using x[k]=v syntax, like a dictionary.
func (d TagDict) SetKey(k, v starlark.Value) error {
if d.tagIterCount > 0 {
return errors.New("cannot insert during iteration")
}
key, ok := k.(starlark.String)
if !ok {
return errors.New("tag key must be of type 'str'")
}
value, ok := v.(starlark.String)
if !ok {
return errors.New("tag value must be of type 'str'")
}
d.metric.AddTag(key.GoString(), value.GoString())
return nil
}
// Items implements the starlark.IterableMapping interface.
func (d TagDict) Items() []starlark.Tuple {
items := make([]starlark.Tuple, 0, len(d.metric.TagList()))
for _, tag := range d.metric.TagList() {
key := starlark.String(tag.Key)
value := starlark.String(tag.Value)
pair := starlark.Tuple{key, value}
items = append(items, pair)
}
return items
}
func (d TagDict) Clear() error {
if d.tagIterCount > 0 {
return errors.New("cannot delete during iteration")
}
keys := make([]string, 0, len(d.metric.TagList()))
for _, tag := range d.metric.TagList() {
keys = append(keys, tag.Key)
}
for _, key := range keys {
d.metric.RemoveTag(key)
}
return nil
}
func (d TagDict) PopItem() (v starlark.Value, err error) {
if d.tagIterCount > 0 {
return nil, errors.New("cannot delete during iteration")
}
for _, tag := range d.metric.TagList() {
k := tag.Key
v := tag.Value
d.metric.RemoveTag(k)
sk := starlark.String(k)
sv := starlark.String(v)
return starlark.Tuple{sk, sv}, nil
}
return nil, errors.New("popitem(): tag dictionary is empty")
}
func (d TagDict) Delete(k starlark.Value) (v starlark.Value, found bool, err error) {
if d.tagIterCount > 0 {
return nil, false, errors.New("cannot delete during iteration")
}
if key, ok := k.(starlark.String); ok {
value, ok := d.metric.GetTag(key.GoString())
if ok {
d.metric.RemoveTag(key.GoString())
v := starlark.String(value)
return v, ok, err
}
return starlark.None, false, nil
}
return starlark.None, false, errors.New("key must be of type 'str'")
}
// Iterate implements the starlark.Iterator interface.
func (d TagDict) Iterate() starlark.Iterator {
d.tagIterCount++
return &TagIterator{Metric: d.Metric, tags: d.metric.TagList()}
}
type TagIterator struct {
*Metric
tags []*telegraf.Tag
}
// Next implements the starlark.Iterator interface.
func (i *TagIterator) Next(p *starlark.Value) bool {
if len(i.tags) == 0 {
return false
}
tag := i.tags[0]
i.tags = i.tags[1:]
*p = starlark.String(tag.Key)
return true
}
// Done implements the starlark.Iterator interface.
func (i *TagIterator) Done() {
i.tagIterCount--
}
// ToTags converts a starlark.Value to a map of string.
func toTags(value starlark.Value) (map[string]string, error) {
if value == nil {
return nil, nil
}
items, err := items(value, "The type %T is unsupported as type of collection of tags")
if err != nil {
return nil, err
}
result := make(map[string]string, len(items))
for _, item := range items {
key, err := toString(item[0], "The type %T is unsupported as type of key for tags")
if err != nil {
return nil, err
}
value, err := toString(item[1], "The type %T is unsupported as type of value for tags")
if err != nil {
return nil, err
}
result[key] = value
}
return result, nil
}

View file

@ -0,0 +1,24 @@
## Set to true/false to enforce TLS being enabled/disabled. If not set,
## enable TLS only if any of the other options are specified.
# tls_enable =
## Trusted root certificates for server
# tls_ca = "/path/to/cafile"
## Used for TLS client certificate authentication
# tls_cert = "/path/to/certfile"
## Used for TLS client certificate authentication
# tls_key = "/path/to/keyfile"
## Password for the key file if it is encrypted
# tls_key_pwd = ""
## Send the specified TLS server name via SNI
# tls_server_name = "kubernetes.example.com"
## Minimal TLS version to accept by the client
# tls_min_version = "TLS12"
## List of ciphers to accept, by default all secure ciphers will be accepted
## See https://pkg.go.dev/crypto/tls#pkg-constants for supported values.
## Use "all", "secure" and "insecure" to add all support ciphers, secure
## suites or insecure suites respectively.
# tls_cipher_suites = ["secure"]
## Renegotiation method, "never", "once" or "freely"
# tls_renegotiation_method = "never"
## Use TLS but skip chain & host verification
# insecure_skip_verify = false

View file

@ -0,0 +1,34 @@
package tls
import (
"crypto/tls"
"sync"
)
var tlsVersionMap = map[string]uint16{
"TLS10": tls.VersionTLS10,
"TLS11": tls.VersionTLS11,
"TLS12": tls.VersionTLS12,
"TLS13": tls.VersionTLS13,
}
var tlsCipherMapInit sync.Once
var tlsCipherMapSecure map[string]uint16
var tlsCipherMapInsecure map[string]uint16
func init() {
tlsCipherMapInit.Do(func() {
// Initialize the secure suites
suites := tls.CipherSuites()
tlsCipherMapSecure = make(map[string]uint16, len(suites))
for _, s := range suites {
tlsCipherMapSecure[s.Name] = s.ID
}
suites = tls.InsecureCipherSuites()
tlsCipherMapInsecure = make(map[string]uint16, len(suites))
for _, s := range suites {
tlsCipherMapInsecure[s.Name] = s.ID
}
})
}

View file

@ -0,0 +1,298 @@
package tls
import (
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
"os"
"go.step.sm/crypto/pemutil"
"github.com/influxdata/telegraf/internal/choice"
)
const TLSMinVersionDefault = tls.VersionTLS12
// ClientConfig represents the standard client TLS config.
type ClientConfig struct {
TLSCA string `toml:"tls_ca"`
TLSCert string `toml:"tls_cert"`
TLSKey string `toml:"tls_key"`
TLSKeyPwd string `toml:"tls_key_pwd"`
TLSMinVersion string `toml:"tls_min_version"`
TLSCipherSuites []string `toml:"tls_cipher_suites"`
InsecureSkipVerify bool `toml:"insecure_skip_verify"`
ServerName string `toml:"tls_server_name"`
RenegotiationMethod string `toml:"tls_renegotiation_method"`
Enable *bool `toml:"tls_enable"`
SSLCA string `toml:"ssl_ca" deprecated:"1.7.0;1.35.0;use 'tls_ca' instead"`
SSLCert string `toml:"ssl_cert" deprecated:"1.7.0;1.35.0;use 'tls_cert' instead"`
SSLKey string `toml:"ssl_key" deprecated:"1.7.0;1.35.0;use 'tls_key' instead"`
}
// ServerConfig represents the standard server TLS config.
type ServerConfig struct {
TLSCert string `toml:"tls_cert"`
TLSKey string `toml:"tls_key"`
TLSKeyPwd string `toml:"tls_key_pwd"`
TLSAllowedCACerts []string `toml:"tls_allowed_cacerts"`
TLSCipherSuites []string `toml:"tls_cipher_suites"`
TLSMinVersion string `toml:"tls_min_version"`
TLSMaxVersion string `toml:"tls_max_version"`
TLSAllowedDNSNames []string `toml:"tls_allowed_dns_names"`
}
// TLSConfig returns a tls.Config, may be nil without error if TLS is not
// configured.
func (c *ClientConfig) TLSConfig() (*tls.Config, error) {
// Check if TLS config is forcefully disabled
if c.Enable != nil && !*c.Enable {
return nil, nil
}
// Support deprecated variable names
if c.TLSCA == "" && c.SSLCA != "" {
c.TLSCA = c.SSLCA
}
if c.TLSCert == "" && c.SSLCert != "" {
c.TLSCert = c.SSLCert
}
if c.TLSKey == "" && c.SSLKey != "" {
c.TLSKey = c.SSLKey
}
// This check returns a nil (aka "disabled") or an empty config
// (aka, "use the default") if no field is set that would have an effect on
// a TLS connection. That is, any of:
// * client certificate settings,
// * peer certificate authorities,
// * disabled security,
// * an SNI server name, or
// * empty/never renegotiation method
empty := c.TLSCA == "" && c.TLSKey == "" && c.TLSCert == ""
empty = empty && !c.InsecureSkipVerify && c.ServerName == ""
empty = empty && (c.RenegotiationMethod == "" || c.RenegotiationMethod == "never")
if empty {
// Check if TLS config is forcefully enabled and supposed to
// use the system defaults.
if c.Enable != nil && *c.Enable {
return &tls.Config{}, nil
}
return nil, nil
}
var renegotiationMethod tls.RenegotiationSupport
switch c.RenegotiationMethod {
case "", "never":
renegotiationMethod = tls.RenegotiateNever
case "once":
renegotiationMethod = tls.RenegotiateOnceAsClient
case "freely":
renegotiationMethod = tls.RenegotiateFreelyAsClient
default:
return nil, fmt.Errorf("unrecognized renegotiation method %q, choose from: 'never', 'once', 'freely'", c.RenegotiationMethod)
}
tlsConfig := &tls.Config{
InsecureSkipVerify: c.InsecureSkipVerify,
Renegotiation: renegotiationMethod,
}
if c.TLSCA != "" {
pool, err := makeCertPool([]string{c.TLSCA})
if err != nil {
return nil, err
}
tlsConfig.RootCAs = pool
}
if c.TLSCert != "" && c.TLSKey != "" {
err := loadCertificate(tlsConfig, c.TLSCert, c.TLSKey, c.TLSKeyPwd)
if err != nil {
return nil, err
}
}
// Explicitly and consistently set the minimal accepted version using the
// defined default. We use this setting for both clients and servers
// instead of relying on Golang's default that is different for clients
// and servers and might change over time.
tlsConfig.MinVersion = TLSMinVersionDefault
if c.TLSMinVersion != "" {
version, err := ParseTLSVersion(c.TLSMinVersion)
if err != nil {
return nil, fmt.Errorf("could not parse tls min version %q: %w", c.TLSMinVersion, err)
}
tlsConfig.MinVersion = version
}
if c.ServerName != "" {
tlsConfig.ServerName = c.ServerName
}
if len(c.TLSCipherSuites) != 0 {
cipherSuites, err := ParseCiphers(c.TLSCipherSuites)
if err != nil {
return nil, fmt.Errorf("could not parse client cipher suites: %w", err)
}
tlsConfig.CipherSuites = cipherSuites
}
return tlsConfig, nil
}
// TLSConfig returns a tls.Config, may be nil without error if TLS is not
// configured.
func (c *ServerConfig) TLSConfig() (*tls.Config, error) {
if c.TLSCert == "" && c.TLSKey == "" && len(c.TLSAllowedCACerts) == 0 {
return nil, nil
}
tlsConfig := &tls.Config{}
if len(c.TLSAllowedCACerts) != 0 {
pool, err := makeCertPool(c.TLSAllowedCACerts)
if err != nil {
return nil, err
}
tlsConfig.ClientCAs = pool
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
}
if c.TLSCert != "" && c.TLSKey != "" {
err := loadCertificate(tlsConfig, c.TLSCert, c.TLSKey, c.TLSKeyPwd)
if err != nil {
return nil, err
}
}
if len(c.TLSCipherSuites) != 0 {
cipherSuites, err := ParseCiphers(c.TLSCipherSuites)
if err != nil {
return nil, fmt.Errorf("could not parse server cipher suites: %w", err)
}
tlsConfig.CipherSuites = cipherSuites
}
if c.TLSMaxVersion != "" {
version, err := ParseTLSVersion(c.TLSMaxVersion)
if err != nil {
return nil, fmt.Errorf(
"could not parse tls max version %q: %w", c.TLSMaxVersion, err)
}
tlsConfig.MaxVersion = version
}
// Explicitly and consistently set the minimal accepted version using the
// defined default. We use this setting for both clients and servers
// instead of relying on Golang's default that is different for clients
// and servers and might change over time.
tlsConfig.MinVersion = TLSMinVersionDefault
if c.TLSMinVersion != "" {
version, err := ParseTLSVersion(c.TLSMinVersion)
if err != nil {
return nil, fmt.Errorf("could not parse tls min version %q: %w", c.TLSMinVersion, err)
}
tlsConfig.MinVersion = version
}
if tlsConfig.MinVersion != 0 && tlsConfig.MaxVersion != 0 && tlsConfig.MinVersion > tlsConfig.MaxVersion {
return nil, fmt.Errorf("tls min version %q can't be greater than tls max version %q", tlsConfig.MinVersion, tlsConfig.MaxVersion)
}
// Since clientAuth is tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
// there must be certs to validate.
if len(c.TLSAllowedCACerts) > 0 && len(c.TLSAllowedDNSNames) > 0 {
tlsConfig.VerifyPeerCertificate = c.verifyPeerCertificate
}
return tlsConfig, nil
}
func makeCertPool(certFiles []string) (*x509.CertPool, error) {
pool := x509.NewCertPool()
for _, certFile := range certFiles {
cert, err := os.ReadFile(certFile)
if err != nil {
return nil, fmt.Errorf("could not read certificate %q: %w", certFile, err)
}
if !pool.AppendCertsFromPEM(cert) {
return nil, fmt.Errorf("could not parse any PEM certificates %q: %w", certFile, err)
}
}
return pool, nil
}
func loadCertificate(config *tls.Config, certFile, keyFile, privateKeyPassphrase string) error {
certBytes, err := os.ReadFile(certFile)
if err != nil {
return fmt.Errorf("could not load certificate %q: %w", certFile, err)
}
keyBytes, err := os.ReadFile(keyFile)
if err != nil {
return fmt.Errorf("could not load private key %q: %w", keyFile, err)
}
keyPEMBlock, _ := pem.Decode(keyBytes)
if keyPEMBlock == nil {
return errors.New("failed to decode private key: no PEM data found")
}
var cert tls.Certificate
if keyPEMBlock.Type == "ENCRYPTED PRIVATE KEY" {
if privateKeyPassphrase == "" {
return errors.New("missing password for PKCS#8 encrypted private key")
}
rawDecryptedKey, err := pemutil.DecryptPKCS8PrivateKey(keyPEMBlock.Bytes, []byte(privateKeyPassphrase))
if err != nil {
return fmt.Errorf("failed to decrypt PKCS#8 private key: %w", err)
}
decryptedKey, err := x509.ParsePKCS8PrivateKey(rawDecryptedKey)
if err != nil {
return fmt.Errorf("failed to parse decrypted PKCS#8 private key: %w", err)
}
privateKey, ok := decryptedKey.(*rsa.PrivateKey)
if !ok {
return fmt.Errorf("decrypted key is not a RSA private key: %T", decryptedKey)
}
cert, err = tls.X509KeyPair(certBytes, pem.EncodeToMemory(&pem.Block{Type: keyPEMBlock.Type, Bytes: x509.MarshalPKCS1PrivateKey(privateKey)}))
if err != nil {
return fmt.Errorf("failed to load cert/key pair: %w", err)
}
} else if keyPEMBlock.Headers["Proc-Type"] == "4,ENCRYPTED" {
// The key is an encrypted private key with the DEK-Info header.
// This is currently unsupported because of the deprecation of x509.IsEncryptedPEMBlock and x509.DecryptPEMBlock.
return errors.New("password-protected keys in pkcs#1 format are not supported")
} else {
cert, err = tls.X509KeyPair(certBytes, keyBytes)
if err != nil {
return fmt.Errorf("failed to load cert/key pair: %w", err)
}
}
config.Certificates = []tls.Certificate{cert}
return nil
}
func (c *ServerConfig) verifyPeerCertificate(rawCerts [][]byte, _ [][]*x509.Certificate) error {
// The certificate chain is client + intermediate + root.
// Let's review the client certificate.
cert, err := x509.ParseCertificate(rawCerts[0])
if err != nil {
return fmt.Errorf("could not validate peer certificate: %w", err)
}
for _, name := range cert.DNSNames {
if choice.Contains(name, c.TLSAllowedDNSNames) {
return nil
}
}
return fmt.Errorf("peer certificate not in allowed DNS Name list: %v", cert.DNSNames)
}

View file

@ -0,0 +1,614 @@
package tls_test
import (
cryptotls "crypto/tls"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/influxdata/telegraf/plugins/common/tls"
"github.com/influxdata/telegraf/testutil"
)
var pki = testutil.NewPKI("../../../testutil/pki")
func TestClientConfig(t *testing.T) {
tests := []struct {
name string
client tls.ClientConfig
expNil bool
expErr bool
}{
{
name: "unset",
client: tls.ClientConfig{},
expNil: true,
},
{
name: "success",
client: tls.ClientConfig{
TLSCA: pki.CACertPath(),
TLSCert: pki.ClientCertPath(),
TLSKey: pki.ClientKeyPath(),
},
},
{
name: "success with tls key password set",
client: tls.ClientConfig{
TLSCA: pki.CACertPath(),
TLSCert: pki.ClientCertPath(),
TLSKey: pki.ClientKeyPath(),
TLSKeyPwd: "",
},
},
{
name: "success with unencrypted pkcs#8 key",
client: tls.ClientConfig{
TLSCA: pki.CACertPath(),
TLSCert: pki.ClientCertPath(),
TLSKey: pki.ClientPKCS8KeyPath(),
},
},
{
name: "encrypted pkcs#8 key but missing password",
client: tls.ClientConfig{
TLSCA: pki.CACertPath(),
TLSCert: pki.ClientCertPath(),
TLSKey: pki.ClientEncPKCS8KeyPath(),
},
expNil: true,
expErr: true,
},
{
name: "encrypted pkcs#8 key and incorrect password",
client: tls.ClientConfig{
TLSCA: pki.CACertPath(),
TLSCert: pki.ClientCertPath(),
TLSKey: pki.ClientEncPKCS8KeyPath(),
TLSKeyPwd: "incorrect",
},
expNil: true,
expErr: true,
},
{
name: "success with encrypted pkcs#8 key and password set",
client: tls.ClientConfig{
TLSCA: pki.CACertPath(),
TLSCert: pki.ClientCertPath(),
TLSKey: pki.ClientEncPKCS8KeyPath(),
TLSKeyPwd: "changeme",
},
},
{
name: "error with encrypted pkcs#1 key and password set",
client: tls.ClientConfig{
TLSCA: pki.CACertPath(),
TLSCert: pki.ClientCertPath(),
TLSKey: pki.ClientEncKeyPath(),
TLSKeyPwd: "changeme",
},
expNil: true,
expErr: true,
},
{
name: "invalid ca",
client: tls.ClientConfig{
TLSCA: pki.ClientKeyPath(),
TLSCert: pki.ClientCertPath(),
TLSKey: pki.ClientKeyPath(),
},
expNil: true,
expErr: true,
},
{
name: "missing ca is okay",
client: tls.ClientConfig{
TLSCert: pki.ClientCertPath(),
TLSKey: pki.ClientKeyPath(),
},
},
{
name: "invalid cert",
client: tls.ClientConfig{
TLSCA: pki.CACertPath(),
TLSCert: pki.ClientKeyPath(),
TLSKey: pki.ClientKeyPath(),
},
expNil: true,
expErr: true,
},
{
name: "missing cert skips client keypair",
client: tls.ClientConfig{
TLSCA: pki.CACertPath(),
TLSKey: pki.ClientKeyPath(),
},
expNil: false,
expErr: false,
},
{
name: "missing key skips client keypair",
client: tls.ClientConfig{
TLSCA: pki.CACertPath(),
TLSCert: pki.ClientCertPath(),
},
expNil: false,
expErr: false,
},
{
name: "support deprecated ssl field names",
client: tls.ClientConfig{
SSLCA: pki.CACertPath(),
SSLCert: pki.ClientCertPath(),
SSLKey: pki.ClientKeyPath(),
},
},
{
name: "set SNI server name",
client: tls.ClientConfig{
ServerName: "foo.example.com",
},
expNil: false,
expErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tlsConfig, err := tt.client.TLSConfig()
if !tt.expNil {
require.NotNil(t, tlsConfig)
} else {
require.Nil(t, tlsConfig)
}
if !tt.expErr {
require.NoError(t, err)
} else {
require.Error(t, err)
}
})
}
}
func TestServerConfig(t *testing.T) {
tests := []struct {
name string
server tls.ServerConfig
expNil bool
expErr bool
}{
{
name: "unset",
server: tls.ServerConfig{},
expNil: true,
},
{
name: "success",
server: tls.ServerConfig{
TLSCert: pki.ServerCertPath(),
TLSKey: pki.ServerKeyPath(),
TLSAllowedCACerts: []string{pki.CACertPath()},
TLSCipherSuites: []string{pki.CipherSuite()},
TLSAllowedDNSNames: []string{"localhost", "127.0.0.1"},
TLSMinVersion: pki.TLSMinVersion(),
TLSMaxVersion: pki.TLSMaxVersion(),
},
},
{
name: "success with tls key password set",
server: tls.ServerConfig{
TLSCert: pki.ServerCertPath(),
TLSKey: pki.ServerKeyPath(),
TLSKeyPwd: "",
TLSAllowedCACerts: []string{pki.CACertPath()},
TLSCipherSuites: []string{pki.CipherSuite()},
TLSMinVersion: pki.TLSMinVersion(),
TLSMaxVersion: pki.TLSMaxVersion(),
},
},
{
name: "missing tls cipher suites is okay",
server: tls.ServerConfig{
TLSCert: pki.ServerCertPath(),
TLSKey: pki.ServerKeyPath(),
TLSAllowedCACerts: []string{pki.CACertPath()},
TLSCipherSuites: []string{pki.CipherSuite()},
},
},
{
name: "missing tls max version is okay",
server: tls.ServerConfig{
TLSCert: pki.ServerCertPath(),
TLSKey: pki.ServerKeyPath(),
TLSAllowedCACerts: []string{pki.CACertPath()},
TLSCipherSuites: []string{pki.CipherSuite()},
TLSMaxVersion: pki.TLSMaxVersion(),
},
},
{
name: "missing tls min version is okay",
server: tls.ServerConfig{
TLSCert: pki.ServerCertPath(),
TLSKey: pki.ServerKeyPath(),
TLSAllowedCACerts: []string{pki.CACertPath()},
TLSCipherSuites: []string{pki.CipherSuite()},
TLSMinVersion: pki.TLSMinVersion(),
},
},
{
name: "missing tls min/max versions is okay",
server: tls.ServerConfig{
TLSCert: pki.ServerCertPath(),
TLSKey: pki.ServerKeyPath(),
TLSAllowedCACerts: []string{pki.CACertPath()},
TLSCipherSuites: []string{pki.CipherSuite()},
},
},
{
name: "invalid ca",
server: tls.ServerConfig{
TLSCert: pki.ServerCertPath(),
TLSKey: pki.ServerKeyPath(),
TLSAllowedCACerts: []string{pki.ServerKeyPath()},
},
expNil: true,
expErr: true,
},
{
name: "missing allowed ca is okay",
server: tls.ServerConfig{
TLSCert: pki.ServerCertPath(),
TLSKey: pki.ServerKeyPath(),
},
expNil: true,
expErr: true,
},
{
name: "invalid cert",
server: tls.ServerConfig{
TLSCert: pki.ServerKeyPath(),
TLSKey: pki.ServerKeyPath(),
TLSAllowedCACerts: []string{pki.CACertPath()},
},
expNil: true,
expErr: true,
},
{
name: "missing cert",
server: tls.ServerConfig{
TLSKey: pki.ServerKeyPath(),
TLSAllowedCACerts: []string{pki.CACertPath()},
},
expNil: true,
expErr: true,
},
{
name: "missing key",
server: tls.ServerConfig{
TLSCert: pki.ServerCertPath(),
TLSAllowedCACerts: []string{pki.CACertPath()},
},
expNil: true,
expErr: true,
},
{
name: "invalid cipher suites",
server: tls.ServerConfig{
TLSCert: pki.ServerCertPath(),
TLSKey: pki.ServerKeyPath(),
TLSAllowedCACerts: []string{pki.CACertPath()},
TLSCipherSuites: []string{pki.CACertPath()},
},
expNil: true,
expErr: true,
},
{
name: "TLS Max Version less than TLS Min version",
server: tls.ServerConfig{
TLSCert: pki.ServerCertPath(),
TLSKey: pki.ServerKeyPath(),
TLSAllowedCACerts: []string{pki.CACertPath()},
TLSCipherSuites: []string{pki.CACertPath()},
TLSMinVersion: pki.TLSMaxVersion(),
TLSMaxVersion: pki.TLSMinVersion(),
},
expNil: true,
expErr: true,
},
{
name: "invalid tls min version",
server: tls.ServerConfig{
TLSCert: pki.ServerCertPath(),
TLSKey: pki.ServerKeyPath(),
TLSAllowedCACerts: []string{pki.CACertPath()},
TLSCipherSuites: []string{pki.CipherSuite()},
TLSMinVersion: pki.ServerKeyPath(),
TLSMaxVersion: pki.TLSMaxVersion(),
},
expNil: true,
expErr: true,
},
{
name: "invalid tls max version",
server: tls.ServerConfig{
TLSCert: pki.ServerCertPath(),
TLSKey: pki.ServerKeyPath(),
TLSAllowedCACerts: []string{pki.CACertPath()},
TLSCipherSuites: []string{pki.CACertPath()},
TLSMinVersion: pki.TLSMinVersion(),
TLSMaxVersion: pki.ServerCertPath(),
},
expNil: true,
expErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tlsConfig, err := tt.server.TLSConfig()
if !tt.expNil {
require.NotNil(t, tlsConfig)
}
if !tt.expErr {
require.NoError(t, err)
}
})
}
}
func TestConnect(t *testing.T) {
clientConfig := tls.ClientConfig{
TLSCA: pki.CACertPath(),
TLSCert: pki.ClientCertPath(),
TLSKey: pki.ClientKeyPath(),
}
serverConfig := tls.ServerConfig{
TLSCert: pki.ServerCertPath(),
TLSKey: pki.ServerKeyPath(),
TLSAllowedCACerts: []string{pki.CACertPath()},
TLSAllowedDNSNames: []string{"localhost", "127.0.0.1"},
}
serverTLSConfig, err := serverConfig.TLSConfig()
require.NoError(t, err)
ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
}))
ts.TLS = serverTLSConfig
ts.StartTLS()
defer ts.Close()
clientTLSConfig, err := clientConfig.TLSConfig()
require.NoError(t, err)
client := http.Client{
Transport: &http.Transport{
TLSClientConfig: clientTLSConfig,
},
Timeout: 10 * time.Second,
}
resp, err := client.Get(ts.URL)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, 200, resp.StatusCode)
}
func TestConnectClientMinTLSVersion(t *testing.T) {
serverConfig := tls.ServerConfig{
TLSCert: pki.ServerCertPath(),
TLSKey: pki.ServerKeyPath(),
TLSAllowedCACerts: []string{pki.CACertPath()},
TLSAllowedDNSNames: []string{"localhost", "127.0.0.1"},
}
tests := []struct {
name string
cfg tls.ClientConfig
}{
{
name: "TLS version default",
cfg: tls.ClientConfig{
TLSCA: pki.CACertPath(),
TLSCert: pki.ClientCertPath(),
TLSKey: pki.ClientKeyPath(),
},
},
{
name: "TLS version 1.0",
cfg: tls.ClientConfig{
TLSCA: pki.CACertPath(),
TLSCert: pki.ClientCertPath(),
TLSKey: pki.ClientKeyPath(),
TLSMinVersion: "TLS10",
},
},
{
name: "TLS version 1.1",
cfg: tls.ClientConfig{
TLSCA: pki.CACertPath(),
TLSCert: pki.ClientCertPath(),
TLSKey: pki.ClientKeyPath(),
TLSMinVersion: "TLS11",
},
},
{
name: "TLS version 1.2",
cfg: tls.ClientConfig{
TLSCA: pki.CACertPath(),
TLSCert: pki.ClientCertPath(),
TLSKey: pki.ClientKeyPath(),
TLSMinVersion: "TLS12",
},
},
{
name: "TLS version 1.3",
cfg: tls.ClientConfig{
TLSCA: pki.CACertPath(),
TLSCert: pki.ClientCertPath(),
TLSKey: pki.ClientKeyPath(),
TLSMinVersion: "TLS13",
},
},
}
tlsVersions := []uint16{
cryptotls.VersionTLS10,
cryptotls.VersionTLS11,
cryptotls.VersionTLS12,
cryptotls.VersionTLS13,
}
tlsVersionNames := []string{
"TLS 1.0",
"TLS 1.1",
"TLS 1.2",
"TLS 1.3",
}
for _, tt := range tests {
clientTLSConfig, err := tt.cfg.TLSConfig()
require.NoError(t, err)
client := http.Client{
Transport: &http.Transport{
TLSClientConfig: clientTLSConfig,
},
Timeout: 1 * time.Second,
}
clientMinVersion := clientTLSConfig.MinVersion
if tt.cfg.TLSMinVersion == "" {
clientMinVersion = tls.TLSMinVersionDefault
}
for i, serverTLSMaxVersion := range tlsVersions {
serverVersionName := tlsVersionNames[i]
t.Run(tt.name+" vs "+serverVersionName, func(t *testing.T) {
// Constrain the server's maximum TLS version
serverTLSConfig, err := serverConfig.TLSConfig()
require.NoError(t, err)
serverTLSConfig.MinVersion = cryptotls.VersionTLS10
serverTLSConfig.MaxVersion = serverTLSMaxVersion
// Start the server
ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
}))
ts.TLS = serverTLSConfig
ts.StartTLS()
// Do the connection and cleanup
resp, err := client.Get(ts.URL)
ts.Close()
// Things should fail if the currently tested "serverTLSMaxVersion"
// is below the client minimum version.
if serverTLSMaxVersion < clientMinVersion {
require.ErrorContains(t, err, "tls: protocol version not supported")
} else {
require.NoErrorf(t, err, "server=%v client=%v", serverTLSMaxVersion, clientMinVersion)
require.Equal(t, 200, resp.StatusCode)
resp.Body.Close()
}
})
}
}
}
func TestConnectClientInvalidMinTLSVersion(t *testing.T) {
clientConfig := tls.ClientConfig{
TLSCA: pki.CACertPath(),
TLSCert: pki.ClientCertPath(),
TLSKey: pki.ClientKeyPath(),
TLSMinVersion: "garbage",
}
_, err := clientConfig.TLSConfig()
expected := `could not parse tls min version "garbage": unsupported version "garbage" (available: TLS10,TLS11,TLS12,TLS13)`
require.EqualError(t, err, expected)
}
func TestConnectWrongDNS(t *testing.T) {
clientConfig := tls.ClientConfig{
TLSCA: pki.CACertPath(),
TLSCert: pki.ClientCertPath(),
TLSKey: pki.ClientKeyPath(),
}
serverConfig := tls.ServerConfig{
TLSCert: pki.ServerCertPath(),
TLSKey: pki.ServerKeyPath(),
TLSAllowedCACerts: []string{pki.CACertPath()},
TLSAllowedDNSNames: []string{"localhos", "127.0.0.2"},
}
serverTLSConfig, err := serverConfig.TLSConfig()
require.NoError(t, err)
ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
}))
ts.TLS = serverTLSConfig
ts.StartTLS()
defer ts.Close()
clientTLSConfig, err := clientConfig.TLSConfig()
require.NoError(t, err)
client := http.Client{
Transport: &http.Transport{
TLSClientConfig: clientTLSConfig,
},
Timeout: 10 * time.Second,
}
resp, err := client.Get(ts.URL)
require.Error(t, err)
if resp != nil {
err = resp.Body.Close()
require.NoError(t, err)
}
}
func TestEnableFlagAuto(t *testing.T) {
cfgEmpty := tls.ClientConfig{}
cfg, err := cfgEmpty.TLSConfig()
require.NoError(t, err)
require.Nil(t, cfg)
cfgSet := tls.ClientConfig{InsecureSkipVerify: true}
cfg, err = cfgSet.TLSConfig()
require.NoError(t, err)
require.NotNil(t, cfg)
}
func TestEnableFlagDisabled(t *testing.T) {
enabled := false
cfgSet := tls.ClientConfig{
InsecureSkipVerify: true,
Enable: &enabled,
}
cfg, err := cfgSet.TLSConfig()
require.NoError(t, err)
require.Nil(t, cfg)
}
func TestEnableFlagEnabled(t *testing.T) {
enabled := true
cfgSet := tls.ClientConfig{Enable: &enabled}
cfg, err := cfgSet.TLSConfig()
require.NoError(t, err)
require.NotNil(t, cfg)
expected := &cryptotls.Config{}
require.Equal(t, expected, cfg)
}

112
plugins/common/tls/utils.go Normal file
View file

@ -0,0 +1,112 @@
package tls
import (
"errors"
"fmt"
"sort"
"strings"
)
var ErrCipherUnsupported = errors.New("unsupported cipher")
// InsecureCiphers returns the list of insecure ciphers among the list of given ciphers
func InsecureCiphers(ciphers []string) []string {
var insecure []string
for _, c := range ciphers {
cipher := strings.ToUpper(c)
if _, ok := tlsCipherMapInsecure[cipher]; ok {
insecure = append(insecure, c)
}
}
return insecure
}
// Ciphers returns the list of supported ciphers
func Ciphers() (secure, insecure []string) {
for c := range tlsCipherMapSecure {
secure = append(secure, c)
}
for c := range tlsCipherMapInsecure {
insecure = append(insecure, c)
}
return secure, insecure
}
// ParseCiphers returns a `[]uint16` by received `[]string` key that represents ciphers from crypto/tls.
// If some of ciphers in received list doesn't exists ParseCiphers returns nil with error
func ParseCiphers(ciphers []string) ([]uint16, error) {
suites := make([]uint16, 0)
added := make(map[uint16]bool, len(ciphers))
for _, c := range ciphers {
// Handle meta-keywords
switch c {
case "all":
for _, id := range tlsCipherMapInsecure {
if added[id] {
continue
}
suites = append(suites, id)
added[id] = true
}
for _, id := range tlsCipherMapSecure {
if added[id] {
continue
}
suites = append(suites, id)
added[id] = true
}
case "insecure":
for _, id := range tlsCipherMapInsecure {
if added[id] {
continue
}
suites = append(suites, id)
added[id] = true
}
case "secure":
for _, id := range tlsCipherMapSecure {
if added[id] {
continue
}
suites = append(suites, id)
added[id] = true
}
default:
cipher := strings.ToUpper(c)
id, ok := tlsCipherMapSecure[cipher]
if !ok {
idInsecure, ok := tlsCipherMapInsecure[cipher]
if !ok {
return nil, fmt.Errorf("%q %w", cipher, ErrCipherUnsupported)
}
id = idInsecure
}
if added[id] {
continue
}
suites = append(suites, id)
added[id] = true
}
}
return suites, nil
}
// ParseTLSVersion returns a `uint16` by received version string key that represents tls version from crypto/tls.
// If version isn't supported ParseTLSVersion returns 0 with error
func ParseTLSVersion(version string) (uint16, error) {
if v, ok := tlsVersionMap[version]; ok {
return v, nil
}
available := make([]string, 0, len(tlsVersionMap))
for n := range tlsVersionMap {
available = append(available, n)
}
sort.Strings(available)
return 0, fmt.Errorf("unsupported version %q (available: %s)", version, strings.Join(available, ","))
}

View file

@ -0,0 +1,263 @@
package yangmodel
import (
"encoding/base64"
"encoding/binary"
"errors"
"fmt"
"math"
"os"
"path/filepath"
"strconv"
"github.com/openconfig/goyang/pkg/yang"
)
var (
ErrInsufficientData = errors.New("insufficient data")
ErrNotFound = errors.New("no such node")
)
type Decoder struct {
modules map[string]*yang.Module
rootNodes map[string][]yang.Node
}
func NewDecoder(paths ...string) (*Decoder, error) {
modules := yang.NewModules()
modules.ParseOptions.IgnoreSubmoduleCircularDependencies = true
var moduleFiles []string
modulePaths := paths
unresolved := paths
for {
var newlyfound []string
for _, path := range unresolved {
entries, err := os.ReadDir(path)
if err != nil {
return nil, fmt.Errorf("reading directory %q failed: %w", path, err)
}
for _, entry := range entries {
info, err := entry.Info()
if err != nil {
fmt.Printf("Couldn't get info for %q: %v", entry.Name(), err)
continue
}
if info.Mode()&os.ModeSymlink != 0 {
target, err := filepath.EvalSymlinks(entry.Name())
if err != nil {
fmt.Printf("Couldn't evaluate symbolic links for %q: %v", entry.Name(), err)
continue
}
info, err = os.Lstat(target)
if err != nil {
fmt.Printf("Couldn't stat target %v: %v", target, err)
continue
}
}
newPath := filepath.Join(path, info.Name())
if info.IsDir() {
newlyfound = append(newlyfound, newPath)
continue
}
if info.Mode().IsRegular() && filepath.Ext(info.Name()) == ".yang" {
moduleFiles = append(moduleFiles, info.Name())
}
}
}
if len(newlyfound) == 0 {
break
}
modulePaths = append(modulePaths, newlyfound...)
unresolved = newlyfound
}
// Add the module paths
modules.AddPath(modulePaths...)
for _, fn := range moduleFiles {
if err := modules.Read(fn); err != nil {
fmt.Printf("reading file %q failed: %v\n", fn, err)
}
}
if errs := modules.Process(); len(errs) > 0 {
return nil, errors.Join(errs...)
}
// Get all root nodes defined in models with their origin. We require
// those nodes to later resolve paths to YANG model leaf nodes...
moduleLUT := make(map[string]*yang.Module)
moduleRootNodes := make(map[string][]yang.Node)
for _, m := range modules.Modules {
// Check if we processed the module already
if _, found := moduleLUT[m.Name]; found {
continue
}
// Create a module mapping for easily finding modules by name
moduleLUT[m.Name] = m
// Determine the origin defined in the module
var prefix string
for _, imp := range m.Import {
if imp.Name == "openconfig-extensions" {
prefix = imp.Name
if imp.Prefix != nil {
prefix = imp.Prefix.Name
}
break
}
}
var moduleOrigin string
if prefix != "" {
for _, e := range m.Extensions {
if e.Keyword == prefix+":origin" || e.Keyword == "origin" {
moduleOrigin = e.Argument
break
}
}
}
for _, u := range m.Uses {
root, err := yang.FindNode(m, u.Name)
if err != nil {
return nil, err
}
moduleRootNodes[moduleOrigin] = append(moduleRootNodes[moduleOrigin], root)
}
}
return &Decoder{modules: moduleLUT, rootNodes: moduleRootNodes}, nil
}
func (d *Decoder) FindLeaf(name, identifier string) (*yang.Leaf, error) {
// Get module name from the element
module, found := d.modules[name]
if !found {
return nil, fmt.Errorf("cannot find module %q", name)
}
for _, grp := range module.Grouping {
for _, leaf := range grp.Leaf {
if leaf.Name == identifier {
return leaf, nil
}
}
}
return nil, ErrNotFound
}
func DecodeLeafValue(leaf *yang.Leaf, value interface{}) (interface{}, error) {
schema := leaf.Type.YangType
// Ignore all non-string values as the types seem already converted...
s, ok := value.(string)
if !ok {
return value, nil
}
switch schema.Kind {
case yang.Ybinary:
// Binary values are encodes as base64 string, so decode the string
raw, err := base64.StdEncoding.DecodeString(s)
if err != nil {
return value, err
}
switch schema.Name {
case "ieeefloat32":
if len(raw) != 4 {
return raw, fmt.Errorf("%w, expected 4 but got %d bytes", ErrInsufficientData, len(raw))
}
return math.Float32frombits(binary.BigEndian.Uint32(raw)), nil
default:
return raw, nil
}
case yang.Yint8:
v, err := strconv.ParseInt(s, 10, 8)
if err != nil {
return value, fmt.Errorf("parsing %s %q failed: %w", yang.TypeKindToName[schema.Kind], s, err)
}
return int8(v), nil
case yang.Yint16:
v, err := strconv.ParseInt(s, 10, 16)
if err != nil {
return value, fmt.Errorf("parsing %s %q failed: %w", yang.TypeKindToName[schema.Kind], s, err)
}
return int16(v), nil
case yang.Yint32:
v, err := strconv.ParseInt(s, 10, 32)
if err != nil {
return value, fmt.Errorf("parsing %s %q failed: %w", yang.TypeKindToName[schema.Kind], s, err)
}
return int32(v), nil
case yang.Yint64:
v, err := strconv.ParseInt(s, 10, 64)
if err != nil {
return value, fmt.Errorf("parsing %s %q failed: %w", yang.TypeKindToName[schema.Kind], s, err)
}
return v, nil
case yang.Yuint8:
v, err := strconv.ParseUint(s, 10, 8)
if err != nil {
return value, fmt.Errorf("parsing %s %q failed: %w", yang.TypeKindToName[schema.Kind], s, err)
}
return uint8(v), nil
case yang.Yuint16:
v, err := strconv.ParseUint(s, 10, 16)
if err != nil {
return value, fmt.Errorf("parsing %s %q failed: %w", yang.TypeKindToName[schema.Kind], s, err)
}
return uint16(v), nil
case yang.Yuint32:
v, err := strconv.ParseUint(s, 10, 32)
if err != nil {
return value, fmt.Errorf("parsing %s %q failed: %w", yang.TypeKindToName[schema.Kind], s, err)
}
return uint32(v), nil
case yang.Yuint64:
v, err := strconv.ParseUint(s, 10, 64)
if err != nil {
return value, fmt.Errorf("parsing %s %q failed: %w", yang.TypeKindToName[schema.Kind], s, err)
}
return v, nil
case yang.Ydecimal64:
v, err := strconv.ParseFloat(s, 64)
if err != nil {
return value, fmt.Errorf("parsing %s %q failed: %w", yang.TypeKindToName[schema.Kind], s, err)
}
return v, nil
}
return value, nil
}
func (d *Decoder) DecodeLeafElement(namespace, identifier string, value interface{}) (interface{}, error) {
leaf, err := d.FindLeaf(namespace, identifier)
if err != nil {
return nil, fmt.Errorf("finding %s failed: %w", identifier, err)
}
return DecodeLeafValue(leaf, value)
}
func (d *Decoder) DecodePathElement(origin, path string, value interface{}) (interface{}, error) {
rootNodes, found := d.rootNodes[origin]
if !found || len(rootNodes) == 0 {
return value, nil
}
for _, root := range rootNodes {
node, err := yang.FindNode(root, path)
if node == nil || err != nil {
// The path does not exist in this root node
continue
}
// We do expect a leaf node...
if leaf, ok := node.(*yang.Leaf); ok {
return DecodeLeafValue(leaf, value)
}
}
return value, nil
}