224 lines
4.8 KiB
Go
224 lines
4.8 KiB
Go
|
package websocket
|
||
|
|
||
|
import (
|
||
|
"net/http"
|
||
|
"net/http/httptest"
|
||
|
"strconv"
|
||
|
"strings"
|
||
|
"testing"
|
||
|
"time"
|
||
|
|
||
|
ws "github.com/gorilla/websocket"
|
||
|
"github.com/stretchr/testify/require"
|
||
|
|
||
|
"github.com/influxdata/telegraf"
|
||
|
"github.com/influxdata/telegraf/config"
|
||
|
"github.com/influxdata/telegraf/testutil"
|
||
|
)
|
||
|
|
||
|
// testSerializer serializes to a number of metrics to simplify tests here.
|
||
|
type testSerializer struct{}
|
||
|
|
||
|
func newTestSerializer() *testSerializer {
|
||
|
return &testSerializer{}
|
||
|
}
|
||
|
|
||
|
func (testSerializer) Serialize(_ telegraf.Metric) ([]byte, error) {
|
||
|
return []byte("1"), nil
|
||
|
}
|
||
|
|
||
|
func (testSerializer) SerializeBatch(metrics []telegraf.Metric) ([]byte, error) {
|
||
|
return []byte(strconv.Itoa(len(metrics))), nil
|
||
|
}
|
||
|
|
||
|
type testServer struct {
|
||
|
*httptest.Server
|
||
|
t *testing.T
|
||
|
messages chan []byte
|
||
|
upgradeDelay time.Duration
|
||
|
expectTextFrames bool
|
||
|
}
|
||
|
|
||
|
func newTestServer(t *testing.T, messages chan []byte, tls bool) *testServer {
|
||
|
s := &testServer{}
|
||
|
s.t = t
|
||
|
if tls {
|
||
|
s.Server = httptest.NewTLSServer(s)
|
||
|
} else {
|
||
|
s.Server = httptest.NewServer(s)
|
||
|
}
|
||
|
s.URL = makeWsProto(s.Server.URL)
|
||
|
s.messages = messages
|
||
|
return s
|
||
|
}
|
||
|
|
||
|
func makeWsProto(s string) string {
|
||
|
return "ws" + strings.TrimPrefix(s, "http")
|
||
|
}
|
||
|
|
||
|
const (
|
||
|
testHeaderName = "X-Telegraf-Test"
|
||
|
testHeaderValue = "1"
|
||
|
)
|
||
|
|
||
|
var testUpgrader = ws.Upgrader{}
|
||
|
|
||
|
func (s *testServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||
|
if r.Header.Get(testHeaderName) != testHeaderValue {
|
||
|
s.t.Fatalf("expected test header found in request, got: %#v", r.Header)
|
||
|
}
|
||
|
if s.upgradeDelay > 0 {
|
||
|
// Emulate long handshake.
|
||
|
select {
|
||
|
case <-r.Context().Done():
|
||
|
return
|
||
|
case <-time.After(s.upgradeDelay):
|
||
|
}
|
||
|
}
|
||
|
conn, err := testUpgrader.Upgrade(w, r, http.Header{})
|
||
|
if err != nil {
|
||
|
return
|
||
|
}
|
||
|
defer func() { _ = conn.Close() }()
|
||
|
|
||
|
for {
|
||
|
messageType, data, err := conn.ReadMessage()
|
||
|
if err != nil {
|
||
|
break
|
||
|
}
|
||
|
if s.expectTextFrames && messageType != ws.TextMessage {
|
||
|
s.t.Fatalf("unexpected frame type: %d", messageType)
|
||
|
}
|
||
|
select {
|
||
|
case s.messages <- data:
|
||
|
case <-time.After(5 * time.Second):
|
||
|
s.t.Fatal("timeout writing to messages channel, make sure there are readers")
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func initWebSocket(s *testServer) *WebSocket {
|
||
|
w := newWebSocket()
|
||
|
w.Log = testutil.Logger{}
|
||
|
w.URL = s.URL
|
||
|
headerSecret := config.NewSecret([]byte(testHeaderValue))
|
||
|
w.Headers = map[string]*config.Secret{testHeaderName: &headerSecret}
|
||
|
w.SetSerializer(newTestSerializer())
|
||
|
return w
|
||
|
}
|
||
|
|
||
|
func connect(t *testing.T, w *WebSocket) {
|
||
|
err := w.Connect()
|
||
|
require.NoError(t, err)
|
||
|
}
|
||
|
|
||
|
func TestWebSocket_NoURL(t *testing.T) {
|
||
|
w := newWebSocket()
|
||
|
err := w.Init()
|
||
|
require.ErrorIs(t, err, errInvalidURL)
|
||
|
}
|
||
|
|
||
|
func TestWebSocket_Connect_Timeout(t *testing.T) {
|
||
|
s := newTestServer(t, nil, false)
|
||
|
s.upgradeDelay = time.Second
|
||
|
defer s.Close()
|
||
|
w := initWebSocket(s)
|
||
|
w.ConnectTimeout = config.Duration(10 * time.Millisecond)
|
||
|
err := w.Connect()
|
||
|
require.Error(t, err)
|
||
|
}
|
||
|
|
||
|
func TestWebSocket_Connect_OK(t *testing.T) {
|
||
|
s := newTestServer(t, nil, false)
|
||
|
defer s.Close()
|
||
|
w := initWebSocket(s)
|
||
|
connect(t, w)
|
||
|
}
|
||
|
|
||
|
func TestWebSocket_ConnectTLS_OK(t *testing.T) {
|
||
|
s := newTestServer(t, nil, true)
|
||
|
defer s.Close()
|
||
|
w := initWebSocket(s)
|
||
|
w.ClientConfig.InsecureSkipVerify = true
|
||
|
connect(t, w)
|
||
|
}
|
||
|
|
||
|
func TestWebSocket_Write_OK(t *testing.T) {
|
||
|
messages := make(chan []byte, 1)
|
||
|
|
||
|
s := newTestServer(t, messages, false)
|
||
|
defer s.Close()
|
||
|
|
||
|
w := initWebSocket(s)
|
||
|
connect(t, w)
|
||
|
|
||
|
metrics := []telegraf.Metric{
|
||
|
testutil.TestMetric(0.4, "test"),
|
||
|
testutil.TestMetric(0.5, "test"),
|
||
|
}
|
||
|
err := w.Write(metrics)
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
select {
|
||
|
case data := <-messages:
|
||
|
require.Equal(t, []byte("2"), data)
|
||
|
case <-time.After(time.Second):
|
||
|
t.Fatal("timeout receiving data")
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestWebSocket_Write_Error(t *testing.T) {
|
||
|
s := newTestServer(t, nil, false)
|
||
|
defer s.Close()
|
||
|
|
||
|
w := initWebSocket(s)
|
||
|
connect(t, w)
|
||
|
|
||
|
require.NoError(t, w.conn.Close())
|
||
|
|
||
|
metrics := []telegraf.Metric{testutil.TestMetric(0.4, "test")}
|
||
|
err := w.Write(metrics)
|
||
|
require.Error(t, err)
|
||
|
require.Nil(t, w.conn)
|
||
|
}
|
||
|
|
||
|
func TestWebSocket_Write_Reconnect(t *testing.T) {
|
||
|
messages := make(chan []byte, 1)
|
||
|
s := newTestServer(t, messages, false)
|
||
|
s.expectTextFrames = true // Also use text frames in this test.
|
||
|
defer s.Close()
|
||
|
|
||
|
w := initWebSocket(s)
|
||
|
w.UseTextFrames = true
|
||
|
connect(t, w)
|
||
|
|
||
|
metrics := []telegraf.Metric{testutil.TestMetric(0.4, "test")}
|
||
|
|
||
|
require.NoError(t, w.conn.Close())
|
||
|
|
||
|
err := w.Write(metrics)
|
||
|
require.Error(t, err)
|
||
|
require.Nil(t, w.conn)
|
||
|
|
||
|
err = w.Write(metrics)
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
select {
|
||
|
case data := <-messages:
|
||
|
require.Equal(t, []byte("1"), data)
|
||
|
case <-time.After(time.Second):
|
||
|
t.Fatal("timeout receiving data")
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestWebSocket_Close(t *testing.T) {
|
||
|
s := newTestServer(t, nil, false)
|
||
|
defer s.Close()
|
||
|
|
||
|
w := initWebSocket(s)
|
||
|
connect(t, w)
|
||
|
require.NoError(t, w.Close())
|
||
|
// Check no error on second close.
|
||
|
require.NoError(t, w.Close())
|
||
|
}
|