1
0
Fork 0

Adding upstream version 2.52.6.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-05-17 06:50:16 +02:00
parent a960158181
commit 6d002e9543
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
441 changed files with 95392 additions and 0 deletions

View file

@ -0,0 +1,171 @@
package adaptor
import (
"io"
"net"
"net/http"
"reflect"
"unsafe"
"github.com/valyala/fasthttp"
"github.com/valyala/fasthttp/fasthttpadaptor"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/utils"
)
// HTTPHandlerFunc wraps net/http handler func to fiber handler
func HTTPHandlerFunc(h http.HandlerFunc) fiber.Handler {
return HTTPHandler(h)
}
// HTTPHandler wraps net/http handler to fiber handler
func HTTPHandler(h http.Handler) fiber.Handler {
return func(c *fiber.Ctx) error {
handler := fasthttpadaptor.NewFastHTTPHandler(h)
handler(c.Context())
return nil
}
}
// ConvertRequest converts a fiber.Ctx to a http.Request.
// forServer should be set to true when the http.Request is going to be passed to a http.Handler.
func ConvertRequest(c *fiber.Ctx, forServer bool) (*http.Request, error) {
var req http.Request
if err := fasthttpadaptor.ConvertRequest(c.Context(), &req, forServer); err != nil {
return nil, err //nolint:wrapcheck // This must not be wrapped
}
return &req, nil
}
// CopyContextToFiberContext copies the values of context.Context to a fasthttp.RequestCtx
func CopyContextToFiberContext(context interface{}, requestContext *fasthttp.RequestCtx) {
contextValues := reflect.ValueOf(context).Elem()
contextKeys := reflect.TypeOf(context).Elem()
if contextKeys.Kind() == reflect.Struct {
var lastKey interface{}
for i := 0; i < contextValues.NumField(); i++ {
reflectValue := contextValues.Field(i)
/* #nosec */
reflectValue = reflect.NewAt(reflectValue.Type(), unsafe.Pointer(reflectValue.UnsafeAddr())).Elem()
reflectField := contextKeys.Field(i)
if reflectField.Name == "noCopy" {
break
} else if reflectField.Name == "Context" {
CopyContextToFiberContext(reflectValue.Interface(), requestContext)
} else if reflectField.Name == "key" {
lastKey = reflectValue.Interface()
} else if lastKey != nil && reflectField.Name == "val" {
requestContext.SetUserValue(lastKey, reflectValue.Interface())
} else {
lastKey = nil
}
}
}
}
// HTTPMiddleware wraps net/http middleware to fiber middleware
func HTTPMiddleware(mw func(http.Handler) http.Handler) fiber.Handler {
return func(c *fiber.Ctx) error {
var next bool
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
next = true
// Convert again in case request may modify by middleware
c.Request().Header.SetMethod(r.Method)
c.Request().SetRequestURI(r.RequestURI)
c.Request().SetHost(r.Host)
c.Request().Header.SetHost(r.Host)
for key, val := range r.Header {
for _, v := range val {
c.Request().Header.Set(key, v)
}
}
CopyContextToFiberContext(r.Context(), c.Context())
})
if err := HTTPHandler(mw(nextHandler))(c); err != nil {
return err
}
if next {
return c.Next()
}
return nil
}
}
// FiberHandler wraps fiber handler to net/http handler
func FiberHandler(h fiber.Handler) http.Handler {
return FiberHandlerFunc(h)
}
// FiberHandlerFunc wraps fiber handler to net/http handler func
func FiberHandlerFunc(h fiber.Handler) http.HandlerFunc {
return handlerFunc(fiber.New(), h)
}
// FiberApp wraps fiber app to net/http handler func
func FiberApp(app *fiber.App) http.HandlerFunc {
return handlerFunc(app)
}
func handlerFunc(app *fiber.App, h ...fiber.Handler) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// New fasthttp request
req := fasthttp.AcquireRequest()
defer fasthttp.ReleaseRequest(req)
// Convert net/http -> fasthttp request
if r.Body != nil {
n, err := io.Copy(req.BodyWriter(), r.Body)
req.Header.SetContentLength(int(n))
if err != nil {
http.Error(w, utils.StatusMessage(fiber.StatusInternalServerError), fiber.StatusInternalServerError)
return
}
}
req.Header.SetMethod(r.Method)
req.SetRequestURI(r.RequestURI)
req.SetHost(r.Host)
req.Header.SetHost(r.Host)
for key, val := range r.Header {
for _, v := range val {
req.Header.Set(key, v)
}
}
if _, _, err := net.SplitHostPort(r.RemoteAddr); err != nil && err.(*net.AddrError).Err == "missing port in address" { //nolint:errorlint, forcetypeassert // overlinting
r.RemoteAddr = net.JoinHostPort(r.RemoteAddr, "80")
}
remoteAddr, err := net.ResolveTCPAddr("tcp", r.RemoteAddr)
if err != nil {
http.Error(w, utils.StatusMessage(fiber.StatusInternalServerError), fiber.StatusInternalServerError)
return
}
// New fasthttp Ctx
var fctx fasthttp.RequestCtx
fctx.Init(req, remoteAddr, nil)
if len(h) > 0 {
// New fiber Ctx
ctx := app.AcquireCtx(&fctx)
defer app.ReleaseCtx(ctx)
// Execute fiber Ctx
err := h[0](ctx)
if err != nil {
_ = app.Config().ErrorHandler(ctx, err) //nolint:errcheck // not needed
}
} else {
// Execute fasthttp Ctx though app.Handler
app.Handler()(&fctx)
}
// Convert fasthttp Ctx > net/http
fctx.Response.Header.VisitAll(func(k, v []byte) {
w.Header().Add(string(k), string(v))
})
w.WriteHeader(fctx.Response.StatusCode())
_, _ = w.Write(fctx.Response.Body()) //nolint:errcheck // not needed
}
}

View file

@ -0,0 +1,492 @@
//nolint:bodyclose, contextcheck, revive // Much easier to just ignore memory leaks in tests
package adaptor
import (
"bytes"
"context"
"fmt"
"io"
"net"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"github.com/valyala/fasthttp"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/utils"
)
func Test_HTTPHandler(t *testing.T) {
expectedMethod := fiber.MethodPost
expectedProto := "HTTP/1.1"
expectedProtoMajor := 1
expectedProtoMinor := 1
expectedRequestURI := "/foo/bar?baz=123"
expectedBody := "body 123 foo bar baz"
expectedContentLength := len(expectedBody)
expectedHost := "foobar.com"
expectedRemoteAddr := "1.2.3.4:6789"
expectedHeader := map[string]string{
"Foo-Bar": "baz",
"Abc": "defg",
"XXX-Remote-Addr": "123.43.4543.345",
}
expectedURL, err := url.ParseRequestURI(expectedRequestURI)
utils.AssertEqual(t, nil, err)
type contextKeyType string
expectedContextKey := contextKeyType("contextKey")
expectedContextValue := "contextValue"
callsCount := 0
nethttpH := func(w http.ResponseWriter, r *http.Request) {
callsCount++
utils.AssertEqual(t, expectedMethod, r.Method, "Method")
utils.AssertEqual(t, expectedProto, r.Proto, "Proto")
utils.AssertEqual(t, expectedProtoMajor, r.ProtoMajor, "ProtoMajor")
utils.AssertEqual(t, expectedProtoMinor, r.ProtoMinor, "ProtoMinor")
utils.AssertEqual(t, expectedRequestURI, r.RequestURI, "RequestURI")
utils.AssertEqual(t, expectedContentLength, int(r.ContentLength), "ContentLength")
utils.AssertEqual(t, 0, len(r.TransferEncoding), "TransferEncoding")
utils.AssertEqual(t, expectedHost, r.Host, "Host")
utils.AssertEqual(t, expectedRemoteAddr, r.RemoteAddr, "RemoteAddr")
body, err := io.ReadAll(r.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, expectedBody, string(body), "Body")
utils.AssertEqual(t, expectedURL, r.URL, "URL")
utils.AssertEqual(t, expectedContextValue, r.Context().Value(expectedContextKey), "Context")
for k, expectedV := range expectedHeader {
v := r.Header.Get(k)
utils.AssertEqual(t, expectedV, v, "Header")
}
w.Header().Set("Header1", "value1")
w.Header().Set("Header2", "value2")
w.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(w, "request body is %q", body)
}
fiberH := HTTPHandlerFunc(http.HandlerFunc(nethttpH))
fiberH = setFiberContextValueMiddleware(fiberH, expectedContextKey, expectedContextValue)
var fctx fasthttp.RequestCtx
var req fasthttp.Request
req.Header.SetMethod(expectedMethod)
req.SetRequestURI(expectedRequestURI)
req.Header.SetHost(expectedHost)
req.BodyWriter().Write([]byte(expectedBody)) //nolint:errcheck, gosec // not needed
for k, v := range expectedHeader {
req.Header.Set(k, v)
}
remoteAddr, err := net.ResolveTCPAddr("tcp", expectedRemoteAddr)
utils.AssertEqual(t, nil, err)
fctx.Init(&req, remoteAddr, nil)
app := fiber.New()
ctx := app.AcquireCtx(&fctx)
defer app.ReleaseCtx(ctx)
err = fiberH(ctx)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 1, callsCount, "callsCount")
resp := &fctx.Response
utils.AssertEqual(t, http.StatusBadRequest, resp.StatusCode(), "StatusCode")
utils.AssertEqual(t, "value1", string(resp.Header.Peek("Header1")), "Header1")
utils.AssertEqual(t, "value2", string(resp.Header.Peek("Header2")), "Header2")
expectedResponseBody := fmt.Sprintf("request body is %q", expectedBody)
utils.AssertEqual(t, expectedResponseBody, string(resp.Body()), "Body")
}
type contextKey string
func (c contextKey) String() string {
return "test-" + string(c)
}
var (
TestContextKey = contextKey("TestContextKey")
TestContextSecondKey = contextKey("TestContextSecondKey")
)
func Test_HTTPMiddleware(t *testing.T) {
const expectedHost = "foobar.com"
tests := []struct {
name string
url string
method string
statusCode int
}{
{
name: "Should return 200",
url: "/",
method: "POST",
statusCode: 200,
},
{
name: "Should return 405",
url: "/",
method: "GET",
statusCode: 405,
},
{
name: "Should return 400",
url: "/unknown",
method: "POST",
statusCode: 404,
},
}
nethttpMW := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
r = r.WithContext(context.WithValue(r.Context(), TestContextKey, "okay"))
r = r.WithContext(context.WithValue(r.Context(), TestContextSecondKey, "not_okay"))
r = r.WithContext(context.WithValue(r.Context(), TestContextSecondKey, "okay"))
next.ServeHTTP(w, r)
})
}
app := fiber.New()
app.Use(HTTPMiddleware(nethttpMW))
app.Post("/", func(c *fiber.Ctx) error {
value := c.Context().Value(TestContextKey)
val, ok := value.(string)
if !ok {
t.Error("unexpected error on type-assertion")
}
if value != nil {
c.Set("context_okay", val)
}
value = c.Context().Value(TestContextSecondKey)
if value != nil {
val, ok := value.(string)
if !ok {
t.Error("unexpected error on type-assertion")
}
c.Set("context_second_okay", val)
}
return c.SendStatus(fiber.StatusOK)
})
for _, tt := range tests {
req, err := http.NewRequestWithContext(context.Background(), tt.method, tt.url, nil)
req.Host = expectedHost
utils.AssertEqual(t, nil, err)
resp, err := app.Test(req)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, tt.statusCode, resp.StatusCode, "StatusCode")
}
req, err := http.NewRequestWithContext(context.Background(), fiber.MethodPost, "/", nil)
req.Host = expectedHost
utils.AssertEqual(t, nil, err)
resp, err := app.Test(req)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, resp.Header.Get("context_okay"), "okay")
utils.AssertEqual(t, resp.Header.Get("context_second_okay"), "okay")
}
func Test_FiberHandler(t *testing.T) {
testFiberToHandlerFunc(t, false)
}
func Test_FiberApp(t *testing.T) {
testFiberToHandlerFunc(t, false, fiber.New())
}
func Test_FiberHandlerDefaultPort(t *testing.T) {
testFiberToHandlerFunc(t, true)
}
func Test_FiberAppDefaultPort(t *testing.T) {
testFiberToHandlerFunc(t, true, fiber.New())
}
func testFiberToHandlerFunc(t *testing.T, checkDefaultPort bool, app ...*fiber.App) {
t.Helper()
expectedMethod := fiber.MethodPost
expectedRequestURI := "/foo/bar?baz=123"
expectedBody := "body 123 foo bar baz"
expectedContentLength := len(expectedBody)
expectedHost := "foobar.com"
expectedRemoteAddr := "1.2.3.4:6789"
if checkDefaultPort {
expectedRemoteAddr = "1.2.3.4:80"
}
expectedHeader := map[string]string{
"Foo-Bar": "baz",
"Abc": "defg",
"XXX-Remote-Addr": "123.43.4543.345",
}
expectedURL, err := url.ParseRequestURI(expectedRequestURI)
utils.AssertEqual(t, nil, err)
callsCount := 0
fiberH := func(c *fiber.Ctx) error {
callsCount++
utils.AssertEqual(t, expectedMethod, c.Method(), "Method")
utils.AssertEqual(t, expectedRequestURI, string(c.Context().RequestURI()), "RequestURI")
utils.AssertEqual(t, expectedContentLength, c.Context().Request.Header.ContentLength(), "ContentLength")
utils.AssertEqual(t, expectedHost, c.Hostname(), "Host")
utils.AssertEqual(t, expectedHost, string(c.Request().Header.Host()), "Host")
utils.AssertEqual(t, "http://"+expectedHost, c.BaseURL(), "BaseURL")
utils.AssertEqual(t, expectedRemoteAddr, c.Context().RemoteAddr().String(), "RemoteAddr")
body := string(c.Body())
utils.AssertEqual(t, expectedBody, body, "Body")
utils.AssertEqual(t, expectedURL.String(), c.OriginalURL(), "URL")
for k, expectedV := range expectedHeader {
v := c.Get(k)
utils.AssertEqual(t, expectedV, v, "Header")
}
c.Set("Header1", "value1")
c.Set("Header2", "value2")
c.Status(fiber.StatusBadRequest)
_, err := c.Write([]byte(fmt.Sprintf("request body is %q", body)))
return err
}
var handlerFunc http.HandlerFunc
if len(app) > 0 {
app[0].Post("/foo/bar", fiberH)
handlerFunc = FiberApp(app[0])
} else {
handlerFunc = FiberHandlerFunc(fiberH)
}
var r http.Request
r.Method = expectedMethod
r.Body = &netHTTPBody{[]byte(expectedBody)}
r.RequestURI = expectedRequestURI
r.ContentLength = int64(expectedContentLength)
r.Host = expectedHost
r.RemoteAddr = expectedRemoteAddr
if checkDefaultPort {
r.RemoteAddr = "1.2.3.4"
}
hdr := make(http.Header)
for k, v := range expectedHeader {
hdr.Set(k, v)
}
r.Header = hdr
var w netHTTPResponseWriter
handlerFunc.ServeHTTP(&w, &r)
utils.AssertEqual(t, http.StatusBadRequest, w.StatusCode(), "StatusCode")
utils.AssertEqual(t, "value1", w.Header().Get("Header1"), "Header1")
utils.AssertEqual(t, "value2", w.Header().Get("Header2"), "Header2")
expectedResponseBody := fmt.Sprintf("request body is %q", expectedBody)
utils.AssertEqual(t, expectedResponseBody, string(w.body), "Body")
}
func setFiberContextValueMiddleware(next fiber.Handler, key, value interface{}) fiber.Handler {
return func(c *fiber.Ctx) error {
c.Locals(key, value)
return next(c)
}
}
func Test_FiberHandler_RequestNilBody(t *testing.T) {
expectedMethod := fiber.MethodGet
expectedRequestURI := "/foo/bar"
expectedContentLength := 0
callsCount := 0
fiberH := func(c *fiber.Ctx) error {
callsCount++
utils.AssertEqual(t, expectedMethod, c.Method(), "Method")
utils.AssertEqual(t, expectedRequestURI, string(c.Context().RequestURI()), "RequestURI")
utils.AssertEqual(t, expectedContentLength, c.Context().Request.Header.ContentLength(), "ContentLength")
_, err := c.Write([]byte("request body is nil"))
return err
}
nethttpH := FiberHandler(fiberH)
var r http.Request
r.Method = expectedMethod
r.RequestURI = expectedRequestURI
var w netHTTPResponseWriter
nethttpH.ServeHTTP(&w, &r)
expectedResponseBody := "request body is nil"
utils.AssertEqual(t, expectedResponseBody, string(w.body), "Body")
}
type netHTTPBody struct {
b []byte
}
func (r *netHTTPBody) Read(p []byte) (int, error) {
if len(r.b) == 0 {
return 0, io.EOF
}
n := copy(p, r.b)
r.b = r.b[n:]
return n, nil
}
func (r *netHTTPBody) Close() error {
r.b = r.b[:0]
return nil
}
type netHTTPResponseWriter struct {
statusCode int
h http.Header
body []byte
}
func (w *netHTTPResponseWriter) StatusCode() int {
if w.statusCode == 0 {
return http.StatusOK
}
return w.statusCode
}
func (w *netHTTPResponseWriter) Header() http.Header {
if w.h == nil {
w.h = make(http.Header)
}
return w.h
}
func (w *netHTTPResponseWriter) WriteHeader(statusCode int) {
w.statusCode = statusCode
}
func (w *netHTTPResponseWriter) Write(p []byte) (int, error) {
w.body = append(w.body, p...)
return len(p), nil
}
func Test_ConvertRequest(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Get("/test", func(c *fiber.Ctx) error {
httpReq, err := ConvertRequest(c, false)
if err != nil {
return err
}
return c.SendString("Request URL: " + httpReq.URL.String())
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test?hello=world&another=test", http.NoBody))
utils.AssertEqual(t, nil, err, "app.Test(req)")
utils.AssertEqual(t, http.StatusOK, resp.StatusCode, "Status code")
body, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "Request URL: /test?hello=world&another=test", string(body))
}
// Benchmark for FiberHandlerFunc
func Benchmark_FiberHandlerFunc_1MB(b *testing.B) {
fiberH := func(c *fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
}
handlerFunc := FiberHandlerFunc(fiberH)
// Create body content
bodyContent := make([]byte, 1*1024*1024)
bodyBuffer := bytes.NewBuffer(bodyContent)
r := http.Request{
Method: http.MethodPost,
Body: http.NoBody,
}
// Replace the empty Body with our buffer
r.Body = io.NopCloser(bodyBuffer)
defer r.Body.Close() //nolint:errcheck // not needed
// Create recorder
w := httptest.NewRecorder()
b.ResetTimer()
for i := 0; i < b.N; i++ {
handlerFunc.ServeHTTP(w, &r)
}
}
func Benchmark_FiberHandlerFunc_10MB(b *testing.B) {
fiberH := func(c *fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
}
handlerFunc := FiberHandlerFunc(fiberH)
// Create body content
bodyContent := make([]byte, 10*1024*1024)
bodyBuffer := bytes.NewBuffer(bodyContent)
r := http.Request{
Method: http.MethodPost,
Body: http.NoBody,
}
// Replace the empty Body with our buffer
r.Body = io.NopCloser(bodyBuffer)
defer r.Body.Close() //nolint:errcheck // not needed
// Create recorder
w := httptest.NewRecorder()
b.ResetTimer()
for i := 0; i < b.N; i++ {
handlerFunc.ServeHTTP(w, &r)
}
}
func Benchmark_FiberHandlerFunc_50MB(b *testing.B) {
fiberH := func(c *fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
}
handlerFunc := FiberHandlerFunc(fiberH)
// Create body content
bodyContent := make([]byte, 50*1024*1024)
bodyBuffer := bytes.NewBuffer(bodyContent)
r := http.Request{
Method: http.MethodPost,
Body: http.NoBody,
}
// Replace the empty Body with our buffer
r.Body = io.NopCloser(bodyBuffer)
defer r.Body.Close() //nolint:errcheck // not needed
// Create recorder
w := httptest.NewRecorder()
b.ResetTimer()
for i := 0; i < b.N; i++ {
handlerFunc.ServeHTTP(w, &r)
}
}

View file

@ -0,0 +1,60 @@
package basicauth
import (
"encoding/base64"
"strings"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/utils"
)
// New creates a new middleware handler
func New(config Config) fiber.Handler {
// Set default config
cfg := configDefault(config)
// Return new handler
return func(c *fiber.Ctx) error {
// Don't execute middleware if Next returns true
if cfg.Next != nil && cfg.Next(c) {
return c.Next()
}
// Get authorization header
auth := c.Get(fiber.HeaderAuthorization)
// Check if the header contains content besides "basic".
if len(auth) <= 6 || !utils.EqualFold(auth[:6], "basic ") {
return cfg.Unauthorized(c)
}
// Decode the header contents
raw, err := base64.StdEncoding.DecodeString(auth[6:])
if err != nil {
return cfg.Unauthorized(c)
}
// Get the credentials
creds := utils.UnsafeString(raw)
// Check if the credentials are in the correct form
// which is "username:password".
index := strings.Index(creds, ":")
if index == -1 {
return cfg.Unauthorized(c)
}
// Get the username and password
username := creds[:index]
password := creds[index+1:]
if cfg.Authorizer(username, password) {
c.Locals(cfg.ContextUsername, username)
c.Locals(cfg.ContextPassword, password)
return c.Next()
}
// Authentication failed
return cfg.Unauthorized(c)
}
}

View file

@ -0,0 +1,154 @@
package basicauth
import (
"encoding/base64"
"fmt"
"io"
"net/http/httptest"
"testing"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/utils"
"github.com/valyala/fasthttp"
)
// go test -run Test_BasicAuth_Next
func Test_BasicAuth_Next(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Next: func(_ *fiber.Ctx) bool {
return true
},
}))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
}
func Test_Middleware_BasicAuth(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Users: map[string]string{
"john": "doe",
"admin": "123456",
},
}))
//nolint:forcetypeassert,errcheck // TODO: Do not force-type assert
app.Get("/testauth", func(c *fiber.Ctx) error {
username := c.Locals("username").(string)
password := c.Locals("password").(string)
return c.SendString(username + password)
})
tests := []struct {
url string
statusCode int
username string
password string
}{
{
url: "/testauth",
statusCode: 200,
username: "john",
password: "doe",
},
{
url: "/testauth",
statusCode: 200,
username: "admin",
password: "123456",
},
{
url: "/testauth",
statusCode: 401,
username: "ee",
password: "123456",
},
}
for _, tt := range tests {
// Base64 encode credentials for http auth header
creds := base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", tt.username, tt.password)))
req := httptest.NewRequest(fiber.MethodGet, "/testauth", nil)
req.Header.Add("Authorization", "Basic "+creds)
resp, err := app.Test(req)
utils.AssertEqual(t, nil, err)
body, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, tt.statusCode, resp.StatusCode)
if tt.statusCode == 200 {
utils.AssertEqual(t, fmt.Sprintf("%s%s", tt.username, tt.password), string(body))
}
}
}
// go test -v -run=^$ -bench=Benchmark_Middleware_BasicAuth -benchmem -count=4
func Benchmark_Middleware_BasicAuth(b *testing.B) {
app := fiber.New()
app.Use(New(Config{
Users: map[string]string{
"john": "doe",
},
}))
app.Get("/", func(c *fiber.Ctx) error {
return c.SendStatus(fiber.StatusTeapot)
})
h := app.Handler()
fctx := &fasthttp.RequestCtx{}
fctx.Request.Header.SetMethod(fiber.MethodGet)
fctx.Request.SetRequestURI("/")
fctx.Request.Header.Set(fiber.HeaderAuthorization, "basic am9objpkb2U=") // john:doe
b.ReportAllocs()
b.ResetTimer()
for n := 0; n < b.N; n++ {
h(fctx)
}
utils.AssertEqual(b, fiber.StatusTeapot, fctx.Response.Header.StatusCode())
}
// go test -v -run=^$ -bench=Benchmark_Middleware_BasicAuth -benchmem -count=4
func Benchmark_Middleware_BasicAuth_Upper(b *testing.B) {
app := fiber.New()
app.Use(New(Config{
Users: map[string]string{
"john": "doe",
},
}))
app.Get("/", func(c *fiber.Ctx) error {
return c.SendStatus(fiber.StatusTeapot)
})
h := app.Handler()
fctx := &fasthttp.RequestCtx{}
fctx.Request.Header.SetMethod(fiber.MethodGet)
fctx.Request.SetRequestURI("/")
fctx.Request.Header.Set(fiber.HeaderAuthorization, "Basic am9objpkb2U=") // john:doe
b.ReportAllocs()
b.ResetTimer()
for n := 0; n < b.N; n++ {
h(fctx)
}
utils.AssertEqual(b, fiber.StatusTeapot, fctx.Response.Header.StatusCode())
}

View file

@ -0,0 +1,105 @@
package basicauth
import (
"crypto/subtle"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/utils"
)
// Config defines the config for middleware.
type Config struct {
// Next defines a function to skip this middleware when returned true.
//
// Optional. Default: nil
Next func(c *fiber.Ctx) bool
// Users defines the allowed credentials
//
// Required. Default: map[string]string{}
Users map[string]string
// Realm is a string to define realm attribute of BasicAuth.
// the realm identifies the system to authenticate against
// and can be used by clients to save credentials
//
// Optional. Default: "Restricted".
Realm string
// Authorizer defines a function you can pass
// to check the credentials however you want.
// It will be called with a username and password
// and is expected to return true or false to indicate
// that the credentials were approved or not.
//
// Optional. Default: nil.
Authorizer func(string, string) bool
// Unauthorized defines the response body for unauthorized responses.
// By default it will return with a 401 Unauthorized and the correct WWW-Auth header
//
// Optional. Default: nil
Unauthorized fiber.Handler
// ContextUser is the key to store the username in Locals
//
// Optional. Default: "username"
ContextUsername interface{}
// ContextPass is the key to store the password in Locals
//
// Optional. Default: "password"
ContextPassword interface{}
}
// ConfigDefault is the default config
var ConfigDefault = Config{
Next: nil,
Users: map[string]string{},
Realm: "Restricted",
Authorizer: nil,
Unauthorized: nil,
ContextUsername: "username",
ContextPassword: "password",
}
// Helper function to set default values
func configDefault(config ...Config) Config {
// Return default config if nothing provided
if len(config) < 1 {
return ConfigDefault
}
// Override default config
cfg := config[0]
// Set default values
if cfg.Next == nil {
cfg.Next = ConfigDefault.Next
}
if cfg.Users == nil {
cfg.Users = ConfigDefault.Users
}
if cfg.Realm == "" {
cfg.Realm = ConfigDefault.Realm
}
if cfg.Authorizer == nil {
cfg.Authorizer = func(user, pass string) bool {
userPwd, exist := cfg.Users[user]
return exist && subtle.ConstantTimeCompare(utils.UnsafeBytes(userPwd), utils.UnsafeBytes(pass)) == 1
}
}
if cfg.Unauthorized == nil {
cfg.Unauthorized = func(c *fiber.Ctx) error {
c.Set(fiber.HeaderWWWAuthenticate, "basic realm="+cfg.Realm)
return c.SendStatus(fiber.StatusUnauthorized)
}
}
if cfg.ContextUsername == nil {
cfg.ContextUsername = ConfigDefault.ContextUsername
}
if cfg.ContextPassword == nil {
cfg.ContextPassword = ConfigDefault.ContextPassword
}
return cfg
}

252
middleware/cache/cache.go vendored Normal file
View file

@ -0,0 +1,252 @@
// Special thanks to @codemicro for moving this to fiber core
// Original middleware: github.com/codemicro/fiber-cache
package cache
import (
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/utils"
)
// timestampUpdatePeriod is the period which is used to check the cache expiration.
// It should not be too long to provide more or less acceptable expiration error, and in the same
// time it should not be too short to avoid overwhelming of the system
const timestampUpdatePeriod = 300 * time.Millisecond
// cache status
// unreachable: when cache is bypass, or invalid
// hit: cache is served
// miss: do not have cache record
const (
cacheUnreachable = "unreachable"
cacheHit = "hit"
cacheMiss = "miss"
)
// directives
const (
noCache = "no-cache"
noStore = "no-store"
)
var ignoreHeaders = map[string]interface{}{
"Connection": nil,
"Keep-Alive": nil,
"Proxy-Authenticate": nil,
"Proxy-Authorization": nil,
"TE": nil,
"Trailers": nil,
"Transfer-Encoding": nil,
"Upgrade": nil,
"Content-Type": nil, // already stored explicitly by the cache manager
"Content-Encoding": nil, // already stored explicitly by the cache manager
}
// New creates a new middleware handler
func New(config ...Config) fiber.Handler {
// Set default config
cfg := configDefault(config...)
// Nothing to cache
if int(cfg.Expiration.Seconds()) < 0 {
return func(c *fiber.Ctx) error {
return c.Next()
}
}
var (
// Cache settings
mux = &sync.RWMutex{}
timestamp = uint64(time.Now().Unix())
)
// Create manager to simplify storage operations ( see manager.go )
manager := newManager(cfg.Storage)
// Create indexed heap for tracking expirations ( see heap.go )
heap := &indexedHeap{}
// count stored bytes (sizes of response bodies)
var storedBytes uint
// Update timestamp in the configured interval
go func() {
for {
atomic.StoreUint64(&timestamp, uint64(time.Now().Unix()))
time.Sleep(timestampUpdatePeriod)
}
}()
// Delete key from both manager and storage
deleteKey := func(dkey string) {
manager.del(dkey)
// External storage saves body data with different key
if cfg.Storage != nil {
manager.del(dkey + "_body")
}
}
// Return new handler
return func(c *fiber.Ctx) error {
// Refrain from caching
if hasRequestDirective(c, noStore) {
return c.Next()
}
// Only cache selected methods
var isExists bool
for _, method := range cfg.Methods {
if c.Method() == method {
isExists = true
}
}
if !isExists {
c.Set(cfg.CacheHeader, cacheUnreachable)
return c.Next()
}
// Get key from request
// TODO(allocation optimization): try to minimize the allocation from 2 to 1
key := cfg.KeyGenerator(c) + "_" + c.Method()
// Get entry from pool
e := manager.get(key)
// Lock entry
mux.Lock()
// Get timestamp
ts := atomic.LoadUint64(&timestamp)
// Check if entry is expired
if e.exp != 0 && ts >= e.exp {
deleteKey(key)
if cfg.MaxBytes > 0 {
_, size := heap.remove(e.heapidx)
storedBytes -= size
}
} else if e.exp != 0 && !hasRequestDirective(c, noCache) {
// Separate body value to avoid msgp serialization
// We can store raw bytes with Storage 👍
if cfg.Storage != nil {
e.body = manager.getRaw(key + "_body")
}
// Set response headers from cache
c.Response().SetBodyRaw(e.body)
c.Response().SetStatusCode(e.status)
c.Response().Header.SetContentTypeBytes(e.ctype)
if len(e.cencoding) > 0 {
c.Response().Header.SetBytesV(fiber.HeaderContentEncoding, e.cencoding)
}
for k, v := range e.headers {
c.Response().Header.SetBytesV(k, v)
}
// Set Cache-Control header if enabled
if cfg.CacheControl {
maxAge := strconv.FormatUint(e.exp-ts, 10)
c.Set(fiber.HeaderCacheControl, "public, max-age="+maxAge)
}
c.Set(cfg.CacheHeader, cacheHit)
mux.Unlock()
// Return response
return nil
}
// make sure we're not blocking concurrent requests - do unlock
mux.Unlock()
// Continue stack, return err to Fiber if exist
if err := c.Next(); err != nil {
return err
}
// lock entry back and unlock on finish
mux.Lock()
defer mux.Unlock()
// Don't cache response if Next returns true
if cfg.Next != nil && cfg.Next(c) {
c.Set(cfg.CacheHeader, cacheUnreachable)
return nil
}
// Don't try to cache if body won't fit into cache
bodySize := uint(len(c.Response().Body()))
if cfg.MaxBytes > 0 && bodySize > cfg.MaxBytes {
c.Set(cfg.CacheHeader, cacheUnreachable)
return nil
}
// Remove oldest to make room for new
if cfg.MaxBytes > 0 {
for storedBytes+bodySize > cfg.MaxBytes {
key, size := heap.removeFirst()
deleteKey(key)
storedBytes -= size
}
}
// Cache response
e.body = utils.CopyBytes(c.Response().Body())
e.status = c.Response().StatusCode()
e.ctype = utils.CopyBytes(c.Response().Header.ContentType())
e.cencoding = utils.CopyBytes(c.Response().Header.Peek(fiber.HeaderContentEncoding))
// Store all response headers
// (more: https://datatracker.ietf.org/doc/html/rfc2616#section-13.5.1)
if cfg.StoreResponseHeaders {
e.headers = make(map[string][]byte)
c.Response().Header.VisitAll(
func(key, value []byte) {
// create real copy
keyS := string(key)
if _, ok := ignoreHeaders[keyS]; !ok {
e.headers[keyS] = utils.CopyBytes(value)
}
},
)
}
// default cache expiration
expiration := cfg.Expiration
// Calculate expiration by response header or other setting
if cfg.ExpirationGenerator != nil {
expiration = cfg.ExpirationGenerator(c, &cfg)
}
e.exp = ts + uint64(expiration.Seconds())
// Store entry in heap
if cfg.MaxBytes > 0 {
e.heapidx = heap.put(key, e.exp, bodySize)
storedBytes += bodySize
}
// For external Storage we store raw body separated
if cfg.Storage != nil {
manager.setRaw(key+"_body", e.body, expiration)
// avoid body msgp encoding
e.body = nil
manager.set(key, e, expiration)
manager.release(e)
} else {
// Store entry in memory
manager.set(key, e, expiration)
}
c.Set(cfg.CacheHeader, cacheMiss)
// Finish response
return nil
}
}
// Check if request has directive
func hasRequestDirective(c *fiber.Ctx, directive string) bool {
return strings.Contains(c.Get(fiber.HeaderCacheControl), directive)
}

901
middleware/cache/cache_test.go vendored Normal file
View file

@ -0,0 +1,901 @@
// Special thanks to @codemicro for moving this to fiber core
// Original middleware: github.com/codemicro/fiber-cache
package cache
import (
"bytes"
"fmt"
"io"
"math"
"net/http/httptest"
"os"
"strconv"
"testing"
"time"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/internal/storage/memory"
"github.com/gofiber/fiber/v2/middleware/etag"
"github.com/gofiber/fiber/v2/utils"
"github.com/valyala/fasthttp"
)
func Test_Cache_CacheControl(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
CacheControl: true,
Expiration: 10 * time.Second,
}))
app.Get("/", func(c *fiber.Ctx) error {
return c.SendString("Hello, World!")
})
_, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "public, max-age=10", resp.Header.Get(fiber.HeaderCacheControl))
}
func Test_Cache_Expired(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{Expiration: 2 * time.Second}))
app.Get("/", func(c *fiber.Ctx) error {
return c.SendString(fmt.Sprintf("%d", time.Now().UnixNano()))
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
body, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
// Sleep until the cache is expired
time.Sleep(3 * time.Second)
respCached, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
bodyCached, err := io.ReadAll(respCached.Body)
utils.AssertEqual(t, nil, err)
if bytes.Equal(body, bodyCached) {
t.Errorf("Cache should have expired: %s, %s", body, bodyCached)
}
// Next response should be also cached
respCachedNextRound, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
bodyCachedNextRound, err := io.ReadAll(respCachedNextRound.Body)
utils.AssertEqual(t, nil, err)
if !bytes.Equal(bodyCachedNextRound, bodyCached) {
t.Errorf("Cache should not have expired: %s, %s", bodyCached, bodyCachedNextRound)
}
}
func Test_Cache(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
app.Get("/", func(c *fiber.Ctx) error {
now := fmt.Sprintf("%d", time.Now().UnixNano())
return c.SendString(now)
})
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
resp, err := app.Test(req)
utils.AssertEqual(t, nil, err)
cachedReq := httptest.NewRequest(fiber.MethodGet, "/", nil)
cachedResp, err := app.Test(cachedReq)
utils.AssertEqual(t, nil, err)
body, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
cachedBody, err := io.ReadAll(cachedResp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, cachedBody, body)
}
// go test -run Test_Cache_WithNoCacheRequestDirective
func Test_Cache_WithNoCacheRequestDirective(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
app.Get("/", func(c *fiber.Ctx) error {
return c.SendString(c.Query("id", "1"))
})
// Request id = 1
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
resp, err := app.Test(req)
utils.AssertEqual(t, nil, err)
body, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, cacheMiss, resp.Header.Get("X-Cache"))
utils.AssertEqual(t, []byte("1"), body)
// Response cached, entry id = 1
// Request id = 2 without Cache-Control: no-cache
cachedReq := httptest.NewRequest(fiber.MethodGet, "/?id=2", nil)
cachedResp, err := app.Test(cachedReq)
utils.AssertEqual(t, nil, err)
cachedBody, err := io.ReadAll(cachedResp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, cacheHit, cachedResp.Header.Get("X-Cache"))
utils.AssertEqual(t, []byte("1"), cachedBody)
// Response not cached, returns cached response, entry id = 1
// Request id = 2 with Cache-Control: no-cache
noCacheReq := httptest.NewRequest(fiber.MethodGet, "/?id=2", nil)
noCacheReq.Header.Set(fiber.HeaderCacheControl, noCache)
noCacheResp, err := app.Test(noCacheReq)
utils.AssertEqual(t, nil, err)
noCacheBody, err := io.ReadAll(noCacheResp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, cacheMiss, noCacheResp.Header.Get("X-Cache"))
utils.AssertEqual(t, []byte("2"), noCacheBody)
// Response cached, returns updated response, entry = 2
/* Check Test_Cache_WithETagAndNoCacheRequestDirective */
// Request id = 2 with Cache-Control: no-cache again
noCacheReq1 := httptest.NewRequest(fiber.MethodGet, "/?id=2", nil)
noCacheReq1.Header.Set(fiber.HeaderCacheControl, noCache)
noCacheResp1, err := app.Test(noCacheReq1)
utils.AssertEqual(t, nil, err)
noCacheBody1, err := io.ReadAll(noCacheResp1.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, cacheMiss, noCacheResp1.Header.Get("X-Cache"))
utils.AssertEqual(t, []byte("2"), noCacheBody1)
// Response cached, returns updated response, entry = 2
// Request id = 1 without Cache-Control: no-cache
cachedReq1 := httptest.NewRequest(fiber.MethodGet, "/", nil)
cachedResp1, err := app.Test(cachedReq1)
utils.AssertEqual(t, nil, err)
cachedBody1, err := io.ReadAll(cachedResp1.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, cacheHit, cachedResp1.Header.Get("X-Cache"))
utils.AssertEqual(t, []byte("2"), cachedBody1)
// Response not cached, returns cached response, entry id = 2
}
// go test -run Test_Cache_WithETagAndNoCacheRequestDirective
func Test_Cache_WithETagAndNoCacheRequestDirective(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(
etag.New(),
New(),
)
app.Get("/", func(c *fiber.Ctx) error {
return c.SendString(c.Query("id", "1"))
})
// Request id = 1
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
resp, err := app.Test(req)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, cacheMiss, resp.Header.Get("X-Cache"))
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
// Response cached, entry id = 1
// If response status 200
etagToken := resp.Header.Get("Etag")
// Request id = 2 with ETag but without Cache-Control: no-cache
cachedReq := httptest.NewRequest(fiber.MethodGet, "/?id=2", nil)
cachedReq.Header.Set(fiber.HeaderIfNoneMatch, etagToken)
cachedResp, err := app.Test(cachedReq)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, cacheHit, cachedResp.Header.Get("X-Cache"))
utils.AssertEqual(t, fiber.StatusNotModified, cachedResp.StatusCode)
// Response not cached, returns cached response, entry id = 1, status not modified
// Request id = 2 with ETag and Cache-Control: no-cache
noCacheReq := httptest.NewRequest(fiber.MethodGet, "/?id=2", nil)
noCacheReq.Header.Set(fiber.HeaderCacheControl, noCache)
noCacheReq.Header.Set(fiber.HeaderIfNoneMatch, etagToken)
noCacheResp, err := app.Test(noCacheReq)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, cacheMiss, noCacheResp.Header.Get("X-Cache"))
utils.AssertEqual(t, fiber.StatusOK, noCacheResp.StatusCode)
// Response cached, returns updated response, entry id = 2
// If response status 200
etagToken = noCacheResp.Header.Get("Etag")
// Request id = 2 with ETag and Cache-Control: no-cache again
noCacheReq1 := httptest.NewRequest(fiber.MethodGet, "/?id=2", nil)
noCacheReq1.Header.Set(fiber.HeaderCacheControl, noCache)
noCacheReq1.Header.Set(fiber.HeaderIfNoneMatch, etagToken)
noCacheResp1, err := app.Test(noCacheReq1)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, cacheMiss, noCacheResp1.Header.Get("X-Cache"))
utils.AssertEqual(t, fiber.StatusNotModified, noCacheResp1.StatusCode)
// Response cached, returns updated response, entry id = 2, status not modified
// Request id = 1 without ETag and Cache-Control: no-cache
cachedReq1 := httptest.NewRequest(fiber.MethodGet, "/", nil)
cachedResp1, err := app.Test(cachedReq1)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, cacheHit, cachedResp1.Header.Get("X-Cache"))
utils.AssertEqual(t, fiber.StatusOK, cachedResp1.StatusCode)
// Response not cached, returns cached response, entry id = 2
}
// go test -run Test_Cache_WithNoStoreRequestDirective
func Test_Cache_WithNoStoreRequestDirective(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
app.Get("/", func(c *fiber.Ctx) error {
return c.SendString(c.Query("id", "1"))
})
// Request id = 2
noStoreReq := httptest.NewRequest(fiber.MethodGet, "/?id=2", nil)
noStoreReq.Header.Set(fiber.HeaderCacheControl, noStore)
noStoreResp, err := app.Test(noStoreReq)
utils.AssertEqual(t, nil, err)
noStoreBody, err := io.ReadAll(noStoreResp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, []byte("2"), noStoreBody)
// Response not cached, returns updated response
}
func Test_Cache_WithSeveralRequests(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
CacheControl: true,
Expiration: 10 * time.Second,
}))
app.Get("/:id", func(c *fiber.Ctx) error {
return c.SendString(c.Params("id"))
})
for runs := 0; runs < 10; runs++ {
for i := 0; i < 10; i++ {
func(id int) {
rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, fmt.Sprintf("/%d", id), nil))
utils.AssertEqual(t, nil, err)
defer func(body io.ReadCloser) {
err := body.Close()
utils.AssertEqual(t, nil, err)
}(rsp.Body)
idFromServ, err := io.ReadAll(rsp.Body)
utils.AssertEqual(t, nil, err)
a, err := strconv.Atoi(string(idFromServ))
utils.AssertEqual(t, nil, err)
// SomeTimes,The id is not equal with a
utils.AssertEqual(t, id, a)
}(i)
}
}
}
func Test_Cache_Invalid_Expiration(t *testing.T) {
t.Parallel()
app := fiber.New()
cache := New(Config{Expiration: 0 * time.Second})
app.Use(cache)
app.Get("/", func(c *fiber.Ctx) error {
now := fmt.Sprintf("%d", time.Now().UnixNano())
return c.SendString(now)
})
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
resp, err := app.Test(req)
utils.AssertEqual(t, nil, err)
cachedReq := httptest.NewRequest(fiber.MethodGet, "/", nil)
cachedResp, err := app.Test(cachedReq)
utils.AssertEqual(t, nil, err)
body, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
cachedBody, err := io.ReadAll(cachedResp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, cachedBody, body)
}
func Test_Cache_Get(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
app.Post("/", func(c *fiber.Ctx) error {
return c.SendString(c.Query("cache"))
})
app.Get("/get", func(c *fiber.Ctx) error {
return c.SendString(c.Query("cache"))
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodPost, "/?cache=123", nil))
utils.AssertEqual(t, nil, err)
body, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "123", string(body))
resp, err = app.Test(httptest.NewRequest(fiber.MethodPost, "/?cache=12345", nil))
utils.AssertEqual(t, nil, err)
body, err = io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "12345", string(body))
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/get?cache=123", nil))
utils.AssertEqual(t, nil, err)
body, err = io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "123", string(body))
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/get?cache=12345", nil))
utils.AssertEqual(t, nil, err)
body, err = io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "123", string(body))
}
func Test_Cache_Post(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Methods: []string{fiber.MethodPost},
}))
app.Post("/", func(c *fiber.Ctx) error {
return c.SendString(c.Query("cache"))
})
app.Get("/get", func(c *fiber.Ctx) error {
return c.SendString(c.Query("cache"))
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodPost, "/?cache=123", nil))
utils.AssertEqual(t, nil, err)
body, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "123", string(body))
resp, err = app.Test(httptest.NewRequest(fiber.MethodPost, "/?cache=12345", nil))
utils.AssertEqual(t, nil, err)
body, err = io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "123", string(body))
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/get?cache=123", nil))
utils.AssertEqual(t, nil, err)
body, err = io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "123", string(body))
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/get?cache=12345", nil))
utils.AssertEqual(t, nil, err)
body, err = io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "12345", string(body))
}
func Test_Cache_NothingToCache(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{Expiration: -(time.Second * 1)}))
app.Get("/", func(c *fiber.Ctx) error {
return c.SendString(time.Now().String())
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
body, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
time.Sleep(500 * time.Millisecond)
respCached, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
bodyCached, err := io.ReadAll(respCached.Body)
utils.AssertEqual(t, nil, err)
if bytes.Equal(body, bodyCached) {
t.Errorf("Cache should have expired: %s, %s", body, bodyCached)
}
}
func Test_Cache_CustomNext(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Next: func(c *fiber.Ctx) bool {
return c.Response().StatusCode() != fiber.StatusOK
},
CacheControl: true,
}))
app.Get("/", func(c *fiber.Ctx) error {
return c.SendString(time.Now().String())
})
app.Get("/error", func(c *fiber.Ctx) error {
return c.Status(fiber.StatusInternalServerError).SendString(time.Now().String())
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
body, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
respCached, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
bodyCached, err := io.ReadAll(respCached.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, true, bytes.Equal(body, bodyCached))
utils.AssertEqual(t, true, respCached.Header.Get(fiber.HeaderCacheControl) != "")
_, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/error", nil))
utils.AssertEqual(t, nil, err)
errRespCached, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/error", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, true, errRespCached.Header.Get(fiber.HeaderCacheControl) == "")
}
func Test_CustomKey(t *testing.T) {
t.Parallel()
app := fiber.New()
var called bool
app.Use(New(Config{KeyGenerator: func(c *fiber.Ctx) string {
called = true
return utils.CopyString(c.Path())
}}))
app.Get("/", func(c *fiber.Ctx) error {
return c.SendString("hi")
})
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
_, err := app.Test(req)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, true, called)
}
func Test_CustomExpiration(t *testing.T) {
t.Parallel()
app := fiber.New()
var called bool
var newCacheTime int
app.Use(New(Config{ExpirationGenerator: func(c *fiber.Ctx, cfg *Config) time.Duration {
called = true
var err error
newCacheTime, err = strconv.Atoi(c.GetRespHeader("Cache-Time", "600"))
utils.AssertEqual(t, nil, err)
return time.Second * time.Duration(newCacheTime)
}}))
app.Get("/", func(c *fiber.Ctx) error {
c.Response().Header.Add("Cache-Time", "1")
now := fmt.Sprintf("%d", time.Now().UnixNano())
return c.SendString(now)
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, true, called)
utils.AssertEqual(t, 1, newCacheTime)
// Sleep until the cache is expired
time.Sleep(1 * time.Second)
cachedResp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
body, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
cachedBody, err := io.ReadAll(cachedResp.Body)
utils.AssertEqual(t, nil, err)
if bytes.Equal(body, cachedBody) {
t.Errorf("Cache should have expired: %s, %s", body, cachedBody)
}
// Next response should be cached
cachedRespNextRound, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
cachedBodyNextRound, err := io.ReadAll(cachedRespNextRound.Body)
utils.AssertEqual(t, nil, err)
if !bytes.Equal(cachedBodyNextRound, cachedBody) {
t.Errorf("Cache should not have expired: %s, %s", cachedBodyNextRound, cachedBody)
}
}
func Test_AdditionalE2EResponseHeaders(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
StoreResponseHeaders: true,
}))
app.Get("/", func(c *fiber.Ctx) error {
c.Response().Header.Add("X-Foobar", "foobar")
return c.SendString("hi")
})
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
resp, err := app.Test(req)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "foobar", resp.Header.Get("X-Foobar"))
req = httptest.NewRequest(fiber.MethodGet, "/", nil)
resp, err = app.Test(req)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "foobar", resp.Header.Get("X-Foobar"))
}
func Test_CacheHeader(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Expiration: 10 * time.Second,
Next: func(c *fiber.Ctx) bool {
return c.Response().StatusCode() != fiber.StatusOK
},
}))
app.Get("/", func(c *fiber.Ctx) error {
return c.SendString("Hello, World!")
})
app.Post("/", func(c *fiber.Ctx) error {
return c.SendString(c.Query("cache"))
})
app.Get("/error", func(c *fiber.Ctx) error {
return c.Status(fiber.StatusInternalServerError).SendString(time.Now().String())
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, cacheMiss, resp.Header.Get("X-Cache"))
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, cacheHit, resp.Header.Get("X-Cache"))
resp, err = app.Test(httptest.NewRequest(fiber.MethodPost, "/?cache=12345", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, cacheUnreachable, resp.Header.Get("X-Cache"))
errRespCached, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/error", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, cacheUnreachable, errRespCached.Header.Get("X-Cache"))
}
func Test_Cache_WithHead(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
app.Get("/", func(c *fiber.Ctx) error {
now := fmt.Sprintf("%d", time.Now().UnixNano())
return c.SendString(now)
})
req := httptest.NewRequest(fiber.MethodHead, "/", nil)
resp, err := app.Test(req)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, cacheMiss, resp.Header.Get("X-Cache"))
cachedReq := httptest.NewRequest(fiber.MethodHead, "/", nil)
cachedResp, err := app.Test(cachedReq)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, cacheHit, cachedResp.Header.Get("X-Cache"))
body, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
cachedBody, err := io.ReadAll(cachedResp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, cachedBody, body)
}
func Test_Cache_WithHeadThenGet(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
app.Get("/", func(c *fiber.Ctx) error {
return c.SendString(c.Query("cache"))
})
headResp, err := app.Test(httptest.NewRequest(fiber.MethodHead, "/?cache=123", nil))
utils.AssertEqual(t, nil, err)
headBody, err := io.ReadAll(headResp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "", string(headBody))
utils.AssertEqual(t, cacheMiss, headResp.Header.Get("X-Cache"))
headResp, err = app.Test(httptest.NewRequest(fiber.MethodHead, "/?cache=123", nil))
utils.AssertEqual(t, nil, err)
headBody, err = io.ReadAll(headResp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "", string(headBody))
utils.AssertEqual(t, cacheHit, headResp.Header.Get("X-Cache"))
getResp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/?cache=123", nil))
utils.AssertEqual(t, nil, err)
getBody, err := io.ReadAll(getResp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "123", string(getBody))
utils.AssertEqual(t, cacheMiss, getResp.Header.Get("X-Cache"))
getResp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/?cache=123", nil))
utils.AssertEqual(t, nil, err)
getBody, err = io.ReadAll(getResp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "123", string(getBody))
utils.AssertEqual(t, cacheHit, getResp.Header.Get("X-Cache"))
}
func Test_CustomCacheHeader(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
CacheHeader: "Cache-Status",
}))
app.Get("/", func(c *fiber.Ctx) error {
return c.SendString("Hello, World!")
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, cacheMiss, resp.Header.Get("Cache-Status"))
}
// Because time points are updated once every X milliseconds, entries in tests can often have
// equal expiration times and thus be in an random order. This closure hands out increasing
// time intervals to maintain strong ascending order of expiration
func stableAscendingExpiration() func(c1 *fiber.Ctx, c2 *Config) time.Duration {
i := 0
return func(c1 *fiber.Ctx, c2 *Config) time.Duration {
i++
return time.Hour * time.Duration(i)
}
}
func Test_Cache_MaxBytesOrder(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
MaxBytes: 2,
ExpirationGenerator: stableAscendingExpiration(),
}))
app.Get("/*", func(c *fiber.Ctx) error {
return c.SendString("1")
})
cases := [][]string{
// Insert a, b into cache of size 2 bytes (responses are 1 byte)
{"/a", cacheMiss},
{"/b", cacheMiss},
{"/a", cacheHit},
{"/b", cacheHit},
// Add c -> a evicted
{"/c", cacheMiss},
{"/b", cacheHit},
// Add a again -> b evicted
{"/a", cacheMiss},
{"/c", cacheHit},
// Add b -> c evicted
{"/b", cacheMiss},
{"/c", cacheMiss},
}
for idx, tcase := range cases {
rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, tcase[0], nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, tcase[1], rsp.Header.Get("X-Cache"), fmt.Sprintf("Case %v", idx))
}
}
func Test_Cache_MaxBytesSizes(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
MaxBytes: 7,
ExpirationGenerator: stableAscendingExpiration(),
}))
app.Get("/*", func(c *fiber.Ctx) error {
path := c.Context().URI().LastPathSegment()
size, err := strconv.Atoi(string(path))
utils.AssertEqual(t, nil, err)
return c.Send(make([]byte, size))
})
cases := [][]string{
{"/1", cacheMiss},
{"/2", cacheMiss},
{"/3", cacheMiss},
{"/4", cacheMiss}, // 1+2+3+4 > 7 => 1,2 are evicted now
{"/3", cacheHit},
{"/1", cacheMiss},
{"/2", cacheMiss},
{"/8", cacheUnreachable}, // too big to cache -> unreachable
}
for idx, tcase := range cases {
rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, tcase[0], nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, tcase[1], rsp.Header.Get("X-Cache"), fmt.Sprintf("Case %v", idx))
}
}
// go test -v -run=^$ -bench=Benchmark_Cache -benchmem -count=4
func Benchmark_Cache(b *testing.B) {
app := fiber.New()
app.Use(New())
app.Get("/demo", func(c *fiber.Ctx) error {
data, _ := os.ReadFile("../../.github/README.md") //nolint:errcheck // We're inside a benchmark
return c.Status(fiber.StatusTeapot).Send(data)
})
h := app.Handler()
fctx := &fasthttp.RequestCtx{}
fctx.Request.Header.SetMethod(fiber.MethodGet)
fctx.Request.SetRequestURI("/demo")
b.ReportAllocs()
b.ResetTimer()
for n := 0; n < b.N; n++ {
h(fctx)
}
utils.AssertEqual(b, fiber.StatusTeapot, fctx.Response.Header.StatusCode())
utils.AssertEqual(b, true, len(fctx.Response.Body()) > 30000)
}
// go test -v -run=^$ -bench=Benchmark_Cache_Storage -benchmem -count=4
func Benchmark_Cache_Storage(b *testing.B) {
app := fiber.New()
app.Use(New(Config{
Storage: memory.New(),
}))
app.Get("/demo", func(c *fiber.Ctx) error {
data, _ := os.ReadFile("../../.github/README.md") //nolint:errcheck // We're inside a benchmark
return c.Status(fiber.StatusTeapot).Send(data)
})
h := app.Handler()
fctx := &fasthttp.RequestCtx{}
fctx.Request.Header.SetMethod(fiber.MethodGet)
fctx.Request.SetRequestURI("/demo")
b.ReportAllocs()
b.ResetTimer()
for n := 0; n < b.N; n++ {
h(fctx)
}
utils.AssertEqual(b, fiber.StatusTeapot, fctx.Response.Header.StatusCode())
utils.AssertEqual(b, true, len(fctx.Response.Body()) > 30000)
}
func Benchmark_Cache_AdditionalHeaders(b *testing.B) {
app := fiber.New()
app.Use(New(Config{
StoreResponseHeaders: true,
}))
app.Get("/demo", func(c *fiber.Ctx) error {
c.Response().Header.Add("X-Foobar", "foobar")
return c.SendStatus(418)
})
h := app.Handler()
fctx := &fasthttp.RequestCtx{}
fctx.Request.Header.SetMethod(fiber.MethodGet)
fctx.Request.SetRequestURI("/demo")
b.ReportAllocs()
b.ResetTimer()
for n := 0; n < b.N; n++ {
h(fctx)
}
utils.AssertEqual(b, fiber.StatusTeapot, fctx.Response.Header.StatusCode())
utils.AssertEqual(b, []byte("foobar"), fctx.Response.Header.Peek("X-Foobar"))
}
func Benchmark_Cache_MaxSize(b *testing.B) {
// The benchmark is run with three different MaxSize parameters
// 1) 0: Tracking is disabled = no overhead
// 2) MaxInt32: Enough to store all entries = no removals
// 3) 100: Small size = constant insertions and removals
cases := []uint{0, math.MaxUint32, 100}
names := []string{"Disabled", "Unlim", "LowBounded"}
for i, size := range cases {
b.Run(names[i], func(b *testing.B) {
app := fiber.New()
app.Use(New(Config{MaxBytes: size}))
app.Get("/*", func(c *fiber.Ctx) error {
return c.Status(fiber.StatusTeapot).SendString("1")
})
h := app.Handler()
fctx := &fasthttp.RequestCtx{}
fctx.Request.Header.SetMethod(fiber.MethodGet)
b.ReportAllocs()
b.ResetTimer()
for n := 0; n < b.N; n++ {
fctx.Request.SetRequestURI(fmt.Sprintf("/%v", n))
h(fctx)
}
utils.AssertEqual(b, fiber.StatusTeapot, fctx.Response.Header.StatusCode())
})
}
}

128
middleware/cache/config.go vendored Normal file
View file

@ -0,0 +1,128 @@
package cache
import (
"time"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/log"
"github.com/gofiber/fiber/v2/utils"
)
// Config defines the config for middleware.
type Config struct {
// Next defines a function to skip this middleware when returned true.
//
// Optional. Default: nil
Next func(c *fiber.Ctx) bool
// Expiration is the time that an cached response will live
//
// Optional. Default: 1 * time.Minute
Expiration time.Duration
// CacheHeader header on response header, indicate cache status, with the following possible return value
//
// hit, miss, unreachable
//
// Optional. Default: X-Cache
CacheHeader string
// CacheControl enables client side caching if set to true
//
// Optional. Default: false
CacheControl bool
// Key allows you to generate custom keys, by default c.Path() is used
//
// Default: func(c *fiber.Ctx) string {
// return utils.CopyString(c.Path())
// }
KeyGenerator func(*fiber.Ctx) string
// allows you to generate custom Expiration Key By Key, default is Expiration (Optional)
//
// Default: nil
ExpirationGenerator func(*fiber.Ctx, *Config) time.Duration
// Store is used to store the state of the middleware
//
// Default: an in memory store for this process only
Storage fiber.Storage
// Deprecated: Use Storage instead
Store fiber.Storage
// Deprecated: Use KeyGenerator instead
Key func(*fiber.Ctx) string
// allows you to store additional headers generated by next middlewares & handler
//
// Default: false
StoreResponseHeaders bool
// Max number of bytes of response bodies simultaneously stored in cache. When limit is reached,
// entries with the nearest expiration are deleted to make room for new.
// 0 means no limit
//
// Default: 0
MaxBytes uint
// You can specify HTTP methods to cache.
// The middleware just caches the routes of its methods in this slice.
//
// Default: []string{fiber.MethodGet, fiber.MethodHead}
Methods []string
}
// ConfigDefault is the default config
var ConfigDefault = Config{
Next: nil,
Expiration: 1 * time.Minute,
CacheHeader: "X-Cache",
CacheControl: false,
KeyGenerator: func(c *fiber.Ctx) string {
return utils.CopyString(c.Path())
},
ExpirationGenerator: nil,
StoreResponseHeaders: false,
Storage: nil,
MaxBytes: 0,
Methods: []string{fiber.MethodGet, fiber.MethodHead},
}
// Helper function to set default values
func configDefault(config ...Config) Config {
// Return default config if nothing provided
if len(config) < 1 {
return ConfigDefault
}
// Override default config
cfg := config[0]
// Set default values
if cfg.Store != nil {
log.Warn("[CACHE] Store is deprecated, please use Storage")
cfg.Storage = cfg.Store
}
if cfg.Key != nil {
log.Warn("[CACHE] Key is deprecated, please use KeyGenerator")
cfg.KeyGenerator = cfg.Key
}
if cfg.Next == nil {
cfg.Next = ConfigDefault.Next
}
if int(cfg.Expiration.Seconds()) == 0 {
cfg.Expiration = ConfigDefault.Expiration
}
if cfg.CacheHeader == "" {
cfg.CacheHeader = ConfigDefault.CacheHeader
}
if cfg.KeyGenerator == nil {
cfg.KeyGenerator = ConfigDefault.KeyGenerator
}
if len(cfg.Methods) == 0 {
cfg.Methods = ConfigDefault.Methods
}
return cfg
}

92
middleware/cache/heap.go vendored Normal file
View file

@ -0,0 +1,92 @@
package cache
import (
"container/heap"
)
type heapEntry struct {
key string
exp uint64
bytes uint
idx int
}
// indexedHeap is a regular min-heap that allows finding
// elements in constant time. It does so by handing out special indices
// and tracking entry movement.
//
// indexdedHeap is used for quickly finding entries with the lowest
// expiration timestamp and deleting arbitrary entries.
type indexedHeap struct {
// Slice the heap is built on
entries []heapEntry
// Mapping "index" to position in heap slice
indices []int
// Max index handed out
maxidx int
}
func (h indexedHeap) Len() int {
return len(h.entries)
}
func (h indexedHeap) Less(i, j int) bool {
return h.entries[i].exp < h.entries[j].exp
}
func (h indexedHeap) Swap(i, j int) {
h.entries[i], h.entries[j] = h.entries[j], h.entries[i]
h.indices[h.entries[i].idx] = i
h.indices[h.entries[j].idx] = j
}
func (h *indexedHeap) Push(x interface{}) {
h.pushInternal(x.(heapEntry)) //nolint:forcetypeassert // Forced type assertion required to implement the heap.Interface interface
}
func (h *indexedHeap) Pop() interface{} {
n := len(h.entries)
h.entries = h.entries[0 : n-1]
return h.entries[0:n][n-1]
}
func (h *indexedHeap) pushInternal(entry heapEntry) {
h.indices[entry.idx] = len(h.entries)
h.entries = append(h.entries, entry)
}
// Returns index to track entry
func (h *indexedHeap) put(key string, exp uint64, bytes uint) int {
idx := 0
if len(h.entries) < h.maxidx {
// Steal index from previously removed entry
// capacity > size is guaranteed
n := len(h.entries)
idx = h.entries[:n+1][n].idx
} else {
idx = h.maxidx
h.maxidx++
h.indices = append(h.indices, idx)
}
// Push manually to avoid allocation
h.pushInternal(heapEntry{
key: key, exp: exp, idx: idx, bytes: bytes,
})
heap.Fix(h, h.Len()-1)
return idx
}
func (h *indexedHeap) removeInternal(realIdx int) (string, uint) {
x := heap.Remove(h, realIdx).(heapEntry) //nolint:forcetypeassert,errcheck // Forced type assertion required to implement the heap.Interface interface
return x.key, x.bytes
}
// Remove entry by index
func (h *indexedHeap) remove(idx int) (string, uint) {
return h.removeInternal(h.indices[idx])
}
// Remove entry with lowest expiration time
func (h *indexedHeap) removeFirst() (string, uint) {
return h.removeInternal(0)
}

132
middleware/cache/manager.go vendored Normal file
View file

@ -0,0 +1,132 @@
package cache
import (
"sync"
"time"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/internal/memory"
)
// go:generate msgp
// msgp -file="manager.go" -o="manager_msgp.go" -tests=false -unexported
type item struct {
body []byte
ctype []byte
cencoding []byte
status int
exp uint64
headers map[string][]byte
// used for finding the item in an indexed heap
heapidx int
}
//msgp:ignore manager
type manager struct {
pool sync.Pool
memory *memory.Storage
storage fiber.Storage
}
func newManager(storage fiber.Storage) *manager {
// Create new storage handler
manager := &manager{
pool: sync.Pool{
New: func() interface{} {
return new(item)
},
},
}
if storage != nil {
// Use provided storage if provided
manager.storage = storage
} else {
// Fallback to memory storage
manager.memory = memory.New()
}
return manager
}
// acquire returns an *entry from the sync.Pool
func (m *manager) acquire() *item {
return m.pool.Get().(*item) //nolint:forcetypeassert // We store nothing else in the pool
}
// release and reset *entry to sync.Pool
func (m *manager) release(e *item) {
// don't release item if we using memory storage
if m.storage != nil {
return
}
e.body = nil
e.ctype = nil
e.status = 0
e.exp = 0
e.headers = nil
m.pool.Put(e)
}
// get data from storage or memory
func (m *manager) get(key string) *item {
var it *item
if m.storage != nil {
it = m.acquire()
raw, err := m.storage.Get(key)
if err != nil {
return it
}
if raw != nil {
if _, err := it.UnmarshalMsg(raw); err != nil {
return it
}
}
return it
}
if it, _ = m.memory.Get(key).(*item); it == nil { //nolint:errcheck // We store nothing else in the pool
it = m.acquire()
return it
}
return it
}
// get raw data from storage or memory
func (m *manager) getRaw(key string) []byte {
var raw []byte
if m.storage != nil {
raw, _ = m.storage.Get(key) //nolint:errcheck // TODO: Handle error here
} else {
raw, _ = m.memory.Get(key).([]byte) //nolint:errcheck // TODO: Handle error here
}
return raw
}
// set data to storage or memory
func (m *manager) set(key string, it *item, exp time.Duration) {
if m.storage != nil {
if raw, err := it.MarshalMsg(nil); err == nil {
_ = m.storage.Set(key, raw, exp) //nolint:errcheck // TODO: Handle error here
}
// we can release data because it's serialized to database
m.release(it)
} else {
m.memory.Set(key, it, exp)
}
}
// set data to storage or memory
func (m *manager) setRaw(key string, raw []byte, exp time.Duration) {
if m.storage != nil {
_ = m.storage.Set(key, raw, exp) //nolint:errcheck // TODO: Handle error here
} else {
m.memory.Set(key, raw, exp)
}
}
// delete data from storage or memory
func (m *manager) del(key string) {
if m.storage != nil {
_ = m.storage.Delete(key) //nolint:errcheck // TODO: Handle error here
} else {
m.memory.Delete(key)
}
}

300
middleware/cache/manager_msgp.go vendored Normal file
View file

@ -0,0 +1,300 @@
package cache
// NOTE: THIS FILE WAS PRODUCED BY THE
// MSGP CODE GENERATION TOOL (github.com/tinylib/msgp)
// DO NOT EDIT
import (
"github.com/tinylib/msgp/msgp"
)
// DecodeMsg implements msgp.Decodable
func (z *item) DecodeMsg(dc *msgp.Reader) (err error) {
var field []byte
_ = field
var zbai uint32
zbai, err = dc.ReadMapHeader()
if err != nil {
return
}
for zbai > 0 {
zbai--
field, err = dc.ReadMapKeyPtr()
if err != nil {
return
}
switch msgp.UnsafeString(field) {
case "body":
z.body, err = dc.ReadBytes(z.body)
if err != nil {
return
}
case "ctype":
z.ctype, err = dc.ReadBytes(z.ctype)
if err != nil {
return
}
case "cencoding":
z.cencoding, err = dc.ReadBytes(z.cencoding)
if err != nil {
return
}
case "status":
z.status, err = dc.ReadInt()
if err != nil {
return
}
case "exp":
z.exp, err = dc.ReadUint64()
if err != nil {
return
}
case "headers":
var zcmr uint32
zcmr, err = dc.ReadMapHeader()
if err != nil {
return
}
if z.headers == nil && zcmr > 0 {
z.headers = make(map[string][]byte, zcmr)
} else if len(z.headers) > 0 {
for key := range z.headers {
delete(z.headers, key)
}
}
for zcmr > 0 {
zcmr--
var zxvk string
var zbzg []byte
zxvk, err = dc.ReadString()
if err != nil {
return
}
zbzg, err = dc.ReadBytes(zbzg)
if err != nil {
return
}
z.headers[zxvk] = zbzg
}
case "heapidx":
z.heapidx, err = dc.ReadInt()
if err != nil {
return
}
default:
err = dc.Skip()
if err != nil {
return
}
}
}
return
}
// EncodeMsg implements msgp.Encodable
func (z *item) EncodeMsg(en *msgp.Writer) (err error) {
// map header, size 7
// write "body"
err = en.Append(0x87, 0xa4, 0x62, 0x6f, 0x64, 0x79)
if err != nil {
return err
}
err = en.WriteBytes(z.body)
if err != nil {
return
}
// write "ctype"
err = en.Append(0xa5, 0x63, 0x74, 0x79, 0x70, 0x65)
if err != nil {
return err
}
err = en.WriteBytes(z.ctype)
if err != nil {
return
}
// write "cencoding"
err = en.Append(0xa9, 0x63, 0x65, 0x6e, 0x63, 0x6f, 0x64, 0x69, 0x6e, 0x67)
if err != nil {
return err
}
err = en.WriteBytes(z.cencoding)
if err != nil {
return
}
// write "status"
err = en.Append(0xa6, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73)
if err != nil {
return err
}
err = en.WriteInt(z.status)
if err != nil {
return
}
// write "exp"
err = en.Append(0xa3, 0x65, 0x78, 0x70)
if err != nil {
return err
}
err = en.WriteUint64(z.exp)
if err != nil {
return
}
// write "headers"
err = en.Append(0xa7, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x73)
if err != nil {
return err
}
err = en.WriteMapHeader(uint32(len(z.headers)))
if err != nil {
return
}
for zxvk, zbzg := range z.headers {
err = en.WriteString(zxvk)
if err != nil {
return
}
err = en.WriteBytes(zbzg)
if err != nil {
return
}
}
// write "heapidx"
err = en.Append(0xa7, 0x68, 0x65, 0x61, 0x70, 0x69, 0x64, 0x78)
if err != nil {
return err
}
err = en.WriteInt(z.heapidx)
if err != nil {
return
}
return
}
// MarshalMsg implements msgp.Marshaler
func (z *item) MarshalMsg(b []byte) (o []byte, err error) {
o = msgp.Require(b, z.Msgsize())
// map header, size 7
// string "body"
o = append(o, 0x87, 0xa4, 0x62, 0x6f, 0x64, 0x79)
o = msgp.AppendBytes(o, z.body)
// string "ctype"
o = append(o, 0xa5, 0x63, 0x74, 0x79, 0x70, 0x65)
o = msgp.AppendBytes(o, z.ctype)
// string "cencoding"
o = append(o, 0xa9, 0x63, 0x65, 0x6e, 0x63, 0x6f, 0x64, 0x69, 0x6e, 0x67)
o = msgp.AppendBytes(o, z.cencoding)
// string "status"
o = append(o, 0xa6, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73)
o = msgp.AppendInt(o, z.status)
// string "exp"
o = append(o, 0xa3, 0x65, 0x78, 0x70)
o = msgp.AppendUint64(o, z.exp)
// string "headers"
o = append(o, 0xa7, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x73)
o = msgp.AppendMapHeader(o, uint32(len(z.headers)))
for zxvk, zbzg := range z.headers {
o = msgp.AppendString(o, zxvk)
o = msgp.AppendBytes(o, zbzg)
}
// string "heapidx"
o = append(o, 0xa7, 0x68, 0x65, 0x61, 0x70, 0x69, 0x64, 0x78)
o = msgp.AppendInt(o, z.heapidx)
return
}
// UnmarshalMsg implements msgp.Unmarshaler
func (z *item) UnmarshalMsg(bts []byte) (o []byte, err error) {
var field []byte
_ = field
var zajw uint32
zajw, bts, err = msgp.ReadMapHeaderBytes(bts)
if err != nil {
return
}
for zajw > 0 {
zajw--
field, bts, err = msgp.ReadMapKeyZC(bts)
if err != nil {
return
}
switch msgp.UnsafeString(field) {
case "body":
z.body, bts, err = msgp.ReadBytesBytes(bts, z.body)
if err != nil {
return
}
case "ctype":
z.ctype, bts, err = msgp.ReadBytesBytes(bts, z.ctype)
if err != nil {
return
}
case "cencoding":
z.cencoding, bts, err = msgp.ReadBytesBytes(bts, z.cencoding)
if err != nil {
return
}
case "status":
z.status, bts, err = msgp.ReadIntBytes(bts)
if err != nil {
return
}
case "exp":
z.exp, bts, err = msgp.ReadUint64Bytes(bts)
if err != nil {
return
}
case "headers":
var zwht uint32
zwht, bts, err = msgp.ReadMapHeaderBytes(bts)
if err != nil {
return
}
if z.headers == nil && zwht > 0 {
z.headers = make(map[string][]byte, zwht)
} else if len(z.headers) > 0 {
for key := range z.headers {
delete(z.headers, key)
}
}
for zwht > 0 {
var zxvk string
var zbzg []byte
zwht--
zxvk, bts, err = msgp.ReadStringBytes(bts)
if err != nil {
return
}
zbzg, bts, err = msgp.ReadBytesBytes(bts, zbzg)
if err != nil {
return
}
z.headers[zxvk] = zbzg
}
case "heapidx":
z.heapidx, bts, err = msgp.ReadIntBytes(bts)
if err != nil {
return
}
default:
bts, err = msgp.Skip(bts)
if err != nil {
return
}
}
}
o = bts
return
}
// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message
func (z *item) Msgsize() (s int) {
s = 1 + 5 + msgp.BytesPrefixSize + len(z.body) + 6 + msgp.BytesPrefixSize + len(z.ctype) + 10 + msgp.BytesPrefixSize + len(z.cencoding) + 7 + msgp.IntSize + 4 + msgp.Uint64Size + 8 + msgp.MapHeaderSize
if z.headers != nil {
for zxvk, zbzg := range z.headers {
_ = zbzg
s += msgp.StringPrefixSize + len(zxvk) + msgp.BytesPrefixSize + len(zbzg)
}
}
s += 8 + msgp.IntSize
return
}

View file

@ -0,0 +1,65 @@
package compress
import (
"github.com/gofiber/fiber/v2"
"github.com/valyala/fasthttp"
)
// New creates a new middleware handler
func New(config ...Config) fiber.Handler {
// Set default config
cfg := configDefault(config...)
// Setup request handlers
var (
fctx = func(c *fasthttp.RequestCtx) {}
compressor fasthttp.RequestHandler
)
// Setup compression algorithm
switch cfg.Level {
case LevelDefault:
// LevelDefault
compressor = fasthttp.CompressHandlerBrotliLevel(fctx,
fasthttp.CompressBrotliDefaultCompression,
fasthttp.CompressDefaultCompression,
)
case LevelBestSpeed:
// LevelBestSpeed
compressor = fasthttp.CompressHandlerBrotliLevel(fctx,
fasthttp.CompressBrotliBestSpeed,
fasthttp.CompressBestSpeed,
)
case LevelBestCompression:
// LevelBestCompression
compressor = fasthttp.CompressHandlerBrotliLevel(fctx,
fasthttp.CompressBrotliBestCompression,
fasthttp.CompressBestCompression,
)
default:
// LevelDisabled
return func(c *fiber.Ctx) error {
return c.Next()
}
}
// Return new handler
return func(c *fiber.Ctx) error {
// Don't execute middleware if Next returns true
if cfg.Next != nil && cfg.Next(c) {
return c.Next()
}
// Continue stack
if err := c.Next(); err != nil {
return err
}
// Compress response
compressor(c.Context())
// Return from handler
return nil
}
}

View file

@ -0,0 +1,192 @@
package compress
import (
"errors"
"fmt"
"io"
"net/http/httptest"
"os"
"testing"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/utils"
)
var filedata []byte
func init() {
dat, err := os.ReadFile("../../.github/README.md")
if err != nil {
panic(err)
}
filedata = dat
}
// go test -run Test_Compress_Gzip
func Test_Compress_Gzip(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
app.Get("/", func(c *fiber.Ctx) error {
c.Set(fiber.HeaderContentType, fiber.MIMETextPlainCharsetUTF8)
return c.Send(filedata)
})
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
req.Header.Set("Accept-Encoding", "gzip")
resp, err := app.Test(req)
utils.AssertEqual(t, nil, err, "app.Test(req)")
utils.AssertEqual(t, 200, resp.StatusCode, "Status code")
utils.AssertEqual(t, "gzip", resp.Header.Get(fiber.HeaderContentEncoding))
// Validate that the file size has shrunk
body, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, true, len(body) < len(filedata))
}
// go test -run Test_Compress_Different_Level
func Test_Compress_Different_Level(t *testing.T) {
t.Parallel()
levels := []Level{LevelBestSpeed, LevelBestCompression}
for _, level := range levels {
level := level
t.Run(fmt.Sprintf("level %d", level), func(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{Level: level}))
app.Get("/", func(c *fiber.Ctx) error {
c.Set(fiber.HeaderContentType, fiber.MIMETextPlainCharsetUTF8)
return c.Send(filedata)
})
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
req.Header.Set("Accept-Encoding", "gzip")
resp, err := app.Test(req)
utils.AssertEqual(t, nil, err, "app.Test(req)")
utils.AssertEqual(t, 200, resp.StatusCode, "Status code")
utils.AssertEqual(t, "gzip", resp.Header.Get(fiber.HeaderContentEncoding))
// Validate that the file size has shrunk
body, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, true, len(body) < len(filedata))
})
}
}
func Test_Compress_Deflate(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
app.Get("/", func(c *fiber.Ctx) error {
return c.Send(filedata)
})
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
req.Header.Set("Accept-Encoding", "deflate")
resp, err := app.Test(req)
utils.AssertEqual(t, nil, err, "app.Test(req)")
utils.AssertEqual(t, 200, resp.StatusCode, "Status code")
utils.AssertEqual(t, "deflate", resp.Header.Get(fiber.HeaderContentEncoding))
// Validate that the file size has shrunk
body, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, true, len(body) < len(filedata))
}
func Test_Compress_Brotli(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
app.Get("/", func(c *fiber.Ctx) error {
return c.Send(filedata)
})
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
req.Header.Set("Accept-Encoding", "br")
resp, err := app.Test(req, 10000)
utils.AssertEqual(t, nil, err, "app.Test(req)")
utils.AssertEqual(t, 200, resp.StatusCode, "Status code")
utils.AssertEqual(t, "br", resp.Header.Get(fiber.HeaderContentEncoding))
// Validate that the file size has shrunk
body, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, true, len(body) < len(filedata))
}
func Test_Compress_Disabled(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{Level: LevelDisabled}))
app.Get("/", func(c *fiber.Ctx) error {
return c.Send(filedata)
})
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
req.Header.Set("Accept-Encoding", "br")
resp, err := app.Test(req)
utils.AssertEqual(t, nil, err, "app.Test(req)")
utils.AssertEqual(t, 200, resp.StatusCode, "Status code")
utils.AssertEqual(t, "", resp.Header.Get(fiber.HeaderContentEncoding))
// Validate the file size is not shrunk
body, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, true, len(body) == len(filedata))
}
func Test_Compress_Next_Error(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
app.Get("/", func(c *fiber.Ctx) error {
return errors.New("next error")
})
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
req.Header.Set("Accept-Encoding", "gzip")
resp, err := app.Test(req)
utils.AssertEqual(t, nil, err, "app.Test(req)")
utils.AssertEqual(t, 500, resp.StatusCode, "Status code")
utils.AssertEqual(t, "", resp.Header.Get(fiber.HeaderContentEncoding))
body, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "next error", string(body))
}
// go test -run Test_Compress_Next
func Test_Compress_Next(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Next: func(_ *fiber.Ctx) bool {
return true
},
}))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
}

View file

@ -0,0 +1,56 @@
package compress
import (
"github.com/gofiber/fiber/v2"
)
// Config defines the config for middleware.
type Config struct {
// Next defines a function to skip this middleware when returned true.
//
// Optional. Default: nil
Next func(c *fiber.Ctx) bool
// Level determines the compression algorithm
//
// Optional. Default: LevelDefault
// LevelDisabled: -1
// LevelDefault: 0
// LevelBestSpeed: 1
// LevelBestCompression: 2
Level Level
}
// Level is numeric representation of compression level
type Level int
// Represents compression level that will be used in the middleware
const (
LevelDisabled Level = -1
LevelDefault Level = 0
LevelBestSpeed Level = 1
LevelBestCompression Level = 2
)
// ConfigDefault is the default config
var ConfigDefault = Config{
Next: nil,
Level: LevelDefault,
}
// Helper function to set default values
func configDefault(config ...Config) Config {
// Return default config if nothing provided
if len(config) < 1 {
return ConfigDefault
}
// Override default config
cfg := config[0]
// Set default values
if cfg.Level < LevelDisabled || cfg.Level > LevelBestCompression {
cfg.Level = ConfigDefault.Level
}
return cfg
}

289
middleware/cors/cors.go Normal file
View file

@ -0,0 +1,289 @@
package cors
import (
"strconv"
"strings"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/log"
)
// Config defines the config for middleware.
type Config struct {
// Next defines a function to skip this middleware when returned true.
//
// Optional. Default: nil
Next func(c *fiber.Ctx) bool
// AllowOriginsFunc defines a function that will set the 'Access-Control-Allow-Origin'
// response header to the 'origin' request header when returned true. This allows for
// dynamic evaluation of allowed origins. Note if AllowCredentials is true, wildcard origins
// will be not have the 'Access-Control-Allow-Credentials' header set to 'true'.
//
// Optional. Default: nil
AllowOriginsFunc func(origin string) bool
// AllowOrigin defines a comma separated list of origins that may access the resource.
//
// Optional. Default value "*"
AllowOrigins string
// AllowMethods defines a list methods allowed when accessing the resource.
// This is used in response to a preflight request.
//
// Optional. Default value "GET,POST,HEAD,PUT,DELETE,PATCH"
AllowMethods string
// AllowHeaders defines a list of request headers that can be used when
// making the actual request. This is in response to a preflight request.
//
// Optional. Default value "".
AllowHeaders string
// AllowCredentials indicates whether or not the response to the request
// can be exposed when the credentials flag is true. When used as part of
// a response to a preflight request, this indicates whether or not the
// actual request can be made using credentials. Note: If true, AllowOrigins
// cannot be set to a wildcard ("*") to prevent security vulnerabilities.
//
// Optional. Default value false.
AllowCredentials bool
// ExposeHeaders defines a whitelist headers that clients are allowed to
// access.
//
// Optional. Default value "".
ExposeHeaders string
// MaxAge indicates how long (in seconds) the results of a preflight request
// can be cached.
// If you pass MaxAge 0, Access-Control-Max-Age header will not be added and
// browser will use 5 seconds by default.
// To disable caching completely, pass MaxAge value negative. It will set the Access-Control-Max-Age header 0.
//
// Optional. Default value 0.
MaxAge int
}
// ConfigDefault is the default config
var ConfigDefault = Config{
Next: nil,
AllowOriginsFunc: nil,
AllowOrigins: "*",
AllowMethods: strings.Join([]string{
fiber.MethodGet,
fiber.MethodPost,
fiber.MethodHead,
fiber.MethodPut,
fiber.MethodDelete,
fiber.MethodPatch,
}, ","),
AllowHeaders: "",
AllowCredentials: false,
ExposeHeaders: "",
MaxAge: 0,
}
// New creates a new middleware handler
func New(config ...Config) fiber.Handler {
// Set default config
cfg := ConfigDefault
// Override config if provided
if len(config) > 0 {
cfg = config[0]
// Set default values
if cfg.AllowMethods == "" {
cfg.AllowMethods = ConfigDefault.AllowMethods
}
// When none of the AllowOrigins or AllowOriginsFunc config was defined, set the default AllowOrigins value with "*"
if cfg.AllowOrigins == "" && cfg.AllowOriginsFunc == nil {
cfg.AllowOrigins = ConfigDefault.AllowOrigins
}
}
// Warning logs if both AllowOrigins and AllowOriginsFunc are set
if cfg.AllowOrigins != "" && cfg.AllowOriginsFunc != nil {
log.Warn("[CORS] Both 'AllowOrigins' and 'AllowOriginsFunc' have been defined.")
}
// Validate CORS credentials configuration
if cfg.AllowCredentials && cfg.AllowOrigins == "*" {
panic("[CORS] Insecure setup, 'AllowCredentials' is set to true, and 'AllowOrigins' is set to a wildcard.")
}
// allowOrigins is a slice of strings that contains the allowed origins
// defined in the 'AllowOrigins' configuration.
allowOrigins := []string{}
allowSOrigins := []subdomain{}
allowAllOrigins := false
// Validate and normalize static AllowOrigins
if cfg.AllowOrigins != "" && cfg.AllowOrigins != "*" {
origins := strings.Split(cfg.AllowOrigins, ",")
for _, origin := range origins {
if i := strings.Index(origin, "://*."); i != -1 {
trimmedOrigin := strings.TrimSpace(origin[:i+3] + origin[i+4:])
isValid, normalizedOrigin := normalizeOrigin(trimmedOrigin)
if !isValid {
panic("[CORS] Invalid origin format in configuration: " + trimmedOrigin)
}
sd := subdomain{prefix: normalizedOrigin[:i+3], suffix: normalizedOrigin[i+3:]}
allowSOrigins = append(allowSOrigins, sd)
} else {
trimmedOrigin := strings.TrimSpace(origin)
isValid, normalizedOrigin := normalizeOrigin(trimmedOrigin)
if !isValid {
panic("[CORS] Invalid origin format in configuration: " + trimmedOrigin)
}
allowOrigins = append(allowOrigins, normalizedOrigin)
}
}
} else if cfg.AllowOrigins == "*" {
allowAllOrigins = true
}
// Strip white spaces
allowMethods := strings.ReplaceAll(cfg.AllowMethods, " ", "")
allowHeaders := strings.ReplaceAll(cfg.AllowHeaders, " ", "")
exposeHeaders := strings.ReplaceAll(cfg.ExposeHeaders, " ", "")
// Convert int to string
maxAge := strconv.Itoa(cfg.MaxAge)
// Return new handler
return func(c *fiber.Ctx) error {
// Don't execute middleware if Next returns true
if cfg.Next != nil && cfg.Next(c) {
return c.Next()
}
// Get originHeader header
originHeader := strings.ToLower(c.Get(fiber.HeaderOrigin))
// If the request does not have Origin header, the request is outside the scope of CORS
if originHeader == "" {
// See https://fetch.spec.whatwg.org/#cors-protocol-and-http-caches
// Unless all origins are allowed, we include the Vary header to cache the response correctly
if !allowAllOrigins {
c.Vary(fiber.HeaderOrigin)
}
return c.Next()
}
// If it's a preflight request and doesn't have Access-Control-Request-Method header, it's outside the scope of CORS
if c.Method() == fiber.MethodOptions && c.Get(fiber.HeaderAccessControlRequestMethod) == "" {
// Response to OPTIONS request should not be cached but,
// some caching can be configured to cache such responses.
// To Avoid poisoning the cache, we include the Vary header
// for non-CORS OPTIONS requests:
c.Vary(fiber.HeaderOrigin)
return c.Next()
}
// Set default allowOrigin to empty string
allowOrigin := ""
// Check allowed origins
if allowAllOrigins {
allowOrigin = "*"
} else {
// Check if the origin is in the list of allowed origins
for _, origin := range allowOrigins {
if origin == originHeader {
allowOrigin = originHeader
break
}
}
// Check if the origin is in the list of allowed subdomains
if allowOrigin == "" {
for _, sOrigin := range allowSOrigins {
if sOrigin.match(originHeader) {
allowOrigin = originHeader
break
}
}
}
}
// Run AllowOriginsFunc if the logic for
// handling the value in 'AllowOrigins' does
// not result in allowOrigin being set.
if allowOrigin == "" && cfg.AllowOriginsFunc != nil && cfg.AllowOriginsFunc(originHeader) {
allowOrigin = originHeader
}
// Simple request
// Ommit allowMethods and allowHeaders, only used for pre-flight requests
if c.Method() != fiber.MethodOptions {
if !allowAllOrigins {
// See https://fetch.spec.whatwg.org/#cors-protocol-and-http-caches
c.Vary(fiber.HeaderOrigin)
}
setCORSHeaders(c, allowOrigin, "", "", exposeHeaders, maxAge, cfg)
return c.Next()
}
// Pre-flight request
// Response to OPTIONS request should not be cached but,
// some caching can be configured to cache such responses.
// To Avoid poisoning the cache, we include the Vary header
// of preflight responses:
c.Vary(fiber.HeaderAccessControlRequestMethod)
c.Vary(fiber.HeaderAccessControlRequestHeaders)
c.Vary(fiber.HeaderOrigin)
setCORSHeaders(c, allowOrigin, allowMethods, allowHeaders, exposeHeaders, maxAge, cfg)
// Send 204 No Content
return c.SendStatus(fiber.StatusNoContent)
}
}
// Function to set CORS headers
func setCORSHeaders(c *fiber.Ctx, allowOrigin, allowMethods, allowHeaders, exposeHeaders, maxAge string, cfg Config) {
if cfg.AllowCredentials {
// When AllowCredentials is true, set the Access-Control-Allow-Origin to the specific origin instead of '*'
if allowOrigin == "*" {
c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin)
log.Warn("[CORS] 'AllowCredentials' is true, but 'AllowOrigins' cannot be set to '*'.")
} else if allowOrigin != "" {
c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin)
c.Set(fiber.HeaderAccessControlAllowCredentials, "true")
}
} else if allowOrigin != "" {
// For non-credential requests, it's safe to set to '*' or specific origins
c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin)
}
// Set Allow-Methods if not empty
if allowMethods != "" {
c.Set(fiber.HeaderAccessControlAllowMethods, allowMethods)
}
// Set Allow-Headers if not empty
if allowHeaders != "" {
c.Set(fiber.HeaderAccessControlAllowHeaders, allowHeaders)
} else {
h := c.Get(fiber.HeaderAccessControlRequestHeaders)
if h != "" {
c.Set(fiber.HeaderAccessControlAllowHeaders, h)
}
}
// Set MaxAge if set
if cfg.MaxAge > 0 {
c.Set(fiber.HeaderAccessControlMaxAge, maxAge)
} else if cfg.MaxAge < 0 {
c.Set(fiber.HeaderAccessControlMaxAge, "0")
}
// Set Expose-Headers if not empty
if exposeHeaders != "" {
c.Set(fiber.HeaderAccessControlExposeHeaders, exposeHeaders)
}
}

1335
middleware/cors/cors_test.go Normal file

File diff suppressed because it is too large Load diff

66
middleware/cors/utils.go Normal file
View file

@ -0,0 +1,66 @@
package cors
import (
"net/url"
"strings"
)
// matchScheme compares the scheme of the domain and pattern
func matchScheme(domain, pattern string) bool {
didx := strings.Index(domain, ":")
pidx := strings.Index(pattern, ":")
return didx != -1 && pidx != -1 && domain[:didx] == pattern[:pidx]
}
// normalizeDomain removes the scheme and port from the input domain
func normalizeDomain(input string) string {
// Remove scheme
input = strings.TrimPrefix(strings.TrimPrefix(input, "http://"), "https://")
// Find and remove port, if present
if len(input) > 0 && input[0] != '[' {
if portIndex := strings.Index(input, ":"); portIndex != -1 {
input = input[:portIndex]
}
}
return input
}
// normalizeOrigin checks if the provided origin is in a correct format
// and normalizes it by removing any path or trailing slash.
// It returns a boolean indicating whether the origin is valid
// and the normalized origin.
func normalizeOrigin(origin string) (bool, string) {
parsedOrigin, err := url.Parse(origin)
if err != nil {
return false, ""
}
// Don't allow a wildcard with a protocol
// wildcards cannot be used within any other value. For example, the following header is not valid:
// Access-Control-Allow-Origin: https://*
if strings.Contains(parsedOrigin.Host, "*") {
return false, ""
}
// Validate there is a host present. The presence of a path, query, or fragment components
// is checked, but a trailing "/" (indicative of the root) is allowed for the path and will be normalized
if parsedOrigin.Host == "" || (parsedOrigin.Path != "" && parsedOrigin.Path != "/") || parsedOrigin.RawQuery != "" || parsedOrigin.Fragment != "" {
return false, ""
}
// Normalize the origin by constructing it from the scheme and host.
// The path or trailing slash is not included in the normalized origin.
return true, strings.ToLower(parsedOrigin.Scheme) + "://" + strings.ToLower(parsedOrigin.Host)
}
type subdomain struct {
// The wildcard pattern
prefix string
suffix string
}
func (s subdomain) match(o string) bool {
return len(o) >= len(s.prefix)+len(s.suffix) && strings.HasPrefix(o, s.prefix) && strings.HasSuffix(o, s.suffix)
}

View file

@ -0,0 +1,196 @@
package cors
import (
"testing"
"github.com/gofiber/fiber/v2/utils"
)
// go test -run -v Test_normalizeOrigin
func Test_normalizeOrigin(t *testing.T) {
testCases := []struct {
origin string
expectedValid bool
expectedOrigin string
}{
{origin: "http://example.com", expectedValid: true, expectedOrigin: "http://example.com"}, // Simple case should work.
{origin: "http://example.com/", expectedValid: true, expectedOrigin: "http://example.com"}, // Trailing slash should be removed.
{origin: "http://example.com:3000", expectedValid: true, expectedOrigin: "http://example.com:3000"}, // Port should be preserved.
{origin: "http://example.com:3000/", expectedValid: true, expectedOrigin: "http://example.com:3000"}, // Trailing slash should be removed.
{origin: "app://example.com/", expectedValid: true, expectedOrigin: "app://example.com"}, // App scheme should be accepted.
{origin: "http://", expectedValid: false, expectedOrigin: ""}, // Invalid origin should not be accepted.
{origin: "file:///etc/passwd", expectedValid: false, expectedOrigin: ""}, // File scheme should not be accepted.
{origin: "https://*example.com", expectedValid: false, expectedOrigin: ""}, // Wildcard domain should not be accepted.
{origin: "http://*.example.com", expectedValid: false, expectedOrigin: ""}, // Wildcard subdomain should not be accepted.
{origin: "http://example.com/path", expectedValid: false, expectedOrigin: ""}, // Path should not be accepted.
{origin: "http://example.com?query=123", expectedValid: false, expectedOrigin: ""}, // Query should not be accepted.
{origin: "http://example.com#fragment", expectedValid: false, expectedOrigin: ""}, // Fragment should not be accepted.
{origin: "http://localhost", expectedValid: true, expectedOrigin: "http://localhost"}, // Localhost should be accepted.
{origin: "http://127.0.0.1", expectedValid: true, expectedOrigin: "http://127.0.0.1"}, // IPv4 address should be accepted.
{origin: "http://[::1]", expectedValid: true, expectedOrigin: "http://[::1]"}, // IPv6 address should be accepted.
{origin: "http://[::1]:8080", expectedValid: true, expectedOrigin: "http://[::1]:8080"}, // IPv6 address with port should be accepted.
{origin: "http://[::1]:8080/", expectedValid: true, expectedOrigin: "http://[::1]:8080"}, // IPv6 address with port and trailing slash should be accepted.
{origin: "http://[::1]:8080/path", expectedValid: false, expectedOrigin: ""}, // IPv6 address with port and path should not be accepted.
{origin: "http://[::1]:8080?query=123", expectedValid: false, expectedOrigin: ""}, // IPv6 address with port and query should not be accepted.
{origin: "http://[::1]:8080#fragment", expectedValid: false, expectedOrigin: ""}, // IPv6 address with port and fragment should not be accepted.
{origin: "http://[::1]:8080/path?query=123#fragment", expectedValid: false, expectedOrigin: ""}, // IPv6 address with port, path, query, and fragment should not be accepted.
{origin: "http://[::1]:8080/path?query=123#fragment/", expectedValid: false, expectedOrigin: ""}, // IPv6 address with port, path, query, fragment, and trailing slash should not be accepted.
{origin: "http://[::1]:8080/path?query=123#fragment/invalid", expectedValid: false, expectedOrigin: ""}, // IPv6 address with port, path, query, fragment, trailing slash, and invalid segment should not be accepted.
{origin: "http://[::1]:8080/path?query=123#fragment/invalid/", expectedValid: false, expectedOrigin: ""}, // IPv6 address with port, path, query, fragment, trailing slash, and invalid segment with trailing slash should not be accepted.
{origin: "http://[::1]:8080/path?query=123#fragment/invalid/segment", expectedValid: false, expectedOrigin: ""}, // IPv6 address with port, path, query, fragment, trailing slash, and invalid segment with additional segment should not be accepted.
}
for _, tc := range testCases {
valid, normalizedOrigin := normalizeOrigin(tc.origin)
if valid != tc.expectedValid {
t.Errorf("Expected origin '%s' to be valid: %v, but got: %v", tc.origin, tc.expectedValid, valid)
}
if normalizedOrigin != tc.expectedOrigin {
t.Errorf("Expected normalized origin '%s' for origin '%s', but got: '%s'", tc.expectedOrigin, tc.origin, normalizedOrigin)
}
}
}
// go test -run -v Test_matchScheme
func Test_matchScheme(t *testing.T) {
testCases := []struct {
domain string
pattern string
expected bool
}{
{"http://example.com", "http://example.com", true}, // Exact match should work.
{"https://example.com", "http://example.com", false}, // Scheme mismatch should matter.
{"http://example.com", "https://example.com", false}, // Scheme mismatch should matter.
{"http://example.com", "http://example.org", true}, // Different domains should not matter.
{"http://example.com", "http://example.com:8080", true}, // Port should not matter.
{"http://example.com:8080", "http://example.com", true}, // Port should not matter.
{"http://example.com:8080", "http://example.com:8081", true}, // Different ports should not matter.
{"http://localhost", "http://localhost", true}, // Localhost should match.
{"http://127.0.0.1", "http://127.0.0.1", true}, // IPv4 address should match.
{"http://[::1]", "http://[::1]", true}, // IPv6 address should match.
}
for _, tc := range testCases {
result := matchScheme(tc.domain, tc.pattern)
if result != tc.expected {
t.Errorf("Expected matchScheme('%s', '%s') to be %v, but got %v", tc.domain, tc.pattern, tc.expected, result)
}
}
}
// go test -run -v Test_normalizeDomain
func Test_normalizeDomain(t *testing.T) {
testCases := []struct {
input string
expectedOutput string
}{
{"http://example.com", "example.com"}, // Simple case with http scheme.
{"https://example.com", "example.com"}, // Simple case with https scheme.
{"http://example.com:3000", "example.com"}, // Case with port.
{"https://example.com:3000", "example.com"}, // Case with port and https scheme.
{"http://example.com/path", "example.com/path"}, // Case with path.
{"http://example.com?query=123", "example.com?query=123"}, // Case with query.
{"http://example.com#fragment", "example.com#fragment"}, // Case with fragment.
{"example.com", "example.com"}, // Case without scheme.
{"example.com:8080", "example.com"}, // Case without scheme but with port.
{"sub.example.com", "sub.example.com"}, // Case with subdomain.
{"sub.sub.example.com", "sub.sub.example.com"}, // Case with nested subdomain.
{"http://localhost", "localhost"}, // Case with localhost.
{"http://127.0.0.1", "127.0.0.1"}, // Case with IPv4 address.
{"http://[::1]", "[::1]"}, // Case with IPv6 address.
}
for _, tc := range testCases {
output := normalizeDomain(tc.input)
if output != tc.expectedOutput {
t.Errorf("Expected normalized domain '%s' for input '%s', but got: '%s'", tc.expectedOutput, tc.input, output)
}
}
}
// go test -v -run=^$ -bench=Benchmark_CORS_SubdomainMatch -benchmem -count=4
func Benchmark_CORS_SubdomainMatch(b *testing.B) {
s := subdomain{
prefix: "www",
suffix: ".example.com",
}
o := "www.example.com"
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
s.match(o)
}
}
func Test_CORS_SubdomainMatch(t *testing.T) {
tests := []struct {
name string
sub subdomain
origin string
expected bool
}{
{
name: "match with different scheme",
sub: subdomain{prefix: "http://api.", suffix: ".example.com"},
origin: "https://api.service.example.com",
expected: false,
},
{
name: "match with different scheme",
sub: subdomain{prefix: "https://", suffix: ".example.com"},
origin: "http://api.service.example.com",
expected: false,
},
{
name: "match with valid subdomain",
sub: subdomain{prefix: "https://", suffix: ".example.com"},
origin: "https://api.service.example.com",
expected: true,
},
{
name: "match with valid nested subdomain",
sub: subdomain{prefix: "https://", suffix: ".example.com"},
origin: "https://1.2.api.service.example.com",
expected: true,
},
{
name: "no match with invalid prefix",
sub: subdomain{prefix: "https://abc.", suffix: ".example.com"},
origin: "https://service.example.com",
expected: false,
},
{
name: "no match with invalid suffix",
sub: subdomain{prefix: "https://", suffix: ".example.com"},
origin: "https://api.example.org",
expected: false,
},
{
name: "no match with empty origin",
sub: subdomain{prefix: "https://", suffix: ".example.com"},
origin: "",
expected: false,
},
{
name: "partial match not considered a match",
sub: subdomain{prefix: "https://service.", suffix: ".example.com"},
origin: "https://api.example.com",
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.sub.match(tt.origin)
utils.AssertEqual(t, tt.expected, got, "subdomain.match()")
})
}
}

243
middleware/csrf/config.go Normal file
View file

@ -0,0 +1,243 @@
package csrf
import (
"net/textproto"
"strings"
"time"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/log"
"github.com/gofiber/fiber/v2/middleware/session"
"github.com/gofiber/fiber/v2/utils"
)
// Config defines the config for middleware.
type Config struct {
// Next defines a function to skip this middleware when returned true.
//
// Optional. Default: nil
Next func(c *fiber.Ctx) bool
// KeyLookup is a string in the form of "<source>:<key>" that is used
// to create an Extractor that extracts the token from the request.
// Possible values:
// - "header:<name>"
// - "query:<name>"
// - "param:<name>"
// - "form:<name>"
// - "cookie:<name>"
//
// Ignored if an Extractor is explicitly set.
//
// Optional. Default: "header:X-Csrf-Token"
KeyLookup string
// Name of the session cookie. This cookie will store session key.
// Optional. Default value "csrf_".
// Overridden if KeyLookup == "cookie:<name>"
CookieName string
// Domain of the CSRF cookie.
// Optional. Default value "".
CookieDomain string
// Path of the CSRF cookie.
// Optional. Default value "".
CookiePath string
// Indicates if CSRF cookie is secure.
// Optional. Default value false.
CookieSecure bool
// Indicates if CSRF cookie is HTTP only.
// Optional. Default value false.
CookieHTTPOnly bool
// Value of SameSite cookie.
// Optional. Default value "Lax".
CookieSameSite string
// Decides whether cookie should last for only the browser sesison.
// Ignores Expiration if set to true
CookieSessionOnly bool
// Expiration is the duration before csrf token will expire
//
// Optional. Default: 1 * time.Hour
Expiration time.Duration
// SingleUseToken indicates if the CSRF token be destroyed
// and a new one generated on each use.
//
// Optional. Default: false
SingleUseToken bool
// Store is used to store the state of the middleware
//
// Optional. Default: memory.New()
// Ignored if Session is set.
Storage fiber.Storage
// Session is used to store the state of the middleware
//
// Optional. Default: nil
// If set, the middleware will use the session store instead of the storage
Session *session.Store
// SessionKey is the key used to store the token in the session
//
// Default: "fiber.csrf.token"
SessionKey string
// Context key to store generated CSRF token into context.
// If left empty, token will not be stored in context.
//
// Optional. Default: ""
ContextKey interface{}
// KeyGenerator creates a new CSRF token
//
// Optional. Default: utils.UUID
KeyGenerator func() string
// Deprecated: Please use Expiration
CookieExpires time.Duration
// Deprecated: Please use Cookie* related fields
Cookie *fiber.Cookie
// Deprecated: Please use KeyLookup
TokenLookup string
// ErrorHandler is executed when an error is returned from fiber.Handler.
//
// Optional. Default: DefaultErrorHandler
ErrorHandler fiber.ErrorHandler
// Extractor returns the csrf token
//
// If set this will be used in place of an Extractor based on KeyLookup.
//
// Optional. Default will create an Extractor based on KeyLookup.
Extractor func(c *fiber.Ctx) (string, error)
// HandlerContextKey is used to store the CSRF Handler into context
//
// Default: "fiber.csrf.handler"
HandlerContextKey interface{}
}
const HeaderName = "X-Csrf-Token"
// ConfigDefault is the default config
var ConfigDefault = Config{
KeyLookup: "header:" + HeaderName,
CookieName: "csrf_",
CookieSameSite: "Lax",
Expiration: 1 * time.Hour,
KeyGenerator: utils.UUIDv4,
ErrorHandler: defaultErrorHandler,
Extractor: CsrfFromHeader(HeaderName),
SessionKey: "fiber.csrf.token",
HandlerContextKey: "fiber.csrf.handler",
}
// default ErrorHandler that process return error from fiber.Handler
func defaultErrorHandler(_ *fiber.Ctx, _ error) error {
return fiber.ErrForbidden
}
// Helper function to set default values
func configDefault(config ...Config) Config {
// Return default config if nothing provided
if len(config) < 1 {
return ConfigDefault
}
// Override default config
cfg := config[0]
// Set default values
if cfg.TokenLookup != "" {
log.Warn("[CSRF] TokenLookup is deprecated, please use KeyLookup")
cfg.KeyLookup = cfg.TokenLookup
}
if int(cfg.CookieExpires.Seconds()) > 0 {
log.Warn("[CSRF] CookieExpires is deprecated, please use Expiration")
cfg.Expiration = cfg.CookieExpires
}
if cfg.Cookie != nil {
log.Warn("[CSRF] Cookie is deprecated, please use Cookie* related fields")
if cfg.Cookie.Name != "" {
cfg.CookieName = cfg.Cookie.Name
}
if cfg.Cookie.Domain != "" {
cfg.CookieDomain = cfg.Cookie.Domain
}
if cfg.Cookie.Path != "" {
cfg.CookiePath = cfg.Cookie.Path
}
cfg.CookieSecure = cfg.Cookie.Secure
cfg.CookieHTTPOnly = cfg.Cookie.HTTPOnly
if cfg.Cookie.SameSite != "" {
cfg.CookieSameSite = cfg.Cookie.SameSite
}
}
if cfg.KeyLookup == "" {
cfg.KeyLookup = ConfigDefault.KeyLookup
}
if int(cfg.Expiration.Seconds()) <= 0 {
cfg.Expiration = ConfigDefault.Expiration
}
if cfg.CookieName == "" {
cfg.CookieName = ConfigDefault.CookieName
}
if cfg.CookieSameSite == "" {
cfg.CookieSameSite = ConfigDefault.CookieSameSite
}
if cfg.KeyGenerator == nil {
cfg.KeyGenerator = ConfigDefault.KeyGenerator
}
if cfg.ErrorHandler == nil {
cfg.ErrorHandler = ConfigDefault.ErrorHandler
}
if cfg.SessionKey == "" {
cfg.SessionKey = ConfigDefault.SessionKey
}
if cfg.HandlerContextKey == nil {
cfg.HandlerContextKey = ConfigDefault.HandlerContextKey
}
// Generate the correct extractor to get the token from the correct location
selectors := strings.Split(cfg.KeyLookup, ":")
const numParts = 2
if len(selectors) != numParts {
panic("[CSRF] KeyLookup must in the form of <source>:<key>")
}
if cfg.Extractor == nil {
// By default we extract from a header
cfg.Extractor = CsrfFromHeader(textproto.CanonicalMIMEHeaderKey(selectors[1]))
switch selectors[0] {
case "form":
cfg.Extractor = CsrfFromForm(selectors[1])
case "query":
cfg.Extractor = CsrfFromQuery(selectors[1])
case "param":
cfg.Extractor = CsrfFromParam(selectors[1])
case "cookie":
if cfg.Session == nil {
log.Warn("[CSRF] Cookie extractor is not recommended without a session store")
}
if cfg.CookieSameSite == "None" || cfg.CookieSameSite != "Lax" && cfg.CookieSameSite != "Strict" {
log.Warn("[CSRF] Cookie extractor is only recommended for use with SameSite=Lax or SameSite=Strict")
}
cfg.Extractor = CsrfFromCookie(selectors[1])
cfg.CookieName = selectors[1] // Cookie name is the same as the key
}
}
return cfg
}

239
middleware/csrf/csrf.go Normal file
View file

@ -0,0 +1,239 @@
package csrf
import (
"errors"
"net/url"
"reflect"
"strings"
"time"
"github.com/gofiber/fiber/v2"
)
var (
ErrTokenNotFound = errors.New("csrf token not found")
ErrTokenInvalid = errors.New("csrf token invalid")
ErrNoReferer = errors.New("referer not supplied")
ErrBadReferer = errors.New("referer invalid")
dummyValue = []byte{'+'}
)
type CSRFHandler struct {
config *Config
sessionManager *sessionManager
storageManager *storageManager
}
// New creates a new middleware handler
func New(config ...Config) fiber.Handler {
// Set default config
cfg := configDefault(config...)
// Create manager to simplify storage operations ( see *_manager.go )
var sessionManager *sessionManager
var storageManager *storageManager
if cfg.Session != nil {
// Register the Token struct in the session store
cfg.Session.RegisterType(Token{})
sessionManager = newSessionManager(cfg.Session, cfg.SessionKey)
} else {
storageManager = newStorageManager(cfg.Storage)
}
// Return new handler
return func(c *fiber.Ctx) error {
// Don't execute middleware if Next returns true
if cfg.Next != nil && cfg.Next(c) {
return c.Next()
}
// Store the CSRF handler in the context if a context key is specified
if cfg.HandlerContextKey != "" {
c.Locals(cfg.HandlerContextKey, &CSRFHandler{
config: &cfg,
sessionManager: sessionManager,
storageManager: storageManager,
})
}
var token string
// Action depends on the HTTP method
switch c.Method() {
case fiber.MethodGet, fiber.MethodHead, fiber.MethodOptions, fiber.MethodTrace:
cookieToken := c.Cookies(cfg.CookieName)
if cookieToken != "" {
raw := getRawFromStorage(c, cookieToken, cfg, sessionManager, storageManager)
if raw != nil {
token = cookieToken // Token is valid, safe to set it
}
}
default:
// Assume that anything not defined as 'safe' by RFC7231 needs protection
// Enforce an origin check for HTTPS connections.
if c.Protocol() == "https" {
if err := refererMatchesHost(c); err != nil {
return cfg.ErrorHandler(c, err)
}
}
// Extract token from client request i.e. header, query, param, form or cookie
extractedToken, err := cfg.Extractor(c)
if err != nil {
return cfg.ErrorHandler(c, err)
}
if extractedToken == "" {
return cfg.ErrorHandler(c, ErrTokenNotFound)
}
// If not using CsrfFromCookie extractor, check that the token matches the cookie
// This is to prevent CSRF attacks by using a Double Submit Cookie method
// Useful when we do not have access to the users Session
if !isCsrfFromCookie(cfg.Extractor) && !compareStrings(extractedToken, c.Cookies(cfg.CookieName)) {
return cfg.ErrorHandler(c, ErrTokenInvalid)
}
raw := getRawFromStorage(c, extractedToken, cfg, sessionManager, storageManager)
if raw == nil {
// If token is not in storage, expire the cookie
expireCSRFCookie(c, cfg)
// and return an error
return cfg.ErrorHandler(c, ErrTokenNotFound)
}
if cfg.SingleUseToken {
// If token is single use, delete it from storage
deleteTokenFromStorage(c, extractedToken, cfg, sessionManager, storageManager)
} else {
token = extractedToken // Token is valid, safe to set it
}
}
// Generate CSRF token if not exist
if token == "" {
// And generate a new token
token = cfg.KeyGenerator()
}
// Create or extend the token in the storage
createOrExtendTokenInStorage(c, token, cfg, sessionManager, storageManager)
// Update the CSRF cookie
updateCSRFCookie(c, cfg, token)
// Tell the browser that a new header value is generated
c.Vary(fiber.HeaderCookie)
// Store the token in the context if a context key is specified
if cfg.ContextKey != nil {
c.Locals(cfg.ContextKey, token)
}
// Continue stack
return c.Next()
}
}
// getRawFromStorage returns the raw value from the storage for the given token
// returns nil if the token does not exist, is expired or is invalid
func getRawFromStorage(c *fiber.Ctx, token string, cfg Config, sessionManager *sessionManager, storageManager *storageManager) []byte {
if cfg.Session != nil {
return sessionManager.getRaw(c, token, dummyValue)
}
return storageManager.getRaw(token)
}
// createOrExtendTokenInStorage creates or extends the token in the storage
func createOrExtendTokenInStorage(c *fiber.Ctx, token string, cfg Config, sessionManager *sessionManager, storageManager *storageManager) {
if cfg.Session != nil {
sessionManager.setRaw(c, token, dummyValue, cfg.Expiration)
} else {
storageManager.setRaw(token, dummyValue, cfg.Expiration)
}
}
func deleteTokenFromStorage(c *fiber.Ctx, token string, cfg Config, sessionManager *sessionManager, storageManager *storageManager) {
if cfg.Session != nil {
sessionManager.delRaw(c)
} else {
storageManager.delRaw(token)
}
}
// Update CSRF cookie
// if expireCookie is true, the cookie will expire immediately
func updateCSRFCookie(c *fiber.Ctx, cfg Config, token string) {
setCSRFCookie(c, cfg, token, cfg.Expiration)
}
func expireCSRFCookie(c *fiber.Ctx, cfg Config) {
setCSRFCookie(c, cfg, "", -time.Hour)
}
func setCSRFCookie(c *fiber.Ctx, cfg Config, token string, expiry time.Duration) {
cookie := &fiber.Cookie{
Name: cfg.CookieName,
Value: token,
Domain: cfg.CookieDomain,
Path: cfg.CookiePath,
Secure: cfg.CookieSecure,
HTTPOnly: cfg.CookieHTTPOnly,
SameSite: cfg.CookieSameSite,
SessionOnly: cfg.CookieSessionOnly,
Expires: time.Now().Add(expiry),
}
// Set the CSRF cookie to the response
c.Cookie(cookie)
}
// DeleteToken removes the token found in the context from the storage
// and expires the CSRF cookie
func (handler *CSRFHandler) DeleteToken(c *fiber.Ctx) error {
// Get the config from the context
config := handler.config
if config == nil {
panic("CSRFHandler config not found in context")
}
// Extract token from the client request cookie
cookieToken := c.Cookies(config.CookieName)
if cookieToken == "" {
return config.ErrorHandler(c, ErrTokenNotFound)
}
// Remove the token from storage
deleteTokenFromStorage(c, cookieToken, *config, handler.sessionManager, handler.storageManager)
// Expire the cookie
expireCSRFCookie(c, *config)
return nil
}
// isCsrfFromCookie checks if the extractor is set to ExtractFromCookie
func isCsrfFromCookie(extractor interface{}) bool {
return reflect.ValueOf(extractor).Pointer() == reflect.ValueOf(CsrfFromCookie).Pointer()
}
// refererMatchesHost checks that the referer header matches the host header
// returns an error if the referer header is not present or is invalid
// returns nil if the referer header is valid
func refererMatchesHost(c *fiber.Ctx) error {
referer := strings.ToLower(c.Get(fiber.HeaderReferer))
if referer == "" {
return ErrNoReferer
}
refererURL, err := url.Parse(referer)
if err != nil {
return ErrBadReferer
}
if refererURL.Scheme == c.Protocol() && refererURL.Host == c.Hostname() {
return nil
}
return ErrBadReferer
}

1060
middleware/csrf/csrf_test.go Normal file

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,70 @@
package csrf
import (
"errors"
"github.com/gofiber/fiber/v2"
)
var (
ErrMissingHeader = errors.New("missing csrf token in header")
ErrMissingQuery = errors.New("missing csrf token in query")
ErrMissingParam = errors.New("missing csrf token in param")
ErrMissingForm = errors.New("missing csrf token in form")
ErrMissingCookie = errors.New("missing csrf token in cookie")
)
// csrfFromParam returns a function that extracts token from the url param string.
func CsrfFromParam(param string) func(c *fiber.Ctx) (string, error) {
return func(c *fiber.Ctx) (string, error) {
token := c.Params(param)
if token == "" {
return "", ErrMissingParam
}
return token, nil
}
}
// csrfFromForm returns a function that extracts a token from a multipart-form.
func CsrfFromForm(param string) func(c *fiber.Ctx) (string, error) {
return func(c *fiber.Ctx) (string, error) {
token := c.FormValue(param)
if token == "" {
return "", ErrMissingForm
}
return token, nil
}
}
// csrfFromCookie returns a function that extracts token from the cookie header.
func CsrfFromCookie(param string) func(c *fiber.Ctx) (string, error) {
return func(c *fiber.Ctx) (string, error) {
token := c.Cookies(param)
if token == "" {
return "", ErrMissingCookie
}
return token, nil
}
}
// csrfFromHeader returns a function that extracts token from the request header.
func CsrfFromHeader(param string) func(c *fiber.Ctx) (string, error) {
return func(c *fiber.Ctx) (string, error) {
token := c.Get(param)
if token == "" {
return "", ErrMissingHeader
}
return token, nil
}
}
// csrfFromQuery returns a function that extracts token from the query string.
func CsrfFromQuery(param string) func(c *fiber.Ctx) (string, error) {
return func(c *fiber.Ctx) (string, error) {
token := c.Query(param)
if token == "" {
return "", ErrMissingQuery
}
return token, nil
}
}

View file

@ -0,0 +1,13 @@
package csrf
import (
"crypto/subtle"
)
func compareTokens(a, b []byte) bool {
return subtle.ConstantTimeCompare(a, b) == 1
}
func compareStrings(a, b string) bool {
return subtle.ConstantTimeCompare([]byte(a), []byte(b)) == 1
}

View file

@ -0,0 +1,68 @@
package csrf
import (
"time"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/log"
"github.com/gofiber/fiber/v2/middleware/session"
)
type sessionManager struct {
key string
session *session.Store
}
func newSessionManager(s *session.Store, k string) *sessionManager {
// Create new storage handler
sessionManager := &sessionManager{
key: k,
}
if s != nil {
// Use provided storage if provided
sessionManager.session = s
}
return sessionManager
}
// get token from session
func (m *sessionManager) getRaw(c *fiber.Ctx, key string, raw []byte) []byte {
sess, err := m.session.Get(c)
if err != nil {
return nil
}
token, ok := sess.Get(m.key).(Token)
if ok {
if token.Expiration.Before(time.Now()) || key != token.Key || !compareTokens(raw, token.Raw) {
return nil
}
return token.Raw
}
return nil
}
// set token in session
func (m *sessionManager) setRaw(c *fiber.Ctx, key string, raw []byte, exp time.Duration) {
sess, err := m.session.Get(c)
if err != nil {
return
}
// the key is crucial in crsf and sometimes a reference to another value which can be reused later(pool/unsafe values concept), so a copy is made here
sess.Set(m.key, &Token{key, raw, time.Now().Add(exp)})
if err := sess.Save(); err != nil {
log.Warn("csrf: failed to save session: ", err)
}
}
// delete token from session
func (m *sessionManager) delRaw(c *fiber.Ctx) {
sess, err := m.session.Get(c)
if err != nil {
return
}
sess.Delete(m.key)
if err := sess.Save(); err != nil {
log.Warn("csrf: failed to save session: ", err)
}
}

View file

@ -0,0 +1,70 @@
package csrf
import (
"sync"
"time"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/internal/memory"
"github.com/gofiber/fiber/v2/utils"
)
// go:generate msgp
// msgp -file="storage_manager.go" -o="storage_manager_msgp.go" -tests=false -unexported
type item struct{}
//msgp:ignore manager
type storageManager struct {
pool sync.Pool
memory *memory.Storage
storage fiber.Storage
}
func newStorageManager(storage fiber.Storage) *storageManager {
// Create new storage handler
storageManager := &storageManager{
pool: sync.Pool{
New: func() interface{} {
return new(item)
},
},
}
if storage != nil {
// Use provided storage if provided
storageManager.storage = storage
} else {
// Fallback too memory storage
storageManager.memory = memory.New()
}
return storageManager
}
// get raw data from storage or memory
func (m *storageManager) getRaw(key string) []byte {
var raw []byte
if m.storage != nil {
raw, _ = m.storage.Get(key) //nolint:errcheck // TODO: Do not ignore error
} else {
raw, _ = m.memory.Get(key).([]byte) //nolint:errcheck // TODO: Do not ignore error
}
return raw
}
// set data to storage or memory
func (m *storageManager) setRaw(key string, raw []byte, exp time.Duration) {
if m.storage != nil {
_ = m.storage.Set(key, raw, exp) //nolint:errcheck // TODO: Do not ignore error
} else {
// the key is crucial in crsf and sometimes a reference to another value which can be reused later(pool/unsafe values concept), so a copy is made here
m.memory.Set(utils.CopyString(key), raw, exp)
}
}
// delete data from storage or memory
func (m *storageManager) delRaw(key string) {
if m.storage != nil {
_ = m.storage.Delete(key) //nolint:errcheck // TODO: Do not ignore error
} else {
m.memory.Delete(key)
}
}

View file

@ -0,0 +1,90 @@
package csrf
// Code generated by github.com/tinylib/msgp DO NOT EDIT.
import (
"github.com/tinylib/msgp/msgp"
)
// DecodeMsg implements msgp.Decodable
func (z *item) DecodeMsg(dc *msgp.Reader) (err error) {
var field []byte
_ = field
var zb0001 uint32
zb0001, err = dc.ReadMapHeader()
if err != nil {
err = msgp.WrapError(err)
return
}
for zb0001 > 0 {
zb0001--
field, err = dc.ReadMapKeyPtr()
if err != nil {
err = msgp.WrapError(err)
return
}
switch msgp.UnsafeString(field) {
default:
err = dc.Skip()
if err != nil {
err = msgp.WrapError(err)
return
}
}
}
return
}
// EncodeMsg implements msgp.Encodable
func (z item) EncodeMsg(en *msgp.Writer) (err error) {
// map header, size 0
err = en.Append(0x80)
if err != nil {
return
}
return
}
// MarshalMsg implements msgp.Marshaler
func (z item) MarshalMsg(b []byte) (o []byte, err error) {
o = msgp.Require(b, z.Msgsize())
// map header, size 0
o = append(o, 0x80)
return
}
// UnmarshalMsg implements msgp.Unmarshaler
func (z *item) UnmarshalMsg(bts []byte) (o []byte, err error) {
var field []byte
_ = field
var zb0001 uint32
zb0001, bts, err = msgp.ReadMapHeaderBytes(bts)
if err != nil {
err = msgp.WrapError(err)
return
}
for zb0001 > 0 {
zb0001--
field, bts, err = msgp.ReadMapKeyZC(bts)
if err != nil {
err = msgp.WrapError(err)
return
}
switch msgp.UnsafeString(field) {
default:
bts, err = msgp.Skip(bts)
if err != nil {
err = msgp.WrapError(err)
return
}
}
}
o = bts
return
}
// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message
func (z item) Msgsize() (s int) {
s = 1
return
}

11
middleware/csrf/token.go Normal file
View file

@ -0,0 +1,11 @@
package csrf
import (
"time"
)
type Token struct {
Key string `json:"key"`
Raw []byte `json:"raw"`
Expiration time.Time `json:"expiration"`
}

View file

@ -0,0 +1,73 @@
package earlydata
import (
"github.com/gofiber/fiber/v2"
)
const (
DefaultHeaderName = "Early-Data"
DefaultHeaderTrueValue = "1"
)
// Config defines the config for middleware.
type Config struct {
// Next defines a function to skip this middleware when returned true.
//
// Optional. Default: nil
Next func(c *fiber.Ctx) bool
// IsEarlyData returns whether the request is an early-data request.
//
// Optional. Default: a function which checks if the "Early-Data" request header equals "1".
IsEarlyData func(c *fiber.Ctx) bool
// AllowEarlyData returns whether the early-data request should be allowed or rejected.
//
// Optional. Default: a function which rejects the request on unsafe and allows the request on safe HTTP request methods.
AllowEarlyData func(c *fiber.Ctx) bool
// Error is returned in case an early-data request is rejected.
//
// Optional. Default: fiber.ErrTooEarly.
Error error
}
// ConfigDefault is the default config
var ConfigDefault = Config{
IsEarlyData: func(c *fiber.Ctx) bool {
return c.Get(DefaultHeaderName) == DefaultHeaderTrueValue
},
AllowEarlyData: func(c *fiber.Ctx) bool {
return fiber.IsMethodSafe(c.Method())
},
Error: fiber.ErrTooEarly,
}
// Helper function to set default values
func configDefault(config ...Config) Config {
// Return default config if nothing provided
if len(config) < 1 {
return ConfigDefault
}
// Override default config
cfg := config[0]
// Set default values
if cfg.IsEarlyData == nil {
cfg.IsEarlyData = ConfigDefault.IsEarlyData
}
if cfg.AllowEarlyData == nil {
cfg.AllowEarlyData = ConfigDefault.AllowEarlyData
}
if cfg.Error == nil {
cfg.Error = ConfigDefault.Error
}
return cfg
}

View file

@ -0,0 +1,47 @@
package earlydata
import (
"github.com/gofiber/fiber/v2"
)
const (
localsKeyAllowed = "earlydata_allowed"
)
func IsEarly(c *fiber.Ctx) bool {
return c.Locals(localsKeyAllowed) != nil
}
// New creates a new middleware handler
// https://datatracker.ietf.org/doc/html/rfc8470#section-5.1
func New(config ...Config) fiber.Handler {
// Set default config
cfg := configDefault(config...)
// Return new handler
return func(c *fiber.Ctx) error {
// Don't execute middleware if Next returns true
if cfg.Next != nil && cfg.Next(c) {
return c.Next()
}
// Abort if we can't trust the early-data header
if !c.IsProxyTrusted() {
return cfg.Error
}
// Continue stack if request is not an early-data request
if !cfg.IsEarlyData(c) {
return c.Next()
}
// Continue stack if we allow early-data for this request
if cfg.AllowEarlyData(c) {
_ = c.Locals(localsKeyAllowed, true)
return c.Next()
}
// Else return our error
return cfg.Error
}
}

View file

@ -0,0 +1,193 @@
//nolint:bodyclose // Much easier to just ignore memory leaks in tests
package earlydata_test
import (
"errors"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/earlydata"
"github.com/gofiber/fiber/v2/utils"
)
const (
headerName = "Early-Data"
headerValOn = "1"
headerValOff = "0"
)
func appWithConfig(t *testing.T, c *fiber.Config) *fiber.App {
t.Helper()
t.Parallel()
var app *fiber.App
if c == nil {
app = fiber.New()
} else {
app = fiber.New(*c)
}
app.Use(earlydata.New())
// Middleware to test IsEarly func
const localsKeyTestValid = "earlydata_testvalid"
app.Use(func(c *fiber.Ctx) error {
isEarly := earlydata.IsEarly(c)
switch h := c.Get(headerName); h {
case "", headerValOff:
if isEarly {
return errors.New("is early-data even though it's not")
}
case headerValOn:
switch {
case fiber.IsMethodSafe(c.Method()):
if !isEarly {
return errors.New("should be early-data on safe HTTP methods")
}
default:
if isEarly {
return errors.New("early-data unsuported on unsafe HTTP methods")
}
}
default:
return fmt.Errorf("header has unsupported value: %s", h)
}
_ = c.Locals(localsKeyTestValid, true)
return c.Next()
})
{
{
handler := func(c *fiber.Ctx) error {
if !c.Locals(localsKeyTestValid).(bool) { //nolint:forcetypeassert // We store nothing else in the pool
return errors.New("handler called even though validation failed")
}
return nil
}
app.Get("/", handler)
app.Post("/", handler)
}
}
return app
}
// go test -run Test_EarlyData
func Test_EarlyData(t *testing.T) {
t.Parallel()
trustedRun := func(t *testing.T, app *fiber.App) {
t.Helper()
{
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
resp, err := app.Test(req)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
req.Header.Set(headerName, headerValOff)
resp, err = app.Test(req)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
req.Header.Set(headerName, headerValOn)
resp, err = app.Test(req)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
}
{
req := httptest.NewRequest(fiber.MethodPost, "/", http.NoBody)
resp, err := app.Test(req)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
req.Header.Set(headerName, headerValOff)
resp, err = app.Test(req)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
req.Header.Set(headerName, headerValOn)
resp, err = app.Test(req)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusTooEarly, resp.StatusCode)
}
}
untrustedRun := func(t *testing.T, app *fiber.App) {
t.Helper()
{
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
resp, err := app.Test(req)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusTooEarly, resp.StatusCode)
req.Header.Set(headerName, headerValOff)
resp, err = app.Test(req)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusTooEarly, resp.StatusCode)
req.Header.Set(headerName, headerValOn)
resp, err = app.Test(req)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusTooEarly, resp.StatusCode)
}
{
req := httptest.NewRequest(fiber.MethodPost, "/", http.NoBody)
resp, err := app.Test(req)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusTooEarly, resp.StatusCode)
req.Header.Set(headerName, headerValOff)
resp, err = app.Test(req)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusTooEarly, resp.StatusCode)
req.Header.Set(headerName, headerValOn)
resp, err = app.Test(req)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusTooEarly, resp.StatusCode)
}
}
t.Run("empty config", func(t *testing.T) {
app := appWithConfig(t, nil)
trustedRun(t, app)
})
t.Run("default config", func(t *testing.T) {
app := appWithConfig(t, &fiber.Config{})
trustedRun(t, app)
})
t.Run("config with EnableTrustedProxyCheck", func(t *testing.T) {
app := appWithConfig(t, &fiber.Config{
EnableTrustedProxyCheck: true,
})
untrustedRun(t, app)
})
t.Run("config with EnableTrustedProxyCheck and trusted TrustedProxies", func(t *testing.T) {
app := appWithConfig(t, &fiber.Config{
EnableTrustedProxyCheck: true,
TrustedProxies: []string{
"0.0.0.0",
},
})
trustedRun(t, app)
})
}

View file

@ -0,0 +1,78 @@
package encryptcookie
import (
"github.com/gofiber/fiber/v2"
)
// Config defines the config for middleware.
type Config struct {
// Next defines a function to skip this middleware when returned true.
//
// Optional. Default: nil
Next func(c *fiber.Ctx) bool
// Array of cookie keys that should not be encrypted.
//
// Optional. Default: []
Except []string
// Base64 encoded unique key to encode & decode cookies.
//
// Required. Key length should be 32 characters.
// You may use `encryptcookie.GenerateKey()` to generate a new key.
Key string
// Custom function to encrypt cookies.
//
// Optional. Default: EncryptCookie
Encryptor func(decryptedString, key string) (string, error)
// Custom function to decrypt cookies.
//
// Optional. Default: DecryptCookie
Decryptor func(encryptedString, key string) (string, error)
}
// ConfigDefault is the default config
var ConfigDefault = Config{
Next: nil,
Except: []string{},
Key: "",
Encryptor: EncryptCookie,
Decryptor: DecryptCookie,
}
// Helper function to set default values
func configDefault(config ...Config) Config {
// Set default config
cfg := ConfigDefault
// Override config if provided
if len(config) > 0 {
cfg = config[0]
// Set default values
if cfg.Next == nil {
cfg.Next = ConfigDefault.Next
}
if cfg.Except == nil {
cfg.Except = ConfigDefault.Except
}
if cfg.Encryptor == nil {
cfg.Encryptor = ConfigDefault.Encryptor
}
if cfg.Decryptor == nil {
cfg.Decryptor = ConfigDefault.Decryptor
}
}
if cfg.Key == "" {
panic("fiber: encrypt cookie middleware requires key")
}
return cfg
}

View file

@ -0,0 +1,57 @@
package encryptcookie
import (
"github.com/gofiber/fiber/v2"
"github.com/valyala/fasthttp"
)
// New creates a new middleware handler
func New(config ...Config) fiber.Handler {
// Set default config
cfg := configDefault(config...)
// Return new handler
return func(c *fiber.Ctx) error {
// Don't execute middleware if Next returns true
if cfg.Next != nil && cfg.Next(c) {
return c.Next()
}
// Decrypt request cookies
c.Request().Header.VisitAllCookie(func(key, value []byte) {
keyString := string(key)
if !isDisabled(keyString, cfg.Except) {
decryptedValue, err := cfg.Decryptor(string(value), cfg.Key)
if err != nil {
c.Request().Header.SetCookieBytesKV(key, nil)
} else {
c.Request().Header.SetCookie(string(key), decryptedValue)
}
}
})
// Continue stack
err := c.Next()
// Encrypt response cookies
c.Response().Header.VisitAllCookie(func(key, value []byte) {
keyString := string(key)
if !isDisabled(keyString, cfg.Except) {
cookieValue := fasthttp.Cookie{}
cookieValue.SetKeyBytes(key)
if c.Response().Header.Cookie(&cookieValue) {
encryptedValue, err := cfg.Encryptor(string(cookieValue.Value()), cfg.Key)
if err != nil {
panic(err)
}
cookieValue.SetValue(encryptedValue)
c.Response().Header.SetCookie(&cookieValue)
}
}
})
return err
}
}

View file

@ -0,0 +1,192 @@
package encryptcookie
import (
"encoding/base64"
"net/http/httptest"
"testing"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/utils"
"github.com/valyala/fasthttp"
)
var testKey = GenerateKey()
func Test_Middleware_Encrypt_Cookie(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Key: testKey,
}))
app.Get("/", func(c *fiber.Ctx) error {
return c.SendString("value=" + c.Cookies("test"))
})
app.Post("/", func(c *fiber.Ctx) error {
c.Cookie(&fiber.Cookie{
Name: "test",
Value: "SomeThing",
})
return nil
})
h := app.Handler()
// Test empty cookie
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodGet)
h(ctx)
utils.AssertEqual(t, 200, ctx.Response.StatusCode())
utils.AssertEqual(t, "value=", string(ctx.Response.Body()))
// Test invalid cookie
ctx = &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.Header.SetCookie("test", "Invalid")
h(ctx)
utils.AssertEqual(t, 200, ctx.Response.StatusCode())
utils.AssertEqual(t, "value=", string(ctx.Response.Body()))
ctx.Request.Header.SetCookie("test", "ixQURE2XOyZUs0WAOh2ehjWcP7oZb07JvnhWOsmeNUhPsj4+RyI=")
h(ctx)
utils.AssertEqual(t, 200, ctx.Response.StatusCode())
utils.AssertEqual(t, "value=", string(ctx.Response.Body()))
// Test valid cookie
ctx = &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodPost)
h(ctx)
utils.AssertEqual(t, 200, ctx.Response.StatusCode())
encryptedCookie := fasthttp.Cookie{}
encryptedCookie.SetKey("test")
utils.AssertEqual(t, true, ctx.Response.Header.Cookie(&encryptedCookie), "Get cookie value")
decryptedCookieValue, err := DecryptCookie(string(encryptedCookie.Value()), testKey)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "SomeThing", decryptedCookieValue)
ctx = &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.Header.SetCookie("test", string(encryptedCookie.Value()))
h(ctx)
utils.AssertEqual(t, 200, ctx.Response.StatusCode())
utils.AssertEqual(t, "value=SomeThing", string(ctx.Response.Body()))
}
func Test_Encrypt_Cookie_Next(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Key: testKey,
Next: func(_ *fiber.Ctx) bool {
return true
},
}))
app.Get("/", func(c *fiber.Ctx) error {
c.Cookie(&fiber.Cookie{
Name: "test",
Value: "SomeThing",
})
return nil
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "SomeThing", resp.Cookies()[0].Value)
}
func Test_Encrypt_Cookie_Except(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Key: testKey,
Except: []string{
"test1",
},
}))
app.Get("/", func(c *fiber.Ctx) error {
c.Cookie(&fiber.Cookie{
Name: "test1",
Value: "SomeThing",
})
c.Cookie(&fiber.Cookie{
Name: "test2",
Value: "SomeThing",
})
return nil
})
h := app.Handler()
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodGet)
h(ctx)
utils.AssertEqual(t, 200, ctx.Response.StatusCode())
rawCookie := fasthttp.Cookie{}
rawCookie.SetKey("test1")
utils.AssertEqual(t, true, ctx.Response.Header.Cookie(&rawCookie), "Get cookie value")
utils.AssertEqual(t, "SomeThing", string(rawCookie.Value()))
encryptedCookie := fasthttp.Cookie{}
encryptedCookie.SetKey("test2")
utils.AssertEqual(t, true, ctx.Response.Header.Cookie(&encryptedCookie), "Get cookie value")
decryptedCookieValue, err := DecryptCookie(string(encryptedCookie.Value()), testKey)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "SomeThing", decryptedCookieValue)
}
func Test_Encrypt_Cookie_Custom_Encryptor(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Key: testKey,
Encryptor: func(decryptedString, _ string) (string, error) {
return base64.StdEncoding.EncodeToString([]byte(decryptedString)), nil
},
Decryptor: func(encryptedString, _ string) (string, error) {
decodedBytes, err := base64.StdEncoding.DecodeString(encryptedString)
return string(decodedBytes), err
},
}))
app.Get("/", func(c *fiber.Ctx) error {
return c.SendString("value=" + c.Cookies("test"))
})
app.Post("/", func(c *fiber.Ctx) error {
c.Cookie(&fiber.Cookie{
Name: "test",
Value: "SomeThing",
})
return nil
})
h := app.Handler()
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodPost)
h(ctx)
utils.AssertEqual(t, 200, ctx.Response.StatusCode())
encryptedCookie := fasthttp.Cookie{}
encryptedCookie.SetKey("test")
utils.AssertEqual(t, true, ctx.Response.Header.Cookie(&encryptedCookie), "Get cookie value")
decodedBytes, err := base64.StdEncoding.DecodeString(string(encryptedCookie.Value()))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "SomeThing", string(decodedBytes))
ctx = &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.Header.SetCookie("test", string(encryptedCookie.Value()))
h(ctx)
utils.AssertEqual(t, 200, ctx.Response.StatusCode())
utils.AssertEqual(t, "value=SomeThing", string(ctx.Response.Body()))
}

View file

@ -0,0 +1,98 @@
package encryptcookie
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/base64"
"errors"
"fmt"
"io"
)
// EncryptCookie Encrypts a cookie value with specific encryption key
func EncryptCookie(value, key string) (string, error) {
keyDecoded, err := base64.StdEncoding.DecodeString(key)
if err != nil {
return "", fmt.Errorf("failed to base64-decode key: %w", err)
}
block, err := aes.NewCipher(keyDecoded)
if err != nil {
return "", fmt.Errorf("failed to create AES cipher: %w", err)
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return "", fmt.Errorf("failed to create GCM mode: %w", err)
}
nonce := make([]byte, gcm.NonceSize())
if _, err = io.ReadFull(rand.Reader, nonce); err != nil {
return "", fmt.Errorf("failed to read: %w", err)
}
ciphertext := gcm.Seal(nonce, nonce, []byte(value), nil)
return base64.StdEncoding.EncodeToString(ciphertext), nil
}
// DecryptCookie Decrypts a cookie value with specific encryption key
func DecryptCookie(value, key string) (string, error) {
keyDecoded, err := base64.StdEncoding.DecodeString(key)
if err != nil {
return "", fmt.Errorf("failed to base64-decode key: %w", err)
}
enc, err := base64.StdEncoding.DecodeString(value)
if err != nil {
return "", fmt.Errorf("failed to base64-decode value: %w", err)
}
block, err := aes.NewCipher(keyDecoded)
if err != nil {
return "", fmt.Errorf("failed to create AES cipher: %w", err)
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return "", fmt.Errorf("failed to create GCM mode: %w", err)
}
nonceSize := gcm.NonceSize()
if len(enc) < nonceSize {
return "", errors.New("encrypted value is not valid")
}
nonce, ciphertext := enc[:nonceSize], enc[nonceSize:]
plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
if err != nil {
return "", fmt.Errorf("failed to decrypt ciphertext: %w", err)
}
return string(plaintext), nil
}
// GenerateKey Generates an encryption key
func GenerateKey() string {
const keyLen = 32
ret := make([]byte, keyLen)
if _, err := rand.Read(ret); err != nil {
panic(err)
}
return base64.StdEncoding.EncodeToString(ret)
}
// Check given cookie key is disabled for encryption or not
func isDisabled(key string, except []string) bool {
for _, k := range except {
if key == k {
return true
}
}
return false
}

View file

@ -0,0 +1,68 @@
package envvar
import (
"os"
"strings"
"github.com/gofiber/fiber/v2"
)
// Config defines the config for middleware.
type Config struct {
// ExportVars specifies the environment variables that should export
ExportVars map[string]string
// ExcludeVars specifies the environment variables that should not export
ExcludeVars map[string]string
}
type EnvVar struct {
Vars map[string]string `json:"vars"`
}
func (envVar *EnvVar) set(key, val string) {
envVar.Vars[key] = val
}
func New(config ...Config) fiber.Handler {
var cfg Config
if len(config) > 0 {
cfg = config[0]
}
return func(c *fiber.Ctx) error {
if c.Method() != fiber.MethodGet {
return fiber.ErrMethodNotAllowed
}
envVar := newEnvVar(cfg)
varsByte, err := c.App().Config().JSONEncoder(envVar)
if err != nil {
return c.Status(fiber.StatusInternalServerError).SendString(err.Error())
}
c.Set(fiber.HeaderContentType, fiber.MIMEApplicationJSONCharsetUTF8)
return c.Send(varsByte)
}
}
func newEnvVar(cfg Config) *EnvVar {
vars := &EnvVar{Vars: make(map[string]string)}
if len(cfg.ExportVars) > 0 {
for key, defaultVal := range cfg.ExportVars {
vars.set(key, defaultVal)
if envVal, exists := os.LookupEnv(key); exists {
vars.set(key, envVal)
}
}
} else {
const numElems = 2
for _, envVal := range os.Environ() {
keyVal := strings.SplitN(envVal, "=", numElems)
if _, exists := cfg.ExcludeVars[keyVal[0]]; !exists {
vars.set(keyVal[0], keyVal[1])
}
}
}
return vars
}

View file

@ -0,0 +1,172 @@
//nolint:bodyclose // Much easier to just ignore memory leaks in tests
package envvar
import (
"context"
"encoding/json"
"io"
"net/http"
"os"
"testing"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/utils"
)
func TestEnvVarStructWithExportVarsExcludeVars(t *testing.T) {
err := os.Setenv("testKey", "testEnvValue")
utils.AssertEqual(t, nil, err)
err = os.Setenv("anotherEnvKey", "anotherEnvVal")
utils.AssertEqual(t, nil, err)
err = os.Setenv("excludeKey", "excludeEnvValue")
utils.AssertEqual(t, nil, err)
defer func() {
err := os.Unsetenv("testKey")
utils.AssertEqual(t, nil, err)
err = os.Unsetenv("anotherEnvKey")
utils.AssertEqual(t, nil, err)
err = os.Unsetenv("excludeKey")
utils.AssertEqual(t, nil, err)
}()
vars := newEnvVar(Config{
ExportVars: map[string]string{"testKey": "", "testDefaultKey": "testDefaultVal"},
ExcludeVars: map[string]string{"excludeKey": ""},
})
utils.AssertEqual(t, vars.Vars["testKey"], "testEnvValue")
utils.AssertEqual(t, vars.Vars["testDefaultKey"], "testDefaultVal")
utils.AssertEqual(t, vars.Vars["excludeKey"], "")
utils.AssertEqual(t, vars.Vars["anotherEnvKey"], "")
}
func TestEnvVarHandler(t *testing.T) {
err := os.Setenv("testKey", "testVal")
utils.AssertEqual(t, nil, err)
defer func() {
err := os.Unsetenv("testKey")
utils.AssertEqual(t, nil, err)
}()
expectedEnvVarResponse, err := json.Marshal(
struct {
Vars map[string]string `json:"vars"`
}{
map[string]string{"testKey": "testVal"},
})
utils.AssertEqual(t, nil, err)
app := fiber.New()
app.Use("/envvars", New(Config{
ExportVars: map[string]string{"testKey": ""},
}))
req, err := http.NewRequestWithContext(context.Background(), fiber.MethodGet, "http://localhost/envvars", nil)
utils.AssertEqual(t, nil, err)
resp, err := app.Test(req)
utils.AssertEqual(t, nil, err)
respBody, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, expectedEnvVarResponse, respBody)
}
func TestEnvVarHandlerNotMatched(t *testing.T) {
app := fiber.New()
app.Use("/envvars", New(Config{
ExportVars: map[string]string{"testKey": ""},
}))
app.Get("/another-path", func(ctx *fiber.Ctx) error {
utils.AssertEqual(t, nil, ctx.SendString("OK"))
return nil
})
req, err := http.NewRequestWithContext(context.Background(), fiber.MethodGet, "http://localhost/another-path", nil)
utils.AssertEqual(t, nil, err)
resp, err := app.Test(req)
utils.AssertEqual(t, nil, err)
respBody, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, []byte("OK"), respBody)
}
func TestEnvVarHandlerDefaultConfig(t *testing.T) {
err := os.Setenv("testEnvKey", "testEnvVal")
utils.AssertEqual(t, nil, err)
defer func() {
err := os.Unsetenv("testEnvKey")
utils.AssertEqual(t, nil, err)
}()
app := fiber.New()
app.Use("/envvars", New())
req, err := http.NewRequestWithContext(context.Background(), fiber.MethodGet, "http://localhost/envvars", nil)
utils.AssertEqual(t, nil, err)
resp, err := app.Test(req)
utils.AssertEqual(t, nil, err)
respBody, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
var envVars EnvVar
utils.AssertEqual(t, nil, json.Unmarshal(respBody, &envVars))
val := envVars.Vars["testEnvKey"]
utils.AssertEqual(t, "testEnvVal", val)
}
func TestEnvVarHandlerMethod(t *testing.T) {
app := fiber.New()
app.Use("/envvars", New())
req, err := http.NewRequestWithContext(context.Background(), fiber.MethodPost, "http://localhost/envvars", nil)
utils.AssertEqual(t, nil, err)
resp, err := app.Test(req)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusMethodNotAllowed, resp.StatusCode)
}
func TestEnvVarHandlerSpecialValue(t *testing.T) {
testEnvKey := "testEnvKey"
fakeBase64 := "testBase64:TQ=="
err := os.Setenv(testEnvKey, fakeBase64)
utils.AssertEqual(t, nil, err)
defer func() {
err := os.Unsetenv(testEnvKey)
utils.AssertEqual(t, nil, err)
}()
app := fiber.New()
app.Use("/envvars", New())
app.Use("/envvars/export", New(Config{ExportVars: map[string]string{testEnvKey: ""}}))
req, err := http.NewRequestWithContext(context.Background(), fiber.MethodGet, "http://localhost/envvars", nil)
utils.AssertEqual(t, nil, err)
resp, err := app.Test(req)
utils.AssertEqual(t, nil, err)
respBody, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
var envVars EnvVar
utils.AssertEqual(t, nil, json.Unmarshal(respBody, &envVars))
val := envVars.Vars[testEnvKey]
utils.AssertEqual(t, fakeBase64, val)
req, err = http.NewRequestWithContext(context.Background(), fiber.MethodGet, "http://localhost/envvars/export", nil)
utils.AssertEqual(t, nil, err)
resp, err = app.Test(req)
utils.AssertEqual(t, nil, err)
respBody, err = io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
var envVarsExport EnvVar
utils.AssertEqual(t, nil, json.Unmarshal(respBody, &envVarsExport))
val = envVarsExport.Vars[testEnvKey]
utils.AssertEqual(t, fakeBase64, val)
}

44
middleware/etag/config.go Normal file
View file

@ -0,0 +1,44 @@
package etag
import (
"github.com/gofiber/fiber/v2"
)
// Config defines the config for middleware.
type Config struct {
// Weak indicates that a weak validator is used. Weak etags are easy
// to generate, but are far less useful for comparisons. Strong
// validators are ideal for comparisons but can be very difficult
// to generate efficiently. Weak ETag values of two representations
// of the same resources might be semantically equivalent, but not
// byte-for-byte identical. This means weak etags prevent caching
// when byte range requests are used, but strong etags mean range
// requests can still be cached.
Weak bool
// Next defines a function to skip this middleware when returned true.
//
// Optional. Default: nil
Next func(c *fiber.Ctx) bool
}
// ConfigDefault is the default config
var ConfigDefault = Config{
Weak: false,
Next: nil,
}
// Helper function to set default values
func configDefault(config ...Config) Config {
// Return default config if nothing provided
if len(config) < 1 {
return ConfigDefault
}
// Override default config
cfg := config[0]
// Set default values
return cfg
}

116
middleware/etag/etag.go Normal file
View file

@ -0,0 +1,116 @@
package etag
import (
"bytes"
"hash/crc32"
"github.com/gofiber/fiber/v2"
"github.com/valyala/bytebufferpool"
)
// New creates a new middleware handler
func New(config ...Config) fiber.Handler {
// Set default config
cfg := configDefault(config...)
var (
normalizedHeaderETag = []byte("Etag")
weakPrefix = []byte("W/")
)
const crcPol = 0xD5828281
crc32q := crc32.MakeTable(crcPol)
// Return new handler
return func(c *fiber.Ctx) error {
// Don't execute middleware if Next returns true
if cfg.Next != nil && cfg.Next(c) {
return c.Next()
}
// Return err if next handler returns one
if err := c.Next(); err != nil {
return err
}
// Don't generate ETags for invalid responses
if c.Response().StatusCode() != fiber.StatusOK {
return nil
}
body := c.Response().Body()
// Skips ETag if no response body is present
if len(body) == 0 {
return nil
}
// Skip ETag if header is already present
if c.Response().Header.PeekBytes(normalizedHeaderETag) != nil {
return nil
}
// Generate ETag for response
bb := bytebufferpool.Get()
defer bytebufferpool.Put(bb)
// Enable weak tag
if cfg.Weak {
_, _ = bb.Write(weakPrefix) //nolint:errcheck // This will never fail
}
_ = bb.WriteByte('"') //nolint:errcheck // This will never fail
bb.B = appendUint(bb.Bytes(), uint32(len(body)))
_ = bb.WriteByte('-') //nolint:errcheck // This will never fail
bb.B = appendUint(bb.Bytes(), crc32.Checksum(body, crc32q))
_ = bb.WriteByte('"') //nolint:errcheck // This will never fail
etag := bb.Bytes()
// Get ETag header from request
clientEtag := c.Request().Header.Peek(fiber.HeaderIfNoneMatch)
// Check if client's ETag is weak
if bytes.HasPrefix(clientEtag, weakPrefix) {
// Check if server's ETag is weak
if bytes.Equal(clientEtag[2:], etag) || bytes.Equal(clientEtag[2:], etag[2:]) {
// W/1 == 1 || W/1 == W/1
c.Context().ResetBody()
return c.SendStatus(fiber.StatusNotModified)
}
// W/1 != W/2 || W/1 != 2
c.Response().Header.SetCanonical(normalizedHeaderETag, etag)
return nil
}
if bytes.Contains(clientEtag, etag) {
// 1 == 1
c.Context().ResetBody()
return c.SendStatus(fiber.StatusNotModified)
}
// 1 != 2
c.Response().Header.SetCanonical(normalizedHeaderETag, etag)
return nil
}
}
// appendUint appends n to dst and returns the extended dst.
func appendUint(dst []byte, n uint32) []byte {
var b [20]byte
buf := b[:]
i := len(buf)
var q uint32
for n >= 10 {
i--
q = n / 10
buf[i] = '0' + byte(n-q*10)
n = q
}
i--
buf[i] = '0' + byte(n)
dst = append(dst, buf[i:]...)
return dst
}

View file

@ -0,0 +1,291 @@
package etag
import (
"bytes"
"io"
"net/http/httptest"
"testing"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/utils"
"github.com/valyala/fasthttp"
)
// go test -run Test_ETag_Next
func Test_ETag_Next(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Next: func(_ *fiber.Ctx) bool {
return true
},
}))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
}
// go test -run Test_ETag_SkipError
func Test_ETag_SkipError(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
app.Get("/", func(c *fiber.Ctx) error {
return fiber.ErrForbidden
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusForbidden, resp.StatusCode)
}
// go test -run Test_ETag_NotStatusOK
func Test_ETag_NotStatusOK(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
app.Get("/", func(c *fiber.Ctx) error {
return c.SendStatus(fiber.StatusCreated)
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusCreated, resp.StatusCode)
}
// go test -run Test_ETag_NoBody
func Test_ETag_NoBody(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
app.Get("/", func(c *fiber.Ctx) error {
return nil
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
}
// go test -run Test_ETag_NewEtag
func Test_ETag_NewEtag(t *testing.T) {
t.Parallel()
t.Run("without HeaderIfNoneMatch", func(t *testing.T) {
t.Parallel()
testETagNewEtag(t, false, false)
})
t.Run("with HeaderIfNoneMatch and not matched", func(t *testing.T) {
t.Parallel()
testETagNewEtag(t, true, false)
})
t.Run("with HeaderIfNoneMatch and matched", func(t *testing.T) {
t.Parallel()
testETagNewEtag(t, true, true)
})
}
func testETagNewEtag(t *testing.T, headerIfNoneMatch, matched bool) { //nolint:revive // We're in a test, so using bools as a flow-control is fine
t.Helper()
app := fiber.New()
app.Use(New())
app.Get("/", func(c *fiber.Ctx) error {
return c.SendString("Hello, World!")
})
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
if headerIfNoneMatch {
etag := `"non-match"`
if matched {
etag = `"13-1831710635"`
}
req.Header.Set(fiber.HeaderIfNoneMatch, etag)
}
resp, err := app.Test(req)
utils.AssertEqual(t, nil, err)
if !headerIfNoneMatch || !matched {
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
utils.AssertEqual(t, `"13-1831710635"`, resp.Header.Get(fiber.HeaderETag))
return
}
if matched {
utils.AssertEqual(t, fiber.StatusNotModified, resp.StatusCode)
b, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 0, len(b))
}
}
// go test -run Test_ETag_WeakEtag
func Test_ETag_WeakEtag(t *testing.T) {
t.Parallel()
t.Run("without HeaderIfNoneMatch", func(t *testing.T) {
t.Parallel()
testETagWeakEtag(t, false, false)
})
t.Run("with HeaderIfNoneMatch and not matched", func(t *testing.T) {
t.Parallel()
testETagWeakEtag(t, true, false)
})
t.Run("with HeaderIfNoneMatch and matched", func(t *testing.T) {
t.Parallel()
testETagWeakEtag(t, true, true)
})
}
func testETagWeakEtag(t *testing.T, headerIfNoneMatch, matched bool) { //nolint:revive // We're in a test, so using bools as a flow-control is fine
t.Helper()
app := fiber.New()
app.Use(New(Config{Weak: true}))
app.Get("/", func(c *fiber.Ctx) error {
return c.SendString("Hello, World!")
})
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
if headerIfNoneMatch {
etag := `W/"non-match"`
if matched {
etag = `W/"13-1831710635"`
}
req.Header.Set(fiber.HeaderIfNoneMatch, etag)
}
resp, err := app.Test(req)
utils.AssertEqual(t, nil, err)
if !headerIfNoneMatch || !matched {
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
utils.AssertEqual(t, `W/"13-1831710635"`, resp.Header.Get(fiber.HeaderETag))
return
}
if matched {
utils.AssertEqual(t, fiber.StatusNotModified, resp.StatusCode)
b, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 0, len(b))
}
}
// go test -run Test_ETag_CustomEtag
func Test_ETag_CustomEtag(t *testing.T) {
t.Parallel()
t.Run("without HeaderIfNoneMatch", func(t *testing.T) {
t.Parallel()
testETagCustomEtag(t, false, false)
})
t.Run("with HeaderIfNoneMatch and not matched", func(t *testing.T) {
t.Parallel()
testETagCustomEtag(t, true, false)
})
t.Run("with HeaderIfNoneMatch and matched", func(t *testing.T) {
t.Parallel()
testETagCustomEtag(t, true, true)
})
}
func testETagCustomEtag(t *testing.T, headerIfNoneMatch, matched bool) { //nolint:revive // We're in a test, so using bools as a flow-control is fine
t.Helper()
app := fiber.New()
app.Use(New())
app.Get("/", func(c *fiber.Ctx) error {
c.Set(fiber.HeaderETag, `"custom"`)
if bytes.Equal(c.Request().Header.Peek(fiber.HeaderIfNoneMatch), []byte(`"custom"`)) {
return c.SendStatus(fiber.StatusNotModified)
}
return c.SendString("Hello, World!")
})
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
if headerIfNoneMatch {
etag := `"non-match"`
if matched {
etag = `"custom"`
}
req.Header.Set(fiber.HeaderIfNoneMatch, etag)
}
resp, err := app.Test(req)
utils.AssertEqual(t, nil, err)
if !headerIfNoneMatch || !matched {
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
utils.AssertEqual(t, `"custom"`, resp.Header.Get(fiber.HeaderETag))
return
}
if matched {
utils.AssertEqual(t, fiber.StatusNotModified, resp.StatusCode)
b, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 0, len(b))
}
}
// go test -run Test_ETag_CustomEtagPut
func Test_ETag_CustomEtagPut(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
app.Put("/", func(c *fiber.Ctx) error {
c.Set(fiber.HeaderETag, `"custom"`)
if !bytes.Equal(c.Request().Header.Peek(fiber.HeaderIfMatch), []byte(`"custom"`)) {
return c.SendStatus(fiber.StatusPreconditionFailed)
}
return c.SendString("Hello, World!")
})
req := httptest.NewRequest(fiber.MethodPut, "/", nil)
req.Header.Set(fiber.HeaderIfMatch, `"non-match"`)
resp, err := app.Test(req)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusPreconditionFailed, resp.StatusCode)
}
// go test -v -run=^$ -bench=Benchmark_Etag -benchmem -count=4
func Benchmark_Etag(b *testing.B) {
app := fiber.New()
app.Use(New())
app.Get("/", func(c *fiber.Ctx) error {
return c.SendString("Hello, World!")
})
h := app.Handler()
fctx := &fasthttp.RequestCtx{}
fctx.Request.Header.SetMethod(fiber.MethodGet)
fctx.Request.SetRequestURI("/")
b.ReportAllocs()
b.ResetTimer()
for n := 0; n < b.N; n++ {
h(fctx)
}
utils.AssertEqual(b, 200, fctx.Response.Header.StatusCode())
utils.AssertEqual(b, `"13-1831710635"`, string(fctx.Response.Header.Peek(fiber.HeaderETag)))
}

View file

@ -0,0 +1,34 @@
package expvar
import (
"github.com/gofiber/fiber/v2"
)
// Config defines the config for middleware.
type Config struct {
// Next defines a function to skip this middleware when returned true.
//
// Optional. Default: nil
Next func(c *fiber.Ctx) bool
}
var ConfigDefault = Config{
Next: nil,
}
func configDefault(config ...Config) Config {
// Return default config if nothing provided
if len(config) < 1 {
return ConfigDefault
}
// Override default config
cfg := config[0]
// Set default values
if cfg.Next == nil {
cfg.Next = ConfigDefault.Next
}
return cfg
}

View file

@ -0,0 +1,35 @@
package expvar
import (
"strings"
"github.com/gofiber/fiber/v2"
"github.com/valyala/fasthttp/expvarhandler"
)
// New creates a new middleware handler
func New(config ...Config) fiber.Handler {
// Set default config
cfg := configDefault(config...)
// Return new handler
return func(c *fiber.Ctx) error {
// Don't execute middleware if Next returns true
if cfg.Next != nil && cfg.Next(c) {
return c.Next()
}
path := c.Path()
// We are only interested in /debug/vars routes
if len(path) < 11 || !strings.HasPrefix(path, "/debug/vars") {
return c.Next()
}
if path == "/debug/vars" {
expvarhandler.ExpvarHandler(c.Context())
return nil
}
return c.Redirect("/debug/vars", fiber.StatusFound)
}
}

View file

@ -0,0 +1,103 @@
package expvar
import (
"bytes"
"io"
"net/http/httptest"
"testing"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/utils"
)
func Test_Non_Expvar_Path(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
app.Get("/", func(c *fiber.Ctx) error {
return c.SendString("escaped")
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode)
b, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "escaped", string(b))
}
func Test_Expvar_Index(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
app.Get("/", func(c *fiber.Ctx) error {
return c.SendString("escaped")
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/debug/vars", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode)
utils.AssertEqual(t, fiber.MIMEApplicationJSONCharsetUTF8, resp.Header.Get(fiber.HeaderContentType))
b, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, true, bytes.Contains(b, []byte("cmdline")))
utils.AssertEqual(t, true, bytes.Contains(b, []byte("memstat")))
}
func Test_Expvar_Filter(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
app.Get("/", func(c *fiber.Ctx) error {
return c.SendString("escaped")
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/debug/vars?r=cmd", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode)
utils.AssertEqual(t, fiber.MIMEApplicationJSONCharsetUTF8, resp.Header.Get(fiber.HeaderContentType))
b, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, true, bytes.Contains(b, []byte("cmdline")))
utils.AssertEqual(t, false, bytes.Contains(b, []byte("memstat")))
}
func Test_Expvar_Other_Path(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
app.Get("/", func(c *fiber.Ctx) error {
return c.SendString("escaped")
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/debug/vars/302", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 302, resp.StatusCode)
}
// go test -run Test_Expvar_Next
func Test_Expvar_Next(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Next: func(_ *fiber.Ctx) bool {
return true
},
}))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/debug/vars", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 404, resp.StatusCode)
}

View file

@ -0,0 +1,146 @@
package favicon
import (
"io"
"net/http"
"os"
"strconv"
"github.com/gofiber/fiber/v2"
)
// Config defines the config for middleware.
type Config struct {
// Next defines a function to skip this middleware when returned true.
//
// Optional. Default: nil
Next func(c *fiber.Ctx) bool
// Raw data of the favicon file
//
// Optional. Default: nil
Data []byte `json:"-"`
// File holds the path to an actual favicon that will be cached
//
// Optional. Default: ""
File string `json:"file"`
// URL for favicon handler
//
// Optional. Default: "/favicon.ico"
URL string `json:"url"`
// FileSystem is an optional alternate filesystem to search for the favicon in.
// An example of this could be an embedded or network filesystem
//
// Optional. Default: nil
FileSystem http.FileSystem `json:"-"`
// CacheControl defines how the Cache-Control header in the response should be set
//
// Optional. Default: "public, max-age=31536000"
CacheControl string `json:"cache_control"`
}
// ConfigDefault is the default config
var ConfigDefault = Config{
Next: nil,
File: "",
URL: fPath,
CacheControl: "public, max-age=31536000",
}
const (
fPath = "/favicon.ico"
hType = "image/x-icon"
hAllow = "GET, HEAD, OPTIONS"
hZero = "0"
)
// New creates a new middleware handler
func New(config ...Config) fiber.Handler {
// Set default config
cfg := ConfigDefault
// Override config if provided
if len(config) > 0 {
cfg = config[0]
// Set default values
if cfg.Next == nil {
cfg.Next = ConfigDefault.Next
}
if cfg.URL == "" {
cfg.URL = ConfigDefault.URL
}
if cfg.File == "" {
cfg.File = ConfigDefault.File
}
if cfg.CacheControl == "" {
cfg.CacheControl = ConfigDefault.CacheControl
}
}
// Load icon if provided
var (
err error
icon []byte
iconLen string
)
if cfg.Data != nil {
// use the provided favicon data
icon = cfg.Data
iconLen = strconv.Itoa(len(cfg.Data))
} else if cfg.File != "" {
// read from configured filesystem if present
if cfg.FileSystem != nil {
f, err := cfg.FileSystem.Open(cfg.File)
if err != nil {
panic(err)
}
if icon, err = io.ReadAll(f); err != nil {
panic(err)
}
} else if icon, err = os.ReadFile(cfg.File); err != nil {
panic(err)
}
iconLen = strconv.Itoa(len(icon))
}
// Return new handler
return func(c *fiber.Ctx) error {
// Don't execute middleware if Next returns true
if cfg.Next != nil && cfg.Next(c) {
return c.Next()
}
// Only respond to favicon requests
if c.Path() != cfg.URL {
return c.Next()
}
// Only allow GET, HEAD and OPTIONS requests
if c.Method() != fiber.MethodGet && c.Method() != fiber.MethodHead {
if c.Method() != fiber.MethodOptions {
c.Status(fiber.StatusMethodNotAllowed)
} else {
c.Status(fiber.StatusOK)
}
c.Set(fiber.HeaderAllow, hAllow)
c.Set(fiber.HeaderContentLength, hZero)
return nil
}
// Serve cached favicon
if len(icon) > 0 {
c.Set(fiber.HeaderContentLength, iconLen)
c.Set(fiber.HeaderContentType, hType)
c.Set(fiber.HeaderCacheControl, cfg.CacheControl)
return c.Status(fiber.StatusOK).Send(icon)
}
return c.SendStatus(fiber.StatusNoContent)
}
}

View file

@ -0,0 +1,208 @@
//nolint:bodyclose // Much easier to just ignore memory leaks in tests
package favicon
import (
"fmt"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/utils"
"github.com/valyala/fasthttp"
)
// go test -run Test_Middleware_Favicon
func Test_Middleware_Favicon(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
app.Get("/", func(c *fiber.Ctx) error {
return nil
})
// Skip Favicon middleware
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err, "app.Test(req)")
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode, "Status code")
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/favicon.ico", nil))
utils.AssertEqual(t, nil, err, "app.Test(req)")
utils.AssertEqual(t, fiber.StatusNoContent, resp.StatusCode, "Status code")
resp, err = app.Test(httptest.NewRequest(fiber.MethodOptions, "/favicon.ico", nil))
utils.AssertEqual(t, nil, err, "app.Test(req)")
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode, "Status code")
resp, err = app.Test(httptest.NewRequest(fiber.MethodPut, "/favicon.ico", nil))
utils.AssertEqual(t, nil, err, "app.Test(req)")
utils.AssertEqual(t, fiber.StatusMethodNotAllowed, resp.StatusCode, "Status code")
utils.AssertEqual(t, strings.Join([]string{fiber.MethodGet, fiber.MethodHead, fiber.MethodOptions}, ", "), resp.Header.Get(fiber.HeaderAllow))
}
// go test -run Test_Middleware_Favicon_Not_Found
func Test_Middleware_Favicon_Not_Found(t *testing.T) {
t.Parallel()
defer func() {
if err := recover(); err == nil {
t.Fatal("should cache panic")
}
}()
fiber.New().Use(New(Config{
File: "non-exist.ico",
}))
}
// go test -run Test_Middleware_Favicon_Found
func Test_Middleware_Favicon_Found(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
File: "../../.github/testdata/favicon.ico",
}))
app.Get("/", func(c *fiber.Ctx) error {
return nil
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/favicon.ico", nil))
utils.AssertEqual(t, nil, err, "app.Test(req)")
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode, "Status code")
utils.AssertEqual(t, "image/x-icon", resp.Header.Get(fiber.HeaderContentType))
utils.AssertEqual(t, "public, max-age=31536000", resp.Header.Get(fiber.HeaderCacheControl), "CacheControl Control")
}
// go test -run Test_Custom_Favicon_Url
func Test_Custom_Favicon_Url(t *testing.T) {
app := fiber.New()
const customURL = "/favicon.svg"
app.Use(New(Config{
File: "../../.github/testdata/favicon.ico",
URL: customURL,
}))
app.Get("/", func(c *fiber.Ctx) error {
return nil
})
resp, err := app.Test(httptest.NewRequest(http.MethodGet, customURL, nil))
utils.AssertEqual(t, nil, err, "app.Test(req)")
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode, "Status code")
utils.AssertEqual(t, "image/x-icon", resp.Header.Get(fiber.HeaderContentType))
}
// go test -run Test_Custom_Favicon_Data
func Test_Custom_Favicon_Data(t *testing.T) {
data, err := os.ReadFile("../../.github/testdata/favicon.ico")
utils.AssertEqual(t, nil, err)
app := fiber.New()
app.Use(New(Config{
Data: data,
}))
app.Get("/", func(c *fiber.Ctx) error {
return nil
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/favicon.ico", nil))
utils.AssertEqual(t, nil, err, "app.Test(req)")
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode, "Status code")
utils.AssertEqual(t, "image/x-icon", resp.Header.Get(fiber.HeaderContentType))
utils.AssertEqual(t, "public, max-age=31536000", resp.Header.Get(fiber.HeaderCacheControl), "CacheControl Control")
}
// mockFS wraps local filesystem for the purposes of
// Test_Middleware_Favicon_FileSystem located below
// TODO use os.Dir if fiber upgrades to 1.16
type mockFS struct{}
func (mockFS) Open(name string) (http.File, error) {
if name == "/" {
name = "."
} else {
name = strings.TrimPrefix(name, "/")
}
file, err := os.Open(name) //nolint:gosec // We're in a test func, so this is fine
if err != nil {
return nil, fmt.Errorf("failed to open: %w", err)
}
return file, nil
}
// go test -run Test_Middleware_Favicon_FileSystem
func Test_Middleware_Favicon_FileSystem(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
File: "../../.github/testdata/favicon.ico",
FileSystem: mockFS{},
}))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/favicon.ico", nil))
utils.AssertEqual(t, nil, err, "app.Test(req)")
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode, "Status code")
utils.AssertEqual(t, "image/x-icon", resp.Header.Get(fiber.HeaderContentType))
utils.AssertEqual(t, "public, max-age=31536000", resp.Header.Get(fiber.HeaderCacheControl), "CacheControl Control")
}
// go test -run Test_Middleware_Favicon_CacheControl
func Test_Middleware_Favicon_CacheControl(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
CacheControl: "public, max-age=100",
File: "../../.github/testdata/favicon.ico",
}))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/favicon.ico", nil))
utils.AssertEqual(t, nil, err, "app.Test(req)")
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode, "Status code")
utils.AssertEqual(t, "image/x-icon", resp.Header.Get(fiber.HeaderContentType))
utils.AssertEqual(t, "public, max-age=100", resp.Header.Get(fiber.HeaderCacheControl), "CacheControl Control")
}
// go test -v -run=^$ -bench=Benchmark_Middleware_Favicon -benchmem -count=4
func Benchmark_Middleware_Favicon(b *testing.B) {
app := fiber.New()
app.Use(New())
app.Get("/", func(c *fiber.Ctx) error {
return nil
})
handler := app.Handler()
c := &fasthttp.RequestCtx{}
c.Request.SetRequestURI("/")
b.ReportAllocs()
b.ResetTimer()
for n := 0; n < b.N; n++ {
handler(c)
}
}
// go test -run Test_Favicon_Next
func Test_Favicon_Next(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Next: func(_ *fiber.Ctx) bool {
return true
},
}))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
}

View file

@ -0,0 +1,287 @@
package filesystem
import (
"errors"
"fmt"
"io/fs"
"net/http"
"strconv"
"strings"
"sync"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/utils"
)
// Config defines the config for middleware.
type Config struct {
// Next defines a function to skip this middleware when returned true.
//
// Optional. Default: nil
Next func(c *fiber.Ctx) bool
// Root is a FileSystem that provides access
// to a collection of files and directories.
//
// Required. Default: nil
Root http.FileSystem `json:"-"`
// PathPrefix defines a prefix to be added to a filepath when
// reading a file from the FileSystem.
//
// Use when using Go 1.16 embed.FS
//
// Optional. Default ""
PathPrefix string `json:"path_prefix"`
// Enable directory browsing.
//
// Optional. Default: false
Browse bool `json:"browse"`
// Index file for serving a directory.
//
// Optional. Default: "index.html"
Index string `json:"index"`
// The value for the Cache-Control HTTP-header
// that is set on the file response. MaxAge is defined in seconds.
//
// Optional. Default value 0.
MaxAge int `json:"max_age"`
// File to return if path is not found. Useful for SPA's.
//
// Optional. Default: ""
NotFoundFile string `json:"not_found_file"`
// The value for the Content-Type HTTP-header
// that is set on the file response
//
// Optional. Default: ""
ContentTypeCharset string `json:"content_type_charset"`
}
// ConfigDefault is the default config
var ConfigDefault = Config{
Next: nil,
Root: nil,
PathPrefix: "",
Browse: false,
Index: "/index.html",
MaxAge: 0,
ContentTypeCharset: "",
}
// New creates a new middleware handler.
//
// filesystem does not handle url encoded values (for example spaces)
// on it's own. If you need that functionality, set "UnescapePath"
// in fiber.Config
func New(config ...Config) fiber.Handler {
// Set default config
cfg := ConfigDefault
// Override config if provided
if len(config) > 0 {
cfg = config[0]
// Set default values
if cfg.Index == "" {
cfg.Index = ConfigDefault.Index
}
if !strings.HasPrefix(cfg.Index, "/") {
cfg.Index = "/" + cfg.Index
}
if cfg.NotFoundFile != "" && !strings.HasPrefix(cfg.NotFoundFile, "/") {
cfg.NotFoundFile = "/" + cfg.NotFoundFile
}
}
if cfg.Root == nil {
panic("filesystem: Root cannot be nil")
}
if cfg.PathPrefix != "" && !strings.HasPrefix(cfg.PathPrefix, "/") {
cfg.PathPrefix = "/" + cfg.PathPrefix
}
var once sync.Once
var prefix string
cacheControlStr := "public, max-age=" + strconv.Itoa(cfg.MaxAge)
// Return new handler
return func(c *fiber.Ctx) error {
// Don't execute middleware if Next returns true
if cfg.Next != nil && cfg.Next(c) {
return c.Next()
}
method := c.Method()
// We only serve static assets on GET or HEAD methods
if method != fiber.MethodGet && method != fiber.MethodHead {
return c.Next()
}
// Set prefix once
once.Do(func() {
prefix = c.Route().Path
})
// Strip prefix
path := strings.TrimPrefix(c.Path(), prefix)
if !strings.HasPrefix(path, "/") {
path = "/" + path
}
// Add PathPrefix
if cfg.PathPrefix != "" {
// PathPrefix already has a "/" prefix
path = cfg.PathPrefix + path
}
if len(path) > 1 {
path = utils.TrimRight(path, '/')
}
file, err := cfg.Root.Open(path)
if err != nil && errors.Is(err, fs.ErrNotExist) && cfg.NotFoundFile != "" {
file, err = cfg.Root.Open(cfg.NotFoundFile)
}
if err != nil {
if errors.Is(err, fs.ErrNotExist) {
return c.Status(fiber.StatusNotFound).Next()
}
return fmt.Errorf("failed to open: %w", err)
}
stat, err := file.Stat()
if err != nil {
return fmt.Errorf("failed to stat: %w", err)
}
// Serve index if path is directory
if stat.IsDir() {
indexPath := utils.TrimRight(path, '/') + cfg.Index
index, err := cfg.Root.Open(indexPath)
if err == nil {
indexStat, err := index.Stat()
if err == nil {
file = index
stat = indexStat
}
}
}
// Browse directory if no index found and browsing is enabled
if stat.IsDir() {
if cfg.Browse {
return dirList(c, file)
}
return fiber.ErrForbidden
}
c.Status(fiber.StatusOK)
modTime := stat.ModTime()
contentLength := int(stat.Size())
// Set Content Type header
if cfg.ContentTypeCharset == "" {
c.Type(getFileExtension(stat.Name()))
} else {
c.Type(getFileExtension(stat.Name()), cfg.ContentTypeCharset)
}
// Set Last Modified header
if !modTime.IsZero() {
c.Set(fiber.HeaderLastModified, modTime.UTC().Format(http.TimeFormat))
}
if method == fiber.MethodGet {
if cfg.MaxAge > 0 {
c.Set(fiber.HeaderCacheControl, cacheControlStr)
}
c.Response().SetBodyStream(file, contentLength)
return nil
}
if method == fiber.MethodHead {
c.Request().ResetBody()
// Fasthttp should skipbody by default if HEAD?
c.Response().SkipBody = true
c.Response().Header.SetContentLength(contentLength)
if err := file.Close(); err != nil {
return fmt.Errorf("failed to close: %w", err)
}
return nil
}
return c.Next()
}
}
// SendFile serves a file from an HTTP file system at the specified path.
// It handles content serving, sets appropriate headers, and returns errors when needed.
// Usage: err := SendFile(ctx, fs, "/path/to/file.txt")
func SendFile(c *fiber.Ctx, filesystem http.FileSystem, path string) error {
file, err := filesystem.Open(path)
if err != nil {
if errors.Is(err, fs.ErrNotExist) {
return fiber.ErrNotFound
}
return fmt.Errorf("failed to open: %w", err)
}
stat, err := file.Stat()
if err != nil {
return fmt.Errorf("failed to stat: %w", err)
}
// Serve index if path is directory
if stat.IsDir() {
indexPath := utils.TrimRight(path, '/') + ConfigDefault.Index
index, err := filesystem.Open(indexPath)
if err == nil {
indexStat, err := index.Stat()
if err == nil {
file = index
stat = indexStat
}
}
}
// Return forbidden if no index found
if stat.IsDir() {
return fiber.ErrForbidden
}
c.Status(fiber.StatusOK)
modTime := stat.ModTime()
contentLength := int(stat.Size())
// Set Content Type header
c.Type(getFileExtension(stat.Name()))
// Set Last Modified header
if !modTime.IsZero() {
c.Set(fiber.HeaderLastModified, modTime.UTC().Format(http.TimeFormat))
}
method := c.Method()
if method == fiber.MethodGet {
c.Response().SetBodyStream(file, contentLength)
return nil
}
if method == fiber.MethodHead {
c.Request().ResetBody()
// Fasthttp should skipbody by default if HEAD?
c.Response().SkipBody = true
c.Response().Header.SetContentLength(contentLength)
if err := file.Close(); err != nil {
return fmt.Errorf("failed to close: %w", err)
}
return nil
}
return nil
}

View file

@ -0,0 +1,235 @@
//nolint:bodyclose // Much easier to just ignore memory leaks in tests
package filesystem
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/utils"
)
// go test -run Test_FileSystem
func Test_FileSystem(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use("/test", New(Config{
Root: http.Dir("../../.github/testdata/fs"),
}))
app.Use("/dir", New(Config{
Root: http.Dir("../../.github/testdata/fs"),
Browse: true,
}))
app.Get("/", func(c *fiber.Ctx) error {
return c.SendString("Hello, World!")
})
app.Use("/spatest", New(Config{
Root: http.Dir("../../.github/testdata/fs"),
Index: "index.html",
NotFoundFile: "index.html",
}))
app.Use("/prefix", New(Config{
Root: http.Dir("../../.github/testdata/fs"),
PathPrefix: "img",
}))
tests := []struct {
name string
url string
statusCode int
contentType string
modifiedTime string
}{
{
name: "Should be returns status 200 with suitable content-type",
url: "/test/index.html",
statusCode: 200,
contentType: "text/html",
},
{
name: "Should be returns status 200 with suitable content-type",
url: "/test",
statusCode: 200,
contentType: "text/html",
},
{
name: "Should be returns status 200 with suitable content-type",
url: "/test/css/style.css",
statusCode: 200,
contentType: "text/css",
},
{
name: "Should be returns status 404",
url: "/test/nofile.js",
statusCode: 404,
},
{
name: "Should be returns status 404",
url: "/test/nofile",
statusCode: 404,
},
{
name: "Should be returns status 200",
url: "/",
statusCode: 200,
contentType: "text/plain; charset=utf-8",
},
{
name: "Should be returns status 403",
url: "/test/img",
statusCode: 403,
},
{
name: "Should list the directory contents",
url: "/dir/img",
statusCode: 200,
contentType: "text/html",
},
{
name: "Should list the directory contents",
url: "/dir/img/",
statusCode: 200,
contentType: "text/html",
},
{
name: "Should be returns status 200",
url: "/dir/img/fiber.png",
statusCode: 200,
contentType: "image/png",
},
{
name: "Should be return status 200",
url: "/spatest/doesnotexist",
statusCode: 200,
contentType: "text/html",
},
{
name: "PathPrefix should be applied",
url: "/prefix/fiber.png",
statusCode: 200,
contentType: "image/png",
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, tt.url, nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, tt.statusCode, resp.StatusCode)
if tt.contentType != "" {
ct := resp.Header.Get("Content-Type")
utils.AssertEqual(t, tt.contentType, ct)
}
})
}
}
// go test -run Test_FileSystem_Next
func Test_FileSystem_Next(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Root: http.Dir("../../.github/testdata/fs"),
Next: func(_ *fiber.Ctx) bool {
return true
},
}))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
}
func Test_FileSystem_NonGetAndHead(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use("/test", New(Config{
Root: http.Dir("../../.github/testdata/fs"),
}))
resp, err := app.Test(httptest.NewRequest(fiber.MethodPost, "/test", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 404, resp.StatusCode)
}
func Test_FileSystem_Head(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use("/test", New(Config{
Root: http.Dir("../../.github/testdata/fs"),
}))
req, err := http.NewRequestWithContext(context.Background(), fiber.MethodHead, "/test", nil)
utils.AssertEqual(t, nil, err)
resp, err := app.Test(req)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode)
}
func Test_FileSystem_NoRoot(t *testing.T) {
t.Parallel()
defer func() {
utils.AssertEqual(t, "filesystem: Root cannot be nil", recover())
}()
app := fiber.New()
app.Use(New())
_, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
}
func Test_FileSystem_UsingParam(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use("/:path", func(c *fiber.Ctx) error {
return SendFile(c, http.Dir("../../.github/testdata/fs"), c.Params("path")+".html")
})
req, err := http.NewRequestWithContext(context.Background(), fiber.MethodHead, "/index", nil)
utils.AssertEqual(t, nil, err)
resp, err := app.Test(req)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode)
}
func Test_FileSystem_UsingParam_NonFile(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use("/:path", func(c *fiber.Ctx) error {
return SendFile(c, http.Dir("../../.github/testdata/fs"), c.Params("path")+".html")
})
req, err := http.NewRequestWithContext(context.Background(), fiber.MethodHead, "/template", nil)
utils.AssertEqual(t, nil, err)
resp, err := app.Test(req)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 404, resp.StatusCode)
}
func Test_FileSystem_UsingContentTypeCharset(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Root: http.Dir("../../.github/testdata/fs/index.html"),
ContentTypeCharset: "UTF-8",
}))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode)
utils.AssertEqual(t, "text/html; charset=UTF-8", resp.Header.Get("Content-Type"))
}

View file

@ -0,0 +1,66 @@
package filesystem
import (
"fmt"
"html"
"net/http"
"os"
"path"
"sort"
"strings"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/utils"
)
func getFileExtension(p string) string {
n := strings.LastIndexByte(p, '.')
if n < 0 {
return ""
}
return p[n:]
}
func dirList(c *fiber.Ctx, f http.File) error {
fileinfos, err := f.Readdir(-1)
if err != nil {
return fmt.Errorf("failed to read dir: %w", err)
}
fm := make(map[string]os.FileInfo, len(fileinfos))
filenames := make([]string, 0, len(fileinfos))
for _, fi := range fileinfos {
name := fi.Name()
fm[name] = fi
filenames = append(filenames, name)
}
basePathEscaped := html.EscapeString(c.Path())
_, _ = fmt.Fprintf(c, "<html><head><title>%s</title><style>.dir { font-weight: bold }</style></head><body>", basePathEscaped)
_, _ = fmt.Fprintf(c, "<h1>%s</h1>", basePathEscaped)
_, _ = fmt.Fprint(c, "<ul>")
if len(basePathEscaped) > 1 {
parentPathEscaped := html.EscapeString(utils.TrimRight(c.Path(), '/') + "/..")
_, _ = fmt.Fprintf(c, `<li><a href="%s" class="dir">..</a></li>`, parentPathEscaped)
}
sort.Strings(filenames)
for _, name := range filenames {
pathEscaped := html.EscapeString(path.Join(c.Path() + "/" + name))
fi := fm[name]
auxStr := "dir"
className := "dir"
if !fi.IsDir() {
auxStr = fmt.Sprintf("file, %d bytes", fi.Size())
className = "file"
}
_, _ = fmt.Fprintf(c, `<li><a href="%s" class="%s">%s</a>, %s, last modified %s</li>`,
pathEscaped, className, html.EscapeString(name), auxStr, fi.ModTime())
}
_, _ = fmt.Fprint(c, "</ul></body></html>")
c.Type("html")
return nil
}

View file

@ -0,0 +1,84 @@
package healthcheck
import (
"github.com/gofiber/fiber/v2"
)
// Config defines the configuration options for the healthcheck middleware.
type Config struct {
// Next defines a function to skip this middleware when returned true.
//
// Optional. Default: nil
Next func(c *fiber.Ctx) bool
// Function used for checking the liveness of the application. Returns true if the application
// is running and false if it is not. The liveness probe is typically used to indicate if
// the application is in a state where it can handle requests (e.g., the server is up and running).
//
// Optional. Default: func(c *fiber.Ctx) bool { return true }
LivenessProbe HealthChecker
// HTTP endpoint at which the liveness probe will be available.
//
// Optional. Default: "/livez"
LivenessEndpoint string
// Function used for checking the readiness of the application. Returns true if the application
// is ready to process requests and false otherwise. The readiness probe typically checks if all necessary
// services, databases, and other dependencies are available for the application to function correctly.
//
// Optional. Default: func(c *fiber.Ctx) bool { return true }
ReadinessProbe HealthChecker
// HTTP endpoint at which the readiness probe will be available.
// Optional. Default: "/readyz"
ReadinessEndpoint string
}
const (
DefaultLivenessEndpoint = "/livez"
DefaultReadinessEndpoint = "/readyz"
)
func defaultLivenessProbe(*fiber.Ctx) bool { return true }
func defaultReadinessProbe(*fiber.Ctx) bool { return true }
// ConfigDefault is the default config
var ConfigDefault = Config{
LivenessProbe: defaultLivenessProbe,
ReadinessProbe: defaultReadinessProbe,
LivenessEndpoint: DefaultLivenessEndpoint,
ReadinessEndpoint: DefaultReadinessEndpoint,
}
// defaultConfig returns a default config for the healthcheck middleware.
func defaultConfig(config ...Config) Config {
if len(config) < 1 {
return ConfigDefault
}
cfg := config[0]
if cfg.Next == nil {
cfg.Next = ConfigDefault.Next
}
if cfg.LivenessProbe == nil {
cfg.LivenessProbe = defaultLivenessProbe
}
if cfg.ReadinessProbe == nil {
cfg.ReadinessProbe = defaultReadinessProbe
}
if cfg.LivenessEndpoint == "" {
cfg.LivenessEndpoint = DefaultLivenessEndpoint
}
if cfg.ReadinessEndpoint == "" {
cfg.ReadinessEndpoint = DefaultReadinessEndpoint
}
return cfg
}

View file

@ -0,0 +1,61 @@
package healthcheck
import (
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/utils"
)
// HealthChecker defines a function to check liveness or readiness of the application
type HealthChecker func(*fiber.Ctx) bool
// ProbeCheckerHandler defines a function that returns a ProbeChecker
type HealthCheckerHandler func(HealthChecker) fiber.Handler
func healthCheckerHandler(checker HealthChecker) fiber.Handler {
return func(c *fiber.Ctx) error {
if checker == nil {
return c.Next()
}
if checker(c) {
return c.SendStatus(fiber.StatusOK)
}
return c.SendStatus(fiber.StatusServiceUnavailable)
}
}
func New(config ...Config) fiber.Handler {
cfg := defaultConfig(config...)
isLiveHandler := healthCheckerHandler(cfg.LivenessProbe)
isReadyHandler := healthCheckerHandler(cfg.ReadinessProbe)
return func(c *fiber.Ctx) error {
// Don't execute middleware if Next returns true
if cfg.Next != nil && cfg.Next(c) {
return c.Next()
}
if c.Method() != fiber.MethodGet {
return c.Next()
}
prefixCount := len(utils.TrimRight(c.Route().Path, '/'))
if len(c.Path()) >= prefixCount {
checkPath := c.Path()[prefixCount:]
checkPathTrimmed := checkPath
if !c.App().Config().StrictRouting {
checkPathTrimmed = utils.TrimRight(checkPath, '/')
}
switch {
case checkPath == cfg.ReadinessEndpoint || checkPathTrimmed == cfg.ReadinessEndpoint:
return isReadyHandler(c)
case checkPath == cfg.LivenessEndpoint || checkPathTrimmed == cfg.LivenessEndpoint:
return isLiveHandler(c)
}
}
return c.Next()
}
}

View file

@ -0,0 +1,237 @@
package healthcheck
import (
"fmt"
"net/http/httptest"
"testing"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/utils"
"github.com/valyala/fasthttp"
)
func shouldGiveStatus(t *testing.T, app *fiber.App, path string, expectedStatus int) {
t.Helper()
req, err := app.Test(httptest.NewRequest(fiber.MethodGet, path, nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, expectedStatus, req.StatusCode, "path: "+path+" should match "+fmt.Sprint(expectedStatus))
}
func shouldGiveOK(t *testing.T, app *fiber.App, path string) {
t.Helper()
shouldGiveStatus(t, app, path, fiber.StatusOK)
}
func shouldGiveNotFound(t *testing.T, app *fiber.App, path string) {
t.Helper()
shouldGiveStatus(t, app, path, fiber.StatusNotFound)
}
func Test_HealthCheck_Strict_Routing_Default(t *testing.T) {
t.Parallel()
app := fiber.New(fiber.Config{
StrictRouting: true,
})
app.Use(New())
shouldGiveOK(t, app, "/readyz")
shouldGiveOK(t, app, "/livez")
shouldGiveNotFound(t, app, "/readyz/")
shouldGiveNotFound(t, app, "/livez/")
shouldGiveNotFound(t, app, "/notDefined/readyz")
shouldGiveNotFound(t, app, "/notDefined/livez")
}
func Test_HealthCheck_Group_Default(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Group("/v1", New())
v2Group := app.Group("/v2/")
customer := v2Group.Group("/customer/")
customer.Use(New())
v3Group := app.Group("/v3/")
v3Group.Group("/todos/", New(Config{ReadinessEndpoint: "/readyz/", LivenessEndpoint: "/livez/"}))
shouldGiveOK(t, app, "/v1/readyz")
shouldGiveOK(t, app, "/v1/livez")
shouldGiveOK(t, app, "/v1/readyz/")
shouldGiveOK(t, app, "/v1/livez/")
shouldGiveOK(t, app, "/v2/customer/readyz")
shouldGiveOK(t, app, "/v2/customer/livez")
shouldGiveOK(t, app, "/v2/customer/readyz/")
shouldGiveOK(t, app, "/v2/customer/livez/")
shouldGiveNotFound(t, app, "/v3/todos/readyz")
shouldGiveNotFound(t, app, "/v3/todos/livez")
shouldGiveOK(t, app, "/v3/todos/readyz/")
shouldGiveOK(t, app, "/v3/todos/livez/")
shouldGiveNotFound(t, app, "/notDefined/readyz")
shouldGiveNotFound(t, app, "/notDefined/livez")
shouldGiveNotFound(t, app, "/notDefined/readyz/")
shouldGiveNotFound(t, app, "/notDefined/livez/")
// strict routing
app = fiber.New(fiber.Config{
StrictRouting: true,
})
app.Group("/v1", New())
v2Group = app.Group("/v2/")
customer = v2Group.Group("/customer/")
customer.Use(New())
v3Group = app.Group("/v3/")
v3Group.Group("/todos/", New(Config{ReadinessEndpoint: "/readyz/", LivenessEndpoint: "/livez/"}))
shouldGiveOK(t, app, "/v1/readyz")
shouldGiveOK(t, app, "/v1/livez")
shouldGiveNotFound(t, app, "/v1/readyz/")
shouldGiveNotFound(t, app, "/v1/livez/")
shouldGiveOK(t, app, "/v2/customer/readyz")
shouldGiveOK(t, app, "/v2/customer/livez")
shouldGiveNotFound(t, app, "/v2/customer/readyz/")
shouldGiveNotFound(t, app, "/v2/customer/livez/")
shouldGiveNotFound(t, app, "/v3/todos/readyz")
shouldGiveNotFound(t, app, "/v3/todos/livez")
shouldGiveOK(t, app, "/v3/todos/readyz/")
shouldGiveOK(t, app, "/v3/todos/livez/")
shouldGiveNotFound(t, app, "/notDefined/readyz")
shouldGiveNotFound(t, app, "/notDefined/livez")
shouldGiveNotFound(t, app, "/notDefined/readyz/")
shouldGiveNotFound(t, app, "/notDefined/livez/")
}
func Test_HealthCheck_Default(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
shouldGiveOK(t, app, "/readyz")
shouldGiveOK(t, app, "/livez")
shouldGiveOK(t, app, "/readyz/")
shouldGiveOK(t, app, "/livez/")
shouldGiveNotFound(t, app, "/notDefined/readyz")
shouldGiveNotFound(t, app, "/notDefined/livez")
}
func Test_HealthCheck_Custom(t *testing.T) {
t.Parallel()
app := fiber.New()
c1 := make(chan struct{}, 1)
app.Use(New(Config{
LivenessProbe: func(c *fiber.Ctx) bool {
return true
},
LivenessEndpoint: "/live",
ReadinessProbe: func(c *fiber.Ctx) bool {
select {
case <-c1:
return true
default:
return false
}
},
ReadinessEndpoint: "/ready",
}))
// Live should return 200 with GET request
shouldGiveOK(t, app, "/live")
// Live should return 404 with POST request
req, err := app.Test(httptest.NewRequest(fiber.MethodPost, "/live", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusNotFound, req.StatusCode)
// Ready should return 404 with POST request
req, err = app.Test(httptest.NewRequest(fiber.MethodPost, "/ready", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusNotFound, req.StatusCode)
// Ready should return 503 with GET request before the channel is closed
shouldGiveStatus(t, app, "/ready", fiber.StatusServiceUnavailable)
// Ready should return 200 with GET request after the channel is closed
c1 <- struct{}{}
shouldGiveOK(t, app, "/ready")
}
func Test_HealthCheck_Custom_Nested(t *testing.T) {
t.Parallel()
app := fiber.New()
c1 := make(chan struct{}, 1)
app.Use(New(Config{
LivenessProbe: func(c *fiber.Ctx) bool {
return true
},
LivenessEndpoint: "/probe/live",
ReadinessProbe: func(c *fiber.Ctx) bool {
select {
case <-c1:
return true
default:
return false
}
},
ReadinessEndpoint: "/probe/ready",
}))
shouldGiveOK(t, app, "/probe/live")
shouldGiveStatus(t, app, "/probe/ready", fiber.StatusServiceUnavailable)
shouldGiveOK(t, app, "/probe/live/")
shouldGiveStatus(t, app, "/probe/ready/", fiber.StatusServiceUnavailable)
shouldGiveNotFound(t, app, "/probe/livez")
shouldGiveNotFound(t, app, "/probe/readyz")
shouldGiveNotFound(t, app, "/probe/livez/")
shouldGiveNotFound(t, app, "/probe/readyz/")
shouldGiveNotFound(t, app, "/livez")
shouldGiveNotFound(t, app, "/readyz")
shouldGiveNotFound(t, app, "/readyz/")
shouldGiveNotFound(t, app, "/livez/")
c1 <- struct{}{}
shouldGiveOK(t, app, "/probe/ready")
c1 <- struct{}{}
shouldGiveOK(t, app, "/probe/ready/")
}
func Test_HealthCheck_Next(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Next: func(c *fiber.Ctx) bool {
return true
},
}))
shouldGiveNotFound(t, app, "/readyz")
shouldGiveNotFound(t, app, "/livez")
}
func Benchmark_HealthCheck(b *testing.B) {
app := fiber.New()
app.Use(New())
h := app.Handler()
fctx := &fasthttp.RequestCtx{}
fctx.Request.Header.SetMethod(fiber.MethodGet)
fctx.Request.SetRequestURI("/livez")
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
h(fctx)
}
utils.AssertEqual(b, fiber.StatusOK, fctx.Response.Header.StatusCode())
}

154
middleware/helmet/config.go Normal file
View file

@ -0,0 +1,154 @@
package helmet
import (
"github.com/gofiber/fiber/v2"
)
// Config defines the config for middleware.
type Config struct {
// Next defines a function to skip middleware.
// Optional. Default: nil
Next func(*fiber.Ctx) bool
// XSSProtection
// Optional. Default value "0".
XSSProtection string
// ContentTypeNosniff
// Optional. Default value "nosniff".
ContentTypeNosniff string
// XFrameOptions
// Optional. Default value "SAMEORIGIN".
// Possible values: "SAMEORIGIN", "DENY", "ALLOW-FROM uri"
XFrameOptions string
// HSTSMaxAge
// Optional. Default value 0.
HSTSMaxAge int
// HSTSExcludeSubdomains
// Optional. Default value false.
HSTSExcludeSubdomains bool
// ContentSecurityPolicy
// Optional. Default value "".
ContentSecurityPolicy string
// CSPReportOnly
// Optional. Default value false.
CSPReportOnly bool
// HSTSPreloadEnabled
// Optional. Default value false.
HSTSPreloadEnabled bool
// ReferrerPolicy
// Optional. Default value "ReferrerPolicy".
ReferrerPolicy string
// Permissions-Policy
// Optional. Default value "".
PermissionPolicy string
// Cross-Origin-Embedder-Policy
// Optional. Default value "require-corp".
CrossOriginEmbedderPolicy string
// Cross-Origin-Opener-Policy
// Optional. Default value "same-origin".
CrossOriginOpenerPolicy string
// Cross-Origin-Resource-Policy
// Optional. Default value "same-origin".
CrossOriginResourcePolicy string
// Origin-Agent-Cluster
// Optional. Default value "?1".
OriginAgentCluster string
// X-DNS-Prefetch-Control
// Optional. Default value "off".
XDNSPrefetchControl string
// X-Download-Options
// Optional. Default value "noopen".
XDownloadOptions string
// X-Permitted-Cross-Domain-Policies
// Optional. Default value "none".
XPermittedCrossDomain string
}
// ConfigDefault is the default config
var ConfigDefault = Config{
XSSProtection: "0",
ContentTypeNosniff: "nosniff",
XFrameOptions: "SAMEORIGIN",
ReferrerPolicy: "no-referrer",
CrossOriginEmbedderPolicy: "require-corp",
CrossOriginOpenerPolicy: "same-origin",
CrossOriginResourcePolicy: "same-origin",
OriginAgentCluster: "?1",
XDNSPrefetchControl: "off",
XDownloadOptions: "noopen",
XPermittedCrossDomain: "none",
}
// Helper function to set default values
func configDefault(config ...Config) Config {
// Return default config if nothing provided
if len(config) < 1 {
return ConfigDefault
}
// Override default config
cfg := config[0]
// Set default values
if cfg.XSSProtection == "" {
cfg.XSSProtection = ConfigDefault.XSSProtection
}
if cfg.ContentTypeNosniff == "" {
cfg.ContentTypeNosniff = ConfigDefault.ContentTypeNosniff
}
if cfg.XFrameOptions == "" {
cfg.XFrameOptions = ConfigDefault.XFrameOptions
}
if cfg.ReferrerPolicy == "" {
cfg.ReferrerPolicy = ConfigDefault.ReferrerPolicy
}
if cfg.CrossOriginEmbedderPolicy == "" {
cfg.CrossOriginEmbedderPolicy = ConfigDefault.CrossOriginEmbedderPolicy
}
if cfg.CrossOriginOpenerPolicy == "" {
cfg.CrossOriginOpenerPolicy = ConfigDefault.CrossOriginOpenerPolicy
}
if cfg.CrossOriginResourcePolicy == "" {
cfg.CrossOriginResourcePolicy = ConfigDefault.CrossOriginResourcePolicy
}
if cfg.OriginAgentCluster == "" {
cfg.OriginAgentCluster = ConfigDefault.OriginAgentCluster
}
if cfg.XDNSPrefetchControl == "" {
cfg.XDNSPrefetchControl = ConfigDefault.XDNSPrefetchControl
}
if cfg.XDownloadOptions == "" {
cfg.XDownloadOptions = ConfigDefault.XDownloadOptions
}
if cfg.XPermittedCrossDomain == "" {
cfg.XPermittedCrossDomain = ConfigDefault.XPermittedCrossDomain
}
return cfg
}

View file

@ -0,0 +1,94 @@
package helmet
import (
"fmt"
"github.com/gofiber/fiber/v2"
)
// New creates a new middleware handler
func New(config ...Config) fiber.Handler {
// Init config
cfg := configDefault(config...)
// Return middleware handler
return func(c *fiber.Ctx) error {
// Next request to skip middleware
if cfg.Next != nil && cfg.Next(c) {
return c.Next()
}
// Set headers
if cfg.XSSProtection != "" {
c.Set(fiber.HeaderXXSSProtection, cfg.XSSProtection)
}
if cfg.ContentTypeNosniff != "" {
c.Set(fiber.HeaderXContentTypeOptions, cfg.ContentTypeNosniff)
}
if cfg.XFrameOptions != "" {
c.Set(fiber.HeaderXFrameOptions, cfg.XFrameOptions)
}
if cfg.CrossOriginEmbedderPolicy != "" {
c.Set("Cross-Origin-Embedder-Policy", cfg.CrossOriginEmbedderPolicy)
}
if cfg.CrossOriginOpenerPolicy != "" {
c.Set("Cross-Origin-Opener-Policy", cfg.CrossOriginOpenerPolicy)
}
if cfg.CrossOriginResourcePolicy != "" {
c.Set("Cross-Origin-Resource-Policy", cfg.CrossOriginResourcePolicy)
}
if cfg.OriginAgentCluster != "" {
c.Set("Origin-Agent-Cluster", cfg.OriginAgentCluster)
}
if cfg.ReferrerPolicy != "" {
c.Set("Referrer-Policy", cfg.ReferrerPolicy)
}
if cfg.XDNSPrefetchControl != "" {
c.Set("X-DNS-Prefetch-Control", cfg.XDNSPrefetchControl)
}
if cfg.XDownloadOptions != "" {
c.Set("X-Download-Options", cfg.XDownloadOptions)
}
if cfg.XPermittedCrossDomain != "" {
c.Set("X-Permitted-Cross-Domain-Policies", cfg.XPermittedCrossDomain)
}
// Handle HSTS headers
if c.Protocol() == "https" && cfg.HSTSMaxAge != 0 {
subdomains := ""
if !cfg.HSTSExcludeSubdomains {
subdomains = "; includeSubDomains"
}
if cfg.HSTSPreloadEnabled {
subdomains = fmt.Sprintf("%s; preload", subdomains)
}
c.Set(fiber.HeaderStrictTransportSecurity, fmt.Sprintf("max-age=%d%s", cfg.HSTSMaxAge, subdomains))
}
// Handle Content-Security-Policy headers
if cfg.ContentSecurityPolicy != "" {
if cfg.CSPReportOnly {
c.Set(fiber.HeaderContentSecurityPolicyReportOnly, cfg.ContentSecurityPolicy)
} else {
c.Set(fiber.HeaderContentSecurityPolicy, cfg.ContentSecurityPolicy)
}
}
// Handle Permissions-Policy headers
if cfg.PermissionPolicy != "" {
c.Set(fiber.HeaderPermissionsPolicy, cfg.PermissionPolicy)
}
return c.Next()
}
}

View file

@ -0,0 +1,201 @@
package helmet
import (
"net/http/httptest"
"testing"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/utils"
)
func Test_Default(t *testing.T) {
app := fiber.New()
app.Use(New())
app.Get("/", func(c *fiber.Ctx) error {
return c.SendString("Hello, World!")
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "0", resp.Header.Get(fiber.HeaderXXSSProtection))
utils.AssertEqual(t, "nosniff", resp.Header.Get(fiber.HeaderXContentTypeOptions))
utils.AssertEqual(t, "SAMEORIGIN", resp.Header.Get(fiber.HeaderXFrameOptions))
utils.AssertEqual(t, "", resp.Header.Get(fiber.HeaderContentSecurityPolicy))
utils.AssertEqual(t, "no-referrer", resp.Header.Get(fiber.HeaderReferrerPolicy))
utils.AssertEqual(t, "", resp.Header.Get(fiber.HeaderPermissionsPolicy))
utils.AssertEqual(t, "require-corp", resp.Header.Get("Cross-Origin-Embedder-Policy"))
utils.AssertEqual(t, "same-origin", resp.Header.Get("Cross-Origin-Opener-Policy"))
utils.AssertEqual(t, "same-origin", resp.Header.Get("Cross-Origin-Resource-Policy"))
utils.AssertEqual(t, "?1", resp.Header.Get("Origin-Agent-Cluster"))
utils.AssertEqual(t, "off", resp.Header.Get("X-DNS-Prefetch-Control"))
utils.AssertEqual(t, "noopen", resp.Header.Get("X-Download-Options"))
utils.AssertEqual(t, "none", resp.Header.Get("X-Permitted-Cross-Domain-Policies"))
}
func Test_CustomValues_AllHeaders(t *testing.T) {
app := fiber.New()
app.Use(New(Config{
// Custom values for all headers
XSSProtection: "0",
ContentTypeNosniff: "custom-nosniff",
XFrameOptions: "DENY",
HSTSExcludeSubdomains: true,
ContentSecurityPolicy: "default-src 'none'",
CSPReportOnly: true,
HSTSPreloadEnabled: true,
ReferrerPolicy: "origin",
PermissionPolicy: "geolocation=(self)",
CrossOriginEmbedderPolicy: "custom-value",
CrossOriginOpenerPolicy: "custom-value",
CrossOriginResourcePolicy: "custom-value",
OriginAgentCluster: "custom-value",
XDNSPrefetchControl: "custom-control",
XDownloadOptions: "custom-options",
XPermittedCrossDomain: "custom-policies",
}))
app.Get("/", func(c *fiber.Ctx) error {
return c.SendString("Hello, World!")
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
// Assertions for custom header values
utils.AssertEqual(t, "0", resp.Header.Get(fiber.HeaderXXSSProtection))
utils.AssertEqual(t, "custom-nosniff", resp.Header.Get(fiber.HeaderXContentTypeOptions))
utils.AssertEqual(t, "DENY", resp.Header.Get(fiber.HeaderXFrameOptions))
utils.AssertEqual(t, "default-src 'none'", resp.Header.Get(fiber.HeaderContentSecurityPolicyReportOnly))
utils.AssertEqual(t, "origin", resp.Header.Get(fiber.HeaderReferrerPolicy))
utils.AssertEqual(t, "geolocation=(self)", resp.Header.Get(fiber.HeaderPermissionsPolicy))
utils.AssertEqual(t, "custom-value", resp.Header.Get("Cross-Origin-Embedder-Policy"))
utils.AssertEqual(t, "custom-value", resp.Header.Get("Cross-Origin-Opener-Policy"))
utils.AssertEqual(t, "custom-value", resp.Header.Get("Cross-Origin-Resource-Policy"))
utils.AssertEqual(t, "custom-value", resp.Header.Get("Origin-Agent-Cluster"))
utils.AssertEqual(t, "custom-control", resp.Header.Get("X-DNS-Prefetch-Control"))
utils.AssertEqual(t, "custom-options", resp.Header.Get("X-Download-Options"))
utils.AssertEqual(t, "custom-policies", resp.Header.Get("X-Permitted-Cross-Domain-Policies"))
}
func Test_RealWorldValues_AllHeaders(t *testing.T) {
app := fiber.New()
app.Use(New(Config{
// Real-world values for all headers
XSSProtection: "0",
ContentTypeNosniff: "nosniff",
XFrameOptions: "SAMEORIGIN",
HSTSExcludeSubdomains: false,
ContentSecurityPolicy: "default-src 'self';base-uri 'self';font-src 'self' https: data:;form-action 'self';frame-ancestors 'self';img-src 'self' data:;object-src 'none';script-src 'self';script-src-attr 'none';style-src 'self' https: 'unsafe-inline';upgrade-insecure-requests",
CSPReportOnly: false,
HSTSPreloadEnabled: true,
ReferrerPolicy: "no-referrer",
PermissionPolicy: "geolocation=(self)",
CrossOriginEmbedderPolicy: "require-corp",
CrossOriginOpenerPolicy: "same-origin",
CrossOriginResourcePolicy: "same-origin",
OriginAgentCluster: "?1",
XDNSPrefetchControl: "off",
XDownloadOptions: "noopen",
XPermittedCrossDomain: "none",
}))
app.Get("/", func(c *fiber.Ctx) error {
return c.SendString("Hello, World!")
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
// Assertions for real-world header values
utils.AssertEqual(t, "0", resp.Header.Get(fiber.HeaderXXSSProtection))
utils.AssertEqual(t, "nosniff", resp.Header.Get(fiber.HeaderXContentTypeOptions))
utils.AssertEqual(t, "SAMEORIGIN", resp.Header.Get(fiber.HeaderXFrameOptions))
utils.AssertEqual(t, "default-src 'self';base-uri 'self';font-src 'self' https: data:;form-action 'self';frame-ancestors 'self';img-src 'self' data:;object-src 'none';script-src 'self';script-src-attr 'none';style-src 'self' https: 'unsafe-inline';upgrade-insecure-requests", resp.Header.Get(fiber.HeaderContentSecurityPolicy))
utils.AssertEqual(t, "no-referrer", resp.Header.Get(fiber.HeaderReferrerPolicy))
utils.AssertEqual(t, "geolocation=(self)", resp.Header.Get(fiber.HeaderPermissionsPolicy))
utils.AssertEqual(t, "require-corp", resp.Header.Get("Cross-Origin-Embedder-Policy"))
utils.AssertEqual(t, "same-origin", resp.Header.Get("Cross-Origin-Opener-Policy"))
utils.AssertEqual(t, "same-origin", resp.Header.Get("Cross-Origin-Resource-Policy"))
utils.AssertEqual(t, "?1", resp.Header.Get("Origin-Agent-Cluster"))
utils.AssertEqual(t, "off", resp.Header.Get("X-DNS-Prefetch-Control"))
utils.AssertEqual(t, "noopen", resp.Header.Get("X-Download-Options"))
utils.AssertEqual(t, "none", resp.Header.Get("X-Permitted-Cross-Domain-Policies"))
}
func Test_Next(t *testing.T) {
app := fiber.New()
app.Use(New(Config{
Next: func(ctx *fiber.Ctx) bool {
return ctx.Path() == "/next"
},
ReferrerPolicy: "no-referrer",
}))
app.Get("/", func(c *fiber.Ctx) error {
return c.SendString("Hello, World!")
})
app.Get("/next", func(c *fiber.Ctx) error {
return c.SendString("Skipped!")
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "no-referrer", resp.Header.Get(fiber.HeaderReferrerPolicy))
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/next", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "", resp.Header.Get(fiber.HeaderReferrerPolicy))
}
func Test_ContentSecurityPolicy(t *testing.T) {
app := fiber.New()
app.Use(New(Config{
ContentSecurityPolicy: "default-src 'none'",
}))
app.Get("/", func(c *fiber.Ctx) error {
return c.SendString("Hello, World!")
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "default-src 'none'", resp.Header.Get(fiber.HeaderContentSecurityPolicy))
}
func Test_ContentSecurityPolicyReportOnly(t *testing.T) {
app := fiber.New()
app.Use(New(Config{
ContentSecurityPolicy: "default-src 'none'",
CSPReportOnly: true,
}))
app.Get("/", func(c *fiber.Ctx) error {
return c.SendString("Hello, World!")
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "default-src 'none'", resp.Header.Get(fiber.HeaderContentSecurityPolicyReportOnly))
utils.AssertEqual(t, "", resp.Header.Get(fiber.HeaderContentSecurityPolicy))
}
func Test_PermissionsPolicy(t *testing.T) {
app := fiber.New()
app.Use(New(Config{
PermissionPolicy: "microphone=()",
}))
app.Get("/", func(c *fiber.Ctx) error {
return c.SendString("Hello, World!")
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "microphone=()", resp.Header.Get(fiber.HeaderPermissionsPolicy))
}

View file

@ -0,0 +1,125 @@
package idempotency
import (
"errors"
"fmt"
"time"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/internal/storage/memory"
)
var ErrInvalidIdempotencyKey = errors.New("invalid idempotency key")
// Config defines the config for middleware.
type Config struct {
// Next defines a function to skip this middleware when returned true.
//
// Optional. Default: a function which skips the middleware on safe HTTP request method.
Next func(c *fiber.Ctx) bool
// Lifetime is the maximum lifetime of an idempotency key.
//
// Optional. Default: 30 * time.Minute
Lifetime time.Duration
// KeyHeader is the name of the header that contains the idempotency key.
//
// Optional. Default: X-Idempotency-Key
KeyHeader string
// KeyHeaderValidate defines a function to validate the syntax of the idempotency header.
//
// Optional. Default: a function which ensures the header is 36 characters long (the size of an UUID).
KeyHeaderValidate func(string) error
// KeepResponseHeaders is a list of headers that should be kept from the original response.
//
// Optional. Default: nil (to keep all headers)
KeepResponseHeaders []string
// Lock locks an idempotency key.
//
// Optional. Default: an in-memory locker for this process only.
Lock Locker
// Storage stores response data by idempotency key.
//
// Optional. Default: an in-memory storage for this process only.
Storage fiber.Storage
}
// ConfigDefault is the default config
var ConfigDefault = Config{
Next: func(c *fiber.Ctx) bool {
// Skip middleware if the request was done using a safe HTTP method
return fiber.IsMethodSafe(c.Method())
},
Lifetime: 30 * time.Minute,
KeyHeader: "X-Idempotency-Key",
KeyHeaderValidate: func(k string) error {
if l, wl := len(k), 36; l != wl { // UUID length is 36 chars
return fmt.Errorf("%w: invalid length: %d != %d", ErrInvalidIdempotencyKey, l, wl)
}
return nil
},
KeepResponseHeaders: nil,
Lock: nil, // Set in configDefault so we don't allocate data here.
Storage: nil, // Set in configDefault so we don't allocate data here.
}
// Helper function to set default values
func configDefault(config ...Config) Config {
// Return default config if nothing provided
if len(config) < 1 {
cfg := ConfigDefault
cfg.Lock = NewMemoryLock()
cfg.Storage = memory.New(memory.Config{
GCInterval: cfg.Lifetime / 2, // Half the lifetime interval
})
return cfg
}
// Override default config
cfg := config[0]
// Set default values
if cfg.Next == nil {
cfg.Next = ConfigDefault.Next
}
if cfg.Lifetime.Nanoseconds() == 0 {
cfg.Lifetime = ConfigDefault.Lifetime
}
if cfg.KeyHeader == "" {
cfg.KeyHeader = ConfigDefault.KeyHeader
}
if cfg.KeyHeaderValidate == nil {
cfg.KeyHeaderValidate = ConfigDefault.KeyHeaderValidate
}
if cfg.KeepResponseHeaders != nil && len(cfg.KeepResponseHeaders) == 0 {
cfg.KeepResponseHeaders = ConfigDefault.KeepResponseHeaders
}
if cfg.Lock == nil {
cfg.Lock = NewMemoryLock()
}
if cfg.Storage == nil {
cfg.Storage = memory.New(memory.Config{
GCInterval: cfg.Lifetime / 2, // Half the lifetime interval
})
}
return cfg
}

View file

@ -0,0 +1,153 @@
package idempotency
import (
"fmt"
"strings"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/log"
"github.com/gofiber/fiber/v2/utils"
)
// Inspired by https://datatracker.ietf.org/doc/html/draft-ietf-httpapi-idempotency-key-header-02
// and https://github.com/penguin-statistics/backend-next/blob/f2f7d5ba54fc8a58f168d153baa17b2ad4a14e45/internal/pkg/middlewares/idempotency.go
type localsKeys string
const (
localsKeyIsFromCache localsKeys = "idempotency_isfromcache"
localsKeyWasPutToCache localsKeys = "idempotency_wasputtocache"
)
func IsFromCache(c *fiber.Ctx) bool {
return c.Locals(localsKeyIsFromCache) != nil
}
func WasPutToCache(c *fiber.Ctx) bool {
return c.Locals(localsKeyWasPutToCache) != nil
}
func New(config ...Config) fiber.Handler {
// Set default config
cfg := configDefault(config...)
keepResponseHeadersMap := make(map[string]struct{}, len(cfg.KeepResponseHeaders))
for _, h := range cfg.KeepResponseHeaders {
keepResponseHeadersMap[strings.ToLower(h)] = struct{}{}
}
maybeWriteCachedResponse := func(c *fiber.Ctx, key string) (bool, error) {
if val, err := cfg.Storage.Get(key); err != nil {
return false, fmt.Errorf("failed to read response: %w", err)
} else if val != nil {
var res response
if _, err := res.UnmarshalMsg(val); err != nil {
return false, fmt.Errorf("failed to unmarshal response: %w", err)
}
_ = c.Status(res.StatusCode)
for header, vals := range res.Headers {
for _, val := range vals {
c.Context().Response.Header.Add(header, val)
}
}
if len(res.Body) != 0 {
if err := c.Send(res.Body); err != nil {
return true, err
}
}
_ = c.Locals(localsKeyIsFromCache, true)
return true, nil
}
return false, nil
}
return func(c *fiber.Ctx) error {
// Don't execute middleware if Next returns true
if cfg.Next != nil && cfg.Next(c) {
return c.Next()
}
// Don't execute middleware if the idempotency key is empty
key := utils.CopyString(c.Get(cfg.KeyHeader))
if key == "" {
return c.Next()
}
// Validate key
if err := cfg.KeyHeaderValidate(key); err != nil {
return err
}
// First-pass: if the idempotency key is in the storage, get and return the response
if ok, err := maybeWriteCachedResponse(c, key); err != nil {
return fmt.Errorf("failed to write cached response at fastpath: %w", err)
} else if ok {
return nil
}
if err := cfg.Lock.Lock(key); err != nil {
return fmt.Errorf("failed to lock: %w", err)
}
defer func() {
if err := cfg.Lock.Unlock(key); err != nil {
log.Errorf("[IDEMPOTENCY] failed to unlock key %q: %v", key, err)
}
}()
// Lock acquired. If the idempotency key now is in the storage, get and return the response
if ok, err := maybeWriteCachedResponse(c, key); err != nil {
return fmt.Errorf("failed to write cached response while locked: %w", err)
} else if ok {
return nil
}
// Execute the request handler
if err := c.Next(); err != nil {
// If the request handler returned an error, return it and skip idempotency
return err
}
// Construct response
res := &response{
StatusCode: c.Response().StatusCode(),
Body: utils.CopyBytes(c.Response().Body()),
}
{
headers := c.GetRespHeaders()
if cfg.KeepResponseHeaders == nil {
// Keep all
res.Headers = headers
} else {
// Filter
res.Headers = make(map[string][]string)
for h := range headers {
if _, ok := keepResponseHeadersMap[utils.ToLower(h)]; ok {
res.Headers[h] = headers[h]
}
}
}
}
// Marshal response
bs, err := res.MarshalMsg(nil)
if err != nil {
return fmt.Errorf("failed to marshal response: %w", err)
}
// Store response
if err := cfg.Storage.Set(key, bs, cfg.Lifetime); err != nil {
return fmt.Errorf("failed to save response: %w", err)
}
_ = c.Locals(localsKeyWasPutToCache, true)
return nil
}
}

View file

@ -0,0 +1,177 @@
//nolint:bodyclose // Much easier to just ignore memory leaks in tests
package idempotency_test
import (
"errors"
"io"
"net/http"
"net/http/httptest"
"strconv"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/idempotency"
"github.com/gofiber/fiber/v2/utils"
"github.com/valyala/fasthttp"
)
// go test -run Test_Idempotency
func Test_Idempotency(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(func(c *fiber.Ctx) error {
if err := c.Next(); err != nil {
return err
}
isMethodSafe := fiber.IsMethodSafe(c.Method())
isIdempotent := idempotency.IsFromCache(c) || idempotency.WasPutToCache(c)
hasReqHeader := c.Get("X-Idempotency-Key") != ""
if isMethodSafe {
if isIdempotent {
return errors.New("request with safe HTTP method should not be idempotent")
}
} else {
// Unsafe
if hasReqHeader {
if !isIdempotent {
return errors.New("request with unsafe HTTP method should be idempotent if X-Idempotency-Key request header is set")
}
} else {
// No request header
if isIdempotent {
return errors.New("request with unsafe HTTP method should not be idempotent if X-Idempotency-Key request header is not set")
}
}
}
return nil
})
// Needs to be at least a second as the memory storage doesn't support shorter durations.
const lifetime = 1 * time.Second
app.Use(idempotency.New(idempotency.Config{
Lifetime: lifetime,
}))
nextCount := func() func() int {
var count int32
return func() int {
return int(atomic.AddInt32(&count, 1))
}
}()
{
handler := func(c *fiber.Ctx) error {
return c.SendString(strconv.Itoa(nextCount()))
}
app.Get("/", handler)
app.Post("/", handler)
}
app.Post("/slow", func(c *fiber.Ctx) error {
time.Sleep(2 * lifetime)
return c.SendString(strconv.Itoa(nextCount()))
})
doReq := func(method, route, idempotencyKey string) string {
req := httptest.NewRequest(method, route, http.NoBody)
if idempotencyKey != "" {
req.Header.Set("X-Idempotency-Key", idempotencyKey)
}
resp, err := app.Test(req, 3*int(lifetime.Milliseconds()))
utils.AssertEqual(t, nil, err)
body, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode, string(body))
return string(body)
}
utils.AssertEqual(t, "1", doReq(fiber.MethodGet, "/", ""))
utils.AssertEqual(t, "2", doReq(fiber.MethodGet, "/", ""))
utils.AssertEqual(t, "3", doReq(fiber.MethodPost, "/", ""))
utils.AssertEqual(t, "4", doReq(fiber.MethodPost, "/", ""))
utils.AssertEqual(t, "5", doReq(fiber.MethodGet, "/", "00000000-0000-0000-0000-000000000000"))
utils.AssertEqual(t, "6", doReq(fiber.MethodGet, "/", "00000000-0000-0000-0000-000000000000"))
utils.AssertEqual(t, "7", doReq(fiber.MethodPost, "/", "00000000-0000-0000-0000-000000000000"))
utils.AssertEqual(t, "7", doReq(fiber.MethodPost, "/", "00000000-0000-0000-0000-000000000000"))
utils.AssertEqual(t, "8", doReq(fiber.MethodPost, "/", ""))
utils.AssertEqual(t, "9", doReq(fiber.MethodPost, "/", "11111111-1111-1111-1111-111111111111"))
utils.AssertEqual(t, "7", doReq(fiber.MethodPost, "/", "00000000-0000-0000-0000-000000000000"))
time.Sleep(2 * lifetime)
utils.AssertEqual(t, "10", doReq(fiber.MethodPost, "/", "00000000-0000-0000-0000-000000000000"))
utils.AssertEqual(t, "10", doReq(fiber.MethodPost, "/", "00000000-0000-0000-0000-000000000000"))
// Test raciness
{
var wg sync.WaitGroup
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
defer wg.Done()
utils.AssertEqual(t, "11", doReq(fiber.MethodPost, "/slow", "22222222-2222-2222-2222-222222222222"))
}()
}
wg.Wait()
utils.AssertEqual(t, "11", doReq(fiber.MethodPost, "/slow", "22222222-2222-2222-2222-222222222222"))
}
time.Sleep(2 * lifetime)
utils.AssertEqual(t, "12", doReq(fiber.MethodPost, "/slow", "22222222-2222-2222-2222-222222222222"))
}
// go test -v -run=^$ -bench=Benchmark_Idempotency -benchmem -count=4
func Benchmark_Idempotency(b *testing.B) {
app := fiber.New()
// Needs to be at least a second as the memory storage doesn't support shorter durations.
const lifetime = 1 * time.Second
app.Use(idempotency.New(idempotency.Config{
Lifetime: lifetime,
}))
app.Post("/", func(c *fiber.Ctx) error {
return nil
})
h := app.Handler()
b.Run("hit", func(b *testing.B) {
c := &fasthttp.RequestCtx{}
c.Request.Header.SetMethod(fiber.MethodPost)
c.Request.SetRequestURI("/")
c.Request.Header.Set("X-Idempotency-Key", "00000000-0000-0000-0000-000000000000")
b.ReportAllocs()
b.ResetTimer()
for n := 0; n < b.N; n++ {
h(c)
}
})
b.Run("skip", func(b *testing.B) {
c := &fasthttp.RequestCtx{}
c.Request.Header.SetMethod(fiber.MethodPost)
c.Request.SetRequestURI("/")
b.ReportAllocs()
b.ResetTimer()
for n := 0; n < b.N; n++ {
h(c)
}
})
}

View file

@ -0,0 +1,53 @@
package idempotency
import (
"sync"
)
// Locker implements a spinlock for a string key.
type Locker interface {
Lock(key string) error
Unlock(key string) error
}
type MemoryLock struct {
mu sync.Mutex
keys map[string]*sync.Mutex
}
func (l *MemoryLock) Lock(key string) error {
l.mu.Lock()
mu, ok := l.keys[key]
if !ok {
mu = new(sync.Mutex)
l.keys[key] = mu
}
l.mu.Unlock()
mu.Lock()
return nil
}
func (l *MemoryLock) Unlock(key string) error {
l.mu.Lock()
mu, ok := l.keys[key]
l.mu.Unlock()
if !ok {
// This happens if we try to unlock an unknown key
return nil
}
mu.Unlock()
return nil
}
func NewMemoryLock() *MemoryLock {
return &MemoryLock{
keys: make(map[string]*sync.Mutex),
}
}
var _ Locker = (*MemoryLock)(nil)

View file

@ -0,0 +1,59 @@
package idempotency_test
import (
"testing"
"time"
"github.com/gofiber/fiber/v2/middleware/idempotency"
"github.com/gofiber/fiber/v2/utils"
)
// go test -run Test_MemoryLock
func Test_MemoryLock(t *testing.T) {
t.Parallel()
l := idempotency.NewMemoryLock()
{
err := l.Lock("a")
utils.AssertEqual(t, nil, err)
}
{
done := make(chan struct{})
go func() {
defer close(done)
err := l.Lock("a")
utils.AssertEqual(t, nil, err)
}()
select {
case <-done:
t.Fatal("lock acquired again")
case <-time.After(time.Second):
}
}
{
err := l.Lock("b")
utils.AssertEqual(t, nil, err)
}
{
err := l.Unlock("b")
utils.AssertEqual(t, nil, err)
}
{
err := l.Lock("b")
utils.AssertEqual(t, nil, err)
}
{
err := l.Unlock("c")
utils.AssertEqual(t, nil, err)
}
{
err := l.Lock("d")
utils.AssertEqual(t, nil, err)
}
}

View file

@ -0,0 +1,10 @@
package idempotency
//go:generate msgp -o=response_msgp.go -io=false -unexported
type response struct {
StatusCode int `msg:"sc"`
Headers map[string][]string `msg:"hs"`
Body []byte `msg:"b"`
}

View file

@ -0,0 +1,131 @@
package idempotency
// Code generated by github.com/tinylib/msgp DO NOT EDIT.
import (
"github.com/tinylib/msgp/msgp"
)
// MarshalMsg implements msgp.Marshaler
func (z *response) MarshalMsg(b []byte) (o []byte, err error) {
o = msgp.Require(b, z.Msgsize())
// map header, size 3
// string "sc"
o = append(o, 0x83, 0xa2, 0x73, 0x63)
o = msgp.AppendInt(o, z.StatusCode)
// string "hs"
o = append(o, 0xa2, 0x68, 0x73)
o = msgp.AppendMapHeader(o, uint32(len(z.Headers)))
for za0001, za0002 := range z.Headers {
o = msgp.AppendString(o, za0001)
o = msgp.AppendArrayHeader(o, uint32(len(za0002)))
for za0003 := range za0002 {
o = msgp.AppendString(o, za0002[za0003])
}
}
// string "b"
o = append(o, 0xa1, 0x62)
o = msgp.AppendBytes(o, z.Body)
return
}
// UnmarshalMsg implements msgp.Unmarshaler
func (z *response) UnmarshalMsg(bts []byte) (o []byte, err error) {
var field []byte
_ = field
var zb0001 uint32
zb0001, bts, err = msgp.ReadMapHeaderBytes(bts)
if err != nil {
err = msgp.WrapError(err)
return
}
for zb0001 > 0 {
zb0001--
field, bts, err = msgp.ReadMapKeyZC(bts)
if err != nil {
err = msgp.WrapError(err)
return
}
switch msgp.UnsafeString(field) {
case "sc":
z.StatusCode, bts, err = msgp.ReadIntBytes(bts)
if err != nil {
err = msgp.WrapError(err, "StatusCode")
return
}
case "hs":
var zb0002 uint32
zb0002, bts, err = msgp.ReadMapHeaderBytes(bts)
if err != nil {
err = msgp.WrapError(err, "Headers")
return
}
if z.Headers == nil {
z.Headers = make(map[string][]string, zb0002)
} else if len(z.Headers) > 0 {
for key := range z.Headers {
delete(z.Headers, key)
}
}
for zb0002 > 0 {
var za0001 string
var za0002 []string
zb0002--
za0001, bts, err = msgp.ReadStringBytes(bts)
if err != nil {
err = msgp.WrapError(err, "Headers")
return
}
var zb0003 uint32
zb0003, bts, err = msgp.ReadArrayHeaderBytes(bts)
if err != nil {
err = msgp.WrapError(err, "Headers", za0001)
return
}
if cap(za0002) >= int(zb0003) {
za0002 = (za0002)[:zb0003]
} else {
za0002 = make([]string, zb0003)
}
for za0003 := range za0002 {
za0002[za0003], bts, err = msgp.ReadStringBytes(bts)
if err != nil {
err = msgp.WrapError(err, "Headers", za0001, za0003)
return
}
}
z.Headers[za0001] = za0002
}
case "b":
z.Body, bts, err = msgp.ReadBytesBytes(bts, z.Body)
if err != nil {
err = msgp.WrapError(err, "Body")
return
}
default:
bts, err = msgp.Skip(bts)
if err != nil {
err = msgp.WrapError(err)
return
}
}
}
o = bts
return
}
// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message
func (z *response) Msgsize() (s int) {
s = 1 + 3 + msgp.IntSize + 3 + msgp.MapHeaderSize
if z.Headers != nil {
for za0001, za0002 := range z.Headers {
_ = za0002
s += msgp.StringPrefixSize + len(za0001) + msgp.ArrayHeaderSize
for za0003 := range za0002 {
s += msgp.StringPrefixSize + len(za0002[za0003])
}
}
}
s += 2 + msgp.BytesPrefixSize + len(z.Body)
return
}

View file

@ -0,0 +1,67 @@
package idempotency
// Code generated by github.com/tinylib/msgp DO NOT EDIT.
import (
"testing"
"github.com/tinylib/msgp/msgp"
)
func TestMarshalUnmarshalresponse(t *testing.T) {
v := response{}
bts, err := v.MarshalMsg(nil)
if err != nil {
t.Fatal(err)
}
left, err := v.UnmarshalMsg(bts)
if err != nil {
t.Fatal(err)
}
if len(left) > 0 {
t.Errorf("%d bytes left over after UnmarshalMsg(): %q", len(left), left)
}
left, err = msgp.Skip(bts)
if err != nil {
t.Fatal(err)
}
if len(left) > 0 {
t.Errorf("%d bytes left over after Skip(): %q", len(left), left)
}
}
func BenchmarkMarshalMsgresponse(b *testing.B) {
v := response{}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
v.MarshalMsg(nil)
}
}
func BenchmarkAppendMsgresponse(b *testing.B) {
v := response{}
bts := make([]byte, 0, v.Msgsize())
bts, _ = v.MarshalMsg(bts[0:0])
b.SetBytes(int64(len(bts)))
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
bts, _ = v.MarshalMsg(bts[0:0])
}
}
func BenchmarkUnmarshalresponse(b *testing.B) {
v := response{}
bts, _ := v.MarshalMsg(nil)
b.ReportAllocs()
b.SetBytes(int64(len(bts)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := v.UnmarshalMsg(bts)
if err != nil {
b.Fatal(err)
}
}
}

View file

@ -0,0 +1,95 @@
package keyauth
import (
"errors"
"github.com/gofiber/fiber/v2"
)
// Config defines the config for middleware.
type Config struct {
// Next defines a function to skip middleware.
// Optional. Default: nil
Next func(*fiber.Ctx) bool
// SuccessHandler defines a function which is executed for a valid key.
// Optional. Default: nil
SuccessHandler fiber.Handler
// ErrorHandler defines a function which is executed for an invalid key.
// It may be used to define a custom error.
// Optional. Default: 401 Invalid or expired key
ErrorHandler fiber.ErrorHandler
// KeyLookup is a string in the form of "<source>:<name>" that is used
// to extract key from the request.
// Optional. Default value "header:Authorization".
// Possible values:
// - "header:<name>"
// - "query:<name>"
// - "form:<name>"
// - "param:<name>"
// - "cookie:<name>"
KeyLookup string
// AuthScheme to be used in the Authorization header.
// Optional. Default value "Bearer".
AuthScheme string
// Validator is a function to validate key.
Validator func(*fiber.Ctx, string) (bool, error)
// Context key to store the bearertoken from the token into context.
// Optional. Default: "token".
ContextKey interface{}
}
// ConfigDefault is the default config
var ConfigDefault = Config{
SuccessHandler: func(c *fiber.Ctx) error {
return c.Next()
},
ErrorHandler: func(c *fiber.Ctx, err error) error {
if errors.Is(err, ErrMissingOrMalformedAPIKey) {
return c.Status(fiber.StatusUnauthorized).SendString(err.Error())
}
return c.Status(fiber.StatusUnauthorized).SendString("Invalid or expired API Key")
},
KeyLookup: "header:" + fiber.HeaderAuthorization,
AuthScheme: "Bearer",
ContextKey: "token",
}
// Helper function to set default values
func configDefault(config ...Config) Config {
// Return default config if nothing provided
if len(config) < 1 {
return ConfigDefault
}
// Override default config
cfg := config[0]
// Set default values
if cfg.SuccessHandler == nil {
cfg.SuccessHandler = ConfigDefault.SuccessHandler
}
if cfg.ErrorHandler == nil {
cfg.ErrorHandler = ConfigDefault.ErrorHandler
}
if cfg.KeyLookup == "" {
cfg.KeyLookup = ConfigDefault.KeyLookup
// set AuthScheme as "Bearer" only if KeyLookup is set to default.
if cfg.AuthScheme == "" {
cfg.AuthScheme = ConfigDefault.AuthScheme
}
}
if cfg.Validator == nil {
panic("fiber: keyauth middleware requires a validator function")
}
if cfg.ContextKey == nil {
cfg.ContextKey = ConfigDefault.ContextKey
}
return cfg
}

View file

@ -0,0 +1,121 @@
// Special thanks to Echo: https://github.com/labstack/echo/blob/master/middleware/key_auth.go
package keyauth
import (
"errors"
"net/url"
"strings"
"github.com/gofiber/fiber/v2"
)
// When there is no request of the key thrown ErrMissingOrMalformedAPIKey
var ErrMissingOrMalformedAPIKey = errors.New("missing or malformed API Key")
const (
query = "query"
form = "form"
param = "param"
cookie = "cookie"
)
// New creates a new middleware handler
func New(config ...Config) fiber.Handler {
// Init config
cfg := configDefault(config...)
// Initialize
parts := strings.Split(cfg.KeyLookup, ":")
extractor := keyFromHeader(parts[1], cfg.AuthScheme)
switch parts[0] {
case query:
extractor = keyFromQuery(parts[1])
case form:
extractor = keyFromForm(parts[1])
case param:
extractor = keyFromParam(parts[1])
case cookie:
extractor = keyFromCookie(parts[1])
}
// Return middleware handler
return func(c *fiber.Ctx) error {
// Filter request to skip middleware
if cfg.Next != nil && cfg.Next(c) {
return c.Next()
}
// Extract and verify key
key, err := extractor(c)
if err != nil {
return cfg.ErrorHandler(c, err)
}
valid, err := cfg.Validator(c, key)
if err == nil && valid {
c.Locals(cfg.ContextKey, key)
return cfg.SuccessHandler(c)
}
return cfg.ErrorHandler(c, err)
}
}
// keyFromHeader returns a function that extracts api key from the request header.
func keyFromHeader(header, authScheme string) func(c *fiber.Ctx) (string, error) {
return func(c *fiber.Ctx) (string, error) {
auth := c.Get(header)
l := len(authScheme)
if len(auth) > 0 && l == 0 {
return auth, nil
}
if len(auth) > l+1 && auth[:l] == authScheme {
return auth[l+1:], nil
}
return "", ErrMissingOrMalformedAPIKey
}
}
// keyFromQuery returns a function that extracts api key from the query string.
func keyFromQuery(param string) func(c *fiber.Ctx) (string, error) {
return func(c *fiber.Ctx) (string, error) {
key := c.Query(param)
if key == "" {
return "", ErrMissingOrMalformedAPIKey
}
return key, nil
}
}
// keyFromForm returns a function that extracts api key from the form.
func keyFromForm(param string) func(c *fiber.Ctx) (string, error) {
return func(c *fiber.Ctx) (string, error) {
key := c.FormValue(param)
if key == "" {
return "", ErrMissingOrMalformedAPIKey
}
return key, nil
}
}
// keyFromParam returns a function that extracts api key from the url param string.
func keyFromParam(param string) func(c *fiber.Ctx) (string, error) {
return func(c *fiber.Ctx) (string, error) {
key, err := url.PathUnescape(c.Params(param))
if err != nil {
return "", ErrMissingOrMalformedAPIKey
}
return key, nil
}
}
// keyFromCookie returns a function that extracts api key from the named cookie.
func keyFromCookie(name string) func(c *fiber.Ctx) (string, error) {
return func(c *fiber.Ctx) (string, error) {
key := c.Cookies(name)
if key == "" {
return "", ErrMissingOrMalformedAPIKey
}
return key, nil
}
}

View file

@ -0,0 +1,461 @@
//nolint:bodyclose // Much easier to just ignore memory leaks in tests
package keyauth
import (
"context"
"fmt"
"io"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/utils"
)
const CorrectKey = "specials: !$%,.#\"!?~`<>@$^*(){}[]|/\\123"
func TestAuthSources(t *testing.T) {
// define test cases
testSources := []string{"header", "cookie", "query", "param", "form"}
tests := []struct {
route string
authTokenName string
description string
APIKey string
expectedCode int
expectedBody string
}{
{
route: "/",
authTokenName: "access_token",
description: "auth with correct key",
APIKey: CorrectKey,
expectedCode: 200,
expectedBody: "Success!",
},
{
route: "/",
authTokenName: "access_token",
description: "auth with no key",
APIKey: "",
expectedCode: 401, // 404 in case of param authentication
expectedBody: "missing or malformed API Key",
},
{
route: "/",
authTokenName: "access_token",
description: "auth with wrong key",
APIKey: "WRONGKEY",
expectedCode: 401,
expectedBody: "missing or malformed API Key",
},
}
for _, authSource := range testSources {
t.Run(authSource, func(t *testing.T) {
for _, test := range tests {
// setup the fiber endpoint
// note that if UnescapePath: false (the default)
// escaped characters (such as `\"`) will not be handled correctly in the tests
app := fiber.New(fiber.Config{UnescapePath: true})
authMiddleware := New(Config{
KeyLookup: authSource + ":" + test.authTokenName,
Validator: func(c *fiber.Ctx, key string) (bool, error) {
if key == CorrectKey {
return true, nil
}
return false, ErrMissingOrMalformedAPIKey
},
})
var route string
if authSource == param {
route = test.route + ":" + test.authTokenName
app.Use(route, authMiddleware)
} else {
route = test.route
app.Use(authMiddleware)
}
app.Get(route, func(c *fiber.Ctx) error {
return c.SendString("Success!")
})
// construct the test HTTP request
var req *http.Request
req, err := http.NewRequestWithContext(context.Background(), fiber.MethodGet, test.route, nil)
utils.AssertEqual(t, err, nil)
// setup the apikey for the different auth schemes
if authSource == "header" {
req.Header.Set(test.authTokenName, test.APIKey)
} else if authSource == "cookie" {
req.Header.Set("Cookie", test.authTokenName+"="+test.APIKey)
} else if authSource == "query" || authSource == "form" {
q := req.URL.Query()
q.Add(test.authTokenName, test.APIKey)
req.URL.RawQuery = q.Encode()
} else if authSource == "param" {
r := req.URL.Path
r += url.PathEscape(test.APIKey)
req.URL.Path = r
}
res, err := app.Test(req, -1)
utils.AssertEqual(t, nil, err, test.description)
// test the body of the request
body, err := io.ReadAll(res.Body)
// for param authentication, the route would be /:access_token
// when the access_token is empty, it leads to a 404 (not found)
// not a 401 (auth error)
if authSource == "param" && test.APIKey == "" {
test.expectedCode = 404
test.expectedBody = "Cannot GET /"
}
utils.AssertEqual(t, test.expectedCode, res.StatusCode, test.description)
// body
utils.AssertEqual(t, nil, err, test.description)
utils.AssertEqual(t, test.expectedBody, string(body), test.description)
err = res.Body.Close()
utils.AssertEqual(t, err, nil)
}
})
}
}
func TestMultipleKeyAuth(t *testing.T) {
// setup the fiber endpoint
app := fiber.New()
// setup keyauth for /auth1
app.Use(New(Config{
Next: func(c *fiber.Ctx) bool {
return c.OriginalURL() != "/auth1"
},
KeyLookup: "header:key",
Validator: func(c *fiber.Ctx, key string) (bool, error) {
if key == "password1" {
return true, nil
}
return false, ErrMissingOrMalformedAPIKey
},
}))
// setup keyauth for /auth2
app.Use(New(Config{
Next: func(c *fiber.Ctx) bool {
return c.OriginalURL() != "/auth2"
},
KeyLookup: "header:key",
Validator: func(c *fiber.Ctx, key string) (bool, error) {
if key == "password2" {
return true, nil
}
return false, ErrMissingOrMalformedAPIKey
},
}))
app.Get("/", func(c *fiber.Ctx) error {
return c.SendString("No auth needed!")
})
app.Get("/auth1", func(c *fiber.Ctx) error {
return c.SendString("Successfully authenticated for auth1!")
})
app.Get("/auth2", func(c *fiber.Ctx) error {
return c.SendString("Successfully authenticated for auth2!")
})
// define test cases
tests := []struct {
route string
description string
APIKey string
expectedCode int
expectedBody string
}{
// No auth needed for /
{
route: "/",
description: "No password needed",
APIKey: "",
expectedCode: 200,
expectedBody: "No auth needed!",
},
// auth needed for auth1
{
route: "/auth1",
description: "Normal Authentication Case",
APIKey: "password1",
expectedCode: 200,
expectedBody: "Successfully authenticated for auth1!",
},
{
route: "/auth1",
description: "Wrong API Key",
APIKey: "WRONG KEY",
expectedCode: 401,
expectedBody: "missing or malformed API Key",
},
{
route: "/auth1",
description: "Wrong API Key",
APIKey: "", // NO KEY
expectedCode: 401,
expectedBody: "missing or malformed API Key",
},
// Auth 2 has a different password
{
route: "/auth2",
description: "Normal Authentication Case for auth2",
APIKey: "password2",
expectedCode: 200,
expectedBody: "Successfully authenticated for auth2!",
},
{
route: "/auth2",
description: "Wrong API Key",
APIKey: "WRONG KEY",
expectedCode: 401,
expectedBody: "missing or malformed API Key",
},
{
route: "/auth2",
description: "Wrong API Key",
APIKey: "", // NO KEY
expectedCode: 401,
expectedBody: "missing or malformed API Key",
},
}
// run the tests
for _, test := range tests {
var req *http.Request
req, err := http.NewRequestWithContext(context.Background(), fiber.MethodGet, test.route, nil)
utils.AssertEqual(t, err, nil)
if test.APIKey != "" {
req.Header.Set("key", test.APIKey)
}
res, err := app.Test(req, -1)
utils.AssertEqual(t, nil, err, test.description)
// test the body of the request
body, err := io.ReadAll(res.Body)
utils.AssertEqual(t, test.expectedCode, res.StatusCode, test.description)
// body
utils.AssertEqual(t, nil, err, test.description)
utils.AssertEqual(t, test.expectedBody, string(body), test.description)
}
}
func TestCustomSuccessAndFailureHandlers(t *testing.T) {
app := fiber.New()
app.Use(New(Config{
SuccessHandler: func(c *fiber.Ctx) error {
return c.Status(fiber.StatusOK).SendString("API key is valid and request was handled by custom success handler")
},
ErrorHandler: func(c *fiber.Ctx, err error) error {
return c.Status(fiber.StatusUnauthorized).SendString("API key is invalid and request was handled by custom error handler")
},
Validator: func(c *fiber.Ctx, key string) (bool, error) {
if key == CorrectKey {
return true, nil
}
return false, ErrMissingOrMalformedAPIKey
},
}))
// Define a test handler that should not be called
app.Get("/", func(c *fiber.Ctx) error {
t.Error("Test handler should not be called")
return nil
})
// Create a request without an API key and send it to the app
res, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, err, nil)
// Read the response body into a string
body, err := io.ReadAll(res.Body)
utils.AssertEqual(t, err, nil)
// Check that the response has the expected status code and body
utils.AssertEqual(t, res.StatusCode, http.StatusUnauthorized)
utils.AssertEqual(t, string(body), "API key is invalid and request was handled by custom error handler")
// Create a request with a valid API key in the Authorization header
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", CorrectKey))
// Send the request to the app
res, err = app.Test(req)
utils.AssertEqual(t, err, nil)
// Read the response body into a string
body, err = io.ReadAll(res.Body)
utils.AssertEqual(t, err, nil)
// Check that the response has the expected status code and body
utils.AssertEqual(t, res.StatusCode, http.StatusOK)
utils.AssertEqual(t, string(body), "API key is valid and request was handled by custom success handler")
}
func TestCustomNextFunc(t *testing.T) {
app := fiber.New()
app.Use(New(Config{
Next: func(c *fiber.Ctx) bool {
return c.Path() == "/allowed"
},
Validator: func(c *fiber.Ctx, key string) (bool, error) {
if key == CorrectKey {
return true, nil
}
return false, ErrMissingOrMalformedAPIKey
},
}))
// Define a test handler
app.Get("/allowed", func(c *fiber.Ctx) error {
return c.SendString("API key is valid and request was allowed by custom filter")
})
// Create a request with the "/allowed" path and send it to the app
req := httptest.NewRequest(fiber.MethodGet, "/allowed", nil)
res, err := app.Test(req)
utils.AssertEqual(t, err, nil)
// Read the response body into a string
body, err := io.ReadAll(res.Body)
utils.AssertEqual(t, err, nil)
// Check that the response has the expected status code and body
utils.AssertEqual(t, res.StatusCode, http.StatusOK)
utils.AssertEqual(t, string(body), "API key is valid and request was allowed by custom filter")
// Create a request with a different path and send it to the app without correct key
req = httptest.NewRequest(fiber.MethodGet, "/not-allowed", nil)
res, err = app.Test(req)
utils.AssertEqual(t, err, nil)
// Read the response body into a string
body, err = io.ReadAll(res.Body)
utils.AssertEqual(t, err, nil)
// Check that the response has the expected status code and body
utils.AssertEqual(t, res.StatusCode, http.StatusUnauthorized)
utils.AssertEqual(t, string(body), ErrMissingOrMalformedAPIKey.Error())
// Create a request with a different path and send it to the app with correct key
req = httptest.NewRequest(fiber.MethodGet, "/not-allowed", nil)
req.Header.Add("Authorization", fmt.Sprintf("Basic %s", CorrectKey))
res, err = app.Test(req)
utils.AssertEqual(t, err, nil)
// Read the response body into a string
body, err = io.ReadAll(res.Body)
utils.AssertEqual(t, err, nil)
// Check that the response has the expected status code and body
utils.AssertEqual(t, res.StatusCode, http.StatusUnauthorized)
utils.AssertEqual(t, string(body), ErrMissingOrMalformedAPIKey.Error())
}
func TestAuthSchemeToken(t *testing.T) {
app := fiber.New()
app.Use(New(Config{
AuthScheme: "Token",
Validator: func(c *fiber.Ctx, key string) (bool, error) {
if key == CorrectKey {
return true, nil
}
return false, ErrMissingOrMalformedAPIKey
},
}))
// Define a test handler
app.Get("/", func(c *fiber.Ctx) error {
return c.SendString("API key is valid")
})
// Create a request with a valid API key in the "Token" Authorization header
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
req.Header.Add("Authorization", fmt.Sprintf("Token %s", CorrectKey))
// Send the request to the app
res, err := app.Test(req)
utils.AssertEqual(t, err, nil)
// Read the response body into a string
body, err := io.ReadAll(res.Body)
utils.AssertEqual(t, err, nil)
// Check that the response has the expected status code and body
utils.AssertEqual(t, res.StatusCode, http.StatusOK)
utils.AssertEqual(t, string(body), "API key is valid")
}
func TestAuthSchemeBasic(t *testing.T) {
app := fiber.New()
app.Use(New(Config{
KeyLookup: "header:Authorization",
AuthScheme: "Basic",
Validator: func(c *fiber.Ctx, key string) (bool, error) {
if key == CorrectKey {
return true, nil
}
return false, ErrMissingOrMalformedAPIKey
},
}))
// Define a test handler
app.Get("/", func(c *fiber.Ctx) error {
return c.SendString("API key is valid")
})
// Create a request without an API key and Send the request to the app
res, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, err, nil)
// Read the response body into a string
body, err := io.ReadAll(res.Body)
utils.AssertEqual(t, err, nil)
// Check that the response has the expected status code and body
utils.AssertEqual(t, res.StatusCode, http.StatusUnauthorized)
utils.AssertEqual(t, string(body), ErrMissingOrMalformedAPIKey.Error())
// Create a request with a valid API key in the "Authorization" header using the "Basic" scheme
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
req.Header.Add("Authorization", fmt.Sprintf("Basic %s", CorrectKey))
// Send the request to the app
res, err = app.Test(req)
utils.AssertEqual(t, err, nil)
// Read the response body into a string
body, err = io.ReadAll(res.Body)
utils.AssertEqual(t, err, nil)
// Check that the response has the expected status code and body
utils.AssertEqual(t, res.StatusCode, http.StatusOK)
utils.AssertEqual(t, string(body), "API key is valid")
}

View file

@ -0,0 +1,128 @@
package limiter
import (
"time"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/log"
)
// Config defines the config for middleware.
type Config struct {
// Next defines a function to skip this middleware when returned true.
//
// Optional. Default: nil
Next func(c *fiber.Ctx) bool
// Max number of recent connections during `Expiration` seconds before sending a 429 response
//
// Default: 5
Max int
// KeyGenerator allows you to generate custom keys, by default c.IP() is used
//
// Default: func(c *fiber.Ctx) string {
// return c.IP()
// }
KeyGenerator func(*fiber.Ctx) string
// Expiration is the time on how long to keep records of requests in memory
//
// Default: 1 * time.Minute
Expiration time.Duration
// LimitReached is called when a request hits the limit
//
// Default: func(c *fiber.Ctx) error {
// return c.SendStatus(fiber.StatusTooManyRequests)
// }
LimitReached fiber.Handler
// When set to true, requests with StatusCode >= 400 won't be counted.
//
// Default: false
SkipFailedRequests bool
// When set to true, requests with StatusCode < 400 won't be counted.
//
// Default: false
SkipSuccessfulRequests bool
// Store is used to store the state of the middleware
//
// Default: an in memory store for this process only
Storage fiber.Storage
// LimiterMiddleware is the struct that implements a limiter middleware.
//
// Default: a new Fixed Window Rate Limiter
LimiterMiddleware LimiterHandler
// Deprecated: Use Expiration instead
Duration time.Duration
// Deprecated: Use Storage instead
Store fiber.Storage
// Deprecated: Use KeyGenerator instead
Key func(*fiber.Ctx) string
}
// ConfigDefault is the default config
var ConfigDefault = Config{
Max: 5,
Expiration: 1 * time.Minute,
KeyGenerator: func(c *fiber.Ctx) string {
return c.IP()
},
LimitReached: func(c *fiber.Ctx) error {
return c.SendStatus(fiber.StatusTooManyRequests)
},
SkipFailedRequests: false,
SkipSuccessfulRequests: false,
LimiterMiddleware: FixedWindow{},
}
// Helper function to set default values
func configDefault(config ...Config) Config {
// Return default config if nothing provided
if len(config) < 1 {
return ConfigDefault
}
// Override default config
cfg := config[0]
// Set default values
if int(cfg.Duration.Seconds()) > 0 {
log.Warn("[LIMITER] Duration is deprecated, please use Expiration")
cfg.Expiration = cfg.Duration
}
if cfg.Key != nil {
log.Warn("[LIMITER] Key is deprecated, please us KeyGenerator")
cfg.KeyGenerator = cfg.Key
}
if cfg.Store != nil {
log.Warn("[LIMITER] Store is deprecated, please use Storage")
cfg.Storage = cfg.Store
}
if cfg.Next == nil {
cfg.Next = ConfigDefault.Next
}
if cfg.Max <= 0 {
cfg.Max = ConfigDefault.Max
}
if int(cfg.Expiration.Seconds()) <= 0 {
cfg.Expiration = ConfigDefault.Expiration
}
if cfg.KeyGenerator == nil {
cfg.KeyGenerator = ConfigDefault.KeyGenerator
}
if cfg.LimitReached == nil {
cfg.LimitReached = ConfigDefault.LimitReached
}
if cfg.LimiterMiddleware == nil {
cfg.LimiterMiddleware = ConfigDefault.LimiterMiddleware
}
return cfg
}

View file

@ -0,0 +1,25 @@
package limiter
import (
"github.com/gofiber/fiber/v2"
)
const (
// X-RateLimit-* headers
xRateLimitLimit = "X-RateLimit-Limit"
xRateLimitRemaining = "X-RateLimit-Remaining"
xRateLimitReset = "X-RateLimit-Reset"
)
type LimiterHandler interface {
New(config Config) fiber.Handler
}
// New creates a new middleware handler
func New(config ...Config) fiber.Handler {
// Set default config
cfg := configDefault(config...)
// Return the specified middleware handler.
return cfg.LimiterMiddleware.New(cfg)
}

View file

@ -0,0 +1,106 @@
package limiter
import (
"strconv"
"sync"
"sync/atomic"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/utils"
)
type FixedWindow struct{}
// New creates a new fixed window middleware handler
func (FixedWindow) New(cfg Config) fiber.Handler {
var (
// Limiter variables
mux = &sync.RWMutex{}
max = strconv.Itoa(cfg.Max)
expiration = uint64(cfg.Expiration.Seconds())
)
// Create manager to simplify storage operations ( see manager.go )
manager := newManager(cfg.Storage)
// Update timestamp every second
utils.StartTimeStampUpdater()
// Return new handler
return func(c *fiber.Ctx) error {
// Don't execute middleware if Next returns true
if cfg.Next != nil && cfg.Next(c) {
return c.Next()
}
// Get key from request
key := cfg.KeyGenerator(c)
// Lock entry
mux.Lock()
// Get entry from pool and release when finished
e := manager.get(key)
// Get timestamp
ts := uint64(atomic.LoadUint32(&utils.Timestamp))
// Set expiration if entry does not exist
if e.exp == 0 {
e.exp = ts + expiration
} else if ts >= e.exp {
// Check if entry is expired
e.currHits = 0
e.exp = ts + expiration
}
// Increment hits
e.currHits++
// Calculate when it resets in seconds
resetInSec := e.exp - ts
// Set how many hits we have left
remaining := cfg.Max - e.currHits
// Update storage
manager.set(key, e, cfg.Expiration)
// Unlock entry
mux.Unlock()
// Check if hits exceed the cfg.Max
if remaining < 0 {
// Return response with Retry-After header
// https://tools.ietf.org/html/rfc6584
c.Set(fiber.HeaderRetryAfter, strconv.FormatUint(resetInSec, 10))
// Call LimitReached handler
return cfg.LimitReached(c)
}
// Continue stack for reaching c.Response().StatusCode()
// Store err for returning
err := c.Next()
// Check for SkipFailedRequests and SkipSuccessfulRequests
if (cfg.SkipSuccessfulRequests && c.Response().StatusCode() < fiber.StatusBadRequest) ||
(cfg.SkipFailedRequests && c.Response().StatusCode() >= fiber.StatusBadRequest) {
// Lock entry
mux.Lock()
e = manager.get(key)
e.currHits--
remaining++
manager.set(key, e, cfg.Expiration)
// Unlock entry
mux.Unlock()
}
// We can continue, update RateLimit headers
c.Set(xRateLimitLimit, max)
c.Set(xRateLimitRemaining, strconv.Itoa(remaining))
c.Set(xRateLimitReset, strconv.FormatUint(resetInSec, 10))
return err
}
}

View file

@ -0,0 +1,137 @@
package limiter
import (
"strconv"
"sync"
"sync/atomic"
"time"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/utils"
)
type SlidingWindow struct{}
// New creates a new sliding window middleware handler
func (SlidingWindow) New(cfg Config) fiber.Handler {
var (
// Limiter variables
mux = &sync.RWMutex{}
max = strconv.Itoa(cfg.Max)
expiration = uint64(cfg.Expiration.Seconds())
)
// Create manager to simplify storage operations ( see manager.go )
manager := newManager(cfg.Storage)
// Update timestamp every second
utils.StartTimeStampUpdater()
// Return new handler
return func(c *fiber.Ctx) error {
// Don't execute middleware if Next returns true
if cfg.Next != nil && cfg.Next(c) {
return c.Next()
}
// Get key from request
key := cfg.KeyGenerator(c)
// Lock entry
mux.Lock()
// Get entry from pool and release when finished
e := manager.get(key)
// Get timestamp
ts := uint64(atomic.LoadUint32(&utils.Timestamp))
// Set expiration if entry does not exist
if e.exp == 0 {
e.exp = ts + expiration
} else if ts >= e.exp {
// The entry has expired, handle the expiration.
// Set the prevHits to the current hits and reset the hits to 0.
e.prevHits = e.currHits
// Reset the current hits to 0.
e.currHits = 0
// Check how much into the current window it currently is and sets the
// expiry based on that, otherwise this would only reset on
// the next request and not show the correct expiry.
elapsed := ts - e.exp
if elapsed >= expiration {
e.exp = ts + expiration
} else {
e.exp = ts + expiration - elapsed
}
}
// Increment hits
e.currHits++
// Calculate when it resets in seconds
resetInSec := e.exp - ts
// weight = time until current window reset / total window length
weight := float64(resetInSec) / float64(expiration)
// rate = request count in previous window - weight + request count in current window
rate := int(float64(e.prevHits)*weight) + e.currHits
// Calculate how many hits can be made based on the current rate
remaining := cfg.Max - rate
// Update storage. Garbage collect when the next window ends.
// |--------------------------|--------------------------|
// ^ ^ ^ ^
// ts e.exp End sample window End next window
// <------------>
// resetInSec
// resetInSec = e.exp - ts - time until end of current window.
// duration + expiration = end of next window.
// Because we don't want to garbage collect in the middle of a window
// we add the expiration to the duration.
// Otherwise after the end of "sample window", attackers could launch
// a new request with the full window length.
manager.set(key, e, time.Duration(resetInSec+expiration)*time.Second)
// Unlock entry
mux.Unlock()
// Check if hits exceed the cfg.Max
if remaining < 0 {
// Return response with Retry-After header
// https://tools.ietf.org/html/rfc6584
c.Set(fiber.HeaderRetryAfter, strconv.FormatUint(resetInSec, 10))
// Call LimitReached handler
return cfg.LimitReached(c)
}
// Continue stack for reaching c.Response().StatusCode()
// Store err for returning
err := c.Next()
// Check for SkipFailedRequests and SkipSuccessfulRequests
if (cfg.SkipSuccessfulRequests && c.Response().StatusCode() < fiber.StatusBadRequest) ||
(cfg.SkipFailedRequests && c.Response().StatusCode() >= fiber.StatusBadRequest) {
// Lock entry
mux.Lock()
e = manager.get(key)
e.currHits--
remaining++
manager.set(key, e, cfg.Expiration)
// Unlock entry
mux.Unlock()
}
// We can continue, update RateLimit headers
c.Set(xRateLimitLimit, max)
c.Set(xRateLimitRemaining, strconv.Itoa(remaining))
c.Set(xRateLimitReset, strconv.FormatUint(resetInSec, 10))
return err
}
}

View file

@ -0,0 +1,727 @@
package limiter
import (
"io"
"net/http/httptest"
"sync"
"testing"
"time"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/internal/storage/memory"
"github.com/gofiber/fiber/v2/utils"
"github.com/valyala/fasthttp"
)
// go test -run Test_Limiter_Concurrency_Store -race -v
func Test_Limiter_Concurrency_Store(t *testing.T) {
t.Parallel()
// Test concurrency using a custom store
app := fiber.New()
app.Use(New(Config{
Max: 50,
Expiration: 2 * time.Second,
Storage: memory.New(),
}))
app.Get("/", func(c *fiber.Ctx) error {
return c.SendString("Hello tester!")
})
var wg sync.WaitGroup
singleRequest := func(wg *sync.WaitGroup) {
defer wg.Done()
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
body, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "Hello tester!", string(body))
}
for i := 0; i <= 49; i++ {
wg.Add(1)
go singleRequest(&wg)
}
wg.Wait()
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 429, resp.StatusCode)
time.Sleep(3 * time.Second)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode)
}
// go test -run Test_Limiter_Concurrency -race -v
func Test_Limiter_Concurrency(t *testing.T) {
t.Parallel()
// Test concurrency using a default store
app := fiber.New()
app.Use(New(Config{
Max: 50,
Expiration: 2 * time.Second,
}))
app.Get("/", func(c *fiber.Ctx) error {
return c.SendString("Hello tester!")
})
var wg sync.WaitGroup
singleRequest := func(wg *sync.WaitGroup) {
defer wg.Done()
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
body, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "Hello tester!", string(body))
}
for i := 0; i <= 49; i++ {
wg.Add(1)
go singleRequest(&wg)
}
wg.Wait()
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 429, resp.StatusCode)
time.Sleep(3 * time.Second)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode)
}
// go test -run Test_Limiter_Fixed_Window_No_Skip_Choices -v
func Test_Limiter_Fixed_Window_No_Skip_Choices(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Max: 2,
Expiration: 2 * time.Second,
SkipFailedRequests: false,
SkipSuccessfulRequests: false,
LimiterMiddleware: FixedWindow{},
}))
app.Get("/:status", func(c *fiber.Ctx) error {
if c.Params("status") == "fail" { //nolint:goconst // False positive
return c.SendStatus(400)
}
return c.SendStatus(200)
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 400, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 429, resp.StatusCode)
time.Sleep(3 * time.Second)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode)
}
// go test -run Test_Limiter_Fixed_Window_Custom_Storage_No_Skip_Choices -v
func Test_Limiter_Fixed_Window_Custom_Storage_No_Skip_Choices(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Max: 2,
Expiration: 2 * time.Second,
SkipFailedRequests: false,
SkipSuccessfulRequests: false,
Storage: memory.New(),
LimiterMiddleware: FixedWindow{},
}))
app.Get("/:status", func(c *fiber.Ctx) error {
if c.Params("status") == "fail" {
return c.SendStatus(400)
}
return c.SendStatus(200)
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 400, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 429, resp.StatusCode)
time.Sleep(3 * time.Second)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode)
}
// go test -run Test_Limiter_Sliding_Window_No_Skip_Choices -v
func Test_Limiter_Sliding_Window_No_Skip_Choices(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Max: 2,
Expiration: 2 * time.Second,
SkipFailedRequests: false,
SkipSuccessfulRequests: false,
LimiterMiddleware: SlidingWindow{},
}))
app.Get("/:status", func(c *fiber.Ctx) error {
if c.Params("status") == "fail" {
return c.SendStatus(400)
}
return c.SendStatus(200)
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 400, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 429, resp.StatusCode)
time.Sleep(4*time.Second + 500*time.Millisecond)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode)
}
// go test -run Test_Limiter_Sliding_Window_Custom_Storage_No_Skip_Choices -v
func Test_Limiter_Sliding_Window_Custom_Storage_No_Skip_Choices(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Max: 2,
Expiration: 2 * time.Second,
SkipFailedRequests: false,
SkipSuccessfulRequests: false,
Storage: memory.New(),
LimiterMiddleware: SlidingWindow{},
}))
app.Get("/:status", func(c *fiber.Ctx) error {
if c.Params("status") == "fail" {
return c.SendStatus(400)
}
return c.SendStatus(200)
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 400, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 429, resp.StatusCode)
time.Sleep(4*time.Second + 500*time.Millisecond)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode)
}
// go test -run Test_Limiter_Fixed_Window_Skip_Failed_Requests -v
func Test_Limiter_Fixed_Window_Skip_Failed_Requests(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Max: 1,
Expiration: 2 * time.Second,
SkipFailedRequests: true,
LimiterMiddleware: FixedWindow{},
}))
app.Get("/:status", func(c *fiber.Ctx) error {
if c.Params("status") == "fail" {
return c.SendStatus(400)
}
return c.SendStatus(200)
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 400, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 429, resp.StatusCode)
time.Sleep(3 * time.Second)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode)
}
// go test -run Test_Limiter_Fixed_Window_Custom_Storage_Skip_Failed_Requests -v
func Test_Limiter_Fixed_Window_Custom_Storage_Skip_Failed_Requests(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Max: 1,
Expiration: 2 * time.Second,
Storage: memory.New(),
SkipFailedRequests: true,
LimiterMiddleware: FixedWindow{},
}))
app.Get("/:status", func(c *fiber.Ctx) error {
if c.Params("status") == "fail" {
return c.SendStatus(400)
}
return c.SendStatus(200)
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 400, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 429, resp.StatusCode)
time.Sleep(3 * time.Second)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode)
}
// go test -run Test_Limiter_Sliding_Window_Skip_Failed_Requests -v
func Test_Limiter_Sliding_Window_Skip_Failed_Requests(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Max: 1,
Expiration: 2 * time.Second,
SkipFailedRequests: true,
LimiterMiddleware: SlidingWindow{},
}))
app.Get("/:status", func(c *fiber.Ctx) error {
if c.Params("status") == "fail" {
return c.SendStatus(400)
}
return c.SendStatus(200)
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 400, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 429, resp.StatusCode)
time.Sleep(4*time.Second + 500*time.Millisecond)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode)
}
// go test -run Test_Limiter_Sliding_Window_Custom_Storage_Skip_Failed_Requests -v
func Test_Limiter_Sliding_Window_Custom_Storage_Skip_Failed_Requests(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Max: 1,
Expiration: 2 * time.Second,
Storage: memory.New(),
SkipFailedRequests: true,
LimiterMiddleware: SlidingWindow{},
}))
app.Get("/:status", func(c *fiber.Ctx) error {
if c.Params("status") == "fail" {
return c.SendStatus(400)
}
return c.SendStatus(200)
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 400, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 429, resp.StatusCode)
time.Sleep(4*time.Second + 500*time.Millisecond)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode)
}
// go test -run Test_Limiter_Fixed_Window_Skip_Successful_Requests -v
func Test_Limiter_Fixed_Window_Skip_Successful_Requests(t *testing.T) {
t.Parallel()
// Test concurrency using a default store
app := fiber.New()
app.Use(New(Config{
Max: 1,
Expiration: 2 * time.Second,
SkipSuccessfulRequests: true,
LimiterMiddleware: FixedWindow{},
}))
app.Get("/:status", func(c *fiber.Ctx) error {
if c.Params("status") == "fail" {
return c.SendStatus(400)
}
return c.SendStatus(200)
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 400, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 429, resp.StatusCode)
time.Sleep(3 * time.Second)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 400, resp.StatusCode)
}
// go test -run Test_Limiter_Fixed_Window_Custom_Storage_Skip_Successful_Requests -v
func Test_Limiter_Fixed_Window_Custom_Storage_Skip_Successful_Requests(t *testing.T) {
t.Parallel()
// Test concurrency using a default store
app := fiber.New()
app.Use(New(Config{
Max: 1,
Expiration: 2 * time.Second,
Storage: memory.New(),
SkipSuccessfulRequests: true,
LimiterMiddleware: FixedWindow{},
}))
app.Get("/:status", func(c *fiber.Ctx) error {
if c.Params("status") == "fail" {
return c.SendStatus(400)
}
return c.SendStatus(200)
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 400, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 429, resp.StatusCode)
time.Sleep(3 * time.Second)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 400, resp.StatusCode)
}
// go test -run Test_Limiter_Sliding_Window_Skip_Successful_Requests -v
func Test_Limiter_Sliding_Window_Skip_Successful_Requests(t *testing.T) {
t.Parallel()
// Test concurrency using a default store
app := fiber.New()
app.Use(New(Config{
Max: 1,
Expiration: 2 * time.Second,
SkipSuccessfulRequests: true,
LimiterMiddleware: SlidingWindow{},
}))
app.Get("/:status", func(c *fiber.Ctx) error {
if c.Params("status") == "fail" {
return c.SendStatus(400)
}
return c.SendStatus(200)
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 400, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 429, resp.StatusCode)
time.Sleep(4*time.Second + 500*time.Millisecond)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 400, resp.StatusCode)
}
// go test -run Test_Limiter_Sliding_Window_Custom_Storage_Skip_Successful_Requests -v
func Test_Limiter_Sliding_Window_Custom_Storage_Skip_Successful_Requests(t *testing.T) {
t.Parallel()
// Test concurrency using a default store
app := fiber.New()
app.Use(New(Config{
Max: 1,
Expiration: 2 * time.Second,
Storage: memory.New(),
SkipSuccessfulRequests: true,
LimiterMiddleware: SlidingWindow{},
}))
app.Get("/:status", func(c *fiber.Ctx) error {
if c.Params("status") == "fail" {
return c.SendStatus(400)
}
return c.SendStatus(200)
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 400, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 429, resp.StatusCode)
time.Sleep(4*time.Second + 500*time.Millisecond)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 400, resp.StatusCode)
}
// go test -v -run=^$ -bench=Benchmark_Limiter_Custom_Store -benchmem -count=4
func Benchmark_Limiter_Custom_Store(b *testing.B) {
app := fiber.New()
app.Use(New(Config{
Max: 100,
Expiration: 60 * time.Second,
Storage: memory.New(),
}))
app.Get("/", func(c *fiber.Ctx) error {
return c.SendString("Hello, World!")
})
h := app.Handler()
fctx := &fasthttp.RequestCtx{}
fctx.Request.Header.SetMethod(fiber.MethodGet)
fctx.Request.SetRequestURI("/")
b.ResetTimer()
for n := 0; n < b.N; n++ {
h(fctx)
}
}
// go test -run Test_Limiter_Next
func Test_Limiter_Next(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Next: func(_ *fiber.Ctx) bool {
return true
},
}))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
}
func Test_Limiter_Headers(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Max: 50,
Expiration: 2 * time.Second,
}))
app.Get("/", func(c *fiber.Ctx) error {
return c.SendString("Hello tester!")
})
fctx := &fasthttp.RequestCtx{}
fctx.Request.Header.SetMethod(fiber.MethodGet)
fctx.Request.SetRequestURI("/")
app.Handler()(fctx)
utils.AssertEqual(t, "50", string(fctx.Response.Header.Peek("X-RateLimit-Limit")))
if v := string(fctx.Response.Header.Peek("X-RateLimit-Remaining")); v == "" {
t.Errorf("The X-RateLimit-Remaining header is not set correctly - value is an empty string.")
}
if v := string(fctx.Response.Header.Peek("X-RateLimit-Reset")); !(v == "1" || v == "2") {
t.Errorf("The X-RateLimit-Reset header is not set correctly - value is out of bounds.")
}
}
// go test -v -run=^$ -bench=Benchmark_Limiter -benchmem -count=4
func Benchmark_Limiter(b *testing.B) {
app := fiber.New()
app.Use(New(Config{
Max: 100,
Expiration: 60 * time.Second,
}))
app.Get("/", func(c *fiber.Ctx) error {
return c.SendString("Hello, World!")
})
h := app.Handler()
fctx := &fasthttp.RequestCtx{}
fctx.Request.Header.SetMethod(fiber.MethodGet)
fctx.Request.SetRequestURI("/")
b.ResetTimer()
for n := 0; n < b.N; n++ {
h(fctx)
}
}
// go test -run Test_Sliding_Window -race -v
func Test_Sliding_Window(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Max: 10,
Expiration: 2 * time.Second,
Storage: memory.New(),
LimiterMiddleware: SlidingWindow{},
}))
app.Get("/", func(c *fiber.Ctx) error {
return c.SendString("Hello tester!")
})
singleRequest := func(shouldFail bool) {
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
if shouldFail {
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 429, resp.StatusCode)
} else {
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
}
}
for i := 0; i < 5; i++ {
singleRequest(false)
}
time.Sleep(2 * time.Second)
for i := 0; i < 5; i++ {
singleRequest(false)
}
time.Sleep(3 * time.Second)
for i := 0; i < 5; i++ {
singleRequest(false)
}
time.Sleep(4 * time.Second)
for i := 0; i < 9; i++ {
singleRequest(false)
}
}

View file

@ -0,0 +1,92 @@
package limiter
import (
"sync"
"time"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/internal/memory"
)
// go:generate msgp
// msgp -file="manager.go" -o="manager_msgp.go" -tests=false -unexported
type item struct {
currHits int
prevHits int
exp uint64
}
//msgp:ignore manager
type manager struct {
pool sync.Pool
memory *memory.Storage
storage fiber.Storage
}
func newManager(storage fiber.Storage) *manager {
// Create new storage handler
manager := &manager{
pool: sync.Pool{
New: func() interface{} {
return new(item)
},
},
}
if storage != nil {
// Use provided storage if provided
manager.storage = storage
} else {
// Fallback too memory storage
manager.memory = memory.New()
}
return manager
}
// acquire returns an *entry from the sync.Pool
func (m *manager) acquire() *item {
return m.pool.Get().(*item) //nolint:forcetypeassert // We store nothing else in the pool
}
// release and reset *entry to sync.Pool
func (m *manager) release(e *item) {
e.prevHits = 0
e.currHits = 0
e.exp = 0
m.pool.Put(e)
}
// get data from storage or memory
func (m *manager) get(key string) *item {
var it *item
if m.storage != nil {
it = m.acquire()
raw, err := m.storage.Get(key)
if err != nil {
return it
}
if raw != nil {
if _, err := it.UnmarshalMsg(raw); err != nil {
return it
}
}
return it
}
if it, _ = m.memory.Get(key).(*item); it == nil { //nolint:errcheck // We store nothing else in the pool
it = m.acquire()
return it
}
return it
}
// set data to storage or memory
func (m *manager) set(key string, it *item, exp time.Duration) {
if m.storage != nil {
if raw, err := it.MarshalMsg(nil); err == nil {
_ = m.storage.Set(key, raw, exp) //nolint:errcheck // TODO: Handle error here
}
// we can release data because it's serialized to database
m.release(it)
} else {
m.memory.Set(key, it, exp)
}
}

View file

@ -0,0 +1,160 @@
package limiter
// Code generated by github.com/tinylib/msgp DO NOT EDIT.
import (
"github.com/tinylib/msgp/msgp"
)
// DecodeMsg implements msgp.Decodable
func (z *item) DecodeMsg(dc *msgp.Reader) (err error) {
var field []byte
_ = field
var zb0001 uint32
zb0001, err = dc.ReadMapHeader()
if err != nil {
err = msgp.WrapError(err)
return
}
for zb0001 > 0 {
zb0001--
field, err = dc.ReadMapKeyPtr()
if err != nil {
err = msgp.WrapError(err)
return
}
switch msgp.UnsafeString(field) {
case "currHits":
z.currHits, err = dc.ReadInt()
if err != nil {
err = msgp.WrapError(err, "currHits")
return
}
case "prevHits":
z.prevHits, err = dc.ReadInt()
if err != nil {
err = msgp.WrapError(err, "prevHits")
return
}
case "exp":
z.exp, err = dc.ReadUint64()
if err != nil {
err = msgp.WrapError(err, "exp")
return
}
default:
err = dc.Skip()
if err != nil {
err = msgp.WrapError(err)
return
}
}
}
return
}
// EncodeMsg implements msgp.Encodable
func (z item) EncodeMsg(en *msgp.Writer) (err error) {
// map header, size 3
// write "currHits"
err = en.Append(0x83, 0xa8, 0x63, 0x75, 0x72, 0x72, 0x48, 0x69, 0x74, 0x73)
if err != nil {
return
}
err = en.WriteInt(z.currHits)
if err != nil {
err = msgp.WrapError(err, "currHits")
return
}
// write "prevHits"
err = en.Append(0xa8, 0x70, 0x72, 0x65, 0x76, 0x48, 0x69, 0x74, 0x73)
if err != nil {
return
}
err = en.WriteInt(z.prevHits)
if err != nil {
err = msgp.WrapError(err, "prevHits")
return
}
// write "exp"
err = en.Append(0xa3, 0x65, 0x78, 0x70)
if err != nil {
return
}
err = en.WriteUint64(z.exp)
if err != nil {
err = msgp.WrapError(err, "exp")
return
}
return
}
// MarshalMsg implements msgp.Marshaler
func (z item) MarshalMsg(b []byte) (o []byte, err error) {
o = msgp.Require(b, z.Msgsize())
// map header, size 3
// string "currHits"
o = append(o, 0x83, 0xa8, 0x63, 0x75, 0x72, 0x72, 0x48, 0x69, 0x74, 0x73)
o = msgp.AppendInt(o, z.currHits)
// string "prevHits"
o = append(o, 0xa8, 0x70, 0x72, 0x65, 0x76, 0x48, 0x69, 0x74, 0x73)
o = msgp.AppendInt(o, z.prevHits)
// string "exp"
o = append(o, 0xa3, 0x65, 0x78, 0x70)
o = msgp.AppendUint64(o, z.exp)
return
}
// UnmarshalMsg implements msgp.Unmarshaler
func (z *item) UnmarshalMsg(bts []byte) (o []byte, err error) {
var field []byte
_ = field
var zb0001 uint32
zb0001, bts, err = msgp.ReadMapHeaderBytes(bts)
if err != nil {
err = msgp.WrapError(err)
return
}
for zb0001 > 0 {
zb0001--
field, bts, err = msgp.ReadMapKeyZC(bts)
if err != nil {
err = msgp.WrapError(err)
return
}
switch msgp.UnsafeString(field) {
case "currHits":
z.currHits, bts, err = msgp.ReadIntBytes(bts)
if err != nil {
err = msgp.WrapError(err, "currHits")
return
}
case "prevHits":
z.prevHits, bts, err = msgp.ReadIntBytes(bts)
if err != nil {
err = msgp.WrapError(err, "prevHits")
return
}
case "exp":
z.exp, bts, err = msgp.ReadUint64Bytes(bts)
if err != nil {
err = msgp.WrapError(err, "exp")
return
}
default:
bts, err = msgp.Skip(bts)
if err != nil {
err = msgp.WrapError(err)
return
}
}
}
o = bts
return
}
// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message
func (z item) Msgsize() (s int) {
s = 1 + 9 + msgp.IntSize + 9 + msgp.IntSize + 4 + msgp.Uint64Size
return
}

136
middleware/logger/config.go Normal file
View file

@ -0,0 +1,136 @@
package logger
import (
"io"
"os"
"time"
"github.com/gofiber/fiber/v2"
)
// Config defines the config for middleware.
type Config struct {
// Next defines a function to skip this middleware when returned true.
//
// Optional. Default: nil
Next func(c *fiber.Ctx) bool
// Done is a function that is called after the log string for a request is written to Output,
// and pass the log string as parameter.
//
// Optional. Default: nil
Done func(c *fiber.Ctx, logString []byte)
// tagFunctions defines the custom tag action
//
// Optional. Default: map[string]LogFunc
CustomTags map[string]LogFunc
// Format defines the logging tags
//
// Optional. Default: ${time} | ${status} | ${latency} | ${ip} | ${method} | ${path} | ${error}\n
Format string
// TimeFormat https://programming.guide/go/format-parse-string-time-date-example.html
//
// Optional. Default: 15:04:05
TimeFormat string
// TimeZone can be specified, such as "UTC" and "America/New_York" and "Asia/Chongqing", etc
//
// Optional. Default: "Local"
TimeZone string
// TimeInterval is the delay before the timestamp is updated
//
// Optional. Default: 500 * time.Millisecond
TimeInterval time.Duration
// Output is a writer where logs are written
//
// Default: os.Stdout
Output io.Writer
// DisableColors defines if the logs output should be colorized
//
// Default: false
DisableColors bool
enableColors bool
enableLatency bool
timeZoneLocation *time.Location
}
const (
startTag = "${"
endTag = "}"
paramSeparator = ":"
)
type Buffer interface {
Len() int
ReadFrom(r io.Reader) (int64, error)
WriteTo(w io.Writer) (int64, error)
Bytes() []byte
Write(p []byte) (int, error)
WriteByte(c byte) error
WriteString(s string) (int, error)
Set(p []byte)
SetString(s string)
String() string
}
type LogFunc func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error)
// ConfigDefault is the default config
var ConfigDefault = Config{
Next: nil,
Done: nil,
Format: "${time} | ${status} | ${latency} | ${ip} | ${method} | ${path} | ${error}\n",
TimeFormat: "15:04:05",
TimeZone: "Local",
TimeInterval: 500 * time.Millisecond,
Output: os.Stdout,
DisableColors: false,
enableColors: true,
}
// Helper function to set default values
func configDefault(config ...Config) Config {
// Return default config if nothing provided
if len(config) < 1 {
return ConfigDefault
}
// Override default config
cfg := config[0]
// Set default values
if cfg.Next == nil {
cfg.Next = ConfigDefault.Next
}
if cfg.Done == nil {
cfg.Done = ConfigDefault.Done
}
if cfg.Format == "" {
cfg.Format = ConfigDefault.Format
}
if cfg.TimeZone == "" {
cfg.TimeZone = ConfigDefault.TimeZone
}
if cfg.TimeFormat == "" {
cfg.TimeFormat = ConfigDefault.TimeFormat
}
if int(cfg.TimeInterval) <= 0 {
cfg.TimeInterval = ConfigDefault.TimeInterval
}
if cfg.Output == nil {
cfg.Output = ConfigDefault.Output
}
if !cfg.DisableColors && cfg.Output == ConfigDefault.Output {
cfg.enableColors = true
}
return cfg
}

16
middleware/logger/data.go Normal file
View file

@ -0,0 +1,16 @@
package logger
import (
"sync/atomic"
"time"
)
// Data is a struct to define some variables to use in custom logger function.
type Data struct {
Pid string
ErrPaddingStr string
ChainErr error
Start time.Time
Stop time.Time
Timestamp atomic.Value
}

182
middleware/logger/logger.go Normal file
View file

@ -0,0 +1,182 @@
package logger
import (
"fmt"
"os"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/utils"
"github.com/mattn/go-colorable"
"github.com/mattn/go-isatty"
"github.com/valyala/bytebufferpool"
"github.com/valyala/fasthttp"
)
// New creates a new middleware handler
func New(config ...Config) fiber.Handler {
// Set default config
cfg := configDefault(config...)
// Get timezone location
tz, err := time.LoadLocation(cfg.TimeZone)
if err != nil || tz == nil {
cfg.timeZoneLocation = time.Local
} else {
cfg.timeZoneLocation = tz
}
// Check if format contains latency
cfg.enableLatency = strings.Contains(cfg.Format, "${"+TagLatency+"}")
var timestamp atomic.Value
// Create correct timeformat
timestamp.Store(time.Now().In(cfg.timeZoneLocation).Format(cfg.TimeFormat))
// Update date/time every 500 milliseconds in a separate go routine
if strings.Contains(cfg.Format, "${"+TagTime+"}") {
go func() {
for {
time.Sleep(cfg.TimeInterval)
timestamp.Store(time.Now().In(cfg.timeZoneLocation).Format(cfg.TimeFormat))
}
}()
}
// Set PID once
pid := strconv.Itoa(os.Getpid())
// Set variables
var (
once sync.Once
mu sync.Mutex
errHandler fiber.ErrorHandler
dataPool = sync.Pool{New: func() interface{} { return new(Data) }}
)
// If colors are enabled, check terminal compatibility
if cfg.enableColors {
cfg.Output = colorable.NewColorableStdout()
if os.Getenv("TERM") == "dumb" || os.Getenv("NO_COLOR") == "1" || (!isatty.IsTerminal(os.Stdout.Fd()) && !isatty.IsCygwinTerminal(os.Stdout.Fd())) {
cfg.Output = colorable.NewNonColorable(os.Stdout)
}
}
errPadding := 15
errPaddingStr := strconv.Itoa(errPadding)
// instead of analyzing the template inside(handler) each time, this is done once before
// and we create several slices of the same length with the functions to be executed and fixed parts.
templateChain, logFunChain, err := buildLogFuncChain(&cfg, createTagMap(&cfg))
if err != nil {
panic(err)
}
// Return new handler
return func(c *fiber.Ctx) error {
// Don't execute middleware if Next returns true
if cfg.Next != nil && cfg.Next(c) {
return c.Next()
}
// Set error handler once
once.Do(func() {
// get longested possible path
stack := c.App().Stack()
for m := range stack {
for r := range stack[m] {
if len(stack[m][r].Path) > errPadding {
errPadding = len(stack[m][r].Path)
errPaddingStr = strconv.Itoa(errPadding)
}
}
}
// override error handler
errHandler = c.App().ErrorHandler
})
// Logger data
data := dataPool.Get().(*Data) //nolint:forcetypeassert,errcheck // We store nothing else in the pool
// no need for a reset, as long as we always override everything
data.Pid = pid
data.ErrPaddingStr = errPaddingStr
data.Timestamp = timestamp
// put data back in the pool
defer dataPool.Put(data)
// Set latency start time
if cfg.enableLatency {
data.Start = time.Now()
}
// Handle request, store err for logging
chainErr := c.Next()
data.ChainErr = chainErr
// Manually call error handler
if chainErr != nil {
if err := errHandler(c, chainErr); err != nil {
_ = c.SendStatus(fiber.StatusInternalServerError) //nolint:errcheck // TODO: Explain why we ignore the error here
}
}
// Set latency stop time
if cfg.enableLatency {
data.Stop = time.Now()
}
// Get new buffer
buf := bytebufferpool.Get()
var err error
// Loop over template parts execute dynamic parts and add fixed parts to the buffer
for i, logFunc := range logFunChain {
if logFunc == nil {
_, _ = buf.Write(templateChain[i]) //nolint:errcheck // This will never fail
} else if templateChain[i] == nil {
_, err = logFunc(buf, c, data, "")
} else {
_, err = logFunc(buf, c, data, utils.UnsafeString(templateChain[i]))
}
if err != nil {
break
}
}
// Also write errors to the buffer
if err != nil {
_, _ = buf.WriteString(err.Error()) //nolint:errcheck // This will never fail
}
mu.Lock()
// Write buffer to output
if _, err := cfg.Output.Write(buf.Bytes()); err != nil {
// Write error to output
if _, err := cfg.Output.Write([]byte(err.Error())); err != nil {
// There is something wrong with the given io.Writer
_, _ = fmt.Fprintf(os.Stderr, "Failed to write to log, %v\n", err)
}
}
mu.Unlock()
if cfg.Done != nil {
cfg.Done(c, buf.Bytes())
}
// Put buffer back to pool
bytebufferpool.Put(buf)
return nil
}
}
func appendInt(output Buffer, v int) (int, error) {
old := output.Len()
output.Set(fasthttp.AppendUint(output.Bytes(), v))
return output.Len() - old, nil
}

View file

@ -0,0 +1,650 @@
//nolint:bodyclose // Much easier to just ignore memory leaks in tests
package logger
import (
"bufio"
"bytes"
"errors"
"fmt"
"io"
"net/http"
"net/http/httptest"
"os"
"runtime"
"sync"
"testing"
"time"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/requestid"
"github.com/gofiber/fiber/v2/utils"
"github.com/valyala/bytebufferpool"
"github.com/valyala/fasthttp"
)
// go test -run Test_Logger
func Test_Logger(t *testing.T) {
t.Parallel()
app := fiber.New()
buf := bytebufferpool.Get()
defer bytebufferpool.Put(buf)
app.Use(New(Config{
Format: "${error}",
Output: buf,
}))
app.Get("/", func(c *fiber.Ctx) error {
return errors.New("some random error")
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusInternalServerError, resp.StatusCode)
utils.AssertEqual(t, "some random error", buf.String())
}
// go test -run Test_Logger_locals
func Test_Logger_locals(t *testing.T) {
t.Parallel()
app := fiber.New()
buf := bytebufferpool.Get()
defer bytebufferpool.Put(buf)
app.Use(New(Config{
Format: "${locals:demo}",
Output: buf,
}))
app.Get("/", func(c *fiber.Ctx) error {
c.Locals("demo", "johndoe")
return c.SendStatus(fiber.StatusOK)
})
app.Get("/int", func(c *fiber.Ctx) error {
c.Locals("demo", 55)
return c.SendStatus(fiber.StatusOK)
})
app.Get("/empty", func(c *fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
utils.AssertEqual(t, "johndoe", buf.String())
buf.Reset()
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/int", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
utils.AssertEqual(t, "55", buf.String())
buf.Reset()
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/empty", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
utils.AssertEqual(t, "", buf.String())
}
// go test -run Test_Logger_Next
func Test_Logger_Next(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Next: func(_ *fiber.Ctx) bool {
return true
},
}))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
}
// go test -run Test_Logger_Done
func Test_Logger_Done(t *testing.T) {
t.Parallel()
buf := bytes.NewBuffer(nil)
app := fiber.New()
app.Use(New(Config{
Done: func(c *fiber.Ctx, logString []byte) {
if c.Response().StatusCode() == fiber.StatusOK {
_, err := buf.Write(logString)
utils.AssertEqual(t, nil, err)
}
},
})).Get("/logging", func(ctx *fiber.Ctx) error {
return ctx.SendStatus(fiber.StatusOK)
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/logging", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
utils.AssertEqual(t, true, buf.Len() > 0)
}
// go test -run Test_Logger_ErrorTimeZone
func Test_Logger_ErrorTimeZone(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
TimeZone: "invalid",
}))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
}
type fakeOutput int
func (o *fakeOutput) Write([]byte) (int, error) {
*o++
return 0, nil
}
// go test -run Test_Logger_ErrorOutput_WithoutColor
func Test_Logger_ErrorOutput_WithoutColor(t *testing.T) {
t.Parallel()
o := new(fakeOutput)
app := fiber.New()
app.Use(New(Config{
Output: o,
DisableColors: true,
}))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
utils.AssertEqual(t, 1, int(*o))
}
// go test -run Test_Logger_ErrorOutput
func Test_Logger_ErrorOutput(t *testing.T) {
t.Parallel()
o := new(fakeOutput)
app := fiber.New()
app.Use(New(Config{
Output: o,
}))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
utils.AssertEqual(t, 1, int(*o))
}
// go test -run Test_Logger_All
func Test_Logger_All(t *testing.T) {
t.Parallel()
buf := bytebufferpool.Get()
defer bytebufferpool.Put(buf)
app := fiber.New()
app.Use(New(Config{
Format: "${pid}${reqHeaders}${referer}${protocol}${ip}${ips}${host}${url}${ua}${body}${route}${black}${red}${green}${yellow}${blue}${magenta}${cyan}${white}${reset}${error}${header:test}${query:test}${form:test}${cookie:test}${non}",
Output: buf,
}))
// Alias colors
colors := app.Config().ColorScheme
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/?foo=bar", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
expected := fmt.Sprintf("%dHost=example.comhttp0.0.0.0example.com/?foo=bar/%s%s%s%s%s%s%s%s%sCannot GET /", os.Getpid(), colors.Black, colors.Red, colors.Green, colors.Yellow, colors.Blue, colors.Magenta, colors.Cyan, colors.White, colors.Reset)
utils.AssertEqual(t, expected, buf.String())
}
func getLatencyTimeUnits() []struct {
unit string
div time.Duration
} {
// windows does not support µs sleep precision
// https://github.com/golang/go/issues/29485
if runtime.GOOS == "windows" {
return []struct {
unit string
div time.Duration
}{
{"ms", time.Millisecond},
{"s", time.Second},
}
}
return []struct {
unit string
div time.Duration
}{
{"µs", time.Microsecond},
{"ms", time.Millisecond},
{"s", time.Second},
}
}
// go test -run Test_Logger_WithLatency
func Test_Logger_WithLatency(t *testing.T) {
t.Parallel()
buff := bytebufferpool.Get()
defer bytebufferpool.Put(buff)
app := fiber.New()
logger := New(Config{
Output: buff,
Format: "${latency}",
})
app.Use(logger)
// Define a list of time units to test
timeUnits := getLatencyTimeUnits()
// Initialize a new time unit
sleepDuration := 1 * time.Nanosecond
// Define a test route that sleeps
app.Get("/test", func(c *fiber.Ctx) error {
time.Sleep(sleepDuration)
return c.SendStatus(fiber.StatusOK)
})
// Loop through each time unit and assert that the log output contains the expected latency value
for _, tu := range timeUnits {
// Update the sleep duration for the next iteration
sleepDuration = 1 * tu.div
// Create a new HTTP request to the test route
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil), int(2*time.Second))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
// Assert that the log output contains the expected latency value in the current time unit
utils.AssertEqual(t, bytes.HasSuffix(buff.Bytes(), []byte(tu.unit)), true, fmt.Sprintf("Expected latency to be in %s, got %s", tu.unit, buff.String()))
// Reset the buffer
buff.Reset()
}
}
// go test -run Test_Logger_WithLatency_DefaultFormat
func Test_Logger_WithLatency_DefaultFormat(t *testing.T) {
t.Parallel()
buff := bytebufferpool.Get()
defer bytebufferpool.Put(buff)
app := fiber.New()
logger := New(Config{
Output: buff,
})
app.Use(logger)
// Define a list of time units to test
timeUnits := getLatencyTimeUnits()
// Initialize a new time unit
sleepDuration := 1 * time.Nanosecond
// Define a test route that sleeps
app.Get("/test", func(c *fiber.Ctx) error {
time.Sleep(sleepDuration)
return c.SendStatus(fiber.StatusOK)
})
// Loop through each time unit and assert that the log output contains the expected latency value
for _, tu := range timeUnits {
// Update the sleep duration for the next iteration
sleepDuration = 1 * tu.div
// Create a new HTTP request to the test route
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil), int(2*time.Second))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
// Assert that the log output contains the expected latency value in the current time unit
// parse out the latency value from the log output
latency := bytes.Split(buff.Bytes(), []byte(" | "))[2]
// Assert that the latency value is in the current time unit
utils.AssertEqual(t, bytes.HasSuffix(latency, []byte(tu.unit)), true, fmt.Sprintf("Expected latency to be in %s, got %s", tu.unit, latency))
// Reset the buffer
buff.Reset()
}
}
// go test -run Test_Query_Params
func Test_Query_Params(t *testing.T) {
t.Parallel()
buf := bytebufferpool.Get()
defer bytebufferpool.Put(buf)
app := fiber.New()
app.Use(New(Config{
Format: "${queryParams}",
Output: buf,
}))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/?foo=bar&baz=moz", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
expected := "foo=bar&baz=moz"
utils.AssertEqual(t, expected, buf.String())
}
// go test -run Test_Response_Body
func Test_Response_Body(t *testing.T) {
t.Parallel()
buf := bytebufferpool.Get()
defer bytebufferpool.Put(buf)
app := fiber.New()
app.Use(New(Config{
Format: "${resBody}",
Output: buf,
}))
app.Get("/", func(c *fiber.Ctx) error {
return c.SendString("Sample response body")
})
app.Post("/test", func(c *fiber.Ctx) error {
return c.Send([]byte("Post in test"))
})
_, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
expectedGetResponse := "Sample response body"
utils.AssertEqual(t, expectedGetResponse, buf.String())
buf.Reset() // Reset buffer to test POST
_, err = app.Test(httptest.NewRequest(fiber.MethodPost, "/test", nil))
utils.AssertEqual(t, nil, err)
expectedPostResponse := "Post in test"
utils.AssertEqual(t, expectedPostResponse, buf.String())
}
// go test -run Test_Logger_AppendUint
func Test_Logger_AppendUint(t *testing.T) {
t.Parallel()
app := fiber.New()
buf := bytebufferpool.Get()
defer bytebufferpool.Put(buf)
app.Use(New(Config{
Format: "${bytesReceived} ${bytesSent} ${status}",
Output: buf,
}))
app.Get("/", func(c *fiber.Ctx) error {
c.Response().Header.SetContentLength(5)
return c.SendString("hello")
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
utils.AssertEqual(t, "-2 5 200", buf.String())
}
// go test -run Test_Logger_Data_Race -race
func Test_Logger_Data_Race(t *testing.T) {
t.Parallel()
app := fiber.New()
buf := bytebufferpool.Get()
defer bytebufferpool.Put(buf)
app.Use(New(ConfigDefault))
app.Use(New(Config{
Format: "${time} | ${pid} | ${locals:requestid} | ${status} | ${latency} | ${method} | ${path}\n",
}))
app.Get("/", func(c *fiber.Ctx) error {
return c.SendString("hello")
})
var (
resp1, resp2 *http.Response
err1, err2 error
)
wg := &sync.WaitGroup{}
wg.Add(1)
go func() {
resp1, err1 = app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
wg.Done()
}()
resp2, err2 = app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
wg.Wait()
utils.AssertEqual(t, nil, err1)
utils.AssertEqual(t, fiber.StatusOK, resp1.StatusCode)
utils.AssertEqual(t, nil, err2)
utils.AssertEqual(t, fiber.StatusOK, resp2.StatusCode)
}
// go test -v -run=^$ -bench=Benchmark_Logger -benchmem -count=4
func Benchmark_Logger(b *testing.B) {
benchSetup := func(b *testing.B, app *fiber.App) {
b.Helper()
h := app.Handler()
fctx := &fasthttp.RequestCtx{}
fctx.Request.Header.SetMethod(fiber.MethodGet)
fctx.Request.SetRequestURI("/")
b.ReportAllocs()
b.ResetTimer()
for n := 0; n < b.N; n++ {
h(fctx)
}
utils.AssertEqual(b, 200, fctx.Response.Header.StatusCode())
}
b.Run("Base", func(bb *testing.B) {
app := fiber.New()
app.Use(New(Config{
Format: "${bytesReceived} ${bytesSent} ${status}",
Output: io.Discard,
}))
app.Get("/", func(c *fiber.Ctx) error {
c.Set("test", "test")
return c.SendString("Hello, World!")
})
benchSetup(bb, app)
})
b.Run("DefaultFormat", func(bb *testing.B) {
app := fiber.New()
app.Use(New(Config{
Output: io.Discard,
}))
app.Get("/", func(c *fiber.Ctx) error {
return c.SendString("Hello, World!")
})
benchSetup(bb, app)
})
b.Run("WithTagParameter", func(bb *testing.B) {
app := fiber.New()
app.Use(New(Config{
Format: "${bytesReceived} ${bytesSent} ${status} ${reqHeader:test}",
Output: io.Discard,
}))
app.Get("/", func(c *fiber.Ctx) error {
c.Set("test", "test")
return c.SendString("Hello, World!")
})
benchSetup(bb, app)
})
}
// go test -run Test_Response_Header
func Test_Response_Header(t *testing.T) {
t.Parallel()
buf := bytebufferpool.Get()
defer bytebufferpool.Put(buf)
app := fiber.New()
app.Use(requestid.New(requestid.Config{
Next: nil,
Header: fiber.HeaderXRequestID,
Generator: func() string { return "Hello fiber!" },
ContextKey: "requestid",
}))
app.Use(New(Config{
Format: "${respHeader:X-Request-ID}",
Output: buf,
}))
app.Get("/", func(c *fiber.Ctx) error {
return c.SendString("Hello fiber!")
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
utils.AssertEqual(t, "Hello fiber!", buf.String())
}
// go test -run Test_Req_Header
func Test_Req_Header(t *testing.T) {
t.Parallel()
buf := bytebufferpool.Get()
defer bytebufferpool.Put(buf)
app := fiber.New()
app.Use(New(Config{
Format: "${header:test}",
Output: buf,
}))
app.Get("/", func(c *fiber.Ctx) error {
return c.SendString("Hello fiber!")
})
headerReq := httptest.NewRequest(fiber.MethodGet, "/", nil)
headerReq.Header.Add("test", "Hello fiber!")
resp, err := app.Test(headerReq)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
utils.AssertEqual(t, "Hello fiber!", buf.String())
}
// go test -run Test_ReqHeader_Header
func Test_ReqHeader_Header(t *testing.T) {
t.Parallel()
buf := bytebufferpool.Get()
defer bytebufferpool.Put(buf)
app := fiber.New()
app.Use(New(Config{
Format: "${reqHeader:test}",
Output: buf,
}))
app.Get("/", func(c *fiber.Ctx) error {
return c.SendString("Hello fiber!")
})
reqHeaderReq := httptest.NewRequest(fiber.MethodGet, "/", nil)
reqHeaderReq.Header.Add("test", "Hello fiber!")
resp, err := app.Test(reqHeaderReq)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
utils.AssertEqual(t, "Hello fiber!", buf.String())
}
// go test -run Test_CustomTags
func Test_CustomTags(t *testing.T) {
t.Parallel()
customTag := "it is a custom tag"
buf := bytebufferpool.Get()
defer bytebufferpool.Put(buf)
app := fiber.New()
app.Use(New(Config{
Format: "${custom_tag}",
CustomTags: map[string]LogFunc{
"custom_tag": func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
return output.WriteString(customTag)
},
},
Output: buf,
}))
app.Get("/", func(c *fiber.Ctx) error {
return c.SendString("Hello fiber!")
})
reqHeaderReq := httptest.NewRequest(fiber.MethodGet, "/", nil)
reqHeaderReq.Header.Add("test", "Hello fiber!")
resp, err := app.Test(reqHeaderReq)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
utils.AssertEqual(t, customTag, buf.String())
}
// go test -run Test_Logger_ByteSent_Streaming
func Test_Logger_ByteSent_Streaming(t *testing.T) {
t.Parallel()
app := fiber.New()
buf := bytebufferpool.Get()
defer bytebufferpool.Put(buf)
app.Use(New(Config{
Format: "${bytesReceived} ${bytesSent} ${status}",
Output: buf,
}))
app.Get("/", func(c *fiber.Ctx) error {
c.Set("Connection", "keep-alive")
c.Set("Transfer-Encoding", "chunked")
c.Context().SetBodyStreamWriter(func(w *bufio.Writer) {
var i int
for {
i++
msg := fmt.Sprintf("%d - the time is %v", i, time.Now())
fmt.Fprintf(w, "data: Message: %s\n\n", msg)
err := w.Flush()
if err != nil {
break
}
if i == 10 {
break
}
}
})
return nil
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
utils.AssertEqual(t, "-2 -1 200", buf.String())
}
// go test -run Test_Logger_EnableColors
func Test_Logger_EnableColors(t *testing.T) {
t.Parallel()
o := new(fakeOutput)
app := fiber.New()
app.Use(New(Config{
Output: o,
}))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
utils.AssertEqual(t, 1, int(*o))
}

209
middleware/logger/tags.go Normal file
View file

@ -0,0 +1,209 @@
package logger
import (
"fmt"
"strconv"
"strings"
"github.com/gofiber/fiber/v2"
)
// Logger variables
const (
TagPid = "pid"
TagTime = "time"
TagReferer = "referer"
TagProtocol = "protocol"
TagPort = "port"
TagIP = "ip"
TagIPs = "ips"
TagHost = "host"
TagMethod = "method"
TagPath = "path"
TagURL = "url"
TagUA = "ua"
TagLatency = "latency"
TagStatus = "status"
TagResBody = "resBody"
TagReqHeaders = "reqHeaders"
TagQueryStringParams = "queryParams"
TagBody = "body"
TagBytesSent = "bytesSent"
TagBytesReceived = "bytesReceived"
TagRoute = "route"
TagError = "error"
// Deprecated: Use TagReqHeader instead
TagHeader = "header:"
TagReqHeader = "reqHeader:"
TagRespHeader = "respHeader:"
TagLocals = "locals:"
TagQuery = "query:"
TagForm = "form:"
TagCookie = "cookie:"
TagBlack = "black"
TagRed = "red"
TagGreen = "green"
TagYellow = "yellow"
TagBlue = "blue"
TagMagenta = "magenta"
TagCyan = "cyan"
TagWhite = "white"
TagReset = "reset"
)
// createTagMap function merged the default with the custom tags
func createTagMap(cfg *Config) map[string]LogFunc {
// Set default tags
tagFunctions := map[string]LogFunc{
TagReferer: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
return output.WriteString(c.Get(fiber.HeaderReferer))
},
TagProtocol: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
return output.WriteString(c.Protocol())
},
TagPort: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
return output.WriteString(c.Port())
},
TagIP: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
return output.WriteString(c.IP())
},
TagIPs: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
return output.WriteString(c.Get(fiber.HeaderXForwardedFor))
},
TagHost: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
return output.WriteString(c.Hostname())
},
TagPath: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
return output.WriteString(c.Path())
},
TagURL: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
return output.WriteString(c.OriginalURL())
},
TagUA: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
return output.WriteString(c.Get(fiber.HeaderUserAgent))
},
TagBody: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
return output.Write(c.Body())
},
TagBytesReceived: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
return output.WriteString(strconv.Itoa((c.Request().Header.ContentLength())))
},
TagBytesSent: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
return output.WriteString(strconv.Itoa((c.Response().Header.ContentLength())))
},
TagRoute: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
return output.WriteString(c.Route().Path)
},
TagResBody: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
return output.Write(c.Response().Body())
},
TagReqHeaders: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
reqHeaders := make([]string, 0)
for k, v := range c.GetReqHeaders() {
reqHeaders = append(reqHeaders, k+"="+strings.Join(v, ","))
}
return output.Write([]byte(strings.Join(reqHeaders, "&")))
},
TagQueryStringParams: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
return output.WriteString(c.Request().URI().QueryArgs().String())
},
TagBlack: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
return output.WriteString(c.App().Config().ColorScheme.Black)
},
TagRed: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
return output.WriteString(c.App().Config().ColorScheme.Red)
},
TagGreen: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
return output.WriteString(c.App().Config().ColorScheme.Green)
},
TagYellow: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
return output.WriteString(c.App().Config().ColorScheme.Yellow)
},
TagBlue: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
return output.WriteString(c.App().Config().ColorScheme.Blue)
},
TagMagenta: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
return output.WriteString(c.App().Config().ColorScheme.Magenta)
},
TagCyan: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
return output.WriteString(c.App().Config().ColorScheme.Cyan)
},
TagWhite: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
return output.WriteString(c.App().Config().ColorScheme.White)
},
TagReset: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
return output.WriteString(c.App().Config().ColorScheme.Reset)
},
TagError: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
if data.ChainErr != nil {
if cfg.enableColors {
colors := c.App().Config().ColorScheme
return output.WriteString(fmt.Sprintf("%s%s%s", colors.Red, data.ChainErr.Error(), colors.Reset))
}
return output.WriteString(data.ChainErr.Error())
}
return output.WriteString("-")
},
TagReqHeader: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
return output.WriteString(c.Get(extraParam))
},
TagHeader: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
return output.WriteString(c.Get(extraParam))
},
TagRespHeader: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
return output.WriteString(c.GetRespHeader(extraParam))
},
TagQuery: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
return output.WriteString(c.Query(extraParam))
},
TagForm: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
return output.WriteString(c.FormValue(extraParam))
},
TagCookie: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
return output.WriteString(c.Cookies(extraParam))
},
TagLocals: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
switch v := c.Locals(extraParam).(type) {
case []byte:
return output.Write(v)
case string:
return output.WriteString(v)
case nil:
return 0, nil
default:
return output.WriteString(fmt.Sprintf("%v", v))
}
},
TagStatus: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
if cfg.enableColors {
colors := c.App().Config().ColorScheme
return output.WriteString(fmt.Sprintf("%s%3d%s", statusColor(c.Response().StatusCode(), colors), c.Response().StatusCode(), colors.Reset))
}
return appendInt(output, c.Response().StatusCode())
},
TagMethod: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
if cfg.enableColors {
colors := c.App().Config().ColorScheme
return output.WriteString(fmt.Sprintf("%s%s%s", methodColor(c.Method(), colors), c.Method(), colors.Reset))
}
return output.WriteString(c.Method())
},
TagPid: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
return output.WriteString(data.Pid)
},
TagLatency: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
latency := data.Stop.Sub(data.Start)
return output.WriteString(fmt.Sprintf("%13v", latency))
},
TagTime: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
return output.WriteString(data.Timestamp.Load().(string)) //nolint:forcetypeassert // We always store a string in here
},
}
// merge with custom tags from user
for k, v := range cfg.CustomTags {
tagFunctions[k] = v
}
return tagFunctions
}

View file

@ -0,0 +1,70 @@
package logger
import (
"bytes"
"errors"
"github.com/gofiber/fiber/v2/utils"
)
// buildLogFuncChain analyzes the template and creates slices with the functions for execution and
// slices with the fixed parts of the template and the parameters
//
// fixParts contains the fixed parts of the template or parameters if a function is stored in the funcChain at this position
// funcChain contains for the parts which exist the functions for the dynamic parts
// funcChain and fixParts always have the same length and contain nil for the parts where no data is required in the chain,
// if a function exists for the part, a parameter for it can also exist in the fixParts slice
func buildLogFuncChain(cfg *Config, tagFunctions map[string]LogFunc) ([][]byte, []LogFunc, error) {
// process flow is copied from the fasttemplate flow https://github.com/valyala/fasttemplate/blob/2a2d1afadadf9715bfa19683cdaeac8347e5d9f9/template.go#L23-L62
templateB := utils.UnsafeBytes(cfg.Format)
startTagB := utils.UnsafeBytes(startTag)
endTagB := utils.UnsafeBytes(endTag)
paramSeparatorB := utils.UnsafeBytes(paramSeparator)
var fixParts [][]byte
var funcChain []LogFunc
for {
currentPos := bytes.Index(templateB, startTagB)
if currentPos < 0 {
// no starting tag found in the existing template part
break
}
// add fixed part
funcChain = append(funcChain, nil)
fixParts = append(fixParts, templateB[:currentPos])
templateB = templateB[currentPos+len(startTagB):]
currentPos = bytes.Index(templateB, endTagB)
if currentPos < 0 {
// cannot find end tag - just write it to the output.
funcChain = append(funcChain, nil)
fixParts = append(fixParts, startTagB)
break
}
// ## function block ##
// first check for tags with parameters
if index := bytes.Index(templateB[:currentPos], paramSeparatorB); index != -1 {
logFunc, ok := tagFunctions[utils.UnsafeString(templateB[:index+1])]
if !ok {
return nil, nil, errors.New("No parameter found in \"" + utils.UnsafeString(templateB[:currentPos]) + "\"")
}
funcChain = append(funcChain, logFunc)
// add param to the fixParts
fixParts = append(fixParts, templateB[index+1:currentPos])
} else if logFunc, ok := tagFunctions[utils.UnsafeString(templateB[:currentPos])]; ok {
// add functions without parameter
funcChain = append(funcChain, logFunc)
fixParts = append(fixParts, nil)
}
// ## function block end ##
// reduce the template string
templateB = templateB[currentPos+len(endTagB):]
}
// set the rest
funcChain = append(funcChain, nil)
fixParts = append(fixParts, templateB)
return fixParts, funcChain, nil
}

View file

@ -0,0 +1,39 @@
package logger
import (
"github.com/gofiber/fiber/v2"
)
func methodColor(method string, colors fiber.Colors) string {
switch method {
case fiber.MethodGet:
return colors.Cyan
case fiber.MethodPost:
return colors.Green
case fiber.MethodPut:
return colors.Yellow
case fiber.MethodDelete:
return colors.Red
case fiber.MethodPatch:
return colors.White
case fiber.MethodHead:
return colors.Magenta
case fiber.MethodOptions:
return colors.Blue
default:
return colors.Reset
}
}
func statusColor(code int, colors fiber.Colors) string {
switch {
case code >= fiber.StatusOK && code < fiber.StatusMultipleChoices:
return colors.Green
case code >= fiber.StatusMultipleChoices && code < fiber.StatusBadRequest:
return colors.Blue
case code >= fiber.StatusBadRequest && code < fiber.StatusInternalServerError:
return colors.Yellow
default:
return colors.Red
}
}

View file

@ -0,0 +1,132 @@
package monitor
import (
"time"
"github.com/gofiber/fiber/v2"
)
// Config defines the config for middleware.
type Config struct {
// Metrics page title
//
// Optional. Default: "Fiber Monitor"
Title string
// Refresh period
//
// Optional. Default: 3 seconds
Refresh time.Duration
// Whether the service should expose only the monitoring API.
//
// Optional. Default: false
APIOnly bool
// Next defines a function to skip this middleware when returned true.
//
// Optional. Default: nil
Next func(c *fiber.Ctx) bool
// Custom HTML Code to Head Section(Before End)
//
// Optional. Default: empty
CustomHead string
// FontURL for specify font resource path or URL . also you can use relative path
//
// Optional. Default: https://fonts.googleapis.com/css2?family=Roboto:wght@400;900&display=swap
FontURL string
// ChartJsURL for specify ChartJS library path or URL . also you can use relative path
//
// Optional. Default: https://cdn.jsdelivr.net/npm/chart.js@2.9/dist/Chart.bundle.min.js
ChartJsURL string // TODO: Rename to "ChartJSURL" in v3
index string
}
var ConfigDefault = Config{
Title: defaultTitle,
Refresh: defaultRefresh,
FontURL: defaultFontURL,
ChartJsURL: defaultChartJSURL,
CustomHead: defaultCustomHead,
APIOnly: false,
Next: nil,
index: newIndex(viewBag{
defaultTitle,
defaultRefresh,
defaultFontURL,
defaultChartJSURL,
defaultCustomHead,
}),
}
func configDefault(config ...Config) Config {
// Users can change ConfigDefault.Title/Refresh which then
// become incompatible with ConfigDefault.index
if ConfigDefault.Title != defaultTitle ||
ConfigDefault.Refresh != defaultRefresh ||
ConfigDefault.FontURL != defaultFontURL ||
ConfigDefault.ChartJsURL != defaultChartJSURL ||
ConfigDefault.CustomHead != defaultCustomHead {
if ConfigDefault.Refresh < minRefresh {
ConfigDefault.Refresh = minRefresh
}
// update default index with new default title/refresh
ConfigDefault.index = newIndex(viewBag{
ConfigDefault.Title,
ConfigDefault.Refresh,
ConfigDefault.FontURL,
ConfigDefault.ChartJsURL,
ConfigDefault.CustomHead,
})
}
// Return default config if nothing provided
if len(config) < 1 {
return ConfigDefault
}
// Override default config
cfg := config[0]
// Set default values
if cfg.Title == "" {
cfg.Title = ConfigDefault.Title
}
if cfg.Refresh == 0 {
cfg.Refresh = ConfigDefault.Refresh
}
if cfg.FontURL == "" {
cfg.FontURL = defaultFontURL
}
if cfg.ChartJsURL == "" {
cfg.ChartJsURL = defaultChartJSURL
}
if cfg.Refresh < minRefresh {
cfg.Refresh = minRefresh
}
if cfg.Next == nil {
cfg.Next = ConfigDefault.Next
}
if !cfg.APIOnly {
cfg.APIOnly = ConfigDefault.APIOnly
}
// update cfg.index with custom title/refresh
cfg.index = newIndex(viewBag{
title: cfg.Title,
refresh: cfg.Refresh,
fontURL: cfg.FontURL,
chartJSURL: cfg.ChartJsURL,
customHead: cfg.CustomHead,
})
return cfg
}

View file

@ -0,0 +1,163 @@
package monitor
import (
"testing"
"time"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/utils"
)
func Test_Config_Default(t *testing.T) {
t.Parallel()
t.Run("use default", func(t *testing.T) {
t.Parallel()
cfg := configDefault()
utils.AssertEqual(t, defaultTitle, cfg.Title)
utils.AssertEqual(t, defaultRefresh, cfg.Refresh)
utils.AssertEqual(t, defaultFontURL, cfg.FontURL)
utils.AssertEqual(t, defaultChartJSURL, cfg.ChartJsURL)
utils.AssertEqual(t, defaultCustomHead, cfg.CustomHead)
utils.AssertEqual(t, false, cfg.APIOnly)
utils.AssertEqual(t, (func(*fiber.Ctx) bool)(nil), cfg.Next)
utils.AssertEqual(t, newIndex(viewBag{defaultTitle, defaultRefresh, defaultFontURL, defaultChartJSURL, defaultCustomHead}), cfg.index)
})
t.Run("set title", func(t *testing.T) {
t.Parallel()
title := "title"
cfg := configDefault(Config{
Title: title,
})
utils.AssertEqual(t, title, cfg.Title)
utils.AssertEqual(t, defaultRefresh, cfg.Refresh)
utils.AssertEqual(t, defaultFontURL, cfg.FontURL)
utils.AssertEqual(t, defaultChartJSURL, cfg.ChartJsURL)
utils.AssertEqual(t, defaultCustomHead, cfg.CustomHead)
utils.AssertEqual(t, false, cfg.APIOnly)
utils.AssertEqual(t, (func(*fiber.Ctx) bool)(nil), cfg.Next)
utils.AssertEqual(t, newIndex(viewBag{title, defaultRefresh, defaultFontURL, defaultChartJSURL, defaultCustomHead}), cfg.index)
})
t.Run("set refresh less than default", func(t *testing.T) {
t.Parallel()
cfg := configDefault(Config{
Refresh: 100 * time.Millisecond,
})
utils.AssertEqual(t, defaultTitle, cfg.Title)
utils.AssertEqual(t, minRefresh, cfg.Refresh)
utils.AssertEqual(t, defaultFontURL, cfg.FontURL)
utils.AssertEqual(t, defaultChartJSURL, cfg.ChartJsURL)
utils.AssertEqual(t, defaultCustomHead, cfg.CustomHead)
utils.AssertEqual(t, false, cfg.APIOnly)
utils.AssertEqual(t, (func(*fiber.Ctx) bool)(nil), cfg.Next)
utils.AssertEqual(t, newIndex(viewBag{defaultTitle, minRefresh, defaultFontURL, defaultChartJSURL, defaultCustomHead}), cfg.index)
})
t.Run("set refresh", func(t *testing.T) {
t.Parallel()
refresh := time.Second
cfg := configDefault(Config{
Refresh: refresh,
})
utils.AssertEqual(t, defaultTitle, cfg.Title)
utils.AssertEqual(t, refresh, cfg.Refresh)
utils.AssertEqual(t, defaultFontURL, cfg.FontURL)
utils.AssertEqual(t, defaultChartJSURL, cfg.ChartJsURL)
utils.AssertEqual(t, defaultCustomHead, cfg.CustomHead)
utils.AssertEqual(t, false, cfg.APIOnly)
utils.AssertEqual(t, (func(*fiber.Ctx) bool)(nil), cfg.Next)
utils.AssertEqual(t, newIndex(viewBag{defaultTitle, refresh, defaultFontURL, defaultChartJSURL, defaultCustomHead}), cfg.index)
})
t.Run("set font url", func(t *testing.T) {
t.Parallel()
fontURL := "https://example.com"
cfg := configDefault(Config{
FontURL: fontURL,
})
utils.AssertEqual(t, defaultTitle, cfg.Title)
utils.AssertEqual(t, defaultRefresh, cfg.Refresh)
utils.AssertEqual(t, fontURL, cfg.FontURL)
utils.AssertEqual(t, defaultChartJSURL, cfg.ChartJsURL)
utils.AssertEqual(t, defaultCustomHead, cfg.CustomHead)
utils.AssertEqual(t, false, cfg.APIOnly)
utils.AssertEqual(t, (func(*fiber.Ctx) bool)(nil), cfg.Next)
utils.AssertEqual(t, newIndex(viewBag{defaultTitle, defaultRefresh, fontURL, defaultChartJSURL, defaultCustomHead}), cfg.index)
})
t.Run("set chart js url", func(t *testing.T) {
t.Parallel()
chartURL := "http://example.com"
cfg := configDefault(Config{
ChartJsURL: chartURL,
})
utils.AssertEqual(t, defaultTitle, cfg.Title)
utils.AssertEqual(t, defaultRefresh, cfg.Refresh)
utils.AssertEqual(t, defaultFontURL, cfg.FontURL)
utils.AssertEqual(t, chartURL, cfg.ChartJsURL)
utils.AssertEqual(t, defaultCustomHead, cfg.CustomHead)
utils.AssertEqual(t, false, cfg.APIOnly)
utils.AssertEqual(t, (func(*fiber.Ctx) bool)(nil), cfg.Next)
utils.AssertEqual(t, newIndex(viewBag{defaultTitle, defaultRefresh, defaultFontURL, chartURL, defaultCustomHead}), cfg.index)
})
t.Run("set custom head", func(t *testing.T) {
t.Parallel()
head := "head"
cfg := configDefault(Config{
CustomHead: head,
})
utils.AssertEqual(t, defaultTitle, cfg.Title)
utils.AssertEqual(t, defaultRefresh, cfg.Refresh)
utils.AssertEqual(t, defaultFontURL, cfg.FontURL)
utils.AssertEqual(t, defaultChartJSURL, cfg.ChartJsURL)
utils.AssertEqual(t, head, cfg.CustomHead)
utils.AssertEqual(t, false, cfg.APIOnly)
utils.AssertEqual(t, (func(*fiber.Ctx) bool)(nil), cfg.Next)
utils.AssertEqual(t, newIndex(viewBag{defaultTitle, defaultRefresh, defaultFontURL, defaultChartJSURL, head}), cfg.index)
})
t.Run("set api only", func(t *testing.T) {
t.Parallel()
cfg := configDefault(Config{
APIOnly: true,
})
utils.AssertEqual(t, defaultTitle, cfg.Title)
utils.AssertEqual(t, defaultRefresh, cfg.Refresh)
utils.AssertEqual(t, defaultFontURL, cfg.FontURL)
utils.AssertEqual(t, defaultChartJSURL, cfg.ChartJsURL)
utils.AssertEqual(t, defaultCustomHead, cfg.CustomHead)
utils.AssertEqual(t, true, cfg.APIOnly)
utils.AssertEqual(t, (func(*fiber.Ctx) bool)(nil), cfg.Next)
utils.AssertEqual(t, newIndex(viewBag{defaultTitle, defaultRefresh, defaultFontURL, defaultChartJSURL, defaultCustomHead}), cfg.index)
})
t.Run("set next", func(t *testing.T) {
t.Parallel()
f := func(c *fiber.Ctx) bool {
return true
}
cfg := configDefault(Config{
Next: f,
})
utils.AssertEqual(t, defaultTitle, cfg.Title)
utils.AssertEqual(t, defaultRefresh, cfg.Refresh)
utils.AssertEqual(t, defaultFontURL, cfg.FontURL)
utils.AssertEqual(t, defaultChartJSURL, cfg.ChartJsURL)
utils.AssertEqual(t, defaultCustomHead, cfg.CustomHead)
utils.AssertEqual(t, false, cfg.APIOnly)
utils.AssertEqual(t, f(nil), cfg.Next(nil))
utils.AssertEqual(t, newIndex(viewBag{defaultTitle, defaultRefresh, defaultFontURL, defaultChartJSURL, defaultCustomHead}), cfg.index)
})
}

271
middleware/monitor/index.go Normal file
View file

@ -0,0 +1,271 @@
package monitor
import (
"strconv"
"strings"
"time"
)
type viewBag struct {
title string
refresh time.Duration
fontURL string
chartJSURL string
customHead string
}
// returns index with new title/refresh
func newIndex(dat viewBag) string {
timeout := dat.refresh.Milliseconds() - timeoutDiff
if timeout < timeoutDiff {
timeout = timeoutDiff
}
ts := strconv.FormatInt(timeout, 10)
replacer := strings.NewReplacer("$TITLE", dat.title, "$TIMEOUT", ts,
"$FONT_URL", dat.fontURL, "$CHART_JS_URL", dat.chartJSURL, "$CUSTOM_HEAD", dat.customHead,
)
return replacer.Replace(indexHTML)
}
const (
defaultTitle = "Fiber Monitor"
defaultRefresh = 3 * time.Second
timeoutDiff = 200 // timeout will be Refresh (in milliseconds) - timeoutDiff
minRefresh = timeoutDiff * time.Millisecond
defaultFontURL = `https://fonts.googleapis.com/css2?family=Roboto:wght@400;900&display=swap`
defaultChartJSURL = `https://cdn.jsdelivr.net/npm/chart.js@2.9/dist/Chart.bundle.min.js`
defaultCustomHead = ``
// parametrized by $TITLE and $TIMEOUT
indexHTML = `<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<link href="$FONT_URL" rel="stylesheet">
<script src="$CHART_JS_URL"></script>
<title>$TITLE</title>
<style>
body {
margin: 0;
font: 16px / 1.6 'Roboto', sans-serif;
}
.wrapper {
max-width: 900px;
margin: 0 auto;
padding: 30px 0;
}
.title {
text-align: center;
margin-bottom: 2em;
}
.title h1 {
font-size: 1.8em;
padding: 0;
margin: 0;
}
.row {
display: flex;
margin-bottom: 20px;
align-items: center;
}
.row .column:first-child { width: 35%; }
.row .column:last-child { width: 65%; }
.metric {
color: #777;
font-weight: 900;
}
h2 {
padding: 0;
margin: 0;
font-size: 2.2em;
}
h2 span {
font-size: 12px;
color: #777;
}
h2 span.ram_os { color: rgba(255, 150, 0, .8); }
h2 span.ram_total { color: rgba(0, 200, 0, .8); }
canvas {
width: 200px;
height: 180px;
}
$CUSTOM_HEAD
</style>
</head>
<body>
<section class="wrapper">
<div class="title"><h1>$TITLE</h1></div>
<section class="charts">
<div class="row">
<div class="column">
<div class="metric">CPU Usage</div>
<h2 id="cpuMetric">0.00%</h2>
</div>
<div class="column">
<canvas id="cpuChart"></canvas>
</div>
</div>
<div class="row">
<div class="column">
<div class="metric">Memory Usage</div>
<h2 id="ramMetric" title="PID used / OS used / OS total">0.00 MB</h2>
</div>
<div class="column">
<canvas id="ramChart"></canvas>
</div>
</div>
<div class="row">
<div class="column">
<div class="metric">Response Time</div>
<h2 id="rtimeMetric">0ms</h2>
</div>
<div class="column">
<canvas id="rtimeChart"></canvas>
</div>
</div>
<div class="row">
<div class="column">
<div class="metric">Open Connections</div>
<h2 id="connsMetric">0</h2>
</div>
<div class="column">
<canvas id="connsChart"></canvas>
</div>
</div>
</section>
</section>
<script>
function formatBytes(bytes, decimals = 1) {
if (bytes === 0) return '0 Bytes';
const k = 1024;
const dm = decimals < 0 ? 0 : decimals;
const sizes = ['Bytes', 'KB', 'MB', 'GB', 'TB', 'PB', 'EB', 'ZB', 'YB'];
const i = Math.floor(Math.log(bytes) / Math.log(k));
return parseFloat((bytes / Math.pow(k, i)).toFixed(dm)) + ' ' + sizes[i];
}
Chart.defaults.global.legend.display = false;
Chart.defaults.global.defaultFontSize = 8;
Chart.defaults.global.animation.duration = 1000;
Chart.defaults.global.animation.easing = 'easeOutQuart';
Chart.defaults.global.elements.line.backgroundColor = 'rgba(0, 172, 215, 0.25)';
Chart.defaults.global.elements.line.borderColor = 'rgba(0, 172, 215, 1)';
Chart.defaults.global.elements.line.borderWidth = 2;
const options = {
scales: {
yAxes: [{ ticks: { beginAtZero: true }}],
xAxes: [{
type: 'time',
time: {
unitStepSize: 30,
unit: 'second'
},
gridlines: { display: false }
}]
},
tooltips: { enabled: false },
responsive: true,
maintainAspectRatio: false,
animation: false
};
const cpuMetric = document.querySelector('#cpuMetric');
const ramMetric = document.querySelector('#ramMetric');
const rtimeMetric = document.querySelector('#rtimeMetric');
const connsMetric = document.querySelector('#connsMetric');
const cpuChartCtx = document.querySelector('#cpuChart').getContext('2d');
const ramChartCtx = document.querySelector('#ramChart').getContext('2d');
const rtimeChartCtx = document.querySelector('#rtimeChart').getContext('2d');
const connsChartCtx = document.querySelector('#connsChart').getContext('2d');
const cpuChart = createChart(cpuChartCtx);
const ramChart = createChart(ramChartCtx);
const rtimeChart = createChart(rtimeChartCtx);
const connsChart = createChart(connsChartCtx);
const charts = [cpuChart, ramChart, rtimeChart, connsChart];
function createChart(ctx) {
return new Chart(ctx, {
type: 'line',
data: {
labels: [],
datasets: [{
label: '',
data: [],
lineTension: 0.2,
pointRadius: 0,
}]
},
options
});
}
ramChart.data.datasets.push({
data: [],
lineTension: 0.2,
pointRadius: 0,
backgroundColor: 'rgba(255, 200, 0, .6)',
borderColor: 'rgba(255, 150, 0, .8)',
})
ramChart.data.datasets.push({
data: [],
lineTension: 0.2,
pointRadius: 0,
backgroundColor: 'rgba(0, 255, 0, .4)',
borderColor: 'rgba(0, 200, 0, .8)',
})
function update(json, rtime) {
cpu = json.pid.cpu.toFixed(1);
cpuOS = json.os.cpu.toFixed(1);
cpuMetric.innerHTML = cpu + '% <span>' + cpuOS + '%</span>';
ramMetric.innerHTML = formatBytes(json.pid.ram) + '<span> / </span><span class="ram_os">' + formatBytes(json.os.ram) +
'<span><span> / </span><span class="ram_total">' + formatBytes(json.os.total_ram) + '</span>';
rtimeMetric.innerHTML = rtime + 'ms <span>client</span>';
connsMetric.innerHTML = json.pid.conns + ' <span>' + json.os.conns + '</span>';
cpuChart.data.datasets[0].data.push(cpu);
ramChart.data.datasets[2].data.push((json.os.total_ram / 1e6).toFixed(2));
ramChart.data.datasets[1].data.push((json.os.ram / 1e6).toFixed(2));
ramChart.data.datasets[0].data.push((json.pid.ram / 1e6).toFixed(2));
rtimeChart.data.datasets[0].data.push(rtime);
connsChart.data.datasets[0].data.push(json.pid.conns);
const timestamp = new Date().getTime();
charts.forEach(chart => {
if (chart.data.labels.length > 50) {
chart.data.datasets.forEach(function (dataset) { dataset.data.shift(); });
chart.data.labels.shift();
}
chart.data.labels.push(timestamp);
chart.update();
});
setTimeout(fetchJSON, $TIMEOUT)
}
function fetchJSON() {
var t1 = ''
var t0 = performance.now()
fetch(window.location.href, {
headers: { 'Accept': 'application/json' },
credentials: 'same-origin'
})
.then(res => {
t1 = performance.now()
return res.json()
})
.then(res => { update(res, Math.round(t1 - t0)) })
.catch(console.error);
}
fetchJSON()
</script>
</body>
</html>
`
)

View file

@ -0,0 +1,137 @@
package monitor
import (
"os"
"runtime"
"sync"
"sync/atomic"
"time"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/internal/gopsutil/cpu"
"github.com/gofiber/fiber/v2/internal/gopsutil/load"
"github.com/gofiber/fiber/v2/internal/gopsutil/mem"
"github.com/gofiber/fiber/v2/internal/gopsutil/net"
"github.com/gofiber/fiber/v2/internal/gopsutil/process"
)
type stats struct {
PID statsPID `json:"pid"`
OS statsOS `json:"os"`
}
type statsPID struct {
CPU float64 `json:"cpu"`
RAM uint64 `json:"ram"`
Conns int `json:"conns"`
}
type statsOS struct {
CPU float64 `json:"cpu"`
RAM uint64 `json:"ram"`
TotalRAM uint64 `json:"total_ram"`
LoadAvg float64 `json:"load_avg"`
Conns int `json:"conns"`
}
var (
monitPIDCPU atomic.Value
monitPIDRAM atomic.Value
monitPIDConns atomic.Value
monitOSCPU atomic.Value
monitOSRAM atomic.Value
monitOSTotalRAM atomic.Value
monitOSLoadAvg atomic.Value
monitOSConns atomic.Value
)
var (
mutex sync.RWMutex
once sync.Once
data = &stats{}
)
// New creates a new middleware handler
func New(config ...Config) fiber.Handler {
// Set default config
cfg := configDefault(config...)
// Start routine to update statistics
once.Do(func() {
p, _ := process.NewProcess(int32(os.Getpid())) //nolint:errcheck // TODO: Handle error
numcpu := runtime.NumCPU()
updateStatistics(p, numcpu)
go func() {
for {
time.Sleep(cfg.Refresh)
updateStatistics(p, numcpu)
}
}()
})
// Return new handler
//nolint:errcheck // Ignore the type-assertion errors
return func(c *fiber.Ctx) error {
// Don't execute middleware if Next returns true
if cfg.Next != nil && cfg.Next(c) {
return c.Next()
}
if c.Method() != fiber.MethodGet {
return fiber.ErrMethodNotAllowed
}
if c.Get(fiber.HeaderAccept) == fiber.MIMEApplicationJSON || cfg.APIOnly {
mutex.Lock()
data.PID.CPU, _ = monitPIDCPU.Load().(float64)
data.PID.RAM, _ = monitPIDRAM.Load().(uint64)
data.PID.Conns, _ = monitPIDConns.Load().(int)
data.OS.CPU, _ = monitOSCPU.Load().(float64)
data.OS.RAM, _ = monitOSRAM.Load().(uint64)
data.OS.TotalRAM, _ = monitOSTotalRAM.Load().(uint64)
data.OS.LoadAvg, _ = monitOSLoadAvg.Load().(float64)
data.OS.Conns, _ = monitOSConns.Load().(int)
mutex.Unlock()
return c.Status(fiber.StatusOK).JSON(data)
}
c.Set(fiber.HeaderContentType, fiber.MIMETextHTMLCharsetUTF8)
return c.Status(fiber.StatusOK).SendString(cfg.index)
}
}
func updateStatistics(p *process.Process, numcpu int) {
pidCPU, err := p.Percent(0)
if err == nil {
monitPIDCPU.Store(pidCPU / float64(numcpu))
}
if osCPU, err := cpu.Percent(0, false); err == nil && len(osCPU) > 0 {
monitOSCPU.Store(osCPU[0])
}
if pidRAM, err := p.MemoryInfo(); err == nil && pidRAM != nil {
monitPIDRAM.Store(pidRAM.RSS)
}
if osRAM, err := mem.VirtualMemory(); err == nil && osRAM != nil {
monitOSRAM.Store(osRAM.Used)
monitOSTotalRAM.Store(osRAM.Total)
}
if loadAvg, err := load.Avg(); err == nil && loadAvg != nil {
monitOSLoadAvg.Store(loadAvg.Load1)
}
pidConns, err := net.ConnectionsPid("tcp", p.Pid)
if err == nil {
monitPIDConns.Store(len(pidConns))
}
osConns, err := net.Connections("tcp")
if err == nil {
monitOSConns.Store(len(osConns))
}
}

View file

@ -0,0 +1,198 @@
package monitor
import (
"bytes"
"fmt"
"io"
"net/http/httptest"
"testing"
"time"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/utils"
"github.com/valyala/fasthttp"
)
func Test_Monitor_405(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use("/", New())
resp, err := app.Test(httptest.NewRequest(fiber.MethodPost, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 405, resp.StatusCode)
}
func Test_Monitor_Html(t *testing.T) {
t.Parallel()
app := fiber.New()
// defaults
app.Get("/", New())
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode)
utils.AssertEqual(t, fiber.MIMETextHTMLCharsetUTF8,
resp.Header.Get(fiber.HeaderContentType))
buf, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, true, bytes.Contains(buf, []byte("<title>"+defaultTitle+"</title>")))
timeoutLine := fmt.Sprintf("setTimeout(fetchJSON, %d)",
defaultRefresh.Milliseconds()-timeoutDiff)
utils.AssertEqual(t, true, bytes.Contains(buf, []byte(timeoutLine)))
// custom config
conf := Config{Title: "New " + defaultTitle, Refresh: defaultRefresh + time.Second}
app.Get("/custom", New(conf))
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/custom", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode)
utils.AssertEqual(t, fiber.MIMETextHTMLCharsetUTF8,
resp.Header.Get(fiber.HeaderContentType))
buf, err = io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, true, bytes.Contains(buf, []byte("<title>"+conf.Title+"</title>")))
timeoutLine = fmt.Sprintf("setTimeout(fetchJSON, %d)",
conf.Refresh.Milliseconds()-timeoutDiff)
utils.AssertEqual(t, true, bytes.Contains(buf, []byte(timeoutLine)))
}
func Test_Monitor_Html_CustomCodes(t *testing.T) {
t.Parallel()
app := fiber.New()
// defaults
app.Get("/", New())
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode)
utils.AssertEqual(t, fiber.MIMETextHTMLCharsetUTF8,
resp.Header.Get(fiber.HeaderContentType))
buf, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, true, bytes.Contains(buf, []byte("<title>"+defaultTitle+"</title>")))
timeoutLine := fmt.Sprintf("setTimeout(fetchJSON, %d)",
defaultRefresh.Milliseconds()-timeoutDiff)
utils.AssertEqual(t, true, bytes.Contains(buf, []byte(timeoutLine)))
// custom config
conf := Config{
Title: "New " + defaultTitle,
Refresh: defaultRefresh + time.Second,
ChartJsURL: "https://cdnjs.com/libraries/Chart.js",
FontURL: "/public/my-font.css",
CustomHead: `<style>body{background:#fff}</style>`,
}
app.Get("/custom", New(conf))
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/custom", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode)
utils.AssertEqual(t, fiber.MIMETextHTMLCharsetUTF8,
resp.Header.Get(fiber.HeaderContentType))
buf, err = io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, true, bytes.Contains(buf, []byte("<title>"+conf.Title+"</title>")))
utils.AssertEqual(t, true, bytes.Contains(buf, []byte("https://cdnjs.com/libraries/Chart.js")))
utils.AssertEqual(t, true, bytes.Contains(buf, []byte("/public/my-font.css")))
utils.AssertEqual(t, true, bytes.Contains(buf, []byte(conf.CustomHead)))
timeoutLine = fmt.Sprintf("setTimeout(fetchJSON, %d)",
conf.Refresh.Milliseconds()-timeoutDiff)
utils.AssertEqual(t, true, bytes.Contains(buf, []byte(timeoutLine)))
}
// go test -run Test_Monitor_JSON -race
func Test_Monitor_JSON(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Get("/", New())
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
req.Header.Set(fiber.HeaderAccept, fiber.MIMEApplicationJSON)
resp, err := app.Test(req)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode)
utils.AssertEqual(t, fiber.MIMEApplicationJSON, resp.Header.Get(fiber.HeaderContentType))
b, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, true, bytes.Contains(b, []byte("pid")))
utils.AssertEqual(t, true, bytes.Contains(b, []byte("os")))
}
// go test -v -run=^$ -bench=Benchmark_Monitor -benchmem -count=4
func Benchmark_Monitor(b *testing.B) {
app := fiber.New()
app.Get("/", New())
h := app.Handler()
fctx := &fasthttp.RequestCtx{}
fctx.Request.Header.SetMethod(fiber.MethodGet)
fctx.Request.SetRequestURI("/")
fctx.Request.Header.Set(fiber.HeaderAccept, fiber.MIMEApplicationJSON)
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
h(fctx)
}
})
utils.AssertEqual(b, 200, fctx.Response.Header.StatusCode())
utils.AssertEqual(b,
fiber.MIMEApplicationJSON,
string(fctx.Response.Header.Peek(fiber.HeaderContentType)))
}
// go test -run Test_Monitor_Next
func Test_Monitor_Next(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use("/", New(Config{
Next: func(_ *fiber.Ctx) bool {
return true
},
}))
resp, err := app.Test(httptest.NewRequest(fiber.MethodPost, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 404, resp.StatusCode)
}
// go test -run Test_Monitor_APIOnly -race
func Test_Monitor_APIOnly(t *testing.T) {
app := fiber.New()
app.Get("/", New(Config{
APIOnly: true,
}))
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
req.Header.Set(fiber.HeaderAccept, fiber.MIMEApplicationJSON)
resp, err := app.Test(req)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode)
utils.AssertEqual(t, fiber.MIMEApplicationJSON, resp.Header.Get(fiber.HeaderContentType))
b, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, true, bytes.Contains(b, []byte("pid")))
utils.AssertEqual(t, true, bytes.Contains(b, []byte("os")))
}

View file

@ -0,0 +1,41 @@
package pprof
import (
"github.com/gofiber/fiber/v2"
)
// Config defines the config for middleware.
type Config struct {
// Next defines a function to skip this middleware when returned true.
//
// Optional. Default: nil
Next func(c *fiber.Ctx) bool
// Prefix defines a URL prefix added before "/debug/pprof".
// Note that it should start with (but not end with) a slash.
// Example: "/federated-fiber"
//
// Optional. Default: ""
Prefix string
}
var ConfigDefault = Config{
Next: nil,
}
func configDefault(config ...Config) Config {
// Return default config if nothing provided
if len(config) < 1 {
return ConfigDefault
}
// Override default config
cfg := config[0]
// Set default values
if cfg.Next == nil {
cfg.Next = ConfigDefault.Next
}
return cfg
}

95
middleware/pprof/pprof.go Normal file
View file

@ -0,0 +1,95 @@
package pprof
import (
"net/http/pprof"
"strings"
"github.com/gofiber/fiber/v2"
"github.com/valyala/fasthttp/fasthttpadaptor"
)
// New creates a new middleware handler
func New(config ...Config) fiber.Handler {
// Set default config
cfg := configDefault(config...)
// Set pprof adaptors
var (
pprofIndex = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Index)
pprofCmdline = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Cmdline)
pprofProfile = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Profile)
pprofSymbol = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Symbol)
pprofTrace = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Trace)
pprofAllocs = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Handler("allocs").ServeHTTP)
pprofBlock = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Handler("block").ServeHTTP)
pprofGoroutine = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Handler("goroutine").ServeHTTP)
pprofHeap = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Handler("heap").ServeHTTP)
pprofMutex = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Handler("mutex").ServeHTTP)
pprofThreadcreate = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Handler("threadcreate").ServeHTTP)
)
// Construct actual prefix
prefix := cfg.Prefix + "/debug/pprof"
// Return new handler
return func(c *fiber.Ctx) error {
// Don't execute middleware if Next returns true
if cfg.Next != nil && cfg.Next(c) {
return c.Next()
}
path := c.Path()
// We are only interested in /debug/pprof routes
path, found := cutPrefix(path, prefix)
if !found {
return c.Next()
}
// Switch on trimmed path against constant strings
switch path {
case "/":
pprofIndex(c.Context())
case "/cmdline":
pprofCmdline(c.Context())
case "/profile":
pprofProfile(c.Context())
case "/symbol":
pprofSymbol(c.Context())
case "/trace":
pprofTrace(c.Context())
case "/allocs":
pprofAllocs(c.Context())
case "/block":
pprofBlock(c.Context())
case "/goroutine":
pprofGoroutine(c.Context())
case "/heap":
pprofHeap(c.Context())
case "/mutex":
pprofMutex(c.Context())
case "/threadcreate":
pprofThreadcreate(c.Context())
default:
// pprof index only works with trailing slash
if strings.HasSuffix(path, "/") {
path = strings.TrimRight(path, "/")
} else {
path = prefix + "/"
}
return c.Redirect(path, fiber.StatusFound)
}
return nil
}
}
// cutPrefix is a copy of [strings.CutPrefix] added in Go 1.20.
// Remove this function when we drop support for Go 1.19.
//
//nolint:nonamedreturns // Align with its original form in std.
func cutPrefix(s, prefix string) (after string, found bool) {
if !strings.HasPrefix(s, prefix) {
return s, false
}
return s[len(prefix):], true
}

View file

@ -0,0 +1,200 @@
package pprof
import (
"bytes"
"io"
"net/http/httptest"
"testing"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/utils"
)
func Test_Non_Pprof_Path(t *testing.T) {
app := fiber.New(fiber.Config{DisableStartupMessage: true})
app.Use(New())
app.Get("/", func(c *fiber.Ctx) error {
return c.SendString("escaped")
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode)
b, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "escaped", string(b))
}
func Test_Non_Pprof_Path_WithPrefix(t *testing.T) {
app := fiber.New(fiber.Config{DisableStartupMessage: true})
app.Use(New(Config{Prefix: "/federated-fiber"}))
app.Get("/", func(c *fiber.Ctx) error {
return c.SendString("escaped")
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode)
b, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "escaped", string(b))
}
func Test_Pprof_Index(t *testing.T) {
app := fiber.New(fiber.Config{DisableStartupMessage: true})
app.Use(New())
app.Get("/", func(c *fiber.Ctx) error {
return c.SendString("escaped")
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/debug/pprof/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode)
utils.AssertEqual(t, fiber.MIMETextHTMLCharsetUTF8, resp.Header.Get(fiber.HeaderContentType))
b, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, true, bytes.Contains(b, []byte("<title>/debug/pprof/</title>")))
}
func Test_Pprof_Index_WithPrefix(t *testing.T) {
app := fiber.New(fiber.Config{DisableStartupMessage: true})
app.Use(New(Config{Prefix: "/federated-fiber"}))
app.Get("/", func(c *fiber.Ctx) error {
return c.SendString("escaped")
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/federated-fiber/debug/pprof/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode)
utils.AssertEqual(t, fiber.MIMETextHTMLCharsetUTF8, resp.Header.Get(fiber.HeaderContentType))
b, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, true, bytes.Contains(b, []byte("<title>/debug/pprof/</title>")))
}
func Test_Pprof_Subs(t *testing.T) {
app := fiber.New(fiber.Config{DisableStartupMessage: true})
app.Use(New())
app.Get("/", func(c *fiber.Ctx) error {
return c.SendString("escaped")
})
subs := []string{
"cmdline", "profile", "symbol", "trace", "allocs", "block",
"goroutine", "heap", "mutex", "threadcreate",
}
for _, sub := range subs {
sub := sub
t.Run(sub, func(t *testing.T) {
target := "/debug/pprof/" + sub
if sub == "profile" {
target += "?seconds=1"
}
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, target, nil), 5000)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode)
})
}
}
func Test_Pprof_Subs_WithPrefix(t *testing.T) {
app := fiber.New(fiber.Config{DisableStartupMessage: true})
app.Use(New(Config{Prefix: "/federated-fiber"}))
app.Get("/", func(c *fiber.Ctx) error {
return c.SendString("escaped")
})
subs := []string{
"cmdline", "profile", "symbol", "trace", "allocs", "block",
"goroutine", "heap", "mutex", "threadcreate",
}
for _, sub := range subs {
sub := sub
t.Run(sub, func(t *testing.T) {
target := "/federated-fiber/debug/pprof/" + sub
if sub == "profile" {
target += "?seconds=1"
}
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, target, nil), 5000)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode)
})
}
}
func Test_Pprof_Other(t *testing.T) {
app := fiber.New(fiber.Config{DisableStartupMessage: true})
app.Use(New())
app.Get("/", func(c *fiber.Ctx) error {
return c.SendString("escaped")
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/debug/pprof/302", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 302, resp.StatusCode)
}
func Test_Pprof_Other_WithPrefix(t *testing.T) {
app := fiber.New(fiber.Config{DisableStartupMessage: true})
app.Use(New(Config{Prefix: "/federated-fiber"}))
app.Get("/", func(c *fiber.Ctx) error {
return c.SendString("escaped")
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/federated-fiber/debug/pprof/302", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 302, resp.StatusCode)
}
// go test -run Test_Pprof_Next
func Test_Pprof_Next(t *testing.T) {
app := fiber.New()
app.Use(New(Config{
Next: func(_ *fiber.Ctx) bool {
return true
},
}))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/debug/pprof/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 404, resp.StatusCode)
}
// go test -run Test_Pprof_Next_WithPrefix
func Test_Pprof_Next_WithPrefix(t *testing.T) {
app := fiber.New()
app.Use(New(Config{
Next: func(_ *fiber.Ctx) bool {
return true
},
Prefix: "/federated-fiber",
}))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/federated-fiber/debug/pprof/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 404, resp.StatusCode)
}

View file

@ -0,0 +1,88 @@
package proxy
import (
"crypto/tls"
"time"
"github.com/gofiber/fiber/v2"
"github.com/valyala/fasthttp"
)
// Config defines the config for middleware.
type Config struct {
// Next defines a function to skip this middleware when returned true.
//
// Optional. Default: nil
Next func(c *fiber.Ctx) bool
// Servers defines a list of <scheme>://<host> HTTP servers,
//
// which are used in a round-robin manner.
// i.e.: "https://foobar.com, http://www.foobar.com"
//
// Required
Servers []string
// ModifyRequest allows you to alter the request
//
// Optional. Default: nil
ModifyRequest fiber.Handler
// ModifyResponse allows you to alter the response
//
// Optional. Default: nil
ModifyResponse fiber.Handler
// Timeout is the request timeout used when calling the proxy client
//
// Optional. Default: 1 second
Timeout time.Duration
// Per-connection buffer size for requests' reading.
// This also limits the maximum header size.
// Increase this buffer if your clients send multi-KB RequestURIs
// and/or multi-KB headers (for example, BIG cookies).
ReadBufferSize int
// Per-connection buffer size for responses' writing.
WriteBufferSize int
// tls config for the http client.
TlsConfig *tls.Config //nolint:stylecheck,revive // TODO: Rename to "TLSConfig" in v3
// Client is custom client when client config is complex.
// Note that Servers, Timeout, WriteBufferSize, ReadBufferSize and TlsConfig
// will not be used if the client are set.
Client *fasthttp.LBClient
}
// ConfigDefault is the default config
var ConfigDefault = Config{
Next: nil,
ModifyRequest: nil,
ModifyResponse: nil,
Timeout: fasthttp.DefaultLBClientTimeout,
}
// configDefault function to set default values
func configDefault(config ...Config) Config {
// Return default config if nothing provided
if len(config) < 1 {
return ConfigDefault
}
// Override default config
cfg := config[0]
// Set default values
if cfg.Timeout <= 0 {
cfg.Timeout = ConfigDefault.Timeout
}
// Set default values
if len(cfg.Servers) == 0 && cfg.Client == nil {
panic("Servers cannot be empty")
}
return cfg
}

267
middleware/proxy/proxy.go Normal file
View file

@ -0,0 +1,267 @@
package proxy
import (
"bytes"
"crypto/tls"
"net/url"
"strings"
"sync"
"time"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/log"
"github.com/gofiber/fiber/v2/utils"
"github.com/valyala/fasthttp"
)
// New is deprecated
func New(config Config) fiber.Handler {
log.Warn("[PROXY] proxy.New is deprecated, please use proxy.Balancer instead")
return Balancer(config)
}
// Balancer creates a load balancer among multiple upstream servers
func Balancer(config Config) fiber.Handler {
// Set default config
cfg := configDefault(config)
// Load balanced client
lbc := &fasthttp.LBClient{}
// Note that Servers, Timeout, WriteBufferSize, ReadBufferSize and TlsConfig
// will not be used if the client are set.
if config.Client == nil {
// Set timeout
lbc.Timeout = cfg.Timeout
// Scheme must be provided, falls back to http
for _, server := range cfg.Servers {
if !strings.HasPrefix(server, "http") {
server = "http://" + server
}
u, err := url.Parse(server)
if err != nil {
panic(err)
}
client := &fasthttp.HostClient{
NoDefaultUserAgentHeader: true,
DisablePathNormalizing: true,
Addr: u.Host,
ReadBufferSize: config.ReadBufferSize,
WriteBufferSize: config.WriteBufferSize,
TLSConfig: config.TlsConfig,
}
lbc.Clients = append(lbc.Clients, client)
}
} else {
// Set custom client
lbc = config.Client
}
// Return new handler
return func(c *fiber.Ctx) error {
// Don't execute middleware if Next returns true
if cfg.Next != nil && cfg.Next(c) {
return c.Next()
}
// Set request and response
req := c.Request()
res := c.Response()
// Don't proxy "Connection" header
req.Header.Del(fiber.HeaderConnection)
// Modify request
if cfg.ModifyRequest != nil {
if err := cfg.ModifyRequest(c); err != nil {
return err
}
}
req.SetRequestURI(utils.UnsafeString(req.RequestURI()))
// Forward request
if err := lbc.Do(req, res); err != nil {
return err
}
// Don't proxy "Connection" header
res.Header.Del(fiber.HeaderConnection)
// Modify response
if cfg.ModifyResponse != nil {
if err := cfg.ModifyResponse(c); err != nil {
return err
}
}
// Return nil to end proxying if no error
return nil
}
}
var client = &fasthttp.Client{
NoDefaultUserAgentHeader: true,
DisablePathNormalizing: true,
}
var lock sync.RWMutex
// WithTlsConfig update http client with a user specified tls.config
// This function should be called before Do and Forward.
// Deprecated: use WithClient instead.
//
//nolint:stylecheck,revive // TODO: Rename to "WithTLSConfig" in v3
func WithTlsConfig(tlsConfig *tls.Config) {
client.TLSConfig = tlsConfig
}
// WithClient sets the global proxy client.
// This function should be called before Do and Forward.
func WithClient(cli *fasthttp.Client) {
lock.Lock()
defer lock.Unlock()
client = cli
}
// Forward performs the given http request and fills the given http response.
// This method will return an fiber.Handler
func Forward(addr string, clients ...*fasthttp.Client) fiber.Handler {
return func(c *fiber.Ctx) error {
return Do(c, addr, clients...)
}
}
// Do performs the given http request and fills the given http response.
// This method can be used within a fiber.Handler
func Do(c *fiber.Ctx, addr string, clients ...*fasthttp.Client) error {
return doAction(c, addr, func(cli *fasthttp.Client, req *fasthttp.Request, resp *fasthttp.Response) error {
return cli.Do(req, resp)
}, clients...)
}
// DoRedirects performs the given http request and fills the given http response, following up to maxRedirectsCount redirects.
// When the redirect count exceeds maxRedirectsCount, ErrTooManyRedirects is returned.
// This method can be used within a fiber.Handler
func DoRedirects(c *fiber.Ctx, addr string, maxRedirectsCount int, clients ...*fasthttp.Client) error {
return doAction(c, addr, func(cli *fasthttp.Client, req *fasthttp.Request, resp *fasthttp.Response) error {
return cli.DoRedirects(req, resp, maxRedirectsCount)
}, clients...)
}
// DoDeadline performs the given request and waits for response until the given deadline.
// This method can be used within a fiber.Handler
func DoDeadline(c *fiber.Ctx, addr string, deadline time.Time, clients ...*fasthttp.Client) error {
return doAction(c, addr, func(cli *fasthttp.Client, req *fasthttp.Request, resp *fasthttp.Response) error {
return cli.DoDeadline(req, resp, deadline)
}, clients...)
}
// DoTimeout performs the given request and waits for response during the given timeout duration.
// This method can be used within a fiber.Handler
func DoTimeout(c *fiber.Ctx, addr string, timeout time.Duration, clients ...*fasthttp.Client) error {
return doAction(c, addr, func(cli *fasthttp.Client, req *fasthttp.Request, resp *fasthttp.Response) error {
return cli.DoTimeout(req, resp, timeout)
}, clients...)
}
func doAction(
c *fiber.Ctx,
addr string,
action func(cli *fasthttp.Client, req *fasthttp.Request, resp *fasthttp.Response) error,
clients ...*fasthttp.Client,
) error {
var cli *fasthttp.Client
// set local or global client
if len(clients) != 0 {
cli = clients[0]
} else {
lock.RLock()
cli = client
lock.RUnlock()
}
req := c.Request()
res := c.Response()
originalURL := utils.CopyString(c.OriginalURL())
defer req.SetRequestURI(originalURL)
copiedURL := utils.CopyString(addr)
req.SetRequestURI(copiedURL)
// NOTE: if req.isTLS is true, SetRequestURI keeps the scheme as https.
// Reference: https://github.com/gofiber/fiber/issues/1762
if scheme := getScheme(utils.UnsafeBytes(copiedURL)); len(scheme) > 0 {
req.URI().SetSchemeBytes(scheme)
}
req.Header.Del(fiber.HeaderConnection)
if err := action(cli, req, res); err != nil {
return err
}
res.Header.Del(fiber.HeaderConnection)
return nil
}
func getScheme(uri []byte) []byte {
i := bytes.IndexByte(uri, '/')
if i < 1 || uri[i-1] != ':' || i == len(uri)-1 || uri[i+1] != '/' {
return nil
}
return uri[:i-1]
}
// DomainForward performs an http request based on the given domain and populates the given http response.
// This method will return an fiber.Handler
func DomainForward(hostname, addr string, clients ...*fasthttp.Client) fiber.Handler {
return func(c *fiber.Ctx) error {
host := string(c.Request().Host())
if host == hostname {
return Do(c, addr+c.OriginalURL(), clients...)
}
return nil
}
}
type roundrobin struct {
sync.Mutex
current int
pool []string
}
// this method will return a string of addr server from list server.
func (r *roundrobin) get() string {
r.Lock()
defer r.Unlock()
if r.current >= len(r.pool) {
r.current %= len(r.pool)
}
result := r.pool[r.current]
r.current++
return result
}
// BalancerForward Forward performs the given http request with round robin algorithm to server and fills the given http response.
// This method will return an fiber.Handler
func BalancerForward(servers []string, clients ...*fasthttp.Client) fiber.Handler {
r := &roundrobin{
current: 0,
pool: servers,
}
return func(c *fiber.Ctx) error {
server := r.get()
if !strings.HasPrefix(server, "http") {
server = "http://" + server
}
c.Request().Header.Add("X-Real-IP", c.IP())
return Do(c, server+c.OriginalURL(), clients...)
}
}

View file

@ -0,0 +1,689 @@
package proxy
import (
"crypto/tls"
"errors"
"io"
"net"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/internal/tlstest"
"github.com/gofiber/fiber/v2/utils"
"github.com/valyala/fasthttp"
)
func createProxyTestServer(t *testing.T, handler fiber.Handler) (*fiber.App, string) {
t.Helper()
target := fiber.New(fiber.Config{DisableStartupMessage: true})
target.Get("/", handler)
ln, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0")
utils.AssertEqual(t, nil, err)
go func() {
utils.AssertEqual(t, nil, target.Listener(ln))
}()
time.Sleep(2 * time.Second)
addr := ln.Addr().String()
return target, addr
}
// go test -run Test_Proxy_Empty_Host
func Test_Proxy_Empty_Upstream_Servers(t *testing.T) {
t.Parallel()
defer func() {
if r := recover(); r != nil {
utils.AssertEqual(t, "Servers cannot be empty", r)
}
}()
app := fiber.New()
app.Use(Balancer(Config{Servers: []string{}}))
}
// go test -run Test_Proxy_Empty_Config
func Test_Proxy_Empty_Config(t *testing.T) {
t.Parallel()
defer func() {
if r := recover(); r != nil {
utils.AssertEqual(t, "Servers cannot be empty", r)
}
}()
app := fiber.New()
app.Use(New(Config{}))
}
// go test -run Test_Proxy_Next
func Test_Proxy_Next(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(Balancer(Config{
Servers: []string{"127.0.0.1"},
Next: func(_ *fiber.Ctx) bool {
return true
},
}))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
}
// go test -run Test_Proxy
func Test_Proxy(t *testing.T) {
t.Parallel()
target, addr := createProxyTestServer(t, func(c *fiber.Ctx) error {
return c.SendStatus(fiber.StatusTeapot)
})
resp, err := target.Test(httptest.NewRequest(fiber.MethodGet, "/", nil), 2000)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusTeapot, resp.StatusCode)
app := fiber.New(fiber.Config{DisableStartupMessage: true})
app.Use(Balancer(Config{Servers: []string{addr}}))
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
req.Host = addr
resp, err = app.Test(req)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusTeapot, resp.StatusCode)
}
// go test -run Test_Proxy_Balancer_WithTlsConfig
func Test_Proxy_Balancer_WithTlsConfig(t *testing.T) {
t.Parallel()
serverTLSConf, _, err := tlstest.GetTLSConfigs()
utils.AssertEqual(t, nil, err)
ln, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0")
utils.AssertEqual(t, nil, err)
ln = tls.NewListener(ln, serverTLSConf)
app := fiber.New(fiber.Config{DisableStartupMessage: true})
app.Get("/tlsbalaner", func(c *fiber.Ctx) error {
return c.SendString("tls balancer")
})
addr := ln.Addr().String()
clientTLSConf := &tls.Config{InsecureSkipVerify: true} //nolint:gosec // We're in a test func, so this is fine
// disable certificate verification in Balancer
app.Use(Balancer(Config{
Servers: []string{addr},
TlsConfig: clientTLSConf,
}))
go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }()
code, body, errs := fiber.Get("https://" + addr + "/tlsbalaner").TLSConfig(clientTLSConf).String()
utils.AssertEqual(t, 0, len(errs))
utils.AssertEqual(t, fiber.StatusOK, code)
utils.AssertEqual(t, "tls balancer", body)
}
// go test -run Test_Proxy_Forward_WithTlsConfig_To_Http
func Test_Proxy_Forward_WithTlsConfig_To_Http(t *testing.T) {
t.Parallel()
_, targetAddr := createProxyTestServer(t, func(c *fiber.Ctx) error {
return c.SendString("hello from target")
})
proxyServerTLSConf, _, err := tlstest.GetTLSConfigs()
utils.AssertEqual(t, nil, err)
proxyServerLn, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0")
utils.AssertEqual(t, nil, err)
proxyServerLn = tls.NewListener(proxyServerLn, proxyServerTLSConf)
app := fiber.New(fiber.Config{DisableStartupMessage: true})
proxyAddr := proxyServerLn.Addr().String()
app.Use(Forward("http://" + targetAddr))
go func() { utils.AssertEqual(t, nil, app.Listener(proxyServerLn)) }()
code, body, errs := fiber.Get("https://" + proxyAddr).
InsecureSkipVerify().
Timeout(5 * time.Second).
String()
utils.AssertEqual(t, 0, len(errs))
utils.AssertEqual(t, fiber.StatusOK, code)
utils.AssertEqual(t, "hello from target", body)
}
// go test -run Test_Proxy_Forward
func Test_Proxy_Forward(t *testing.T) {
t.Parallel()
app := fiber.New()
_, addr := createProxyTestServer(t, func(c *fiber.Ctx) error {
return c.SendString("forwarded")
})
app.Use(Forward("http://" + addr))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
b, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "forwarded", string(b))
}
// go test -run Test_Proxy_Forward_WithTlsConfig
func Test_Proxy_Forward_WithTlsConfig(t *testing.T) {
t.Parallel()
serverTLSConf, _, err := tlstest.GetTLSConfigs()
utils.AssertEqual(t, nil, err)
ln, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0")
utils.AssertEqual(t, nil, err)
ln = tls.NewListener(ln, serverTLSConf)
app := fiber.New(fiber.Config{DisableStartupMessage: true})
app.Get("/tlsfwd", func(c *fiber.Ctx) error {
return c.SendString("tls forward")
})
addr := ln.Addr().String()
clientTLSConf := &tls.Config{InsecureSkipVerify: true} //nolint:gosec // We're in a test func, so this is fine
// disable certificate verification
WithTlsConfig(clientTLSConf)
app.Use(Forward("https://" + addr + "/tlsfwd"))
go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }()
code, body, errs := fiber.Get("https://" + addr).TLSConfig(clientTLSConf).String()
utils.AssertEqual(t, 0, len(errs))
utils.AssertEqual(t, fiber.StatusOK, code)
utils.AssertEqual(t, "tls forward", body)
}
// go test -run Test_Proxy_Modify_Response
func Test_Proxy_Modify_Response(t *testing.T) {
t.Parallel()
_, addr := createProxyTestServer(t, func(c *fiber.Ctx) error {
return c.Status(500).SendString("not modified")
})
app := fiber.New()
app.Use(Balancer(Config{
Servers: []string{addr},
ModifyResponse: func(c *fiber.Ctx) error {
c.Response().SetStatusCode(fiber.StatusOK)
return c.SendString("modified response")
},
}))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
b, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "modified response", string(b))
}
// go test -run Test_Proxy_Modify_Request
func Test_Proxy_Modify_Request(t *testing.T) {
t.Parallel()
_, addr := createProxyTestServer(t, func(c *fiber.Ctx) error {
b := c.Request().Body()
return c.SendString(string(b))
})
app := fiber.New()
app.Use(Balancer(Config{
Servers: []string{addr},
ModifyRequest: func(c *fiber.Ctx) error {
c.Request().SetBody([]byte("modified request"))
return nil
},
}))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
b, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "modified request", string(b))
}
// go test -run Test_Proxy_Timeout_Slow_Server
func Test_Proxy_Timeout_Slow_Server(t *testing.T) {
t.Parallel()
_, addr := createProxyTestServer(t, func(c *fiber.Ctx) error {
time.Sleep(2 * time.Second)
return c.SendString("fiber is awesome")
})
app := fiber.New()
app.Use(Balancer(Config{
Servers: []string{addr},
Timeout: 3 * time.Second,
}))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil), 5000)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
b, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "fiber is awesome", string(b))
}
// go test -run Test_Proxy_With_Timeout
func Test_Proxy_With_Timeout(t *testing.T) {
t.Parallel()
_, addr := createProxyTestServer(t, func(c *fiber.Ctx) error {
time.Sleep(1 * time.Second)
return c.SendString("fiber is awesome")
})
app := fiber.New()
app.Use(Balancer(Config{
Servers: []string{addr},
Timeout: 100 * time.Millisecond,
}))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil), 2000)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusInternalServerError, resp.StatusCode)
b, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "timeout", string(b))
}
// go test -run Test_Proxy_Buffer_Size_Response
func Test_Proxy_Buffer_Size_Response(t *testing.T) {
t.Parallel()
_, addr := createProxyTestServer(t, func(c *fiber.Ctx) error {
long := strings.Join(make([]string, 5000), "-")
c.Set("Very-Long-Header", long)
return c.SendString("ok")
})
app := fiber.New()
app.Use(Balancer(Config{Servers: []string{addr}}))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusInternalServerError, resp.StatusCode)
app = fiber.New()
app.Use(Balancer(Config{
Servers: []string{addr},
ReadBufferSize: 1024 * 8,
}))
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
}
// go test -race -run Test_Proxy_Do_RestoreOriginalURL
func Test_Proxy_Do_RestoreOriginalURL(t *testing.T) {
t.Parallel()
_, addr := createProxyTestServer(t, func(c *fiber.Ctx) error {
return c.SendString("proxied")
})
app := fiber.New()
app.Get("/test", func(c *fiber.Ctx) error {
return Do(c, "http://"+addr)
})
resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil))
utils.AssertEqual(t, nil, err1)
utils.AssertEqual(t, "/test", resp.Request.URL.String())
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
body, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "proxied", string(body))
}
// go test -race -run Test_Proxy_Do_WithRealURL
func Test_Proxy_Do_WithRealURL(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Get("/test", func(c *fiber.Ctx) error {
return Do(c, "https://www.google.com")
})
resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil))
utils.AssertEqual(t, nil, err1)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
utils.AssertEqual(t, "/test", resp.Request.URL.String())
body, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, true, strings.Contains(string(body), "https://www.google.com/"))
}
// go test -race -run Test_Proxy_Do_WithRedirect
func Test_Proxy_Do_WithRedirect(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Get("/test", func(c *fiber.Ctx) error {
return Do(c, "https://google.com")
})
resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil))
utils.AssertEqual(t, nil, err1)
body, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, true, strings.Contains(string(body), "https://www.google.com/"))
utils.AssertEqual(t, 301, resp.StatusCode)
}
// go test -race -run Test_Proxy_DoRedirects_RestoreOriginalURL
func Test_Proxy_DoRedirects_RestoreOriginalURL(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Get("/test", func(c *fiber.Ctx) error {
return DoRedirects(c, "http://google.com", 1)
})
resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil))
utils.AssertEqual(t, nil, err1)
_, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
utils.AssertEqual(t, "/test", resp.Request.URL.String())
}
// go test -race -run Test_Proxy_DoRedirects_TooManyRedirects
func Test_Proxy_DoRedirects_TooManyRedirects(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Get("/test", func(c *fiber.Ctx) error {
return DoRedirects(c, "http://google.com", 0)
})
resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil))
utils.AssertEqual(t, nil, err1)
body, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "too many redirects detected when doing the request", string(body))
utils.AssertEqual(t, fiber.StatusInternalServerError, resp.StatusCode)
utils.AssertEqual(t, "/test", resp.Request.URL.String())
}
// go test -race -run Test_Proxy_DoTimeout_RestoreOriginalURL
func Test_Proxy_DoTimeout_RestoreOriginalURL(t *testing.T) {
t.Parallel()
_, addr := createProxyTestServer(t, func(c *fiber.Ctx) error {
return c.SendString("proxied")
})
app := fiber.New()
app.Get("/test", func(c *fiber.Ctx) error {
return DoTimeout(c, "http://"+addr, time.Second)
})
resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil))
utils.AssertEqual(t, nil, err1)
body, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "proxied", string(body))
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
utils.AssertEqual(t, "/test", resp.Request.URL.String())
}
// go test -race -run Test_Proxy_DoTimeout_Timeout
func Test_Proxy_DoTimeout_Timeout(t *testing.T) {
t.Parallel()
_, addr := createProxyTestServer(t, func(c *fiber.Ctx) error {
time.Sleep(time.Second * 5)
return c.SendString("proxied")
})
app := fiber.New()
app.Get("/test", func(c *fiber.Ctx) error {
return DoTimeout(c, "http://"+addr, time.Second)
})
_, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil))
utils.AssertEqual(t, errors.New("test: timeout error 1000ms"), err1)
}
// go test -race -run Test_Proxy_DoDeadline_RestoreOriginalURL
func Test_Proxy_DoDeadline_RestoreOriginalURL(t *testing.T) {
t.Parallel()
_, addr := createProxyTestServer(t, func(c *fiber.Ctx) error {
return c.SendString("proxied")
})
app := fiber.New()
app.Get("/test", func(c *fiber.Ctx) error {
return DoDeadline(c, "http://"+addr, time.Now().Add(time.Second))
})
resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil))
utils.AssertEqual(t, nil, err1)
body, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "proxied", string(body))
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
utils.AssertEqual(t, "/test", resp.Request.URL.String())
}
// go test -race -run Test_Proxy_DoDeadline_PastDeadline
func Test_Proxy_DoDeadline_PastDeadline(t *testing.T) {
t.Parallel()
_, addr := createProxyTestServer(t, func(c *fiber.Ctx) error {
time.Sleep(time.Second * 5)
return c.SendString("proxied")
})
app := fiber.New()
app.Get("/test", func(c *fiber.Ctx) error {
return DoDeadline(c, "http://"+addr, time.Now().Add(time.Second))
})
_, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil))
utils.AssertEqual(t, errors.New("test: timeout error 1000ms"), err1)
}
// go test -race -run Test_Proxy_Do_HTTP_Prefix_URL
func Test_Proxy_Do_HTTP_Prefix_URL(t *testing.T) {
t.Parallel()
_, addr := createProxyTestServer(t, func(c *fiber.Ctx) error {
return c.SendString("hello world")
})
app := fiber.New(fiber.Config{DisableStartupMessage: true})
app.Get("/*", func(c *fiber.Ctx) error {
path := c.OriginalURL()
url := strings.TrimPrefix(path, "/")
utils.AssertEqual(t, "http://"+addr, url)
if err := Do(c, url); err != nil {
return err
}
c.Response().Header.Del(fiber.HeaderServer)
return nil
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/http://"+addr, nil))
utils.AssertEqual(t, nil, err)
s, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "hello world", string(s))
}
// go test -race -run Test_Proxy_Forward_Global_Client
func Test_Proxy_Forward_Global_Client(t *testing.T) {
t.Parallel()
ln, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0")
utils.AssertEqual(t, nil, err)
WithClient(&fasthttp.Client{
NoDefaultUserAgentHeader: true,
DisablePathNormalizing: true,
})
app := fiber.New(fiber.Config{DisableStartupMessage: true})
app.Get("/test_global_client", func(c *fiber.Ctx) error {
return c.SendString("test_global_client")
})
addr := ln.Addr().String()
app.Use(Forward("http://" + addr + "/test_global_client"))
go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }()
code, body, errs := fiber.Get("http://" + addr).String()
utils.AssertEqual(t, 0, len(errs))
utils.AssertEqual(t, fiber.StatusOK, code)
utils.AssertEqual(t, "test_global_client", body)
}
// go test -race -run Test_Proxy_Forward_Local_Client
func Test_Proxy_Forward_Local_Client(t *testing.T) {
t.Parallel()
ln, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0")
utils.AssertEqual(t, nil, err)
app := fiber.New(fiber.Config{DisableStartupMessage: true})
app.Get("/test_local_client", func(c *fiber.Ctx) error {
return c.SendString("test_local_client")
})
addr := ln.Addr().String()
app.Use(Forward("http://"+addr+"/test_local_client", &fasthttp.Client{
NoDefaultUserAgentHeader: true,
DisablePathNormalizing: true,
Dial: fasthttp.Dial,
}))
go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }()
code, body, errs := fiber.Get("http://" + addr).String()
utils.AssertEqual(t, 0, len(errs))
utils.AssertEqual(t, fiber.StatusOK, code)
utils.AssertEqual(t, "test_local_client", body)
}
// go test -run Test_ProxyBalancer_Custom_Client
func Test_ProxyBalancer_Custom_Client(t *testing.T) {
t.Parallel()
target, addr := createProxyTestServer(t, func(c *fiber.Ctx) error {
return c.SendStatus(fiber.StatusTeapot)
})
resp, err := target.Test(httptest.NewRequest(fiber.MethodGet, "/", nil), 2000)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusTeapot, resp.StatusCode)
app := fiber.New(fiber.Config{DisableStartupMessage: true})
app.Use(Balancer(Config{Client: &fasthttp.LBClient{
Clients: []fasthttp.BalancingClient{
&fasthttp.HostClient{
NoDefaultUserAgentHeader: true,
DisablePathNormalizing: true,
Addr: addr,
},
},
Timeout: time.Second,
}}))
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
req.Host = addr
resp, err = app.Test(req)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusTeapot, resp.StatusCode)
}
// go test -run Test_Proxy_Domain_Forward_Local
func Test_Proxy_Domain_Forward_Local(t *testing.T) {
t.Parallel()
ln, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0")
utils.AssertEqual(t, nil, err)
app := fiber.New(fiber.Config{DisableStartupMessage: true})
// target server
ln1, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0")
utils.AssertEqual(t, nil, err)
app1 := fiber.New(fiber.Config{DisableStartupMessage: true})
app1.Get("/test", func(c *fiber.Ctx) error {
return c.SendString("test_local_client:" + c.Query("query_test"))
})
proxyAddr := ln.Addr().String()
targetAddr := ln1.Addr().String()
localDomain := strings.Replace(proxyAddr, "127.0.0.1", "localhost", 1)
app.Use(DomainForward(localDomain, "http://"+targetAddr, &fasthttp.Client{
NoDefaultUserAgentHeader: true,
DisablePathNormalizing: true,
Dial: fasthttp.Dial,
}))
go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }()
go func() { utils.AssertEqual(t, nil, app1.Listener(ln1)) }()
code, body, errs := fiber.Get("http://" + localDomain + "/test?query_test=true").String()
utils.AssertEqual(t, 0, len(errs))
utils.AssertEqual(t, fiber.StatusOK, code)
utils.AssertEqual(t, "test_local_client:true", body)
}
// go test -run Test_Proxy_Balancer_Forward_Local
func Test_Proxy_Balancer_Forward_Local(t *testing.T) {
t.Parallel()
app := fiber.New()
_, addr := createProxyTestServer(t, func(c *fiber.Ctx) error {
return c.SendString("forwarded")
})
app.Use(BalancerForward([]string{addr}))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
b, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, string(b), "forwarded")
}

View file

@ -0,0 +1,47 @@
package recover //nolint:predeclared // TODO: Rename to some non-builtin
import (
"github.com/gofiber/fiber/v2"
)
// Config defines the config for middleware.
type Config struct {
// Next defines a function to skip this middleware when returned true.
//
// Optional. Default: nil
Next func(c *fiber.Ctx) bool
// EnableStackTrace enables handling stack trace
//
// Optional. Default: false
EnableStackTrace bool
// StackTraceHandler defines a function to handle stack trace
//
// Optional. Default: defaultStackTraceHandler
StackTraceHandler func(c *fiber.Ctx, e interface{})
}
// ConfigDefault is the default config
var ConfigDefault = Config{
Next: nil,
EnableStackTrace: false,
StackTraceHandler: defaultStackTraceHandler,
}
// Helper function to set default values
func configDefault(config ...Config) Config {
// Return default config if nothing provided
if len(config) < 1 {
return ConfigDefault
}
// Override default config
cfg := config[0]
if cfg.EnableStackTrace && cfg.StackTraceHandler == nil {
cfg.StackTraceHandler = defaultStackTraceHandler
}
return cfg
}

View file

@ -0,0 +1,45 @@
package recover //nolint:predeclared // TODO: Rename to some non-builtin
import (
"fmt"
"os"
"runtime/debug"
"github.com/gofiber/fiber/v2"
)
func defaultStackTraceHandler(_ *fiber.Ctx, e interface{}) {
_, _ = os.Stderr.WriteString(fmt.Sprintf("panic: %v\n%s\n", e, debug.Stack())) //nolint:errcheck // This will never fail
}
// New creates a new middleware handler
func New(config ...Config) fiber.Handler {
// Set default config
cfg := configDefault(config...)
// Return new handler
return func(c *fiber.Ctx) (err error) { //nolint:nonamedreturns // Uses recover() to overwrite the error
// Don't execute middleware if Next returns true
if cfg.Next != nil && cfg.Next(c) {
return c.Next()
}
// Catch panics
defer func() {
if r := recover(); r != nil {
if cfg.EnableStackTrace {
cfg.StackTraceHandler(c, r)
}
var ok bool
if err, ok = r.(error); !ok {
// Set error that will call the global error handler
err = fmt.Errorf("%v", r)
}
}
}()
// Return err if exist, else move to next handler
return c.Next()
}
}

View file

@ -0,0 +1,61 @@
package recover //nolint:predeclared // TODO: Rename to some non-builtin
import (
"net/http/httptest"
"testing"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/utils"
)
// go test -run Test_Recover
func Test_Recover(t *testing.T) {
t.Parallel()
app := fiber.New(fiber.Config{
ErrorHandler: func(c *fiber.Ctx, err error) error {
utils.AssertEqual(t, "Hi, I'm an error!", err.Error())
return c.SendStatus(fiber.StatusTeapot)
},
})
app.Use(New())
app.Get("/panic", func(c *fiber.Ctx) error {
panic("Hi, I'm an error!")
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/panic", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusTeapot, resp.StatusCode)
}
// go test -run Test_Recover_Next
func Test_Recover_Next(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Next: func(_ *fiber.Ctx) bool {
return true
},
}))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
}
func Test_Recover_EnableStackTrace(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
EnableStackTrace: true,
}))
app.Get("/panic", func(c *fiber.Ctx) error {
panic("Hi, I'm an error!")
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/panic", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusInternalServerError, resp.StatusCode)
}

View file

@ -0,0 +1,53 @@
package redirect
import (
"regexp"
"github.com/gofiber/fiber/v2"
)
// Config defines the config for middleware.
type Config struct {
// Filter defines a function to skip middleware.
// Optional. Default: nil
Next func(*fiber.Ctx) bool
// Rules defines the URL path rewrite rules. The values captured in asterisk can be
// retrieved by index e.g. $1, $2 and so on.
// Required. Example:
// "/old": "/new",
// "/api/*": "/$1",
// "/js/*": "/public/javascripts/$1",
// "/users/*/orders/*": "/user/$1/order/$2",
Rules map[string]string
// The status code when redirecting
// This is ignored if Redirect is disabled
// Optional. Default: 302 Temporary Redirect
StatusCode int
rulesRegex map[*regexp.Regexp]string
}
// ConfigDefault is the default config
var ConfigDefault = Config{
StatusCode: fiber.StatusFound,
}
// Helper function to set default values
func configDefault(config ...Config) Config {
// Return default config if nothing provided
if len(config) < 1 {
return ConfigDefault
}
// Override default config
cfg := config[0]
// Set default values
if cfg.StatusCode == 0 {
cfg.StatusCode = ConfigDefault.StatusCode
}
return cfg
}

View file

@ -0,0 +1,61 @@
package redirect
import (
"regexp"
"strconv"
"strings"
"github.com/gofiber/fiber/v2"
)
// New creates a new middleware handler
func New(config ...Config) fiber.Handler {
cfg := configDefault(config...)
// Initialize
cfg.rulesRegex = map[*regexp.Regexp]string{}
for k, v := range cfg.Rules {
k = strings.ReplaceAll(k, "*", "(.*)")
k += "$"
cfg.rulesRegex[regexp.MustCompile(k)] = v
}
// Middleware function
return func(c *fiber.Ctx) error {
// Next request to skip middleware
if cfg.Next != nil && cfg.Next(c) {
return c.Next()
}
// Rewrite
for k, v := range cfg.rulesRegex {
replacer := captureTokens(k, c.Path())
if replacer != nil {
queryString := string(c.Context().QueryArgs().QueryString())
if queryString != "" {
queryString = "?" + queryString
}
return c.Redirect(replacer.Replace(v)+queryString, cfg.StatusCode)
}
}
return c.Next()
}
}
// https://github.com/labstack/echo/blob/master/middleware/rewrite.go
func captureTokens(pattern *regexp.Regexp, input string) *strings.Replacer {
if len(input) > 1 {
input = strings.TrimSuffix(input, "/")
}
groups := pattern.FindAllStringSubmatch(input, -1)
if groups == nil {
return nil
}
values := groups[0][1:]
replace := make([]string, 2*len(values))
for i, v := range values {
j := 2 * i
replace[j] = "$" + strconv.Itoa(i+1)
replace[j+1] = v
}
return strings.NewReplacer(replace...)
}

View file

@ -0,0 +1,295 @@
//nolint:bodyclose // Much easier to just ignore memory leaks in tests
package redirect
import (
"context"
"net/http"
"testing"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/utils"
)
func Test_Redirect(t *testing.T) {
app := *fiber.New()
app.Use(New(Config{
Rules: map[string]string{
"/default": "google.com",
},
StatusCode: fiber.StatusMovedPermanently,
}))
app.Use(New(Config{
Rules: map[string]string{
"/default/*": "fiber.wiki",
},
StatusCode: fiber.StatusTemporaryRedirect,
}))
app.Use(New(Config{
Rules: map[string]string{
"/redirect/*": "$1",
},
StatusCode: fiber.StatusSeeOther,
}))
app.Use(New(Config{
Rules: map[string]string{
"/pattern/*": "golang.org",
},
StatusCode: fiber.StatusFound,
}))
app.Use(New(Config{
Rules: map[string]string{
"/": "/swagger",
},
StatusCode: fiber.StatusMovedPermanently,
}))
app.Use(New(Config{
Rules: map[string]string{
"/params": "/with_params",
},
StatusCode: fiber.StatusMovedPermanently,
}))
app.Get("/api/*", func(c *fiber.Ctx) error {
return c.SendString("API")
})
app.Get("/new", func(c *fiber.Ctx) error {
return c.SendString("Hello, World!")
})
tests := []struct {
name string
url string
redirectTo string
statusCode int
}{
{
name: "should be returns status StatusFound without a wildcard",
url: "/default",
redirectTo: "google.com",
statusCode: fiber.StatusMovedPermanently,
},
{
name: "should be returns status StatusTemporaryRedirect using wildcard",
url: "/default/xyz",
redirectTo: "fiber.wiki",
statusCode: fiber.StatusTemporaryRedirect,
},
{
name: "should be returns status StatusSeeOther without set redirectTo to use the default",
url: "/redirect/github.com/gofiber/redirect",
redirectTo: "github.com/gofiber/redirect",
statusCode: fiber.StatusSeeOther,
},
{
name: "should return the status code default",
url: "/pattern/xyz",
redirectTo: "golang.org",
statusCode: fiber.StatusFound,
},
{
name: "access URL without rule",
url: "/new",
statusCode: fiber.StatusOK,
},
{
name: "redirect to swagger route",
url: "/",
redirectTo: "/swagger",
statusCode: fiber.StatusMovedPermanently,
},
{
name: "no redirect to swagger route",
url: "/api/",
statusCode: fiber.StatusOK,
},
{
name: "no redirect to swagger route #2",
url: "/api/test",
statusCode: fiber.StatusOK,
},
{
name: "redirect with query params",
url: "/params?query=abc",
redirectTo: "/with_params?query=abc",
statusCode: fiber.StatusMovedPermanently,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req, err := http.NewRequestWithContext(context.Background(), fiber.MethodGet, tt.url, nil)
utils.AssertEqual(t, err, nil)
req.Header.Set("Location", "github.com/gofiber/redirect")
resp, err := app.Test(req)
utils.AssertEqual(t, err, nil)
utils.AssertEqual(t, tt.statusCode, resp.StatusCode)
utils.AssertEqual(t, tt.redirectTo, resp.Header.Get("Location"))
})
}
}
func Test_Next(t *testing.T) {
// Case 1 : Next function always returns true
app := *fiber.New()
app.Use(New(Config{
Next: func(*fiber.Ctx) bool {
return true
},
Rules: map[string]string{
"/default": "google.com",
},
StatusCode: fiber.StatusMovedPermanently,
}))
app.Use(func(c *fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
req, err := http.NewRequestWithContext(context.Background(), fiber.MethodGet, "/default", nil)
utils.AssertEqual(t, err, nil)
resp, err := app.Test(req)
utils.AssertEqual(t, err, nil)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
// Case 2 : Next function always returns false
app = *fiber.New()
app.Use(New(Config{
Next: func(*fiber.Ctx) bool {
return false
},
Rules: map[string]string{
"/default": "google.com",
},
StatusCode: fiber.StatusMovedPermanently,
}))
req, err = http.NewRequestWithContext(context.Background(), fiber.MethodGet, "/default", nil)
utils.AssertEqual(t, err, nil)
resp, err = app.Test(req)
utils.AssertEqual(t, err, nil)
utils.AssertEqual(t, fiber.StatusMovedPermanently, resp.StatusCode)
utils.AssertEqual(t, "google.com", resp.Header.Get("Location"))
}
func Test_NoRules(t *testing.T) {
// Case 1: No rules with default route defined
app := *fiber.New()
app.Use(New(Config{
StatusCode: fiber.StatusMovedPermanently,
}))
app.Use(func(c *fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
req, err := http.NewRequestWithContext(context.Background(), fiber.MethodGet, "/default", nil)
utils.AssertEqual(t, err, nil)
resp, err := app.Test(req)
utils.AssertEqual(t, err, nil)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
// Case 2: No rules and no default route defined
app = *fiber.New()
app.Use(New(Config{
StatusCode: fiber.StatusMovedPermanently,
}))
req, err = http.NewRequestWithContext(context.Background(), fiber.MethodGet, "/default", nil)
utils.AssertEqual(t, err, nil)
resp, err = app.Test(req)
utils.AssertEqual(t, err, nil)
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
}
func Test_DefaultConfig(t *testing.T) {
// Case 1: Default config and no default route
app := *fiber.New()
app.Use(New())
req, err := http.NewRequestWithContext(context.Background(), fiber.MethodGet, "/default", nil)
utils.AssertEqual(t, err, nil)
resp, err := app.Test(req)
utils.AssertEqual(t, err, nil)
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
// Case 2: Default config and default route
app = *fiber.New()
app.Use(New())
app.Use(func(c *fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
req, err = http.NewRequestWithContext(context.Background(), fiber.MethodGet, "/default", nil)
utils.AssertEqual(t, err, nil)
resp, err = app.Test(req)
utils.AssertEqual(t, err, nil)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
}
func Test_RegexRules(t *testing.T) {
// Case 1: Rules regex is empty
app := *fiber.New()
app.Use(New(Config{
Rules: map[string]string{},
StatusCode: fiber.StatusMovedPermanently,
}))
app.Use(func(c *fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
req, err := http.NewRequestWithContext(context.Background(), fiber.MethodGet, "/default", nil)
utils.AssertEqual(t, err, nil)
resp, err := app.Test(req)
utils.AssertEqual(t, err, nil)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
// Case 2: Rules regex map contains valid regex and well-formed replacement URLs
app = *fiber.New()
app.Use(New(Config{
Rules: map[string]string{
"/default": "google.com",
},
StatusCode: fiber.StatusMovedPermanently,
}))
app.Use(func(c *fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
req, err = http.NewRequestWithContext(context.Background(), fiber.MethodGet, "/default", nil)
utils.AssertEqual(t, err, nil)
resp, err = app.Test(req)
utils.AssertEqual(t, err, nil)
utils.AssertEqual(t, fiber.StatusMovedPermanently, resp.StatusCode)
utils.AssertEqual(t, "google.com", resp.Header.Get("Location"))
// Case 3: Test invalid regex throws panic
defer func() {
if r := recover(); r != nil {
t.Log("Recovered from invalid regex: ", r)
}
}()
app = *fiber.New()
app.Use(New(Config{
Rules: map[string]string{
"(": "google.com",
},
StatusCode: fiber.StatusMovedPermanently,
}))
t.Error("Expected panic, got nil")
}

View file

@ -0,0 +1,66 @@
package requestid
import (
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/utils"
)
// Config defines the config for middleware.
type Config struct {
// Next defines a function to skip this middleware when returned true.
//
// Optional. Default: nil
Next func(c *fiber.Ctx) bool
// Header is the header key where to get/set the unique request ID
//
// Optional. Default: "X-Request-ID"
Header string
// Generator defines a function to generate the unique identifier.
//
// Optional. Default: utils.UUID
Generator func() string
// ContextKey defines the key used when storing the request ID in
// the locals for a specific request.
// Should be a private type instead of string, but too many apps probably
// rely on this exact value.
//
// Optional. Default: "requestid"
ContextKey interface{}
}
// ConfigDefault is the default config
// It uses a fast UUID generator which will expose the number of
// requests made to the server. To conceal this value for better
// privacy, use the "utils.UUIDv4" generator.
var ConfigDefault = Config{
Next: nil,
Header: fiber.HeaderXRequestID,
Generator: utils.UUID,
ContextKey: "requestid",
}
// Helper function to set default values
func configDefault(config ...Config) Config {
// Return default config if nothing provided
if len(config) < 1 {
return ConfigDefault
}
// Override default config
cfg := config[0]
// Set default values
if cfg.Header == "" {
cfg.Header = ConfigDefault.Header
}
if cfg.Generator == nil {
cfg.Generator = ConfigDefault.Generator
}
if cfg.ContextKey == nil {
cfg.ContextKey = ConfigDefault.ContextKey
}
return cfg
}

View file

@ -0,0 +1,33 @@
package requestid
import (
"github.com/gofiber/fiber/v2"
)
// New creates a new middleware handler
func New(config ...Config) fiber.Handler {
// Set default config
cfg := configDefault(config...)
// Return new handler
return func(c *fiber.Ctx) error {
// Don't execute middleware if Next returns true
if cfg.Next != nil && cfg.Next(c) {
return c.Next()
}
// Get id from request, else we generate one
rid := c.Get(cfg.Header)
if rid == "" {
rid = cfg.Generator()
}
// Set new id to response header
c.Set(cfg.Header, rid)
// Add the request ID to locals
c.Locals(cfg.ContextKey, rid)
// Continue stack
return c.Next()
}
}

View file

@ -0,0 +1,103 @@
package requestid
import (
"net/http/httptest"
"testing"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/utils"
)
// go test -run Test_RequestID
func Test_RequestID(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
app.Get("/", func(c *fiber.Ctx) error {
return c.SendString("Hello, World 👋!")
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
reqid := resp.Header.Get(fiber.HeaderXRequestID)
utils.AssertEqual(t, 36, len(reqid))
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
req.Header.Add(fiber.HeaderXRequestID, reqid)
resp, err = app.Test(req)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
utils.AssertEqual(t, reqid, resp.Header.Get(fiber.HeaderXRequestID))
}
// go test -run Test_RequestID_Next
func Test_RequestID_Next(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Next: func(_ *fiber.Ctx) bool {
return true
},
}))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, resp.Header.Get(fiber.HeaderXRequestID), "")
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
}
// go test -run Test_RequestID_Locals
func Test_RequestID_Locals(t *testing.T) {
t.Parallel()
reqID := "ThisIsARequestId"
type ContextKey int
const requestContextKey ContextKey = iota
app := fiber.New()
app.Use(New(Config{
Generator: func() string {
return reqID
},
ContextKey: requestContextKey,
}))
var ctxVal string
app.Use(func(c *fiber.Ctx) error {
ctxVal = c.Locals(requestContextKey).(string) //nolint:forcetypeassert,errcheck // We always store a string in here
return c.Next()
})
_, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, reqID, ctxVal)
}
// go test -run Test_RequestID_DefaultKey
func Test_RequestID_DefaultKey(t *testing.T) {
t.Parallel()
reqID := "ThisIsARequestId"
app := fiber.New()
app.Use(New(Config{
Generator: func() string {
return reqID
},
}))
var ctxVal string
app.Use(func(c *fiber.Ctx) error {
ctxVal = c.Locals("requestid").(string) //nolint:forcetypeassert,errcheck // We always store a string in here
return c.Next()
})
_, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, reqID, ctxVal)
}

View file

@ -0,0 +1,38 @@
package rewrite
import (
"regexp"
"github.com/gofiber/fiber/v2"
)
// Config defines the config for middleware.
type Config struct {
// Next defines a function to skip middleware.
// Optional. Default: nil
Next func(*fiber.Ctx) bool
// Rules defines the URL path rewrite rules. The values captured in asterisk can be
// retrieved by index e.g. $1, $2 and so on.
// Required. Example:
// "/old": "/new",
// "/api/*": "/$1",
// "/js/*": "/public/javascripts/$1",
// "/users/*/orders/*": "/user/$1/order/$2",
Rules map[string]string
rulesRegex map[*regexp.Regexp]string
}
// Helper function to set default values
func configDefault(config ...Config) Config {
// Return default config if nothing provided
if len(config) < 1 {
return Config{}
}
// Override default config
cfg := config[0]
return cfg
}

View file

@ -0,0 +1,54 @@
package rewrite
import (
"regexp"
"strconv"
"strings"
"github.com/gofiber/fiber/v2"
)
// New creates a new middleware handler
func New(config ...Config) fiber.Handler {
cfg := configDefault(config...)
// Initialize
cfg.rulesRegex = map[*regexp.Regexp]string{}
for k, v := range cfg.Rules {
k = strings.ReplaceAll(k, "*", "(.*)")
k += "$"
cfg.rulesRegex[regexp.MustCompile(k)] = v
}
// Middleware function
return func(c *fiber.Ctx) error {
// Next request to skip middleware
if cfg.Next != nil && cfg.Next(c) {
return c.Next()
}
// Rewrite
for k, v := range cfg.rulesRegex {
replacer := captureTokens(k, c.Path())
if replacer != nil {
c.Path(replacer.Replace(v))
break
}
}
return c.Next()
}
}
// https://github.com/labstack/echo/blob/master/middleware/rewrite.go
func captureTokens(pattern *regexp.Regexp, input string) *strings.Replacer {
groups := pattern.FindAllStringSubmatch(input, -1)
if groups == nil {
return nil
}
values := groups[0][1:]
replace := make([]string, 2*len(values))
for i, v := range values {
j := 2 * i
replace[j] = "$" + strconv.Itoa(i+1)
replace[j+1] = v
}
return strings.NewReplacer(replace...)
}

Some files were not shown because too many files have changed in this diff Show more