diff --git a/.forgejo/workflows/test.yml b/.forgejo/workflows/test.yml new file mode 100644 index 0000000..0424bf4 --- /dev/null +++ b/.forgejo/workflows/test.yml @@ -0,0 +1,45 @@ +name: test +on: + - push + - pull_request + +jobs: + test: + runs-on: docker-global-bookworm + container: + image: 'code.forgejo.org/oci/node:20-bookworm' + services: + redis: + image: registry.redict.io/redict:7 + pgsql: + image: 'code.forgejo.org/oci/postgres:16' + env: + POSTGRES_DB: testsession + POSTGRES_PASSWORD: postgres + mysql: + image: mariadb:11 + env: + MARIADB_ALLOW_EMPTY_ROOT_PASSWORD: yes + MARIADB_DATABASE: testsession + # + # See also https://codeberg.org/forgejo/forgejo/issues/976 + # + MARIADB_EXTRA_FLAGS: --innodb-adaptive-flushing=OFF --innodb-buffer-pool-size=4G --innodb-log-buffer-size=128M --innodb-flush-log-at-trx-commit=0 --innodb-flush-log-at-timeout=30 --innodb-flush-method=nosync --innodb-fsync-threshold=1000000000 + memcached: + image: memcached:1.6-alpine + steps: + - uses: https://code.forgejo.org/actions/checkout@v4 + - uses: https://code.forgejo.org/actions/setup-go@v5 + with: + go-version-file: "go.mod" + - name: lint + run: go run github.com/golangci/golangci-lint/cmd/golangci-lint@v1.60.3 run + - name: create database tables + run: | + apt-get update -qq + apt-get install -y -qq postgresql-client mariadb-client + PGPASSWORD=postgres psql -h pgsql -U postgres -c 'CREATE TABLE IF NOT EXISTS session (key TEXT PRIMARY KEY, data BYTEA, expiry BIGINT)' testsession + mariadb -h mysql -u root -D testsession -e 'CREATE TABLE IF NOT EXISTS session (`key` varchar(255) PRIMARY KEY, data BLOB, expiry BIGINT)' + + - name: test + run: go test -v ./... diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..5049f84 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +/vendor +/.idea diff --git a/.golangci.yaml b/.golangci.yaml new file mode 100644 index 0000000..25287ce --- /dev/null +++ b/.golangci.yaml @@ -0,0 +1,27 @@ +linters: + enable-all: false + disable-all: true + fast: false + enable: + - bidichk + - dupl + - errcheck + - forbidigo + - gocritic + - gofmt + - gofumpt + - gosimple + - govet + - ineffassign + - nakedret + - nolintlint + - revive + - staticcheck + - stylecheck + - tenv + - testifylint + - typecheck + - unconvert + - unused + - unparam + - wastedassign diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..8405e89 --- /dev/null +++ b/LICENSE @@ -0,0 +1,191 @@ +Apache License +Version 2.0, January 2004 +http://www.apache.org/licenses/ + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + +1. Definitions. + +"License" shall mean the terms and conditions for use, reproduction, and +distribution as defined by Sections 1 through 9 of this document. + +"Licensor" shall mean the copyright owner or entity authorized by the copyright +owner that is granting the License. + +"Legal Entity" shall mean the union of the acting entity and all other entities +that control, are controlled by, or are under common control with that entity. +For the purposes of this definition, "control" means (i) the power, direct or +indirect, to cause the direction or management of such entity, whether by +contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the +outstanding shares, or (iii) beneficial ownership of such entity. + +"You" (or "Your") shall mean an individual or Legal Entity exercising +permissions granted by this License. + +"Source" form shall mean the preferred form for making modifications, including +but not limited to software source code, documentation source, and configuration +files. + +"Object" form shall mean any form resulting from mechanical transformation or +translation of a Source form, including but not limited to compiled object code, +generated documentation, and conversions to other media types. + +"Work" shall mean the work of authorship, whether in Source or Object form, made +available under the License, as indicated by a copyright notice that is included +in or attached to the work (an example is provided in the Appendix below). + +"Derivative Works" shall mean any work, whether in Source or Object form, that +is based on (or derived from) the Work and for which the editorial revisions, +annotations, elaborations, or other modifications represent, as a whole, an +original work of authorship. For the purposes of this License, Derivative Works +shall not include works that remain separable from, or merely link (or bind by +name) to the interfaces of, the Work and Derivative Works thereof. + +"Contribution" shall mean any work of authorship, including the original version +of the Work and any modifications or additions to that Work or Derivative Works +thereof, that is intentionally submitted to Licensor for inclusion in the Work +by the copyright owner or by an individual or Legal Entity authorized to submit +on behalf of the copyright owner. For the purposes of this definition, +"submitted" means any form of electronic, verbal, or written communication sent +to the Licensor or its representatives, including but not limited to +communication on electronic mailing lists, source code control systems, and +issue tracking systems that are managed by, or on behalf of, the Licensor for +the purpose of discussing and improving the Work, but excluding communication +that is conspicuously marked or otherwise designated in writing by the copyright +owner as "Not a Contribution." + +"Contributor" shall mean Licensor and any individual or Legal Entity on behalf +of whom a Contribution has been received by Licensor and subsequently +incorporated within the Work. + +2. Grant of Copyright License. + +Subject to the terms and conditions of this License, each Contributor hereby +grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, +irrevocable copyright license to reproduce, prepare Derivative Works of, +publicly display, publicly perform, sublicense, and distribute the Work and such +Derivative Works in Source or Object form. + +3. Grant of Patent License. + +Subject to the terms and conditions of this License, each Contributor hereby +grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, +irrevocable (except as stated in this section) patent license to make, have +made, use, offer to sell, sell, import, and otherwise transfer the Work, where +such license applies only to those patent claims licensable by such Contributor +that are necessarily infringed by their Contribution(s) alone or by combination +of their Contribution(s) with the Work to which such Contribution(s) was +submitted. If You institute patent litigation against any entity (including a +cross-claim or counterclaim in a lawsuit) alleging that the Work or a +Contribution incorporated within the Work constitutes direct or contributory +patent infringement, then any patent licenses granted to You under this License +for that Work shall terminate as of the date such litigation is filed. + +4. Redistribution. + +You may reproduce and distribute copies of the Work or Derivative Works thereof +in any medium, with or without modifications, and in Source or Object form, +provided that You meet the following conditions: + +You must give any other recipients of the Work or Derivative Works a copy of +this License; and +You must cause any modified files to carry prominent notices stating that You +changed the files; and +You must retain, in the Source form of any Derivative Works that You distribute, +all copyright, patent, trademark, and attribution notices from the Source form +of the Work, excluding those notices that do not pertain to any part of the +Derivative Works; and +If the Work includes a "NOTICE" text file as part of its distribution, then any +Derivative Works that You distribute must include a readable copy of the +attribution notices contained within such NOTICE file, excluding those notices +that do not pertain to any part of the Derivative Works, in at least one of the +following places: within a NOTICE text file distributed as part of the +Derivative Works; within the Source form or documentation, if provided along +with the Derivative Works; or, within a display generated by the Derivative +Works, if and wherever such third-party notices normally appear. The contents of +the NOTICE file are for informational purposes only and do not modify the +License. You may add Your own attribution notices within Derivative Works that +You distribute, alongside or as an addendum to the NOTICE text from the Work, +provided that such additional attribution notices cannot be construed as +modifying the License. +You may add Your own copyright statement to Your modifications and may provide +additional or different license terms and conditions for use, reproduction, or +distribution of Your modifications, or for any such Derivative Works as a whole, +provided Your use, reproduction, and distribution of the Work otherwise complies +with the conditions stated in this License. + +5. Submission of Contributions. + +Unless You explicitly state otherwise, any Contribution intentionally submitted +for inclusion in the Work by You to the Licensor shall be under the terms and +conditions of this License, without any additional terms or conditions. +Notwithstanding the above, nothing herein shall supersede or modify the terms of +any separate license agreement you may have executed with Licensor regarding +such Contributions. + +6. Trademarks. + +This License does not grant permission to use the trade names, trademarks, +service marks, or product names of the Licensor, except as required for +reasonable and customary use in describing the origin of the Work and +reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. + +Unless required by applicable law or agreed to in writing, Licensor provides the +Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, +including, without limitation, any warranties or conditions of TITLE, +NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are +solely responsible for determining the appropriateness of using or +redistributing the Work and assume any risks associated with Your exercise of +permissions under this License. + +8. Limitation of Liability. + +In no event and under no legal theory, whether in tort (including negligence), +contract, or otherwise, unless required by applicable law (such as deliberate +and grossly negligent acts) or agreed to in writing, shall any Contributor be +liable to You for damages, including any direct, indirect, special, incidental, +or consequential damages of any character arising as a result of this License or +out of the use or inability to use the Work (including but not limited to +damages for loss of goodwill, work stoppage, computer failure or malfunction, or +any and all other commercial damages or losses), even if such Contributor has +been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. + +While redistributing the Work or Derivative Works thereof, You may choose to +offer, and charge a fee for, acceptance of support, warranty, indemnity, or +other liability obligations and/or rights consistent with this License. However, +in accepting such obligations, You may act only on Your own behalf and on Your +sole responsibility, not on behalf of any other Contributor, and only if You +agree to indemnify, defend, and hold each Contributor harmless for any liability +incurred by, or claims asserted against, such Contributor by reason of your +accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS + +APPENDIX: How to apply the Apache License to your work + +To apply the Apache License to your work, attach the following boilerplate +notice, with the fields enclosed by brackets "[]" replaced with your own +identifying information. (Don't include the brackets!) The text should be +enclosed in the appropriate comment syntax for the file format. We also +recommend that a file or class name and description of purpose be included on +the same "printed page" as the copyright notice for easier identification within +third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..8a560d9 --- /dev/null +++ b/README.md @@ -0,0 +1,17 @@ +# Session + +Middleware session provides session management which based on a [fork](https://gitea.com/go-chi/session) of a [fork](https://gitea.com/macaron/session) of [Macaron Session](https://github.com/go-macaron/session) for [go-chi](https://github.com/go-chi/chi). It can use many session providers, including memory, file, Redis, Memcache, PostgreSQL and MySQL. + +## Installation + +``` +go get code.forgejo.org/go-chi/session +``` + +## Credits + +This package is a modified version of [go-macaron/session](https://github.com/go-macaron/session). + +## License + +This project is under the Apache License, Version 2.0. See the [LICENSE](LICENSE) file for the full license text. diff --git a/file.go b/file.go new file mode 100644 index 0000000..000a47f --- /dev/null +++ b/file.go @@ -0,0 +1,276 @@ +// Copyright 2013 Beego Authors +// Copyright 2014 The Macaron Authors +// Copyright 2024 The Forgejo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"): you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +package session + +import ( + "encoding/gob" + "fmt" + "io" + "log" + "os" + "path" + "path/filepath" + "sync" + "time" +) + +// FileStore represents a file session store implementation. +type FileStore struct { + p *FileProvider + sid string + lock sync.RWMutex + data map[interface{}]interface{} +} + +// NewFileStore creates and returns a file session store. +func NewFileStore(p *FileProvider, sid string, kv map[interface{}]interface{}) *FileStore { + return &FileStore{ + p: p, + sid: sid, + data: kv, + } +} + +// Set sets value to given key in session. +func (s *FileStore) Set(key, val interface{}) error { + s.lock.Lock() + defer s.lock.Unlock() + + s.data[key] = val + return nil +} + +// Get gets value by given key in session. +func (s *FileStore) Get(key interface{}) interface{} { + s.lock.RLock() + defer s.lock.RUnlock() + + return s.data[key] +} + +// Delete delete a key from session. +func (s *FileStore) Delete(key interface{}) error { + s.lock.Lock() + defer s.lock.Unlock() + + delete(s.data, key) + return nil +} + +// ID returns current session ID. +func (s *FileStore) ID() string { + return s.sid +} + +// Release releases resource and save data to provider. +func (s *FileStore) Release() error { + s.p.lock.Lock() + defer s.p.lock.Unlock() + + // Skip encoding if the data is empty + if len(s.data) == 0 { + return nil + } + + data, err := EncodeGob(s.data) + if err != nil { + return err + } + + return os.WriteFile(s.p.filepath(s.sid), data, 0o600) +} + +// Flush deletes all session data. +func (s *FileStore) Flush() error { + s.lock.Lock() + defer s.lock.Unlock() + + s.data = make(map[interface{}]interface{}) + return nil +} + +// FileProvider represents a file session provider implementation. +type FileProvider struct { + lock sync.RWMutex + maxlifetime int64 + rootPath string +} + +// Init initializes file session provider with given root path. +func (p *FileProvider) Init(maxlifetime int64, rootPath string) error { + p.lock.Lock() + p.maxlifetime = maxlifetime + p.rootPath = rootPath + p.lock.Unlock() + return nil +} + +func (p *FileProvider) filepath(sid string) string { + return path.Join(p.rootPath, string(sid[0]), string(sid[1]), sid) +} + +// Read returns raw session store by session ID. +func (p *FileProvider) Read(sid string) (_ RawStore, err error) { + filename := p.filepath(sid) + if err = os.MkdirAll(path.Dir(filename), 0o700); err != nil { + return nil, err + } + p.lock.RLock() + defer p.lock.RUnlock() + + var f *os.File + fStat, err := os.Stat(filename) + if err != nil || (fStat.ModTime().Unix()+p.maxlifetime) < time.Now().Unix() { + if err != nil && !os.IsNotExist(err) { + return nil, err + } + f, err = os.Create(filename) + } else { + f, err = os.OpenFile(filename, os.O_RDONLY, 0o600) + } + if err != nil { + return nil, err + } + defer f.Close() + + if err = os.Chtimes(filename, time.Now(), time.Now()); err != nil { + return nil, err + } + + var kv map[any]any + err = gob.NewDecoder(f).Decode(&kv) + if err != nil { + if err != io.EOF { + return nil, err + } + // the session file has been truncated and is now invalid - therefore all session data is lost + kv = make(map[any]any) + + } + return NewFileStore(p, sid, kv), nil +} + +// Exist returns true if session with given ID exists. +func (p *FileProvider) Exist(sid string) bool { + p.lock.RLock() + defer p.lock.RUnlock() + _, err := os.Stat(p.filepath(sid)) + return err == nil || os.IsExist(err) +} + +// Destroy deletes a session by session ID. +func (p *FileProvider) Destroy(sid string) error { + p.lock.Lock() + defer p.lock.Unlock() + return os.Remove(p.filepath(sid)) +} + +func (p *FileProvider) regenerate(oldsid, sid string) (err error) { + p.lock.Lock() + defer p.lock.Unlock() + + filename := p.filepath(sid) + _, err = os.Stat(p.filepath(filename)) + if err == nil || os.IsExist(err) { + return fmt.Errorf("new sid '%s' already exists", sid) + } + + oldname := p.filepath(oldsid) + fStat, err := os.Stat(oldname) + if err != nil || fStat.IsDir() { + data, err := EncodeGob(make(map[interface{}]interface{})) + if err != nil { + return err + } + if err = os.MkdirAll(path.Dir(oldname), 0o700); err != nil { + return err + } + if err = os.WriteFile(oldname, data, 0o600); err != nil { + return err + } + } + + if err = os.MkdirAll(path.Dir(filename), 0o700); err != nil { + return err + } + if err = os.Rename(oldname, filename); err != nil { + return err + } + return nil +} + +// Regenerate regenerates a session store from old session ID to new one. +func (p *FileProvider) Regenerate(oldsid, sid string) (_ RawStore, err error) { + if err := p.regenerate(oldsid, sid); err != nil { + return nil, err + } + + return p.Read(sid) +} + +// Count counts and returns number of sessions. +func (p *FileProvider) Count() int { + count := 0 + if err := filepath.WalkDir(p.rootPath, func(_ string, d os.DirEntry, err error) error { + if err != nil { + return err + } + + if !d.IsDir() { + count++ + } + return nil + }); err != nil { + log.Printf("error counting session files: %v", err) + return 0 + } + return count +} + +// GC calls GC to clean expired sessions. +func (p *FileProvider) GC() { + p.lock.RLock() + defer p.lock.RUnlock() + + _, err := os.Stat(p.rootPath) + if err != nil && os.IsNotExist(err) { + return + } + + if err := filepath.WalkDir(p.rootPath, func(path string, d os.DirEntry, err error) error { + if err != nil { + return err + } + + info, err := d.Info() + if err != nil { + return err + } + + if !d.IsDir() && + (info.ModTime().Unix()+p.maxlifetime) < time.Now().Unix() { + return os.Remove(path) + } + return nil + }); err != nil { + log.Printf("error garbage collecting session files: %v", err) + } +} + +func init() { + Register("file", &FileProvider{}) +} diff --git a/file_test.go b/file_test.go new file mode 100644 index 0000000..7105cb0 --- /dev/null +++ b/file_test.go @@ -0,0 +1,24 @@ +// Copyright 2014 The Macaron Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"): you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +package session + +import "testing" + +func Test_FileProvider(t *testing.T) { + testProvider(t, Options{ + Provider: "file", + ProviderConfig: t.TempDir(), + }) +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..895f29d --- /dev/null +++ b/go.mod @@ -0,0 +1,23 @@ +module code.forgejo.org/go-chi/session + +go 1.23 + +toolchain go1.23.4 + +require ( + github.com/bradfitz/gomemcache v0.0.0-20230905024940-24af94b03874 + github.com/go-chi/chi/v5 v5.1.0 + github.com/go-sql-driver/mysql v1.8.1 + github.com/lib/pq v1.10.9 + github.com/redis/go-redis/v9 v9.7.0 + github.com/stretchr/testify v1.10.0 +) + +require ( + filippo.io/edwards25519 v1.1.0 // indirect + github.com/cespare/xxhash/v2 v2.2.0 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..2dc7cf6 --- /dev/null +++ b/go.sum @@ -0,0 +1,30 @@ +filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= +filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= +github.com/bradfitz/gomemcache v0.0.0-20230905024940-24af94b03874 h1:N7oVaKyGp8bttX0bfZGmcGkjz7DLQXhAn3DNd3T0ous= +github.com/bradfitz/gomemcache v0.0.0-20230905024940-24af94b03874/go.mod h1:r5xuitiExdLAJ09PR7vBVENGvp4ZuTBeWTGtxuX3K+c= +github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= +github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= +github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= +github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= +github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= +github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= +github.com/go-chi/chi/v5 v5.1.0 h1:acVI1TYaD+hhedDJ3r54HyA6sExp3HfXq7QWEEY/xMw= +github.com/go-chi/chi/v5 v5.1.0/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8= +github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y= +github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= +github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= +github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/redis/go-redis/v9 v9.7.0 h1:HhLSs+B6O021gwzl+locl0zEDnyNkxMtf/Z3NNBMa9E= +github.com/redis/go-redis/v9 v9.7.0/go.mod h1:f6zhXITC7JUJIlPEiBOTXxJgPLdZcA93GewI7inzyWw= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/test/provider.go b/internal/test/provider.go new file mode 100644 index 0000000..c22602c --- /dev/null +++ b/internal/test/provider.go @@ -0,0 +1,149 @@ +package test + +import ( + "net/http" + "net/http/httptest" + "testing" + "time" + + "code.forgejo.org/go-chi/session" + "github.com/go-chi/chi/v5" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Provider(t *testing.T, opt session.Options) { + t.Run("Basic operation", func(t *testing.T) { + c := chi.NewRouter() + c.Use(session.Sessioner(opt)) + var initialSid string + + c.Get("/", func(_ http.ResponseWriter, req *http.Request) { + sess := session.GetSession(req) + assert.NoError(t, sess.Set("uname", "unknwon")) + initialSid = sess.ID() + }) + c.Get("/reg", func(resp http.ResponseWriter, req *http.Request) { + sess := session.GetSession(req) + assert.EqualValues(t, initialSid, sess.ID()) + raw, err := session.RegenerateSession(resp, req) + assert.NoError(t, err) + assert.NotNil(t, sess) + assert.EqualValues(t, sess, raw) + + assert.NotEqualValues(t, initialSid, sess.ID()) + + uname := sess.Get("uname") + assert.NotNil(t, uname) + assert.EqualValues(t, "unknwon", uname) + + assert.NoError(t, sess.Set("uname", "lunny")) + uname = sess.Get("uname") + assert.NotNil(t, uname) + assert.EqualValues(t, "lunny", uname) + }) + c.Get("/get", func(resp http.ResponseWriter, req *http.Request) { + sess := session.GetSession(req) + sid := sess.ID() + assert.NotEmpty(t, sid) + + raw, err := sess.Read(sid) + assert.NoError(t, err) + assert.NotNil(t, raw) + + uname := sess.Get("uname") + assert.NotNil(t, uname) + assert.EqualValues(t, "lunny", uname) + + assert.NoError(t, sess.Delete("uname")) + assert.Nil(t, sess.Get("uname")) + + assert.NoError(t, sess.Destroy(resp, req)) + }) + + resp := httptest.NewRecorder() + req, err := http.NewRequest("GET", "/", nil) + require.NoError(t, err) + c.ServeHTTP(resp, req) + + cookie := resp.Header().Get("Set-Cookie") + + resp = httptest.NewRecorder() + req, err = http.NewRequest("GET", "/reg", nil) + require.NoError(t, err) + req.Header.Set("Cookie", cookie) + c.ServeHTTP(resp, req) + + cookie = resp.Header().Get("Set-Cookie") + + resp = httptest.NewRecorder() + req, err = http.NewRequest("GET", "/get", nil) + require.NoError(t, err) + req.Header.Set("Cookie", cookie) + c.ServeHTTP(resp, req) + }) + + t.Run("Regenerate empty session", func(t *testing.T) { + c := chi.NewRouter() + c.Use(session.Sessioner(opt)) + c.Get("/", func(resp http.ResponseWriter, req *http.Request) { + sess := session.GetSession(req) + raw, err := sess.RegenerateID(resp, req) + assert.NoError(t, err) + assert.NotNil(t, raw) + }) + + resp := httptest.NewRecorder() + req, err := http.NewRequest("GET", "/", nil) + require.NoError(t, err) + req.Header.Set("Cookie", "MacaronSession=ad2c7e3cbecfcf48; Path=/;") + c.ServeHTTP(resp, req) + }) + + t.Run("GC session", func(t *testing.T) { + if opt.Provider == "redis" || opt.Provider == "memcache" { + t.Skip("Doesn't implement GC") + } + + c := chi.NewRouter() + opt2 := opt + opt2.Gclifetime = 1 + c.Use(session.Sessioner(opt2)) + + c.Get("/", func(_ http.ResponseWriter, req *http.Request) { + sess := session.GetSession(req) + assert.NoError(t, sess.Set("uname", "unknwon")) + assert.NotEmpty(t, sess.ID()) + uname := sess.Get("uname") + assert.NotNil(t, uname) + assert.EqualValues(t, "unknwon", uname) + + assert.NoError(t, sess.Flush()) + assert.Nil(t, sess.Get("uname")) + + time.Sleep(2 * time.Second) + sess.GC() + assert.Zero(t, sess.Count()) + }) + + resp := httptest.NewRecorder() + req, err := http.NewRequest("GET", "/", nil) + require.NoError(t, err) + c.ServeHTTP(resp, req) + }) + t.Run("Detect invalid sid", func(t *testing.T) { + c := chi.NewRouter() + c.Use(session.Sessioner(opt)) + c.Get("/", func(_ http.ResponseWriter, req *http.Request) { + sess := session.GetSession(req) + raw, err := sess.Read("../session/ad2c7e3cbecfcf486") + assert.Contains(t, err.Error(), "invalid 'sid'") + assert.Nil(t, raw) + }) + + resp := httptest.NewRecorder() + req, err := http.NewRequest("GET", "/", nil) + require.NoError(t, err) + c.ServeHTTP(resp, req) + }) +} diff --git a/memcache/memcache.go b/memcache/memcache.go new file mode 100644 index 0000000..c940905 --- /dev/null +++ b/memcache/memcache.go @@ -0,0 +1,203 @@ +// Copyright 2013 Beego Authors +// Copyright 2014 The Macaron Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"): you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +package session + +import ( + "fmt" + "strings" + "sync" + + "code.forgejo.org/go-chi/session" + "github.com/bradfitz/gomemcache/memcache" +) + +// MemcacheStore represents a memcache session store implementation. +type MemcacheStore struct { + c *memcache.Client + sid string + expire int32 + lock sync.RWMutex + data map[interface{}]interface{} +} + +// NewMemcacheStore creates and returns a memcache session store. +func NewMemcacheStore(c *memcache.Client, sid string, expire int32, kv map[interface{}]interface{}) *MemcacheStore { + return &MemcacheStore{ + c: c, + sid: sid, + expire: expire, + data: kv, + } +} + +func NewItem(sid string, data []byte, expire int32) *memcache.Item { + return &memcache.Item{ + Key: sid, + Value: data, + Expiration: expire, + } +} + +// Set sets value to given key in session. +func (s *MemcacheStore) Set(key, val interface{}) error { + s.lock.Lock() + defer s.lock.Unlock() + + s.data[key] = val + return nil +} + +// Get gets value by given key in session. +func (s *MemcacheStore) Get(key interface{}) interface{} { + s.lock.RLock() + defer s.lock.RUnlock() + + return s.data[key] +} + +// Delete delete a key from session. +func (s *MemcacheStore) Delete(key interface{}) error { + s.lock.Lock() + defer s.lock.Unlock() + + delete(s.data, key) + return nil +} + +// ID returns current session ID. +func (s *MemcacheStore) ID() string { + return s.sid +} + +// Release releases resource and save data to provider. +func (s *MemcacheStore) Release() error { + // Skip encoding if the data is empty + if len(s.data) == 0 { + return nil + } + + data, err := session.EncodeGob(s.data) + if err != nil { + return err + } + + return s.c.Set(NewItem(s.sid, data, s.expire)) +} + +// Flush deletes all session data. +func (s *MemcacheStore) Flush() error { + s.lock.Lock() + defer s.lock.Unlock() + + s.data = make(map[interface{}]interface{}) + return nil +} + +// MemcacheProvider represents a memcache session provider implementation. +type MemcacheProvider struct { + c *memcache.Client + expire int32 +} + +// Init initializes memcache session provider. +// connStrs: 127.0.0.1:9090;127.0.0.1:9091 +func (p *MemcacheProvider) Init(expire int64, connStrs string) error { + p.expire = int32(expire) + p.c = memcache.New(strings.Split(connStrs, ";")...) + return nil +} + +// Read returns raw session store by session ID. +func (p *MemcacheProvider) Read(sid string) (session.RawStore, error) { + if !p.Exist(sid) { + if err := p.c.Set(NewItem(sid, []byte(""), p.expire)); err != nil { + return nil, err + } + } + + var kv map[interface{}]interface{} + item, err := p.c.Get(sid) + if err != nil { + return nil, err + } + if len(item.Value) == 0 { + kv = make(map[interface{}]interface{}) + } else { + kv, err = session.DecodeGob(item.Value) + if err != nil { + return nil, err + } + } + + return NewMemcacheStore(p.c, sid, p.expire, kv), nil +} + +// Exist returns true if session with given ID exists. +func (p *MemcacheProvider) Exist(sid string) bool { + _, err := p.c.Get(sid) + return err == nil +} + +// Destroy deletes a session by session ID. +func (p *MemcacheProvider) Destroy(sid string) error { + return p.c.Delete(sid) +} + +// Regenerate regenerates a session store from old session ID to new one. +func (p *MemcacheProvider) Regenerate(oldsid, sid string) (_ session.RawStore, err error) { + if p.Exist(sid) { + return nil, fmt.Errorf("new sid '%s' already exists", sid) + } + + item := NewItem(sid, []byte(""), p.expire) + if p.Exist(oldsid) { + item, err = p.c.Get(oldsid) + if err != nil { + return nil, err + } else if err = p.c.Delete(oldsid); err != nil { + return nil, err + } + item.Key = sid + } + if err = p.c.Set(item); err != nil { + return nil, err + } + + var kv map[interface{}]interface{} + if len(item.Value) == 0 { + kv = make(map[interface{}]interface{}) + } else { + kv, err = session.DecodeGob(item.Value) + if err != nil { + return nil, err + } + } + + return NewMemcacheStore(p.c, sid, p.expire, kv), nil +} + +// Count counts and returns number of sessions. +func (p *MemcacheProvider) Count() int { + // FIXME: how come this library does not have Stats method? + return -1 +} + +// GC calls GC to clean expired sessions. +func (p *MemcacheProvider) GC() {} + +func init() { + session.Register("memcache", &MemcacheProvider{}) +} diff --git a/memcache/memcache_test.go b/memcache/memcache_test.go new file mode 100644 index 0000000..5d1ed0d --- /dev/null +++ b/memcache/memcache_test.go @@ -0,0 +1,30 @@ +// Copyright 2014 The Macaron Authors +// Copyright 2024 The Forgejo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"): you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +package session + +import ( + "testing" + + "code.forgejo.org/go-chi/session" + "code.forgejo.org/go-chi/session/internal/test" +) + +func Test_MemcacheProvider(t *testing.T) { + test.Provider(t, session.Options{ + Provider: "memcache", + ProviderConfig: "memcached:11211", + }) +} diff --git a/memory.go b/memory.go new file mode 100644 index 0000000..8ffe090 --- /dev/null +++ b/memory.go @@ -0,0 +1,223 @@ +// Copyright 2013 Beego Authors +// Copyright 2014 The Macaron Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"): you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +package session + +import ( + "container/list" + "fmt" + "sync" + "time" +) + +// MemStore represents a in-memory session store implementation. +type MemStore struct { + sid string + lock sync.RWMutex + data map[interface{}]interface{} + lastAccess time.Time +} + +// NewMemStore creates and returns a memory session store. +func NewMemStore(sid string) *MemStore { + return &MemStore{ + sid: sid, + data: make(map[interface{}]interface{}), + lastAccess: time.Now(), + } +} + +// Set sets value to given key in session. +func (s *MemStore) Set(key, val interface{}) error { + s.lock.Lock() + defer s.lock.Unlock() + + s.data[key] = val + return nil +} + +// Get gets value by given key in session. +func (s *MemStore) Get(key interface{}) interface{} { + s.lock.RLock() + defer s.lock.RUnlock() + + return s.data[key] +} + +// Delete deletes a key from session. +func (s *MemStore) Delete(key interface{}) error { + s.lock.Lock() + defer s.lock.Unlock() + + delete(s.data, key) + return nil +} + +// ID returns current session ID. +func (s *MemStore) ID() string { + return s.sid +} + +// Release releases resource and save data to provider. +func (*MemStore) Release() error { + return nil +} + +// Flush deletes all session data. +func (s *MemStore) Flush() error { + s.lock.Lock() + defer s.lock.Unlock() + + s.data = make(map[interface{}]interface{}) + return nil +} + +// MemProvider represents a in-memory session provider implementation. +type MemProvider struct { + lock sync.RWMutex + maxLifetime int64 + data map[string]*list.Element + // A priority list whose lastAccess newer gets higher priority. + list *list.List +} + +// Init initializes memory session provider. +func (p *MemProvider) Init(maxLifetime int64, _ string) error { + p.lock.Lock() + p.list = list.New() + p.data = make(map[string]*list.Element) + p.maxLifetime = maxLifetime + p.lock.Unlock() + return nil +} + +// update expands time of session store by given ID. +func (p *MemProvider) update(sid string) error { + p.lock.Lock() + defer p.lock.Unlock() + + if e, ok := p.data[sid]; ok { + e.Value.(*MemStore).lastAccess = time.Now() + p.list.MoveToFront(e) + return nil + } + return nil +} + +// Read returns raw session store by session ID. +func (p *MemProvider) Read(sid string) (_ RawStore, err error) { + p.lock.RLock() + e, ok := p.data[sid] + p.lock.RUnlock() + + // Only restore if the session is still alive. + if ok && (e.Value.(*MemStore).lastAccess.Unix()+p.maxLifetime) >= time.Now().Unix() { + if err = p.update(sid); err != nil { + return nil, err + } + return e.Value.(*MemStore), nil + } + + // Create a new session. + p.lock.Lock() + defer p.lock.Unlock() + if ok { + p.list.Remove(e) + delete(p.data, sid) + } + s := NewMemStore(sid) + p.data[sid] = p.list.PushBack(s) + return s, nil +} + +// Exist returns true if session with given ID exists. +func (p *MemProvider) Exist(sid string) bool { + p.lock.RLock() + defer p.lock.RUnlock() + + _, ok := p.data[sid] + return ok +} + +// Destroy deletes a session by session ID. +func (p *MemProvider) Destroy(sid string) error { + p.lock.Lock() + defer p.lock.Unlock() + + e, ok := p.data[sid] + if !ok { + return nil + } + + p.list.Remove(e) + delete(p.data, sid) + return nil +} + +// Regenerate regenerates a session store from old session ID to new one. +func (p *MemProvider) Regenerate(oldsid, sid string) (RawStore, error) { + if p.Exist(sid) { + return nil, fmt.Errorf("new sid '%s' already exists", sid) + } + + s, err := p.Read(oldsid) + if err != nil { + return nil, err + } + + if err = p.Destroy(oldsid); err != nil { + return nil, err + } + + s.(*MemStore).sid = sid + + p.lock.Lock() + defer p.lock.Unlock() + p.data[sid] = p.list.PushBack(s) + return s, nil +} + +// Count counts and returns number of sessions. +func (p *MemProvider) Count() int { + return p.list.Len() +} + +// GC calls GC to clean expired sessions. +func (p *MemProvider) GC() { + p.lock.RLock() + for { + // No session in the list. + e := p.list.Back() + if e == nil { + break + } + + if (e.Value.(*MemStore).lastAccess.Unix() + p.maxLifetime) < time.Now().Unix() { + p.lock.RUnlock() + p.lock.Lock() + p.list.Remove(e) + delete(p.data, e.Value.(*MemStore).sid) + p.lock.Unlock() + p.lock.RLock() + } else { + break + } + } + p.lock.RUnlock() +} + +func init() { + Register("memory", &MemProvider{}) +} diff --git a/memory_test.go b/memory_test.go new file mode 100644 index 0000000..20b529f --- /dev/null +++ b/memory_test.go @@ -0,0 +1,21 @@ +// Copyright 2014 The Macaron Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"): you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +package session + +import "testing" + +func Test_MemProvider(t *testing.T) { + testProvider(t, Options{}) +} diff --git a/mysql/mysql.go b/mysql/mysql.go new file mode 100644 index 0000000..7637144 --- /dev/null +++ b/mysql/mysql.go @@ -0,0 +1,201 @@ +// Copyright 2013 Beego Authors +// Copyright 2014 The Macaron Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"): you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +package session + +import ( + "database/sql" + "fmt" + "log" + "sync" + "time" + + "code.forgejo.org/go-chi/session" + _ "github.com/go-sql-driver/mysql" // Needed for the MySQL driver +) + +// MysqlStore represents a mysql session store implementation. +type MysqlStore struct { + c *sql.DB + sid string + lock sync.RWMutex + data map[interface{}]interface{} +} + +// NewMysqlStore creates and returns a mysql session store. +func NewMysqlStore(c *sql.DB, sid string, kv map[interface{}]interface{}) *MysqlStore { + return &MysqlStore{ + c: c, + sid: sid, + data: kv, + } +} + +// Set sets value to given key in session. +func (s *MysqlStore) Set(key, val interface{}) error { + s.lock.Lock() + defer s.lock.Unlock() + + s.data[key] = val + return nil +} + +// Get gets value by given key in session. +func (s *MysqlStore) Get(key interface{}) interface{} { + s.lock.RLock() + defer s.lock.RUnlock() + + return s.data[key] +} + +// Delete delete a key from session. +func (s *MysqlStore) Delete(key interface{}) error { + s.lock.Lock() + defer s.lock.Unlock() + + delete(s.data, key) + return nil +} + +// ID returns current session ID. +func (s *MysqlStore) ID() string { + return s.sid +} + +// Release releases resource and save data to provider. +func (s *MysqlStore) Release() error { + // Skip encoding if the data is empty + if len(s.data) == 0 { + return nil + } + + data, err := session.EncodeGob(s.data) + if err != nil { + return err + } + + _, err = s.c.Exec("UPDATE session SET data=?, expiry=? WHERE `key`=?", + data, time.Now().Unix(), s.sid) + return err +} + +// Flush deletes all session data. +func (s *MysqlStore) Flush() error { + s.lock.Lock() + defer s.lock.Unlock() + + s.data = make(map[interface{}]interface{}) + return nil +} + +// MysqlProvider represents a mysql session provider implementation. +type MysqlProvider struct { + c *sql.DB + expire int64 +} + +// Init initializes mysql session provider. +// connStr: username:password@protocol(address)/dbname?param=value +func (p *MysqlProvider) Init(expire int64, connStr string) (err error) { + p.expire = expire + + p.c, err = sql.Open("mysql", connStr) + if err != nil { + return err + } + return p.c.Ping() +} + +// Read returns raw session store by session ID. +func (p *MysqlProvider) Read(sid string) (session.RawStore, error) { + now := time.Now().Unix() + var data []byte + expiry := now + err := p.c.QueryRow("SELECT data, expiry FROM session WHERE `key`=?", sid).Scan(&data, &expiry) + if err == sql.ErrNoRows { + _, err = p.c.Exec("INSERT INTO session(`key`,data,expiry) VALUES(?,?,?)", + sid, "", now) + } + if err != nil { + return nil, err + } + + var kv map[interface{}]interface{} + if len(data) == 0 || expiry+p.expire <= now { + kv = make(map[interface{}]interface{}) + } else { + kv, err = session.DecodeGob(data) + if err != nil { + return nil, err + } + } + + return NewMysqlStore(p.c, sid, kv), nil +} + +// Exist returns true if session with given ID exists. +func (p *MysqlProvider) Exist(sid string) bool { + var data []byte + err := p.c.QueryRow("SELECT data FROM session WHERE `key`=?", sid).Scan(&data) + if err != nil && err != sql.ErrNoRows { + panic("session/mysql: error checking existence: " + err.Error()) + } + return err != sql.ErrNoRows +} + +// Destroy deletes a session by session ID. +func (p *MysqlProvider) Destroy(sid string) error { + _, err := p.c.Exec("DELETE FROM session WHERE `key`=?", sid) + return err +} + +// Regenerate regenerates a session store from old session ID to new one. +func (p *MysqlProvider) Regenerate(oldsid, sid string) (_ session.RawStore, err error) { + if p.Exist(sid) { + return nil, fmt.Errorf("new sid '%s' already exists", sid) + } + + if !p.Exist(oldsid) { + if _, err = p.c.Exec("INSERT INTO session(`key`,data,expiry) VALUES(?,?,?)", + oldsid, "", time.Now().Unix()); err != nil { + return nil, err + } + } + + if _, err = p.c.Exec("UPDATE session SET `key`=? WHERE `key`=?", sid, oldsid); err != nil { + return nil, err + } + + return p.Read(sid) +} + +// Count counts and returns number of sessions. +func (p *MysqlProvider) Count() (total int) { + if err := p.c.QueryRow("SELECT COUNT(*) AS NUM FROM session").Scan(&total); err != nil { + panic("session/mysql: error counting records: " + err.Error()) + } + return total +} + +// GC calls GC to clean expired sessions. +func (p *MysqlProvider) GC() { + if _, err := p.c.Exec("DELETE FROM session WHERE expiry + ? <= UNIX_TIMESTAMP(NOW())", p.expire); err != nil { + log.Printf("session/mysql: error garbage collecting: %v", err) + } +} + +func init() { + session.Register("mysql", &MysqlProvider{}) +} diff --git a/mysql/mysql_test.go b/mysql/mysql_test.go new file mode 100644 index 0000000..f05445b --- /dev/null +++ b/mysql/mysql_test.go @@ -0,0 +1,30 @@ +// Copyright 2014 The Macaron Authors +// Copyright 2024 The Forgejo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"): you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +package session + +import ( + "testing" + + "code.forgejo.org/go-chi/session" + "code.forgejo.org/go-chi/session/internal/test" +) + +func Test_MysqlProvider(t *testing.T) { + test.Provider(t, session.Options{ + Provider: "mysql", + ProviderConfig: "root:@tcp(mysql:3306)/testsession?charset=utf8", + }) +} diff --git a/postgres/postgres.go b/postgres/postgres.go new file mode 100644 index 0000000..d2633ee --- /dev/null +++ b/postgres/postgres.go @@ -0,0 +1,203 @@ +// Copyright 2013 Beego Authors +// Copyright 2014 The Macaron Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"): you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +package session + +import ( + "database/sql" + "fmt" + "log" + "sync" + "time" + + "code.forgejo.org/go-chi/session" + _ "github.com/lib/pq" // Needed for the Postgresql driver +) + +// PostgresStore represents a postgres session store implementation. +type PostgresStore struct { + c *sql.DB + sid string + lock sync.RWMutex + data map[interface{}]interface{} +} + +// NewPostgresStore creates and returns a postgres session store. +func NewPostgresStore(c *sql.DB, sid string, kv map[interface{}]interface{}) *PostgresStore { + return &PostgresStore{ + c: c, + sid: sid, + data: kv, + } +} + +// Set sets value to given key in session. +func (s *PostgresStore) Set(key, value interface{}) error { + s.lock.Lock() + defer s.lock.Unlock() + + s.data[key] = value + return nil +} + +// Get gets value by given key in session. +func (s *PostgresStore) Get(key interface{}) interface{} { + s.lock.RLock() + defer s.lock.RUnlock() + + return s.data[key] +} + +// Delete delete a key from session. +func (s *PostgresStore) Delete(key interface{}) error { + s.lock.Lock() + defer s.lock.Unlock() + + delete(s.data, key) + return nil +} + +// ID returns current session ID. +func (s *PostgresStore) ID() string { + return s.sid +} + +// save postgres session values to database. +// must call this method to save values to database. +func (s *PostgresStore) Release() error { + // Skip encoding if the data is empty + if len(s.data) == 0 { + return nil + } + + data, err := session.EncodeGob(s.data) + if err != nil { + return err + } + + _, err = s.c.Exec("UPDATE session SET data=$1, expiry=$2 WHERE key=$3", + data, time.Now().Unix(), s.sid) + return err +} + +// Flush deletes all session data. +func (s *PostgresStore) Flush() error { + s.lock.Lock() + defer s.lock.Unlock() + + s.data = make(map[interface{}]interface{}) + return nil +} + +// PostgresProvider represents a postgres session provider implementation. +type PostgresProvider struct { + c *sql.DB + maxlifetime int64 +} + +// Init initializes postgres session provider. +// connStr: user=a password=b host=localhost port=5432 dbname=c sslmode=disable +func (p *PostgresProvider) Init(maxlifetime int64, connStr string) (err error) { + p.maxlifetime = maxlifetime + + p.c, err = sql.Open("postgres", connStr) + if err != nil { + return err + } + + return p.c.Ping() +} + +// Read returns raw session store by session ID. +func (p *PostgresProvider) Read(sid string) (session.RawStore, error) { + now := time.Now().Unix() + var data []byte + expiry := now + err := p.c.QueryRow("SELECT data, expiry FROM session WHERE key=$1", sid).Scan(&data, &expiry) + if err == sql.ErrNoRows { + _, err = p.c.Exec("INSERT INTO session(key,data,expiry) VALUES($1,$2,$3)", + sid, "", now) + } + if err != nil { + return nil, err + } + + var kv map[interface{}]interface{} + if len(data) == 0 || expiry+p.maxlifetime <= now { + kv = make(map[interface{}]interface{}) + } else { + kv, err = session.DecodeGob(data) + if err != nil { + return nil, err + } + } + + return NewPostgresStore(p.c, sid, kv), nil +} + +// Exist returns true if session with given ID exists. +func (p *PostgresProvider) Exist(sid string) bool { + var data []byte + err := p.c.QueryRow("SELECT data FROM session WHERE key=$1", sid).Scan(&data) + if err != nil && err != sql.ErrNoRows { + panic("session/postgres: error checking existence: " + err.Error()) + } + return err != sql.ErrNoRows +} + +// Destroy deletes a session by session ID. +func (p *PostgresProvider) Destroy(sid string) error { + _, err := p.c.Exec("DELETE FROM session WHERE key=$1", sid) + return err +} + +// Regenerate regenerates a session store from old session ID to new one. +func (p *PostgresProvider) Regenerate(oldsid, sid string) (_ session.RawStore, err error) { + if p.Exist(sid) { + return nil, fmt.Errorf("new sid '%s' already exists", sid) + } + + if !p.Exist(oldsid) { + if _, err = p.c.Exec("INSERT INTO session(key,data,expiry) VALUES($1,$2,$3)", + oldsid, "", time.Now().Unix()); err != nil { + return nil, err + } + } + + if _, err = p.c.Exec("UPDATE session SET key=$1 WHERE key=$2", sid, oldsid); err != nil { + return nil, err + } + + return p.Read(sid) +} + +// Count counts and returns number of sessions. +func (p *PostgresProvider) Count() (total int) { + if err := p.c.QueryRow("SELECT COUNT(*) AS NUM FROM session").Scan(&total); err != nil { + panic("session/postgres: error counting records: " + err.Error()) + } + return total +} + +// GC calls GC to clean expired sessions. +func (p *PostgresProvider) GC() { + if _, err := p.c.Exec("DELETE FROM session WHERE EXTRACT(EPOCH FROM NOW()) - expiry > $1", p.maxlifetime); err != nil { + log.Printf("session/postgres: error garbage collecting: %v", err) + } +} + +func init() { + session.Register("postgres", &PostgresProvider{}) +} diff --git a/postgres/postgres_test.go b/postgres/postgres_test.go new file mode 100644 index 0000000..94b3c41 --- /dev/null +++ b/postgres/postgres_test.go @@ -0,0 +1,30 @@ +// Copyright 2014 The Macaron Authors +// Copyright 2024 The Forgejo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"): you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +package session + +import ( + "testing" + + "code.forgejo.org/go-chi/session" + "code.forgejo.org/go-chi/session/internal/test" +) + +func Test_PostgresProvider(t *testing.T) { + test.Provider(t, session.Options{ + Provider: "postgres", + ProviderConfig: "user=postgres password=postgres host=pgsql dbname=testsession port=5432 sslmode=disable", + }) +} diff --git a/redis/redis.go b/redis/redis.go new file mode 100644 index 0000000..71e6c22 --- /dev/null +++ b/redis/redis.go @@ -0,0 +1,256 @@ +// Copyright 2013 Beego Authors +// Copyright 2014 The Macaron Authors +// Copyright 2024 The Forgejo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"): you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +package session + +import ( + "context" + "fmt" + "strconv" + "strings" + "sync" + "time" + + "code.forgejo.org/go-chi/session" + "github.com/redis/go-redis/v9" +) + +// since we do not use context define global once +var ctx = context.TODO() + +// RedisStore represents a redis session store implementation. +type RedisStore struct { + c *redis.Client + prefix, sid string + duration time.Duration + lock sync.RWMutex + data map[interface{}]interface{} +} + +// NewRedisStore creates and returns a redis session store. +func NewRedisStore(c *redis.Client, prefix, sid string, dur time.Duration, kv map[interface{}]interface{}) *RedisStore { + return &RedisStore{ + c: c, + prefix: prefix, + sid: sid, + duration: dur, + data: kv, + } +} + +// Set sets value to given key in session. +func (s *RedisStore) Set(key, val interface{}) error { + s.lock.Lock() + defer s.lock.Unlock() + + s.data[key] = val + return nil +} + +// Get gets value by given key in session. +func (s *RedisStore) Get(key interface{}) interface{} { + s.lock.RLock() + defer s.lock.RUnlock() + + return s.data[key] +} + +// Delete delete a key from session. +func (s *RedisStore) Delete(key interface{}) error { + s.lock.Lock() + defer s.lock.Unlock() + + delete(s.data, key) + return nil +} + +// ID returns current session ID. +func (s *RedisStore) ID() string { + return s.sid +} + +// Release releases resource and save data to provider. +func (s *RedisStore) Release() error { + // Skip encoding if the data is empty + if len(s.data) == 0 { + return nil + } + + data, err := session.EncodeGob(s.data) + if err != nil { + return err + } + + return s.c.Set(ctx, s.prefix+s.sid, string(data), s.duration).Err() +} + +// Flush deletes all session data. +func (s *RedisStore) Flush() error { + s.lock.Lock() + defer s.lock.Unlock() + + s.data = make(map[interface{}]interface{}) + return nil +} + +// RedisProvider represents a redis session provider implementation. +type RedisProvider struct { + c *redis.Client + duration time.Duration + prefix string +} + +// Init initializes redis session provider. +// config: network=tcp,addr=:6379,password=macaron,db=0,pool_size=100,idle_timeout=180,prefix=session; +func (p *RedisProvider) Init(maxlifetime int64, config string) (err error) { + p.duration, err = time.ParseDuration(fmt.Sprintf("%ds", maxlifetime)) + if err != nil { + return err + } + + settings := strings.Split(config, ",") + + opt := &redis.Options{ + Network: "tcp", + } + for _, setting := range settings { + k, v, found := strings.Cut(setting, "=") + if !found { + return fmt.Errorf("session/redis: cannot find '=': %q", setting) + } + k, v = strings.TrimSpace(k), strings.TrimSpace(v) + + switch k { + case "network": + opt.Network = v + case "addr": + opt.Addr = v + case "password": + opt.Password = v + case "db": + opt.DB, err = strconv.Atoi(v) + if err != nil { + return fmt.Errorf("error parsing db: %w", err) + } + case "pool_size": + opt.PoolSize, err = strconv.Atoi(v) + if err != nil { + return fmt.Errorf("error parsing pool_size: %w", err) + } + case "idle_timeout": + opt.ConnMaxIdleTime, err = time.ParseDuration(v + "s") + if err != nil { + return fmt.Errorf("error parsing idle timeout: %w", err) + } + case "prefix": + p.prefix = v + default: + return fmt.Errorf("session/redis: unsupported option '%s'", k) + } + } + + p.c = redis.NewClient(opt) + return p.c.Ping(ctx).Err() +} + +// Read returns raw session store by session ID. +func (p *RedisProvider) Read(sid string) (session.RawStore, error) { + psid := p.prefix + sid + if !p.Exist(sid) { + if err := p.c.Set(ctx, psid, "", p.duration).Err(); err != nil { + return nil, err + } + } + + var kv map[interface{}]interface{} + kvs, err := p.c.Get(ctx, psid).Result() + if err != nil { + return nil, err + } + if len(kvs) == 0 { + kv = make(map[interface{}]interface{}) + } else { + kv, err = session.DecodeGob([]byte(kvs)) + if err != nil { + return nil, err + } + } + + return NewRedisStore(p.c, p.prefix, sid, p.duration, kv), nil +} + +// Exist returns true if session with given ID exists. +func (p *RedisProvider) Exist(sid string) bool { + count, err := p.c.Exists(ctx, p.prefix+sid).Result() + return err == nil && count == 1 +} + +// Destroy deletes a session by session ID. +func (p *RedisProvider) Destroy(sid string) error { + return p.c.Del(ctx, p.prefix+sid).Err() +} + +// Regenerate regenerates a session store from old session ID to new one. +func (p *RedisProvider) Regenerate(oldsid, sid string) (_ session.RawStore, err error) { + poldsid := p.prefix + oldsid + psid := p.prefix + sid + + if p.Exist(sid) { + return nil, fmt.Errorf("new sid '%s' already exists", sid) + } else if !p.Exist(oldsid) { + // Make a fake old session. + if err = p.c.Set(ctx, poldsid, "", p.duration).Err(); err != nil { + return nil, err + } + } + + if err = p.c.Rename(ctx, poldsid, psid).Err(); err != nil { + return nil, err + } + + var kv map[interface{}]interface{} + kvs, err := p.c.Get(ctx, psid).Result() + if err != nil { + return nil, err + } + + if len(kvs) == 0 { + kv = make(map[interface{}]interface{}) + } else { + kv, err = session.DecodeGob([]byte(kvs)) + if err != nil { + return nil, err + } + } + + return NewRedisStore(p.c, p.prefix, sid, p.duration, kv), nil +} + +// Count counts and returns number of sessions. +func (p *RedisProvider) Count() int { + count, err := p.c.DBSize(ctx).Result() + if err != nil { + return 0 + } + return int(count) +} + +// GC calls GC to clean expired sessions. +func (*RedisProvider) GC() {} + +func init() { + session.Register("redis", &RedisProvider{}) +} diff --git a/redis/redis_test.go b/redis/redis_test.go new file mode 100644 index 0000000..eacb012 --- /dev/null +++ b/redis/redis_test.go @@ -0,0 +1,30 @@ +// Copyright 2014 The Macaron Authors +// Copyright 2024 The Forgejo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"): you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +package session + +import ( + "testing" + + "code.forgejo.org/go-chi/session" + "code.forgejo.org/go-chi/session/internal/test" +) + +func Test_RedisProvider(t *testing.T) { + test.Provider(t, session.Options{ + Provider: "redis", + ProviderConfig: "addr=redis:6379, prefix = session:", + }) +} diff --git a/renovate.json b/renovate.json new file mode 100644 index 0000000..3146f63 --- /dev/null +++ b/renovate.json @@ -0,0 +1,4 @@ +{ + "$schema": "https://docs.renovatebot.com/renovate-schema.json", + "extends": ["go-chi/renovate-config"] +} diff --git a/session.go b/session.go new file mode 100644 index 0000000..1b409e4 --- /dev/null +++ b/session.go @@ -0,0 +1,420 @@ +// Copyright 2013 Beego Authors +// Copyright 2014 The Macaron Authors +// Copyright 2024 The Forgejo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"): you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +// Package session a middleware that provides the session management of Macaron. +package session + +import ( + "context" + "crypto/rand" + "encoding/hex" + "errors" + "fmt" + "net/http" + "net/url" + "reflect" + "time" +) + +// RawStore is the interface that operates the session data. +type RawStore interface { + // Set sets value to given key in session. + Set(interface{}, interface{}) error + // Get gets value by given key in session. + Get(interface{}) interface{} + // Delete deletes a key from session. + Delete(interface{}) error + // ID returns current session ID. + ID() string + // Release releases session resource and save data to provider. + Release() error + // Flush deletes all session data. + Flush() error +} + +// Store is the interface that contains all data for one session process with specific ID. +type Store interface { + RawStore + // Read returns raw session store by session ID. + Read(string) (RawStore, error) + // Destroy deletes a session. + Destroy(http.ResponseWriter, *http.Request) error + // RegenerateID regenerates a session store from old session ID to new one. + RegenerateID(http.ResponseWriter, *http.Request) (RawStore, error) + // Count counts and returns number of sessions. + Count() int + // GC calls GC to clean expired sessions. + GC() +} + +type store struct { + RawStore + *Manager +} + +var _ Store = &store{} + +// Options represents a struct for specifying configuration options for the session middleware. +type Options struct { + // Name of provider. Default is "memory". + Provider string + // Provider configuration, it's corresponding to provider. + ProviderConfig string + // Cookie name to save session ID. Default is "MacaronSession". + CookieName string + // Cookie path to store. Default is "/". + CookiePath string + // GC interval time in seconds. Default is 3600. + Gclifetime int64 + // Max life time in seconds. Default is whatever GC interval time is. + Maxlifetime int64 + // Use HTTPS only. Default is false. + Secure bool + // Cookie life time. Default is 0. + CookieLifeTime int + // SameSite set the cookie SameSite + SameSite http.SameSite + // Cookie domain name. Default is empty. + Domain string + // Session ID length. Default is 16. + IDLength int + // Ignore release for websocket. Default is false. + IgnoreReleaseForWebSocket bool +} + +// PrepareOptions gives some default values for options +func PrepareOptions(options []Options) Options { + var opt Options + if len(options) > 0 { + opt = options[0] + } + + if len(opt.Provider) == 0 { + opt.Provider = "memory" + } + if len(opt.ProviderConfig) == 0 { + opt.ProviderConfig = "data/sessions" + } + if len(opt.CookieName) == 0 { + opt.CookieName = "MacaronSession" + } + if len(opt.CookiePath) == 0 { + opt.CookiePath = "/" + } + if opt.Gclifetime == 0 { + opt.Gclifetime = 3600 + } + if opt.Maxlifetime == 0 { + opt.Maxlifetime = opt.Gclifetime + } + if !opt.Secure { + opt.Secure = false + } + if opt.IDLength == 0 { + opt.IDLength = 16 + } + + return opt +} + +// GetCookie returns given cookie value from request header. +func GetCookie(req *http.Request, name string) string { + cookie, err := req.Cookie(name) + if err != nil { + return "" + } + val, _ := url.QueryUnescape(cookie.Value) + return val +} + +// Sessioner is a middleware that maps a session.SessionStore service into the Macaron handler chain. +// An single variadic session.Options struct can be optionally provided to configure. +func Sessioner(options ...Options) func(next http.Handler) http.Handler { + opt := PrepareOptions(options) + manager, err := NewManager(opt.Provider, opt) + if err != nil { + panic(err) + } + go manager.startGC() + + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + sess, err := manager.Start(w, req) + if err != nil { + panic("session(start): " + err.Error()) + } + + s := store{ + RawStore: sess, + Manager: manager, + } + + req = req.WithContext(context.WithValue(req.Context(), interface{}("Session"), &s)) //nolint:staticcheck + + next.ServeHTTP(w, req) + + if manager.opt.IgnoreReleaseForWebSocket && req.Header.Get("Upgrade") == "websocket" { + return + } + + if err = s.RawStore.Release(); err != nil { + panic("session(release): " + err.Error()) + } + }) + } +} + +// GetSession returns session store +func GetSession(req *http.Request) Store { + sessCtx := req.Context().Value("Session") + sess, _ := sessCtx.(*store) + return sess +} + +// RegenerateSession +func RegenerateSession(resp http.ResponseWriter, req *http.Request) (Store, error) { + sess, ok := GetSession(req).(*store) + if !ok { + return nil, errors.New("no session in request context") + } + + oldRawStore := sess.RawStore + if err := oldRawStore.Release(); err != nil { + return nil, err + } + + store, err := sess.RegenerateID(resp, req) + if err != nil { + return nil, err + } + sess.RawStore = store + return sess, nil +} + +// Provider is the interface that provides session manipulations. +type Provider interface { + // Init initializes session provider. + Init(gclifetime int64, config string) error + // Read returns raw session store by session ID. + Read(sid string) (RawStore, error) + // Exist returns true if session with given ID exists. + Exist(sid string) bool + // Destroy deletes a session by session ID. + Destroy(sid string) error + // Regenerate regenerates a session store from old session ID to new one. + Regenerate(oldsid, sid string) (RawStore, error) + // Count counts and returns number of sessions. + Count() int + // GC calls GC to clean expired sessions. + GC() +} + +var providers = make(map[string]func() Provider) + +// Register registers a provider. +func Register(name string, provider Provider) { + if reflect.TypeOf(provider).Kind() == reflect.Ptr { + // Pointer: + RegisterFn(name, func() Provider { + return reflect.New(reflect.ValueOf(provider).Elem().Type()).Interface().(Provider) + }) + return + } + + // Not a Pointer + RegisterFn(name, func() Provider { + return reflect.New(reflect.TypeOf(provider)).Elem().Interface().(Provider) + }) +} + +// RegisterFn registers a provider function. +func RegisterFn(name string, providerfn func() Provider) { + if providerfn == nil { + panic("session: cannot register provider with nil value") + } + if _, dup := providers[name]; dup { + panic(fmt.Errorf("session: cannot register provider '%s' twice", name)) + } + + providers[name] = providerfn +} + +// Manager represents a struct that contains session provider and its configuration. +type Manager struct { + provider Provider + opt Options +} + +// NewManager creates and returns a new session manager by given provider name and configuration. +// It returns an error when requested provider name isn't registered. +func NewManager(name string, opt Options) (*Manager, error) { + fn, ok := providers[name] + if !ok { + return nil, fmt.Errorf("session: unknown provider '%s'(forgotten import?)", name) + } + + p := fn() + + return &Manager{p, opt}, p.Init(opt.Maxlifetime, opt.ProviderConfig) +} + +// sessionID generates a new session ID. +// Gives half of the ID length amount of entropy. +func (m *Manager) sessionID() string { + buf := make([]byte, m.opt.IDLength/2) + if _, err := rand.Read(buf); err != nil { + panic(err) + } + + return hex.EncodeToString(buf) +} + +// validSessionID tests whether a provided session ID is a valid session ID. +func (m *Manager) validSessionID(sid string) (bool, error) { + if len(sid) != m.opt.IDLength { + return false, fmt.Errorf("invalid 'sid': %s %d != %d", sid, len(sid), m.opt.IDLength) + } + + for i := range sid { + switch { + case '0' <= sid[i] && sid[i] <= '9': + case 'a' <= sid[i] && sid[i] <= 'f': + default: + return false, errors.New("invalid 'sid': " + sid) + } + } + return true, nil +} + +// Start starts a session by generating new one +// or retrieve existence one by reading session ID from HTTP request if it's valid. +func (m *Manager) Start(resp http.ResponseWriter, req *http.Request) (RawStore, error) { + sid := GetCookie(req, m.opt.CookieName) + valid, _ := m.validSessionID(sid) + if len(sid) > 0 && valid && m.provider.Exist(sid) { + return m.provider.Read(sid) + } + + sid = m.sessionID() + sess, err := m.provider.Read(sid) + if err != nil { + return nil, err + } + + cookie := &http.Cookie{ + Name: m.opt.CookieName, + Value: sid, + Path: m.opt.CookiePath, + HttpOnly: true, + Secure: m.opt.Secure, + Domain: m.opt.Domain, + SameSite: m.opt.SameSite, + } + if m.opt.CookieLifeTime >= 0 { + cookie.MaxAge = m.opt.CookieLifeTime + } + http.SetCookie(resp, cookie) + req.AddCookie(cookie) + return sess, nil +} + +// Read returns raw session store by session ID. +func (m *Manager) Read(sid string) (RawStore, error) { + // Ensure we're trying to read a valid session ID + if _, err := m.validSessionID(sid); err != nil { + return nil, err + } + + return m.provider.Read(sid) +} + +// Destroy deletes a session by given ID. +func (m *Manager) Destroy(resp http.ResponseWriter, req *http.Request) error { + sid := GetCookie(req, m.opt.CookieName) + if len(sid) == 0 { + return nil + } + + if _, err := m.validSessionID(sid); err != nil { + return err + } + + if err := m.provider.Destroy(sid); err != nil { + return err + } + cookie := &http.Cookie{ + Name: m.opt.CookieName, + Domain: m.opt.Domain, + Path: m.opt.CookiePath, + HttpOnly: true, + Secure: m.opt.Secure, + SameSite: m.opt.SameSite, + Expires: time.Now(), + MaxAge: -1, + } + http.SetCookie(resp, cookie) + return nil +} + +// RegenerateID regenerates a session store from old session ID to new one. +func (m *Manager) RegenerateID(resp http.ResponseWriter, req *http.Request) (sess RawStore, err error) { + sid := m.sessionID() + oldsid := GetCookie(req, m.opt.CookieName) + _, err = m.validSessionID(oldsid) + if err != nil { + return nil, err + } + sess, err = m.provider.Regenerate(oldsid, sid) + if err != nil { + return nil, err + } + cookie := &http.Cookie{ + Name: m.opt.CookieName, + Value: sid, + Path: m.opt.CookiePath, + HttpOnly: true, + Secure: m.opt.Secure, + Domain: m.opt.Domain, + SameSite: m.opt.SameSite, + } + if m.opt.CookieLifeTime >= 0 { + cookie.MaxAge = m.opt.CookieLifeTime + } + http.SetCookie(resp, cookie) + req.AddCookie(cookie) + return sess, nil +} + +// Count counts and returns number of sessions. +func (m *Manager) Count() int { + return m.provider.Count() +} + +// GC starts GC job in a certain period. +func (m *Manager) GC() { + m.provider.GC() +} + +// startGC starts GC job in a certain period. +func (m *Manager) startGC() { + m.GC() + time.AfterFunc(time.Duration(m.opt.Gclifetime)*time.Second, func() { m.startGC() }) +} + +// SetSecure indicates whether to set cookie with HTTPS or not. +func (m *Manager) SetSecure(secure bool) { + m.opt.Secure = secure +} diff --git a/session_test.go b/session_test.go new file mode 100644 index 0000000..a27c9b2 --- /dev/null +++ b/session_test.go @@ -0,0 +1,196 @@ +// Copyright 2014 The Macaron Authors +// Copyright 2024 The Forgejo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"): you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +package session + +import ( + "net/http" + "net/http/httptest" + "testing" + "time" + + chi "github.com/go-chi/chi/v5" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_Sessioner(t *testing.T) { + t.Run("Use session middleware", func(t *testing.T) { + c := chi.NewRouter() + c.Use(Sessioner()) + c.Get("/", func(_ http.ResponseWriter, _ *http.Request) {}) + + resp := httptest.NewRecorder() + req, err := http.NewRequest("GET", "/", nil) + require.NoError(t, err) + + c.ServeHTTP(resp, req) + }) + + t.Run("Register invalid provider", func(t *testing.T) { + t.Run("Provider not exists", func(t *testing.T) { + assert.Panics(t, func() { + c := chi.NewRouter() + c.Use(Sessioner(Options{ + Provider: "fake", + })) + }) + }) + + t.Run("Provider value is nil", func(t *testing.T) { + assert.Panics(t, func() { + Register("fake", nil) + }) + }) + + t.Run("Register twice", func(t *testing.T) { + assert.Panics(t, func() { + Register("memory", &MemProvider{}) + }) + }) + }) +} + +func testProvider(t *testing.T, opt Options) { + t.Run("Basic operation", func(t *testing.T) { + c := chi.NewRouter() + c.Use(Sessioner(opt)) + var initialSid string + + c.Get("/", func(_ http.ResponseWriter, req *http.Request) { + sess := GetSession(req) + assert.NoError(t, sess.Set("uname", "unknwon")) + initialSid = sess.ID() + }) + c.Get("/reg", func(resp http.ResponseWriter, req *http.Request) { + sess := GetSession(req) + assert.EqualValues(t, initialSid, sess.ID()) + raw, err := RegenerateSession(resp, req) + assert.NoError(t, err) + assert.NotNil(t, sess) + assert.EqualValues(t, sess, raw) + + assert.NotEqualValues(t, initialSid, sess.ID()) + + uname := sess.Get("uname") + assert.NotNil(t, uname) + assert.EqualValues(t, "unknwon", uname) + + assert.NoError(t, sess.Set("uname", "lunny")) + uname = sess.Get("uname") + assert.NotNil(t, uname) + assert.EqualValues(t, "lunny", uname) + }) + c.Get("/get", func(resp http.ResponseWriter, req *http.Request) { + sess := GetSession(req) + sid := sess.ID() + assert.NotEmpty(t, sid) + + raw, err := sess.Read(sid) + assert.NoError(t, err) + assert.NotNil(t, raw) + + uname := sess.Get("uname") + assert.NotNil(t, uname) + assert.EqualValues(t, "lunny", uname) + + assert.NoError(t, sess.Delete("uname")) + assert.Nil(t, sess.Get("uname")) + + assert.NoError(t, sess.Destroy(resp, req)) + }) + + resp := httptest.NewRecorder() + req, err := http.NewRequest("GET", "/", nil) + require.NoError(t, err) + c.ServeHTTP(resp, req) + + cookie := resp.Header().Get("Set-Cookie") + + resp = httptest.NewRecorder() + req, err = http.NewRequest("GET", "/reg", nil) + require.NoError(t, err) + req.Header.Set("Cookie", cookie) + c.ServeHTTP(resp, req) + + cookie = resp.Header().Get("Set-Cookie") + + resp = httptest.NewRecorder() + req, err = http.NewRequest("GET", "/get", nil) + require.NoError(t, err) + req.Header.Set("Cookie", cookie) + c.ServeHTTP(resp, req) + }) + + t.Run("Regenerate empty session", func(t *testing.T) { + c := chi.NewRouter() + c.Use(Sessioner(opt)) + c.Get("/", func(resp http.ResponseWriter, req *http.Request) { + sess := GetSession(req) + raw, err := sess.RegenerateID(resp, req) + assert.NoError(t, err) + assert.NotNil(t, raw) + }) + + resp := httptest.NewRecorder() + req, err := http.NewRequest("GET", "/", nil) + require.NoError(t, err) + req.Header.Set("Cookie", "MacaronSession=ad2c7e3cbecfcf48; Path=/;") + c.ServeHTTP(resp, req) + }) + + t.Run("GC session", func(t *testing.T) { + c := chi.NewRouter() + opt2 := opt + opt2.Gclifetime = 1 + c.Use(Sessioner(opt2)) + + c.Get("/", func(_ http.ResponseWriter, req *http.Request) { + sess := GetSession(req) + assert.NoError(t, sess.Set("uname", "unknwon")) + assert.NotEmpty(t, sess.ID()) + uname := sess.Get("uname") + assert.NotNil(t, uname) + assert.EqualValues(t, "unknwon", uname) + + assert.NoError(t, sess.Flush()) + assert.Nil(t, sess.Get("uname")) + + time.Sleep(2 * time.Second) + sess.GC() + assert.Zero(t, sess.Count()) + }) + + resp := httptest.NewRecorder() + req, err := http.NewRequest("GET", "/", nil) + require.NoError(t, err) + c.ServeHTTP(resp, req) + }) + t.Run("Detect invalid sid", func(t *testing.T) { + c := chi.NewRouter() + c.Use(Sessioner(opt)) + c.Get("/", func(_ http.ResponseWriter, req *http.Request) { + sess := GetSession(req) + raw, err := sess.Read("../session/ad2c7e3cbecfcf486") + assert.Contains(t, err.Error(), "invalid 'sid'") + assert.Nil(t, raw) + }) + + resp := httptest.NewRecorder() + req, err := http.NewRequest("GET", "/", nil) + require.NoError(t, err) + c.ServeHTTP(resp, req) + }) +} diff --git a/utils.go b/utils.go new file mode 100644 index 0000000..9ad600e --- /dev/null +++ b/utils.go @@ -0,0 +1,49 @@ +// Copyright 2013 Beego Authors +// Copyright 2014 The Macaron Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"): you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +package session + +import ( + "bytes" + "encoding/gob" +) + +func init() { + gob.Register([]interface{}{}) + gob.Register(map[int]interface{}{}) + gob.Register(map[string]interface{}{}) + gob.Register(map[interface{}]interface{}{}) + gob.Register(map[string]string{}) + gob.Register(map[int]string{}) + gob.Register(map[int]int{}) + gob.Register(map[int]int64{}) +} + +// EncodeGob encodes obj with gob +func EncodeGob(obj map[interface{}]interface{}) ([]byte, error) { + for _, v := range obj { + gob.Register(v) + } + buf := bytes.NewBuffer(nil) + err := gob.NewEncoder(buf).Encode(obj) + return buf.Bytes(), err +} + +// DecodeGob decodes bytes to obj +func DecodeGob(encoded []byte) (out map[interface{}]interface{}, err error) { + buf := bytes.NewBuffer(encoded) + err = gob.NewDecoder(buf).Decode(&out) + return out, err +}