197 lines
5.1 KiB
Go
197 lines
5.1 KiB
Go
|
// Copyright 2014 The Macaron Authors
|
||
|
// Copyright 2024 The Forgejo Authors
|
||
|
//
|
||
|
// Licensed under the Apache License, Version 2.0 (the "License"): you may
|
||
|
// not use this file except in compliance with the License. You may obtain
|
||
|
// a copy of the License at
|
||
|
//
|
||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||
|
//
|
||
|
// Unless required by applicable law or agreed to in writing, software
|
||
|
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||
|
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||
|
// License for the specific language governing permissions and limitations
|
||
|
// under the License.
|
||
|
|
||
|
package session
|
||
|
|
||
|
import (
|
||
|
"net/http"
|
||
|
"net/http/httptest"
|
||
|
"testing"
|
||
|
"time"
|
||
|
|
||
|
chi "github.com/go-chi/chi/v5"
|
||
|
"github.com/stretchr/testify/assert"
|
||
|
"github.com/stretchr/testify/require"
|
||
|
)
|
||
|
|
||
|
func Test_Sessioner(t *testing.T) {
|
||
|
t.Run("Use session middleware", func(t *testing.T) {
|
||
|
c := chi.NewRouter()
|
||
|
c.Use(Sessioner())
|
||
|
c.Get("/", func(_ http.ResponseWriter, _ *http.Request) {})
|
||
|
|
||
|
resp := httptest.NewRecorder()
|
||
|
req, err := http.NewRequest("GET", "/", nil)
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
c.ServeHTTP(resp, req)
|
||
|
})
|
||
|
|
||
|
t.Run("Register invalid provider", func(t *testing.T) {
|
||
|
t.Run("Provider not exists", func(t *testing.T) {
|
||
|
assert.Panics(t, func() {
|
||
|
c := chi.NewRouter()
|
||
|
c.Use(Sessioner(Options{
|
||
|
Provider: "fake",
|
||
|
}))
|
||
|
})
|
||
|
})
|
||
|
|
||
|
t.Run("Provider value is nil", func(t *testing.T) {
|
||
|
assert.Panics(t, func() {
|
||
|
Register("fake", nil)
|
||
|
})
|
||
|
})
|
||
|
|
||
|
t.Run("Register twice", func(t *testing.T) {
|
||
|
assert.Panics(t, func() {
|
||
|
Register("memory", &MemProvider{})
|
||
|
})
|
||
|
})
|
||
|
})
|
||
|
}
|
||
|
|
||
|
func testProvider(t *testing.T, opt Options) {
|
||
|
t.Run("Basic operation", func(t *testing.T) {
|
||
|
c := chi.NewRouter()
|
||
|
c.Use(Sessioner(opt))
|
||
|
var initialSid string
|
||
|
|
||
|
c.Get("/", func(_ http.ResponseWriter, req *http.Request) {
|
||
|
sess := GetSession(req)
|
||
|
assert.NoError(t, sess.Set("uname", "unknwon"))
|
||
|
initialSid = sess.ID()
|
||
|
})
|
||
|
c.Get("/reg", func(resp http.ResponseWriter, req *http.Request) {
|
||
|
sess := GetSession(req)
|
||
|
assert.EqualValues(t, initialSid, sess.ID())
|
||
|
raw, err := RegenerateSession(resp, req)
|
||
|
assert.NoError(t, err)
|
||
|
assert.NotNil(t, sess)
|
||
|
assert.EqualValues(t, sess, raw)
|
||
|
|
||
|
assert.NotEqualValues(t, initialSid, sess.ID())
|
||
|
|
||
|
uname := sess.Get("uname")
|
||
|
assert.NotNil(t, uname)
|
||
|
assert.EqualValues(t, "unknwon", uname)
|
||
|
|
||
|
assert.NoError(t, sess.Set("uname", "lunny"))
|
||
|
uname = sess.Get("uname")
|
||
|
assert.NotNil(t, uname)
|
||
|
assert.EqualValues(t, "lunny", uname)
|
||
|
})
|
||
|
c.Get("/get", func(resp http.ResponseWriter, req *http.Request) {
|
||
|
sess := GetSession(req)
|
||
|
sid := sess.ID()
|
||
|
assert.NotEmpty(t, sid)
|
||
|
|
||
|
raw, err := sess.Read(sid)
|
||
|
assert.NoError(t, err)
|
||
|
assert.NotNil(t, raw)
|
||
|
|
||
|
uname := sess.Get("uname")
|
||
|
assert.NotNil(t, uname)
|
||
|
assert.EqualValues(t, "lunny", uname)
|
||
|
|
||
|
assert.NoError(t, sess.Delete("uname"))
|
||
|
assert.Nil(t, sess.Get("uname"))
|
||
|
|
||
|
assert.NoError(t, sess.Destroy(resp, req))
|
||
|
})
|
||
|
|
||
|
resp := httptest.NewRecorder()
|
||
|
req, err := http.NewRequest("GET", "/", nil)
|
||
|
require.NoError(t, err)
|
||
|
c.ServeHTTP(resp, req)
|
||
|
|
||
|
cookie := resp.Header().Get("Set-Cookie")
|
||
|
|
||
|
resp = httptest.NewRecorder()
|
||
|
req, err = http.NewRequest("GET", "/reg", nil)
|
||
|
require.NoError(t, err)
|
||
|
req.Header.Set("Cookie", cookie)
|
||
|
c.ServeHTTP(resp, req)
|
||
|
|
||
|
cookie = resp.Header().Get("Set-Cookie")
|
||
|
|
||
|
resp = httptest.NewRecorder()
|
||
|
req, err = http.NewRequest("GET", "/get", nil)
|
||
|
require.NoError(t, err)
|
||
|
req.Header.Set("Cookie", cookie)
|
||
|
c.ServeHTTP(resp, req)
|
||
|
})
|
||
|
|
||
|
t.Run("Regenerate empty session", func(t *testing.T) {
|
||
|
c := chi.NewRouter()
|
||
|
c.Use(Sessioner(opt))
|
||
|
c.Get("/", func(resp http.ResponseWriter, req *http.Request) {
|
||
|
sess := GetSession(req)
|
||
|
raw, err := sess.RegenerateID(resp, req)
|
||
|
assert.NoError(t, err)
|
||
|
assert.NotNil(t, raw)
|
||
|
})
|
||
|
|
||
|
resp := httptest.NewRecorder()
|
||
|
req, err := http.NewRequest("GET", "/", nil)
|
||
|
require.NoError(t, err)
|
||
|
req.Header.Set("Cookie", "MacaronSession=ad2c7e3cbecfcf48; Path=/;")
|
||
|
c.ServeHTTP(resp, req)
|
||
|
})
|
||
|
|
||
|
t.Run("GC session", func(t *testing.T) {
|
||
|
c := chi.NewRouter()
|
||
|
opt2 := opt
|
||
|
opt2.Gclifetime = 1
|
||
|
c.Use(Sessioner(opt2))
|
||
|
|
||
|
c.Get("/", func(_ http.ResponseWriter, req *http.Request) {
|
||
|
sess := GetSession(req)
|
||
|
assert.NoError(t, sess.Set("uname", "unknwon"))
|
||
|
assert.NotEmpty(t, sess.ID())
|
||
|
uname := sess.Get("uname")
|
||
|
assert.NotNil(t, uname)
|
||
|
assert.EqualValues(t, "unknwon", uname)
|
||
|
|
||
|
assert.NoError(t, sess.Flush())
|
||
|
assert.Nil(t, sess.Get("uname"))
|
||
|
|
||
|
time.Sleep(2 * time.Second)
|
||
|
sess.GC()
|
||
|
assert.Zero(t, sess.Count())
|
||
|
})
|
||
|
|
||
|
resp := httptest.NewRecorder()
|
||
|
req, err := http.NewRequest("GET", "/", nil)
|
||
|
require.NoError(t, err)
|
||
|
c.ServeHTTP(resp, req)
|
||
|
})
|
||
|
t.Run("Detect invalid sid", func(t *testing.T) {
|
||
|
c := chi.NewRouter()
|
||
|
c.Use(Sessioner(opt))
|
||
|
c.Get("/", func(_ http.ResponseWriter, req *http.Request) {
|
||
|
sess := GetSession(req)
|
||
|
raw, err := sess.Read("../session/ad2c7e3cbecfcf486")
|
||
|
assert.Contains(t, err.Error(), "invalid 'sid'")
|
||
|
assert.Nil(t, raw)
|
||
|
})
|
||
|
|
||
|
resp := httptest.NewRecorder()
|
||
|
req, err := http.NewRequest("GET", "/", nil)
|
||
|
require.NoError(t, err)
|
||
|
c.ServeHTTP(resp, req)
|
||
|
})
|
||
|
}
|