package websocket import ( "net/http" "net/http/httptest" "strconv" "strings" "testing" "time" ws "github.com/gorilla/websocket" "github.com/stretchr/testify/require" "github.com/influxdata/telegraf" "github.com/influxdata/telegraf/config" "github.com/influxdata/telegraf/testutil" ) // testSerializer serializes to a number of metrics to simplify tests here. type testSerializer struct{} func newTestSerializer() *testSerializer { return &testSerializer{} } func (testSerializer) Serialize(_ telegraf.Metric) ([]byte, error) { return []byte("1"), nil } func (testSerializer) SerializeBatch(metrics []telegraf.Metric) ([]byte, error) { return []byte(strconv.Itoa(len(metrics))), nil } type testServer struct { *httptest.Server t *testing.T messages chan []byte upgradeDelay time.Duration expectTextFrames bool } func newTestServer(t *testing.T, messages chan []byte, tls bool) *testServer { s := &testServer{} s.t = t if tls { s.Server = httptest.NewTLSServer(s) } else { s.Server = httptest.NewServer(s) } s.URL = makeWsProto(s.Server.URL) s.messages = messages return s } func makeWsProto(s string) string { return "ws" + strings.TrimPrefix(s, "http") } const ( testHeaderName = "X-Telegraf-Test" testHeaderValue = "1" ) var testUpgrader = ws.Upgrader{} func (s *testServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { if r.Header.Get(testHeaderName) != testHeaderValue { s.t.Fatalf("expected test header found in request, got: %#v", r.Header) } if s.upgradeDelay > 0 { // Emulate long handshake. select { case <-r.Context().Done(): return case <-time.After(s.upgradeDelay): } } conn, err := testUpgrader.Upgrade(w, r, http.Header{}) if err != nil { return } defer func() { _ = conn.Close() }() for { messageType, data, err := conn.ReadMessage() if err != nil { break } if s.expectTextFrames && messageType != ws.TextMessage { s.t.Fatalf("unexpected frame type: %d", messageType) } select { case s.messages <- data: case <-time.After(5 * time.Second): s.t.Fatal("timeout writing to messages channel, make sure there are readers") } } } func initWebSocket(s *testServer) *WebSocket { w := newWebSocket() w.Log = testutil.Logger{} w.URL = s.URL headerSecret := config.NewSecret([]byte(testHeaderValue)) w.Headers = map[string]*config.Secret{testHeaderName: &headerSecret} w.SetSerializer(newTestSerializer()) return w } func connect(t *testing.T, w *WebSocket) { err := w.Connect() require.NoError(t, err) } func TestWebSocket_NoURL(t *testing.T) { w := newWebSocket() err := w.Init() require.ErrorIs(t, err, errInvalidURL) } func TestWebSocket_Connect_Timeout(t *testing.T) { s := newTestServer(t, nil, false) s.upgradeDelay = time.Second defer s.Close() w := initWebSocket(s) w.ConnectTimeout = config.Duration(10 * time.Millisecond) err := w.Connect() require.Error(t, err) } func TestWebSocket_Connect_OK(t *testing.T) { s := newTestServer(t, nil, false) defer s.Close() w := initWebSocket(s) connect(t, w) } func TestWebSocket_ConnectTLS_OK(t *testing.T) { s := newTestServer(t, nil, true) defer s.Close() w := initWebSocket(s) w.ClientConfig.InsecureSkipVerify = true connect(t, w) } func TestWebSocket_Write_OK(t *testing.T) { messages := make(chan []byte, 1) s := newTestServer(t, messages, false) defer s.Close() w := initWebSocket(s) connect(t, w) metrics := []telegraf.Metric{ testutil.TestMetric(0.4, "test"), testutil.TestMetric(0.5, "test"), } err := w.Write(metrics) require.NoError(t, err) select { case data := <-messages: require.Equal(t, []byte("2"), data) case <-time.After(time.Second): t.Fatal("timeout receiving data") } } func TestWebSocket_Write_Error(t *testing.T) { s := newTestServer(t, nil, false) defer s.Close() w := initWebSocket(s) connect(t, w) require.NoError(t, w.conn.Close()) metrics := []telegraf.Metric{testutil.TestMetric(0.4, "test")} err := w.Write(metrics) require.Error(t, err) require.Nil(t, w.conn) } func TestWebSocket_Write_Reconnect(t *testing.T) { messages := make(chan []byte, 1) s := newTestServer(t, messages, false) s.expectTextFrames = true // Also use text frames in this test. defer s.Close() w := initWebSocket(s) w.UseTextFrames = true connect(t, w) metrics := []telegraf.Metric{testutil.TestMetric(0.4, "test")} require.NoError(t, w.conn.Close()) err := w.Write(metrics) require.Error(t, err) require.Nil(t, w.conn) err = w.Write(metrics) require.NoError(t, err) select { case data := <-messages: require.Equal(t, []byte("1"), data) case <-time.After(time.Second): t.Fatal("timeout receiving data") } } func TestWebSocket_Close(t *testing.T) { s := newTestServer(t, nil, false) defer s.Close() w := initWebSocket(s) connect(t, w) require.NoError(t, w.Close()) // Check no error on second close. require.NoError(t, w.Close()) }