1
0
Fork 0

Adding upstream version 2.52.6.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-05-17 06:50:16 +02:00
parent a960158181
commit 6d002e9543
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
441 changed files with 95392 additions and 0 deletions

View file

@ -0,0 +1,128 @@
package session
import (
"fmt"
"strings"
"time"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/log"
"github.com/gofiber/fiber/v2/utils"
)
// Config defines the config for middleware.
type Config struct {
// Allowed session duration
// Optional. Default value 24 * time.Hour
Expiration time.Duration
// Storage interface to store the session data
// Optional. Default value memory.New()
Storage fiber.Storage
// KeyLookup is a string in the form of "<source>:<name>" that is used
// to extract session id from the request.
// Possible values: "header:<name>", "query:<name>" or "cookie:<name>"
// Optional. Default value "cookie:session_id".
KeyLookup string
// Domain of the cookie.
// Optional. Default value "".
CookieDomain string
// Path of the cookie.
// Optional. Default value "".
CookiePath string
// Indicates if cookie is secure.
// Optional. Default value false.
CookieSecure bool
// Indicates if cookie is HTTP only.
// Optional. Default value false.
CookieHTTPOnly bool
// Value of SameSite cookie.
// Optional. Default value "Lax".
CookieSameSite string
// Decides whether cookie should last for only the browser sesison.
// Ignores Expiration if set to true
// Optional. Default value false.
CookieSessionOnly bool
// KeyGenerator generates the session key.
// Optional. Default value utils.UUIDv4
KeyGenerator func() string
// Deprecated: Please use KeyLookup
CookieName string
// Source defines where to obtain the session id
source Source
// The session name
sessionName string
}
type Source string
const (
SourceCookie Source = "cookie"
SourceHeader Source = "header"
SourceURLQuery Source = "query"
)
// ConfigDefault is the default config
var ConfigDefault = Config{
Expiration: 24 * time.Hour,
KeyLookup: "cookie:session_id",
KeyGenerator: utils.UUIDv4,
source: "cookie",
sessionName: "session_id",
}
// Helper function to set default values
func configDefault(config ...Config) Config {
// Return default config if nothing provided
if len(config) < 1 {
return ConfigDefault
}
// Override default config
cfg := config[0]
// Set default values
if int(cfg.Expiration.Seconds()) <= 0 {
cfg.Expiration = ConfigDefault.Expiration
}
if cfg.CookieName != "" {
log.Warn("[SESSION] CookieName is deprecated, please use KeyLookup")
cfg.KeyLookup = fmt.Sprintf("cookie:%s", cfg.CookieName)
}
if cfg.KeyLookup == "" {
cfg.KeyLookup = ConfigDefault.KeyLookup
}
if cfg.KeyGenerator == nil {
cfg.KeyGenerator = ConfigDefault.KeyGenerator
}
selectors := strings.Split(cfg.KeyLookup, ":")
const numSelectors = 2
if len(selectors) != numSelectors {
panic("[session] KeyLookup must in the form of <source>:<name>")
}
switch Source(selectors[0]) {
case SourceCookie:
cfg.source = SourceCookie
case SourceHeader:
cfg.source = SourceHeader
case SourceURLQuery:
cfg.source = SourceURLQuery
default:
panic("[session] source is not supported")
}
cfg.sessionName = selectors[1]
return cfg
}

View file

@ -0,0 +1,63 @@
package session
import (
"sync"
)
// go:generate msgp
// msgp -file="data.go" -o="data_msgp.go" -tests=false -unexported
type data struct {
sync.RWMutex
Data map[string]interface{}
}
var dataPool = sync.Pool{
New: func() interface{} {
d := new(data)
d.Data = make(map[string]interface{})
return d
},
}
func acquireData() *data {
return dataPool.Get().(*data) //nolint:forcetypeassert // We store nothing else in the pool
}
func (d *data) Reset() {
d.Lock()
d.Data = make(map[string]interface{})
d.Unlock()
}
func (d *data) Get(key string) interface{} {
d.RLock()
v := d.Data[key]
d.RUnlock()
return v
}
func (d *data) Set(key string, value interface{}) {
d.Lock()
d.Data[key] = value
d.Unlock()
}
func (d *data) Delete(key string) {
d.Lock()
delete(d.Data, key)
d.Unlock()
}
func (d *data) Keys() []string {
d.Lock()
keys := make([]string, 0, len(d.Data))
for k := range d.Data {
keys = append(keys, k)
}
d.Unlock()
return keys
}
func (d *data) Len() int {
return len(d.Data)
}

View file

