1
0
Fork 0

Adding upstream version 0.28.1.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-05-22 10:57:38 +02:00
parent 88f1d47ab6
commit e28c88ef14
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
933 changed files with 194711 additions and 0 deletions

231
tools/router/error.go Normal file
View file

@ -0,0 +1,231 @@
package router
import (
"database/sql"
"errors"
"io/fs"
"net/http"
"strings"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/pocketbase/pocketbase/tools/inflector"
)
// SafeErrorItem defines a common error interface for a printable public safe error.
type SafeErrorItem interface {
// Code represents a fixed unique identifier of the error (usually used as translation key).
Code() string
// Error is the default English human readable error message that will be returned.
Error() string
}
// SafeErrorParamsResolver defines an optional interface for specifying dynamic error parameters.
type SafeErrorParamsResolver interface {
// Params defines a map with dynamic parameters to return as part of the public safe error view.
Params() map[string]any
}
// SafeErrorResolver defines an error interface for resolving the public safe error fields.
type SafeErrorResolver interface {
// Resolve allows modifying and returning a new public safe error data map.
Resolve(errData map[string]any) any
}
// ApiError defines the struct for a basic api error response.
type ApiError struct {
rawData any
Data map[string]any `json:"data"`
Message string `json:"message"`
Status int `json:"status"`
}
// Error makes it compatible with the `error` interface.
func (e *ApiError) Error() string {
return e.Message
}
// RawData returns the unformatted error data (could be an internal error, text, etc.)
func (e *ApiError) RawData() any {
return e.rawData
}
// Is reports whether the current ApiError wraps the target.
func (e *ApiError) Is(target error) bool {
err, ok := e.rawData.(error)
if ok {
return errors.Is(err, target)
}
apiErr, ok := target.(*ApiError)
return ok && e == apiErr
}
// NewNotFoundError creates and returns 404 ApiError.
func NewNotFoundError(message string, rawErrData any) *ApiError {
if message == "" {
message = "The requested resource wasn't found."
}
return NewApiError(http.StatusNotFound, message, rawErrData)
}
// NewBadRequestError creates and returns 400 ApiError.
func NewBadRequestError(message string, rawErrData any) *ApiError {
if message == "" {
message = "Something went wrong while processing your request."
}
return NewApiError(http.StatusBadRequest, message, rawErrData)
}
// NewForbiddenError creates and returns 403 ApiError.
func NewForbiddenError(message string, rawErrData any) *ApiError {
if message == "" {
message = "You are not allowed to perform this request."
}
return NewApiError(http.StatusForbidden, message, rawErrData)
}
// NewUnauthorizedError creates and returns 401 ApiError.
func NewUnauthorizedError(message string, rawErrData any) *ApiError {
if message == "" {
message = "Missing or invalid authentication."
}
return NewApiError(http.StatusUnauthorized, message, rawErrData)
}
// NewInternalServerError creates and returns 500 ApiError.
func NewInternalServerError(message string, rawErrData any) *ApiError {
if message == "" {
message = "Something went wrong while processing your request."
}
return NewApiError(http.StatusInternalServerError, message, rawErrData)
}
func NewTooManyRequestsError(message string, rawErrData any) *ApiError {
if message == "" {
message = "Too Many Requests."
}
return NewApiError(http.StatusTooManyRequests, message, rawErrData)
}
// NewApiError creates and returns new normalized ApiError instance.
func NewApiError(status int, message string, rawErrData any) *ApiError {
if message == "" {
message = http.StatusText(status)
}
return &ApiError{
rawData: rawErrData,
Data: safeErrorsData(rawErrData),
Status: status,
Message: strings.TrimSpace(inflector.Sentenize(message)),
}
}
// ToApiError wraps err into ApiError instance (if not already).
func ToApiError(err error) *ApiError {
var apiErr *ApiError
if !errors.As(err, &apiErr) {
// no ApiError found -> assign a generic one
if errors.Is(err, sql.ErrNoRows) || errors.Is(err, fs.ErrNotExist) {
apiErr = NewNotFoundError("", err)
} else {
apiErr = NewBadRequestError("", err)
}
}
return apiErr
}
// -------------------------------------------------------------------
func safeErrorsData(data any) map[string]any {
switch v := data.(type) {
case validation.Errors:
return resolveSafeErrorsData(v)
case error:
validationErrors := validation.Errors{}
if errors.As(v, &validationErrors) {
return resolveSafeErrorsData(validationErrors)
}
return map[string]any{} // not nil to ensure that is json serialized as object
case map[string]validation.Error:
return resolveSafeErrorsData(v)
case map[string]SafeErrorItem:
return resolveSafeErrorsData(v)
case map[string]error:
return resolveSafeErrorsData(v)
case map[string]string:
return resolveSafeErrorsData(v)
case map[string]any:
return resolveSafeErrorsData(v)
default:
return map[string]any{} // not nil to ensure that is json serialized as object
}
}
func resolveSafeErrorsData[T any](data map[string]T) map[string]any {
result := map[string]any{}
for name, err := range data {
if isNestedError(err) {
result[name] = safeErrorsData(err)
} else {
result[name] = resolveSafeErrorItem(err)
}
}
return result
}
func isNestedError(err any) bool {
switch err.(type) {
case validation.Errors,
map[string]validation.Error,
map[string]SafeErrorItem,
map[string]error,
map[string]string,
map[string]any:
return true
}
return false
}
// resolveSafeErrorItem extracts from each validation error its
// public safe error code and message.
func resolveSafeErrorItem(err any) any {
data := map[string]any{}
if obj, ok := err.(SafeErrorItem); ok {
// extract the specific error code and message
data["code"] = obj.Code()
data["message"] = inflector.Sentenize(obj.Error())
} else {
// fallback to the default public safe values
data["code"] = "validation_invalid_value"
data["message"] = "Invalid value."
}
if s, ok := err.(SafeErrorParamsResolver); ok {
params := s.Params()
if len(params) > 0 {
data["params"] = params
}
}
if s, ok := err.(SafeErrorResolver); ok {
return s.Resolve(data)
}
return data
}

358
tools/router/error_test.go Normal file
View file

@ -0,0 +1,358 @@
package router_test
import (
"database/sql"
"encoding/json"
"errors"
"fmt"
"io/fs"
"strconv"
"testing"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/pocketbase/pocketbase/tools/router"
)
func TestNewApiErrorWithRawData(t *testing.T) {
t.Parallel()
e := router.NewApiError(
300,
"message_test",
"rawData_test",
)
result, _ := json.Marshal(e)
expected := `{"data":{},"message":"Message_test.","status":300}`
if string(result) != expected {
t.Errorf("Expected\n%v\ngot\n%v", expected, string(result))
}
if e.Error() != "Message_test." {
t.Errorf("Expected %q, got %q", "Message_test.", e.Error())
}
if e.RawData() != "rawData_test" {
t.Errorf("Expected rawData\n%v\ngot\n%v", "rawData_test", e.RawData())
}
}
func TestNewApiErrorWithValidationData(t *testing.T) {
t.Parallel()
e := router.NewApiError(
300,
"message_test",
map[string]any{
"err1": errors.New("test error"), // should be normalized
"err2": validation.ErrRequired,
"err3": validation.Errors{
"err3.1": errors.New("test error"), // should be normalized
"err3.2": validation.ErrRequired,
"err3.3": validation.Errors{
"err3.3.1": validation.ErrRequired,
},
},
"err4": &mockSafeErrorItem{},
"err5": map[string]error{
"err5.1": validation.ErrRequired,
},
},
)
result, _ := json.Marshal(e)
expected := `{"data":{"err1":{"code":"validation_invalid_value","message":"Invalid value."},"err2":{"code":"validation_required","message":"Cannot be blank."},"err3":{"err3.1":{"code":"validation_invalid_value","message":"Invalid value."},"err3.2":{"code":"validation_required","message":"Cannot be blank."},"err3.3":{"err3.3.1":{"code":"validation_required","message":"Cannot be blank."}}},"err4":{"code":"mock_code","message":"Mock_error.","mock_resolve":123},"err5":{"err5.1":{"code":"validation_required","message":"Cannot be blank."}}},"message":"Message_test.","status":300}`
if string(result) != expected {
t.Errorf("Expected \n%v, \ngot \n%v", expected, string(result))
}
if e.Error() != "Message_test." {
t.Errorf("Expected %q, got %q", "Message_test.", e.Error())
}
if e.RawData() == nil {
t.Error("Expected non-nil rawData")
}
}
func TestNewNotFoundError(t *testing.T) {
t.Parallel()
scenarios := []struct {
message string
data any
expected string
}{
{"", nil, `{"data":{},"message":"The requested resource wasn't found.","status":404}`},
{"demo", "rawData_test", `{"data":{},"message":"Demo.","status":404}`},
{"demo", validation.Errors{"err1": validation.NewError("test_code", "test_message")}, `{"data":{"err1":{"code":"test_code","message":"Test_message."}},"message":"Demo.","status":404}`},
}
for i, s := range scenarios {
t.Run(strconv.Itoa(i), func(t *testing.T) {
e := router.NewNotFoundError(s.message, s.data)
result, _ := json.Marshal(e)
if str := string(result); str != s.expected {
t.Fatalf("Expected\n%v\ngot\n%v", s.expected, str)
}
})
}
}
func TestNewBadRequestError(t *testing.T) {
t.Parallel()
scenarios := []struct {
message string
data any
expected string
}{
{"", nil, `{"data":{},"message":"Something went wrong while processing your request.","status":400}`},
{"demo", "rawData_test", `{"data":{},"message":"Demo.","status":400}`},
{"demo", validation.Errors{"err1": validation.NewError("test_code", "test_message")}, `{"data":{"err1":{"code":"test_code","message":"Test_message."}},"message":"Demo.","status":400}`},
}
for i, s := range scenarios {
t.Run(strconv.Itoa(i), func(t *testing.T) {
e := router.NewBadRequestError(s.message, s.data)
result, _ := json.Marshal(e)
if str := string(result); str != s.expected {
t.Fatalf("Expected\n%v\ngot\n%v", s.expected, str)
}
})
}
}
func TestNewForbiddenError(t *testing.T) {
t.Parallel()
scenarios := []struct {
message string
data any
expected string
}{
{"", nil, `{"data":{},"message":"You are not allowed to perform this request.","status":403}`},
{"demo", "rawData_test", `{"data":{},"message":"Demo.","status":403}`},
{"demo", validation.Errors{"err1": validation.NewError("test_code", "test_message")}, `{"data":{"err1":{"code":"test_code","message":"Test_message."}},"message":"Demo.","status":403}`},
}
for i, s := range scenarios {
t.Run(strconv.Itoa(i), func(t *testing.T) {
e := router.NewForbiddenError(s.message, s.data)
result, _ := json.Marshal(e)
if str := string(result); str != s.expected {
t.Fatalf("Expected\n%v\ngot\n%v", s.expected, str)
}
})
}
}
func TestNewUnauthorizedError(t *testing.T) {
t.Parallel()
scenarios := []struct {
message string
data any
expected string
}{
{"", nil, `{"data":{},"message":"Missing or invalid authentication.","status":401}`},
{"demo", "rawData_test", `{"data":{},"message":"Demo.","status":401}`},
{"demo", validation.Errors{"err1": validation.NewError("test_code", "test_message")}, `{"data":{"err1":{"code":"test_code","message":"Test_message."}},"message":"Demo.","status":401}`},
}
for i, s := range scenarios {
t.Run(strconv.Itoa(i), func(t *testing.T) {
e := router.NewUnauthorizedError(s.message, s.data)
result, _ := json.Marshal(e)
if str := string(result); str != s.expected {
t.Fatalf("Expected\n%v\ngot\n%v", s.expected, str)
}
})
}
}
func TestNewInternalServerError(t *testing.T) {
t.Parallel()
scenarios := []struct {
message string
data any
expected string
}{
{"", nil, `{"data":{},"message":"Something went wrong while processing your request.","status":500}`},
{"demo", "rawData_test", `{"data":{},"message":"Demo.","status":500}`},
{"demo", validation.Errors{"err1": validation.NewError("test_code", "test_message")}, `{"data":{"err1":{"code":"test_code","message":"Test_message."}},"message":"Demo.","status":500}`},
}
for i, s := range scenarios {
t.Run(strconv.Itoa(i), func(t *testing.T) {
e := router.NewInternalServerError(s.message, s.data)
result, _ := json.Marshal(e)
if str := string(result); str != s.expected {
t.Fatalf("Expected\n%v\ngot\n%v", s.expected, str)
}
})
}
}
func TestNewTooManyRequestsError(t *testing.T) {
t.Parallel()
scenarios := []struct {
message string
data any
expected string
}{
{"", nil, `{"data":{},"message":"Too Many Requests.","status":429}`},
{"demo", "rawData_test", `{"data":{},"message":"Demo.","status":429}`},
{"demo", validation.Errors{"err1": validation.NewError("test_code", "test_message").SetParams(map[string]any{"test": 123})}, `{"data":{"err1":{"code":"test_code","message":"Test_message.","params":{"test":123}}},"message":"Demo.","status":429}`},
}
for i, s := range scenarios {
t.Run(strconv.Itoa(i), func(t *testing.T) {
e := router.NewTooManyRequestsError(s.message, s.data)
result, _ := json.Marshal(e)
if str := string(result); str != s.expected {
t.Fatalf("Expected\n%v\ngot\n%v", s.expected, str)
}
})
}
}
func TestApiErrorIs(t *testing.T) {
t.Parallel()
err0 := router.NewInternalServerError("", nil)
err1 := router.NewInternalServerError("", nil)
err2 := errors.New("test")
err3 := fmt.Errorf("wrapped: %w", err0)
scenarios := []struct {
name string
err error
target error
expected bool
}{
{
"nil error",
err0,
nil,
false,
},
{
"non ApiError",
err0,
err1,
false,
},
{
"different ApiError",
err0,
err2,
false,
},
{
"same ApiError",
err0,
err0,
true,
},
{
"wrapped ApiError",
err3,
err0,
true,
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
is := errors.Is(s.err, s.target)
if is != s.expected {
t.Fatalf("Expected %v, got %v", s.expected, is)
}
})
}
}
func TestToApiError(t *testing.T) {
t.Parallel()
scenarios := []struct {
name string
err error
expected string
}{
{
"regular error",
errors.New("test"),
`{"data":{},"message":"Something went wrong while processing your request.","status":400}`,
},
{
"fs.ErrNotExist",
fs.ErrNotExist,
`{"data":{},"message":"The requested resource wasn't found.","status":404}`,
},
{
"sql.ErrNoRows",
sql.ErrNoRows,
`{"data":{},"message":"The requested resource wasn't found.","status":404}`,
},
{
"ApiError",
router.NewForbiddenError("test", nil),
`{"data":{},"message":"Test.","status":403}`,
},
{
"wrapped ApiError",
fmt.Errorf("wrapped: %w", router.NewForbiddenError("test", nil)),
`{"data":{},"message":"Test.","status":403}`,
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
raw, err := json.Marshal(router.ToApiError(s.err))
if err != nil {
t.Fatal(err)
}
rawStr := string(raw)
if rawStr != s.expected {
t.Fatalf("Expected error\n%vgot\n%v", s.expected, rawStr)
}
})
}
}
// -------------------------------------------------------------------
var (
_ router.SafeErrorItem = (*mockSafeErrorItem)(nil)
_ router.SafeErrorResolver = (*mockSafeErrorItem)(nil)
)
type mockSafeErrorItem struct {
}
func (m *mockSafeErrorItem) Code() string {
return "mock_code"
}
func (m *mockSafeErrorItem) Error() string {
return "mock_error"
}
func (m *mockSafeErrorItem) Resolve(errData map[string]any) any {
errData["mock_resolve"] = 123
return errData
}

398
tools/router/event.go Normal file
View file

@ -0,0 +1,398 @@
package router
import (
"encoding/json"
"encoding/xml"
"errors"
"io"
"io/fs"
"net"
"net/http"
"net/netip"
"path/filepath"
"strings"
"github.com/pocketbase/pocketbase/tools/filesystem"
"github.com/pocketbase/pocketbase/tools/hook"
"github.com/pocketbase/pocketbase/tools/picker"
"github.com/pocketbase/pocketbase/tools/store"
)
var ErrUnsupportedContentType = NewBadRequestError("Unsupported Content-Type", nil)
var ErrInvalidRedirectStatusCode = NewInternalServerError("Invalid redirect status code", nil)
var ErrFileNotFound = NewNotFoundError("File not found", nil)
const IndexPage = "index.html"
// Event specifies based Route handler event that is usually intended
// to be embedded as part of a custom event struct.
//
// NB! It is expected that the Response and Request fields are always set.
type Event struct {
Response http.ResponseWriter
Request *http.Request
hook.Event
data store.Store[string, any]
}
// RWUnwrapper specifies that an http.ResponseWriter could be "unwrapped"
// (usually used with [http.ResponseController]).
type RWUnwrapper interface {
Unwrap() http.ResponseWriter
}
// Written reports whether the current response has already been written.
//
// This method always returns false if e.ResponseWritter doesn't implement the WriteTracker interface
// (all router package handlers receives a ResponseWritter that implements it unless explicitly replaced with a custom one).
func (e *Event) Written() bool {
written, _ := getWritten(e.Response)
return written
}
// Status reports the status code of the current response.
//
// This method always returns 0 if e.Response doesn't implement the StatusTracker interface
// (all router package handlers receives a ResponseWritter that implements it unless explicitly replaced with a custom one).
func (e *Event) Status() int {
status, _ := getStatus(e.Response)
return status
}
// Flush flushes buffered data to the current response.
//
// Returns [http.ErrNotSupported] if e.Response doesn't implement the [http.Flusher] interface
// (all router package handlers receives a ResponseWritter that implements it unless explicitly replaced with a custom one).
func (e *Event) Flush() error {
return http.NewResponseController(e.Response).Flush()
}
// IsTLS reports whether the connection on which the request was received is TLS.
func (e *Event) IsTLS() bool {
return e.Request.TLS != nil
}
// SetCookie is an alias for [http.SetCookie].
//
// SetCookie adds a Set-Cookie header to the current response's headers.
// The provided cookie must have a valid Name.
// Invalid cookies may be silently dropped.
func (e *Event) SetCookie(cookie *http.Cookie) {
http.SetCookie(e.Response, cookie)
}
// RemoteIP returns the IP address of the client that sent the request.
//
// IPv6 addresses are returned expanded.
// For example, "2001:db8::1" becomes "2001:0db8:0000:0000:0000:0000:0000:0001".
//
// Note that if you are behind reverse proxy(ies), this method returns
// the IP of the last connecting proxy.
func (e *Event) RemoteIP() string {
ip, _, _ := net.SplitHostPort(e.Request.RemoteAddr)
parsed, _ := netip.ParseAddr(ip)
return parsed.StringExpanded()
}
// FindUploadedFiles extracts all form files of "key" from a http request
// and returns a slice with filesystem.File instances (if any).
func (e *Event) FindUploadedFiles(key string) ([]*filesystem.File, error) {
if e.Request.MultipartForm == nil {
err := e.Request.ParseMultipartForm(DefaultMaxMemory)
if err != nil {
return nil, err
}
}
if e.Request.MultipartForm == nil || e.Request.MultipartForm.File == nil || len(e.Request.MultipartForm.File[key]) == 0 {
return nil, http.ErrMissingFile
}
result := make([]*filesystem.File, 0, len(e.Request.MultipartForm.File[key]))
for _, fh := range e.Request.MultipartForm.File[key] {
file, err := filesystem.NewFileFromMultipart(fh)
if err != nil {
return nil, err
}
result = append(result, file)
}
return result, nil
}
// Store
// -------------------------------------------------------------------
// Get retrieves single value from the current event data store.
func (e *Event) Get(key string) any {
return e.data.Get(key)
}
// GetAll returns a copy of the current event data store.
func (e *Event) GetAll() map[string]any {
return e.data.GetAll()
}
// Set saves single value into the current event data store.
func (e *Event) Set(key string, value any) {
e.data.Set(key, value)
}
// SetAll saves all items from m into the current event data store.
func (e *Event) SetAll(m map[string]any) {
for k, v := range m {
e.Set(k, v)
}
}
// Response writers
// -------------------------------------------------------------------
const headerContentType = "Content-Type"
func (e *Event) setResponseHeaderIfEmpty(key, value string) {
header := e.Response.Header()
if header.Get(key) == "" {
header.Set(key, value)
}
}
// String writes a plain string response.
func (e *Event) String(status int, data string) error {
e.setResponseHeaderIfEmpty(headerContentType, "text/plain; charset=utf-8")
e.Response.WriteHeader(status)
_, err := e.Response.Write([]byte(data))
return err
}
// HTML writes an HTML response.
func (e *Event) HTML(status int, data string) error {
e.setResponseHeaderIfEmpty(headerContentType, "text/html; charset=utf-8")
e.Response.WriteHeader(status)
_, err := e.Response.Write([]byte(data))
return err
}
const jsonFieldsParam = "fields"
// JSON writes a JSON response.
//
// It also provides a generic response data fields picker if the "fields" query parameter is set.
// For example, if you are requesting `?fields=a,b` for `e.JSON(200, map[string]int{ "a":1, "b":2, "c":3 })`,
// it should result in a JSON response like: `{"a":1, "b": 2}`.
func (e *Event) JSON(status int, data any) error {
e.setResponseHeaderIfEmpty(headerContentType, "application/json")
e.Response.WriteHeader(status)
rawFields := e.Request.URL.Query().Get(jsonFieldsParam)
// error response or no fields to pick
if rawFields == "" || status < 200 || status > 299 {
return json.NewEncoder(e.Response).Encode(data)
}
// pick only the requested fields
modified, err := picker.Pick(data, rawFields)
if err != nil {
return err
}
return json.NewEncoder(e.Response).Encode(modified)
}
// XML writes an XML response.
// It automatically prepends the generic [xml.Header] string to the response.
func (e *Event) XML(status int, data any) error {
e.setResponseHeaderIfEmpty(headerContentType, "application/xml; charset=utf-8")
e.Response.WriteHeader(status)
if _, err := e.Response.Write([]byte(xml.Header)); err != nil {
return err
}
return xml.NewEncoder(e.Response).Encode(data)
}
// Stream streams the specified reader into the response.
func (e *Event) Stream(status int, contentType string, reader io.Reader) error {
e.Response.Header().Set(headerContentType, contentType)
e.Response.WriteHeader(status)
_, err := io.Copy(e.Response, reader)
return err
}
// Blob writes a blob (bytes slice) response.
func (e *Event) Blob(status int, contentType string, b []byte) error {
e.setResponseHeaderIfEmpty(headerContentType, contentType)
e.Response.WriteHeader(status)
_, err := e.Response.Write(b)
return err
}
// FileFS serves the specified filename from fsys.
//
// It is similar to [echo.FileFS] for consistency with earlier versions.
func (e *Event) FileFS(fsys fs.FS, filename string) error {
f, err := fsys.Open(filename)
if err != nil {
return ErrFileNotFound
}
defer f.Close()
fi, err := f.Stat()
if err != nil {
return err
}
// if it is a directory try to open its index.html file
if fi.IsDir() {
filename = filepath.ToSlash(filepath.Join(filename, IndexPage))
f, err = fsys.Open(filename)
if err != nil {
return ErrFileNotFound
}
defer f.Close()
fi, err = f.Stat()
if err != nil {
return err
}
}
ff, ok := f.(io.ReadSeeker)
if !ok {
return errors.New("[FileFS] file does not implement io.ReadSeeker")
}
http.ServeContent(e.Response, e.Request, fi.Name(), fi.ModTime(), ff)
return nil
}
// NoContent writes a response with no body (ex. 204).
func (e *Event) NoContent(status int) error {
e.Response.WriteHeader(status)
return nil
}
// Redirect writes a redirect response to the specified url.
// The status code must be in between 300 399 range.
func (e *Event) Redirect(status int, url string) error {
if status < 300 || status > 399 {
return ErrInvalidRedirectStatusCode
}
e.Response.Header().Set("Location", url)
e.Response.WriteHeader(status)
return nil
}
// ApiError helpers
// -------------------------------------------------------------------
func (e *Event) Error(status int, message string, errData any) *ApiError {
return NewApiError(status, message, errData)
}
func (e *Event) BadRequestError(message string, errData any) *ApiError {
return NewBadRequestError(message, errData)
}
func (e *Event) NotFoundError(message string, errData any) *ApiError {
return NewNotFoundError(message, errData)
}
func (e *Event) ForbiddenError(message string, errData any) *ApiError {
return NewForbiddenError(message, errData)
}
func (e *Event) UnauthorizedError(message string, errData any) *ApiError {
return NewUnauthorizedError(message, errData)
}
func (e *Event) TooManyRequestsError(message string, errData any) *ApiError {
return NewTooManyRequestsError(message, errData)
}
func (e *Event) InternalServerError(message string, errData any) *ApiError {
return NewInternalServerError(message, errData)
}
// Binders
// -------------------------------------------------------------------
const DefaultMaxMemory = 32 << 20 // 32mb
// BindBody unmarshal the request body into the provided dst.
//
// dst must be either a struct pointer or map[string]any.
//
// The rules how the body will be scanned depends on the request Content-Type.
//
// Currently the following Content-Types are supported:
// - application/json
// - text/xml, application/xml
// - multipart/form-data, application/x-www-form-urlencoded
//
// Respectively the following struct tags are supported (again, which one will be used depends on the Content-Type):
// - "json" (json body)- uses the builtin Go json package for unmarshaling.
// - "xml" (xml body) - uses the builtin Go xml package for unmarshaling.
// - "form" (form data) - utilizes the custom [router.UnmarshalRequestData] method.
//
// NB! When dst is a struct make sure that it doesn't have public fields
// that shouldn't be bindable and it is advisible such fields to be unexported
// or have a separate struct just for the binding. For example:
//
// data := struct{
// somethingPrivate string
//
// Title string `json:"title" form:"title"`
// Total int `json:"total" form:"total"`
// }
// err := e.BindBody(&data)
func (e *Event) BindBody(dst any) error {
if e.Request.ContentLength == 0 {
return nil
}
contentType := e.Request.Header.Get(headerContentType)
if strings.HasPrefix(contentType, "application/json") {
dec := json.NewDecoder(e.Request.Body)
err := dec.Decode(dst)
if err == nil {
// manually call Reread because single call of json.Decoder.Decode()
// doesn't ensure that the entire body is a valid json string
// and it is not guaranteed that it will reach EOF to trigger the reread reset
// (ex. in case of trailing spaces or invalid trailing parts like: `{"test":1},something`)
if body, ok := e.Request.Body.(Rereader); ok {
body.Reread()
}
}
return err
}
if strings.HasPrefix(contentType, "multipart/form-data") {
if err := e.Request.ParseMultipartForm(DefaultMaxMemory); err != nil {
return err
}
return UnmarshalRequestData(e.Request.Form, dst, "", "")
}
if strings.HasPrefix(contentType, "application/x-www-form-urlencoded") {
if err := e.Request.ParseForm(); err != nil {
return err
}
return UnmarshalRequestData(e.Request.Form, dst, "", "")
}
if strings.HasPrefix(contentType, "text/xml") ||
strings.HasPrefix(contentType, "application/xml") {
return xml.NewDecoder(e.Request.Body).Decode(dst)
}
return ErrUnsupportedContentType
}

959
tools/router/event_test.go Normal file
View file

@ -0,0 +1,959 @@
package router_test
import (
"bytes"
"crypto/tls"
"encoding/json"
"encoding/xml"
"errors"
"fmt"
"io"
"mime/multipart"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"regexp"
"strconv"
"strings"
"testing"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/pocketbase/pocketbase/tools/router"
)
type unwrapTester struct {
http.ResponseWriter
}
func (ut unwrapTester) Unwrap() http.ResponseWriter {
return ut.ResponseWriter
}
func TestEventWritten(t *testing.T) {
t.Parallel()
res1 := httptest.NewRecorder()
res2 := httptest.NewRecorder()
res2.Write([]byte("test"))
res3 := &router.ResponseWriter{ResponseWriter: unwrapTester{httptest.NewRecorder()}}
res4 := &router.ResponseWriter{ResponseWriter: unwrapTester{httptest.NewRecorder()}}
res4.Write([]byte("test"))
scenarios := []struct {
name string
response http.ResponseWriter
expected bool
}{
{
name: "non-written non-WriteTracker",
response: res1,
expected: false,
},
{
name: "written non-WriteTracker",
response: res2,
expected: false,
},
{
name: "non-written WriteTracker",
response: res3,
expected: false,
},
{
name: "written WriteTracker",
response: res4,
expected: true,
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
event := router.Event{
Response: s.response,
}
result := event.Written()
if result != s.expected {
t.Fatalf("Expected %v, got %v", s.expected, result)
}
})
}
}
func TestEventStatus(t *testing.T) {
t.Parallel()
res1 := httptest.NewRecorder()
res2 := httptest.NewRecorder()
res2.WriteHeader(123)
res3 := &router.ResponseWriter{ResponseWriter: unwrapTester{httptest.NewRecorder()}}
res4 := &router.ResponseWriter{ResponseWriter: unwrapTester{httptest.NewRecorder()}}
res4.WriteHeader(123)
scenarios := []struct {
name string
response http.ResponseWriter
expected int
}{
{
name: "non-written non-StatusTracker",
response: res1,
expected: 0,
},
{
name: "written non-StatusTracker",
response: res2,
expected: 0,
},
{
name: "non-written StatusTracker",
response: res3,
expected: 0,
},
{
name: "written StatusTracker",
response: res4,
expected: 123,
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
event := router.Event{
Response: s.response,
}
result := event.Status()
if result != s.expected {
t.Fatalf("Expected %d, got %d", s.expected, result)
}
})
}
}
func TestEventIsTLS(t *testing.T) {
t.Parallel()
req, err := http.NewRequest(http.MethodGet, "/", nil)
if err != nil {
t.Fatal(err)
}
event := router.Event{Request: req}
// without TLS
if event.IsTLS() {
t.Fatalf("Expected IsTLS false")
}
// dummy TLS state
req.TLS = new(tls.ConnectionState)
// with TLS
if !event.IsTLS() {
t.Fatalf("Expected IsTLS true")
}
}
func TestEventSetCookie(t *testing.T) {
t.Parallel()
event := router.Event{
Response: httptest.NewRecorder(),
}
cookie := event.Response.Header().Get("set-cookie")
if cookie != "" {
t.Fatalf("Expected empty cookie string, got %q", cookie)
}
event.SetCookie(&http.Cookie{Name: "test", Value: "a"})
expected := "test=a"
cookie = event.Response.Header().Get("set-cookie")
if cookie != expected {
t.Fatalf("Expected cookie %q, got %q", expected, cookie)
}
}
func TestEventRemoteIP(t *testing.T) {
t.Parallel()
scenarios := []struct {
remoteAddr string
expected string
}{
{"", "invalid IP"},
{"1.2.3.4", "invalid IP"},
{"1.2.3.4:8090", "1.2.3.4"},
{"[0000:0000:0000:0000:0000:0000:0000:0002]:80", "0000:0000:0000:0000:0000:0000:0000:0002"},
{"[::2]:80", "0000:0000:0000:0000:0000:0000:0000:0002"}, // should always return the expanded version
}
for _, s := range scenarios {
t.Run(s.remoteAddr, func(t *testing.T) {
req, err := http.NewRequest(http.MethodGet, "/", nil)
if err != nil {
t.Fatal(err)
}
req.RemoteAddr = s.remoteAddr
event := router.Event{Request: req}
ip := event.RemoteIP()
if ip != s.expected {
t.Fatalf("Expected IP %q, got %q", s.expected, ip)
}
})
}
}
func TestFindUploadedFiles(t *testing.T) {
scenarios := []struct {
filename string
expectedPattern string
}{
{"ab.png", `^ab\w{10}_\w{10}\.png$`},
{"test", `^test_\w{10}\.txt$`},
{"a b c d!@$.j!@$pg", `^a_b_c_d_\w{10}\.jpg$`},
{strings.Repeat("a", 150), `^a{100}_\w{10}\.txt$`},
}
for _, s := range scenarios {
t.Run(s.filename, func(t *testing.T) {
// create multipart form file body
body := new(bytes.Buffer)
mp := multipart.NewWriter(body)
w, err := mp.CreateFormFile("test", s.filename)
if err != nil {
t.Fatal(err)
}
w.Write([]byte("test"))
mp.Close()
// ---
req := httptest.NewRequest(http.MethodPost, "/", body)
req.Header.Add("Content-Type", mp.FormDataContentType())
event := router.Event{Request: req}
result, err := event.FindUploadedFiles("test")
if err != nil {
t.Fatal(err)
}
if len(result) != 1 {
t.Fatalf("Expected 1 file, got %d", len(result))
}
if result[0].Size != 4 {
t.Fatalf("Expected the file size to be 4 bytes, got %d", result[0].Size)
}
pattern, err := regexp.Compile(s.expectedPattern)
if err != nil {
t.Fatalf("Invalid filename pattern %q: %v", s.expectedPattern, err)
}
if !pattern.MatchString(result[0].Name) {
t.Fatalf("Expected filename to match %s, got filename %s", s.expectedPattern, result[0].Name)
}
})
}
}
func TestFindUploadedFilesMissing(t *testing.T) {
body := new(bytes.Buffer)
mp := multipart.NewWriter(body)
mp.Close()
req := httptest.NewRequest(http.MethodPost, "/", body)
req.Header.Add("Content-Type", mp.FormDataContentType())
event := router.Event{Request: req}
result, err := event.FindUploadedFiles("test")
if err == nil {
t.Error("Expected error, got nil")
}
if result != nil {
t.Errorf("Expected result to be nil, got %v", result)
}
}
func TestEventSetGet(t *testing.T) {
event := router.Event{}
// get before any set (ensures that doesn't panic)
if v := event.Get("test"); v != nil {
t.Fatalf("Expected nil value, got %v", v)
}
event.Set("a", 123)
event.Set("b", 456)
scenarios := []struct {
key string
expected any
}{
{"", nil},
{"missing", nil},
{"a", 123},
{"b", 456},
}
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%s", i, s.key), func(t *testing.T) {
result := event.Get(s.key)
if result != s.expected {
t.Fatalf("Expected %v, got %v", s.expected, result)
}
})
}
}
func TestEventSetAllGetAll(t *testing.T) {
data := map[string]any{
"a": 123,
"b": 456,
}
rawData, err := json.Marshal(data)
if err != nil {
t.Fatal(err)
}
event := router.Event{}
event.SetAll(data)
// modify the data to ensure that the map was shallow coppied
data["c"] = 789
result := event.GetAll()
rawResult, err := json.Marshal(result)
if err != nil {
t.Fatal(err)
}
if len(rawResult) == 0 || !bytes.Equal(rawData, rawResult) {
t.Fatalf("Expected\n%v\ngot\n%v", rawData, rawResult)
}
}
func TestEventString(t *testing.T) {
scenarios := []testResponseWriteScenario[string]{
{
name: "no explicit content-type",
status: 123,
headers: nil,
body: "test",
expectedStatus: 123,
expectedHeaders: map[string]string{"content-type": "text/plain; charset=utf-8"},
expectedBody: "test",
},
{
name: "with explicit content-type",
status: 123,
headers: map[string]string{"content-type": "text/test"},
body: "test",
expectedStatus: 123,
expectedHeaders: map[string]string{"content-type": "text/test"},
expectedBody: "test",
},
}
for _, s := range scenarios {
testEventResponseWrite(t, s, func(e *router.Event) error {
return e.String(s.status, s.body)
})
}
}
func TestEventHTML(t *testing.T) {
scenarios := []testResponseWriteScenario[string]{
{
name: "no explicit content-type",
status: 123,
headers: nil,
body: "test",
expectedStatus: 123,
expectedHeaders: map[string]string{"content-type": "text/html; charset=utf-8"},
expectedBody: "test",
},
{
name: "with explicit content-type",
status: 123,
headers: map[string]string{"content-type": "text/test"},
body: "test",
expectedStatus: 123,
expectedHeaders: map[string]string{"content-type": "text/test"},
expectedBody: "test",
},
}
for _, s := range scenarios {
testEventResponseWrite(t, s, func(e *router.Event) error {
return e.HTML(s.status, s.body)
})
}
}
func TestEventJSON(t *testing.T) {
body := map[string]any{"a": 123, "b": 456, "c": "test"}
expectedPickedBody := `{"a":123,"c":"test"}` + "\n"
expectedFullBody := `{"a":123,"b":456,"c":"test"}` + "\n"
scenarios := []testResponseWriteScenario[any]{
{
name: "no explicit content-type",
status: 200,
headers: nil,
body: body,
expectedStatus: 200,
expectedHeaders: map[string]string{"content-type": "application/json"},
expectedBody: expectedPickedBody,
},
{
name: "with explicit content-type (200)",
status: 200,
headers: map[string]string{"content-type": "application/test"},
body: body,
expectedStatus: 200,
expectedHeaders: map[string]string{"content-type": "application/test"},
expectedBody: expectedPickedBody,
},
{
name: "with explicit content-type (400)", // no fields picker
status: 400,
headers: map[string]string{"content-type": "application/test"},
body: body,
expectedStatus: 400,
expectedHeaders: map[string]string{"content-type": "application/test"},
expectedBody: expectedFullBody,
},
}
for _, s := range scenarios {
testEventResponseWrite(t, s, func(e *router.Event) error {
e.Request.URL.RawQuery = "fields=a,c" // ensures that the picker is invoked
return e.JSON(s.status, s.body)
})
}
}
func TestEventXML(t *testing.T) {
scenarios := []testResponseWriteScenario[string]{
{
name: "no explicit content-type",
status: 234,
headers: nil,
body: "test",
expectedStatus: 234,
expectedHeaders: map[string]string{"content-type": "application/xml; charset=utf-8"},
expectedBody: xml.Header + "<string>test</string>",
},
{
name: "with explicit content-type",
status: 234,
headers: map[string]string{"content-type": "text/test"},
body: "test",
expectedStatus: 234,
expectedHeaders: map[string]string{"content-type": "text/test"},
expectedBody: xml.Header + "<string>test</string>",
},
}
for _, s := range scenarios {
testEventResponseWrite(t, s, func(e *router.Event) error {
return e.XML(s.status, s.body)
})
}
}
func TestEventStream(t *testing.T) {
scenarios := []testResponseWriteScenario[string]{
{
name: "stream",
status: 234,
headers: map[string]string{"content-type": "text/test"},
body: "test",
expectedStatus: 234,
expectedHeaders: map[string]string{"content-type": "text/test"},
expectedBody: "test",
},
}
for _, s := range scenarios {
testEventResponseWrite(t, s, func(e *router.Event) error {
return e.Stream(s.status, s.headers["content-type"], strings.NewReader(s.body))
})
}
}
func TestEventBlob(t *testing.T) {
scenarios := []testResponseWriteScenario[[]byte]{
{
name: "blob",
status: 234,
headers: map[string]string{"content-type": "text/test"},
body: []byte("test"),
expectedStatus: 234,
expectedHeaders: map[string]string{"content-type": "text/test"},
expectedBody: "test",
},
}
for _, s := range scenarios {
testEventResponseWrite(t, s, func(e *router.Event) error {
return e.Blob(s.status, s.headers["content-type"], s.body)
})
}
}
func TestEventNoContent(t *testing.T) {
s := testResponseWriteScenario[any]{
name: "no content",
status: 234,
headers: map[string]string{"content-type": "text/test"},
body: nil,
expectedStatus: 234,
expectedHeaders: map[string]string{"content-type": "text/test"},
expectedBody: "",
}
testEventResponseWrite(t, s, func(e *router.Event) error {
return e.NoContent(s.status)
})
}
func TestEventFlush(t *testing.T) {
rec := httptest.NewRecorder()
event := &router.Event{
Response: unwrapTester{&router.ResponseWriter{ResponseWriter: rec}},
}
event.Response.Write([]byte("test"))
event.Flush()
if !rec.Flushed {
t.Fatal("Expected response to be flushed")
}
}
func TestEventRedirect(t *testing.T) {
scenarios := []testResponseWriteScenario[any]{
{
name: "non-30x status",
status: 200,
expectedStatus: 200,
expectedError: router.ErrInvalidRedirectStatusCode,
},
{
name: "30x status",
status: 302,
headers: map[string]string{"location": "test"}, // should be overwritten with the argument
expectedStatus: 302,
expectedHeaders: map[string]string{"location": "example"},
},
}
for _, s := range scenarios {
testEventResponseWrite(t, s, func(e *router.Event) error {
return e.Redirect(s.status, "example")
})
}
}
func TestEventFileFS(t *testing.T) {
// stub test files
// ---
dir, err := os.MkdirTemp("", "EventFileFS")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(dir)
err = os.WriteFile(filepath.Join(dir, "index.html"), []byte("index"), 0644)
if err != nil {
t.Fatal(err)
}
err = os.WriteFile(filepath.Join(dir, "test.txt"), []byte("test"), 0644)
if err != nil {
t.Fatal(err)
}
// create sub directory with an index.html file inside it
err = os.MkdirAll(filepath.Join(dir, "sub1"), os.ModePerm)
if err != nil {
t.Fatal(err)
}
err = os.WriteFile(filepath.Join(dir, "sub1", "index.html"), []byte("sub1 index"), 0644)
if err != nil {
t.Fatal(err)
}
err = os.MkdirAll(filepath.Join(dir, "sub2"), os.ModePerm)
if err != nil {
t.Fatal(err)
}
err = os.WriteFile(filepath.Join(dir, "sub2", "test.txt"), []byte("sub2 test"), 0644)
if err != nil {
t.Fatal(err)
}
// ---
scenarios := []struct {
name string
path string
expected string
}{
{"missing file", "", ""},
{"root with no explicit file", "", ""},
{"root with explicit file", "test.txt", "test"},
{"sub dir with no explicit file", "sub1", "sub1 index"},
{"sub dir with no explicit file (no index.html)", "sub2", ""},
{"sub dir explicit file", "sub2/test.txt", "sub2 test"},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
req, err := http.NewRequest(http.MethodGet, "/", nil)
if err != nil {
t.Fatal(err)
}
rec := httptest.NewRecorder()
event := &router.Event{
Request: req,
Response: rec,
}
err = event.FileFS(os.DirFS(dir), s.path)
hasErr := err != nil
expectErr := s.expected == ""
if hasErr != expectErr {
t.Fatalf("Expected hasErr %v, got %v (%v)", expectErr, hasErr, err)
}
result := rec.Result()
raw, err := io.ReadAll(result.Body)
result.Body.Close()
if err != nil {
t.Fatal(err)
}
if string(raw) != s.expected {
t.Fatalf("Expected body\n%s\ngot\n%s", s.expected, raw)
}
// ensure that the proper file headers are added
// (aka. http.ServeContent is invoked)
length, _ := strconv.Atoi(result.Header.Get("content-length"))
if length != len(s.expected) {
t.Fatalf("Expected Content-Length %d, got %d", len(s.expected), length)
}
})
}
}
func TestEventError(t *testing.T) {
err := new(router.Event).Error(123, "message_test", map[string]any{"a": validation.Required, "b": "test"})
result, _ := json.Marshal(err)
expected := `{"data":{"a":{"code":"validation_invalid_value","message":"Invalid value."},"b":{"code":"validation_invalid_value","message":"Invalid value."}},"message":"Message_test.","status":123}`
if string(result) != expected {
t.Errorf("Expected\n%s\ngot\n%s", expected, result)
}
}
func TestEventBadRequestError(t *testing.T) {
err := new(router.Event).BadRequestError("message_test", map[string]any{"a": validation.Required, "b": "test"})
result, _ := json.Marshal(err)
expected := `{"data":{"a":{"code":"validation_invalid_value","message":"Invalid value."},"b":{"code":"validation_invalid_value","message":"Invalid value."}},"message":"Message_test.","status":400}`
if string(result) != expected {
t.Errorf("Expected\n%s\ngot\n%s", expected, result)
}
}
func TestEventNotFoundError(t *testing.T) {
err := new(router.Event).NotFoundError("message_test", map[string]any{"a": validation.Required, "b": "test"})
result, _ := json.Marshal(err)
expected := `{"data":{"a":{"code":"validation_invalid_value","message":"Invalid value."},"b":{"code":"validation_invalid_value","message":"Invalid value."}},"message":"Message_test.","status":404}`
if string(result) != expected {
t.Errorf("Expected\n%s\ngot\n%s", expected, result)
}
}
func TestEventForbiddenError(t *testing.T) {
err := new(router.Event).ForbiddenError("message_test", map[string]any{"a": validation.Required, "b": "test"})
result, _ := json.Marshal(err)
expected := `{"data":{"a":{"code":"validation_invalid_value","message":"Invalid value."},"b":{"code":"validation_invalid_value","message":"Invalid value."}},"message":"Message_test.","status":403}`
if string(result) != expected {
t.Errorf("Expected\n%s\ngot\n%s", expected, result)
}
}
func TestEventUnauthorizedError(t *testing.T) {
err := new(router.Event).UnauthorizedError("message_test", map[string]any{"a": validation.Required, "b": "test"})
result, _ := json.Marshal(err)
expected := `{"data":{"a":{"code":"validation_invalid_value","message":"Invalid value."},"b":{"code":"validation_invalid_value","message":"Invalid value."}},"message":"Message_test.","status":401}`
if string(result) != expected {
t.Errorf("Expected\n%s\ngot\n%s", expected, result)
}
}
func TestEventTooManyRequestsError(t *testing.T) {
err := new(router.Event).TooManyRequestsError("message_test", map[string]any{"a": validation.Required, "b": "test"})
result, _ := json.Marshal(err)
expected := `{"data":{"a":{"code":"validation_invalid_value","message":"Invalid value."},"b":{"code":"validation_invalid_value","message":"Invalid value."}},"message":"Message_test.","status":429}`
if string(result) != expected {
t.Errorf("Expected\n%s\ngot\n%s", expected, result)
}
}
func TestEventInternalServerError(t *testing.T) {
err := new(router.Event).InternalServerError("message_test", map[string]any{"a": validation.Required, "b": "test"})
result, _ := json.Marshal(err)
expected := `{"data":{"a":{"code":"validation_invalid_value","message":"Invalid value."},"b":{"code":"validation_invalid_value","message":"Invalid value."}},"message":"Message_test.","status":500}`
if string(result) != expected {
t.Errorf("Expected\n%s\ngot\n%s", expected, result)
}
}
func TestEventBindBody(t *testing.T) {
type testDstStruct struct {
A int `json:"a" xml:"a" form:"a"`
B int `json:"b" xml:"b" form:"b"`
C string `json:"c" xml:"c" form:"c"`
}
emptyDst := `{"a":0,"b":0,"c":""}`
queryDst := `a=123&b=-456&c=test`
xmlDst := `
<?xml version="1.0" encoding="UTF-8" ?>
<root>
<a>123</a>
<b>-456</b>
<c>test</c>
</root>
`
jsonDst := `{"a":123,"b":-456,"c":"test"}`
// multipart
mpBody := &bytes.Buffer{}
mpWriter := multipart.NewWriter(mpBody)
mpWriter.WriteField("@jsonPayload", `{"a":123}`)
mpWriter.WriteField("b", "-456")
mpWriter.WriteField("c", "test")
if err := mpWriter.Close(); err != nil {
t.Fatal(err)
}
scenarios := []struct {
contentType string
body io.Reader
expectDst string
expectError bool
}{
{
contentType: "",
body: strings.NewReader(jsonDst),
expectDst: emptyDst,
expectError: true,
},
{
contentType: "application/rtf", // unsupported
body: strings.NewReader(jsonDst),
expectDst: emptyDst,
expectError: true,
},
// empty body
{
contentType: "application/json;charset=emptybody",
body: strings.NewReader(""),
expectDst: emptyDst,
},
// json
{
contentType: "application/json",
body: strings.NewReader(jsonDst),
expectDst: jsonDst,
},
{
contentType: "application/json;charset=abc",
body: strings.NewReader(jsonDst),
expectDst: jsonDst,
},
// xml
{
contentType: "text/xml",
body: strings.NewReader(xmlDst),
expectDst: jsonDst,
},
{
contentType: "text/xml;charset=abc",
body: strings.NewReader(xmlDst),
expectDst: jsonDst,
},
{
contentType: "application/xml",
body: strings.NewReader(xmlDst),
expectDst: jsonDst,
},
{
contentType: "application/xml;charset=abc",
body: strings.NewReader(xmlDst),
expectDst: jsonDst,
},
// x-www-form-urlencoded
{
contentType: "application/x-www-form-urlencoded",
body: strings.NewReader(queryDst),
expectDst: jsonDst,
},
{
contentType: "application/x-www-form-urlencoded;charset=abc",
body: strings.NewReader(queryDst),
expectDst: jsonDst,
},
// multipart
{
contentType: mpWriter.FormDataContentType(),
body: mpBody,
expectDst: jsonDst,
},
}
for _, s := range scenarios {
t.Run(s.contentType, func(t *testing.T) {
req, err := http.NewRequest(http.MethodPost, "/", s.body)
if err != nil {
t.Fatal(err)
}
req.Header.Add("content-type", s.contentType)
event := &router.Event{Request: req}
dst := testDstStruct{}
err = event.BindBody(&dst)
hasErr := err != nil
if hasErr != s.expectError {
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
}
dstRaw, err := json.Marshal(dst)
if err != nil {
t.Fatal(err)
}
if string(dstRaw) != s.expectDst {
t.Fatalf("Expected dst\n%s\ngot\n%s", s.expectDst, dstRaw)
}
})
}
}
// -------------------------------------------------------------------
type testResponseWriteScenario[T any] struct {
name string
status int
headers map[string]string
body T
expectedStatus int
expectedHeaders map[string]string
expectedBody string
expectedError error
}
func testEventResponseWrite[T any](
t *testing.T,
scenario testResponseWriteScenario[T],
writeFunc func(e *router.Event) error,
) {
t.Run(scenario.name, func(t *testing.T) {
req, err := http.NewRequest(http.MethodGet, "/", nil)
if err != nil {
t.Fatal(err)
}
rec := httptest.NewRecorder()
event := &router.Event{
Request: req,
Response: &router.ResponseWriter{ResponseWriter: rec},
}
for k, v := range scenario.headers {
event.Response.Header().Add(k, v)
}
err = writeFunc(event)
if (scenario.expectedError != nil || err != nil) && !errors.Is(err, scenario.expectedError) {
t.Fatalf("Expected error %v, got %v", scenario.expectedError, err)
}
result := rec.Result()
if result.StatusCode != scenario.expectedStatus {
t.Fatalf("Expected status code %d, got %d", scenario.expectedStatus, result.StatusCode)
}
resultBody, err := io.ReadAll(result.Body)
result.Body.Close()
if err != nil {
t.Fatalf("Failed to read response body: %v", err)
}
resultBody, err = json.Marshal(string(resultBody))
if err != nil {
t.Fatal(err)
}
expectedBody, err := json.Marshal(scenario.expectedBody)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(resultBody, expectedBody) {
t.Fatalf("Expected body\n%s\ngot\n%s", expectedBody, resultBody)
}
for k, ev := range scenario.expectedHeaders {
if v := result.Header.Get(k); v != ev {
t.Fatalf("Expected %q header to be %q, got %q", k, ev, v)
}
}
})
}

231
tools/router/group.go Normal file
View file

@ -0,0 +1,231 @@
package router
import (
"net/http"
"regexp"
"strings"
"github.com/pocketbase/pocketbase/tools/hook"
)
// (note: the struct is named RouterGroup instead of Group so that it can
// be embedded in the Router without conflicting with the Group method)
// RouterGroup represents a collection of routes and other sub groups
// that share common pattern prefix and middlewares.
type RouterGroup[T hook.Resolver] struct {
excludedMiddlewares map[string]struct{}
children []any // Route or RouterGroup
Prefix string
Middlewares []*hook.Handler[T]
}
// Group creates and register a new child Group into the current one
// with the specified prefix.
//
// The prefix follows the standard Go net/http ServeMux pattern format ("[HOST]/[PATH]")
// and will be concatenated recursively into the final route path, meaning that
// only the root level group could have HOST as part of the prefix.
//
// Returns the newly created group to allow chaining and registering
// sub-routes and group specific middlewares.
func (group *RouterGroup[T]) Group(prefix string) *RouterGroup[T] {
newGroup := &RouterGroup[T]{}
newGroup.Prefix = prefix
group.children = append(group.children, newGroup)
return newGroup
}
// BindFunc registers one or multiple middleware functions to the current group.
//
// The registered middleware functions are "anonymous" and with default priority,
// aka. executes in the order they were registered.
//
// If you need to specify a named middleware (ex. so that it can be removed)
// or middleware with custom exec prirority, use [RouterGroup.Bind] method.
func (group *RouterGroup[T]) BindFunc(middlewareFuncs ...func(e T) error) *RouterGroup[T] {
for _, m := range middlewareFuncs {
group.Middlewares = append(group.Middlewares, &hook.Handler[T]{Func: m})
}
return group
}
// Bind registers one or multiple middleware handlers to the current group.
func (group *RouterGroup[T]) Bind(middlewares ...*hook.Handler[T]) *RouterGroup[T] {
group.Middlewares = append(group.Middlewares, middlewares...)
// unmark the newly added middlewares in case they were previously "excluded"
if group.excludedMiddlewares != nil {
for _, m := range middlewares {
if m.Id != "" {
delete(group.excludedMiddlewares, m.Id)
}
}
}
return group
}
// Unbind removes one or more middlewares with the specified id(s)
// from the current group and its children (if any).
//
// Anonymous middlewares are not removable, aka. this method does nothing
// if the middleware id is an empty string.
func (group *RouterGroup[T]) Unbind(middlewareIds ...string) *RouterGroup[T] {
for _, middlewareId := range middlewareIds {
if middlewareId == "" {
continue
}
// remove from the group middlwares
for i := len(group.Middlewares) - 1; i >= 0; i-- {
if group.Middlewares[i].Id == middlewareId {
group.Middlewares = append(group.Middlewares[:i], group.Middlewares[i+1:]...)
}
}
// remove from the group children
for i := len(group.children) - 1; i >= 0; i-- {
switch v := group.children[i].(type) {
case *RouterGroup[T]:
v.Unbind(middlewareId)
case *Route[T]:
v.Unbind(middlewareId)
}
}
// add to the exclude list
if group.excludedMiddlewares == nil {
group.excludedMiddlewares = map[string]struct{}{}
}
group.excludedMiddlewares[middlewareId] = struct{}{}
}
return group
}
// Route registers a single route into the current group.
//
// Note that the final route path will be the concatenation of all parent groups prefixes + the route path.
// The path follows the standard Go net/http ServeMux format ("[HOST]/[PATH]"),
// meaning that only a top level group route could have HOST as part of the prefix.
//
// Returns the newly created route to allow attaching route-only middlewares.
func (group *RouterGroup[T]) Route(method string, path string, action func(e T) error) *Route[T] {
route := &Route[T]{
Method: method,
Path: path,
Action: action,
}
group.children = append(group.children, route)
return route
}
// Any is a shorthand for [RouterGroup.AddRoute] with "" as route method (aka. matches any method).
func (group *RouterGroup[T]) Any(path string, action func(e T) error) *Route[T] {
return group.Route("", path, action)
}
// GET is a shorthand for [RouterGroup.AddRoute] with GET as route method.
func (group *RouterGroup[T]) GET(path string, action func(e T) error) *Route[T] {
return group.Route(http.MethodGet, path, action)
}
// SEARCH is a shorthand for [RouterGroup.AddRoute] with SEARCH as route method.
func (group *RouterGroup[T]) SEARCH(path string, action func(e T) error) *Route[T] {
return group.Route("SEARCH", path, action)
}
// POST is a shorthand for [RouterGroup.AddRoute] with POST as route method.
func (group *RouterGroup[T]) POST(path string, action func(e T) error) *Route[T] {
return group.Route(http.MethodPost, path, action)
}
// DELETE is a shorthand for [RouterGroup.AddRoute] with DELETE as route method.
func (group *RouterGroup[T]) DELETE(path string, action func(e T) error) *Route[T] {
return group.Route(http.MethodDelete, path, action)
}
// PATCH is a shorthand for [RouterGroup.AddRoute] with PATCH as route method.
func (group *RouterGroup[T]) PATCH(path string, action func(e T) error) *Route[T] {
return group.Route(http.MethodPatch, path, action)
}
// PUT is a shorthand for [RouterGroup.AddRoute] with PUT as route method.
func (group *RouterGroup[T]) PUT(path string, action func(e T) error) *Route[T] {
return group.Route(http.MethodPut, path, action)
}
// HEAD is a shorthand for [RouterGroup.AddRoute] with HEAD as route method.
func (group *RouterGroup[T]) HEAD(path string, action func(e T) error) *Route[T] {
return group.Route(http.MethodHead, path, action)
}
// OPTIONS is a shorthand for [RouterGroup.AddRoute] with OPTIONS as route method.
func (group *RouterGroup[T]) OPTIONS(path string, action func(e T) error) *Route[T] {
return group.Route(http.MethodOptions, path, action)
}
// HasRoute checks whether the specified route pattern (method + path)
// is registered in the current group or its children.
//
// This could be useful to conditionally register and checks for routes
// in order prevent panic on duplicated routes.
//
// Note that routes with anonymous and named wildcard placeholder are treated as equal,
// aka. "GET /abc/" is considered the same as "GET /abc/{something...}".
func (group *RouterGroup[T]) HasRoute(method string, path string) bool {
pattern := path
if method != "" {
pattern = strings.ToUpper(method) + " " + pattern
}
return group.hasRoute(pattern, nil)
}
func (group *RouterGroup[T]) hasRoute(pattern string, parents []*RouterGroup[T]) bool {
for _, child := range group.children {
switch v := child.(type) {
case *RouterGroup[T]:
if v.hasRoute(pattern, append(parents, group)) {
return true
}
case *Route[T]:
var result string
if v.Method != "" {
result += v.Method + " "
}
// add parent groups prefixes
for _, p := range parents {
result += p.Prefix
}
// add current group prefix
result += group.Prefix
// add current route path
result += v.Path
if result == pattern || // direct match
// compares without the named wildcard, aka. /abc/{test...} is equal to /abc/
stripWildcard(result) == stripWildcard(pattern) {
return true
}
}
}
return false
}
var wildcardPlaceholderRegex = regexp.MustCompile(`/{.+\.\.\.}$`)
func stripWildcard(pattern string) string {
return wildcardPlaceholderRegex.ReplaceAllString(pattern, "/")
}

430
tools/router/group_test.go Normal file
View file

@ -0,0 +1,430 @@
package router
import (
"errors"
"fmt"
"net/http"
"slices"
"testing"
"github.com/pocketbase/pocketbase/tools/hook"
)
func TestRouterGroupGroup(t *testing.T) {
t.Parallel()
g0 := RouterGroup[*Event]{}
g1 := g0.Group("test1")
g2 := g0.Group("test2")
if total := len(g0.children); total != 2 {
t.Fatalf("Expected %d child groups, got %d", 2, total)
}
if g1.Prefix != "test1" {
t.Fatalf("Expected g1 with prefix %q, got %q", "test1", g1.Prefix)
}
if g2.Prefix != "test2" {
t.Fatalf("Expected g2 with prefix %q, got %q", "test2", g2.Prefix)
}
}
func TestRouterGroupBindFunc(t *testing.T) {
t.Parallel()
g := RouterGroup[*Event]{}
calls := ""
// append one function
g.BindFunc(func(e *Event) error {
calls += "a"
return nil
})
// append multiple functions
g.BindFunc(
func(e *Event) error {
calls += "b"
return nil
},
func(e *Event) error {
calls += "c"
return nil
},
)
if total := len(g.Middlewares); total != 3 {
t.Fatalf("Expected %d middlewares, got %v", 3, total)
}
for _, h := range g.Middlewares {
_ = h.Func(nil)
}
if calls != "abc" {
t.Fatalf("Expected calls sequence %q, got %q", "abc", calls)
}
}
func TestRouterGroupBind(t *testing.T) {
t.Parallel()
g := RouterGroup[*Event]{
// mock excluded middlewares to check whether the entry will be deleted
excludedMiddlewares: map[string]struct{}{"test2": {}},
}
calls := ""
// append one handler
g.Bind(&hook.Handler[*Event]{
Func: func(e *Event) error {
calls += "a"
return nil
},
})
// append multiple handlers
g.Bind(
&hook.Handler[*Event]{
Id: "test1",
Func: func(e *Event) error {
calls += "b"
return nil
},
},
&hook.Handler[*Event]{
Id: "test2",
Func: func(e *Event) error {
calls += "c"
return nil
},
},
)
if total := len(g.Middlewares); total != 3 {
t.Fatalf("Expected %d middlewares, got %v", 3, total)
}
for _, h := range g.Middlewares {
_ = h.Func(nil)
}
if calls != "abc" {
t.Fatalf("Expected calls %q, got %q", "abc", calls)
}
// ensures that the previously excluded middleware was removed
if len(g.excludedMiddlewares) != 0 {
t.Fatalf("Expected test2 to be removed from the excludedMiddlewares list, got %v", g.excludedMiddlewares)
}
}
func TestRouterGroupUnbind(t *testing.T) {
t.Parallel()
g := RouterGroup[*Event]{}
calls := ""
// anonymous middlewares
g.Bind(&hook.Handler[*Event]{
Func: func(e *Event) error {
calls += "a"
return nil // unused value
},
})
// middlewares with id
g.Bind(&hook.Handler[*Event]{
Id: "test1",
Func: func(e *Event) error {
calls += "b"
return nil // unused value
},
})
g.Bind(&hook.Handler[*Event]{
Id: "test2",
Func: func(e *Event) error {
calls += "c"
return nil // unused value
},
})
g.Bind(&hook.Handler[*Event]{
Id: "test3",
Func: func(e *Event) error {
calls += "d"
return nil // unused value
},
})
// remove
g.Unbind("") // should be no-op
g.Unbind("test1", "test3")
if total := len(g.Middlewares); total != 2 {
t.Fatalf("Expected %d middlewares, got %v", 2, total)
}
for _, h := range g.Middlewares {
if err := h.Func(nil); err != nil {
continue
}
}
if calls != "ac" {
t.Fatalf("Expected calls %q, got %q", "ac", calls)
}
// ensure that the ids were added in the exclude list
excluded := []string{"test1", "test3"}
if len(g.excludedMiddlewares) != len(excluded) {
t.Fatalf("Expected excludes %v, got %v", excluded, g.excludedMiddlewares)
}
for id := range g.excludedMiddlewares {
if !slices.Contains(excluded, id) {
t.Fatalf("Expected %q to be marked as excluded", id)
}
}
}
func TestRouterGroupRoute(t *testing.T) {
t.Parallel()
group := RouterGroup[*Event]{}
sub := group.Group("sub")
var called bool
route := group.Route(http.MethodPost, "/test", func(e *Event) error {
called = true
return nil
})
// ensure that the route was registered only to the main one
// ---
if len(sub.children) != 0 {
t.Fatalf("Expected no sub children, got %d", len(sub.children))
}
if len(group.children) != 2 {
t.Fatalf("Expected %d group children, got %d", 2, len(group.children))
}
// ---
// check the registered route
// ---
if route != group.children[1] {
t.Fatalf("Expected group children %v, got %v", route, group.children[1])
}
if route.Method != http.MethodPost {
t.Fatalf("Expected route method %q, got %q", http.MethodPost, route.Method)
}
if route.Path != "/test" {
t.Fatalf("Expected route path %q, got %q", "/test", route.Path)
}
route.Action(nil)
if !called {
t.Fatal("Expected route action to be called")
}
}
func TestRouterGroupRouteAliases(t *testing.T) {
t.Parallel()
group := RouterGroup[*Event]{}
testErr := errors.New("test")
testAction := func(e *Event) error {
return testErr
}
scenarios := []struct {
route *Route[*Event]
expectMethod string
expectPath string
}{
{
group.Any("/test", testAction),
"",
"/test",
},
{
group.GET("/test", testAction),
http.MethodGet,
"/test",
},
{
group.SEARCH("/test", testAction),
"SEARCH",
"/test",
},
{
group.POST("/test", testAction),
http.MethodPost,
"/test",
},
{
group.DELETE("/test", testAction),
http.MethodDelete,
"/test",
},
{
group.PATCH("/test", testAction),
http.MethodPatch,
"/test",
},
{
group.PUT("/test", testAction),
http.MethodPut,
"/test",
},
{
group.HEAD("/test", testAction),
http.MethodHead,
"/test",
},
{
group.OPTIONS("/test", testAction),
http.MethodOptions,
"/test",
},
}
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%s_%s", i, s.expectMethod, s.expectPath), func(t *testing.T) {
if s.route.Method != s.expectMethod {
t.Fatalf("Expected method %q, got %q", s.expectMethod, s.route.Method)
}
if s.route.Path != s.expectPath {
t.Fatalf("Expected path %q, got %q", s.expectPath, s.route.Path)
}
if err := s.route.Action(nil); !errors.Is(err, testErr) {
t.Fatal("Expected test action")
}
})
}
}
func TestRouterGroupHasRoute(t *testing.T) {
t.Parallel()
group := RouterGroup[*Event]{}
group.Any("/any", nil)
group.GET("/base", nil)
group.DELETE("/base", nil)
sub := group.Group("/sub1")
sub.GET("/a", nil)
sub.POST("/a", nil)
sub2 := sub.Group("/sub2")
sub2.GET("/b", nil)
sub2.GET("/b/{test}", nil)
// special cases to test the normalizations
group.GET("/c/", nil) // the same as /c/{test...}
group.GET("/d/{test...}", nil) // the same as /d/
scenarios := []struct {
method string
path string
expected bool
}{
{
http.MethodGet,
"",
false,
},
{
"",
"/any",
true,
},
{
http.MethodPost,
"/base",
false,
},
{
http.MethodGet,
"/base",
true,
},
{
http.MethodDelete,
"/base",
true,
},
{
http.MethodGet,
"/sub1",
false,
},
{
http.MethodGet,
"/sub1/a",
true,
},
{
http.MethodPost,
"/sub1/a",
true,
},
{
http.MethodDelete,
"/sub1/a",
false,
},
{
http.MethodGet,
"/sub2/b",
false,
},
{
http.MethodGet,
"/sub1/sub2/b",
true,
},
{
http.MethodGet,
"/sub1/sub2/b/{test}",
true,
},
{
http.MethodGet,
"/sub1/sub2/b/{test2}",
false,
},
{
http.MethodGet,
"/c/{test...}",
true,
},
{
http.MethodGet,
"/d/",
true,
},
}
for _, s := range scenarios {
t.Run(s.method+"_"+s.path, func(t *testing.T) {
has := group.HasRoute(s.method, s.path)
if has != s.expected {
t.Fatalf("Expected %v, got %v", s.expected, has)
}
})
}
}

View file

@ -0,0 +1,60 @@
package router
import (
"bytes"
"io"
)
var (
_ io.ReadCloser = (*RereadableReadCloser)(nil)
_ Rereader = (*RereadableReadCloser)(nil)
)
// Rereader defines an interface for rewindable readers.
type Rereader interface {
Reread()
}
// RereadableReadCloser defines a wrapper around a io.ReadCloser reader
// allowing to read the original reader multiple times.
type RereadableReadCloser struct {
io.ReadCloser
copy *bytes.Buffer
active io.Reader
}
// Read implements the standard io.Reader interface.
//
// It reads up to len(b) bytes into b and at at the same time writes
// the read data into an internal bytes buffer.
//
// On EOF the r is "rewinded" to allow reading from r multiple times.
func (r *RereadableReadCloser) Read(b []byte) (int, error) {
if r.active == nil {
if r.copy == nil {
r.copy = &bytes.Buffer{}
}
r.active = io.TeeReader(r.ReadCloser, r.copy)
}
n, err := r.active.Read(b)
if err == io.EOF {
r.Reread()
}
return n, err
}
// Reread satisfies the [Rereader] interface and resets the r internal state to allow rereads.
//
// note: not named Reset to avoid conflicts with other reader interfaces.
func (r *RereadableReadCloser) Reread() {
if r.copy == nil || r.copy.Len() == 0 {
return // nothing to reset or it has been already reset
}
oldCopy := r.copy
r.copy = &bytes.Buffer{}
r.active = io.TeeReader(oldCopy, r.copy)
}

View file

@ -0,0 +1,28 @@
package router_test
import (
"io"
"strings"
"testing"
"github.com/pocketbase/pocketbase/tools/router"
)
func TestRereadableReadCloser(t *testing.T) {
content := "test"
rereadable := &router.RereadableReadCloser{
ReadCloser: io.NopCloser(strings.NewReader(content)),
}
// read multiple times
for i := 0; i < 3; i++ {
result, err := io.ReadAll(rereadable)
if err != nil {
t.Fatalf("[read:%d] %v", i, err)
}
if str := string(result); str != content {
t.Fatalf("[read:%d] Expected %q, got %q", i, content, result)
}
}
}

73
tools/router/route.go Normal file
View file

@ -0,0 +1,73 @@
package router
import "github.com/pocketbase/pocketbase/tools/hook"
type Route[T hook.Resolver] struct {
excludedMiddlewares map[string]struct{}
Action func(e T) error
Method string
Path string
Middlewares []*hook.Handler[T]
}
// BindFunc registers one or multiple middleware functions to the current route.
//
// The registered middleware functions are "anonymous" and with default priority,
// aka. executes in the order they were registered.
//
// If you need to specify a named middleware (ex. so that it can be removed)
// or middleware with custom exec prirority, use the [Route.Bind] method.
func (route *Route[T]) BindFunc(middlewareFuncs ...func(e T) error) *Route[T] {
for _, m := range middlewareFuncs {
route.Middlewares = append(route.Middlewares, &hook.Handler[T]{Func: m})
}
return route
}
// Bind registers one or multiple middleware handlers to the current route.
func (route *Route[T]) Bind(middlewares ...*hook.Handler[T]) *Route[T] {
route.Middlewares = append(route.Middlewares, middlewares...)
// unmark the newly added middlewares in case they were previously "excluded"
if route.excludedMiddlewares != nil {
for _, m := range middlewares {
if m.Id != "" {
delete(route.excludedMiddlewares, m.Id)
}
}
}
return route
}
// Unbind removes one or more middlewares with the specified id(s) from the current route.
//
// It also adds the removed middleware ids to an exclude list so that they could be skipped from
// the execution chain in case the middleware is registered in a parent group.
//
// Anonymous middlewares are considered non-removable, aka. this method
// does nothing if the middleware id is an empty string.
func (route *Route[T]) Unbind(middlewareIds ...string) *Route[T] {
for _, middlewareId := range middlewareIds {
if middlewareId == "" {
continue
}
// remove from the route's middlewares
for i := len(route.Middlewares) - 1; i >= 0; i-- {
if route.Middlewares[i].Id == middlewareId {
route.Middlewares = append(route.Middlewares[:i], route.Middlewares[i+1:]...)
}
}
// add to the exclude list
if route.excludedMiddlewares == nil {
route.excludedMiddlewares = map[string]struct{}{}
}
route.excludedMiddlewares[middlewareId] = struct{}{}
}
return route
}

168
tools/router/route_test.go Normal file
View file

@ -0,0 +1,168 @@
package router
import (
"slices"
"testing"
"github.com/pocketbase/pocketbase/tools/hook"
)
func TestRouteBindFunc(t *testing.T) {
t.Parallel()
r := Route[*Event]{}
calls := ""
// append one function
r.BindFunc(func(e *Event) error {
calls += "a"
return nil
})
// append multiple functions
r.BindFunc(
func(e *Event) error {
calls += "b"
return nil
},
func(e *Event) error {
calls += "c"
return nil
},
)
if total := len(r.Middlewares); total != 3 {
t.Fatalf("Expected %d middlewares, got %v", 3, total)
}
for _, h := range r.Middlewares {
_ = h.Func(nil)
}
if calls != "abc" {
t.Fatalf("Expected calls sequence %q, got %q", "abc", calls)
}
}
func TestRouteBind(t *testing.T) {
t.Parallel()
r := Route[*Event]{
// mock excluded middlewares to check whether the entry will be deleted
excludedMiddlewares: map[string]struct{}{"test2": {}},
}
calls := ""
// append one handler
r.Bind(&hook.Handler[*Event]{
Func: func(e *Event) error {
calls += "a"
return nil
},
})
// append multiple handlers
r.Bind(
&hook.Handler[*Event]{
Id: "test1",
Func: func(e *Event) error {
calls += "b"
return nil
},
},
&hook.Handler[*Event]{
Id: "test2",
Func: func(e *Event) error {
calls += "c"
return nil
},
},
)
if total := len(r.Middlewares); total != 3 {
t.Fatalf("Expected %d middlewares, got %v", 3, total)
}
for _, h := range r.Middlewares {
_ = h.Func(nil)
}
if calls != "abc" {
t.Fatalf("Expected calls %q, got %q", "abc", calls)
}
// ensures that the previously excluded middleware was removed
if len(r.excludedMiddlewares) != 0 {
t.Fatalf("Expected test2 to be removed from the excludedMiddlewares list, got %v", r.excludedMiddlewares)
}
}
func TestRouteUnbind(t *testing.T) {
t.Parallel()
r := Route[*Event]{}
calls := ""
// anonymous middlewares
r.Bind(&hook.Handler[*Event]{
Func: func(e *Event) error {
calls += "a"
return nil // unused value
},
})
// middlewares with id
r.Bind(&hook.Handler[*Event]{
Id: "test1",
Func: func(e *Event) error {
calls += "b"
return nil // unused value
},
})
r.Bind(&hook.Handler[*Event]{
Id: "test2",
Func: func(e *Event) error {
calls += "c"
return nil // unused value
},
})
r.Bind(&hook.Handler[*Event]{
Id: "test3",
Func: func(e *Event) error {
calls += "d"
return nil // unused value
},
})
// remove
r.Unbind("") // should be no-op
r.Unbind("test1", "test3")
if total := len(r.Middlewares); total != 2 {
t.Fatalf("Expected %d middlewares, got %v", 2, total)
}
for _, h := range r.Middlewares {
if err := h.Func(nil); err != nil {
continue
}
}
if calls != "ac" {
t.Fatalf("Expected calls %q, got %q", "ac", calls)
}
// ensure that the id was added in the exclude list
excluded := []string{"test1", "test3"}
if len(r.excludedMiddlewares) != len(excluded) {
t.Fatalf("Expected excludes %v, got %v", excluded, r.excludedMiddlewares)
}
for id := range r.excludedMiddlewares {
if !slices.Contains(excluded, id) {
t.Fatalf("Expected %q to be marked as excluded", id)
}
}
}

329
tools/router/router.go Normal file
View file

@ -0,0 +1,329 @@
package router
import (
"bufio"
"encoding/json"
"errors"
"io"
"log"
"net"
"net/http"
"github.com/pocketbase/pocketbase/tools/hook"
)
type EventCleanupFunc func()
// EventFactoryFunc defines the function responsible for creating a Route specific event
// based on the provided request handler ServeHTTP data.
//
// Optionally return a clean up function that will be invoked right after the route execution.
type EventFactoryFunc[T hook.Resolver] func(w http.ResponseWriter, r *http.Request) (T, EventCleanupFunc)
// Router defines a thin wrapper around the standard Go [http.ServeMux] by
// adding support for routing sub-groups, middlewares and other common utils.
//
// Example:
//
// r := NewRouter[*MyEvent](eventFactory)
//
// // middlewares
// r.BindFunc(m1, m2)
//
// // routes
// r.GET("/test", handler1)
//
// // sub-routers/groups
// api := r.Group("/api")
// api.GET("/admins", handler2)
//
// // generate a http.ServeMux instance based on the router configurations
// mux, _ := r.BuildMux()
//
// http.ListenAndServe("localhost:8090", mux)
type Router[T hook.Resolver] struct {
// @todo consider renaming the type to just Group and replace the embed type
// with an alias after Go 1.24 adds support for generic type aliases
*RouterGroup[T]
eventFactory EventFactoryFunc[T]
}
// NewRouter creates a new Router instance with the provided event factory function.
func NewRouter[T hook.Resolver](eventFactory EventFactoryFunc[T]) *Router[T] {
return &Router[T]{
RouterGroup: &RouterGroup[T]{},
eventFactory: eventFactory,
}
}
// BuildMux constructs a new mux [http.Handler] instance from the current router configurations.
func (r *Router[T]) BuildMux() (http.Handler, error) {
// Note that some of the default std Go handlers like the [http.NotFoundHandler]
// cannot be currently extended and requires defining a custom "catch-all" route
// so that the group middlewares could be executed.
//
// https://github.com/golang/go/issues/65648
if !r.HasRoute("", "/") {
r.Route("", "/", func(e T) error {
return NewNotFoundError("", nil)
})
}
mux := http.NewServeMux()
if err := r.loadMux(mux, r.RouterGroup, nil); err != nil {
return nil, err
}
return mux, nil
}
func (r *Router[T]) loadMux(mux *http.ServeMux, group *RouterGroup[T], parents []*RouterGroup[T]) error {
for _, child := range group.children {
switch v := child.(type) {
case *RouterGroup[T]:
if err := r.loadMux(mux, v, append(parents, group)); err != nil {
return err
}
case *Route[T]:
routeHook := &hook.Hook[T]{}
var pattern string
if v.Method != "" {
pattern = v.Method + " "
}
// add parent groups middlewares
for _, p := range parents {
pattern += p.Prefix
for _, h := range p.Middlewares {
if _, ok := p.excludedMiddlewares[h.Id]; !ok {
if _, ok = group.excludedMiddlewares[h.Id]; !ok {
if _, ok = v.excludedMiddlewares[h.Id]; !ok {
routeHook.Bind(h)
}
}
}
}
}
// add current groups middlewares
pattern += group.Prefix
for _, h := range group.Middlewares {
if _, ok := group.excludedMiddlewares[h.Id]; !ok {
if _, ok = v.excludedMiddlewares[h.Id]; !ok {
routeHook.Bind(h)
}
}
}
// add current route middlewares
pattern += v.Path
for _, h := range v.Middlewares {
if _, ok := v.excludedMiddlewares[h.Id]; !ok {
routeHook.Bind(h)
}
}
mux.HandleFunc(pattern, func(resp http.ResponseWriter, req *http.Request) {
// wrap the response to add write and status tracking
resp = &ResponseWriter{ResponseWriter: resp}
// wrap the request body to allow multiple reads
req.Body = &RereadableReadCloser{ReadCloser: req.Body}
event, cleanupFunc := r.eventFactory(resp, req)
// trigger the handler hook chain
err := routeHook.Trigger(event, v.Action)
if err != nil {
ErrorHandler(resp, req, err)
}
if cleanupFunc != nil {
cleanupFunc()
}
})
default:
return errors.New("invalid Group item type")
}
}
return nil
}
func ErrorHandler(resp http.ResponseWriter, req *http.Request, err error) {
if err == nil {
return
}
if ok, _ := getWritten(resp); ok {
return // a response was already written (aka. already handled)
}
header := resp.Header()
if header.Get("Content-Type") == "" {
header.Set("Content-Type", "application/json")
}
apiErr := ToApiError(err)
resp.WriteHeader(apiErr.Status)
if req.Method != http.MethodHead {
if jsonErr := json.NewEncoder(resp).Encode(apiErr); jsonErr != nil {
log.Println(jsonErr) // truly rare case, log to stderr only for dev purposes
}
}
}
// -------------------------------------------------------------------
type WriteTracker interface {
// Written reports whether a write operation has occurred.
Written() bool
}
type StatusTracker interface {
// Status reports the written response status code.
Status() int
}
type flushErrorer interface {
FlushError() error
}
var (
_ WriteTracker = (*ResponseWriter)(nil)
_ StatusTracker = (*ResponseWriter)(nil)
_ http.Flusher = (*ResponseWriter)(nil)
_ http.Hijacker = (*ResponseWriter)(nil)
_ http.Pusher = (*ResponseWriter)(nil)
_ io.ReaderFrom = (*ResponseWriter)(nil)
_ flushErrorer = (*ResponseWriter)(nil)
)
// ResponseWriter wraps a http.ResponseWriter to track its write state.
type ResponseWriter struct {
http.ResponseWriter
written bool
status int
}
func (rw *ResponseWriter) WriteHeader(status int) {
if rw.written {
return
}
rw.written = true
rw.status = status
rw.ResponseWriter.WriteHeader(status)
}
func (rw *ResponseWriter) Write(b []byte) (int, error) {
if !rw.written {
rw.WriteHeader(http.StatusOK)
}
return rw.ResponseWriter.Write(b)
}
// Written implements [WriteTracker] and returns whether the current response body has been already written.
func (rw *ResponseWriter) Written() bool {
return rw.written
}
// Written implements [StatusTracker] and returns the written status code of the current response.
func (rw *ResponseWriter) Status() int {
return rw.status
}
// Flush implements [http.Flusher] and allows an HTTP handler to flush buffered data to the client.
// This method is no-op if the wrapped writer doesn't support it.
func (rw *ResponseWriter) Flush() {
_ = rw.FlushError()
}
// FlushError is similar to [Flush] but returns [http.ErrNotSupported]
// if the wrapped writer doesn't support it.
func (rw *ResponseWriter) FlushError() error {
err := http.NewResponseController(rw.ResponseWriter).Flush()
if err == nil || !errors.Is(err, http.ErrNotSupported) {
rw.written = true
}
return err
}
// Hijack implements [http.Hijacker] and allows an HTTP handler to take over the current connection.
func (rw *ResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return http.NewResponseController(rw.ResponseWriter).Hijack()
}
// Pusher implements [http.Pusher] to indicate HTTP/2 server push support.
func (rw *ResponseWriter) Push(target string, opts *http.PushOptions) error {
w := rw.ResponseWriter
for {
switch p := w.(type) {
case http.Pusher:
return p.Push(target, opts)
case RWUnwrapper:
w = p.Unwrap()
default:
return http.ErrNotSupported
}
}
}
// ReaderFrom implements [io.ReaderFrom] by checking if the underlying writer supports it.
// Otherwise calls [io.Copy].
func (rw *ResponseWriter) ReadFrom(r io.Reader) (n int64, err error) {
if !rw.written {
rw.WriteHeader(http.StatusOK)
}
w := rw.ResponseWriter
for {
switch rf := w.(type) {
case io.ReaderFrom:
return rf.ReadFrom(r)
case RWUnwrapper:
w = rf.Unwrap()
default:
return io.Copy(rw.ResponseWriter, r)
}
}
}
// Unwrap returns the underlying ResponseWritter instance (usually used by [http.ResponseController]).
func (rw *ResponseWriter) Unwrap() http.ResponseWriter {
return rw.ResponseWriter
}
func getWritten(rw http.ResponseWriter) (bool, error) {
for {
switch w := rw.(type) {
case WriteTracker:
return w.Written(), nil
case RWUnwrapper:
rw = w.Unwrap()
default:
return false, http.ErrNotSupported
}
}
}
func getStatus(rw http.ResponseWriter) (int, error) {
for {
switch w := rw.(type) {
case StatusTracker:
return w.Status(), nil
case RWUnwrapper:
rw = w.Unwrap()
default:
return 0, http.ErrNotSupported
}
}
}

253
tools/router/router_test.go Normal file
View file

@ -0,0 +1,253 @@
package router_test
import (
"errors"
"net/http"
"net/http/httptest"
"testing"
"github.com/pocketbase/pocketbase/tools/hook"
"github.com/pocketbase/pocketbase/tools/router"
)
func TestRouter(t *testing.T) {
calls := ""
r := router.NewRouter(func(w http.ResponseWriter, r *http.Request) (*router.Event, router.EventCleanupFunc) {
return &router.Event{
Response: w,
Request: r,
},
func() {
calls += ":cleanup"
}
})
r.BindFunc(func(e *router.Event) error {
calls += "root_m:"
err := e.Next()
if err != nil {
calls += "/error"
}
return err
})
r.Any("/any", func(e *router.Event) error {
calls += "/any"
return nil
})
r.GET("/a", func(e *router.Event) error {
calls += "/a"
return nil
})
g1 := r.Group("/a/b").BindFunc(func(e *router.Event) error {
calls += "a_b_group_m:"
return e.Next()
})
g1.GET("/1", func(e *router.Event) error {
calls += "/1_get"
return nil
}).BindFunc(func(e *router.Event) error {
calls += "1_get_m:"
return e.Next()
})
g1.POST("/1", func(e *router.Event) error {
calls += "/1_post"
return nil
})
g1.GET("/{param}", func(e *router.Event) error {
calls += "/" + e.Request.PathValue("param")
return errors.New("test") // should be normalized to an ApiError
})
mux, err := r.BuildMux()
if err != nil {
t.Fatal(err)
}
ts := httptest.NewServer(mux)
defer ts.Close()
client := ts.Client()
scenarios := []struct {
method string
path string
calls string
}{
{http.MethodGet, "/any", "root_m:/any:cleanup"},
{http.MethodOptions, "/any", "root_m:/any:cleanup"},
{http.MethodPatch, "/any", "root_m:/any:cleanup"},
{http.MethodPut, "/any", "root_m:/any:cleanup"},
{http.MethodPost, "/any", "root_m:/any:cleanup"},
{http.MethodDelete, "/any", "root_m:/any:cleanup"},
// ---
{http.MethodPost, "/a", "root_m:/error:cleanup"}, // missing
{http.MethodGet, "/a", "root_m:/a:cleanup"},
{http.MethodHead, "/a", "root_m:/a:cleanup"}, // auto registered with the GET
{http.MethodGet, "/a/b/1", "root_m:a_b_group_m:1_get_m:/1_get:cleanup"},
{http.MethodHead, "/a/b/1", "root_m:a_b_group_m:1_get_m:/1_get:cleanup"},
{http.MethodPost, "/a/b/1", "root_m:a_b_group_m:/1_post:cleanup"},
{http.MethodGet, "/a/b/456", "root_m:a_b_group_m:/456/error:cleanup"},
}
for _, s := range scenarios {
t.Run(s.method+"_"+s.path, func(t *testing.T) {
calls = "" // reset
req, err := http.NewRequest(s.method, ts.URL+s.path, nil)
if err != nil {
t.Fatal(err)
}
_, err = client.Do(req)
if err != nil {
t.Fatal(err)
}
if calls != s.calls {
t.Fatalf("Expected calls\n%q\ngot\n%q", s.calls, calls)
}
})
}
}
func TestRouterUnbind(t *testing.T) {
calls := ""
r := router.NewRouter(func(w http.ResponseWriter, r *http.Request) (*router.Event, router.EventCleanupFunc) {
return &router.Event{
Response: w,
Request: r,
},
func() {
calls += ":cleanup"
}
})
r.Bind(&hook.Handler[*router.Event]{
Id: "root_1",
Func: func(e *router.Event) error {
calls += "root_1:"
return e.Next()
},
})
r.Bind(&hook.Handler[*router.Event]{
Id: "root_2",
Func: func(e *router.Event) error {
calls += "root_2:"
return e.Next()
},
})
r.Bind(&hook.Handler[*router.Event]{
Id: "root_3",
Func: func(e *router.Event) error {
calls += "root_3:"
return e.Next()
},
})
r.GET("/action", func(e *router.Event) error {
calls += "root_action"
return nil
}).Unbind("root_1")
ga := r.Group("/group_a")
ga.Unbind("root_1")
ga.Bind(&hook.Handler[*router.Event]{
Id: "group_a_1",
Func: func(e *router.Event) error {
calls += "group_a_1:"
return e.Next()
},
})
ga.Bind(&hook.Handler[*router.Event]{
Id: "group_a_2",
Func: func(e *router.Event) error {
calls += "group_a_2:"
return e.Next()
},
})
ga.Bind(&hook.Handler[*router.Event]{
Id: "group_a_3",
Func: func(e *router.Event) error {
calls += "group_a_3:"
return e.Next()
},
})
ga.GET("/action", func(e *router.Event) error {
calls += "group_a_action"
return nil
}).Unbind("root_2", "group_b_1", "group_a_1")
gb := r.Group("/group_b")
gb.Unbind("root_2")
gb.Bind(&hook.Handler[*router.Event]{
Id: "group_b_1",
Func: func(e *router.Event) error {
calls += "group_b_1:"
return e.Next()
},
})
gb.Bind(&hook.Handler[*router.Event]{
Id: "group_b_2",
Func: func(e *router.Event) error {
calls += "group_b_2:"
return e.Next()
},
})
gb.Bind(&hook.Handler[*router.Event]{
Id: "group_b_3",
Func: func(e *router.Event) error {
calls += "group_b_3:"
return e.Next()
},
})
gb.GET("/action", func(e *router.Event) error {
calls += "group_b_action"
return nil
}).Unbind("group_b_3", "group_a_3", "root_3")
mux, err := r.BuildMux()
if err != nil {
t.Fatal(err)
}
ts := httptest.NewServer(mux)
defer ts.Close()
client := ts.Client()
scenarios := []struct {
method string
path string
calls string
}{
{http.MethodGet, "/action", "root_2:root_3:root_action:cleanup"},
{http.MethodGet, "/group_a/action", "root_3:group_a_2:group_a_3:group_a_action:cleanup"},
{http.MethodGet, "/group_b/action", "root_1:group_b_1:group_b_2:group_b_action:cleanup"},
}
for _, s := range scenarios {
t.Run(s.method+"_"+s.path, func(t *testing.T) {
calls = "" // reset
req, err := http.NewRequest(s.method, ts.URL+s.path, nil)
if err != nil {
t.Fatal(err)
}
_, err = client.Do(req)
if err != nil {
t.Fatal(err)
}
if calls != s.calls {
t.Fatalf("Expected calls\n%q\ngot\n%q", s.calls, calls)
}
})
}
}

View file

@ -0,0 +1,338 @@
package router
import (
"encoding"
"encoding/json"
"errors"
"reflect"
"regexp"
"strconv"
)
var textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()
// JSONPayloadKey is the key for the special UnmarshalRequestData case
// used for reading serialized json payload without normalization.
const JSONPayloadKey string = "@jsonPayload"
// UnmarshalRequestData unmarshals url.Values type of data (query, multipart/form-data, etc.) into dst.
//
// dst must be a pointer to a map[string]any or struct.
//
// If dst is a map[string]any, each data value will be inferred and
// converted to its bool, numeric, or string equivalent value
// (refer to inferValue() for the exact rules).
//
// If dst is a struct, the following field types are supported:
// - bool
// - string
// - int, int8, int16, int32, int64
// - uint, uint8, uint16, uint32, uint64
// - float32, float64
// - serialized json string if submitted under the special "@jsonPayload" key
// - encoding.TextUnmarshaler
// - pointer and slice variations of the above primitives (ex. *string, []string, *[]string []*string, etc.)
// - named/anonymous struct fields
// Dot-notation is used to target nested fields, ex. "nestedStructField.title".
// - embedded struct fields
// The embedded struct fields are treated by default as if they were defined in their parent struct.
// If the embedded struct has a tag matching structTagKey then to set its fields the data keys must be prefixed with that tag
// similar to the regular nested struct fields.
//
// structTagKey and structPrefix are used only when dst is a struct.
//
// structTagKey represents the tag to use to match a data entry with a struct field (defaults to "form").
// If the struct field doesn't have the structTagKey tag, then the exported struct field name will be used as it is.
//
// structPrefix could be provided if all of the data keys are prefixed with a common string
// and you want the struct field to match only the value without the structPrefix
// (ex. for "user.name", "user.email" data keys and structPrefix "user", it will match "name" and "email" struct fields).
//
// Note that while the method was inspired by binders from echo, gorrila/schema, ozzo-routing
// and other similar common routing packages, it is not intended to be a drop-in replacement.
//
// @todo Consider adding support for dot-notation keys, in addition to the prefix, (ex. parent.child.title) to express nested object keys.
func UnmarshalRequestData(data map[string][]string, dst any, structTagKey string, structPrefix string) error {
if len(data) == 0 {
return nil // nothing to unmarshal
}
dstValue := reflect.ValueOf(dst)
if dstValue.Kind() != reflect.Pointer {
return errors.New("dst must be a pointer")
}
dstValue = dereference(dstValue)
dstType := dstValue.Type()
switch dstType.Kind() {
case reflect.Map: // map[string]any
if dstType.Elem().Kind() != reflect.Interface {
return errors.New("dst map value type must be any/interface{}")
}
for k, v := range data {
if k == JSONPayloadKey {
continue // unmarshaled separately
}
total := len(v)
if total == 1 {
dstValue.SetMapIndex(reflect.ValueOf(k), reflect.ValueOf(inferValue(v[0])))
} else {
normalized := make([]any, total)
for i, vItem := range v {
normalized[i] = inferValue(vItem)
}
dstValue.SetMapIndex(reflect.ValueOf(k), reflect.ValueOf(normalized))
}
}
case reflect.Struct:
// set a default tag key
if structTagKey == "" {
structTagKey = "form"
}
err := unmarshalInStructValue(data, dstValue, structTagKey, structPrefix)
if err != nil {
return err
}
default:
return errors.New("dst must be a map[string]any or struct")
}
// @jsonPayload
//
// Special case to scan serialized json string without
// normalization alongside the other data values
// ---------------------------------------------------------------
jsonPayloadValues := data[JSONPayloadKey]
for _, payload := range jsonPayloadValues {
if err := json.Unmarshal([]byte(payload), dst); err != nil {
return err
}
}
return nil
}
// unmarshalInStructValue unmarshals data into the provided struct reflect.Value fields.
func unmarshalInStructValue(
data map[string][]string,
dstStructValue reflect.Value,
structTagKey string,
structPrefix string,
) error {
dstStructType := dstStructValue.Type()
for i := 0; i < dstStructValue.NumField(); i++ {
fieldType := dstStructType.Field(i)
tag := fieldType.Tag.Get(structTagKey)
if tag == "-" || (!fieldType.Anonymous && !fieldType.IsExported()) {
continue // disabled or unexported non-anonymous struct field
}
fieldValue := dereference(dstStructValue.Field(i))
ft := fieldType.Type
if ft.Kind() == reflect.Ptr {
ft = ft.Elem()
}
isSlice := ft.Kind() == reflect.Slice
if isSlice {
ft = ft.Elem()
}
name := tag
if name == "" && !fieldType.Anonymous {
name = fieldType.Name
}
if name != "" && structPrefix != "" {
name = structPrefix + "." + name
}
// (*)encoding.TextUnmarshaler field
// ---
if ft.Implements(textUnmarshalerType) || reflect.PointerTo(ft).Implements(textUnmarshalerType) {
values, ok := data[name]
if !ok || len(values) == 0 || !fieldValue.CanSet() {
continue // no value to load or the field cannot be set
}
if isSlice {
n := len(values)
slice := reflect.MakeSlice(fieldValue.Type(), n, n)
for i, v := range values {
unmarshaler, ok := dereference(slice.Index(i)).Addr().Interface().(encoding.TextUnmarshaler)
if ok {
if err := unmarshaler.UnmarshalText([]byte(v)); err != nil {
return err
}
}
}
fieldValue.Set(slice)
} else {
unmarshaler, ok := fieldValue.Addr().Interface().(encoding.TextUnmarshaler)
if ok {
if err := unmarshaler.UnmarshalText([]byte(values[0])); err != nil {
return err
}
}
}
continue
}
// "regular" field
// ---
if ft.Kind() != reflect.Struct {
values, ok := data[name]
if !ok || len(values) == 0 || !fieldValue.CanSet() {
continue // no value to load
}
if isSlice {
n := len(values)
slice := reflect.MakeSlice(fieldValue.Type(), n, n)
for i, v := range values {
if err := setRegularReflectedValue(dereference(slice.Index(i)), v); err != nil {
return err
}
}
fieldValue.Set(slice)
} else {
if err := setRegularReflectedValue(fieldValue, values[0]); err != nil {
return err
}
}
continue
}
// structs (embedded or nested)
// ---
// slice of structs
if isSlice {
// populating slice of structs is not supported at the moment
// because the filling rules are ambiguous
continue
}
if tag != "" {
structPrefix = tag
} else {
structPrefix = name // name is empty for anonymous structs -> no prefix
}
if err := unmarshalInStructValue(data, fieldValue, structTagKey, structPrefix); err != nil {
return err
}
}
return nil
}
// dereference returns the underlying value v points to.
func dereference(v reflect.Value) reflect.Value {
for v.Kind() == reflect.Ptr {
if v.IsNil() {
// initialize with a new value and continue searching
v.Set(reflect.New(v.Type().Elem()))
}
v = v.Elem()
}
return v
}
// setRegularReflectedValue sets and casts value into rv.
func setRegularReflectedValue(rv reflect.Value, value string) error {
switch rv.Kind() {
case reflect.String:
rv.SetString(value)
case reflect.Bool:
if value == "" {
value = "f"
}
v, err := strconv.ParseBool(value)
if err != nil {
return err
}
rv.SetBool(v)
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int:
if value == "" {
value = "0"
}
v, err := strconv.ParseInt(value, 0, 64)
if err != nil {
return err
}
rv.SetInt(v)
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
if value == "" {
value = "0"
}
v, err := strconv.ParseUint(value, 0, 64)
if err != nil {
return err
}
rv.SetUint(v)
case reflect.Float32, reflect.Float64:
if value == "" {
value = "0"
}
v, err := strconv.ParseFloat(value, 64)
if err != nil {
return err
}
rv.SetFloat(v)
default:
return errors.New("unknown value type " + rv.Kind().String())
}
return nil
}
var inferNumberCharsRegex = regexp.MustCompile(`^[\-\.\d]+$`)
// In order to support more seamlessly both json and multipart/form-data requests,
// the following normalization rules are applied for plain multipart string values:
// - "true" is converted to the json "true"
// - "false" is converted to the json "false"
// - numeric strings are converted to json number ONLY if the resulted
// minimal number string representation is the same as the provided raw string
// (aka. scientific notations, "Infinity", "0.0", "0001", etc. are kept as string)
// - any other string (empty string too) is left as it is
func inferValue(raw string) any {
switch raw {
case "":
return raw
case "true":
return true
case "false":
return false
default:
// try to convert to number
//
// note: expects the provided raw string to match exactly with the minimal string representation of the parsed float
if (raw[0] == '-' || (raw[0] >= '0' && raw[0] <= '9')) &&
inferNumberCharsRegex.Match([]byte(raw)) {
v, err := strconv.ParseFloat(raw, 64)
if err == nil && strconv.FormatFloat(v, 'f', -1, 64) == raw {
return v
}
}
return raw
}
}

View file

@ -0,0 +1,471 @@
package router_test
import (
"bytes"
"encoding/json"
"testing"
"time"
"github.com/pocketbase/pocketbase/tools/router"
)
func pointer[T any](val T) *T {
return &val
}
func TestUnmarshalRequestData(t *testing.T) {
t.Parallel()
mapData := map[string][]string{
"number1": {"1"},
"number2": {"2", "3"},
"number3": {"2.1", "-3.4"},
"number4": {"0", "-0", "0.0001"},
"string0": {""},
"string1": {"a"},
"string2": {"b", "c"},
"string3": {
"0.0",
"-0.0",
"000.1",
"000001",
"-000001",
"1.6E-35",
"-1.6E-35",
"10e100",
"1_000_000",
"1.000.000",
" 123 ",
"0b1",
"0xFF",
"1234A",
"Infinity",
"-Infinity",
"undefined",
"null",
},
"bool1": {"true"},
"bool2": {"true", "false"},
"mixed": {"true", "123", "test"},
"@jsonPayload": {`{"json_a":null,"json_b":123}`, `{"json_c":[1,2,3]}`},
}
structData := map[string][]string{
"stringTag": {"a", "b"},
"StringPtr": {"b"},
"StringSlice": {"a", "b", "c", ""},
"stringSlicePtrTag": {"d", "e"},
"StringSliceOfPtr": {"f", "g"},
"boolTag": {"true"},
"BoolPtr": {"true"},
"BoolSlice": {"true", "false", ""},
"boolSlicePtrTag": {"false", "false", "true"},
"BoolSliceOfPtr": {"false", "true", "false"},
"int8Tag": {"-1", "2"},
"Int8Ptr": {"3"},
"Int8Slice": {"4", "5", ""},
"int8SlicePtrTag": {"5", "6"},
"Int8SliceOfPtr": {"7", "8"},
"int16Tag": {"-1", "2"},
"Int16Ptr": {"3"},
"Int16Slice": {"4", "5", ""},
"int16SlicePtrTag": {"5", "6"},
"Int16SliceOfPtr": {"7", "8"},
"int32Tag": {"-1", "2"},
"Int32Ptr": {"3"},
"Int32Slice": {"4", "5", ""},
"int32SlicePtrTag": {"5", "6"},
"Int32SliceOfPtr": {"7", "8"},
"int64Tag": {"-1", "2"},
"Int64Ptr": {"3"},
"Int64Slice": {"4", "5", ""},
"int64SlicePtrTag": {"5", "6"},
"Int64SliceOfPtr": {"7", "8"},
"intTag": {"-1", "2"},
"IntPtr": {"3"},
"IntSlice": {"4", "5", ""},
"intSlicePtrTag": {"5", "6"},
"IntSliceOfPtr": {"7", "8"},
"uint8Tag": {"1", "2"},
"Uint8Ptr": {"3"},
"Uint8Slice": {"4", "5", ""},
"uint8SlicePtrTag": {"5", "6"},
"Uint8SliceOfPtr": {"7", "8"},
"uint16Tag": {"1", "2"},
"Uint16Ptr": {"3"},
"Uint16Slice": {"4", "5", ""},
"uint16SlicePtrTag": {"5", "6"},
"Uint16SliceOfPtr": {"7", "8"},
"uint32Tag": {"1", "2"},
"Uint32Ptr": {"3"},
"Uint32Slice": {"4", "5", ""},
"uint32SlicePtrTag": {"5", "6"},
"Uint32SliceOfPtr": {"7", "8"},
"uint64Tag": {"1", "2"},
"Uint64Ptr": {"3"},
"Uint64Slice": {"4", "5", ""},
"uint64SlicePtrTag": {"5", "6"},
"Uint64SliceOfPtr": {"7", "8"},
"uintTag": {"1", "2"},
"UintPtr": {"3"},
"UintSlice": {"4", "5", ""},
"uintSlicePtrTag": {"5", "6"},
"UintSliceOfPtr": {"7", "8"},
"float32Tag": {"-1.2"},
"Float32Ptr": {"1.5", "2.0"},
"Float32Slice": {"1", "2.3", "-0.3", ""},
"float32SlicePtrTag": {"-1.3", "3"},
"Float32SliceOfPtr": {"0", "1.2"},
"float64Tag": {"-1.2"},
"Float64Ptr": {"1.5", "2.0"},
"Float64Slice": {"1", "2.3", "-0.3", ""},
"float64SlicePtrTag": {"-1.3", "3"},
"Float64SliceOfPtr": {"0", "1.2"},
"timeTag": {"2009-11-10T15:00:00Z"},
"TimePtr": {"2009-11-10T14:00:00Z", "2009-11-10T15:00:00Z"},
"TimeSlice": {"2009-11-10T14:00:00Z", "2009-11-10T15:00:00Z"},
"timeSlicePtrTag": {"2009-11-10T15:00:00Z", "2009-11-10T16:00:00Z"},
"TimeSliceOfPtr": {"2009-11-10T17:00:00Z", "2009-11-10T18:00:00Z"},
// @jsonPayload fields
"@jsonPayload": {
`{"payloadA":"test", "shouldBeIgnored": "abc"}`,
`{"payloadB":[1,2,3], "payloadC":true}`,
},
// unexported fields or `-` tags
"unexperted": {"test"},
"SkipExported": {"test"},
"unexportedStructFieldWithoutTag.Name": {"test"},
"unexportedStruct.Name": {"test"},
// structs
"StructWithoutTag.Name": {"test1"},
"exportedStruct.Name": {"test2"},
// embedded
"embed_name": {"test3"},
"embed2.embed_name2": {"test4"},
}
type embed1 struct {
Name string `form:"embed_name" json:"embed_name"`
}
type embed2 struct {
Name string `form:"embed_name2" json:"embed_name2"`
}
//nolint
type TestStruct struct {
String string `form:"stringTag" query:"stringTag2"`
StringPtr *string
StringSlice []string
StringSlicePtr *[]string `form:"stringSlicePtrTag"`
StringSliceOfPtr []*string
Bool bool `form:"boolTag" query:"boolTag2"`
BoolPtr *bool
BoolSlice []bool
BoolSlicePtr *[]bool `form:"boolSlicePtrTag"`
BoolSliceOfPtr []*bool
Int8 int8 `form:"int8Tag" query:"int8Tag2"`
Int8Ptr *int8
Int8Slice []int8
Int8SlicePtr *[]int8 `form:"int8SlicePtrTag"`
Int8SliceOfPtr []*int8
Int16 int16 `form:"int16Tag" query:"int16Tag2"`
Int16Ptr *int16
Int16Slice []int16
Int16SlicePtr *[]int16 `form:"int16SlicePtrTag"`
Int16SliceOfPtr []*int16
Int32 int32 `form:"int32Tag" query:"int32Tag2"`
Int32Ptr *int32
Int32Slice []int32
Int32SlicePtr *[]int32 `form:"int32SlicePtrTag"`
Int32SliceOfPtr []*int32
Int64 int64 `form:"int64Tag" query:"int64Tag2"`
Int64Ptr *int64
Int64Slice []int64
Int64SlicePtr *[]int64 `form:"int64SlicePtrTag"`
Int64SliceOfPtr []*int64
Int int `form:"intTag" query:"intTag2"`
IntPtr *int
IntSlice []int
IntSlicePtr *[]int `form:"intSlicePtrTag"`
IntSliceOfPtr []*int
Uint8 uint8 `form:"uint8Tag" query:"uint8Tag2"`
Uint8Ptr *uint8
Uint8Slice []uint8
Uint8SlicePtr *[]uint8 `form:"uint8SlicePtrTag"`
Uint8SliceOfPtr []*uint8
Uint16 uint16 `form:"uint16Tag" query:"uint16Tag2"`
Uint16Ptr *uint16
Uint16Slice []uint16
Uint16SlicePtr *[]uint16 `form:"uint16SlicePtrTag"`
Uint16SliceOfPtr []*uint16
Uint32 uint32 `form:"uint32Tag" query:"uint32Tag2"`
Uint32Ptr *uint32
Uint32Slice []uint32
Uint32SlicePtr *[]uint32 `form:"uint32SlicePtrTag"`
Uint32SliceOfPtr []*uint32
Uint64 uint64 `form:"uint64Tag" query:"uint64Tag2"`
Uint64Ptr *uint64
Uint64Slice []uint64
Uint64SlicePtr *[]uint64 `form:"uint64SlicePtrTag"`
Uint64SliceOfPtr []*uint64
Uint uint `form:"uintTag" query:"uintTag2"`
UintPtr *uint
UintSlice []uint
UintSlicePtr *[]uint `form:"uintSlicePtrTag"`
UintSliceOfPtr []*uint
Float32 float32 `form:"float32Tag" query:"float32Tag2"`
Float32Ptr *float32
Float32Slice []float32
Float32SlicePtr *[]float32 `form:"float32SlicePtrTag"`
Float32SliceOfPtr []*float32
Float64 float64 `form:"float64Tag" query:"float64Tag2"`
Float64Ptr *float64
Float64Slice []float64
Float64SlicePtr *[]float64 `form:"float64SlicePtrTag"`
Float64SliceOfPtr []*float64
// encoding.TextUnmarshaler
Time time.Time `form:"timeTag" query:"timeTag2"`
TimePtr *time.Time
TimeSlice []time.Time
TimeSlicePtr *[]time.Time `form:"timeSlicePtrTag"`
TimeSliceOfPtr []*time.Time
// @jsonPayload fields
JSONPayloadA string `form:"shouldBeIgnored" json:"payloadA"`
JSONPayloadB []int `json:"payloadB"`
JSONPayloadC bool `json:"-"`
// unexported fields or `-` tags
unexported string
SkipExported string `form:"-"`
unexportedStructFieldWithoutTag struct {
Name string `json:"unexportedStructFieldWithoutTag_name"`
}
unexportedStructFieldWithTag struct {
Name string `json:"unexportedStructFieldWithTag_name"`
} `form:"unexportedStruct"`
// structs
StructWithoutTag struct {
Name string `json:"StructWithoutTag_name"`
}
StructWithTag struct {
Name string `json:"StructWithTag_name"`
} `form:"exportedStruct"`
// embedded
embed1
embed2 `form:"embed2"`
}
scenarios := []struct {
name string
data map[string][]string
dst any
tag string
prefix string
error bool
result string
}{
{
name: "nil data",
data: nil,
dst: pointer(map[string]any{}),
error: false,
result: `{}`,
},
{
name: "non-pointer map[string]any",
data: mapData,
dst: map[string]any{},
error: true,
},
{
name: "unsupported *map[string]string",
data: mapData,
dst: pointer(map[string]string{}),
error: true,
},
{
name: "unsupported *map[string][]string",
data: mapData,
dst: pointer(map[string][]string{}),
error: true,
},
{
name: "*map[string]any",
data: mapData,
dst: pointer(map[string]any{}),
result: `{"bool1":true,"bool2":[true,false],"json_a":null,"json_b":123,"json_c":[1,2,3],"mixed":[true,123,"test"],"number1":1,"number2":[2,3],"number3":[2.1,-3.4],"number4":[0,-0,0.0001],"string0":"","string1":"a","string2":["b","c"],"string3":["0.0","-0.0","000.1","000001","-000001","1.6E-35","-1.6E-35","10e100","1_000_000","1.000.000"," 123 ","0b1","0xFF","1234A","Infinity","-Infinity","undefined","null"]}`,
},
{
name: "valid pointer struct (all fields)",
data: structData,
dst: &TestStruct{},
result: `{"String":"a","StringPtr":"b","StringSlice":["a","b","c",""],"StringSlicePtr":["d","e"],"StringSliceOfPtr":["f","g"],"Bool":true,"BoolPtr":true,"BoolSlice":[true,false,false],"BoolSlicePtr":[false,false,true],"BoolSliceOfPtr":[false,true,false],"Int8":-1,"Int8Ptr":3,"Int8Slice":[4,5,0],"Int8SlicePtr":[5,6],"Int8SliceOfPtr":[7,8],"Int16":-1,"Int16Ptr":3,"Int16Slice":[4,5,0],"Int16SlicePtr":[5,6],"Int16SliceOfPtr":[7,8],"Int32":-1,"Int32Ptr":3,"Int32Slice":[4,5,0],"Int32SlicePtr":[5,6],"Int32SliceOfPtr":[7,8],"Int64":-1,"Int64Ptr":3,"Int64Slice":[4,5,0],"Int64SlicePtr":[5,6],"Int64SliceOfPtr":[7,8],"Int":-1,"IntPtr":3,"IntSlice":[4,5,0],"IntSlicePtr":[5,6],"IntSliceOfPtr":[7,8],"Uint8":1,"Uint8Ptr":3,"Uint8Slice":"BAUA","Uint8SlicePtr":"BQY=","Uint8SliceOfPtr":[7,8],"Uint16":1,"Uint16Ptr":3,"Uint16Slice":[4,5,0],"Uint16SlicePtr":[5,6],"Uint16SliceOfPtr":[7,8],"Uint32":1,"Uint32Ptr":3,"Uint32Slice":[4,5,0],"Uint32SlicePtr":[5,6],"Uint32SliceOfPtr":[7,8],"Uint64":1,"Uint64Ptr":3,"Uint64Slice":[4,5,0],"Uint64SlicePtr":[5,6],"Uint64SliceOfPtr":[7,8],"Uint":1,"UintPtr":3,"UintSlice":[4,5,0],"UintSlicePtr":[5,6],"UintSliceOfPtr":[7,8],"Float32":-1.2,"Float32Ptr":1.5,"Float32Slice":[1,2.3,-0.3,0],"Float32SlicePtr":[-1.3,3],"Float32SliceOfPtr":[0,1.2],"Float64":-1.2,"Float64Ptr":1.5,"Float64Slice":[1,2.3,-0.3,0],"Float64SlicePtr":[-1.3,3],"Float64SliceOfPtr":[0,1.2],"Time":"2009-11-10T15:00:00Z","TimePtr":"2009-11-10T14:00:00Z","TimeSlice":["2009-11-10T14:00:00Z","2009-11-10T15:00:00Z"],"TimeSlicePtr":["2009-11-10T15:00:00Z","2009-11-10T16:00:00Z"],"TimeSliceOfPtr":["2009-11-10T17:00:00Z","2009-11-10T18:00:00Z"],"payloadA":"test","payloadB":[1,2,3],"SkipExported":"","StructWithoutTag":{"StructWithoutTag_name":"test1"},"StructWithTag":{"StructWithTag_name":"test2"},"embed_name":"test3","embed_name2":"test4"}`,
},
{
name: "non-pointer struct",
data: structData,
dst: TestStruct{},
error: true,
},
{
name: "invalid struct uint value",
data: map[string][]string{"uintTag": {"-1"}},
dst: &TestStruct{},
error: true,
},
{
name: "invalid struct int value",
data: map[string][]string{"intTag": {"abc"}},
dst: &TestStruct{},
error: true,
},
{
name: "invalid struct bool value",
data: map[string][]string{"boolTag": {"abc"}},
dst: &TestStruct{},
error: true,
},
{
name: "invalid struct float value",
data: map[string][]string{"float64Tag": {"abc"}},
dst: &TestStruct{},
error: true,
},
{
name: "invalid struct TextUnmarshaler value",
data: map[string][]string{"timeTag": {"123"}},
dst: &TestStruct{},
error: true,
},
{
name: "custom tagKey",
data: map[string][]string{
"tag1": {"a"},
"tag2": {"b"},
"tag3": {"c"},
"Item": {"d"},
},
dst: &struct {
Item string `form:"tag1" query:"tag2" json:"tag2"`
}{},
tag: "query",
result: `{"tag2":"b"}`,
},
{
name: "custom prefix",
data: map[string][]string{
"test.A": {"1"},
"A": {"2"},
"test.alias": {"3"},
},
dst: &struct {
A string
B string `form:"alias"`
}{},
prefix: "test",
result: `{"A":"1","B":"3"}`,
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
err := router.UnmarshalRequestData(s.data, s.dst, s.tag, s.prefix)
hasErr := err != nil
if hasErr != s.error {
t.Fatalf("Expected hasErr %v, got %v (%v)", s.error, hasErr, err)
}
if hasErr {
return
}
raw, err := json.Marshal(s.dst)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(raw, []byte(s.result)) {
t.Fatalf("Expected dst \n%s\ngot\n%s", s.result, raw)
}
})
}
}
// note: extra unexported checks in addition to the above test as there
// is no easy way to print nested structs with all their fields.
func TestUnmarshalRequestDataUnexportedFields(t *testing.T) {
t.Parallel()
//nolint:all
type TestStruct struct {
Exported string
unexported string
// to ensure that the reflection doesn't take tags with higher priority than the exported state
unexportedWithTag string `form:"unexportedWithTag" json:"unexportedWithTag"`
}
dst := &TestStruct{}
err := router.UnmarshalRequestData(map[string][]string{
"Exported": {"test"}, // just for reference
"Unexported": {"test"},
"unexported": {"test"},
"UnexportedWithTag": {"test"},
"unexportedWithTag": {"test"},
}, dst, "", "")
if err != nil {
t.Fatal(err)
}
if dst.Exported != "test" {
t.Fatalf("Expected the Exported field to be %q, got %q", "test", dst.Exported)
}
if dst.unexported != "" {
t.Fatalf("Expected the unexported field to remain empty, got %q", dst.unexported)
}
if dst.unexportedWithTag != "" {
t.Fatalf("Expected the unexportedWithTag field to remain empty, got %q", dst.unexportedWithTag)
}
}