1
0
Fork 0

Adding upstream version 0.28.1.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-05-22 10:57:38 +02:00
parent 88f1d47ab6
commit e28c88ef14
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
933 changed files with 194711 additions and 0 deletions

47
apis/api_error_aliases.go Normal file
View file

@ -0,0 +1,47 @@
package apis
import "github.com/pocketbase/pocketbase/tools/router"
// ApiError aliases to minimize the breaking changes with earlier versions
// and for consistency with the JSVM binds.
// -------------------------------------------------------------------
// ToApiError wraps err into ApiError instance (if not already).
func ToApiError(err error) *router.ApiError {
return router.ToApiError(err)
}
// NewApiError is an alias for [router.NewApiError].
func NewApiError(status int, message string, errData any) *router.ApiError {
return router.NewApiError(status, message, errData)
}
// NewBadRequestError is an alias for [router.NewBadRequestError].
func NewBadRequestError(message string, errData any) *router.ApiError {
return router.NewBadRequestError(message, errData)
}
// NewNotFoundError is an alias for [router.NewNotFoundError].
func NewNotFoundError(message string, errData any) *router.ApiError {
return router.NewNotFoundError(message, errData)
}
// NewForbiddenError is an alias for [router.NewForbiddenError].
func NewForbiddenError(message string, errData any) *router.ApiError {
return router.NewForbiddenError(message, errData)
}
// NewUnauthorizedError is an alias for [router.NewUnauthorizedError].
func NewUnauthorizedError(message string, errData any) *router.ApiError {
return router.NewUnauthorizedError(message, errData)
}
// NewTooManyRequestsError is an alias for [router.NewTooManyRequestsError].
func NewTooManyRequestsError(message string, errData any) *router.ApiError {
return router.NewTooManyRequestsError(message, errData)
}
// NewInternalServerError is an alias for [router.NewInternalServerError].
func NewInternalServerError(message string, errData any) *router.ApiError {
return router.NewInternalServerError(message, errData)
}

155
apis/backup.go Normal file
View file

@ -0,0 +1,155 @@
package apis
import (
"context"
"net/http"
"path/filepath"
"time"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tools/router"
"github.com/pocketbase/pocketbase/tools/routine"
"github.com/pocketbase/pocketbase/tools/types"
"github.com/spf13/cast"
)
// bindBackupApi registers the file api endpoints and the corresponding handlers.
func bindBackupApi(app core.App, rg *router.RouterGroup[*core.RequestEvent]) {
sub := rg.Group("/backups")
sub.GET("", backupsList).Bind(RequireSuperuserAuth())
sub.POST("", backupCreate).Bind(RequireSuperuserAuth())
sub.POST("/upload", backupUpload).Bind(BodyLimit(0), RequireSuperuserAuth())
sub.GET("/{key}", backupDownload) // relies on superuser file token
sub.DELETE("/{key}", backupDelete).Bind(RequireSuperuserAuth())
sub.POST("/{key}/restore", backupRestore).Bind(RequireSuperuserAuth())
}
type backupFileInfo struct {
Modified types.DateTime `json:"modified"`
Key string `json:"key"`
Size int64 `json:"size"`
}
func backupsList(e *core.RequestEvent) error {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
fsys, err := e.App.NewBackupsFilesystem()
if err != nil {
return e.BadRequestError("Failed to load backups filesystem.", err)
}
defer fsys.Close()
fsys.SetContext(ctx)
backups, err := fsys.List("")
if err != nil {
return e.BadRequestError("Failed to retrieve backup items. Raw error: \n"+err.Error(), nil)
}
result := make([]backupFileInfo, len(backups))
for i, obj := range backups {
modified, _ := types.ParseDateTime(obj.ModTime)
result[i] = backupFileInfo{
Key: obj.Key,
Size: obj.Size,
Modified: modified,
}
}
return e.JSON(http.StatusOK, result)
}
func backupDownload(e *core.RequestEvent) error {
fileToken := e.Request.URL.Query().Get("token")
authRecord, err := e.App.FindAuthRecordByToken(fileToken, core.TokenTypeFile)
if err != nil || !authRecord.IsSuperuser() {
return e.ForbiddenError("Insufficient permissions to access the resource.", err)
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
defer cancel()
fsys, err := e.App.NewBackupsFilesystem()
if err != nil {
return e.InternalServerError("Failed to load backups filesystem.", err)
}
defer fsys.Close()
fsys.SetContext(ctx)
key := e.Request.PathValue("key")
return fsys.Serve(
e.Response,
e.Request,
key,
filepath.Base(key), // without the path prefix (if any)
)
}
func backupDelete(e *core.RequestEvent) error {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
fsys, err := e.App.NewBackupsFilesystem()
if err != nil {
return e.InternalServerError("Failed to load backups filesystem.", err)
}
defer fsys.Close()
fsys.SetContext(ctx)
key := e.Request.PathValue("key")
if key != "" && cast.ToString(e.App.Store().Get(core.StoreKeyActiveBackup)) == key {
return e.BadRequestError("The backup is currently being used and cannot be deleted.", nil)
}
if err := fsys.Delete(key); err != nil {
return e.BadRequestError("Invalid or already deleted backup file. Raw error: \n"+err.Error(), nil)
}
return e.NoContent(http.StatusNoContent)
}
func backupRestore(e *core.RequestEvent) error {
if e.App.Store().Has(core.StoreKeyActiveBackup) {
return e.BadRequestError("Try again later - another backup/restore process has already been started.", nil)
}
key := e.Request.PathValue("key")
existsCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
fsys, err := e.App.NewBackupsFilesystem()
if err != nil {
return e.InternalServerError("Failed to load backups filesystem.", err)
}
defer fsys.Close()
fsys.SetContext(existsCtx)
if exists, err := fsys.Exists(key); !exists {
return e.BadRequestError("Missing or invalid backup file.", err)
}
routine.FireAndForget(func() {
// give some optimistic time to write the response before restarting the app
time.Sleep(1 * time.Second)
// wait max 10 minutes to fetch the backup
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
defer cancel()
if err := e.App.RestoreBackup(ctx, key); err != nil {
e.App.Logger().Error("Failed to restore backup", "key", key, "error", err.Error())
}
})
return e.NoContent(http.StatusNoContent)
}

78
apis/backup_create.go Normal file
View file

@ -0,0 +1,78 @@
package apis
import (
"context"
"net/http"
"regexp"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/pocketbase/pocketbase/core"
)
func backupCreate(e *core.RequestEvent) error {
if e.App.Store().Has(core.StoreKeyActiveBackup) {
return e.BadRequestError("Try again later - another backup/restore process has already been started", nil)
}
form := new(backupCreateForm)
form.app = e.App
err := e.BindBody(form)
if err != nil {
return e.BadRequestError("An error occurred while loading the submitted data.", err)
}
err = form.validate()
if err != nil {
return e.BadRequestError("An error occurred while validating the submitted data.", err)
}
err = e.App.CreateBackup(context.Background(), form.Name)
if err != nil {
return e.BadRequestError("Failed to create backup.", err)
}
// we don't retrieve the generated backup file because it may not be
// available yet due to the eventually consistent nature of some S3 providers
return e.NoContent(http.StatusNoContent)
}
// -------------------------------------------------------------------
var backupNameRegex = regexp.MustCompile(`^[a-z0-9_-]+\.zip$`)
type backupCreateForm struct {
app core.App
Name string `form:"name" json:"name"`
}
func (form *backupCreateForm) validate() error {
return validation.ValidateStruct(form,
validation.Field(
&form.Name,
validation.Length(1, 150),
validation.Match(backupNameRegex),
validation.By(form.checkUniqueName),
),
)
}
func (form *backupCreateForm) checkUniqueName(value any) error {
v, _ := value.(string)
if v == "" {
return nil // nothing to check
}
fsys, err := form.app.NewBackupsFilesystem()
if err != nil {
return err
}
defer fsys.Close()
if exists, err := fsys.Exists(v); err != nil || exists {
return validation.NewError("validation_backup_name_exists", "The backup file name is invalid or already exists.")
}
return nil
}

823
apis/backup_test.go Normal file
View file

@ -0,0 +1,823 @@
package apis_test
import (
"archive/zip"
"bytes"
"context"
"io"
"mime/multipart"
"net/http"
"strings"
"testing"
"github.com/pocketbase/pocketbase/apis"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
"github.com/pocketbase/pocketbase/tools/filesystem/blob"
)
func TestBackupsList(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "unauthorized",
Method: http.MethodGet,
URL: "/api/backups",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := createTestBackups(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as regular user",
Method: http.MethodGet,
URL: "/api/backups",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := createTestBackups(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as superuser (empty list)",
Method: http.MethodGet,
URL: "/api/backups",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
ExpectedStatus: 200,
ExpectedContent: []string{`[]`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as superuser",
Method: http.MethodGet,
URL: "/api/backups",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := createTestBackups(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"test1.zip"`,
`"test2.zip"`,
`"test3.zip"`,
},
ExpectedEvents: map[string]int{"*": 0},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestBackupsCreate(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "unauthorized",
Method: http.MethodPost,
URL: "/api/backups",
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
ensureNoBackups(t, app)
},
ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as regular user",
Method: http.MethodPost,
URL: "/api/backups",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
ensureNoBackups(t, app)
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as superuser (pending backup)",
Method: http.MethodPost,
URL: "/api/backups",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Store().Set(core.StoreKeyActiveBackup, "")
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
ensureNoBackups(t, app)
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as superuser (autogenerated name)",
Method: http.MethodPost,
URL: "/api/backups",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
files, err := getBackupFiles(app)
if err != nil {
t.Fatal(err)
}
if total := len(files); total != 1 {
t.Fatalf("Expected 1 backup file, got %d", total)
}
expected := "pb_backup_"
if !strings.HasPrefix(files[0].Key, expected) {
t.Fatalf("Expected backup file with prefix %q, got %q", expected, files[0].Key)
}
},
ExpectedStatus: 204,
ExpectedEvents: map[string]int{
"*": 0,
"OnBackupCreate": 1,
},
},
{
Name: "authorized as superuser (invalid name)",
Method: http.MethodPost,
URL: "/api/backups",
Body: strings.NewReader(`{"name":"!test.zip"}`),
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
ensureNoBackups(t, app)
},
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{`,
`"name":{"code":"validation_match_invalid"`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as superuser (valid name)",
Method: http.MethodPost,
URL: "/api/backups",
Body: strings.NewReader(`{"name":"test.zip"}`),
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
files, err := getBackupFiles(app)
if err != nil {
t.Fatal(err)
}
if total := len(files); total != 1 {
t.Fatalf("Expected 1 backup file, got %d", total)
}
expected := "test.zip"
if files[0].Key != expected {
t.Fatalf("Expected backup file %q, got %q", expected, files[0].Key)
}
},
ExpectedStatus: 204,
ExpectedEvents: map[string]int{
"*": 0,
"OnBackupCreate": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestBackupUpload(t *testing.T) {
t.Parallel()
// create dummy form data bodies
type body struct {
buffer io.Reader
contentType string
}
bodies := make([]body, 10)
for i := 0; i < 10; i++ {
func() {
zb := new(bytes.Buffer)
zw := zip.NewWriter(zb)
if err := zw.Close(); err != nil {
t.Fatal(err)
}
b := new(bytes.Buffer)
mw := multipart.NewWriter(b)
mfw, err := mw.CreateFormFile("file", "test")
if err != nil {
t.Fatal(err)
}
if _, err := io.Copy(mfw, zb); err != nil {
t.Fatal(err)
}
mw.Close()
bodies[i] = body{
buffer: b,
contentType: mw.FormDataContentType(),
}
}()
}
// ---
scenarios := []tests.ApiScenario{
{
Name: "unauthorized",
Method: http.MethodPost,
URL: "/api/backups/upload",
Body: bodies[0].buffer,
Headers: map[string]string{
"Content-Type": bodies[0].contentType,
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
ensureNoBackups(t, app)
},
ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as regular user",
Method: http.MethodPost,
URL: "/api/backups/upload",
Body: bodies[1].buffer,
Headers: map[string]string{
"Content-Type": bodies[1].contentType,
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
ensureNoBackups(t, app)
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as superuser (missing file)",
Method: http.MethodPost,
URL: "/api/backups/upload",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
ensureNoBackups(t, app)
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as superuser (existing backup name)",
Method: http.MethodPost,
URL: "/api/backups/upload",
Body: bodies[3].buffer,
Headers: map[string]string{
"Content-Type": bodies[3].contentType,
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
fsys, err := app.NewBackupsFilesystem()
if err != nil {
t.Fatal(err)
}
defer fsys.Close()
// create a dummy backup file to simulate existing backups
if err := fsys.Upload([]byte("123"), "test"); err != nil {
t.Fatal(err)
}
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
files, _ := getBackupFiles(app)
if total := len(files); total != 1 {
t.Fatalf("Expected %d backup file, got %d", 1, total)
}
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{"file":{`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as superuser (valid file)",
Method: http.MethodPost,
URL: "/api/backups/upload",
Body: bodies[4].buffer,
Headers: map[string]string{
"Content-Type": bodies[4].contentType,
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
files, _ := getBackupFiles(app)
if total := len(files); total != 1 {
t.Fatalf("Expected %d backup file, got %d", 1, total)
}
},
ExpectedStatus: 204,
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "ensure that the default body limit is skipped",
Method: http.MethodPost,
URL: "/api/backups/upload",
Body: bytes.NewBuffer(make([]byte, apis.DefaultMaxBodySize+100)),
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
ExpectedStatus: 400, // it doesn't matter as long as it is not 413
ExpectedContent: []string{`"data":{`},
NotExpectedContent: []string{"entity too large"},
ExpectedEvents: map[string]int{"*": 0},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestBackupsDownload(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "unauthorized",
Method: http.MethodGet,
URL: "/api/backups/test1.zip",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := createTestBackups(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "with record auth header",
Method: http.MethodGet,
URL: "/api/backups/test1.zip",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := createTestBackups(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "with superuser auth header",
Method: http.MethodGet,
URL: "/api/backups/test1.zip",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := createTestBackups(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "with empty or invalid token",
Method: http.MethodGet,
URL: "/api/backups/test1.zip?token=",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := createTestBackups(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "with valid record auth token",
Method: http.MethodGet,
URL: "/api/backups/test1.zip?token=eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := createTestBackups(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "with valid record file token",
Method: http.MethodGet,
URL: "/api/backups/test1.zip?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6ImZpbGUiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8ifQ.nSTLuCPcGpWn2K2l-BFkC3Vlzc-ZTDPByYq8dN1oPSo",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := createTestBackups(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "with valid superuser auth token",
Method: http.MethodGet,
URL: "/api/backups/test1.zip?token=eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := createTestBackups(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "with expired superuser file token",
Method: http.MethodGet,
URL: "/api/backups/test1.zip?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsImV4cCI6MTY0MDk5MTY2MSwidHlwZSI6ImZpbGUiLCJjb2xsZWN0aW9uSWQiOiJwYmNfMzE0MjYzNTgyMyJ9.nqqtqpPhxU0045F4XP_ruAkzAidYBc5oPy9ErN3XBq0",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := createTestBackups(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "with valid superuser file token but missing backup name",
Method: http.MethodGet,
URL: "/api/backups/missing?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6ImZpbGUiLCJjb2xsZWN0aW9uSWQiOiJwYmNfMzE0MjYzNTgyMyJ9.Lupz541xRvrktwkrl55p5pPCF77T69ZRsohsIcb2dxc",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := createTestBackups(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "with valid superuser file token",
Method: http.MethodGet,
URL: "/api/backups/test1.zip?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6ImZpbGUiLCJjb2xsZWN0aW9uSWQiOiJwYmNfMzE0MjYzNTgyMyJ9.Lupz541xRvrktwkrl55p5pPCF77T69ZRsohsIcb2dxc",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := createTestBackups(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 200,
ExpectedContent: []string{
"storage/",
"data.db",
"auxiliary.db",
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "with valid superuser file token and backup name with escaped char",
Method: http.MethodGet,
URL: "/api/backups/%40test4.zip?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6ImZpbGUiLCJjb2xsZWN0aW9uSWQiOiJwYmNfMzE0MjYzNTgyMyJ9.Lupz541xRvrktwkrl55p5pPCF77T69ZRsohsIcb2dxc",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := createTestBackups(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 200,
ExpectedContent: []string{
"storage/",
"data.db",
"auxiliary.db",
},
ExpectedEvents: map[string]int{"*": 0},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestBackupsDelete(t *testing.T) {
t.Parallel()
noTestBackupFilesChanges := func(t testing.TB, app *tests.TestApp) {
files, err := getBackupFiles(app)
if err != nil {
t.Fatal(err)
}
expected := 4
if total := len(files); total != expected {
t.Fatalf("Expected %d backup(s), got %d", expected, total)
}
}
scenarios := []tests.ApiScenario{
{
Name: "unauthorized",
Method: http.MethodDelete,
URL: "/api/backups/test1.zip",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := createTestBackups(app); err != nil {
t.Fatal(err)
}
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
noTestBackupFilesChanges(t, app)
},
ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as regular user",
Method: http.MethodDelete,
URL: "/api/backups/test1.zip",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := createTestBackups(app); err != nil {
t.Fatal(err)
}
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
noTestBackupFilesChanges(t, app)
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as superuser (missing file)",
Method: http.MethodDelete,
URL: "/api/backups/missing.zip",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := createTestBackups(app); err != nil {
t.Fatal(err)
}
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
noTestBackupFilesChanges(t, app)
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as superuser (existing file with matching active backup)",
Method: http.MethodDelete,
URL: "/api/backups/test1.zip",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := createTestBackups(app); err != nil {
t.Fatal(err)
}
// mock active backup with the same name to delete
app.Store().Set(core.StoreKeyActiveBackup, "test1.zip")
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
noTestBackupFilesChanges(t, app)
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as superuser (existing file and no matching active backup)",
Method: http.MethodDelete,
URL: "/api/backups/test1.zip",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := createTestBackups(app); err != nil {
t.Fatal(err)
}
// mock active backup with different name
app.Store().Set(core.StoreKeyActiveBackup, "new.zip")
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
files, err := getBackupFiles(app)
if err != nil {
t.Fatal(err)
}
if total := len(files); total != 3 {
t.Fatalf("Expected %d backup files, got %d", 3, total)
}
deletedFile := "test1.zip"
for _, f := range files {
if f.Key == deletedFile {
t.Fatalf("Expected backup %q to be deleted", deletedFile)
}
}
},
ExpectedStatus: 204,
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as superuser (backup with escaped character)",
Method: http.MethodDelete,
URL: "/api/backups/%40test4.zip",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := createTestBackups(app); err != nil {
t.Fatal(err)
}
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
files, err := getBackupFiles(app)
if err != nil {
t.Fatal(err)
}
if total := len(files); total != 3 {
t.Fatalf("Expected %d backup files, got %d", 3, total)
}
deletedFile := "@test4.zip"
for _, f := range files {
if f.Key == deletedFile {
t.Fatalf("Expected backup %q to be deleted", deletedFile)
}
}
},
ExpectedStatus: 204,
ExpectedEvents: map[string]int{"*": 0},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestBackupsRestore(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "unauthorized",
Method: http.MethodPost,
URL: "/api/backups/test1.zip/restore",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := createTestBackups(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as regular user",
Method: http.MethodPost,
URL: "/api/backups/test1.zip/restore",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := createTestBackups(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as superuser (missing file)",
Method: http.MethodPost,
URL: "/api/backups/missing.zip/restore",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := createTestBackups(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as superuser (active backup process)",
Method: http.MethodPost,
URL: "/api/backups/test1.zip/restore",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := createTestBackups(app); err != nil {
t.Fatal(err)
}
app.Store().Set(core.StoreKeyActiveBackup, "")
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
// -------------------------------------------------------------------
func createTestBackups(app core.App) error {
ctx := context.Background()
if err := app.CreateBackup(ctx, "test1.zip"); err != nil {
return err
}
if err := app.CreateBackup(ctx, "test2.zip"); err != nil {
return err
}
if err := app.CreateBackup(ctx, "test3.zip"); err != nil {
return err
}
if err := app.CreateBackup(ctx, "@test4.zip"); err != nil {
return err
}
return nil
}
func getBackupFiles(app core.App) ([]*blob.ListObject, error) {
fsys, err := app.NewBackupsFilesystem()
if err != nil {
return nil, err
}
defer fsys.Close()
return fsys.List("")
}
func ensureNoBackups(t testing.TB, app *tests.TestApp) {
files, err := getBackupFiles(app)
if err != nil {
t.Fatal(err)
}
if total := len(files); total != 0 {
t.Fatalf("Expected 0 backup files, got %d", total)
}
}

72
apis/backup_upload.go Normal file
View file

@ -0,0 +1,72 @@
package apis
import (
"net/http"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/core/validators"
"github.com/pocketbase/pocketbase/tools/filesystem"
)
func backupUpload(e *core.RequestEvent) error {
fsys, err := e.App.NewBackupsFilesystem()
if err != nil {
return err
}
defer fsys.Close()
form := new(backupUploadForm)
form.fsys = fsys
files, _ := e.FindUploadedFiles("file")
if len(files) > 0 {
form.File = files[0]
}
err = form.validate()
if err != nil {
return e.BadRequestError("An error occurred while validating the submitted data.", err)
}
err = fsys.UploadFile(form.File, form.File.OriginalName)
if err != nil {
return e.BadRequestError("Failed to upload backup.", err)
}
// we don't retrieve the generated backup file because it may not be
// available yet due to the eventually consistent nature of some S3 providers
return e.NoContent(http.StatusNoContent)
}
// -------------------------------------------------------------------
type backupUploadForm struct {
fsys *filesystem.System
File *filesystem.File `json:"file"`
}
func (form *backupUploadForm) validate() error {
return validation.ValidateStruct(form,
validation.Field(
&form.File,
validation.Required,
validation.By(validators.UploadedFileMimeType([]string{"application/zip"})),
validation.By(form.checkUniqueName),
),
)
}
func (form *backupUploadForm) checkUniqueName(value any) error {
v, _ := value.(*filesystem.File)
if v == nil {
return nil // nothing to check
}
// note: we use the original name because that is what we upload
if exists, err := form.fsys.Exists(v.OriginalName); err != nil || exists {
return validation.NewError("validation_backup_name_exists", "Backup file with the specified name already exists.")
}
return nil
}

174
apis/base.go Normal file
View file

@ -0,0 +1,174 @@
package apis
import (
"errors"
"fmt"
"io/fs"
"net/http"
"path/filepath"
"strings"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tools/router"
)
// StaticWildcardParam is the name of Static handler wildcard parameter.
const StaticWildcardParam = "path"
// NewRouter returns a new router instance loaded with the default app middlewares and api routes.
func NewRouter(app core.App) (*router.Router[*core.RequestEvent], error) {
pbRouter := router.NewRouter(func(w http.ResponseWriter, r *http.Request) (*core.RequestEvent, router.EventCleanupFunc) {
event := new(core.RequestEvent)
event.Response = w
event.Request = r
event.App = app
return event, nil
})
// register default middlewares
pbRouter.Bind(activityLogger())
pbRouter.Bind(panicRecover())
pbRouter.Bind(rateLimit())
pbRouter.Bind(loadAuthToken())
pbRouter.Bind(securityHeaders())
pbRouter.Bind(BodyLimit(DefaultMaxBodySize))
apiGroup := pbRouter.Group("/api")
bindSettingsApi(app, apiGroup)
bindCollectionApi(app, apiGroup)
bindRecordCrudApi(app, apiGroup)
bindRecordAuthApi(app, apiGroup)
bindLogsApi(app, apiGroup)
bindBackupApi(app, apiGroup)
bindCronApi(app, apiGroup)
bindFileApi(app, apiGroup)
bindBatchApi(app, apiGroup)
bindRealtimeApi(app, apiGroup)
bindHealthApi(app, apiGroup)
return pbRouter, nil
}
// WrapStdHandler wraps Go [http.Handler] into a PocketBase handler func.
func WrapStdHandler(h http.Handler) func(*core.RequestEvent) error {
return func(e *core.RequestEvent) error {
h.ServeHTTP(e.Response, e.Request)
return nil
}
}
// WrapStdMiddleware wraps Go [func(http.Handler) http.Handle] into a PocketBase middleware func.
func WrapStdMiddleware(m func(http.Handler) http.Handler) func(*core.RequestEvent) error {
return func(e *core.RequestEvent) (err error) {
m(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
e.Response = w
e.Request = r
err = e.Next()
})).ServeHTTP(e.Response, e.Request)
return err
}
}
// MustSubFS returns an [fs.FS] corresponding to the subtree rooted at fsys's dir.
//
// This is similar to [fs.Sub] but panics on failure.
func MustSubFS(fsys fs.FS, dir string) fs.FS {
dir = filepath.ToSlash(filepath.Clean(dir)) // ToSlash in case of Windows path
sub, err := fs.Sub(fsys, dir)
if err != nil {
panic(fmt.Errorf("failed to create sub FS: %w", err))
}
return sub
}
// Static is a handler function to serve static directory content from fsys.
//
// If a file resource is missing and indexFallback is set, the request
// will be forwarded to the base index.html (useful for SPA with pretty urls).
//
// NB! Expects the route to have a "{path...}" wildcard parameter.
//
// Special redirects:
// - if "path" is a file that ends in index.html, it is redirected to its non-index.html version (eg. /test/index.html -> /test/)
// - if "path" is a directory that has index.html, the index.html file is rendered,
// otherwise if missing - returns 404 or fallback to the root index.html if indexFallback is set
//
// Example:
//
// fsys := os.DirFS("./pb_public")
// router.GET("/files/{path...}", apis.Static(fsys, false))
func Static(fsys fs.FS, indexFallback bool) func(*core.RequestEvent) error {
if fsys == nil {
panic("Static: the provided fs.FS argument is nil")
}
return func(e *core.RequestEvent) error {
// disable the activity logger to avoid flooding with messages
//
// note: errors are still logged
if e.Get(requestEventKeySkipSuccessActivityLog) == nil {
e.Set(requestEventKeySkipSuccessActivityLog, true)
}
filename := e.Request.PathValue(StaticWildcardParam)
filename = filepath.ToSlash(filepath.Clean(strings.TrimPrefix(filename, "/")))
// eagerly check for directory traversal
//
// note: this is just out of an abundance of caution because the fs.FS implementation could be non-std,
// but usually shouldn't be necessary since os.DirFS.Open is expected to fail if the filename starts with dots
if len(filename) > 2 && filename[0] == '.' && filename[1] == '.' && (filename[2] == '/' || filename[2] == '\\') {
if indexFallback && filename != router.IndexPage {
return e.FileFS(fsys, router.IndexPage)
}
return router.ErrFileNotFound
}
fi, err := fs.Stat(fsys, filename)
if err != nil {
if indexFallback && filename != router.IndexPage {
return e.FileFS(fsys, router.IndexPage)
}
return router.ErrFileNotFound
}
if fi.IsDir() {
// redirect to a canonical dir url, aka. with trailing slash
if !strings.HasSuffix(e.Request.URL.Path, "/") {
return e.Redirect(http.StatusMovedPermanently, safeRedirectPath(e.Request.URL.Path+"/"))
}
} else {
urlPath := e.Request.URL.Path
if strings.HasSuffix(urlPath, "/") {
// redirect to a non-trailing slash file route
urlPath = strings.TrimRight(urlPath, "/")
if len(urlPath) > 0 {
return e.Redirect(http.StatusMovedPermanently, safeRedirectPath(urlPath))
}
} else if stripped, ok := strings.CutSuffix(urlPath, router.IndexPage); ok {
// redirect without the index.html
return e.Redirect(http.StatusMovedPermanently, safeRedirectPath(stripped))
}
}
fileErr := e.FileFS(fsys, filename)
if fileErr != nil && indexFallback && filename != router.IndexPage && errors.Is(fileErr, router.ErrFileNotFound) {
return e.FileFS(fsys, router.IndexPage)
}
return fileErr
}
}
// safeRedirectPath normalizes the path string by replacing all beginning slashes
// (`\\`, `//`, `\/`) with a single forward slash to prevent open redirect attacks
func safeRedirectPath(path string) string {
if len(path) > 1 && (path[0] == '\\' || path[0] == '/') && (path[1] == '\\' || path[1] == '/') {
path = "/" + strings.TrimLeft(path, `/\`)
}
return path
}

313
apis/base_test.go Normal file
View file

@ -0,0 +1,313 @@
package apis_test
import (
"fmt"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
"github.com/pocketbase/pocketbase/apis"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
"github.com/pocketbase/pocketbase/tools/router"
)
func TestWrapStdHandler(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
e := new(core.RequestEvent)
e.App = app
e.Request = req
e.Response = rec
err := apis.WrapStdHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("test"))
}))(e)
if err != nil {
t.Fatal(err)
}
if body := rec.Body.String(); body != "test" {
t.Fatalf("Expected body %q, got %q", "test", body)
}
}
func TestWrapStdMiddleware(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
e := new(core.RequestEvent)
e.App = app
e.Request = req
e.Response = rec
err := apis.WrapStdMiddleware(func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("test"))
})
})(e)
if err != nil {
t.Fatal(err)
}
if body := rec.Body.String(); body != "test" {
t.Fatalf("Expected body %q, got %q", "test", body)
}
}
func TestStatic(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
dir := createTestDir(t)
defer os.RemoveAll(dir)
fsys := os.DirFS(filepath.Join(dir, "sub"))
type staticScenario struct {
path string
indexFallback bool
expectedStatus int
expectBody string
expectError bool
}
scenarios := []staticScenario{
{
path: "",
indexFallback: false,
expectedStatus: 200,
expectBody: "sub index.html",
expectError: false,
},
{
path: "missing/a/b/c",
indexFallback: false,
expectedStatus: 404,
expectBody: "",
expectError: true,
},
{
path: "missing/a/b/c",
indexFallback: true,
expectedStatus: 200,
expectBody: "sub index.html",
expectError: false,
},
{
path: "testroot", // parent directory file
indexFallback: false,
expectedStatus: 404,
expectBody: "",
expectError: true,
},
{
path: "test",
indexFallback: false,
expectedStatus: 200,
expectBody: "sub test",
expectError: false,
},
{
path: "sub2",
indexFallback: false,
expectedStatus: 301,
expectBody: "",
expectError: false,
},
{
path: "sub2/",
indexFallback: false,
expectedStatus: 200,
expectBody: "sub2 index.html",
expectError: false,
},
{
path: "sub2/test",
indexFallback: false,
expectedStatus: 200,
expectBody: "sub2 test",
expectError: false,
},
{
path: "sub2/test/",
indexFallback: false,
expectedStatus: 301,
expectBody: "",
expectError: false,
},
}
// extra directory traversal checks
dtp := []string{
"/../",
"\\../",
"../",
"../../",
"..\\",
"..\\..\\",
"../..\\",
"..\\..//",
`%2e%2e%2f`,
`%2e%2e%2f%2e%2e%2f`,
`%2e%2e/`,
`%2e%2e/%2e%2e/`,
`..%2f`,
`..%2f..%2f`,
`%2e%2e%5c`,
`%2e%2e%5c%2e%2e%5c`,
`%2e%2e\`,
`%2e%2e\%2e%2e\`,
`..%5c`,
`..%5c..%5c`,
`%252e%252e%255c`,
`%252e%252e%255c%252e%252e%255c`,
`..%255c`,
`..%255c..%255c`,
}
for _, p := range dtp {
scenarios = append(scenarios,
staticScenario{
path: p + "testroot",
indexFallback: false,
expectedStatus: 404,
expectBody: "",
expectError: true,
},
staticScenario{
path: p + "testroot",
indexFallback: true,
expectedStatus: 200,
expectBody: "sub index.html",
expectError: false,
},
)
}
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%s_%v", i, s.path, s.indexFallback), func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/"+s.path, nil)
req.SetPathValue(apis.StaticWildcardParam, s.path)
rec := httptest.NewRecorder()
e := new(core.RequestEvent)
e.App = app
e.Request = req
e.Response = rec
err := apis.Static(fsys, s.indexFallback)(e)
hasErr := err != nil
if hasErr != s.expectError {
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
}
body := rec.Body.String()
if body != s.expectBody {
t.Fatalf("Expected body %q, got %q", s.expectBody, body)
}
if hasErr {
apiErr := router.ToApiError(err)
if apiErr.Status != s.expectedStatus {
t.Fatalf("Expected status code %d, got %d", s.expectedStatus, apiErr.Status)
}
}
})
}
}
func TestMustSubFS(t *testing.T) {
t.Parallel()
dir := createTestDir(t)
defer os.RemoveAll(dir)
// invalid path (no beginning and ending slashes)
if !hasPanicked(func() {
apis.MustSubFS(os.DirFS(dir), "/test/")
}) {
t.Fatalf("Expected to panic")
}
// valid path
if hasPanicked(func() {
apis.MustSubFS(os.DirFS(dir), "./////a/b/c") // checks if ToSlash was called
}) {
t.Fatalf("Didn't expect to panic")
}
// check sub content
sub := apis.MustSubFS(os.DirFS(dir), "sub")
_, err := sub.Open("test")
if err != nil {
t.Fatalf("Missing expected file sub/test")
}
}
// -------------------------------------------------------------------
func hasPanicked(f func()) (didPanic bool) {
defer func() {
if r := recover(); r != nil {
didPanic = true
}
}()
f()
return
}
// note: make sure to call os.RemoveAll(dir) after you are done
// working with the created test dir.
func createTestDir(t *testing.T) string {
dir, err := os.MkdirTemp(os.TempDir(), "test_dir")
if err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(dir, "index.html"), []byte("root index.html"), 0644); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(dir, "testroot"), []byte("root test"), 0644); err != nil {
t.Fatal(err)
}
if err := os.MkdirAll(filepath.Join(dir, "sub"), os.ModePerm); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(dir, "sub/index.html"), []byte("sub index.html"), 0644); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(dir, "sub/test"), []byte("sub test"), 0644); err != nil {
t.Fatal(err)
}
if err := os.MkdirAll(filepath.Join(dir, "sub", "sub2"), os.ModePerm); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(dir, "sub/sub2/index.html"), []byte("sub2 index.html"), 0644); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(dir, "sub/sub2/test"), []byte("sub2 test"), 0644); err != nil {
t.Fatal(err)
}
return dir
}

548
apis/batch.go Normal file
View file

@ -0,0 +1,548 @@
package apis
import (
"bytes"
"encoding/json"
"errors"
"io"
"mime/multipart"
"net/http"
"net/http/httptest"
"regexp"
"slices"
"strconv"
"strings"
"time"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tools/filesystem"
"github.com/pocketbase/pocketbase/tools/router"
"github.com/pocketbase/pocketbase/tools/types"
"github.com/spf13/cast"
)
func bindBatchApi(app core.App, rg *router.RouterGroup[*core.RequestEvent]) {
sub := rg.Group("/batch")
sub.POST("", batchTransaction).Unbind(DefaultBodyLimitMiddlewareId) // the body limit is inlined
}
type HandleFunc func(e *core.RequestEvent) error
type BatchActionHandlerFunc func(app core.App, ir *core.InternalRequest, params map[string]string, next func(data any) error) HandleFunc
// ValidBatchActions defines a map with the supported batch InternalRequest actions.
//
// Note: when adding new routes make sure that their middlewares are inlined!
var ValidBatchActions = map[*regexp.Regexp]BatchActionHandlerFunc{
// "upsert" handler
regexp.MustCompile(`^PUT /api/collections/(?P<collection>[^\/\?]+)/records(?P<query>\?.*)?$`): func(app core.App, ir *core.InternalRequest, params map[string]string, next func(any) error) HandleFunc {
var id string
if len(ir.Body) > 0 && ir.Body["id"] != "" {
id = cast.ToString(ir.Body["id"])
}
if id != "" {
_, err := app.FindRecordById(params["collection"], id)
if err == nil {
// update
// ---
params["id"] = id // required for the path value
ir.Method = "PATCH"
ir.URL = "/api/collections/" + params["collection"] + "/records/" + id + params["query"]
return recordUpdate(false, next)
}
}
// create
// ---
ir.Method = "POST"
ir.URL = "/api/collections/" + params["collection"] + "/records" + params["query"]
return recordCreate(false, next)
},
regexp.MustCompile(`^POST /api/collections/(?P<collection>[^\/\?]+)/records(\?.*)?$`): func(app core.App, ir *core.InternalRequest, params map[string]string, next func(any) error) HandleFunc {
return recordCreate(false, next)
},
regexp.MustCompile(`^PATCH /api/collections/(?P<collection>[^\/\?]+)/records/(?P<id>[^\/\?]+)(\?.*)?$`): func(app core.App, ir *core.InternalRequest, params map[string]string, next func(any) error) HandleFunc {
return recordUpdate(false, next)
},
regexp.MustCompile(`^DELETE /api/collections/(?P<collection>[^\/\?]+)/records/(?P<id>[^\/\?]+)(\?.*)?$`): func(app core.App, ir *core.InternalRequest, params map[string]string, next func(any) error) HandleFunc {
return recordDelete(false, next)
},
}
type BatchRequestResult struct {
Body any `json:"body"`
Status int `json:"status"`
}
type batchRequestsForm struct {
Requests []*core.InternalRequest `form:"requests" json:"requests"`
max int
}
func (brs batchRequestsForm) validate() error {
return validation.ValidateStruct(&brs,
validation.Field(&brs.Requests, validation.Required, validation.Length(0, brs.max)),
)
}
// NB! When the request is submitted as multipart/form-data,
// the regular fields data is expected to be submitted as serailized
// json under the @jsonPayload field and file keys need to follow the
// pattern "requests.N.fileField" or requests[N].fileField.
func batchTransaction(e *core.RequestEvent) error {
maxRequests := e.App.Settings().Batch.MaxRequests
if !e.App.Settings().Batch.Enabled || maxRequests <= 0 {
return e.ForbiddenError("Batch requests are not allowed.", nil)
}
txTimeout := time.Duration(e.App.Settings().Batch.Timeout) * time.Second
if txTimeout <= 0 {
txTimeout = 3 * time.Second // for now always limit
}
maxBodySize := e.App.Settings().Batch.MaxBodySize
if maxBodySize <= 0 {
maxBodySize = 128 << 20
}
err := applyBodyLimit(e, maxBodySize)
if err != nil {
return err
}
form := &batchRequestsForm{max: maxRequests}
// load base requests data
err = e.BindBody(form)
if err != nil {
return e.BadRequestError("Failed to read the submitted batch data.", err)
}
// load uploaded files into each request item
// note: expects the files to be under "requests.N.fileField" or "requests[N].fileField" format
// (the other regular fields must be put under `@jsonPayload` as serialized json)
if strings.HasPrefix(e.Request.Header.Get("Content-Type"), "multipart/form-data") {
for i, ir := range form.Requests {
iStr := strconv.Itoa(i)
files, err := extractPrefixedFiles(e.Request, "requests."+iStr+".", "requests["+iStr+"].")
if err != nil {
return e.BadRequestError("Failed to read the submitted batch files data.", err)
}
for key, files := range files {
if ir.Body == nil {
ir.Body = map[string]any{}
}
ir.Body[key] = files
}
}
}
// validate batch request form
err = form.validate()
if err != nil {
return e.BadRequestError("Invalid batch request data.", err)
}
event := new(core.BatchRequestEvent)
event.RequestEvent = e
event.Batch = form.Requests
return e.App.OnBatchRequest().Trigger(event, func(e *core.BatchRequestEvent) error {
bp := batchProcessor{
app: e.App,
baseEvent: e.RequestEvent,
infoContext: core.RequestInfoContextBatch,
}
if err := bp.Process(e.Batch, txTimeout); err != nil {
return firstApiError(err, e.BadRequestError("Batch transaction failed.", err))
}
return e.JSON(http.StatusOK, bp.results)
})
}
type batchProcessor struct {
app core.App
baseEvent *core.RequestEvent
infoContext string
results []*BatchRequestResult
failedIndex int
errCh chan error
stopCh chan struct{}
}
func (p *batchProcessor) Process(batch []*core.InternalRequest, timeout time.Duration) error {
p.results = make([]*BatchRequestResult, 0, len(batch))
if p.stopCh != nil {
close(p.stopCh)
}
p.stopCh = make(chan struct{}, 1)
if p.errCh != nil {
close(p.errCh)
}
p.errCh = make(chan error, 1)
return p.app.RunInTransaction(func(txApp core.App) error {
// used to interupts the recursive processing calls in case of a timeout or connection close
defer func() {
p.stopCh <- struct{}{}
}()
go func() {
err := p.process(txApp, batch, 0)
if err != nil {
err = validation.Errors{
"requests": validation.Errors{
strconv.Itoa(p.failedIndex): &BatchResponseError{
code: "batch_request_failed",
message: "Batch request failed.",
err: router.ToApiError(err),
},
},
}
}
// note: to avoid copying and due to the process recursion the final results order is reversed
if err == nil {
slices.Reverse(p.results)
}
p.errCh <- err
}()
select {
case responseErr := <-p.errCh:
return responseErr
case <-time.After(timeout):
// note: we don't return 408 Reques Timeout error because
// some browsers perform automatic retry behind the scenes
// which are hard to debug and unnecessary
return errors.New("batch transaction timeout")
case <-p.baseEvent.Request.Context().Done():
return errors.New("batch request interrupted")
}
})
}
func (p *batchProcessor) process(activeApp core.App, batch []*core.InternalRequest, i int) error {
select {
case <-p.stopCh:
return nil
default:
if len(batch) == 0 {
return nil
}
result, err := processInternalRequest(
activeApp,
p.baseEvent,
batch[0],
p.infoContext,
func(_ any) error {
if len(batch) == 1 {
return nil
}
err := p.process(activeApp, batch[1:], i+1)
// update the failed batch index (if not already)
if err != nil && p.failedIndex == 0 {
p.failedIndex = i + 1
}
return err
},
)
if err != nil {
return err
}
p.results = append(p.results, result)
return nil
}
}
func processInternalRequest(
activeApp core.App,
baseEvent *core.RequestEvent,
ir *core.InternalRequest,
infoContext string,
optNext func(data any) error,
) (*BatchRequestResult, error) {
handle, params, ok := prepareInternalAction(activeApp, ir, optNext)
if !ok {
return nil, errors.New("unknown batch request action")
}
// construct a new http.Request
// ---------------------------------------------------------------
buf, mw, err := multipartDataFromInternalRequest(ir)
if err != nil {
return nil, err
}
r, err := http.NewRequest(strings.ToUpper(ir.Method), ir.URL, buf)
if err != nil {
return nil, err
}
// cleanup multipart temp files
defer func() {
if r.MultipartForm != nil {
if err := r.MultipartForm.RemoveAll(); err != nil {
activeApp.Logger().Warn("failed to cleanup temp batch files", "error", err)
}
}
}()
// load batch request path params
// ---
for k, v := range params {
r.SetPathValue(k, v)
}
// clone original request
// ---
r.RequestURI = r.URL.RequestURI()
r.Proto = baseEvent.Request.Proto
r.ProtoMajor = baseEvent.Request.ProtoMajor
r.ProtoMinor = baseEvent.Request.ProtoMinor
r.Host = baseEvent.Request.Host
r.RemoteAddr = baseEvent.Request.RemoteAddr
r.TLS = baseEvent.Request.TLS
if s := baseEvent.Request.TransferEncoding; s != nil {
s2 := make([]string, len(s))
copy(s2, s)
r.TransferEncoding = s2
}
if baseEvent.Request.Trailer != nil {
r.Trailer = baseEvent.Request.Trailer.Clone()
}
if baseEvent.Request.Header != nil {
r.Header = baseEvent.Request.Header.Clone()
}
// apply batch request specific headers
// ---
for k, v := range ir.Headers {
// individual Authorization header keys don't have affect
// because the auth state is populated from the base event
if strings.EqualFold(k, "authorization") {
continue
}
r.Header.Set(k, v)
}
r.Header.Set("Content-Type", mw.FormDataContentType())
// construct a new RequestEvent
// ---------------------------------------------------------------
event := &core.RequestEvent{}
event.App = activeApp
event.Auth = baseEvent.Auth
event.SetAll(baseEvent.GetAll())
// load RequestInfo context
if infoContext == "" {
infoContext = core.RequestInfoContextDefault
}
event.Set(core.RequestEventKeyInfoContext, infoContext)
// assign request
event.Request = r
event.Request.Body = &router.RereadableReadCloser{ReadCloser: r.Body} // enables multiple reads
// assign response
rec := httptest.NewRecorder()
event.Response = &router.ResponseWriter{ResponseWriter: rec} // enables status and write tracking
// execute
// ---------------------------------------------------------------
if err := handle(event); err != nil {
return nil, err
}
result := rec.Result()
defer result.Body.Close()
body, _ := types.ParseJSONRaw(rec.Body.Bytes())
return &BatchRequestResult{
Status: result.StatusCode,
Body: body,
}, nil
}
func multipartDataFromInternalRequest(ir *core.InternalRequest) (*bytes.Buffer, *multipart.Writer, error) {
buf := &bytes.Buffer{}
mw := multipart.NewWriter(buf)
regularFields := map[string]any{}
fileFields := map[string][]*filesystem.File{}
// separate regular fields from files
// ---
for k, rawV := range ir.Body {
switch v := rawV.(type) {
case *filesystem.File:
fileFields[k] = append(fileFields[k], v)
case []*filesystem.File:
fileFields[k] = append(fileFields[k], v...)
default:
regularFields[k] = v
}
}
// submit regularFields as @jsonPayload
// ---
rawBody, err := json.Marshal(regularFields)
if err != nil {
return nil, nil, errors.Join(err, mw.Close())
}
jsonPayload, err := mw.CreateFormField("@jsonPayload")
if err != nil {
return nil, nil, errors.Join(err, mw.Close())
}
_, err = jsonPayload.Write(rawBody)
if err != nil {
return nil, nil, errors.Join(err, mw.Close())
}
// submit fileFields as multipart files
// ---
for key, files := range fileFields {
for _, file := range files {
part, err := mw.CreateFormFile(key, file.Name)
if err != nil {
return nil, nil, errors.Join(err, mw.Close())
}
fr, err := file.Reader.Open()
if err != nil {
return nil, nil, errors.Join(err, mw.Close())
}
_, err = io.Copy(part, fr)
if err != nil {
return nil, nil, errors.Join(err, fr.Close(), mw.Close())
}
err = fr.Close()
if err != nil {
return nil, nil, errors.Join(err, mw.Close())
}
}
}
return buf, mw, mw.Close()
}
func extractPrefixedFiles(request *http.Request, prefixes ...string) (map[string][]*filesystem.File, error) {
if request.MultipartForm == nil {
if err := request.ParseMultipartForm(router.DefaultMaxMemory); err != nil {
return nil, err
}
}
result := make(map[string][]*filesystem.File)
for k, fhs := range request.MultipartForm.File {
for _, p := range prefixes {
if strings.HasPrefix(k, p) {
resultKey := strings.TrimPrefix(k, p)
for _, fh := range fhs {
file, err := filesystem.NewFileFromMultipart(fh)
if err != nil {
return nil, err
}
result[resultKey] = append(result[resultKey], file)
}
}
}
}
return result, nil
}
func prepareInternalAction(activeApp core.App, ir *core.InternalRequest, optNext func(data any) error) (HandleFunc, map[string]string, bool) {
full := strings.ToUpper(ir.Method) + " " + ir.URL
for re, actionFactory := range ValidBatchActions {
params, ok := findNamedMatches(re, full)
if ok {
return actionFactory(activeApp, ir, params, optNext), params, true
}
}
return nil, nil, false
}
func findNamedMatches(re *regexp.Regexp, str string) (map[string]string, bool) {
match := re.FindStringSubmatch(str)
if match == nil {
return nil, false
}
result := map[string]string{}
names := re.SubexpNames()
for i, m := range match {
if names[i] != "" {
result[names[i]] = m
}
}
return result, true
}
// -------------------------------------------------------------------
var (
_ router.SafeErrorItem = (*BatchResponseError)(nil)
_ router.SafeErrorResolver = (*BatchResponseError)(nil)
)
type BatchResponseError struct {
err *router.ApiError
code string
message string
}
func (e *BatchResponseError) Error() string {
return e.message
}
func (e *BatchResponseError) Code() string {
return e.code
}
func (e *BatchResponseError) Resolve(errData map[string]any) any {
errData["response"] = e.err
return errData
}
func (e BatchResponseError) MarshalJSON() ([]byte, error) {
return json.Marshal(map[string]any{
"message": e.message,
"code": e.code,
"response": e.err,
})
}

691
apis/batch_test.go Normal file
View file

@ -0,0 +1,691 @@
package apis_test
import (
"net/http"
"strings"
"testing"
"time"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
"github.com/pocketbase/pocketbase/tools/router"
)
func TestBatchRequest(t *testing.T) {
t.Parallel()
formData, mp, err := tests.MockMultipartData(
map[string]string{
router.JSONPayloadKey: `{
"requests":[
{"method":"POST", "url":"/api/collections/demo3/records", "body": {"title": "batch1"}},
{"method":"POST", "url":"/api/collections/demo3/records", "body": {"title": "batch2"}},
{"method":"POST", "url":"/api/collections/demo3/records", "body": {"title": "batch3"}},
{"method":"PATCH", "url":"/api/collections/demo3/records/lcl9d87w22ml6jy", "body": {"files-": "test_FLurQTgrY8.txt"}}
]
}`,
},
"requests.0.files",
"requests.0.files",
"requests.0.files",
"requests[2].files",
)
if err != nil {
t.Fatal(err)
}
scenarios := []tests.ApiScenario{
{
Name: "disabled batch requets",
Method: http.MethodPost,
URL: "/api/batch",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().Batch.Enabled = false
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "max request limits reached",
Method: http.MethodPost,
URL: "/api/batch",
Body: strings.NewReader(`{
"requests": [
{"method":"GET", "url":"/test1"},
{"method":"GET", "url":"/test2"},
{"method":"GET", "url":"/test3"}
]
}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().Batch.Enabled = true
app.Settings().Batch.MaxRequests = 2
},
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{`,
`"requests":{"code":"validation_length_too_long"`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "trigger requests validations",
Method: http.MethodPost,
URL: "/api/batch",
Body: strings.NewReader(`{
"requests": [
{},
{"method":"GET", "url":"/valid"},
{"method":"invalid", "url":"/valid"},
{"method":"POST", "url":"` + strings.Repeat("a", 2001) + `"}
]
}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().Batch.Enabled = true
app.Settings().Batch.MaxRequests = 100
},
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{`,
`"requests":{`,
`"0":{"method":{"code":"validation_required"`,
`"2":{"method":{"code":"validation_in_invalid"`,
`"3":{"url":{"code":"validation_length_too_long"`,
},
NotExpectedContent: []string{
`"1":`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "unknown batch request action",
Method: http.MethodPost,
URL: "/api/batch",
Body: strings.NewReader(`{
"requests": [
{"method":"GET", "url":"/api/health"}
]
}`),
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{`,
`"requests":{`,
`0":{"code":"batch_request_failed"`,
`"response":{`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnBatchRequest": 1,
},
},
{
Name: "base 2 successful and 1 failed (public collection)",
Method: http.MethodPost,
URL: "/api/batch",
Body: strings.NewReader(`{
"requests": [
{"method":"POST", "url":"/api/collections/demo2/records", "body": {"title": "batch1"}},
{"method":"POST", "url":"/api/collections/demo2/records", "body": {"title": "batch2"}},
{"method":"POST", "url":"/api/collections/demo2/records", "body": {"title": ""}}
]
}`),
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{`,
`"response":{`,
`"2":{"code":"batch_request_failed"`,
`"response":{"data":{"title":{"code":"validation_required"`,
},
NotExpectedContent: []string{
`"0":`,
`"1":`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnBatchRequest": 1,
"OnRecordCreateRequest": 3,
"OnModelCreate": 3,
"OnModelCreateExecute": 2,
"OnModelAfterCreateError": 3,
"OnModelValidate": 3,
"OnRecordCreate": 3,
"OnRecordCreateExecute": 2,
"OnRecordAfterCreateError": 3,
"OnRecordValidate": 3,
"OnRecordEnrich": 2,
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
records, err := app.FindRecordsByFilter("demo2", `title~"batch"`, "", 0, 0)
if err != nil {
t.Fatal(err)
}
if len(records) != 0 {
t.Fatalf("Expected no batch records to be persisted, got %d", len(records))
}
},
},
{
Name: "base 4 successful (public collection)",
Method: http.MethodPost,
URL: "/api/batch",
Body: strings.NewReader(`{
"requests": [
{"method":"POST", "url":"/api/collections/demo2/records", "body": {"title": "batch1"}},
{"method":"POST", "url":"/api/collections/demo2/records", "body": {"title": "batch2"}},
{"method":"PUT", "url":"/api/collections/demo2/records", "body": {"title": "batch3"}},
{"method":"PUT", "url":"/api/collections/demo2/records?fields=*,id:excerpt(4,true)", "body": {"id":"achvryl401bhse3","title": "batch4"}}
]
}`),
ExpectedStatus: 200,
ExpectedContent: []string{
`"title":"batch1"`,
`"title":"batch2"`,
`"title":"batch3"`,
`"title":"batch4"`,
`"id":"achv..."`,
`"active":false`,
`"active":true`,
`"status":200`,
`"body":{`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnBatchRequest": 1,
"OnModelValidate": 4,
"OnRecordValidate": 4,
"OnRecordEnrich": 4,
"OnRecordCreateRequest": 3,
"OnModelCreate": 3,
"OnModelCreateExecute": 3,
"OnModelAfterCreateSuccess": 3,
"OnRecordCreate": 3,
"OnRecordCreateExecute": 3,
"OnRecordAfterCreateSuccess": 3,
"OnRecordUpdateRequest": 1,
"OnModelUpdate": 1,
"OnModelUpdateExecute": 1,
"OnModelAfterUpdateSuccess": 1,
"OnRecordUpdate": 1,
"OnRecordUpdateExecute": 1,
"OnRecordAfterUpdateSuccess": 1,
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
records, err := app.FindRecordsByFilter("demo2", `title~"batch"`, "", 0, 0)
if err != nil {
t.Fatal(err)
}
if len(records) != 4 {
t.Fatalf("Expected %d batch records to be persisted, got %d", 3, len(records))
}
},
},
{
Name: "mixed create/update/delete (rules failure)",
Method: http.MethodPost,
URL: "/api/batch",
Body: strings.NewReader(`{
"requests": [
{"method":"POST", "url":"/api/collections/demo2/records", "body": {"title": "batch_create"}},
{"method":"DELETE", "url":"/api/collections/demo2/records/achvryl401bhse3"},
{"method":"PATCH", "url":"/api/collections/demo3/records/1tmknxy2868d869", "body": {"title": "batch_update"}}
]
}`),
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{`,
`"requests":{`,
`"2":{"code":"batch_request_failed"`,
`"response":{`,
},
NotExpectedContent: []string{
// only demo3 requires authentication
`"0":`,
`"1":`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnBatchRequest": 1,
"OnModelCreate": 1,
"OnModelCreateExecute": 1,
"OnModelAfterCreateError": 1,
"OnModelDelete": 1,
"OnModelDeleteExecute": 1,
"OnModelAfterDeleteError": 1,
"OnModelValidate": 1,
"OnRecordCreateRequest": 1,
"OnRecordCreate": 1,
"OnRecordCreateExecute": 1,
"OnRecordAfterCreateError": 1,
"OnRecordDeleteRequest": 1,
"OnRecordDelete": 1,
"OnRecordDeleteExecute": 1,
"OnRecordAfterDeleteError": 1,
"OnRecordEnrich": 1,
"OnRecordValidate": 1,
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
_, err := app.FindFirstRecordByFilter("demo2", `title="batch_create"`)
if err == nil {
t.Fatal("Expected record to not be created")
}
_, err = app.FindFirstRecordByFilter("demo3", `title="batch_update"`)
if err == nil {
t.Fatal("Expected record to not be updated")
}
_, err = app.FindRecordById("demo2", "achvryl401bhse3")
if err != nil {
t.Fatal("Expected record to not be deleted")
}
},
},
{
Name: "mixed create/update/delete (rules success)",
Method: http.MethodPost,
URL: "/api/batch",
Headers: map[string]string{
// test@example.com, clients
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
},
Body: strings.NewReader(`{
"requests": [
{"method":"POST", "url":"/api/collections/demo2/records", "body": {"title": "batch_create"}, "headers": {"Authorization": "ignored"}},
{"method":"DELETE", "url":"/api/collections/demo2/records/achvryl401bhse3", "headers": {"Authorization": "ignored"}},
{"method":"PATCH", "url":"/api/collections/demo3/records/1tmknxy2868d869", "body": {"title": "batch_update"}, "headers": {"Authorization": "ignored"}}
]
}`),
ExpectedStatus: 200,
ExpectedContent: []string{
`"title":"batch_create"`,
`"title":"batch_update"`,
`"status":200`,
`"status":204`,
`"body":{`,
`"body":null`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnBatchRequest": 1,
// ---
"OnModelCreate": 1,
"OnModelCreateExecute": 1,
"OnModelAfterCreateSuccess": 1,
"OnModelDelete": 1,
"OnModelDeleteExecute": 1,
"OnModelAfterDeleteSuccess": 1,
"OnModelUpdate": 1,
"OnModelUpdateExecute": 1,
"OnModelAfterUpdateSuccess": 1,
"OnModelValidate": 2,
// ---
"OnRecordCreateRequest": 1,
"OnRecordCreate": 1,
"OnRecordCreateExecute": 1,
"OnRecordAfterCreateSuccess": 1,
"OnRecordDeleteRequest": 1,
"OnRecordDelete": 1,
"OnRecordDeleteExecute": 1,
"OnRecordAfterDeleteSuccess": 1,
"OnRecordUpdateRequest": 1,
"OnRecordUpdate": 1,
"OnRecordUpdateExecute": 1,
"OnRecordAfterUpdateSuccess": 1,
"OnRecordValidate": 2,
"OnRecordEnrich": 2,
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
_, err := app.FindFirstRecordByFilter("demo2", `title="batch_create"`)
if err != nil {
t.Fatal(err)
}
_, err = app.FindFirstRecordByFilter("demo3", `title="batch_update"`)
if err != nil {
t.Fatal(err)
}
_, err = app.FindRecordById("demo2", "achvryl401bhse3")
if err == nil {
t.Fatal("Expected record to be deleted")
}
},
},
{
Name: "mixed create/update/delete (superuser auth)",
Method: http.MethodPost,
URL: "/api/batch",
Headers: map[string]string{
// test@example.com, _superusers
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
Body: strings.NewReader(`{
"requests": [
{"method":"POST", "url":"/api/collections/demo2/records", "body": {"title": "batch_create"}},
{"method":"DELETE", "url":"/api/collections/demo2/records/achvryl401bhse3"},
{"method":"PATCH", "url":"/api/collections/demo3/records/1tmknxy2868d869", "body": {"title": "batch_update"}}
]
}`),
ExpectedStatus: 200,
ExpectedContent: []string{
`"title":"batch_create"`,
`"title":"batch_update"`,
`"status":200`,
`"status":204`,
`"body":{`,
`"body":null`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnBatchRequest": 1,
// ---
"OnModelCreate": 1,
"OnModelCreateExecute": 1,
"OnModelAfterCreateSuccess": 1,
"OnModelDelete": 1,
"OnModelDeleteExecute": 1,
"OnModelAfterDeleteSuccess": 1,
"OnModelUpdate": 1,
"OnModelUpdateExecute": 1,
"OnModelAfterUpdateSuccess": 1,
"OnModelValidate": 2,
// ---
"OnRecordCreateRequest": 1,
"OnRecordCreate": 1,
"OnRecordCreateExecute": 1,
"OnRecordAfterCreateSuccess": 1,
"OnRecordDeleteRequest": 1,
"OnRecordDelete": 1,
"OnRecordDeleteExecute": 1,
"OnRecordAfterDeleteSuccess": 1,
"OnRecordUpdateRequest": 1,
"OnRecordUpdate": 1,
"OnRecordUpdateExecute": 1,
"OnRecordAfterUpdateSuccess": 1,
"OnRecordValidate": 2,
"OnRecordEnrich": 2,
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
_, err := app.FindFirstRecordByFilter("demo2", `title="batch_create"`)
if err != nil {
t.Fatal(err)
}
_, err = app.FindFirstRecordByFilter("demo3", `title="batch_update"`)
if err != nil {
t.Fatal(err)
}
_, err = app.FindRecordById("demo2", "achvryl401bhse3")
if err == nil {
t.Fatal("Expected record to be deleted")
}
},
},
{
Name: "cascade delete/update",
Method: http.MethodPost,
URL: "/api/batch",
Headers: map[string]string{
// test@example.com, _superusers
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
Body: strings.NewReader(`{
"requests": [
{"method":"DELETE", "url":"/api/collections/demo3/records/1tmknxy2868d869"},
{"method":"DELETE", "url":"/api/collections/demo3/records/mk5fmymtx4wsprk"}
]
}`),
ExpectedStatus: 200,
ExpectedContent: []string{
`"status":204`,
`"body":null`,
},
NotExpectedContent: []string{
`"status":200`,
`"body":{`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnBatchRequest": 1,
// ---
"OnModelDelete": 3, // 2 batch + 1 cascade delete
"OnModelDeleteExecute": 3,
"OnModelAfterDeleteSuccess": 3,
"OnModelUpdate": 5, // 5 cascade update
"OnModelUpdateExecute": 5,
"OnModelAfterUpdateSuccess": 5,
// ---
"OnRecordDeleteRequest": 2,
"OnRecordDelete": 3,
"OnRecordDeleteExecute": 3,
"OnRecordAfterDeleteSuccess": 3,
"OnRecordUpdate": 5,
"OnRecordUpdateExecute": 5,
"OnRecordAfterUpdateSuccess": 5,
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
ids := []string{
"1tmknxy2868d869",
"mk5fmymtx4wsprk",
"qzaqccwrmva4o1n",
}
for _, id := range ids {
_, err := app.FindRecordById("demo2", id)
if err == nil {
t.Fatalf("Expected record %q to be deleted", id)
}
}
},
},
{
Name: "transaction timeout",
Method: http.MethodPost,
URL: "/api/batch",
Body: strings.NewReader(`{
"requests": [
{"method":"POST", "url":"/api/collections/demo2/records", "body": {"title": "batch1"}},
{"method":"POST", "url":"/api/collections/demo2/records", "body": {"title": "batch2"}}
]
}`),
Headers: map[string]string{
// test@example.com, _superusers
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().Batch.Timeout = 1
app.OnRecordCreateRequest("demo2").BindFunc(func(e *core.RecordRequestEvent) error {
time.Sleep(600 * time.Millisecond) // < 1s so that the first request can succeed
return e.Next()
})
},
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{}`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnBatchRequest": 1,
"OnRecordCreateRequest": 2,
"OnModelCreate": 1,
"OnModelCreateExecute": 1,
"OnModelAfterCreateError": 1,
"OnModelValidate": 1,
"OnRecordCreate": 1,
"OnRecordCreateExecute": 1,
"OnRecordAfterCreateError": 1,
"OnRecordEnrich": 1,
"OnRecordValidate": 1,
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
records, err := app.FindRecordsByFilter("demo2", `title~"batch"`, "", 0, 0)
if err != nil {
t.Fatal(err)
}
if len(records) != 0 {
t.Fatalf("Expected %d batch records to be persisted, got %d", 0, len(records))
}
},
},
{
Name: "multipart/form-data + file upload",
Method: http.MethodPost,
URL: "/api/batch",
Body: formData,
Headers: map[string]string{
// test@example.com, clients
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
"Content-Type": mp.FormDataContentType(),
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"title":"batch1"`,
`"title":"batch2"`,
`"title":"batch3"`,
`"id":"lcl9d87w22ml6jy"`,
`"files":["300_UhLKX91HVb.png"]`,
`"tmpfile_`,
`"status":200`,
`"body":{`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnBatchRequest": 1,
// ---
"OnModelCreate": 3,
"OnModelCreateExecute": 3,
"OnModelAfterCreateSuccess": 3,
"OnModelUpdate": 1,
"OnModelUpdateExecute": 1,
"OnModelAfterUpdateSuccess": 1,
"OnModelValidate": 4,
// ---
"OnRecordCreateRequest": 3,
"OnRecordUpdateRequest": 1,
"OnRecordCreate": 3,
"OnRecordCreateExecute": 3,
"OnRecordAfterCreateSuccess": 3,
"OnRecordUpdate": 1,
"OnRecordUpdateExecute": 1,
"OnRecordAfterUpdateSuccess": 1,
"OnRecordValidate": 4,
"OnRecordEnrich": 4,
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
batch1, err := app.FindFirstRecordByFilter("demo3", `title="batch1"`)
if err != nil {
t.Fatalf("missing batch1: %v", err)
}
batch1Files := batch1.GetStringSlice("files")
if len(batch1Files) != 3 {
t.Fatalf("Expected %d batch1 file(s), got %d", 3, len(batch1Files))
}
batch2, err := app.FindFirstRecordByFilter("demo3", `title="batch2"`)
if err != nil {
t.Fatalf("missing batch2: %v", err)
}
batch2Files := batch2.GetStringSlice("files")
if len(batch2Files) != 0 {
t.Fatalf("Expected %d batch2 file(s), got %d", 0, len(batch2Files))
}
batch3, err := app.FindFirstRecordByFilter("demo3", `title="batch3"`)
if err != nil {
t.Fatalf("missing batch3: %v", err)
}
batch3Files := batch3.GetStringSlice("files")
if len(batch3Files) != 1 {
t.Fatalf("Expected %d batch3 file(s), got %d", 1, len(batch3Files))
}
batch4, err := app.FindRecordById("demo3", "lcl9d87w22ml6jy")
if err != nil {
t.Fatalf("missing batch4: %v", err)
}
batch4Files := batch4.GetStringSlice("files")
if len(batch4Files) != 1 {
t.Fatalf("Expected %d batch4 file(s), got %d", 1, len(batch4Files))
}
},
},
{
Name: "create/update with expand query params",
Method: http.MethodPost,
URL: "/api/batch",
Headers: map[string]string{
// test@example.com, _superusers
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
Body: strings.NewReader(`{
"requests": [
{"method":"POST", "url":"/api/collections/demo5/records?expand=rel_one", "body": {"total": 9, "rel_one":"qzaqccwrmva4o1n"}},
{"method":"PATCH", "url":"/api/collections/demo5/records/qjeql998mtp1azp?expand=rel_many", "body": {"total": 10}}
]
}`),
ExpectedStatus: 200,
ExpectedContent: []string{
`"body":{`,
`"id":"qjeql998mtp1azp"`,
`"id":"qzaqccwrmva4o1n"`,
`"id":"i9naidtvr6qsgb4"`,
`"expand":{"rel_one"`,
`"expand":{"rel_many"`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnBatchRequest": 1,
// ---
"OnModelCreate": 1,
"OnModelCreateExecute": 1,
"OnModelAfterCreateSuccess": 1,
"OnModelUpdate": 1,
"OnModelUpdateExecute": 1,
"OnModelAfterUpdateSuccess": 1,
"OnModelValidate": 2,
// ---
"OnRecordCreateRequest": 1,
"OnRecordUpdateRequest": 1,
"OnRecordCreate": 1,
"OnRecordCreateExecute": 1,
"OnRecordAfterCreateSuccess": 1,
"OnRecordUpdate": 1,
"OnRecordUpdateExecute": 1,
"OnRecordAfterUpdateSuccess": 1,
"OnRecordValidate": 2,
"OnRecordEnrich": 5,
},
},
{
Name: "check body limit middleware",
Method: http.MethodPost,
URL: "/api/batch",
Headers: map[string]string{
// test@example.com, _superusers
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
Body: strings.NewReader(`{
"requests": [
{"method":"POST", "url":"/api/collections/demo5/records?expand=rel_one", "body": {"total": 9, "rel_one":"qzaqccwrmva4o1n"}},
{"method":"PATCH", "url":"/api/collections/demo5/records/qjeql998mtp1azp?expand=rel_many", "body": {"total": 10}}
]
}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().Batch.MaxBodySize = 10
},
ExpectedStatus: 413,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}

206
apis/collection.go Normal file
View file

@ -0,0 +1,206 @@
package apis
import (
"errors"
"net/http"
"strings"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tools/router"
"github.com/pocketbase/pocketbase/tools/search"
)
// bindCollectionApi registers the collection api endpoints and the corresponding handlers.
func bindCollectionApi(app core.App, rg *router.RouterGroup[*core.RequestEvent]) {
subGroup := rg.Group("/collections").Bind(RequireSuperuserAuth())
subGroup.GET("", collectionsList)
subGroup.POST("", collectionCreate)
subGroup.GET("/{collection}", collectionView)
subGroup.PATCH("/{collection}", collectionUpdate)
subGroup.DELETE("/{collection}", collectionDelete)
subGroup.DELETE("/{collection}/truncate", collectionTruncate)
subGroup.PUT("/import", collectionsImport)
subGroup.GET("/meta/scaffolds", collectionScaffolds)
}
func collectionsList(e *core.RequestEvent) error {
fieldResolver := search.NewSimpleFieldResolver(
"id", "created", "updated", "name", "system", "type",
)
collections := []*core.Collection{}
result, err := search.NewProvider(fieldResolver).
Query(e.App.CollectionQuery()).
ParseAndExec(e.Request.URL.Query().Encode(), &collections)
if err != nil {
return e.BadRequestError("", err)
}
event := new(core.CollectionsListRequestEvent)
event.RequestEvent = e
event.Collections = collections
event.Result = result
return event.App.OnCollectionsListRequest().Trigger(event, func(e *core.CollectionsListRequestEvent) error {
return execAfterSuccessTx(true, e.App, func() error {
return e.JSON(http.StatusOK, e.Result)
})
})
}
func collectionView(e *core.RequestEvent) error {
collection, err := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue("collection"))
if err != nil || collection == nil {
return e.NotFoundError("", err)
}
event := new(core.CollectionRequestEvent)
event.RequestEvent = e
event.Collection = collection
return e.App.OnCollectionViewRequest().Trigger(event, func(e *core.CollectionRequestEvent) error {
return execAfterSuccessTx(true, e.App, func() error {
return e.JSON(http.StatusOK, e.Collection)
})
})
}
func collectionCreate(e *core.RequestEvent) error {
// populate the minimal required factory collection data (if any)
factoryExtract := struct {
Type string `form:"type" json:"type"`
Name string `form:"name" json:"name"`
}{}
if err := e.BindBody(&factoryExtract); err != nil {
return e.BadRequestError("Failed to load the collection type data due to invalid formatting.", err)
}
// create scaffold
collection := core.NewCollection(factoryExtract.Type, factoryExtract.Name)
// merge the scaffold with the submitted request data
if err := e.BindBody(collection); err != nil {
return e.BadRequestError("Failed to load the submitted data due to invalid formatting.", err)
}
event := new(core.CollectionRequestEvent)
event.RequestEvent = e
event.Collection = collection
return e.App.OnCollectionCreateRequest().Trigger(event, func(e *core.CollectionRequestEvent) error {
if err := e.App.Save(e.Collection); err != nil {
// validation failure
var validationErrors validation.Errors
if errors.As(err, &validationErrors) {
return e.BadRequestError("Failed to create collection.", validationErrors)
}
// other generic db error
return e.BadRequestError("Failed to create collection. Raw error: \n"+err.Error(), nil)
}
return execAfterSuccessTx(true, e.App, func() error {
return e.JSON(http.StatusOK, e.Collection)
})
})
}
func collectionUpdate(e *core.RequestEvent) error {
collection, err := e.App.FindCollectionByNameOrId(e.Request.PathValue("collection"))
if err != nil || collection == nil {
return e.NotFoundError("", err)
}
if err := e.BindBody(collection); err != nil {
return e.BadRequestError("Failed to load the submitted data due to invalid formatting.", err)
}
event := new(core.CollectionRequestEvent)
event.RequestEvent = e
event.Collection = collection
return event.App.OnCollectionUpdateRequest().Trigger(event, func(e *core.CollectionRequestEvent) error {
if err := e.App.Save(e.Collection); err != nil {
// validation failure
var validationErrors validation.Errors
if errors.As(err, &validationErrors) {
return e.BadRequestError("Failed to update collection.", validationErrors)
}
// other generic db error
return e.BadRequestError("Failed to update collection. Raw error: \n"+err.Error(), nil)
}
return execAfterSuccessTx(true, e.App, func() error {
return e.JSON(http.StatusOK, e.Collection)
})
})
}
func collectionDelete(e *core.RequestEvent) error {
collection, err := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue("collection"))
if err != nil || collection == nil {
return e.NotFoundError("", err)
}
event := new(core.CollectionRequestEvent)
event.RequestEvent = e
event.Collection = collection
return e.App.OnCollectionDeleteRequest().Trigger(event, func(e *core.CollectionRequestEvent) error {
if err := e.App.Delete(e.Collection); err != nil {
msg := "Failed to delete collection"
// check fo references
refs, _ := e.App.FindCollectionReferences(e.Collection, e.Collection.Id)
if len(refs) > 0 {
names := make([]string, 0, len(refs))
for ref := range refs {
names = append(names, ref.Name)
}
msg += " probably due to existing reference in " + strings.Join(names, ", ")
}
return e.BadRequestError(msg, err)
}
return execAfterSuccessTx(true, e.App, func() error {
return e.NoContent(http.StatusNoContent)
})
})
}
func collectionTruncate(e *core.RequestEvent) error {
collection, err := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue("collection"))
if err != nil || collection == nil {
return e.NotFoundError("", err)
}
if collection.IsView() {
return e.BadRequestError("View collections cannot be truncated since they don't store their own records.", nil)
}
err = e.App.TruncateCollection(collection)
if err != nil {
return e.BadRequestError("Failed to truncate collection (most likely due to required cascade delete record references).", err)
}
return e.NoContent(http.StatusNoContent)
}
func collectionScaffolds(e *core.RequestEvent) error {
collections := map[string]*core.Collection{
core.CollectionTypeBase: core.NewBaseCollection(""),
core.CollectionTypeAuth: core.NewAuthCollection(""),
core.CollectionTypeView: core.NewViewCollection(""),
}
for _, c := range collections {
c.Id = "" // clear autogenerated id
}
return e.JSON(http.StatusOK, collections)
}

62
apis/collection_import.go Normal file
View file

@ -0,0 +1,62 @@
package apis
import (
"errors"
"net/http"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/pocketbase/pocketbase/core"
)
func collectionsImport(e *core.RequestEvent) error {
form := new(collectionsImportForm)
err := e.BindBody(form)
if err != nil {
return firstApiError(err, e.BadRequestError("An error occurred while loading the submitted data.", err))
}
err = form.validate()
if err != nil {
return firstApiError(err, e.BadRequestError("An error occurred while validating the submitted data.", err))
}
event := new(core.CollectionsImportRequestEvent)
event.RequestEvent = e
event.CollectionsData = form.Collections
event.DeleteMissing = form.DeleteMissing
return event.App.OnCollectionsImportRequest().Trigger(event, func(e *core.CollectionsImportRequestEvent) error {
importErr := e.App.ImportCollections(e.CollectionsData, form.DeleteMissing)
if importErr == nil {
return execAfterSuccessTx(true, e.App, func() error {
return e.NoContent(http.StatusNoContent)
})
}
// validation failure
var validationErrors validation.Errors
if errors.As(importErr, &validationErrors) {
return e.BadRequestError("Failed to import collections.", validationErrors)
}
// generic/db failure
return e.BadRequestError("Failed to import collections.", validation.Errors{"collections": validation.NewError(
"validation_collections_import_failure",
"Failed to import the collections configuration. Raw error:\n"+importErr.Error(),
)})
})
}
// -------------------------------------------------------------------
type collectionsImportForm struct {
Collections []map[string]any `form:"collections" json:"collections"`
DeleteMissing bool `form:"deleteMissing" json:"deleteMissing"`
}
func (form *collectionsImportForm) validate() error {
return validation.ValidateStruct(form,
validation.Field(&form.Collections, validation.Required),
)
}

View file

@ -0,0 +1,369 @@
package apis_test
import (
"errors"
"net/http"
"strings"
"testing"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
)
func TestCollectionsImport(t *testing.T) {
t.Parallel()
totalCollections := 16
scenarios := []tests.ApiScenario{
{
Name: "unauthorized",
Method: http.MethodPut,
URL: "/api/collections/import",
ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as regular user",
Method: http.MethodPut,
URL: "/api/collections/import",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as superuser + empty collections",
Method: http.MethodPut,
URL: "/api/collections/import",
Body: strings.NewReader(`{"collections":[]}`),
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{`,
`"collections":{"code":"validation_required"`,
},
ExpectedEvents: map[string]int{"*": 0},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
collections := []*core.Collection{}
if err := app.CollectionQuery().All(&collections); err != nil {
t.Fatal(err)
}
expected := totalCollections
if len(collections) != expected {
t.Fatalf("Expected %d collections, got %d", expected, len(collections))
}
},
},
{
Name: "authorized as superuser + collections validator failure",
Method: http.MethodPut,
URL: "/api/collections/import",
Body: strings.NewReader(`{
"collections":[
{"name": "import1"},
{
"name": "import2",
"fields": [
{
"id": "koih1lqx",
"name": "expand",
"type": "text"
}
]
}
]
}`),
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{`,
`"collections":{"code":"validation_collections_import_failure"`,
`import2`,
`fields`,
},
NotExpectedContent: []string{"Raw error:"},
ExpectedEvents: map[string]int{
"*": 0,
"OnCollectionsImportRequest": 1,
"OnCollectionCreate": 2,
"OnCollectionCreateExecute": 2,
"OnCollectionAfterCreateError": 2,
"OnModelCreate": 2,
"OnModelCreateExecute": 2,
"OnModelAfterCreateError": 2,
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
collections := []*core.Collection{}
if err := app.CollectionQuery().All(&collections); err != nil {
t.Fatal(err)
}
expected := totalCollections
if len(collections) != expected {
t.Fatalf("Expected %d collections, got %d", expected, len(collections))
}
},
},
{
Name: "authorized as superuser + non-validator failure",
Method: http.MethodPut,
URL: "/api/collections/import",
Body: strings.NewReader(`{
"collections":[
{
"name": "import1",
"fields": [
{
"id": "koih1lqx",
"name": "test",
"type": "text"
}
]
},
{
"name": "import2",
"fields": [
{
"id": "koih1lqx",
"name": "test",
"type": "text"
}
],
"indexes": [
"create index idx_test on import2 (test)"
]
}
]
}`),
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{`,
`"collections":{"code":"validation_collections_import_failure"`,
`Raw error:`,
`custom_error`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnCollectionsImportRequest": 1,
"OnCollectionCreate": 1,
"OnCollectionAfterCreateError": 1,
"OnModelCreate": 1,
"OnModelAfterCreateError": 1,
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.OnCollectionCreate().BindFunc(func(e *core.CollectionEvent) error {
return errors.New("custom_error")
})
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
collections := []*core.Collection{}
if err := app.CollectionQuery().All(&collections); err != nil {
t.Fatal(err)
}
expected := totalCollections
if len(collections) != expected {
t.Fatalf("Expected %d collections, got %d", expected, len(collections))
}
},
},
{
Name: "authorized as superuser + successful collections create",
Method: http.MethodPut,
URL: "/api/collections/import",
Body: strings.NewReader(`{
"collections":[
{
"name": "import1",
"fields": [
{
"id": "koih1lqx",
"name": "test",
"type": "text"
}
]
},
{
"name": "import2",
"fields": [
{
"id": "koih1lqx",
"name": "test",
"type": "text"
}
],
"indexes": [
"create index idx_test on import2 (test)"
]
},
{
"name": "auth_without_fields",
"type": "auth"
}
]
}`),
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
ExpectedStatus: 204,
ExpectedEvents: map[string]int{
"*": 0,
"OnCollectionsImportRequest": 1,
"OnCollectionCreate": 3,
"OnCollectionCreateExecute": 3,
"OnCollectionAfterCreateSuccess": 3,
"OnModelCreate": 3,
"OnModelCreateExecute": 3,
"OnModelAfterCreateSuccess": 3,
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
collections := []*core.Collection{}
if err := app.CollectionQuery().All(&collections); err != nil {
t.Fatal(err)
}
expected := totalCollections + 3
if len(collections) != expected {
t.Fatalf("Expected %d collections, got %d", expected, len(collections))
}
indexes, err := app.TableIndexes("import2")
if err != nil || indexes["idx_test"] == "" {
t.Fatalf("Missing index %s (%v)", "idx_test", err)
}
},
},
{
Name: "authorized as superuser + create/update/delete",
Method: http.MethodPut,
URL: "/api/collections/import",
Body: strings.NewReader(`{
"deleteMissing": true,
"collections":[
{"name": "test123"},
{
"id":"wsmn24bux7wo113",
"name":"demo1",
"fields":[
{
"id":"_2hlxbmp",
"name":"title",
"type":"text",
"required":true
}
],
"indexes": []
}
]
}`),
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
ExpectedStatus: 204,
ExpectedEvents: map[string]int{
"*": 0,
"OnCollectionsImportRequest": 1,
// ---
"OnModelCreate": 1,
"OnModelCreateExecute": 1,
"OnModelAfterCreateSuccess": 1,
"OnCollectionCreate": 1,
"OnCollectionCreateExecute": 1,
"OnCollectionAfterCreateSuccess": 1,
// ---
"OnModelUpdate": 1,
"OnModelUpdateExecute": 1,
"OnModelAfterUpdateSuccess": 1,
"OnCollectionUpdate": 1,
"OnCollectionUpdateExecute": 1,
"OnCollectionAfterUpdateSuccess": 1,
// ---
"OnModelDelete": 14,
"OnModelAfterDeleteSuccess": 14,
"OnModelDeleteExecute": 14,
"OnCollectionDelete": 9,
"OnCollectionDeleteExecute": 9,
"OnCollectionAfterDeleteSuccess": 9,
"OnRecordAfterDeleteSuccess": 5,
"OnRecordDelete": 5,
"OnRecordDeleteExecute": 5,
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
collections := []*core.Collection{}
if err := app.CollectionQuery().All(&collections); err != nil {
t.Fatal(err)
}
systemCollections := 0
for _, c := range collections {
if c.System {
systemCollections++
}
}
expected := systemCollections + 2
if len(collections) != expected {
t.Fatalf("Expected %d collections, got %d", expected, len(collections))
}
},
},
{
Name: "OnCollectionsImportRequest tx body write check",
Method: http.MethodPut,
URL: "/api/collections/import",
Body: strings.NewReader(`{
"deleteMissing": true,
"collections":[
{"name": "test123"},
{
"id":"wsmn24bux7wo113",
"name":"demo1",
"fields":[
{
"id":"_2hlxbmp",
"name":"title",
"type":"text",
"required":true
}
],
"indexes": []
}
]
}`),
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.OnCollectionsImportRequest().BindFunc(func(e *core.CollectionsImportRequestEvent) error {
original := e.App
return e.App.RunInTransaction(func(txApp core.App) error {
e.App = txApp
defer func() { e.App = original }()
if err := e.Next(); err != nil {
return err
}
return e.BadRequestError("TX_ERROR", nil)
})
})
},
ExpectedStatus: 400,
ExpectedEvents: map[string]int{"OnCollectionsImportRequest": 1},
ExpectedContent: []string{"TX_ERROR"},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}

1586
apis/collection_test.go Normal file

File diff suppressed because it is too large Load diff

59
apis/cron.go Normal file
View file

@ -0,0 +1,59 @@
package apis
import (
"net/http"
"slices"
"strings"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tools/cron"
"github.com/pocketbase/pocketbase/tools/router"
"github.com/pocketbase/pocketbase/tools/routine"
)
// bindCronApi registers the crons api endpoint.
func bindCronApi(app core.App, rg *router.RouterGroup[*core.RequestEvent]) {
subGroup := rg.Group("/crons").Bind(RequireSuperuserAuth())
subGroup.GET("", cronsList)
subGroup.POST("/{id}", cronRun)
}
func cronsList(e *core.RequestEvent) error {
jobs := e.App.Cron().Jobs()
slices.SortStableFunc(jobs, func(a, b *cron.Job) int {
if strings.HasPrefix(a.Id(), "__pb") {
return 1
}
if strings.HasPrefix(b.Id(), "__pb") {
return -1
}
return strings.Compare(a.Id(), b.Id())
})
return e.JSON(http.StatusOK, jobs)
}
func cronRun(e *core.RequestEvent) error {
cronId := e.Request.PathValue("id")
var foundJob *cron.Job
jobs := e.App.Cron().Jobs()
for _, j := range jobs {
if j.Id() == cronId {
foundJob = j
break
}
}
if foundJob == nil {
return e.NotFoundError("Missing or invalid cron job", nil)
}
routine.FireAndForget(func() {
foundJob.Run()
})
return e.NoContent(http.StatusNoContent)
}

149
apis/cron_test.go Normal file
View file

@ -0,0 +1,149 @@
package apis_test
import (
"net/http"
"testing"
"time"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
"github.com/spf13/cast"
)
func TestCronsList(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "unauthorized",
Method: http.MethodGet,
URL: "/api/crons",
ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as regular user",
Method: http.MethodGet,
URL: "/api/crons",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as superuser (empty list)",
Method: http.MethodGet,
URL: "/api/crons",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Cron().RemoveAll()
},
ExpectedStatus: 200,
ExpectedContent: []string{`[]`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as superuser",
Method: http.MethodGet,
URL: "/api/crons",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
ExpectedStatus: 200,
ExpectedContent: []string{
`{"id":"__pbLogsCleanup__","expression":"0 */6 * * *"}`,
`{"id":"__pbDBOptimize__","expression":"0 0 * * *"}`,
`{"id":"__pbMFACleanup__","expression":"0 * * * *"}`,
`{"id":"__pbOTPCleanup__","expression":"0 * * * *"}`,
},
ExpectedEvents: map[string]int{"*": 0},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestCronsRun(t *testing.T) {
t.Parallel()
beforeTestFunc := func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Cron().Add("test", "* * * * *", func() {
app.Store().Set("testJobCalls", cast.ToInt(app.Store().Get("testJobCalls"))+1)
})
}
expectedCalls := func(expected int) func(t testing.TB, app *tests.TestApp, res *http.Response) {
return func(t testing.TB, app *tests.TestApp, res *http.Response) {
total := cast.ToInt(app.Store().Get("testJobCalls"))
if total != expected {
t.Fatalf("Expected total testJobCalls %d, got %d", expected, total)
}
}
}
scenarios := []tests.ApiScenario{
{
Name: "unauthorized",
Method: http.MethodPost,
URL: "/api/crons/test",
Delay: 50 * time.Millisecond,
BeforeTestFunc: beforeTestFunc,
AfterTestFunc: expectedCalls(0),
ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as regular user",
Method: http.MethodPost,
URL: "/api/crons/test",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
Delay: 50 * time.Millisecond,
BeforeTestFunc: beforeTestFunc,
AfterTestFunc: expectedCalls(0),
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as superuser (missing job)",
Method: http.MethodPost,
URL: "/api/crons/missing",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
Delay: 50 * time.Millisecond,
BeforeTestFunc: beforeTestFunc,
AfterTestFunc: expectedCalls(0),
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as superuser (existing job)",
Method: http.MethodPost,
URL: "/api/crons/test",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
Delay: 50 * time.Millisecond,
BeforeTestFunc: beforeTestFunc,
AfterTestFunc: expectedCalls(1),
ExpectedStatus: 204,
ExpectedEvents: map[string]int{"*": 0},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}

230
apis/file.go Normal file
View file

@ -0,0 +1,230 @@
package apis
import (
"context"
"errors"
"fmt"
"log/slog"
"net/http"
"os"
"runtime"
"time"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tools/filesystem"
"github.com/pocketbase/pocketbase/tools/list"
"github.com/pocketbase/pocketbase/tools/router"
"github.com/spf13/cast"
"golang.org/x/sync/semaphore"
"golang.org/x/sync/singleflight"
)
var imageContentTypes = []string{"image/png", "image/jpg", "image/jpeg", "image/gif", "image/webp"}
var defaultThumbSizes = []string{"100x100"}
// bindFileApi registers the file api endpoints and the corresponding handlers.
func bindFileApi(app core.App, rg *router.RouterGroup[*core.RequestEvent]) {
maxWorkers := cast.ToInt64(os.Getenv("PB_THUMBS_MAX_WORKERS"))
if maxWorkers <= 0 {
maxWorkers = int64(runtime.NumCPU() + 2) // the value is arbitrary chosen and may change in the future
}
maxWait := cast.ToInt64(os.Getenv("PB_THUMBS_MAX_WAIT"))
if maxWait <= 0 {
maxWait = 60
}
api := fileApi{
thumbGenPending: new(singleflight.Group),
thumbGenSem: semaphore.NewWeighted(maxWorkers),
thumbGenMaxWait: time.Duration(maxWait) * time.Second,
}
sub := rg.Group("/files")
sub.POST("/token", api.fileToken).Bind(RequireAuth())
sub.GET("/{collection}/{recordId}/{filename}", api.download).Bind(collectionPathRateLimit("", "file"))
}
type fileApi struct {
// thumbGenSem is a semaphore to prevent too much concurrent
// requests generating new thumbs at the same time.
thumbGenSem *semaphore.Weighted
// thumbGenPending represents a group of currently pending
// thumb generation processes.
thumbGenPending *singleflight.Group
// thumbGenMaxWait is the maximum waiting time for starting a new
// thumb generation process.
thumbGenMaxWait time.Duration
}
func (api *fileApi) fileToken(e *core.RequestEvent) error {
if e.Auth == nil {
return e.UnauthorizedError("Missing auth context.", nil)
}
token, err := e.Auth.NewFileToken()
if err != nil {
return e.InternalServerError("Failed to generate file token", err)
}
event := new(core.FileTokenRequestEvent)
event.RequestEvent = e
event.Token = token
event.Record = e.Auth
return e.App.OnFileTokenRequest().Trigger(event, func(e *core.FileTokenRequestEvent) error {
return execAfterSuccessTx(true, e.App, func() error {
return e.JSON(http.StatusOK, map[string]string{"token": e.Token})
})
})
}
func (api *fileApi) download(e *core.RequestEvent) error {
collection, err := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue("collection"))
if err != nil {
return e.NotFoundError("", nil)
}
recordId := e.Request.PathValue("recordId")
if recordId == "" {
return e.NotFoundError("", nil)
}
record, err := e.App.FindRecordById(collection, recordId)
if err != nil {
return e.NotFoundError("", err)
}
filename := e.Request.PathValue("filename")
fileField := record.FindFileFieldByFile(filename)
if fileField == nil {
return e.NotFoundError("", nil)
}
// check whether the request is authorized to view the protected file
if fileField.Protected {
originalRequestInfo, err := e.RequestInfo()
if err != nil {
return e.InternalServerError("Failed to load request info", err)
}
token := e.Request.URL.Query().Get("token")
authRecord, _ := e.App.FindAuthRecordByToken(token, core.TokenTypeFile)
// create a shallow copy of the cached request data and adjust it to the current auth record (if any)
requestInfo := *originalRequestInfo
requestInfo.Context = core.RequestInfoContextProtectedFile
requestInfo.Auth = authRecord
if ok, _ := e.App.CanAccessRecord(record, &requestInfo, record.Collection().ViewRule); !ok {
return e.NotFoundError("", errors.New("insufficient permissions to access the file resource"))
}
}
baseFilesPath := record.BaseFilesPath()
// fetch the original view file field related record
if collection.IsView() {
fileRecord, err := e.App.FindRecordByViewFile(collection.Id, fileField.Name, filename)
if err != nil {
return e.NotFoundError("", fmt.Errorf("failed to fetch view file field record: %w", err))
}
baseFilesPath = fileRecord.BaseFilesPath()
}
fsys, err := e.App.NewFilesystem()
if err != nil {
return e.InternalServerError("Filesystem initialization failure.", err)
}
defer fsys.Close()
originalPath := baseFilesPath + "/" + filename
servedPath := originalPath
servedName := filename
// check for valid thumb size param
thumbSize := e.Request.URL.Query().Get("thumb")
if thumbSize != "" && (list.ExistInSlice(thumbSize, defaultThumbSizes) || list.ExistInSlice(thumbSize, fileField.Thumbs)) {
// extract the original file meta attributes and check it existence
oAttrs, oAttrsErr := fsys.Attributes(originalPath)
if oAttrsErr != nil {
return e.NotFoundError("", err)
}
// check if it is an image
if list.ExistInSlice(oAttrs.ContentType, imageContentTypes) {
// add thumb size as file suffix
servedName = thumbSize + "_" + filename
servedPath = baseFilesPath + "/thumbs_" + filename + "/" + servedName
// create a new thumb if it doesn't exist
if exists, _ := fsys.Exists(servedPath); !exists {
if err := api.createThumb(e, fsys, originalPath, servedPath, thumbSize); err != nil {
e.App.Logger().Warn(
"Fallback to original - failed to create thumb "+servedName,
slog.Any("error", err),
slog.String("original", originalPath),
slog.String("thumb", servedPath),
)
// fallback to the original
servedName = filename
servedPath = originalPath
}
}
}
}
event := new(core.FileDownloadRequestEvent)
event.RequestEvent = e
event.Collection = collection
event.Record = record
event.FileField = fileField
event.ServedPath = servedPath
event.ServedName = servedName
// clickjacking shouldn't be a concern when serving uploaded files,
// so it safe to unset the global X-Frame-Options to allow files embedding
// (note: it is out of the hook to allow users to customize the behavior)
e.Response.Header().Del("X-Frame-Options")
return e.App.OnFileDownloadRequest().Trigger(event, func(e *core.FileDownloadRequestEvent) error {
err = execAfterSuccessTx(true, e.App, func() error {
return fsys.Serve(e.Response, e.Request, e.ServedPath, e.ServedName)
})
if err != nil {
return e.NotFoundError("", err)
}
return nil
})
}
func (api *fileApi) createThumb(
e *core.RequestEvent,
fsys *filesystem.System,
originalPath string,
thumbPath string,
thumbSize string,
) error {
ch := api.thumbGenPending.DoChan(thumbPath, func() (any, error) {
ctx, cancel := context.WithTimeout(e.Request.Context(), api.thumbGenMaxWait)
defer cancel()
if err := api.thumbGenSem.Acquire(ctx, 1); err != nil {
return nil, err
}
defer api.thumbGenSem.Release(1)
return nil, fsys.CreateThumb(originalPath, thumbPath, thumbSize)
})
res := <-ch
api.thumbGenPending.Forget(thumbPath)
return res.Err
}

504
apis/file_test.go Normal file
View file

@ -0,0 +1,504 @@
package apis_test
import (
"net/http"
"net/http/httptest"
"os"
"path"
"path/filepath"
"runtime"
"sync"
"testing"
"github.com/pocketbase/pocketbase/apis"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
"github.com/pocketbase/pocketbase/tools/types"
)
func TestFileToken(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "unauthorized",
Method: http.MethodPost,
URL: "/api/files/token",
ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "regular user",
Method: http.MethodPost,
URL: "/api/files/token",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"token":"`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnFileTokenRequest": 1,
},
},
{
Name: "superuser",
Method: http.MethodPost,
URL: "/api/files/token",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"token":"`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnFileTokenRequest": 1,
},
},
{
Name: "hook token overwrite",
Method: http.MethodPost,
URL: "/api/files/token",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.OnFileTokenRequest().BindFunc(func(e *core.FileTokenRequestEvent) error {
e.Token = "test"
return e.Next()
})
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"token":"test"`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnFileTokenRequest": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestFileDownload(t *testing.T) {
t.Parallel()
_, currentFile, _, _ := runtime.Caller(0)
dataDirRelPath := "../tests/data/"
testFilePath := filepath.Join(path.Dir(currentFile), dataDirRelPath, "storage/_pb_users_auth_/oap640cot4yru2s/test_kfd2wYLxkz.txt")
testImgPath := filepath.Join(path.Dir(currentFile), dataDirRelPath, "storage/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png")
testThumbCropCenterPath := filepath.Join(path.Dir(currentFile), dataDirRelPath, "storage/_pb_users_auth_/4q1xlclmfloku33/thumbs_300_1SEi6Q6U72.png/70x50_300_1SEi6Q6U72.png")
testThumbCropTopPath := filepath.Join(path.Dir(currentFile), dataDirRelPath, "storage/_pb_users_auth_/4q1xlclmfloku33/thumbs_300_1SEi6Q6U72.png/70x50t_300_1SEi6Q6U72.png")
testThumbCropBottomPath := filepath.Join(path.Dir(currentFile), dataDirRelPath, "storage/_pb_users_auth_/4q1xlclmfloku33/thumbs_300_1SEi6Q6U72.png/70x50b_300_1SEi6Q6U72.png")
testThumbFitPath := filepath.Join(path.Dir(currentFile), dataDirRelPath, "storage/_pb_users_auth_/4q1xlclmfloku33/thumbs_300_1SEi6Q6U72.png/70x50f_300_1SEi6Q6U72.png")
testThumbZeroWidthPath := filepath.Join(path.Dir(currentFile), dataDirRelPath, "storage/_pb_users_auth_/4q1xlclmfloku33/thumbs_300_1SEi6Q6U72.png/0x50_300_1SEi6Q6U72.png")
testThumbZeroHeightPath := filepath.Join(path.Dir(currentFile), dataDirRelPath, "storage/_pb_users_auth_/4q1xlclmfloku33/thumbs_300_1SEi6Q6U72.png/70x0_300_1SEi6Q6U72.png")
testFile, fileErr := os.ReadFile(testFilePath)
if fileErr != nil {
t.Fatal(fileErr)
}
testImg, imgErr := os.ReadFile(testImgPath)
if imgErr != nil {
t.Fatal(imgErr)
}
testThumbCropCenter, thumbErr := os.ReadFile(testThumbCropCenterPath)
if thumbErr != nil {
t.Fatal(thumbErr)
}
testThumbCropTop, thumbErr := os.ReadFile(testThumbCropTopPath)
if thumbErr != nil {
t.Fatal(thumbErr)
}
testThumbCropBottom, thumbErr := os.ReadFile(testThumbCropBottomPath)
if thumbErr != nil {
t.Fatal(thumbErr)
}
testThumbFit, thumbErr := os.ReadFile(testThumbFitPath)
if thumbErr != nil {
t.Fatal(thumbErr)
}
testThumbZeroWidth, thumbErr := os.ReadFile(testThumbZeroWidthPath)
if thumbErr != nil {
t.Fatal(thumbErr)
}
testThumbZeroHeight, thumbErr := os.ReadFile(testThumbZeroHeightPath)
if thumbErr != nil {
t.Fatal(thumbErr)
}
scenarios := []tests.ApiScenario{
{
Name: "missing collection",
Method: http.MethodGet,
URL: "/api/files/missing/4q1xlclmfloku33/300_1SEi6Q6U72.png",
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "missing record",
Method: http.MethodGet,
URL: "/api/files/_pb_users_auth_/missing/300_1SEi6Q6U72.png",
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "missing file",
Method: http.MethodGet,
URL: "/api/files/_pb_users_auth_/4q1xlclmfloku33/missing.png",
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "existing image",
Method: http.MethodGet,
URL: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png",
ExpectedStatus: 200,
ExpectedContent: []string{string(testImg)},
ExpectedEvents: map[string]int{
"*": 0,
"OnFileDownloadRequest": 1,
},
},
{
Name: "existing image - missing thumb (should fallback to the original)",
Method: http.MethodGet,
URL: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png?thumb=999x999",
ExpectedStatus: 200,
ExpectedContent: []string{string(testImg)},
ExpectedEvents: map[string]int{
"*": 0,
"OnFileDownloadRequest": 1,
},
},
{
Name: "existing image - existing thumb (crop center)",
Method: http.MethodGet,
URL: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png?thumb=70x50",
ExpectedStatus: 200,
ExpectedContent: []string{string(testThumbCropCenter)},
ExpectedEvents: map[string]int{
"*": 0,
"OnFileDownloadRequest": 1,
},
},
{
Name: "existing image - existing thumb (crop top)",
Method: http.MethodGet,
URL: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png?thumb=70x50t",
ExpectedStatus: 200,
ExpectedContent: []string{string(testThumbCropTop)},
ExpectedEvents: map[string]int{
"*": 0,
"OnFileDownloadRequest": 1,
},
},
{
Name: "existing image - existing thumb (crop bottom)",
Method: http.MethodGet,
URL: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png?thumb=70x50b",
ExpectedStatus: 200,
ExpectedContent: []string{string(testThumbCropBottom)},
ExpectedEvents: map[string]int{
"*": 0,
"OnFileDownloadRequest": 1,
},
},
{
Name: "existing image - existing thumb (fit)",
Method: http.MethodGet,
URL: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png?thumb=70x50f",
ExpectedStatus: 200,
ExpectedContent: []string{string(testThumbFit)},
ExpectedEvents: map[string]int{
"*": 0,
"OnFileDownloadRequest": 1,
},
},
{
Name: "existing image - existing thumb (zero width)",
Method: http.MethodGet,
URL: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png?thumb=0x50",
ExpectedStatus: 200,
ExpectedContent: []string{string(testThumbZeroWidth)},
ExpectedEvents: map[string]int{
"*": 0,
"OnFileDownloadRequest": 1,
},
},
{
Name: "existing image - existing thumb (zero height)",
Method: http.MethodGet,
URL: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png?thumb=70x0",
ExpectedStatus: 200,
ExpectedContent: []string{string(testThumbZeroHeight)},
ExpectedEvents: map[string]int{
"*": 0,
"OnFileDownloadRequest": 1,
},
},
{
Name: "existing non image file - thumb parameter should be ignored",
Method: http.MethodGet,
URL: "/api/files/_pb_users_auth_/oap640cot4yru2s/test_kfd2wYLxkz.txt?thumb=100x100",
ExpectedStatus: 200,
ExpectedContent: []string{string(testFile)},
ExpectedEvents: map[string]int{
"*": 0,
"OnFileDownloadRequest": 1,
},
},
// protected file access checks
{
Name: "protected file - superuser with expired file token",
Method: http.MethodGet,
URL: "/api/files/demo1/al1h9ijdeojtsjy/300_Jsjq7RdBgA.png?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsImV4cCI6MTY0MDk5MTY2MSwidHlwZSI6ImZpbGUiLCJjb2xsZWN0aW9uSWQiOiJwYmNfMzE0MjYzNTgyMyJ9.nqqtqpPhxU0045F4XP_ruAkzAidYBc5oPy9ErN3XBq0",
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "protected file - superuser with valid file token",
Method: http.MethodGet,
URL: "/api/files/demo1/al1h9ijdeojtsjy/300_Jsjq7RdBgA.png?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6ImZpbGUiLCJjb2xsZWN0aW9uSWQiOiJwYmNfMzE0MjYzNTgyMyJ9.Lupz541xRvrktwkrl55p5pPCF77T69ZRsohsIcb2dxc",
ExpectedStatus: 200,
ExpectedContent: []string{"PNG"},
ExpectedEvents: map[string]int{
"*": 0,
"OnFileDownloadRequest": 1,
},
},
{
Name: "protected file - guest without view access",
Method: http.MethodGet,
URL: "/api/files/demo1/al1h9ijdeojtsjy/300_Jsjq7RdBgA.png",
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "protected file - guest with view access",
Method: http.MethodGet,
URL: "/api/files/demo1/al1h9ijdeojtsjy/300_Jsjq7RdBgA.png",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
// mock public view access
c, err := app.FindCachedCollectionByNameOrId("demo1")
if err != nil {
t.Fatalf("Failed to fetch mock collection: %v", err)
}
c.ViewRule = types.Pointer("")
if err := app.UnsafeWithoutHooks().Save(c); err != nil {
t.Fatalf("Failed to update mock collection: %v", err)
}
},
ExpectedStatus: 200,
ExpectedContent: []string{"PNG"},
ExpectedEvents: map[string]int{
"*": 0,
"OnFileDownloadRequest": 1,
},
},
{
Name: "protected file - auth record without view access",
Method: http.MethodGet,
URL: "/api/files/demo1/al1h9ijdeojtsjy/300_Jsjq7RdBgA.png?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6ImZpbGUiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8ifQ.nSTLuCPcGpWn2K2l-BFkC3Vlzc-ZTDPByYq8dN1oPSo",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
// mock restricted user view access
c, err := app.FindCachedCollectionByNameOrId("demo1")
if err != nil {
t.Fatalf("Failed to fetch mock collection: %v", err)
}
c.ViewRule = types.Pointer("@request.auth.verified = true")
if err := app.UnsafeWithoutHooks().Save(c); err != nil {
t.Fatalf("Failed to update mock collection: %v", err)
}
},
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "protected file - auth record with view access",
Method: http.MethodGet,
URL: "/api/files/demo1/al1h9ijdeojtsjy/300_Jsjq7RdBgA.png?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6ImZpbGUiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8ifQ.nSTLuCPcGpWn2K2l-BFkC3Vlzc-ZTDPByYq8dN1oPSo",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
// mock user view access
c, err := app.FindCachedCollectionByNameOrId("demo1")
if err != nil {
t.Fatalf("Failed to fetch mock collection: %v", err)
}
c.ViewRule = types.Pointer("@request.auth.verified = false")
if err := app.UnsafeWithoutHooks().Save(c); err != nil {
t.Fatalf("Failed to update mock collection: %v", err)
}
},
ExpectedStatus: 200,
ExpectedContent: []string{"PNG"},
ExpectedEvents: map[string]int{
"*": 0,
"OnFileDownloadRequest": 1,
},
},
{
Name: "protected file in view (view's View API rule failure)",
Method: http.MethodGet,
URL: "/api/files/view1/al1h9ijdeojtsjy/300_Jsjq7RdBgA.png?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6ImZpbGUiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8ifQ.nSTLuCPcGpWn2K2l-BFkC3Vlzc-ZTDPByYq8dN1oPSo",
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "protected file in view (view's View API rule success)",
Method: http.MethodGet,
URL: "/api/files/view1/84nmscqy84lsi1t/test_d61b33QdDU.txt?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6ImZpbGUiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8ifQ.nSTLuCPcGpWn2K2l-BFkC3Vlzc-ZTDPByYq8dN1oPSo",
ExpectedStatus: 200,
ExpectedContent: []string{"test"},
ExpectedEvents: map[string]int{
"*": 0,
"OnFileDownloadRequest": 1,
},
},
// rate limit checks
// -----------------------------------------------------------
{
Name: "RateLimit rule - users:file",
Method: http.MethodGet,
URL: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 100, Label: "*:file"},
{MaxRequests: 0, Label: "users:file"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "RateLimit rule - *:file",
Method: http.MethodGet,
URL: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 0, Label: "*:file"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
}
for _, scenario := range scenarios {
// clone for the HEAD test (the same as the original scenario but without body)
head := scenario
head.Method = http.MethodHead
head.Name = ("(HEAD) " + scenario.Name)
head.ExpectedContent = nil
head.Test(t)
// regular request test
scenario.Test(t)
}
}
func TestConcurrentThumbsGeneration(t *testing.T) {
t.Parallel()
app, err := tests.NewTestApp()
if err != nil {
t.Fatal(err)
}
defer app.Cleanup()
fsys, err := app.NewFilesystem()
if err != nil {
t.Fatal(err)
}
defer fsys.Close()
// create a dummy file field collection
demo1, err := app.FindCollectionByNameOrId("demo1")
if err != nil {
t.Fatal(err)
}
fileField := demo1.Fields.GetByName("file_one").(*core.FileField)
fileField.Protected = false
fileField.MaxSelect = 1
fileField.MaxSize = 999999
// new thumbs
fileField.Thumbs = []string{"111x111", "111x222", "111x333"}
demo1.Fields.Add(fileField)
if err = app.Save(demo1); err != nil {
t.Fatal(err)
}
fileKey := "wsmn24bux7wo113/al1h9ijdeojtsjy/300_Jsjq7RdBgA.png"
urls := []string{
"/api/files/" + fileKey + "?thumb=111x111",
"/api/files/" + fileKey + "?thumb=111x111", // should still result in single thumb
"/api/files/" + fileKey + "?thumb=111x222",
"/api/files/" + fileKey + "?thumb=111x333",
}
var wg sync.WaitGroup
wg.Add(len(urls))
for _, url := range urls {
go func() {
defer wg.Done()
recorder := httptest.NewRecorder()
req := httptest.NewRequest("GET", url, nil)
pbRouter, _ := apis.NewRouter(app)
mux, _ := pbRouter.BuildMux()
if mux != nil {
mux.ServeHTTP(recorder, req)
}
}()
}
wg.Wait()
// ensure that all new requested thumbs were created
thumbKeys := []string{
"wsmn24bux7wo113/al1h9ijdeojtsjy/thumbs_300_Jsjq7RdBgA.png/111x111_" + filepath.Base(fileKey),
"wsmn24bux7wo113/al1h9ijdeojtsjy/thumbs_300_Jsjq7RdBgA.png/111x222_" + filepath.Base(fileKey),
"wsmn24bux7wo113/al1h9ijdeojtsjy/thumbs_300_Jsjq7RdBgA.png/111x333_" + filepath.Base(fileKey),
}
for _, k := range thumbKeys {
if exists, _ := fsys.Exists(k); !exists {
t.Fatalf("Missing thumb %q: %v", k, err)
}
}
}

53
apis/health.go Normal file
View file

@ -0,0 +1,53 @@
package apis
import (
"net/http"
"slices"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tools/router"
)
// bindHealthApi registers the health api endpoint.
func bindHealthApi(app core.App, rg *router.RouterGroup[*core.RequestEvent]) {
subGroup := rg.Group("/health")
subGroup.GET("", healthCheck)
}
// healthCheck returns a 200 OK response if the server is healthy.
func healthCheck(e *core.RequestEvent) error {
resp := struct {
Message string `json:"message"`
Code int `json:"code"`
Data map[string]any `json:"data"`
}{
Code: http.StatusOK,
Message: "API is healthy.",
}
if e.HasSuperuserAuth() {
resp.Data = make(map[string]any, 3)
resp.Data["canBackup"] = !e.App.Store().Has(core.StoreKeyActiveBackup)
resp.Data["realIP"] = e.RealIP()
// loosely check if behind a reverse proxy
// (usually used in the dashboard to remind superusers in case deployed behind reverse-proxy)
possibleProxyHeader := ""
headersToCheck := append(
slices.Clone(e.App.Settings().TrustedProxy.Headers),
// common proxy headers
"CF-Connecting-IP", "Fly-Client-IP", "XForwarded-For",
)
for _, header := range headersToCheck {
if e.Request.Header.Get(header) != "" {
possibleProxyHeader = header
break
}
}
resp.Data["possibleProxyHeader"] = possibleProxyHeader
} else {
resp.Data = map[string]any{} // ensure that it is returned as object
}
return e.JSON(http.StatusOK, resp)
}

71
apis/health_test.go Normal file
View file

@ -0,0 +1,71 @@
package apis_test
import (
"net/http"
"testing"
"github.com/pocketbase/pocketbase/tests"
)
func TestHealthAPI(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "GET health status (guest)",
Method: http.MethodGet, // automatically matches also HEAD as a side-effect of the Go std mux
URL: "/api/health",
ExpectedStatus: 200,
ExpectedContent: []string{
`"code":200`,
`"data":{}`,
},
NotExpectedContent: []string{
"canBackup",
"realIP",
"possibleProxyHeader",
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "GET health status (regular user)",
Method: http.MethodGet,
URL: "/api/health",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"code":200`,
`"data":{}`,
},
NotExpectedContent: []string{
"canBackup",
"realIP",
"possibleProxyHeader",
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "GET health status (superuser)",
Method: http.MethodGet,
URL: "/api/health",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"code":200`,
`"data":{`,
`"canBackup":true`,
`"realIP"`,
`"possibleProxyHeader"`,
},
ExpectedEvents: map[string]int{"*": 0},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}

88
apis/installer.go Normal file
View file

@ -0,0 +1,88 @@
package apis
import (
"database/sql"
"errors"
"fmt"
"os"
"strings"
"time"
"github.com/fatih/color"
"github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tools/osutils"
)
// DefaultInstallerFunc is the default PocketBase installer function.
//
// It will attempt to open a link in the browser (with a short-lived auth
// token for the systemSuperuser) to the installer UI so that users can
// create their own custom superuser record.
//
// See https://github.com/pocketbase/pocketbase/discussions/5814.
func DefaultInstallerFunc(app core.App, systemSuperuser *core.Record, baseURL string) error {
token, err := systemSuperuser.NewStaticAuthToken(30 * time.Minute)
if err != nil {
return err
}
// launch url (ignore errors and always print a help text as fallback)
url := fmt.Sprintf("%s/_/#/pbinstal/%s", strings.TrimRight(baseURL, "/"), token)
_ = osutils.LaunchURL(url)
color.Magenta("\n(!) Launch the URL below in the browser if it hasn't been open already to create your first superuser account:")
color.New(color.Bold).Add(color.FgCyan).Println(url)
color.New(color.FgHiBlack, color.Italic).Printf("(you can also create your first superuser by running: %s superuser upsert EMAIL PASS)\n\n", os.Args[0])
return nil
}
func loadInstaller(
app core.App,
baseURL string,
installerFunc func(app core.App, systemSuperuser *core.Record, baseURL string) error,
) error {
if installerFunc == nil || !needInstallerSuperuser(app) {
return nil
}
superuser, err := findOrCreateInstallerSuperuser(app)
if err != nil {
return err
}
return installerFunc(app, superuser, baseURL)
}
func needInstallerSuperuser(app core.App) bool {
total, err := app.CountRecords(core.CollectionNameSuperusers, dbx.Not(dbx.HashExp{
"email": core.DefaultInstallerEmail,
}))
return err == nil && total == 0
}
func findOrCreateInstallerSuperuser(app core.App) (*core.Record, error) {
col, err := app.FindCachedCollectionByNameOrId(core.CollectionNameSuperusers)
if err != nil {
return nil, err
}
record, err := app.FindAuthRecordByEmail(col, core.DefaultInstallerEmail)
if err != nil {
if !errors.Is(err, sql.ErrNoRows) {
return nil, err
}
record = core.NewRecord(col)
record.SetEmail(core.DefaultInstallerEmail)
record.SetRandomPassword()
err = app.Save(record)
if err != nil {
return nil, err
}
}
return record, nil
}

73
apis/logs.go Normal file
View file

@ -0,0 +1,73 @@
package apis
import (
"net/http"
"github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tools/router"
"github.com/pocketbase/pocketbase/tools/search"
)
// bindLogsApi registers the request logs api endpoints.
func bindLogsApi(app core.App, rg *router.RouterGroup[*core.RequestEvent]) {
sub := rg.Group("/logs").Bind(RequireSuperuserAuth(), SkipSuccessActivityLog())
sub.GET("", logsList)
sub.GET("/stats", logsStats)
sub.GET("/{id}", logsView)
}
var logFilterFields = []string{
"id", "created", "level", "message", "data",
`^data\.[\w\.\:]*\w+$`,
}
func logsList(e *core.RequestEvent) error {
fieldResolver := search.NewSimpleFieldResolver(logFilterFields...)
result, err := search.NewProvider(fieldResolver).
Query(e.App.AuxModelQuery(&core.Log{})).
ParseAndExec(e.Request.URL.Query().Encode(), &[]*core.Log{})
if err != nil {
return e.BadRequestError("", err)
}
return e.JSON(http.StatusOK, result)
}
func logsStats(e *core.RequestEvent) error {
fieldResolver := search.NewSimpleFieldResolver(logFilterFields...)
filter := e.Request.URL.Query().Get(search.FilterQueryParam)
var expr dbx.Expression
if filter != "" {
var err error
expr, err = search.FilterData(filter).BuildExpr(fieldResolver)
if err != nil {
return e.BadRequestError("Invalid filter format.", err)
}
}
stats, err := e.App.LogsStats(expr)
if err != nil {
return e.BadRequestError("Failed to generate logs stats.", err)
}
return e.JSON(http.StatusOK, stats)
}
func logsView(e *core.RequestEvent) error {
id := e.Request.PathValue("id")
if id == "" {
return e.NotFoundError("", nil)
}
log, err := e.App.FindLogById(id)
if err != nil || log == nil {
return e.NotFoundError("", err)
}
return e.JSON(http.StatusOK, log)
}

212
apis/logs_test.go Normal file
View file

@ -0,0 +1,212 @@
package apis_test
import (
"net/http"
"testing"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
)
func TestLogsList(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "unauthorized",
Method: http.MethodGet,
URL: "/api/logs",
ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as regular user",
Method: http.MethodGet,
URL: "/api/logs",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as superuser",
Method: http.MethodGet,
URL: "/api/logs",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubLogsData(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"page":1`,
`"perPage":30`,
`"totalItems":2`,
`"items":[{`,
`"id":"873f2133-9f38-44fb-bf82-c8f53b310d91"`,
`"id":"f2133873-44fb-9f38-bf82-c918f53b310d"`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as superuser + filter",
Method: http.MethodGet,
URL: "/api/logs?filter=data.status>200",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubLogsData(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"page":1`,
`"perPage":30`,
`"totalItems":1`,
`"items":[{`,
`"id":"f2133873-44fb-9f38-bf82-c918f53b310d"`,
},
ExpectedEvents: map[string]int{"*": 0},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestLogView(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "unauthorized",
Method: http.MethodGet,
URL: "/api/logs/873f2133-9f38-44fb-bf82-c8f53b310d91",
ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as regular user",
Method: http.MethodGet,
URL: "/api/logs/873f2133-9f38-44fb-bf82-c8f53b310d91",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as superuser (nonexisting request log)",
Method: http.MethodGet,
URL: "/api/logs/missing1-9f38-44fb-bf82-c8f53b310d91",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubLogsData(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as superuser (existing request log)",
Method: http.MethodGet,
URL: "/api/logs/873f2133-9f38-44fb-bf82-c8f53b310d91",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubLogsData(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"id":"873f2133-9f38-44fb-bf82-c8f53b310d91"`,
},
ExpectedEvents: map[string]int{"*": 0},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestLogsStats(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "unauthorized",
Method: http.MethodGet,
URL: "/api/logs/stats",
ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as regular user",
Method: http.MethodGet,
URL: "/api/logs/stats",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as superuser",
Method: http.MethodGet,
URL: "/api/logs/stats",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubLogsData(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 200,
ExpectedContent: []string{
`[{"date":"2022-05-01 10:00:00.000Z","total":1},{"date":"2022-05-02 10:00:00.000Z","total":1}]`,
},
},
{
Name: "authorized as superuser + filter",
Method: http.MethodGet,
URL: "/api/logs/stats?filter=data.status>200",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubLogsData(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 200,
ExpectedContent: []string{
`[{"date":"2022-05-02 10:00:00.000Z","total":1}]`,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}

444
apis/middlewares.go Normal file
View file

@ -0,0 +1,444 @@
package apis
import (
"errors"
"fmt"
"log/slog"
"net/http"
"net/url"
"runtime"
"slices"
"strings"
"time"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tools/hook"
"github.com/pocketbase/pocketbase/tools/list"
"github.com/pocketbase/pocketbase/tools/router"
"github.com/pocketbase/pocketbase/tools/routine"
"github.com/spf13/cast"
)
// Common request event store keys used by the middlewares and api handlers.
const (
RequestEventKeyLogMeta = "pbLogMeta" // extra data to store with the request activity log
requestEventKeyExecStart = "__execStart" // the value must be time.Time
requestEventKeySkipSuccessActivityLog = "__skipSuccessActivityLogger" // the value must be bool
)
const (
DefaultWWWRedirectMiddlewarePriority = -99999
DefaultWWWRedirectMiddlewareId = "pbWWWRedirect"
DefaultActivityLoggerMiddlewarePriority = DefaultRateLimitMiddlewarePriority - 40
DefaultActivityLoggerMiddlewareId = "pbActivityLogger"
DefaultSkipSuccessActivityLogMiddlewareId = "pbSkipSuccessActivityLog"
DefaultEnableAuthIdActivityLog = "pbEnableAuthIdActivityLog"
DefaultPanicRecoverMiddlewarePriority = DefaultRateLimitMiddlewarePriority - 30
DefaultPanicRecoverMiddlewareId = "pbPanicRecover"
DefaultLoadAuthTokenMiddlewarePriority = DefaultRateLimitMiddlewarePriority - 20
DefaultLoadAuthTokenMiddlewareId = "pbLoadAuthToken"
DefaultSecurityHeadersMiddlewarePriority = DefaultRateLimitMiddlewarePriority - 10
DefaultSecurityHeadersMiddlewareId = "pbSecurityHeaders"
DefaultRequireGuestOnlyMiddlewareId = "pbRequireGuestOnly"
DefaultRequireAuthMiddlewareId = "pbRequireAuth"
DefaultRequireSuperuserAuthMiddlewareId = "pbRequireSuperuserAuth"
DefaultRequireSuperuserOrOwnerAuthMiddlewareId = "pbRequireSuperuserOrOwnerAuth"
DefaultRequireSameCollectionContextAuthMiddlewareId = "pbRequireSameCollectionContextAuth"
)
// RequireGuestOnly middleware requires a request to NOT have a valid
// Authorization header.
//
// This middleware is the opposite of [apis.RequireAuth()].
func RequireGuestOnly() *hook.Handler[*core.RequestEvent] {
return &hook.Handler[*core.RequestEvent]{
Id: DefaultRequireGuestOnlyMiddlewareId,
Func: func(e *core.RequestEvent) error {
if e.Auth != nil {
return router.NewBadRequestError("The request can be accessed only by guests.", nil)
}
return e.Next()
},
}
}
// RequireAuth middleware requires a request to have a valid record Authorization header.
//
// The auth record could be from any collection.
// You can further filter the allowed record auth collections by specifying their names.
//
// Example:
//
// apis.RequireAuth() // any auth collection
// apis.RequireAuth("_superusers", "users") // only the listed auth collections
func RequireAuth(optCollectionNames ...string) *hook.Handler[*core.RequestEvent] {
return &hook.Handler[*core.RequestEvent]{
Id: DefaultRequireAuthMiddlewareId,
Func: requireAuth(optCollectionNames...),
}
}
func requireAuth(optCollectionNames ...string) func(*core.RequestEvent) error {
return func(e *core.RequestEvent) error {
if e.Auth == nil {
return e.UnauthorizedError("The request requires valid record authorization token.", nil)
}
// check record collection name
if len(optCollectionNames) > 0 && !slices.Contains(optCollectionNames, e.Auth.Collection().Name) {
return e.ForbiddenError("The authorized record is not allowed to perform this action.", nil)
}
return e.Next()
}
}
// RequireSuperuserAuth middleware requires a request to have
// a valid superuser Authorization header.
func RequireSuperuserAuth() *hook.Handler[*core.RequestEvent] {
return &hook.Handler[*core.RequestEvent]{
Id: DefaultRequireSuperuserAuthMiddlewareId,
Func: requireAuth(core.CollectionNameSuperusers),
}
}
// RequireSuperuserOrOwnerAuth middleware requires a request to have
// a valid superuser or regular record owner Authorization header set.
//
// This middleware is similar to [apis.RequireAuth()] but
// for the auth record token expects to have the same id as the path
// parameter ownerIdPathParam (default to "id" if empty).
func RequireSuperuserOrOwnerAuth(ownerIdPathParam string) *hook.Handler[*core.RequestEvent] {
return &hook.Handler[*core.RequestEvent]{
Id: DefaultRequireSuperuserOrOwnerAuthMiddlewareId,
Func: func(e *core.RequestEvent) error {
if e.Auth == nil {
return e.UnauthorizedError("The request requires superuser or record authorization token.", nil)
}
if e.Auth.IsSuperuser() {
return e.Next()
}
if ownerIdPathParam == "" {
ownerIdPathParam = "id"
}
ownerId := e.Request.PathValue(ownerIdPathParam)
// note: it is considered "safe" to compare only the record id
// since the auth record ids are treated as unique across all auth collections
if e.Auth.Id != ownerId {
return e.ForbiddenError("You are not allowed to perform this request.", nil)
}
return e.Next()
},
}
}
// RequireSameCollectionContextAuth middleware requires a request to have
// a valid record Authorization header and the auth record's collection to
// match the one from the route path parameter (default to "collection" if collectionParam is empty).
func RequireSameCollectionContextAuth(collectionPathParam string) *hook.Handler[*core.RequestEvent] {
return &hook.Handler[*core.RequestEvent]{
Id: DefaultRequireSameCollectionContextAuthMiddlewareId,
Func: func(e *core.RequestEvent) error {
if e.Auth == nil {
return e.UnauthorizedError("The request requires valid record authorization token.", nil)
}
if collectionPathParam == "" {
collectionPathParam = "collection"
}
collection, _ := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue(collectionPathParam))
if collection == nil || e.Auth.Collection().Id != collection.Id {
return e.ForbiddenError(fmt.Sprintf("The request requires auth record from %s collection.", e.Auth.Collection().Name), nil)
}
return e.Next()
},
}
}
// loadAuthToken attempts to load the auth context based on the "Authorization: TOKEN" header value.
//
// This middleware does nothing in case of:
// - missing, invalid or expired token
// - e.Auth is already loaded by another middleware
//
// This middleware is registered by default for all routes.
//
// Note: We don't throw an error on invalid or expired token to allow
// users to extend with their own custom handling in external middleware(s).
func loadAuthToken() *hook.Handler[*core.RequestEvent] {
return &hook.Handler[*core.RequestEvent]{
Id: DefaultLoadAuthTokenMiddlewareId,
Priority: DefaultLoadAuthTokenMiddlewarePriority,
Func: func(e *core.RequestEvent) error {
// already loaded by another middleware
if e.Auth != nil {
return e.Next()
}
token := getAuthTokenFromRequest(e)
if token == "" {
return e.Next()
}
record, err := e.App.FindAuthRecordByToken(token, core.TokenTypeAuth)
if err != nil {
e.App.Logger().Debug("loadAuthToken failure", "error", err)
} else if record != nil {
e.Auth = record
}
return e.Next()
},
}
}
func getAuthTokenFromRequest(e *core.RequestEvent) string {
token := e.Request.Header.Get("Authorization")
if token != "" {
// the schema prefix is not required and it is only for
// compatibility with the defaults of some HTTP clients
token = strings.TrimPrefix(token, "Bearer ")
}
return token
}
// wwwRedirect performs www->non-www redirect(s) if the request host
// matches with one of the values in redirectHosts.
//
// This middleware is registered by default on Serve for all routes.
func wwwRedirect(redirectHosts []string) *hook.Handler[*core.RequestEvent] {
return &hook.Handler[*core.RequestEvent]{
Id: DefaultWWWRedirectMiddlewareId,
Priority: DefaultWWWRedirectMiddlewarePriority,
Func: func(e *core.RequestEvent) error {
host := e.Request.Host
if strings.HasPrefix(host, "www.") && list.ExistInSlice(host, redirectHosts) {
// note: e.Request.URL.Scheme would be empty
schema := "http://"
if e.IsTLS() {
schema = "https://"
}
return e.Redirect(
http.StatusTemporaryRedirect,
(schema + host[4:] + e.Request.RequestURI),
)
}
return e.Next()
},
}
}
// panicRecover returns a default panic-recover handler.
func panicRecover() *hook.Handler[*core.RequestEvent] {
return &hook.Handler[*core.RequestEvent]{
Id: DefaultPanicRecoverMiddlewareId,
Priority: DefaultPanicRecoverMiddlewarePriority,
Func: func(e *core.RequestEvent) (err error) {
// panic-recover
defer func() {
recoverResult := recover()
if recoverResult == nil {
return
}
recoverErr, ok := recoverResult.(error)
if !ok {
recoverErr = fmt.Errorf("%v", recoverResult)
} else if errors.Is(recoverErr, http.ErrAbortHandler) {
// don't recover ErrAbortHandler so the response to the client can be aborted
panic(recoverResult)
}
stack := make([]byte, 2<<10) // 2 KB
length := runtime.Stack(stack, true)
err = e.InternalServerError("", fmt.Errorf("[PANIC RECOVER] %w %s", recoverErr, stack[:length]))
}()
err = e.Next()
return err
},
}
}
// securityHeaders middleware adds common security headers to the response.
//
// This middleware is registered by default for all routes.
func securityHeaders() *hook.Handler[*core.RequestEvent] {
return &hook.Handler[*core.RequestEvent]{
Id: DefaultSecurityHeadersMiddlewareId,
Priority: DefaultSecurityHeadersMiddlewarePriority,
Func: func(e *core.RequestEvent) error {
e.Response.Header().Set("X-XSS-Protection", "1; mode=block")
e.Response.Header().Set("X-Content-Type-Options", "nosniff")
e.Response.Header().Set("X-Frame-Options", "SAMEORIGIN")
// @todo consider a default HSTS?
// (see also https://webkit.org/blog/8146/protecting-against-hsts-abuse/)
return e.Next()
},
}
}
// SkipSuccessActivityLog is a helper middleware that instructs the global
// activity logger to log only requests that have failed/returned an error.
func SkipSuccessActivityLog() *hook.Handler[*core.RequestEvent] {
return &hook.Handler[*core.RequestEvent]{
Id: DefaultSkipSuccessActivityLogMiddlewareId,
Func: func(e *core.RequestEvent) error {
e.Set(requestEventKeySkipSuccessActivityLog, true)
return e.Next()
},
}
}
// activityLogger middleware takes care to save the request information
// into the logs database.
//
// This middleware is registered by default for all routes.
//
// The middleware does nothing if the app logs retention period is zero
// (aka. app.Settings().Logs.MaxDays = 0).
//
// Users can attach the [apis.SkipSuccessActivityLog()] middleware if
// you want to log only the failed requests.
func activityLogger() *hook.Handler[*core.RequestEvent] {
return &hook.Handler[*core.RequestEvent]{
Id: DefaultActivityLoggerMiddlewareId,
Priority: DefaultActivityLoggerMiddlewarePriority,
Func: func(e *core.RequestEvent) error {
e.Set(requestEventKeyExecStart, time.Now())
err := e.Next()
logRequest(e, err)
return err
},
}
}
func logRequest(event *core.RequestEvent, err error) {
// no logs retention
if event.App.Settings().Logs.MaxDays == 0 {
return
}
// the non-error route has explicitly disabled the activity logger
if err == nil && event.Get(requestEventKeySkipSuccessActivityLog) != nil {
return
}
attrs := make([]any, 0, 15)
attrs = append(attrs, slog.String("type", "request"))
started := cast.ToTime(event.Get(requestEventKeyExecStart))
if !started.IsZero() {
attrs = append(attrs, slog.Float64("execTime", float64(time.Since(started))/float64(time.Millisecond)))
}
if meta := event.Get(RequestEventKeyLogMeta); meta != nil {
attrs = append(attrs, slog.Any("meta", meta))
}
status := event.Status()
method := cutStr(strings.ToUpper(event.Request.Method), 50)
requestUri := cutStr(event.Request.URL.RequestURI(), 3000)
// parse the request error
if err != nil {
apiErr, isPlainApiError := err.(*router.ApiError)
if isPlainApiError || errors.As(err, &apiErr) {
// the status header wasn't written yet
if status == 0 {
status = apiErr.Status
}
var errMsg string
if isPlainApiError {
errMsg = apiErr.Message
} else {
// wrapped ApiError -> add the full serialized version
// of the original error since it could contain more information
errMsg = err.Error()
}
attrs = append(
attrs,
slog.String("error", errMsg),
slog.Any("details", apiErr.RawData()),
)
} else {
attrs = append(attrs, slog.String("error", err.Error()))
}
}
attrs = append(
attrs,
slog.String("url", requestUri),
slog.String("method", method),
slog.Int("status", status),
slog.String("referer", cutStr(event.Request.Referer(), 2000)),
slog.String("userAgent", cutStr(event.Request.UserAgent(), 2000)),
)
if event.Auth != nil {
attrs = append(attrs, slog.String("auth", event.Auth.Collection().Name))
if event.App.Settings().Logs.LogAuthId {
attrs = append(attrs, slog.String("authId", event.Auth.Id))
}
} else {
attrs = append(attrs, slog.String("auth", ""))
}
if event.App.Settings().Logs.LogIP {
attrs = append(
attrs,
slog.String("userIP", event.RealIP()),
slog.String("remoteIP", event.RemoteIP()),
)
}
// don't block on logs write
routine.FireAndForget(func() {
message := method + " "
if escaped, unescapeErr := url.PathUnescape(requestUri); unescapeErr == nil {
message += escaped
} else {
message += requestUri
}
if err != nil {
event.App.Logger().Error(message, attrs...)
} else {
event.App.Logger().Info(message, attrs...)
}
})
}
func cutStr(str string, max int) string {
if len(str) > max {
return str[:max] + "..."
}
return str
}

View file

@ -0,0 +1,120 @@
package apis
import (
"io"
"net/http"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tools/hook"
"github.com/pocketbase/pocketbase/tools/router"
)
var ErrRequestEntityTooLarge = router.NewApiError(http.StatusRequestEntityTooLarge, "Request entity too large", nil)
const DefaultMaxBodySize int64 = 32 << 20
const (
DefaultBodyLimitMiddlewareId = "pbBodyLimit"
DefaultBodyLimitMiddlewarePriority = DefaultRateLimitMiddlewarePriority + 10
)
// BodyLimit returns a middleware handler that changes the default request body size limit.
//
// If limitBytes <= 0, no limit is applied.
//
// Otherwise, if the request body size exceeds the configured limitBytes,
// it sends 413 error response.
func BodyLimit(limitBytes int64) *hook.Handler[*core.RequestEvent] {
return &hook.Handler[*core.RequestEvent]{
Id: DefaultBodyLimitMiddlewareId,
Priority: DefaultBodyLimitMiddlewarePriority,
Func: func(e *core.RequestEvent) error {
err := applyBodyLimit(e, limitBytes)
if err != nil {
return err
}
return e.Next()
},
}
}
func dynamicCollectionBodyLimit(collectionPathParam string) *hook.Handler[*core.RequestEvent] {
if collectionPathParam == "" {
collectionPathParam = "collection"
}
return &hook.Handler[*core.RequestEvent]{
Id: DefaultBodyLimitMiddlewareId,
Priority: DefaultBodyLimitMiddlewarePriority,
Func: func(e *core.RequestEvent) error {
collection, err := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue(collectionPathParam))
if err != nil {
return e.NotFoundError("Missing or invalid collection context.", err)
}
limitBytes := DefaultMaxBodySize
if !collection.IsView() {
for _, f := range collection.Fields {
if calc, ok := f.(core.MaxBodySizeCalculator); ok {
limitBytes += calc.CalculateMaxBodySize()
}
}
}
err = applyBodyLimit(e, limitBytes)
if err != nil {
return err
}
return e.Next()
},
}
}
func applyBodyLimit(e *core.RequestEvent, limitBytes int64) error {
// no limit
if limitBytes <= 0 {
return nil
}
// optimistically check the submitted request content length
if e.Request.ContentLength > limitBytes {
return ErrRequestEntityTooLarge
}
// replace the request body
//
// note: we don't use sync.Pool since the size of the elements could vary too much
// and it might not be efficient (see https://github.com/golang/go/issues/23199)
e.Request.Body = &limitedReader{ReadCloser: e.Request.Body, limit: limitBytes}
return nil
}
type limitedReader struct {
io.ReadCloser
limit int64
totalRead int64
}
func (r *limitedReader) Read(b []byte) (int, error) {
n, err := r.ReadCloser.Read(b)
if err != nil {
return n, err
}
r.totalRead += int64(n)
if r.totalRead > r.limit {
return n, ErrRequestEntityTooLarge
}
return n, nil
}
func (r *limitedReader) Reread() {
rr, ok := r.ReadCloser.(router.Rereader)
if ok {
rr.Reread()
}
}

View file

@ -0,0 +1,60 @@
package apis_test
import (
"bytes"
"fmt"
"net/http/httptest"
"testing"
"github.com/pocketbase/pocketbase/apis"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
)
func TestBodyLimitMiddleware(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
pbRouter, err := apis.NewRouter(app)
if err != nil {
t.Fatal(err)
}
pbRouter.POST("/a", func(e *core.RequestEvent) error {
return e.String(200, "a")
}) // default global BodyLimit check
pbRouter.POST("/b", func(e *core.RequestEvent) error {
return e.String(200, "b")
}).Bind(apis.BodyLimit(20))
mux, err := pbRouter.BuildMux()
if err != nil {
t.Fatal(err)
}
scenarios := []struct {
url string
size int64
expectedStatus int
}{
{"/a", 21, 200},
{"/a", apis.DefaultMaxBodySize + 1, 413},
{"/b", 20, 200},
{"/b", 21, 413},
}
for _, s := range scenarios {
t.Run(fmt.Sprintf("%s_%d", s.url, s.size), func(t *testing.T) {
rec := httptest.NewRecorder()
req := httptest.NewRequest("POST", s.url, bytes.NewReader(make([]byte, s.size)))
mux.ServeHTTP(rec, req)
result := rec.Result()
defer result.Body.Close()
if result.StatusCode != s.expectedStatus {
t.Fatalf("Expected response status %d, got %d", s.expectedStatus, result.StatusCode)
}
})
}
}

327
apis/middlewares_cors.go Normal file
View file

@ -0,0 +1,327 @@
package apis
// -------------------------------------------------------------------
// This middleware is ported from echo/middleware to minimize the breaking
// changes and differences in the API behavior from earlier PocketBase versions
// (https://github.com/labstack/echo/blob/ec5b858dab6105ab4c3ed2627d1ebdfb6ae1ecb8/middleware/cors.go).
//
// I doubt that this would matter for most cases, but the only major difference
// is that for non-supported routes this middleware doesn't return 405 and fallbacks
// to the default catch-all PocketBase route (aka. returns 404) to avoid
// the extra overhead of further hijacking and wrapping the Go default mux
// (https://github.com/golang/go/issues/65648#issuecomment-1955328807).
// -------------------------------------------------------------------
import (
"log"
"net/http"
"regexp"
"strconv"
"strings"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tools/hook"
)
const (
DefaultCorsMiddlewareId = "pbCors"
DefaultCorsMiddlewarePriority = DefaultActivityLoggerMiddlewarePriority - 1 // before the activity logger and rate limit so that OPTIONS preflight requests are not counted
)
// CORSConfig defines the config for CORS middleware.
type CORSConfig struct {
// AllowOrigins determines the value of the Access-Control-Allow-Origin
// response header. This header defines a list of origins that may access the
// resource. The wildcard characters '*' and '?' are supported and are
// converted to regex fragments '.*' and '.' accordingly.
//
// Security: use extreme caution when handling the origin, and carefully
// validate any logic. Remember that attackers may register hostile domain names.
// See https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html
//
// Optional. Default value []string{"*"}.
//
// See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin
AllowOrigins []string
// AllowOriginFunc is a custom function to validate the origin. It takes the
// origin as an argument and returns true if allowed or false otherwise. If
// an error is returned, it is returned by the handler. If this option is
// set, AllowOrigins is ignored.
//
// Security: use extreme caution when handling the origin, and carefully
// validate any logic. Remember that attackers may register hostile domain names.
// See https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html
//
// Optional.
AllowOriginFunc func(origin string) (bool, error)
// AllowMethods determines the value of the Access-Control-Allow-Methods
// response header. This header specified the list of methods allowed when
// accessing the resource. This is used in response to a preflight request.
//
// Optional. Default value DefaultCORSConfig.AllowMethods.
//
// See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Methods
AllowMethods []string
// AllowHeaders determines the value of the Access-Control-Allow-Headers
// response header. This header is used in response to a preflight request to
// indicate which HTTP headers can be used when making the actual request.
//
// Optional. Default value []string{}.
//
// See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers
AllowHeaders []string
// AllowCredentials determines the value of the
// Access-Control-Allow-Credentials response header. This header indicates
// whether or not the response to the request can be exposed when the
// credentials mode (Request.credentials) is true. When used as part of a
// response to a preflight request, this indicates whether or not the actual
// request can be made using credentials. See also
// [MDN: Access-Control-Allow-Credentials].
//
// Optional. Default value false, in which case the header is not set.
//
// Security: avoid using `AllowCredentials = true` with `AllowOrigins = *`.
// See "Exploiting CORS misconfigurations for Bitcoins and bounties",
// https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html
//
// See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Credentials
AllowCredentials bool
// UnsafeWildcardOriginWithAllowCredentials UNSAFE/INSECURE: allows wildcard '*' origin to be used with AllowCredentials
// flag. In that case we consider any origin allowed and send it back to the client with `Access-Control-Allow-Origin` header.
//
// This is INSECURE and potentially leads to [cross-origin](https://portswigger.net/research/exploiting-cors-misconfigurations-for-bitcoins-and-bounties)
// attacks. See: https://github.com/labstack/echo/issues/2400 for discussion on the subject.
//
// Optional. Default value is false.
UnsafeWildcardOriginWithAllowCredentials bool
// ExposeHeaders determines the value of Access-Control-Expose-Headers, which
// defines a list of headers that clients are allowed to access.
//
// Optional. Default value []string{}, in which case the header is not set.
//
// See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Expose-Header
ExposeHeaders []string
// MaxAge determines the value of the Access-Control-Max-Age response header.
// This header indicates how long (in seconds) the results of a preflight
// request can be cached.
// The header is set only if MaxAge != 0, negative value sends "0" which instructs browsers not to cache that response.
//
// Optional. Default value 0 - meaning header is not sent.
//
// See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Max-Age
MaxAge int
}
// DefaultCORSConfig is the default CORS middleware config.
var DefaultCORSConfig = CORSConfig{
AllowOrigins: []string{"*"},
AllowMethods: []string{http.MethodGet, http.MethodHead, http.MethodPut, http.MethodPatch, http.MethodPost, http.MethodDelete},
}
// CORS returns a CORS middleware.
func CORS(config CORSConfig) *hook.Handler[*core.RequestEvent] {
// Defaults
if len(config.AllowOrigins) == 0 {
config.AllowOrigins = DefaultCORSConfig.AllowOrigins
}
if len(config.AllowMethods) == 0 {
config.AllowMethods = DefaultCORSConfig.AllowMethods
}
allowOriginPatterns := make([]*regexp.Regexp, 0, len(config.AllowOrigins))
for _, origin := range config.AllowOrigins {
if origin == "*" {
continue // "*" is handled differently and does not need regexp
}
pattern := regexp.QuoteMeta(origin)
pattern = strings.ReplaceAll(pattern, "\\*", ".*")
pattern = strings.ReplaceAll(pattern, "\\?", ".")
pattern = "^" + pattern + "$"
re, err := regexp.Compile(pattern)
if err != nil {
// This is to preserve previous behaviour - invalid patterns were just ignored.
// If we would turn this to panic, users with invalid patterns
// would have applications crashing in production due unrecovered panic.
log.Println("invalid AllowOrigins pattern", origin)
continue
}
allowOriginPatterns = append(allowOriginPatterns, re)
}
allowMethods := strings.Join(config.AllowMethods, ",")
allowHeaders := strings.Join(config.AllowHeaders, ",")
exposeHeaders := strings.Join(config.ExposeHeaders, ",")
maxAge := "0"
if config.MaxAge > 0 {
maxAge = strconv.Itoa(config.MaxAge)
}
return &hook.Handler[*core.RequestEvent]{
Id: DefaultCorsMiddlewareId,
Priority: DefaultCorsMiddlewarePriority,
Func: func(e *core.RequestEvent) error {
req := e.Request
res := e.Response
origin := req.Header.Get("Origin")
allowOrigin := ""
res.Header().Add("Vary", "Origin")
// Preflight request is an OPTIONS request, using three HTTP request headers: Access-Control-Request-Method,
// Access-Control-Request-Headers, and the Origin header. See: https://developer.mozilla.org/en-US/docs/Glossary/Preflight_request
// For simplicity we just consider method type and later `Origin` header.
preflight := req.Method == http.MethodOptions
// No Origin provided. This is (probably) not request from actual browser - proceed executing middleware chain
if origin == "" {
if !preflight {
return e.Next()
}
return e.NoContent(http.StatusNoContent)
}
if config.AllowOriginFunc != nil {
allowed, err := config.AllowOriginFunc(origin)
if err != nil {
return err
}
if allowed {
allowOrigin = origin
}
} else {
// Check allowed origins
for _, o := range config.AllowOrigins {
if o == "*" && config.AllowCredentials && config.UnsafeWildcardOriginWithAllowCredentials {
allowOrigin = origin
break
}
if o == "*" || o == origin {
allowOrigin = o
break
}
if matchSubdomain(origin, o) {
allowOrigin = origin
break
}
}
checkPatterns := false
if allowOrigin == "" {
// to avoid regex cost by invalid (long) domains (253 is domain name max limit)
if len(origin) <= (253+3+5) && strings.Contains(origin, "://") {
checkPatterns = true
}
}
if checkPatterns {
for _, re := range allowOriginPatterns {
if match := re.MatchString(origin); match {
allowOrigin = origin
break
}
}
}
}
// Origin not allowed
if allowOrigin == "" {
if !preflight {
return e.Next()
}
return e.NoContent(http.StatusNoContent)
}
res.Header().Set("Access-Control-Allow-Origin", allowOrigin)
if config.AllowCredentials {
res.Header().Set("Access-Control-Allow-Credentials", "true")
}
// Simple request
if !preflight {
if exposeHeaders != "" {
res.Header().Set("Access-Control-Expose-Headers", exposeHeaders)
}
return e.Next()
}
// Preflight request
res.Header().Add("Vary", "Access-Control-Request-Method")
res.Header().Add("Vary", "Access-Control-Request-Headers")
res.Header().Set("Access-Control-Allow-Methods", allowMethods)
if allowHeaders != "" {
res.Header().Set("Access-Control-Allow-Headers", allowHeaders)
} else {
h := req.Header.Get("Access-Control-Request-Headers")
if h != "" {
res.Header().Set("Access-Control-Allow-Headers", h)
}
}
if config.MaxAge != 0 {
res.Header().Set("Access-Control-Max-Age", maxAge)
}
return e.NoContent(http.StatusNoContent)
},
}
}
func matchScheme(domain, pattern string) bool {
didx := strings.Index(domain, ":")
pidx := strings.Index(pattern, ":")
return didx != -1 && pidx != -1 && domain[:didx] == pattern[:pidx]
}
// matchSubdomain compares authority with wildcard
func matchSubdomain(domain, pattern string) bool {
if !matchScheme(domain, pattern) {
return false
}
didx := strings.Index(domain, "://")
pidx := strings.Index(pattern, "://")
if didx == -1 || pidx == -1 {
return false
}
domAuth := domain[didx+3:]
// to avoid long loop by invalid long domain
if len(domAuth) > 253 {
return false
}
patAuth := pattern[pidx+3:]
domComp := strings.Split(domAuth, ".")
patComp := strings.Split(patAuth, ".")
for i := len(domComp)/2 - 1; i >= 0; i-- {
opp := len(domComp) - 1 - i
domComp[i], domComp[opp] = domComp[opp], domComp[i]
}
for i := len(patComp)/2 - 1; i >= 0; i-- {
opp := len(patComp) - 1 - i
patComp[i], patComp[opp] = patComp[opp], patComp[i]
}
for i, v := range domComp {
if len(patComp) <= i {
return false
}
p := patComp[i]
if p == "*" {
return true
}
if p != v {
return false
}
}
return false
}

247
apis/middlewares_gzip.go Normal file
View file

@ -0,0 +1,247 @@
package apis
// -------------------------------------------------------------------
// This middleware is ported from echo/middleware to minimize the breaking
// changes and differences in the API behavior from earlier PocketBase versions
// (https://github.com/labstack/echo/blob/ec5b858dab6105ab4c3ed2627d1ebdfb6ae1ecb8/middleware/compress.go).
// -------------------------------------------------------------------
import (
"bufio"
"bytes"
"compress/gzip"
"errors"
"io"
"net"
"net/http"
"strings"
"sync"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tools/hook"
"github.com/pocketbase/pocketbase/tools/router"
)
const (
gzipScheme = "gzip"
)
const (
DefaultGzipMiddlewareId = "pbGzip"
)
// GzipConfig defines the config for Gzip middleware.
type GzipConfig struct {
// Gzip compression level.
// Optional. Default value -1.
Level int
// Length threshold before gzip compression is applied.
// Optional. Default value 0.
//
// Most of the time you will not need to change the default. Compressing
// a short response might increase the transmitted data because of the
// gzip format overhead. Compressing the response will also consume CPU
// and time on the server and the client (for decompressing). Depending on
// your use case such a threshold might be useful.
//
// See also:
// https://webmasters.stackexchange.com/questions/31750/what-is-recommended-minimum-object-size-for-gzip-performance-benefits
MinLength int
}
// Gzip returns a middleware which compresses HTTP response using Gzip compression scheme.
func Gzip() *hook.Handler[*core.RequestEvent] {
return GzipWithConfig(GzipConfig{})
}
// GzipWithConfig returns a middleware which compresses HTTP response using gzip compression scheme.
func GzipWithConfig(config GzipConfig) *hook.Handler[*core.RequestEvent] {
if config.Level < -2 || config.Level > 9 { // these are consts: gzip.HuffmanOnly and gzip.BestCompression
panic(errors.New("invalid gzip level"))
}
if config.Level == 0 {
config.Level = -1
}
if config.MinLength < 0 {
config.MinLength = 0
}
pool := sync.Pool{
New: func() interface{} {
w, err := gzip.NewWriterLevel(io.Discard, config.Level)
if err != nil {
return err
}
return w
},
}
bpool := sync.Pool{
New: func() interface{} {
b := &bytes.Buffer{}
return b
},
}
return &hook.Handler[*core.RequestEvent]{
Id: DefaultGzipMiddlewareId,
Func: func(e *core.RequestEvent) error {
e.Response.Header().Add("Vary", "Accept-Encoding")
if strings.Contains(e.Request.Header.Get("Accept-Encoding"), gzipScheme) {
w, ok := pool.Get().(*gzip.Writer)
if !ok {
return e.InternalServerError("", errors.New("failed to get gzip.Writer"))
}
rw := e.Response
w.Reset(rw)
buf := bpool.Get().(*bytes.Buffer)
buf.Reset()
grw := &gzipResponseWriter{Writer: w, ResponseWriter: rw, minLength: config.MinLength, buffer: buf}
defer func() {
// There are different reasons for cases when we have not yet written response to the client and now need to do so.
// a) handler response had only response code and no response body (ala 404 or redirects etc). Response code need to be written now.
// b) body is shorter than our minimum length threshold and being buffered currently and needs to be written
if !grw.wroteBody {
if rw.Header().Get("Content-Encoding") == gzipScheme {
rw.Header().Del("Content-Encoding")
}
if grw.wroteHeader {
rw.WriteHeader(grw.code)
}
// We have to reset response to it's pristine state when
// nothing is written to body or error is returned.
// See issue echo#424, echo#407.
e.Response = rw
w.Reset(io.Discard)
} else if !grw.minLengthExceeded {
// Write uncompressed response
e.Response = rw
if grw.wroteHeader {
rw.WriteHeader(grw.code)
}
grw.buffer.WriteTo(rw)
w.Reset(io.Discard)
}
w.Close()
bpool.Put(buf)
pool.Put(w)
}()
e.Response = grw
}
return e.Next()
},
}
}
type gzipResponseWriter struct {
http.ResponseWriter
io.Writer
buffer *bytes.Buffer
minLength int
code int
wroteHeader bool
wroteBody bool
minLengthExceeded bool
}
func (w *gzipResponseWriter) WriteHeader(code int) {
w.Header().Del("Content-Length") // Issue echo#444
w.wroteHeader = true
// Delay writing of the header until we know if we'll actually compress the response
w.code = code
}
func (w *gzipResponseWriter) Write(b []byte) (int, error) {
if w.Header().Get("Content-Type") == "" {
w.Header().Set("Content-Type", http.DetectContentType(b))
}
w.wroteBody = true
if !w.minLengthExceeded {
n, err := w.buffer.Write(b)
if w.buffer.Len() >= w.minLength {
w.minLengthExceeded = true
// The minimum length is exceeded, add Content-Encoding header and write the header
w.Header().Set("Content-Encoding", gzipScheme)
if w.wroteHeader {
w.ResponseWriter.WriteHeader(w.code)
}
return w.Writer.Write(w.buffer.Bytes())
}
return n, err
}
return w.Writer.Write(b)
}
func (w *gzipResponseWriter) Flush() {
if !w.minLengthExceeded {
// Enforce compression because we will not know how much more data will come
w.minLengthExceeded = true
w.Header().Set("Content-Encoding", gzipScheme)
if w.wroteHeader {
w.ResponseWriter.WriteHeader(w.code)
}
_, _ = w.Writer.Write(w.buffer.Bytes())
}
_ = w.Writer.(*gzip.Writer).Flush()
_ = http.NewResponseController(w.ResponseWriter).Flush()
}
func (w *gzipResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return http.NewResponseController(w.ResponseWriter).Hijack()
}
func (w *gzipResponseWriter) Push(target string, opts *http.PushOptions) error {
rw := w.ResponseWriter
for {
switch p := rw.(type) {
case http.Pusher:
return p.Push(target, opts)
case router.RWUnwrapper:
rw = p.Unwrap()
default:
return http.ErrNotSupported
}
}
}
// Note: Disable the implementation for now because in case the platform
// supports the sendfile fast-path it won't run gzipResponseWriter.Write,
// preventing compression on the fly.
//
// func (w *gzipResponseWriter) ReadFrom(r io.Reader) (n int64, err error) {
// if w.wroteHeader {
// w.ResponseWriter.WriteHeader(w.code)
// }
// rw := w.ResponseWriter
// for {
// switch rf := rw.(type) {
// case io.ReaderFrom:
// return rf.ReadFrom(r)
// case router.RWUnwrapper:
// rw = rf.Unwrap()
// default:
// return io.Copy(w.ResponseWriter, r)
// }
// }
// }
func (w *gzipResponseWriter) Unwrap() http.ResponseWriter {
return w.ResponseWriter
}

View file

@ -0,0 +1,356 @@
package apis
import (
"sync"
"time"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tools/hook"
"github.com/pocketbase/pocketbase/tools/store"
)
const (
DefaultRateLimitMiddlewareId = "pbRateLimit"
DefaultRateLimitMiddlewarePriority = -1000
)
const (
rateLimitersStoreKey = "__pbRateLimiters__"
rateLimitersCronKey = "__pbRateLimitersCleanup__"
rateLimitersSettingsHookId = "__pbRateLimitersSettingsHook__"
)
// rateLimit defines the global rate limit middleware.
//
// This middleware is registered by default for all routes.
func rateLimit() *hook.Handler[*core.RequestEvent] {
return &hook.Handler[*core.RequestEvent]{
Id: DefaultRateLimitMiddlewareId,
Priority: DefaultRateLimitMiddlewarePriority,
Func: func(e *core.RequestEvent) error {
if skipRateLimit(e) {
return e.Next()
}
rule, ok := e.App.Settings().RateLimits.FindRateLimitRule(
defaultRateLimitLabels(e),
defaultRateLimitAudience(e)...,
)
if ok {
err := checkRateLimit(e, rule.Label+rule.Audience, rule)
if err != nil {
return err
}
}
return e.Next()
},
}
}
// collectionPathRateLimit defines a rate limit middleware for the internal collection handlers.
func collectionPathRateLimit(collectionPathParam string, baseTags ...string) *hook.Handler[*core.RequestEvent] {
if collectionPathParam == "" {
collectionPathParam = "collection"
}
return &hook.Handler[*core.RequestEvent]{
Id: DefaultRateLimitMiddlewareId,
Priority: DefaultRateLimitMiddlewarePriority,
Func: func(e *core.RequestEvent) error {
collection, err := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue(collectionPathParam))
if err != nil {
return e.NotFoundError("Missing or invalid collection context.", err)
}
if err := checkCollectionRateLimit(e, collection, baseTags...); err != nil {
return err
}
return e.Next()
},
}
}
// checkCollectionRateLimit checks whether the current request satisfy the
// rate limit configuration for the specific collection.
//
// Each baseTags entry will be prefixed with the collection name and its wildcard variant.
func checkCollectionRateLimit(e *core.RequestEvent, collection *core.Collection, baseTags ...string) error {
if skipRateLimit(e) {
return nil
}
labels := make([]string, 0, 2+len(baseTags)*2)
rtId := collection.Id + e.Request.Pattern
// add first the primary labels (aka. ["collectionName:action1", "collectionName:action2"])
for _, baseTag := range baseTags {
rtId += baseTag
labels = append(labels, collection.Name+":"+baseTag)
}
// add the wildcard labels (aka. [..., "*:action1","*:action2", "*"])
for _, baseTag := range baseTags {
labels = append(labels, "*:"+baseTag)
}
labels = append(labels, defaultRateLimitLabels(e)...)
rule, ok := e.App.Settings().RateLimits.FindRateLimitRule(labels, defaultRateLimitAudience(e)...)
if ok {
return checkRateLimit(e, rtId+rule.Audience, rule)
}
return nil
}
// -------------------------------------------------------------------
// @todo consider exporting as helper?
//
//nolint:unused
func isClientRateLimited(e *core.RequestEvent, rtId string) bool {
rateLimiters, ok := e.App.Store().Get(rateLimitersStoreKey).(*store.Store[string, *rateLimiter])
if !ok || rateLimiters == nil {
return false
}
rt, ok := rateLimiters.GetOk(rtId)
if !ok || rt == nil {
return false
}
client, ok := rt.getClient(e.RealIP())
if !ok || client == nil {
return false
}
return client.available <= 0 && time.Now().Unix()-client.lastConsume < client.interval
}
// @todo consider exporting as helper?
func checkRateLimit(e *core.RequestEvent, rtId string, rule core.RateLimitRule) error {
switch rule.Audience {
case core.RateLimitRuleAudienceAll:
// valid for both guest and regular users
case core.RateLimitRuleAudienceGuest:
if e.Auth != nil {
return nil
}
case core.RateLimitRuleAudienceAuth:
if e.Auth == nil {
return nil
}
}
rateLimiters := e.App.Store().GetOrSet(rateLimitersStoreKey, func() any {
return initRateLimitersStore(e.App)
}).(*store.Store[string, *rateLimiter])
if rateLimiters == nil {
e.App.Logger().Warn("Failed to retrieve app rate limiters store")
return nil
}
rt := rateLimiters.GetOrSet(rtId, func() *rateLimiter {
return newRateLimiter(rule.MaxRequests, rule.Duration, rule.Duration+1800)
})
if rt == nil {
e.App.Logger().Warn("Failed to retrieve app rate limiter", "id", rtId)
return nil
}
key := e.RealIP()
if key == "" {
e.App.Logger().Warn("Empty rate limit client key")
return nil
}
if !rt.isAllowed(key) {
return e.TooManyRequestsError("", nil)
}
return nil
}
func skipRateLimit(e *core.RequestEvent) bool {
return !e.App.Settings().RateLimits.Enabled || e.HasSuperuserAuth()
}
var defaultAuthAudience = []string{core.RateLimitRuleAudienceAll, core.RateLimitRuleAudienceAuth}
var defaultGuestAudience = []string{core.RateLimitRuleAudienceAll, core.RateLimitRuleAudienceGuest}
func defaultRateLimitAudience(e *core.RequestEvent) []string {
if e.Auth != nil {
return defaultAuthAudience
}
return defaultGuestAudience
}
func defaultRateLimitLabels(e *core.RequestEvent) []string {
return []string{e.Request.Method + " " + e.Request.URL.Path, e.Request.URL.Path}
}
func destroyRateLimitersStore(app core.App) {
app.OnSettingsReload().Unbind(rateLimitersSettingsHookId)
app.Cron().Remove(rateLimitersCronKey)
app.Store().Remove(rateLimitersStoreKey)
}
func initRateLimitersStore(app core.App) *store.Store[string, *rateLimiter] {
app.Cron().Add(rateLimitersCronKey, "2 * * * *", func() { // offset a little since too many cleanup tasks execute at 00
limitersStore, ok := app.Store().Get(rateLimitersStoreKey).(*store.Store[string, *rateLimiter])
if !ok {
return
}
limiters := limitersStore.GetAll()
for _, limiter := range limiters {
limiter.clean()
}
})
app.OnSettingsReload().Bind(&hook.Handler[*core.SettingsReloadEvent]{
Id: rateLimitersSettingsHookId,
Func: func(e *core.SettingsReloadEvent) error {
err := e.Next()
if err != nil {
return err
}
// reset
destroyRateLimitersStore(e.App)
return nil
},
})
return store.New[string, *rateLimiter](nil)
}
func newRateLimiter(maxAllowed int, intervalInSec int64, minDeleteIntervalInSec int64) *rateLimiter {
return &rateLimiter{
maxAllowed: maxAllowed,
interval: intervalInSec,
minDeleteInterval: minDeleteIntervalInSec,
clients: map[string]*fixedWindow{},
}
}
type rateLimiter struct {
clients map[string]*fixedWindow
maxAllowed int
interval int64
minDeleteInterval int64
totalDeleted int64
sync.RWMutex
}
//nolint:unused
func (rt *rateLimiter) getClient(key string) (*fixedWindow, bool) {
rt.RLock()
client, ok := rt.clients[key]
rt.RUnlock()
return client, ok
}
func (rt *rateLimiter) isAllowed(key string) bool {
// lock only reads to minimize locks contention
rt.RLock()
client, ok := rt.clients[key]
rt.RUnlock()
if !ok {
rt.Lock()
// check again in case the client was added by another request
client, ok = rt.clients[key]
if !ok {
client = newFixedWindow(rt.maxAllowed, rt.interval)
rt.clients[key] = client
}
rt.Unlock()
}
return client.consume()
}
func (rt *rateLimiter) clean() {
rt.Lock()
defer rt.Unlock()
nowUnix := time.Now().Unix()
for k, client := range rt.clients {
if client.hasExpired(nowUnix, rt.minDeleteInterval) {
delete(rt.clients, k)
rt.totalDeleted++
}
}
// "shrink" the map if too may items were deleted
//
// @todo remove after https://github.com/golang/go/issues/20135
if rt.totalDeleted >= 300 {
shrunk := make(map[string]*fixedWindow, len(rt.clients))
for k, v := range rt.clients {
shrunk[k] = v
}
rt.clients = shrunk
rt.totalDeleted = 0
}
}
func newFixedWindow(maxAllowed int, intervalInSec int64) *fixedWindow {
return &fixedWindow{
maxAllowed: maxAllowed,
interval: intervalInSec,
}
}
type fixedWindow struct {
// use plain Mutex instead of RWMutex since the operations are expected
// to be mostly writes (e.g. consume()) and it should perform better
sync.Mutex
maxAllowed int // the max allowed tokens per interval
available int // the total available tokens
interval int64 // in seconds
lastConsume int64 // the time of the last consume
}
// hasExpired checks whether it has been at least minElapsed seconds since the lastConsume time.
// (usually used to perform periodic cleanup of staled instances).
func (l *fixedWindow) hasExpired(relativeNow int64, minElapsed int64) bool {
l.Lock()
defer l.Unlock()
return relativeNow-l.lastConsume > minElapsed
}
// consume decrease the current window allowance with 1 (if not exhausted already).
//
// It returns false if the allowance has been already exhausted and the user
// has to wait until it resets back to its maxAllowed value.
func (l *fixedWindow) consume() bool {
l.Lock()
defer l.Unlock()
nowUnix := time.Now().Unix()
// reset consumed counter
if nowUnix-l.lastConsume >= l.interval {
l.available = l.maxAllowed
}
if l.available > 0 {
l.available--
l.lastConsume = nowUnix
return true
}
return false
}

View file

@ -0,0 +1,159 @@
package apis_test
import (
"net/http/httptest"
"testing"
"time"
"github.com/pocketbase/pocketbase/apis"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
)
func TestDefaultRateLimitMiddleware(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{
Label: "/rate/",
MaxRequests: 2,
Duration: 1,
},
{
Label: "/rate/b",
MaxRequests: 3,
Duration: 1,
},
{
Label: "POST /rate/b",
MaxRequests: 1,
Duration: 1,
},
{
Label: "/rate/guest",
MaxRequests: 1,
Duration: 1,
Audience: core.RateLimitRuleAudienceGuest,
},
{
Label: "/rate/auth",
MaxRequests: 1,
Duration: 1,
Audience: core.RateLimitRuleAudienceAuth,
},
}
pbRouter, err := apis.NewRouter(app)
if err != nil {
t.Fatal(err)
}
pbRouter.GET("/norate", func(e *core.RequestEvent) error {
return e.String(200, "norate")
}).BindFunc(func(e *core.RequestEvent) error {
return e.Next()
})
pbRouter.GET("/rate/a", func(e *core.RequestEvent) error {
return e.String(200, "a")
})
pbRouter.GET("/rate/b", func(e *core.RequestEvent) error {
return e.String(200, "b")
})
pbRouter.GET("/rate/guest", func(e *core.RequestEvent) error {
return e.String(200, "guest")
})
pbRouter.GET("/rate/auth", func(e *core.RequestEvent) error {
return e.String(200, "auth")
})
mux, err := pbRouter.BuildMux()
if err != nil {
t.Fatal(err)
}
scenarios := []struct {
url string
wait float64
authenticated bool
expectedStatus int
}{
{"/norate", 0, false, 200},
{"/norate", 0, false, 200},
{"/norate", 0, false, 200},
{"/norate", 0, false, 200},
{"/norate", 0, false, 200},
{"/rate/a", 0, false, 200},
{"/rate/a", 0, false, 200},
{"/rate/a", 0, false, 429},
{"/rate/a", 0, false, 429},
{"/rate/a", 1.1, false, 200},
{"/rate/a", 0, false, 200},
{"/rate/a", 0, false, 429},
{"/rate/b", 0, false, 200},
{"/rate/b", 0, false, 200},
{"/rate/b", 0, false, 200},
{"/rate/b", 0, false, 429},
{"/rate/b", 1.1, false, 200},
{"/rate/b", 0, false, 200},
{"/rate/b", 0, false, 200},
{"/rate/b", 0, false, 429},
// "auth" with guest (should fallback to the /rate/ rule)
{"/rate/auth", 0, false, 200},
{"/rate/auth", 0, false, 200},
{"/rate/auth", 0, false, 429},
{"/rate/auth", 0, false, 429},
// "auth" rule with regular user (should match the /rate/auth rule)
{"/rate/auth", 0, true, 200},
{"/rate/auth", 0, true, 429},
{"/rate/auth", 0, true, 429},
// "guest" with guest (should match the /rate/guest rule)
{"/rate/guest", 0, false, 200},
{"/rate/guest", 0, false, 429},
{"/rate/guest", 0, false, 429},
// "guest" rule with regular user (should fallback to the /rate/ rule)
{"/rate/guest", 1, true, 200},
{"/rate/guest", 0, true, 200},
{"/rate/guest", 0, true, 429},
{"/rate/guest", 0, true, 429},
}
for _, s := range scenarios {
t.Run(s.url, func(t *testing.T) {
if s.wait > 0 {
time.Sleep(time.Duration(s.wait) * time.Second)
}
rec := httptest.NewRecorder()
req := httptest.NewRequest("GET", s.url, nil)
if s.authenticated {
auth, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
token, err := auth.NewAuthToken()
if err != nil {
t.Fatal(err)
}
req.Header.Add("Authorization", token)
}
mux.ServeHTTP(rec, req)
result := rec.Result()
if result.StatusCode != s.expectedStatus {
t.Fatalf("Expected response status %d, got %d", s.expectedStatus, result.StatusCode)
}
})
}
}

539
apis/middlewares_test.go Normal file
View file

@ -0,0 +1,539 @@
package apis_test
import (
"net/http"
"testing"
"github.com/pocketbase/pocketbase/apis"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
)
func TestPanicRecover(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "panic from route",
Method: http.MethodGet,
URL: "/my/test",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
e.Router.GET("/my/test", func(e *core.RequestEvent) error {
panic("123")
})
},
ExpectedStatus: 500,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "panic from middleware",
Method: http.MethodGet,
URL: "/my/test",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
e.Router.GET("/my/test", func(e *core.RequestEvent) error {
return e.String(http.StatusOK, "test")
}).BindFunc(func(e *core.RequestEvent) error {
panic(123)
})
},
ExpectedStatus: 500,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestRequireGuestOnly(t *testing.T) {
t.Parallel()
beforeTestFunc := func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
e.Router.GET("/my/test", func(e *core.RequestEvent) error {
return e.String(200, "test123")
}).Bind(apis.RequireGuestOnly())
}
scenarios := []tests.ApiScenario{
{
Name: "valid regular user token",
Method: http.MethodGet,
URL: "/my/test",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
BeforeTestFunc: beforeTestFunc,
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "valid superuser auth token",
Method: http.MethodGet,
URL: "/my/test",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
BeforeTestFunc: beforeTestFunc,
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "expired/invalid token",
Method: http.MethodGet,
URL: "/my/test",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoxNjQwOTkxNjYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.2D3tmqPn3vc5LoqqCz8V-iCDVXo9soYiH0d32G7FQT4",
},
BeforeTestFunc: beforeTestFunc,
ExpectedStatus: 200,
ExpectedContent: []string{"test123"},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "guest",
Method: http.MethodGet,
URL: "/my/test",
BeforeTestFunc: beforeTestFunc,
ExpectedStatus: 200,
ExpectedContent: []string{"test123"},
ExpectedEvents: map[string]int{"*": 0},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestRequireAuth(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "guest",
Method: http.MethodGet,
URL: "/my/test",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
e.Router.GET("/my/test", func(e *core.RequestEvent) error {
return e.String(200, "test123")
}).Bind(apis.RequireAuth())
},
ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "expired token",
Method: http.MethodGet,
URL: "/my/test",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoxNjQwOTkxNjYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.2D3tmqPn3vc5LoqqCz8V-iCDVXo9soYiH0d32G7FQT4",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
e.Router.GET("/my/test", func(e *core.RequestEvent) error {
return e.String(200, "test123")
}).Bind(apis.RequireAuth())
},
ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "invalid token",
Method: http.MethodGet,
URL: "/my/test",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6ImZpbGUiLCJjb2xsZWN0aW9uSWQiOiJwYmNfMzE0MjYzNTgyMyJ9.Lupz541xRvrktwkrl55p5pPCF77T69ZRsohsIcb2dxc",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
e.Router.GET("/my/test", func(e *core.RequestEvent) error {
return e.String(200, "test123")
}).Bind(apis.RequireAuth())
},
ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "valid record auth token with no collection restrictions",
Method: http.MethodGet,
URL: "/my/test",
Headers: map[string]string{
// regular user
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
e.Router.GET("/my/test", func(e *core.RequestEvent) error {
return e.String(200, "test123")
}).Bind(apis.RequireAuth())
},
ExpectedStatus: 200,
ExpectedContent: []string{"test123"},
},
{
Name: "valid record static auth token",
Method: http.MethodGet,
URL: "/my/test",
Headers: map[string]string{
// regular user
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6ZmFsc2V9.4IsO6YMsR19crhwl_YWzvRH8pfq2Ri4Gv2dzGyneLak",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
e.Router.GET("/my/test", func(e *core.RequestEvent) error {
return e.String(200, "test123")
}).Bind(apis.RequireAuth())
},
ExpectedStatus: 200,
ExpectedContent: []string{"test123"},
},
{
Name: "valid record auth token with collection not in the restricted list",
Method: http.MethodGet,
URL: "/my/test",
Headers: map[string]string{
// superuser
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
e.Router.GET("/my/test", func(e *core.RequestEvent) error {
return e.String(200, "test123")
}).Bind(apis.RequireAuth("users", "demo1"))
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "valid record auth token with collection in the restricted list",
Method: http.MethodGet,
URL: "/my/test",
Headers: map[string]string{
// superuser
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
e.Router.GET("/my/test", func(e *core.RequestEvent) error {
return e.String(200, "test123")
}).Bind(apis.RequireAuth("users", core.CollectionNameSuperusers))
},
ExpectedStatus: 200,
ExpectedContent: []string{"test123"},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestRequireSuperuserAuth(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "guest",
Method: http.MethodGet,
URL: "/my/test",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
e.Router.GET("/my/test", func(e *core.RequestEvent) error {
return e.String(200, "test123")
}).Bind(apis.RequireSuperuserAuth())
},
ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "expired/invalid token",
Method: http.MethodGet,
URL: "/my/test",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjE2NDA5OTE2NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.0pDcBPGDpL2Khh76ivlRi7ugiLBSYvasct3qpHV3rfs",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
e.Router.GET("/my/test", func(e *core.RequestEvent) error {
return e.String(200, "test123")
}).Bind(apis.RequireSuperuserAuth())
},
ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "valid regular user auth token",
Method: http.MethodGet,
URL: "/my/test",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
e.Router.GET("/my/test", func(e *core.RequestEvent) error {
return e.String(200, "test123")
}).Bind(apis.RequireSuperuserAuth())
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "valid superuser auth token",
Method: http.MethodGet,
URL: "/my/test",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
e.Router.GET("/my/test", func(e *core.RequestEvent) error {
return e.String(200, "test123")
}).Bind(apis.RequireSuperuserAuth())
},
ExpectedStatus: 200,
ExpectedContent: []string{"test123"},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestRequireSuperuserOrOwnerAuth(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "guest",
Method: http.MethodGet,
URL: "/my/test/4q1xlclmfloku33",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
e.Router.GET("/my/test/{id}", func(e *core.RequestEvent) error {
return e.String(200, "test123")
}).Bind(apis.RequireSuperuserOrOwnerAuth(""))
},
ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "expired/invalid token",
Method: http.MethodGet,
URL: "/my/test/4q1xlclmfloku33",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjE2NDA5OTE2NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.0pDcBPGDpL2Khh76ivlRi7ugiLBSYvasct3qpHV3rfs",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
e.Router.GET("/my/test/{id}", func(e *core.RequestEvent) error {
return e.String(200, "test123")
}).Bind(apis.RequireSuperuserOrOwnerAuth(""))
},
ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "valid record auth token (different user)",
Method: http.MethodGet,
URL: "/my/test/oap640cot4yru2s",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
e.Router.GET("/my/test/{id}", func(e *core.RequestEvent) error {
return e.String(200, "test123")
}).Bind(apis.RequireSuperuserOrOwnerAuth(""))
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "valid record auth token (owner)",
Method: http.MethodGet,
URL: "/my/test/4q1xlclmfloku33",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
e.Router.GET("/my/test/{id}", func(e *core.RequestEvent) error {
return e.String(200, "test123")
}).Bind(apis.RequireSuperuserOrOwnerAuth(""))
},
ExpectedStatus: 200,
ExpectedContent: []string{"test123"},
},
{
Name: "valid record auth token (owner + non-matching custom owner param)",
Method: http.MethodGet,
URL: "/my/test/4q1xlclmfloku33",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
e.Router.GET("/my/test/{id}", func(e *core.RequestEvent) error {
return e.String(200, "test123")
}).Bind(apis.RequireSuperuserOrOwnerAuth("test"))
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "valid record auth token (owner + matching custom owner param)",
Method: http.MethodGet,
URL: "/my/test/4q1xlclmfloku33",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
e.Router.GET("/my/test/{test}", func(e *core.RequestEvent) error {
return e.String(200, "test123")
}).Bind(apis.RequireSuperuserOrOwnerAuth("test"))
},
ExpectedStatus: 200,
ExpectedContent: []string{"test123"},
},
{
Name: "valid superuser auth token",
Method: http.MethodGet,
URL: "/my/test/4q1xlclmfloku33",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
e.Router.GET("/my/test/{id}", func(e *core.RequestEvent) error {
return e.String(200, "test123")
}).Bind(apis.RequireSuperuserOrOwnerAuth(""))
},
ExpectedStatus: 200,
ExpectedContent: []string{"test123"},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestRequireSameCollectionContextAuth(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "guest",
Method: http.MethodGet,
URL: "/my/test/_pb_users_auth_",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
e.Router.GET("/my/test/{collection}", func(e *core.RequestEvent) error {
return e.String(200, "test123")
}).Bind(apis.RequireSameCollectionContextAuth(""))
},
ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "expired/invalid token",
Method: http.MethodGet,
URL: "/my/test/_pb_users_auth_",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoxNjQwOTkxNjYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.2D3tmqPn3vc5LoqqCz8V-iCDVXo9soYiH0d32G7FQT4",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
e.Router.GET("/my/test/{collection}", func(e *core.RequestEvent) error {
return e.String(200, "test123")
}).Bind(apis.RequireSameCollectionContextAuth(""))
},
ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "valid record auth token (different collection)",
Method: http.MethodGet,
URL: "/my/test/clients",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
e.Router.GET("/my/test/{collection}", func(e *core.RequestEvent) error {
return e.String(200, "test123")
}).Bind(apis.RequireSameCollectionContextAuth(""))
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "valid record auth token (same collection)",
Method: http.MethodGet,
URL: "/my/test/_pb_users_auth_",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
e.Router.GET("/my/test/{collection}", func(e *core.RequestEvent) error {
return e.String(200, "test123")
}).Bind(apis.RequireSameCollectionContextAuth(""))
},
ExpectedStatus: 200,
ExpectedContent: []string{"test123"},
},
{
Name: "valid record auth token (non-matching/missing collection param)",
Method: http.MethodGet,
URL: "/my/test/_pb_users_auth_",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
e.Router.GET("/my/test/{id}", func(e *core.RequestEvent) error {
return e.String(200, "test123")
}).Bind(apis.RequireSuperuserOrOwnerAuth(""))
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "valid record auth token (matching custom collection param)",
Method: http.MethodGet,
URL: "/my/test/_pb_users_auth_",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
e.Router.GET("/my/test/{test}", func(e *core.RequestEvent) error {
return e.String(200, "test123")
}).Bind(apis.RequireSuperuserOrOwnerAuth("test"))
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "superuser no exception check",
Method: http.MethodGet,
URL: "/my/test/_pb_users_auth_",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
e.Router.GET("/my/test/{collection}", func(e *core.RequestEvent) error {
return e.String(200, "test123")
}).Bind(apis.RequireSameCollectionContextAuth(""))
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}

777
apis/realtime.go Normal file
View file

@ -0,0 +1,777 @@
package apis
import (
"context"
"encoding/json"
"errors"
"fmt"
"log/slog"
"net/http"
"strings"
"time"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tools/hook"
"github.com/pocketbase/pocketbase/tools/picker"
"github.com/pocketbase/pocketbase/tools/router"
"github.com/pocketbase/pocketbase/tools/routine"
"github.com/pocketbase/pocketbase/tools/search"
"github.com/pocketbase/pocketbase/tools/subscriptions"
"golang.org/x/sync/errgroup"
)
// note: the chunk size is arbitrary chosen and may change in the future
const clientsChunkSize = 150
// RealtimeClientAuthKey is the name of the realtime client store key that holds its auth state.
const RealtimeClientAuthKey = "auth"
// bindRealtimeApi registers the realtime api endpoints.
func bindRealtimeApi(app core.App, rg *router.RouterGroup[*core.RequestEvent]) {
sub := rg.Group("/realtime")
sub.GET("", realtimeConnect).Bind(SkipSuccessActivityLog())
sub.POST("", realtimeSetSubscriptions)
bindRealtimeEvents(app)
}
func realtimeConnect(e *core.RequestEvent) error {
// disable global write deadline for the SSE connection
rc := http.NewResponseController(e.Response)
writeDeadlineErr := rc.SetWriteDeadline(time.Time{})
if writeDeadlineErr != nil {
if !errors.Is(writeDeadlineErr, http.ErrNotSupported) {
return e.InternalServerError("Failed to initialize SSE connection.", writeDeadlineErr)
}
// only log since there are valid cases where it may not be implement (e.g. httptest.ResponseRecorder)
e.App.Logger().Warn("SetWriteDeadline is not supported, fallback to the default server WriteTimeout")
}
// create cancellable request
cancelCtx, cancelRequest := context.WithCancel(e.Request.Context())
defer cancelRequest()
e.Request = e.Request.Clone(cancelCtx)
e.Response.Header().Set("Content-Type", "text/event-stream")
e.Response.Header().Set("Cache-Control", "no-store")
// https://github.com/pocketbase/pocketbase/discussions/480#discussioncomment-3657640
// https://nginx.org/en/docs/http/ngx_http_proxy_module.html#proxy_buffering
e.Response.Header().Set("X-Accel-Buffering", "no")
connectEvent := new(core.RealtimeConnectRequestEvent)
connectEvent.RequestEvent = e
connectEvent.Client = subscriptions.NewDefaultClient()
connectEvent.IdleTimeout = 5 * time.Minute
return e.App.OnRealtimeConnectRequest().Trigger(connectEvent, func(ce *core.RealtimeConnectRequestEvent) error {
// register new subscription client
ce.App.SubscriptionsBroker().Register(ce.Client)
defer func() {
e.App.SubscriptionsBroker().Unregister(ce.Client.Id())
}()
ce.App.Logger().Debug("Realtime connection established.", slog.String("clientId", ce.Client.Id()))
// signalize established connection (aka. fire "connect" message)
connectMsgEvent := new(core.RealtimeMessageEvent)
connectMsgEvent.RequestEvent = ce.RequestEvent
connectMsgEvent.Client = ce.Client
connectMsgEvent.Message = &subscriptions.Message{
Name: "PB_CONNECT",
Data: []byte(`{"clientId":"` + ce.Client.Id() + `"}`),
}
connectMsgErr := ce.App.OnRealtimeMessageSend().Trigger(connectMsgEvent, func(me *core.RealtimeMessageEvent) error {
err := me.Message.WriteSSE(me.Response, me.Client.Id())
if err != nil {
return err
}
return me.Flush()
})
if connectMsgErr != nil {
ce.App.Logger().Debug(
"Realtime connection closed (failed to deliver PB_CONNECT)",
slog.String("clientId", ce.Client.Id()),
slog.String("error", connectMsgErr.Error()),
)
return nil
}
// start an idle timer to keep track of inactive/forgotten connections
idleTimer := time.NewTimer(ce.IdleTimeout)
defer idleTimer.Stop()
for {
select {
case <-idleTimer.C:
cancelRequest()
case msg, ok := <-ce.Client.Channel():
if !ok {
// channel is closed
ce.App.Logger().Debug(
"Realtime connection closed (closed channel)",
slog.String("clientId", ce.Client.Id()),
)
return nil
}
msgEvent := new(core.RealtimeMessageEvent)
msgEvent.RequestEvent = ce.RequestEvent
msgEvent.Client = ce.Client
msgEvent.Message = &msg
msgErr := ce.App.OnRealtimeMessageSend().Trigger(msgEvent, func(me *core.RealtimeMessageEvent) error {
err := me.Message.WriteSSE(me.Response, me.Client.Id())
if err != nil {
return err
}
return me.Flush()
})
if msgErr != nil {
ce.App.Logger().Debug(
"Realtime connection closed (failed to deliver message)",
slog.String("clientId", ce.Client.Id()),
slog.String("error", msgErr.Error()),
)
return nil
}
idleTimer.Stop()
idleTimer.Reset(ce.IdleTimeout)
case <-ce.Request.Context().Done():
// connection is closed
ce.App.Logger().Debug(
"Realtime connection closed (cancelled request)",
slog.String("clientId", ce.Client.Id()),
)
return nil
}
}
})
}
type realtimeSubscribeForm struct {
ClientId string `form:"clientId" json:"clientId"`
Subscriptions []string `form:"subscriptions" json:"subscriptions"`
}
func (form *realtimeSubscribeForm) validate() error {
return validation.ValidateStruct(form,
validation.Field(&form.ClientId, validation.Required, validation.Length(1, 255)),
validation.Field(&form.Subscriptions,
validation.Length(0, 1000),
validation.Each(validation.Length(0, 2500)),
),
)
}
// note: in case of reconnect, clients will have to resubmit all subscriptions again
func realtimeSetSubscriptions(e *core.RequestEvent) error {
form := new(realtimeSubscribeForm)
err := e.BindBody(form)
if err != nil {
return e.BadRequestError("", err)
}
err = form.validate()
if err != nil {
return e.BadRequestError("", err)
}
// find subscription client
client, err := e.App.SubscriptionsBroker().ClientById(form.ClientId)
if err != nil {
return e.NotFoundError("Missing or invalid client id.", err)
}
// for now allow only guest->auth upgrades and any other auth change is forbidden
clientAuth, _ := client.Get(RealtimeClientAuthKey).(*core.Record)
if clientAuth != nil && !isSameAuth(clientAuth, e.Auth) {
return e.ForbiddenError("The current and the previous request authorization don't match.", nil)
}
event := new(core.RealtimeSubscribeRequestEvent)
event.RequestEvent = e
event.Client = client
event.Subscriptions = form.Subscriptions
return e.App.OnRealtimeSubscribeRequest().Trigger(event, func(e *core.RealtimeSubscribeRequestEvent) error {
// update auth state
e.Client.Set(RealtimeClientAuthKey, e.Auth)
// unsubscribe from any previous existing subscriptions
e.Client.Unsubscribe()
// subscribe to the new subscriptions
e.Client.Subscribe(e.Subscriptions...)
e.App.Logger().Debug(
"Realtime subscriptions updated.",
slog.String("clientId", e.Client.Id()),
slog.Any("subscriptions", e.Subscriptions),
)
return execAfterSuccessTx(true, e.App, func() error {
return e.NoContent(http.StatusNoContent)
})
})
}
// updateClientsAuth updates the existing clients auth record with the new one (matched by ID).
func realtimeUpdateClientsAuth(app core.App, newAuthRecord *core.Record) error {
chunks := app.SubscriptionsBroker().ChunkedClients(clientsChunkSize)
group := new(errgroup.Group)
for _, chunk := range chunks {
group.Go(func() error {
for _, client := range chunk {
clientAuth, _ := client.Get(RealtimeClientAuthKey).(*core.Record)
if clientAuth != nil &&
clientAuth.Id == newAuthRecord.Id &&
clientAuth.Collection().Name == newAuthRecord.Collection().Name {
client.Set(RealtimeClientAuthKey, newAuthRecord)
}
}
return nil
})
}
return group.Wait()
}
// realtimeUnsetClientsAuthState unsets the auth state of all clients that have the provided auth model.
func realtimeUnsetClientsAuthState(app core.App, authModel core.Model) error {
chunks := app.SubscriptionsBroker().ChunkedClients(clientsChunkSize)
group := new(errgroup.Group)
for _, chunk := range chunks {
group.Go(func() error {
for _, client := range chunk {
clientAuth, _ := client.Get(RealtimeClientAuthKey).(*core.Record)
if clientAuth != nil &&
clientAuth.Id == authModel.PK() &&
clientAuth.Collection().Name == authModel.TableName() {
client.Unset(RealtimeClientAuthKey)
}
}
return nil
})
}
return group.Wait()
}
func bindRealtimeEvents(app core.App) {
// update the clients that has auth record association
app.OnModelAfterUpdateSuccess().Bind(&hook.Handler[*core.ModelEvent]{
Func: func(e *core.ModelEvent) error {
authRecord := realtimeResolveRecord(e.App, e.Model, core.CollectionTypeAuth)
if authRecord != nil {
if err := realtimeUpdateClientsAuth(e.App, authRecord); err != nil {
app.Logger().Warn(
"Failed to update client(s) associated to the updated auth record",
slog.Any("id", authRecord.Id),
slog.String("collectionName", authRecord.Collection().Name),
slog.String("error", err.Error()),
)
}
}
return e.Next()
},
Priority: -99,
})
// remove the client(s) associated to the deleted auth model
// (note: works also with custom model for backward compatibility)
app.OnModelAfterDeleteSuccess().Bind(&hook.Handler[*core.ModelEvent]{
Func: func(e *core.ModelEvent) error {
collection := realtimeResolveRecordCollection(e.App, e.Model)
if collection != nil && collection.IsAuth() {
if err := realtimeUnsetClientsAuthState(e.App, e.Model); err != nil {
app.Logger().Warn(
"Failed to remove client(s) associated to the deleted auth model",
slog.Any("id", e.Model.PK()),
slog.String("collectionName", e.Model.TableName()),
slog.String("error", err.Error()),
)
}
}
return e.Next()
},
Priority: -99,
})
app.OnModelAfterCreateSuccess().Bind(&hook.Handler[*core.ModelEvent]{
Func: func(e *core.ModelEvent) error {
record := realtimeResolveRecord(e.App, e.Model, "")
if record != nil {
err := realtimeBroadcastRecord(e.App, "create", record, false)
if err != nil {
app.Logger().Debug(
"Failed to broadcast record create",
slog.String("id", record.Id),
slog.String("collectionName", record.Collection().Name),
slog.String("error", err.Error()),
)
}
}
return e.Next()
},
Priority: -99,
})
app.OnModelAfterUpdateSuccess().Bind(&hook.Handler[*core.ModelEvent]{
Func: func(e *core.ModelEvent) error {
record := realtimeResolveRecord(e.App, e.Model, "")
if record != nil {
err := realtimeBroadcastRecord(e.App, "update", record, false)
if err != nil {
app.Logger().Debug(
"Failed to broadcast record update",
slog.String("id", record.Id),
slog.String("collectionName", record.Collection().Name),
slog.String("error", err.Error()),
)
}
}
return e.Next()
},
Priority: -99,
})
// delete: dry cache
app.OnModelDelete().Bind(&hook.Handler[*core.ModelEvent]{
Func: func(e *core.ModelEvent) error {
record := realtimeResolveRecord(e.App, e.Model, "")
if record != nil {
// note: use the outside scoped app instance for the access checks so that the API rules
// are performed out of the delete transaction ensuring that they would still work even if
// a cascade-deleted record's API rule relies on an already deleted parent record
err := realtimeBroadcastRecord(e.App, "delete", record, true, app)
if err != nil {
app.Logger().Debug(
"Failed to dry cache record delete",
slog.String("id", record.Id),
slog.String("collectionName", record.Collection().Name),
slog.String("error", err.Error()),
)
}
}
return e.Next()
},
Priority: 99, // execute as later as possible
})
// delete: broadcast
app.OnModelAfterDeleteSuccess().Bind(&hook.Handler[*core.ModelEvent]{
Func: func(e *core.ModelEvent) error {
// note: only ensure that it is a collection record
// and don't use realtimeResolveRecord because in case of a
// custom model it'll fail to resolve since the record is already deleted
collection := realtimeResolveRecordCollection(e.App, e.Model)
if collection != nil {
err := realtimeBroadcastDryCacheKey(e.App, getDryCacheKey("delete", e.Model))
if err != nil {
app.Logger().Debug(
"Failed to broadcast record delete",
slog.Any("id", e.Model.PK()),
slog.String("collectionName", collection.Name),
slog.String("error", err.Error()),
)
}
}
return e.Next()
},
Priority: -99,
})
// delete: failure
app.OnModelAfterDeleteError().Bind(&hook.Handler[*core.ModelErrorEvent]{
Func: func(e *core.ModelErrorEvent) error {
record := realtimeResolveRecord(e.App, e.Model, "")
if record != nil {
err := realtimeUnsetDryCacheKey(e.App, getDryCacheKey("delete", record))
if err != nil {
app.Logger().Debug(
"Failed to cleanup after broadcast record delete failure",
slog.String("id", record.Id),
slog.String("collectionName", record.Collection().Name),
slog.String("error", err.Error()),
)
}
}
return e.Next()
},
Priority: -99,
})
}
// resolveRecord converts *if possible* the provided model interface to a Record.
// This is usually helpful if the provided model is a custom Record model struct.
func realtimeResolveRecord(app core.App, model core.Model, optCollectionType string) *core.Record {
var record *core.Record
switch m := model.(type) {
case *core.Record:
record = m
case core.RecordProxy:
record = m.ProxyRecord()
}
if record != nil {
if optCollectionType == "" || record.Collection().Type == optCollectionType {
return record
}
return nil
}
tblName := model.TableName()
// skip Log model checks
if tblName == core.LogsTableName {
return nil
}
// check if it is custom Record model struct
collection, _ := app.FindCachedCollectionByNameOrId(tblName)
if collection != nil && (optCollectionType == "" || collection.Type == optCollectionType) {
if id, ok := model.PK().(string); ok {
record, _ = app.FindRecordById(collection, id)
}
}
return record
}
// realtimeResolveRecordCollection extracts *if possible* the Collection model from the provided model interface.
// This is usually helpful if the provided model is a custom Record model struct.
func realtimeResolveRecordCollection(app core.App, model core.Model) (collection *core.Collection) {
switch m := model.(type) {
case *core.Record:
return m.Collection()
case core.RecordProxy:
return m.ProxyRecord().Collection()
default:
// check if it is custom Record model struct
collection, err := app.FindCachedCollectionByNameOrId(model.TableName())
if err == nil {
return collection
}
}
return nil
}
// recordData represents the broadcasted record subscrition message data.
type recordData struct {
Record any `json:"record"` /* map or core.Record */
Action string `json:"action"`
}
// Note: the optAccessCheckApp is there in case you want the access check
// to be performed against different db app context (e.g. out of a transaction).
// If set, it is expected that optAccessCheckApp instance is used for read-only operations to avoid deadlocks.
// If not set, it fallbacks to app.
func realtimeBroadcastRecord(app core.App, action string, record *core.Record, dryCache bool, optAccessCheckApp ...core.App) error {
collection := record.Collection()
if collection == nil {
return errors.New("[broadcastRecord] Record collection not set")
}
chunks := app.SubscriptionsBroker().ChunkedClients(clientsChunkSize)
if len(chunks) == 0 {
return nil // no subscribers
}
subscriptionRuleMap := map[string]*string{
(collection.Name + "/" + record.Id + "?"): collection.ViewRule,
(collection.Id + "/" + record.Id + "?"): collection.ViewRule,
(collection.Name + "/*?"): collection.ListRule,
(collection.Id + "/*?"): collection.ListRule,
// @deprecated: the same as the wildcard topic but kept for backward compatibility
(collection.Name + "?"): collection.ListRule,
(collection.Id + "?"): collection.ListRule,
}
dryCacheKey := getDryCacheKey(action, record)
group := new(errgroup.Group)
accessCheckApp := app
if len(optAccessCheckApp) > 0 {
accessCheckApp = optAccessCheckApp[0]
}
for _, chunk := range chunks {
group.Go(func() error {
var clientAuth *core.Record
for _, client := range chunk {
// note: not executed concurrently to avoid races and to ensure
// that the access checks are applied for the current record db state
for prefix, rule := range subscriptionRuleMap {
subs := client.Subscriptions(prefix)
if len(subs) == 0 {
continue
}
clientAuth, _ = client.Get(RealtimeClientAuthKey).(*core.Record)
for sub, options := range subs {
// mock request data
requestInfo := &core.RequestInfo{
Context: core.RequestInfoContextRealtime,
Method: "GET",
Query: options.Query,
Headers: options.Headers,
Auth: clientAuth,
}
if !realtimeCanAccessRecord(accessCheckApp, record, requestInfo, rule) {
continue
}
// create a clean record copy without expand and unknown fields because we don't know yet
// which exact fields the client subscription requested or has permissions to access
cleanRecord := record.Fresh()
// trigger the enrich hooks
enrichErr := triggerRecordEnrichHooks(app, requestInfo, []*core.Record{cleanRecord}, func() error {
// apply expand
rawExpand := options.Query[expandQueryParam]
if rawExpand != "" {
expandErrs := app.ExpandRecord(cleanRecord, strings.Split(rawExpand, ","), expandFetch(app, requestInfo))
if len(expandErrs) > 0 {
app.Logger().Debug(
"[broadcastRecord] expand errors",
slog.String("id", cleanRecord.Id),
slog.String("collectionName", cleanRecord.Collection().Name),
slog.String("sub", sub),
slog.String("expand", rawExpand),
slog.Any("errors", expandErrs),
)
}
}
// ignore the auth record email visibility checks
// for auth owner, superuser or manager
if collection.IsAuth() {
if isSameAuth(clientAuth, cleanRecord) ||
realtimeCanAccessRecord(accessCheckApp, cleanRecord, requestInfo, collection.ManageRule) {
cleanRecord.IgnoreEmailVisibility(true)
}
}
return nil
})
if enrichErr != nil {
app.Logger().Debug(
"[broadcastRecord] record enrich error",
slog.String("id", cleanRecord.Id),
slog.String("collectionName", cleanRecord.Collection().Name),
slog.String("sub", sub),
slog.Any("error", enrichErr),
)
continue
}
data := &recordData{
Action: action,
Record: cleanRecord,
}
// check fields
rawFields := options.Query[fieldsQueryParam]
if rawFields != "" {
decoded, err := picker.Pick(cleanRecord, rawFields)
if err == nil {
data.Record = decoded
} else {
app.Logger().Debug(
"[broadcastRecord] pick fields error",
slog.String("id", cleanRecord.Id),
slog.String("collectionName", cleanRecord.Collection().Name),
slog.String("sub", sub),
slog.String("fields", rawFields),
slog.String("error", err.Error()),
)
}
}
dataBytes, err := json.Marshal(data)
if err != nil {
app.Logger().Debug(
"[broadcastRecord] data marshal error",
slog.String("id", cleanRecord.Id),
slog.String("collectionName", cleanRecord.Collection().Name),
slog.String("error", err.Error()),
)
continue
}
msg := subscriptions.Message{
Name: sub,
Data: dataBytes,
}
if dryCache {
messages, ok := client.Get(dryCacheKey).([]subscriptions.Message)
if !ok {
messages = []subscriptions.Message{msg}
} else {
messages = append(messages, msg)
}
client.Set(dryCacheKey, messages)
} else {
routine.FireAndForget(func() {
client.Send(msg)
})
}
}
}
}
return nil
})
}
return group.Wait()
}
// realtimeBroadcastDryCacheKey broadcasts the dry cached key related messages.
func realtimeBroadcastDryCacheKey(app core.App, key string) error {
chunks := app.SubscriptionsBroker().ChunkedClients(clientsChunkSize)
if len(chunks) == 0 {
return nil // no subscribers
}
group := new(errgroup.Group)
for _, chunk := range chunks {
group.Go(func() error {
for _, client := range chunk {
messages, ok := client.Get(key).([]subscriptions.Message)
if !ok {
continue
}
client.Unset(key)
client := client
routine.FireAndForget(func() {
for _, msg := range messages {
client.Send(msg)
}
})
}
return nil
})
}
return group.Wait()
}
// realtimeUnsetDryCacheKey removes the dry cached key related messages.
func realtimeUnsetDryCacheKey(app core.App, key string) error {
chunks := app.SubscriptionsBroker().ChunkedClients(clientsChunkSize)
if len(chunks) == 0 {
return nil // no subscribers
}
group := new(errgroup.Group)
for _, chunk := range chunks {
group.Go(func() error {
for _, client := range chunk {
if client.Get(key) != nil {
client.Unset(key)
}
}
return nil
})
}
return group.Wait()
}
func getDryCacheKey(action string, model core.Model) string {
pkStr, ok := model.PK().(string)
if !ok {
pkStr = fmt.Sprintf("%v", model.PK())
}
return action + "/" + model.TableName() + "/" + pkStr
}
func isSameAuth(authA, authB *core.Record) bool {
if authA == nil {
return authB == nil
}
if authB == nil {
return false
}
return authA.Id == authB.Id && authA.Collection().Id == authB.Collection().Id
}
// realtimeCanAccessRecord checks if the subscription client has access to the specified record model.
func realtimeCanAccessRecord(
app core.App,
record *core.Record,
requestInfo *core.RequestInfo,
accessRule *string,
) bool {
// check the access rule
// ---
if ok, _ := app.CanAccessRecord(record, requestInfo, accessRule); !ok {
return false
}
// check the subscription client-side filter (if any)
// ---
filter := requestInfo.Query[search.FilterQueryParam]
if filter == "" {
return true // no further checks needed
}
err := checkForSuperuserOnlyRuleFields(requestInfo)
if err != nil {
return false
}
var exists int
q := app.ConcurrentDB().Select("(1)").
From(record.Collection().Name).
AndWhere(dbx.HashExp{record.Collection().Name + ".id": record.Id})
resolver := core.NewRecordFieldResolver(app, record.Collection(), requestInfo, false)
expr, err := search.FilterData(filter).BuildExpr(resolver)
if err != nil {
return false
}
q.AndWhere(expr)
resolver.UpdateQuery(q)
err = q.Limit(1).Row(&exists)
return err == nil && exists > 0
}

885
apis/realtime_test.go Normal file
View file

@ -0,0 +1,885 @@
package apis_test
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"slices"
"strings"
"sync"
"testing"
"time"
"github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase/apis"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
"github.com/pocketbase/pocketbase/tools/subscriptions"
"github.com/pocketbase/pocketbase/tools/types"
)
func TestRealtimeConnect(t *testing.T) {
scenarios := []tests.ApiScenario{
{
Method: http.MethodGet,
URL: "/api/realtime",
Timeout: 100 * time.Millisecond,
ExpectedStatus: 200,
ExpectedContent: []string{
`id:`,
`event:PB_CONNECT`,
`data:{"clientId":`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRealtimeConnectRequest": 1,
"OnRealtimeMessageSend": 1,
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
if len(app.SubscriptionsBroker().Clients()) != 0 {
t.Errorf("Expected the subscribers to be removed after connection close, found %d", len(app.SubscriptionsBroker().Clients()))
}
},
},
{
Name: "PB_CONNECT interrupt",
Method: http.MethodGet,
URL: "/api/realtime",
Timeout: 100 * time.Millisecond,
ExpectedStatus: 200,
ExpectedEvents: map[string]int{
"*": 0,
"OnRealtimeConnectRequest": 1,
"OnRealtimeMessageSend": 1,
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.OnRealtimeMessageSend().BindFunc(func(e *core.RealtimeMessageEvent) error {
if e.Message.Name == "PB_CONNECT" {
return errors.New("PB_CONNECT error")
}
return e.Next()
})
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
if len(app.SubscriptionsBroker().Clients()) != 0 {
t.Errorf("Expected the subscribers to be removed after connection close, found %d", len(app.SubscriptionsBroker().Clients()))
}
},
},
{
Name: "Skipping/ignoring messages",
Method: http.MethodGet,
URL: "/api/realtime",
Timeout: 100 * time.Millisecond,
ExpectedStatus: 200,
ExpectedEvents: map[string]int{
"*": 0,
"OnRealtimeConnectRequest": 1,
"OnRealtimeMessageSend": 1,
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.OnRealtimeMessageSend().BindFunc(func(e *core.RealtimeMessageEvent) error {
return nil
})
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
if len(app.SubscriptionsBroker().Clients()) != 0 {
t.Errorf("Expected the subscribers to be removed after connection close, found %d", len(app.SubscriptionsBroker().Clients()))
}
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestRealtimeSubscribe(t *testing.T) {
client := subscriptions.NewDefaultClient()
resetClient := func() {
client.Unsubscribe()
client.Set(apis.RealtimeClientAuthKey, nil)
}
validSubscriptionsLimit := make([]string, 1000)
for i := 0; i < len(validSubscriptionsLimit); i++ {
validSubscriptionsLimit[i] = fmt.Sprintf(`"%d"`, i)
}
invalidSubscriptionsLimit := make([]string, 1001)
for i := 0; i < len(invalidSubscriptionsLimit); i++ {
invalidSubscriptionsLimit[i] = fmt.Sprintf(`"%d"`, i)
}
scenarios := []tests.ApiScenario{
{
Name: "missing client",
Method: http.MethodPost,
URL: "/api/realtime",
Body: strings.NewReader(`{"clientId":"missing","subscriptions":["test1", "test2"]}`),
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "empty data",
Method: http.MethodPost,
URL: "/api/realtime",
Body: strings.NewReader(`{}`),
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{`,
`"clientId":{"code":"validation_required`,
},
NotExpectedContent: []string{
`"subscriptions"`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "existing client with invalid subscriptions limit",
Method: http.MethodPost,
URL: "/api/realtime",
Body: strings.NewReader(`{
"clientId": "` + client.Id() + `",
"subscriptions": [` + strings.Join(invalidSubscriptionsLimit, ",") + `]
}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.SubscriptionsBroker().Register(client)
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
resetClient()
},
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{`,
`"subscriptions":{"code":"validation_length_too_long"`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "existing client with valid subscriptions limit",
Method: http.MethodPost,
URL: "/api/realtime",
Body: strings.NewReader(`{
"clientId": "` + client.Id() + `",
"subscriptions": [` + strings.Join(validSubscriptionsLimit, ",") + `]
}`),
ExpectedStatus: 204,
ExpectedEvents: map[string]int{
"*": 0,
"OnRealtimeSubscribeRequest": 1,
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
client.Subscribe("test0") // should be replaced
app.SubscriptionsBroker().Register(client)
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
if len(client.Subscriptions()) != len(validSubscriptionsLimit) {
t.Errorf("Expected %d subscriptions, got %d", len(validSubscriptionsLimit), len(client.Subscriptions()))
}
if client.HasSubscription("test0") {
t.Errorf("Expected old subscriptions to be replaced")
}
resetClient()
},
},
{
Name: "existing client with invalid topic length",
Method: http.MethodPost,
URL: "/api/realtime",
Body: strings.NewReader(`{
"clientId": "` + client.Id() + `",
"subscriptions": ["abc", "` + strings.Repeat("a", 2501) + `"]
}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.SubscriptionsBroker().Register(client)
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
resetClient()
},
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{`,
`"subscriptions":{"1":{"code":"validation_length_too_long"`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "existing client with valid topic length",
Method: http.MethodPost,
URL: "/api/realtime",
Body: strings.NewReader(`{
"clientId": "` + client.Id() + `",
"subscriptions": ["abc", "` + strings.Repeat("a", 2500) + `"]
}`),
ExpectedStatus: 204,
ExpectedEvents: map[string]int{
"*": 0,
"OnRealtimeSubscribeRequest": 1,
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
client.Subscribe("test0")
app.SubscriptionsBroker().Register(client)
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
if len(client.Subscriptions()) != 2 {
t.Errorf("Expected %d subscriptions, got %d", 2, len(client.Subscriptions()))
}
if client.HasSubscription("test0") {
t.Errorf("Expected old subscriptions to be replaced")
}
resetClient()
},
},
{
Name: "existing client - empty subscriptions",
Method: http.MethodPost,
URL: "/api/realtime",
Body: strings.NewReader(`{"clientId":"` + client.Id() + `","subscriptions":[]}`),
ExpectedStatus: 204,
ExpectedEvents: map[string]int{
"*": 0,
"OnRealtimeSubscribeRequest": 1,
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
client.Subscribe("test0")
app.SubscriptionsBroker().Register(client)
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
if len(client.Subscriptions()) != 0 {
t.Errorf("Expected no subscriptions, got %d", len(client.Subscriptions()))
}
resetClient()
},
},
{
Name: "existing client - 2 new subscriptions",
Method: http.MethodPost,
URL: "/api/realtime",
Body: strings.NewReader(`{"clientId":"` + client.Id() + `","subscriptions":["test1", "test2"]}`),
ExpectedStatus: 204,
ExpectedEvents: map[string]int{
"*": 0,
"OnRealtimeSubscribeRequest": 1,
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
client.Subscribe("test0")
app.SubscriptionsBroker().Register(client)
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
expectedSubs := []string{"test1", "test2"}
if len(expectedSubs) != len(client.Subscriptions()) {
t.Errorf("Expected subscriptions %v, got %v", expectedSubs, client.Subscriptions())
}
for _, s := range expectedSubs {
if !client.HasSubscription(s) {
t.Errorf("Cannot find %q subscription in %v", s, client.Subscriptions())
}
}
resetClient()
},
},
{
Name: "existing client - guest -> authorized superuser",
Method: http.MethodPost,
URL: "/api/realtime",
Body: strings.NewReader(`{"clientId":"` + client.Id() + `","subscriptions":["test1", "test2"]}`),
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
ExpectedStatus: 204,
ExpectedEvents: map[string]int{
"*": 0,
"OnRealtimeSubscribeRequest": 1,
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.SubscriptionsBroker().Register(client)
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
authRecord, _ := client.Get(apis.RealtimeClientAuthKey).(*core.Record)
if authRecord == nil || !authRecord.IsSuperuser() {
t.Errorf("Expected superuser auth record, got %v", authRecord)
}
resetClient()
},
},
{
Name: "existing client - guest -> authorized regular auth record",
Method: http.MethodPost,
URL: "/api/realtime",
Body: strings.NewReader(`{"clientId":"` + client.Id() + `","subscriptions":["test1", "test2"]}`),
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 204,
ExpectedEvents: map[string]int{
"*": 0,
"OnRealtimeSubscribeRequest": 1,
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.SubscriptionsBroker().Register(client)
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
authRecord, _ := client.Get(apis.RealtimeClientAuthKey).(*core.Record)
if authRecord == nil {
t.Errorf("Expected regular user auth record, got %v", authRecord)
}
resetClient()
},
},
{
Name: "existing client - same auth",
Method: http.MethodPost,
URL: "/api/realtime",
Body: strings.NewReader(`{"clientId":"` + client.Id() + `","subscriptions":["test1", "test2"]}`),
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 204,
ExpectedEvents: map[string]int{
"*": 0,
"OnRealtimeSubscribeRequest": 1,
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
// the same user as the auth token
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
client.Set(apis.RealtimeClientAuthKey, user)
app.SubscriptionsBroker().Register(client)
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
authRecord, _ := client.Get(apis.RealtimeClientAuthKey).(*core.Record)
if authRecord == nil {
t.Errorf("Expected auth record model, got nil")
}
resetClient()
},
},
{
Name: "existing client - mismatched auth",
Method: http.MethodPost,
URL: "/api/realtime",
Body: strings.NewReader(`{"clientId":"` + client.Id() + `","subscriptions":["test1", "test2"]}`),
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
user, err := app.FindAuthRecordByEmail("users", "test2@example.com")
if err != nil {
t.Fatal(err)
}
client.Set(apis.RealtimeClientAuthKey, user)
app.SubscriptionsBroker().Register(client)
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
authRecord, _ := client.Get(apis.RealtimeClientAuthKey).(*core.Record)
if authRecord == nil {
t.Errorf("Expected auth record model, got nil")
}
resetClient()
},
},
{
Name: "existing client - unauthorized client",
Method: http.MethodPost,
URL: "/api/realtime",
Body: strings.NewReader(`{"clientId":"` + client.Id() + `","subscriptions":["test1", "test2"]}`),
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
user, err := app.FindAuthRecordByEmail("users", "test2@example.com")
if err != nil {
t.Fatal(err)
}
client.Set(apis.RealtimeClientAuthKey, user)
app.SubscriptionsBroker().Register(client)
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
authRecord, _ := client.Get(apis.RealtimeClientAuthKey).(*core.Record)
if authRecord == nil {
t.Errorf("Expected auth record model, got nil")
}
resetClient()
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestRealtimeAuthRecordDeleteEvent(t *testing.T) {
testApp, _ := tests.NewTestApp()
defer testApp.Cleanup()
// init realtime handlers
apis.NewRouter(testApp)
authRecord1, err := testApp.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
authRecord2, err := testApp.FindAuthRecordByEmail("users", "test2@example.com")
if err != nil {
t.Fatal(err)
}
client1 := subscriptions.NewDefaultClient()
client1.Set(apis.RealtimeClientAuthKey, authRecord1)
testApp.SubscriptionsBroker().Register(client1)
client2 := subscriptions.NewDefaultClient()
client2.Set(apis.RealtimeClientAuthKey, authRecord1)
testApp.SubscriptionsBroker().Register(client2)
client3 := subscriptions.NewDefaultClient()
client3.Set(apis.RealtimeClientAuthKey, authRecord2)
testApp.SubscriptionsBroker().Register(client3)
// mock delete event
e := new(core.ModelEvent)
e.App = testApp
e.Type = core.ModelEventTypeDelete
e.Context = context.Background()
e.Model = authRecord1
testApp.OnModelAfterDeleteSuccess().Trigger(e)
if total := len(testApp.SubscriptionsBroker().Clients()); total != 3 {
t.Fatalf("Expected %d subscription clients, found %d", 3, total)
}
if auth := client1.Get(apis.RealtimeClientAuthKey); auth != nil {
t.Fatalf("[client1] Expected the auth state to be unset, found %#v", auth)
}
if auth := client2.Get(apis.RealtimeClientAuthKey); auth != nil {
t.Fatalf("[client2] Expected the auth state to be unset, found %#v", auth)
}
if auth := client3.Get(apis.RealtimeClientAuthKey); auth == nil || auth.(*core.Record).Id != authRecord2.Id {
t.Fatalf("[client3] Expected the auth state to be left unchanged, found %#v", auth)
}
}
func TestRealtimeAuthRecordUpdateEvent(t *testing.T) {
testApp, _ := tests.NewTestApp()
defer testApp.Cleanup()
// init realtime handlers
apis.NewRouter(testApp)
authRecord1, err := testApp.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
client := subscriptions.NewDefaultClient()
client.Set(apis.RealtimeClientAuthKey, authRecord1)
testApp.SubscriptionsBroker().Register(client)
// refetch the authRecord and change its email
authRecord2, err := testApp.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
authRecord2.SetEmail("new@example.com")
// mock update event
e := new(core.ModelEvent)
e.App = testApp
e.Type = core.ModelEventTypeUpdate
e.Context = context.Background()
e.Model = authRecord2
testApp.OnModelAfterUpdateSuccess().Trigger(e)
clientAuthRecord, _ := client.Get(apis.RealtimeClientAuthKey).(*core.Record)
if clientAuthRecord.Email() != authRecord2.Email() {
t.Fatalf("Expected authRecord with email %q, got %q", authRecord2.Email(), clientAuthRecord.Email())
}
}
// Custom auth record model struct
// -------------------------------------------------------------------
var _ core.Model = (*CustomUser)(nil)
type CustomUser struct {
core.BaseModel
Email string `db:"email" json:"email"`
}
func (m *CustomUser) TableName() string {
return "users"
}
func findCustomUserByEmail(app core.App, email string) (*CustomUser, error) {
model := &CustomUser{}
err := app.ModelQuery(model).
AndWhere(dbx.HashExp{"email": email}).
Limit(1).
One(model)
if err != nil {
return nil, err
}
return model, nil
}
func TestRealtimeCustomAuthModelDeleteEvent(t *testing.T) {
testApp, _ := tests.NewTestApp()
defer testApp.Cleanup()
// init realtime handlers
apis.NewRouter(testApp)
authRecord1, err := testApp.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
authRecord2, err := testApp.FindAuthRecordByEmail("users", "test2@example.com")
if err != nil {
t.Fatal(err)
}
client1 := subscriptions.NewDefaultClient()
client1.Set(apis.RealtimeClientAuthKey, authRecord1)
testApp.SubscriptionsBroker().Register(client1)
client2 := subscriptions.NewDefaultClient()
client2.Set(apis.RealtimeClientAuthKey, authRecord1)
testApp.SubscriptionsBroker().Register(client2)
client3 := subscriptions.NewDefaultClient()
client3.Set(apis.RealtimeClientAuthKey, authRecord2)
testApp.SubscriptionsBroker().Register(client3)
// refetch the authRecord as CustomUser
customUser, err := findCustomUserByEmail(testApp, authRecord1.Email())
if err != nil {
t.Fatal(err)
}
// delete the custom user (should unset the client auth record)
if err := testApp.Delete(customUser); err != nil {
t.Fatal(err)
}
if total := len(testApp.SubscriptionsBroker().Clients()); total != 3 {
t.Fatalf("Expected %d subscription clients, found %d", 3, total)
}
if auth := client1.Get(apis.RealtimeClientAuthKey); auth != nil {
t.Fatalf("[client1] Expected the auth state to be unset, found %#v", auth)
}
if auth := client2.Get(apis.RealtimeClientAuthKey); auth != nil {
t.Fatalf("[client2] Expected the auth state to be unset, found %#v", auth)
}
if auth := client3.Get(apis.RealtimeClientAuthKey); auth == nil || auth.(*core.Record).Id != authRecord2.Id {
t.Fatalf("[client3] Expected the auth state to be left unchanged, found %#v", auth)
}
}
func TestRealtimeCustomAuthModelUpdateEvent(t *testing.T) {
testApp, _ := tests.NewTestApp()
defer testApp.Cleanup()
// init realtime handlers
apis.NewRouter(testApp)
authRecord, err := testApp.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
client := subscriptions.NewDefaultClient()
client.Set(apis.RealtimeClientAuthKey, authRecord)
testApp.SubscriptionsBroker().Register(client)
// refetch the authRecord as CustomUser
customUser, err := findCustomUserByEmail(testApp, "test@example.com")
if err != nil {
t.Fatal(err)
}
// change its email
customUser.Email = "new@example.com"
if err := testApp.Save(customUser); err != nil {
t.Fatal(err)
}
clientAuthRecord, _ := client.Get(apis.RealtimeClientAuthKey).(*core.Record)
if clientAuthRecord.Email() != customUser.Email {
t.Fatalf("Expected authRecord with email %q, got %q", customUser.Email, clientAuthRecord.Email())
}
}
// -------------------------------------------------------------------
var _ core.Model = (*CustomModelResolve)(nil)
type CustomModelResolve struct {
core.BaseModel
tableName string
Created string `db:"created"`
}
func (m *CustomModelResolve) TableName() string {
return m.tableName
}
func TestRealtimeRecordResolve(t *testing.T) {
t.Parallel()
const testCollectionName = "realtime_test_collection"
testRecordId := core.GenerateDefaultRandomId()
client0 := subscriptions.NewDefaultClient()
client0.Subscribe(testCollectionName + "/*")
client0.Discard()
// ---
client1 := subscriptions.NewDefaultClient()
client1.Subscribe(testCollectionName + "/*")
// ---
client2 := subscriptions.NewDefaultClient()
client2.Subscribe(testCollectionName + "/" + testRecordId)
// ---
client3 := subscriptions.NewDefaultClient()
client3.Subscribe("demo1/*")
scenarios := []struct {
name string
op func(testApp core.App) error
expected map[string][]string // clientId -> [events]
}{
{
"core.Record",
func(testApp core.App) error {
c, err := testApp.FindCollectionByNameOrId(testCollectionName)
if err != nil {
return err
}
r := core.NewRecord(c)
r.Id = testRecordId
// create
err = testApp.Save(r)
if err != nil {
return err
}
// update
err = testApp.Save(r)
if err != nil {
return err
}
// delete
err = testApp.Delete(r)
if err != nil {
return err
}
return nil
},
map[string][]string{
client1.Id(): {"create", "update", "delete"},
client2.Id(): {"create", "update", "delete"},
},
},
{
"core.RecordProxy",
func(testApp core.App) error {
c, err := testApp.FindCollectionByNameOrId(testCollectionName)
if err != nil {
return err
}
r := core.NewRecord(c)
proxy := &struct {
core.BaseRecordProxy
}{}
proxy.SetProxyRecord(r)
proxy.Id = testRecordId
// create
err = testApp.Save(proxy)
if err != nil {
return err
}
// update
err = testApp.Save(proxy)
if err != nil {
return err
}
// delete
err = testApp.Delete(proxy)
if err != nil {
return err
}
return nil
},
map[string][]string{
client1.Id(): {"create", "update", "delete"},
client2.Id(): {"create", "update", "delete"},
},
},
{
"custom model struct",
func(testApp core.App) error {
m := &CustomModelResolve{tableName: testCollectionName}
m.Id = testRecordId
// create
err := testApp.Save(m)
if err != nil {
return err
}
// update
m.Created = "123"
err = testApp.Save(m)
if err != nil {
return err
}
// delete
err = testApp.Delete(m)
if err != nil {
return err
}
return nil
},
map[string][]string{
client1.Id(): {"create", "update", "delete"},
client2.Id(): {"create", "update", "delete"},
},
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
testApp, _ := tests.NewTestApp()
defer testApp.Cleanup()
// init realtime handlers
apis.NewRouter(testApp)
// create new test collection with public read access
testCollection := core.NewBaseCollection(testCollectionName)
testCollection.Fields.Add(&core.AutodateField{Name: "created", OnCreate: true, OnUpdate: true})
testCollection.ListRule = types.Pointer("")
testCollection.ViewRule = types.Pointer("")
err := testApp.Save(testCollection)
if err != nil {
t.Fatal(err)
}
testApp.SubscriptionsBroker().Register(client0)
testApp.SubscriptionsBroker().Register(client1)
testApp.SubscriptionsBroker().Register(client2)
testApp.SubscriptionsBroker().Register(client3)
var wg sync.WaitGroup
var notifications = map[string][]string{}
var mu sync.Mutex
notify := func(clientId string, eventData []byte) {
data := struct{ Action string }{}
_ = json.Unmarshal(eventData, &data)
mu.Lock()
defer mu.Unlock()
if notifications[clientId] == nil {
notifications[clientId] = []string{}
}
notifications[clientId] = append(notifications[clientId], data.Action)
}
wg.Add(1)
go func() {
defer wg.Done()
timeout := time.After(250 * time.Millisecond)
for {
select {
case e, ok := <-client0.Channel():
if ok {
notify(client0.Id(), e.Data)
}
case e, ok := <-client1.Channel():
if ok {
notify(client1.Id(), e.Data)
}
case e, ok := <-client2.Channel():
if ok {
notify(client2.Id(), e.Data)
}
case e, ok := <-client3.Channel():
if ok {
notify(client3.Id(), e.Data)
}
case <-timeout:
return
}
}
}()
err = s.op(testApp)
if err != nil {
t.Fatal(err)
}
wg.Wait()
if len(s.expected) != len(notifications) {
t.Fatalf("Expected %d notified clients, got %d:\n%v", len(s.expected), len(notifications), notifications)
}
for id, events := range s.expected {
if len(events) != len(notifications[id]) {
t.Fatalf("[%s] Expected %d events, got %d:\n%v\n%v", id, len(events), len(notifications[id]), s.expected, notifications)
}
for _, event := range events {
if !slices.Contains(notifications[id], event) {
t.Fatalf("[%s] Missing expected event %q in %v", id, event, notifications[id])
}
}
}
})
}
}

79
apis/record_auth.go Normal file
View file

@ -0,0 +1,79 @@
package apis
import (
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tools/router"
)
// bindRecordAuthApi registers the auth record api endpoints and
// the corresponding handlers.
func bindRecordAuthApi(app core.App, rg *router.RouterGroup[*core.RequestEvent]) {
// global oauth2 subscription redirect handler
rg.GET("/oauth2-redirect", oauth2SubscriptionRedirect).Bind(
SkipSuccessActivityLog(), // skip success log as it could contain sensitive information in the url
)
// add again as POST in case of response_mode=form_post
rg.POST("/oauth2-redirect", oauth2SubscriptionRedirect).Bind(
SkipSuccessActivityLog(), // skip success log as it could contain sensitive information in the url
)
sub := rg.Group("/collections/{collection}")
sub.GET("/auth-methods", recordAuthMethods).Bind(
collectionPathRateLimit("", "listAuthMethods"),
)
sub.POST("/auth-refresh", recordAuthRefresh).Bind(
collectionPathRateLimit("", "authRefresh"),
RequireSameCollectionContextAuth(""),
)
sub.POST("/auth-with-password", recordAuthWithPassword).Bind(
collectionPathRateLimit("", "authWithPassword", "auth"),
)
sub.POST("/auth-with-oauth2", recordAuthWithOAuth2).Bind(
collectionPathRateLimit("", "authWithOAuth2", "auth"),
)
sub.POST("/request-otp", recordRequestOTP).Bind(
collectionPathRateLimit("", "requestOTP"),
)
sub.POST("/auth-with-otp", recordAuthWithOTP).Bind(
collectionPathRateLimit("", "authWithOTP", "auth"),
)
sub.POST("/request-password-reset", recordRequestPasswordReset).Bind(
collectionPathRateLimit("", "requestPasswordReset"),
)
sub.POST("/confirm-password-reset", recordConfirmPasswordReset).Bind(
collectionPathRateLimit("", "confirmPasswordReset"),
)
sub.POST("/request-verification", recordRequestVerification).Bind(
collectionPathRateLimit("", "requestVerification"),
)
sub.POST("/confirm-verification", recordConfirmVerification).Bind(
collectionPathRateLimit("", "confirmVerification"),
)
sub.POST("/request-email-change", recordRequestEmailChange).Bind(
collectionPathRateLimit("", "requestEmailChange"),
RequireSameCollectionContextAuth(""),
)
sub.POST("/confirm-email-change", recordConfirmEmailChange).Bind(
collectionPathRateLimit("", "confirmEmailChange"),
)
sub.POST("/impersonate/{id}", recordAuthImpersonate).Bind(RequireSuperuserAuth())
}
func findAuthCollection(e *core.RequestEvent) (*core.Collection, error) {
collection, err := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue("collection"))
if err != nil || !collection.IsAuth() {
return nil, e.NotFoundError("Missing or invalid auth collection context.", err)
}
return collection, nil
}

View file

@ -0,0 +1,122 @@
package apis
import (
"net/http"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tools/security"
)
func recordConfirmEmailChange(e *core.RequestEvent) error {
collection, err := findAuthCollection(e)
if err != nil {
return err
}
if collection.Name == core.CollectionNameSuperusers {
return e.BadRequestError("All superusers can change their emails directly.", nil)
}
form := newEmailChangeConfirmForm(e.App, collection)
if err = e.BindBody(form); err != nil {
return firstApiError(err, e.BadRequestError("An error occurred while loading the submitted data.", err))
}
if err = form.validate(); err != nil {
return firstApiError(err, e.BadRequestError("An error occurred while validating the submitted data.", err))
}
authRecord, newEmail, err := form.parseToken()
if err != nil {
return firstApiError(err, e.BadRequestError("Invalid or expired token.", err))
}
event := new(core.RecordConfirmEmailChangeRequestEvent)
event.RequestEvent = e
event.Collection = collection
event.Record = authRecord
event.NewEmail = newEmail
return e.App.OnRecordConfirmEmailChangeRequest().Trigger(event, func(e *core.RecordConfirmEmailChangeRequestEvent) error {
e.Record.SetEmail(e.NewEmail)
e.Record.SetVerified(true)
if err := e.App.Save(e.Record); err != nil {
return firstApiError(err, e.BadRequestError("Failed to confirm email change.", err))
}
return execAfterSuccessTx(true, e.App, func() error {
return e.NoContent(http.StatusNoContent)
})
})
}
// -------------------------------------------------------------------
func newEmailChangeConfirmForm(app core.App, collection *core.Collection) *EmailChangeConfirmForm {
return &EmailChangeConfirmForm{
app: app,
collection: collection,
}
}
type EmailChangeConfirmForm struct {
app core.App
collection *core.Collection
Token string `form:"token" json:"token"`
Password string `form:"password" json:"password"`
}
func (form *EmailChangeConfirmForm) validate() error {
return validation.ValidateStruct(form,
validation.Field(&form.Token, validation.Required, validation.By(form.checkToken)),
validation.Field(&form.Password, validation.Required, validation.Length(1, 100), validation.By(form.checkPassword)),
)
}
func (form *EmailChangeConfirmForm) checkToken(value any) error {
_, _, err := form.parseToken()
return err
}
func (form *EmailChangeConfirmForm) checkPassword(value any) error {
v, _ := value.(string)
if v == "" {
return nil // nothing to check
}
authRecord, _, _ := form.parseToken()
if authRecord == nil || !authRecord.ValidatePassword(v) {
return validation.NewError("validation_invalid_password", "Missing or invalid auth record password.")
}
return nil
}
func (form *EmailChangeConfirmForm) parseToken() (*core.Record, string, error) {
// check token payload
claims, _ := security.ParseUnverifiedJWT(form.Token)
newEmail, _ := claims[core.TokenClaimNewEmail].(string)
if newEmail == "" {
return nil, "", validation.NewError("validation_invalid_token_payload", "Invalid token payload - newEmail must be set.")
}
// ensure that there aren't other users with the new email
_, err := form.app.FindAuthRecordByEmail(form.collection, newEmail)
if err == nil {
return nil, "", validation.NewError("validation_existing_token_email", "The new email address is already registered: "+newEmail)
}
// verify that the token is not expired and its signature is valid
authRecord, err := form.app.FindAuthRecordByToken(form.Token, core.TokenTypeEmailChange)
if err != nil {
return nil, "", validation.NewError("validation_invalid_token", "Invalid or expired token.")
}
if authRecord.Collection().Id != form.collection.Id {
return nil, "", validation.NewError("validation_token_collection_mismatch", "The provided token is for different auth collection.")
}
return authRecord, newEmail, nil
}

View file

@ -0,0 +1,211 @@
package apis_test
import (
"net/http"
"strings"
"testing"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
)
func TestRecordConfirmEmailChange(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "not an auth collection",
Method: http.MethodPost,
URL: "/api/collections/demo1/confirm-email-change",
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "empty data",
Method: http.MethodPost,
URL: "/api/collections/users/confirm-email-change",
Body: strings.NewReader(``),
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":`,
`"token":{"code":"validation_required"`,
`"password":{"code":"validation_required"`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "invalid data",
Method: http.MethodPost,
URL: "/api/collections/users/confirm-email-change",
Body: strings.NewReader(`{"token`),
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "expired token and correct password",
Method: http.MethodPost,
URL: "/api/collections/users/confirm-email-change",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsInR5cGUiOiJlbWFpbENoYW5nZSIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSIsIm5ld0VtYWlsIjoiY2hhbmdlQGV4YW1wbGUuY29tIiwiZXhwIjoxNjQwOTkxNjYxfQ.dff842MO0mgRTHY8dktp0dqG9-7LGQOgRuiAbQpYBls",
"password":"1234567890"
}`),
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{`,
`"token":{`,
`"code":"validation_invalid_token"`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "non-email change token",
Method: http.MethodPost,
URL: "/api/collections/users/confirm-email-change",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
"password":"1234567890"
}`),
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{`,
`"token":{`,
`"code":"validation_invalid_token_payload"`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "valid token and incorrect password",
Method: http.MethodPost,
URL: "/api/collections/users/confirm-email-change",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsInR5cGUiOiJlbWFpbENoYW5nZSIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSIsIm5ld0VtYWlsIjoiY2hhbmdlQGV4YW1wbGUuY29tIiwiZXhwIjoyNTI0NjA0NDYxfQ.Y7mVlaEPhJiNPoIvIqbIosZU4c4lEhwysOrRR8c95iU",
"password":"1234567891"
}`),
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{`,
`"password":{`,
`"code":"validation_invalid_password"`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "valid token and correct password",
Method: http.MethodPost,
URL: "/api/collections/users/confirm-email-change",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsInR5cGUiOiJlbWFpbENoYW5nZSIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSIsIm5ld0VtYWlsIjoiY2hhbmdlQGV4YW1wbGUuY29tIiwiZXhwIjoyNTI0NjA0NDYxfQ.Y7mVlaEPhJiNPoIvIqbIosZU4c4lEhwysOrRR8c95iU",
"password":"1234567890"
}`),
ExpectedStatus: 204,
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordConfirmEmailChangeRequest": 1,
"OnModelUpdate": 1,
"OnModelUpdateExecute": 1,
"OnModelAfterUpdateSuccess": 1,
"OnModelValidate": 1,
"OnRecordUpdate": 1,
"OnRecordUpdateExecute": 1,
"OnRecordAfterUpdateSuccess": 1,
"OnRecordValidate": 1,
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
_, err := app.FindAuthRecordByEmail("users", "change@example.com")
if err != nil {
t.Fatalf("Expected to find user with email %q, got error: %v", "change@example.com", err)
}
},
},
{
Name: "valid token in different auth collection",
Method: http.MethodPost,
URL: "/api/collections/clients/confirm-email-change",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsInR5cGUiOiJlbWFpbENoYW5nZSIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSIsIm5ld0VtYWlsIjoiY2hhbmdlQGV4YW1wbGUuY29tIiwiZXhwIjoyNTI0NjA0NDYxfQ.Y7mVlaEPhJiNPoIvIqbIosZU4c4lEhwysOrRR8c95iU",
"password":"1234567890"
}`),
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{`,
`"token":{"code":"validation_token_collection_mismatch"`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "OnRecordConfirmEmailChangeRequest tx body write check",
Method: http.MethodPost,
URL: "/api/collections/users/confirm-email-change",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsInR5cGUiOiJlbWFpbENoYW5nZSIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSIsIm5ld0VtYWlsIjoiY2hhbmdlQGV4YW1wbGUuY29tIiwiZXhwIjoyNTI0NjA0NDYxfQ.Y7mVlaEPhJiNPoIvIqbIosZU4c4lEhwysOrRR8c95iU",
"password":"1234567890"
}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.OnRecordConfirmEmailChangeRequest().BindFunc(func(e *core.RecordConfirmEmailChangeRequestEvent) error {
original := e.App
return e.App.RunInTransaction(func(txApp core.App) error {
e.App = txApp
defer func() { e.App = original }()
if err := e.Next(); err != nil {
return err
}
return e.BadRequestError("TX_ERROR", nil)
})
})
},
ExpectedStatus: 400,
ExpectedEvents: map[string]int{"OnRecordConfirmEmailChangeRequest": 1},
ExpectedContent: []string{"TX_ERROR"},
},
// rate limit checks
// -----------------------------------------------------------
{
Name: "RateLimit rule - users:confirmEmailChange",
Method: http.MethodPost,
URL: "/api/collections/users/confirm-email-change",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsInR5cGUiOiJlbWFpbENoYW5nZSIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSIsIm5ld0VtYWlsIjoiY2hhbmdlQGV4YW1wbGUuY29tIiwiZXhwIjoyNTI0NjA0NDYxfQ.Y7mVlaEPhJiNPoIvIqbIosZU4c4lEhwysOrRR8c95iU",
"password":"1234567890"
}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 100, Label: "*:confirmEmailChange"},
{MaxRequests: 0, Label: "users:confirmEmailChange"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "RateLimit rule - *:confirmEmailChange",
Method: http.MethodPost,
URL: "/api/collections/users/confirm-email-change",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsInR5cGUiOiJlbWFpbENoYW5nZSIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSIsIm5ld0VtYWlsIjoiY2hhbmdlQGV4YW1wbGUuY29tIiwiZXhwIjoyNTI0NjA0NDYxfQ.Y7mVlaEPhJiNPoIvIqbIosZU4c4lEhwysOrRR8c95iU",
"password":"1234567890"
}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 0, Label: "*:confirmEmailChange"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}

View file

@ -0,0 +1,92 @@
package apis
import (
"net/http"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/go-ozzo/ozzo-validation/v4/is"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/mails"
)
func recordRequestEmailChange(e *core.RequestEvent) error {
collection, err := findAuthCollection(e)
if err != nil {
return err
}
if collection.Name == core.CollectionNameSuperusers {
return e.BadRequestError("All superusers can change their emails directly.", nil)
}
record := e.Auth
if record == nil {
return e.UnauthorizedError("The request requires valid auth record.", nil)
}
form := newEmailChangeRequestForm(e.App, record)
if err = e.BindBody(form); err != nil {
return firstApiError(err, e.BadRequestError("An error occurred while loading the submitted data.", err))
}
if err = form.validate(); err != nil {
return firstApiError(err, e.BadRequestError("An error occurred while validating the submitted data.", err))
}
event := new(core.RecordRequestEmailChangeRequestEvent)
event.RequestEvent = e
event.Collection = collection
event.Record = record
event.NewEmail = form.NewEmail
return e.App.OnRecordRequestEmailChangeRequest().Trigger(event, func(e *core.RecordRequestEmailChangeRequestEvent) error {
if err := mails.SendRecordChangeEmail(e.App, e.Record, e.NewEmail); err != nil {
return firstApiError(err, e.BadRequestError("Failed to request email change.", err))
}
return execAfterSuccessTx(true, e.App, func() error {
return e.NoContent(http.StatusNoContent)
})
})
}
// -------------------------------------------------------------------
func newEmailChangeRequestForm(app core.App, record *core.Record) *emailChangeRequestForm {
return &emailChangeRequestForm{
app: app,
record: record,
}
}
type emailChangeRequestForm struct {
app core.App
record *core.Record
NewEmail string `form:"newEmail" json:"newEmail"`
}
func (form *emailChangeRequestForm) validate() error {
return validation.ValidateStruct(form,
validation.Field(&form.NewEmail,
validation.Required,
validation.Length(1, 255),
is.EmailFormat,
validation.NotIn(form.record.Email()),
validation.By(form.checkUniqueEmail),
),
)
}
func (form *emailChangeRequestForm) checkUniqueEmail(value any) error {
v, _ := value.(string)
if v == "" {
return nil
}
found, _ := form.app.FindAuthRecordByEmail(form.record.Collection(), v)
if found != nil && found.Id != form.record.Id {
return validation.NewError("validation_invalid_new_email", "Invalid new email address.")
}
return nil
}

View file

@ -0,0 +1,195 @@
package apis_test
import (
"net/http"
"strings"
"testing"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
)
func TestRecordRequestEmailChange(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "unauthorized",
Method: http.MethodPost,
URL: "/api/collections/users/request-email-change",
Body: strings.NewReader(`{"newEmail":"change@example.com"}`),
ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "not an auth collection",
Method: http.MethodPost,
URL: "/api/collections/demo1/request-email-change",
Body: strings.NewReader(`{"newEmail":"change@example.com"}`),
ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "record authentication but from different auth collection",
Method: http.MethodPost,
URL: "/api/collections/clients/request-email-change",
Body: strings.NewReader(`{"newEmail":"change@example.com"}`),
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "superuser authentication",
Method: http.MethodPost,
URL: "/api/collections/users/request-email-change",
Body: strings.NewReader(`{"newEmail":"change@example.com"}`),
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "invalid data",
Method: http.MethodPost,
URL: "/api/collections/users/request-email-change",
Body: strings.NewReader(`{"newEmail`),
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "empty data",
Method: http.MethodPost,
URL: "/api/collections/users/request-email-change",
Body: strings.NewReader(`{}`),
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":`,
`"newEmail":{"code":"validation_required"`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "valid data (existing email)",
Method: http.MethodPost,
URL: "/api/collections/users/request-email-change",
Body: strings.NewReader(`{"newEmail":"test2@example.com"}`),
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":`,
`"newEmail":{"code":"validation_invalid_new_email"`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "valid data (new email)",
Method: http.MethodPost,
URL: "/api/collections/users/request-email-change",
Body: strings.NewReader(`{"newEmail":"change@example.com"}`),
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 204,
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordRequestEmailChangeRequest": 1,
"OnMailerSend": 1,
"OnMailerRecordEmailChangeSend": 1,
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
if !strings.Contains(app.TestMailer.LastMessage().HTML, "/auth/confirm-email-change") {
t.Fatalf("Expected email change email, got\n%v", app.TestMailer.LastMessage().HTML)
}
},
},
{
Name: "OnRecordRequestEmailChangeRequest tx body write check",
Method: http.MethodPost,
URL: "/api/collections/users/request-email-change",
Body: strings.NewReader(`{"newEmail":"change@example.com"}`),
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.OnRecordRequestEmailChangeRequest().BindFunc(func(e *core.RecordRequestEmailChangeRequestEvent) error {
original := e.App
return e.App.RunInTransaction(func(txApp core.App) error {
e.App = txApp
defer func() { e.App = original }()
if err := e.Next(); err != nil {
return err
}
return e.BadRequestError("TX_ERROR", nil)
})
})
},
ExpectedStatus: 400,
ExpectedEvents: map[string]int{"OnRecordRequestEmailChangeRequest": 1},
ExpectedContent: []string{"TX_ERROR"},
},
// rate limit checks
// -----------------------------------------------------------
{
Name: "RateLimit rule - users:requestEmailChange",
Method: http.MethodPost,
URL: "/api/collections/users/request-email-change",
Body: strings.NewReader(`{"newEmail":"change@example.com"}`),
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 100, Label: "*:requestEmailChange"},
{MaxRequests: 0, Label: "users:requestEmailChange"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "RateLimit rule - *:requestEmailChange",
Method: http.MethodPost,
URL: "/api/collections/users/request-email-change",
Body: strings.NewReader(`{"newEmail":"change@example.com"}`),
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 0, Label: "*:requestEmailChange"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}

View file

@ -0,0 +1,54 @@
package apis
import (
"time"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/pocketbase/pocketbase/core"
)
// note: for now allow superusers but it may change in the future to allow access
// also to users with "Manage API" rule access depending on the use cases that will arise
func recordAuthImpersonate(e *core.RequestEvent) error {
if !e.HasSuperuserAuth() {
return e.ForbiddenError("", nil)
}
collection, err := findAuthCollection(e)
if err != nil {
return err
}
record, err := e.App.FindRecordById(collection, e.Request.PathValue("id"))
if err != nil {
return e.NotFoundError("", err)
}
form := &impersonateForm{}
if err = e.BindBody(form); err != nil {
return e.BadRequestError("An error occurred while loading the submitted data.", err)
}
if err = form.validate(); err != nil {
return e.BadRequestError("An error occurred while validating the submitted data.", err)
}
token, err := record.NewStaticAuthToken(time.Duration(form.Duration) * time.Second)
if err != nil {
e.InternalServerError("Failed to generate static auth token", err)
}
return recordAuthResponse(e, record, token, "", nil)
}
// -------------------------------------------------------------------
type impersonateForm struct {
// Duration is the optional custom token duration in seconds.
Duration int64 `form:"duration" json:"duration"`
}
func (form *impersonateForm) validate() error {
return validation.ValidateStruct(form,
validation.Field(&form.Duration, validation.Min(0)),
)
}

View file

@ -0,0 +1,109 @@
package apis_test
import (
"net/http"
"strings"
"testing"
"github.com/pocketbase/pocketbase/tests"
)
func TestRecordAuthImpersonate(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "unauthorized",
Method: http.MethodPost,
URL: "/api/collections/users/impersonate/4q1xlclmfloku33",
ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as different user",
Method: http.MethodPost,
URL: "/api/collections/users/impersonate/4q1xlclmfloku33",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6Im9hcDY0MGNvdDR5cnUycyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.GfJo6EHIobgas_AXt-M-tj5IoQendPnrkMSe9ExuSEY",
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as the same user",
Method: http.MethodPost,
URL: "/api/collections/users/impersonate/4q1xlclmfloku33",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as superuser",
Method: http.MethodPost,
URL: "/api/collections/users/impersonate/4q1xlclmfloku33",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"token":"`,
`"id":"4q1xlclmfloku33"`,
`"record":{`,
},
NotExpectedContent: []string{
// hidden fields should remain hidden even though we are authenticated as superuser
`"tokenKey"`,
`"password"`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordAuthRequest": 1,
"OnRecordEnrich": 1,
},
},
{
Name: "authorized as superuser with custom invalid duration",
Method: http.MethodPost,
URL: "/api/collections/users/impersonate/4q1xlclmfloku33",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
Body: strings.NewReader(`{"duration":-1}`),
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{`,
`"duration":{`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as superuser with custom valid duration",
Method: http.MethodPost,
URL: "/api/collections/users/impersonate/4q1xlclmfloku33",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
Body: strings.NewReader(`{"duration":100}`),
ExpectedStatus: 200,
ExpectedContent: []string{
`"token":"`,
`"id":"4q1xlclmfloku33"`,
`"record":{`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordAuthRequest": 1,
"OnRecordEnrich": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}

170
apis/record_auth_methods.go Normal file
View file

@ -0,0 +1,170 @@
package apis
import (
"log/slog"
"net/http"
"slices"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tools/auth"
"github.com/pocketbase/pocketbase/tools/security"
"golang.org/x/oauth2"
)
type otpResponse struct {
Enabled bool `json:"enabled"`
Duration int64 `json:"duration"` // in seconds
}
type mfaResponse struct {
Enabled bool `json:"enabled"`
Duration int64 `json:"duration"` // in seconds
}
type passwordResponse struct {
IdentityFields []string `json:"identityFields"`
Enabled bool `json:"enabled"`
}
type oauth2Response struct {
Providers []providerInfo `json:"providers"`
Enabled bool `json:"enabled"`
}
type providerInfo struct {
Name string `json:"name"`
DisplayName string `json:"displayName"`
State string `json:"state"`
AuthURL string `json:"authURL"`
// @todo
// deprecated: use AuthURL instead
// AuthUrl will be removed after dropping v0.22 support
AuthUrl string `json:"authUrl"`
// technically could be omitted if the provider doesn't support PKCE,
// but to avoid breaking existing typed clients we'll return them as empty string
CodeVerifier string `json:"codeVerifier"`
CodeChallenge string `json:"codeChallenge"`
CodeChallengeMethod string `json:"codeChallengeMethod"`
}
type authMethodsResponse struct {
Password passwordResponse `json:"password"`
OAuth2 oauth2Response `json:"oauth2"`
MFA mfaResponse `json:"mfa"`
OTP otpResponse `json:"otp"`
// legacy fields
// @todo remove after dropping v0.22 support
AuthProviders []providerInfo `json:"authProviders"`
UsernamePassword bool `json:"usernamePassword"`
EmailPassword bool `json:"emailPassword"`
}
func (amr *authMethodsResponse) fillLegacyFields() {
amr.EmailPassword = amr.Password.Enabled && slices.Contains(amr.Password.IdentityFields, "email")
amr.UsernamePassword = amr.Password.Enabled && slices.Contains(amr.Password.IdentityFields, "username")
if amr.OAuth2.Enabled {
amr.AuthProviders = amr.OAuth2.Providers
}
}
func recordAuthMethods(e *core.RequestEvent) error {
collection, err := findAuthCollection(e)
if err != nil {
return err
}
result := authMethodsResponse{
Password: passwordResponse{
IdentityFields: make([]string, 0, len(collection.PasswordAuth.IdentityFields)),
},
OAuth2: oauth2Response{
Providers: make([]providerInfo, 0, len(collection.OAuth2.Providers)),
},
OTP: otpResponse{
Enabled: collection.OTP.Enabled,
},
MFA: mfaResponse{
Enabled: collection.MFA.Enabled,
},
}
if collection.PasswordAuth.Enabled {
result.Password.Enabled = true
result.Password.IdentityFields = collection.PasswordAuth.IdentityFields
}
if collection.OTP.Enabled {
result.OTP.Duration = collection.OTP.Duration
}
if collection.MFA.Enabled {
result.MFA.Duration = collection.MFA.Duration
}
if !collection.OAuth2.Enabled {
result.fillLegacyFields()
return e.JSON(http.StatusOK, result)
}
result.OAuth2.Enabled = true
for _, config := range collection.OAuth2.Providers {
provider, err := config.InitProvider()
if err != nil {
e.App.Logger().Debug(
"Failed to setup OAuth2 provider",
slog.String("name", config.Name),
slog.String("error", err.Error()),
)
continue // skip provider
}
info := providerInfo{
Name: config.Name,
DisplayName: provider.DisplayName(),
State: security.RandomString(30),
}
if info.DisplayName == "" {
info.DisplayName = config.Name
}
urlOpts := []oauth2.AuthCodeOption{}
// custom providers url options
switch config.Name {
case auth.NameApple:
// see https://developer.apple.com/documentation/sign_in_with_apple/sign_in_with_apple_js/incorporating_sign_in_with_apple_into_other_platforms#3332113
urlOpts = append(urlOpts, oauth2.SetAuthURLParam("response_mode", "form_post"))
}
if provider.PKCE() {
info.CodeVerifier = security.RandomString(43)
info.CodeChallenge = security.S256Challenge(info.CodeVerifier)
info.CodeChallengeMethod = "S256"
urlOpts = append(urlOpts,
oauth2.SetAuthURLParam("code_challenge", info.CodeChallenge),
oauth2.SetAuthURLParam("code_challenge_method", info.CodeChallengeMethod),
)
}
info.AuthURL = provider.BuildAuthURL(
info.State,
urlOpts...,
) + "&redirect_uri=" // empty redirect_uri so that users can append their redirect url
info.AuthUrl = info.AuthURL
result.OAuth2.Providers = append(result.OAuth2.Providers, info)
}
result.fillLegacyFields()
return e.JSON(http.StatusOK, result)
}

View file

@ -0,0 +1,106 @@
package apis_test
import (
"net/http"
"testing"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
)
func TestRecordAuthMethodsList(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "missing collection",
Method: http.MethodGet,
URL: "/api/collections/missing/auth-methods",
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "non auth collection",
Method: http.MethodGet,
URL: "/api/collections/demo1/auth-methods",
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "auth collection with none auth methods allowed",
Method: http.MethodGet,
URL: "/api/collections/nologin/auth-methods",
ExpectedStatus: 200,
ExpectedContent: []string{
`"password":{"identityFields":[],"enabled":false}`,
`"oauth2":{"providers":[],"enabled":false}`,
`"mfa":{"enabled":false,"duration":0}`,
`"otp":{"enabled":false,"duration":0}`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "auth collection with all auth methods allowed",
Method: http.MethodGet,
URL: "/api/collections/users/auth-methods",
ExpectedStatus: 200,
ExpectedContent: []string{
`"password":{"identityFields":["email","username"],"enabled":true}`,
`"mfa":{"enabled":true,"duration":1800}`,
`"otp":{"enabled":true,"duration":300}`,
`"oauth2":{`,
`"providers":[{`,
`"name":"google"`,
`"name":"gitlab"`,
`"state":`,
`"displayName":`,
`"codeVerifier":`,
`"codeChallenge":`,
`"codeChallengeMethod":`,
`"authURL":`,
`redirect_uri="`, // ensures that the redirect_uri is the last url param
},
ExpectedEvents: map[string]int{"*": 0},
},
// rate limit checks
// -----------------------------------------------------------
{
Name: "RateLimit rule - nologin:listAuthMethods",
Method: http.MethodGet,
URL: "/api/collections/nologin/auth-methods",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 100, Label: "*:listAuthMethods"},
{MaxRequests: 0, Label: "nologin:listAuthMethods"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "RateLimit rule - *:listAuthMethods",
Method: http.MethodGet,
URL: "/api/collections/nologin/auth-methods",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 0, Label: "*:listAuthMethods"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}

View file

@ -0,0 +1,127 @@
package apis
import (
"database/sql"
"errors"
"fmt"
"net/http"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/go-ozzo/ozzo-validation/v4/is"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/mails"
"github.com/pocketbase/pocketbase/tools/routine"
"github.com/pocketbase/pocketbase/tools/security"
)
func recordRequestOTP(e *core.RequestEvent) error {
collection, err := findAuthCollection(e)
if err != nil {
return err
}
if !collection.OTP.Enabled {
return e.ForbiddenError("The collection is not configured to allow OTP authentication.", nil)
}
form := &createOTPForm{}
if err = e.BindBody(form); err != nil {
return firstApiError(err, e.BadRequestError("An error occurred while loading the submitted data.", err))
}
if err = form.validate(); err != nil {
return firstApiError(err, e.BadRequestError("An error occurred while validating the submitted data.", err))
}
record, err := e.App.FindAuthRecordByEmail(collection, form.Email)
// ignore not found errors to allow custom record find implementations
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return e.InternalServerError("", err)
}
event := new(core.RecordCreateOTPRequestEvent)
event.RequestEvent = e
event.Password = security.RandomStringWithAlphabet(collection.OTP.Length, "1234567890")
event.Collection = collection
event.Record = record
originalApp := e.App
return e.App.OnRecordRequestOTPRequest().Trigger(event, func(e *core.RecordCreateOTPRequestEvent) error {
if e.Record == nil {
// write a dummy 200 response as a very rudimentary emails enumeration "protection"
e.JSON(http.StatusOK, map[string]string{
"otpId": core.GenerateDefaultRandomId(),
})
return fmt.Errorf("missing or invalid %s OTP auth record with email %s", collection.Name, form.Email)
}
var otp *core.OTP
// limit the new OTP creations for a single user
if !e.App.IsDev() {
otps, err := e.App.FindAllOTPsByRecord(e.Record)
if err != nil {
return firstApiError(err, e.InternalServerError("Failed to fetch previous record OTPs.", err))
}
totalRecent := 0
for _, existingOTP := range otps {
if !existingOTP.HasExpired(collection.OTP.DurationTime()) {
totalRecent++
}
// use the last issued one
if totalRecent > 9 {
otp = otps[0] // otps are DESC sorted
e.App.Logger().Warn(
"Too many OTP requests - reusing the last issued",
"email", form.Email,
"recordId", e.Record.Id,
"otpId", existingOTP.Id,
)
break
}
}
}
if otp == nil {
// create new OTP
// ---
otp = core.NewOTP(e.App)
otp.SetCollectionRef(e.Record.Collection().Id)
otp.SetRecordRef(e.Record.Id)
otp.SetPassword(e.Password)
err = e.App.Save(otp)
if err != nil {
return err
}
// send OTP email
// (in the background as a very basic timing attacks and emails enumeration protection)
// ---
routine.FireAndForget(func() {
err = mails.SendRecordOTP(originalApp, e.Record, otp.Id, e.Password)
if err != nil {
originalApp.Logger().Error("Failed to send OTP email", "error", errors.Join(err, originalApp.Delete(otp)))
}
})
}
return execAfterSuccessTx(true, e.App, func() error {
return e.JSON(http.StatusOK, map[string]string{"otpId": otp.Id})
})
})
}
// -------------------------------------------------------------------
type createOTPForm struct {
Email string `form:"email" json:"email"`
}
func (form createOTPForm) validate() error {
return validation.ValidateStruct(&form,
validation.Field(&form.Email, validation.Required, validation.Length(1, 255), is.EmailFormat),
)
}

View file

@ -0,0 +1,316 @@
package apis_test
import (
"net/http"
"strconv"
"strings"
"testing"
"time"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
"github.com/pocketbase/pocketbase/tools/types"
)
func TestRecordRequestOTP(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "not an auth collection",
Method: http.MethodPost,
URL: "/api/collections/demo1/request-otp",
Body: strings.NewReader(`{"email":"test@example.com"}`),
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "auth collection with disabled otp",
Method: http.MethodPost,
URL: "/api/collections/users/request-otp",
Body: strings.NewReader(`{"email":"test@example.com"}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
usersCol, err := app.FindCollectionByNameOrId("users")
if err != nil {
t.Fatal(err)
}
usersCol.OTP.Enabled = false
if err := app.Save(usersCol); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "empty body",
Method: http.MethodPost,
URL: "/api/collections/users/request-otp",
Body: strings.NewReader(``),
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{"email":{"code":"validation_required","message":"Cannot be blank."}}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "invalid body",
Method: http.MethodPost,
URL: "/api/collections/users/request-otp",
Body: strings.NewReader(`{"email`),
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "invalid request data",
Method: http.MethodPost,
URL: "/api/collections/users/request-otp",
Body: strings.NewReader(`{"email":"invalid"}`),
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{`,
`"email":{"code":"validation_is_email`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "missing auth record",
Method: http.MethodPost,
URL: "/api/collections/users/request-otp",
Body: strings.NewReader(`{"email":"missing@example.com"}`),
Delay: 100 * time.Millisecond,
ExpectedStatus: 200,
ExpectedContent: []string{
`"otpId":"`, // some fake random generated string
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordRequestOTPRequest": 1,
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
if app.TestMailer.TotalSend() != 0 {
t.Fatalf("Expected zero emails, got %d", app.TestMailer.TotalSend())
}
},
},
{
Name: "existing auth record (with < 9 non-expired)",
Method: http.MethodPost,
URL: "/api/collections/users/request-otp",
Body: strings.NewReader(`{"email":"test@example.com"}`),
Delay: 100 * time.Millisecond,
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
// insert 8 non-expired and 2 expired
for i := 0; i < 10; i++ {
otp := core.NewOTP(app)
otp.Id = "otp_" + strconv.Itoa(i)
otp.SetCollectionRef(user.Collection().Id)
otp.SetRecordRef(user.Id)
otp.SetPassword("123456")
if i >= 8 {
expiredDate := types.NowDateTime().AddDate(-3, 0, 0)
otp.SetRaw("created", expiredDate)
otp.SetRaw("updated", expiredDate)
}
if err := app.SaveNoValidate(otp); err != nil {
t.Fatal(err)
}
}
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"otpId":"`,
},
NotExpectedContent: []string{
`"otpId":"otp_`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordRequestOTPRequest": 1,
"OnMailerSend": 1,
"OnMailerRecordOTPSend": 1,
"OnModelCreate": 1,
"OnModelCreateExecute": 1,
"OnModelAfterCreateSuccess": 1,
"OnModelValidate": 2, // + 1 for the OTP update after the email send
"OnRecordCreate": 1,
"OnRecordCreateExecute": 1,
"OnRecordAfterCreateSuccess": 1,
"OnRecordValidate": 2,
// OTP update
"OnModelUpdate": 1,
"OnModelUpdateExecute": 1,
"OnModelAfterUpdateSuccess": 1,
"OnRecordUpdate": 1,
"OnRecordUpdateExecute": 1,
"OnRecordAfterUpdateSuccess": 1,
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
if app.TestMailer.TotalSend() != 1 {
t.Fatalf("Expected 1 email, got %d", app.TestMailer.TotalSend())
}
// ensure that sentTo is set
otps, err := app.FindRecordsByFilter(core.CollectionNameOTPs, "sentTo='test@example.com'", "", 0, 0)
if err != nil || len(otps) != 1 {
t.Fatalf("Expected to find 1 OTP with sentTo %q, found %d", "test@example.com", len(otps))
}
},
},
{
Name: "existing auth record with intercepted email (with < 9 non-expired)",
Method: http.MethodPost,
URL: "/api/collections/users/request-otp",
Body: strings.NewReader(`{"email":"test@example.com"}`),
Delay: 100 * time.Millisecond,
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
// prevent email sent
app.OnMailerRecordOTPSend("users").BindFunc(func(e *core.MailerRecordEvent) error {
return nil
})
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"otpId":"`,
},
NotExpectedContent: []string{
`"otpId":"otp_`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordRequestOTPRequest": 1,
"OnMailerRecordOTPSend": 1,
"OnModelCreate": 1,
"OnModelCreateExecute": 1,
"OnModelAfterCreateSuccess": 1,
"OnModelValidate": 1,
"OnRecordCreate": 1,
"OnRecordCreateExecute": 1,
"OnRecordAfterCreateSuccess": 1,
"OnRecordValidate": 1,
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
if app.TestMailer.TotalSend() != 0 {
t.Fatalf("Expected 0 emails, got %d", app.TestMailer.TotalSend())
}
// ensure that there is no OTP with user email as sentTo
otps, err := app.FindRecordsByFilter(core.CollectionNameOTPs, "sentTo='test@example.com'", "", 0, 0)
if err != nil || len(otps) != 0 {
t.Fatalf("Expected to find 0 OTPs with sentTo %q, found %d", "test@example.com", len(otps))
}
},
},
{
Name: "existing auth record (with > 9 non-expired)",
Method: http.MethodPost,
URL: "/api/collections/users/request-otp",
Body: strings.NewReader(`{"email":"test@example.com"}`),
Delay: 100 * time.Millisecond,
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
// insert 10 non-expired
for i := 0; i < 10; i++ {
otp := core.NewOTP(app)
otp.Id = "otp_" + strconv.Itoa(i)
otp.SetCollectionRef(user.Collection().Id)
otp.SetRecordRef(user.Id)
otp.SetPassword("123456")
if err := app.SaveNoValidate(otp); err != nil {
t.Fatal(err)
}
}
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"otpId":"otp_9"`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordRequestOTPRequest": 1,
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
if app.TestMailer.TotalSend() != 0 {
t.Fatalf("Expected 0 sent emails, got %d", app.TestMailer.TotalSend())
}
},
},
{
Name: "OnRecordRequestOTPRequest tx body write check",
Method: http.MethodPost,
URL: "/api/collections/users/request-otp",
Body: strings.NewReader(`{"email":"test@example.com"}`),
Delay: 100 * time.Millisecond,
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.OnRecordRequestOTPRequest().BindFunc(func(e *core.RecordCreateOTPRequestEvent) error {
original := e.App
return e.App.RunInTransaction(func(txApp core.App) error {
e.App = txApp
defer func() { e.App = original }()
if err := e.Next(); err != nil {
return err
}
return e.BadRequestError("TX_ERROR", nil)
})
})
},
ExpectedStatus: 400,
ExpectedEvents: map[string]int{"OnRecordRequestOTPRequest": 1},
ExpectedContent: []string{"TX_ERROR"},
},
// rate limit checks
// -----------------------------------------------------------
{
Name: "RateLimit rule - users:requestOTP",
Method: http.MethodPost,
URL: "/api/collections/users/request-otp",
Body: strings.NewReader(`{"email":"test@example.com"}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 100, Label: "*:requestOTP"},
{MaxRequests: 0, Label: "users:requestOTP"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "RateLimit rule - *:requestOTP",
Method: http.MethodPost,
URL: "/api/collections/users/request-otp",
Body: strings.NewReader(`{"email":"test@example.com"}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 0, Label: "*:requestOTP"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}

View file

@ -0,0 +1,104 @@
package apis
import (
"net/http"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/core/validators"
"github.com/pocketbase/pocketbase/tools/security"
"github.com/spf13/cast"
)
func recordConfirmPasswordReset(e *core.RequestEvent) error {
collection, err := findAuthCollection(e)
if err != nil {
return err
}
form := new(recordConfirmPasswordResetForm)
form.app = e.App
form.collection = collection
if err = e.BindBody(form); err != nil {
return e.BadRequestError("An error occurred while loading the submitted data.", err)
}
if err = form.validate(); err != nil {
return firstApiError(err, e.BadRequestError("An error occurred while validating the submitted data.", err))
}
authRecord, err := e.App.FindAuthRecordByToken(form.Token, core.TokenTypePasswordReset)
if err != nil {
return firstApiError(err, e.BadRequestError("Invalid or expired password reset token.", err))
}
event := new(core.RecordConfirmPasswordResetRequestEvent)
event.RequestEvent = e
event.Collection = collection
event.Record = authRecord
return e.App.OnRecordConfirmPasswordResetRequest().Trigger(event, func(e *core.RecordConfirmPasswordResetRequestEvent) error {
authRecord.SetPassword(form.Password)
if !authRecord.Verified() {
payload, err := security.ParseUnverifiedJWT(form.Token)
if err == nil && authRecord.Email() == cast.ToString(payload[core.TokenClaimEmail]) {
// mark as verified if the email hasn't changed
authRecord.SetVerified(true)
}
}
err = e.App.Save(authRecord)
if err != nil {
return firstApiError(err, e.BadRequestError("Failed to set new password.", err))
}
e.App.Store().Remove(getPasswordResetResendKey(authRecord))
return execAfterSuccessTx(true, e.App, func() error {
return e.NoContent(http.StatusNoContent)
})
})
}
// -------------------------------------------------------------------
type recordConfirmPasswordResetForm struct {
app core.App
collection *core.Collection
Token string `form:"token" json:"token"`
Password string `form:"password" json:"password"`
PasswordConfirm string `form:"passwordConfirm" json:"passwordConfirm"`
}
func (form *recordConfirmPasswordResetForm) validate() error {
min := 1
passField, ok := form.collection.Fields.GetByName(core.FieldNamePassword).(*core.PasswordField)
if ok && passField != nil && passField.Min > 0 {
min = passField.Min
}
return validation.ValidateStruct(form,
validation.Field(&form.Token, validation.Required, validation.By(form.checkToken)),
validation.Field(&form.Password, validation.Required, validation.Length(min, 255)), // the FieldPassword validator will check further the specicic length constraints
validation.Field(&form.PasswordConfirm, validation.Required, validation.By(validators.Equal(form.Password))),
)
}
func (form *recordConfirmPasswordResetForm) checkToken(value any) error {
v, _ := value.(string)
if v == "" {
return nil
}
record, err := form.app.FindAuthRecordByToken(v, core.TokenTypePasswordReset)
if err != nil || record == nil {
return validation.NewError("validation_invalid_token", "Invalid or expired token.")
}
if record.Collection().Id != form.collection.Id {
return validation.NewError("validation_token_collection_mismatch", "The provided token is for different auth collection.")
}
return nil
}

View file

@ -0,0 +1,360 @@
package apis_test
import (
"net/http"
"strings"
"testing"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
)
func TestRecordConfirmPasswordReset(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "empty data",
Method: http.MethodPost,
URL: "/api/collections/users/confirm-password-reset",
Body: strings.NewReader(``),
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{`,
`"password":{"code":"validation_required"`,
`"passwordConfirm":{"code":"validation_required"`,
`"token":{"code":"validation_required"`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "invalid data format",
Method: http.MethodPost,
URL: "/api/collections/users/confirm-password-reset",
Body: strings.NewReader(`{"password`),
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "expired token and invalid password",
Method: http.MethodPost,
URL: "/api/collections/users/confirm-password-reset",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MTY0MDk5MTY2MSwidHlwZSI6InBhc3N3b3JkUmVzZXQiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8iLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.5Tm6_6amQqOlX3urAnXlEdmxwG5qQJfiTg6U0hHR1hk",
"password":"1234567",
"passwordConfirm":"7654321"
}`),
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{`,
`"token":{"code":"validation_invalid_token"`,
`"password":{"code":"validation_length_out_of_range"`,
`"passwordConfirm":{"code":"validation_values_mismatch"`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "non-password reset token",
Method: http.MethodPost,
URL: "/api/collections/users/confirm-password-reset",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InZlcmlmaWNhdGlvbiIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSJ9.SetHpu2H-x-q4TIUz-xiQjwi7MNwLCLvSs4O0hUSp0E",
"password":"1234567!",
"passwordConfirm":"1234567!"
}`),
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{`,
`"token":{"code":"validation_invalid_token"`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "non auth collection",
Method: http.MethodPost,
URL: "/api/collections/demo1/confirm-password-reset?expand=rel,missing",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InBhc3N3b3JkUmVzZXQiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8iLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.xR-xq1oHDy0D8Q4NDOAEyYKGHWd_swzoiSoL8FLFBHY",
"password":"1234567!",
"passwordConfirm":"1234567!"
}`),
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "different auth collection",
Method: http.MethodPost,
URL: "/api/collections/clients/confirm-password-reset?expand=rel,missing",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InBhc3N3b3JkUmVzZXQiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8iLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.xR-xq1oHDy0D8Q4NDOAEyYKGHWd_swzoiSoL8FLFBHY",
"password":"1234567!",
"passwordConfirm":"1234567!"
}`),
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{"token":{"code":"validation_token_collection_mismatch"`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "valid token and data (unverified user)",
Method: http.MethodPost,
URL: "/api/collections/users/confirm-password-reset",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InBhc3N3b3JkUmVzZXQiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8iLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.xR-xq1oHDy0D8Q4NDOAEyYKGHWd_swzoiSoL8FLFBHY",
"password":"1234567!",
"passwordConfirm":"1234567!"
}`),
ExpectedStatus: 204,
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordConfirmPasswordResetRequest": 1,
"OnModelUpdate": 1,
"OnModelUpdateExecute": 1,
"OnModelAfterUpdateSuccess": 1,
"OnModelValidate": 1,
"OnRecordUpdate": 1,
"OnRecordUpdateExecute": 1,
"OnRecordAfterUpdateSuccess": 1,
"OnRecordValidate": 1,
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatalf("Failed to fetch confirm password user: %v", err)
}
if user.Verified() {
t.Fatal("Expected the user to be unverified")
}
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
_, err := app.FindAuthRecordByToken(
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InBhc3N3b3JkUmVzZXQiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8iLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.xR-xq1oHDy0D8Q4NDOAEyYKGHWd_swzoiSoL8FLFBHY",
core.TokenTypePasswordReset,
)
if err == nil {
t.Fatal("Expected the password reset token to be invalidated")
}
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatalf("Failed to fetch confirm password user: %v", err)
}
if !user.Verified() {
t.Fatal("Expected the user to be marked as verified")
}
if !user.ValidatePassword("1234567!") {
t.Fatal("Password wasn't changed")
}
},
},
{
Name: "valid token and data (unverified user with different email from the one in the token)",
Method: http.MethodPost,
URL: "/api/collections/users/confirm-password-reset",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InBhc3N3b3JkUmVzZXQiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8iLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.xR-xq1oHDy0D8Q4NDOAEyYKGHWd_swzoiSoL8FLFBHY",
"password":"1234567!",
"passwordConfirm":"1234567!"
}`),
ExpectedStatus: 204,
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordConfirmPasswordResetRequest": 1,
"OnModelUpdate": 1,
"OnModelUpdateExecute": 1,
"OnModelAfterUpdateSuccess": 1,
"OnModelValidate": 1,
"OnRecordUpdate": 1,
"OnRecordUpdateExecute": 1,
"OnRecordAfterUpdateSuccess": 1,
"OnRecordValidate": 1,
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatalf("Failed to fetch confirm password user: %v", err)
}
if user.Verified() {
t.Fatal("Expected the user to be unverified")
}
oldTokenKey := user.TokenKey()
// manually change the email to check whether the verified state will be updated
user.SetEmail("test_update@example.com")
if err = app.Save(user); err != nil {
t.Fatalf("Failed to update user test email: %v", err)
}
// resave with the old token key since the email change above
// would change it and will make the password token invalid
user.SetTokenKey(oldTokenKey)
if err = app.Save(user); err != nil {
t.Fatalf("Failed to restore original user tokenKey: %v", err)
}
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
_, err := app.FindAuthRecordByToken(
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InBhc3N3b3JkUmVzZXQiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8iLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.xR-xq1oHDy0D8Q4NDOAEyYKGHWd_swzoiSoL8FLFBHY",
core.TokenTypePasswordReset,
)
if err == nil {
t.Fatalf("Expected the password reset token to be invalidated")
}
user, err := app.FindAuthRecordByEmail("users", "test_update@example.com")
if err != nil {
t.Fatalf("Failed to fetch confirm password user: %v", err)
}
if user.Verified() {
t.Fatal("Expected the user to remain unverified")
}
if !user.ValidatePassword("1234567!") {
t.Fatal("Password wasn't changed")
}
},
},
{
Name: "valid token and data (verified user)",
Method: http.MethodPost,
URL: "/api/collections/users/confirm-password-reset",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InBhc3N3b3JkUmVzZXQiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8iLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.xR-xq1oHDy0D8Q4NDOAEyYKGHWd_swzoiSoL8FLFBHY",
"password":"1234567!",
"passwordConfirm":"1234567!"
}`),
ExpectedStatus: 204,
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordConfirmPasswordResetRequest": 1,
"OnModelUpdate": 1,
"OnModelUpdateExecute": 1,
"OnModelAfterUpdateSuccess": 1,
"OnModelValidate": 1,
"OnRecordUpdate": 1,
"OnRecordUpdateExecute": 1,
"OnRecordAfterUpdateSuccess": 1,
"OnRecordValidate": 1,
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatalf("Failed to fetch confirm password user: %v", err)
}
// ensure that the user is already verified
user.SetVerified(true)
if err := app.Save(user); err != nil {
t.Fatalf("Failed to update user verified state")
}
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
_, err := app.FindAuthRecordByToken(
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InBhc3N3b3JkUmVzZXQiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8iLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.xR-xq1oHDy0D8Q4NDOAEyYKGHWd_swzoiSoL8FLFBHY",
core.TokenTypePasswordReset,
)
if err == nil {
t.Fatal("Expected the password reset token to be invalidated")
}
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatalf("Failed to fetch confirm password user: %v", err)
}
if !user.Verified() {
t.Fatal("Expected the user to remain verified")
}
if !user.ValidatePassword("1234567!") {
t.Fatal("Password wasn't changed")
}
},
},
{
Name: "OnRecordConfirmPasswordResetRequest tx body write check",
Method: http.MethodPost,
URL: "/api/collections/users/confirm-password-reset",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InBhc3N3b3JkUmVzZXQiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8iLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.xR-xq1oHDy0D8Q4NDOAEyYKGHWd_swzoiSoL8FLFBHY",
"password":"1234567!",
"passwordConfirm":"1234567!"
}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.OnRecordConfirmPasswordResetRequest().BindFunc(func(e *core.RecordConfirmPasswordResetRequestEvent) error {
original := e.App
return e.App.RunInTransaction(func(txApp core.App) error {
e.App = txApp
defer func() { e.App = original }()
if err := e.Next(); err != nil {
return err
}
return e.BadRequestError("TX_ERROR", nil)
})
})
},
ExpectedStatus: 400,
ExpectedEvents: map[string]int{"OnRecordConfirmPasswordResetRequest": 1},
ExpectedContent: []string{"TX_ERROR"},
},
// rate limit checks
// -----------------------------------------------------------
{
Name: "RateLimit rule - users:confirmPasswordReset",
Method: http.MethodPost,
URL: "/api/collections/users/confirm-password-reset",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InBhc3N3b3JkUmVzZXQiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8iLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.xR-xq1oHDy0D8Q4NDOAEyYKGHWd_swzoiSoL8FLFBHY",
"password":"1234567!",
"passwordConfirm":"1234567!"
}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 100, Label: "*:confirmPasswordReset"},
{MaxRequests: 0, Label: "users:confirmPasswordReset"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "RateLimit rule - *:confirmPasswordReset",
Method: http.MethodPost,
URL: "/api/collections/users/confirm-password-reset",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InBhc3N3b3JkUmVzZXQiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8iLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.xR-xq1oHDy0D8Q4NDOAEyYKGHWd_swzoiSoL8FLFBHY",
"password":"1234567!",
"passwordConfirm":"1234567!"
}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 0, Label: "*:confirmPasswordReset"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}

View file

@ -0,0 +1,88 @@
package apis
import (
"errors"
"fmt"
"net/http"
"time"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/go-ozzo/ozzo-validation/v4/is"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/mails"
"github.com/pocketbase/pocketbase/tools/routine"
)
func recordRequestPasswordReset(e *core.RequestEvent) error {
collection, err := findAuthCollection(e)
if err != nil {
return err
}
if !collection.PasswordAuth.Enabled {
return e.BadRequestError("The collection is not configured to allow password authentication.", nil)
}
form := new(recordRequestPasswordResetForm)
if err = e.BindBody(form); err != nil {
return firstApiError(err, e.BadRequestError("An error occurred while loading the submitted data.", err))
}
if err = form.validate(); err != nil {
return firstApiError(err, e.BadRequestError("An error occurred while validating the submitted data.", err))
}
record, err := e.App.FindAuthRecordByEmail(collection, form.Email)
if err != nil {
// eagerly write 204 response as a very basic measure against emails enumeration
e.NoContent(http.StatusNoContent)
return fmt.Errorf("failed to fetch %s record with email %s: %w", collection.Name, form.Email, err)
}
resendKey := getPasswordResetResendKey(record)
if e.App.Store().Has(resendKey) {
// eagerly write 204 response as a very basic measure against emails enumeration
e.NoContent(http.StatusNoContent)
return errors.New("try again later - you've already requested a password reset email")
}
event := new(core.RecordRequestPasswordResetRequestEvent)
event.RequestEvent = e
event.Collection = collection
event.Record = record
return e.App.OnRecordRequestPasswordResetRequest().Trigger(event, func(e *core.RecordRequestPasswordResetRequestEvent) error {
// run in background because we don't need to show the result to the client
app := e.App
routine.FireAndForget(func() {
if err := mails.SendRecordPasswordReset(app, e.Record); err != nil {
app.Logger().Error("Failed to send password reset email", "error", err)
return
}
app.Store().Set(resendKey, struct{}{})
time.AfterFunc(2*time.Minute, func() {
app.Store().Remove(resendKey)
})
})
return execAfterSuccessTx(true, e.App, func() error {
return e.NoContent(http.StatusNoContent)
})
})
}
// -------------------------------------------------------------------
type recordRequestPasswordResetForm struct {
Email string `form:"email" json:"email"`
}
func (form *recordRequestPasswordResetForm) validate() error {
return validation.ValidateStruct(form,
validation.Field(&form.Email, validation.Required, validation.Length(1, 255), is.EmailFormat),
)
}
func getPasswordResetResendKey(record *core.Record) string {
return "@limitPasswordResetEmail_" + record.Collection().Id + record.Id
}

View file

@ -0,0 +1,169 @@
package apis_test
import (
"net/http"
"strings"
"testing"
"time"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
)
func TestRecordRequestPasswordReset(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "not an auth collection",
Method: http.MethodPost,
URL: "/api/collections/demo1/request-password-reset",
Body: strings.NewReader(``),
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "empty data",
Method: http.MethodPost,
URL: "/api/collections/users/request-password-reset",
Body: strings.NewReader(``),
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{"email":{"code":"validation_required","message":"Cannot be blank."}}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "invalid data",
Method: http.MethodPost,
URL: "/api/collections/users/request-password-reset",
Body: strings.NewReader(`{"email`),
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "existing auth record in a collection with disabled password login",
Method: http.MethodPost,
URL: "/api/collections/nologin/request-password-reset",
Body: strings.NewReader(`{"email":"test@example.com"}`),
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "missing auth record",
Method: http.MethodPost,
URL: "/api/collections/users/request-password-reset",
Body: strings.NewReader(`{"email":"missing@example.com"}`),
Delay: 100 * time.Millisecond,
ExpectedStatus: 204,
ExpectedEvents: map[string]int{"*": 0},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
if app.TestMailer.TotalSend() != 0 {
t.Fatalf("Expected zero emails, got %d", app.TestMailer.TotalSend())
}
},
},
{
Name: "existing auth record",
Method: http.MethodPost,
URL: "/api/collections/users/request-password-reset",
Body: strings.NewReader(`{"email":"test@example.com"}`),
Delay: 100 * time.Millisecond,
ExpectedStatus: 204,
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordRequestPasswordResetRequest": 1,
"OnMailerSend": 1,
"OnMailerRecordPasswordResetSend": 1,
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
if !strings.Contains(app.TestMailer.LastMessage().HTML, "/auth/confirm-password-reset") {
t.Fatalf("Expected password reset email, got\n%v", app.TestMailer.LastMessage().HTML)
}
},
},
{
Name: "existing auth record (after already sent)",
Method: http.MethodPost,
URL: "/api/collections/users/request-password-reset",
Body: strings.NewReader(`{"email":"test@example.com"}`),
Delay: 100 * time.Millisecond,
ExpectedStatus: 204,
ExpectedEvents: map[string]int{"*": 0},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
// simulate recent verification sent
authRecord, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
resendKey := "@limitPasswordResetEmail_" + authRecord.Collection().Id + authRecord.Id
app.Store().Set(resendKey, struct{}{})
},
},
{
Name: "OnRecordRequestPasswordResetRequest tx body write check",
Method: http.MethodPost,
URL: "/api/collections/users/request-password-reset",
Body: strings.NewReader(`{"email":"test@example.com"}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.OnRecordRequestPasswordResetRequest().BindFunc(func(e *core.RecordRequestPasswordResetRequestEvent) error {
original := e.App
return e.App.RunInTransaction(func(txApp core.App) error {
e.App = txApp
defer func() { e.App = original }()
if err := e.Next(); err != nil {
return err
}
return e.BadRequestError("TX_ERROR", nil)
})
})
},
ExpectedStatus: 400,
ExpectedEvents: map[string]int{"OnRecordRequestPasswordResetRequest": 1},
ExpectedContent: []string{"TX_ERROR"},
},
// rate limit checks
// -----------------------------------------------------------
{
Name: "RateLimit rule - users:requestPasswordReset",
Method: http.MethodPost,
URL: "/api/collections/users/request-password-reset",
Body: strings.NewReader(`{"email":"missing@example.com"}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 100, Label: "*:requestPasswordReset"},
{MaxRequests: 0, Label: "users:requestPasswordReset"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "RateLimit rule - *:requestPasswordReset",
Method: http.MethodPost,
URL: "/api/collections/users/request-password-reset",
Body: strings.NewReader(`{"email":"missing@example.com"}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 0, Label: "*:requestPasswordReset"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}

View file

@ -0,0 +1,29 @@
package apis
import (
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tools/security"
"github.com/spf13/cast"
)
func recordAuthRefresh(e *core.RequestEvent) error {
record := e.Auth
if record == nil {
return e.NotFoundError("Missing auth record context.", nil)
}
currentToken := getAuthTokenFromRequest(e)
claims, _ := security.ParseUnverifiedJWT(currentToken)
if v, ok := claims[core.TokenClaimRefreshable]; !ok || !cast.ToBool(v) {
return e.ForbiddenError("The current auth token is not refreshable.", nil)
}
event := new(core.RecordAuthRefreshRequestEvent)
event.RequestEvent = e
event.Collection = record.Collection()
event.Record = record
return e.App.OnRecordAuthRefreshRequest().Trigger(event, func(e *core.RecordAuthRefreshRequestEvent) error {
return RecordAuthResponse(e.RequestEvent, e.Record, "", nil)
})
}

View file

@ -0,0 +1,202 @@
package apis_test
import (
"net/http"
"testing"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
)
func TestRecordAuthRefresh(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "unauthorized",
Method: http.MethodPost,
URL: "/api/collections/users/auth-refresh",
ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "superuser trying to refresh the auth of another auth collection",
Method: http.MethodPost,
URL: "/api/collections/users/auth-refresh",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "auth record + not an auth collection",
Method: http.MethodPost,
URL: "/api/collections/demo1/auth-refresh",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "auth record + different auth collection",
Method: http.MethodPost,
URL: "/api/collections/clients/auth-refresh?expand=rel,missing",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "auth record + same auth collection as the token",
Method: http.MethodPost,
URL: "/api/collections/users/auth-refresh?expand=rel,missing",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"token":`,
`"record":`,
`"id":"4q1xlclmfloku33"`,
`"emailVisibility":false`,
`"email":"test@example.com"`, // the owner can always view their email address
`"expand":`,
`"rel":`,
`"id":"llvuca81nly1qls"`,
},
NotExpectedContent: []string{
`"missing":`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordAuthRefreshRequest": 1,
"OnRecordAuthRequest": 1,
"OnRecordEnrich": 2,
},
},
{
Name: "auth record + same auth collection as the token but static/unrefreshable",
Method: http.MethodPost,
URL: "/api/collections/users/auth-refresh",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6ZmFsc2V9.4IsO6YMsR19crhwl_YWzvRH8pfq2Ri4Gv2dzGyneLak",
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "unverified auth record in onlyVerified collection",
Method: http.MethodPost,
URL: "/api/collections/clients/auth-refresh",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6Im8xeTBkZDBzcGQ3ODZtZCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.Zi0yXE-CNmnbTdVaQEzYZVuECqRdn3LgEM6pmB3XWBE",
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordAuthRefreshRequest": 1,
},
},
{
Name: "verified auth record in onlyVerified collection",
Method: http.MethodPost,
URL: "/api/collections/clients/auth-refresh",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"token":`,
`"record":`,
`"id":"gk390qegs4y47wn"`,
`"verified":true`,
`"email":"test@example.com"`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordAuthRefreshRequest": 1,
"OnRecordAuthRequest": 1,
"OnRecordEnrich": 1,
},
},
{
Name: "OnRecordAuthRefreshRequest tx body write check",
Method: http.MethodPost,
URL: "/api/collections/users/auth-refresh",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.OnRecordAuthRefreshRequest().BindFunc(func(e *core.RecordAuthRefreshRequestEvent) error {
original := e.App
return e.App.RunInTransaction(func(txApp core.App) error {
e.App = txApp
defer func() { e.App = original }()
if err := e.Next(); err != nil {
return err
}
return e.BadRequestError("TX_ERROR", nil)
})
})
},
ExpectedStatus: 400,
ExpectedEvents: map[string]int{"OnRecordAuthRefreshRequest": 1},
ExpectedContent: []string{"TX_ERROR"},
},
// rate limit checks
// -----------------------------------------------------------
{
Name: "RateLimit rule - users:authRefresh",
Method: http.MethodPost,
URL: "/api/collections/users/auth-refresh",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 100, Label: "*:authRefresh"},
{MaxRequests: 0, Label: "users:authRefresh"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "RateLimit rule - *:authRefresh",
Method: http.MethodPost,
URL: "/api/collections/users/auth-refresh",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 0, Label: "*:authRefresh"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}

View file

@ -0,0 +1,102 @@
package apis
import (
"net/http"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tools/security"
"github.com/spf13/cast"
)
func recordConfirmVerification(e *core.RequestEvent) error {
collection, err := findAuthCollection(e)
if err != nil {
return err
}
if collection.Name == core.CollectionNameSuperusers {
return e.BadRequestError("All superusers are verified by default.", nil)
}
form := new(recordConfirmVerificationForm)
form.app = e.App
form.collection = collection
if err = e.BindBody(form); err != nil {
return firstApiError(err, e.BadRequestError("An error occurred while loading the submitted data.", err))
}
if err = form.validate(); err != nil {
return firstApiError(err, e.BadRequestError("An error occurred while validating the submitted data.", err))
}
record, err := form.app.FindAuthRecordByToken(form.Token, core.TokenTypeVerification)
if err != nil {
return e.BadRequestError("Invalid or expired verification token.", err)
}
wasVerified := record.Verified()
event := new(core.RecordConfirmVerificationRequestEvent)
event.RequestEvent = e
event.Collection = collection
event.Record = record
return e.App.OnRecordConfirmVerificationRequest().Trigger(event, func(e *core.RecordConfirmVerificationRequestEvent) error {
if !wasVerified {
e.Record.SetVerified(true)
if err := e.App.Save(e.Record); err != nil {
return firstApiError(err, e.BadRequestError("An error occurred while saving the verified state.", err))
}
}
e.App.Store().Remove(getVerificationResendKey(e.Record))
return execAfterSuccessTx(true, e.App, func() error {
return e.NoContent(http.StatusNoContent)
})
})
}
// -------------------------------------------------------------------
type recordConfirmVerificationForm struct {
app core.App
collection *core.Collection
Token string `form:"token" json:"token"`
}
func (form *recordConfirmVerificationForm) validate() error {
return validation.ValidateStruct(form,
validation.Field(&form.Token, validation.Required, validation.By(form.checkToken)),
)
}
func (form *recordConfirmVerificationForm) checkToken(value any) error {
v, _ := value.(string)
if v == "" {
return nil // nothing to check
}
claims, _ := security.ParseUnverifiedJWT(v)
email := cast.ToString(claims["email"])
if email == "" {
return validation.NewError("validation_invalid_token_claims", "Missing email token claim.")
}
record, err := form.app.FindAuthRecordByToken(v, core.TokenTypeVerification)
if err != nil || record == nil {
return validation.NewError("validation_invalid_token", "Invalid or expired token.")
}
if record.Collection().Id != form.collection.Id {
return validation.NewError("validation_token_collection_mismatch", "The provided token is for different auth collection.")
}
if record.Email() != email {
return validation.NewError("validation_token_email_mismatch", "The record email doesn't match with the requested token claims.")
}
return nil
}

View file

@ -0,0 +1,216 @@
package apis_test
import (
"net/http"
"strings"
"testing"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
)
func TestRecordConfirmVerification(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "empty data",
Method: http.MethodPost,
URL: "/api/collections/users/confirm-verification",
Body: strings.NewReader(``),
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{`,
`"token":{"code":"validation_required"`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "invalid data format",
Method: http.MethodPost,
URL: "/api/collections/users/confirm-verification",
Body: strings.NewReader(`{"password`),
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "expired token",
Method: http.MethodPost,
URL: "/api/collections/users/confirm-verification",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MTY0MDk5MTY2MSwidHlwZSI6InZlcmlmaWNhdGlvbiIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSJ9.qqelNNL2Udl6K_TJ282sNHYCpASgA6SIuSVKGfBHMZU"
}`),
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{`,
`"token":{"code":"validation_invalid_token"`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "non-verification token",
Method: http.MethodPost,
URL: "/api/collections/users/confirm-verification",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InBhc3N3b3JkUmVzZXQiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8iLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.xR-xq1oHDy0D8Q4NDOAEyYKGHWd_swzoiSoL8FLFBHY"
}`),
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{`,
`"token":{"code":"validation_invalid_token"`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "non auth collection",
Method: http.MethodPost,
URL: "/api/collections/demo1/confirm-verification?expand=rel,missing",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InZlcmlmaWNhdGlvbiIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSJ9.SetHpu2H-x-q4TIUz-xiQjwi7MNwLCLvSs4O0hUSp0E"
}`),
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "different auth collection",
Method: http.MethodPost,
URL: "/api/collections/clients/confirm-verification?expand=rel,missing",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InZlcmlmaWNhdGlvbiIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSJ9.SetHpu2H-x-q4TIUz-xiQjwi7MNwLCLvSs4O0hUSp0E"
}`),
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{"token":{"code":"validation_token_collection_mismatch"`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "valid token",
Method: http.MethodPost,
URL: "/api/collections/users/confirm-verification",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InZlcmlmaWNhdGlvbiIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSJ9.SetHpu2H-x-q4TIUz-xiQjwi7MNwLCLvSs4O0hUSp0E"
}`),
ExpectedStatus: 204,
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordConfirmVerificationRequest": 1,
"OnModelUpdate": 1,
"OnModelValidate": 1,
"OnModelUpdateExecute": 1,
"OnModelAfterUpdateSuccess": 1,
"OnRecordUpdate": 1,
"OnRecordValidate": 1,
"OnRecordUpdateExecute": 1,
"OnRecordAfterUpdateSuccess": 1,
},
},
{
Name: "valid token (already verified)",
Method: http.MethodPost,
URL: "/api/collections/users/confirm-verification",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6Im9hcDY0MGNvdDR5cnUycyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InZlcmlmaWNhdGlvbiIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsImVtYWlsIjoidGVzdDJAZXhhbXBsZS5jb20ifQ.QQmM3odNFVk6u4J4-5H8IBM3dfk9YCD7mPW-8PhBAI8"
}`),
ExpectedStatus: 204,
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordConfirmVerificationRequest": 1,
},
},
{
Name: "valid verification token from a collection without allowed login",
Method: http.MethodPost,
URL: "/api/collections/nologin/confirm-verification",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6ImRjNDlrNmpnZWpuNDBoMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InZlcmlmaWNhdGlvbiIsImNvbGxlY3Rpb25JZCI6ImtwdjcwOXNrMmxxYnFrOCIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSJ9.5GmuZr4vmwk3Cb_3ZZWNxwbE75KZC-j71xxIPR9AsVw"
}`),
ExpectedStatus: 204,
ExpectedContent: []string{},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordConfirmVerificationRequest": 1,
"OnModelUpdate": 1,
"OnModelValidate": 1,
"OnModelUpdateExecute": 1,
"OnModelAfterUpdateSuccess": 1,
"OnRecordUpdate": 1,
"OnRecordValidate": 1,
"OnRecordUpdateExecute": 1,
"OnRecordAfterUpdateSuccess": 1,
},
},
{
Name: "OnRecordConfirmVerificationRequest tx body write check",
Method: http.MethodPost,
URL: "/api/collections/users/confirm-verification",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InZlcmlmaWNhdGlvbiIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSJ9.SetHpu2H-x-q4TIUz-xiQjwi7MNwLCLvSs4O0hUSp0E"
}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.OnRecordConfirmVerificationRequest().BindFunc(func(e *core.RecordConfirmVerificationRequestEvent) error {
original := e.App
return e.App.RunInTransaction(func(txApp core.App) error {
e.App = txApp
defer func() { e.App = original }()
if err := e.Next(); err != nil {
return err
}
return e.BadRequestError("TX_ERROR", nil)
})
})
},
ExpectedStatus: 400,
ExpectedEvents: map[string]int{"OnRecordConfirmVerificationRequest": 1},
ExpectedContent: []string{"TX_ERROR"},
},
// rate limit checks
// -----------------------------------------------------------
{
Name: "RateLimit rule - nologin:confirmVerification",
Method: http.MethodPost,
URL: "/api/collections/nologin/confirm-verification",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6ImRjNDlrNmpnZWpuNDBoMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InZlcmlmaWNhdGlvbiIsImNvbGxlY3Rpb25JZCI6ImtwdjcwOXNrMmxxYnFrOCIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSJ9.5GmuZr4vmwk3Cb_3ZZWNxwbE75KZC-j71xxIPR9AsVw"
}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 100, Label: "*:confirmVerification"},
{MaxRequests: 0, Label: "nologin:confirmVerification"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "RateLimit rule - *:confirmVerification",
Method: http.MethodPost,
URL: "/api/collections/nologin/confirm-verification",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6ImRjNDlrNmpnZWpuNDBoMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InZlcmlmaWNhdGlvbiIsImNvbGxlY3Rpb25JZCI6ImtwdjcwOXNrMmxxYnFrOCIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSJ9.5GmuZr4vmwk3Cb_3ZZWNxwbE75KZC-j71xxIPR9AsVw"
}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 0, Label: "*:confirmVerification"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}

View file

@ -0,0 +1,91 @@
package apis
import (
"errors"
"fmt"
"net/http"
"time"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/go-ozzo/ozzo-validation/v4/is"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/mails"
"github.com/pocketbase/pocketbase/tools/routine"
)
func recordRequestVerification(e *core.RequestEvent) error {
collection, err := findAuthCollection(e)
if err != nil {
return err
}
if collection.Name == core.CollectionNameSuperusers {
return e.BadRequestError("All superusers are verified by default.", nil)
}
form := new(recordRequestVerificationForm)
if err = e.BindBody(form); err != nil {
return firstApiError(err, e.BadRequestError("An error occurred while loading the submitted data.", err))
}
if err = form.validate(); err != nil {
return firstApiError(err, e.BadRequestError("An error occurred while validating the submitted data.", err))
}
record, err := e.App.FindAuthRecordByEmail(collection, form.Email)
if err != nil {
// eagerly write 204 response as a very basic measure against emails enumeration
e.NoContent(http.StatusNoContent)
return fmt.Errorf("failed to fetch %s record with email %s: %w", collection.Name, form.Email, err)
}
resendKey := getVerificationResendKey(record)
if !record.Verified() && e.App.Store().Has(resendKey) {
// eagerly write 204 response as a very basic measure against emails enumeration
e.NoContent(http.StatusNoContent)
return errors.New("try again later - you've already requested a verification email")
}
event := new(core.RecordRequestVerificationRequestEvent)
event.RequestEvent = e
event.Collection = collection
event.Record = record
return e.App.OnRecordRequestVerificationRequest().Trigger(event, func(e *core.RecordRequestVerificationRequestEvent) error {
if e.Record.Verified() {
return e.NoContent(http.StatusNoContent)
}
// run in background because we don't need to show the result to the client
app := e.App
routine.FireAndForget(func() {
if err := mails.SendRecordVerification(app, e.Record); err != nil {
app.Logger().Error("Failed to send verification email", "error", err)
}
app.Store().Set(resendKey, struct{}{})
time.AfterFunc(2*time.Minute, func() {
app.Store().Remove(resendKey)
})
})
return execAfterSuccessTx(true, e.App, func() error {
return e.NoContent(http.StatusNoContent)
})
})
}
// -------------------------------------------------------------------
type recordRequestVerificationForm struct {
Email string `form:"email" json:"email"`
}
func (form *recordRequestVerificationForm) validate() error {
return validation.ValidateStruct(form,
validation.Field(&form.Email, validation.Required, validation.Length(1, 255), is.EmailFormat),
)
}
func getVerificationResendKey(record *core.Record) string {
return "@limitVerificationEmail_" + record.Collection().Id + record.Id
}

View file

@ -0,0 +1,186 @@
package apis_test
import (
"net/http"
"strings"
"testing"
"time"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
)
func TestRecordRequestVerification(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "not an auth collection",
Method: http.MethodPost,
URL: "/api/collections/demo1/request-verification",
Body: strings.NewReader(``),
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "empty data",
Method: http.MethodPost,
URL: "/api/collections/users/request-verification",
Body: strings.NewReader(``),
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{"email":{"code":"validation_required","message":"Cannot be blank."}}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "invalid data",
Method: http.MethodPost,
URL: "/api/collections/users/request-verification",
Body: strings.NewReader(`{"email`),
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "missing auth record",
Method: http.MethodPost,
URL: "/api/collections/users/request-verification",
Body: strings.NewReader(`{"email":"missing@example.com"}`),
Delay: 100 * time.Millisecond,
ExpectedStatus: 204,
ExpectedEvents: map[string]int{"*": 0},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
if app.TestMailer.TotalSend() != 0 {
t.Fatalf("Expected zero emails, got %d", app.TestMailer.TotalSend())
}
},
},
{
Name: "already verified auth record",
Method: http.MethodPost,
URL: "/api/collections/users/request-verification",
Body: strings.NewReader(`{"email":"test2@example.com"}`),
Delay: 100 * time.Millisecond,
ExpectedStatus: 204,
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordRequestVerificationRequest": 1,
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
if app.TestMailer.TotalSend() != 0 {
t.Fatalf("Expected zero emails, got %d", app.TestMailer.TotalSend())
}
},
},
{
Name: "existing auth record",
Method: http.MethodPost,
URL: "/api/collections/users/request-verification",
Body: strings.NewReader(`{"email":"test@example.com"}`),
Delay: 100 * time.Millisecond,
ExpectedStatus: 204,
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordRequestVerificationRequest": 1,
"OnMailerSend": 1,
"OnMailerRecordVerificationSend": 1,
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
if !strings.Contains(app.TestMailer.LastMessage().HTML, "/auth/confirm-verification") {
t.Fatalf("Expected verification email, got\n%v", app.TestMailer.LastMessage().HTML)
}
},
},
{
Name: "existing auth record (after already sent)",
Method: http.MethodPost,
URL: "/api/collections/users/request-verification",
Body: strings.NewReader(`{"email":"test@example.com"}`),
Delay: 100 * time.Millisecond,
ExpectedStatus: 204,
ExpectedEvents: map[string]int{
"*": 0,
// terminated before firing the event
// "OnRecordRequestVerificationRequest": 1,
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
// simulate recent verification sent
authRecord, err := app.FindFirstRecordByData("users", "email", "test@example.com")
if err != nil {
t.Fatal(err)
}
resendKey := "@limitVerificationEmail_" + authRecord.Collection().Id + authRecord.Id
app.Store().Set(resendKey, struct{}{})
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
if app.TestMailer.TotalSend() != 0 {
t.Fatalf("Expected zero emails, got %d", app.TestMailer.TotalSend())
}
},
},
{
Name: "OnRecordRequestVerificationRequest tx body write check",
Method: http.MethodPost,
URL: "/api/collections/users/request-verification",
Body: strings.NewReader(`{"email":"test@example.com"}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.OnRecordRequestVerificationRequest().BindFunc(func(e *core.RecordRequestVerificationRequestEvent) error {
original := e.App
return e.App.RunInTransaction(func(txApp core.App) error {
e.App = txApp
defer func() { e.App = original }()
if err := e.Next(); err != nil {
return err
}
return e.BadRequestError("TX_ERROR", nil)
})
})
},
ExpectedStatus: 400,
ExpectedEvents: map[string]int{"OnRecordRequestVerificationRequest": 1},
ExpectedContent: []string{"TX_ERROR"},
},
// rate limit checks
// -----------------------------------------------------------
{
Name: "RateLimit rule - users:requestVerification",
Method: http.MethodPost,
URL: "/api/collections/users/request-verification",
Body: strings.NewReader(`{"email":"test@example.com"}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 100, Label: "*:requestVerification"},
{MaxRequests: 0, Label: "users:requestVerification"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "RateLimit rule - *:requestVerification",
Method: http.MethodPost,
URL: "/api/collections/users/request-verification",
Body: strings.NewReader(`{"email":"test@example.com"}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 0, Label: "*:requestVerification"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}

View file

@ -0,0 +1,374 @@
package apis
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"log/slog"
"maps"
"net/http"
"strings"
"time"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tools/dbutils"
"github.com/pocketbase/pocketbase/tools/filesystem"
"golang.org/x/oauth2"
)
func recordAuthWithOAuth2(e *core.RequestEvent) error {
collection, err := findAuthCollection(e)
if err != nil {
return err
}
if !collection.OAuth2.Enabled {
return e.ForbiddenError("The collection is not configured to allow OAuth2 authentication.", nil)
}
var fallbackAuthRecord *core.Record
if e.Auth != nil && e.Auth.Collection().Id == collection.Id {
fallbackAuthRecord = e.Auth
}
e.Set(core.RequestEventKeyInfoContext, core.RequestInfoContextOAuth2)
form := new(recordOAuth2LoginForm)
form.collection = collection
if err = e.BindBody(form); err != nil {
return firstApiError(err, e.BadRequestError("An error occurred while loading the submitted data.", err))
}
if form.RedirectUrl != "" && form.RedirectURL == "" {
e.App.Logger().Warn("[recordAuthWithOAuth2] redirectUrl body param is deprecated and will be removed in the future. Please replace it with redirectURL.")
form.RedirectURL = form.RedirectUrl
}
if err = form.validate(); err != nil {
return firstApiError(err, e.BadRequestError("An error occurred while loading the submitted data.", err))
}
// exchange token for OAuth2 user info and locate existing ExternalAuth rel
// ---------------------------------------------------------------
// load provider configuration
providerConfig, ok := collection.OAuth2.GetProviderConfig(form.Provider)
if !ok {
return e.InternalServerError("Missing or invalid provider config.", nil)
}
provider, err := providerConfig.InitProvider()
if err != nil {
return firstApiError(err, e.InternalServerError("Failed to init provider "+form.Provider, err))
}
ctx, cancel := context.WithTimeout(e.Request.Context(), 30*time.Second)
defer cancel()
provider.SetContext(ctx)
provider.SetRedirectURL(form.RedirectURL)
var opts []oauth2.AuthCodeOption
if provider.PKCE() {
opts = append(opts, oauth2.SetAuthURLParam("code_verifier", form.CodeVerifier))
}
// fetch token
token, err := provider.FetchToken(form.Code, opts...)
if err != nil {
return firstApiError(err, e.BadRequestError("Failed to fetch OAuth2 token.", err))
}
// fetch external auth user
authUser, err := provider.FetchAuthUser(token)
if err != nil {
return firstApiError(err, e.BadRequestError("Failed to fetch OAuth2 user.", err))
}
var authRecord *core.Record
// check for existing relation with the auth collection
externalAuthRel, err := e.App.FindFirstExternalAuthByExpr(dbx.HashExp{
"collectionRef": form.collection.Id,
"provider": form.Provider,
"providerId": authUser.Id,
})
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return e.InternalServerError("Failed OAuth2 relation check.", err)
}
switch {
case err == nil && externalAuthRel != nil:
authRecord, err = e.App.FindRecordById(form.collection, externalAuthRel.RecordRef())
if err != nil {
return err
}
case fallbackAuthRecord != nil && fallbackAuthRecord.Collection().Id == form.collection.Id:
// fallback to the logged auth record (if any)
authRecord = fallbackAuthRecord
case authUser.Email != "":
// look for an existing auth record by the external auth record's email
authRecord, err = e.App.FindAuthRecordByEmail(form.collection.Id, authUser.Email)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return e.InternalServerError("Failed OAuth2 auth record check.", err)
}
}
// ---------------------------------------------------------------
event := new(core.RecordAuthWithOAuth2RequestEvent)
event.RequestEvent = e
event.Collection = collection
event.ProviderName = form.Provider
event.ProviderClient = provider
event.OAuth2User = authUser
event.CreateData = form.CreateData
event.Record = authRecord
event.IsNewRecord = authRecord == nil
return e.App.OnRecordAuthWithOAuth2Request().Trigger(event, func(e *core.RecordAuthWithOAuth2RequestEvent) error {
if err := oauth2Submit(e, externalAuthRel); err != nil {
return firstApiError(err, e.BadRequestError("Failed to authenticate.", err))
}
// @todo revert back to struct after removing the custom auth.AuthUser marshalization
meta := map[string]any{}
rawOAuth2User, err := json.Marshal(e.OAuth2User)
if err != nil {
return err
}
err = json.Unmarshal(rawOAuth2User, &meta)
if err != nil {
return err
}
meta["isNew"] = e.IsNewRecord
return RecordAuthResponse(e.RequestEvent, e.Record, core.MFAMethodOAuth2, meta)
})
}
// -------------------------------------------------------------------
type recordOAuth2LoginForm struct {
collection *core.Collection
// Additional data that will be used for creating a new auth record
// if an existing OAuth2 account doesn't exist.
CreateData map[string]any `form:"createData" json:"createData"`
// The name of the OAuth2 client provider (eg. "google")
Provider string `form:"provider" json:"provider"`
// The authorization code returned from the initial request.
Code string `form:"code" json:"code"`
// The optional PKCE code verifier as part of the code_challenge sent with the initial request.
CodeVerifier string `form:"codeVerifier" json:"codeVerifier"`
// The redirect url sent with the initial request.
RedirectURL string `form:"redirectURL" json:"redirectURL"`
// @todo
// deprecated: use RedirectURL instead
// RedirectUrl will be removed after dropping v0.22 support
RedirectUrl string `form:"redirectUrl" json:"redirectUrl"`
}
func (form *recordOAuth2LoginForm) validate() error {
return validation.ValidateStruct(form,
validation.Field(&form.Provider, validation.Required, validation.Length(0, 100), validation.By(form.checkProviderName)),
validation.Field(&form.Code, validation.Required),
validation.Field(&form.RedirectURL, validation.Required),
)
}
func (form *recordOAuth2LoginForm) checkProviderName(value any) error {
name, _ := value.(string)
_, ok := form.collection.OAuth2.GetProviderConfig(name)
if !ok {
return validation.NewError("validation_invalid_provider", "Provider with name {{.name}} is missing or is not enabled.").
SetParams(map[string]any{"name": name})
}
return nil
}
func oldCanAssignUsername(txApp core.App, collection *core.Collection, username string) bool {
// ensure that username is unique
index, hasUniqueue := dbutils.FindSingleColumnUniqueIndex(collection.Indexes, collection.OAuth2.MappedFields.Username)
if hasUniqueue {
var expr dbx.Expression
if strings.EqualFold(index.Columns[0].Collate, "nocase") {
// case-insensitive search
expr = dbx.NewExp("username = {:username} COLLATE NOCASE", dbx.Params{"username": username})
} else {
expr = dbx.HashExp{"username": username}
}
var exists int
_ = txApp.RecordQuery(collection).Select("(1)").AndWhere(expr).Limit(1).Row(&exists)
if exists > 0 {
return false
}
}
// ensure that the value matches the pattern of the username field (if text)
txtField, _ := collection.Fields.GetByName(collection.OAuth2.MappedFields.Username).(*core.TextField)
return txtField != nil && txtField.ValidatePlainValue(username) == nil
}
func oauth2Submit(e *core.RecordAuthWithOAuth2RequestEvent, optExternalAuth *core.ExternalAuth) error {
return e.App.RunInTransaction(func(txApp core.App) error {
if e.Record == nil {
// extra check to prevent creating a superuser record via
// OAuth2 in case the method is used by another action
if e.Collection.Name == core.CollectionNameSuperusers {
return errors.New("superusers are not allowed to sign-up with OAuth2")
}
payload := maps.Clone(e.CreateData)
if payload == nil {
payload = map[string]any{}
}
// assign the OAuth2 user email only if the user hasn't submitted one
// (ignore empty/invalid values for consistency with the OAuth2->existing user update flow)
if v, _ := payload[core.FieldNameEmail].(string); v == "" {
payload[core.FieldNameEmail] = e.OAuth2User.Email
}
// map known fields (unless the field was explicitly submitted as part of CreateData)
if _, ok := payload[e.Collection.OAuth2.MappedFields.Id]; !ok && e.Collection.OAuth2.MappedFields.Id != "" {
payload[e.Collection.OAuth2.MappedFields.Id] = e.OAuth2User.Id
}
if _, ok := payload[e.Collection.OAuth2.MappedFields.Name]; !ok && e.Collection.OAuth2.MappedFields.Name != "" {
payload[e.Collection.OAuth2.MappedFields.Name] = e.OAuth2User.Name
}
if _, ok := payload[e.Collection.OAuth2.MappedFields.Username]; !ok &&
// no explicit username payload value and existing OAuth2 mapping
e.Collection.OAuth2.MappedFields.Username != "" &&
// extra checks for backward compatibility with earlier versions
oldCanAssignUsername(txApp, e.Collection, e.OAuth2User.Username) {
payload[e.Collection.OAuth2.MappedFields.Username] = e.OAuth2User.Username
}
if _, ok := payload[e.Collection.OAuth2.MappedFields.AvatarURL]; !ok &&
// no explicit avatar payload value and existing OAuth2 mapping
e.Collection.OAuth2.MappedFields.AvatarURL != "" &&
// non-empty OAuth2 avatar url
e.OAuth2User.AvatarURL != "" {
mappedField := e.Collection.Fields.GetByName(e.Collection.OAuth2.MappedFields.AvatarURL)
if mappedField != nil && mappedField.Type() == core.FieldTypeFile {
// download the avatar if the mapped field is a file
avatarFile, err := func() (*filesystem.File, error) {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
return filesystem.NewFileFromURL(ctx, e.OAuth2User.AvatarURL)
}()
if err != nil {
txApp.Logger().Warn("Failed to retrieve OAuth2 avatar", slog.String("error", err.Error()))
} else {
payload[e.Collection.OAuth2.MappedFields.AvatarURL] = avatarFile
}
} else {
// otherwise - assign the url string
payload[e.Collection.OAuth2.MappedFields.AvatarURL] = e.OAuth2User.AvatarURL
}
}
createdRecord, err := sendOAuth2RecordCreateRequest(txApp, e, payload)
if err != nil {
return err
}
e.Record = createdRecord
if e.Record.Email() == e.OAuth2User.Email && !e.Record.Verified() {
// mark as verified as long as it matches the OAuth2 data (even if the email is empty)
e.Record.SetVerified(true)
if err := txApp.Save(e.Record); err != nil {
return err
}
}
} else {
var needUpdate bool
isLoggedAuthRecord := e.Auth != nil &&
e.Auth.Id == e.Record.Id &&
e.Auth.Collection().Id == e.Record.Collection().Id
// set random password for users with unverified email
// (this is in case a malicious actor has registered previously with the user email)
if !isLoggedAuthRecord && e.Record.Email() != "" && !e.Record.Verified() {
e.Record.SetRandomPassword()
needUpdate = true
}
// update the existing auth record empty email if the data.OAuth2User has one
// (this is in case previously the auth record was created
// with an OAuth2 provider that didn't return an email address)
if e.Record.Email() == "" && e.OAuth2User.Email != "" {
e.Record.SetEmail(e.OAuth2User.Email)
needUpdate = true
}
// update the existing auth record verified state
// (only if the auth record doesn't have an email or the auth record email match with the one in data.OAuth2User)
if !e.Record.Verified() && (e.Record.Email() == "" || e.Record.Email() == e.OAuth2User.Email) {
e.Record.SetVerified(true)
needUpdate = true
}
if needUpdate {
if err := txApp.Save(e.Record); err != nil {
return err
}
}
}
// create ExternalAuth relation if missing
if optExternalAuth == nil {
optExternalAuth = core.NewExternalAuth(txApp)
optExternalAuth.SetCollectionRef(e.Record.Collection().Id)
optExternalAuth.SetRecordRef(e.Record.Id)
optExternalAuth.SetProvider(e.ProviderName)
optExternalAuth.SetProviderId(e.OAuth2User.Id)
if err := txApp.Save(optExternalAuth); err != nil {
return fmt.Errorf("failed to save linked rel: %w", err)
}
}
return nil
})
}
func sendOAuth2RecordCreateRequest(txApp core.App, e *core.RecordAuthWithOAuth2RequestEvent, payload map[string]any) (*core.Record, error) {
ir := &core.InternalRequest{
Method: http.MethodPost,
URL: "/api/collections/" + e.Collection.Name + "/records",
Body: payload,
}
var createdRecord *core.Record
response, err := processInternalRequest(txApp, e.RequestEvent, ir, core.RequestInfoContextOAuth2, func(data any) error {
createdRecord, _ = data.(*core.Record)
return nil
})
if err != nil {
return nil, err
}
if response.Status != http.StatusOK || createdRecord == nil {
return nil, errors.New("failed to create OAuth2 auth record")
}
return createdRecord, nil
}

View file

@ -0,0 +1,74 @@
package apis
import (
"encoding/json"
"net/http"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tools/subscriptions"
)
const (
oauth2SubscriptionTopic string = "@oauth2"
oauth2RedirectFailurePath string = "../_/#/auth/oauth2-redirect-failure"
oauth2RedirectSuccessPath string = "../_/#/auth/oauth2-redirect-success"
)
type oauth2RedirectData struct {
State string `form:"state" json:"state"`
Code string `form:"code" json:"code"`
Error string `form:"error" json:"error,omitempty"`
}
func oauth2SubscriptionRedirect(e *core.RequestEvent) error {
redirectStatusCode := http.StatusTemporaryRedirect
if e.Request.Method != http.MethodGet {
redirectStatusCode = http.StatusSeeOther
}
data := oauth2RedirectData{}
if e.Request.Method == http.MethodPost {
if err := e.BindBody(&data); err != nil {
e.App.Logger().Debug("Failed to read OAuth2 redirect data", "error", err)
return e.Redirect(redirectStatusCode, oauth2RedirectFailurePath)
}
} else {
query := e.Request.URL.Query()
data.State = query.Get("state")
data.Code = query.Get("code")
data.Error = query.Get("error")
}
if data.State == "" {
e.App.Logger().Debug("Missing OAuth2 state parameter")
return e.Redirect(redirectStatusCode, oauth2RedirectFailurePath)
}
client, err := e.App.SubscriptionsBroker().ClientById(data.State)
if err != nil || client.IsDiscarded() || !client.HasSubscription(oauth2SubscriptionTopic) {
e.App.Logger().Debug("Missing or invalid OAuth2 subscription client", "error", err, "clientId", data.State)
return e.Redirect(redirectStatusCode, oauth2RedirectFailurePath)
}
defer client.Unsubscribe(oauth2SubscriptionTopic)
encodedData, err := json.Marshal(data)
if err != nil {
e.App.Logger().Debug("Failed to marshalize OAuth2 redirect data", "error", err)
return e.Redirect(redirectStatusCode, oauth2RedirectFailurePath)
}
msg := subscriptions.Message{
Name: oauth2SubscriptionTopic,
Data: encodedData,
}
client.Send(msg)
if data.Error != "" || data.Code == "" {
e.App.Logger().Debug("Failed OAuth2 redirect due to an error or missing code parameter", "error", data.Error, "clientId", data.State)
return e.Redirect(redirectStatusCode, oauth2RedirectFailurePath)
}
return e.Redirect(redirectStatusCode, oauth2RedirectSuccessPath)
}

View file

@ -0,0 +1,274 @@
package apis_test
import (
"context"
"net/http"
"strings"
"testing"
"time"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
"github.com/pocketbase/pocketbase/tools/subscriptions"
)
func TestRecordAuthWithOAuth2Redirect(t *testing.T) {
t.Parallel()
clientStubs := make([]map[string]subscriptions.Client, 0, 10)
for i := 0; i < 10; i++ {
c1 := subscriptions.NewDefaultClient()
c2 := subscriptions.NewDefaultClient()
c2.Subscribe("@oauth2")
c3 := subscriptions.NewDefaultClient()
c3.Subscribe("test1", "@oauth2")
c4 := subscriptions.NewDefaultClient()
c4.Subscribe("test1", "test2")
c5 := subscriptions.NewDefaultClient()
c5.Subscribe("@oauth2")
c5.Discard()
clientStubs = append(clientStubs, map[string]subscriptions.Client{
"c1": c1,
"c2": c2,
"c3": c3,
"c4": c4,
"c5": c5,
})
}
checkFailureRedirect := func(t testing.TB, app *tests.TestApp, res *http.Response) {
loc := res.Header.Get("Location")
if !strings.Contains(loc, "/oauth2-redirect-failure") {
t.Fatalf("Expected failure redirect, got %q", loc)
}
}
checkSuccessRedirect := func(t testing.TB, app *tests.TestApp, res *http.Response) {
loc := res.Header.Get("Location")
if !strings.Contains(loc, "/oauth2-redirect-success") {
t.Fatalf("Expected success redirect, got %q", loc)
}
}
// note: don't exit because it is usually called as part of a separate goroutine
checkClientMessages := func(t testing.TB, clientId string, msg subscriptions.Message, expectedMessages map[string][]string) {
if len(expectedMessages[clientId]) == 0 {
t.Errorf("Unexpected client %q message, got %q:\n%q", clientId, msg.Name, msg.Data)
return
}
if msg.Name != "@oauth2" {
t.Errorf("Expected @oauth2 msg.Name, got %q", msg.Name)
return
}
for _, txt := range expectedMessages[clientId] {
if !strings.Contains(string(msg.Data), txt) {
t.Errorf("Failed to find %q in \n%s", txt, msg.Data)
return
}
}
}
beforeTestFunc := func(
clients map[string]subscriptions.Client,
expectedMessages map[string][]string,
) func(testing.TB, *tests.TestApp, *core.ServeEvent) {
return func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
for _, client := range clients {
app.SubscriptionsBroker().Register(client)
}
ctx, cancelFunc := context.WithTimeout(context.Background(), 100*time.Millisecond)
// add to the app store so that it can be cancelled manually after test completion
app.Store().Set("cancelFunc", cancelFunc)
go func() {
defer cancelFunc()
for {
select {
case msg, ok := <-clients["c1"].Channel():
if ok {
checkClientMessages(t, "c1", msg, expectedMessages)
} else {
t.Errorf("Unexpected c1 closed channel")
}
case msg, ok := <-clients["c2"].Channel():
if ok {
checkClientMessages(t, "c2", msg, expectedMessages)
} else {
t.Errorf("Unexpected c2 closed channel")
}
case msg, ok := <-clients["c3"].Channel():
if ok {
checkClientMessages(t, "c3", msg, expectedMessages)
} else {
t.Errorf("Unexpected c3 closed channel")
}
case msg, ok := <-clients["c4"].Channel():
if ok {
checkClientMessages(t, "c4", msg, expectedMessages)
} else {
t.Errorf("Unexpected c4 closed channel")
}
case _, ok := <-clients["c5"].Channel():
if ok {
t.Errorf("Expected c5 channel to be closed")
}
case <-ctx.Done():
for _, c := range clients {
c.Discard()
}
return
}
}
}()
}
}
scenarios := []tests.ApiScenario{
{
Name: "no state query param",
Method: http.MethodGet,
URL: "/api/oauth2-redirect?code=123",
BeforeTestFunc: beforeTestFunc(clientStubs[0], nil),
ExpectedStatus: http.StatusTemporaryRedirect,
ExpectedEvents: map[string]int{"*": 0},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
app.Store().Get("cancelFunc").(context.CancelFunc)()
checkFailureRedirect(t, app, res)
},
},
{
Name: "invalid or missing client",
Method: http.MethodGet,
URL: "/api/oauth2-redirect?code=123&state=missing",
BeforeTestFunc: beforeTestFunc(clientStubs[1], nil),
ExpectedStatus: http.StatusTemporaryRedirect,
ExpectedEvents: map[string]int{"*": 0},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
app.Store().Get("cancelFunc").(context.CancelFunc)()
checkFailureRedirect(t, app, res)
},
},
{
Name: "no code query param",
Method: http.MethodGet,
URL: "/api/oauth2-redirect?state=" + clientStubs[2]["c3"].Id(),
BeforeTestFunc: beforeTestFunc(clientStubs[2], map[string][]string{
"c3": {`"state":"` + clientStubs[2]["c3"].Id(), `"code":""`},
}),
ExpectedStatus: http.StatusTemporaryRedirect,
ExpectedEvents: map[string]int{"*": 0},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
app.Store().Get("cancelFunc").(context.CancelFunc)()
checkFailureRedirect(t, app, res)
if clientStubs[2]["c3"].HasSubscription("@oauth2") {
t.Fatalf("Expected oauth2 subscription to be removed")
}
},
},
{
Name: "error query param",
Method: http.MethodGet,
URL: "/api/oauth2-redirect?error=example&code=123&state=" + clientStubs[3]["c3"].Id(),
BeforeTestFunc: beforeTestFunc(clientStubs[3], map[string][]string{
"c3": {`"state":"` + clientStubs[3]["c3"].Id(), `"code":"123"`, `"error":"example"`},
}),
ExpectedStatus: http.StatusTemporaryRedirect,
ExpectedEvents: map[string]int{"*": 0},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
app.Store().Get("cancelFunc").(context.CancelFunc)()
checkFailureRedirect(t, app, res)
if clientStubs[3]["c3"].HasSubscription("@oauth2") {
t.Fatalf("Expected oauth2 subscription to be removed")
}
},
},
{
Name: "discarded client with @oauth2 subscription",
Method: http.MethodGet,
URL: "/api/oauth2-redirect?code=123&state=" + clientStubs[4]["c5"].Id(),
BeforeTestFunc: beforeTestFunc(clientStubs[4], nil),
ExpectedStatus: http.StatusTemporaryRedirect,
ExpectedEvents: map[string]int{"*": 0},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
app.Store().Get("cancelFunc").(context.CancelFunc)()
checkFailureRedirect(t, app, res)
},
},
{
Name: "client without @oauth2 subscription",
Method: http.MethodGet,
URL: "/api/oauth2-redirect?code=123&state=" + clientStubs[4]["c4"].Id(),
BeforeTestFunc: beforeTestFunc(clientStubs[5], nil),
ExpectedStatus: http.StatusTemporaryRedirect,
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
app.Store().Get("cancelFunc").(context.CancelFunc)()
checkFailureRedirect(t, app, res)
},
},
{
Name: "client with @oauth2 subscription",
Method: http.MethodGet,
URL: "/api/oauth2-redirect?code=123&state=" + clientStubs[6]["c3"].Id(),
BeforeTestFunc: beforeTestFunc(clientStubs[6], map[string][]string{
"c3": {`"state":"` + clientStubs[6]["c3"].Id(), `"code":"123"`},
}),
ExpectedStatus: http.StatusTemporaryRedirect,
ExpectedEvents: map[string]int{"*": 0},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
app.Store().Get("cancelFunc").(context.CancelFunc)()
checkSuccessRedirect(t, app, res)
if clientStubs[6]["c3"].HasSubscription("@oauth2") {
t.Fatalf("Expected oauth2 subscription to be removed")
}
},
},
{
Name: "(POST) client with @oauth2 subscription",
Method: http.MethodPost,
URL: "/api/oauth2-redirect",
Body: strings.NewReader("code=123&state=" + clientStubs[7]["c3"].Id()),
Headers: map[string]string{
"content-type": "application/x-www-form-urlencoded",
},
BeforeTestFunc: beforeTestFunc(clientStubs[7], map[string][]string{
"c3": {`"state":"` + clientStubs[7]["c3"].Id(), `"code":"123"`},
}),
ExpectedStatus: http.StatusSeeOther,
ExpectedEvents: map[string]int{"*": 0},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
app.Store().Get("cancelFunc").(context.CancelFunc)()
checkSuccessRedirect(t, app, res)
if clientStubs[7]["c3"].HasSubscription("@oauth2") {
t.Fatalf("Expected oauth2 subscription to be removed")
}
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,106 @@
package apis
import (
"errors"
"fmt"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/pocketbase/pocketbase/core"
)
func recordAuthWithOTP(e *core.RequestEvent) error {
collection, err := findAuthCollection(e)
if err != nil {
return err
}
if !collection.OTP.Enabled {
return e.ForbiddenError("The collection is not configured to allow OTP authentication.", nil)
}
form := &authWithOTPForm{}
if err = e.BindBody(form); err != nil {
return firstApiError(err, e.BadRequestError("An error occurred while loading the submitted data.", err))
}
if err = form.validate(); err != nil {
return firstApiError(err, e.BadRequestError("An error occurred while validating the submitted data.", err))
}
e.Set(core.RequestEventKeyInfoContext, core.RequestInfoContextOTP)
event := new(core.RecordAuthWithOTPRequestEvent)
event.RequestEvent = e
event.Collection = collection
// extra validations
// (note: returns a generic 400 as a very basic OTPs enumeration protection)
// ---
event.OTP, err = e.App.FindOTPById(form.OTPId)
if err != nil {
return e.BadRequestError("Invalid or expired OTP", err)
}
if event.OTP.CollectionRef() != collection.Id {
return e.BadRequestError("Invalid or expired OTP", errors.New("the OTP is for a different collection"))
}
if event.OTP.HasExpired(collection.OTP.DurationTime()) {
return e.BadRequestError("Invalid or expired OTP", errors.New("the OTP is expired"))
}
event.Record, err = e.App.FindRecordById(event.OTP.CollectionRef(), event.OTP.RecordRef())
if err != nil {
return e.BadRequestError("Invalid or expired OTP", fmt.Errorf("missing auth record: %w", err))
}
// since otps are usually simple digit numbers, enforce an extra rate limit rule as basic enumaration protection
err = checkRateLimit(e, "@pb_otp_"+event.Record.Id, core.RateLimitRule{MaxRequests: 5, Duration: 180})
if err != nil {
return e.TooManyRequestsError("Too many attempts, please try again later with a new OTP.", nil)
}
if !event.OTP.ValidatePassword(form.Password) {
return e.BadRequestError("Invalid or expired OTP", errors.New("incorrect password"))
}
// ---
return e.App.OnRecordAuthWithOTPRequest().Trigger(event, func(e *core.RecordAuthWithOTPRequestEvent) error {
// update the user email verified state in case the OTP originate from an email address matching the current record one
//
// note: don't wait for success auth response (it could fail because of MFA) and because we already validated the OTP above
otpSentTo := e.OTP.SentTo()
if !e.Record.Verified() && otpSentTo != "" && e.Record.Email() == otpSentTo {
e.Record.SetVerified(true)
err = e.App.Save(e.Record)
if err != nil {
e.App.Logger().Error("Failed to update record verified state after successful OTP validation",
"error", err,
"otpId", e.OTP.Id,
"recordId", e.Record.Id,
)
}
}
// try to delete the used otp
err = e.App.Delete(e.OTP)
if err != nil {
e.App.Logger().Error("Failed to delete used OTP", "error", err, "otpId", e.OTP.Id)
}
return RecordAuthResponse(e.RequestEvent, e.Record, core.MFAMethodOTP, nil)
})
}
// -------------------------------------------------------------------
type authWithOTPForm struct {
OTPId string `form:"otpId" json:"otpId"`
Password string `form:"password" json:"password"`
}
func (form *authWithOTPForm) validate() error {
return validation.ValidateStruct(form,
validation.Field(&form.OTPId, validation.Required, validation.Length(1, 255)),
validation.Field(&form.Password, validation.Required, validation.Length(1, 71)),
)
}

View file

@ -0,0 +1,608 @@
package apis_test
import (
"net/http"
"strings"
"testing"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
"github.com/pocketbase/pocketbase/tools/types"
)
func TestRecordAuthWithOTP(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "not an auth collection",
Method: http.MethodPost,
URL: "/api/collections/demo1/auth-with-otp",
Body: strings.NewReader(`{"otpId":"test","password":"123456"}`),
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "auth collection with disabled otp",
Method: http.MethodPost,
URL: "/api/collections/users/auth-with-otp",
Body: strings.NewReader(`{"otpId":"test","password":"123456"}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
usersCol, err := app.FindCollectionByNameOrId("users")
if err != nil {
t.Fatal(err)
}
usersCol.OTP.Enabled = false
if err := app.Save(usersCol); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "invalid body",
Method: http.MethodPost,
URL: "/api/collections/users/auth-with-otp",
Body: strings.NewReader(`{"email`),
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "empty body",
Method: http.MethodPost,
URL: "/api/collections/users/auth-with-otp",
Body: strings.NewReader(``),
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{`,
`"otpId":{"code":"validation_required"`,
`"password":{"code":"validation_required"`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "invalid request data",
Method: http.MethodPost,
URL: "/api/collections/users/auth-with-otp",
Body: strings.NewReader(`{
"otpId":"` + strings.Repeat("a", 256) + `",
"password":"` + strings.Repeat("a", 72) + `"
}`),
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{`,
`"otpId":{"code":"validation_length_out_of_range"`,
`"password":{"code":"validation_length_out_of_range"`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "missing otp",
Method: http.MethodPost,
URL: "/api/collections/users/auth-with-otp",
Body: strings.NewReader(`{
"otpId":"missing",
"password":"123456"
}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
otp := core.NewOTP(app)
otp.Id = strings.Repeat("a", 15)
otp.SetCollectionRef(user.Collection().Id)
otp.SetRecordRef(user.Id)
otp.SetPassword("123456")
if err := app.Save(otp); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "otp for different collection",
Method: http.MethodPost,
URL: "/api/collections/users/auth-with-otp",
Body: strings.NewReader(`{
"otpId":"` + strings.Repeat("a", 15) + `",
"password":"123456"
}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
client, err := app.FindAuthRecordByEmail("clients", "test@example.com")
if err != nil {
t.Fatal(err)
}
otp := core.NewOTP(app)
otp.Id = strings.Repeat("a", 15)
otp.SetCollectionRef(client.Collection().Id)
otp.SetRecordRef(client.Id)
otp.SetPassword("123456")
if err := app.Save(otp); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "otp with wrong password",
Method: http.MethodPost,
URL: "/api/collections/users/auth-with-otp",
Body: strings.NewReader(`{
"otpId":"` + strings.Repeat("a", 15) + `",
"password":"123456"
}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
otp := core.NewOTP(app)
otp.Id = strings.Repeat("a", 15)
otp.SetCollectionRef(user.Collection().Id)
otp.SetRecordRef(user.Id)
otp.SetPassword("1234567890")
if err := app.Save(otp); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "expired otp with valid password",
Method: http.MethodPost,
URL: "/api/collections/users/auth-with-otp",
Body: strings.NewReader(`{
"otpId":"` + strings.Repeat("a", 15) + `",
"password":"123456"
}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
otp := core.NewOTP(app)
otp.Id = strings.Repeat("a", 15)
otp.SetCollectionRef(user.Collection().Id)
otp.SetRecordRef(user.Id)
otp.SetPassword("123456")
expiredDate := types.NowDateTime().AddDate(-3, 0, 0)
otp.SetRaw("created", expiredDate)
otp.SetRaw("updated", expiredDate)
if err := app.Save(otp); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "valid otp with valid password (enabled MFA)",
Method: http.MethodPost,
URL: "/api/collections/users/auth-with-otp",
Body: strings.NewReader(`{
"otpId":"` + strings.Repeat("a", 15) + `",
"password":"123456"
}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
otp := core.NewOTP(app)
otp.Id = strings.Repeat("a", 15)
otp.SetCollectionRef(user.Collection().Id)
otp.SetRecordRef(user.Id)
otp.SetPassword("123456")
if err := app.Save(otp); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 401,
ExpectedContent: []string{`"mfaId":"`},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordAuthWithOTPRequest": 1,
"OnRecordAuthRequest": 1,
// ---
"OnModelValidate": 1,
"OnModelCreate": 1, // mfa record
"OnModelCreateExecute": 1,
"OnModelAfterCreateSuccess": 1,
"OnModelDelete": 1, // otp delete
"OnModelDeleteExecute": 1,
"OnModelAfterDeleteSuccess": 1,
// ---
"OnRecordValidate": 1,
"OnRecordCreate": 1,
"OnRecordCreateExecute": 1,
"OnRecordAfterCreateSuccess": 1,
"OnRecordDelete": 1,
"OnRecordDeleteExecute": 1,
"OnRecordAfterDeleteSuccess": 1,
},
},
{
Name: "valid otp with valid password and empty sentTo (disabled MFA)",
Method: http.MethodPost,
URL: "/api/collections/users/auth-with-otp",
Body: strings.NewReader(`{
"otpId":"` + strings.Repeat("a", 15) + `",
"password":"123456"
}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
// ensure that the user is unverified
user.SetVerified(false)
if err = app.Save(user); err != nil {
t.Fatal(err)
}
// disable MFA
user.Collection().MFA.Enabled = false
if err = app.Save(user.Collection()); err != nil {
t.Fatal(err)
}
otp := core.NewOTP(app)
otp.Id = strings.Repeat("a", 15)
otp.SetCollectionRef(user.Collection().Id)
otp.SetRecordRef(user.Id)
otp.SetPassword("123456")
if err := app.Save(otp); err != nil {
t.Fatal(err)
}
// test at least once that the correct request info context is properly loaded
app.OnRecordAuthRequest().BindFunc(func(e *core.RecordAuthRequestEvent) error {
info, err := e.RequestInfo()
if err != nil {
t.Fatal(err)
}
if info.Context != core.RequestInfoContextOTP {
t.Fatalf("Expected request context %q, got %q", core.RequestInfoContextOTP, info.Context)
}
return e.Next()
})
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"token":"`,
`"record":{`,
`"email":"test@example.com"`,
},
NotExpectedContent: []string{
`"meta":`,
// hidden fields
`"tokenKey"`,
`"password"`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordAuthWithOTPRequest": 1,
"OnRecordAuthRequest": 1,
"OnRecordEnrich": 1,
// ---
"OnModelValidate": 1,
"OnModelCreate": 1, // authOrigin
"OnModelCreateExecute": 1,
"OnModelAfterCreateSuccess": 1,
"OnModelDelete": 1, // otp delete
"OnModelDeleteExecute": 1,
"OnModelAfterDeleteSuccess": 1,
// ---
"OnRecordValidate": 1,
"OnRecordCreate": 1,
"OnRecordCreateExecute": 1,
"OnRecordAfterCreateSuccess": 1,
"OnRecordDelete": 1,
"OnRecordDeleteExecute": 1,
"OnRecordAfterDeleteSuccess": 1,
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
if user.Verified() {
t.Fatal("Expected the user to remain unverified because sentTo != email")
}
},
},
{
Name: "valid otp with valid password and nonempty sentTo=email (disabled MFA)",
Method: http.MethodPost,
URL: "/api/collections/users/auth-with-otp",
Body: strings.NewReader(`{
"otpId":"` + strings.Repeat("a", 15) + `",
"password":"123456"
}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
// ensure that the user is unverified
user.SetVerified(false)
if err = app.Save(user); err != nil {
t.Fatal(err)
}
// disable MFA
user.Collection().MFA.Enabled = false
if err = app.Save(user.Collection()); err != nil {
t.Fatal(err)
}
otp := core.NewOTP(app)
otp.Id = strings.Repeat("a", 15)
otp.SetCollectionRef(user.Collection().Id)
otp.SetRecordRef(user.Id)
otp.SetPassword("123456")
otp.SetSentTo(user.Email())
if err := app.Save(otp); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"token":"`,
`"record":{`,
`"email":"test@example.com"`,
},
NotExpectedContent: []string{
`"meta":`,
// hidden fields
`"tokenKey"`,
`"password"`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordAuthWithOTPRequest": 1,
"OnRecordAuthRequest": 1,
"OnRecordEnrich": 1,
// ---
"OnModelValidate": 2, // +1 because of the verified user update
// authOrigin create
"OnModelCreate": 1,
"OnModelCreateExecute": 1,
"OnModelAfterCreateSuccess": 1,
// OTP delete
"OnModelDelete": 1,
"OnModelDeleteExecute": 1,
"OnModelAfterDeleteSuccess": 1,
// user verified update
"OnModelUpdate": 1,
"OnModelUpdateExecute": 1,
"OnModelAfterUpdateSuccess": 1,
// ---
"OnRecordValidate": 2,
"OnRecordCreate": 1,
"OnRecordCreateExecute": 1,
"OnRecordAfterCreateSuccess": 1,
"OnRecordDelete": 1,
"OnRecordDeleteExecute": 1,
"OnRecordAfterDeleteSuccess": 1,
"OnRecordUpdate": 1,
"OnRecordUpdateExecute": 1,
"OnRecordAfterUpdateSuccess": 1,
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
if !user.Verified() {
t.Fatal("Expected the user to be marked as verified")
}
},
},
{
Name: "OnRecordAuthWithOTPRequest tx body write check",
Method: http.MethodPost,
URL: "/api/collections/users/auth-with-otp",
Body: strings.NewReader(`{
"otpId":"` + strings.Repeat("a", 15) + `",
"password":"123456"
}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
// disable MFA
user.Collection().MFA.Enabled = false
if err = app.Save(user.Collection()); err != nil {
t.Fatal(err)
}
otp := core.NewOTP(app)
otp.Id = strings.Repeat("a", 15)
otp.SetCollectionRef(user.Collection().Id)
otp.SetRecordRef(user.Id)
otp.SetPassword("123456")
if err := app.Save(otp); err != nil {
t.Fatal(err)
}
app.OnRecordAuthWithOTPRequest().BindFunc(func(e *core.RecordAuthWithOTPRequestEvent) error {
original := e.App
return e.App.RunInTransaction(func(txApp core.App) error {
e.App = txApp
defer func() { e.App = original }()
if err := e.Next(); err != nil {
return err
}
return e.BadRequestError("TX_ERROR", nil)
})
})
},
ExpectedStatus: 400,
ExpectedEvents: map[string]int{"OnRecordAuthWithOTPRequest": 1},
ExpectedContent: []string{"TX_ERROR"},
},
// rate limit checks
// -----------------------------------------------------------
{
Name: "RateLimit rule - users:authWithOTP",
Method: http.MethodPost,
URL: "/api/collections/users/auth-with-otp",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 100, Label: "*:authWithOTP"},
{MaxRequests: 100, Label: "users:auth"},
{MaxRequests: 0, Label: "users:authWithOTP"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "RateLimit rule - *:authWithOTP",
Method: http.MethodPost,
URL: "/api/collections/users/auth-with-otp",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 100, Label: "*:auth"},
{MaxRequests: 0, Label: "*:authWithOTP"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "RateLimit rule - users:auth",
Method: http.MethodPost,
URL: "/api/collections/users/auth-with-otp",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 100, Label: "*:authWithOTP"},
{MaxRequests: 0, Label: "users:auth"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "RateLimit rule - *:auth",
Method: http.MethodPost,
URL: "/api/collections/users/auth-with-otp",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 0, Label: "*:auth"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestRecordAuthWithOTPManualRateLimiterCheck(t *testing.T) {
t.Parallel()
var storeCache map[string]any
otpAId := strings.Repeat("a", 15)
otpBId := strings.Repeat("b", 15)
scenarios := []struct {
otpId string
password string
expectedStatus int
}{
{otpAId, "12345", 400},
{otpAId, "12345", 400},
{otpBId, "12345", 400},
{otpBId, "12345", 400},
{otpBId, "12345", 400},
{otpAId, "12345", 429},
{otpAId, "123456", 429}, // reject even if it is correct
{otpAId, "123456", 429},
{otpBId, "123456", 429},
}
for _, s := range scenarios {
(&tests.ApiScenario{
Method: http.MethodPost,
URL: "/api/collections/users/auth-with-otp",
Body: strings.NewReader(`{
"otpId":"` + s.otpId + `",
"password":"` + s.password + `"
}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
for k, v := range storeCache {
app.Store().Set(k, v)
}
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
user.Collection().MFA.Enabled = false
if err := app.Save(user.Collection()); err != nil {
t.Fatal(err)
}
for _, id := range []string{otpAId, otpBId} {
otp := core.NewOTP(app)
otp.Id = id
otp.SetCollectionRef(user.Collection().Id)
otp.SetRecordRef(user.Id)
otp.SetPassword("123456")
if err := app.Save(otp); err != nil {
t.Fatal(err)
}
}
},
ExpectedStatus: s.expectedStatus,
ExpectedContent: []string{`"`}, // it doesn't matter anything non-empty
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
storeCache = app.Store().GetAll()
},
}).Test(t)
}
}

View file

@ -0,0 +1,135 @@
package apis
import (
"database/sql"
"errors"
"slices"
"strings"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/go-ozzo/ozzo-validation/v4/is"
"github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tools/dbutils"
"github.com/pocketbase/pocketbase/tools/list"
)
func recordAuthWithPassword(e *core.RequestEvent) error {
collection, err := findAuthCollection(e)
if err != nil {
return err
}
if !collection.PasswordAuth.Enabled {
return e.ForbiddenError("The collection is not configured to allow password authentication.", nil)
}
form := &authWithPasswordForm{}
if err = e.BindBody(form); err != nil {
return firstApiError(err, e.BadRequestError("An error occurred while loading the submitted data.", err))
}
if err = form.validate(collection); err != nil {
return firstApiError(err, e.BadRequestError("An error occurred while validating the submitted data.", err))
}
e.Set(core.RequestEventKeyInfoContext, core.RequestInfoContextPasswordAuth)
var foundRecord *core.Record
var foundErr error
if form.IdentityField != "" {
foundRecord, foundErr = findRecordByIdentityField(e.App, collection, form.IdentityField, form.Identity)
} else {
// prioritize email lookup
isEmail := is.EmailFormat.Validate(form.Identity) == nil
if isEmail && list.ExistInSlice(core.FieldNameEmail, collection.PasswordAuth.IdentityFields) {
foundRecord, foundErr = findRecordByIdentityField(e.App, collection, core.FieldNameEmail, form.Identity)
}
// search by the other identity fields
if !isEmail || foundErr != nil {
for _, name := range collection.PasswordAuth.IdentityFields {
if !isEmail && name == core.FieldNameEmail {
continue // no need to search by the email field if it is not an email
}
foundRecord, foundErr = findRecordByIdentityField(e.App, collection, name, form.Identity)
if foundErr == nil {
break
}
}
}
}
// ignore not found errors to allow custom record find implementations
if foundErr != nil && !errors.Is(foundErr, sql.ErrNoRows) {
return e.InternalServerError("", foundErr)
}
event := new(core.RecordAuthWithPasswordRequestEvent)
event.RequestEvent = e
event.Collection = collection
event.Record = foundRecord
event.Identity = form.Identity
event.Password = form.Password
event.IdentityField = form.IdentityField
return e.App.OnRecordAuthWithPasswordRequest().Trigger(event, func(e *core.RecordAuthWithPasswordRequestEvent) error {
if e.Record == nil || !e.Record.ValidatePassword(e.Password) {
return e.BadRequestError("Failed to authenticate.", errors.New("invalid login credentials"))
}
return RecordAuthResponse(e.RequestEvent, e.Record, core.MFAMethodPassword, nil)
})
}
// -------------------------------------------------------------------
type authWithPasswordForm struct {
Identity string `form:"identity" json:"identity"`
Password string `form:"password" json:"password"`
// IdentityField specifies the field to use to search for the identity
// (leave it empty for "auto" detection).
IdentityField string `form:"identityField" json:"identityField"`
}
func (form *authWithPasswordForm) validate(collection *core.Collection) error {
return validation.ValidateStruct(form,
validation.Field(&form.Identity, validation.Required, validation.Length(1, 255)),
validation.Field(&form.Password, validation.Required, validation.Length(1, 255)),
validation.Field(
&form.IdentityField,
validation.Length(1, 255),
validation.In(list.ToInterfaceSlice(collection.PasswordAuth.IdentityFields)...),
),
)
}
func findRecordByIdentityField(app core.App, collection *core.Collection, field string, value any) (*core.Record, error) {
if !slices.Contains(collection.PasswordAuth.IdentityFields, field) {
return nil, errors.New("invalid identity field " + field)
}
index, ok := dbutils.FindSingleColumnUniqueIndex(collection.Indexes, field)
if !ok {
return nil, errors.New("missing " + field + " unique index constraint")
}
var expr dbx.Expression
if strings.EqualFold(index.Columns[0].Collate, "nocase") {
// case-insensitive search
expr = dbx.NewExp("[["+field+"]] = {:identity} COLLATE NOCASE", dbx.Params{"identity": value})
} else {
expr = dbx.HashExp{field: value}
}
record := &core.Record{}
err := app.RecordQuery(collection).AndWhere(expr).Limit(1).One(record)
if err != nil {
return nil, err
}
return record, nil
}

View file

@ -0,0 +1,713 @@
package apis_test
import (
"net/http"
"strings"
"testing"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
"github.com/pocketbase/pocketbase/tools/dbutils"
)
func TestRecordAuthWithPassword(t *testing.T) {
t.Parallel()
updateIdentityIndex := func(collectionIdOrName string, fieldCollateMap map[string]string) func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
return func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
collection, err := app.FindCollectionByNameOrId("clients")
if err != nil {
t.Fatal(err)
}
for column, collate := range fieldCollateMap {
index, ok := dbutils.FindSingleColumnUniqueIndex(collection.Indexes, column)
if !ok {
t.Fatalf("Missing unique identityField index for column %q", column)
}
index.Columns[0].Collate = collate
collection.RemoveIndex(index.IndexName)
collection.Indexes = append(collection.Indexes, index.Build())
}
err = app.Save(collection)
if err != nil {
t.Fatalf("Failed to update identityField index: %v", err)
}
}
}
scenarios := []tests.ApiScenario{
{
Name: "disabled password auth",
Method: http.MethodPost,
URL: "/api/collections/nologin/auth-with-password",
Body: strings.NewReader(`{"identity":"test@example.com","password":"1234567890"}`),
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "non-auth collection",
Method: http.MethodPost,
URL: "/api/collections/demo1/auth-with-password",
Body: strings.NewReader(`{"identity":"test@example.com","password":"1234567890"}`),
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "invalid body format",
Method: http.MethodPost,
URL: "/api/collections/clients/auth-with-password",
Body: strings.NewReader(`{"identity`),
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "empty body params",
Method: http.MethodPost,
URL: "/api/collections/clients/auth-with-password",
Body: strings.NewReader(`{"identity":"","password":""}`),
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{`,
`"identity":{`,
`"password":{`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "OnRecordAuthWithPasswordRequest tx body write check",
Method: http.MethodPost,
URL: "/api/collections/clients/auth-with-password",
Body: strings.NewReader(`{
"identity":"test@example.com",
"password":"1234567890"
}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.OnRecordAuthWithPasswordRequest().BindFunc(func(e *core.RecordAuthWithPasswordRequestEvent) error {
original := e.App
return e.App.RunInTransaction(func(txApp core.App) error {
e.App = txApp
defer func() { e.App = original }()
if err := e.Next(); err != nil {
return err
}
return e.BadRequestError("TX_ERROR", nil)
})
})
},
ExpectedStatus: 400,
ExpectedEvents: map[string]int{"OnRecordAuthWithPasswordRequest": 1},
ExpectedContent: []string{"TX_ERROR"},
},
{
Name: "valid identity field and invalid password",
Method: http.MethodPost,
URL: "/api/collections/clients/auth-with-password",
Body: strings.NewReader(`{
"identity":"test@example.com",
"password":"invalid"
}`),
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{}`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordAuthWithPasswordRequest": 1,
},
},
{
Name: "valid identity field (email) and valid password",
Method: http.MethodPost,
URL: "/api/collections/clients/auth-with-password",
Body: strings.NewReader(`{
"identity":"test@example.com",
"password":"1234567890"
}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
// test at least once that the correct request info context is properly loaded
app.OnRecordAuthRequest().BindFunc(func(e *core.RecordAuthRequestEvent) error {
info, err := e.RequestInfo()
if err != nil {
t.Fatal(err)
}
if info.Context != core.RequestInfoContextPasswordAuth {
t.Fatalf("Expected request context %q, got %q", core.RequestInfoContextPasswordAuth, info.Context)
}
return e.Next()
})
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"email":"test@example.com"`,
`"token":`,
},
NotExpectedContent: []string{
// hidden fields
`"tokenKey"`,
`"password"`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordAuthWithPasswordRequest": 1,
"OnRecordAuthRequest": 1,
"OnRecordEnrich": 1,
// authOrigin track
"OnModelCreate": 1,
"OnModelCreateExecute": 1,
"OnModelAfterCreateSuccess": 1,
"OnModelValidate": 1,
"OnRecordCreate": 1,
"OnRecordCreateExecute": 1,
"OnRecordAfterCreateSuccess": 1,
"OnRecordValidate": 1,
"OnMailerSend": 1,
"OnMailerRecordAuthAlertSend": 1,
},
},
{
Name: "valid identity field (username) and valid password",
Method: http.MethodPost,
URL: "/api/collections/clients/auth-with-password",
Body: strings.NewReader(`{
"identity":"clients57772",
"password":"1234567890"
}`),
ExpectedStatus: 200,
ExpectedContent: []string{
`"email":"test@example.com"`,
`"username":"clients57772"`,
`"token":`,
},
NotExpectedContent: []string{
// hidden fields
`"tokenKey"`,
`"password"`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordAuthWithPasswordRequest": 1,
"OnRecordAuthRequest": 1,
"OnRecordEnrich": 1,
// authOrigin track
"OnModelCreate": 1,
"OnModelCreateExecute": 1,
"OnModelAfterCreateSuccess": 1,
"OnModelValidate": 1,
"OnRecordCreate": 1,
"OnRecordCreateExecute": 1,
"OnRecordAfterCreateSuccess": 1,
"OnRecordValidate": 1,
"OnMailerSend": 1,
"OnMailerRecordAuthAlertSend": 1,
},
},
{
Name: "unknown explicit identityField",
Method: http.MethodPost,
URL: "/api/collections/clients/auth-with-password",
Body: strings.NewReader(`{
"identityField": "created",
"identity":"test@example.com",
"password":"1234567890"
}`),
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{`,
`"identityField":{"code":"validation_in_invalid"`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "valid identity field and valid password with mismatched explicit identityField",
Method: http.MethodPost,
URL: "/api/collections/clients/auth-with-password",
Body: strings.NewReader(`{
"identityField": "username",
"identity":"test@example.com",
"password":"1234567890"
}`),
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordAuthWithPasswordRequest": 1,
},
},
{
Name: "valid identity field and valid password with matched explicit identityField",
Method: http.MethodPost,
URL: "/api/collections/clients/auth-with-password",
Body: strings.NewReader(`{
"identityField": "username",
"identity":"clients57772",
"password":"1234567890"
}`),
ExpectedStatus: 200,
ExpectedContent: []string{
`"email":"test@example.com"`,
`"username":"clients57772"`,
`"token":`,
},
NotExpectedContent: []string{
// hidden fields
`"tokenKey"`,
`"password"`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordAuthWithPasswordRequest": 1,
"OnRecordAuthRequest": 1,
"OnRecordEnrich": 1,
// authOrigin track
"OnModelCreate": 1,
"OnModelCreateExecute": 1,
"OnModelAfterCreateSuccess": 1,
"OnModelValidate": 1,
"OnRecordCreate": 1,
"OnRecordCreateExecute": 1,
"OnRecordAfterCreateSuccess": 1,
"OnRecordValidate": 1,
"OnMailerSend": 1,
"OnMailerRecordAuthAlertSend": 1,
},
},
{
Name: "valid identity (unverified) and valid password in onlyVerified collection",
Method: http.MethodPost,
URL: "/api/collections/clients/auth-with-password",
Body: strings.NewReader(`{
"identity":"test2@example.com",
"password":"1234567890"
}`),
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordAuthWithPasswordRequest": 1,
},
},
{
Name: "already authenticated record",
Method: http.MethodPost,
URL: "/api/collections/clients/auth-with-password",
Body: strings.NewReader(`{
"identity":"test@example.com",
"password":"1234567890"
}`),
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"id":"gk390qegs4y47wn"`,
`"email":"test@example.com"`,
`"token":`,
},
NotExpectedContent: []string{
// hidden fields
`"tokenKey"`,
`"password"`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordAuthWithPasswordRequest": 1,
"OnRecordAuthRequest": 1,
"OnRecordEnrich": 1,
// authOrigin track
"OnModelCreate": 1,
"OnModelCreateExecute": 1,
"OnModelAfterCreateSuccess": 1,
"OnModelValidate": 1,
"OnRecordCreate": 1,
"OnRecordCreateExecute": 1,
"OnRecordAfterCreateSuccess": 1,
"OnRecordValidate": 1,
"OnMailerSend": 1,
"OnMailerRecordAuthAlertSend": 1,
},
},
{
Name: "with mfa first auth check",
Method: http.MethodPost,
URL: "/api/collections/users/auth-with-password",
Body: strings.NewReader(`{
"identity":"test@example.com",
"password":"1234567890"
}`),
ExpectedStatus: 401,
ExpectedContent: []string{
`"mfaId":"`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordAuthWithPasswordRequest": 1,
"OnRecordAuthRequest": 1,
// mfa create
"OnModelCreate": 1,
"OnModelCreateExecute": 1,
"OnModelAfterCreateSuccess": 1,
"OnModelValidate": 1,
"OnRecordCreate": 1,
"OnRecordCreateExecute": 1,
"OnRecordAfterCreateSuccess": 1,
"OnRecordValidate": 1,
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
mfas, err := app.FindAllMFAsByRecord(user)
if err != nil {
t.Fatal(err)
}
if v := len(mfas); v != 1 {
t.Fatalf("Expected 1 mfa record to be created, got %d", v)
}
},
},
{
Name: "with mfa second auth check",
Method: http.MethodPost,
URL: "/api/collections/users/auth-with-password",
Body: strings.NewReader(`{
"mfaId": "` + strings.Repeat("a", 15) + `",
"identity":"test@example.com",
"password":"1234567890"
}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
// insert a dummy mfa record
mfa := core.NewMFA(app)
mfa.Id = strings.Repeat("a", 15)
mfa.SetCollectionRef(user.Collection().Id)
mfa.SetRecordRef(user.Id)
mfa.SetMethod("test")
if err := app.Save(mfa); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"email":"test@example.com"`,
`"token":`,
},
NotExpectedContent: []string{
// hidden fields
`"tokenKey"`,
`"password"`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordAuthWithPasswordRequest": 1,
"OnRecordAuthRequest": 1,
"OnRecordEnrich": 1,
// authOrigin track
"OnModelCreate": 1,
"OnModelCreateExecute": 1,
"OnModelAfterCreateSuccess": 1,
"OnModelValidate": 1,
"OnRecordCreate": 1,
"OnRecordCreateExecute": 1,
"OnRecordAfterCreateSuccess": 1,
"OnRecordValidate": 1,
"OnMailerSend": 0, // disabled auth email alerts
"OnMailerRecordAuthAlertSend": 0,
// mfa delete
"OnModelDelete": 1,
"OnModelDeleteExecute": 1,
"OnModelAfterDeleteSuccess": 1,
"OnRecordDelete": 1,
"OnRecordDeleteExecute": 1,
"OnRecordAfterDeleteSuccess": 1,
},
},
{
Name: "with enabled mfa but unsatisfied mfa rule (aka. skip the mfa check)",
Method: http.MethodPost,
URL: "/api/collections/users/auth-with-password",
Body: strings.NewReader(`{
"identity":"test@example.com",
"password":"1234567890"
}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
users, err := app.FindCollectionByNameOrId("users")
if err != nil {
t.Fatal(err)
}
users.MFA.Enabled = true
users.MFA.Rule = "1=2"
if err := app.Save(users); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"email":"test@example.com"`,
`"token":`,
},
NotExpectedContent: []string{
// hidden fields
`"tokenKey"`,
`"password"`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordAuthWithPasswordRequest": 1,
"OnRecordAuthRequest": 1,
"OnRecordEnrich": 1,
// authOrigin track
"OnModelCreate": 1,
"OnModelCreateExecute": 1,
"OnModelAfterCreateSuccess": 1,
"OnModelValidate": 1,
"OnRecordCreate": 1,
"OnRecordCreateExecute": 1,
"OnRecordAfterCreateSuccess": 1,
"OnRecordValidate": 1,
"OnMailerSend": 0, // disabled auth email alerts
"OnMailerRecordAuthAlertSend": 0,
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
mfas, err := app.FindAllMFAsByRecord(user)
if err != nil {
t.Fatal(err)
}
if v := len(mfas); v != 0 {
t.Fatalf("Expected no mfa records to be created, got %d", v)
}
},
},
// case sensitivity checks
// -----------------------------------------------------------
{
Name: "with explicit identityField (case-sensitive)",
Method: http.MethodPost,
URL: "/api/collections/clients/auth-with-password",
Body: strings.NewReader(`{
"identityField": "username",
"identity":"Clients57772",
"password":"1234567890"
}`),
BeforeTestFunc: updateIdentityIndex("clients", map[string]string{"username": ""}),
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordAuthWithPasswordRequest": 1,
},
},
{
Name: "with explicit identityField (case-insensitive)",
Method: http.MethodPost,
URL: "/api/collections/clients/auth-with-password",
Body: strings.NewReader(`{
"identityField": "username",
"identity":"Clients57772",
"password":"1234567890"
}`),
BeforeTestFunc: updateIdentityIndex("clients", map[string]string{"username": "nocase"}),
ExpectedStatus: 200,
ExpectedContent: []string{
`"email":"test@example.com"`,
`"username":"clients57772"`,
`"token":`,
},
NotExpectedContent: []string{
// hidden fields
`"tokenKey"`,
`"password"`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordAuthWithPasswordRequest": 1,
"OnRecordAuthRequest": 1,
"OnRecordEnrich": 1,
// authOrigin track
"OnModelCreate": 1,
"OnModelCreateExecute": 1,
"OnModelAfterCreateSuccess": 1,
"OnModelValidate": 1,
"OnRecordCreate": 1,
"OnRecordCreateExecute": 1,
"OnRecordAfterCreateSuccess": 1,
"OnRecordValidate": 1,
"OnMailerSend": 1,
"OnMailerRecordAuthAlertSend": 1,
},
},
{
Name: "without explicit identityField and non-email field (case-insensitive)",
Method: http.MethodPost,
URL: "/api/collections/clients/auth-with-password",
Body: strings.NewReader(`{
"identity":"Clients57772",
"password":"1234567890"
}`),
BeforeTestFunc: updateIdentityIndex("clients", map[string]string{"username": "nocase"}),
ExpectedStatus: 200,
ExpectedContent: []string{
`"email":"test@example.com"`,
`"username":"clients57772"`,
`"token":`,
},
NotExpectedContent: []string{
// hidden fields
`"tokenKey"`,
`"password"`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordAuthWithPasswordRequest": 1,
"OnRecordAuthRequest": 1,
"OnRecordEnrich": 1,
// authOrigin track
"OnModelCreate": 1,
"OnModelCreateExecute": 1,
"OnModelAfterCreateSuccess": 1,
"OnModelValidate": 1,
"OnRecordCreate": 1,
"OnRecordCreateExecute": 1,
"OnRecordAfterCreateSuccess": 1,
"OnRecordValidate": 1,
"OnMailerSend": 1,
"OnMailerRecordAuthAlertSend": 1,
},
},
{
Name: "without explicit identityField and email field (case-insensitive)",
Method: http.MethodPost,
URL: "/api/collections/clients/auth-with-password",
Body: strings.NewReader(`{
"identity":"tESt@example.com",
"password":"1234567890"
}`),
BeforeTestFunc: updateIdentityIndex("clients", map[string]string{"email": "nocase"}),
ExpectedStatus: 200,
ExpectedContent: []string{
`"email":"test@example.com"`,
`"username":"clients57772"`,
`"token":`,
},
NotExpectedContent: []string{
// hidden fields
`"tokenKey"`,
`"password"`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordAuthWithPasswordRequest": 1,
"OnRecordAuthRequest": 1,
"OnRecordEnrich": 1,
// authOrigin track
"OnModelCreate": 1,
"OnModelCreateExecute": 1,
"OnModelAfterCreateSuccess": 1,
"OnModelValidate": 1,
"OnRecordCreate": 1,
"OnRecordCreateExecute": 1,
"OnRecordAfterCreateSuccess": 1,
"OnRecordValidate": 1,
"OnMailerSend": 1,
"OnMailerRecordAuthAlertSend": 1,
},
},
// rate limit checks
// -----------------------------------------------------------
{
Name: "RateLimit rule - users:authWithPassword",
Method: http.MethodPost,
URL: "/api/collections/users/auth-with-password",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 100, Label: "*:authWithPassword"},
{MaxRequests: 100, Label: "users:auth"},
{MaxRequests: 0, Label: "users:authWithPassword"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "RateLimit rule - *:authWithPassword",
Method: http.MethodPost,
URL: "/api/collections/users/auth-with-password",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 100, Label: "*:auth"},
{MaxRequests: 0, Label: "*:authWithPassword"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "RateLimit rule - users:auth",
Method: http.MethodPost,
URL: "/api/collections/users/auth-with-password",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 100, Label: "*:authWithPassword"},
{MaxRequests: 0, Label: "users:auth"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "RateLimit rule - *:auth",
Method: http.MethodPost,
URL: "/api/collections/users/auth-with-password",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 0, Label: "*:auth"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}

742
apis/record_crud.go Normal file
View file

@ -0,0 +1,742 @@
package apis
import (
cryptoRand "crypto/rand"
"errors"
"fmt"
"math/big"
"net/http"
"strings"
"time"
"github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/forms"
"github.com/pocketbase/pocketbase/tools/filesystem"
"github.com/pocketbase/pocketbase/tools/inflector"
"github.com/pocketbase/pocketbase/tools/list"
"github.com/pocketbase/pocketbase/tools/router"
"github.com/pocketbase/pocketbase/tools/search"
"github.com/pocketbase/pocketbase/tools/security"
)
// bindRecordCrudApi registers the record crud api endpoints and
// the corresponding handlers.
//
// note: the rate limiter is "inlined" because some of the crud actions are also used in the batch APIs
func bindRecordCrudApi(app core.App, rg *router.RouterGroup[*core.RequestEvent]) {
subGroup := rg.Group("/collections/{collection}/records").Unbind(DefaultRateLimitMiddlewareId)
subGroup.GET("", recordsList)
subGroup.GET("/{id}", recordView)
subGroup.POST("", recordCreate(true, nil)).Bind(dynamicCollectionBodyLimit(""))
subGroup.PATCH("/{id}", recordUpdate(true, nil)).Bind(dynamicCollectionBodyLimit(""))
subGroup.DELETE("/{id}", recordDelete(true, nil))
}
func recordsList(e *core.RequestEvent) error {
collection, err := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue("collection"))
if err != nil || collection == nil {
return e.NotFoundError("Missing collection context.", err)
}
err = checkCollectionRateLimit(e, collection, "list")
if err != nil {
return err
}
requestInfo, err := e.RequestInfo()
if err != nil {
return firstApiError(err, e.BadRequestError("", err))
}
if collection.ListRule == nil && !requestInfo.HasSuperuserAuth() {
return e.ForbiddenError("Only superusers can perform this action.", nil)
}
// forbid users and guests to query special filter/sort fields
err = checkForSuperuserOnlyRuleFields(requestInfo)
if err != nil {
return err
}
query := e.App.RecordQuery(collection)
fieldsResolver := core.NewRecordFieldResolver(e.App, collection, requestInfo, true)
if !requestInfo.HasSuperuserAuth() && collection.ListRule != nil && *collection.ListRule != "" {
expr, err := search.FilterData(*collection.ListRule).BuildExpr(fieldsResolver)
if err != nil {
return err
}
query.AndWhere(expr)
// will be applied by the search provider right before executing the query
// fieldsResolver.UpdateQuery(query)
}
// hidden fields are searchable only by superusers
fieldsResolver.SetAllowHiddenFields(requestInfo.HasSuperuserAuth())
searchProvider := search.NewProvider(fieldsResolver).Query(query)
// use rowid when available to minimize the need of a covering index with the "id" field
if !collection.IsView() {
searchProvider.CountCol("_rowid_")
}
records := []*core.Record{}
result, err := searchProvider.ParseAndExec(e.Request.URL.Query().Encode(), &records)
if err != nil {
return firstApiError(err, e.BadRequestError("", err))
}
event := new(core.RecordsListRequestEvent)
event.RequestEvent = e
event.Collection = collection
event.Records = records
event.Result = result
return e.App.OnRecordsListRequest().Trigger(event, func(e *core.RecordsListRequestEvent) error {
if err := EnrichRecords(e.RequestEvent, e.Records); err != nil {
return firstApiError(err, e.InternalServerError("Failed to enrich records", err))
}
// Add a randomized throttle in case of too many empty search filter attempts.
//
// This is just for extra precaution since security researches raised concern regarding the possibility of eventual
// timing attacks because the List API rule acts also as filter and executes in a single run with the client-side filters.
// This is by design and it is an accepted trade off between performance, usability and correctness.
//
// While technically the below doesn't fully guarantee protection against filter timing attacks, in practice combined with the network latency it makes them even less feasible.
// A properly configured rate limiter or individual fields Hidden checks are better suited if you are really concerned about eventual information disclosure by side-channel attacks.
//
// In all cases it doesn't really matter that much because it doesn't affect the builtin PocketBase security sensitive fields (e.g. password and tokenKey) since they
// are not client-side filterable and in the few places where they need to be compared against an external value, a constant time check is used.
if !e.HasSuperuserAuth() &&
(collection.ListRule != nil && *collection.ListRule != "") &&
(requestInfo.Query["filter"] != "") &&
len(e.Records) == 0 &&
checkRateLimit(e.RequestEvent, "@pb_list_timing_check_"+collection.Id, listTimingRateLimitRule) != nil {
e.App.Logger().Debug("Randomized throttle because of too many failed searches", "collectionId", collection.Id)
randomizedThrottle(150)
}
return execAfterSuccessTx(true, e.App, func() error {
return e.JSON(http.StatusOK, e.Result)
})
})
}
var listTimingRateLimitRule = core.RateLimitRule{MaxRequests: 3, Duration: 3}
func randomizedThrottle(softMax int64) {
var timeout int64
randRange, err := cryptoRand.Int(cryptoRand.Reader, big.NewInt(softMax))
if err == nil {
timeout = randRange.Int64()
} else {
timeout = softMax
}
time.Sleep(time.Duration(timeout) * time.Millisecond)
}
func recordView(e *core.RequestEvent) error {
collection, err := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue("collection"))
if err != nil || collection == nil {
return e.NotFoundError("Missing collection context.", err)
}
err = checkCollectionRateLimit(e, collection, "view")
if err != nil {
return err
}
recordId := e.Request.PathValue("id")
if recordId == "" {
return e.NotFoundError("", nil)
}
requestInfo, err := e.RequestInfo()
if err != nil {
return firstApiError(err, e.BadRequestError("", err))
}
if collection.ViewRule == nil && !requestInfo.HasSuperuserAuth() {
return e.ForbiddenError("Only superusers can perform this action.", nil)
}
ruleFunc := func(q *dbx.SelectQuery) error {
if !requestInfo.HasSuperuserAuth() && collection.ViewRule != nil && *collection.ViewRule != "" {
resolver := core.NewRecordFieldResolver(e.App, collection, requestInfo, true)
expr, err := search.FilterData(*collection.ViewRule).BuildExpr(resolver)
if err != nil {
return err
}
resolver.UpdateQuery(q)
q.AndWhere(expr)
}
return nil
}
record, fetchErr := e.App.FindRecordById(collection, recordId, ruleFunc)
if fetchErr != nil || record == nil {
return firstApiError(err, e.NotFoundError("", fetchErr))
}
event := new(core.RecordRequestEvent)
event.RequestEvent = e
event.Collection = collection
event.Record = record
return e.App.OnRecordViewRequest().Trigger(event, func(e *core.RecordRequestEvent) error {
if err := EnrichRecord(e.RequestEvent, e.Record); err != nil {
return firstApiError(err, e.InternalServerError("Failed to enrich record", err))
}
return execAfterSuccessTx(true, e.App, func() error {
return e.JSON(http.StatusOK, e.Record)
})
})
}
func recordCreate(responseWriteAfterTx bool, optFinalizer func(data any) error) func(e *core.RequestEvent) error {
return func(e *core.RequestEvent) error {
collection, err := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue("collection"))
if err != nil || collection == nil {
return e.NotFoundError("Missing collection context.", err)
}
if collection.IsView() {
return e.BadRequestError("Unsupported collection type.", nil)
}
err = checkCollectionRateLimit(e, collection, "create")
if err != nil {
return err
}
requestInfo, err := e.RequestInfo()
if err != nil {
return firstApiError(err, e.BadRequestError("", err))
}
hasSuperuserAuth := requestInfo.HasSuperuserAuth()
if !hasSuperuserAuth && collection.CreateRule == nil {
return e.ForbiddenError("Only superusers can perform this action.", nil)
}
record := core.NewRecord(collection)
data, err := recordDataFromRequest(e, record)
if err != nil {
return firstApiError(err, e.BadRequestError("Failed to read the submitted data.", err))
}
// set a random password for the OAuth2 ignoring its plain password validators
var skipPlainPasswordRecordValidators bool
if requestInfo.Context == core.RequestInfoContextOAuth2 {
if _, ok := data[core.FieldNamePassword]; !ok {
data[core.FieldNamePassword] = security.RandomString(30)
data[core.FieldNamePassword+"Confirm"] = data[core.FieldNamePassword]
skipPlainPasswordRecordValidators = true
}
}
// replace modifiers fields so that the resolved value is always
// available when accessing requestInfo.Body
requestInfo.Body = data
form := forms.NewRecordUpsert(e.App, record)
if hasSuperuserAuth {
form.GrantSuperuserAccess()
}
form.Load(data)
if skipPlainPasswordRecordValidators {
// unset the plain value to skip the plain password field validators
if raw, ok := record.GetRaw(core.FieldNamePassword).(*core.PasswordFieldValue); ok {
raw.Plain = ""
}
}
// check the request and record data against the create and manage rules
if !hasSuperuserAuth && collection.CreateRule != nil {
dummyRecord := record.Clone()
dummyRandomPart := "__pb_create__" + security.PseudorandomString(6)
// set an id if it doesn't have already
// (the value doesn't matter; it is used only to minimize the breaking changes with earlier versions)
if dummyRecord.Id == "" {
dummyRecord.Id = "__temp_id__" + dummyRandomPart
}
// unset the verified field to prevent manage API rule misuse in case the rule relies on it
dummyRecord.SetVerified(false)
// export the dummy record data into db params
dummyExport, err := dummyRecord.DBExport(e.App)
if err != nil {
return e.BadRequestError("Failed to create record", fmt.Errorf("dummy DBExport error: %w", err))
}
dummyParams := make(dbx.Params, len(dummyExport))
selects := make([]string, 0, len(dummyExport))
var param string
for k, v := range dummyExport {
k = inflector.Columnify(k) // columnify is just as extra measure in case of custom fields
param = "__pb_create__" + k
dummyParams[param] = v
selects = append(selects, "{:"+param+"} AS [["+k+"]]")
}
// shallow clone the current collection
dummyCollection := *collection
dummyCollection.Id += dummyRandomPart
dummyCollection.Name += inflector.Columnify(dummyRandomPart)
withFrom := fmt.Sprintf("WITH {{%s}} as (SELECT %s)", dummyCollection.Name, strings.Join(selects, ","))
// check non-empty create rule
if *dummyCollection.CreateRule != "" {
ruleQuery := e.App.ConcurrentDB().Select("(1)").PreFragment(withFrom).From(dummyCollection.Name).AndBind(dummyParams)
resolver := core.NewRecordFieldResolver(e.App, &dummyCollection, requestInfo, true)
expr, err := search.FilterData(*dummyCollection.CreateRule).BuildExpr(resolver)
if err != nil {
return e.BadRequestError("Failed to create record", fmt.Errorf("create rule build expression failure: %w", err))
}
ruleQuery.AndWhere(expr)
resolver.UpdateQuery(ruleQuery)
var exists int
err = ruleQuery.Limit(1).Row(&exists)
if err != nil || exists == 0 {
return e.BadRequestError("Failed to create record", fmt.Errorf("create rule failure: %w", err))
}
}
// check for manage rule access
manageRuleQuery := e.App.ConcurrentDB().Select("(1)").PreFragment(withFrom).From(dummyCollection.Name).AndBind(dummyParams)
if !form.HasManageAccess() &&
hasAuthManageAccess(e.App, requestInfo, &dummyCollection, manageRuleQuery) {
form.GrantManagerAccess()
}
}
var isOptFinalizerCalled bool
event := new(core.RecordRequestEvent)
event.RequestEvent = e
event.Collection = collection
event.Record = record
hookErr := e.App.OnRecordCreateRequest().Trigger(event, func(e *core.RecordRequestEvent) error {
form.SetApp(e.App)
form.SetRecord(e.Record)
err := form.Submit()
if err != nil {
return firstApiError(err, e.BadRequestError("Failed to create record", err))
}
err = EnrichRecord(e.RequestEvent, e.Record)
if err != nil {
return firstApiError(err, e.InternalServerError("Failed to enrich record", err))
}
err = execAfterSuccessTx(responseWriteAfterTx, e.App, func() error {
return e.JSON(http.StatusOK, e.Record)
})
if err != nil {
return err
}
if optFinalizer != nil {
isOptFinalizerCalled = true
err = optFinalizer(e.Record)
if err != nil {
return firstApiError(err, e.InternalServerError("", err))
}
}
return nil
})
if hookErr != nil {
return hookErr
}
// e.g. in case the regular hook chain was stopped and the finalizer cannot be executed as part of the last e.Next() task
if !isOptFinalizerCalled && optFinalizer != nil {
if err := optFinalizer(event.Record); err != nil {
return firstApiError(err, e.InternalServerError("", err))
}
}
return nil
}
}
func recordUpdate(responseWriteAfterTx bool, optFinalizer func(data any) error) func(e *core.RequestEvent) error {
return func(e *core.RequestEvent) error {
collection, err := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue("collection"))
if err != nil || collection == nil {
return e.NotFoundError("Missing collection context.", err)
}
if collection.IsView() {
return e.BadRequestError("Unsupported collection type.", nil)
}
err = checkCollectionRateLimit(e, collection, "update")
if err != nil {
return err
}
recordId := e.Request.PathValue("id")
if recordId == "" {
return e.NotFoundError("", nil)
}
requestInfo, err := e.RequestInfo()
if err != nil {
return firstApiError(err, e.BadRequestError("", err))
}
hasSuperuserAuth := requestInfo.HasSuperuserAuth()
if !hasSuperuserAuth && collection.UpdateRule == nil {
return firstApiError(err, e.ForbiddenError("Only superusers can perform this action.", nil))
}
// eager fetch the record so that the modifiers field values can be resolved
record, err := e.App.FindRecordById(collection, recordId)
if err != nil {
return firstApiError(err, e.NotFoundError("", err))
}
data, err := recordDataFromRequest(e, record)
if err != nil {
return firstApiError(err, e.BadRequestError("Failed to read the submitted data.", err))
}
// replace modifiers fields so that the resolved value is always
// available when accessing requestInfo.Body
requestInfo.Body = data
ruleFunc := func(q *dbx.SelectQuery) error {
if !hasSuperuserAuth && collection.UpdateRule != nil && *collection.UpdateRule != "" {
resolver := core.NewRecordFieldResolver(e.App, collection, requestInfo, true)
expr, err := search.FilterData(*collection.UpdateRule).BuildExpr(resolver)
if err != nil {
return err
}
resolver.UpdateQuery(q)
q.AndWhere(expr)
}
return nil
}
// refetch with access checks
record, err = e.App.FindRecordById(collection, recordId, ruleFunc)
if err != nil {
return firstApiError(err, e.NotFoundError("", err))
}
form := forms.NewRecordUpsert(e.App, record)
if hasSuperuserAuth {
form.GrantSuperuserAccess()
}
form.Load(data)
manageRuleQuery := e.App.ConcurrentDB().Select("(1)").From(collection.Name).AndWhere(dbx.HashExp{
collection.Name + ".id": record.Id,
})
if !form.HasManageAccess() &&
hasAuthManageAccess(e.App, requestInfo, collection, manageRuleQuery) {
form.GrantManagerAccess()
}
var isOptFinalizerCalled bool
event := new(core.RecordRequestEvent)
event.RequestEvent = e
event.Collection = collection
event.Record = record
hookErr := e.App.OnRecordUpdateRequest().Trigger(event, func(e *core.RecordRequestEvent) error {
form.SetApp(e.App)
form.SetRecord(e.Record)
err := form.Submit()
if err != nil {
return firstApiError(err, e.BadRequestError("Failed to update record.", err))
}
err = EnrichRecord(e.RequestEvent, e.Record)
if err != nil {
return firstApiError(err, e.InternalServerError("Failed to enrich record", err))
}
err = execAfterSuccessTx(responseWriteAfterTx, e.App, func() error {
return e.JSON(http.StatusOK, e.Record)
})
if err != nil {
return err
}
if optFinalizer != nil {
isOptFinalizerCalled = true
err = optFinalizer(e.Record)
if err != nil {
return firstApiError(err, e.InternalServerError("", fmt.Errorf("update optFinalizer error: %w", err)))
}
}
return nil
})
if hookErr != nil {
return hookErr
}
// e.g. in case the regular hook chain was stopped and the finalizer cannot be executed as part of the last e.Next() task
if !isOptFinalizerCalled && optFinalizer != nil {
if err := optFinalizer(event.Record); err != nil {
return firstApiError(err, e.InternalServerError("", fmt.Errorf("update optFinalizer error: %w", err)))
}
}
return nil
}
}
func recordDelete(responseWriteAfterTx bool, optFinalizer func(data any) error) func(e *core.RequestEvent) error {
return func(e *core.RequestEvent) error {
collection, err := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue("collection"))
if err != nil || collection == nil {
return e.NotFoundError("Missing collection context.", err)
}
if collection.IsView() {
return e.BadRequestError("Unsupported collection type.", nil)
}
err = checkCollectionRateLimit(e, collection, "delete")
if err != nil {
return err
}
recordId := e.Request.PathValue("id")
if recordId == "" {
return e.NotFoundError("", nil)
}
requestInfo, err := e.RequestInfo()
if err != nil {
return firstApiError(err, e.BadRequestError("", err))
}
if !requestInfo.HasSuperuserAuth() && collection.DeleteRule == nil {
return e.ForbiddenError("Only superusers can perform this action.", nil)
}
ruleFunc := func(q *dbx.SelectQuery) error {
if !requestInfo.HasSuperuserAuth() && collection.DeleteRule != nil && *collection.DeleteRule != "" {
resolver := core.NewRecordFieldResolver(e.App, collection, requestInfo, true)
expr, err := search.FilterData(*collection.DeleteRule).BuildExpr(resolver)
if err != nil {
return err
}
resolver.UpdateQuery(q)
q.AndWhere(expr)
}
return nil
}
record, err := e.App.FindRecordById(collection, recordId, ruleFunc)
if err != nil || record == nil {
return e.NotFoundError("", err)
}
var isOptFinalizerCalled bool
event := new(core.RecordRequestEvent)
event.RequestEvent = e
event.Collection = collection
event.Record = record
hookErr := e.App.OnRecordDeleteRequest().Trigger(event, func(e *core.RecordRequestEvent) error {
if err := e.App.Delete(e.Record); err != nil {
return firstApiError(err, e.BadRequestError("Failed to delete record. Make sure that the record is not part of a required relation reference.", err))
}
err = execAfterSuccessTx(responseWriteAfterTx, e.App, func() error {
return e.NoContent(http.StatusNoContent)
})
if err != nil {
return err
}
if optFinalizer != nil {
isOptFinalizerCalled = true
err = optFinalizer(e.Record)
if err != nil {
return firstApiError(err, e.InternalServerError("", fmt.Errorf("delete optFinalizer error: %w", err)))
}
}
return nil
})
if hookErr != nil {
return hookErr
}
// e.g. in case the regular hook chain was stopped and the finalizer cannot be executed as part of the last e.Next() task
if !isOptFinalizerCalled && optFinalizer != nil {
if err := optFinalizer(event.Record); err != nil {
return firstApiError(err, e.InternalServerError("", fmt.Errorf("delete optFinalizer error: %w", err)))
}
}
return nil
}
}
// -------------------------------------------------------------------
func recordDataFromRequest(e *core.RequestEvent, record *core.Record) (map[string]any, error) {
info, err := e.RequestInfo()
if err != nil {
return nil, err
}
// resolve regular fields
result := record.ReplaceModifiers(info.Body)
// resolve uploaded files
uploadedFiles, err := extractUploadedFiles(e, record.Collection(), "")
if err != nil {
return nil, err
}
if len(uploadedFiles) > 0 {
for k, files := range uploadedFiles {
uploaded := make([]any, 0, len(files))
// if not remove/prepend/append -> merge with the submitted
// info.Body values to prevent accidental old files deletion
if info.Body[k] != nil &&
!strings.HasPrefix(k, "+") &&
!strings.HasSuffix(k, "+") &&
!strings.HasSuffix(k, "-") {
existing := list.ToUniqueStringSlice(info.Body[k])
for _, name := range existing {
uploaded = append(uploaded, name)
}
}
for _, file := range files {
uploaded = append(uploaded, file)
}
result[k] = uploaded
}
result = record.ReplaceModifiers(result)
}
isAuth := record.Collection().IsAuth()
// unset hidden fields for non-superusers
if !info.HasSuperuserAuth() {
for _, f := range record.Collection().Fields {
if f.GetHidden() {
// exception for the auth collection "password" field
if isAuth && f.GetName() == core.FieldNamePassword {
continue
}
delete(result, f.GetName())
}
}
}
return result, nil
}
func extractUploadedFiles(re *core.RequestEvent, collection *core.Collection, prefix string) (map[string][]*filesystem.File, error) {
contentType := re.Request.Header.Get("content-type")
if !strings.HasPrefix(contentType, "multipart/form-data") {
return nil, nil // not multipart/form-data request
}
result := map[string][]*filesystem.File{}
for _, field := range collection.Fields {
if field.Type() != core.FieldTypeFile {
continue
}
baseKey := field.GetName()
keys := []string{
baseKey,
// prepend and append modifiers
"+" + baseKey,
baseKey + "+",
}
for _, k := range keys {
if prefix != "" {
k = prefix + "." + k
}
files, err := re.FindUploadedFiles(k)
if err != nil && !errors.Is(err, http.ErrMissingFile) {
return nil, err
}
if len(files) > 0 {
result[k] = files
}
}
}
return result, nil
}
// hasAuthManageAccess checks whether the client is allowed to have
// [forms.RecordUpsert] auth management permissions
// (e.g. allowing to change system auth fields without oldPassword).
func hasAuthManageAccess(app core.App, requestInfo *core.RequestInfo, collection *core.Collection, query *dbx.SelectQuery) bool {
if !collection.IsAuth() {
return false
}
manageRule := collection.ManageRule
if manageRule == nil || *manageRule == "" {
return false // only for superusers (manageRule can't be empty)
}
if requestInfo == nil || requestInfo.Auth == nil {
return false // no auth record
}
resolver := core.NewRecordFieldResolver(app, collection, requestInfo, true)
expr, err := search.FilterData(*manageRule).BuildExpr(resolver)
if err != nil {
app.Logger().Error("Manage rule build expression error", "error", err, "collectionId", collection.Id)
return false
}
query.AndWhere(expr)
resolver.UpdateQuery(query)
var exists int
err = query.Limit(1).Row(&exists)
return err == nil && exists > 0
}

View file

@ -0,0 +1,314 @@
package apis_test
import (
"net/http"
"strings"
"testing"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
)
func TestRecordCrudAuthOriginList(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "guest",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records",
ExpectedStatus: 200,
ExpectedContent: []string{
`"page":1`,
`"perPage":30`,
`"totalItems":0`,
`"totalPages":0`,
`"items":[]`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordsListRequest": 1,
},
},
{
Name: "regular auth with authOrigins",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records",
Headers: map[string]string{
// clients, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"page":1`,
`"perPage":30`,
`"totalItems":1`,
`"totalPages":1`,
`"id":"9r2j0m74260ur8i"`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordsListRequest": 1,
"OnRecordEnrich": 1,
},
},
{
Name: "regular auth without authOrigins",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records",
Headers: map[string]string{
// users, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"page":1`,
`"perPage":30`,
`"totalItems":0`,
`"totalPages":0`,
`"items":[]`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordsListRequest": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestRecordCrudAuthOriginView(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "guest",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records/9r2j0m74260ur8i",
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "non-owner",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records/9r2j0m74260ur8i",
Headers: map[string]string{
// users, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "owner",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records/9r2j0m74260ur8i",
Headers: map[string]string{
// clients, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
},
ExpectedStatus: 200,
ExpectedContent: []string{`"id":"9r2j0m74260ur8i"`},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordViewRequest": 1,
"OnRecordEnrich": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestRecordCrudAuthOriginDelete(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "guest",
Method: http.MethodDelete,
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records/9r2j0m74260ur8i",
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "non-owner",
Method: http.MethodDelete,
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records/9r2j0m74260ur8i",
Headers: map[string]string{
// users, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "owner",
Method: http.MethodDelete,
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records/9r2j0m74260ur8i",
Headers: map[string]string{
// clients, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
},
ExpectedStatus: 204,
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordDeleteRequest": 1,
"OnModelDelete": 1,
"OnModelDeleteExecute": 1,
"OnModelAfterDeleteSuccess": 1,
"OnRecordDelete": 1,
"OnRecordDeleteExecute": 1,
"OnRecordAfterDeleteSuccess": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestRecordCrudAuthOriginCreate(t *testing.T) {
t.Parallel()
body := func() *strings.Reader {
return strings.NewReader(`{
"recordRef": "4q1xlclmfloku33",
"collectionRef": "_pb_users_auth_",
"fingerprint": "abc"
}`)
}
scenarios := []tests.ApiScenario{
{
Name: "guest",
Method: http.MethodPost,
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records",
Body: body(),
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "owner regular auth",
Method: http.MethodPost,
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records",
Headers: map[string]string{
// users, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
Body: body(),
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "superusers auth",
Method: http.MethodPost,
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records",
Headers: map[string]string{
// superusers, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
Body: body(),
ExpectedContent: []string{
`"fingerprint":"abc"`,
},
ExpectedStatus: 200,
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordCreateRequest": 1,
"OnRecordEnrich": 1,
"OnModelCreate": 1,
"OnModelCreateExecute": 1,
"OnModelAfterCreateSuccess": 1,
"OnModelValidate": 1,
"OnRecordCreate": 1,
"OnRecordCreateExecute": 1,
"OnRecordAfterCreateSuccess": 1,
"OnRecordValidate": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestRecordCrudAuthOriginUpdate(t *testing.T) {
t.Parallel()
body := func() *strings.Reader {
return strings.NewReader(`{
"fingerprint":"abc"
}`)
}
scenarios := []tests.ApiScenario{
{
Name: "guest",
Method: http.MethodPatch,
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records/9r2j0m74260ur8i",
Body: body(),
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "owner regular auth",
Method: http.MethodPatch,
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records/9r2j0m74260ur8i",
Headers: map[string]string{
// clients, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
},
Body: body(),
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "superusers auth",
Method: http.MethodPatch,
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records/9r2j0m74260ur8i",
Headers: map[string]string{
// superusers, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
Body: body(),
ExpectedContent: []string{
`"id":"9r2j0m74260ur8i"`,
`"fingerprint":"abc"`,
},
ExpectedStatus: 200,
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordUpdateRequest": 1,
"OnRecordEnrich": 1,
"OnModelUpdate": 1,
"OnModelUpdateExecute": 1,
"OnModelAfterUpdateSuccess": 1,
"OnModelValidate": 1,
"OnRecordUpdate": 1,
"OnRecordUpdateExecute": 1,
"OnRecordAfterUpdateSuccess": 1,
"OnRecordValidate": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}

View file

@ -0,0 +1,316 @@
package apis_test
import (
"net/http"
"strings"
"testing"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
)
func TestRecordCrudExternalAuthList(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "guest",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records",
ExpectedStatus: 200,
ExpectedContent: []string{
`"page":1`,
`"perPage":30`,
`"totalItems":0`,
`"totalPages":0`,
`"items":[]`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordsListRequest": 1,
},
},
{
Name: "regular auth with externalAuths",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records",
Headers: map[string]string{
// clients, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"page":1`,
`"perPage":30`,
`"totalItems":1`,
`"totalPages":1`,
`"id":"f1z5b3843pzc964"`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordsListRequest": 1,
"OnRecordEnrich": 1,
},
},
{
Name: "regular auth without externalAuths",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records",
Headers: map[string]string{
// users, test2@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6Im9hcDY0MGNvdDR5cnUycyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.GfJo6EHIobgas_AXt-M-tj5IoQendPnrkMSe9ExuSEY",
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"page":1`,
`"perPage":30`,
`"totalItems":0`,
`"totalPages":0`,
`"items":[]`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordsListRequest": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestRecordCrudExternalAuthView(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "guest",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records/dlmflokuq1xl342",
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "non-owner",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records/dlmflokuq1xl342",
Headers: map[string]string{
// clients, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
},
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "owner",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records/dlmflokuq1xl342",
Headers: map[string]string{
// users, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 200,
ExpectedContent: []string{`"id":"dlmflokuq1xl342"`},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordViewRequest": 1,
"OnRecordEnrich": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestRecordCrudExternalAuthDelete(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "guest",
Method: http.MethodDelete,
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records/dlmflokuq1xl342",
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "non-owner",
Method: http.MethodDelete,
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records/dlmflokuq1xl342",
Headers: map[string]string{
// clients, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
},
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "owner",
Method: http.MethodDelete,
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records/dlmflokuq1xl342",
Headers: map[string]string{
// users, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 204,
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordDeleteRequest": 1,
"OnModelDelete": 1,
"OnModelDeleteExecute": 1,
"OnModelAfterDeleteSuccess": 1,
"OnRecordDelete": 1,
"OnRecordDeleteExecute": 1,
"OnRecordAfterDeleteSuccess": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestRecordCrudExternalAuthCreate(t *testing.T) {
t.Parallel()
body := func() *strings.Reader {
return strings.NewReader(`{
"recordRef": "4q1xlclmfloku33",
"collectionRef": "_pb_users_auth_",
"provider": "github",
"providerId": "abc"
}`)
}
scenarios := []tests.ApiScenario{
{
Name: "guest",
Method: http.MethodPost,
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records",
Body: body(),
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "owner regular auth",
Method: http.MethodPost,
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records",
Headers: map[string]string{
// users, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
Body: body(),
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "superusers auth",
Method: http.MethodPost,
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records",
Headers: map[string]string{
// superusers, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
Body: body(),
ExpectedContent: []string{
`"recordRef":"4q1xlclmfloku33"`,
`"providerId":"abc"`,
},
ExpectedStatus: 200,
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordCreateRequest": 1,
"OnRecordEnrich": 1,
"OnModelCreate": 1,
"OnModelCreateExecute": 1,
"OnModelAfterCreateSuccess": 1,
"OnModelValidate": 1,
"OnRecordCreate": 1,
"OnRecordCreateExecute": 1,
"OnRecordAfterCreateSuccess": 1,
"OnRecordValidate": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestRecordCrudExternalAuthUpdate(t *testing.T) {
t.Parallel()
body := func() *strings.Reader {
return strings.NewReader(`{
"providerId": "abc"
}`)
}
scenarios := []tests.ApiScenario{
{
Name: "guest",
Method: http.MethodPatch,
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records/dlmflokuq1xl342",
Body: body(),
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "owner regular auth",
Method: http.MethodPatch,
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records/dlmflokuq1xl342",
Headers: map[string]string{
// clients, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
},
Body: body(),
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "superusers auth",
Method: http.MethodPatch,
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records/dlmflokuq1xl342",
Headers: map[string]string{
// superusers, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
Body: body(),
ExpectedContent: []string{
`"id":"dlmflokuq1xl342"`,
`"providerId":"abc"`,
},
ExpectedStatus: 200,
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordUpdateRequest": 1,
"OnRecordEnrich": 1,
"OnModelUpdate": 1,
"OnModelUpdateExecute": 1,
"OnModelAfterUpdateSuccess": 1,
"OnModelValidate": 1,
"OnRecordUpdate": 1,
"OnRecordUpdateExecute": 1,
"OnRecordAfterUpdateSuccess": 1,
"OnRecordValidate": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}

View file

@ -0,0 +1,405 @@
package apis_test
import (
"net/http"
"strings"
"testing"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
)
func TestRecordCrudMFAList(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "guest",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameMFAs + "/records",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubMFARecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"page":1`,
`"perPage":30`,
`"totalItems":0`,
`"totalPages":0`,
`"items":[]`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordsListRequest": 1,
},
},
{
Name: "regular auth with mfas",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameMFAs + "/records",
Headers: map[string]string{
// users, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubMFARecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"page":1`,
`"perPage":30`,
`"totalItems":1`,
`"totalPages":1`,
`"id":"user1_0"`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordsListRequest": 1,
"OnRecordEnrich": 1,
},
},
{
Name: "regular auth without mfas",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameMFAs + "/records",
Headers: map[string]string{
// clients, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubMFARecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"page":1`,
`"perPage":30`,
`"totalItems":0`,
`"totalPages":0`,
`"items":[]`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordsListRequest": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestRecordCrudMFAView(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "guest",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameMFAs + "/records/user1_0",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubMFARecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "non-owner",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameMFAs + "/records/user1_0",
Headers: map[string]string{
// clients, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubMFARecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "owner",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameMFAs + "/records/user1_0",
Headers: map[string]string{
// users, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubMFARecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 200,
ExpectedContent: []string{`"id":"user1_0"`},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordViewRequest": 1,
"OnRecordEnrich": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestRecordCrudMFADelete(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "guest",
Method: http.MethodDelete,
URL: "/api/collections/" + core.CollectionNameMFAs + "/records/user1_0",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubMFARecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "non-owner",
Method: http.MethodDelete,
URL: "/api/collections/" + core.CollectionNameMFAs + "/records/user1_0",
Headers: map[string]string{
// clients, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubMFARecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "owner",
Method: http.MethodDelete,
URL: "/api/collections/" + core.CollectionNameMFAs + "/records/user1_0",
Headers: map[string]string{
// users, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubMFARecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "superusers auth",
Method: http.MethodDelete,
URL: "/api/collections/" + core.CollectionNameMFAs + "/records/user1_0",
Headers: map[string]string{
// superusers, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubMFARecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 204,
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordDeleteRequest": 1,
"OnModelDelete": 1,
"OnModelDeleteExecute": 1,
"OnModelAfterDeleteSuccess": 1,
"OnRecordDelete": 1,
"OnRecordDeleteExecute": 1,
"OnRecordAfterDeleteSuccess": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestRecordCrudMFACreate(t *testing.T) {
t.Parallel()
body := func() *strings.Reader {
return strings.NewReader(`{
"recordRef": "4q1xlclmfloku33",
"collectionRef": "_pb_users_auth_",
"method": "abc"
}`)
}
scenarios := []tests.ApiScenario{
{
Name: "guest",
Method: http.MethodPost,
URL: "/api/collections/" + core.CollectionNameMFAs + "/records",
Body: body(),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubMFARecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "owner regular auth",
Method: http.MethodPost,
URL: "/api/collections/" + core.CollectionNameMFAs + "/records",
Headers: map[string]string{
// users, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
Body: body(),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubMFARecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "superusers auth",
Method: http.MethodPost,
URL: "/api/collections/" + core.CollectionNameMFAs + "/records",
Headers: map[string]string{
// superusers, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
Body: body(),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubMFARecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedContent: []string{
`"recordRef":"4q1xlclmfloku33"`,
},
ExpectedStatus: 200,
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordCreateRequest": 1,
"OnRecordEnrich": 1,
"OnModelCreate": 1,
"OnModelCreateExecute": 1,
"OnModelAfterCreateSuccess": 1,
"OnModelValidate": 1,
"OnRecordCreate": 1,
"OnRecordCreateExecute": 1,
"OnRecordAfterCreateSuccess": 1,
"OnRecordValidate": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestRecordCrudMFAUpdate(t *testing.T) {
t.Parallel()
body := func() *strings.Reader {
return strings.NewReader(`{
"method":"abc"
}`)
}
scenarios := []tests.ApiScenario{
{
Name: "guest",
Method: http.MethodPatch,
URL: "/api/collections/" + core.CollectionNameMFAs + "/records/user1_0",
Body: body(),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubMFARecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "owner regular auth",
Method: http.MethodPatch,
URL: "/api/collections/" + core.CollectionNameMFAs + "/records/user1_0",
Headers: map[string]string{
// users, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
Body: body(),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubMFARecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "superusers auth",
Method: http.MethodPatch,
URL: "/api/collections/" + core.CollectionNameMFAs + "/records/user1_0",
Headers: map[string]string{
// superusers, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
Body: body(),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubMFARecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedContent: []string{
`"id":"user1_0"`,
},
ExpectedStatus: 200,
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordUpdateRequest": 1,
"OnRecordEnrich": 1,
"OnModelUpdate": 1,
"OnModelUpdateExecute": 1,
"OnModelAfterUpdateSuccess": 1,
"OnModelValidate": 1,
"OnRecordUpdate": 1,
"OnRecordUpdateExecute": 1,
"OnRecordAfterUpdateSuccess": 1,
"OnRecordValidate": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}

View file

@ -0,0 +1,405 @@
package apis_test
import (
"net/http"
"strings"
"testing"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
)
func TestRecordCrudOTPList(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "guest",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameOTPs + "/records",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubOTPRecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"page":1`,
`"perPage":30`,
`"totalItems":0`,
`"totalPages":0`,
`"items":[]`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordsListRequest": 1,
},
},
{
Name: "regular auth with otps",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameOTPs + "/records",
Headers: map[string]string{
// users, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubOTPRecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"page":1`,
`"perPage":30`,
`"totalItems":1`,
`"totalPages":1`,
`"id":"user1_0"`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordsListRequest": 1,
"OnRecordEnrich": 1,
},
},
{
Name: "regular auth without otps",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameOTPs + "/records",
Headers: map[string]string{
// clients, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubOTPRecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"page":1`,
`"perPage":30`,
`"totalItems":0`,
`"totalPages":0`,
`"items":[]`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordsListRequest": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestRecordCrudOTPView(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "guest",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameOTPs + "/records/user1_0",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubOTPRecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "non-owner",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameOTPs + "/records/user1_0",
Headers: map[string]string{
// clients, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubOTPRecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "owner",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameOTPs + "/records/user1_0",
Headers: map[string]string{
// users, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubOTPRecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 200,
ExpectedContent: []string{`"id":"user1_0"`},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordViewRequest": 1,
"OnRecordEnrich": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestRecordCrudOTPDelete(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "guest",
Method: http.MethodDelete,
URL: "/api/collections/" + core.CollectionNameOTPs + "/records/user1_0",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubOTPRecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "non-owner auth",
Method: http.MethodDelete,
URL: "/api/collections/" + core.CollectionNameOTPs + "/records/user1_0",
Headers: map[string]string{
// clients, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubOTPRecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "owner regular auth",
Method: http.MethodDelete,
URL: "/api/collections/" + core.CollectionNameOTPs + "/records/user1_0",
Headers: map[string]string{
// users, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubOTPRecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "superusers auth",
Method: http.MethodDelete,
URL: "/api/collections/" + core.CollectionNameOTPs + "/records/user1_0",
Headers: map[string]string{
// superusers, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubOTPRecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 204,
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordDeleteRequest": 1,
"OnModelDelete": 1,
"OnModelDeleteExecute": 1,
"OnModelAfterDeleteSuccess": 1,
"OnRecordDelete": 1,
"OnRecordDeleteExecute": 1,
"OnRecordAfterDeleteSuccess": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestRecordCrudOTPCreate(t *testing.T) {
t.Parallel()
body := func() *strings.Reader {
return strings.NewReader(`{
"recordRef": "4q1xlclmfloku33",
"collectionRef": "_pb_users_auth_",
"password": "abc"
}`)
}
scenarios := []tests.ApiScenario{
{
Name: "guest",
Method: http.MethodPost,
URL: "/api/collections/" + core.CollectionNameOTPs + "/records",
Body: body(),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubOTPRecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "owner regular auth",
Method: http.MethodPost,
URL: "/api/collections/" + core.CollectionNameOTPs + "/records",
Headers: map[string]string{
// users, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
Body: body(),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubOTPRecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "superusers auth",
Method: http.MethodPost,
URL: "/api/collections/" + core.CollectionNameOTPs + "/records",
Headers: map[string]string{
// superusers, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
Body: body(),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubOTPRecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedContent: []string{
`"recordRef":"4q1xlclmfloku33"`,
},
ExpectedStatus: 200,
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordCreateRequest": 1,
"OnRecordEnrich": 1,
"OnModelCreate": 1,
"OnModelCreateExecute": 1,
"OnModelAfterCreateSuccess": 1,
"OnModelValidate": 1,
"OnRecordCreate": 1,
"OnRecordCreateExecute": 1,
"OnRecordAfterCreateSuccess": 1,
"OnRecordValidate": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestRecordCrudOTPUpdate(t *testing.T) {
t.Parallel()
body := func() *strings.Reader {
return strings.NewReader(`{
"password":"abc"
}`)
}
scenarios := []tests.ApiScenario{
{
Name: "guest",
Method: http.MethodPatch,
URL: "/api/collections/" + core.CollectionNameOTPs + "/records/user1_0",
Body: body(),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubOTPRecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "owner regular auth",
Method: http.MethodPatch,
URL: "/api/collections/" + core.CollectionNameOTPs + "/records/user1_0",
Headers: map[string]string{
// users, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
Body: body(),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubOTPRecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "superusers auth",
Method: http.MethodPatch,
URL: "/api/collections/" + core.CollectionNameOTPs + "/records/user1_0",
Headers: map[string]string{
// superusers, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
Body: body(),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubOTPRecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedContent: []string{
`"id":"user1_0"`,
},
ExpectedStatus: 200,
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordUpdateRequest": 1,
"OnRecordEnrich": 1,
"OnModelUpdate": 1,
"OnModelUpdateExecute": 1,
"OnModelAfterUpdateSuccess": 1,
"OnModelValidate": 1,
"OnRecordUpdate": 1,
"OnRecordUpdateExecute": 1,
"OnRecordAfterUpdateSuccess": 1,
"OnRecordValidate": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}

View file

@ -0,0 +1,336 @@
package apis_test
import (
"net/http"
"strings"
"testing"
"github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
)
func TestRecordCrudSuperuserList(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "guest",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records",
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "non-superusers auth",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records",
Headers: map[string]string{
// users, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "superusers auth",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records",
Headers: map[string]string{
// _superusers, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"page":1`,
`"perPage":30`,
`"totalPages":1`,
`"totalItems":4`,
`"items":[{`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordsListRequest": 1,
"OnRecordEnrich": 4,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestRecordCrudSuperuserView(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "guest",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records/sywbhecnh46rhm0",
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "non-superusers auth",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records/sywbhecnh46rhm0",
Headers: map[string]string{
// users, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "superusers auth",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records/sywbhecnh46rhm0",
Headers: map[string]string{
// _superusers, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"id":"sywbhecnh46rhm0"`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordViewRequest": 1,
"OnRecordEnrich": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestRecordCrudSuperuserDelete(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "guest",
Method: http.MethodDelete,
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records/sbmbsdb40jyxf7h",
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "non-superusers auth",
Method: http.MethodDelete,
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records/sbmbsdb40jyxf7h",
Headers: map[string]string{
// users, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "superusers auth",
Method: http.MethodDelete,
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records/sbmbsdb40jyxf7h",
Headers: map[string]string{
// _superusers, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
ExpectedStatus: 204,
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordDeleteRequest": 1,
"OnModelDelete": 4, // + 3 AuthOrigins
"OnModelDeleteExecute": 4,
"OnModelAfterDeleteSuccess": 4,
"OnRecordDelete": 4,
"OnRecordDeleteExecute": 4,
"OnRecordAfterDeleteSuccess": 4,
},
},
{
Name: "delete the last superuser",
Method: http.MethodDelete,
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records/sywbhecnh46rhm0",
Headers: map[string]string{
// _superusers, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
// delete all other superusers
superusers, err := app.FindAllRecords(core.CollectionNameSuperusers, dbx.Not(dbx.HashExp{"id": "sywbhecnh46rhm0"}))
if err != nil {
t.Fatal(err)
}
for _, superuser := range superusers {
if err = app.Delete(superuser); err != nil {
t.Fatal(err)
}
}
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordDeleteRequest": 1,
"OnModelDelete": 1,
"OnModelAfterDeleteError": 1,
"OnRecordDelete": 1,
"OnRecordAfterDeleteError": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestRecordCrudSuperuserCreate(t *testing.T) {
t.Parallel()
body := func() *strings.Reader {
return strings.NewReader(`{
"email": "test_new@example.com",
"password": "1234567890",
"passwordConfirm": "1234567890",
"verified": false
}`)
}
scenarios := []tests.ApiScenario{
{
Name: "guest",
Method: http.MethodPost,
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records",
Body: body(),
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "non-superusers auth",
Method: http.MethodPost,
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records",
Headers: map[string]string{
// users, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
Body: body(),
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "superusers auth",
Method: http.MethodPost,
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records",
Headers: map[string]string{
// _superusers, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
Body: body(),
ExpectedContent: []string{
`"collectionName":"_superusers"`,
`"email":"test_new@example.com"`,
`"verified":true`,
},
ExpectedStatus: 200,
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordCreateRequest": 1,
"OnRecordEnrich": 1,
"OnModelCreate": 1,
"OnModelCreateExecute": 1,
"OnModelAfterCreateSuccess": 1,
"OnModelValidate": 1,
"OnRecordCreate": 1,
"OnRecordCreateExecute": 1,
"OnRecordAfterCreateSuccess": 1,
"OnRecordValidate": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestRecordCrudSuperuserUpdate(t *testing.T) {
t.Parallel()
body := func() *strings.Reader {
return strings.NewReader(`{
"email": "test_new@example.com",
"verified": true
}`)
}
scenarios := []tests.ApiScenario{
{
Name: "guest",
Method: http.MethodPatch,
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records/sywbhecnh46rhm0",
Body: body(),
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "non-superusers auth",
Method: http.MethodPatch,
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records/sywbhecnh46rhm0",
Headers: map[string]string{
// users, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
Body: body(),
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "superusers auth",
Method: http.MethodPatch,
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records/sywbhecnh46rhm0",
Headers: map[string]string{
// _superusers, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
Body: body(),
ExpectedContent: []string{
`"collectionName":"_superusers"`,
`"id":"sywbhecnh46rhm0"`,
`"email":"test_new@example.com"`,
`"verified":true`,
},
ExpectedStatus: 200,
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordUpdateRequest": 1,
"OnRecordEnrich": 1,
"OnModelUpdate": 1,
"OnModelUpdateExecute": 1,
"OnModelAfterUpdateSuccess": 1,
"OnModelValidate": 1,
"OnRecordUpdate": 1,
"OnRecordUpdateExecute": 1,
"OnRecordAfterUpdateSuccess": 1,
"OnRecordValidate": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}

3584
apis/record_crud_test.go Normal file

File diff suppressed because it is too large Load diff

636
apis/record_helpers.go Normal file
View file

@ -0,0 +1,636 @@
package apis
import (
"database/sql"
"errors"
"fmt"
"net/http"
"strings"
"time"
"github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/mails"
"github.com/pocketbase/pocketbase/tools/router"
"github.com/pocketbase/pocketbase/tools/routine"
"github.com/pocketbase/pocketbase/tools/search"
"github.com/pocketbase/pocketbase/tools/security"
)
const (
expandQueryParam = "expand"
fieldsQueryParam = "fields"
)
var ErrMFA = errors.New("mfa required")
// RecordAuthResponse writes standardized json record auth response
// into the specified request context.
//
// The authMethod argument specify the name of the current authentication method (eg. password, oauth2, etc.)
// that it is used primarily as an auth identifier during MFA and for login alerts.
//
// Set authMethod to empty string if you want to ignore the MFA checks and the login alerts
// (can be also adjusted additionally via the OnRecordAuthRequest hook).
func RecordAuthResponse(e *core.RequestEvent, authRecord *core.Record, authMethod string, meta any) error {
token, tokenErr := authRecord.NewAuthToken()
if tokenErr != nil {
return e.InternalServerError("Failed to create auth token.", tokenErr)
}
return recordAuthResponse(e, authRecord, token, authMethod, meta)
}
func recordAuthResponse(e *core.RequestEvent, authRecord *core.Record, token string, authMethod string, meta any) error {
originalRequestInfo, err := e.RequestInfo()
if err != nil {
return err
}
ok, err := e.App.CanAccessRecord(authRecord, originalRequestInfo, authRecord.Collection().AuthRule)
if !ok {
return firstApiError(err, e.ForbiddenError("The request doesn't satisfy the collection requirements to authenticate.", err))
}
event := new(core.RecordAuthRequestEvent)
event.RequestEvent = e
event.Collection = authRecord.Collection()
event.Record = authRecord
event.Token = token
event.Meta = meta
event.AuthMethod = authMethod
return e.App.OnRecordAuthRequest().Trigger(event, func(e *core.RecordAuthRequestEvent) error {
if e.Written() {
return nil
}
// MFA
// ---
mfaId, err := checkMFA(e.RequestEvent, e.Record, e.AuthMethod)
if err != nil {
return err
}
// require additional authentication
if mfaId != "" {
// eagerly write the mfa response and return an err so that
// external middlewars are aware that the auth response requires an extra step
e.JSON(http.StatusUnauthorized, map[string]string{
"mfaId": mfaId,
})
return ErrMFA
}
// ---
// create a shallow copy of the cached request data and adjust it to the current auth record
requestInfo := *originalRequestInfo
requestInfo.Auth = e.Record
err = triggerRecordEnrichHooks(e.App, &requestInfo, []*core.Record{e.Record}, func() error {
if e.Record.IsSuperuser() {
e.Record.Unhide(e.Record.Collection().Fields.FieldNames()...)
}
// allow always returning the email address of the authenticated model
e.Record.IgnoreEmailVisibility(true)
// expand record relations
expands := strings.Split(e.Request.URL.Query().Get(expandQueryParam), ",")
if len(expands) > 0 {
failed := e.App.ExpandRecord(e.Record, expands, expandFetch(e.App, &requestInfo))
if len(failed) > 0 {
e.App.Logger().Warn("[recordAuthResponse] Failed to expand relations", "error", failed)
}
}
return nil
})
if err != nil {
return err
}
if e.AuthMethod != "" && authRecord.Collection().AuthAlert.Enabled {
if err = authAlert(e.RequestEvent, e.Record); err != nil {
e.App.Logger().Warn("[recordAuthResponse] Failed to send login alert", "error", err)
}
}
result := struct {
Meta any `json:"meta,omitempty"`
Record *core.Record `json:"record"`
Token string `json:"token"`
}{
Token: e.Token,
Record: e.Record,
}
if e.Meta != nil {
result.Meta = e.Meta
}
return execAfterSuccessTx(true, e.App, func() error {
return e.JSON(http.StatusOK, result)
})
})
}
// wantsMFA checks whether to enable MFA for the specified auth record based on its MFA rule
// (note: returns true even in case of an error as a safer default).
func wantsMFA(e *core.RequestEvent, record *core.Record) (bool, error) {
rule := record.Collection().MFA.Rule
if rule == "" {
return true, nil
}
requestInfo, err := e.RequestInfo()
if err != nil {
return true, err
}
var exists int
query := e.App.RecordQuery(record.Collection()).
Select("(1)").
AndWhere(dbx.HashExp{record.Collection().Name + ".id": record.Id})
// parse and apply the access rule filter
resolver := core.NewRecordFieldResolver(e.App, record.Collection(), requestInfo, true)
expr, err := search.FilterData(rule).BuildExpr(resolver)
if err != nil {
return true, err
}
resolver.UpdateQuery(query)
err = query.AndWhere(expr).Limit(1).Row(&exists)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return true, err
}
return exists > 0, nil
}
// checkMFA handles any MFA auth checks that needs to be performed for the specified request event.
// Returns the mfaId that needs to be written as response to the user.
//
// (note: all auth methods are treated as equal and there is no requirement for "pairing").
func checkMFA(e *core.RequestEvent, authRecord *core.Record, currentAuthMethod string) (string, error) {
if !authRecord.Collection().MFA.Enabled || currentAuthMethod == "" {
return "", nil
}
ok, err := wantsMFA(e, authRecord)
if err != nil {
return "", e.BadRequestError("Failed to authenticate.", fmt.Errorf("MFA rule failure: %w", err))
}
if !ok {
return "", nil // no mfa needed for this auth record
}
// read the mfaId either from the qyery params or request body
mfaId := e.Request.URL.Query().Get("mfaId")
if mfaId == "" {
// check the body
data := struct {
MfaId string `form:"mfaId" json:"mfaId" xml:"mfaId"`
}{}
if err := e.BindBody(&data); err != nil {
return "", firstApiError(err, e.BadRequestError("Failed to read MFA Id", err))
}
mfaId = data.MfaId
}
// first-time auth
// ---
if mfaId == "" {
mfa := core.NewMFA(e.App)
mfa.SetCollectionRef(authRecord.Collection().Id)
mfa.SetRecordRef(authRecord.Id)
mfa.SetMethod(currentAuthMethod)
if err := e.App.Save(mfa); err != nil {
return "", firstApiError(err, e.InternalServerError("Failed to create MFA record", err))
}
return mfa.Id, nil
}
// second-time auth
// ---
mfa, err := e.App.FindMFAById(mfaId)
deleteMFA := func() {
// try to delete the expired mfa
if mfa != nil {
if deleteErr := e.App.Delete(mfa); deleteErr != nil {
e.App.Logger().Warn("Failed to delete expired MFA record", "error", deleteErr, "mfaId", mfa.Id)
}
}
}
if err != nil || mfa.HasExpired(authRecord.Collection().MFA.DurationTime()) {
deleteMFA()
return "", e.BadRequestError("Invalid or expired MFA session.", err)
}
if mfa.RecordRef() != authRecord.Id || mfa.CollectionRef() != authRecord.Collection().Id {
return "", e.BadRequestError("Invalid MFA session.", nil)
}
if mfa.Method() == currentAuthMethod {
return "", e.BadRequestError("A different authentication method is required.", nil)
}
deleteMFA()
return "", nil
}
// EnrichRecord parses the request context and enrich the provided record:
// - expands relations (if defaultExpands and/or ?expand query param is set)
// - ensures that the emails of the auth record and its expanded auth relations
// are visible only for the current logged superuser, record owner or record with manage access
func EnrichRecord(e *core.RequestEvent, record *core.Record, defaultExpands ...string) error {
return EnrichRecords(e, []*core.Record{record}, defaultExpands...)
}
// EnrichRecords parses the request context and enriches the provided records:
// - expands relations (if defaultExpands and/or ?expand query param is set)
// - ensures that the emails of the auth records and their expanded auth relations
// are visible only for the current logged superuser, record owner or record with manage access
//
// Note: Expects all records to be from the same collection!
func EnrichRecords(e *core.RequestEvent, records []*core.Record, defaultExpands ...string) error {
if len(records) == 0 {
return nil
}
info, err := e.RequestInfo()
if err != nil {
return err
}
return triggerRecordEnrichHooks(e.App, info, records, func() error {
expands := defaultExpands
if param := info.Query[expandQueryParam]; param != "" {
expands = append(expands, strings.Split(param, ",")...)
}
err := defaultEnrichRecords(e.App, info, records, expands...)
if err != nil {
// only log because it is not critical
e.App.Logger().Warn("failed to apply default enriching", "error", err)
}
return nil
})
}
type iterator[T any] struct {
items []T
index int
}
func (ri *iterator[T]) next() T {
var item T
if ri.index < len(ri.items) {
item = ri.items[ri.index]
ri.index++
}
return item
}
func triggerRecordEnrichHooks(app core.App, requestInfo *core.RequestInfo, records []*core.Record, finalizer func() error) error {
it := iterator[*core.Record]{items: records}
enrichHook := app.OnRecordEnrich()
event := new(core.RecordEnrichEvent)
event.App = app
event.RequestInfo = requestInfo
var iterate func(record *core.Record) error
iterate = func(record *core.Record) error {
if record == nil {
return nil
}
event.Record = record
return enrichHook.Trigger(event, func(ee *core.RecordEnrichEvent) error {
next := it.next()
if next == nil {
if finalizer != nil {
return finalizer()
}
return nil
}
event.App = ee.App // in case it was replaced with a transaction
event.Record = next
err := iterate(next)
event.App = app
event.Record = record
return err
})
}
return iterate(it.next())
}
func defaultEnrichRecords(app core.App, requestInfo *core.RequestInfo, records []*core.Record, expands ...string) error {
err := autoResolveRecordsFlags(app, records, requestInfo)
if err != nil {
return fmt.Errorf("failed to resolve records flags: %w", err)
}
if len(expands) > 0 {
expandErrs := app.ExpandRecords(records, expands, expandFetch(app, requestInfo))
if len(expandErrs) > 0 {
errsSlice := make([]error, 0, len(expandErrs))
for key, err := range expandErrs {
errsSlice = append(errsSlice, fmt.Errorf("failed to expand %q: %w", key, err))
}
return fmt.Errorf("failed to expand records: %w", errors.Join(errsSlice...))
}
}
return nil
}
// expandFetch is the records fetch function that is used to expand related records.
func expandFetch(app core.App, originalRequestInfo *core.RequestInfo) core.ExpandFetchFunc {
// shallow clone the provided request info to set an "expand" context
requestInfoClone := *originalRequestInfo
requestInfoPtr := &requestInfoClone
requestInfoPtr.Context = core.RequestInfoContextExpand
return func(relCollection *core.Collection, relIds []string) ([]*core.Record, error) {
records, findErr := app.FindRecordsByIds(relCollection.Id, relIds, func(q *dbx.SelectQuery) error {
if requestInfoPtr.Auth != nil && requestInfoPtr.Auth.IsSuperuser() {
return nil // superusers can access everything
}
if relCollection.ViewRule == nil {
return fmt.Errorf("only superusers can view collection %q records", relCollection.Name)
}
if *relCollection.ViewRule != "" {
resolver := core.NewRecordFieldResolver(app, relCollection, requestInfoPtr, true)
expr, err := search.FilterData(*(relCollection.ViewRule)).BuildExpr(resolver)
if err != nil {
return err
}
resolver.UpdateQuery(q)
q.AndWhere(expr)
}
return nil
})
if findErr != nil {
return nil, findErr
}
enrichErr := triggerRecordEnrichHooks(app, requestInfoPtr, records, func() error {
if err := autoResolveRecordsFlags(app, records, requestInfoPtr); err != nil {
// non-critical error
app.Logger().Warn("Failed to apply autoResolveRecordsFlags for the expanded records", "error", err)
}
return nil
})
if enrichErr != nil {
return nil, enrichErr
}
return records, nil
}
}
// autoResolveRecordsFlags resolves various visibility flags of the provided records.
//
// Currently it enables:
// - export of hidden fields if the current auth model is a superuser
// - email export ignoring the emailVisibity checks if the current auth model is superuser, owner or a "manager".
//
// Note: Expects all records to be from the same collection!
func autoResolveRecordsFlags(app core.App, records []*core.Record, requestInfo *core.RequestInfo) error {
if len(records) == 0 {
return nil // nothing to resolve
}
if requestInfo.HasSuperuserAuth() {
hiddenFields := records[0].Collection().Fields.FieldNames()
for _, rec := range records {
rec.Unhide(hiddenFields...)
rec.IgnoreEmailVisibility(true)
}
}
// additional emailVisibility checks
// ---------------------------------------------------------------
if !records[0].Collection().IsAuth() {
return nil // not auth collection records
}
collection := records[0].Collection()
mappedRecords := make(map[string]*core.Record, len(records))
recordIds := make([]any, len(records))
for i, rec := range records {
mappedRecords[rec.Id] = rec
recordIds[i] = rec.Id
}
if requestInfo.Auth != nil && mappedRecords[requestInfo.Auth.Id] != nil {
mappedRecords[requestInfo.Auth.Id].IgnoreEmailVisibility(true)
}
if collection.ManageRule == nil || *collection.ManageRule == "" {
return nil // no manage rule to check
}
// fetch the ids of the managed records
// ---
managedIds := []string{}
query := app.RecordQuery(collection).
Select(app.ConcurrentDB().QuoteSimpleColumnName(collection.Name) + ".id").
AndWhere(dbx.In(app.ConcurrentDB().QuoteSimpleColumnName(collection.Name)+".id", recordIds...))
resolver := core.NewRecordFieldResolver(app, collection, requestInfo, true)
expr, err := search.FilterData(*collection.ManageRule).BuildExpr(resolver)
if err != nil {
return err
}
resolver.UpdateQuery(query)
query.AndWhere(expr)
if err := query.Column(&managedIds); err != nil {
return err
}
// ---
// ignore the email visibility check for the managed records
for _, id := range managedIds {
if rec, ok := mappedRecords[id]; ok {
rec.IgnoreEmailVisibility(true)
}
}
return nil
}
var ruleQueryParams = []string{search.FilterQueryParam, search.SortQueryParam}
var superuserOnlyRuleFields = []string{"@collection.", "@request."}
// checkForSuperuserOnlyRuleFields loosely checks and returns an error if
// the provided RequestInfo contains rule fields that only the superuser can use.
func checkForSuperuserOnlyRuleFields(requestInfo *core.RequestInfo) error {
if len(requestInfo.Query) == 0 || requestInfo.HasSuperuserAuth() {
return nil // superuser or nothing to check
}
for _, param := range ruleQueryParams {
v := requestInfo.Query[param]
if v == "" {
continue
}
for _, field := range superuserOnlyRuleFields {
if strings.Contains(v, field) {
return router.NewForbiddenError("Only superusers can filter by "+field, nil)
}
}
}
return nil
}
// firstApiError returns the first ApiError from the errors list
// (this is used usually to prevent unnecessary wraping and to allow bubling ApiError from nested hooks)
//
// If no ApiError is found, returns a default "Internal server" error.
func firstApiError(errs ...error) *router.ApiError {
var apiErr *router.ApiError
var ok bool
for _, err := range errs {
if err == nil {
continue
}
// quick assert to avoid the reflection checks
apiErr, ok = err.(*router.ApiError)
if ok {
return apiErr
}
// nested/wrapped errors
if errors.As(err, &apiErr) {
return apiErr
}
}
return router.NewInternalServerError("", errors.Join(errs...))
}
// execAfterSuccessTx ensures that fn is executed only after a succesul transaction.
//
// If the current app instance is not a transactional or checkTx is false,
// then fn is directly executed.
//
// It could be usually used to allow propagating an error or writing
// custom response from within the wrapped transaction block.
func execAfterSuccessTx(checkTx bool, app core.App, fn func() error) error {
if txInfo := app.TxInfo(); txInfo != nil && checkTx {
txInfo.OnComplete(func(txErr error) error {
if txErr == nil {
return fn()
}
return nil
})
return nil
}
return fn()
}
// -------------------------------------------------------------------
const maxAuthOrigins = 5
func authAlert(e *core.RequestEvent, authRecord *core.Record) error {
// generating fingerprint
// ---
userAgent := e.Request.UserAgent()
if len(userAgent) > 300 {
userAgent = userAgent[:300]
}
fingerprint := security.MD5(e.RealIP() + userAgent)
// ---
origins, err := e.App.FindAllAuthOriginsByRecord(authRecord)
if err != nil {
return err
}
isFirstLogin := len(origins) == 0
var currentOrigin *core.AuthOrigin
for _, origin := range origins {
if origin.Fingerprint() == fingerprint {
currentOrigin = origin
break
}
}
if currentOrigin == nil {
currentOrigin = core.NewAuthOrigin(e.App)
currentOrigin.SetCollectionRef(authRecord.Collection().Id)
currentOrigin.SetRecordRef(authRecord.Id)
currentOrigin.SetFingerprint(fingerprint)
}
// send email alert for the new origin auth (skip first login)
//
// Note: The "fake" timeout is a temp solution to avoid blocking
// for too long when the SMTP server is not accessible, due
// to the lack of context cancellation support in the underlying
// mailer and net/smtp package.
// The goroutine technically "leaks" but we assume that the OS will
// terminate the connection after some time (usually after 3-4 mins).
if !isFirstLogin && currentOrigin.IsNew() && authRecord.Email() != "" {
mailSent := make(chan error, 1)
timer := time.AfterFunc(15*time.Second, func() {
mailSent <- errors.New("auth alert mail send wait timeout reached")
})
routine.FireAndForget(func() {
err := mails.SendRecordAuthAlert(e.App, authRecord)
timer.Stop()
mailSent <- err
})
err = <-mailSent
if err != nil {
return err
}
}
// try to keep only up to maxAuthOrigins
// (pop the last used ones; it is not executed in a transaction to avoid unnecessary locks)
if currentOrigin.IsNew() && len(origins) >= maxAuthOrigins {
for i := len(origins) - 1; i >= maxAuthOrigins-1; i-- {
if err := e.App.Delete(origins[i]); err != nil {
// treat as non-critical error, just log for now
e.App.Logger().Warn("Failed to delete old AuthOrigin record", "error", err, "authOriginId", origins[i].Id)
}
}
}
// create/update the origin fingerprint
return e.App.Save(currentOrigin)
}

761
apis/record_helpers_test.go Normal file
View file

@ -0,0 +1,761 @@
package apis_test
import (
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/pocketbase/pocketbase/apis"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
"github.com/pocketbase/pocketbase/tools/router"
"github.com/pocketbase/pocketbase/tools/types"
)
func TestEnrichRecords(t *testing.T) {
t.Parallel()
// mock test data
// ---
app, _ := tests.NewTestApp()
defer app.Cleanup()
freshRecords := func(records []*core.Record) []*core.Record {
result := make([]*core.Record, len(records))
for i, r := range records {
result[i] = r.Fresh()
}
return result
}
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
superuser, err := app.FindAuthRecordByEmail(core.CollectionNameSuperusers, "test@example.com")
if err != nil {
t.Fatal(err)
}
usersRecords, err := app.FindRecordsByIds("users", []string{"4q1xlclmfloku33", "bgs820n361vj1qd"})
if err != nil {
t.Fatal(err)
}
nologinRecords, err := app.FindRecordsByIds("nologin", []string{"dc49k6jgejn40h3", "oos036e9xvqeexy"})
if err != nil {
t.Fatal(err)
}
demo1Records, err := app.FindRecordsByIds("demo1", []string{"al1h9ijdeojtsjy", "84nmscqy84lsi1t"})
if err != nil {
t.Fatal(err)
}
demo5Records, err := app.FindRecordsByIds("demo5", []string{"la4y2w4o98acwuj", "qjeql998mtp1azp"})
if err != nil {
t.Fatal(err)
}
// temp update the view rule to ensure that request context is set to "expand"
demo4, err := app.FindCollectionByNameOrId("demo4")
if err != nil {
t.Fatal(err)
}
demo4.ViewRule = types.Pointer("@request.context = 'expand'")
if err := app.Save(demo4); err != nil {
t.Fatal(err)
}
// ---
scenarios := []struct {
name string
auth *core.Record
records []*core.Record
queryExpand string
defaultExpands []string
expected []string
notExpected []string
}{
// email visibility checks
{
name: "[emailVisibility] guest",
auth: nil,
records: freshRecords(usersRecords),
queryExpand: "",
defaultExpands: nil,
expected: []string{
`"customField":"123"`,
`"test3@example.com"`, // emailVisibility=true
},
notExpected: []string{
`"test@example.com"`,
},
},
{
name: "[emailVisibility] owner",
auth: user,
records: freshRecords(usersRecords),
queryExpand: "",
defaultExpands: nil,
expected: []string{
`"customField":"123"`,
`"test3@example.com"`, // emailVisibility=true
`"test@example.com"`, // owner
},
},
{
name: "[emailVisibility] manager",
auth: user,
records: freshRecords(nologinRecords),
queryExpand: "",
defaultExpands: nil,
expected: []string{
`"customField":"123"`,
`"test3@example.com"`,
`"test@example.com"`,
},
},
{
name: "[emailVisibility] superuser",
auth: superuser,
records: freshRecords(nologinRecords),
queryExpand: "",
defaultExpands: nil,
expected: []string{
`"customField":"123"`,
`"test3@example.com"`,
`"test@example.com"`,
},
},
{
name: "[emailVisibility + expand] recursive auth rule checks (regular user)",
auth: user,
records: freshRecords(demo1Records),
queryExpand: "",
defaultExpands: []string{"rel_many"},
expected: []string{
`"customField":"123"`,
`"expand":{"rel_many"`,
`"expand":{}`,
`"test@example.com"`,
},
notExpected: []string{
`"id":"bgs820n361vj1qd"`,
`"id":"oap640cot4yru2s"`,
},
},
{
name: "[emailVisibility + expand] recursive auth rule checks (superuser)",
auth: superuser,
records: freshRecords(demo1Records),
queryExpand: "",
defaultExpands: []string{"rel_many"},
expected: []string{
`"customField":"123"`,
`"test@example.com"`,
`"expand":{"rel_many"`,
`"id":"bgs820n361vj1qd"`,
`"id":"4q1xlclmfloku33"`,
`"id":"oap640cot4yru2s"`,
},
notExpected: []string{
`"expand":{}`,
},
},
// expand checks
{
name: "[expand] guest (query)",
auth: nil,
records: freshRecords(usersRecords),
queryExpand: "rel",
defaultExpands: nil,
expected: []string{
`"customField":"123"`,
`"expand":{"rel"`,
`"id":"llvuca81nly1qls"`,
`"id":"0yxhwia2amd8gec"`,
},
notExpected: []string{
`"expand":{}`,
},
},
{
name: "[expand] guest (default expands)",
auth: nil,
records: freshRecords(usersRecords),
queryExpand: "",
defaultExpands: []string{"rel"},
expected: []string{
`"customField":"123"`,
`"expand":{"rel"`,
`"id":"llvuca81nly1qls"`,
`"id":"0yxhwia2amd8gec"`,
},
},
{
name: "[expand] @request.context=expand check",
auth: nil,
records: freshRecords(demo5Records),
queryExpand: "rel_one",
defaultExpands: []string{"rel_many"},
expected: []string{
`"customField":"123"`,
`"expand":{}`,
`"expand":{"`,
`"rel_many":[{`,
`"rel_one":{`,
`"id":"i9naidtvr6qsgb4"`,
`"id":"qzaqccwrmva4o1n"`,
},
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
app.OnRecordEnrich().BindFunc(func(e *core.RecordEnrichEvent) error {
e.Record.WithCustomData(true)
e.Record.Set("customField", "123")
return e.Next()
})
req := httptest.NewRequest(http.MethodGet, "/?expand="+s.queryExpand, nil)
rec := httptest.NewRecorder()
requestEvent := new(core.RequestEvent)
requestEvent.App = app
requestEvent.Request = req
requestEvent.Response = rec
requestEvent.Auth = s.auth
err := apis.EnrichRecords(requestEvent, s.records, s.defaultExpands...)
if err != nil {
t.Fatal(err)
}
raw, err := json.Marshal(s.records)
if err != nil {
t.Fatal(err)
}
rawStr := string(raw)
for _, str := range s.expected {
if !strings.Contains(rawStr, str) {
t.Fatalf("Expected\n%q\nin\n%v", str, rawStr)
}
}
for _, str := range s.notExpected {
if strings.Contains(rawStr, str) {
t.Fatalf("Didn't expected\n%q\nin\n%v", str, rawStr)
}
}
})
}
}
func TestRecordAuthResponseAuthRuleCheck(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
event := new(core.RequestEvent)
event.App = app
event.Request = httptest.NewRequest(http.MethodGet, "/", nil)
event.Response = httptest.NewRecorder()
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
scenarios := []struct {
name string
rule *string
expectError bool
}{
{
"admin only rule",
nil,
true,
},
{
"empty rule",
types.Pointer(""),
false,
},
{
"false rule",
types.Pointer("1=2"),
true,
},
{
"true rule",
types.Pointer("1=1"),
false,
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
user.Collection().AuthRule = s.rule
err := apis.RecordAuthResponse(event, user, "", nil)
hasErr := err != nil
if s.expectError != hasErr {
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
}
// in all cases login alert shouldn't be send because of the empty auth method
if app.TestMailer.TotalSend() != 0 {
t.Fatalf("Expected no emails send, got %d:\n%v", app.TestMailer.TotalSend(), app.TestMailer.LastMessage().HTML)
}
if !hasErr {
return
}
apiErr, ok := err.(*router.ApiError)
if !ok || apiErr == nil {
t.Fatalf("Expected ApiError, got %v", apiErr)
}
if apiErr.Status != http.StatusForbidden {
t.Fatalf("Expected ApiError.Status %d, got %d", http.StatusForbidden, apiErr.Status)
}
})
}
}
func TestRecordAuthResponseAuthAlertCheck(t *testing.T) {
const testFingerprint = "d0f88d6c87767262ba8e93d6acccd784"
scenarios := []struct {
name string
devices []string // mock existing device fingerprints
expectDevices []string
enabled bool
expectEmail bool
}{
{
name: "first login",
devices: nil,
expectDevices: []string{testFingerprint},
enabled: true,
expectEmail: false,
},
{
name: "existing device",
devices: []string{"1", testFingerprint},
expectDevices: []string{"1", testFingerprint},
enabled: true,
expectEmail: false,
},
{
name: "new device (< 5)",
devices: []string{"1", "2"},
expectDevices: []string{"1", "2", testFingerprint},
enabled: true,
expectEmail: true,
},
{
name: "new device (>= 5)",
devices: []string{"1", "2", "3", "4", "5"},
expectDevices: []string{"2", "3", "4", "5", testFingerprint},
enabled: true,
expectEmail: true,
},
{
name: "with disabled auth alert collection flag",
devices: []string{"1", "2"},
expectDevices: []string{"1", "2"},
enabled: false,
expectEmail: false,
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
event := new(core.RequestEvent)
event.App = app
event.Request = httptest.NewRequest(http.MethodGet, "/", nil)
event.Response = httptest.NewRecorder()
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
user.Collection().MFA.Enabled = false
user.Collection().AuthRule = types.Pointer("")
user.Collection().AuthAlert.Enabled = s.enabled
// ensure that there are no other auth origins
err = app.DeleteAllAuthOriginsByRecord(user)
if err != nil {
t.Fatal(err)
}
mockCreated := types.NowDateTime().Add(-time.Duration(len(s.devices)+1) * time.Second)
// insert the mock devices
for _, fingerprint := range s.devices {
mockCreated = mockCreated.Add(1 * time.Second)
d := core.NewAuthOrigin(app)
d.SetCollectionRef(user.Collection().Id)
d.SetRecordRef(user.Id)
d.SetFingerprint(fingerprint)
d.SetRaw("created", mockCreated)
d.SetRaw("updated", mockCreated)
if err = app.Save(d); err != nil {
t.Fatal(err)
}
}
err = apis.RecordAuthResponse(event, user, "example", nil)
if err != nil {
t.Fatalf("Failed to resolve auth response: %v", err)
}
var expectTotalSend int
if s.expectEmail {
expectTotalSend = 1
}
if total := app.TestMailer.TotalSend(); total != expectTotalSend {
t.Fatalf("Expected %d sent emails, got %d", expectTotalSend, total)
}
devices, err := app.FindAllAuthOriginsByRecord(user)
if err != nil {
t.Fatalf("Failed to retrieve auth origins: %v", err)
}
if len(devices) != len(s.expectDevices) {
t.Fatalf("Expected %d devices, got %d", len(s.expectDevices), len(devices))
}
for _, fingerprint := range s.expectDevices {
var exists bool
fingerprints := make([]string, 0, len(devices))
for _, d := range devices {
if d.Fingerprint() == fingerprint {
exists = true
break
}
fingerprints = append(fingerprints, d.Fingerprint())
}
if !exists {
t.Fatalf("Missing device with fingerprint %q:\n%v", fingerprint, fingerprints)
}
}
})
}
}
func TestRecordAuthResponseMFACheck(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
user2, err := app.FindAuthRecordByEmail("users", "test2@example.com")
if err != nil {
t.Fatal(err)
}
rec := httptest.NewRecorder()
event := new(core.RequestEvent)
event.App = app
event.Request = httptest.NewRequest(http.MethodGet, "/", nil)
event.Response = rec
resetMFAs := func(authRecord *core.Record) {
// ensure that mfa is enabled
user.Collection().MFA.Enabled = true
user.Collection().MFA.Duration = 5
user.Collection().MFA.Rule = ""
mfas, err := app.FindAllMFAsByRecord(authRecord)
if err != nil {
t.Fatalf("Failed to retrieve mfas: %v", err)
}
for _, mfa := range mfas {
if err := app.Delete(mfa); err != nil {
t.Fatalf("Failed to delete mfa %q: %v", mfa.Id, err)
}
}
// reset response
rec = httptest.NewRecorder()
event.Response = rec
}
totalMFAs := func(authRecord *core.Record) int {
mfas, err := app.FindAllMFAsByRecord(authRecord)
if err != nil {
t.Fatalf("Failed to retrieve mfas: %v", err)
}
return len(mfas)
}
t.Run("no collection MFA enabled", func(t *testing.T) {
resetMFAs(user)
user.Collection().MFA.Enabled = false
err = apis.RecordAuthResponse(event, user, "example", nil)
if err != nil {
t.Fatalf("Expected nil, got error: %v", err)
}
body := rec.Body.String()
if strings.Contains(body, "mfaId") {
t.Fatalf("Expected no mfaId in the response body, got\n%v", body)
}
if !strings.Contains(body, "token") {
t.Fatalf("Expected auth token in the response body, got\n%v", body)
}
if total := totalMFAs(user); total != 0 {
t.Fatalf("Expected no mfa records to be created, got %d", total)
}
})
t.Run("no explicit auth method", func(t *testing.T) {
resetMFAs(user)
err = apis.RecordAuthResponse(event, user, "", nil)
if err != nil {
t.Fatalf("Expected nil, got error: %v", err)
}
body := rec.Body.String()
if strings.Contains(body, "mfaId") {
t.Fatalf("Expected no mfaId in the response body, got\n%v", body)
}
if !strings.Contains(body, "token") {
t.Fatalf("Expected auth token in the response body, got\n%v", body)
}
if total := totalMFAs(user); total != 0 {
t.Fatalf("Expected no mfa records to be created, got %d", total)
}
})
t.Run("no mfa wanted (mfa rule check failure)", func(t *testing.T) {
resetMFAs(user)
user.Collection().MFA.Rule = "1=2"
err = apis.RecordAuthResponse(event, user, "example", nil)
if err != nil {
t.Fatalf("Expected nil, got error: %v", err)
}
body := rec.Body.String()
if strings.Contains(body, "mfaId") {
t.Fatalf("Expected no mfaId in the response body, got\n%v", body)
}
if !strings.Contains(body, "token") {
t.Fatalf("Expected auth token in the response body, got\n%v", body)
}
if total := totalMFAs(user); total != 0 {
t.Fatalf("Expected no mfa records to be created, got %d", total)
}
})
t.Run("mfa wanted (mfa rule check success)", func(t *testing.T) {
resetMFAs(user)
user.Collection().MFA.Rule = "1=1"
err = apis.RecordAuthResponse(event, user, "example", nil)
if !errors.Is(err, apis.ErrMFA) {
t.Fatalf("Expected ErrMFA, got: %v", err)
}
body := rec.Body.String()
if !strings.Contains(body, "mfaId") {
t.Fatalf("Expected the created mfaId to be returned in the response body, got\n%v", body)
}
if total := totalMFAs(user); total != 1 {
t.Fatalf("Expected a single mfa record to be created, got %d", total)
}
})
t.Run("mfa first-time", func(t *testing.T) {
resetMFAs(user)
err = apis.RecordAuthResponse(event, user, "example", nil)
if !errors.Is(err, apis.ErrMFA) {
t.Fatalf("Expected ErrMFA, got: %v", err)
}
body := rec.Body.String()
if !strings.Contains(body, "mfaId") {
t.Fatalf("Expected the created mfaId to be returned in the response body, got\n%v", body)
}
if total := totalMFAs(user); total != 1 {
t.Fatalf("Expected a single mfa record to be created, got %d", total)
}
})
t.Run("mfa second-time with the same auth method", func(t *testing.T) {
resetMFAs(user)
// create a dummy mfa record
mfa := core.NewMFA(app)
mfa.SetCollectionRef(user.Collection().Id)
mfa.SetRecordRef(user.Id)
mfa.SetMethod("example")
if err = app.Save(mfa); err != nil {
t.Fatal(err)
}
event.Request = httptest.NewRequest(http.MethodGet, "/?mfaId="+mfa.Id, nil)
err = apis.RecordAuthResponse(event, user, "example", nil)
if err == nil {
t.Fatal("Expected error, got nil")
}
if total := totalMFAs(user); total != 1 {
t.Fatalf("Expected only 1 mfa record (the existing one), got %d", total)
}
})
t.Run("mfa second-time with the different auth method (query param)", func(t *testing.T) {
resetMFAs(user)
// create a dummy mfa record
mfa := core.NewMFA(app)
mfa.SetCollectionRef(user.Collection().Id)
mfa.SetRecordRef(user.Id)
mfa.SetMethod("example1")
if err = app.Save(mfa); err != nil {
t.Fatal(err)
}
event.Request = httptest.NewRequest(http.MethodGet, "/?mfaId="+mfa.Id, nil)
err = apis.RecordAuthResponse(event, user, "example2", nil)
if err != nil {
t.Fatalf("Expected nil, got error: %v", err)
}
if total := totalMFAs(user); total != 0 {
t.Fatalf("Expected the dummy mfa record to be deleted, found %d", total)
}
})
t.Run("mfa second-time with the different auth method (body param)", func(t *testing.T) {
resetMFAs(user)
// create a dummy mfa record
mfa := core.NewMFA(app)
mfa.SetCollectionRef(user.Collection().Id)
mfa.SetRecordRef(user.Id)
mfa.SetMethod("example1")
if err = app.Save(mfa); err != nil {
t.Fatal(err)
}
event.Request = httptest.NewRequest(http.MethodGet, "/", strings.NewReader(`{"mfaId":"`+mfa.Id+`"}`))
event.Request.Header.Add("content-type", "application/json")
err = apis.RecordAuthResponse(event, user, "example2", nil)
if err != nil {
t.Fatalf("Expected nil, got error: %v", err)
}
if total := totalMFAs(user); total != 0 {
t.Fatalf("Expected the dummy mfa record to be deleted, found %d", total)
}
})
t.Run("missing mfa", func(t *testing.T) {
resetMFAs(user)
event.Request = httptest.NewRequest(http.MethodGet, "/?mfaId=missing", nil)
err = apis.RecordAuthResponse(event, user, "example2", nil)
if err == nil {
t.Fatal("Expected error, got nil")
}
if total := totalMFAs(user); total != 0 {
t.Fatalf("Expected 0 mfa records, got %d", total)
}
})
t.Run("expired mfa", func(t *testing.T) {
resetMFAs(user)
// create a dummy expired mfa record
mfa := core.NewMFA(app)
mfa.SetCollectionRef(user.Collection().Id)
mfa.SetRecordRef(user.Id)
mfa.SetMethod("example1")
mfa.SetRaw("created", types.NowDateTime().Add(-1*time.Hour))
mfa.SetRaw("updated", types.NowDateTime().Add(-1*time.Hour))
if err = app.Save(mfa); err != nil {
t.Fatal(err)
}
event.Request = httptest.NewRequest(http.MethodGet, "/?mfaId="+mfa.Id, nil)
err = apis.RecordAuthResponse(event, user, "example2", nil)
if err == nil {
t.Fatal("Expected error, got nil")
}
if totalMFAs(user) != 0 {
t.Fatal("Expected the expired mfa record to be deleted")
}
})
t.Run("mfa for different auth record", func(t *testing.T) {
resetMFAs(user)
// create a dummy expired mfa record
mfa := core.NewMFA(app)
mfa.SetCollectionRef(user2.Collection().Id)
mfa.SetRecordRef(user2.Id)
mfa.SetMethod("example1")
if err = app.Save(mfa); err != nil {
t.Fatal(err)
}
event.Request = httptest.NewRequest(http.MethodGet, "/?mfaId="+mfa.Id, nil)
err = apis.RecordAuthResponse(event, user, "example2", nil)
if err == nil {
t.Fatal("Expected error, got nil")
}
if total := totalMFAs(user); total != 0 {
t.Fatalf("Expected no user mfas, got %d", total)
}
if total := totalMFAs(user2); total != 1 {
t.Fatalf("Expected only 1 user2 mfa, got %d", total)
}
})
}

318
apis/serve.go Normal file
View file

@ -0,0 +1,318 @@
package apis
import (
"context"
"crypto/tls"
"errors"
"log"
"net"
"net/http"
"path/filepath"
"strings"
"sync"
"time"
"github.com/fatih/color"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tools/hook"
"github.com/pocketbase/pocketbase/tools/list"
"github.com/pocketbase/pocketbase/tools/routine"
"github.com/pocketbase/pocketbase/ui"
"golang.org/x/crypto/acme"
"golang.org/x/crypto/acme/autocert"
)
// ServeConfig defines a configuration struct for apis.Serve().
type ServeConfig struct {
// ShowStartBanner indicates whether to show or hide the server start console message.
ShowStartBanner bool
// HttpAddr is the TCP address to listen for the HTTP server (eg. "127.0.0.1:80").
HttpAddr string
// HttpsAddr is the TCP address to listen for the HTTPS server (eg. "127.0.0.1:443").
HttpsAddr string
// Optional domains list to use when issuing the TLS certificate.
//
// If not set, the host from the bound server address will be used.
//
// For convenience, for each "non-www" domain a "www" entry and
// redirect will be automatically added.
CertificateDomains []string
// AllowedOrigins is an optional list of CORS origins (default to "*").
AllowedOrigins []string
}
// Serve starts a new app web server.
//
// NB! The app should be bootstrapped before starting the web server.
//
// Example:
//
// app.Bootstrap()
// apis.Serve(app, apis.ServeConfig{
// HttpAddr: "127.0.0.1:8080",
// ShowStartBanner: false,
// })
func Serve(app core.App, config ServeConfig) error {
if len(config.AllowedOrigins) == 0 {
config.AllowedOrigins = []string{"*"}
}
// ensure that the latest migrations are applied before starting the server
err := app.RunAllMigrations()
if err != nil {
return err
}
pbRouter, err := NewRouter(app)
if err != nil {
return err
}
pbRouter.Bind(CORS(CORSConfig{
AllowOrigins: config.AllowedOrigins,
AllowMethods: []string{http.MethodGet, http.MethodHead, http.MethodPut, http.MethodPatch, http.MethodPost, http.MethodDelete},
}))
pbRouter.GET("/_/{path...}", Static(ui.DistDirFS, false)).
BindFunc(func(e *core.RequestEvent) error {
// ignore root path
if e.Request.PathValue(StaticWildcardParam) != "" {
e.Response.Header().Set("Cache-Control", "max-age=1209600, stale-while-revalidate=86400")
}
// add a default CSP
if e.Response.Header().Get("Content-Security-Policy") == "" {
e.Response.Header().Set("Content-Security-Policy", "default-src 'self'; style-src 'self' 'unsafe-inline'; img-src 'self' http://127.0.0.1:* https://tile.openstreetmap.org data: blob:; connect-src 'self' http://127.0.0.1:* https://nominatim.openstreetmap.org; script-src 'self' 'sha256-GRUzBA7PzKYug7pqxv5rJaec5bwDCw1Vo6/IXwvD3Tc='")
}
return e.Next()
}).
Bind(Gzip())
// start http server
// ---
mainAddr := config.HttpAddr
if config.HttpsAddr != "" {
mainAddr = config.HttpsAddr
}
var wwwRedirects []string
// extract the host names for the certificate host policy
hostNames := config.CertificateDomains
if len(hostNames) == 0 {
host, _, _ := net.SplitHostPort(mainAddr)
hostNames = append(hostNames, host)
}
for _, host := range hostNames {
if strings.HasPrefix(host, "www.") {
continue // explicitly set www host
}
wwwHost := "www." + host
if !list.ExistInSlice(wwwHost, hostNames) {
hostNames = append(hostNames, wwwHost)
wwwRedirects = append(wwwRedirects, wwwHost)
}
}
// implicit www->non-www redirect(s)
if len(wwwRedirects) > 0 {
pbRouter.Bind(wwwRedirect(wwwRedirects))
}
certManager := &autocert.Manager{
Prompt: autocert.AcceptTOS,
Cache: autocert.DirCache(filepath.Join(app.DataDir(), core.LocalAutocertCacheDirName)),
HostPolicy: autocert.HostWhitelist(hostNames...),
}
// base request context used for cancelling long running requests
// like the SSE connections
baseCtx, cancelBaseCtx := context.WithCancel(context.Background())
defer cancelBaseCtx()
server := &http.Server{
TLSConfig: &tls.Config{
MinVersion: tls.VersionTLS12,
GetCertificate: certManager.GetCertificate,
NextProtos: []string{acme.ALPNProto},
},
// higher defaults to accommodate large file uploads/downloads
WriteTimeout: 5 * time.Minute,
ReadTimeout: 5 * time.Minute,
ReadHeaderTimeout: 1 * time.Minute,
Addr: mainAddr,
BaseContext: func(l net.Listener) context.Context {
return baseCtx
},
ErrorLog: log.New(&serverErrorLogWriter{app: app}, "", 0),
}
serveEvent := new(core.ServeEvent)
serveEvent.App = app
serveEvent.Router = pbRouter
serveEvent.Server = server
serveEvent.CertManager = certManager
serveEvent.InstallerFunc = DefaultInstallerFunc
var listener net.Listener
// graceful shutdown
// ---------------------------------------------------------------
// WaitGroup to block until server.ShutDown() returns because Serve and similar methods exit immediately.
// Note that the WaitGroup would do nothing if the app.OnTerminate() hook isn't triggered.
var wg sync.WaitGroup
// try to gracefully shutdown the server on app termination
app.OnTerminate().Bind(&hook.Handler[*core.TerminateEvent]{
Id: "pbGracefulShutdown",
Func: func(te *core.TerminateEvent) error {
cancelBaseCtx()
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
wg.Add(1)
_ = server.Shutdown(ctx)
if te.IsRestart {
// wait for execve and other handlers up to 3 seconds before exit
time.AfterFunc(3*time.Second, func() {
wg.Done()
})
} else {
wg.Done()
}
return te.Next()
},
Priority: -9999,
})
// wait for the graceful shutdown to complete before exit
defer func() {
wg.Wait()
if listener != nil {
_ = listener.Close()
}
}()
// ---------------------------------------------------------------
var baseURL string
// trigger the OnServe hook and start the tcp listener
serveHookErr := app.OnServe().Trigger(serveEvent, func(e *core.ServeEvent) error {
handler, err := e.Router.BuildMux()
if err != nil {
return err
}
e.Server.Handler = handler
if config.HttpsAddr == "" {
baseURL = "http://" + serverAddrToHost(serveEvent.Server.Addr)
} else {
baseURL = "https://"
if len(config.CertificateDomains) > 0 {
baseURL += config.CertificateDomains[0]
} else {
baseURL += serverAddrToHost(serveEvent.Server.Addr)
}
}
addr := e.Server.Addr
if addr == "" {
// fallback similar to the std Server.ListenAndServe/ListenAndServeTLS
if config.HttpsAddr != "" {
addr = ":https"
} else {
addr = ":http"
}
}
listener, err = net.Listen("tcp", addr)
if err != nil {
return err
}
if e.InstallerFunc != nil {
app := e.App
installerFunc := e.InstallerFunc
routine.FireAndForget(func() {
if err := loadInstaller(app, baseURL, installerFunc); err != nil {
app.Logger().Warn("Failed to initialize installer", "error", err)
}
})
}
return nil
})
if serveHookErr != nil {
return serveHookErr
}
if listener == nil {
//nolint:staticcheck
return errors.New("The OnServe finalizer wasn't invoked. Did you forget to call the ServeEvent.Next() method?")
}
if config.ShowStartBanner {
date := new(strings.Builder)
log.New(date, "", log.LstdFlags).Print()
bold := color.New(color.Bold).Add(color.FgGreen)
bold.Printf(
"%s Server started at %s\n",
strings.TrimSpace(date.String()),
color.CyanString("%s", baseURL),
)
regular := color.New()
regular.Printf("├─ REST API: %s\n", color.CyanString("%s/api/", baseURL))
regular.Printf("└─ Dashboard: %s\n", color.CyanString("%s/_/", baseURL))
}
var serveErr error
if config.HttpsAddr != "" {
if config.HttpAddr != "" {
// start an additional HTTP server for redirecting the traffic to the HTTPS version
go http.ListenAndServe(config.HttpAddr, certManager.HTTPHandler(nil))
}
// start HTTPS server
serveErr = serveEvent.Server.ServeTLS(listener, "", "")
} else {
// OR start HTTP server
serveErr = serveEvent.Server.Serve(listener)
}
if serveErr != nil && !errors.Is(serveErr, http.ErrServerClosed) {
return serveErr
}
return nil
}
// serverAddrToHost loosely converts http.Server.Addr string into a host to print.
func serverAddrToHost(addr string) string {
if addr == "" || strings.HasSuffix(addr, ":http") || strings.HasSuffix(addr, ":https") {
return "127.0.0.1"
}
return addr
}
type serverErrorLogWriter struct {
app core.App
}
func (s *serverErrorLogWriter) Write(p []byte) (int, error) {
s.app.Logger().Debug(strings.TrimSpace(string(p)))
return len(p), nil
}

143
apis/settings.go Normal file
View file

@ -0,0 +1,143 @@
package apis
import (
"net/http"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/forms"
"github.com/pocketbase/pocketbase/tools/router"
)
// bindSettingsApi registers the settings api endpoints.
func bindSettingsApi(app core.App, rg *router.RouterGroup[*core.RequestEvent]) {
subGroup := rg.Group("/settings").Bind(RequireSuperuserAuth())
subGroup.GET("", settingsList)
subGroup.PATCH("", settingsSet)
subGroup.POST("/test/s3", settingsTestS3)
subGroup.POST("/test/email", settingsTestEmail)
subGroup.POST("/apple/generate-client-secret", settingsGenerateAppleClientSecret)
}
func settingsList(e *core.RequestEvent) error {
clone, err := e.App.Settings().Clone()
if err != nil {
return e.InternalServerError("", err)
}
event := new(core.SettingsListRequestEvent)
event.RequestEvent = e
event.Settings = clone
return e.App.OnSettingsListRequest().Trigger(event, func(e *core.SettingsListRequestEvent) error {
return execAfterSuccessTx(true, e.App, func() error {
return e.JSON(http.StatusOK, e.Settings)
})
})
}
func settingsSet(e *core.RequestEvent) error {
event := new(core.SettingsUpdateRequestEvent)
event.RequestEvent = e
if clone, err := e.App.Settings().Clone(); err == nil {
event.OldSettings = clone
} else {
return e.BadRequestError("", err)
}
if clone, err := e.App.Settings().Clone(); err == nil {
event.NewSettings = clone
} else {
return e.BadRequestError("", err)
}
if err := e.BindBody(&event.NewSettings); err != nil {
return e.BadRequestError("An error occurred while loading the submitted data.", err)
}
return e.App.OnSettingsUpdateRequest().Trigger(event, func(e *core.SettingsUpdateRequestEvent) error {
err := e.App.Save(e.NewSettings)
if err != nil {
return e.BadRequestError("An error occurred while saving the new settings.", err)
}
appSettings, err := e.App.Settings().Clone()
if err != nil {
return e.InternalServerError("Failed to clone app settings.", err)
}
return execAfterSuccessTx(true, e.App, func() error {
return e.JSON(http.StatusOK, appSettings)
})
})
}
func settingsTestS3(e *core.RequestEvent) error {
form := forms.NewTestS3Filesystem(e.App)
// load request
if err := e.BindBody(form); err != nil {
return e.BadRequestError("An error occurred while loading the submitted data.", err)
}
// send
if err := form.Submit(); err != nil {
// form error
if fErr, ok := err.(validation.Errors); ok {
return e.BadRequestError("Failed to test the S3 filesystem.", fErr)
}
// mailer error
return e.BadRequestError("Failed to test the S3 filesystem. Raw error: \n"+err.Error(), nil)
}
return e.NoContent(http.StatusNoContent)
}
func settingsTestEmail(e *core.RequestEvent) error {
form := forms.NewTestEmailSend(e.App)
// load request
if err := e.BindBody(form); err != nil {
return e.BadRequestError("An error occurred while loading the submitted data.", err)
}
// send
if err := form.Submit(); err != nil {
// form error
if fErr, ok := err.(validation.Errors); ok {
return e.BadRequestError("Failed to send the test email.", fErr)
}
// mailer error
return e.BadRequestError("Failed to send the test email. Raw error: \n"+err.Error(), nil)
}
return e.NoContent(http.StatusNoContent)
}
func settingsGenerateAppleClientSecret(e *core.RequestEvent) error {
form := forms.NewAppleClientSecretCreate(e.App)
// load request
if err := e.BindBody(form); err != nil {
return e.BadRequestError("An error occurred while loading the submitted data.", err)
}
// generate
secret, err := form.Submit()
if err != nil {
// form error
if fErr, ok := err.(validation.Errors); ok {
return e.BadRequestError("Invalid client secret data.", fErr)
}
// secret generation error
return e.BadRequestError("Failed to generate client secret. Raw error: \n"+err.Error(), nil)
}
return e.JSON(http.StatusOK, map[string]string{
"secret": secret,
})
}

641
apis/settings_test.go Normal file
View file

@ -0,0 +1,641 @@
package apis_test
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/x509"
"encoding/pem"
"fmt"
"net/http"
"strings"
"testing"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
)
func TestSettingsList(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "unauthorized",
Method: http.MethodGet,
URL: "/api/settings",
ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as regular user",
Method: http.MethodGet,
URL: "/api/settings",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as superuser",
Method: http.MethodGet,
URL: "/api/settings",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"meta":{`,
`"logs":{`,
`"smtp":{`,
`"s3":{`,
`"backups":{`,
`"batch":{`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnSettingsListRequest": 1,
},
},
{
Name: "OnSettingsListRequest tx body write check",
Method: http.MethodGet,
URL: "/api/settings",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.OnSettingsListRequest().BindFunc(func(e *core.SettingsListRequestEvent) error {
original := e.App
return e.App.RunInTransaction(func(txApp core.App) error {
e.App = txApp
defer func() { e.App = original }()
if err := e.Next(); err != nil {
return err
}
return e.BadRequestError("TX_ERROR", nil)
})
})
},
ExpectedStatus: 400,
ExpectedEvents: map[string]int{"OnSettingsListRequest": 1},
ExpectedContent: []string{"TX_ERROR"},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestSettingsSet(t *testing.T) {
t.Parallel()
validData := `{
"meta":{"appName":"update_test"},
"s3":{"secret": "s3_secret"},
"backups":{"s3":{"secret":"backups_s3_secret"}}
}`
scenarios := []tests.ApiScenario{
{
Name: "unauthorized",
Method: http.MethodPatch,
URL: "/api/settings",
Body: strings.NewReader(validData),
ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as regular user",
Method: http.MethodPatch,
URL: "/api/settings",
Body: strings.NewReader(validData),
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as superuser submitting empty data",
Method: http.MethodPatch,
URL: "/api/settings",
Body: strings.NewReader(``),
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"meta":{`,
`"logs":{`,
`"smtp":{`,
`"s3":{`,
`"backups":{`,
`"batch":{`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnSettingsUpdateRequest": 1,
"OnModelUpdate": 1,
"OnModelUpdateExecute": 1,
"OnModelAfterUpdateSuccess": 1,
"OnModelValidate": 1,
"OnSettingsReload": 1,
},
},
{
Name: "authorized as superuser submitting invalid data",
Method: http.MethodPatch,
URL: "/api/settings",
Body: strings.NewReader(`{"meta":{"appName":""}}`),
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{`,
`"meta":{"appName":{"code":"validation_required"`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnModelUpdate": 1,
"OnModelAfterUpdateError": 1,
"OnModelValidate": 1,
"OnSettingsUpdateRequest": 1,
},
},
{
Name: "authorized as superuser submitting valid data",
Method: http.MethodPatch,
URL: "/api/settings",
Body: strings.NewReader(validData),
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"meta":{`,
`"logs":{`,
`"smtp":{`,
`"s3":{`,
`"backups":{`,
`"batch":{`,
`"appName":"update_test"`,
},
NotExpectedContent: []string{
"secret",
"password",
},
ExpectedEvents: map[string]int{
"*": 0,
"OnSettingsUpdateRequest": 1,
"OnModelUpdate": 1,
"OnModelUpdateExecute": 1,
"OnModelAfterUpdateSuccess": 1,
"OnModelValidate": 1,
"OnSettingsReload": 1,
},
},
{
Name: "OnSettingsUpdateRequest tx body write check",
Method: http.MethodPatch,
URL: "/api/settings",
Body: strings.NewReader(validData),
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.OnSettingsUpdateRequest().BindFunc(func(e *core.SettingsUpdateRequestEvent) error {
original := e.App
return e.App.RunInTransaction(func(txApp core.App) error {
e.App = txApp
defer func() { e.App = original }()
if err := e.Next(); err != nil {
return err
}
return e.BadRequestError("TX_ERROR", nil)
})
})
},
ExpectedStatus: 400,
ExpectedEvents: map[string]int{"OnSettingsUpdateRequest": 1},
ExpectedContent: []string{"TX_ERROR"},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestSettingsTestS3(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "unauthorized",
Method: http.MethodPost,
URL: "/api/settings/test/s3",
ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as regular user",
Method: http.MethodPost,
URL: "/api/settings/test/s3",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as superuser (missing body + no s3)",
Method: http.MethodPost,
URL: "/api/settings/test/s3",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{`,
`"filesystem":{`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as superuser (invalid filesystem)",
Method: http.MethodPost,
URL: "/api/settings/test/s3",
Body: strings.NewReader(`{"filesystem":"invalid"}`),
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{`,
`"filesystem":{`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as superuser (valid filesystem and no s3)",
Method: http.MethodPost,
URL: "/api/settings/test/s3",
Body: strings.NewReader(`{"filesystem":"storage"}`),
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{}`,
},
ExpectedEvents: map[string]int{"*": 0},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestSettingsTestEmail(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "unauthorized",
Method: http.MethodPost,
URL: "/api/settings/test/email",
Body: strings.NewReader(`{
"template": "verification",
"email": "test@example.com"
}`),
ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as regular user",
Method: http.MethodPost,
URL: "/api/settings/test/email",
Body: strings.NewReader(`{
"template": "verification",
"email": "test@example.com"
}`),
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as superuser (invalid body)",
Method: http.MethodPost,
URL: "/api/settings/test/email",
Body: strings.NewReader(`{`),
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as superuser (empty json)",
Method: http.MethodPost,
URL: "/api/settings/test/email",
Body: strings.NewReader(`{}`),
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
ExpectedStatus: 400,
ExpectedContent: []string{
`"email":{"code":"validation_required"`,
`"template":{"code":"validation_required"`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as superuser (verifiation template)",
Method: http.MethodPost,
URL: "/api/settings/test/email",
Body: strings.NewReader(`{
"template": "verification",
"email": "test@example.com"
}`),
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
if app.TestMailer.TotalSend() != 1 {
t.Fatalf("[verification] Expected 1 sent email, got %d", app.TestMailer.TotalSend())
}
if len(app.TestMailer.LastMessage().To) != 1 {
t.Fatalf("[verification] Expected 1 recipient, got %v", app.TestMailer.LastMessage().To)
}
if app.TestMailer.LastMessage().To[0].Address != "test@example.com" {
t.Fatalf("[verification] Expected the email to be sent to %s, got %s", "test@example.com", app.TestMailer.LastMessage().To[0].Address)
}
if !strings.Contains(app.TestMailer.LastMessage().HTML, "Verify") {
t.Fatalf("[verification] Expected to sent a verification email, got \n%v\n%v", app.TestMailer.LastMessage().Subject, app.TestMailer.LastMessage().HTML)
}
},
ExpectedStatus: 204,
ExpectedContent: []string{},
ExpectedEvents: map[string]int{
"*": 0,
"OnMailerSend": 1,
"OnMailerRecordVerificationSend": 1,
},
},
{
Name: "authorized as superuser (password reset template)",
Method: http.MethodPost,
URL: "/api/settings/test/email",
Body: strings.NewReader(`{
"template": "password-reset",
"email": "test@example.com"
}`),
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
if app.TestMailer.TotalSend() != 1 {
t.Fatalf("[password-reset] Expected 1 sent email, got %d", app.TestMailer.TotalSend())
}
if len(app.TestMailer.LastMessage().To) != 1 {
t.Fatalf("[password-reset] Expected 1 recipient, got %v", app.TestMailer.LastMessage().To)
}
if app.TestMailer.LastMessage().To[0].Address != "test@example.com" {
t.Fatalf("[password-reset] Expected the email to be sent to %s, got %s", "test@example.com", app.TestMailer.LastMessage().To[0].Address)
}
if !strings.Contains(app.TestMailer.LastMessage().HTML, "Reset password") {
t.Fatalf("[password-reset] Expected to sent a password-reset email, got \n%v\n%v", app.TestMailer.LastMessage().Subject, app.TestMailer.LastMessage().HTML)
}
},
ExpectedStatus: 204,
ExpectedContent: []string{},
ExpectedEvents: map[string]int{
"*": 0,
"OnMailerSend": 1,
"OnMailerRecordPasswordResetSend": 1,
},
},
{
Name: "authorized as superuser (email change)",
Method: http.MethodPost,
URL: "/api/settings/test/email",
Body: strings.NewReader(`{
"template": "email-change",
"email": "test@example.com"
}`),
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
if app.TestMailer.TotalSend() != 1 {
t.Fatalf("[email-change] Expected 1 sent email, got %d", app.TestMailer.TotalSend())
}
if len(app.TestMailer.LastMessage().To) != 1 {
t.Fatalf("[email-change] Expected 1 recipient, got %v", app.TestMailer.LastMessage().To)
}
if app.TestMailer.LastMessage().To[0].Address != "test@example.com" {
t.Fatalf("[email-change] Expected the email to be sent to %s, got %s", "test@example.com", app.TestMailer.LastMessage().To[0].Address)
}
if !strings.Contains(app.TestMailer.LastMessage().HTML, "Confirm new email") {
t.Fatalf("[email-change] Expected to sent a confirm new email email, got \n%v\n%v", app.TestMailer.LastMessage().Subject, app.TestMailer.LastMessage().HTML)
}
},
ExpectedStatus: 204,
ExpectedContent: []string{},
ExpectedEvents: map[string]int{
"*": 0,
"OnMailerSend": 1,
"OnMailerRecordEmailChangeSend": 1,
},
},
{
Name: "authorized as superuser (otp)",
Method: http.MethodPost,
URL: "/api/settings/test/email",
Body: strings.NewReader(`{
"template": "otp",
"email": "test@example.com"
}`),
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
if app.TestMailer.TotalSend() != 1 {
t.Fatalf("[otp] Expected 1 sent email, got %d", app.TestMailer.TotalSend())
}
if len(app.TestMailer.LastMessage().To) != 1 {
t.Fatalf("[otp] Expected 1 recipient, got %v", app.TestMailer.LastMessage().To)
}
if app.TestMailer.LastMessage().To[0].Address != "test@example.com" {
t.Fatalf("[otp] Expected the email to be sent to %s, got %s", "test@example.com", app.TestMailer.LastMessage().To[0].Address)
}
if !strings.Contains(app.TestMailer.LastMessage().HTML, "one-time password") {
t.Fatalf("[otp] Expected to sent OTP email, got \n%v\n%v", app.TestMailer.LastMessage().Subject, app.TestMailer.LastMessage().HTML)
}
},
ExpectedStatus: 204,
ExpectedContent: []string{},
ExpectedEvents: map[string]int{
"*": 0,
"OnMailerSend": 1,
"OnMailerRecordOTPSend": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestGenerateAppleClientSecret(t *testing.T) {
t.Parallel()
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatal(err)
}
encodedKey, err := x509.MarshalPKCS8PrivateKey(key)
if err != nil {
t.Fatal(err)
}
privatePem := pem.EncodeToMemory(
&pem.Block{
Type: "PRIVATE KEY",
Bytes: encodedKey,
},
)
scenarios := []tests.ApiScenario{
{
Name: "unauthorized",
Method: http.MethodPost,
URL: "/api/settings/apple/generate-client-secret",
ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as regular user",
Method: http.MethodPost,
URL: "/api/settings/apple/generate-client-secret",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as superuser (invalid body)",
Method: http.MethodPost,
URL: "/api/settings/apple/generate-client-secret",
Body: strings.NewReader(`{`),
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as superuser (empty json)",
Method: http.MethodPost,
URL: "/api/settings/apple/generate-client-secret",
Body: strings.NewReader(`{}`),
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
ExpectedStatus: 400,
ExpectedContent: []string{
`"clientId":{"code":"validation_required"`,
`"teamId":{"code":"validation_required"`,
`"keyId":{"code":"validation_required"`,
`"privateKey":{"code":"validation_required"`,
`"duration":{"code":"validation_required"`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as superuser (invalid data)",
Method: http.MethodPost,
URL: "/api/settings/apple/generate-client-secret",
Body: strings.NewReader(`{
"clientId": "",
"teamId": "123456789",
"keyId": "123456789",
"privateKey": "invalid",
"duration": -1
}`),
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
ExpectedStatus: 400,
ExpectedContent: []string{
`"clientId":{"code":"validation_required"`,
`"teamId":{"code":"validation_length_invalid"`,
`"keyId":{"code":"validation_length_invalid"`,
`"privateKey":{"code":"validation_match_invalid"`,
`"duration":{"code":"validation_min_greater_equal_than_required"`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as superuser (valid data)",
Method: http.MethodPost,
URL: "/api/settings/apple/generate-client-secret",
Body: strings.NewReader(fmt.Sprintf(`{
"clientId": "123",
"teamId": "1234567890",
"keyId": "1234567891",
"privateKey": %q,
"duration": 1
}`, privatePem)),
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"secret":"`,
},
ExpectedEvents: map[string]int{"*": 0},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}