@ -0,0 +1,311 @@
package session
import (
"bytes"
"encoding/gob"
"fmt"
"sync"
"time"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/utils"
"github.com/valyala/fasthttp"
)
type Session struct {
mu sync.RWMutex // Mutex to protect non-data fields
id string // session id
fresh bool // if new session
ctx *fiber.Ctx // fiber context
config *Store // store configuration
data *data // key value data
byteBuffer *bytes.Buffer // byte buffer for the en- and decode
exp time.Duration // expiration of this session
}
var sessionPool = sync.Pool{
New: func() interface{} {
return new(Session)
},
}
func acquireSession() *Session {
s := sessionPool.Get().(*Session) //nolint:forcetypeassert,errcheck // We store nothing else in the pool
if s.data == nil {
s.data = acquireData()
}
if s.byteBuffer == nil {
s.byteBuffer = new(bytes.Buffer)
}
s.fresh = true
return s
}
func releaseSession(s *Session) {
s.mu.Lock()
s.id = ""
s.exp = 0
s.ctx = nil
s.config = nil
if s.data != nil {
s.data.Reset()
}
if s.byteBuffer != nil {
s.byteBuffer.Reset()
}
s.mu.Unlock()
sessionPool.Put(s)
}
// Fresh is true if the current session is new
func (s *Session) Fresh() bool {
s.mu.RLock()
defer s.mu.RUnlock()
return s.fresh
}
// ID returns the session id
func (s *Session) ID() string {
s.mu.RLock()
defer s.mu.RUnlock()
return s.id
}
// Get will return the value
func (s *Session) Get(key string) interface{} {
// Better safe than sorry
if s.data == nil {
return nil
}
return s.data.Get(key)
}
// Set will update or create a new key value
func (s *Session) Set(key string, val interface{}) {
// Better safe than sorry
if s.data == nil {
return
}
s.data.Set(key, val)
}
// Delete will delete the value
func (s *Session) Delete(key string) {
// Better safe than sorry
if s.data == nil {
return
}
s.data.Delete(key)
}
// Destroy will delete the session from Storage and expire session cookie
func (s *Session) Destroy() error {
// Better safe than sorry
if s.data == nil {
return nil
}
// Reset local data
s.data.Reset()
s.mu.RLock()
defer s.mu.RUnlock()
// Use external Storage if exist
if err := s.config.Storage.Delete(s.id); err != nil {
return err
}
// Expire session
s.delSession()
return nil
}
// Regenerate generates a new session id and delete the old one from Storage
func (s *Session) Regenerate() error {
s.mu.Lock()
defer s.mu.Unlock()
// Delete old id from storage
if err := s.config.Storage.Delete(s.id); err != nil {
return err
}
// Generate a new session, and set session.fresh to true
s.refresh()
return nil
}
// Reset generates a new session id, deletes the old one from storage, and resets the associated data
func (s *Session) Reset() error {
// Reset local data
if s.data != nil {
s.data.Reset()
}
s.mu.Lock()
defer s.mu.Unlock()
// Reset byte buffer
if s.byteBuffer != nil {
s.byteBuffer.Reset()
}
// Reset expiration
s.exp = 0
// Delete old id from storage
if err := s.config.Storage.Delete(s.id); err != nil {
return err
}
// Expire session
s.delSession()
// Generate a new session, and set session.fresh to true
s.refresh()
return nil
}
// refresh generates a new session, and set session.fresh to be true
func (s *Session) refresh() {
s.id = s.config.KeyGenerator()
s.fresh = true
}
// Save will update the storage and client cookie
//
// sess.Save() will save the session data to the storage and update the
// client cookie, and it will release the session after saving.
//
// It's not safe to use the session after calling Save().
func (s *Session) Save() error {
// Better safe than sorry
if s.data == nil {
return nil
}
s.mu.Lock()
// Check if session has your own expiration, otherwise use default value
if s.exp <= 0 {
s.exp = s.config.Expiration
}
// Update client cookie
s.setSession()
// Convert data to bytes
encCache := gob.NewEncoder(s.byteBuffer)
err := encCache.Encode(&s.data.Data)
if err != nil {
return fmt.Errorf("failed to encode data: %w", err)
}
// Copy the data in buffer
encodedBytes := make([]byte, s.byteBuffer.Len())
copy(encodedBytes, s.byteBuffer.Bytes())
// Pass copied bytes with session id to provider
if err := s.config.Storage.Set(s.id, encodedBytes, s.exp); err != nil {
return err
}
s.mu.Unlock()
// Release session
// TODO: It's not safe to use the Session after calling Save()
releaseSession(s)
return nil
}
// Keys will retrieve all keys in current session
func (s *Session) Keys() []string {
if s.data == nil {
return []string{}
}
return s.data.Keys()
}
// SetExpiry sets a specific expiration for this session
func (s *Session) SetExpiry(exp time.Duration) {
s.mu.Lock()
defer s.mu.Unlock()
s.exp = exp
}
func (s *Session) setSession() {
if s.config.source == SourceHeader {
s.ctx.Request().Header.SetBytesV(s.config.sessionName, []byte(s.id))
s.ctx.Response().Header.SetBytesV(s.config.sessionName, []byte(s.id))
} else {
fcookie := fasthttp.AcquireCookie()
fcookie.SetKey(s.config.sessionName)
fcookie.SetValue(s.id)
fcookie.SetPath(s.config.CookiePath)
fcookie.SetDomain(s.config.CookieDomain)
// Cookies are also session cookies if they do not specify the Expires or Max-Age attribute.
// refer: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Set-Cookie
if !s.config.CookieSessionOnly {
fcookie.SetMaxAge(int(s.exp.Seconds()))
fcookie.SetExpire(time.Now().Add(s.exp))
}
fcookie.SetSecure(s.config.CookieSecure)
fcookie.SetHTTPOnly(s.config.CookieHTTPOnly)
switch utils.ToLower(s.config.CookieSameSite) {
case "strict":
fcookie.SetSameSite(fasthttp.CookieSameSiteStrictMode)
case "none":
fcookie.SetSameSite(fasthttp.CookieSameSiteNoneMode)
default:
fcookie.SetSameSite(fasthttp.CookieSameSiteLaxMode)
}
s.ctx.Response().Header.SetCookie(fcookie)
fasthttp.ReleaseCookie(fcookie)
}
}
func (s *Session) delSession() {
if s.config.source == SourceHeader {
s.ctx.Request().Header.Del(s.config.sessionName)
s.ctx.Response().Header.Del(s.config.sessionName)
} else {
s.ctx.Request().Header.DelCookie(s.config.sessionName)
s.ctx.Response().Header.DelCookie(s.config.sessionName)
fcookie := fasthttp.AcquireCookie()
fcookie.SetKey(s.config.sessionName)
fcookie.SetPath(s.config.CookiePath)
fcookie.SetDomain(s.config.CookieDomain)
fcookie.SetMaxAge(-1)
fcookie.SetExpire(time.Now().Add(-1 * time.Minute))
fcookie.SetSecure(s.config.CookieSecure)
fcookie.SetHTTPOnly(s.config.CookieHTTPOnly)
switch utils.ToLower(s.config.CookieSameSite) {
case "strict":
fcookie.SetSameSite(fasthttp.CookieSameSiteStrictMode)
case "none":
fcookie.SetSameSite(fasthttp.CookieSameSiteNoneMode)
default:
fcookie.SetSameSite(fasthttp.CookieSameSiteLaxMode)
}
s.ctx.Response().Header.SetCookie(fcookie)
fasthttp.ReleaseCookie(fcookie)
}
}
// decodeSessionData decodes the session data from raw bytes.
func (s *Session) decodeSessionData(rawData []byte) error {
_, _ = s.byteBuffer.Write(rawData) //nolint:errcheck // This will never fail
encCache := gob.NewDecoder(s.byteBuffer)
if err := encCache.Decode(&s.data.Data); err != nil {
return fmt.Errorf("failed to decode session data: %w", err)
}
return nil
}

View file

@ -0,0 +1,904 @@
package session
import (
"errors"
"sync"
"testing"
"time"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/internal/storage/memory"
"github.com/gofiber/fiber/v2/utils"
"github.com/valyala/fasthttp"
)
// go test -run Test_Session
func Test_Session(t *testing.T) {
t.Parallel()
// session store
store := New()
// fiber instance
app := fiber.New()
// fiber context
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
// Get a new session
sess, err := store.Get(ctx)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, true, sess.Fresh())
token := sess.ID()
err = sess.Save()
utils.AssertEqual(t, nil, err)
app.ReleaseCtx(ctx)
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
// set session
ctx.Request().Header.SetCookie(store.sessionName, token)
// get session
sess, err = store.Get(ctx)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, false, sess.Fresh())
// get keys
keys := sess.Keys()
utils.AssertEqual(t, []string{}, keys)
// get value
name := sess.Get("name")
utils.AssertEqual(t, nil, name)
// set value
sess.Set("name", "john")
// get value
name = sess.Get("name")
utils.AssertEqual(t, "john", name)
keys = sess.Keys()
utils.AssertEqual(t, []string{"name"}, keys)
// delete key
sess.Delete("name")
// get value
name = sess.Get("name")
utils.AssertEqual(t, nil, name)
// get keys
keys = sess.Keys()
utils.AssertEqual(t, []string{}, keys)
// get id
id := sess.ID()
utils.AssertEqual(t, token, id)
// save the old session first
err = sess.Save()
utils.AssertEqual(t, nil, err)
app.ReleaseCtx(ctx)
// requesting entirely new context to prevent falsy tests
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
sess, err = store.Get(ctx)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, true, sess.Fresh())
// this id should be randomly generated as session key was deleted
utils.AssertEqual(t, 36, len(sess.ID()))
// when we use the original session for the second time
// the session be should be same if the session is not expired
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
// request the server with the old session
ctx.Request().Header.SetCookie(store.sessionName, id)
sess, err = store.Get(ctx)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, false, sess.Fresh())
utils.AssertEqual(t, sess.id, id)
}
// go test -run Test_Session_Types
//
//nolint:forcetypeassert // TODO: Do not force-type assert
func Test_Session_Types(t *testing.T) {
t.Parallel()
// session store
store := New()
// fiber instance
app := fiber.New()
// fiber context
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
// set cookie
ctx.Request().Header.SetCookie(store.sessionName, "123")
// get session
sess, err := store.Get(ctx)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, true, sess.Fresh())
// the session string is no longer be 123
newSessionIDString := sess.ID()
type User struct {
Name string
}
store.RegisterType(User{})
vuser := User{
Name: "John",
}
// set value
var (
vbool = true
vstring = "str"
vint = 13
vint8 int8 = 13
vint16 int16 = 13
vint32 int32 = 13
vint64 int64 = 13
vuint uint = 13
vuint8 uint8 = 13
vuint16 uint16 = 13
vuint32 uint32 = 13
vuint64 uint64 = 13
vuintptr uintptr = 13
vbyte byte = 'k'
vrune = 'k'
vfloat32 float32 = 13
vfloat64 float64 = 13
vcomplex64 complex64 = 13
vcomplex128 complex128 = 13
)
sess.Set("vuser", vuser)
sess.Set("vbool", vbool)
sess.Set("vstring", vstring)
sess.Set("vint", vint)
sess.Set("vint8", vint8)
sess.Set("vint16", vint16)
sess.Set("vint32", vint32)
sess.Set("vint64", vint64)
sess.Set("vuint", vuint)
sess.Set("vuint8", vuint8)
sess.Set("vuint16", vuint16)
sess.Set("vuint32", vuint32)
sess.Set("vuint32", vuint32)
sess.Set("vuint64", vuint64)
sess.Set("vuintptr", vuintptr)
sess.Set("vbyte", vbyte)
sess.Set("vrune", vrune)
sess.Set("vfloat32", vfloat32)
sess.Set("vfloat64", vfloat64)
sess.Set("vcomplex64", vcomplex64)
sess.Set("vcomplex128", vcomplex128)
// save session
err = sess.Save()
utils.AssertEqual(t, nil, err)
app.ReleaseCtx(ctx)
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
ctx.Request().Header.SetCookie(store.sessionName, newSessionIDString)
// get session
sess, err = store.Get(ctx)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, false, sess.Fresh())
// get value
utils.AssertEqual(t, vuser, sess.Get("vuser").(User))
utils.AssertEqual(t, vbool, sess.Get("vbool").(bool))
utils.AssertEqual(t, vstring, sess.Get("vstring").(string))
utils.AssertEqual(t, vint, sess.Get("vint").(int))
utils.AssertEqual(t, vint8, sess.Get("vint8").(int8))
utils.AssertEqual(t, vint16, sess.Get("vint16").(int16))
utils.AssertEqual(t, vint32, sess.Get("vint32").(int32))
utils.AssertEqual(t, vint64, sess.Get("vint64").(int64))
utils.AssertEqual(t, vuint, sess.Get("vuint").(uint))
utils.AssertEqual(t, vuint8, sess.Get("vuint8").(uint8))
utils.AssertEqual(t, vuint16, sess.Get("vuint16").(uint16))
utils.AssertEqual(t, vuint32, sess.Get("vuint32").(uint32))
utils.AssertEqual(t, vuint64, sess.Get("vuint64").(uint64))
utils.AssertEqual(t, vuintptr, sess.Get("vuintptr").(uintptr))
utils.AssertEqual(t, vbyte, sess.Get("vbyte").(byte))
utils.AssertEqual(t, vrune, sess.Get("vrune").(rune))
utils.AssertEqual(t, vfloat32, sess.Get("vfloat32").(float32))
utils.AssertEqual(t, vfloat64, sess.Get("vfloat64").(float64))
utils.AssertEqual(t, vcomplex64, sess.Get("vcomplex64").(complex64))
utils.AssertEqual(t, vcomplex128, sess.Get("vcomplex128").(complex128))
app.ReleaseCtx(ctx)
}
// go test -run Test_Session_Store_Reset
func Test_Session_Store_Reset(t *testing.T) {
t.Parallel()
// session store
store := New()
// fiber instance
app := fiber.New()
// fiber context
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
// get session
sess, err := store.Get(ctx)
utils.AssertEqual(t, nil, err)
// make sure its new
utils.AssertEqual(t, true, sess.Fresh())
// set value & save
sess.Set("hello", "world")
ctx.Request().Header.SetCookie(store.sessionName, sess.ID())
utils.AssertEqual(t, nil, sess.Save())
// reset store
utils.AssertEqual(t, nil, store.Reset())
id := sess.ID()
app.ReleaseCtx(ctx)
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
ctx.Request().Header.SetCookie(store.sessionName, id)
// make sure the session is recreated
sess, err = store.Get(ctx)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, true, sess.Fresh())
utils.AssertEqual(t, nil, sess.Get("hello"))
}
// go test -run Test_Session_Save
func Test_Session_Save(t *testing.T) {
t.Parallel()
t.Run("save to cookie", func(t *testing.T) {
// session store
store := New()
// fiber instance
app := fiber.New()
// fiber context
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
// get session
sess, err := store.Get(ctx)
utils.AssertEqual(t, nil, err)
// set value
sess.Set("name", "john")
// save session
err = sess.Save()
utils.AssertEqual(t, nil, err)
})
t.Run("save to header", func(t *testing.T) {
// session store
store := New(Config{
KeyLookup: "header:session_id",
})
// fiber instance
app := fiber.New()
// fiber context
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
// get session
sess, err := store.Get(ctx)
utils.AssertEqual(t, nil, err)
// set value
sess.Set("name", "john")
// save session
err = sess.Save()
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, store.getSessionID(ctx), string(ctx.Response().Header.Peek(store.sessionName)))
utils.AssertEqual(t, store.getSessionID(ctx), string(ctx.Request().Header.Peek(store.sessionName)))
})
}
func Test_Session_Save_Expiration(t *testing.T) {
t.Parallel()
t.Run("save to cookie", func(t *testing.T) {
const sessionDuration = 5 * time.Second
t.Parallel()
// session store
store := New()
// fiber instance
app := fiber.New()
// fiber context
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
// get session
sess, err := store.Get(ctx)
utils.AssertEqual(t, nil, err)
// set value
sess.Set("name", "john")
token := sess.ID()
// expire this session in 5 seconds
sess.SetExpiry(sessionDuration)
// save session
err = sess.Save()
utils.AssertEqual(t, nil, err)
app.ReleaseCtx(ctx)
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
// here you need to get the old session yet
ctx.Request().Header.SetCookie(store.sessionName, token)
sess, err = store.Get(ctx)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "john", sess.Get("name"))
// just to make sure the session has been expired
time.Sleep(sessionDuration + (10 * time.Millisecond))
app.ReleaseCtx(ctx)
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
// here you should get a new session
ctx.Request().Header.SetCookie(store.sessionName, token)
sess, err = store.Get(ctx)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, nil, sess.Get("name"))
utils.AssertEqual(t, true, sess.ID() != token)
})
}
// go test -run Test_Session_Destroy
func Test_Session_Destroy(t *testing.T) {
t.Parallel()
t.Run("destroy from cookie", func(t *testing.T) {
t.Parallel()
// session store
store := New()
// fiber instance
app := fiber.New()
// fiber context
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
// get session
sess, err := store.Get(ctx)
utils.AssertEqual(t, nil, err)
sess.Set("name", "fenny")
utils.AssertEqual(t, nil, sess.Destroy())
name := sess.Get("name")
utils.AssertEqual(t, nil, name)
})
t.Run("destroy from header", func(t *testing.T) {
t.Parallel()
// session store
store := New(Config{
KeyLookup: "header:session_id",
})
// fiber instance
app := fiber.New()
// fiber context
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
// get session
sess, err := store.Get(ctx)
utils.AssertEqual(t, nil, err)
// set value & save
sess.Set("name", "fenny")
id := sess.ID()
utils.AssertEqual(t, nil, sess.Save())
app.ReleaseCtx(ctx)
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
// get session
ctx.Request().Header.Set(store.sessionName, id)
sess, err = store.Get(ctx)
utils.AssertEqual(t, nil, err)
err = sess.Destroy()
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "", string(ctx.Response().Header.Peek(store.sessionName)))
utils.AssertEqual(t, "", string(ctx.Request().Header.Peek(store.sessionName)))
})
}
// go test -run Test_Session_Custom_Config
func Test_Session_Custom_Config(t *testing.T) {
t.Parallel()
store := New(Config{Expiration: time.Hour, KeyGenerator: func() string { return "very random" }})
utils.AssertEqual(t, time.Hour, store.Expiration)
utils.AssertEqual(t, "very random", store.KeyGenerator())
store = New(Config{Expiration: 0})
utils.AssertEqual(t, ConfigDefault.Expiration, store.Expiration)
}
// go test -run Test_Session_Cookie
func Test_Session_Cookie(t *testing.T) {
t.Parallel()
// session store
store := New()
// fiber instance
app := fiber.New()
// fiber context
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
// get session
sess, err := store.Get(ctx)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, nil, sess.Save())
// cookie should be set on Save ( even if empty data )
utils.AssertEqual(t, 84, len(ctx.Response().Header.PeekCookie(store.sessionName)))
}
// go test -run Test_Session_Cookie_In_Response
// Regression: https://github.com/gofiber/fiber/pull/1191
func Test_Session_Cookie_In_Middleware_Chain(t *testing.T) {
t.Parallel()
store := New()
app := fiber.New()
// fiber context
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
// get session
sess, err := store.Get(ctx)
utils.AssertEqual(t, nil, err)
sess.Set("id", "1")
id := sess.ID()
utils.AssertEqual(t, true, sess.Fresh())
utils.AssertEqual(t, nil, sess.Save())
sess, err = store.Get(ctx)
utils.AssertEqual(t, nil, err)
sess.Set("name", "john")
utils.AssertEqual(t, true, sess.Fresh())
utils.AssertEqual(t, id, sess.ID()) // session id should be the same
utils.AssertEqual(t, sess.ID() != "1", true)
utils.AssertEqual(t, "john", sess.Get("name"))
}
// go test -run Test_Session_Deletes_Single_Key
// Regression: https://github.com/gofiber/fiber/issues/1365
func Test_Session_Deletes_Single_Key(t *testing.T) {
t.Parallel()
store := New()
app := fiber.New()
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
sess, err := store.Get(ctx)
utils.AssertEqual(t, nil, err)
id := sess.ID()
sess.Set("id", "1")
utils.AssertEqual(t, nil, sess.Save())
app.ReleaseCtx(ctx)
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
ctx.Request().Header.SetCookie(store.sessionName, id)
sess, err = store.Get(ctx)
utils.AssertEqual(t, nil, err)
sess.Delete("id")
utils.AssertEqual(t, nil, sess.Save())
app.ReleaseCtx(ctx)
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
ctx.Request().Header.SetCookie(store.sessionName, id)
sess, err = store.Get(ctx)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, false, sess.Fresh())
utils.AssertEqual(t, nil, sess.Get("id"))
app.ReleaseCtx(ctx)
}
// go test -run Test_Session_Reset
func Test_Session_Reset(t *testing.T) {
t.Parallel()
// fiber instance
app := fiber.New()
// session store
store := New()
// fiber context
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
t.Run("reset session data and id, and set fresh to be true", func(t *testing.T) {
t.Parallel()
// fiber context
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
// a random session uuid
originalSessionUUIDString := ""
// now the session is in the storage
freshSession, err := store.Get(ctx)
utils.AssertEqual(t, nil, err)
originalSessionUUIDString = freshSession.ID()
// set a value
freshSession.Set("name", "fenny")
freshSession.Set("email", "fenny@example.com")
err = freshSession.Save()
utils.AssertEqual(t, nil, err)
app.ReleaseCtx(ctx)
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
// set cookie
ctx.Request().Header.SetCookie(store.sessionName, originalSessionUUIDString)
// as the session is in the storage, session.fresh should be false
acquiredSession, err := store.Get(ctx)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, false, acquiredSession.Fresh())
err = acquiredSession.Reset()
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, false, acquiredSession.ID() == originalSessionUUIDString)
utils.AssertEqual(t, false, acquiredSession.ID() == "")
// acquiredSession.fresh should be true after resetting
utils.AssertEqual(t, true, acquiredSession.Fresh())
// Check that the session data has been reset
keys := acquiredSession.Keys()
utils.AssertEqual(t, []string{}, keys)
// Set a new value for 'name' and check that it's updated
acquiredSession.Set("name", "john")
utils.AssertEqual(t, "john", acquiredSession.Get("name"))
utils.AssertEqual(t, nil, acquiredSession.Get("email"))
// Save after resetting
err = acquiredSession.Save()
utils.AssertEqual(t, nil, err)
// Check that the session id is not in the header or cookie anymore
utils.AssertEqual(t, "", string(ctx.Response().Header.Peek(store.sessionName)))
utils.AssertEqual(t, "", string(ctx.Request().Header.Peek(store.sessionName)))
app.ReleaseCtx(ctx)
})
}
// go test -run Test_Session_Regenerate
// Regression: https://github.com/gofiber/fiber/issues/1395
func Test_Session_Regenerate(t *testing.T) {
t.Parallel()
// fiber instance
app := fiber.New()
t.Run("set fresh to be true when regenerating a session", func(t *testing.T) {
// session store
store := New()
// a random session uuid
originalSessionUUIDString := ""
// fiber context
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
// now the session is in the storage
freshSession, err := store.Get(ctx)
utils.AssertEqual(t, nil, err)
originalSessionUUIDString = freshSession.ID()
err = freshSession.Save()
utils.AssertEqual(t, nil, err)
// release the context
app.ReleaseCtx(ctx)
// acquire a new context
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
// set cookie
ctx.Request().Header.SetCookie(store.sessionName, originalSessionUUIDString)
// as the session is in the storage, session.fresh should be false
acquiredSession, err := store.Get(ctx)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, false, acquiredSession.Fresh())
err = acquiredSession.Regenerate()
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, false, acquiredSession.ID() == originalSessionUUIDString)
// acquiredSession.fresh should be true after regenerating
utils.AssertEqual(t, true, acquiredSession.Fresh())
// release the context
app.ReleaseCtx(ctx)
})
}
// go test -v -run=^$ -bench=Benchmark_Session -benchmem -count=4
func Benchmark_Session(b *testing.B) {
app, store := fiber.New(), New()
c := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(c)
c.Request().Header.SetCookie(store.sessionName, "12356789")
var err error
b.Run("default", func(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
for n := 0; n < b.N; n++ {
sess, _ := store.Get(c) //nolint:errcheck // We're inside a benchmark
sess.Set("john", "doe")
err = sess.Save()
}
utils.AssertEqual(b, nil, err)
})
b.Run("storage", func(b *testing.B) {
store = New(Config{
Storage: memory.New(),
})
b.ReportAllocs()
b.ResetTimer()
for n := 0; n < b.N; n++ {
sess, _ := store.Get(c) //nolint:errcheck // We're inside a benchmark
sess.Set("john", "doe")
err = sess.Save()
}
utils.AssertEqual(b, nil, err)
})
}
// go test -v -run=^$ -bench=Benchmark_Session_Parallel -benchmem -count=4
func Benchmark_Session_Parallel(b *testing.B) {
b.Run("default", func(b *testing.B) {
app, store := fiber.New(), New()
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Request().Header.SetCookie(store.sessionName, "12356789")
sess, _ := store.Get(c) //nolint:errcheck // We're inside a benchmark
sess.Set("john", "doe")
_ = sess.Save() //nolint:errcheck // We're inside a benchmark
app.ReleaseCtx(c)
}
})
})
b.Run("storage", func(b *testing.B) {
app := fiber.New()
store := New(Config{
Storage: memory.New(),
})
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Request().Header.SetCookie(store.sessionName, "12356789")
sess, _ := store.Get(c) //nolint:errcheck // We're inside a benchmark
sess.Set("john", "doe")
_ = sess.Save() //nolint:errcheck // We're inside a benchmark
app.ReleaseCtx(c)
}
})
})
}
// go test -v -run=^$ -bench=Benchmark_Session_Asserted -benchmem -count=4
func Benchmark_Session_Asserted(b *testing.B) {
b.Run("default", func(b *testing.B) {
app, store := fiber.New(), New()
c := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(c)
c.Request().Header.SetCookie(store.sessionName, "12356789")
b.ReportAllocs()
b.ResetTimer()
for n := 0; n < b.N; n++ {
sess, err := store.Get(c)
utils.AssertEqual(b, nil, err)
sess.Set("john", "doe")
err = sess.Save()
utils.AssertEqual(b, nil, err)
}
})
b.Run("storage", func(b *testing.B) {
app := fiber.New()
store := New(Config{
Storage: memory.New(),
})
c := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(c)
c.Request().Header.SetCookie(store.sessionName, "12356789")
b.ReportAllocs()
b.ResetTimer()
for n := 0; n < b.N; n++ {
sess, err := store.Get(c)
utils.AssertEqual(b, nil, err)
sess.Set("john", "doe")
err = sess.Save()
utils.AssertEqual(b, nil, err)
}
})
}
// go test -v -run=^$ -bench=Benchmark_Session_Asserted_Parallel -benchmem -count=4
func Benchmark_Session_Asserted_Parallel(b *testing.B) {
b.Run("default", func(b *testing.B) {
app, store := fiber.New(), New()
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Request().Header.SetCookie(store.sessionName, "12356789")
sess, err := store.Get(c)
utils.AssertEqual(b, nil, err)
sess.Set("john", "doe")
utils.AssertEqual(b, nil, sess.Save())
app.ReleaseCtx(c)
}
})
})
b.Run("storage", func(b *testing.B) {
app := fiber.New()
store := New(Config{
Storage: memory.New(),
})
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Request().Header.SetCookie(store.sessionName, "12356789")
sess, err := store.Get(c)
utils.AssertEqual(b, nil, err)
sess.Set("john", "doe")
utils.AssertEqual(b, nil, sess.Save())
app.ReleaseCtx(c)
}
})
})
}
// go test -v -race -run Test_Session_Concurrency ./...
func Test_Session_Concurrency(t *testing.T) {
t.Parallel()
app := fiber.New()
store := New()
var wg sync.WaitGroup
errChan := make(chan error, 10) // Buffered channel to collect errors
const numGoroutines = 10 // Number of concurrent goroutines to test
// Start numGoroutines goroutines
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func() {
defer wg.Done()
localCtx := app.AcquireCtx(&fasthttp.RequestCtx{})
sess, err := store.Get(localCtx)
if err != nil {
errChan <- err
return
}
// Set a value
sess.Set("name", "john")
// get the session id
id := sess.ID()
// Check if the session is fresh
if !sess.Fresh() {
errChan <- errors.New("session should be fresh")
return
}
// Save the session
if err := sess.Save(); err != nil {
errChan <- err
return
}
// Release the context
app.ReleaseCtx(localCtx)
// Acquire a new context
localCtx = app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(localCtx)
// Set the session id in the header
localCtx.Request().Header.SetCookie(store.sessionName, id)
// Get the session
sess, err = store.Get(localCtx)
if err != nil {
errChan <- err
return
}
// Get the value
name := sess.Get("name")
if name != "john" {
errChan <- errors.New("name should be john")
return
}
// Get ID from the session
if sess.ID() != id {
errChan <- errors.New("id should be the same")
return
}
// Check if the session is fresh
if sess.Fresh() {
errChan <- errors.New("session should not be fresh")
return
}
// Delete the key
sess.Delete("name")
// Get the value
name = sess.Get("name")
if name != nil {
errChan <- errors.New("name should be nil")
return
}
// Destroy the session
if err := sess.Destroy(); err != nil {
errChan <- err
return
}
}()
}
wg.Wait() // Wait for all goroutines to finish
close(errChan) // Close the channel to signal no more errors will be sent
// Check for errors sent to errChan
for err := range errChan {
utils.AssertEqual(t, nil, err)
}
}

136
middleware/session/store.go Normal file
View file

@ -0,0 +1,136 @@
package session
import (
"encoding/gob"
"errors"
"fmt"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/internal/storage/memory"
"github.com/gofiber/fiber/v2/utils"
)
// ErrEmptySessionID is an error that occurs when the session ID is empty.
var ErrEmptySessionID = errors.New("session id cannot be empty")
// sessionIDKey is the local key type used to store and retrieve the session ID in context.
type sessionIDKey int
const (
// sessionIDContextKey is the key used to store the session ID in the context locals.
sessionIDContextKey sessionIDKey = iota
)
type Store struct {
Config
}
// New creates a new session store with the provided configuration.
func New(config ...Config) *Store {
// Set default config
cfg := configDefault(config...)
if cfg.Storage == nil {
cfg.Storage = memory.New()
}
return &Store{
cfg,
}
}
// RegisterType registers a custom type for encoding/decoding into any storage provider.
func (*Store) RegisterType(i interface{}) {
gob.Register(i)
}
// Get retrieves or creates a session for the given context.
func (s *Store) Get(c *fiber.Ctx) (*Session, error) {
var rawData []byte
var err error
id, ok := c.Locals(sessionIDContextKey).(string)
if !ok {
id = s.getSessionID(c)
}
fresh := ok // Assume the session is fresh if the ID is found in locals
// Attempt to fetch session data if an ID is provided
if id != "" {
rawData, err = s.Storage.Get(id)
if err != nil {
return nil, err
}
if rawData == nil {
// Data not found, prepare to generate a new session
id = ""
}
}
// Generate a new ID if needed
if id == "" {
fresh = true // The session is fresh if a new ID is generated
id = s.KeyGenerator()
c.Locals(sessionIDContextKey, id)
}
// Create session object
sess := acquireSession()
sess.mu.Lock()
defer sess.mu.Unlock()
sess.ctx = c
sess.config = s
sess.id = id
sess.fresh = fresh
// Decode session data if found
if rawData != nil {
sess.data.Lock()
defer sess.data.Unlock()
if err := sess.decodeSessionData(rawData); err != nil {
return nil, fmt.Errorf("failed to decode session data: %w", err)
}
}
return sess, nil
}
// getSessionID returns the session ID from cookies, headers, or query string.
func (s *Store) getSessionID(c *fiber.Ctx) string {
id := c.Cookies(s.sessionName)
if len(id) > 0 {
return utils.CopyString(id)
}
if s.source == SourceHeader {
id = string(c.Request().Header.Peek(s.sessionName))
if len(id) > 0 {
return id
}
}
if s.source == SourceURLQuery {
id = c.Query(s.sessionName)
if len(id) > 0 {
return utils.CopyString(id)
}
}
return ""
}
// Reset deletes all sessions from the storage.
func (s *Store) Reset() error {
return s.Storage.Reset()
}
// Delete deletes a session by its ID.
func (s *Store) Delete(id string) error {
if id == "" {
return ErrEmptySessionID
}
return s.Storage.Delete(id)
}

View file

@ -0,0 +1,119 @@
package session
import (
"fmt"
"testing"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/utils"
"github.com/valyala/fasthttp"
)
// go test -run TestStore_getSessionID
func TestStore_getSessionID(t *testing.T) {
t.Parallel()
expectedID := "test-session-id"
// fiber instance
app := fiber.New()
t.Run("from cookie", func(t *testing.T) {
t.Parallel()
// session store
store := New()
// fiber context
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
// set cookie
ctx.Request().Header.SetCookie(store.sessionName, expectedID)
utils.AssertEqual(t, expectedID, store.getSessionID(ctx))
})
t.Run("from header", func(t *testing.T) {
t.Parallel()
// session store
store := New(Config{
KeyLookup: "header:session_id",
})
// fiber context
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
// set header
ctx.Request().Header.Set(store.sessionName, expectedID)
utils.AssertEqual(t, expectedID, store.getSessionID(ctx))
})
t.Run("from url query", func(t *testing.T) {
t.Parallel()
// session store
store := New(Config{
KeyLookup: "query:session_id",
})
// fiber context
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
// set url parameter
ctx.Request().SetRequestURI(fmt.Sprintf("/path?%s=%s", store.sessionName, expectedID))
utils.AssertEqual(t, expectedID, store.getSessionID(ctx))
})
}
// go test -run TestStore_Get
// Regression: https://github.com/gofiber/fiber/issues/1408
// Regression: https://github.com/gofiber/fiber/security/advisories/GHSA-98j2-3j3p-fw2v
func TestStore_Get(t *testing.T) {
t.Parallel()
unexpectedID := "test-session-id"
// fiber instance
app := fiber.New()
t.Run("session should be re-generated if it is invalid", func(t *testing.T) {
t.Parallel()
// session store
store := New()
// fiber context
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
// set cookie
ctx.Request().Header.SetCookie(store.sessionName, unexpectedID)
acquiredSession, err := store.Get(ctx)
utils.AssertEqual(t, err, nil)
utils.AssertEqual(t, acquiredSession.ID() != unexpectedID, true)
})
}
// go test -run TestStore_DeleteSession
func TestStore_DeleteSession(t *testing.T) {
t.Parallel()
// fiber instance
app := fiber.New()
// session store
store := New()
// fiber context
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
// Create a new session
session, err := store.Get(ctx)
utils.AssertEqual(t, err, nil)
// Save the session ID
sessionID := session.ID()
// Delete the session
err = store.Delete(sessionID)
utils.AssertEqual(t, err, nil)
// Try to get the session again
session, err = store.Get(ctx)
utils.AssertEqual(t, err, nil)
// The session ID should be different now, because the old session was deleted
utils.AssertEqual(t, session.ID() == sessionID, false)
}