Adding upstream version 2.52.6.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
a960158181
commit
6d002e9543
441 changed files with 95392 additions and 0 deletions
128
middleware/limiter/config.go
Normal file
128
middleware/limiter/config.go
Normal file
|
@ -0,0 +1,128 @@
|
|||
package limiter
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/log"
|
||||
)
|
||||
|
||||
// Config defines the config for middleware.
|
||||
type Config struct {
|
||||
// Next defines a function to skip this middleware when returned true.
|
||||
//
|
||||
// Optional. Default: nil
|
||||
Next func(c *fiber.Ctx) bool
|
||||
|
||||
// Max number of recent connections during `Expiration` seconds before sending a 429 response
|
||||
//
|
||||
// Default: 5
|
||||
Max int
|
||||
|
||||
// KeyGenerator allows you to generate custom keys, by default c.IP() is used
|
||||
//
|
||||
// Default: func(c *fiber.Ctx) string {
|
||||
// return c.IP()
|
||||
// }
|
||||
KeyGenerator func(*fiber.Ctx) string
|
||||
|
||||
// Expiration is the time on how long to keep records of requests in memory
|
||||
//
|
||||
// Default: 1 * time.Minute
|
||||
Expiration time.Duration
|
||||
|
||||
// LimitReached is called when a request hits the limit
|
||||
//
|
||||
// Default: func(c *fiber.Ctx) error {
|
||||
// return c.SendStatus(fiber.StatusTooManyRequests)
|
||||
// }
|
||||
LimitReached fiber.Handler
|
||||
|
||||
// When set to true, requests with StatusCode >= 400 won't be counted.
|
||||
//
|
||||
// Default: false
|
||||
SkipFailedRequests bool
|
||||
|
||||
// When set to true, requests with StatusCode < 400 won't be counted.
|
||||
//
|
||||
// Default: false
|
||||
SkipSuccessfulRequests bool
|
||||
|
||||
// Store is used to store the state of the middleware
|
||||
//
|
||||
// Default: an in memory store for this process only
|
||||
Storage fiber.Storage
|
||||
|
||||
// LimiterMiddleware is the struct that implements a limiter middleware.
|
||||
//
|
||||
// Default: a new Fixed Window Rate Limiter
|
||||
LimiterMiddleware LimiterHandler
|
||||
|
||||
// Deprecated: Use Expiration instead
|
||||
Duration time.Duration
|
||||
|
||||
// Deprecated: Use Storage instead
|
||||
Store fiber.Storage
|
||||
|
||||
// Deprecated: Use KeyGenerator instead
|
||||
Key func(*fiber.Ctx) string
|
||||
}
|
||||
|
||||
// ConfigDefault is the default config
|
||||
var ConfigDefault = Config{
|
||||
Max: 5,
|
||||
Expiration: 1 * time.Minute,
|
||||
KeyGenerator: func(c *fiber.Ctx) string {
|
||||
return c.IP()
|
||||
},
|
||||
LimitReached: func(c *fiber.Ctx) error {
|
||||
return c.SendStatus(fiber.StatusTooManyRequests)
|
||||
},
|
||||
SkipFailedRequests: false,
|
||||
SkipSuccessfulRequests: false,
|
||||
LimiterMiddleware: FixedWindow{},
|
||||
}
|
||||
|
||||
// Helper function to set default values
|
||||
func configDefault(config ...Config) Config {
|
||||
// Return default config if nothing provided
|
||||
if len(config) < 1 {
|
||||
return ConfigDefault
|
||||
}
|
||||
|
||||
// Override default config
|
||||
cfg := config[0]
|
||||
|
||||
// Set default values
|
||||
if int(cfg.Duration.Seconds()) > 0 {
|
||||
log.Warn("[LIMITER] Duration is deprecated, please use Expiration")
|
||||
cfg.Expiration = cfg.Duration
|
||||
}
|
||||
if cfg.Key != nil {
|
||||
log.Warn("[LIMITER] Key is deprecated, please us KeyGenerator")
|
||||
cfg.KeyGenerator = cfg.Key
|
||||
}
|
||||
if cfg.Store != nil {
|
||||
log.Warn("[LIMITER] Store is deprecated, please use Storage")
|
||||
cfg.Storage = cfg.Store
|
||||
}
|
||||
if cfg.Next == nil {
|
||||
cfg.Next = ConfigDefault.Next
|
||||
}
|
||||
if cfg.Max <= 0 {
|
||||
cfg.Max = ConfigDefault.Max
|
||||
}
|
||||
if int(cfg.Expiration.Seconds()) <= 0 {
|
||||
cfg.Expiration = ConfigDefault.Expiration
|
||||
}
|
||||
if cfg.KeyGenerator == nil {
|
||||
cfg.KeyGenerator = ConfigDefault.KeyGenerator
|
||||
}
|
||||
if cfg.LimitReached == nil {
|
||||
cfg.LimitReached = ConfigDefault.LimitReached
|
||||
}
|
||||
if cfg.LimiterMiddleware == nil {
|
||||
cfg.LimiterMiddleware = ConfigDefault.LimiterMiddleware
|
||||
}
|
||||
return cfg
|
||||
}
|
25
middleware/limiter/limiter.go
Normal file
25
middleware/limiter/limiter.go
Normal file
|
@ -0,0 +1,25 @@
|
|||
package limiter
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
const (
|
||||
// X-RateLimit-* headers
|
||||
xRateLimitLimit = "X-RateLimit-Limit"
|
||||
xRateLimitRemaining = "X-RateLimit-Remaining"
|
||||
xRateLimitReset = "X-RateLimit-Reset"
|
||||
)
|
||||
|
||||
type LimiterHandler interface {
|
||||
New(config Config) fiber.Handler
|
||||
}
|
||||
|
||||
// New creates a new middleware handler
|
||||
func New(config ...Config) fiber.Handler {
|
||||
// Set default config
|
||||
cfg := configDefault(config...)
|
||||
|
||||
// Return the specified middleware handler.
|
||||
return cfg.LimiterMiddleware.New(cfg)
|
||||
}
|
106
middleware/limiter/limiter_fixed.go
Normal file
106
middleware/limiter/limiter_fixed.go
Normal file
|
@ -0,0 +1,106 @@
|
|||
package limiter
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
)
|
||||
|
||||
type FixedWindow struct{}
|
||||
|
||||
// New creates a new fixed window middleware handler
|
||||
func (FixedWindow) New(cfg Config) fiber.Handler {
|
||||
var (
|
||||
// Limiter variables
|
||||
mux = &sync.RWMutex{}
|
||||
max = strconv.Itoa(cfg.Max)
|
||||
expiration = uint64(cfg.Expiration.Seconds())
|
||||
)
|
||||
|
||||
// Create manager to simplify storage operations ( see manager.go )
|
||||
manager := newManager(cfg.Storage)
|
||||
|
||||
// Update timestamp every second
|
||||
utils.StartTimeStampUpdater()
|
||||
|
||||
// Return new handler
|
||||
return func(c *fiber.Ctx) error {
|
||||
// Don't execute middleware if Next returns true
|
||||
if cfg.Next != nil && cfg.Next(c) {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
// Get key from request
|
||||
key := cfg.KeyGenerator(c)
|
||||
|
||||
// Lock entry
|
||||
mux.Lock()
|
||||
|
||||
// Get entry from pool and release when finished
|
||||
e := manager.get(key)
|
||||
|
||||
// Get timestamp
|
||||
ts := uint64(atomic.LoadUint32(&utils.Timestamp))
|
||||
|
||||
// Set expiration if entry does not exist
|
||||
if e.exp == 0 {
|
||||
e.exp = ts + expiration
|
||||
} else if ts >= e.exp {
|
||||
// Check if entry is expired
|
||||
e.currHits = 0
|
||||
e.exp = ts + expiration
|
||||
}
|
||||
|
||||
// Increment hits
|
||||
e.currHits++
|
||||
|
||||
// Calculate when it resets in seconds
|
||||
resetInSec := e.exp - ts
|
||||
|
||||
// Set how many hits we have left
|
||||
remaining := cfg.Max - e.currHits
|
||||
|
||||
// Update storage
|
||||
manager.set(key, e, cfg.Expiration)
|
||||
|
||||
// Unlock entry
|
||||
mux.Unlock()
|
||||
|
||||
// Check if hits exceed the cfg.Max
|
||||
if remaining < 0 {
|
||||
// Return response with Retry-After header
|
||||
// https://tools.ietf.org/html/rfc6584
|
||||
c.Set(fiber.HeaderRetryAfter, strconv.FormatUint(resetInSec, 10))
|
||||
|
||||
// Call LimitReached handler
|
||||
return cfg.LimitReached(c)
|
||||
}
|
||||
|
||||
// Continue stack for reaching c.Response().StatusCode()
|
||||
// Store err for returning
|
||||
err := c.Next()
|
||||
|
||||
// Check for SkipFailedRequests and SkipSuccessfulRequests
|
||||
if (cfg.SkipSuccessfulRequests && c.Response().StatusCode() < fiber.StatusBadRequest) ||
|
||||
(cfg.SkipFailedRequests && c.Response().StatusCode() >= fiber.StatusBadRequest) {
|
||||
// Lock entry
|
||||
mux.Lock()
|
||||
e = manager.get(key)
|
||||
e.currHits--
|
||||
remaining++
|
||||
manager.set(key, e, cfg.Expiration)
|
||||
// Unlock entry
|
||||
mux.Unlock()
|
||||
}
|
||||
|
||||
// We can continue, update RateLimit headers
|
||||
c.Set(xRateLimitLimit, max)
|
||||
c.Set(xRateLimitRemaining, strconv.Itoa(remaining))
|
||||
c.Set(xRateLimitReset, strconv.FormatUint(resetInSec, 10))
|
||||
|
||||
return err
|
||||
}
|
||||
}
|
137
middleware/limiter/limiter_sliding.go
Normal file
137
middleware/limiter/limiter_sliding.go
Normal file
|
@ -0,0 +1,137 @@
|
|||
package limiter
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
)
|
||||
|
||||
type SlidingWindow struct{}
|
||||
|
||||
// New creates a new sliding window middleware handler
|
||||
func (SlidingWindow) New(cfg Config) fiber.Handler {
|
||||
var (
|
||||
// Limiter variables
|
||||
mux = &sync.RWMutex{}
|
||||
max = strconv.Itoa(cfg.Max)
|
||||
expiration = uint64(cfg.Expiration.Seconds())
|
||||
)
|
||||
|
||||
// Create manager to simplify storage operations ( see manager.go )
|
||||
manager := newManager(cfg.Storage)
|
||||
|
||||
// Update timestamp every second
|
||||
utils.StartTimeStampUpdater()
|
||||
|
||||
// Return new handler
|
||||
return func(c *fiber.Ctx) error {
|
||||
// Don't execute middleware if Next returns true
|
||||
if cfg.Next != nil && cfg.Next(c) {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
// Get key from request
|
||||
key := cfg.KeyGenerator(c)
|
||||
|
||||
// Lock entry
|
||||
mux.Lock()
|
||||
|
||||
// Get entry from pool and release when finished
|
||||
e := manager.get(key)
|
||||
|
||||
// Get timestamp
|
||||
ts := uint64(atomic.LoadUint32(&utils.Timestamp))
|
||||
|
||||
// Set expiration if entry does not exist
|
||||
if e.exp == 0 {
|
||||
e.exp = ts + expiration
|
||||
} else if ts >= e.exp {
|
||||
// The entry has expired, handle the expiration.
|
||||
// Set the prevHits to the current hits and reset the hits to 0.
|
||||
e.prevHits = e.currHits
|
||||
|
||||
// Reset the current hits to 0.
|
||||
e.currHits = 0
|
||||
|
||||
// Check how much into the current window it currently is and sets the
|
||||
// expiry based on that, otherwise this would only reset on
|
||||
// the next request and not show the correct expiry.
|
||||
elapsed := ts - e.exp
|
||||
if elapsed >= expiration {
|
||||
e.exp = ts + expiration
|
||||
} else {
|
||||
e.exp = ts + expiration - elapsed
|
||||
}
|
||||
}
|
||||
|
||||
// Increment hits
|
||||
e.currHits++
|
||||
|
||||
// Calculate when it resets in seconds
|
||||
resetInSec := e.exp - ts
|
||||
|
||||
// weight = time until current window reset / total window length
|
||||
weight := float64(resetInSec) / float64(expiration)
|
||||
|
||||
// rate = request count in previous window - weight + request count in current window
|
||||
rate := int(float64(e.prevHits)*weight) + e.currHits
|
||||
|
||||
// Calculate how many hits can be made based on the current rate
|
||||
remaining := cfg.Max - rate
|
||||
|
||||
// Update storage. Garbage collect when the next window ends.
|
||||
// |--------------------------|--------------------------|
|
||||
// ^ ^ ^ ^
|
||||
// ts e.exp End sample window End next window
|
||||
// <------------>
|
||||
// resetInSec
|
||||
// resetInSec = e.exp - ts - time until end of current window.
|
||||
// duration + expiration = end of next window.
|
||||
// Because we don't want to garbage collect in the middle of a window
|
||||
// we add the expiration to the duration.
|
||||
// Otherwise after the end of "sample window", attackers could launch
|
||||
// a new request with the full window length.
|
||||
manager.set(key, e, time.Duration(resetInSec+expiration)*time.Second)
|
||||
|
||||
// Unlock entry
|
||||
mux.Unlock()
|
||||
|
||||
// Check if hits exceed the cfg.Max
|
||||
if remaining < 0 {
|
||||
// Return response with Retry-After header
|
||||
// https://tools.ietf.org/html/rfc6584
|
||||
c.Set(fiber.HeaderRetryAfter, strconv.FormatUint(resetInSec, 10))
|
||||
|
||||
// Call LimitReached handler
|
||||
return cfg.LimitReached(c)
|
||||
}
|
||||
|
||||
// Continue stack for reaching c.Response().StatusCode()
|
||||
// Store err for returning
|
||||
err := c.Next()
|
||||
|
||||
// Check for SkipFailedRequests and SkipSuccessfulRequests
|
||||
if (cfg.SkipSuccessfulRequests && c.Response().StatusCode() < fiber.StatusBadRequest) ||
|
||||
(cfg.SkipFailedRequests && c.Response().StatusCode() >= fiber.StatusBadRequest) {
|
||||
// Lock entry
|
||||
mux.Lock()
|
||||
e = manager.get(key)
|
||||
e.currHits--
|
||||
remaining++
|
||||
manager.set(key, e, cfg.Expiration)
|
||||
// Unlock entry
|
||||
mux.Unlock()
|
||||
}
|
||||
|
||||
// We can continue, update RateLimit headers
|
||||
c.Set(xRateLimitLimit, max)
|
||||
c.Set(xRateLimitRemaining, strconv.Itoa(remaining))
|
||||
c.Set(xRateLimitReset, strconv.FormatUint(resetInSec, 10))
|
||||
|
||||
return err
|
||||
}
|
||||
}
|
727
middleware/limiter/limiter_test.go
Normal file
727
middleware/limiter/limiter_test.go
Normal file
|
@ -0,0 +1,727 @@
|
|||
package limiter
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/internal/storage/memory"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// go test -run Test_Limiter_Concurrency_Store -race -v
|
||||
func Test_Limiter_Concurrency_Store(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Test concurrency using a custom store
|
||||
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
Max: 50,
|
||||
Expiration: 2 * time.Second,
|
||||
Storage: memory.New(),
|
||||
}))
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return c.SendString("Hello tester!")
|
||||
})
|
||||
|
||||
var wg sync.WaitGroup
|
||||
singleRequest := func(wg *sync.WaitGroup) {
|
||||
defer wg.Done()
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, "Hello tester!", string(body))
|
||||
}
|
||||
|
||||
for i := 0; i <= 49; i++ {
|
||||
wg.Add(1)
|
||||
go singleRequest(&wg)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 429, resp.StatusCode)
|
||||
|
||||
time.Sleep(3 * time.Second)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||
}
|
||||
|
||||
// go test -run Test_Limiter_Concurrency -race -v
|
||||
func Test_Limiter_Concurrency(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Test concurrency using a default store
|
||||
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
Max: 50,
|
||||
Expiration: 2 * time.Second,
|
||||
}))
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return c.SendString("Hello tester!")
|
||||
})
|
||||
|
||||
var wg sync.WaitGroup
|
||||
singleRequest := func(wg *sync.WaitGroup) {
|
||||
defer wg.Done()
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, "Hello tester!", string(body))
|
||||
}
|
||||
|
||||
for i := 0; i <= 49; i++ {
|
||||
wg.Add(1)
|
||||
go singleRequest(&wg)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 429, resp.StatusCode)
|
||||
|
||||
time.Sleep(3 * time.Second)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||
}
|
||||
|
||||
// go test -run Test_Limiter_Fixed_Window_No_Skip_Choices -v
|
||||
func Test_Limiter_Fixed_Window_No_Skip_Choices(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
Max: 2,
|
||||
Expiration: 2 * time.Second,
|
||||
SkipFailedRequests: false,
|
||||
SkipSuccessfulRequests: false,
|
||||
LimiterMiddleware: FixedWindow{},
|
||||
}))
|
||||
|
||||
app.Get("/:status", func(c *fiber.Ctx) error {
|
||||
if c.Params("status") == "fail" { //nolint:goconst // False positive
|
||||
return c.SendStatus(400)
|
||||
}
|
||||
return c.SendStatus(200)
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 400, resp.StatusCode)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 429, resp.StatusCode)
|
||||
|
||||
time.Sleep(3 * time.Second)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||
}
|
||||
|
||||
// go test -run Test_Limiter_Fixed_Window_Custom_Storage_No_Skip_Choices -v
|
||||
func Test_Limiter_Fixed_Window_Custom_Storage_No_Skip_Choices(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
Max: 2,
|
||||
Expiration: 2 * time.Second,
|
||||
SkipFailedRequests: false,
|
||||
SkipSuccessfulRequests: false,
|
||||
Storage: memory.New(),
|
||||
LimiterMiddleware: FixedWindow{},
|
||||
}))
|
||||
|
||||
app.Get("/:status", func(c *fiber.Ctx) error {
|
||||
if c.Params("status") == "fail" {
|
||||
return c.SendStatus(400)
|
||||
}
|
||||
return c.SendStatus(200)
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 400, resp.StatusCode)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 429, resp.StatusCode)
|
||||
|
||||
time.Sleep(3 * time.Second)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||
}
|
||||
|
||||
// go test -run Test_Limiter_Sliding_Window_No_Skip_Choices -v
|
||||
func Test_Limiter_Sliding_Window_No_Skip_Choices(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
Max: 2,
|
||||
Expiration: 2 * time.Second,
|
||||
SkipFailedRequests: false,
|
||||
SkipSuccessfulRequests: false,
|
||||
LimiterMiddleware: SlidingWindow{},
|
||||
}))
|
||||
|
||||
app.Get("/:status", func(c *fiber.Ctx) error {
|
||||
if c.Params("status") == "fail" {
|
||||
return c.SendStatus(400)
|
||||
}
|
||||
return c.SendStatus(200)
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 400, resp.StatusCode)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 429, resp.StatusCode)
|
||||
|
||||
time.Sleep(4*time.Second + 500*time.Millisecond)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||
}
|
||||
|
||||
// go test -run Test_Limiter_Sliding_Window_Custom_Storage_No_Skip_Choices -v
|
||||
func Test_Limiter_Sliding_Window_Custom_Storage_No_Skip_Choices(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
Max: 2,
|
||||
Expiration: 2 * time.Second,
|
||||
SkipFailedRequests: false,
|
||||
SkipSuccessfulRequests: false,
|
||||
Storage: memory.New(),
|
||||
LimiterMiddleware: SlidingWindow{},
|
||||
}))
|
||||
|
||||
app.Get("/:status", func(c *fiber.Ctx) error {
|
||||
if c.Params("status") == "fail" {
|
||||
return c.SendStatus(400)
|
||||
}
|
||||
return c.SendStatus(200)
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 400, resp.StatusCode)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 429, resp.StatusCode)
|
||||
|
||||
time.Sleep(4*time.Second + 500*time.Millisecond)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||
}
|
||||
|
||||
// go test -run Test_Limiter_Fixed_Window_Skip_Failed_Requests -v
|
||||
func Test_Limiter_Fixed_Window_Skip_Failed_Requests(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
Max: 1,
|
||||
Expiration: 2 * time.Second,
|
||||
SkipFailedRequests: true,
|
||||
LimiterMiddleware: FixedWindow{},
|
||||
}))
|
||||
|
||||
app.Get("/:status", func(c *fiber.Ctx) error {
|
||||
if c.Params("status") == "fail" {
|
||||
return c.SendStatus(400)
|
||||
}
|
||||
return c.SendStatus(200)
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 400, resp.StatusCode)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 429, resp.StatusCode)
|
||||
|
||||
time.Sleep(3 * time.Second)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||
}
|
||||
|
||||
// go test -run Test_Limiter_Fixed_Window_Custom_Storage_Skip_Failed_Requests -v
|
||||
func Test_Limiter_Fixed_Window_Custom_Storage_Skip_Failed_Requests(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
Max: 1,
|
||||
Expiration: 2 * time.Second,
|
||||
Storage: memory.New(),
|
||||
SkipFailedRequests: true,
|
||||
LimiterMiddleware: FixedWindow{},
|
||||
}))
|
||||
|
||||
app.Get("/:status", func(c *fiber.Ctx) error {
|
||||
if c.Params("status") == "fail" {
|
||||
return c.SendStatus(400)
|
||||
}
|
||||
return c.SendStatus(200)
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 400, resp.StatusCode)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 429, resp.StatusCode)
|
||||
|
||||
time.Sleep(3 * time.Second)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||
}
|
||||
|
||||
// go test -run Test_Limiter_Sliding_Window_Skip_Failed_Requests -v
|
||||
func Test_Limiter_Sliding_Window_Skip_Failed_Requests(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
Max: 1,
|
||||
Expiration: 2 * time.Second,
|
||||
SkipFailedRequests: true,
|
||||
LimiterMiddleware: SlidingWindow{},
|
||||
}))
|
||||
|
||||
app.Get("/:status", func(c *fiber.Ctx) error {
|
||||
if c.Params("status") == "fail" {
|
||||
return c.SendStatus(400)
|
||||
}
|
||||
return c.SendStatus(200)
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 400, resp.StatusCode)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 429, resp.StatusCode)
|
||||
|
||||
time.Sleep(4*time.Second + 500*time.Millisecond)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||
}
|
||||
|
||||
// go test -run Test_Limiter_Sliding_Window_Custom_Storage_Skip_Failed_Requests -v
|
||||
func Test_Limiter_Sliding_Window_Custom_Storage_Skip_Failed_Requests(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
Max: 1,
|
||||
Expiration: 2 * time.Second,
|
||||
Storage: memory.New(),
|
||||
SkipFailedRequests: true,
|
||||
LimiterMiddleware: SlidingWindow{},
|
||||
}))
|
||||
|
||||
app.Get("/:status", func(c *fiber.Ctx) error {
|
||||
if c.Params("status") == "fail" {
|
||||
return c.SendStatus(400)
|
||||
}
|
||||
return c.SendStatus(200)
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 400, resp.StatusCode)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 429, resp.StatusCode)
|
||||
|
||||
time.Sleep(4*time.Second + 500*time.Millisecond)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||
}
|
||||
|
||||
// go test -run Test_Limiter_Fixed_Window_Skip_Successful_Requests -v
|
||||
func Test_Limiter_Fixed_Window_Skip_Successful_Requests(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Test concurrency using a default store
|
||||
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
Max: 1,
|
||||
Expiration: 2 * time.Second,
|
||||
SkipSuccessfulRequests: true,
|
||||
LimiterMiddleware: FixedWindow{},
|
||||
}))
|
||||
|
||||
app.Get("/:status", func(c *fiber.Ctx) error {
|
||||
if c.Params("status") == "fail" {
|
||||
return c.SendStatus(400)
|
||||
}
|
||||
return c.SendStatus(200)
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 400, resp.StatusCode)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 429, resp.StatusCode)
|
||||
|
||||
time.Sleep(3 * time.Second)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 400, resp.StatusCode)
|
||||
}
|
||||
|
||||
// go test -run Test_Limiter_Fixed_Window_Custom_Storage_Skip_Successful_Requests -v
|
||||
func Test_Limiter_Fixed_Window_Custom_Storage_Skip_Successful_Requests(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Test concurrency using a default store
|
||||
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
Max: 1,
|
||||
Expiration: 2 * time.Second,
|
||||
Storage: memory.New(),
|
||||
SkipSuccessfulRequests: true,
|
||||
LimiterMiddleware: FixedWindow{},
|
||||
}))
|
||||
|
||||
app.Get("/:status", func(c *fiber.Ctx) error {
|
||||
if c.Params("status") == "fail" {
|
||||
return c.SendStatus(400)
|
||||
}
|
||||
return c.SendStatus(200)
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 400, resp.StatusCode)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 429, resp.StatusCode)
|
||||
|
||||
time.Sleep(3 * time.Second)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 400, resp.StatusCode)
|
||||
}
|
||||
|
||||
// go test -run Test_Limiter_Sliding_Window_Skip_Successful_Requests -v
|
||||
func Test_Limiter_Sliding_Window_Skip_Successful_Requests(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Test concurrency using a default store
|
||||
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
Max: 1,
|
||||
Expiration: 2 * time.Second,
|
||||
SkipSuccessfulRequests: true,
|
||||
LimiterMiddleware: SlidingWindow{},
|
||||
}))
|
||||
|
||||
app.Get("/:status", func(c *fiber.Ctx) error {
|
||||
if c.Params("status") == "fail" {
|
||||
return c.SendStatus(400)
|
||||
}
|
||||
return c.SendStatus(200)
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 400, resp.StatusCode)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 429, resp.StatusCode)
|
||||
|
||||
time.Sleep(4*time.Second + 500*time.Millisecond)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 400, resp.StatusCode)
|
||||
}
|
||||
|
||||
// go test -run Test_Limiter_Sliding_Window_Custom_Storage_Skip_Successful_Requests -v
|
||||
func Test_Limiter_Sliding_Window_Custom_Storage_Skip_Successful_Requests(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Test concurrency using a default store
|
||||
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
Max: 1,
|
||||
Expiration: 2 * time.Second,
|
||||
Storage: memory.New(),
|
||||
SkipSuccessfulRequests: true,
|
||||
LimiterMiddleware: SlidingWindow{},
|
||||
}))
|
||||
|
||||
app.Get("/:status", func(c *fiber.Ctx) error {
|
||||
if c.Params("status") == "fail" {
|
||||
return c.SendStatus(400)
|
||||
}
|
||||
return c.SendStatus(200)
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 400, resp.StatusCode)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 429, resp.StatusCode)
|
||||
|
||||
time.Sleep(4*time.Second + 500*time.Millisecond)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 400, resp.StatusCode)
|
||||
}
|
||||
|
||||
// go test -v -run=^$ -bench=Benchmark_Limiter_Custom_Store -benchmem -count=4
|
||||
func Benchmark_Limiter_Custom_Store(b *testing.B) {
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
Max: 100,
|
||||
Expiration: 60 * time.Second,
|
||||
Storage: memory.New(),
|
||||
}))
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return c.SendString("Hello, World!")
|
||||
})
|
||||
|
||||
h := app.Handler()
|
||||
|
||||
fctx := &fasthttp.RequestCtx{}
|
||||
fctx.Request.Header.SetMethod(fiber.MethodGet)
|
||||
fctx.Request.SetRequestURI("/")
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
h(fctx)
|
||||
}
|
||||
}
|
||||
|
||||
// go test -run Test_Limiter_Next
|
||||
func Test_Limiter_Next(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
app.Use(New(Config{
|
||||
Next: func(_ *fiber.Ctx) bool {
|
||||
return true
|
||||
},
|
||||
}))
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
|
||||
}
|
||||
|
||||
func Test_Limiter_Headers(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
Max: 50,
|
||||
Expiration: 2 * time.Second,
|
||||
}))
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return c.SendString("Hello tester!")
|
||||
})
|
||||
|
||||
fctx := &fasthttp.RequestCtx{}
|
||||
fctx.Request.Header.SetMethod(fiber.MethodGet)
|
||||
fctx.Request.SetRequestURI("/")
|
||||
|
||||
app.Handler()(fctx)
|
||||
|
||||
utils.AssertEqual(t, "50", string(fctx.Response.Header.Peek("X-RateLimit-Limit")))
|
||||
if v := string(fctx.Response.Header.Peek("X-RateLimit-Remaining")); v == "" {
|
||||
t.Errorf("The X-RateLimit-Remaining header is not set correctly - value is an empty string.")
|
||||
}
|
||||
if v := string(fctx.Response.Header.Peek("X-RateLimit-Reset")); !(v == "1" || v == "2") {
|
||||
t.Errorf("The X-RateLimit-Reset header is not set correctly - value is out of bounds.")
|
||||
}
|
||||
}
|
||||
|
||||
// go test -v -run=^$ -bench=Benchmark_Limiter -benchmem -count=4
|
||||
func Benchmark_Limiter(b *testing.B) {
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
Max: 100,
|
||||
Expiration: 60 * time.Second,
|
||||
}))
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return c.SendString("Hello, World!")
|
||||
})
|
||||
|
||||
h := app.Handler()
|
||||
|
||||
fctx := &fasthttp.RequestCtx{}
|
||||
fctx.Request.Header.SetMethod(fiber.MethodGet)
|
||||
fctx.Request.SetRequestURI("/")
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
h(fctx)
|
||||
}
|
||||
}
|
||||
|
||||
// go test -run Test_Sliding_Window -race -v
|
||||
func Test_Sliding_Window(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
app.Use(New(Config{
|
||||
Max: 10,
|
||||
Expiration: 2 * time.Second,
|
||||
Storage: memory.New(),
|
||||
LimiterMiddleware: SlidingWindow{},
|
||||
}))
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return c.SendString("Hello tester!")
|
||||
})
|
||||
|
||||
singleRequest := func(shouldFail bool) {
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
if shouldFail {
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 429, resp.StatusCode)
|
||||
} else {
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
singleRequest(false)
|
||||
}
|
||||
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
singleRequest(false)
|
||||
}
|
||||
|
||||
time.Sleep(3 * time.Second)
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
singleRequest(false)
|
||||
}
|
||||
|
||||
time.Sleep(4 * time.Second)
|
||||
|
||||
for i := 0; i < 9; i++ {
|
||||
singleRequest(false)
|
||||
}
|
||||
}
|
92
middleware/limiter/manager.go
Normal file
92
middleware/limiter/manager.go
Normal file
|
@ -0,0 +1,92 @@
|
|||
package limiter
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/internal/memory"
|
||||
)
|
||||
|
||||
// go:generate msgp
|
||||
// msgp -file="manager.go" -o="manager_msgp.go" -tests=false -unexported
|
||||
type item struct {
|
||||
currHits int
|
||||
prevHits int
|
||||
exp uint64
|
||||
}
|
||||
|
||||
//msgp:ignore manager
|
||||
type manager struct {
|
||||
pool sync.Pool
|
||||
memory *memory.Storage
|
||||
storage fiber.Storage
|
||||
}
|
||||
|
||||
func newManager(storage fiber.Storage) *manager {
|
||||
// Create new storage handler
|
||||
manager := &manager{
|
||||
pool: sync.Pool{
|
||||
New: func() interface{} {
|
||||
return new(item)
|
||||
},
|
||||
},
|
||||
}
|
||||
if storage != nil {
|
||||
// Use provided storage if provided
|
||||
manager.storage = storage
|
||||
} else {
|
||||
// Fallback too memory storage
|
||||
manager.memory = memory.New()
|
||||
}
|
||||
return manager
|
||||
}
|
||||
|
||||
// acquire returns an *entry from the sync.Pool
|
||||
func (m *manager) acquire() *item {
|
||||
return m.pool.Get().(*item) //nolint:forcetypeassert // We store nothing else in the pool
|
||||
}
|
||||
|
||||
// release and reset *entry to sync.Pool
|
||||
func (m *manager) release(e *item) {
|
||||
e.prevHits = 0
|
||||
e.currHits = 0
|
||||
e.exp = 0
|
||||
m.pool.Put(e)
|
||||
}
|
||||
|
||||
// get data from storage or memory
|
||||
func (m *manager) get(key string) *item {
|
||||
var it *item
|
||||
if m.storage != nil {
|
||||
it = m.acquire()
|
||||
raw, err := m.storage.Get(key)
|
||||
if err != nil {
|
||||
return it
|
||||
}
|
||||
if raw != nil {
|
||||
if _, err := it.UnmarshalMsg(raw); err != nil {
|
||||
return it
|
||||
}
|
||||
}
|
||||
return it
|
||||
}
|
||||
if it, _ = m.memory.Get(key).(*item); it == nil { //nolint:errcheck // We store nothing else in the pool
|
||||
it = m.acquire()
|
||||
return it
|
||||
}
|
||||
return it
|
||||
}
|
||||
|
||||
// set data to storage or memory
|
||||
func (m *manager) set(key string, it *item, exp time.Duration) {
|
||||
if m.storage != nil {
|
||||
if raw, err := it.MarshalMsg(nil); err == nil {
|
||||
_ = m.storage.Set(key, raw, exp) //nolint:errcheck // TODO: Handle error here
|
||||
}
|
||||
// we can release data because it's serialized to database
|
||||
m.release(it)
|
||||
} else {
|
||||
m.memory.Set(key, it, exp)
|
||||
}
|
||||
}
|
160
middleware/limiter/manager_msgp.go
Normal file
160
middleware/limiter/manager_msgp.go
Normal file
|
@ -0,0 +1,160 @@
|
|||
package limiter
|
||||
|
||||
// Code generated by github.com/tinylib/msgp DO NOT EDIT.
|
||||
|
||||
import (
|
||||
"github.com/tinylib/msgp/msgp"
|
||||
)
|
||||
|
||||
// DecodeMsg implements msgp.Decodable
|
||||
func (z *item) DecodeMsg(dc *msgp.Reader) (err error) {
|
||||
var field []byte
|
||||
_ = field
|
||||
var zb0001 uint32
|
||||
zb0001, err = dc.ReadMapHeader()
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err)
|
||||
return
|
||||
}
|
||||
for zb0001 > 0 {
|
||||
zb0001--
|
||||
field, err = dc.ReadMapKeyPtr()
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err)
|
||||
return
|
||||
}
|
||||
switch msgp.UnsafeString(field) {
|
||||
case "currHits":
|
||||
z.currHits, err = dc.ReadInt()
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "currHits")
|
||||
return
|
||||
}
|
||||
case "prevHits":
|
||||
z.prevHits, err = dc.ReadInt()
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "prevHits")
|
||||
return
|
||||
}
|
||||
case "exp":
|
||||
z.exp, err = dc.ReadUint64()
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "exp")
|
||||
return
|
||||
}
|
||||
default:
|
||||
err = dc.Skip()
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// EncodeMsg implements msgp.Encodable
|
||||
func (z item) EncodeMsg(en *msgp.Writer) (err error) {
|
||||
// map header, size 3
|
||||
// write "currHits"
|
||||
err = en.Append(0x83, 0xa8, 0x63, 0x75, 0x72, 0x72, 0x48, 0x69, 0x74, 0x73)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = en.WriteInt(z.currHits)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "currHits")
|
||||
return
|
||||
}
|
||||
// write "prevHits"
|
||||
err = en.Append(0xa8, 0x70, 0x72, 0x65, 0x76, 0x48, 0x69, 0x74, 0x73)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = en.WriteInt(z.prevHits)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "prevHits")
|
||||
return
|
||||
}
|
||||
// write "exp"
|
||||
err = en.Append(0xa3, 0x65, 0x78, 0x70)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = en.WriteUint64(z.exp)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "exp")
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// MarshalMsg implements msgp.Marshaler
|
||||
func (z item) MarshalMsg(b []byte) (o []byte, err error) {
|
||||
o = msgp.Require(b, z.Msgsize())
|
||||
// map header, size 3
|
||||
// string "currHits"
|
||||
o = append(o, 0x83, 0xa8, 0x63, 0x75, 0x72, 0x72, 0x48, 0x69, 0x74, 0x73)
|
||||
o = msgp.AppendInt(o, z.currHits)
|
||||
// string "prevHits"
|
||||
o = append(o, 0xa8, 0x70, 0x72, 0x65, 0x76, 0x48, 0x69, 0x74, 0x73)
|
||||
o = msgp.AppendInt(o, z.prevHits)
|
||||
// string "exp"
|
||||
o = append(o, 0xa3, 0x65, 0x78, 0x70)
|
||||
o = msgp.AppendUint64(o, z.exp)
|
||||
return
|
||||
}
|
||||
|
||||
// UnmarshalMsg implements msgp.Unmarshaler
|
||||
func (z *item) UnmarshalMsg(bts []byte) (o []byte, err error) {
|
||||
var field []byte
|
||||
_ = field
|
||||
var zb0001 uint32
|
||||
zb0001, bts, err = msgp.ReadMapHeaderBytes(bts)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err)
|
||||
return
|
||||
}
|
||||
for zb0001 > 0 {
|
||||
zb0001--
|
||||
field, bts, err = msgp.ReadMapKeyZC(bts)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err)
|
||||
return
|
||||
}
|
||||
switch msgp.UnsafeString(field) {
|
||||
case "currHits":
|
||||
z.currHits, bts, err = msgp.ReadIntBytes(bts)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "currHits")
|
||||
return
|
||||
}
|
||||
case "prevHits":
|
||||
z.prevHits, bts, err = msgp.ReadIntBytes(bts)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "prevHits")
|
||||
return
|
||||
}
|
||||
case "exp":
|
||||
z.exp, bts, err = msgp.ReadUint64Bytes(bts)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "exp")
|
||||
return
|
||||
}
|
||||
default:
|
||||
bts, err = msgp.Skip(bts)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
o = bts
|
||||
return
|
||||
}
|
||||
|
||||
// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message
|
||||
func (z item) Msgsize() (s int) {
|
||||
s = 1 + 9 + msgp.IntSize + 9 + msgp.IntSize + 4 + msgp.Uint64Size
|
||||
return
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue