Adding upstream version 0.8.9.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
3b2c48b5e4
commit
c0c4addb85
285 changed files with 25880 additions and 0 deletions
219
pkg/generators/basic/basic.go
Normal file
219
pkg/generators/basic/basic.go
Normal file
|
@ -0,0 +1,219 @@
|
|||
package basic
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/fatih/color"
|
||||
|
||||
"github.com/nicholas-fedor/shoutrrr/pkg/format"
|
||||
"github.com/nicholas-fedor/shoutrrr/pkg/types"
|
||||
)
|
||||
|
||||
// Errors defined as static variables for better error handling.
|
||||
var (
|
||||
ErrInvalidConfigType = errors.New("config does not implement types.ServiceConfig")
|
||||
ErrInvalidConfigField = errors.New("config field is invalid or nil")
|
||||
ErrRequiredFieldMissing = errors.New("field is required and has no default value")
|
||||
)
|
||||
|
||||
// Generator is the Basic Generator implementation for creating service configurations.
|
||||
type Generator struct{}
|
||||
|
||||
// Generate creates a service configuration by prompting the user for field values or using provided properties.
|
||||
func (g *Generator) Generate(
|
||||
service types.Service,
|
||||
props map[string]string,
|
||||
_ []string,
|
||||
) (types.ServiceConfig, error) {
|
||||
configPtr := reflect.ValueOf(service).Elem().FieldByName("Config")
|
||||
if !configPtr.IsValid() || configPtr.IsNil() {
|
||||
return nil, ErrInvalidConfigField
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(os.Stdin)
|
||||
if err := g.promptUserForFields(configPtr, props, scanner); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if config, ok := configPtr.Interface().(types.ServiceConfig); ok {
|
||||
return config, nil
|
||||
}
|
||||
|
||||
return nil, ErrInvalidConfigType
|
||||
}
|
||||
|
||||
// promptUserForFields iterates over config fields, prompting the user or using props to set values.
|
||||
func (g *Generator) promptUserForFields(
|
||||
configPtr reflect.Value,
|
||||
props map[string]string,
|
||||
scanner *bufio.Scanner,
|
||||
) error {
|
||||
serviceConfig, ok := configPtr.Interface().(types.ServiceConfig)
|
||||
if !ok {
|
||||
return ErrInvalidConfigType
|
||||
}
|
||||
|
||||
configNode := format.GetConfigFormat(serviceConfig)
|
||||
config := configPtr.Elem() // Dereference for setting fields
|
||||
|
||||
for _, item := range configNode.Items {
|
||||
field := item.Field()
|
||||
propKey := strings.ToLower(field.Name)
|
||||
|
||||
for {
|
||||
inputValue, err := g.getInputValue(field, propKey, props, scanner)
|
||||
if err != nil {
|
||||
return err // Propagate the error immediately
|
||||
}
|
||||
|
||||
if valid, err := g.setFieldValue(config, field, inputValue); valid {
|
||||
break
|
||||
} else if err != nil {
|
||||
g.printError(field.Name, err.Error())
|
||||
} else {
|
||||
g.printInvalidType(field.Name, field.Type.Kind().String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getInputValue retrieves the value for a field from props or user input.
|
||||
func (g *Generator) getInputValue(
|
||||
field *format.FieldInfo,
|
||||
propKey string,
|
||||
props map[string]string,
|
||||
scanner *bufio.Scanner,
|
||||
) (string, error) {
|
||||
if propValue, ok := props[propKey]; ok && len(propValue) > 0 {
|
||||
_, _ = fmt.Fprint(
|
||||
color.Output,
|
||||
"Using property ",
|
||||
color.HiCyanString(propValue),
|
||||
" for ",
|
||||
color.HiMagentaString(field.Name),
|
||||
" field\n",
|
||||
)
|
||||
props[propKey] = ""
|
||||
|
||||
return propValue, nil
|
||||
}
|
||||
|
||||
prompt := g.formatPrompt(field)
|
||||
_, _ = fmt.Fprint(color.Output, prompt)
|
||||
|
||||
if scanner.Scan() {
|
||||
input := scanner.Text()
|
||||
if len(input) == 0 {
|
||||
if len(field.DefaultValue) > 0 {
|
||||
return field.DefaultValue, nil
|
||||
}
|
||||
|
||||
if field.Required {
|
||||
return "", fmt.Errorf("%s: %w", field.Name, ErrRequiredFieldMissing)
|
||||
}
|
||||
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// More specific type validation
|
||||
if field.Type != nil {
|
||||
kind := field.Type.Kind()
|
||||
if kind == reflect.Int || kind == reflect.Int8 || kind == reflect.Int16 ||
|
||||
kind == reflect.Int32 || kind == reflect.Int64 {
|
||||
if _, err := strconv.ParseInt(input, 10, field.Type.Bits()); err != nil {
|
||||
return "", fmt.Errorf("invalid integer value for %s: %w", field.Name, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return input, nil
|
||||
} else if scanErr := scanner.Err(); scanErr != nil {
|
||||
return "", fmt.Errorf("scanner error: %w", scanErr)
|
||||
}
|
||||
|
||||
return field.DefaultValue, nil
|
||||
}
|
||||
|
||||
// formatPrompt creates a user prompt based on the field’s name and default value.
|
||||
func (g *Generator) formatPrompt(field *format.FieldInfo) string {
|
||||
if len(field.DefaultValue) > 0 {
|
||||
return fmt.Sprintf("%s[%s]: ", color.HiWhiteString(field.Name), field.DefaultValue)
|
||||
}
|
||||
|
||||
return color.HiWhiteString(field.Name) + ": "
|
||||
}
|
||||
|
||||
// setFieldValue attempts to set a field’s value and handles required field validation.
|
||||
func (g *Generator) setFieldValue(
|
||||
config reflect.Value,
|
||||
field *format.FieldInfo,
|
||||
inputValue string,
|
||||
) (bool, error) {
|
||||
if len(inputValue) == 0 {
|
||||
if field.Required {
|
||||
_, _ = fmt.Fprint(
|
||||
color.Output,
|
||||
"Field ",
|
||||
color.HiCyanString(field.Name),
|
||||
" is required!\n\n",
|
||||
)
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if len(field.DefaultValue) == 0 {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
inputValue = field.DefaultValue
|
||||
}
|
||||
|
||||
valid, err := format.SetConfigField(config, *field, inputValue)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to set field %s: %w", field.Name, err)
|
||||
}
|
||||
|
||||
return valid, nil
|
||||
}
|
||||
|
||||
// printError displays an error message for an invalid field value.
|
||||
func (g *Generator) printError(fieldName, errorMsg string) {
|
||||
_, _ = fmt.Fprint(
|
||||
color.Output,
|
||||
"Invalid format for field ",
|
||||
color.HiCyanString(fieldName),
|
||||
": ",
|
||||
errorMsg,
|
||||
"\n\n",
|
||||
)
|
||||
}
|
||||
|
||||
// printInvalidType displays a type mismatch error for a field.
|
||||
func (g *Generator) printInvalidType(fieldName, typeName string) {
|
||||
_, _ = fmt.Fprint(
|
||||
color.Output,
|
||||
"Invalid type ",
|
||||
color.HiYellowString(typeName),
|
||||
" for field ",
|
||||
color.HiCyanString(fieldName),
|
||||
"\n\n",
|
||||
)
|
||||
}
|
||||
|
||||
// validateAndReturnConfig ensures the config implements ServiceConfig and returns it.
|
||||
func (g *Generator) validateAndReturnConfig(config reflect.Value) (types.ServiceConfig, error) {
|
||||
configInterface := config.Interface()
|
||||
if serviceConfig, ok := configInterface.(types.ServiceConfig); ok {
|
||||
return serviceConfig, nil
|
||||
}
|
||||
|
||||
return nil, ErrInvalidConfigType
|
||||
}
|
543
pkg/generators/basic/basic_test.go
Normal file
543
pkg/generators/basic/basic_test.go
Normal file
|
@ -0,0 +1,543 @@
|
|||
package basic
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"text/template"
|
||||
|
||||
"github.com/fatih/color"
|
||||
|
||||
"github.com/nicholas-fedor/shoutrrr/pkg/format"
|
||||
"github.com/nicholas-fedor/shoutrrr/pkg/types"
|
||||
)
|
||||
|
||||
// mockConfig implements types.ServiceConfig.
|
||||
type mockConfig struct {
|
||||
Host string `default:"localhost" key:"host"`
|
||||
Port int `default:"8080" key:"port" required:"true"`
|
||||
url *url.URL
|
||||
}
|
||||
|
||||
func (m *mockConfig) Enums() map[string]types.EnumFormatter {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockConfig) GetURL() *url.URL {
|
||||
if m.url == nil {
|
||||
u, _ := url.Parse("mock://url")
|
||||
m.url = u
|
||||
}
|
||||
|
||||
return m.url
|
||||
}
|
||||
|
||||
func (m *mockConfig) SetURL(u *url.URL) error {
|
||||
m.url = u
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockConfig) SetTemplateFile(_ string, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockConfig) SetTemplateString(_ string, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockConfig) SetLogger(_ types.StdLogger) {
|
||||
// Minimal implementation, no-op
|
||||
}
|
||||
|
||||
// ConfigQueryResolver methods.
|
||||
func (m *mockConfig) Get(key string) (string, error) {
|
||||
switch strings.ToLower(key) {
|
||||
case "host":
|
||||
return m.Host, nil
|
||||
case "port":
|
||||
return strconv.Itoa(m.Port), nil
|
||||
default:
|
||||
return "", fmt.Errorf("unknown key: %s", key)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockConfig) Set(key string, value string) error {
|
||||
switch strings.ToLower(key) {
|
||||
case "host":
|
||||
m.Host = value
|
||||
|
||||
return nil
|
||||
case "port":
|
||||
port, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.Port = port
|
||||
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("unknown key: %s", key)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockConfig) QueryFields() []string {
|
||||
return []string{"host", "port"}
|
||||
}
|
||||
|
||||
// mockServiceConfig is a test implementation of Service.
|
||||
type mockServiceConfig struct {
|
||||
Config *mockConfig
|
||||
}
|
||||
|
||||
func (m *mockServiceConfig) GetID() string {
|
||||
return "mockID"
|
||||
}
|
||||
|
||||
func (m *mockServiceConfig) GetTemplate(_ string) (*template.Template, bool) {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (m *mockServiceConfig) SetTemplateFile(_ string, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockServiceConfig) SetTemplateString(_ string, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockServiceConfig) Initialize(_ *url.URL, _ types.StdLogger) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockServiceConfig) Send(_ string, _ *types.Params) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockServiceConfig) SetLogger(_ types.StdLogger) {}
|
||||
|
||||
// ConfigProp methods.
|
||||
func (m *mockConfig) SetFromProp(propValue string) error {
|
||||
// Minimal implementation for testing; typically parses propValue
|
||||
parts := strings.SplitN(propValue, ":", 2)
|
||||
if len(parts) == 2 {
|
||||
m.Host = parts[0]
|
||||
|
||||
port, err := strconv.Atoi(parts[1])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.Port = port
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockConfig) GetPropValue() (string, error) {
|
||||
// Minimal implementation for testing
|
||||
return fmt.Sprintf("%s:%d", m.Host, m.Port), nil
|
||||
}
|
||||
|
||||
// newMockServiceConfig creates a new mockServiceConfig with an initialized Config.
|
||||
func newMockServiceConfig() *mockServiceConfig {
|
||||
return &mockServiceConfig{
|
||||
Config: &mockConfig{},
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerator_Generate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
props map[string]string
|
||||
input string
|
||||
want types.ServiceConfig
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "successful generation with defaults",
|
||||
props: map[string]string{},
|
||||
input: "\n8080\n",
|
||||
want: &mockConfig{
|
||||
Host: "localhost",
|
||||
Port: 8080,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "successful generation with props",
|
||||
props: map[string]string{"host": "example.com", "port": "9090"},
|
||||
input: "",
|
||||
want: &mockConfig{
|
||||
Host: "example.com",
|
||||
Port: 9090,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "error_on_invalid_port",
|
||||
props: map[string]string{},
|
||||
input: "\ninvalid\n",
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
g := &Generator{}
|
||||
|
||||
// Set up pipe for stdin
|
||||
r, w, err := os.Pipe()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
originalStdin := os.Stdin
|
||||
os.Stdin = r
|
||||
|
||||
defer func() {
|
||||
os.Stdin = originalStdin
|
||||
|
||||
w.Close()
|
||||
}()
|
||||
|
||||
// Write input to the pipe
|
||||
_, err = w.WriteString(tt.input)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
w.Close()
|
||||
|
||||
service := newMockServiceConfig()
|
||||
color.NoColor = true
|
||||
|
||||
got, err := g.Generate(service, tt.props, nil)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Generate() error = %v, wantErr %v", err, tt.wantErr)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if !tt.wantErr && !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("Generate() = %+v, want %+v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerator_promptUserForFields(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config reflect.Value
|
||||
props map[string]string
|
||||
input string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid input with defaults",
|
||||
config: reflect.ValueOf(newMockServiceConfig().Config), // Pass *mockConfig
|
||||
props: map[string]string{},
|
||||
input: "\n8080\n",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid props",
|
||||
config: reflect.ValueOf(newMockServiceConfig().Config), // Pass *mockConfig
|
||||
props: map[string]string{"host": "test.com", "port": "1234"},
|
||||
input: "",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid config type",
|
||||
config: reflect.ValueOf("not a config"),
|
||||
props: map[string]string{},
|
||||
input: "",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
g := &Generator{}
|
||||
scanner := bufio.NewScanner(strings.NewReader(tt.input))
|
||||
color.NoColor = true
|
||||
|
||||
err := g.promptUserForFields(tt.config, tt.props, scanner)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("promptUserForFields() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
|
||||
if err == nil && tt.config.Kind() == reflect.Ptr &&
|
||||
tt.config.Type().Elem().Kind() == reflect.Struct {
|
||||
got := tt.config.Interface().(*mockConfig)
|
||||
if tt.props["host"] != "" && got.Host != tt.props["host"] {
|
||||
t.Errorf("promptUserForFields() host = %v, want %v", got.Host, tt.props["host"])
|
||||
}
|
||||
|
||||
if tt.props["port"] != "" {
|
||||
wantPort := atoiOrZero(tt.props["port"])
|
||||
if got.Port != wantPort {
|
||||
t.Errorf("promptUserForFields() port = %v, want %v", got.Port, wantPort)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerator_getInputValue(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
field *format.FieldInfo
|
||||
propKey string
|
||||
props map[string]string
|
||||
input string
|
||||
want string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "from props",
|
||||
field: &format.FieldInfo{Name: "Host"},
|
||||
propKey: "host",
|
||||
props: map[string]string{"host": "example.com"},
|
||||
input: "",
|
||||
want: "example.com",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "from user input",
|
||||
field: &format.FieldInfo{Name: "Port", Type: reflect.TypeOf(0)}, // Add Type
|
||||
propKey: "port",
|
||||
props: map[string]string{},
|
||||
input: "8080\n",
|
||||
want: "8080",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "default value",
|
||||
field: &format.FieldInfo{Name: "Host", DefaultValue: "localhost"},
|
||||
propKey: "host",
|
||||
props: map[string]string{},
|
||||
input: "\n",
|
||||
want: "localhost",
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
g := &Generator{}
|
||||
scanner := bufio.NewScanner(strings.NewReader(tt.input))
|
||||
color.NoColor = true
|
||||
|
||||
got, err := g.getInputValue(tt.field, tt.propKey, tt.props, scanner)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("getInputValue() error = %v, wantErr %v", err, tt.wantErr)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if got != tt.want {
|
||||
t.Errorf("getInputValue() = %v, want %v", got, tt.want)
|
||||
}
|
||||
|
||||
if tt.props[tt.propKey] != "" {
|
||||
t.Errorf("getInputValue() did not clear prop, got %v", tt.props[tt.propKey])
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerator_formatPrompt(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
field *format.FieldInfo
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "field with default",
|
||||
field: &format.FieldInfo{Name: "Host", DefaultValue: "localhost"},
|
||||
want: "\x1b[97mHost\x1b[0m[localhost]: ",
|
||||
},
|
||||
{
|
||||
name: "field without default",
|
||||
field: &format.FieldInfo{Name: "Port"},
|
||||
want: "\x1b[97mPort\x1b[0m: ",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
g := &Generator{}
|
||||
color.NoColor = false
|
||||
|
||||
got := g.formatPrompt(tt.field)
|
||||
if got != tt.want {
|
||||
t.Errorf("formatPrompt() = %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerator_setFieldValue(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config reflect.Value
|
||||
field *format.FieldInfo
|
||||
inputValue string
|
||||
want bool
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid value",
|
||||
config: reflect.ValueOf(newMockServiceConfig().Config).Elem(),
|
||||
field: &format.FieldInfo{Name: "Port", Type: reflect.TypeOf(0), Required: true},
|
||||
inputValue: "8080",
|
||||
want: true,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "required field empty",
|
||||
config: reflect.ValueOf(newMockServiceConfig().Config).Elem(),
|
||||
field: &format.FieldInfo{Name: "Port", Type: reflect.TypeOf(0), Required: true},
|
||||
inputValue: "",
|
||||
want: false,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid value",
|
||||
config: reflect.ValueOf(newMockServiceConfig().Config).Elem(),
|
||||
field: &format.FieldInfo{Name: "Port", Type: reflect.TypeOf(0)},
|
||||
inputValue: "invalid",
|
||||
want: false,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
g := &Generator{}
|
||||
color.NoColor = true
|
||||
|
||||
got, err := g.setFieldValue(tt.config, tt.field, tt.inputValue)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("setFieldValue() error = %v, wantErr %v", err, tt.wantErr)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if got != tt.want {
|
||||
t.Errorf("setFieldValue() = %v, want %v", got, tt.want)
|
||||
}
|
||||
|
||||
if got && !tt.wantErr {
|
||||
if tt.field.Name == "Port" {
|
||||
wantPort := atoiOrZero(tt.inputValue)
|
||||
if gotPort := tt.config.FieldByName("Port").Int(); int(gotPort) != wantPort {
|
||||
t.Errorf("setFieldValue() set Port = %v, want %v", gotPort, wantPort)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerator_printError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
fieldName string
|
||||
errorMsg string
|
||||
}{
|
||||
{
|
||||
name: "basic error",
|
||||
fieldName: "Port",
|
||||
errorMsg: "invalid format",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(*testing.T) {
|
||||
g := &Generator{}
|
||||
color.NoColor = true
|
||||
|
||||
g.printError(tt.fieldName, tt.errorMsg)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerator_printInvalidType(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
fieldName string
|
||||
typeName string
|
||||
}{
|
||||
{
|
||||
name: "invalid type",
|
||||
fieldName: "Port",
|
||||
typeName: "int",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(*testing.T) {
|
||||
g := &Generator{}
|
||||
color.NoColor = true
|
||||
|
||||
g.printInvalidType(tt.fieldName, tt.typeName)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerator_validateAndReturnConfig(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config reflect.Value
|
||||
want types.ServiceConfig
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid config",
|
||||
config: reflect.ValueOf(&mockConfig{Host: "test", Port: 1234}),
|
||||
want: &mockConfig{Host: "test", Port: 1234},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid config type",
|
||||
config: reflect.ValueOf("not a config"),
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
g := &Generator{}
|
||||
|
||||
got, err := g.validateAndReturnConfig(tt.config)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("validateAndReturnConfig() error = %v, wantErr %v", err, tt.wantErr)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("validateAndReturnConfig() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// atoiOrZero converts a string to an int, returning 0 on error.
|
||||
func atoiOrZero(s string) int {
|
||||
i, _ := strconv.Atoi(s)
|
||||
|
||||
return i
|
||||
}
|
44
pkg/generators/router.go
Normal file
44
pkg/generators/router.go
Normal file
|
@ -0,0 +1,44 @@
|
|||
package generators
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/nicholas-fedor/shoutrrr/pkg/generators/basic"
|
||||
"github.com/nicholas-fedor/shoutrrr/pkg/generators/xouath2"
|
||||
"github.com/nicholas-fedor/shoutrrr/pkg/services/telegram"
|
||||
"github.com/nicholas-fedor/shoutrrr/pkg/types"
|
||||
)
|
||||
|
||||
var ErrUnknownGenerator = errors.New("unknown generator")
|
||||
|
||||
var generatorMap = map[string]func() types.Generator{
|
||||
"basic": func() types.Generator { return &basic.Generator{} },
|
||||
"oauth2": func() types.Generator { return &xouath2.Generator{} },
|
||||
"telegram": func() types.Generator { return &telegram.Generator{} },
|
||||
}
|
||||
|
||||
// NewGenerator creates an instance of the generator that corresponds to the provided identifier.
|
||||
func NewGenerator(identifier string) (types.Generator, error) {
|
||||
generatorFactory, valid := generatorMap[strings.ToLower(identifier)]
|
||||
if !valid {
|
||||
return nil, fmt.Errorf("%w: %q", ErrUnknownGenerator, identifier)
|
||||
}
|
||||
|
||||
return generatorFactory(), nil
|
||||
}
|
||||
|
||||
// ListGenerators lists all available generators.
|
||||
func ListGenerators() []string {
|
||||
generators := make([]string, len(generatorMap))
|
||||
|
||||
i := 0
|
||||
|
||||
for key := range generatorMap {
|
||||
generators[i] = key
|
||||
i++
|
||||
}
|
||||
|
||||
return generators
|
||||
}
|
266
pkg/generators/xouath2/xoauth2.go
Normal file
266
pkg/generators/xouath2/xoauth2.go
Normal file
|
@ -0,0 +1,266 @@
|
|||
//go:generate stringer -type=URLPart -trimprefix URL
|
||||
|
||||
package xouath2
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/google"
|
||||
|
||||
"github.com/nicholas-fedor/shoutrrr/pkg/services/smtp"
|
||||
"github.com/nicholas-fedor/shoutrrr/pkg/types"
|
||||
)
|
||||
|
||||
// SMTP port constants.
|
||||
const (
|
||||
DefaultSMTPPort uint16 = 25 // Standard SMTP port without encryption
|
||||
GmailSMTPPortStartTLS uint16 = 587 // Gmail SMTP port with STARTTLS
|
||||
)
|
||||
|
||||
const StateLength int = 16 // Length in bytes for OAuth 2.0 state randomness (128 bits)
|
||||
|
||||
// Errors.
|
||||
var (
|
||||
ErrReadFileFailed = errors.New("failed to read file")
|
||||
ErrUnmarshalFailed = errors.New("failed to unmarshal JSON")
|
||||
ErrScanFailed = errors.New("failed to scan input")
|
||||
ErrTokenExchangeFailed = errors.New("failed to exchange token")
|
||||
)
|
||||
|
||||
// Generator is the XOAuth2 Generator implementation.
|
||||
type Generator struct{}
|
||||
|
||||
// Generate generates a service URL from a set of user questions/answers.
|
||||
func (g *Generator) Generate(
|
||||
_ types.Service,
|
||||
props map[string]string,
|
||||
args []string,
|
||||
) (types.ServiceConfig, error) {
|
||||
if provider, found := props["provider"]; found {
|
||||
if provider == "gmail" {
|
||||
return oauth2GeneratorGmail(args[0])
|
||||
}
|
||||
}
|
||||
|
||||
if len(args) > 0 {
|
||||
return oauth2GeneratorFile(args[0])
|
||||
}
|
||||
|
||||
return oauth2Generator()
|
||||
}
|
||||
|
||||
func oauth2GeneratorFile(file string) (*smtp.Config, error) {
|
||||
jsonData, err := os.ReadFile(file)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%s: %w", file, ErrReadFileFailed)
|
||||
}
|
||||
|
||||
var providerConfig struct {
|
||||
ClientID string `json:"client_id"`
|
||||
ClientSecret string `json:"client_secret"`
|
||||
RedirectURL string `json:"redirect_url"`
|
||||
AuthURL string `json:"auth_url"`
|
||||
TokenURL string `json:"token_url"`
|
||||
Hostname string `json:"smtp_hostname"`
|
||||
Scopes []string `json:"scopes"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(jsonData, &providerConfig); err != nil {
|
||||
return nil, fmt.Errorf("%s: %w", file, ErrUnmarshalFailed)
|
||||
}
|
||||
|
||||
conf := oauth2.Config{
|
||||
ClientID: providerConfig.ClientID,
|
||||
ClientSecret: providerConfig.ClientSecret,
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: providerConfig.AuthURL,
|
||||
TokenURL: providerConfig.TokenURL,
|
||||
AuthStyle: oauth2.AuthStyleAutoDetect,
|
||||
},
|
||||
RedirectURL: providerConfig.RedirectURL,
|
||||
Scopes: providerConfig.Scopes,
|
||||
}
|
||||
|
||||
return generateOauth2Config(&conf, providerConfig.Hostname)
|
||||
}
|
||||
|
||||
func oauth2Generator() (*smtp.Config, error) {
|
||||
scanner := bufio.NewScanner(os.Stdin)
|
||||
|
||||
var clientID string
|
||||
|
||||
fmt.Fprint(os.Stdout, "ClientID: ")
|
||||
|
||||
if scanner.Scan() {
|
||||
clientID = scanner.Text()
|
||||
} else {
|
||||
return nil, fmt.Errorf("clientID: %w", ErrScanFailed)
|
||||
}
|
||||
|
||||
var clientSecret string
|
||||
|
||||
fmt.Fprint(os.Stdout, "ClientSecret: ")
|
||||
|
||||
if scanner.Scan() {
|
||||
clientSecret = scanner.Text()
|
||||
} else {
|
||||
return nil, fmt.Errorf("clientSecret: %w", ErrScanFailed)
|
||||
}
|
||||
|
||||
var authURL string
|
||||
|
||||
fmt.Fprint(os.Stdout, "AuthURL: ")
|
||||
|
||||
if scanner.Scan() {
|
||||
authURL = scanner.Text()
|
||||
} else {
|
||||
return nil, fmt.Errorf("authURL: %w", ErrScanFailed)
|
||||
}
|
||||
|
||||
var tokenURL string
|
||||
|
||||
fmt.Fprint(os.Stdout, "TokenURL: ")
|
||||
|
||||
if scanner.Scan() {
|
||||
tokenURL = scanner.Text()
|
||||
} else {
|
||||
return nil, fmt.Errorf("tokenURL: %w", ErrScanFailed)
|
||||
}
|
||||
|
||||
var redirectURL string
|
||||
|
||||
fmt.Fprint(os.Stdout, "RedirectURL: ")
|
||||
|
||||
if scanner.Scan() {
|
||||
redirectURL = scanner.Text()
|
||||
} else {
|
||||
return nil, fmt.Errorf("redirectURL: %w", ErrScanFailed)
|
||||
}
|
||||
|
||||
var scopes string
|
||||
|
||||
fmt.Fprint(os.Stdout, "Scopes: ")
|
||||
|
||||
if scanner.Scan() {
|
||||
scopes = scanner.Text()
|
||||
} else {
|
||||
return nil, fmt.Errorf("scopes: %w", ErrScanFailed)
|
||||
}
|
||||
|
||||
var hostname string
|
||||
|
||||
fmt.Fprint(os.Stdout, "SMTP Hostname: ")
|
||||
|
||||
if scanner.Scan() {
|
||||
hostname = scanner.Text()
|
||||
} else {
|
||||
return nil, fmt.Errorf("hostname: %w", ErrScanFailed)
|
||||
}
|
||||
|
||||
conf := oauth2.Config{
|
||||
ClientID: clientID,
|
||||
ClientSecret: clientSecret,
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: authURL,
|
||||
TokenURL: tokenURL,
|
||||
AuthStyle: oauth2.AuthStyleAutoDetect,
|
||||
},
|
||||
RedirectURL: redirectURL,
|
||||
Scopes: strings.Split(scopes, ","),
|
||||
}
|
||||
|
||||
return generateOauth2Config(&conf, hostname)
|
||||
}
|
||||
|
||||
func oauth2GeneratorGmail(credFile string) (*smtp.Config, error) {
|
||||
data, err := os.ReadFile(credFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%s: %w", credFile, ErrReadFileFailed)
|
||||
}
|
||||
|
||||
conf, err := google.ConfigFromJSON(data, "https://mail.google.com/")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"%s: %w",
|
||||
credFile,
|
||||
err,
|
||||
) // google.ConfigFromJSON error doesn't need custom wrapping
|
||||
}
|
||||
|
||||
return generateOauth2Config(conf, "smtp.gmail.com")
|
||||
}
|
||||
|
||||
func generateOauth2Config(conf *oauth2.Config, host string) (*smtp.Config, error) {
|
||||
scanner := bufio.NewScanner(os.Stdin)
|
||||
|
||||
// Generate a random state value
|
||||
stateBytes := make([]byte, StateLength)
|
||||
if _, err := rand.Read(stateBytes); err != nil {
|
||||
return nil, fmt.Errorf("generating random state: %w", err)
|
||||
}
|
||||
|
||||
state := base64.URLEncoding.EncodeToString(stateBytes)
|
||||
|
||||
fmt.Fprintf(
|
||||
os.Stdout,
|
||||
"Visit the following URL to authenticate:\n%s\n\n",
|
||||
conf.AuthCodeURL(state),
|
||||
)
|
||||
|
||||
var verCode string
|
||||
|
||||
fmt.Fprint(os.Stdout, "Enter verification code: ")
|
||||
|
||||
if scanner.Scan() {
|
||||
verCode = scanner.Text()
|
||||
} else {
|
||||
return nil, fmt.Errorf("verification code: %w", ErrScanFailed)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
token, err := conf.Exchange(ctx, verCode)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%s: %w", verCode, ErrTokenExchangeFailed)
|
||||
}
|
||||
|
||||
var sender string
|
||||
|
||||
fmt.Fprint(os.Stdout, "Enter sender e-mail: ")
|
||||
|
||||
if scanner.Scan() {
|
||||
sender = scanner.Text()
|
||||
} else {
|
||||
return nil, fmt.Errorf("sender email: %w", ErrScanFailed)
|
||||
}
|
||||
|
||||
// Determine the appropriate port based on the host
|
||||
port := DefaultSMTPPort
|
||||
if host == "smtp.gmail.com" {
|
||||
port = GmailSMTPPortStartTLS // Use 587 for Gmail with STARTTLS
|
||||
}
|
||||
|
||||
svcConf := &smtp.Config{
|
||||
Host: host,
|
||||
Port: port,
|
||||
Username: sender,
|
||||
Password: token.AccessToken,
|
||||
FromAddress: sender,
|
||||
FromName: "Shoutrrr",
|
||||
ToAddresses: []string{sender},
|
||||
Auth: smtp.AuthTypes.OAuth2,
|
||||
UseStartTLS: true,
|
||||
UseHTML: true,
|
||||
}
|
||||
|
||||
return svcConf, nil
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue