Adding upstream version 2.52.6.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
a960158181
commit
6d002e9543
441 changed files with 95392 additions and 0 deletions
128
middleware/session/config.go
Normal file
128
middleware/session/config.go
Normal file
|
@ -0,0 +1,128 @@
|
|||
package session
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/log"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
)
|
||||
|
||||
// Config defines the config for middleware.
|
||||
type Config struct {
|
||||
// Allowed session duration
|
||||
// Optional. Default value 24 * time.Hour
|
||||
Expiration time.Duration
|
||||
|
||||
// Storage interface to store the session data
|
||||
// Optional. Default value memory.New()
|
||||
Storage fiber.Storage
|
||||
|
||||
// KeyLookup is a string in the form of "<source>:<name>" that is used
|
||||
// to extract session id from the request.
|
||||
// Possible values: "header:<name>", "query:<name>" or "cookie:<name>"
|
||||
// Optional. Default value "cookie:session_id".
|
||||
KeyLookup string
|
||||
|
||||
// Domain of the cookie.
|
||||
// Optional. Default value "".
|
||||
CookieDomain string
|
||||
|
||||
// Path of the cookie.
|
||||
// Optional. Default value "".
|
||||
CookiePath string
|
||||
|
||||
// Indicates if cookie is secure.
|
||||
// Optional. Default value false.
|
||||
CookieSecure bool
|
||||
|
||||
// Indicates if cookie is HTTP only.
|
||||
// Optional. Default value false.
|
||||
CookieHTTPOnly bool
|
||||
|
||||
// Value of SameSite cookie.
|
||||
// Optional. Default value "Lax".
|
||||
CookieSameSite string
|
||||
|
||||
// Decides whether cookie should last for only the browser sesison.
|
||||
// Ignores Expiration if set to true
|
||||
// Optional. Default value false.
|
||||
CookieSessionOnly bool
|
||||
|
||||
// KeyGenerator generates the session key.
|
||||
// Optional. Default value utils.UUIDv4
|
||||
KeyGenerator func() string
|
||||
|
||||
// Deprecated: Please use KeyLookup
|
||||
CookieName string
|
||||
|
||||
// Source defines where to obtain the session id
|
||||
source Source
|
||||
|
||||
// The session name
|
||||
sessionName string
|
||||
}
|
||||
|
||||
type Source string
|
||||
|
||||
const (
|
||||
SourceCookie Source = "cookie"
|
||||
SourceHeader Source = "header"
|
||||
SourceURLQuery Source = "query"
|
||||
)
|
||||
|
||||
// ConfigDefault is the default config
|
||||
var ConfigDefault = Config{
|
||||
Expiration: 24 * time.Hour,
|
||||
KeyLookup: "cookie:session_id",
|
||||
KeyGenerator: utils.UUIDv4,
|
||||
source: "cookie",
|
||||
sessionName: "session_id",
|
||||
}
|
||||
|
||||
// Helper function to set default values
|
||||
func configDefault(config ...Config) Config {
|
||||
// Return default config if nothing provided
|
||||
if len(config) < 1 {
|
||||
return ConfigDefault
|
||||
}
|
||||
|
||||
// Override default config
|
||||
cfg := config[0]
|
||||
|
||||
// Set default values
|
||||
if int(cfg.Expiration.Seconds()) <= 0 {
|
||||
cfg.Expiration = ConfigDefault.Expiration
|
||||
}
|
||||
if cfg.CookieName != "" {
|
||||
log.Warn("[SESSION] CookieName is deprecated, please use KeyLookup")
|
||||
cfg.KeyLookup = fmt.Sprintf("cookie:%s", cfg.CookieName)
|
||||
}
|
||||
if cfg.KeyLookup == "" {
|
||||
cfg.KeyLookup = ConfigDefault.KeyLookup
|
||||
}
|
||||
if cfg.KeyGenerator == nil {
|
||||
cfg.KeyGenerator = ConfigDefault.KeyGenerator
|
||||
}
|
||||
|
||||
selectors := strings.Split(cfg.KeyLookup, ":")
|
||||
const numSelectors = 2
|
||||
if len(selectors) != numSelectors {
|
||||
panic("[session] KeyLookup must in the form of <source>:<name>")
|
||||
}
|
||||
switch Source(selectors[0]) {
|
||||
case SourceCookie:
|
||||
cfg.source = SourceCookie
|
||||
case SourceHeader:
|
||||
cfg.source = SourceHeader
|
||||
case SourceURLQuery:
|
||||
cfg.source = SourceURLQuery
|
||||
default:
|
||||
panic("[session] source is not supported")
|
||||
}
|
||||
cfg.sessionName = selectors[1]
|
||||
|
||||
return cfg
|
||||
}
|
63
middleware/session/data.go
Normal file
63
middleware/session/data.go
Normal file
|
@ -0,0 +1,63 @@
|
|||
package session
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
// go:generate msgp
|
||||
// msgp -file="data.go" -o="data_msgp.go" -tests=false -unexported
|
||||
type data struct {
|
||||
sync.RWMutex
|
||||
Data map[string]interface{}
|
||||
}
|
||||
|
||||
var dataPool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
d := new(data)
|
||||
d.Data = make(map[string]interface{})
|
||||
return d
|
||||
},
|
||||
}
|
||||
|
||||
func acquireData() *data {
|
||||
return dataPool.Get().(*data) //nolint:forcetypeassert // We store nothing else in the pool
|
||||
}
|
||||
|
||||
func (d *data) Reset() {
|
||||
d.Lock()
|
||||
d.Data = make(map[string]interface{})
|
||||
d.Unlock()
|
||||
}
|
||||
|
||||
func (d *data) Get(key string) interface{} {
|
||||
d.RLock()
|
||||
v := d.Data[key]
|
||||
d.RUnlock()
|
||||
return v
|
||||
}
|
||||
|
||||
func (d *data) Set(key string, value interface{}) {
|
||||
d.Lock()
|
||||
d.Data[key] = value
|
||||
d.Unlock()
|
||||
}
|
||||
|
||||
func (d *data) Delete(key string) {
|
||||
d.Lock()
|
||||
delete(d.Data, key)
|
||||
d.Unlock()
|
||||
}
|
||||
|
||||
func (d *data) Keys() []string {
|
||||
d.Lock()
|
||||
keys := make([]string, 0, len(d.Data))
|
||||
for k := range d.Data {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
d.Unlock()
|
||||
return keys
|
||||
}
|
||||
|
||||
func (d *data) Len() int {
|
||||
return len(d.Data)
|
||||
}
|
311
middleware/session/session.go
Normal file
311
middleware/session/session.go
Normal file
|
@ -0,0 +1,311 @@
|
|||
package session
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/gob"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
type Session struct {
|
||||
mu sync.RWMutex // Mutex to protect non-data fields
|
||||
id string // session id
|
||||
fresh bool // if new session
|
||||
ctx *fiber.Ctx // fiber context
|
||||
config *Store // store configuration
|
||||
data *data // key value data
|
||||
byteBuffer *bytes.Buffer // byte buffer for the en- and decode
|
||||
exp time.Duration // expiration of this session
|
||||
}
|
||||
|
||||
var sessionPool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
return new(Session)
|
||||
},
|
||||
}
|
||||
|
||||
func acquireSession() *Session {
|
||||
s := sessionPool.Get().(*Session) //nolint:forcetypeassert,errcheck // We store nothing else in the pool
|
||||
if s.data == nil {
|
||||
s.data = acquireData()
|
||||
}
|
||||
if s.byteBuffer == nil {
|
||||
s.byteBuffer = new(bytes.Buffer)
|
||||
}
|
||||
s.fresh = true
|
||||
return s
|
||||
}
|
||||
|
||||
func releaseSession(s *Session) {
|
||||
s.mu.Lock()
|
||||
s.id = ""
|
||||
s.exp = 0
|
||||
s.ctx = nil
|
||||
s.config = nil
|
||||
if s.data != nil {
|
||||
s.data.Reset()
|
||||
}
|
||||
if s.byteBuffer != nil {
|
||||
s.byteBuffer.Reset()
|
||||
}
|
||||
s.mu.Unlock()
|
||||
sessionPool.Put(s)
|
||||
}
|
||||
|
||||
// Fresh is true if the current session is new
|
||||
func (s *Session) Fresh() bool {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.fresh
|
||||
}
|
||||
|
||||
// ID returns the session id
|
||||
func (s *Session) ID() string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.id
|
||||
}
|
||||
|
||||
// Get will return the value
|
||||
func (s *Session) Get(key string) interface{} {
|
||||
// Better safe than sorry
|
||||
if s.data == nil {
|
||||
return nil
|
||||
}
|
||||
return s.data.Get(key)
|
||||
}
|
||||
|
||||
// Set will update or create a new key value
|
||||
func (s *Session) Set(key string, val interface{}) {
|
||||
// Better safe than sorry
|
||||
if s.data == nil {
|
||||
return
|
||||
}
|
||||
s.data.Set(key, val)
|
||||
}
|
||||
|
||||
// Delete will delete the value
|
||||
func (s *Session) Delete(key string) {
|
||||
// Better safe than sorry
|
||||
if s.data == nil {
|
||||
return
|
||||
}
|
||||
s.data.Delete(key)
|
||||
}
|
||||
|
||||
// Destroy will delete the session from Storage and expire session cookie
|
||||
func (s *Session) Destroy() error {
|
||||
// Better safe than sorry
|
||||
if s.data == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reset local data
|
||||
s.data.Reset()
|
||||
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
// Use external Storage if exist
|
||||
if err := s.config.Storage.Delete(s.id); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Expire session
|
||||
s.delSession()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Regenerate generates a new session id and delete the old one from Storage
|
||||
func (s *Session) Regenerate() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Delete old id from storage
|
||||
if err := s.config.Storage.Delete(s.id); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Generate a new session, and set session.fresh to true
|
||||
s.refresh()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reset generates a new session id, deletes the old one from storage, and resets the associated data
|
||||
func (s *Session) Reset() error {
|
||||
// Reset local data
|
||||
if s.data != nil {
|
||||
s.data.Reset()
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Reset byte buffer
|
||||
if s.byteBuffer != nil {
|
||||
s.byteBuffer.Reset()
|
||||
}
|
||||
// Reset expiration
|
||||
s.exp = 0
|
||||
|
||||
// Delete old id from storage
|
||||
if err := s.config.Storage.Delete(s.id); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Expire session
|
||||
s.delSession()
|
||||
|
||||
// Generate a new session, and set session.fresh to true
|
||||
s.refresh()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// refresh generates a new session, and set session.fresh to be true
|
||||
func (s *Session) refresh() {
|
||||
s.id = s.config.KeyGenerator()
|
||||
s.fresh = true
|
||||
}
|
||||
|
||||
// Save will update the storage and client cookie
|
||||
//
|
||||
// sess.Save() will save the session data to the storage and update the
|
||||
// client cookie, and it will release the session after saving.
|
||||
//
|
||||
// It's not safe to use the session after calling Save().
|
||||
func (s *Session) Save() error {
|
||||
// Better safe than sorry
|
||||
if s.data == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
|
||||
// Check if session has your own expiration, otherwise use default value
|
||||
if s.exp <= 0 {
|
||||
s.exp = s.config.Expiration
|
||||
}
|
||||
|
||||
// Update client cookie
|
||||
s.setSession()
|
||||
|
||||
// Convert data to bytes
|
||||
encCache := gob.NewEncoder(s.byteBuffer)
|
||||
err := encCache.Encode(&s.data.Data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encode data: %w", err)
|
||||
}
|
||||
|
||||
// Copy the data in buffer
|
||||
encodedBytes := make([]byte, s.byteBuffer.Len())
|
||||
copy(encodedBytes, s.byteBuffer.Bytes())
|
||||
|
||||
// Pass copied bytes with session id to provider
|
||||
if err := s.config.Storage.Set(s.id, encodedBytes, s.exp); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.mu.Unlock()
|
||||
|
||||
// Release session
|
||||
// TODO: It's not safe to use the Session after calling Save()
|
||||
releaseSession(s)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Keys will retrieve all keys in current session
|
||||
func (s *Session) Keys() []string {
|
||||
if s.data == nil {
|
||||
return []string{}
|
||||
}
|
||||
return s.data.Keys()
|
||||
}
|
||||
|
||||
// SetExpiry sets a specific expiration for this session
|
||||
func (s *Session) SetExpiry(exp time.Duration) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.exp = exp
|
||||
}
|
||||
|
||||
func (s *Session) setSession() {
|
||||
if s.config.source == SourceHeader {
|
||||
s.ctx.Request().Header.SetBytesV(s.config.sessionName, []byte(s.id))
|
||||
s.ctx.Response().Header.SetBytesV(s.config.sessionName, []byte(s.id))
|
||||
} else {
|
||||
fcookie := fasthttp.AcquireCookie()
|
||||
fcookie.SetKey(s.config.sessionName)
|
||||
fcookie.SetValue(s.id)
|
||||
fcookie.SetPath(s.config.CookiePath)
|
||||
fcookie.SetDomain(s.config.CookieDomain)
|
||||
// Cookies are also session cookies if they do not specify the Expires or Max-Age attribute.
|
||||
// refer: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Set-Cookie
|
||||
if !s.config.CookieSessionOnly {
|
||||
fcookie.SetMaxAge(int(s.exp.Seconds()))
|
||||
fcookie.SetExpire(time.Now().Add(s.exp))
|
||||
}
|
||||
fcookie.SetSecure(s.config.CookieSecure)
|
||||
fcookie.SetHTTPOnly(s.config.CookieHTTPOnly)
|
||||
|
||||
switch utils.ToLower(s.config.CookieSameSite) {
|
||||
case "strict":
|
||||
fcookie.SetSameSite(fasthttp.CookieSameSiteStrictMode)
|
||||
case "none":
|
||||
fcookie.SetSameSite(fasthttp.CookieSameSiteNoneMode)
|
||||
default:
|
||||
fcookie.SetSameSite(fasthttp.CookieSameSiteLaxMode)
|
||||
}
|
||||
s.ctx.Response().Header.SetCookie(fcookie)
|
||||
fasthttp.ReleaseCookie(fcookie)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Session) delSession() {
|
||||
if s.config.source == SourceHeader {
|
||||
s.ctx.Request().Header.Del(s.config.sessionName)
|
||||
s.ctx.Response().Header.Del(s.config.sessionName)
|
||||
} else {
|
||||
s.ctx.Request().Header.DelCookie(s.config.sessionName)
|
||||
s.ctx.Response().Header.DelCookie(s.config.sessionName)
|
||||
|
||||
fcookie := fasthttp.AcquireCookie()
|
||||
fcookie.SetKey(s.config.sessionName)
|
||||
fcookie.SetPath(s.config.CookiePath)
|
||||
fcookie.SetDomain(s.config.CookieDomain)
|
||||
fcookie.SetMaxAge(-1)
|
||||
fcookie.SetExpire(time.Now().Add(-1 * time.Minute))
|
||||
fcookie.SetSecure(s.config.CookieSecure)
|
||||
fcookie.SetHTTPOnly(s.config.CookieHTTPOnly)
|
||||
|
||||
switch utils.ToLower(s.config.CookieSameSite) {
|
||||
case "strict":
|
||||
fcookie.SetSameSite(fasthttp.CookieSameSiteStrictMode)
|
||||
case "none":
|
||||
fcookie.SetSameSite(fasthttp.CookieSameSiteNoneMode)
|
||||
default:
|
||||
fcookie.SetSameSite(fasthttp.CookieSameSiteLaxMode)
|
||||
}
|
||||
|
||||
s.ctx.Response().Header.SetCookie(fcookie)
|
||||
fasthttp.ReleaseCookie(fcookie)
|
||||
}
|
||||
}
|
||||
|
||||
// decodeSessionData decodes the session data from raw bytes.
|
||||
func (s *Session) decodeSessionData(rawData []byte) error {
|
||||
_, _ = s.byteBuffer.Write(rawData) //nolint:errcheck // This will never fail
|
||||
encCache := gob.NewDecoder(s.byteBuffer)
|
||||
if err := encCache.Decode(&s.data.Data); err != nil {
|
||||
return fmt.Errorf("failed to decode session data: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
904
middleware/session/session_test.go
Normal file
904
middleware/session/session_test.go
Normal file
|
@ -0,0 +1,904 @@
|
|||
package session
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/internal/storage/memory"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// go test -run Test_Session
|
||||
func Test_Session(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// session store
|
||||
store := New()
|
||||
|
||||
// fiber instance
|
||||
app := fiber.New()
|
||||
|
||||
// fiber context
|
||||
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
defer app.ReleaseCtx(ctx)
|
||||
|
||||
// Get a new session
|
||||
sess, err := store.Get(ctx)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, true, sess.Fresh())
|
||||
token := sess.ID()
|
||||
err = sess.Save()
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
app.ReleaseCtx(ctx)
|
||||
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
|
||||
// set session
|
||||
ctx.Request().Header.SetCookie(store.sessionName, token)
|
||||
|
||||
// get session
|
||||
sess, err = store.Get(ctx)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, false, sess.Fresh())
|
||||
|
||||
// get keys
|
||||
keys := sess.Keys()
|
||||
utils.AssertEqual(t, []string{}, keys)
|
||||
|
||||
// get value
|
||||
name := sess.Get("name")
|
||||
utils.AssertEqual(t, nil, name)
|
||||
|
||||
// set value
|
||||
sess.Set("name", "john")
|
||||
|
||||
// get value
|
||||
name = sess.Get("name")
|
||||
utils.AssertEqual(t, "john", name)
|
||||
|
||||
keys = sess.Keys()
|
||||
utils.AssertEqual(t, []string{"name"}, keys)
|
||||
|
||||
// delete key
|
||||
sess.Delete("name")
|
||||
|
||||
// get value
|
||||
name = sess.Get("name")
|
||||
utils.AssertEqual(t, nil, name)
|
||||
|
||||
// get keys
|
||||
keys = sess.Keys()
|
||||
utils.AssertEqual(t, []string{}, keys)
|
||||
|
||||
// get id
|
||||
id := sess.ID()
|
||||
utils.AssertEqual(t, token, id)
|
||||
|
||||
// save the old session first
|
||||
err = sess.Save()
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
app.ReleaseCtx(ctx)
|
||||
|
||||
// requesting entirely new context to prevent falsy tests
|
||||
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
defer app.ReleaseCtx(ctx)
|
||||
|
||||
sess, err = store.Get(ctx)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, true, sess.Fresh())
|
||||
|
||||
// this id should be randomly generated as session key was deleted
|
||||
utils.AssertEqual(t, 36, len(sess.ID()))
|
||||
|
||||
// when we use the original session for the second time
|
||||
// the session be should be same if the session is not expired
|
||||
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
defer app.ReleaseCtx(ctx)
|
||||
|
||||
// request the server with the old session
|
||||
ctx.Request().Header.SetCookie(store.sessionName, id)
|
||||
sess, err = store.Get(ctx)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, false, sess.Fresh())
|
||||
utils.AssertEqual(t, sess.id, id)
|
||||
}
|
||||
|
||||
// go test -run Test_Session_Types
|
||||
//
|
||||
//nolint:forcetypeassert // TODO: Do not force-type assert
|
||||
func Test_Session_Types(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// session store
|
||||
store := New()
|
||||
|
||||
// fiber instance
|
||||
app := fiber.New()
|
||||
|
||||
// fiber context
|
||||
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
|
||||
// set cookie
|
||||
ctx.Request().Header.SetCookie(store.sessionName, "123")
|
||||
|
||||
// get session
|
||||
sess, err := store.Get(ctx)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, true, sess.Fresh())
|
||||
|
||||
// the session string is no longer be 123
|
||||
newSessionIDString := sess.ID()
|
||||
|
||||
type User struct {
|
||||
Name string
|
||||
}
|
||||
store.RegisterType(User{})
|
||||
vuser := User{
|
||||
Name: "John",
|
||||
}
|
||||
// set value
|
||||
var (
|
||||
vbool = true
|
||||
vstring = "str"
|
||||
vint = 13
|
||||
vint8 int8 = 13
|
||||
vint16 int16 = 13
|
||||
vint32 int32 = 13
|
||||
vint64 int64 = 13
|
||||
vuint uint = 13
|
||||
vuint8 uint8 = 13
|
||||
vuint16 uint16 = 13
|
||||
vuint32 uint32 = 13
|
||||
vuint64 uint64 = 13
|
||||
vuintptr uintptr = 13
|
||||
vbyte byte = 'k'
|
||||
vrune = 'k'
|
||||
vfloat32 float32 = 13
|
||||
vfloat64 float64 = 13
|
||||
vcomplex64 complex64 = 13
|
||||
vcomplex128 complex128 = 13
|
||||
)
|
||||
sess.Set("vuser", vuser)
|
||||
sess.Set("vbool", vbool)
|
||||
sess.Set("vstring", vstring)
|
||||
sess.Set("vint", vint)
|
||||
sess.Set("vint8", vint8)
|
||||
sess.Set("vint16", vint16)
|
||||
sess.Set("vint32", vint32)
|
||||
sess.Set("vint64", vint64)
|
||||
sess.Set("vuint", vuint)
|
||||
sess.Set("vuint8", vuint8)
|
||||
sess.Set("vuint16", vuint16)
|
||||
sess.Set("vuint32", vuint32)
|
||||
sess.Set("vuint32", vuint32)
|
||||
sess.Set("vuint64", vuint64)
|
||||
sess.Set("vuintptr", vuintptr)
|
||||
sess.Set("vbyte", vbyte)
|
||||
sess.Set("vrune", vrune)
|
||||
sess.Set("vfloat32", vfloat32)
|
||||
sess.Set("vfloat64", vfloat64)
|
||||
sess.Set("vcomplex64", vcomplex64)
|
||||
sess.Set("vcomplex128", vcomplex128)
|
||||
|
||||
// save session
|
||||
err = sess.Save()
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
app.ReleaseCtx(ctx)
|
||||
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
|
||||
ctx.Request().Header.SetCookie(store.sessionName, newSessionIDString)
|
||||
|
||||
// get session
|
||||
sess, err = store.Get(ctx)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, false, sess.Fresh())
|
||||
|
||||
// get value
|
||||
utils.AssertEqual(t, vuser, sess.Get("vuser").(User))
|
||||
utils.AssertEqual(t, vbool, sess.Get("vbool").(bool))
|
||||
utils.AssertEqual(t, vstring, sess.Get("vstring").(string))
|
||||
utils.AssertEqual(t, vint, sess.Get("vint").(int))
|
||||
utils.AssertEqual(t, vint8, sess.Get("vint8").(int8))
|
||||
utils.AssertEqual(t, vint16, sess.Get("vint16").(int16))
|
||||
utils.AssertEqual(t, vint32, sess.Get("vint32").(int32))
|
||||
utils.AssertEqual(t, vint64, sess.Get("vint64").(int64))
|
||||
utils.AssertEqual(t, vuint, sess.Get("vuint").(uint))
|
||||
utils.AssertEqual(t, vuint8, sess.Get("vuint8").(uint8))
|
||||
utils.AssertEqual(t, vuint16, sess.Get("vuint16").(uint16))
|
||||
utils.AssertEqual(t, vuint32, sess.Get("vuint32").(uint32))
|
||||
utils.AssertEqual(t, vuint64, sess.Get("vuint64").(uint64))
|
||||
utils.AssertEqual(t, vuintptr, sess.Get("vuintptr").(uintptr))
|
||||
utils.AssertEqual(t, vbyte, sess.Get("vbyte").(byte))
|
||||
utils.AssertEqual(t, vrune, sess.Get("vrune").(rune))
|
||||
utils.AssertEqual(t, vfloat32, sess.Get("vfloat32").(float32))
|
||||
utils.AssertEqual(t, vfloat64, sess.Get("vfloat64").(float64))
|
||||
utils.AssertEqual(t, vcomplex64, sess.Get("vcomplex64").(complex64))
|
||||
utils.AssertEqual(t, vcomplex128, sess.Get("vcomplex128").(complex128))
|
||||
app.ReleaseCtx(ctx)
|
||||
}
|
||||
|
||||
// go test -run Test_Session_Store_Reset
|
||||
func Test_Session_Store_Reset(t *testing.T) {
|
||||
t.Parallel()
|
||||
// session store
|
||||
store := New()
|
||||
// fiber instance
|
||||
app := fiber.New()
|
||||
// fiber context
|
||||
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
|
||||
// get session
|
||||
sess, err := store.Get(ctx)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
// make sure its new
|
||||
utils.AssertEqual(t, true, sess.Fresh())
|
||||
// set value & save
|
||||
sess.Set("hello", "world")
|
||||
ctx.Request().Header.SetCookie(store.sessionName, sess.ID())
|
||||
utils.AssertEqual(t, nil, sess.Save())
|
||||
|
||||
// reset store
|
||||
utils.AssertEqual(t, nil, store.Reset())
|
||||
id := sess.ID()
|
||||
|
||||
app.ReleaseCtx(ctx)
|
||||
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
defer app.ReleaseCtx(ctx)
|
||||
ctx.Request().Header.SetCookie(store.sessionName, id)
|
||||
|
||||
// make sure the session is recreated
|
||||
sess, err = store.Get(ctx)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, true, sess.Fresh())
|
||||
utils.AssertEqual(t, nil, sess.Get("hello"))
|
||||
}
|
||||
|
||||
// go test -run Test_Session_Save
|
||||
func Test_Session_Save(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("save to cookie", func(t *testing.T) {
|
||||
// session store
|
||||
store := New()
|
||||
// fiber instance
|
||||
app := fiber.New()
|
||||
// fiber context
|
||||
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
defer app.ReleaseCtx(ctx)
|
||||
// get session
|
||||
sess, err := store.Get(ctx)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
// set value
|
||||
sess.Set("name", "john")
|
||||
|
||||
// save session
|
||||
err = sess.Save()
|
||||
utils.AssertEqual(t, nil, err)
|
||||
})
|
||||
|
||||
t.Run("save to header", func(t *testing.T) {
|
||||
// session store
|
||||
store := New(Config{
|
||||
KeyLookup: "header:session_id",
|
||||
})
|
||||
// fiber instance
|
||||
app := fiber.New()
|
||||
// fiber context
|
||||
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
defer app.ReleaseCtx(ctx)
|
||||
// get session
|
||||
sess, err := store.Get(ctx)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
// set value
|
||||
sess.Set("name", "john")
|
||||
|
||||
// save session
|
||||
err = sess.Save()
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, store.getSessionID(ctx), string(ctx.Response().Header.Peek(store.sessionName)))
|
||||
utils.AssertEqual(t, store.getSessionID(ctx), string(ctx.Request().Header.Peek(store.sessionName)))
|
||||
})
|
||||
}
|
||||
|
||||
func Test_Session_Save_Expiration(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("save to cookie", func(t *testing.T) {
|
||||
const sessionDuration = 5 * time.Second
|
||||
t.Parallel()
|
||||
// session store
|
||||
store := New()
|
||||
// fiber instance
|
||||
app := fiber.New()
|
||||
// fiber context
|
||||
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
defer app.ReleaseCtx(ctx)
|
||||
// get session
|
||||
sess, err := store.Get(ctx)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
// set value
|
||||
sess.Set("name", "john")
|
||||
|
||||
token := sess.ID()
|
||||
|
||||
// expire this session in 5 seconds
|
||||
sess.SetExpiry(sessionDuration)
|
||||
|
||||
// save session
|
||||
err = sess.Save()
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
app.ReleaseCtx(ctx)
|
||||
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
|
||||
// here you need to get the old session yet
|
||||
ctx.Request().Header.SetCookie(store.sessionName, token)
|
||||
sess, err = store.Get(ctx)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, "john", sess.Get("name"))
|
||||
|
||||
// just to make sure the session has been expired
|
||||
time.Sleep(sessionDuration + (10 * time.Millisecond))
|
||||
|
||||
app.ReleaseCtx(ctx)
|
||||
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
defer app.ReleaseCtx(ctx)
|
||||
|
||||
// here you should get a new session
|
||||
ctx.Request().Header.SetCookie(store.sessionName, token)
|
||||
sess, err = store.Get(ctx)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, nil, sess.Get("name"))
|
||||
utils.AssertEqual(t, true, sess.ID() != token)
|
||||
})
|
||||
}
|
||||
|
||||
// go test -run Test_Session_Destroy
|
||||
func Test_Session_Destroy(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("destroy from cookie", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
// session store
|
||||
store := New()
|
||||
// fiber instance
|
||||
app := fiber.New()
|
||||
// fiber context
|
||||
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
defer app.ReleaseCtx(ctx)
|
||||
// get session
|
||||
sess, err := store.Get(ctx)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
sess.Set("name", "fenny")
|
||||
utils.AssertEqual(t, nil, sess.Destroy())
|
||||
name := sess.Get("name")
|
||||
utils.AssertEqual(t, nil, name)
|
||||
})
|
||||
|
||||
t.Run("destroy from header", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
// session store
|
||||
store := New(Config{
|
||||
KeyLookup: "header:session_id",
|
||||
})
|
||||
// fiber instance
|
||||
app := fiber.New()
|
||||
// fiber context
|
||||
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
defer app.ReleaseCtx(ctx)
|
||||
// get session
|
||||
sess, err := store.Get(ctx)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
// set value & save
|
||||
sess.Set("name", "fenny")
|
||||
id := sess.ID()
|
||||
utils.AssertEqual(t, nil, sess.Save())
|
||||
|
||||
app.ReleaseCtx(ctx)
|
||||
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
defer app.ReleaseCtx(ctx)
|
||||
|
||||
// get session
|
||||
ctx.Request().Header.Set(store.sessionName, id)
|
||||
sess, err = store.Get(ctx)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
err = sess.Destroy()
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, "", string(ctx.Response().Header.Peek(store.sessionName)))
|
||||
utils.AssertEqual(t, "", string(ctx.Request().Header.Peek(store.sessionName)))
|
||||
})
|
||||
}
|
||||
|
||||
// go test -run Test_Session_Custom_Config
|
||||
func Test_Session_Custom_Config(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := New(Config{Expiration: time.Hour, KeyGenerator: func() string { return "very random" }})
|
||||
utils.AssertEqual(t, time.Hour, store.Expiration)
|
||||
utils.AssertEqual(t, "very random", store.KeyGenerator())
|
||||
|
||||
store = New(Config{Expiration: 0})
|
||||
utils.AssertEqual(t, ConfigDefault.Expiration, store.Expiration)
|
||||
}
|
||||
|
||||
// go test -run Test_Session_Cookie
|
||||
func Test_Session_Cookie(t *testing.T) {
|
||||
t.Parallel()
|
||||
// session store
|
||||
store := New()
|
||||
// fiber instance
|
||||
app := fiber.New()
|
||||
// fiber context
|
||||
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
defer app.ReleaseCtx(ctx)
|
||||
|
||||
// get session
|
||||
sess, err := store.Get(ctx)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, nil, sess.Save())
|
||||
|
||||
// cookie should be set on Save ( even if empty data )
|
||||
utils.AssertEqual(t, 84, len(ctx.Response().Header.PeekCookie(store.sessionName)))
|
||||
}
|
||||
|
||||
// go test -run Test_Session_Cookie_In_Response
|
||||
// Regression: https://github.com/gofiber/fiber/pull/1191
|
||||
func Test_Session_Cookie_In_Middleware_Chain(t *testing.T) {
|
||||
t.Parallel()
|
||||
store := New()
|
||||
app := fiber.New()
|
||||
|
||||
// fiber context
|
||||
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
defer app.ReleaseCtx(ctx)
|
||||
|
||||
// get session
|
||||
sess, err := store.Get(ctx)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
sess.Set("id", "1")
|
||||
id := sess.ID()
|
||||
utils.AssertEqual(t, true, sess.Fresh())
|
||||
utils.AssertEqual(t, nil, sess.Save())
|
||||
|
||||
sess, err = store.Get(ctx)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
sess.Set("name", "john")
|
||||
utils.AssertEqual(t, true, sess.Fresh())
|
||||
utils.AssertEqual(t, id, sess.ID()) // session id should be the same
|
||||
|
||||
utils.AssertEqual(t, sess.ID() != "1", true)
|
||||
utils.AssertEqual(t, "john", sess.Get("name"))
|
||||
}
|
||||
|
||||
// go test -run Test_Session_Deletes_Single_Key
|
||||
// Regression: https://github.com/gofiber/fiber/issues/1365
|
||||
func Test_Session_Deletes_Single_Key(t *testing.T) {
|
||||
t.Parallel()
|
||||
store := New()
|
||||
app := fiber.New()
|
||||
|
||||
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
|
||||
sess, err := store.Get(ctx)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
id := sess.ID()
|
||||
sess.Set("id", "1")
|
||||
utils.AssertEqual(t, nil, sess.Save())
|
||||
|
||||
app.ReleaseCtx(ctx)
|
||||
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
ctx.Request().Header.SetCookie(store.sessionName, id)
|
||||
|
||||
sess, err = store.Get(ctx)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
sess.Delete("id")
|
||||
utils.AssertEqual(t, nil, sess.Save())
|
||||
|
||||
app.ReleaseCtx(ctx)
|
||||
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
ctx.Request().Header.SetCookie(store.sessionName, id)
|
||||
|
||||
sess, err = store.Get(ctx)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, false, sess.Fresh())
|
||||
utils.AssertEqual(t, nil, sess.Get("id"))
|
||||
app.ReleaseCtx(ctx)
|
||||
}
|
||||
|
||||
// go test -run Test_Session_Reset
|
||||
func Test_Session_Reset(t *testing.T) {
|
||||
t.Parallel()
|
||||
// fiber instance
|
||||
app := fiber.New()
|
||||
|
||||
// session store
|
||||
store := New()
|
||||
|
||||
// fiber context
|
||||
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
defer app.ReleaseCtx(ctx)
|
||||
|
||||
t.Run("reset session data and id, and set fresh to be true", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
// fiber context
|
||||
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
// a random session uuid
|
||||
originalSessionUUIDString := ""
|
||||
|
||||
// now the session is in the storage
|
||||
freshSession, err := store.Get(ctx)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
originalSessionUUIDString = freshSession.ID()
|
||||
|
||||
// set a value
|
||||
freshSession.Set("name", "fenny")
|
||||
freshSession.Set("email", "fenny@example.com")
|
||||
|
||||
err = freshSession.Save()
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
app.ReleaseCtx(ctx)
|
||||
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
|
||||
// set cookie
|
||||
ctx.Request().Header.SetCookie(store.sessionName, originalSessionUUIDString)
|
||||
|
||||
// as the session is in the storage, session.fresh should be false
|
||||
acquiredSession, err := store.Get(ctx)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, false, acquiredSession.Fresh())
|
||||
|
||||
err = acquiredSession.Reset()
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
utils.AssertEqual(t, false, acquiredSession.ID() == originalSessionUUIDString)
|
||||
utils.AssertEqual(t, false, acquiredSession.ID() == "")
|
||||
|
||||
// acquiredSession.fresh should be true after resetting
|
||||
utils.AssertEqual(t, true, acquiredSession.Fresh())
|
||||
|
||||
// Check that the session data has been reset
|
||||
keys := acquiredSession.Keys()
|
||||
utils.AssertEqual(t, []string{}, keys)
|
||||
|
||||
// Set a new value for 'name' and check that it's updated
|
||||
acquiredSession.Set("name", "john")
|
||||
utils.AssertEqual(t, "john", acquiredSession.Get("name"))
|
||||
utils.AssertEqual(t, nil, acquiredSession.Get("email"))
|
||||
|
||||
// Save after resetting
|
||||
err = acquiredSession.Save()
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
// Check that the session id is not in the header or cookie anymore
|
||||
utils.AssertEqual(t, "", string(ctx.Response().Header.Peek(store.sessionName)))
|
||||
utils.AssertEqual(t, "", string(ctx.Request().Header.Peek(store.sessionName)))
|
||||
|
||||
app.ReleaseCtx(ctx)
|
||||
})
|
||||
}
|
||||
|
||||
// go test -run Test_Session_Regenerate
|
||||
// Regression: https://github.com/gofiber/fiber/issues/1395
|
||||
func Test_Session_Regenerate(t *testing.T) {
|
||||
t.Parallel()
|
||||
// fiber instance
|
||||
app := fiber.New()
|
||||
t.Run("set fresh to be true when regenerating a session", func(t *testing.T) {
|
||||
// session store
|
||||
store := New()
|
||||
// a random session uuid
|
||||
originalSessionUUIDString := ""
|
||||
// fiber context
|
||||
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
defer app.ReleaseCtx(ctx)
|
||||
|
||||
// now the session is in the storage
|
||||
freshSession, err := store.Get(ctx)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
originalSessionUUIDString = freshSession.ID()
|
||||
|
||||
err = freshSession.Save()
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
// release the context
|
||||
app.ReleaseCtx(ctx)
|
||||
|
||||
// acquire a new context
|
||||
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
|
||||
// set cookie
|
||||
ctx.Request().Header.SetCookie(store.sessionName, originalSessionUUIDString)
|
||||
|
||||
// as the session is in the storage, session.fresh should be false
|
||||
acquiredSession, err := store.Get(ctx)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, false, acquiredSession.Fresh())
|
||||
|
||||
err = acquiredSession.Regenerate()
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
utils.AssertEqual(t, false, acquiredSession.ID() == originalSessionUUIDString)
|
||||
|
||||
// acquiredSession.fresh should be true after regenerating
|
||||
utils.AssertEqual(t, true, acquiredSession.Fresh())
|
||||
|
||||
// release the context
|
||||
app.ReleaseCtx(ctx)
|
||||
})
|
||||
}
|
||||
|
||||
// go test -v -run=^$ -bench=Benchmark_Session -benchmem -count=4
|
||||
func Benchmark_Session(b *testing.B) {
|
||||
app, store := fiber.New(), New()
|
||||
c := app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
defer app.ReleaseCtx(c)
|
||||
c.Request().Header.SetCookie(store.sessionName, "12356789")
|
||||
|
||||
var err error
|
||||
b.Run("default", func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for n := 0; n < b.N; n++ {
|
||||
sess, _ := store.Get(c) //nolint:errcheck // We're inside a benchmark
|
||||
sess.Set("john", "doe")
|
||||
err = sess.Save()
|
||||
}
|
||||
|
||||
utils.AssertEqual(b, nil, err)
|
||||
})
|
||||
|
||||
b.Run("storage", func(b *testing.B) {
|
||||
store = New(Config{
|
||||
Storage: memory.New(),
|
||||
})
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for n := 0; n < b.N; n++ {
|
||||
sess, _ := store.Get(c) //nolint:errcheck // We're inside a benchmark
|
||||
sess.Set("john", "doe")
|
||||
err = sess.Save()
|
||||
}
|
||||
|
||||
utils.AssertEqual(b, nil, err)
|
||||
})
|
||||
}
|
||||
|
||||
// go test -v -run=^$ -bench=Benchmark_Session_Parallel -benchmem -count=4
|
||||
func Benchmark_Session_Parallel(b *testing.B) {
|
||||
b.Run("default", func(b *testing.B) {
|
||||
app, store := fiber.New(), New()
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
c := app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
c.Request().Header.SetCookie(store.sessionName, "12356789")
|
||||
|
||||
sess, _ := store.Get(c) //nolint:errcheck // We're inside a benchmark
|
||||
sess.Set("john", "doe")
|
||||
_ = sess.Save() //nolint:errcheck // We're inside a benchmark
|
||||
app.ReleaseCtx(c)
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
b.Run("storage", func(b *testing.B) {
|
||||
app := fiber.New()
|
||||
store := New(Config{
|
||||
Storage: memory.New(),
|
||||
})
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
c := app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
c.Request().Header.SetCookie(store.sessionName, "12356789")
|
||||
|
||||
sess, _ := store.Get(c) //nolint:errcheck // We're inside a benchmark
|
||||
sess.Set("john", "doe")
|
||||
_ = sess.Save() //nolint:errcheck // We're inside a benchmark
|
||||
app.ReleaseCtx(c)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// go test -v -run=^$ -bench=Benchmark_Session_Asserted -benchmem -count=4
|
||||
func Benchmark_Session_Asserted(b *testing.B) {
|
||||
b.Run("default", func(b *testing.B) {
|
||||
app, store := fiber.New(), New()
|
||||
c := app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
defer app.ReleaseCtx(c)
|
||||
c.Request().Header.SetCookie(store.sessionName, "12356789")
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for n := 0; n < b.N; n++ {
|
||||
sess, err := store.Get(c)
|
||||
utils.AssertEqual(b, nil, err)
|
||||
sess.Set("john", "doe")
|
||||
err = sess.Save()
|
||||
utils.AssertEqual(b, nil, err)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("storage", func(b *testing.B) {
|
||||
app := fiber.New()
|
||||
store := New(Config{
|
||||
Storage: memory.New(),
|
||||
})
|
||||
c := app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
defer app.ReleaseCtx(c)
|
||||
c.Request().Header.SetCookie(store.sessionName, "12356789")
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for n := 0; n < b.N; n++ {
|
||||
sess, err := store.Get(c)
|
||||
utils.AssertEqual(b, nil, err)
|
||||
sess.Set("john", "doe")
|
||||
err = sess.Save()
|
||||
utils.AssertEqual(b, nil, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// go test -v -run=^$ -bench=Benchmark_Session_Asserted_Parallel -benchmem -count=4
|
||||
func Benchmark_Session_Asserted_Parallel(b *testing.B) {
|
||||
b.Run("default", func(b *testing.B) {
|
||||
app, store := fiber.New(), New()
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
c := app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
c.Request().Header.SetCookie(store.sessionName, "12356789")
|
||||
|
||||
sess, err := store.Get(c)
|
||||
utils.AssertEqual(b, nil, err)
|
||||
sess.Set("john", "doe")
|
||||
utils.AssertEqual(b, nil, sess.Save())
|
||||
app.ReleaseCtx(c)
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
b.Run("storage", func(b *testing.B) {
|
||||
app := fiber.New()
|
||||
store := New(Config{
|
||||
Storage: memory.New(),
|
||||
})
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
c := app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
c.Request().Header.SetCookie(store.sessionName, "12356789")
|
||||
|
||||
sess, err := store.Get(c)
|
||||
utils.AssertEqual(b, nil, err)
|
||||
sess.Set("john", "doe")
|
||||
utils.AssertEqual(b, nil, sess.Save())
|
||||
app.ReleaseCtx(c)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// go test -v -race -run Test_Session_Concurrency ./...
|
||||
func Test_Session_Concurrency(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
store := New()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
errChan := make(chan error, 10) // Buffered channel to collect errors
|
||||
const numGoroutines = 10 // Number of concurrent goroutines to test
|
||||
|
||||
// Start numGoroutines goroutines
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
localCtx := app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
|
||||
sess, err := store.Get(localCtx)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
|
||||
// Set a value
|
||||
sess.Set("name", "john")
|
||||
|
||||
// get the session id
|
||||
id := sess.ID()
|
||||
|
||||
// Check if the session is fresh
|
||||
if !sess.Fresh() {
|
||||
errChan <- errors.New("session should be fresh")
|
||||
return
|
||||
}
|
||||
|
||||
// Save the session
|
||||
if err := sess.Save(); err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
|
||||
// Release the context
|
||||
app.ReleaseCtx(localCtx)
|
||||
|
||||
// Acquire a new context
|
||||
localCtx = app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
defer app.ReleaseCtx(localCtx)
|
||||
|
||||
// Set the session id in the header
|
||||
localCtx.Request().Header.SetCookie(store.sessionName, id)
|
||||
|
||||
// Get the session
|
||||
sess, err = store.Get(localCtx)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
|
||||
// Get the value
|
||||
name := sess.Get("name")
|
||||
if name != "john" {
|
||||
errChan <- errors.New("name should be john")
|
||||
return
|
||||
}
|
||||
|
||||
// Get ID from the session
|
||||
if sess.ID() != id {
|
||||
errChan <- errors.New("id should be the same")
|
||||
return
|
||||
}
|
||||
|
||||
// Check if the session is fresh
|
||||
if sess.Fresh() {
|
||||
errChan <- errors.New("session should not be fresh")
|
||||
return
|
||||
}
|
||||
|
||||
// Delete the key
|
||||
sess.Delete("name")
|
||||
|
||||
// Get the value
|
||||
name = sess.Get("name")
|
||||
if name != nil {
|
||||
errChan <- errors.New("name should be nil")
|
||||
return
|
||||
}
|
||||
|
||||
// Destroy the session
|
||||
if err := sess.Destroy(); err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait() // Wait for all goroutines to finish
|
||||
close(errChan) // Close the channel to signal no more errors will be sent
|
||||
|
||||
// Check for errors sent to errChan
|
||||
for err := range errChan {
|
||||
utils.AssertEqual(t, nil, err)
|
||||
}
|
||||
}
|
136
middleware/session/store.go
Normal file
136
middleware/session/store.go
Normal file
|
@ -0,0 +1,136 @@
|
|||
package session
|
||||
|
||||
import (
|
||||
"encoding/gob"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/internal/storage/memory"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
)
|
||||
|
||||
// ErrEmptySessionID is an error that occurs when the session ID is empty.
|
||||
var ErrEmptySessionID = errors.New("session id cannot be empty")
|
||||
|
||||
// sessionIDKey is the local key type used to store and retrieve the session ID in context.
|
||||
type sessionIDKey int
|
||||
|
||||
const (
|
||||
// sessionIDContextKey is the key used to store the session ID in the context locals.
|
||||
sessionIDContextKey sessionIDKey = iota
|
||||
)
|
||||
|
||||
type Store struct {
|
||||
Config
|
||||
}
|
||||
|
||||
// New creates a new session store with the provided configuration.
|
||||
func New(config ...Config) *Store {
|
||||
// Set default config
|
||||
cfg := configDefault(config...)
|
||||
|
||||
if cfg.Storage == nil {
|
||||
cfg.Storage = memory.New()
|
||||
}
|
||||
|
||||
return &Store{
|
||||
cfg,
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterType registers a custom type for encoding/decoding into any storage provider.
|
||||
func (*Store) RegisterType(i interface{}) {
|
||||
gob.Register(i)
|
||||
}
|
||||
|
||||
// Get retrieves or creates a session for the given context.
|
||||
func (s *Store) Get(c *fiber.Ctx) (*Session, error) {
|
||||
var rawData []byte
|
||||
var err error
|
||||
|
||||
id, ok := c.Locals(sessionIDContextKey).(string)
|
||||
if !ok {
|
||||
id = s.getSessionID(c)
|
||||
}
|
||||
|
||||
fresh := ok // Assume the session is fresh if the ID is found in locals
|
||||
|
||||
// Attempt to fetch session data if an ID is provided
|
||||
if id != "" {
|
||||
rawData, err = s.Storage.Get(id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if rawData == nil {
|
||||
// Data not found, prepare to generate a new session
|
||||
id = ""
|
||||
}
|
||||
}
|
||||
|
||||
// Generate a new ID if needed
|
||||
if id == "" {
|
||||
fresh = true // The session is fresh if a new ID is generated
|
||||
id = s.KeyGenerator()
|
||||
c.Locals(sessionIDContextKey, id)
|
||||
}
|
||||
|
||||
// Create session object
|
||||
sess := acquireSession()
|
||||
|
||||
sess.mu.Lock()
|
||||
defer sess.mu.Unlock()
|
||||
|
||||
sess.ctx = c
|
||||
sess.config = s
|
||||
sess.id = id
|
||||
sess.fresh = fresh
|
||||
|
||||
// Decode session data if found
|
||||
if rawData != nil {
|
||||
sess.data.Lock()
|
||||
defer sess.data.Unlock()
|
||||
if err := sess.decodeSessionData(rawData); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode session data: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return sess, nil
|
||||
}
|
||||
|
||||
// getSessionID returns the session ID from cookies, headers, or query string.
|
||||
func (s *Store) getSessionID(c *fiber.Ctx) string {
|
||||
id := c.Cookies(s.sessionName)
|
||||
if len(id) > 0 {
|
||||
return utils.CopyString(id)
|
||||
}
|
||||
|
||||
if s.source == SourceHeader {
|
||||
id = string(c.Request().Header.Peek(s.sessionName))
|
||||
if len(id) > 0 {
|
||||
return id
|
||||
}
|
||||
}
|
||||
|
||||
if s.source == SourceURLQuery {
|
||||
id = c.Query(s.sessionName)
|
||||
if len(id) > 0 {
|
||||
return utils.CopyString(id)
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// Reset deletes all sessions from the storage.
|
||||
func (s *Store) Reset() error {
|
||||
return s.Storage.Reset()
|
||||
}
|
||||
|
||||
// Delete deletes a session by its ID.
|
||||
func (s *Store) Delete(id string) error {
|
||||
if id == "" {
|
||||
return ErrEmptySessionID
|
||||
}
|
||||
return s.Storage.Delete(id)
|
||||
}
|
119
middleware/session/store_test.go
Normal file
119
middleware/session/store_test.go
Normal file
|
@ -0,0 +1,119 @@
|
|||
package session
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// go test -run TestStore_getSessionID
|
||||
func TestStore_getSessionID(t *testing.T) {
|
||||
t.Parallel()
|
||||
expectedID := "test-session-id"
|
||||
|
||||
// fiber instance
|
||||
app := fiber.New()
|
||||
|
||||
t.Run("from cookie", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
// session store
|
||||
store := New()
|
||||
// fiber context
|
||||
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
defer app.ReleaseCtx(ctx)
|
||||
// set cookie
|
||||
ctx.Request().Header.SetCookie(store.sessionName, expectedID)
|
||||
|
||||
utils.AssertEqual(t, expectedID, store.getSessionID(ctx))
|
||||
})
|
||||
|
||||
t.Run("from header", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
// session store
|
||||
store := New(Config{
|
||||
KeyLookup: "header:session_id",
|
||||
})
|
||||
// fiber context
|
||||
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
defer app.ReleaseCtx(ctx)
|
||||
// set header
|
||||
ctx.Request().Header.Set(store.sessionName, expectedID)
|
||||
|
||||
utils.AssertEqual(t, expectedID, store.getSessionID(ctx))
|
||||
})
|
||||
|
||||
t.Run("from url query", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
// session store
|
||||
store := New(Config{
|
||||
KeyLookup: "query:session_id",
|
||||
})
|
||||
// fiber context
|
||||
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
defer app.ReleaseCtx(ctx)
|
||||
// set url parameter
|
||||
ctx.Request().SetRequestURI(fmt.Sprintf("/path?%s=%s", store.sessionName, expectedID))
|
||||
|
||||
utils.AssertEqual(t, expectedID, store.getSessionID(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
// go test -run TestStore_Get
|
||||
// Regression: https://github.com/gofiber/fiber/issues/1408
|
||||
// Regression: https://github.com/gofiber/fiber/security/advisories/GHSA-98j2-3j3p-fw2v
|
||||
func TestStore_Get(t *testing.T) {
|
||||
t.Parallel()
|
||||
unexpectedID := "test-session-id"
|
||||
// fiber instance
|
||||
app := fiber.New()
|
||||
t.Run("session should be re-generated if it is invalid", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
// session store
|
||||
store := New()
|
||||
// fiber context
|
||||
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
defer app.ReleaseCtx(ctx)
|
||||
// set cookie
|
||||
ctx.Request().Header.SetCookie(store.sessionName, unexpectedID)
|
||||
|
||||
acquiredSession, err := store.Get(ctx)
|
||||
utils.AssertEqual(t, err, nil)
|
||||
|
||||
utils.AssertEqual(t, acquiredSession.ID() != unexpectedID, true)
|
||||
})
|
||||
}
|
||||
|
||||
// go test -run TestStore_DeleteSession
|
||||
func TestStore_DeleteSession(t *testing.T) {
|
||||
t.Parallel()
|
||||
// fiber instance
|
||||
app := fiber.New()
|
||||
// session store
|
||||
store := New()
|
||||
|
||||
// fiber context
|
||||
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
defer app.ReleaseCtx(ctx)
|
||||
|
||||
// Create a new session
|
||||
session, err := store.Get(ctx)
|
||||
utils.AssertEqual(t, err, nil)
|
||||
|
||||
// Save the session ID
|
||||
sessionID := session.ID()
|
||||
|
||||
// Delete the session
|
||||
err = store.Delete(sessionID)
|
||||
utils.AssertEqual(t, err, nil)
|
||||
|
||||
// Try to get the session again
|
||||
session, err = store.Get(ctx)
|
||||
utils.AssertEqual(t, err, nil)
|
||||
|
||||
// The session ID should be different now, because the old session was deleted
|
||||
utils.AssertEqual(t, session.ID() == sessionID, false)
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue