Adding upstream version 1.34.4.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
e393c3af3f
commit
4978089aab
4963 changed files with 677545 additions and 0 deletions
189
plugins/common/adx/adx.go
Normal file
189
plugins/common/adx/adx.go
Normal file
|
@ -0,0 +1,189 @@
|
|||
package adx
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Azure/azure-kusto-go/kusto"
|
||||
kustoerrors "github.com/Azure/azure-kusto-go/kusto/data/errors"
|
||||
"github.com/Azure/azure-kusto-go/kusto/ingest"
|
||||
"github.com/Azure/azure-kusto-go/kusto/kql"
|
||||
|
||||
"github.com/influxdata/telegraf"
|
||||
"github.com/influxdata/telegraf/config"
|
||||
"github.com/influxdata/telegraf/internal"
|
||||
)
|
||||
|
||||
const (
|
||||
TablePerMetric = "tablepermetric"
|
||||
SingleTable = "singletable"
|
||||
// These control the amount of memory we use when ingesting blobs
|
||||
bufferSize = 1 << 20 // 1 MiB
|
||||
maxBuffers = 5
|
||||
ManagedIngestion = "managed"
|
||||
QueuedIngestion = "queued"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Endpoint string `toml:"endpoint_url"`
|
||||
Database string `toml:"database"`
|
||||
Timeout config.Duration `toml:"timeout"`
|
||||
MetricsGrouping string `toml:"metrics_grouping_type"`
|
||||
TableName string `toml:"table_name"`
|
||||
CreateTables bool `toml:"create_tables"`
|
||||
IngestionType string `toml:"ingestion_type"`
|
||||
}
|
||||
|
||||
type Client struct {
|
||||
cfg *Config
|
||||
client *kusto.Client
|
||||
ingestors map[string]ingest.Ingestor
|
||||
logger telegraf.Logger
|
||||
}
|
||||
|
||||
func (cfg *Config) NewClient(app string, log telegraf.Logger) (*Client, error) {
|
||||
if cfg.Endpoint == "" {
|
||||
return nil, errors.New("endpoint configuration cannot be empty")
|
||||
}
|
||||
if cfg.Database == "" {
|
||||
return nil, errors.New("database configuration cannot be empty")
|
||||
}
|
||||
|
||||
cfg.MetricsGrouping = strings.ToLower(cfg.MetricsGrouping)
|
||||
if cfg.MetricsGrouping == SingleTable && cfg.TableName == "" {
|
||||
return nil, errors.New("table name cannot be empty for SingleTable metrics grouping type")
|
||||
}
|
||||
|
||||
if cfg.MetricsGrouping == "" {
|
||||
cfg.MetricsGrouping = TablePerMetric
|
||||
}
|
||||
|
||||
if cfg.MetricsGrouping != SingleTable && cfg.MetricsGrouping != TablePerMetric {
|
||||
return nil, errors.New("metrics grouping type is not valid")
|
||||
}
|
||||
|
||||
if cfg.Timeout == 0 {
|
||||
cfg.Timeout = config.Duration(20 * time.Second)
|
||||
}
|
||||
|
||||
switch cfg.IngestionType {
|
||||
case "":
|
||||
cfg.IngestionType = QueuedIngestion
|
||||
case ManagedIngestion, QueuedIngestion:
|
||||
// Do nothing as those are valid
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown ingestion type %q", cfg.IngestionType)
|
||||
}
|
||||
|
||||
conn := kusto.NewConnectionStringBuilder(cfg.Endpoint).WithDefaultAzureCredential()
|
||||
conn.SetConnectorDetails("Telegraf", internal.ProductToken(), app, "", false, "")
|
||||
client, err := kusto.New(conn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Client{
|
||||
cfg: cfg,
|
||||
ingestors: make(map[string]ingest.Ingestor),
|
||||
logger: log,
|
||||
client: client,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Clean up and close the ingestor
|
||||
func (adx *Client) Close() error {
|
||||
var errs []error
|
||||
for _, v := range adx.ingestors {
|
||||
if err := v.Close(); err != nil {
|
||||
// accumulate errors while closing ingestors
|
||||
errs = append(errs, err)
|
||||
}
|
||||
}
|
||||
if err := adx.client.Close(); err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
|
||||
adx.client = nil
|
||||
adx.ingestors = nil
|
||||
|
||||
if len(errs) == 0 {
|
||||
return nil
|
||||
}
|
||||
// Combine errors into a single object and return the combined error
|
||||
return kustoerrors.GetCombinedError(errs...)
|
||||
}
|
||||
|
||||
func (adx *Client) PushMetrics(format ingest.FileOption, tableName string, metrics []byte) error {
|
||||
ctx := context.Background()
|
||||
ctx, cancel := context.WithTimeout(ctx, time.Duration(adx.cfg.Timeout))
|
||||
defer cancel()
|
||||
metricIngestor, err := adx.getMetricIngestor(ctx, tableName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
reader := bytes.NewReader(metrics)
|
||||
mapping := ingest.IngestionMappingRef(tableName+"_mapping", ingest.JSON)
|
||||
if metricIngestor != nil {
|
||||
if _, err := metricIngestor.FromReader(ctx, reader, format, mapping); err != nil {
|
||||
return fmt.Errorf("sending ingestion request to Azure Data Explorer for table %q failed: %w", tableName, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (adx *Client) getMetricIngestor(ctx context.Context, tableName string) (ingest.Ingestor, error) {
|
||||
if ingestor := adx.ingestors[tableName]; ingestor != nil {
|
||||
return ingestor, nil
|
||||
}
|
||||
|
||||
if adx.cfg.CreateTables {
|
||||
if _, err := adx.client.Mgmt(ctx, adx.cfg.Database, createTableCommand(tableName)); err != nil {
|
||||
return nil, fmt.Errorf("creating table for %q failed: %w", tableName, err)
|
||||
}
|
||||
|
||||
if _, err := adx.client.Mgmt(ctx, adx.cfg.Database, createTableMappingCommand(tableName)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Create a new ingestor client for the table
|
||||
var ingestor ingest.Ingestor
|
||||
var err error
|
||||
switch strings.ToLower(adx.cfg.IngestionType) {
|
||||
case ManagedIngestion:
|
||||
ingestor, err = ingest.NewManaged(adx.client, adx.cfg.Database, tableName)
|
||||
case QueuedIngestion:
|
||||
ingestor, err = ingest.New(adx.client, adx.cfg.Database, tableName, ingest.WithStaticBuffer(bufferSize, maxBuffers))
|
||||
default:
|
||||
return nil, fmt.Errorf(`ingestion_type has to be one of %q or %q`, ManagedIngestion, QueuedIngestion)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating ingestor for %q failed: %w", tableName, err)
|
||||
}
|
||||
adx.ingestors[tableName] = ingestor
|
||||
|
||||
return ingestor, nil
|
||||
}
|
||||
|
||||
func createTableCommand(table string) kusto.Statement {
|
||||
builder := kql.New(`.create-merge table ['`).AddTable(table).AddLiteral(`'] `)
|
||||
builder.AddLiteral(`(['fields']:dynamic, ['name']:string, ['tags']:dynamic, ['timestamp']:datetime);`)
|
||||
|
||||
return builder
|
||||
}
|
||||
|
||||
func createTableMappingCommand(table string) kusto.Statement {
|
||||
builder := kql.New(`.create-or-alter table ['`).AddTable(table).AddLiteral(`'] `)
|
||||
builder.AddLiteral(`ingestion json mapping '`).AddTable(table + "_mapping").AddLiteral(`' `)
|
||||
builder.AddLiteral(`'[{"column":"fields", `)
|
||||
builder.AddLiteral(`"Properties":{"Path":"$[\'fields\']"}},{"column":"name", `)
|
||||
builder.AddLiteral(`"Properties":{"Path":"$[\'name\']"}},{"column":"tags", `)
|
||||
builder.AddLiteral(`"Properties":{"Path":"$[\'tags\']"}},{"column":"timestamp", `)
|
||||
builder.AddLiteral(`"Properties":{"Path":"$[\'timestamp\']"}}]'`)
|
||||
|
||||
return builder
|
||||
}
|
219
plugins/common/adx/adx_test.go
Normal file
219
plugins/common/adx/adx_test.go
Normal file
|
@ -0,0 +1,219 @@
|
|||
package adx
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Azure/azure-kusto-go/kusto"
|
||||
"github.com/Azure/azure-kusto-go/kusto/ingest"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/influxdata/telegraf"
|
||||
"github.com/influxdata/telegraf/config"
|
||||
serializers_json "github.com/influxdata/telegraf/plugins/serializers/json"
|
||||
"github.com/influxdata/telegraf/testutil"
|
||||
)
|
||||
|
||||
func TestInitBlankEndpointData(t *testing.T) {
|
||||
plugin := Config{
|
||||
Endpoint: "",
|
||||
Database: "mydb",
|
||||
}
|
||||
|
||||
_, err := plugin.NewClient("TestKusto.Telegraf", nil)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, "endpoint configuration cannot be empty", err.Error())
|
||||
}
|
||||
|
||||
func TestQueryConstruction(t *testing.T) {
|
||||
const tableName = "mytable"
|
||||
const expectedCreate = `.create-merge table ['mytable'] (['fields']:dynamic, ['name']:string, ['tags']:dynamic, ['timestamp']:datetime);`
|
||||
const expectedMapping = `` +
|
||||
`.create-or-alter table ['mytable'] ingestion json mapping 'mytable_mapping' '[{"column":"fields", ` +
|
||||
`"Properties":{"Path":"$[\'fields\']"}},{"column":"name", "Properties":{"Path":"$[\'name\']"}},{"column":"tags", ` +
|
||||
`"Properties":{"Path":"$[\'tags\']"}},{"column":"timestamp", "Properties":{"Path":"$[\'timestamp\']"}}]'`
|
||||
require.Equal(t, expectedCreate, createTableCommand(tableName).String())
|
||||
require.Equal(t, expectedMapping, createTableMappingCommand(tableName).String())
|
||||
}
|
||||
|
||||
func TestGetMetricIngestor(t *testing.T) {
|
||||
plugin := Client{
|
||||
logger: testutil.Logger{},
|
||||
client: kusto.NewMockClient(),
|
||||
cfg: &Config{
|
||||
Database: "mydb",
|
||||
IngestionType: QueuedIngestion,
|
||||
},
|
||||
ingestors: map[string]ingest.Ingestor{"test1": &fakeIngestor{}},
|
||||
}
|
||||
|
||||
ingestor, err := plugin.getMetricIngestor(t.Context(), "test1")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, ingestor)
|
||||
}
|
||||
|
||||
func TestGetMetricIngestorNoIngester(t *testing.T) {
|
||||
plugin := Client{
|
||||
logger: testutil.Logger{},
|
||||
client: kusto.NewMockClient(),
|
||||
cfg: &Config{
|
||||
IngestionType: QueuedIngestion,
|
||||
},
|
||||
ingestors: map[string]ingest.Ingestor{"test1": &fakeIngestor{}},
|
||||
}
|
||||
|
||||
ingestor, err := plugin.getMetricIngestor(t.Context(), "test1")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, ingestor)
|
||||
}
|
||||
|
||||
func TestPushMetrics(t *testing.T) {
|
||||
plugin := Client{
|
||||
logger: testutil.Logger{},
|
||||
client: kusto.NewMockClient(),
|
||||
cfg: &Config{
|
||||
Database: "mydb",
|
||||
Endpoint: "https://ingest-test.westus.kusto.windows.net",
|
||||
IngestionType: QueuedIngestion,
|
||||
},
|
||||
ingestors: map[string]ingest.Ingestor{"test1": &fakeIngestor{}},
|
||||
}
|
||||
|
||||
metrics := []byte(`{"fields": {"value": 1}, "name": "test1", "tags": {"tag1": "value1"}, "timestamp": "2021-01-01T00:00:00Z"}`)
|
||||
require.NoError(t, plugin.PushMetrics(ingest.FileFormat(ingest.JSON), "test1", metrics))
|
||||
}
|
||||
|
||||
func TestPushMetricsOutputs(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
inputMetric []telegraf.Metric
|
||||
metricsGrouping string
|
||||
createTables bool
|
||||
ingestionType string
|
||||
}{
|
||||
{
|
||||
name: "Valid metric",
|
||||
inputMetric: testutil.MockMetrics(),
|
||||
createTables: true,
|
||||
metricsGrouping: TablePerMetric,
|
||||
},
|
||||
{
|
||||
name: "Don't create tables'",
|
||||
inputMetric: testutil.MockMetrics(),
|
||||
createTables: false,
|
||||
metricsGrouping: TablePerMetric,
|
||||
},
|
||||
{
|
||||
name: "SingleTable metric grouping type",
|
||||
inputMetric: testutil.MockMetrics(),
|
||||
createTables: true,
|
||||
metricsGrouping: SingleTable,
|
||||
},
|
||||
{
|
||||
name: "Valid metric managed ingestion",
|
||||
inputMetric: testutil.MockMetrics(),
|
||||
createTables: true,
|
||||
metricsGrouping: TablePerMetric,
|
||||
ingestionType: ManagedIngestion,
|
||||
},
|
||||
}
|
||||
var expectedMetric = map[string]interface{}{
|
||||
"metricName": "test1",
|
||||
"fields": map[string]interface{}{
|
||||
"value": 1.0,
|
||||
},
|
||||
"tags": map[string]interface{}{
|
||||
"tag1": "value1",
|
||||
},
|
||||
"timestamp": float64(time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC).UnixNano() / int64(time.Second)),
|
||||
}
|
||||
for _, tC := range testCases {
|
||||
t.Run(tC.name, func(t *testing.T) {
|
||||
ingestionType := "queued"
|
||||
if tC.ingestionType != "" {
|
||||
ingestionType = tC.ingestionType
|
||||
}
|
||||
|
||||
serializer := &serializers_json.Serializer{
|
||||
TimestampUnits: config.Duration(time.Nanosecond),
|
||||
TimestampFormat: time.RFC3339Nano,
|
||||
}
|
||||
|
||||
cfg := &Config{
|
||||
Endpoint: "https://someendpoint.kusto.net",
|
||||
Database: "databasename",
|
||||
MetricsGrouping: tC.metricsGrouping,
|
||||
TableName: "test1",
|
||||
CreateTables: tC.createTables,
|
||||
IngestionType: ingestionType,
|
||||
Timeout: config.Duration(20 * time.Second),
|
||||
}
|
||||
client, err := cfg.NewClient("telegraf", &testutil.Logger{})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Inject the ingestor
|
||||
ingestor := &fakeIngestor{}
|
||||
client.ingestors["test1"] = ingestor
|
||||
|
||||
tableMetricGroups := make(map[string][]byte)
|
||||
mockmetrics := testutil.MockMetrics()
|
||||
for _, m := range mockmetrics {
|
||||
metricInBytes, err := serializer.Serialize(m)
|
||||
require.NoError(t, err)
|
||||
tableMetricGroups[m.Name()] = append(tableMetricGroups[m.Name()], metricInBytes...)
|
||||
}
|
||||
|
||||
format := ingest.FileFormat(ingest.JSON)
|
||||
for tableName, tableMetrics := range tableMetricGroups {
|
||||
require.NoError(t, client.PushMetrics(format, tableName, tableMetrics))
|
||||
createdFakeIngestor := ingestor
|
||||
require.EqualValues(t, expectedMetric["metricName"], createdFakeIngestor.actualOutputMetric["name"])
|
||||
require.EqualValues(t, expectedMetric["fields"], createdFakeIngestor.actualOutputMetric["fields"])
|
||||
require.EqualValues(t, expectedMetric["tags"], createdFakeIngestor.actualOutputMetric["tags"])
|
||||
timestampStr := createdFakeIngestor.actualOutputMetric["timestamp"].(string)
|
||||
parsedTime, err := time.Parse(time.RFC3339Nano, timestampStr)
|
||||
parsedTimeFloat := float64(parsedTime.UnixNano()) / 1e9
|
||||
require.NoError(t, err)
|
||||
require.InDelta(t, expectedMetric["timestamp"].(float64), parsedTimeFloat, testutil.DefaultDelta)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAlreadyClosed(t *testing.T) {
|
||||
plugin := Client{
|
||||
logger: testutil.Logger{},
|
||||
cfg: &Config{
|
||||
IngestionType: QueuedIngestion,
|
||||
},
|
||||
client: kusto.NewMockClient(),
|
||||
}
|
||||
require.NoError(t, plugin.Close())
|
||||
}
|
||||
|
||||
type fakeIngestor struct {
|
||||
actualOutputMetric map[string]interface{}
|
||||
}
|
||||
|
||||
func (f *fakeIngestor) FromReader(_ context.Context, reader io.Reader, _ ...ingest.FileOption) (*ingest.Result, error) {
|
||||
scanner := bufio.NewScanner(reader)
|
||||
scanner.Scan()
|
||||
firstLine := scanner.Text()
|
||||
err := json.Unmarshal([]byte(firstLine), &f.actualOutputMetric)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &ingest.Result{}, nil
|
||||
}
|
||||
|
||||
func (*fakeIngestor) FromFile(_ context.Context, _ string, _ ...ingest.FileOption) (*ingest.Result, error) {
|
||||
return &ingest.Result{}, nil
|
||||
}
|
||||
|
||||
func (*fakeIngestor) Close() error {
|
||||
return nil
|
||||
}
|
23
plugins/common/auth/basic_auth.go
Normal file
23
plugins/common/auth/basic_auth.go
Normal file
|
@ -0,0 +1,23 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"crypto/subtle"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type BasicAuth struct {
|
||||
Username string `toml:"username"`
|
||||
Password string `toml:"password"`
|
||||
}
|
||||
|
||||
func (b *BasicAuth) Verify(r *http.Request) bool {
|
||||
if b.Username == "" && b.Password == "" {
|
||||
return true
|
||||
}
|
||||
|
||||
username, password, ok := r.BasicAuth()
|
||||
|
||||
usernameComparison := subtle.ConstantTimeCompare([]byte(username), []byte(b.Username)) == 1
|
||||
passwordComparison := subtle.ConstantTimeCompare([]byte(password), []byte(b.Password)) == 1
|
||||
return ok && usernameComparison && passwordComparison
|
||||
}
|
34
plugins/common/auth/basic_auth_test.go
Normal file
34
plugins/common/auth/basic_auth_test.go
Normal file
|
@ -0,0 +1,34 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBasicAuth_VerifyWithCredentials(t *testing.T) {
|
||||
auth := BasicAuth{"username", "password"}
|
||||
|
||||
r := httptest.NewRequest("GET", "/github", nil)
|
||||
r.SetBasicAuth(auth.Username, auth.Password)
|
||||
|
||||
require.True(t, auth.Verify(r))
|
||||
}
|
||||
|
||||
func TestBasicAuth_VerifyWithoutCredentials(t *testing.T) {
|
||||
auth := BasicAuth{}
|
||||
|
||||
r := httptest.NewRequest("GET", "/github", nil)
|
||||
|
||||
require.True(t, auth.Verify(r))
|
||||
}
|
||||
|
||||
func TestBasicAuth_VerifyWithInvalidCredentials(t *testing.T) {
|
||||
auth := BasicAuth{"username", "password"}
|
||||
|
||||
r := httptest.NewRequest("GET", "/github", nil)
|
||||
r.SetBasicAuth("wrong-username", "wrong-password")
|
||||
|
||||
require.False(t, auth.Verify(r))
|
||||
}
|
84
plugins/common/aws/credentials.go
Normal file
84
plugins/common/aws/credentials.go
Normal file
|
@ -0,0 +1,84 @@
|
|||
package aws
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
"github.com/aws/aws-sdk-go-v2/config"
|
||||
"github.com/aws/aws-sdk-go-v2/credentials"
|
||||
"github.com/aws/aws-sdk-go-v2/credentials/stscreds"
|
||||
"github.com/aws/aws-sdk-go-v2/service/sts"
|
||||
)
|
||||
|
||||
// The endpoint_url supplied here is used for specific AWS service (Cloudwatch / Timestream / etc.)
|
||||
type CredentialConfig struct {
|
||||
Region string `toml:"region"`
|
||||
AccessKey string `toml:"access_key"`
|
||||
SecretKey string `toml:"secret_key"`
|
||||
RoleARN string `toml:"role_arn"`
|
||||
Profile string `toml:"profile"`
|
||||
Filename string `toml:"shared_credential_file"`
|
||||
Token string `toml:"token"`
|
||||
EndpointURL string `toml:"endpoint_url"`
|
||||
RoleSessionName string `toml:"role_session_name"`
|
||||
WebIdentityTokenFile string `toml:"web_identity_token_file"`
|
||||
}
|
||||
|
||||
func (c *CredentialConfig) Credentials() (aws.Config, error) {
|
||||
if c.RoleARN != "" {
|
||||
return c.configWithAssumeCredentials()
|
||||
}
|
||||
return c.configWithRootCredentials()
|
||||
}
|
||||
|
||||
func (c *CredentialConfig) configWithRootCredentials() (aws.Config, error) {
|
||||
options := []func(*config.LoadOptions) error{
|
||||
config.WithRegion(c.Region),
|
||||
}
|
||||
|
||||
if c.Profile != "" {
|
||||
options = append(options, config.WithSharedConfigProfile(c.Profile))
|
||||
}
|
||||
if c.Filename != "" {
|
||||
options = append(options, config.WithSharedCredentialsFiles([]string{c.Filename}))
|
||||
}
|
||||
|
||||
if c.AccessKey != "" || c.SecretKey != "" {
|
||||
provider := credentials.NewStaticCredentialsProvider(c.AccessKey, c.SecretKey, c.Token)
|
||||
options = append(options, config.WithCredentialsProvider(provider))
|
||||
}
|
||||
|
||||
return config.LoadDefaultConfig(context.Background(), options...)
|
||||
}
|
||||
|
||||
func (c *CredentialConfig) configWithAssumeCredentials() (aws.Config, error) {
|
||||
// To generate credentials using assumeRole, we need to create AWS STS client with the default AWS endpoint,
|
||||
defaultConfig, err := c.configWithRootCredentials()
|
||||
if err != nil {
|
||||
return aws.Config{}, err
|
||||
}
|
||||
|
||||
var provider aws.CredentialsProvider
|
||||
stsService := sts.NewFromConfig(defaultConfig)
|
||||
if c.WebIdentityTokenFile != "" {
|
||||
provider = stscreds.NewWebIdentityRoleProvider(
|
||||
stsService,
|
||||
c.RoleARN,
|
||||
stscreds.IdentityTokenFile(c.WebIdentityTokenFile),
|
||||
func(opts *stscreds.WebIdentityRoleOptions) {
|
||||
if c.RoleSessionName != "" {
|
||||
opts.RoleSessionName = c.RoleSessionName
|
||||
}
|
||||
},
|
||||
)
|
||||
} else {
|
||||
provider = stscreds.NewAssumeRoleProvider(stsService, c.RoleARN, func(opts *stscreds.AssumeRoleOptions) {
|
||||
if c.RoleSessionName != "" {
|
||||
opts.RoleSessionName = c.RoleSessionName
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
defaultConfig.Credentials = aws.NewCredentialsCache(provider)
|
||||
return defaultConfig, nil
|
||||
}
|
136
plugins/common/cookie/cookie.go
Normal file
136
plugins/common/cookie/cookie.go
Normal file
|
@ -0,0 +1,136 @@
|
|||
package cookie
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/cookiejar"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
clockutil "github.com/benbjohnson/clock"
|
||||
|
||||
"github.com/influxdata/telegraf"
|
||||
"github.com/influxdata/telegraf/config"
|
||||
)
|
||||
|
||||
type CookieAuthConfig struct {
|
||||
URL string `toml:"cookie_auth_url"`
|
||||
Method string `toml:"cookie_auth_method"`
|
||||
|
||||
Headers map[string]*config.Secret `toml:"cookie_auth_headers"`
|
||||
|
||||
// HTTP Basic Auth Credentials
|
||||
Username string `toml:"cookie_auth_username"`
|
||||
Password string `toml:"cookie_auth_password"`
|
||||
|
||||
Body string `toml:"cookie_auth_body"`
|
||||
Renewal config.Duration `toml:"cookie_auth_renewal"`
|
||||
|
||||
client *http.Client
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
func (c *CookieAuthConfig) Start(client *http.Client, log telegraf.Logger, clock clockutil.Clock) (err error) {
|
||||
if err := c.initializeClient(client); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// continual auth renewal if set
|
||||
if c.Renewal > 0 {
|
||||
ticker := clock.Ticker(time.Duration(c.Renewal))
|
||||
// this context is used in the tests only, it is to cancel the goroutine
|
||||
go c.authRenewal(context.Background(), ticker, log)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CookieAuthConfig) initializeClient(client *http.Client) (err error) {
|
||||
c.client = client
|
||||
|
||||
if c.Method == "" {
|
||||
c.Method = http.MethodPost
|
||||
}
|
||||
|
||||
return c.auth()
|
||||
}
|
||||
|
||||
func (c *CookieAuthConfig) authRenewal(ctx context.Context, ticker *clockutil.Ticker, log telegraf.Logger) {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
c.wg.Done()
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := c.auth(); err != nil && log != nil {
|
||||
log.Errorf("renewal failed for %q: %v", c.URL, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *CookieAuthConfig) auth() error {
|
||||
var err error
|
||||
|
||||
// everytime we auth we clear out the cookie jar to ensure that the cookie
|
||||
// is not used as a part of re-authing. The only way to empty or reset is
|
||||
// to create a new cookie jar.
|
||||
c.client.Jar, err = cookiejar.New(nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var body io.Reader
|
||||
if c.Body != "" {
|
||||
body = strings.NewReader(c.Body)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(c.Method, c.URL, body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if c.Username != "" {
|
||||
req.SetBasicAuth(c.Username, c.Password)
|
||||
}
|
||||
|
||||
for k, v := range c.Headers {
|
||||
secret, err := v.Get()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
headerVal := secret.String()
|
||||
if strings.EqualFold(k, "host") {
|
||||
req.Host = headerVal
|
||||
} else {
|
||||
req.Header.Add(k, headerVal)
|
||||
}
|
||||
|
||||
secret.Destroy()
|
||||
}
|
||||
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode > 299 {
|
||||
return fmt.Errorf("cookie auth renewal received status code: %v (%v) [%v]",
|
||||
resp.StatusCode,
|
||||
http.StatusText(resp.StatusCode),
|
||||
string(respBody),
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
281
plugins/common/cookie/cookie_test.go
Normal file
281
plugins/common/cookie/cookie_test.go
Normal file
|
@ -0,0 +1,281 @@
|
|||
package cookie
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
clockutil "github.com/benbjohnson/clock"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/influxdata/telegraf/config"
|
||||
"github.com/influxdata/telegraf/testutil"
|
||||
)
|
||||
|
||||
const (
|
||||
reqUser = "testUser"
|
||||
reqPasswd = "testPassword"
|
||||
reqBody = "a body"
|
||||
reqHeaderKey = "hello"
|
||||
reqHeaderVal = "world"
|
||||
|
||||
authEndpointNoCreds = "/auth"
|
||||
authEndpointWithBasicAuth = "/authWithCreds"
|
||||
authEndpointWithBasicAuthOnlyUsername = "/authWithCredsUser"
|
||||
authEndpointWithBody = "/authWithBody"
|
||||
authEndpointWithHeader = "/authWithHeader"
|
||||
)
|
||||
|
||||
var fakeCookie = &http.Cookie{
|
||||
Name: "test-cookie",
|
||||
Value: "this is an auth cookie",
|
||||
}
|
||||
|
||||
var reqHeaderValSecret = config.NewSecret([]byte(reqHeaderVal))
|
||||
|
||||
type fakeServer struct {
|
||||
*httptest.Server
|
||||
*int32
|
||||
}
|
||||
|
||||
func newFakeServer(t *testing.T) fakeServer {
|
||||
var c int32
|
||||
return fakeServer{
|
||||
Server: httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
authed := func() {
|
||||
atomic.AddInt32(&c, 1) // increment auth counter
|
||||
http.SetCookie(w, fakeCookie) // set fake cookie
|
||||
}
|
||||
switch r.URL.Path {
|
||||
case authEndpointNoCreds:
|
||||
authed()
|
||||
case authEndpointWithHeader:
|
||||
if !cmp.Equal(r.Header.Get(reqHeaderKey), reqHeaderVal) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
authed()
|
||||
case authEndpointWithBody:
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
if !cmp.Equal([]byte(reqBody), body) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
authed()
|
||||
case authEndpointWithBasicAuth:
|
||||
u, p, ok := r.BasicAuth()
|
||||
if !ok || u != reqUser || p != reqPasswd {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
authed()
|
||||
case authEndpointWithBasicAuthOnlyUsername:
|
||||
u, p, ok := r.BasicAuth()
|
||||
if !ok || u != reqUser || p != "" {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
authed()
|
||||
default:
|
||||
// ensure cookie exists on request
|
||||
if _, err := r.Cookie(fakeCookie.Name); err != nil {
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
if _, err := w.Write([]byte("good test response")); err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
})),
|
||||
int32: &c,
|
||||
}
|
||||
}
|
||||
|
||||
func (s fakeServer) checkResp(t *testing.T, expCode int) {
|
||||
t.Helper()
|
||||
resp, err := s.Client().Get(s.URL + "/endpoint")
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
require.Equal(t, expCode, resp.StatusCode)
|
||||
|
||||
if expCode == http.StatusOK {
|
||||
require.Len(t, resp.Request.Cookies(), 1)
|
||||
require.Equal(t, "test-cookie", resp.Request.Cookies()[0].Name)
|
||||
}
|
||||
}
|
||||
|
||||
func (s fakeServer) checkAuthCount(t *testing.T, atLeast int32) {
|
||||
t.Helper()
|
||||
require.GreaterOrEqual(t, atomic.LoadInt32(s.int32), atLeast)
|
||||
}
|
||||
|
||||
func TestAuthConfig_Start(t *testing.T) {
|
||||
const (
|
||||
renewal = 50 * time.Millisecond
|
||||
renewalCheck = 5 * renewal
|
||||
)
|
||||
type fields struct {
|
||||
Method string
|
||||
Username string
|
||||
Password string
|
||||
Body string
|
||||
Headers map[string]*config.Secret
|
||||
}
|
||||
type args struct {
|
||||
renewal time.Duration
|
||||
endpoint string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
wantErr error
|
||||
firstAuthCount int32
|
||||
lastAuthCount int32
|
||||
firstHTTPResponse int
|
||||
lastHTTPResponse int
|
||||
}{
|
||||
{
|
||||
name: "success no creds, no body, default method",
|
||||
args: args{
|
||||
renewal: renewal,
|
||||
endpoint: authEndpointNoCreds,
|
||||
},
|
||||
firstAuthCount: 1,
|
||||
lastAuthCount: 3,
|
||||
firstHTTPResponse: http.StatusOK,
|
||||
lastHTTPResponse: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "success no creds, no body, default method, header set",
|
||||
args: args{
|
||||
renewal: renewal,
|
||||
endpoint: authEndpointWithHeader,
|
||||
},
|
||||
fields: fields{
|
||||
Headers: map[string]*config.Secret{reqHeaderKey: &reqHeaderValSecret},
|
||||
},
|
||||
firstAuthCount: 1,
|
||||
lastAuthCount: 3,
|
||||
firstHTTPResponse: http.StatusOK,
|
||||
lastHTTPResponse: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "success with creds, no body",
|
||||
fields: fields{
|
||||
Method: http.MethodPost,
|
||||
Username: reqUser,
|
||||
Password: reqPasswd,
|
||||
},
|
||||
args: args{
|
||||
renewal: renewal,
|
||||
endpoint: authEndpointWithBasicAuth,
|
||||
},
|
||||
firstAuthCount: 1,
|
||||
lastAuthCount: 3,
|
||||
firstHTTPResponse: http.StatusOK,
|
||||
lastHTTPResponse: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "failure with bad creds",
|
||||
fields: fields{
|
||||
Method: http.MethodPost,
|
||||
Username: reqUser,
|
||||
Password: "a bad password",
|
||||
},
|
||||
args: args{
|
||||
renewal: renewal,
|
||||
endpoint: authEndpointWithBasicAuth,
|
||||
},
|
||||
wantErr: errors.New("cookie auth renewal received status code: 401 (Unauthorized) []"),
|
||||
firstAuthCount: 0,
|
||||
lastAuthCount: 0,
|
||||
firstHTTPResponse: http.StatusForbidden,
|
||||
lastHTTPResponse: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "success with no creds, with good body",
|
||||
fields: fields{
|
||||
Method: http.MethodPost,
|
||||
Body: reqBody,
|
||||
},
|
||||
args: args{
|
||||
renewal: renewal,
|
||||
endpoint: authEndpointWithBody,
|
||||
},
|
||||
firstAuthCount: 1,
|
||||
lastAuthCount: 3,
|
||||
firstHTTPResponse: http.StatusOK,
|
||||
lastHTTPResponse: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "failure with bad body",
|
||||
fields: fields{
|
||||
Method: http.MethodPost,
|
||||
Body: "a bad body",
|
||||
},
|
||||
args: args{
|
||||
renewal: renewal,
|
||||
endpoint: authEndpointWithBody,
|
||||
},
|
||||
wantErr: errors.New("cookie auth renewal received status code: 401 (Unauthorized) []"),
|
||||
firstAuthCount: 0,
|
||||
lastAuthCount: 0,
|
||||
firstHTTPResponse: http.StatusForbidden,
|
||||
lastHTTPResponse: http.StatusForbidden,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
srv := newFakeServer(t)
|
||||
c := &CookieAuthConfig{
|
||||
URL: srv.URL + tt.args.endpoint,
|
||||
Method: tt.fields.Method,
|
||||
Username: tt.fields.Username,
|
||||
Password: tt.fields.Password,
|
||||
Body: tt.fields.Body,
|
||||
Headers: tt.fields.Headers,
|
||||
Renewal: config.Duration(tt.args.renewal),
|
||||
}
|
||||
if err := c.initializeClient(srv.Client()); tt.wantErr != nil {
|
||||
require.EqualError(t, err, tt.wantErr.Error())
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
mock := clockutil.NewMock()
|
||||
ticker := mock.Ticker(time.Duration(c.Renewal))
|
||||
defer ticker.Stop()
|
||||
|
||||
c.wg.Add(1)
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
go c.authRenewal(ctx, ticker, testutil.Logger{Name: "cookie_auth"})
|
||||
|
||||
srv.checkAuthCount(t, tt.firstAuthCount)
|
||||
srv.checkResp(t, tt.firstHTTPResponse)
|
||||
mock.Add(renewalCheck)
|
||||
|
||||
// Ensure that the auth renewal goroutine has completed
|
||||
require.Eventually(t, func() bool { return atomic.LoadInt32(srv.int32) >= tt.lastAuthCount }, time.Second, 10*time.Millisecond)
|
||||
|
||||
cancel()
|
||||
c.wg.Wait()
|
||||
srv.checkAuthCount(t, tt.lastAuthCount)
|
||||
srv.checkResp(t, tt.lastHTTPResponse)
|
||||
|
||||
srv.Close()
|
||||
})
|
||||
}
|
||||
}
|
79
plugins/common/docker/stats_helpers.go
Normal file
79
plugins/common/docker/stats_helpers.go
Normal file
|
@ -0,0 +1,79 @@
|
|||
// Package docker contains few helper functions copied from
|
||||
// https://github.com/docker/cli/blob/master/cli/command/container/stats_helpers.go
|
||||
package docker
|
||||
|
||||
import (
|
||||
"github.com/docker/docker/api/types/container"
|
||||
)
|
||||
|
||||
// CalculateCPUPercentUnix calculate CPU usage (for Unix, in percentages)
|
||||
func CalculateCPUPercentUnix(previousCPU, previousSystem uint64, v *container.StatsResponse) float64 {
|
||||
var (
|
||||
cpuPercent = 0.0
|
||||
// calculate the change for the cpu usage of the container in between readings
|
||||
cpuDelta = float64(v.CPUStats.CPUUsage.TotalUsage) - float64(previousCPU)
|
||||
// calculate the change for the entire system between readings
|
||||
systemDelta = float64(v.CPUStats.SystemUsage) - float64(previousSystem)
|
||||
onlineCPUs = float64(v.CPUStats.OnlineCPUs)
|
||||
)
|
||||
|
||||
if onlineCPUs == 0.0 {
|
||||
onlineCPUs = float64(len(v.CPUStats.CPUUsage.PercpuUsage))
|
||||
}
|
||||
if systemDelta > 0.0 && cpuDelta > 0.0 {
|
||||
cpuPercent = (cpuDelta / systemDelta) * onlineCPUs * 100.0
|
||||
}
|
||||
return cpuPercent
|
||||
}
|
||||
|
||||
// CalculateCPUPercentWindows calculate CPU usage (for Windows, in percentages)
|
||||
func CalculateCPUPercentWindows(v *container.StatsResponse) float64 {
|
||||
// Max number of 100ns intervals between the previous time read and now
|
||||
possIntervals := uint64(v.Read.Sub(v.PreRead).Nanoseconds()) // Start with number of ns intervals
|
||||
possIntervals /= 100 // Convert to number of 100ns intervals
|
||||
possIntervals *= uint64(v.NumProcs) // Multiple by the number of processors
|
||||
|
||||
// Intervals used
|
||||
intervalsUsed := v.CPUStats.CPUUsage.TotalUsage - v.PreCPUStats.CPUUsage.TotalUsage
|
||||
|
||||
// Percentage avoiding divide-by-zero
|
||||
if possIntervals > 0 {
|
||||
return float64(intervalsUsed) / float64(possIntervals) * 100.0
|
||||
}
|
||||
return 0.00
|
||||
}
|
||||
|
||||
// CalculateMemUsageUnixNoCache calculate memory usage of the container.
|
||||
// Cache is intentionally excluded to avoid misinterpretation of the output.
|
||||
//
|
||||
// On Docker 19.03 and older, the result is `mem.Usage - mem.Stats["cache"]`.
|
||||
// On new docker with cgroup v1 host, the result is `mem.Usage - mem.Stats["total_inactive_file"]`.
|
||||
// On new docker with cgroup v2 host, the result is `mem.Usage - mem.Stats["inactive_file"]`.
|
||||
//
|
||||
// This definition is designed to be consistent with past values and the latest docker CLI
|
||||
// * https://github.com/docker/cli/blob/6e2838e18645e06f3e4b6c5143898ccc44063e3b/cli/command/container/stats_helpers.go#L239
|
||||
func CalculateMemUsageUnixNoCache(mem container.MemoryStats) float64 {
|
||||
// Docker 19.03 and older
|
||||
if v, isOldDocker := mem.Stats["cache"]; isOldDocker && v < mem.Usage {
|
||||
return float64(mem.Usage - v)
|
||||
}
|
||||
// cgroup v1
|
||||
if v, isCgroup1 := mem.Stats["total_inactive_file"]; isCgroup1 && v < mem.Usage {
|
||||
return float64(mem.Usage - v)
|
||||
}
|
||||
// cgroup v2
|
||||
if v := mem.Stats["inactive_file"]; v < mem.Usage {
|
||||
return float64(mem.Usage - v)
|
||||
}
|
||||
return float64(mem.Usage)
|
||||
}
|
||||
|
||||
// CalculateMemPercentUnixNoCache calculate memory usage of the container, in percentages.
|
||||
func CalculateMemPercentUnixNoCache(limit, usedNoCache float64) float64 {
|
||||
// MemoryStats.Limit will never be 0 unless the container is not running and we haven't
|
||||
// got any data from cgroup
|
||||
if limit != 0 {
|
||||
return usedNoCache / limit * 100.0
|
||||
}
|
||||
return 0
|
||||
}
|
34
plugins/common/encoding/decoder.go
Normal file
34
plugins/common/encoding/decoder.go
Normal file
|
@ -0,0 +1,34 @@
|
|||
package encoding
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"golang.org/x/text/encoding"
|
||||
"golang.org/x/text/encoding/unicode"
|
||||
)
|
||||
|
||||
// NewDecoder returns an x/text Decoder for the specified text encoding. The
|
||||
// Decoder converts a character encoding into utf-8 bytes. If a BOM is found
|
||||
// it will be converted into an utf-8 BOM, you can use
|
||||
// github.com/dimchansky/utfbom to strip the BOM.
|
||||
//
|
||||
// The "none" or "" encoding will pass through bytes unchecked. Use the utf-8
|
||||
// encoding if you want invalid bytes replaced using the unicode
|
||||
// replacement character.
|
||||
//
|
||||
// Detection of utf-16 endianness using the BOM is not currently provided due
|
||||
// to the tail input plugins requirement to be able to start at the middle or
|
||||
// end of the file.
|
||||
func NewDecoder(enc string) (*Decoder, error) {
|
||||
switch enc {
|
||||
case "utf-8":
|
||||
return createDecoder(unicode.UTF8.NewDecoder()), nil
|
||||
case "utf-16le":
|
||||
return createDecoder(unicode.UTF16(unicode.LittleEndian, unicode.IgnoreBOM).NewDecoder()), nil
|
||||
case "utf-16be":
|
||||
return createDecoder(unicode.UTF16(unicode.BigEndian, unicode.IgnoreBOM).NewDecoder()), nil
|
||||
case "none", "":
|
||||
return createDecoder(encoding.Nop.NewDecoder()), nil
|
||||
}
|
||||
return nil, errors.New("unknown character encoding")
|
||||
}
|
164
plugins/common/encoding/decoder_reader.go
Normal file
164
plugins/common/encoding/decoder_reader.go
Normal file
|
@ -0,0 +1,164 @@
|
|||
package encoding
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
|
||||
"golang.org/x/text/transform"
|
||||
)
|
||||
|
||||
// Other than resetting r.err and r.transformComplete in Read() this
|
||||
// was copied from x/text
|
||||
|
||||
func createDecoder(t transform.Transformer) *Decoder {
|
||||
return &Decoder{Transformer: t}
|
||||
}
|
||||
|
||||
// A Decoder converts bytes to UTF-8. It implements transform.Transformer.
|
||||
//
|
||||
// Transforming source bytes that are not of that encoding will not result in an
|
||||
// error per se. Each byte that cannot be transcoded will be represented in the
|
||||
// output by the UTF-8 encoding of '\uFFFD', the replacement rune.
|
||||
type Decoder struct {
|
||||
transform.Transformer
|
||||
|
||||
// This forces external creators of Decoders to use names in struct
|
||||
// initializers, allowing for future extensibility without having to break
|
||||
// code.
|
||||
_ struct{}
|
||||
}
|
||||
|
||||
// Bytes converts the given encoded bytes to UTF-8. It returns the converted
|
||||
// bytes or nil, err if any error occurred.
|
||||
func (d *Decoder) Bytes(b []byte) ([]byte, error) {
|
||||
b, _, err := transform.Bytes(d, b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return b, nil
|
||||
}
|
||||
|
||||
// String converts the given encoded string to UTF-8. It returns the converted
|
||||
// string or "", err if any error occurred.
|
||||
func (d *Decoder) String(s string) (string, error) {
|
||||
s, _, err := transform.String(d, s)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// Reader wraps another Reader to decode its bytes.
|
||||
//
|
||||
// The Decoder may not be used for any other operation as long as the returned
|
||||
// Reader is in use.
|
||||
func (d *Decoder) Reader(r io.Reader) io.Reader {
|
||||
return NewReader(r, d)
|
||||
}
|
||||
|
||||
// Reader wraps another io.Reader by transforming the bytes read.
|
||||
type Reader struct {
|
||||
r io.Reader
|
||||
t transform.Transformer
|
||||
err error
|
||||
|
||||
// dst[dst0:dst1] contains bytes that have been transformed by t but
|
||||
// not yet copied out via Read.
|
||||
dst []byte
|
||||
dst0, dst1 int
|
||||
|
||||
// src[src0:src1] contains bytes that have been read from r but not
|
||||
// yet transformed through t.
|
||||
src []byte
|
||||
src0, src1 int
|
||||
|
||||
// transformComplete is whether the transformation is complete,
|
||||
// regardless of whether it was successful.
|
||||
transformComplete bool
|
||||
}
|
||||
|
||||
var (
|
||||
// errInconsistentByteCount means that Transform returned success (nil
|
||||
// error) but also returned nSrc inconsistent with the src argument.
|
||||
errInconsistentByteCount = errors.New("transform: inconsistent byte count returned")
|
||||
)
|
||||
|
||||
const defaultBufSize = 4096
|
||||
|
||||
// NewReader returns a new Reader that wraps r by transforming the bytes read
|
||||
// via t. It calls Reset on t.
|
||||
func NewReader(r io.Reader, t transform.Transformer) *Reader {
|
||||
t.Reset()
|
||||
return &Reader{
|
||||
r: r,
|
||||
t: t,
|
||||
dst: make([]byte, defaultBufSize),
|
||||
src: make([]byte, defaultBufSize),
|
||||
}
|
||||
}
|
||||
|
||||
// Read implements the io.Reader interface.
|
||||
func (r *Reader) Read(p []byte) (int, error) {
|
||||
// Clear previous errors so a Read can be performed even if the last call
|
||||
// returned EOF.
|
||||
r.err = nil
|
||||
r.transformComplete = false
|
||||
|
||||
n := 0
|
||||
for {
|
||||
// Copy out any transformed bytes and return the final error if we are done.
|
||||
if r.dst0 != r.dst1 {
|
||||
n = copy(p, r.dst[r.dst0:r.dst1])
|
||||
r.dst0 += n
|
||||
if r.dst0 == r.dst1 && r.transformComplete {
|
||||
return n, r.err
|
||||
}
|
||||
return n, nil
|
||||
} else if r.transformComplete {
|
||||
return 0, r.err
|
||||
}
|
||||
|
||||
// Try to transform some source bytes, or to flush the transformer if we
|
||||
// are out of source bytes. We do this even if r.r.Read returned an error.
|
||||
// As the io.Reader documentation says, "process the n > 0 bytes returned
|
||||
// before considering the error".
|
||||
if r.src0 != r.src1 || r.err != nil {
|
||||
var err error
|
||||
r.dst0 = 0
|
||||
r.dst1, n, err = r.t.Transform(r.dst, r.src[r.src0:r.src1], errors.Is(r.err, io.EOF))
|
||||
r.src0 += n
|
||||
|
||||
switch {
|
||||
case err == nil:
|
||||
if r.src0 != r.src1 {
|
||||
r.err = errInconsistentByteCount
|
||||
}
|
||||
// The Transform call was successful; we are complete if we
|
||||
// cannot read more bytes into src.
|
||||
r.transformComplete = r.err != nil
|
||||
continue
|
||||
case errors.Is(err, transform.ErrShortDst) && (r.dst1 != 0 || n != 0):
|
||||
// Make room in dst by copying out, and try again.
|
||||
continue
|
||||
case errors.Is(err, transform.ErrShortSrc) && r.src1-r.src0 != len(r.src) && r.err == nil:
|
||||
// Read more bytes into src via the code below, and try again.
|
||||
default:
|
||||
r.transformComplete = true
|
||||
// The reader error (r.err) takes precedence over the
|
||||
// transformer error (err) unless r.err is nil or io.EOF.
|
||||
if r.err == nil || errors.Is(r.err, io.EOF) {
|
||||
r.err = err
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Move any untransformed source bytes to the start of the buffer
|
||||
// and read more bytes.
|
||||
if r.src0 != 0 {
|
||||
r.src0, r.src1 = 0, copy(r.src, r.src[r.src0:r.src1])
|
||||
}
|
||||
n, r.err = r.r.Read(r.src[r.src1:])
|
||||
r.src1 += n
|
||||
}
|
||||
}
|
78
plugins/common/encoding/decoder_test.go
Normal file
78
plugins/common/encoding/decoder_test.go
Normal file
|
@ -0,0 +1,78 @@
|
|||
package encoding
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestDecoder(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
encoding string
|
||||
input []byte
|
||||
expected []byte
|
||||
expectedErr bool
|
||||
}{
|
||||
{
|
||||
name: "no decoder utf-8",
|
||||
encoding: "",
|
||||
input: []byte("howdy"),
|
||||
expected: []byte("howdy"),
|
||||
},
|
||||
{
|
||||
name: "utf-8 decoder",
|
||||
encoding: "utf-8",
|
||||
input: []byte("howdy"),
|
||||
expected: []byte("howdy"),
|
||||
},
|
||||
{
|
||||
name: "utf-8 decoder invalid bytes replaced with replacement char",
|
||||
encoding: "utf-8",
|
||||
input: []byte("\xff\xfe"),
|
||||
expected: []byte("\uFFFD\uFFFD"),
|
||||
},
|
||||
{
|
||||
name: "utf-16le decoder no BOM",
|
||||
encoding: "utf-16le",
|
||||
input: []byte("h\x00o\x00w\x00d\x00y\x00"),
|
||||
expected: []byte("howdy"),
|
||||
},
|
||||
{
|
||||
name: "utf-16le decoder with BOM",
|
||||
encoding: "utf-16le",
|
||||
input: []byte("\xff\xfeh\x00o\x00w\x00d\x00y\x00"),
|
||||
expected: []byte("\xef\xbb\xbfhowdy"),
|
||||
},
|
||||
{
|
||||
name: "utf-16be decoder no BOM",
|
||||
encoding: "utf-16be",
|
||||
input: []byte("\x00h\x00o\x00w\x00d\x00y"),
|
||||
expected: []byte("howdy"),
|
||||
},
|
||||
{
|
||||
name: "utf-16be decoder with BOM",
|
||||
encoding: "utf-16be",
|
||||
input: []byte("\xfe\xff\x00h\x00o\x00w\x00d\x00y"),
|
||||
expected: []byte("\xef\xbb\xbfhowdy"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
decoder, err := NewDecoder(tt.encoding)
|
||||
require.NoError(t, err)
|
||||
buf := bytes.NewBuffer(tt.input)
|
||||
r := decoder.Reader(buf)
|
||||
actual, err := io.ReadAll(r)
|
||||
if tt.expectedErr {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.expected, actual)
|
||||
})
|
||||
}
|
||||
}
|
78
plugins/common/http/config.go
Normal file
78
plugins/common/http/config.go
Normal file
|
@ -0,0 +1,78 @@
|
|||
package httpconfig
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/benbjohnson/clock"
|
||||
"github.com/peterbourgon/unixtransport"
|
||||
|
||||
"github.com/influxdata/telegraf"
|
||||
"github.com/influxdata/telegraf/config"
|
||||
"github.com/influxdata/telegraf/plugins/common/cookie"
|
||||
"github.com/influxdata/telegraf/plugins/common/oauth"
|
||||
"github.com/influxdata/telegraf/plugins/common/proxy"
|
||||
"github.com/influxdata/telegraf/plugins/common/tls"
|
||||
)
|
||||
|
||||
// Common HTTP client struct.
|
||||
type HTTPClientConfig struct {
|
||||
Timeout config.Duration `toml:"timeout"`
|
||||
IdleConnTimeout config.Duration `toml:"idle_conn_timeout"`
|
||||
MaxIdleConns int `toml:"max_idle_conn"`
|
||||
MaxIdleConnsPerHost int `toml:"max_idle_conn_per_host"`
|
||||
ResponseHeaderTimeout config.Duration `toml:"response_timeout"`
|
||||
|
||||
proxy.HTTPProxy
|
||||
tls.ClientConfig
|
||||
oauth.OAuth2Config
|
||||
cookie.CookieAuthConfig
|
||||
}
|
||||
|
||||
func (h *HTTPClientConfig) CreateClient(ctx context.Context, log telegraf.Logger) (*http.Client, error) {
|
||||
tlsCfg, err := h.ClientConfig.TLSConfig()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to set TLS config: %w", err)
|
||||
}
|
||||
|
||||
prox, err := h.HTTPProxy.Proxy()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to set proxy: %w", err)
|
||||
}
|
||||
|
||||
transport := &http.Transport{
|
||||
TLSClientConfig: tlsCfg,
|
||||
Proxy: prox,
|
||||
IdleConnTimeout: time.Duration(h.IdleConnTimeout),
|
||||
MaxIdleConns: h.MaxIdleConns,
|
||||
MaxIdleConnsPerHost: h.MaxIdleConnsPerHost,
|
||||
ResponseHeaderTimeout: time.Duration(h.ResponseHeaderTimeout),
|
||||
}
|
||||
|
||||
// Register "http+unix" and "https+unix" protocol handler.
|
||||
unixtransport.Register(transport)
|
||||
|
||||
client := &http.Client{
|
||||
Transport: transport,
|
||||
}
|
||||
|
||||
// While CreateOauth2Client returns a http.Client keeping the Transport configuration,
|
||||
// it does not keep other http.Client parameters (e.g. Timeout).
|
||||
client = h.OAuth2Config.CreateOauth2Client(ctx, client)
|
||||
|
||||
if h.CookieAuthConfig.URL != "" {
|
||||
if err := h.CookieAuthConfig.Start(client, log, clock.New()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
timeout := h.Timeout
|
||||
if timeout == 0 {
|
||||
timeout = config.Duration(time.Second * 5)
|
||||
}
|
||||
client.Timeout = time.Duration(timeout)
|
||||
|
||||
return client, nil
|
||||
}
|
275
plugins/common/jolokia2/client.go
Normal file
275
plugins/common/jolokia2/client.go
Normal file
|
@ -0,0 +1,275 @@
|
|||
package jolokia2
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
"time"
|
||||
|
||||
"github.com/influxdata/telegraf/plugins/common/tls"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
URL string
|
||||
client *http.Client
|
||||
config *ClientConfig
|
||||
}
|
||||
|
||||
type ClientConfig struct {
|
||||
ResponseTimeout time.Duration
|
||||
Username string
|
||||
Password string
|
||||
Origin string
|
||||
ProxyConfig *ProxyConfig
|
||||
tls.ClientConfig
|
||||
}
|
||||
|
||||
type ProxyConfig struct {
|
||||
DefaultTargetUsername string
|
||||
DefaultTargetPassword string
|
||||
Targets []ProxyTargetConfig
|
||||
}
|
||||
|
||||
type ProxyTargetConfig struct {
|
||||
Username string
|
||||
Password string
|
||||
URL string
|
||||
}
|
||||
|
||||
type ReadRequest struct {
|
||||
Mbean string
|
||||
Attributes []string
|
||||
Path string
|
||||
}
|
||||
|
||||
type ReadResponse struct {
|
||||
Status int
|
||||
Value interface{}
|
||||
RequestMbean string
|
||||
RequestAttributes []string
|
||||
RequestPath string
|
||||
RequestTarget string
|
||||
}
|
||||
|
||||
// Jolokia JSON request object. Example: {
|
||||
// "type": "read",
|
||||
// "mbean: "java.lang:type="Runtime",
|
||||
// "attribute": "Uptime",
|
||||
// "target": {
|
||||
// "url: "service:jmx:rmi:///jndi/rmi://target:9010/jmxrmi"
|
||||
// }
|
||||
// }
|
||||
type jolokiaRequest struct {
|
||||
Type string `json:"type"`
|
||||
Mbean string `json:"mbean"`
|
||||
Attribute interface{} `json:"attribute,omitempty"`
|
||||
Path string `json:"path,omitempty"`
|
||||
Target *jolokiaTarget `json:"target,omitempty"`
|
||||
}
|
||||
|
||||
type jolokiaTarget struct {
|
||||
URL string `json:"url"`
|
||||
User string `json:"user,omitempty"`
|
||||
Password string `json:"password,omitempty"`
|
||||
}
|
||||
|
||||
// Jolokia JSON response object. Example: {
|
||||
// "request": {
|
||||
// "type": "read"
|
||||
// "mbean": "java.lang:type=Runtime",
|
||||
// "attribute": "Uptime",
|
||||
// "target": {
|
||||
// "url": "service:jmx:rmi:///jndi/rmi://target:9010/jmxrmi"
|
||||
// }
|
||||
// },
|
||||
// "value": 1214083,
|
||||
// "timestamp": 1488059309,
|
||||
// "status": 200
|
||||
// }
|
||||
type jolokiaResponse struct {
|
||||
Request jolokiaRequest `json:"request"`
|
||||
Value interface{} `json:"value"`
|
||||
Status int `json:"status"`
|
||||
}
|
||||
|
||||
func NewClient(address string, config *ClientConfig) (*Client, error) {
|
||||
tlsConfig, err := config.ClientConfig.TLSConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
transport := &http.Transport{
|
||||
ResponseHeaderTimeout: config.ResponseTimeout,
|
||||
TLSClientConfig: tlsConfig,
|
||||
}
|
||||
|
||||
client := &http.Client{
|
||||
Transport: transport,
|
||||
Timeout: config.ResponseTimeout,
|
||||
}
|
||||
|
||||
return &Client{
|
||||
URL: address,
|
||||
config: config,
|
||||
client: client,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *Client) read(requests []ReadRequest) ([]ReadResponse, error) {
|
||||
jRequests := makeJolokiaRequests(requests, c.config.ProxyConfig)
|
||||
requestBody, err := json.Marshal(jRequests)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
requestURL, err := formatReadURL(c.URL, c.config.Username, c.config.Password)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", requestURL, bytes.NewBuffer(requestBody))
|
||||
if err != nil {
|
||||
// err is not contained in returned error - it may contain sensitive data (password) which should not be logged
|
||||
return nil, fmt.Errorf("unable to create new request for: %q", c.URL)
|
||||
}
|
||||
|
||||
req.Header.Add("Content-type", "application/json")
|
||||
if c.config.Origin != "" {
|
||||
req.Header.Add("Origin", c.config.Origin)
|
||||
}
|
||||
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("response from url %q has status code %d (%s), expected %d (%s)",
|
||||
c.URL, resp.StatusCode, http.StatusText(resp.StatusCode), http.StatusOK, http.StatusText(http.StatusOK))
|
||||
}
|
||||
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var jResponses []jolokiaResponse
|
||||
if err = json.Unmarshal(responseBody, &jResponses); err != nil {
|
||||
return nil, fmt.Errorf("decoding JSON response: %w: %s", err, responseBody)
|
||||
}
|
||||
|
||||
return makeReadResponses(jResponses), nil
|
||||
}
|
||||
|
||||
func makeJolokiaRequests(rrequests []ReadRequest, proxyConfig *ProxyConfig) []jolokiaRequest {
|
||||
jrequests := make([]jolokiaRequest, 0)
|
||||
if proxyConfig == nil {
|
||||
for _, rr := range rrequests {
|
||||
jrequests = append(jrequests, makeJolokiaRequest(rr, nil))
|
||||
}
|
||||
} else {
|
||||
for _, t := range proxyConfig.Targets {
|
||||
if t.Username == "" {
|
||||
t.Username = proxyConfig.DefaultTargetUsername
|
||||
}
|
||||
if t.Password == "" {
|
||||
t.Password = proxyConfig.DefaultTargetPassword
|
||||
}
|
||||
|
||||
for _, rr := range rrequests {
|
||||
jtarget := &jolokiaTarget{
|
||||
URL: t.URL,
|
||||
User: t.Username,
|
||||
Password: t.Password,
|
||||
}
|
||||
|
||||
jrequests = append(jrequests, makeJolokiaRequest(rr, jtarget))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return jrequests
|
||||
}
|
||||
|
||||
func makeJolokiaRequest(rrequest ReadRequest, jtarget *jolokiaTarget) jolokiaRequest {
|
||||
jrequest := jolokiaRequest{
|
||||
Type: "read",
|
||||
Mbean: rrequest.Mbean,
|
||||
Path: rrequest.Path,
|
||||
Target: jtarget,
|
||||
}
|
||||
|
||||
if len(rrequest.Attributes) == 1 {
|
||||
jrequest.Attribute = rrequest.Attributes[0]
|
||||
}
|
||||
if len(rrequest.Attributes) > 1 {
|
||||
jrequest.Attribute = rrequest.Attributes
|
||||
}
|
||||
|
||||
return jrequest
|
||||
}
|
||||
|
||||
func makeReadResponses(jresponses []jolokiaResponse) []ReadResponse {
|
||||
rresponses := make([]ReadResponse, 0)
|
||||
|
||||
for _, jr := range jresponses {
|
||||
rrequest := ReadRequest{
|
||||
Mbean: jr.Request.Mbean,
|
||||
Path: jr.Request.Path,
|
||||
Attributes: make([]string, 0),
|
||||
}
|
||||
|
||||
attrValue := jr.Request.Attribute
|
||||
if attrValue != nil {
|
||||
attribute, ok := attrValue.(string)
|
||||
if ok {
|
||||
rrequest.Attributes = []string{attribute}
|
||||
} else {
|
||||
attributes, _ := attrValue.([]interface{})
|
||||
rrequest.Attributes = make([]string, 0, len(attributes))
|
||||
for _, attr := range attributes {
|
||||
rrequest.Attributes = append(rrequest.Attributes, attr.(string))
|
||||
}
|
||||
}
|
||||
}
|
||||
rresponse := ReadResponse{
|
||||
Value: jr.Value,
|
||||
Status: jr.Status,
|
||||
RequestMbean: rrequest.Mbean,
|
||||
RequestAttributes: rrequest.Attributes,
|
||||
RequestPath: rrequest.Path,
|
||||
}
|
||||
if jtarget := jr.Request.Target; jtarget != nil {
|
||||
rresponse.RequestTarget = jtarget.URL
|
||||
}
|
||||
|
||||
rresponses = append(rresponses, rresponse)
|
||||
}
|
||||
|
||||
return rresponses
|
||||
}
|
||||
|
||||
func formatReadURL(configURL, username, password string) (string, error) {
|
||||
parsedURL, err := url.Parse(configURL)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
readURL := url.URL{
|
||||
Host: parsedURL.Host,
|
||||
Scheme: parsedURL.Scheme,
|
||||
}
|
||||
|
||||
if username != "" || password != "" {
|
||||
readURL.User = url.UserPassword(username, password)
|
||||
}
|
||||
|
||||
readURL.Path = path.Join(parsedURL.Path, "read")
|
||||
readURL.Query().Add("ignoreErrors", "true")
|
||||
return readURL.String(), nil
|
||||
}
|
266
plugins/common/jolokia2/gatherer.go
Normal file
266
plugins/common/jolokia2/gatherer.go
Normal file
|
@ -0,0 +1,266 @@
|
|||
package jolokia2
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/influxdata/telegraf"
|
||||
)
|
||||
|
||||
const defaultFieldName = "value"
|
||||
|
||||
type Gatherer struct {
|
||||
metrics []Metric
|
||||
requests []ReadRequest
|
||||
}
|
||||
|
||||
func NewGatherer(metrics []Metric) *Gatherer {
|
||||
return &Gatherer{
|
||||
metrics: metrics,
|
||||
requests: makeReadRequests(metrics),
|
||||
}
|
||||
}
|
||||
|
||||
// Gather adds points to an accumulator from responses returned
|
||||
// by a Jolokia agent.
|
||||
func (g *Gatherer) Gather(client *Client, acc telegraf.Accumulator) error {
|
||||
var tags map[string]string
|
||||
|
||||
if client.config.ProxyConfig != nil {
|
||||
tags = map[string]string{"jolokia_proxy_url": client.URL}
|
||||
} else {
|
||||
tags = map[string]string{"jolokia_agent_url": client.URL}
|
||||
}
|
||||
|
||||
requests := makeReadRequests(g.metrics)
|
||||
responses, err := client.read(requests)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
g.gatherResponses(responses, tags, acc)
|
||||
return nil
|
||||
}
|
||||
|
||||
// gatherResponses adds points to an accumulator from the ReadResponse objects
|
||||
// returned by a Jolokia agent.
|
||||
func (g *Gatherer) gatherResponses(responses []ReadResponse, tags map[string]string, acc telegraf.Accumulator) {
|
||||
series := make(map[string][]point)
|
||||
|
||||
for _, metric := range g.metrics {
|
||||
points, ok := series[metric.Name]
|
||||
if !ok {
|
||||
points = make([]point, 0)
|
||||
}
|
||||
|
||||
responsePoints, responseErrors := generatePoints(metric, responses)
|
||||
points = append(points, responsePoints...)
|
||||
for _, err := range responseErrors {
|
||||
acc.AddError(err)
|
||||
}
|
||||
|
||||
series[metric.Name] = points
|
||||
}
|
||||
|
||||
for measurement, points := range series {
|
||||
for _, point := range compactPoints(points) {
|
||||
acc.AddFields(measurement,
|
||||
point.Fields, mergeTags(point.Tags, tags))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// generatePoints creates points for the supplied metric from the ReadResponse objects returned by the Jolokia client.
|
||||
func generatePoints(metric Metric, responses []ReadResponse) ([]point, []error) {
|
||||
points := make([]point, 0)
|
||||
errors := make([]error, 0)
|
||||
|
||||
for _, response := range responses {
|
||||
switch response.Status {
|
||||
case 200:
|
||||
// Correct response status - do nothing.
|
||||
case 404:
|
||||
continue
|
||||
default:
|
||||
errors = append(errors, fmt.Errorf("unexpected status in response from target %s (%q): %d",
|
||||
response.RequestTarget, response.RequestMbean, response.Status))
|
||||
continue
|
||||
}
|
||||
|
||||
if !metricMatchesResponse(metric, response) {
|
||||
continue
|
||||
}
|
||||
|
||||
pb := NewPointBuilder(metric, response.RequestAttributes, response.RequestPath)
|
||||
ps, err := pb.Build(metric.Mbean, response.Value)
|
||||
if err != nil {
|
||||
errors = append(errors, err)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, point := range ps {
|
||||
if response.RequestTarget != "" {
|
||||
point.Tags["jolokia_agent_url"] = response.RequestTarget
|
||||
}
|
||||
|
||||
points = append(points, point)
|
||||
}
|
||||
}
|
||||
|
||||
return points, errors
|
||||
}
|
||||
|
||||
// mergeTags combines two tag sets into a single tag set.
|
||||
func mergeTags(metricTags, outerTags map[string]string) map[string]string {
|
||||
tags := make(map[string]string)
|
||||
for k, v := range outerTags {
|
||||
tags[k] = strings.Trim(v, `'"`)
|
||||
}
|
||||
for k, v := range metricTags {
|
||||
tags[k] = strings.Trim(v, `'"`)
|
||||
}
|
||||
|
||||
return tags
|
||||
}
|
||||
|
||||
// metricMatchesResponse returns true when the name, attributes, and path
|
||||
// of a Metric match the corresponding elements in a ReadResponse object
|
||||
// returned by a Jolokia agent.
|
||||
func metricMatchesResponse(metric Metric, response ReadResponse) bool {
|
||||
if !metric.MatchObjectName(response.RequestMbean) {
|
||||
return false
|
||||
}
|
||||
|
||||
if len(metric.Paths) == 0 {
|
||||
return len(response.RequestAttributes) == 0
|
||||
}
|
||||
|
||||
for _, attribute := range response.RequestAttributes {
|
||||
if metric.MatchAttributeAndPath(attribute, response.RequestPath) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// compactPoints attempts to remove points by compacting points
|
||||
// with matching tag sets. When a match is found, the fields from
|
||||
// one point are moved to another, and the empty point is removed.
|
||||
func compactPoints(points []point) []point {
|
||||
compactedPoints := make([]point, 0)
|
||||
|
||||
for _, sourcePoint := range points {
|
||||
keepPoint := true
|
||||
|
||||
for _, compactPoint := range compactedPoints {
|
||||
if !tagSetsMatch(sourcePoint.Tags, compactPoint.Tags) {
|
||||
continue
|
||||
}
|
||||
|
||||
keepPoint = false
|
||||
for key, val := range sourcePoint.Fields {
|
||||
compactPoint.Fields[key] = val
|
||||
}
|
||||
}
|
||||
|
||||
if keepPoint {
|
||||
compactedPoints = append(compactedPoints, sourcePoint)
|
||||
}
|
||||
}
|
||||
|
||||
return compactedPoints
|
||||
}
|
||||
|
||||
// tagSetsMatch returns true if two maps are equivalent.
|
||||
func tagSetsMatch(a, b map[string]string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
|
||||
for ak, av := range a {
|
||||
bv, ok := b[ak]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if av != bv {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// makeReadRequests creates ReadRequest objects from metrics definitions.
|
||||
func makeReadRequests(metrics []Metric) []ReadRequest {
|
||||
var requests []ReadRequest
|
||||
for _, metric := range metrics {
|
||||
if len(metric.Paths) == 0 {
|
||||
requests = append(requests, ReadRequest{
|
||||
Mbean: metric.Mbean,
|
||||
Attributes: make([]string, 0),
|
||||
})
|
||||
} else {
|
||||
attributes := make(map[string][]string)
|
||||
|
||||
for _, path := range metric.Paths {
|
||||
segments := strings.Split(path, "/")
|
||||
attribute := segments[0]
|
||||
|
||||
if _, ok := attributes[attribute]; !ok {
|
||||
attributes[attribute] = make([]string, 0)
|
||||
}
|
||||
|
||||
if len(segments) > 1 {
|
||||
paths := attributes[attribute]
|
||||
attributes[attribute] = append(paths, strings.Join(segments[1:], "/"))
|
||||
}
|
||||
}
|
||||
|
||||
rootAttributes := findRequestAttributesWithoutPaths(attributes)
|
||||
if len(rootAttributes) > 0 {
|
||||
requests = append(requests, ReadRequest{
|
||||
Mbean: metric.Mbean,
|
||||
Attributes: rootAttributes,
|
||||
})
|
||||
}
|
||||
|
||||
for _, deepAttribute := range findRequestAttributesWithPaths(attributes) {
|
||||
for _, path := range attributes[deepAttribute] {
|
||||
requests = append(requests, ReadRequest{
|
||||
Mbean: metric.Mbean,
|
||||
Attributes: []string{deepAttribute},
|
||||
Path: path,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return requests
|
||||
}
|
||||
|
||||
func findRequestAttributesWithoutPaths(attributes map[string][]string) []string {
|
||||
results := make([]string, 0)
|
||||
for attr, paths := range attributes {
|
||||
if len(paths) == 0 {
|
||||
results = append(results, attr)
|
||||
}
|
||||
}
|
||||
|
||||
sort.Strings(results)
|
||||
return results
|
||||
}
|
||||
|
||||
func findRequestAttributesWithPaths(attributes map[string][]string) []string {
|
||||
results := make([]string, 0)
|
||||
for attr, paths := range attributes {
|
||||
if len(paths) != 0 {
|
||||
results = append(results, attr)
|
||||
}
|
||||
}
|
||||
|
||||
sort.Strings(results)
|
||||
return results
|
||||
}
|
104
plugins/common/jolokia2/gatherer_test.go
Normal file
104
plugins/common/jolokia2/gatherer_test.go
Normal file
|
@ -0,0 +1,104 @@
|
|||
package jolokia2
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestJolokia2_makeReadRequests(t *testing.T) {
|
||||
cases := []struct {
|
||||
metric Metric
|
||||
expected []ReadRequest
|
||||
}{
|
||||
{
|
||||
metric: Metric{
|
||||
Name: "object",
|
||||
Mbean: "test:foo=bar",
|
||||
},
|
||||
expected: []ReadRequest{
|
||||
{
|
||||
Mbean: "test:foo=bar",
|
||||
Attributes: make([]string, 0),
|
||||
},
|
||||
},
|
||||
}, {
|
||||
metric: Metric{
|
||||
Name: "object_with_an_attribute",
|
||||
Mbean: "test:foo=bar",
|
||||
Paths: []string{"biz"},
|
||||
},
|
||||
expected: []ReadRequest{
|
||||
{
|
||||
Mbean: "test:foo=bar",
|
||||
Attributes: []string{"biz"},
|
||||
},
|
||||
},
|
||||
}, {
|
||||
metric: Metric{
|
||||
Name: "object_with_attributes",
|
||||
Mbean: "test:foo=bar",
|
||||
Paths: []string{"baz", "biz"},
|
||||
},
|
||||
expected: []ReadRequest{
|
||||
{
|
||||
Mbean: "test:foo=bar",
|
||||
Attributes: []string{"baz", "biz"},
|
||||
},
|
||||
},
|
||||
}, {
|
||||
metric: Metric{
|
||||
Name: "object_with_an_attribute_and_path",
|
||||
Mbean: "test:foo=bar",
|
||||
Paths: []string{"biz/baz"},
|
||||
},
|
||||
expected: []ReadRequest{
|
||||
{
|
||||
Mbean: "test:foo=bar",
|
||||
Attributes: []string{"biz"},
|
||||
Path: "baz",
|
||||
},
|
||||
},
|
||||
}, {
|
||||
metric: Metric{
|
||||
Name: "object_with_an_attribute_and_a_deep_path",
|
||||
Mbean: "test:foo=bar",
|
||||
Paths: []string{"biz/baz/fiz/faz"},
|
||||
},
|
||||
expected: []ReadRequest{
|
||||
{
|
||||
Mbean: "test:foo=bar",
|
||||
Attributes: []string{"biz"},
|
||||
Path: "baz/fiz/faz",
|
||||
},
|
||||
},
|
||||
}, {
|
||||
metric: Metric{
|
||||
Name: "object_with_attributes_and_paths",
|
||||
Mbean: "test:foo=bar",
|
||||
Paths: []string{"baz/biz", "faz/fiz"},
|
||||
},
|
||||
expected: []ReadRequest{
|
||||
{
|
||||
Mbean: "test:foo=bar",
|
||||
Attributes: []string{"baz"},
|
||||
Path: "biz",
|
||||
},
|
||||
{
|
||||
Mbean: "test:foo=bar",
|
||||
Attributes: []string{"faz"},
|
||||
Path: "fiz",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
payload := makeReadRequests([]Metric{c.metric})
|
||||
|
||||
require.Len(t, payload, len(c.expected), "Failing case: "+c.metric.Name)
|
||||
for _, actual := range payload {
|
||||
require.Contains(t, c.expected, actual, "Failing case: "+c.metric.Name)
|
||||
}
|
||||
}
|
||||
}
|
128
plugins/common/jolokia2/metric.go
Normal file
128
plugins/common/jolokia2/metric.go
Normal file
|
@ -0,0 +1,128 @@
|
|||
package jolokia2
|
||||
|
||||
import "strings"
|
||||
|
||||
// A MetricConfig represents a TOML form of
|
||||
// a Metric with some optional fields.
|
||||
type MetricConfig struct {
|
||||
Name string
|
||||
Mbean string
|
||||
Paths []string
|
||||
FieldName *string
|
||||
FieldPrefix *string
|
||||
FieldSeparator *string
|
||||
TagPrefix *string
|
||||
TagKeys []string
|
||||
}
|
||||
|
||||
// A Metric represents a specification for a
|
||||
// Jolokia read request, and the transformations
|
||||
// to apply to points generated from the responses.
|
||||
type Metric struct {
|
||||
Name string
|
||||
Mbean string
|
||||
Paths []string
|
||||
FieldName string
|
||||
FieldPrefix string
|
||||
FieldSeparator string
|
||||
TagPrefix string
|
||||
TagKeys []string
|
||||
|
||||
mbeanDomain string
|
||||
mbeanProperties []string
|
||||
}
|
||||
|
||||
func NewMetric(config MetricConfig, defaultFieldPrefix, defaultFieldSeparator, defaultTagPrefix string) Metric {
|
||||
metric := Metric{
|
||||
Name: config.Name,
|
||||
Mbean: config.Mbean,
|
||||
Paths: config.Paths,
|
||||
TagKeys: config.TagKeys,
|
||||
}
|
||||
|
||||
if config.FieldName != nil {
|
||||
metric.FieldName = *config.FieldName
|
||||
}
|
||||
|
||||
if config.FieldPrefix == nil {
|
||||
metric.FieldPrefix = defaultFieldPrefix
|
||||
} else {
|
||||
metric.FieldPrefix = *config.FieldPrefix
|
||||
}
|
||||
|
||||
if config.FieldSeparator == nil {
|
||||
metric.FieldSeparator = defaultFieldSeparator
|
||||
} else {
|
||||
metric.FieldSeparator = *config.FieldSeparator
|
||||
}
|
||||
|
||||
if config.TagPrefix == nil {
|
||||
metric.TagPrefix = defaultTagPrefix
|
||||
} else {
|
||||
metric.TagPrefix = *config.TagPrefix
|
||||
}
|
||||
|
||||
mbeanDomain, mbeanProperties := parseMbeanObjectName(config.Mbean)
|
||||
metric.mbeanDomain = mbeanDomain
|
||||
metric.mbeanProperties = mbeanProperties
|
||||
|
||||
return metric
|
||||
}
|
||||
|
||||
func (m Metric) MatchObjectName(name string) bool {
|
||||
if name == m.Mbean {
|
||||
return true
|
||||
}
|
||||
|
||||
mbeanDomain, mbeanProperties := parseMbeanObjectName(name)
|
||||
if mbeanDomain != m.mbeanDomain {
|
||||
return false
|
||||
}
|
||||
|
||||
if len(mbeanProperties) != len(m.mbeanProperties) {
|
||||
return false
|
||||
}
|
||||
|
||||
NEXT_PROPERTY:
|
||||
for _, mbeanProperty := range m.mbeanProperties {
|
||||
for i := range mbeanProperties {
|
||||
if mbeanProperties[i] == mbeanProperty {
|
||||
continue NEXT_PROPERTY
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (m Metric) MatchAttributeAndPath(attribute, innerPath string) bool {
|
||||
path := attribute
|
||||
if innerPath != "" {
|
||||
path = path + "/" + innerPath
|
||||
}
|
||||
|
||||
for i := range m.Paths {
|
||||
if path == m.Paths[i] {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func parseMbeanObjectName(name string) (string, []string) {
|
||||
index := strings.Index(name, ":")
|
||||
if index == -1 {
|
||||
return name, nil
|
||||
}
|
||||
|
||||
domain := name[:index]
|
||||
|
||||
if index+1 > len(name) {
|
||||
return domain, nil
|
||||
}
|
||||
|
||||
return domain, strings.Split(name[index+1:], ",")
|
||||
}
|
274
plugins/common/jolokia2/point_builder.go
Normal file
274
plugins/common/jolokia2/point_builder.go
Normal file
|
@ -0,0 +1,274 @@
|
|||
package jolokia2
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type point struct {
|
||||
Tags map[string]string
|
||||
Fields map[string]interface{}
|
||||
}
|
||||
|
||||
type pointBuilder struct {
|
||||
metric Metric
|
||||
objectAttributes []string
|
||||
objectPath string
|
||||
substitutions []string
|
||||
}
|
||||
|
||||
func NewPointBuilder(metric Metric, attributes []string, path string) *pointBuilder {
|
||||
return &pointBuilder{
|
||||
metric: metric,
|
||||
objectAttributes: attributes,
|
||||
objectPath: path,
|
||||
substitutions: makeSubstitutionList(metric.Mbean),
|
||||
}
|
||||
}
|
||||
|
||||
// Build generates a point for a given mbean name/pattern and value object.
|
||||
func (pb *pointBuilder) Build(mbean string, value interface{}) ([]point, error) {
|
||||
hasPattern := strings.Contains(mbean, "*")
|
||||
if !hasPattern || value == nil {
|
||||
value = map[string]interface{}{mbean: value}
|
||||
}
|
||||
|
||||
valueMap, ok := value.(map[string]interface{})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("the response of %s's value should be a map", mbean)
|
||||
}
|
||||
|
||||
points := make([]point, 0)
|
||||
for mbean, value := range valueMap {
|
||||
points = append(points, point{
|
||||
Tags: pb.extractTags(mbean),
|
||||
Fields: pb.extractFields(mbean, value),
|
||||
})
|
||||
}
|
||||
|
||||
return compactPoints(points), nil
|
||||
}
|
||||
|
||||
// extractTags generates the map of tags for a given mbean name/pattern.
|
||||
func (pb *pointBuilder) extractTags(mbean string) map[string]string {
|
||||
propertyMap := makePropertyMap(mbean)
|
||||
tagMap := make(map[string]string)
|
||||
|
||||
for key, value := range propertyMap {
|
||||
if pb.includeTag(key) {
|
||||
tagName := pb.formatTagName(key)
|
||||
tagMap[tagName] = value
|
||||
}
|
||||
}
|
||||
|
||||
return tagMap
|
||||
}
|
||||
|
||||
func (pb *pointBuilder) includeTag(tagName string) bool {
|
||||
for _, t := range pb.metric.TagKeys {
|
||||
if tagName == t {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (pb *pointBuilder) formatTagName(tagName string) string {
|
||||
if tagName == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
if tagPrefix := pb.metric.TagPrefix; tagPrefix != "" {
|
||||
return tagPrefix + tagName
|
||||
}
|
||||
|
||||
return tagName
|
||||
}
|
||||
|
||||
// extractFields generates the map of fields for a given mbean name
|
||||
// and value object.
|
||||
func (pb *pointBuilder) extractFields(mbean string, value interface{}) map[string]interface{} {
|
||||
fieldMap := make(map[string]interface{})
|
||||
valueMap, ok := value.(map[string]interface{})
|
||||
|
||||
if ok {
|
||||
// complex value
|
||||
if len(pb.objectAttributes) == 0 {
|
||||
// if there were no attributes requested,
|
||||
// then the keys are attributes
|
||||
pb.FillFields("", valueMap, fieldMap)
|
||||
} else if len(pb.objectAttributes) == 1 {
|
||||
// if there was a single attribute requested,
|
||||
// then the keys are the attribute's properties
|
||||
fieldName := pb.formatFieldName(pb.objectAttributes[0], pb.objectPath)
|
||||
pb.FillFields(fieldName, valueMap, fieldMap)
|
||||
} else {
|
||||
// if there were multiple attributes requested,
|
||||
// then the keys are the attribute names
|
||||
for _, attribute := range pb.objectAttributes {
|
||||
fieldName := pb.formatFieldName(attribute, pb.objectPath)
|
||||
pb.FillFields(fieldName, valueMap[attribute], fieldMap)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// scalar value
|
||||
var fieldName string
|
||||
if len(pb.objectAttributes) == 0 {
|
||||
fieldName = pb.formatFieldName(defaultFieldName, pb.objectPath)
|
||||
} else {
|
||||
fieldName = pb.formatFieldName(pb.objectAttributes[0], pb.objectPath)
|
||||
}
|
||||
|
||||
pb.FillFields(fieldName, value, fieldMap)
|
||||
}
|
||||
|
||||
if len(pb.substitutions) > 1 {
|
||||
pb.applySubstitutions(mbean, fieldMap)
|
||||
}
|
||||
|
||||
return fieldMap
|
||||
}
|
||||
|
||||
// formatFieldName generates a field name from the supplied attribute and
|
||||
// path. The return value has the configured FieldPrefix and FieldSuffix
|
||||
// instructions applied.
|
||||
func (pb *pointBuilder) formatFieldName(attribute, path string) string {
|
||||
fieldName := attribute
|
||||
fieldPrefix := pb.metric.FieldPrefix
|
||||
fieldSeparator := pb.metric.FieldSeparator
|
||||
|
||||
if fieldPrefix != "" {
|
||||
fieldName = fieldPrefix + fieldName
|
||||
}
|
||||
|
||||
if path != "" {
|
||||
fieldName = fieldName + fieldSeparator + strings.ReplaceAll(path, "/", fieldSeparator)
|
||||
}
|
||||
|
||||
return fieldName
|
||||
}
|
||||
|
||||
// FillFields recurses into the supplied value object, generating a named field
|
||||
// for every value it discovers.
|
||||
func (pb *pointBuilder) FillFields(name string, value interface{}, fieldMap map[string]interface{}) {
|
||||
if valueMap, ok := value.(map[string]interface{}); ok {
|
||||
// keep going until we get to something that is not a map
|
||||
for key, innerValue := range valueMap {
|
||||
if _, ok := innerValue.([]interface{}); ok {
|
||||
continue
|
||||
}
|
||||
|
||||
var innerName string
|
||||
if name == "" {
|
||||
innerName = pb.metric.FieldPrefix + key
|
||||
} else {
|
||||
innerName = name + pb.metric.FieldSeparator + key
|
||||
}
|
||||
|
||||
pb.FillFields(innerName, innerValue, fieldMap)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if _, ok := value.([]interface{}); ok {
|
||||
return
|
||||
}
|
||||
|
||||
if pb.metric.FieldName != "" {
|
||||
name = pb.metric.FieldName
|
||||
if prefix := pb.metric.FieldPrefix; prefix != "" {
|
||||
name = prefix + name
|
||||
}
|
||||
}
|
||||
|
||||
if name == "" {
|
||||
name = defaultFieldName
|
||||
}
|
||||
|
||||
fieldMap[name] = value
|
||||
}
|
||||
|
||||
// applySubstitutions updates all the keys in the supplied map
|
||||
// of fields to account for $1-style substitution instructions.
|
||||
func (pb *pointBuilder) applySubstitutions(mbean string, fieldMap map[string]interface{}) {
|
||||
properties := makePropertyMap(mbean)
|
||||
|
||||
for i, subKey := range pb.substitutions[1:] {
|
||||
symbol := fmt.Sprintf("$%d", i+1)
|
||||
substitution := properties[subKey]
|
||||
|
||||
for fieldName, fieldValue := range fieldMap {
|
||||
newFieldName := strings.ReplaceAll(fieldName, symbol, substitution)
|
||||
if fieldName != newFieldName {
|
||||
fieldMap[newFieldName] = fieldValue
|
||||
delete(fieldMap, fieldName)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// makePropertyMap returns a the mbean property-key list as
|
||||
// a dictionary. foo:x=y becomes map[string]string { "x": "y" }
|
||||
func makePropertyMap(mbean string) map[string]string {
|
||||
props := make(map[string]string)
|
||||
object := strings.SplitN(mbean, ":", 2)
|
||||
domain := object[0]
|
||||
|
||||
if domain != "" && len(object) == 2 {
|
||||
list := object[1]
|
||||
|
||||
for _, keyProperty := range strings.Split(list, ",") {
|
||||
pair := strings.SplitN(keyProperty, "=", 2)
|
||||
|
||||
if len(pair) != 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
if key := pair[0]; key != "" {
|
||||
props[key] = pair[1]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return props
|
||||
}
|
||||
|
||||
// makeSubstitutionList returns an array of values to
|
||||
// use as substitutions when renaming fields
|
||||
// with the $1..$N syntax. The first item in the list
|
||||
// is always the mbean domain.
|
||||
func makeSubstitutionList(mbean string) []string {
|
||||
subs := make([]string, 0)
|
||||
|
||||
object := strings.SplitN(mbean, ":", 2)
|
||||
domain := object[0]
|
||||
|
||||
if domain != "" && len(object) == 2 {
|
||||
subs = append(subs, domain)
|
||||
list := object[1]
|
||||
|
||||
for _, keyProperty := range strings.Split(list, ",") {
|
||||
pair := strings.SplitN(keyProperty, "=", 2)
|
||||
|
||||
if len(pair) != 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
key := pair[0]
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
property := pair[1]
|
||||
if !strings.Contains(property, "*") {
|
||||
continue
|
||||
}
|
||||
|
||||
subs = append(subs, key)
|
||||
}
|
||||
}
|
||||
|
||||
return subs
|
||||
}
|
158
plugins/common/kafka/config.go
Normal file
158
plugins/common/kafka/config.go
Normal file
|
@ -0,0 +1,158 @@
|
|||
package kafka
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"math"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/IBM/sarama"
|
||||
|
||||
"github.com/influxdata/telegraf"
|
||||
"github.com/influxdata/telegraf/config"
|
||||
"github.com/influxdata/telegraf/plugins/common/tls"
|
||||
)
|
||||
|
||||
// ReadConfig for kafka clients meaning to read from Kafka.
|
||||
type ReadConfig struct {
|
||||
Config
|
||||
}
|
||||
|
||||
// SetConfig on the sarama.Config object from the ReadConfig struct.
|
||||
func (k *ReadConfig) SetConfig(cfg *sarama.Config, log telegraf.Logger) error {
|
||||
cfg.Consumer.Return.Errors = true
|
||||
return k.Config.SetConfig(cfg, log)
|
||||
}
|
||||
|
||||
// WriteConfig for kafka clients meaning to write to kafka
|
||||
type WriteConfig struct {
|
||||
Config
|
||||
|
||||
RequiredAcks int `toml:"required_acks"`
|
||||
MaxRetry int `toml:"max_retry"`
|
||||
MaxMessageBytes int `toml:"max_message_bytes"`
|
||||
IdempotentWrites bool `toml:"idempotent_writes"`
|
||||
}
|
||||
|
||||
// SetConfig on the sarama.Config object from the WriteConfig struct.
|
||||
func (k *WriteConfig) SetConfig(cfg *sarama.Config, log telegraf.Logger) error {
|
||||
cfg.Producer.Return.Successes = true
|
||||
cfg.Producer.Idempotent = k.IdempotentWrites
|
||||
cfg.Producer.Retry.Max = k.MaxRetry
|
||||
if k.MaxMessageBytes > 0 {
|
||||
cfg.Producer.MaxMessageBytes = k.MaxMessageBytes
|
||||
}
|
||||
cfg.Producer.RequiredAcks = sarama.RequiredAcks(k.RequiredAcks)
|
||||
if cfg.Producer.Idempotent {
|
||||
cfg.Net.MaxOpenRequests = 1
|
||||
}
|
||||
return k.Config.SetConfig(cfg, log)
|
||||
}
|
||||
|
||||
// Config common to all Kafka clients.
|
||||
type Config struct {
|
||||
SASLAuth
|
||||
tls.ClientConfig
|
||||
|
||||
Version string `toml:"version"`
|
||||
ClientID string `toml:"client_id"`
|
||||
CompressionCodec int `toml:"compression_codec"`
|
||||
EnableTLS *bool `toml:"enable_tls"`
|
||||
KeepAlivePeriod *config.Duration `toml:"keep_alive_period"`
|
||||
|
||||
MetadataRetryMax int `toml:"metadata_retry_max"`
|
||||
MetadataRetryType string `toml:"metadata_retry_type"`
|
||||
MetadataRetryBackoff config.Duration `toml:"metadata_retry_backoff"`
|
||||
MetadataRetryMaxDuration config.Duration `toml:"metadata_retry_max_duration"`
|
||||
|
||||
// Disable full metadata fetching
|
||||
MetadataFull *bool `toml:"metadata_full"`
|
||||
}
|
||||
|
||||
type BackoffFunc func(retries, maxRetries int) time.Duration
|
||||
|
||||
func makeBackoffFunc(backoff, maxDuration time.Duration) BackoffFunc {
|
||||
return func(retries, _ int) time.Duration {
|
||||
d := time.Duration(math.Pow(2, float64(retries))) * backoff
|
||||
if maxDuration != 0 && d > maxDuration {
|
||||
return maxDuration
|
||||
}
|
||||
return d
|
||||
}
|
||||
}
|
||||
|
||||
// SetConfig on the sarama.Config object from the Config struct.
|
||||
func (k *Config) SetConfig(cfg *sarama.Config, log telegraf.Logger) error {
|
||||
if k.Version != "" {
|
||||
version, err := sarama.ParseKafkaVersion(k.Version)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cfg.Version = version
|
||||
}
|
||||
|
||||
if k.ClientID != "" {
|
||||
cfg.ClientID = k.ClientID
|
||||
} else {
|
||||
cfg.ClientID = "Telegraf"
|
||||
}
|
||||
|
||||
cfg.Producer.Compression = sarama.CompressionCodec(k.CompressionCodec)
|
||||
|
||||
if k.EnableTLS != nil && *k.EnableTLS {
|
||||
cfg.Net.TLS.Enable = true
|
||||
}
|
||||
|
||||
tlsConfig, err := k.ClientConfig.TLSConfig()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if tlsConfig != nil {
|
||||
cfg.Net.TLS.Config = tlsConfig
|
||||
|
||||
// To maintain backwards compatibility, if the enable_tls option is not
|
||||
// set TLS is enabled if a non-default TLS config is used.
|
||||
if k.EnableTLS == nil {
|
||||
cfg.Net.TLS.Enable = true
|
||||
}
|
||||
}
|
||||
|
||||
if k.KeepAlivePeriod != nil {
|
||||
// Defaults to OS setting (15s currently)
|
||||
cfg.Net.KeepAlive = time.Duration(*k.KeepAlivePeriod)
|
||||
}
|
||||
|
||||
if k.MetadataFull != nil {
|
||||
// Defaults to true in Sarama
|
||||
cfg.Metadata.Full = *k.MetadataFull
|
||||
}
|
||||
|
||||
if k.MetadataRetryMax != 0 {
|
||||
cfg.Metadata.Retry.Max = k.MetadataRetryMax
|
||||
}
|
||||
|
||||
if k.MetadataRetryBackoff != 0 {
|
||||
// If cfg.Metadata.Retry.BackoffFunc is set, sarama ignores
|
||||
// cfg.Metadata.Retry.Backoff
|
||||
cfg.Metadata.Retry.Backoff = time.Duration(k.MetadataRetryBackoff)
|
||||
}
|
||||
|
||||
switch strings.ToLower(k.MetadataRetryType) {
|
||||
default:
|
||||
return errors.New("invalid metadata retry type")
|
||||
case "exponential":
|
||||
if k.MetadataRetryBackoff == 0 {
|
||||
k.MetadataRetryBackoff = config.Duration(250 * time.Millisecond)
|
||||
log.Warnf("metadata_retry_backoff is 0, using %s", time.Duration(k.MetadataRetryBackoff))
|
||||
}
|
||||
cfg.Metadata.Retry.BackoffFunc = makeBackoffFunc(
|
||||
time.Duration(k.MetadataRetryBackoff),
|
||||
time.Duration(k.MetadataRetryMaxDuration),
|
||||
)
|
||||
case "constant", "":
|
||||
}
|
||||
|
||||
return k.SetSASLConfig(cfg)
|
||||
}
|
22
plugins/common/kafka/config_test.go
Normal file
22
plugins/common/kafka/config_test.go
Normal file
|
@ -0,0 +1,22 @@
|
|||
package kafka
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBackoffFunc(t *testing.T) {
|
||||
b := 250 * time.Millisecond
|
||||
limit := 1100 * time.Millisecond
|
||||
|
||||
f := makeBackoffFunc(b, limit)
|
||||
require.Equal(t, b, f(0, 0))
|
||||
require.Equal(t, b*2, f(1, 0))
|
||||
require.Equal(t, b*4, f(2, 0))
|
||||
require.Equal(t, limit, f(3, 0)) // would be 2000 but that's greater than max
|
||||
|
||||
f = makeBackoffFunc(b, 0) // max = 0 means no max
|
||||
require.Equal(t, b*8, f(3, 0)) // with no max, it's 2000
|
||||
}
|
41
plugins/common/kafka/logger.go
Normal file
41
plugins/common/kafka/logger.go
Normal file
|
@ -0,0 +1,41 @@
|
|||
package kafka
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/IBM/sarama"
|
||||
|
||||
"github.com/influxdata/telegraf"
|
||||
"github.com/influxdata/telegraf/logger"
|
||||
)
|
||||
|
||||
var (
|
||||
log = logger.New("sarama", "", "")
|
||||
once sync.Once
|
||||
)
|
||||
|
||||
type debugLogger struct{}
|
||||
|
||||
func (*debugLogger) Print(v ...interface{}) {
|
||||
log.Trace(v...)
|
||||
}
|
||||
|
||||
func (*debugLogger) Printf(format string, v ...interface{}) {
|
||||
log.Tracef(format, v...)
|
||||
}
|
||||
|
||||
func (l *debugLogger) Println(v ...interface{}) {
|
||||
l.Print(v...)
|
||||
}
|
||||
|
||||
// SetLogger configures a debug logger for kafka (sarama)
|
||||
func SetLogger(level telegraf.LogLevel) {
|
||||
// Set-up the sarama logger only once
|
||||
once.Do(func() {
|
||||
sarama.Logger = &debugLogger{}
|
||||
})
|
||||
// Increase the log-level if needed.
|
||||
if !log.Level().Includes(level) {
|
||||
log.SetLevel(level)
|
||||
}
|
||||
}
|
127
plugins/common/kafka/sasl.go
Normal file
127
plugins/common/kafka/sasl.go
Normal file
|
@ -0,0 +1,127 @@
|
|||
package kafka
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/IBM/sarama"
|
||||
|
||||
"github.com/influxdata/telegraf/config"
|
||||
)
|
||||
|
||||
type SASLAuth struct {
|
||||
SASLUsername config.Secret `toml:"sasl_username"`
|
||||
SASLPassword config.Secret `toml:"sasl_password"`
|
||||
SASLExtensions map[string]string `toml:"sasl_extensions"`
|
||||
SASLMechanism string `toml:"sasl_mechanism"`
|
||||
SASLVersion *int `toml:"sasl_version"`
|
||||
|
||||
// GSSAPI config
|
||||
SASLGSSAPIServiceName string `toml:"sasl_gssapi_service_name"`
|
||||
SASLGSSAPIAuthType string `toml:"sasl_gssapi_auth_type"`
|
||||
SASLGSSAPIDisablePAFXFAST bool `toml:"sasl_gssapi_disable_pafxfast"`
|
||||
SASLGSSAPIKerberosConfigPath string `toml:"sasl_gssapi_kerberos_config_path"`
|
||||
SASLGSSAPIKeyTabPath string `toml:"sasl_gssapi_key_tab_path"`
|
||||
SASLGSSAPIRealm string `toml:"sasl_gssapi_realm"`
|
||||
|
||||
// OAUTHBEARER config
|
||||
SASLAccessToken config.Secret `toml:"sasl_access_token"`
|
||||
}
|
||||
|
||||
// SetSASLConfig configures SASL for kafka (sarama)
|
||||
func (k *SASLAuth) SetSASLConfig(cfg *sarama.Config) error {
|
||||
username, err := k.SASLUsername.Get()
|
||||
if err != nil {
|
||||
return fmt.Errorf("getting username failed: %w", err)
|
||||
}
|
||||
cfg.Net.SASL.User = username.String()
|
||||
defer username.Destroy()
|
||||
password, err := k.SASLPassword.Get()
|
||||
if err != nil {
|
||||
return fmt.Errorf("getting password failed: %w", err)
|
||||
}
|
||||
cfg.Net.SASL.Password = password.String()
|
||||
defer password.Destroy()
|
||||
|
||||
if k.SASLMechanism != "" {
|
||||
cfg.Net.SASL.Mechanism = sarama.SASLMechanism(k.SASLMechanism)
|
||||
switch cfg.Net.SASL.Mechanism {
|
||||
case sarama.SASLTypeSCRAMSHA256:
|
||||
cfg.Net.SASL.SCRAMClientGeneratorFunc = func() sarama.SCRAMClient {
|
||||
return &XDGSCRAMClient{HashGeneratorFcn: SHA256}
|
||||
}
|
||||
case sarama.SASLTypeSCRAMSHA512:
|
||||
cfg.Net.SASL.SCRAMClientGeneratorFunc = func() sarama.SCRAMClient {
|
||||
return &XDGSCRAMClient{HashGeneratorFcn: SHA512}
|
||||
}
|
||||
case sarama.SASLTypeOAuth:
|
||||
cfg.Net.SASL.TokenProvider = k // use self as token provider.
|
||||
case sarama.SASLTypeGSSAPI:
|
||||
cfg.Net.SASL.GSSAPI.ServiceName = k.SASLGSSAPIServiceName
|
||||
cfg.Net.SASL.GSSAPI.AuthType = gssapiAuthType(k.SASLGSSAPIAuthType)
|
||||
cfg.Net.SASL.GSSAPI.Username = username.String()
|
||||
cfg.Net.SASL.GSSAPI.Password = password.String()
|
||||
cfg.Net.SASL.GSSAPI.DisablePAFXFAST = k.SASLGSSAPIDisablePAFXFAST
|
||||
cfg.Net.SASL.GSSAPI.KerberosConfigPath = k.SASLGSSAPIKerberosConfigPath
|
||||
cfg.Net.SASL.GSSAPI.KeyTabPath = k.SASLGSSAPIKeyTabPath
|
||||
cfg.Net.SASL.GSSAPI.Realm = k.SASLGSSAPIRealm
|
||||
|
||||
case sarama.SASLTypePlaintext:
|
||||
// nothing.
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
if !k.SASLUsername.Empty() || k.SASLMechanism != "" {
|
||||
cfg.Net.SASL.Enable = true
|
||||
|
||||
version, err := SASLVersion(cfg.Version, k.SASLVersion)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cfg.Net.SASL.Version = version
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Token does nothing smart, it just grabs a hard-coded token from config.
|
||||
func (k *SASLAuth) Token() (*sarama.AccessToken, error) {
|
||||
token, err := k.SASLAccessToken.Get()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting token failed: %w", err)
|
||||
}
|
||||
defer token.Destroy()
|
||||
return &sarama.AccessToken{
|
||||
Token: token.String(),
|
||||
Extensions: k.SASLExtensions,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func SASLVersion(kafkaVersion sarama.KafkaVersion, saslVersion *int) (int16, error) {
|
||||
if saslVersion == nil {
|
||||
if kafkaVersion.IsAtLeast(sarama.V1_0_0_0) {
|
||||
return sarama.SASLHandshakeV1, nil
|
||||
}
|
||||
return sarama.SASLHandshakeV0, nil
|
||||
}
|
||||
|
||||
switch *saslVersion {
|
||||
case 0:
|
||||
return sarama.SASLHandshakeV0, nil
|
||||
case 1:
|
||||
return sarama.SASLHandshakeV1, nil
|
||||
default:
|
||||
return 0, errors.New("invalid SASL version")
|
||||
}
|
||||
}
|
||||
|
||||
func gssapiAuthType(authType string) int {
|
||||
switch authType {
|
||||
case "KRB5_USER_AUTH":
|
||||
return sarama.KRB5_USER_AUTH
|
||||
case "KRB5_KEYTAB_AUTH":
|
||||
return sarama.KRB5_KEYTAB_AUTH
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
35
plugins/common/kafka/scram_client.go
Normal file
35
plugins/common/kafka/scram_client.go
Normal file
|
@ -0,0 +1,35 @@
|
|||
package kafka
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"crypto/sha512"
|
||||
"hash"
|
||||
|
||||
"github.com/xdg/scram"
|
||||
)
|
||||
|
||||
var SHA256 scram.HashGeneratorFcn = func() hash.Hash { return sha256.New() }
|
||||
var SHA512 scram.HashGeneratorFcn = func() hash.Hash { return sha512.New() }
|
||||
|
||||
type XDGSCRAMClient struct {
|
||||
*scram.Client
|
||||
*scram.ClientConversation
|
||||
scram.HashGeneratorFcn
|
||||
}
|
||||
|
||||
func (x *XDGSCRAMClient) Begin(userName, password, authzID string) (err error) {
|
||||
x.Client, err = x.HashGeneratorFcn.NewClient(userName, password, authzID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
x.ClientConversation = x.Client.NewConversation()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *XDGSCRAMClient) Step(challenge string) (response string, err error) {
|
||||
return x.ClientConversation.Step(challenge)
|
||||
}
|
||||
|
||||
func (x *XDGSCRAMClient) Done() bool {
|
||||
return x.ClientConversation.Done()
|
||||
}
|
35
plugins/common/logrus/hook.go
Normal file
35
plugins/common/logrus/hook.go
Normal file
|
@ -0,0 +1,35 @@
|
|||
package logrus
|
||||
|
||||
import (
|
||||
"io"
|
||||
"log" //nolint:depguard // Allow exceptional but valid use of log here.
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var once sync.Once
|
||||
|
||||
type LogHook struct {
|
||||
}
|
||||
|
||||
// InstallHook installs a logging hook into the logrus standard logger, diverting all logs
|
||||
// through the Telegraf logger at debug level. This is useful for libraries
|
||||
// that directly log to the logrus system without providing an override method.
|
||||
func InstallHook() {
|
||||
once.Do(func() {
|
||||
logrus.SetOutput(io.Discard)
|
||||
logrus.AddHook(&LogHook{})
|
||||
})
|
||||
}
|
||||
|
||||
func (*LogHook) Fire(entry *logrus.Entry) error {
|
||||
msg := strings.ReplaceAll(entry.Message, "\n", " ")
|
||||
log.Print("D! [logrus] ", msg)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (*LogHook) Levels() []logrus.Level {
|
||||
return logrus.AllLevels
|
||||
}
|
95
plugins/common/mqtt/mqtt.go
Normal file
95
plugins/common/mqtt/mqtt.go
Normal file
|
@ -0,0 +1,95 @@
|
|||
package mqtt
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
paho "github.com/eclipse/paho.mqtt.golang"
|
||||
|
||||
"github.com/influxdata/telegraf/config"
|
||||
"github.com/influxdata/telegraf/plugins/common/tls"
|
||||
)
|
||||
|
||||
// mqtt v5-specific publish properties.
|
||||
// See https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901109
|
||||
type PublishProperties struct {
|
||||
ContentType string `toml:"content_type"`
|
||||
ResponseTopic string `toml:"response_topic"`
|
||||
MessageExpiry config.Duration `toml:"message_expiry"`
|
||||
TopicAlias *uint16 `toml:"topic_alias"`
|
||||
UserProperties map[string]string `toml:"user_properties"`
|
||||
}
|
||||
|
||||
type MqttConfig struct {
|
||||
Servers []string `toml:"servers"`
|
||||
Protocol string `toml:"protocol"`
|
||||
Username config.Secret `toml:"username"`
|
||||
Password config.Secret `toml:"password"`
|
||||
Timeout config.Duration `toml:"timeout"`
|
||||
ConnectionTimeout config.Duration `toml:"connection_timeout"`
|
||||
QoS int `toml:"qos"`
|
||||
ClientID string `toml:"client_id"`
|
||||
Retain bool `toml:"retain"`
|
||||
KeepAlive int64 `toml:"keep_alive"`
|
||||
PersistentSession bool `toml:"persistent_session"`
|
||||
PublishPropertiesV5 *PublishProperties `toml:"v5"`
|
||||
ClientTrace bool `toml:"client_trace"`
|
||||
|
||||
tls.ClientConfig
|
||||
|
||||
AutoReconnect bool `toml:"-"`
|
||||
OnConnectionLost func(error) `toml:"-"`
|
||||
}
|
||||
|
||||
// Client is a protocol neutral MQTT client for connecting,
|
||||
// disconnecting, and publishing data to a topic.
|
||||
// The protocol specific clients must implement this interface
|
||||
type Client interface {
|
||||
Connect() (bool, error)
|
||||
Publish(topic string, data []byte) error
|
||||
SubscribeMultiple(filters map[string]byte, callback paho.MessageHandler) error
|
||||
AddRoute(topic string, callback paho.MessageHandler)
|
||||
Close() error
|
||||
}
|
||||
|
||||
func NewClient(cfg *MqttConfig) (Client, error) {
|
||||
if len(cfg.Servers) == 0 {
|
||||
return nil, errors.New("no servers specified")
|
||||
}
|
||||
|
||||
if cfg.PersistentSession && cfg.ClientID == "" {
|
||||
return nil, errors.New("persistent_session requires client_id")
|
||||
}
|
||||
|
||||
if cfg.QoS > 2 || cfg.QoS < 0 {
|
||||
return nil, fmt.Errorf("invalid QoS value %d; must be 0, 1 or 2", cfg.QoS)
|
||||
}
|
||||
|
||||
switch cfg.Protocol {
|
||||
case "", "3.1.1":
|
||||
return NewMQTTv311Client(cfg)
|
||||
case "5":
|
||||
return NewMQTTv5Client(cfg)
|
||||
}
|
||||
return nil, fmt.Errorf("unsupported protocol %q: must be \"3.1.1\" or \"5\"", cfg.Protocol)
|
||||
}
|
||||
|
||||
func parseServers(servers []string) ([]*url.URL, error) {
|
||||
urls := make([]*url.URL, 0, len(servers))
|
||||
for _, svr := range servers {
|
||||
// Preserve support for host:port style servers; deprecated in Telegraf 1.4.4
|
||||
if !strings.Contains(svr, "://") {
|
||||
urls = append(urls, &url.URL{Scheme: "tcp", Host: svr})
|
||||
continue
|
||||
}
|
||||
|
||||
u, err := url.Parse(svr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
urls = append(urls, u)
|
||||
}
|
||||
return urls, nil
|
||||
}
|
17
plugins/common/mqtt/mqtt_logger.go
Normal file
17
plugins/common/mqtt/mqtt_logger.go
Normal file
|
@ -0,0 +1,17 @@
|
|||
package mqtt
|
||||
|
||||
import (
|
||||
"github.com/influxdata/telegraf"
|
||||
)
|
||||
|
||||
type mqttLogger struct {
|
||||
telegraf.Logger
|
||||
}
|
||||
|
||||
func (l mqttLogger) Printf(fmt string, args ...interface{}) {
|
||||
l.Logger.Debugf(fmt, args...)
|
||||
}
|
||||
|
||||
func (l mqttLogger) Println(args ...interface{}) {
|
||||
l.Logger.Debug(args...)
|
||||
}
|
26
plugins/common/mqtt/mqtt_test.go
Normal file
26
plugins/common/mqtt/mqtt_test.go
Normal file
|
@ -0,0 +1,26 @@
|
|||
package mqtt
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// Test that default client has random ID
|
||||
func TestRandomClientID(t *testing.T) {
|
||||
var err error
|
||||
|
||||
cfg := &MqttConfig{
|
||||
Servers: []string{"tcp://localhost:1883"},
|
||||
}
|
||||
|
||||
client1, err := NewMQTTv311Client(cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
client2, err := NewMQTTv311Client(cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
options1 := client1.client.OptionsReader()
|
||||
options2 := client2.client.OptionsReader()
|
||||
require.NotEqual(t, options1.ClientID(), options2.ClientID())
|
||||
}
|
139
plugins/common/mqtt/mqtt_v3.go
Normal file
139
plugins/common/mqtt/mqtt_v3.go
Normal file
|
@ -0,0 +1,139 @@
|
|||
package mqtt
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
mqttv3 "github.com/eclipse/paho.mqtt.golang" // Library that supports v3.1.1
|
||||
|
||||
"github.com/influxdata/telegraf/internal"
|
||||
"github.com/influxdata/telegraf/logger"
|
||||
)
|
||||
|
||||
type mqttv311Client struct {
|
||||
client mqttv3.Client
|
||||
timeout time.Duration
|
||||
qos int
|
||||
retain bool
|
||||
}
|
||||
|
||||
func NewMQTTv311Client(cfg *MqttConfig) (*mqttv311Client, error) {
|
||||
opts := mqttv3.NewClientOptions()
|
||||
opts.KeepAlive = cfg.KeepAlive
|
||||
opts.WriteTimeout = time.Duration(cfg.Timeout)
|
||||
if time.Duration(cfg.ConnectionTimeout) >= 1*time.Second {
|
||||
opts.ConnectTimeout = time.Duration(cfg.ConnectionTimeout)
|
||||
}
|
||||
opts.SetCleanSession(!cfg.PersistentSession)
|
||||
if cfg.OnConnectionLost != nil {
|
||||
onConnectionLost := func(_ mqttv3.Client, err error) {
|
||||
cfg.OnConnectionLost(err)
|
||||
}
|
||||
opts.SetConnectionLostHandler(onConnectionLost)
|
||||
}
|
||||
opts.SetAutoReconnect(cfg.AutoReconnect)
|
||||
|
||||
if cfg.ClientID != "" {
|
||||
opts.SetClientID(cfg.ClientID)
|
||||
} else {
|
||||
id, err := internal.RandomString(5)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generating random client ID failed: %w", err)
|
||||
}
|
||||
opts.SetClientID("Telegraf-Output-" + id)
|
||||
}
|
||||
|
||||
tlsCfg, err := cfg.ClientConfig.TLSConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
opts.SetTLSConfig(tlsCfg)
|
||||
|
||||
if !cfg.Username.Empty() {
|
||||
user, err := cfg.Username.Get()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting username failed: %w", err)
|
||||
}
|
||||
opts.SetUsername(user.String())
|
||||
user.Destroy()
|
||||
}
|
||||
if !cfg.Password.Empty() {
|
||||
password, err := cfg.Password.Get()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting password failed: %w", err)
|
||||
}
|
||||
opts.SetPassword(password.String())
|
||||
password.Destroy()
|
||||
}
|
||||
|
||||
servers, err := parseServers(cfg.Servers)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, server := range servers {
|
||||
if tlsCfg != nil {
|
||||
server.Scheme = "tls"
|
||||
}
|
||||
broker := server.String()
|
||||
opts.AddBroker(broker)
|
||||
}
|
||||
|
||||
if cfg.ClientTrace {
|
||||
log := &mqttLogger{logger.New("paho", "", "")}
|
||||
mqttv3.ERROR = log
|
||||
mqttv3.CRITICAL = log
|
||||
mqttv3.WARN = log
|
||||
mqttv3.DEBUG = log
|
||||
}
|
||||
|
||||
return &mqttv311Client{
|
||||
client: mqttv3.NewClient(opts),
|
||||
timeout: time.Duration(cfg.Timeout),
|
||||
qos: cfg.QoS,
|
||||
retain: cfg.Retain,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *mqttv311Client) Connect() (bool, error) {
|
||||
token := m.client.Connect()
|
||||
|
||||
if token.Wait() && token.Error() != nil {
|
||||
return false, token.Error()
|
||||
}
|
||||
|
||||
// Persistent sessions should skip subscription if a session is present, as
|
||||
// the subscriptions are stored by the server.
|
||||
type sessionPresent interface {
|
||||
SessionPresent() bool
|
||||
}
|
||||
if t, ok := token.(sessionPresent); ok {
|
||||
return t.SessionPresent(), nil
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (m *mqttv311Client) Publish(topic string, body []byte) error {
|
||||
token := m.client.Publish(topic, byte(m.qos), m.retain, body)
|
||||
if !token.WaitTimeout(m.timeout) {
|
||||
return internal.ErrTimeout
|
||||
}
|
||||
return token.Error()
|
||||
}
|
||||
|
||||
func (m *mqttv311Client) SubscribeMultiple(filters map[string]byte, callback mqttv3.MessageHandler) error {
|
||||
token := m.client.SubscribeMultiple(filters, callback)
|
||||
token.Wait()
|
||||
return token.Error()
|
||||
}
|
||||
|
||||
func (m *mqttv311Client) AddRoute(topic string, callback mqttv3.MessageHandler) {
|
||||
m.client.AddRoute(topic, callback)
|
||||
}
|
||||
|
||||
func (m *mqttv311Client) Close() error {
|
||||
if m.client.IsConnected() {
|
||||
m.client.Disconnect(100)
|
||||
}
|
||||
return nil
|
||||
}
|
165
plugins/common/mqtt/mqtt_v5.go
Normal file
165
plugins/common/mqtt/mqtt_v5.go
Normal file
|
@ -0,0 +1,165 @@
|
|||
package mqtt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
mqttv5auto "github.com/eclipse/paho.golang/autopaho"
|
||||
mqttv5 "github.com/eclipse/paho.golang/paho"
|
||||
paho "github.com/eclipse/paho.mqtt.golang"
|
||||
|
||||
"github.com/influxdata/telegraf/config"
|
||||
"github.com/influxdata/telegraf/internal"
|
||||
"github.com/influxdata/telegraf/logger"
|
||||
)
|
||||
|
||||
type mqttv5Client struct {
|
||||
client *mqttv5auto.ConnectionManager
|
||||
options mqttv5auto.ClientConfig
|
||||
username config.Secret
|
||||
password config.Secret
|
||||
timeout time.Duration
|
||||
qos int
|
||||
retain bool
|
||||
clientTrace bool
|
||||
properties *mqttv5.PublishProperties
|
||||
}
|
||||
|
||||
func NewMQTTv5Client(cfg *MqttConfig) (*mqttv5Client, error) {
|
||||
opts := mqttv5auto.ClientConfig{
|
||||
KeepAlive: uint16(cfg.KeepAlive),
|
||||
OnConnectError: cfg.OnConnectionLost,
|
||||
}
|
||||
opts.ConnectPacketBuilder = func(c *mqttv5.Connect, _ *url.URL) (*mqttv5.Connect, error) {
|
||||
c.CleanStart = cfg.PersistentSession
|
||||
return c, nil
|
||||
}
|
||||
|
||||
if time.Duration(cfg.ConnectionTimeout) >= 1*time.Second {
|
||||
opts.ConnectTimeout = time.Duration(cfg.ConnectionTimeout)
|
||||
}
|
||||
|
||||
if cfg.ClientID != "" {
|
||||
opts.ClientID = cfg.ClientID
|
||||
} else {
|
||||
id, err := internal.RandomString(5)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generating random client ID failed: %w", err)
|
||||
}
|
||||
opts.ClientID = "Telegraf-Output-" + id
|
||||
}
|
||||
|
||||
tlsCfg, err := cfg.ClientConfig.TLSConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if tlsCfg != nil {
|
||||
opts.TlsCfg = tlsCfg
|
||||
}
|
||||
|
||||
brokers := make([]*url.URL, 0)
|
||||
servers, err := parseServers(cfg.Servers)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, server := range servers {
|
||||
if tlsCfg != nil {
|
||||
server.Scheme = "tls"
|
||||
}
|
||||
brokers = append(brokers, server)
|
||||
}
|
||||
opts.BrokerUrls = brokers
|
||||
|
||||
// Build the v5 specific publish properties if they are present in the config.
|
||||
// These should not change during the lifecycle of the client.
|
||||
var properties *mqttv5.PublishProperties
|
||||
if cfg.PublishPropertiesV5 != nil {
|
||||
properties = &mqttv5.PublishProperties{
|
||||
ContentType: cfg.PublishPropertiesV5.ContentType,
|
||||
ResponseTopic: cfg.PublishPropertiesV5.ResponseTopic,
|
||||
TopicAlias: cfg.PublishPropertiesV5.TopicAlias,
|
||||
}
|
||||
|
||||
messageExpiry := time.Duration(cfg.PublishPropertiesV5.MessageExpiry)
|
||||
if expirySeconds := uint32(messageExpiry.Seconds()); expirySeconds > 0 {
|
||||
properties.MessageExpiry = &expirySeconds
|
||||
}
|
||||
|
||||
properties.User = make([]mqttv5.UserProperty, 0, len(cfg.PublishPropertiesV5.UserProperties))
|
||||
for k, v := range cfg.PublishPropertiesV5.UserProperties {
|
||||
properties.User.Add(k, v)
|
||||
}
|
||||
}
|
||||
|
||||
return &mqttv5Client{
|
||||
options: opts,
|
||||
timeout: time.Duration(cfg.Timeout),
|
||||
username: cfg.Username,
|
||||
password: cfg.Password,
|
||||
qos: cfg.QoS,
|
||||
retain: cfg.Retain,
|
||||
properties: properties,
|
||||
clientTrace: cfg.ClientTrace,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *mqttv5Client) Connect() (bool, error) {
|
||||
user, err := m.username.Get()
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("getting username failed: %w", err)
|
||||
}
|
||||
defer user.Destroy()
|
||||
pass, err := m.password.Get()
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("getting password failed: %w", err)
|
||||
}
|
||||
defer pass.Destroy()
|
||||
m.options.ConnectUsername = user.String()
|
||||
m.options.ConnectPassword = []byte(pass.String())
|
||||
|
||||
if m.clientTrace {
|
||||
log := mqttLogger{logger.New("paho", "", "")}
|
||||
m.options.Debug = log
|
||||
m.options.Errors = log
|
||||
}
|
||||
|
||||
client, err := mqttv5auto.NewConnection(context.Background(), m.options)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
m.client = client
|
||||
return false, client.AwaitConnection(context.Background())
|
||||
}
|
||||
|
||||
func (m *mqttv5Client) Publish(topic string, body []byte) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), m.timeout)
|
||||
defer cancel()
|
||||
|
||||
_, err := m.client.Publish(ctx, &mqttv5.Publish{
|
||||
Topic: topic,
|
||||
QoS: byte(m.qos),
|
||||
Retain: m.retain,
|
||||
Payload: body,
|
||||
Properties: m.properties,
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (*mqttv5Client) SubscribeMultiple(filters map[string]byte, callback paho.MessageHandler) error {
|
||||
_, _ = filters, callback
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (*mqttv5Client) AddRoute(topic string, callback paho.MessageHandler) {
|
||||
_, _ = topic, callback
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (m *mqttv5Client) Close() error {
|
||||
return m.client.Disconnect(context.Background())
|
||||
}
|
42
plugins/common/oauth/config.go
Normal file
42
plugins/common/oauth/config.go
Normal file
|
@ -0,0 +1,42 @@
|
|||
package oauth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/clientcredentials"
|
||||
)
|
||||
|
||||
type OAuth2Config struct {
|
||||
// OAuth2 Credentials
|
||||
ClientID string `toml:"client_id"`
|
||||
ClientSecret string `toml:"client_secret"`
|
||||
TokenURL string `toml:"token_url"`
|
||||
Audience string `toml:"audience"`
|
||||
Scopes []string `toml:"scopes"`
|
||||
}
|
||||
|
||||
func (o *OAuth2Config) CreateOauth2Client(ctx context.Context, client *http.Client) *http.Client {
|
||||
if o.ClientID == "" || o.ClientSecret == "" || o.TokenURL == "" {
|
||||
return client
|
||||
}
|
||||
|
||||
oauthConfig := clientcredentials.Config{
|
||||
ClientID: o.ClientID,
|
||||
ClientSecret: o.ClientSecret,
|
||||
TokenURL: o.TokenURL,
|
||||
Scopes: o.Scopes,
|
||||
EndpointParams: make(url.Values),
|
||||
}
|
||||
|
||||
if o.Audience != "" {
|
||||
oauthConfig.EndpointParams.Add("audience", o.Audience)
|
||||
}
|
||||
|
||||
ctx = context.WithValue(ctx, oauth2.HTTPClient, client)
|
||||
client = oauthConfig.Client(ctx)
|
||||
|
||||
return client
|
||||
}
|
241
plugins/common/opcua/client.go
Normal file
241
plugins/common/opcua/client.go
Normal file
|
@ -0,0 +1,241 @@
|
|||
package opcua
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log" //nolint:depguard // just for debug
|
||||
"net/url"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/gopcua/opcua"
|
||||
"github.com/gopcua/opcua/debug"
|
||||
"github.com/gopcua/opcua/ua"
|
||||
|
||||
"github.com/influxdata/telegraf"
|
||||
"github.com/influxdata/telegraf/config"
|
||||
"github.com/influxdata/telegraf/internal/choice"
|
||||
)
|
||||
|
||||
type OpcUAWorkarounds struct {
|
||||
AdditionalValidStatusCodes []string `toml:"additional_valid_status_codes"`
|
||||
}
|
||||
|
||||
type ConnectionState opcua.ConnState
|
||||
|
||||
const (
|
||||
Closed ConnectionState = ConnectionState(opcua.Closed)
|
||||
Connected ConnectionState = ConnectionState(opcua.Connected)
|
||||
Connecting ConnectionState = ConnectionState(opcua.Connecting)
|
||||
Disconnected ConnectionState = ConnectionState(opcua.Disconnected)
|
||||
Reconnecting ConnectionState = ConnectionState(opcua.Reconnecting)
|
||||
)
|
||||
|
||||
func (c ConnectionState) String() string {
|
||||
return opcua.ConnState(c).String()
|
||||
}
|
||||
|
||||
type OpcUAClientConfig struct {
|
||||
Endpoint string `toml:"endpoint"`
|
||||
SecurityPolicy string `toml:"security_policy"`
|
||||
SecurityMode string `toml:"security_mode"`
|
||||
Certificate string `toml:"certificate"`
|
||||
PrivateKey string `toml:"private_key"`
|
||||
Username config.Secret `toml:"username"`
|
||||
Password config.Secret `toml:"password"`
|
||||
AuthMethod string `toml:"auth_method"`
|
||||
ConnectTimeout config.Duration `toml:"connect_timeout"`
|
||||
RequestTimeout config.Duration `toml:"request_timeout"`
|
||||
ClientTrace bool `toml:"client_trace"`
|
||||
|
||||
OptionalFields []string `toml:"optional_fields"`
|
||||
Workarounds OpcUAWorkarounds `toml:"workarounds"`
|
||||
SessionTimeout config.Duration `toml:"session_timeout"`
|
||||
}
|
||||
|
||||
func (o *OpcUAClientConfig) Validate() error {
|
||||
if err := o.validateOptionalFields(); err != nil {
|
||||
return fmt.Errorf("invalid 'optional_fields': %w", err)
|
||||
}
|
||||
|
||||
return o.validateEndpoint()
|
||||
}
|
||||
|
||||
func (o *OpcUAClientConfig) validateOptionalFields() error {
|
||||
validFields := []string{"DataType"}
|
||||
return choice.CheckSlice(o.OptionalFields, validFields)
|
||||
}
|
||||
|
||||
func (o *OpcUAClientConfig) validateEndpoint() error {
|
||||
if o.Endpoint == "" {
|
||||
return errors.New("endpoint url is empty")
|
||||
}
|
||||
|
||||
_, err := url.Parse(o.Endpoint)
|
||||
if err != nil {
|
||||
return errors.New("endpoint url is invalid")
|
||||
}
|
||||
|
||||
switch o.SecurityPolicy {
|
||||
case "None", "Basic128Rsa15", "Basic256", "Basic256Sha256", "auto":
|
||||
default:
|
||||
return fmt.Errorf("invalid security type %q in %q", o.SecurityPolicy, o.Endpoint)
|
||||
}
|
||||
|
||||
switch o.SecurityMode {
|
||||
case "None", "Sign", "SignAndEncrypt", "auto":
|
||||
default:
|
||||
return fmt.Errorf("invalid security type %q in %q", o.SecurityMode, o.Endpoint)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o *OpcUAClientConfig) CreateClient(telegrafLogger telegraf.Logger) (*OpcUAClient, error) {
|
||||
err := o.Validate()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if o.ClientTrace {
|
||||
debug.Enable = true
|
||||
debug.Logger = log.New(&DebugLogger{Log: telegrafLogger}, "", 0)
|
||||
}
|
||||
|
||||
c := &OpcUAClient{
|
||||
Config: o,
|
||||
Log: telegrafLogger,
|
||||
}
|
||||
c.Log.Debug("Initialising OpcUAClient")
|
||||
|
||||
err = c.setupWorkarounds()
|
||||
return c, err
|
||||
}
|
||||
|
||||
type OpcUAClient struct {
|
||||
Config *OpcUAClientConfig
|
||||
Log telegraf.Logger
|
||||
|
||||
Client *opcua.Client
|
||||
|
||||
opts []opcua.Option
|
||||
codes []ua.StatusCode
|
||||
}
|
||||
|
||||
// / setupOptions read the endpoints from the specified server and setup all authentication
|
||||
func (o *OpcUAClient) SetupOptions() error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(o.Config.ConnectTimeout))
|
||||
defer cancel()
|
||||
// Get a list of the endpoints for our target server
|
||||
endpoints, err := opcua.GetEndpoints(ctx, o.Config.Endpoint)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if o.Config.Certificate == "" && o.Config.PrivateKey == "" {
|
||||
if o.Config.SecurityPolicy != "None" || o.Config.SecurityMode != "None" {
|
||||
o.Log.Debug("Generating self-signed certificate")
|
||||
cert, privateKey, err := generateCert("urn:telegraf:gopcua:client", 2048,
|
||||
o.Config.Certificate, o.Config.PrivateKey, 365*24*time.Hour)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
o.Config.Certificate = cert
|
||||
o.Config.PrivateKey = privateKey
|
||||
}
|
||||
}
|
||||
|
||||
o.Log.Debug("Configuring OPC UA connection options")
|
||||
o.opts, err = o.generateClientOpts(endpoints)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (o *OpcUAClient) setupWorkarounds() error {
|
||||
o.codes = []ua.StatusCode{ua.StatusOK}
|
||||
for _, c := range o.Config.Workarounds.AdditionalValidStatusCodes {
|
||||
val, err := strconv.ParseUint(c, 0, 32) // setting 32 bits to allow for safe conversion
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
o.codes = append(o.codes, ua.StatusCode(val))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o *OpcUAClient) StatusCodeOK(code ua.StatusCode) bool {
|
||||
for _, val := range o.codes {
|
||||
if val == code {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Connect to an OPC UA device
|
||||
func (o *OpcUAClient) Connect(ctx context.Context) error {
|
||||
o.Log.Debug("Connecting OPC UA Client to server")
|
||||
u, err := url.Parse(o.Config.Endpoint)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch u.Scheme {
|
||||
case "opc.tcp":
|
||||
if err := o.SetupOptions(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if o.Client != nil {
|
||||
o.Log.Warnf("Closing connection to %q as already connected", u)
|
||||
if err := o.Client.Close(ctx); err != nil {
|
||||
// Only log the error but to not bail-out here as this prevents
|
||||
// reconnections for multiple parties (see e.g. #9523).
|
||||
o.Log.Errorf("Closing connection failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
o.Client, err = opcua.NewClient(o.Config.Endpoint, o.opts...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error in new client: %w", err)
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(o.Config.ConnectTimeout))
|
||||
defer cancel()
|
||||
if err := o.Client.Connect(ctx); err != nil {
|
||||
return fmt.Errorf("error in Client Connection: %w", err)
|
||||
}
|
||||
o.Log.Debug("Connected to OPC UA Server")
|
||||
|
||||
default:
|
||||
return fmt.Errorf("unsupported scheme %q in endpoint. Expected opc.tcp", u.Scheme)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o *OpcUAClient) Disconnect(ctx context.Context) error {
|
||||
o.Log.Debug("Disconnecting from OPC UA Server")
|
||||
u, err := url.Parse(o.Config.Endpoint)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch u.Scheme {
|
||||
case "opc.tcp":
|
||||
// We can't do anything about failing to close a connection
|
||||
err := o.Client.Close(ctx)
|
||||
o.Client = nil
|
||||
return err
|
||||
default:
|
||||
return errors.New("invalid controller")
|
||||
}
|
||||
}
|
||||
|
||||
func (o *OpcUAClient) State() ConnectionState {
|
||||
if o.Client == nil {
|
||||
return Disconnected
|
||||
}
|
||||
return ConnectionState(o.Client.State())
|
||||
}
|
33
plugins/common/opcua/client_test.go
Normal file
33
plugins/common/opcua/client_test.go
Normal file
|
@ -0,0 +1,33 @@
|
|||
package opcua
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/gopcua/opcua/ua"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSetupWorkarounds(t *testing.T) {
|
||||
o := OpcUAClient{
|
||||
Config: &OpcUAClientConfig{
|
||||
Workarounds: OpcUAWorkarounds{
|
||||
AdditionalValidStatusCodes: []string{"0xC0", "0x00AA0000", "0x80000000"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err := o.setupWorkarounds()
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Len(t, o.codes, 4)
|
||||
require.Equal(t, o.codes[0], ua.StatusCode(0))
|
||||
require.Equal(t, o.codes[1], ua.StatusCode(192))
|
||||
require.Equal(t, o.codes[2], ua.StatusCode(11141120))
|
||||
require.Equal(t, o.codes[3], ua.StatusCode(2147483648))
|
||||
}
|
||||
|
||||
func TestCheckStatusCode(t *testing.T) {
|
||||
var o OpcUAClient
|
||||
o.codes = []ua.StatusCode{ua.StatusCode(0), ua.StatusCode(192), ua.StatusCode(11141120)}
|
||||
require.True(t, o.StatusCodeOK(ua.StatusCode(192)))
|
||||
}
|
500
plugins/common/opcua/input/input_client.go
Normal file
500
plugins/common/opcua/input/input_client.go
Normal file
|
@ -0,0 +1,500 @@
|
|||
package input
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gopcua/opcua/ua"
|
||||
|
||||
"github.com/influxdata/telegraf"
|
||||
"github.com/influxdata/telegraf/config"
|
||||
"github.com/influxdata/telegraf/internal/choice"
|
||||
"github.com/influxdata/telegraf/metric"
|
||||
"github.com/influxdata/telegraf/plugins/common/opcua"
|
||||
)
|
||||
|
||||
type Trigger string
|
||||
|
||||
const (
|
||||
Status Trigger = "Status"
|
||||
StatusValue Trigger = "StatusValue"
|
||||
StatusValueTimestamp Trigger = "StatusValueTimestamp"
|
||||
)
|
||||
|
||||
type DeadbandType string
|
||||
|
||||
const (
|
||||
Absolute DeadbandType = "Absolute"
|
||||
Percent DeadbandType = "Percent"
|
||||
)
|
||||
|
||||
type DataChangeFilter struct {
|
||||
Trigger Trigger `toml:"trigger"`
|
||||
DeadbandType DeadbandType `toml:"deadband_type"`
|
||||
DeadbandValue *float64 `toml:"deadband_value"`
|
||||
}
|
||||
|
||||
type MonitoringParameters struct {
|
||||
SamplingInterval config.Duration `toml:"sampling_interval"`
|
||||
QueueSize *uint32 `toml:"queue_size"`
|
||||
DiscardOldest *bool `toml:"discard_oldest"`
|
||||
DataChangeFilter *DataChangeFilter `toml:"data_change_filter"`
|
||||
}
|
||||
|
||||
// NodeSettings describes how to map from a OPC UA node to a Metric
|
||||
type NodeSettings struct {
|
||||
FieldName string `toml:"name"`
|
||||
Namespace string `toml:"namespace"`
|
||||
IdentifierType string `toml:"identifier_type"`
|
||||
Identifier string `toml:"identifier"`
|
||||
DataType string `toml:"data_type" deprecated:"1.17.0;1.35.0;option is ignored"`
|
||||
Description string `toml:"description" deprecated:"1.17.0;1.35.0;option is ignored"`
|
||||
TagsSlice [][]string `toml:"tags" deprecated:"1.25.0;1.35.0;use 'default_tags' instead"`
|
||||
DefaultTags map[string]string `toml:"default_tags"`
|
||||
MonitoringParams MonitoringParameters `toml:"monitoring_params"`
|
||||
}
|
||||
|
||||
// NodeID returns the OPC UA node id
|
||||
func (tag *NodeSettings) NodeID() string {
|
||||
return "ns=" + tag.Namespace + ";" + tag.IdentifierType + "=" + tag.Identifier
|
||||
}
|
||||
|
||||
// NodeGroupSettings describes a mapping of group of nodes to Metrics
|
||||
type NodeGroupSettings struct {
|
||||
MetricName string `toml:"name"` // Overrides plugin's setting
|
||||
Namespace string `toml:"namespace"` // Can be overridden by node setting
|
||||
IdentifierType string `toml:"identifier_type"` // Can be overridden by node setting
|
||||
Nodes []NodeSettings `toml:"nodes"`
|
||||
TagsSlice [][]string `toml:"tags" deprecated:"1.26.0;1.35.0;use default_tags"`
|
||||
DefaultTags map[string]string `toml:"default_tags"`
|
||||
SamplingInterval config.Duration `toml:"sampling_interval"` // Can be overridden by monitoring parameters
|
||||
}
|
||||
|
||||
type TimestampSource string
|
||||
|
||||
const (
|
||||
TimestampSourceServer TimestampSource = "server"
|
||||
TimestampSourceSource TimestampSource = "source"
|
||||
TimestampSourceTelegraf TimestampSource = "gather"
|
||||
)
|
||||
|
||||
// InputClientConfig a configuration for the input client
|
||||
type InputClientConfig struct {
|
||||
opcua.OpcUAClientConfig
|
||||
MetricName string `toml:"name"`
|
||||
Timestamp TimestampSource `toml:"timestamp"`
|
||||
TimestampFormat string `toml:"timestamp_format"`
|
||||
RootNodes []NodeSettings `toml:"nodes"`
|
||||
Groups []NodeGroupSettings `toml:"group"`
|
||||
}
|
||||
|
||||
func (o *InputClientConfig) Validate() error {
|
||||
if o.MetricName == "" {
|
||||
return errors.New("metric name is empty")
|
||||
}
|
||||
|
||||
err := choice.Check(string(o.Timestamp), []string{"", "gather", "server", "source"})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if o.TimestampFormat == "" {
|
||||
o.TimestampFormat = time.RFC3339Nano
|
||||
}
|
||||
|
||||
if len(o.Groups) == 0 && len(o.RootNodes) == 0 {
|
||||
return errors.New("no groups or root nodes provided to gather from")
|
||||
}
|
||||
for _, group := range o.Groups {
|
||||
if len(group.Nodes) == 0 {
|
||||
return errors.New("group has no nodes to collect from")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o *InputClientConfig) CreateInputClient(log telegraf.Logger) (*OpcUAInputClient, error) {
|
||||
if err := o.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Debug("Initialising OpcUAInputClient")
|
||||
opcClient, err := o.OpcUAClientConfig.CreateClient(log)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c := &OpcUAInputClient{
|
||||
OpcUAClient: opcClient,
|
||||
Log: log,
|
||||
Config: *o,
|
||||
}
|
||||
|
||||
log.Debug("Initialising node to metric mapping")
|
||||
if err := c.InitNodeMetricMapping(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.initLastReceivedValues()
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// NodeMetricMapping mapping from a single node to a metric
|
||||
type NodeMetricMapping struct {
|
||||
Tag NodeSettings
|
||||
idStr string
|
||||
metricName string
|
||||
MetricTags map[string]string
|
||||
}
|
||||
|
||||
// NewNodeMetricMapping builds a new NodeMetricMapping from the given argument
|
||||
func NewNodeMetricMapping(metricName string, node NodeSettings, groupTags map[string]string) (*NodeMetricMapping, error) {
|
||||
mergedTags := make(map[string]string)
|
||||
for n, t := range groupTags {
|
||||
mergedTags[n] = t
|
||||
}
|
||||
|
||||
nodeTags := make(map[string]string)
|
||||
if len(node.DefaultTags) > 0 {
|
||||
nodeTags = node.DefaultTags
|
||||
} else if len(node.TagsSlice) > 0 {
|
||||
// fixme: once the TagsSlice has been removed (after deprecation), remove this if else logic
|
||||
var err error
|
||||
nodeTags, err = tagsSliceToMap(node.TagsSlice)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
for n, t := range nodeTags {
|
||||
mergedTags[n] = t
|
||||
}
|
||||
|
||||
return &NodeMetricMapping{
|
||||
Tag: node,
|
||||
idStr: node.NodeID(),
|
||||
metricName: metricName,
|
||||
MetricTags: mergedTags,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// NodeValue The received value for a node
|
||||
type NodeValue struct {
|
||||
TagName string
|
||||
Value interface{}
|
||||
Quality ua.StatusCode
|
||||
ServerTime time.Time
|
||||
SourceTime time.Time
|
||||
DataType ua.TypeID
|
||||
IsArray bool
|
||||
}
|
||||
|
||||
// OpcUAInputClient can receive data from an OPC UA server and map it to Metrics. This type does not contain
|
||||
// logic for actually retrieving data from the server, but is used by other types like ReadClient and
|
||||
// OpcUAInputSubscribeClient to store data needed to convert node ids to the corresponding metrics.
|
||||
type OpcUAInputClient struct {
|
||||
*opcua.OpcUAClient
|
||||
Config InputClientConfig
|
||||
Log telegraf.Logger
|
||||
|
||||
NodeMetricMapping []NodeMetricMapping
|
||||
NodeIDs []*ua.NodeID
|
||||
LastReceivedData []NodeValue
|
||||
}
|
||||
|
||||
// Stop the connection to the client
|
||||
func (o *OpcUAInputClient) Stop(ctx context.Context) <-chan struct{} {
|
||||
ch := make(chan struct{})
|
||||
defer close(ch)
|
||||
err := o.Disconnect(ctx)
|
||||
if err != nil {
|
||||
o.Log.Warn("Disconnecting from server failed with error ", err)
|
||||
}
|
||||
|
||||
return ch
|
||||
}
|
||||
|
||||
// metricParts is only used to ensure no duplicate metrics are created
|
||||
type metricParts struct {
|
||||
metricName string
|
||||
fieldName string
|
||||
tags string // sorted by tag name and in format tag1=value1, tag2=value2
|
||||
}
|
||||
|
||||
func newMP(n *NodeMetricMapping) metricParts {
|
||||
keys := make([]string, 0, len(n.MetricTags))
|
||||
for key := range n.MetricTags {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
var sb strings.Builder
|
||||
for i, key := range keys {
|
||||
if i != 0 {
|
||||
sb.WriteString(", ")
|
||||
}
|
||||
sb.WriteString(key)
|
||||
sb.WriteString("=")
|
||||
sb.WriteString(n.MetricTags[key])
|
||||
}
|
||||
x := metricParts{
|
||||
metricName: n.metricName,
|
||||
fieldName: n.Tag.FieldName,
|
||||
tags: sb.String(),
|
||||
}
|
||||
return x
|
||||
}
|
||||
|
||||
// fixme: once the TagsSlice has been removed (after deprecation), remove this
|
||||
// tagsSliceToMap takes an array of pairs of strings and creates a map from it
|
||||
func tagsSliceToMap(tags [][]string) (map[string]string, error) {
|
||||
m := make(map[string]string)
|
||||
for i, tag := range tags {
|
||||
if len(tag) != 2 {
|
||||
return nil, fmt.Errorf("tag %d needs 2 values, has %d: %v", i+1, len(tag), tag)
|
||||
}
|
||||
if tag[0] == "" {
|
||||
return nil, fmt.Errorf("tag %d has empty name", i+1)
|
||||
}
|
||||
if tag[1] == "" {
|
||||
return nil, fmt.Errorf("tag %d has empty value", i+1)
|
||||
}
|
||||
if _, ok := m[tag[0]]; ok {
|
||||
return nil, fmt.Errorf("tag %d has duplicate key: %v", i+1, tag[0])
|
||||
}
|
||||
m[tag[0]] = tag[1]
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func validateNodeToAdd(existing map[metricParts]struct{}, nmm *NodeMetricMapping) error {
|
||||
if nmm.Tag.FieldName == "" {
|
||||
return fmt.Errorf("empty name in %q", nmm.Tag.FieldName)
|
||||
}
|
||||
|
||||
if len(nmm.Tag.Namespace) == 0 {
|
||||
return errors.New("empty node namespace not allowed")
|
||||
}
|
||||
|
||||
if len(nmm.Tag.Identifier) == 0 {
|
||||
return errors.New("empty node identifier not allowed")
|
||||
}
|
||||
|
||||
mp := newMP(nmm)
|
||||
if _, exists := existing[mp]; exists {
|
||||
return fmt.Errorf("name %q is duplicated (metric name %q, tags %q)",
|
||||
mp.fieldName, mp.metricName, mp.tags)
|
||||
}
|
||||
|
||||
switch nmm.Tag.IdentifierType {
|
||||
case "i":
|
||||
if _, err := strconv.Atoi(nmm.Tag.Identifier); err != nil {
|
||||
return fmt.Errorf("identifier type %q does not match the type of identifier %q", nmm.Tag.IdentifierType, nmm.Tag.Identifier)
|
||||
}
|
||||
case "s", "g", "b":
|
||||
// Valid identifier type - do nothing.
|
||||
default:
|
||||
return fmt.Errorf("invalid identifier type %q in %q", nmm.Tag.IdentifierType, nmm.Tag.FieldName)
|
||||
}
|
||||
|
||||
existing[mp] = struct{}{}
|
||||
return nil
|
||||
}
|
||||
|
||||
// InitNodeMetricMapping builds nodes from the configuration
|
||||
func (o *OpcUAInputClient) InitNodeMetricMapping() error {
|
||||
existing := make(map[metricParts]struct{}, len(o.Config.RootNodes))
|
||||
for _, node := range o.Config.RootNodes {
|
||||
nmm, err := NewNodeMetricMapping(o.Config.MetricName, node, make(map[string]string))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := validateNodeToAdd(existing, nmm); err != nil {
|
||||
return err
|
||||
}
|
||||
o.NodeMetricMapping = append(o.NodeMetricMapping, *nmm)
|
||||
}
|
||||
|
||||
for _, group := range o.Config.Groups {
|
||||
if group.MetricName == "" {
|
||||
group.MetricName = o.Config.MetricName
|
||||
}
|
||||
|
||||
if len(group.DefaultTags) > 0 && len(group.TagsSlice) > 0 {
|
||||
o.Log.Warn("Tags found in both `tags` and `default_tags`, only using tags defined in `default_tags`")
|
||||
}
|
||||
|
||||
groupTags := make(map[string]string)
|
||||
if len(group.DefaultTags) > 0 {
|
||||
groupTags = group.DefaultTags
|
||||
} else if len(group.TagsSlice) > 0 {
|
||||
// fixme: once the TagsSlice has been removed (after deprecation), remove this if else logic
|
||||
var err error
|
||||
groupTags, err = tagsSliceToMap(group.TagsSlice)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
for _, node := range group.Nodes {
|
||||
if node.Namespace == "" {
|
||||
node.Namespace = group.Namespace
|
||||
}
|
||||
if node.IdentifierType == "" {
|
||||
node.IdentifierType = group.IdentifierType
|
||||
}
|
||||
if node.MonitoringParams.SamplingInterval == 0 {
|
||||
node.MonitoringParams.SamplingInterval = group.SamplingInterval
|
||||
}
|
||||
|
||||
nmm, err := NewNodeMetricMapping(group.MetricName, node, groupTags)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := validateNodeToAdd(existing, nmm); err != nil {
|
||||
return err
|
||||
}
|
||||
o.NodeMetricMapping = append(o.NodeMetricMapping, *nmm)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o *OpcUAInputClient) InitNodeIDs() error {
|
||||
o.NodeIDs = make([]*ua.NodeID, 0, len(o.NodeMetricMapping))
|
||||
for _, node := range o.NodeMetricMapping {
|
||||
nid, err := ua.ParseNodeID(node.Tag.NodeID())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
o.NodeIDs = append(o.NodeIDs, nid)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o *OpcUAInputClient) initLastReceivedValues() {
|
||||
o.LastReceivedData = make([]NodeValue, len(o.NodeMetricMapping))
|
||||
for nodeIdx, nmm := range o.NodeMetricMapping {
|
||||
o.LastReceivedData[nodeIdx].TagName = nmm.Tag.FieldName
|
||||
}
|
||||
}
|
||||
|
||||
func (o *OpcUAInputClient) UpdateNodeValue(nodeIdx int, d *ua.DataValue) {
|
||||
o.LastReceivedData[nodeIdx].Quality = d.Status
|
||||
if !o.StatusCodeOK(d.Status) {
|
||||
// Verify NodeIDs array has been built before trying to get item; otherwise show '?' for node id
|
||||
if len(o.NodeIDs) > nodeIdx {
|
||||
o.Log.Errorf("status not OK for node %v (%v): %v", o.NodeMetricMapping[nodeIdx].Tag.FieldName, o.NodeIDs[nodeIdx].String(), d.Status)
|
||||
} else {
|
||||
o.Log.Errorf("status not OK for node %v (%v): %v", o.NodeMetricMapping[nodeIdx].Tag.FieldName, '?', d.Status)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if d.Value != nil {
|
||||
o.LastReceivedData[nodeIdx].DataType = d.Value.Type()
|
||||
o.LastReceivedData[nodeIdx].IsArray = d.Value.Has(ua.VariantArrayValues)
|
||||
|
||||
o.LastReceivedData[nodeIdx].Value = d.Value.Value()
|
||||
if o.LastReceivedData[nodeIdx].DataType == ua.TypeIDDateTime {
|
||||
if t, ok := d.Value.Value().(time.Time); ok {
|
||||
o.LastReceivedData[nodeIdx].Value = t.Format(o.Config.TimestampFormat)
|
||||
}
|
||||
}
|
||||
}
|
||||
o.LastReceivedData[nodeIdx].ServerTime = d.ServerTimestamp
|
||||
o.LastReceivedData[nodeIdx].SourceTime = d.SourceTimestamp
|
||||
}
|
||||
|
||||
func (o *OpcUAInputClient) MetricForNode(nodeIdx int) telegraf.Metric {
|
||||
nmm := &o.NodeMetricMapping[nodeIdx]
|
||||
tags := map[string]string{
|
||||
"id": nmm.idStr,
|
||||
}
|
||||
for k, v := range nmm.MetricTags {
|
||||
tags[k] = v
|
||||
}
|
||||
|
||||
fields := make(map[string]interface{})
|
||||
if o.LastReceivedData[nodeIdx].Value != nil {
|
||||
// Simple scalar types can be stored directly under the field name while
|
||||
// arrays (see 5.2.5) and structures (see 5.2.6) must be unpacked.
|
||||
// Note: Structures and arrays of structures are currently not supported.
|
||||
if o.LastReceivedData[nodeIdx].IsArray {
|
||||
switch typedValue := o.LastReceivedData[nodeIdx].Value.(type) {
|
||||
case []uint8:
|
||||
fields = unpack(nmm.Tag.FieldName, typedValue)
|
||||
case []uint16:
|
||||
fields = unpack(nmm.Tag.FieldName, typedValue)
|
||||
case []uint32:
|
||||
fields = unpack(nmm.Tag.FieldName, typedValue)
|
||||
case []uint64:
|
||||
fields = unpack(nmm.Tag.FieldName, typedValue)
|
||||
case []int8:
|
||||
fields = unpack(nmm.Tag.FieldName, typedValue)
|
||||
case []int16:
|
||||
fields = unpack(nmm.Tag.FieldName, typedValue)
|
||||
case []int32:
|
||||
fields = unpack(nmm.Tag.FieldName, typedValue)
|
||||
case []int64:
|
||||
fields = unpack(nmm.Tag.FieldName, typedValue)
|
||||
case []float32:
|
||||
fields = unpack(nmm.Tag.FieldName, typedValue)
|
||||
case []float64:
|
||||
fields = unpack(nmm.Tag.FieldName, typedValue)
|
||||
case []string:
|
||||
fields = unpack(nmm.Tag.FieldName, typedValue)
|
||||
case []bool:
|
||||
fields = unpack(nmm.Tag.FieldName, typedValue)
|
||||
default:
|
||||
o.Log.Errorf("could not unpack variant array of type: %T", typedValue)
|
||||
}
|
||||
} else {
|
||||
fields = map[string]interface{}{
|
||||
nmm.Tag.FieldName: o.LastReceivedData[nodeIdx].Value,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fields["Quality"] = strings.TrimSpace(o.LastReceivedData[nodeIdx].Quality.Error())
|
||||
if choice.Contains("DataType", o.Config.OptionalFields) {
|
||||
fields["DataType"] = strings.Replace(o.LastReceivedData[nodeIdx].DataType.String(), "TypeID", "", 1)
|
||||
}
|
||||
if !o.StatusCodeOK(o.LastReceivedData[nodeIdx].Quality) {
|
||||
mp := newMP(nmm)
|
||||
o.Log.Debugf("status not OK for node %q(metric name %q, tags %q)",
|
||||
mp.fieldName, mp.metricName, mp.tags)
|
||||
}
|
||||
|
||||
var t time.Time
|
||||
switch o.Config.Timestamp {
|
||||
case TimestampSourceServer:
|
||||
t = o.LastReceivedData[nodeIdx].ServerTime
|
||||
case TimestampSourceSource:
|
||||
t = o.LastReceivedData[nodeIdx].SourceTime
|
||||
default:
|
||||
t = time.Now()
|
||||
}
|
||||
|
||||
return metric.New(nmm.metricName, tags, fields, t)
|
||||
}
|
||||
|
||||
func unpack[Slice ~[]E, E any](prefix string, value Slice) map[string]interface{} {
|
||||
fields := make(map[string]interface{}, len(value))
|
||||
for i, v := range value {
|
||||
key := fmt.Sprintf("%s[%d]", prefix, i)
|
||||
fields[key] = v
|
||||
}
|
||||
return fields
|
||||
}
|
890
plugins/common/opcua/input/input_client_test.go
Normal file
890
plugins/common/opcua/input/input_client_test.go
Normal file
|
@ -0,0 +1,890 @@
|
|||
package input
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gopcua/opcua/ua"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/influxdata/telegraf"
|
||||
"github.com/influxdata/telegraf/config"
|
||||
"github.com/influxdata/telegraf/metric"
|
||||
"github.com/influxdata/telegraf/plugins/common/opcua"
|
||||
"github.com/influxdata/telegraf/testutil"
|
||||
)
|
||||
|
||||
func TestTagsSliceToMap(t *testing.T) {
|
||||
m, err := tagsSliceToMap([][]string{{"foo", "bar"}, {"baz", "bat"}})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, m, 2)
|
||||
require.Equal(t, "bar", m["foo"])
|
||||
require.Equal(t, "bat", m["baz"])
|
||||
}
|
||||
|
||||
func TestTagsSliceToMap_twoStrings(t *testing.T) {
|
||||
var err error
|
||||
_, err = tagsSliceToMap([][]string{{"foo", "bar", "baz"}})
|
||||
require.Error(t, err)
|
||||
_, err = tagsSliceToMap([][]string{{"foo"}})
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestTagsSliceToMap_dupeKey(t *testing.T) {
|
||||
_, err := tagsSliceToMap([][]string{{"foo", "bar"}, {"foo", "bat"}})
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestTagsSliceToMap_empty(t *testing.T) {
|
||||
_, err := tagsSliceToMap([][]string{{"foo", ""}})
|
||||
require.Equal(t, errors.New("tag 1 has empty value"), err)
|
||||
_, err = tagsSliceToMap([][]string{{"", "bar"}})
|
||||
require.Equal(t, errors.New("tag 1 has empty name"), err)
|
||||
}
|
||||
|
||||
func TestValidateOPCTags(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config InputClientConfig
|
||||
err error
|
||||
}{
|
||||
{
|
||||
"duplicates",
|
||||
InputClientConfig{
|
||||
MetricName: "mn",
|
||||
RootNodes: []NodeSettings{
|
||||
{
|
||||
FieldName: "fn",
|
||||
Namespace: "2",
|
||||
IdentifierType: "s",
|
||||
Identifier: "i1",
|
||||
TagsSlice: [][]string{{"t1", "v1"}, {"t2", "v2"}},
|
||||
},
|
||||
},
|
||||
Groups: []NodeGroupSettings{
|
||||
{
|
||||
Nodes: []NodeSettings{
|
||||
{
|
||||
FieldName: "fn",
|
||||
Namespace: "2",
|
||||
IdentifierType: "s",
|
||||
Identifier: "i1",
|
||||
},
|
||||
},
|
||||
TagsSlice: [][]string{{"t1", "v1"}, {"t2", "v2"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
errors.New(`name "fn" is duplicated (metric name "mn", tags "t1=v1, t2=v2")`),
|
||||
},
|
||||
{
|
||||
"empty tag value not allowed",
|
||||
InputClientConfig{
|
||||
MetricName: "mn",
|
||||
RootNodes: []NodeSettings{
|
||||
{
|
||||
FieldName: "fn",
|
||||
IdentifierType: "s",
|
||||
TagsSlice: [][]string{{"t1", ""}},
|
||||
},
|
||||
},
|
||||
},
|
||||
errors.New("tag 1 has empty value"),
|
||||
},
|
||||
{
|
||||
"empty tag name not allowed",
|
||||
InputClientConfig{
|
||||
MetricName: "mn",
|
||||
RootNodes: []NodeSettings{
|
||||
{
|
||||
FieldName: "fn",
|
||||
IdentifierType: "s",
|
||||
TagsSlice: [][]string{{"", "1"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
errors.New("tag 1 has empty name"),
|
||||
},
|
||||
{
|
||||
"different metric tag names",
|
||||
InputClientConfig{
|
||||
MetricName: "mn",
|
||||
RootNodes: []NodeSettings{
|
||||
{
|
||||
FieldName: "fn",
|
||||
Namespace: "2",
|
||||
IdentifierType: "s",
|
||||
Identifier: "i1",
|
||||
TagsSlice: [][]string{{"t1", "v1"}, {"t2", "v2"}},
|
||||
},
|
||||
{
|
||||
FieldName: "fn",
|
||||
Namespace: "2",
|
||||
IdentifierType: "s",
|
||||
Identifier: "i1",
|
||||
TagsSlice: [][]string{{"t1", "v1"}, {"t3", "v2"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"different metric tag values",
|
||||
InputClientConfig{
|
||||
MetricName: "mn",
|
||||
RootNodes: []NodeSettings{
|
||||
{
|
||||
FieldName: "fn",
|
||||
Namespace: "2",
|
||||
IdentifierType: "s",
|
||||
Identifier: "i1",
|
||||
TagsSlice: [][]string{{"t1", "foo"}, {"t2", "v2"}},
|
||||
},
|
||||
{
|
||||
FieldName: "fn",
|
||||
Namespace: "2",
|
||||
IdentifierType: "s",
|
||||
Identifier: "i1",
|
||||
TagsSlice: [][]string{{"t1", "bar"}, {"t2", "v2"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"different metric names",
|
||||
InputClientConfig{
|
||||
MetricName: "mn",
|
||||
Groups: []NodeGroupSettings{
|
||||
{
|
||||
MetricName: "mn",
|
||||
Namespace: "2",
|
||||
Nodes: []NodeSettings{
|
||||
{
|
||||
FieldName: "fn",
|
||||
IdentifierType: "s",
|
||||
Identifier: "i1",
|
||||
TagsSlice: [][]string{{"t1", "v1"}, {"t2", "v2"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
MetricName: "mn2",
|
||||
Namespace: "2",
|
||||
Nodes: []NodeSettings{
|
||||
{
|
||||
FieldName: "fn",
|
||||
IdentifierType: "s",
|
||||
Identifier: "i1",
|
||||
TagsSlice: [][]string{{"t1", "v1"}, {"t2", "v2"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"different field names",
|
||||
InputClientConfig{
|
||||
MetricName: "mn",
|
||||
RootNodes: []NodeSettings{
|
||||
{
|
||||
FieldName: "fn",
|
||||
Namespace: "2",
|
||||
IdentifierType: "s",
|
||||
Identifier: "i1",
|
||||
TagsSlice: [][]string{{"t1", "v1"}, {"t2", "v2"}},
|
||||
},
|
||||
{
|
||||
FieldName: "fn2",
|
||||
Namespace: "2",
|
||||
IdentifierType: "s",
|
||||
Identifier: "i1",
|
||||
TagsSlice: [][]string{{"t1", "v1"}, {"t2", "v2"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
nil,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
o := OpcUAInputClient{
|
||||
Config: tt.config,
|
||||
Log: testutil.Logger{},
|
||||
}
|
||||
require.Equal(t, tt.err, o.InitNodeMetricMapping())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewNodeMetricMappingTags(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
settings NodeSettings
|
||||
groupTags map[string]string
|
||||
expectedTags map[string]string
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "empty tags",
|
||||
settings: NodeSettings{
|
||||
FieldName: "f",
|
||||
Namespace: "2",
|
||||
IdentifierType: "s",
|
||||
Identifier: "h",
|
||||
},
|
||||
groupTags: map[string]string{},
|
||||
expectedTags: map[string]string{},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "node tags only",
|
||||
settings: NodeSettings{
|
||||
FieldName: "f",
|
||||
Namespace: "2",
|
||||
IdentifierType: "s",
|
||||
Identifier: "h",
|
||||
TagsSlice: [][]string{{"t1", "v1"}},
|
||||
},
|
||||
groupTags: map[string]string{},
|
||||
expectedTags: map[string]string{"t1": "v1"},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "group tags only",
|
||||
settings: NodeSettings{
|
||||
FieldName: "f",
|
||||
Namespace: "2",
|
||||
IdentifierType: "s",
|
||||
Identifier: "h",
|
||||
},
|
||||
groupTags: map[string]string{"t1": "v1"},
|
||||
expectedTags: map[string]string{"t1": "v1"},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "node tag overrides group tags",
|
||||
settings: NodeSettings{
|
||||
FieldName: "f",
|
||||
Namespace: "2",
|
||||
IdentifierType: "s",
|
||||
Identifier: "h",
|
||||
TagsSlice: [][]string{{"t1", "v2"}},
|
||||
},
|
||||
groupTags: map[string]string{"t1": "v1"},
|
||||
expectedTags: map[string]string{"t1": "v2"},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "node tag merged with group tags",
|
||||
settings: NodeSettings{
|
||||
FieldName: "f",
|
||||
Namespace: "2",
|
||||
IdentifierType: "s",
|
||||
Identifier: "h",
|
||||
TagsSlice: [][]string{{"t2", "v2"}},
|
||||
},
|
||||
groupTags: map[string]string{"t1": "v1"},
|
||||
expectedTags: map[string]string{"t1": "v1", "t2": "v2"},
|
||||
err: nil,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
nmm, err := NewNodeMetricMapping("testmetric", tt.settings, tt.groupTags)
|
||||
require.Equal(t, tt.err, err)
|
||||
require.Equal(t, tt.expectedTags, nmm.MetricTags)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewNodeMetricMappingIdStrInstantiated(t *testing.T) {
|
||||
nmm, err := NewNodeMetricMapping("testmetric", NodeSettings{
|
||||
FieldName: "f",
|
||||
Namespace: "2",
|
||||
IdentifierType: "s",
|
||||
Identifier: "h",
|
||||
}, map[string]string{})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "ns=2;s=h", nmm.idStr)
|
||||
}
|
||||
|
||||
func TestValidateNodeToAdd(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
existing map[metricParts]struct{}
|
||||
nmm *NodeMetricMapping
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "valid",
|
||||
existing: map[metricParts]struct{}{},
|
||||
nmm: func() *NodeMetricMapping {
|
||||
nmm, err := NewNodeMetricMapping("testmetric", NodeSettings{
|
||||
FieldName: "f",
|
||||
Namespace: "2",
|
||||
IdentifierType: "s",
|
||||
Identifier: "hf",
|
||||
}, map[string]string{})
|
||||
require.NoError(t, err)
|
||||
return nmm
|
||||
}(),
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "empty field name not allowed",
|
||||
existing: map[metricParts]struct{}{},
|
||||
nmm: func() *NodeMetricMapping {
|
||||
nmm, err := NewNodeMetricMapping("testmetric", NodeSettings{
|
||||
FieldName: "",
|
||||
Namespace: "2",
|
||||
IdentifierType: "s",
|
||||
Identifier: "hf",
|
||||
}, map[string]string{})
|
||||
require.NoError(t, err)
|
||||
return nmm
|
||||
}(),
|
||||
err: errors.New(`empty name in ""`),
|
||||
},
|
||||
{
|
||||
name: "empty namespace not allowed",
|
||||
existing: map[metricParts]struct{}{},
|
||||
nmm: func() *NodeMetricMapping {
|
||||
nmm, err := NewNodeMetricMapping("testmetric", NodeSettings{
|
||||
FieldName: "f",
|
||||
Namespace: "",
|
||||
IdentifierType: "s",
|
||||
Identifier: "hf",
|
||||
}, map[string]string{})
|
||||
require.NoError(t, err)
|
||||
return nmm
|
||||
}(),
|
||||
err: errors.New("empty node namespace not allowed"),
|
||||
},
|
||||
{
|
||||
name: "empty identifier type not allowed",
|
||||
existing: map[metricParts]struct{}{},
|
||||
nmm: func() *NodeMetricMapping {
|
||||
nmm, err := NewNodeMetricMapping("testmetric", NodeSettings{
|
||||
FieldName: "f",
|
||||
Namespace: "2",
|
||||
IdentifierType: "",
|
||||
Identifier: "hf",
|
||||
}, map[string]string{})
|
||||
require.NoError(t, err)
|
||||
return nmm
|
||||
}(),
|
||||
err: errors.New(`invalid identifier type "" in "f"`),
|
||||
},
|
||||
{
|
||||
name: "invalid identifier type not allowed",
|
||||
existing: map[metricParts]struct{}{},
|
||||
nmm: func() *NodeMetricMapping {
|
||||
nmm, err := NewNodeMetricMapping("testmetric", NodeSettings{
|
||||
FieldName: "f",
|
||||
Namespace: "2",
|
||||
IdentifierType: "j",
|
||||
Identifier: "hf",
|
||||
}, map[string]string{})
|
||||
require.NoError(t, err)
|
||||
return nmm
|
||||
}(),
|
||||
err: errors.New(`invalid identifier type "j" in "f"`),
|
||||
},
|
||||
{
|
||||
name: "duplicate metric not allowed",
|
||||
existing: map[metricParts]struct{}{
|
||||
{metricName: "testmetric", fieldName: "f", tags: "t1=v1, t2=v2"}: {},
|
||||
},
|
||||
nmm: func() *NodeMetricMapping {
|
||||
nmm, err := NewNodeMetricMapping("testmetric", NodeSettings{
|
||||
FieldName: "f",
|
||||
Namespace: "2",
|
||||
IdentifierType: "s",
|
||||
Identifier: "hf",
|
||||
TagsSlice: [][]string{{"t1", "v1"}, {"t2", "v2"}},
|
||||
}, map[string]string{})
|
||||
require.NoError(t, err)
|
||||
return nmm
|
||||
}(),
|
||||
err: errors.New(`name "f" is duplicated (metric name "testmetric", tags "t1=v1, t2=v2")`),
|
||||
},
|
||||
{
|
||||
name: "identifier type mismatch",
|
||||
existing: map[metricParts]struct{}{},
|
||||
nmm: func() *NodeMetricMapping {
|
||||
nmm, err := NewNodeMetricMapping("testmetric", NodeSettings{
|
||||
FieldName: "f",
|
||||
Namespace: "2",
|
||||
IdentifierType: "i",
|
||||
Identifier: "hf",
|
||||
}, map[string]string{})
|
||||
require.NoError(t, err)
|
||||
return nmm
|
||||
}(),
|
||||
err: errors.New(`identifier type "i" does not match the type of identifier "hf"`),
|
||||
},
|
||||
}
|
||||
|
||||
for idT, idV := range map[string]string{
|
||||
"s": "hf",
|
||||
"i": "1",
|
||||
"g": "849683f0-ce92-4fa2-836f-a02cde61d75d",
|
||||
"b": "aGVsbG8gSSBhbSBhIHRlc3QgaWRlbnRpZmllcg=="} {
|
||||
tests = append(tests, struct {
|
||||
name string
|
||||
existing map[metricParts]struct{}
|
||||
nmm *NodeMetricMapping
|
||||
err error
|
||||
}{
|
||||
name: "identifier type " + idT + " allowed",
|
||||
existing: map[metricParts]struct{}{},
|
||||
nmm: func() *NodeMetricMapping {
|
||||
nmm, err := NewNodeMetricMapping("testmetric", NodeSettings{
|
||||
FieldName: "f",
|
||||
Namespace: "2",
|
||||
IdentifierType: idT,
|
||||
Identifier: idV,
|
||||
}, map[string]string{})
|
||||
require.NoError(t, err)
|
||||
return nmm
|
||||
}(),
|
||||
err: nil,
|
||||
})
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validateNodeToAdd(tt.existing, tt.nmm)
|
||||
require.Equal(t, tt.err, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestInitNodeMetricMapping(t *testing.T) {
|
||||
tests := []struct {
|
||||
testname string
|
||||
config InputClientConfig
|
||||
expected []NodeMetricMapping
|
||||
err error
|
||||
}{
|
||||
{
|
||||
testname: "only root node",
|
||||
config: InputClientConfig{
|
||||
MetricName: "testmetric",
|
||||
Timestamp: TimestampSourceTelegraf,
|
||||
RootNodes: []NodeSettings{
|
||||
{
|
||||
FieldName: "f",
|
||||
Namespace: "2",
|
||||
IdentifierType: "s",
|
||||
Identifier: "id1",
|
||||
TagsSlice: [][]string{{"t1", "v1"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: []NodeMetricMapping{
|
||||
{
|
||||
Tag: NodeSettings{
|
||||
FieldName: "f",
|
||||
Namespace: "2",
|
||||
IdentifierType: "s",
|
||||
Identifier: "id1",
|
||||
TagsSlice: [][]string{{"t1", "v1"}},
|
||||
},
|
||||
idStr: "ns=2;s=id1",
|
||||
metricName: "testmetric",
|
||||
MetricTags: map[string]string{"t1": "v1"},
|
||||
},
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
testname: "root node and group node",
|
||||
config: InputClientConfig{
|
||||
MetricName: "testmetric",
|
||||
Timestamp: TimestampSourceTelegraf,
|
||||
RootNodes: []NodeSettings{
|
||||
{
|
||||
FieldName: "f",
|
||||
Namespace: "2",
|
||||
IdentifierType: "s",
|
||||
Identifier: "id1",
|
||||
TagsSlice: [][]string{{"t1", "v1"}},
|
||||
},
|
||||
},
|
||||
Groups: []NodeGroupSettings{
|
||||
{
|
||||
MetricName: "groupmetric",
|
||||
Namespace: "3",
|
||||
IdentifierType: "s",
|
||||
Nodes: []NodeSettings{
|
||||
{
|
||||
FieldName: "f",
|
||||
Identifier: "id2",
|
||||
TagsSlice: [][]string{{"t2", "v2"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: []NodeMetricMapping{
|
||||
{
|
||||
Tag: NodeSettings{
|
||||
FieldName: "f",
|
||||
Namespace: "2",
|
||||
IdentifierType: "s",
|
||||
Identifier: "id1",
|
||||
TagsSlice: [][]string{{"t1", "v1"}},
|
||||
},
|
||||
idStr: "ns=2;s=id1",
|
||||
metricName: "testmetric",
|
||||
MetricTags: map[string]string{"t1": "v1"},
|
||||
},
|
||||
{
|
||||
Tag: NodeSettings{
|
||||
FieldName: "f",
|
||||
Namespace: "3",
|
||||
IdentifierType: "s",
|
||||
Identifier: "id2",
|
||||
TagsSlice: [][]string{{"t2", "v2"}},
|
||||
},
|
||||
idStr: "ns=3;s=id2",
|
||||
metricName: "groupmetric",
|
||||
MetricTags: map[string]string{"t2": "v2"},
|
||||
},
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
testname: "only group node",
|
||||
config: InputClientConfig{
|
||||
MetricName: "testmetric",
|
||||
Timestamp: TimestampSourceTelegraf,
|
||||
Groups: []NodeGroupSettings{
|
||||
{
|
||||
MetricName: "groupmetric",
|
||||
Namespace: "3",
|
||||
IdentifierType: "s",
|
||||
Nodes: []NodeSettings{
|
||||
{
|
||||
FieldName: "f",
|
||||
Identifier: "id2",
|
||||
TagsSlice: [][]string{{"t2", "v2"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: []NodeMetricMapping{
|
||||
{
|
||||
Tag: NodeSettings{
|
||||
FieldName: "f",
|
||||
Namespace: "3",
|
||||
IdentifierType: "s",
|
||||
Identifier: "id2",
|
||||
TagsSlice: [][]string{{"t2", "v2"}},
|
||||
},
|
||||
idStr: "ns=3;s=id2",
|
||||
metricName: "groupmetric",
|
||||
MetricTags: map[string]string{"t2": "v2"},
|
||||
},
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
testname: "tags and default only default tags used",
|
||||
config: InputClientConfig{
|
||||
MetricName: "testmetric",
|
||||
Timestamp: TimestampSourceTelegraf,
|
||||
Groups: []NodeGroupSettings{
|
||||
{
|
||||
MetricName: "groupmetric",
|
||||
Namespace: "3",
|
||||
IdentifierType: "s",
|
||||
Nodes: []NodeSettings{
|
||||
{
|
||||
FieldName: "f",
|
||||
Identifier: "id2",
|
||||
TagsSlice: [][]string{{"t2", "v2"}},
|
||||
DefaultTags: map[string]string{"t3": "v3"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: []NodeMetricMapping{
|
||||
{
|
||||
Tag: NodeSettings{
|
||||
FieldName: "f",
|
||||
Namespace: "3",
|
||||
IdentifierType: "s",
|
||||
Identifier: "id2",
|
||||
TagsSlice: [][]string{{"t2", "v2"}},
|
||||
DefaultTags: map[string]string{"t3": "v3"},
|
||||
},
|
||||
idStr: "ns=3;s=id2",
|
||||
metricName: "groupmetric",
|
||||
MetricTags: map[string]string{"t3": "v3"},
|
||||
},
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
testname: "only root node default overrides slice",
|
||||
config: InputClientConfig{
|
||||
MetricName: "testmetric",
|
||||
Timestamp: TimestampSourceTelegraf,
|
||||
RootNodes: []NodeSettings{
|
||||
{
|
||||
FieldName: "f",
|
||||
Namespace: "2",
|
||||
IdentifierType: "s",
|
||||
Identifier: "id1",
|
||||
TagsSlice: [][]string{{"t1", "v1"}},
|
||||
DefaultTags: map[string]string{"t3": "v3"},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: []NodeMetricMapping{
|
||||
{
|
||||
Tag: NodeSettings{
|
||||
FieldName: "f",
|
||||
Namespace: "2",
|
||||
IdentifierType: "s",
|
||||
Identifier: "id1",
|
||||
TagsSlice: [][]string{{"t1", "v1"}},
|
||||
DefaultTags: map[string]string{"t3": "v3"},
|
||||
},
|
||||
idStr: "ns=2;s=id1",
|
||||
metricName: "testmetric",
|
||||
MetricTags: map[string]string{"t3": "v3"},
|
||||
},
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.testname, func(t *testing.T) {
|
||||
o := OpcUAInputClient{Config: tt.config}
|
||||
err := o.InitNodeMetricMapping()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.expected, o.NodeMetricMapping)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateNodeValue(t *testing.T) {
|
||||
type testStep struct {
|
||||
nodeIdx int
|
||||
value interface{}
|
||||
status ua.StatusCode
|
||||
expected interface{}
|
||||
}
|
||||
tests := []struct {
|
||||
testname string
|
||||
steps []testStep
|
||||
}{
|
||||
{
|
||||
"value should update when code ok",
|
||||
[]testStep{
|
||||
{
|
||||
0,
|
||||
"Harmony",
|
||||
ua.StatusOK,
|
||||
"Harmony",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"value should not update when code bad",
|
||||
[]testStep{
|
||||
{
|
||||
0,
|
||||
"Harmony",
|
||||
ua.StatusOK,
|
||||
"Harmony",
|
||||
},
|
||||
{
|
||||
0,
|
||||
"Odium",
|
||||
ua.StatusBad,
|
||||
"Harmony",
|
||||
},
|
||||
{
|
||||
0,
|
||||
"Ati",
|
||||
ua.StatusOK,
|
||||
"Ati",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
conf := &opcua.OpcUAClientConfig{
|
||||
Endpoint: "opc.tcp://localhost:4930",
|
||||
SecurityPolicy: "None",
|
||||
SecurityMode: "None",
|
||||
AuthMethod: "",
|
||||
ConnectTimeout: config.Duration(2 * time.Second),
|
||||
RequestTimeout: config.Duration(2 * time.Second),
|
||||
Workarounds: opcua.OpcUAWorkarounds{},
|
||||
}
|
||||
c, err := conf.CreateClient(testutil.Logger{})
|
||||
require.NoError(t, err)
|
||||
o := OpcUAInputClient{
|
||||
OpcUAClient: c,
|
||||
Log: testutil.Logger{},
|
||||
NodeMetricMapping: []NodeMetricMapping{
|
||||
{
|
||||
Tag: NodeSettings{
|
||||
FieldName: "f",
|
||||
},
|
||||
},
|
||||
{
|
||||
Tag: NodeSettings{
|
||||
FieldName: "f2",
|
||||
},
|
||||
},
|
||||
},
|
||||
LastReceivedData: make([]NodeValue, 2),
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.testname, func(t *testing.T) {
|
||||
o.LastReceivedData = make([]NodeValue, 2)
|
||||
for i, step := range tt.steps {
|
||||
v, err := ua.NewVariant(step.value)
|
||||
require.NoError(t, err)
|
||||
o.UpdateNodeValue(0, &ua.DataValue{
|
||||
Value: v,
|
||||
Status: step.status,
|
||||
SourceTimestamp: time.Date(2022, 03, 17, 8, 33, 00, 00, &time.Location{}).Add(time.Duration(i) * time.Second),
|
||||
SourcePicoseconds: 0,
|
||||
ServerTimestamp: time.Date(2022, 03, 17, 8, 33, 00, 500, &time.Location{}).Add(time.Duration(i) * time.Second),
|
||||
ServerPicoseconds: 0,
|
||||
})
|
||||
require.Equal(t, step.expected, o.LastReceivedData[0].Value)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMetricForNode(t *testing.T) {
|
||||
conf := &opcua.OpcUAClientConfig{
|
||||
Endpoint: "opc.tcp://localhost:4930",
|
||||
SecurityPolicy: "None",
|
||||
SecurityMode: "None",
|
||||
AuthMethod: "",
|
||||
ConnectTimeout: config.Duration(2 * time.Second),
|
||||
RequestTimeout: config.Duration(2 * time.Second),
|
||||
Workarounds: opcua.OpcUAWorkarounds{},
|
||||
}
|
||||
c, err := conf.CreateClient(testutil.Logger{})
|
||||
require.NoError(t, err)
|
||||
o := OpcUAInputClient{
|
||||
Config: InputClientConfig{
|
||||
Timestamp: TimestampSourceSource,
|
||||
},
|
||||
OpcUAClient: c,
|
||||
Log: testutil.Logger{},
|
||||
LastReceivedData: make([]NodeValue, 2),
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
testname string
|
||||
nmm []NodeMetricMapping
|
||||
v interface{}
|
||||
isArray bool
|
||||
dataType ua.TypeID
|
||||
time time.Time
|
||||
status ua.StatusCode
|
||||
expected telegraf.Metric
|
||||
}{
|
||||
{
|
||||
testname: "metric build correctly",
|
||||
nmm: []NodeMetricMapping{
|
||||
{
|
||||
Tag: NodeSettings{
|
||||
FieldName: "fn",
|
||||
},
|
||||
idStr: "ns=3;s=hi",
|
||||
metricName: "testingmetric",
|
||||
MetricTags: map[string]string{"t1": "v1"},
|
||||
},
|
||||
},
|
||||
v: 16,
|
||||
isArray: false,
|
||||
dataType: ua.TypeIDInt32,
|
||||
time: time.Date(2022, 03, 17, 8, 55, 00, 00, &time.Location{}),
|
||||
status: ua.StatusOK,
|
||||
expected: metric.New("testingmetric",
|
||||
map[string]string{"t1": "v1", "id": "ns=3;s=hi"},
|
||||
map[string]interface{}{"Quality": "The operation succeeded. StatusGood (0x0)", "fn": 16},
|
||||
time.Date(2022, 03, 17, 8, 55, 00, 00, &time.Location{})),
|
||||
},
|
||||
{
|
||||
testname: "array-like metric build correctly",
|
||||
nmm: []NodeMetricMapping{
|
||||
{
|
||||
Tag: NodeSettings{
|
||||
FieldName: "fn",
|
||||
},
|
||||
idStr: "ns=3;s=hi",
|
||||
metricName: "testingmetric",
|
||||
MetricTags: map[string]string{"t1": "v1"},
|
||||
},
|
||||
},
|
||||
v: []int32{16, 17},
|
||||
isArray: true,
|
||||
dataType: ua.TypeIDInt32,
|
||||
time: time.Date(2022, 03, 17, 8, 55, 00, 00, &time.Location{}),
|
||||
status: ua.StatusOK,
|
||||
expected: metric.New("testingmetric",
|
||||
map[string]string{"t1": "v1", "id": "ns=3;s=hi"},
|
||||
map[string]interface{}{"Quality": "The operation succeeded. StatusGood (0x0)", "fn[0]": 16, "fn[1]": 17},
|
||||
time.Date(2022, 03, 17, 8, 55, 00, 00, &time.Location{})),
|
||||
},
|
||||
{
|
||||
testname: "nil does not panic",
|
||||
nmm: []NodeMetricMapping{
|
||||
{
|
||||
Tag: NodeSettings{
|
||||
FieldName: "fn",
|
||||
},
|
||||
idStr: "ns=3;s=hi",
|
||||
metricName: "testingmetric",
|
||||
MetricTags: map[string]string{"t1": "v1"},
|
||||
},
|
||||
},
|
||||
v: nil,
|
||||
isArray: false,
|
||||
dataType: ua.TypeIDNull,
|
||||
time: time.Date(2022, 03, 17, 8, 55, 00, 00, &time.Location{}),
|
||||
status: ua.StatusOK,
|
||||
expected: metric.New("testingmetric",
|
||||
map[string]string{"t1": "v1", "id": "ns=3;s=hi"},
|
||||
map[string]interface{}{"Quality": "The operation succeeded. StatusGood (0x0)"},
|
||||
time.Date(2022, 03, 17, 8, 55, 00, 00, &time.Location{})),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.testname, func(t *testing.T) {
|
||||
o.NodeMetricMapping = tt.nmm
|
||||
o.LastReceivedData[0].SourceTime = tt.time
|
||||
o.LastReceivedData[0].Quality = tt.status
|
||||
o.LastReceivedData[0].Value = tt.v
|
||||
o.LastReceivedData[0].DataType = tt.dataType
|
||||
o.LastReceivedData[0].IsArray = tt.isArray
|
||||
actual := o.MetricForNode(0)
|
||||
require.Equal(t, tt.expected.Tags(), actual.Tags())
|
||||
require.Equal(t, tt.expected.Fields(), actual.Fields())
|
||||
require.Equal(t, tt.expected.Time(), actual.Time())
|
||||
})
|
||||
}
|
||||
}
|
15
plugins/common/opcua/logger.go
Normal file
15
plugins/common/opcua/logger.go
Normal file
|
@ -0,0 +1,15 @@
|
|||
package opcua
|
||||
|
||||
import (
|
||||
"github.com/influxdata/telegraf"
|
||||
)
|
||||
|
||||
// DebugLogger logs messages from opcua at the debug level.
|
||||
type DebugLogger struct {
|
||||
Log telegraf.Logger
|
||||
}
|
||||
|
||||
func (l *DebugLogger) Write(p []byte) (n int, err error) {
|
||||
l.Log.Debug(string(p))
|
||||
return len(p), nil
|
||||
}
|
359
plugins/common/opcua/opcua_util.go
Normal file
359
plugins/common/opcua/opcua_util.go
Normal file
|
@ -0,0 +1,359 @@
|
|||
package opcua
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gopcua/opcua"
|
||||
"github.com/gopcua/opcua/debug"
|
||||
"github.com/gopcua/opcua/ua"
|
||||
|
||||
"github.com/influxdata/telegraf/config"
|
||||
)
|
||||
|
||||
// SELF SIGNED CERT FUNCTIONS
|
||||
|
||||
func newTempDir() (string, error) {
|
||||
dir, err := os.MkdirTemp("", "ssc")
|
||||
return dir, err
|
||||
}
|
||||
|
||||
func generateCert(host string, rsaBits int, certFile, keyFile string, dur time.Duration) (cert, key string, err error) {
|
||||
dir, err := newTempDir()
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to create certificate: %w", err)
|
||||
}
|
||||
|
||||
if len(host) == 0 {
|
||||
return "", "", errors.New("missing required host parameter")
|
||||
}
|
||||
if rsaBits == 0 {
|
||||
rsaBits = 2048
|
||||
}
|
||||
if len(certFile) == 0 {
|
||||
certFile = dir + "/cert.pem"
|
||||
}
|
||||
if len(keyFile) == 0 {
|
||||
keyFile = dir + "/key.pem"
|
||||
}
|
||||
|
||||
priv, err := rsa.GenerateKey(rand.Reader, rsaBits)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to generate private key: %w", err)
|
||||
}
|
||||
|
||||
notBefore := time.Now()
|
||||
notAfter := notBefore.Add(dur)
|
||||
|
||||
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
|
||||
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to generate serial number: %w", err)
|
||||
}
|
||||
|
||||
template := x509.Certificate{
|
||||
SerialNumber: serialNumber,
|
||||
Subject: pkix.Name{
|
||||
Organization: []string{"Telegraf OPC UA Client"},
|
||||
},
|
||||
NotBefore: notBefore,
|
||||
NotAfter: notAfter,
|
||||
|
||||
KeyUsage: x509.KeyUsageContentCommitment | x509.KeyUsageKeyEncipherment |
|
||||
x509.KeyUsageDigitalSignature | x509.KeyUsageDataEncipherment | x509.KeyUsageCertSign,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
|
||||
BasicConstraintsValid: true,
|
||||
}
|
||||
|
||||
hosts := strings.Split(host, ",")
|
||||
for _, h := range hosts {
|
||||
if ip := net.ParseIP(h); ip != nil {
|
||||
template.IPAddresses = append(template.IPAddresses, ip)
|
||||
} else {
|
||||
template.DNSNames = append(template.DNSNames, h)
|
||||
}
|
||||
if uri, err := url.Parse(h); err == nil {
|
||||
template.URIs = append(template.URIs, uri)
|
||||
}
|
||||
}
|
||||
|
||||
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, publicKey(priv), priv)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to create certificate: %w", err)
|
||||
}
|
||||
|
||||
certOut, err := os.Create(certFile)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to open %s for writing: %w", certFile, err)
|
||||
}
|
||||
if err := pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil {
|
||||
return "", "", fmt.Errorf("failed to write data to %s: %w", certFile, err)
|
||||
}
|
||||
if err := certOut.Close(); err != nil {
|
||||
return "", "", fmt.Errorf("error closing %s: %w", certFile, err)
|
||||
}
|
||||
|
||||
keyOut, err := os.OpenFile(keyFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to open %s for writing: %w", keyFile, err)
|
||||
}
|
||||
keyBlock, err := pemBlockForKey(priv)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("error generating block: %w", err)
|
||||
}
|
||||
if err := pem.Encode(keyOut, keyBlock); err != nil {
|
||||
return "", "", fmt.Errorf("failed to write data to %s: %w", keyFile, err)
|
||||
}
|
||||
if err := keyOut.Close(); err != nil {
|
||||
return "", "", fmt.Errorf("error closing %s: %w", keyFile, err)
|
||||
}
|
||||
|
||||
return certFile, keyFile, nil
|
||||
}
|
||||
|
||||
func publicKey(priv interface{}) interface{} {
|
||||
switch k := priv.(type) {
|
||||
case *rsa.PrivateKey:
|
||||
return &k.PublicKey
|
||||
case *ecdsa.PrivateKey:
|
||||
return &k.PublicKey
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func pemBlockForKey(priv interface{}) (*pem.Block, error) {
|
||||
switch k := priv.(type) {
|
||||
case *rsa.PrivateKey:
|
||||
return &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(k)}, nil
|
||||
case *ecdsa.PrivateKey:
|
||||
b, err := x509.MarshalECPrivateKey(k)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to marshal ECDSA private key: %w", err)
|
||||
}
|
||||
return &pem.Block{Type: "EC PRIVATE KEY", Bytes: b}, nil
|
||||
default:
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (o *OpcUAClient) generateClientOpts(endpoints []*ua.EndpointDescription) ([]opcua.Option, error) {
|
||||
appuri := "urn:telegraf:gopcua:client"
|
||||
appname := "Telegraf"
|
||||
|
||||
// ApplicationURI is automatically read from the cert so is not required if a cert if provided
|
||||
opts := []opcua.Option{
|
||||
opcua.ApplicationURI(appuri),
|
||||
opcua.ApplicationName(appname),
|
||||
opcua.RequestTimeout(time.Duration(o.Config.RequestTimeout)),
|
||||
}
|
||||
|
||||
if o.Config.SessionTimeout != 0 {
|
||||
opts = append(opts, opcua.SessionTimeout(time.Duration(o.Config.SessionTimeout)))
|
||||
}
|
||||
|
||||
certFile := o.Config.Certificate
|
||||
keyFile := o.Config.PrivateKey
|
||||
policy := o.Config.SecurityPolicy
|
||||
mode := o.Config.SecurityMode
|
||||
var err error
|
||||
if certFile == "" && keyFile == "" {
|
||||
if policy != "None" || mode != "None" {
|
||||
certFile, keyFile, err = generateCert(appuri, 2048, certFile, keyFile, 365*24*time.Hour)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var cert []byte
|
||||
if certFile != "" && keyFile != "" {
|
||||
debug.Printf("Loading cert/key from %s/%s", certFile, keyFile)
|
||||
c, err := tls.LoadX509KeyPair(certFile, keyFile)
|
||||
if err != nil {
|
||||
o.Log.Warnf("Failed to load certificate: %s", err)
|
||||
} else {
|
||||
pk, ok := c.PrivateKey.(*rsa.PrivateKey)
|
||||
if !ok {
|
||||
return nil, errors.New("invalid private key")
|
||||
}
|
||||
cert = c.Certificate[0]
|
||||
opts = append(opts, opcua.PrivateKey(pk), opcua.Certificate(cert))
|
||||
}
|
||||
}
|
||||
|
||||
var secPolicy string
|
||||
switch {
|
||||
case policy == "auto":
|
||||
// set it later
|
||||
case strings.HasPrefix(policy, ua.SecurityPolicyURIPrefix):
|
||||
secPolicy = policy
|
||||
policy = ""
|
||||
case policy == "None" || policy == "Basic128Rsa15" || policy == "Basic256" || policy == "Basic256Sha256" ||
|
||||
policy == "Aes128_Sha256_RsaOaep" || policy == "Aes256_Sha256_RsaPss":
|
||||
secPolicy = ua.SecurityPolicyURIPrefix + policy
|
||||
policy = ""
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid security policy: %s", policy)
|
||||
}
|
||||
|
||||
o.Log.Debugf("security policy from configuration %s", secPolicy)
|
||||
|
||||
// Select the most appropriate authentication mode from server capabilities and user input
|
||||
authMode, authOption, err := o.generateAuth(o.Config.AuthMethod, cert, o.Config.Username, o.Config.Password)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
opts = append(opts, authOption)
|
||||
|
||||
var secMode ua.MessageSecurityMode
|
||||
switch strings.ToLower(mode) {
|
||||
case "auto":
|
||||
case "none":
|
||||
secMode = ua.MessageSecurityModeNone
|
||||
mode = ""
|
||||
case "sign":
|
||||
secMode = ua.MessageSecurityModeSign
|
||||
mode = ""
|
||||
case "signandencrypt":
|
||||
secMode = ua.MessageSecurityModeSignAndEncrypt
|
||||
mode = ""
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid security mode: %s", mode)
|
||||
}
|
||||
|
||||
// Allow input of only one of sec-mode,sec-policy when choosing 'None'
|
||||
if secMode == ua.MessageSecurityModeNone || secPolicy == ua.SecurityPolicyURINone {
|
||||
secMode = ua.MessageSecurityModeNone
|
||||
secPolicy = ua.SecurityPolicyURINone
|
||||
}
|
||||
|
||||
// Find the best endpoint based on our input and server recommendation (highest SecurityMode+SecurityLevel)
|
||||
var serverEndpoint *ua.EndpointDescription
|
||||
switch {
|
||||
case mode == "auto" && policy == "auto": // No user selection, choose best
|
||||
for _, e := range endpoints {
|
||||
if serverEndpoint == nil || (e.SecurityMode >= serverEndpoint.SecurityMode && e.SecurityLevel >= serverEndpoint.SecurityLevel) {
|
||||
serverEndpoint = e
|
||||
}
|
||||
}
|
||||
|
||||
case mode != "auto" && policy == "auto": // User only cares about mode, select highest securitylevel with that mode
|
||||
for _, e := range endpoints {
|
||||
if e.SecurityMode == secMode && (serverEndpoint == nil || e.SecurityLevel >= serverEndpoint.SecurityLevel) {
|
||||
serverEndpoint = e
|
||||
}
|
||||
}
|
||||
|
||||
case mode == "auto" && policy != "auto": // User only cares about policy, select highest securitylevel with that policy
|
||||
for _, e := range endpoints {
|
||||
if e.SecurityPolicyURI == secPolicy && (serverEndpoint == nil || e.SecurityLevel >= serverEndpoint.SecurityLevel) {
|
||||
serverEndpoint = e
|
||||
}
|
||||
}
|
||||
|
||||
default: // User cares about both
|
||||
o.Log.Debugf("User cares about both the policy (%s) and security mode (%s)", secPolicy, secMode)
|
||||
o.Log.Debugf("Server has %d endpoints", len(endpoints))
|
||||
for _, e := range endpoints {
|
||||
o.Log.Debugf("Evaluating endpoint %s, policy %s, mode %s, level %d", e.EndpointURL, e.SecurityPolicyURI, e.SecurityMode, e.SecurityLevel)
|
||||
if e.SecurityPolicyURI == secPolicy && e.SecurityMode == secMode && (serverEndpoint == nil || e.SecurityLevel >= serverEndpoint.SecurityLevel) {
|
||||
serverEndpoint = e
|
||||
o.Log.Debugf(
|
||||
"Security policy and mode found. Using server endpoint %s for security. Policy %s",
|
||||
serverEndpoint.EndpointURL,
|
||||
serverEndpoint.SecurityPolicyURI,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if serverEndpoint == nil { // Didn't find an endpoint with matching policy and mode.
|
||||
return nil, errors.New("unable to find suitable server endpoint with selected sec-policy and sec-mode")
|
||||
}
|
||||
|
||||
secPolicy = serverEndpoint.SecurityPolicyURI
|
||||
secMode = serverEndpoint.SecurityMode
|
||||
|
||||
// Check that the selected endpoint is a valid combo
|
||||
err = validateEndpointConfig(endpoints, secPolicy, secMode, authMode)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error validating input: %w", err)
|
||||
}
|
||||
|
||||
opts = append(opts, opcua.SecurityFromEndpoint(serverEndpoint, authMode))
|
||||
return opts, nil
|
||||
}
|
||||
|
||||
func (o *OpcUAClient) generateAuth(a string, cert []byte, user, passwd config.Secret) (ua.UserTokenType, opcua.Option, error) {
|
||||
var authMode ua.UserTokenType
|
||||
var authOption opcua.Option
|
||||
switch strings.ToLower(a) {
|
||||
case "anonymous":
|
||||
authMode = ua.UserTokenTypeAnonymous
|
||||
authOption = opcua.AuthAnonymous()
|
||||
case "username":
|
||||
authMode = ua.UserTokenTypeUserName
|
||||
|
||||
var username, password []byte
|
||||
if !user.Empty() {
|
||||
usecret, err := user.Get()
|
||||
if err != nil {
|
||||
return 0, nil, fmt.Errorf("error reading the username input: %w", err)
|
||||
}
|
||||
defer usecret.Destroy()
|
||||
username = usecret.Bytes()
|
||||
}
|
||||
|
||||
if !passwd.Empty() {
|
||||
psecret, err := passwd.Get()
|
||||
if err != nil {
|
||||
return 0, nil, fmt.Errorf("error reading the password input: %w", err)
|
||||
}
|
||||
defer psecret.Destroy()
|
||||
password = psecret.Bytes()
|
||||
}
|
||||
authOption = opcua.AuthUsername(string(username), string(password))
|
||||
case "certificate":
|
||||
authMode = ua.UserTokenTypeCertificate
|
||||
authOption = opcua.AuthCertificate(cert)
|
||||
case "issuedtoken":
|
||||
// todo: this is unsupported, fail here or fail in the opcua package?
|
||||
authMode = ua.UserTokenTypeIssuedToken
|
||||
authOption = opcua.AuthIssuedToken([]byte(nil))
|
||||
default:
|
||||
o.Log.Warnf("unknown auth-mode, defaulting to Anonymous")
|
||||
authMode = ua.UserTokenTypeAnonymous
|
||||
authOption = opcua.AuthAnonymous()
|
||||
}
|
||||
|
||||
return authMode, authOption, nil
|
||||
}
|
||||
|
||||
func validateEndpointConfig(endpoints []*ua.EndpointDescription, secPolicy string, secMode ua.MessageSecurityMode, authMode ua.UserTokenType) error {
|
||||
for _, e := range endpoints {
|
||||
if e.SecurityMode == secMode && e.SecurityPolicyURI == secPolicy {
|
||||
for _, t := range e.UserIdentityTokens {
|
||||
if t.TokenType == authMode {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("server does not support an endpoint with security: %q, %q", secPolicy, secMode)
|
||||
}
|
84
plugins/common/parallel/ordered.go
Normal file
84
plugins/common/parallel/ordered.go
Normal file
|
@ -0,0 +1,84 @@
|
|||
package parallel
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/influxdata/telegraf"
|
||||
)
|
||||
|
||||
type Ordered struct {
|
||||
wg sync.WaitGroup
|
||||
fn func(telegraf.Metric) []telegraf.Metric
|
||||
|
||||
// queue of jobs coming in. Workers pick jobs off this queue for processing
|
||||
workerQueue chan job
|
||||
|
||||
// queue of ordered metrics going out
|
||||
queue chan futureMetric
|
||||
}
|
||||
|
||||
func NewOrdered(acc telegraf.Accumulator, fn func(telegraf.Metric) []telegraf.Metric, orderedQueueSize, workerCount int) *Ordered {
|
||||
p := &Ordered{
|
||||
fn: fn,
|
||||
workerQueue: make(chan job, workerCount),
|
||||
queue: make(chan futureMetric, orderedQueueSize),
|
||||
}
|
||||
p.startWorkers(workerCount)
|
||||
p.wg.Add(1)
|
||||
go func() {
|
||||
p.readQueue(acc)
|
||||
p.wg.Done()
|
||||
}()
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *Ordered) Enqueue(metric telegraf.Metric) {
|
||||
future := make(futureMetric)
|
||||
p.queue <- future
|
||||
|
||||
// write the future to the worker pool. Order doesn't matter now because the
|
||||
// outgoing p.queue will enforce order regardless of the order the jobs are
|
||||
// completed in
|
||||
p.workerQueue <- job{
|
||||
future: future,
|
||||
metric: metric,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Ordered) readQueue(acc telegraf.Accumulator) {
|
||||
// wait for the response from each worker in order
|
||||
for mCh := range p.queue {
|
||||
// allow each worker to write out multiple metrics
|
||||
for metrics := range mCh {
|
||||
for _, m := range metrics {
|
||||
acc.AddMetric(m)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Ordered) startWorkers(count int) {
|
||||
p.wg.Add(count)
|
||||
for i := 0; i < count; i++ {
|
||||
go func() {
|
||||
for job := range p.workerQueue {
|
||||
job.future <- p.fn(job.metric)
|
||||
close(job.future)
|
||||
}
|
||||
p.wg.Done()
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Ordered) Stop() {
|
||||
close(p.queue)
|
||||
close(p.workerQueue)
|
||||
p.wg.Wait()
|
||||
}
|
||||
|
||||
type futureMetric chan []telegraf.Metric
|
||||
|
||||
type job struct {
|
||||
future futureMetric
|
||||
metric telegraf.Metric
|
||||
}
|
8
plugins/common/parallel/parallel.go
Normal file
8
plugins/common/parallel/parallel.go
Normal file
|
@ -0,0 +1,8 @@
|
|||
package parallel
|
||||
|
||||
import "github.com/influxdata/telegraf"
|
||||
|
||||
type Parallel interface {
|
||||
Enqueue(telegraf.Metric)
|
||||
Stop()
|
||||
}
|
117
plugins/common/parallel/parallel_test.go
Normal file
117
plugins/common/parallel/parallel_test.go
Normal file
|
@ -0,0 +1,117 @@
|
|||
package parallel_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/influxdata/telegraf"
|
||||
"github.com/influxdata/telegraf/metric"
|
||||
"github.com/influxdata/telegraf/plugins/common/parallel"
|
||||
"github.com/influxdata/telegraf/testutil"
|
||||
)
|
||||
|
||||
func TestOrderedJobsStayOrdered(t *testing.T) {
|
||||
acc := &testutil.Accumulator{}
|
||||
|
||||
p := parallel.NewOrdered(acc, jobFunc, 10000, 10)
|
||||
now := time.Now()
|
||||
for i := 0; i < 20000; i++ {
|
||||
m := metric.New("test",
|
||||
map[string]string{},
|
||||
map[string]interface{}{
|
||||
"val": i,
|
||||
},
|
||||
now,
|
||||
)
|
||||
now = now.Add(1)
|
||||
p.Enqueue(m)
|
||||
}
|
||||
p.Stop()
|
||||
|
||||
i := 0
|
||||
require.Len(t, acc.Metrics, 20000)
|
||||
for _, m := range acc.GetTelegrafMetrics() {
|
||||
v, ok := m.GetField("val")
|
||||
require.True(t, ok)
|
||||
require.EqualValues(t, i, v)
|
||||
i++
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnorderedJobsDontDropAnyJobs(t *testing.T) {
|
||||
acc := &testutil.Accumulator{}
|
||||
|
||||
p := parallel.NewUnordered(acc, jobFunc, 10)
|
||||
|
||||
now := time.Now()
|
||||
|
||||
expectedTotal := 0
|
||||
for i := 0; i < 20000; i++ {
|
||||
expectedTotal += i
|
||||
m := metric.New("test",
|
||||
map[string]string{},
|
||||
map[string]interface{}{
|
||||
"val": i,
|
||||
},
|
||||
now,
|
||||
)
|
||||
now = now.Add(1)
|
||||
p.Enqueue(m)
|
||||
}
|
||||
p.Stop()
|
||||
|
||||
actualTotal := int64(0)
|
||||
require.Len(t, acc.Metrics, 20000)
|
||||
for _, m := range acc.GetTelegrafMetrics() {
|
||||
v, ok := m.GetField("val")
|
||||
require.True(t, ok)
|
||||
actualTotal += v.(int64)
|
||||
}
|
||||
require.EqualValues(t, expectedTotal, actualTotal)
|
||||
}
|
||||
|
||||
func BenchmarkOrdered(b *testing.B) {
|
||||
acc := &testutil.Accumulator{}
|
||||
|
||||
p := parallel.NewOrdered(acc, jobFunc, 10000, 10)
|
||||
|
||||
m := metric.New("test",
|
||||
map[string]string{},
|
||||
map[string]interface{}{
|
||||
"val": 1,
|
||||
},
|
||||
time.Now(),
|
||||
)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
p.Enqueue(m)
|
||||
}
|
||||
p.Stop()
|
||||
}
|
||||
|
||||
func BenchmarkUnordered(b *testing.B) {
|
||||
acc := &testutil.Accumulator{}
|
||||
|
||||
p := parallel.NewUnordered(acc, jobFunc, 10)
|
||||
|
||||
m := metric.New("test",
|
||||
map[string]string{},
|
||||
map[string]interface{}{
|
||||
"val": 1,
|
||||
},
|
||||
time.Now(),
|
||||
)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
p.Enqueue(m)
|
||||
}
|
||||
p.Stop()
|
||||
}
|
||||
|
||||
func jobFunc(m telegraf.Metric) []telegraf.Metric {
|
||||
return []telegraf.Metric{m}
|
||||
}
|
60
plugins/common/parallel/unordered.go
Normal file
60
plugins/common/parallel/unordered.go
Normal file
|
@ -0,0 +1,60 @@
|
|||
package parallel
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/influxdata/telegraf"
|
||||
)
|
||||
|
||||
type Unordered struct {
|
||||
wg sync.WaitGroup
|
||||
acc telegraf.Accumulator
|
||||
fn func(telegraf.Metric) []telegraf.Metric
|
||||
inQueue chan telegraf.Metric
|
||||
}
|
||||
|
||||
func NewUnordered(
|
||||
acc telegraf.Accumulator,
|
||||
fn func(telegraf.Metric) []telegraf.Metric,
|
||||
workerCount int,
|
||||
) *Unordered {
|
||||
p := &Unordered{
|
||||
acc: acc,
|
||||
inQueue: make(chan telegraf.Metric, workerCount),
|
||||
fn: fn,
|
||||
}
|
||||
|
||||
// start workers
|
||||
p.wg.Add(1)
|
||||
go func() {
|
||||
p.startWorkers(workerCount)
|
||||
p.wg.Done()
|
||||
}()
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *Unordered) startWorkers(count int) {
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(count)
|
||||
for i := 0; i < count; i++ {
|
||||
go func() {
|
||||
for metric := range p.inQueue {
|
||||
for _, m := range p.fn(metric) {
|
||||
p.acc.AddMetric(m)
|
||||
}
|
||||
}
|
||||
wg.Done()
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func (p *Unordered) Stop() {
|
||||
close(p.inQueue)
|
||||
p.wg.Wait()
|
||||
}
|
||||
|
||||
func (p *Unordered) Enqueue(m telegraf.Metric) {
|
||||
p.inQueue <- m
|
||||
}
|
184
plugins/common/postgresql/config.go
Normal file
184
plugins/common/postgresql/config.go
Normal file
|
@ -0,0 +1,184 @@
|
|||
package postgresql
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v4"
|
||||
"github.com/jackc/pgx/v4/stdlib"
|
||||
|
||||
"github.com/influxdata/telegraf/config"
|
||||
)
|
||||
|
||||
var socketRegexp = regexp.MustCompile(`/\.s\.PGSQL\.\d+$`)
|
||||
var sanitizer = regexp.MustCompile(`(\s|^)((?:password|sslcert|sslkey|sslmode|sslrootcert)\s?=\s?(?:(?:'(?:[^'\\]|\\.)*')|(?:\S+)))`)
|
||||
|
||||
type Config struct {
|
||||
Address config.Secret `toml:"address"`
|
||||
OutputAddress string `toml:"outputaddress"`
|
||||
MaxIdle int `toml:"max_idle"`
|
||||
MaxOpen int `toml:"max_open"`
|
||||
MaxLifetime config.Duration `toml:"max_lifetime"`
|
||||
IsPgBouncer bool `toml:"-"`
|
||||
}
|
||||
|
||||
func (c *Config) CreateService() (*Service, error) {
|
||||
addrSecret, err := c.Address.Get()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting address failed: %w", err)
|
||||
}
|
||||
addr := addrSecret.String()
|
||||
defer addrSecret.Destroy()
|
||||
|
||||
if c.Address.Empty() || addr == "localhost" {
|
||||
addr = "host=localhost sslmode=disable"
|
||||
if err := c.Address.Set([]byte(addr)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
connConfig, err := pgx.ParseConfig(addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Remove the socket name from the path
|
||||
connConfig.Host = socketRegexp.ReplaceAllLiteralString(connConfig.Host, "")
|
||||
|
||||
// Specific support to make it work with PgBouncer too
|
||||
// See https://github.com/influxdata/telegraf/issues/3253#issuecomment-357505343
|
||||
if c.IsPgBouncer {
|
||||
// Remove DriveConfig and revert it by the ParseConfig method
|
||||
// See https://github.com/influxdata/telegraf/issues/9134
|
||||
connConfig.PreferSimpleProtocol = true
|
||||
}
|
||||
|
||||
// Provide the connection string without sensitive information for use as
|
||||
// tag or other output properties
|
||||
sanitizedAddr, err := c.sanitizedAddress()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Service{
|
||||
SanitizedAddress: sanitizedAddr,
|
||||
ConnectionDatabase: connectionDatabase(sanitizedAddr),
|
||||
maxIdle: c.MaxIdle,
|
||||
maxOpen: c.MaxOpen,
|
||||
maxLifetime: time.Duration(c.MaxLifetime),
|
||||
dsn: stdlib.RegisterConnConfig(connConfig),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// connectionDatabase determines the database to which the connection was made
|
||||
func connectionDatabase(sanitizedAddr string) string {
|
||||
connConfig, err := pgx.ParseConfig(sanitizedAddr)
|
||||
if err != nil || connConfig.Database == "" {
|
||||
return "postgres"
|
||||
}
|
||||
|
||||
return connConfig.Database
|
||||
}
|
||||
|
||||
// sanitizedAddress strips sensitive information from the connection string.
|
||||
// If the user set the output address use that before parsing anything else.
|
||||
func (c *Config) sanitizedAddress() (string, error) {
|
||||
if c.OutputAddress != "" {
|
||||
return c.OutputAddress, nil
|
||||
}
|
||||
|
||||
// Get the address
|
||||
addrSecret, err := c.Address.Get()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("getting address for sanitization failed: %w", err)
|
||||
}
|
||||
defer addrSecret.Destroy()
|
||||
|
||||
// Make sure we convert URI-formatted strings into key-values
|
||||
addr := addrSecret.TemporaryString()
|
||||
if strings.HasPrefix(addr, "postgres://") || strings.HasPrefix(addr, "postgresql://") {
|
||||
if addr, err = toKeyValue(addr); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
// Sanitize the string using a regular expression
|
||||
sanitized := sanitizer.ReplaceAllString(addr, "")
|
||||
return strings.TrimSpace(sanitized), nil
|
||||
}
|
||||
|
||||
// Based on parseURLSettings() at https://github.com/jackc/pgx/blob/master/pgconn/config.go
|
||||
func toKeyValue(uri string) (string, error) {
|
||||
u, err := url.Parse(uri)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("parsing URI failed: %w", err)
|
||||
}
|
||||
|
||||
// Check the protocol
|
||||
if u.Scheme != "postgres" && u.Scheme != "postgresql" {
|
||||
return "", fmt.Errorf("invalid connection protocol: %s", u.Scheme)
|
||||
}
|
||||
|
||||
quoteIfNecessary := func(v string) string {
|
||||
if !strings.ContainsAny(v, ` ='\`) {
|
||||
return v
|
||||
}
|
||||
r := strings.ReplaceAll(v, `\`, `\\`)
|
||||
r = strings.ReplaceAll(r, `'`, `\'`)
|
||||
return "'" + r + "'"
|
||||
}
|
||||
|
||||
// Extract the parameters
|
||||
parts := make([]string, 0, len(u.Query())+5)
|
||||
if u.User != nil {
|
||||
parts = append(parts, "user="+quoteIfNecessary(u.User.Username()))
|
||||
if password, found := u.User.Password(); found {
|
||||
parts = append(parts, "password="+quoteIfNecessary(password))
|
||||
}
|
||||
}
|
||||
|
||||
// Handle multiple host:port's in url.Host by splitting them into host,host,host and port,port,port.
|
||||
hostParts := strings.Split(u.Host, ",")
|
||||
hosts := make([]string, 0, len(hostParts))
|
||||
ports := make([]string, 0, len(hostParts))
|
||||
var anyPortSet bool
|
||||
for _, host := range hostParts {
|
||||
if host == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
h, p, err := net.SplitHostPort(host)
|
||||
if err != nil {
|
||||
if !strings.Contains(err.Error(), "missing port") {
|
||||
return "", fmt.Errorf("failed to process host %q: %w", host, err)
|
||||
}
|
||||
h = host
|
||||
}
|
||||
anyPortSet = anyPortSet || err == nil
|
||||
hosts = append(hosts, h)
|
||||
ports = append(ports, p)
|
||||
}
|
||||
if len(hosts) > 0 {
|
||||
parts = append(parts, "host="+strings.Join(hosts, ","))
|
||||
}
|
||||
if anyPortSet {
|
||||
parts = append(parts, "port="+strings.Join(ports, ","))
|
||||
}
|
||||
|
||||
database := strings.TrimLeft(u.Path, "/")
|
||||
if database != "" {
|
||||
parts = append(parts, "dbname="+quoteIfNecessary(database))
|
||||
}
|
||||
|
||||
for k, v := range u.Query() {
|
||||
parts = append(parts, k+"="+quoteIfNecessary(strings.Join(v, ",")))
|
||||
}
|
||||
|
||||
// Required to produce a repeatable output e.g. for tags or testing
|
||||
sort.Strings(parts)
|
||||
return strings.Join(parts, " "), nil
|
||||
}
|
240
plugins/common/postgresql/config_test.go
Normal file
240
plugins/common/postgresql/config_test.go
Normal file
|
@ -0,0 +1,240 @@
|
|||
package postgresql
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/influxdata/telegraf/config"
|
||||
)
|
||||
|
||||
func TestURIParsing(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
uri string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "short",
|
||||
uri: `postgres://localhost`,
|
||||
expected: "host=localhost",
|
||||
},
|
||||
{
|
||||
name: "with port",
|
||||
uri: `postgres://localhost:5432`,
|
||||
expected: "host=localhost port=5432",
|
||||
},
|
||||
{
|
||||
name: "with database",
|
||||
uri: `postgres://localhost/mydb`,
|
||||
expected: "dbname=mydb host=localhost",
|
||||
},
|
||||
{
|
||||
name: "with additional parameters",
|
||||
uri: `postgres://localhost/mydb?application_name=pgxtest&search_path=myschema&connect_timeout=5`,
|
||||
expected: "application_name=pgxtest connect_timeout=5 dbname=mydb host=localhost search_path=myschema",
|
||||
},
|
||||
{
|
||||
name: "with database setting in params",
|
||||
uri: `postgres://localhost:5432/?database=mydb`,
|
||||
expected: "database=mydb host=localhost port=5432",
|
||||
},
|
||||
{
|
||||
name: "with authentication",
|
||||
uri: `postgres://jack:secret@localhost:5432/mydb?sslmode=prefer`,
|
||||
expected: "dbname=mydb host=localhost password=secret port=5432 sslmode=prefer user=jack",
|
||||
},
|
||||
{
|
||||
name: "with spaces",
|
||||
uri: `postgres://jack%20hunter:secret@localhost/mydb?application_name=pgx%20test`,
|
||||
expected: "application_name='pgx test' dbname=mydb host=localhost password=secret user='jack hunter'",
|
||||
},
|
||||
{
|
||||
name: "with equal signs",
|
||||
uri: `postgres://jack%20hunter:secret@localhost/mydb?application_name=pgx%3Dtest`,
|
||||
expected: "application_name='pgx=test' dbname=mydb host=localhost password=secret user='jack hunter'",
|
||||
},
|
||||
{
|
||||
name: "multiple hosts",
|
||||
uri: `postgres://jack:secret@foo:1,bar:2,baz:3/mydb?sslmode=disable`,
|
||||
expected: "dbname=mydb host=foo,bar,baz password=secret port=1,2,3 sslmode=disable user=jack",
|
||||
},
|
||||
{
|
||||
name: "multiple hosts without ports",
|
||||
uri: `postgres://jack:secret@foo,bar,baz/mydb?sslmode=disable`,
|
||||
expected: "dbname=mydb host=foo,bar,baz password=secret sslmode=disable user=jack",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
// Key value without spaces around equal sign
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
actual, err := toKeyValue(tt.uri)
|
||||
require.NoError(t, err)
|
||||
require.Equalf(t, tt.expected, actual, "initial: %s", tt.uri)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeAddressKeyValue(t *testing.T) {
|
||||
keys := []string{"password", "sslcert", "sslkey", "sslmode", "sslrootcert"}
|
||||
tests := []struct {
|
||||
name string
|
||||
value string
|
||||
}{
|
||||
{
|
||||
name: "simple text",
|
||||
value: `foo`,
|
||||
},
|
||||
{
|
||||
name: "empty values",
|
||||
value: `''`,
|
||||
},
|
||||
{
|
||||
name: "space in value",
|
||||
value: `'foo bar'`,
|
||||
},
|
||||
{
|
||||
name: "equal sign in value",
|
||||
value: `'foo=bar'`,
|
||||
},
|
||||
{
|
||||
name: "escaped quote",
|
||||
value: `'foo\'s bar'`,
|
||||
},
|
||||
{
|
||||
name: "escaped quote no space",
|
||||
value: `\'foobar\'s\'`,
|
||||
},
|
||||
{
|
||||
name: "escaped backslash",
|
||||
value: `'foo bar\\'`,
|
||||
},
|
||||
{
|
||||
name: "escaped quote and backslash",
|
||||
value: `'foo\\\'s bar'`,
|
||||
},
|
||||
{
|
||||
name: "two escaped backslashes",
|
||||
value: `'foo bar\\\\'`,
|
||||
},
|
||||
{
|
||||
name: "multiple inline spaces",
|
||||
value: "'foo \t bar'",
|
||||
},
|
||||
{
|
||||
name: "leading space",
|
||||
value: `' foo bar'`,
|
||||
},
|
||||
{
|
||||
name: "trailing space",
|
||||
value: `'foo bar '`,
|
||||
},
|
||||
{
|
||||
name: "multiple equal signs",
|
||||
value: `'foo===bar'`,
|
||||
},
|
||||
{
|
||||
name: "leading equal sign",
|
||||
value: `'=foo bar'`,
|
||||
},
|
||||
{
|
||||
name: "trailing equal sign",
|
||||
value: `'foo bar='`,
|
||||
},
|
||||
{
|
||||
name: "mix of equal signs and spaces",
|
||||
value: "'foo = a\t===\tbar'",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
// Key value without spaces around equal sign
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Generate the DSN from the given keys and value
|
||||
parts := make([]string, 0, len(keys))
|
||||
for _, k := range keys {
|
||||
parts = append(parts, k+"="+tt.value)
|
||||
}
|
||||
dsn := strings.Join(parts, " canary=ok ")
|
||||
|
||||
cfg := &Config{
|
||||
Address: config.NewSecret([]byte(dsn)),
|
||||
}
|
||||
|
||||
expected := strings.Join(make([]string, len(keys)), "canary=ok ")
|
||||
expected = strings.TrimSpace(expected)
|
||||
actual, err := cfg.sanitizedAddress()
|
||||
require.NoError(t, err)
|
||||
require.Equalf(t, expected, actual, "initial: %s", dsn)
|
||||
})
|
||||
|
||||
// Key value with spaces around equal sign
|
||||
t.Run("spaced "+tt.name, func(t *testing.T) {
|
||||
// Generate the DSN from the given keys and value
|
||||
parts := make([]string, 0, len(keys))
|
||||
for _, k := range keys {
|
||||
parts = append(parts, k+" = "+tt.value)
|
||||
}
|
||||
dsn := strings.Join(parts, " canary=ok ")
|
||||
|
||||
cfg := &Config{
|
||||
Address: config.NewSecret([]byte(dsn)),
|
||||
}
|
||||
|
||||
expected := strings.Join(make([]string, len(keys)), "canary=ok ")
|
||||
expected = strings.TrimSpace(expected)
|
||||
actual, err := cfg.sanitizedAddress()
|
||||
require.NoError(t, err)
|
||||
require.Equalf(t, expected, actual, "initial: %s", dsn)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeAddressURI(t *testing.T) {
|
||||
keys := []string{"password", "sslcert", "sslkey", "sslmode", "sslrootcert"}
|
||||
tests := []struct {
|
||||
name string
|
||||
value string
|
||||
}{
|
||||
{
|
||||
name: "simple text",
|
||||
value: `foo`,
|
||||
},
|
||||
{
|
||||
name: "empty values",
|
||||
value: ``,
|
||||
},
|
||||
{
|
||||
name: "space in value",
|
||||
value: `foo bar`,
|
||||
},
|
||||
{
|
||||
name: "equal sign in value",
|
||||
value: `foo=bar`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Generate the DSN from the given keys and value
|
||||
value := strings.ReplaceAll(tt.value, "=", "%3D")
|
||||
value = strings.ReplaceAll(value, " ", "%20")
|
||||
parts := make([]string, 0, len(keys))
|
||||
for _, k := range keys {
|
||||
parts = append(parts, k+"="+value)
|
||||
}
|
||||
dsn := "postgresql://user:passwd@localhost:5432/db?" + strings.Join(parts, "&")
|
||||
|
||||
cfg := &Config{
|
||||
Address: config.NewSecret([]byte(dsn)),
|
||||
}
|
||||
|
||||
expected := "dbname=db host=localhost port=5432 user=user"
|
||||
actual, err := cfg.sanitizedAddress()
|
||||
require.NoError(t, err)
|
||||
require.Equalf(t, expected, actual, "initial: %s", dsn)
|
||||
})
|
||||
}
|
||||
}
|
42
plugins/common/postgresql/service.go
Normal file
42
plugins/common/postgresql/service.go
Normal file
|
@ -0,0 +1,42 @@
|
|||
package postgresql
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
// Blank import required to register driver
|
||||
_ "github.com/jackc/pgx/v4/stdlib"
|
||||
)
|
||||
|
||||
// Service common functionality shared between the postgresql and postgresql_extensible
|
||||
// packages.
|
||||
type Service struct {
|
||||
DB *sql.DB
|
||||
SanitizedAddress string
|
||||
ConnectionDatabase string
|
||||
|
||||
dsn string
|
||||
maxIdle int
|
||||
maxOpen int
|
||||
maxLifetime time.Duration
|
||||
}
|
||||
|
||||
func (p *Service) Start() error {
|
||||
db, err := sql.Open("pgx", p.dsn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p.DB = db
|
||||
|
||||
p.DB.SetMaxOpenConns(p.maxOpen)
|
||||
p.DB.SetMaxIdleConns(p.maxIdle)
|
||||
p.DB.SetConnMaxLifetime(p.maxLifetime)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Service) Stop() {
|
||||
if p.DB != nil {
|
||||
p.DB.Close()
|
||||
}
|
||||
}
|
140
plugins/common/proxy/connect.go
Normal file
140
plugins/common/proxy/connect.go
Normal file
|
@ -0,0 +1,140 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"golang.org/x/net/proxy"
|
||||
)
|
||||
|
||||
// httpConnectProxy proxies (only?) TCP over a HTTP tunnel using the CONNECT method
|
||||
type httpConnectProxy struct {
|
||||
forward proxy.Dialer
|
||||
url *url.URL
|
||||
}
|
||||
|
||||
func (c *httpConnectProxy) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
// Prevent using UDP
|
||||
if network == "udp" {
|
||||
return nil, fmt.Errorf("cannot proxy %q traffic over HTTP CONNECT", network)
|
||||
}
|
||||
|
||||
var proxyConn net.Conn
|
||||
var err error
|
||||
if dialer, ok := c.forward.(proxy.ContextDialer); ok {
|
||||
proxyConn, err = dialer.DialContext(ctx, "tcp", c.url.Host)
|
||||
} else {
|
||||
shim := contextDialerShim{c.forward}
|
||||
proxyConn, err = shim.DialContext(ctx, "tcp", c.url.Host)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Add and strip http:// to extract authority portion of the URL
|
||||
// since CONNECT doesn't use a full URL. The request header would
|
||||
// look something like: "CONNECT www.influxdata.com:443 HTTP/1.1"
|
||||
requestURL, err := url.Parse("http://" + addr)
|
||||
if err != nil {
|
||||
if err := proxyConn.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
requestURL.Scheme = ""
|
||||
|
||||
// Build HTTP CONNECT request
|
||||
req, err := http.NewRequest(http.MethodConnect, requestURL.String(), nil)
|
||||
if err != nil {
|
||||
if err := proxyConn.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
req.Close = false
|
||||
if password, hasAuth := c.url.User.Password(); hasAuth {
|
||||
req.SetBasicAuth(c.url.User.Username(), password)
|
||||
}
|
||||
|
||||
err = req.Write(proxyConn)
|
||||
if err != nil {
|
||||
if err := proxyConn.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := http.ReadResponse(bufio.NewReader(proxyConn), req)
|
||||
if err != nil {
|
||||
if err := proxyConn.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
if err := resp.Body.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
if err := proxyConn.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nil, fmt.Errorf("failed to connect to proxy: %q", resp.Status)
|
||||
}
|
||||
|
||||
return proxyConn, nil
|
||||
}
|
||||
|
||||
func (c *httpConnectProxy) Dial(network, addr string) (net.Conn, error) {
|
||||
return c.DialContext(context.Background(), network, addr)
|
||||
}
|
||||
|
||||
func newHTTPConnectProxy(proxyURL *url.URL, forward proxy.Dialer) (proxy.Dialer, error) {
|
||||
return &httpConnectProxy{forward, proxyURL}, nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
// Register new proxy types
|
||||
proxy.RegisterDialerType("http", newHTTPConnectProxy)
|
||||
proxy.RegisterDialerType("https", newHTTPConnectProxy)
|
||||
}
|
||||
|
||||
// contextDialerShim allows cancellation of the dial from a context even if the underlying
|
||||
// dialer does not implement `proxy.ContextDialer`. Arguably, this shouldn't actually get run,
|
||||
// unless a new proxy type is added that doesn't implement `proxy.ContextDialer`, as all the
|
||||
// standard library dialers implement `proxy.ContextDialer`.
|
||||
type contextDialerShim struct {
|
||||
dialer proxy.Dialer
|
||||
}
|
||||
|
||||
func (cd *contextDialerShim) Dial(network, addr string) (net.Conn, error) {
|
||||
return cd.dialer.Dial(network, addr)
|
||||
}
|
||||
|
||||
func (cd *contextDialerShim) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
var (
|
||||
conn net.Conn
|
||||
done = make(chan struct{}, 1)
|
||||
err error
|
||||
)
|
||||
|
||||
go func() {
|
||||
conn, err = cd.dialer.Dial(network, addr)
|
||||
close(done)
|
||||
if conn != nil && ctx.Err() != nil {
|
||||
_ = conn.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
err = ctx.Err()
|
||||
case <-done:
|
||||
}
|
||||
|
||||
return conn, err
|
||||
}
|
37
plugins/common/proxy/dialer.go
Normal file
37
plugins/common/proxy/dialer.go
Normal file
|
@ -0,0 +1,37 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/proxy"
|
||||
)
|
||||
|
||||
type ProxiedDialer struct {
|
||||
dialer proxy.Dialer
|
||||
}
|
||||
|
||||
func (pd *ProxiedDialer) Dial(network, addr string) (net.Conn, error) {
|
||||
return pd.dialer.Dial(network, addr)
|
||||
}
|
||||
|
||||
func (pd *ProxiedDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
if contextDialer, ok := pd.dialer.(proxy.ContextDialer); ok {
|
||||
return contextDialer.DialContext(ctx, network, addr)
|
||||
}
|
||||
|
||||
contextDialer := contextDialerShim{pd.dialer}
|
||||
return contextDialer.DialContext(ctx, network, addr)
|
||||
}
|
||||
|
||||
func (pd *ProxiedDialer) DialTimeout(network, addr string, timeout time.Duration) (net.Conn, error) {
|
||||
ctx := context.Background()
|
||||
if timeout.Seconds() != 0 {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
return pd.DialContext(ctx, network, addr)
|
||||
}
|
57
plugins/common/proxy/proxy.go
Normal file
57
plugins/common/proxy/proxy.go
Normal file
|
@ -0,0 +1,57 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"golang.org/x/net/proxy"
|
||||
)
|
||||
|
||||
type HTTPProxy struct {
|
||||
UseSystemProxy bool `toml:"use_system_proxy"`
|
||||
HTTPProxyURL string `toml:"http_proxy_url"`
|
||||
}
|
||||
|
||||
type proxyFunc func(req *http.Request) (*url.URL, error)
|
||||
|
||||
func (p *HTTPProxy) Proxy() (proxyFunc, error) {
|
||||
if p.UseSystemProxy {
|
||||
return http.ProxyFromEnvironment, nil
|
||||
} else if len(p.HTTPProxyURL) > 0 {
|
||||
address, err := url.Parse(p.HTTPProxyURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing proxy url %q: %w", p.HTTPProxyURL, err)
|
||||
}
|
||||
return http.ProxyURL(address), nil
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
type TCPProxy struct {
|
||||
UseProxy bool `toml:"use_proxy"`
|
||||
ProxyURL string `toml:"proxy_url"`
|
||||
}
|
||||
|
||||
func (p *TCPProxy) Proxy() (*ProxiedDialer, error) {
|
||||
var dialer proxy.Dialer
|
||||
if p.UseProxy {
|
||||
if len(p.ProxyURL) > 0 {
|
||||
parsed, err := url.Parse(p.ProxyURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing proxy url %q: %w", p.ProxyURL, err)
|
||||
}
|
||||
|
||||
if dialer, err = proxy.FromURL(parsed, proxy.Direct); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
dialer = proxy.FromEnvironment()
|
||||
}
|
||||
} else {
|
||||
dialer = proxy.Direct
|
||||
}
|
||||
|
||||
return &ProxiedDialer{dialer}, nil
|
||||
}
|
22
plugins/common/proxy/socks5.go
Normal file
22
plugins/common/proxy/socks5.go
Normal file
|
@ -0,0 +1,22 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"golang.org/x/net/proxy"
|
||||
)
|
||||
|
||||
type Socks5ProxyConfig struct {
|
||||
Socks5ProxyEnabled bool `toml:"socks5_enabled"`
|
||||
Socks5ProxyAddress string `toml:"socks5_address"`
|
||||
Socks5ProxyUsername string `toml:"socks5_username"`
|
||||
Socks5ProxyPassword string `toml:"socks5_password"`
|
||||
}
|
||||
|
||||
func (c *Socks5ProxyConfig) GetDialer() (proxy.Dialer, error) {
|
||||
var auth *proxy.Auth
|
||||
if c.Socks5ProxyPassword != "" || c.Socks5ProxyUsername != "" {
|
||||
auth = new(proxy.Auth)
|
||||
auth.User = c.Socks5ProxyUsername
|
||||
auth.Password = c.Socks5ProxyPassword
|
||||
}
|
||||
return proxy.SOCKS5("tcp", c.Socks5ProxyAddress, auth, proxy.Direct)
|
||||
}
|
74
plugins/common/proxy/socks5_test.go
Normal file
74
plugins/common/proxy/socks5_test.go
Normal file
|
@ -0,0 +1,74 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/armon/go-socks5"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSocks5ProxyConfigIntegration(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping integration test in short mode")
|
||||
}
|
||||
|
||||
const (
|
||||
proxyAddress = "127.0.0.1:12345"
|
||||
proxyUsername = "user"
|
||||
proxyPassword = "password"
|
||||
)
|
||||
|
||||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
server, err := socks5.New(&socks5.Config{
|
||||
AuthMethods: []socks5.Authenticator{socks5.UserPassAuthenticator{
|
||||
Credentials: socks5.StaticCredentials{
|
||||
proxyUsername: proxyPassword,
|
||||
},
|
||||
}},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
go func() {
|
||||
if err := server.ListenAndServe("tcp", proxyAddress); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
|
||||
conf := Socks5ProxyConfig{
|
||||
Socks5ProxyEnabled: true,
|
||||
Socks5ProxyAddress: proxyAddress,
|
||||
Socks5ProxyUsername: proxyUsername,
|
||||
Socks5ProxyPassword: proxyPassword,
|
||||
}
|
||||
dialer, err := conf.GetDialer()
|
||||
require.NoError(t, err)
|
||||
|
||||
var proxyConn net.Conn
|
||||
for i := 0; i < 10; i++ {
|
||||
proxyConn, err = dialer.Dial("tcp", l.Addr().String())
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
require.NotNil(t, proxyConn)
|
||||
defer func() { require.NoError(t, proxyConn.Close()) }()
|
||||
|
||||
serverConn, err := l.Accept()
|
||||
require.NoError(t, err)
|
||||
defer func() { require.NoError(t, serverConn.Close()) }()
|
||||
|
||||
writePayload := []byte("test")
|
||||
_, err = proxyConn.Write(writePayload)
|
||||
require.NoError(t, err)
|
||||
|
||||
receivePayload := make([]byte, 4)
|
||||
_, err = serverConn.Read(receivePayload)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, writePayload, receivePayload)
|
||||
}
|
155
plugins/common/psutil/mock_ps.go
Normal file
155
plugins/common/psutil/mock_ps.go
Normal file
|
@ -0,0 +1,155 @@
|
|||
package psutil
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"github.com/shirou/gopsutil/v4/cpu"
|
||||
"github.com/shirou/gopsutil/v4/disk"
|
||||
"github.com/shirou/gopsutil/v4/mem"
|
||||
"github.com/shirou/gopsutil/v4/net"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
// MockPS is a mock implementation of the PS interface for testing purposes.
|
||||
type MockPS struct {
|
||||
mock.Mock
|
||||
PSDiskDeps
|
||||
}
|
||||
|
||||
// MockPSDisk is a mock implementation of the PSDiskDeps interface for testing purposes.
|
||||
type MockPSDisk struct {
|
||||
*SystemPS
|
||||
*mock.Mock
|
||||
}
|
||||
|
||||
// MockDiskUsage is a mock implementation for disk usage operations.
|
||||
type MockDiskUsage struct {
|
||||
*mock.Mock
|
||||
}
|
||||
|
||||
// CPUTimes returns the CPU times statistics.
|
||||
func (m *MockPS) CPUTimes(_, _ bool) ([]cpu.TimesStat, error) {
|
||||
ret := m.Called()
|
||||
|
||||
r0 := ret.Get(0).([]cpu.TimesStat)
|
||||
r1 := ret.Error(1)
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// DiskUsage returns the disk usage statistics.
|
||||
func (m *MockPS) DiskUsage(mountPointFilter, mountOptsExclude, fstypeExclude []string) ([]*disk.UsageStat, []*disk.PartitionStat, error) {
|
||||
ret := m.Called(mountPointFilter, mountOptsExclude, fstypeExclude)
|
||||
|
||||
r0 := ret.Get(0).([]*disk.UsageStat)
|
||||
r1 := ret.Get(1).([]*disk.PartitionStat)
|
||||
r2 := ret.Error(2)
|
||||
|
||||
return r0, r1, r2
|
||||
}
|
||||
|
||||
// NetIO returns network I/O statistics for every network interface installed on the system.
|
||||
func (m *MockPS) NetIO() ([]net.IOCountersStat, error) {
|
||||
ret := m.Called()
|
||||
|
||||
r0 := ret.Get(0).([]net.IOCountersStat)
|
||||
r1 := ret.Error(1)
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// NetProto returns network statistics for the entire system.
|
||||
func (m *MockPS) NetProto() ([]net.ProtoCountersStat, error) {
|
||||
ret := m.Called()
|
||||
|
||||
r0 := ret.Get(0).([]net.ProtoCountersStat)
|
||||
r1 := ret.Error(1)
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// DiskIO returns the disk I/O statistics.
|
||||
func (m *MockPS) DiskIO(_ []string) (map[string]disk.IOCountersStat, error) {
|
||||
ret := m.Called()
|
||||
|
||||
r0 := ret.Get(0).(map[string]disk.IOCountersStat)
|
||||
r1 := ret.Error(1)
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// VMStat returns the virtual memory statistics.
|
||||
func (m *MockPS) VMStat() (*mem.VirtualMemoryStat, error) {
|
||||
ret := m.Called()
|
||||
|
||||
r0 := ret.Get(0).(*mem.VirtualMemoryStat)
|
||||
r1 := ret.Error(1)
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// SwapStat returns the swap memory statistics.
|
||||
func (m *MockPS) SwapStat() (*mem.SwapMemoryStat, error) {
|
||||
ret := m.Called()
|
||||
|
||||
r0 := ret.Get(0).(*mem.SwapMemoryStat)
|
||||
r1 := ret.Error(1)
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// NetConnections returns a list of network connections opened.
|
||||
func (m *MockPS) NetConnections() ([]net.ConnectionStat, error) {
|
||||
ret := m.Called()
|
||||
|
||||
r0 := ret.Get(0).([]net.ConnectionStat)
|
||||
r1 := ret.Error(1)
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// NetConntrack returns more detailed info about the conntrack table.
|
||||
func (m *MockPS) NetConntrack(perCPU bool) ([]net.ConntrackStat, error) {
|
||||
ret := m.Called(perCPU)
|
||||
|
||||
r0 := ret.Get(0).([]net.ConntrackStat)
|
||||
r1 := ret.Error(1)
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// Partitions returns the disk partition statistics.
|
||||
func (m *MockDiskUsage) Partitions(all bool) ([]disk.PartitionStat, error) {
|
||||
ret := m.Called(all)
|
||||
|
||||
r0 := ret.Get(0).([]disk.PartitionStat)
|
||||
r1 := ret.Error(1)
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// OSGetenv returns the value of the environment variable named by the key.
|
||||
func (m *MockDiskUsage) OSGetenv(key string) string {
|
||||
ret := m.Called(key)
|
||||
return ret.Get(0).(string)
|
||||
}
|
||||
|
||||
// OSStat returns the FileInfo structure describing the named file.
|
||||
func (m *MockDiskUsage) OSStat(name string) (os.FileInfo, error) {
|
||||
ret := m.Called(name)
|
||||
|
||||
r0 := ret.Get(0).(os.FileInfo)
|
||||
r1 := ret.Error(1)
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// PSDiskUsage returns a file system usage for the specified path.
|
||||
func (m *MockDiskUsage) PSDiskUsage(path string) (*disk.UsageStat, error) {
|
||||
ret := m.Called(path)
|
||||
|
||||
r0 := ret.Get(0).(*disk.UsageStat)
|
||||
r1 := ret.Error(1)
|
||||
|
||||
return r0, r1
|
||||
}
|
256
plugins/common/psutil/ps.go
Normal file
256
plugins/common/psutil/ps.go
Normal file
|
@ -0,0 +1,256 @@
|
|||
package psutil
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/shirou/gopsutil/v4/cpu"
|
||||
"github.com/shirou/gopsutil/v4/disk"
|
||||
"github.com/shirou/gopsutil/v4/mem"
|
||||
"github.com/shirou/gopsutil/v4/net"
|
||||
|
||||
"github.com/influxdata/telegraf"
|
||||
"github.com/influxdata/telegraf/internal"
|
||||
)
|
||||
|
||||
// PS is an interface that defines methods for gathering system statistics.
|
||||
type PS interface {
|
||||
// CPUTimes returns the CPU times statistics.
|
||||
CPUTimes(perCPU, totalCPU bool) ([]cpu.TimesStat, error)
|
||||
// DiskUsage returns the disk usage statistics.
|
||||
DiskUsage(mountPointFilter []string, mountOptsExclude []string, fstypeExclude []string) ([]*disk.UsageStat, []*disk.PartitionStat, error)
|
||||
// NetIO returns network I/O statistics for every network interface installed on the system.
|
||||
NetIO() ([]net.IOCountersStat, error)
|
||||
// NetProto returns network statistics for the entire system.
|
||||
NetProto() ([]net.ProtoCountersStat, error)
|
||||
// DiskIO returns the disk I/O statistics.
|
||||
DiskIO(names []string) (map[string]disk.IOCountersStat, error)
|
||||
// VMStat returns the virtual memory statistics.
|
||||
VMStat() (*mem.VirtualMemoryStat, error)
|
||||
// SwapStat returns the swap memory statistics.
|
||||
SwapStat() (*mem.SwapMemoryStat, error)
|
||||
// NetConnections returns a list of network connections opened.
|
||||
NetConnections() ([]net.ConnectionStat, error)
|
||||
// NetConntrack returns more detailed info about the conntrack table.
|
||||
NetConntrack(perCPU bool) ([]net.ConntrackStat, error)
|
||||
}
|
||||
|
||||
// PSDiskDeps is an interface that defines methods for gathering disk statistics.
|
||||
type PSDiskDeps interface {
|
||||
// Partitions returns the disk partition statistics.
|
||||
Partitions(all bool) ([]disk.PartitionStat, error)
|
||||
// OSGetenv returns the value of the environment variable named by the key.
|
||||
OSGetenv(key string) string
|
||||
// OSStat returns the FileInfo structure describing the named file.
|
||||
OSStat(name string) (os.FileInfo, error)
|
||||
// PSDiskUsage returns a file system usage for the specified path.
|
||||
PSDiskUsage(path string) (*disk.UsageStat, error)
|
||||
}
|
||||
|
||||
// SystemPS is a struct that implements the PS interface.
|
||||
type SystemPS struct {
|
||||
PSDiskDeps
|
||||
Log telegraf.Logger `toml:"-"`
|
||||
}
|
||||
|
||||
// SystemPSDisk is a struct that implements the PSDiskDeps interface.
|
||||
type SystemPSDisk struct{}
|
||||
|
||||
// NewSystemPS creates a new instance of SystemPS.
|
||||
func NewSystemPS() *SystemPS {
|
||||
return &SystemPS{PSDiskDeps: &SystemPSDisk{}}
|
||||
}
|
||||
|
||||
// CPUTimes returns the CPU times statistics.
|
||||
func (*SystemPS) CPUTimes(perCPU, totalCPU bool) ([]cpu.TimesStat, error) {
|
||||
var cpuTimes []cpu.TimesStat
|
||||
if perCPU {
|
||||
perCPUTimes, err := cpu.Times(true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cpuTimes = append(cpuTimes, perCPUTimes...)
|
||||
}
|
||||
if totalCPU {
|
||||
totalCPUTimes, err := cpu.Times(false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cpuTimes = append(cpuTimes, totalCPUTimes...)
|
||||
}
|
||||
return cpuTimes, nil
|
||||
}
|
||||
|
||||
// DiskUsage returns the disk usage statistics.
|
||||
func (s *SystemPS) DiskUsage(mountPointFilter, mountOptsExclude, fstypeExclude []string) ([]*disk.UsageStat, []*disk.PartitionStat, error) {
|
||||
parts, err := s.Partitions(true)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
mountPointFilterSet := newSet()
|
||||
for _, filter := range mountPointFilter {
|
||||
mountPointFilterSet.add(filter)
|
||||
}
|
||||
mountOptFilterSet := newSet()
|
||||
for _, filter := range mountOptsExclude {
|
||||
mountOptFilterSet.add(filter)
|
||||
}
|
||||
fstypeExcludeSet := newSet()
|
||||
for _, filter := range fstypeExclude {
|
||||
fstypeExcludeSet.add(filter)
|
||||
}
|
||||
paths := newSet()
|
||||
for _, part := range parts {
|
||||
paths.add(part.Mountpoint)
|
||||
}
|
||||
|
||||
// Autofs mounts indicate a potential mount, the partition will also be
|
||||
// listed with the actual filesystem when mounted. Ignore the autofs
|
||||
// partition to avoid triggering a mount.
|
||||
fstypeExcludeSet.add("autofs")
|
||||
|
||||
var usage []*disk.UsageStat
|
||||
var partitions []*disk.PartitionStat
|
||||
hostMountPrefix := s.OSGetenv("HOST_MOUNT_PREFIX")
|
||||
|
||||
partitionRange:
|
||||
for i := range parts {
|
||||
p := parts[i]
|
||||
|
||||
for _, o := range p.Opts {
|
||||
if !mountOptFilterSet.empty() && mountOptFilterSet.has(o) {
|
||||
continue partitionRange
|
||||
}
|
||||
}
|
||||
// If there is a filter set and if the mount point is not a
|
||||
// member of the filter set, don't gather info on it.
|
||||
if !mountPointFilterSet.empty() && !mountPointFilterSet.has(p.Mountpoint) {
|
||||
continue
|
||||
}
|
||||
|
||||
// If the mount point is a member of the exclude set,
|
||||
// don't gather info on it.
|
||||
if fstypeExcludeSet.has(p.Fstype) {
|
||||
continue
|
||||
}
|
||||
|
||||
// If there's a host mount prefix use it as newer gopsutil version check for
|
||||
// the init's mountpoints usually pointing to the host-mountpoint but in the
|
||||
// container. This won't work for checking the disk-usage as the disks are
|
||||
// mounted at HOST_MOUNT_PREFIX...
|
||||
mountpoint := p.Mountpoint
|
||||
if hostMountPrefix != "" && !strings.HasPrefix(p.Mountpoint, hostMountPrefix) {
|
||||
mountpoint = filepath.Join(hostMountPrefix, p.Mountpoint)
|
||||
// Exclude conflicting paths
|
||||
if paths.has(mountpoint) {
|
||||
if s.Log != nil {
|
||||
s.Log.Debugf("[SystemPS] => dropped by mount prefix (%q): %q", mountpoint, hostMountPrefix)
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
du, err := s.PSDiskUsage(mountpoint)
|
||||
if err != nil {
|
||||
if s.Log != nil {
|
||||
s.Log.Debugf("[SystemPS] => unable to get disk usage (%q): %v", mountpoint, err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
du.Path = filepath.Join(string(os.PathSeparator), strings.TrimPrefix(p.Mountpoint, hostMountPrefix))
|
||||
du.Fstype = p.Fstype
|
||||
usage = append(usage, du)
|
||||
partitions = append(partitions, &p)
|
||||
}
|
||||
|
||||
return usage, partitions, nil
|
||||
}
|
||||
|
||||
// NetProto returns network statistics for the entire system.
|
||||
func (*SystemPS) NetProto() ([]net.ProtoCountersStat, error) {
|
||||
return net.ProtoCounters(nil)
|
||||
}
|
||||
|
||||
// NetIO returns network I/O statistics for every network interface installed on the system.
|
||||
func (*SystemPS) NetIO() ([]net.IOCountersStat, error) {
|
||||
return net.IOCounters(true)
|
||||
}
|
||||
|
||||
// NetConnections returns a list of network connections opened.
|
||||
func (*SystemPS) NetConnections() ([]net.ConnectionStat, error) {
|
||||
return net.Connections("all")
|
||||
}
|
||||
|
||||
// NetConntrack returns more detailed info about the conntrack table.
|
||||
func (*SystemPS) NetConntrack(perCPU bool) ([]net.ConntrackStat, error) {
|
||||
return net.ConntrackStats(perCPU)
|
||||
}
|
||||
|
||||
// DiskIO returns the disk I/O statistics.
|
||||
func (*SystemPS) DiskIO(names []string) (map[string]disk.IOCountersStat, error) {
|
||||
m, err := disk.IOCounters(names...)
|
||||
if errors.Is(err, internal.ErrNotImplemented) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return m, err
|
||||
}
|
||||
|
||||
// VMStat returns the virtual memory statistics.
|
||||
func (*SystemPS) VMStat() (*mem.VirtualMemoryStat, error) {
|
||||
return mem.VirtualMemory()
|
||||
}
|
||||
|
||||
// SwapStat returns the swap memory statistics.
|
||||
func (*SystemPS) SwapStat() (*mem.SwapMemoryStat, error) {
|
||||
return mem.SwapMemory()
|
||||
}
|
||||
|
||||
// Partitions returns the disk partition statistics.
|
||||
func (*SystemPSDisk) Partitions(all bool) ([]disk.PartitionStat, error) {
|
||||
return disk.Partitions(all)
|
||||
}
|
||||
|
||||
// OSGetenv returns the value of the environment variable named by the key.
|
||||
func (*SystemPSDisk) OSGetenv(key string) string {
|
||||
return os.Getenv(key)
|
||||
}
|
||||
|
||||
// OSStat returns the FileInfo structure describing the named file.
|
||||
func (*SystemPSDisk) OSStat(name string) (os.FileInfo, error) {
|
||||
return os.Stat(name)
|
||||
}
|
||||
|
||||
// PSDiskUsage returns a file system usage for the specified path.
|
||||
func (*SystemPSDisk) PSDiskUsage(path string) (*disk.UsageStat, error) {
|
||||
return disk.Usage(path)
|
||||
}
|
||||
|
||||
type set struct {
|
||||
m map[string]struct{}
|
||||
}
|
||||
|
||||
func (s *set) empty() bool {
|
||||
return len(s.m) == 0
|
||||
}
|
||||
|
||||
func (s *set) add(key string) {
|
||||
s.m[key] = struct{}{}
|
||||
}
|
||||
|
||||
func (s *set) has(key string) bool {
|
||||
var ok bool
|
||||
_, ok = s.m[key]
|
||||
return ok
|
||||
}
|
||||
|
||||
func newSet() *set {
|
||||
s := &set{
|
||||
m: make(map[string]struct{}),
|
||||
}
|
||||
return s
|
||||
}
|
23
plugins/common/ratelimiter/config.go
Normal file
23
plugins/common/ratelimiter/config.go
Normal file
|
@ -0,0 +1,23 @@
|
|||
package ratelimiter
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/influxdata/telegraf/config"
|
||||
)
|
||||
|
||||
type RateLimitConfig struct {
|
||||
Limit config.Size `toml:"rate_limit"`
|
||||
Period config.Duration `toml:"rate_limit_period"`
|
||||
}
|
||||
|
||||
func (cfg *RateLimitConfig) CreateRateLimiter() (*RateLimiter, error) {
|
||||
if cfg.Limit > 0 && cfg.Period <= 0 {
|
||||
return nil, errors.New("invalid period for rate-limit")
|
||||
}
|
||||
return &RateLimiter{
|
||||
limit: int64(cfg.Limit),
|
||||
period: time.Duration(cfg.Period),
|
||||
}, nil
|
||||
}
|
66
plugins/common/ratelimiter/limiters.go
Normal file
66
plugins/common/ratelimiter/limiters.go
Normal file
|
@ -0,0 +1,66 @@
|
|||
package ratelimiter
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"math"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrLimitExceeded = errors.New("not enough tokens")
|
||||
)
|
||||
|
||||
type RateLimiter struct {
|
||||
limit int64
|
||||
period time.Duration
|
||||
periodStart time.Time
|
||||
remaining int64
|
||||
}
|
||||
|
||||
func (r *RateLimiter) Remaining(t time.Time) int64 {
|
||||
if r.limit == 0 {
|
||||
return math.MaxInt64
|
||||
}
|
||||
|
||||
// Check for corner case
|
||||
if !r.periodStart.Before(t) {
|
||||
return 0
|
||||
}
|
||||
|
||||
// We are in a new period, so the complete size is available
|
||||
deltat := t.Sub(r.periodStart)
|
||||
if deltat >= r.period {
|
||||
return r.limit
|
||||
}
|
||||
|
||||
return r.remaining
|
||||
}
|
||||
|
||||
func (r *RateLimiter) Accept(t time.Time, used int64) {
|
||||
if r.limit == 0 || r.periodStart.After(t) {
|
||||
return
|
||||
}
|
||||
|
||||
// Remember the first query and reset if we are in a new period
|
||||
if r.periodStart.IsZero() {
|
||||
r.periodStart = t
|
||||
r.remaining = r.limit
|
||||
} else if deltat := t.Sub(r.periodStart); deltat >= r.period {
|
||||
r.periodStart = r.periodStart.Add(deltat.Truncate(r.period))
|
||||
r.remaining = r.limit
|
||||
}
|
||||
|
||||
// Update the state
|
||||
r.remaining = max(r.remaining-used, 0)
|
||||
}
|
||||
|
||||
func (r *RateLimiter) Undo(t time.Time, used int64) {
|
||||
// Do nothing if we are not in the current period or unlimited because we
|
||||
// already reset the limit on a new window.
|
||||
if r.limit == 0 || r.periodStart.IsZero() || r.periodStart.After(t) || t.Sub(r.periodStart) >= r.period {
|
||||
return
|
||||
}
|
||||
|
||||
// Undo the state update
|
||||
r.remaining = min(r.remaining+used, r.limit)
|
||||
}
|
189
plugins/common/ratelimiter/limiters_test.go
Normal file
189
plugins/common/ratelimiter/limiters_test.go
Normal file
|
@ -0,0 +1,189 @@
|
|||
package ratelimiter
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/influxdata/telegraf/config"
|
||||
)
|
||||
|
||||
func TestInvalidPeriod(t *testing.T) {
|
||||
cfg := &RateLimitConfig{Limit: config.Size(1024)}
|
||||
_, err := cfg.CreateRateLimiter()
|
||||
require.ErrorContains(t, err, "invalid period for rate-limit")
|
||||
}
|
||||
|
||||
func TestUnlimited(t *testing.T) {
|
||||
cfg := &RateLimitConfig{}
|
||||
limiter, err := cfg.CreateRateLimiter()
|
||||
require.NoError(t, err)
|
||||
|
||||
start := time.Now()
|
||||
end := start.Add(30 * time.Minute)
|
||||
for ts := start; ts.Before(end); ts = ts.Add(1 * time.Minute) {
|
||||
require.EqualValues(t, int64(math.MaxInt64), limiter.Remaining(ts))
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnlimitedWithPeriod(t *testing.T) {
|
||||
cfg := &RateLimitConfig{
|
||||
Period: config.Duration(5 * time.Minute),
|
||||
}
|
||||
limiter, err := cfg.CreateRateLimiter()
|
||||
require.NoError(t, err)
|
||||
|
||||
start := time.Now()
|
||||
end := start.Add(30 * time.Minute)
|
||||
for ts := start; ts.Before(end); ts = ts.Add(1 * time.Minute) {
|
||||
require.EqualValues(t, int64(math.MaxInt64), limiter.Remaining(ts))
|
||||
}
|
||||
}
|
||||
|
||||
func TestLimited(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg *RateLimitConfig
|
||||
step time.Duration
|
||||
request []int64
|
||||
expected []int64
|
||||
}{
|
||||
{
|
||||
name: "constant usage",
|
||||
cfg: &RateLimitConfig{
|
||||
Limit: config.Size(1024),
|
||||
Period: config.Duration(5 * time.Minute),
|
||||
},
|
||||
step: time.Minute,
|
||||
request: []int64{300},
|
||||
expected: []int64{1024, 724, 424, 124, 0, 1024, 724, 424, 124, 0},
|
||||
},
|
||||
{
|
||||
name: "variable usage",
|
||||
cfg: &RateLimitConfig{
|
||||
Limit: config.Size(1024),
|
||||
Period: config.Duration(5 * time.Minute),
|
||||
},
|
||||
step: time.Minute,
|
||||
request: []int64{256, 128, 512, 64, 64, 1024, 0, 0, 0, 0, 128, 4096, 4096, 4096, 4096, 4096},
|
||||
expected: []int64{1024, 768, 640, 128, 64, 1024, 0, 0, 0, 0, 1024, 896, 0, 0, 0, 1024},
|
||||
},
|
||||
}
|
||||
|
||||
// Run the test with an offset of period multiples
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name+" at period", func(t *testing.T) {
|
||||
// Setup the limiter
|
||||
limiter, err := tt.cfg.CreateRateLimiter()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Compute the actual values
|
||||
start := time.Now().Truncate(tt.step)
|
||||
for i, expected := range tt.expected {
|
||||
ts := start.Add(time.Duration(i) * tt.step)
|
||||
remaining := limiter.Remaining(ts)
|
||||
use := min(remaining, tt.request[i%len(tt.request)])
|
||||
require.Equalf(t, expected, remaining, "mismatch at index %d", i)
|
||||
limiter.Accept(ts, use)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Run the test at a time of period multiples
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Setup the limiter
|
||||
limiter, err := tt.cfg.CreateRateLimiter()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Compute the actual values
|
||||
start := time.Now().Truncate(tt.step).Add(1 * time.Second)
|
||||
for i, expected := range tt.expected {
|
||||
ts := start.Add(time.Duration(i) * tt.step)
|
||||
remaining := limiter.Remaining(ts)
|
||||
use := min(remaining, tt.request[i%len(tt.request)])
|
||||
require.Equalf(t, expected, remaining, "mismatch at index %d", i)
|
||||
limiter.Accept(ts, use)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUndo(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg *RateLimitConfig
|
||||
step time.Duration
|
||||
request []int64
|
||||
expected []int64
|
||||
}{
|
||||
{
|
||||
name: "constant usage",
|
||||
cfg: &RateLimitConfig{
|
||||
Limit: config.Size(1024),
|
||||
Period: config.Duration(5 * time.Minute),
|
||||
},
|
||||
step: time.Minute,
|
||||
request: []int64{300},
|
||||
expected: []int64{1024, 724, 424, 124, 124, 1024, 724, 424, 124, 124},
|
||||
},
|
||||
{
|
||||
name: "variable usage",
|
||||
cfg: &RateLimitConfig{
|
||||
Limit: config.Size(1024),
|
||||
Period: config.Duration(5 * time.Minute),
|
||||
},
|
||||
step: time.Minute,
|
||||
request: []int64{256, 128, 512, 64, 64, 1024, 0, 0, 0, 0, 128, 4096, 4096, 4096, 4096, 4096},
|
||||
expected: []int64{1024, 768, 640, 128, 64, 1024, 0, 0, 0, 0, 1024, 896, 896, 896, 896, 1024},
|
||||
},
|
||||
}
|
||||
|
||||
// Run the test with an offset of period multiples
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name+" at period", func(t *testing.T) {
|
||||
// Setup the limiter
|
||||
limiter, err := tt.cfg.CreateRateLimiter()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Compute the actual values
|
||||
start := time.Now().Truncate(tt.step)
|
||||
for i, expected := range tt.expected {
|
||||
ts := start.Add(time.Duration(i) * tt.step)
|
||||
remaining := limiter.Remaining(ts)
|
||||
use := min(remaining, tt.request[i%len(tt.request)])
|
||||
require.Equalf(t, expected, remaining, "mismatch at index %d", i)
|
||||
limiter.Accept(ts, use)
|
||||
// Undo too large operations
|
||||
if tt.request[i%len(tt.request)] > remaining {
|
||||
limiter.Undo(ts, use)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Run the test at a time of period multiples
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Setup the limiter
|
||||
limiter, err := tt.cfg.CreateRateLimiter()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Compute the actual values
|
||||
start := time.Now().Truncate(tt.step).Add(1 * time.Second)
|
||||
for i, expected := range tt.expected {
|
||||
ts := start.Add(time.Duration(i) * tt.step)
|
||||
remaining := limiter.Remaining(ts)
|
||||
use := min(remaining, tt.request[i%len(tt.request)])
|
||||
require.Equalf(t, expected, remaining, "mismatch at index %d", i)
|
||||
limiter.Accept(ts, use)
|
||||
// Undo too large operations
|
||||
if tt.request[i%len(tt.request)] > remaining {
|
||||
limiter.Undo(ts, use)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
100
plugins/common/ratelimiter/serializers.go
Normal file
100
plugins/common/ratelimiter/serializers.go
Normal file
|
@ -0,0 +1,100 @@
|
|||
package ratelimiter
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"math"
|
||||
|
||||
"github.com/influxdata/telegraf"
|
||||
"github.com/influxdata/telegraf/internal"
|
||||
)
|
||||
|
||||
// Serializer interface abstracting the different implementations of a
|
||||
// limited-size serializer
|
||||
type Serializer interface {
|
||||
Serialize(metric telegraf.Metric, limit int64) ([]byte, error)
|
||||
SerializeBatch(metrics []telegraf.Metric, limit int64) ([]byte, error)
|
||||
}
|
||||
|
||||
// Individual serializers do serialize each metric individually using the
|
||||
// serializer's Serialize() function and add the resulting output to the buffer
|
||||
// until the limit is reached. This only works for serializers NOT requiring
|
||||
// the serialization of a batch as-a-whole.
|
||||
type IndividualSerializer struct {
|
||||
serializer telegraf.Serializer
|
||||
buffer *bytes.Buffer
|
||||
}
|
||||
|
||||
func NewIndividualSerializer(s telegraf.Serializer) *IndividualSerializer {
|
||||
return &IndividualSerializer{
|
||||
serializer: s,
|
||||
buffer: &bytes.Buffer{},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *IndividualSerializer) Serialize(metric telegraf.Metric, limit int64) ([]byte, error) {
|
||||
// Do the serialization
|
||||
buf, err := s.serializer.Serialize(metric)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// The serialized metric fits into the limit, so output it
|
||||
if buflen := int64(len(buf)); buflen <= limit {
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
// The serialized metric exceeds the limit
|
||||
return nil, internal.ErrSizeLimitReached
|
||||
}
|
||||
|
||||
func (s *IndividualSerializer) SerializeBatch(metrics []telegraf.Metric, limit int64) ([]byte, error) {
|
||||
// Grow the buffer so it can hold at least the required size. This will
|
||||
// save us from reallocate often
|
||||
s.buffer.Reset()
|
||||
if limit > 0 && limit > int64(s.buffer.Cap()) && limit < int64(math.MaxInt) {
|
||||
s.buffer.Grow(int(limit))
|
||||
}
|
||||
|
||||
// Prepare a potential write error and be optimistic
|
||||
werr := &internal.PartialWriteError{
|
||||
MetricsAccept: make([]int, 0, len(metrics)),
|
||||
}
|
||||
|
||||
// Iterate through the metrics, serialize them and add them to the output
|
||||
// buffer if they are within the size limit.
|
||||
var used int64
|
||||
for i, m := range metrics {
|
||||
buf, err := s.serializer.Serialize(m)
|
||||
if err != nil {
|
||||
// Failing serialization is a fatal error so mark the metric as such
|
||||
werr.Err = internal.ErrSerialization
|
||||
werr.MetricsReject = append(werr.MetricsReject, i)
|
||||
werr.MetricsRejectErrors = append(werr.MetricsRejectErrors, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// The serialized metric fits into the limit, so add it to the output
|
||||
if usedAdded := used + int64(len(buf)); usedAdded <= limit {
|
||||
if _, err := s.buffer.Write(buf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
werr.MetricsAccept = append(werr.MetricsAccept, i)
|
||||
used = usedAdded
|
||||
continue
|
||||
}
|
||||
|
||||
// Return only the size-limit-reached error if all metrics failed.
|
||||
if used == 0 {
|
||||
return nil, internal.ErrSizeLimitReached
|
||||
}
|
||||
|
||||
// Adding the serialized metric would exceed the limit so exit with an
|
||||
// WriteError and fill in the required information
|
||||
werr.Err = internal.ErrSizeLimitReached
|
||||
break
|
||||
}
|
||||
if werr.Err != nil {
|
||||
return s.buffer.Bytes(), werr
|
||||
}
|
||||
return s.buffer.Bytes(), nil
|
||||
}
|
352
plugins/common/ratelimiter/serializers_test.go
Normal file
352
plugins/common/ratelimiter/serializers_test.go
Normal file
|
@ -0,0 +1,352 @@
|
|||
package ratelimiter
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/influxdata/telegraf"
|
||||
"github.com/influxdata/telegraf/internal"
|
||||
"github.com/influxdata/telegraf/metric"
|
||||
"github.com/influxdata/telegraf/plugins/serializers/influx"
|
||||
)
|
||||
|
||||
func TestIndividualSerializer(t *testing.T) {
|
||||
input := []telegraf.Metric{
|
||||
metric.New(
|
||||
"serializer_test",
|
||||
map[string]string{
|
||||
"source": "localhost",
|
||||
"location": "factory_north",
|
||||
"machine": "A",
|
||||
"status": "ok",
|
||||
},
|
||||
map[string]interface{}{
|
||||
"operating_hours": 123,
|
||||
"temperature": 25.0,
|
||||
"pressure": 1023.4,
|
||||
},
|
||||
time.Unix(1722443551, 0),
|
||||
),
|
||||
metric.New(
|
||||
"serializer_test",
|
||||
map[string]string{
|
||||
"source": "localhost",
|
||||
"location": "factory_north",
|
||||
"machine": "B",
|
||||
"status": "failed",
|
||||
},
|
||||
map[string]interface{}{
|
||||
"operating_hours": 8430,
|
||||
"temperature": 65.2,
|
||||
"pressure": 985.9,
|
||||
},
|
||||
time.Unix(1722443554, 0),
|
||||
),
|
||||
metric.New(
|
||||
"serializer_test",
|
||||
map[string]string{
|
||||
"source": "localhost",
|
||||
"location": "factory_north",
|
||||
"machine": "C",
|
||||
"status": "warning",
|
||||
},
|
||||
map[string]interface{}{
|
||||
"operating_hours": 6765,
|
||||
"temperature": 42.5,
|
||||
"pressure": 986.1,
|
||||
},
|
||||
time.Unix(1722443555, 0),
|
||||
),
|
||||
metric.New(
|
||||
"device",
|
||||
map[string]string{
|
||||
"source": "localhost",
|
||||
"location": "factory_north",
|
||||
},
|
||||
map[string]interface{}{
|
||||
"status": "ok",
|
||||
},
|
||||
time.Unix(1722443556, 0),
|
||||
),
|
||||
metric.New(
|
||||
"serializer_test",
|
||||
map[string]string{
|
||||
"source": "gateway_af43e",
|
||||
"location": "factory_south",
|
||||
"machine": "A",
|
||||
"status": "ok",
|
||||
},
|
||||
map[string]interface{}{
|
||||
"operating_hours": 5544,
|
||||
"temperature": 18.6,
|
||||
"pressure": 1069.4,
|
||||
},
|
||||
time.Unix(1722443552, 0),
|
||||
),
|
||||
metric.New(
|
||||
"serializer_test",
|
||||
map[string]string{
|
||||
"source": "gateway_af43e",
|
||||
"location": "factory_south",
|
||||
"machine": "B",
|
||||
"status": "ok",
|
||||
},
|
||||
map[string]interface{}{
|
||||
"operating_hours": 65,
|
||||
"temperature": 29.7,
|
||||
"pressure": 1101.2,
|
||||
},
|
||||
time.Unix(1722443553, 0),
|
||||
),
|
||||
metric.New(
|
||||
"device",
|
||||
map[string]string{
|
||||
"source": "gateway_af43e",
|
||||
"location": "factory_south",
|
||||
},
|
||||
map[string]interface{}{
|
||||
"status": "ok",
|
||||
},
|
||||
time.Unix(1722443559, 0),
|
||||
),
|
||||
metric.New(
|
||||
"serializer_test",
|
||||
map[string]string{
|
||||
"source": "gateway_af43e",
|
||||
"location": "factory_south",
|
||||
"machine": "C",
|
||||
"status": "off",
|
||||
},
|
||||
map[string]interface{}{
|
||||
"operating_hours": 0,
|
||||
"temperature": 0.0,
|
||||
"pressure": 0.0,
|
||||
},
|
||||
time.Unix(1722443562, 0),
|
||||
),
|
||||
}
|
||||
//nolint:lll // Resulting metrics should not be wrapped for readability
|
||||
expected := []string{
|
||||
"serializer_test,location=factory_north,machine=A,source=localhost,status=ok operating_hours=123i,pressure=1023.4,temperature=25 1722443551000000000\n" +
|
||||
"serializer_test,location=factory_north,machine=B,source=localhost,status=failed operating_hours=8430i,pressure=985.9,temperature=65.2 1722443554000000000\n",
|
||||
"serializer_test,location=factory_north,machine=C,source=localhost,status=warning operating_hours=6765i,pressure=986.1,temperature=42.5 1722443555000000000\n" +
|
||||
"device,location=factory_north,source=localhost status=\"ok\" 1722443556000000000\n" +
|
||||
"serializer_test,location=factory_south,machine=A,source=gateway_af43e,status=ok operating_hours=5544i,pressure=1069.4,temperature=18.6 1722443552000000000\n",
|
||||
"serializer_test,location=factory_south,machine=B,source=gateway_af43e,status=ok operating_hours=65i,pressure=1101.2,temperature=29.7 1722443553000000000\n" +
|
||||
"device,location=factory_south,source=gateway_af43e status=\"ok\" 1722443559000000000\n" +
|
||||
"serializer_test,location=factory_south,machine=C,source=gateway_af43e,status=off operating_hours=0i,pressure=0,temperature=0 1722443562000000000\n",
|
||||
}
|
||||
|
||||
// Setup the limited serializer
|
||||
s := &influx.Serializer{SortFields: true}
|
||||
require.NoError(t, s.Init())
|
||||
serializer := NewIndividualSerializer(s)
|
||||
|
||||
var werr *internal.PartialWriteError
|
||||
|
||||
// Do the first serialization runs with all metrics
|
||||
buf, err := serializer.SerializeBatch(input, 400)
|
||||
require.ErrorAs(t, err, &werr)
|
||||
require.ErrorIs(t, werr.Err, internal.ErrSizeLimitReached)
|
||||
require.EqualValues(t, []int{0, 1}, werr.MetricsAccept)
|
||||
require.Empty(t, werr.MetricsReject)
|
||||
require.Equal(t, expected[0], string(buf))
|
||||
|
||||
// Run again with the successful metrics removed
|
||||
buf, err = serializer.SerializeBatch(input[2:], 400)
|
||||
require.ErrorAs(t, err, &werr)
|
||||
require.ErrorIs(t, werr.Err, internal.ErrSizeLimitReached)
|
||||
require.EqualValues(t, []int{0, 1, 2}, werr.MetricsAccept)
|
||||
require.Empty(t, werr.MetricsReject)
|
||||
require.Equal(t, expected[1], string(buf))
|
||||
|
||||
// Final run with the successful metrics removed
|
||||
buf, err = serializer.SerializeBatch(input[5:], 400)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, expected[2], string(buf))
|
||||
}
|
||||
|
||||
func TestIndividualSerializerFirstTooBig(t *testing.T) {
|
||||
input := []telegraf.Metric{
|
||||
metric.New(
|
||||
"serializer_test",
|
||||
map[string]string{
|
||||
"source": "localhost",
|
||||
"location": "factory_north",
|
||||
"machine": "A",
|
||||
"status": "ok",
|
||||
},
|
||||
map[string]interface{}{
|
||||
"operating_hours": 123,
|
||||
"temperature": 25.0,
|
||||
"pressure": 1023.4,
|
||||
},
|
||||
time.Unix(1722443551, 0),
|
||||
),
|
||||
metric.New(
|
||||
"serializer_test",
|
||||
map[string]string{
|
||||
"source": "localhost",
|
||||
"location": "factory_north",
|
||||
"machine": "B",
|
||||
"status": "failed",
|
||||
},
|
||||
map[string]interface{}{
|
||||
"operating_hours": 8430,
|
||||
"temperature": 65.2,
|
||||
"pressure": 985.9,
|
||||
},
|
||||
time.Unix(1722443554, 0),
|
||||
),
|
||||
}
|
||||
|
||||
// Setup the limited serializer
|
||||
s := &influx.Serializer{SortFields: true}
|
||||
require.NoError(t, s.Init())
|
||||
serializer := NewIndividualSerializer(s)
|
||||
|
||||
// The first metric will already exceed the size so all metrics fail and
|
||||
// we expect a shortcut error.
|
||||
buf, err := serializer.SerializeBatch(input, 100)
|
||||
require.ErrorIs(t, err, internal.ErrSizeLimitReached)
|
||||
require.Empty(t, buf)
|
||||
}
|
||||
|
||||
func TestIndividualSerializerUnlimited(t *testing.T) {
|
||||
input := []telegraf.Metric{
|
||||
metric.New(
|
||||
"serializer_test",
|
||||
map[string]string{
|
||||
"source": "localhost",
|
||||
"location": "factory_north",
|
||||
"machine": "A",
|
||||
"status": "ok",
|
||||
},
|
||||
map[string]interface{}{
|
||||
"operating_hours": 123,
|
||||
"temperature": 25.0,
|
||||
"pressure": 1023.4,
|
||||
},
|
||||
time.Unix(1722443551, 0),
|
||||
),
|
||||
metric.New(
|
||||
"serializer_test",
|
||||
map[string]string{
|
||||
"source": "localhost",
|
||||
"location": "factory_north",
|
||||
"machine": "B",
|
||||
"status": "failed",
|
||||
},
|
||||
map[string]interface{}{
|
||||
"operating_hours": 8430,
|
||||
"temperature": 65.2,
|
||||
"pressure": 985.9,
|
||||
},
|
||||
time.Unix(1722443554, 0),
|
||||
),
|
||||
metric.New(
|
||||
"serializer_test",
|
||||
map[string]string{
|
||||
"source": "localhost",
|
||||
"location": "factory_north",
|
||||
"machine": "C",
|
||||
"status": "warning",
|
||||
},
|
||||
map[string]interface{}{
|
||||
"operating_hours": 6765,
|
||||
"temperature": 42.5,
|
||||
"pressure": 986.1,
|
||||
},
|
||||
time.Unix(1722443555, 0),
|
||||
),
|
||||
metric.New(
|
||||
"device",
|
||||
map[string]string{
|
||||
"source": "localhost",
|
||||
"location": "factory_north",
|
||||
},
|
||||
map[string]interface{}{
|
||||
"status": "ok",
|
||||
},
|
||||
time.Unix(1722443556, 0),
|
||||
),
|
||||
metric.New(
|
||||
"serializer_test",
|
||||
map[string]string{
|
||||
"source": "gateway_af43e",
|
||||
"location": "factory_south",
|
||||
"machine": "A",
|
||||
"status": "ok",
|
||||
},
|
||||
map[string]interface{}{
|
||||
"operating_hours": 5544,
|
||||
"temperature": 18.6,
|
||||
"pressure": 1069.4,
|
||||
},
|
||||
time.Unix(1722443552, 0),
|
||||
),
|
||||
metric.New(
|
||||
"serializer_test",
|
||||
map[string]string{
|
||||
"source": "gateway_af43e",
|
||||
"location": "factory_south",
|
||||
"machine": "B",
|
||||
"status": "ok",
|
||||
},
|
||||
map[string]interface{}{
|
||||
"operating_hours": 65,
|
||||
"temperature": 29.7,
|
||||
"pressure": 1101.2,
|
||||
},
|
||||
time.Unix(1722443553, 0),
|
||||
),
|
||||
metric.New(
|
||||
"device",
|
||||
map[string]string{
|
||||
"source": "gateway_af43e",
|
||||
"location": "factory_south",
|
||||
},
|
||||
map[string]interface{}{
|
||||
"status": "ok",
|
||||
},
|
||||
time.Unix(1722443559, 0),
|
||||
),
|
||||
metric.New(
|
||||
"serializer_test",
|
||||
map[string]string{
|
||||
"source": "gateway_af43e",
|
||||
"location": "factory_south",
|
||||
"machine": "C",
|
||||
"status": "off",
|
||||
},
|
||||
map[string]interface{}{
|
||||
"operating_hours": 0,
|
||||
"temperature": 0.0,
|
||||
"pressure": 0.0,
|
||||
},
|
||||
time.Unix(1722443562, 0),
|
||||
),
|
||||
}
|
||||
//nolint:lll // Resulting metrics should not be wrapped for readability
|
||||
expected := "serializer_test,location=factory_north,machine=A,source=localhost,status=ok operating_hours=123i,pressure=1023.4,temperature=25 1722443551000000000\n" +
|
||||
"serializer_test,location=factory_north,machine=B,source=localhost,status=failed operating_hours=8430i,pressure=985.9,temperature=65.2 1722443554000000000\n" +
|
||||
"serializer_test,location=factory_north,machine=C,source=localhost,status=warning operating_hours=6765i,pressure=986.1,temperature=42.5 1722443555000000000\n" +
|
||||
"device,location=factory_north,source=localhost status=\"ok\" 1722443556000000000\n" +
|
||||
"serializer_test,location=factory_south,machine=A,source=gateway_af43e,status=ok operating_hours=5544i,pressure=1069.4,temperature=18.6 1722443552000000000\n" +
|
||||
"serializer_test,location=factory_south,machine=B,source=gateway_af43e,status=ok operating_hours=65i,pressure=1101.2,temperature=29.7 1722443553000000000\n" +
|
||||
"device,location=factory_south,source=gateway_af43e status=\"ok\" 1722443559000000000\n" +
|
||||
"serializer_test,location=factory_south,machine=C,source=gateway_af43e,status=off operating_hours=0i,pressure=0,temperature=0 1722443562000000000\n"
|
||||
|
||||
// Setup the limited serializer
|
||||
s := &influx.Serializer{SortFields: true}
|
||||
require.NoError(t, s.Init())
|
||||
serializer := NewIndividualSerializer(s)
|
||||
|
||||
// Do the first serialization runs with all metrics
|
||||
buf, err := serializer.SerializeBatch(input, math.MaxInt64)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, expected, string(buf))
|
||||
}
|
64
plugins/common/shim/README.md
Normal file
64
plugins/common/shim/README.md
Normal file
|
@ -0,0 +1,64 @@
|
|||
# Telegraf Execd Go Shim
|
||||
|
||||
The goal of this _shim_ is to make it trivial to extract an internal input,
|
||||
processor, or output plugin from the main Telegraf repo out to a stand-alone
|
||||
repo. This allows anyone to build and run it as a separate app using one of the
|
||||
execd plugins:
|
||||
|
||||
- [inputs.execd](/plugins/inputs/execd)
|
||||
- [processors.execd](/plugins/processors/execd)
|
||||
- [outputs.execd](/plugins/outputs/execd)
|
||||
|
||||
## Steps to externalize a plugin
|
||||
|
||||
1. Move the project to an external repo, it's recommended to preserve the path
|
||||
structure, (but not strictly necessary). eg if your plugin was at
|
||||
`plugins/inputs/cpu`, it's recommended that it also be under `plugins/inputs/cpu`
|
||||
in the new repo. For a further example of what this might look like, take a
|
||||
look at [ssoroka/rand](https://github.com/ssoroka/rand) or
|
||||
[danielnelson/telegraf-plugins](https://github.com/danielnelson/telegraf-plugins)
|
||||
1. Copy [main.go](./example/cmd/main.go) into your project under the `cmd` folder.
|
||||
This will be the entrypoint to the plugin when run as a stand-alone program, and
|
||||
it will call the shim code for you to make that happen. It's recommended to
|
||||
have only one plugin per repo, as the shim is not designed to run multiple
|
||||
plugins at the same time (it would vastly complicate things).
|
||||
1. Edit the main.go file to import your plugin. Within Telegraf this would have
|
||||
been done in an all.go file, but here we don't split the two apart, and the change
|
||||
just goes in the top of main.go. If you skip this step, your plugin will do nothing.
|
||||
eg: `_ "github.com/me/my-plugin-telegraf/plugins/inputs/cpu"`
|
||||
1. Optionally add a [plugin.conf](./example/cmd/plugin.conf) for configuration
|
||||
specific to your plugin. Note that this config file **must be separate from the
|
||||
rest of the config for Telegraf, and must not be in a shared directory where
|
||||
Telegraf is expecting to load all configs**. If Telegraf reads this config file
|
||||
it will not know which plugin it relates to. Telegraf instead uses an execd config
|
||||
block to look for this plugin.
|
||||
|
||||
## Steps to build and run your plugin
|
||||
|
||||
1. Build the cmd/main.go. For my rand project this looks like `go build -o rand cmd/main.go`
|
||||
1. If you're building an input, you can test out the binary just by running it.
|
||||
eg `./rand -config plugin.conf`
|
||||
Depending on your polling settings and whether you implemented a service plugin or
|
||||
an input gathering plugin, you may see data right away, or you may have to hit enter
|
||||
first, or wait for your poll duration to elapse, but the metrics will be written to
|
||||
STDOUT. Ctrl-C to end your test.
|
||||
If you're testing a processor or output manually, you can still do this but you
|
||||
will need to feed valid metrics in on STDIN to verify that it is doing what you
|
||||
want. This can be a very valuable debugging technique before hooking it up to
|
||||
Telegraf.
|
||||
1. Configure Telegraf to call your new plugin binary. For an input, this would
|
||||
look something like:
|
||||
|
||||
```toml
|
||||
[[inputs.execd]]
|
||||
command = ["/path/to/rand", "-config", "/path/to/plugin.conf"]
|
||||
signal = "none"
|
||||
```
|
||||
|
||||
Refer to the execd plugin readmes for more information.
|
||||
|
||||
## Congratulations
|
||||
|
||||
You've done it! Consider publishing your plugin to github and open a Pull Request
|
||||
back to the Telegraf repo letting us know about the availability of your
|
||||
[external plugin](https://github.com/influxdata/telegraf/blob/master/EXTERNAL_PLUGINS.md).
|
170
plugins/common/shim/config.go
Normal file
170
plugins/common/shim/config.go
Normal file
|
@ -0,0 +1,170 @@
|
|||
package shim
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log" //nolint:depguard // Allow exceptional but valid use of log here.
|
||||
"os"
|
||||
|
||||
"github.com/BurntSushi/toml"
|
||||
|
||||
"github.com/influxdata/telegraf"
|
||||
"github.com/influxdata/telegraf/plugins/inputs"
|
||||
"github.com/influxdata/telegraf/plugins/outputs"
|
||||
"github.com/influxdata/telegraf/plugins/processors"
|
||||
)
|
||||
|
||||
type config struct {
|
||||
Inputs map[string][]toml.Primitive
|
||||
Processors map[string][]toml.Primitive
|
||||
Outputs map[string][]toml.Primitive
|
||||
}
|
||||
|
||||
type loadedConfig struct {
|
||||
Input telegraf.Input
|
||||
Processor telegraf.StreamingProcessor
|
||||
Output telegraf.Output
|
||||
}
|
||||
|
||||
// LoadConfig Adds plugins to the shim
|
||||
func (s *Shim) LoadConfig(filePath *string) error {
|
||||
conf, err := LoadConfig(filePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if conf.Input != nil {
|
||||
if err = s.AddInput(conf.Input); err != nil {
|
||||
return fmt.Errorf("failed to add Input: %w", err)
|
||||
}
|
||||
} else if conf.Processor != nil {
|
||||
if err = s.AddStreamingProcessor(conf.Processor); err != nil {
|
||||
return fmt.Errorf("failed to add Processor: %w", err)
|
||||
}
|
||||
} else if conf.Output != nil {
|
||||
if err = s.AddOutput(conf.Output); err != nil {
|
||||
return fmt.Errorf("failed to add Output: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadConfig loads the config and returns inputs that later need to be loaded.
|
||||
func LoadConfig(filePath *string) (loaded loadedConfig, err error) {
|
||||
var data string
|
||||
conf := config{}
|
||||
if filePath != nil && *filePath != "" {
|
||||
b, err := os.ReadFile(*filePath)
|
||||
if err != nil {
|
||||
return loadedConfig{}, err
|
||||
}
|
||||
|
||||
data = expandEnvVars(b)
|
||||
} else {
|
||||
conf = DefaultImportedPlugins()
|
||||
}
|
||||
|
||||
md, err := toml.Decode(data, &conf)
|
||||
if err != nil {
|
||||
return loadedConfig{}, err
|
||||
}
|
||||
|
||||
return createPluginsWithTomlConfig(md, conf)
|
||||
}
|
||||
|
||||
func expandEnvVars(contents []byte) string {
|
||||
return os.Expand(string(contents), getEnv)
|
||||
}
|
||||
|
||||
func getEnv(key string) string {
|
||||
v := os.Getenv(key)
|
||||
|
||||
return envVarEscaper.Replace(v)
|
||||
}
|
||||
|
||||
func createPluginsWithTomlConfig(md toml.MetaData, conf config) (loadedConfig, error) {
|
||||
loadedConf := loadedConfig{}
|
||||
|
||||
for name, primitives := range conf.Inputs {
|
||||
creator, ok := inputs.Inputs[name]
|
||||
if !ok {
|
||||
return loadedConf, errors.New("unknown input " + name)
|
||||
}
|
||||
|
||||
plugin := creator()
|
||||
if len(primitives) > 0 {
|
||||
primitive := primitives[0]
|
||||
if err := md.PrimitiveDecode(primitive, plugin); err != nil {
|
||||
return loadedConf, err
|
||||
}
|
||||
}
|
||||
|
||||
loadedConf.Input = plugin
|
||||
break
|
||||
}
|
||||
|
||||
for name, primitives := range conf.Processors {
|
||||
creator, ok := processors.Processors[name]
|
||||
if !ok {
|
||||
return loadedConf, errors.New("unknown processor " + name)
|
||||
}
|
||||
|
||||
plugin := creator()
|
||||
if len(primitives) > 0 {
|
||||
primitive := primitives[0]
|
||||
var p telegraf.PluginDescriber = plugin
|
||||
if processor, ok := plugin.(processors.HasUnwrap); ok {
|
||||
p = processor.Unwrap()
|
||||
}
|
||||
if err := md.PrimitiveDecode(primitive, p); err != nil {
|
||||
return loadedConf, err
|
||||
}
|
||||
}
|
||||
loadedConf.Processor = plugin
|
||||
break
|
||||
}
|
||||
|
||||
for name, primitives := range conf.Outputs {
|
||||
creator, ok := outputs.Outputs[name]
|
||||
if !ok {
|
||||
return loadedConf, errors.New("unknown output " + name)
|
||||
}
|
||||
|
||||
plugin := creator()
|
||||
if len(primitives) > 0 {
|
||||
primitive := primitives[0]
|
||||
if err := md.PrimitiveDecode(primitive, plugin); err != nil {
|
||||
return loadedConf, err
|
||||
}
|
||||
}
|
||||
loadedConf.Output = plugin
|
||||
break
|
||||
}
|
||||
return loadedConf, nil
|
||||
}
|
||||
|
||||
// DefaultImportedPlugins defaults to whatever plugins happen to be loaded and
|
||||
// have registered themselves with the registry. This makes loading plugins
|
||||
// without having to define a config dead easy.
|
||||
func DefaultImportedPlugins() config {
|
||||
conf := config{
|
||||
Inputs: make(map[string][]toml.Primitive, len(inputs.Inputs)),
|
||||
Processors: make(map[string][]toml.Primitive, len(processors.Processors)),
|
||||
Outputs: make(map[string][]toml.Primitive, len(outputs.Outputs)),
|
||||
}
|
||||
for name := range inputs.Inputs {
|
||||
log.Println("No config found. Loading default config for plugin", name)
|
||||
conf.Inputs[name] = make([]toml.Primitive, 0)
|
||||
return conf
|
||||
}
|
||||
for name := range processors.Processors {
|
||||
log.Println("No config found. Loading default config for plugin", name)
|
||||
conf.Processors[name] = make([]toml.Primitive, 0)
|
||||
return conf
|
||||
}
|
||||
for name := range outputs.Outputs {
|
||||
log.Println("No config found. Loading default config for plugin", name)
|
||||
conf.Outputs[name] = make([]toml.Primitive, 0)
|
||||
return conf
|
||||
}
|
||||
return conf
|
||||
}
|
87
plugins/common/shim/config_test.go
Normal file
87
plugins/common/shim/config_test.go
Normal file
|
@ -0,0 +1,87 @@
|
|||
package shim
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/influxdata/telegraf"
|
||||
cfg "github.com/influxdata/telegraf/config"
|
||||
"github.com/influxdata/telegraf/plugins/inputs"
|
||||
"github.com/influxdata/telegraf/plugins/processors"
|
||||
)
|
||||
|
||||
func TestLoadConfig(t *testing.T) {
|
||||
t.Setenv("SECRET_TOKEN", "xxxxxxxxxx")
|
||||
t.Setenv("SECRET_VALUE", `test"\test`)
|
||||
|
||||
inputs.Add("test", func() telegraf.Input {
|
||||
return &serviceInput{}
|
||||
})
|
||||
|
||||
c := "./testdata/plugin.conf"
|
||||
conf, err := LoadConfig(&c)
|
||||
require.NoError(t, err)
|
||||
|
||||
inp := conf.Input.(*serviceInput)
|
||||
|
||||
require.Equal(t, "awesome name", inp.ServiceName)
|
||||
require.Equal(t, "xxxxxxxxxx", inp.SecretToken)
|
||||
require.Equal(t, `test"\test`, inp.SecretValue)
|
||||
}
|
||||
|
||||
func TestLoadingSpecialTypes(t *testing.T) {
|
||||
inputs.Add("test", func() telegraf.Input {
|
||||
return &testDurationInput{}
|
||||
})
|
||||
|
||||
c := "./testdata/special.conf"
|
||||
conf, err := LoadConfig(&c)
|
||||
require.NoError(t, err)
|
||||
|
||||
inp := conf.Input.(*testDurationInput)
|
||||
|
||||
require.EqualValues(t, 3*time.Second, inp.Duration)
|
||||
require.EqualValues(t, 3*1000*1000, inp.Size)
|
||||
require.EqualValues(t, 52, inp.Hex)
|
||||
}
|
||||
|
||||
func TestLoadingProcessorWithConfig(t *testing.T) {
|
||||
proc := &testConfigProcessor{}
|
||||
processors.Add("test_config_load", func() telegraf.Processor {
|
||||
return proc
|
||||
})
|
||||
|
||||
c := "./testdata/processor.conf"
|
||||
_, err := LoadConfig(&c)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.EqualValues(t, "yep", proc.Loaded)
|
||||
}
|
||||
|
||||
type testDurationInput struct {
|
||||
Duration cfg.Duration `toml:"duration"`
|
||||
Size cfg.Size `toml:"size"`
|
||||
Hex int64 `toml:"hex"`
|
||||
}
|
||||
|
||||
func (*testDurationInput) SampleConfig() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (*testDurationInput) Gather(telegraf.Accumulator) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type testConfigProcessor struct {
|
||||
Loaded string `toml:"loaded"`
|
||||
}
|
||||
|
||||
func (*testConfigProcessor) SampleConfig() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (*testConfigProcessor) Apply(metrics ...telegraf.Metric) []telegraf.Metric {
|
||||
return metrics
|
||||
}
|
64
plugins/common/shim/example/cmd/main.go
Normal file
64
plugins/common/shim/example/cmd/main.go
Normal file
|
@ -0,0 +1,64 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
// TODO: import your plugins
|
||||
_ "github.com/influxdata/tail" // Example external package for showing where you can import your plugins
|
||||
|
||||
"github.com/influxdata/telegraf/plugins/common/shim"
|
||||
)
|
||||
|
||||
var pollInterval = flag.Duration("poll_interval", 1*time.Second, "how often to send metrics")
|
||||
|
||||
var pollIntervalDisabled = flag.Bool(
|
||||
"poll_interval_disabled",
|
||||
false,
|
||||
"set to true to disable polling. You want to use this when you are sending metrics on your own schedule",
|
||||
)
|
||||
var configFile = flag.String("config", "", "path to the config file for this plugin")
|
||||
var err error
|
||||
|
||||
// This is designed to be simple; Just change the import above, and you're good.
|
||||
//
|
||||
// However, if you want to do all your config in code, you can like so:
|
||||
//
|
||||
// // initialize your plugin with any settings you want
|
||||
//
|
||||
// myInput := &mypluginname.MyPlugin{
|
||||
// DefaultSettingHere: 3,
|
||||
// }
|
||||
//
|
||||
// shim := shim.New()
|
||||
//
|
||||
// shim.AddInput(myInput)
|
||||
//
|
||||
// // now the shim.Run() call as below. Note the shim is only intended to run a single plugin.
|
||||
func main() {
|
||||
// parse command line options
|
||||
flag.Parse()
|
||||
if *pollIntervalDisabled {
|
||||
*pollInterval = shim.PollIntervalDisabled
|
||||
}
|
||||
|
||||
// create the shim. This is what will run your plugins.
|
||||
shimLayer := shim.New()
|
||||
|
||||
// If no config is specified, all imported plugins are loaded.
|
||||
// otherwise, follow what the config asks for.
|
||||
// Check for settings from a config toml file,
|
||||
// (or just use whatever plugins were imported above)
|
||||
if err = shimLayer.LoadConfig(configFile); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Err loading input: %s\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// run a single plugin until stdin closes, or we receive a termination signal
|
||||
if err = shimLayer.Run(*pollInterval); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Err: %s\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
2
plugins/common/shim/example/cmd/plugin.conf
Normal file
2
plugins/common/shim/example/cmd/plugin.conf
Normal file
|
@ -0,0 +1,2 @@
|
|||
[[inputs.my_plugin_name]]
|
||||
value_name = "value"
|
145
plugins/common/shim/goshim.go
Normal file
145
plugins/common/shim/goshim.go
Normal file
|
@ -0,0 +1,145 @@
|
|||
package shim
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/influxdata/telegraf"
|
||||
"github.com/influxdata/telegraf/logger"
|
||||
"github.com/influxdata/telegraf/plugins/serializers/influx"
|
||||
)
|
||||
|
||||
type empty struct{}
|
||||
|
||||
var (
|
||||
forever = 100 * 365 * 24 * time.Hour
|
||||
envVarEscaper = strings.NewReplacer(
|
||||
`"`, `\"`,
|
||||
`\`, `\\`,
|
||||
)
|
||||
)
|
||||
|
||||
const (
|
||||
// PollIntervalDisabled is used to indicate that you want to disable polling,
|
||||
// as opposed to duration 0 meaning poll constantly.
|
||||
PollIntervalDisabled = time.Duration(0)
|
||||
)
|
||||
|
||||
// Shim allows you to wrap your inputs and run them as if they were part of Telegraf,
|
||||
// except built externally.
|
||||
type Shim struct {
|
||||
Input telegraf.Input
|
||||
Processor telegraf.StreamingProcessor
|
||||
Output telegraf.Output
|
||||
|
||||
log telegraf.Logger
|
||||
|
||||
// streams
|
||||
stdin io.Reader
|
||||
stdout io.Writer
|
||||
stderr io.Writer
|
||||
|
||||
// outgoing metric channel
|
||||
metricCh chan telegraf.Metric
|
||||
|
||||
// input only
|
||||
gatherPromptCh chan empty
|
||||
}
|
||||
|
||||
// New creates a new shim interface
|
||||
func New() *Shim {
|
||||
return &Shim{
|
||||
metricCh: make(chan telegraf.Metric, 1),
|
||||
stdin: os.Stdin,
|
||||
stdout: os.Stdout,
|
||||
stderr: os.Stderr,
|
||||
log: logger.New("", "", ""),
|
||||
}
|
||||
}
|
||||
|
||||
func (*Shim) watchForShutdown(cancel context.CancelFunc) {
|
||||
quit := make(chan os.Signal, 1)
|
||||
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
||||
go func() {
|
||||
<-quit // user-triggered quit
|
||||
// cancel, but keep looping until the metric channel closes.
|
||||
cancel()
|
||||
}()
|
||||
}
|
||||
|
||||
// Run the input plugins..
|
||||
func (s *Shim) Run(pollInterval time.Duration) error {
|
||||
if s.Input != nil {
|
||||
err := s.RunInput(pollInterval)
|
||||
if err != nil {
|
||||
return fmt.Errorf("running input failed: %w", err)
|
||||
}
|
||||
} else if s.Processor != nil {
|
||||
err := s.RunProcessor()
|
||||
if err != nil {
|
||||
return fmt.Errorf("running processor failed: %w", err)
|
||||
}
|
||||
} else if s.Output != nil {
|
||||
err := s.RunOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf("running output failed: %w", err)
|
||||
}
|
||||
} else {
|
||||
return errors.New("nothing to run")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func hasQuit(ctx context.Context) bool {
|
||||
return ctx.Err() != nil
|
||||
}
|
||||
|
||||
func (s *Shim) writeProcessedMetrics() error {
|
||||
serializer := &influx.Serializer{}
|
||||
if err := serializer.Init(); err != nil {
|
||||
return fmt.Errorf("creating serializer failed: %w", err)
|
||||
}
|
||||
for { //nolint:staticcheck // for-select used on purpose
|
||||
select {
|
||||
case m, open := <-s.metricCh:
|
||||
if !open {
|
||||
return nil
|
||||
}
|
||||
b, err := serializer.Serialize(m)
|
||||
if err != nil {
|
||||
m.Reject()
|
||||
return fmt.Errorf("failed to serialize metric: %w", err)
|
||||
}
|
||||
// Write this to stdout
|
||||
_, err = fmt.Fprint(s.stdout, string(b))
|
||||
if err != nil {
|
||||
m.Drop()
|
||||
return fmt.Errorf("failed to write metric: %w", err)
|
||||
}
|
||||
m.Accept()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// LogName satisfies the MetricMaker interface
|
||||
func (*Shim) LogName() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// MakeMetric satisfies the MetricMaker interface
|
||||
func (*Shim) MakeMetric(m telegraf.Metric) telegraf.Metric {
|
||||
return m // don't need to do anything to it.
|
||||
}
|
||||
|
||||
// Log satisfies the MetricMaker interface
|
||||
func (s *Shim) Log() telegraf.Logger {
|
||||
return s.log
|
||||
}
|
78
plugins/common/shim/goshim_test.go
Normal file
78
plugins/common/shim/goshim_test.go
Normal file
|
@ -0,0 +1,78 @@
|
|||
package shim
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"io"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/influxdata/telegraf"
|
||||
"github.com/influxdata/telegraf/logger"
|
||||
)
|
||||
|
||||
func TestShimSetsUpLogger(t *testing.T) {
|
||||
stderrReader, stderrWriter := io.Pipe()
|
||||
stdinReader, stdinWriter := io.Pipe()
|
||||
|
||||
runErroringInputPlugin(t, 40*time.Second, stdinReader, nil, stderrWriter)
|
||||
|
||||
_, err := stdinWriter.Write([]byte("\n"))
|
||||
require.NoError(t, err)
|
||||
|
||||
r := bufio.NewReader(stderrReader)
|
||||
out, err := r.ReadString('\n')
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, out, "Error in plugin: intentional")
|
||||
|
||||
err = stdinWriter.Close()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func runErroringInputPlugin(t *testing.T, interval time.Duration, stdin io.Reader, stdout, stderr io.Writer) (processed, exited chan bool) {
|
||||
processed = make(chan bool, 1)
|
||||
exited = make(chan bool, 1)
|
||||
inp := &erroringInput{}
|
||||
|
||||
shim := New()
|
||||
if stdin != nil {
|
||||
shim.stdin = stdin
|
||||
}
|
||||
if stdout != nil {
|
||||
shim.stdout = stdout
|
||||
}
|
||||
if stderr != nil {
|
||||
shim.stderr = stderr
|
||||
logger.RedirectLogging(stderr)
|
||||
}
|
||||
|
||||
require.NoError(t, shim.AddInput(inp))
|
||||
go func(e chan bool) {
|
||||
if err := shim.Run(interval); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
e <- true
|
||||
}(exited)
|
||||
return processed, exited
|
||||
}
|
||||
|
||||
type erroringInput struct {
|
||||
}
|
||||
|
||||
func (*erroringInput) SampleConfig() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (*erroringInput) Gather(acc telegraf.Accumulator) error {
|
||||
acc.AddError(errors.New("intentional"))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (*erroringInput) Start(telegraf.Accumulator) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (*erroringInput) Stop() {
|
||||
}
|
116
plugins/common/shim/input.go
Normal file
116
plugins/common/shim/input.go
Normal file
|
@ -0,0 +1,116 @@
|
|||
package shim
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/influxdata/telegraf"
|
||||
"github.com/influxdata/telegraf/agent"
|
||||
"github.com/influxdata/telegraf/models"
|
||||
)
|
||||
|
||||
// AddInput adds the input to the shim. Later calls to Run() will run this input.
|
||||
func (s *Shim) AddInput(input telegraf.Input) error {
|
||||
models.SetLoggerOnPlugin(input, s.Log())
|
||||
if p, ok := input.(telegraf.Initializer); ok {
|
||||
err := p.Init()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to init input: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
s.Input = input
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Shim) RunInput(pollInterval time.Duration) error {
|
||||
// context is used only to close the stdin reader. everything else cascades
|
||||
// from that point and closes cleanly when it's done.
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s.watchForShutdown(cancel)
|
||||
|
||||
acc := agent.NewAccumulator(s, s.metricCh)
|
||||
acc.SetPrecision(time.Nanosecond)
|
||||
|
||||
if serviceInput, ok := s.Input.(telegraf.ServiceInput); ok {
|
||||
if err := serviceInput.Start(acc); err != nil {
|
||||
return fmt.Errorf("failed to start input: %w", err)
|
||||
}
|
||||
}
|
||||
s.gatherPromptCh = make(chan empty, 1)
|
||||
go func() {
|
||||
s.startGathering(ctx, s.Input, acc, pollInterval)
|
||||
if serviceInput, ok := s.Input.(telegraf.ServiceInput); ok {
|
||||
serviceInput.Stop()
|
||||
}
|
||||
// closing the metric channel gracefully stops writing to stdout
|
||||
close(s.metricCh)
|
||||
}()
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
err := s.writeProcessedMetrics()
|
||||
if err != nil {
|
||||
s.log.Warn(err.Error())
|
||||
}
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
go func() {
|
||||
scanner := bufio.NewScanner(s.stdin)
|
||||
for scanner.Scan() {
|
||||
// push a non-blocking message to trigger metric collection.
|
||||
s.pushCollectMetricsRequest()
|
||||
}
|
||||
|
||||
cancel() // cancel gracefully stops gathering
|
||||
}()
|
||||
|
||||
wg.Wait() // wait for writing to stdout to finish
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Shim) startGathering(ctx context.Context, input telegraf.Input, acc telegraf.Accumulator, pollInterval time.Duration) {
|
||||
if pollInterval == PollIntervalDisabled {
|
||||
pollInterval = forever
|
||||
}
|
||||
t := time.NewTicker(pollInterval)
|
||||
defer t.Stop()
|
||||
for {
|
||||
// give priority to stopping.
|
||||
if hasQuit(ctx) {
|
||||
return
|
||||
}
|
||||
// see what's up
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-s.gatherPromptCh:
|
||||
if err := input.Gather(acc); err != nil {
|
||||
fmt.Fprintf(s.stderr, "failed to gather metrics: %s\n", err)
|
||||
}
|
||||
case <-t.C:
|
||||
if err := input.Gather(acc); err != nil {
|
||||
fmt.Fprintf(s.stderr, "failed to gather metrics: %s\n", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// pushCollectMetricsRequest pushes a non-blocking (nil) message to the
|
||||
// gatherPromptCh channel to trigger metric collection.
|
||||
// The channel is defined with a buffer of 1, so while it's full, subsequent
|
||||
// requests are discarded.
|
||||
func (s *Shim) pushCollectMetricsRequest() {
|
||||
// push a message out to each channel to collect metrics. don't block.
|
||||
select {
|
||||
case s.gatherPromptCh <- empty{}:
|
||||
default:
|
||||
}
|
||||
}
|
140
plugins/common/shim/input_test.go
Normal file
140
plugins/common/shim/input_test.go
Normal file
|
@ -0,0 +1,140 @@
|
|||
package shim
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/influxdata/telegraf"
|
||||
)
|
||||
|
||||
func TestInputShimTimer(t *testing.T) {
|
||||
stdoutReader, stdoutWriter := io.Pipe()
|
||||
|
||||
stdin, _ := io.Pipe() // hold the stdin pipe open
|
||||
|
||||
metricProcessed, _ := runInputPlugin(t, 10*time.Millisecond, stdin, stdoutWriter, nil)
|
||||
|
||||
<-metricProcessed
|
||||
r := bufio.NewReader(stdoutReader)
|
||||
out, err := r.ReadString('\n')
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, out, "\n")
|
||||
metricLine := strings.Split(out, "\n")[0]
|
||||
require.Equal(t, "measurement,tag=tag field=1i 1234000005678", metricLine)
|
||||
}
|
||||
|
||||
func TestInputShimStdinSignalingWorks(t *testing.T) {
|
||||
stdinReader, stdinWriter := io.Pipe()
|
||||
stdoutReader, stdoutWriter := io.Pipe()
|
||||
|
||||
metricProcessed, exited := runInputPlugin(t, 40*time.Second, stdinReader, stdoutWriter, nil)
|
||||
|
||||
_, err := stdinWriter.Write([]byte("\n"))
|
||||
require.NoError(t, err)
|
||||
|
||||
<-metricProcessed
|
||||
|
||||
r := bufio.NewReader(stdoutReader)
|
||||
out, err := r.ReadString('\n')
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "measurement,tag=tag field=1i 1234000005678\n", out)
|
||||
|
||||
err = stdinWriter.Close()
|
||||
require.NoError(t, err)
|
||||
go func() {
|
||||
if _, err = io.ReadAll(r); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
// check that it exits cleanly
|
||||
<-exited
|
||||
}
|
||||
|
||||
func runInputPlugin(t *testing.T, interval time.Duration, stdin io.Reader, stdout, stderr io.Writer) (processed, exited chan bool) {
|
||||
processed = make(chan bool, 1)
|
||||
exited = make(chan bool, 1)
|
||||
inp := &testInput{
|
||||
metricProcessed: processed,
|
||||
}
|
||||
|
||||
shim := New()
|
||||
if stdin != nil {
|
||||
shim.stdin = stdin
|
||||
}
|
||||
if stdout != nil {
|
||||
shim.stdout = stdout
|
||||
}
|
||||
if stderr != nil {
|
||||
shim.stderr = stderr
|
||||
}
|
||||
err := shim.AddInput(inp)
|
||||
require.NoError(t, err)
|
||||
go func(e chan bool) {
|
||||
if err := shim.Run(interval); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
e <- true
|
||||
}(exited)
|
||||
return processed, exited
|
||||
}
|
||||
|
||||
type testInput struct {
|
||||
metricProcessed chan bool
|
||||
}
|
||||
|
||||
func (*testInput) SampleConfig() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (i *testInput) Gather(acc telegraf.Accumulator) error {
|
||||
acc.AddFields("measurement",
|
||||
map[string]interface{}{
|
||||
"field": 1,
|
||||
},
|
||||
map[string]string{
|
||||
"tag": "tag",
|
||||
}, time.Unix(1234, 5678))
|
||||
i.metricProcessed <- true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (*testInput) Start(telegraf.Accumulator) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (*testInput) Stop() {
|
||||
}
|
||||
|
||||
type serviceInput struct {
|
||||
ServiceName string `toml:"service_name"`
|
||||
SecretToken string `toml:"secret_token"`
|
||||
SecretValue string `toml:"secret_value"`
|
||||
}
|
||||
|
||||
func (*serviceInput) SampleConfig() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (*serviceInput) Gather(acc telegraf.Accumulator) error {
|
||||
acc.AddFields("measurement",
|
||||
map[string]interface{}{
|
||||
"field": 1,
|
||||
},
|
||||
map[string]string{
|
||||
"tag": "tag",
|
||||
}, time.Unix(1234, 5678))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (*serviceInput) Start(telegraf.Accumulator) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (*serviceInput) Stop() {
|
||||
}
|
54
plugins/common/shim/output.go
Normal file
54
plugins/common/shim/output.go
Normal file
|
@ -0,0 +1,54 @@
|
|||
package shim
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
|
||||
"github.com/influxdata/telegraf"
|
||||
"github.com/influxdata/telegraf/models"
|
||||
"github.com/influxdata/telegraf/plugins/parsers/influx"
|
||||
)
|
||||
|
||||
// AddOutput adds the input to the shim. Later calls to Run() will run this.
|
||||
func (s *Shim) AddOutput(output telegraf.Output) error {
|
||||
models.SetLoggerOnPlugin(output, s.Log())
|
||||
if p, ok := output.(telegraf.Initializer); ok {
|
||||
err := p.Init()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to init input: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
s.Output = output
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Shim) RunOutput() error {
|
||||
parser := influx.Parser{}
|
||||
err := parser.Init()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create new parser: %w", err)
|
||||
}
|
||||
|
||||
err = s.Output.Connect()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to start processor: %w", err)
|
||||
}
|
||||
defer s.Output.Close()
|
||||
|
||||
var m telegraf.Metric
|
||||
|
||||
scanner := bufio.NewScanner(s.stdin)
|
||||
for scanner.Scan() {
|
||||
m, err = parser.ParseLine(scanner.Text())
|
||||
if err != nil {
|
||||
fmt.Fprintf(s.stderr, "Failed to parse metric: %s\n", err)
|
||||
continue
|
||||
}
|
||||
if err = s.Output.Write([]telegraf.Metric{m}); err != nil {
|
||||
fmt.Fprintf(s.stderr, "Failed to write metric: %s\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
81
plugins/common/shim/output_test.go
Normal file
81
plugins/common/shim/output_test.go
Normal file
|
@ -0,0 +1,81 @@
|
|||
package shim
|
||||
|
||||
import (
|
||||
"io"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/influxdata/telegraf"
|
||||
"github.com/influxdata/telegraf/metric"
|
||||
"github.com/influxdata/telegraf/plugins/serializers/influx"
|
||||
"github.com/influxdata/telegraf/testutil"
|
||||
)
|
||||
|
||||
func TestOutputShim(t *testing.T) {
|
||||
o := &testOutput{}
|
||||
|
||||
stdinReader, stdinWriter := io.Pipe()
|
||||
|
||||
s := New()
|
||||
s.stdin = stdinReader
|
||||
err := s.AddOutput(o)
|
||||
require.NoError(t, err)
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
if err := s.RunOutput(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
serializer := &influx.Serializer{}
|
||||
require.NoError(t, serializer.Init())
|
||||
|
||||
m := metric.New("thing",
|
||||
map[string]string{
|
||||
"a": "b",
|
||||
},
|
||||
map[string]interface{}{
|
||||
"v": 1,
|
||||
},
|
||||
time.Now(),
|
||||
)
|
||||
b, err := serializer.Serialize(m)
|
||||
require.NoError(t, err)
|
||||
_, err = stdinWriter.Write(b)
|
||||
require.NoError(t, err)
|
||||
err = stdinWriter.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
wg.Wait()
|
||||
|
||||
require.Len(t, o.MetricsWritten, 1)
|
||||
mOut := o.MetricsWritten[0]
|
||||
|
||||
testutil.RequireMetricEqual(t, m, mOut)
|
||||
}
|
||||
|
||||
type testOutput struct {
|
||||
MetricsWritten []telegraf.Metric
|
||||
}
|
||||
|
||||
func (*testOutput) Connect() error {
|
||||
return nil
|
||||
}
|
||||
func (*testOutput) Close() error {
|
||||
return nil
|
||||
}
|
||||
func (o *testOutput) Write(metrics []telegraf.Metric) error {
|
||||
o.MetricsWritten = append(o.MetricsWritten, metrics...)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (*testOutput) SampleConfig() string {
|
||||
return ""
|
||||
}
|
81
plugins/common/shim/processor.go
Normal file
81
plugins/common/shim/processor.go
Normal file
|
@ -0,0 +1,81 @@
|
|||
package shim
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/influxdata/telegraf"
|
||||
"github.com/influxdata/telegraf/agent"
|
||||
"github.com/influxdata/telegraf/models"
|
||||
"github.com/influxdata/telegraf/plugins/parsers/influx"
|
||||
"github.com/influxdata/telegraf/plugins/processors"
|
||||
)
|
||||
|
||||
// AddProcessor adds the processor to the shim. Later calls to Run() will run this.
|
||||
func (s *Shim) AddProcessor(processor telegraf.Processor) error {
|
||||
models.SetLoggerOnPlugin(processor, s.Log())
|
||||
p := processors.NewStreamingProcessorFromProcessor(processor)
|
||||
return s.AddStreamingProcessor(p)
|
||||
}
|
||||
|
||||
// AddStreamingProcessor adds the processor to the shim. Later calls to Run() will run this.
|
||||
func (s *Shim) AddStreamingProcessor(processor telegraf.StreamingProcessor) error {
|
||||
models.SetLoggerOnPlugin(processor, s.Log())
|
||||
if p, ok := processor.(telegraf.Initializer); ok {
|
||||
err := p.Init()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to init input: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
s.Processor = processor
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Shim) RunProcessor() error {
|
||||
acc := agent.NewAccumulator(s, s.metricCh)
|
||||
acc.SetPrecision(time.Nanosecond)
|
||||
|
||||
err := s.Processor.Start(acc)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to start processor: %w", err)
|
||||
}
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
err := s.writeProcessedMetrics()
|
||||
if err != nil {
|
||||
s.log.Warn(err.Error())
|
||||
}
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
parser := influx.NewStreamParser(s.stdin)
|
||||
for {
|
||||
m, err := parser.Next()
|
||||
if err != nil {
|
||||
if errors.Is(err, influx.EOF) {
|
||||
break // stream ended
|
||||
}
|
||||
var parseErr *influx.ParseError
|
||||
if errors.As(err, &parseErr) {
|
||||
fmt.Fprintf(s.stderr, "Failed to parse metric: %s\b", parseErr)
|
||||
continue
|
||||
}
|
||||
fmt.Fprintf(s.stderr, "Failure during reading stdin: %s\b", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if err = s.Processor.Add(m, acc); err != nil {
|
||||
fmt.Fprintf(s.stderr, "Failure during processing metric by processor: %v\b", err)
|
||||
}
|
||||
}
|
||||
|
||||
close(s.metricCh)
|
||||
s.Processor.Stop()
|
||||
wg.Wait()
|
||||
return nil
|
||||
}
|
113
plugins/common/shim/processor_test.go
Normal file
113
plugins/common/shim/processor_test.go
Normal file
|
@ -0,0 +1,113 @@
|
|||
package shim
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"io"
|
||||
"math/rand"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/influxdata/telegraf"
|
||||
"github.com/influxdata/telegraf/metric"
|
||||
"github.com/influxdata/telegraf/plugins/parsers/influx"
|
||||
serializers_influx "github.com/influxdata/telegraf/plugins/serializers/influx"
|
||||
)
|
||||
|
||||
func TestProcessorShim(t *testing.T) {
|
||||
testSendAndReceive(t, "f1", "fv1")
|
||||
}
|
||||
|
||||
func TestProcessorShimWithLargerThanDefaultScannerBufferSize(t *testing.T) {
|
||||
letters := []rune("ABCDEFGHIJKLMNOPQRSTUVWXYZ")
|
||||
b := make([]rune, 0, bufio.MaxScanTokenSize*2)
|
||||
for i := 0; i < bufio.MaxScanTokenSize*2; i++ {
|
||||
b = append(b, letters[rand.Intn(len(letters))])
|
||||
}
|
||||
|
||||
testSendAndReceive(t, "f1", string(b))
|
||||
}
|
||||
|
||||
func testSendAndReceive(t *testing.T, fieldKey, fieldValue string) {
|
||||
p := &testProcessor{"hi", "mom"}
|
||||
|
||||
stdinReader, stdinWriter := io.Pipe()
|
||||
stdoutReader, stdoutWriter := io.Pipe()
|
||||
|
||||
s := New()
|
||||
// inject test into shim
|
||||
s.stdin = stdinReader
|
||||
s.stdout = stdoutWriter
|
||||
err := s.AddProcessor(p)
|
||||
require.NoError(t, err)
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
if err := s.RunProcessor(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
serializer := &serializers_influx.Serializer{}
|
||||
require.NoError(t, serializer.Init())
|
||||
|
||||
parser := influx.Parser{}
|
||||
require.NoError(t, parser.Init())
|
||||
|
||||
m := metric.New("thing",
|
||||
map[string]string{
|
||||
"a": "b",
|
||||
},
|
||||
map[string]interface{}{
|
||||
"v": 1,
|
||||
fieldKey: fieldValue,
|
||||
},
|
||||
time.Now(),
|
||||
)
|
||||
b, err := serializer.Serialize(m)
|
||||
require.NoError(t, err)
|
||||
_, err = stdinWriter.Write(b)
|
||||
require.NoError(t, err)
|
||||
err = stdinWriter.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
r := bufio.NewReader(stdoutReader)
|
||||
out, err := r.ReadString('\n')
|
||||
require.NoError(t, err)
|
||||
mOut, err := parser.ParseLine(out)
|
||||
require.NoError(t, err)
|
||||
|
||||
val, ok := mOut.GetTag(p.tagName)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, p.tagValue, val)
|
||||
val2, ok := mOut.Fields()[fieldKey]
|
||||
require.True(t, ok)
|
||||
require.Equal(t, fieldValue, val2)
|
||||
go func() {
|
||||
if _, err = io.ReadAll(r); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
type testProcessor struct {
|
||||
tagName string
|
||||
tagValue string
|
||||
}
|
||||
|
||||
func (p *testProcessor) Apply(in ...telegraf.Metric) []telegraf.Metric {
|
||||
for _, m := range in {
|
||||
m.AddTag(p.tagName, p.tagValue)
|
||||
}
|
||||
return in
|
||||
}
|
||||
|
||||
func (*testProcessor) SampleConfig() string {
|
||||
return ""
|
||||
}
|
4
plugins/common/shim/testdata/plugin.conf
vendored
Normal file
4
plugins/common/shim/testdata/plugin.conf
vendored
Normal file
|
@ -0,0 +1,4 @@
|
|||
[[inputs.test]]
|
||||
service_name = "awesome name"
|
||||
secret_token = "${SECRET_TOKEN}"
|
||||
secret_value = "$SECRET_VALUE"
|
2
plugins/common/shim/testdata/processor.conf
vendored
Normal file
2
plugins/common/shim/testdata/processor.conf
vendored
Normal file
|
@ -0,0 +1,2 @@
|
|||
[[processors.test_config_load]]
|
||||
loaded = "yep"
|
5
plugins/common/shim/testdata/special.conf
vendored
Normal file
5
plugins/common/shim/testdata/special.conf
vendored
Normal file
|
@ -0,0 +1,5 @@
|
|||
# testing custom field types
|
||||
[[inputs.test]]
|
||||
duration = "3s"
|
||||
size = "3MB"
|
||||
hex = 0x34
|
267
plugins/common/socket/datagram.go
Normal file
267
plugins/common/socket/datagram.go
Normal file
|
@ -0,0 +1,267 @@
|
|||
package socket
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/alitto/pond"
|
||||
|
||||
"github.com/influxdata/telegraf"
|
||||
"github.com/influxdata/telegraf/config"
|
||||
"github.com/influxdata/telegraf/internal"
|
||||
)
|
||||
|
||||
type packetListener struct {
|
||||
Encoding string
|
||||
MaxDecompressionSize int64
|
||||
SocketMode string
|
||||
ReadBufferSize int
|
||||
Log telegraf.Logger
|
||||
|
||||
conn net.PacketConn
|
||||
decoders sync.Pool
|
||||
path string
|
||||
wg sync.WaitGroup
|
||||
parsePool *pond.WorkerPool
|
||||
}
|
||||
|
||||
func newPacketListener(encoding string, maxDecompressionSize config.Size, maxWorkers int) *packetListener {
|
||||
return &packetListener{
|
||||
Encoding: encoding,
|
||||
MaxDecompressionSize: int64(maxDecompressionSize),
|
||||
parsePool: pond.New(maxWorkers, 0, pond.MinWorkers(maxWorkers/2+1)),
|
||||
}
|
||||
}
|
||||
|
||||
func (l *packetListener) listenData(onData CallbackData, onError CallbackError) {
|
||||
l.wg.Add(1)
|
||||
|
||||
go func() {
|
||||
defer l.wg.Done()
|
||||
|
||||
buf := make([]byte, l.ReadBufferSize)
|
||||
for {
|
||||
n, src, err := l.conn.ReadFrom(buf)
|
||||
receiveTime := time.Now()
|
||||
if err != nil {
|
||||
if !strings.HasSuffix(err.Error(), ": use of closed network connection") {
|
||||
if onError != nil {
|
||||
onError(err)
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
d := make([]byte, n)
|
||||
copy(d, buf[:n])
|
||||
l.parsePool.Submit(func() {
|
||||
decoder := l.decoders.Get().(internal.ContentDecoder)
|
||||
defer l.decoders.Put(decoder)
|
||||
body, err := decoder.Decode(d)
|
||||
if err != nil && onError != nil {
|
||||
onError(fmt.Errorf("unable to decode incoming packet: %w", err))
|
||||
}
|
||||
|
||||
if l.path != "" {
|
||||
src = &net.UnixAddr{Name: l.path, Net: "unixgram"}
|
||||
}
|
||||
|
||||
onData(src, body, receiveTime)
|
||||
})
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (l *packetListener) listenConnection(onConnection CallbackConnection, onError CallbackError) {
|
||||
l.wg.Add(1)
|
||||
go func() {
|
||||
defer l.wg.Done()
|
||||
defer l.conn.Close()
|
||||
|
||||
buf := make([]byte, l.ReadBufferSize)
|
||||
for {
|
||||
// Wait for packets and read them
|
||||
n, src, err := l.conn.ReadFrom(buf)
|
||||
if err != nil {
|
||||
if !strings.HasSuffix(err.Error(), ": use of closed network connection") {
|
||||
if onError != nil {
|
||||
onError(err)
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
d := make([]byte, n)
|
||||
copy(d, buf[:n])
|
||||
l.parsePool.Submit(func() {
|
||||
// Decode the contents depending on the given encoding
|
||||
decoder := l.decoders.Get().(internal.ContentDecoder)
|
||||
// Not possible to immediately return the decoder to the Pool after calling Decode, because some
|
||||
// decoders return a reference to their internal buffers. This would cause data races.
|
||||
defer l.decoders.Put(decoder)
|
||||
body, err := decoder.Decode(d[:n])
|
||||
if err != nil && onError != nil {
|
||||
onError(fmt.Errorf("unable to decode incoming packet: %w", err))
|
||||
}
|
||||
|
||||
// Workaround to provide remote endpoints for Unix-type sockets
|
||||
if l.path != "" {
|
||||
src = &net.UnixAddr{Name: l.path, Net: "unixgram"}
|
||||
}
|
||||
|
||||
// Create a pipe and notify the caller via Callback that new data is
|
||||
// available. Afterwards write the data. Please note: Write() will
|
||||
// block until all data is consumed!
|
||||
reader, writer := io.Pipe()
|
||||
go onConnection(src, reader)
|
||||
if _, err := writer.Write(body); err != nil && onError != nil {
|
||||
onError(err)
|
||||
}
|
||||
writer.Close()
|
||||
})
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (l *packetListener) setupUnixgram(u *url.URL, socketMode string, bufferSize int) error {
|
||||
l.path = filepath.FromSlash(u.Path)
|
||||
if runtime.GOOS == "windows" && strings.Contains(l.path, ":") {
|
||||
l.path = strings.TrimPrefix(l.path, `\`)
|
||||
}
|
||||
if err := os.Remove(l.path); err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
return fmt.Errorf("removing socket failed: %w", err)
|
||||
}
|
||||
|
||||
conn, err := net.ListenPacket(u.Scheme, l.path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("listening (unixgram) failed: %w", err)
|
||||
}
|
||||
l.conn = conn
|
||||
|
||||
// Set permissions on socket
|
||||
if socketMode != "" {
|
||||
// Convert from octal in string to int
|
||||
i, err := strconv.ParseUint(socketMode, 8, 32)
|
||||
if err != nil {
|
||||
return fmt.Errorf("converting socket mode failed: %w", err)
|
||||
}
|
||||
|
||||
perm := os.FileMode(uint32(i))
|
||||
if err := os.Chmod(u.Path, perm); err != nil {
|
||||
return fmt.Errorf("changing socket permissions failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if bufferSize > 0 {
|
||||
l.ReadBufferSize = bufferSize
|
||||
} else {
|
||||
l.ReadBufferSize = 64 * 1024 // 64kb - IP packet size
|
||||
}
|
||||
|
||||
return l.setupDecoder()
|
||||
}
|
||||
|
||||
func (l *packetListener) setupUDP(u *url.URL, ifname string, bufferSize int) error {
|
||||
var conn *net.UDPConn
|
||||
|
||||
addr, err := net.ResolveUDPAddr(u.Scheme, u.Host)
|
||||
if err != nil {
|
||||
return fmt.Errorf("resolving UDP address failed: %w", err)
|
||||
}
|
||||
if addr.IP.IsMulticast() {
|
||||
var iface *net.Interface
|
||||
if ifname != "" {
|
||||
var err error
|
||||
iface, err = net.InterfaceByName(ifname)
|
||||
if err != nil {
|
||||
return fmt.Errorf("resolving address of %q failed: %w", ifname, err)
|
||||
}
|
||||
}
|
||||
conn, err = net.ListenMulticastUDP(u.Scheme, iface, addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("listening (udp multicast) failed: %w", err)
|
||||
}
|
||||
} else {
|
||||
conn, err = net.ListenUDP(u.Scheme, addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("listening (udp) failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if bufferSize > 0 {
|
||||
if err := conn.SetReadBuffer(bufferSize); err != nil {
|
||||
l.Log.Warnf("Setting read buffer on %s socket failed: %v", u.Scheme, err)
|
||||
}
|
||||
}
|
||||
|
||||
l.ReadBufferSize = 64 * 1024 // 64kb - IP packet size
|
||||
l.conn = conn
|
||||
return l.setupDecoder()
|
||||
}
|
||||
|
||||
func (l *packetListener) setupIP(u *url.URL) error {
|
||||
conn, err := net.ListenPacket(u.Scheme, u.Host)
|
||||
if err != nil {
|
||||
return fmt.Errorf("listening (ip) failed: %w", err)
|
||||
}
|
||||
|
||||
l.ReadBufferSize = 64 * 1024 // 64kb - IP packet size
|
||||
l.conn = conn
|
||||
return l.setupDecoder()
|
||||
}
|
||||
|
||||
func (l *packetListener) setupDecoder() error {
|
||||
// Create a decoder for the given encoding
|
||||
var options []internal.DecodingOption
|
||||
if l.MaxDecompressionSize > 0 {
|
||||
options = append(options, internal.WithMaxDecompressionSize(l.MaxDecompressionSize))
|
||||
}
|
||||
|
||||
l.decoders = sync.Pool{New: func() any {
|
||||
decoder, err := internal.NewContentDecoder(l.Encoding, options...)
|
||||
if err != nil {
|
||||
l.Log.Errorf("creating decoder failed: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
return decoder
|
||||
}}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *packetListener) address() net.Addr {
|
||||
return l.conn.LocalAddr()
|
||||
}
|
||||
|
||||
func (l *packetListener) close() error {
|
||||
if err := l.conn.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
l.wg.Wait()
|
||||
|
||||
if l.path != "" {
|
||||
fn := filepath.FromSlash(l.path)
|
||||
if runtime.GOOS == "windows" && strings.Contains(fn, ":") {
|
||||
fn = strings.TrimPrefix(fn, `\`)
|
||||
}
|
||||
if err := os.Remove(fn); err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
// Ignore file-not-exists errors when removing the socket
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
l.parsePool.StopAndWait()
|
||||
|
||||
return nil
|
||||
}
|
37
plugins/common/socket/socket.conf
Normal file
37
plugins/common/socket/socket.conf
Normal file
|
@ -0,0 +1,37 @@
|
|||
## Permission for unix sockets (only available on unix sockets)
|
||||
## This setting may not be respected by some platforms. To safely restrict
|
||||
## permissions it is recommended to place the socket into a previously
|
||||
## created directory with the desired permissions.
|
||||
## ex: socket_mode = "777"
|
||||
# socket_mode = ""
|
||||
|
||||
## Maximum number of concurrent connections (only available on stream sockets like TCP)
|
||||
## Zero means unlimited.
|
||||
# max_connections = 0
|
||||
|
||||
## Read timeout (only available on stream sockets like TCP)
|
||||
## Zero means unlimited.
|
||||
# read_timeout = "0s"
|
||||
|
||||
## Optional TLS configuration (only available on stream sockets like TCP)
|
||||
# tls_cert = "/etc/telegraf/cert.pem"
|
||||
# tls_key = "/etc/telegraf/key.pem"
|
||||
## Enables client authentication if set.
|
||||
# tls_allowed_cacerts = ["/etc/telegraf/clientca.pem"]
|
||||
|
||||
## Maximum socket buffer size (in bytes when no unit specified)
|
||||
## For stream sockets, once the buffer fills up, the sender will start
|
||||
## backing up. For datagram sockets, once the buffer fills up, metrics will
|
||||
## start dropping. Defaults to the OS default.
|
||||
# read_buffer_size = "64KiB"
|
||||
|
||||
## Period between keep alive probes (only applies to TCP sockets)
|
||||
## Zero disables keep alive probes. Defaults to the OS configuration.
|
||||
# keep_alive_period = "5m"
|
||||
|
||||
## Content encoding for message payloads
|
||||
## Can be set to "gzip" for compressed payloads or "identity" for no encoding.
|
||||
# content_encoding = "identity"
|
||||
|
||||
## Maximum size of decoded packet (in bytes when no unit specified)
|
||||
# max_decompression_size = "500MB"
|
181
plugins/common/socket/socket.go
Normal file
181
plugins/common/socket/socket.go
Normal file
|
@ -0,0 +1,181 @@
|
|||
package socket
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/influxdata/telegraf"
|
||||
"github.com/influxdata/telegraf/config"
|
||||
common_tls "github.com/influxdata/telegraf/plugins/common/tls"
|
||||
)
|
||||
|
||||
type CallbackData func(net.Addr, []byte, time.Time)
|
||||
type CallbackConnection func(net.Addr, io.ReadCloser)
|
||||
type CallbackError func(error)
|
||||
|
||||
type listener interface {
|
||||
address() net.Addr
|
||||
listenData(CallbackData, CallbackError)
|
||||
listenConnection(CallbackConnection, CallbackError)
|
||||
close() error
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
MaxConnections uint64 `toml:"max_connections"`
|
||||
ReadBufferSize config.Size `toml:"read_buffer_size"`
|
||||
ReadTimeout config.Duration `toml:"read_timeout"`
|
||||
KeepAlivePeriod *config.Duration `toml:"keep_alive_period"`
|
||||
SocketMode string `toml:"socket_mode"`
|
||||
ContentEncoding string `toml:"content_encoding"`
|
||||
MaxDecompressionSize config.Size `toml:"max_decompression_size"`
|
||||
MaxParallelParsers int `toml:"max_parallel_parsers"`
|
||||
common_tls.ServerConfig
|
||||
}
|
||||
|
||||
type Socket struct {
|
||||
Config
|
||||
|
||||
url *url.URL
|
||||
interfaceName string
|
||||
tlsCfg *tls.Config
|
||||
log telegraf.Logger
|
||||
|
||||
splitter bufio.SplitFunc
|
||||
listener listener
|
||||
}
|
||||
|
||||
func (cfg *Config) NewSocket(address string, splitcfg *SplitConfig, logger telegraf.Logger) (*Socket, error) {
|
||||
s := &Socket{
|
||||
Config: *cfg,
|
||||
log: logger,
|
||||
}
|
||||
|
||||
// Setup the splitter if given
|
||||
if splitcfg != nil {
|
||||
splitter, err := splitcfg.NewSplitter()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.splitter = splitter
|
||||
}
|
||||
|
||||
// Resolve the interface to an address if any given
|
||||
ifregex := regexp.MustCompile(`%([\w\.]+)`)
|
||||
if matches := ifregex.FindStringSubmatch(address); len(matches) == 2 {
|
||||
s.interfaceName = matches[1]
|
||||
address = strings.Replace(address, "%"+s.interfaceName, "", 1)
|
||||
}
|
||||
|
||||
// Preparing TLS configuration
|
||||
tlsCfg, err := s.ServerConfig.TLSConfig()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting TLS config failed: %w", err)
|
||||
}
|
||||
s.tlsCfg = tlsCfg
|
||||
|
||||
// Parse and check the address
|
||||
u, err := url.Parse(address)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing address failed: %w", err)
|
||||
}
|
||||
s.url = u
|
||||
|
||||
switch s.url.Scheme {
|
||||
case "tcp", "tcp4", "tcp6", "unix", "unixpacket",
|
||||
"udp", "udp4", "udp6", "ip", "ip4", "ip6", "unixgram", "vsock":
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown protocol %q in %q", u.Scheme, address)
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *Socket) Setup() error {
|
||||
s.MaxParallelParsers = max(s.MaxParallelParsers, 1)
|
||||
switch s.url.Scheme {
|
||||
case "tcp", "tcp4", "tcp6":
|
||||
l := newStreamListener(
|
||||
s.Config,
|
||||
s.splitter,
|
||||
s.log,
|
||||
)
|
||||
|
||||
if err := l.setupTCP(s.url, s.tlsCfg); err != nil {
|
||||
return err
|
||||
}
|
||||
s.listener = l
|
||||
case "unix", "unixpacket":
|
||||
l := newStreamListener(
|
||||
s.Config,
|
||||
s.splitter,
|
||||
s.log,
|
||||
)
|
||||
|
||||
if err := l.setupUnix(s.url, s.tlsCfg, s.SocketMode); err != nil {
|
||||
return err
|
||||
}
|
||||
s.listener = l
|
||||
case "udp", "udp4", "udp6":
|
||||
l := newPacketListener(s.ContentEncoding, s.MaxDecompressionSize, s.MaxParallelParsers)
|
||||
if err := l.setupUDP(s.url, s.interfaceName, int(s.ReadBufferSize)); err != nil {
|
||||
return err
|
||||
}
|
||||
s.listener = l
|
||||
case "ip", "ip4", "ip6":
|
||||
l := newPacketListener(s.ContentEncoding, s.MaxDecompressionSize, s.MaxParallelParsers)
|
||||
if err := l.setupIP(s.url); err != nil {
|
||||
return err
|
||||
}
|
||||
s.listener = l
|
||||
case "unixgram":
|
||||
l := newPacketListener(s.ContentEncoding, s.MaxDecompressionSize, s.MaxParallelParsers)
|
||||
if err := l.setupUnixgram(s.url, s.SocketMode, int(s.ReadBufferSize)); err != nil {
|
||||
return err
|
||||
}
|
||||
s.listener = l
|
||||
case "vsock":
|
||||
l := newStreamListener(
|
||||
s.Config,
|
||||
s.splitter,
|
||||
s.log,
|
||||
)
|
||||
|
||||
if err := l.setupVsock(s.url); err != nil {
|
||||
return err
|
||||
}
|
||||
s.listener = l
|
||||
default:
|
||||
return fmt.Errorf("unknown protocol %q", s.url.Scheme)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Socket) Listen(onData CallbackData, onError CallbackError) {
|
||||
s.listener.listenData(onData, onError)
|
||||
}
|
||||
|
||||
func (s *Socket) ListenConnection(onConnection CallbackConnection, onError CallbackError) {
|
||||
s.listener.listenConnection(onConnection, onError)
|
||||
}
|
||||
|
||||
func (s *Socket) Close() {
|
||||
if s.listener != nil {
|
||||
// Ignore the returned error as we cannot do anything about it anyway
|
||||
if err := s.listener.close(); err != nil {
|
||||
s.log.Warnf("Closing socket failed: %v", err)
|
||||
}
|
||||
s.listener = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Socket) Address() net.Addr {
|
||||
return s.listener.address()
|
||||
}
|
845
plugins/common/socket/socket_test.go
Normal file
845
plugins/common/socket/socket_test.go
Normal file
|
@ -0,0 +1,845 @@
|
|||
package socket
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/influxdata/telegraf"
|
||||
"github.com/influxdata/telegraf/config"
|
||||
"github.com/influxdata/telegraf/internal"
|
||||
"github.com/influxdata/telegraf/metric"
|
||||
_ "github.com/influxdata/telegraf/plugins/parsers/all"
|
||||
"github.com/influxdata/telegraf/plugins/parsers/influx"
|
||||
"github.com/influxdata/telegraf/testutil"
|
||||
)
|
||||
|
||||
var pki = testutil.NewPKI("../../../testutil/pki")
|
||||
|
||||
func TestListenData(t *testing.T) {
|
||||
messages := [][]byte{
|
||||
[]byte("test,foo=bar v=1i 123456789\ntest,foo=baz v=2i 123456790\n"),
|
||||
[]byte("test,foo=zab v=3i 123456791\n"),
|
||||
}
|
||||
expectedTemplates := []telegraf.Metric{
|
||||
metric.New(
|
||||
"test",
|
||||
map[string]string{"foo": "bar"},
|
||||
map[string]interface{}{"v": int64(1)},
|
||||
time.Unix(0, 123456789),
|
||||
),
|
||||
metric.New(
|
||||
"test",
|
||||
map[string]string{"foo": "baz"},
|
||||
map[string]interface{}{"v": int64(2)},
|
||||
time.Unix(0, 123456790),
|
||||
),
|
||||
metric.New(
|
||||
"test",
|
||||
map[string]string{"foo": "zab"},
|
||||
map[string]interface{}{"v": int64(3)},
|
||||
time.Unix(0, 123456791),
|
||||
),
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
schema string
|
||||
buffersize config.Size
|
||||
encoding string
|
||||
}{
|
||||
{
|
||||
name: "TCP",
|
||||
schema: "tcp",
|
||||
buffersize: config.Size(1024),
|
||||
},
|
||||
{
|
||||
name: "TCP with TLS",
|
||||
schema: "tcp+tls",
|
||||
},
|
||||
{
|
||||
name: "TCP with gzip encoding",
|
||||
schema: "tcp",
|
||||
buffersize: config.Size(1024),
|
||||
encoding: "gzip",
|
||||
},
|
||||
{
|
||||
name: "UDP",
|
||||
schema: "udp",
|
||||
buffersize: config.Size(1024),
|
||||
},
|
||||
{
|
||||
name: "UDP with gzip encoding",
|
||||
schema: "udp",
|
||||
buffersize: config.Size(1024),
|
||||
encoding: "gzip",
|
||||
},
|
||||
{
|
||||
name: "unix socket",
|
||||
schema: "unix",
|
||||
buffersize: config.Size(1024),
|
||||
},
|
||||
{
|
||||
name: "unix socket with TLS",
|
||||
schema: "unix+tls",
|
||||
},
|
||||
{
|
||||
name: "unix socket with gzip encoding",
|
||||
schema: "unix",
|
||||
encoding: "gzip",
|
||||
},
|
||||
{
|
||||
name: "unixgram socket",
|
||||
schema: "unixgram",
|
||||
buffersize: config.Size(1024),
|
||||
},
|
||||
}
|
||||
|
||||
serverTLS := pki.TLSServerConfig()
|
||||
clientTLS := pki.TLSClientConfig()
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
proto := strings.TrimSuffix(tt.schema, "+tls")
|
||||
|
||||
// Prepare the address and socket if needed
|
||||
var sockPath string
|
||||
var serviceAddress string
|
||||
var tlsCfg *tls.Config
|
||||
switch proto {
|
||||
case "tcp", "udp":
|
||||
serviceAddress = proto + "://" + "127.0.0.1:0"
|
||||
case "unix", "unixgram":
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("Skipping on Windows, as unixgram sockets are not supported")
|
||||
}
|
||||
|
||||
// Create a socket
|
||||
sockPath = testutil.TempSocket(t)
|
||||
f, err := os.Create(sockPath)
|
||||
require.NoError(t, err)
|
||||
defer f.Close()
|
||||
serviceAddress = proto + "://" + sockPath
|
||||
}
|
||||
|
||||
// Setup the configuration according to test specification
|
||||
cfg := &Config{
|
||||
ContentEncoding: tt.encoding,
|
||||
ReadBufferSize: tt.buffersize,
|
||||
}
|
||||
if strings.HasSuffix(tt.schema, "tls") {
|
||||
cfg.ServerConfig = *serverTLS
|
||||
var err error
|
||||
tlsCfg, err = clientTLS.TLSConfig()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Create the socket
|
||||
sock, err := cfg.NewSocket(serviceAddress, &SplitConfig{}, &testutil.Logger{})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create callbacks
|
||||
parser := &influx.Parser{}
|
||||
require.NoError(t, parser.Init())
|
||||
|
||||
var acc testutil.Accumulator
|
||||
onData := func(remote net.Addr, data []byte, _ time.Time) {
|
||||
m, err := parser.Parse(data)
|
||||
require.NoError(t, err)
|
||||
addr, _, err := net.SplitHostPort(remote.String())
|
||||
if err != nil {
|
||||
addr = remote.String()
|
||||
}
|
||||
for i := range m {
|
||||
m[i].AddTag("source", addr)
|
||||
}
|
||||
acc.AddMetrics(m)
|
||||
}
|
||||
onError := func(err error) {
|
||||
acc.AddError(err)
|
||||
}
|
||||
|
||||
// Start the listener
|
||||
require.NoError(t, sock.Setup())
|
||||
sock.Listen(onData, onError)
|
||||
defer sock.Close()
|
||||
|
||||
addr := sock.Address()
|
||||
|
||||
// Create a noop client
|
||||
// Server is async, so verify no errors at the end.
|
||||
client, err := createClient(serviceAddress, addr, tlsCfg)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, client.Close())
|
||||
|
||||
// Setup the client for submitting data
|
||||
client, err = createClient(serviceAddress, addr, tlsCfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Conditionally add the source address to the expectation
|
||||
expected := make([]telegraf.Metric, 0, len(expectedTemplates))
|
||||
for _, tmpl := range expectedTemplates {
|
||||
m := tmpl.Copy()
|
||||
switch proto {
|
||||
case "tcp", "udp":
|
||||
laddr := client.LocalAddr().String()
|
||||
addr, _, err := net.SplitHostPort(laddr)
|
||||
if err != nil {
|
||||
addr = laddr
|
||||
}
|
||||
m.AddTag("source", addr)
|
||||
case "unix", "unixgram":
|
||||
m.AddTag("source", sockPath)
|
||||
}
|
||||
expected = append(expected, m)
|
||||
}
|
||||
|
||||
// Send the data with the correct encoding
|
||||
encoder, err := internal.NewContentEncoder(tt.encoding)
|
||||
require.NoError(t, err)
|
||||
|
||||
for i, msg := range messages {
|
||||
m, err := encoder.Encode(msg)
|
||||
require.NoErrorf(t, err, "encoding failed for msg %d", i)
|
||||
_, err = client.Write(m)
|
||||
require.NoErrorf(t, err, "sending msg %d failed", i)
|
||||
}
|
||||
|
||||
// Test the resulting metrics and compare against expected results
|
||||
require.Eventuallyf(t, func() bool {
|
||||
acc.Lock()
|
||||
defer acc.Unlock()
|
||||
return acc.NMetrics() >= uint64(len(expected))
|
||||
}, time.Second, 100*time.Millisecond, "did not receive metrics (%d)", acc.NMetrics())
|
||||
|
||||
actual := acc.GetTelegrafMetrics()
|
||||
testutil.RequireMetricsEqual(t, expected, actual, testutil.SortMetrics())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestListenConnection(t *testing.T) {
|
||||
messages := [][]byte{
|
||||
[]byte("test,foo=bar v=1i 123456789\ntest,foo=baz v=2i 123456790\n"),
|
||||
[]byte("test,foo=zab v=3i 123456791\n"),
|
||||
}
|
||||
expectedTemplates := []telegraf.Metric{
|
||||
metric.New(
|
||||
"test",
|
||||
map[string]string{"foo": "bar"},
|
||||
map[string]interface{}{"v": int64(1)},
|
||||
time.Unix(0, 123456789),
|
||||
),
|
||||
metric.New(
|
||||
"test",
|
||||
map[string]string{"foo": "baz"},
|
||||
map[string]interface{}{"v": int64(2)},
|
||||
time.Unix(0, 123456790),
|
||||
),
|
||||
metric.New(
|
||||
"test",
|
||||
map[string]string{"foo": "zab"},
|
||||
map[string]interface{}{"v": int64(3)},
|
||||
time.Unix(0, 123456791),
|
||||
),
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
schema string
|
||||
buffersize config.Size
|
||||
encoding string
|
||||
}{
|
||||
{
|
||||
name: "TCP",
|
||||
schema: "tcp",
|
||||
buffersize: config.Size(1024),
|
||||
},
|
||||
{
|
||||
name: "TCP with TLS",
|
||||
schema: "tcp+tls",
|
||||
},
|
||||
{
|
||||
name: "TCP with gzip encoding",
|
||||
schema: "tcp",
|
||||
buffersize: config.Size(1024),
|
||||
encoding: "gzip",
|
||||
},
|
||||
{
|
||||
name: "UDP",
|
||||
schema: "udp",
|
||||
buffersize: config.Size(1024),
|
||||
},
|
||||
{
|
||||
name: "UDP with gzip encoding",
|
||||
schema: "udp",
|
||||
buffersize: config.Size(1024),
|
||||
encoding: "gzip",
|
||||
},
|
||||
{
|
||||
name: "unix socket",
|
||||
schema: "unix",
|
||||
buffersize: config.Size(1024),
|
||||
},
|
||||
{
|
||||
name: "unix socket with TLS",
|
||||
schema: "unix+tls",
|
||||
},
|
||||
{
|
||||
name: "unix socket with gzip encoding",
|
||||
schema: "unix",
|
||||
encoding: "gzip",
|
||||
},
|
||||
{
|
||||
name: "unixgram socket",
|
||||
schema: "unixgram",
|
||||
buffersize: config.Size(1024),
|
||||
},
|
||||
}
|
||||
|
||||
serverTLS := pki.TLSServerConfig()
|
||||
clientTLS := pki.TLSClientConfig()
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
proto := strings.TrimSuffix(tt.schema, "+tls")
|
||||
|
||||
// Prepare the address and socket if needed
|
||||
var sockPath string
|
||||
var serviceAddress string
|
||||
var tlsCfg *tls.Config
|
||||
switch proto {
|
||||
case "tcp", "udp":
|
||||
serviceAddress = proto + "://" + "127.0.0.1:0"
|
||||
case "unix", "unixgram":
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("Skipping on Windows, as unixgram sockets are not supported")
|
||||
}
|
||||
|
||||
// Create a socket
|
||||
sockPath = testutil.TempSocket(t)
|
||||
f, err := os.Create(sockPath)
|
||||
require.NoError(t, err)
|
||||
defer f.Close()
|
||||
serviceAddress = proto + "://" + sockPath
|
||||
}
|
||||
|
||||
// Setup the configuration according to test specification
|
||||
cfg := &Config{
|
||||
ContentEncoding: tt.encoding,
|
||||
ReadBufferSize: tt.buffersize,
|
||||
}
|
||||
if strings.HasSuffix(tt.schema, "tls") {
|
||||
cfg.ServerConfig = *serverTLS
|
||||
var err error
|
||||
tlsCfg, err = clientTLS.TLSConfig()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Create the socket
|
||||
sock, err := cfg.NewSocket(serviceAddress, &SplitConfig{}, &testutil.Logger{})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create callbacks
|
||||
parser := &influx.Parser{}
|
||||
require.NoError(t, parser.Init())
|
||||
|
||||
var acc testutil.Accumulator
|
||||
onConnection := func(remote net.Addr, reader io.ReadCloser) {
|
||||
data, err := io.ReadAll(reader)
|
||||
require.NoError(t, err)
|
||||
m, err := parser.Parse(data)
|
||||
require.NoError(t, err)
|
||||
addr, _, err := net.SplitHostPort(remote.String())
|
||||
if err != nil {
|
||||
addr = remote.String()
|
||||
}
|
||||
for i := range m {
|
||||
m[i].AddTag("source", addr)
|
||||
}
|
||||
acc.AddMetrics(m)
|
||||
}
|
||||
onError := func(err error) {
|
||||
acc.AddError(err)
|
||||
}
|
||||
|
||||
// Start the listener
|
||||
require.NoError(t, sock.Setup())
|
||||
sock.ListenConnection(onConnection, onError)
|
||||
defer sock.Close()
|
||||
|
||||
addr := sock.Address()
|
||||
|
||||
// Create a noop client
|
||||
// Server is async, so verify no errors at the end.
|
||||
client, err := createClient(serviceAddress, addr, tlsCfg)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, client.Close())
|
||||
|
||||
// Setup the client for submitting data
|
||||
client, err = createClient(serviceAddress, addr, tlsCfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Conditionally add the source address to the expectation
|
||||
expected := make([]telegraf.Metric, 0, len(expectedTemplates))
|
||||
for _, tmpl := range expectedTemplates {
|
||||
m := tmpl.Copy()
|
||||
switch proto {
|
||||
case "tcp", "udp":
|
||||
laddr := client.LocalAddr().String()
|
||||
addr, _, err := net.SplitHostPort(laddr)
|
||||
if err != nil {
|
||||
addr = laddr
|
||||
}
|
||||
m.AddTag("source", addr)
|
||||
case "unix", "unixgram":
|
||||
m.AddTag("source", sockPath)
|
||||
}
|
||||
expected = append(expected, m)
|
||||
}
|
||||
|
||||
// Send the data with the correct encoding
|
||||
encoder, err := internal.NewContentEncoder(tt.encoding)
|
||||
require.NoError(t, err)
|
||||
|
||||
for i, msg := range messages {
|
||||
m, err := encoder.Encode(msg)
|
||||
require.NoErrorf(t, err, "encoding failed for msg %d", i)
|
||||
_, err = client.Write(m)
|
||||
require.NoErrorf(t, err, "sending msg %d failed", i)
|
||||
}
|
||||
client.Close()
|
||||
|
||||
// Test the resulting metrics and compare against expected results
|
||||
require.Eventuallyf(t, func() bool {
|
||||
acc.Lock()
|
||||
defer acc.Unlock()
|
||||
return acc.NMetrics() >= uint64(len(expected))
|
||||
}, time.Second, 100*time.Millisecond, "did not receive metrics (%d)", acc.NMetrics())
|
||||
|
||||
actual := acc.GetTelegrafMetrics()
|
||||
testutil.RequireMetricsEqual(t, expected, actual, testutil.SortMetrics())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClosingConnections(t *testing.T) {
|
||||
// Setup the configuration
|
||||
cfg := &Config{
|
||||
ReadBufferSize: 1024,
|
||||
}
|
||||
|
||||
// Create the socket
|
||||
serviceAddress := "tcp://127.0.0.1:0"
|
||||
logger := &testutil.CaptureLogger{}
|
||||
sock, err := cfg.NewSocket(serviceAddress, &SplitConfig{}, logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create callbacks
|
||||
parser := &influx.Parser{}
|
||||
require.NoError(t, parser.Init())
|
||||
|
||||
var acc testutil.Accumulator
|
||||
onData := func(_ net.Addr, data []byte, _ time.Time) {
|
||||
m, err := parser.Parse(data)
|
||||
require.NoError(t, err)
|
||||
acc.AddMetrics(m)
|
||||
}
|
||||
onError := func(err error) {
|
||||
acc.AddError(err)
|
||||
}
|
||||
|
||||
// Start the listener
|
||||
require.NoError(t, sock.Setup())
|
||||
sock.Listen(onData, onError)
|
||||
defer sock.Close()
|
||||
|
||||
addr := sock.Address()
|
||||
|
||||
// Create a noop client
|
||||
client, err := createClient(serviceAddress, addr, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.Write([]byte("test value=42i\n"))
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
acc.Lock()
|
||||
defer acc.Unlock()
|
||||
return acc.NMetrics() >= 1
|
||||
}, time.Second, 100*time.Millisecond, "did not receive metric")
|
||||
|
||||
// This has to be a stream-listener...
|
||||
listener, ok := sock.listener.(*streamListener)
|
||||
require.True(t, ok)
|
||||
listener.Lock()
|
||||
conns := listener.connections
|
||||
listener.Unlock()
|
||||
require.NotZero(t, conns)
|
||||
|
||||
sock.Close()
|
||||
|
||||
// Verify that plugin.Stop() closed the client's connection
|
||||
require.NoError(t, client.SetReadDeadline(time.Now().Add(time.Second)))
|
||||
buf := []byte{1}
|
||||
_, err = client.Read(buf)
|
||||
require.Equal(t, err, io.EOF)
|
||||
|
||||
require.Empty(t, logger.Errors())
|
||||
require.Empty(t, logger.Warnings())
|
||||
}
|
||||
func TestMaxConnections(t *testing.T) {
|
||||
if runtime.GOOS == "darwin" {
|
||||
t.Skip("Skipping on darwin due to missing socket options")
|
||||
}
|
||||
|
||||
// Setup the configuration
|
||||
period := config.Duration(10 * time.Millisecond)
|
||||
cfg := &Config{
|
||||
MaxConnections: 5,
|
||||
KeepAlivePeriod: &period,
|
||||
}
|
||||
|
||||
// Create the socket
|
||||
serviceAddress := "tcp://127.0.0.1:0"
|
||||
sock, err := cfg.NewSocket(serviceAddress, nil, &testutil.Logger{})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create callback
|
||||
var errs []error
|
||||
var mu sync.Mutex
|
||||
onData := func(_ net.Addr, _ []byte, _ time.Time) {}
|
||||
onError := func(err error) {
|
||||
mu.Lock()
|
||||
errs = append(errs, err)
|
||||
mu.Unlock()
|
||||
}
|
||||
|
||||
// Start the listener
|
||||
require.NoError(t, sock.Setup())
|
||||
sock.Listen(onData, onError)
|
||||
defer sock.Close()
|
||||
|
||||
addr := sock.Address()
|
||||
|
||||
// Create maximum number of connections and write some data. All of this
|
||||
// should succeed...
|
||||
clients := make([]*net.TCPConn, 0, cfg.MaxConnections)
|
||||
for i := 0; i < int(cfg.MaxConnections); i++ {
|
||||
c, err := net.DialTCP("tcp", nil, addr.(*net.TCPAddr))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, c.SetWriteBuffer(0))
|
||||
require.NoError(t, c.SetNoDelay(true))
|
||||
clients = append(clients, c)
|
||||
|
||||
_, err = c.Write([]byte("test value=42i\n"))
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func() {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
require.Empty(t, errs)
|
||||
}()
|
||||
|
||||
// Create another client. This should fail because we already reached the
|
||||
// connection limit and the connection should be closed...
|
||||
client, err := net.DialTCP("tcp", nil, addr.(*net.TCPAddr))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, client.SetWriteBuffer(0))
|
||||
require.NoError(t, client.SetNoDelay(true))
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
return len(errs) > 0
|
||||
}, 3*time.Second, 100*time.Millisecond)
|
||||
func() {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
require.Len(t, errs, 1)
|
||||
require.ErrorContains(t, errs[0], "too many connections")
|
||||
errs = make([]error, 0)
|
||||
}()
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
_, err := client.Write([]byte("fail\n"))
|
||||
return err != nil
|
||||
}, 3*time.Second, 100*time.Millisecond)
|
||||
_, err = client.Write([]byte("test\n"))
|
||||
require.Error(t, err)
|
||||
|
||||
// Check other connections are still good
|
||||
for _, c := range clients {
|
||||
_, err := c.Write([]byte("test\n"))
|
||||
require.NoError(t, err)
|
||||
}
|
||||
func() {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
require.Empty(t, errs)
|
||||
}()
|
||||
|
||||
// Close the first client and check if we can connect now
|
||||
require.NoError(t, clients[0].Close())
|
||||
client, err = net.DialTCP("tcp", nil, addr.(*net.TCPAddr))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, client.SetWriteBuffer(0))
|
||||
require.NoError(t, client.SetNoDelay(true))
|
||||
_, err = client.Write([]byte("success\n"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Close all connections
|
||||
require.NoError(t, client.Close())
|
||||
for _, c := range clients[1:] {
|
||||
require.NoError(t, c.Close())
|
||||
}
|
||||
|
||||
// Close the clients and check the connection counter
|
||||
listener, ok := sock.listener.(*streamListener)
|
||||
require.True(t, ok)
|
||||
require.Eventually(t, func() bool {
|
||||
listener.Lock()
|
||||
conns := listener.connections
|
||||
listener.Unlock()
|
||||
return conns == 0
|
||||
}, 3*time.Second, 100*time.Millisecond)
|
||||
|
||||
// Close the socket and check again...
|
||||
sock.Close()
|
||||
listener.Lock()
|
||||
conns := listener.connections
|
||||
listener.Unlock()
|
||||
require.Zero(t, conns)
|
||||
}
|
||||
|
||||
func TestNoSplitter(t *testing.T) {
|
||||
messages := [][]byte{
|
||||
[]byte("test,foo=bar v"),
|
||||
[]byte("=1i 123456789\ntest,foo=baz v=2i 123456790\ntest,foo=zab v=3i 123456791\n"),
|
||||
}
|
||||
expectedTemplates := []telegraf.Metric{
|
||||
metric.New(
|
||||
"test",
|
||||
map[string]string{"foo": "bar"},
|
||||
map[string]interface{}{"v": int64(1)},
|
||||
time.Unix(0, 123456789),
|
||||
),
|
||||
metric.New(
|
||||
"test",
|
||||
map[string]string{"foo": "baz"},
|
||||
map[string]interface{}{"v": int64(2)},
|
||||
time.Unix(0, 123456790),
|
||||
),
|
||||
metric.New(
|
||||
"test",
|
||||
map[string]string{"foo": "zab"},
|
||||
map[string]interface{}{"v": int64(3)},
|
||||
time.Unix(0, 123456791),
|
||||
),
|
||||
}
|
||||
|
||||
// Prepare the address and socket if needed
|
||||
serviceAddress := "tcp://127.0.0.1:0"
|
||||
|
||||
// Setup the configuration according to test specification
|
||||
cfg := &Config{}
|
||||
|
||||
// Create the socket
|
||||
sock, err := cfg.NewSocket(serviceAddress, nil, &testutil.Logger{})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create callbacks
|
||||
parser := &influx.Parser{}
|
||||
require.NoError(t, parser.Init())
|
||||
|
||||
var acc testutil.Accumulator
|
||||
onConnection := func(remote net.Addr, reader io.ReadCloser) {
|
||||
data, err := io.ReadAll(reader)
|
||||
require.NoError(t, err)
|
||||
m, err := parser.Parse(data)
|
||||
require.NoError(t, err)
|
||||
addr, _, err := net.SplitHostPort(remote.String())
|
||||
if err != nil {
|
||||
addr = remote.String()
|
||||
}
|
||||
for i := range m {
|
||||
m[i].AddTag("source", addr)
|
||||
}
|
||||
acc.AddMetrics(m)
|
||||
}
|
||||
onError := func(err error) {
|
||||
acc.AddError(err)
|
||||
}
|
||||
|
||||
// Start the listener
|
||||
require.NoError(t, sock.Setup())
|
||||
sock.ListenConnection(onConnection, onError)
|
||||
defer sock.Close()
|
||||
|
||||
addr := sock.Address()
|
||||
|
||||
// Create a noop client
|
||||
// Server is async, so verify no errors at the end.
|
||||
client, err := createClient(serviceAddress, addr, nil)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, client.Close())
|
||||
|
||||
// Setup the client for submitting data
|
||||
client, err = createClient(serviceAddress, addr, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Conditionally add the source address to the expectation
|
||||
expected := make([]telegraf.Metric, 0, len(expectedTemplates))
|
||||
for _, tmpl := range expectedTemplates {
|
||||
m := tmpl.Copy()
|
||||
laddr := client.LocalAddr().String()
|
||||
addr, _, err := net.SplitHostPort(laddr)
|
||||
if err != nil {
|
||||
addr = laddr
|
||||
}
|
||||
m.AddTag("source", addr)
|
||||
expected = append(expected, m)
|
||||
}
|
||||
|
||||
// Send the data
|
||||
for i, msg := range messages {
|
||||
_, err = client.Write(msg)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
require.NoErrorf(t, err, "sending msg %d failed", i)
|
||||
}
|
||||
client.Close()
|
||||
|
||||
// Test the resulting metrics and compare against expected results
|
||||
require.Eventuallyf(t, func() bool {
|
||||
acc.Lock()
|
||||
defer acc.Unlock()
|
||||
return acc.NMetrics() >= uint64(len(expected))
|
||||
}, time.Second, 100*time.Millisecond, "did not receive metrics (%d)", acc.NMetrics())
|
||||
|
||||
actual := acc.GetTelegrafMetrics()
|
||||
testutil.RequireMetricsEqual(t, expected, actual, testutil.SortMetrics())
|
||||
}
|
||||
|
||||
func TestTLSMemLeak(t *testing.T) {
|
||||
// For issue https://github.com/influxdata/telegraf/issues/15509
|
||||
|
||||
// Prepare the address and socket if needed
|
||||
serviceAddress := "tcp://127.0.0.1:0"
|
||||
|
||||
// Setup a TLS socket to trigger the issue
|
||||
cfg := &Config{
|
||||
ServerConfig: *pki.TLSServerConfig(),
|
||||
}
|
||||
|
||||
// Create the socket
|
||||
sock, err := cfg.NewSocket(serviceAddress, nil, &testutil.Logger{})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create callbacks
|
||||
onConnection := func(_ net.Addr, reader io.ReadCloser) {
|
||||
//nolint:errcheck // We are not interested in the data so ignore all errors
|
||||
io.Copy(io.Discard, reader)
|
||||
}
|
||||
|
||||
// Start the listener
|
||||
require.NoError(t, sock.Setup())
|
||||
sock.ListenConnection(onConnection, nil)
|
||||
defer sock.Close()
|
||||
|
||||
addr := sock.Address()
|
||||
|
||||
// Setup the client side TLS
|
||||
tlsCfg, err := pki.TLSClientConfig().TLSConfig()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Define a single client write sequence
|
||||
data := []byte("test value=42i")
|
||||
write := func() error {
|
||||
conn, err := tls.Dial("tcp", addr.String(), tlsCfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
_, err = conn.Write(data)
|
||||
return err
|
||||
}
|
||||
|
||||
// Define a test with the given number of connections
|
||||
maxConcurrency := runtime.GOMAXPROCS(0)
|
||||
testCycle := func(connections int) (uint64, error) {
|
||||
var mu sync.Mutex
|
||||
var errs []error
|
||||
var wg sync.WaitGroup
|
||||
for count := 1; count < connections; count++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := write(); err != nil {
|
||||
mu.Lock()
|
||||
errs = append(errs, err)
|
||||
mu.Unlock()
|
||||
}
|
||||
}()
|
||||
if count%maxConcurrency == 0 {
|
||||
wg.Wait()
|
||||
mu.Lock()
|
||||
if len(errs) > 0 {
|
||||
mu.Unlock()
|
||||
return 0, errors.Join(errs...)
|
||||
}
|
||||
mu.Unlock()
|
||||
}
|
||||
}
|
||||
//nolint:revive // We need to actively run the garbage collector to get reliable measurements
|
||||
runtime.GC()
|
||||
|
||||
var stats runtime.MemStats
|
||||
runtime.ReadMemStats(&stats)
|
||||
return stats.HeapObjects, nil
|
||||
}
|
||||
|
||||
// Measure the memory usage after a short warmup and after some time.
|
||||
// The final number of heap objects should not exceed the number of
|
||||
// runs by a save margin
|
||||
|
||||
// Warmup, do a low number of runs to initialize all data structures
|
||||
// taking them out of the equation.
|
||||
initial, err := testCycle(100)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Do some more runs and make sure the memory growth is bound
|
||||
final, err := testCycle(2000)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Less(t, final, 3*initial)
|
||||
}
|
||||
|
||||
func createClient(endpoint string, addr net.Addr, tlsCfg *tls.Config) (net.Conn, error) {
|
||||
// Determine the protocol in a crude fashion
|
||||
parts := strings.SplitN(endpoint, "://", 2)
|
||||
if len(parts) != 2 {
|
||||
return nil, fmt.Errorf("invalid endpoint %q", endpoint)
|
||||
}
|
||||
protocol := parts[0]
|
||||
|
||||
if tlsCfg == nil {
|
||||
return net.Dial(protocol, addr.String())
|
||||
}
|
||||
|
||||
if protocol == "unix" {
|
||||
tlsCfg.InsecureSkipVerify = true
|
||||
}
|
||||
return tls.Dial(protocol, addr.String(), tlsCfg)
|
||||
}
|
37
plugins/common/socket/splitter.conf
Normal file
37
plugins/common/socket/splitter.conf
Normal file
|
@ -0,0 +1,37 @@
|
|||
## Message splitting strategy and corresponding settings for stream sockets
|
||||
## (tcp, tcp4, tcp6, unix or unixpacket). The setting is ignored for packet
|
||||
## listeners such as udp.
|
||||
## Available strategies are:
|
||||
## newline -- split at newlines (default)
|
||||
## null -- split at null bytes
|
||||
## delimiter -- split at delimiter byte-sequence in hex-format
|
||||
## given in `splitting_delimiter`
|
||||
## fixed length -- split after number of bytes given in `splitting_length`
|
||||
## variable length -- split depending on length information received in the
|
||||
## data. The length field information is specified in
|
||||
## `splitting_length_field`.
|
||||
# splitting_strategy = "newline"
|
||||
|
||||
## Delimiter used to split received data to messages consumed by the parser.
|
||||
## The delimiter is a hex byte-sequence marking the end of a message
|
||||
## e.g. "0x0D0A", "x0d0a" or "0d0a" marks a Windows line-break (CR LF).
|
||||
## The value is case-insensitive and can be specified with "0x" or "x" prefix
|
||||
## or without.
|
||||
## Note: This setting is only used for splitting_strategy = "delimiter".
|
||||
# splitting_delimiter = ""
|
||||
|
||||
## Fixed length of a message in bytes.
|
||||
## Note: This setting is only used for splitting_strategy = "fixed length".
|
||||
# splitting_length = 0
|
||||
|
||||
## Specification of the length field contained in the data to split messages
|
||||
## with variable length. The specification contains the following fields:
|
||||
## offset -- start of length field in bytes from begin of data
|
||||
## bytes -- length of length field in bytes
|
||||
## endianness -- endianness of the value, either "be" for big endian or
|
||||
## "le" for little endian
|
||||
## header_length -- total length of header to be skipped when passing
|
||||
## data on to the parser. If zero (default), the header
|
||||
## is passed on to the parser together with the message.
|
||||
## Note: This setting is only used for splitting_strategy = "variable length".
|
||||
# splitting_length_field = {offset = 0, bytes = 0, endianness = "be", header_length = 0}
|
167
plugins/common/socket/splitters.go
Normal file
167
plugins/common/socket/splitters.go
Normal file
|
@ -0,0 +1,167 @@
|
|||
package socket
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type lengthFieldSpec struct {
|
||||
Offset int64 `toml:"offset"`
|
||||
Bytes int64 `toml:"bytes"`
|
||||
Endianness string `toml:"endianness"`
|
||||
HeaderLength int64 `toml:"header_length"`
|
||||
converter func([]byte) int
|
||||
}
|
||||
|
||||
type SplitConfig struct {
|
||||
SplittingStrategy string `toml:"splitting_strategy"`
|
||||
SplittingDelimiter string `toml:"splitting_delimiter"`
|
||||
SplittingLength int `toml:"splitting_length"`
|
||||
SplittingLengthField lengthFieldSpec `toml:"splitting_length_field"`
|
||||
}
|
||||
|
||||
func (cfg *SplitConfig) NewSplitter() (bufio.SplitFunc, error) {
|
||||
switch cfg.SplittingStrategy {
|
||||
case "", "newline":
|
||||
return bufio.ScanLines, nil
|
||||
case "null":
|
||||
return scanNull, nil
|
||||
case "delimiter":
|
||||
re := regexp.MustCompile(`(\s*0?x)`)
|
||||
d := re.ReplaceAllString(strings.ToLower(cfg.SplittingDelimiter), "")
|
||||
delimiter, err := hex.DecodeString(d)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decoding delimiter failed: %w", err)
|
||||
}
|
||||
return createScanDelimiter(delimiter), nil
|
||||
case "fixed length":
|
||||
return createScanFixedLength(cfg.SplittingLength), nil
|
||||
case "variable length":
|
||||
// Create the converter function
|
||||
var order binary.ByteOrder
|
||||
switch strings.ToLower(cfg.SplittingLengthField.Endianness) {
|
||||
case "", "be":
|
||||
order = binary.BigEndian
|
||||
case "le":
|
||||
order = binary.LittleEndian
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid 'endianness' %q", cfg.SplittingLengthField.Endianness)
|
||||
}
|
||||
|
||||
switch cfg.SplittingLengthField.Bytes {
|
||||
case 1:
|
||||
cfg.SplittingLengthField.converter = func(b []byte) int {
|
||||
return int(b[0])
|
||||
}
|
||||
case 2:
|
||||
cfg.SplittingLengthField.converter = func(b []byte) int {
|
||||
return int(order.Uint16(b))
|
||||
}
|
||||
case 4:
|
||||
cfg.SplittingLengthField.converter = func(b []byte) int {
|
||||
return int(order.Uint32(b))
|
||||
}
|
||||
case 8:
|
||||
cfg.SplittingLengthField.converter = func(b []byte) int {
|
||||
return int(order.Uint64(b))
|
||||
}
|
||||
default:
|
||||
cfg.SplittingLengthField.converter = func(b []byte) int {
|
||||
buf := make([]byte, 8)
|
||||
start := 0
|
||||
if order == binary.BigEndian {
|
||||
start = 8 - len(b)
|
||||
}
|
||||
for i := 0; i < len(b); i++ {
|
||||
buf[start+i] = b[i]
|
||||
}
|
||||
return int(order.Uint64(buf))
|
||||
}
|
||||
}
|
||||
|
||||
// Check if we have enough bytes in the header
|
||||
return createScanVariableLength(cfg.SplittingLengthField), nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unknown 'splitting_strategy' %q", cfg.SplittingStrategy)
|
||||
}
|
||||
|
||||
func scanNull(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
if i := bytes.IndexByte(data, 0); i >= 0 {
|
||||
return i + 1, data[:i], nil
|
||||
}
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
// Request more data.
|
||||
return 0, nil, nil
|
||||
}
|
||||
|
||||
func createScanDelimiter(delimiter []byte) bufio.SplitFunc {
|
||||
return func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
if i := bytes.Index(data, delimiter); i >= 0 {
|
||||
return i + len(delimiter), data[:i], nil
|
||||
}
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
// Request more data.
|
||||
return 0, nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
func createScanFixedLength(length int) bufio.SplitFunc {
|
||||
return func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
if len(data) >= length {
|
||||
return length, data[:length], nil
|
||||
}
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
// Request more data.
|
||||
return 0, nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
func createScanVariableLength(spec lengthFieldSpec) bufio.SplitFunc {
|
||||
minlen := int(spec.Offset)
|
||||
minlen += int(spec.Bytes)
|
||||
headerLen := int(spec.HeaderLength)
|
||||
|
||||
return func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
dataLen := len(data)
|
||||
if dataLen >= minlen {
|
||||
// Extract the length field and convert it to a number
|
||||
lf := data[spec.Offset : spec.Offset+spec.Bytes]
|
||||
length := spec.converter(lf)
|
||||
start := headerLen
|
||||
end := length + headerLen
|
||||
// If we have enough data return it without the header
|
||||
if end <= dataLen {
|
||||
return end, data[start:end], nil
|
||||
}
|
||||
}
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
// Request more data.
|
||||
return 0, nil, nil
|
||||
}
|
||||
}
|
476
plugins/common/socket/stream.go
Normal file
476
plugins/common/socket/stream.go
Normal file
|
@ -0,0 +1,476 @@
|
|||
package socket
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"net"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/alitto/pond"
|
||||
"github.com/mdlayher/vsock"
|
||||
|
||||
"github.com/influxdata/telegraf"
|
||||
"github.com/influxdata/telegraf/config"
|
||||
"github.com/influxdata/telegraf/internal"
|
||||
)
|
||||
|
||||
type hasSetReadBuffer interface {
|
||||
SetReadBuffer(bytes int) error
|
||||
}
|
||||
|
||||
type streamListener struct {
|
||||
Encoding string
|
||||
ReadBufferSize int
|
||||
MaxConnections uint64
|
||||
ReadTimeout config.Duration
|
||||
KeepAlivePeriod *config.Duration
|
||||
Splitter bufio.SplitFunc
|
||||
Log telegraf.Logger
|
||||
|
||||
listener net.Listener
|
||||
connections uint64
|
||||
path string
|
||||
cancel context.CancelFunc
|
||||
parsePool *pond.WorkerPool
|
||||
|
||||
wg sync.WaitGroup
|
||||
sync.Mutex
|
||||
}
|
||||
|
||||
func newStreamListener(conf Config, splitter bufio.SplitFunc, log telegraf.Logger) *streamListener {
|
||||
return &streamListener{
|
||||
ReadBufferSize: int(conf.ReadBufferSize),
|
||||
ReadTimeout: conf.ReadTimeout,
|
||||
KeepAlivePeriod: conf.KeepAlivePeriod,
|
||||
MaxConnections: conf.MaxConnections,
|
||||
Encoding: conf.ContentEncoding,
|
||||
Splitter: splitter,
|
||||
Log: log,
|
||||
|
||||
parsePool: pond.New(
|
||||
conf.MaxParallelParsers,
|
||||
0,
|
||||
pond.MinWorkers(conf.MaxParallelParsers/2+1)),
|
||||
}
|
||||
}
|
||||
|
||||
func (l *streamListener) setupTCP(u *url.URL, tlsCfg *tls.Config) error {
|
||||
var err error
|
||||
if tlsCfg == nil {
|
||||
l.listener, err = net.Listen(u.Scheme, u.Host)
|
||||
} else {
|
||||
l.listener, err = tls.Listen(u.Scheme, u.Host, tlsCfg)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (l *streamListener) setupUnix(u *url.URL, tlsCfg *tls.Config, socketMode string) error {
|
||||
l.path = filepath.FromSlash(u.Path)
|
||||
if runtime.GOOS == "windows" && strings.Contains(l.path, ":") {
|
||||
l.path = strings.TrimPrefix(l.path, `\`)
|
||||
}
|
||||
if err := os.Remove(l.path); err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
return fmt.Errorf("removing socket failed: %w", err)
|
||||
}
|
||||
|
||||
var err error
|
||||
if tlsCfg == nil {
|
||||
l.listener, err = net.Listen(u.Scheme, l.path)
|
||||
} else {
|
||||
l.listener, err = tls.Listen(u.Scheme, l.path, tlsCfg)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Set permissions on socket
|
||||
if socketMode != "" {
|
||||
// Convert from octal in string to int
|
||||
i, err := strconv.ParseUint(socketMode, 8, 32)
|
||||
if err != nil {
|
||||
return fmt.Errorf("converting socket mode failed: %w", err)
|
||||
}
|
||||
|
||||
perm := os.FileMode(uint32(i))
|
||||
if err := os.Chmod(u.Path, perm); err != nil {
|
||||
return fmt.Errorf("changing socket permissions failed: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *streamListener) setupVsock(u *url.URL) error {
|
||||
var err error
|
||||
|
||||
addrTuple := strings.SplitN(u.String(), ":", 2)
|
||||
|
||||
// Check address string for containing two tokens
|
||||
if len(addrTuple) < 2 {
|
||||
return errors.New("port and/or CID number missing")
|
||||
}
|
||||
// Parse CID and port number from address string both being 32-bit
|
||||
// source: https://man7.org/linux/man-pages/man7/vsock.7.html
|
||||
cid, err := strconv.ParseUint(addrTuple[0], 10, 32)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse CID %s: %w", addrTuple[0], err)
|
||||
}
|
||||
if (cid >= uint64(math.Pow(2, 32))-1) && (cid <= 0) {
|
||||
return fmt.Errorf("value of CID %d is out of range", cid)
|
||||
}
|
||||
port, err := strconv.ParseUint(addrTuple[1], 10, 32)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse port number %s: %w", addrTuple[1], err)
|
||||
}
|
||||
if (port >= uint64(math.Pow(2, 32))-1) && (port <= 0) {
|
||||
return fmt.Errorf("port number %d is out of range", port)
|
||||
}
|
||||
|
||||
l.listener, err = vsock.Listen(uint32(port), nil)
|
||||
return err
|
||||
}
|
||||
|
||||
func (l *streamListener) setupConnection(conn net.Conn) error {
|
||||
addr := conn.RemoteAddr().String()
|
||||
l.Lock()
|
||||
if l.MaxConnections > 0 && l.connections >= l.MaxConnections {
|
||||
l.Unlock()
|
||||
// Ignore the returned error as we cannot do anything about it anyway
|
||||
_ = conn.Close()
|
||||
return fmt.Errorf("unable to accept connection from %q: too many connections", addr)
|
||||
}
|
||||
l.connections++
|
||||
l.Unlock()
|
||||
|
||||
if l.ReadBufferSize > 0 {
|
||||
if rb, ok := conn.(hasSetReadBuffer); ok {
|
||||
if err := rb.SetReadBuffer(l.ReadBufferSize); err != nil {
|
||||
l.Log.Warnf("Setting read buffer on socket failed: %v", err)
|
||||
}
|
||||
} else {
|
||||
l.Log.Warn("Cannot set read buffer on socket of this type")
|
||||
}
|
||||
}
|
||||
|
||||
// Set keep alive handlings
|
||||
if l.KeepAlivePeriod != nil {
|
||||
if c, ok := conn.(*tls.Conn); ok {
|
||||
conn = c.NetConn()
|
||||
}
|
||||
tcpConn, ok := conn.(*net.TCPConn)
|
||||
if !ok {
|
||||
l.Log.Warnf("connection not a TCP connection (%T)", conn)
|
||||
}
|
||||
if *l.KeepAlivePeriod == 0 {
|
||||
if err := tcpConn.SetKeepAlive(false); err != nil {
|
||||
l.Log.Warnf("Cannot set keep-alive: %v", err)
|
||||
}
|
||||
} else {
|
||||
if err := tcpConn.SetKeepAlive(true); err != nil {
|
||||
l.Log.Warnf("Cannot set keep-alive: %v", err)
|
||||
}
|
||||
err := tcpConn.SetKeepAlivePeriod(time.Duration(*l.KeepAlivePeriod))
|
||||
if err != nil {
|
||||
l.Log.Warnf("Cannot set keep-alive period: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *streamListener) closeConnection(conn net.Conn) {
|
||||
// Fallback to enforce blocked reads on connections to end immediately
|
||||
//nolint:errcheck // Ignore errors as this is a fallback only
|
||||
conn.SetReadDeadline(time.Now())
|
||||
|
||||
addr := conn.RemoteAddr().String()
|
||||
if err := conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) && !errors.Is(err, syscall.EPIPE) {
|
||||
l.Log.Warnf("Cannot close connection to %q: %v", addr, err)
|
||||
} else {
|
||||
l.Lock()
|
||||
l.connections--
|
||||
l.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (l *streamListener) address() net.Addr {
|
||||
return l.listener.Addr()
|
||||
}
|
||||
|
||||
func (l *streamListener) close() error {
|
||||
if l.listener != nil {
|
||||
// Continue even if we cannot close the listener in order to at least
|
||||
// close all active connections
|
||||
if err := l.listener.Close(); err != nil {
|
||||
l.Log.Errorf("Cannot close listener: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if l.cancel != nil {
|
||||
l.cancel()
|
||||
l.cancel = nil
|
||||
}
|
||||
l.wg.Wait()
|
||||
|
||||
if l.path != "" {
|
||||
fn := filepath.FromSlash(l.path)
|
||||
if runtime.GOOS == "windows" && strings.Contains(fn, ":") {
|
||||
fn = strings.TrimPrefix(fn, `\`)
|
||||
}
|
||||
// Ignore file-not-exists errors when removing the socket
|
||||
if err := os.Remove(fn); err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
l.parsePool.StopAndWait()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *streamListener) listenData(onData CallbackData, onError CallbackError) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
l.cancel = cancel
|
||||
|
||||
l.wg.Add(1)
|
||||
go func() {
|
||||
defer l.wg.Done()
|
||||
|
||||
for {
|
||||
conn, err := l.listener.Accept()
|
||||
if err != nil {
|
||||
if !errors.Is(err, net.ErrClosed) && onError != nil {
|
||||
onError(err)
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
if err := l.setupConnection(conn); err != nil && onError != nil {
|
||||
onError(err)
|
||||
continue
|
||||
}
|
||||
|
||||
l.wg.Add(1)
|
||||
go l.handleReaderConn(ctx, conn, onData, onError)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (l *streamListener) handleReaderConn(ctx context.Context, conn net.Conn, onData CallbackData, onError CallbackError) {
|
||||
defer l.wg.Done()
|
||||
|
||||
localCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
defer l.closeConnection(conn)
|
||||
stopFunc := context.AfterFunc(localCtx, func() { l.closeConnection(conn) })
|
||||
defer stopFunc()
|
||||
|
||||
reader := l.read
|
||||
if l.Splitter == nil {
|
||||
reader = l.readAll
|
||||
}
|
||||
if err := reader(conn, onData); err != nil {
|
||||
if !errors.Is(err, io.EOF) && !errors.Is(err, syscall.ECONNRESET) {
|
||||
if onError != nil {
|
||||
onError(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *streamListener) listenConnection(onConnection CallbackConnection, onError CallbackError) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
l.cancel = cancel
|
||||
|
||||
l.wg.Add(1)
|
||||
go func() {
|
||||
defer l.wg.Done()
|
||||
|
||||
for {
|
||||
conn, err := l.listener.Accept()
|
||||
if err != nil {
|
||||
if !errors.Is(err, net.ErrClosed) && onError != nil {
|
||||
onError(err)
|
||||
}
|
||||
break
|
||||
}
|
||||
if err := l.setupConnection(conn); err != nil && onError != nil {
|
||||
onError(err)
|
||||
continue
|
||||
}
|
||||
|
||||
l.wg.Add(1)
|
||||
go func(c net.Conn) {
|
||||
if err := l.handleConnection(ctx, c, onConnection); err != nil {
|
||||
if !errors.Is(err, io.EOF) && !errors.Is(err, syscall.ECONNRESET) {
|
||||
if onError != nil {
|
||||
onError(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}(conn)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (l *streamListener) read(conn net.Conn, onData CallbackData) error {
|
||||
decoder, err := internal.NewStreamContentDecoder(l.Encoding, conn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating decoder failed: %w", err)
|
||||
}
|
||||
|
||||
timeout := time.Duration(l.ReadTimeout)
|
||||
|
||||
scanner := bufio.NewScanner(decoder)
|
||||
if l.ReadBufferSize > bufio.MaxScanTokenSize {
|
||||
scanner.Buffer(make([]byte, l.ReadBufferSize), l.ReadBufferSize)
|
||||
}
|
||||
scanner.Split(l.Splitter)
|
||||
for {
|
||||
// Set the read deadline, if any, then start reading. The read
|
||||
// will accept the deadline and return if no or insufficient data
|
||||
// arrived in time. We need to set the deadline in every cycle as
|
||||
// it is an ABSOLUTE time and not a timeout.
|
||||
if timeout > 0 {
|
||||
deadline := time.Now().Add(timeout)
|
||||
if err := conn.SetReadDeadline(deadline); err != nil {
|
||||
return fmt.Errorf("setting read deadline failed: %w", err)
|
||||
}
|
||||
}
|
||||
if !scanner.Scan() {
|
||||
// Exit if no data arrived e.g. due to timeout or closed connection
|
||||
break
|
||||
}
|
||||
|
||||
receiveTime := time.Now()
|
||||
src := conn.RemoteAddr()
|
||||
if l.path != "" {
|
||||
src = &net.UnixAddr{Name: l.path, Net: "unix"}
|
||||
}
|
||||
|
||||
data := scanner.Bytes()
|
||||
d := make([]byte, len(data))
|
||||
copy(d, data)
|
||||
l.parsePool.Submit(func() {
|
||||
onData(src, d, receiveTime)
|
||||
})
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
if errors.Is(err, os.ErrDeadlineExceeded) {
|
||||
// Ignore the timeout and silently close the connection
|
||||
l.Log.Debug(err)
|
||||
return nil
|
||||
}
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
// Ignore the connection closing of the remote side
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *streamListener) readAll(conn net.Conn, onData CallbackData) error {
|
||||
src := conn.RemoteAddr()
|
||||
if l.path != "" {
|
||||
src = &net.UnixAddr{Name: l.path, Net: "unix"}
|
||||
}
|
||||
|
||||
decoder, err := internal.NewStreamContentDecoder(l.Encoding, conn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating decoder failed: %w", err)
|
||||
}
|
||||
|
||||
timeout := time.Duration(l.ReadTimeout)
|
||||
// Set the read deadline, if any, then start reading. The read
|
||||
// will accept the deadline and return if no or insufficient data
|
||||
// arrived in time. We need to set the deadline in every cycle as
|
||||
// it is an ABSOLUTE time and not a timeout.
|
||||
if timeout > 0 {
|
||||
deadline := time.Now().Add(timeout)
|
||||
if err := conn.SetReadDeadline(deadline); err != nil {
|
||||
return fmt.Errorf("setting read deadline failed: %w", err)
|
||||
}
|
||||
}
|
||||
buf, err := io.ReadAll(decoder)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read on %s failed: %w", src, err)
|
||||
}
|
||||
|
||||
receiveTime := time.Now()
|
||||
l.parsePool.Submit(func() {
|
||||
onData(src, buf, receiveTime)
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *streamListener) handleConnection(ctx context.Context, conn net.Conn, onConnection CallbackConnection) error {
|
||||
defer l.wg.Done()
|
||||
|
||||
localCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
defer l.closeConnection(conn)
|
||||
stopFunc := context.AfterFunc(localCtx, func() { l.closeConnection(conn) })
|
||||
defer stopFunc()
|
||||
|
||||
// Prepare the data decoder for the connection
|
||||
decoder, err := internal.NewStreamContentDecoder(l.Encoding, conn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating decoder failed: %w", err)
|
||||
}
|
||||
|
||||
// Get the remote address
|
||||
src := conn.RemoteAddr()
|
||||
if l.path != "" {
|
||||
src = &net.UnixAddr{Name: l.path, Net: "unix"}
|
||||
}
|
||||
|
||||
// Create a pipe and feed it to the callback
|
||||
reader, writer := io.Pipe()
|
||||
defer writer.Close()
|
||||
go onConnection(src, reader)
|
||||
|
||||
timeout := time.Duration(l.ReadTimeout)
|
||||
buf := make([]byte, 4096) // 4kb
|
||||
for {
|
||||
// Set the read deadline, if any, then start reading. The read
|
||||
// will accept the deadline and return if no or insufficient data
|
||||
// arrived in time. We need to set the deadline in every cycle as
|
||||
// it is an ABSOLUTE time and not a timeout.
|
||||
if timeout > 0 {
|
||||
deadline := time.Now().Add(timeout)
|
||||
if err := conn.SetReadDeadline(deadline); err != nil {
|
||||
return fmt.Errorf("setting read deadline failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Copy the data
|
||||
n, err := decoder.Read(buf)
|
||||
if err != nil {
|
||||
if !strings.HasSuffix(err.Error(), ": use of closed network connection") {
|
||||
if !errors.Is(err, os.ErrDeadlineExceeded) && errors.Is(err, net.ErrClosed) {
|
||||
writer.CloseWithError(err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if _, err := writer.Write(buf[:n]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
305
plugins/common/starlark/builtins.go
Normal file
305
plugins/common/starlark/builtins.go
Normal file
|
@ -0,0 +1,305 @@
|
|||
package starlark
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"go.starlark.net/starlark"
|
||||
|
||||
"github.com/influxdata/telegraf"
|
||||
"github.com/influxdata/telegraf/metric"
|
||||
)
|
||||
|
||||
func newMetric(_ *starlark.Thread, _ *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) {
|
||||
var (
|
||||
name starlark.String
|
||||
tags, fields starlark.Value
|
||||
)
|
||||
if err := starlark.UnpackArgs("Metric", args, kwargs, "name", &name, "tags?", &tags, "fields?", &fields); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
allFields, err := toFields(fields)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
allTags, err := toTags(tags)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m := metric.New(string(name), allTags, allFields, time.Now())
|
||||
|
||||
return &Metric{metric: m}, nil
|
||||
}
|
||||
|
||||
func toString(value starlark.Value, errorMsg string) (string, error) {
|
||||
if value, ok := value.(starlark.String); ok {
|
||||
return string(value), nil
|
||||
}
|
||||
return "", fmt.Errorf(errorMsg, value)
|
||||
}
|
||||
|
||||
func items(value starlark.Value, errorMsg string) ([]starlark.Tuple, error) {
|
||||
if iter, ok := value.(starlark.IterableMapping); ok {
|
||||
return iter.Items(), nil
|
||||
}
|
||||
return nil, fmt.Errorf(errorMsg, value)
|
||||
}
|
||||
|
||||
func deepcopy(_ *starlark.Thread, _ *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) {
|
||||
var sm *Metric
|
||||
var track bool
|
||||
if err := starlark.UnpackArgs("deepcopy", args, kwargs, "source", &sm, "track?", &track); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// In case we copy a tracking metric but do not want to track the result,
|
||||
// we have to strip the tracking information. This can be done by unwrapping
|
||||
// the metric.
|
||||
if tm, ok := sm.metric.(telegraf.TrackingMetric); ok && !track {
|
||||
return &Metric{metric: tm.Unwrap().Copy()}, nil
|
||||
}
|
||||
|
||||
// Copy the whole metric including potential tracking information
|
||||
return &Metric{metric: sm.metric.Copy()}, nil
|
||||
}
|
||||
|
||||
// catch(f) evaluates f() and returns its evaluation error message
|
||||
// if it failed or None if it succeeded.
|
||||
func catch(thread *starlark.Thread, _ *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) {
|
||||
var fn starlark.Callable
|
||||
if err := starlark.UnpackArgs("catch", args, kwargs, "fn", &fn); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, err := starlark.Call(thread, fn, nil, nil); err != nil {
|
||||
//nolint:nilerr // nil returned on purpose, error put inside starlark.Value
|
||||
return starlark.String(err.Error()), nil
|
||||
}
|
||||
return starlark.None, nil
|
||||
}
|
||||
|
||||
type builtinMethod func(b *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error)
|
||||
|
||||
func builtinAttr(recv starlark.Value, name string, methods map[string]builtinMethod) (starlark.Value, error) {
|
||||
method := methods[name]
|
||||
if method == nil {
|
||||
return starlark.None, fmt.Errorf("no such method %q", name)
|
||||
}
|
||||
|
||||
// Allocate a closure over 'method'.
|
||||
impl := func(_ *starlark.Thread, b *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) {
|
||||
return method(b, args, kwargs)
|
||||
}
|
||||
return starlark.NewBuiltin(name, impl).BindReceiver(recv), nil
|
||||
}
|
||||
|
||||
func builtinAttrNames(methods map[string]builtinMethod) []string {
|
||||
names := make([]string, 0, len(methods))
|
||||
for name := range methods {
|
||||
names = append(names, name)
|
||||
}
|
||||
sort.Strings(names)
|
||||
return names
|
||||
}
|
||||
|
||||
// --- dictionary methods ---
|
||||
|
||||
// https://github.com/google/starlark-go/blob/master/doc/spec.md#dict·clear
|
||||
func dictClear(b *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) {
|
||||
if err := starlark.UnpackPositionalArgs(b.Name(), args, kwargs, 0); err != nil {
|
||||
return starlark.None, fmt.Errorf("%s: %w", b.Name(), err)
|
||||
}
|
||||
|
||||
type HasClear interface {
|
||||
Clear() error
|
||||
}
|
||||
return starlark.None, b.Receiver().(HasClear).Clear()
|
||||
}
|
||||
|
||||
// https://github.com/google/starlark-go/blob/master/doc/spec.md#dict·pop
|
||||
func dictPop(b *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) {
|
||||
var k, d starlark.Value
|
||||
if err := starlark.UnpackPositionalArgs(b.Name(), args, kwargs, 1, &k, &d); err != nil {
|
||||
return starlark.None, fmt.Errorf("%s: %w", b.Name(), err)
|
||||
}
|
||||
|
||||
type HasDelete interface {
|
||||
Delete(k starlark.Value) (starlark.Value, bool, error)
|
||||
}
|
||||
if v, found, err := b.Receiver().(HasDelete).Delete(k); err != nil {
|
||||
return starlark.None, fmt.Errorf("%s: %w", b.Name(), err) // dict is frozen or key is unhashable
|
||||
} else if found {
|
||||
return v, nil
|
||||
} else if d != nil {
|
||||
return d, nil
|
||||
}
|
||||
return starlark.None, fmt.Errorf("%s: missing key", b.Name())
|
||||
}
|
||||
|
||||
// https://github.com/google/starlark-go/blob/master/doc/spec.md#dict·popitem
|
||||
func dictPopitem(b *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) {
|
||||
if err := starlark.UnpackPositionalArgs(b.Name(), args, kwargs, 0); err != nil {
|
||||
return starlark.None, fmt.Errorf("%s: %w", b.Name(), err)
|
||||
}
|
||||
|
||||
type HasPopItem interface {
|
||||
PopItem() (starlark.Value, error)
|
||||
}
|
||||
return b.Receiver().(HasPopItem).PopItem()
|
||||
}
|
||||
|
||||
// https://github.com/google/starlark-go/blob/master/doc/spec.md#dict·get
|
||||
func dictGet(b *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) {
|
||||
var key, dflt starlark.Value
|
||||
if err := starlark.UnpackPositionalArgs(b.Name(), args, kwargs, 1, &key, &dflt); err != nil {
|
||||
return starlark.None, fmt.Errorf("%s: %w", b.Name(), err)
|
||||
}
|
||||
if v, ok, err := b.Receiver().(starlark.Mapping).Get(key); err != nil {
|
||||
return starlark.None, fmt.Errorf("%s: %w", b.Name(), err)
|
||||
} else if ok {
|
||||
return v, nil
|
||||
} else if dflt != nil {
|
||||
return dflt, nil
|
||||
}
|
||||
return starlark.None, nil
|
||||
}
|
||||
|
||||
// https://github.com/google/starlark-go/blob/master/doc/spec.md#dict·setdefault
|
||||
func dictSetdefault(b *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) {
|
||||
var key, dflt starlark.Value = nil, starlark.None
|
||||
if err := starlark.UnpackPositionalArgs(b.Name(), args, kwargs, 1, &key, &dflt); err != nil {
|
||||
return starlark.None, fmt.Errorf("%s: %w", b.Name(), err)
|
||||
}
|
||||
|
||||
recv := b.Receiver().(starlark.HasSetKey)
|
||||
v, found, err := recv.Get(key)
|
||||
if err != nil {
|
||||
return starlark.None, fmt.Errorf("%s: %w", b.Name(), err)
|
||||
}
|
||||
if !found {
|
||||
v = dflt
|
||||
if err := recv.SetKey(key, dflt); err != nil {
|
||||
return starlark.None, fmt.Errorf("%s: %w", b.Name(), err)
|
||||
}
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// https://github.com/google/starlark-go/blob/master/doc/spec.md#dict·update
|
||||
func dictUpdate(b *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) {
|
||||
// Unpack the arguments
|
||||
if len(args) > 1 {
|
||||
return nil, fmt.Errorf("update: got %d arguments, want at most 1", len(args))
|
||||
}
|
||||
|
||||
// Get the target
|
||||
dict := b.Receiver().(starlark.HasSetKey)
|
||||
|
||||
if len(args) == 1 {
|
||||
switch updates := args[0].(type) {
|
||||
case starlark.IterableMapping:
|
||||
// Iterate over dict's key/value pairs, not just keys.
|
||||
for _, item := range updates.Items() {
|
||||
if err := dict.SetKey(item[0], item[1]); err != nil {
|
||||
return nil, err // dict is frozen
|
||||
}
|
||||
}
|
||||
default:
|
||||
// all other sequences
|
||||
iter := starlark.Iterate(updates)
|
||||
if iter == nil {
|
||||
return nil, fmt.Errorf("got %s, want iterable", updates.Type())
|
||||
}
|
||||
defer iter.Done()
|
||||
var pair starlark.Value
|
||||
for i := 0; iter.Next(&pair); i++ {
|
||||
iterErr := func() error {
|
||||
iter2 := starlark.Iterate(pair)
|
||||
if iter2 == nil {
|
||||
return fmt.Errorf("dictionary update sequence element #%d is not iterable (%s)", i, pair.Type())
|
||||
}
|
||||
defer iter2.Done()
|
||||
length := starlark.Len(pair)
|
||||
if length < 0 {
|
||||
return fmt.Errorf("dictionary update sequence element #%d has unknown length (%s)", i, pair.Type())
|
||||
} else if length != 2 {
|
||||
return fmt.Errorf("dictionary update sequence element #%d has length %d, want 2", i, length)
|
||||
}
|
||||
var k, v starlark.Value
|
||||
iter2.Next(&k)
|
||||
iter2.Next(&v)
|
||||
|
||||
return dict.SetKey(k, v)
|
||||
}()
|
||||
|
||||
if iterErr != nil {
|
||||
return nil, iterErr
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Then add the kwargs.
|
||||
before := starlark.Len(dict)
|
||||
for _, pair := range kwargs {
|
||||
if err := dict.SetKey(pair[0], pair[1]); err != nil {
|
||||
return nil, err // dict is frozen
|
||||
}
|
||||
}
|
||||
// In the common case, each kwarg will add another dict entry.
|
||||
// If that's not so, check whether it is because there was a duplicate kwarg.
|
||||
if starlark.Len(dict) < before+len(kwargs) {
|
||||
keys := make(map[starlark.String]bool, len(kwargs))
|
||||
for _, kv := range kwargs {
|
||||
k := kv[0].(starlark.String)
|
||||
if keys[k] {
|
||||
return nil, fmt.Errorf("duplicate keyword arg: %v", k)
|
||||
}
|
||||
keys[k] = true
|
||||
}
|
||||
}
|
||||
|
||||
return starlark.None, nil
|
||||
}
|
||||
|
||||
// https://github.com/google/starlark-go/blob/master/doc/spec.md#dict·items
|
||||
func dictItems(b *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) {
|
||||
if err := starlark.UnpackPositionalArgs(b.Name(), args, kwargs, 0); err != nil {
|
||||
return starlark.None, fmt.Errorf("%s: %w", b.Name(), err)
|
||||
}
|
||||
items := b.Receiver().(starlark.IterableMapping).Items()
|
||||
res := make([]starlark.Value, 0, len(items))
|
||||
for _, item := range items {
|
||||
res = append(res, item) // convert [2]starlark.Value to starlark.Value
|
||||
}
|
||||
return starlark.NewList(res), nil
|
||||
}
|
||||
|
||||
// https://github.com/google/starlark-go/blob/master/doc/spec.md#dict·keys
|
||||
func dictKeys(b *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) {
|
||||
if err := starlark.UnpackPositionalArgs(b.Name(), args, kwargs, 0); err != nil {
|
||||
return starlark.None, fmt.Errorf("%s: %w", b.Name(), err)
|
||||
}
|
||||
|
||||
items := b.Receiver().(starlark.IterableMapping).Items()
|
||||
res := make([]starlark.Value, 0, len(items))
|
||||
for _, item := range items {
|
||||
res = append(res, item[0])
|
||||
}
|
||||
return starlark.NewList(res), nil
|
||||
}
|
||||
|
||||
// https://github.com/google/starlark-go/blob/master/doc/spec.md#dict·update
|
||||
func dictValues(b *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) {
|
||||
if err := starlark.UnpackPositionalArgs(b.Name(), args, kwargs, 0); err != nil {
|
||||
return starlark.None, fmt.Errorf("%s: %w", b.Name(), err)
|
||||
}
|
||||
items := b.Receiver().(starlark.IterableMapping).Items()
|
||||
res := make([]starlark.Value, 0, len(items))
|
||||
for _, item := range items {
|
||||
res = append(res, item[1])
|
||||
}
|
||||
return starlark.NewList(res), nil
|
||||
}
|
308
plugins/common/starlark/field_dict.go
Normal file
308
plugins/common/starlark/field_dict.go
Normal file
|
@ -0,0 +1,308 @@
|
|||
package starlark
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"go.starlark.net/starlark"
|
||||
|
||||
"github.com/influxdata/telegraf"
|
||||
)
|
||||
|
||||
// FieldDict is a starlark.Value for the metric fields. It is heavily based on the
|
||||
// starlark.Dict.
|
||||
type FieldDict struct {
|
||||
*Metric
|
||||
}
|
||||
|
||||
func (d FieldDict) String() string {
|
||||
buf := new(strings.Builder)
|
||||
buf.WriteString("{")
|
||||
sep := ""
|
||||
for _, item := range d.Items() {
|
||||
k, v := item[0], item[1]
|
||||
buf.WriteString(sep)
|
||||
buf.WriteString(k.String())
|
||||
buf.WriteString(": ")
|
||||
buf.WriteString(v.String())
|
||||
sep = ", "
|
||||
}
|
||||
buf.WriteString("}")
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
func (FieldDict) Type() string {
|
||||
return "Fields"
|
||||
}
|
||||
|
||||
func (d FieldDict) Freeze() {
|
||||
// Disable linter check as the frozen variable is modified despite
|
||||
// passing a value instead of a pointer, because `FieldDict` holds
|
||||
// a pointer to the underlying metric containing the `frozen` field.
|
||||
//revive:disable:modifies-value-receiver
|
||||
d.frozen = true
|
||||
}
|
||||
|
||||
func (d FieldDict) Truth() starlark.Bool {
|
||||
return len(d.metric.FieldList()) != 0
|
||||
}
|
||||
|
||||
func (FieldDict) Hash() (uint32, error) {
|
||||
return 0, errors.New("not hashable")
|
||||
}
|
||||
|
||||
// AttrNames implements the starlark.HasAttrs interface.
|
||||
func (FieldDict) AttrNames() []string {
|
||||
return builtinAttrNames(FieldDictMethods)
|
||||
}
|
||||
|
||||
// Attr implements the starlark.HasAttrs interface.
|
||||
func (d FieldDict) Attr(name string) (starlark.Value, error) {
|
||||
return builtinAttr(d, name, FieldDictMethods)
|
||||
}
|
||||
|
||||
var FieldDictMethods = map[string]builtinMethod{
|
||||
"clear": dictClear,
|
||||
"get": dictGet,
|
||||
"items": dictItems,
|
||||
"keys": dictKeys,
|
||||
"pop": dictPop,
|
||||
"popitem": dictPopitem,
|
||||
"setdefault": dictSetdefault,
|
||||
"update": dictUpdate,
|
||||
"values": dictValues,
|
||||
}
|
||||
|
||||
// Get implements the starlark.Mapping interface.
|
||||
func (d FieldDict) Get(key starlark.Value) (v starlark.Value, found bool, err error) {
|
||||
if k, ok := key.(starlark.String); ok {
|
||||
gv, found := d.metric.GetField(k.GoString())
|
||||
if !found {
|
||||
return starlark.None, false, nil
|
||||
}
|
||||
|
||||
v, err := asStarlarkValue(gv)
|
||||
if err != nil {
|
||||
return starlark.None, false, err
|
||||
}
|
||||
return v, true, nil
|
||||
}
|
||||
|
||||
return starlark.None, false, errors.New("key must be of type 'str'")
|
||||
}
|
||||
|
||||
// SetKey implements the starlark.HasSetKey interface to support map update
|
||||
// using x[k]=v syntax, like a dictionary.
|
||||
func (d FieldDict) SetKey(k, v starlark.Value) error {
|
||||
if d.fieldIterCount > 0 {
|
||||
return errors.New("cannot insert during iteration")
|
||||
}
|
||||
|
||||
key, ok := k.(starlark.String)
|
||||
if !ok {
|
||||
return errors.New("field key must be of type 'str'")
|
||||
}
|
||||
|
||||
gv, err := asGoValue(v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
d.metric.AddField(key.GoString(), gv)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Items implements the starlark.IterableMapping interface.
|
||||
func (d FieldDict) Items() []starlark.Tuple {
|
||||
items := make([]starlark.Tuple, 0, len(d.metric.FieldList()))
|
||||
for _, field := range d.metric.FieldList() {
|
||||
key := starlark.String(field.Key)
|
||||
sv, err := asStarlarkValue(field.Value)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
pair := starlark.Tuple{key, sv}
|
||||
items = append(items, pair)
|
||||
}
|
||||
return items
|
||||
}
|
||||
|
||||
func (d FieldDict) Clear() error {
|
||||
if d.fieldIterCount > 0 {
|
||||
return errors.New("cannot delete during iteration")
|
||||
}
|
||||
|
||||
keys := make([]string, 0, len(d.metric.FieldList()))
|
||||
for _, field := range d.metric.FieldList() {
|
||||
keys = append(keys, field.Key)
|
||||
}
|
||||
|
||||
for _, key := range keys {
|
||||
d.metric.RemoveField(key)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d FieldDict) PopItem() (starlark.Value, error) {
|
||||
if d.fieldIterCount > 0 {
|
||||
return nil, errors.New("cannot delete during iteration")
|
||||
}
|
||||
|
||||
if len(d.metric.FieldList()) == 0 {
|
||||
return nil, errors.New("popitem(): field dictionary is empty")
|
||||
}
|
||||
|
||||
field := d.metric.FieldList()[0]
|
||||
k := field.Key
|
||||
v := field.Value
|
||||
|
||||
d.metric.RemoveField(k)
|
||||
|
||||
sk := starlark.String(k)
|
||||
sv, err := asStarlarkValue(v)
|
||||
if err != nil {
|
||||
return nil, errors.New("could not convert to starlark value")
|
||||
}
|
||||
|
||||
return starlark.Tuple{sk, sv}, nil
|
||||
}
|
||||
|
||||
func (d FieldDict) Delete(k starlark.Value) (v starlark.Value, found bool, err error) {
|
||||
if d.fieldIterCount > 0 {
|
||||
return nil, false, errors.New("cannot delete during iteration")
|
||||
}
|
||||
|
||||
if key, ok := k.(starlark.String); ok {
|
||||
value, ok := d.metric.GetField(key.GoString())
|
||||
if ok {
|
||||
d.metric.RemoveField(key.GoString())
|
||||
sv, err := asStarlarkValue(value)
|
||||
return sv, ok, err
|
||||
}
|
||||
return starlark.None, false, nil
|
||||
}
|
||||
|
||||
return starlark.None, false, errors.New("key must be of type 'str'")
|
||||
}
|
||||
|
||||
// Iterate implements the starlark.Iterator interface.
|
||||
func (d FieldDict) Iterate() starlark.Iterator {
|
||||
d.fieldIterCount++
|
||||
return &FieldIterator{Metric: d.Metric, fields: d.metric.FieldList()}
|
||||
}
|
||||
|
||||
type FieldIterator struct {
|
||||
*Metric
|
||||
fields []*telegraf.Field
|
||||
}
|
||||
|
||||
// Next implements the starlark.Iterator interface.
|
||||
func (i *FieldIterator) Next(p *starlark.Value) bool {
|
||||
if len(i.fields) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
field := i.fields[0]
|
||||
i.fields = i.fields[1:]
|
||||
*p = starlark.String(field.Key)
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// Done implements the starlark.Iterator interface.
|
||||
func (i *FieldIterator) Done() {
|
||||
i.fieldIterCount--
|
||||
}
|
||||
|
||||
// AsStarlarkValue converts a field value to a starlark.Value.
|
||||
func asStarlarkValue(value interface{}) (starlark.Value, error) {
|
||||
v := reflect.ValueOf(value)
|
||||
switch v.Kind() {
|
||||
case reflect.Slice:
|
||||
length := v.Len()
|
||||
array := make([]starlark.Value, 0, length)
|
||||
for i := 0; i < length; i++ {
|
||||
sVal, err := asStarlarkValue(v.Index(i).Interface())
|
||||
if err != nil {
|
||||
return starlark.None, err
|
||||
}
|
||||
array = append(array, sVal)
|
||||
}
|
||||
return starlark.NewList(array), nil
|
||||
case reflect.Map:
|
||||
dict := starlark.NewDict(v.Len())
|
||||
iter := v.MapRange()
|
||||
for iter.Next() {
|
||||
sKey, err := asStarlarkValue(iter.Key().Interface())
|
||||
if err != nil {
|
||||
return starlark.None, err
|
||||
}
|
||||
sValue, err := asStarlarkValue(iter.Value().Interface())
|
||||
if err != nil {
|
||||
return starlark.None, err
|
||||
}
|
||||
if err := dict.SetKey(sKey, sValue); err != nil {
|
||||
return starlark.None, err
|
||||
}
|
||||
}
|
||||
return dict, nil
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return starlark.Float(v.Float()), nil
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
return starlark.MakeInt64(v.Int()), nil
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
return starlark.MakeUint64(v.Uint()), nil
|
||||
case reflect.String:
|
||||
return starlark.String(v.String()), nil
|
||||
case reflect.Bool:
|
||||
return starlark.Bool(v.Bool()), nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("invalid type %T", value)
|
||||
}
|
||||
|
||||
// AsGoValue converts a starlark.Value to a field value.
|
||||
func asGoValue(value interface{}) (interface{}, error) {
|
||||
switch v := value.(type) {
|
||||
case starlark.Float:
|
||||
return float64(v), nil
|
||||
case starlark.Int:
|
||||
n, ok := v.Int64()
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("cannot represent integer %v as int64", v)
|
||||
}
|
||||
return n, nil
|
||||
case starlark.String:
|
||||
return string(v), nil
|
||||
case starlark.Bool:
|
||||
return bool(v), nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("invalid starlark type %T", value)
|
||||
}
|
||||
|
||||
// ToFields converts a starlark.Value to a map of values.
|
||||
func toFields(value starlark.Value) (map[string]interface{}, error) {
|
||||
if value == nil {
|
||||
return nil, nil
|
||||
}
|
||||
items, err := items(value, "The type %T is unsupported as type of collection of fields")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result := make(map[string]interface{}, len(items))
|
||||
for _, item := range items {
|
||||
key, err := toString(item[0], "The type %T is unsupported as type of key for fields")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
value, err := asGoValue(item[1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result[key] = value
|
||||
}
|
||||
return result, nil
|
||||
}
|
48
plugins/common/starlark/logging.go
Normal file
48
plugins/common/starlark/logging.go
Normal file
|
@ -0,0 +1,48 @@
|
|||
package starlark
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"go.starlark.net/starlark"
|
||||
"go.starlark.net/starlarkstruct"
|
||||
|
||||
"github.com/influxdata/telegraf"
|
||||
)
|
||||
|
||||
// Builds a module that defines all the supported logging functions which will log using the provided logger
|
||||
func LogModule(logger telegraf.Logger) *starlarkstruct.Module {
|
||||
var logFunc = func(_ *starlark.Thread, b *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) {
|
||||
return log(b, args, kwargs, logger)
|
||||
}
|
||||
return &starlarkstruct.Module{
|
||||
Name: "log",
|
||||
Members: starlark.StringDict{
|
||||
"debug": starlark.NewBuiltin("log.debug", logFunc),
|
||||
"info": starlark.NewBuiltin("log.info", logFunc),
|
||||
"warn": starlark.NewBuiltin("log.warn", logFunc),
|
||||
"error": starlark.NewBuiltin("log.error", logFunc),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Logs the provided message according to the level chosen
|
||||
func log(b *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple, logger telegraf.Logger) (starlark.Value, error) {
|
||||
var msg starlark.String
|
||||
if err := starlark.UnpackPositionalArgs(b.Name(), args, kwargs, 1, &msg); err != nil {
|
||||
return starlark.None, fmt.Errorf("%s: %w", b.Name(), err)
|
||||
}
|
||||
switch b.Name() {
|
||||
case "log.debug":
|
||||
logger.Debug(string(msg))
|
||||
case "log.info":
|
||||
logger.Info(string(msg))
|
||||
case "log.warn":
|
||||
logger.Warn(string(msg))
|
||||
case "log.error":
|
||||
logger.Error(string(msg))
|
||||
default:
|
||||
return nil, errors.New("method " + b.Name() + " is unknown")
|
||||
}
|
||||
return starlark.None, nil
|
||||
}
|
156
plugins/common/starlark/metric.go
Normal file
156
plugins/common/starlark/metric.go
Normal file
|
@ -0,0 +1,156 @@
|
|||
package starlark
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"go.starlark.net/starlark"
|
||||
|
||||
"github.com/influxdata/telegraf"
|
||||
)
|
||||
|
||||
type Metric struct {
|
||||
ID telegraf.TrackingID
|
||||
metric telegraf.Metric
|
||||
tagIterCount int
|
||||
fieldIterCount int
|
||||
frozen bool
|
||||
}
|
||||
|
||||
// Wrap updates the starlark.Metric to wrap a new telegraf.Metric.
|
||||
func (m *Metric) Wrap(metric telegraf.Metric) {
|
||||
if tm, ok := metric.(telegraf.TrackingMetric); ok {
|
||||
m.ID = tm.TrackingID()
|
||||
}
|
||||
m.metric = metric
|
||||
m.tagIterCount = 0
|
||||
m.fieldIterCount = 0
|
||||
m.frozen = false
|
||||
}
|
||||
|
||||
// Unwrap removes the telegraf.Metric from the startlark.Metric.
|
||||
func (m *Metric) Unwrap() telegraf.Metric {
|
||||
return m.metric
|
||||
}
|
||||
|
||||
// String returns the starlark representation of the Metric.
|
||||
//
|
||||
// The String function is called by both the repr() and str() functions, and so
|
||||
// it behaves more like the repr function would in Python.
|
||||
func (m *Metric) String() string {
|
||||
buf := new(strings.Builder)
|
||||
buf.WriteString("Metric(")
|
||||
buf.WriteString(m.Name().String())
|
||||
buf.WriteString(", tags=")
|
||||
buf.WriteString(m.Tags().String())
|
||||
buf.WriteString(", fields=")
|
||||
buf.WriteString(m.Fields().String())
|
||||
buf.WriteString(", time=")
|
||||
buf.WriteString(m.Time().String())
|
||||
buf.WriteString(")")
|
||||
if m.ID != 0 {
|
||||
fmt.Fprintf(buf, "[tracking ID=%v]", m.ID)
|
||||
}
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
func (*Metric) Type() string {
|
||||
return "Metric"
|
||||
}
|
||||
|
||||
func (m *Metric) Freeze() {
|
||||
m.frozen = true
|
||||
}
|
||||
|
||||
func (*Metric) Truth() starlark.Bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (*Metric) Hash() (uint32, error) {
|
||||
return 0, errors.New("not hashable")
|
||||
}
|
||||
|
||||
// AttrNames implements the starlark.HasAttrs interface.
|
||||
func (*Metric) AttrNames() []string {
|
||||
return []string{"name", "tags", "fields", "time"}
|
||||
}
|
||||
|
||||
// Attr implements the starlark.HasAttrs interface.
|
||||
func (m *Metric) Attr(name string) (starlark.Value, error) {
|
||||
switch name {
|
||||
case "name":
|
||||
return m.Name(), nil
|
||||
case "tags":
|
||||
return m.Tags(), nil
|
||||
case "fields":
|
||||
return m.Fields(), nil
|
||||
case "time":
|
||||
return m.Time(), nil
|
||||
default:
|
||||
// Returning nil, nil indicates "no such field or method"
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
// SetField implements the starlark.HasSetField interface.
|
||||
func (m *Metric) SetField(name string, value starlark.Value) error {
|
||||
if m.frozen {
|
||||
return errors.New("cannot modify frozen metric")
|
||||
}
|
||||
|
||||
switch name {
|
||||
case "name":
|
||||
return m.SetName(value)
|
||||
case "time":
|
||||
return m.SetTime(value)
|
||||
case "tags":
|
||||
return errors.New("cannot set tags")
|
||||
case "fields":
|
||||
return errors.New("cannot set fields")
|
||||
default:
|
||||
return starlark.NoSuchAttrError(
|
||||
fmt.Sprintf("cannot assign to field %q", name))
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Metric) Name() starlark.String {
|
||||
return starlark.String(m.metric.Name())
|
||||
}
|
||||
|
||||
func (m *Metric) SetName(value starlark.Value) error {
|
||||
if str, ok := value.(starlark.String); ok {
|
||||
m.metric.SetName(str.GoString())
|
||||
return nil
|
||||
}
|
||||
|
||||
return errors.New("type error")
|
||||
}
|
||||
|
||||
func (m *Metric) Tags() TagDict {
|
||||
return TagDict{m}
|
||||
}
|
||||
|
||||
func (m *Metric) Fields() FieldDict {
|
||||
return FieldDict{m}
|
||||
}
|
||||
|
||||
func (m *Metric) Time() starlark.Int {
|
||||
return starlark.MakeInt64(m.metric.Time().UnixNano())
|
||||
}
|
||||
|
||||
func (m *Metric) SetTime(value starlark.Value) error {
|
||||
switch v := value.(type) {
|
||||
case starlark.Int:
|
||||
ns, ok := v.Int64()
|
||||
if !ok {
|
||||
return errors.New("type error: unrepresentable time")
|
||||
}
|
||||
tm := time.Unix(0, ns)
|
||||
m.metric.SetTime(tm)
|
||||
return nil
|
||||
default:
|
||||
return errors.New("type error")
|
||||
}
|
||||
}
|
282
plugins/common/starlark/starlark.go
Normal file
282
plugins/common/starlark/starlark.go
Normal file
|
@ -0,0 +1,282 @@
|
|||
package starlark
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/gob"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"go.starlark.net/lib/json"
|
||||
"go.starlark.net/lib/math"
|
||||
"go.starlark.net/lib/time"
|
||||
"go.starlark.net/starlark"
|
||||
"go.starlark.net/syntax"
|
||||
|
||||
"github.com/influxdata/telegraf"
|
||||
)
|
||||
|
||||
type Common struct {
|
||||
Source string `toml:"source"`
|
||||
Script string `toml:"script"`
|
||||
Constants map[string]interface{} `toml:"constants"`
|
||||
|
||||
Log telegraf.Logger `toml:"-"`
|
||||
StarlarkLoadFunc func(module string, logger telegraf.Logger) (starlark.StringDict, error)
|
||||
|
||||
thread *starlark.Thread
|
||||
builtins starlark.StringDict
|
||||
globals starlark.StringDict
|
||||
functions map[string]*starlark.Function
|
||||
parameters map[string]starlark.Tuple
|
||||
state *starlark.Dict
|
||||
}
|
||||
|
||||
func (s *Common) GetState() interface{} {
|
||||
// Return the actual byte-type instead of nil allowing the persister
|
||||
// to guess instantiate variable of the appropriate type
|
||||
if s.state == nil {
|
||||
return make([]byte, 0)
|
||||
}
|
||||
|
||||
// Convert the starlark dict into a golang dictionary for serialization
|
||||
state := make(map[string]interface{}, s.state.Len())
|
||||
items := s.state.Items()
|
||||
for _, item := range items {
|
||||
if len(item) != 2 {
|
||||
// We do expect key-value pairs in the state so there should be
|
||||
// two items.
|
||||
s.Log.Errorf("state item %+v does not contain a key-value pair", item)
|
||||
continue
|
||||
}
|
||||
k, ok := item.Index(0).(starlark.String)
|
||||
if !ok {
|
||||
s.Log.Errorf("state item %+v has invalid key type %T", item, item.Index(0))
|
||||
continue
|
||||
}
|
||||
v, err := asGoValue(item.Index(1))
|
||||
if err != nil {
|
||||
s.Log.Errorf("state item %+v value cannot be converted: %v", item, err)
|
||||
continue
|
||||
}
|
||||
state[k.GoString()] = v
|
||||
}
|
||||
|
||||
// Do a binary GOB encoding to preserve types
|
||||
var buf bytes.Buffer
|
||||
if err := gob.NewEncoder(&buf).Encode(state); err != nil {
|
||||
s.Log.Errorf("encoding state failed: %v", err)
|
||||
return make([]byte, 0)
|
||||
}
|
||||
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func (s *Common) SetState(state interface{}) error {
|
||||
data, ok := state.([]byte)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for state", state)
|
||||
}
|
||||
if len(data) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Decode the binary GOB encoding
|
||||
var dict map[string]interface{}
|
||||
if err := gob.NewDecoder(bytes.NewBuffer(data)).Decode(&dict); err != nil {
|
||||
return fmt.Errorf("decoding state failed: %w", err)
|
||||
}
|
||||
|
||||
// Convert the golang dict back to starlark types
|
||||
s.state = starlark.NewDict(len(dict))
|
||||
for k, v := range dict {
|
||||
sv, err := asStarlarkValue(v)
|
||||
if err != nil {
|
||||
return fmt.Errorf("value %v of state item %q cannot be set: %w", v, k, err)
|
||||
}
|
||||
if err := s.state.SetKey(starlark.String(k), sv); err != nil {
|
||||
return fmt.Errorf("state item %q cannot be set: %w", k, err)
|
||||
}
|
||||
}
|
||||
s.builtins["state"] = s.state
|
||||
|
||||
return s.InitProgram()
|
||||
}
|
||||
|
||||
func (s *Common) Init() error {
|
||||
if s.Source == "" && s.Script == "" {
|
||||
return errors.New("one of source or script must be set")
|
||||
}
|
||||
if s.Source != "" && s.Script != "" {
|
||||
return errors.New("both source or script cannot be set")
|
||||
}
|
||||
|
||||
s.builtins = starlark.StringDict{}
|
||||
s.builtins["Metric"] = starlark.NewBuiltin("Metric", newMetric)
|
||||
s.builtins["deepcopy"] = starlark.NewBuiltin("deepcopy", deepcopy)
|
||||
s.builtins["catch"] = starlark.NewBuiltin("catch", catch)
|
||||
|
||||
if err := s.addConstants(&s.builtins); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Initialize the program
|
||||
if err := s.InitProgram(); err != nil {
|
||||
// Try again with a declared state. This might be necessary for
|
||||
// state persistence.
|
||||
s.state = starlark.NewDict(0)
|
||||
s.builtins["state"] = s.state
|
||||
if serr := s.InitProgram(); serr != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
s.functions = make(map[string]*starlark.Function)
|
||||
s.parameters = make(map[string]starlark.Tuple)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Common) InitProgram() error {
|
||||
// Load the program. In case of an error we can try to insert the state
|
||||
// which can be used implicitly e.g. when persisting states
|
||||
program, err := s.sourceProgram(s.builtins)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Execute source
|
||||
s.thread = &starlark.Thread{
|
||||
Print: func(_ *starlark.Thread, msg string) { s.Log.Debug(msg) },
|
||||
Load: func(_ *starlark.Thread, module string) (starlark.StringDict, error) {
|
||||
return s.StarlarkLoadFunc(module, s.Log)
|
||||
},
|
||||
}
|
||||
globals, err := program.Init(s.thread, s.builtins)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// In case the program declares a global "state" we should insert it to
|
||||
// avoid warnings about inserting into a frozen variable
|
||||
if _, found := globals["state"]; found {
|
||||
globals["state"] = starlark.NewDict(0)
|
||||
}
|
||||
|
||||
// Freeze the global state. This prevents modifications to the processor
|
||||
// state and prevents scripts from containing errors storing tracking
|
||||
// metrics. Tasks that require global state will not be possible due to
|
||||
// this, so maybe we should relax this in the future.
|
||||
globals.Freeze()
|
||||
s.globals = globals
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Common) GetParameters(name string) (starlark.Tuple, bool) {
|
||||
parameters, found := s.parameters[name]
|
||||
return parameters, found
|
||||
}
|
||||
|
||||
func (s *Common) AddFunction(name string, params ...starlark.Value) error {
|
||||
globalFn, found := s.globals[name]
|
||||
if !found {
|
||||
return fmt.Errorf("%s is not defined", name)
|
||||
}
|
||||
|
||||
fn, found := globalFn.(*starlark.Function)
|
||||
if !found {
|
||||
return fmt.Errorf("%s is not a function", name)
|
||||
}
|
||||
|
||||
if fn.NumParams() != len(params) {
|
||||
return fmt.Errorf("%s function must take %d parameter(s)", name, len(params))
|
||||
}
|
||||
p := make(starlark.Tuple, len(params))
|
||||
copy(p, params)
|
||||
|
||||
s.functions[name] = fn
|
||||
s.parameters[name] = params
|
||||
return nil
|
||||
}
|
||||
|
||||
// Add all the constants defined in the plugin as constants of the script
|
||||
func (s *Common) addConstants(builtins *starlark.StringDict) error {
|
||||
for key, val := range s.Constants {
|
||||
if key == "state" {
|
||||
return errors.New("'state' constant uses reserved name")
|
||||
}
|
||||
sVal, err := asStarlarkValue(val)
|
||||
if err != nil {
|
||||
return fmt.Errorf("converting type %T failed: %w", val, err)
|
||||
}
|
||||
(*builtins)[key] = sVal
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Common) sourceProgram(builtins starlark.StringDict) (*starlark.Program, error) {
|
||||
var src interface{}
|
||||
if s.Source != "" {
|
||||
src = s.Source
|
||||
}
|
||||
|
||||
// AllowFloat - obsolete, no effect
|
||||
// AllowNestedDef - always on https://github.com/google/starlark-go/pull/328
|
||||
// AllowLambda - always on https://github.com/google/starlark-go/pull/328
|
||||
options := syntax.FileOptions{
|
||||
Recursion: true,
|
||||
GlobalReassign: true,
|
||||
Set: true,
|
||||
}
|
||||
|
||||
_, program, err := starlark.SourceProgramOptions(&options, s.Script, src, builtins.Has)
|
||||
return program, err
|
||||
}
|
||||
|
||||
// Call calls the function corresponding to the given name.
|
||||
func (s *Common) Call(name string) (starlark.Value, error) {
|
||||
fn, ok := s.functions[name]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("function %q does not exist", name)
|
||||
}
|
||||
args, ok := s.parameters[name]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("params for function %q do not exist", name)
|
||||
}
|
||||
return starlark.Call(s.thread, fn, args, nil)
|
||||
}
|
||||
|
||||
func (s *Common) LogError(err error) {
|
||||
var evalErr *starlark.EvalError
|
||||
if errors.As(err, &evalErr) {
|
||||
for _, line := range strings.Split(evalErr.Backtrace(), "\n") {
|
||||
s.Log.Error(line)
|
||||
}
|
||||
} else {
|
||||
s.Log.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func LoadFunc(module string, logger telegraf.Logger) (starlark.StringDict, error) {
|
||||
switch module {
|
||||
case "json.star":
|
||||
return starlark.StringDict{
|
||||
"json": json.Module,
|
||||
}, nil
|
||||
case "logging.star":
|
||||
return starlark.StringDict{
|
||||
"log": LogModule(logger),
|
||||
}, nil
|
||||
case "math.star":
|
||||
return starlark.StringDict{
|
||||
"math": math.Module,
|
||||
}, nil
|
||||
case "time.star":
|
||||
return starlark.StringDict{
|
||||
"time": time.Module,
|
||||
}, nil
|
||||
default:
|
||||
return nil, errors.New("module " + module + " is not available")
|
||||
}
|
||||
}
|
226
plugins/common/starlark/tag_dict.go
Normal file
226
plugins/common/starlark/tag_dict.go
Normal file
|
@ -0,0 +1,226 @@
|
|||
package starlark
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"go.starlark.net/starlark"
|
||||
|
||||
"github.com/influxdata/telegraf"
|
||||
)
|
||||
|
||||
// TagDict is a starlark.Value for the metric tags. It is heavily based on the
|
||||
// starlark.Dict.
|
||||
type TagDict struct {
|
||||
*Metric
|
||||
}
|
||||
|
||||
func (d TagDict) String() string {
|
||||
buf := new(strings.Builder)
|
||||
buf.WriteString("{")
|
||||
sep := ""
|
||||
for _, item := range d.Items() {
|
||||
k, v := item[0], item[1]
|
||||
buf.WriteString(sep)
|
||||
buf.WriteString(k.String())
|
||||
buf.WriteString(": ")
|
||||
buf.WriteString(v.String())
|
||||
sep = ", "
|
||||
}
|
||||
buf.WriteString("}")
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
func (TagDict) Type() string {
|
||||
return "Tags"
|
||||
}
|
||||
|
||||
func (d TagDict) Freeze() {
|
||||
// Disable linter check as the frozen variable is modified despite
|
||||
// passing a value instead of a pointer, because `TagDict` holds
|
||||
// a pointer to the underlying metric containing the `frozen` field.
|
||||
//revive:disable:modifies-value-receiver
|
||||
d.frozen = true
|
||||
}
|
||||
|
||||
func (d TagDict) Truth() starlark.Bool {
|
||||
return len(d.metric.TagList()) != 0
|
||||
}
|
||||
|
||||
func (TagDict) Hash() (uint32, error) {
|
||||
return 0, errors.New("not hashable")
|
||||
}
|
||||
|
||||
// AttrNames implements the starlark.HasAttrs interface.
|
||||
func (TagDict) AttrNames() []string {
|
||||
return builtinAttrNames(TagDictMethods)
|
||||
}
|
||||
|
||||
// Attr implements the starlark.HasAttrs interface.
|
||||
func (d TagDict) Attr(name string) (starlark.Value, error) {
|
||||
return builtinAttr(d, name, TagDictMethods)
|
||||
}
|
||||
|
||||
var TagDictMethods = map[string]builtinMethod{
|
||||
"clear": dictClear,
|
||||
"get": dictGet,
|
||||
"items": dictItems,
|
||||
"keys": dictKeys,
|
||||
"pop": dictPop,
|
||||
"popitem": dictPopitem,
|
||||
"setdefault": dictSetdefault,
|
||||
"update": dictUpdate,
|
||||
"values": dictValues,
|
||||
}
|
||||
|
||||
// Get implements the starlark.Mapping interface.
|
||||
func (d TagDict) Get(key starlark.Value) (v starlark.Value, found bool, err error) {
|
||||
if k, ok := key.(starlark.String); ok {
|
||||
gv, found := d.metric.GetTag(k.GoString())
|
||||
if !found {
|
||||
return starlark.None, false, nil
|
||||
}
|
||||
return starlark.String(gv), true, err
|
||||
}
|
||||
|
||||
return starlark.None, false, errors.New("key must be of type 'str'")
|
||||
}
|
||||
|
||||
// SetKey implements the starlark.HasSetKey interface to support map update
|
||||
// using x[k]=v syntax, like a dictionary.
|
||||
func (d TagDict) SetKey(k, v starlark.Value) error {
|
||||
if d.tagIterCount > 0 {
|
||||
return errors.New("cannot insert during iteration")
|
||||
}
|
||||
|
||||
key, ok := k.(starlark.String)
|
||||
if !ok {
|
||||
return errors.New("tag key must be of type 'str'")
|
||||
}
|
||||
|
||||
value, ok := v.(starlark.String)
|
||||
if !ok {
|
||||
return errors.New("tag value must be of type 'str'")
|
||||
}
|
||||
|
||||
d.metric.AddTag(key.GoString(), value.GoString())
|
||||
return nil
|
||||
}
|
||||
|
||||
// Items implements the starlark.IterableMapping interface.
|
||||
func (d TagDict) Items() []starlark.Tuple {
|
||||
items := make([]starlark.Tuple, 0, len(d.metric.TagList()))
|
||||
for _, tag := range d.metric.TagList() {
|
||||
key := starlark.String(tag.Key)
|
||||
value := starlark.String(tag.Value)
|
||||
pair := starlark.Tuple{key, value}
|
||||
items = append(items, pair)
|
||||
}
|
||||
return items
|
||||
}
|
||||
|
||||
func (d TagDict) Clear() error {
|
||||
if d.tagIterCount > 0 {
|
||||
return errors.New("cannot delete during iteration")
|
||||
}
|
||||
|
||||
keys := make([]string, 0, len(d.metric.TagList()))
|
||||
for _, tag := range d.metric.TagList() {
|
||||
keys = append(keys, tag.Key)
|
||||
}
|
||||
|
||||
for _, key := range keys {
|
||||
d.metric.RemoveTag(key)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d TagDict) PopItem() (v starlark.Value, err error) {
|
||||
if d.tagIterCount > 0 {
|
||||
return nil, errors.New("cannot delete during iteration")
|
||||
}
|
||||
|
||||
for _, tag := range d.metric.TagList() {
|
||||
k := tag.Key
|
||||
v := tag.Value
|
||||
|
||||
d.metric.RemoveTag(k)
|
||||
|
||||
sk := starlark.String(k)
|
||||
sv := starlark.String(v)
|
||||
return starlark.Tuple{sk, sv}, nil
|
||||
}
|
||||
|
||||
return nil, errors.New("popitem(): tag dictionary is empty")
|
||||
}
|
||||
|
||||
func (d TagDict) Delete(k starlark.Value) (v starlark.Value, found bool, err error) {
|
||||
if d.tagIterCount > 0 {
|
||||
return nil, false, errors.New("cannot delete during iteration")
|
||||
}
|
||||
|
||||
if key, ok := k.(starlark.String); ok {
|
||||
value, ok := d.metric.GetTag(key.GoString())
|
||||
if ok {
|
||||
d.metric.RemoveTag(key.GoString())
|
||||
v := starlark.String(value)
|
||||
return v, ok, err
|
||||
}
|
||||
return starlark.None, false, nil
|
||||
}
|
||||
|
||||
return starlark.None, false, errors.New("key must be of type 'str'")
|
||||
}
|
||||
|
||||
// Iterate implements the starlark.Iterator interface.
|
||||
func (d TagDict) Iterate() starlark.Iterator {
|
||||
d.tagIterCount++
|
||||
return &TagIterator{Metric: d.Metric, tags: d.metric.TagList()}
|
||||
}
|
||||
|
||||
type TagIterator struct {
|
||||
*Metric
|
||||
tags []*telegraf.Tag
|
||||
}
|
||||
|
||||
// Next implements the starlark.Iterator interface.
|
||||
func (i *TagIterator) Next(p *starlark.Value) bool {
|
||||
if len(i.tags) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
tag := i.tags[0]
|
||||
i.tags = i.tags[1:]
|
||||
*p = starlark.String(tag.Key)
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// Done implements the starlark.Iterator interface.
|
||||
func (i *TagIterator) Done() {
|
||||
i.tagIterCount--
|
||||
}
|
||||
|
||||
// ToTags converts a starlark.Value to a map of string.
|
||||
func toTags(value starlark.Value) (map[string]string, error) {
|
||||
if value == nil {
|
||||
return nil, nil
|
||||
}
|
||||
items, err := items(value, "The type %T is unsupported as type of collection of tags")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result := make(map[string]string, len(items))
|
||||
for _, item := range items {
|
||||
key, err := toString(item[0], "The type %T is unsupported as type of key for tags")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
value, err := toString(item[1], "The type %T is unsupported as type of value for tags")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result[key] = value
|
||||
}
|
||||
return result, nil
|
||||
}
|
24
plugins/common/tls/client.conf
Normal file
24
plugins/common/tls/client.conf
Normal file
|
@ -0,0 +1,24 @@
|
|||
## Set to true/false to enforce TLS being enabled/disabled. If not set,
|
||||
## enable TLS only if any of the other options are specified.
|
||||
# tls_enable =
|
||||
## Trusted root certificates for server
|
||||
# tls_ca = "/path/to/cafile"
|
||||
## Used for TLS client certificate authentication
|
||||
# tls_cert = "/path/to/certfile"
|
||||
## Used for TLS client certificate authentication
|
||||
# tls_key = "/path/to/keyfile"
|
||||
## Password for the key file if it is encrypted
|
||||
# tls_key_pwd = ""
|
||||
## Send the specified TLS server name via SNI
|
||||
# tls_server_name = "kubernetes.example.com"
|
||||
## Minimal TLS version to accept by the client
|
||||
# tls_min_version = "TLS12"
|
||||
## List of ciphers to accept, by default all secure ciphers will be accepted
|
||||
## See https://pkg.go.dev/crypto/tls#pkg-constants for supported values.
|
||||
## Use "all", "secure" and "insecure" to add all support ciphers, secure
|
||||
## suites or insecure suites respectively.
|
||||
# tls_cipher_suites = ["secure"]
|
||||
## Renegotiation method, "never", "once" or "freely"
|
||||
# tls_renegotiation_method = "never"
|
||||
## Use TLS but skip chain & host verification
|
||||
# insecure_skip_verify = false
|
34
plugins/common/tls/common.go
Normal file
34
plugins/common/tls/common.go
Normal file
|
@ -0,0 +1,34 @@
|
|||
package tls
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var tlsVersionMap = map[string]uint16{
|
||||
"TLS10": tls.VersionTLS10,
|
||||
"TLS11": tls.VersionTLS11,
|
||||
"TLS12": tls.VersionTLS12,
|
||||
"TLS13": tls.VersionTLS13,
|
||||
}
|
||||
|
||||
var tlsCipherMapInit sync.Once
|
||||
var tlsCipherMapSecure map[string]uint16
|
||||
var tlsCipherMapInsecure map[string]uint16
|
||||
|
||||
func init() {
|
||||
tlsCipherMapInit.Do(func() {
|
||||
// Initialize the secure suites
|
||||
suites := tls.CipherSuites()
|
||||
tlsCipherMapSecure = make(map[string]uint16, len(suites))
|
||||
for _, s := range suites {
|
||||
tlsCipherMapSecure[s.Name] = s.ID
|
||||
}
|
||||
|
||||
suites = tls.InsecureCipherSuites()
|
||||
tlsCipherMapInsecure = make(map[string]uint16, len(suites))
|
||||
for _, s := range suites {
|
||||
tlsCipherMapInsecure[s.Name] = s.ID
|
||||
}
|
||||
})
|
||||
}
|
298
plugins/common/tls/config.go
Normal file
298
plugins/common/tls/config.go
Normal file
|
@ -0,0 +1,298 @@
|
|||
package tls
|
||||
|
||||
import (
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"go.step.sm/crypto/pemutil"
|
||||
|
||||
"github.com/influxdata/telegraf/internal/choice"
|
||||
)
|
||||
|
||||
const TLSMinVersionDefault = tls.VersionTLS12
|
||||
|
||||
// ClientConfig represents the standard client TLS config.
|
||||
type ClientConfig struct {
|
||||
TLSCA string `toml:"tls_ca"`
|
||||
TLSCert string `toml:"tls_cert"`
|
||||
TLSKey string `toml:"tls_key"`
|
||||
TLSKeyPwd string `toml:"tls_key_pwd"`
|
||||
TLSMinVersion string `toml:"tls_min_version"`
|
||||
TLSCipherSuites []string `toml:"tls_cipher_suites"`
|
||||
InsecureSkipVerify bool `toml:"insecure_skip_verify"`
|
||||
ServerName string `toml:"tls_server_name"`
|
||||
RenegotiationMethod string `toml:"tls_renegotiation_method"`
|
||||
Enable *bool `toml:"tls_enable"`
|
||||
|
||||
SSLCA string `toml:"ssl_ca" deprecated:"1.7.0;1.35.0;use 'tls_ca' instead"`
|
||||
SSLCert string `toml:"ssl_cert" deprecated:"1.7.0;1.35.0;use 'tls_cert' instead"`
|
||||
SSLKey string `toml:"ssl_key" deprecated:"1.7.0;1.35.0;use 'tls_key' instead"`
|
||||
}
|
||||
|
||||
// ServerConfig represents the standard server TLS config.
|
||||
type ServerConfig struct {
|
||||
TLSCert string `toml:"tls_cert"`
|
||||
TLSKey string `toml:"tls_key"`
|
||||
TLSKeyPwd string `toml:"tls_key_pwd"`
|
||||
TLSAllowedCACerts []string `toml:"tls_allowed_cacerts"`
|
||||
TLSCipherSuites []string `toml:"tls_cipher_suites"`
|
||||
TLSMinVersion string `toml:"tls_min_version"`
|
||||
TLSMaxVersion string `toml:"tls_max_version"`
|
||||
TLSAllowedDNSNames []string `toml:"tls_allowed_dns_names"`
|
||||
}
|
||||
|
||||
// TLSConfig returns a tls.Config, may be nil without error if TLS is not
|
||||
// configured.
|
||||
func (c *ClientConfig) TLSConfig() (*tls.Config, error) {
|
||||
// Check if TLS config is forcefully disabled
|
||||
if c.Enable != nil && !*c.Enable {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Support deprecated variable names
|
||||
if c.TLSCA == "" && c.SSLCA != "" {
|
||||
c.TLSCA = c.SSLCA
|
||||
}
|
||||
if c.TLSCert == "" && c.SSLCert != "" {
|
||||
c.TLSCert = c.SSLCert
|
||||
}
|
||||
if c.TLSKey == "" && c.SSLKey != "" {
|
||||
c.TLSKey = c.SSLKey
|
||||
}
|
||||
|
||||
// This check returns a nil (aka "disabled") or an empty config
|
||||
// (aka, "use the default") if no field is set that would have an effect on
|
||||
// a TLS connection. That is, any of:
|
||||
// * client certificate settings,
|
||||
// * peer certificate authorities,
|
||||
// * disabled security,
|
||||
// * an SNI server name, or
|
||||
// * empty/never renegotiation method
|
||||
empty := c.TLSCA == "" && c.TLSKey == "" && c.TLSCert == ""
|
||||
empty = empty && !c.InsecureSkipVerify && c.ServerName == ""
|
||||
empty = empty && (c.RenegotiationMethod == "" || c.RenegotiationMethod == "never")
|
||||
|
||||
if empty {
|
||||
// Check if TLS config is forcefully enabled and supposed to
|
||||
// use the system defaults.
|
||||
if c.Enable != nil && *c.Enable {
|
||||
return &tls.Config{}, nil
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var renegotiationMethod tls.RenegotiationSupport
|
||||
switch c.RenegotiationMethod {
|
||||
case "", "never":
|
||||
renegotiationMethod = tls.RenegotiateNever
|
||||
case "once":
|
||||
renegotiationMethod = tls.RenegotiateOnceAsClient
|
||||
case "freely":
|
||||
renegotiationMethod = tls.RenegotiateFreelyAsClient
|
||||
default:
|
||||
return nil, fmt.Errorf("unrecognized renegotiation method %q, choose from: 'never', 'once', 'freely'", c.RenegotiationMethod)
|
||||
}
|
||||
|
||||
tlsConfig := &tls.Config{
|
||||
InsecureSkipVerify: c.InsecureSkipVerify,
|
||||
Renegotiation: renegotiationMethod,
|
||||
}
|
||||
|
||||
if c.TLSCA != "" {
|
||||
pool, err := makeCertPool([]string{c.TLSCA})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tlsConfig.RootCAs = pool
|
||||
}
|
||||
|
||||
if c.TLSCert != "" && c.TLSKey != "" {
|
||||
err := loadCertificate(tlsConfig, c.TLSCert, c.TLSKey, c.TLSKeyPwd)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Explicitly and consistently set the minimal accepted version using the
|
||||
// defined default. We use this setting for both clients and servers
|
||||
// instead of relying on Golang's default that is different for clients
|
||||
// and servers and might change over time.
|
||||
tlsConfig.MinVersion = TLSMinVersionDefault
|
||||
if c.TLSMinVersion != "" {
|
||||
version, err := ParseTLSVersion(c.TLSMinVersion)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not parse tls min version %q: %w", c.TLSMinVersion, err)
|
||||
}
|
||||
tlsConfig.MinVersion = version
|
||||
}
|
||||
|
||||
if c.ServerName != "" {
|
||||
tlsConfig.ServerName = c.ServerName
|
||||
}
|
||||
|
||||
if len(c.TLSCipherSuites) != 0 {
|
||||
cipherSuites, err := ParseCiphers(c.TLSCipherSuites)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not parse client cipher suites: %w", err)
|
||||
}
|
||||
tlsConfig.CipherSuites = cipherSuites
|
||||
}
|
||||
|
||||
return tlsConfig, nil
|
||||
}
|
||||
|
||||
// TLSConfig returns a tls.Config, may be nil without error if TLS is not
|
||||
// configured.
|
||||
func (c *ServerConfig) TLSConfig() (*tls.Config, error) {
|
||||
if c.TLSCert == "" && c.TLSKey == "" && len(c.TLSAllowedCACerts) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
tlsConfig := &tls.Config{}
|
||||
|
||||
if len(c.TLSAllowedCACerts) != 0 {
|
||||
pool, err := makeCertPool(c.TLSAllowedCACerts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tlsConfig.ClientCAs = pool
|
||||
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
|
||||
}
|
||||
|
||||
if c.TLSCert != "" && c.TLSKey != "" {
|
||||
err := loadCertificate(tlsConfig, c.TLSCert, c.TLSKey, c.TLSKeyPwd)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if len(c.TLSCipherSuites) != 0 {
|
||||
cipherSuites, err := ParseCiphers(c.TLSCipherSuites)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not parse server cipher suites: %w", err)
|
||||
}
|
||||
tlsConfig.CipherSuites = cipherSuites
|
||||
}
|
||||
|
||||
if c.TLSMaxVersion != "" {
|
||||
version, err := ParseTLSVersion(c.TLSMaxVersion)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"could not parse tls max version %q: %w", c.TLSMaxVersion, err)
|
||||
}
|
||||
tlsConfig.MaxVersion = version
|
||||
}
|
||||
|
||||
// Explicitly and consistently set the minimal accepted version using the
|
||||
// defined default. We use this setting for both clients and servers
|
||||
// instead of relying on Golang's default that is different for clients
|
||||
// and servers and might change over time.
|
||||
tlsConfig.MinVersion = TLSMinVersionDefault
|
||||
if c.TLSMinVersion != "" {
|
||||
version, err := ParseTLSVersion(c.TLSMinVersion)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not parse tls min version %q: %w", c.TLSMinVersion, err)
|
||||
}
|
||||
tlsConfig.MinVersion = version
|
||||
}
|
||||
|
||||
if tlsConfig.MinVersion != 0 && tlsConfig.MaxVersion != 0 && tlsConfig.MinVersion > tlsConfig.MaxVersion {
|
||||
return nil, fmt.Errorf("tls min version %q can't be greater than tls max version %q", tlsConfig.MinVersion, tlsConfig.MaxVersion)
|
||||
}
|
||||
|
||||
// Since clientAuth is tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
|
||||
// there must be certs to validate.
|
||||
if len(c.TLSAllowedCACerts) > 0 && len(c.TLSAllowedDNSNames) > 0 {
|
||||
tlsConfig.VerifyPeerCertificate = c.verifyPeerCertificate
|
||||
}
|
||||
|
||||
return tlsConfig, nil
|
||||
}
|
||||
|
||||
func makeCertPool(certFiles []string) (*x509.CertPool, error) {
|
||||
pool := x509.NewCertPool()
|
||||
for _, certFile := range certFiles {
|
||||
cert, err := os.ReadFile(certFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not read certificate %q: %w", certFile, err)
|
||||
}
|
||||
if !pool.AppendCertsFromPEM(cert) {
|
||||
return nil, fmt.Errorf("could not parse any PEM certificates %q: %w", certFile, err)
|
||||
}
|
||||
}
|
||||
return pool, nil
|
||||
}
|
||||
|
||||
func loadCertificate(config *tls.Config, certFile, keyFile, privateKeyPassphrase string) error {
|
||||
certBytes, err := os.ReadFile(certFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not load certificate %q: %w", certFile, err)
|
||||
}
|
||||
|
||||
keyBytes, err := os.ReadFile(keyFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not load private key %q: %w", keyFile, err)
|
||||
}
|
||||
|
||||
keyPEMBlock, _ := pem.Decode(keyBytes)
|
||||
if keyPEMBlock == nil {
|
||||
return errors.New("failed to decode private key: no PEM data found")
|
||||
}
|
||||
|
||||
var cert tls.Certificate
|
||||
if keyPEMBlock.Type == "ENCRYPTED PRIVATE KEY" {
|
||||
if privateKeyPassphrase == "" {
|
||||
return errors.New("missing password for PKCS#8 encrypted private key")
|
||||
}
|
||||
rawDecryptedKey, err := pemutil.DecryptPKCS8PrivateKey(keyPEMBlock.Bytes, []byte(privateKeyPassphrase))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decrypt PKCS#8 private key: %w", err)
|
||||
}
|
||||
decryptedKey, err := x509.ParsePKCS8PrivateKey(rawDecryptedKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse decrypted PKCS#8 private key: %w", err)
|
||||
}
|
||||
privateKey, ok := decryptedKey.(*rsa.PrivateKey)
|
||||
if !ok {
|
||||
return fmt.Errorf("decrypted key is not a RSA private key: %T", decryptedKey)
|
||||
}
|
||||
cert, err = tls.X509KeyPair(certBytes, pem.EncodeToMemory(&pem.Block{Type: keyPEMBlock.Type, Bytes: x509.MarshalPKCS1PrivateKey(privateKey)}))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load cert/key pair: %w", err)
|
||||
}
|
||||
} else if keyPEMBlock.Headers["Proc-Type"] == "4,ENCRYPTED" {
|
||||
// The key is an encrypted private key with the DEK-Info header.
|
||||
// This is currently unsupported because of the deprecation of x509.IsEncryptedPEMBlock and x509.DecryptPEMBlock.
|
||||
return errors.New("password-protected keys in pkcs#1 format are not supported")
|
||||
} else {
|
||||
cert, err = tls.X509KeyPair(certBytes, keyBytes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load cert/key pair: %w", err)
|
||||
}
|
||||
}
|
||||
config.Certificates = []tls.Certificate{cert}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ServerConfig) verifyPeerCertificate(rawCerts [][]byte, _ [][]*x509.Certificate) error {
|
||||
// The certificate chain is client + intermediate + root.
|
||||
// Let's review the client certificate.
|
||||
cert, err := x509.ParseCertificate(rawCerts[0])
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not validate peer certificate: %w", err)
|
||||
}
|
||||
|
||||
for _, name := range cert.DNSNames {
|
||||
if choice.Contains(name, c.TLSAllowedDNSNames) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("peer certificate not in allowed DNS Name list: %v", cert.DNSNames)
|
||||
}
|
614
plugins/common/tls/config_test.go
Normal file
614
plugins/common/tls/config_test.go
Normal file
|
@ -0,0 +1,614 @@
|
|||
package tls_test
|
||||
|
||||
import (
|
||||
cryptotls "crypto/tls"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/influxdata/telegraf/plugins/common/tls"
|
||||
"github.com/influxdata/telegraf/testutil"
|
||||
)
|
||||
|
||||
var pki = testutil.NewPKI("../../../testutil/pki")
|
||||
|
||||
func TestClientConfig(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
client tls.ClientConfig
|
||||
expNil bool
|
||||
expErr bool
|
||||
}{
|
||||
{
|
||||
name: "unset",
|
||||
client: tls.ClientConfig{},
|
||||
expNil: true,
|
||||
},
|
||||
{
|
||||
name: "success",
|
||||
client: tls.ClientConfig{
|
||||
TLSCA: pki.CACertPath(),
|
||||
TLSCert: pki.ClientCertPath(),
|
||||
TLSKey: pki.ClientKeyPath(),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "success with tls key password set",
|
||||
client: tls.ClientConfig{
|
||||
TLSCA: pki.CACertPath(),
|
||||
TLSCert: pki.ClientCertPath(),
|
||||
TLSKey: pki.ClientKeyPath(),
|
||||
TLSKeyPwd: "",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "success with unencrypted pkcs#8 key",
|
||||
client: tls.ClientConfig{
|
||||
TLSCA: pki.CACertPath(),
|
||||
TLSCert: pki.ClientCertPath(),
|
||||
TLSKey: pki.ClientPKCS8KeyPath(),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "encrypted pkcs#8 key but missing password",
|
||||
client: tls.ClientConfig{
|
||||
TLSCA: pki.CACertPath(),
|
||||
TLSCert: pki.ClientCertPath(),
|
||||
TLSKey: pki.ClientEncPKCS8KeyPath(),
|
||||
},
|
||||
expNil: true,
|
||||
expErr: true,
|
||||
},
|
||||
{
|
||||
name: "encrypted pkcs#8 key and incorrect password",
|
||||
client: tls.ClientConfig{
|
||||
TLSCA: pki.CACertPath(),
|
||||
TLSCert: pki.ClientCertPath(),
|
||||
TLSKey: pki.ClientEncPKCS8KeyPath(),
|
||||
TLSKeyPwd: "incorrect",
|
||||
},
|
||||
expNil: true,
|
||||
expErr: true,
|
||||
},
|
||||
{
|
||||
name: "success with encrypted pkcs#8 key and password set",
|
||||
client: tls.ClientConfig{
|
||||
TLSCA: pki.CACertPath(),
|
||||
TLSCert: pki.ClientCertPath(),
|
||||
TLSKey: pki.ClientEncPKCS8KeyPath(),
|
||||
TLSKeyPwd: "changeme",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "error with encrypted pkcs#1 key and password set",
|
||||
client: tls.ClientConfig{
|
||||
TLSCA: pki.CACertPath(),
|
||||
TLSCert: pki.ClientCertPath(),
|
||||
TLSKey: pki.ClientEncKeyPath(),
|
||||
TLSKeyPwd: "changeme",
|
||||
},
|
||||
expNil: true,
|
||||
expErr: true,
|
||||
},
|
||||
|
||||
{
|
||||
name: "invalid ca",
|
||||
client: tls.ClientConfig{
|
||||
TLSCA: pki.ClientKeyPath(),
|
||||
TLSCert: pki.ClientCertPath(),
|
||||
TLSKey: pki.ClientKeyPath(),
|
||||
},
|
||||
expNil: true,
|
||||
expErr: true,
|
||||
},
|
||||
{
|
||||
name: "missing ca is okay",
|
||||
client: tls.ClientConfig{
|
||||
TLSCert: pki.ClientCertPath(),
|
||||
TLSKey: pki.ClientKeyPath(),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid cert",
|
||||
client: tls.ClientConfig{
|
||||
TLSCA: pki.CACertPath(),
|
||||
TLSCert: pki.ClientKeyPath(),
|
||||
TLSKey: pki.ClientKeyPath(),
|
||||
},
|
||||
expNil: true,
|
||||
expErr: true,
|
||||
},
|
||||
{
|
||||
name: "missing cert skips client keypair",
|
||||
client: tls.ClientConfig{
|
||||
TLSCA: pki.CACertPath(),
|
||||
TLSKey: pki.ClientKeyPath(),
|
||||
},
|
||||
expNil: false,
|
||||
expErr: false,
|
||||
},
|
||||
{
|
||||
name: "missing key skips client keypair",
|
||||
client: tls.ClientConfig{
|
||||
TLSCA: pki.CACertPath(),
|
||||
TLSCert: pki.ClientCertPath(),
|
||||
},
|
||||
expNil: false,
|
||||
expErr: false,
|
||||
},
|
||||
{
|
||||
name: "support deprecated ssl field names",
|
||||
client: tls.ClientConfig{
|
||||
SSLCA: pki.CACertPath(),
|
||||
SSLCert: pki.ClientCertPath(),
|
||||
SSLKey: pki.ClientKeyPath(),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "set SNI server name",
|
||||
client: tls.ClientConfig{
|
||||
ServerName: "foo.example.com",
|
||||
},
|
||||
expNil: false,
|
||||
expErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tlsConfig, err := tt.client.TLSConfig()
|
||||
if !tt.expNil {
|
||||
require.NotNil(t, tlsConfig)
|
||||
} else {
|
||||
require.Nil(t, tlsConfig)
|
||||
}
|
||||
|
||||
if !tt.expErr {
|
||||
require.NoError(t, err)
|
||||
} else {
|
||||
require.Error(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerConfig(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
server tls.ServerConfig
|
||||
expNil bool
|
||||
expErr bool
|
||||
}{
|
||||
{
|
||||
name: "unset",
|
||||
server: tls.ServerConfig{},
|
||||
expNil: true,
|
||||
},
|
||||
{
|
||||
name: "success",
|
||||
server: tls.ServerConfig{
|
||||
TLSCert: pki.ServerCertPath(),
|
||||
TLSKey: pki.ServerKeyPath(),
|
||||
TLSAllowedCACerts: []string{pki.CACertPath()},
|
||||
TLSCipherSuites: []string{pki.CipherSuite()},
|
||||
TLSAllowedDNSNames: []string{"localhost", "127.0.0.1"},
|
||||
TLSMinVersion: pki.TLSMinVersion(),
|
||||
TLSMaxVersion: pki.TLSMaxVersion(),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "success with tls key password set",
|
||||
server: tls.ServerConfig{
|
||||
TLSCert: pki.ServerCertPath(),
|
||||
TLSKey: pki.ServerKeyPath(),
|
||||
TLSKeyPwd: "",
|
||||
TLSAllowedCACerts: []string{pki.CACertPath()},
|
||||
TLSCipherSuites: []string{pki.CipherSuite()},
|
||||
TLSMinVersion: pki.TLSMinVersion(),
|
||||
TLSMaxVersion: pki.TLSMaxVersion(),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing tls cipher suites is okay",
|
||||
server: tls.ServerConfig{
|
||||
TLSCert: pki.ServerCertPath(),
|
||||
TLSKey: pki.ServerKeyPath(),
|
||||
TLSAllowedCACerts: []string{pki.CACertPath()},
|
||||
TLSCipherSuites: []string{pki.CipherSuite()},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing tls max version is okay",
|
||||
server: tls.ServerConfig{
|
||||
TLSCert: pki.ServerCertPath(),
|
||||
TLSKey: pki.ServerKeyPath(),
|
||||
TLSAllowedCACerts: []string{pki.CACertPath()},
|
||||
TLSCipherSuites: []string{pki.CipherSuite()},
|
||||
TLSMaxVersion: pki.TLSMaxVersion(),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing tls min version is okay",
|
||||
server: tls.ServerConfig{
|
||||
TLSCert: pki.ServerCertPath(),
|
||||
TLSKey: pki.ServerKeyPath(),
|
||||
TLSAllowedCACerts: []string{pki.CACertPath()},
|
||||
TLSCipherSuites: []string{pki.CipherSuite()},
|
||||
TLSMinVersion: pki.TLSMinVersion(),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing tls min/max versions is okay",
|
||||
server: tls.ServerConfig{
|
||||
TLSCert: pki.ServerCertPath(),
|
||||
TLSKey: pki.ServerKeyPath(),
|
||||
TLSAllowedCACerts: []string{pki.CACertPath()},
|
||||
TLSCipherSuites: []string{pki.CipherSuite()},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid ca",
|
||||
server: tls.ServerConfig{
|
||||
TLSCert: pki.ServerCertPath(),
|
||||
TLSKey: pki.ServerKeyPath(),
|
||||
TLSAllowedCACerts: []string{pki.ServerKeyPath()},
|
||||
},
|
||||
expNil: true,
|
||||
expErr: true,
|
||||
},
|
||||
{
|
||||
name: "missing allowed ca is okay",
|
||||
server: tls.ServerConfig{
|
||||
TLSCert: pki.ServerCertPath(),
|
||||
TLSKey: pki.ServerKeyPath(),
|
||||
},
|
||||
expNil: true,
|
||||
expErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid cert",
|
||||
server: tls.ServerConfig{
|
||||
TLSCert: pki.ServerKeyPath(),
|
||||
TLSKey: pki.ServerKeyPath(),
|
||||
TLSAllowedCACerts: []string{pki.CACertPath()},
|
||||
},
|
||||
expNil: true,
|
||||
expErr: true,
|
||||
},
|
||||
{
|
||||
name: "missing cert",
|
||||
server: tls.ServerConfig{
|
||||
TLSKey: pki.ServerKeyPath(),
|
||||
TLSAllowedCACerts: []string{pki.CACertPath()},
|
||||
},
|
||||
expNil: true,
|
||||
expErr: true,
|
||||
},
|
||||
{
|
||||
name: "missing key",
|
||||
server: tls.ServerConfig{
|
||||
TLSCert: pki.ServerCertPath(),
|
||||
TLSAllowedCACerts: []string{pki.CACertPath()},
|
||||
},
|
||||
expNil: true,
|
||||
expErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid cipher suites",
|
||||
server: tls.ServerConfig{
|
||||
TLSCert: pki.ServerCertPath(),
|
||||
TLSKey: pki.ServerKeyPath(),
|
||||
TLSAllowedCACerts: []string{pki.CACertPath()},
|
||||
TLSCipherSuites: []string{pki.CACertPath()},
|
||||
},
|
||||
expNil: true,
|
||||
expErr: true,
|
||||
},
|
||||
{
|
||||
name: "TLS Max Version less than TLS Min version",
|
||||
server: tls.ServerConfig{
|
||||
TLSCert: pki.ServerCertPath(),
|
||||
TLSKey: pki.ServerKeyPath(),
|
||||
TLSAllowedCACerts: []string{pki.CACertPath()},
|
||||
TLSCipherSuites: []string{pki.CACertPath()},
|
||||
TLSMinVersion: pki.TLSMaxVersion(),
|
||||
TLSMaxVersion: pki.TLSMinVersion(),
|
||||
},
|
||||
expNil: true,
|
||||
expErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid tls min version",
|
||||
server: tls.ServerConfig{
|
||||
TLSCert: pki.ServerCertPath(),
|
||||
TLSKey: pki.ServerKeyPath(),
|
||||
TLSAllowedCACerts: []string{pki.CACertPath()},
|
||||
TLSCipherSuites: []string{pki.CipherSuite()},
|
||||
TLSMinVersion: pki.ServerKeyPath(),
|
||||
TLSMaxVersion: pki.TLSMaxVersion(),
|
||||
},
|
||||
expNil: true,
|
||||
expErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid tls max version",
|
||||
server: tls.ServerConfig{
|
||||
TLSCert: pki.ServerCertPath(),
|
||||
TLSKey: pki.ServerKeyPath(),
|
||||
TLSAllowedCACerts: []string{pki.CACertPath()},
|
||||
TLSCipherSuites: []string{pki.CACertPath()},
|
||||
TLSMinVersion: pki.TLSMinVersion(),
|
||||
TLSMaxVersion: pki.ServerCertPath(),
|
||||
},
|
||||
expNil: true,
|
||||
expErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tlsConfig, err := tt.server.TLSConfig()
|
||||
if !tt.expNil {
|
||||
require.NotNil(t, tlsConfig)
|
||||
}
|
||||
if !tt.expErr {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnect(t *testing.T) {
|
||||
clientConfig := tls.ClientConfig{
|
||||
TLSCA: pki.CACertPath(),
|
||||
TLSCert: pki.ClientCertPath(),
|
||||
TLSKey: pki.ClientKeyPath(),
|
||||
}
|
||||
|
||||
serverConfig := tls.ServerConfig{
|
||||
TLSCert: pki.ServerCertPath(),
|
||||
TLSKey: pki.ServerKeyPath(),
|
||||
TLSAllowedCACerts: []string{pki.CACertPath()},
|
||||
TLSAllowedDNSNames: []string{"localhost", "127.0.0.1"},
|
||||
}
|
||||
|
||||
serverTLSConfig, err := serverConfig.TLSConfig()
|
||||
require.NoError(t, err)
|
||||
|
||||
ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
ts.TLS = serverTLSConfig
|
||||
|
||||
ts.StartTLS()
|
||||
defer ts.Close()
|
||||
|
||||
clientTLSConfig, err := clientConfig.TLSConfig()
|
||||
require.NoError(t, err)
|
||||
|
||||
client := http.Client{
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: clientTLSConfig,
|
||||
},
|
||||
Timeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
resp, err := client.Get(ts.URL)
|
||||
require.NoError(t, err)
|
||||
|
||||
defer resp.Body.Close()
|
||||
require.Equal(t, 200, resp.StatusCode)
|
||||
}
|
||||
|
||||
func TestConnectClientMinTLSVersion(t *testing.T) {
|
||||
serverConfig := tls.ServerConfig{
|
||||
TLSCert: pki.ServerCertPath(),
|
||||
TLSKey: pki.ServerKeyPath(),
|
||||
TLSAllowedCACerts: []string{pki.CACertPath()},
|
||||
TLSAllowedDNSNames: []string{"localhost", "127.0.0.1"},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg tls.ClientConfig
|
||||
}{
|
||||
{
|
||||
name: "TLS version default",
|
||||
cfg: tls.ClientConfig{
|
||||
TLSCA: pki.CACertPath(),
|
||||
TLSCert: pki.ClientCertPath(),
|
||||
TLSKey: pki.ClientKeyPath(),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "TLS version 1.0",
|
||||
cfg: tls.ClientConfig{
|
||||
TLSCA: pki.CACertPath(),
|
||||
TLSCert: pki.ClientCertPath(),
|
||||
TLSKey: pki.ClientKeyPath(),
|
||||
TLSMinVersion: "TLS10",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "TLS version 1.1",
|
||||
cfg: tls.ClientConfig{
|
||||
TLSCA: pki.CACertPath(),
|
||||
TLSCert: pki.ClientCertPath(),
|
||||
TLSKey: pki.ClientKeyPath(),
|
||||
TLSMinVersion: "TLS11",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "TLS version 1.2",
|
||||
cfg: tls.ClientConfig{
|
||||
TLSCA: pki.CACertPath(),
|
||||
TLSCert: pki.ClientCertPath(),
|
||||
TLSKey: pki.ClientKeyPath(),
|
||||
TLSMinVersion: "TLS12",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "TLS version 1.3",
|
||||
cfg: tls.ClientConfig{
|
||||
TLSCA: pki.CACertPath(),
|
||||
TLSCert: pki.ClientCertPath(),
|
||||
TLSKey: pki.ClientKeyPath(),
|
||||
TLSMinVersion: "TLS13",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tlsVersions := []uint16{
|
||||
cryptotls.VersionTLS10,
|
||||
cryptotls.VersionTLS11,
|
||||
cryptotls.VersionTLS12,
|
||||
cryptotls.VersionTLS13,
|
||||
}
|
||||
|
||||
tlsVersionNames := []string{
|
||||
"TLS 1.0",
|
||||
"TLS 1.1",
|
||||
"TLS 1.2",
|
||||
"TLS 1.3",
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
clientTLSConfig, err := tt.cfg.TLSConfig()
|
||||
require.NoError(t, err)
|
||||
|
||||
client := http.Client{
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: clientTLSConfig,
|
||||
},
|
||||
Timeout: 1 * time.Second,
|
||||
}
|
||||
|
||||
clientMinVersion := clientTLSConfig.MinVersion
|
||||
if tt.cfg.TLSMinVersion == "" {
|
||||
clientMinVersion = tls.TLSMinVersionDefault
|
||||
}
|
||||
|
||||
for i, serverTLSMaxVersion := range tlsVersions {
|
||||
serverVersionName := tlsVersionNames[i]
|
||||
t.Run(tt.name+" vs "+serverVersionName, func(t *testing.T) {
|
||||
// Constrain the server's maximum TLS version
|
||||
serverTLSConfig, err := serverConfig.TLSConfig()
|
||||
require.NoError(t, err)
|
||||
serverTLSConfig.MinVersion = cryptotls.VersionTLS10
|
||||
serverTLSConfig.MaxVersion = serverTLSMaxVersion
|
||||
|
||||
// Start the server
|
||||
ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
ts.TLS = serverTLSConfig
|
||||
ts.StartTLS()
|
||||
|
||||
// Do the connection and cleanup
|
||||
resp, err := client.Get(ts.URL)
|
||||
ts.Close()
|
||||
|
||||
// Things should fail if the currently tested "serverTLSMaxVersion"
|
||||
// is below the client minimum version.
|
||||
if serverTLSMaxVersion < clientMinVersion {
|
||||
require.ErrorContains(t, err, "tls: protocol version not supported")
|
||||
} else {
|
||||
require.NoErrorf(t, err, "server=%v client=%v", serverTLSMaxVersion, clientMinVersion)
|
||||
require.Equal(t, 200, resp.StatusCode)
|
||||
resp.Body.Close()
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectClientInvalidMinTLSVersion(t *testing.T) {
|
||||
clientConfig := tls.ClientConfig{
|
||||
TLSCA: pki.CACertPath(),
|
||||
TLSCert: pki.ClientCertPath(),
|
||||
TLSKey: pki.ClientKeyPath(),
|
||||
TLSMinVersion: "garbage",
|
||||
}
|
||||
|
||||
_, err := clientConfig.TLSConfig()
|
||||
expected := `could not parse tls min version "garbage": unsupported version "garbage" (available: TLS10,TLS11,TLS12,TLS13)`
|
||||
require.EqualError(t, err, expected)
|
||||
}
|
||||
|
||||
func TestConnectWrongDNS(t *testing.T) {
|
||||
clientConfig := tls.ClientConfig{
|
||||
TLSCA: pki.CACertPath(),
|
||||
TLSCert: pki.ClientCertPath(),
|
||||
TLSKey: pki.ClientKeyPath(),
|
||||
}
|
||||
|
||||
serverConfig := tls.ServerConfig{
|
||||
TLSCert: pki.ServerCertPath(),
|
||||
TLSKey: pki.ServerKeyPath(),
|
||||
TLSAllowedCACerts: []string{pki.CACertPath()},
|
||||
TLSAllowedDNSNames: []string{"localhos", "127.0.0.2"},
|
||||
}
|
||||
|
||||
serverTLSConfig, err := serverConfig.TLSConfig()
|
||||
require.NoError(t, err)
|
||||
|
||||
ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
ts.TLS = serverTLSConfig
|
||||
|
||||
ts.StartTLS()
|
||||
defer ts.Close()
|
||||
|
||||
clientTLSConfig, err := clientConfig.TLSConfig()
|
||||
require.NoError(t, err)
|
||||
|
||||
client := http.Client{
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: clientTLSConfig,
|
||||
},
|
||||
Timeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
resp, err := client.Get(ts.URL)
|
||||
require.Error(t, err)
|
||||
if resp != nil {
|
||||
err = resp.Body.Close()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnableFlagAuto(t *testing.T) {
|
||||
cfgEmpty := tls.ClientConfig{}
|
||||
cfg, err := cfgEmpty.TLSConfig()
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, cfg)
|
||||
|
||||
cfgSet := tls.ClientConfig{InsecureSkipVerify: true}
|
||||
cfg, err = cfgSet.TLSConfig()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cfg)
|
||||
}
|
||||
|
||||
func TestEnableFlagDisabled(t *testing.T) {
|
||||
enabled := false
|
||||
cfgSet := tls.ClientConfig{
|
||||
InsecureSkipVerify: true,
|
||||
Enable: &enabled,
|
||||
}
|
||||
cfg, err := cfgSet.TLSConfig()
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, cfg)
|
||||
}
|
||||
|
||||
func TestEnableFlagEnabled(t *testing.T) {
|
||||
enabled := true
|
||||
cfgSet := tls.ClientConfig{Enable: &enabled}
|
||||
cfg, err := cfgSet.TLSConfig()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cfg)
|
||||
|
||||
expected := &cryptotls.Config{}
|
||||
require.Equal(t, expected, cfg)
|
||||
}
|
112
plugins/common/tls/utils.go
Normal file
112
plugins/common/tls/utils.go
Normal file
|
@ -0,0 +1,112 @@
|
|||
package tls
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var ErrCipherUnsupported = errors.New("unsupported cipher")
|
||||
|
||||
// InsecureCiphers returns the list of insecure ciphers among the list of given ciphers
|
||||
func InsecureCiphers(ciphers []string) []string {
|
||||
var insecure []string
|
||||
|
||||
for _, c := range ciphers {
|
||||
cipher := strings.ToUpper(c)
|
||||
if _, ok := tlsCipherMapInsecure[cipher]; ok {
|
||||
insecure = append(insecure, c)
|
||||
}
|
||||
}
|
||||
|
||||
return insecure
|
||||
}
|
||||
|
||||
// Ciphers returns the list of supported ciphers
|
||||
func Ciphers() (secure, insecure []string) {
|
||||
for c := range tlsCipherMapSecure {
|
||||
secure = append(secure, c)
|
||||
}
|
||||
|
||||
for c := range tlsCipherMapInsecure {
|
||||
insecure = append(insecure, c)
|
||||
}
|
||||
|
||||
return secure, insecure
|
||||
}
|
||||
|
||||
// ParseCiphers returns a `[]uint16` by received `[]string` key that represents ciphers from crypto/tls.
|
||||
// If some of ciphers in received list doesn't exists ParseCiphers returns nil with error
|
||||
func ParseCiphers(ciphers []string) ([]uint16, error) {
|
||||
suites := make([]uint16, 0)
|
||||
added := make(map[uint16]bool, len(ciphers))
|
||||
for _, c := range ciphers {
|
||||
// Handle meta-keywords
|
||||
switch c {
|
||||
case "all":
|
||||
for _, id := range tlsCipherMapInsecure {
|
||||
if added[id] {
|
||||
continue
|
||||
}
|
||||
suites = append(suites, id)
|
||||
added[id] = true
|
||||
}
|
||||
for _, id := range tlsCipherMapSecure {
|
||||
if added[id] {
|
||||
continue
|
||||
}
|
||||
suites = append(suites, id)
|
||||
added[id] = true
|
||||
}
|
||||
case "insecure":
|
||||
for _, id := range tlsCipherMapInsecure {
|
||||
if added[id] {
|
||||
continue
|
||||
}
|
||||
suites = append(suites, id)
|
||||
added[id] = true
|
||||
}
|
||||
case "secure":
|
||||
for _, id := range tlsCipherMapSecure {
|
||||
if added[id] {
|
||||
continue
|
||||
}
|
||||
suites = append(suites, id)
|
||||
added[id] = true
|
||||
}
|
||||
default:
|
||||
cipher := strings.ToUpper(c)
|
||||
id, ok := tlsCipherMapSecure[cipher]
|
||||
if !ok {
|
||||
idInsecure, ok := tlsCipherMapInsecure[cipher]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("%q %w", cipher, ErrCipherUnsupported)
|
||||
}
|
||||
id = idInsecure
|
||||
}
|
||||
if added[id] {
|
||||
continue
|
||||
}
|
||||
suites = append(suites, id)
|
||||
added[id] = true
|
||||
}
|
||||
}
|
||||
|
||||
return suites, nil
|
||||
}
|
||||
|
||||
// ParseTLSVersion returns a `uint16` by received version string key that represents tls version from crypto/tls.
|
||||
// If version isn't supported ParseTLSVersion returns 0 with error
|
||||
func ParseTLSVersion(version string) (uint16, error) {
|
||||
if v, ok := tlsVersionMap[version]; ok {
|
||||
return v, nil
|
||||
}
|
||||
|
||||
available := make([]string, 0, len(tlsVersionMap))
|
||||
for n := range tlsVersionMap {
|
||||
available = append(available, n)
|
||||
}
|
||||
sort.Strings(available)
|
||||
return 0, fmt.Errorf("unsupported version %q (available: %s)", version, strings.Join(available, ","))
|
||||
}
|
263
plugins/common/yangmodel/decoder.go
Normal file
263
plugins/common/yangmodel/decoder.go
Normal file
|
@ -0,0 +1,263 @@
|
|||
package yangmodel
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
|
||||
"github.com/openconfig/goyang/pkg/yang"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInsufficientData = errors.New("insufficient data")
|
||||
ErrNotFound = errors.New("no such node")
|
||||
)
|
||||
|
||||
type Decoder struct {
|
||||
modules map[string]*yang.Module
|
||||
rootNodes map[string][]yang.Node
|
||||
}
|
||||
|
||||
func NewDecoder(paths ...string) (*Decoder, error) {
|
||||
modules := yang.NewModules()
|
||||
modules.ParseOptions.IgnoreSubmoduleCircularDependencies = true
|
||||
|
||||
var moduleFiles []string
|
||||
modulePaths := paths
|
||||
unresolved := paths
|
||||
for {
|
||||
var newlyfound []string
|
||||
for _, path := range unresolved {
|
||||
entries, err := os.ReadDir(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reading directory %q failed: %w", path, err)
|
||||
}
|
||||
for _, entry := range entries {
|
||||
info, err := entry.Info()
|
||||
if err != nil {
|
||||
fmt.Printf("Couldn't get info for %q: %v", entry.Name(), err)
|
||||
continue
|
||||
}
|
||||
|
||||
if info.Mode()&os.ModeSymlink != 0 {
|
||||
target, err := filepath.EvalSymlinks(entry.Name())
|
||||
if err != nil {
|
||||
fmt.Printf("Couldn't evaluate symbolic links for %q: %v", entry.Name(), err)
|
||||
continue
|
||||
}
|
||||
info, err = os.Lstat(target)
|
||||
if err != nil {
|
||||
fmt.Printf("Couldn't stat target %v: %v", target, err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
newPath := filepath.Join(path, info.Name())
|
||||
if info.IsDir() {
|
||||
newlyfound = append(newlyfound, newPath)
|
||||
continue
|
||||
}
|
||||
if info.Mode().IsRegular() && filepath.Ext(info.Name()) == ".yang" {
|
||||
moduleFiles = append(moduleFiles, info.Name())
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(newlyfound) == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
modulePaths = append(modulePaths, newlyfound...)
|
||||
unresolved = newlyfound
|
||||
}
|
||||
|
||||
// Add the module paths
|
||||
modules.AddPath(modulePaths...)
|
||||
for _, fn := range moduleFiles {
|
||||
if err := modules.Read(fn); err != nil {
|
||||
fmt.Printf("reading file %q failed: %v\n", fn, err)
|
||||
}
|
||||
}
|
||||
if errs := modules.Process(); len(errs) > 0 {
|
||||
return nil, errors.Join(errs...)
|
||||
}
|
||||
|
||||
// Get all root nodes defined in models with their origin. We require
|
||||
// those nodes to later resolve paths to YANG model leaf nodes...
|
||||
moduleLUT := make(map[string]*yang.Module)
|
||||
moduleRootNodes := make(map[string][]yang.Node)
|
||||
for _, m := range modules.Modules {
|
||||
// Check if we processed the module already
|
||||
if _, found := moduleLUT[m.Name]; found {
|
||||
continue
|
||||
}
|
||||
// Create a module mapping for easily finding modules by name
|
||||
moduleLUT[m.Name] = m
|
||||
|
||||
// Determine the origin defined in the module
|
||||
var prefix string
|
||||
for _, imp := range m.Import {
|
||||
if imp.Name == "openconfig-extensions" {
|
||||
prefix = imp.Name
|
||||
if imp.Prefix != nil {
|
||||
prefix = imp.Prefix.Name
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
var moduleOrigin string
|
||||
if prefix != "" {
|
||||
for _, e := range m.Extensions {
|
||||
if e.Keyword == prefix+":origin" || e.Keyword == "origin" {
|
||||
moduleOrigin = e.Argument
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, u := range m.Uses {
|
||||
root, err := yang.FindNode(m, u.Name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
moduleRootNodes[moduleOrigin] = append(moduleRootNodes[moduleOrigin], root)
|
||||
}
|
||||
}
|
||||
|
||||
return &Decoder{modules: moduleLUT, rootNodes: moduleRootNodes}, nil
|
||||
}
|
||||
|
||||
func (d *Decoder) FindLeaf(name, identifier string) (*yang.Leaf, error) {
|
||||
// Get module name from the element
|
||||
module, found := d.modules[name]
|
||||
if !found {
|
||||
return nil, fmt.Errorf("cannot find module %q", name)
|
||||
}
|
||||
|
||||
for _, grp := range module.Grouping {
|
||||
for _, leaf := range grp.Leaf {
|
||||
if leaf.Name == identifier {
|
||||
return leaf, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
|
||||
func DecodeLeafValue(leaf *yang.Leaf, value interface{}) (interface{}, error) {
|
||||
schema := leaf.Type.YangType
|
||||
|
||||
// Ignore all non-string values as the types seem already converted...
|
||||
s, ok := value.(string)
|
||||
if !ok {
|
||||
return value, nil
|
||||
}
|
||||
|
||||
switch schema.Kind {
|
||||
case yang.Ybinary:
|
||||
// Binary values are encodes as base64 string, so decode the string
|
||||
raw, err := base64.StdEncoding.DecodeString(s)
|
||||
if err != nil {
|
||||
return value, err
|
||||
}
|
||||
|
||||
switch schema.Name {
|
||||
case "ieeefloat32":
|
||||
if len(raw) != 4 {
|
||||
return raw, fmt.Errorf("%w, expected 4 but got %d bytes", ErrInsufficientData, len(raw))
|
||||
}
|
||||
return math.Float32frombits(binary.BigEndian.Uint32(raw)), nil
|
||||
default:
|
||||
return raw, nil
|
||||
}
|
||||
case yang.Yint8:
|
||||
v, err := strconv.ParseInt(s, 10, 8)
|
||||
if err != nil {
|
||||
return value, fmt.Errorf("parsing %s %q failed: %w", yang.TypeKindToName[schema.Kind], s, err)
|
||||
}
|
||||
return int8(v), nil
|
||||
case yang.Yint16:
|
||||
v, err := strconv.ParseInt(s, 10, 16)
|
||||
if err != nil {
|
||||
return value, fmt.Errorf("parsing %s %q failed: %w", yang.TypeKindToName[schema.Kind], s, err)
|
||||
}
|
||||
return int16(v), nil
|
||||
case yang.Yint32:
|
||||
v, err := strconv.ParseInt(s, 10, 32)
|
||||
if err != nil {
|
||||
return value, fmt.Errorf("parsing %s %q failed: %w", yang.TypeKindToName[schema.Kind], s, err)
|
||||
}
|
||||
return int32(v), nil
|
||||
case yang.Yint64:
|
||||
v, err := strconv.ParseInt(s, 10, 64)
|
||||
if err != nil {
|
||||
return value, fmt.Errorf("parsing %s %q failed: %w", yang.TypeKindToName[schema.Kind], s, err)
|
||||
}
|
||||
return v, nil
|
||||
case yang.Yuint8:
|
||||
v, err := strconv.ParseUint(s, 10, 8)
|
||||
if err != nil {
|
||||
return value, fmt.Errorf("parsing %s %q failed: %w", yang.TypeKindToName[schema.Kind], s, err)
|
||||
}
|
||||
return uint8(v), nil
|
||||
case yang.Yuint16:
|
||||
v, err := strconv.ParseUint(s, 10, 16)
|
||||
if err != nil {
|
||||
return value, fmt.Errorf("parsing %s %q failed: %w", yang.TypeKindToName[schema.Kind], s, err)
|
||||
}
|
||||
return uint16(v), nil
|
||||
case yang.Yuint32:
|
||||
v, err := strconv.ParseUint(s, 10, 32)
|
||||
if err != nil {
|
||||
return value, fmt.Errorf("parsing %s %q failed: %w", yang.TypeKindToName[schema.Kind], s, err)
|
||||
}
|
||||
return uint32(v), nil
|
||||
case yang.Yuint64:
|
||||
v, err := strconv.ParseUint(s, 10, 64)
|
||||
if err != nil {
|
||||
return value, fmt.Errorf("parsing %s %q failed: %w", yang.TypeKindToName[schema.Kind], s, err)
|
||||
}
|
||||
return v, nil
|
||||
case yang.Ydecimal64:
|
||||
v, err := strconv.ParseFloat(s, 64)
|
||||
if err != nil {
|
||||
return value, fmt.Errorf("parsing %s %q failed: %w", yang.TypeKindToName[schema.Kind], s, err)
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
return value, nil
|
||||
}
|
||||
|
||||
func (d *Decoder) DecodeLeafElement(namespace, identifier string, value interface{}) (interface{}, error) {
|
||||
leaf, err := d.FindLeaf(namespace, identifier)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("finding %s failed: %w", identifier, err)
|
||||
}
|
||||
|
||||
return DecodeLeafValue(leaf, value)
|
||||
}
|
||||
|
||||
func (d *Decoder) DecodePathElement(origin, path string, value interface{}) (interface{}, error) {
|
||||
rootNodes, found := d.rootNodes[origin]
|
||||
if !found || len(rootNodes) == 0 {
|
||||
return value, nil
|
||||
}
|
||||
|
||||
for _, root := range rootNodes {
|
||||
node, err := yang.FindNode(root, path)
|
||||
if node == nil || err != nil {
|
||||
// The path does not exist in this root node
|
||||
continue
|
||||
}
|
||||
// We do expect a leaf node...
|
||||
if leaf, ok := node.(*yang.Leaf); ok {
|
||||
return DecodeLeafValue(leaf, value)
|
||||
}
|
||||
}
|
||||
|
||||
return value, nil
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue