//go:build windows package win_services import ( "bytes" "errors" "fmt" "log" "testing" "github.com/stretchr/testify/require" "golang.org/x/sys/windows/svc" "golang.org/x/sys/windows/svc/mgr" "github.com/influxdata/telegraf/testutil" ) // testData is DD wrapper for unit testing of WinServices type testData struct { // collection that will be returned in listServices if service array passed into WinServices constructor is empty queryServiceList []string mgrConnectError error mgrListServicesError error services []serviceTestInfo } type serviceTestInfo struct { serviceOpenError error serviceQueryError error serviceConfigError error serviceName string displayName string state int startUpMode int } type FakeSvcMgr struct { testData testData } func (*FakeSvcMgr) disconnect() error { return nil } func (m *FakeSvcMgr) openService(name string) (winService, error) { for _, s := range m.testData.services { if s.serviceName == name { if s.serviceOpenError != nil { return nil, s.serviceOpenError } return &fakeWinSvc{s}, nil } } return nil, fmt.Errorf("cannot find service %q", name) } func (m *FakeSvcMgr) listServices() ([]string, error) { if m.testData.mgrListServicesError != nil { return nil, m.testData.mgrListServicesError } return m.testData.queryServiceList, nil } type FakeMgProvider struct { testData testData } func (m *FakeMgProvider) connect() (winServiceManager, error) { if m.testData.mgrConnectError != nil { return nil, m.testData.mgrConnectError } return &FakeSvcMgr{m.testData}, nil } type fakeWinSvc struct { testData serviceTestInfo } func (*fakeWinSvc) Close() error { return nil } func (m *fakeWinSvc) Config() (mgr.Config, error) { if m.testData.serviceConfigError != nil { return mgr.Config{}, m.testData.serviceConfigError } return mgr.Config{ ServiceType: 0, StartType: uint32(m.testData.startUpMode), ErrorControl: 0, BinaryPathName: "", LoadOrderGroup: "", TagId: 0, Dependencies: nil, ServiceStartName: m.testData.serviceName, DisplayName: m.testData.displayName, Password: "", Description: "", }, nil } func (m *fakeWinSvc) Query() (svc.Status, error) { if m.testData.serviceQueryError != nil { return svc.Status{}, m.testData.serviceQueryError } return svc.Status{ State: svc.State(m.testData.state), Accepts: 0, CheckPoint: 0, WaitHint: 0, }, nil } var testErrors = []testData{ {nil, errors.New("fake mgr connect error"), nil, nil}, {nil, nil, errors.New("fake mgr list services error"), nil}, {[]string{"Fake service 1", "Fake service 2", "Fake service 3"}, nil, nil, []serviceTestInfo{ {errors.New("fake srv open error"), nil, nil, "Fake service 1", "", 0, 0}, {nil, errors.New("fake srv query error"), nil, "Fake service 2", "", 0, 0}, {nil, nil, errors.New("fake srv config error"), "Fake service 3", "", 0, 0}, }}, {[]string{"Fake service 1"}, nil, nil, []serviceTestInfo{ {errors.New("fake srv open error"), nil, nil, "Fake service 1", "", 0, 0}, }}, } func TestMgrErrors(t *testing.T) { // mgr.connect error winServices := &WinServices{ Log: testutil.Logger{}, mgrProvider: &FakeMgProvider{testErrors[0]}, } var acc1 testutil.Accumulator err := winServices.Gather(&acc1) require.Error(t, err) require.Contains(t, err.Error(), testErrors[0].mgrConnectError.Error()) // mgr.listServices error winServices = &WinServices{ Log: testutil.Logger{}, mgrProvider: &FakeMgProvider{testErrors[1]}, } var acc2 testutil.Accumulator err = winServices.Gather(&acc2) require.Error(t, err) require.Contains(t, err.Error(), testErrors[1].mgrListServicesError.Error()) // mgr.listServices error 2 winServices = &WinServices{ Log: testutil.Logger{}, ServiceNames: []string{"Fake service 1"}, mgrProvider: &FakeMgProvider{testErrors[3]}, } err = winServices.Init() require.NoError(t, err) var acc3 testutil.Accumulator buf := &bytes.Buffer{} log.SetOutput(buf) require.NoError(t, winServices.Gather(&acc3)) require.Contains(t, buf.String(), testErrors[2].services[0].serviceOpenError.Error()) } func TestServiceErrors(t *testing.T) { winServices := &WinServices{ Log: testutil.Logger{}, mgrProvider: &FakeMgProvider{testErrors[2]}, } err := winServices.Init() require.NoError(t, err) var acc1 testutil.Accumulator buf := &bytes.Buffer{} log.SetOutput(buf) require.NoError(t, winServices.Gather(&acc1)) // open service error require.Contains(t, buf.String(), testErrors[2].services[0].serviceOpenError.Error()) // query service error require.Contains(t, buf.String(), testErrors[2].services[1].serviceQueryError.Error()) // config service error require.Contains(t, buf.String(), testErrors[2].services[2].serviceConfigError.Error()) } var testSimpleData = []testData{ {[]string{"Service 1", "Service 2"}, nil, nil, []serviceTestInfo{ {nil, nil, nil, "Service 1", "Fake service 1", 1, 2}, {nil, nil, nil, "Service 2", "Fake service 2", 1, 2}, }}, } func TestGatherContainsTag(t *testing.T) { winServices := &WinServices{ Log: testutil.Logger{}, ServiceNames: []string{"Service*"}, mgrProvider: &FakeMgProvider{testSimpleData[0]}, } err := winServices.Init() require.NoError(t, err) var acc1 testutil.Accumulator require.NoError(t, winServices.Gather(&acc1)) require.Empty(t, acc1.Errors, "There should be no errors after gather") for _, s := range testSimpleData[0].services { fields := make(map[string]interface{}) tags := make(map[string]string) fields["state"] = s.state fields["startup_mode"] = s.startUpMode tags["service_name"] = s.serviceName tags["display_name"] = s.displayName acc1.AssertContainsTaggedFields(t, "win_services", fields, tags) } } func TestExcludingNamesTag(t *testing.T) { winServices := &WinServices{ Log: testutil.Logger{}, ServiceNamesExcluded: []string{"Service*"}, mgrProvider: &FakeMgProvider{testSimpleData[0]}, } err := winServices.Init() require.NoError(t, err) var acc1 testutil.Accumulator require.NoError(t, winServices.Gather(&acc1)) for _, s := range testSimpleData[0].services { fields := make(map[string]interface{}) tags := make(map[string]string) fields["state"] = s.state fields["startup_mode"] = s.startUpMode tags["service_name"] = s.serviceName tags["display_name"] = s.displayName acc1.AssertDoesNotContainsTaggedFields(t, "win_services", fields, tags) } }