1
0
Fork 0

Adding upstream version 2.52.6.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-05-17 06:50:16 +02:00
parent a960158181
commit 6d002e9543
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
441 changed files with 95392 additions and 0 deletions

View file

@ -0,0 +1,128 @@
package limiter
import (
"time"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/log"
)
// Config defines the config for middleware.
type Config struct {
// Next defines a function to skip this middleware when returned true.
//
// Optional. Default: nil
Next func(c *fiber.Ctx) bool
// Max number of recent connections during `Expiration` seconds before sending a 429 response
//
// Default: 5
Max int
// KeyGenerator allows you to generate custom keys, by default c.IP() is used
//
// Default: func(c *fiber.Ctx) string {
// return c.IP()
// }
KeyGenerator func(*fiber.Ctx) string
// Expiration is the time on how long to keep records of requests in memory
//
// Default: 1 * time.Minute
Expiration time.Duration
// LimitReached is called when a request hits the limit
//
// Default: func(c *fiber.Ctx) error {
// return c.SendStatus(fiber.StatusTooManyRequests)
// }
LimitReached fiber.Handler
// When set to true, requests with StatusCode >= 400 won't be counted.
//
// Default: false
SkipFailedRequests bool
// When set to true, requests with StatusCode < 400 won't be counted.
//
// Default: false
SkipSuccessfulRequests bool
// Store is used to store the state of the middleware
//
// Default: an in memory store for this process only
Storage fiber.Storage
// LimiterMiddleware is the struct that implements a limiter middleware.
//
// Default: a new Fixed Window Rate Limiter
LimiterMiddleware LimiterHandler
// Deprecated: Use Expiration instead
Duration time.Duration
// Deprecated: Use Storage instead
Store fiber.Storage
// Deprecated: Use KeyGenerator instead
Key func(*fiber.Ctx) string
}
// ConfigDefault is the default config
var ConfigDefault = Config{
Max: 5,
Expiration: 1 * time.Minute,
KeyGenerator: func(c *fiber.Ctx) string {
return c.IP()
},
LimitReached: func(c *fiber.Ctx) error {
return c.SendStatus(fiber.StatusTooManyRequests)
},
SkipFailedRequests: false,
SkipSuccessfulRequests: false,
LimiterMiddleware: FixedWindow{},
}
// Helper function to set default values
func configDefault(config ...Config) Config {
// Return default config if nothing provided
if len(config) < 1 {
return ConfigDefault
}
// Override default config
cfg := config[0]
// Set default values
if int(cfg.Duration.Seconds()) > 0 {
log.Warn("[LIMITER] Duration is deprecated, please use Expiration")
cfg.Expiration = cfg.Duration
}
if cfg.Key != nil {
log.Warn("[LIMITER] Key is deprecated, please us KeyGenerator")
cfg.KeyGenerator = cfg.Key
}
if cfg.Store != nil {
log.Warn("[LIMITER] Store is deprecated, please use Storage")
cfg.Storage = cfg.Store
}
if cfg.Next == nil {
cfg.Next = ConfigDefault.Next
}
if cfg.Max <= 0 {
cfg.Max = ConfigDefault.Max
}
if int(cfg.Expiration.Seconds()) <= 0 {
cfg.Expiration = ConfigDefault.Expiration
}
if cfg.KeyGenerator == nil {
cfg.KeyGenerator = ConfigDefault.KeyGenerator
}
if cfg.LimitReached == nil {
cfg.LimitReached = ConfigDefault.LimitReached
}
if cfg.LimiterMiddleware == nil {
cfg.LimiterMiddleware = ConfigDefault.LimiterMiddleware
}
return cfg
}

View file

@ -0,0 +1,25 @@
package limiter
import (
"github.com/gofiber/fiber/v2"
)
const (
// X-RateLimit-* headers
xRateLimitLimit = "X-RateLimit-Limit"
xRateLimitRemaining = "X-RateLimit-Remaining"
xRateLimitReset = "X-RateLimit-Reset"
)
type LimiterHandler interface {
New(config Config) fiber.Handler
}
// New creates a new middleware handler
func New(config ...Config) fiber.Handler {
// Set default config
cfg := configDefault(config...)
// Return the specified middleware handler.
return cfg.LimiterMiddleware.New(cfg)
}

View file

@ -0,0 +1,106 @@
package limiter
import (
"strconv"
"sync"
"sync/atomic"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/utils"
)
type FixedWindow struct{}
// New creates a new fixed window middleware handler
func (FixedWindow) New(cfg Config) fiber.Handler {
var (
// Limiter variables
mux = &sync.RWMutex{}
max = strconv.Itoa(cfg.Max)
expiration = uint64(cfg.Expiration.Seconds())
)
// Create manager to simplify storage operations ( see manager.go )
manager := newManager(cfg.Storage)
// Update timestamp every second
utils.StartTimeStampUpdater()
// Return new handler
return func(c *fiber.Ctx) error {
// Don't execute middleware if Next returns true
if cfg.Next != nil && cfg.Next(c) {
return c.Next()
}
// Get key from request
key := cfg.KeyGenerator(c)
// Lock entry
mux.Lock()
// Get entry from pool and release when finished
e := manager.get(key)
// Get timestamp
ts := uint64(atomic.LoadUint32(&utils.Timestamp))
// Set expiration if entry does not exist
if e.exp == 0 {
e.exp = ts + expiration
} else if ts >= e.exp {
// Check if entry is expired
e.currHits = 0
e.exp = ts + expiration
}
// Increment hits
e.currHits++
// Calculate when it resets in seconds
resetInSec := e.exp - ts
// Set how many hits we have left
remaining := cfg.Max - e.currHits
// Update storage
manager.set(key, e, cfg.Expiration)
// Unlock entry
mux.Unlock()
// Check if hits exceed the cfg.Max
if remaining < 0 {
// Return response with Retry-After header
// https://tools.ietf.org/html/rfc6584
c.Set(fiber.HeaderRetryAfter, strconv.FormatUint(resetInSec, 10))
// Call LimitReached handler
return cfg.LimitReached(c)
}
// Continue stack for reaching c.Response().StatusCode()
// Store err for returning
err := c.Next()
// Check for SkipFailedRequests and SkipSuccessfulRequests
if (cfg.SkipSuccessfulRequests && c.Response().StatusCode() < fiber.StatusBadRequest) ||
(cfg.SkipFailedRequests && c.Response().StatusCode() >= fiber.StatusBadRequest) {
// Lock entry
mux.Lock()
e = manager.get(key)
e.currHits--
remaining++
manager.set(key, e, cfg.Expiration)
// Unlock entry
mux.Unlock()
}
// We can continue, update RateLimit headers
c.Set(xRateLimitLimit, max)
c.Set(xRateLimitRemaining, strconv.Itoa(remaining))
c.Set(xRateLimitReset, strconv.FormatUint(resetInSec, 10))
return err
}
}

View file

@ -0,0 +1,137 @@
package limiter
import (
"strconv"
"sync"
"sync/atomic"
"time"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/utils"
)
type SlidingWindow struct{}
// New creates a new sliding window middleware handler
func (SlidingWindow) New(cfg Config) fiber.Handler {
var (
// Limiter variables
mux = &sync.RWMutex{}
max = strconv.Itoa(cfg.Max)
expiration = uint64(cfg.Expiration.Seconds())
)
// Create manager to simplify storage operations ( see manager.go )
manager := newManager(cfg.Storage)
// Update timestamp every second
utils.StartTimeStampUpdater()
// Return new handler
return func(c *fiber.Ctx) error {
// Don't execute middleware if Next returns true
if cfg.Next != nil && cfg.Next(c) {
return c.Next()
}
// Get key from request
key := cfg.KeyGenerator(c)
// Lock entry
mux.Lock()
// Get entry from pool and release when finished
e := manager.get(key)
// Get timestamp
ts := uint64(atomic.LoadUint32(&utils.Timestamp))
// Set expiration if entry does not exist
if e.exp == 0 {
e.exp = ts + expiration
} else if ts >= e.exp {
// The entry has expired, handle the expiration.
// Set the prevHits to the current hits and reset the hits to 0.
e.prevHits = e.currHits
// Reset the current hits to 0.
e.currHits = 0
// Check how much into the current window it currently is and sets the
// expiry based on that, otherwise this would only reset on
// the next request and not show the correct expiry.
elapsed := ts - e.exp
if elapsed >= expiration {
e.exp = ts + expiration
} else {
e.exp = ts + expiration - elapsed
}
}
// Increment hits
e.currHits++
// Calculate when it resets in seconds
resetInSec := e.exp - ts
// weight = time until current window reset / total window length
weight := float64(resetInSec) / float64(expiration)
// rate = request count in previous window - weight + request count in current window
rate := int(float64(e.prevHits)*weight) + e.currHits
// Calculate how many hits can be made based on the current rate
remaining := cfg.Max - rate
// Update storage. Garbage collect when the next window ends.
// |--------------------------|--------------------------|
// ^ ^ ^ ^
// ts e.exp End sample window End next window
// <------------>
// resetInSec
// resetInSec = e.exp - ts - time until end of current window.
// duration + expiration = end of next window.
// Because we don't want to garbage collect in the middle of a window
// we add the expiration to the duration.
// Otherwise after the end of "sample window", attackers could launch
// a new request with the full window length.
manager.set(key, e, time.Duration(resetInSec+expiration)*time.Second)
// Unlock entry
mux.Unlock()
// Check if hits exceed the cfg.Max
if remaining < 0 {
// Return response with Retry-After header
// https://tools.ietf.org/html/rfc6584
c.Set(fiber.HeaderRetryAfter, strconv.FormatUint(resetInSec, 10))
// Call LimitReached handler
return cfg.LimitReached(c)
}
// Continue stack for reaching c.Response().StatusCode()
// Store err for returning
err := c.Next()
// Check for SkipFailedRequests and SkipSuccessfulRequests
if (cfg.SkipSuccessfulRequests && c.Response().StatusCode() < fiber.StatusBadRequest) ||
(cfg.SkipFailedRequests && c.Response().StatusCode() >= fiber.StatusBadRequest) {
// Lock entry
mux.Lock()
e = manager.get(key)
e.currHits--
remaining++
manager.set(key, e, cfg.Expiration)
// Unlock entry
mux.Unlock()
}
// We can continue, update RateLimit headers
c.Set(xRateLimitLimit, max)
c.Set(xRateLimitRemaining, strconv.Itoa(remaining))
c.Set(xRateLimitReset, strconv.FormatUint(resetInSec, 10))
return err
}
}

View file

@ -0,0 +1,727 @@
package limiter
import (
"io"
"net/http/httptest"
"sync"
"testing"
"time"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/internal/storage/memory"
"github.com/gofiber/fiber/v2/utils"
"github.com/valyala/fasthttp"
)
// go test -run Test_Limiter_Concurrency_Store -race -v
func Test_Limiter_Concurrency_Store(t *testing.T) {
t.Parallel()
// Test concurrency using a custom store
app := fiber.New()
app.Use(New(Config{
Max: 50,
Expiration: 2 * time.Second,
Storage: memory.New(),
}))
app.Get("/", func(c *fiber.Ctx) error {
return c.SendString("Hello tester!")
})
var wg sync.WaitGroup
singleRequest := func(wg *sync.WaitGroup) {
defer wg.Done()
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
body, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "Hello tester!", string(body))
}
for i := 0; i <= 49; i++ {
wg.Add(1)
go singleRequest(&wg)
}
wg.Wait()
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 429, resp.StatusCode)
time.Sleep(3 * time.Second)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode)
}
// go test -run Test_Limiter_Concurrency -race -v
func Test_Limiter_Concurrency(t *testing.T) {
t.Parallel()
// Test concurrency using a default store
app := fiber.New()
app.Use(New(Config{
Max: 50,
Expiration: 2 * time.Second,
}))
app.Get("/", func(c *fiber.Ctx) error {
return c.SendString("Hello tester!")
})
var wg sync.WaitGroup
singleRequest := func(wg *sync.WaitGroup) {
defer wg.Done()
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
body, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "Hello tester!", string(body))
}
for i := 0; i <= 49; i++ {
wg.Add(1)
go singleRequest(&wg)
}
wg.Wait()
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 429, resp.StatusCode)
time.Sleep(3 * time.Second)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode)
}
// go test -run Test_Limiter_Fixed_Window_No_Skip_Choices -v
func Test_Limiter_Fixed_Window_No_Skip_Choices(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Max: 2,
Expiration: 2 * time.Second,
SkipFailedRequests: false,
SkipSuccessfulRequests: false,
LimiterMiddleware: FixedWindow{},
}))
app.Get("/:status", func(c *fiber.Ctx) error {
if c.Params("status") == "fail" { //nolint:goconst // False positive
return c.SendStatus(400)
}
return c.SendStatus(200)
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 400, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 429, resp.StatusCode)
time.Sleep(3 * time.Second)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode)
}
// go test -run Test_Limiter_Fixed_Window_Custom_Storage_No_Skip_Choices -v
func Test_Limiter_Fixed_Window_Custom_Storage_No_Skip_Choices(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Max: 2,
Expiration: 2 * time.Second,
SkipFailedRequests: false,
SkipSuccessfulRequests: false,
Storage: memory.New(),
LimiterMiddleware: FixedWindow{},
}))
app.Get("/:status", func(c *fiber.Ctx) error {
if c.Params("status") == "fail" {
return c.SendStatus(400)
}
return c.SendStatus(200)
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 400, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 429, resp.StatusCode)
time.Sleep(3 * time.Second)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode)
}
// go test -run Test_Limiter_Sliding_Window_No_Skip_Choices -v
func Test_Limiter_Sliding_Window_No_Skip_Choices(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Max: 2,
Expiration: 2 * time.Second,
SkipFailedRequests: false,
SkipSuccessfulRequests: false,
LimiterMiddleware: SlidingWindow{},
}))
app.Get("/:status", func(c *fiber.Ctx) error {
if c.Params("status") == "fail" {
return c.SendStatus(400)
}
return c.SendStatus(200)
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 400, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 429, resp.StatusCode)
time.Sleep(4*time.Second + 500*time.Millisecond)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode)
}
// go test -run Test_Limiter_Sliding_Window_Custom_Storage_No_Skip_Choices -v
func Test_Limiter_Sliding_Window_Custom_Storage_No_Skip_Choices(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Max: 2,
Expiration: 2 * time.Second,
SkipFailedRequests: false,
SkipSuccessfulRequests: false,
Storage: memory.New(),
LimiterMiddleware: SlidingWindow{},
}))
app.Get("/:status", func(c *fiber.Ctx) error {
if c.Params("status") == "fail" {
return c.SendStatus(400)
}
return c.SendStatus(200)
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 400, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 429, resp.StatusCode)
time.Sleep(4*time.Second + 500*time.Millisecond)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode)
}
// go test -run Test_Limiter_Fixed_Window_Skip_Failed_Requests -v
func Test_Limiter_Fixed_Window_Skip_Failed_Requests(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Max: 1,
Expiration: 2 * time.Second,
SkipFailedRequests: true,
LimiterMiddleware: FixedWindow{},
}))
app.Get("/:status", func(c *fiber.Ctx) error {
if c.Params("status") == "fail" {
return c.SendStatus(400)
}
return c.SendStatus(200)
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 400, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 429, resp.StatusCode)
time.Sleep(3 * time.Second)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode)
}
// go test -run Test_Limiter_Fixed_Window_Custom_Storage_Skip_Failed_Requests -v
func Test_Limiter_Fixed_Window_Custom_Storage_Skip_Failed_Requests(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Max: 1,
Expiration: 2 * time.Second,
Storage: memory.New(),
SkipFailedRequests: true,
LimiterMiddleware: FixedWindow{},
}))
app.Get("/:status", func(c *fiber.Ctx) error {
if c.Params("status") == "fail" {
return c.SendStatus(400)
}
return c.SendStatus(200)
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 400, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 429, resp.StatusCode)
time.Sleep(3 * time.Second)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode)
}
// go test -run Test_Limiter_Sliding_Window_Skip_Failed_Requests -v
func Test_Limiter_Sliding_Window_Skip_Failed_Requests(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Max: 1,
Expiration: 2 * time.Second,
SkipFailedRequests: true,
LimiterMiddleware: SlidingWindow{},
}))
app.Get("/:status", func(c *fiber.Ctx) error {
if c.Params("status") == "fail" {
return c.SendStatus(400)
}
return c.SendStatus(200)
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 400, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 429, resp.StatusCode)
time.Sleep(4*time.Second + 500*time.Millisecond)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode)
}
// go test -run Test_Limiter_Sliding_Window_Custom_Storage_Skip_Failed_Requests -v
func Test_Limiter_Sliding_Window_Custom_Storage_Skip_Failed_Requests(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Max: 1,
Expiration: 2 * time.Second,
Storage: memory.New(),
SkipFailedRequests: true,
LimiterMiddleware: SlidingWindow{},
}))
app.Get("/:status", func(c *fiber.Ctx) error {
if c.Params("status") == "fail" {
return c.SendStatus(400)
}
return c.SendStatus(200)
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 400, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 429, resp.StatusCode)
time.Sleep(4*time.Second + 500*time.Millisecond)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode)
}
// go test -run Test_Limiter_Fixed_Window_Skip_Successful_Requests -v
func Test_Limiter_Fixed_Window_Skip_Successful_Requests(t *testing.T) {
t.Parallel()
// Test concurrency using a default store
app := fiber.New()
app.Use(New(Config{
Max: 1,
Expiration: 2 * time.Second,
SkipSuccessfulRequests: true,
LimiterMiddleware: FixedWindow{},
}))
app.Get("/:status", func(c *fiber.Ctx) error {
if c.Params("status") == "fail" {
return c.SendStatus(400)
}
return c.SendStatus(200)
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 400, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 429, resp.StatusCode)
time.Sleep(3 * time.Second)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 400, resp.StatusCode)
}
// go test -run Test_Limiter_Fixed_Window_Custom_Storage_Skip_Successful_Requests -v
func Test_Limiter_Fixed_Window_Custom_Storage_Skip_Successful_Requests(t *testing.T) {
t.Parallel()
// Test concurrency using a default store
app := fiber.New()
app.Use(New(Config{
Max: 1,
Expiration: 2 * time.Second,
Storage: memory.New(),
SkipSuccessfulRequests: true,
LimiterMiddleware: FixedWindow{},
}))
app.Get("/:status", func(c *fiber.Ctx) error {
if c.Params("status") == "fail" {
return c.SendStatus(400)
}
return c.SendStatus(200)
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 400, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 429, resp.StatusCode)
time.Sleep(3 * time.Second)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 400, resp.StatusCode)
}
// go test -run Test_Limiter_Sliding_Window_Skip_Successful_Requests -v
func Test_Limiter_Sliding_Window_Skip_Successful_Requests(t *testing.T) {
t.Parallel()
// Test concurrency using a default store
app := fiber.New()
app.Use(New(Config{
Max: 1,
Expiration: 2 * time.Second,
SkipSuccessfulRequests: true,
LimiterMiddleware: SlidingWindow{},
}))
app.Get("/:status", func(c *fiber.Ctx) error {
if c.Params("status") == "fail" {
return c.SendStatus(400)
}
return c.SendStatus(200)
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 400, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 429, resp.StatusCode)
time.Sleep(4*time.Second + 500*time.Millisecond)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 400, resp.StatusCode)
}
// go test -run Test_Limiter_Sliding_Window_Custom_Storage_Skip_Successful_Requests -v
func Test_Limiter_Sliding_Window_Custom_Storage_Skip_Successful_Requests(t *testing.T) {
t.Parallel()
// Test concurrency using a default store
app := fiber.New()
app.Use(New(Config{
Max: 1,
Expiration: 2 * time.Second,
Storage: memory.New(),
SkipSuccessfulRequests: true,
LimiterMiddleware: SlidingWindow{},
}))
app.Get("/:status", func(c *fiber.Ctx) error {
if c.Params("status") == "fail" {
return c.SendStatus(400)
}
return c.SendStatus(200)
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 400, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 429, resp.StatusCode)
time.Sleep(4*time.Second + 500*time.Millisecond)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 400, resp.StatusCode)
}
// go test -v -run=^$ -bench=Benchmark_Limiter_Custom_Store -benchmem -count=4
func Benchmark_Limiter_Custom_Store(b *testing.B) {
app := fiber.New()
app.Use(New(Config{
Max: 100,
Expiration: 60 * time.Second,
Storage: memory.New(),
}))
app.Get("/", func(c *fiber.Ctx) error {
return c.SendString("Hello, World!")
})
h := app.Handler()
fctx := &fasthttp.RequestCtx{}
fctx.Request.Header.SetMethod(fiber.MethodGet)
fctx.Request.SetRequestURI("/")
b.ResetTimer()
for n := 0; n < b.N; n++ {
h(fctx)
}
}
// go test -run Test_Limiter_Next
func Test_Limiter_Next(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Next: func(_ *fiber.Ctx) bool {
return true
},
}))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
}
func Test_Limiter_Headers(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Max: 50,
Expiration: 2 * time.Second,
}))
app.Get("/", func(c *fiber.Ctx) error {
return c.SendString("Hello tester!")
})
fctx := &fasthttp.RequestCtx{}
fctx.Request.Header.SetMethod(fiber.MethodGet)
fctx.Request.SetRequestURI("/")
app.Handler()(fctx)
utils.AssertEqual(t, "50", string(fctx.Response.Header.Peek("X-RateLimit-Limit")))
if v := string(fctx.Response.Header.Peek("X-RateLimit-Remaining")); v == "" {
t.Errorf("The X-RateLimit-Remaining header is not set correctly - value is an empty string.")
}
if v := string(fctx.Response.Header.Peek("X-RateLimit-Reset")); !(v == "1" || v == "2") {
t.Errorf("The X-RateLimit-Reset header is not set correctly - value is out of bounds.")
}
}
// go test -v -run=^$ -bench=Benchmark_Limiter -benchmem -count=4
func Benchmark_Limiter(b *testing.B) {
app := fiber.New()
app.Use(New(Config{
Max: 100,
Expiration: 60 * time.Second,
}))
app.Get("/", func(c *fiber.Ctx) error {
return c.SendString("Hello, World!")
})
h := app.Handler()
fctx := &fasthttp.RequestCtx{}
fctx.Request.Header.SetMethod(fiber.MethodGet)
fctx.Request.SetRequestURI("/")
b.ResetTimer()
for n := 0; n < b.N; n++ {
h(fctx)
}
}
// go test -run Test_Sliding_Window -race -v
func Test_Sliding_Window(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Max: 10,
Expiration: 2 * time.Second,
Storage: memory.New(),
LimiterMiddleware: SlidingWindow{},
}))
app.Get("/", func(c *fiber.Ctx) error {
return c.SendString("Hello tester!")
})
singleRequest := func(shouldFail bool) {
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
if shouldFail {
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 429, resp.StatusCode)
} else {
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
}
}
for i := 0; i < 5; i++ {
singleRequest(false)
}
time.Sleep(2 * time.Second)
for i := 0; i < 5; i++ {
singleRequest(false)
}
time.Sleep(3 * time.Second)
for i := 0; i < 5; i++ {
singleRequest(false)
}
time.Sleep(4 * time.Second)
for i := 0; i < 9; i++ {
singleRequest(false)
}
}

View file

@ -0,0 +1,92 @@
package limiter
import (
"sync"
"time"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/internal/memory"
)
// go:generate msgp
// msgp -file="manager.go" -o="manager_msgp.go" -tests=false -unexported
type item struct {
currHits int
prevHits int
exp uint64
}
//msgp:ignore manager
type manager struct {
pool sync.Pool
memory *memory.Storage
storage fiber.Storage
}
func newManager(storage fiber.Storage) *manager {
// Create new storage handler
manager := &manager{
pool: sync.Pool{
New: func() interface{} {
return new(item)
},
},
}
if storage != nil {
// Use provided storage if provided
manager.storage = storage
} else {
// Fallback too memory storage
manager.memory = memory.New()
}
return manager
}
// acquire returns an *entry from the sync.Pool
func (m *manager) acquire() *item {
return m.pool.Get().(*item) //nolint:forcetypeassert // We store nothing else in the pool
}
// release and reset *entry to sync.Pool
func (m *manager) release(e *item) {
e.prevHits = 0
e.currHits = 0
e.exp = 0
m.pool.Put(e)
}
// get data from storage or memory
func (m *manager) get(key string) *item {
var it *item
if m.storage != nil {
it = m.acquire()
raw, err := m.storage.Get(key)
if err != nil {
return it
}
if raw != nil {
if _, err := it.UnmarshalMsg(raw); err != nil {
return it
}
}
return it
}
if it, _ = m.memory.Get(key).(*item); it == nil { //nolint:errcheck // We store nothing else in the pool
it = m.acquire()
return it
}
return it
}
// set data to storage or memory
func (m *manager) set(key string, it *item, exp time.Duration) {
if m.storage != nil {
if raw, err := it.MarshalMsg(nil); err == nil {
_ = m.storage.Set(key, raw, exp) //nolint:errcheck // TODO: Handle error here
}
// we can release data because it's serialized to database
m.release(it)
} else {
m.memory.Set(key, it, exp)
}
}

View file

@ -0,0 +1,160 @@
package limiter
// Code generated by github.com/tinylib/msgp DO NOT EDIT.
import (
"github.com/tinylib/msgp/msgp"
)
// DecodeMsg implements msgp.Decodable
func (z *item) DecodeMsg(dc *msgp.Reader) (err error) {
var field []byte
_ = field
var zb0001 uint32
zb0001, err = dc.ReadMapHeader()
if err != nil {
err = msgp.WrapError(err)
return
}
for zb0001 > 0 {
zb0001--
field, err = dc.ReadMapKeyPtr()
if err != nil {
err = msgp.WrapError(err)
return
}
switch msgp.UnsafeString(field) {
case "currHits":
z.currHits, err = dc.ReadInt()
if err != nil {
err = msgp.WrapError(err, "currHits")
return
}
case "prevHits":
z.prevHits, err = dc.ReadInt()
if err != nil {
err = msgp.WrapError(err, "prevHits")
return
}
case "exp":
z.exp, err = dc.ReadUint64()
if err != nil {
err = msgp.WrapError(err, "exp")
return
}
default:
err = dc.Skip()
if err != nil {
err = msgp.WrapError(err)
return
}
}
}
return
}
// EncodeMsg implements msgp.Encodable
func (z item) EncodeMsg(en *msgp.Writer) (err error) {
// map header, size 3
// write "currHits"
err = en.Append(0x83, 0xa8, 0x63, 0x75, 0x72, 0x72, 0x48, 0x69, 0x74, 0x73)
if err != nil {
return
}
err = en.WriteInt(z.currHits)
if err != nil {
err = msgp.WrapError(err, "currHits")
return
}
// write "prevHits"
err = en.Append(0xa8, 0x70, 0x72, 0x65, 0x76, 0x48, 0x69, 0x74, 0x73)
if err != nil {
return
}
err = en.WriteInt(z.prevHits)
if err != nil {
err = msgp.WrapError(err, "prevHits")
return
}
// write "exp"
err = en.Append(0xa3, 0x65, 0x78, 0x70)
if err != nil {
return
}
err = en.WriteUint64(z.exp)
if err != nil {
err = msgp.WrapError(err, "exp")
return
}
return
}
// MarshalMsg implements msgp.Marshaler
func (z item) MarshalMsg(b []byte) (o []byte, err error) {
o = msgp.Require(b, z.Msgsize())
// map header, size 3
// string "currHits"
o = append(o, 0x83, 0xa8, 0x63, 0x75, 0x72, 0x72, 0x48, 0x69, 0x74, 0x73)
o = msgp.AppendInt(o, z.currHits)
// string "prevHits"
o = append(o, 0xa8, 0x70, 0x72, 0x65, 0x76, 0x48, 0x69, 0x74, 0x73)
o = msgp.AppendInt(o, z.prevHits)
// string "exp"
o = append(o, 0xa3, 0x65, 0x78, 0x70)
o = msgp.AppendUint64(o, z.exp)
return
}
// UnmarshalMsg implements msgp.Unmarshaler
func (z *item) UnmarshalMsg(bts []byte) (o []byte, err error) {
var field []byte
_ = field
var zb0001 uint32
zb0001, bts, err = msgp.ReadMapHeaderBytes(bts)
if err != nil {
err = msgp.WrapError(err)
return
}
for zb0001 > 0 {
zb0001--
field, bts, err = msgp.ReadMapKeyZC(bts)
if err != nil {
err = msgp.WrapError(err)
return
}
switch msgp.UnsafeString(field) {
case "currHits":
z.currHits, bts, err = msgp.ReadIntBytes(bts)
if err != nil {
err = msgp.WrapError(err, "currHits")
return
}
case "prevHits":
z.prevHits, bts, err = msgp.ReadIntBytes(bts)
if err != nil {
err = msgp.WrapError(err, "prevHits")
return
}
case "exp":
z.exp, bts, err = msgp.ReadUint64Bytes(bts)
if err != nil {
err = msgp.WrapError(err, "exp")
return
}
default:
bts, err = msgp.Skip(bts)
if err != nil {
err = msgp.WrapError(err)
return
}
}
}
o = bts
return
}
// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message
func (z item) Msgsize() (s int) {
s = 1 + 9 + msgp.IntSize + 9 + msgp.IntSize + 4 + msgp.Uint64Size
return
}