Adding upstream version 0.28.1.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
88f1d47ab6
commit
e28c88ef14
933 changed files with 194711 additions and 0 deletions
231
tools/router/error.go
Normal file
231
tools/router/error.go
Normal file
|
@ -0,0 +1,231 @@
|
|||
package router
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"io/fs"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/pocketbase/pocketbase/tools/inflector"
|
||||
)
|
||||
|
||||
// SafeErrorItem defines a common error interface for a printable public safe error.
|
||||
type SafeErrorItem interface {
|
||||
// Code represents a fixed unique identifier of the error (usually used as translation key).
|
||||
Code() string
|
||||
|
||||
// Error is the default English human readable error message that will be returned.
|
||||
Error() string
|
||||
}
|
||||
|
||||
// SafeErrorParamsResolver defines an optional interface for specifying dynamic error parameters.
|
||||
type SafeErrorParamsResolver interface {
|
||||
// Params defines a map with dynamic parameters to return as part of the public safe error view.
|
||||
Params() map[string]any
|
||||
}
|
||||
|
||||
// SafeErrorResolver defines an error interface for resolving the public safe error fields.
|
||||
type SafeErrorResolver interface {
|
||||
// Resolve allows modifying and returning a new public safe error data map.
|
||||
Resolve(errData map[string]any) any
|
||||
}
|
||||
|
||||
// ApiError defines the struct for a basic api error response.
|
||||
type ApiError struct {
|
||||
rawData any
|
||||
|
||||
Data map[string]any `json:"data"`
|
||||
Message string `json:"message"`
|
||||
Status int `json:"status"`
|
||||
}
|
||||
|
||||
// Error makes it compatible with the `error` interface.
|
||||
func (e *ApiError) Error() string {
|
||||
return e.Message
|
||||
}
|
||||
|
||||
// RawData returns the unformatted error data (could be an internal error, text, etc.)
|
||||
func (e *ApiError) RawData() any {
|
||||
return e.rawData
|
||||
}
|
||||
|
||||
// Is reports whether the current ApiError wraps the target.
|
||||
func (e *ApiError) Is(target error) bool {
|
||||
err, ok := e.rawData.(error)
|
||||
if ok {
|
||||
return errors.Is(err, target)
|
||||
}
|
||||
|
||||
apiErr, ok := target.(*ApiError)
|
||||
|
||||
return ok && e == apiErr
|
||||
}
|
||||
|
||||
// NewNotFoundError creates and returns 404 ApiError.
|
||||
func NewNotFoundError(message string, rawErrData any) *ApiError {
|
||||
if message == "" {
|
||||
message = "The requested resource wasn't found."
|
||||
}
|
||||
|
||||
return NewApiError(http.StatusNotFound, message, rawErrData)
|
||||
}
|
||||
|
||||
// NewBadRequestError creates and returns 400 ApiError.
|
||||
func NewBadRequestError(message string, rawErrData any) *ApiError {
|
||||
if message == "" {
|
||||
message = "Something went wrong while processing your request."
|
||||
}
|
||||
|
||||
return NewApiError(http.StatusBadRequest, message, rawErrData)
|
||||
}
|
||||
|
||||
// NewForbiddenError creates and returns 403 ApiError.
|
||||
func NewForbiddenError(message string, rawErrData any) *ApiError {
|
||||
if message == "" {
|
||||
message = "You are not allowed to perform this request."
|
||||
}
|
||||
|
||||
return NewApiError(http.StatusForbidden, message, rawErrData)
|
||||
}
|
||||
|
||||
// NewUnauthorizedError creates and returns 401 ApiError.
|
||||
func NewUnauthorizedError(message string, rawErrData any) *ApiError {
|
||||
if message == "" {
|
||||
message = "Missing or invalid authentication."
|
||||
}
|
||||
|
||||
return NewApiError(http.StatusUnauthorized, message, rawErrData)
|
||||
}
|
||||
|
||||
// NewInternalServerError creates and returns 500 ApiError.
|
||||
func NewInternalServerError(message string, rawErrData any) *ApiError {
|
||||
if message == "" {
|
||||
message = "Something went wrong while processing your request."
|
||||
}
|
||||
|
||||
return NewApiError(http.StatusInternalServerError, message, rawErrData)
|
||||
}
|
||||
|
||||
func NewTooManyRequestsError(message string, rawErrData any) *ApiError {
|
||||
if message == "" {
|
||||
message = "Too Many Requests."
|
||||
}
|
||||
|
||||
return NewApiError(http.StatusTooManyRequests, message, rawErrData)
|
||||
}
|
||||
|
||||
// NewApiError creates and returns new normalized ApiError instance.
|
||||
func NewApiError(status int, message string, rawErrData any) *ApiError {
|
||||
if message == "" {
|
||||
message = http.StatusText(status)
|
||||
}
|
||||
|
||||
return &ApiError{
|
||||
rawData: rawErrData,
|
||||
Data: safeErrorsData(rawErrData),
|
||||
Status: status,
|
||||
Message: strings.TrimSpace(inflector.Sentenize(message)),
|
||||
}
|
||||
}
|
||||
|
||||
// ToApiError wraps err into ApiError instance (if not already).
|
||||
func ToApiError(err error) *ApiError {
|
||||
var apiErr *ApiError
|
||||
|
||||
if !errors.As(err, &apiErr) {
|
||||
// no ApiError found -> assign a generic one
|
||||
if errors.Is(err, sql.ErrNoRows) || errors.Is(err, fs.ErrNotExist) {
|
||||
apiErr = NewNotFoundError("", err)
|
||||
} else {
|
||||
apiErr = NewBadRequestError("", err)
|
||||
}
|
||||
}
|
||||
|
||||
return apiErr
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
func safeErrorsData(data any) map[string]any {
|
||||
switch v := data.(type) {
|
||||
case validation.Errors:
|
||||
return resolveSafeErrorsData(v)
|
||||
case error:
|
||||
validationErrors := validation.Errors{}
|
||||
if errors.As(v, &validationErrors) {
|
||||
return resolveSafeErrorsData(validationErrors)
|
||||
}
|
||||
return map[string]any{} // not nil to ensure that is json serialized as object
|
||||
case map[string]validation.Error:
|
||||
return resolveSafeErrorsData(v)
|
||||
case map[string]SafeErrorItem:
|
||||
return resolveSafeErrorsData(v)
|
||||
case map[string]error:
|
||||
return resolveSafeErrorsData(v)
|
||||
case map[string]string:
|
||||
return resolveSafeErrorsData(v)
|
||||
case map[string]any:
|
||||
return resolveSafeErrorsData(v)
|
||||
default:
|
||||
return map[string]any{} // not nil to ensure that is json serialized as object
|
||||
}
|
||||
}
|
||||
|
||||
func resolveSafeErrorsData[T any](data map[string]T) map[string]any {
|
||||
result := map[string]any{}
|
||||
|
||||
for name, err := range data {
|
||||
if isNestedError(err) {
|
||||
result[name] = safeErrorsData(err)
|
||||
} else {
|
||||
result[name] = resolveSafeErrorItem(err)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func isNestedError(err any) bool {
|
||||
switch err.(type) {
|
||||
case validation.Errors,
|
||||
map[string]validation.Error,
|
||||
map[string]SafeErrorItem,
|
||||
map[string]error,
|
||||
map[string]string,
|
||||
map[string]any:
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// resolveSafeErrorItem extracts from each validation error its
|
||||
// public safe error code and message.
|
||||
func resolveSafeErrorItem(err any) any {
|
||||
data := map[string]any{}
|
||||
|
||||
if obj, ok := err.(SafeErrorItem); ok {
|
||||
// extract the specific error code and message
|
||||
data["code"] = obj.Code()
|
||||
data["message"] = inflector.Sentenize(obj.Error())
|
||||
} else {
|
||||
// fallback to the default public safe values
|
||||
data["code"] = "validation_invalid_value"
|
||||
data["message"] = "Invalid value."
|
||||
}
|
||||
|
||||
if s, ok := err.(SafeErrorParamsResolver); ok {
|
||||
params := s.Params()
|
||||
if len(params) > 0 {
|
||||
data["params"] = params
|
||||
}
|
||||
}
|
||||
|
||||
if s, ok := err.(SafeErrorResolver); ok {
|
||||
return s.Resolve(data)
|
||||
}
|
||||
|
||||
return data
|
||||
}
|
358
tools/router/error_test.go
Normal file
358
tools/router/error_test.go
Normal file
|
@ -0,0 +1,358 @@
|
|||
package router_test
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/pocketbase/pocketbase/tools/router"
|
||||
)
|
||||
|
||||
func TestNewApiErrorWithRawData(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
e := router.NewApiError(
|
||||
300,
|
||||
"message_test",
|
||||
"rawData_test",
|
||||
)
|
||||
|
||||
result, _ := json.Marshal(e)
|
||||
expected := `{"data":{},"message":"Message_test.","status":300}`
|
||||
|
||||
if string(result) != expected {
|
||||
t.Errorf("Expected\n%v\ngot\n%v", expected, string(result))
|
||||
}
|
||||
|
||||
if e.Error() != "Message_test." {
|
||||
t.Errorf("Expected %q, got %q", "Message_test.", e.Error())
|
||||
}
|
||||
|
||||
if e.RawData() != "rawData_test" {
|
||||
t.Errorf("Expected rawData\n%v\ngot\n%v", "rawData_test", e.RawData())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewApiErrorWithValidationData(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
e := router.NewApiError(
|
||||
300,
|
||||
"message_test",
|
||||
map[string]any{
|
||||
"err1": errors.New("test error"), // should be normalized
|
||||
"err2": validation.ErrRequired,
|
||||
"err3": validation.Errors{
|
||||
"err3.1": errors.New("test error"), // should be normalized
|
||||
"err3.2": validation.ErrRequired,
|
||||
"err3.3": validation.Errors{
|
||||
"err3.3.1": validation.ErrRequired,
|
||||
},
|
||||
},
|
||||
"err4": &mockSafeErrorItem{},
|
||||
"err5": map[string]error{
|
||||
"err5.1": validation.ErrRequired,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
result, _ := json.Marshal(e)
|
||||
expected := `{"data":{"err1":{"code":"validation_invalid_value","message":"Invalid value."},"err2":{"code":"validation_required","message":"Cannot be blank."},"err3":{"err3.1":{"code":"validation_invalid_value","message":"Invalid value."},"err3.2":{"code":"validation_required","message":"Cannot be blank."},"err3.3":{"err3.3.1":{"code":"validation_required","message":"Cannot be blank."}}},"err4":{"code":"mock_code","message":"Mock_error.","mock_resolve":123},"err5":{"err5.1":{"code":"validation_required","message":"Cannot be blank."}}},"message":"Message_test.","status":300}`
|
||||
|
||||
if string(result) != expected {
|
||||
t.Errorf("Expected \n%v, \ngot \n%v", expected, string(result))
|
||||
}
|
||||
|
||||
if e.Error() != "Message_test." {
|
||||
t.Errorf("Expected %q, got %q", "Message_test.", e.Error())
|
||||
}
|
||||
|
||||
if e.RawData() == nil {
|
||||
t.Error("Expected non-nil rawData")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewNotFoundError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []struct {
|
||||
message string
|
||||
data any
|
||||
expected string
|
||||
}{
|
||||
{"", nil, `{"data":{},"message":"The requested resource wasn't found.","status":404}`},
|
||||
{"demo", "rawData_test", `{"data":{},"message":"Demo.","status":404}`},
|
||||
{"demo", validation.Errors{"err1": validation.NewError("test_code", "test_message")}, `{"data":{"err1":{"code":"test_code","message":"Test_message."}},"message":"Demo.","status":404}`},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(strconv.Itoa(i), func(t *testing.T) {
|
||||
e := router.NewNotFoundError(s.message, s.data)
|
||||
result, _ := json.Marshal(e)
|
||||
|
||||
if str := string(result); str != s.expected {
|
||||
t.Fatalf("Expected\n%v\ngot\n%v", s.expected, str)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewBadRequestError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []struct {
|
||||
message string
|
||||
data any
|
||||
expected string
|
||||
}{
|
||||
{"", nil, `{"data":{},"message":"Something went wrong while processing your request.","status":400}`},
|
||||
{"demo", "rawData_test", `{"data":{},"message":"Demo.","status":400}`},
|
||||
{"demo", validation.Errors{"err1": validation.NewError("test_code", "test_message")}, `{"data":{"err1":{"code":"test_code","message":"Test_message."}},"message":"Demo.","status":400}`},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(strconv.Itoa(i), func(t *testing.T) {
|
||||
e := router.NewBadRequestError(s.message, s.data)
|
||||
result, _ := json.Marshal(e)
|
||||
|
||||
if str := string(result); str != s.expected {
|
||||
t.Fatalf("Expected\n%v\ngot\n%v", s.expected, str)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewForbiddenError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []struct {
|
||||
message string
|
||||
data any
|
||||
expected string
|
||||
}{
|
||||
{"", nil, `{"data":{},"message":"You are not allowed to perform this request.","status":403}`},
|
||||
{"demo", "rawData_test", `{"data":{},"message":"Demo.","status":403}`},
|
||||
{"demo", validation.Errors{"err1": validation.NewError("test_code", "test_message")}, `{"data":{"err1":{"code":"test_code","message":"Test_message."}},"message":"Demo.","status":403}`},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(strconv.Itoa(i), func(t *testing.T) {
|
||||
e := router.NewForbiddenError(s.message, s.data)
|
||||
result, _ := json.Marshal(e)
|
||||
|
||||
if str := string(result); str != s.expected {
|
||||
t.Fatalf("Expected\n%v\ngot\n%v", s.expected, str)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewUnauthorizedError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []struct {
|
||||
message string
|
||||
data any
|
||||
expected string
|
||||
}{
|
||||
{"", nil, `{"data":{},"message":"Missing or invalid authentication.","status":401}`},
|
||||
{"demo", "rawData_test", `{"data":{},"message":"Demo.","status":401}`},
|
||||
{"demo", validation.Errors{"err1": validation.NewError("test_code", "test_message")}, `{"data":{"err1":{"code":"test_code","message":"Test_message."}},"message":"Demo.","status":401}`},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(strconv.Itoa(i), func(t *testing.T) {
|
||||
e := router.NewUnauthorizedError(s.message, s.data)
|
||||
result, _ := json.Marshal(e)
|
||||
|
||||
if str := string(result); str != s.expected {
|
||||
t.Fatalf("Expected\n%v\ngot\n%v", s.expected, str)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewInternalServerError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []struct {
|
||||
message string
|
||||
data any
|
||||
expected string
|
||||
}{
|
||||
{"", nil, `{"data":{},"message":"Something went wrong while processing your request.","status":500}`},
|
||||
{"demo", "rawData_test", `{"data":{},"message":"Demo.","status":500}`},
|
||||
{"demo", validation.Errors{"err1": validation.NewError("test_code", "test_message")}, `{"data":{"err1":{"code":"test_code","message":"Test_message."}},"message":"Demo.","status":500}`},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(strconv.Itoa(i), func(t *testing.T) {
|
||||
e := router.NewInternalServerError(s.message, s.data)
|
||||
result, _ := json.Marshal(e)
|
||||
|
||||
if str := string(result); str != s.expected {
|
||||
t.Fatalf("Expected\n%v\ngot\n%v", s.expected, str)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewTooManyRequestsError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []struct {
|
||||
message string
|
||||
data any
|
||||
expected string
|
||||
}{
|
||||
{"", nil, `{"data":{},"message":"Too Many Requests.","status":429}`},
|
||||
{"demo", "rawData_test", `{"data":{},"message":"Demo.","status":429}`},
|
||||
{"demo", validation.Errors{"err1": validation.NewError("test_code", "test_message").SetParams(map[string]any{"test": 123})}, `{"data":{"err1":{"code":"test_code","message":"Test_message.","params":{"test":123}}},"message":"Demo.","status":429}`},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(strconv.Itoa(i), func(t *testing.T) {
|
||||
e := router.NewTooManyRequestsError(s.message, s.data)
|
||||
result, _ := json.Marshal(e)
|
||||
|
||||
if str := string(result); str != s.expected {
|
||||
t.Fatalf("Expected\n%v\ngot\n%v", s.expected, str)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestApiErrorIs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
err0 := router.NewInternalServerError("", nil)
|
||||
err1 := router.NewInternalServerError("", nil)
|
||||
err2 := errors.New("test")
|
||||
err3 := fmt.Errorf("wrapped: %w", err0)
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
err error
|
||||
target error
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
"nil error",
|
||||
err0,
|
||||
nil,
|
||||
false,
|
||||
},
|
||||
{
|
||||
"non ApiError",
|
||||
err0,
|
||||
err1,
|
||||
false,
|
||||
},
|
||||
{
|
||||
"different ApiError",
|
||||
err0,
|
||||
err2,
|
||||
false,
|
||||
},
|
||||
{
|
||||
"same ApiError",
|
||||
err0,
|
||||
err0,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"wrapped ApiError",
|
||||
err3,
|
||||
err0,
|
||||
true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
is := errors.Is(s.err, s.target)
|
||||
|
||||
if is != s.expected {
|
||||
t.Fatalf("Expected %v, got %v", s.expected, is)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestToApiError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
err error
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
"regular error",
|
||||
errors.New("test"),
|
||||
`{"data":{},"message":"Something went wrong while processing your request.","status":400}`,
|
||||
},
|
||||
{
|
||||
"fs.ErrNotExist",
|
||||
fs.ErrNotExist,
|
||||
`{"data":{},"message":"The requested resource wasn't found.","status":404}`,
|
||||
},
|
||||
{
|
||||
"sql.ErrNoRows",
|
||||
sql.ErrNoRows,
|
||||
`{"data":{},"message":"The requested resource wasn't found.","status":404}`,
|
||||
},
|
||||
{
|
||||
"ApiError",
|
||||
router.NewForbiddenError("test", nil),
|
||||
`{"data":{},"message":"Test.","status":403}`,
|
||||
},
|
||||
{
|
||||
"wrapped ApiError",
|
||||
fmt.Errorf("wrapped: %w", router.NewForbiddenError("test", nil)),
|
||||
`{"data":{},"message":"Test.","status":403}`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
raw, err := json.Marshal(router.ToApiError(s.err))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
rawStr := string(raw)
|
||||
|
||||
if rawStr != s.expected {
|
||||
t.Fatalf("Expected error\n%vgot\n%v", s.expected, rawStr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
var (
|
||||
_ router.SafeErrorItem = (*mockSafeErrorItem)(nil)
|
||||
_ router.SafeErrorResolver = (*mockSafeErrorItem)(nil)
|
||||
)
|
||||
|
||||
type mockSafeErrorItem struct {
|
||||
}
|
||||
|
||||
func (m *mockSafeErrorItem) Code() string {
|
||||
return "mock_code"
|
||||
}
|
||||
|
||||
func (m *mockSafeErrorItem) Error() string {
|
||||
return "mock_error"
|
||||
}
|
||||
|
||||
func (m *mockSafeErrorItem) Resolve(errData map[string]any) any {
|
||||
errData["mock_resolve"] = 123
|
||||
return errData
|
||||
}
|
398
tools/router/event.go
Normal file
398
tools/router/event.go
Normal file
|
@ -0,0 +1,398 @@
|
|||
package router
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"encoding/xml"
|
||||
"errors"
|
||||
"io"
|
||||
"io/fs"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/filesystem"
|
||||
"github.com/pocketbase/pocketbase/tools/hook"
|
||||
"github.com/pocketbase/pocketbase/tools/picker"
|
||||
"github.com/pocketbase/pocketbase/tools/store"
|
||||
)
|
||||
|
||||
var ErrUnsupportedContentType = NewBadRequestError("Unsupported Content-Type", nil)
|
||||
var ErrInvalidRedirectStatusCode = NewInternalServerError("Invalid redirect status code", nil)
|
||||
var ErrFileNotFound = NewNotFoundError("File not found", nil)
|
||||
|
||||
const IndexPage = "index.html"
|
||||
|
||||
// Event specifies based Route handler event that is usually intended
|
||||
// to be embedded as part of a custom event struct.
|
||||
//
|
||||
// NB! It is expected that the Response and Request fields are always set.
|
||||
type Event struct {
|
||||
Response http.ResponseWriter
|
||||
Request *http.Request
|
||||
|
||||
hook.Event
|
||||
|
||||
data store.Store[string, any]
|
||||
}
|
||||
|
||||
// RWUnwrapper specifies that an http.ResponseWriter could be "unwrapped"
|
||||
// (usually used with [http.ResponseController]).
|
||||
type RWUnwrapper interface {
|
||||
Unwrap() http.ResponseWriter
|
||||
}
|
||||
|
||||
// Written reports whether the current response has already been written.
|
||||
//
|
||||
// This method always returns false if e.ResponseWritter doesn't implement the WriteTracker interface
|
||||
// (all router package handlers receives a ResponseWritter that implements it unless explicitly replaced with a custom one).
|
||||
func (e *Event) Written() bool {
|
||||
written, _ := getWritten(e.Response)
|
||||
return written
|
||||
}
|
||||
|
||||
// Status reports the status code of the current response.
|
||||
//
|
||||
// This method always returns 0 if e.Response doesn't implement the StatusTracker interface
|
||||
// (all router package handlers receives a ResponseWritter that implements it unless explicitly replaced with a custom one).
|
||||
func (e *Event) Status() int {
|
||||
status, _ := getStatus(e.Response)
|
||||
return status
|
||||
}
|
||||
|
||||
// Flush flushes buffered data to the current response.
|
||||
//
|
||||
// Returns [http.ErrNotSupported] if e.Response doesn't implement the [http.Flusher] interface
|
||||
// (all router package handlers receives a ResponseWritter that implements it unless explicitly replaced with a custom one).
|
||||
func (e *Event) Flush() error {
|
||||
return http.NewResponseController(e.Response).Flush()
|
||||
}
|
||||
|
||||
// IsTLS reports whether the connection on which the request was received is TLS.
|
||||
func (e *Event) IsTLS() bool {
|
||||
return e.Request.TLS != nil
|
||||
}
|
||||
|
||||
// SetCookie is an alias for [http.SetCookie].
|
||||
//
|
||||
// SetCookie adds a Set-Cookie header to the current response's headers.
|
||||
// The provided cookie must have a valid Name.
|
||||
// Invalid cookies may be silently dropped.
|
||||
func (e *Event) SetCookie(cookie *http.Cookie) {
|
||||
http.SetCookie(e.Response, cookie)
|
||||
}
|
||||
|
||||
// RemoteIP returns the IP address of the client that sent the request.
|
||||
//
|
||||
// IPv6 addresses are returned expanded.
|
||||
// For example, "2001:db8::1" becomes "2001:0db8:0000:0000:0000:0000:0000:0001".
|
||||
//
|
||||
// Note that if you are behind reverse proxy(ies), this method returns
|
||||
// the IP of the last connecting proxy.
|
||||
func (e *Event) RemoteIP() string {
|
||||
ip, _, _ := net.SplitHostPort(e.Request.RemoteAddr)
|
||||
parsed, _ := netip.ParseAddr(ip)
|
||||
return parsed.StringExpanded()
|
||||
}
|
||||
|
||||
// FindUploadedFiles extracts all form files of "key" from a http request
|
||||
// and returns a slice with filesystem.File instances (if any).
|
||||
func (e *Event) FindUploadedFiles(key string) ([]*filesystem.File, error) {
|
||||
if e.Request.MultipartForm == nil {
|
||||
err := e.Request.ParseMultipartForm(DefaultMaxMemory)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if e.Request.MultipartForm == nil || e.Request.MultipartForm.File == nil || len(e.Request.MultipartForm.File[key]) == 0 {
|
||||
return nil, http.ErrMissingFile
|
||||
}
|
||||
|
||||
result := make([]*filesystem.File, 0, len(e.Request.MultipartForm.File[key]))
|
||||
|
||||
for _, fh := range e.Request.MultipartForm.File[key] {
|
||||
file, err := filesystem.NewFileFromMultipart(fh)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result = append(result, file)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Store
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
// Get retrieves single value from the current event data store.
|
||||
func (e *Event) Get(key string) any {
|
||||
return e.data.Get(key)
|
||||
}
|
||||
|
||||
// GetAll returns a copy of the current event data store.
|
||||
func (e *Event) GetAll() map[string]any {
|
||||
return e.data.GetAll()
|
||||
}
|
||||
|
||||
// Set saves single value into the current event data store.
|
||||
func (e *Event) Set(key string, value any) {
|
||||
e.data.Set(key, value)
|
||||
}
|
||||
|
||||
// SetAll saves all items from m into the current event data store.
|
||||
func (e *Event) SetAll(m map[string]any) {
|
||||
for k, v := range m {
|
||||
e.Set(k, v)
|
||||
}
|
||||
}
|
||||
|
||||
// Response writers
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
const headerContentType = "Content-Type"
|
||||
|
||||
func (e *Event) setResponseHeaderIfEmpty(key, value string) {
|
||||
header := e.Response.Header()
|
||||
if header.Get(key) == "" {
|
||||
header.Set(key, value)
|
||||
}
|
||||
}
|
||||
|
||||
// String writes a plain string response.
|
||||
func (e *Event) String(status int, data string) error {
|
||||
e.setResponseHeaderIfEmpty(headerContentType, "text/plain; charset=utf-8")
|
||||
e.Response.WriteHeader(status)
|
||||
_, err := e.Response.Write([]byte(data))
|
||||
return err
|
||||
}
|
||||
|
||||
// HTML writes an HTML response.
|
||||
func (e *Event) HTML(status int, data string) error {
|
||||
e.setResponseHeaderIfEmpty(headerContentType, "text/html; charset=utf-8")
|
||||
e.Response.WriteHeader(status)
|
||||
_, err := e.Response.Write([]byte(data))
|
||||
return err
|
||||
}
|
||||
|
||||
const jsonFieldsParam = "fields"
|
||||
|
||||
// JSON writes a JSON response.
|
||||
//
|
||||
// It also provides a generic response data fields picker if the "fields" query parameter is set.
|
||||
// For example, if you are requesting `?fields=a,b` for `e.JSON(200, map[string]int{ "a":1, "b":2, "c":3 })`,
|
||||
// it should result in a JSON response like: `{"a":1, "b": 2}`.
|
||||
func (e *Event) JSON(status int, data any) error {
|
||||
e.setResponseHeaderIfEmpty(headerContentType, "application/json")
|
||||
e.Response.WriteHeader(status)
|
||||
|
||||
rawFields := e.Request.URL.Query().Get(jsonFieldsParam)
|
||||
|
||||
// error response or no fields to pick
|
||||
if rawFields == "" || status < 200 || status > 299 {
|
||||
return json.NewEncoder(e.Response).Encode(data)
|
||||
}
|
||||
|
||||
// pick only the requested fields
|
||||
modified, err := picker.Pick(data, rawFields)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return json.NewEncoder(e.Response).Encode(modified)
|
||||
}
|
||||
|
||||
// XML writes an XML response.
|
||||
// It automatically prepends the generic [xml.Header] string to the response.
|
||||
func (e *Event) XML(status int, data any) error {
|
||||
e.setResponseHeaderIfEmpty(headerContentType, "application/xml; charset=utf-8")
|
||||
e.Response.WriteHeader(status)
|
||||
if _, err := e.Response.Write([]byte(xml.Header)); err != nil {
|
||||
return err
|
||||
}
|
||||
return xml.NewEncoder(e.Response).Encode(data)
|
||||
}
|
||||
|
||||
// Stream streams the specified reader into the response.
|
||||
func (e *Event) Stream(status int, contentType string, reader io.Reader) error {
|
||||
e.Response.Header().Set(headerContentType, contentType)
|
||||
e.Response.WriteHeader(status)
|
||||
_, err := io.Copy(e.Response, reader)
|
||||
return err
|
||||
}
|
||||
|
||||
// Blob writes a blob (bytes slice) response.
|
||||
func (e *Event) Blob(status int, contentType string, b []byte) error {
|
||||
e.setResponseHeaderIfEmpty(headerContentType, contentType)
|
||||
e.Response.WriteHeader(status)
|
||||
_, err := e.Response.Write(b)
|
||||
return err
|
||||
}
|
||||
|
||||
// FileFS serves the specified filename from fsys.
|
||||
//
|
||||
// It is similar to [echo.FileFS] for consistency with earlier versions.
|
||||
func (e *Event) FileFS(fsys fs.FS, filename string) error {
|
||||
f, err := fsys.Open(filename)
|
||||
if err != nil {
|
||||
return ErrFileNotFound
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
fi, err := f.Stat()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// if it is a directory try to open its index.html file
|
||||
if fi.IsDir() {
|
||||
filename = filepath.ToSlash(filepath.Join(filename, IndexPage))
|
||||
f, err = fsys.Open(filename)
|
||||
if err != nil {
|
||||
return ErrFileNotFound
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
fi, err = f.Stat()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
ff, ok := f.(io.ReadSeeker)
|
||||
if !ok {
|
||||
return errors.New("[FileFS] file does not implement io.ReadSeeker")
|
||||
}
|
||||
|
||||
http.ServeContent(e.Response, e.Request, fi.Name(), fi.ModTime(), ff)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// NoContent writes a response with no body (ex. 204).
|
||||
func (e *Event) NoContent(status int) error {
|
||||
e.Response.WriteHeader(status)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Redirect writes a redirect response to the specified url.
|
||||
// The status code must be in between 300 – 399 range.
|
||||
func (e *Event) Redirect(status int, url string) error {
|
||||
if status < 300 || status > 399 {
|
||||
return ErrInvalidRedirectStatusCode
|
||||
}
|
||||
e.Response.Header().Set("Location", url)
|
||||
e.Response.WriteHeader(status)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ApiError helpers
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
func (e *Event) Error(status int, message string, errData any) *ApiError {
|
||||
return NewApiError(status, message, errData)
|
||||
}
|
||||
|
||||
func (e *Event) BadRequestError(message string, errData any) *ApiError {
|
||||
return NewBadRequestError(message, errData)
|
||||
}
|
||||
|
||||
func (e *Event) NotFoundError(message string, errData any) *ApiError {
|
||||
return NewNotFoundError(message, errData)
|
||||
}
|
||||
|
||||
func (e *Event) ForbiddenError(message string, errData any) *ApiError {
|
||||
return NewForbiddenError(message, errData)
|
||||
}
|
||||
|
||||
func (e *Event) UnauthorizedError(message string, errData any) *ApiError {
|
||||
return NewUnauthorizedError(message, errData)
|
||||
}
|
||||
|
||||
func (e *Event) TooManyRequestsError(message string, errData any) *ApiError {
|
||||
return NewTooManyRequestsError(message, errData)
|
||||
}
|
||||
|
||||
func (e *Event) InternalServerError(message string, errData any) *ApiError {
|
||||
return NewInternalServerError(message, errData)
|
||||
}
|
||||
|
||||
// Binders
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
const DefaultMaxMemory = 32 << 20 // 32mb
|
||||
|
||||
// BindBody unmarshal the request body into the provided dst.
|
||||
//
|
||||
// dst must be either a struct pointer or map[string]any.
|
||||
//
|
||||
// The rules how the body will be scanned depends on the request Content-Type.
|
||||
//
|
||||
// Currently the following Content-Types are supported:
|
||||
// - application/json
|
||||
// - text/xml, application/xml
|
||||
// - multipart/form-data, application/x-www-form-urlencoded
|
||||
//
|
||||
// Respectively the following struct tags are supported (again, which one will be used depends on the Content-Type):
|
||||
// - "json" (json body)- uses the builtin Go json package for unmarshaling.
|
||||
// - "xml" (xml body) - uses the builtin Go xml package for unmarshaling.
|
||||
// - "form" (form data) - utilizes the custom [router.UnmarshalRequestData] method.
|
||||
//
|
||||
// NB! When dst is a struct make sure that it doesn't have public fields
|
||||
// that shouldn't be bindable and it is advisible such fields to be unexported
|
||||
// or have a separate struct just for the binding. For example:
|
||||
//
|
||||
// data := struct{
|
||||
// somethingPrivate string
|
||||
//
|
||||
// Title string `json:"title" form:"title"`
|
||||
// Total int `json:"total" form:"total"`
|
||||
// }
|
||||
// err := e.BindBody(&data)
|
||||
func (e *Event) BindBody(dst any) error {
|
||||
if e.Request.ContentLength == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
contentType := e.Request.Header.Get(headerContentType)
|
||||
|
||||
if strings.HasPrefix(contentType, "application/json") {
|
||||
dec := json.NewDecoder(e.Request.Body)
|
||||
err := dec.Decode(dst)
|
||||
if err == nil {
|
||||
// manually call Reread because single call of json.Decoder.Decode()
|
||||
// doesn't ensure that the entire body is a valid json string
|
||||
// and it is not guaranteed that it will reach EOF to trigger the reread reset
|
||||
// (ex. in case of trailing spaces or invalid trailing parts like: `{"test":1},something`)
|
||||
if body, ok := e.Request.Body.(Rereader); ok {
|
||||
body.Reread()
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
if strings.HasPrefix(contentType, "multipart/form-data") {
|
||||
if err := e.Request.ParseMultipartForm(DefaultMaxMemory); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return UnmarshalRequestData(e.Request.Form, dst, "", "")
|
||||
}
|
||||
|
||||
if strings.HasPrefix(contentType, "application/x-www-form-urlencoded") {
|
||||
if err := e.Request.ParseForm(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return UnmarshalRequestData(e.Request.Form, dst, "", "")
|
||||
}
|
||||
|
||||
if strings.HasPrefix(contentType, "text/xml") ||
|
||||
strings.HasPrefix(contentType, "application/xml") {
|
||||
return xml.NewDecoder(e.Request.Body).Decode(dst)
|
||||
}
|
||||
|
||||
return ErrUnsupportedContentType
|
||||
}
|
959
tools/router/event_test.go
Normal file
959
tools/router/event_test.go
Normal file
|
@ -0,0 +1,959 @@
|
|||
package router_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"encoding/xml"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/pocketbase/pocketbase/tools/router"
|
||||
)
|
||||
|
||||
type unwrapTester struct {
|
||||
http.ResponseWriter
|
||||
}
|
||||
|
||||
func (ut unwrapTester) Unwrap() http.ResponseWriter {
|
||||
return ut.ResponseWriter
|
||||
}
|
||||
|
||||
func TestEventWritten(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
res1 := httptest.NewRecorder()
|
||||
|
||||
res2 := httptest.NewRecorder()
|
||||
res2.Write([]byte("test"))
|
||||
|
||||
res3 := &router.ResponseWriter{ResponseWriter: unwrapTester{httptest.NewRecorder()}}
|
||||
|
||||
res4 := &router.ResponseWriter{ResponseWriter: unwrapTester{httptest.NewRecorder()}}
|
||||
res4.Write([]byte("test"))
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
response http.ResponseWriter
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "non-written non-WriteTracker",
|
||||
response: res1,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "written non-WriteTracker",
|
||||
response: res2,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "non-written WriteTracker",
|
||||
response: res3,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "written WriteTracker",
|
||||
response: res4,
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
event := router.Event{
|
||||
Response: s.response,
|
||||
}
|
||||
|
||||
result := event.Written()
|
||||
|
||||
if result != s.expected {
|
||||
t.Fatalf("Expected %v, got %v", s.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventStatus(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
res1 := httptest.NewRecorder()
|
||||
|
||||
res2 := httptest.NewRecorder()
|
||||
res2.WriteHeader(123)
|
||||
|
||||
res3 := &router.ResponseWriter{ResponseWriter: unwrapTester{httptest.NewRecorder()}}
|
||||
|
||||
res4 := &router.ResponseWriter{ResponseWriter: unwrapTester{httptest.NewRecorder()}}
|
||||
res4.WriteHeader(123)
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
response http.ResponseWriter
|
||||
expected int
|
||||
}{
|
||||
{
|
||||
name: "non-written non-StatusTracker",
|
||||
response: res1,
|
||||
expected: 0,
|
||||
},
|
||||
{
|
||||
name: "written non-StatusTracker",
|
||||
response: res2,
|
||||
expected: 0,
|
||||
},
|
||||
{
|
||||
name: "non-written StatusTracker",
|
||||
response: res3,
|
||||
expected: 0,
|
||||
},
|
||||
{
|
||||
name: "written StatusTracker",
|
||||
response: res4,
|
||||
expected: 123,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
event := router.Event{
|
||||
Response: s.response,
|
||||
}
|
||||
|
||||
result := event.Status()
|
||||
|
||||
if result != s.expected {
|
||||
t.Fatalf("Expected %d, got %d", s.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventIsTLS(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, "/", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
event := router.Event{Request: req}
|
||||
|
||||
// without TLS
|
||||
if event.IsTLS() {
|
||||
t.Fatalf("Expected IsTLS false")
|
||||
}
|
||||
|
||||
// dummy TLS state
|
||||
req.TLS = new(tls.ConnectionState)
|
||||
|
||||
// with TLS
|
||||
if !event.IsTLS() {
|
||||
t.Fatalf("Expected IsTLS true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventSetCookie(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
event := router.Event{
|
||||
Response: httptest.NewRecorder(),
|
||||
}
|
||||
|
||||
cookie := event.Response.Header().Get("set-cookie")
|
||||
if cookie != "" {
|
||||
t.Fatalf("Expected empty cookie string, got %q", cookie)
|
||||
}
|
||||
|
||||
event.SetCookie(&http.Cookie{Name: "test", Value: "a"})
|
||||
|
||||
expected := "test=a"
|
||||
|
||||
cookie = event.Response.Header().Get("set-cookie")
|
||||
if cookie != expected {
|
||||
t.Fatalf("Expected cookie %q, got %q", expected, cookie)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventRemoteIP(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []struct {
|
||||
remoteAddr string
|
||||
expected string
|
||||
}{
|
||||
{"", "invalid IP"},
|
||||
{"1.2.3.4", "invalid IP"},
|
||||
{"1.2.3.4:8090", "1.2.3.4"},
|
||||
{"[0000:0000:0000:0000:0000:0000:0000:0002]:80", "0000:0000:0000:0000:0000:0000:0000:0002"},
|
||||
{"[::2]:80", "0000:0000:0000:0000:0000:0000:0000:0002"}, // should always return the expanded version
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.remoteAddr, func(t *testing.T) {
|
||||
req, err := http.NewRequest(http.MethodGet, "/", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
req.RemoteAddr = s.remoteAddr
|
||||
|
||||
event := router.Event{Request: req}
|
||||
|
||||
ip := event.RemoteIP()
|
||||
|
||||
if ip != s.expected {
|
||||
t.Fatalf("Expected IP %q, got %q", s.expected, ip)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindUploadedFiles(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
filename string
|
||||
expectedPattern string
|
||||
}{
|
||||
{"ab.png", `^ab\w{10}_\w{10}\.png$`},
|
||||
{"test", `^test_\w{10}\.txt$`},
|
||||
{"a b c d!@$.j!@$pg", `^a_b_c_d_\w{10}\.jpg$`},
|
||||
{strings.Repeat("a", 150), `^a{100}_\w{10}\.txt$`},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.filename, func(t *testing.T) {
|
||||
// create multipart form file body
|
||||
body := new(bytes.Buffer)
|
||||
mp := multipart.NewWriter(body)
|
||||
w, err := mp.CreateFormFile("test", s.filename)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
w.Write([]byte("test"))
|
||||
mp.Close()
|
||||
// ---
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/", body)
|
||||
req.Header.Add("Content-Type", mp.FormDataContentType())
|
||||
|
||||
event := router.Event{Request: req}
|
||||
|
||||
result, err := event.FindUploadedFiles("test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(result) != 1 {
|
||||
t.Fatalf("Expected 1 file, got %d", len(result))
|
||||
}
|
||||
|
||||
if result[0].Size != 4 {
|
||||
t.Fatalf("Expected the file size to be 4 bytes, got %d", result[0].Size)
|
||||
}
|
||||
|
||||
pattern, err := regexp.Compile(s.expectedPattern)
|
||||
if err != nil {
|
||||
t.Fatalf("Invalid filename pattern %q: %v", s.expectedPattern, err)
|
||||
}
|
||||
if !pattern.MatchString(result[0].Name) {
|
||||
t.Fatalf("Expected filename to match %s, got filename %s", s.expectedPattern, result[0].Name)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindUploadedFilesMissing(t *testing.T) {
|
||||
body := new(bytes.Buffer)
|
||||
mp := multipart.NewWriter(body)
|
||||
mp.Close()
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/", body)
|
||||
req.Header.Add("Content-Type", mp.FormDataContentType())
|
||||
|
||||
event := router.Event{Request: req}
|
||||
|
||||
result, err := event.FindUploadedFiles("test")
|
||||
if err == nil {
|
||||
t.Error("Expected error, got nil")
|
||||
}
|
||||
|
||||
if result != nil {
|
||||
t.Errorf("Expected result to be nil, got %v", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventSetGet(t *testing.T) {
|
||||
event := router.Event{}
|
||||
|
||||
// get before any set (ensures that doesn't panic)
|
||||
if v := event.Get("test"); v != nil {
|
||||
t.Fatalf("Expected nil value, got %v", v)
|
||||
}
|
||||
|
||||
event.Set("a", 123)
|
||||
event.Set("b", 456)
|
||||
|
||||
scenarios := []struct {
|
||||
key string
|
||||
expected any
|
||||
}{
|
||||
{"", nil},
|
||||
{"missing", nil},
|
||||
{"a", 123},
|
||||
{"b", 456},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%s", i, s.key), func(t *testing.T) {
|
||||
result := event.Get(s.key)
|
||||
if result != s.expected {
|
||||
t.Fatalf("Expected %v, got %v", s.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventSetAllGetAll(t *testing.T) {
|
||||
data := map[string]any{
|
||||
"a": 123,
|
||||
"b": 456,
|
||||
}
|
||||
rawData, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
event := router.Event{}
|
||||
event.SetAll(data)
|
||||
|
||||
// modify the data to ensure that the map was shallow coppied
|
||||
data["c"] = 789
|
||||
|
||||
result := event.GetAll()
|
||||
rawResult, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(rawResult) == 0 || !bytes.Equal(rawData, rawResult) {
|
||||
t.Fatalf("Expected\n%v\ngot\n%v", rawData, rawResult)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventString(t *testing.T) {
|
||||
scenarios := []testResponseWriteScenario[string]{
|
||||
{
|
||||
name: "no explicit content-type",
|
||||
status: 123,
|
||||
headers: nil,
|
||||
body: "test",
|
||||
expectedStatus: 123,
|
||||
expectedHeaders: map[string]string{"content-type": "text/plain; charset=utf-8"},
|
||||
expectedBody: "test",
|
||||
},
|
||||
{
|
||||
name: "with explicit content-type",
|
||||
status: 123,
|
||||
headers: map[string]string{"content-type": "text/test"},
|
||||
body: "test",
|
||||
expectedStatus: 123,
|
||||
expectedHeaders: map[string]string{"content-type": "text/test"},
|
||||
expectedBody: "test",
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
testEventResponseWrite(t, s, func(e *router.Event) error {
|
||||
return e.String(s.status, s.body)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventHTML(t *testing.T) {
|
||||
scenarios := []testResponseWriteScenario[string]{
|
||||
{
|
||||
name: "no explicit content-type",
|
||||
status: 123,
|
||||
headers: nil,
|
||||
body: "test",
|
||||
expectedStatus: 123,
|
||||
expectedHeaders: map[string]string{"content-type": "text/html; charset=utf-8"},
|
||||
expectedBody: "test",
|
||||
},
|
||||
{
|
||||
name: "with explicit content-type",
|
||||
status: 123,
|
||||
headers: map[string]string{"content-type": "text/test"},
|
||||
body: "test",
|
||||
expectedStatus: 123,
|
||||
expectedHeaders: map[string]string{"content-type": "text/test"},
|
||||
expectedBody: "test",
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
testEventResponseWrite(t, s, func(e *router.Event) error {
|
||||
return e.HTML(s.status, s.body)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventJSON(t *testing.T) {
|
||||
body := map[string]any{"a": 123, "b": 456, "c": "test"}
|
||||
expectedPickedBody := `{"a":123,"c":"test"}` + "\n"
|
||||
expectedFullBody := `{"a":123,"b":456,"c":"test"}` + "\n"
|
||||
|
||||
scenarios := []testResponseWriteScenario[any]{
|
||||
{
|
||||
name: "no explicit content-type",
|
||||
status: 200,
|
||||
headers: nil,
|
||||
body: body,
|
||||
expectedStatus: 200,
|
||||
expectedHeaders: map[string]string{"content-type": "application/json"},
|
||||
expectedBody: expectedPickedBody,
|
||||
},
|
||||
{
|
||||
name: "with explicit content-type (200)",
|
||||
status: 200,
|
||||
headers: map[string]string{"content-type": "application/test"},
|
||||
body: body,
|
||||
expectedStatus: 200,
|
||||
expectedHeaders: map[string]string{"content-type": "application/test"},
|
||||
expectedBody: expectedPickedBody,
|
||||
},
|
||||
{
|
||||
name: "with explicit content-type (400)", // no fields picker
|
||||
status: 400,
|
||||
headers: map[string]string{"content-type": "application/test"},
|
||||
body: body,
|
||||
expectedStatus: 400,
|
||||
expectedHeaders: map[string]string{"content-type": "application/test"},
|
||||
expectedBody: expectedFullBody,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
testEventResponseWrite(t, s, func(e *router.Event) error {
|
||||
e.Request.URL.RawQuery = "fields=a,c" // ensures that the picker is invoked
|
||||
return e.JSON(s.status, s.body)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventXML(t *testing.T) {
|
||||
scenarios := []testResponseWriteScenario[string]{
|
||||
{
|
||||
name: "no explicit content-type",
|
||||
status: 234,
|
||||
headers: nil,
|
||||
body: "test",
|
||||
expectedStatus: 234,
|
||||
expectedHeaders: map[string]string{"content-type": "application/xml; charset=utf-8"},
|
||||
expectedBody: xml.Header + "<string>test</string>",
|
||||
},
|
||||
{
|
||||
name: "with explicit content-type",
|
||||
status: 234,
|
||||
headers: map[string]string{"content-type": "text/test"},
|
||||
body: "test",
|
||||
expectedStatus: 234,
|
||||
expectedHeaders: map[string]string{"content-type": "text/test"},
|
||||
expectedBody: xml.Header + "<string>test</string>",
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
testEventResponseWrite(t, s, func(e *router.Event) error {
|
||||
return e.XML(s.status, s.body)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventStream(t *testing.T) {
|
||||
scenarios := []testResponseWriteScenario[string]{
|
||||
{
|
||||
name: "stream",
|
||||
status: 234,
|
||||
headers: map[string]string{"content-type": "text/test"},
|
||||
body: "test",
|
||||
expectedStatus: 234,
|
||||
expectedHeaders: map[string]string{"content-type": "text/test"},
|
||||
expectedBody: "test",
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
testEventResponseWrite(t, s, func(e *router.Event) error {
|
||||
return e.Stream(s.status, s.headers["content-type"], strings.NewReader(s.body))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventBlob(t *testing.T) {
|
||||
scenarios := []testResponseWriteScenario[[]byte]{
|
||||
{
|
||||
name: "blob",
|
||||
status: 234,
|
||||
headers: map[string]string{"content-type": "text/test"},
|
||||
body: []byte("test"),
|
||||
expectedStatus: 234,
|
||||
expectedHeaders: map[string]string{"content-type": "text/test"},
|
||||
expectedBody: "test",
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
testEventResponseWrite(t, s, func(e *router.Event) error {
|
||||
return e.Blob(s.status, s.headers["content-type"], s.body)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventNoContent(t *testing.T) {
|
||||
s := testResponseWriteScenario[any]{
|
||||
name: "no content",
|
||||
status: 234,
|
||||
headers: map[string]string{"content-type": "text/test"},
|
||||
body: nil,
|
||||
expectedStatus: 234,
|
||||
expectedHeaders: map[string]string{"content-type": "text/test"},
|
||||
expectedBody: "",
|
||||
}
|
||||
|
||||
testEventResponseWrite(t, s, func(e *router.Event) error {
|
||||
return e.NoContent(s.status)
|
||||
})
|
||||
}
|
||||
|
||||
func TestEventFlush(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
event := &router.Event{
|
||||
Response: unwrapTester{&router.ResponseWriter{ResponseWriter: rec}},
|
||||
}
|
||||
event.Response.Write([]byte("test"))
|
||||
event.Flush()
|
||||
|
||||
if !rec.Flushed {
|
||||
t.Fatal("Expected response to be flushed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventRedirect(t *testing.T) {
|
||||
scenarios := []testResponseWriteScenario[any]{
|
||||
{
|
||||
name: "non-30x status",
|
||||
status: 200,
|
||||
expectedStatus: 200,
|
||||
expectedError: router.ErrInvalidRedirectStatusCode,
|
||||
},
|
||||
{
|
||||
name: "30x status",
|
||||
status: 302,
|
||||
headers: map[string]string{"location": "test"}, // should be overwritten with the argument
|
||||
expectedStatus: 302,
|
||||
expectedHeaders: map[string]string{"location": "example"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
testEventResponseWrite(t, s, func(e *router.Event) error {
|
||||
return e.Redirect(s.status, "example")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventFileFS(t *testing.T) {
|
||||
// stub test files
|
||||
// ---
|
||||
dir, err := os.MkdirTemp("", "EventFileFS")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
err = os.WriteFile(filepath.Join(dir, "index.html"), []byte("index"), 0644)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = os.WriteFile(filepath.Join(dir, "test.txt"), []byte("test"), 0644)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// create sub directory with an index.html file inside it
|
||||
err = os.MkdirAll(filepath.Join(dir, "sub1"), os.ModePerm)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = os.WriteFile(filepath.Join(dir, "sub1", "index.html"), []byte("sub1 index"), 0644)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = os.MkdirAll(filepath.Join(dir, "sub2"), os.ModePerm)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = os.WriteFile(filepath.Join(dir, "sub2", "test.txt"), []byte("sub2 test"), 0644)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// ---
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
path string
|
||||
expected string
|
||||
}{
|
||||
{"missing file", "", ""},
|
||||
{"root with no explicit file", "", ""},
|
||||
{"root with explicit file", "test.txt", "test"},
|
||||
{"sub dir with no explicit file", "sub1", "sub1 index"},
|
||||
{"sub dir with no explicit file (no index.html)", "sub2", ""},
|
||||
{"sub dir explicit file", "sub2/test.txt", "sub2 test"},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
req, err := http.NewRequest(http.MethodGet, "/", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
event := &router.Event{
|
||||
Request: req,
|
||||
Response: rec,
|
||||
}
|
||||
|
||||
err = event.FileFS(os.DirFS(dir), s.path)
|
||||
|
||||
hasErr := err != nil
|
||||
expectErr := s.expected == ""
|
||||
if hasErr != expectErr {
|
||||
t.Fatalf("Expected hasErr %v, got %v (%v)", expectErr, hasErr, err)
|
||||
}
|
||||
|
||||
result := rec.Result()
|
||||
|
||||
raw, err := io.ReadAll(result.Body)
|
||||
result.Body.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if string(raw) != s.expected {
|
||||
t.Fatalf("Expected body\n%s\ngot\n%s", s.expected, raw)
|
||||
}
|
||||
|
||||
// ensure that the proper file headers are added
|
||||
// (aka. http.ServeContent is invoked)
|
||||
length, _ := strconv.Atoi(result.Header.Get("content-length"))
|
||||
if length != len(s.expected) {
|
||||
t.Fatalf("Expected Content-Length %d, got %d", len(s.expected), length)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventError(t *testing.T) {
|
||||
err := new(router.Event).Error(123, "message_test", map[string]any{"a": validation.Required, "b": "test"})
|
||||
|
||||
result, _ := json.Marshal(err)
|
||||
expected := `{"data":{"a":{"code":"validation_invalid_value","message":"Invalid value."},"b":{"code":"validation_invalid_value","message":"Invalid value."}},"message":"Message_test.","status":123}`
|
||||
|
||||
if string(result) != expected {
|
||||
t.Errorf("Expected\n%s\ngot\n%s", expected, result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventBadRequestError(t *testing.T) {
|
||||
err := new(router.Event).BadRequestError("message_test", map[string]any{"a": validation.Required, "b": "test"})
|
||||
|
||||
result, _ := json.Marshal(err)
|
||||
expected := `{"data":{"a":{"code":"validation_invalid_value","message":"Invalid value."},"b":{"code":"validation_invalid_value","message":"Invalid value."}},"message":"Message_test.","status":400}`
|
||||
|
||||
if string(result) != expected {
|
||||
t.Errorf("Expected\n%s\ngot\n%s", expected, result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventNotFoundError(t *testing.T) {
|
||||
err := new(router.Event).NotFoundError("message_test", map[string]any{"a": validation.Required, "b": "test"})
|
||||
|
||||
result, _ := json.Marshal(err)
|
||||
expected := `{"data":{"a":{"code":"validation_invalid_value","message":"Invalid value."},"b":{"code":"validation_invalid_value","message":"Invalid value."}},"message":"Message_test.","status":404}`
|
||||
|
||||
if string(result) != expected {
|
||||
t.Errorf("Expected\n%s\ngot\n%s", expected, result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventForbiddenError(t *testing.T) {
|
||||
err := new(router.Event).ForbiddenError("message_test", map[string]any{"a": validation.Required, "b": "test"})
|
||||
|
||||
result, _ := json.Marshal(err)
|
||||
expected := `{"data":{"a":{"code":"validation_invalid_value","message":"Invalid value."},"b":{"code":"validation_invalid_value","message":"Invalid value."}},"message":"Message_test.","status":403}`
|
||||
|
||||
if string(result) != expected {
|
||||
t.Errorf("Expected\n%s\ngot\n%s", expected, result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventUnauthorizedError(t *testing.T) {
|
||||
err := new(router.Event).UnauthorizedError("message_test", map[string]any{"a": validation.Required, "b": "test"})
|
||||
|
||||
result, _ := json.Marshal(err)
|
||||
expected := `{"data":{"a":{"code":"validation_invalid_value","message":"Invalid value."},"b":{"code":"validation_invalid_value","message":"Invalid value."}},"message":"Message_test.","status":401}`
|
||||
|
||||
if string(result) != expected {
|
||||
t.Errorf("Expected\n%s\ngot\n%s", expected, result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventTooManyRequestsError(t *testing.T) {
|
||||
err := new(router.Event).TooManyRequestsError("message_test", map[string]any{"a": validation.Required, "b": "test"})
|
||||
|
||||
result, _ := json.Marshal(err)
|
||||
expected := `{"data":{"a":{"code":"validation_invalid_value","message":"Invalid value."},"b":{"code":"validation_invalid_value","message":"Invalid value."}},"message":"Message_test.","status":429}`
|
||||
|
||||
if string(result) != expected {
|
||||
t.Errorf("Expected\n%s\ngot\n%s", expected, result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventInternalServerError(t *testing.T) {
|
||||
err := new(router.Event).InternalServerError("message_test", map[string]any{"a": validation.Required, "b": "test"})
|
||||
|
||||
result, _ := json.Marshal(err)
|
||||
expected := `{"data":{"a":{"code":"validation_invalid_value","message":"Invalid value."},"b":{"code":"validation_invalid_value","message":"Invalid value."}},"message":"Message_test.","status":500}`
|
||||
|
||||
if string(result) != expected {
|
||||
t.Errorf("Expected\n%s\ngot\n%s", expected, result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventBindBody(t *testing.T) {
|
||||
type testDstStruct struct {
|
||||
A int `json:"a" xml:"a" form:"a"`
|
||||
B int `json:"b" xml:"b" form:"b"`
|
||||
C string `json:"c" xml:"c" form:"c"`
|
||||
}
|
||||
|
||||
emptyDst := `{"a":0,"b":0,"c":""}`
|
||||
|
||||
queryDst := `a=123&b=-456&c=test`
|
||||
|
||||
xmlDst := `
|
||||
<?xml version="1.0" encoding="UTF-8" ?>
|
||||
<root>
|
||||
<a>123</a>
|
||||
<b>-456</b>
|
||||
<c>test</c>
|
||||
</root>
|
||||
`
|
||||
|
||||
jsonDst := `{"a":123,"b":-456,"c":"test"}`
|
||||
|
||||
// multipart
|
||||
mpBody := &bytes.Buffer{}
|
||||
mpWriter := multipart.NewWriter(mpBody)
|
||||
mpWriter.WriteField("@jsonPayload", `{"a":123}`)
|
||||
mpWriter.WriteField("b", "-456")
|
||||
mpWriter.WriteField("c", "test")
|
||||
if err := mpWriter.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
scenarios := []struct {
|
||||
contentType string
|
||||
body io.Reader
|
||||
expectDst string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
contentType: "",
|
||||
body: strings.NewReader(jsonDst),
|
||||
expectDst: emptyDst,
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
contentType: "application/rtf", // unsupported
|
||||
body: strings.NewReader(jsonDst),
|
||||
expectDst: emptyDst,
|
||||
expectError: true,
|
||||
},
|
||||
// empty body
|
||||
{
|
||||
contentType: "application/json;charset=emptybody",
|
||||
body: strings.NewReader(""),
|
||||
expectDst: emptyDst,
|
||||
},
|
||||
// json
|
||||
{
|
||||
contentType: "application/json",
|
||||
body: strings.NewReader(jsonDst),
|
||||
expectDst: jsonDst,
|
||||
},
|
||||
{
|
||||
contentType: "application/json;charset=abc",
|
||||
body: strings.NewReader(jsonDst),
|
||||
expectDst: jsonDst,
|
||||
},
|
||||
// xml
|
||||
{
|
||||
contentType: "text/xml",
|
||||
body: strings.NewReader(xmlDst),
|
||||
expectDst: jsonDst,
|
||||
},
|
||||
{
|
||||
contentType: "text/xml;charset=abc",
|
||||
body: strings.NewReader(xmlDst),
|
||||
expectDst: jsonDst,
|
||||
},
|
||||
{
|
||||
contentType: "application/xml",
|
||||
body: strings.NewReader(xmlDst),
|
||||
expectDst: jsonDst,
|
||||
},
|
||||
{
|
||||
contentType: "application/xml;charset=abc",
|
||||
body: strings.NewReader(xmlDst),
|
||||
expectDst: jsonDst,
|
||||
},
|
||||
// x-www-form-urlencoded
|
||||
{
|
||||
contentType: "application/x-www-form-urlencoded",
|
||||
body: strings.NewReader(queryDst),
|
||||
expectDst: jsonDst,
|
||||
},
|
||||
{
|
||||
contentType: "application/x-www-form-urlencoded;charset=abc",
|
||||
body: strings.NewReader(queryDst),
|
||||
expectDst: jsonDst,
|
||||
},
|
||||
// multipart
|
||||
{
|
||||
contentType: mpWriter.FormDataContentType(),
|
||||
body: mpBody,
|
||||
expectDst: jsonDst,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.contentType, func(t *testing.T) {
|
||||
req, err := http.NewRequest(http.MethodPost, "/", s.body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
req.Header.Add("content-type", s.contentType)
|
||||
|
||||
event := &router.Event{Request: req}
|
||||
|
||||
dst := testDstStruct{}
|
||||
|
||||
err = event.BindBody(&dst)
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
|
||||
}
|
||||
|
||||
dstRaw, err := json.Marshal(dst)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if string(dstRaw) != s.expectDst {
|
||||
t.Fatalf("Expected dst\n%s\ngot\n%s", s.expectDst, dstRaw)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
type testResponseWriteScenario[T any] struct {
|
||||
name string
|
||||
status int
|
||||
headers map[string]string
|
||||
body T
|
||||
expectedStatus int
|
||||
expectedHeaders map[string]string
|
||||
expectedBody string
|
||||
expectedError error
|
||||
}
|
||||
|
||||
func testEventResponseWrite[T any](
|
||||
t *testing.T,
|
||||
scenario testResponseWriteScenario[T],
|
||||
writeFunc func(e *router.Event) error,
|
||||
) {
|
||||
t.Run(scenario.name, func(t *testing.T) {
|
||||
req, err := http.NewRequest(http.MethodGet, "/", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
event := &router.Event{
|
||||
Request: req,
|
||||
Response: &router.ResponseWriter{ResponseWriter: rec},
|
||||
}
|
||||
|
||||
for k, v := range scenario.headers {
|
||||
event.Response.Header().Add(k, v)
|
||||
}
|
||||
|
||||
err = writeFunc(event)
|
||||
if (scenario.expectedError != nil || err != nil) && !errors.Is(err, scenario.expectedError) {
|
||||
t.Fatalf("Expected error %v, got %v", scenario.expectedError, err)
|
||||
}
|
||||
|
||||
result := rec.Result()
|
||||
|
||||
if result.StatusCode != scenario.expectedStatus {
|
||||
t.Fatalf("Expected status code %d, got %d", scenario.expectedStatus, result.StatusCode)
|
||||
}
|
||||
|
||||
resultBody, err := io.ReadAll(result.Body)
|
||||
result.Body.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read response body: %v", err)
|
||||
}
|
||||
|
||||
resultBody, err = json.Marshal(string(resultBody))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expectedBody, err := json.Marshal(scenario.expectedBody)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(resultBody, expectedBody) {
|
||||
t.Fatalf("Expected body\n%s\ngot\n%s", expectedBody, resultBody)
|
||||
}
|
||||
|
||||
for k, ev := range scenario.expectedHeaders {
|
||||
if v := result.Header.Get(k); v != ev {
|
||||
t.Fatalf("Expected %q header to be %q, got %q", k, ev, v)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
231
tools/router/group.go
Normal file
231
tools/router/group.go
Normal file
|
@ -0,0 +1,231 @@
|
|||
package router
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/hook"
|
||||
)
|
||||
|
||||
// (note: the struct is named RouterGroup instead of Group so that it can
|
||||
// be embedded in the Router without conflicting with the Group method)
|
||||
|
||||
// RouterGroup represents a collection of routes and other sub groups
|
||||
// that share common pattern prefix and middlewares.
|
||||
type RouterGroup[T hook.Resolver] struct {
|
||||
excludedMiddlewares map[string]struct{}
|
||||
children []any // Route or RouterGroup
|
||||
|
||||
Prefix string
|
||||
Middlewares []*hook.Handler[T]
|
||||
}
|
||||
|
||||
// Group creates and register a new child Group into the current one
|
||||
// with the specified prefix.
|
||||
//
|
||||
// The prefix follows the standard Go net/http ServeMux pattern format ("[HOST]/[PATH]")
|
||||
// and will be concatenated recursively into the final route path, meaning that
|
||||
// only the root level group could have HOST as part of the prefix.
|
||||
//
|
||||
// Returns the newly created group to allow chaining and registering
|
||||
// sub-routes and group specific middlewares.
|
||||
func (group *RouterGroup[T]) Group(prefix string) *RouterGroup[T] {
|
||||
newGroup := &RouterGroup[T]{}
|
||||
newGroup.Prefix = prefix
|
||||
|
||||
group.children = append(group.children, newGroup)
|
||||
|
||||
return newGroup
|
||||
}
|
||||
|
||||
// BindFunc registers one or multiple middleware functions to the current group.
|
||||
//
|
||||
// The registered middleware functions are "anonymous" and with default priority,
|
||||
// aka. executes in the order they were registered.
|
||||
//
|
||||
// If you need to specify a named middleware (ex. so that it can be removed)
|
||||
// or middleware with custom exec prirority, use [RouterGroup.Bind] method.
|
||||
func (group *RouterGroup[T]) BindFunc(middlewareFuncs ...func(e T) error) *RouterGroup[T] {
|
||||
for _, m := range middlewareFuncs {
|
||||
group.Middlewares = append(group.Middlewares, &hook.Handler[T]{Func: m})
|
||||
}
|
||||
|
||||
return group
|
||||
}
|
||||
|
||||
// Bind registers one or multiple middleware handlers to the current group.
|
||||
func (group *RouterGroup[T]) Bind(middlewares ...*hook.Handler[T]) *RouterGroup[T] {
|
||||
group.Middlewares = append(group.Middlewares, middlewares...)
|
||||
|
||||
// unmark the newly added middlewares in case they were previously "excluded"
|
||||
if group.excludedMiddlewares != nil {
|
||||
for _, m := range middlewares {
|
||||
if m.Id != "" {
|
||||
delete(group.excludedMiddlewares, m.Id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return group
|
||||
}
|
||||
|
||||
// Unbind removes one or more middlewares with the specified id(s)
|
||||
// from the current group and its children (if any).
|
||||
//
|
||||
// Anonymous middlewares are not removable, aka. this method does nothing
|
||||
// if the middleware id is an empty string.
|
||||
func (group *RouterGroup[T]) Unbind(middlewareIds ...string) *RouterGroup[T] {
|
||||
for _, middlewareId := range middlewareIds {
|
||||
if middlewareId == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// remove from the group middlwares
|
||||
for i := len(group.Middlewares) - 1; i >= 0; i-- {
|
||||
if group.Middlewares[i].Id == middlewareId {
|
||||
group.Middlewares = append(group.Middlewares[:i], group.Middlewares[i+1:]...)
|
||||
}
|
||||
}
|
||||
|
||||
// remove from the group children
|
||||
for i := len(group.children) - 1; i >= 0; i-- {
|
||||
switch v := group.children[i].(type) {
|
||||
case *RouterGroup[T]:
|
||||
v.Unbind(middlewareId)
|
||||
case *Route[T]:
|
||||
v.Unbind(middlewareId)
|
||||
}
|
||||
}
|
||||
|
||||
// add to the exclude list
|
||||
if group.excludedMiddlewares == nil {
|
||||
group.excludedMiddlewares = map[string]struct{}{}
|
||||
}
|
||||
group.excludedMiddlewares[middlewareId] = struct{}{}
|
||||
}
|
||||
|
||||
return group
|
||||
}
|
||||
|
||||
// Route registers a single route into the current group.
|
||||
//
|
||||
// Note that the final route path will be the concatenation of all parent groups prefixes + the route path.
|
||||
// The path follows the standard Go net/http ServeMux format ("[HOST]/[PATH]"),
|
||||
// meaning that only a top level group route could have HOST as part of the prefix.
|
||||
//
|
||||
// Returns the newly created route to allow attaching route-only middlewares.
|
||||
func (group *RouterGroup[T]) Route(method string, path string, action func(e T) error) *Route[T] {
|
||||
route := &Route[T]{
|
||||
Method: method,
|
||||
Path: path,
|
||||
Action: action,
|
||||
}
|
||||
|
||||
group.children = append(group.children, route)
|
||||
|
||||
return route
|
||||
}
|
||||
|
||||
// Any is a shorthand for [RouterGroup.AddRoute] with "" as route method (aka. matches any method).
|
||||
func (group *RouterGroup[T]) Any(path string, action func(e T) error) *Route[T] {
|
||||
return group.Route("", path, action)
|
||||
}
|
||||
|
||||
// GET is a shorthand for [RouterGroup.AddRoute] with GET as route method.
|
||||
func (group *RouterGroup[T]) GET(path string, action func(e T) error) *Route[T] {
|
||||
return group.Route(http.MethodGet, path, action)
|
||||
}
|
||||
|
||||
// SEARCH is a shorthand for [RouterGroup.AddRoute] with SEARCH as route method.
|
||||
func (group *RouterGroup[T]) SEARCH(path string, action func(e T) error) *Route[T] {
|
||||
return group.Route("SEARCH", path, action)
|
||||
}
|
||||
|
||||
// POST is a shorthand for [RouterGroup.AddRoute] with POST as route method.
|
||||
func (group *RouterGroup[T]) POST(path string, action func(e T) error) *Route[T] {
|
||||
return group.Route(http.MethodPost, path, action)
|
||||
}
|
||||
|
||||
// DELETE is a shorthand for [RouterGroup.AddRoute] with DELETE as route method.
|
||||
func (group *RouterGroup[T]) DELETE(path string, action func(e T) error) *Route[T] {
|
||||
return group.Route(http.MethodDelete, path, action)
|
||||
}
|
||||
|
||||
// PATCH is a shorthand for [RouterGroup.AddRoute] with PATCH as route method.
|
||||
func (group *RouterGroup[T]) PATCH(path string, action func(e T) error) *Route[T] {
|
||||
return group.Route(http.MethodPatch, path, action)
|
||||
}
|
||||
|
||||
// PUT is a shorthand for [RouterGroup.AddRoute] with PUT as route method.
|
||||
func (group *RouterGroup[T]) PUT(path string, action func(e T) error) *Route[T] {
|
||||
return group.Route(http.MethodPut, path, action)
|
||||
}
|
||||
|
||||
// HEAD is a shorthand for [RouterGroup.AddRoute] with HEAD as route method.
|
||||
func (group *RouterGroup[T]) HEAD(path string, action func(e T) error) *Route[T] {
|
||||
return group.Route(http.MethodHead, path, action)
|
||||
}
|
||||
|
||||
// OPTIONS is a shorthand for [RouterGroup.AddRoute] with OPTIONS as route method.
|
||||
func (group *RouterGroup[T]) OPTIONS(path string, action func(e T) error) *Route[T] {
|
||||
return group.Route(http.MethodOptions, path, action)
|
||||
}
|
||||
|
||||
// HasRoute checks whether the specified route pattern (method + path)
|
||||
// is registered in the current group or its children.
|
||||
//
|
||||
// This could be useful to conditionally register and checks for routes
|
||||
// in order prevent panic on duplicated routes.
|
||||
//
|
||||
// Note that routes with anonymous and named wildcard placeholder are treated as equal,
|
||||
// aka. "GET /abc/" is considered the same as "GET /abc/{something...}".
|
||||
func (group *RouterGroup[T]) HasRoute(method string, path string) bool {
|
||||
pattern := path
|
||||
if method != "" {
|
||||
pattern = strings.ToUpper(method) + " " + pattern
|
||||
}
|
||||
|
||||
return group.hasRoute(pattern, nil)
|
||||
}
|
||||
|
||||
func (group *RouterGroup[T]) hasRoute(pattern string, parents []*RouterGroup[T]) bool {
|
||||
for _, child := range group.children {
|
||||
switch v := child.(type) {
|
||||
case *RouterGroup[T]:
|
||||
if v.hasRoute(pattern, append(parents, group)) {
|
||||
return true
|
||||
}
|
||||
case *Route[T]:
|
||||
var result string
|
||||
|
||||
if v.Method != "" {
|
||||
result += v.Method + " "
|
||||
}
|
||||
|
||||
// add parent groups prefixes
|
||||
for _, p := range parents {
|
||||
result += p.Prefix
|
||||
}
|
||||
|
||||
// add current group prefix
|
||||
result += group.Prefix
|
||||
|
||||
// add current route path
|
||||
result += v.Path
|
||||
|
||||
if result == pattern || // direct match
|
||||
// compares without the named wildcard, aka. /abc/{test...} is equal to /abc/
|
||||
stripWildcard(result) == stripWildcard(pattern) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
var wildcardPlaceholderRegex = regexp.MustCompile(`/{.+\.\.\.}$`)
|
||||
|
||||
func stripWildcard(pattern string) string {
|
||||
return wildcardPlaceholderRegex.ReplaceAllString(pattern, "/")
|
||||
}
|
430
tools/router/group_test.go
Normal file
430
tools/router/group_test.go
Normal file
|
@ -0,0 +1,430 @@
|
|||
package router
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/hook"
|
||||
)
|
||||
|
||||
func TestRouterGroupGroup(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
g0 := RouterGroup[*Event]{}
|
||||
|
||||
g1 := g0.Group("test1")
|
||||
g2 := g0.Group("test2")
|
||||
|
||||
if total := len(g0.children); total != 2 {
|
||||
t.Fatalf("Expected %d child groups, got %d", 2, total)
|
||||
}
|
||||
|
||||
if g1.Prefix != "test1" {
|
||||
t.Fatalf("Expected g1 with prefix %q, got %q", "test1", g1.Prefix)
|
||||
}
|
||||
if g2.Prefix != "test2" {
|
||||
t.Fatalf("Expected g2 with prefix %q, got %q", "test2", g2.Prefix)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouterGroupBindFunc(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
g := RouterGroup[*Event]{}
|
||||
|
||||
calls := ""
|
||||
|
||||
// append one function
|
||||
g.BindFunc(func(e *Event) error {
|
||||
calls += "a"
|
||||
return nil
|
||||
})
|
||||
|
||||
// append multiple functions
|
||||
g.BindFunc(
|
||||
func(e *Event) error {
|
||||
calls += "b"
|
||||
return nil
|
||||
},
|
||||
func(e *Event) error {
|
||||
calls += "c"
|
||||
return nil
|
||||
},
|
||||
)
|
||||
|
||||
if total := len(g.Middlewares); total != 3 {
|
||||
t.Fatalf("Expected %d middlewares, got %v", 3, total)
|
||||
}
|
||||
|
||||
for _, h := range g.Middlewares {
|
||||
_ = h.Func(nil)
|
||||
}
|
||||
|
||||
if calls != "abc" {
|
||||
t.Fatalf("Expected calls sequence %q, got %q", "abc", calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouterGroupBind(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
g := RouterGroup[*Event]{
|
||||
// mock excluded middlewares to check whether the entry will be deleted
|
||||
excludedMiddlewares: map[string]struct{}{"test2": {}},
|
||||
}
|
||||
|
||||
calls := ""
|
||||
|
||||
// append one handler
|
||||
g.Bind(&hook.Handler[*Event]{
|
||||
Func: func(e *Event) error {
|
||||
calls += "a"
|
||||
return nil
|
||||
},
|
||||
})
|
||||
|
||||
// append multiple handlers
|
||||
g.Bind(
|
||||
&hook.Handler[*Event]{
|
||||
Id: "test1",
|
||||
Func: func(e *Event) error {
|
||||
calls += "b"
|
||||
return nil
|
||||
},
|
||||
},
|
||||
&hook.Handler[*Event]{
|
||||
Id: "test2",
|
||||
Func: func(e *Event) error {
|
||||
calls += "c"
|
||||
return nil
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
if total := len(g.Middlewares); total != 3 {
|
||||
t.Fatalf("Expected %d middlewares, got %v", 3, total)
|
||||
}
|
||||
|
||||
for _, h := range g.Middlewares {
|
||||
_ = h.Func(nil)
|
||||
}
|
||||
|
||||
if calls != "abc" {
|
||||
t.Fatalf("Expected calls %q, got %q", "abc", calls)
|
||||
}
|
||||
|
||||
// ensures that the previously excluded middleware was removed
|
||||
if len(g.excludedMiddlewares) != 0 {
|
||||
t.Fatalf("Expected test2 to be removed from the excludedMiddlewares list, got %v", g.excludedMiddlewares)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouterGroupUnbind(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
g := RouterGroup[*Event]{}
|
||||
|
||||
calls := ""
|
||||
|
||||
// anonymous middlewares
|
||||
g.Bind(&hook.Handler[*Event]{
|
||||
Func: func(e *Event) error {
|
||||
calls += "a"
|
||||
return nil // unused value
|
||||
},
|
||||
})
|
||||
|
||||
// middlewares with id
|
||||
g.Bind(&hook.Handler[*Event]{
|
||||
Id: "test1",
|
||||
Func: func(e *Event) error {
|
||||
calls += "b"
|
||||
return nil // unused value
|
||||
},
|
||||
})
|
||||
g.Bind(&hook.Handler[*Event]{
|
||||
Id: "test2",
|
||||
Func: func(e *Event) error {
|
||||
calls += "c"
|
||||
return nil // unused value
|
||||
},
|
||||
})
|
||||
g.Bind(&hook.Handler[*Event]{
|
||||
Id: "test3",
|
||||
Func: func(e *Event) error {
|
||||
calls += "d"
|
||||
return nil // unused value
|
||||
},
|
||||
})
|
||||
|
||||
// remove
|
||||
g.Unbind("") // should be no-op
|
||||
g.Unbind("test1", "test3")
|
||||
|
||||
if total := len(g.Middlewares); total != 2 {
|
||||
t.Fatalf("Expected %d middlewares, got %v", 2, total)
|
||||
}
|
||||
|
||||
for _, h := range g.Middlewares {
|
||||
if err := h.Func(nil); err != nil {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if calls != "ac" {
|
||||
t.Fatalf("Expected calls %q, got %q", "ac", calls)
|
||||
}
|
||||
|
||||
// ensure that the ids were added in the exclude list
|
||||
excluded := []string{"test1", "test3"}
|
||||
if len(g.excludedMiddlewares) != len(excluded) {
|
||||
t.Fatalf("Expected excludes %v, got %v", excluded, g.excludedMiddlewares)
|
||||
}
|
||||
for id := range g.excludedMiddlewares {
|
||||
if !slices.Contains(excluded, id) {
|
||||
t.Fatalf("Expected %q to be marked as excluded", id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouterGroupRoute(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
group := RouterGroup[*Event]{}
|
||||
|
||||
sub := group.Group("sub")
|
||||
|
||||
var called bool
|
||||
route := group.Route(http.MethodPost, "/test", func(e *Event) error {
|
||||
called = true
|
||||
return nil
|
||||
})
|
||||
|
||||
// ensure that the route was registered only to the main one
|
||||
// ---
|
||||
if len(sub.children) != 0 {
|
||||
t.Fatalf("Expected no sub children, got %d", len(sub.children))
|
||||
}
|
||||
|
||||
if len(group.children) != 2 {
|
||||
t.Fatalf("Expected %d group children, got %d", 2, len(group.children))
|
||||
}
|
||||
// ---
|
||||
|
||||
// check the registered route
|
||||
// ---
|
||||
if route != group.children[1] {
|
||||
t.Fatalf("Expected group children %v, got %v", route, group.children[1])
|
||||
}
|
||||
|
||||
if route.Method != http.MethodPost {
|
||||
t.Fatalf("Expected route method %q, got %q", http.MethodPost, route.Method)
|
||||
}
|
||||
|
||||
if route.Path != "/test" {
|
||||
t.Fatalf("Expected route path %q, got %q", "/test", route.Path)
|
||||
}
|
||||
|
||||
route.Action(nil)
|
||||
if !called {
|
||||
t.Fatal("Expected route action to be called")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouterGroupRouteAliases(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
group := RouterGroup[*Event]{}
|
||||
|
||||
testErr := errors.New("test")
|
||||
|
||||
testAction := func(e *Event) error {
|
||||
return testErr
|
||||
}
|
||||
|
||||
scenarios := []struct {
|
||||
route *Route[*Event]
|
||||
expectMethod string
|
||||
expectPath string
|
||||
}{
|
||||
{
|
||||
group.Any("/test", testAction),
|
||||
"",
|
||||
"/test",
|
||||
},
|
||||
{
|
||||
group.GET("/test", testAction),
|
||||
http.MethodGet,
|
||||
"/test",
|
||||
},
|
||||
{
|
||||
group.SEARCH("/test", testAction),
|
||||
"SEARCH",
|
||||
"/test",
|
||||
},
|
||||
{
|
||||
group.POST("/test", testAction),
|
||||
http.MethodPost,
|
||||
"/test",
|
||||
},
|
||||
{
|
||||
group.DELETE("/test", testAction),
|
||||
http.MethodDelete,
|
||||
"/test",
|
||||
},
|
||||
{
|
||||
group.PATCH("/test", testAction),
|
||||
http.MethodPatch,
|
||||
"/test",
|
||||
},
|
||||
{
|
||||
group.PUT("/test", testAction),
|
||||
http.MethodPut,
|
||||
"/test",
|
||||
},
|
||||
{
|
||||
group.HEAD("/test", testAction),
|
||||
http.MethodHead,
|
||||
"/test",
|
||||
},
|
||||
{
|
||||
group.OPTIONS("/test", testAction),
|
||||
http.MethodOptions,
|
||||
"/test",
|
||||
},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%s_%s", i, s.expectMethod, s.expectPath), func(t *testing.T) {
|
||||
if s.route.Method != s.expectMethod {
|
||||
t.Fatalf("Expected method %q, got %q", s.expectMethod, s.route.Method)
|
||||
}
|
||||
|
||||
if s.route.Path != s.expectPath {
|
||||
t.Fatalf("Expected path %q, got %q", s.expectPath, s.route.Path)
|
||||
}
|
||||
|
||||
if err := s.route.Action(nil); !errors.Is(err, testErr) {
|
||||
t.Fatal("Expected test action")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouterGroupHasRoute(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
group := RouterGroup[*Event]{}
|
||||
|
||||
group.Any("/any", nil)
|
||||
|
||||
group.GET("/base", nil)
|
||||
group.DELETE("/base", nil)
|
||||
|
||||
sub := group.Group("/sub1")
|
||||
sub.GET("/a", nil)
|
||||
sub.POST("/a", nil)
|
||||
|
||||
sub2 := sub.Group("/sub2")
|
||||
sub2.GET("/b", nil)
|
||||
sub2.GET("/b/{test}", nil)
|
||||
|
||||
// special cases to test the normalizations
|
||||
group.GET("/c/", nil) // the same as /c/{test...}
|
||||
group.GET("/d/{test...}", nil) // the same as /d/
|
||||
|
||||
scenarios := []struct {
|
||||
method string
|
||||
path string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
http.MethodGet,
|
||||
"",
|
||||
false,
|
||||
},
|
||||
{
|
||||
"",
|
||||
"/any",
|
||||
true,
|
||||
},
|
||||
{
|
||||
http.MethodPost,
|
||||
"/base",
|
||||
false,
|
||||
},
|
||||
{
|
||||
http.MethodGet,
|
||||
"/base",
|
||||
true,
|
||||
},
|
||||
{
|
||||
http.MethodDelete,
|
||||
"/base",
|
||||
true,
|
||||
},
|
||||
{
|
||||
http.MethodGet,
|
||||
"/sub1",
|
||||
false,
|
||||
},
|
||||
{
|
||||
http.MethodGet,
|
||||
"/sub1/a",
|
||||
true,
|
||||
},
|
||||
{
|
||||
http.MethodPost,
|
||||
"/sub1/a",
|
||||
true,
|
||||
},
|
||||
{
|
||||
http.MethodDelete,
|
||||
"/sub1/a",
|
||||
false,
|
||||
},
|
||||
{
|
||||
http.MethodGet,
|
||||
"/sub2/b",
|
||||
false,
|
||||
},
|
||||
{
|
||||
http.MethodGet,
|
||||
"/sub1/sub2/b",
|
||||
true,
|
||||
},
|
||||
{
|
||||
http.MethodGet,
|
||||
"/sub1/sub2/b/{test}",
|
||||
true,
|
||||
},
|
||||
{
|
||||
http.MethodGet,
|
||||
"/sub1/sub2/b/{test2}",
|
||||
false,
|
||||
},
|
||||
{
|
||||
http.MethodGet,
|
||||
"/c/{test...}",
|
||||
true,
|
||||
},
|
||||
{
|
||||
http.MethodGet,
|
||||
"/d/",
|
||||
true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.method+"_"+s.path, func(t *testing.T) {
|
||||
has := group.HasRoute(s.method, s.path)
|
||||
|
||||
if has != s.expected {
|
||||
t.Fatalf("Expected %v, got %v", s.expected, has)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
60
tools/router/rereadable_read_closer.go
Normal file
60
tools/router/rereadable_read_closer.go
Normal file
|
@ -0,0 +1,60 @@
|
|||
package router
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
)
|
||||
|
||||
var (
|
||||
_ io.ReadCloser = (*RereadableReadCloser)(nil)
|
||||
_ Rereader = (*RereadableReadCloser)(nil)
|
||||
)
|
||||
|
||||
// Rereader defines an interface for rewindable readers.
|
||||
type Rereader interface {
|
||||
Reread()
|
||||
}
|
||||
|
||||
// RereadableReadCloser defines a wrapper around a io.ReadCloser reader
|
||||
// allowing to read the original reader multiple times.
|
||||
type RereadableReadCloser struct {
|
||||
io.ReadCloser
|
||||
|
||||
copy *bytes.Buffer
|
||||
active io.Reader
|
||||
}
|
||||
|
||||
// Read implements the standard io.Reader interface.
|
||||
//
|
||||
// It reads up to len(b) bytes into b and at at the same time writes
|
||||
// the read data into an internal bytes buffer.
|
||||
//
|
||||
// On EOF the r is "rewinded" to allow reading from r multiple times.
|
||||
func (r *RereadableReadCloser) Read(b []byte) (int, error) {
|
||||
if r.active == nil {
|
||||
if r.copy == nil {
|
||||
r.copy = &bytes.Buffer{}
|
||||
}
|
||||
r.active = io.TeeReader(r.ReadCloser, r.copy)
|
||||
}
|
||||
|
||||
n, err := r.active.Read(b)
|
||||
if err == io.EOF {
|
||||
r.Reread()
|
||||
}
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
||||
// Reread satisfies the [Rereader] interface and resets the r internal state to allow rereads.
|
||||
//
|
||||
// note: not named Reset to avoid conflicts with other reader interfaces.
|
||||
func (r *RereadableReadCloser) Reread() {
|
||||
if r.copy == nil || r.copy.Len() == 0 {
|
||||
return // nothing to reset or it has been already reset
|
||||
}
|
||||
|
||||
oldCopy := r.copy
|
||||
r.copy = &bytes.Buffer{}
|
||||
r.active = io.TeeReader(oldCopy, r.copy)
|
||||
}
|
28
tools/router/rereadable_read_closer_test.go
Normal file
28
tools/router/rereadable_read_closer_test.go
Normal file
|
@ -0,0 +1,28 @@
|
|||
package router_test
|
||||
|
||||
import (
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/router"
|
||||
)
|
||||
|
||||
func TestRereadableReadCloser(t *testing.T) {
|
||||
content := "test"
|
||||
|
||||
rereadable := &router.RereadableReadCloser{
|
||||
ReadCloser: io.NopCloser(strings.NewReader(content)),
|
||||
}
|
||||
|
||||
// read multiple times
|
||||
for i := 0; i < 3; i++ {
|
||||
result, err := io.ReadAll(rereadable)
|
||||
if err != nil {
|
||||
t.Fatalf("[read:%d] %v", i, err)
|
||||
}
|
||||
if str := string(result); str != content {
|
||||
t.Fatalf("[read:%d] Expected %q, got %q", i, content, result)
|
||||
}
|
||||
}
|
||||
}
|
73
tools/router/route.go
Normal file
73
tools/router/route.go
Normal file
|
@ -0,0 +1,73 @@
|
|||
package router
|
||||
|
||||
import "github.com/pocketbase/pocketbase/tools/hook"
|
||||
|
||||
type Route[T hook.Resolver] struct {
|
||||
excludedMiddlewares map[string]struct{}
|
||||
|
||||
Action func(e T) error
|
||||
Method string
|
||||
Path string
|
||||
Middlewares []*hook.Handler[T]
|
||||
}
|
||||
|
||||
// BindFunc registers one or multiple middleware functions to the current route.
|
||||
//
|
||||
// The registered middleware functions are "anonymous" and with default priority,
|
||||
// aka. executes in the order they were registered.
|
||||
//
|
||||
// If you need to specify a named middleware (ex. so that it can be removed)
|
||||
// or middleware with custom exec prirority, use the [Route.Bind] method.
|
||||
func (route *Route[T]) BindFunc(middlewareFuncs ...func(e T) error) *Route[T] {
|
||||
for _, m := range middlewareFuncs {
|
||||
route.Middlewares = append(route.Middlewares, &hook.Handler[T]{Func: m})
|
||||
}
|
||||
|
||||
return route
|
||||
}
|
||||
|
||||
// Bind registers one or multiple middleware handlers to the current route.
|
||||
func (route *Route[T]) Bind(middlewares ...*hook.Handler[T]) *Route[T] {
|
||||
route.Middlewares = append(route.Middlewares, middlewares...)
|
||||
|
||||
// unmark the newly added middlewares in case they were previously "excluded"
|
||||
if route.excludedMiddlewares != nil {
|
||||
for _, m := range middlewares {
|
||||
if m.Id != "" {
|
||||
delete(route.excludedMiddlewares, m.Id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return route
|
||||
}
|
||||
|
||||
// Unbind removes one or more middlewares with the specified id(s) from the current route.
|
||||
//
|
||||
// It also adds the removed middleware ids to an exclude list so that they could be skipped from
|
||||
// the execution chain in case the middleware is registered in a parent group.
|
||||
//
|
||||
// Anonymous middlewares are considered non-removable, aka. this method
|
||||
// does nothing if the middleware id is an empty string.
|
||||
func (route *Route[T]) Unbind(middlewareIds ...string) *Route[T] {
|
||||
for _, middlewareId := range middlewareIds {
|
||||
if middlewareId == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// remove from the route's middlewares
|
||||
for i := len(route.Middlewares) - 1; i >= 0; i-- {
|
||||
if route.Middlewares[i].Id == middlewareId {
|
||||
route.Middlewares = append(route.Middlewares[:i], route.Middlewares[i+1:]...)
|
||||
}
|
||||
}
|
||||
|
||||
// add to the exclude list
|
||||
if route.excludedMiddlewares == nil {
|
||||
route.excludedMiddlewares = map[string]struct{}{}
|
||||
}
|
||||
route.excludedMiddlewares[middlewareId] = struct{}{}
|
||||
}
|
||||
|
||||
return route
|
||||
}
|
168
tools/router/route_test.go
Normal file
168
tools/router/route_test.go
Normal file
|
@ -0,0 +1,168 @@
|
|||
package router
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/hook"
|
||||
)
|
||||
|
||||
func TestRouteBindFunc(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
r := Route[*Event]{}
|
||||
|
||||
calls := ""
|
||||
|
||||
// append one function
|
||||
r.BindFunc(func(e *Event) error {
|
||||
calls += "a"
|
||||
return nil
|
||||
})
|
||||
|
||||
// append multiple functions
|
||||
r.BindFunc(
|
||||
func(e *Event) error {
|
||||
calls += "b"
|
||||
return nil
|
||||
},
|
||||
func(e *Event) error {
|
||||
calls += "c"
|
||||
return nil
|
||||
},
|
||||
)
|
||||
|
||||
if total := len(r.Middlewares); total != 3 {
|
||||
t.Fatalf("Expected %d middlewares, got %v", 3, total)
|
||||
}
|
||||
|
||||
for _, h := range r.Middlewares {
|
||||
_ = h.Func(nil)
|
||||
}
|
||||
|
||||
if calls != "abc" {
|
||||
t.Fatalf("Expected calls sequence %q, got %q", "abc", calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouteBind(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
r := Route[*Event]{
|
||||
// mock excluded middlewares to check whether the entry will be deleted
|
||||
excludedMiddlewares: map[string]struct{}{"test2": {}},
|
||||
}
|
||||
|
||||
calls := ""
|
||||
|
||||
// append one handler
|
||||
r.Bind(&hook.Handler[*Event]{
|
||||
Func: func(e *Event) error {
|
||||
calls += "a"
|
||||
return nil
|
||||
},
|
||||
})
|
||||
|
||||
// append multiple handlers
|
||||
r.Bind(
|
||||
&hook.Handler[*Event]{
|
||||
Id: "test1",
|
||||
Func: func(e *Event) error {
|
||||
calls += "b"
|
||||
return nil
|
||||
},
|
||||
},
|
||||
&hook.Handler[*Event]{
|
||||
Id: "test2",
|
||||
Func: func(e *Event) error {
|
||||
calls += "c"
|
||||
return nil
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
if total := len(r.Middlewares); total != 3 {
|
||||
t.Fatalf("Expected %d middlewares, got %v", 3, total)
|
||||
}
|
||||
|
||||
for _, h := range r.Middlewares {
|
||||
_ = h.Func(nil)
|
||||
}
|
||||
|
||||
if calls != "abc" {
|
||||
t.Fatalf("Expected calls %q, got %q", "abc", calls)
|
||||
}
|
||||
|
||||
// ensures that the previously excluded middleware was removed
|
||||
if len(r.excludedMiddlewares) != 0 {
|
||||
t.Fatalf("Expected test2 to be removed from the excludedMiddlewares list, got %v", r.excludedMiddlewares)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouteUnbind(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
r := Route[*Event]{}
|
||||
|
||||
calls := ""
|
||||
|
||||
// anonymous middlewares
|
||||
r.Bind(&hook.Handler[*Event]{
|
||||
Func: func(e *Event) error {
|
||||
calls += "a"
|
||||
return nil // unused value
|
||||
},
|
||||
})
|
||||
|
||||
// middlewares with id
|
||||
r.Bind(&hook.Handler[*Event]{
|
||||
Id: "test1",
|
||||
Func: func(e *Event) error {
|
||||
calls += "b"
|
||||
return nil // unused value
|
||||
},
|
||||
})
|
||||
r.Bind(&hook.Handler[*Event]{
|
||||
Id: "test2",
|
||||
Func: func(e *Event) error {
|
||||
calls += "c"
|
||||
return nil // unused value
|
||||
},
|
||||
})
|
||||
r.Bind(&hook.Handler[*Event]{
|
||||
Id: "test3",
|
||||
Func: func(e *Event) error {
|
||||
calls += "d"
|
||||
return nil // unused value
|
||||
},
|
||||
})
|
||||
|
||||
// remove
|
||||
r.Unbind("") // should be no-op
|
||||
r.Unbind("test1", "test3")
|
||||
|
||||
if total := len(r.Middlewares); total != 2 {
|
||||
t.Fatalf("Expected %d middlewares, got %v", 2, total)
|
||||
}
|
||||
|
||||
for _, h := range r.Middlewares {
|
||||
if err := h.Func(nil); err != nil {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if calls != "ac" {
|
||||
t.Fatalf("Expected calls %q, got %q", "ac", calls)
|
||||
}
|
||||
|
||||
// ensure that the id was added in the exclude list
|
||||
excluded := []string{"test1", "test3"}
|
||||
if len(r.excludedMiddlewares) != len(excluded) {
|
||||
t.Fatalf("Expected excludes %v, got %v", excluded, r.excludedMiddlewares)
|
||||
}
|
||||
for id := range r.excludedMiddlewares {
|
||||
if !slices.Contains(excluded, id) {
|
||||
t.Fatalf("Expected %q to be marked as excluded", id)
|
||||
}
|
||||
}
|
||||
}
|
329
tools/router/router.go
Normal file
329
tools/router/router.go
Normal file
|
@ -0,0 +1,329 @@
|
|||
package router
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/hook"
|
||||
)
|
||||
|
||||
type EventCleanupFunc func()
|
||||
|
||||
// EventFactoryFunc defines the function responsible for creating a Route specific event
|
||||
// based on the provided request handler ServeHTTP data.
|
||||
//
|
||||
// Optionally return a clean up function that will be invoked right after the route execution.
|
||||
type EventFactoryFunc[T hook.Resolver] func(w http.ResponseWriter, r *http.Request) (T, EventCleanupFunc)
|
||||
|
||||
// Router defines a thin wrapper around the standard Go [http.ServeMux] by
|
||||
// adding support for routing sub-groups, middlewares and other common utils.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// r := NewRouter[*MyEvent](eventFactory)
|
||||
//
|
||||
// // middlewares
|
||||
// r.BindFunc(m1, m2)
|
||||
//
|
||||
// // routes
|
||||
// r.GET("/test", handler1)
|
||||
//
|
||||
// // sub-routers/groups
|
||||
// api := r.Group("/api")
|
||||
// api.GET("/admins", handler2)
|
||||
//
|
||||
// // generate a http.ServeMux instance based on the router configurations
|
||||
// mux, _ := r.BuildMux()
|
||||
//
|
||||
// http.ListenAndServe("localhost:8090", mux)
|
||||
type Router[T hook.Resolver] struct {
|
||||
// @todo consider renaming the type to just Group and replace the embed type
|
||||
// with an alias after Go 1.24 adds support for generic type aliases
|
||||
*RouterGroup[T]
|
||||
|
||||
eventFactory EventFactoryFunc[T]
|
||||
}
|
||||
|
||||
// NewRouter creates a new Router instance with the provided event factory function.
|
||||
func NewRouter[T hook.Resolver](eventFactory EventFactoryFunc[T]) *Router[T] {
|
||||
return &Router[T]{
|
||||
RouterGroup: &RouterGroup[T]{},
|
||||
eventFactory: eventFactory,
|
||||
}
|
||||
}
|
||||
|
||||
// BuildMux constructs a new mux [http.Handler] instance from the current router configurations.
|
||||
func (r *Router[T]) BuildMux() (http.Handler, error) {
|
||||
// Note that some of the default std Go handlers like the [http.NotFoundHandler]
|
||||
// cannot be currently extended and requires defining a custom "catch-all" route
|
||||
// so that the group middlewares could be executed.
|
||||
//
|
||||
// https://github.com/golang/go/issues/65648
|
||||
if !r.HasRoute("", "/") {
|
||||
r.Route("", "/", func(e T) error {
|
||||
return NewNotFoundError("", nil)
|
||||
})
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
|
||||
if err := r.loadMux(mux, r.RouterGroup, nil); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return mux, nil
|
||||
}
|
||||
|
||||
func (r *Router[T]) loadMux(mux *http.ServeMux, group *RouterGroup[T], parents []*RouterGroup[T]) error {
|
||||
for _, child := range group.children {
|
||||
switch v := child.(type) {
|
||||
case *RouterGroup[T]:
|
||||
if err := r.loadMux(mux, v, append(parents, group)); err != nil {
|
||||
return err
|
||||
}
|
||||
case *Route[T]:
|
||||
routeHook := &hook.Hook[T]{}
|
||||
|
||||
var pattern string
|
||||
|
||||
if v.Method != "" {
|
||||
pattern = v.Method + " "
|
||||
}
|
||||
|
||||
// add parent groups middlewares
|
||||
for _, p := range parents {
|
||||
pattern += p.Prefix
|
||||
for _, h := range p.Middlewares {
|
||||
if _, ok := p.excludedMiddlewares[h.Id]; !ok {
|
||||
if _, ok = group.excludedMiddlewares[h.Id]; !ok {
|
||||
if _, ok = v.excludedMiddlewares[h.Id]; !ok {
|
||||
routeHook.Bind(h)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// add current groups middlewares
|
||||
pattern += group.Prefix
|
||||
for _, h := range group.Middlewares {
|
||||
if _, ok := group.excludedMiddlewares[h.Id]; !ok {
|
||||
if _, ok = v.excludedMiddlewares[h.Id]; !ok {
|
||||
routeHook.Bind(h)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// add current route middlewares
|
||||
pattern += v.Path
|
||||
for _, h := range v.Middlewares {
|
||||
if _, ok := v.excludedMiddlewares[h.Id]; !ok {
|
||||
routeHook.Bind(h)
|
||||
}
|
||||
}
|
||||
|
||||
mux.HandleFunc(pattern, func(resp http.ResponseWriter, req *http.Request) {
|
||||
// wrap the response to add write and status tracking
|
||||
resp = &ResponseWriter{ResponseWriter: resp}
|
||||
|
||||
// wrap the request body to allow multiple reads
|
||||
req.Body = &RereadableReadCloser{ReadCloser: req.Body}
|
||||
|
||||
event, cleanupFunc := r.eventFactory(resp, req)
|
||||
|
||||
// trigger the handler hook chain
|
||||
err := routeHook.Trigger(event, v.Action)
|
||||
if err != nil {
|
||||
ErrorHandler(resp, req, err)
|
||||
}
|
||||
|
||||
if cleanupFunc != nil {
|
||||
cleanupFunc()
|
||||
}
|
||||
})
|
||||
default:
|
||||
return errors.New("invalid Group item type")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func ErrorHandler(resp http.ResponseWriter, req *http.Request, err error) {
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if ok, _ := getWritten(resp); ok {
|
||||
return // a response was already written (aka. already handled)
|
||||
}
|
||||
|
||||
header := resp.Header()
|
||||
if header.Get("Content-Type") == "" {
|
||||
header.Set("Content-Type", "application/json")
|
||||
}
|
||||
|
||||
apiErr := ToApiError(err)
|
||||
|
||||
resp.WriteHeader(apiErr.Status)
|
||||
|
||||
if req.Method != http.MethodHead {
|
||||
if jsonErr := json.NewEncoder(resp).Encode(apiErr); jsonErr != nil {
|
||||
log.Println(jsonErr) // truly rare case, log to stderr only for dev purposes
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
type WriteTracker interface {
|
||||
// Written reports whether a write operation has occurred.
|
||||
Written() bool
|
||||
}
|
||||
|
||||
type StatusTracker interface {
|
||||
// Status reports the written response status code.
|
||||
Status() int
|
||||
}
|
||||
|
||||
type flushErrorer interface {
|
||||
FlushError() error
|
||||
}
|
||||
|
||||
var (
|
||||
_ WriteTracker = (*ResponseWriter)(nil)
|
||||
_ StatusTracker = (*ResponseWriter)(nil)
|
||||
_ http.Flusher = (*ResponseWriter)(nil)
|
||||
_ http.Hijacker = (*ResponseWriter)(nil)
|
||||
_ http.Pusher = (*ResponseWriter)(nil)
|
||||
_ io.ReaderFrom = (*ResponseWriter)(nil)
|
||||
_ flushErrorer = (*ResponseWriter)(nil)
|
||||
)
|
||||
|
||||
// ResponseWriter wraps a http.ResponseWriter to track its write state.
|
||||
type ResponseWriter struct {
|
||||
http.ResponseWriter
|
||||
|
||||
written bool
|
||||
status int
|
||||
}
|
||||
|
||||
func (rw *ResponseWriter) WriteHeader(status int) {
|
||||
if rw.written {
|
||||
return
|
||||
}
|
||||
|
||||
rw.written = true
|
||||
rw.status = status
|
||||
rw.ResponseWriter.WriteHeader(status)
|
||||
}
|
||||
|
||||
func (rw *ResponseWriter) Write(b []byte) (int, error) {
|
||||
if !rw.written {
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
return rw.ResponseWriter.Write(b)
|
||||
}
|
||||
|
||||
// Written implements [WriteTracker] and returns whether the current response body has been already written.
|
||||
func (rw *ResponseWriter) Written() bool {
|
||||
return rw.written
|
||||
}
|
||||
|
||||
// Written implements [StatusTracker] and returns the written status code of the current response.
|
||||
func (rw *ResponseWriter) Status() int {
|
||||
return rw.status
|
||||
}
|
||||
|
||||
// Flush implements [http.Flusher] and allows an HTTP handler to flush buffered data to the client.
|
||||
// This method is no-op if the wrapped writer doesn't support it.
|
||||
func (rw *ResponseWriter) Flush() {
|
||||
_ = rw.FlushError()
|
||||
}
|
||||
|
||||
// FlushError is similar to [Flush] but returns [http.ErrNotSupported]
|
||||
// if the wrapped writer doesn't support it.
|
||||
func (rw *ResponseWriter) FlushError() error {
|
||||
err := http.NewResponseController(rw.ResponseWriter).Flush()
|
||||
if err == nil || !errors.Is(err, http.ErrNotSupported) {
|
||||
rw.written = true
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Hijack implements [http.Hijacker] and allows an HTTP handler to take over the current connection.
|
||||
func (rw *ResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
return http.NewResponseController(rw.ResponseWriter).Hijack()
|
||||
}
|
||||
|
||||
// Pusher implements [http.Pusher] to indicate HTTP/2 server push support.
|
||||
func (rw *ResponseWriter) Push(target string, opts *http.PushOptions) error {
|
||||
w := rw.ResponseWriter
|
||||
for {
|
||||
switch p := w.(type) {
|
||||
case http.Pusher:
|
||||
return p.Push(target, opts)
|
||||
case RWUnwrapper:
|
||||
w = p.Unwrap()
|
||||
default:
|
||||
return http.ErrNotSupported
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ReaderFrom implements [io.ReaderFrom] by checking if the underlying writer supports it.
|
||||
// Otherwise calls [io.Copy].
|
||||
func (rw *ResponseWriter) ReadFrom(r io.Reader) (n int64, err error) {
|
||||
if !rw.written {
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
w := rw.ResponseWriter
|
||||
for {
|
||||
switch rf := w.(type) {
|
||||
case io.ReaderFrom:
|
||||
return rf.ReadFrom(r)
|
||||
case RWUnwrapper:
|
||||
w = rf.Unwrap()
|
||||
default:
|
||||
return io.Copy(rw.ResponseWriter, r)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Unwrap returns the underlying ResponseWritter instance (usually used by [http.ResponseController]).
|
||||
func (rw *ResponseWriter) Unwrap() http.ResponseWriter {
|
||||
return rw.ResponseWriter
|
||||
}
|
||||
|
||||
func getWritten(rw http.ResponseWriter) (bool, error) {
|
||||
for {
|
||||
switch w := rw.(type) {
|
||||
case WriteTracker:
|
||||
return w.Written(), nil
|
||||
case RWUnwrapper:
|
||||
rw = w.Unwrap()
|
||||
default:
|
||||
return false, http.ErrNotSupported
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func getStatus(rw http.ResponseWriter) (int, error) {
|
||||
for {
|
||||
switch w := rw.(type) {
|
||||
case StatusTracker:
|
||||
return w.Status(), nil
|
||||
case RWUnwrapper:
|
||||
rw = w.Unwrap()
|
||||
default:
|
||||
return 0, http.ErrNotSupported
|
||||
}
|
||||
}
|
||||
}
|
253
tools/router/router_test.go
Normal file
253
tools/router/router_test.go
Normal file
|
@ -0,0 +1,253 @@
|
|||
package router_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/hook"
|
||||
"github.com/pocketbase/pocketbase/tools/router"
|
||||
)
|
||||
|
||||
func TestRouter(t *testing.T) {
|
||||
calls := ""
|
||||
|
||||
r := router.NewRouter(func(w http.ResponseWriter, r *http.Request) (*router.Event, router.EventCleanupFunc) {
|
||||
return &router.Event{
|
||||
Response: w,
|
||||
Request: r,
|
||||
},
|
||||
func() {
|
||||
calls += ":cleanup"
|
||||
}
|
||||
})
|
||||
|
||||
r.BindFunc(func(e *router.Event) error {
|
||||
calls += "root_m:"
|
||||
|
||||
err := e.Next()
|
||||
|
||||
if err != nil {
|
||||
calls += "/error"
|
||||
}
|
||||
|
||||
return err
|
||||
})
|
||||
|
||||
r.Any("/any", func(e *router.Event) error {
|
||||
calls += "/any"
|
||||
return nil
|
||||
})
|
||||
|
||||
r.GET("/a", func(e *router.Event) error {
|
||||
calls += "/a"
|
||||
return nil
|
||||
})
|
||||
|
||||
g1 := r.Group("/a/b").BindFunc(func(e *router.Event) error {
|
||||
calls += "a_b_group_m:"
|
||||
return e.Next()
|
||||
})
|
||||
g1.GET("/1", func(e *router.Event) error {
|
||||
calls += "/1_get"
|
||||
return nil
|
||||
}).BindFunc(func(e *router.Event) error {
|
||||
calls += "1_get_m:"
|
||||
return e.Next()
|
||||
})
|
||||
g1.POST("/1", func(e *router.Event) error {
|
||||
calls += "/1_post"
|
||||
return nil
|
||||
})
|
||||
g1.GET("/{param}", func(e *router.Event) error {
|
||||
calls += "/" + e.Request.PathValue("param")
|
||||
return errors.New("test") // should be normalized to an ApiError
|
||||
})
|
||||
|
||||
mux, err := r.BuildMux()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ts := httptest.NewServer(mux)
|
||||
defer ts.Close()
|
||||
|
||||
client := ts.Client()
|
||||
|
||||
scenarios := []struct {
|
||||
method string
|
||||
path string
|
||||
calls string
|
||||
}{
|
||||
{http.MethodGet, "/any", "root_m:/any:cleanup"},
|
||||
{http.MethodOptions, "/any", "root_m:/any:cleanup"},
|
||||
{http.MethodPatch, "/any", "root_m:/any:cleanup"},
|
||||
{http.MethodPut, "/any", "root_m:/any:cleanup"},
|
||||
{http.MethodPost, "/any", "root_m:/any:cleanup"},
|
||||
{http.MethodDelete, "/any", "root_m:/any:cleanup"},
|
||||
// ---
|
||||
{http.MethodPost, "/a", "root_m:/error:cleanup"}, // missing
|
||||
{http.MethodGet, "/a", "root_m:/a:cleanup"},
|
||||
{http.MethodHead, "/a", "root_m:/a:cleanup"}, // auto registered with the GET
|
||||
{http.MethodGet, "/a/b/1", "root_m:a_b_group_m:1_get_m:/1_get:cleanup"},
|
||||
{http.MethodHead, "/a/b/1", "root_m:a_b_group_m:1_get_m:/1_get:cleanup"},
|
||||
{http.MethodPost, "/a/b/1", "root_m:a_b_group_m:/1_post:cleanup"},
|
||||
{http.MethodGet, "/a/b/456", "root_m:a_b_group_m:/456/error:cleanup"},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.method+"_"+s.path, func(t *testing.T) {
|
||||
calls = "" // reset
|
||||
|
||||
req, err := http.NewRequest(s.method, ts.URL+s.path, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if calls != s.calls {
|
||||
t.Fatalf("Expected calls\n%q\ngot\n%q", s.calls, calls)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouterUnbind(t *testing.T) {
|
||||
calls := ""
|
||||
|
||||
r := router.NewRouter(func(w http.ResponseWriter, r *http.Request) (*router.Event, router.EventCleanupFunc) {
|
||||
return &router.Event{
|
||||
Response: w,
|
||||
Request: r,
|
||||
},
|
||||
func() {
|
||||
calls += ":cleanup"
|
||||
}
|
||||
})
|
||||
r.Bind(&hook.Handler[*router.Event]{
|
||||
Id: "root_1",
|
||||
Func: func(e *router.Event) error {
|
||||
calls += "root_1:"
|
||||
return e.Next()
|
||||
},
|
||||
})
|
||||
r.Bind(&hook.Handler[*router.Event]{
|
||||
Id: "root_2",
|
||||
Func: func(e *router.Event) error {
|
||||
calls += "root_2:"
|
||||
return e.Next()
|
||||
},
|
||||
})
|
||||
r.Bind(&hook.Handler[*router.Event]{
|
||||
Id: "root_3",
|
||||
Func: func(e *router.Event) error {
|
||||
calls += "root_3:"
|
||||
return e.Next()
|
||||
},
|
||||
})
|
||||
r.GET("/action", func(e *router.Event) error {
|
||||
calls += "root_action"
|
||||
return nil
|
||||
}).Unbind("root_1")
|
||||
|
||||
ga := r.Group("/group_a")
|
||||
ga.Unbind("root_1")
|
||||
ga.Bind(&hook.Handler[*router.Event]{
|
||||
Id: "group_a_1",
|
||||
Func: func(e *router.Event) error {
|
||||
calls += "group_a_1:"
|
||||
return e.Next()
|
||||
},
|
||||
})
|
||||
ga.Bind(&hook.Handler[*router.Event]{
|
||||
Id: "group_a_2",
|
||||
Func: func(e *router.Event) error {
|
||||
calls += "group_a_2:"
|
||||
return e.Next()
|
||||
},
|
||||
})
|
||||
ga.Bind(&hook.Handler[*router.Event]{
|
||||
Id: "group_a_3",
|
||||
Func: func(e *router.Event) error {
|
||||
calls += "group_a_3:"
|
||||
return e.Next()
|
||||
},
|
||||
})
|
||||
ga.GET("/action", func(e *router.Event) error {
|
||||
calls += "group_a_action"
|
||||
return nil
|
||||
}).Unbind("root_2", "group_b_1", "group_a_1")
|
||||
|
||||
gb := r.Group("/group_b")
|
||||
gb.Unbind("root_2")
|
||||
gb.Bind(&hook.Handler[*router.Event]{
|
||||
Id: "group_b_1",
|
||||
Func: func(e *router.Event) error {
|
||||
calls += "group_b_1:"
|
||||
return e.Next()
|
||||
},
|
||||
})
|
||||
gb.Bind(&hook.Handler[*router.Event]{
|
||||
Id: "group_b_2",
|
||||
Func: func(e *router.Event) error {
|
||||
calls += "group_b_2:"
|
||||
return e.Next()
|
||||
},
|
||||
})
|
||||
gb.Bind(&hook.Handler[*router.Event]{
|
||||
Id: "group_b_3",
|
||||
Func: func(e *router.Event) error {
|
||||
calls += "group_b_3:"
|
||||
return e.Next()
|
||||
},
|
||||
})
|
||||
gb.GET("/action", func(e *router.Event) error {
|
||||
calls += "group_b_action"
|
||||
return nil
|
||||
}).Unbind("group_b_3", "group_a_3", "root_3")
|
||||
|
||||
mux, err := r.BuildMux()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ts := httptest.NewServer(mux)
|
||||
defer ts.Close()
|
||||
|
||||
client := ts.Client()
|
||||
|
||||
scenarios := []struct {
|
||||
method string
|
||||
path string
|
||||
calls string
|
||||
}{
|
||||
{http.MethodGet, "/action", "root_2:root_3:root_action:cleanup"},
|
||||
{http.MethodGet, "/group_a/action", "root_3:group_a_2:group_a_3:group_a_action:cleanup"},
|
||||
{http.MethodGet, "/group_b/action", "root_1:group_b_1:group_b_2:group_b_action:cleanup"},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.method+"_"+s.path, func(t *testing.T) {
|
||||
calls = "" // reset
|
||||
|
||||
req, err := http.NewRequest(s.method, ts.URL+s.path, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if calls != s.calls {
|
||||
t.Fatalf("Expected calls\n%q\ngot\n%q", s.calls, calls)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
338
tools/router/unmarshal_request_data.go
Normal file
338
tools/router/unmarshal_request_data.go
Normal file
|
@ -0,0 +1,338 @@
|
|||
package router
|
||||
|
||||
import (
|
||||
"encoding"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
var textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()
|
||||
|
||||
// JSONPayloadKey is the key for the special UnmarshalRequestData case
|
||||
// used for reading serialized json payload without normalization.
|
||||
const JSONPayloadKey string = "@jsonPayload"
|
||||
|
||||
// UnmarshalRequestData unmarshals url.Values type of data (query, multipart/form-data, etc.) into dst.
|
||||
//
|
||||
// dst must be a pointer to a map[string]any or struct.
|
||||
//
|
||||
// If dst is a map[string]any, each data value will be inferred and
|
||||
// converted to its bool, numeric, or string equivalent value
|
||||
// (refer to inferValue() for the exact rules).
|
||||
//
|
||||
// If dst is a struct, the following field types are supported:
|
||||
// - bool
|
||||
// - string
|
||||
// - int, int8, int16, int32, int64
|
||||
// - uint, uint8, uint16, uint32, uint64
|
||||
// - float32, float64
|
||||
// - serialized json string if submitted under the special "@jsonPayload" key
|
||||
// - encoding.TextUnmarshaler
|
||||
// - pointer and slice variations of the above primitives (ex. *string, []string, *[]string []*string, etc.)
|
||||
// - named/anonymous struct fields
|
||||
// Dot-notation is used to target nested fields, ex. "nestedStructField.title".
|
||||
// - embedded struct fields
|
||||
// The embedded struct fields are treated by default as if they were defined in their parent struct.
|
||||
// If the embedded struct has a tag matching structTagKey then to set its fields the data keys must be prefixed with that tag
|
||||
// similar to the regular nested struct fields.
|
||||
//
|
||||
// structTagKey and structPrefix are used only when dst is a struct.
|
||||
//
|
||||
// structTagKey represents the tag to use to match a data entry with a struct field (defaults to "form").
|
||||
// If the struct field doesn't have the structTagKey tag, then the exported struct field name will be used as it is.
|
||||
//
|
||||
// structPrefix could be provided if all of the data keys are prefixed with a common string
|
||||
// and you want the struct field to match only the value without the structPrefix
|
||||
// (ex. for "user.name", "user.email" data keys and structPrefix "user", it will match "name" and "email" struct fields).
|
||||
//
|
||||
// Note that while the method was inspired by binders from echo, gorrila/schema, ozzo-routing
|
||||
// and other similar common routing packages, it is not intended to be a drop-in replacement.
|
||||
//
|
||||
// @todo Consider adding support for dot-notation keys, in addition to the prefix, (ex. parent.child.title) to express nested object keys.
|
||||
func UnmarshalRequestData(data map[string][]string, dst any, structTagKey string, structPrefix string) error {
|
||||
if len(data) == 0 {
|
||||
return nil // nothing to unmarshal
|
||||
}
|
||||
|
||||
dstValue := reflect.ValueOf(dst)
|
||||
if dstValue.Kind() != reflect.Pointer {
|
||||
return errors.New("dst must be a pointer")
|
||||
}
|
||||
|
||||
dstValue = dereference(dstValue)
|
||||
|
||||
dstType := dstValue.Type()
|
||||
|
||||
switch dstType.Kind() {
|
||||
case reflect.Map: // map[string]any
|
||||
if dstType.Elem().Kind() != reflect.Interface {
|
||||
return errors.New("dst map value type must be any/interface{}")
|
||||
}
|
||||
|
||||
for k, v := range data {
|
||||
if k == JSONPayloadKey {
|
||||
continue // unmarshaled separately
|
||||
}
|
||||
|
||||
total := len(v)
|
||||
|
||||
if total == 1 {
|
||||
dstValue.SetMapIndex(reflect.ValueOf(k), reflect.ValueOf(inferValue(v[0])))
|
||||
} else {
|
||||
normalized := make([]any, total)
|
||||
for i, vItem := range v {
|
||||
normalized[i] = inferValue(vItem)
|
||||
}
|
||||
dstValue.SetMapIndex(reflect.ValueOf(k), reflect.ValueOf(normalized))
|
||||
}
|
||||
}
|
||||
case reflect.Struct:
|
||||
// set a default tag key
|
||||
if structTagKey == "" {
|
||||
structTagKey = "form"
|
||||
}
|
||||
|
||||
err := unmarshalInStructValue(data, dstValue, structTagKey, structPrefix)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
return errors.New("dst must be a map[string]any or struct")
|
||||
}
|
||||
|
||||
// @jsonPayload
|
||||
//
|
||||
// Special case to scan serialized json string without
|
||||
// normalization alongside the other data values
|
||||
// ---------------------------------------------------------------
|
||||
jsonPayloadValues := data[JSONPayloadKey]
|
||||
for _, payload := range jsonPayloadValues {
|
||||
if err := json.Unmarshal([]byte(payload), dst); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// unmarshalInStructValue unmarshals data into the provided struct reflect.Value fields.
|
||||
func unmarshalInStructValue(
|
||||
data map[string][]string,
|
||||
dstStructValue reflect.Value,
|
||||
structTagKey string,
|
||||
structPrefix string,
|
||||
) error {
|
||||
dstStructType := dstStructValue.Type()
|
||||
|
||||
for i := 0; i < dstStructValue.NumField(); i++ {
|
||||
fieldType := dstStructType.Field(i)
|
||||
|
||||
tag := fieldType.Tag.Get(structTagKey)
|
||||
|
||||
if tag == "-" || (!fieldType.Anonymous && !fieldType.IsExported()) {
|
||||
continue // disabled or unexported non-anonymous struct field
|
||||
}
|
||||
|
||||
fieldValue := dereference(dstStructValue.Field(i))
|
||||
|
||||
ft := fieldType.Type
|
||||
if ft.Kind() == reflect.Ptr {
|
||||
ft = ft.Elem()
|
||||
}
|
||||
|
||||
isSlice := ft.Kind() == reflect.Slice
|
||||
if isSlice {
|
||||
ft = ft.Elem()
|
||||
}
|
||||
|
||||
name := tag
|
||||
if name == "" && !fieldType.Anonymous {
|
||||
name = fieldType.Name
|
||||
}
|
||||
if name != "" && structPrefix != "" {
|
||||
name = structPrefix + "." + name
|
||||
}
|
||||
|
||||
// (*)encoding.TextUnmarshaler field
|
||||
// ---
|
||||
if ft.Implements(textUnmarshalerType) || reflect.PointerTo(ft).Implements(textUnmarshalerType) {
|
||||
values, ok := data[name]
|
||||
if !ok || len(values) == 0 || !fieldValue.CanSet() {
|
||||
continue // no value to load or the field cannot be set
|
||||
}
|
||||
|
||||
if isSlice {
|
||||
n := len(values)
|
||||
slice := reflect.MakeSlice(fieldValue.Type(), n, n)
|
||||
for i, v := range values {
|
||||
unmarshaler, ok := dereference(slice.Index(i)).Addr().Interface().(encoding.TextUnmarshaler)
|
||||
if ok {
|
||||
if err := unmarshaler.UnmarshalText([]byte(v)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
fieldValue.Set(slice)
|
||||
} else {
|
||||
unmarshaler, ok := fieldValue.Addr().Interface().(encoding.TextUnmarshaler)
|
||||
if ok {
|
||||
if err := unmarshaler.UnmarshalText([]byte(values[0])); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// "regular" field
|
||||
// ---
|
||||
if ft.Kind() != reflect.Struct {
|
||||
values, ok := data[name]
|
||||
if !ok || len(values) == 0 || !fieldValue.CanSet() {
|
||||
continue // no value to load
|
||||
}
|
||||
|
||||
if isSlice {
|
||||
n := len(values)
|
||||
slice := reflect.MakeSlice(fieldValue.Type(), n, n)
|
||||
for i, v := range values {
|
||||
if err := setRegularReflectedValue(dereference(slice.Index(i)), v); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
fieldValue.Set(slice)
|
||||
} else {
|
||||
if err := setRegularReflectedValue(fieldValue, values[0]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// structs (embedded or nested)
|
||||
// ---
|
||||
// slice of structs
|
||||
if isSlice {
|
||||
// populating slice of structs is not supported at the moment
|
||||
// because the filling rules are ambiguous
|
||||
continue
|
||||
}
|
||||
|
||||
if tag != "" {
|
||||
structPrefix = tag
|
||||
} else {
|
||||
structPrefix = name // name is empty for anonymous structs -> no prefix
|
||||
}
|
||||
|
||||
if err := unmarshalInStructValue(data, fieldValue, structTagKey, structPrefix); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// dereference returns the underlying value v points to.
|
||||
func dereference(v reflect.Value) reflect.Value {
|
||||
for v.Kind() == reflect.Ptr {
|
||||
if v.IsNil() {
|
||||
// initialize with a new value and continue searching
|
||||
v.Set(reflect.New(v.Type().Elem()))
|
||||
}
|
||||
v = v.Elem()
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// setRegularReflectedValue sets and casts value into rv.
|
||||
func setRegularReflectedValue(rv reflect.Value, value string) error {
|
||||
switch rv.Kind() {
|
||||
case reflect.String:
|
||||
rv.SetString(value)
|
||||
case reflect.Bool:
|
||||
if value == "" {
|
||||
value = "f"
|
||||
}
|
||||
|
||||
v, err := strconv.ParseBool(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rv.SetBool(v)
|
||||
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int:
|
||||
if value == "" {
|
||||
value = "0"
|
||||
}
|
||||
|
||||
v, err := strconv.ParseInt(value, 0, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rv.SetInt(v)
|
||||
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
|
||||
if value == "" {
|
||||
value = "0"
|
||||
}
|
||||
|
||||
v, err := strconv.ParseUint(value, 0, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rv.SetUint(v)
|
||||
case reflect.Float32, reflect.Float64:
|
||||
if value == "" {
|
||||
value = "0"
|
||||
}
|
||||
|
||||
v, err := strconv.ParseFloat(value, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rv.SetFloat(v)
|
||||
default:
|
||||
return errors.New("unknown value type " + rv.Kind().String())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var inferNumberCharsRegex = regexp.MustCompile(`^[\-\.\d]+$`)
|
||||
|
||||
// In order to support more seamlessly both json and multipart/form-data requests,
|
||||
// the following normalization rules are applied for plain multipart string values:
|
||||
// - "true" is converted to the json "true"
|
||||
// - "false" is converted to the json "false"
|
||||
// - numeric strings are converted to json number ONLY if the resulted
|
||||
// minimal number string representation is the same as the provided raw string
|
||||
// (aka. scientific notations, "Infinity", "0.0", "0001", etc. are kept as string)
|
||||
// - any other string (empty string too) is left as it is
|
||||
func inferValue(raw string) any {
|
||||
switch raw {
|
||||
case "":
|
||||
return raw
|
||||
case "true":
|
||||
return true
|
||||
case "false":
|
||||
return false
|
||||
default:
|
||||
// try to convert to number
|
||||
//
|
||||
// note: expects the provided raw string to match exactly with the minimal string representation of the parsed float
|
||||
if (raw[0] == '-' || (raw[0] >= '0' && raw[0] <= '9')) &&
|
||||
inferNumberCharsRegex.Match([]byte(raw)) {
|
||||
v, err := strconv.ParseFloat(raw, 64)
|
||||
if err == nil && strconv.FormatFloat(v, 'f', -1, 64) == raw {
|
||||
return v
|
||||
}
|
||||
}
|
||||
|
||||
return raw
|
||||
}
|
||||
}
|
471
tools/router/unmarshal_request_data_test.go
Normal file
471
tools/router/unmarshal_request_data_test.go
Normal file
|
@ -0,0 +1,471 @@
|
|||
package router_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/router"
|
||||
)
|
||||
|
||||
func pointer[T any](val T) *T {
|
||||
return &val
|
||||
}
|
||||
|
||||
func TestUnmarshalRequestData(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mapData := map[string][]string{
|
||||
"number1": {"1"},
|
||||
"number2": {"2", "3"},
|
||||
"number3": {"2.1", "-3.4"},
|
||||
"number4": {"0", "-0", "0.0001"},
|
||||
"string0": {""},
|
||||
"string1": {"a"},
|
||||
"string2": {"b", "c"},
|
||||
"string3": {
|
||||
"0.0",
|
||||
"-0.0",
|
||||
"000.1",
|
||||
"000001",
|
||||
"-000001",
|
||||
"1.6E-35",
|
||||
"-1.6E-35",
|
||||
"10e100",
|
||||
"1_000_000",
|
||||
"1.000.000",
|
||||
" 123 ",
|
||||
"0b1",
|
||||
"0xFF",
|
||||
"1234A",
|
||||
"Infinity",
|
||||
"-Infinity",
|
||||
"undefined",
|
||||
"null",
|
||||
},
|
||||
"bool1": {"true"},
|
||||
"bool2": {"true", "false"},
|
||||
"mixed": {"true", "123", "test"},
|
||||
"@jsonPayload": {`{"json_a":null,"json_b":123}`, `{"json_c":[1,2,3]}`},
|
||||
}
|
||||
|
||||
structData := map[string][]string{
|
||||
"stringTag": {"a", "b"},
|
||||
"StringPtr": {"b"},
|
||||
"StringSlice": {"a", "b", "c", ""},
|
||||
"stringSlicePtrTag": {"d", "e"},
|
||||
"StringSliceOfPtr": {"f", "g"},
|
||||
|
||||
"boolTag": {"true"},
|
||||
"BoolPtr": {"true"},
|
||||
"BoolSlice": {"true", "false", ""},
|
||||
"boolSlicePtrTag": {"false", "false", "true"},
|
||||
"BoolSliceOfPtr": {"false", "true", "false"},
|
||||
|
||||
"int8Tag": {"-1", "2"},
|
||||
"Int8Ptr": {"3"},
|
||||
"Int8Slice": {"4", "5", ""},
|
||||
"int8SlicePtrTag": {"5", "6"},
|
||||
"Int8SliceOfPtr": {"7", "8"},
|
||||
|
||||
"int16Tag": {"-1", "2"},
|
||||
"Int16Ptr": {"3"},
|
||||
"Int16Slice": {"4", "5", ""},
|
||||
"int16SlicePtrTag": {"5", "6"},
|
||||
"Int16SliceOfPtr": {"7", "8"},
|
||||
|
||||
"int32Tag": {"-1", "2"},
|
||||
"Int32Ptr": {"3"},
|
||||
"Int32Slice": {"4", "5", ""},
|
||||
"int32SlicePtrTag": {"5", "6"},
|
||||
"Int32SliceOfPtr": {"7", "8"},
|
||||
|
||||
"int64Tag": {"-1", "2"},
|
||||
"Int64Ptr": {"3"},
|
||||
"Int64Slice": {"4", "5", ""},
|
||||
"int64SlicePtrTag": {"5", "6"},
|
||||
"Int64SliceOfPtr": {"7", "8"},
|
||||
|
||||
"intTag": {"-1", "2"},
|
||||
"IntPtr": {"3"},
|
||||
"IntSlice": {"4", "5", ""},
|
||||
"intSlicePtrTag": {"5", "6"},
|
||||
"IntSliceOfPtr": {"7", "8"},
|
||||
|
||||
"uint8Tag": {"1", "2"},
|
||||
"Uint8Ptr": {"3"},
|
||||
"Uint8Slice": {"4", "5", ""},
|
||||
"uint8SlicePtrTag": {"5", "6"},
|
||||
"Uint8SliceOfPtr": {"7", "8"},
|
||||
|
||||
"uint16Tag": {"1", "2"},
|
||||
"Uint16Ptr": {"3"},
|
||||
"Uint16Slice": {"4", "5", ""},
|
||||
"uint16SlicePtrTag": {"5", "6"},
|
||||
"Uint16SliceOfPtr": {"7", "8"},
|
||||
|
||||
"uint32Tag": {"1", "2"},
|
||||
"Uint32Ptr": {"3"},
|
||||
"Uint32Slice": {"4", "5", ""},
|
||||
"uint32SlicePtrTag": {"5", "6"},
|
||||
"Uint32SliceOfPtr": {"7", "8"},
|
||||
|
||||
"uint64Tag": {"1", "2"},
|
||||
"Uint64Ptr": {"3"},
|
||||
"Uint64Slice": {"4", "5", ""},
|
||||
"uint64SlicePtrTag": {"5", "6"},
|
||||
"Uint64SliceOfPtr": {"7", "8"},
|
||||
|
||||
"uintTag": {"1", "2"},
|
||||
"UintPtr": {"3"},
|
||||
"UintSlice": {"4", "5", ""},
|
||||
"uintSlicePtrTag": {"5", "6"},
|
||||
"UintSliceOfPtr": {"7", "8"},
|
||||
|
||||
"float32Tag": {"-1.2"},
|
||||
"Float32Ptr": {"1.5", "2.0"},
|
||||
"Float32Slice": {"1", "2.3", "-0.3", ""},
|
||||
"float32SlicePtrTag": {"-1.3", "3"},
|
||||
"Float32SliceOfPtr": {"0", "1.2"},
|
||||
|
||||
"float64Tag": {"-1.2"},
|
||||
"Float64Ptr": {"1.5", "2.0"},
|
||||
"Float64Slice": {"1", "2.3", "-0.3", ""},
|
||||
"float64SlicePtrTag": {"-1.3", "3"},
|
||||
"Float64SliceOfPtr": {"0", "1.2"},
|
||||
|
||||
"timeTag": {"2009-11-10T15:00:00Z"},
|
||||
"TimePtr": {"2009-11-10T14:00:00Z", "2009-11-10T15:00:00Z"},
|
||||
"TimeSlice": {"2009-11-10T14:00:00Z", "2009-11-10T15:00:00Z"},
|
||||
"timeSlicePtrTag": {"2009-11-10T15:00:00Z", "2009-11-10T16:00:00Z"},
|
||||
"TimeSliceOfPtr": {"2009-11-10T17:00:00Z", "2009-11-10T18:00:00Z"},
|
||||
|
||||
// @jsonPayload fields
|
||||
"@jsonPayload": {
|
||||
`{"payloadA":"test", "shouldBeIgnored": "abc"}`,
|
||||
`{"payloadB":[1,2,3], "payloadC":true}`,
|
||||
},
|
||||
|
||||
// unexported fields or `-` tags
|
||||
"unexperted": {"test"},
|
||||
"SkipExported": {"test"},
|
||||
"unexportedStructFieldWithoutTag.Name": {"test"},
|
||||
"unexportedStruct.Name": {"test"},
|
||||
|
||||
// structs
|
||||
"StructWithoutTag.Name": {"test1"},
|
||||
"exportedStruct.Name": {"test2"},
|
||||
|
||||
// embedded
|
||||
"embed_name": {"test3"},
|
||||
"embed2.embed_name2": {"test4"},
|
||||
}
|
||||
|
||||
type embed1 struct {
|
||||
Name string `form:"embed_name" json:"embed_name"`
|
||||
}
|
||||
|
||||
type embed2 struct {
|
||||
Name string `form:"embed_name2" json:"embed_name2"`
|
||||
}
|
||||
|
||||
//nolint
|
||||
type TestStruct struct {
|
||||
String string `form:"stringTag" query:"stringTag2"`
|
||||
StringPtr *string
|
||||
StringSlice []string
|
||||
StringSlicePtr *[]string `form:"stringSlicePtrTag"`
|
||||
StringSliceOfPtr []*string
|
||||
|
||||
Bool bool `form:"boolTag" query:"boolTag2"`
|
||||
BoolPtr *bool
|
||||
BoolSlice []bool
|
||||
BoolSlicePtr *[]bool `form:"boolSlicePtrTag"`
|
||||
BoolSliceOfPtr []*bool
|
||||
|
||||
Int8 int8 `form:"int8Tag" query:"int8Tag2"`
|
||||
Int8Ptr *int8
|
||||
Int8Slice []int8
|
||||
Int8SlicePtr *[]int8 `form:"int8SlicePtrTag"`
|
||||
Int8SliceOfPtr []*int8
|
||||
|
||||
Int16 int16 `form:"int16Tag" query:"int16Tag2"`
|
||||
Int16Ptr *int16
|
||||
Int16Slice []int16
|
||||
Int16SlicePtr *[]int16 `form:"int16SlicePtrTag"`
|
||||
Int16SliceOfPtr []*int16
|
||||
|
||||
Int32 int32 `form:"int32Tag" query:"int32Tag2"`
|
||||
Int32Ptr *int32
|
||||
Int32Slice []int32
|
||||
Int32SlicePtr *[]int32 `form:"int32SlicePtrTag"`
|
||||
Int32SliceOfPtr []*int32
|
||||
|
||||
Int64 int64 `form:"int64Tag" query:"int64Tag2"`
|
||||
Int64Ptr *int64
|
||||
Int64Slice []int64
|
||||
Int64SlicePtr *[]int64 `form:"int64SlicePtrTag"`
|
||||
Int64SliceOfPtr []*int64
|
||||
|
||||
Int int `form:"intTag" query:"intTag2"`
|
||||
IntPtr *int
|
||||
IntSlice []int
|
||||
IntSlicePtr *[]int `form:"intSlicePtrTag"`
|
||||
IntSliceOfPtr []*int
|
||||
|
||||
Uint8 uint8 `form:"uint8Tag" query:"uint8Tag2"`
|
||||
Uint8Ptr *uint8
|
||||
Uint8Slice []uint8
|
||||
Uint8SlicePtr *[]uint8 `form:"uint8SlicePtrTag"`
|
||||
Uint8SliceOfPtr []*uint8
|
||||
|
||||
Uint16 uint16 `form:"uint16Tag" query:"uint16Tag2"`
|
||||
Uint16Ptr *uint16
|
||||
Uint16Slice []uint16
|
||||
Uint16SlicePtr *[]uint16 `form:"uint16SlicePtrTag"`
|
||||
Uint16SliceOfPtr []*uint16
|
||||
|
||||
Uint32 uint32 `form:"uint32Tag" query:"uint32Tag2"`
|
||||
Uint32Ptr *uint32
|
||||
Uint32Slice []uint32
|
||||
Uint32SlicePtr *[]uint32 `form:"uint32SlicePtrTag"`
|
||||
Uint32SliceOfPtr []*uint32
|
||||
|
||||
Uint64 uint64 `form:"uint64Tag" query:"uint64Tag2"`
|
||||
Uint64Ptr *uint64
|
||||
Uint64Slice []uint64
|
||||
Uint64SlicePtr *[]uint64 `form:"uint64SlicePtrTag"`
|
||||
Uint64SliceOfPtr []*uint64
|
||||
|
||||
Uint uint `form:"uintTag" query:"uintTag2"`
|
||||
UintPtr *uint
|
||||
UintSlice []uint
|
||||
UintSlicePtr *[]uint `form:"uintSlicePtrTag"`
|
||||
UintSliceOfPtr []*uint
|
||||
|
||||
Float32 float32 `form:"float32Tag" query:"float32Tag2"`
|
||||
Float32Ptr *float32
|
||||
Float32Slice []float32
|
||||
Float32SlicePtr *[]float32 `form:"float32SlicePtrTag"`
|
||||
Float32SliceOfPtr []*float32
|
||||
|
||||
Float64 float64 `form:"float64Tag" query:"float64Tag2"`
|
||||
Float64Ptr *float64
|
||||
Float64Slice []float64
|
||||
Float64SlicePtr *[]float64 `form:"float64SlicePtrTag"`
|
||||
Float64SliceOfPtr []*float64
|
||||
|
||||
// encoding.TextUnmarshaler
|
||||
Time time.Time `form:"timeTag" query:"timeTag2"`
|
||||
TimePtr *time.Time
|
||||
TimeSlice []time.Time
|
||||
TimeSlicePtr *[]time.Time `form:"timeSlicePtrTag"`
|
||||
TimeSliceOfPtr []*time.Time
|
||||
|
||||
// @jsonPayload fields
|
||||
JSONPayloadA string `form:"shouldBeIgnored" json:"payloadA"`
|
||||
JSONPayloadB []int `json:"payloadB"`
|
||||
JSONPayloadC bool `json:"-"`
|
||||
|
||||
// unexported fields or `-` tags
|
||||
unexported string
|
||||
SkipExported string `form:"-"`
|
||||
unexportedStructFieldWithoutTag struct {
|
||||
Name string `json:"unexportedStructFieldWithoutTag_name"`
|
||||
}
|
||||
unexportedStructFieldWithTag struct {
|
||||
Name string `json:"unexportedStructFieldWithTag_name"`
|
||||
} `form:"unexportedStruct"`
|
||||
|
||||
// structs
|
||||
StructWithoutTag struct {
|
||||
Name string `json:"StructWithoutTag_name"`
|
||||
}
|
||||
StructWithTag struct {
|
||||
Name string `json:"StructWithTag_name"`
|
||||
} `form:"exportedStruct"`
|
||||
|
||||
// embedded
|
||||
embed1
|
||||
embed2 `form:"embed2"`
|
||||
}
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
data map[string][]string
|
||||
dst any
|
||||
tag string
|
||||
prefix string
|
||||
error bool
|
||||
result string
|
||||
}{
|
||||
{
|
||||
name: "nil data",
|
||||
data: nil,
|
||||
dst: pointer(map[string]any{}),
|
||||
error: false,
|
||||
result: `{}`,
|
||||
},
|
||||
{
|
||||
name: "non-pointer map[string]any",
|
||||
data: mapData,
|
||||
dst: map[string]any{},
|
||||
error: true,
|
||||
},
|
||||
{
|
||||
name: "unsupported *map[string]string",
|
||||
data: mapData,
|
||||
dst: pointer(map[string]string{}),
|
||||
error: true,
|
||||
},
|
||||
{
|
||||
name: "unsupported *map[string][]string",
|
||||
data: mapData,
|
||||
dst: pointer(map[string][]string{}),
|
||||
error: true,
|
||||
},
|
||||
{
|
||||
name: "*map[string]any",
|
||||
data: mapData,
|
||||
dst: pointer(map[string]any{}),
|
||||
result: `{"bool1":true,"bool2":[true,false],"json_a":null,"json_b":123,"json_c":[1,2,3],"mixed":[true,123,"test"],"number1":1,"number2":[2,3],"number3":[2.1,-3.4],"number4":[0,-0,0.0001],"string0":"","string1":"a","string2":["b","c"],"string3":["0.0","-0.0","000.1","000001","-000001","1.6E-35","-1.6E-35","10e100","1_000_000","1.000.000"," 123 ","0b1","0xFF","1234A","Infinity","-Infinity","undefined","null"]}`,
|
||||
},
|
||||
{
|
||||
name: "valid pointer struct (all fields)",
|
||||
data: structData,
|
||||
dst: &TestStruct{},
|
||||
result: `{"String":"a","StringPtr":"b","StringSlice":["a","b","c",""],"StringSlicePtr":["d","e"],"StringSliceOfPtr":["f","g"],"Bool":true,"BoolPtr":true,"BoolSlice":[true,false,false],"BoolSlicePtr":[false,false,true],"BoolSliceOfPtr":[false,true,false],"Int8":-1,"Int8Ptr":3,"Int8Slice":[4,5,0],"Int8SlicePtr":[5,6],"Int8SliceOfPtr":[7,8],"Int16":-1,"Int16Ptr":3,"Int16Slice":[4,5,0],"Int16SlicePtr":[5,6],"Int16SliceOfPtr":[7,8],"Int32":-1,"Int32Ptr":3,"Int32Slice":[4,5,0],"Int32SlicePtr":[5,6],"Int32SliceOfPtr":[7,8],"Int64":-1,"Int64Ptr":3,"Int64Slice":[4,5,0],"Int64SlicePtr":[5,6],"Int64SliceOfPtr":[7,8],"Int":-1,"IntPtr":3,"IntSlice":[4,5,0],"IntSlicePtr":[5,6],"IntSliceOfPtr":[7,8],"Uint8":1,"Uint8Ptr":3,"Uint8Slice":"BAUA","Uint8SlicePtr":"BQY=","Uint8SliceOfPtr":[7,8],"Uint16":1,"Uint16Ptr":3,"Uint16Slice":[4,5,0],"Uint16SlicePtr":[5,6],"Uint16SliceOfPtr":[7,8],"Uint32":1,"Uint32Ptr":3,"Uint32Slice":[4,5,0],"Uint32SlicePtr":[5,6],"Uint32SliceOfPtr":[7,8],"Uint64":1,"Uint64Ptr":3,"Uint64Slice":[4,5,0],"Uint64SlicePtr":[5,6],"Uint64SliceOfPtr":[7,8],"Uint":1,"UintPtr":3,"UintSlice":[4,5,0],"UintSlicePtr":[5,6],"UintSliceOfPtr":[7,8],"Float32":-1.2,"Float32Ptr":1.5,"Float32Slice":[1,2.3,-0.3,0],"Float32SlicePtr":[-1.3,3],"Float32SliceOfPtr":[0,1.2],"Float64":-1.2,"Float64Ptr":1.5,"Float64Slice":[1,2.3,-0.3,0],"Float64SlicePtr":[-1.3,3],"Float64SliceOfPtr":[0,1.2],"Time":"2009-11-10T15:00:00Z","TimePtr":"2009-11-10T14:00:00Z","TimeSlice":["2009-11-10T14:00:00Z","2009-11-10T15:00:00Z"],"TimeSlicePtr":["2009-11-10T15:00:00Z","2009-11-10T16:00:00Z"],"TimeSliceOfPtr":["2009-11-10T17:00:00Z","2009-11-10T18:00:00Z"],"payloadA":"test","payloadB":[1,2,3],"SkipExported":"","StructWithoutTag":{"StructWithoutTag_name":"test1"},"StructWithTag":{"StructWithTag_name":"test2"},"embed_name":"test3","embed_name2":"test4"}`,
|
||||
},
|
||||
{
|
||||
name: "non-pointer struct",
|
||||
data: structData,
|
||||
dst: TestStruct{},
|
||||
error: true,
|
||||
},
|
||||
{
|
||||
name: "invalid struct uint value",
|
||||
data: map[string][]string{"uintTag": {"-1"}},
|
||||
dst: &TestStruct{},
|
||||
error: true,
|
||||
},
|
||||
{
|
||||
name: "invalid struct int value",
|
||||
data: map[string][]string{"intTag": {"abc"}},
|
||||
dst: &TestStruct{},
|
||||
error: true,
|
||||
},
|
||||
{
|
||||
name: "invalid struct bool value",
|
||||
data: map[string][]string{"boolTag": {"abc"}},
|
||||
dst: &TestStruct{},
|
||||
error: true,
|
||||
},
|
||||
{
|
||||
name: "invalid struct float value",
|
||||
data: map[string][]string{"float64Tag": {"abc"}},
|
||||
dst: &TestStruct{},
|
||||
error: true,
|
||||
},
|
||||
{
|
||||
name: "invalid struct TextUnmarshaler value",
|
||||
data: map[string][]string{"timeTag": {"123"}},
|
||||
dst: &TestStruct{},
|
||||
error: true,
|
||||
},
|
||||
{
|
||||
name: "custom tagKey",
|
||||
data: map[string][]string{
|
||||
"tag1": {"a"},
|
||||
"tag2": {"b"},
|
||||
"tag3": {"c"},
|
||||
"Item": {"d"},
|
||||
},
|
||||
dst: &struct {
|
||||
Item string `form:"tag1" query:"tag2" json:"tag2"`
|
||||
}{},
|
||||
tag: "query",
|
||||
result: `{"tag2":"b"}`,
|
||||
},
|
||||
{
|
||||
name: "custom prefix",
|
||||
data: map[string][]string{
|
||||
"test.A": {"1"},
|
||||
"A": {"2"},
|
||||
"test.alias": {"3"},
|
||||
},
|
||||
dst: &struct {
|
||||
A string
|
||||
B string `form:"alias"`
|
||||
}{},
|
||||
prefix: "test",
|
||||
result: `{"A":"1","B":"3"}`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
err := router.UnmarshalRequestData(s.data, s.dst, s.tag, s.prefix)
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != s.error {
|
||||
t.Fatalf("Expected hasErr %v, got %v (%v)", s.error, hasErr, err)
|
||||
}
|
||||
|
||||
if hasErr {
|
||||
return
|
||||
}
|
||||
|
||||
raw, err := json.Marshal(s.dst)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(raw, []byte(s.result)) {
|
||||
t.Fatalf("Expected dst \n%s\ngot\n%s", s.result, raw)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// note: extra unexported checks in addition to the above test as there
|
||||
// is no easy way to print nested structs with all their fields.
|
||||
func TestUnmarshalRequestDataUnexportedFields(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
//nolint:all
|
||||
type TestStruct struct {
|
||||
Exported string
|
||||
|
||||
unexported string
|
||||
// to ensure that the reflection doesn't take tags with higher priority than the exported state
|
||||
unexportedWithTag string `form:"unexportedWithTag" json:"unexportedWithTag"`
|
||||
}
|
||||
|
||||
dst := &TestStruct{}
|
||||
|
||||
err := router.UnmarshalRequestData(map[string][]string{
|
||||
"Exported": {"test"}, // just for reference
|
||||
|
||||
"Unexported": {"test"},
|
||||
"unexported": {"test"},
|
||||
"UnexportedWithTag": {"test"},
|
||||
"unexportedWithTag": {"test"},
|
||||
}, dst, "", "")
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if dst.Exported != "test" {
|
||||
t.Fatalf("Expected the Exported field to be %q, got %q", "test", dst.Exported)
|
||||
}
|
||||
|
||||
if dst.unexported != "" {
|
||||
t.Fatalf("Expected the unexported field to remain empty, got %q", dst.unexported)
|
||||
}
|
||||
|
||||
if dst.unexportedWithTag != "" {
|
||||
t.Fatalf("Expected the unexportedWithTag field to remain empty, got %q", dst.unexportedWithTag)
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue