Adding upstream version 2.52.6.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
a960158181
commit
6d002e9543
441 changed files with 95392 additions and 0 deletions
171
middleware/adaptor/adaptor.go
Normal file
171
middleware/adaptor/adaptor.go
Normal file
|
@ -0,0 +1,171 @@
|
|||
package adaptor
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"unsafe"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
"github.com/valyala/fasthttp/fasthttpadaptor"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
)
|
||||
|
||||
// HTTPHandlerFunc wraps net/http handler func to fiber handler
|
||||
func HTTPHandlerFunc(h http.HandlerFunc) fiber.Handler {
|
||||
return HTTPHandler(h)
|
||||
}
|
||||
|
||||
// HTTPHandler wraps net/http handler to fiber handler
|
||||
func HTTPHandler(h http.Handler) fiber.Handler {
|
||||
return func(c *fiber.Ctx) error {
|
||||
handler := fasthttpadaptor.NewFastHTTPHandler(h)
|
||||
handler(c.Context())
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// ConvertRequest converts a fiber.Ctx to a http.Request.
|
||||
// forServer should be set to true when the http.Request is going to be passed to a http.Handler.
|
||||
func ConvertRequest(c *fiber.Ctx, forServer bool) (*http.Request, error) {
|
||||
var req http.Request
|
||||
if err := fasthttpadaptor.ConvertRequest(c.Context(), &req, forServer); err != nil {
|
||||
return nil, err //nolint:wrapcheck // This must not be wrapped
|
||||
}
|
||||
return &req, nil
|
||||
}
|
||||
|
||||
// CopyContextToFiberContext copies the values of context.Context to a fasthttp.RequestCtx
|
||||
func CopyContextToFiberContext(context interface{}, requestContext *fasthttp.RequestCtx) {
|
||||
contextValues := reflect.ValueOf(context).Elem()
|
||||
contextKeys := reflect.TypeOf(context).Elem()
|
||||
if contextKeys.Kind() == reflect.Struct {
|
||||
var lastKey interface{}
|
||||
for i := 0; i < contextValues.NumField(); i++ {
|
||||
reflectValue := contextValues.Field(i)
|
||||
/* #nosec */
|
||||
reflectValue = reflect.NewAt(reflectValue.Type(), unsafe.Pointer(reflectValue.UnsafeAddr())).Elem()
|
||||
|
||||
reflectField := contextKeys.Field(i)
|
||||
|
||||
if reflectField.Name == "noCopy" {
|
||||
break
|
||||
} else if reflectField.Name == "Context" {
|
||||
CopyContextToFiberContext(reflectValue.Interface(), requestContext)
|
||||
} else if reflectField.Name == "key" {
|
||||
lastKey = reflectValue.Interface()
|
||||
} else if lastKey != nil && reflectField.Name == "val" {
|
||||
requestContext.SetUserValue(lastKey, reflectValue.Interface())
|
||||
} else {
|
||||
lastKey = nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// HTTPMiddleware wraps net/http middleware to fiber middleware
|
||||
func HTTPMiddleware(mw func(http.Handler) http.Handler) fiber.Handler {
|
||||
return func(c *fiber.Ctx) error {
|
||||
var next bool
|
||||
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
next = true
|
||||
// Convert again in case request may modify by middleware
|
||||
c.Request().Header.SetMethod(r.Method)
|
||||
c.Request().SetRequestURI(r.RequestURI)
|
||||
c.Request().SetHost(r.Host)
|
||||
c.Request().Header.SetHost(r.Host)
|
||||
for key, val := range r.Header {
|
||||
for _, v := range val {
|
||||
c.Request().Header.Set(key, v)
|
||||
}
|
||||
}
|
||||
CopyContextToFiberContext(r.Context(), c.Context())
|
||||
})
|
||||
|
||||
if err := HTTPHandler(mw(nextHandler))(c); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if next {
|
||||
return c.Next()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// FiberHandler wraps fiber handler to net/http handler
|
||||
func FiberHandler(h fiber.Handler) http.Handler {
|
||||
return FiberHandlerFunc(h)
|
||||
}
|
||||
|
||||
// FiberHandlerFunc wraps fiber handler to net/http handler func
|
||||
func FiberHandlerFunc(h fiber.Handler) http.HandlerFunc {
|
||||
return handlerFunc(fiber.New(), h)
|
||||
}
|
||||
|
||||
// FiberApp wraps fiber app to net/http handler func
|
||||
func FiberApp(app *fiber.App) http.HandlerFunc {
|
||||
return handlerFunc(app)
|
||||
}
|
||||
|
||||
func handlerFunc(app *fiber.App, h ...fiber.Handler) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
// New fasthttp request
|
||||
req := fasthttp.AcquireRequest()
|
||||
defer fasthttp.ReleaseRequest(req)
|
||||
// Convert net/http -> fasthttp request
|
||||
if r.Body != nil {
|
||||
n, err := io.Copy(req.BodyWriter(), r.Body)
|
||||
req.Header.SetContentLength(int(n))
|
||||
|
||||
if err != nil {
|
||||
http.Error(w, utils.StatusMessage(fiber.StatusInternalServerError), fiber.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
req.Header.SetMethod(r.Method)
|
||||
req.SetRequestURI(r.RequestURI)
|
||||
req.SetHost(r.Host)
|
||||
req.Header.SetHost(r.Host)
|
||||
for key, val := range r.Header {
|
||||
for _, v := range val {
|
||||
req.Header.Set(key, v)
|
||||
}
|
||||
}
|
||||
if _, _, err := net.SplitHostPort(r.RemoteAddr); err != nil && err.(*net.AddrError).Err == "missing port in address" { //nolint:errorlint, forcetypeassert // overlinting
|
||||
r.RemoteAddr = net.JoinHostPort(r.RemoteAddr, "80")
|
||||
}
|
||||
remoteAddr, err := net.ResolveTCPAddr("tcp", r.RemoteAddr)
|
||||
if err != nil {
|
||||
http.Error(w, utils.StatusMessage(fiber.StatusInternalServerError), fiber.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// New fasthttp Ctx
|
||||
var fctx fasthttp.RequestCtx
|
||||
fctx.Init(req, remoteAddr, nil)
|
||||
if len(h) > 0 {
|
||||
// New fiber Ctx
|
||||
ctx := app.AcquireCtx(&fctx)
|
||||
defer app.ReleaseCtx(ctx)
|
||||
// Execute fiber Ctx
|
||||
err := h[0](ctx)
|
||||
if err != nil {
|
||||
_ = app.Config().ErrorHandler(ctx, err) //nolint:errcheck // not needed
|
||||
}
|
||||
} else {
|
||||
// Execute fasthttp Ctx though app.Handler
|
||||
app.Handler()(&fctx)
|
||||
}
|
||||
|
||||
// Convert fasthttp Ctx > net/http
|
||||
fctx.Response.Header.VisitAll(func(k, v []byte) {
|
||||
w.Header().Add(string(k), string(v))
|
||||
})
|
||||
w.WriteHeader(fctx.Response.StatusCode())
|
||||
_, _ = w.Write(fctx.Response.Body()) //nolint:errcheck // not needed
|
||||
}
|
||||
}
|
492
middleware/adaptor/adaptor_test.go
Normal file
492
middleware/adaptor/adaptor_test.go
Normal file
|
@ -0,0 +1,492 @@
|
|||
//nolint:bodyclose, contextcheck, revive // Much easier to just ignore memory leaks in tests
|
||||
package adaptor
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
)
|
||||
|
||||
func Test_HTTPHandler(t *testing.T) {
|
||||
expectedMethod := fiber.MethodPost
|
||||
expectedProto := "HTTP/1.1"
|
||||
expectedProtoMajor := 1
|
||||
expectedProtoMinor := 1
|
||||
expectedRequestURI := "/foo/bar?baz=123"
|
||||
expectedBody := "body 123 foo bar baz"
|
||||
expectedContentLength := len(expectedBody)
|
||||
expectedHost := "foobar.com"
|
||||
expectedRemoteAddr := "1.2.3.4:6789"
|
||||
expectedHeader := map[string]string{
|
||||
"Foo-Bar": "baz",
|
||||
"Abc": "defg",
|
||||
"XXX-Remote-Addr": "123.43.4543.345",
|
||||
}
|
||||
expectedURL, err := url.ParseRequestURI(expectedRequestURI)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
type contextKeyType string
|
||||
expectedContextKey := contextKeyType("contextKey")
|
||||
expectedContextValue := "contextValue"
|
||||
|
||||
callsCount := 0
|
||||
nethttpH := func(w http.ResponseWriter, r *http.Request) {
|
||||
callsCount++
|
||||
utils.AssertEqual(t, expectedMethod, r.Method, "Method")
|
||||
utils.AssertEqual(t, expectedProto, r.Proto, "Proto")
|
||||
utils.AssertEqual(t, expectedProtoMajor, r.ProtoMajor, "ProtoMajor")
|
||||
utils.AssertEqual(t, expectedProtoMinor, r.ProtoMinor, "ProtoMinor")
|
||||
utils.AssertEqual(t, expectedRequestURI, r.RequestURI, "RequestURI")
|
||||
utils.AssertEqual(t, expectedContentLength, int(r.ContentLength), "ContentLength")
|
||||
utils.AssertEqual(t, 0, len(r.TransferEncoding), "TransferEncoding")
|
||||
utils.AssertEqual(t, expectedHost, r.Host, "Host")
|
||||
utils.AssertEqual(t, expectedRemoteAddr, r.RemoteAddr, "RemoteAddr")
|
||||
|
||||
body, err := io.ReadAll(r.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, expectedBody, string(body), "Body")
|
||||
utils.AssertEqual(t, expectedURL, r.URL, "URL")
|
||||
utils.AssertEqual(t, expectedContextValue, r.Context().Value(expectedContextKey), "Context")
|
||||
|
||||
for k, expectedV := range expectedHeader {
|
||||
v := r.Header.Get(k)
|
||||
utils.AssertEqual(t, expectedV, v, "Header")
|
||||
}
|
||||
|
||||
w.Header().Set("Header1", "value1")
|
||||
w.Header().Set("Header2", "value2")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
fmt.Fprintf(w, "request body is %q", body)
|
||||
}
|
||||
fiberH := HTTPHandlerFunc(http.HandlerFunc(nethttpH))
|
||||
fiberH = setFiberContextValueMiddleware(fiberH, expectedContextKey, expectedContextValue)
|
||||
|
||||
var fctx fasthttp.RequestCtx
|
||||
var req fasthttp.Request
|
||||
|
||||
req.Header.SetMethod(expectedMethod)
|
||||
req.SetRequestURI(expectedRequestURI)
|
||||
req.Header.SetHost(expectedHost)
|
||||
req.BodyWriter().Write([]byte(expectedBody)) //nolint:errcheck, gosec // not needed
|
||||
for k, v := range expectedHeader {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
|
||||
remoteAddr, err := net.ResolveTCPAddr("tcp", expectedRemoteAddr)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
fctx.Init(&req, remoteAddr, nil)
|
||||
app := fiber.New()
|
||||
ctx := app.AcquireCtx(&fctx)
|
||||
defer app.ReleaseCtx(ctx)
|
||||
|
||||
err = fiberH(ctx)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 1, callsCount, "callsCount")
|
||||
|
||||
resp := &fctx.Response
|
||||
utils.AssertEqual(t, http.StatusBadRequest, resp.StatusCode(), "StatusCode")
|
||||
utils.AssertEqual(t, "value1", string(resp.Header.Peek("Header1")), "Header1")
|
||||
utils.AssertEqual(t, "value2", string(resp.Header.Peek("Header2")), "Header2")
|
||||
|
||||
expectedResponseBody := fmt.Sprintf("request body is %q", expectedBody)
|
||||
utils.AssertEqual(t, expectedResponseBody, string(resp.Body()), "Body")
|
||||
}
|
||||
|
||||
type contextKey string
|
||||
|
||||
func (c contextKey) String() string {
|
||||
return "test-" + string(c)
|
||||
}
|
||||
|
||||
var (
|
||||
TestContextKey = contextKey("TestContextKey")
|
||||
TestContextSecondKey = contextKey("TestContextSecondKey")
|
||||
)
|
||||
|
||||
func Test_HTTPMiddleware(t *testing.T) {
|
||||
const expectedHost = "foobar.com"
|
||||
tests := []struct {
|
||||
name string
|
||||
url string
|
||||
method string
|
||||
statusCode int
|
||||
}{
|
||||
{
|
||||
name: "Should return 200",
|
||||
url: "/",
|
||||
method: "POST",
|
||||
statusCode: 200,
|
||||
},
|
||||
{
|
||||
name: "Should return 405",
|
||||
url: "/",
|
||||
method: "GET",
|
||||
statusCode: 405,
|
||||
},
|
||||
{
|
||||
name: "Should return 400",
|
||||
url: "/unknown",
|
||||
method: "POST",
|
||||
statusCode: 404,
|
||||
},
|
||||
}
|
||||
|
||||
nethttpMW := func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
r = r.WithContext(context.WithValue(r.Context(), TestContextKey, "okay"))
|
||||
r = r.WithContext(context.WithValue(r.Context(), TestContextSecondKey, "not_okay"))
|
||||
r = r.WithContext(context.WithValue(r.Context(), TestContextSecondKey, "okay"))
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
app := fiber.New()
|
||||
app.Use(HTTPMiddleware(nethttpMW))
|
||||
app.Post("/", func(c *fiber.Ctx) error {
|
||||
value := c.Context().Value(TestContextKey)
|
||||
val, ok := value.(string)
|
||||
if !ok {
|
||||
t.Error("unexpected error on type-assertion")
|
||||
}
|
||||
if value != nil {
|
||||
c.Set("context_okay", val)
|
||||
}
|
||||
value = c.Context().Value(TestContextSecondKey)
|
||||
if value != nil {
|
||||
val, ok := value.(string)
|
||||
if !ok {
|
||||
t.Error("unexpected error on type-assertion")
|
||||
}
|
||||
c.Set("context_second_okay", val)
|
||||
}
|
||||
return c.SendStatus(fiber.StatusOK)
|
||||
})
|
||||
|
||||
for _, tt := range tests {
|
||||
req, err := http.NewRequestWithContext(context.Background(), tt.method, tt.url, nil)
|
||||
req.Host = expectedHost
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
resp, err := app.Test(req)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, tt.statusCode, resp.StatusCode, "StatusCode")
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(context.Background(), fiber.MethodPost, "/", nil)
|
||||
req.Host = expectedHost
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
resp, err := app.Test(req)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, resp.Header.Get("context_okay"), "okay")
|
||||
utils.AssertEqual(t, resp.Header.Get("context_second_okay"), "okay")
|
||||
}
|
||||
|
||||
func Test_FiberHandler(t *testing.T) {
|
||||
testFiberToHandlerFunc(t, false)
|
||||
}
|
||||
|
||||
func Test_FiberApp(t *testing.T) {
|
||||
testFiberToHandlerFunc(t, false, fiber.New())
|
||||
}
|
||||
|
||||
func Test_FiberHandlerDefaultPort(t *testing.T) {
|
||||
testFiberToHandlerFunc(t, true)
|
||||
}
|
||||
|
||||
func Test_FiberAppDefaultPort(t *testing.T) {
|
||||
testFiberToHandlerFunc(t, true, fiber.New())
|
||||
}
|
||||
|
||||
func testFiberToHandlerFunc(t *testing.T, checkDefaultPort bool, app ...*fiber.App) {
|
||||
t.Helper()
|
||||
|
||||
expectedMethod := fiber.MethodPost
|
||||
expectedRequestURI := "/foo/bar?baz=123"
|
||||
expectedBody := "body 123 foo bar baz"
|
||||
expectedContentLength := len(expectedBody)
|
||||
expectedHost := "foobar.com"
|
||||
expectedRemoteAddr := "1.2.3.4:6789"
|
||||
if checkDefaultPort {
|
||||
expectedRemoteAddr = "1.2.3.4:80"
|
||||
}
|
||||
expectedHeader := map[string]string{
|
||||
"Foo-Bar": "baz",
|
||||
"Abc": "defg",
|
||||
"XXX-Remote-Addr": "123.43.4543.345",
|
||||
}
|
||||
expectedURL, err := url.ParseRequestURI(expectedRequestURI)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
callsCount := 0
|
||||
fiberH := func(c *fiber.Ctx) error {
|
||||
callsCount++
|
||||
utils.AssertEqual(t, expectedMethod, c.Method(), "Method")
|
||||
utils.AssertEqual(t, expectedRequestURI, string(c.Context().RequestURI()), "RequestURI")
|
||||
utils.AssertEqual(t, expectedContentLength, c.Context().Request.Header.ContentLength(), "ContentLength")
|
||||
utils.AssertEqual(t, expectedHost, c.Hostname(), "Host")
|
||||
utils.AssertEqual(t, expectedHost, string(c.Request().Header.Host()), "Host")
|
||||
utils.AssertEqual(t, "http://"+expectedHost, c.BaseURL(), "BaseURL")
|
||||
utils.AssertEqual(t, expectedRemoteAddr, c.Context().RemoteAddr().String(), "RemoteAddr")
|
||||
|
||||
body := string(c.Body())
|
||||
utils.AssertEqual(t, expectedBody, body, "Body")
|
||||
utils.AssertEqual(t, expectedURL.String(), c.OriginalURL(), "URL")
|
||||
|
||||
for k, expectedV := range expectedHeader {
|
||||
v := c.Get(k)
|
||||
utils.AssertEqual(t, expectedV, v, "Header")
|
||||
}
|
||||
|
||||
c.Set("Header1", "value1")
|
||||
c.Set("Header2", "value2")
|
||||
c.Status(fiber.StatusBadRequest)
|
||||
_, err := c.Write([]byte(fmt.Sprintf("request body is %q", body)))
|
||||
return err
|
||||
}
|
||||
|
||||
var handlerFunc http.HandlerFunc
|
||||
if len(app) > 0 {
|
||||
app[0].Post("/foo/bar", fiberH)
|
||||
handlerFunc = FiberApp(app[0])
|
||||
} else {
|
||||
handlerFunc = FiberHandlerFunc(fiberH)
|
||||
}
|
||||
|
||||
var r http.Request
|
||||
|
||||
r.Method = expectedMethod
|
||||
r.Body = &netHTTPBody{[]byte(expectedBody)}
|
||||
r.RequestURI = expectedRequestURI
|
||||
r.ContentLength = int64(expectedContentLength)
|
||||
r.Host = expectedHost
|
||||
r.RemoteAddr = expectedRemoteAddr
|
||||
if checkDefaultPort {
|
||||
r.RemoteAddr = "1.2.3.4"
|
||||
}
|
||||
|
||||
hdr := make(http.Header)
|
||||
for k, v := range expectedHeader {
|
||||
hdr.Set(k, v)
|
||||
}
|
||||
r.Header = hdr
|
||||
|
||||
var w netHTTPResponseWriter
|
||||
handlerFunc.ServeHTTP(&w, &r)
|
||||
|
||||
utils.AssertEqual(t, http.StatusBadRequest, w.StatusCode(), "StatusCode")
|
||||
utils.AssertEqual(t, "value1", w.Header().Get("Header1"), "Header1")
|
||||
utils.AssertEqual(t, "value2", w.Header().Get("Header2"), "Header2")
|
||||
|
||||
expectedResponseBody := fmt.Sprintf("request body is %q", expectedBody)
|
||||
utils.AssertEqual(t, expectedResponseBody, string(w.body), "Body")
|
||||
}
|
||||
|
||||
func setFiberContextValueMiddleware(next fiber.Handler, key, value interface{}) fiber.Handler {
|
||||
return func(c *fiber.Ctx) error {
|
||||
c.Locals(key, value)
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_FiberHandler_RequestNilBody(t *testing.T) {
|
||||
expectedMethod := fiber.MethodGet
|
||||
expectedRequestURI := "/foo/bar"
|
||||
expectedContentLength := 0
|
||||
|
||||
callsCount := 0
|
||||
fiberH := func(c *fiber.Ctx) error {
|
||||
callsCount++
|
||||
utils.AssertEqual(t, expectedMethod, c.Method(), "Method")
|
||||
utils.AssertEqual(t, expectedRequestURI, string(c.Context().RequestURI()), "RequestURI")
|
||||
utils.AssertEqual(t, expectedContentLength, c.Context().Request.Header.ContentLength(), "ContentLength")
|
||||
|
||||
_, err := c.Write([]byte("request body is nil"))
|
||||
return err
|
||||
}
|
||||
nethttpH := FiberHandler(fiberH)
|
||||
|
||||
var r http.Request
|
||||
|
||||
r.Method = expectedMethod
|
||||
r.RequestURI = expectedRequestURI
|
||||
|
||||
var w netHTTPResponseWriter
|
||||
nethttpH.ServeHTTP(&w, &r)
|
||||
|
||||
expectedResponseBody := "request body is nil"
|
||||
utils.AssertEqual(t, expectedResponseBody, string(w.body), "Body")
|
||||
}
|
||||
|
||||
type netHTTPBody struct {
|
||||
b []byte
|
||||
}
|
||||
|
||||
func (r *netHTTPBody) Read(p []byte) (int, error) {
|
||||
if len(r.b) == 0 {
|
||||
return 0, io.EOF
|
||||
}
|
||||
n := copy(p, r.b)
|
||||
r.b = r.b[n:]
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (r *netHTTPBody) Close() error {
|
||||
r.b = r.b[:0]
|
||||
return nil
|
||||
}
|
||||
|
||||
type netHTTPResponseWriter struct {
|
||||
statusCode int
|
||||
h http.Header
|
||||
body []byte
|
||||
}
|
||||
|
||||
func (w *netHTTPResponseWriter) StatusCode() int {
|
||||
if w.statusCode == 0 {
|
||||
return http.StatusOK
|
||||
}
|
||||
return w.statusCode
|
||||
}
|
||||
|
||||
func (w *netHTTPResponseWriter) Header() http.Header {
|
||||
if w.h == nil {
|
||||
w.h = make(http.Header)
|
||||
}
|
||||
return w.h
|
||||
}
|
||||
|
||||
func (w *netHTTPResponseWriter) WriteHeader(statusCode int) {
|
||||
w.statusCode = statusCode
|
||||
}
|
||||
|
||||
func (w *netHTTPResponseWriter) Write(p []byte) (int, error) {
|
||||
w.body = append(w.body, p...)
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func Test_ConvertRequest(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app := fiber.New()
|
||||
|
||||
app.Get("/test", func(c *fiber.Ctx) error {
|
||||
httpReq, err := ConvertRequest(c, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.SendString("Request URL: " + httpReq.URL.String())
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test?hello=world&another=test", http.NoBody))
|
||||
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
||||
utils.AssertEqual(t, http.StatusOK, resp.StatusCode, "Status code")
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, "Request URL: /test?hello=world&another=test", string(body))
|
||||
}
|
||||
|
||||
// Benchmark for FiberHandlerFunc
|
||||
func Benchmark_FiberHandlerFunc_1MB(b *testing.B) {
|
||||
fiberH := func(c *fiber.Ctx) error {
|
||||
return c.SendStatus(fiber.StatusOK)
|
||||
}
|
||||
handlerFunc := FiberHandlerFunc(fiberH)
|
||||
|
||||
// Create body content
|
||||
bodyContent := make([]byte, 1*1024*1024)
|
||||
bodyBuffer := bytes.NewBuffer(bodyContent)
|
||||
|
||||
r := http.Request{
|
||||
Method: http.MethodPost,
|
||||
Body: http.NoBody,
|
||||
}
|
||||
|
||||
// Replace the empty Body with our buffer
|
||||
r.Body = io.NopCloser(bodyBuffer)
|
||||
defer r.Body.Close() //nolint:errcheck // not needed
|
||||
|
||||
// Create recorder
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
handlerFunc.ServeHTTP(w, &r)
|
||||
}
|
||||
}
|
||||
|
||||
func Benchmark_FiberHandlerFunc_10MB(b *testing.B) {
|
||||
fiberH := func(c *fiber.Ctx) error {
|
||||
return c.SendStatus(fiber.StatusOK)
|
||||
}
|
||||
handlerFunc := FiberHandlerFunc(fiberH)
|
||||
|
||||
// Create body content
|
||||
bodyContent := make([]byte, 10*1024*1024)
|
||||
bodyBuffer := bytes.NewBuffer(bodyContent)
|
||||
|
||||
r := http.Request{
|
||||
Method: http.MethodPost,
|
||||
Body: http.NoBody,
|
||||
}
|
||||
|
||||
// Replace the empty Body with our buffer
|
||||
r.Body = io.NopCloser(bodyBuffer)
|
||||
defer r.Body.Close() //nolint:errcheck // not needed
|
||||
|
||||
// Create recorder
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
handlerFunc.ServeHTTP(w, &r)
|
||||
}
|
||||
}
|
||||
|
||||
func Benchmark_FiberHandlerFunc_50MB(b *testing.B) {
|
||||
fiberH := func(c *fiber.Ctx) error {
|
||||
return c.SendStatus(fiber.StatusOK)
|
||||
}
|
||||
handlerFunc := FiberHandlerFunc(fiberH)
|
||||
|
||||
// Create body content
|
||||
bodyContent := make([]byte, 50*1024*1024)
|
||||
bodyBuffer := bytes.NewBuffer(bodyContent)
|
||||
|
||||
r := http.Request{
|
||||
Method: http.MethodPost,
|
||||
Body: http.NoBody,
|
||||
}
|
||||
|
||||
// Replace the empty Body with our buffer
|
||||
r.Body = io.NopCloser(bodyBuffer)
|
||||
defer r.Body.Close() //nolint:errcheck // not needed
|
||||
|
||||
// Create recorder
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
handlerFunc.ServeHTTP(w, &r)
|
||||
}
|
||||
}
|
60
middleware/basicauth/basicauth.go
Normal file
60
middleware/basicauth/basicauth.go
Normal file
|
@ -0,0 +1,60 @@
|
|||
package basicauth
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"strings"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
)
|
||||
|
||||
// New creates a new middleware handler
|
||||
func New(config Config) fiber.Handler {
|
||||
// Set default config
|
||||
cfg := configDefault(config)
|
||||
|
||||
// Return new handler
|
||||
return func(c *fiber.Ctx) error {
|
||||
// Don't execute middleware if Next returns true
|
||||
if cfg.Next != nil && cfg.Next(c) {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
// Get authorization header
|
||||
auth := c.Get(fiber.HeaderAuthorization)
|
||||
|
||||
// Check if the header contains content besides "basic".
|
||||
if len(auth) <= 6 || !utils.EqualFold(auth[:6], "basic ") {
|
||||
return cfg.Unauthorized(c)
|
||||
}
|
||||
|
||||
// Decode the header contents
|
||||
raw, err := base64.StdEncoding.DecodeString(auth[6:])
|
||||
if err != nil {
|
||||
return cfg.Unauthorized(c)
|
||||
}
|
||||
|
||||
// Get the credentials
|
||||
creds := utils.UnsafeString(raw)
|
||||
|
||||
// Check if the credentials are in the correct form
|
||||
// which is "username:password".
|
||||
index := strings.Index(creds, ":")
|
||||
if index == -1 {
|
||||
return cfg.Unauthorized(c)
|
||||
}
|
||||
|
||||
// Get the username and password
|
||||
username := creds[:index]
|
||||
password := creds[index+1:]
|
||||
|
||||
if cfg.Authorizer(username, password) {
|
||||
c.Locals(cfg.ContextUsername, username)
|
||||
c.Locals(cfg.ContextPassword, password)
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
// Authentication failed
|
||||
return cfg.Unauthorized(c)
|
||||
}
|
||||
}
|
154
middleware/basicauth/basicauth_test.go
Normal file
154
middleware/basicauth/basicauth_test.go
Normal file
|
@ -0,0 +1,154 @@
|
|||
package basicauth
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// go test -run Test_BasicAuth_Next
|
||||
func Test_BasicAuth_Next(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
app.Use(New(Config{
|
||||
Next: func(_ *fiber.Ctx) bool {
|
||||
return true
|
||||
},
|
||||
}))
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
|
||||
}
|
||||
|
||||
func Test_Middleware_BasicAuth(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
Users: map[string]string{
|
||||
"john": "doe",
|
||||
"admin": "123456",
|
||||
},
|
||||
}))
|
||||
|
||||
//nolint:forcetypeassert,errcheck // TODO: Do not force-type assert
|
||||
app.Get("/testauth", func(c *fiber.Ctx) error {
|
||||
username := c.Locals("username").(string)
|
||||
password := c.Locals("password").(string)
|
||||
|
||||
return c.SendString(username + password)
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
url string
|
||||
statusCode int
|
||||
username string
|
||||
password string
|
||||
}{
|
||||
{
|
||||
url: "/testauth",
|
||||
statusCode: 200,
|
||||
username: "john",
|
||||
password: "doe",
|
||||
},
|
||||
{
|
||||
url: "/testauth",
|
||||
statusCode: 200,
|
||||
username: "admin",
|
||||
password: "123456",
|
||||
},
|
||||
{
|
||||
url: "/testauth",
|
||||
statusCode: 401,
|
||||
username: "ee",
|
||||
password: "123456",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
// Base64 encode credentials for http auth header
|
||||
creds := base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", tt.username, tt.password)))
|
||||
|
||||
req := httptest.NewRequest(fiber.MethodGet, "/testauth", nil)
|
||||
req.Header.Add("Authorization", "Basic "+creds)
|
||||
resp, err := app.Test(req)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, tt.statusCode, resp.StatusCode)
|
||||
|
||||
if tt.statusCode == 200 {
|
||||
utils.AssertEqual(t, fmt.Sprintf("%s%s", tt.username, tt.password), string(body))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// go test -v -run=^$ -bench=Benchmark_Middleware_BasicAuth -benchmem -count=4
|
||||
func Benchmark_Middleware_BasicAuth(b *testing.B) {
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
Users: map[string]string{
|
||||
"john": "doe",
|
||||
},
|
||||
}))
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return c.SendStatus(fiber.StatusTeapot)
|
||||
})
|
||||
|
||||
h := app.Handler()
|
||||
|
||||
fctx := &fasthttp.RequestCtx{}
|
||||
fctx.Request.Header.SetMethod(fiber.MethodGet)
|
||||
fctx.Request.SetRequestURI("/")
|
||||
fctx.Request.Header.Set(fiber.HeaderAuthorization, "basic am9objpkb2U=") // john:doe
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
h(fctx)
|
||||
}
|
||||
|
||||
utils.AssertEqual(b, fiber.StatusTeapot, fctx.Response.Header.StatusCode())
|
||||
}
|
||||
|
||||
// go test -v -run=^$ -bench=Benchmark_Middleware_BasicAuth -benchmem -count=4
|
||||
func Benchmark_Middleware_BasicAuth_Upper(b *testing.B) {
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
Users: map[string]string{
|
||||
"john": "doe",
|
||||
},
|
||||
}))
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return c.SendStatus(fiber.StatusTeapot)
|
||||
})
|
||||
|
||||
h := app.Handler()
|
||||
|
||||
fctx := &fasthttp.RequestCtx{}
|
||||
fctx.Request.Header.SetMethod(fiber.MethodGet)
|
||||
fctx.Request.SetRequestURI("/")
|
||||
fctx.Request.Header.Set(fiber.HeaderAuthorization, "Basic am9objpkb2U=") // john:doe
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
h(fctx)
|
||||
}
|
||||
|
||||
utils.AssertEqual(b, fiber.StatusTeapot, fctx.Response.Header.StatusCode())
|
||||
}
|
105
middleware/basicauth/config.go
Normal file
105
middleware/basicauth/config.go
Normal file
|
@ -0,0 +1,105 @@
|
|||
package basicauth
|
||||
|
||||
import (
|
||||
"crypto/subtle"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
)
|
||||
|
||||
// Config defines the config for middleware.
|
||||
type Config struct {
|
||||
// Next defines a function to skip this middleware when returned true.
|
||||
//
|
||||
// Optional. Default: nil
|
||||
Next func(c *fiber.Ctx) bool
|
||||
|
||||
// Users defines the allowed credentials
|
||||
//
|
||||
// Required. Default: map[string]string{}
|
||||
Users map[string]string
|
||||
|
||||
// Realm is a string to define realm attribute of BasicAuth.
|
||||
// the realm identifies the system to authenticate against
|
||||
// and can be used by clients to save credentials
|
||||
//
|
||||
// Optional. Default: "Restricted".
|
||||
Realm string
|
||||
|
||||
// Authorizer defines a function you can pass
|
||||
// to check the credentials however you want.
|
||||
// It will be called with a username and password
|
||||
// and is expected to return true or false to indicate
|
||||
// that the credentials were approved or not.
|
||||
//
|
||||
// Optional. Default: nil.
|
||||
Authorizer func(string, string) bool
|
||||
|
||||
// Unauthorized defines the response body for unauthorized responses.
|
||||
// By default it will return with a 401 Unauthorized and the correct WWW-Auth header
|
||||
//
|
||||
// Optional. Default: nil
|
||||
Unauthorized fiber.Handler
|
||||
|
||||
// ContextUser is the key to store the username in Locals
|
||||
//
|
||||
// Optional. Default: "username"
|
||||
ContextUsername interface{}
|
||||
|
||||
// ContextPass is the key to store the password in Locals
|
||||
//
|
||||
// Optional. Default: "password"
|
||||
ContextPassword interface{}
|
||||
}
|
||||
|
||||
// ConfigDefault is the default config
|
||||
var ConfigDefault = Config{
|
||||
Next: nil,
|
||||
Users: map[string]string{},
|
||||
Realm: "Restricted",
|
||||
Authorizer: nil,
|
||||
Unauthorized: nil,
|
||||
ContextUsername: "username",
|
||||
ContextPassword: "password",
|
||||
}
|
||||
|
||||
// Helper function to set default values
|
||||
func configDefault(config ...Config) Config {
|
||||
// Return default config if nothing provided
|
||||
if len(config) < 1 {
|
||||
return ConfigDefault
|
||||
}
|
||||
|
||||
// Override default config
|
||||
cfg := config[0]
|
||||
|
||||
// Set default values
|
||||
if cfg.Next == nil {
|
||||
cfg.Next = ConfigDefault.Next
|
||||
}
|
||||
if cfg.Users == nil {
|
||||
cfg.Users = ConfigDefault.Users
|
||||
}
|
||||
if cfg.Realm == "" {
|
||||
cfg.Realm = ConfigDefault.Realm
|
||||
}
|
||||
if cfg.Authorizer == nil {
|
||||
cfg.Authorizer = func(user, pass string) bool {
|
||||
userPwd, exist := cfg.Users[user]
|
||||
return exist && subtle.ConstantTimeCompare(utils.UnsafeBytes(userPwd), utils.UnsafeBytes(pass)) == 1
|
||||
}
|
||||
}
|
||||
if cfg.Unauthorized == nil {
|
||||
cfg.Unauthorized = func(c *fiber.Ctx) error {
|
||||
c.Set(fiber.HeaderWWWAuthenticate, "basic realm="+cfg.Realm)
|
||||
return c.SendStatus(fiber.StatusUnauthorized)
|
||||
}
|
||||
}
|
||||
if cfg.ContextUsername == nil {
|
||||
cfg.ContextUsername = ConfigDefault.ContextUsername
|
||||
}
|
||||
if cfg.ContextPassword == nil {
|
||||
cfg.ContextPassword = ConfigDefault.ContextPassword
|
||||
}
|
||||
return cfg
|
||||
}
|
252
middleware/cache/cache.go
vendored
Normal file
252
middleware/cache/cache.go
vendored
Normal file
|
@ -0,0 +1,252 @@
|
|||
// Special thanks to @codemicro for moving this to fiber core
|
||||
// Original middleware: github.com/codemicro/fiber-cache
|
||||
package cache
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
)
|
||||
|
||||
// timestampUpdatePeriod is the period which is used to check the cache expiration.
|
||||
// It should not be too long to provide more or less acceptable expiration error, and in the same
|
||||
// time it should not be too short to avoid overwhelming of the system
|
||||
const timestampUpdatePeriod = 300 * time.Millisecond
|
||||
|
||||
// cache status
|
||||
// unreachable: when cache is bypass, or invalid
|
||||
// hit: cache is served
|
||||
// miss: do not have cache record
|
||||
const (
|
||||
cacheUnreachable = "unreachable"
|
||||
cacheHit = "hit"
|
||||
cacheMiss = "miss"
|
||||
)
|
||||
|
||||
// directives
|
||||
const (
|
||||
noCache = "no-cache"
|
||||
noStore = "no-store"
|
||||
)
|
||||
|
||||
var ignoreHeaders = map[string]interface{}{
|
||||
"Connection": nil,
|
||||
"Keep-Alive": nil,
|
||||
"Proxy-Authenticate": nil,
|
||||
"Proxy-Authorization": nil,
|
||||
"TE": nil,
|
||||
"Trailers": nil,
|
||||
"Transfer-Encoding": nil,
|
||||
"Upgrade": nil,
|
||||
"Content-Type": nil, // already stored explicitly by the cache manager
|
||||
"Content-Encoding": nil, // already stored explicitly by the cache manager
|
||||
}
|
||||
|
||||
// New creates a new middleware handler
|
||||
func New(config ...Config) fiber.Handler {
|
||||
// Set default config
|
||||
cfg := configDefault(config...)
|
||||
|
||||
// Nothing to cache
|
||||
if int(cfg.Expiration.Seconds()) < 0 {
|
||||
return func(c *fiber.Ctx) error {
|
||||
return c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
// Cache settings
|
||||
mux = &sync.RWMutex{}
|
||||
timestamp = uint64(time.Now().Unix())
|
||||
)
|
||||
// Create manager to simplify storage operations ( see manager.go )
|
||||
manager := newManager(cfg.Storage)
|
||||
// Create indexed heap for tracking expirations ( see heap.go )
|
||||
heap := &indexedHeap{}
|
||||
// count stored bytes (sizes of response bodies)
|
||||
var storedBytes uint
|
||||
|
||||
// Update timestamp in the configured interval
|
||||
go func() {
|
||||
for {
|
||||
atomic.StoreUint64(×tamp, uint64(time.Now().Unix()))
|
||||
time.Sleep(timestampUpdatePeriod)
|
||||
}
|
||||
}()
|
||||
|
||||
// Delete key from both manager and storage
|
||||
deleteKey := func(dkey string) {
|
||||
manager.del(dkey)
|
||||
// External storage saves body data with different key
|
||||
if cfg.Storage != nil {
|
||||
manager.del(dkey + "_body")
|
||||
}
|
||||
}
|
||||
|
||||
// Return new handler
|
||||
return func(c *fiber.Ctx) error {
|
||||
// Refrain from caching
|
||||
if hasRequestDirective(c, noStore) {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
// Only cache selected methods
|
||||
var isExists bool
|
||||
for _, method := range cfg.Methods {
|
||||
if c.Method() == method {
|
||||
isExists = true
|
||||
}
|
||||
}
|
||||
|
||||
if !isExists {
|
||||
c.Set(cfg.CacheHeader, cacheUnreachable)
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
// Get key from request
|
||||
// TODO(allocation optimization): try to minimize the allocation from 2 to 1
|
||||
key := cfg.KeyGenerator(c) + "_" + c.Method()
|
||||
|
||||
// Get entry from pool
|
||||
e := manager.get(key)
|
||||
|
||||
// Lock entry
|
||||
mux.Lock()
|
||||
|
||||
// Get timestamp
|
||||
ts := atomic.LoadUint64(×tamp)
|
||||
|
||||
// Check if entry is expired
|
||||
if e.exp != 0 && ts >= e.exp {
|
||||
deleteKey(key)
|
||||
if cfg.MaxBytes > 0 {
|
||||
_, size := heap.remove(e.heapidx)
|
||||
storedBytes -= size
|
||||
}
|
||||
} else if e.exp != 0 && !hasRequestDirective(c, noCache) {
|
||||
// Separate body value to avoid msgp serialization
|
||||
// We can store raw bytes with Storage 👍
|
||||
if cfg.Storage != nil {
|
||||
e.body = manager.getRaw(key + "_body")
|
||||
}
|
||||
// Set response headers from cache
|
||||
c.Response().SetBodyRaw(e.body)
|
||||
c.Response().SetStatusCode(e.status)
|
||||
c.Response().Header.SetContentTypeBytes(e.ctype)
|
||||
if len(e.cencoding) > 0 {
|
||||
c.Response().Header.SetBytesV(fiber.HeaderContentEncoding, e.cencoding)
|
||||
}
|
||||
for k, v := range e.headers {
|
||||
c.Response().Header.SetBytesV(k, v)
|
||||
}
|
||||
// Set Cache-Control header if enabled
|
||||
if cfg.CacheControl {
|
||||
maxAge := strconv.FormatUint(e.exp-ts, 10)
|
||||
c.Set(fiber.HeaderCacheControl, "public, max-age="+maxAge)
|
||||
}
|
||||
|
||||
c.Set(cfg.CacheHeader, cacheHit)
|
||||
|
||||
mux.Unlock()
|
||||
|
||||
// Return response
|
||||
return nil
|
||||
}
|
||||
|
||||
// make sure we're not blocking concurrent requests - do unlock
|
||||
mux.Unlock()
|
||||
|
||||
// Continue stack, return err to Fiber if exist
|
||||
if err := c.Next(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// lock entry back and unlock on finish
|
||||
mux.Lock()
|
||||
defer mux.Unlock()
|
||||
|
||||
// Don't cache response if Next returns true
|
||||
if cfg.Next != nil && cfg.Next(c) {
|
||||
c.Set(cfg.CacheHeader, cacheUnreachable)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Don't try to cache if body won't fit into cache
|
||||
bodySize := uint(len(c.Response().Body()))
|
||||
if cfg.MaxBytes > 0 && bodySize > cfg.MaxBytes {
|
||||
c.Set(cfg.CacheHeader, cacheUnreachable)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Remove oldest to make room for new
|
||||
if cfg.MaxBytes > 0 {
|
||||
for storedBytes+bodySize > cfg.MaxBytes {
|
||||
key, size := heap.removeFirst()
|
||||
deleteKey(key)
|
||||
storedBytes -= size
|
||||
}
|
||||
}
|
||||
|
||||
// Cache response
|
||||
e.body = utils.CopyBytes(c.Response().Body())
|
||||
e.status = c.Response().StatusCode()
|
||||
e.ctype = utils.CopyBytes(c.Response().Header.ContentType())
|
||||
e.cencoding = utils.CopyBytes(c.Response().Header.Peek(fiber.HeaderContentEncoding))
|
||||
|
||||
// Store all response headers
|
||||
// (more: https://datatracker.ietf.org/doc/html/rfc2616#section-13.5.1)
|
||||
if cfg.StoreResponseHeaders {
|
||||
e.headers = make(map[string][]byte)
|
||||
c.Response().Header.VisitAll(
|
||||
func(key, value []byte) {
|
||||
// create real copy
|
||||
keyS := string(key)
|
||||
if _, ok := ignoreHeaders[keyS]; !ok {
|
||||
e.headers[keyS] = utils.CopyBytes(value)
|
||||
}
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// default cache expiration
|
||||
expiration := cfg.Expiration
|
||||
// Calculate expiration by response header or other setting
|
||||
if cfg.ExpirationGenerator != nil {
|
||||
expiration = cfg.ExpirationGenerator(c, &cfg)
|
||||
}
|
||||
e.exp = ts + uint64(expiration.Seconds())
|
||||
|
||||
// Store entry in heap
|
||||
if cfg.MaxBytes > 0 {
|
||||
e.heapidx = heap.put(key, e.exp, bodySize)
|
||||
storedBytes += bodySize
|
||||
}
|
||||
|
||||
// For external Storage we store raw body separated
|
||||
if cfg.Storage != nil {
|
||||
manager.setRaw(key+"_body", e.body, expiration)
|
||||
// avoid body msgp encoding
|
||||
e.body = nil
|
||||
manager.set(key, e, expiration)
|
||||
manager.release(e)
|
||||
} else {
|
||||
// Store entry in memory
|
||||
manager.set(key, e, expiration)
|
||||
}
|
||||
|
||||
c.Set(cfg.CacheHeader, cacheMiss)
|
||||
|
||||
// Finish response
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Check if request has directive
|
||||
func hasRequestDirective(c *fiber.Ctx, directive string) bool {
|
||||
return strings.Contains(c.Get(fiber.HeaderCacheControl), directive)
|
||||
}
|
901
middleware/cache/cache_test.go
vendored
Normal file
901
middleware/cache/cache_test.go
vendored
Normal file
|
@ -0,0 +1,901 @@
|
|||
// Special thanks to @codemicro for moving this to fiber core
|
||||
// Original middleware: github.com/codemicro/fiber-cache
|
||||
package cache
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/internal/storage/memory"
|
||||
"github.com/gofiber/fiber/v2/middleware/etag"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
func Test_Cache_CacheControl(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
CacheControl: true,
|
||||
Expiration: 10 * time.Second,
|
||||
}))
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return c.SendString("Hello, World!")
|
||||
})
|
||||
|
||||
_, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, "public, max-age=10", resp.Header.Get(fiber.HeaderCacheControl))
|
||||
}
|
||||
|
||||
func Test_Cache_Expired(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app := fiber.New()
|
||||
app.Use(New(Config{Expiration: 2 * time.Second}))
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return c.SendString(fmt.Sprintf("%d", time.Now().UnixNano()))
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
// Sleep until the cache is expired
|
||||
time.Sleep(3 * time.Second)
|
||||
|
||||
respCached, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
bodyCached, err := io.ReadAll(respCached.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
if bytes.Equal(body, bodyCached) {
|
||||
t.Errorf("Cache should have expired: %s, %s", body, bodyCached)
|
||||
}
|
||||
|
||||
// Next response should be also cached
|
||||
respCachedNextRound, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
bodyCachedNextRound, err := io.ReadAll(respCachedNextRound.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
if !bytes.Equal(bodyCachedNextRound, bodyCached) {
|
||||
t.Errorf("Cache should not have expired: %s, %s", bodyCached, bodyCachedNextRound)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Cache(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app := fiber.New()
|
||||
app.Use(New())
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
now := fmt.Sprintf("%d", time.Now().UnixNano())
|
||||
return c.SendString(now)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||
resp, err := app.Test(req)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
cachedReq := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||
cachedResp, err := app.Test(cachedReq)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
cachedBody, err := io.ReadAll(cachedResp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
utils.AssertEqual(t, cachedBody, body)
|
||||
}
|
||||
|
||||
// go test -run Test_Cache_WithNoCacheRequestDirective
|
||||
func Test_Cache_WithNoCacheRequestDirective(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app := fiber.New()
|
||||
app.Use(New())
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return c.SendString(c.Query("id", "1"))
|
||||
})
|
||||
|
||||
// Request id = 1
|
||||
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||
resp, err := app.Test(req)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, cacheMiss, resp.Header.Get("X-Cache"))
|
||||
utils.AssertEqual(t, []byte("1"), body)
|
||||
// Response cached, entry id = 1
|
||||
|
||||
// Request id = 2 without Cache-Control: no-cache
|
||||
cachedReq := httptest.NewRequest(fiber.MethodGet, "/?id=2", nil)
|
||||
cachedResp, err := app.Test(cachedReq)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
cachedBody, err := io.ReadAll(cachedResp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, cacheHit, cachedResp.Header.Get("X-Cache"))
|
||||
utils.AssertEqual(t, []byte("1"), cachedBody)
|
||||
// Response not cached, returns cached response, entry id = 1
|
||||
|
||||
// Request id = 2 with Cache-Control: no-cache
|
||||
noCacheReq := httptest.NewRequest(fiber.MethodGet, "/?id=2", nil)
|
||||
noCacheReq.Header.Set(fiber.HeaderCacheControl, noCache)
|
||||
noCacheResp, err := app.Test(noCacheReq)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
noCacheBody, err := io.ReadAll(noCacheResp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, cacheMiss, noCacheResp.Header.Get("X-Cache"))
|
||||
utils.AssertEqual(t, []byte("2"), noCacheBody)
|
||||
// Response cached, returns updated response, entry = 2
|
||||
|
||||
/* Check Test_Cache_WithETagAndNoCacheRequestDirective */
|
||||
// Request id = 2 with Cache-Control: no-cache again
|
||||
noCacheReq1 := httptest.NewRequest(fiber.MethodGet, "/?id=2", nil)
|
||||
noCacheReq1.Header.Set(fiber.HeaderCacheControl, noCache)
|
||||
noCacheResp1, err := app.Test(noCacheReq1)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
noCacheBody1, err := io.ReadAll(noCacheResp1.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, cacheMiss, noCacheResp1.Header.Get("X-Cache"))
|
||||
utils.AssertEqual(t, []byte("2"), noCacheBody1)
|
||||
// Response cached, returns updated response, entry = 2
|
||||
|
||||
// Request id = 1 without Cache-Control: no-cache
|
||||
cachedReq1 := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||
cachedResp1, err := app.Test(cachedReq1)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
cachedBody1, err := io.ReadAll(cachedResp1.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, cacheHit, cachedResp1.Header.Get("X-Cache"))
|
||||
utils.AssertEqual(t, []byte("2"), cachedBody1)
|
||||
// Response not cached, returns cached response, entry id = 2
|
||||
}
|
||||
|
||||
// go test -run Test_Cache_WithETagAndNoCacheRequestDirective
|
||||
func Test_Cache_WithETagAndNoCacheRequestDirective(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app := fiber.New()
|
||||
app.Use(
|
||||
etag.New(),
|
||||
New(),
|
||||
)
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return c.SendString(c.Query("id", "1"))
|
||||
})
|
||||
|
||||
// Request id = 1
|
||||
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||
resp, err := app.Test(req)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, cacheMiss, resp.Header.Get("X-Cache"))
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||
// Response cached, entry id = 1
|
||||
|
||||
// If response status 200
|
||||
etagToken := resp.Header.Get("Etag")
|
||||
|
||||
// Request id = 2 with ETag but without Cache-Control: no-cache
|
||||
cachedReq := httptest.NewRequest(fiber.MethodGet, "/?id=2", nil)
|
||||
cachedReq.Header.Set(fiber.HeaderIfNoneMatch, etagToken)
|
||||
cachedResp, err := app.Test(cachedReq)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, cacheHit, cachedResp.Header.Get("X-Cache"))
|
||||
utils.AssertEqual(t, fiber.StatusNotModified, cachedResp.StatusCode)
|
||||
// Response not cached, returns cached response, entry id = 1, status not modified
|
||||
|
||||
// Request id = 2 with ETag and Cache-Control: no-cache
|
||||
noCacheReq := httptest.NewRequest(fiber.MethodGet, "/?id=2", nil)
|
||||
noCacheReq.Header.Set(fiber.HeaderCacheControl, noCache)
|
||||
noCacheReq.Header.Set(fiber.HeaderIfNoneMatch, etagToken)
|
||||
noCacheResp, err := app.Test(noCacheReq)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, cacheMiss, noCacheResp.Header.Get("X-Cache"))
|
||||
utils.AssertEqual(t, fiber.StatusOK, noCacheResp.StatusCode)
|
||||
// Response cached, returns updated response, entry id = 2
|
||||
|
||||
// If response status 200
|
||||
etagToken = noCacheResp.Header.Get("Etag")
|
||||
|
||||
// Request id = 2 with ETag and Cache-Control: no-cache again
|
||||
noCacheReq1 := httptest.NewRequest(fiber.MethodGet, "/?id=2", nil)
|
||||
noCacheReq1.Header.Set(fiber.HeaderCacheControl, noCache)
|
||||
noCacheReq1.Header.Set(fiber.HeaderIfNoneMatch, etagToken)
|
||||
noCacheResp1, err := app.Test(noCacheReq1)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, cacheMiss, noCacheResp1.Header.Get("X-Cache"))
|
||||
utils.AssertEqual(t, fiber.StatusNotModified, noCacheResp1.StatusCode)
|
||||
// Response cached, returns updated response, entry id = 2, status not modified
|
||||
|
||||
// Request id = 1 without ETag and Cache-Control: no-cache
|
||||
cachedReq1 := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||
cachedResp1, err := app.Test(cachedReq1)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, cacheHit, cachedResp1.Header.Get("X-Cache"))
|
||||
utils.AssertEqual(t, fiber.StatusOK, cachedResp1.StatusCode)
|
||||
// Response not cached, returns cached response, entry id = 2
|
||||
}
|
||||
|
||||
// go test -run Test_Cache_WithNoStoreRequestDirective
|
||||
func Test_Cache_WithNoStoreRequestDirective(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app := fiber.New()
|
||||
app.Use(New())
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return c.SendString(c.Query("id", "1"))
|
||||
})
|
||||
|
||||
// Request id = 2
|
||||
noStoreReq := httptest.NewRequest(fiber.MethodGet, "/?id=2", nil)
|
||||
noStoreReq.Header.Set(fiber.HeaderCacheControl, noStore)
|
||||
noStoreResp, err := app.Test(noStoreReq)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
noStoreBody, err := io.ReadAll(noStoreResp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, []byte("2"), noStoreBody)
|
||||
// Response not cached, returns updated response
|
||||
}
|
||||
|
||||
func Test_Cache_WithSeveralRequests(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
CacheControl: true,
|
||||
Expiration: 10 * time.Second,
|
||||
}))
|
||||
|
||||
app.Get("/:id", func(c *fiber.Ctx) error {
|
||||
return c.SendString(c.Params("id"))
|
||||
})
|
||||
|
||||
for runs := 0; runs < 10; runs++ {
|
||||
for i := 0; i < 10; i++ {
|
||||
func(id int) {
|
||||
rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, fmt.Sprintf("/%d", id), nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
defer func(body io.ReadCloser) {
|
||||
err := body.Close()
|
||||
utils.AssertEqual(t, nil, err)
|
||||
}(rsp.Body)
|
||||
|
||||
idFromServ, err := io.ReadAll(rsp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
a, err := strconv.Atoi(string(idFromServ))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
// SomeTimes,The id is not equal with a
|
||||
utils.AssertEqual(t, id, a)
|
||||
}(i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Cache_Invalid_Expiration(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app := fiber.New()
|
||||
cache := New(Config{Expiration: 0 * time.Second})
|
||||
app.Use(cache)
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
now := fmt.Sprintf("%d", time.Now().UnixNano())
|
||||
return c.SendString(now)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||
resp, err := app.Test(req)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
cachedReq := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||
cachedResp, err := app.Test(cachedReq)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
cachedBody, err := io.ReadAll(cachedResp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
utils.AssertEqual(t, cachedBody, body)
|
||||
}
|
||||
|
||||
func Test_Cache_Get(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New())
|
||||
|
||||
app.Post("/", func(c *fiber.Ctx) error {
|
||||
return c.SendString(c.Query("cache"))
|
||||
})
|
||||
|
||||
app.Get("/get", func(c *fiber.Ctx) error {
|
||||
return c.SendString(c.Query("cache"))
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodPost, "/?cache=123", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, "123", string(body))
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodPost, "/?cache=12345", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
body, err = io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, "12345", string(body))
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/get?cache=123", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
body, err = io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, "123", string(body))
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/get?cache=12345", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
body, err = io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, "123", string(body))
|
||||
}
|
||||
|
||||
func Test_Cache_Post(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
Methods: []string{fiber.MethodPost},
|
||||
}))
|
||||
|
||||
app.Post("/", func(c *fiber.Ctx) error {
|
||||
return c.SendString(c.Query("cache"))
|
||||
})
|
||||
|
||||
app.Get("/get", func(c *fiber.Ctx) error {
|
||||
return c.SendString(c.Query("cache"))
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodPost, "/?cache=123", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, "123", string(body))
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodPost, "/?cache=12345", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
body, err = io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, "123", string(body))
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/get?cache=123", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
body, err = io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, "123", string(body))
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/get?cache=12345", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
body, err = io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, "12345", string(body))
|
||||
}
|
||||
|
||||
func Test_Cache_NothingToCache(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New(Config{Expiration: -(time.Second * 1)}))
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return c.SendString(time.Now().String())
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
respCached, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
bodyCached, err := io.ReadAll(respCached.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
if bytes.Equal(body, bodyCached) {
|
||||
t.Errorf("Cache should have expired: %s, %s", body, bodyCached)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Cache_CustomNext(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
Next: func(c *fiber.Ctx) bool {
|
||||
return c.Response().StatusCode() != fiber.StatusOK
|
||||
},
|
||||
CacheControl: true,
|
||||
}))
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return c.SendString(time.Now().String())
|
||||
})
|
||||
|
||||
app.Get("/error", func(c *fiber.Ctx) error {
|
||||
return c.Status(fiber.StatusInternalServerError).SendString(time.Now().String())
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
respCached, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
bodyCached, err := io.ReadAll(respCached.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, true, bytes.Equal(body, bodyCached))
|
||||
utils.AssertEqual(t, true, respCached.Header.Get(fiber.HeaderCacheControl) != "")
|
||||
|
||||
_, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/error", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
errRespCached, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/error", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, true, errRespCached.Header.Get(fiber.HeaderCacheControl) == "")
|
||||
}
|
||||
|
||||
func Test_CustomKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app := fiber.New()
|
||||
var called bool
|
||||
app.Use(New(Config{KeyGenerator: func(c *fiber.Ctx) string {
|
||||
called = true
|
||||
return utils.CopyString(c.Path())
|
||||
}}))
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return c.SendString("hi")
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||
_, err := app.Test(req)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, true, called)
|
||||
}
|
||||
|
||||
func Test_CustomExpiration(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app := fiber.New()
|
||||
var called bool
|
||||
var newCacheTime int
|
||||
app.Use(New(Config{ExpirationGenerator: func(c *fiber.Ctx, cfg *Config) time.Duration {
|
||||
called = true
|
||||
var err error
|
||||
newCacheTime, err = strconv.Atoi(c.GetRespHeader("Cache-Time", "600"))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
return time.Second * time.Duration(newCacheTime)
|
||||
}}))
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
c.Response().Header.Add("Cache-Time", "1")
|
||||
now := fmt.Sprintf("%d", time.Now().UnixNano())
|
||||
return c.SendString(now)
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, true, called)
|
||||
utils.AssertEqual(t, 1, newCacheTime)
|
||||
|
||||
// Sleep until the cache is expired
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
cachedResp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
cachedBody, err := io.ReadAll(cachedResp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
if bytes.Equal(body, cachedBody) {
|
||||
t.Errorf("Cache should have expired: %s, %s", body, cachedBody)
|
||||
}
|
||||
|
||||
// Next response should be cached
|
||||
cachedRespNextRound, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
cachedBodyNextRound, err := io.ReadAll(cachedRespNextRound.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
if !bytes.Equal(cachedBodyNextRound, cachedBody) {
|
||||
t.Errorf("Cache should not have expired: %s, %s", cachedBodyNextRound, cachedBody)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_AdditionalE2EResponseHeaders(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app := fiber.New()
|
||||
app.Use(New(Config{
|
||||
StoreResponseHeaders: true,
|
||||
}))
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
c.Response().Header.Add("X-Foobar", "foobar")
|
||||
return c.SendString("hi")
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||
resp, err := app.Test(req)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, "foobar", resp.Header.Get("X-Foobar"))
|
||||
|
||||
req = httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||
resp, err = app.Test(req)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, "foobar", resp.Header.Get("X-Foobar"))
|
||||
}
|
||||
|
||||
func Test_CacheHeader(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
Expiration: 10 * time.Second,
|
||||
Next: func(c *fiber.Ctx) bool {
|
||||
return c.Response().StatusCode() != fiber.StatusOK
|
||||
},
|
||||
}))
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return c.SendString("Hello, World!")
|
||||
})
|
||||
|
||||
app.Post("/", func(c *fiber.Ctx) error {
|
||||
return c.SendString(c.Query("cache"))
|
||||
})
|
||||
|
||||
app.Get("/error", func(c *fiber.Ctx) error {
|
||||
return c.Status(fiber.StatusInternalServerError).SendString(time.Now().String())
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, cacheMiss, resp.Header.Get("X-Cache"))
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, cacheHit, resp.Header.Get("X-Cache"))
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodPost, "/?cache=12345", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, cacheUnreachable, resp.Header.Get("X-Cache"))
|
||||
|
||||
errRespCached, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/error", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, cacheUnreachable, errRespCached.Header.Get("X-Cache"))
|
||||
}
|
||||
|
||||
func Test_Cache_WithHead(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app := fiber.New()
|
||||
app.Use(New())
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
now := fmt.Sprintf("%d", time.Now().UnixNano())
|
||||
return c.SendString(now)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(fiber.MethodHead, "/", nil)
|
||||
resp, err := app.Test(req)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, cacheMiss, resp.Header.Get("X-Cache"))
|
||||
|
||||
cachedReq := httptest.NewRequest(fiber.MethodHead, "/", nil)
|
||||
cachedResp, err := app.Test(cachedReq)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, cacheHit, cachedResp.Header.Get("X-Cache"))
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
cachedBody, err := io.ReadAll(cachedResp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
utils.AssertEqual(t, cachedBody, body)
|
||||
}
|
||||
|
||||
func Test_Cache_WithHeadThenGet(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app := fiber.New()
|
||||
app.Use(New())
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return c.SendString(c.Query("cache"))
|
||||
})
|
||||
|
||||
headResp, err := app.Test(httptest.NewRequest(fiber.MethodHead, "/?cache=123", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
headBody, err := io.ReadAll(headResp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, "", string(headBody))
|
||||
utils.AssertEqual(t, cacheMiss, headResp.Header.Get("X-Cache"))
|
||||
|
||||
headResp, err = app.Test(httptest.NewRequest(fiber.MethodHead, "/?cache=123", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
headBody, err = io.ReadAll(headResp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, "", string(headBody))
|
||||
utils.AssertEqual(t, cacheHit, headResp.Header.Get("X-Cache"))
|
||||
|
||||
getResp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/?cache=123", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
getBody, err := io.ReadAll(getResp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, "123", string(getBody))
|
||||
utils.AssertEqual(t, cacheMiss, getResp.Header.Get("X-Cache"))
|
||||
|
||||
getResp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/?cache=123", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
getBody, err = io.ReadAll(getResp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, "123", string(getBody))
|
||||
utils.AssertEqual(t, cacheHit, getResp.Header.Get("X-Cache"))
|
||||
}
|
||||
|
||||
func Test_CustomCacheHeader(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
CacheHeader: "Cache-Status",
|
||||
}))
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return c.SendString("Hello, World!")
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, cacheMiss, resp.Header.Get("Cache-Status"))
|
||||
}
|
||||
|
||||
// Because time points are updated once every X milliseconds, entries in tests can often have
|
||||
// equal expiration times and thus be in an random order. This closure hands out increasing
|
||||
// time intervals to maintain strong ascending order of expiration
|
||||
func stableAscendingExpiration() func(c1 *fiber.Ctx, c2 *Config) time.Duration {
|
||||
i := 0
|
||||
return func(c1 *fiber.Ctx, c2 *Config) time.Duration {
|
||||
i++
|
||||
return time.Hour * time.Duration(i)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Cache_MaxBytesOrder(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app := fiber.New()
|
||||
app.Use(New(Config{
|
||||
MaxBytes: 2,
|
||||
ExpirationGenerator: stableAscendingExpiration(),
|
||||
}))
|
||||
|
||||
app.Get("/*", func(c *fiber.Ctx) error {
|
||||
return c.SendString("1")
|
||||
})
|
||||
|
||||
cases := [][]string{
|
||||
// Insert a, b into cache of size 2 bytes (responses are 1 byte)
|
||||
{"/a", cacheMiss},
|
||||
{"/b", cacheMiss},
|
||||
{"/a", cacheHit},
|
||||
{"/b", cacheHit},
|
||||
// Add c -> a evicted
|
||||
{"/c", cacheMiss},
|
||||
{"/b", cacheHit},
|
||||
// Add a again -> b evicted
|
||||
{"/a", cacheMiss},
|
||||
{"/c", cacheHit},
|
||||
// Add b -> c evicted
|
||||
{"/b", cacheMiss},
|
||||
{"/c", cacheMiss},
|
||||
}
|
||||
|
||||
for idx, tcase := range cases {
|
||||
rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, tcase[0], nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, tcase[1], rsp.Header.Get("X-Cache"), fmt.Sprintf("Case %v", idx))
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Cache_MaxBytesSizes(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
MaxBytes: 7,
|
||||
ExpirationGenerator: stableAscendingExpiration(),
|
||||
}))
|
||||
|
||||
app.Get("/*", func(c *fiber.Ctx) error {
|
||||
path := c.Context().URI().LastPathSegment()
|
||||
size, err := strconv.Atoi(string(path))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
return c.Send(make([]byte, size))
|
||||
})
|
||||
|
||||
cases := [][]string{
|
||||
{"/1", cacheMiss},
|
||||
{"/2", cacheMiss},
|
||||
{"/3", cacheMiss},
|
||||
{"/4", cacheMiss}, // 1+2+3+4 > 7 => 1,2 are evicted now
|
||||
{"/3", cacheHit},
|
||||
{"/1", cacheMiss},
|
||||
{"/2", cacheMiss},
|
||||
{"/8", cacheUnreachable}, // too big to cache -> unreachable
|
||||
}
|
||||
|
||||
for idx, tcase := range cases {
|
||||
rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, tcase[0], nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, tcase[1], rsp.Header.Get("X-Cache"), fmt.Sprintf("Case %v", idx))
|
||||
}
|
||||
}
|
||||
|
||||
// go test -v -run=^$ -bench=Benchmark_Cache -benchmem -count=4
|
||||
func Benchmark_Cache(b *testing.B) {
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New())
|
||||
|
||||
app.Get("/demo", func(c *fiber.Ctx) error {
|
||||
data, _ := os.ReadFile("../../.github/README.md") //nolint:errcheck // We're inside a benchmark
|
||||
return c.Status(fiber.StatusTeapot).Send(data)
|
||||
})
|
||||
|
||||
h := app.Handler()
|
||||
|
||||
fctx := &fasthttp.RequestCtx{}
|
||||
fctx.Request.Header.SetMethod(fiber.MethodGet)
|
||||
fctx.Request.SetRequestURI("/demo")
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
h(fctx)
|
||||
}
|
||||
|
||||
utils.AssertEqual(b, fiber.StatusTeapot, fctx.Response.Header.StatusCode())
|
||||
utils.AssertEqual(b, true, len(fctx.Response.Body()) > 30000)
|
||||
}
|
||||
|
||||
// go test -v -run=^$ -bench=Benchmark_Cache_Storage -benchmem -count=4
|
||||
func Benchmark_Cache_Storage(b *testing.B) {
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
Storage: memory.New(),
|
||||
}))
|
||||
|
||||
app.Get("/demo", func(c *fiber.Ctx) error {
|
||||
data, _ := os.ReadFile("../../.github/README.md") //nolint:errcheck // We're inside a benchmark
|
||||
return c.Status(fiber.StatusTeapot).Send(data)
|
||||
})
|
||||
|
||||
h := app.Handler()
|
||||
|
||||
fctx := &fasthttp.RequestCtx{}
|
||||
fctx.Request.Header.SetMethod(fiber.MethodGet)
|
||||
fctx.Request.SetRequestURI("/demo")
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
h(fctx)
|
||||
}
|
||||
|
||||
utils.AssertEqual(b, fiber.StatusTeapot, fctx.Response.Header.StatusCode())
|
||||
utils.AssertEqual(b, true, len(fctx.Response.Body()) > 30000)
|
||||
}
|
||||
|
||||
func Benchmark_Cache_AdditionalHeaders(b *testing.B) {
|
||||
app := fiber.New()
|
||||
app.Use(New(Config{
|
||||
StoreResponseHeaders: true,
|
||||
}))
|
||||
|
||||
app.Get("/demo", func(c *fiber.Ctx) error {
|
||||
c.Response().Header.Add("X-Foobar", "foobar")
|
||||
return c.SendStatus(418)
|
||||
})
|
||||
|
||||
h := app.Handler()
|
||||
|
||||
fctx := &fasthttp.RequestCtx{}
|
||||
fctx.Request.Header.SetMethod(fiber.MethodGet)
|
||||
fctx.Request.SetRequestURI("/demo")
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
h(fctx)
|
||||
}
|
||||
|
||||
utils.AssertEqual(b, fiber.StatusTeapot, fctx.Response.Header.StatusCode())
|
||||
utils.AssertEqual(b, []byte("foobar"), fctx.Response.Header.Peek("X-Foobar"))
|
||||
}
|
||||
|
||||
func Benchmark_Cache_MaxSize(b *testing.B) {
|
||||
// The benchmark is run with three different MaxSize parameters
|
||||
// 1) 0: Tracking is disabled = no overhead
|
||||
// 2) MaxInt32: Enough to store all entries = no removals
|
||||
// 3) 100: Small size = constant insertions and removals
|
||||
cases := []uint{0, math.MaxUint32, 100}
|
||||
names := []string{"Disabled", "Unlim", "LowBounded"}
|
||||
for i, size := range cases {
|
||||
b.Run(names[i], func(b *testing.B) {
|
||||
app := fiber.New()
|
||||
app.Use(New(Config{MaxBytes: size}))
|
||||
|
||||
app.Get("/*", func(c *fiber.Ctx) error {
|
||||
return c.Status(fiber.StatusTeapot).SendString("1")
|
||||
})
|
||||
|
||||
h := app.Handler()
|
||||
fctx := &fasthttp.RequestCtx{}
|
||||
fctx.Request.Header.SetMethod(fiber.MethodGet)
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
fctx.Request.SetRequestURI(fmt.Sprintf("/%v", n))
|
||||
h(fctx)
|
||||
}
|
||||
|
||||
utils.AssertEqual(b, fiber.StatusTeapot, fctx.Response.Header.StatusCode())
|
||||
})
|
||||
}
|
||||
}
|
128
middleware/cache/config.go
vendored
Normal file
128
middleware/cache/config.go
vendored
Normal file
|
@ -0,0 +1,128 @@
|
|||
package cache
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/log"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
)
|
||||
|
||||
// Config defines the config for middleware.
|
||||
type Config struct {
|
||||
// Next defines a function to skip this middleware when returned true.
|
||||
//
|
||||
// Optional. Default: nil
|
||||
Next func(c *fiber.Ctx) bool
|
||||
|
||||
// Expiration is the time that an cached response will live
|
||||
//
|
||||
// Optional. Default: 1 * time.Minute
|
||||
Expiration time.Duration
|
||||
|
||||
// CacheHeader header on response header, indicate cache status, with the following possible return value
|
||||
//
|
||||
// hit, miss, unreachable
|
||||
//
|
||||
// Optional. Default: X-Cache
|
||||
CacheHeader string
|
||||
|
||||
// CacheControl enables client side caching if set to true
|
||||
//
|
||||
// Optional. Default: false
|
||||
CacheControl bool
|
||||
|
||||
// Key allows you to generate custom keys, by default c.Path() is used
|
||||
//
|
||||
// Default: func(c *fiber.Ctx) string {
|
||||
// return utils.CopyString(c.Path())
|
||||
// }
|
||||
KeyGenerator func(*fiber.Ctx) string
|
||||
|
||||
// allows you to generate custom Expiration Key By Key, default is Expiration (Optional)
|
||||
//
|
||||
// Default: nil
|
||||
ExpirationGenerator func(*fiber.Ctx, *Config) time.Duration
|
||||
|
||||
// Store is used to store the state of the middleware
|
||||
//
|
||||
// Default: an in memory store for this process only
|
||||
Storage fiber.Storage
|
||||
|
||||
// Deprecated: Use Storage instead
|
||||
Store fiber.Storage
|
||||
|
||||
// Deprecated: Use KeyGenerator instead
|
||||
Key func(*fiber.Ctx) string
|
||||
|
||||
// allows you to store additional headers generated by next middlewares & handler
|
||||
//
|
||||
// Default: false
|
||||
StoreResponseHeaders bool
|
||||
|
||||
// Max number of bytes of response bodies simultaneously stored in cache. When limit is reached,
|
||||
// entries with the nearest expiration are deleted to make room for new.
|
||||
// 0 means no limit
|
||||
//
|
||||
// Default: 0
|
||||
MaxBytes uint
|
||||
|
||||
// You can specify HTTP methods to cache.
|
||||
// The middleware just caches the routes of its methods in this slice.
|
||||
//
|
||||
// Default: []string{fiber.MethodGet, fiber.MethodHead}
|
||||
Methods []string
|
||||
}
|
||||
|
||||
// ConfigDefault is the default config
|
||||
var ConfigDefault = Config{
|
||||
Next: nil,
|
||||
Expiration: 1 * time.Minute,
|
||||
CacheHeader: "X-Cache",
|
||||
CacheControl: false,
|
||||
KeyGenerator: func(c *fiber.Ctx) string {
|
||||
return utils.CopyString(c.Path())
|
||||
},
|
||||
ExpirationGenerator: nil,
|
||||
StoreResponseHeaders: false,
|
||||
Storage: nil,
|
||||
MaxBytes: 0,
|
||||
Methods: []string{fiber.MethodGet, fiber.MethodHead},
|
||||
}
|
||||
|
||||
// Helper function to set default values
|
||||
func configDefault(config ...Config) Config {
|
||||
// Return default config if nothing provided
|
||||
if len(config) < 1 {
|
||||
return ConfigDefault
|
||||
}
|
||||
|
||||
// Override default config
|
||||
cfg := config[0]
|
||||
|
||||
// Set default values
|
||||
if cfg.Store != nil {
|
||||
log.Warn("[CACHE] Store is deprecated, please use Storage")
|
||||
cfg.Storage = cfg.Store
|
||||
}
|
||||
if cfg.Key != nil {
|
||||
log.Warn("[CACHE] Key is deprecated, please use KeyGenerator")
|
||||
cfg.KeyGenerator = cfg.Key
|
||||
}
|
||||
if cfg.Next == nil {
|
||||
cfg.Next = ConfigDefault.Next
|
||||
}
|
||||
if int(cfg.Expiration.Seconds()) == 0 {
|
||||
cfg.Expiration = ConfigDefault.Expiration
|
||||
}
|
||||
if cfg.CacheHeader == "" {
|
||||
cfg.CacheHeader = ConfigDefault.CacheHeader
|
||||
}
|
||||
if cfg.KeyGenerator == nil {
|
||||
cfg.KeyGenerator = ConfigDefault.KeyGenerator
|
||||
}
|
||||
if len(cfg.Methods) == 0 {
|
||||
cfg.Methods = ConfigDefault.Methods
|
||||
}
|
||||
return cfg
|
||||
}
|
92
middleware/cache/heap.go
vendored
Normal file
92
middleware/cache/heap.go
vendored
Normal file
|
@ -0,0 +1,92 @@
|
|||
package cache
|
||||
|
||||
import (
|
||||
"container/heap"
|
||||
)
|
||||
|
||||
type heapEntry struct {
|
||||
key string
|
||||
exp uint64
|
||||
bytes uint
|
||||
idx int
|
||||
}
|
||||
|
||||
// indexedHeap is a regular min-heap that allows finding
|
||||
// elements in constant time. It does so by handing out special indices
|
||||
// and tracking entry movement.
|
||||
//
|
||||
// indexdedHeap is used for quickly finding entries with the lowest
|
||||
// expiration timestamp and deleting arbitrary entries.
|
||||
type indexedHeap struct {
|
||||
// Slice the heap is built on
|
||||
entries []heapEntry
|
||||
// Mapping "index" to position in heap slice
|
||||
indices []int
|
||||
// Max index handed out
|
||||
maxidx int
|
||||
}
|
||||
|
||||
func (h indexedHeap) Len() int {
|
||||
return len(h.entries)
|
||||
}
|
||||
|
||||
func (h indexedHeap) Less(i, j int) bool {
|
||||
return h.entries[i].exp < h.entries[j].exp
|
||||
}
|
||||
|
||||
func (h indexedHeap) Swap(i, j int) {
|
||||
h.entries[i], h.entries[j] = h.entries[j], h.entries[i]
|
||||
h.indices[h.entries[i].idx] = i
|
||||
h.indices[h.entries[j].idx] = j
|
||||
}
|
||||
|
||||
func (h *indexedHeap) Push(x interface{}) {
|
||||
h.pushInternal(x.(heapEntry)) //nolint:forcetypeassert // Forced type assertion required to implement the heap.Interface interface
|
||||
}
|
||||
|
||||
func (h *indexedHeap) Pop() interface{} {
|
||||
n := len(h.entries)
|
||||
h.entries = h.entries[0 : n-1]
|
||||
return h.entries[0:n][n-1]
|
||||
}
|
||||
|
||||
func (h *indexedHeap) pushInternal(entry heapEntry) {
|
||||
h.indices[entry.idx] = len(h.entries)
|
||||
h.entries = append(h.entries, entry)
|
||||
}
|
||||
|
||||
// Returns index to track entry
|
||||
func (h *indexedHeap) put(key string, exp uint64, bytes uint) int {
|
||||
idx := 0
|
||||
if len(h.entries) < h.maxidx {
|
||||
// Steal index from previously removed entry
|
||||
// capacity > size is guaranteed
|
||||
n := len(h.entries)
|
||||
idx = h.entries[:n+1][n].idx
|
||||
} else {
|
||||
idx = h.maxidx
|
||||
h.maxidx++
|
||||
h.indices = append(h.indices, idx)
|
||||
}
|
||||
// Push manually to avoid allocation
|
||||
h.pushInternal(heapEntry{
|
||||
key: key, exp: exp, idx: idx, bytes: bytes,
|
||||
})
|
||||
heap.Fix(h, h.Len()-1)
|
||||
return idx
|
||||
}
|
||||
|
||||
func (h *indexedHeap) removeInternal(realIdx int) (string, uint) {
|
||||
x := heap.Remove(h, realIdx).(heapEntry) //nolint:forcetypeassert,errcheck // Forced type assertion required to implement the heap.Interface interface
|
||||
return x.key, x.bytes
|
||||
}
|
||||
|
||||
// Remove entry by index
|
||||
func (h *indexedHeap) remove(idx int) (string, uint) {
|
||||
return h.removeInternal(h.indices[idx])
|
||||
}
|
||||
|
||||
// Remove entry with lowest expiration time
|
||||
func (h *indexedHeap) removeFirst() (string, uint) {
|
||||
return h.removeInternal(0)
|
||||
}
|
132
middleware/cache/manager.go
vendored
Normal file
132
middleware/cache/manager.go
vendored
Normal file
|
@ -0,0 +1,132 @@
|
|||
package cache
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/internal/memory"
|
||||
)
|
||||
|
||||
// go:generate msgp
|
||||
// msgp -file="manager.go" -o="manager_msgp.go" -tests=false -unexported
|
||||
type item struct {
|
||||
body []byte
|
||||
ctype []byte
|
||||
cencoding []byte
|
||||
status int
|
||||
exp uint64
|
||||
headers map[string][]byte
|
||||
// used for finding the item in an indexed heap
|
||||
heapidx int
|
||||
}
|
||||
|
||||
//msgp:ignore manager
|
||||
type manager struct {
|
||||
pool sync.Pool
|
||||
memory *memory.Storage
|
||||
storage fiber.Storage
|
||||
}
|
||||
|
||||
func newManager(storage fiber.Storage) *manager {
|
||||
// Create new storage handler
|
||||
manager := &manager{
|
||||
pool: sync.Pool{
|
||||
New: func() interface{} {
|
||||
return new(item)
|
||||
},
|
||||
},
|
||||
}
|
||||
if storage != nil {
|
||||
// Use provided storage if provided
|
||||
manager.storage = storage
|
||||
} else {
|
||||
// Fallback to memory storage
|
||||
manager.memory = memory.New()
|
||||
}
|
||||
return manager
|
||||
}
|
||||
|
||||
// acquire returns an *entry from the sync.Pool
|
||||
func (m *manager) acquire() *item {
|
||||
return m.pool.Get().(*item) //nolint:forcetypeassert // We store nothing else in the pool
|
||||
}
|
||||
|
||||
// release and reset *entry to sync.Pool
|
||||
func (m *manager) release(e *item) {
|
||||
// don't release item if we using memory storage
|
||||
if m.storage != nil {
|
||||
return
|
||||
}
|
||||
e.body = nil
|
||||
e.ctype = nil
|
||||
e.status = 0
|
||||
e.exp = 0
|
||||
e.headers = nil
|
||||
m.pool.Put(e)
|
||||
}
|
||||
|
||||
// get data from storage or memory
|
||||
func (m *manager) get(key string) *item {
|
||||
var it *item
|
||||
if m.storage != nil {
|
||||
it = m.acquire()
|
||||
raw, err := m.storage.Get(key)
|
||||
if err != nil {
|
||||
return it
|
||||
}
|
||||
if raw != nil {
|
||||
if _, err := it.UnmarshalMsg(raw); err != nil {
|
||||
return it
|
||||
}
|
||||
}
|
||||
return it
|
||||
}
|
||||
if it, _ = m.memory.Get(key).(*item); it == nil { //nolint:errcheck // We store nothing else in the pool
|
||||
it = m.acquire()
|
||||
return it
|
||||
}
|
||||
return it
|
||||
}
|
||||
|
||||
// get raw data from storage or memory
|
||||
func (m *manager) getRaw(key string) []byte {
|
||||
var raw []byte
|
||||
if m.storage != nil {
|
||||
raw, _ = m.storage.Get(key) //nolint:errcheck // TODO: Handle error here
|
||||
} else {
|
||||
raw, _ = m.memory.Get(key).([]byte) //nolint:errcheck // TODO: Handle error here
|
||||
}
|
||||
return raw
|
||||
}
|
||||
|
||||
// set data to storage or memory
|
||||
func (m *manager) set(key string, it *item, exp time.Duration) {
|
||||
if m.storage != nil {
|
||||
if raw, err := it.MarshalMsg(nil); err == nil {
|
||||
_ = m.storage.Set(key, raw, exp) //nolint:errcheck // TODO: Handle error here
|
||||
}
|
||||
// we can release data because it's serialized to database
|
||||
m.release(it)
|
||||
} else {
|
||||
m.memory.Set(key, it, exp)
|
||||
}
|
||||
}
|
||||
|
||||
// set data to storage or memory
|
||||
func (m *manager) setRaw(key string, raw []byte, exp time.Duration) {
|
||||
if m.storage != nil {
|
||||
_ = m.storage.Set(key, raw, exp) //nolint:errcheck // TODO: Handle error here
|
||||
} else {
|
||||
m.memory.Set(key, raw, exp)
|
||||
}
|
||||
}
|
||||
|
||||
// delete data from storage or memory
|
||||
func (m *manager) del(key string) {
|
||||
if m.storage != nil {
|
||||
_ = m.storage.Delete(key) //nolint:errcheck // TODO: Handle error here
|
||||
} else {
|
||||
m.memory.Delete(key)
|
||||
}
|
||||
}
|
300
middleware/cache/manager_msgp.go
vendored
Normal file
300
middleware/cache/manager_msgp.go
vendored
Normal file
|
@ -0,0 +1,300 @@
|
|||
package cache
|
||||
|
||||
// NOTE: THIS FILE WAS PRODUCED BY THE
|
||||
// MSGP CODE GENERATION TOOL (github.com/tinylib/msgp)
|
||||
// DO NOT EDIT
|
||||
|
||||
import (
|
||||
"github.com/tinylib/msgp/msgp"
|
||||
)
|
||||
|
||||
// DecodeMsg implements msgp.Decodable
|
||||
func (z *item) DecodeMsg(dc *msgp.Reader) (err error) {
|
||||
var field []byte
|
||||
_ = field
|
||||
var zbai uint32
|
||||
zbai, err = dc.ReadMapHeader()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
for zbai > 0 {
|
||||
zbai--
|
||||
field, err = dc.ReadMapKeyPtr()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
switch msgp.UnsafeString(field) {
|
||||
case "body":
|
||||
z.body, err = dc.ReadBytes(z.body)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
case "ctype":
|
||||
z.ctype, err = dc.ReadBytes(z.ctype)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
case "cencoding":
|
||||
z.cencoding, err = dc.ReadBytes(z.cencoding)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
case "status":
|
||||
z.status, err = dc.ReadInt()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
case "exp":
|
||||
z.exp, err = dc.ReadUint64()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
case "headers":
|
||||
var zcmr uint32
|
||||
zcmr, err = dc.ReadMapHeader()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if z.headers == nil && zcmr > 0 {
|
||||
z.headers = make(map[string][]byte, zcmr)
|
||||
} else if len(z.headers) > 0 {
|
||||
for key := range z.headers {
|
||||
delete(z.headers, key)
|
||||
}
|
||||
}
|
||||
for zcmr > 0 {
|
||||
zcmr--
|
||||
var zxvk string
|
||||
var zbzg []byte
|
||||
zxvk, err = dc.ReadString()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
zbzg, err = dc.ReadBytes(zbzg)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
z.headers[zxvk] = zbzg
|
||||
}
|
||||
case "heapidx":
|
||||
z.heapidx, err = dc.ReadInt()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
default:
|
||||
err = dc.Skip()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// EncodeMsg implements msgp.Encodable
|
||||
func (z *item) EncodeMsg(en *msgp.Writer) (err error) {
|
||||
// map header, size 7
|
||||
// write "body"
|
||||
err = en.Append(0x87, 0xa4, 0x62, 0x6f, 0x64, 0x79)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = en.WriteBytes(z.body)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
// write "ctype"
|
||||
err = en.Append(0xa5, 0x63, 0x74, 0x79, 0x70, 0x65)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = en.WriteBytes(z.ctype)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
// write "cencoding"
|
||||
err = en.Append(0xa9, 0x63, 0x65, 0x6e, 0x63, 0x6f, 0x64, 0x69, 0x6e, 0x67)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = en.WriteBytes(z.cencoding)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
// write "status"
|
||||
err = en.Append(0xa6, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = en.WriteInt(z.status)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
// write "exp"
|
||||
err = en.Append(0xa3, 0x65, 0x78, 0x70)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = en.WriteUint64(z.exp)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
// write "headers"
|
||||
err = en.Append(0xa7, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x73)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = en.WriteMapHeader(uint32(len(z.headers)))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
for zxvk, zbzg := range z.headers {
|
||||
err = en.WriteString(zxvk)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = en.WriteBytes(zbzg)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
// write "heapidx"
|
||||
err = en.Append(0xa7, 0x68, 0x65, 0x61, 0x70, 0x69, 0x64, 0x78)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = en.WriteInt(z.heapidx)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// MarshalMsg implements msgp.Marshaler
|
||||
func (z *item) MarshalMsg(b []byte) (o []byte, err error) {
|
||||
o = msgp.Require(b, z.Msgsize())
|
||||
// map header, size 7
|
||||
// string "body"
|
||||
o = append(o, 0x87, 0xa4, 0x62, 0x6f, 0x64, 0x79)
|
||||
o = msgp.AppendBytes(o, z.body)
|
||||
// string "ctype"
|
||||
o = append(o, 0xa5, 0x63, 0x74, 0x79, 0x70, 0x65)
|
||||
o = msgp.AppendBytes(o, z.ctype)
|
||||
// string "cencoding"
|
||||
o = append(o, 0xa9, 0x63, 0x65, 0x6e, 0x63, 0x6f, 0x64, 0x69, 0x6e, 0x67)
|
||||
o = msgp.AppendBytes(o, z.cencoding)
|
||||
// string "status"
|
||||
o = append(o, 0xa6, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73)
|
||||
o = msgp.AppendInt(o, z.status)
|
||||
// string "exp"
|
||||
o = append(o, 0xa3, 0x65, 0x78, 0x70)
|
||||
o = msgp.AppendUint64(o, z.exp)
|
||||
// string "headers"
|
||||
o = append(o, 0xa7, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x73)
|
||||
o = msgp.AppendMapHeader(o, uint32(len(z.headers)))
|
||||
for zxvk, zbzg := range z.headers {
|
||||
o = msgp.AppendString(o, zxvk)
|
||||
o = msgp.AppendBytes(o, zbzg)
|
||||
}
|
||||
// string "heapidx"
|
||||
o = append(o, 0xa7, 0x68, 0x65, 0x61, 0x70, 0x69, 0x64, 0x78)
|
||||
o = msgp.AppendInt(o, z.heapidx)
|
||||
return
|
||||
}
|
||||
|
||||
// UnmarshalMsg implements msgp.Unmarshaler
|
||||
func (z *item) UnmarshalMsg(bts []byte) (o []byte, err error) {
|
||||
var field []byte
|
||||
_ = field
|
||||
var zajw uint32
|
||||
zajw, bts, err = msgp.ReadMapHeaderBytes(bts)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
for zajw > 0 {
|
||||
zajw--
|
||||
field, bts, err = msgp.ReadMapKeyZC(bts)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
switch msgp.UnsafeString(field) {
|
||||
case "body":
|
||||
z.body, bts, err = msgp.ReadBytesBytes(bts, z.body)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
case "ctype":
|
||||
z.ctype, bts, err = msgp.ReadBytesBytes(bts, z.ctype)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
case "cencoding":
|
||||
z.cencoding, bts, err = msgp.ReadBytesBytes(bts, z.cencoding)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
case "status":
|
||||
z.status, bts, err = msgp.ReadIntBytes(bts)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
case "exp":
|
||||
z.exp, bts, err = msgp.ReadUint64Bytes(bts)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
case "headers":
|
||||
var zwht uint32
|
||||
zwht, bts, err = msgp.ReadMapHeaderBytes(bts)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if z.headers == nil && zwht > 0 {
|
||||
z.headers = make(map[string][]byte, zwht)
|
||||
} else if len(z.headers) > 0 {
|
||||
for key := range z.headers {
|
||||
delete(z.headers, key)
|
||||
}
|
||||
}
|
||||
for zwht > 0 {
|
||||
var zxvk string
|
||||
var zbzg []byte
|
||||
zwht--
|
||||
zxvk, bts, err = msgp.ReadStringBytes(bts)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
zbzg, bts, err = msgp.ReadBytesBytes(bts, zbzg)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
z.headers[zxvk] = zbzg
|
||||
}
|
||||
case "heapidx":
|
||||
z.heapidx, bts, err = msgp.ReadIntBytes(bts)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
default:
|
||||
bts, err = msgp.Skip(bts)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
o = bts
|
||||
return
|
||||
}
|
||||
|
||||
// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message
|
||||
func (z *item) Msgsize() (s int) {
|
||||
s = 1 + 5 + msgp.BytesPrefixSize + len(z.body) + 6 + msgp.BytesPrefixSize + len(z.ctype) + 10 + msgp.BytesPrefixSize + len(z.cencoding) + 7 + msgp.IntSize + 4 + msgp.Uint64Size + 8 + msgp.MapHeaderSize
|
||||
if z.headers != nil {
|
||||
for zxvk, zbzg := range z.headers {
|
||||
_ = zbzg
|
||||
s += msgp.StringPrefixSize + len(zxvk) + msgp.BytesPrefixSize + len(zbzg)
|
||||
}
|
||||
}
|
||||
s += 8 + msgp.IntSize
|
||||
return
|
||||
}
|
65
middleware/compress/compress.go
Normal file
65
middleware/compress/compress.go
Normal file
|
@ -0,0 +1,65 @@
|
|||
package compress
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// New creates a new middleware handler
|
||||
func New(config ...Config) fiber.Handler {
|
||||
// Set default config
|
||||
cfg := configDefault(config...)
|
||||
|
||||
// Setup request handlers
|
||||
var (
|
||||
fctx = func(c *fasthttp.RequestCtx) {}
|
||||
compressor fasthttp.RequestHandler
|
||||
)
|
||||
|
||||
// Setup compression algorithm
|
||||
switch cfg.Level {
|
||||
case LevelDefault:
|
||||
// LevelDefault
|
||||
compressor = fasthttp.CompressHandlerBrotliLevel(fctx,
|
||||
fasthttp.CompressBrotliDefaultCompression,
|
||||
fasthttp.CompressDefaultCompression,
|
||||
)
|
||||
case LevelBestSpeed:
|
||||
// LevelBestSpeed
|
||||
compressor = fasthttp.CompressHandlerBrotliLevel(fctx,
|
||||
fasthttp.CompressBrotliBestSpeed,
|
||||
fasthttp.CompressBestSpeed,
|
||||
)
|
||||
case LevelBestCompression:
|
||||
// LevelBestCompression
|
||||
compressor = fasthttp.CompressHandlerBrotliLevel(fctx,
|
||||
fasthttp.CompressBrotliBestCompression,
|
||||
fasthttp.CompressBestCompression,
|
||||
)
|
||||
default:
|
||||
// LevelDisabled
|
||||
return func(c *fiber.Ctx) error {
|
||||
return c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// Return new handler
|
||||
return func(c *fiber.Ctx) error {
|
||||
// Don't execute middleware if Next returns true
|
||||
if cfg.Next != nil && cfg.Next(c) {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
// Continue stack
|
||||
if err := c.Next(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Compress response
|
||||
compressor(c.Context())
|
||||
|
||||
// Return from handler
|
||||
return nil
|
||||
}
|
||||
}
|
192
middleware/compress/compress_test.go
Normal file
192
middleware/compress/compress_test.go
Normal file
|
@ -0,0 +1,192 @@
|
|||
package compress
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
)
|
||||
|
||||
var filedata []byte
|
||||
|
||||
func init() {
|
||||
dat, err := os.ReadFile("../../.github/README.md")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
filedata = dat
|
||||
}
|
||||
|
||||
// go test -run Test_Compress_Gzip
|
||||
func Test_Compress_Gzip(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New())
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
c.Set(fiber.HeaderContentType, fiber.MIMETextPlainCharsetUTF8)
|
||||
return c.Send(filedata)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||
req.Header.Set("Accept-Encoding", "gzip")
|
||||
|
||||
resp, err := app.Test(req)
|
||||
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
||||
utils.AssertEqual(t, 200, resp.StatusCode, "Status code")
|
||||
utils.AssertEqual(t, "gzip", resp.Header.Get(fiber.HeaderContentEncoding))
|
||||
|
||||
// Validate that the file size has shrunk
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, true, len(body) < len(filedata))
|
||||
}
|
||||
|
||||
// go test -run Test_Compress_Different_Level
|
||||
func Test_Compress_Different_Level(t *testing.T) {
|
||||
t.Parallel()
|
||||
levels := []Level{LevelBestSpeed, LevelBestCompression}
|
||||
for _, level := range levels {
|
||||
level := level
|
||||
t.Run(fmt.Sprintf("level %d", level), func(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New(Config{Level: level}))
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
c.Set(fiber.HeaderContentType, fiber.MIMETextPlainCharsetUTF8)
|
||||
return c.Send(filedata)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||
req.Header.Set("Accept-Encoding", "gzip")
|
||||
|
||||
resp, err := app.Test(req)
|
||||
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
||||
utils.AssertEqual(t, 200, resp.StatusCode, "Status code")
|
||||
utils.AssertEqual(t, "gzip", resp.Header.Get(fiber.HeaderContentEncoding))
|
||||
|
||||
// Validate that the file size has shrunk
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, true, len(body) < len(filedata))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Compress_Deflate(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New())
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return c.Send(filedata)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||
req.Header.Set("Accept-Encoding", "deflate")
|
||||
|
||||
resp, err := app.Test(req)
|
||||
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
||||
utils.AssertEqual(t, 200, resp.StatusCode, "Status code")
|
||||
utils.AssertEqual(t, "deflate", resp.Header.Get(fiber.HeaderContentEncoding))
|
||||
|
||||
// Validate that the file size has shrunk
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, true, len(body) < len(filedata))
|
||||
}
|
||||
|
||||
func Test_Compress_Brotli(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New())
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return c.Send(filedata)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||
req.Header.Set("Accept-Encoding", "br")
|
||||
|
||||
resp, err := app.Test(req, 10000)
|
||||
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
||||
utils.AssertEqual(t, 200, resp.StatusCode, "Status code")
|
||||
utils.AssertEqual(t, "br", resp.Header.Get(fiber.HeaderContentEncoding))
|
||||
|
||||
// Validate that the file size has shrunk
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, true, len(body) < len(filedata))
|
||||
}
|
||||
|
||||
func Test_Compress_Disabled(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New(Config{Level: LevelDisabled}))
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return c.Send(filedata)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||
req.Header.Set("Accept-Encoding", "br")
|
||||
|
||||
resp, err := app.Test(req)
|
||||
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
||||
utils.AssertEqual(t, 200, resp.StatusCode, "Status code")
|
||||
utils.AssertEqual(t, "", resp.Header.Get(fiber.HeaderContentEncoding))
|
||||
|
||||
// Validate the file size is not shrunk
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, true, len(body) == len(filedata))
|
||||
}
|
||||
|
||||
func Test_Compress_Next_Error(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New())
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return errors.New("next error")
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||
req.Header.Set("Accept-Encoding", "gzip")
|
||||
|
||||
resp, err := app.Test(req)
|
||||
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
||||
utils.AssertEqual(t, 500, resp.StatusCode, "Status code")
|
||||
utils.AssertEqual(t, "", resp.Header.Get(fiber.HeaderContentEncoding))
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, "next error", string(body))
|
||||
}
|
||||
|
||||
// go test -run Test_Compress_Next
|
||||
func Test_Compress_Next(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
app.Use(New(Config{
|
||||
Next: func(_ *fiber.Ctx) bool {
|
||||
return true
|
||||
},
|
||||
}))
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
|
||||
}
|
56
middleware/compress/config.go
Normal file
56
middleware/compress/config.go
Normal file
|
@ -0,0 +1,56 @@
|
|||
package compress
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
// Config defines the config for middleware.
|
||||
type Config struct {
|
||||
// Next defines a function to skip this middleware when returned true.
|
||||
//
|
||||
// Optional. Default: nil
|
||||
Next func(c *fiber.Ctx) bool
|
||||
|
||||
// Level determines the compression algorithm
|
||||
//
|
||||
// Optional. Default: LevelDefault
|
||||
// LevelDisabled: -1
|
||||
// LevelDefault: 0
|
||||
// LevelBestSpeed: 1
|
||||
// LevelBestCompression: 2
|
||||
Level Level
|
||||
}
|
||||
|
||||
// Level is numeric representation of compression level
|
||||
type Level int
|
||||
|
||||
// Represents compression level that will be used in the middleware
|
||||
const (
|
||||
LevelDisabled Level = -1
|
||||
LevelDefault Level = 0
|
||||
LevelBestSpeed Level = 1
|
||||
LevelBestCompression Level = 2
|
||||
)
|
||||
|
||||
// ConfigDefault is the default config
|
||||
var ConfigDefault = Config{
|
||||
Next: nil,
|
||||
Level: LevelDefault,
|
||||
}
|
||||
|
||||
// Helper function to set default values
|
||||
func configDefault(config ...Config) Config {
|
||||
// Return default config if nothing provided
|
||||
if len(config) < 1 {
|
||||
return ConfigDefault
|
||||
}
|
||||
|
||||
// Override default config
|
||||
cfg := config[0]
|
||||
|
||||
// Set default values
|
||||
if cfg.Level < LevelDisabled || cfg.Level > LevelBestCompression {
|
||||
cfg.Level = ConfigDefault.Level
|
||||
}
|
||||
return cfg
|
||||
}
|
289
middleware/cors/cors.go
Normal file
289
middleware/cors/cors.go
Normal file
|
@ -0,0 +1,289 @@
|
|||
package cors
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/log"
|
||||
)
|
||||
|
||||
// Config defines the config for middleware.
|
||||
type Config struct {
|
||||
// Next defines a function to skip this middleware when returned true.
|
||||
//
|
||||
// Optional. Default: nil
|
||||
Next func(c *fiber.Ctx) bool
|
||||
|
||||
// AllowOriginsFunc defines a function that will set the 'Access-Control-Allow-Origin'
|
||||
// response header to the 'origin' request header when returned true. This allows for
|
||||
// dynamic evaluation of allowed origins. Note if AllowCredentials is true, wildcard origins
|
||||
// will be not have the 'Access-Control-Allow-Credentials' header set to 'true'.
|
||||
//
|
||||
// Optional. Default: nil
|
||||
AllowOriginsFunc func(origin string) bool
|
||||
|
||||
// AllowOrigin defines a comma separated list of origins that may access the resource.
|
||||
//
|
||||
// Optional. Default value "*"
|
||||
AllowOrigins string
|
||||
|
||||
// AllowMethods defines a list methods allowed when accessing the resource.
|
||||
// This is used in response to a preflight request.
|
||||
//
|
||||
// Optional. Default value "GET,POST,HEAD,PUT,DELETE,PATCH"
|
||||
AllowMethods string
|
||||
|
||||
// AllowHeaders defines a list of request headers that can be used when
|
||||
// making the actual request. This is in response to a preflight request.
|
||||
//
|
||||
// Optional. Default value "".
|
||||
AllowHeaders string
|
||||
|
||||
// AllowCredentials indicates whether or not the response to the request
|
||||
// can be exposed when the credentials flag is true. When used as part of
|
||||
// a response to a preflight request, this indicates whether or not the
|
||||
// actual request can be made using credentials. Note: If true, AllowOrigins
|
||||
// cannot be set to a wildcard ("*") to prevent security vulnerabilities.
|
||||
//
|
||||
// Optional. Default value false.
|
||||
AllowCredentials bool
|
||||
|
||||
// ExposeHeaders defines a whitelist headers that clients are allowed to
|
||||
// access.
|
||||
//
|
||||
// Optional. Default value "".
|
||||
ExposeHeaders string
|
||||
|
||||
// MaxAge indicates how long (in seconds) the results of a preflight request
|
||||
// can be cached.
|
||||
// If you pass MaxAge 0, Access-Control-Max-Age header will not be added and
|
||||
// browser will use 5 seconds by default.
|
||||
// To disable caching completely, pass MaxAge value negative. It will set the Access-Control-Max-Age header 0.
|
||||
//
|
||||
// Optional. Default value 0.
|
||||
MaxAge int
|
||||
}
|
||||
|
||||
// ConfigDefault is the default config
|
||||
var ConfigDefault = Config{
|
||||
Next: nil,
|
||||
AllowOriginsFunc: nil,
|
||||
AllowOrigins: "*",
|
||||
AllowMethods: strings.Join([]string{
|
||||
fiber.MethodGet,
|
||||
fiber.MethodPost,
|
||||
fiber.MethodHead,
|
||||
fiber.MethodPut,
|
||||
fiber.MethodDelete,
|
||||
fiber.MethodPatch,
|
||||
}, ","),
|
||||
AllowHeaders: "",
|
||||
AllowCredentials: false,
|
||||
ExposeHeaders: "",
|
||||
MaxAge: 0,
|
||||
}
|
||||
|
||||
// New creates a new middleware handler
|
||||
func New(config ...Config) fiber.Handler {
|
||||
// Set default config
|
||||
cfg := ConfigDefault
|
||||
|
||||
// Override config if provided
|
||||
if len(config) > 0 {
|
||||
cfg = config[0]
|
||||
|
||||
// Set default values
|
||||
if cfg.AllowMethods == "" {
|
||||
cfg.AllowMethods = ConfigDefault.AllowMethods
|
||||
}
|
||||
// When none of the AllowOrigins or AllowOriginsFunc config was defined, set the default AllowOrigins value with "*"
|
||||
if cfg.AllowOrigins == "" && cfg.AllowOriginsFunc == nil {
|
||||
cfg.AllowOrigins = ConfigDefault.AllowOrigins
|
||||
}
|
||||
}
|
||||
|
||||
// Warning logs if both AllowOrigins and AllowOriginsFunc are set
|
||||
if cfg.AllowOrigins != "" && cfg.AllowOriginsFunc != nil {
|
||||
log.Warn("[CORS] Both 'AllowOrigins' and 'AllowOriginsFunc' have been defined.")
|
||||
}
|
||||
|
||||
// Validate CORS credentials configuration
|
||||
if cfg.AllowCredentials && cfg.AllowOrigins == "*" {
|
||||
panic("[CORS] Insecure setup, 'AllowCredentials' is set to true, and 'AllowOrigins' is set to a wildcard.")
|
||||
}
|
||||
|
||||
// allowOrigins is a slice of strings that contains the allowed origins
|
||||
// defined in the 'AllowOrigins' configuration.
|
||||
allowOrigins := []string{}
|
||||
allowSOrigins := []subdomain{}
|
||||
allowAllOrigins := false
|
||||
|
||||
// Validate and normalize static AllowOrigins
|
||||
if cfg.AllowOrigins != "" && cfg.AllowOrigins != "*" {
|
||||
origins := strings.Split(cfg.AllowOrigins, ",")
|
||||
for _, origin := range origins {
|
||||
if i := strings.Index(origin, "://*."); i != -1 {
|
||||
trimmedOrigin := strings.TrimSpace(origin[:i+3] + origin[i+4:])
|
||||
isValid, normalizedOrigin := normalizeOrigin(trimmedOrigin)
|
||||
if !isValid {
|
||||
panic("[CORS] Invalid origin format in configuration: " + trimmedOrigin)
|
||||
}
|
||||
sd := subdomain{prefix: normalizedOrigin[:i+3], suffix: normalizedOrigin[i+3:]}
|
||||
allowSOrigins = append(allowSOrigins, sd)
|
||||
} else {
|
||||
trimmedOrigin := strings.TrimSpace(origin)
|
||||
isValid, normalizedOrigin := normalizeOrigin(trimmedOrigin)
|
||||
if !isValid {
|
||||
panic("[CORS] Invalid origin format in configuration: " + trimmedOrigin)
|
||||
}
|
||||
allowOrigins = append(allowOrigins, normalizedOrigin)
|
||||
}
|
||||
}
|
||||
} else if cfg.AllowOrigins == "*" {
|
||||
allowAllOrigins = true
|
||||
}
|
||||
|
||||
// Strip white spaces
|
||||
allowMethods := strings.ReplaceAll(cfg.AllowMethods, " ", "")
|
||||
allowHeaders := strings.ReplaceAll(cfg.AllowHeaders, " ", "")
|
||||
exposeHeaders := strings.ReplaceAll(cfg.ExposeHeaders, " ", "")
|
||||
|
||||
// Convert int to string
|
||||
maxAge := strconv.Itoa(cfg.MaxAge)
|
||||
|
||||
// Return new handler
|
||||
return func(c *fiber.Ctx) error {
|
||||
// Don't execute middleware if Next returns true
|
||||
if cfg.Next != nil && cfg.Next(c) {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
// Get originHeader header
|
||||
originHeader := strings.ToLower(c.Get(fiber.HeaderOrigin))
|
||||
|
||||
// If the request does not have Origin header, the request is outside the scope of CORS
|
||||
if originHeader == "" {
|
||||
// See https://fetch.spec.whatwg.org/#cors-protocol-and-http-caches
|
||||
// Unless all origins are allowed, we include the Vary header to cache the response correctly
|
||||
if !allowAllOrigins {
|
||||
c.Vary(fiber.HeaderOrigin)
|
||||
}
|
||||
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
// If it's a preflight request and doesn't have Access-Control-Request-Method header, it's outside the scope of CORS
|
||||
if c.Method() == fiber.MethodOptions && c.Get(fiber.HeaderAccessControlRequestMethod) == "" {
|
||||
// Response to OPTIONS request should not be cached but,
|
||||
// some caching can be configured to cache such responses.
|
||||
// To Avoid poisoning the cache, we include the Vary header
|
||||
// for non-CORS OPTIONS requests:
|
||||
c.Vary(fiber.HeaderOrigin)
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
// Set default allowOrigin to empty string
|
||||
allowOrigin := ""
|
||||
|
||||
// Check allowed origins
|
||||
if allowAllOrigins {
|
||||
allowOrigin = "*"
|
||||
} else {
|
||||
// Check if the origin is in the list of allowed origins
|
||||
for _, origin := range allowOrigins {
|
||||
if origin == originHeader {
|
||||
allowOrigin = originHeader
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Check if the origin is in the list of allowed subdomains
|
||||
if allowOrigin == "" {
|
||||
for _, sOrigin := range allowSOrigins {
|
||||
if sOrigin.match(originHeader) {
|
||||
allowOrigin = originHeader
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Run AllowOriginsFunc if the logic for
|
||||
// handling the value in 'AllowOrigins' does
|
||||
// not result in allowOrigin being set.
|
||||
if allowOrigin == "" && cfg.AllowOriginsFunc != nil && cfg.AllowOriginsFunc(originHeader) {
|
||||
allowOrigin = originHeader
|
||||
}
|
||||
|
||||
// Simple request
|
||||
// Ommit allowMethods and allowHeaders, only used for pre-flight requests
|
||||
if c.Method() != fiber.MethodOptions {
|
||||
if !allowAllOrigins {
|
||||
// See https://fetch.spec.whatwg.org/#cors-protocol-and-http-caches
|
||||
c.Vary(fiber.HeaderOrigin)
|
||||
}
|
||||
setCORSHeaders(c, allowOrigin, "", "", exposeHeaders, maxAge, cfg)
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
// Pre-flight request
|
||||
|
||||
// Response to OPTIONS request should not be cached but,
|
||||
// some caching can be configured to cache such responses.
|
||||
// To Avoid poisoning the cache, we include the Vary header
|
||||
// of preflight responses:
|
||||
c.Vary(fiber.HeaderAccessControlRequestMethod)
|
||||
c.Vary(fiber.HeaderAccessControlRequestHeaders)
|
||||
c.Vary(fiber.HeaderOrigin)
|
||||
|
||||
setCORSHeaders(c, allowOrigin, allowMethods, allowHeaders, exposeHeaders, maxAge, cfg)
|
||||
|
||||
// Send 204 No Content
|
||||
return c.SendStatus(fiber.StatusNoContent)
|
||||
}
|
||||
}
|
||||
|
||||
// Function to set CORS headers
|
||||
func setCORSHeaders(c *fiber.Ctx, allowOrigin, allowMethods, allowHeaders, exposeHeaders, maxAge string, cfg Config) {
|
||||
if cfg.AllowCredentials {
|
||||
// When AllowCredentials is true, set the Access-Control-Allow-Origin to the specific origin instead of '*'
|
||||
if allowOrigin == "*" {
|
||||
c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin)
|
||||
log.Warn("[CORS] 'AllowCredentials' is true, but 'AllowOrigins' cannot be set to '*'.")
|
||||
} else if allowOrigin != "" {
|
||||
c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin)
|
||||
c.Set(fiber.HeaderAccessControlAllowCredentials, "true")
|
||||
}
|
||||
} else if allowOrigin != "" {
|
||||
// For non-credential requests, it's safe to set to '*' or specific origins
|
||||
c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin)
|
||||
}
|
||||
|
||||
// Set Allow-Methods if not empty
|
||||
if allowMethods != "" {
|
||||
c.Set(fiber.HeaderAccessControlAllowMethods, allowMethods)
|
||||
}
|
||||
|
||||
// Set Allow-Headers if not empty
|
||||
if allowHeaders != "" {
|
||||
c.Set(fiber.HeaderAccessControlAllowHeaders, allowHeaders)
|
||||
} else {
|
||||
h := c.Get(fiber.HeaderAccessControlRequestHeaders)
|
||||
if h != "" {
|
||||
c.Set(fiber.HeaderAccessControlAllowHeaders, h)
|
||||
}
|
||||
}
|
||||
|
||||
// Set MaxAge if set
|
||||
if cfg.MaxAge > 0 {
|
||||
c.Set(fiber.HeaderAccessControlMaxAge, maxAge)
|
||||
} else if cfg.MaxAge < 0 {
|
||||
c.Set(fiber.HeaderAccessControlMaxAge, "0")
|
||||
}
|
||||
|
||||
// Set Expose-Headers if not empty
|
||||
if exposeHeaders != "" {
|
||||
c.Set(fiber.HeaderAccessControlExposeHeaders, exposeHeaders)
|
||||
}
|
||||
}
|
1335
middleware/cors/cors_test.go
Normal file
1335
middleware/cors/cors_test.go
Normal file
File diff suppressed because it is too large
Load diff
66
middleware/cors/utils.go
Normal file
66
middleware/cors/utils.go
Normal file
|
@ -0,0 +1,66 @@
|
|||
package cors
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// matchScheme compares the scheme of the domain and pattern
|
||||
func matchScheme(domain, pattern string) bool {
|
||||
didx := strings.Index(domain, ":")
|
||||
pidx := strings.Index(pattern, ":")
|
||||
return didx != -1 && pidx != -1 && domain[:didx] == pattern[:pidx]
|
||||
}
|
||||
|
||||
// normalizeDomain removes the scheme and port from the input domain
|
||||
func normalizeDomain(input string) string {
|
||||
// Remove scheme
|
||||
input = strings.TrimPrefix(strings.TrimPrefix(input, "http://"), "https://")
|
||||
|
||||
// Find and remove port, if present
|
||||
if len(input) > 0 && input[0] != '[' {
|
||||
if portIndex := strings.Index(input, ":"); portIndex != -1 {
|
||||
input = input[:portIndex]
|
||||
}
|
||||
}
|
||||
|
||||
return input
|
||||
}
|
||||
|
||||
// normalizeOrigin checks if the provided origin is in a correct format
|
||||
// and normalizes it by removing any path or trailing slash.
|
||||
// It returns a boolean indicating whether the origin is valid
|
||||
// and the normalized origin.
|
||||
func normalizeOrigin(origin string) (bool, string) {
|
||||
parsedOrigin, err := url.Parse(origin)
|
||||
if err != nil {
|
||||
return false, ""
|
||||
}
|
||||
|
||||
// Don't allow a wildcard with a protocol
|
||||
// wildcards cannot be used within any other value. For example, the following header is not valid:
|
||||
// Access-Control-Allow-Origin: https://*
|
||||
if strings.Contains(parsedOrigin.Host, "*") {
|
||||
return false, ""
|
||||
}
|
||||
|
||||
// Validate there is a host present. The presence of a path, query, or fragment components
|
||||
// is checked, but a trailing "/" (indicative of the root) is allowed for the path and will be normalized
|
||||
if parsedOrigin.Host == "" || (parsedOrigin.Path != "" && parsedOrigin.Path != "/") || parsedOrigin.RawQuery != "" || parsedOrigin.Fragment != "" {
|
||||
return false, ""
|
||||
}
|
||||
|
||||
// Normalize the origin by constructing it from the scheme and host.
|
||||
// The path or trailing slash is not included in the normalized origin.
|
||||
return true, strings.ToLower(parsedOrigin.Scheme) + "://" + strings.ToLower(parsedOrigin.Host)
|
||||
}
|
||||
|
||||
type subdomain struct {
|
||||
// The wildcard pattern
|
||||
prefix string
|
||||
suffix string
|
||||
}
|
||||
|
||||
func (s subdomain) match(o string) bool {
|
||||
return len(o) >= len(s.prefix)+len(s.suffix) && strings.HasPrefix(o, s.prefix) && strings.HasSuffix(o, s.suffix)
|
||||
}
|
196
middleware/cors/utils_test.go
Normal file
196
middleware/cors/utils_test.go
Normal file
|
@ -0,0 +1,196 @@
|
|||
package cors
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
)
|
||||
|
||||
// go test -run -v Test_normalizeOrigin
|
||||
func Test_normalizeOrigin(t *testing.T) {
|
||||
testCases := []struct {
|
||||
origin string
|
||||
expectedValid bool
|
||||
expectedOrigin string
|
||||
}{
|
||||
{origin: "http://example.com", expectedValid: true, expectedOrigin: "http://example.com"}, // Simple case should work.
|
||||
{origin: "http://example.com/", expectedValid: true, expectedOrigin: "http://example.com"}, // Trailing slash should be removed.
|
||||
{origin: "http://example.com:3000", expectedValid: true, expectedOrigin: "http://example.com:3000"}, // Port should be preserved.
|
||||
{origin: "http://example.com:3000/", expectedValid: true, expectedOrigin: "http://example.com:3000"}, // Trailing slash should be removed.
|
||||
{origin: "app://example.com/", expectedValid: true, expectedOrigin: "app://example.com"}, // App scheme should be accepted.
|
||||
{origin: "http://", expectedValid: false, expectedOrigin: ""}, // Invalid origin should not be accepted.
|
||||
{origin: "file:///etc/passwd", expectedValid: false, expectedOrigin: ""}, // File scheme should not be accepted.
|
||||
{origin: "https://*example.com", expectedValid: false, expectedOrigin: ""}, // Wildcard domain should not be accepted.
|
||||
{origin: "http://*.example.com", expectedValid: false, expectedOrigin: ""}, // Wildcard subdomain should not be accepted.
|
||||
{origin: "http://example.com/path", expectedValid: false, expectedOrigin: ""}, // Path should not be accepted.
|
||||
{origin: "http://example.com?query=123", expectedValid: false, expectedOrigin: ""}, // Query should not be accepted.
|
||||
{origin: "http://example.com#fragment", expectedValid: false, expectedOrigin: ""}, // Fragment should not be accepted.
|
||||
{origin: "http://localhost", expectedValid: true, expectedOrigin: "http://localhost"}, // Localhost should be accepted.
|
||||
{origin: "http://127.0.0.1", expectedValid: true, expectedOrigin: "http://127.0.0.1"}, // IPv4 address should be accepted.
|
||||
{origin: "http://[::1]", expectedValid: true, expectedOrigin: "http://[::1]"}, // IPv6 address should be accepted.
|
||||
{origin: "http://[::1]:8080", expectedValid: true, expectedOrigin: "http://[::1]:8080"}, // IPv6 address with port should be accepted.
|
||||
{origin: "http://[::1]:8080/", expectedValid: true, expectedOrigin: "http://[::1]:8080"}, // IPv6 address with port and trailing slash should be accepted.
|
||||
{origin: "http://[::1]:8080/path", expectedValid: false, expectedOrigin: ""}, // IPv6 address with port and path should not be accepted.
|
||||
{origin: "http://[::1]:8080?query=123", expectedValid: false, expectedOrigin: ""}, // IPv6 address with port and query should not be accepted.
|
||||
{origin: "http://[::1]:8080#fragment", expectedValid: false, expectedOrigin: ""}, // IPv6 address with port and fragment should not be accepted.
|
||||
{origin: "http://[::1]:8080/path?query=123#fragment", expectedValid: false, expectedOrigin: ""}, // IPv6 address with port, path, query, and fragment should not be accepted.
|
||||
{origin: "http://[::1]:8080/path?query=123#fragment/", expectedValid: false, expectedOrigin: ""}, // IPv6 address with port, path, query, fragment, and trailing slash should not be accepted.
|
||||
{origin: "http://[::1]:8080/path?query=123#fragment/invalid", expectedValid: false, expectedOrigin: ""}, // IPv6 address with port, path, query, fragment, trailing slash, and invalid segment should not be accepted.
|
||||
{origin: "http://[::1]:8080/path?query=123#fragment/invalid/", expectedValid: false, expectedOrigin: ""}, // IPv6 address with port, path, query, fragment, trailing slash, and invalid segment with trailing slash should not be accepted.
|
||||
{origin: "http://[::1]:8080/path?query=123#fragment/invalid/segment", expectedValid: false, expectedOrigin: ""}, // IPv6 address with port, path, query, fragment, trailing slash, and invalid segment with additional segment should not be accepted.
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
valid, normalizedOrigin := normalizeOrigin(tc.origin)
|
||||
|
||||
if valid != tc.expectedValid {
|
||||
t.Errorf("Expected origin '%s' to be valid: %v, but got: %v", tc.origin, tc.expectedValid, valid)
|
||||
}
|
||||
|
||||
if normalizedOrigin != tc.expectedOrigin {
|
||||
t.Errorf("Expected normalized origin '%s' for origin '%s', but got: '%s'", tc.expectedOrigin, tc.origin, normalizedOrigin)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// go test -run -v Test_matchScheme
|
||||
func Test_matchScheme(t *testing.T) {
|
||||
testCases := []struct {
|
||||
domain string
|
||||
pattern string
|
||||
expected bool
|
||||
}{
|
||||
{"http://example.com", "http://example.com", true}, // Exact match should work.
|
||||
{"https://example.com", "http://example.com", false}, // Scheme mismatch should matter.
|
||||
{"http://example.com", "https://example.com", false}, // Scheme mismatch should matter.
|
||||
{"http://example.com", "http://example.org", true}, // Different domains should not matter.
|
||||
{"http://example.com", "http://example.com:8080", true}, // Port should not matter.
|
||||
{"http://example.com:8080", "http://example.com", true}, // Port should not matter.
|
||||
{"http://example.com:8080", "http://example.com:8081", true}, // Different ports should not matter.
|
||||
{"http://localhost", "http://localhost", true}, // Localhost should match.
|
||||
{"http://127.0.0.1", "http://127.0.0.1", true}, // IPv4 address should match.
|
||||
{"http://[::1]", "http://[::1]", true}, // IPv6 address should match.
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
result := matchScheme(tc.domain, tc.pattern)
|
||||
|
||||
if result != tc.expected {
|
||||
t.Errorf("Expected matchScheme('%s', '%s') to be %v, but got %v", tc.domain, tc.pattern, tc.expected, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// go test -run -v Test_normalizeDomain
|
||||
func Test_normalizeDomain(t *testing.T) {
|
||||
testCases := []struct {
|
||||
input string
|
||||
expectedOutput string
|
||||
}{
|
||||
{"http://example.com", "example.com"}, // Simple case with http scheme.
|
||||
{"https://example.com", "example.com"}, // Simple case with https scheme.
|
||||
{"http://example.com:3000", "example.com"}, // Case with port.
|
||||
{"https://example.com:3000", "example.com"}, // Case with port and https scheme.
|
||||
{"http://example.com/path", "example.com/path"}, // Case with path.
|
||||
{"http://example.com?query=123", "example.com?query=123"}, // Case with query.
|
||||
{"http://example.com#fragment", "example.com#fragment"}, // Case with fragment.
|
||||
{"example.com", "example.com"}, // Case without scheme.
|
||||
{"example.com:8080", "example.com"}, // Case without scheme but with port.
|
||||
{"sub.example.com", "sub.example.com"}, // Case with subdomain.
|
||||
{"sub.sub.example.com", "sub.sub.example.com"}, // Case with nested subdomain.
|
||||
{"http://localhost", "localhost"}, // Case with localhost.
|
||||
{"http://127.0.0.1", "127.0.0.1"}, // Case with IPv4 address.
|
||||
{"http://[::1]", "[::1]"}, // Case with IPv6 address.
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
output := normalizeDomain(tc.input)
|
||||
|
||||
if output != tc.expectedOutput {
|
||||
t.Errorf("Expected normalized domain '%s' for input '%s', but got: '%s'", tc.expectedOutput, tc.input, output)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// go test -v -run=^$ -bench=Benchmark_CORS_SubdomainMatch -benchmem -count=4
|
||||
func Benchmark_CORS_SubdomainMatch(b *testing.B) {
|
||||
s := subdomain{
|
||||
prefix: "www",
|
||||
suffix: ".example.com",
|
||||
}
|
||||
|
||||
o := "www.example.com"
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
s.match(o)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_CORS_SubdomainMatch(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sub subdomain
|
||||
origin string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "match with different scheme",
|
||||
sub: subdomain{prefix: "http://api.", suffix: ".example.com"},
|
||||
origin: "https://api.service.example.com",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "match with different scheme",
|
||||
sub: subdomain{prefix: "https://", suffix: ".example.com"},
|
||||
origin: "http://api.service.example.com",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "match with valid subdomain",
|
||||
sub: subdomain{prefix: "https://", suffix: ".example.com"},
|
||||
origin: "https://api.service.example.com",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "match with valid nested subdomain",
|
||||
sub: subdomain{prefix: "https://", suffix: ".example.com"},
|
||||
origin: "https://1.2.api.service.example.com",
|
||||
expected: true,
|
||||
},
|
||||
|
||||
{
|
||||
name: "no match with invalid prefix",
|
||||
sub: subdomain{prefix: "https://abc.", suffix: ".example.com"},
|
||||
origin: "https://service.example.com",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "no match with invalid suffix",
|
||||
sub: subdomain{prefix: "https://", suffix: ".example.com"},
|
||||
origin: "https://api.example.org",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "no match with empty origin",
|
||||
sub: subdomain{prefix: "https://", suffix: ".example.com"},
|
||||
origin: "",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "partial match not considered a match",
|
||||
sub: subdomain{prefix: "https://service.", suffix: ".example.com"},
|
||||
origin: "https://api.example.com",
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := tt.sub.match(tt.origin)
|
||||
utils.AssertEqual(t, tt.expected, got, "subdomain.match()")
|
||||
})
|
||||
}
|
||||
}
|
243
middleware/csrf/config.go
Normal file
243
middleware/csrf/config.go
Normal file
|
@ -0,0 +1,243 @@
|
|||
package csrf
|
||||
|
||||
import (
|
||||
"net/textproto"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/log"
|
||||
"github.com/gofiber/fiber/v2/middleware/session"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
)
|
||||
|
||||
// Config defines the config for middleware.
|
||||
type Config struct {
|
||||
// Next defines a function to skip this middleware when returned true.
|
||||
//
|
||||
// Optional. Default: nil
|
||||
Next func(c *fiber.Ctx) bool
|
||||
|
||||
// KeyLookup is a string in the form of "<source>:<key>" that is used
|
||||
// to create an Extractor that extracts the token from the request.
|
||||
// Possible values:
|
||||
// - "header:<name>"
|
||||
// - "query:<name>"
|
||||
// - "param:<name>"
|
||||
// - "form:<name>"
|
||||
// - "cookie:<name>"
|
||||
//
|
||||
// Ignored if an Extractor is explicitly set.
|
||||
//
|
||||
// Optional. Default: "header:X-Csrf-Token"
|
||||
KeyLookup string
|
||||
|
||||
// Name of the session cookie. This cookie will store session key.
|
||||
// Optional. Default value "csrf_".
|
||||
// Overridden if KeyLookup == "cookie:<name>"
|
||||
CookieName string
|
||||
|
||||
// Domain of the CSRF cookie.
|
||||
// Optional. Default value "".
|
||||
CookieDomain string
|
||||
|
||||
// Path of the CSRF cookie.
|
||||
// Optional. Default value "".
|
||||
CookiePath string
|
||||
|
||||
// Indicates if CSRF cookie is secure.
|
||||
// Optional. Default value false.
|
||||
CookieSecure bool
|
||||
|
||||
// Indicates if CSRF cookie is HTTP only.
|
||||
// Optional. Default value false.
|
||||
CookieHTTPOnly bool
|
||||
|
||||
// Value of SameSite cookie.
|
||||
// Optional. Default value "Lax".
|
||||
CookieSameSite string
|
||||
|
||||
// Decides whether cookie should last for only the browser sesison.
|
||||
// Ignores Expiration if set to true
|
||||
CookieSessionOnly bool
|
||||
|
||||
// Expiration is the duration before csrf token will expire
|
||||
//
|
||||
// Optional. Default: 1 * time.Hour
|
||||
Expiration time.Duration
|
||||
|
||||
// SingleUseToken indicates if the CSRF token be destroyed
|
||||
// and a new one generated on each use.
|
||||
//
|
||||
// Optional. Default: false
|
||||
SingleUseToken bool
|
||||
|
||||
// Store is used to store the state of the middleware
|
||||
//
|
||||
// Optional. Default: memory.New()
|
||||
// Ignored if Session is set.
|
||||
Storage fiber.Storage
|
||||
|
||||
// Session is used to store the state of the middleware
|
||||
//
|
||||
// Optional. Default: nil
|
||||
// If set, the middleware will use the session store instead of the storage
|
||||
Session *session.Store
|
||||
|
||||
// SessionKey is the key used to store the token in the session
|
||||
//
|
||||
// Default: "fiber.csrf.token"
|
||||
SessionKey string
|
||||
|
||||
// Context key to store generated CSRF token into context.
|
||||
// If left empty, token will not be stored in context.
|
||||
//
|
||||
// Optional. Default: ""
|
||||
ContextKey interface{}
|
||||
|
||||
// KeyGenerator creates a new CSRF token
|
||||
//
|
||||
// Optional. Default: utils.UUID
|
||||
KeyGenerator func() string
|
||||
|
||||
// Deprecated: Please use Expiration
|
||||
CookieExpires time.Duration
|
||||
|
||||
// Deprecated: Please use Cookie* related fields
|
||||
Cookie *fiber.Cookie
|
||||
|
||||
// Deprecated: Please use KeyLookup
|
||||
TokenLookup string
|
||||
|
||||
// ErrorHandler is executed when an error is returned from fiber.Handler.
|
||||
//
|
||||
// Optional. Default: DefaultErrorHandler
|
||||
ErrorHandler fiber.ErrorHandler
|
||||
|
||||
// Extractor returns the csrf token
|
||||
//
|
||||
// If set this will be used in place of an Extractor based on KeyLookup.
|
||||
//
|
||||
// Optional. Default will create an Extractor based on KeyLookup.
|
||||
Extractor func(c *fiber.Ctx) (string, error)
|
||||
|
||||
// HandlerContextKey is used to store the CSRF Handler into context
|
||||
//
|
||||
// Default: "fiber.csrf.handler"
|
||||
HandlerContextKey interface{}
|
||||
}
|
||||
|
||||
const HeaderName = "X-Csrf-Token"
|
||||
|
||||
// ConfigDefault is the default config
|
||||
var ConfigDefault = Config{
|
||||
KeyLookup: "header:" + HeaderName,
|
||||
CookieName: "csrf_",
|
||||
CookieSameSite: "Lax",
|
||||
Expiration: 1 * time.Hour,
|
||||
KeyGenerator: utils.UUIDv4,
|
||||
ErrorHandler: defaultErrorHandler,
|
||||
Extractor: CsrfFromHeader(HeaderName),
|
||||
SessionKey: "fiber.csrf.token",
|
||||
HandlerContextKey: "fiber.csrf.handler",
|
||||
}
|
||||
|
||||
// default ErrorHandler that process return error from fiber.Handler
|
||||
func defaultErrorHandler(_ *fiber.Ctx, _ error) error {
|
||||
return fiber.ErrForbidden
|
||||
}
|
||||
|
||||
// Helper function to set default values
|
||||
func configDefault(config ...Config) Config {
|
||||
// Return default config if nothing provided
|
||||
if len(config) < 1 {
|
||||
return ConfigDefault
|
||||
}
|
||||
|
||||
// Override default config
|
||||
cfg := config[0]
|
||||
|
||||
// Set default values
|
||||
if cfg.TokenLookup != "" {
|
||||
log.Warn("[CSRF] TokenLookup is deprecated, please use KeyLookup")
|
||||
cfg.KeyLookup = cfg.TokenLookup
|
||||
}
|
||||
if int(cfg.CookieExpires.Seconds()) > 0 {
|
||||
log.Warn("[CSRF] CookieExpires is deprecated, please use Expiration")
|
||||
cfg.Expiration = cfg.CookieExpires
|
||||
}
|
||||
if cfg.Cookie != nil {
|
||||
log.Warn("[CSRF] Cookie is deprecated, please use Cookie* related fields")
|
||||
if cfg.Cookie.Name != "" {
|
||||
cfg.CookieName = cfg.Cookie.Name
|
||||
}
|
||||
if cfg.Cookie.Domain != "" {
|
||||
cfg.CookieDomain = cfg.Cookie.Domain
|
||||
}
|
||||
if cfg.Cookie.Path != "" {
|
||||
cfg.CookiePath = cfg.Cookie.Path
|
||||
}
|
||||
cfg.CookieSecure = cfg.Cookie.Secure
|
||||
cfg.CookieHTTPOnly = cfg.Cookie.HTTPOnly
|
||||
if cfg.Cookie.SameSite != "" {
|
||||
cfg.CookieSameSite = cfg.Cookie.SameSite
|
||||
}
|
||||
}
|
||||
if cfg.KeyLookup == "" {
|
||||
cfg.KeyLookup = ConfigDefault.KeyLookup
|
||||
}
|
||||
if int(cfg.Expiration.Seconds()) <= 0 {
|
||||
cfg.Expiration = ConfigDefault.Expiration
|
||||
}
|
||||
if cfg.CookieName == "" {
|
||||
cfg.CookieName = ConfigDefault.CookieName
|
||||
}
|
||||
if cfg.CookieSameSite == "" {
|
||||
cfg.CookieSameSite = ConfigDefault.CookieSameSite
|
||||
}
|
||||
if cfg.KeyGenerator == nil {
|
||||
cfg.KeyGenerator = ConfigDefault.KeyGenerator
|
||||
}
|
||||
if cfg.ErrorHandler == nil {
|
||||
cfg.ErrorHandler = ConfigDefault.ErrorHandler
|
||||
}
|
||||
if cfg.SessionKey == "" {
|
||||
cfg.SessionKey = ConfigDefault.SessionKey
|
||||
}
|
||||
if cfg.HandlerContextKey == nil {
|
||||
cfg.HandlerContextKey = ConfigDefault.HandlerContextKey
|
||||
}
|
||||
|
||||
// Generate the correct extractor to get the token from the correct location
|
||||
selectors := strings.Split(cfg.KeyLookup, ":")
|
||||
|
||||
const numParts = 2
|
||||
if len(selectors) != numParts {
|
||||
panic("[CSRF] KeyLookup must in the form of <source>:<key>")
|
||||
}
|
||||
|
||||
if cfg.Extractor == nil {
|
||||
// By default we extract from a header
|
||||
cfg.Extractor = CsrfFromHeader(textproto.CanonicalMIMEHeaderKey(selectors[1]))
|
||||
|
||||
switch selectors[0] {
|
||||
case "form":
|
||||
cfg.Extractor = CsrfFromForm(selectors[1])
|
||||
case "query":
|
||||
cfg.Extractor = CsrfFromQuery(selectors[1])
|
||||
case "param":
|
||||
cfg.Extractor = CsrfFromParam(selectors[1])
|
||||
case "cookie":
|
||||
if cfg.Session == nil {
|
||||
log.Warn("[CSRF] Cookie extractor is not recommended without a session store")
|
||||
}
|
||||
if cfg.CookieSameSite == "None" || cfg.CookieSameSite != "Lax" && cfg.CookieSameSite != "Strict" {
|
||||
log.Warn("[CSRF] Cookie extractor is only recommended for use with SameSite=Lax or SameSite=Strict")
|
||||
}
|
||||
cfg.Extractor = CsrfFromCookie(selectors[1])
|
||||
cfg.CookieName = selectors[1] // Cookie name is the same as the key
|
||||
}
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
239
middleware/csrf/csrf.go
Normal file
239
middleware/csrf/csrf.go
Normal file
|
@ -0,0 +1,239 @@
|
|||
package csrf
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrTokenNotFound = errors.New("csrf token not found")
|
||||
ErrTokenInvalid = errors.New("csrf token invalid")
|
||||
ErrNoReferer = errors.New("referer not supplied")
|
||||
ErrBadReferer = errors.New("referer invalid")
|
||||
dummyValue = []byte{'+'}
|
||||
)
|
||||
|
||||
type CSRFHandler struct {
|
||||
config *Config
|
||||
sessionManager *sessionManager
|
||||
storageManager *storageManager
|
||||
}
|
||||
|
||||
// New creates a new middleware handler
|
||||
func New(config ...Config) fiber.Handler {
|
||||
// Set default config
|
||||
cfg := configDefault(config...)
|
||||
|
||||
// Create manager to simplify storage operations ( see *_manager.go )
|
||||
var sessionManager *sessionManager
|
||||
var storageManager *storageManager
|
||||
if cfg.Session != nil {
|
||||
// Register the Token struct in the session store
|
||||
cfg.Session.RegisterType(Token{})
|
||||
|
||||
sessionManager = newSessionManager(cfg.Session, cfg.SessionKey)
|
||||
} else {
|
||||
storageManager = newStorageManager(cfg.Storage)
|
||||
}
|
||||
|
||||
// Return new handler
|
||||
return func(c *fiber.Ctx) error {
|
||||
// Don't execute middleware if Next returns true
|
||||
if cfg.Next != nil && cfg.Next(c) {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
// Store the CSRF handler in the context if a context key is specified
|
||||
if cfg.HandlerContextKey != "" {
|
||||
c.Locals(cfg.HandlerContextKey, &CSRFHandler{
|
||||
config: &cfg,
|
||||
sessionManager: sessionManager,
|
||||
storageManager: storageManager,
|
||||
})
|
||||
}
|
||||
|
||||
var token string
|
||||
|
||||
// Action depends on the HTTP method
|
||||
switch c.Method() {
|
||||
case fiber.MethodGet, fiber.MethodHead, fiber.MethodOptions, fiber.MethodTrace:
|
||||
cookieToken := c.Cookies(cfg.CookieName)
|
||||
|
||||
if cookieToken != "" {
|
||||
raw := getRawFromStorage(c, cookieToken, cfg, sessionManager, storageManager)
|
||||
|
||||
if raw != nil {
|
||||
token = cookieToken // Token is valid, safe to set it
|
||||
}
|
||||
}
|
||||
default:
|
||||
// Assume that anything not defined as 'safe' by RFC7231 needs protection
|
||||
|
||||
// Enforce an origin check for HTTPS connections.
|
||||
if c.Protocol() == "https" {
|
||||
if err := refererMatchesHost(c); err != nil {
|
||||
return cfg.ErrorHandler(c, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Extract token from client request i.e. header, query, param, form or cookie
|
||||
extractedToken, err := cfg.Extractor(c)
|
||||
if err != nil {
|
||||
return cfg.ErrorHandler(c, err)
|
||||
}
|
||||
|
||||
if extractedToken == "" {
|
||||
return cfg.ErrorHandler(c, ErrTokenNotFound)
|
||||
}
|
||||
|
||||
// If not using CsrfFromCookie extractor, check that the token matches the cookie
|
||||
// This is to prevent CSRF attacks by using a Double Submit Cookie method
|
||||
// Useful when we do not have access to the users Session
|
||||
if !isCsrfFromCookie(cfg.Extractor) && !compareStrings(extractedToken, c.Cookies(cfg.CookieName)) {
|
||||
return cfg.ErrorHandler(c, ErrTokenInvalid)
|
||||
}
|
||||
|
||||
raw := getRawFromStorage(c, extractedToken, cfg, sessionManager, storageManager)
|
||||
|
||||
if raw == nil {
|
||||
// If token is not in storage, expire the cookie
|
||||
expireCSRFCookie(c, cfg)
|
||||
// and return an error
|
||||
return cfg.ErrorHandler(c, ErrTokenNotFound)
|
||||
}
|
||||
if cfg.SingleUseToken {
|
||||
// If token is single use, delete it from storage
|
||||
deleteTokenFromStorage(c, extractedToken, cfg, sessionManager, storageManager)
|
||||
} else {
|
||||
token = extractedToken // Token is valid, safe to set it
|
||||
}
|
||||
}
|
||||
|
||||
// Generate CSRF token if not exist
|
||||
if token == "" {
|
||||
// And generate a new token
|
||||
token = cfg.KeyGenerator()
|
||||
}
|
||||
|
||||
// Create or extend the token in the storage
|
||||
createOrExtendTokenInStorage(c, token, cfg, sessionManager, storageManager)
|
||||
|
||||
// Update the CSRF cookie
|
||||
updateCSRFCookie(c, cfg, token)
|
||||
|
||||
// Tell the browser that a new header value is generated
|
||||
c.Vary(fiber.HeaderCookie)
|
||||
|
||||
// Store the token in the context if a context key is specified
|
||||
if cfg.ContextKey != nil {
|
||||
c.Locals(cfg.ContextKey, token)
|
||||
}
|
||||
|
||||
// Continue stack
|
||||
return c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// getRawFromStorage returns the raw value from the storage for the given token
|
||||
// returns nil if the token does not exist, is expired or is invalid
|
||||
func getRawFromStorage(c *fiber.Ctx, token string, cfg Config, sessionManager *sessionManager, storageManager *storageManager) []byte {
|
||||
if cfg.Session != nil {
|
||||
return sessionManager.getRaw(c, token, dummyValue)
|
||||
}
|
||||
return storageManager.getRaw(token)
|
||||
}
|
||||
|
||||
// createOrExtendTokenInStorage creates or extends the token in the storage
|
||||
func createOrExtendTokenInStorage(c *fiber.Ctx, token string, cfg Config, sessionManager *sessionManager, storageManager *storageManager) {
|
||||
if cfg.Session != nil {
|
||||
sessionManager.setRaw(c, token, dummyValue, cfg.Expiration)
|
||||
} else {
|
||||
storageManager.setRaw(token, dummyValue, cfg.Expiration)
|
||||
}
|
||||
}
|
||||
|
||||
func deleteTokenFromStorage(c *fiber.Ctx, token string, cfg Config, sessionManager *sessionManager, storageManager *storageManager) {
|
||||
if cfg.Session != nil {
|
||||
sessionManager.delRaw(c)
|
||||
} else {
|
||||
storageManager.delRaw(token)
|
||||
}
|
||||
}
|
||||
|
||||
// Update CSRF cookie
|
||||
// if expireCookie is true, the cookie will expire immediately
|
||||
func updateCSRFCookie(c *fiber.Ctx, cfg Config, token string) {
|
||||
setCSRFCookie(c, cfg, token, cfg.Expiration)
|
||||
}
|
||||
|
||||
func expireCSRFCookie(c *fiber.Ctx, cfg Config) {
|
||||
setCSRFCookie(c, cfg, "", -time.Hour)
|
||||
}
|
||||
|
||||
func setCSRFCookie(c *fiber.Ctx, cfg Config, token string, expiry time.Duration) {
|
||||
cookie := &fiber.Cookie{
|
||||
Name: cfg.CookieName,
|
||||
Value: token,
|
||||
Domain: cfg.CookieDomain,
|
||||
Path: cfg.CookiePath,
|
||||
Secure: cfg.CookieSecure,
|
||||
HTTPOnly: cfg.CookieHTTPOnly,
|
||||
SameSite: cfg.CookieSameSite,
|
||||
SessionOnly: cfg.CookieSessionOnly,
|
||||
Expires: time.Now().Add(expiry),
|
||||
}
|
||||
|
||||
// Set the CSRF cookie to the response
|
||||
c.Cookie(cookie)
|
||||
}
|
||||
|
||||
// DeleteToken removes the token found in the context from the storage
|
||||
// and expires the CSRF cookie
|
||||
func (handler *CSRFHandler) DeleteToken(c *fiber.Ctx) error {
|
||||
// Get the config from the context
|
||||
config := handler.config
|
||||
if config == nil {
|
||||
panic("CSRFHandler config not found in context")
|
||||
}
|
||||
// Extract token from the client request cookie
|
||||
cookieToken := c.Cookies(config.CookieName)
|
||||
if cookieToken == "" {
|
||||
return config.ErrorHandler(c, ErrTokenNotFound)
|
||||
}
|
||||
// Remove the token from storage
|
||||
deleteTokenFromStorage(c, cookieToken, *config, handler.sessionManager, handler.storageManager)
|
||||
// Expire the cookie
|
||||
expireCSRFCookie(c, *config)
|
||||
return nil
|
||||
}
|
||||
|
||||
// isCsrfFromCookie checks if the extractor is set to ExtractFromCookie
|
||||
func isCsrfFromCookie(extractor interface{}) bool {
|
||||
return reflect.ValueOf(extractor).Pointer() == reflect.ValueOf(CsrfFromCookie).Pointer()
|
||||
}
|
||||
|
||||
// refererMatchesHost checks that the referer header matches the host header
|
||||
// returns an error if the referer header is not present or is invalid
|
||||
// returns nil if the referer header is valid
|
||||
func refererMatchesHost(c *fiber.Ctx) error {
|
||||
referer := strings.ToLower(c.Get(fiber.HeaderReferer))
|
||||
if referer == "" {
|
||||
return ErrNoReferer
|
||||
}
|
||||
|
||||
refererURL, err := url.Parse(referer)
|
||||
if err != nil {
|
||||
return ErrBadReferer
|
||||
}
|
||||
|
||||
if refererURL.Scheme == c.Protocol() && refererURL.Host == c.Hostname() {
|
||||
return nil
|
||||
}
|
||||
|
||||
return ErrBadReferer
|
||||
}
|
1060
middleware/csrf/csrf_test.go
Normal file
1060
middleware/csrf/csrf_test.go
Normal file
File diff suppressed because it is too large
Load diff
70
middleware/csrf/extractors.go
Normal file
70
middleware/csrf/extractors.go
Normal file
|
@ -0,0 +1,70 @@
|
|||
package csrf
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrMissingHeader = errors.New("missing csrf token in header")
|
||||
ErrMissingQuery = errors.New("missing csrf token in query")
|
||||
ErrMissingParam = errors.New("missing csrf token in param")
|
||||
ErrMissingForm = errors.New("missing csrf token in form")
|
||||
ErrMissingCookie = errors.New("missing csrf token in cookie")
|
||||
)
|
||||
|
||||
// csrfFromParam returns a function that extracts token from the url param string.
|
||||
func CsrfFromParam(param string) func(c *fiber.Ctx) (string, error) {
|
||||
return func(c *fiber.Ctx) (string, error) {
|
||||
token := c.Params(param)
|
||||
if token == "" {
|
||||
return "", ErrMissingParam
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
}
|
||||
|
||||
// csrfFromForm returns a function that extracts a token from a multipart-form.
|
||||
func CsrfFromForm(param string) func(c *fiber.Ctx) (string, error) {
|
||||
return func(c *fiber.Ctx) (string, error) {
|
||||
token := c.FormValue(param)
|
||||
if token == "" {
|
||||
return "", ErrMissingForm
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
}
|
||||
|
||||
// csrfFromCookie returns a function that extracts token from the cookie header.
|
||||
func CsrfFromCookie(param string) func(c *fiber.Ctx) (string, error) {
|
||||
return func(c *fiber.Ctx) (string, error) {
|
||||
token := c.Cookies(param)
|
||||
if token == "" {
|
||||
return "", ErrMissingCookie
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
}
|
||||
|
||||
// csrfFromHeader returns a function that extracts token from the request header.
|
||||
func CsrfFromHeader(param string) func(c *fiber.Ctx) (string, error) {
|
||||
return func(c *fiber.Ctx) (string, error) {
|
||||
token := c.Get(param)
|
||||
if token == "" {
|
||||
return "", ErrMissingHeader
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
}
|
||||
|
||||
// csrfFromQuery returns a function that extracts token from the query string.
|
||||
func CsrfFromQuery(param string) func(c *fiber.Ctx) (string, error) {
|
||||
return func(c *fiber.Ctx) (string, error) {
|
||||
token := c.Query(param)
|
||||
if token == "" {
|
||||
return "", ErrMissingQuery
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
}
|
13
middleware/csrf/helpers.go
Normal file
13
middleware/csrf/helpers.go
Normal file
|
@ -0,0 +1,13 @@
|
|||
package csrf
|
||||
|
||||
import (
|
||||
"crypto/subtle"
|
||||
)
|
||||
|
||||
func compareTokens(a, b []byte) bool {
|
||||
return subtle.ConstantTimeCompare(a, b) == 1
|
||||
}
|
||||
|
||||
func compareStrings(a, b string) bool {
|
||||
return subtle.ConstantTimeCompare([]byte(a), []byte(b)) == 1
|
||||
}
|
68
middleware/csrf/session_manager.go
Normal file
68
middleware/csrf/session_manager.go
Normal file
|
@ -0,0 +1,68 @@
|
|||
package csrf
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/log"
|
||||
"github.com/gofiber/fiber/v2/middleware/session"
|
||||
)
|
||||
|
||||
type sessionManager struct {
|
||||
key string
|
||||
session *session.Store
|
||||
}
|
||||
|
||||
func newSessionManager(s *session.Store, k string) *sessionManager {
|
||||
// Create new storage handler
|
||||
sessionManager := &sessionManager{
|
||||
key: k,
|
||||
}
|
||||
if s != nil {
|
||||
// Use provided storage if provided
|
||||
sessionManager.session = s
|
||||
}
|
||||
return sessionManager
|
||||
}
|
||||
|
||||
// get token from session
|
||||
func (m *sessionManager) getRaw(c *fiber.Ctx, key string, raw []byte) []byte {
|
||||
sess, err := m.session.Get(c)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
token, ok := sess.Get(m.key).(Token)
|
||||
if ok {
|
||||
if token.Expiration.Before(time.Now()) || key != token.Key || !compareTokens(raw, token.Raw) {
|
||||
return nil
|
||||
}
|
||||
return token.Raw
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// set token in session
|
||||
func (m *sessionManager) setRaw(c *fiber.Ctx, key string, raw []byte, exp time.Duration) {
|
||||
sess, err := m.session.Get(c)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
// the key is crucial in crsf and sometimes a reference to another value which can be reused later(pool/unsafe values concept), so a copy is made here
|
||||
sess.Set(m.key, &Token{key, raw, time.Now().Add(exp)})
|
||||
if err := sess.Save(); err != nil {
|
||||
log.Warn("csrf: failed to save session: ", err)
|
||||
}
|
||||
}
|
||||
|
||||
// delete token from session
|
||||
func (m *sessionManager) delRaw(c *fiber.Ctx) {
|
||||
sess, err := m.session.Get(c)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
sess.Delete(m.key)
|
||||
if err := sess.Save(); err != nil {
|
||||
log.Warn("csrf: failed to save session: ", err)
|
||||
}
|
||||
}
|
70
middleware/csrf/storage_manager.go
Normal file
70
middleware/csrf/storage_manager.go
Normal file
|
@ -0,0 +1,70 @@
|
|||
package csrf
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/internal/memory"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
)
|
||||
|
||||
// go:generate msgp
|
||||
// msgp -file="storage_manager.go" -o="storage_manager_msgp.go" -tests=false -unexported
|
||||
type item struct{}
|
||||
|
||||
//msgp:ignore manager
|
||||
type storageManager struct {
|
||||
pool sync.Pool
|
||||
memory *memory.Storage
|
||||
storage fiber.Storage
|
||||
}
|
||||
|
||||
func newStorageManager(storage fiber.Storage) *storageManager {
|
||||
// Create new storage handler
|
||||
storageManager := &storageManager{
|
||||
pool: sync.Pool{
|
||||
New: func() interface{} {
|
||||
return new(item)
|
||||
},
|
||||
},
|
||||
}
|
||||
if storage != nil {
|
||||
// Use provided storage if provided
|
||||
storageManager.storage = storage
|
||||
} else {
|
||||
// Fallback too memory storage
|
||||
storageManager.memory = memory.New()
|
||||
}
|
||||
return storageManager
|
||||
}
|
||||
|
||||
// get raw data from storage or memory
|
||||
func (m *storageManager) getRaw(key string) []byte {
|
||||
var raw []byte
|
||||
if m.storage != nil {
|
||||
raw, _ = m.storage.Get(key) //nolint:errcheck // TODO: Do not ignore error
|
||||
} else {
|
||||
raw, _ = m.memory.Get(key).([]byte) //nolint:errcheck // TODO: Do not ignore error
|
||||
}
|
||||
return raw
|
||||
}
|
||||
|
||||
// set data to storage or memory
|
||||
func (m *storageManager) setRaw(key string, raw []byte, exp time.Duration) {
|
||||
if m.storage != nil {
|
||||
_ = m.storage.Set(key, raw, exp) //nolint:errcheck // TODO: Do not ignore error
|
||||
} else {
|
||||
// the key is crucial in crsf and sometimes a reference to another value which can be reused later(pool/unsafe values concept), so a copy is made here
|
||||
m.memory.Set(utils.CopyString(key), raw, exp)
|
||||
}
|
||||
}
|
||||
|
||||
// delete data from storage or memory
|
||||
func (m *storageManager) delRaw(key string) {
|
||||
if m.storage != nil {
|
||||
_ = m.storage.Delete(key) //nolint:errcheck // TODO: Do not ignore error
|
||||
} else {
|
||||
m.memory.Delete(key)
|
||||
}
|
||||
}
|
90
middleware/csrf/storage_manager_msgp.go
Normal file
90
middleware/csrf/storage_manager_msgp.go
Normal file
|
@ -0,0 +1,90 @@
|
|||
package csrf
|
||||
|
||||
// Code generated by github.com/tinylib/msgp DO NOT EDIT.
|
||||
|
||||
import (
|
||||
"github.com/tinylib/msgp/msgp"
|
||||
)
|
||||
|
||||
// DecodeMsg implements msgp.Decodable
|
||||
func (z *item) DecodeMsg(dc *msgp.Reader) (err error) {
|
||||
var field []byte
|
||||
_ = field
|
||||
var zb0001 uint32
|
||||
zb0001, err = dc.ReadMapHeader()
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err)
|
||||
return
|
||||
}
|
||||
for zb0001 > 0 {
|
||||
zb0001--
|
||||
field, err = dc.ReadMapKeyPtr()
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err)
|
||||
return
|
||||
}
|
||||
switch msgp.UnsafeString(field) {
|
||||
default:
|
||||
err = dc.Skip()
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// EncodeMsg implements msgp.Encodable
|
||||
func (z item) EncodeMsg(en *msgp.Writer) (err error) {
|
||||
// map header, size 0
|
||||
err = en.Append(0x80)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// MarshalMsg implements msgp.Marshaler
|
||||
func (z item) MarshalMsg(b []byte) (o []byte, err error) {
|
||||
o = msgp.Require(b, z.Msgsize())
|
||||
// map header, size 0
|
||||
o = append(o, 0x80)
|
||||
return
|
||||
}
|
||||
|
||||
// UnmarshalMsg implements msgp.Unmarshaler
|
||||
func (z *item) UnmarshalMsg(bts []byte) (o []byte, err error) {
|
||||
var field []byte
|
||||
_ = field
|
||||
var zb0001 uint32
|
||||
zb0001, bts, err = msgp.ReadMapHeaderBytes(bts)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err)
|
||||
return
|
||||
}
|
||||
for zb0001 > 0 {
|
||||
zb0001--
|
||||
field, bts, err = msgp.ReadMapKeyZC(bts)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err)
|
||||
return
|
||||
}
|
||||
switch msgp.UnsafeString(field) {
|
||||
default:
|
||||
bts, err = msgp.Skip(bts)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
o = bts
|
||||
return
|
||||
}
|
||||
|
||||
// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message
|
||||
func (z item) Msgsize() (s int) {
|
||||
s = 1
|
||||
return
|
||||
}
|
11
middleware/csrf/token.go
Normal file
11
middleware/csrf/token.go
Normal file
|
@ -0,0 +1,11 @@
|
|||
package csrf
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
type Token struct {
|
||||
Key string `json:"key"`
|
||||
Raw []byte `json:"raw"`
|
||||
Expiration time.Time `json:"expiration"`
|
||||
}
|
73
middleware/earlydata/config.go
Normal file
73
middleware/earlydata/config.go
Normal file
|
@ -0,0 +1,73 @@
|
|||
package earlydata
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultHeaderName = "Early-Data"
|
||||
DefaultHeaderTrueValue = "1"
|
||||
)
|
||||
|
||||
// Config defines the config for middleware.
|
||||
type Config struct {
|
||||
// Next defines a function to skip this middleware when returned true.
|
||||
//
|
||||
// Optional. Default: nil
|
||||
Next func(c *fiber.Ctx) bool
|
||||
|
||||
// IsEarlyData returns whether the request is an early-data request.
|
||||
//
|
||||
// Optional. Default: a function which checks if the "Early-Data" request header equals "1".
|
||||
IsEarlyData func(c *fiber.Ctx) bool
|
||||
|
||||
// AllowEarlyData returns whether the early-data request should be allowed or rejected.
|
||||
//
|
||||
// Optional. Default: a function which rejects the request on unsafe and allows the request on safe HTTP request methods.
|
||||
AllowEarlyData func(c *fiber.Ctx) bool
|
||||
|
||||
// Error is returned in case an early-data request is rejected.
|
||||
//
|
||||
// Optional. Default: fiber.ErrTooEarly.
|
||||
Error error
|
||||
}
|
||||
|
||||
// ConfigDefault is the default config
|
||||
var ConfigDefault = Config{
|
||||
IsEarlyData: func(c *fiber.Ctx) bool {
|
||||
return c.Get(DefaultHeaderName) == DefaultHeaderTrueValue
|
||||
},
|
||||
|
||||
AllowEarlyData: func(c *fiber.Ctx) bool {
|
||||
return fiber.IsMethodSafe(c.Method())
|
||||
},
|
||||
|
||||
Error: fiber.ErrTooEarly,
|
||||
}
|
||||
|
||||
// Helper function to set default values
|
||||
func configDefault(config ...Config) Config {
|
||||
// Return default config if nothing provided
|
||||
if len(config) < 1 {
|
||||
return ConfigDefault
|
||||
}
|
||||
|
||||
// Override default config
|
||||
cfg := config[0]
|
||||
|
||||
// Set default values
|
||||
|
||||
if cfg.IsEarlyData == nil {
|
||||
cfg.IsEarlyData = ConfigDefault.IsEarlyData
|
||||
}
|
||||
|
||||
if cfg.AllowEarlyData == nil {
|
||||
cfg.AllowEarlyData = ConfigDefault.AllowEarlyData
|
||||
}
|
||||
|
||||
if cfg.Error == nil {
|
||||
cfg.Error = ConfigDefault.Error
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
47
middleware/earlydata/earlydata.go
Normal file
47
middleware/earlydata/earlydata.go
Normal file
|
@ -0,0 +1,47 @@
|
|||
package earlydata
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
const (
|
||||
localsKeyAllowed = "earlydata_allowed"
|
||||
)
|
||||
|
||||
func IsEarly(c *fiber.Ctx) bool {
|
||||
return c.Locals(localsKeyAllowed) != nil
|
||||
}
|
||||
|
||||
// New creates a new middleware handler
|
||||
// https://datatracker.ietf.org/doc/html/rfc8470#section-5.1
|
||||
func New(config ...Config) fiber.Handler {
|
||||
// Set default config
|
||||
cfg := configDefault(config...)
|
||||
|
||||
// Return new handler
|
||||
return func(c *fiber.Ctx) error {
|
||||
// Don't execute middleware if Next returns true
|
||||
if cfg.Next != nil && cfg.Next(c) {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
// Abort if we can't trust the early-data header
|
||||
if !c.IsProxyTrusted() {
|
||||
return cfg.Error
|
||||
}
|
||||
|
||||
// Continue stack if request is not an early-data request
|
||||
if !cfg.IsEarlyData(c) {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
// Continue stack if we allow early-data for this request
|
||||
if cfg.AllowEarlyData(c) {
|
||||
_ = c.Locals(localsKeyAllowed, true)
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
// Else return our error
|
||||
return cfg.Error
|
||||
}
|
||||
}
|
193
middleware/earlydata/earlydata_test.go
Normal file
193
middleware/earlydata/earlydata_test.go
Normal file
|
@ -0,0 +1,193 @@
|
|||
//nolint:bodyclose // Much easier to just ignore memory leaks in tests
|
||||
package earlydata_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/middleware/earlydata"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
)
|
||||
|
||||
const (
|
||||
headerName = "Early-Data"
|
||||
headerValOn = "1"
|
||||
headerValOff = "0"
|
||||
)
|
||||
|
||||
func appWithConfig(t *testing.T, c *fiber.Config) *fiber.App {
|
||||
t.Helper()
|
||||
t.Parallel()
|
||||
|
||||
var app *fiber.App
|
||||
if c == nil {
|
||||
app = fiber.New()
|
||||
} else {
|
||||
app = fiber.New(*c)
|
||||
}
|
||||
|
||||
app.Use(earlydata.New())
|
||||
|
||||
// Middleware to test IsEarly func
|
||||
const localsKeyTestValid = "earlydata_testvalid"
|
||||
app.Use(func(c *fiber.Ctx) error {
|
||||
isEarly := earlydata.IsEarly(c)
|
||||
|
||||
switch h := c.Get(headerName); h {
|
||||
case "", headerValOff:
|
||||
if isEarly {
|
||||
return errors.New("is early-data even though it's not")
|
||||
}
|
||||
|
||||
case headerValOn:
|
||||
switch {
|
||||
case fiber.IsMethodSafe(c.Method()):
|
||||
if !isEarly {
|
||||
return errors.New("should be early-data on safe HTTP methods")
|
||||
}
|
||||
default:
|
||||
if isEarly {
|
||||
return errors.New("early-data unsuported on unsafe HTTP methods")
|
||||
}
|
||||
}
|
||||
|
||||
default:
|
||||
return fmt.Errorf("header has unsupported value: %s", h)
|
||||
}
|
||||
|
||||
_ = c.Locals(localsKeyTestValid, true)
|
||||
|
||||
return c.Next()
|
||||
})
|
||||
|
||||
{
|
||||
{
|
||||
handler := func(c *fiber.Ctx) error {
|
||||
if !c.Locals(localsKeyTestValid).(bool) { //nolint:forcetypeassert // We store nothing else in the pool
|
||||
return errors.New("handler called even though validation failed")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
app.Get("/", handler)
|
||||
app.Post("/", handler)
|
||||
}
|
||||
}
|
||||
|
||||
return app
|
||||
}
|
||||
|
||||
// go test -run Test_EarlyData
|
||||
func Test_EarlyData(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
trustedRun := func(t *testing.T, app *fiber.App) {
|
||||
t.Helper()
|
||||
|
||||
{
|
||||
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
|
||||
|
||||
resp, err := app.Test(req)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||
|
||||
req.Header.Set(headerName, headerValOff)
|
||||
resp, err = app.Test(req)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||
|
||||
req.Header.Set(headerName, headerValOn)
|
||||
resp, err = app.Test(req)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||
}
|
||||
|
||||
{
|
||||
req := httptest.NewRequest(fiber.MethodPost, "/", http.NoBody)
|
||||
|
||||
resp, err := app.Test(req)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||
|
||||
req.Header.Set(headerName, headerValOff)
|
||||
resp, err = app.Test(req)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||
|
||||
req.Header.Set(headerName, headerValOn)
|
||||
resp, err = app.Test(req)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusTooEarly, resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
untrustedRun := func(t *testing.T, app *fiber.App) {
|
||||
t.Helper()
|
||||
|
||||
{
|
||||
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
|
||||
|
||||
resp, err := app.Test(req)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusTooEarly, resp.StatusCode)
|
||||
|
||||
req.Header.Set(headerName, headerValOff)
|
||||
resp, err = app.Test(req)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusTooEarly, resp.StatusCode)
|
||||
|
||||
req.Header.Set(headerName, headerValOn)
|
||||
resp, err = app.Test(req)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusTooEarly, resp.StatusCode)
|
||||
}
|
||||
|
||||
{
|
||||
req := httptest.NewRequest(fiber.MethodPost, "/", http.NoBody)
|
||||
|
||||
resp, err := app.Test(req)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusTooEarly, resp.StatusCode)
|
||||
|
||||
req.Header.Set(headerName, headerValOff)
|
||||
resp, err = app.Test(req)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusTooEarly, resp.StatusCode)
|
||||
|
||||
req.Header.Set(headerName, headerValOn)
|
||||
resp, err = app.Test(req)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusTooEarly, resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
t.Run("empty config", func(t *testing.T) {
|
||||
app := appWithConfig(t, nil)
|
||||
trustedRun(t, app)
|
||||
})
|
||||
t.Run("default config", func(t *testing.T) {
|
||||
app := appWithConfig(t, &fiber.Config{})
|
||||
trustedRun(t, app)
|
||||
})
|
||||
|
||||
t.Run("config with EnableTrustedProxyCheck", func(t *testing.T) {
|
||||
app := appWithConfig(t, &fiber.Config{
|
||||
EnableTrustedProxyCheck: true,
|
||||
})
|
||||
untrustedRun(t, app)
|
||||
})
|
||||
t.Run("config with EnableTrustedProxyCheck and trusted TrustedProxies", func(t *testing.T) {
|
||||
app := appWithConfig(t, &fiber.Config{
|
||||
EnableTrustedProxyCheck: true,
|
||||
TrustedProxies: []string{
|
||||
"0.0.0.0",
|
||||
},
|
||||
})
|
||||
trustedRun(t, app)
|
||||
})
|
||||
}
|
78
middleware/encryptcookie/config.go
Normal file
78
middleware/encryptcookie/config.go
Normal file
|
@ -0,0 +1,78 @@
|
|||
package encryptcookie
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
// Config defines the config for middleware.
|
||||
type Config struct {
|
||||
// Next defines a function to skip this middleware when returned true.
|
||||
//
|
||||
// Optional. Default: nil
|
||||
Next func(c *fiber.Ctx) bool
|
||||
|
||||
// Array of cookie keys that should not be encrypted.
|
||||
//
|
||||
// Optional. Default: []
|
||||
Except []string
|
||||
|
||||
// Base64 encoded unique key to encode & decode cookies.
|
||||
//
|
||||
// Required. Key length should be 32 characters.
|
||||
// You may use `encryptcookie.GenerateKey()` to generate a new key.
|
||||
Key string
|
||||
|
||||
// Custom function to encrypt cookies.
|
||||
//
|
||||
// Optional. Default: EncryptCookie
|
||||
Encryptor func(decryptedString, key string) (string, error)
|
||||
|
||||
// Custom function to decrypt cookies.
|
||||
//
|
||||
// Optional. Default: DecryptCookie
|
||||
Decryptor func(encryptedString, key string) (string, error)
|
||||
}
|
||||
|
||||
// ConfigDefault is the default config
|
||||
var ConfigDefault = Config{
|
||||
Next: nil,
|
||||
Except: []string{},
|
||||
Key: "",
|
||||
Encryptor: EncryptCookie,
|
||||
Decryptor: DecryptCookie,
|
||||
}
|
||||
|
||||
// Helper function to set default values
|
||||
func configDefault(config ...Config) Config {
|
||||
// Set default config
|
||||
cfg := ConfigDefault
|
||||
|
||||
// Override config if provided
|
||||
if len(config) > 0 {
|
||||
cfg = config[0]
|
||||
|
||||
// Set default values
|
||||
|
||||
if cfg.Next == nil {
|
||||
cfg.Next = ConfigDefault.Next
|
||||
}
|
||||
|
||||
if cfg.Except == nil {
|
||||
cfg.Except = ConfigDefault.Except
|
||||
}
|
||||
|
||||
if cfg.Encryptor == nil {
|
||||
cfg.Encryptor = ConfigDefault.Encryptor
|
||||
}
|
||||
|
||||
if cfg.Decryptor == nil {
|
||||
cfg.Decryptor = ConfigDefault.Decryptor
|
||||
}
|
||||
}
|
||||
|
||||
if cfg.Key == "" {
|
||||
panic("fiber: encrypt cookie middleware requires key")
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
57
middleware/encryptcookie/encryptcookie.go
Normal file
57
middleware/encryptcookie/encryptcookie.go
Normal file
|
@ -0,0 +1,57 @@
|
|||
package encryptcookie
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// New creates a new middleware handler
|
||||
func New(config ...Config) fiber.Handler {
|
||||
// Set default config
|
||||
cfg := configDefault(config...)
|
||||
|
||||
// Return new handler
|
||||
return func(c *fiber.Ctx) error {
|
||||
// Don't execute middleware if Next returns true
|
||||
if cfg.Next != nil && cfg.Next(c) {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
// Decrypt request cookies
|
||||
c.Request().Header.VisitAllCookie(func(key, value []byte) {
|
||||
keyString := string(key)
|
||||
if !isDisabled(keyString, cfg.Except) {
|
||||
decryptedValue, err := cfg.Decryptor(string(value), cfg.Key)
|
||||
if err != nil {
|
||||
c.Request().Header.SetCookieBytesKV(key, nil)
|
||||
} else {
|
||||
c.Request().Header.SetCookie(string(key), decryptedValue)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// Continue stack
|
||||
err := c.Next()
|
||||
|
||||
// Encrypt response cookies
|
||||
c.Response().Header.VisitAllCookie(func(key, value []byte) {
|
||||
keyString := string(key)
|
||||
if !isDisabled(keyString, cfg.Except) {
|
||||
cookieValue := fasthttp.Cookie{}
|
||||
cookieValue.SetKeyBytes(key)
|
||||
if c.Response().Header.Cookie(&cookieValue) {
|
||||
encryptedValue, err := cfg.Encryptor(string(cookieValue.Value()), cfg.Key)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
cookieValue.SetValue(encryptedValue)
|
||||
c.Response().Header.SetCookie(&cookieValue)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
}
|
192
middleware/encryptcookie/encryptcookie_test.go
Normal file
192
middleware/encryptcookie/encryptcookie_test.go
Normal file
|
@ -0,0 +1,192 @@
|
|||
package encryptcookie
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
var testKey = GenerateKey()
|
||||
|
||||
func Test_Middleware_Encrypt_Cookie(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
Key: testKey,
|
||||
}))
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return c.SendString("value=" + c.Cookies("test"))
|
||||
})
|
||||
app.Post("/", func(c *fiber.Ctx) error {
|
||||
c.Cookie(&fiber.Cookie{
|
||||
Name: "test",
|
||||
Value: "SomeThing",
|
||||
})
|
||||
return nil
|
||||
})
|
||||
|
||||
h := app.Handler()
|
||||
|
||||
// Test empty cookie
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.SetMethod(fiber.MethodGet)
|
||||
h(ctx)
|
||||
utils.AssertEqual(t, 200, ctx.Response.StatusCode())
|
||||
utils.AssertEqual(t, "value=", string(ctx.Response.Body()))
|
||||
|
||||
// Test invalid cookie
|
||||
ctx = &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.SetMethod(fiber.MethodGet)
|
||||
ctx.Request.Header.SetCookie("test", "Invalid")
|
||||
h(ctx)
|
||||
utils.AssertEqual(t, 200, ctx.Response.StatusCode())
|
||||
utils.AssertEqual(t, "value=", string(ctx.Response.Body()))
|
||||
ctx.Request.Header.SetCookie("test", "ixQURE2XOyZUs0WAOh2ehjWcP7oZb07JvnhWOsmeNUhPsj4+RyI=")
|
||||
h(ctx)
|
||||
utils.AssertEqual(t, 200, ctx.Response.StatusCode())
|
||||
utils.AssertEqual(t, "value=", string(ctx.Response.Body()))
|
||||
|
||||
// Test valid cookie
|
||||
ctx = &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.SetMethod(fiber.MethodPost)
|
||||
h(ctx)
|
||||
utils.AssertEqual(t, 200, ctx.Response.StatusCode())
|
||||
|
||||
encryptedCookie := fasthttp.Cookie{}
|
||||
encryptedCookie.SetKey("test")
|
||||
utils.AssertEqual(t, true, ctx.Response.Header.Cookie(&encryptedCookie), "Get cookie value")
|
||||
decryptedCookieValue, err := DecryptCookie(string(encryptedCookie.Value()), testKey)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, "SomeThing", decryptedCookieValue)
|
||||
|
||||
ctx = &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.SetMethod(fiber.MethodGet)
|
||||
ctx.Request.Header.SetCookie("test", string(encryptedCookie.Value()))
|
||||
h(ctx)
|
||||
utils.AssertEqual(t, 200, ctx.Response.StatusCode())
|
||||
utils.AssertEqual(t, "value=SomeThing", string(ctx.Response.Body()))
|
||||
}
|
||||
|
||||
func Test_Encrypt_Cookie_Next(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
Key: testKey,
|
||||
Next: func(_ *fiber.Ctx) bool {
|
||||
return true
|
||||
},
|
||||
}))
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
c.Cookie(&fiber.Cookie{
|
||||
Name: "test",
|
||||
Value: "SomeThing",
|
||||
})
|
||||
return nil
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, "SomeThing", resp.Cookies()[0].Value)
|
||||
}
|
||||
|
||||
func Test_Encrypt_Cookie_Except(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
Key: testKey,
|
||||
Except: []string{
|
||||
"test1",
|
||||
},
|
||||
}))
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
c.Cookie(&fiber.Cookie{
|
||||
Name: "test1",
|
||||
Value: "SomeThing",
|
||||
})
|
||||
c.Cookie(&fiber.Cookie{
|
||||
Name: "test2",
|
||||
Value: "SomeThing",
|
||||
})
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
h := app.Handler()
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.SetMethod(fiber.MethodGet)
|
||||
h(ctx)
|
||||
utils.AssertEqual(t, 200, ctx.Response.StatusCode())
|
||||
|
||||
rawCookie := fasthttp.Cookie{}
|
||||
rawCookie.SetKey("test1")
|
||||
utils.AssertEqual(t, true, ctx.Response.Header.Cookie(&rawCookie), "Get cookie value")
|
||||
utils.AssertEqual(t, "SomeThing", string(rawCookie.Value()))
|
||||
|
||||
encryptedCookie := fasthttp.Cookie{}
|
||||
encryptedCookie.SetKey("test2")
|
||||
utils.AssertEqual(t, true, ctx.Response.Header.Cookie(&encryptedCookie), "Get cookie value")
|
||||
decryptedCookieValue, err := DecryptCookie(string(encryptedCookie.Value()), testKey)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, "SomeThing", decryptedCookieValue)
|
||||
}
|
||||
|
||||
func Test_Encrypt_Cookie_Custom_Encryptor(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
Key: testKey,
|
||||
Encryptor: func(decryptedString, _ string) (string, error) {
|
||||
return base64.StdEncoding.EncodeToString([]byte(decryptedString)), nil
|
||||
},
|
||||
Decryptor: func(encryptedString, _ string) (string, error) {
|
||||
decodedBytes, err := base64.StdEncoding.DecodeString(encryptedString)
|
||||
return string(decodedBytes), err
|
||||
},
|
||||
}))
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return c.SendString("value=" + c.Cookies("test"))
|
||||
})
|
||||
app.Post("/", func(c *fiber.Ctx) error {
|
||||
c.Cookie(&fiber.Cookie{
|
||||
Name: "test",
|
||||
Value: "SomeThing",
|
||||
})
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
h := app.Handler()
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.SetMethod(fiber.MethodPost)
|
||||
h(ctx)
|
||||
utils.AssertEqual(t, 200, ctx.Response.StatusCode())
|
||||
|
||||
encryptedCookie := fasthttp.Cookie{}
|
||||
encryptedCookie.SetKey("test")
|
||||
utils.AssertEqual(t, true, ctx.Response.Header.Cookie(&encryptedCookie), "Get cookie value")
|
||||
decodedBytes, err := base64.StdEncoding.DecodeString(string(encryptedCookie.Value()))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, "SomeThing", string(decodedBytes))
|
||||
|
||||
ctx = &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.SetMethod(fiber.MethodGet)
|
||||
ctx.Request.Header.SetCookie("test", string(encryptedCookie.Value()))
|
||||
h(ctx)
|
||||
utils.AssertEqual(t, 200, ctx.Response.StatusCode())
|
||||
utils.AssertEqual(t, "value=SomeThing", string(ctx.Response.Body()))
|
||||
}
|
98
middleware/encryptcookie/utils.go
Normal file
98
middleware/encryptcookie/utils.go
Normal file
|
@ -0,0 +1,98 @@
|
|||
package encryptcookie
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
// EncryptCookie Encrypts a cookie value with specific encryption key
|
||||
func EncryptCookie(value, key string) (string, error) {
|
||||
keyDecoded, err := base64.StdEncoding.DecodeString(key)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to base64-decode key: %w", err)
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(keyDecoded)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create AES cipher: %w", err)
|
||||
}
|
||||
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create GCM mode: %w", err)
|
||||
}
|
||||
|
||||
nonce := make([]byte, gcm.NonceSize())
|
||||
if _, err = io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
return "", fmt.Errorf("failed to read: %w", err)
|
||||
}
|
||||
|
||||
ciphertext := gcm.Seal(nonce, nonce, []byte(value), nil)
|
||||
|
||||
return base64.StdEncoding.EncodeToString(ciphertext), nil
|
||||
}
|
||||
|
||||
// DecryptCookie Decrypts a cookie value with specific encryption key
|
||||
func DecryptCookie(value, key string) (string, error) {
|
||||
keyDecoded, err := base64.StdEncoding.DecodeString(key)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to base64-decode key: %w", err)
|
||||
}
|
||||
enc, err := base64.StdEncoding.DecodeString(value)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to base64-decode value: %w", err)
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(keyDecoded)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create AES cipher: %w", err)
|
||||
}
|
||||
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create GCM mode: %w", err)
|
||||
}
|
||||
|
||||
nonceSize := gcm.NonceSize()
|
||||
|
||||
if len(enc) < nonceSize {
|
||||
return "", errors.New("encrypted value is not valid")
|
||||
}
|
||||
|
||||
nonce, ciphertext := enc[:nonceSize], enc[nonceSize:]
|
||||
|
||||
plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to decrypt ciphertext: %w", err)
|
||||
}
|
||||
|
||||
return string(plaintext), nil
|
||||
}
|
||||
|
||||
// GenerateKey Generates an encryption key
|
||||
func GenerateKey() string {
|
||||
const keyLen = 32
|
||||
ret := make([]byte, keyLen)
|
||||
|
||||
if _, err := rand.Read(ret); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return base64.StdEncoding.EncodeToString(ret)
|
||||
}
|
||||
|
||||
// Check given cookie key is disabled for encryption or not
|
||||
func isDisabled(key string, except []string) bool {
|
||||
for _, k := range except {
|
||||
if key == k {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
68
middleware/envvar/envvar.go
Normal file
68
middleware/envvar/envvar.go
Normal file
|
@ -0,0 +1,68 @@
|
|||
package envvar
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
// Config defines the config for middleware.
|
||||
type Config struct {
|
||||
// ExportVars specifies the environment variables that should export
|
||||
ExportVars map[string]string
|
||||
// ExcludeVars specifies the environment variables that should not export
|
||||
ExcludeVars map[string]string
|
||||
}
|
||||
|
||||
type EnvVar struct {
|
||||
Vars map[string]string `json:"vars"`
|
||||
}
|
||||
|
||||
func (envVar *EnvVar) set(key, val string) {
|
||||
envVar.Vars[key] = val
|
||||
}
|
||||
|
||||
func New(config ...Config) fiber.Handler {
|
||||
var cfg Config
|
||||
if len(config) > 0 {
|
||||
cfg = config[0]
|
||||
}
|
||||
|
||||
return func(c *fiber.Ctx) error {
|
||||
if c.Method() != fiber.MethodGet {
|
||||
return fiber.ErrMethodNotAllowed
|
||||
}
|
||||
|
||||
envVar := newEnvVar(cfg)
|
||||
varsByte, err := c.App().Config().JSONEncoder(envVar)
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusInternalServerError).SendString(err.Error())
|
||||
}
|
||||
c.Set(fiber.HeaderContentType, fiber.MIMEApplicationJSONCharsetUTF8)
|
||||
return c.Send(varsByte)
|
||||
}
|
||||
}
|
||||
|
||||
func newEnvVar(cfg Config) *EnvVar {
|
||||
vars := &EnvVar{Vars: make(map[string]string)}
|
||||
|
||||
if len(cfg.ExportVars) > 0 {
|
||||
for key, defaultVal := range cfg.ExportVars {
|
||||
vars.set(key, defaultVal)
|
||||
if envVal, exists := os.LookupEnv(key); exists {
|
||||
vars.set(key, envVal)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
const numElems = 2
|
||||
for _, envVal := range os.Environ() {
|
||||
keyVal := strings.SplitN(envVal, "=", numElems)
|
||||
if _, exists := cfg.ExcludeVars[keyVal[0]]; !exists {
|
||||
vars.set(keyVal[0], keyVal[1])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return vars
|
||||
}
|
172
middleware/envvar/envvar_test.go
Normal file
172
middleware/envvar/envvar_test.go
Normal file
|
@ -0,0 +1,172 @@
|
|||
//nolint:bodyclose // Much easier to just ignore memory leaks in tests
|
||||
package envvar
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
)
|
||||
|
||||
func TestEnvVarStructWithExportVarsExcludeVars(t *testing.T) {
|
||||
err := os.Setenv("testKey", "testEnvValue")
|
||||
utils.AssertEqual(t, nil, err)
|
||||
err = os.Setenv("anotherEnvKey", "anotherEnvVal")
|
||||
utils.AssertEqual(t, nil, err)
|
||||
err = os.Setenv("excludeKey", "excludeEnvValue")
|
||||
utils.AssertEqual(t, nil, err)
|
||||
defer func() {
|
||||
err := os.Unsetenv("testKey")
|
||||
utils.AssertEqual(t, nil, err)
|
||||
err = os.Unsetenv("anotherEnvKey")
|
||||
utils.AssertEqual(t, nil, err)
|
||||
err = os.Unsetenv("excludeKey")
|
||||
utils.AssertEqual(t, nil, err)
|
||||
}()
|
||||
|
||||
vars := newEnvVar(Config{
|
||||
ExportVars: map[string]string{"testKey": "", "testDefaultKey": "testDefaultVal"},
|
||||
ExcludeVars: map[string]string{"excludeKey": ""},
|
||||
})
|
||||
|
||||
utils.AssertEqual(t, vars.Vars["testKey"], "testEnvValue")
|
||||
utils.AssertEqual(t, vars.Vars["testDefaultKey"], "testDefaultVal")
|
||||
utils.AssertEqual(t, vars.Vars["excludeKey"], "")
|
||||
utils.AssertEqual(t, vars.Vars["anotherEnvKey"], "")
|
||||
}
|
||||
|
||||
func TestEnvVarHandler(t *testing.T) {
|
||||
err := os.Setenv("testKey", "testVal")
|
||||
utils.AssertEqual(t, nil, err)
|
||||
defer func() {
|
||||
err := os.Unsetenv("testKey")
|
||||
utils.AssertEqual(t, nil, err)
|
||||
}()
|
||||
|
||||
expectedEnvVarResponse, err := json.Marshal(
|
||||
struct {
|
||||
Vars map[string]string `json:"vars"`
|
||||
}{
|
||||
map[string]string{"testKey": "testVal"},
|
||||
})
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
app := fiber.New()
|
||||
app.Use("/envvars", New(Config{
|
||||
ExportVars: map[string]string{"testKey": ""},
|
||||
}))
|
||||
|
||||
req, err := http.NewRequestWithContext(context.Background(), fiber.MethodGet, "http://localhost/envvars", nil)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
resp, err := app.Test(req)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
utils.AssertEqual(t, expectedEnvVarResponse, respBody)
|
||||
}
|
||||
|
||||
func TestEnvVarHandlerNotMatched(t *testing.T) {
|
||||
app := fiber.New()
|
||||
app.Use("/envvars", New(Config{
|
||||
ExportVars: map[string]string{"testKey": ""},
|
||||
}))
|
||||
|
||||
app.Get("/another-path", func(ctx *fiber.Ctx) error {
|
||||
utils.AssertEqual(t, nil, ctx.SendString("OK"))
|
||||
return nil
|
||||
})
|
||||
|
||||
req, err := http.NewRequestWithContext(context.Background(), fiber.MethodGet, "http://localhost/another-path", nil)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
resp, err := app.Test(req)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
utils.AssertEqual(t, []byte("OK"), respBody)
|
||||
}
|
||||
|
||||
func TestEnvVarHandlerDefaultConfig(t *testing.T) {
|
||||
err := os.Setenv("testEnvKey", "testEnvVal")
|
||||
utils.AssertEqual(t, nil, err)
|
||||
defer func() {
|
||||
err := os.Unsetenv("testEnvKey")
|
||||
utils.AssertEqual(t, nil, err)
|
||||
}()
|
||||
|
||||
app := fiber.New()
|
||||
app.Use("/envvars", New())
|
||||
|
||||
req, err := http.NewRequestWithContext(context.Background(), fiber.MethodGet, "http://localhost/envvars", nil)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
resp, err := app.Test(req)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
var envVars EnvVar
|
||||
utils.AssertEqual(t, nil, json.Unmarshal(respBody, &envVars))
|
||||
val := envVars.Vars["testEnvKey"]
|
||||
utils.AssertEqual(t, "testEnvVal", val)
|
||||
}
|
||||
|
||||
func TestEnvVarHandlerMethod(t *testing.T) {
|
||||
app := fiber.New()
|
||||
app.Use("/envvars", New())
|
||||
|
||||
req, err := http.NewRequestWithContext(context.Background(), fiber.MethodPost, "http://localhost/envvars", nil)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
resp, err := app.Test(req)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusMethodNotAllowed, resp.StatusCode)
|
||||
}
|
||||
|
||||
func TestEnvVarHandlerSpecialValue(t *testing.T) {
|
||||
testEnvKey := "testEnvKey"
|
||||
fakeBase64 := "testBase64:TQ=="
|
||||
err := os.Setenv(testEnvKey, fakeBase64)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
defer func() {
|
||||
err := os.Unsetenv(testEnvKey)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
}()
|
||||
|
||||
app := fiber.New()
|
||||
app.Use("/envvars", New())
|
||||
app.Use("/envvars/export", New(Config{ExportVars: map[string]string{testEnvKey: ""}}))
|
||||
|
||||
req, err := http.NewRequestWithContext(context.Background(), fiber.MethodGet, "http://localhost/envvars", nil)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
resp, err := app.Test(req)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
var envVars EnvVar
|
||||
utils.AssertEqual(t, nil, json.Unmarshal(respBody, &envVars))
|
||||
val := envVars.Vars[testEnvKey]
|
||||
utils.AssertEqual(t, fakeBase64, val)
|
||||
|
||||
req, err = http.NewRequestWithContext(context.Background(), fiber.MethodGet, "http://localhost/envvars/export", nil)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
resp, err = app.Test(req)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
respBody, err = io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
var envVarsExport EnvVar
|
||||
utils.AssertEqual(t, nil, json.Unmarshal(respBody, &envVarsExport))
|
||||
val = envVarsExport.Vars[testEnvKey]
|
||||
utils.AssertEqual(t, fakeBase64, val)
|
||||
}
|
44
middleware/etag/config.go
Normal file
44
middleware/etag/config.go
Normal file
|
@ -0,0 +1,44 @@
|
|||
package etag
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
// Config defines the config for middleware.
|
||||
type Config struct {
|
||||
// Weak indicates that a weak validator is used. Weak etags are easy
|
||||
// to generate, but are far less useful for comparisons. Strong
|
||||
// validators are ideal for comparisons but can be very difficult
|
||||
// to generate efficiently. Weak ETag values of two representations
|
||||
// of the same resources might be semantically equivalent, but not
|
||||
// byte-for-byte identical. This means weak etags prevent caching
|
||||
// when byte range requests are used, but strong etags mean range
|
||||
// requests can still be cached.
|
||||
Weak bool
|
||||
|
||||
// Next defines a function to skip this middleware when returned true.
|
||||
//
|
||||
// Optional. Default: nil
|
||||
Next func(c *fiber.Ctx) bool
|
||||
}
|
||||
|
||||
// ConfigDefault is the default config
|
||||
var ConfigDefault = Config{
|
||||
Weak: false,
|
||||
Next: nil,
|
||||
}
|
||||
|
||||
// Helper function to set default values
|
||||
func configDefault(config ...Config) Config {
|
||||
// Return default config if nothing provided
|
||||
if len(config) < 1 {
|
||||
return ConfigDefault
|
||||
}
|
||||
|
||||
// Override default config
|
||||
cfg := config[0]
|
||||
|
||||
// Set default values
|
||||
|
||||
return cfg
|
||||
}
|
116
middleware/etag/etag.go
Normal file
116
middleware/etag/etag.go
Normal file
|
@ -0,0 +1,116 @@
|
|||
package etag
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"hash/crc32"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
|
||||
"github.com/valyala/bytebufferpool"
|
||||
)
|
||||
|
||||
// New creates a new middleware handler
|
||||
func New(config ...Config) fiber.Handler {
|
||||
// Set default config
|
||||
cfg := configDefault(config...)
|
||||
|
||||
var (
|
||||
normalizedHeaderETag = []byte("Etag")
|
||||
weakPrefix = []byte("W/")
|
||||
)
|
||||
|
||||
const crcPol = 0xD5828281
|
||||
crc32q := crc32.MakeTable(crcPol)
|
||||
|
||||
// Return new handler
|
||||
return func(c *fiber.Ctx) error {
|
||||
// Don't execute middleware if Next returns true
|
||||
if cfg.Next != nil && cfg.Next(c) {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
// Return err if next handler returns one
|
||||
if err := c.Next(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Don't generate ETags for invalid responses
|
||||
if c.Response().StatusCode() != fiber.StatusOK {
|
||||
return nil
|
||||
}
|
||||
body := c.Response().Body()
|
||||
// Skips ETag if no response body is present
|
||||
if len(body) == 0 {
|
||||
return nil
|
||||
}
|
||||
// Skip ETag if header is already present
|
||||
if c.Response().Header.PeekBytes(normalizedHeaderETag) != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Generate ETag for response
|
||||
bb := bytebufferpool.Get()
|
||||
defer bytebufferpool.Put(bb)
|
||||
|
||||
// Enable weak tag
|
||||
if cfg.Weak {
|
||||
_, _ = bb.Write(weakPrefix) //nolint:errcheck // This will never fail
|
||||
}
|
||||
|
||||
_ = bb.WriteByte('"') //nolint:errcheck // This will never fail
|
||||
bb.B = appendUint(bb.Bytes(), uint32(len(body)))
|
||||
_ = bb.WriteByte('-') //nolint:errcheck // This will never fail
|
||||
bb.B = appendUint(bb.Bytes(), crc32.Checksum(body, crc32q))
|
||||
_ = bb.WriteByte('"') //nolint:errcheck // This will never fail
|
||||
|
||||
etag := bb.Bytes()
|
||||
|
||||
// Get ETag header from request
|
||||
clientEtag := c.Request().Header.Peek(fiber.HeaderIfNoneMatch)
|
||||
|
||||
// Check if client's ETag is weak
|
||||
if bytes.HasPrefix(clientEtag, weakPrefix) {
|
||||
// Check if server's ETag is weak
|
||||
if bytes.Equal(clientEtag[2:], etag) || bytes.Equal(clientEtag[2:], etag[2:]) {
|
||||
// W/1 == 1 || W/1 == W/1
|
||||
c.Context().ResetBody()
|
||||
|
||||
return c.SendStatus(fiber.StatusNotModified)
|
||||
}
|
||||
// W/1 != W/2 || W/1 != 2
|
||||
c.Response().Header.SetCanonical(normalizedHeaderETag, etag)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
if bytes.Contains(clientEtag, etag) {
|
||||
// 1 == 1
|
||||
c.Context().ResetBody()
|
||||
|
||||
return c.SendStatus(fiber.StatusNotModified)
|
||||
}
|
||||
// 1 != 2
|
||||
c.Response().Header.SetCanonical(normalizedHeaderETag, etag)
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// appendUint appends n to dst and returns the extended dst.
|
||||
func appendUint(dst []byte, n uint32) []byte {
|
||||
var b [20]byte
|
||||
buf := b[:]
|
||||
i := len(buf)
|
||||
var q uint32
|
||||
for n >= 10 {
|
||||
i--
|
||||
q = n / 10
|
||||
buf[i] = '0' + byte(n-q*10)
|
||||
n = q
|
||||
}
|
||||
i--
|
||||
buf[i] = '0' + byte(n)
|
||||
|
||||
dst = append(dst, buf[i:]...)
|
||||
return dst
|
||||
}
|
291
middleware/etag/etag_test.go
Normal file
291
middleware/etag/etag_test.go
Normal file
|
@ -0,0 +1,291 @@
|
|||
package etag
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// go test -run Test_ETag_Next
|
||||
func Test_ETag_Next(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
app.Use(New(Config{
|
||||
Next: func(_ *fiber.Ctx) bool {
|
||||
return true
|
||||
},
|
||||
}))
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
|
||||
}
|
||||
|
||||
// go test -run Test_ETag_SkipError
|
||||
func Test_ETag_SkipError(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New())
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return fiber.ErrForbidden
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusForbidden, resp.StatusCode)
|
||||
}
|
||||
|
||||
// go test -run Test_ETag_NotStatusOK
|
||||
func Test_ETag_NotStatusOK(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New())
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return c.SendStatus(fiber.StatusCreated)
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusCreated, resp.StatusCode)
|
||||
}
|
||||
|
||||
// go test -run Test_ETag_NoBody
|
||||
func Test_ETag_NoBody(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New())
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return nil
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||
}
|
||||
|
||||
// go test -run Test_ETag_NewEtag
|
||||
func Test_ETag_NewEtag(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("without HeaderIfNoneMatch", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
testETagNewEtag(t, false, false)
|
||||
})
|
||||
t.Run("with HeaderIfNoneMatch and not matched", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
testETagNewEtag(t, true, false)
|
||||
})
|
||||
t.Run("with HeaderIfNoneMatch and matched", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
testETagNewEtag(t, true, true)
|
||||
})
|
||||
}
|
||||
|
||||
func testETagNewEtag(t *testing.T, headerIfNoneMatch, matched bool) { //nolint:revive // We're in a test, so using bools as a flow-control is fine
|
||||
t.Helper()
|
||||
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New())
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return c.SendString("Hello, World!")
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||
if headerIfNoneMatch {
|
||||
etag := `"non-match"`
|
||||
if matched {
|
||||
etag = `"13-1831710635"`
|
||||
}
|
||||
req.Header.Set(fiber.HeaderIfNoneMatch, etag)
|
||||
}
|
||||
|
||||
resp, err := app.Test(req)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
if !headerIfNoneMatch || !matched {
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||
utils.AssertEqual(t, `"13-1831710635"`, resp.Header.Get(fiber.HeaderETag))
|
||||
return
|
||||
}
|
||||
|
||||
if matched {
|
||||
utils.AssertEqual(t, fiber.StatusNotModified, resp.StatusCode)
|
||||
b, err := io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 0, len(b))
|
||||
}
|
||||
}
|
||||
|
||||
// go test -run Test_ETag_WeakEtag
|
||||
func Test_ETag_WeakEtag(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("without HeaderIfNoneMatch", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
testETagWeakEtag(t, false, false)
|
||||
})
|
||||
t.Run("with HeaderIfNoneMatch and not matched", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
testETagWeakEtag(t, true, false)
|
||||
})
|
||||
t.Run("with HeaderIfNoneMatch and matched", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
testETagWeakEtag(t, true, true)
|
||||
})
|
||||
}
|
||||
|
||||
func testETagWeakEtag(t *testing.T, headerIfNoneMatch, matched bool) { //nolint:revive // We're in a test, so using bools as a flow-control is fine
|
||||
t.Helper()
|
||||
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New(Config{Weak: true}))
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return c.SendString("Hello, World!")
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||
if headerIfNoneMatch {
|
||||
etag := `W/"non-match"`
|
||||
if matched {
|
||||
etag = `W/"13-1831710635"`
|
||||
}
|
||||
req.Header.Set(fiber.HeaderIfNoneMatch, etag)
|
||||
}
|
||||
|
||||
resp, err := app.Test(req)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
if !headerIfNoneMatch || !matched {
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||
utils.AssertEqual(t, `W/"13-1831710635"`, resp.Header.Get(fiber.HeaderETag))
|
||||
return
|
||||
}
|
||||
|
||||
if matched {
|
||||
utils.AssertEqual(t, fiber.StatusNotModified, resp.StatusCode)
|
||||
b, err := io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 0, len(b))
|
||||
}
|
||||
}
|
||||
|
||||
// go test -run Test_ETag_CustomEtag
|
||||
func Test_ETag_CustomEtag(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("without HeaderIfNoneMatch", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
testETagCustomEtag(t, false, false)
|
||||
})
|
||||
t.Run("with HeaderIfNoneMatch and not matched", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
testETagCustomEtag(t, true, false)
|
||||
})
|
||||
t.Run("with HeaderIfNoneMatch and matched", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
testETagCustomEtag(t, true, true)
|
||||
})
|
||||
}
|
||||
|
||||
func testETagCustomEtag(t *testing.T, headerIfNoneMatch, matched bool) { //nolint:revive // We're in a test, so using bools as a flow-control is fine
|
||||
t.Helper()
|
||||
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New())
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
c.Set(fiber.HeaderETag, `"custom"`)
|
||||
if bytes.Equal(c.Request().Header.Peek(fiber.HeaderIfNoneMatch), []byte(`"custom"`)) {
|
||||
return c.SendStatus(fiber.StatusNotModified)
|
||||
}
|
||||
return c.SendString("Hello, World!")
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||
if headerIfNoneMatch {
|
||||
etag := `"non-match"`
|
||||
if matched {
|
||||
etag = `"custom"`
|
||||
}
|
||||
req.Header.Set(fiber.HeaderIfNoneMatch, etag)
|
||||
}
|
||||
|
||||
resp, err := app.Test(req)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
if !headerIfNoneMatch || !matched {
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||
utils.AssertEqual(t, `"custom"`, resp.Header.Get(fiber.HeaderETag))
|
||||
return
|
||||
}
|
||||
|
||||
if matched {
|
||||
utils.AssertEqual(t, fiber.StatusNotModified, resp.StatusCode)
|
||||
b, err := io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 0, len(b))
|
||||
}
|
||||
}
|
||||
|
||||
// go test -run Test_ETag_CustomEtagPut
|
||||
func Test_ETag_CustomEtagPut(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New())
|
||||
|
||||
app.Put("/", func(c *fiber.Ctx) error {
|
||||
c.Set(fiber.HeaderETag, `"custom"`)
|
||||
if !bytes.Equal(c.Request().Header.Peek(fiber.HeaderIfMatch), []byte(`"custom"`)) {
|
||||
return c.SendStatus(fiber.StatusPreconditionFailed)
|
||||
}
|
||||
return c.SendString("Hello, World!")
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(fiber.MethodPut, "/", nil)
|
||||
req.Header.Set(fiber.HeaderIfMatch, `"non-match"`)
|
||||
resp, err := app.Test(req)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusPreconditionFailed, resp.StatusCode)
|
||||
}
|
||||
|
||||
// go test -v -run=^$ -bench=Benchmark_Etag -benchmem -count=4
|
||||
func Benchmark_Etag(b *testing.B) {
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New())
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return c.SendString("Hello, World!")
|
||||
})
|
||||
|
||||
h := app.Handler()
|
||||
|
||||
fctx := &fasthttp.RequestCtx{}
|
||||
fctx.Request.Header.SetMethod(fiber.MethodGet)
|
||||
fctx.Request.SetRequestURI("/")
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
h(fctx)
|
||||
}
|
||||
|
||||
utils.AssertEqual(b, 200, fctx.Response.Header.StatusCode())
|
||||
utils.AssertEqual(b, `"13-1831710635"`, string(fctx.Response.Header.Peek(fiber.HeaderETag)))
|
||||
}
|
34
middleware/expvar/config.go
Normal file
34
middleware/expvar/config.go
Normal file
|
@ -0,0 +1,34 @@
|
|||
package expvar
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
// Config defines the config for middleware.
|
||||
type Config struct {
|
||||
// Next defines a function to skip this middleware when returned true.
|
||||
//
|
||||
// Optional. Default: nil
|
||||
Next func(c *fiber.Ctx) bool
|
||||
}
|
||||
|
||||
var ConfigDefault = Config{
|
||||
Next: nil,
|
||||
}
|
||||
|
||||
func configDefault(config ...Config) Config {
|
||||
// Return default config if nothing provided
|
||||
if len(config) < 1 {
|
||||
return ConfigDefault
|
||||
}
|
||||
|
||||
// Override default config
|
||||
cfg := config[0]
|
||||
|
||||
// Set default values
|
||||
if cfg.Next == nil {
|
||||
cfg.Next = ConfigDefault.Next
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
35
middleware/expvar/expvar.go
Normal file
35
middleware/expvar/expvar.go
Normal file
|
@ -0,0 +1,35 @@
|
|||
package expvar
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
|
||||
"github.com/valyala/fasthttp/expvarhandler"
|
||||
)
|
||||
|
||||
// New creates a new middleware handler
|
||||
func New(config ...Config) fiber.Handler {
|
||||
// Set default config
|
||||
cfg := configDefault(config...)
|
||||
|
||||
// Return new handler
|
||||
return func(c *fiber.Ctx) error {
|
||||
// Don't execute middleware if Next returns true
|
||||
if cfg.Next != nil && cfg.Next(c) {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
path := c.Path()
|
||||
// We are only interested in /debug/vars routes
|
||||
if len(path) < 11 || !strings.HasPrefix(path, "/debug/vars") {
|
||||
return c.Next()
|
||||
}
|
||||
if path == "/debug/vars" {
|
||||
expvarhandler.ExpvarHandler(c.Context())
|
||||
return nil
|
||||
}
|
||||
|
||||
return c.Redirect("/debug/vars", fiber.StatusFound)
|
||||
}
|
||||
}
|
103
middleware/expvar/expvar_test.go
Normal file
103
middleware/expvar/expvar_test.go
Normal file
|
@ -0,0 +1,103 @@
|
|||
package expvar
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
)
|
||||
|
||||
func Test_Non_Expvar_Path(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New())
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return c.SendString("escaped")
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||
|
||||
b, err := io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, "escaped", string(b))
|
||||
}
|
||||
|
||||
func Test_Expvar_Index(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New())
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return c.SendString("escaped")
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/debug/vars", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||
utils.AssertEqual(t, fiber.MIMEApplicationJSONCharsetUTF8, resp.Header.Get(fiber.HeaderContentType))
|
||||
|
||||
b, err := io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, true, bytes.Contains(b, []byte("cmdline")))
|
||||
utils.AssertEqual(t, true, bytes.Contains(b, []byte("memstat")))
|
||||
}
|
||||
|
||||
func Test_Expvar_Filter(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New())
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return c.SendString("escaped")
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/debug/vars?r=cmd", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||
utils.AssertEqual(t, fiber.MIMEApplicationJSONCharsetUTF8, resp.Header.Get(fiber.HeaderContentType))
|
||||
|
||||
b, err := io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, true, bytes.Contains(b, []byte("cmdline")))
|
||||
utils.AssertEqual(t, false, bytes.Contains(b, []byte("memstat")))
|
||||
}
|
||||
|
||||
func Test_Expvar_Other_Path(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New())
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return c.SendString("escaped")
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/debug/vars/302", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 302, resp.StatusCode)
|
||||
}
|
||||
|
||||
// go test -run Test_Expvar_Next
|
||||
func Test_Expvar_Next(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
Next: func(_ *fiber.Ctx) bool {
|
||||
return true
|
||||
},
|
||||
}))
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/debug/vars", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 404, resp.StatusCode)
|
||||
}
|
146
middleware/favicon/favicon.go
Normal file
146
middleware/favicon/favicon.go
Normal file
|
@ -0,0 +1,146 @@
|
|||
package favicon
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
// Config defines the config for middleware.
|
||||
type Config struct {
|
||||
// Next defines a function to skip this middleware when returned true.
|
||||
//
|
||||
// Optional. Default: nil
|
||||
Next func(c *fiber.Ctx) bool
|
||||
|
||||
// Raw data of the favicon file
|
||||
//
|
||||
// Optional. Default: nil
|
||||
Data []byte `json:"-"`
|
||||
|
||||
// File holds the path to an actual favicon that will be cached
|
||||
//
|
||||
// Optional. Default: ""
|
||||
File string `json:"file"`
|
||||
|
||||
// URL for favicon handler
|
||||
//
|
||||
// Optional. Default: "/favicon.ico"
|
||||
URL string `json:"url"`
|
||||
|
||||
// FileSystem is an optional alternate filesystem to search for the favicon in.
|
||||
// An example of this could be an embedded or network filesystem
|
||||
//
|
||||
// Optional. Default: nil
|
||||
FileSystem http.FileSystem `json:"-"`
|
||||
|
||||
// CacheControl defines how the Cache-Control header in the response should be set
|
||||
//
|
||||
// Optional. Default: "public, max-age=31536000"
|
||||
CacheControl string `json:"cache_control"`
|
||||
}
|
||||
|
||||
// ConfigDefault is the default config
|
||||
var ConfigDefault = Config{
|
||||
Next: nil,
|
||||
File: "",
|
||||
URL: fPath,
|
||||
CacheControl: "public, max-age=31536000",
|
||||
}
|
||||
|
||||
const (
|
||||
fPath = "/favicon.ico"
|
||||
hType = "image/x-icon"
|
||||
hAllow = "GET, HEAD, OPTIONS"
|
||||
hZero = "0"
|
||||
)
|
||||
|
||||
// New creates a new middleware handler
|
||||
func New(config ...Config) fiber.Handler {
|
||||
// Set default config
|
||||
cfg := ConfigDefault
|
||||
|
||||
// Override config if provided
|
||||
if len(config) > 0 {
|
||||
cfg = config[0]
|
||||
|
||||
// Set default values
|
||||
if cfg.Next == nil {
|
||||
cfg.Next = ConfigDefault.Next
|
||||
}
|
||||
if cfg.URL == "" {
|
||||
cfg.URL = ConfigDefault.URL
|
||||
}
|
||||
if cfg.File == "" {
|
||||
cfg.File = ConfigDefault.File
|
||||
}
|
||||
if cfg.CacheControl == "" {
|
||||
cfg.CacheControl = ConfigDefault.CacheControl
|
||||
}
|
||||
}
|
||||
|
||||
// Load icon if provided
|
||||
var (
|
||||
err error
|
||||
icon []byte
|
||||
iconLen string
|
||||
)
|
||||
if cfg.Data != nil {
|
||||
// use the provided favicon data
|
||||
icon = cfg.Data
|
||||
iconLen = strconv.Itoa(len(cfg.Data))
|
||||
} else if cfg.File != "" {
|
||||
// read from configured filesystem if present
|
||||
if cfg.FileSystem != nil {
|
||||
f, err := cfg.FileSystem.Open(cfg.File)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if icon, err = io.ReadAll(f); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
} else if icon, err = os.ReadFile(cfg.File); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
iconLen = strconv.Itoa(len(icon))
|
||||
}
|
||||
|
||||
// Return new handler
|
||||
return func(c *fiber.Ctx) error {
|
||||
// Don't execute middleware if Next returns true
|
||||
if cfg.Next != nil && cfg.Next(c) {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
// Only respond to favicon requests
|
||||
if c.Path() != cfg.URL {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
// Only allow GET, HEAD and OPTIONS requests
|
||||
if c.Method() != fiber.MethodGet && c.Method() != fiber.MethodHead {
|
||||
if c.Method() != fiber.MethodOptions {
|
||||
c.Status(fiber.StatusMethodNotAllowed)
|
||||
} else {
|
||||
c.Status(fiber.StatusOK)
|
||||
}
|
||||
c.Set(fiber.HeaderAllow, hAllow)
|
||||
c.Set(fiber.HeaderContentLength, hZero)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Serve cached favicon
|
||||
if len(icon) > 0 {
|
||||
c.Set(fiber.HeaderContentLength, iconLen)
|
||||
c.Set(fiber.HeaderContentType, hType)
|
||||
c.Set(fiber.HeaderCacheControl, cfg.CacheControl)
|
||||
return c.Status(fiber.StatusOK).Send(icon)
|
||||
}
|
||||
|
||||
return c.SendStatus(fiber.StatusNoContent)
|
||||
}
|
||||
}
|
208
middleware/favicon/favicon_test.go
Normal file
208
middleware/favicon/favicon_test.go
Normal file
|
@ -0,0 +1,208 @@
|
|||
//nolint:bodyclose // Much easier to just ignore memory leaks in tests
|
||||
package favicon
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// go test -run Test_Middleware_Favicon
|
||||
func Test_Middleware_Favicon(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New())
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return nil
|
||||
})
|
||||
|
||||
// Skip Favicon middleware
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode, "Status code")
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/favicon.ico", nil))
|
||||
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
||||
utils.AssertEqual(t, fiber.StatusNoContent, resp.StatusCode, "Status code")
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodOptions, "/favicon.ico", nil))
|
||||
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode, "Status code")
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodPut, "/favicon.ico", nil))
|
||||
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
||||
utils.AssertEqual(t, fiber.StatusMethodNotAllowed, resp.StatusCode, "Status code")
|
||||
utils.AssertEqual(t, strings.Join([]string{fiber.MethodGet, fiber.MethodHead, fiber.MethodOptions}, ", "), resp.Header.Get(fiber.HeaderAllow))
|
||||
}
|
||||
|
||||
// go test -run Test_Middleware_Favicon_Not_Found
|
||||
func Test_Middleware_Favicon_Not_Found(t *testing.T) {
|
||||
t.Parallel()
|
||||
defer func() {
|
||||
if err := recover(); err == nil {
|
||||
t.Fatal("should cache panic")
|
||||
}
|
||||
}()
|
||||
|
||||
fiber.New().Use(New(Config{
|
||||
File: "non-exist.ico",
|
||||
}))
|
||||
}
|
||||
|
||||
// go test -run Test_Middleware_Favicon_Found
|
||||
func Test_Middleware_Favicon_Found(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
File: "../../.github/testdata/favicon.ico",
|
||||
}))
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return nil
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/favicon.ico", nil))
|
||||
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode, "Status code")
|
||||
utils.AssertEqual(t, "image/x-icon", resp.Header.Get(fiber.HeaderContentType))
|
||||
utils.AssertEqual(t, "public, max-age=31536000", resp.Header.Get(fiber.HeaderCacheControl), "CacheControl Control")
|
||||
}
|
||||
|
||||
// go test -run Test_Custom_Favicon_Url
|
||||
func Test_Custom_Favicon_Url(t *testing.T) {
|
||||
app := fiber.New()
|
||||
const customURL = "/favicon.svg"
|
||||
app.Use(New(Config{
|
||||
File: "../../.github/testdata/favicon.ico",
|
||||
URL: customURL,
|
||||
}))
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return nil
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(http.MethodGet, customURL, nil))
|
||||
|
||||
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode, "Status code")
|
||||
utils.AssertEqual(t, "image/x-icon", resp.Header.Get(fiber.HeaderContentType))
|
||||
}
|
||||
|
||||
// go test -run Test_Custom_Favicon_Data
|
||||
func Test_Custom_Favicon_Data(t *testing.T) {
|
||||
data, err := os.ReadFile("../../.github/testdata/favicon.ico")
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
Data: data,
|
||||
}))
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return nil
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/favicon.ico", nil))
|
||||
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode, "Status code")
|
||||
utils.AssertEqual(t, "image/x-icon", resp.Header.Get(fiber.HeaderContentType))
|
||||
utils.AssertEqual(t, "public, max-age=31536000", resp.Header.Get(fiber.HeaderCacheControl), "CacheControl Control")
|
||||
}
|
||||
|
||||
// mockFS wraps local filesystem for the purposes of
|
||||
// Test_Middleware_Favicon_FileSystem located below
|
||||
// TODO use os.Dir if fiber upgrades to 1.16
|
||||
type mockFS struct{}
|
||||
|
||||
func (mockFS) Open(name string) (http.File, error) {
|
||||
if name == "/" {
|
||||
name = "."
|
||||
} else {
|
||||
name = strings.TrimPrefix(name, "/")
|
||||
}
|
||||
file, err := os.Open(name) //nolint:gosec // We're in a test func, so this is fine
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open: %w", err)
|
||||
}
|
||||
return file, nil
|
||||
}
|
||||
|
||||
// go test -run Test_Middleware_Favicon_FileSystem
|
||||
func Test_Middleware_Favicon_FileSystem(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
File: "../../.github/testdata/favicon.ico",
|
||||
FileSystem: mockFS{},
|
||||
}))
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/favicon.ico", nil))
|
||||
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode, "Status code")
|
||||
utils.AssertEqual(t, "image/x-icon", resp.Header.Get(fiber.HeaderContentType))
|
||||
utils.AssertEqual(t, "public, max-age=31536000", resp.Header.Get(fiber.HeaderCacheControl), "CacheControl Control")
|
||||
}
|
||||
|
||||
// go test -run Test_Middleware_Favicon_CacheControl
|
||||
func Test_Middleware_Favicon_CacheControl(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
CacheControl: "public, max-age=100",
|
||||
File: "../../.github/testdata/favicon.ico",
|
||||
}))
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/favicon.ico", nil))
|
||||
utils.AssertEqual(t, nil, err, "app.Test(req)")
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode, "Status code")
|
||||
utils.AssertEqual(t, "image/x-icon", resp.Header.Get(fiber.HeaderContentType))
|
||||
utils.AssertEqual(t, "public, max-age=100", resp.Header.Get(fiber.HeaderCacheControl), "CacheControl Control")
|
||||
}
|
||||
|
||||
// go test -v -run=^$ -bench=Benchmark_Middleware_Favicon -benchmem -count=4
|
||||
func Benchmark_Middleware_Favicon(b *testing.B) {
|
||||
app := fiber.New()
|
||||
app.Use(New())
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return nil
|
||||
})
|
||||
handler := app.Handler()
|
||||
|
||||
c := &fasthttp.RequestCtx{}
|
||||
c.Request.SetRequestURI("/")
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for n := 0; n < b.N; n++ {
|
||||
handler(c)
|
||||
}
|
||||
}
|
||||
|
||||
// go test -run Test_Favicon_Next
|
||||
func Test_Favicon_Next(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
app.Use(New(Config{
|
||||
Next: func(_ *fiber.Ctx) bool {
|
||||
return true
|
||||
},
|
||||
}))
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
|
||||
}
|
287
middleware/filesystem/filesystem.go
Normal file
287
middleware/filesystem/filesystem.go
Normal file
|
@ -0,0 +1,287 @@
|
|||
package filesystem
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
)
|
||||
|
||||
// Config defines the config for middleware.
|
||||
type Config struct {
|
||||
// Next defines a function to skip this middleware when returned true.
|
||||
//
|
||||
// Optional. Default: nil
|
||||
Next func(c *fiber.Ctx) bool
|
||||
|
||||
// Root is a FileSystem that provides access
|
||||
// to a collection of files and directories.
|
||||
//
|
||||
// Required. Default: nil
|
||||
Root http.FileSystem `json:"-"`
|
||||
|
||||
// PathPrefix defines a prefix to be added to a filepath when
|
||||
// reading a file from the FileSystem.
|
||||
//
|
||||
// Use when using Go 1.16 embed.FS
|
||||
//
|
||||
// Optional. Default ""
|
||||
PathPrefix string `json:"path_prefix"`
|
||||
|
||||
// Enable directory browsing.
|
||||
//
|
||||
// Optional. Default: false
|
||||
Browse bool `json:"browse"`
|
||||
|
||||
// Index file for serving a directory.
|
||||
//
|
||||
// Optional. Default: "index.html"
|
||||
Index string `json:"index"`
|
||||
|
||||
// The value for the Cache-Control HTTP-header
|
||||
// that is set on the file response. MaxAge is defined in seconds.
|
||||
//
|
||||
// Optional. Default value 0.
|
||||
MaxAge int `json:"max_age"`
|
||||
|
||||
// File to return if path is not found. Useful for SPA's.
|
||||
//
|
||||
// Optional. Default: ""
|
||||
NotFoundFile string `json:"not_found_file"`
|
||||
|
||||
// The value for the Content-Type HTTP-header
|
||||
// that is set on the file response
|
||||
//
|
||||
// Optional. Default: ""
|
||||
ContentTypeCharset string `json:"content_type_charset"`
|
||||
}
|
||||
|
||||
// ConfigDefault is the default config
|
||||
var ConfigDefault = Config{
|
||||
Next: nil,
|
||||
Root: nil,
|
||||
PathPrefix: "",
|
||||
Browse: false,
|
||||
Index: "/index.html",
|
||||
MaxAge: 0,
|
||||
ContentTypeCharset: "",
|
||||
}
|
||||
|
||||
// New creates a new middleware handler.
|
||||
//
|
||||
// filesystem does not handle url encoded values (for example spaces)
|
||||
// on it's own. If you need that functionality, set "UnescapePath"
|
||||
// in fiber.Config
|
||||
func New(config ...Config) fiber.Handler {
|
||||
// Set default config
|
||||
cfg := ConfigDefault
|
||||
|
||||
// Override config if provided
|
||||
if len(config) > 0 {
|
||||
cfg = config[0]
|
||||
|
||||
// Set default values
|
||||
if cfg.Index == "" {
|
||||
cfg.Index = ConfigDefault.Index
|
||||
}
|
||||
if !strings.HasPrefix(cfg.Index, "/") {
|
||||
cfg.Index = "/" + cfg.Index
|
||||
}
|
||||
if cfg.NotFoundFile != "" && !strings.HasPrefix(cfg.NotFoundFile, "/") {
|
||||
cfg.NotFoundFile = "/" + cfg.NotFoundFile
|
||||
}
|
||||
}
|
||||
|
||||
if cfg.Root == nil {
|
||||
panic("filesystem: Root cannot be nil")
|
||||
}
|
||||
|
||||
if cfg.PathPrefix != "" && !strings.HasPrefix(cfg.PathPrefix, "/") {
|
||||
cfg.PathPrefix = "/" + cfg.PathPrefix
|
||||
}
|
||||
|
||||
var once sync.Once
|
||||
var prefix string
|
||||
cacheControlStr := "public, max-age=" + strconv.Itoa(cfg.MaxAge)
|
||||
|
||||
// Return new handler
|
||||
return func(c *fiber.Ctx) error {
|
||||
// Don't execute middleware if Next returns true
|
||||
if cfg.Next != nil && cfg.Next(c) {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
method := c.Method()
|
||||
|
||||
// We only serve static assets on GET or HEAD methods
|
||||
if method != fiber.MethodGet && method != fiber.MethodHead {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
// Set prefix once
|
||||
once.Do(func() {
|
||||
prefix = c.Route().Path
|
||||
})
|
||||
|
||||
// Strip prefix
|
||||
path := strings.TrimPrefix(c.Path(), prefix)
|
||||
if !strings.HasPrefix(path, "/") {
|
||||
path = "/" + path
|
||||
}
|
||||
// Add PathPrefix
|
||||
if cfg.PathPrefix != "" {
|
||||
// PathPrefix already has a "/" prefix
|
||||
path = cfg.PathPrefix + path
|
||||
}
|
||||
|
||||
if len(path) > 1 {
|
||||
path = utils.TrimRight(path, '/')
|
||||
}
|
||||
file, err := cfg.Root.Open(path)
|
||||
if err != nil && errors.Is(err, fs.ErrNotExist) && cfg.NotFoundFile != "" {
|
||||
file, err = cfg.Root.Open(cfg.NotFoundFile)
|
||||
}
|
||||
if err != nil {
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
return c.Status(fiber.StatusNotFound).Next()
|
||||
}
|
||||
return fmt.Errorf("failed to open: %w", err)
|
||||
}
|
||||
|
||||
stat, err := file.Stat()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to stat: %w", err)
|
||||
}
|
||||
|
||||
// Serve index if path is directory
|
||||
if stat.IsDir() {
|
||||
indexPath := utils.TrimRight(path, '/') + cfg.Index
|
||||
index, err := cfg.Root.Open(indexPath)
|
||||
if err == nil {
|
||||
indexStat, err := index.Stat()
|
||||
if err == nil {
|
||||
file = index
|
||||
stat = indexStat
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Browse directory if no index found and browsing is enabled
|
||||
if stat.IsDir() {
|
||||
if cfg.Browse {
|
||||
return dirList(c, file)
|
||||
}
|
||||
return fiber.ErrForbidden
|
||||
}
|
||||
|
||||
c.Status(fiber.StatusOK)
|
||||
|
||||
modTime := stat.ModTime()
|
||||
contentLength := int(stat.Size())
|
||||
|
||||
// Set Content Type header
|
||||
if cfg.ContentTypeCharset == "" {
|
||||
c.Type(getFileExtension(stat.Name()))
|
||||
} else {
|
||||
c.Type(getFileExtension(stat.Name()), cfg.ContentTypeCharset)
|
||||
}
|
||||
|
||||
// Set Last Modified header
|
||||
if !modTime.IsZero() {
|
||||
c.Set(fiber.HeaderLastModified, modTime.UTC().Format(http.TimeFormat))
|
||||
}
|
||||
|
||||
if method == fiber.MethodGet {
|
||||
if cfg.MaxAge > 0 {
|
||||
c.Set(fiber.HeaderCacheControl, cacheControlStr)
|
||||
}
|
||||
c.Response().SetBodyStream(file, contentLength)
|
||||
return nil
|
||||
}
|
||||
if method == fiber.MethodHead {
|
||||
c.Request().ResetBody()
|
||||
// Fasthttp should skipbody by default if HEAD?
|
||||
c.Response().SkipBody = true
|
||||
c.Response().Header.SetContentLength(contentLength)
|
||||
if err := file.Close(); err != nil {
|
||||
return fmt.Errorf("failed to close: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
return c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// SendFile serves a file from an HTTP file system at the specified path.
|
||||
// It handles content serving, sets appropriate headers, and returns errors when needed.
|
||||
// Usage: err := SendFile(ctx, fs, "/path/to/file.txt")
|
||||
func SendFile(c *fiber.Ctx, filesystem http.FileSystem, path string) error {
|
||||
file, err := filesystem.Open(path)
|
||||
if err != nil {
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
return fiber.ErrNotFound
|
||||
}
|
||||
return fmt.Errorf("failed to open: %w", err)
|
||||
}
|
||||
|
||||
stat, err := file.Stat()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to stat: %w", err)
|
||||
}
|
||||
|
||||
// Serve index if path is directory
|
||||
if stat.IsDir() {
|
||||
indexPath := utils.TrimRight(path, '/') + ConfigDefault.Index
|
||||
index, err := filesystem.Open(indexPath)
|
||||
if err == nil {
|
||||
indexStat, err := index.Stat()
|
||||
if err == nil {
|
||||
file = index
|
||||
stat = indexStat
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Return forbidden if no index found
|
||||
if stat.IsDir() {
|
||||
return fiber.ErrForbidden
|
||||
}
|
||||
|
||||
c.Status(fiber.StatusOK)
|
||||
|
||||
modTime := stat.ModTime()
|
||||
contentLength := int(stat.Size())
|
||||
|
||||
// Set Content Type header
|
||||
c.Type(getFileExtension(stat.Name()))
|
||||
|
||||
// Set Last Modified header
|
||||
if !modTime.IsZero() {
|
||||
c.Set(fiber.HeaderLastModified, modTime.UTC().Format(http.TimeFormat))
|
||||
}
|
||||
|
||||
method := c.Method()
|
||||
if method == fiber.MethodGet {
|
||||
c.Response().SetBodyStream(file, contentLength)
|
||||
return nil
|
||||
}
|
||||
if method == fiber.MethodHead {
|
||||
c.Request().ResetBody()
|
||||
// Fasthttp should skipbody by default if HEAD?
|
||||
c.Response().SkipBody = true
|
||||
c.Response().Header.SetContentLength(contentLength)
|
||||
if err := file.Close(); err != nil {
|
||||
return fmt.Errorf("failed to close: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
235
middleware/filesystem/filesystem_test.go
Normal file
235
middleware/filesystem/filesystem_test.go
Normal file
|
@ -0,0 +1,235 @@
|
|||
//nolint:bodyclose // Much easier to just ignore memory leaks in tests
|
||||
package filesystem
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
)
|
||||
|
||||
// go test -run Test_FileSystem
|
||||
func Test_FileSystem(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
|
||||
app.Use("/test", New(Config{
|
||||
Root: http.Dir("../../.github/testdata/fs"),
|
||||
}))
|
||||
|
||||
app.Use("/dir", New(Config{
|
||||
Root: http.Dir("../../.github/testdata/fs"),
|
||||
Browse: true,
|
||||
}))
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return c.SendString("Hello, World!")
|
||||
})
|
||||
|
||||
app.Use("/spatest", New(Config{
|
||||
Root: http.Dir("../../.github/testdata/fs"),
|
||||
Index: "index.html",
|
||||
NotFoundFile: "index.html",
|
||||
}))
|
||||
|
||||
app.Use("/prefix", New(Config{
|
||||
Root: http.Dir("../../.github/testdata/fs"),
|
||||
PathPrefix: "img",
|
||||
}))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
url string
|
||||
statusCode int
|
||||
contentType string
|
||||
modifiedTime string
|
||||
}{
|
||||
{
|
||||
name: "Should be returns status 200 with suitable content-type",
|
||||
url: "/test/index.html",
|
||||
statusCode: 200,
|
||||
contentType: "text/html",
|
||||
},
|
||||
{
|
||||
name: "Should be returns status 200 with suitable content-type",
|
||||
url: "/test",
|
||||
statusCode: 200,
|
||||
contentType: "text/html",
|
||||
},
|
||||
{
|
||||
name: "Should be returns status 200 with suitable content-type",
|
||||
url: "/test/css/style.css",
|
||||
statusCode: 200,
|
||||
contentType: "text/css",
|
||||
},
|
||||
{
|
||||
name: "Should be returns status 404",
|
||||
url: "/test/nofile.js",
|
||||
statusCode: 404,
|
||||
},
|
||||
{
|
||||
name: "Should be returns status 404",
|
||||
url: "/test/nofile",
|
||||
statusCode: 404,
|
||||
},
|
||||
{
|
||||
name: "Should be returns status 200",
|
||||
url: "/",
|
||||
statusCode: 200,
|
||||
contentType: "text/plain; charset=utf-8",
|
||||
},
|
||||
{
|
||||
name: "Should be returns status 403",
|
||||
url: "/test/img",
|
||||
statusCode: 403,
|
||||
},
|
||||
{
|
||||
name: "Should list the directory contents",
|
||||
url: "/dir/img",
|
||||
statusCode: 200,
|
||||
contentType: "text/html",
|
||||
},
|
||||
{
|
||||
name: "Should list the directory contents",
|
||||
url: "/dir/img/",
|
||||
statusCode: 200,
|
||||
contentType: "text/html",
|
||||
},
|
||||
{
|
||||
name: "Should be returns status 200",
|
||||
url: "/dir/img/fiber.png",
|
||||
statusCode: 200,
|
||||
contentType: "image/png",
|
||||
},
|
||||
{
|
||||
name: "Should be return status 200",
|
||||
url: "/spatest/doesnotexist",
|
||||
statusCode: 200,
|
||||
contentType: "text/html",
|
||||
},
|
||||
{
|
||||
name: "PathPrefix should be applied",
|
||||
url: "/prefix/fiber.png",
|
||||
statusCode: 200,
|
||||
contentType: "image/png",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, tt.url, nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, tt.statusCode, resp.StatusCode)
|
||||
|
||||
if tt.contentType != "" {
|
||||
ct := resp.Header.Get("Content-Type")
|
||||
utils.AssertEqual(t, tt.contentType, ct)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// go test -run Test_FileSystem_Next
|
||||
func Test_FileSystem_Next(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
app.Use(New(Config{
|
||||
Root: http.Dir("../../.github/testdata/fs"),
|
||||
Next: func(_ *fiber.Ctx) bool {
|
||||
return true
|
||||
},
|
||||
}))
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
|
||||
}
|
||||
|
||||
func Test_FileSystem_NonGetAndHead(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
|
||||
app.Use("/test", New(Config{
|
||||
Root: http.Dir("../../.github/testdata/fs"),
|
||||
}))
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodPost, "/test", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 404, resp.StatusCode)
|
||||
}
|
||||
|
||||
func Test_FileSystem_Head(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
|
||||
app.Use("/test", New(Config{
|
||||
Root: http.Dir("../../.github/testdata/fs"),
|
||||
}))
|
||||
|
||||
req, err := http.NewRequestWithContext(context.Background(), fiber.MethodHead, "/test", nil)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
resp, err := app.Test(req)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||
}
|
||||
|
||||
func Test_FileSystem_NoRoot(t *testing.T) {
|
||||
t.Parallel()
|
||||
defer func() {
|
||||
utils.AssertEqual(t, "filesystem: Root cannot be nil", recover())
|
||||
}()
|
||||
|
||||
app := fiber.New()
|
||||
app.Use(New())
|
||||
_, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
}
|
||||
|
||||
func Test_FileSystem_UsingParam(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
|
||||
app.Use("/:path", func(c *fiber.Ctx) error {
|
||||
return SendFile(c, http.Dir("../../.github/testdata/fs"), c.Params("path")+".html")
|
||||
})
|
||||
|
||||
req, err := http.NewRequestWithContext(context.Background(), fiber.MethodHead, "/index", nil)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
resp, err := app.Test(req)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||
}
|
||||
|
||||
func Test_FileSystem_UsingParam_NonFile(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
|
||||
app.Use("/:path", func(c *fiber.Ctx) error {
|
||||
return SendFile(c, http.Dir("../../.github/testdata/fs"), c.Params("path")+".html")
|
||||
})
|
||||
|
||||
req, err := http.NewRequestWithContext(context.Background(), fiber.MethodHead, "/template", nil)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
resp, err := app.Test(req)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 404, resp.StatusCode)
|
||||
}
|
||||
|
||||
func Test_FileSystem_UsingContentTypeCharset(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
app.Use(New(Config{
|
||||
Root: http.Dir("../../.github/testdata/fs/index.html"),
|
||||
ContentTypeCharset: "UTF-8",
|
||||
}))
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||
utils.AssertEqual(t, "text/html; charset=UTF-8", resp.Header.Get("Content-Type"))
|
||||
}
|
66
middleware/filesystem/utils.go
Normal file
66
middleware/filesystem/utils.go
Normal file
|
@ -0,0 +1,66 @@
|
|||
package filesystem
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"html"
|
||||
"net/http"
|
||||
"os"
|
||||
"path"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
)
|
||||
|
||||
func getFileExtension(p string) string {
|
||||
n := strings.LastIndexByte(p, '.')
|
||||
if n < 0 {
|
||||
return ""
|
||||
}
|
||||
return p[n:]
|
||||
}
|
||||
|
||||
func dirList(c *fiber.Ctx, f http.File) error {
|
||||
fileinfos, err := f.Readdir(-1)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read dir: %w", err)
|
||||
}
|
||||
|
||||
fm := make(map[string]os.FileInfo, len(fileinfos))
|
||||
filenames := make([]string, 0, len(fileinfos))
|
||||
for _, fi := range fileinfos {
|
||||
name := fi.Name()
|
||||
fm[name] = fi
|
||||
filenames = append(filenames, name)
|
||||
}
|
||||
|
||||
basePathEscaped := html.EscapeString(c.Path())
|
||||
_, _ = fmt.Fprintf(c, "<html><head><title>%s</title><style>.dir { font-weight: bold }</style></head><body>", basePathEscaped)
|
||||
_, _ = fmt.Fprintf(c, "<h1>%s</h1>", basePathEscaped)
|
||||
_, _ = fmt.Fprint(c, "<ul>")
|
||||
|
||||
if len(basePathEscaped) > 1 {
|
||||
parentPathEscaped := html.EscapeString(utils.TrimRight(c.Path(), '/') + "/..")
|
||||
_, _ = fmt.Fprintf(c, `<li><a href="%s" class="dir">..</a></li>`, parentPathEscaped)
|
||||
}
|
||||
|
||||
sort.Strings(filenames)
|
||||
for _, name := range filenames {
|
||||
pathEscaped := html.EscapeString(path.Join(c.Path() + "/" + name))
|
||||
fi := fm[name]
|
||||
auxStr := "dir"
|
||||
className := "dir"
|
||||
if !fi.IsDir() {
|
||||
auxStr = fmt.Sprintf("file, %d bytes", fi.Size())
|
||||
className = "file"
|
||||
}
|
||||
_, _ = fmt.Fprintf(c, `<li><a href="%s" class="%s">%s</a>, %s, last modified %s</li>`,
|
||||
pathEscaped, className, html.EscapeString(name), auxStr, fi.ModTime())
|
||||
}
|
||||
_, _ = fmt.Fprint(c, "</ul></body></html>")
|
||||
|
||||
c.Type("html")
|
||||
|
||||
return nil
|
||||
}
|
84
middleware/healthcheck/config.go
Normal file
84
middleware/healthcheck/config.go
Normal file
|
@ -0,0 +1,84 @@
|
|||
package healthcheck
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
// Config defines the configuration options for the healthcheck middleware.
|
||||
type Config struct {
|
||||
// Next defines a function to skip this middleware when returned true.
|
||||
//
|
||||
// Optional. Default: nil
|
||||
Next func(c *fiber.Ctx) bool
|
||||
|
||||
// Function used for checking the liveness of the application. Returns true if the application
|
||||
// is running and false if it is not. The liveness probe is typically used to indicate if
|
||||
// the application is in a state where it can handle requests (e.g., the server is up and running).
|
||||
//
|
||||
// Optional. Default: func(c *fiber.Ctx) bool { return true }
|
||||
LivenessProbe HealthChecker
|
||||
|
||||
// HTTP endpoint at which the liveness probe will be available.
|
||||
//
|
||||
// Optional. Default: "/livez"
|
||||
LivenessEndpoint string
|
||||
|
||||
// Function used for checking the readiness of the application. Returns true if the application
|
||||
// is ready to process requests and false otherwise. The readiness probe typically checks if all necessary
|
||||
// services, databases, and other dependencies are available for the application to function correctly.
|
||||
//
|
||||
// Optional. Default: func(c *fiber.Ctx) bool { return true }
|
||||
ReadinessProbe HealthChecker
|
||||
|
||||
// HTTP endpoint at which the readiness probe will be available.
|
||||
// Optional. Default: "/readyz"
|
||||
ReadinessEndpoint string
|
||||
}
|
||||
|
||||
const (
|
||||
DefaultLivenessEndpoint = "/livez"
|
||||
DefaultReadinessEndpoint = "/readyz"
|
||||
)
|
||||
|
||||
func defaultLivenessProbe(*fiber.Ctx) bool { return true }
|
||||
|
||||
func defaultReadinessProbe(*fiber.Ctx) bool { return true }
|
||||
|
||||
// ConfigDefault is the default config
|
||||
var ConfigDefault = Config{
|
||||
LivenessProbe: defaultLivenessProbe,
|
||||
ReadinessProbe: defaultReadinessProbe,
|
||||
LivenessEndpoint: DefaultLivenessEndpoint,
|
||||
ReadinessEndpoint: DefaultReadinessEndpoint,
|
||||
}
|
||||
|
||||
// defaultConfig returns a default config for the healthcheck middleware.
|
||||
func defaultConfig(config ...Config) Config {
|
||||
if len(config) < 1 {
|
||||
return ConfigDefault
|
||||
}
|
||||
|
||||
cfg := config[0]
|
||||
|
||||
if cfg.Next == nil {
|
||||
cfg.Next = ConfigDefault.Next
|
||||
}
|
||||
|
||||
if cfg.LivenessProbe == nil {
|
||||
cfg.LivenessProbe = defaultLivenessProbe
|
||||
}
|
||||
|
||||
if cfg.ReadinessProbe == nil {
|
||||
cfg.ReadinessProbe = defaultReadinessProbe
|
||||
}
|
||||
|
||||
if cfg.LivenessEndpoint == "" {
|
||||
cfg.LivenessEndpoint = DefaultLivenessEndpoint
|
||||
}
|
||||
|
||||
if cfg.ReadinessEndpoint == "" {
|
||||
cfg.ReadinessEndpoint = DefaultReadinessEndpoint
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
61
middleware/healthcheck/healthcheck.go
Normal file
61
middleware/healthcheck/healthcheck.go
Normal file
|
@ -0,0 +1,61 @@
|
|||
package healthcheck
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
)
|
||||
|
||||
// HealthChecker defines a function to check liveness or readiness of the application
|
||||
type HealthChecker func(*fiber.Ctx) bool
|
||||
|
||||
// ProbeCheckerHandler defines a function that returns a ProbeChecker
|
||||
type HealthCheckerHandler func(HealthChecker) fiber.Handler
|
||||
|
||||
func healthCheckerHandler(checker HealthChecker) fiber.Handler {
|
||||
return func(c *fiber.Ctx) error {
|
||||
if checker == nil {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
if checker(c) {
|
||||
return c.SendStatus(fiber.StatusOK)
|
||||
}
|
||||
|
||||
return c.SendStatus(fiber.StatusServiceUnavailable)
|
||||
}
|
||||
}
|
||||
|
||||
func New(config ...Config) fiber.Handler {
|
||||
cfg := defaultConfig(config...)
|
||||
|
||||
isLiveHandler := healthCheckerHandler(cfg.LivenessProbe)
|
||||
isReadyHandler := healthCheckerHandler(cfg.ReadinessProbe)
|
||||
|
||||
return func(c *fiber.Ctx) error {
|
||||
// Don't execute middleware if Next returns true
|
||||
if cfg.Next != nil && cfg.Next(c) {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
if c.Method() != fiber.MethodGet {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
prefixCount := len(utils.TrimRight(c.Route().Path, '/'))
|
||||
if len(c.Path()) >= prefixCount {
|
||||
checkPath := c.Path()[prefixCount:]
|
||||
checkPathTrimmed := checkPath
|
||||
if !c.App().Config().StrictRouting {
|
||||
checkPathTrimmed = utils.TrimRight(checkPath, '/')
|
||||
}
|
||||
switch {
|
||||
case checkPath == cfg.ReadinessEndpoint || checkPathTrimmed == cfg.ReadinessEndpoint:
|
||||
return isReadyHandler(c)
|
||||
case checkPath == cfg.LivenessEndpoint || checkPathTrimmed == cfg.LivenessEndpoint:
|
||||
return isLiveHandler(c)
|
||||
}
|
||||
}
|
||||
|
||||
return c.Next()
|
||||
}
|
||||
}
|
237
middleware/healthcheck/healthcheck_test.go
Normal file
237
middleware/healthcheck/healthcheck_test.go
Normal file
|
@ -0,0 +1,237 @@
|
|||
package healthcheck
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
func shouldGiveStatus(t *testing.T, app *fiber.App, path string, expectedStatus int) {
|
||||
t.Helper()
|
||||
req, err := app.Test(httptest.NewRequest(fiber.MethodGet, path, nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, expectedStatus, req.StatusCode, "path: "+path+" should match "+fmt.Sprint(expectedStatus))
|
||||
}
|
||||
|
||||
func shouldGiveOK(t *testing.T, app *fiber.App, path string) {
|
||||
t.Helper()
|
||||
shouldGiveStatus(t, app, path, fiber.StatusOK)
|
||||
}
|
||||
|
||||
func shouldGiveNotFound(t *testing.T, app *fiber.App, path string) {
|
||||
t.Helper()
|
||||
shouldGiveStatus(t, app, path, fiber.StatusNotFound)
|
||||
}
|
||||
|
||||
func Test_HealthCheck_Strict_Routing_Default(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app := fiber.New(fiber.Config{
|
||||
StrictRouting: true,
|
||||
})
|
||||
|
||||
app.Use(New())
|
||||
|
||||
shouldGiveOK(t, app, "/readyz")
|
||||
shouldGiveOK(t, app, "/livez")
|
||||
shouldGiveNotFound(t, app, "/readyz/")
|
||||
shouldGiveNotFound(t, app, "/livez/")
|
||||
shouldGiveNotFound(t, app, "/notDefined/readyz")
|
||||
shouldGiveNotFound(t, app, "/notDefined/livez")
|
||||
}
|
||||
|
||||
func Test_HealthCheck_Group_Default(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app := fiber.New()
|
||||
app.Group("/v1", New())
|
||||
v2Group := app.Group("/v2/")
|
||||
customer := v2Group.Group("/customer/")
|
||||
customer.Use(New())
|
||||
|
||||
v3Group := app.Group("/v3/")
|
||||
v3Group.Group("/todos/", New(Config{ReadinessEndpoint: "/readyz/", LivenessEndpoint: "/livez/"}))
|
||||
|
||||
shouldGiveOK(t, app, "/v1/readyz")
|
||||
shouldGiveOK(t, app, "/v1/livez")
|
||||
shouldGiveOK(t, app, "/v1/readyz/")
|
||||
shouldGiveOK(t, app, "/v1/livez/")
|
||||
shouldGiveOK(t, app, "/v2/customer/readyz")
|
||||
shouldGiveOK(t, app, "/v2/customer/livez")
|
||||
shouldGiveOK(t, app, "/v2/customer/readyz/")
|
||||
shouldGiveOK(t, app, "/v2/customer/livez/")
|
||||
shouldGiveNotFound(t, app, "/v3/todos/readyz")
|
||||
shouldGiveNotFound(t, app, "/v3/todos/livez")
|
||||
shouldGiveOK(t, app, "/v3/todos/readyz/")
|
||||
shouldGiveOK(t, app, "/v3/todos/livez/")
|
||||
shouldGiveNotFound(t, app, "/notDefined/readyz")
|
||||
shouldGiveNotFound(t, app, "/notDefined/livez")
|
||||
shouldGiveNotFound(t, app, "/notDefined/readyz/")
|
||||
shouldGiveNotFound(t, app, "/notDefined/livez/")
|
||||
|
||||
// strict routing
|
||||
app = fiber.New(fiber.Config{
|
||||
StrictRouting: true,
|
||||
})
|
||||
app.Group("/v1", New())
|
||||
v2Group = app.Group("/v2/")
|
||||
customer = v2Group.Group("/customer/")
|
||||
customer.Use(New())
|
||||
|
||||
v3Group = app.Group("/v3/")
|
||||
v3Group.Group("/todos/", New(Config{ReadinessEndpoint: "/readyz/", LivenessEndpoint: "/livez/"}))
|
||||
|
||||
shouldGiveOK(t, app, "/v1/readyz")
|
||||
shouldGiveOK(t, app, "/v1/livez")
|
||||
shouldGiveNotFound(t, app, "/v1/readyz/")
|
||||
shouldGiveNotFound(t, app, "/v1/livez/")
|
||||
shouldGiveOK(t, app, "/v2/customer/readyz")
|
||||
shouldGiveOK(t, app, "/v2/customer/livez")
|
||||
shouldGiveNotFound(t, app, "/v2/customer/readyz/")
|
||||
shouldGiveNotFound(t, app, "/v2/customer/livez/")
|
||||
shouldGiveNotFound(t, app, "/v3/todos/readyz")
|
||||
shouldGiveNotFound(t, app, "/v3/todos/livez")
|
||||
shouldGiveOK(t, app, "/v3/todos/readyz/")
|
||||
shouldGiveOK(t, app, "/v3/todos/livez/")
|
||||
shouldGiveNotFound(t, app, "/notDefined/readyz")
|
||||
shouldGiveNotFound(t, app, "/notDefined/livez")
|
||||
shouldGiveNotFound(t, app, "/notDefined/readyz/")
|
||||
shouldGiveNotFound(t, app, "/notDefined/livez/")
|
||||
}
|
||||
|
||||
func Test_HealthCheck_Default(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app := fiber.New()
|
||||
app.Use(New())
|
||||
|
||||
shouldGiveOK(t, app, "/readyz")
|
||||
shouldGiveOK(t, app, "/livez")
|
||||
shouldGiveOK(t, app, "/readyz/")
|
||||
shouldGiveOK(t, app, "/livez/")
|
||||
shouldGiveNotFound(t, app, "/notDefined/readyz")
|
||||
shouldGiveNotFound(t, app, "/notDefined/livez")
|
||||
}
|
||||
|
||||
func Test_HealthCheck_Custom(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app := fiber.New()
|
||||
|
||||
c1 := make(chan struct{}, 1)
|
||||
app.Use(New(Config{
|
||||
LivenessProbe: func(c *fiber.Ctx) bool {
|
||||
return true
|
||||
},
|
||||
LivenessEndpoint: "/live",
|
||||
ReadinessProbe: func(c *fiber.Ctx) bool {
|
||||
select {
|
||||
case <-c1:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
},
|
||||
ReadinessEndpoint: "/ready",
|
||||
}))
|
||||
|
||||
// Live should return 200 with GET request
|
||||
shouldGiveOK(t, app, "/live")
|
||||
// Live should return 404 with POST request
|
||||
req, err := app.Test(httptest.NewRequest(fiber.MethodPost, "/live", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusNotFound, req.StatusCode)
|
||||
|
||||
// Ready should return 404 with POST request
|
||||
req, err = app.Test(httptest.NewRequest(fiber.MethodPost, "/ready", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusNotFound, req.StatusCode)
|
||||
|
||||
// Ready should return 503 with GET request before the channel is closed
|
||||
shouldGiveStatus(t, app, "/ready", fiber.StatusServiceUnavailable)
|
||||
|
||||
// Ready should return 200 with GET request after the channel is closed
|
||||
c1 <- struct{}{}
|
||||
shouldGiveOK(t, app, "/ready")
|
||||
}
|
||||
|
||||
func Test_HealthCheck_Custom_Nested(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app := fiber.New()
|
||||
|
||||
c1 := make(chan struct{}, 1)
|
||||
|
||||
app.Use(New(Config{
|
||||
LivenessProbe: func(c *fiber.Ctx) bool {
|
||||
return true
|
||||
},
|
||||
LivenessEndpoint: "/probe/live",
|
||||
ReadinessProbe: func(c *fiber.Ctx) bool {
|
||||
select {
|
||||
case <-c1:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
},
|
||||
ReadinessEndpoint: "/probe/ready",
|
||||
}))
|
||||
|
||||
shouldGiveOK(t, app, "/probe/live")
|
||||
shouldGiveStatus(t, app, "/probe/ready", fiber.StatusServiceUnavailable)
|
||||
shouldGiveOK(t, app, "/probe/live/")
|
||||
shouldGiveStatus(t, app, "/probe/ready/", fiber.StatusServiceUnavailable)
|
||||
shouldGiveNotFound(t, app, "/probe/livez")
|
||||
shouldGiveNotFound(t, app, "/probe/readyz")
|
||||
shouldGiveNotFound(t, app, "/probe/livez/")
|
||||
shouldGiveNotFound(t, app, "/probe/readyz/")
|
||||
shouldGiveNotFound(t, app, "/livez")
|
||||
shouldGiveNotFound(t, app, "/readyz")
|
||||
shouldGiveNotFound(t, app, "/readyz/")
|
||||
shouldGiveNotFound(t, app, "/livez/")
|
||||
|
||||
c1 <- struct{}{}
|
||||
shouldGiveOK(t, app, "/probe/ready")
|
||||
c1 <- struct{}{}
|
||||
shouldGiveOK(t, app, "/probe/ready/")
|
||||
}
|
||||
|
||||
func Test_HealthCheck_Next(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
Next: func(c *fiber.Ctx) bool {
|
||||
return true
|
||||
},
|
||||
}))
|
||||
|
||||
shouldGiveNotFound(t, app, "/readyz")
|
||||
shouldGiveNotFound(t, app, "/livez")
|
||||
}
|
||||
|
||||
func Benchmark_HealthCheck(b *testing.B) {
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New())
|
||||
|
||||
h := app.Handler()
|
||||
fctx := &fasthttp.RequestCtx{}
|
||||
fctx.Request.Header.SetMethod(fiber.MethodGet)
|
||||
fctx.Request.SetRequestURI("/livez")
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
h(fctx)
|
||||
}
|
||||
|
||||
utils.AssertEqual(b, fiber.StatusOK, fctx.Response.Header.StatusCode())
|
||||
}
|
154
middleware/helmet/config.go
Normal file
154
middleware/helmet/config.go
Normal file
|
@ -0,0 +1,154 @@
|
|||
package helmet
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
// Config defines the config for middleware.
|
||||
type Config struct {
|
||||
// Next defines a function to skip middleware.
|
||||
// Optional. Default: nil
|
||||
Next func(*fiber.Ctx) bool
|
||||
|
||||
// XSSProtection
|
||||
// Optional. Default value "0".
|
||||
XSSProtection string
|
||||
|
||||
// ContentTypeNosniff
|
||||
// Optional. Default value "nosniff".
|
||||
ContentTypeNosniff string
|
||||
|
||||
// XFrameOptions
|
||||
// Optional. Default value "SAMEORIGIN".
|
||||
// Possible values: "SAMEORIGIN", "DENY", "ALLOW-FROM uri"
|
||||
XFrameOptions string
|
||||
|
||||
// HSTSMaxAge
|
||||
// Optional. Default value 0.
|
||||
HSTSMaxAge int
|
||||
|
||||
// HSTSExcludeSubdomains
|
||||
// Optional. Default value false.
|
||||
HSTSExcludeSubdomains bool
|
||||
|
||||
// ContentSecurityPolicy
|
||||
// Optional. Default value "".
|
||||
ContentSecurityPolicy string
|
||||
|
||||
// CSPReportOnly
|
||||
// Optional. Default value false.
|
||||
CSPReportOnly bool
|
||||
|
||||
// HSTSPreloadEnabled
|
||||
// Optional. Default value false.
|
||||
HSTSPreloadEnabled bool
|
||||
|
||||
// ReferrerPolicy
|
||||
// Optional. Default value "ReferrerPolicy".
|
||||
ReferrerPolicy string
|
||||
|
||||
// Permissions-Policy
|
||||
// Optional. Default value "".
|
||||
PermissionPolicy string
|
||||
|
||||
// Cross-Origin-Embedder-Policy
|
||||
// Optional. Default value "require-corp".
|
||||
CrossOriginEmbedderPolicy string
|
||||
|
||||
// Cross-Origin-Opener-Policy
|
||||
// Optional. Default value "same-origin".
|
||||
CrossOriginOpenerPolicy string
|
||||
|
||||
// Cross-Origin-Resource-Policy
|
||||
// Optional. Default value "same-origin".
|
||||
CrossOriginResourcePolicy string
|
||||
|
||||
// Origin-Agent-Cluster
|
||||
// Optional. Default value "?1".
|
||||
OriginAgentCluster string
|
||||
|
||||
// X-DNS-Prefetch-Control
|
||||
// Optional. Default value "off".
|
||||
XDNSPrefetchControl string
|
||||
|
||||
// X-Download-Options
|
||||
// Optional. Default value "noopen".
|
||||
XDownloadOptions string
|
||||
|
||||
// X-Permitted-Cross-Domain-Policies
|
||||
// Optional. Default value "none".
|
||||
XPermittedCrossDomain string
|
||||
}
|
||||
|
||||
// ConfigDefault is the default config
|
||||
var ConfigDefault = Config{
|
||||
XSSProtection: "0",
|
||||
ContentTypeNosniff: "nosniff",
|
||||
XFrameOptions: "SAMEORIGIN",
|
||||
ReferrerPolicy: "no-referrer",
|
||||
CrossOriginEmbedderPolicy: "require-corp",
|
||||
CrossOriginOpenerPolicy: "same-origin",
|
||||
CrossOriginResourcePolicy: "same-origin",
|
||||
OriginAgentCluster: "?1",
|
||||
XDNSPrefetchControl: "off",
|
||||
XDownloadOptions: "noopen",
|
||||
XPermittedCrossDomain: "none",
|
||||
}
|
||||
|
||||
// Helper function to set default values
|
||||
func configDefault(config ...Config) Config {
|
||||
// Return default config if nothing provided
|
||||
if len(config) < 1 {
|
||||
return ConfigDefault
|
||||
}
|
||||
|
||||
// Override default config
|
||||
cfg := config[0]
|
||||
|
||||
// Set default values
|
||||
if cfg.XSSProtection == "" {
|
||||
cfg.XSSProtection = ConfigDefault.XSSProtection
|
||||
}
|
||||
|
||||
if cfg.ContentTypeNosniff == "" {
|
||||
cfg.ContentTypeNosniff = ConfigDefault.ContentTypeNosniff
|
||||
}
|
||||
|
||||
if cfg.XFrameOptions == "" {
|
||||
cfg.XFrameOptions = ConfigDefault.XFrameOptions
|
||||
}
|
||||
|
||||
if cfg.ReferrerPolicy == "" {
|
||||
cfg.ReferrerPolicy = ConfigDefault.ReferrerPolicy
|
||||
}
|
||||
|
||||
if cfg.CrossOriginEmbedderPolicy == "" {
|
||||
cfg.CrossOriginEmbedderPolicy = ConfigDefault.CrossOriginEmbedderPolicy
|
||||
}
|
||||
|
||||
if cfg.CrossOriginOpenerPolicy == "" {
|
||||
cfg.CrossOriginOpenerPolicy = ConfigDefault.CrossOriginOpenerPolicy
|
||||
}
|
||||
|
||||
if cfg.CrossOriginResourcePolicy == "" {
|
||||
cfg.CrossOriginResourcePolicy = ConfigDefault.CrossOriginResourcePolicy
|
||||
}
|
||||
|
||||
if cfg.OriginAgentCluster == "" {
|
||||
cfg.OriginAgentCluster = ConfigDefault.OriginAgentCluster
|
||||
}
|
||||
|
||||
if cfg.XDNSPrefetchControl == "" {
|
||||
cfg.XDNSPrefetchControl = ConfigDefault.XDNSPrefetchControl
|
||||
}
|
||||
|
||||
if cfg.XDownloadOptions == "" {
|
||||
cfg.XDownloadOptions = ConfigDefault.XDownloadOptions
|
||||
}
|
||||
|
||||
if cfg.XPermittedCrossDomain == "" {
|
||||
cfg.XPermittedCrossDomain = ConfigDefault.XPermittedCrossDomain
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
94
middleware/helmet/helmet.go
Normal file
94
middleware/helmet/helmet.go
Normal file
|
@ -0,0 +1,94 @@
|
|||
package helmet
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
// New creates a new middleware handler
|
||||
func New(config ...Config) fiber.Handler {
|
||||
// Init config
|
||||
cfg := configDefault(config...)
|
||||
|
||||
// Return middleware handler
|
||||
return func(c *fiber.Ctx) error {
|
||||
// Next request to skip middleware
|
||||
if cfg.Next != nil && cfg.Next(c) {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
// Set headers
|
||||
if cfg.XSSProtection != "" {
|
||||
c.Set(fiber.HeaderXXSSProtection, cfg.XSSProtection)
|
||||
}
|
||||
|
||||
if cfg.ContentTypeNosniff != "" {
|
||||
c.Set(fiber.HeaderXContentTypeOptions, cfg.ContentTypeNosniff)
|
||||
}
|
||||
|
||||
if cfg.XFrameOptions != "" {
|
||||
c.Set(fiber.HeaderXFrameOptions, cfg.XFrameOptions)
|
||||
}
|
||||
|
||||
if cfg.CrossOriginEmbedderPolicy != "" {
|
||||
c.Set("Cross-Origin-Embedder-Policy", cfg.CrossOriginEmbedderPolicy)
|
||||
}
|
||||
|
||||
if cfg.CrossOriginOpenerPolicy != "" {
|
||||
c.Set("Cross-Origin-Opener-Policy", cfg.CrossOriginOpenerPolicy)
|
||||
}
|
||||
|
||||
if cfg.CrossOriginResourcePolicy != "" {
|
||||
c.Set("Cross-Origin-Resource-Policy", cfg.CrossOriginResourcePolicy)
|
||||
}
|
||||
|
||||
if cfg.OriginAgentCluster != "" {
|
||||
c.Set("Origin-Agent-Cluster", cfg.OriginAgentCluster)
|
||||
}
|
||||
|
||||
if cfg.ReferrerPolicy != "" {
|
||||
c.Set("Referrer-Policy", cfg.ReferrerPolicy)
|
||||
}
|
||||
|
||||
if cfg.XDNSPrefetchControl != "" {
|
||||
c.Set("X-DNS-Prefetch-Control", cfg.XDNSPrefetchControl)
|
||||
}
|
||||
|
||||
if cfg.XDownloadOptions != "" {
|
||||
c.Set("X-Download-Options", cfg.XDownloadOptions)
|
||||
}
|
||||
|
||||
if cfg.XPermittedCrossDomain != "" {
|
||||
c.Set("X-Permitted-Cross-Domain-Policies", cfg.XPermittedCrossDomain)
|
||||
}
|
||||
|
||||
// Handle HSTS headers
|
||||
if c.Protocol() == "https" && cfg.HSTSMaxAge != 0 {
|
||||
subdomains := ""
|
||||
if !cfg.HSTSExcludeSubdomains {
|
||||
subdomains = "; includeSubDomains"
|
||||
}
|
||||
if cfg.HSTSPreloadEnabled {
|
||||
subdomains = fmt.Sprintf("%s; preload", subdomains)
|
||||
}
|
||||
c.Set(fiber.HeaderStrictTransportSecurity, fmt.Sprintf("max-age=%d%s", cfg.HSTSMaxAge, subdomains))
|
||||
}
|
||||
|
||||
// Handle Content-Security-Policy headers
|
||||
if cfg.ContentSecurityPolicy != "" {
|
||||
if cfg.CSPReportOnly {
|
||||
c.Set(fiber.HeaderContentSecurityPolicyReportOnly, cfg.ContentSecurityPolicy)
|
||||
} else {
|
||||
c.Set(fiber.HeaderContentSecurityPolicy, cfg.ContentSecurityPolicy)
|
||||
}
|
||||
}
|
||||
|
||||
// Handle Permissions-Policy headers
|
||||
if cfg.PermissionPolicy != "" {
|
||||
c.Set(fiber.HeaderPermissionsPolicy, cfg.PermissionPolicy)
|
||||
}
|
||||
|
||||
return c.Next()
|
||||
}
|
||||
}
|
201
middleware/helmet/helmet_test.go
Normal file
201
middleware/helmet/helmet_test.go
Normal file
|
@ -0,0 +1,201 @@
|
|||
package helmet
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
)
|
||||
|
||||
func Test_Default(t *testing.T) {
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New())
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return c.SendString("Hello, World!")
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, "0", resp.Header.Get(fiber.HeaderXXSSProtection))
|
||||
utils.AssertEqual(t, "nosniff", resp.Header.Get(fiber.HeaderXContentTypeOptions))
|
||||
utils.AssertEqual(t, "SAMEORIGIN", resp.Header.Get(fiber.HeaderXFrameOptions))
|
||||
utils.AssertEqual(t, "", resp.Header.Get(fiber.HeaderContentSecurityPolicy))
|
||||
utils.AssertEqual(t, "no-referrer", resp.Header.Get(fiber.HeaderReferrerPolicy))
|
||||
utils.AssertEqual(t, "", resp.Header.Get(fiber.HeaderPermissionsPolicy))
|
||||
utils.AssertEqual(t, "require-corp", resp.Header.Get("Cross-Origin-Embedder-Policy"))
|
||||
utils.AssertEqual(t, "same-origin", resp.Header.Get("Cross-Origin-Opener-Policy"))
|
||||
utils.AssertEqual(t, "same-origin", resp.Header.Get("Cross-Origin-Resource-Policy"))
|
||||
utils.AssertEqual(t, "?1", resp.Header.Get("Origin-Agent-Cluster"))
|
||||
utils.AssertEqual(t, "off", resp.Header.Get("X-DNS-Prefetch-Control"))
|
||||
utils.AssertEqual(t, "noopen", resp.Header.Get("X-Download-Options"))
|
||||
utils.AssertEqual(t, "none", resp.Header.Get("X-Permitted-Cross-Domain-Policies"))
|
||||
}
|
||||
|
||||
func Test_CustomValues_AllHeaders(t *testing.T) {
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
// Custom values for all headers
|
||||
XSSProtection: "0",
|
||||
ContentTypeNosniff: "custom-nosniff",
|
||||
XFrameOptions: "DENY",
|
||||
HSTSExcludeSubdomains: true,
|
||||
ContentSecurityPolicy: "default-src 'none'",
|
||||
CSPReportOnly: true,
|
||||
HSTSPreloadEnabled: true,
|
||||
ReferrerPolicy: "origin",
|
||||
PermissionPolicy: "geolocation=(self)",
|
||||
CrossOriginEmbedderPolicy: "custom-value",
|
||||
CrossOriginOpenerPolicy: "custom-value",
|
||||
CrossOriginResourcePolicy: "custom-value",
|
||||
OriginAgentCluster: "custom-value",
|
||||
XDNSPrefetchControl: "custom-control",
|
||||
XDownloadOptions: "custom-options",
|
||||
XPermittedCrossDomain: "custom-policies",
|
||||
}))
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return c.SendString("Hello, World!")
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
// Assertions for custom header values
|
||||
utils.AssertEqual(t, "0", resp.Header.Get(fiber.HeaderXXSSProtection))
|
||||
utils.AssertEqual(t, "custom-nosniff", resp.Header.Get(fiber.HeaderXContentTypeOptions))
|
||||
utils.AssertEqual(t, "DENY", resp.Header.Get(fiber.HeaderXFrameOptions))
|
||||
utils.AssertEqual(t, "default-src 'none'", resp.Header.Get(fiber.HeaderContentSecurityPolicyReportOnly))
|
||||
utils.AssertEqual(t, "origin", resp.Header.Get(fiber.HeaderReferrerPolicy))
|
||||
utils.AssertEqual(t, "geolocation=(self)", resp.Header.Get(fiber.HeaderPermissionsPolicy))
|
||||
utils.AssertEqual(t, "custom-value", resp.Header.Get("Cross-Origin-Embedder-Policy"))
|
||||
utils.AssertEqual(t, "custom-value", resp.Header.Get("Cross-Origin-Opener-Policy"))
|
||||
utils.AssertEqual(t, "custom-value", resp.Header.Get("Cross-Origin-Resource-Policy"))
|
||||
utils.AssertEqual(t, "custom-value", resp.Header.Get("Origin-Agent-Cluster"))
|
||||
utils.AssertEqual(t, "custom-control", resp.Header.Get("X-DNS-Prefetch-Control"))
|
||||
utils.AssertEqual(t, "custom-options", resp.Header.Get("X-Download-Options"))
|
||||
utils.AssertEqual(t, "custom-policies", resp.Header.Get("X-Permitted-Cross-Domain-Policies"))
|
||||
}
|
||||
|
||||
func Test_RealWorldValues_AllHeaders(t *testing.T) {
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
// Real-world values for all headers
|
||||
XSSProtection: "0",
|
||||
ContentTypeNosniff: "nosniff",
|
||||
XFrameOptions: "SAMEORIGIN",
|
||||
HSTSExcludeSubdomains: false,
|
||||
ContentSecurityPolicy: "default-src 'self';base-uri 'self';font-src 'self' https: data:;form-action 'self';frame-ancestors 'self';img-src 'self' data:;object-src 'none';script-src 'self';script-src-attr 'none';style-src 'self' https: 'unsafe-inline';upgrade-insecure-requests",
|
||||
CSPReportOnly: false,
|
||||
HSTSPreloadEnabled: true,
|
||||
ReferrerPolicy: "no-referrer",
|
||||
PermissionPolicy: "geolocation=(self)",
|
||||
CrossOriginEmbedderPolicy: "require-corp",
|
||||
CrossOriginOpenerPolicy: "same-origin",
|
||||
CrossOriginResourcePolicy: "same-origin",
|
||||
OriginAgentCluster: "?1",
|
||||
XDNSPrefetchControl: "off",
|
||||
XDownloadOptions: "noopen",
|
||||
XPermittedCrossDomain: "none",
|
||||
}))
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return c.SendString("Hello, World!")
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
// Assertions for real-world header values
|
||||
utils.AssertEqual(t, "0", resp.Header.Get(fiber.HeaderXXSSProtection))
|
||||
utils.AssertEqual(t, "nosniff", resp.Header.Get(fiber.HeaderXContentTypeOptions))
|
||||
utils.AssertEqual(t, "SAMEORIGIN", resp.Header.Get(fiber.HeaderXFrameOptions))
|
||||
utils.AssertEqual(t, "default-src 'self';base-uri 'self';font-src 'self' https: data:;form-action 'self';frame-ancestors 'self';img-src 'self' data:;object-src 'none';script-src 'self';script-src-attr 'none';style-src 'self' https: 'unsafe-inline';upgrade-insecure-requests", resp.Header.Get(fiber.HeaderContentSecurityPolicy))
|
||||
utils.AssertEqual(t, "no-referrer", resp.Header.Get(fiber.HeaderReferrerPolicy))
|
||||
utils.AssertEqual(t, "geolocation=(self)", resp.Header.Get(fiber.HeaderPermissionsPolicy))
|
||||
utils.AssertEqual(t, "require-corp", resp.Header.Get("Cross-Origin-Embedder-Policy"))
|
||||
utils.AssertEqual(t, "same-origin", resp.Header.Get("Cross-Origin-Opener-Policy"))
|
||||
utils.AssertEqual(t, "same-origin", resp.Header.Get("Cross-Origin-Resource-Policy"))
|
||||
utils.AssertEqual(t, "?1", resp.Header.Get("Origin-Agent-Cluster"))
|
||||
utils.AssertEqual(t, "off", resp.Header.Get("X-DNS-Prefetch-Control"))
|
||||
utils.AssertEqual(t, "noopen", resp.Header.Get("X-Download-Options"))
|
||||
utils.AssertEqual(t, "none", resp.Header.Get("X-Permitted-Cross-Domain-Policies"))
|
||||
}
|
||||
|
||||
func Test_Next(t *testing.T) {
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
Next: func(ctx *fiber.Ctx) bool {
|
||||
return ctx.Path() == "/next"
|
||||
},
|
||||
ReferrerPolicy: "no-referrer",
|
||||
}))
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return c.SendString("Hello, World!")
|
||||
})
|
||||
app.Get("/next", func(c *fiber.Ctx) error {
|
||||
return c.SendString("Skipped!")
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, "no-referrer", resp.Header.Get(fiber.HeaderReferrerPolicy))
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/next", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, "", resp.Header.Get(fiber.HeaderReferrerPolicy))
|
||||
}
|
||||
|
||||
func Test_ContentSecurityPolicy(t *testing.T) {
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
ContentSecurityPolicy: "default-src 'none'",
|
||||
}))
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return c.SendString("Hello, World!")
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, "default-src 'none'", resp.Header.Get(fiber.HeaderContentSecurityPolicy))
|
||||
}
|
||||
|
||||
func Test_ContentSecurityPolicyReportOnly(t *testing.T) {
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
ContentSecurityPolicy: "default-src 'none'",
|
||||
CSPReportOnly: true,
|
||||
}))
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return c.SendString("Hello, World!")
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, "default-src 'none'", resp.Header.Get(fiber.HeaderContentSecurityPolicyReportOnly))
|
||||
utils.AssertEqual(t, "", resp.Header.Get(fiber.HeaderContentSecurityPolicy))
|
||||
}
|
||||
|
||||
func Test_PermissionsPolicy(t *testing.T) {
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
PermissionPolicy: "microphone=()",
|
||||
}))
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return c.SendString("Hello, World!")
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, "microphone=()", resp.Header.Get(fiber.HeaderPermissionsPolicy))
|
||||
}
|
125
middleware/idempotency/config.go
Normal file
125
middleware/idempotency/config.go
Normal file
|
@ -0,0 +1,125 @@
|
|||
package idempotency
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/internal/storage/memory"
|
||||
)
|
||||
|
||||
var ErrInvalidIdempotencyKey = errors.New("invalid idempotency key")
|
||||
|
||||
// Config defines the config for middleware.
|
||||
type Config struct {
|
||||
// Next defines a function to skip this middleware when returned true.
|
||||
//
|
||||
// Optional. Default: a function which skips the middleware on safe HTTP request method.
|
||||
Next func(c *fiber.Ctx) bool
|
||||
|
||||
// Lifetime is the maximum lifetime of an idempotency key.
|
||||
//
|
||||
// Optional. Default: 30 * time.Minute
|
||||
Lifetime time.Duration
|
||||
|
||||
// KeyHeader is the name of the header that contains the idempotency key.
|
||||
//
|
||||
// Optional. Default: X-Idempotency-Key
|
||||
KeyHeader string
|
||||
// KeyHeaderValidate defines a function to validate the syntax of the idempotency header.
|
||||
//
|
||||
// Optional. Default: a function which ensures the header is 36 characters long (the size of an UUID).
|
||||
KeyHeaderValidate func(string) error
|
||||
|
||||
// KeepResponseHeaders is a list of headers that should be kept from the original response.
|
||||
//
|
||||
// Optional. Default: nil (to keep all headers)
|
||||
KeepResponseHeaders []string
|
||||
|
||||
// Lock locks an idempotency key.
|
||||
//
|
||||
// Optional. Default: an in-memory locker for this process only.
|
||||
Lock Locker
|
||||
|
||||
// Storage stores response data by idempotency key.
|
||||
//
|
||||
// Optional. Default: an in-memory storage for this process only.
|
||||
Storage fiber.Storage
|
||||
}
|
||||
|
||||
// ConfigDefault is the default config
|
||||
var ConfigDefault = Config{
|
||||
Next: func(c *fiber.Ctx) bool {
|
||||
// Skip middleware if the request was done using a safe HTTP method
|
||||
return fiber.IsMethodSafe(c.Method())
|
||||
},
|
||||
|
||||
Lifetime: 30 * time.Minute,
|
||||
|
||||
KeyHeader: "X-Idempotency-Key",
|
||||
KeyHeaderValidate: func(k string) error {
|
||||
if l, wl := len(k), 36; l != wl { // UUID length is 36 chars
|
||||
return fmt.Errorf("%w: invalid length: %d != %d", ErrInvalidIdempotencyKey, l, wl)
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
|
||||
KeepResponseHeaders: nil,
|
||||
|
||||
Lock: nil, // Set in configDefault so we don't allocate data here.
|
||||
|
||||
Storage: nil, // Set in configDefault so we don't allocate data here.
|
||||
}
|
||||
|
||||
// Helper function to set default values
|
||||
func configDefault(config ...Config) Config {
|
||||
// Return default config if nothing provided
|
||||
if len(config) < 1 {
|
||||
cfg := ConfigDefault
|
||||
|
||||
cfg.Lock = NewMemoryLock()
|
||||
cfg.Storage = memory.New(memory.Config{
|
||||
GCInterval: cfg.Lifetime / 2, // Half the lifetime interval
|
||||
})
|
||||
|
||||
return cfg
|
||||
}
|
||||
|
||||
// Override default config
|
||||
cfg := config[0]
|
||||
|
||||
// Set default values
|
||||
|
||||
if cfg.Next == nil {
|
||||
cfg.Next = ConfigDefault.Next
|
||||
}
|
||||
|
||||
if cfg.Lifetime.Nanoseconds() == 0 {
|
||||
cfg.Lifetime = ConfigDefault.Lifetime
|
||||
}
|
||||
|
||||
if cfg.KeyHeader == "" {
|
||||
cfg.KeyHeader = ConfigDefault.KeyHeader
|
||||
}
|
||||
if cfg.KeyHeaderValidate == nil {
|
||||
cfg.KeyHeaderValidate = ConfigDefault.KeyHeaderValidate
|
||||
}
|
||||
|
||||
if cfg.KeepResponseHeaders != nil && len(cfg.KeepResponseHeaders) == 0 {
|
||||
cfg.KeepResponseHeaders = ConfigDefault.KeepResponseHeaders
|
||||
}
|
||||
|
||||
if cfg.Lock == nil {
|
||||
cfg.Lock = NewMemoryLock()
|
||||
}
|
||||
|
||||
if cfg.Storage == nil {
|
||||
cfg.Storage = memory.New(memory.Config{
|
||||
GCInterval: cfg.Lifetime / 2, // Half the lifetime interval
|
||||
})
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
153
middleware/idempotency/idempotency.go
Normal file
153
middleware/idempotency/idempotency.go
Normal file
|
@ -0,0 +1,153 @@
|
|||
package idempotency
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/log"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
)
|
||||
|
||||
// Inspired by https://datatracker.ietf.org/doc/html/draft-ietf-httpapi-idempotency-key-header-02
|
||||
// and https://github.com/penguin-statistics/backend-next/blob/f2f7d5ba54fc8a58f168d153baa17b2ad4a14e45/internal/pkg/middlewares/idempotency.go
|
||||
|
||||
type localsKeys string
|
||||
|
||||
const (
|
||||
localsKeyIsFromCache localsKeys = "idempotency_isfromcache"
|
||||
localsKeyWasPutToCache localsKeys = "idempotency_wasputtocache"
|
||||
)
|
||||
|
||||
func IsFromCache(c *fiber.Ctx) bool {
|
||||
return c.Locals(localsKeyIsFromCache) != nil
|
||||
}
|
||||
|
||||
func WasPutToCache(c *fiber.Ctx) bool {
|
||||
return c.Locals(localsKeyWasPutToCache) != nil
|
||||
}
|
||||
|
||||
func New(config ...Config) fiber.Handler {
|
||||
// Set default config
|
||||
cfg := configDefault(config...)
|
||||
|
||||
keepResponseHeadersMap := make(map[string]struct{}, len(cfg.KeepResponseHeaders))
|
||||
for _, h := range cfg.KeepResponseHeaders {
|
||||
keepResponseHeadersMap[strings.ToLower(h)] = struct{}{}
|
||||
}
|
||||
|
||||
maybeWriteCachedResponse := func(c *fiber.Ctx, key string) (bool, error) {
|
||||
if val, err := cfg.Storage.Get(key); err != nil {
|
||||
return false, fmt.Errorf("failed to read response: %w", err)
|
||||
} else if val != nil {
|
||||
var res response
|
||||
if _, err := res.UnmarshalMsg(val); err != nil {
|
||||
return false, fmt.Errorf("failed to unmarshal response: %w", err)
|
||||
}
|
||||
|
||||
_ = c.Status(res.StatusCode)
|
||||
|
||||
for header, vals := range res.Headers {
|
||||
for _, val := range vals {
|
||||
c.Context().Response.Header.Add(header, val)
|
||||
}
|
||||
}
|
||||
|
||||
if len(res.Body) != 0 {
|
||||
if err := c.Send(res.Body); err != nil {
|
||||
return true, err
|
||||
}
|
||||
}
|
||||
|
||||
_ = c.Locals(localsKeyIsFromCache, true)
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return func(c *fiber.Ctx) error {
|
||||
// Don't execute middleware if Next returns true
|
||||
if cfg.Next != nil && cfg.Next(c) {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
// Don't execute middleware if the idempotency key is empty
|
||||
key := utils.CopyString(c.Get(cfg.KeyHeader))
|
||||
if key == "" {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
// Validate key
|
||||
if err := cfg.KeyHeaderValidate(key); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// First-pass: if the idempotency key is in the storage, get and return the response
|
||||
if ok, err := maybeWriteCachedResponse(c, key); err != nil {
|
||||
return fmt.Errorf("failed to write cached response at fastpath: %w", err)
|
||||
} else if ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := cfg.Lock.Lock(key); err != nil {
|
||||
return fmt.Errorf("failed to lock: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := cfg.Lock.Unlock(key); err != nil {
|
||||
log.Errorf("[IDEMPOTENCY] failed to unlock key %q: %v", key, err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Lock acquired. If the idempotency key now is in the storage, get and return the response
|
||||
if ok, err := maybeWriteCachedResponse(c, key); err != nil {
|
||||
return fmt.Errorf("failed to write cached response while locked: %w", err)
|
||||
} else if ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Execute the request handler
|
||||
if err := c.Next(); err != nil {
|
||||
// If the request handler returned an error, return it and skip idempotency
|
||||
return err
|
||||
}
|
||||
|
||||
// Construct response
|
||||
res := &response{
|
||||
StatusCode: c.Response().StatusCode(),
|
||||
|
||||
Body: utils.CopyBytes(c.Response().Body()),
|
||||
}
|
||||
{
|
||||
headers := c.GetRespHeaders()
|
||||
if cfg.KeepResponseHeaders == nil {
|
||||
// Keep all
|
||||
res.Headers = headers
|
||||
} else {
|
||||
// Filter
|
||||
res.Headers = make(map[string][]string)
|
||||
for h := range headers {
|
||||
if _, ok := keepResponseHeadersMap[utils.ToLower(h)]; ok {
|
||||
res.Headers[h] = headers[h]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Marshal response
|
||||
bs, err := res.MarshalMsg(nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal response: %w", err)
|
||||
}
|
||||
|
||||
// Store response
|
||||
if err := cfg.Storage.Set(key, bs, cfg.Lifetime); err != nil {
|
||||
return fmt.Errorf("failed to save response: %w", err)
|
||||
}
|
||||
|
||||
_ = c.Locals(localsKeyWasPutToCache, true)
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
177
middleware/idempotency/idempotency_test.go
Normal file
177
middleware/idempotency/idempotency_test.go
Normal file
|
@ -0,0 +1,177 @@
|
|||
//nolint:bodyclose // Much easier to just ignore memory leaks in tests
|
||||
package idempotency_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/middleware/idempotency"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// go test -run Test_Idempotency
|
||||
func Test_Idempotency(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(func(c *fiber.Ctx) error {
|
||||
if err := c.Next(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
isMethodSafe := fiber.IsMethodSafe(c.Method())
|
||||
isIdempotent := idempotency.IsFromCache(c) || idempotency.WasPutToCache(c)
|
||||
hasReqHeader := c.Get("X-Idempotency-Key") != ""
|
||||
|
||||
if isMethodSafe {
|
||||
if isIdempotent {
|
||||
return errors.New("request with safe HTTP method should not be idempotent")
|
||||
}
|
||||
} else {
|
||||
// Unsafe
|
||||
if hasReqHeader {
|
||||
if !isIdempotent {
|
||||
return errors.New("request with unsafe HTTP method should be idempotent if X-Idempotency-Key request header is set")
|
||||
}
|
||||
} else {
|
||||
// No request header
|
||||
if isIdempotent {
|
||||
return errors.New("request with unsafe HTTP method should not be idempotent if X-Idempotency-Key request header is not set")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
// Needs to be at least a second as the memory storage doesn't support shorter durations.
|
||||
const lifetime = 1 * time.Second
|
||||
|
||||
app.Use(idempotency.New(idempotency.Config{
|
||||
Lifetime: lifetime,
|
||||
}))
|
||||
|
||||
nextCount := func() func() int {
|
||||
var count int32
|
||||
return func() int {
|
||||
return int(atomic.AddInt32(&count, 1))
|
||||
}
|
||||
}()
|
||||
|
||||
{
|
||||
handler := func(c *fiber.Ctx) error {
|
||||
return c.SendString(strconv.Itoa(nextCount()))
|
||||
}
|
||||
|
||||
app.Get("/", handler)
|
||||
app.Post("/", handler)
|
||||
}
|
||||
|
||||
app.Post("/slow", func(c *fiber.Ctx) error {
|
||||
time.Sleep(2 * lifetime)
|
||||
|
||||
return c.SendString(strconv.Itoa(nextCount()))
|
||||
})
|
||||
|
||||
doReq := func(method, route, idempotencyKey string) string {
|
||||
req := httptest.NewRequest(method, route, http.NoBody)
|
||||
if idempotencyKey != "" {
|
||||
req.Header.Set("X-Idempotency-Key", idempotencyKey)
|
||||
}
|
||||
resp, err := app.Test(req, 3*int(lifetime.Milliseconds()))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode, string(body))
|
||||
return string(body)
|
||||
}
|
||||
|
||||
utils.AssertEqual(t, "1", doReq(fiber.MethodGet, "/", ""))
|
||||
utils.AssertEqual(t, "2", doReq(fiber.MethodGet, "/", ""))
|
||||
|
||||
utils.AssertEqual(t, "3", doReq(fiber.MethodPost, "/", ""))
|
||||
utils.AssertEqual(t, "4", doReq(fiber.MethodPost, "/", ""))
|
||||
|
||||
utils.AssertEqual(t, "5", doReq(fiber.MethodGet, "/", "00000000-0000-0000-0000-000000000000"))
|
||||
utils.AssertEqual(t, "6", doReq(fiber.MethodGet, "/", "00000000-0000-0000-0000-000000000000"))
|
||||
|
||||
utils.AssertEqual(t, "7", doReq(fiber.MethodPost, "/", "00000000-0000-0000-0000-000000000000"))
|
||||
utils.AssertEqual(t, "7", doReq(fiber.MethodPost, "/", "00000000-0000-0000-0000-000000000000"))
|
||||
utils.AssertEqual(t, "8", doReq(fiber.MethodPost, "/", ""))
|
||||
utils.AssertEqual(t, "9", doReq(fiber.MethodPost, "/", "11111111-1111-1111-1111-111111111111"))
|
||||
|
||||
utils.AssertEqual(t, "7", doReq(fiber.MethodPost, "/", "00000000-0000-0000-0000-000000000000"))
|
||||
time.Sleep(2 * lifetime)
|
||||
utils.AssertEqual(t, "10", doReq(fiber.MethodPost, "/", "00000000-0000-0000-0000-000000000000"))
|
||||
utils.AssertEqual(t, "10", doReq(fiber.MethodPost, "/", "00000000-0000-0000-0000-000000000000"))
|
||||
|
||||
// Test raciness
|
||||
{
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 100; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
utils.AssertEqual(t, "11", doReq(fiber.MethodPost, "/slow", "22222222-2222-2222-2222-222222222222"))
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
utils.AssertEqual(t, "11", doReq(fiber.MethodPost, "/slow", "22222222-2222-2222-2222-222222222222"))
|
||||
}
|
||||
time.Sleep(2 * lifetime)
|
||||
utils.AssertEqual(t, "12", doReq(fiber.MethodPost, "/slow", "22222222-2222-2222-2222-222222222222"))
|
||||
}
|
||||
|
||||
// go test -v -run=^$ -bench=Benchmark_Idempotency -benchmem -count=4
|
||||
func Benchmark_Idempotency(b *testing.B) {
|
||||
app := fiber.New()
|
||||
|
||||
// Needs to be at least a second as the memory storage doesn't support shorter durations.
|
||||
const lifetime = 1 * time.Second
|
||||
|
||||
app.Use(idempotency.New(idempotency.Config{
|
||||
Lifetime: lifetime,
|
||||
}))
|
||||
|
||||
app.Post("/", func(c *fiber.Ctx) error {
|
||||
return nil
|
||||
})
|
||||
|
||||
h := app.Handler()
|
||||
|
||||
b.Run("hit", func(b *testing.B) {
|
||||
c := &fasthttp.RequestCtx{}
|
||||
c.Request.Header.SetMethod(fiber.MethodPost)
|
||||
c.Request.SetRequestURI("/")
|
||||
c.Request.Header.Set("X-Idempotency-Key", "00000000-0000-0000-0000-000000000000")
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for n := 0; n < b.N; n++ {
|
||||
h(c)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("skip", func(b *testing.B) {
|
||||
c := &fasthttp.RequestCtx{}
|
||||
c.Request.Header.SetMethod(fiber.MethodPost)
|
||||
c.Request.SetRequestURI("/")
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for n := 0; n < b.N; n++ {
|
||||
h(c)
|
||||
}
|
||||
})
|
||||
}
|
53
middleware/idempotency/locker.go
Normal file
53
middleware/idempotency/locker.go
Normal file
|
@ -0,0 +1,53 @@
|
|||
package idempotency
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Locker implements a spinlock for a string key.
|
||||
type Locker interface {
|
||||
Lock(key string) error
|
||||
Unlock(key string) error
|
||||
}
|
||||
|
||||
type MemoryLock struct {
|
||||
mu sync.Mutex
|
||||
|
||||
keys map[string]*sync.Mutex
|
||||
}
|
||||
|
||||
func (l *MemoryLock) Lock(key string) error {
|
||||
l.mu.Lock()
|
||||
mu, ok := l.keys[key]
|
||||
if !ok {
|
||||
mu = new(sync.Mutex)
|
||||
l.keys[key] = mu
|
||||
}
|
||||
l.mu.Unlock()
|
||||
|
||||
mu.Lock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *MemoryLock) Unlock(key string) error {
|
||||
l.mu.Lock()
|
||||
mu, ok := l.keys[key]
|
||||
l.mu.Unlock()
|
||||
if !ok {
|
||||
// This happens if we try to unlock an unknown key
|
||||
return nil
|
||||
}
|
||||
|
||||
mu.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewMemoryLock() *MemoryLock {
|
||||
return &MemoryLock{
|
||||
keys: make(map[string]*sync.Mutex),
|
||||
}
|
||||
}
|
||||
|
||||
var _ Locker = (*MemoryLock)(nil)
|
59
middleware/idempotency/locker_test.go
Normal file
59
middleware/idempotency/locker_test.go
Normal file
|
@ -0,0 +1,59 @@
|
|||
package idempotency_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2/middleware/idempotency"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
)
|
||||
|
||||
// go test -run Test_MemoryLock
|
||||
func Test_MemoryLock(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
l := idempotency.NewMemoryLock()
|
||||
|
||||
{
|
||||
err := l.Lock("a")
|
||||
utils.AssertEqual(t, nil, err)
|
||||
}
|
||||
{
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
|
||||
err := l.Lock("a")
|
||||
utils.AssertEqual(t, nil, err)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
t.Fatal("lock acquired again")
|
||||
case <-time.After(time.Second):
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
err := l.Lock("b")
|
||||
utils.AssertEqual(t, nil, err)
|
||||
}
|
||||
{
|
||||
err := l.Unlock("b")
|
||||
utils.AssertEqual(t, nil, err)
|
||||
}
|
||||
{
|
||||
err := l.Lock("b")
|
||||
utils.AssertEqual(t, nil, err)
|
||||
}
|
||||
|
||||
{
|
||||
err := l.Unlock("c")
|
||||
utils.AssertEqual(t, nil, err)
|
||||
}
|
||||
|
||||
{
|
||||
err := l.Lock("d")
|
||||
utils.AssertEqual(t, nil, err)
|
||||
}
|
||||
}
|
10
middleware/idempotency/response.go
Normal file
10
middleware/idempotency/response.go
Normal file
|
@ -0,0 +1,10 @@
|
|||
package idempotency
|
||||
|
||||
//go:generate msgp -o=response_msgp.go -io=false -unexported
|
||||
type response struct {
|
||||
StatusCode int `msg:"sc"`
|
||||
|
||||
Headers map[string][]string `msg:"hs"`
|
||||
|
||||
Body []byte `msg:"b"`
|
||||
}
|
131
middleware/idempotency/response_msgp.go
Normal file
131
middleware/idempotency/response_msgp.go
Normal file
|
@ -0,0 +1,131 @@
|
|||
package idempotency
|
||||
|
||||
// Code generated by github.com/tinylib/msgp DO NOT EDIT.
|
||||
|
||||
import (
|
||||
"github.com/tinylib/msgp/msgp"
|
||||
)
|
||||
|
||||
// MarshalMsg implements msgp.Marshaler
|
||||
func (z *response) MarshalMsg(b []byte) (o []byte, err error) {
|
||||
o = msgp.Require(b, z.Msgsize())
|
||||
// map header, size 3
|
||||
// string "sc"
|
||||
o = append(o, 0x83, 0xa2, 0x73, 0x63)
|
||||
o = msgp.AppendInt(o, z.StatusCode)
|
||||
// string "hs"
|
||||
o = append(o, 0xa2, 0x68, 0x73)
|
||||
o = msgp.AppendMapHeader(o, uint32(len(z.Headers)))
|
||||
for za0001, za0002 := range z.Headers {
|
||||
o = msgp.AppendString(o, za0001)
|
||||
o = msgp.AppendArrayHeader(o, uint32(len(za0002)))
|
||||
for za0003 := range za0002 {
|
||||
o = msgp.AppendString(o, za0002[za0003])
|
||||
}
|
||||
}
|
||||
// string "b"
|
||||
o = append(o, 0xa1, 0x62)
|
||||
o = msgp.AppendBytes(o, z.Body)
|
||||
return
|
||||
}
|
||||
|
||||
// UnmarshalMsg implements msgp.Unmarshaler
|
||||
func (z *response) UnmarshalMsg(bts []byte) (o []byte, err error) {
|
||||
var field []byte
|
||||
_ = field
|
||||
var zb0001 uint32
|
||||
zb0001, bts, err = msgp.ReadMapHeaderBytes(bts)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err)
|
||||
return
|
||||
}
|
||||
for zb0001 > 0 {
|
||||
zb0001--
|
||||
field, bts, err = msgp.ReadMapKeyZC(bts)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err)
|
||||
return
|
||||
}
|
||||
switch msgp.UnsafeString(field) {
|
||||
case "sc":
|
||||
z.StatusCode, bts, err = msgp.ReadIntBytes(bts)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "StatusCode")
|
||||
return
|
||||
}
|
||||
case "hs":
|
||||
var zb0002 uint32
|
||||
zb0002, bts, err = msgp.ReadMapHeaderBytes(bts)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "Headers")
|
||||
return
|
||||
}
|
||||
if z.Headers == nil {
|
||||
z.Headers = make(map[string][]string, zb0002)
|
||||
} else if len(z.Headers) > 0 {
|
||||
for key := range z.Headers {
|
||||
delete(z.Headers, key)
|
||||
}
|
||||
}
|
||||
for zb0002 > 0 {
|
||||
var za0001 string
|
||||
var za0002 []string
|
||||
zb0002--
|
||||
za0001, bts, err = msgp.ReadStringBytes(bts)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "Headers")
|
||||
return
|
||||
}
|
||||
var zb0003 uint32
|
||||
zb0003, bts, err = msgp.ReadArrayHeaderBytes(bts)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "Headers", za0001)
|
||||
return
|
||||
}
|
||||
if cap(za0002) >= int(zb0003) {
|
||||
za0002 = (za0002)[:zb0003]
|
||||
} else {
|
||||
za0002 = make([]string, zb0003)
|
||||
}
|
||||
for za0003 := range za0002 {
|
||||
za0002[za0003], bts, err = msgp.ReadStringBytes(bts)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "Headers", za0001, za0003)
|
||||
return
|
||||
}
|
||||
}
|
||||
z.Headers[za0001] = za0002
|
||||
}
|
||||
case "b":
|
||||
z.Body, bts, err = msgp.ReadBytesBytes(bts, z.Body)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "Body")
|
||||
return
|
||||
}
|
||||
default:
|
||||
bts, err = msgp.Skip(bts)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
o = bts
|
||||
return
|
||||
}
|
||||
|
||||
// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message
|
||||
func (z *response) Msgsize() (s int) {
|
||||
s = 1 + 3 + msgp.IntSize + 3 + msgp.MapHeaderSize
|
||||
if z.Headers != nil {
|
||||
for za0001, za0002 := range z.Headers {
|
||||
_ = za0002
|
||||
s += msgp.StringPrefixSize + len(za0001) + msgp.ArrayHeaderSize
|
||||
for za0003 := range za0002 {
|
||||
s += msgp.StringPrefixSize + len(za0002[za0003])
|
||||
}
|
||||
}
|
||||
}
|
||||
s += 2 + msgp.BytesPrefixSize + len(z.Body)
|
||||
return
|
||||
}
|
67
middleware/idempotency/response_msgp_test.go
Normal file
67
middleware/idempotency/response_msgp_test.go
Normal file
|
@ -0,0 +1,67 @@
|
|||
package idempotency
|
||||
|
||||
// Code generated by github.com/tinylib/msgp DO NOT EDIT.
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/tinylib/msgp/msgp"
|
||||
)
|
||||
|
||||
func TestMarshalUnmarshalresponse(t *testing.T) {
|
||||
v := response{}
|
||||
bts, err := v.MarshalMsg(nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
left, err := v.UnmarshalMsg(bts)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(left) > 0 {
|
||||
t.Errorf("%d bytes left over after UnmarshalMsg(): %q", len(left), left)
|
||||
}
|
||||
|
||||
left, err = msgp.Skip(bts)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(left) > 0 {
|
||||
t.Errorf("%d bytes left over after Skip(): %q", len(left), left)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkMarshalMsgresponse(b *testing.B) {
|
||||
v := response{}
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
v.MarshalMsg(nil)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkAppendMsgresponse(b *testing.B) {
|
||||
v := response{}
|
||||
bts := make([]byte, 0, v.Msgsize())
|
||||
bts, _ = v.MarshalMsg(bts[0:0])
|
||||
b.SetBytes(int64(len(bts)))
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
bts, _ = v.MarshalMsg(bts[0:0])
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkUnmarshalresponse(b *testing.B) {
|
||||
v := response{}
|
||||
bts, _ := v.MarshalMsg(nil)
|
||||
b.ReportAllocs()
|
||||
b.SetBytes(int64(len(bts)))
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := v.UnmarshalMsg(bts)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
95
middleware/keyauth/config.go
Normal file
95
middleware/keyauth/config.go
Normal file
|
@ -0,0 +1,95 @@
|
|||
package keyauth
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
// Config defines the config for middleware.
|
||||
type Config struct {
|
||||
// Next defines a function to skip middleware.
|
||||
// Optional. Default: nil
|
||||
Next func(*fiber.Ctx) bool
|
||||
|
||||
// SuccessHandler defines a function which is executed for a valid key.
|
||||
// Optional. Default: nil
|
||||
SuccessHandler fiber.Handler
|
||||
|
||||
// ErrorHandler defines a function which is executed for an invalid key.
|
||||
// It may be used to define a custom error.
|
||||
// Optional. Default: 401 Invalid or expired key
|
||||
ErrorHandler fiber.ErrorHandler
|
||||
|
||||
// KeyLookup is a string in the form of "<source>:<name>" that is used
|
||||
// to extract key from the request.
|
||||
// Optional. Default value "header:Authorization".
|
||||
// Possible values:
|
||||
// - "header:<name>"
|
||||
// - "query:<name>"
|
||||
// - "form:<name>"
|
||||
// - "param:<name>"
|
||||
// - "cookie:<name>"
|
||||
KeyLookup string
|
||||
|
||||
// AuthScheme to be used in the Authorization header.
|
||||
// Optional. Default value "Bearer".
|
||||
AuthScheme string
|
||||
|
||||
// Validator is a function to validate key.
|
||||
Validator func(*fiber.Ctx, string) (bool, error)
|
||||
|
||||
// Context key to store the bearertoken from the token into context.
|
||||
// Optional. Default: "token".
|
||||
ContextKey interface{}
|
||||
}
|
||||
|
||||
// ConfigDefault is the default config
|
||||
var ConfigDefault = Config{
|
||||
SuccessHandler: func(c *fiber.Ctx) error {
|
||||
return c.Next()
|
||||
},
|
||||
ErrorHandler: func(c *fiber.Ctx, err error) error {
|
||||
if errors.Is(err, ErrMissingOrMalformedAPIKey) {
|
||||
return c.Status(fiber.StatusUnauthorized).SendString(err.Error())
|
||||
}
|
||||
return c.Status(fiber.StatusUnauthorized).SendString("Invalid or expired API Key")
|
||||
},
|
||||
KeyLookup: "header:" + fiber.HeaderAuthorization,
|
||||
AuthScheme: "Bearer",
|
||||
ContextKey: "token",
|
||||
}
|
||||
|
||||
// Helper function to set default values
|
||||
func configDefault(config ...Config) Config {
|
||||
// Return default config if nothing provided
|
||||
if len(config) < 1 {
|
||||
return ConfigDefault
|
||||
}
|
||||
|
||||
// Override default config
|
||||
cfg := config[0]
|
||||
|
||||
// Set default values
|
||||
if cfg.SuccessHandler == nil {
|
||||
cfg.SuccessHandler = ConfigDefault.SuccessHandler
|
||||
}
|
||||
if cfg.ErrorHandler == nil {
|
||||
cfg.ErrorHandler = ConfigDefault.ErrorHandler
|
||||
}
|
||||
if cfg.KeyLookup == "" {
|
||||
cfg.KeyLookup = ConfigDefault.KeyLookup
|
||||
// set AuthScheme as "Bearer" only if KeyLookup is set to default.
|
||||
if cfg.AuthScheme == "" {
|
||||
cfg.AuthScheme = ConfigDefault.AuthScheme
|
||||
}
|
||||
}
|
||||
if cfg.Validator == nil {
|
||||
panic("fiber: keyauth middleware requires a validator function")
|
||||
}
|
||||
if cfg.ContextKey == nil {
|
||||
cfg.ContextKey = ConfigDefault.ContextKey
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
121
middleware/keyauth/keyauth.go
Normal file
121
middleware/keyauth/keyauth.go
Normal file
|
@ -0,0 +1,121 @@
|
|||
// Special thanks to Echo: https://github.com/labstack/echo/blob/master/middleware/key_auth.go
|
||||
package keyauth
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
// When there is no request of the key thrown ErrMissingOrMalformedAPIKey
|
||||
var ErrMissingOrMalformedAPIKey = errors.New("missing or malformed API Key")
|
||||
|
||||
const (
|
||||
query = "query"
|
||||
form = "form"
|
||||
param = "param"
|
||||
cookie = "cookie"
|
||||
)
|
||||
|
||||
// New creates a new middleware handler
|
||||
func New(config ...Config) fiber.Handler {
|
||||
// Init config
|
||||
cfg := configDefault(config...)
|
||||
|
||||
// Initialize
|
||||
parts := strings.Split(cfg.KeyLookup, ":")
|
||||
extractor := keyFromHeader(parts[1], cfg.AuthScheme)
|
||||
switch parts[0] {
|
||||
case query:
|
||||
extractor = keyFromQuery(parts[1])
|
||||
case form:
|
||||
extractor = keyFromForm(parts[1])
|
||||
case param:
|
||||
extractor = keyFromParam(parts[1])
|
||||
case cookie:
|
||||
extractor = keyFromCookie(parts[1])
|
||||
}
|
||||
|
||||
// Return middleware handler
|
||||
return func(c *fiber.Ctx) error {
|
||||
// Filter request to skip middleware
|
||||
if cfg.Next != nil && cfg.Next(c) {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
// Extract and verify key
|
||||
key, err := extractor(c)
|
||||
if err != nil {
|
||||
return cfg.ErrorHandler(c, err)
|
||||
}
|
||||
|
||||
valid, err := cfg.Validator(c, key)
|
||||
|
||||
if err == nil && valid {
|
||||
c.Locals(cfg.ContextKey, key)
|
||||
return cfg.SuccessHandler(c)
|
||||
}
|
||||
return cfg.ErrorHandler(c, err)
|
||||
}
|
||||
}
|
||||
|
||||
// keyFromHeader returns a function that extracts api key from the request header.
|
||||
func keyFromHeader(header, authScheme string) func(c *fiber.Ctx) (string, error) {
|
||||
return func(c *fiber.Ctx) (string, error) {
|
||||
auth := c.Get(header)
|
||||
l := len(authScheme)
|
||||
if len(auth) > 0 && l == 0 {
|
||||
return auth, nil
|
||||
}
|
||||
if len(auth) > l+1 && auth[:l] == authScheme {
|
||||
return auth[l+1:], nil
|
||||
}
|
||||
return "", ErrMissingOrMalformedAPIKey
|
||||
}
|
||||
}
|
||||
|
||||
// keyFromQuery returns a function that extracts api key from the query string.
|
||||
func keyFromQuery(param string) func(c *fiber.Ctx) (string, error) {
|
||||
return func(c *fiber.Ctx) (string, error) {
|
||||
key := c.Query(param)
|
||||
if key == "" {
|
||||
return "", ErrMissingOrMalformedAPIKey
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
}
|
||||
|
||||
// keyFromForm returns a function that extracts api key from the form.
|
||||
func keyFromForm(param string) func(c *fiber.Ctx) (string, error) {
|
||||
return func(c *fiber.Ctx) (string, error) {
|
||||
key := c.FormValue(param)
|
||||
if key == "" {
|
||||
return "", ErrMissingOrMalformedAPIKey
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
}
|
||||
|
||||
// keyFromParam returns a function that extracts api key from the url param string.
|
||||
func keyFromParam(param string) func(c *fiber.Ctx) (string, error) {
|
||||
return func(c *fiber.Ctx) (string, error) {
|
||||
key, err := url.PathUnescape(c.Params(param))
|
||||
if err != nil {
|
||||
return "", ErrMissingOrMalformedAPIKey
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
}
|
||||
|
||||
// keyFromCookie returns a function that extracts api key from the named cookie.
|
||||
func keyFromCookie(name string) func(c *fiber.Ctx) (string, error) {
|
||||
return func(c *fiber.Ctx) (string, error) {
|
||||
key := c.Cookies(name)
|
||||
if key == "" {
|
||||
return "", ErrMissingOrMalformedAPIKey
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
}
|
461
middleware/keyauth/keyauth_test.go
Normal file
461
middleware/keyauth/keyauth_test.go
Normal file
|
@ -0,0 +1,461 @@
|
|||
//nolint:bodyclose // Much easier to just ignore memory leaks in tests
|
||||
package keyauth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
)
|
||||
|
||||
const CorrectKey = "specials: !$%,.#\"!?~`<>@$^*(){}[]|/\\123"
|
||||
|
||||
func TestAuthSources(t *testing.T) {
|
||||
// define test cases
|
||||
testSources := []string{"header", "cookie", "query", "param", "form"}
|
||||
|
||||
tests := []struct {
|
||||
route string
|
||||
authTokenName string
|
||||
description string
|
||||
APIKey string
|
||||
expectedCode int
|
||||
expectedBody string
|
||||
}{
|
||||
{
|
||||
route: "/",
|
||||
authTokenName: "access_token",
|
||||
description: "auth with correct key",
|
||||
APIKey: CorrectKey,
|
||||
expectedCode: 200,
|
||||
expectedBody: "Success!",
|
||||
},
|
||||
{
|
||||
route: "/",
|
||||
authTokenName: "access_token",
|
||||
description: "auth with no key",
|
||||
APIKey: "",
|
||||
expectedCode: 401, // 404 in case of param authentication
|
||||
expectedBody: "missing or malformed API Key",
|
||||
},
|
||||
{
|
||||
route: "/",
|
||||
authTokenName: "access_token",
|
||||
description: "auth with wrong key",
|
||||
APIKey: "WRONGKEY",
|
||||
expectedCode: 401,
|
||||
expectedBody: "missing or malformed API Key",
|
||||
},
|
||||
}
|
||||
|
||||
for _, authSource := range testSources {
|
||||
t.Run(authSource, func(t *testing.T) {
|
||||
for _, test := range tests {
|
||||
// setup the fiber endpoint
|
||||
// note that if UnescapePath: false (the default)
|
||||
// escaped characters (such as `\"`) will not be handled correctly in the tests
|
||||
app := fiber.New(fiber.Config{UnescapePath: true})
|
||||
|
||||
authMiddleware := New(Config{
|
||||
KeyLookup: authSource + ":" + test.authTokenName,
|
||||
Validator: func(c *fiber.Ctx, key string) (bool, error) {
|
||||
if key == CorrectKey {
|
||||
return true, nil
|
||||
}
|
||||
return false, ErrMissingOrMalformedAPIKey
|
||||
},
|
||||
})
|
||||
|
||||
var route string
|
||||
if authSource == param {
|
||||
route = test.route + ":" + test.authTokenName
|
||||
app.Use(route, authMiddleware)
|
||||
} else {
|
||||
route = test.route
|
||||
app.Use(authMiddleware)
|
||||
}
|
||||
|
||||
app.Get(route, func(c *fiber.Ctx) error {
|
||||
return c.SendString("Success!")
|
||||
})
|
||||
|
||||
// construct the test HTTP request
|
||||
var req *http.Request
|
||||
req, err := http.NewRequestWithContext(context.Background(), fiber.MethodGet, test.route, nil)
|
||||
utils.AssertEqual(t, err, nil)
|
||||
|
||||
// setup the apikey for the different auth schemes
|
||||
if authSource == "header" {
|
||||
req.Header.Set(test.authTokenName, test.APIKey)
|
||||
} else if authSource == "cookie" {
|
||||
req.Header.Set("Cookie", test.authTokenName+"="+test.APIKey)
|
||||
} else if authSource == "query" || authSource == "form" {
|
||||
q := req.URL.Query()
|
||||
q.Add(test.authTokenName, test.APIKey)
|
||||
req.URL.RawQuery = q.Encode()
|
||||
} else if authSource == "param" {
|
||||
r := req.URL.Path
|
||||
r += url.PathEscape(test.APIKey)
|
||||
req.URL.Path = r
|
||||
}
|
||||
|
||||
res, err := app.Test(req, -1)
|
||||
|
||||
utils.AssertEqual(t, nil, err, test.description)
|
||||
|
||||
// test the body of the request
|
||||
body, err := io.ReadAll(res.Body)
|
||||
// for param authentication, the route would be /:access_token
|
||||
// when the access_token is empty, it leads to a 404 (not found)
|
||||
// not a 401 (auth error)
|
||||
if authSource == "param" && test.APIKey == "" {
|
||||
test.expectedCode = 404
|
||||
test.expectedBody = "Cannot GET /"
|
||||
}
|
||||
utils.AssertEqual(t, test.expectedCode, res.StatusCode, test.description)
|
||||
|
||||
// body
|
||||
utils.AssertEqual(t, nil, err, test.description)
|
||||
utils.AssertEqual(t, test.expectedBody, string(body), test.description)
|
||||
|
||||
err = res.Body.Close()
|
||||
utils.AssertEqual(t, err, nil)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultipleKeyAuth(t *testing.T) {
|
||||
// setup the fiber endpoint
|
||||
app := fiber.New()
|
||||
|
||||
// setup keyauth for /auth1
|
||||
app.Use(New(Config{
|
||||
Next: func(c *fiber.Ctx) bool {
|
||||
return c.OriginalURL() != "/auth1"
|
||||
},
|
||||
KeyLookup: "header:key",
|
||||
Validator: func(c *fiber.Ctx, key string) (bool, error) {
|
||||
if key == "password1" {
|
||||
return true, nil
|
||||
}
|
||||
return false, ErrMissingOrMalformedAPIKey
|
||||
},
|
||||
}))
|
||||
|
||||
// setup keyauth for /auth2
|
||||
app.Use(New(Config{
|
||||
Next: func(c *fiber.Ctx) bool {
|
||||
return c.OriginalURL() != "/auth2"
|
||||
},
|
||||
KeyLookup: "header:key",
|
||||
Validator: func(c *fiber.Ctx, key string) (bool, error) {
|
||||
if key == "password2" {
|
||||
return true, nil
|
||||
}
|
||||
return false, ErrMissingOrMalformedAPIKey
|
||||
},
|
||||
}))
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return c.SendString("No auth needed!")
|
||||
})
|
||||
|
||||
app.Get("/auth1", func(c *fiber.Ctx) error {
|
||||
return c.SendString("Successfully authenticated for auth1!")
|
||||
})
|
||||
|
||||
app.Get("/auth2", func(c *fiber.Ctx) error {
|
||||
return c.SendString("Successfully authenticated for auth2!")
|
||||
})
|
||||
|
||||
// define test cases
|
||||
tests := []struct {
|
||||
route string
|
||||
description string
|
||||
APIKey string
|
||||
expectedCode int
|
||||
expectedBody string
|
||||
}{
|
||||
// No auth needed for /
|
||||
{
|
||||
route: "/",
|
||||
description: "No password needed",
|
||||
APIKey: "",
|
||||
expectedCode: 200,
|
||||
expectedBody: "No auth needed!",
|
||||
},
|
||||
|
||||
// auth needed for auth1
|
||||
{
|
||||
route: "/auth1",
|
||||
description: "Normal Authentication Case",
|
||||
APIKey: "password1",
|
||||
expectedCode: 200,
|
||||
expectedBody: "Successfully authenticated for auth1!",
|
||||
},
|
||||
{
|
||||
route: "/auth1",
|
||||
description: "Wrong API Key",
|
||||
APIKey: "WRONG KEY",
|
||||
expectedCode: 401,
|
||||
expectedBody: "missing or malformed API Key",
|
||||
},
|
||||
{
|
||||
route: "/auth1",
|
||||
description: "Wrong API Key",
|
||||
APIKey: "", // NO KEY
|
||||
expectedCode: 401,
|
||||
expectedBody: "missing or malformed API Key",
|
||||
},
|
||||
|
||||
// Auth 2 has a different password
|
||||
{
|
||||
route: "/auth2",
|
||||
description: "Normal Authentication Case for auth2",
|
||||
APIKey: "password2",
|
||||
expectedCode: 200,
|
||||
expectedBody: "Successfully authenticated for auth2!",
|
||||
},
|
||||
{
|
||||
route: "/auth2",
|
||||
description: "Wrong API Key",
|
||||
APIKey: "WRONG KEY",
|
||||
expectedCode: 401,
|
||||
expectedBody: "missing or malformed API Key",
|
||||
},
|
||||
{
|
||||
route: "/auth2",
|
||||
description: "Wrong API Key",
|
||||
APIKey: "", // NO KEY
|
||||
expectedCode: 401,
|
||||
expectedBody: "missing or malformed API Key",
|
||||
},
|
||||
}
|
||||
|
||||
// run the tests
|
||||
for _, test := range tests {
|
||||
var req *http.Request
|
||||
req, err := http.NewRequestWithContext(context.Background(), fiber.MethodGet, test.route, nil)
|
||||
utils.AssertEqual(t, err, nil)
|
||||
if test.APIKey != "" {
|
||||
req.Header.Set("key", test.APIKey)
|
||||
}
|
||||
|
||||
res, err := app.Test(req, -1)
|
||||
|
||||
utils.AssertEqual(t, nil, err, test.description)
|
||||
|
||||
// test the body of the request
|
||||
body, err := io.ReadAll(res.Body)
|
||||
utils.AssertEqual(t, test.expectedCode, res.StatusCode, test.description)
|
||||
|
||||
// body
|
||||
utils.AssertEqual(t, nil, err, test.description)
|
||||
utils.AssertEqual(t, test.expectedBody, string(body), test.description)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCustomSuccessAndFailureHandlers(t *testing.T) {
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
SuccessHandler: func(c *fiber.Ctx) error {
|
||||
return c.Status(fiber.StatusOK).SendString("API key is valid and request was handled by custom success handler")
|
||||
},
|
||||
ErrorHandler: func(c *fiber.Ctx, err error) error {
|
||||
return c.Status(fiber.StatusUnauthorized).SendString("API key is invalid and request was handled by custom error handler")
|
||||
},
|
||||
Validator: func(c *fiber.Ctx, key string) (bool, error) {
|
||||
if key == CorrectKey {
|
||||
return true, nil
|
||||
}
|
||||
return false, ErrMissingOrMalformedAPIKey
|
||||
},
|
||||
}))
|
||||
|
||||
// Define a test handler that should not be called
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
t.Error("Test handler should not be called")
|
||||
return nil
|
||||
})
|
||||
|
||||
// Create a request without an API key and send it to the app
|
||||
res, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, err, nil)
|
||||
|
||||
// Read the response body into a string
|
||||
body, err := io.ReadAll(res.Body)
|
||||
utils.AssertEqual(t, err, nil)
|
||||
|
||||
// Check that the response has the expected status code and body
|
||||
utils.AssertEqual(t, res.StatusCode, http.StatusUnauthorized)
|
||||
utils.AssertEqual(t, string(body), "API key is invalid and request was handled by custom error handler")
|
||||
|
||||
// Create a request with a valid API key in the Authorization header
|
||||
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", CorrectKey))
|
||||
|
||||
// Send the request to the app
|
||||
res, err = app.Test(req)
|
||||
utils.AssertEqual(t, err, nil)
|
||||
|
||||
// Read the response body into a string
|
||||
body, err = io.ReadAll(res.Body)
|
||||
utils.AssertEqual(t, err, nil)
|
||||
|
||||
// Check that the response has the expected status code and body
|
||||
utils.AssertEqual(t, res.StatusCode, http.StatusOK)
|
||||
utils.AssertEqual(t, string(body), "API key is valid and request was handled by custom success handler")
|
||||
}
|
||||
|
||||
func TestCustomNextFunc(t *testing.T) {
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
Next: func(c *fiber.Ctx) bool {
|
||||
return c.Path() == "/allowed"
|
||||
},
|
||||
Validator: func(c *fiber.Ctx, key string) (bool, error) {
|
||||
if key == CorrectKey {
|
||||
return true, nil
|
||||
}
|
||||
return false, ErrMissingOrMalformedAPIKey
|
||||
},
|
||||
}))
|
||||
|
||||
// Define a test handler
|
||||
app.Get("/allowed", func(c *fiber.Ctx) error {
|
||||
return c.SendString("API key is valid and request was allowed by custom filter")
|
||||
})
|
||||
|
||||
// Create a request with the "/allowed" path and send it to the app
|
||||
req := httptest.NewRequest(fiber.MethodGet, "/allowed", nil)
|
||||
res, err := app.Test(req)
|
||||
utils.AssertEqual(t, err, nil)
|
||||
|
||||
// Read the response body into a string
|
||||
body, err := io.ReadAll(res.Body)
|
||||
utils.AssertEqual(t, err, nil)
|
||||
|
||||
// Check that the response has the expected status code and body
|
||||
utils.AssertEqual(t, res.StatusCode, http.StatusOK)
|
||||
utils.AssertEqual(t, string(body), "API key is valid and request was allowed by custom filter")
|
||||
|
||||
// Create a request with a different path and send it to the app without correct key
|
||||
req = httptest.NewRequest(fiber.MethodGet, "/not-allowed", nil)
|
||||
res, err = app.Test(req)
|
||||
utils.AssertEqual(t, err, nil)
|
||||
|
||||
// Read the response body into a string
|
||||
body, err = io.ReadAll(res.Body)
|
||||
utils.AssertEqual(t, err, nil)
|
||||
|
||||
// Check that the response has the expected status code and body
|
||||
utils.AssertEqual(t, res.StatusCode, http.StatusUnauthorized)
|
||||
utils.AssertEqual(t, string(body), ErrMissingOrMalformedAPIKey.Error())
|
||||
|
||||
// Create a request with a different path and send it to the app with correct key
|
||||
req = httptest.NewRequest(fiber.MethodGet, "/not-allowed", nil)
|
||||
req.Header.Add("Authorization", fmt.Sprintf("Basic %s", CorrectKey))
|
||||
|
||||
res, err = app.Test(req)
|
||||
utils.AssertEqual(t, err, nil)
|
||||
|
||||
// Read the response body into a string
|
||||
body, err = io.ReadAll(res.Body)
|
||||
utils.AssertEqual(t, err, nil)
|
||||
|
||||
// Check that the response has the expected status code and body
|
||||
utils.AssertEqual(t, res.StatusCode, http.StatusUnauthorized)
|
||||
utils.AssertEqual(t, string(body), ErrMissingOrMalformedAPIKey.Error())
|
||||
}
|
||||
|
||||
func TestAuthSchemeToken(t *testing.T) {
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
AuthScheme: "Token",
|
||||
Validator: func(c *fiber.Ctx, key string) (bool, error) {
|
||||
if key == CorrectKey {
|
||||
return true, nil
|
||||
}
|
||||
return false, ErrMissingOrMalformedAPIKey
|
||||
},
|
||||
}))
|
||||
|
||||
// Define a test handler
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return c.SendString("API key is valid")
|
||||
})
|
||||
|
||||
// Create a request with a valid API key in the "Token" Authorization header
|
||||
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||
req.Header.Add("Authorization", fmt.Sprintf("Token %s", CorrectKey))
|
||||
|
||||
// Send the request to the app
|
||||
res, err := app.Test(req)
|
||||
utils.AssertEqual(t, err, nil)
|
||||
|
||||
// Read the response body into a string
|
||||
body, err := io.ReadAll(res.Body)
|
||||
utils.AssertEqual(t, err, nil)
|
||||
|
||||
// Check that the response has the expected status code and body
|
||||
utils.AssertEqual(t, res.StatusCode, http.StatusOK)
|
||||
utils.AssertEqual(t, string(body), "API key is valid")
|
||||
}
|
||||
|
||||
func TestAuthSchemeBasic(t *testing.T) {
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
KeyLookup: "header:Authorization",
|
||||
AuthScheme: "Basic",
|
||||
Validator: func(c *fiber.Ctx, key string) (bool, error) {
|
||||
if key == CorrectKey {
|
||||
return true, nil
|
||||
}
|
||||
return false, ErrMissingOrMalformedAPIKey
|
||||
},
|
||||
}))
|
||||
|
||||
// Define a test handler
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return c.SendString("API key is valid")
|
||||
})
|
||||
|
||||
// Create a request without an API key and Send the request to the app
|
||||
res, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, err, nil)
|
||||
|
||||
// Read the response body into a string
|
||||
body, err := io.ReadAll(res.Body)
|
||||
utils.AssertEqual(t, err, nil)
|
||||
|
||||
// Check that the response has the expected status code and body
|
||||
utils.AssertEqual(t, res.StatusCode, http.StatusUnauthorized)
|
||||
utils.AssertEqual(t, string(body), ErrMissingOrMalformedAPIKey.Error())
|
||||
|
||||
// Create a request with a valid API key in the "Authorization" header using the "Basic" scheme
|
||||
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||
req.Header.Add("Authorization", fmt.Sprintf("Basic %s", CorrectKey))
|
||||
|
||||
// Send the request to the app
|
||||
res, err = app.Test(req)
|
||||
utils.AssertEqual(t, err, nil)
|
||||
|
||||
// Read the response body into a string
|
||||
body, err = io.ReadAll(res.Body)
|
||||
utils.AssertEqual(t, err, nil)
|
||||
|
||||
// Check that the response has the expected status code and body
|
||||
utils.AssertEqual(t, res.StatusCode, http.StatusOK)
|
||||
utils.AssertEqual(t, string(body), "API key is valid")
|
||||
}
|
128
middleware/limiter/config.go
Normal file
128
middleware/limiter/config.go
Normal file
|
@ -0,0 +1,128 @@
|
|||
package limiter
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/log"
|
||||
)
|
||||
|
||||
// Config defines the config for middleware.
|
||||
type Config struct {
|
||||
// Next defines a function to skip this middleware when returned true.
|
||||
//
|
||||
// Optional. Default: nil
|
||||
Next func(c *fiber.Ctx) bool
|
||||
|
||||
// Max number of recent connections during `Expiration` seconds before sending a 429 response
|
||||
//
|
||||
// Default: 5
|
||||
Max int
|
||||
|
||||
// KeyGenerator allows you to generate custom keys, by default c.IP() is used
|
||||
//
|
||||
// Default: func(c *fiber.Ctx) string {
|
||||
// return c.IP()
|
||||
// }
|
||||
KeyGenerator func(*fiber.Ctx) string
|
||||
|
||||
// Expiration is the time on how long to keep records of requests in memory
|
||||
//
|
||||
// Default: 1 * time.Minute
|
||||
Expiration time.Duration
|
||||
|
||||
// LimitReached is called when a request hits the limit
|
||||
//
|
||||
// Default: func(c *fiber.Ctx) error {
|
||||
// return c.SendStatus(fiber.StatusTooManyRequests)
|
||||
// }
|
||||
LimitReached fiber.Handler
|
||||
|
||||
// When set to true, requests with StatusCode >= 400 won't be counted.
|
||||
//
|
||||
// Default: false
|
||||
SkipFailedRequests bool
|
||||
|
||||
// When set to true, requests with StatusCode < 400 won't be counted.
|
||||
//
|
||||
// Default: false
|
||||
SkipSuccessfulRequests bool
|
||||
|
||||
// Store is used to store the state of the middleware
|
||||
//
|
||||
// Default: an in memory store for this process only
|
||||
Storage fiber.Storage
|
||||
|
||||
// LimiterMiddleware is the struct that implements a limiter middleware.
|
||||
//
|
||||
// Default: a new Fixed Window Rate Limiter
|
||||
LimiterMiddleware LimiterHandler
|
||||
|
||||
// Deprecated: Use Expiration instead
|
||||
Duration time.Duration
|
||||
|
||||
// Deprecated: Use Storage instead
|
||||
Store fiber.Storage
|
||||
|
||||
// Deprecated: Use KeyGenerator instead
|
||||
Key func(*fiber.Ctx) string
|
||||
}
|
||||
|
||||
// ConfigDefault is the default config
|
||||
var ConfigDefault = Config{
|
||||
Max: 5,
|
||||
Expiration: 1 * time.Minute,
|
||||
KeyGenerator: func(c *fiber.Ctx) string {
|
||||
return c.IP()
|
||||
},
|
||||
LimitReached: func(c *fiber.Ctx) error {
|
||||
return c.SendStatus(fiber.StatusTooManyRequests)
|
||||
},
|
||||
SkipFailedRequests: false,
|
||||
SkipSuccessfulRequests: false,
|
||||
LimiterMiddleware: FixedWindow{},
|
||||
}
|
||||
|
||||
// Helper function to set default values
|
||||
func configDefault(config ...Config) Config {
|
||||
// Return default config if nothing provided
|
||||
if len(config) < 1 {
|
||||
return ConfigDefault
|
||||
}
|
||||
|
||||
// Override default config
|
||||
cfg := config[0]
|
||||
|
||||
// Set default values
|
||||
if int(cfg.Duration.Seconds()) > 0 {
|
||||
log.Warn("[LIMITER] Duration is deprecated, please use Expiration")
|
||||
cfg.Expiration = cfg.Duration
|
||||
}
|
||||
if cfg.Key != nil {
|
||||
log.Warn("[LIMITER] Key is deprecated, please us KeyGenerator")
|
||||
cfg.KeyGenerator = cfg.Key
|
||||
}
|
||||
if cfg.Store != nil {
|
||||
log.Warn("[LIMITER] Store is deprecated, please use Storage")
|
||||
cfg.Storage = cfg.Store
|
||||
}
|
||||
if cfg.Next == nil {
|
||||
cfg.Next = ConfigDefault.Next
|
||||
}
|
||||
if cfg.Max <= 0 {
|
||||
cfg.Max = ConfigDefault.Max
|
||||
}
|
||||
if int(cfg.Expiration.Seconds()) <= 0 {
|
||||
cfg.Expiration = ConfigDefault.Expiration
|
||||
}
|
||||
if cfg.KeyGenerator == nil {
|
||||
cfg.KeyGenerator = ConfigDefault.KeyGenerator
|
||||
}
|
||||
if cfg.LimitReached == nil {
|
||||
cfg.LimitReached = ConfigDefault.LimitReached
|
||||
}
|
||||
if cfg.LimiterMiddleware == nil {
|
||||
cfg.LimiterMiddleware = ConfigDefault.LimiterMiddleware
|
||||
}
|
||||
return cfg
|
||||
}
|
25
middleware/limiter/limiter.go
Normal file
25
middleware/limiter/limiter.go
Normal file
|
@ -0,0 +1,25 @@
|
|||
package limiter
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
const (
|
||||
// X-RateLimit-* headers
|
||||
xRateLimitLimit = "X-RateLimit-Limit"
|
||||
xRateLimitRemaining = "X-RateLimit-Remaining"
|
||||
xRateLimitReset = "X-RateLimit-Reset"
|
||||
)
|
||||
|
||||
type LimiterHandler interface {
|
||||
New(config Config) fiber.Handler
|
||||
}
|
||||
|
||||
// New creates a new middleware handler
|
||||
func New(config ...Config) fiber.Handler {
|
||||
// Set default config
|
||||
cfg := configDefault(config...)
|
||||
|
||||
// Return the specified middleware handler.
|
||||
return cfg.LimiterMiddleware.New(cfg)
|
||||
}
|
106
middleware/limiter/limiter_fixed.go
Normal file
106
middleware/limiter/limiter_fixed.go
Normal file
|
@ -0,0 +1,106 @@
|
|||
package limiter
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
)
|
||||
|
||||
type FixedWindow struct{}
|
||||
|
||||
// New creates a new fixed window middleware handler
|
||||
func (FixedWindow) New(cfg Config) fiber.Handler {
|
||||
var (
|
||||
// Limiter variables
|
||||
mux = &sync.RWMutex{}
|
||||
max = strconv.Itoa(cfg.Max)
|
||||
expiration = uint64(cfg.Expiration.Seconds())
|
||||
)
|
||||
|
||||
// Create manager to simplify storage operations ( see manager.go )
|
||||
manager := newManager(cfg.Storage)
|
||||
|
||||
// Update timestamp every second
|
||||
utils.StartTimeStampUpdater()
|
||||
|
||||
// Return new handler
|
||||
return func(c *fiber.Ctx) error {
|
||||
// Don't execute middleware if Next returns true
|
||||
if cfg.Next != nil && cfg.Next(c) {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
// Get key from request
|
||||
key := cfg.KeyGenerator(c)
|
||||
|
||||
// Lock entry
|
||||
mux.Lock()
|
||||
|
||||
// Get entry from pool and release when finished
|
||||
e := manager.get(key)
|
||||
|
||||
// Get timestamp
|
||||
ts := uint64(atomic.LoadUint32(&utils.Timestamp))
|
||||
|
||||
// Set expiration if entry does not exist
|
||||
if e.exp == 0 {
|
||||
e.exp = ts + expiration
|
||||
} else if ts >= e.exp {
|
||||
// Check if entry is expired
|
||||
e.currHits = 0
|
||||
e.exp = ts + expiration
|
||||
}
|
||||
|
||||
// Increment hits
|
||||
e.currHits++
|
||||
|
||||
// Calculate when it resets in seconds
|
||||
resetInSec := e.exp - ts
|
||||
|
||||
// Set how many hits we have left
|
||||
remaining := cfg.Max - e.currHits
|
||||
|
||||
// Update storage
|
||||
manager.set(key, e, cfg.Expiration)
|
||||
|
||||
// Unlock entry
|
||||
mux.Unlock()
|
||||
|
||||
// Check if hits exceed the cfg.Max
|
||||
if remaining < 0 {
|
||||
// Return response with Retry-After header
|
||||
// https://tools.ietf.org/html/rfc6584
|
||||
c.Set(fiber.HeaderRetryAfter, strconv.FormatUint(resetInSec, 10))
|
||||
|
||||
// Call LimitReached handler
|
||||
return cfg.LimitReached(c)
|
||||
}
|
||||
|
||||
// Continue stack for reaching c.Response().StatusCode()
|
||||
// Store err for returning
|
||||
err := c.Next()
|
||||
|
||||
// Check for SkipFailedRequests and SkipSuccessfulRequests
|
||||
if (cfg.SkipSuccessfulRequests && c.Response().StatusCode() < fiber.StatusBadRequest) ||
|
||||
(cfg.SkipFailedRequests && c.Response().StatusCode() >= fiber.StatusBadRequest) {
|
||||
// Lock entry
|
||||
mux.Lock()
|
||||
e = manager.get(key)
|
||||
e.currHits--
|
||||
remaining++
|
||||
manager.set(key, e, cfg.Expiration)
|
||||
// Unlock entry
|
||||
mux.Unlock()
|
||||
}
|
||||
|
||||
// We can continue, update RateLimit headers
|
||||
c.Set(xRateLimitLimit, max)
|
||||
c.Set(xRateLimitRemaining, strconv.Itoa(remaining))
|
||||
c.Set(xRateLimitReset, strconv.FormatUint(resetInSec, 10))
|
||||
|
||||
return err
|
||||
}
|
||||
}
|
137
middleware/limiter/limiter_sliding.go
Normal file
137
middleware/limiter/limiter_sliding.go
Normal file
|
@ -0,0 +1,137 @@
|
|||
package limiter
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
)
|
||||
|
||||
type SlidingWindow struct{}
|
||||
|
||||
// New creates a new sliding window middleware handler
|
||||
func (SlidingWindow) New(cfg Config) fiber.Handler {
|
||||
var (
|
||||
// Limiter variables
|
||||
mux = &sync.RWMutex{}
|
||||
max = strconv.Itoa(cfg.Max)
|
||||
expiration = uint64(cfg.Expiration.Seconds())
|
||||
)
|
||||
|
||||
// Create manager to simplify storage operations ( see manager.go )
|
||||
manager := newManager(cfg.Storage)
|
||||
|
||||
// Update timestamp every second
|
||||
utils.StartTimeStampUpdater()
|
||||
|
||||
// Return new handler
|
||||
return func(c *fiber.Ctx) error {
|
||||
// Don't execute middleware if Next returns true
|
||||
if cfg.Next != nil && cfg.Next(c) {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
// Get key from request
|
||||
key := cfg.KeyGenerator(c)
|
||||
|
||||
// Lock entry
|
||||
mux.Lock()
|
||||
|
||||
// Get entry from pool and release when finished
|
||||
e := manager.get(key)
|
||||
|
||||
// Get timestamp
|
||||
ts := uint64(atomic.LoadUint32(&utils.Timestamp))
|
||||
|
||||
// Set expiration if entry does not exist
|
||||
if e.exp == 0 {
|
||||
e.exp = ts + expiration
|
||||
} else if ts >= e.exp {
|
||||
// The entry has expired, handle the expiration.
|
||||
// Set the prevHits to the current hits and reset the hits to 0.
|
||||
e.prevHits = e.currHits
|
||||
|
||||
// Reset the current hits to 0.
|
||||
e.currHits = 0
|
||||
|
||||
// Check how much into the current window it currently is and sets the
|
||||
// expiry based on that, otherwise this would only reset on
|
||||
// the next request and not show the correct expiry.
|
||||
elapsed := ts - e.exp
|
||||
if elapsed >= expiration {
|
||||
e.exp = ts + expiration
|
||||
} else {
|
||||
e.exp = ts + expiration - elapsed
|
||||
}
|
||||
}
|
||||
|
||||
// Increment hits
|
||||
e.currHits++
|
||||
|
||||
// Calculate when it resets in seconds
|
||||
resetInSec := e.exp - ts
|
||||
|
||||
// weight = time until current window reset / total window length
|
||||
weight := float64(resetInSec) / float64(expiration)
|
||||
|
||||
// rate = request count in previous window - weight + request count in current window
|
||||
rate := int(float64(e.prevHits)*weight) + e.currHits
|
||||
|
||||
// Calculate how many hits can be made based on the current rate
|
||||
remaining := cfg.Max - rate
|
||||
|
||||
// Update storage. Garbage collect when the next window ends.
|
||||
// |--------------------------|--------------------------|
|
||||
// ^ ^ ^ ^
|
||||
// ts e.exp End sample window End next window
|
||||
// <------------>
|
||||
// resetInSec
|
||||
// resetInSec = e.exp - ts - time until end of current window.
|
||||
// duration + expiration = end of next window.
|
||||
// Because we don't want to garbage collect in the middle of a window
|
||||
// we add the expiration to the duration.
|
||||
// Otherwise after the end of "sample window", attackers could launch
|
||||
// a new request with the full window length.
|
||||
manager.set(key, e, time.Duration(resetInSec+expiration)*time.Second)
|
||||
|
||||
// Unlock entry
|
||||
mux.Unlock()
|
||||
|
||||
// Check if hits exceed the cfg.Max
|
||||
if remaining < 0 {
|
||||
// Return response with Retry-After header
|
||||
// https://tools.ietf.org/html/rfc6584
|
||||
c.Set(fiber.HeaderRetryAfter, strconv.FormatUint(resetInSec, 10))
|
||||
|
||||
// Call LimitReached handler
|
||||
return cfg.LimitReached(c)
|
||||
}
|
||||
|
||||
// Continue stack for reaching c.Response().StatusCode()
|
||||
// Store err for returning
|
||||
err := c.Next()
|
||||
|
||||
// Check for SkipFailedRequests and SkipSuccessfulRequests
|
||||
if (cfg.SkipSuccessfulRequests && c.Response().StatusCode() < fiber.StatusBadRequest) ||
|
||||
(cfg.SkipFailedRequests && c.Response().StatusCode() >= fiber.StatusBadRequest) {
|
||||
// Lock entry
|
||||
mux.Lock()
|
||||
e = manager.get(key)
|
||||
e.currHits--
|
||||
remaining++
|
||||
manager.set(key, e, cfg.Expiration)
|
||||
// Unlock entry
|
||||
mux.Unlock()
|
||||
}
|
||||
|
||||
// We can continue, update RateLimit headers
|
||||
c.Set(xRateLimitLimit, max)
|
||||
c.Set(xRateLimitRemaining, strconv.Itoa(remaining))
|
||||
c.Set(xRateLimitReset, strconv.FormatUint(resetInSec, 10))
|
||||
|
||||
return err
|
||||
}
|
||||
}
|
727
middleware/limiter/limiter_test.go
Normal file
727
middleware/limiter/limiter_test.go
Normal file
|
@ -0,0 +1,727 @@
|
|||
package limiter
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/internal/storage/memory"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// go test -run Test_Limiter_Concurrency_Store -race -v
|
||||
func Test_Limiter_Concurrency_Store(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Test concurrency using a custom store
|
||||
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
Max: 50,
|
||||
Expiration: 2 * time.Second,
|
||||
Storage: memory.New(),
|
||||
}))
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return c.SendString("Hello tester!")
|
||||
})
|
||||
|
||||
var wg sync.WaitGroup
|
||||
singleRequest := func(wg *sync.WaitGroup) {
|
||||
defer wg.Done()
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, "Hello tester!", string(body))
|
||||
}
|
||||
|
||||
for i := 0; i <= 49; i++ {
|
||||
wg.Add(1)
|
||||
go singleRequest(&wg)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 429, resp.StatusCode)
|
||||
|
||||
time.Sleep(3 * time.Second)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||
}
|
||||
|
||||
// go test -run Test_Limiter_Concurrency -race -v
|
||||
func Test_Limiter_Concurrency(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Test concurrency using a default store
|
||||
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
Max: 50,
|
||||
Expiration: 2 * time.Second,
|
||||
}))
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return c.SendString("Hello tester!")
|
||||
})
|
||||
|
||||
var wg sync.WaitGroup
|
||||
singleRequest := func(wg *sync.WaitGroup) {
|
||||
defer wg.Done()
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, "Hello tester!", string(body))
|
||||
}
|
||||
|
||||
for i := 0; i <= 49; i++ {
|
||||
wg.Add(1)
|
||||
go singleRequest(&wg)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 429, resp.StatusCode)
|
||||
|
||||
time.Sleep(3 * time.Second)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||
}
|
||||
|
||||
// go test -run Test_Limiter_Fixed_Window_No_Skip_Choices -v
|
||||
func Test_Limiter_Fixed_Window_No_Skip_Choices(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
Max: 2,
|
||||
Expiration: 2 * time.Second,
|
||||
SkipFailedRequests: false,
|
||||
SkipSuccessfulRequests: false,
|
||||
LimiterMiddleware: FixedWindow{},
|
||||
}))
|
||||
|
||||
app.Get("/:status", func(c *fiber.Ctx) error {
|
||||
if c.Params("status") == "fail" { //nolint:goconst // False positive
|
||||
return c.SendStatus(400)
|
||||
}
|
||||
return c.SendStatus(200)
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 400, resp.StatusCode)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 429, resp.StatusCode)
|
||||
|
||||
time.Sleep(3 * time.Second)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||
}
|
||||
|
||||
// go test -run Test_Limiter_Fixed_Window_Custom_Storage_No_Skip_Choices -v
|
||||
func Test_Limiter_Fixed_Window_Custom_Storage_No_Skip_Choices(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
Max: 2,
|
||||
Expiration: 2 * time.Second,
|
||||
SkipFailedRequests: false,
|
||||
SkipSuccessfulRequests: false,
|
||||
Storage: memory.New(),
|
||||
LimiterMiddleware: FixedWindow{},
|
||||
}))
|
||||
|
||||
app.Get("/:status", func(c *fiber.Ctx) error {
|
||||
if c.Params("status") == "fail" {
|
||||
return c.SendStatus(400)
|
||||
}
|
||||
return c.SendStatus(200)
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 400, resp.StatusCode)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 429, resp.StatusCode)
|
||||
|
||||
time.Sleep(3 * time.Second)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||
}
|
||||
|
||||
// go test -run Test_Limiter_Sliding_Window_No_Skip_Choices -v
|
||||
func Test_Limiter_Sliding_Window_No_Skip_Choices(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
Max: 2,
|
||||
Expiration: 2 * time.Second,
|
||||
SkipFailedRequests: false,
|
||||
SkipSuccessfulRequests: false,
|
||||
LimiterMiddleware: SlidingWindow{},
|
||||
}))
|
||||
|
||||
app.Get("/:status", func(c *fiber.Ctx) error {
|
||||
if c.Params("status") == "fail" {
|
||||
return c.SendStatus(400)
|
||||
}
|
||||
return c.SendStatus(200)
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 400, resp.StatusCode)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 429, resp.StatusCode)
|
||||
|
||||
time.Sleep(4*time.Second + 500*time.Millisecond)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||
}
|
||||
|
||||
// go test -run Test_Limiter_Sliding_Window_Custom_Storage_No_Skip_Choices -v
|
||||
func Test_Limiter_Sliding_Window_Custom_Storage_No_Skip_Choices(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
Max: 2,
|
||||
Expiration: 2 * time.Second,
|
||||
SkipFailedRequests: false,
|
||||
SkipSuccessfulRequests: false,
|
||||
Storage: memory.New(),
|
||||
LimiterMiddleware: SlidingWindow{},
|
||||
}))
|
||||
|
||||
app.Get("/:status", func(c *fiber.Ctx) error {
|
||||
if c.Params("status") == "fail" {
|
||||
return c.SendStatus(400)
|
||||
}
|
||||
return c.SendStatus(200)
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 400, resp.StatusCode)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 429, resp.StatusCode)
|
||||
|
||||
time.Sleep(4*time.Second + 500*time.Millisecond)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||
}
|
||||
|
||||
// go test -run Test_Limiter_Fixed_Window_Skip_Failed_Requests -v
|
||||
func Test_Limiter_Fixed_Window_Skip_Failed_Requests(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
Max: 1,
|
||||
Expiration: 2 * time.Second,
|
||||
SkipFailedRequests: true,
|
||||
LimiterMiddleware: FixedWindow{},
|
||||
}))
|
||||
|
||||
app.Get("/:status", func(c *fiber.Ctx) error {
|
||||
if c.Params("status") == "fail" {
|
||||
return c.SendStatus(400)
|
||||
}
|
||||
return c.SendStatus(200)
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 400, resp.StatusCode)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 429, resp.StatusCode)
|
||||
|
||||
time.Sleep(3 * time.Second)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||
}
|
||||
|
||||
// go test -run Test_Limiter_Fixed_Window_Custom_Storage_Skip_Failed_Requests -v
|
||||
func Test_Limiter_Fixed_Window_Custom_Storage_Skip_Failed_Requests(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
Max: 1,
|
||||
Expiration: 2 * time.Second,
|
||||
Storage: memory.New(),
|
||||
SkipFailedRequests: true,
|
||||
LimiterMiddleware: FixedWindow{},
|
||||
}))
|
||||
|
||||
app.Get("/:status", func(c *fiber.Ctx) error {
|
||||
if c.Params("status") == "fail" {
|
||||
return c.SendStatus(400)
|
||||
}
|
||||
return c.SendStatus(200)
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 400, resp.StatusCode)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 429, resp.StatusCode)
|
||||
|
||||
time.Sleep(3 * time.Second)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||
}
|
||||
|
||||
// go test -run Test_Limiter_Sliding_Window_Skip_Failed_Requests -v
|
||||
func Test_Limiter_Sliding_Window_Skip_Failed_Requests(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
Max: 1,
|
||||
Expiration: 2 * time.Second,
|
||||
SkipFailedRequests: true,
|
||||
LimiterMiddleware: SlidingWindow{},
|
||||
}))
|
||||
|
||||
app.Get("/:status", func(c *fiber.Ctx) error {
|
||||
if c.Params("status") == "fail" {
|
||||
return c.SendStatus(400)
|
||||
}
|
||||
return c.SendStatus(200)
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 400, resp.StatusCode)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 429, resp.StatusCode)
|
||||
|
||||
time.Sleep(4*time.Second + 500*time.Millisecond)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||
}
|
||||
|
||||
// go test -run Test_Limiter_Sliding_Window_Custom_Storage_Skip_Failed_Requests -v
|
||||
func Test_Limiter_Sliding_Window_Custom_Storage_Skip_Failed_Requests(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
Max: 1,
|
||||
Expiration: 2 * time.Second,
|
||||
Storage: memory.New(),
|
||||
SkipFailedRequests: true,
|
||||
LimiterMiddleware: SlidingWindow{},
|
||||
}))
|
||||
|
||||
app.Get("/:status", func(c *fiber.Ctx) error {
|
||||
if c.Params("status") == "fail" {
|
||||
return c.SendStatus(400)
|
||||
}
|
||||
return c.SendStatus(200)
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 400, resp.StatusCode)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 429, resp.StatusCode)
|
||||
|
||||
time.Sleep(4*time.Second + 500*time.Millisecond)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||
}
|
||||
|
||||
// go test -run Test_Limiter_Fixed_Window_Skip_Successful_Requests -v
|
||||
func Test_Limiter_Fixed_Window_Skip_Successful_Requests(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Test concurrency using a default store
|
||||
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
Max: 1,
|
||||
Expiration: 2 * time.Second,
|
||||
SkipSuccessfulRequests: true,
|
||||
LimiterMiddleware: FixedWindow{},
|
||||
}))
|
||||
|
||||
app.Get("/:status", func(c *fiber.Ctx) error {
|
||||
if c.Params("status") == "fail" {
|
||||
return c.SendStatus(400)
|
||||
}
|
||||
return c.SendStatus(200)
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 400, resp.StatusCode)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 429, resp.StatusCode)
|
||||
|
||||
time.Sleep(3 * time.Second)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 400, resp.StatusCode)
|
||||
}
|
||||
|
||||
// go test -run Test_Limiter_Fixed_Window_Custom_Storage_Skip_Successful_Requests -v
|
||||
func Test_Limiter_Fixed_Window_Custom_Storage_Skip_Successful_Requests(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Test concurrency using a default store
|
||||
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
Max: 1,
|
||||
Expiration: 2 * time.Second,
|
||||
Storage: memory.New(),
|
||||
SkipSuccessfulRequests: true,
|
||||
LimiterMiddleware: FixedWindow{},
|
||||
}))
|
||||
|
||||
app.Get("/:status", func(c *fiber.Ctx) error {
|
||||
if c.Params("status") == "fail" {
|
||||
return c.SendStatus(400)
|
||||
}
|
||||
return c.SendStatus(200)
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 400, resp.StatusCode)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 429, resp.StatusCode)
|
||||
|
||||
time.Sleep(3 * time.Second)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 400, resp.StatusCode)
|
||||
}
|
||||
|
||||
// go test -run Test_Limiter_Sliding_Window_Skip_Successful_Requests -v
|
||||
func Test_Limiter_Sliding_Window_Skip_Successful_Requests(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Test concurrency using a default store
|
||||
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
Max: 1,
|
||||
Expiration: 2 * time.Second,
|
||||
SkipSuccessfulRequests: true,
|
||||
LimiterMiddleware: SlidingWindow{},
|
||||
}))
|
||||
|
||||
app.Get("/:status", func(c *fiber.Ctx) error {
|
||||
if c.Params("status") == "fail" {
|
||||
return c.SendStatus(400)
|
||||
}
|
||||
return c.SendStatus(200)
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 400, resp.StatusCode)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 429, resp.StatusCode)
|
||||
|
||||
time.Sleep(4*time.Second + 500*time.Millisecond)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 400, resp.StatusCode)
|
||||
}
|
||||
|
||||
// go test -run Test_Limiter_Sliding_Window_Custom_Storage_Skip_Successful_Requests -v
|
||||
func Test_Limiter_Sliding_Window_Custom_Storage_Skip_Successful_Requests(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Test concurrency using a default store
|
||||
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
Max: 1,
|
||||
Expiration: 2 * time.Second,
|
||||
Storage: memory.New(),
|
||||
SkipSuccessfulRequests: true,
|
||||
LimiterMiddleware: SlidingWindow{},
|
||||
}))
|
||||
|
||||
app.Get("/:status", func(c *fiber.Ctx) error {
|
||||
if c.Params("status") == "fail" {
|
||||
return c.SendStatus(400)
|
||||
}
|
||||
return c.SendStatus(200)
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 400, resp.StatusCode)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 429, resp.StatusCode)
|
||||
|
||||
time.Sleep(4*time.Second + 500*time.Millisecond)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 400, resp.StatusCode)
|
||||
}
|
||||
|
||||
// go test -v -run=^$ -bench=Benchmark_Limiter_Custom_Store -benchmem -count=4
|
||||
func Benchmark_Limiter_Custom_Store(b *testing.B) {
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
Max: 100,
|
||||
Expiration: 60 * time.Second,
|
||||
Storage: memory.New(),
|
||||
}))
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return c.SendString("Hello, World!")
|
||||
})
|
||||
|
||||
h := app.Handler()
|
||||
|
||||
fctx := &fasthttp.RequestCtx{}
|
||||
fctx.Request.Header.SetMethod(fiber.MethodGet)
|
||||
fctx.Request.SetRequestURI("/")
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
h(fctx)
|
||||
}
|
||||
}
|
||||
|
||||
// go test -run Test_Limiter_Next
|
||||
func Test_Limiter_Next(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
app.Use(New(Config{
|
||||
Next: func(_ *fiber.Ctx) bool {
|
||||
return true
|
||||
},
|
||||
}))
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
|
||||
}
|
||||
|
||||
func Test_Limiter_Headers(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
Max: 50,
|
||||
Expiration: 2 * time.Second,
|
||||
}))
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return c.SendString("Hello tester!")
|
||||
})
|
||||
|
||||
fctx := &fasthttp.RequestCtx{}
|
||||
fctx.Request.Header.SetMethod(fiber.MethodGet)
|
||||
fctx.Request.SetRequestURI("/")
|
||||
|
||||
app.Handler()(fctx)
|
||||
|
||||
utils.AssertEqual(t, "50", string(fctx.Response.Header.Peek("X-RateLimit-Limit")))
|
||||
if v := string(fctx.Response.Header.Peek("X-RateLimit-Remaining")); v == "" {
|
||||
t.Errorf("The X-RateLimit-Remaining header is not set correctly - value is an empty string.")
|
||||
}
|
||||
if v := string(fctx.Response.Header.Peek("X-RateLimit-Reset")); !(v == "1" || v == "2") {
|
||||
t.Errorf("The X-RateLimit-Reset header is not set correctly - value is out of bounds.")
|
||||
}
|
||||
}
|
||||
|
||||
// go test -v -run=^$ -bench=Benchmark_Limiter -benchmem -count=4
|
||||
func Benchmark_Limiter(b *testing.B) {
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
Max: 100,
|
||||
Expiration: 60 * time.Second,
|
||||
}))
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return c.SendString("Hello, World!")
|
||||
})
|
||||
|
||||
h := app.Handler()
|
||||
|
||||
fctx := &fasthttp.RequestCtx{}
|
||||
fctx.Request.Header.SetMethod(fiber.MethodGet)
|
||||
fctx.Request.SetRequestURI("/")
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
h(fctx)
|
||||
}
|
||||
}
|
||||
|
||||
// go test -run Test_Sliding_Window -race -v
|
||||
func Test_Sliding_Window(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
app.Use(New(Config{
|
||||
Max: 10,
|
||||
Expiration: 2 * time.Second,
|
||||
Storage: memory.New(),
|
||||
LimiterMiddleware: SlidingWindow{},
|
||||
}))
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return c.SendString("Hello tester!")
|
||||
})
|
||||
|
||||
singleRequest := func(shouldFail bool) {
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
if shouldFail {
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 429, resp.StatusCode)
|
||||
} else {
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
singleRequest(false)
|
||||
}
|
||||
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
singleRequest(false)
|
||||
}
|
||||
|
||||
time.Sleep(3 * time.Second)
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
singleRequest(false)
|
||||
}
|
||||
|
||||
time.Sleep(4 * time.Second)
|
||||
|
||||
for i := 0; i < 9; i++ {
|
||||
singleRequest(false)
|
||||
}
|
||||
}
|
92
middleware/limiter/manager.go
Normal file
92
middleware/limiter/manager.go
Normal file
|
@ -0,0 +1,92 @@
|
|||
package limiter
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/internal/memory"
|
||||
)
|
||||
|
||||
// go:generate msgp
|
||||
// msgp -file="manager.go" -o="manager_msgp.go" -tests=false -unexported
|
||||
type item struct {
|
||||
currHits int
|
||||
prevHits int
|
||||
exp uint64
|
||||
}
|
||||
|
||||
//msgp:ignore manager
|
||||
type manager struct {
|
||||
pool sync.Pool
|
||||
memory *memory.Storage
|
||||
storage fiber.Storage
|
||||
}
|
||||
|
||||
func newManager(storage fiber.Storage) *manager {
|
||||
// Create new storage handler
|
||||
manager := &manager{
|
||||
pool: sync.Pool{
|
||||
New: func() interface{} {
|
||||
return new(item)
|
||||
},
|
||||
},
|
||||
}
|
||||
if storage != nil {
|
||||
// Use provided storage if provided
|
||||
manager.storage = storage
|
||||
} else {
|
||||
// Fallback too memory storage
|
||||
manager.memory = memory.New()
|
||||
}
|
||||
return manager
|
||||
}
|
||||
|
||||
// acquire returns an *entry from the sync.Pool
|
||||
func (m *manager) acquire() *item {
|
||||
return m.pool.Get().(*item) //nolint:forcetypeassert // We store nothing else in the pool
|
||||
}
|
||||
|
||||
// release and reset *entry to sync.Pool
|
||||
func (m *manager) release(e *item) {
|
||||
e.prevHits = 0
|
||||
e.currHits = 0
|
||||
e.exp = 0
|
||||
m.pool.Put(e)
|
||||
}
|
||||
|
||||
// get data from storage or memory
|
||||
func (m *manager) get(key string) *item {
|
||||
var it *item
|
||||
if m.storage != nil {
|
||||
it = m.acquire()
|
||||
raw, err := m.storage.Get(key)
|
||||
if err != nil {
|
||||
return it
|
||||
}
|
||||
if raw != nil {
|
||||
if _, err := it.UnmarshalMsg(raw); err != nil {
|
||||
return it
|
||||
}
|
||||
}
|
||||
return it
|
||||
}
|
||||
if it, _ = m.memory.Get(key).(*item); it == nil { //nolint:errcheck // We store nothing else in the pool
|
||||
it = m.acquire()
|
||||
return it
|
||||
}
|
||||
return it
|
||||
}
|
||||
|
||||
// set data to storage or memory
|
||||
func (m *manager) set(key string, it *item, exp time.Duration) {
|
||||
if m.storage != nil {
|
||||
if raw, err := it.MarshalMsg(nil); err == nil {
|
||||
_ = m.storage.Set(key, raw, exp) //nolint:errcheck // TODO: Handle error here
|
||||
}
|
||||
// we can release data because it's serialized to database
|
||||
m.release(it)
|
||||
} else {
|
||||
m.memory.Set(key, it, exp)
|
||||
}
|
||||
}
|
160
middleware/limiter/manager_msgp.go
Normal file
160
middleware/limiter/manager_msgp.go
Normal file
|
@ -0,0 +1,160 @@
|
|||
package limiter
|
||||
|
||||
// Code generated by github.com/tinylib/msgp DO NOT EDIT.
|
||||
|
||||
import (
|
||||
"github.com/tinylib/msgp/msgp"
|
||||
)
|
||||
|
||||
// DecodeMsg implements msgp.Decodable
|
||||
func (z *item) DecodeMsg(dc *msgp.Reader) (err error) {
|
||||
var field []byte
|
||||
_ = field
|
||||
var zb0001 uint32
|
||||
zb0001, err = dc.ReadMapHeader()
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err)
|
||||
return
|
||||
}
|
||||
for zb0001 > 0 {
|
||||
zb0001--
|
||||
field, err = dc.ReadMapKeyPtr()
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err)
|
||||
return
|
||||
}
|
||||
switch msgp.UnsafeString(field) {
|
||||
case "currHits":
|
||||
z.currHits, err = dc.ReadInt()
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "currHits")
|
||||
return
|
||||
}
|
||||
case "prevHits":
|
||||
z.prevHits, err = dc.ReadInt()
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "prevHits")
|
||||
return
|
||||
}
|
||||
case "exp":
|
||||
z.exp, err = dc.ReadUint64()
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "exp")
|
||||
return
|
||||
}
|
||||
default:
|
||||
err = dc.Skip()
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// EncodeMsg implements msgp.Encodable
|
||||
func (z item) EncodeMsg(en *msgp.Writer) (err error) {
|
||||
// map header, size 3
|
||||
// write "currHits"
|
||||
err = en.Append(0x83, 0xa8, 0x63, 0x75, 0x72, 0x72, 0x48, 0x69, 0x74, 0x73)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = en.WriteInt(z.currHits)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "currHits")
|
||||
return
|
||||
}
|
||||
// write "prevHits"
|
||||
err = en.Append(0xa8, 0x70, 0x72, 0x65, 0x76, 0x48, 0x69, 0x74, 0x73)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = en.WriteInt(z.prevHits)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "prevHits")
|
||||
return
|
||||
}
|
||||
// write "exp"
|
||||
err = en.Append(0xa3, 0x65, 0x78, 0x70)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = en.WriteUint64(z.exp)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "exp")
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// MarshalMsg implements msgp.Marshaler
|
||||
func (z item) MarshalMsg(b []byte) (o []byte, err error) {
|
||||
o = msgp.Require(b, z.Msgsize())
|
||||
// map header, size 3
|
||||
// string "currHits"
|
||||
o = append(o, 0x83, 0xa8, 0x63, 0x75, 0x72, 0x72, 0x48, 0x69, 0x74, 0x73)
|
||||
o = msgp.AppendInt(o, z.currHits)
|
||||
// string "prevHits"
|
||||
o = append(o, 0xa8, 0x70, 0x72, 0x65, 0x76, 0x48, 0x69, 0x74, 0x73)
|
||||
o = msgp.AppendInt(o, z.prevHits)
|
||||
// string "exp"
|
||||
o = append(o, 0xa3, 0x65, 0x78, 0x70)
|
||||
o = msgp.AppendUint64(o, z.exp)
|
||||
return
|
||||
}
|
||||
|
||||
// UnmarshalMsg implements msgp.Unmarshaler
|
||||
func (z *item) UnmarshalMsg(bts []byte) (o []byte, err error) {
|
||||
var field []byte
|
||||
_ = field
|
||||
var zb0001 uint32
|
||||
zb0001, bts, err = msgp.ReadMapHeaderBytes(bts)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err)
|
||||
return
|
||||
}
|
||||
for zb0001 > 0 {
|
||||
zb0001--
|
||||
field, bts, err = msgp.ReadMapKeyZC(bts)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err)
|
||||
return
|
||||
}
|
||||
switch msgp.UnsafeString(field) {
|
||||
case "currHits":
|
||||
z.currHits, bts, err = msgp.ReadIntBytes(bts)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "currHits")
|
||||
return
|
||||
}
|
||||
case "prevHits":
|
||||
z.prevHits, bts, err = msgp.ReadIntBytes(bts)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "prevHits")
|
||||
return
|
||||
}
|
||||
case "exp":
|
||||
z.exp, bts, err = msgp.ReadUint64Bytes(bts)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "exp")
|
||||
return
|
||||
}
|
||||
default:
|
||||
bts, err = msgp.Skip(bts)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
o = bts
|
||||
return
|
||||
}
|
||||
|
||||
// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message
|
||||
func (z item) Msgsize() (s int) {
|
||||
s = 1 + 9 + msgp.IntSize + 9 + msgp.IntSize + 4 + msgp.Uint64Size
|
||||
return
|
||||
}
|
136
middleware/logger/config.go
Normal file
136
middleware/logger/config.go
Normal file
|
@ -0,0 +1,136 @@
|
|||
package logger
|
||||
|
||||
import (
|
||||
"io"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
// Config defines the config for middleware.
|
||||
type Config struct {
|
||||
// Next defines a function to skip this middleware when returned true.
|
||||
//
|
||||
// Optional. Default: nil
|
||||
Next func(c *fiber.Ctx) bool
|
||||
|
||||
// Done is a function that is called after the log string for a request is written to Output,
|
||||
// and pass the log string as parameter.
|
||||
//
|
||||
// Optional. Default: nil
|
||||
Done func(c *fiber.Ctx, logString []byte)
|
||||
|
||||
// tagFunctions defines the custom tag action
|
||||
//
|
||||
// Optional. Default: map[string]LogFunc
|
||||
CustomTags map[string]LogFunc
|
||||
|
||||
// Format defines the logging tags
|
||||
//
|
||||
// Optional. Default: ${time} | ${status} | ${latency} | ${ip} | ${method} | ${path} | ${error}\n
|
||||
Format string
|
||||
|
||||
// TimeFormat https://programming.guide/go/format-parse-string-time-date-example.html
|
||||
//
|
||||
// Optional. Default: 15:04:05
|
||||
TimeFormat string
|
||||
|
||||
// TimeZone can be specified, such as "UTC" and "America/New_York" and "Asia/Chongqing", etc
|
||||
//
|
||||
// Optional. Default: "Local"
|
||||
TimeZone string
|
||||
|
||||
// TimeInterval is the delay before the timestamp is updated
|
||||
//
|
||||
// Optional. Default: 500 * time.Millisecond
|
||||
TimeInterval time.Duration
|
||||
|
||||
// Output is a writer where logs are written
|
||||
//
|
||||
// Default: os.Stdout
|
||||
Output io.Writer
|
||||
|
||||
// DisableColors defines if the logs output should be colorized
|
||||
//
|
||||
// Default: false
|
||||
DisableColors bool
|
||||
|
||||
enableColors bool
|
||||
enableLatency bool
|
||||
timeZoneLocation *time.Location
|
||||
}
|
||||
|
||||
const (
|
||||
startTag = "${"
|
||||
endTag = "}"
|
||||
paramSeparator = ":"
|
||||
)
|
||||
|
||||
type Buffer interface {
|
||||
Len() int
|
||||
ReadFrom(r io.Reader) (int64, error)
|
||||
WriteTo(w io.Writer) (int64, error)
|
||||
Bytes() []byte
|
||||
Write(p []byte) (int, error)
|
||||
WriteByte(c byte) error
|
||||
WriteString(s string) (int, error)
|
||||
Set(p []byte)
|
||||
SetString(s string)
|
||||
String() string
|
||||
}
|
||||
|
||||
type LogFunc func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error)
|
||||
|
||||
// ConfigDefault is the default config
|
||||
var ConfigDefault = Config{
|
||||
Next: nil,
|
||||
Done: nil,
|
||||
Format: "${time} | ${status} | ${latency} | ${ip} | ${method} | ${path} | ${error}\n",
|
||||
TimeFormat: "15:04:05",
|
||||
TimeZone: "Local",
|
||||
TimeInterval: 500 * time.Millisecond,
|
||||
Output: os.Stdout,
|
||||
DisableColors: false,
|
||||
enableColors: true,
|
||||
}
|
||||
|
||||
// Helper function to set default values
|
||||
func configDefault(config ...Config) Config {
|
||||
// Return default config if nothing provided
|
||||
if len(config) < 1 {
|
||||
return ConfigDefault
|
||||
}
|
||||
|
||||
// Override default config
|
||||
cfg := config[0]
|
||||
|
||||
// Set default values
|
||||
if cfg.Next == nil {
|
||||
cfg.Next = ConfigDefault.Next
|
||||
}
|
||||
if cfg.Done == nil {
|
||||
cfg.Done = ConfigDefault.Done
|
||||
}
|
||||
if cfg.Format == "" {
|
||||
cfg.Format = ConfigDefault.Format
|
||||
}
|
||||
if cfg.TimeZone == "" {
|
||||
cfg.TimeZone = ConfigDefault.TimeZone
|
||||
}
|
||||
if cfg.TimeFormat == "" {
|
||||
cfg.TimeFormat = ConfigDefault.TimeFormat
|
||||
}
|
||||
if int(cfg.TimeInterval) <= 0 {
|
||||
cfg.TimeInterval = ConfigDefault.TimeInterval
|
||||
}
|
||||
if cfg.Output == nil {
|
||||
cfg.Output = ConfigDefault.Output
|
||||
}
|
||||
|
||||
if !cfg.DisableColors && cfg.Output == ConfigDefault.Output {
|
||||
cfg.enableColors = true
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
16
middleware/logger/data.go
Normal file
16
middleware/logger/data.go
Normal file
|
@ -0,0 +1,16 @@
|
|||
package logger
|
||||
|
||||
import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Data is a struct to define some variables to use in custom logger function.
|
||||
type Data struct {
|
||||
Pid string
|
||||
ErrPaddingStr string
|
||||
ChainErr error
|
||||
Start time.Time
|
||||
Stop time.Time
|
||||
Timestamp atomic.Value
|
||||
}
|
182
middleware/logger/logger.go
Normal file
182
middleware/logger/logger.go
Normal file
|
@ -0,0 +1,182 @@
|
|||
package logger
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
|
||||
"github.com/mattn/go-colorable"
|
||||
"github.com/mattn/go-isatty"
|
||||
"github.com/valyala/bytebufferpool"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// New creates a new middleware handler
|
||||
func New(config ...Config) fiber.Handler {
|
||||
// Set default config
|
||||
cfg := configDefault(config...)
|
||||
|
||||
// Get timezone location
|
||||
tz, err := time.LoadLocation(cfg.TimeZone)
|
||||
if err != nil || tz == nil {
|
||||
cfg.timeZoneLocation = time.Local
|
||||
} else {
|
||||
cfg.timeZoneLocation = tz
|
||||
}
|
||||
|
||||
// Check if format contains latency
|
||||
cfg.enableLatency = strings.Contains(cfg.Format, "${"+TagLatency+"}")
|
||||
|
||||
var timestamp atomic.Value
|
||||
// Create correct timeformat
|
||||
timestamp.Store(time.Now().In(cfg.timeZoneLocation).Format(cfg.TimeFormat))
|
||||
|
||||
// Update date/time every 500 milliseconds in a separate go routine
|
||||
if strings.Contains(cfg.Format, "${"+TagTime+"}") {
|
||||
go func() {
|
||||
for {
|
||||
time.Sleep(cfg.TimeInterval)
|
||||
timestamp.Store(time.Now().In(cfg.timeZoneLocation).Format(cfg.TimeFormat))
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Set PID once
|
||||
pid := strconv.Itoa(os.Getpid())
|
||||
|
||||
// Set variables
|
||||
var (
|
||||
once sync.Once
|
||||
mu sync.Mutex
|
||||
errHandler fiber.ErrorHandler
|
||||
|
||||
dataPool = sync.Pool{New: func() interface{} { return new(Data) }}
|
||||
)
|
||||
|
||||
// If colors are enabled, check terminal compatibility
|
||||
if cfg.enableColors {
|
||||
cfg.Output = colorable.NewColorableStdout()
|
||||
if os.Getenv("TERM") == "dumb" || os.Getenv("NO_COLOR") == "1" || (!isatty.IsTerminal(os.Stdout.Fd()) && !isatty.IsCygwinTerminal(os.Stdout.Fd())) {
|
||||
cfg.Output = colorable.NewNonColorable(os.Stdout)
|
||||
}
|
||||
}
|
||||
|
||||
errPadding := 15
|
||||
errPaddingStr := strconv.Itoa(errPadding)
|
||||
|
||||
// instead of analyzing the template inside(handler) each time, this is done once before
|
||||
// and we create several slices of the same length with the functions to be executed and fixed parts.
|
||||
templateChain, logFunChain, err := buildLogFuncChain(&cfg, createTagMap(&cfg))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Return new handler
|
||||
return func(c *fiber.Ctx) error {
|
||||
// Don't execute middleware if Next returns true
|
||||
if cfg.Next != nil && cfg.Next(c) {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
// Set error handler once
|
||||
once.Do(func() {
|
||||
// get longested possible path
|
||||
stack := c.App().Stack()
|
||||
for m := range stack {
|
||||
for r := range stack[m] {
|
||||
if len(stack[m][r].Path) > errPadding {
|
||||
errPadding = len(stack[m][r].Path)
|
||||
errPaddingStr = strconv.Itoa(errPadding)
|
||||
}
|
||||
}
|
||||
}
|
||||
// override error handler
|
||||
errHandler = c.App().ErrorHandler
|
||||
})
|
||||
|
||||
// Logger data
|
||||
data := dataPool.Get().(*Data) //nolint:forcetypeassert,errcheck // We store nothing else in the pool
|
||||
// no need for a reset, as long as we always override everything
|
||||
data.Pid = pid
|
||||
data.ErrPaddingStr = errPaddingStr
|
||||
data.Timestamp = timestamp
|
||||
// put data back in the pool
|
||||
defer dataPool.Put(data)
|
||||
|
||||
// Set latency start time
|
||||
if cfg.enableLatency {
|
||||
data.Start = time.Now()
|
||||
}
|
||||
|
||||
// Handle request, store err for logging
|
||||
chainErr := c.Next()
|
||||
|
||||
data.ChainErr = chainErr
|
||||
// Manually call error handler
|
||||
if chainErr != nil {
|
||||
if err := errHandler(c, chainErr); err != nil {
|
||||
_ = c.SendStatus(fiber.StatusInternalServerError) //nolint:errcheck // TODO: Explain why we ignore the error here
|
||||
}
|
||||
}
|
||||
|
||||
// Set latency stop time
|
||||
if cfg.enableLatency {
|
||||
data.Stop = time.Now()
|
||||
}
|
||||
|
||||
// Get new buffer
|
||||
buf := bytebufferpool.Get()
|
||||
|
||||
var err error
|
||||
// Loop over template parts execute dynamic parts and add fixed parts to the buffer
|
||||
for i, logFunc := range logFunChain {
|
||||
if logFunc == nil {
|
||||
_, _ = buf.Write(templateChain[i]) //nolint:errcheck // This will never fail
|
||||
} else if templateChain[i] == nil {
|
||||
_, err = logFunc(buf, c, data, "")
|
||||
} else {
|
||||
_, err = logFunc(buf, c, data, utils.UnsafeString(templateChain[i]))
|
||||
}
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Also write errors to the buffer
|
||||
if err != nil {
|
||||
_, _ = buf.WriteString(err.Error()) //nolint:errcheck // This will never fail
|
||||
}
|
||||
mu.Lock()
|
||||
// Write buffer to output
|
||||
if _, err := cfg.Output.Write(buf.Bytes()); err != nil {
|
||||
// Write error to output
|
||||
if _, err := cfg.Output.Write([]byte(err.Error())); err != nil {
|
||||
// There is something wrong with the given io.Writer
|
||||
_, _ = fmt.Fprintf(os.Stderr, "Failed to write to log, %v\n", err)
|
||||
}
|
||||
}
|
||||
mu.Unlock()
|
||||
|
||||
if cfg.Done != nil {
|
||||
cfg.Done(c, buf.Bytes())
|
||||
}
|
||||
|
||||
// Put buffer back to pool
|
||||
bytebufferpool.Put(buf)
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func appendInt(output Buffer, v int) (int, error) {
|
||||
old := output.Len()
|
||||
output.Set(fasthttp.AppendUint(output.Bytes(), v))
|
||||
return output.Len() - old, nil
|
||||
}
|
650
middleware/logger/logger_test.go
Normal file
650
middleware/logger/logger_test.go
Normal file
|
@ -0,0 +1,650 @@
|
|||
//nolint:bodyclose // Much easier to just ignore memory leaks in tests
|
||||
package logger
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"runtime"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/middleware/requestid"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
|
||||
"github.com/valyala/bytebufferpool"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// go test -run Test_Logger
|
||||
func Test_Logger(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
|
||||
buf := bytebufferpool.Get()
|
||||
defer bytebufferpool.Put(buf)
|
||||
|
||||
app.Use(New(Config{
|
||||
Format: "${error}",
|
||||
Output: buf,
|
||||
}))
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return errors.New("some random error")
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusInternalServerError, resp.StatusCode)
|
||||
utils.AssertEqual(t, "some random error", buf.String())
|
||||
}
|
||||
|
||||
// go test -run Test_Logger_locals
|
||||
func Test_Logger_locals(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
|
||||
buf := bytebufferpool.Get()
|
||||
defer bytebufferpool.Put(buf)
|
||||
|
||||
app.Use(New(Config{
|
||||
Format: "${locals:demo}",
|
||||
Output: buf,
|
||||
}))
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
c.Locals("demo", "johndoe")
|
||||
return c.SendStatus(fiber.StatusOK)
|
||||
})
|
||||
|
||||
app.Get("/int", func(c *fiber.Ctx) error {
|
||||
c.Locals("demo", 55)
|
||||
return c.SendStatus(fiber.StatusOK)
|
||||
})
|
||||
|
||||
app.Get("/empty", func(c *fiber.Ctx) error {
|
||||
return c.SendStatus(fiber.StatusOK)
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||
utils.AssertEqual(t, "johndoe", buf.String())
|
||||
|
||||
buf.Reset()
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/int", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||
utils.AssertEqual(t, "55", buf.String())
|
||||
|
||||
buf.Reset()
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/empty", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||
utils.AssertEqual(t, "", buf.String())
|
||||
}
|
||||
|
||||
// go test -run Test_Logger_Next
|
||||
func Test_Logger_Next(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
app.Use(New(Config{
|
||||
Next: func(_ *fiber.Ctx) bool {
|
||||
return true
|
||||
},
|
||||
}))
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
|
||||
}
|
||||
|
||||
// go test -run Test_Logger_Done
|
||||
func Test_Logger_Done(t *testing.T) {
|
||||
t.Parallel()
|
||||
buf := bytes.NewBuffer(nil)
|
||||
app := fiber.New()
|
||||
app.Use(New(Config{
|
||||
Done: func(c *fiber.Ctx, logString []byte) {
|
||||
if c.Response().StatusCode() == fiber.StatusOK {
|
||||
_, err := buf.Write(logString)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
}
|
||||
},
|
||||
})).Get("/logging", func(ctx *fiber.Ctx) error {
|
||||
return ctx.SendStatus(fiber.StatusOK)
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/logging", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||
utils.AssertEqual(t, true, buf.Len() > 0)
|
||||
}
|
||||
|
||||
// go test -run Test_Logger_ErrorTimeZone
|
||||
func Test_Logger_ErrorTimeZone(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
app.Use(New(Config{
|
||||
TimeZone: "invalid",
|
||||
}))
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
|
||||
}
|
||||
|
||||
type fakeOutput int
|
||||
|
||||
func (o *fakeOutput) Write([]byte) (int, error) {
|
||||
*o++
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// go test -run Test_Logger_ErrorOutput_WithoutColor
|
||||
func Test_Logger_ErrorOutput_WithoutColor(t *testing.T) {
|
||||
t.Parallel()
|
||||
o := new(fakeOutput)
|
||||
app := fiber.New()
|
||||
app.Use(New(Config{
|
||||
Output: o,
|
||||
DisableColors: true,
|
||||
}))
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
|
||||
|
||||
utils.AssertEqual(t, 1, int(*o))
|
||||
}
|
||||
|
||||
// go test -run Test_Logger_ErrorOutput
|
||||
func Test_Logger_ErrorOutput(t *testing.T) {
|
||||
t.Parallel()
|
||||
o := new(fakeOutput)
|
||||
app := fiber.New()
|
||||
app.Use(New(Config{
|
||||
Output: o,
|
||||
}))
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
|
||||
|
||||
utils.AssertEqual(t, 1, int(*o))
|
||||
}
|
||||
|
||||
// go test -run Test_Logger_All
|
||||
func Test_Logger_All(t *testing.T) {
|
||||
t.Parallel()
|
||||
buf := bytebufferpool.Get()
|
||||
defer bytebufferpool.Put(buf)
|
||||
|
||||
app := fiber.New()
|
||||
app.Use(New(Config{
|
||||
Format: "${pid}${reqHeaders}${referer}${protocol}${ip}${ips}${host}${url}${ua}${body}${route}${black}${red}${green}${yellow}${blue}${magenta}${cyan}${white}${reset}${error}${header:test}${query:test}${form:test}${cookie:test}${non}",
|
||||
Output: buf,
|
||||
}))
|
||||
|
||||
// Alias colors
|
||||
colors := app.Config().ColorScheme
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/?foo=bar", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
|
||||
|
||||
expected := fmt.Sprintf("%dHost=example.comhttp0.0.0.0example.com/?foo=bar/%s%s%s%s%s%s%s%s%sCannot GET /", os.Getpid(), colors.Black, colors.Red, colors.Green, colors.Yellow, colors.Blue, colors.Magenta, colors.Cyan, colors.White, colors.Reset)
|
||||
utils.AssertEqual(t, expected, buf.String())
|
||||
}
|
||||
|
||||
func getLatencyTimeUnits() []struct {
|
||||
unit string
|
||||
div time.Duration
|
||||
} {
|
||||
// windows does not support µs sleep precision
|
||||
// https://github.com/golang/go/issues/29485
|
||||
if runtime.GOOS == "windows" {
|
||||
return []struct {
|
||||
unit string
|
||||
div time.Duration
|
||||
}{
|
||||
{"ms", time.Millisecond},
|
||||
{"s", time.Second},
|
||||
}
|
||||
}
|
||||
return []struct {
|
||||
unit string
|
||||
div time.Duration
|
||||
}{
|
||||
{"µs", time.Microsecond},
|
||||
{"ms", time.Millisecond},
|
||||
{"s", time.Second},
|
||||
}
|
||||
}
|
||||
|
||||
// go test -run Test_Logger_WithLatency
|
||||
func Test_Logger_WithLatency(t *testing.T) {
|
||||
t.Parallel()
|
||||
buff := bytebufferpool.Get()
|
||||
defer bytebufferpool.Put(buff)
|
||||
app := fiber.New()
|
||||
logger := New(Config{
|
||||
Output: buff,
|
||||
Format: "${latency}",
|
||||
})
|
||||
app.Use(logger)
|
||||
|
||||
// Define a list of time units to test
|
||||
timeUnits := getLatencyTimeUnits()
|
||||
|
||||
// Initialize a new time unit
|
||||
sleepDuration := 1 * time.Nanosecond
|
||||
|
||||
// Define a test route that sleeps
|
||||
app.Get("/test", func(c *fiber.Ctx) error {
|
||||
time.Sleep(sleepDuration)
|
||||
return c.SendStatus(fiber.StatusOK)
|
||||
})
|
||||
|
||||
// Loop through each time unit and assert that the log output contains the expected latency value
|
||||
for _, tu := range timeUnits {
|
||||
// Update the sleep duration for the next iteration
|
||||
sleepDuration = 1 * tu.div
|
||||
|
||||
// Create a new HTTP request to the test route
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil), int(2*time.Second))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||
|
||||
// Assert that the log output contains the expected latency value in the current time unit
|
||||
utils.AssertEqual(t, bytes.HasSuffix(buff.Bytes(), []byte(tu.unit)), true, fmt.Sprintf("Expected latency to be in %s, got %s", tu.unit, buff.String()))
|
||||
|
||||
// Reset the buffer
|
||||
buff.Reset()
|
||||
}
|
||||
}
|
||||
|
||||
// go test -run Test_Logger_WithLatency_DefaultFormat
|
||||
func Test_Logger_WithLatency_DefaultFormat(t *testing.T) {
|
||||
t.Parallel()
|
||||
buff := bytebufferpool.Get()
|
||||
defer bytebufferpool.Put(buff)
|
||||
app := fiber.New()
|
||||
logger := New(Config{
|
||||
Output: buff,
|
||||
})
|
||||
app.Use(logger)
|
||||
|
||||
// Define a list of time units to test
|
||||
timeUnits := getLatencyTimeUnits()
|
||||
|
||||
// Initialize a new time unit
|
||||
sleepDuration := 1 * time.Nanosecond
|
||||
|
||||
// Define a test route that sleeps
|
||||
app.Get("/test", func(c *fiber.Ctx) error {
|
||||
time.Sleep(sleepDuration)
|
||||
return c.SendStatus(fiber.StatusOK)
|
||||
})
|
||||
|
||||
// Loop through each time unit and assert that the log output contains the expected latency value
|
||||
for _, tu := range timeUnits {
|
||||
// Update the sleep duration for the next iteration
|
||||
sleepDuration = 1 * tu.div
|
||||
|
||||
// Create a new HTTP request to the test route
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil), int(2*time.Second))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||
|
||||
// Assert that the log output contains the expected latency value in the current time unit
|
||||
// parse out the latency value from the log output
|
||||
latency := bytes.Split(buff.Bytes(), []byte(" | "))[2]
|
||||
// Assert that the latency value is in the current time unit
|
||||
utils.AssertEqual(t, bytes.HasSuffix(latency, []byte(tu.unit)), true, fmt.Sprintf("Expected latency to be in %s, got %s", tu.unit, latency))
|
||||
|
||||
// Reset the buffer
|
||||
buff.Reset()
|
||||
}
|
||||
}
|
||||
|
||||
// go test -run Test_Query_Params
|
||||
func Test_Query_Params(t *testing.T) {
|
||||
t.Parallel()
|
||||
buf := bytebufferpool.Get()
|
||||
defer bytebufferpool.Put(buf)
|
||||
|
||||
app := fiber.New()
|
||||
app.Use(New(Config{
|
||||
Format: "${queryParams}",
|
||||
Output: buf,
|
||||
}))
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/?foo=bar&baz=moz", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
|
||||
|
||||
expected := "foo=bar&baz=moz"
|
||||
utils.AssertEqual(t, expected, buf.String())
|
||||
}
|
||||
|
||||
// go test -run Test_Response_Body
|
||||
func Test_Response_Body(t *testing.T) {
|
||||
t.Parallel()
|
||||
buf := bytebufferpool.Get()
|
||||
defer bytebufferpool.Put(buf)
|
||||
|
||||
app := fiber.New()
|
||||
app.Use(New(Config{
|
||||
Format: "${resBody}",
|
||||
Output: buf,
|
||||
}))
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return c.SendString("Sample response body")
|
||||
})
|
||||
|
||||
app.Post("/test", func(c *fiber.Ctx) error {
|
||||
return c.Send([]byte("Post in test"))
|
||||
})
|
||||
|
||||
_, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
expectedGetResponse := "Sample response body"
|
||||
utils.AssertEqual(t, expectedGetResponse, buf.String())
|
||||
|
||||
buf.Reset() // Reset buffer to test POST
|
||||
|
||||
_, err = app.Test(httptest.NewRequest(fiber.MethodPost, "/test", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
expectedPostResponse := "Post in test"
|
||||
utils.AssertEqual(t, expectedPostResponse, buf.String())
|
||||
}
|
||||
|
||||
// go test -run Test_Logger_AppendUint
|
||||
func Test_Logger_AppendUint(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
|
||||
buf := bytebufferpool.Get()
|
||||
defer bytebufferpool.Put(buf)
|
||||
|
||||
app.Use(New(Config{
|
||||
Format: "${bytesReceived} ${bytesSent} ${status}",
|
||||
Output: buf,
|
||||
}))
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
c.Response().Header.SetContentLength(5)
|
||||
return c.SendString("hello")
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||
utils.AssertEqual(t, "-2 5 200", buf.String())
|
||||
}
|
||||
|
||||
// go test -run Test_Logger_Data_Race -race
|
||||
func Test_Logger_Data_Race(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
|
||||
buf := bytebufferpool.Get()
|
||||
defer bytebufferpool.Put(buf)
|
||||
|
||||
app.Use(New(ConfigDefault))
|
||||
app.Use(New(Config{
|
||||
Format: "${time} | ${pid} | ${locals:requestid} | ${status} | ${latency} | ${method} | ${path}\n",
|
||||
}))
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return c.SendString("hello")
|
||||
})
|
||||
|
||||
var (
|
||||
resp1, resp2 *http.Response
|
||||
err1, err2 error
|
||||
)
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
resp1, err1 = app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
wg.Done()
|
||||
}()
|
||||
resp2, err2 = app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
wg.Wait()
|
||||
utils.AssertEqual(t, nil, err1)
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp1.StatusCode)
|
||||
utils.AssertEqual(t, nil, err2)
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp2.StatusCode)
|
||||
}
|
||||
|
||||
// go test -v -run=^$ -bench=Benchmark_Logger -benchmem -count=4
|
||||
func Benchmark_Logger(b *testing.B) {
|
||||
benchSetup := func(b *testing.B, app *fiber.App) {
|
||||
b.Helper()
|
||||
|
||||
h := app.Handler()
|
||||
|
||||
fctx := &fasthttp.RequestCtx{}
|
||||
fctx.Request.Header.SetMethod(fiber.MethodGet)
|
||||
fctx.Request.SetRequestURI("/")
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
h(fctx)
|
||||
}
|
||||
|
||||
utils.AssertEqual(b, 200, fctx.Response.Header.StatusCode())
|
||||
}
|
||||
|
||||
b.Run("Base", func(bb *testing.B) {
|
||||
app := fiber.New()
|
||||
app.Use(New(Config{
|
||||
Format: "${bytesReceived} ${bytesSent} ${status}",
|
||||
Output: io.Discard,
|
||||
}))
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
c.Set("test", "test")
|
||||
return c.SendString("Hello, World!")
|
||||
})
|
||||
benchSetup(bb, app)
|
||||
})
|
||||
|
||||
b.Run("DefaultFormat", func(bb *testing.B) {
|
||||
app := fiber.New()
|
||||
app.Use(New(Config{
|
||||
Output: io.Discard,
|
||||
}))
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return c.SendString("Hello, World!")
|
||||
})
|
||||
benchSetup(bb, app)
|
||||
})
|
||||
|
||||
b.Run("WithTagParameter", func(bb *testing.B) {
|
||||
app := fiber.New()
|
||||
app.Use(New(Config{
|
||||
Format: "${bytesReceived} ${bytesSent} ${status} ${reqHeader:test}",
|
||||
Output: io.Discard,
|
||||
}))
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
c.Set("test", "test")
|
||||
return c.SendString("Hello, World!")
|
||||
})
|
||||
benchSetup(bb, app)
|
||||
})
|
||||
}
|
||||
|
||||
// go test -run Test_Response_Header
|
||||
func Test_Response_Header(t *testing.T) {
|
||||
t.Parallel()
|
||||
buf := bytebufferpool.Get()
|
||||
defer bytebufferpool.Put(buf)
|
||||
|
||||
app := fiber.New()
|
||||
app.Use(requestid.New(requestid.Config{
|
||||
Next: nil,
|
||||
Header: fiber.HeaderXRequestID,
|
||||
Generator: func() string { return "Hello fiber!" },
|
||||
ContextKey: "requestid",
|
||||
}))
|
||||
app.Use(New(Config{
|
||||
Format: "${respHeader:X-Request-ID}",
|
||||
Output: buf,
|
||||
}))
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return c.SendString("Hello fiber!")
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||
utils.AssertEqual(t, "Hello fiber!", buf.String())
|
||||
}
|
||||
|
||||
// go test -run Test_Req_Header
|
||||
func Test_Req_Header(t *testing.T) {
|
||||
t.Parallel()
|
||||
buf := bytebufferpool.Get()
|
||||
defer bytebufferpool.Put(buf)
|
||||
|
||||
app := fiber.New()
|
||||
app.Use(New(Config{
|
||||
Format: "${header:test}",
|
||||
Output: buf,
|
||||
}))
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return c.SendString("Hello fiber!")
|
||||
})
|
||||
headerReq := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||
headerReq.Header.Add("test", "Hello fiber!")
|
||||
|
||||
resp, err := app.Test(headerReq)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||
utils.AssertEqual(t, "Hello fiber!", buf.String())
|
||||
}
|
||||
|
||||
// go test -run Test_ReqHeader_Header
|
||||
func Test_ReqHeader_Header(t *testing.T) {
|
||||
t.Parallel()
|
||||
buf := bytebufferpool.Get()
|
||||
defer bytebufferpool.Put(buf)
|
||||
|
||||
app := fiber.New()
|
||||
app.Use(New(Config{
|
||||
Format: "${reqHeader:test}",
|
||||
Output: buf,
|
||||
}))
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return c.SendString("Hello fiber!")
|
||||
})
|
||||
reqHeaderReq := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||
reqHeaderReq.Header.Add("test", "Hello fiber!")
|
||||
|
||||
resp, err := app.Test(reqHeaderReq)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||
utils.AssertEqual(t, "Hello fiber!", buf.String())
|
||||
}
|
||||
|
||||
// go test -run Test_CustomTags
|
||||
func Test_CustomTags(t *testing.T) {
|
||||
t.Parallel()
|
||||
customTag := "it is a custom tag"
|
||||
|
||||
buf := bytebufferpool.Get()
|
||||
defer bytebufferpool.Put(buf)
|
||||
|
||||
app := fiber.New()
|
||||
app.Use(New(Config{
|
||||
Format: "${custom_tag}",
|
||||
CustomTags: map[string]LogFunc{
|
||||
"custom_tag": func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(customTag)
|
||||
},
|
||||
},
|
||||
Output: buf,
|
||||
}))
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return c.SendString("Hello fiber!")
|
||||
})
|
||||
reqHeaderReq := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||
reqHeaderReq.Header.Add("test", "Hello fiber!")
|
||||
|
||||
resp, err := app.Test(reqHeaderReq)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||
utils.AssertEqual(t, customTag, buf.String())
|
||||
}
|
||||
|
||||
// go test -run Test_Logger_ByteSent_Streaming
|
||||
func Test_Logger_ByteSent_Streaming(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
|
||||
buf := bytebufferpool.Get()
|
||||
defer bytebufferpool.Put(buf)
|
||||
|
||||
app.Use(New(Config{
|
||||
Format: "${bytesReceived} ${bytesSent} ${status}",
|
||||
Output: buf,
|
||||
}))
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
c.Set("Connection", "keep-alive")
|
||||
c.Set("Transfer-Encoding", "chunked")
|
||||
c.Context().SetBodyStreamWriter(func(w *bufio.Writer) {
|
||||
var i int
|
||||
for {
|
||||
i++
|
||||
msg := fmt.Sprintf("%d - the time is %v", i, time.Now())
|
||||
fmt.Fprintf(w, "data: Message: %s\n\n", msg)
|
||||
err := w.Flush()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
if i == 10 {
|
||||
break
|
||||
}
|
||||
}
|
||||
})
|
||||
return nil
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||
utils.AssertEqual(t, "-2 -1 200", buf.String())
|
||||
}
|
||||
|
||||
// go test -run Test_Logger_EnableColors
|
||||
func Test_Logger_EnableColors(t *testing.T) {
|
||||
t.Parallel()
|
||||
o := new(fakeOutput)
|
||||
app := fiber.New()
|
||||
app.Use(New(Config{
|
||||
Output: o,
|
||||
}))
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
|
||||
|
||||
utils.AssertEqual(t, 1, int(*o))
|
||||
}
|
209
middleware/logger/tags.go
Normal file
209
middleware/logger/tags.go
Normal file
|
@ -0,0 +1,209 @@
|
|||
package logger
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
// Logger variables
|
||||
const (
|
||||
TagPid = "pid"
|
||||
TagTime = "time"
|
||||
TagReferer = "referer"
|
||||
TagProtocol = "protocol"
|
||||
TagPort = "port"
|
||||
TagIP = "ip"
|
||||
TagIPs = "ips"
|
||||
TagHost = "host"
|
||||
TagMethod = "method"
|
||||
TagPath = "path"
|
||||
TagURL = "url"
|
||||
TagUA = "ua"
|
||||
TagLatency = "latency"
|
||||
TagStatus = "status"
|
||||
TagResBody = "resBody"
|
||||
TagReqHeaders = "reqHeaders"
|
||||
TagQueryStringParams = "queryParams"
|
||||
TagBody = "body"
|
||||
TagBytesSent = "bytesSent"
|
||||
TagBytesReceived = "bytesReceived"
|
||||
TagRoute = "route"
|
||||
TagError = "error"
|
||||
// Deprecated: Use TagReqHeader instead
|
||||
TagHeader = "header:"
|
||||
TagReqHeader = "reqHeader:"
|
||||
TagRespHeader = "respHeader:"
|
||||
TagLocals = "locals:"
|
||||
TagQuery = "query:"
|
||||
TagForm = "form:"
|
||||
TagCookie = "cookie:"
|
||||
TagBlack = "black"
|
||||
TagRed = "red"
|
||||
TagGreen = "green"
|
||||
TagYellow = "yellow"
|
||||
TagBlue = "blue"
|
||||
TagMagenta = "magenta"
|
||||
TagCyan = "cyan"
|
||||
TagWhite = "white"
|
||||
TagReset = "reset"
|
||||
)
|
||||
|
||||
// createTagMap function merged the default with the custom tags
|
||||
func createTagMap(cfg *Config) map[string]LogFunc {
|
||||
// Set default tags
|
||||
tagFunctions := map[string]LogFunc{
|
||||
TagReferer: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(c.Get(fiber.HeaderReferer))
|
||||
},
|
||||
TagProtocol: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(c.Protocol())
|
||||
},
|
||||
TagPort: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(c.Port())
|
||||
},
|
||||
TagIP: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(c.IP())
|
||||
},
|
||||
TagIPs: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(c.Get(fiber.HeaderXForwardedFor))
|
||||
},
|
||||
TagHost: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(c.Hostname())
|
||||
},
|
||||
TagPath: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(c.Path())
|
||||
},
|
||||
TagURL: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(c.OriginalURL())
|
||||
},
|
||||
TagUA: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(c.Get(fiber.HeaderUserAgent))
|
||||
},
|
||||
TagBody: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.Write(c.Body())
|
||||
},
|
||||
TagBytesReceived: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(strconv.Itoa((c.Request().Header.ContentLength())))
|
||||
},
|
||||
TagBytesSent: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(strconv.Itoa((c.Response().Header.ContentLength())))
|
||||
},
|
||||
TagRoute: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(c.Route().Path)
|
||||
},
|
||||
TagResBody: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.Write(c.Response().Body())
|
||||
},
|
||||
TagReqHeaders: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
reqHeaders := make([]string, 0)
|
||||
for k, v := range c.GetReqHeaders() {
|
||||
reqHeaders = append(reqHeaders, k+"="+strings.Join(v, ","))
|
||||
}
|
||||
return output.Write([]byte(strings.Join(reqHeaders, "&")))
|
||||
},
|
||||
TagQueryStringParams: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(c.Request().URI().QueryArgs().String())
|
||||
},
|
||||
|
||||
TagBlack: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(c.App().Config().ColorScheme.Black)
|
||||
},
|
||||
TagRed: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(c.App().Config().ColorScheme.Red)
|
||||
},
|
||||
TagGreen: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(c.App().Config().ColorScheme.Green)
|
||||
},
|
||||
TagYellow: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(c.App().Config().ColorScheme.Yellow)
|
||||
},
|
||||
TagBlue: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(c.App().Config().ColorScheme.Blue)
|
||||
},
|
||||
TagMagenta: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(c.App().Config().ColorScheme.Magenta)
|
||||
},
|
||||
TagCyan: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(c.App().Config().ColorScheme.Cyan)
|
||||
},
|
||||
TagWhite: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(c.App().Config().ColorScheme.White)
|
||||
},
|
||||
TagReset: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(c.App().Config().ColorScheme.Reset)
|
||||
},
|
||||
TagError: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
if data.ChainErr != nil {
|
||||
if cfg.enableColors {
|
||||
colors := c.App().Config().ColorScheme
|
||||
return output.WriteString(fmt.Sprintf("%s%s%s", colors.Red, data.ChainErr.Error(), colors.Reset))
|
||||
}
|
||||
return output.WriteString(data.ChainErr.Error())
|
||||
}
|
||||
return output.WriteString("-")
|
||||
},
|
||||
TagReqHeader: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(c.Get(extraParam))
|
||||
},
|
||||
TagHeader: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(c.Get(extraParam))
|
||||
},
|
||||
TagRespHeader: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(c.GetRespHeader(extraParam))
|
||||
},
|
||||
TagQuery: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(c.Query(extraParam))
|
||||
},
|
||||
TagForm: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(c.FormValue(extraParam))
|
||||
},
|
||||
TagCookie: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(c.Cookies(extraParam))
|
||||
},
|
||||
TagLocals: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
switch v := c.Locals(extraParam).(type) {
|
||||
case []byte:
|
||||
return output.Write(v)
|
||||
case string:
|
||||
return output.WriteString(v)
|
||||
case nil:
|
||||
return 0, nil
|
||||
default:
|
||||
return output.WriteString(fmt.Sprintf("%v", v))
|
||||
}
|
||||
},
|
||||
TagStatus: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
if cfg.enableColors {
|
||||
colors := c.App().Config().ColorScheme
|
||||
return output.WriteString(fmt.Sprintf("%s%3d%s", statusColor(c.Response().StatusCode(), colors), c.Response().StatusCode(), colors.Reset))
|
||||
}
|
||||
return appendInt(output, c.Response().StatusCode())
|
||||
},
|
||||
TagMethod: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
if cfg.enableColors {
|
||||
colors := c.App().Config().ColorScheme
|
||||
return output.WriteString(fmt.Sprintf("%s%s%s", methodColor(c.Method(), colors), c.Method(), colors.Reset))
|
||||
}
|
||||
return output.WriteString(c.Method())
|
||||
},
|
||||
TagPid: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(data.Pid)
|
||||
},
|
||||
TagLatency: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
latency := data.Stop.Sub(data.Start)
|
||||
return output.WriteString(fmt.Sprintf("%13v", latency))
|
||||
},
|
||||
TagTime: func(output Buffer, c *fiber.Ctx, data *Data, extraParam string) (int, error) {
|
||||
return output.WriteString(data.Timestamp.Load().(string)) //nolint:forcetypeassert // We always store a string in here
|
||||
},
|
||||
}
|
||||
// merge with custom tags from user
|
||||
for k, v := range cfg.CustomTags {
|
||||
tagFunctions[k] = v
|
||||
}
|
||||
|
||||
return tagFunctions
|
||||
}
|
70
middleware/logger/template_chain.go
Normal file
70
middleware/logger/template_chain.go
Normal file
|
@ -0,0 +1,70 @@
|
|||
package logger
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
)
|
||||
|
||||
// buildLogFuncChain analyzes the template and creates slices with the functions for execution and
|
||||
// slices with the fixed parts of the template and the parameters
|
||||
//
|
||||
// fixParts contains the fixed parts of the template or parameters if a function is stored in the funcChain at this position
|
||||
// funcChain contains for the parts which exist the functions for the dynamic parts
|
||||
// funcChain and fixParts always have the same length and contain nil for the parts where no data is required in the chain,
|
||||
// if a function exists for the part, a parameter for it can also exist in the fixParts slice
|
||||
func buildLogFuncChain(cfg *Config, tagFunctions map[string]LogFunc) ([][]byte, []LogFunc, error) {
|
||||
// process flow is copied from the fasttemplate flow https://github.com/valyala/fasttemplate/blob/2a2d1afadadf9715bfa19683cdaeac8347e5d9f9/template.go#L23-L62
|
||||
templateB := utils.UnsafeBytes(cfg.Format)
|
||||
startTagB := utils.UnsafeBytes(startTag)
|
||||
endTagB := utils.UnsafeBytes(endTag)
|
||||
paramSeparatorB := utils.UnsafeBytes(paramSeparator)
|
||||
|
||||
var fixParts [][]byte
|
||||
var funcChain []LogFunc
|
||||
|
||||
for {
|
||||
currentPos := bytes.Index(templateB, startTagB)
|
||||
if currentPos < 0 {
|
||||
// no starting tag found in the existing template part
|
||||
break
|
||||
}
|
||||
// add fixed part
|
||||
funcChain = append(funcChain, nil)
|
||||
fixParts = append(fixParts, templateB[:currentPos])
|
||||
|
||||
templateB = templateB[currentPos+len(startTagB):]
|
||||
currentPos = bytes.Index(templateB, endTagB)
|
||||
if currentPos < 0 {
|
||||
// cannot find end tag - just write it to the output.
|
||||
funcChain = append(funcChain, nil)
|
||||
fixParts = append(fixParts, startTagB)
|
||||
break
|
||||
}
|
||||
// ## function block ##
|
||||
// first check for tags with parameters
|
||||
if index := bytes.Index(templateB[:currentPos], paramSeparatorB); index != -1 {
|
||||
logFunc, ok := tagFunctions[utils.UnsafeString(templateB[:index+1])]
|
||||
if !ok {
|
||||
return nil, nil, errors.New("No parameter found in \"" + utils.UnsafeString(templateB[:currentPos]) + "\"")
|
||||
}
|
||||
funcChain = append(funcChain, logFunc)
|
||||
// add param to the fixParts
|
||||
fixParts = append(fixParts, templateB[index+1:currentPos])
|
||||
} else if logFunc, ok := tagFunctions[utils.UnsafeString(templateB[:currentPos])]; ok {
|
||||
// add functions without parameter
|
||||
funcChain = append(funcChain, logFunc)
|
||||
fixParts = append(fixParts, nil)
|
||||
}
|
||||
// ## function block end ##
|
||||
|
||||
// reduce the template string
|
||||
templateB = templateB[currentPos+len(endTagB):]
|
||||
}
|
||||
// set the rest
|
||||
funcChain = append(funcChain, nil)
|
||||
fixParts = append(fixParts, templateB)
|
||||
|
||||
return fixParts, funcChain, nil
|
||||
}
|
39
middleware/logger/utils.go
Normal file
39
middleware/logger/utils.go
Normal file
|
@ -0,0 +1,39 @@
|
|||
package logger
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
func methodColor(method string, colors fiber.Colors) string {
|
||||
switch method {
|
||||
case fiber.MethodGet:
|
||||
return colors.Cyan
|
||||
case fiber.MethodPost:
|
||||
return colors.Green
|
||||
case fiber.MethodPut:
|
||||
return colors.Yellow
|
||||
case fiber.MethodDelete:
|
||||
return colors.Red
|
||||
case fiber.MethodPatch:
|
||||
return colors.White
|
||||
case fiber.MethodHead:
|
||||
return colors.Magenta
|
||||
case fiber.MethodOptions:
|
||||
return colors.Blue
|
||||
default:
|
||||
return colors.Reset
|
||||
}
|
||||
}
|
||||
|
||||
func statusColor(code int, colors fiber.Colors) string {
|
||||
switch {
|
||||
case code >= fiber.StatusOK && code < fiber.StatusMultipleChoices:
|
||||
return colors.Green
|
||||
case code >= fiber.StatusMultipleChoices && code < fiber.StatusBadRequest:
|
||||
return colors.Blue
|
||||
case code >= fiber.StatusBadRequest && code < fiber.StatusInternalServerError:
|
||||
return colors.Yellow
|
||||
default:
|
||||
return colors.Red
|
||||
}
|
||||
}
|
132
middleware/monitor/config.go
Normal file
132
middleware/monitor/config.go
Normal file
|
@ -0,0 +1,132 @@
|
|||
package monitor
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
// Config defines the config for middleware.
|
||||
type Config struct {
|
||||
// Metrics page title
|
||||
//
|
||||
// Optional. Default: "Fiber Monitor"
|
||||
Title string
|
||||
|
||||
// Refresh period
|
||||
//
|
||||
// Optional. Default: 3 seconds
|
||||
Refresh time.Duration
|
||||
|
||||
// Whether the service should expose only the monitoring API.
|
||||
//
|
||||
// Optional. Default: false
|
||||
APIOnly bool
|
||||
|
||||
// Next defines a function to skip this middleware when returned true.
|
||||
//
|
||||
// Optional. Default: nil
|
||||
Next func(c *fiber.Ctx) bool
|
||||
|
||||
// Custom HTML Code to Head Section(Before End)
|
||||
//
|
||||
// Optional. Default: empty
|
||||
CustomHead string
|
||||
|
||||
// FontURL for specify font resource path or URL . also you can use relative path
|
||||
//
|
||||
// Optional. Default: https://fonts.googleapis.com/css2?family=Roboto:wght@400;900&display=swap
|
||||
FontURL string
|
||||
|
||||
// ChartJsURL for specify ChartJS library path or URL . also you can use relative path
|
||||
//
|
||||
// Optional. Default: https://cdn.jsdelivr.net/npm/chart.js@2.9/dist/Chart.bundle.min.js
|
||||
ChartJsURL string // TODO: Rename to "ChartJSURL" in v3
|
||||
|
||||
index string
|
||||
}
|
||||
|
||||
var ConfigDefault = Config{
|
||||
Title: defaultTitle,
|
||||
Refresh: defaultRefresh,
|
||||
FontURL: defaultFontURL,
|
||||
ChartJsURL: defaultChartJSURL,
|
||||
CustomHead: defaultCustomHead,
|
||||
APIOnly: false,
|
||||
Next: nil,
|
||||
index: newIndex(viewBag{
|
||||
defaultTitle,
|
||||
defaultRefresh,
|
||||
defaultFontURL,
|
||||
defaultChartJSURL,
|
||||
defaultCustomHead,
|
||||
}),
|
||||
}
|
||||
|
||||
func configDefault(config ...Config) Config {
|
||||
// Users can change ConfigDefault.Title/Refresh which then
|
||||
// become incompatible with ConfigDefault.index
|
||||
if ConfigDefault.Title != defaultTitle ||
|
||||
ConfigDefault.Refresh != defaultRefresh ||
|
||||
ConfigDefault.FontURL != defaultFontURL ||
|
||||
ConfigDefault.ChartJsURL != defaultChartJSURL ||
|
||||
ConfigDefault.CustomHead != defaultCustomHead {
|
||||
if ConfigDefault.Refresh < minRefresh {
|
||||
ConfigDefault.Refresh = minRefresh
|
||||
}
|
||||
// update default index with new default title/refresh
|
||||
ConfigDefault.index = newIndex(viewBag{
|
||||
ConfigDefault.Title,
|
||||
ConfigDefault.Refresh,
|
||||
ConfigDefault.FontURL,
|
||||
ConfigDefault.ChartJsURL,
|
||||
ConfigDefault.CustomHead,
|
||||
})
|
||||
}
|
||||
|
||||
// Return default config if nothing provided
|
||||
if len(config) < 1 {
|
||||
return ConfigDefault
|
||||
}
|
||||
|
||||
// Override default config
|
||||
cfg := config[0]
|
||||
|
||||
// Set default values
|
||||
if cfg.Title == "" {
|
||||
cfg.Title = ConfigDefault.Title
|
||||
}
|
||||
|
||||
if cfg.Refresh == 0 {
|
||||
cfg.Refresh = ConfigDefault.Refresh
|
||||
}
|
||||
if cfg.FontURL == "" {
|
||||
cfg.FontURL = defaultFontURL
|
||||
}
|
||||
|
||||
if cfg.ChartJsURL == "" {
|
||||
cfg.ChartJsURL = defaultChartJSURL
|
||||
}
|
||||
if cfg.Refresh < minRefresh {
|
||||
cfg.Refresh = minRefresh
|
||||
}
|
||||
|
||||
if cfg.Next == nil {
|
||||
cfg.Next = ConfigDefault.Next
|
||||
}
|
||||
|
||||
if !cfg.APIOnly {
|
||||
cfg.APIOnly = ConfigDefault.APIOnly
|
||||
}
|
||||
|
||||
// update cfg.index with custom title/refresh
|
||||
cfg.index = newIndex(viewBag{
|
||||
title: cfg.Title,
|
||||
refresh: cfg.Refresh,
|
||||
fontURL: cfg.FontURL,
|
||||
chartJSURL: cfg.ChartJsURL,
|
||||
customHead: cfg.CustomHead,
|
||||
})
|
||||
|
||||
return cfg
|
||||
}
|
163
middleware/monitor/config_test.go
Normal file
163
middleware/monitor/config_test.go
Normal file
|
@ -0,0 +1,163 @@
|
|||
package monitor
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
)
|
||||
|
||||
func Test_Config_Default(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("use default", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
cfg := configDefault()
|
||||
|
||||
utils.AssertEqual(t, defaultTitle, cfg.Title)
|
||||
utils.AssertEqual(t, defaultRefresh, cfg.Refresh)
|
||||
utils.AssertEqual(t, defaultFontURL, cfg.FontURL)
|
||||
utils.AssertEqual(t, defaultChartJSURL, cfg.ChartJsURL)
|
||||
utils.AssertEqual(t, defaultCustomHead, cfg.CustomHead)
|
||||
utils.AssertEqual(t, false, cfg.APIOnly)
|
||||
utils.AssertEqual(t, (func(*fiber.Ctx) bool)(nil), cfg.Next)
|
||||
utils.AssertEqual(t, newIndex(viewBag{defaultTitle, defaultRefresh, defaultFontURL, defaultChartJSURL, defaultCustomHead}), cfg.index)
|
||||
})
|
||||
|
||||
t.Run("set title", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
title := "title"
|
||||
cfg := configDefault(Config{
|
||||
Title: title,
|
||||
})
|
||||
|
||||
utils.AssertEqual(t, title, cfg.Title)
|
||||
utils.AssertEqual(t, defaultRefresh, cfg.Refresh)
|
||||
utils.AssertEqual(t, defaultFontURL, cfg.FontURL)
|
||||
utils.AssertEqual(t, defaultChartJSURL, cfg.ChartJsURL)
|
||||
utils.AssertEqual(t, defaultCustomHead, cfg.CustomHead)
|
||||
utils.AssertEqual(t, false, cfg.APIOnly)
|
||||
utils.AssertEqual(t, (func(*fiber.Ctx) bool)(nil), cfg.Next)
|
||||
utils.AssertEqual(t, newIndex(viewBag{title, defaultRefresh, defaultFontURL, defaultChartJSURL, defaultCustomHead}), cfg.index)
|
||||
})
|
||||
|
||||
t.Run("set refresh less than default", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
cfg := configDefault(Config{
|
||||
Refresh: 100 * time.Millisecond,
|
||||
})
|
||||
|
||||
utils.AssertEqual(t, defaultTitle, cfg.Title)
|
||||
utils.AssertEqual(t, minRefresh, cfg.Refresh)
|
||||
utils.AssertEqual(t, defaultFontURL, cfg.FontURL)
|
||||
utils.AssertEqual(t, defaultChartJSURL, cfg.ChartJsURL)
|
||||
utils.AssertEqual(t, defaultCustomHead, cfg.CustomHead)
|
||||
utils.AssertEqual(t, false, cfg.APIOnly)
|
||||
utils.AssertEqual(t, (func(*fiber.Ctx) bool)(nil), cfg.Next)
|
||||
utils.AssertEqual(t, newIndex(viewBag{defaultTitle, minRefresh, defaultFontURL, defaultChartJSURL, defaultCustomHead}), cfg.index)
|
||||
})
|
||||
|
||||
t.Run("set refresh", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
refresh := time.Second
|
||||
cfg := configDefault(Config{
|
||||
Refresh: refresh,
|
||||
})
|
||||
|
||||
utils.AssertEqual(t, defaultTitle, cfg.Title)
|
||||
utils.AssertEqual(t, refresh, cfg.Refresh)
|
||||
utils.AssertEqual(t, defaultFontURL, cfg.FontURL)
|
||||
utils.AssertEqual(t, defaultChartJSURL, cfg.ChartJsURL)
|
||||
utils.AssertEqual(t, defaultCustomHead, cfg.CustomHead)
|
||||
utils.AssertEqual(t, false, cfg.APIOnly)
|
||||
utils.AssertEqual(t, (func(*fiber.Ctx) bool)(nil), cfg.Next)
|
||||
utils.AssertEqual(t, newIndex(viewBag{defaultTitle, refresh, defaultFontURL, defaultChartJSURL, defaultCustomHead}), cfg.index)
|
||||
})
|
||||
|
||||
t.Run("set font url", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
fontURL := "https://example.com"
|
||||
cfg := configDefault(Config{
|
||||
FontURL: fontURL,
|
||||
})
|
||||
|
||||
utils.AssertEqual(t, defaultTitle, cfg.Title)
|
||||
utils.AssertEqual(t, defaultRefresh, cfg.Refresh)
|
||||
utils.AssertEqual(t, fontURL, cfg.FontURL)
|
||||
utils.AssertEqual(t, defaultChartJSURL, cfg.ChartJsURL)
|
||||
utils.AssertEqual(t, defaultCustomHead, cfg.CustomHead)
|
||||
utils.AssertEqual(t, false, cfg.APIOnly)
|
||||
utils.AssertEqual(t, (func(*fiber.Ctx) bool)(nil), cfg.Next)
|
||||
utils.AssertEqual(t, newIndex(viewBag{defaultTitle, defaultRefresh, fontURL, defaultChartJSURL, defaultCustomHead}), cfg.index)
|
||||
})
|
||||
|
||||
t.Run("set chart js url", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
chartURL := "http://example.com"
|
||||
cfg := configDefault(Config{
|
||||
ChartJsURL: chartURL,
|
||||
})
|
||||
|
||||
utils.AssertEqual(t, defaultTitle, cfg.Title)
|
||||
utils.AssertEqual(t, defaultRefresh, cfg.Refresh)
|
||||
utils.AssertEqual(t, defaultFontURL, cfg.FontURL)
|
||||
utils.AssertEqual(t, chartURL, cfg.ChartJsURL)
|
||||
utils.AssertEqual(t, defaultCustomHead, cfg.CustomHead)
|
||||
utils.AssertEqual(t, false, cfg.APIOnly)
|
||||
utils.AssertEqual(t, (func(*fiber.Ctx) bool)(nil), cfg.Next)
|
||||
utils.AssertEqual(t, newIndex(viewBag{defaultTitle, defaultRefresh, defaultFontURL, chartURL, defaultCustomHead}), cfg.index)
|
||||
})
|
||||
|
||||
t.Run("set custom head", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
head := "head"
|
||||
cfg := configDefault(Config{
|
||||
CustomHead: head,
|
||||
})
|
||||
|
||||
utils.AssertEqual(t, defaultTitle, cfg.Title)
|
||||
utils.AssertEqual(t, defaultRefresh, cfg.Refresh)
|
||||
utils.AssertEqual(t, defaultFontURL, cfg.FontURL)
|
||||
utils.AssertEqual(t, defaultChartJSURL, cfg.ChartJsURL)
|
||||
utils.AssertEqual(t, head, cfg.CustomHead)
|
||||
utils.AssertEqual(t, false, cfg.APIOnly)
|
||||
utils.AssertEqual(t, (func(*fiber.Ctx) bool)(nil), cfg.Next)
|
||||
utils.AssertEqual(t, newIndex(viewBag{defaultTitle, defaultRefresh, defaultFontURL, defaultChartJSURL, head}), cfg.index)
|
||||
})
|
||||
|
||||
t.Run("set api only", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
cfg := configDefault(Config{
|
||||
APIOnly: true,
|
||||
})
|
||||
|
||||
utils.AssertEqual(t, defaultTitle, cfg.Title)
|
||||
utils.AssertEqual(t, defaultRefresh, cfg.Refresh)
|
||||
utils.AssertEqual(t, defaultFontURL, cfg.FontURL)
|
||||
utils.AssertEqual(t, defaultChartJSURL, cfg.ChartJsURL)
|
||||
utils.AssertEqual(t, defaultCustomHead, cfg.CustomHead)
|
||||
utils.AssertEqual(t, true, cfg.APIOnly)
|
||||
utils.AssertEqual(t, (func(*fiber.Ctx) bool)(nil), cfg.Next)
|
||||
utils.AssertEqual(t, newIndex(viewBag{defaultTitle, defaultRefresh, defaultFontURL, defaultChartJSURL, defaultCustomHead}), cfg.index)
|
||||
})
|
||||
|
||||
t.Run("set next", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
f := func(c *fiber.Ctx) bool {
|
||||
return true
|
||||
}
|
||||
cfg := configDefault(Config{
|
||||
Next: f,
|
||||
})
|
||||
|
||||
utils.AssertEqual(t, defaultTitle, cfg.Title)
|
||||
utils.AssertEqual(t, defaultRefresh, cfg.Refresh)
|
||||
utils.AssertEqual(t, defaultFontURL, cfg.FontURL)
|
||||
utils.AssertEqual(t, defaultChartJSURL, cfg.ChartJsURL)
|
||||
utils.AssertEqual(t, defaultCustomHead, cfg.CustomHead)
|
||||
utils.AssertEqual(t, false, cfg.APIOnly)
|
||||
utils.AssertEqual(t, f(nil), cfg.Next(nil))
|
||||
utils.AssertEqual(t, newIndex(viewBag{defaultTitle, defaultRefresh, defaultFontURL, defaultChartJSURL, defaultCustomHead}), cfg.index)
|
||||
})
|
||||
}
|
271
middleware/monitor/index.go
Normal file
271
middleware/monitor/index.go
Normal file
|
@ -0,0 +1,271 @@
|
|||
package monitor
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type viewBag struct {
|
||||
title string
|
||||
refresh time.Duration
|
||||
fontURL string
|
||||
chartJSURL string
|
||||
customHead string
|
||||
}
|
||||
|
||||
// returns index with new title/refresh
|
||||
func newIndex(dat viewBag) string {
|
||||
timeout := dat.refresh.Milliseconds() - timeoutDiff
|
||||
if timeout < timeoutDiff {
|
||||
timeout = timeoutDiff
|
||||
}
|
||||
ts := strconv.FormatInt(timeout, 10)
|
||||
replacer := strings.NewReplacer("$TITLE", dat.title, "$TIMEOUT", ts,
|
||||
"$FONT_URL", dat.fontURL, "$CHART_JS_URL", dat.chartJSURL, "$CUSTOM_HEAD", dat.customHead,
|
||||
)
|
||||
return replacer.Replace(indexHTML)
|
||||
}
|
||||
|
||||
const (
|
||||
defaultTitle = "Fiber Monitor"
|
||||
|
||||
defaultRefresh = 3 * time.Second
|
||||
timeoutDiff = 200 // timeout will be Refresh (in milliseconds) - timeoutDiff
|
||||
minRefresh = timeoutDiff * time.Millisecond
|
||||
defaultFontURL = `https://fonts.googleapis.com/css2?family=Roboto:wght@400;900&display=swap`
|
||||
defaultChartJSURL = `https://cdn.jsdelivr.net/npm/chart.js@2.9/dist/Chart.bundle.min.js`
|
||||
defaultCustomHead = ``
|
||||
|
||||
// parametrized by $TITLE and $TIMEOUT
|
||||
indexHTML = `<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<link href="$FONT_URL" rel="stylesheet">
|
||||
<script src="$CHART_JS_URL"></script>
|
||||
|
||||
<title>$TITLE</title>
|
||||
<style>
|
||||
body {
|
||||
margin: 0;
|
||||
font: 16px / 1.6 'Roboto', sans-serif;
|
||||
}
|
||||
.wrapper {
|
||||
max-width: 900px;
|
||||
margin: 0 auto;
|
||||
padding: 30px 0;
|
||||
}
|
||||
.title {
|
||||
text-align: center;
|
||||
margin-bottom: 2em;
|
||||
}
|
||||
.title h1 {
|
||||
font-size: 1.8em;
|
||||
padding: 0;
|
||||
margin: 0;
|
||||
}
|
||||
.row {
|
||||
display: flex;
|
||||
margin-bottom: 20px;
|
||||
align-items: center;
|
||||
}
|
||||
.row .column:first-child { width: 35%; }
|
||||
.row .column:last-child { width: 65%; }
|
||||
.metric {
|
||||
color: #777;
|
||||
font-weight: 900;
|
||||
}
|
||||
h2 {
|
||||
padding: 0;
|
||||
margin: 0;
|
||||
font-size: 2.2em;
|
||||
}
|
||||
h2 span {
|
||||
font-size: 12px;
|
||||
color: #777;
|
||||
}
|
||||
h2 span.ram_os { color: rgba(255, 150, 0, .8); }
|
||||
h2 span.ram_total { color: rgba(0, 200, 0, .8); }
|
||||
canvas {
|
||||
width: 200px;
|
||||
height: 180px;
|
||||
}
|
||||
$CUSTOM_HEAD
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<section class="wrapper">
|
||||
<div class="title"><h1>$TITLE</h1></div>
|
||||
<section class="charts">
|
||||
<div class="row">
|
||||
<div class="column">
|
||||
<div class="metric">CPU Usage</div>
|
||||
<h2 id="cpuMetric">0.00%</h2>
|
||||
</div>
|
||||
<div class="column">
|
||||
<canvas id="cpuChart"></canvas>
|
||||
</div>
|
||||
</div>
|
||||
<div class="row">
|
||||
<div class="column">
|
||||
<div class="metric">Memory Usage</div>
|
||||
<h2 id="ramMetric" title="PID used / OS used / OS total">0.00 MB</h2>
|
||||
</div>
|
||||
<div class="column">
|
||||
<canvas id="ramChart"></canvas>
|
||||
</div>
|
||||
</div>
|
||||
<div class="row">
|
||||
<div class="column">
|
||||
<div class="metric">Response Time</div>
|
||||
<h2 id="rtimeMetric">0ms</h2>
|
||||
</div>
|
||||
<div class="column">
|
||||
<canvas id="rtimeChart"></canvas>
|
||||
</div>
|
||||
</div>
|
||||
<div class="row">
|
||||
<div class="column">
|
||||
<div class="metric">Open Connections</div>
|
||||
<h2 id="connsMetric">0</h2>
|
||||
</div>
|
||||
<div class="column">
|
||||
<canvas id="connsChart"></canvas>
|
||||
</div>
|
||||
</div>
|
||||
</section>
|
||||
</section>
|
||||
<script>
|
||||
function formatBytes(bytes, decimals = 1) {
|
||||
if (bytes === 0) return '0 Bytes';
|
||||
|
||||
const k = 1024;
|
||||
const dm = decimals < 0 ? 0 : decimals;
|
||||
const sizes = ['Bytes', 'KB', 'MB', 'GB', 'TB', 'PB', 'EB', 'ZB', 'YB'];
|
||||
|
||||
const i = Math.floor(Math.log(bytes) / Math.log(k));
|
||||
|
||||
return parseFloat((bytes / Math.pow(k, i)).toFixed(dm)) + ' ' + sizes[i];
|
||||
}
|
||||
Chart.defaults.global.legend.display = false;
|
||||
Chart.defaults.global.defaultFontSize = 8;
|
||||
Chart.defaults.global.animation.duration = 1000;
|
||||
Chart.defaults.global.animation.easing = 'easeOutQuart';
|
||||
Chart.defaults.global.elements.line.backgroundColor = 'rgba(0, 172, 215, 0.25)';
|
||||
Chart.defaults.global.elements.line.borderColor = 'rgba(0, 172, 215, 1)';
|
||||
Chart.defaults.global.elements.line.borderWidth = 2;
|
||||
|
||||
const options = {
|
||||
scales: {
|
||||
yAxes: [{ ticks: { beginAtZero: true }}],
|
||||
xAxes: [{
|
||||
type: 'time',
|
||||
time: {
|
||||
unitStepSize: 30,
|
||||
unit: 'second'
|
||||
},
|
||||
gridlines: { display: false }
|
||||
}]
|
||||
},
|
||||
tooltips: { enabled: false },
|
||||
responsive: true,
|
||||
maintainAspectRatio: false,
|
||||
animation: false
|
||||
};
|
||||
const cpuMetric = document.querySelector('#cpuMetric');
|
||||
const ramMetric = document.querySelector('#ramMetric');
|
||||
const rtimeMetric = document.querySelector('#rtimeMetric');
|
||||
const connsMetric = document.querySelector('#connsMetric');
|
||||
|
||||
const cpuChartCtx = document.querySelector('#cpuChart').getContext('2d');
|
||||
const ramChartCtx = document.querySelector('#ramChart').getContext('2d');
|
||||
const rtimeChartCtx = document.querySelector('#rtimeChart').getContext('2d');
|
||||
const connsChartCtx = document.querySelector('#connsChart').getContext('2d');
|
||||
|
||||
const cpuChart = createChart(cpuChartCtx);
|
||||
const ramChart = createChart(ramChartCtx);
|
||||
const rtimeChart = createChart(rtimeChartCtx);
|
||||
const connsChart = createChart(connsChartCtx);
|
||||
|
||||
const charts = [cpuChart, ramChart, rtimeChart, connsChart];
|
||||
|
||||
function createChart(ctx) {
|
||||
return new Chart(ctx, {
|
||||
type: 'line',
|
||||
data: {
|
||||
labels: [],
|
||||
datasets: [{
|
||||
label: '',
|
||||
data: [],
|
||||
lineTension: 0.2,
|
||||
pointRadius: 0,
|
||||
}]
|
||||
},
|
||||
options
|
||||
});
|
||||
}
|
||||
ramChart.data.datasets.push({
|
||||
data: [],
|
||||
lineTension: 0.2,
|
||||
pointRadius: 0,
|
||||
backgroundColor: 'rgba(255, 200, 0, .6)',
|
||||
borderColor: 'rgba(255, 150, 0, .8)',
|
||||
})
|
||||
ramChart.data.datasets.push({
|
||||
data: [],
|
||||
lineTension: 0.2,
|
||||
pointRadius: 0,
|
||||
backgroundColor: 'rgba(0, 255, 0, .4)',
|
||||
borderColor: 'rgba(0, 200, 0, .8)',
|
||||
})
|
||||
function update(json, rtime) {
|
||||
cpu = json.pid.cpu.toFixed(1);
|
||||
cpuOS = json.os.cpu.toFixed(1);
|
||||
|
||||
cpuMetric.innerHTML = cpu + '% <span>' + cpuOS + '%</span>';
|
||||
ramMetric.innerHTML = formatBytes(json.pid.ram) + '<span> / </span><span class="ram_os">' + formatBytes(json.os.ram) +
|
||||
'<span><span> / </span><span class="ram_total">' + formatBytes(json.os.total_ram) + '</span>';
|
||||
rtimeMetric.innerHTML = rtime + 'ms <span>client</span>';
|
||||
connsMetric.innerHTML = json.pid.conns + ' <span>' + json.os.conns + '</span>';
|
||||
|
||||
cpuChart.data.datasets[0].data.push(cpu);
|
||||
ramChart.data.datasets[2].data.push((json.os.total_ram / 1e6).toFixed(2));
|
||||
ramChart.data.datasets[1].data.push((json.os.ram / 1e6).toFixed(2));
|
||||
ramChart.data.datasets[0].data.push((json.pid.ram / 1e6).toFixed(2));
|
||||
rtimeChart.data.datasets[0].data.push(rtime);
|
||||
connsChart.data.datasets[0].data.push(json.pid.conns);
|
||||
|
||||
const timestamp = new Date().getTime();
|
||||
|
||||
charts.forEach(chart => {
|
||||
if (chart.data.labels.length > 50) {
|
||||
chart.data.datasets.forEach(function (dataset) { dataset.data.shift(); });
|
||||
chart.data.labels.shift();
|
||||
}
|
||||
chart.data.labels.push(timestamp);
|
||||
chart.update();
|
||||
});
|
||||
setTimeout(fetchJSON, $TIMEOUT)
|
||||
}
|
||||
function fetchJSON() {
|
||||
var t1 = ''
|
||||
var t0 = performance.now()
|
||||
fetch(window.location.href, {
|
||||
headers: { 'Accept': 'application/json' },
|
||||
credentials: 'same-origin'
|
||||
})
|
||||
.then(res => {
|
||||
t1 = performance.now()
|
||||
return res.json()
|
||||
})
|
||||
.then(res => { update(res, Math.round(t1 - t0)) })
|
||||
.catch(console.error);
|
||||
}
|
||||
fetchJSON()
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
`
|
||||
)
|
137
middleware/monitor/monitor.go
Normal file
137
middleware/monitor/monitor.go
Normal file
|
@ -0,0 +1,137 @@
|
|||
package monitor
|
||||
|
||||
import (
|
||||
"os"
|
||||
"runtime"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/internal/gopsutil/cpu"
|
||||
"github.com/gofiber/fiber/v2/internal/gopsutil/load"
|
||||
"github.com/gofiber/fiber/v2/internal/gopsutil/mem"
|
||||
"github.com/gofiber/fiber/v2/internal/gopsutil/net"
|
||||
"github.com/gofiber/fiber/v2/internal/gopsutil/process"
|
||||
)
|
||||
|
||||
type stats struct {
|
||||
PID statsPID `json:"pid"`
|
||||
OS statsOS `json:"os"`
|
||||
}
|
||||
|
||||
type statsPID struct {
|
||||
CPU float64 `json:"cpu"`
|
||||
RAM uint64 `json:"ram"`
|
||||
Conns int `json:"conns"`
|
||||
}
|
||||
|
||||
type statsOS struct {
|
||||
CPU float64 `json:"cpu"`
|
||||
RAM uint64 `json:"ram"`
|
||||
TotalRAM uint64 `json:"total_ram"`
|
||||
LoadAvg float64 `json:"load_avg"`
|
||||
Conns int `json:"conns"`
|
||||
}
|
||||
|
||||
var (
|
||||
monitPIDCPU atomic.Value
|
||||
monitPIDRAM atomic.Value
|
||||
monitPIDConns atomic.Value
|
||||
|
||||
monitOSCPU atomic.Value
|
||||
monitOSRAM atomic.Value
|
||||
monitOSTotalRAM atomic.Value
|
||||
monitOSLoadAvg atomic.Value
|
||||
monitOSConns atomic.Value
|
||||
)
|
||||
|
||||
var (
|
||||
mutex sync.RWMutex
|
||||
once sync.Once
|
||||
data = &stats{}
|
||||
)
|
||||
|
||||
// New creates a new middleware handler
|
||||
func New(config ...Config) fiber.Handler {
|
||||
// Set default config
|
||||
cfg := configDefault(config...)
|
||||
|
||||
// Start routine to update statistics
|
||||
once.Do(func() {
|
||||
p, _ := process.NewProcess(int32(os.Getpid())) //nolint:errcheck // TODO: Handle error
|
||||
numcpu := runtime.NumCPU()
|
||||
updateStatistics(p, numcpu)
|
||||
|
||||
go func() {
|
||||
for {
|
||||
time.Sleep(cfg.Refresh)
|
||||
|
||||
updateStatistics(p, numcpu)
|
||||
}
|
||||
}()
|
||||
})
|
||||
|
||||
// Return new handler
|
||||
//nolint:errcheck // Ignore the type-assertion errors
|
||||
return func(c *fiber.Ctx) error {
|
||||
// Don't execute middleware if Next returns true
|
||||
if cfg.Next != nil && cfg.Next(c) {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
if c.Method() != fiber.MethodGet {
|
||||
return fiber.ErrMethodNotAllowed
|
||||
}
|
||||
if c.Get(fiber.HeaderAccept) == fiber.MIMEApplicationJSON || cfg.APIOnly {
|
||||
mutex.Lock()
|
||||
data.PID.CPU, _ = monitPIDCPU.Load().(float64)
|
||||
data.PID.RAM, _ = monitPIDRAM.Load().(uint64)
|
||||
data.PID.Conns, _ = monitPIDConns.Load().(int)
|
||||
|
||||
data.OS.CPU, _ = monitOSCPU.Load().(float64)
|
||||
data.OS.RAM, _ = monitOSRAM.Load().(uint64)
|
||||
data.OS.TotalRAM, _ = monitOSTotalRAM.Load().(uint64)
|
||||
data.OS.LoadAvg, _ = monitOSLoadAvg.Load().(float64)
|
||||
data.OS.Conns, _ = monitOSConns.Load().(int)
|
||||
mutex.Unlock()
|
||||
return c.Status(fiber.StatusOK).JSON(data)
|
||||
}
|
||||
c.Set(fiber.HeaderContentType, fiber.MIMETextHTMLCharsetUTF8)
|
||||
return c.Status(fiber.StatusOK).SendString(cfg.index)
|
||||
}
|
||||
}
|
||||
|
||||
func updateStatistics(p *process.Process, numcpu int) {
|
||||
pidCPU, err := p.Percent(0)
|
||||
if err == nil {
|
||||
monitPIDCPU.Store(pidCPU / float64(numcpu))
|
||||
}
|
||||
|
||||
if osCPU, err := cpu.Percent(0, false); err == nil && len(osCPU) > 0 {
|
||||
monitOSCPU.Store(osCPU[0])
|
||||
}
|
||||
|
||||
if pidRAM, err := p.MemoryInfo(); err == nil && pidRAM != nil {
|
||||
monitPIDRAM.Store(pidRAM.RSS)
|
||||
}
|
||||
|
||||
if osRAM, err := mem.VirtualMemory(); err == nil && osRAM != nil {
|
||||
monitOSRAM.Store(osRAM.Used)
|
||||
monitOSTotalRAM.Store(osRAM.Total)
|
||||
}
|
||||
|
||||
if loadAvg, err := load.Avg(); err == nil && loadAvg != nil {
|
||||
monitOSLoadAvg.Store(loadAvg.Load1)
|
||||
}
|
||||
|
||||
pidConns, err := net.ConnectionsPid("tcp", p.Pid)
|
||||
if err == nil {
|
||||
monitPIDConns.Store(len(pidConns))
|
||||
}
|
||||
|
||||
osConns, err := net.Connections("tcp")
|
||||
if err == nil {
|
||||
monitOSConns.Store(len(osConns))
|
||||
}
|
||||
}
|
198
middleware/monitor/monitor_test.go
Normal file
198
middleware/monitor/monitor_test.go
Normal file
|
@ -0,0 +1,198 @@
|
|||
package monitor
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
func Test_Monitor_405(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app := fiber.New()
|
||||
|
||||
app.Use("/", New())
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodPost, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 405, resp.StatusCode)
|
||||
}
|
||||
|
||||
func Test_Monitor_Html(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app := fiber.New()
|
||||
|
||||
// defaults
|
||||
app.Get("/", New())
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||
utils.AssertEqual(t, fiber.MIMETextHTMLCharsetUTF8,
|
||||
resp.Header.Get(fiber.HeaderContentType))
|
||||
buf, err := io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, true, bytes.Contains(buf, []byte("<title>"+defaultTitle+"</title>")))
|
||||
timeoutLine := fmt.Sprintf("setTimeout(fetchJSON, %d)",
|
||||
defaultRefresh.Milliseconds()-timeoutDiff)
|
||||
utils.AssertEqual(t, true, bytes.Contains(buf, []byte(timeoutLine)))
|
||||
|
||||
// custom config
|
||||
conf := Config{Title: "New " + defaultTitle, Refresh: defaultRefresh + time.Second}
|
||||
app.Get("/custom", New(conf))
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/custom", nil))
|
||||
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||
utils.AssertEqual(t, fiber.MIMETextHTMLCharsetUTF8,
|
||||
resp.Header.Get(fiber.HeaderContentType))
|
||||
buf, err = io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, true, bytes.Contains(buf, []byte("<title>"+conf.Title+"</title>")))
|
||||
timeoutLine = fmt.Sprintf("setTimeout(fetchJSON, %d)",
|
||||
conf.Refresh.Milliseconds()-timeoutDiff)
|
||||
utils.AssertEqual(t, true, bytes.Contains(buf, []byte(timeoutLine)))
|
||||
}
|
||||
|
||||
func Test_Monitor_Html_CustomCodes(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app := fiber.New()
|
||||
|
||||
// defaults
|
||||
app.Get("/", New())
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||
utils.AssertEqual(t, fiber.MIMETextHTMLCharsetUTF8,
|
||||
resp.Header.Get(fiber.HeaderContentType))
|
||||
buf, err := io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, true, bytes.Contains(buf, []byte("<title>"+defaultTitle+"</title>")))
|
||||
timeoutLine := fmt.Sprintf("setTimeout(fetchJSON, %d)",
|
||||
defaultRefresh.Milliseconds()-timeoutDiff)
|
||||
utils.AssertEqual(t, true, bytes.Contains(buf, []byte(timeoutLine)))
|
||||
|
||||
// custom config
|
||||
conf := Config{
|
||||
Title: "New " + defaultTitle,
|
||||
Refresh: defaultRefresh + time.Second,
|
||||
ChartJsURL: "https://cdnjs.com/libraries/Chart.js",
|
||||
FontURL: "/public/my-font.css",
|
||||
CustomHead: `<style>body{background:#fff}</style>`,
|
||||
}
|
||||
app.Get("/custom", New(conf))
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/custom", nil))
|
||||
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||
utils.AssertEqual(t, fiber.MIMETextHTMLCharsetUTF8,
|
||||
resp.Header.Get(fiber.HeaderContentType))
|
||||
buf, err = io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, true, bytes.Contains(buf, []byte("<title>"+conf.Title+"</title>")))
|
||||
utils.AssertEqual(t, true, bytes.Contains(buf, []byte("https://cdnjs.com/libraries/Chart.js")))
|
||||
utils.AssertEqual(t, true, bytes.Contains(buf, []byte("/public/my-font.css")))
|
||||
utils.AssertEqual(t, true, bytes.Contains(buf, []byte(conf.CustomHead)))
|
||||
|
||||
timeoutLine = fmt.Sprintf("setTimeout(fetchJSON, %d)",
|
||||
conf.Refresh.Milliseconds()-timeoutDiff)
|
||||
utils.AssertEqual(t, true, bytes.Contains(buf, []byte(timeoutLine)))
|
||||
}
|
||||
|
||||
// go test -run Test_Monitor_JSON -race
|
||||
func Test_Monitor_JSON(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app := fiber.New()
|
||||
|
||||
app.Get("/", New())
|
||||
|
||||
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||
req.Header.Set(fiber.HeaderAccept, fiber.MIMEApplicationJSON)
|
||||
resp, err := app.Test(req)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||
utils.AssertEqual(t, fiber.MIMEApplicationJSON, resp.Header.Get(fiber.HeaderContentType))
|
||||
|
||||
b, err := io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, true, bytes.Contains(b, []byte("pid")))
|
||||
utils.AssertEqual(t, true, bytes.Contains(b, []byte("os")))
|
||||
}
|
||||
|
||||
// go test -v -run=^$ -bench=Benchmark_Monitor -benchmem -count=4
|
||||
func Benchmark_Monitor(b *testing.B) {
|
||||
app := fiber.New()
|
||||
|
||||
app.Get("/", New())
|
||||
|
||||
h := app.Handler()
|
||||
|
||||
fctx := &fasthttp.RequestCtx{}
|
||||
fctx.Request.Header.SetMethod(fiber.MethodGet)
|
||||
fctx.Request.SetRequestURI("/")
|
||||
fctx.Request.Header.Set(fiber.HeaderAccept, fiber.MIMEApplicationJSON)
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
h(fctx)
|
||||
}
|
||||
})
|
||||
|
||||
utils.AssertEqual(b, 200, fctx.Response.Header.StatusCode())
|
||||
utils.AssertEqual(b,
|
||||
fiber.MIMEApplicationJSON,
|
||||
string(fctx.Response.Header.Peek(fiber.HeaderContentType)))
|
||||
}
|
||||
|
||||
// go test -run Test_Monitor_Next
|
||||
func Test_Monitor_Next(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app := fiber.New()
|
||||
|
||||
app.Use("/", New(Config{
|
||||
Next: func(_ *fiber.Ctx) bool {
|
||||
return true
|
||||
},
|
||||
}))
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodPost, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 404, resp.StatusCode)
|
||||
}
|
||||
|
||||
// go test -run Test_Monitor_APIOnly -race
|
||||
func Test_Monitor_APIOnly(t *testing.T) {
|
||||
app := fiber.New()
|
||||
|
||||
app.Get("/", New(Config{
|
||||
APIOnly: true,
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||
req.Header.Set(fiber.HeaderAccept, fiber.MIMEApplicationJSON)
|
||||
resp, err := app.Test(req)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||
utils.AssertEqual(t, fiber.MIMEApplicationJSON, resp.Header.Get(fiber.HeaderContentType))
|
||||
|
||||
b, err := io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, true, bytes.Contains(b, []byte("pid")))
|
||||
utils.AssertEqual(t, true, bytes.Contains(b, []byte("os")))
|
||||
}
|
41
middleware/pprof/config.go
Normal file
41
middleware/pprof/config.go
Normal file
|
@ -0,0 +1,41 @@
|
|||
package pprof
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
// Config defines the config for middleware.
|
||||
type Config struct {
|
||||
// Next defines a function to skip this middleware when returned true.
|
||||
//
|
||||
// Optional. Default: nil
|
||||
Next func(c *fiber.Ctx) bool
|
||||
|
||||
// Prefix defines a URL prefix added before "/debug/pprof".
|
||||
// Note that it should start with (but not end with) a slash.
|
||||
// Example: "/federated-fiber"
|
||||
//
|
||||
// Optional. Default: ""
|
||||
Prefix string
|
||||
}
|
||||
|
||||
var ConfigDefault = Config{
|
||||
Next: nil,
|
||||
}
|
||||
|
||||
func configDefault(config ...Config) Config {
|
||||
// Return default config if nothing provided
|
||||
if len(config) < 1 {
|
||||
return ConfigDefault
|
||||
}
|
||||
|
||||
// Override default config
|
||||
cfg := config[0]
|
||||
|
||||
// Set default values
|
||||
if cfg.Next == nil {
|
||||
cfg.Next = ConfigDefault.Next
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
95
middleware/pprof/pprof.go
Normal file
95
middleware/pprof/pprof.go
Normal file
|
@ -0,0 +1,95 @@
|
|||
package pprof
|
||||
|
||||
import (
|
||||
"net/http/pprof"
|
||||
"strings"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
|
||||
"github.com/valyala/fasthttp/fasthttpadaptor"
|
||||
)
|
||||
|
||||
// New creates a new middleware handler
|
||||
func New(config ...Config) fiber.Handler {
|
||||
// Set default config
|
||||
cfg := configDefault(config...)
|
||||
|
||||
// Set pprof adaptors
|
||||
var (
|
||||
pprofIndex = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Index)
|
||||
pprofCmdline = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Cmdline)
|
||||
pprofProfile = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Profile)
|
||||
pprofSymbol = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Symbol)
|
||||
pprofTrace = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Trace)
|
||||
pprofAllocs = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Handler("allocs").ServeHTTP)
|
||||
pprofBlock = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Handler("block").ServeHTTP)
|
||||
pprofGoroutine = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Handler("goroutine").ServeHTTP)
|
||||
pprofHeap = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Handler("heap").ServeHTTP)
|
||||
pprofMutex = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Handler("mutex").ServeHTTP)
|
||||
pprofThreadcreate = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Handler("threadcreate").ServeHTTP)
|
||||
)
|
||||
|
||||
// Construct actual prefix
|
||||
prefix := cfg.Prefix + "/debug/pprof"
|
||||
|
||||
// Return new handler
|
||||
return func(c *fiber.Ctx) error {
|
||||
// Don't execute middleware if Next returns true
|
||||
if cfg.Next != nil && cfg.Next(c) {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
path := c.Path()
|
||||
// We are only interested in /debug/pprof routes
|
||||
path, found := cutPrefix(path, prefix)
|
||||
if !found {
|
||||
return c.Next()
|
||||
}
|
||||
// Switch on trimmed path against constant strings
|
||||
switch path {
|
||||
case "/":
|
||||
pprofIndex(c.Context())
|
||||
case "/cmdline":
|
||||
pprofCmdline(c.Context())
|
||||
case "/profile":
|
||||
pprofProfile(c.Context())
|
||||
case "/symbol":
|
||||
pprofSymbol(c.Context())
|
||||
case "/trace":
|
||||
pprofTrace(c.Context())
|
||||
case "/allocs":
|
||||
pprofAllocs(c.Context())
|
||||
case "/block":
|
||||
pprofBlock(c.Context())
|
||||
case "/goroutine":
|
||||
pprofGoroutine(c.Context())
|
||||
case "/heap":
|
||||
pprofHeap(c.Context())
|
||||
case "/mutex":
|
||||
pprofMutex(c.Context())
|
||||
case "/threadcreate":
|
||||
pprofThreadcreate(c.Context())
|
||||
default:
|
||||
// pprof index only works with trailing slash
|
||||
if strings.HasSuffix(path, "/") {
|
||||
path = strings.TrimRight(path, "/")
|
||||
} else {
|
||||
path = prefix + "/"
|
||||
}
|
||||
|
||||
return c.Redirect(path, fiber.StatusFound)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// cutPrefix is a copy of [strings.CutPrefix] added in Go 1.20.
|
||||
// Remove this function when we drop support for Go 1.19.
|
||||
//
|
||||
//nolint:nonamedreturns // Align with its original form in std.
|
||||
func cutPrefix(s, prefix string) (after string, found bool) {
|
||||
if !strings.HasPrefix(s, prefix) {
|
||||
return s, false
|
||||
}
|
||||
return s[len(prefix):], true
|
||||
}
|
200
middleware/pprof/pprof_test.go
Normal file
200
middleware/pprof/pprof_test.go
Normal file
|
@ -0,0 +1,200 @@
|
|||
package pprof
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
)
|
||||
|
||||
func Test_Non_Pprof_Path(t *testing.T) {
|
||||
app := fiber.New(fiber.Config{DisableStartupMessage: true})
|
||||
|
||||
app.Use(New())
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return c.SendString("escaped")
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||
|
||||
b, err := io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, "escaped", string(b))
|
||||
}
|
||||
|
||||
func Test_Non_Pprof_Path_WithPrefix(t *testing.T) {
|
||||
app := fiber.New(fiber.Config{DisableStartupMessage: true})
|
||||
|
||||
app.Use(New(Config{Prefix: "/federated-fiber"}))
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return c.SendString("escaped")
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||
|
||||
b, err := io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, "escaped", string(b))
|
||||
}
|
||||
|
||||
func Test_Pprof_Index(t *testing.T) {
|
||||
app := fiber.New(fiber.Config{DisableStartupMessage: true})
|
||||
|
||||
app.Use(New())
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return c.SendString("escaped")
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/debug/pprof/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||
utils.AssertEqual(t, fiber.MIMETextHTMLCharsetUTF8, resp.Header.Get(fiber.HeaderContentType))
|
||||
|
||||
b, err := io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, true, bytes.Contains(b, []byte("<title>/debug/pprof/</title>")))
|
||||
}
|
||||
|
||||
func Test_Pprof_Index_WithPrefix(t *testing.T) {
|
||||
app := fiber.New(fiber.Config{DisableStartupMessage: true})
|
||||
|
||||
app.Use(New(Config{Prefix: "/federated-fiber"}))
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return c.SendString("escaped")
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/federated-fiber/debug/pprof/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||
utils.AssertEqual(t, fiber.MIMETextHTMLCharsetUTF8, resp.Header.Get(fiber.HeaderContentType))
|
||||
|
||||
b, err := io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, true, bytes.Contains(b, []byte("<title>/debug/pprof/</title>")))
|
||||
}
|
||||
|
||||
func Test_Pprof_Subs(t *testing.T) {
|
||||
app := fiber.New(fiber.Config{DisableStartupMessage: true})
|
||||
|
||||
app.Use(New())
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return c.SendString("escaped")
|
||||
})
|
||||
|
||||
subs := []string{
|
||||
"cmdline", "profile", "symbol", "trace", "allocs", "block",
|
||||
"goroutine", "heap", "mutex", "threadcreate",
|
||||
}
|
||||
|
||||
for _, sub := range subs {
|
||||
sub := sub
|
||||
t.Run(sub, func(t *testing.T) {
|
||||
target := "/debug/pprof/" + sub
|
||||
if sub == "profile" {
|
||||
target += "?seconds=1"
|
||||
}
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, target, nil), 5000)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Pprof_Subs_WithPrefix(t *testing.T) {
|
||||
app := fiber.New(fiber.Config{DisableStartupMessage: true})
|
||||
|
||||
app.Use(New(Config{Prefix: "/federated-fiber"}))
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return c.SendString("escaped")
|
||||
})
|
||||
|
||||
subs := []string{
|
||||
"cmdline", "profile", "symbol", "trace", "allocs", "block",
|
||||
"goroutine", "heap", "mutex", "threadcreate",
|
||||
}
|
||||
|
||||
for _, sub := range subs {
|
||||
sub := sub
|
||||
t.Run(sub, func(t *testing.T) {
|
||||
target := "/federated-fiber/debug/pprof/" + sub
|
||||
if sub == "profile" {
|
||||
target += "?seconds=1"
|
||||
}
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, target, nil), 5000)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Pprof_Other(t *testing.T) {
|
||||
app := fiber.New(fiber.Config{DisableStartupMessage: true})
|
||||
|
||||
app.Use(New())
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return c.SendString("escaped")
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/debug/pprof/302", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 302, resp.StatusCode)
|
||||
}
|
||||
|
||||
func Test_Pprof_Other_WithPrefix(t *testing.T) {
|
||||
app := fiber.New(fiber.Config{DisableStartupMessage: true})
|
||||
|
||||
app.Use(New(Config{Prefix: "/federated-fiber"}))
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return c.SendString("escaped")
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/federated-fiber/debug/pprof/302", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 302, resp.StatusCode)
|
||||
}
|
||||
|
||||
// go test -run Test_Pprof_Next
|
||||
func Test_Pprof_Next(t *testing.T) {
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
Next: func(_ *fiber.Ctx) bool {
|
||||
return true
|
||||
},
|
||||
}))
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/debug/pprof/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 404, resp.StatusCode)
|
||||
}
|
||||
|
||||
// go test -run Test_Pprof_Next_WithPrefix
|
||||
func Test_Pprof_Next_WithPrefix(t *testing.T) {
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
Next: func(_ *fiber.Ctx) bool {
|
||||
return true
|
||||
},
|
||||
Prefix: "/federated-fiber",
|
||||
}))
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/federated-fiber/debug/pprof/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 404, resp.StatusCode)
|
||||
}
|
88
middleware/proxy/config.go
Normal file
88
middleware/proxy/config.go
Normal file
|
@ -0,0 +1,88 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// Config defines the config for middleware.
|
||||
type Config struct {
|
||||
// Next defines a function to skip this middleware when returned true.
|
||||
//
|
||||
// Optional. Default: nil
|
||||
Next func(c *fiber.Ctx) bool
|
||||
|
||||
// Servers defines a list of <scheme>://<host> HTTP servers,
|
||||
//
|
||||
// which are used in a round-robin manner.
|
||||
// i.e.: "https://foobar.com, http://www.foobar.com"
|
||||
//
|
||||
// Required
|
||||
Servers []string
|
||||
|
||||
// ModifyRequest allows you to alter the request
|
||||
//
|
||||
// Optional. Default: nil
|
||||
ModifyRequest fiber.Handler
|
||||
|
||||
// ModifyResponse allows you to alter the response
|
||||
//
|
||||
// Optional. Default: nil
|
||||
ModifyResponse fiber.Handler
|
||||
|
||||
// Timeout is the request timeout used when calling the proxy client
|
||||
//
|
||||
// Optional. Default: 1 second
|
||||
Timeout time.Duration
|
||||
|
||||
// Per-connection buffer size for requests' reading.
|
||||
// This also limits the maximum header size.
|
||||
// Increase this buffer if your clients send multi-KB RequestURIs
|
||||
// and/or multi-KB headers (for example, BIG cookies).
|
||||
ReadBufferSize int
|
||||
|
||||
// Per-connection buffer size for responses' writing.
|
||||
WriteBufferSize int
|
||||
|
||||
// tls config for the http client.
|
||||
TlsConfig *tls.Config //nolint:stylecheck,revive // TODO: Rename to "TLSConfig" in v3
|
||||
|
||||
// Client is custom client when client config is complex.
|
||||
// Note that Servers, Timeout, WriteBufferSize, ReadBufferSize and TlsConfig
|
||||
// will not be used if the client are set.
|
||||
Client *fasthttp.LBClient
|
||||
}
|
||||
|
||||
// ConfigDefault is the default config
|
||||
var ConfigDefault = Config{
|
||||
Next: nil,
|
||||
ModifyRequest: nil,
|
||||
ModifyResponse: nil,
|
||||
Timeout: fasthttp.DefaultLBClientTimeout,
|
||||
}
|
||||
|
||||
// configDefault function to set default values
|
||||
func configDefault(config ...Config) Config {
|
||||
// Return default config if nothing provided
|
||||
if len(config) < 1 {
|
||||
return ConfigDefault
|
||||
}
|
||||
|
||||
// Override default config
|
||||
cfg := config[0]
|
||||
|
||||
// Set default values
|
||||
if cfg.Timeout <= 0 {
|
||||
cfg.Timeout = ConfigDefault.Timeout
|
||||
}
|
||||
|
||||
// Set default values
|
||||
if len(cfg.Servers) == 0 && cfg.Client == nil {
|
||||
panic("Servers cannot be empty")
|
||||
}
|
||||
return cfg
|
||||
}
|
267
middleware/proxy/proxy.go
Normal file
267
middleware/proxy/proxy.go
Normal file
|
@ -0,0 +1,267 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/log"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// New is deprecated
|
||||
func New(config Config) fiber.Handler {
|
||||
log.Warn("[PROXY] proxy.New is deprecated, please use proxy.Balancer instead")
|
||||
return Balancer(config)
|
||||
}
|
||||
|
||||
// Balancer creates a load balancer among multiple upstream servers
|
||||
func Balancer(config Config) fiber.Handler {
|
||||
// Set default config
|
||||
cfg := configDefault(config)
|
||||
|
||||
// Load balanced client
|
||||
lbc := &fasthttp.LBClient{}
|
||||
// Note that Servers, Timeout, WriteBufferSize, ReadBufferSize and TlsConfig
|
||||
// will not be used if the client are set.
|
||||
if config.Client == nil {
|
||||
// Set timeout
|
||||
lbc.Timeout = cfg.Timeout
|
||||
// Scheme must be provided, falls back to http
|
||||
for _, server := range cfg.Servers {
|
||||
if !strings.HasPrefix(server, "http") {
|
||||
server = "http://" + server
|
||||
}
|
||||
|
||||
u, err := url.Parse(server)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
client := &fasthttp.HostClient{
|
||||
NoDefaultUserAgentHeader: true,
|
||||
DisablePathNormalizing: true,
|
||||
Addr: u.Host,
|
||||
|
||||
ReadBufferSize: config.ReadBufferSize,
|
||||
WriteBufferSize: config.WriteBufferSize,
|
||||
|
||||
TLSConfig: config.TlsConfig,
|
||||
}
|
||||
|
||||
lbc.Clients = append(lbc.Clients, client)
|
||||
}
|
||||
} else {
|
||||
// Set custom client
|
||||
lbc = config.Client
|
||||
}
|
||||
|
||||
// Return new handler
|
||||
return func(c *fiber.Ctx) error {
|
||||
// Don't execute middleware if Next returns true
|
||||
if cfg.Next != nil && cfg.Next(c) {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
// Set request and response
|
||||
req := c.Request()
|
||||
res := c.Response()
|
||||
|
||||
// Don't proxy "Connection" header
|
||||
req.Header.Del(fiber.HeaderConnection)
|
||||
|
||||
// Modify request
|
||||
if cfg.ModifyRequest != nil {
|
||||
if err := cfg.ModifyRequest(c); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
req.SetRequestURI(utils.UnsafeString(req.RequestURI()))
|
||||
|
||||
// Forward request
|
||||
if err := lbc.Do(req, res); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Don't proxy "Connection" header
|
||||
res.Header.Del(fiber.HeaderConnection)
|
||||
|
||||
// Modify response
|
||||
if cfg.ModifyResponse != nil {
|
||||
if err := cfg.ModifyResponse(c); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Return nil to end proxying if no error
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
var client = &fasthttp.Client{
|
||||
NoDefaultUserAgentHeader: true,
|
||||
DisablePathNormalizing: true,
|
||||
}
|
||||
|
||||
var lock sync.RWMutex
|
||||
|
||||
// WithTlsConfig update http client with a user specified tls.config
|
||||
// This function should be called before Do and Forward.
|
||||
// Deprecated: use WithClient instead.
|
||||
//
|
||||
//nolint:stylecheck,revive // TODO: Rename to "WithTLSConfig" in v3
|
||||
func WithTlsConfig(tlsConfig *tls.Config) {
|
||||
client.TLSConfig = tlsConfig
|
||||
}
|
||||
|
||||
// WithClient sets the global proxy client.
|
||||
// This function should be called before Do and Forward.
|
||||
func WithClient(cli *fasthttp.Client) {
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
client = cli
|
||||
}
|
||||
|
||||
// Forward performs the given http request and fills the given http response.
|
||||
// This method will return an fiber.Handler
|
||||
func Forward(addr string, clients ...*fasthttp.Client) fiber.Handler {
|
||||
return func(c *fiber.Ctx) error {
|
||||
return Do(c, addr, clients...)
|
||||
}
|
||||
}
|
||||
|
||||
// Do performs the given http request and fills the given http response.
|
||||
// This method can be used within a fiber.Handler
|
||||
func Do(c *fiber.Ctx, addr string, clients ...*fasthttp.Client) error {
|
||||
return doAction(c, addr, func(cli *fasthttp.Client, req *fasthttp.Request, resp *fasthttp.Response) error {
|
||||
return cli.Do(req, resp)
|
||||
}, clients...)
|
||||
}
|
||||
|
||||
// DoRedirects performs the given http request and fills the given http response, following up to maxRedirectsCount redirects.
|
||||
// When the redirect count exceeds maxRedirectsCount, ErrTooManyRedirects is returned.
|
||||
// This method can be used within a fiber.Handler
|
||||
func DoRedirects(c *fiber.Ctx, addr string, maxRedirectsCount int, clients ...*fasthttp.Client) error {
|
||||
return doAction(c, addr, func(cli *fasthttp.Client, req *fasthttp.Request, resp *fasthttp.Response) error {
|
||||
return cli.DoRedirects(req, resp, maxRedirectsCount)
|
||||
}, clients...)
|
||||
}
|
||||
|
||||
// DoDeadline performs the given request and waits for response until the given deadline.
|
||||
// This method can be used within a fiber.Handler
|
||||
func DoDeadline(c *fiber.Ctx, addr string, deadline time.Time, clients ...*fasthttp.Client) error {
|
||||
return doAction(c, addr, func(cli *fasthttp.Client, req *fasthttp.Request, resp *fasthttp.Response) error {
|
||||
return cli.DoDeadline(req, resp, deadline)
|
||||
}, clients...)
|
||||
}
|
||||
|
||||
// DoTimeout performs the given request and waits for response during the given timeout duration.
|
||||
// This method can be used within a fiber.Handler
|
||||
func DoTimeout(c *fiber.Ctx, addr string, timeout time.Duration, clients ...*fasthttp.Client) error {
|
||||
return doAction(c, addr, func(cli *fasthttp.Client, req *fasthttp.Request, resp *fasthttp.Response) error {
|
||||
return cli.DoTimeout(req, resp, timeout)
|
||||
}, clients...)
|
||||
}
|
||||
|
||||
func doAction(
|
||||
c *fiber.Ctx,
|
||||
addr string,
|
||||
action func(cli *fasthttp.Client, req *fasthttp.Request, resp *fasthttp.Response) error,
|
||||
clients ...*fasthttp.Client,
|
||||
) error {
|
||||
var cli *fasthttp.Client
|
||||
|
||||
// set local or global client
|
||||
if len(clients) != 0 {
|
||||
cli = clients[0]
|
||||
} else {
|
||||
lock.RLock()
|
||||
cli = client
|
||||
lock.RUnlock()
|
||||
}
|
||||
|
||||
req := c.Request()
|
||||
res := c.Response()
|
||||
originalURL := utils.CopyString(c.OriginalURL())
|
||||
defer req.SetRequestURI(originalURL)
|
||||
|
||||
copiedURL := utils.CopyString(addr)
|
||||
req.SetRequestURI(copiedURL)
|
||||
// NOTE: if req.isTLS is true, SetRequestURI keeps the scheme as https.
|
||||
// Reference: https://github.com/gofiber/fiber/issues/1762
|
||||
if scheme := getScheme(utils.UnsafeBytes(copiedURL)); len(scheme) > 0 {
|
||||
req.URI().SetSchemeBytes(scheme)
|
||||
}
|
||||
|
||||
req.Header.Del(fiber.HeaderConnection)
|
||||
if err := action(cli, req, res); err != nil {
|
||||
return err
|
||||
}
|
||||
res.Header.Del(fiber.HeaderConnection)
|
||||
return nil
|
||||
}
|
||||
|
||||
func getScheme(uri []byte) []byte {
|
||||
i := bytes.IndexByte(uri, '/')
|
||||
if i < 1 || uri[i-1] != ':' || i == len(uri)-1 || uri[i+1] != '/' {
|
||||
return nil
|
||||
}
|
||||
return uri[:i-1]
|
||||
}
|
||||
|
||||
// DomainForward performs an http request based on the given domain and populates the given http response.
|
||||
// This method will return an fiber.Handler
|
||||
func DomainForward(hostname, addr string, clients ...*fasthttp.Client) fiber.Handler {
|
||||
return func(c *fiber.Ctx) error {
|
||||
host := string(c.Request().Host())
|
||||
if host == hostname {
|
||||
return Do(c, addr+c.OriginalURL(), clients...)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
type roundrobin struct {
|
||||
sync.Mutex
|
||||
|
||||
current int
|
||||
pool []string
|
||||
}
|
||||
|
||||
// this method will return a string of addr server from list server.
|
||||
func (r *roundrobin) get() string {
|
||||
r.Lock()
|
||||
defer r.Unlock()
|
||||
|
||||
if r.current >= len(r.pool) {
|
||||
r.current %= len(r.pool)
|
||||
}
|
||||
|
||||
result := r.pool[r.current]
|
||||
r.current++
|
||||
return result
|
||||
}
|
||||
|
||||
// BalancerForward Forward performs the given http request with round robin algorithm to server and fills the given http response.
|
||||
// This method will return an fiber.Handler
|
||||
func BalancerForward(servers []string, clients ...*fasthttp.Client) fiber.Handler {
|
||||
r := &roundrobin{
|
||||
current: 0,
|
||||
pool: servers,
|
||||
}
|
||||
return func(c *fiber.Ctx) error {
|
||||
server := r.get()
|
||||
if !strings.HasPrefix(server, "http") {
|
||||
server = "http://" + server
|
||||
}
|
||||
c.Request().Header.Add("X-Real-IP", c.IP())
|
||||
return Do(c, server+c.OriginalURL(), clients...)
|
||||
}
|
||||
}
|
689
middleware/proxy/proxy_test.go
Normal file
689
middleware/proxy/proxy_test.go
Normal file
|
@ -0,0 +1,689 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/internal/tlstest"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
func createProxyTestServer(t *testing.T, handler fiber.Handler) (*fiber.App, string) {
|
||||
t.Helper()
|
||||
|
||||
target := fiber.New(fiber.Config{DisableStartupMessage: true})
|
||||
target.Get("/", handler)
|
||||
|
||||
ln, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0")
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
go func() {
|
||||
utils.AssertEqual(t, nil, target.Listener(ln))
|
||||
}()
|
||||
|
||||
time.Sleep(2 * time.Second)
|
||||
addr := ln.Addr().String()
|
||||
|
||||
return target, addr
|
||||
}
|
||||
|
||||
// go test -run Test_Proxy_Empty_Host
|
||||
func Test_Proxy_Empty_Upstream_Servers(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
utils.AssertEqual(t, "Servers cannot be empty", r)
|
||||
}
|
||||
}()
|
||||
app := fiber.New()
|
||||
app.Use(Balancer(Config{Servers: []string{}}))
|
||||
}
|
||||
|
||||
// go test -run Test_Proxy_Empty_Config
|
||||
func Test_Proxy_Empty_Config(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
utils.AssertEqual(t, "Servers cannot be empty", r)
|
||||
}
|
||||
}()
|
||||
app := fiber.New()
|
||||
app.Use(New(Config{}))
|
||||
}
|
||||
|
||||
// go test -run Test_Proxy_Next
|
||||
func Test_Proxy_Next(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app := fiber.New()
|
||||
app.Use(Balancer(Config{
|
||||
Servers: []string{"127.0.0.1"},
|
||||
Next: func(_ *fiber.Ctx) bool {
|
||||
return true
|
||||
},
|
||||
}))
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
|
||||
}
|
||||
|
||||
// go test -run Test_Proxy
|
||||
func Test_Proxy(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
target, addr := createProxyTestServer(t, func(c *fiber.Ctx) error {
|
||||
return c.SendStatus(fiber.StatusTeapot)
|
||||
})
|
||||
|
||||
resp, err := target.Test(httptest.NewRequest(fiber.MethodGet, "/", nil), 2000)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusTeapot, resp.StatusCode)
|
||||
|
||||
app := fiber.New(fiber.Config{DisableStartupMessage: true})
|
||||
|
||||
app.Use(Balancer(Config{Servers: []string{addr}}))
|
||||
|
||||
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||
req.Host = addr
|
||||
resp, err = app.Test(req)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusTeapot, resp.StatusCode)
|
||||
}
|
||||
|
||||
// go test -run Test_Proxy_Balancer_WithTlsConfig
|
||||
func Test_Proxy_Balancer_WithTlsConfig(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
serverTLSConf, _, err := tlstest.GetTLSConfigs()
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
ln, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0")
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
ln = tls.NewListener(ln, serverTLSConf)
|
||||
|
||||
app := fiber.New(fiber.Config{DisableStartupMessage: true})
|
||||
|
||||
app.Get("/tlsbalaner", func(c *fiber.Ctx) error {
|
||||
return c.SendString("tls balancer")
|
||||
})
|
||||
|
||||
addr := ln.Addr().String()
|
||||
clientTLSConf := &tls.Config{InsecureSkipVerify: true} //nolint:gosec // We're in a test func, so this is fine
|
||||
|
||||
// disable certificate verification in Balancer
|
||||
app.Use(Balancer(Config{
|
||||
Servers: []string{addr},
|
||||
TlsConfig: clientTLSConf,
|
||||
}))
|
||||
|
||||
go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }()
|
||||
|
||||
code, body, errs := fiber.Get("https://" + addr + "/tlsbalaner").TLSConfig(clientTLSConf).String()
|
||||
|
||||
utils.AssertEqual(t, 0, len(errs))
|
||||
utils.AssertEqual(t, fiber.StatusOK, code)
|
||||
utils.AssertEqual(t, "tls balancer", body)
|
||||
}
|
||||
|
||||
// go test -run Test_Proxy_Forward_WithTlsConfig_To_Http
|
||||
func Test_Proxy_Forward_WithTlsConfig_To_Http(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, targetAddr := createProxyTestServer(t, func(c *fiber.Ctx) error {
|
||||
return c.SendString("hello from target")
|
||||
})
|
||||
|
||||
proxyServerTLSConf, _, err := tlstest.GetTLSConfigs()
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
proxyServerLn, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0")
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
proxyServerLn = tls.NewListener(proxyServerLn, proxyServerTLSConf)
|
||||
|
||||
app := fiber.New(fiber.Config{DisableStartupMessage: true})
|
||||
|
||||
proxyAddr := proxyServerLn.Addr().String()
|
||||
|
||||
app.Use(Forward("http://" + targetAddr))
|
||||
|
||||
go func() { utils.AssertEqual(t, nil, app.Listener(proxyServerLn)) }()
|
||||
|
||||
code, body, errs := fiber.Get("https://" + proxyAddr).
|
||||
InsecureSkipVerify().
|
||||
Timeout(5 * time.Second).
|
||||
String()
|
||||
|
||||
utils.AssertEqual(t, 0, len(errs))
|
||||
utils.AssertEqual(t, fiber.StatusOK, code)
|
||||
utils.AssertEqual(t, "hello from target", body)
|
||||
}
|
||||
|
||||
// go test -run Test_Proxy_Forward
|
||||
func Test_Proxy_Forward(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app := fiber.New()
|
||||
|
||||
_, addr := createProxyTestServer(t, func(c *fiber.Ctx) error {
|
||||
return c.SendString("forwarded")
|
||||
})
|
||||
|
||||
app.Use(Forward("http://" + addr))
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||
|
||||
b, err := io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, "forwarded", string(b))
|
||||
}
|
||||
|
||||
// go test -run Test_Proxy_Forward_WithTlsConfig
|
||||
func Test_Proxy_Forward_WithTlsConfig(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
serverTLSConf, _, err := tlstest.GetTLSConfigs()
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
ln, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0")
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
ln = tls.NewListener(ln, serverTLSConf)
|
||||
|
||||
app := fiber.New(fiber.Config{DisableStartupMessage: true})
|
||||
|
||||
app.Get("/tlsfwd", func(c *fiber.Ctx) error {
|
||||
return c.SendString("tls forward")
|
||||
})
|
||||
|
||||
addr := ln.Addr().String()
|
||||
clientTLSConf := &tls.Config{InsecureSkipVerify: true} //nolint:gosec // We're in a test func, so this is fine
|
||||
|
||||
// disable certificate verification
|
||||
WithTlsConfig(clientTLSConf)
|
||||
app.Use(Forward("https://" + addr + "/tlsfwd"))
|
||||
|
||||
go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }()
|
||||
|
||||
code, body, errs := fiber.Get("https://" + addr).TLSConfig(clientTLSConf).String()
|
||||
|
||||
utils.AssertEqual(t, 0, len(errs))
|
||||
utils.AssertEqual(t, fiber.StatusOK, code)
|
||||
utils.AssertEqual(t, "tls forward", body)
|
||||
}
|
||||
|
||||
// go test -run Test_Proxy_Modify_Response
|
||||
func Test_Proxy_Modify_Response(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, addr := createProxyTestServer(t, func(c *fiber.Ctx) error {
|
||||
return c.Status(500).SendString("not modified")
|
||||
})
|
||||
|
||||
app := fiber.New()
|
||||
app.Use(Balancer(Config{
|
||||
Servers: []string{addr},
|
||||
ModifyResponse: func(c *fiber.Ctx) error {
|
||||
c.Response().SetStatusCode(fiber.StatusOK)
|
||||
return c.SendString("modified response")
|
||||
},
|
||||
}))
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||
|
||||
b, err := io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, "modified response", string(b))
|
||||
}
|
||||
|
||||
// go test -run Test_Proxy_Modify_Request
|
||||
func Test_Proxy_Modify_Request(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, addr := createProxyTestServer(t, func(c *fiber.Ctx) error {
|
||||
b := c.Request().Body()
|
||||
return c.SendString(string(b))
|
||||
})
|
||||
|
||||
app := fiber.New()
|
||||
app.Use(Balancer(Config{
|
||||
Servers: []string{addr},
|
||||
ModifyRequest: func(c *fiber.Ctx) error {
|
||||
c.Request().SetBody([]byte("modified request"))
|
||||
return nil
|
||||
},
|
||||
}))
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||
|
||||
b, err := io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, "modified request", string(b))
|
||||
}
|
||||
|
||||
// go test -run Test_Proxy_Timeout_Slow_Server
|
||||
func Test_Proxy_Timeout_Slow_Server(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, addr := createProxyTestServer(t, func(c *fiber.Ctx) error {
|
||||
time.Sleep(2 * time.Second)
|
||||
return c.SendString("fiber is awesome")
|
||||
})
|
||||
|
||||
app := fiber.New()
|
||||
app.Use(Balancer(Config{
|
||||
Servers: []string{addr},
|
||||
Timeout: 3 * time.Second,
|
||||
}))
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil), 5000)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||
|
||||
b, err := io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, "fiber is awesome", string(b))
|
||||
}
|
||||
|
||||
// go test -run Test_Proxy_With_Timeout
|
||||
func Test_Proxy_With_Timeout(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, addr := createProxyTestServer(t, func(c *fiber.Ctx) error {
|
||||
time.Sleep(1 * time.Second)
|
||||
return c.SendString("fiber is awesome")
|
||||
})
|
||||
|
||||
app := fiber.New()
|
||||
app.Use(Balancer(Config{
|
||||
Servers: []string{addr},
|
||||
Timeout: 100 * time.Millisecond,
|
||||
}))
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil), 2000)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusInternalServerError, resp.StatusCode)
|
||||
|
||||
b, err := io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, "timeout", string(b))
|
||||
}
|
||||
|
||||
// go test -run Test_Proxy_Buffer_Size_Response
|
||||
func Test_Proxy_Buffer_Size_Response(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, addr := createProxyTestServer(t, func(c *fiber.Ctx) error {
|
||||
long := strings.Join(make([]string, 5000), "-")
|
||||
c.Set("Very-Long-Header", long)
|
||||
return c.SendString("ok")
|
||||
})
|
||||
|
||||
app := fiber.New()
|
||||
app.Use(Balancer(Config{Servers: []string{addr}}))
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusInternalServerError, resp.StatusCode)
|
||||
|
||||
app = fiber.New()
|
||||
app.Use(Balancer(Config{
|
||||
Servers: []string{addr},
|
||||
ReadBufferSize: 1024 * 8,
|
||||
}))
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||
}
|
||||
|
||||
// go test -race -run Test_Proxy_Do_RestoreOriginalURL
|
||||
func Test_Proxy_Do_RestoreOriginalURL(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, addr := createProxyTestServer(t, func(c *fiber.Ctx) error {
|
||||
return c.SendString("proxied")
|
||||
})
|
||||
|
||||
app := fiber.New()
|
||||
app.Get("/test", func(c *fiber.Ctx) error {
|
||||
return Do(c, "http://"+addr)
|
||||
})
|
||||
resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil))
|
||||
utils.AssertEqual(t, nil, err1)
|
||||
utils.AssertEqual(t, "/test", resp.Request.URL.String())
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, "proxied", string(body))
|
||||
}
|
||||
|
||||
// go test -race -run Test_Proxy_Do_WithRealURL
|
||||
func Test_Proxy_Do_WithRealURL(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
app.Get("/test", func(c *fiber.Ctx) error {
|
||||
return Do(c, "https://www.google.com")
|
||||
})
|
||||
|
||||
resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil))
|
||||
utils.AssertEqual(t, nil, err1)
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||
utils.AssertEqual(t, "/test", resp.Request.URL.String())
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, true, strings.Contains(string(body), "https://www.google.com/"))
|
||||
}
|
||||
|
||||
// go test -race -run Test_Proxy_Do_WithRedirect
|
||||
func Test_Proxy_Do_WithRedirect(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
app.Get("/test", func(c *fiber.Ctx) error {
|
||||
return Do(c, "https://google.com")
|
||||
})
|
||||
|
||||
resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil))
|
||||
utils.AssertEqual(t, nil, err1)
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, true, strings.Contains(string(body), "https://www.google.com/"))
|
||||
utils.AssertEqual(t, 301, resp.StatusCode)
|
||||
}
|
||||
|
||||
// go test -race -run Test_Proxy_DoRedirects_RestoreOriginalURL
|
||||
func Test_Proxy_DoRedirects_RestoreOriginalURL(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
app.Get("/test", func(c *fiber.Ctx) error {
|
||||
return DoRedirects(c, "http://google.com", 1)
|
||||
})
|
||||
|
||||
resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil))
|
||||
utils.AssertEqual(t, nil, err1)
|
||||
_, err := io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||
utils.AssertEqual(t, "/test", resp.Request.URL.String())
|
||||
}
|
||||
|
||||
// go test -race -run Test_Proxy_DoRedirects_TooManyRedirects
|
||||
func Test_Proxy_DoRedirects_TooManyRedirects(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
app.Get("/test", func(c *fiber.Ctx) error {
|
||||
return DoRedirects(c, "http://google.com", 0)
|
||||
})
|
||||
|
||||
resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil))
|
||||
utils.AssertEqual(t, nil, err1)
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, "too many redirects detected when doing the request", string(body))
|
||||
utils.AssertEqual(t, fiber.StatusInternalServerError, resp.StatusCode)
|
||||
utils.AssertEqual(t, "/test", resp.Request.URL.String())
|
||||
}
|
||||
|
||||
// go test -race -run Test_Proxy_DoTimeout_RestoreOriginalURL
|
||||
func Test_Proxy_DoTimeout_RestoreOriginalURL(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, addr := createProxyTestServer(t, func(c *fiber.Ctx) error {
|
||||
return c.SendString("proxied")
|
||||
})
|
||||
|
||||
app := fiber.New()
|
||||
app.Get("/test", func(c *fiber.Ctx) error {
|
||||
return DoTimeout(c, "http://"+addr, time.Second)
|
||||
})
|
||||
|
||||
resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil))
|
||||
utils.AssertEqual(t, nil, err1)
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, "proxied", string(body))
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||
utils.AssertEqual(t, "/test", resp.Request.URL.String())
|
||||
}
|
||||
|
||||
// go test -race -run Test_Proxy_DoTimeout_Timeout
|
||||
func Test_Proxy_DoTimeout_Timeout(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, addr := createProxyTestServer(t, func(c *fiber.Ctx) error {
|
||||
time.Sleep(time.Second * 5)
|
||||
return c.SendString("proxied")
|
||||
})
|
||||
|
||||
app := fiber.New()
|
||||
app.Get("/test", func(c *fiber.Ctx) error {
|
||||
return DoTimeout(c, "http://"+addr, time.Second)
|
||||
})
|
||||
|
||||
_, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil))
|
||||
utils.AssertEqual(t, errors.New("test: timeout error 1000ms"), err1)
|
||||
}
|
||||
|
||||
// go test -race -run Test_Proxy_DoDeadline_RestoreOriginalURL
|
||||
func Test_Proxy_DoDeadline_RestoreOriginalURL(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, addr := createProxyTestServer(t, func(c *fiber.Ctx) error {
|
||||
return c.SendString("proxied")
|
||||
})
|
||||
|
||||
app := fiber.New()
|
||||
app.Get("/test", func(c *fiber.Ctx) error {
|
||||
return DoDeadline(c, "http://"+addr, time.Now().Add(time.Second))
|
||||
})
|
||||
|
||||
resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil))
|
||||
utils.AssertEqual(t, nil, err1)
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, "proxied", string(body))
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||
utils.AssertEqual(t, "/test", resp.Request.URL.String())
|
||||
}
|
||||
|
||||
// go test -race -run Test_Proxy_DoDeadline_PastDeadline
|
||||
func Test_Proxy_DoDeadline_PastDeadline(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, addr := createProxyTestServer(t, func(c *fiber.Ctx) error {
|
||||
time.Sleep(time.Second * 5)
|
||||
return c.SendString("proxied")
|
||||
})
|
||||
|
||||
app := fiber.New()
|
||||
app.Get("/test", func(c *fiber.Ctx) error {
|
||||
return DoDeadline(c, "http://"+addr, time.Now().Add(time.Second))
|
||||
})
|
||||
|
||||
_, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil))
|
||||
utils.AssertEqual(t, errors.New("test: timeout error 1000ms"), err1)
|
||||
}
|
||||
|
||||
// go test -race -run Test_Proxy_Do_HTTP_Prefix_URL
|
||||
func Test_Proxy_Do_HTTP_Prefix_URL(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, addr := createProxyTestServer(t, func(c *fiber.Ctx) error {
|
||||
return c.SendString("hello world")
|
||||
})
|
||||
|
||||
app := fiber.New(fiber.Config{DisableStartupMessage: true})
|
||||
app.Get("/*", func(c *fiber.Ctx) error {
|
||||
path := c.OriginalURL()
|
||||
url := strings.TrimPrefix(path, "/")
|
||||
|
||||
utils.AssertEqual(t, "http://"+addr, url)
|
||||
if err := Do(c, url); err != nil {
|
||||
return err
|
||||
}
|
||||
c.Response().Header.Del(fiber.HeaderServer)
|
||||
return nil
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/http://"+addr, nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
s, err := io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, "hello world", string(s))
|
||||
}
|
||||
|
||||
// go test -race -run Test_Proxy_Forward_Global_Client
|
||||
func Test_Proxy_Forward_Global_Client(t *testing.T) {
|
||||
t.Parallel()
|
||||
ln, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0")
|
||||
utils.AssertEqual(t, nil, err)
|
||||
WithClient(&fasthttp.Client{
|
||||
NoDefaultUserAgentHeader: true,
|
||||
DisablePathNormalizing: true,
|
||||
})
|
||||
app := fiber.New(fiber.Config{DisableStartupMessage: true})
|
||||
app.Get("/test_global_client", func(c *fiber.Ctx) error {
|
||||
return c.SendString("test_global_client")
|
||||
})
|
||||
|
||||
addr := ln.Addr().String()
|
||||
app.Use(Forward("http://" + addr + "/test_global_client"))
|
||||
go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }()
|
||||
|
||||
code, body, errs := fiber.Get("http://" + addr).String()
|
||||
utils.AssertEqual(t, 0, len(errs))
|
||||
utils.AssertEqual(t, fiber.StatusOK, code)
|
||||
utils.AssertEqual(t, "test_global_client", body)
|
||||
}
|
||||
|
||||
// go test -race -run Test_Proxy_Forward_Local_Client
|
||||
func Test_Proxy_Forward_Local_Client(t *testing.T) {
|
||||
t.Parallel()
|
||||
ln, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0")
|
||||
utils.AssertEqual(t, nil, err)
|
||||
app := fiber.New(fiber.Config{DisableStartupMessage: true})
|
||||
app.Get("/test_local_client", func(c *fiber.Ctx) error {
|
||||
return c.SendString("test_local_client")
|
||||
})
|
||||
|
||||
addr := ln.Addr().String()
|
||||
app.Use(Forward("http://"+addr+"/test_local_client", &fasthttp.Client{
|
||||
NoDefaultUserAgentHeader: true,
|
||||
DisablePathNormalizing: true,
|
||||
|
||||
Dial: fasthttp.Dial,
|
||||
}))
|
||||
go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }()
|
||||
|
||||
code, body, errs := fiber.Get("http://" + addr).String()
|
||||
utils.AssertEqual(t, 0, len(errs))
|
||||
utils.AssertEqual(t, fiber.StatusOK, code)
|
||||
utils.AssertEqual(t, "test_local_client", body)
|
||||
}
|
||||
|
||||
// go test -run Test_ProxyBalancer_Custom_Client
|
||||
func Test_ProxyBalancer_Custom_Client(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
target, addr := createProxyTestServer(t, func(c *fiber.Ctx) error {
|
||||
return c.SendStatus(fiber.StatusTeapot)
|
||||
})
|
||||
|
||||
resp, err := target.Test(httptest.NewRequest(fiber.MethodGet, "/", nil), 2000)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusTeapot, resp.StatusCode)
|
||||
|
||||
app := fiber.New(fiber.Config{DisableStartupMessage: true})
|
||||
|
||||
app.Use(Balancer(Config{Client: &fasthttp.LBClient{
|
||||
Clients: []fasthttp.BalancingClient{
|
||||
&fasthttp.HostClient{
|
||||
NoDefaultUserAgentHeader: true,
|
||||
DisablePathNormalizing: true,
|
||||
Addr: addr,
|
||||
},
|
||||
},
|
||||
Timeout: time.Second,
|
||||
}}))
|
||||
|
||||
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||
req.Host = addr
|
||||
resp, err = app.Test(req)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusTeapot, resp.StatusCode)
|
||||
}
|
||||
|
||||
// go test -run Test_Proxy_Domain_Forward_Local
|
||||
func Test_Proxy_Domain_Forward_Local(t *testing.T) {
|
||||
t.Parallel()
|
||||
ln, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0")
|
||||
utils.AssertEqual(t, nil, err)
|
||||
app := fiber.New(fiber.Config{DisableStartupMessage: true})
|
||||
|
||||
// target server
|
||||
ln1, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0")
|
||||
utils.AssertEqual(t, nil, err)
|
||||
app1 := fiber.New(fiber.Config{DisableStartupMessage: true})
|
||||
|
||||
app1.Get("/test", func(c *fiber.Ctx) error {
|
||||
return c.SendString("test_local_client:" + c.Query("query_test"))
|
||||
})
|
||||
|
||||
proxyAddr := ln.Addr().String()
|
||||
targetAddr := ln1.Addr().String()
|
||||
localDomain := strings.Replace(proxyAddr, "127.0.0.1", "localhost", 1)
|
||||
app.Use(DomainForward(localDomain, "http://"+targetAddr, &fasthttp.Client{
|
||||
NoDefaultUserAgentHeader: true,
|
||||
DisablePathNormalizing: true,
|
||||
|
||||
Dial: fasthttp.Dial,
|
||||
}))
|
||||
|
||||
go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }()
|
||||
go func() { utils.AssertEqual(t, nil, app1.Listener(ln1)) }()
|
||||
|
||||
code, body, errs := fiber.Get("http://" + localDomain + "/test?query_test=true").String()
|
||||
utils.AssertEqual(t, 0, len(errs))
|
||||
utils.AssertEqual(t, fiber.StatusOK, code)
|
||||
utils.AssertEqual(t, "test_local_client:true", body)
|
||||
}
|
||||
|
||||
// go test -run Test_Proxy_Balancer_Forward_Local
|
||||
func Test_Proxy_Balancer_Forward_Local(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app := fiber.New()
|
||||
|
||||
_, addr := createProxyTestServer(t, func(c *fiber.Ctx) error {
|
||||
return c.SendString("forwarded")
|
||||
})
|
||||
|
||||
app.Use(BalancerForward([]string{addr}))
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||
|
||||
b, err := io.ReadAll(resp.Body)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
|
||||
utils.AssertEqual(t, string(b), "forwarded")
|
||||
}
|
47
middleware/recover/config.go
Normal file
47
middleware/recover/config.go
Normal file
|
@ -0,0 +1,47 @@
|
|||
package recover //nolint:predeclared // TODO: Rename to some non-builtin
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
// Config defines the config for middleware.
|
||||
type Config struct {
|
||||
// Next defines a function to skip this middleware when returned true.
|
||||
//
|
||||
// Optional. Default: nil
|
||||
Next func(c *fiber.Ctx) bool
|
||||
|
||||
// EnableStackTrace enables handling stack trace
|
||||
//
|
||||
// Optional. Default: false
|
||||
EnableStackTrace bool
|
||||
|
||||
// StackTraceHandler defines a function to handle stack trace
|
||||
//
|
||||
// Optional. Default: defaultStackTraceHandler
|
||||
StackTraceHandler func(c *fiber.Ctx, e interface{})
|
||||
}
|
||||
|
||||
// ConfigDefault is the default config
|
||||
var ConfigDefault = Config{
|
||||
Next: nil,
|
||||
EnableStackTrace: false,
|
||||
StackTraceHandler: defaultStackTraceHandler,
|
||||
}
|
||||
|
||||
// Helper function to set default values
|
||||
func configDefault(config ...Config) Config {
|
||||
// Return default config if nothing provided
|
||||
if len(config) < 1 {
|
||||
return ConfigDefault
|
||||
}
|
||||
|
||||
// Override default config
|
||||
cfg := config[0]
|
||||
|
||||
if cfg.EnableStackTrace && cfg.StackTraceHandler == nil {
|
||||
cfg.StackTraceHandler = defaultStackTraceHandler
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
45
middleware/recover/recover.go
Normal file
45
middleware/recover/recover.go
Normal file
|
@ -0,0 +1,45 @@
|
|||
package recover //nolint:predeclared // TODO: Rename to some non-builtin
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime/debug"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
func defaultStackTraceHandler(_ *fiber.Ctx, e interface{}) {
|
||||
_, _ = os.Stderr.WriteString(fmt.Sprintf("panic: %v\n%s\n", e, debug.Stack())) //nolint:errcheck // This will never fail
|
||||
}
|
||||
|
||||
// New creates a new middleware handler
|
||||
func New(config ...Config) fiber.Handler {
|
||||
// Set default config
|
||||
cfg := configDefault(config...)
|
||||
|
||||
// Return new handler
|
||||
return func(c *fiber.Ctx) (err error) { //nolint:nonamedreturns // Uses recover() to overwrite the error
|
||||
// Don't execute middleware if Next returns true
|
||||
if cfg.Next != nil && cfg.Next(c) {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
// Catch panics
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
if cfg.EnableStackTrace {
|
||||
cfg.StackTraceHandler(c, r)
|
||||
}
|
||||
|
||||
var ok bool
|
||||
if err, ok = r.(error); !ok {
|
||||
// Set error that will call the global error handler
|
||||
err = fmt.Errorf("%v", r)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Return err if exist, else move to next handler
|
||||
return c.Next()
|
||||
}
|
||||
}
|
61
middleware/recover/recover_test.go
Normal file
61
middleware/recover/recover_test.go
Normal file
|
@ -0,0 +1,61 @@
|
|||
package recover //nolint:predeclared // TODO: Rename to some non-builtin
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
)
|
||||
|
||||
// go test -run Test_Recover
|
||||
func Test_Recover(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New(fiber.Config{
|
||||
ErrorHandler: func(c *fiber.Ctx, err error) error {
|
||||
utils.AssertEqual(t, "Hi, I'm an error!", err.Error())
|
||||
return c.SendStatus(fiber.StatusTeapot)
|
||||
},
|
||||
})
|
||||
|
||||
app.Use(New())
|
||||
|
||||
app.Get("/panic", func(c *fiber.Ctx) error {
|
||||
panic("Hi, I'm an error!")
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/panic", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusTeapot, resp.StatusCode)
|
||||
}
|
||||
|
||||
// go test -run Test_Recover_Next
|
||||
func Test_Recover_Next(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
app.Use(New(Config{
|
||||
Next: func(_ *fiber.Ctx) bool {
|
||||
return true
|
||||
},
|
||||
}))
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
|
||||
}
|
||||
|
||||
func Test_Recover_EnableStackTrace(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
app.Use(New(Config{
|
||||
EnableStackTrace: true,
|
||||
}))
|
||||
|
||||
app.Get("/panic", func(c *fiber.Ctx) error {
|
||||
panic("Hi, I'm an error!")
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/panic", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusInternalServerError, resp.StatusCode)
|
||||
}
|
53
middleware/redirect/config.go
Normal file
53
middleware/redirect/config.go
Normal file
|
@ -0,0 +1,53 @@
|
|||
package redirect
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
// Config defines the config for middleware.
|
||||
type Config struct {
|
||||
// Filter defines a function to skip middleware.
|
||||
// Optional. Default: nil
|
||||
Next func(*fiber.Ctx) bool
|
||||
|
||||
// Rules defines the URL path rewrite rules. The values captured in asterisk can be
|
||||
// retrieved by index e.g. $1, $2 and so on.
|
||||
// Required. Example:
|
||||
// "/old": "/new",
|
||||
// "/api/*": "/$1",
|
||||
// "/js/*": "/public/javascripts/$1",
|
||||
// "/users/*/orders/*": "/user/$1/order/$2",
|
||||
Rules map[string]string
|
||||
|
||||
// The status code when redirecting
|
||||
// This is ignored if Redirect is disabled
|
||||
// Optional. Default: 302 Temporary Redirect
|
||||
StatusCode int
|
||||
|
||||
rulesRegex map[*regexp.Regexp]string
|
||||
}
|
||||
|
||||
// ConfigDefault is the default config
|
||||
var ConfigDefault = Config{
|
||||
StatusCode: fiber.StatusFound,
|
||||
}
|
||||
|
||||
// Helper function to set default values
|
||||
func configDefault(config ...Config) Config {
|
||||
// Return default config if nothing provided
|
||||
if len(config) < 1 {
|
||||
return ConfigDefault
|
||||
}
|
||||
|
||||
// Override default config
|
||||
cfg := config[0]
|
||||
|
||||
// Set default values
|
||||
if cfg.StatusCode == 0 {
|
||||
cfg.StatusCode = ConfigDefault.StatusCode
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
61
middleware/redirect/redirect.go
Normal file
61
middleware/redirect/redirect.go
Normal file
|
@ -0,0 +1,61 @@
|
|||
package redirect
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
// New creates a new middleware handler
|
||||
func New(config ...Config) fiber.Handler {
|
||||
cfg := configDefault(config...)
|
||||
|
||||
// Initialize
|
||||
cfg.rulesRegex = map[*regexp.Regexp]string{}
|
||||
for k, v := range cfg.Rules {
|
||||
k = strings.ReplaceAll(k, "*", "(.*)")
|
||||
k += "$"
|
||||
cfg.rulesRegex[regexp.MustCompile(k)] = v
|
||||
}
|
||||
|
||||
// Middleware function
|
||||
return func(c *fiber.Ctx) error {
|
||||
// Next request to skip middleware
|
||||
if cfg.Next != nil && cfg.Next(c) {
|
||||
return c.Next()
|
||||
}
|
||||
// Rewrite
|
||||
for k, v := range cfg.rulesRegex {
|
||||
replacer := captureTokens(k, c.Path())
|
||||
if replacer != nil {
|
||||
queryString := string(c.Context().QueryArgs().QueryString())
|
||||
if queryString != "" {
|
||||
queryString = "?" + queryString
|
||||
}
|
||||
return c.Redirect(replacer.Replace(v)+queryString, cfg.StatusCode)
|
||||
}
|
||||
}
|
||||
return c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/labstack/echo/blob/master/middleware/rewrite.go
|
||||
func captureTokens(pattern *regexp.Regexp, input string) *strings.Replacer {
|
||||
if len(input) > 1 {
|
||||
input = strings.TrimSuffix(input, "/")
|
||||
}
|
||||
groups := pattern.FindAllStringSubmatch(input, -1)
|
||||
if groups == nil {
|
||||
return nil
|
||||
}
|
||||
values := groups[0][1:]
|
||||
replace := make([]string, 2*len(values))
|
||||
for i, v := range values {
|
||||
j := 2 * i
|
||||
replace[j] = "$" + strconv.Itoa(i+1)
|
||||
replace[j+1] = v
|
||||
}
|
||||
return strings.NewReplacer(replace...)
|
||||
}
|
295
middleware/redirect/redirect_test.go
Normal file
295
middleware/redirect/redirect_test.go
Normal file
|
@ -0,0 +1,295 @@
|
|||
//nolint:bodyclose // Much easier to just ignore memory leaks in tests
|
||||
package redirect
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
)
|
||||
|
||||
func Test_Redirect(t *testing.T) {
|
||||
app := *fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
Rules: map[string]string{
|
||||
"/default": "google.com",
|
||||
},
|
||||
StatusCode: fiber.StatusMovedPermanently,
|
||||
}))
|
||||
app.Use(New(Config{
|
||||
Rules: map[string]string{
|
||||
"/default/*": "fiber.wiki",
|
||||
},
|
||||
StatusCode: fiber.StatusTemporaryRedirect,
|
||||
}))
|
||||
app.Use(New(Config{
|
||||
Rules: map[string]string{
|
||||
"/redirect/*": "$1",
|
||||
},
|
||||
StatusCode: fiber.StatusSeeOther,
|
||||
}))
|
||||
app.Use(New(Config{
|
||||
Rules: map[string]string{
|
||||
"/pattern/*": "golang.org",
|
||||
},
|
||||
StatusCode: fiber.StatusFound,
|
||||
}))
|
||||
|
||||
app.Use(New(Config{
|
||||
Rules: map[string]string{
|
||||
"/": "/swagger",
|
||||
},
|
||||
StatusCode: fiber.StatusMovedPermanently,
|
||||
}))
|
||||
app.Use(New(Config{
|
||||
Rules: map[string]string{
|
||||
"/params": "/with_params",
|
||||
},
|
||||
StatusCode: fiber.StatusMovedPermanently,
|
||||
}))
|
||||
|
||||
app.Get("/api/*", func(c *fiber.Ctx) error {
|
||||
return c.SendString("API")
|
||||
})
|
||||
|
||||
app.Get("/new", func(c *fiber.Ctx) error {
|
||||
return c.SendString("Hello, World!")
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
url string
|
||||
redirectTo string
|
||||
statusCode int
|
||||
}{
|
||||
{
|
||||
name: "should be returns status StatusFound without a wildcard",
|
||||
url: "/default",
|
||||
redirectTo: "google.com",
|
||||
statusCode: fiber.StatusMovedPermanently,
|
||||
},
|
||||
{
|
||||
name: "should be returns status StatusTemporaryRedirect using wildcard",
|
||||
url: "/default/xyz",
|
||||
redirectTo: "fiber.wiki",
|
||||
statusCode: fiber.StatusTemporaryRedirect,
|
||||
},
|
||||
{
|
||||
name: "should be returns status StatusSeeOther without set redirectTo to use the default",
|
||||
url: "/redirect/github.com/gofiber/redirect",
|
||||
redirectTo: "github.com/gofiber/redirect",
|
||||
statusCode: fiber.StatusSeeOther,
|
||||
},
|
||||
{
|
||||
name: "should return the status code default",
|
||||
url: "/pattern/xyz",
|
||||
redirectTo: "golang.org",
|
||||
statusCode: fiber.StatusFound,
|
||||
},
|
||||
{
|
||||
name: "access URL without rule",
|
||||
url: "/new",
|
||||
statusCode: fiber.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "redirect to swagger route",
|
||||
url: "/",
|
||||
redirectTo: "/swagger",
|
||||
statusCode: fiber.StatusMovedPermanently,
|
||||
},
|
||||
{
|
||||
name: "no redirect to swagger route",
|
||||
url: "/api/",
|
||||
statusCode: fiber.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "no redirect to swagger route #2",
|
||||
url: "/api/test",
|
||||
statusCode: fiber.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "redirect with query params",
|
||||
url: "/params?query=abc",
|
||||
redirectTo: "/with_params?query=abc",
|
||||
statusCode: fiber.StatusMovedPermanently,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req, err := http.NewRequestWithContext(context.Background(), fiber.MethodGet, tt.url, nil)
|
||||
utils.AssertEqual(t, err, nil)
|
||||
req.Header.Set("Location", "github.com/gofiber/redirect")
|
||||
resp, err := app.Test(req)
|
||||
|
||||
utils.AssertEqual(t, err, nil)
|
||||
utils.AssertEqual(t, tt.statusCode, resp.StatusCode)
|
||||
utils.AssertEqual(t, tt.redirectTo, resp.Header.Get("Location"))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Next(t *testing.T) {
|
||||
// Case 1 : Next function always returns true
|
||||
app := *fiber.New()
|
||||
app.Use(New(Config{
|
||||
Next: func(*fiber.Ctx) bool {
|
||||
return true
|
||||
},
|
||||
Rules: map[string]string{
|
||||
"/default": "google.com",
|
||||
},
|
||||
StatusCode: fiber.StatusMovedPermanently,
|
||||
}))
|
||||
|
||||
app.Use(func(c *fiber.Ctx) error {
|
||||
return c.SendStatus(fiber.StatusOK)
|
||||
})
|
||||
|
||||
req, err := http.NewRequestWithContext(context.Background(), fiber.MethodGet, "/default", nil)
|
||||
utils.AssertEqual(t, err, nil)
|
||||
resp, err := app.Test(req)
|
||||
utils.AssertEqual(t, err, nil)
|
||||
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||
|
||||
// Case 2 : Next function always returns false
|
||||
app = *fiber.New()
|
||||
app.Use(New(Config{
|
||||
Next: func(*fiber.Ctx) bool {
|
||||
return false
|
||||
},
|
||||
Rules: map[string]string{
|
||||
"/default": "google.com",
|
||||
},
|
||||
StatusCode: fiber.StatusMovedPermanently,
|
||||
}))
|
||||
|
||||
req, err = http.NewRequestWithContext(context.Background(), fiber.MethodGet, "/default", nil)
|
||||
utils.AssertEqual(t, err, nil)
|
||||
resp, err = app.Test(req)
|
||||
utils.AssertEqual(t, err, nil)
|
||||
|
||||
utils.AssertEqual(t, fiber.StatusMovedPermanently, resp.StatusCode)
|
||||
utils.AssertEqual(t, "google.com", resp.Header.Get("Location"))
|
||||
}
|
||||
|
||||
func Test_NoRules(t *testing.T) {
|
||||
// Case 1: No rules with default route defined
|
||||
app := *fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
StatusCode: fiber.StatusMovedPermanently,
|
||||
}))
|
||||
|
||||
app.Use(func(c *fiber.Ctx) error {
|
||||
return c.SendStatus(fiber.StatusOK)
|
||||
})
|
||||
|
||||
req, err := http.NewRequestWithContext(context.Background(), fiber.MethodGet, "/default", nil)
|
||||
utils.AssertEqual(t, err, nil)
|
||||
resp, err := app.Test(req)
|
||||
utils.AssertEqual(t, err, nil)
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||
|
||||
// Case 2: No rules and no default route defined
|
||||
app = *fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
StatusCode: fiber.StatusMovedPermanently,
|
||||
}))
|
||||
|
||||
req, err = http.NewRequestWithContext(context.Background(), fiber.MethodGet, "/default", nil)
|
||||
utils.AssertEqual(t, err, nil)
|
||||
resp, err = app.Test(req)
|
||||
utils.AssertEqual(t, err, nil)
|
||||
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
|
||||
}
|
||||
|
||||
func Test_DefaultConfig(t *testing.T) {
|
||||
// Case 1: Default config and no default route
|
||||
app := *fiber.New()
|
||||
|
||||
app.Use(New())
|
||||
|
||||
req, err := http.NewRequestWithContext(context.Background(), fiber.MethodGet, "/default", nil)
|
||||
utils.AssertEqual(t, err, nil)
|
||||
resp, err := app.Test(req)
|
||||
|
||||
utils.AssertEqual(t, err, nil)
|
||||
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
|
||||
|
||||
// Case 2: Default config and default route
|
||||
app = *fiber.New()
|
||||
|
||||
app.Use(New())
|
||||
app.Use(func(c *fiber.Ctx) error {
|
||||
return c.SendStatus(fiber.StatusOK)
|
||||
})
|
||||
|
||||
req, err = http.NewRequestWithContext(context.Background(), fiber.MethodGet, "/default", nil)
|
||||
utils.AssertEqual(t, err, nil)
|
||||
resp, err = app.Test(req)
|
||||
|
||||
utils.AssertEqual(t, err, nil)
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||
}
|
||||
|
||||
func Test_RegexRules(t *testing.T) {
|
||||
// Case 1: Rules regex is empty
|
||||
app := *fiber.New()
|
||||
app.Use(New(Config{
|
||||
Rules: map[string]string{},
|
||||
StatusCode: fiber.StatusMovedPermanently,
|
||||
}))
|
||||
|
||||
app.Use(func(c *fiber.Ctx) error {
|
||||
return c.SendStatus(fiber.StatusOK)
|
||||
})
|
||||
|
||||
req, err := http.NewRequestWithContext(context.Background(), fiber.MethodGet, "/default", nil)
|
||||
utils.AssertEqual(t, err, nil)
|
||||
resp, err := app.Test(req)
|
||||
|
||||
utils.AssertEqual(t, err, nil)
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||
|
||||
// Case 2: Rules regex map contains valid regex and well-formed replacement URLs
|
||||
app = *fiber.New()
|
||||
app.Use(New(Config{
|
||||
Rules: map[string]string{
|
||||
"/default": "google.com",
|
||||
},
|
||||
StatusCode: fiber.StatusMovedPermanently,
|
||||
}))
|
||||
|
||||
app.Use(func(c *fiber.Ctx) error {
|
||||
return c.SendStatus(fiber.StatusOK)
|
||||
})
|
||||
|
||||
req, err = http.NewRequestWithContext(context.Background(), fiber.MethodGet, "/default", nil)
|
||||
utils.AssertEqual(t, err, nil)
|
||||
resp, err = app.Test(req)
|
||||
|
||||
utils.AssertEqual(t, err, nil)
|
||||
utils.AssertEqual(t, fiber.StatusMovedPermanently, resp.StatusCode)
|
||||
utils.AssertEqual(t, "google.com", resp.Header.Get("Location"))
|
||||
|
||||
// Case 3: Test invalid regex throws panic
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Log("Recovered from invalid regex: ", r)
|
||||
}
|
||||
}()
|
||||
|
||||
app = *fiber.New()
|
||||
app.Use(New(Config{
|
||||
Rules: map[string]string{
|
||||
"(": "google.com",
|
||||
},
|
||||
StatusCode: fiber.StatusMovedPermanently,
|
||||
}))
|
||||
t.Error("Expected panic, got nil")
|
||||
}
|
66
middleware/requestid/config.go
Normal file
66
middleware/requestid/config.go
Normal file
|
@ -0,0 +1,66 @@
|
|||
package requestid
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
)
|
||||
|
||||
// Config defines the config for middleware.
|
||||
type Config struct {
|
||||
// Next defines a function to skip this middleware when returned true.
|
||||
//
|
||||
// Optional. Default: nil
|
||||
Next func(c *fiber.Ctx) bool
|
||||
|
||||
// Header is the header key where to get/set the unique request ID
|
||||
//
|
||||
// Optional. Default: "X-Request-ID"
|
||||
Header string
|
||||
|
||||
// Generator defines a function to generate the unique identifier.
|
||||
//
|
||||
// Optional. Default: utils.UUID
|
||||
Generator func() string
|
||||
|
||||
// ContextKey defines the key used when storing the request ID in
|
||||
// the locals for a specific request.
|
||||
// Should be a private type instead of string, but too many apps probably
|
||||
// rely on this exact value.
|
||||
//
|
||||
// Optional. Default: "requestid"
|
||||
ContextKey interface{}
|
||||
}
|
||||
|
||||
// ConfigDefault is the default config
|
||||
// It uses a fast UUID generator which will expose the number of
|
||||
// requests made to the server. To conceal this value for better
|
||||
// privacy, use the "utils.UUIDv4" generator.
|
||||
var ConfigDefault = Config{
|
||||
Next: nil,
|
||||
Header: fiber.HeaderXRequestID,
|
||||
Generator: utils.UUID,
|
||||
ContextKey: "requestid",
|
||||
}
|
||||
|
||||
// Helper function to set default values
|
||||
func configDefault(config ...Config) Config {
|
||||
// Return default config if nothing provided
|
||||
if len(config) < 1 {
|
||||
return ConfigDefault
|
||||
}
|
||||
|
||||
// Override default config
|
||||
cfg := config[0]
|
||||
|
||||
// Set default values
|
||||
if cfg.Header == "" {
|
||||
cfg.Header = ConfigDefault.Header
|
||||
}
|
||||
if cfg.Generator == nil {
|
||||
cfg.Generator = ConfigDefault.Generator
|
||||
}
|
||||
if cfg.ContextKey == nil {
|
||||
cfg.ContextKey = ConfigDefault.ContextKey
|
||||
}
|
||||
return cfg
|
||||
}
|
33
middleware/requestid/requestid.go
Normal file
33
middleware/requestid/requestid.go
Normal file
|
@ -0,0 +1,33 @@
|
|||
package requestid
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
// New creates a new middleware handler
|
||||
func New(config ...Config) fiber.Handler {
|
||||
// Set default config
|
||||
cfg := configDefault(config...)
|
||||
|
||||
// Return new handler
|
||||
return func(c *fiber.Ctx) error {
|
||||
// Don't execute middleware if Next returns true
|
||||
if cfg.Next != nil && cfg.Next(c) {
|
||||
return c.Next()
|
||||
}
|
||||
// Get id from request, else we generate one
|
||||
rid := c.Get(cfg.Header)
|
||||
if rid == "" {
|
||||
rid = cfg.Generator()
|
||||
}
|
||||
|
||||
// Set new id to response header
|
||||
c.Set(cfg.Header, rid)
|
||||
|
||||
// Add the request ID to locals
|
||||
c.Locals(cfg.ContextKey, rid)
|
||||
|
||||
// Continue stack
|
||||
return c.Next()
|
||||
}
|
||||
}
|
103
middleware/requestid/requestid_test.go
Normal file
103
middleware/requestid/requestid_test.go
Normal file
|
@ -0,0 +1,103 @@
|
|||
package requestid
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
)
|
||||
|
||||
// go test -run Test_RequestID
|
||||
func Test_RequestID(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New())
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return c.SendString("Hello, World 👋!")
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||
|
||||
reqid := resp.Header.Get(fiber.HeaderXRequestID)
|
||||
utils.AssertEqual(t, 36, len(reqid))
|
||||
|
||||
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
|
||||
req.Header.Add(fiber.HeaderXRequestID, reqid)
|
||||
|
||||
resp, err = app.Test(req)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||
utils.AssertEqual(t, reqid, resp.Header.Get(fiber.HeaderXRequestID))
|
||||
}
|
||||
|
||||
// go test -run Test_RequestID_Next
|
||||
func Test_RequestID_Next(t *testing.T) {
|
||||
t.Parallel()
|
||||
app := fiber.New()
|
||||
app.Use(New(Config{
|
||||
Next: func(_ *fiber.Ctx) bool {
|
||||
return true
|
||||
},
|
||||
}))
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, resp.Header.Get(fiber.HeaderXRequestID), "")
|
||||
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
|
||||
}
|
||||
|
||||
// go test -run Test_RequestID_Locals
|
||||
func Test_RequestID_Locals(t *testing.T) {
|
||||
t.Parallel()
|
||||
reqID := "ThisIsARequestId"
|
||||
type ContextKey int
|
||||
const requestContextKey ContextKey = iota
|
||||
|
||||
app := fiber.New()
|
||||
app.Use(New(Config{
|
||||
Generator: func() string {
|
||||
return reqID
|
||||
},
|
||||
ContextKey: requestContextKey,
|
||||
}))
|
||||
|
||||
var ctxVal string
|
||||
|
||||
app.Use(func(c *fiber.Ctx) error {
|
||||
ctxVal = c.Locals(requestContextKey).(string) //nolint:forcetypeassert,errcheck // We always store a string in here
|
||||
return c.Next()
|
||||
})
|
||||
|
||||
_, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, reqID, ctxVal)
|
||||
}
|
||||
|
||||
// go test -run Test_RequestID_DefaultKey
|
||||
func Test_RequestID_DefaultKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
reqID := "ThisIsARequestId"
|
||||
|
||||
app := fiber.New()
|
||||
app.Use(New(Config{
|
||||
Generator: func() string {
|
||||
return reqID
|
||||
},
|
||||
}))
|
||||
|
||||
var ctxVal string
|
||||
|
||||
app.Use(func(c *fiber.Ctx) error {
|
||||
ctxVal = c.Locals("requestid").(string) //nolint:forcetypeassert,errcheck // We always store a string in here
|
||||
return c.Next()
|
||||
})
|
||||
|
||||
_, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, reqID, ctxVal)
|
||||
}
|
38
middleware/rewrite/config.go
Normal file
38
middleware/rewrite/config.go
Normal file
|
@ -0,0 +1,38 @@
|
|||
package rewrite
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
// Config defines the config for middleware.
|
||||
type Config struct {
|
||||
// Next defines a function to skip middleware.
|
||||
// Optional. Default: nil
|
||||
Next func(*fiber.Ctx) bool
|
||||
|
||||
// Rules defines the URL path rewrite rules. The values captured in asterisk can be
|
||||
// retrieved by index e.g. $1, $2 and so on.
|
||||
// Required. Example:
|
||||
// "/old": "/new",
|
||||
// "/api/*": "/$1",
|
||||
// "/js/*": "/public/javascripts/$1",
|
||||
// "/users/*/orders/*": "/user/$1/order/$2",
|
||||
Rules map[string]string
|
||||
|
||||
rulesRegex map[*regexp.Regexp]string
|
||||
}
|
||||
|
||||
// Helper function to set default values
|
||||
func configDefault(config ...Config) Config {
|
||||
// Return default config if nothing provided
|
||||
if len(config) < 1 {
|
||||
return Config{}
|
||||
}
|
||||
|
||||
// Override default config
|
||||
cfg := config[0]
|
||||
|
||||
return cfg
|
||||
}
|
54
middleware/rewrite/rewrite.go
Normal file
54
middleware/rewrite/rewrite.go
Normal file
|
@ -0,0 +1,54 @@
|
|||
package rewrite
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
// New creates a new middleware handler
|
||||
func New(config ...Config) fiber.Handler {
|
||||
cfg := configDefault(config...)
|
||||
|
||||
// Initialize
|
||||
cfg.rulesRegex = map[*regexp.Regexp]string{}
|
||||
for k, v := range cfg.Rules {
|
||||
k = strings.ReplaceAll(k, "*", "(.*)")
|
||||
k += "$"
|
||||
cfg.rulesRegex[regexp.MustCompile(k)] = v
|
||||
}
|
||||
// Middleware function
|
||||
return func(c *fiber.Ctx) error {
|
||||
// Next request to skip middleware
|
||||
if cfg.Next != nil && cfg.Next(c) {
|
||||
return c.Next()
|
||||
}
|
||||
// Rewrite
|
||||
for k, v := range cfg.rulesRegex {
|
||||
replacer := captureTokens(k, c.Path())
|
||||
if replacer != nil {
|
||||
c.Path(replacer.Replace(v))
|
||||
break
|
||||
}
|
||||
}
|
||||
return c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/labstack/echo/blob/master/middleware/rewrite.go
|
||||
func captureTokens(pattern *regexp.Regexp, input string) *strings.Replacer {
|
||||
groups := pattern.FindAllStringSubmatch(input, -1)
|
||||
if groups == nil {
|
||||
return nil
|
||||
}
|
||||
values := groups[0][1:]
|
||||
replace := make([]string, 2*len(values))
|
||||
for i, v := range values {
|
||||
j := 2 * i
|
||||
replace[j] = "$" + strconv.Itoa(i+1)
|
||||
replace[j+1] = v
|
||||
}
|
||||
return strings.NewReplacer(replace...)
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue