diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..44db581 --- /dev/null +++ b/.gitattributes @@ -0,0 +1 @@ +testdata/dos-lines eol=crlf diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..e402af1 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,42 @@ +on: [push, pull_request] +name: Test +jobs: + lint: + runs-on: ubuntu-latest + steps: + - name: Install Go + uses: WillAbides/setup-go-faster@main + with: + go-version: 1.21.x + - uses: actions/checkout@v4 + with: + path: './src/github.com/iamFrancescoFerro/ssh_config' + # staticcheck needs this for GOPATH + - run: | + echo "GO111MODULE=off" >> $GITHUB_ENV + echo "GOPATH=$GITHUB_WORKSPACE" >> $GITHUB_ENV + echo "PATH=$GITHUB_WORKSPACE/bin:$PATH" >> $GITHUB_ENV + - name: Run tests + run: make lint + working-directory: './src/github.com/iamFrancescoFerro/ssh_config' + + test: + strategy: + matrix: + go-version: [1.17.x, 1.18.x, 1.19.x, 1.20.x, 1.21.x] + runs-on: ubuntu-latest + steps: + - name: Install Go + uses: WillAbides/setup-go-faster@main + with: + go-version: ${{ matrix.go-version }} + - uses: actions/checkout@v4 + with: + path: './src/github.com/iamFrancescoFerro/ssh_config' + - run: | + echo "GO111MODULE=off" >> $GITHUB_ENV + echo "GOPATH=$GITHUB_WORKSPACE" >> $GITHUB_ENV + echo "PATH=$GITHUB_WORKSPACE/bin:$PATH" >> $GITHUB_ENV + - name: Run tests with race detector on + run: make race-test + working-directory: './src/github.com/iamFrancescoFerro/ssh_config' diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e69de29 diff --git a/.mailmap b/.mailmap new file mode 100644 index 0000000..253406b --- /dev/null +++ b/.mailmap @@ -0,0 +1 @@ +Kevin Burke Kevin Burke diff --git a/AUTHORS.txt b/AUTHORS.txt new file mode 100644 index 0000000..311aeb1 --- /dev/null +++ b/AUTHORS.txt @@ -0,0 +1,9 @@ +Carlos A Becker +Dustin Spicuzza +Eugene Terentev +Kevin Burke +Mark Nevill +Scott Lessans +Sergey Lukjanov +Wayne Ashley Berry +santosh653 <70637961+santosh653@users.noreply.github.com> diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..d32a3f5 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,19 @@ +# Changes + +## Version 1.2 + +Previously, if a Host declaration or a value had trailing whitespace, that +whitespace would have been included as part of the value. This led to unexpected +consequences. For example: + +``` +Host example # A comment + HostName example.com # Another comment +``` + +Prior to version 1.2, the value for Host would have been "example " and the +value for HostName would have been "example.com ". Both of these are +unintuitive. + +Instead, we strip the trailing whitespace in the configuration, which leads to +more intuitive behavior. diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..b9a770a --- /dev/null +++ b/LICENSE @@ -0,0 +1,49 @@ +Copyright (c) 2017 Kevin Burke. + +Permission is hereby granted, free of charge, to any person +obtaining a copy of this software and associated documentation +files (the "Software"), to deal in the Software without +restriction, including without limitation the rights to use, +copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following +conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES +OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT +HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, +WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR +OTHER DEALINGS IN THE SOFTWARE. + +=================== + +The lexer and parser borrow heavily from github.com/pelletier/go-toml. The +license for that project is copied below. + +The MIT License (MIT) + +Copyright (c) 2013 - 2017 Thomas Pelletier, Eric Anderton + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..d30c5e5 --- /dev/null +++ b/Makefile @@ -0,0 +1,33 @@ +BUMP_VERSION := $(GOPATH)/bin/bump_version +STATICCHECK := $(GOPATH)/bin/staticcheck +WRITE_MAILMAP := $(GOPATH)/bin/write_mailmap + +$(STATICCHECK): + go get honnef.co/go/tools/cmd/staticcheck + +lint: $(STATICCHECK) + go vet ./... + $(STATICCHECK) + +test: + @# the timeout helps guard against infinite recursion + go test -timeout=250ms ./... + +race-test: + go test -timeout=500ms -race ./... + +$(BUMP_VERSION): + go get -u github.com/dimonomid/bump_version + +$(WRITE_MAILMAP): + go get -u github.com/dimonomid/write_mailmap + +release: test | $(BUMP_VERSION) + $(BUMP_VERSION) --tag-prefix=v minor config.go + +force: ; + +AUTHORS.txt: force | $(WRITE_MAILMAP) + $(WRITE_MAILMAP) > AUTHORS.txt + +authors: AUTHORS.txt diff --git a/README.md b/README.md new file mode 100644 index 0000000..b3199c7 --- /dev/null +++ b/README.md @@ -0,0 +1,92 @@ +# ssh_config + +This is a Go parser for `ssh_config` files. Importantly, this parser attempts +to preserve comments in a given file, so you can manipulate a `ssh_config` file +from a program, if your heart desires. + +It's designed to be used with the excellent +[x/crypto/ssh](https://golang.org/x/crypto/ssh) package, which handles SSH +negotiation but isn't very easy to configure. + +The `ssh_config` `Get()` and `GetStrict()` functions will attempt to read values +from `$HOME/.ssh/config` and fall back to `/etc/ssh/ssh_config`. The first +argument is the host name to match on, and the second argument is the key you +want to retrieve. + +```go +port := ssh_config.Get("myhost", "Port") +``` + +Certain directives can occur multiple times for a host (such as `IdentityFile`), +so you should use the `GetAll` or `GetAllStrict` directive to retrieve those +instead. + +```go +files := ssh_config.GetAll("myhost", "IdentityFile") +``` + +You can also load a config file and read values from it. + +```go +var config = ` +Host *.test + Compression yes +` + +cfg, err := ssh_config.Decode(strings.NewReader(config)) +fmt.Println(cfg.Get("example.test", "Port")) +``` + +Some SSH arguments have default values - for example, the default value for +`KeyboardAuthentication` is `"yes"`. If you call Get(), and no value for the +given Host/keyword pair exists in the config, we'll return a default for the +keyword if one exists. + +### Manipulating SSH config files + +Here's how you can manipulate an SSH config file, and then write it back to +disk. + +```go +f, _ := os.Open(filepath.Join(os.Getenv("HOME"), ".ssh", "config")) +cfg, _ := ssh_config.Decode(f) +for _, host := range cfg.Hosts { + fmt.Println("patterns:", host.Patterns) + for _, node := range host.Nodes { + // Manipulate the nodes as you see fit, or use a type switch to + // distinguish between Empty, KV, and Include nodes. + fmt.Println(node.String()) + } +} + +// Print the config to stdout: +fmt.Println(cfg.String()) +``` + +## Spec compliance + +Wherever possible we try to implement the specification as documented in +the `ssh_config` manpage. Unimplemented features should be present in the +[issues][issues] list. + +Notably, the `Match` directive is currently unsupported. + +[issues]: https://github.com/iamFrancescoFerro/ssh_config/issues + +## Errata + +This is the second [comment-preserving configuration parser][blog] I've written, after +[an /etc/hosts parser][hostsfile]. Eventually, I will write one for every Linux +file format. + +[blog]: https://kev.inburke.com/kevin/more-comment-preserving-configuration-parsers/ +[hostsfile]: https://github.com/iamFrancescoFerro/hostsfile + +## Sponsorships + +Thank you very much to Tailscale and Indeed for sponsoring development of this +library. [Sponsors][sponsors] will get their names featured in the README. + +You can also reach out about a consulting engagement: https://burke.services + +[sponsors]: https://github.com/sponsors/kevinburke diff --git a/config.go b/config.go new file mode 100644 index 0000000..b112118 --- /dev/null +++ b/config.go @@ -0,0 +1,846 @@ +// Package ssh_config provides tools for manipulating SSH config files. +// +// Importantly, this parser attempts to preserve comments in a given file, so +// you can manipulate a `ssh_config` file from a program, if your heart desires. +// +// The Get() and GetStrict() functions will attempt to read values from +// $HOME/.ssh/config, falling back to /etc/ssh/ssh_config. The first argument is +// the host name to match on ("example.com"), and the second argument is the key +// you want to retrieve ("Port"). The keywords are case insensitive. +// +// port := ssh_config.Get("myhost", "Port") +// +// You can also manipulate an SSH config file and then print it or write it back +// to disk. +// +// f, _ := os.Open(filepath.Join(os.Getenv("HOME"), ".ssh", "config")) +// cfg, _ := ssh_config.Decode(f) +// for _, host := range cfg.Hosts { +// fmt.Println("patterns:", host.Patterns) +// for _, node := range host.Nodes { +// fmt.Println(node.String()) +// } +// } +// +// // Write the cfg back to disk: +// fmt.Println(cfg.String()) +// +// BUG: the Match directive is currently unsupported; parsing a config with +// a Match directive will trigger an error. +package ssh_config + +import ( + "bytes" + "errors" + "fmt" + "io" + "os" + osuser "os/user" + "path/filepath" + "regexp" + "runtime" + "strings" + "sync" +) + +const version = "1.2" + +var _ = version + +type configFinder func() string + +// UserSettings checks ~/.ssh and /etc/ssh for configuration files. The config +// files are parsed and cached the first time Get() or GetStrict() is called. +type UserSettings struct { + IgnoreErrors bool + IgnoreMatchDirective bool + customConfig *Config + customConfigFinder configFinder + systemConfig *Config + systemConfigFinder configFinder + userConfig *Config + userConfigFinder configFinder + loadConfigs sync.Once + onceErr error +} + +func homedir() string { + user, err := osuser.Current() + if err == nil { + return user.HomeDir + } else { + return os.Getenv("HOME") + } +} + +func userConfigFinder() string { + return filepath.Join(homedir(), ".ssh", "config") +} + +// DefaultUserSettings is the default UserSettings and is used by Get and +// GetStrict. It checks both $HOME/.ssh/config and /etc/ssh/ssh_config for keys, +// and it will return parse errors (if any) instead of swallowing them. +var DefaultUserSettings = &UserSettings{ + IgnoreErrors: false, + IgnoreMatchDirective: false, + systemConfigFinder: systemConfigFinder, + userConfigFinder: userConfigFinder, +} + +func systemConfigFinder() string { + return filepath.Join("/", "etc", "ssh", "ssh_config") +} + +func findVal(c *Config, alias, key string) (string, error) { + if c == nil { + return "", nil + } + val, err := c.Get(alias, key) + if err != nil || val == "" { + return "", err + } + if err := validate(key, val); err != nil { + return "", err + } + return val, nil +} + +func findAll(c *Config, alias, key string) ([]string, error) { + if c == nil { + return nil, nil + } + return c.GetAll(alias, key) +} + +// Get finds the first value for key within a declaration that matches the +// alias. Get returns the empty string if no value was found, or if IgnoreErrors +// is false and we could not parse the configuration file. Use GetStrict to +// disambiguate the latter cases. +// +// The match for key is case insensitive. +// +// Get is a wrapper around DefaultUserSettings.Get. +func Get(alias, key string) string { + return DefaultUserSettings.Get(alias, key) +} + +// GetAll retrieves zero or more directives for key for the given alias. GetAll +// returns nil if no value was found, or if IgnoreErrors is false and we could +// not parse the configuration file. Use GetAllStrict to disambiguate the +// latter cases. +// +// In most cases you want to use Get or GetStrict, which returns a single value. +// However, a subset of ssh configuration values (IdentityFile, for example) +// allow you to specify multiple directives. +// +// The match for key is case insensitive. +// +// GetAll is a wrapper around DefaultUserSettings.GetAll. +func GetAll(alias, key string) []string { + return DefaultUserSettings.GetAll(alias, key) +} + +// GetStrict finds the first value for key within a declaration that matches the +// alias. If key has a default value and no matching configuration is found, the +// default will be returned. For more information on default values and the way +// patterns are matched, see the manpage for ssh_config. +// +// The returned error will be non-nil if and only if a user's configuration file +// or the system configuration file could not be parsed, and u.IgnoreErrors is +// false. +// +// GetStrict is a wrapper around DefaultUserSettings.GetStrict. +func GetStrict(alias, key string) (string, error) { + return DefaultUserSettings.GetStrict(alias, key) +} + +// GetAllStrict retrieves zero or more directives for key for the given alias. +// +// In most cases you want to use Get or GetStrict, which returns a single value. +// However, a subset of ssh configuration values (IdentityFile, for example) +// allow you to specify multiple directives. +// +// The returned error will be non-nil if and only if a user's configuration file +// or the system configuration file could not be parsed, and u.IgnoreErrors is +// false. +// +// GetAllStrict is a wrapper around DefaultUserSettings.GetAllStrict. +func GetAllStrict(alias, key string) ([]string, error) { + return DefaultUserSettings.GetAllStrict(alias, key) +} + +// Get finds the first value for key within a declaration that matches the +// alias. Get returns the empty string if no value was found, or if IgnoreErrors +// is false and we could not parse the configuration file. Use GetStrict to +// disambiguate the latter cases. +// +// The match for key is case insensitive. +func (u *UserSettings) Get(alias, key string) string { + val, err := u.GetStrict(alias, key) + if err != nil { + return "" + } + return val +} + +// GetAll retrieves zero or more directives for key for the given alias. GetAll +// returns nil if no value was found, or if IgnoreErrors is false and we could +// not parse the configuration file. Use GetStrict to disambiguate the latter +// cases. +// +// The match for key is case insensitive. +func (u *UserSettings) GetAll(alias, key string) []string { + val, _ := u.GetAllStrict(alias, key) + return val +} + +// GetStrict finds the first value for key within a declaration that matches the +// alias. If key has a default value and no matching configuration is found, the +// default will be returned. For more information on default values and the way +// patterns are matched, see the manpage for ssh_config. +// +// error will be non-nil if and only if a user's configuration file or the +// system configuration file could not be parsed, and u.IgnoreErrors is false. +func (u *UserSettings) GetStrict(alias, key string) (string, error) { + u.doLoadConfigs() + //lint:ignore S1002 I prefer it this way + if u.onceErr != nil && u.IgnoreErrors == false { + return "", u.onceErr + } + // TODO this is getting repetitive + if u.customConfig != nil { + val, err := findVal(u.customConfig, alias, key) + if err != nil || val != "" { + return val, err + } + } + val, err := findVal(u.userConfig, alias, key) + if err != nil || val != "" { + return val, err + } + val2, err2 := findVal(u.systemConfig, alias, key) + if err2 != nil || val2 != "" { + return val2, err2 + } + return Default(key), nil +} + +// GetAllStrict retrieves zero or more directives for key for the given alias. +// If key has a default value and no matching configuration is found, the +// default will be returned. For more information on default values and the way +// patterns are matched, see the manpage for ssh_config. +// +// The returned error will be non-nil if and only if a user's configuration file +// or the system configuration file could not be parsed, and u.IgnoreErrors is +// false. +func (u *UserSettings) GetAllStrict(alias, key string) ([]string, error) { + u.doLoadConfigs() + //lint:ignore S1002 I prefer it this way + if u.onceErr != nil && u.IgnoreErrors == false { + return nil, u.onceErr + } + if u.customConfig != nil { + val, err := findAll(u.customConfig, alias, key) + if err != nil || val != nil { + return val, err + } + } + val, err := findAll(u.userConfig, alias, key) + if err != nil || val != nil { + return val, err + } + val2, err2 := findAll(u.systemConfig, alias, key) + if err2 != nil || val2 != nil { + return val2, err2 + } + // TODO: IdentityFile has multiple default values that we should return. + if def := Default(key); def != "" { + return []string{def}, nil + } + return []string{}, nil +} + +// ConfigFinder will invoke f to try to find a ssh config file in a custom +// location on disk, instead of in /etc/ssh or $HOME/.ssh. f should return the +// name of a file containing SSH configuration. +// +// ConfigFinder must be invoked before any calls to Get or GetStrict and panics +// if f is nil. Most users should not need to use this function. +func (u *UserSettings) ConfigFinder(f func() string) { + if f == nil { + panic("cannot call ConfigFinder with nil function") + } + u.customConfigFinder = f +} + +func (u *UserSettings) doLoadConfigs() { + u.loadConfigs.Do(func() { + var filename string + var err error + if u.customConfigFinder != nil { + filename = u.customConfigFinder() + u.customConfig, err = parseFile(filename, u.IgnoreMatchDirective) + // IsNotExist should be returned because a user specified this + // function - not existing likely means they made an error + // We should also respect the ignore flag + if err != nil && !u.IgnoreErrors { + u.onceErr = err + } + return + } + if u.userConfigFinder == nil { + filename = userConfigFinder() + } else { + filename = u.userConfigFinder() + } + u.userConfig, err = parseFile(filename, u.IgnoreMatchDirective) + //lint:ignore S1002 I prefer it this way + if err != nil && os.IsNotExist(err) == false { + u.onceErr = err + return + } + if u.systemConfigFinder == nil { + filename = systemConfigFinder() + } else { + filename = u.systemConfigFinder() + } + u.systemConfig, err = parseFile(filename, u.IgnoreMatchDirective) + //lint:ignore S1002 I prefer it this way + if err != nil && os.IsNotExist(err) == false { + u.onceErr = err + return + } + }, + ) +} + +func parseFile(filename string, ignoreMatchDirective bool) (*Config, error) { + return parseWithDepth(filename, ignoreMatchDirective, 0) +} + +func parseWithDepth(filename string, ignoreMatchDirective bool, depth uint8) (*Config, error) { + b, err := os.ReadFile(filename) + if err != nil { + return nil, err + } + return decodeBytes(b, isSystem(filename), ignoreMatchDirective, depth) +} + +func isSystem(filename string) bool { + // TODO: not sure this is the best way to detect a system repo + return strings.HasPrefix(filepath.Clean(filename), "/etc/ssh") +} + +// Decode reads r into a Config, or returns an error if r could not be parsed as +// an SSH config file. +func Decode(r io.Reader, ignoreMatchDirective bool) (*Config, error) { + b, err := io.ReadAll(r) + if err != nil { + return nil, err + } + return decodeBytes(b, false, ignoreMatchDirective, 0) +} + +// DecodeBytes reads b into a Config, or returns an error if r could not be +// parsed as an SSH config file. +func DecodeBytes(b []byte, ignoreMatchDirective bool) (*Config, error) { + return decodeBytes(b, false, ignoreMatchDirective, 0) +} + +func decodeBytes(b []byte, system, ignoreMatchDirective bool, depth uint8) (c *Config, err error) { + defer func() { + if r := recover(); r != nil { + if _, ok := r.(runtime.Error); ok { + panic(r) + } + if e, ok := r.(error); ok && e == ErrDepthExceeded { + err = e + return + } + err = errors.New(r.(string)) + } + }() + + c = parseSSH(lexSSH(b), system, ignoreMatchDirective, depth) + return c, err +} + +// Config represents an SSH config file. +type Config struct { + // A list of hosts to match against. The file begins with an implicit + // "Host *" declaration matching all hosts. + Hosts []*Host + depth uint8 + position Position + ignoreMatchDirective bool +} + +// Get finds the first value in the configuration that matches the alias and +// contains key. Get returns the empty string if no value was found, or if the +// Config contains an invalid conditional Include value. +// +// The match for key is case insensitive. +func (c *Config) Get(alias, key string) (string, error) { + lowerKey := strings.ToLower(key) + for _, host := range c.Hosts { + if !host.Matches(alias) { + continue + } + for _, node := range host.Nodes { + switch t := node.(type) { + case *Empty: + continue + case *KV: + // "keys are case insensitive" per the spec + lkey := strings.ToLower(t.Key) + if lkey == "match" && !c.ignoreMatchDirective { + panic("can't handle Match directives") + } + if lkey == lowerKey { + return t.Value, nil + } + case *Include: + val := t.Get(alias, key) + if val != "" { + return val, nil + } + default: + return "", fmt.Errorf("unknown Node type %v", t) + } + } + } + return "", nil +} + +// GetAll returns all values in the configuration that match the alias and +// contains key, or nil if none are present. +func (c *Config) GetAll(alias, key string) ([]string, error) { + lowerKey := strings.ToLower(key) + all := []string(nil) + for _, host := range c.Hosts { + if !host.Matches(alias) { + continue + } + for _, node := range host.Nodes { + switch t := node.(type) { + case *Empty: + continue + case *KV: + // "keys are case insensitive" per the spec + lkey := strings.ToLower(t.Key) + if lkey == "match" { + panic("can't handle Match directives") + } + if lkey == lowerKey { + all = append(all, t.Value) + } + case *Include: + val, _ := t.GetAll(alias, key) + if len(val) > 0 { + all = append(all, val...) + } + default: + return nil, fmt.Errorf("unknown Node type %v", t) + } + } + } + + return all, nil +} + +// String returns a string representation of the Config file. +func (c Config) String() string { + return marshal(c).String() +} + +func (c Config) MarshalText() ([]byte, error) { + return marshal(c).Bytes(), nil +} + +func marshal(c Config) *bytes.Buffer { + var buf bytes.Buffer + for i := range c.Hosts { + buf.WriteString(c.Hosts[i].String()) + } + return &buf +} + +// Pattern is a pattern in a Host declaration. Patterns are read-only values; +// create a new one with NewPattern(). +type Pattern struct { + str string // Its appearance in the file, not the value that gets compiled. + regex *regexp.Regexp + not bool // True if this is a negated match +} + +// String prints the string representation of the pattern. +func (p Pattern) String() string { + return p.str +} + +// Copied from regexp.go with * and ? removed. +var specialBytes = []byte(`\.+()|[]{}^$`) + +func special(b byte) bool { + return bytes.IndexByte(specialBytes, b) >= 0 +} + +// NewPattern creates a new Pattern for matching hosts. NewPattern("*") creates +// a Pattern that matches all hosts. +// +// From the manpage, a pattern consists of zero or more non-whitespace +// characters, `*' (a wildcard that matches zero or more characters), or `?' (a +// wildcard that matches exactly one character). For example, to specify a set +// of declarations for any host in the ".co.uk" set of domains, the following +// pattern could be used: +// +// Host *.co.uk +// +// The following pattern would match any host in the 192.168.0.[0-9] network range: +// +// Host 192.168.0.? +func NewPattern(s string) (*Pattern, error) { + if s == "" { + return nil, errors.New("ssh_config: empty pattern") + } + negated := false + if s[0] == '!' { + negated = true + s = s[1:] + } + var buf bytes.Buffer + buf.WriteByte('^') + for i := 0; i < len(s); i++ { + // A byte loop is correct because all metacharacters are ASCII. + switch b := s[i]; b { + case '*': + buf.WriteString(".*") + case '?': + buf.WriteString(".?") + default: + // borrowing from QuoteMeta here. + if special(b) { + buf.WriteByte('\\') + } + buf.WriteByte(b) + } + } + buf.WriteByte('$') + r, err := regexp.Compile(buf.String()) + if err != nil { + return nil, err + } + return &Pattern{str: s, regex: r, not: negated}, nil +} + +// Host describes a Host directive and the keywords that follow it. +type Host struct { + // A list of host patterns that should match this host. + Patterns []*Pattern + // A Node is either a key/value pair or a comment line. + Nodes []Node + // EOLComment is the comment (if any) terminating the Host line. + EOLComment string + // Whitespace if any between the Host declaration and a trailing comment. + spaceBeforeComment string + + hasEquals bool + leadingSpace int // TODO: handle spaces vs tabs here. + // The file starts with an implicit "Host *" declaration. + implicit bool +} + +// Matches returns true if the Host matches for the given alias. For +// a description of the rules that provide a match, see the manpage for +// ssh_config. +func (h *Host) Matches(alias string) bool { + found := false + for i := range h.Patterns { + if h.Patterns[i].regex.MatchString(alias) { + if h.Patterns[i].not { + // Negated match. "A pattern entry may be negated by prefixing + // it with an exclamation mark (`!'). If a negated entry is + // matched, then the Host entry is ignored, regardless of + // whether any other patterns on the line match. Negated matches + // are therefore useful to provide exceptions for wildcard + // matches." + return false + } + found = true + } + } + return found +} + +// String prints h as it would appear in a config file. Minor tweaks may be +// present in the whitespace in the printed file. +func (h *Host) String() string { + var buf strings.Builder + //lint:ignore S1002 I prefer to write it this way + if h.implicit == false { + buf.WriteString(strings.Repeat(" ", int(h.leadingSpace))) + buf.WriteString("Host") + if h.hasEquals { + buf.WriteString(" = ") + } else { + buf.WriteString(" ") + } + for i, pat := range h.Patterns { + buf.WriteString(pat.String()) + if i < len(h.Patterns)-1 { + buf.WriteString(" ") + } + } + buf.WriteString(h.spaceBeforeComment) + if h.EOLComment != "" { + buf.WriteByte('#') + buf.WriteString(h.EOLComment) + } + buf.WriteByte('\n') + } + for i := range h.Nodes { + buf.WriteString(h.Nodes[i].String()) + buf.WriteByte('\n') + } + return buf.String() +} + +// Node represents a line in a Config. +type Node interface { + Pos() Position + String() string +} + +// KV is a line in the config file that contains a key, a value, and possibly +// a comment. +type KV struct { + Key string + Value string + // Whitespace after the value but before any comment + spaceAfterValue string + Comment string + hasEquals bool + leadingSpace int // Space before the key. TODO handle spaces vs tabs. + position Position +} + +// Pos returns k's Position. +func (k *KV) Pos() Position { + return k.position +} + +// String prints k as it was parsed in the config file. +func (k *KV) String() string { + if k == nil { + return "" + } + equals := " " + if k.hasEquals { + equals = " = " + } + line := strings.Repeat(" ", int(k.leadingSpace)) + k.Key + equals + k.Value + k.spaceAfterValue + if k.Comment != "" { + line += "#" + k.Comment + } + return line +} + +// Empty is a line in the config file that contains only whitespace or comments. +type Empty struct { + Comment string + leadingSpace int // TODO handle spaces vs tabs. + position Position +} + +// Pos returns e's Position. +func (e *Empty) Pos() Position { + return e.position +} + +// String prints e as it was parsed in the config file. +func (e *Empty) String() string { + if e == nil { + return "" + } + if e.Comment == "" { + return "" + } + return fmt.Sprintf("%s#%s", strings.Repeat(" ", int(e.leadingSpace)), e.Comment) +} + +// Include holds the result of an Include directive, including the config files +// that have been parsed as part of that directive. At most 5 levels of Include +// statements will be parsed. +type Include struct { + // Comment is the contents of any comment at the end of the Include + // statement. + Comment string + // an include directive can include several different files, and wildcards + directives []string + + mu sync.Mutex + // 1:1 mapping between matches and keys in files array; matches preserves + // ordering + matches []string + // actual filenames are listed here + files map[string]*Config + leadingSpace int + position Position + depth uint8 + hasEquals bool +} + +const maxRecurseDepth = 5 + +// ErrDepthExceeded is returned if too many Include directives are parsed. +// Usually this indicates a recursive loop (an Include directive pointing to the +// file it contains). +var ErrDepthExceeded = errors.New("ssh_config: max recurse depth exceeded") + +func removeDups(arr []string) []string { + // Use map to record duplicates as we find them. + encountered := make(map[string]bool, len(arr)) + result := make([]string, 0) + + for v := range arr { + //lint:ignore S1002 I prefer it this way + if encountered[arr[v]] == false { + encountered[arr[v]] = true + result = append(result, arr[v]) + } + } + return result +} + +// NewInclude creates a new Include with a list of file globs to include. +// Configuration files are parsed greedily (e.g. as soon as this function runs). +// Any error encountered while parsing nested configuration files will be +// returned. +func NewInclude(directives []string, hasEquals bool, pos Position, comment string, system, ignoreMatchDirective bool, depth uint8, +) (*Include, error) { + if depth > maxRecurseDepth { + return nil, ErrDepthExceeded + } + inc := &Include{ + Comment: comment, + directives: directives, + files: make(map[string]*Config), + position: pos, + leadingSpace: pos.Col - 1, + depth: depth, + hasEquals: hasEquals, + } + // no need for inc.mu.Lock() since nothing else can access this inc + matches := make([]string, 0) + for i := range directives { + var path string + if filepath.IsAbs(directives[i]) { + path = directives[i] + } else if system { + path = filepath.Join("/etc/ssh", directives[i]) + } else { + path = filepath.Join(homedir(), ".ssh", directives[i]) + } + theseMatches, err := filepath.Glob(path) + if err != nil { + return nil, err + } + matches = append(matches, theseMatches...) + } + matches = removeDups(matches) + inc.matches = matches + for i := range matches { + config, err := parseWithDepth(matches[i], ignoreMatchDirective, depth) + if err != nil { + return nil, err + } + inc.files[matches[i]] = config + } + return inc, nil +} + +// Pos returns the position of the Include directive in the larger file. +func (i *Include) Pos() Position { + return i.position +} + +// Get finds the first value in the Include statement matching the alias and the +// given key. +func (inc *Include) Get(alias, key string) string { + inc.mu.Lock() + defer inc.mu.Unlock() + // TODO: we search files in any order which is not correct + for i := range inc.matches { + cfg := inc.files[inc.matches[i]] + if cfg == nil { + panic("nil cfg") + } + val, err := cfg.Get(alias, key) + if err == nil && val != "" { + return val + } + } + return "" +} + +// GetAll finds all values in the Include statement matching the alias and the +// given key. +func (inc *Include) GetAll(alias, key string) ([]string, error) { + inc.mu.Lock() + defer inc.mu.Unlock() + var vals []string + + // TODO: we search files in any order which is not correct + for i := range inc.matches { + cfg := inc.files[inc.matches[i]] + if cfg == nil { + panic("nil cfg") + } + val, err := cfg.GetAll(alias, key) + if err == nil && len(val) != 0 { + // In theory if SupportsMultiple was false for this key we could + // stop looking here. But the caller has asked us to find all + // instances of the keyword (and could use Get() if they wanted) so + // let's keep looking. + vals = append(vals, val...) + } + } + return vals, nil +} + +// String prints out a string representation of this Include directive. Note +// included Config files are not printed as part of this representation. +func (inc *Include) String() string { + equals := " " + if inc.hasEquals { + equals = " = " + } + line := fmt.Sprintf("%sInclude%s%s", strings.Repeat(" ", int(inc.leadingSpace)), equals, strings.Join(inc.directives, " ")) + if inc.Comment != "" { + line += " #" + inc.Comment + } + return line +} + +var matchAll *Pattern + +func init() { + var err error + matchAll, err = NewPattern("*") + if err != nil { + panic(err) + } +} + +func newConfig() *Config { + return &Config{ + Hosts: []*Host{ + &Host{ + implicit: true, + Patterns: []*Pattern{matchAll}, + Nodes: make([]Node, 0), + }, + }, + depth: 0, + } +} diff --git a/config_test.go b/config_test.go new file mode 100644 index 0000000..db8360b --- /dev/null +++ b/config_test.go @@ -0,0 +1,495 @@ +package ssh_config + +import ( + "bytes" + "log" + "os" + "path/filepath" + "strings" + "testing" +) + +func loadFile(t *testing.T, filename string) []byte { + t.Helper() + data, err := os.ReadFile(filename) + if err != nil { + t.Fatal(err) + } + return data +} + +var files = []string{ + "testdata/config1", + "testdata/config2", + "testdata/eol-comments", +} + +func TestDecode(t *testing.T) { + for _, filename := range files { + data := loadFile(t, filename) + cfg, err := Decode(bytes.NewReader(data), false) + if err != nil { + t.Fatal(err) + } + out := cfg.String() + if out != string(data) { + t.Errorf("%s out != data: got:\n%s\nwant:\n%s\n", filename, out, string(data)) + } + } +} + +func testConfigFinder(filename string) func() string { + return func() string { return filename } +} + +func nullConfigFinder() string { + return "" +} + +func TestGet(t *testing.T) { + us := &UserSettings{ + userConfigFinder: testConfigFinder("testdata/config1"), + } + + val := us.Get("wap", "User") + if val != "root" { + t.Errorf("expected to find User root, got %q", val) + } +} + +func TestGetWithDefault(t *testing.T) { + us := &UserSettings{ + userConfigFinder: testConfigFinder("testdata/config1"), + } + + val, err := us.GetStrict("wap", "PasswordAuthentication") + if err != nil { + t.Fatalf("expected nil err, got %v", err) + } + if val != "yes" { + t.Errorf("expected to get PasswordAuthentication yes, got %q", val) + } +} + +func TestGetAllWithDefault(t *testing.T) { + us := &UserSettings{ + userConfigFinder: testConfigFinder("testdata/config1"), + } + + val, err := us.GetAllStrict("wap", "PasswordAuthentication") + if err != nil { + t.Fatalf("expected nil err, got %v", err) + } + if len(val) != 1 || val[0] != "yes" { + t.Errorf("expected to get PasswordAuthentication yes, got %q", val) + } +} + +func TestGetIdentities(t *testing.T) { + us := &UserSettings{ + userConfigFinder: testConfigFinder("testdata/identities"), + } + + val, err := us.GetAllStrict("hasidentity", "IdentityFile") + if err != nil { + t.Errorf("expected nil err, got %v", err) + } + if len(val) != 1 || val[0] != "file1" { + t.Errorf(`expected ["file1"], got %v`, val) + } + + val, err = us.GetAllStrict("has2identity", "IdentityFile") + if err != nil { + t.Errorf("expected nil err, got %v", err) + } + if len(val) != 2 || val[0] != "f1" || val[1] != "f2" { + t.Errorf(`expected [\"f1\", \"f2\"], got %v`, val) + } + + val, err = us.GetAllStrict("randomhost", "IdentityFile") + if err != nil { + t.Errorf("expected nil err, got %v", err) + } + if len(val) != len(defaultProtocol2Identities) { + // TODO: return the right values here. + log.Printf("expected defaults, got %v", val) + } else { + for i, v := range defaultProtocol2Identities { + if val[i] != v { + t.Errorf("invalid %d in val, expected %s got %s", i, v, val[i]) + } + } + } + + val, err = us.GetAllStrict("protocol1", "IdentityFile") + if err != nil { + t.Errorf("expected nil err, got %v", err) + } + if len(val) != 1 || val[0] != "~/.ssh/identity" { + t.Errorf("expected [\"~/.ssh/identity\"], got %v", val) + } +} + +func TestGetInvalidPort(t *testing.T) { + us := &UserSettings{ + userConfigFinder: testConfigFinder("testdata/invalid-port"), + } + + val, err := us.GetStrict("test.test", "Port") + if err == nil { + t.Fatalf("expected non-nil err, got nil") + } + if val != "" { + t.Errorf("expected to get '' for val, got %q", val) + } + if err.Error() != `ssh_config: strconv.ParseUint: parsing "notanumber": invalid syntax` { + t.Errorf("wrong error: got %v", err) + } +} + +func TestGetNotFoundNoDefault(t *testing.T) { + us := &UserSettings{ + userConfigFinder: testConfigFinder("testdata/config1"), + } + + val, err := us.GetStrict("wap", "CanonicalDomains") + if err != nil { + t.Fatalf("expected nil err, got %v", err) + } + if val != "" { + t.Errorf("expected to get CanonicalDomains '', got %q", val) + } +} + +func TestGetAllNotFoundNoDefault(t *testing.T) { + us := &UserSettings{ + userConfigFinder: testConfigFinder("testdata/config1"), + } + + val, err := us.GetAllStrict("wap", "CanonicalDomains") + if err != nil { + t.Fatalf("expected nil err, got %v", err) + } + if len(val) != 0 { + t.Errorf("expected to get CanonicalDomains '', got %q", val) + } +} + +func TestGetWildcard(t *testing.T) { + us := &UserSettings{ + userConfigFinder: testConfigFinder("testdata/config3"), + } + + val := us.Get("bastion.stage.i.us.example.net", "Port") + if val != "22" { + t.Errorf("expected to find Port 22, got %q", val) + } + + val = us.Get("bastion.net", "Port") + if val != "25" { + t.Errorf("expected to find Port 24, got %q", val) + } + + val = us.Get("10.2.3.4", "Port") + if val != "23" { + t.Errorf("expected to find Port 23, got %q", val) + } + val = us.Get("101.2.3.4", "Port") + if val != "25" { + t.Errorf("expected to find Port 24, got %q", val) + } + val = us.Get("20.20.20.4", "Port") + if val != "24" { + t.Errorf("expected to find Port 24, got %q", val) + } + val = us.Get("20.20.20.20", "Port") + if val != "25" { + t.Errorf("expected to find Port 25, got %q", val) + } +} + +func TestGetExtraSpaces(t *testing.T) { + us := &UserSettings{ + userConfigFinder: testConfigFinder("testdata/extraspace"), + } + + val := us.Get("test.test", "Port") + if val != "1234" { + t.Errorf("expected to find Port 1234, got %q", val) + } +} + +func TestGetCaseInsensitive(t *testing.T) { + us := &UserSettings{ + userConfigFinder: testConfigFinder("testdata/config1"), + } + + val := us.Get("wap", "uSER") + if val != "root" { + t.Errorf("expected to find User root, got %q", val) + } +} + +func TestGetEmpty(t *testing.T) { + us := &UserSettings{ + userConfigFinder: nullConfigFinder, + systemConfigFinder: nullConfigFinder, + } + val, err := us.GetStrict("wap", "User") + if err != nil { + t.Errorf("expected nil error, got %v", err) + } + if val != "" { + t.Errorf("expected to get empty string, got %q", val) + } +} + +func TestGetEqsign(t *testing.T) { + us := &UserSettings{ + userConfigFinder: testConfigFinder("testdata/eqsign"), + } + + val := us.Get("test.test", "Port") + if val != "1234" { + t.Errorf("expected to find Port 1234, got %q", val) + } + val = us.Get("test.test", "Port2") + if val != "5678" { + t.Errorf("expected to find Port2 5678, got %q", val) + } +} + +var includeFile = []byte(` +# This host should not exist, so we can use it for test purposes / it won't +# interfere with any other configurations. +Host kevinburke.ssh_config.test.example.com + Port 4567 +`) + +func TestInclude(t *testing.T) { + if testing.Short() { + t.Skip("skipping fs write in short mode") + } + testPath := filepath.Join(homedir(), ".ssh", "kevinburke-ssh-config-test-file") + err := os.WriteFile(testPath, includeFile, 0644) + if err != nil { + t.Skipf("couldn't write SSH config file: %v", err.Error()) + } + defer os.Remove(testPath) + us := &UserSettings{ + userConfigFinder: testConfigFinder("testdata/include"), + } + val := us.Get("kevinburke.ssh_config.test.example.com", "Port") + if val != "4567" { + t.Errorf("expected to find Port=4567 in included file, got %q", val) + } +} + +func TestIncludeSystem(t *testing.T) { + if testing.Short() { + t.Skip("skipping fs write in short mode") + } + testPath := filepath.Join("/", "etc", "ssh", "kevinburke-ssh-config-test-file") + err := os.WriteFile(testPath, includeFile, 0644) + if err != nil { + t.Skipf("couldn't write SSH config file: %v", err.Error()) + } + defer os.Remove(testPath) + us := &UserSettings{ + systemConfigFinder: testConfigFinder("testdata/include"), + } + val := us.Get("kevinburke.ssh_config.test.example.com", "Port") + if val != "4567" { + t.Errorf("expected to find Port=4567 in included file, got %q", val) + } +} + +var recursiveIncludeFile = []byte(` +Host kevinburke.ssh_config.test.example.com + Include kevinburke-ssh-config-recursive-include +`) + +func TestIncludeRecursive(t *testing.T) { + if testing.Short() { + t.Skip("skipping fs write in short mode") + } + testPath := filepath.Join(homedir(), ".ssh", "kevinburke-ssh-config-recursive-include") + err := os.WriteFile(testPath, recursiveIncludeFile, 0644) + if err != nil { + t.Skipf("couldn't write SSH config file: %v", err.Error()) + } + defer os.Remove(testPath) + us := &UserSettings{ + userConfigFinder: testConfigFinder("testdata/include-recursive"), + } + val, err := us.GetStrict("kevinburke.ssh_config.test.example.com", "Port") + if err != ErrDepthExceeded { + t.Errorf("Recursive include: expected ErrDepthExceeded, got %v", err) + } + if val != "" { + t.Errorf("non-empty string value %s", val) + } +} + +func TestIncludeString(t *testing.T) { + if testing.Short() { + t.Skip("skipping fs write in short mode") + } + data, err := os.ReadFile("testdata/include") + if err != nil { + log.Fatal(err) + } + c, err := Decode(bytes.NewReader(data), false) + if err != nil { + t.Fatal(err) + } + s := c.String() + if s != string(data) { + t.Errorf("mismatch: got %q\nwant %q", s, string(data)) + } +} + +var matchTests = []struct { + in []string + alias string + want bool +}{ + {[]string{"*"}, "any.test", true}, + {[]string{"a", "b", "*", "c"}, "any.test", true}, + {[]string{"a", "b", "c"}, "any.test", false}, + {[]string{"any.test"}, "any1test", false}, + {[]string{"192.168.0.?"}, "192.168.0.1", true}, + {[]string{"192.168.0.?"}, "192.168.0.10", false}, + {[]string{"*.co.uk"}, "bbc.co.uk", true}, + {[]string{"*.co.uk"}, "subdomain.bbc.co.uk", true}, + {[]string{"*.*.co.uk"}, "bbc.co.uk", false}, + {[]string{"*.*.co.uk"}, "subdomain.bbc.co.uk", true}, + {[]string{"*.example.com", "!*.dialup.example.com", "foo.dialup.example.com"}, "foo.dialup.example.com", false}, + {[]string{"test.*", "!test.host"}, "test.host", false}, +} + +func TestMatches(t *testing.T) { + for _, tt := range matchTests { + patterns := make([]*Pattern, len(tt.in)) + for i := range tt.in { + pat, err := NewPattern(tt.in[i]) + if err != nil { + t.Fatalf("error compiling pattern %s: %v", tt.in[i], err) + } + patterns[i] = pat + } + host := &Host{ + Patterns: patterns, + } + got := host.Matches(tt.alias) + if got != tt.want { + t.Errorf("host(%q).Matches(%q): got %v, want %v", tt.in, tt.alias, got, tt.want) + } + } +} + +func TestMatchUnsupported(t *testing.T) { + us := &UserSettings{ + userConfigFinder: testConfigFinder("testdata/match-directive"), + } + + _, err := us.GetStrict("test.test", "Port") + if err == nil { + t.Fatal("expected Match directive to error, didn't") + } + if !strings.Contains(err.Error(), "ssh_config: Match directive parsing is unsupported") { + t.Errorf("wrong error: %v", err) + } +} + +func TestIndexInRange(t *testing.T) { + us := &UserSettings{ + userConfigFinder: testConfigFinder("testdata/config4"), + } + + user, err := us.GetStrict("wap", "User") + if err != nil { + t.Fatal(err) + } + if user != "root" { + t.Errorf("expected User to be %q, got %q", "root", user) + } +} + +func TestDosLinesEndingsDecode(t *testing.T) { + us := &UserSettings{ + userConfigFinder: testConfigFinder("testdata/dos-lines"), + } + + user, err := us.GetStrict("wap", "User") + if err != nil { + t.Fatal(err) + } + + if user != "root" { + t.Errorf("expected User to be %q, got %q", "root", user) + } + + host, err := us.GetStrict("wap2", "HostName") + if err != nil { + t.Fatal(err) + } + + if host != "8.8.8.8" { + t.Errorf("expected HostName to be %q, got %q", "8.8.8.8", host) + } +} + +func TestNoTrailingNewline(t *testing.T) { + us := &UserSettings{ + userConfigFinder: testConfigFinder("testdata/config-no-ending-newline"), + systemConfigFinder: nullConfigFinder, + } + + port, err := us.GetStrict("example", "Port") + if err != nil { + t.Fatal(err) + } + + if port != "4242" { + t.Errorf("wrong port: got %q want 4242", port) + } +} + +func TestCustomFinder(t *testing.T) { + us := &UserSettings{} + us.ConfigFinder(func() string { + return "testdata/config1" + }) + + val := us.Get("wap", "User") + if val != "root" { + t.Errorf("expected to find User root, got %q", val) + } +} + +func TestCustomFinderWhenIgnoringMatchDirective(t *testing.T) { + us := &UserSettings{ + IgnoreMatchDirective: true, + } + us.ConfigFinder(func() string { + return "testdata/config1-with-match-directive" + }) + + val := us.Get("git.yahoo.com", "HostName") + if val != "git.proxy.com" { + t.Errorf("expected to find Hostname git.proxy.com, got %q", val) + } +} + +func TestCustomFinderWhenNotIgnoringMatchDirective(t *testing.T) { + us := &UserSettings{} + us.ConfigFinder(func() string { + return "testdata/config1-with-match-directive" + }) + + val := us.Get("git.yahoo.com", "HostName") + if val != "" { + t.Errorf("expected to find Hostname empty %q", val) + } +} diff --git a/example_test.go b/example_test.go new file mode 100644 index 0000000..0720858 --- /dev/null +++ b/example_test.go @@ -0,0 +1,59 @@ +package ssh_config_test + +import ( + "fmt" + "path/filepath" + "strings" + + "github.com/dimonomid/ssh_config" +) + +func ExampleHost_Matches() { + pat, _ := ssh_config.NewPattern("test.*.example.com") + host := &ssh_config.Host{Patterns: []*ssh_config.Pattern{pat}} + fmt.Println(host.Matches("test.stage.example.com")) + fmt.Println(host.Matches("othersubdomain.example.com")) + // Output: + // true + // false +} + +func ExamplePattern() { + pat, _ := ssh_config.NewPattern("*") + host := &ssh_config.Host{Patterns: []*ssh_config.Pattern{pat}} + fmt.Println(host.Matches("test.stage.example.com")) + fmt.Println(host.Matches("othersubdomain.any.any")) + // Output: + // true + // true +} + +func ExampleDecode() { + var config = ` +Host *.example.com + Compression yes +` + + cfg, _ := ssh_config.Decode(strings.NewReader(config), false) + val, _ := cfg.Get("test.example.com", "Compression") + fmt.Println(val) + // Output: yes +} + +func ExampleDefault() { + fmt.Println(ssh_config.Default("Port")) + fmt.Println(ssh_config.Default("UnknownVar")) + // Output: + // 22 + // +} + +func ExampleUserSettings_ConfigFinder() { + // This can be used to test SSH config parsing. + u := ssh_config.UserSettings{} + u.ConfigFinder(func() string { + return filepath.Join("testdata", "test_config") + }, + ) + u.Get("example.com", "Host") +} diff --git a/fuzz_test.go b/fuzz_test.go new file mode 100644 index 0000000..e3a168f --- /dev/null +++ b/fuzz_test.go @@ -0,0 +1,18 @@ +//go:build 1.18 +// +build 1.18 + +package ssh_config + +import ( + "bytes" + "testing" +) + +func FuzzDecode(f *testing.F) { + f.Fuzz(func(t *testing.T, in []byte) { + _, err := Decode(bytes.NewReader(in)) + if err != nil { + t.Fatalf("decode %q: %v", string(in), err) + } + }) +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..ff6a0b7 --- /dev/null +++ b/go.mod @@ -0,0 +1,3 @@ +module github.com/dimonomid/ssh_config + +go 1.18 diff --git a/lexer.go b/lexer.go new file mode 100644 index 0000000..11680b4 --- /dev/null +++ b/lexer.go @@ -0,0 +1,240 @@ +package ssh_config + +import ( + "bytes" +) + +// Define state functions +type sshLexStateFn func() sshLexStateFn + +type sshLexer struct { + inputIdx int + input []rune // Textual source + + buffer []rune // Runes composing the current token + tokens chan token + line int + col int + endbufferLine int + endbufferCol int +} + +func (s *sshLexer) lexComment(previousState sshLexStateFn) sshLexStateFn { + return func() sshLexStateFn { + growingString := "" + for next := s.peek(); next != '\n' && next != eof; next = s.peek() { + if next == '\r' && s.follow("\r\n") { + break + } + growingString += string(next) + s.next() + } + s.emitWithValue(tokenComment, growingString) + s.skip() + return previousState + } +} + +// lex the space after an equals sign in a function +func (s *sshLexer) lexRspace() sshLexStateFn { + for { + next := s.peek() + if !isSpace(next) { + break + } + s.skip() + } + return s.lexRvalue +} + +func (s *sshLexer) lexEquals() sshLexStateFn { + for { + next := s.peek() + if next == '=' { + s.emit(tokenEquals) + s.skip() + return s.lexRspace + } + // TODO error handling here; newline eof etc. + if !isSpace(next) { + break + } + s.skip() + } + return s.lexRvalue +} + +func (s *sshLexer) lexKey() sshLexStateFn { + growingString := "" + + for r := s.peek(); isKeyChar(r); r = s.peek() { + // simplified a lot here + if isSpace(r) || r == '=' { + s.emitWithValue(tokenKey, growingString) + s.skip() + return s.lexEquals + } + growingString += string(r) + s.next() + } + s.emitWithValue(tokenKey, growingString) + return s.lexEquals +} + +func (s *sshLexer) lexRvalue() sshLexStateFn { + growingString := "" + for { + next := s.peek() + switch next { + case '\r': + if s.follow("\r\n") { + s.emitWithValue(tokenString, growingString) + s.skip() + return s.lexVoid + } + case '\n': + s.emitWithValue(tokenString, growingString) + s.skip() + return s.lexVoid + case '#': + s.emitWithValue(tokenString, growingString) + s.skip() + return s.lexComment(s.lexVoid) + case eof: + s.next() + } + if next == eof { + break + } + growingString += string(next) + s.next() + } + s.emit(tokenEOF) + return nil +} + +func (s *sshLexer) read() rune { + r := s.peek() + if r == '\n' { + s.endbufferLine++ + s.endbufferCol = 1 + } else { + s.endbufferCol++ + } + s.inputIdx++ + return r +} + +func (s *sshLexer) next() rune { + r := s.read() + + if r != eof { + s.buffer = append(s.buffer, r) + } + return r +} + +func (s *sshLexer) lexVoid() sshLexStateFn { + for { + next := s.peek() + switch next { + case '#': + s.skip() + return s.lexComment(s.lexVoid) + case '\r': + fallthrough + case '\n': + s.emit(tokenEmptyLine) + s.skip() + continue + } + + if isSpace(next) { + s.skip() + } + + if isKeyStartChar(next) { + return s.lexKey + } + + // removed IsKeyStartChar and lexKey. probably will need to readd + + if next == eof { + s.next() + break + } + } + + s.emit(tokenEOF) + return nil +} + +func (s *sshLexer) ignore() { + s.buffer = make([]rune, 0) + s.line = s.endbufferLine + s.col = s.endbufferCol +} + +func (s *sshLexer) skip() { + s.next() + s.ignore() +} + +func (s *sshLexer) emit(t tokenType) { + s.emitWithValue(t, string(s.buffer)) +} + +func (s *sshLexer) emitWithValue(t tokenType, value string) { + tok := token{ + Position: Position{s.line, s.col}, + typ: t, + val: value, + } + s.tokens <- tok + s.ignore() +} + +func (s *sshLexer) peek() rune { + if s.inputIdx >= len(s.input) { + return eof + } + + r := s.input[s.inputIdx] + return r +} + +func (s *sshLexer) follow(next string) bool { + inputIdx := s.inputIdx + for _, expectedRune := range next { + if inputIdx >= len(s.input) { + return false + } + r := s.input[inputIdx] + inputIdx++ + if expectedRune != r { + return false + } + } + return true +} + +func (s *sshLexer) run() { + for state := s.lexVoid; state != nil; { + state = state() + } + close(s.tokens) +} + +func lexSSH(input []byte) chan token { + runes := bytes.Runes(input) + l := &sshLexer{ + input: runes, + tokens: make(chan token), + line: 1, + col: 1, + endbufferLine: 1, + endbufferCol: 1, + } + go l.run() + return l.tokens +} diff --git a/parser.go b/parser.go new file mode 100644 index 0000000..e59c500 --- /dev/null +++ b/parser.go @@ -0,0 +1,205 @@ +package ssh_config + +import ( + "fmt" + "strings" + "unicode" +) + +type sshParser struct { + ignoreMatchDirective bool + flow chan token + config *Config + tokensBuffer []token + currentTable []string + seenTableKeys []string + // /etc/ssh parser or local parser - used to find the default for relative + // filepaths in the Include directive + system bool + depth uint8 +} + +type sshParserStateFn func() sshParserStateFn + +// Formats and panics an error message based on a token +func (p *sshParser) raiseErrorf(tok *token, msg string, args ...interface{}) { + // TODO this format is ugly + panic(tok.Position.String() + ": " + fmt.Sprintf(msg, args...)) +} + +func (p *sshParser) raiseError(tok *token, err error) { + if err == ErrDepthExceeded { + panic(err) + } + // TODO this format is ugly + panic(tok.Position.String() + ": " + err.Error()) +} + +func (p *sshParser) run() { + for state := p.parseStart; state != nil; { + state = state() + } +} + +func (p *sshParser) peek() *token { + if len(p.tokensBuffer) != 0 { + return &(p.tokensBuffer[0]) + } + + tok, ok := <-p.flow + if !ok { + return nil + } + p.tokensBuffer = append(p.tokensBuffer, tok) + return &tok +} + +func (p *sshParser) getToken() *token { + if len(p.tokensBuffer) != 0 { + tok := p.tokensBuffer[0] + p.tokensBuffer = p.tokensBuffer[1:] + return &tok + } + tok, ok := <-p.flow + if !ok { + return nil + } + return &tok +} + +func (p *sshParser) parseStart() sshParserStateFn { + tok := p.peek() + + // end of stream, parsing is finished + if tok == nil { + return nil + } + + switch tok.typ { + case tokenComment, tokenEmptyLine: + return p.parseComment + case tokenKey: + return p.parseKV + case tokenEOF: + return nil + default: + p.raiseErrorf(tok, fmt.Sprintf("unexpected token %q\n", tok)) + } + return nil +} + +func (p *sshParser) parseKV() sshParserStateFn { + key := p.getToken() + hasEquals := false + val := p.getToken() + if val.typ == tokenEquals { + hasEquals = true + val = p.getToken() + } + comment := "" + tok := p.peek() + if tok == nil { + tok = &token{typ: tokenEOF} + } + if tok.typ == tokenComment && tok.Position.Line == val.Position.Line { + tok = p.getToken() + comment = tok.val + } + if strings.ToLower(key.val) == "match" && !p.ignoreMatchDirective { + // https://github.com/kevinburke/ssh_config/issues/6 + p.raiseErrorf(val, "ssh_config: Match directive parsing is unsupported") + return nil + } + if strings.ToLower(key.val) == "host" { + strPatterns := strings.Split(val.val, " ") + patterns := make([]*Pattern, 0) + for i := range strPatterns { + if strPatterns[i] == "" { + continue + } + pat, err := NewPattern(strPatterns[i]) + if err != nil { + p.raiseErrorf(val, "Invalid host pattern: %v", err) + return nil + } + patterns = append(patterns, pat) + } + // val.val at this point could be e.g. "example.com " + hostval := strings.TrimRightFunc(val.val, unicode.IsSpace) + spaceBeforeComment := val.val[len(hostval):] + val.val = hostval + p.config.ignoreMatchDirective = p.ignoreMatchDirective + p.config.Hosts = append(p.config.Hosts, &Host{ + Patterns: patterns, + Nodes: make([]Node, 0), + EOLComment: comment, + spaceBeforeComment: spaceBeforeComment, + hasEquals: hasEquals, + }, + ) + return p.parseStart + } + lastHost := p.config.Hosts[len(p.config.Hosts)-1] + if strings.ToLower(key.val) == "include" { + inc, err := NewInclude(strings.Split(val.val, " "), hasEquals, key.Position, comment, p.system, p.ignoreMatchDirective, p.depth+1) + if err == ErrDepthExceeded { + p.raiseError(val, err) + return nil + } + if err != nil { + p.raiseErrorf(val, "Error parsing Include directive: %v", err) + return nil + } + lastHost.Nodes = append(lastHost.Nodes, inc) + return p.parseStart + } + shortval := strings.TrimRightFunc(val.val, unicode.IsSpace) + spaceAfterValue := val.val[len(shortval):] + kv := &KV{ + Key: key.val, + Value: shortval, + spaceAfterValue: spaceAfterValue, + Comment: comment, + hasEquals: hasEquals, + leadingSpace: key.Position.Col - 1, + position: key.Position, + } + lastHost.Nodes = append(lastHost.Nodes, kv) + return p.parseStart +} + +func (p *sshParser) parseComment() sshParserStateFn { + comment := p.getToken() + lastHost := p.config.Hosts[len(p.config.Hosts)-1] + lastHost.Nodes = append(lastHost.Nodes, &Empty{ + Comment: comment.val, + // account for the "#" as well + leadingSpace: comment.Position.Col - 2, + position: comment.Position, + }, + ) + return p.parseStart +} + +func parseSSH(flow chan token, system, ignoreMatchDirective bool, depth uint8) *Config { + // Ensure we consume tokens to completion even if parser exits early + defer func() { + for range flow { + } + }() + + result := newConfig() + result.position = Position{1, 1} + parser := &sshParser{ + ignoreMatchDirective: ignoreMatchDirective, + flow: flow, + config: result, + tokensBuffer: make([]token, 0), + currentTable: make([]string, 0), + seenTableKeys: make([]string, 0), + system: system, + depth: depth, + } + parser.run() + return result +} diff --git a/parser_test.go b/parser_test.go new file mode 100644 index 0000000..dc679a4 --- /dev/null +++ b/parser_test.go @@ -0,0 +1,24 @@ +package ssh_config + +import ( + "errors" + "testing" +) + +type errReader struct { +} + +func (b *errReader) Read(p []byte) (n int, err error) { + return 0, errors.New("read error occurred") +} + +func TestIOError(t *testing.T) { + buf := &errReader{} + _, err := Decode(buf, false) + if err == nil { + t.Fatal("expected non-nil err, got nil") + } + if err.Error() != "read error occurred" { + t.Errorf("expected read error msg, got %v", err) + } +} diff --git a/position.go b/position.go new file mode 100644 index 0000000..e0b5e3f --- /dev/null +++ b/position.go @@ -0,0 +1,25 @@ +package ssh_config + +import "fmt" + +// Position of a document element within a SSH document. +// +// Line and Col are both 1-indexed positions for the element's line number and +// column number, respectively. Values of zero or less will cause Invalid(), +// to return true. +type Position struct { + Line int // line within the document + Col int // column within the line +} + +// String representation of the position. +// Displays 1-indexed line and column numbers. +func (p Position) String() string { + return fmt.Sprintf("(%d, %d)", p.Line, p.Col) +} + +// Invalid returns whether or not the position is valid (i.e. with negative or +// null values) +func (p Position) Invalid() bool { + return p.Line <= 0 || p.Col <= 0 +} diff --git a/testdata/anotherfile b/testdata/anotherfile new file mode 100644 index 0000000..c4de676 --- /dev/null +++ b/testdata/anotherfile @@ -0,0 +1,3 @@ +# Not sure that this actually works; Include might need to be relative to the +# load directory. +Compression yes diff --git a/testdata/config-no-ending-newline b/testdata/config-no-ending-newline new file mode 100644 index 0000000..74347a4 --- /dev/null +++ b/testdata/config-no-ending-newline @@ -0,0 +1,3 @@ +Host example + HostName example.com + Port 4242 \ No newline at end of file diff --git a/testdata/config1 b/testdata/config1 new file mode 100644 index 0000000..66ee5a5 --- /dev/null +++ b/testdata/config1 @@ -0,0 +1,39 @@ +Host localhost 127.0.0.1 # A comment at the end of a host line. + NoHostAuthenticationForLocalhost yes + +# A comment + # A comment with leading spaces. + +Host wap + User root + KexAlgorithms diffie-hellman-group1-sha1 + +Host [some stuff behind a NAT] + Compression yes + ProxyCommand ssh -qW %h:%p [NATrouter] + +Host wopr # there are 2 proxies available for this one... + User root + ProxyCommand sh -c "ssh proxy1 -qW %h:22 || ssh proxy2 -qW %h:22" + +Host dhcp-?? + UserKnownHostsFile /dev/null + StrictHostKeyChecking no + User root + +Host [my boxes] [*.mydomain] + ForwardAgent yes + ForwardX11 yes + ForwardX11Trusted yes + +Host * + #ControlMaster auto + #ControlPath /tmp/ssh-master-%C + #ControlPath /tmp/ssh-%u-%r@%h:%p + #ControlPersist yes + ForwardX11Timeout 52w + XAuthLocation /usr/bin/xauth + SendEnv LANG LC_* + HostKeyAlgorithms ssh-ed25519,ssh-rsa + AddressFamily inet + #UpdateHostKeys ask diff --git a/testdata/config1-with-match-directive b/testdata/config1-with-match-directive new file mode 100644 index 0000000..ce7fc1a --- /dev/null +++ b/testdata/config1-with-match-directive @@ -0,0 +1,6 @@ +Match all + Include ~/.ssh +Host * + User usr +Host git.yahoo.com + HostName git.proxy.com diff --git a/testdata/config2 b/testdata/config2 new file mode 100644 index 0000000..90fb63f --- /dev/null +++ b/testdata/config2 @@ -0,0 +1,50 @@ +# $OpenBSD: ssh_config,v 1.30 2016/02/20 23:06:23 sobrado Exp $ + +# This is the ssh client system-wide configuration file. See +# ssh_config(5) for more information. This file provides defaults for +# users, and the values can be changed in per-user configuration files +# or on the command line. + +# Configuration data is parsed as follows: +# 1. command line options +# 2. user-specific file +# 3. system-wide file +# Any configuration value is only changed the first time it is set. +# Thus, host-specific definitions should be at the beginning of the +# configuration file, and defaults at the end. + +# Site-wide defaults for some commonly used options. For a comprehensive +# list of available options, their meanings and defaults, please see the +# ssh_config(5) man page. + +# Host * +# ForwardAgent no +# ForwardX11 no +# RhostsRSAAuthentication no +# RSAAuthentication yes +# PasswordAuthentication yes +# HostbasedAuthentication no +# GSSAPIAuthentication no +# GSSAPIDelegateCredentials no +# BatchMode no +# CheckHostIP yes +# AddressFamily any +# ConnectTimeout 0 +# StrictHostKeyChecking ask +# IdentityFile ~/.ssh/identity +# IdentityFile ~/.ssh/id_rsa +# IdentityFile ~/.ssh/id_dsa +# IdentityFile ~/.ssh/id_ecdsa +# IdentityFile ~/.ssh/id_ed25519 +# Port 22 +# Protocol 2 +# Cipher 3des +# Ciphers aes128-ctr,aes192-ctr,aes256-ctr,arcfour256,arcfour128,aes128-cbc,3des-cbc +# MACs hmac-md5,hmac-sha1,umac-64@openssh.com,hmac-ripemd160 +# EscapeChar ~ +# Tunnel no +# TunnelDevice any:any +# PermitLocalCommand no +# VisualHostKey no +# ProxyCommand ssh -q -W %h:%p gateway.example.com +# RekeyLimit 1G 1h diff --git a/testdata/config3 b/testdata/config3 new file mode 100644 index 0000000..8c15654 --- /dev/null +++ b/testdata/config3 @@ -0,0 +1,31 @@ +Host bastion.*.i.*.example.net + User simon.thulbourn + Port 22 + ForwardAgent yes + IdentityFile /Users/%u/.ssh/example.net/%r/id_rsa + UseKeychain yes + +Host 10.* + User simon.thulbourn + Port 23 + ForwardAgent yes + StrictHostKeyChecking no + UserKnownHostsFile /dev/null + IdentityFile /Users/%u/.ssh/example.net/%r/id_rsa + UseKeychain yes + ProxyCommand >&1; h="%h"; exec ssh -q $(ssh-bastion -ip $h) nc %h %p + +Host 20.20.20.? + User simon.thulbourn + Port 24 + ForwardAgent yes + StrictHostKeyChecking no + UserKnownHostsFile /dev/null + IdentityFile /Users/%u/.ssh/example.net/%r/id_rsa + UseKeychain yes + ProxyCommand >&1; h="%h"; exec ssh -q $(ssh-bastion -ip $h) nc %h %p + +Host * + IdentityFile /Users/%u/.ssh/%h/%r/id_rsa + UseKeychain yes + Port 25 diff --git a/testdata/config4 b/testdata/config4 new file mode 100644 index 0000000..ea10ac0 --- /dev/null +++ b/testdata/config4 @@ -0,0 +1,4 @@ +# Extra space at end of line is important. +Host wap + User root + KexAlgorithms diffie-hellman-group1-sha1 diff --git a/testdata/dos-lines b/testdata/dos-lines new file mode 100644 index 0000000..1001df4 --- /dev/null +++ b/testdata/dos-lines @@ -0,0 +1,10 @@ +# Config file with dos line endings +Host wap + HostName wap.example.org + Port 22 + User root + KexAlgorithms diffie-hellman-group1-sha1 + +Host wap2 + HostName 8.8.8.8 + User google diff --git a/testdata/eol-comments b/testdata/eol-comments new file mode 100644 index 0000000..cf8376d --- /dev/null +++ b/testdata/eol-comments @@ -0,0 +1,7 @@ +Host example # this comment terminates a Host line + HostName example.com # aligned eol comment 1 + ForwardX11Timeout 52w # aligned eol comment 2 +# This comment takes up a whole line + # This comment is offset and takes up a whole line + AddressFamily inet # aligned eol comment 3 + Port 4242 #compact comment diff --git a/testdata/eqsign b/testdata/eqsign new file mode 100644 index 0000000..6332b85 --- /dev/null +++ b/testdata/eqsign @@ -0,0 +1,4 @@ +Host=test.test + Port =1234 + Port2= 5678 + Compression yes diff --git a/testdata/extraspace b/testdata/extraspace new file mode 100644 index 0000000..e9ce2f8 --- /dev/null +++ b/testdata/extraspace @@ -0,0 +1,2 @@ +Host test.test + Port 1234 diff --git a/testdata/fuzz/FuzzDecode/3cfc035ae4867ca13fa7bfaf2793731f05fd4d59c3af8761ea365c7485c752fd b/testdata/fuzz/FuzzDecode/3cfc035ae4867ca13fa7bfaf2793731f05fd4d59c3af8761ea365c7485c752fd new file mode 100644 index 0000000..3497812 --- /dev/null +++ b/testdata/fuzz/FuzzDecode/3cfc035ae4867ca13fa7bfaf2793731f05fd4d59c3af8761ea365c7485c752fd @@ -0,0 +1,2 @@ +go test fuzz v1 +[]byte("#\t$OpenBSD: ssh_config,v 1.30 2016/02/20 23:06:23 sobrado Exp $\n\n# This is the ssh client system-wide configuration file. See\n# ssh_config(5) for more information. This file provides defaults for\n# users, and the values can be changed in per-user configuration files\n# or on the command line.\n\n# Configuration data is parsed as follows:\n# 1. command line options\n# 2. user-specific file\n# 3. system-wide file\n# Any configuration value is only changed the first time it is set.\n# Thus, host-specific definitions should be at the beginning of the\n# configuration file, and defaults at the end.\n\n# Site-wide defaults for some commonly used options. For a comprehensive\n# list of available options, their meanings and defaults, please see the\n# ssh_config(5) man page.\n\n# Host *\n# ForwardAgent no\n# ForwardX11 no\n# RhostsRSAAuthentication no\n# RSAAuthentication yes\n# PasswordAuthentication yes\n# HostbasedAuthentication no\n# GSSAPIAuthentication no\n# GSSAPIDelegateCredentials no\n# BatchMode no\n# CheckHostIP yes\n# AddressFamily any\n# ConnectTimeout 0\n# StrictHostKeyChecking ask\n# IdentityFile ~/.ssh/identity\n# IdentityFile ~/.ssh/id_rsa\n# IdentityFile ~/.ssh/id_dsa\n# IdentityFile ~/.ssh/id_ecdsa\n# IdentityFile ~/.ssh/id_ed25519\n# Port 22\n# Protocol 2\n# Cipher 3des\n# Ciphers aes128-ctr,aes192-ctr,aes256-ctr,arcfour256,arcfour128,aes128-cbc,3des-cbc\n# MACs hmac-md5,hmac-sha1,umac-64@openssh.com,hmac-ripemd160\n# EscapeChar ~\n# Tunnel no\n# TunnelDevice any:any\n# PermitLocalCommand no\n# VisualHostKey no\n# ProxyCommand ssh -q -W %h:%p gateway.example.com\n# RekeyLimit 1G 1h\n") \ No newline at end of file diff --git a/testdata/fuzz/FuzzDecode/4f8b378d89916e9b4fd796f74f5b12efb5cd85faaba9fea8fbe419d6af63add8 b/testdata/fuzz/FuzzDecode/4f8b378d89916e9b4fd796f74f5b12efb5cd85faaba9fea8fbe419d6af63add8 new file mode 100644 index 0000000..5f3337e --- /dev/null +++ b/testdata/fuzz/FuzzDecode/4f8b378d89916e9b4fd796f74f5b12efb5cd85faaba9fea8fbe419d6af63add8 @@ -0,0 +1,2 @@ +go test fuzz v1 +[]byte("Host localhost 127.0.0.1 # A comment at the end of a host line.\n NoHostAuthenticationForLocalhost yes\n\n# A comment\n # A comment with leading spaces.\n\nHost wap\n User root\n KexAlgorithms diffie-hellman-group1-sha1\n\nHost [some stuff behind a NAT]\n Compression yes\n ProxyCommand ssh -qW %h:%p [NATrouter]\n\nHost wopr # there are 2 proxies available for this one...\n User root\n ProxyCommand sh -c \"ssh proxy1 -qW %h:22 || ssh proxy2 -qW %h:22\"\n\nHost dhcp-??\n UserKnownHostsFile /dev/null\n StrictHostKeyChecking no\n User root\n\nHost [my boxes] [*.mydomain]\n ForwardAgent yes\n ForwardX11 yes\n ForwardX11Trusted yes\n\nHost *\n #ControlMaster auto\n #ControlPath /tmp/ssh-master-%C\n #ControlPath /tmp/ssh-%u-%r@%h:%p\n #ControlPersist yes\n ForwardX11Timeout 52w\n XAuthLocation /usr/bin/xauth\n SendEnv LANG LC_*\n HostKeyAlgorithms ssh-ed25519,ssh-rsa\n AddressFamily inet\n #UpdateHostKeys ask\n") \ No newline at end of file diff --git a/testdata/identities b/testdata/identities new file mode 100644 index 0000000..14971f7 --- /dev/null +++ b/testdata/identities @@ -0,0 +1,11 @@ + +Host hasidentity + IdentityFile file1 + +Host has2identity + IdentityFile f1 + IdentityFile f2 + +Host protocol1 + Protocol 1 + diff --git a/testdata/include b/testdata/include new file mode 100644 index 0000000..ee238dd --- /dev/null +++ b/testdata/include @@ -0,0 +1,4 @@ +Host kevinburke.ssh_config.test.example.com + # This file (or files) needs to be found in ~/.ssh or /etc/ssh, depending on + # the test. + Include kevinburke-ssh-config-*-file diff --git a/testdata/include-recursive b/testdata/include-recursive new file mode 100644 index 0000000..8a3cd3d --- /dev/null +++ b/testdata/include-recursive @@ -0,0 +1,4 @@ +Host kevinburke.ssh_config.test.example.com + # This file (or files) needs to be found in ~/.ssh or /etc/ssh, depending on + # the test. It should include itself. + Include kevinburke-ssh-config-recursive-include diff --git a/testdata/invalid-port b/testdata/invalid-port new file mode 100644 index 0000000..845b918 --- /dev/null +++ b/testdata/invalid-port @@ -0,0 +1,2 @@ +Host test.test + Port notanumber diff --git a/testdata/match-directive b/testdata/match-directive new file mode 100644 index 0000000..93bf382 --- /dev/null +++ b/testdata/match-directive @@ -0,0 +1,2 @@ +Match all + Port 4567 diff --git a/testdata/negated b/testdata/negated new file mode 100644 index 0000000..82df3c8 --- /dev/null +++ b/testdata/negated @@ -0,0 +1,5 @@ +Host *.example.com !*.dialup.example.com + Port 1234 + +Host * + Port 5678 diff --git a/testdata/system-include b/testdata/system-include new file mode 100644 index 0000000..e69de29 diff --git a/token.go b/token.go new file mode 100644 index 0000000..a0ecbb2 --- /dev/null +++ b/token.go @@ -0,0 +1,49 @@ +package ssh_config + +import "fmt" + +type token struct { + Position + typ tokenType + val string +} + +func (t token) String() string { + switch t.typ { + case tokenEOF: + return "EOF" + } + return fmt.Sprintf("%q", t.val) +} + +type tokenType int + +const ( + eof = -(iota + 1) +) + +const ( + tokenError tokenType = iota + tokenEOF + tokenEmptyLine + tokenComment + tokenKey + tokenEquals + tokenString +) + +func isSpace(r rune) bool { + return r == ' ' || r == '\t' +} + +func isKeyStartChar(r rune) bool { + return !(isSpace(r) || r == '\r' || r == '\n' || r == eof) +} + +// I'm not sure that this is correct +func isKeyChar(r rune) bool { + // Keys start with the first character that isn't whitespace or [ and end + // with the last non-whitespace character before the equals sign. Keys + // cannot contain a # character." + return !(r == '\r' || r == '\n' || r == eof || r == '=') +} diff --git a/validators.go b/validators.go new file mode 100644 index 0000000..5977f90 --- /dev/null +++ b/validators.go @@ -0,0 +1,186 @@ +package ssh_config + +import ( + "fmt" + "strconv" + "strings" +) + +// Default returns the default value for the given keyword, for example "22" if +// the keyword is "Port". Default returns the empty string if the keyword has no +// default, or if the keyword is unknown. Keyword matching is case-insensitive. +// +// Default values are provided by OpenSSH_7.4p1 on a Mac. +func Default(keyword string) string { + return defaults[strings.ToLower(keyword)] +} + +// Arguments where the value must be "yes" or "no" and *only* yes or no. +var yesnos = map[string]bool{ + strings.ToLower("BatchMode"): true, + strings.ToLower("CanonicalizeFallbackLocal"): true, + strings.ToLower("ChallengeResponseAuthentication"): true, + strings.ToLower("CheckHostIP"): true, + strings.ToLower("ClearAllForwardings"): true, + strings.ToLower("Compression"): true, + strings.ToLower("EnableSSHKeysign"): true, + strings.ToLower("ExitOnForwardFailure"): true, + strings.ToLower("ForwardAgent"): true, + strings.ToLower("ForwardX11"): true, + strings.ToLower("ForwardX11Trusted"): true, + strings.ToLower("GatewayPorts"): true, + strings.ToLower("GSSAPIAuthentication"): true, + strings.ToLower("GSSAPIDelegateCredentials"): true, + strings.ToLower("HostbasedAuthentication"): true, + strings.ToLower("IdentitiesOnly"): true, + strings.ToLower("KbdInteractiveAuthentication"): true, + strings.ToLower("NoHostAuthenticationForLocalhost"): true, + strings.ToLower("PasswordAuthentication"): true, + strings.ToLower("PermitLocalCommand"): true, + strings.ToLower("PubkeyAuthentication"): true, + strings.ToLower("RhostsRSAAuthentication"): true, + strings.ToLower("RSAAuthentication"): true, + strings.ToLower("StreamLocalBindUnlink"): true, + strings.ToLower("TCPKeepAlive"): true, + strings.ToLower("UseKeychain"): true, + strings.ToLower("UsePrivilegedPort"): true, + strings.ToLower("VisualHostKey"): true, +} + +var uints = map[string]bool{ + strings.ToLower("CanonicalizeMaxDots"): true, + strings.ToLower("CompressionLevel"): true, // 1 to 9 + strings.ToLower("ConnectionAttempts"): true, + strings.ToLower("ConnectTimeout"): true, + strings.ToLower("NumberOfPasswordPrompts"): true, + strings.ToLower("Port"): true, + strings.ToLower("ServerAliveCountMax"): true, + strings.ToLower("ServerAliveInterval"): true, +} + +func mustBeYesOrNo(lkey string) bool { + return yesnos[lkey] +} + +func mustBeUint(lkey string) bool { + return uints[lkey] +} + +func validate(key, val string) error { + lkey := strings.ToLower(key) + if mustBeYesOrNo(lkey) && (val != "yes" && val != "no") { + return fmt.Errorf("ssh_config: value for key %q must be 'yes' or 'no', got %q", key, val) + } + if mustBeUint(lkey) { + _, err := strconv.ParseUint(val, 10, 64) + if err != nil { + return fmt.Errorf("ssh_config: %v", err) + } + } + return nil +} + +var defaults = map[string]string{ + strings.ToLower("AddKeysToAgent"): "no", + strings.ToLower("AddressFamily"): "any", + strings.ToLower("BatchMode"): "no", + strings.ToLower("CanonicalizeFallbackLocal"): "yes", + strings.ToLower("CanonicalizeHostname"): "no", + strings.ToLower("CanonicalizeMaxDots"): "1", + strings.ToLower("ChallengeResponseAuthentication"): "yes", + strings.ToLower("CheckHostIP"): "yes", + // TODO is this still the correct cipher + strings.ToLower("Cipher"): "3des", + strings.ToLower("Ciphers"): "chacha20-poly1305@openssh.com,aes128-ctr,aes192-ctr,aes256-ctr,aes128-gcm@openssh.com,aes256-gcm@openssh.com,aes128-cbc,aes192-cbc,aes256-cbc", + strings.ToLower("ClearAllForwardings"): "no", + strings.ToLower("Compression"): "no", + strings.ToLower("CompressionLevel"): "6", + strings.ToLower("ConnectionAttempts"): "1", + strings.ToLower("ControlMaster"): "no", + strings.ToLower("EnableSSHKeysign"): "no", + strings.ToLower("EscapeChar"): "~", + strings.ToLower("ExitOnForwardFailure"): "no", + strings.ToLower("FingerprintHash"): "sha256", + strings.ToLower("ForwardAgent"): "no", + strings.ToLower("ForwardX11"): "no", + strings.ToLower("ForwardX11Timeout"): "20m", + strings.ToLower("ForwardX11Trusted"): "no", + strings.ToLower("GatewayPorts"): "no", + strings.ToLower("GlobalKnownHostsFile"): "/etc/ssh/ssh_known_hosts /etc/ssh/ssh_known_hosts2", + strings.ToLower("GSSAPIAuthentication"): "no", + strings.ToLower("GSSAPIDelegateCredentials"): "no", + strings.ToLower("HashKnownHosts"): "no", + strings.ToLower("HostbasedAuthentication"): "no", + + strings.ToLower("HostbasedKeyTypes"): "ecdsa-sha2-nistp256-cert-v01@openssh.com,ecdsa-sha2-nistp384-cert-v01@openssh.com,ecdsa-sha2-nistp521-cert-v01@openssh.com,ssh-ed25519-cert-v01@openssh.com,ssh-rsa-cert-v01@openssh.com,ecdsa-sha2-nistp256,ecdsa-sha2-nistp384,ecdsa-sha2-nistp521,ssh-ed25519,ssh-rsa", + strings.ToLower("HostKeyAlgorithms"): "ecdsa-sha2-nistp256-cert-v01@openssh.com,ecdsa-sha2-nistp384-cert-v01@openssh.com,ecdsa-sha2-nistp521-cert-v01@openssh.com,ssh-ed25519-cert-v01@openssh.com,ssh-rsa-cert-v01@openssh.com,ecdsa-sha2-nistp256,ecdsa-sha2-nistp384,ecdsa-sha2-nistp521,ssh-ed25519,ssh-rsa", + // HostName has a dynamic default (the value passed at the command line). + + strings.ToLower("IdentitiesOnly"): "no", + strings.ToLower("IdentityFile"): "~/.ssh/identity", + + // IPQoS has a dynamic default based on interactive or non-interactive + // sessions. + + strings.ToLower("KbdInteractiveAuthentication"): "yes", + + strings.ToLower("KexAlgorithms"): "curve25519-sha256,curve25519-sha256@libssh.org,ecdh-sha2-nistp256,ecdh-sha2-nistp384,ecdh-sha2-nistp521,diffie-hellman-group-exchange-sha256,diffie-hellman-group-exchange-sha1,diffie-hellman-group14-sha1", + strings.ToLower("LogLevel"): "INFO", + strings.ToLower("MACs"): "umac-64-etm@openssh.com,umac-128-etm@openssh.com,hmac-sha2-256-etm@openssh.com,hmac-sha2-512-etm@openssh.com,hmac-sha1-etm@openssh.com,umac-64@openssh.com,umac-128@openssh.com,hmac-sha2-256,hmac-sha2-512,hmac-sha1", + + strings.ToLower("NoHostAuthenticationForLocalhost"): "no", + strings.ToLower("NumberOfPasswordPrompts"): "3", + strings.ToLower("PasswordAuthentication"): "yes", + strings.ToLower("PermitLocalCommand"): "no", + strings.ToLower("Port"): "22", + + strings.ToLower("PreferredAuthentications"): "gssapi-with-mic,hostbased,publickey,keyboard-interactive,password", + strings.ToLower("Protocol"): "2", + strings.ToLower("ProxyUseFdpass"): "no", + strings.ToLower("PubkeyAcceptedKeyTypes"): "ecdsa-sha2-nistp256-cert-v01@openssh.com,ecdsa-sha2-nistp384-cert-v01@openssh.com,ecdsa-sha2-nistp521-cert-v01@openssh.com,ssh-ed25519-cert-v01@openssh.com,ssh-rsa-cert-v01@openssh.com,ecdsa-sha2-nistp256,ecdsa-sha2-nistp384,ecdsa-sha2-nistp521,ssh-ed25519,ssh-rsa", + strings.ToLower("PubkeyAuthentication"): "yes", + strings.ToLower("RekeyLimit"): "default none", + strings.ToLower("RhostsRSAAuthentication"): "no", + strings.ToLower("RSAAuthentication"): "yes", + + strings.ToLower("ServerAliveCountMax"): "3", + strings.ToLower("ServerAliveInterval"): "0", + strings.ToLower("StreamLocalBindMask"): "0177", + strings.ToLower("StreamLocalBindUnlink"): "no", + strings.ToLower("StrictHostKeyChecking"): "ask", + strings.ToLower("TCPKeepAlive"): "yes", + strings.ToLower("Tunnel"): "no", + strings.ToLower("TunnelDevice"): "any:any", + strings.ToLower("UpdateHostKeys"): "no", + strings.ToLower("UseKeychain"): "no", + strings.ToLower("UsePrivilegedPort"): "no", + + strings.ToLower("UserKnownHostsFile"): "~/.ssh/known_hosts ~/.ssh/known_hosts2", + strings.ToLower("VerifyHostKeyDNS"): "no", + strings.ToLower("VisualHostKey"): "no", + strings.ToLower("XAuthLocation"): "/usr/X11R6/bin/xauth", +} + +// these identities are used for SSH protocol 2 +var defaultProtocol2Identities = []string{ + "~/.ssh/id_dsa", + "~/.ssh/id_ecdsa", + "~/.ssh/id_ed25519", + "~/.ssh/id_rsa", +} + +// these directives support multiple items that can be collected +// across multiple files +var pluralDirectives = map[string]bool{ + "CertificateFile": true, + "IdentityFile": true, + "DynamicForward": true, + "RemoteForward": true, + "SendEnv": true, + "SetEnv": true, +} + +// SupportsMultiple reports whether a directive can be specified multiple times. +func SupportsMultiple(key string) bool { + return pluralDirectives[strings.ToLower(key)] +} diff --git a/validators_test.go b/validators_test.go new file mode 100644 index 0000000..ac34554 --- /dev/null +++ b/validators_test.go @@ -0,0 +1,44 @@ +package ssh_config + +import ( + "testing" +) + +var validateTests = []struct { + key string + val string + err string +}{ + {"IdentitiesOnly", "yes", ""}, + {"IdentitiesOnly", "Yes", `ssh_config: value for key "IdentitiesOnly" must be 'yes' or 'no', got "Yes"`}, + {"Port", "22", ``}, + {"Port", "yes", `ssh_config: strconv.ParseUint: parsing "yes": invalid syntax`}, +} + +func TestValidate(t *testing.T) { + for _, tt := range validateTests { + err := validate(tt.key, tt.val) + if tt.err == "" && err != nil { + t.Errorf("validate(%q, %q): got %v, want nil", tt.key, tt.val, err) + } + if tt.err != "" { + if err == nil { + t.Errorf("validate(%q, %q): got nil error, want %v", tt.key, tt.val, tt.err) + } else if err.Error() != tt.err { + t.Errorf("validate(%q, %q): got err %v, want %v", tt.key, tt.val, err, tt.err) + } + } + } +} + +func TestDefault(t *testing.T) { + if v := Default("VisualHostKey"); v != "no" { + t.Errorf("Default(%q): got %v, want 'no'", "VisualHostKey", v) + } + if v := Default("visualhostkey"); v != "no" { + t.Errorf("Default(%q): got %v, want 'no'", "visualhostkey", v) + } + if v := Default("notfound"); v != "" { + t.Errorf("Default(%q): got %v, want ''", "notfound", v) + } +}