Adding upstream version 0.28.1.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
88f1d47ab6
commit
e28c88ef14
933 changed files with 194711 additions and 0 deletions
47
apis/api_error_aliases.go
Normal file
47
apis/api_error_aliases.go
Normal file
|
@ -0,0 +1,47 @@
|
|||
package apis
|
||||
|
||||
import "github.com/pocketbase/pocketbase/tools/router"
|
||||
|
||||
// ApiError aliases to minimize the breaking changes with earlier versions
|
||||
// and for consistency with the JSVM binds.
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
// ToApiError wraps err into ApiError instance (if not already).
|
||||
func ToApiError(err error) *router.ApiError {
|
||||
return router.ToApiError(err)
|
||||
}
|
||||
|
||||
// NewApiError is an alias for [router.NewApiError].
|
||||
func NewApiError(status int, message string, errData any) *router.ApiError {
|
||||
return router.NewApiError(status, message, errData)
|
||||
}
|
||||
|
||||
// NewBadRequestError is an alias for [router.NewBadRequestError].
|
||||
func NewBadRequestError(message string, errData any) *router.ApiError {
|
||||
return router.NewBadRequestError(message, errData)
|
||||
}
|
||||
|
||||
// NewNotFoundError is an alias for [router.NewNotFoundError].
|
||||
func NewNotFoundError(message string, errData any) *router.ApiError {
|
||||
return router.NewNotFoundError(message, errData)
|
||||
}
|
||||
|
||||
// NewForbiddenError is an alias for [router.NewForbiddenError].
|
||||
func NewForbiddenError(message string, errData any) *router.ApiError {
|
||||
return router.NewForbiddenError(message, errData)
|
||||
}
|
||||
|
||||
// NewUnauthorizedError is an alias for [router.NewUnauthorizedError].
|
||||
func NewUnauthorizedError(message string, errData any) *router.ApiError {
|
||||
return router.NewUnauthorizedError(message, errData)
|
||||
}
|
||||
|
||||
// NewTooManyRequestsError is an alias for [router.NewTooManyRequestsError].
|
||||
func NewTooManyRequestsError(message string, errData any) *router.ApiError {
|
||||
return router.NewTooManyRequestsError(message, errData)
|
||||
}
|
||||
|
||||
// NewInternalServerError is an alias for [router.NewInternalServerError].
|
||||
func NewInternalServerError(message string, errData any) *router.ApiError {
|
||||
return router.NewInternalServerError(message, errData)
|
||||
}
|
155
apis/backup.go
Normal file
155
apis/backup.go
Normal file
|
@ -0,0 +1,155 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tools/router"
|
||||
"github.com/pocketbase/pocketbase/tools/routine"
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
"github.com/spf13/cast"
|
||||
)
|
||||
|
||||
// bindBackupApi registers the file api endpoints and the corresponding handlers.
|
||||
func bindBackupApi(app core.App, rg *router.RouterGroup[*core.RequestEvent]) {
|
||||
sub := rg.Group("/backups")
|
||||
sub.GET("", backupsList).Bind(RequireSuperuserAuth())
|
||||
sub.POST("", backupCreate).Bind(RequireSuperuserAuth())
|
||||
sub.POST("/upload", backupUpload).Bind(BodyLimit(0), RequireSuperuserAuth())
|
||||
sub.GET("/{key}", backupDownload) // relies on superuser file token
|
||||
sub.DELETE("/{key}", backupDelete).Bind(RequireSuperuserAuth())
|
||||
sub.POST("/{key}/restore", backupRestore).Bind(RequireSuperuserAuth())
|
||||
}
|
||||
|
||||
type backupFileInfo struct {
|
||||
Modified types.DateTime `json:"modified"`
|
||||
Key string `json:"key"`
|
||||
Size int64 `json:"size"`
|
||||
}
|
||||
|
||||
func backupsList(e *core.RequestEvent) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
fsys, err := e.App.NewBackupsFilesystem()
|
||||
if err != nil {
|
||||
return e.BadRequestError("Failed to load backups filesystem.", err)
|
||||
}
|
||||
defer fsys.Close()
|
||||
|
||||
fsys.SetContext(ctx)
|
||||
|
||||
backups, err := fsys.List("")
|
||||
if err != nil {
|
||||
return e.BadRequestError("Failed to retrieve backup items. Raw error: \n"+err.Error(), nil)
|
||||
}
|
||||
|
||||
result := make([]backupFileInfo, len(backups))
|
||||
|
||||
for i, obj := range backups {
|
||||
modified, _ := types.ParseDateTime(obj.ModTime)
|
||||
|
||||
result[i] = backupFileInfo{
|
||||
Key: obj.Key,
|
||||
Size: obj.Size,
|
||||
Modified: modified,
|
||||
}
|
||||
}
|
||||
|
||||
return e.JSON(http.StatusOK, result)
|
||||
}
|
||||
|
||||
func backupDownload(e *core.RequestEvent) error {
|
||||
fileToken := e.Request.URL.Query().Get("token")
|
||||
|
||||
authRecord, err := e.App.FindAuthRecordByToken(fileToken, core.TokenTypeFile)
|
||||
if err != nil || !authRecord.IsSuperuser() {
|
||||
return e.ForbiddenError("Insufficient permissions to access the resource.", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
fsys, err := e.App.NewBackupsFilesystem()
|
||||
if err != nil {
|
||||
return e.InternalServerError("Failed to load backups filesystem.", err)
|
||||
}
|
||||
defer fsys.Close()
|
||||
|
||||
fsys.SetContext(ctx)
|
||||
|
||||
key := e.Request.PathValue("key")
|
||||
|
||||
return fsys.Serve(
|
||||
e.Response,
|
||||
e.Request,
|
||||
key,
|
||||
filepath.Base(key), // without the path prefix (if any)
|
||||
)
|
||||
}
|
||||
|
||||
func backupDelete(e *core.RequestEvent) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
fsys, err := e.App.NewBackupsFilesystem()
|
||||
if err != nil {
|
||||
return e.InternalServerError("Failed to load backups filesystem.", err)
|
||||
}
|
||||
defer fsys.Close()
|
||||
|
||||
fsys.SetContext(ctx)
|
||||
|
||||
key := e.Request.PathValue("key")
|
||||
|
||||
if key != "" && cast.ToString(e.App.Store().Get(core.StoreKeyActiveBackup)) == key {
|
||||
return e.BadRequestError("The backup is currently being used and cannot be deleted.", nil)
|
||||
}
|
||||
|
||||
if err := fsys.Delete(key); err != nil {
|
||||
return e.BadRequestError("Invalid or already deleted backup file. Raw error: \n"+err.Error(), nil)
|
||||
}
|
||||
|
||||
return e.NoContent(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func backupRestore(e *core.RequestEvent) error {
|
||||
if e.App.Store().Has(core.StoreKeyActiveBackup) {
|
||||
return e.BadRequestError("Try again later - another backup/restore process has already been started.", nil)
|
||||
}
|
||||
|
||||
key := e.Request.PathValue("key")
|
||||
|
||||
existsCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
fsys, err := e.App.NewBackupsFilesystem()
|
||||
if err != nil {
|
||||
return e.InternalServerError("Failed to load backups filesystem.", err)
|
||||
}
|
||||
defer fsys.Close()
|
||||
|
||||
fsys.SetContext(existsCtx)
|
||||
|
||||
if exists, err := fsys.Exists(key); !exists {
|
||||
return e.BadRequestError("Missing or invalid backup file.", err)
|
||||
}
|
||||
|
||||
routine.FireAndForget(func() {
|
||||
// give some optimistic time to write the response before restarting the app
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
// wait max 10 minutes to fetch the backup
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
if err := e.App.RestoreBackup(ctx, key); err != nil {
|
||||
e.App.Logger().Error("Failed to restore backup", "key", key, "error", err.Error())
|
||||
}
|
||||
})
|
||||
|
||||
return e.NoContent(http.StatusNoContent)
|
||||
}
|
78
apis/backup_create.go
Normal file
78
apis/backup_create.go
Normal file
|
@ -0,0 +1,78 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"regexp"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
)
|
||||
|
||||
func backupCreate(e *core.RequestEvent) error {
|
||||
if e.App.Store().Has(core.StoreKeyActiveBackup) {
|
||||
return e.BadRequestError("Try again later - another backup/restore process has already been started", nil)
|
||||
}
|
||||
|
||||
form := new(backupCreateForm)
|
||||
form.app = e.App
|
||||
|
||||
err := e.BindBody(form)
|
||||
if err != nil {
|
||||
return e.BadRequestError("An error occurred while loading the submitted data.", err)
|
||||
}
|
||||
|
||||
err = form.validate()
|
||||
if err != nil {
|
||||
return e.BadRequestError("An error occurred while validating the submitted data.", err)
|
||||
}
|
||||
|
||||
err = e.App.CreateBackup(context.Background(), form.Name)
|
||||
if err != nil {
|
||||
return e.BadRequestError("Failed to create backup.", err)
|
||||
}
|
||||
|
||||
// we don't retrieve the generated backup file because it may not be
|
||||
// available yet due to the eventually consistent nature of some S3 providers
|
||||
return e.NoContent(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
var backupNameRegex = regexp.MustCompile(`^[a-z0-9_-]+\.zip$`)
|
||||
|
||||
type backupCreateForm struct {
|
||||
app core.App
|
||||
|
||||
Name string `form:"name" json:"name"`
|
||||
}
|
||||
|
||||
func (form *backupCreateForm) validate() error {
|
||||
return validation.ValidateStruct(form,
|
||||
validation.Field(
|
||||
&form.Name,
|
||||
validation.Length(1, 150),
|
||||
validation.Match(backupNameRegex),
|
||||
validation.By(form.checkUniqueName),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
func (form *backupCreateForm) checkUniqueName(value any) error {
|
||||
v, _ := value.(string)
|
||||
if v == "" {
|
||||
return nil // nothing to check
|
||||
}
|
||||
|
||||
fsys, err := form.app.NewBackupsFilesystem()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer fsys.Close()
|
||||
|
||||
if exists, err := fsys.Exists(v); err != nil || exists {
|
||||
return validation.NewError("validation_backup_name_exists", "The backup file name is invalid or already exists.")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
823
apis/backup_test.go
Normal file
823
apis/backup_test.go
Normal file
|
@ -0,0 +1,823 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"archive/zip"
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/apis"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
"github.com/pocketbase/pocketbase/tools/filesystem/blob"
|
||||
)
|
||||
|
||||
func TestBackupsList(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "unauthorized",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/backups",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := createTestBackups(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as regular user",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/backups",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := createTestBackups(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser (empty list)",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/backups",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{`[]`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/backups",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := createTestBackups(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"test1.zip"`,
|
||||
`"test2.zip"`,
|
||||
`"test3.zip"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackupsCreate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "unauthorized",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/backups",
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
ensureNoBackups(t, app)
|
||||
},
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as regular user",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/backups",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
ensureNoBackups(t, app)
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser (pending backup)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/backups",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Store().Set(core.StoreKeyActiveBackup, "")
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
ensureNoBackups(t, app)
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser (autogenerated name)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/backups",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
files, err := getBackupFiles(app)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if total := len(files); total != 1 {
|
||||
t.Fatalf("Expected 1 backup file, got %d", total)
|
||||
}
|
||||
|
||||
expected := "pb_backup_"
|
||||
if !strings.HasPrefix(files[0].Key, expected) {
|
||||
t.Fatalf("Expected backup file with prefix %q, got %q", expected, files[0].Key)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnBackupCreate": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser (invalid name)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/backups",
|
||||
Body: strings.NewReader(`{"name":"!test.zip"}`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
ensureNoBackups(t, app)
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"name":{"code":"validation_match_invalid"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser (valid name)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/backups",
|
||||
Body: strings.NewReader(`{"name":"test.zip"}`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
files, err := getBackupFiles(app)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if total := len(files); total != 1 {
|
||||
t.Fatalf("Expected 1 backup file, got %d", total)
|
||||
}
|
||||
|
||||
expected := "test.zip"
|
||||
if files[0].Key != expected {
|
||||
t.Fatalf("Expected backup file %q, got %q", expected, files[0].Key)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnBackupCreate": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackupUpload(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// create dummy form data bodies
|
||||
type body struct {
|
||||
buffer io.Reader
|
||||
contentType string
|
||||
}
|
||||
bodies := make([]body, 10)
|
||||
for i := 0; i < 10; i++ {
|
||||
func() {
|
||||
zb := new(bytes.Buffer)
|
||||
zw := zip.NewWriter(zb)
|
||||
if err := zw.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
b := new(bytes.Buffer)
|
||||
mw := multipart.NewWriter(b)
|
||||
|
||||
mfw, err := mw.CreateFormFile("file", "test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := io.Copy(mfw, zb); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
mw.Close()
|
||||
|
||||
bodies[i] = body{
|
||||
buffer: b,
|
||||
contentType: mw.FormDataContentType(),
|
||||
}
|
||||
}()
|
||||
}
|
||||
// ---
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "unauthorized",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/backups/upload",
|
||||
Body: bodies[0].buffer,
|
||||
Headers: map[string]string{
|
||||
"Content-Type": bodies[0].contentType,
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
ensureNoBackups(t, app)
|
||||
},
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as regular user",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/backups/upload",
|
||||
Body: bodies[1].buffer,
|
||||
Headers: map[string]string{
|
||||
"Content-Type": bodies[1].contentType,
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
ensureNoBackups(t, app)
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser (missing file)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/backups/upload",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
ensureNoBackups(t, app)
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser (existing backup name)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/backups/upload",
|
||||
Body: bodies[3].buffer,
|
||||
Headers: map[string]string{
|
||||
"Content-Type": bodies[3].contentType,
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
fsys, err := app.NewBackupsFilesystem()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer fsys.Close()
|
||||
// create a dummy backup file to simulate existing backups
|
||||
if err := fsys.Upload([]byte("123"), "test"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
files, _ := getBackupFiles(app)
|
||||
if total := len(files); total != 1 {
|
||||
t.Fatalf("Expected %d backup file, got %d", 1, total)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{"file":{`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser (valid file)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/backups/upload",
|
||||
Body: bodies[4].buffer,
|
||||
Headers: map[string]string{
|
||||
"Content-Type": bodies[4].contentType,
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
files, _ := getBackupFiles(app)
|
||||
if total := len(files); total != 1 {
|
||||
t.Fatalf("Expected %d backup file, got %d", 1, total)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "ensure that the default body limit is skipped",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/backups/upload",
|
||||
Body: bytes.NewBuffer(make([]byte, apis.DefaultMaxBodySize+100)),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
ExpectedStatus: 400, // it doesn't matter as long as it is not 413
|
||||
ExpectedContent: []string{`"data":{`},
|
||||
NotExpectedContent: []string{"entity too large"},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackupsDownload(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "unauthorized",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/backups/test1.zip",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := createTestBackups(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "with record auth header",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/backups/test1.zip",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := createTestBackups(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "with superuser auth header",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/backups/test1.zip",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := createTestBackups(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "with empty or invalid token",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/backups/test1.zip?token=",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := createTestBackups(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "with valid record auth token",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/backups/test1.zip?token=eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := createTestBackups(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "with valid record file token",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/backups/test1.zip?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6ImZpbGUiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8ifQ.nSTLuCPcGpWn2K2l-BFkC3Vlzc-ZTDPByYq8dN1oPSo",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := createTestBackups(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "with valid superuser auth token",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/backups/test1.zip?token=eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := createTestBackups(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "with expired superuser file token",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/backups/test1.zip?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsImV4cCI6MTY0MDk5MTY2MSwidHlwZSI6ImZpbGUiLCJjb2xsZWN0aW9uSWQiOiJwYmNfMzE0MjYzNTgyMyJ9.nqqtqpPhxU0045F4XP_ruAkzAidYBc5oPy9ErN3XBq0",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := createTestBackups(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "with valid superuser file token but missing backup name",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/backups/missing?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6ImZpbGUiLCJjb2xsZWN0aW9uSWQiOiJwYmNfMzE0MjYzNTgyMyJ9.Lupz541xRvrktwkrl55p5pPCF77T69ZRsohsIcb2dxc",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := createTestBackups(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "with valid superuser file token",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/backups/test1.zip?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6ImZpbGUiLCJjb2xsZWN0aW9uSWQiOiJwYmNfMzE0MjYzNTgyMyJ9.Lupz541xRvrktwkrl55p5pPCF77T69ZRsohsIcb2dxc",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := createTestBackups(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
"storage/",
|
||||
"data.db",
|
||||
"auxiliary.db",
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "with valid superuser file token and backup name with escaped char",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/backups/%40test4.zip?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6ImZpbGUiLCJjb2xsZWN0aW9uSWQiOiJwYmNfMzE0MjYzNTgyMyJ9.Lupz541xRvrktwkrl55p5pPCF77T69ZRsohsIcb2dxc",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := createTestBackups(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
"storage/",
|
||||
"data.db",
|
||||
"auxiliary.db",
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackupsDelete(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
noTestBackupFilesChanges := func(t testing.TB, app *tests.TestApp) {
|
||||
files, err := getBackupFiles(app)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expected := 4
|
||||
if total := len(files); total != expected {
|
||||
t.Fatalf("Expected %d backup(s), got %d", expected, total)
|
||||
}
|
||||
}
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "unauthorized",
|
||||
Method: http.MethodDelete,
|
||||
URL: "/api/backups/test1.zip",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := createTestBackups(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
noTestBackupFilesChanges(t, app)
|
||||
},
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as regular user",
|
||||
Method: http.MethodDelete,
|
||||
URL: "/api/backups/test1.zip",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := createTestBackups(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
noTestBackupFilesChanges(t, app)
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser (missing file)",
|
||||
Method: http.MethodDelete,
|
||||
URL: "/api/backups/missing.zip",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := createTestBackups(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
noTestBackupFilesChanges(t, app)
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser (existing file with matching active backup)",
|
||||
Method: http.MethodDelete,
|
||||
URL: "/api/backups/test1.zip",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := createTestBackups(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// mock active backup with the same name to delete
|
||||
app.Store().Set(core.StoreKeyActiveBackup, "test1.zip")
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
noTestBackupFilesChanges(t, app)
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser (existing file and no matching active backup)",
|
||||
Method: http.MethodDelete,
|
||||
URL: "/api/backups/test1.zip",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := createTestBackups(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// mock active backup with different name
|
||||
app.Store().Set(core.StoreKeyActiveBackup, "new.zip")
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
files, err := getBackupFiles(app)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if total := len(files); total != 3 {
|
||||
t.Fatalf("Expected %d backup files, got %d", 3, total)
|
||||
}
|
||||
|
||||
deletedFile := "test1.zip"
|
||||
|
||||
for _, f := range files {
|
||||
if f.Key == deletedFile {
|
||||
t.Fatalf("Expected backup %q to be deleted", deletedFile)
|
||||
}
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser (backup with escaped character)",
|
||||
Method: http.MethodDelete,
|
||||
URL: "/api/backups/%40test4.zip",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := createTestBackups(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
files, err := getBackupFiles(app)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if total := len(files); total != 3 {
|
||||
t.Fatalf("Expected %d backup files, got %d", 3, total)
|
||||
}
|
||||
|
||||
deletedFile := "@test4.zip"
|
||||
|
||||
for _, f := range files {
|
||||
if f.Key == deletedFile {
|
||||
t.Fatalf("Expected backup %q to be deleted", deletedFile)
|
||||
}
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackupsRestore(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "unauthorized",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/backups/test1.zip/restore",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := createTestBackups(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as regular user",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/backups/test1.zip/restore",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := createTestBackups(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser (missing file)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/backups/missing.zip/restore",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := createTestBackups(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser (active backup process)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/backups/test1.zip/restore",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := createTestBackups(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
app.Store().Set(core.StoreKeyActiveBackup, "")
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
func createTestBackups(app core.App) error {
|
||||
ctx := context.Background()
|
||||
|
||||
if err := app.CreateBackup(ctx, "test1.zip"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := app.CreateBackup(ctx, "test2.zip"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := app.CreateBackup(ctx, "test3.zip"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := app.CreateBackup(ctx, "@test4.zip"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getBackupFiles(app core.App) ([]*blob.ListObject, error) {
|
||||
fsys, err := app.NewBackupsFilesystem()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer fsys.Close()
|
||||
|
||||
return fsys.List("")
|
||||
}
|
||||
|
||||
func ensureNoBackups(t testing.TB, app *tests.TestApp) {
|
||||
files, err := getBackupFiles(app)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if total := len(files); total != 0 {
|
||||
t.Fatalf("Expected 0 backup files, got %d", total)
|
||||
}
|
||||
}
|
72
apis/backup_upload.go
Normal file
72
apis/backup_upload.go
Normal file
|
@ -0,0 +1,72 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/core/validators"
|
||||
"github.com/pocketbase/pocketbase/tools/filesystem"
|
||||
)
|
||||
|
||||
func backupUpload(e *core.RequestEvent) error {
|
||||
fsys, err := e.App.NewBackupsFilesystem()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer fsys.Close()
|
||||
|
||||
form := new(backupUploadForm)
|
||||
form.fsys = fsys
|
||||
files, _ := e.FindUploadedFiles("file")
|
||||
if len(files) > 0 {
|
||||
form.File = files[0]
|
||||
}
|
||||
|
||||
err = form.validate()
|
||||
if err != nil {
|
||||
return e.BadRequestError("An error occurred while validating the submitted data.", err)
|
||||
}
|
||||
|
||||
err = fsys.UploadFile(form.File, form.File.OriginalName)
|
||||
if err != nil {
|
||||
return e.BadRequestError("Failed to upload backup.", err)
|
||||
}
|
||||
|
||||
// we don't retrieve the generated backup file because it may not be
|
||||
// available yet due to the eventually consistent nature of some S3 providers
|
||||
return e.NoContent(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
type backupUploadForm struct {
|
||||
fsys *filesystem.System
|
||||
|
||||
File *filesystem.File `json:"file"`
|
||||
}
|
||||
|
||||
func (form *backupUploadForm) validate() error {
|
||||
return validation.ValidateStruct(form,
|
||||
validation.Field(
|
||||
&form.File,
|
||||
validation.Required,
|
||||
validation.By(validators.UploadedFileMimeType([]string{"application/zip"})),
|
||||
validation.By(form.checkUniqueName),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
func (form *backupUploadForm) checkUniqueName(value any) error {
|
||||
v, _ := value.(*filesystem.File)
|
||||
if v == nil {
|
||||
return nil // nothing to check
|
||||
}
|
||||
|
||||
// note: we use the original name because that is what we upload
|
||||
if exists, err := form.fsys.Exists(v.OriginalName); err != nil || exists {
|
||||
return validation.NewError("validation_backup_name_exists", "Backup file with the specified name already exists.")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
174
apis/base.go
Normal file
174
apis/base.go
Normal file
|
@ -0,0 +1,174 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tools/router"
|
||||
)
|
||||
|
||||
// StaticWildcardParam is the name of Static handler wildcard parameter.
|
||||
const StaticWildcardParam = "path"
|
||||
|
||||
// NewRouter returns a new router instance loaded with the default app middlewares and api routes.
|
||||
func NewRouter(app core.App) (*router.Router[*core.RequestEvent], error) {
|
||||
pbRouter := router.NewRouter(func(w http.ResponseWriter, r *http.Request) (*core.RequestEvent, router.EventCleanupFunc) {
|
||||
event := new(core.RequestEvent)
|
||||
event.Response = w
|
||||
event.Request = r
|
||||
event.App = app
|
||||
|
||||
return event, nil
|
||||
})
|
||||
|
||||
// register default middlewares
|
||||
pbRouter.Bind(activityLogger())
|
||||
pbRouter.Bind(panicRecover())
|
||||
pbRouter.Bind(rateLimit())
|
||||
pbRouter.Bind(loadAuthToken())
|
||||
pbRouter.Bind(securityHeaders())
|
||||
pbRouter.Bind(BodyLimit(DefaultMaxBodySize))
|
||||
|
||||
apiGroup := pbRouter.Group("/api")
|
||||
bindSettingsApi(app, apiGroup)
|
||||
bindCollectionApi(app, apiGroup)
|
||||
bindRecordCrudApi(app, apiGroup)
|
||||
bindRecordAuthApi(app, apiGroup)
|
||||
bindLogsApi(app, apiGroup)
|
||||
bindBackupApi(app, apiGroup)
|
||||
bindCronApi(app, apiGroup)
|
||||
bindFileApi(app, apiGroup)
|
||||
bindBatchApi(app, apiGroup)
|
||||
bindRealtimeApi(app, apiGroup)
|
||||
bindHealthApi(app, apiGroup)
|
||||
|
||||
return pbRouter, nil
|
||||
}
|
||||
|
||||
// WrapStdHandler wraps Go [http.Handler] into a PocketBase handler func.
|
||||
func WrapStdHandler(h http.Handler) func(*core.RequestEvent) error {
|
||||
return func(e *core.RequestEvent) error {
|
||||
h.ServeHTTP(e.Response, e.Request)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WrapStdMiddleware wraps Go [func(http.Handler) http.Handle] into a PocketBase middleware func.
|
||||
func WrapStdMiddleware(m func(http.Handler) http.Handler) func(*core.RequestEvent) error {
|
||||
return func(e *core.RequestEvent) (err error) {
|
||||
m(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
e.Response = w
|
||||
e.Request = r
|
||||
err = e.Next()
|
||||
})).ServeHTTP(e.Response, e.Request)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// MustSubFS returns an [fs.FS] corresponding to the subtree rooted at fsys's dir.
|
||||
//
|
||||
// This is similar to [fs.Sub] but panics on failure.
|
||||
func MustSubFS(fsys fs.FS, dir string) fs.FS {
|
||||
dir = filepath.ToSlash(filepath.Clean(dir)) // ToSlash in case of Windows path
|
||||
|
||||
sub, err := fs.Sub(fsys, dir)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("failed to create sub FS: %w", err))
|
||||
}
|
||||
|
||||
return sub
|
||||
}
|
||||
|
||||
// Static is a handler function to serve static directory content from fsys.
|
||||
//
|
||||
// If a file resource is missing and indexFallback is set, the request
|
||||
// will be forwarded to the base index.html (useful for SPA with pretty urls).
|
||||
//
|
||||
// NB! Expects the route to have a "{path...}" wildcard parameter.
|
||||
//
|
||||
// Special redirects:
|
||||
// - if "path" is a file that ends in index.html, it is redirected to its non-index.html version (eg. /test/index.html -> /test/)
|
||||
// - if "path" is a directory that has index.html, the index.html file is rendered,
|
||||
// otherwise if missing - returns 404 or fallback to the root index.html if indexFallback is set
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// fsys := os.DirFS("./pb_public")
|
||||
// router.GET("/files/{path...}", apis.Static(fsys, false))
|
||||
func Static(fsys fs.FS, indexFallback bool) func(*core.RequestEvent) error {
|
||||
if fsys == nil {
|
||||
panic("Static: the provided fs.FS argument is nil")
|
||||
}
|
||||
|
||||
return func(e *core.RequestEvent) error {
|
||||
// disable the activity logger to avoid flooding with messages
|
||||
//
|
||||
// note: errors are still logged
|
||||
if e.Get(requestEventKeySkipSuccessActivityLog) == nil {
|
||||
e.Set(requestEventKeySkipSuccessActivityLog, true)
|
||||
}
|
||||
|
||||
filename := e.Request.PathValue(StaticWildcardParam)
|
||||
filename = filepath.ToSlash(filepath.Clean(strings.TrimPrefix(filename, "/")))
|
||||
|
||||
// eagerly check for directory traversal
|
||||
//
|
||||
// note: this is just out of an abundance of caution because the fs.FS implementation could be non-std,
|
||||
// but usually shouldn't be necessary since os.DirFS.Open is expected to fail if the filename starts with dots
|
||||
if len(filename) > 2 && filename[0] == '.' && filename[1] == '.' && (filename[2] == '/' || filename[2] == '\\') {
|
||||
if indexFallback && filename != router.IndexPage {
|
||||
return e.FileFS(fsys, router.IndexPage)
|
||||
}
|
||||
return router.ErrFileNotFound
|
||||
}
|
||||
|
||||
fi, err := fs.Stat(fsys, filename)
|
||||
if err != nil {
|
||||
if indexFallback && filename != router.IndexPage {
|
||||
return e.FileFS(fsys, router.IndexPage)
|
||||
}
|
||||
return router.ErrFileNotFound
|
||||
}
|
||||
|
||||
if fi.IsDir() {
|
||||
// redirect to a canonical dir url, aka. with trailing slash
|
||||
if !strings.HasSuffix(e.Request.URL.Path, "/") {
|
||||
return e.Redirect(http.StatusMovedPermanently, safeRedirectPath(e.Request.URL.Path+"/"))
|
||||
}
|
||||
} else {
|
||||
urlPath := e.Request.URL.Path
|
||||
if strings.HasSuffix(urlPath, "/") {
|
||||
// redirect to a non-trailing slash file route
|
||||
urlPath = strings.TrimRight(urlPath, "/")
|
||||
if len(urlPath) > 0 {
|
||||
return e.Redirect(http.StatusMovedPermanently, safeRedirectPath(urlPath))
|
||||
}
|
||||
} else if stripped, ok := strings.CutSuffix(urlPath, router.IndexPage); ok {
|
||||
// redirect without the index.html
|
||||
return e.Redirect(http.StatusMovedPermanently, safeRedirectPath(stripped))
|
||||
}
|
||||
}
|
||||
|
||||
fileErr := e.FileFS(fsys, filename)
|
||||
|
||||
if fileErr != nil && indexFallback && filename != router.IndexPage && errors.Is(fileErr, router.ErrFileNotFound) {
|
||||
return e.FileFS(fsys, router.IndexPage)
|
||||
}
|
||||
|
||||
return fileErr
|
||||
}
|
||||
}
|
||||
|
||||
// safeRedirectPath normalizes the path string by replacing all beginning slashes
|
||||
// (`\\`, `//`, `\/`) with a single forward slash to prevent open redirect attacks
|
||||
func safeRedirectPath(path string) string {
|
||||
if len(path) > 1 && (path[0] == '\\' || path[0] == '/') && (path[1] == '\\' || path[1] == '/') {
|
||||
path = "/" + strings.TrimLeft(path, `/\`)
|
||||
}
|
||||
return path
|
||||
}
|
313
apis/base_test.go
Normal file
313
apis/base_test.go
Normal file
|
@ -0,0 +1,313 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/apis"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
"github.com/pocketbase/pocketbase/tools/router"
|
||||
)
|
||||
|
||||
func TestWrapStdHandler(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
e := new(core.RequestEvent)
|
||||
e.App = app
|
||||
e.Request = req
|
||||
e.Response = rec
|
||||
|
||||
err := apis.WrapStdHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("test"))
|
||||
}))(e)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if body := rec.Body.String(); body != "test" {
|
||||
t.Fatalf("Expected body %q, got %q", "test", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWrapStdMiddleware(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
e := new(core.RequestEvent)
|
||||
e.App = app
|
||||
e.Request = req
|
||||
e.Response = rec
|
||||
|
||||
err := apis.WrapStdMiddleware(func(h http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("test"))
|
||||
})
|
||||
})(e)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if body := rec.Body.String(); body != "test" {
|
||||
t.Fatalf("Expected body %q, got %q", "test", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStatic(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
dir := createTestDir(t)
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
fsys := os.DirFS(filepath.Join(dir, "sub"))
|
||||
|
||||
type staticScenario struct {
|
||||
path string
|
||||
indexFallback bool
|
||||
expectedStatus int
|
||||
expectBody string
|
||||
expectError bool
|
||||
}
|
||||
|
||||
scenarios := []staticScenario{
|
||||
{
|
||||
path: "",
|
||||
indexFallback: false,
|
||||
expectedStatus: 200,
|
||||
expectBody: "sub index.html",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
path: "missing/a/b/c",
|
||||
indexFallback: false,
|
||||
expectedStatus: 404,
|
||||
expectBody: "",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
path: "missing/a/b/c",
|
||||
indexFallback: true,
|
||||
expectedStatus: 200,
|
||||
expectBody: "sub index.html",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
path: "testroot", // parent directory file
|
||||
indexFallback: false,
|
||||
expectedStatus: 404,
|
||||
expectBody: "",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
path: "test",
|
||||
indexFallback: false,
|
||||
expectedStatus: 200,
|
||||
expectBody: "sub test",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
path: "sub2",
|
||||
indexFallback: false,
|
||||
expectedStatus: 301,
|
||||
expectBody: "",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
path: "sub2/",
|
||||
indexFallback: false,
|
||||
expectedStatus: 200,
|
||||
expectBody: "sub2 index.html",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
path: "sub2/test",
|
||||
indexFallback: false,
|
||||
expectedStatus: 200,
|
||||
expectBody: "sub2 test",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
path: "sub2/test/",
|
||||
indexFallback: false,
|
||||
expectedStatus: 301,
|
||||
expectBody: "",
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
// extra directory traversal checks
|
||||
dtp := []string{
|
||||
"/../",
|
||||
"\\../",
|
||||
"../",
|
||||
"../../",
|
||||
"..\\",
|
||||
"..\\..\\",
|
||||
"../..\\",
|
||||
"..\\..//",
|
||||
`%2e%2e%2f`,
|
||||
`%2e%2e%2f%2e%2e%2f`,
|
||||
`%2e%2e/`,
|
||||
`%2e%2e/%2e%2e/`,
|
||||
`..%2f`,
|
||||
`..%2f..%2f`,
|
||||
`%2e%2e%5c`,
|
||||
`%2e%2e%5c%2e%2e%5c`,
|
||||
`%2e%2e\`,
|
||||
`%2e%2e\%2e%2e\`,
|
||||
`..%5c`,
|
||||
`..%5c..%5c`,
|
||||
`%252e%252e%255c`,
|
||||
`%252e%252e%255c%252e%252e%255c`,
|
||||
`..%255c`,
|
||||
`..%255c..%255c`,
|
||||
}
|
||||
for _, p := range dtp {
|
||||
scenarios = append(scenarios,
|
||||
staticScenario{
|
||||
path: p + "testroot",
|
||||
indexFallback: false,
|
||||
expectedStatus: 404,
|
||||
expectBody: "",
|
||||
expectError: true,
|
||||
},
|
||||
staticScenario{
|
||||
path: p + "testroot",
|
||||
indexFallback: true,
|
||||
expectedStatus: 200,
|
||||
expectBody: "sub index.html",
|
||||
expectError: false,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%s_%v", i, s.path, s.indexFallback), func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/"+s.path, nil)
|
||||
req.SetPathValue(apis.StaticWildcardParam, s.path)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
e := new(core.RequestEvent)
|
||||
e.App = app
|
||||
e.Request = req
|
||||
e.Response = rec
|
||||
|
||||
err := apis.Static(fsys, s.indexFallback)(e)
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
|
||||
}
|
||||
|
||||
body := rec.Body.String()
|
||||
if body != s.expectBody {
|
||||
t.Fatalf("Expected body %q, got %q", s.expectBody, body)
|
||||
}
|
||||
|
||||
if hasErr {
|
||||
apiErr := router.ToApiError(err)
|
||||
if apiErr.Status != s.expectedStatus {
|
||||
t.Fatalf("Expected status code %d, got %d", s.expectedStatus, apiErr.Status)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMustSubFS(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dir := createTestDir(t)
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
// invalid path (no beginning and ending slashes)
|
||||
if !hasPanicked(func() {
|
||||
apis.MustSubFS(os.DirFS(dir), "/test/")
|
||||
}) {
|
||||
t.Fatalf("Expected to panic")
|
||||
}
|
||||
|
||||
// valid path
|
||||
if hasPanicked(func() {
|
||||
apis.MustSubFS(os.DirFS(dir), "./////a/b/c") // checks if ToSlash was called
|
||||
}) {
|
||||
t.Fatalf("Didn't expect to panic")
|
||||
}
|
||||
|
||||
// check sub content
|
||||
sub := apis.MustSubFS(os.DirFS(dir), "sub")
|
||||
|
||||
_, err := sub.Open("test")
|
||||
if err != nil {
|
||||
t.Fatalf("Missing expected file sub/test")
|
||||
}
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
func hasPanicked(f func()) (didPanic bool) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
didPanic = true
|
||||
}
|
||||
}()
|
||||
f()
|
||||
return
|
||||
}
|
||||
|
||||
// note: make sure to call os.RemoveAll(dir) after you are done
|
||||
// working with the created test dir.
|
||||
func createTestDir(t *testing.T) string {
|
||||
dir, err := os.MkdirTemp(os.TempDir(), "test_dir")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(filepath.Join(dir, "index.html"), []byte("root index.html"), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(dir, "testroot"), []byte("root test"), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(filepath.Join(dir, "sub"), os.ModePerm); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(dir, "sub/index.html"), []byte("sub index.html"), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(dir, "sub/test"), []byte("sub test"), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(filepath.Join(dir, "sub", "sub2"), os.ModePerm); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(dir, "sub/sub2/index.html"), []byte("sub2 index.html"), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(dir, "sub/sub2/test"), []byte("sub2 test"), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return dir
|
||||
}
|
548
apis/batch.go
Normal file
548
apis/batch.go
Normal file
|
@ -0,0 +1,548 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"regexp"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tools/filesystem"
|
||||
"github.com/pocketbase/pocketbase/tools/router"
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
"github.com/spf13/cast"
|
||||
)
|
||||
|
||||
func bindBatchApi(app core.App, rg *router.RouterGroup[*core.RequestEvent]) {
|
||||
sub := rg.Group("/batch")
|
||||
sub.POST("", batchTransaction).Unbind(DefaultBodyLimitMiddlewareId) // the body limit is inlined
|
||||
}
|
||||
|
||||
type HandleFunc func(e *core.RequestEvent) error
|
||||
|
||||
type BatchActionHandlerFunc func(app core.App, ir *core.InternalRequest, params map[string]string, next func(data any) error) HandleFunc
|
||||
|
||||
// ValidBatchActions defines a map with the supported batch InternalRequest actions.
|
||||
//
|
||||
// Note: when adding new routes make sure that their middlewares are inlined!
|
||||
var ValidBatchActions = map[*regexp.Regexp]BatchActionHandlerFunc{
|
||||
// "upsert" handler
|
||||
regexp.MustCompile(`^PUT /api/collections/(?P<collection>[^\/\?]+)/records(?P<query>\?.*)?$`): func(app core.App, ir *core.InternalRequest, params map[string]string, next func(any) error) HandleFunc {
|
||||
var id string
|
||||
if len(ir.Body) > 0 && ir.Body["id"] != "" {
|
||||
id = cast.ToString(ir.Body["id"])
|
||||
}
|
||||
if id != "" {
|
||||
_, err := app.FindRecordById(params["collection"], id)
|
||||
if err == nil {
|
||||
// update
|
||||
// ---
|
||||
params["id"] = id // required for the path value
|
||||
ir.Method = "PATCH"
|
||||
ir.URL = "/api/collections/" + params["collection"] + "/records/" + id + params["query"]
|
||||
return recordUpdate(false, next)
|
||||
}
|
||||
}
|
||||
|
||||
// create
|
||||
// ---
|
||||
ir.Method = "POST"
|
||||
ir.URL = "/api/collections/" + params["collection"] + "/records" + params["query"]
|
||||
return recordCreate(false, next)
|
||||
},
|
||||
regexp.MustCompile(`^POST /api/collections/(?P<collection>[^\/\?]+)/records(\?.*)?$`): func(app core.App, ir *core.InternalRequest, params map[string]string, next func(any) error) HandleFunc {
|
||||
return recordCreate(false, next)
|
||||
},
|
||||
regexp.MustCompile(`^PATCH /api/collections/(?P<collection>[^\/\?]+)/records/(?P<id>[^\/\?]+)(\?.*)?$`): func(app core.App, ir *core.InternalRequest, params map[string]string, next func(any) error) HandleFunc {
|
||||
return recordUpdate(false, next)
|
||||
},
|
||||
regexp.MustCompile(`^DELETE /api/collections/(?P<collection>[^\/\?]+)/records/(?P<id>[^\/\?]+)(\?.*)?$`): func(app core.App, ir *core.InternalRequest, params map[string]string, next func(any) error) HandleFunc {
|
||||
return recordDelete(false, next)
|
||||
},
|
||||
}
|
||||
|
||||
type BatchRequestResult struct {
|
||||
Body any `json:"body"`
|
||||
Status int `json:"status"`
|
||||
}
|
||||
|
||||
type batchRequestsForm struct {
|
||||
Requests []*core.InternalRequest `form:"requests" json:"requests"`
|
||||
|
||||
max int
|
||||
}
|
||||
|
||||
func (brs batchRequestsForm) validate() error {
|
||||
return validation.ValidateStruct(&brs,
|
||||
validation.Field(&brs.Requests, validation.Required, validation.Length(0, brs.max)),
|
||||
)
|
||||
}
|
||||
|
||||
// NB! When the request is submitted as multipart/form-data,
|
||||
// the regular fields data is expected to be submitted as serailized
|
||||
// json under the @jsonPayload field and file keys need to follow the
|
||||
// pattern "requests.N.fileField" or requests[N].fileField.
|
||||
func batchTransaction(e *core.RequestEvent) error {
|
||||
maxRequests := e.App.Settings().Batch.MaxRequests
|
||||
if !e.App.Settings().Batch.Enabled || maxRequests <= 0 {
|
||||
return e.ForbiddenError("Batch requests are not allowed.", nil)
|
||||
}
|
||||
|
||||
txTimeout := time.Duration(e.App.Settings().Batch.Timeout) * time.Second
|
||||
if txTimeout <= 0 {
|
||||
txTimeout = 3 * time.Second // for now always limit
|
||||
}
|
||||
|
||||
maxBodySize := e.App.Settings().Batch.MaxBodySize
|
||||
if maxBodySize <= 0 {
|
||||
maxBodySize = 128 << 20
|
||||
}
|
||||
|
||||
err := applyBodyLimit(e, maxBodySize)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
form := &batchRequestsForm{max: maxRequests}
|
||||
|
||||
// load base requests data
|
||||
err = e.BindBody(form)
|
||||
if err != nil {
|
||||
return e.BadRequestError("Failed to read the submitted batch data.", err)
|
||||
}
|
||||
|
||||
// load uploaded files into each request item
|
||||
// note: expects the files to be under "requests.N.fileField" or "requests[N].fileField" format
|
||||
// (the other regular fields must be put under `@jsonPayload` as serialized json)
|
||||
if strings.HasPrefix(e.Request.Header.Get("Content-Type"), "multipart/form-data") {
|
||||
for i, ir := range form.Requests {
|
||||
iStr := strconv.Itoa(i)
|
||||
|
||||
files, err := extractPrefixedFiles(e.Request, "requests."+iStr+".", "requests["+iStr+"].")
|
||||
if err != nil {
|
||||
return e.BadRequestError("Failed to read the submitted batch files data.", err)
|
||||
}
|
||||
|
||||
for key, files := range files {
|
||||
if ir.Body == nil {
|
||||
ir.Body = map[string]any{}
|
||||
}
|
||||
ir.Body[key] = files
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// validate batch request form
|
||||
err = form.validate()
|
||||
if err != nil {
|
||||
return e.BadRequestError("Invalid batch request data.", err)
|
||||
}
|
||||
|
||||
event := new(core.BatchRequestEvent)
|
||||
event.RequestEvent = e
|
||||
event.Batch = form.Requests
|
||||
|
||||
return e.App.OnBatchRequest().Trigger(event, func(e *core.BatchRequestEvent) error {
|
||||
bp := batchProcessor{
|
||||
app: e.App,
|
||||
baseEvent: e.RequestEvent,
|
||||
infoContext: core.RequestInfoContextBatch,
|
||||
}
|
||||
|
||||
if err := bp.Process(e.Batch, txTimeout); err != nil {
|
||||
return firstApiError(err, e.BadRequestError("Batch transaction failed.", err))
|
||||
}
|
||||
|
||||
return e.JSON(http.StatusOK, bp.results)
|
||||
})
|
||||
}
|
||||
|
||||
type batchProcessor struct {
|
||||
app core.App
|
||||
baseEvent *core.RequestEvent
|
||||
infoContext string
|
||||
results []*BatchRequestResult
|
||||
failedIndex int
|
||||
errCh chan error
|
||||
stopCh chan struct{}
|
||||
}
|
||||
|
||||
func (p *batchProcessor) Process(batch []*core.InternalRequest, timeout time.Duration) error {
|
||||
p.results = make([]*BatchRequestResult, 0, len(batch))
|
||||
|
||||
if p.stopCh != nil {
|
||||
close(p.stopCh)
|
||||
}
|
||||
p.stopCh = make(chan struct{}, 1)
|
||||
|
||||
if p.errCh != nil {
|
||||
close(p.errCh)
|
||||
}
|
||||
p.errCh = make(chan error, 1)
|
||||
|
||||
return p.app.RunInTransaction(func(txApp core.App) error {
|
||||
// used to interupts the recursive processing calls in case of a timeout or connection close
|
||||
defer func() {
|
||||
p.stopCh <- struct{}{}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
err := p.process(txApp, batch, 0)
|
||||
|
||||
if err != nil {
|
||||
err = validation.Errors{
|
||||
"requests": validation.Errors{
|
||||
strconv.Itoa(p.failedIndex): &BatchResponseError{
|
||||
code: "batch_request_failed",
|
||||
message: "Batch request failed.",
|
||||
err: router.ToApiError(err),
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// note: to avoid copying and due to the process recursion the final results order is reversed
|
||||
if err == nil {
|
||||
slices.Reverse(p.results)
|
||||
}
|
||||
|
||||
p.errCh <- err
|
||||
}()
|
||||
|
||||
select {
|
||||
case responseErr := <-p.errCh:
|
||||
return responseErr
|
||||
case <-time.After(timeout):
|
||||
// note: we don't return 408 Reques Timeout error because
|
||||
// some browsers perform automatic retry behind the scenes
|
||||
// which are hard to debug and unnecessary
|
||||
return errors.New("batch transaction timeout")
|
||||
case <-p.baseEvent.Request.Context().Done():
|
||||
return errors.New("batch request interrupted")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (p *batchProcessor) process(activeApp core.App, batch []*core.InternalRequest, i int) error {
|
||||
select {
|
||||
case <-p.stopCh:
|
||||
return nil
|
||||
default:
|
||||
if len(batch) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
result, err := processInternalRequest(
|
||||
activeApp,
|
||||
p.baseEvent,
|
||||
batch[0],
|
||||
p.infoContext,
|
||||
func(_ any) error {
|
||||
if len(batch) == 1 {
|
||||
return nil
|
||||
}
|
||||
|
||||
err := p.process(activeApp, batch[1:], i+1)
|
||||
|
||||
// update the failed batch index (if not already)
|
||||
if err != nil && p.failedIndex == 0 {
|
||||
p.failedIndex = i + 1
|
||||
}
|
||||
|
||||
return err
|
||||
},
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p.results = append(p.results, result)
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func processInternalRequest(
|
||||
activeApp core.App,
|
||||
baseEvent *core.RequestEvent,
|
||||
ir *core.InternalRequest,
|
||||
infoContext string,
|
||||
optNext func(data any) error,
|
||||
) (*BatchRequestResult, error) {
|
||||
handle, params, ok := prepareInternalAction(activeApp, ir, optNext)
|
||||
if !ok {
|
||||
return nil, errors.New("unknown batch request action")
|
||||
}
|
||||
|
||||
// construct a new http.Request
|
||||
// ---------------------------------------------------------------
|
||||
buf, mw, err := multipartDataFromInternalRequest(ir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r, err := http.NewRequest(strings.ToUpper(ir.Method), ir.URL, buf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// cleanup multipart temp files
|
||||
defer func() {
|
||||
if r.MultipartForm != nil {
|
||||
if err := r.MultipartForm.RemoveAll(); err != nil {
|
||||
activeApp.Logger().Warn("failed to cleanup temp batch files", "error", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// load batch request path params
|
||||
// ---
|
||||
for k, v := range params {
|
||||
r.SetPathValue(k, v)
|
||||
}
|
||||
|
||||
// clone original request
|
||||
// ---
|
||||
r.RequestURI = r.URL.RequestURI()
|
||||
r.Proto = baseEvent.Request.Proto
|
||||
r.ProtoMajor = baseEvent.Request.ProtoMajor
|
||||
r.ProtoMinor = baseEvent.Request.ProtoMinor
|
||||
r.Host = baseEvent.Request.Host
|
||||
r.RemoteAddr = baseEvent.Request.RemoteAddr
|
||||
r.TLS = baseEvent.Request.TLS
|
||||
|
||||
if s := baseEvent.Request.TransferEncoding; s != nil {
|
||||
s2 := make([]string, len(s))
|
||||
copy(s2, s)
|
||||
r.TransferEncoding = s2
|
||||
}
|
||||
|
||||
if baseEvent.Request.Trailer != nil {
|
||||
r.Trailer = baseEvent.Request.Trailer.Clone()
|
||||
}
|
||||
|
||||
if baseEvent.Request.Header != nil {
|
||||
r.Header = baseEvent.Request.Header.Clone()
|
||||
}
|
||||
|
||||
// apply batch request specific headers
|
||||
// ---
|
||||
for k, v := range ir.Headers {
|
||||
// individual Authorization header keys don't have affect
|
||||
// because the auth state is populated from the base event
|
||||
if strings.EqualFold(k, "authorization") {
|
||||
continue
|
||||
}
|
||||
|
||||
r.Header.Set(k, v)
|
||||
}
|
||||
r.Header.Set("Content-Type", mw.FormDataContentType())
|
||||
|
||||
// construct a new RequestEvent
|
||||
// ---------------------------------------------------------------
|
||||
event := &core.RequestEvent{}
|
||||
event.App = activeApp
|
||||
event.Auth = baseEvent.Auth
|
||||
event.SetAll(baseEvent.GetAll())
|
||||
|
||||
// load RequestInfo context
|
||||
if infoContext == "" {
|
||||
infoContext = core.RequestInfoContextDefault
|
||||
}
|
||||
event.Set(core.RequestEventKeyInfoContext, infoContext)
|
||||
|
||||
// assign request
|
||||
event.Request = r
|
||||
event.Request.Body = &router.RereadableReadCloser{ReadCloser: r.Body} // enables multiple reads
|
||||
|
||||
// assign response
|
||||
rec := httptest.NewRecorder()
|
||||
event.Response = &router.ResponseWriter{ResponseWriter: rec} // enables status and write tracking
|
||||
|
||||
// execute
|
||||
// ---------------------------------------------------------------
|
||||
if err := handle(event); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := rec.Result()
|
||||
defer result.Body.Close()
|
||||
|
||||
body, _ := types.ParseJSONRaw(rec.Body.Bytes())
|
||||
|
||||
return &BatchRequestResult{
|
||||
Status: result.StatusCode,
|
||||
Body: body,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func multipartDataFromInternalRequest(ir *core.InternalRequest) (*bytes.Buffer, *multipart.Writer, error) {
|
||||
buf := &bytes.Buffer{}
|
||||
|
||||
mw := multipart.NewWriter(buf)
|
||||
|
||||
regularFields := map[string]any{}
|
||||
fileFields := map[string][]*filesystem.File{}
|
||||
|
||||
// separate regular fields from files
|
||||
// ---
|
||||
for k, rawV := range ir.Body {
|
||||
switch v := rawV.(type) {
|
||||
case *filesystem.File:
|
||||
fileFields[k] = append(fileFields[k], v)
|
||||
case []*filesystem.File:
|
||||
fileFields[k] = append(fileFields[k], v...)
|
||||
default:
|
||||
regularFields[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
// submit regularFields as @jsonPayload
|
||||
// ---
|
||||
rawBody, err := json.Marshal(regularFields)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Join(err, mw.Close())
|
||||
}
|
||||
|
||||
jsonPayload, err := mw.CreateFormField("@jsonPayload")
|
||||
if err != nil {
|
||||
return nil, nil, errors.Join(err, mw.Close())
|
||||
}
|
||||
_, err = jsonPayload.Write(rawBody)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Join(err, mw.Close())
|
||||
}
|
||||
|
||||
// submit fileFields as multipart files
|
||||
// ---
|
||||
for key, files := range fileFields {
|
||||
for _, file := range files {
|
||||
part, err := mw.CreateFormFile(key, file.Name)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Join(err, mw.Close())
|
||||
}
|
||||
|
||||
fr, err := file.Reader.Open()
|
||||
if err != nil {
|
||||
return nil, nil, errors.Join(err, mw.Close())
|
||||
}
|
||||
|
||||
_, err = io.Copy(part, fr)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Join(err, fr.Close(), mw.Close())
|
||||
}
|
||||
|
||||
err = fr.Close()
|
||||
if err != nil {
|
||||
return nil, nil, errors.Join(err, mw.Close())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return buf, mw, mw.Close()
|
||||
}
|
||||
|
||||
func extractPrefixedFiles(request *http.Request, prefixes ...string) (map[string][]*filesystem.File, error) {
|
||||
if request.MultipartForm == nil {
|
||||
if err := request.ParseMultipartForm(router.DefaultMaxMemory); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
result := make(map[string][]*filesystem.File)
|
||||
|
||||
for k, fhs := range request.MultipartForm.File {
|
||||
for _, p := range prefixes {
|
||||
if strings.HasPrefix(k, p) {
|
||||
resultKey := strings.TrimPrefix(k, p)
|
||||
|
||||
for _, fh := range fhs {
|
||||
file, err := filesystem.NewFileFromMultipart(fh)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result[resultKey] = append(result[resultKey], file)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func prepareInternalAction(activeApp core.App, ir *core.InternalRequest, optNext func(data any) error) (HandleFunc, map[string]string, bool) {
|
||||
full := strings.ToUpper(ir.Method) + " " + ir.URL
|
||||
|
||||
for re, actionFactory := range ValidBatchActions {
|
||||
params, ok := findNamedMatches(re, full)
|
||||
if ok {
|
||||
return actionFactory(activeApp, ir, params, optNext), params, true
|
||||
}
|
||||
}
|
||||
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
func findNamedMatches(re *regexp.Regexp, str string) (map[string]string, bool) {
|
||||
match := re.FindStringSubmatch(str)
|
||||
if match == nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
result := map[string]string{}
|
||||
|
||||
names := re.SubexpNames()
|
||||
|
||||
for i, m := range match {
|
||||
if names[i] != "" {
|
||||
result[names[i]] = m
|
||||
}
|
||||
}
|
||||
|
||||
return result, true
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
var (
|
||||
_ router.SafeErrorItem = (*BatchResponseError)(nil)
|
||||
_ router.SafeErrorResolver = (*BatchResponseError)(nil)
|
||||
)
|
||||
|
||||
type BatchResponseError struct {
|
||||
err *router.ApiError
|
||||
code string
|
||||
message string
|
||||
}
|
||||
|
||||
func (e *BatchResponseError) Error() string {
|
||||
return e.message
|
||||
}
|
||||
|
||||
func (e *BatchResponseError) Code() string {
|
||||
return e.code
|
||||
}
|
||||
|
||||
func (e *BatchResponseError) Resolve(errData map[string]any) any {
|
||||
errData["response"] = e.err
|
||||
return errData
|
||||
}
|
||||
|
||||
func (e BatchResponseError) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(map[string]any{
|
||||
"message": e.message,
|
||||
"code": e.code,
|
||||
"response": e.err,
|
||||
})
|
||||
}
|
691
apis/batch_test.go
Normal file
691
apis/batch_test.go
Normal file
|
@ -0,0 +1,691 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
"github.com/pocketbase/pocketbase/tools/router"
|
||||
)
|
||||
|
||||
func TestBatchRequest(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
formData, mp, err := tests.MockMultipartData(
|
||||
map[string]string{
|
||||
router.JSONPayloadKey: `{
|
||||
"requests":[
|
||||
{"method":"POST", "url":"/api/collections/demo3/records", "body": {"title": "batch1"}},
|
||||
{"method":"POST", "url":"/api/collections/demo3/records", "body": {"title": "batch2"}},
|
||||
{"method":"POST", "url":"/api/collections/demo3/records", "body": {"title": "batch3"}},
|
||||
{"method":"PATCH", "url":"/api/collections/demo3/records/lcl9d87w22ml6jy", "body": {"files-": "test_FLurQTgrY8.txt"}}
|
||||
]
|
||||
}`,
|
||||
},
|
||||
"requests.0.files",
|
||||
"requests.0.files",
|
||||
"requests.0.files",
|
||||
"requests[2].files",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "disabled batch requets",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/batch",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().Batch.Enabled = false
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "max request limits reached",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/batch",
|
||||
Body: strings.NewReader(`{
|
||||
"requests": [
|
||||
{"method":"GET", "url":"/test1"},
|
||||
{"method":"GET", "url":"/test2"},
|
||||
{"method":"GET", "url":"/test3"}
|
||||
]
|
||||
}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().Batch.Enabled = true
|
||||
app.Settings().Batch.MaxRequests = 2
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"requests":{"code":"validation_length_too_long"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "trigger requests validations",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/batch",
|
||||
Body: strings.NewReader(`{
|
||||
"requests": [
|
||||
{},
|
||||
{"method":"GET", "url":"/valid"},
|
||||
{"method":"invalid", "url":"/valid"},
|
||||
{"method":"POST", "url":"` + strings.Repeat("a", 2001) + `"}
|
||||
]
|
||||
}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().Batch.Enabled = true
|
||||
app.Settings().Batch.MaxRequests = 100
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"requests":{`,
|
||||
`"0":{"method":{"code":"validation_required"`,
|
||||
`"2":{"method":{"code":"validation_in_invalid"`,
|
||||
`"3":{"url":{"code":"validation_length_too_long"`,
|
||||
},
|
||||
NotExpectedContent: []string{
|
||||
`"1":`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "unknown batch request action",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/batch",
|
||||
Body: strings.NewReader(`{
|
||||
"requests": [
|
||||
{"method":"GET", "url":"/api/health"}
|
||||
]
|
||||
}`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"requests":{`,
|
||||
`0":{"code":"batch_request_failed"`,
|
||||
`"response":{`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnBatchRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "base 2 successful and 1 failed (public collection)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/batch",
|
||||
Body: strings.NewReader(`{
|
||||
"requests": [
|
||||
{"method":"POST", "url":"/api/collections/demo2/records", "body": {"title": "batch1"}},
|
||||
{"method":"POST", "url":"/api/collections/demo2/records", "body": {"title": "batch2"}},
|
||||
{"method":"POST", "url":"/api/collections/demo2/records", "body": {"title": ""}}
|
||||
]
|
||||
}`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"response":{`,
|
||||
`"2":{"code":"batch_request_failed"`,
|
||||
`"response":{"data":{"title":{"code":"validation_required"`,
|
||||
},
|
||||
NotExpectedContent: []string{
|
||||
`"0":`,
|
||||
`"1":`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnBatchRequest": 1,
|
||||
"OnRecordCreateRequest": 3,
|
||||
"OnModelCreate": 3,
|
||||
"OnModelCreateExecute": 2,
|
||||
"OnModelAfterCreateError": 3,
|
||||
"OnModelValidate": 3,
|
||||
"OnRecordCreate": 3,
|
||||
"OnRecordCreateExecute": 2,
|
||||
"OnRecordAfterCreateError": 3,
|
||||
"OnRecordValidate": 3,
|
||||
"OnRecordEnrich": 2,
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
records, err := app.FindRecordsByFilter("demo2", `title~"batch"`, "", 0, 0)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(records) != 0 {
|
||||
t.Fatalf("Expected no batch records to be persisted, got %d", len(records))
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "base 4 successful (public collection)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/batch",
|
||||
Body: strings.NewReader(`{
|
||||
"requests": [
|
||||
{"method":"POST", "url":"/api/collections/demo2/records", "body": {"title": "batch1"}},
|
||||
{"method":"POST", "url":"/api/collections/demo2/records", "body": {"title": "batch2"}},
|
||||
{"method":"PUT", "url":"/api/collections/demo2/records", "body": {"title": "batch3"}},
|
||||
{"method":"PUT", "url":"/api/collections/demo2/records?fields=*,id:excerpt(4,true)", "body": {"id":"achvryl401bhse3","title": "batch4"}}
|
||||
]
|
||||
}`),
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"title":"batch1"`,
|
||||
`"title":"batch2"`,
|
||||
`"title":"batch3"`,
|
||||
`"title":"batch4"`,
|
||||
`"id":"achv..."`,
|
||||
`"active":false`,
|
||||
`"active":true`,
|
||||
`"status":200`,
|
||||
`"body":{`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnBatchRequest": 1,
|
||||
"OnModelValidate": 4,
|
||||
"OnRecordValidate": 4,
|
||||
"OnRecordEnrich": 4,
|
||||
|
||||
"OnRecordCreateRequest": 3,
|
||||
"OnModelCreate": 3,
|
||||
"OnModelCreateExecute": 3,
|
||||
"OnModelAfterCreateSuccess": 3,
|
||||
"OnRecordCreate": 3,
|
||||
"OnRecordCreateExecute": 3,
|
||||
"OnRecordAfterCreateSuccess": 3,
|
||||
|
||||
"OnRecordUpdateRequest": 1,
|
||||
"OnModelUpdate": 1,
|
||||
"OnModelUpdateExecute": 1,
|
||||
"OnModelAfterUpdateSuccess": 1,
|
||||
"OnRecordUpdate": 1,
|
||||
"OnRecordUpdateExecute": 1,
|
||||
"OnRecordAfterUpdateSuccess": 1,
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
records, err := app.FindRecordsByFilter("demo2", `title~"batch"`, "", 0, 0)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(records) != 4 {
|
||||
t.Fatalf("Expected %d batch records to be persisted, got %d", 3, len(records))
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "mixed create/update/delete (rules failure)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/batch",
|
||||
Body: strings.NewReader(`{
|
||||
"requests": [
|
||||
{"method":"POST", "url":"/api/collections/demo2/records", "body": {"title": "batch_create"}},
|
||||
{"method":"DELETE", "url":"/api/collections/demo2/records/achvryl401bhse3"},
|
||||
{"method":"PATCH", "url":"/api/collections/demo3/records/1tmknxy2868d869", "body": {"title": "batch_update"}}
|
||||
]
|
||||
}`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"requests":{`,
|
||||
`"2":{"code":"batch_request_failed"`,
|
||||
`"response":{`,
|
||||
},
|
||||
NotExpectedContent: []string{
|
||||
// only demo3 requires authentication
|
||||
`"0":`,
|
||||
`"1":`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnBatchRequest": 1,
|
||||
"OnModelCreate": 1,
|
||||
"OnModelCreateExecute": 1,
|
||||
"OnModelAfterCreateError": 1,
|
||||
"OnModelDelete": 1,
|
||||
"OnModelDeleteExecute": 1,
|
||||
"OnModelAfterDeleteError": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnRecordCreateRequest": 1,
|
||||
"OnRecordCreate": 1,
|
||||
"OnRecordCreateExecute": 1,
|
||||
"OnRecordAfterCreateError": 1,
|
||||
"OnRecordDeleteRequest": 1,
|
||||
"OnRecordDelete": 1,
|
||||
"OnRecordDeleteExecute": 1,
|
||||
"OnRecordAfterDeleteError": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
"OnRecordValidate": 1,
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
_, err := app.FindFirstRecordByFilter("demo2", `title="batch_create"`)
|
||||
if err == nil {
|
||||
t.Fatal("Expected record to not be created")
|
||||
}
|
||||
|
||||
_, err = app.FindFirstRecordByFilter("demo3", `title="batch_update"`)
|
||||
if err == nil {
|
||||
t.Fatal("Expected record to not be updated")
|
||||
}
|
||||
|
||||
_, err = app.FindRecordById("demo2", "achvryl401bhse3")
|
||||
if err != nil {
|
||||
t.Fatal("Expected record to not be deleted")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "mixed create/update/delete (rules success)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/batch",
|
||||
Headers: map[string]string{
|
||||
// test@example.com, clients
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
|
||||
},
|
||||
Body: strings.NewReader(`{
|
||||
"requests": [
|
||||
{"method":"POST", "url":"/api/collections/demo2/records", "body": {"title": "batch_create"}, "headers": {"Authorization": "ignored"}},
|
||||
{"method":"DELETE", "url":"/api/collections/demo2/records/achvryl401bhse3", "headers": {"Authorization": "ignored"}},
|
||||
{"method":"PATCH", "url":"/api/collections/demo3/records/1tmknxy2868d869", "body": {"title": "batch_update"}, "headers": {"Authorization": "ignored"}}
|
||||
]
|
||||
}`),
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"title":"batch_create"`,
|
||||
`"title":"batch_update"`,
|
||||
`"status":200`,
|
||||
`"status":204`,
|
||||
`"body":{`,
|
||||
`"body":null`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnBatchRequest": 1,
|
||||
// ---
|
||||
"OnModelCreate": 1,
|
||||
"OnModelCreateExecute": 1,
|
||||
"OnModelAfterCreateSuccess": 1,
|
||||
"OnModelDelete": 1,
|
||||
"OnModelDeleteExecute": 1,
|
||||
"OnModelAfterDeleteSuccess": 1,
|
||||
"OnModelUpdate": 1,
|
||||
"OnModelUpdateExecute": 1,
|
||||
"OnModelAfterUpdateSuccess": 1,
|
||||
"OnModelValidate": 2,
|
||||
// ---
|
||||
"OnRecordCreateRequest": 1,
|
||||
"OnRecordCreate": 1,
|
||||
"OnRecordCreateExecute": 1,
|
||||
"OnRecordAfterCreateSuccess": 1,
|
||||
"OnRecordDeleteRequest": 1,
|
||||
"OnRecordDelete": 1,
|
||||
"OnRecordDeleteExecute": 1,
|
||||
"OnRecordAfterDeleteSuccess": 1,
|
||||
"OnRecordUpdateRequest": 1,
|
||||
"OnRecordUpdate": 1,
|
||||
"OnRecordUpdateExecute": 1,
|
||||
"OnRecordAfterUpdateSuccess": 1,
|
||||
"OnRecordValidate": 2,
|
||||
"OnRecordEnrich": 2,
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
_, err := app.FindFirstRecordByFilter("demo2", `title="batch_create"`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = app.FindFirstRecordByFilter("demo3", `title="batch_update"`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = app.FindRecordById("demo2", "achvryl401bhse3")
|
||||
if err == nil {
|
||||
t.Fatal("Expected record to be deleted")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "mixed create/update/delete (superuser auth)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/batch",
|
||||
Headers: map[string]string{
|
||||
// test@example.com, _superusers
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
Body: strings.NewReader(`{
|
||||
"requests": [
|
||||
{"method":"POST", "url":"/api/collections/demo2/records", "body": {"title": "batch_create"}},
|
||||
{"method":"DELETE", "url":"/api/collections/demo2/records/achvryl401bhse3"},
|
||||
{"method":"PATCH", "url":"/api/collections/demo3/records/1tmknxy2868d869", "body": {"title": "batch_update"}}
|
||||
]
|
||||
}`),
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"title":"batch_create"`,
|
||||
`"title":"batch_update"`,
|
||||
`"status":200`,
|
||||
`"status":204`,
|
||||
`"body":{`,
|
||||
`"body":null`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnBatchRequest": 1,
|
||||
// ---
|
||||
"OnModelCreate": 1,
|
||||
"OnModelCreateExecute": 1,
|
||||
"OnModelAfterCreateSuccess": 1,
|
||||
"OnModelDelete": 1,
|
||||
"OnModelDeleteExecute": 1,
|
||||
"OnModelAfterDeleteSuccess": 1,
|
||||
"OnModelUpdate": 1,
|
||||
"OnModelUpdateExecute": 1,
|
||||
"OnModelAfterUpdateSuccess": 1,
|
||||
"OnModelValidate": 2,
|
||||
// ---
|
||||
"OnRecordCreateRequest": 1,
|
||||
"OnRecordCreate": 1,
|
||||
"OnRecordCreateExecute": 1,
|
||||
"OnRecordAfterCreateSuccess": 1,
|
||||
"OnRecordDeleteRequest": 1,
|
||||
"OnRecordDelete": 1,
|
||||
"OnRecordDeleteExecute": 1,
|
||||
"OnRecordAfterDeleteSuccess": 1,
|
||||
"OnRecordUpdateRequest": 1,
|
||||
"OnRecordUpdate": 1,
|
||||
"OnRecordUpdateExecute": 1,
|
||||
"OnRecordAfterUpdateSuccess": 1,
|
||||
"OnRecordValidate": 2,
|
||||
"OnRecordEnrich": 2,
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
_, err := app.FindFirstRecordByFilter("demo2", `title="batch_create"`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = app.FindFirstRecordByFilter("demo3", `title="batch_update"`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = app.FindRecordById("demo2", "achvryl401bhse3")
|
||||
if err == nil {
|
||||
t.Fatal("Expected record to be deleted")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "cascade delete/update",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/batch",
|
||||
Headers: map[string]string{
|
||||
// test@example.com, _superusers
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
Body: strings.NewReader(`{
|
||||
"requests": [
|
||||
{"method":"DELETE", "url":"/api/collections/demo3/records/1tmknxy2868d869"},
|
||||
{"method":"DELETE", "url":"/api/collections/demo3/records/mk5fmymtx4wsprk"}
|
||||
]
|
||||
}`),
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"status":204`,
|
||||
`"body":null`,
|
||||
},
|
||||
NotExpectedContent: []string{
|
||||
`"status":200`,
|
||||
`"body":{`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnBatchRequest": 1,
|
||||
// ---
|
||||
"OnModelDelete": 3, // 2 batch + 1 cascade delete
|
||||
"OnModelDeleteExecute": 3,
|
||||
"OnModelAfterDeleteSuccess": 3,
|
||||
"OnModelUpdate": 5, // 5 cascade update
|
||||
"OnModelUpdateExecute": 5,
|
||||
"OnModelAfterUpdateSuccess": 5,
|
||||
// ---
|
||||
"OnRecordDeleteRequest": 2,
|
||||
"OnRecordDelete": 3,
|
||||
"OnRecordDeleteExecute": 3,
|
||||
"OnRecordAfterDeleteSuccess": 3,
|
||||
"OnRecordUpdate": 5,
|
||||
"OnRecordUpdateExecute": 5,
|
||||
"OnRecordAfterUpdateSuccess": 5,
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
ids := []string{
|
||||
"1tmknxy2868d869",
|
||||
"mk5fmymtx4wsprk",
|
||||
"qzaqccwrmva4o1n",
|
||||
}
|
||||
|
||||
for _, id := range ids {
|
||||
_, err := app.FindRecordById("demo2", id)
|
||||
if err == nil {
|
||||
t.Fatalf("Expected record %q to be deleted", id)
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "transaction timeout",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/batch",
|
||||
Body: strings.NewReader(`{
|
||||
"requests": [
|
||||
{"method":"POST", "url":"/api/collections/demo2/records", "body": {"title": "batch1"}},
|
||||
{"method":"POST", "url":"/api/collections/demo2/records", "body": {"title": "batch2"}}
|
||||
]
|
||||
}`),
|
||||
Headers: map[string]string{
|
||||
// test@example.com, _superusers
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().Batch.Timeout = 1
|
||||
app.OnRecordCreateRequest("demo2").BindFunc(func(e *core.RecordRequestEvent) error {
|
||||
time.Sleep(600 * time.Millisecond) // < 1s so that the first request can succeed
|
||||
return e.Next()
|
||||
})
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{}`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnBatchRequest": 1,
|
||||
"OnRecordCreateRequest": 2,
|
||||
"OnModelCreate": 1,
|
||||
"OnModelCreateExecute": 1,
|
||||
"OnModelAfterCreateError": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnRecordCreate": 1,
|
||||
"OnRecordCreateExecute": 1,
|
||||
"OnRecordAfterCreateError": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
"OnRecordValidate": 1,
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
records, err := app.FindRecordsByFilter("demo2", `title~"batch"`, "", 0, 0)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(records) != 0 {
|
||||
t.Fatalf("Expected %d batch records to be persisted, got %d", 0, len(records))
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "multipart/form-data + file upload",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/batch",
|
||||
Body: formData,
|
||||
Headers: map[string]string{
|
||||
// test@example.com, clients
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
|
||||
"Content-Type": mp.FormDataContentType(),
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"title":"batch1"`,
|
||||
`"title":"batch2"`,
|
||||
`"title":"batch3"`,
|
||||
`"id":"lcl9d87w22ml6jy"`,
|
||||
`"files":["300_UhLKX91HVb.png"]`,
|
||||
`"tmpfile_`,
|
||||
`"status":200`,
|
||||
`"body":{`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnBatchRequest": 1,
|
||||
// ---
|
||||
"OnModelCreate": 3,
|
||||
"OnModelCreateExecute": 3,
|
||||
"OnModelAfterCreateSuccess": 3,
|
||||
"OnModelUpdate": 1,
|
||||
"OnModelUpdateExecute": 1,
|
||||
"OnModelAfterUpdateSuccess": 1,
|
||||
"OnModelValidate": 4,
|
||||
// ---
|
||||
"OnRecordCreateRequest": 3,
|
||||
"OnRecordUpdateRequest": 1,
|
||||
"OnRecordCreate": 3,
|
||||
"OnRecordCreateExecute": 3,
|
||||
"OnRecordAfterCreateSuccess": 3,
|
||||
"OnRecordUpdate": 1,
|
||||
"OnRecordUpdateExecute": 1,
|
||||
"OnRecordAfterUpdateSuccess": 1,
|
||||
"OnRecordValidate": 4,
|
||||
"OnRecordEnrich": 4,
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
batch1, err := app.FindFirstRecordByFilter("demo3", `title="batch1"`)
|
||||
if err != nil {
|
||||
t.Fatalf("missing batch1: %v", err)
|
||||
}
|
||||
batch1Files := batch1.GetStringSlice("files")
|
||||
if len(batch1Files) != 3 {
|
||||
t.Fatalf("Expected %d batch1 file(s), got %d", 3, len(batch1Files))
|
||||
}
|
||||
|
||||
batch2, err := app.FindFirstRecordByFilter("demo3", `title="batch2"`)
|
||||
if err != nil {
|
||||
t.Fatalf("missing batch2: %v", err)
|
||||
}
|
||||
batch2Files := batch2.GetStringSlice("files")
|
||||
if len(batch2Files) != 0 {
|
||||
t.Fatalf("Expected %d batch2 file(s), got %d", 0, len(batch2Files))
|
||||
}
|
||||
|
||||
batch3, err := app.FindFirstRecordByFilter("demo3", `title="batch3"`)
|
||||
if err != nil {
|
||||
t.Fatalf("missing batch3: %v", err)
|
||||
}
|
||||
batch3Files := batch3.GetStringSlice("files")
|
||||
if len(batch3Files) != 1 {
|
||||
t.Fatalf("Expected %d batch3 file(s), got %d", 1, len(batch3Files))
|
||||
}
|
||||
|
||||
batch4, err := app.FindRecordById("demo3", "lcl9d87w22ml6jy")
|
||||
if err != nil {
|
||||
t.Fatalf("missing batch4: %v", err)
|
||||
}
|
||||
batch4Files := batch4.GetStringSlice("files")
|
||||
if len(batch4Files) != 1 {
|
||||
t.Fatalf("Expected %d batch4 file(s), got %d", 1, len(batch4Files))
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "create/update with expand query params",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/batch",
|
||||
Headers: map[string]string{
|
||||
// test@example.com, _superusers
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
Body: strings.NewReader(`{
|
||||
"requests": [
|
||||
{"method":"POST", "url":"/api/collections/demo5/records?expand=rel_one", "body": {"total": 9, "rel_one":"qzaqccwrmva4o1n"}},
|
||||
{"method":"PATCH", "url":"/api/collections/demo5/records/qjeql998mtp1azp?expand=rel_many", "body": {"total": 10}}
|
||||
]
|
||||
}`),
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"body":{`,
|
||||
`"id":"qjeql998mtp1azp"`,
|
||||
`"id":"qzaqccwrmva4o1n"`,
|
||||
`"id":"i9naidtvr6qsgb4"`,
|
||||
`"expand":{"rel_one"`,
|
||||
`"expand":{"rel_many"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnBatchRequest": 1,
|
||||
// ---
|
||||
"OnModelCreate": 1,
|
||||
"OnModelCreateExecute": 1,
|
||||
"OnModelAfterCreateSuccess": 1,
|
||||
"OnModelUpdate": 1,
|
||||
"OnModelUpdateExecute": 1,
|
||||
"OnModelAfterUpdateSuccess": 1,
|
||||
"OnModelValidate": 2,
|
||||
// ---
|
||||
"OnRecordCreateRequest": 1,
|
||||
"OnRecordUpdateRequest": 1,
|
||||
"OnRecordCreate": 1,
|
||||
"OnRecordCreateExecute": 1,
|
||||
"OnRecordAfterCreateSuccess": 1,
|
||||
"OnRecordUpdate": 1,
|
||||
"OnRecordUpdateExecute": 1,
|
||||
"OnRecordAfterUpdateSuccess": 1,
|
||||
"OnRecordValidate": 2,
|
||||
"OnRecordEnrich": 5,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "check body limit middleware",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/batch",
|
||||
Headers: map[string]string{
|
||||
// test@example.com, _superusers
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
Body: strings.NewReader(`{
|
||||
"requests": [
|
||||
{"method":"POST", "url":"/api/collections/demo5/records?expand=rel_one", "body": {"total": 9, "rel_one":"qzaqccwrmva4o1n"}},
|
||||
{"method":"PATCH", "url":"/api/collections/demo5/records/qjeql998mtp1azp?expand=rel_many", "body": {"total": 10}}
|
||||
]
|
||||
}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().Batch.MaxBodySize = 10
|
||||
},
|
||||
ExpectedStatus: 413,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
206
apis/collection.go
Normal file
206
apis/collection.go
Normal file
|
@ -0,0 +1,206 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tools/router"
|
||||
"github.com/pocketbase/pocketbase/tools/search"
|
||||
)
|
||||
|
||||
// bindCollectionApi registers the collection api endpoints and the corresponding handlers.
|
||||
func bindCollectionApi(app core.App, rg *router.RouterGroup[*core.RequestEvent]) {
|
||||
subGroup := rg.Group("/collections").Bind(RequireSuperuserAuth())
|
||||
subGroup.GET("", collectionsList)
|
||||
subGroup.POST("", collectionCreate)
|
||||
subGroup.GET("/{collection}", collectionView)
|
||||
subGroup.PATCH("/{collection}", collectionUpdate)
|
||||
subGroup.DELETE("/{collection}", collectionDelete)
|
||||
subGroup.DELETE("/{collection}/truncate", collectionTruncate)
|
||||
subGroup.PUT("/import", collectionsImport)
|
||||
subGroup.GET("/meta/scaffolds", collectionScaffolds)
|
||||
}
|
||||
|
||||
func collectionsList(e *core.RequestEvent) error {
|
||||
fieldResolver := search.NewSimpleFieldResolver(
|
||||
"id", "created", "updated", "name", "system", "type",
|
||||
)
|
||||
|
||||
collections := []*core.Collection{}
|
||||
|
||||
result, err := search.NewProvider(fieldResolver).
|
||||
Query(e.App.CollectionQuery()).
|
||||
ParseAndExec(e.Request.URL.Query().Encode(), &collections)
|
||||
|
||||
if err != nil {
|
||||
return e.BadRequestError("", err)
|
||||
}
|
||||
|
||||
event := new(core.CollectionsListRequestEvent)
|
||||
event.RequestEvent = e
|
||||
event.Collections = collections
|
||||
event.Result = result
|
||||
|
||||
return event.App.OnCollectionsListRequest().Trigger(event, func(e *core.CollectionsListRequestEvent) error {
|
||||
return execAfterSuccessTx(true, e.App, func() error {
|
||||
return e.JSON(http.StatusOK, e.Result)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func collectionView(e *core.RequestEvent) error {
|
||||
collection, err := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue("collection"))
|
||||
if err != nil || collection == nil {
|
||||
return e.NotFoundError("", err)
|
||||
}
|
||||
|
||||
event := new(core.CollectionRequestEvent)
|
||||
event.RequestEvent = e
|
||||
event.Collection = collection
|
||||
|
||||
return e.App.OnCollectionViewRequest().Trigger(event, func(e *core.CollectionRequestEvent) error {
|
||||
return execAfterSuccessTx(true, e.App, func() error {
|
||||
return e.JSON(http.StatusOK, e.Collection)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func collectionCreate(e *core.RequestEvent) error {
|
||||
// populate the minimal required factory collection data (if any)
|
||||
factoryExtract := struct {
|
||||
Type string `form:"type" json:"type"`
|
||||
Name string `form:"name" json:"name"`
|
||||
}{}
|
||||
if err := e.BindBody(&factoryExtract); err != nil {
|
||||
return e.BadRequestError("Failed to load the collection type data due to invalid formatting.", err)
|
||||
}
|
||||
|
||||
// create scaffold
|
||||
collection := core.NewCollection(factoryExtract.Type, factoryExtract.Name)
|
||||
|
||||
// merge the scaffold with the submitted request data
|
||||
if err := e.BindBody(collection); err != nil {
|
||||
return e.BadRequestError("Failed to load the submitted data due to invalid formatting.", err)
|
||||
}
|
||||
|
||||
event := new(core.CollectionRequestEvent)
|
||||
event.RequestEvent = e
|
||||
event.Collection = collection
|
||||
|
||||
return e.App.OnCollectionCreateRequest().Trigger(event, func(e *core.CollectionRequestEvent) error {
|
||||
if err := e.App.Save(e.Collection); err != nil {
|
||||
// validation failure
|
||||
var validationErrors validation.Errors
|
||||
if errors.As(err, &validationErrors) {
|
||||
return e.BadRequestError("Failed to create collection.", validationErrors)
|
||||
}
|
||||
|
||||
// other generic db error
|
||||
return e.BadRequestError("Failed to create collection. Raw error: \n"+err.Error(), nil)
|
||||
}
|
||||
|
||||
return execAfterSuccessTx(true, e.App, func() error {
|
||||
return e.JSON(http.StatusOK, e.Collection)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func collectionUpdate(e *core.RequestEvent) error {
|
||||
collection, err := e.App.FindCollectionByNameOrId(e.Request.PathValue("collection"))
|
||||
if err != nil || collection == nil {
|
||||
return e.NotFoundError("", err)
|
||||
}
|
||||
|
||||
if err := e.BindBody(collection); err != nil {
|
||||
return e.BadRequestError("Failed to load the submitted data due to invalid formatting.", err)
|
||||
}
|
||||
|
||||
event := new(core.CollectionRequestEvent)
|
||||
event.RequestEvent = e
|
||||
event.Collection = collection
|
||||
|
||||
return event.App.OnCollectionUpdateRequest().Trigger(event, func(e *core.CollectionRequestEvent) error {
|
||||
if err := e.App.Save(e.Collection); err != nil {
|
||||
// validation failure
|
||||
var validationErrors validation.Errors
|
||||
if errors.As(err, &validationErrors) {
|
||||
return e.BadRequestError("Failed to update collection.", validationErrors)
|
||||
}
|
||||
|
||||
// other generic db error
|
||||
return e.BadRequestError("Failed to update collection. Raw error: \n"+err.Error(), nil)
|
||||
}
|
||||
|
||||
return execAfterSuccessTx(true, e.App, func() error {
|
||||
return e.JSON(http.StatusOK, e.Collection)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func collectionDelete(e *core.RequestEvent) error {
|
||||
collection, err := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue("collection"))
|
||||
if err != nil || collection == nil {
|
||||
return e.NotFoundError("", err)
|
||||
}
|
||||
|
||||
event := new(core.CollectionRequestEvent)
|
||||
event.RequestEvent = e
|
||||
event.Collection = collection
|
||||
|
||||
return e.App.OnCollectionDeleteRequest().Trigger(event, func(e *core.CollectionRequestEvent) error {
|
||||
if err := e.App.Delete(e.Collection); err != nil {
|
||||
msg := "Failed to delete collection"
|
||||
|
||||
// check fo references
|
||||
refs, _ := e.App.FindCollectionReferences(e.Collection, e.Collection.Id)
|
||||
if len(refs) > 0 {
|
||||
names := make([]string, 0, len(refs))
|
||||
for ref := range refs {
|
||||
names = append(names, ref.Name)
|
||||
}
|
||||
msg += " probably due to existing reference in " + strings.Join(names, ", ")
|
||||
}
|
||||
|
||||
return e.BadRequestError(msg, err)
|
||||
}
|
||||
|
||||
return execAfterSuccessTx(true, e.App, func() error {
|
||||
return e.NoContent(http.StatusNoContent)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func collectionTruncate(e *core.RequestEvent) error {
|
||||
collection, err := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue("collection"))
|
||||
if err != nil || collection == nil {
|
||||
return e.NotFoundError("", err)
|
||||
}
|
||||
|
||||
if collection.IsView() {
|
||||
return e.BadRequestError("View collections cannot be truncated since they don't store their own records.", nil)
|
||||
}
|
||||
|
||||
err = e.App.TruncateCollection(collection)
|
||||
if err != nil {
|
||||
return e.BadRequestError("Failed to truncate collection (most likely due to required cascade delete record references).", err)
|
||||
}
|
||||
|
||||
return e.NoContent(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func collectionScaffolds(e *core.RequestEvent) error {
|
||||
collections := map[string]*core.Collection{
|
||||
core.CollectionTypeBase: core.NewBaseCollection(""),
|
||||
core.CollectionTypeAuth: core.NewAuthCollection(""),
|
||||
core.CollectionTypeView: core.NewViewCollection(""),
|
||||
}
|
||||
|
||||
for _, c := range collections {
|
||||
c.Id = "" // clear autogenerated id
|
||||
}
|
||||
|
||||
return e.JSON(http.StatusOK, collections)
|
||||
}
|
62
apis/collection_import.go
Normal file
62
apis/collection_import.go
Normal file
|
@ -0,0 +1,62 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
)
|
||||
|
||||
func collectionsImport(e *core.RequestEvent) error {
|
||||
form := new(collectionsImportForm)
|
||||
|
||||
err := e.BindBody(form)
|
||||
if err != nil {
|
||||
return firstApiError(err, e.BadRequestError("An error occurred while loading the submitted data.", err))
|
||||
}
|
||||
|
||||
err = form.validate()
|
||||
if err != nil {
|
||||
return firstApiError(err, e.BadRequestError("An error occurred while validating the submitted data.", err))
|
||||
}
|
||||
|
||||
event := new(core.CollectionsImportRequestEvent)
|
||||
event.RequestEvent = e
|
||||
event.CollectionsData = form.Collections
|
||||
event.DeleteMissing = form.DeleteMissing
|
||||
|
||||
return event.App.OnCollectionsImportRequest().Trigger(event, func(e *core.CollectionsImportRequestEvent) error {
|
||||
importErr := e.App.ImportCollections(e.CollectionsData, form.DeleteMissing)
|
||||
if importErr == nil {
|
||||
return execAfterSuccessTx(true, e.App, func() error {
|
||||
return e.NoContent(http.StatusNoContent)
|
||||
})
|
||||
}
|
||||
|
||||
// validation failure
|
||||
var validationErrors validation.Errors
|
||||
if errors.As(importErr, &validationErrors) {
|
||||
return e.BadRequestError("Failed to import collections.", validationErrors)
|
||||
}
|
||||
|
||||
// generic/db failure
|
||||
return e.BadRequestError("Failed to import collections.", validation.Errors{"collections": validation.NewError(
|
||||
"validation_collections_import_failure",
|
||||
"Failed to import the collections configuration. Raw error:\n"+importErr.Error(),
|
||||
)})
|
||||
})
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
type collectionsImportForm struct {
|
||||
Collections []map[string]any `form:"collections" json:"collections"`
|
||||
DeleteMissing bool `form:"deleteMissing" json:"deleteMissing"`
|
||||
}
|
||||
|
||||
func (form *collectionsImportForm) validate() error {
|
||||
return validation.ValidateStruct(form,
|
||||
validation.Field(&form.Collections, validation.Required),
|
||||
)
|
||||
}
|
369
apis/collection_import_test.go
Normal file
369
apis/collection_import_test.go
Normal file
|
@ -0,0 +1,369 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestCollectionsImport(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
totalCollections := 16
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "unauthorized",
|
||||
Method: http.MethodPut,
|
||||
URL: "/api/collections/import",
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as regular user",
|
||||
Method: http.MethodPut,
|
||||
URL: "/api/collections/import",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser + empty collections",
|
||||
Method: http.MethodPut,
|
||||
URL: "/api/collections/import",
|
||||
Body: strings.NewReader(`{"collections":[]}`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"collections":{"code":"validation_required"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
collections := []*core.Collection{}
|
||||
if err := app.CollectionQuery().All(&collections); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
expected := totalCollections
|
||||
if len(collections) != expected {
|
||||
t.Fatalf("Expected %d collections, got %d", expected, len(collections))
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser + collections validator failure",
|
||||
Method: http.MethodPut,
|
||||
URL: "/api/collections/import",
|
||||
Body: strings.NewReader(`{
|
||||
"collections":[
|
||||
{"name": "import1"},
|
||||
{
|
||||
"name": "import2",
|
||||
"fields": [
|
||||
{
|
||||
"id": "koih1lqx",
|
||||
"name": "expand",
|
||||
"type": "text"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"collections":{"code":"validation_collections_import_failure"`,
|
||||
`import2`,
|
||||
`fields`,
|
||||
},
|
||||
NotExpectedContent: []string{"Raw error:"},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnCollectionsImportRequest": 1,
|
||||
"OnCollectionCreate": 2,
|
||||
"OnCollectionCreateExecute": 2,
|
||||
"OnCollectionAfterCreateError": 2,
|
||||
"OnModelCreate": 2,
|
||||
"OnModelCreateExecute": 2,
|
||||
"OnModelAfterCreateError": 2,
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
collections := []*core.Collection{}
|
||||
if err := app.CollectionQuery().All(&collections); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
expected := totalCollections
|
||||
if len(collections) != expected {
|
||||
t.Fatalf("Expected %d collections, got %d", expected, len(collections))
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser + non-validator failure",
|
||||
Method: http.MethodPut,
|
||||
URL: "/api/collections/import",
|
||||
Body: strings.NewReader(`{
|
||||
"collections":[
|
||||
{
|
||||
"name": "import1",
|
||||
"fields": [
|
||||
{
|
||||
"id": "koih1lqx",
|
||||
"name": "test",
|
||||
"type": "text"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "import2",
|
||||
"fields": [
|
||||
{
|
||||
"id": "koih1lqx",
|
||||
"name": "test",
|
||||
"type": "text"
|
||||
}
|
||||
],
|
||||
"indexes": [
|
||||
"create index idx_test on import2 (test)"
|
||||
]
|
||||
}
|
||||
]
|
||||
}`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"collections":{"code":"validation_collections_import_failure"`,
|
||||
`Raw error:`,
|
||||
`custom_error`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnCollectionsImportRequest": 1,
|
||||
"OnCollectionCreate": 1,
|
||||
"OnCollectionAfterCreateError": 1,
|
||||
"OnModelCreate": 1,
|
||||
"OnModelAfterCreateError": 1,
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.OnCollectionCreate().BindFunc(func(e *core.CollectionEvent) error {
|
||||
return errors.New("custom_error")
|
||||
})
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
collections := []*core.Collection{}
|
||||
if err := app.CollectionQuery().All(&collections); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
expected := totalCollections
|
||||
if len(collections) != expected {
|
||||
t.Fatalf("Expected %d collections, got %d", expected, len(collections))
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser + successful collections create",
|
||||
Method: http.MethodPut,
|
||||
URL: "/api/collections/import",
|
||||
Body: strings.NewReader(`{
|
||||
"collections":[
|
||||
{
|
||||
"name": "import1",
|
||||
"fields": [
|
||||
{
|
||||
"id": "koih1lqx",
|
||||
"name": "test",
|
||||
"type": "text"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "import2",
|
||||
"fields": [
|
||||
{
|
||||
"id": "koih1lqx",
|
||||
"name": "test",
|
||||
"type": "text"
|
||||
}
|
||||
],
|
||||
"indexes": [
|
||||
"create index idx_test on import2 (test)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "auth_without_fields",
|
||||
"type": "auth"
|
||||
}
|
||||
]
|
||||
}`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnCollectionsImportRequest": 1,
|
||||
"OnCollectionCreate": 3,
|
||||
"OnCollectionCreateExecute": 3,
|
||||
"OnCollectionAfterCreateSuccess": 3,
|
||||
"OnModelCreate": 3,
|
||||
"OnModelCreateExecute": 3,
|
||||
"OnModelAfterCreateSuccess": 3,
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
collections := []*core.Collection{}
|
||||
if err := app.CollectionQuery().All(&collections); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expected := totalCollections + 3
|
||||
if len(collections) != expected {
|
||||
t.Fatalf("Expected %d collections, got %d", expected, len(collections))
|
||||
}
|
||||
|
||||
indexes, err := app.TableIndexes("import2")
|
||||
if err != nil || indexes["idx_test"] == "" {
|
||||
t.Fatalf("Missing index %s (%v)", "idx_test", err)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser + create/update/delete",
|
||||
Method: http.MethodPut,
|
||||
URL: "/api/collections/import",
|
||||
Body: strings.NewReader(`{
|
||||
"deleteMissing": true,
|
||||
"collections":[
|
||||
{"name": "test123"},
|
||||
{
|
||||
"id":"wsmn24bux7wo113",
|
||||
"name":"demo1",
|
||||
"fields":[
|
||||
{
|
||||
"id":"_2hlxbmp",
|
||||
"name":"title",
|
||||
"type":"text",
|
||||
"required":true
|
||||
}
|
||||
],
|
||||
"indexes": []
|
||||
}
|
||||
]
|
||||
}`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnCollectionsImportRequest": 1,
|
||||
// ---
|
||||
"OnModelCreate": 1,
|
||||
"OnModelCreateExecute": 1,
|
||||
"OnModelAfterCreateSuccess": 1,
|
||||
"OnCollectionCreate": 1,
|
||||
"OnCollectionCreateExecute": 1,
|
||||
"OnCollectionAfterCreateSuccess": 1,
|
||||
// ---
|
||||
"OnModelUpdate": 1,
|
||||
"OnModelUpdateExecute": 1,
|
||||
"OnModelAfterUpdateSuccess": 1,
|
||||
"OnCollectionUpdate": 1,
|
||||
"OnCollectionUpdateExecute": 1,
|
||||
"OnCollectionAfterUpdateSuccess": 1,
|
||||
// ---
|
||||
"OnModelDelete": 14,
|
||||
"OnModelAfterDeleteSuccess": 14,
|
||||
"OnModelDeleteExecute": 14,
|
||||
"OnCollectionDelete": 9,
|
||||
"OnCollectionDeleteExecute": 9,
|
||||
"OnCollectionAfterDeleteSuccess": 9,
|
||||
"OnRecordAfterDeleteSuccess": 5,
|
||||
"OnRecordDelete": 5,
|
||||
"OnRecordDeleteExecute": 5,
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
collections := []*core.Collection{}
|
||||
if err := app.CollectionQuery().All(&collections); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
systemCollections := 0
|
||||
for _, c := range collections {
|
||||
if c.System {
|
||||
systemCollections++
|
||||
}
|
||||
}
|
||||
|
||||
expected := systemCollections + 2
|
||||
if len(collections) != expected {
|
||||
t.Fatalf("Expected %d collections, got %d", expected, len(collections))
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "OnCollectionsImportRequest tx body write check",
|
||||
Method: http.MethodPut,
|
||||
URL: "/api/collections/import",
|
||||
Body: strings.NewReader(`{
|
||||
"deleteMissing": true,
|
||||
"collections":[
|
||||
{"name": "test123"},
|
||||
{
|
||||
"id":"wsmn24bux7wo113",
|
||||
"name":"demo1",
|
||||
"fields":[
|
||||
{
|
||||
"id":"_2hlxbmp",
|
||||
"name":"title",
|
||||
"type":"text",
|
||||
"required":true
|
||||
}
|
||||
],
|
||||
"indexes": []
|
||||
}
|
||||
]
|
||||
}`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.OnCollectionsImportRequest().BindFunc(func(e *core.CollectionsImportRequestEvent) error {
|
||||
original := e.App
|
||||
return e.App.RunInTransaction(func(txApp core.App) error {
|
||||
e.App = txApp
|
||||
defer func() { e.App = original }()
|
||||
|
||||
if err := e.Next(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return e.BadRequestError("TX_ERROR", nil)
|
||||
})
|
||||
})
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedEvents: map[string]int{"OnCollectionsImportRequest": 1},
|
||||
ExpectedContent: []string{"TX_ERROR"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
1586
apis/collection_test.go
Normal file
1586
apis/collection_test.go
Normal file
File diff suppressed because it is too large
Load diff
59
apis/cron.go
Normal file
59
apis/cron.go
Normal file
|
@ -0,0 +1,59 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tools/cron"
|
||||
"github.com/pocketbase/pocketbase/tools/router"
|
||||
"github.com/pocketbase/pocketbase/tools/routine"
|
||||
)
|
||||
|
||||
// bindCronApi registers the crons api endpoint.
|
||||
func bindCronApi(app core.App, rg *router.RouterGroup[*core.RequestEvent]) {
|
||||
subGroup := rg.Group("/crons").Bind(RequireSuperuserAuth())
|
||||
subGroup.GET("", cronsList)
|
||||
subGroup.POST("/{id}", cronRun)
|
||||
}
|
||||
|
||||
func cronsList(e *core.RequestEvent) error {
|
||||
jobs := e.App.Cron().Jobs()
|
||||
|
||||
slices.SortStableFunc(jobs, func(a, b *cron.Job) int {
|
||||
if strings.HasPrefix(a.Id(), "__pb") {
|
||||
return 1
|
||||
}
|
||||
if strings.HasPrefix(b.Id(), "__pb") {
|
||||
return -1
|
||||
}
|
||||
return strings.Compare(a.Id(), b.Id())
|
||||
})
|
||||
|
||||
return e.JSON(http.StatusOK, jobs)
|
||||
}
|
||||
|
||||
func cronRun(e *core.RequestEvent) error {
|
||||
cronId := e.Request.PathValue("id")
|
||||
|
||||
var foundJob *cron.Job
|
||||
|
||||
jobs := e.App.Cron().Jobs()
|
||||
for _, j := range jobs {
|
||||
if j.Id() == cronId {
|
||||
foundJob = j
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if foundJob == nil {
|
||||
return e.NotFoundError("Missing or invalid cron job", nil)
|
||||
}
|
||||
|
||||
routine.FireAndForget(func() {
|
||||
foundJob.Run()
|
||||
})
|
||||
|
||||
return e.NoContent(http.StatusNoContent)
|
||||
}
|
149
apis/cron_test.go
Normal file
149
apis/cron_test.go
Normal file
|
@ -0,0 +1,149 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
"github.com/spf13/cast"
|
||||
)
|
||||
|
||||
func TestCronsList(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "unauthorized",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/crons",
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as regular user",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/crons",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser (empty list)",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/crons",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Cron().RemoveAll()
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{`[]`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/crons",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`{"id":"__pbLogsCleanup__","expression":"0 */6 * * *"}`,
|
||||
`{"id":"__pbDBOptimize__","expression":"0 0 * * *"}`,
|
||||
`{"id":"__pbMFACleanup__","expression":"0 * * * *"}`,
|
||||
`{"id":"__pbOTPCleanup__","expression":"0 * * * *"}`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCronsRun(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
beforeTestFunc := func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Cron().Add("test", "* * * * *", func() {
|
||||
app.Store().Set("testJobCalls", cast.ToInt(app.Store().Get("testJobCalls"))+1)
|
||||
})
|
||||
}
|
||||
|
||||
expectedCalls := func(expected int) func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
return func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
total := cast.ToInt(app.Store().Get("testJobCalls"))
|
||||
if total != expected {
|
||||
t.Fatalf("Expected total testJobCalls %d, got %d", expected, total)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "unauthorized",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/crons/test",
|
||||
Delay: 50 * time.Millisecond,
|
||||
BeforeTestFunc: beforeTestFunc,
|
||||
AfterTestFunc: expectedCalls(0),
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as regular user",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/crons/test",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
Delay: 50 * time.Millisecond,
|
||||
BeforeTestFunc: beforeTestFunc,
|
||||
AfterTestFunc: expectedCalls(0),
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser (missing job)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/crons/missing",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
Delay: 50 * time.Millisecond,
|
||||
BeforeTestFunc: beforeTestFunc,
|
||||
AfterTestFunc: expectedCalls(0),
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser (existing job)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/crons/test",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
Delay: 50 * time.Millisecond,
|
||||
BeforeTestFunc: beforeTestFunc,
|
||||
AfterTestFunc: expectedCalls(1),
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
230
apis/file.go
Normal file
230
apis/file.go
Normal file
|
@ -0,0 +1,230 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tools/filesystem"
|
||||
"github.com/pocketbase/pocketbase/tools/list"
|
||||
"github.com/pocketbase/pocketbase/tools/router"
|
||||
"github.com/spf13/cast"
|
||||
"golang.org/x/sync/semaphore"
|
||||
"golang.org/x/sync/singleflight"
|
||||
)
|
||||
|
||||
var imageContentTypes = []string{"image/png", "image/jpg", "image/jpeg", "image/gif", "image/webp"}
|
||||
var defaultThumbSizes = []string{"100x100"}
|
||||
|
||||
// bindFileApi registers the file api endpoints and the corresponding handlers.
|
||||
func bindFileApi(app core.App, rg *router.RouterGroup[*core.RequestEvent]) {
|
||||
maxWorkers := cast.ToInt64(os.Getenv("PB_THUMBS_MAX_WORKERS"))
|
||||
if maxWorkers <= 0 {
|
||||
maxWorkers = int64(runtime.NumCPU() + 2) // the value is arbitrary chosen and may change in the future
|
||||
}
|
||||
|
||||
maxWait := cast.ToInt64(os.Getenv("PB_THUMBS_MAX_WAIT"))
|
||||
if maxWait <= 0 {
|
||||
maxWait = 60
|
||||
}
|
||||
|
||||
api := fileApi{
|
||||
thumbGenPending: new(singleflight.Group),
|
||||
thumbGenSem: semaphore.NewWeighted(maxWorkers),
|
||||
thumbGenMaxWait: time.Duration(maxWait) * time.Second,
|
||||
}
|
||||
|
||||
sub := rg.Group("/files")
|
||||
sub.POST("/token", api.fileToken).Bind(RequireAuth())
|
||||
sub.GET("/{collection}/{recordId}/{filename}", api.download).Bind(collectionPathRateLimit("", "file"))
|
||||
}
|
||||
|
||||
type fileApi struct {
|
||||
// thumbGenSem is a semaphore to prevent too much concurrent
|
||||
// requests generating new thumbs at the same time.
|
||||
thumbGenSem *semaphore.Weighted
|
||||
|
||||
// thumbGenPending represents a group of currently pending
|
||||
// thumb generation processes.
|
||||
thumbGenPending *singleflight.Group
|
||||
|
||||
// thumbGenMaxWait is the maximum waiting time for starting a new
|
||||
// thumb generation process.
|
||||
thumbGenMaxWait time.Duration
|
||||
}
|
||||
|
||||
func (api *fileApi) fileToken(e *core.RequestEvent) error {
|
||||
if e.Auth == nil {
|
||||
return e.UnauthorizedError("Missing auth context.", nil)
|
||||
}
|
||||
|
||||
token, err := e.Auth.NewFileToken()
|
||||
if err != nil {
|
||||
return e.InternalServerError("Failed to generate file token", err)
|
||||
}
|
||||
|
||||
event := new(core.FileTokenRequestEvent)
|
||||
event.RequestEvent = e
|
||||
event.Token = token
|
||||
event.Record = e.Auth
|
||||
|
||||
return e.App.OnFileTokenRequest().Trigger(event, func(e *core.FileTokenRequestEvent) error {
|
||||
return execAfterSuccessTx(true, e.App, func() error {
|
||||
return e.JSON(http.StatusOK, map[string]string{"token": e.Token})
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func (api *fileApi) download(e *core.RequestEvent) error {
|
||||
collection, err := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue("collection"))
|
||||
if err != nil {
|
||||
return e.NotFoundError("", nil)
|
||||
}
|
||||
|
||||
recordId := e.Request.PathValue("recordId")
|
||||
if recordId == "" {
|
||||
return e.NotFoundError("", nil)
|
||||
}
|
||||
|
||||
record, err := e.App.FindRecordById(collection, recordId)
|
||||
if err != nil {
|
||||
return e.NotFoundError("", err)
|
||||
}
|
||||
|
||||
filename := e.Request.PathValue("filename")
|
||||
|
||||
fileField := record.FindFileFieldByFile(filename)
|
||||
if fileField == nil {
|
||||
return e.NotFoundError("", nil)
|
||||
}
|
||||
|
||||
// check whether the request is authorized to view the protected file
|
||||
if fileField.Protected {
|
||||
originalRequestInfo, err := e.RequestInfo()
|
||||
if err != nil {
|
||||
return e.InternalServerError("Failed to load request info", err)
|
||||
}
|
||||
|
||||
token := e.Request.URL.Query().Get("token")
|
||||
authRecord, _ := e.App.FindAuthRecordByToken(token, core.TokenTypeFile)
|
||||
|
||||
// create a shallow copy of the cached request data and adjust it to the current auth record (if any)
|
||||
requestInfo := *originalRequestInfo
|
||||
requestInfo.Context = core.RequestInfoContextProtectedFile
|
||||
requestInfo.Auth = authRecord
|
||||
|
||||
if ok, _ := e.App.CanAccessRecord(record, &requestInfo, record.Collection().ViewRule); !ok {
|
||||
return e.NotFoundError("", errors.New("insufficient permissions to access the file resource"))
|
||||
}
|
||||
}
|
||||
|
||||
baseFilesPath := record.BaseFilesPath()
|
||||
|
||||
// fetch the original view file field related record
|
||||
if collection.IsView() {
|
||||
fileRecord, err := e.App.FindRecordByViewFile(collection.Id, fileField.Name, filename)
|
||||
if err != nil {
|
||||
return e.NotFoundError("", fmt.Errorf("failed to fetch view file field record: %w", err))
|
||||
}
|
||||
baseFilesPath = fileRecord.BaseFilesPath()
|
||||
}
|
||||
|
||||
fsys, err := e.App.NewFilesystem()
|
||||
if err != nil {
|
||||
return e.InternalServerError("Filesystem initialization failure.", err)
|
||||
}
|
||||
defer fsys.Close()
|
||||
|
||||
originalPath := baseFilesPath + "/" + filename
|
||||
servedPath := originalPath
|
||||
servedName := filename
|
||||
|
||||
// check for valid thumb size param
|
||||
thumbSize := e.Request.URL.Query().Get("thumb")
|
||||
if thumbSize != "" && (list.ExistInSlice(thumbSize, defaultThumbSizes) || list.ExistInSlice(thumbSize, fileField.Thumbs)) {
|
||||
// extract the original file meta attributes and check it existence
|
||||
oAttrs, oAttrsErr := fsys.Attributes(originalPath)
|
||||
if oAttrsErr != nil {
|
||||
return e.NotFoundError("", err)
|
||||
}
|
||||
|
||||
// check if it is an image
|
||||
if list.ExistInSlice(oAttrs.ContentType, imageContentTypes) {
|
||||
// add thumb size as file suffix
|
||||
servedName = thumbSize + "_" + filename
|
||||
servedPath = baseFilesPath + "/thumbs_" + filename + "/" + servedName
|
||||
|
||||
// create a new thumb if it doesn't exist
|
||||
if exists, _ := fsys.Exists(servedPath); !exists {
|
||||
if err := api.createThumb(e, fsys, originalPath, servedPath, thumbSize); err != nil {
|
||||
e.App.Logger().Warn(
|
||||
"Fallback to original - failed to create thumb "+servedName,
|
||||
slog.Any("error", err),
|
||||
slog.String("original", originalPath),
|
||||
slog.String("thumb", servedPath),
|
||||
)
|
||||
|
||||
// fallback to the original
|
||||
servedName = filename
|
||||
servedPath = originalPath
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
event := new(core.FileDownloadRequestEvent)
|
||||
event.RequestEvent = e
|
||||
event.Collection = collection
|
||||
event.Record = record
|
||||
event.FileField = fileField
|
||||
event.ServedPath = servedPath
|
||||
event.ServedName = servedName
|
||||
|
||||
// clickjacking shouldn't be a concern when serving uploaded files,
|
||||
// so it safe to unset the global X-Frame-Options to allow files embedding
|
||||
// (note: it is out of the hook to allow users to customize the behavior)
|
||||
e.Response.Header().Del("X-Frame-Options")
|
||||
|
||||
return e.App.OnFileDownloadRequest().Trigger(event, func(e *core.FileDownloadRequestEvent) error {
|
||||
err = execAfterSuccessTx(true, e.App, func() error {
|
||||
return fsys.Serve(e.Response, e.Request, e.ServedPath, e.ServedName)
|
||||
})
|
||||
if err != nil {
|
||||
return e.NotFoundError("", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (api *fileApi) createThumb(
|
||||
e *core.RequestEvent,
|
||||
fsys *filesystem.System,
|
||||
originalPath string,
|
||||
thumbPath string,
|
||||
thumbSize string,
|
||||
) error {
|
||||
ch := api.thumbGenPending.DoChan(thumbPath, func() (any, error) {
|
||||
ctx, cancel := context.WithTimeout(e.Request.Context(), api.thumbGenMaxWait)
|
||||
defer cancel()
|
||||
|
||||
if err := api.thumbGenSem.Acquire(ctx, 1); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer api.thumbGenSem.Release(1)
|
||||
|
||||
return nil, fsys.CreateThumb(originalPath, thumbPath, thumbSize)
|
||||
})
|
||||
|
||||
res := <-ch
|
||||
|
||||
api.thumbGenPending.Forget(thumbPath)
|
||||
|
||||
return res.Err
|
||||
}
|
504
apis/file_test.go
Normal file
504
apis/file_test.go
Normal file
|
@ -0,0 +1,504 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/apis"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
)
|
||||
|
||||
func TestFileToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "unauthorized",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/files/token",
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "regular user",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/files/token",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"token":"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnFileTokenRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "superuser",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/files/token",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"token":"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnFileTokenRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "hook token overwrite",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/files/token",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.OnFileTokenRequest().BindFunc(func(e *core.FileTokenRequestEvent) error {
|
||||
e.Token = "test"
|
||||
return e.Next()
|
||||
})
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"token":"test"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnFileTokenRequest": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileDownload(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, currentFile, _, _ := runtime.Caller(0)
|
||||
dataDirRelPath := "../tests/data/"
|
||||
|
||||
testFilePath := filepath.Join(path.Dir(currentFile), dataDirRelPath, "storage/_pb_users_auth_/oap640cot4yru2s/test_kfd2wYLxkz.txt")
|
||||
testImgPath := filepath.Join(path.Dir(currentFile), dataDirRelPath, "storage/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png")
|
||||
testThumbCropCenterPath := filepath.Join(path.Dir(currentFile), dataDirRelPath, "storage/_pb_users_auth_/4q1xlclmfloku33/thumbs_300_1SEi6Q6U72.png/70x50_300_1SEi6Q6U72.png")
|
||||
testThumbCropTopPath := filepath.Join(path.Dir(currentFile), dataDirRelPath, "storage/_pb_users_auth_/4q1xlclmfloku33/thumbs_300_1SEi6Q6U72.png/70x50t_300_1SEi6Q6U72.png")
|
||||
testThumbCropBottomPath := filepath.Join(path.Dir(currentFile), dataDirRelPath, "storage/_pb_users_auth_/4q1xlclmfloku33/thumbs_300_1SEi6Q6U72.png/70x50b_300_1SEi6Q6U72.png")
|
||||
testThumbFitPath := filepath.Join(path.Dir(currentFile), dataDirRelPath, "storage/_pb_users_auth_/4q1xlclmfloku33/thumbs_300_1SEi6Q6U72.png/70x50f_300_1SEi6Q6U72.png")
|
||||
testThumbZeroWidthPath := filepath.Join(path.Dir(currentFile), dataDirRelPath, "storage/_pb_users_auth_/4q1xlclmfloku33/thumbs_300_1SEi6Q6U72.png/0x50_300_1SEi6Q6U72.png")
|
||||
testThumbZeroHeightPath := filepath.Join(path.Dir(currentFile), dataDirRelPath, "storage/_pb_users_auth_/4q1xlclmfloku33/thumbs_300_1SEi6Q6U72.png/70x0_300_1SEi6Q6U72.png")
|
||||
|
||||
testFile, fileErr := os.ReadFile(testFilePath)
|
||||
if fileErr != nil {
|
||||
t.Fatal(fileErr)
|
||||
}
|
||||
|
||||
testImg, imgErr := os.ReadFile(testImgPath)
|
||||
if imgErr != nil {
|
||||
t.Fatal(imgErr)
|
||||
}
|
||||
|
||||
testThumbCropCenter, thumbErr := os.ReadFile(testThumbCropCenterPath)
|
||||
if thumbErr != nil {
|
||||
t.Fatal(thumbErr)
|
||||
}
|
||||
|
||||
testThumbCropTop, thumbErr := os.ReadFile(testThumbCropTopPath)
|
||||
if thumbErr != nil {
|
||||
t.Fatal(thumbErr)
|
||||
}
|
||||
|
||||
testThumbCropBottom, thumbErr := os.ReadFile(testThumbCropBottomPath)
|
||||
if thumbErr != nil {
|
||||
t.Fatal(thumbErr)
|
||||
}
|
||||
|
||||
testThumbFit, thumbErr := os.ReadFile(testThumbFitPath)
|
||||
if thumbErr != nil {
|
||||
t.Fatal(thumbErr)
|
||||
}
|
||||
|
||||
testThumbZeroWidth, thumbErr := os.ReadFile(testThumbZeroWidthPath)
|
||||
if thumbErr != nil {
|
||||
t.Fatal(thumbErr)
|
||||
}
|
||||
|
||||
testThumbZeroHeight, thumbErr := os.ReadFile(testThumbZeroHeightPath)
|
||||
if thumbErr != nil {
|
||||
t.Fatal(thumbErr)
|
||||
}
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "missing collection",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/files/missing/4q1xlclmfloku33/300_1SEi6Q6U72.png",
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "missing record",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/files/_pb_users_auth_/missing/300_1SEi6Q6U72.png",
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "missing file",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/files/_pb_users_auth_/4q1xlclmfloku33/missing.png",
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "existing image",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png",
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{string(testImg)},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnFileDownloadRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "existing image - missing thumb (should fallback to the original)",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png?thumb=999x999",
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{string(testImg)},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnFileDownloadRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "existing image - existing thumb (crop center)",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png?thumb=70x50",
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{string(testThumbCropCenter)},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnFileDownloadRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "existing image - existing thumb (crop top)",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png?thumb=70x50t",
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{string(testThumbCropTop)},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnFileDownloadRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "existing image - existing thumb (crop bottom)",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png?thumb=70x50b",
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{string(testThumbCropBottom)},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnFileDownloadRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "existing image - existing thumb (fit)",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png?thumb=70x50f",
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{string(testThumbFit)},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnFileDownloadRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "existing image - existing thumb (zero width)",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png?thumb=0x50",
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{string(testThumbZeroWidth)},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnFileDownloadRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "existing image - existing thumb (zero height)",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png?thumb=70x0",
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{string(testThumbZeroHeight)},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnFileDownloadRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "existing non image file - thumb parameter should be ignored",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/files/_pb_users_auth_/oap640cot4yru2s/test_kfd2wYLxkz.txt?thumb=100x100",
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{string(testFile)},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnFileDownloadRequest": 1,
|
||||
},
|
||||
},
|
||||
|
||||
// protected file access checks
|
||||
{
|
||||
Name: "protected file - superuser with expired file token",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/files/demo1/al1h9ijdeojtsjy/300_Jsjq7RdBgA.png?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsImV4cCI6MTY0MDk5MTY2MSwidHlwZSI6ImZpbGUiLCJjb2xsZWN0aW9uSWQiOiJwYmNfMzE0MjYzNTgyMyJ9.nqqtqpPhxU0045F4XP_ruAkzAidYBc5oPy9ErN3XBq0",
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "protected file - superuser with valid file token",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/files/demo1/al1h9ijdeojtsjy/300_Jsjq7RdBgA.png?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6ImZpbGUiLCJjb2xsZWN0aW9uSWQiOiJwYmNfMzE0MjYzNTgyMyJ9.Lupz541xRvrktwkrl55p5pPCF77T69ZRsohsIcb2dxc",
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{"PNG"},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnFileDownloadRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "protected file - guest without view access",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/files/demo1/al1h9ijdeojtsjy/300_Jsjq7RdBgA.png",
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "protected file - guest with view access",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/files/demo1/al1h9ijdeojtsjy/300_Jsjq7RdBgA.png",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
// mock public view access
|
||||
c, err := app.FindCachedCollectionByNameOrId("demo1")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to fetch mock collection: %v", err)
|
||||
}
|
||||
c.ViewRule = types.Pointer("")
|
||||
if err := app.UnsafeWithoutHooks().Save(c); err != nil {
|
||||
t.Fatalf("Failed to update mock collection: %v", err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{"PNG"},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnFileDownloadRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "protected file - auth record without view access",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/files/demo1/al1h9ijdeojtsjy/300_Jsjq7RdBgA.png?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6ImZpbGUiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8ifQ.nSTLuCPcGpWn2K2l-BFkC3Vlzc-ZTDPByYq8dN1oPSo",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
// mock restricted user view access
|
||||
c, err := app.FindCachedCollectionByNameOrId("demo1")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to fetch mock collection: %v", err)
|
||||
}
|
||||
c.ViewRule = types.Pointer("@request.auth.verified = true")
|
||||
if err := app.UnsafeWithoutHooks().Save(c); err != nil {
|
||||
t.Fatalf("Failed to update mock collection: %v", err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "protected file - auth record with view access",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/files/demo1/al1h9ijdeojtsjy/300_Jsjq7RdBgA.png?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6ImZpbGUiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8ifQ.nSTLuCPcGpWn2K2l-BFkC3Vlzc-ZTDPByYq8dN1oPSo",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
// mock user view access
|
||||
c, err := app.FindCachedCollectionByNameOrId("demo1")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to fetch mock collection: %v", err)
|
||||
}
|
||||
c.ViewRule = types.Pointer("@request.auth.verified = false")
|
||||
if err := app.UnsafeWithoutHooks().Save(c); err != nil {
|
||||
t.Fatalf("Failed to update mock collection: %v", err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{"PNG"},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnFileDownloadRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "protected file in view (view's View API rule failure)",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/files/view1/al1h9ijdeojtsjy/300_Jsjq7RdBgA.png?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6ImZpbGUiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8ifQ.nSTLuCPcGpWn2K2l-BFkC3Vlzc-ZTDPByYq8dN1oPSo",
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "protected file in view (view's View API rule success)",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/files/view1/84nmscqy84lsi1t/test_d61b33QdDU.txt?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6ImZpbGUiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8ifQ.nSTLuCPcGpWn2K2l-BFkC3Vlzc-ZTDPByYq8dN1oPSo",
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{"test"},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnFileDownloadRequest": 1,
|
||||
},
|
||||
},
|
||||
|
||||
// rate limit checks
|
||||
// -----------------------------------------------------------
|
||||
{
|
||||
Name: "RateLimit rule - users:file",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 100, Label: "*:file"},
|
||||
{MaxRequests: 0, Label: "users:file"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "RateLimit rule - *:file",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 0, Label: "*:file"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
// clone for the HEAD test (the same as the original scenario but without body)
|
||||
head := scenario
|
||||
head.Method = http.MethodHead
|
||||
head.Name = ("(HEAD) " + scenario.Name)
|
||||
head.ExpectedContent = nil
|
||||
head.Test(t)
|
||||
|
||||
// regular request test
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConcurrentThumbsGeneration(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, err := tests.NewTestApp()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer app.Cleanup()
|
||||
|
||||
fsys, err := app.NewFilesystem()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer fsys.Close()
|
||||
|
||||
// create a dummy file field collection
|
||||
demo1, err := app.FindCollectionByNameOrId("demo1")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
fileField := demo1.Fields.GetByName("file_one").(*core.FileField)
|
||||
fileField.Protected = false
|
||||
fileField.MaxSelect = 1
|
||||
fileField.MaxSize = 999999
|
||||
// new thumbs
|
||||
fileField.Thumbs = []string{"111x111", "111x222", "111x333"}
|
||||
demo1.Fields.Add(fileField)
|
||||
if err = app.Save(demo1); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
fileKey := "wsmn24bux7wo113/al1h9ijdeojtsjy/300_Jsjq7RdBgA.png"
|
||||
|
||||
urls := []string{
|
||||
"/api/files/" + fileKey + "?thumb=111x111",
|
||||
"/api/files/" + fileKey + "?thumb=111x111", // should still result in single thumb
|
||||
"/api/files/" + fileKey + "?thumb=111x222",
|
||||
"/api/files/" + fileKey + "?thumb=111x333",
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
wg.Add(len(urls))
|
||||
|
||||
for _, url := range urls {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
req := httptest.NewRequest("GET", url, nil)
|
||||
|
||||
pbRouter, _ := apis.NewRouter(app)
|
||||
mux, _ := pbRouter.BuildMux()
|
||||
if mux != nil {
|
||||
mux.ServeHTTP(recorder, req)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// ensure that all new requested thumbs were created
|
||||
thumbKeys := []string{
|
||||
"wsmn24bux7wo113/al1h9ijdeojtsjy/thumbs_300_Jsjq7RdBgA.png/111x111_" + filepath.Base(fileKey),
|
||||
"wsmn24bux7wo113/al1h9ijdeojtsjy/thumbs_300_Jsjq7RdBgA.png/111x222_" + filepath.Base(fileKey),
|
||||
"wsmn24bux7wo113/al1h9ijdeojtsjy/thumbs_300_Jsjq7RdBgA.png/111x333_" + filepath.Base(fileKey),
|
||||
}
|
||||
for _, k := range thumbKeys {
|
||||
if exists, _ := fsys.Exists(k); !exists {
|
||||
t.Fatalf("Missing thumb %q: %v", k, err)
|
||||
}
|
||||
}
|
||||
}
|
53
apis/health.go
Normal file
53
apis/health.go
Normal file
|
@ -0,0 +1,53 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"slices"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tools/router"
|
||||
)
|
||||
|
||||
// bindHealthApi registers the health api endpoint.
|
||||
func bindHealthApi(app core.App, rg *router.RouterGroup[*core.RequestEvent]) {
|
||||
subGroup := rg.Group("/health")
|
||||
subGroup.GET("", healthCheck)
|
||||
}
|
||||
|
||||
// healthCheck returns a 200 OK response if the server is healthy.
|
||||
func healthCheck(e *core.RequestEvent) error {
|
||||
resp := struct {
|
||||
Message string `json:"message"`
|
||||
Code int `json:"code"`
|
||||
Data map[string]any `json:"data"`
|
||||
}{
|
||||
Code: http.StatusOK,
|
||||
Message: "API is healthy.",
|
||||
}
|
||||
|
||||
if e.HasSuperuserAuth() {
|
||||
resp.Data = make(map[string]any, 3)
|
||||
resp.Data["canBackup"] = !e.App.Store().Has(core.StoreKeyActiveBackup)
|
||||
resp.Data["realIP"] = e.RealIP()
|
||||
|
||||
// loosely check if behind a reverse proxy
|
||||
// (usually used in the dashboard to remind superusers in case deployed behind reverse-proxy)
|
||||
possibleProxyHeader := ""
|
||||
headersToCheck := append(
|
||||
slices.Clone(e.App.Settings().TrustedProxy.Headers),
|
||||
// common proxy headers
|
||||
"CF-Connecting-IP", "Fly-Client-IP", "X‑Forwarded-For",
|
||||
)
|
||||
for _, header := range headersToCheck {
|
||||
if e.Request.Header.Get(header) != "" {
|
||||
possibleProxyHeader = header
|
||||
break
|
||||
}
|
||||
}
|
||||
resp.Data["possibleProxyHeader"] = possibleProxyHeader
|
||||
} else {
|
||||
resp.Data = map[string]any{} // ensure that it is returned as object
|
||||
}
|
||||
|
||||
return e.JSON(http.StatusOK, resp)
|
||||
}
|
71
apis/health_test.go
Normal file
71
apis/health_test.go
Normal file
|
@ -0,0 +1,71 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestHealthAPI(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "GET health status (guest)",
|
||||
Method: http.MethodGet, // automatically matches also HEAD as a side-effect of the Go std mux
|
||||
URL: "/api/health",
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"code":200`,
|
||||
`"data":{}`,
|
||||
},
|
||||
NotExpectedContent: []string{
|
||||
"canBackup",
|
||||
"realIP",
|
||||
"possibleProxyHeader",
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "GET health status (regular user)",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/health",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"code":200`,
|
||||
`"data":{}`,
|
||||
},
|
||||
NotExpectedContent: []string{
|
||||
"canBackup",
|
||||
"realIP",
|
||||
"possibleProxyHeader",
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "GET health status (superuser)",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/health",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"code":200`,
|
||||
`"data":{`,
|
||||
`"canBackup":true`,
|
||||
`"realIP"`,
|
||||
`"possibleProxyHeader"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
88
apis/installer.go
Normal file
88
apis/installer.go
Normal file
|
@ -0,0 +1,88 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/fatih/color"
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tools/osutils"
|
||||
)
|
||||
|
||||
// DefaultInstallerFunc is the default PocketBase installer function.
|
||||
//
|
||||
// It will attempt to open a link in the browser (with a short-lived auth
|
||||
// token for the systemSuperuser) to the installer UI so that users can
|
||||
// create their own custom superuser record.
|
||||
//
|
||||
// See https://github.com/pocketbase/pocketbase/discussions/5814.
|
||||
func DefaultInstallerFunc(app core.App, systemSuperuser *core.Record, baseURL string) error {
|
||||
token, err := systemSuperuser.NewStaticAuthToken(30 * time.Minute)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// launch url (ignore errors and always print a help text as fallback)
|
||||
url := fmt.Sprintf("%s/_/#/pbinstal/%s", strings.TrimRight(baseURL, "/"), token)
|
||||
_ = osutils.LaunchURL(url)
|
||||
color.Magenta("\n(!) Launch the URL below in the browser if it hasn't been open already to create your first superuser account:")
|
||||
color.New(color.Bold).Add(color.FgCyan).Println(url)
|
||||
color.New(color.FgHiBlack, color.Italic).Printf("(you can also create your first superuser by running: %s superuser upsert EMAIL PASS)\n\n", os.Args[0])
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func loadInstaller(
|
||||
app core.App,
|
||||
baseURL string,
|
||||
installerFunc func(app core.App, systemSuperuser *core.Record, baseURL string) error,
|
||||
) error {
|
||||
if installerFunc == nil || !needInstallerSuperuser(app) {
|
||||
return nil
|
||||
}
|
||||
|
||||
superuser, err := findOrCreateInstallerSuperuser(app)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return installerFunc(app, superuser, baseURL)
|
||||
}
|
||||
|
||||
func needInstallerSuperuser(app core.App) bool {
|
||||
total, err := app.CountRecords(core.CollectionNameSuperusers, dbx.Not(dbx.HashExp{
|
||||
"email": core.DefaultInstallerEmail,
|
||||
}))
|
||||
|
||||
return err == nil && total == 0
|
||||
}
|
||||
|
||||
func findOrCreateInstallerSuperuser(app core.App) (*core.Record, error) {
|
||||
col, err := app.FindCachedCollectionByNameOrId(core.CollectionNameSuperusers)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
record, err := app.FindAuthRecordByEmail(col, core.DefaultInstallerEmail)
|
||||
if err != nil {
|
||||
if !errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
record = core.NewRecord(col)
|
||||
record.SetEmail(core.DefaultInstallerEmail)
|
||||
record.SetRandomPassword()
|
||||
|
||||
err = app.Save(record)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
73
apis/logs.go
Normal file
73
apis/logs.go
Normal file
|
@ -0,0 +1,73 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tools/router"
|
||||
"github.com/pocketbase/pocketbase/tools/search"
|
||||
)
|
||||
|
||||
// bindLogsApi registers the request logs api endpoints.
|
||||
func bindLogsApi(app core.App, rg *router.RouterGroup[*core.RequestEvent]) {
|
||||
sub := rg.Group("/logs").Bind(RequireSuperuserAuth(), SkipSuccessActivityLog())
|
||||
sub.GET("", logsList)
|
||||
sub.GET("/stats", logsStats)
|
||||
sub.GET("/{id}", logsView)
|
||||
}
|
||||
|
||||
var logFilterFields = []string{
|
||||
"id", "created", "level", "message", "data",
|
||||
`^data\.[\w\.\:]*\w+$`,
|
||||
}
|
||||
|
||||
func logsList(e *core.RequestEvent) error {
|
||||
fieldResolver := search.NewSimpleFieldResolver(logFilterFields...)
|
||||
|
||||
result, err := search.NewProvider(fieldResolver).
|
||||
Query(e.App.AuxModelQuery(&core.Log{})).
|
||||
ParseAndExec(e.Request.URL.Query().Encode(), &[]*core.Log{})
|
||||
|
||||
if err != nil {
|
||||
return e.BadRequestError("", err)
|
||||
}
|
||||
|
||||
return e.JSON(http.StatusOK, result)
|
||||
}
|
||||
|
||||
func logsStats(e *core.RequestEvent) error {
|
||||
fieldResolver := search.NewSimpleFieldResolver(logFilterFields...)
|
||||
|
||||
filter := e.Request.URL.Query().Get(search.FilterQueryParam)
|
||||
|
||||
var expr dbx.Expression
|
||||
if filter != "" {
|
||||
var err error
|
||||
expr, err = search.FilterData(filter).BuildExpr(fieldResolver)
|
||||
if err != nil {
|
||||
return e.BadRequestError("Invalid filter format.", err)
|
||||
}
|
||||
}
|
||||
|
||||
stats, err := e.App.LogsStats(expr)
|
||||
if err != nil {
|
||||
return e.BadRequestError("Failed to generate logs stats.", err)
|
||||
}
|
||||
|
||||
return e.JSON(http.StatusOK, stats)
|
||||
}
|
||||
|
||||
func logsView(e *core.RequestEvent) error {
|
||||
id := e.Request.PathValue("id")
|
||||
if id == "" {
|
||||
return e.NotFoundError("", nil)
|
||||
}
|
||||
|
||||
log, err := e.App.FindLogById(id)
|
||||
if err != nil || log == nil {
|
||||
return e.NotFoundError("", err)
|
||||
}
|
||||
|
||||
return e.JSON(http.StatusOK, log)
|
||||
}
|
212
apis/logs_test.go
Normal file
212
apis/logs_test.go
Normal file
|
@ -0,0 +1,212 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestLogsList(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "unauthorized",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/logs",
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as regular user",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/logs",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/logs",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubLogsData(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"page":1`,
|
||||
`"perPage":30`,
|
||||
`"totalItems":2`,
|
||||
`"items":[{`,
|
||||
`"id":"873f2133-9f38-44fb-bf82-c8f53b310d91"`,
|
||||
`"id":"f2133873-44fb-9f38-bf82-c918f53b310d"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser + filter",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/logs?filter=data.status>200",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubLogsData(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"page":1`,
|
||||
`"perPage":30`,
|
||||
`"totalItems":1`,
|
||||
`"items":[{`,
|
||||
`"id":"f2133873-44fb-9f38-bf82-c918f53b310d"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogView(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "unauthorized",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/logs/873f2133-9f38-44fb-bf82-c8f53b310d91",
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as regular user",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/logs/873f2133-9f38-44fb-bf82-c8f53b310d91",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser (nonexisting request log)",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/logs/missing1-9f38-44fb-bf82-c8f53b310d91",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubLogsData(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser (existing request log)",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/logs/873f2133-9f38-44fb-bf82-c8f53b310d91",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubLogsData(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"id":"873f2133-9f38-44fb-bf82-c8f53b310d91"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogsStats(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "unauthorized",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/logs/stats",
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as regular user",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/logs/stats",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/logs/stats",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubLogsData(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`[{"date":"2022-05-01 10:00:00.000Z","total":1},{"date":"2022-05-02 10:00:00.000Z","total":1}]`,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser + filter",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/logs/stats?filter=data.status>200",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubLogsData(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`[{"date":"2022-05-02 10:00:00.000Z","total":1}]`,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
444
apis/middlewares.go
Normal file
444
apis/middlewares.go
Normal file
|
@ -0,0 +1,444 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tools/hook"
|
||||
"github.com/pocketbase/pocketbase/tools/list"
|
||||
"github.com/pocketbase/pocketbase/tools/router"
|
||||
"github.com/pocketbase/pocketbase/tools/routine"
|
||||
"github.com/spf13/cast"
|
||||
)
|
||||
|
||||
// Common request event store keys used by the middlewares and api handlers.
|
||||
const (
|
||||
RequestEventKeyLogMeta = "pbLogMeta" // extra data to store with the request activity log
|
||||
|
||||
requestEventKeyExecStart = "__execStart" // the value must be time.Time
|
||||
requestEventKeySkipSuccessActivityLog = "__skipSuccessActivityLogger" // the value must be bool
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultWWWRedirectMiddlewarePriority = -99999
|
||||
DefaultWWWRedirectMiddlewareId = "pbWWWRedirect"
|
||||
|
||||
DefaultActivityLoggerMiddlewarePriority = DefaultRateLimitMiddlewarePriority - 40
|
||||
DefaultActivityLoggerMiddlewareId = "pbActivityLogger"
|
||||
DefaultSkipSuccessActivityLogMiddlewareId = "pbSkipSuccessActivityLog"
|
||||
DefaultEnableAuthIdActivityLog = "pbEnableAuthIdActivityLog"
|
||||
|
||||
DefaultPanicRecoverMiddlewarePriority = DefaultRateLimitMiddlewarePriority - 30
|
||||
DefaultPanicRecoverMiddlewareId = "pbPanicRecover"
|
||||
|
||||
DefaultLoadAuthTokenMiddlewarePriority = DefaultRateLimitMiddlewarePriority - 20
|
||||
DefaultLoadAuthTokenMiddlewareId = "pbLoadAuthToken"
|
||||
|
||||
DefaultSecurityHeadersMiddlewarePriority = DefaultRateLimitMiddlewarePriority - 10
|
||||
DefaultSecurityHeadersMiddlewareId = "pbSecurityHeaders"
|
||||
|
||||
DefaultRequireGuestOnlyMiddlewareId = "pbRequireGuestOnly"
|
||||
DefaultRequireAuthMiddlewareId = "pbRequireAuth"
|
||||
DefaultRequireSuperuserAuthMiddlewareId = "pbRequireSuperuserAuth"
|
||||
DefaultRequireSuperuserOrOwnerAuthMiddlewareId = "pbRequireSuperuserOrOwnerAuth"
|
||||
DefaultRequireSameCollectionContextAuthMiddlewareId = "pbRequireSameCollectionContextAuth"
|
||||
)
|
||||
|
||||
// RequireGuestOnly middleware requires a request to NOT have a valid
|
||||
// Authorization header.
|
||||
//
|
||||
// This middleware is the opposite of [apis.RequireAuth()].
|
||||
func RequireGuestOnly() *hook.Handler[*core.RequestEvent] {
|
||||
return &hook.Handler[*core.RequestEvent]{
|
||||
Id: DefaultRequireGuestOnlyMiddlewareId,
|
||||
Func: func(e *core.RequestEvent) error {
|
||||
if e.Auth != nil {
|
||||
return router.NewBadRequestError("The request can be accessed only by guests.", nil)
|
||||
}
|
||||
|
||||
return e.Next()
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// RequireAuth middleware requires a request to have a valid record Authorization header.
|
||||
//
|
||||
// The auth record could be from any collection.
|
||||
// You can further filter the allowed record auth collections by specifying their names.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// apis.RequireAuth() // any auth collection
|
||||
// apis.RequireAuth("_superusers", "users") // only the listed auth collections
|
||||
func RequireAuth(optCollectionNames ...string) *hook.Handler[*core.RequestEvent] {
|
||||
return &hook.Handler[*core.RequestEvent]{
|
||||
Id: DefaultRequireAuthMiddlewareId,
|
||||
Func: requireAuth(optCollectionNames...),
|
||||
}
|
||||
}
|
||||
|
||||
func requireAuth(optCollectionNames ...string) func(*core.RequestEvent) error {
|
||||
return func(e *core.RequestEvent) error {
|
||||
if e.Auth == nil {
|
||||
return e.UnauthorizedError("The request requires valid record authorization token.", nil)
|
||||
}
|
||||
|
||||
// check record collection name
|
||||
if len(optCollectionNames) > 0 && !slices.Contains(optCollectionNames, e.Auth.Collection().Name) {
|
||||
return e.ForbiddenError("The authorized record is not allowed to perform this action.", nil)
|
||||
}
|
||||
|
||||
return e.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// RequireSuperuserAuth middleware requires a request to have
|
||||
// a valid superuser Authorization header.
|
||||
func RequireSuperuserAuth() *hook.Handler[*core.RequestEvent] {
|
||||
return &hook.Handler[*core.RequestEvent]{
|
||||
Id: DefaultRequireSuperuserAuthMiddlewareId,
|
||||
Func: requireAuth(core.CollectionNameSuperusers),
|
||||
}
|
||||
}
|
||||
|
||||
// RequireSuperuserOrOwnerAuth middleware requires a request to have
|
||||
// a valid superuser or regular record owner Authorization header set.
|
||||
//
|
||||
// This middleware is similar to [apis.RequireAuth()] but
|
||||
// for the auth record token expects to have the same id as the path
|
||||
// parameter ownerIdPathParam (default to "id" if empty).
|
||||
func RequireSuperuserOrOwnerAuth(ownerIdPathParam string) *hook.Handler[*core.RequestEvent] {
|
||||
return &hook.Handler[*core.RequestEvent]{
|
||||
Id: DefaultRequireSuperuserOrOwnerAuthMiddlewareId,
|
||||
Func: func(e *core.RequestEvent) error {
|
||||
if e.Auth == nil {
|
||||
return e.UnauthorizedError("The request requires superuser or record authorization token.", nil)
|
||||
}
|
||||
|
||||
if e.Auth.IsSuperuser() {
|
||||
return e.Next()
|
||||
}
|
||||
|
||||
if ownerIdPathParam == "" {
|
||||
ownerIdPathParam = "id"
|
||||
}
|
||||
ownerId := e.Request.PathValue(ownerIdPathParam)
|
||||
|
||||
// note: it is considered "safe" to compare only the record id
|
||||
// since the auth record ids are treated as unique across all auth collections
|
||||
if e.Auth.Id != ownerId {
|
||||
return e.ForbiddenError("You are not allowed to perform this request.", nil)
|
||||
}
|
||||
|
||||
return e.Next()
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// RequireSameCollectionContextAuth middleware requires a request to have
|
||||
// a valid record Authorization header and the auth record's collection to
|
||||
// match the one from the route path parameter (default to "collection" if collectionParam is empty).
|
||||
func RequireSameCollectionContextAuth(collectionPathParam string) *hook.Handler[*core.RequestEvent] {
|
||||
return &hook.Handler[*core.RequestEvent]{
|
||||
Id: DefaultRequireSameCollectionContextAuthMiddlewareId,
|
||||
Func: func(e *core.RequestEvent) error {
|
||||
if e.Auth == nil {
|
||||
return e.UnauthorizedError("The request requires valid record authorization token.", nil)
|
||||
}
|
||||
|
||||
if collectionPathParam == "" {
|
||||
collectionPathParam = "collection"
|
||||
}
|
||||
|
||||
collection, _ := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue(collectionPathParam))
|
||||
if collection == nil || e.Auth.Collection().Id != collection.Id {
|
||||
return e.ForbiddenError(fmt.Sprintf("The request requires auth record from %s collection.", e.Auth.Collection().Name), nil)
|
||||
}
|
||||
|
||||
return e.Next()
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// loadAuthToken attempts to load the auth context based on the "Authorization: TOKEN" header value.
|
||||
//
|
||||
// This middleware does nothing in case of:
|
||||
// - missing, invalid or expired token
|
||||
// - e.Auth is already loaded by another middleware
|
||||
//
|
||||
// This middleware is registered by default for all routes.
|
||||
//
|
||||
// Note: We don't throw an error on invalid or expired token to allow
|
||||
// users to extend with their own custom handling in external middleware(s).
|
||||
func loadAuthToken() *hook.Handler[*core.RequestEvent] {
|
||||
return &hook.Handler[*core.RequestEvent]{
|
||||
Id: DefaultLoadAuthTokenMiddlewareId,
|
||||
Priority: DefaultLoadAuthTokenMiddlewarePriority,
|
||||
Func: func(e *core.RequestEvent) error {
|
||||
// already loaded by another middleware
|
||||
if e.Auth != nil {
|
||||
return e.Next()
|
||||
}
|
||||
|
||||
token := getAuthTokenFromRequest(e)
|
||||
if token == "" {
|
||||
return e.Next()
|
||||
}
|
||||
|
||||
record, err := e.App.FindAuthRecordByToken(token, core.TokenTypeAuth)
|
||||
if err != nil {
|
||||
e.App.Logger().Debug("loadAuthToken failure", "error", err)
|
||||
} else if record != nil {
|
||||
e.Auth = record
|
||||
}
|
||||
|
||||
return e.Next()
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func getAuthTokenFromRequest(e *core.RequestEvent) string {
|
||||
token := e.Request.Header.Get("Authorization")
|
||||
if token != "" {
|
||||
// the schema prefix is not required and it is only for
|
||||
// compatibility with the defaults of some HTTP clients
|
||||
token = strings.TrimPrefix(token, "Bearer ")
|
||||
}
|
||||
return token
|
||||
}
|
||||
|
||||
// wwwRedirect performs www->non-www redirect(s) if the request host
|
||||
// matches with one of the values in redirectHosts.
|
||||
//
|
||||
// This middleware is registered by default on Serve for all routes.
|
||||
func wwwRedirect(redirectHosts []string) *hook.Handler[*core.RequestEvent] {
|
||||
return &hook.Handler[*core.RequestEvent]{
|
||||
Id: DefaultWWWRedirectMiddlewareId,
|
||||
Priority: DefaultWWWRedirectMiddlewarePriority,
|
||||
Func: func(e *core.RequestEvent) error {
|
||||
host := e.Request.Host
|
||||
|
||||
if strings.HasPrefix(host, "www.") && list.ExistInSlice(host, redirectHosts) {
|
||||
// note: e.Request.URL.Scheme would be empty
|
||||
schema := "http://"
|
||||
if e.IsTLS() {
|
||||
schema = "https://"
|
||||
}
|
||||
|
||||
return e.Redirect(
|
||||
http.StatusTemporaryRedirect,
|
||||
(schema + host[4:] + e.Request.RequestURI),
|
||||
)
|
||||
}
|
||||
|
||||
return e.Next()
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// panicRecover returns a default panic-recover handler.
|
||||
func panicRecover() *hook.Handler[*core.RequestEvent] {
|
||||
return &hook.Handler[*core.RequestEvent]{
|
||||
Id: DefaultPanicRecoverMiddlewareId,
|
||||
Priority: DefaultPanicRecoverMiddlewarePriority,
|
||||
Func: func(e *core.RequestEvent) (err error) {
|
||||
// panic-recover
|
||||
defer func() {
|
||||
recoverResult := recover()
|
||||
if recoverResult == nil {
|
||||
return
|
||||
}
|
||||
|
||||
recoverErr, ok := recoverResult.(error)
|
||||
if !ok {
|
||||
recoverErr = fmt.Errorf("%v", recoverResult)
|
||||
} else if errors.Is(recoverErr, http.ErrAbortHandler) {
|
||||
// don't recover ErrAbortHandler so the response to the client can be aborted
|
||||
panic(recoverResult)
|
||||
}
|
||||
|
||||
stack := make([]byte, 2<<10) // 2 KB
|
||||
length := runtime.Stack(stack, true)
|
||||
err = e.InternalServerError("", fmt.Errorf("[PANIC RECOVER] %w %s", recoverErr, stack[:length]))
|
||||
}()
|
||||
|
||||
err = e.Next()
|
||||
|
||||
return err
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// securityHeaders middleware adds common security headers to the response.
|
||||
//
|
||||
// This middleware is registered by default for all routes.
|
||||
func securityHeaders() *hook.Handler[*core.RequestEvent] {
|
||||
return &hook.Handler[*core.RequestEvent]{
|
||||
Id: DefaultSecurityHeadersMiddlewareId,
|
||||
Priority: DefaultSecurityHeadersMiddlewarePriority,
|
||||
Func: func(e *core.RequestEvent) error {
|
||||
e.Response.Header().Set("X-XSS-Protection", "1; mode=block")
|
||||
e.Response.Header().Set("X-Content-Type-Options", "nosniff")
|
||||
e.Response.Header().Set("X-Frame-Options", "SAMEORIGIN")
|
||||
|
||||
// @todo consider a default HSTS?
|
||||
// (see also https://webkit.org/blog/8146/protecting-against-hsts-abuse/)
|
||||
|
||||
return e.Next()
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// SkipSuccessActivityLog is a helper middleware that instructs the global
|
||||
// activity logger to log only requests that have failed/returned an error.
|
||||
func SkipSuccessActivityLog() *hook.Handler[*core.RequestEvent] {
|
||||
return &hook.Handler[*core.RequestEvent]{
|
||||
Id: DefaultSkipSuccessActivityLogMiddlewareId,
|
||||
Func: func(e *core.RequestEvent) error {
|
||||
e.Set(requestEventKeySkipSuccessActivityLog, true)
|
||||
return e.Next()
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// activityLogger middleware takes care to save the request information
|
||||
// into the logs database.
|
||||
//
|
||||
// This middleware is registered by default for all routes.
|
||||
//
|
||||
// The middleware does nothing if the app logs retention period is zero
|
||||
// (aka. app.Settings().Logs.MaxDays = 0).
|
||||
//
|
||||
// Users can attach the [apis.SkipSuccessActivityLog()] middleware if
|
||||
// you want to log only the failed requests.
|
||||
func activityLogger() *hook.Handler[*core.RequestEvent] {
|
||||
return &hook.Handler[*core.RequestEvent]{
|
||||
Id: DefaultActivityLoggerMiddlewareId,
|
||||
Priority: DefaultActivityLoggerMiddlewarePriority,
|
||||
Func: func(e *core.RequestEvent) error {
|
||||
e.Set(requestEventKeyExecStart, time.Now())
|
||||
|
||||
err := e.Next()
|
||||
|
||||
logRequest(e, err)
|
||||
|
||||
return err
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func logRequest(event *core.RequestEvent, err error) {
|
||||
// no logs retention
|
||||
if event.App.Settings().Logs.MaxDays == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// the non-error route has explicitly disabled the activity logger
|
||||
if err == nil && event.Get(requestEventKeySkipSuccessActivityLog) != nil {
|
||||
return
|
||||
}
|
||||
|
||||
attrs := make([]any, 0, 15)
|
||||
|
||||
attrs = append(attrs, slog.String("type", "request"))
|
||||
|
||||
started := cast.ToTime(event.Get(requestEventKeyExecStart))
|
||||
if !started.IsZero() {
|
||||
attrs = append(attrs, slog.Float64("execTime", float64(time.Since(started))/float64(time.Millisecond)))
|
||||
}
|
||||
|
||||
if meta := event.Get(RequestEventKeyLogMeta); meta != nil {
|
||||
attrs = append(attrs, slog.Any("meta", meta))
|
||||
}
|
||||
|
||||
status := event.Status()
|
||||
method := cutStr(strings.ToUpper(event.Request.Method), 50)
|
||||
requestUri := cutStr(event.Request.URL.RequestURI(), 3000)
|
||||
|
||||
// parse the request error
|
||||
if err != nil {
|
||||
apiErr, isPlainApiError := err.(*router.ApiError)
|
||||
if isPlainApiError || errors.As(err, &apiErr) {
|
||||
// the status header wasn't written yet
|
||||
if status == 0 {
|
||||
status = apiErr.Status
|
||||
}
|
||||
|
||||
var errMsg string
|
||||
if isPlainApiError {
|
||||
errMsg = apiErr.Message
|
||||
} else {
|
||||
// wrapped ApiError -> add the full serialized version
|
||||
// of the original error since it could contain more information
|
||||
errMsg = err.Error()
|
||||
}
|
||||
|
||||
attrs = append(
|
||||
attrs,
|
||||
slog.String("error", errMsg),
|
||||
slog.Any("details", apiErr.RawData()),
|
||||
)
|
||||
} else {
|
||||
attrs = append(attrs, slog.String("error", err.Error()))
|
||||
}
|
||||
}
|
||||
|
||||
attrs = append(
|
||||
attrs,
|
||||
slog.String("url", requestUri),
|
||||
slog.String("method", method),
|
||||
slog.Int("status", status),
|
||||
slog.String("referer", cutStr(event.Request.Referer(), 2000)),
|
||||
slog.String("userAgent", cutStr(event.Request.UserAgent(), 2000)),
|
||||
)
|
||||
|
||||
if event.Auth != nil {
|
||||
attrs = append(attrs, slog.String("auth", event.Auth.Collection().Name))
|
||||
|
||||
if event.App.Settings().Logs.LogAuthId {
|
||||
attrs = append(attrs, slog.String("authId", event.Auth.Id))
|
||||
}
|
||||
} else {
|
||||
attrs = append(attrs, slog.String("auth", ""))
|
||||
}
|
||||
|
||||
if event.App.Settings().Logs.LogIP {
|
||||
attrs = append(
|
||||
attrs,
|
||||
slog.String("userIP", event.RealIP()),
|
||||
slog.String("remoteIP", event.RemoteIP()),
|
||||
)
|
||||
}
|
||||
|
||||
// don't block on logs write
|
||||
routine.FireAndForget(func() {
|
||||
message := method + " "
|
||||
|
||||
if escaped, unescapeErr := url.PathUnescape(requestUri); unescapeErr == nil {
|
||||
message += escaped
|
||||
} else {
|
||||
message += requestUri
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
event.App.Logger().Error(message, attrs...)
|
||||
} else {
|
||||
event.App.Logger().Info(message, attrs...)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func cutStr(str string, max int) string {
|
||||
if len(str) > max {
|
||||
return str[:max] + "..."
|
||||
}
|
||||
return str
|
||||
}
|
120
apis/middlewares_body_limit.go
Normal file
120
apis/middlewares_body_limit.go
Normal file
|
@ -0,0 +1,120 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tools/hook"
|
||||
"github.com/pocketbase/pocketbase/tools/router"
|
||||
)
|
||||
|
||||
var ErrRequestEntityTooLarge = router.NewApiError(http.StatusRequestEntityTooLarge, "Request entity too large", nil)
|
||||
|
||||
const DefaultMaxBodySize int64 = 32 << 20
|
||||
|
||||
const (
|
||||
DefaultBodyLimitMiddlewareId = "pbBodyLimit"
|
||||
DefaultBodyLimitMiddlewarePriority = DefaultRateLimitMiddlewarePriority + 10
|
||||
)
|
||||
|
||||
// BodyLimit returns a middleware handler that changes the default request body size limit.
|
||||
//
|
||||
// If limitBytes <= 0, no limit is applied.
|
||||
//
|
||||
// Otherwise, if the request body size exceeds the configured limitBytes,
|
||||
// it sends 413 error response.
|
||||
func BodyLimit(limitBytes int64) *hook.Handler[*core.RequestEvent] {
|
||||
return &hook.Handler[*core.RequestEvent]{
|
||||
Id: DefaultBodyLimitMiddlewareId,
|
||||
Priority: DefaultBodyLimitMiddlewarePriority,
|
||||
Func: func(e *core.RequestEvent) error {
|
||||
err := applyBodyLimit(e, limitBytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return e.Next()
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func dynamicCollectionBodyLimit(collectionPathParam string) *hook.Handler[*core.RequestEvent] {
|
||||
if collectionPathParam == "" {
|
||||
collectionPathParam = "collection"
|
||||
}
|
||||
|
||||
return &hook.Handler[*core.RequestEvent]{
|
||||
Id: DefaultBodyLimitMiddlewareId,
|
||||
Priority: DefaultBodyLimitMiddlewarePriority,
|
||||
Func: func(e *core.RequestEvent) error {
|
||||
collection, err := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue(collectionPathParam))
|
||||
if err != nil {
|
||||
return e.NotFoundError("Missing or invalid collection context.", err)
|
||||
}
|
||||
|
||||
limitBytes := DefaultMaxBodySize
|
||||
if !collection.IsView() {
|
||||
for _, f := range collection.Fields {
|
||||
if calc, ok := f.(core.MaxBodySizeCalculator); ok {
|
||||
limitBytes += calc.CalculateMaxBodySize()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
err = applyBodyLimit(e, limitBytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return e.Next()
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func applyBodyLimit(e *core.RequestEvent, limitBytes int64) error {
|
||||
// no limit
|
||||
if limitBytes <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// optimistically check the submitted request content length
|
||||
if e.Request.ContentLength > limitBytes {
|
||||
return ErrRequestEntityTooLarge
|
||||
}
|
||||
|
||||
// replace the request body
|
||||
//
|
||||
// note: we don't use sync.Pool since the size of the elements could vary too much
|
||||
// and it might not be efficient (see https://github.com/golang/go/issues/23199)
|
||||
e.Request.Body = &limitedReader{ReadCloser: e.Request.Body, limit: limitBytes}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type limitedReader struct {
|
||||
io.ReadCloser
|
||||
limit int64
|
||||
totalRead int64
|
||||
}
|
||||
|
||||
func (r *limitedReader) Read(b []byte) (int, error) {
|
||||
n, err := r.ReadCloser.Read(b)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
|
||||
r.totalRead += int64(n)
|
||||
if r.totalRead > r.limit {
|
||||
return n, ErrRequestEntityTooLarge
|
||||
}
|
||||
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (r *limitedReader) Reread() {
|
||||
rr, ok := r.ReadCloser.(router.Rereader)
|
||||
if ok {
|
||||
rr.Reread()
|
||||
}
|
||||
}
|
60
apis/middlewares_body_limit_test.go
Normal file
60
apis/middlewares_body_limit_test.go
Normal file
|
@ -0,0 +1,60 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/apis"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestBodyLimitMiddleware(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
pbRouter, err := apis.NewRouter(app)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
pbRouter.POST("/a", func(e *core.RequestEvent) error {
|
||||
return e.String(200, "a")
|
||||
}) // default global BodyLimit check
|
||||
|
||||
pbRouter.POST("/b", func(e *core.RequestEvent) error {
|
||||
return e.String(200, "b")
|
||||
}).Bind(apis.BodyLimit(20))
|
||||
|
||||
mux, err := pbRouter.BuildMux()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
scenarios := []struct {
|
||||
url string
|
||||
size int64
|
||||
expectedStatus int
|
||||
}{
|
||||
{"/a", 21, 200},
|
||||
{"/a", apis.DefaultMaxBodySize + 1, 413},
|
||||
{"/b", 20, 200},
|
||||
{"/b", 21, 413},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%s_%d", s.url, s.size), func(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("POST", s.url, bytes.NewReader(make([]byte, s.size)))
|
||||
mux.ServeHTTP(rec, req)
|
||||
|
||||
result := rec.Result()
|
||||
defer result.Body.Close()
|
||||
|
||||
if result.StatusCode != s.expectedStatus {
|
||||
t.Fatalf("Expected response status %d, got %d", s.expectedStatus, result.StatusCode)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
327
apis/middlewares_cors.go
Normal file
327
apis/middlewares_cors.go
Normal file
|
@ -0,0 +1,327 @@
|
|||
package apis
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
// This middleware is ported from echo/middleware to minimize the breaking
|
||||
// changes and differences in the API behavior from earlier PocketBase versions
|
||||
// (https://github.com/labstack/echo/blob/ec5b858dab6105ab4c3ed2627d1ebdfb6ae1ecb8/middleware/cors.go).
|
||||
//
|
||||
// I doubt that this would matter for most cases, but the only major difference
|
||||
// is that for non-supported routes this middleware doesn't return 405 and fallbacks
|
||||
// to the default catch-all PocketBase route (aka. returns 404) to avoid
|
||||
// the extra overhead of further hijacking and wrapping the Go default mux
|
||||
// (https://github.com/golang/go/issues/65648#issuecomment-1955328807).
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tools/hook"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultCorsMiddlewareId = "pbCors"
|
||||
DefaultCorsMiddlewarePriority = DefaultActivityLoggerMiddlewarePriority - 1 // before the activity logger and rate limit so that OPTIONS preflight requests are not counted
|
||||
)
|
||||
|
||||
// CORSConfig defines the config for CORS middleware.
|
||||
type CORSConfig struct {
|
||||
// AllowOrigins determines the value of the Access-Control-Allow-Origin
|
||||
// response header. This header defines a list of origins that may access the
|
||||
// resource. The wildcard characters '*' and '?' are supported and are
|
||||
// converted to regex fragments '.*' and '.' accordingly.
|
||||
//
|
||||
// Security: use extreme caution when handling the origin, and carefully
|
||||
// validate any logic. Remember that attackers may register hostile domain names.
|
||||
// See https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html
|
||||
//
|
||||
// Optional. Default value []string{"*"}.
|
||||
//
|
||||
// See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin
|
||||
AllowOrigins []string
|
||||
|
||||
// AllowOriginFunc is a custom function to validate the origin. It takes the
|
||||
// origin as an argument and returns true if allowed or false otherwise. If
|
||||
// an error is returned, it is returned by the handler. If this option is
|
||||
// set, AllowOrigins is ignored.
|
||||
//
|
||||
// Security: use extreme caution when handling the origin, and carefully
|
||||
// validate any logic. Remember that attackers may register hostile domain names.
|
||||
// See https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html
|
||||
//
|
||||
// Optional.
|
||||
AllowOriginFunc func(origin string) (bool, error)
|
||||
|
||||
// AllowMethods determines the value of the Access-Control-Allow-Methods
|
||||
// response header. This header specified the list of methods allowed when
|
||||
// accessing the resource. This is used in response to a preflight request.
|
||||
//
|
||||
// Optional. Default value DefaultCORSConfig.AllowMethods.
|
||||
//
|
||||
// See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Methods
|
||||
AllowMethods []string
|
||||
|
||||
// AllowHeaders determines the value of the Access-Control-Allow-Headers
|
||||
// response header. This header is used in response to a preflight request to
|
||||
// indicate which HTTP headers can be used when making the actual request.
|
||||
//
|
||||
// Optional. Default value []string{}.
|
||||
//
|
||||
// See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers
|
||||
AllowHeaders []string
|
||||
|
||||
// AllowCredentials determines the value of the
|
||||
// Access-Control-Allow-Credentials response header. This header indicates
|
||||
// whether or not the response to the request can be exposed when the
|
||||
// credentials mode (Request.credentials) is true. When used as part of a
|
||||
// response to a preflight request, this indicates whether or not the actual
|
||||
// request can be made using credentials. See also
|
||||
// [MDN: Access-Control-Allow-Credentials].
|
||||
//
|
||||
// Optional. Default value false, in which case the header is not set.
|
||||
//
|
||||
// Security: avoid using `AllowCredentials = true` with `AllowOrigins = *`.
|
||||
// See "Exploiting CORS misconfigurations for Bitcoins and bounties",
|
||||
// https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html
|
||||
//
|
||||
// See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Credentials
|
||||
AllowCredentials bool
|
||||
|
||||
// UnsafeWildcardOriginWithAllowCredentials UNSAFE/INSECURE: allows wildcard '*' origin to be used with AllowCredentials
|
||||
// flag. In that case we consider any origin allowed and send it back to the client with `Access-Control-Allow-Origin` header.
|
||||
//
|
||||
// This is INSECURE and potentially leads to [cross-origin](https://portswigger.net/research/exploiting-cors-misconfigurations-for-bitcoins-and-bounties)
|
||||
// attacks. See: https://github.com/labstack/echo/issues/2400 for discussion on the subject.
|
||||
//
|
||||
// Optional. Default value is false.
|
||||
UnsafeWildcardOriginWithAllowCredentials bool
|
||||
|
||||
// ExposeHeaders determines the value of Access-Control-Expose-Headers, which
|
||||
// defines a list of headers that clients are allowed to access.
|
||||
//
|
||||
// Optional. Default value []string{}, in which case the header is not set.
|
||||
//
|
||||
// See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Expose-Header
|
||||
ExposeHeaders []string
|
||||
|
||||
// MaxAge determines the value of the Access-Control-Max-Age response header.
|
||||
// This header indicates how long (in seconds) the results of a preflight
|
||||
// request can be cached.
|
||||
// The header is set only if MaxAge != 0, negative value sends "0" which instructs browsers not to cache that response.
|
||||
//
|
||||
// Optional. Default value 0 - meaning header is not sent.
|
||||
//
|
||||
// See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Max-Age
|
||||
MaxAge int
|
||||
}
|
||||
|
||||
// DefaultCORSConfig is the default CORS middleware config.
|
||||
var DefaultCORSConfig = CORSConfig{
|
||||
AllowOrigins: []string{"*"},
|
||||
AllowMethods: []string{http.MethodGet, http.MethodHead, http.MethodPut, http.MethodPatch, http.MethodPost, http.MethodDelete},
|
||||
}
|
||||
|
||||
// CORS returns a CORS middleware.
|
||||
func CORS(config CORSConfig) *hook.Handler[*core.RequestEvent] {
|
||||
// Defaults
|
||||
if len(config.AllowOrigins) == 0 {
|
||||
config.AllowOrigins = DefaultCORSConfig.AllowOrigins
|
||||
}
|
||||
if len(config.AllowMethods) == 0 {
|
||||
config.AllowMethods = DefaultCORSConfig.AllowMethods
|
||||
}
|
||||
|
||||
allowOriginPatterns := make([]*regexp.Regexp, 0, len(config.AllowOrigins))
|
||||
for _, origin := range config.AllowOrigins {
|
||||
if origin == "*" {
|
||||
continue // "*" is handled differently and does not need regexp
|
||||
}
|
||||
|
||||
pattern := regexp.QuoteMeta(origin)
|
||||
pattern = strings.ReplaceAll(pattern, "\\*", ".*")
|
||||
pattern = strings.ReplaceAll(pattern, "\\?", ".")
|
||||
pattern = "^" + pattern + "$"
|
||||
|
||||
re, err := regexp.Compile(pattern)
|
||||
if err != nil {
|
||||
// This is to preserve previous behaviour - invalid patterns were just ignored.
|
||||
// If we would turn this to panic, users with invalid patterns
|
||||
// would have applications crashing in production due unrecovered panic.
|
||||
log.Println("invalid AllowOrigins pattern", origin)
|
||||
continue
|
||||
}
|
||||
allowOriginPatterns = append(allowOriginPatterns, re)
|
||||
}
|
||||
|
||||
allowMethods := strings.Join(config.AllowMethods, ",")
|
||||
allowHeaders := strings.Join(config.AllowHeaders, ",")
|
||||
exposeHeaders := strings.Join(config.ExposeHeaders, ",")
|
||||
|
||||
maxAge := "0"
|
||||
if config.MaxAge > 0 {
|
||||
maxAge = strconv.Itoa(config.MaxAge)
|
||||
}
|
||||
|
||||
return &hook.Handler[*core.RequestEvent]{
|
||||
Id: DefaultCorsMiddlewareId,
|
||||
Priority: DefaultCorsMiddlewarePriority,
|
||||
Func: func(e *core.RequestEvent) error {
|
||||
req := e.Request
|
||||
res := e.Response
|
||||
origin := req.Header.Get("Origin")
|
||||
allowOrigin := ""
|
||||
|
||||
res.Header().Add("Vary", "Origin")
|
||||
|
||||
// Preflight request is an OPTIONS request, using three HTTP request headers: Access-Control-Request-Method,
|
||||
// Access-Control-Request-Headers, and the Origin header. See: https://developer.mozilla.org/en-US/docs/Glossary/Preflight_request
|
||||
// For simplicity we just consider method type and later `Origin` header.
|
||||
preflight := req.Method == http.MethodOptions
|
||||
|
||||
// No Origin provided. This is (probably) not request from actual browser - proceed executing middleware chain
|
||||
if origin == "" {
|
||||
if !preflight {
|
||||
return e.Next()
|
||||
}
|
||||
return e.NoContent(http.StatusNoContent)
|
||||
}
|
||||
|
||||
if config.AllowOriginFunc != nil {
|
||||
allowed, err := config.AllowOriginFunc(origin)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if allowed {
|
||||
allowOrigin = origin
|
||||
}
|
||||
} else {
|
||||
// Check allowed origins
|
||||
for _, o := range config.AllowOrigins {
|
||||
if o == "*" && config.AllowCredentials && config.UnsafeWildcardOriginWithAllowCredentials {
|
||||
allowOrigin = origin
|
||||
break
|
||||
}
|
||||
if o == "*" || o == origin {
|
||||
allowOrigin = o
|
||||
break
|
||||
}
|
||||
if matchSubdomain(origin, o) {
|
||||
allowOrigin = origin
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
checkPatterns := false
|
||||
if allowOrigin == "" {
|
||||
// to avoid regex cost by invalid (long) domains (253 is domain name max limit)
|
||||
if len(origin) <= (253+3+5) && strings.Contains(origin, "://") {
|
||||
checkPatterns = true
|
||||
}
|
||||
}
|
||||
if checkPatterns {
|
||||
for _, re := range allowOriginPatterns {
|
||||
if match := re.MatchString(origin); match {
|
||||
allowOrigin = origin
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Origin not allowed
|
||||
if allowOrigin == "" {
|
||||
if !preflight {
|
||||
return e.Next()
|
||||
}
|
||||
return e.NoContent(http.StatusNoContent)
|
||||
}
|
||||
|
||||
res.Header().Set("Access-Control-Allow-Origin", allowOrigin)
|
||||
if config.AllowCredentials {
|
||||
res.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
}
|
||||
|
||||
// Simple request
|
||||
if !preflight {
|
||||
if exposeHeaders != "" {
|
||||
res.Header().Set("Access-Control-Expose-Headers", exposeHeaders)
|
||||
}
|
||||
return e.Next()
|
||||
}
|
||||
|
||||
// Preflight request
|
||||
res.Header().Add("Vary", "Access-Control-Request-Method")
|
||||
res.Header().Add("Vary", "Access-Control-Request-Headers")
|
||||
res.Header().Set("Access-Control-Allow-Methods", allowMethods)
|
||||
|
||||
if allowHeaders != "" {
|
||||
res.Header().Set("Access-Control-Allow-Headers", allowHeaders)
|
||||
} else {
|
||||
h := req.Header.Get("Access-Control-Request-Headers")
|
||||
if h != "" {
|
||||
res.Header().Set("Access-Control-Allow-Headers", h)
|
||||
}
|
||||
}
|
||||
if config.MaxAge != 0 {
|
||||
res.Header().Set("Access-Control-Max-Age", maxAge)
|
||||
}
|
||||
|
||||
return e.NoContent(http.StatusNoContent)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func matchScheme(domain, pattern string) bool {
|
||||
didx := strings.Index(domain, ":")
|
||||
pidx := strings.Index(pattern, ":")
|
||||
return didx != -1 && pidx != -1 && domain[:didx] == pattern[:pidx]
|
||||
}
|
||||
|
||||
// matchSubdomain compares authority with wildcard
|
||||
func matchSubdomain(domain, pattern string) bool {
|
||||
if !matchScheme(domain, pattern) {
|
||||
return false
|
||||
}
|
||||
|
||||
didx := strings.Index(domain, "://")
|
||||
pidx := strings.Index(pattern, "://")
|
||||
if didx == -1 || pidx == -1 {
|
||||
return false
|
||||
}
|
||||
domAuth := domain[didx+3:]
|
||||
// to avoid long loop by invalid long domain
|
||||
if len(domAuth) > 253 {
|
||||
return false
|
||||
}
|
||||
patAuth := pattern[pidx+3:]
|
||||
|
||||
domComp := strings.Split(domAuth, ".")
|
||||
patComp := strings.Split(patAuth, ".")
|
||||
for i := len(domComp)/2 - 1; i >= 0; i-- {
|
||||
opp := len(domComp) - 1 - i
|
||||
domComp[i], domComp[opp] = domComp[opp], domComp[i]
|
||||
}
|
||||
for i := len(patComp)/2 - 1; i >= 0; i-- {
|
||||
opp := len(patComp) - 1 - i
|
||||
patComp[i], patComp[opp] = patComp[opp], patComp[i]
|
||||
}
|
||||
|
||||
for i, v := range domComp {
|
||||
if len(patComp) <= i {
|
||||
return false
|
||||
}
|
||||
p := patComp[i]
|
||||
if p == "*" {
|
||||
return true
|
||||
}
|
||||
if p != v {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
247
apis/middlewares_gzip.go
Normal file
247
apis/middlewares_gzip.go
Normal file
|
@ -0,0 +1,247 @@
|
|||
package apis
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
// This middleware is ported from echo/middleware to minimize the breaking
|
||||
// changes and differences in the API behavior from earlier PocketBase versions
|
||||
// (https://github.com/labstack/echo/blob/ec5b858dab6105ab4c3ed2627d1ebdfb6ae1ecb8/middleware/compress.go).
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tools/hook"
|
||||
"github.com/pocketbase/pocketbase/tools/router"
|
||||
)
|
||||
|
||||
const (
|
||||
gzipScheme = "gzip"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultGzipMiddlewareId = "pbGzip"
|
||||
)
|
||||
|
||||
// GzipConfig defines the config for Gzip middleware.
|
||||
type GzipConfig struct {
|
||||
// Gzip compression level.
|
||||
// Optional. Default value -1.
|
||||
Level int
|
||||
|
||||
// Length threshold before gzip compression is applied.
|
||||
// Optional. Default value 0.
|
||||
//
|
||||
// Most of the time you will not need to change the default. Compressing
|
||||
// a short response might increase the transmitted data because of the
|
||||
// gzip format overhead. Compressing the response will also consume CPU
|
||||
// and time on the server and the client (for decompressing). Depending on
|
||||
// your use case such a threshold might be useful.
|
||||
//
|
||||
// See also:
|
||||
// https://webmasters.stackexchange.com/questions/31750/what-is-recommended-minimum-object-size-for-gzip-performance-benefits
|
||||
MinLength int
|
||||
}
|
||||
|
||||
// Gzip returns a middleware which compresses HTTP response using Gzip compression scheme.
|
||||
func Gzip() *hook.Handler[*core.RequestEvent] {
|
||||
return GzipWithConfig(GzipConfig{})
|
||||
}
|
||||
|
||||
// GzipWithConfig returns a middleware which compresses HTTP response using gzip compression scheme.
|
||||
func GzipWithConfig(config GzipConfig) *hook.Handler[*core.RequestEvent] {
|
||||
if config.Level < -2 || config.Level > 9 { // these are consts: gzip.HuffmanOnly and gzip.BestCompression
|
||||
panic(errors.New("invalid gzip level"))
|
||||
}
|
||||
if config.Level == 0 {
|
||||
config.Level = -1
|
||||
}
|
||||
if config.MinLength < 0 {
|
||||
config.MinLength = 0
|
||||
}
|
||||
|
||||
pool := sync.Pool{
|
||||
New: func() interface{} {
|
||||
w, err := gzip.NewWriterLevel(io.Discard, config.Level)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return w
|
||||
},
|
||||
}
|
||||
|
||||
bpool := sync.Pool{
|
||||
New: func() interface{} {
|
||||
b := &bytes.Buffer{}
|
||||
return b
|
||||
},
|
||||
}
|
||||
|
||||
return &hook.Handler[*core.RequestEvent]{
|
||||
Id: DefaultGzipMiddlewareId,
|
||||
Func: func(e *core.RequestEvent) error {
|
||||
e.Response.Header().Add("Vary", "Accept-Encoding")
|
||||
if strings.Contains(e.Request.Header.Get("Accept-Encoding"), gzipScheme) {
|
||||
w, ok := pool.Get().(*gzip.Writer)
|
||||
if !ok {
|
||||
return e.InternalServerError("", errors.New("failed to get gzip.Writer"))
|
||||
}
|
||||
|
||||
rw := e.Response
|
||||
w.Reset(rw)
|
||||
|
||||
buf := bpool.Get().(*bytes.Buffer)
|
||||
buf.Reset()
|
||||
|
||||
grw := &gzipResponseWriter{Writer: w, ResponseWriter: rw, minLength: config.MinLength, buffer: buf}
|
||||
defer func() {
|
||||
// There are different reasons for cases when we have not yet written response to the client and now need to do so.
|
||||
// a) handler response had only response code and no response body (ala 404 or redirects etc). Response code need to be written now.
|
||||
// b) body is shorter than our minimum length threshold and being buffered currently and needs to be written
|
||||
if !grw.wroteBody {
|
||||
if rw.Header().Get("Content-Encoding") == gzipScheme {
|
||||
rw.Header().Del("Content-Encoding")
|
||||
}
|
||||
if grw.wroteHeader {
|
||||
rw.WriteHeader(grw.code)
|
||||
}
|
||||
// We have to reset response to it's pristine state when
|
||||
// nothing is written to body or error is returned.
|
||||
// See issue echo#424, echo#407.
|
||||
e.Response = rw
|
||||
w.Reset(io.Discard)
|
||||
} else if !grw.minLengthExceeded {
|
||||
// Write uncompressed response
|
||||
e.Response = rw
|
||||
if grw.wroteHeader {
|
||||
rw.WriteHeader(grw.code)
|
||||
}
|
||||
grw.buffer.WriteTo(rw)
|
||||
w.Reset(io.Discard)
|
||||
}
|
||||
w.Close()
|
||||
bpool.Put(buf)
|
||||
pool.Put(w)
|
||||
}()
|
||||
e.Response = grw
|
||||
}
|
||||
|
||||
return e.Next()
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
type gzipResponseWriter struct {
|
||||
http.ResponseWriter
|
||||
io.Writer
|
||||
buffer *bytes.Buffer
|
||||
minLength int
|
||||
code int
|
||||
wroteHeader bool
|
||||
wroteBody bool
|
||||
minLengthExceeded bool
|
||||
}
|
||||
|
||||
func (w *gzipResponseWriter) WriteHeader(code int) {
|
||||
w.Header().Del("Content-Length") // Issue echo#444
|
||||
|
||||
w.wroteHeader = true
|
||||
|
||||
// Delay writing of the header until we know if we'll actually compress the response
|
||||
w.code = code
|
||||
}
|
||||
|
||||
func (w *gzipResponseWriter) Write(b []byte) (int, error) {
|
||||
if w.Header().Get("Content-Type") == "" {
|
||||
w.Header().Set("Content-Type", http.DetectContentType(b))
|
||||
}
|
||||
|
||||
w.wroteBody = true
|
||||
|
||||
if !w.minLengthExceeded {
|
||||
n, err := w.buffer.Write(b)
|
||||
|
||||
if w.buffer.Len() >= w.minLength {
|
||||
w.minLengthExceeded = true
|
||||
|
||||
// The minimum length is exceeded, add Content-Encoding header and write the header
|
||||
w.Header().Set("Content-Encoding", gzipScheme)
|
||||
if w.wroteHeader {
|
||||
w.ResponseWriter.WriteHeader(w.code)
|
||||
}
|
||||
|
||||
return w.Writer.Write(w.buffer.Bytes())
|
||||
}
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
||||
return w.Writer.Write(b)
|
||||
}
|
||||
|
||||
func (w *gzipResponseWriter) Flush() {
|
||||
if !w.minLengthExceeded {
|
||||
// Enforce compression because we will not know how much more data will come
|
||||
w.minLengthExceeded = true
|
||||
w.Header().Set("Content-Encoding", gzipScheme)
|
||||
if w.wroteHeader {
|
||||
w.ResponseWriter.WriteHeader(w.code)
|
||||
}
|
||||
|
||||
_, _ = w.Writer.Write(w.buffer.Bytes())
|
||||
}
|
||||
|
||||
_ = w.Writer.(*gzip.Writer).Flush()
|
||||
|
||||
_ = http.NewResponseController(w.ResponseWriter).Flush()
|
||||
}
|
||||
|
||||
func (w *gzipResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
return http.NewResponseController(w.ResponseWriter).Hijack()
|
||||
}
|
||||
|
||||
func (w *gzipResponseWriter) Push(target string, opts *http.PushOptions) error {
|
||||
rw := w.ResponseWriter
|
||||
for {
|
||||
switch p := rw.(type) {
|
||||
case http.Pusher:
|
||||
return p.Push(target, opts)
|
||||
case router.RWUnwrapper:
|
||||
rw = p.Unwrap()
|
||||
default:
|
||||
return http.ErrNotSupported
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Note: Disable the implementation for now because in case the platform
|
||||
// supports the sendfile fast-path it won't run gzipResponseWriter.Write,
|
||||
// preventing compression on the fly.
|
||||
//
|
||||
// func (w *gzipResponseWriter) ReadFrom(r io.Reader) (n int64, err error) {
|
||||
// if w.wroteHeader {
|
||||
// w.ResponseWriter.WriteHeader(w.code)
|
||||
// }
|
||||
// rw := w.ResponseWriter
|
||||
// for {
|
||||
// switch rf := rw.(type) {
|
||||
// case io.ReaderFrom:
|
||||
// return rf.ReadFrom(r)
|
||||
// case router.RWUnwrapper:
|
||||
// rw = rf.Unwrap()
|
||||
// default:
|
||||
// return io.Copy(w.ResponseWriter, r)
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
func (w *gzipResponseWriter) Unwrap() http.ResponseWriter {
|
||||
return w.ResponseWriter
|
||||
}
|
356
apis/middlewares_rate_limit.go
Normal file
356
apis/middlewares_rate_limit.go
Normal file
|
@ -0,0 +1,356 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tools/hook"
|
||||
"github.com/pocketbase/pocketbase/tools/store"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultRateLimitMiddlewareId = "pbRateLimit"
|
||||
DefaultRateLimitMiddlewarePriority = -1000
|
||||
)
|
||||
|
||||
const (
|
||||
rateLimitersStoreKey = "__pbRateLimiters__"
|
||||
rateLimitersCronKey = "__pbRateLimitersCleanup__"
|
||||
rateLimitersSettingsHookId = "__pbRateLimitersSettingsHook__"
|
||||
)
|
||||
|
||||
// rateLimit defines the global rate limit middleware.
|
||||
//
|
||||
// This middleware is registered by default for all routes.
|
||||
func rateLimit() *hook.Handler[*core.RequestEvent] {
|
||||
return &hook.Handler[*core.RequestEvent]{
|
||||
Id: DefaultRateLimitMiddlewareId,
|
||||
Priority: DefaultRateLimitMiddlewarePriority,
|
||||
Func: func(e *core.RequestEvent) error {
|
||||
if skipRateLimit(e) {
|
||||
return e.Next()
|
||||
}
|
||||
|
||||
rule, ok := e.App.Settings().RateLimits.FindRateLimitRule(
|
||||
defaultRateLimitLabels(e),
|
||||
defaultRateLimitAudience(e)...,
|
||||
)
|
||||
if ok {
|
||||
err := checkRateLimit(e, rule.Label+rule.Audience, rule)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return e.Next()
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// collectionPathRateLimit defines a rate limit middleware for the internal collection handlers.
|
||||
func collectionPathRateLimit(collectionPathParam string, baseTags ...string) *hook.Handler[*core.RequestEvent] {
|
||||
if collectionPathParam == "" {
|
||||
collectionPathParam = "collection"
|
||||
}
|
||||
|
||||
return &hook.Handler[*core.RequestEvent]{
|
||||
Id: DefaultRateLimitMiddlewareId,
|
||||
Priority: DefaultRateLimitMiddlewarePriority,
|
||||
Func: func(e *core.RequestEvent) error {
|
||||
collection, err := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue(collectionPathParam))
|
||||
if err != nil {
|
||||
return e.NotFoundError("Missing or invalid collection context.", err)
|
||||
}
|
||||
|
||||
if err := checkCollectionRateLimit(e, collection, baseTags...); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return e.Next()
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// checkCollectionRateLimit checks whether the current request satisfy the
|
||||
// rate limit configuration for the specific collection.
|
||||
//
|
||||
// Each baseTags entry will be prefixed with the collection name and its wildcard variant.
|
||||
func checkCollectionRateLimit(e *core.RequestEvent, collection *core.Collection, baseTags ...string) error {
|
||||
if skipRateLimit(e) {
|
||||
return nil
|
||||
}
|
||||
|
||||
labels := make([]string, 0, 2+len(baseTags)*2)
|
||||
|
||||
rtId := collection.Id + e.Request.Pattern
|
||||
|
||||
// add first the primary labels (aka. ["collectionName:action1", "collectionName:action2"])
|
||||
for _, baseTag := range baseTags {
|
||||
rtId += baseTag
|
||||
labels = append(labels, collection.Name+":"+baseTag)
|
||||
}
|
||||
|
||||
// add the wildcard labels (aka. [..., "*:action1","*:action2", "*"])
|
||||
for _, baseTag := range baseTags {
|
||||
labels = append(labels, "*:"+baseTag)
|
||||
}
|
||||
labels = append(labels, defaultRateLimitLabels(e)...)
|
||||
|
||||
rule, ok := e.App.Settings().RateLimits.FindRateLimitRule(labels, defaultRateLimitAudience(e)...)
|
||||
if ok {
|
||||
return checkRateLimit(e, rtId+rule.Audience, rule)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
// @todo consider exporting as helper?
|
||||
//
|
||||
//nolint:unused
|
||||
func isClientRateLimited(e *core.RequestEvent, rtId string) bool {
|
||||
rateLimiters, ok := e.App.Store().Get(rateLimitersStoreKey).(*store.Store[string, *rateLimiter])
|
||||
if !ok || rateLimiters == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
rt, ok := rateLimiters.GetOk(rtId)
|
||||
if !ok || rt == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
client, ok := rt.getClient(e.RealIP())
|
||||
if !ok || client == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return client.available <= 0 && time.Now().Unix()-client.lastConsume < client.interval
|
||||
}
|
||||
|
||||
// @todo consider exporting as helper?
|
||||
func checkRateLimit(e *core.RequestEvent, rtId string, rule core.RateLimitRule) error {
|
||||
switch rule.Audience {
|
||||
case core.RateLimitRuleAudienceAll:
|
||||
// valid for both guest and regular users
|
||||
case core.RateLimitRuleAudienceGuest:
|
||||
if e.Auth != nil {
|
||||
return nil
|
||||
}
|
||||
case core.RateLimitRuleAudienceAuth:
|
||||
if e.Auth == nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
rateLimiters := e.App.Store().GetOrSet(rateLimitersStoreKey, func() any {
|
||||
return initRateLimitersStore(e.App)
|
||||
}).(*store.Store[string, *rateLimiter])
|
||||
if rateLimiters == nil {
|
||||
e.App.Logger().Warn("Failed to retrieve app rate limiters store")
|
||||
return nil
|
||||
}
|
||||
|
||||
rt := rateLimiters.GetOrSet(rtId, func() *rateLimiter {
|
||||
return newRateLimiter(rule.MaxRequests, rule.Duration, rule.Duration+1800)
|
||||
})
|
||||
if rt == nil {
|
||||
e.App.Logger().Warn("Failed to retrieve app rate limiter", "id", rtId)
|
||||
return nil
|
||||
}
|
||||
|
||||
key := e.RealIP()
|
||||
if key == "" {
|
||||
e.App.Logger().Warn("Empty rate limit client key")
|
||||
return nil
|
||||
}
|
||||
|
||||
if !rt.isAllowed(key) {
|
||||
return e.TooManyRequestsError("", nil)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func skipRateLimit(e *core.RequestEvent) bool {
|
||||
return !e.App.Settings().RateLimits.Enabled || e.HasSuperuserAuth()
|
||||
}
|
||||
|
||||
var defaultAuthAudience = []string{core.RateLimitRuleAudienceAll, core.RateLimitRuleAudienceAuth}
|
||||
var defaultGuestAudience = []string{core.RateLimitRuleAudienceAll, core.RateLimitRuleAudienceGuest}
|
||||
|
||||
func defaultRateLimitAudience(e *core.RequestEvent) []string {
|
||||
if e.Auth != nil {
|
||||
return defaultAuthAudience
|
||||
}
|
||||
|
||||
return defaultGuestAudience
|
||||
}
|
||||
|
||||
func defaultRateLimitLabels(e *core.RequestEvent) []string {
|
||||
return []string{e.Request.Method + " " + e.Request.URL.Path, e.Request.URL.Path}
|
||||
}
|
||||
|
||||
func destroyRateLimitersStore(app core.App) {
|
||||
app.OnSettingsReload().Unbind(rateLimitersSettingsHookId)
|
||||
app.Cron().Remove(rateLimitersCronKey)
|
||||
app.Store().Remove(rateLimitersStoreKey)
|
||||
}
|
||||
|
||||
func initRateLimitersStore(app core.App) *store.Store[string, *rateLimiter] {
|
||||
app.Cron().Add(rateLimitersCronKey, "2 * * * *", func() { // offset a little since too many cleanup tasks execute at 00
|
||||
limitersStore, ok := app.Store().Get(rateLimitersStoreKey).(*store.Store[string, *rateLimiter])
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
limiters := limitersStore.GetAll()
|
||||
for _, limiter := range limiters {
|
||||
limiter.clean()
|
||||
}
|
||||
})
|
||||
|
||||
app.OnSettingsReload().Bind(&hook.Handler[*core.SettingsReloadEvent]{
|
||||
Id: rateLimitersSettingsHookId,
|
||||
Func: func(e *core.SettingsReloadEvent) error {
|
||||
err := e.Next()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// reset
|
||||
destroyRateLimitersStore(e.App)
|
||||
|
||||
return nil
|
||||
},
|
||||
})
|
||||
|
||||
return store.New[string, *rateLimiter](nil)
|
||||
}
|
||||
|
||||
func newRateLimiter(maxAllowed int, intervalInSec int64, minDeleteIntervalInSec int64) *rateLimiter {
|
||||
return &rateLimiter{
|
||||
maxAllowed: maxAllowed,
|
||||
interval: intervalInSec,
|
||||
minDeleteInterval: minDeleteIntervalInSec,
|
||||
clients: map[string]*fixedWindow{},
|
||||
}
|
||||
}
|
||||
|
||||
type rateLimiter struct {
|
||||
clients map[string]*fixedWindow
|
||||
|
||||
maxAllowed int
|
||||
interval int64
|
||||
minDeleteInterval int64
|
||||
totalDeleted int64
|
||||
|
||||
sync.RWMutex
|
||||
}
|
||||
|
||||
//nolint:unused
|
||||
func (rt *rateLimiter) getClient(key string) (*fixedWindow, bool) {
|
||||
rt.RLock()
|
||||
client, ok := rt.clients[key]
|
||||
rt.RUnlock()
|
||||
|
||||
return client, ok
|
||||
}
|
||||
|
||||
func (rt *rateLimiter) isAllowed(key string) bool {
|
||||
// lock only reads to minimize locks contention
|
||||
rt.RLock()
|
||||
client, ok := rt.clients[key]
|
||||
rt.RUnlock()
|
||||
|
||||
if !ok {
|
||||
rt.Lock()
|
||||
// check again in case the client was added by another request
|
||||
client, ok = rt.clients[key]
|
||||
if !ok {
|
||||
client = newFixedWindow(rt.maxAllowed, rt.interval)
|
||||
rt.clients[key] = client
|
||||
}
|
||||
rt.Unlock()
|
||||
}
|
||||
|
||||
return client.consume()
|
||||
}
|
||||
|
||||
func (rt *rateLimiter) clean() {
|
||||
rt.Lock()
|
||||
defer rt.Unlock()
|
||||
|
||||
nowUnix := time.Now().Unix()
|
||||
|
||||
for k, client := range rt.clients {
|
||||
if client.hasExpired(nowUnix, rt.minDeleteInterval) {
|
||||
delete(rt.clients, k)
|
||||
rt.totalDeleted++
|
||||
}
|
||||
}
|
||||
|
||||
// "shrink" the map if too may items were deleted
|
||||
//
|
||||
// @todo remove after https://github.com/golang/go/issues/20135
|
||||
if rt.totalDeleted >= 300 {
|
||||
shrunk := make(map[string]*fixedWindow, len(rt.clients))
|
||||
for k, v := range rt.clients {
|
||||
shrunk[k] = v
|
||||
}
|
||||
rt.clients = shrunk
|
||||
rt.totalDeleted = 0
|
||||
}
|
||||
}
|
||||
|
||||
func newFixedWindow(maxAllowed int, intervalInSec int64) *fixedWindow {
|
||||
return &fixedWindow{
|
||||
maxAllowed: maxAllowed,
|
||||
interval: intervalInSec,
|
||||
}
|
||||
}
|
||||
|
||||
type fixedWindow struct {
|
||||
// use plain Mutex instead of RWMutex since the operations are expected
|
||||
// to be mostly writes (e.g. consume()) and it should perform better
|
||||
sync.Mutex
|
||||
|
||||
maxAllowed int // the max allowed tokens per interval
|
||||
available int // the total available tokens
|
||||
interval int64 // in seconds
|
||||
lastConsume int64 // the time of the last consume
|
||||
}
|
||||
|
||||
// hasExpired checks whether it has been at least minElapsed seconds since the lastConsume time.
|
||||
// (usually used to perform periodic cleanup of staled instances).
|
||||
func (l *fixedWindow) hasExpired(relativeNow int64, minElapsed int64) bool {
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
|
||||
return relativeNow-l.lastConsume > minElapsed
|
||||
}
|
||||
|
||||
// consume decrease the current window allowance with 1 (if not exhausted already).
|
||||
//
|
||||
// It returns false if the allowance has been already exhausted and the user
|
||||
// has to wait until it resets back to its maxAllowed value.
|
||||
func (l *fixedWindow) consume() bool {
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
|
||||
nowUnix := time.Now().Unix()
|
||||
|
||||
// reset consumed counter
|
||||
if nowUnix-l.lastConsume >= l.interval {
|
||||
l.available = l.maxAllowed
|
||||
}
|
||||
|
||||
if l.available > 0 {
|
||||
l.available--
|
||||
l.lastConsume = nowUnix
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
159
apis/middlewares_rate_limit_test.go
Normal file
159
apis/middlewares_rate_limit_test.go
Normal file
|
@ -0,0 +1,159 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pocketbase/pocketbase/apis"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestDefaultRateLimitMiddleware(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{
|
||||
Label: "/rate/",
|
||||
MaxRequests: 2,
|
||||
Duration: 1,
|
||||
},
|
||||
{
|
||||
Label: "/rate/b",
|
||||
MaxRequests: 3,
|
||||
Duration: 1,
|
||||
},
|
||||
{
|
||||
Label: "POST /rate/b",
|
||||
MaxRequests: 1,
|
||||
Duration: 1,
|
||||
},
|
||||
{
|
||||
Label: "/rate/guest",
|
||||
MaxRequests: 1,
|
||||
Duration: 1,
|
||||
Audience: core.RateLimitRuleAudienceGuest,
|
||||
},
|
||||
{
|
||||
Label: "/rate/auth",
|
||||
MaxRequests: 1,
|
||||
Duration: 1,
|
||||
Audience: core.RateLimitRuleAudienceAuth,
|
||||
},
|
||||
}
|
||||
|
||||
pbRouter, err := apis.NewRouter(app)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
pbRouter.GET("/norate", func(e *core.RequestEvent) error {
|
||||
return e.String(200, "norate")
|
||||
}).BindFunc(func(e *core.RequestEvent) error {
|
||||
return e.Next()
|
||||
})
|
||||
pbRouter.GET("/rate/a", func(e *core.RequestEvent) error {
|
||||
return e.String(200, "a")
|
||||
})
|
||||
pbRouter.GET("/rate/b", func(e *core.RequestEvent) error {
|
||||
return e.String(200, "b")
|
||||
})
|
||||
pbRouter.GET("/rate/guest", func(e *core.RequestEvent) error {
|
||||
return e.String(200, "guest")
|
||||
})
|
||||
pbRouter.GET("/rate/auth", func(e *core.RequestEvent) error {
|
||||
return e.String(200, "auth")
|
||||
})
|
||||
|
||||
mux, err := pbRouter.BuildMux()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
scenarios := []struct {
|
||||
url string
|
||||
wait float64
|
||||
authenticated bool
|
||||
expectedStatus int
|
||||
}{
|
||||
{"/norate", 0, false, 200},
|
||||
{"/norate", 0, false, 200},
|
||||
{"/norate", 0, false, 200},
|
||||
{"/norate", 0, false, 200},
|
||||
{"/norate", 0, false, 200},
|
||||
|
||||
{"/rate/a", 0, false, 200},
|
||||
{"/rate/a", 0, false, 200},
|
||||
{"/rate/a", 0, false, 429},
|
||||
{"/rate/a", 0, false, 429},
|
||||
{"/rate/a", 1.1, false, 200},
|
||||
{"/rate/a", 0, false, 200},
|
||||
{"/rate/a", 0, false, 429},
|
||||
|
||||
{"/rate/b", 0, false, 200},
|
||||
{"/rate/b", 0, false, 200},
|
||||
{"/rate/b", 0, false, 200},
|
||||
{"/rate/b", 0, false, 429},
|
||||
{"/rate/b", 1.1, false, 200},
|
||||
{"/rate/b", 0, false, 200},
|
||||
{"/rate/b", 0, false, 200},
|
||||
{"/rate/b", 0, false, 429},
|
||||
|
||||
// "auth" with guest (should fallback to the /rate/ rule)
|
||||
{"/rate/auth", 0, false, 200},
|
||||
{"/rate/auth", 0, false, 200},
|
||||
{"/rate/auth", 0, false, 429},
|
||||
{"/rate/auth", 0, false, 429},
|
||||
|
||||
// "auth" rule with regular user (should match the /rate/auth rule)
|
||||
{"/rate/auth", 0, true, 200},
|
||||
{"/rate/auth", 0, true, 429},
|
||||
{"/rate/auth", 0, true, 429},
|
||||
|
||||
// "guest" with guest (should match the /rate/guest rule)
|
||||
{"/rate/guest", 0, false, 200},
|
||||
{"/rate/guest", 0, false, 429},
|
||||
{"/rate/guest", 0, false, 429},
|
||||
|
||||
// "guest" rule with regular user (should fallback to the /rate/ rule)
|
||||
{"/rate/guest", 1, true, 200},
|
||||
{"/rate/guest", 0, true, 200},
|
||||
{"/rate/guest", 0, true, 429},
|
||||
{"/rate/guest", 0, true, 429},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.url, func(t *testing.T) {
|
||||
if s.wait > 0 {
|
||||
time.Sleep(time.Duration(s.wait) * time.Second)
|
||||
}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("GET", s.url, nil)
|
||||
|
||||
if s.authenticated {
|
||||
auth, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
token, err := auth.NewAuthToken()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
req.Header.Add("Authorization", token)
|
||||
}
|
||||
|
||||
mux.ServeHTTP(rec, req)
|
||||
|
||||
result := rec.Result()
|
||||
|
||||
if result.StatusCode != s.expectedStatus {
|
||||
t.Fatalf("Expected response status %d, got %d", s.expectedStatus, result.StatusCode)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
539
apis/middlewares_test.go
Normal file
539
apis/middlewares_test.go
Normal file
|
@ -0,0 +1,539 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/apis"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestPanicRecover(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "panic from route",
|
||||
Method: http.MethodGet,
|
||||
URL: "/my/test",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
e.Router.GET("/my/test", func(e *core.RequestEvent) error {
|
||||
panic("123")
|
||||
})
|
||||
},
|
||||
ExpectedStatus: 500,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "panic from middleware",
|
||||
Method: http.MethodGet,
|
||||
URL: "/my/test",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
e.Router.GET("/my/test", func(e *core.RequestEvent) error {
|
||||
return e.String(http.StatusOK, "test")
|
||||
}).BindFunc(func(e *core.RequestEvent) error {
|
||||
panic(123)
|
||||
})
|
||||
},
|
||||
ExpectedStatus: 500,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireGuestOnly(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
beforeTestFunc := func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
e.Router.GET("/my/test", func(e *core.RequestEvent) error {
|
||||
return e.String(200, "test123")
|
||||
}).Bind(apis.RequireGuestOnly())
|
||||
}
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "valid regular user token",
|
||||
Method: http.MethodGet,
|
||||
URL: "/my/test",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
BeforeTestFunc: beforeTestFunc,
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "valid superuser auth token",
|
||||
Method: http.MethodGet,
|
||||
URL: "/my/test",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
BeforeTestFunc: beforeTestFunc,
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "expired/invalid token",
|
||||
Method: http.MethodGet,
|
||||
URL: "/my/test",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoxNjQwOTkxNjYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.2D3tmqPn3vc5LoqqCz8V-iCDVXo9soYiH0d32G7FQT4",
|
||||
},
|
||||
BeforeTestFunc: beforeTestFunc,
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{"test123"},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodGet,
|
||||
URL: "/my/test",
|
||||
BeforeTestFunc: beforeTestFunc,
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{"test123"},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireAuth(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodGet,
|
||||
URL: "/my/test",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
e.Router.GET("/my/test", func(e *core.RequestEvent) error {
|
||||
return e.String(200, "test123")
|
||||
}).Bind(apis.RequireAuth())
|
||||
},
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "expired token",
|
||||
Method: http.MethodGet,
|
||||
URL: "/my/test",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoxNjQwOTkxNjYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.2D3tmqPn3vc5LoqqCz8V-iCDVXo9soYiH0d32G7FQT4",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
e.Router.GET("/my/test", func(e *core.RequestEvent) error {
|
||||
return e.String(200, "test123")
|
||||
}).Bind(apis.RequireAuth())
|
||||
},
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "invalid token",
|
||||
Method: http.MethodGet,
|
||||
URL: "/my/test",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6ImZpbGUiLCJjb2xsZWN0aW9uSWQiOiJwYmNfMzE0MjYzNTgyMyJ9.Lupz541xRvrktwkrl55p5pPCF77T69ZRsohsIcb2dxc",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
e.Router.GET("/my/test", func(e *core.RequestEvent) error {
|
||||
return e.String(200, "test123")
|
||||
}).Bind(apis.RequireAuth())
|
||||
},
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "valid record auth token with no collection restrictions",
|
||||
Method: http.MethodGet,
|
||||
URL: "/my/test",
|
||||
Headers: map[string]string{
|
||||
// regular user
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
e.Router.GET("/my/test", func(e *core.RequestEvent) error {
|
||||
return e.String(200, "test123")
|
||||
}).Bind(apis.RequireAuth())
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{"test123"},
|
||||
},
|
||||
{
|
||||
Name: "valid record static auth token",
|
||||
Method: http.MethodGet,
|
||||
URL: "/my/test",
|
||||
Headers: map[string]string{
|
||||
// regular user
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6ZmFsc2V9.4IsO6YMsR19crhwl_YWzvRH8pfq2Ri4Gv2dzGyneLak",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
e.Router.GET("/my/test", func(e *core.RequestEvent) error {
|
||||
return e.String(200, "test123")
|
||||
}).Bind(apis.RequireAuth())
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{"test123"},
|
||||
},
|
||||
{
|
||||
Name: "valid record auth token with collection not in the restricted list",
|
||||
Method: http.MethodGet,
|
||||
URL: "/my/test",
|
||||
Headers: map[string]string{
|
||||
// superuser
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
e.Router.GET("/my/test", func(e *core.RequestEvent) error {
|
||||
return e.String(200, "test123")
|
||||
}).Bind(apis.RequireAuth("users", "demo1"))
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "valid record auth token with collection in the restricted list",
|
||||
Method: http.MethodGet,
|
||||
URL: "/my/test",
|
||||
Headers: map[string]string{
|
||||
// superuser
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
e.Router.GET("/my/test", func(e *core.RequestEvent) error {
|
||||
return e.String(200, "test123")
|
||||
}).Bind(apis.RequireAuth("users", core.CollectionNameSuperusers))
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{"test123"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireSuperuserAuth(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodGet,
|
||||
URL: "/my/test",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
e.Router.GET("/my/test", func(e *core.RequestEvent) error {
|
||||
return e.String(200, "test123")
|
||||
}).Bind(apis.RequireSuperuserAuth())
|
||||
},
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "expired/invalid token",
|
||||
Method: http.MethodGet,
|
||||
URL: "/my/test",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjE2NDA5OTE2NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.0pDcBPGDpL2Khh76ivlRi7ugiLBSYvasct3qpHV3rfs",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
e.Router.GET("/my/test", func(e *core.RequestEvent) error {
|
||||
return e.String(200, "test123")
|
||||
}).Bind(apis.RequireSuperuserAuth())
|
||||
},
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "valid regular user auth token",
|
||||
Method: http.MethodGet,
|
||||
URL: "/my/test",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
e.Router.GET("/my/test", func(e *core.RequestEvent) error {
|
||||
return e.String(200, "test123")
|
||||
}).Bind(apis.RequireSuperuserAuth())
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "valid superuser auth token",
|
||||
Method: http.MethodGet,
|
||||
URL: "/my/test",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
e.Router.GET("/my/test", func(e *core.RequestEvent) error {
|
||||
return e.String(200, "test123")
|
||||
}).Bind(apis.RequireSuperuserAuth())
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{"test123"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireSuperuserOrOwnerAuth(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodGet,
|
||||
URL: "/my/test/4q1xlclmfloku33",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
e.Router.GET("/my/test/{id}", func(e *core.RequestEvent) error {
|
||||
return e.String(200, "test123")
|
||||
}).Bind(apis.RequireSuperuserOrOwnerAuth(""))
|
||||
},
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "expired/invalid token",
|
||||
Method: http.MethodGet,
|
||||
URL: "/my/test/4q1xlclmfloku33",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjE2NDA5OTE2NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.0pDcBPGDpL2Khh76ivlRi7ugiLBSYvasct3qpHV3rfs",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
e.Router.GET("/my/test/{id}", func(e *core.RequestEvent) error {
|
||||
return e.String(200, "test123")
|
||||
}).Bind(apis.RequireSuperuserOrOwnerAuth(""))
|
||||
},
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "valid record auth token (different user)",
|
||||
Method: http.MethodGet,
|
||||
URL: "/my/test/oap640cot4yru2s",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
e.Router.GET("/my/test/{id}", func(e *core.RequestEvent) error {
|
||||
return e.String(200, "test123")
|
||||
}).Bind(apis.RequireSuperuserOrOwnerAuth(""))
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "valid record auth token (owner)",
|
||||
Method: http.MethodGet,
|
||||
URL: "/my/test/4q1xlclmfloku33",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
e.Router.GET("/my/test/{id}", func(e *core.RequestEvent) error {
|
||||
return e.String(200, "test123")
|
||||
}).Bind(apis.RequireSuperuserOrOwnerAuth(""))
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{"test123"},
|
||||
},
|
||||
{
|
||||
Name: "valid record auth token (owner + non-matching custom owner param)",
|
||||
Method: http.MethodGet,
|
||||
URL: "/my/test/4q1xlclmfloku33",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
e.Router.GET("/my/test/{id}", func(e *core.RequestEvent) error {
|
||||
return e.String(200, "test123")
|
||||
}).Bind(apis.RequireSuperuserOrOwnerAuth("test"))
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "valid record auth token (owner + matching custom owner param)",
|
||||
Method: http.MethodGet,
|
||||
URL: "/my/test/4q1xlclmfloku33",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
e.Router.GET("/my/test/{test}", func(e *core.RequestEvent) error {
|
||||
return e.String(200, "test123")
|
||||
}).Bind(apis.RequireSuperuserOrOwnerAuth("test"))
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{"test123"},
|
||||
},
|
||||
{
|
||||
Name: "valid superuser auth token",
|
||||
Method: http.MethodGet,
|
||||
URL: "/my/test/4q1xlclmfloku33",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
e.Router.GET("/my/test/{id}", func(e *core.RequestEvent) error {
|
||||
return e.String(200, "test123")
|
||||
}).Bind(apis.RequireSuperuserOrOwnerAuth(""))
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{"test123"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireSameCollectionContextAuth(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodGet,
|
||||
URL: "/my/test/_pb_users_auth_",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
e.Router.GET("/my/test/{collection}", func(e *core.RequestEvent) error {
|
||||
return e.String(200, "test123")
|
||||
}).Bind(apis.RequireSameCollectionContextAuth(""))
|
||||
},
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "expired/invalid token",
|
||||
Method: http.MethodGet,
|
||||
URL: "/my/test/_pb_users_auth_",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoxNjQwOTkxNjYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.2D3tmqPn3vc5LoqqCz8V-iCDVXo9soYiH0d32G7FQT4",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
e.Router.GET("/my/test/{collection}", func(e *core.RequestEvent) error {
|
||||
return e.String(200, "test123")
|
||||
}).Bind(apis.RequireSameCollectionContextAuth(""))
|
||||
},
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "valid record auth token (different collection)",
|
||||
Method: http.MethodGet,
|
||||
URL: "/my/test/clients",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
e.Router.GET("/my/test/{collection}", func(e *core.RequestEvent) error {
|
||||
return e.String(200, "test123")
|
||||
}).Bind(apis.RequireSameCollectionContextAuth(""))
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "valid record auth token (same collection)",
|
||||
Method: http.MethodGet,
|
||||
URL: "/my/test/_pb_users_auth_",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
e.Router.GET("/my/test/{collection}", func(e *core.RequestEvent) error {
|
||||
return e.String(200, "test123")
|
||||
}).Bind(apis.RequireSameCollectionContextAuth(""))
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{"test123"},
|
||||
},
|
||||
{
|
||||
Name: "valid record auth token (non-matching/missing collection param)",
|
||||
Method: http.MethodGet,
|
||||
URL: "/my/test/_pb_users_auth_",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
e.Router.GET("/my/test/{id}", func(e *core.RequestEvent) error {
|
||||
return e.String(200, "test123")
|
||||
}).Bind(apis.RequireSuperuserOrOwnerAuth(""))
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "valid record auth token (matching custom collection param)",
|
||||
Method: http.MethodGet,
|
||||
URL: "/my/test/_pb_users_auth_",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
e.Router.GET("/my/test/{test}", func(e *core.RequestEvent) error {
|
||||
return e.String(200, "test123")
|
||||
}).Bind(apis.RequireSuperuserOrOwnerAuth("test"))
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "superuser no exception check",
|
||||
Method: http.MethodGet,
|
||||
URL: "/my/test/_pb_users_auth_",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
e.Router.GET("/my/test/{collection}", func(e *core.RequestEvent) error {
|
||||
return e.String(200, "test123")
|
||||
}).Bind(apis.RequireSameCollectionContextAuth(""))
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
777
apis/realtime.go
Normal file
777
apis/realtime.go
Normal file
|
@ -0,0 +1,777 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tools/hook"
|
||||
"github.com/pocketbase/pocketbase/tools/picker"
|
||||
"github.com/pocketbase/pocketbase/tools/router"
|
||||
"github.com/pocketbase/pocketbase/tools/routine"
|
||||
"github.com/pocketbase/pocketbase/tools/search"
|
||||
"github.com/pocketbase/pocketbase/tools/subscriptions"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
// note: the chunk size is arbitrary chosen and may change in the future
|
||||
const clientsChunkSize = 150
|
||||
|
||||
// RealtimeClientAuthKey is the name of the realtime client store key that holds its auth state.
|
||||
const RealtimeClientAuthKey = "auth"
|
||||
|
||||
// bindRealtimeApi registers the realtime api endpoints.
|
||||
func bindRealtimeApi(app core.App, rg *router.RouterGroup[*core.RequestEvent]) {
|
||||
sub := rg.Group("/realtime")
|
||||
sub.GET("", realtimeConnect).Bind(SkipSuccessActivityLog())
|
||||
sub.POST("", realtimeSetSubscriptions)
|
||||
|
||||
bindRealtimeEvents(app)
|
||||
}
|
||||
|
||||
func realtimeConnect(e *core.RequestEvent) error {
|
||||
// disable global write deadline for the SSE connection
|
||||
rc := http.NewResponseController(e.Response)
|
||||
writeDeadlineErr := rc.SetWriteDeadline(time.Time{})
|
||||
if writeDeadlineErr != nil {
|
||||
if !errors.Is(writeDeadlineErr, http.ErrNotSupported) {
|
||||
return e.InternalServerError("Failed to initialize SSE connection.", writeDeadlineErr)
|
||||
}
|
||||
|
||||
// only log since there are valid cases where it may not be implement (e.g. httptest.ResponseRecorder)
|
||||
e.App.Logger().Warn("SetWriteDeadline is not supported, fallback to the default server WriteTimeout")
|
||||
}
|
||||
|
||||
// create cancellable request
|
||||
cancelCtx, cancelRequest := context.WithCancel(e.Request.Context())
|
||||
defer cancelRequest()
|
||||
e.Request = e.Request.Clone(cancelCtx)
|
||||
|
||||
e.Response.Header().Set("Content-Type", "text/event-stream")
|
||||
e.Response.Header().Set("Cache-Control", "no-store")
|
||||
// https://github.com/pocketbase/pocketbase/discussions/480#discussioncomment-3657640
|
||||
// https://nginx.org/en/docs/http/ngx_http_proxy_module.html#proxy_buffering
|
||||
e.Response.Header().Set("X-Accel-Buffering", "no")
|
||||
|
||||
connectEvent := new(core.RealtimeConnectRequestEvent)
|
||||
connectEvent.RequestEvent = e
|
||||
connectEvent.Client = subscriptions.NewDefaultClient()
|
||||
connectEvent.IdleTimeout = 5 * time.Minute
|
||||
|
||||
return e.App.OnRealtimeConnectRequest().Trigger(connectEvent, func(ce *core.RealtimeConnectRequestEvent) error {
|
||||
// register new subscription client
|
||||
ce.App.SubscriptionsBroker().Register(ce.Client)
|
||||
defer func() {
|
||||
e.App.SubscriptionsBroker().Unregister(ce.Client.Id())
|
||||
}()
|
||||
|
||||
ce.App.Logger().Debug("Realtime connection established.", slog.String("clientId", ce.Client.Id()))
|
||||
|
||||
// signalize established connection (aka. fire "connect" message)
|
||||
connectMsgEvent := new(core.RealtimeMessageEvent)
|
||||
connectMsgEvent.RequestEvent = ce.RequestEvent
|
||||
connectMsgEvent.Client = ce.Client
|
||||
connectMsgEvent.Message = &subscriptions.Message{
|
||||
Name: "PB_CONNECT",
|
||||
Data: []byte(`{"clientId":"` + ce.Client.Id() + `"}`),
|
||||
}
|
||||
connectMsgErr := ce.App.OnRealtimeMessageSend().Trigger(connectMsgEvent, func(me *core.RealtimeMessageEvent) error {
|
||||
err := me.Message.WriteSSE(me.Response, me.Client.Id())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return me.Flush()
|
||||
})
|
||||
if connectMsgErr != nil {
|
||||
ce.App.Logger().Debug(
|
||||
"Realtime connection closed (failed to deliver PB_CONNECT)",
|
||||
slog.String("clientId", ce.Client.Id()),
|
||||
slog.String("error", connectMsgErr.Error()),
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
// start an idle timer to keep track of inactive/forgotten connections
|
||||
idleTimer := time.NewTimer(ce.IdleTimeout)
|
||||
defer idleTimer.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-idleTimer.C:
|
||||
cancelRequest()
|
||||
case msg, ok := <-ce.Client.Channel():
|
||||
if !ok {
|
||||
// channel is closed
|
||||
ce.App.Logger().Debug(
|
||||
"Realtime connection closed (closed channel)",
|
||||
slog.String("clientId", ce.Client.Id()),
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
msgEvent := new(core.RealtimeMessageEvent)
|
||||
msgEvent.RequestEvent = ce.RequestEvent
|
||||
msgEvent.Client = ce.Client
|
||||
msgEvent.Message = &msg
|
||||
msgErr := ce.App.OnRealtimeMessageSend().Trigger(msgEvent, func(me *core.RealtimeMessageEvent) error {
|
||||
err := me.Message.WriteSSE(me.Response, me.Client.Id())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return me.Flush()
|
||||
})
|
||||
if msgErr != nil {
|
||||
ce.App.Logger().Debug(
|
||||
"Realtime connection closed (failed to deliver message)",
|
||||
slog.String("clientId", ce.Client.Id()),
|
||||
slog.String("error", msgErr.Error()),
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
idleTimer.Stop()
|
||||
idleTimer.Reset(ce.IdleTimeout)
|
||||
case <-ce.Request.Context().Done():
|
||||
// connection is closed
|
||||
ce.App.Logger().Debug(
|
||||
"Realtime connection closed (cancelled request)",
|
||||
slog.String("clientId", ce.Client.Id()),
|
||||
)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
type realtimeSubscribeForm struct {
|
||||
ClientId string `form:"clientId" json:"clientId"`
|
||||
Subscriptions []string `form:"subscriptions" json:"subscriptions"`
|
||||
}
|
||||
|
||||
func (form *realtimeSubscribeForm) validate() error {
|
||||
return validation.ValidateStruct(form,
|
||||
validation.Field(&form.ClientId, validation.Required, validation.Length(1, 255)),
|
||||
validation.Field(&form.Subscriptions,
|
||||
validation.Length(0, 1000),
|
||||
validation.Each(validation.Length(0, 2500)),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
// note: in case of reconnect, clients will have to resubmit all subscriptions again
|
||||
func realtimeSetSubscriptions(e *core.RequestEvent) error {
|
||||
form := new(realtimeSubscribeForm)
|
||||
|
||||
err := e.BindBody(form)
|
||||
if err != nil {
|
||||
return e.BadRequestError("", err)
|
||||
}
|
||||
|
||||
err = form.validate()
|
||||
if err != nil {
|
||||
return e.BadRequestError("", err)
|
||||
}
|
||||
|
||||
// find subscription client
|
||||
client, err := e.App.SubscriptionsBroker().ClientById(form.ClientId)
|
||||
if err != nil {
|
||||
return e.NotFoundError("Missing or invalid client id.", err)
|
||||
}
|
||||
|
||||
// for now allow only guest->auth upgrades and any other auth change is forbidden
|
||||
clientAuth, _ := client.Get(RealtimeClientAuthKey).(*core.Record)
|
||||
if clientAuth != nil && !isSameAuth(clientAuth, e.Auth) {
|
||||
return e.ForbiddenError("The current and the previous request authorization don't match.", nil)
|
||||
}
|
||||
|
||||
event := new(core.RealtimeSubscribeRequestEvent)
|
||||
event.RequestEvent = e
|
||||
event.Client = client
|
||||
event.Subscriptions = form.Subscriptions
|
||||
|
||||
return e.App.OnRealtimeSubscribeRequest().Trigger(event, func(e *core.RealtimeSubscribeRequestEvent) error {
|
||||
// update auth state
|
||||
e.Client.Set(RealtimeClientAuthKey, e.Auth)
|
||||
|
||||
// unsubscribe from any previous existing subscriptions
|
||||
e.Client.Unsubscribe()
|
||||
|
||||
// subscribe to the new subscriptions
|
||||
e.Client.Subscribe(e.Subscriptions...)
|
||||
|
||||
e.App.Logger().Debug(
|
||||
"Realtime subscriptions updated.",
|
||||
slog.String("clientId", e.Client.Id()),
|
||||
slog.Any("subscriptions", e.Subscriptions),
|
||||
)
|
||||
|
||||
return execAfterSuccessTx(true, e.App, func() error {
|
||||
return e.NoContent(http.StatusNoContent)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// updateClientsAuth updates the existing clients auth record with the new one (matched by ID).
|
||||
func realtimeUpdateClientsAuth(app core.App, newAuthRecord *core.Record) error {
|
||||
chunks := app.SubscriptionsBroker().ChunkedClients(clientsChunkSize)
|
||||
|
||||
group := new(errgroup.Group)
|
||||
|
||||
for _, chunk := range chunks {
|
||||
group.Go(func() error {
|
||||
for _, client := range chunk {
|
||||
clientAuth, _ := client.Get(RealtimeClientAuthKey).(*core.Record)
|
||||
if clientAuth != nil &&
|
||||
clientAuth.Id == newAuthRecord.Id &&
|
||||
clientAuth.Collection().Name == newAuthRecord.Collection().Name {
|
||||
client.Set(RealtimeClientAuthKey, newAuthRecord)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
return group.Wait()
|
||||
}
|
||||
|
||||
// realtimeUnsetClientsAuthState unsets the auth state of all clients that have the provided auth model.
|
||||
func realtimeUnsetClientsAuthState(app core.App, authModel core.Model) error {
|
||||
chunks := app.SubscriptionsBroker().ChunkedClients(clientsChunkSize)
|
||||
|
||||
group := new(errgroup.Group)
|
||||
|
||||
for _, chunk := range chunks {
|
||||
group.Go(func() error {
|
||||
for _, client := range chunk {
|
||||
clientAuth, _ := client.Get(RealtimeClientAuthKey).(*core.Record)
|
||||
if clientAuth != nil &&
|
||||
clientAuth.Id == authModel.PK() &&
|
||||
clientAuth.Collection().Name == authModel.TableName() {
|
||||
client.Unset(RealtimeClientAuthKey)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
return group.Wait()
|
||||
}
|
||||
|
||||
func bindRealtimeEvents(app core.App) {
|
||||
// update the clients that has auth record association
|
||||
app.OnModelAfterUpdateSuccess().Bind(&hook.Handler[*core.ModelEvent]{
|
||||
Func: func(e *core.ModelEvent) error {
|
||||
authRecord := realtimeResolveRecord(e.App, e.Model, core.CollectionTypeAuth)
|
||||
if authRecord != nil {
|
||||
if err := realtimeUpdateClientsAuth(e.App, authRecord); err != nil {
|
||||
app.Logger().Warn(
|
||||
"Failed to update client(s) associated to the updated auth record",
|
||||
slog.Any("id", authRecord.Id),
|
||||
slog.String("collectionName", authRecord.Collection().Name),
|
||||
slog.String("error", err.Error()),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return e.Next()
|
||||
},
|
||||
Priority: -99,
|
||||
})
|
||||
|
||||
// remove the client(s) associated to the deleted auth model
|
||||
// (note: works also with custom model for backward compatibility)
|
||||
app.OnModelAfterDeleteSuccess().Bind(&hook.Handler[*core.ModelEvent]{
|
||||
Func: func(e *core.ModelEvent) error {
|
||||
collection := realtimeResolveRecordCollection(e.App, e.Model)
|
||||
if collection != nil && collection.IsAuth() {
|
||||
if err := realtimeUnsetClientsAuthState(e.App, e.Model); err != nil {
|
||||
app.Logger().Warn(
|
||||
"Failed to remove client(s) associated to the deleted auth model",
|
||||
slog.Any("id", e.Model.PK()),
|
||||
slog.String("collectionName", e.Model.TableName()),
|
||||
slog.String("error", err.Error()),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return e.Next()
|
||||
},
|
||||
Priority: -99,
|
||||
})
|
||||
|
||||
app.OnModelAfterCreateSuccess().Bind(&hook.Handler[*core.ModelEvent]{
|
||||
Func: func(e *core.ModelEvent) error {
|
||||
record := realtimeResolveRecord(e.App, e.Model, "")
|
||||
if record != nil {
|
||||
err := realtimeBroadcastRecord(e.App, "create", record, false)
|
||||
if err != nil {
|
||||
app.Logger().Debug(
|
||||
"Failed to broadcast record create",
|
||||
slog.String("id", record.Id),
|
||||
slog.String("collectionName", record.Collection().Name),
|
||||
slog.String("error", err.Error()),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return e.Next()
|
||||
},
|
||||
Priority: -99,
|
||||
})
|
||||
|
||||
app.OnModelAfterUpdateSuccess().Bind(&hook.Handler[*core.ModelEvent]{
|
||||
Func: func(e *core.ModelEvent) error {
|
||||
record := realtimeResolveRecord(e.App, e.Model, "")
|
||||
if record != nil {
|
||||
err := realtimeBroadcastRecord(e.App, "update", record, false)
|
||||
if err != nil {
|
||||
app.Logger().Debug(
|
||||
"Failed to broadcast record update",
|
||||
slog.String("id", record.Id),
|
||||
slog.String("collectionName", record.Collection().Name),
|
||||
slog.String("error", err.Error()),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return e.Next()
|
||||
},
|
||||
Priority: -99,
|
||||
})
|
||||
|
||||
// delete: dry cache
|
||||
app.OnModelDelete().Bind(&hook.Handler[*core.ModelEvent]{
|
||||
Func: func(e *core.ModelEvent) error {
|
||||
record := realtimeResolveRecord(e.App, e.Model, "")
|
||||
if record != nil {
|
||||
// note: use the outside scoped app instance for the access checks so that the API rules
|
||||
// are performed out of the delete transaction ensuring that they would still work even if
|
||||
// a cascade-deleted record's API rule relies on an already deleted parent record
|
||||
err := realtimeBroadcastRecord(e.App, "delete", record, true, app)
|
||||
if err != nil {
|
||||
app.Logger().Debug(
|
||||
"Failed to dry cache record delete",
|
||||
slog.String("id", record.Id),
|
||||
slog.String("collectionName", record.Collection().Name),
|
||||
slog.String("error", err.Error()),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return e.Next()
|
||||
},
|
||||
Priority: 99, // execute as later as possible
|
||||
})
|
||||
|
||||
// delete: broadcast
|
||||
app.OnModelAfterDeleteSuccess().Bind(&hook.Handler[*core.ModelEvent]{
|
||||
Func: func(e *core.ModelEvent) error {
|
||||
// note: only ensure that it is a collection record
|
||||
// and don't use realtimeResolveRecord because in case of a
|
||||
// custom model it'll fail to resolve since the record is already deleted
|
||||
collection := realtimeResolveRecordCollection(e.App, e.Model)
|
||||
if collection != nil {
|
||||
err := realtimeBroadcastDryCacheKey(e.App, getDryCacheKey("delete", e.Model))
|
||||
if err != nil {
|
||||
app.Logger().Debug(
|
||||
"Failed to broadcast record delete",
|
||||
slog.Any("id", e.Model.PK()),
|
||||
slog.String("collectionName", collection.Name),
|
||||
slog.String("error", err.Error()),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return e.Next()
|
||||
},
|
||||
Priority: -99,
|
||||
})
|
||||
|
||||
// delete: failure
|
||||
app.OnModelAfterDeleteError().Bind(&hook.Handler[*core.ModelErrorEvent]{
|
||||
Func: func(e *core.ModelErrorEvent) error {
|
||||
record := realtimeResolveRecord(e.App, e.Model, "")
|
||||
if record != nil {
|
||||
err := realtimeUnsetDryCacheKey(e.App, getDryCacheKey("delete", record))
|
||||
if err != nil {
|
||||
app.Logger().Debug(
|
||||
"Failed to cleanup after broadcast record delete failure",
|
||||
slog.String("id", record.Id),
|
||||
slog.String("collectionName", record.Collection().Name),
|
||||
slog.String("error", err.Error()),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return e.Next()
|
||||
},
|
||||
Priority: -99,
|
||||
})
|
||||
}
|
||||
|
||||
// resolveRecord converts *if possible* the provided model interface to a Record.
|
||||
// This is usually helpful if the provided model is a custom Record model struct.
|
||||
func realtimeResolveRecord(app core.App, model core.Model, optCollectionType string) *core.Record {
|
||||
var record *core.Record
|
||||
switch m := model.(type) {
|
||||
case *core.Record:
|
||||
record = m
|
||||
case core.RecordProxy:
|
||||
record = m.ProxyRecord()
|
||||
}
|
||||
|
||||
if record != nil {
|
||||
if optCollectionType == "" || record.Collection().Type == optCollectionType {
|
||||
return record
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
tblName := model.TableName()
|
||||
|
||||
// skip Log model checks
|
||||
if tblName == core.LogsTableName {
|
||||
return nil
|
||||
}
|
||||
|
||||
// check if it is custom Record model struct
|
||||
collection, _ := app.FindCachedCollectionByNameOrId(tblName)
|
||||
if collection != nil && (optCollectionType == "" || collection.Type == optCollectionType) {
|
||||
if id, ok := model.PK().(string); ok {
|
||||
record, _ = app.FindRecordById(collection, id)
|
||||
}
|
||||
}
|
||||
|
||||
return record
|
||||
}
|
||||
|
||||
// realtimeResolveRecordCollection extracts *if possible* the Collection model from the provided model interface.
|
||||
// This is usually helpful if the provided model is a custom Record model struct.
|
||||
func realtimeResolveRecordCollection(app core.App, model core.Model) (collection *core.Collection) {
|
||||
switch m := model.(type) {
|
||||
case *core.Record:
|
||||
return m.Collection()
|
||||
case core.RecordProxy:
|
||||
return m.ProxyRecord().Collection()
|
||||
default:
|
||||
// check if it is custom Record model struct
|
||||
collection, err := app.FindCachedCollectionByNameOrId(model.TableName())
|
||||
if err == nil {
|
||||
return collection
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// recordData represents the broadcasted record subscrition message data.
|
||||
type recordData struct {
|
||||
Record any `json:"record"` /* map or core.Record */
|
||||
Action string `json:"action"`
|
||||
}
|
||||
|
||||
// Note: the optAccessCheckApp is there in case you want the access check
|
||||
// to be performed against different db app context (e.g. out of a transaction).
|
||||
// If set, it is expected that optAccessCheckApp instance is used for read-only operations to avoid deadlocks.
|
||||
// If not set, it fallbacks to app.
|
||||
func realtimeBroadcastRecord(app core.App, action string, record *core.Record, dryCache bool, optAccessCheckApp ...core.App) error {
|
||||
collection := record.Collection()
|
||||
if collection == nil {
|
||||
return errors.New("[broadcastRecord] Record collection not set")
|
||||
}
|
||||
|
||||
chunks := app.SubscriptionsBroker().ChunkedClients(clientsChunkSize)
|
||||
if len(chunks) == 0 {
|
||||
return nil // no subscribers
|
||||
}
|
||||
|
||||
subscriptionRuleMap := map[string]*string{
|
||||
(collection.Name + "/" + record.Id + "?"): collection.ViewRule,
|
||||
(collection.Id + "/" + record.Id + "?"): collection.ViewRule,
|
||||
(collection.Name + "/*?"): collection.ListRule,
|
||||
(collection.Id + "/*?"): collection.ListRule,
|
||||
|
||||
// @deprecated: the same as the wildcard topic but kept for backward compatibility
|
||||
(collection.Name + "?"): collection.ListRule,
|
||||
(collection.Id + "?"): collection.ListRule,
|
||||
}
|
||||
|
||||
dryCacheKey := getDryCacheKey(action, record)
|
||||
|
||||
group := new(errgroup.Group)
|
||||
|
||||
accessCheckApp := app
|
||||
if len(optAccessCheckApp) > 0 {
|
||||
accessCheckApp = optAccessCheckApp[0]
|
||||
}
|
||||
|
||||
for _, chunk := range chunks {
|
||||
group.Go(func() error {
|
||||
var clientAuth *core.Record
|
||||
|
||||
for _, client := range chunk {
|
||||
// note: not executed concurrently to avoid races and to ensure
|
||||
// that the access checks are applied for the current record db state
|
||||
for prefix, rule := range subscriptionRuleMap {
|
||||
subs := client.Subscriptions(prefix)
|
||||
if len(subs) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
clientAuth, _ = client.Get(RealtimeClientAuthKey).(*core.Record)
|
||||
|
||||
for sub, options := range subs {
|
||||
// mock request data
|
||||
requestInfo := &core.RequestInfo{
|
||||
Context: core.RequestInfoContextRealtime,
|
||||
Method: "GET",
|
||||
Query: options.Query,
|
||||
Headers: options.Headers,
|
||||
Auth: clientAuth,
|
||||
}
|
||||
|
||||
if !realtimeCanAccessRecord(accessCheckApp, record, requestInfo, rule) {
|
||||
continue
|
||||
}
|
||||
|
||||
// create a clean record copy without expand and unknown fields because we don't know yet
|
||||
// which exact fields the client subscription requested or has permissions to access
|
||||
cleanRecord := record.Fresh()
|
||||
|
||||
// trigger the enrich hooks
|
||||
enrichErr := triggerRecordEnrichHooks(app, requestInfo, []*core.Record{cleanRecord}, func() error {
|
||||
// apply expand
|
||||
rawExpand := options.Query[expandQueryParam]
|
||||
if rawExpand != "" {
|
||||
expandErrs := app.ExpandRecord(cleanRecord, strings.Split(rawExpand, ","), expandFetch(app, requestInfo))
|
||||
if len(expandErrs) > 0 {
|
||||
app.Logger().Debug(
|
||||
"[broadcastRecord] expand errors",
|
||||
slog.String("id", cleanRecord.Id),
|
||||
slog.String("collectionName", cleanRecord.Collection().Name),
|
||||
slog.String("sub", sub),
|
||||
slog.String("expand", rawExpand),
|
||||
slog.Any("errors", expandErrs),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// ignore the auth record email visibility checks
|
||||
// for auth owner, superuser or manager
|
||||
if collection.IsAuth() {
|
||||
if isSameAuth(clientAuth, cleanRecord) ||
|
||||
realtimeCanAccessRecord(accessCheckApp, cleanRecord, requestInfo, collection.ManageRule) {
|
||||
cleanRecord.IgnoreEmailVisibility(true)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if enrichErr != nil {
|
||||
app.Logger().Debug(
|
||||
"[broadcastRecord] record enrich error",
|
||||
slog.String("id", cleanRecord.Id),
|
||||
slog.String("collectionName", cleanRecord.Collection().Name),
|
||||
slog.String("sub", sub),
|
||||
slog.Any("error", enrichErr),
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
data := &recordData{
|
||||
Action: action,
|
||||
Record: cleanRecord,
|
||||
}
|
||||
|
||||
// check fields
|
||||
rawFields := options.Query[fieldsQueryParam]
|
||||
if rawFields != "" {
|
||||
decoded, err := picker.Pick(cleanRecord, rawFields)
|
||||
if err == nil {
|
||||
data.Record = decoded
|
||||
} else {
|
||||
app.Logger().Debug(
|
||||
"[broadcastRecord] pick fields error",
|
||||
slog.String("id", cleanRecord.Id),
|
||||
slog.String("collectionName", cleanRecord.Collection().Name),
|
||||
slog.String("sub", sub),
|
||||
slog.String("fields", rawFields),
|
||||
slog.String("error", err.Error()),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
dataBytes, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
app.Logger().Debug(
|
||||
"[broadcastRecord] data marshal error",
|
||||
slog.String("id", cleanRecord.Id),
|
||||
slog.String("collectionName", cleanRecord.Collection().Name),
|
||||
slog.String("error", err.Error()),
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
msg := subscriptions.Message{
|
||||
Name: sub,
|
||||
Data: dataBytes,
|
||||
}
|
||||
|
||||
if dryCache {
|
||||
messages, ok := client.Get(dryCacheKey).([]subscriptions.Message)
|
||||
if !ok {
|
||||
messages = []subscriptions.Message{msg}
|
||||
} else {
|
||||
messages = append(messages, msg)
|
||||
}
|
||||
client.Set(dryCacheKey, messages)
|
||||
} else {
|
||||
routine.FireAndForget(func() {
|
||||
client.Send(msg)
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
return group.Wait()
|
||||
}
|
||||
|
||||
// realtimeBroadcastDryCacheKey broadcasts the dry cached key related messages.
|
||||
func realtimeBroadcastDryCacheKey(app core.App, key string) error {
|
||||
chunks := app.SubscriptionsBroker().ChunkedClients(clientsChunkSize)
|
||||
if len(chunks) == 0 {
|
||||
return nil // no subscribers
|
||||
}
|
||||
|
||||
group := new(errgroup.Group)
|
||||
|
||||
for _, chunk := range chunks {
|
||||
group.Go(func() error {
|
||||
for _, client := range chunk {
|
||||
messages, ok := client.Get(key).([]subscriptions.Message)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
client.Unset(key)
|
||||
|
||||
client := client
|
||||
|
||||
routine.FireAndForget(func() {
|
||||
for _, msg := range messages {
|
||||
client.Send(msg)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
return group.Wait()
|
||||
}
|
||||
|
||||
// realtimeUnsetDryCacheKey removes the dry cached key related messages.
|
||||
func realtimeUnsetDryCacheKey(app core.App, key string) error {
|
||||
chunks := app.SubscriptionsBroker().ChunkedClients(clientsChunkSize)
|
||||
if len(chunks) == 0 {
|
||||
return nil // no subscribers
|
||||
}
|
||||
|
||||
group := new(errgroup.Group)
|
||||
|
||||
for _, chunk := range chunks {
|
||||
group.Go(func() error {
|
||||
for _, client := range chunk {
|
||||
if client.Get(key) != nil {
|
||||
client.Unset(key)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
return group.Wait()
|
||||
}
|
||||
|
||||
func getDryCacheKey(action string, model core.Model) string {
|
||||
pkStr, ok := model.PK().(string)
|
||||
if !ok {
|
||||
pkStr = fmt.Sprintf("%v", model.PK())
|
||||
}
|
||||
|
||||
return action + "/" + model.TableName() + "/" + pkStr
|
||||
}
|
||||
|
||||
func isSameAuth(authA, authB *core.Record) bool {
|
||||
if authA == nil {
|
||||
return authB == nil
|
||||
}
|
||||
|
||||
if authB == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return authA.Id == authB.Id && authA.Collection().Id == authB.Collection().Id
|
||||
}
|
||||
|
||||
// realtimeCanAccessRecord checks if the subscription client has access to the specified record model.
|
||||
func realtimeCanAccessRecord(
|
||||
app core.App,
|
||||
record *core.Record,
|
||||
requestInfo *core.RequestInfo,
|
||||
accessRule *string,
|
||||
) bool {
|
||||
// check the access rule
|
||||
// ---
|
||||
if ok, _ := app.CanAccessRecord(record, requestInfo, accessRule); !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
// check the subscription client-side filter (if any)
|
||||
// ---
|
||||
filter := requestInfo.Query[search.FilterQueryParam]
|
||||
if filter == "" {
|
||||
return true // no further checks needed
|
||||
}
|
||||
|
||||
err := checkForSuperuserOnlyRuleFields(requestInfo)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
var exists int
|
||||
|
||||
q := app.ConcurrentDB().Select("(1)").
|
||||
From(record.Collection().Name).
|
||||
AndWhere(dbx.HashExp{record.Collection().Name + ".id": record.Id})
|
||||
|
||||
resolver := core.NewRecordFieldResolver(app, record.Collection(), requestInfo, false)
|
||||
expr, err := search.FilterData(filter).BuildExpr(resolver)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
q.AndWhere(expr)
|
||||
resolver.UpdateQuery(q)
|
||||
|
||||
err = q.Limit(1).Row(&exists)
|
||||
|
||||
return err == nil && exists > 0
|
||||
}
|
885
apis/realtime_test.go
Normal file
885
apis/realtime_test.go
Normal file
|
@ -0,0 +1,885 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/apis"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
"github.com/pocketbase/pocketbase/tools/subscriptions"
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
)
|
||||
|
||||
func TestRealtimeConnect(t *testing.T) {
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/realtime",
|
||||
Timeout: 100 * time.Millisecond,
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`id:`,
|
||||
`event:PB_CONNECT`,
|
||||
`data:{"clientId":`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRealtimeConnectRequest": 1,
|
||||
"OnRealtimeMessageSend": 1,
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
if len(app.SubscriptionsBroker().Clients()) != 0 {
|
||||
t.Errorf("Expected the subscribers to be removed after connection close, found %d", len(app.SubscriptionsBroker().Clients()))
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "PB_CONNECT interrupt",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/realtime",
|
||||
Timeout: 100 * time.Millisecond,
|
||||
ExpectedStatus: 200,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRealtimeConnectRequest": 1,
|
||||
"OnRealtimeMessageSend": 1,
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.OnRealtimeMessageSend().BindFunc(func(e *core.RealtimeMessageEvent) error {
|
||||
if e.Message.Name == "PB_CONNECT" {
|
||||
return errors.New("PB_CONNECT error")
|
||||
}
|
||||
return e.Next()
|
||||
})
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
if len(app.SubscriptionsBroker().Clients()) != 0 {
|
||||
t.Errorf("Expected the subscribers to be removed after connection close, found %d", len(app.SubscriptionsBroker().Clients()))
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "Skipping/ignoring messages",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/realtime",
|
||||
Timeout: 100 * time.Millisecond,
|
||||
ExpectedStatus: 200,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRealtimeConnectRequest": 1,
|
||||
"OnRealtimeMessageSend": 1,
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.OnRealtimeMessageSend().BindFunc(func(e *core.RealtimeMessageEvent) error {
|
||||
return nil
|
||||
})
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
if len(app.SubscriptionsBroker().Clients()) != 0 {
|
||||
t.Errorf("Expected the subscribers to be removed after connection close, found %d", len(app.SubscriptionsBroker().Clients()))
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRealtimeSubscribe(t *testing.T) {
|
||||
client := subscriptions.NewDefaultClient()
|
||||
|
||||
resetClient := func() {
|
||||
client.Unsubscribe()
|
||||
client.Set(apis.RealtimeClientAuthKey, nil)
|
||||
}
|
||||
|
||||
validSubscriptionsLimit := make([]string, 1000)
|
||||
for i := 0; i < len(validSubscriptionsLimit); i++ {
|
||||
validSubscriptionsLimit[i] = fmt.Sprintf(`"%d"`, i)
|
||||
}
|
||||
invalidSubscriptionsLimit := make([]string, 1001)
|
||||
for i := 0; i < len(invalidSubscriptionsLimit); i++ {
|
||||
invalidSubscriptionsLimit[i] = fmt.Sprintf(`"%d"`, i)
|
||||
}
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "missing client",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/realtime",
|
||||
Body: strings.NewReader(`{"clientId":"missing","subscriptions":["test1", "test2"]}`),
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "empty data",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/realtime",
|
||||
Body: strings.NewReader(`{}`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"clientId":{"code":"validation_required`,
|
||||
},
|
||||
NotExpectedContent: []string{
|
||||
`"subscriptions"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "existing client with invalid subscriptions limit",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/realtime",
|
||||
Body: strings.NewReader(`{
|
||||
"clientId": "` + client.Id() + `",
|
||||
"subscriptions": [` + strings.Join(invalidSubscriptionsLimit, ",") + `]
|
||||
}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.SubscriptionsBroker().Register(client)
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
resetClient()
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"subscriptions":{"code":"validation_length_too_long"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "existing client with valid subscriptions limit",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/realtime",
|
||||
Body: strings.NewReader(`{
|
||||
"clientId": "` + client.Id() + `",
|
||||
"subscriptions": [` + strings.Join(validSubscriptionsLimit, ",") + `]
|
||||
}`),
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRealtimeSubscribeRequest": 1,
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
client.Subscribe("test0") // should be replaced
|
||||
app.SubscriptionsBroker().Register(client)
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
if len(client.Subscriptions()) != len(validSubscriptionsLimit) {
|
||||
t.Errorf("Expected %d subscriptions, got %d", len(validSubscriptionsLimit), len(client.Subscriptions()))
|
||||
}
|
||||
if client.HasSubscription("test0") {
|
||||
t.Errorf("Expected old subscriptions to be replaced")
|
||||
}
|
||||
resetClient()
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "existing client with invalid topic length",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/realtime",
|
||||
Body: strings.NewReader(`{
|
||||
"clientId": "` + client.Id() + `",
|
||||
"subscriptions": ["abc", "` + strings.Repeat("a", 2501) + `"]
|
||||
}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.SubscriptionsBroker().Register(client)
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
resetClient()
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"subscriptions":{"1":{"code":"validation_length_too_long"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "existing client with valid topic length",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/realtime",
|
||||
Body: strings.NewReader(`{
|
||||
"clientId": "` + client.Id() + `",
|
||||
"subscriptions": ["abc", "` + strings.Repeat("a", 2500) + `"]
|
||||
}`),
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRealtimeSubscribeRequest": 1,
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
client.Subscribe("test0")
|
||||
app.SubscriptionsBroker().Register(client)
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
if len(client.Subscriptions()) != 2 {
|
||||
t.Errorf("Expected %d subscriptions, got %d", 2, len(client.Subscriptions()))
|
||||
}
|
||||
if client.HasSubscription("test0") {
|
||||
t.Errorf("Expected old subscriptions to be replaced")
|
||||
}
|
||||
resetClient()
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "existing client - empty subscriptions",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/realtime",
|
||||
Body: strings.NewReader(`{"clientId":"` + client.Id() + `","subscriptions":[]}`),
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRealtimeSubscribeRequest": 1,
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
client.Subscribe("test0")
|
||||
app.SubscriptionsBroker().Register(client)
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
if len(client.Subscriptions()) != 0 {
|
||||
t.Errorf("Expected no subscriptions, got %d", len(client.Subscriptions()))
|
||||
}
|
||||
resetClient()
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "existing client - 2 new subscriptions",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/realtime",
|
||||
Body: strings.NewReader(`{"clientId":"` + client.Id() + `","subscriptions":["test1", "test2"]}`),
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRealtimeSubscribeRequest": 1,
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
client.Subscribe("test0")
|
||||
app.SubscriptionsBroker().Register(client)
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
expectedSubs := []string{"test1", "test2"}
|
||||
if len(expectedSubs) != len(client.Subscriptions()) {
|
||||
t.Errorf("Expected subscriptions %v, got %v", expectedSubs, client.Subscriptions())
|
||||
}
|
||||
|
||||
for _, s := range expectedSubs {
|
||||
if !client.HasSubscription(s) {
|
||||
t.Errorf("Cannot find %q subscription in %v", s, client.Subscriptions())
|
||||
}
|
||||
}
|
||||
resetClient()
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "existing client - guest -> authorized superuser",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/realtime",
|
||||
Body: strings.NewReader(`{"clientId":"` + client.Id() + `","subscriptions":["test1", "test2"]}`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRealtimeSubscribeRequest": 1,
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.SubscriptionsBroker().Register(client)
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
authRecord, _ := client.Get(apis.RealtimeClientAuthKey).(*core.Record)
|
||||
if authRecord == nil || !authRecord.IsSuperuser() {
|
||||
t.Errorf("Expected superuser auth record, got %v", authRecord)
|
||||
}
|
||||
resetClient()
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "existing client - guest -> authorized regular auth record",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/realtime",
|
||||
Body: strings.NewReader(`{"clientId":"` + client.Id() + `","subscriptions":["test1", "test2"]}`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRealtimeSubscribeRequest": 1,
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.SubscriptionsBroker().Register(client)
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
authRecord, _ := client.Get(apis.RealtimeClientAuthKey).(*core.Record)
|
||||
if authRecord == nil {
|
||||
t.Errorf("Expected regular user auth record, got %v", authRecord)
|
||||
}
|
||||
resetClient()
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "existing client - same auth",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/realtime",
|
||||
Body: strings.NewReader(`{"clientId":"` + client.Id() + `","subscriptions":["test1", "test2"]}`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRealtimeSubscribeRequest": 1,
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
// the same user as the auth token
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
client.Set(apis.RealtimeClientAuthKey, user)
|
||||
|
||||
app.SubscriptionsBroker().Register(client)
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
authRecord, _ := client.Get(apis.RealtimeClientAuthKey).(*core.Record)
|
||||
if authRecord == nil {
|
||||
t.Errorf("Expected auth record model, got nil")
|
||||
}
|
||||
resetClient()
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "existing client - mismatched auth",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/realtime",
|
||||
Body: strings.NewReader(`{"clientId":"` + client.Id() + `","subscriptions":["test1", "test2"]}`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
user, err := app.FindAuthRecordByEmail("users", "test2@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
client.Set(apis.RealtimeClientAuthKey, user)
|
||||
|
||||
app.SubscriptionsBroker().Register(client)
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
authRecord, _ := client.Get(apis.RealtimeClientAuthKey).(*core.Record)
|
||||
if authRecord == nil {
|
||||
t.Errorf("Expected auth record model, got nil")
|
||||
}
|
||||
resetClient()
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "existing client - unauthorized client",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/realtime",
|
||||
Body: strings.NewReader(`{"clientId":"` + client.Id() + `","subscriptions":["test1", "test2"]}`),
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
user, err := app.FindAuthRecordByEmail("users", "test2@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
client.Set(apis.RealtimeClientAuthKey, user)
|
||||
|
||||
app.SubscriptionsBroker().Register(client)
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
authRecord, _ := client.Get(apis.RealtimeClientAuthKey).(*core.Record)
|
||||
if authRecord == nil {
|
||||
t.Errorf("Expected auth record model, got nil")
|
||||
}
|
||||
resetClient()
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRealtimeAuthRecordDeleteEvent(t *testing.T) {
|
||||
testApp, _ := tests.NewTestApp()
|
||||
defer testApp.Cleanup()
|
||||
|
||||
// init realtime handlers
|
||||
apis.NewRouter(testApp)
|
||||
|
||||
authRecord1, err := testApp.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
authRecord2, err := testApp.FindAuthRecordByEmail("users", "test2@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
client1 := subscriptions.NewDefaultClient()
|
||||
client1.Set(apis.RealtimeClientAuthKey, authRecord1)
|
||||
testApp.SubscriptionsBroker().Register(client1)
|
||||
|
||||
client2 := subscriptions.NewDefaultClient()
|
||||
client2.Set(apis.RealtimeClientAuthKey, authRecord1)
|
||||
testApp.SubscriptionsBroker().Register(client2)
|
||||
|
||||
client3 := subscriptions.NewDefaultClient()
|
||||
client3.Set(apis.RealtimeClientAuthKey, authRecord2)
|
||||
testApp.SubscriptionsBroker().Register(client3)
|
||||
|
||||
// mock delete event
|
||||
e := new(core.ModelEvent)
|
||||
e.App = testApp
|
||||
e.Type = core.ModelEventTypeDelete
|
||||
e.Context = context.Background()
|
||||
e.Model = authRecord1
|
||||
|
||||
testApp.OnModelAfterDeleteSuccess().Trigger(e)
|
||||
|
||||
if total := len(testApp.SubscriptionsBroker().Clients()); total != 3 {
|
||||
t.Fatalf("Expected %d subscription clients, found %d", 3, total)
|
||||
}
|
||||
|
||||
if auth := client1.Get(apis.RealtimeClientAuthKey); auth != nil {
|
||||
t.Fatalf("[client1] Expected the auth state to be unset, found %#v", auth)
|
||||
}
|
||||
|
||||
if auth := client2.Get(apis.RealtimeClientAuthKey); auth != nil {
|
||||
t.Fatalf("[client2] Expected the auth state to be unset, found %#v", auth)
|
||||
}
|
||||
|
||||
if auth := client3.Get(apis.RealtimeClientAuthKey); auth == nil || auth.(*core.Record).Id != authRecord2.Id {
|
||||
t.Fatalf("[client3] Expected the auth state to be left unchanged, found %#v", auth)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRealtimeAuthRecordUpdateEvent(t *testing.T) {
|
||||
testApp, _ := tests.NewTestApp()
|
||||
defer testApp.Cleanup()
|
||||
|
||||
// init realtime handlers
|
||||
apis.NewRouter(testApp)
|
||||
|
||||
authRecord1, err := testApp.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
client := subscriptions.NewDefaultClient()
|
||||
client.Set(apis.RealtimeClientAuthKey, authRecord1)
|
||||
testApp.SubscriptionsBroker().Register(client)
|
||||
|
||||
// refetch the authRecord and change its email
|
||||
authRecord2, err := testApp.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
authRecord2.SetEmail("new@example.com")
|
||||
|
||||
// mock update event
|
||||
e := new(core.ModelEvent)
|
||||
e.App = testApp
|
||||
e.Type = core.ModelEventTypeUpdate
|
||||
e.Context = context.Background()
|
||||
e.Model = authRecord2
|
||||
|
||||
testApp.OnModelAfterUpdateSuccess().Trigger(e)
|
||||
|
||||
clientAuthRecord, _ := client.Get(apis.RealtimeClientAuthKey).(*core.Record)
|
||||
if clientAuthRecord.Email() != authRecord2.Email() {
|
||||
t.Fatalf("Expected authRecord with email %q, got %q", authRecord2.Email(), clientAuthRecord.Email())
|
||||
}
|
||||
}
|
||||
|
||||
// Custom auth record model struct
|
||||
// -------------------------------------------------------------------
|
||||
var _ core.Model = (*CustomUser)(nil)
|
||||
|
||||
type CustomUser struct {
|
||||
core.BaseModel
|
||||
|
||||
Email string `db:"email" json:"email"`
|
||||
}
|
||||
|
||||
func (m *CustomUser) TableName() string {
|
||||
return "users"
|
||||
}
|
||||
|
||||
func findCustomUserByEmail(app core.App, email string) (*CustomUser, error) {
|
||||
model := &CustomUser{}
|
||||
|
||||
err := app.ModelQuery(model).
|
||||
AndWhere(dbx.HashExp{"email": email}).
|
||||
Limit(1).
|
||||
One(model)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return model, nil
|
||||
}
|
||||
|
||||
func TestRealtimeCustomAuthModelDeleteEvent(t *testing.T) {
|
||||
testApp, _ := tests.NewTestApp()
|
||||
defer testApp.Cleanup()
|
||||
|
||||
// init realtime handlers
|
||||
apis.NewRouter(testApp)
|
||||
|
||||
authRecord1, err := testApp.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
authRecord2, err := testApp.FindAuthRecordByEmail("users", "test2@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
client1 := subscriptions.NewDefaultClient()
|
||||
client1.Set(apis.RealtimeClientAuthKey, authRecord1)
|
||||
testApp.SubscriptionsBroker().Register(client1)
|
||||
|
||||
client2 := subscriptions.NewDefaultClient()
|
||||
client2.Set(apis.RealtimeClientAuthKey, authRecord1)
|
||||
testApp.SubscriptionsBroker().Register(client2)
|
||||
|
||||
client3 := subscriptions.NewDefaultClient()
|
||||
client3.Set(apis.RealtimeClientAuthKey, authRecord2)
|
||||
testApp.SubscriptionsBroker().Register(client3)
|
||||
|
||||
// refetch the authRecord as CustomUser
|
||||
customUser, err := findCustomUserByEmail(testApp, authRecord1.Email())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// delete the custom user (should unset the client auth record)
|
||||
if err := testApp.Delete(customUser); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if total := len(testApp.SubscriptionsBroker().Clients()); total != 3 {
|
||||
t.Fatalf("Expected %d subscription clients, found %d", 3, total)
|
||||
}
|
||||
|
||||
if auth := client1.Get(apis.RealtimeClientAuthKey); auth != nil {
|
||||
t.Fatalf("[client1] Expected the auth state to be unset, found %#v", auth)
|
||||
}
|
||||
|
||||
if auth := client2.Get(apis.RealtimeClientAuthKey); auth != nil {
|
||||
t.Fatalf("[client2] Expected the auth state to be unset, found %#v", auth)
|
||||
}
|
||||
|
||||
if auth := client3.Get(apis.RealtimeClientAuthKey); auth == nil || auth.(*core.Record).Id != authRecord2.Id {
|
||||
t.Fatalf("[client3] Expected the auth state to be left unchanged, found %#v", auth)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRealtimeCustomAuthModelUpdateEvent(t *testing.T) {
|
||||
testApp, _ := tests.NewTestApp()
|
||||
defer testApp.Cleanup()
|
||||
|
||||
// init realtime handlers
|
||||
apis.NewRouter(testApp)
|
||||
|
||||
authRecord, err := testApp.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
client := subscriptions.NewDefaultClient()
|
||||
client.Set(apis.RealtimeClientAuthKey, authRecord)
|
||||
testApp.SubscriptionsBroker().Register(client)
|
||||
|
||||
// refetch the authRecord as CustomUser
|
||||
customUser, err := findCustomUserByEmail(testApp, "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// change its email
|
||||
customUser.Email = "new@example.com"
|
||||
if err := testApp.Save(customUser); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
clientAuthRecord, _ := client.Get(apis.RealtimeClientAuthKey).(*core.Record)
|
||||
if clientAuthRecord.Email() != customUser.Email {
|
||||
t.Fatalf("Expected authRecord with email %q, got %q", customUser.Email, clientAuthRecord.Email())
|
||||
}
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
var _ core.Model = (*CustomModelResolve)(nil)
|
||||
|
||||
type CustomModelResolve struct {
|
||||
core.BaseModel
|
||||
tableName string
|
||||
|
||||
Created string `db:"created"`
|
||||
}
|
||||
|
||||
func (m *CustomModelResolve) TableName() string {
|
||||
return m.tableName
|
||||
}
|
||||
|
||||
func TestRealtimeRecordResolve(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const testCollectionName = "realtime_test_collection"
|
||||
|
||||
testRecordId := core.GenerateDefaultRandomId()
|
||||
|
||||
client0 := subscriptions.NewDefaultClient()
|
||||
client0.Subscribe(testCollectionName + "/*")
|
||||
client0.Discard()
|
||||
// ---
|
||||
client1 := subscriptions.NewDefaultClient()
|
||||
client1.Subscribe(testCollectionName + "/*")
|
||||
// ---
|
||||
client2 := subscriptions.NewDefaultClient()
|
||||
client2.Subscribe(testCollectionName + "/" + testRecordId)
|
||||
// ---
|
||||
client3 := subscriptions.NewDefaultClient()
|
||||
client3.Subscribe("demo1/*")
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
op func(testApp core.App) error
|
||||
expected map[string][]string // clientId -> [events]
|
||||
}{
|
||||
{
|
||||
"core.Record",
|
||||
func(testApp core.App) error {
|
||||
c, err := testApp.FindCollectionByNameOrId(testCollectionName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r := core.NewRecord(c)
|
||||
r.Id = testRecordId
|
||||
|
||||
// create
|
||||
err = testApp.Save(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// update
|
||||
err = testApp.Save(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// delete
|
||||
err = testApp.Delete(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
map[string][]string{
|
||||
client1.Id(): {"create", "update", "delete"},
|
||||
client2.Id(): {"create", "update", "delete"},
|
||||
},
|
||||
},
|
||||
{
|
||||
"core.RecordProxy",
|
||||
func(testApp core.App) error {
|
||||
c, err := testApp.FindCollectionByNameOrId(testCollectionName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r := core.NewRecord(c)
|
||||
|
||||
proxy := &struct {
|
||||
core.BaseRecordProxy
|
||||
}{}
|
||||
proxy.SetProxyRecord(r)
|
||||
proxy.Id = testRecordId
|
||||
|
||||
// create
|
||||
err = testApp.Save(proxy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// update
|
||||
err = testApp.Save(proxy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// delete
|
||||
err = testApp.Delete(proxy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
map[string][]string{
|
||||
client1.Id(): {"create", "update", "delete"},
|
||||
client2.Id(): {"create", "update", "delete"},
|
||||
},
|
||||
},
|
||||
{
|
||||
"custom model struct",
|
||||
func(testApp core.App) error {
|
||||
m := &CustomModelResolve{tableName: testCollectionName}
|
||||
m.Id = testRecordId
|
||||
|
||||
// create
|
||||
err := testApp.Save(m)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// update
|
||||
m.Created = "123"
|
||||
err = testApp.Save(m)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// delete
|
||||
err = testApp.Delete(m)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
map[string][]string{
|
||||
client1.Id(): {"create", "update", "delete"},
|
||||
client2.Id(): {"create", "update", "delete"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
testApp, _ := tests.NewTestApp()
|
||||
defer testApp.Cleanup()
|
||||
|
||||
// init realtime handlers
|
||||
apis.NewRouter(testApp)
|
||||
|
||||
// create new test collection with public read access
|
||||
testCollection := core.NewBaseCollection(testCollectionName)
|
||||
testCollection.Fields.Add(&core.AutodateField{Name: "created", OnCreate: true, OnUpdate: true})
|
||||
testCollection.ListRule = types.Pointer("")
|
||||
testCollection.ViewRule = types.Pointer("")
|
||||
err := testApp.Save(testCollection)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
testApp.SubscriptionsBroker().Register(client0)
|
||||
testApp.SubscriptionsBroker().Register(client1)
|
||||
testApp.SubscriptionsBroker().Register(client2)
|
||||
testApp.SubscriptionsBroker().Register(client3)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
var notifications = map[string][]string{}
|
||||
|
||||
var mu sync.Mutex
|
||||
notify := func(clientId string, eventData []byte) {
|
||||
data := struct{ Action string }{}
|
||||
_ = json.Unmarshal(eventData, &data)
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
if notifications[clientId] == nil {
|
||||
notifications[clientId] = []string{}
|
||||
}
|
||||
notifications[clientId] = append(notifications[clientId], data.Action)
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
timeout := time.After(250 * time.Millisecond)
|
||||
|
||||
for {
|
||||
select {
|
||||
case e, ok := <-client0.Channel():
|
||||
if ok {
|
||||
notify(client0.Id(), e.Data)
|
||||
}
|
||||
case e, ok := <-client1.Channel():
|
||||
if ok {
|
||||
notify(client1.Id(), e.Data)
|
||||
}
|
||||
case e, ok := <-client2.Channel():
|
||||
if ok {
|
||||
notify(client2.Id(), e.Data)
|
||||
}
|
||||
case e, ok := <-client3.Channel():
|
||||
if ok {
|
||||
notify(client3.Id(), e.Data)
|
||||
}
|
||||
case <-timeout:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
err = s.op(testApp)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if len(s.expected) != len(notifications) {
|
||||
t.Fatalf("Expected %d notified clients, got %d:\n%v", len(s.expected), len(notifications), notifications)
|
||||
}
|
||||
|
||||
for id, events := range s.expected {
|
||||
if len(events) != len(notifications[id]) {
|
||||
t.Fatalf("[%s] Expected %d events, got %d:\n%v\n%v", id, len(events), len(notifications[id]), s.expected, notifications)
|
||||
}
|
||||
for _, event := range events {
|
||||
if !slices.Contains(notifications[id], event) {
|
||||
t.Fatalf("[%s] Missing expected event %q in %v", id, event, notifications[id])
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
79
apis/record_auth.go
Normal file
79
apis/record_auth.go
Normal file
|
@ -0,0 +1,79 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tools/router"
|
||||
)
|
||||
|
||||
// bindRecordAuthApi registers the auth record api endpoints and
|
||||
// the corresponding handlers.
|
||||
func bindRecordAuthApi(app core.App, rg *router.RouterGroup[*core.RequestEvent]) {
|
||||
// global oauth2 subscription redirect handler
|
||||
rg.GET("/oauth2-redirect", oauth2SubscriptionRedirect).Bind(
|
||||
SkipSuccessActivityLog(), // skip success log as it could contain sensitive information in the url
|
||||
)
|
||||
// add again as POST in case of response_mode=form_post
|
||||
rg.POST("/oauth2-redirect", oauth2SubscriptionRedirect).Bind(
|
||||
SkipSuccessActivityLog(), // skip success log as it could contain sensitive information in the url
|
||||
)
|
||||
|
||||
sub := rg.Group("/collections/{collection}")
|
||||
|
||||
sub.GET("/auth-methods", recordAuthMethods).Bind(
|
||||
collectionPathRateLimit("", "listAuthMethods"),
|
||||
)
|
||||
|
||||
sub.POST("/auth-refresh", recordAuthRefresh).Bind(
|
||||
collectionPathRateLimit("", "authRefresh"),
|
||||
RequireSameCollectionContextAuth(""),
|
||||
)
|
||||
|
||||
sub.POST("/auth-with-password", recordAuthWithPassword).Bind(
|
||||
collectionPathRateLimit("", "authWithPassword", "auth"),
|
||||
)
|
||||
|
||||
sub.POST("/auth-with-oauth2", recordAuthWithOAuth2).Bind(
|
||||
collectionPathRateLimit("", "authWithOAuth2", "auth"),
|
||||
)
|
||||
|
||||
sub.POST("/request-otp", recordRequestOTP).Bind(
|
||||
collectionPathRateLimit("", "requestOTP"),
|
||||
)
|
||||
sub.POST("/auth-with-otp", recordAuthWithOTP).Bind(
|
||||
collectionPathRateLimit("", "authWithOTP", "auth"),
|
||||
)
|
||||
|
||||
sub.POST("/request-password-reset", recordRequestPasswordReset).Bind(
|
||||
collectionPathRateLimit("", "requestPasswordReset"),
|
||||
)
|
||||
sub.POST("/confirm-password-reset", recordConfirmPasswordReset).Bind(
|
||||
collectionPathRateLimit("", "confirmPasswordReset"),
|
||||
)
|
||||
|
||||
sub.POST("/request-verification", recordRequestVerification).Bind(
|
||||
collectionPathRateLimit("", "requestVerification"),
|
||||
)
|
||||
sub.POST("/confirm-verification", recordConfirmVerification).Bind(
|
||||
collectionPathRateLimit("", "confirmVerification"),
|
||||
)
|
||||
|
||||
sub.POST("/request-email-change", recordRequestEmailChange).Bind(
|
||||
collectionPathRateLimit("", "requestEmailChange"),
|
||||
RequireSameCollectionContextAuth(""),
|
||||
)
|
||||
sub.POST("/confirm-email-change", recordConfirmEmailChange).Bind(
|
||||
collectionPathRateLimit("", "confirmEmailChange"),
|
||||
)
|
||||
|
||||
sub.POST("/impersonate/{id}", recordAuthImpersonate).Bind(RequireSuperuserAuth())
|
||||
}
|
||||
|
||||
func findAuthCollection(e *core.RequestEvent) (*core.Collection, error) {
|
||||
collection, err := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue("collection"))
|
||||
|
||||
if err != nil || !collection.IsAuth() {
|
||||
return nil, e.NotFoundError("Missing or invalid auth collection context.", err)
|
||||
}
|
||||
|
||||
return collection, nil
|
||||
}
|
122
apis/record_auth_email_change_confirm.go
Normal file
122
apis/record_auth_email_change_confirm.go
Normal file
|
@ -0,0 +1,122 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tools/security"
|
||||
)
|
||||
|
||||
func recordConfirmEmailChange(e *core.RequestEvent) error {
|
||||
collection, err := findAuthCollection(e)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if collection.Name == core.CollectionNameSuperusers {
|
||||
return e.BadRequestError("All superusers can change their emails directly.", nil)
|
||||
}
|
||||
|
||||
form := newEmailChangeConfirmForm(e.App, collection)
|
||||
if err = e.BindBody(form); err != nil {
|
||||
return firstApiError(err, e.BadRequestError("An error occurred while loading the submitted data.", err))
|
||||
}
|
||||
if err = form.validate(); err != nil {
|
||||
return firstApiError(err, e.BadRequestError("An error occurred while validating the submitted data.", err))
|
||||
}
|
||||
|
||||
authRecord, newEmail, err := form.parseToken()
|
||||
if err != nil {
|
||||
return firstApiError(err, e.BadRequestError("Invalid or expired token.", err))
|
||||
}
|
||||
|
||||
event := new(core.RecordConfirmEmailChangeRequestEvent)
|
||||
event.RequestEvent = e
|
||||
event.Collection = collection
|
||||
event.Record = authRecord
|
||||
event.NewEmail = newEmail
|
||||
|
||||
return e.App.OnRecordConfirmEmailChangeRequest().Trigger(event, func(e *core.RecordConfirmEmailChangeRequestEvent) error {
|
||||
e.Record.SetEmail(e.NewEmail)
|
||||
e.Record.SetVerified(true)
|
||||
|
||||
if err := e.App.Save(e.Record); err != nil {
|
||||
return firstApiError(err, e.BadRequestError("Failed to confirm email change.", err))
|
||||
}
|
||||
|
||||
return execAfterSuccessTx(true, e.App, func() error {
|
||||
return e.NoContent(http.StatusNoContent)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
func newEmailChangeConfirmForm(app core.App, collection *core.Collection) *EmailChangeConfirmForm {
|
||||
return &EmailChangeConfirmForm{
|
||||
app: app,
|
||||
collection: collection,
|
||||
}
|
||||
}
|
||||
|
||||
type EmailChangeConfirmForm struct {
|
||||
app core.App
|
||||
collection *core.Collection
|
||||
|
||||
Token string `form:"token" json:"token"`
|
||||
Password string `form:"password" json:"password"`
|
||||
}
|
||||
|
||||
func (form *EmailChangeConfirmForm) validate() error {
|
||||
return validation.ValidateStruct(form,
|
||||
validation.Field(&form.Token, validation.Required, validation.By(form.checkToken)),
|
||||
validation.Field(&form.Password, validation.Required, validation.Length(1, 100), validation.By(form.checkPassword)),
|
||||
)
|
||||
}
|
||||
|
||||
func (form *EmailChangeConfirmForm) checkToken(value any) error {
|
||||
_, _, err := form.parseToken()
|
||||
return err
|
||||
}
|
||||
|
||||
func (form *EmailChangeConfirmForm) checkPassword(value any) error {
|
||||
v, _ := value.(string)
|
||||
if v == "" {
|
||||
return nil // nothing to check
|
||||
}
|
||||
|
||||
authRecord, _, _ := form.parseToken()
|
||||
if authRecord == nil || !authRecord.ValidatePassword(v) {
|
||||
return validation.NewError("validation_invalid_password", "Missing or invalid auth record password.")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (form *EmailChangeConfirmForm) parseToken() (*core.Record, string, error) {
|
||||
// check token payload
|
||||
claims, _ := security.ParseUnverifiedJWT(form.Token)
|
||||
newEmail, _ := claims[core.TokenClaimNewEmail].(string)
|
||||
if newEmail == "" {
|
||||
return nil, "", validation.NewError("validation_invalid_token_payload", "Invalid token payload - newEmail must be set.")
|
||||
}
|
||||
|
||||
// ensure that there aren't other users with the new email
|
||||
_, err := form.app.FindAuthRecordByEmail(form.collection, newEmail)
|
||||
if err == nil {
|
||||
return nil, "", validation.NewError("validation_existing_token_email", "The new email address is already registered: "+newEmail)
|
||||
}
|
||||
|
||||
// verify that the token is not expired and its signature is valid
|
||||
authRecord, err := form.app.FindAuthRecordByToken(form.Token, core.TokenTypeEmailChange)
|
||||
if err != nil {
|
||||
return nil, "", validation.NewError("validation_invalid_token", "Invalid or expired token.")
|
||||
}
|
||||
|
||||
if authRecord.Collection().Id != form.collection.Id {
|
||||
return nil, "", validation.NewError("validation_token_collection_mismatch", "The provided token is for different auth collection.")
|
||||
}
|
||||
|
||||
return authRecord, newEmail, nil
|
||||
}
|
211
apis/record_auth_email_change_confirm_test.go
Normal file
211
apis/record_auth_email_change_confirm_test.go
Normal file
|
@ -0,0 +1,211 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestRecordConfirmEmailChange(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "not an auth collection",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/demo1/confirm-email-change",
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "empty data",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/confirm-email-change",
|
||||
Body: strings.NewReader(``),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":`,
|
||||
`"token":{"code":"validation_required"`,
|
||||
`"password":{"code":"validation_required"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "invalid data",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/confirm-email-change",
|
||||
Body: strings.NewReader(`{"token`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "expired token and correct password",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/confirm-email-change",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsInR5cGUiOiJlbWFpbENoYW5nZSIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSIsIm5ld0VtYWlsIjoiY2hhbmdlQGV4YW1wbGUuY29tIiwiZXhwIjoxNjQwOTkxNjYxfQ.dff842MO0mgRTHY8dktp0dqG9-7LGQOgRuiAbQpYBls",
|
||||
"password":"1234567890"
|
||||
}`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"token":{`,
|
||||
`"code":"validation_invalid_token"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "non-email change token",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/confirm-email-change",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
"password":"1234567890"
|
||||
}`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"token":{`,
|
||||
`"code":"validation_invalid_token_payload"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "valid token and incorrect password",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/confirm-email-change",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsInR5cGUiOiJlbWFpbENoYW5nZSIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSIsIm5ld0VtYWlsIjoiY2hhbmdlQGV4YW1wbGUuY29tIiwiZXhwIjoyNTI0NjA0NDYxfQ.Y7mVlaEPhJiNPoIvIqbIosZU4c4lEhwysOrRR8c95iU",
|
||||
"password":"1234567891"
|
||||
}`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"password":{`,
|
||||
`"code":"validation_invalid_password"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "valid token and correct password",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/confirm-email-change",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsInR5cGUiOiJlbWFpbENoYW5nZSIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSIsIm5ld0VtYWlsIjoiY2hhbmdlQGV4YW1wbGUuY29tIiwiZXhwIjoyNTI0NjA0NDYxfQ.Y7mVlaEPhJiNPoIvIqbIosZU4c4lEhwysOrRR8c95iU",
|
||||
"password":"1234567890"
|
||||
}`),
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordConfirmEmailChangeRequest": 1,
|
||||
"OnModelUpdate": 1,
|
||||
"OnModelUpdateExecute": 1,
|
||||
"OnModelAfterUpdateSuccess": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnRecordUpdate": 1,
|
||||
"OnRecordUpdateExecute": 1,
|
||||
"OnRecordAfterUpdateSuccess": 1,
|
||||
"OnRecordValidate": 1,
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
_, err := app.FindAuthRecordByEmail("users", "change@example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Expected to find user with email %q, got error: %v", "change@example.com", err)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "valid token in different auth collection",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/clients/confirm-email-change",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsInR5cGUiOiJlbWFpbENoYW5nZSIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSIsIm5ld0VtYWlsIjoiY2hhbmdlQGV4YW1wbGUuY29tIiwiZXhwIjoyNTI0NjA0NDYxfQ.Y7mVlaEPhJiNPoIvIqbIosZU4c4lEhwysOrRR8c95iU",
|
||||
"password":"1234567890"
|
||||
}`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"token":{"code":"validation_token_collection_mismatch"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "OnRecordConfirmEmailChangeRequest tx body write check",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/confirm-email-change",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsInR5cGUiOiJlbWFpbENoYW5nZSIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSIsIm5ld0VtYWlsIjoiY2hhbmdlQGV4YW1wbGUuY29tIiwiZXhwIjoyNTI0NjA0NDYxfQ.Y7mVlaEPhJiNPoIvIqbIosZU4c4lEhwysOrRR8c95iU",
|
||||
"password":"1234567890"
|
||||
}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.OnRecordConfirmEmailChangeRequest().BindFunc(func(e *core.RecordConfirmEmailChangeRequestEvent) error {
|
||||
original := e.App
|
||||
return e.App.RunInTransaction(func(txApp core.App) error {
|
||||
e.App = txApp
|
||||
defer func() { e.App = original }()
|
||||
|
||||
if err := e.Next(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return e.BadRequestError("TX_ERROR", nil)
|
||||
})
|
||||
})
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedEvents: map[string]int{"OnRecordConfirmEmailChangeRequest": 1},
|
||||
ExpectedContent: []string{"TX_ERROR"},
|
||||
},
|
||||
|
||||
// rate limit checks
|
||||
// -----------------------------------------------------------
|
||||
{
|
||||
Name: "RateLimit rule - users:confirmEmailChange",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/confirm-email-change",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsInR5cGUiOiJlbWFpbENoYW5nZSIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSIsIm5ld0VtYWlsIjoiY2hhbmdlQGV4YW1wbGUuY29tIiwiZXhwIjoyNTI0NjA0NDYxfQ.Y7mVlaEPhJiNPoIvIqbIosZU4c4lEhwysOrRR8c95iU",
|
||||
"password":"1234567890"
|
||||
}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 100, Label: "*:confirmEmailChange"},
|
||||
{MaxRequests: 0, Label: "users:confirmEmailChange"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "RateLimit rule - *:confirmEmailChange",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/confirm-email-change",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsInR5cGUiOiJlbWFpbENoYW5nZSIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSIsIm5ld0VtYWlsIjoiY2hhbmdlQGV4YW1wbGUuY29tIiwiZXhwIjoyNTI0NjA0NDYxfQ.Y7mVlaEPhJiNPoIvIqbIosZU4c4lEhwysOrRR8c95iU",
|
||||
"password":"1234567890"
|
||||
}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 0, Label: "*:confirmEmailChange"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
92
apis/record_auth_email_change_request.go
Normal file
92
apis/record_auth_email_change_request.go
Normal file
|
@ -0,0 +1,92 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/go-ozzo/ozzo-validation/v4/is"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/mails"
|
||||
)
|
||||
|
||||
func recordRequestEmailChange(e *core.RequestEvent) error {
|
||||
collection, err := findAuthCollection(e)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if collection.Name == core.CollectionNameSuperusers {
|
||||
return e.BadRequestError("All superusers can change their emails directly.", nil)
|
||||
}
|
||||
|
||||
record := e.Auth
|
||||
if record == nil {
|
||||
return e.UnauthorizedError("The request requires valid auth record.", nil)
|
||||
}
|
||||
|
||||
form := newEmailChangeRequestForm(e.App, record)
|
||||
if err = e.BindBody(form); err != nil {
|
||||
return firstApiError(err, e.BadRequestError("An error occurred while loading the submitted data.", err))
|
||||
}
|
||||
if err = form.validate(); err != nil {
|
||||
return firstApiError(err, e.BadRequestError("An error occurred while validating the submitted data.", err))
|
||||
}
|
||||
|
||||
event := new(core.RecordRequestEmailChangeRequestEvent)
|
||||
event.RequestEvent = e
|
||||
event.Collection = collection
|
||||
event.Record = record
|
||||
event.NewEmail = form.NewEmail
|
||||
|
||||
return e.App.OnRecordRequestEmailChangeRequest().Trigger(event, func(e *core.RecordRequestEmailChangeRequestEvent) error {
|
||||
if err := mails.SendRecordChangeEmail(e.App, e.Record, e.NewEmail); err != nil {
|
||||
return firstApiError(err, e.BadRequestError("Failed to request email change.", err))
|
||||
}
|
||||
|
||||
return execAfterSuccessTx(true, e.App, func() error {
|
||||
return e.NoContent(http.StatusNoContent)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
func newEmailChangeRequestForm(app core.App, record *core.Record) *emailChangeRequestForm {
|
||||
return &emailChangeRequestForm{
|
||||
app: app,
|
||||
record: record,
|
||||
}
|
||||
}
|
||||
|
||||
type emailChangeRequestForm struct {
|
||||
app core.App
|
||||
record *core.Record
|
||||
|
||||
NewEmail string `form:"newEmail" json:"newEmail"`
|
||||
}
|
||||
|
||||
func (form *emailChangeRequestForm) validate() error {
|
||||
return validation.ValidateStruct(form,
|
||||
validation.Field(&form.NewEmail,
|
||||
validation.Required,
|
||||
validation.Length(1, 255),
|
||||
is.EmailFormat,
|
||||
validation.NotIn(form.record.Email()),
|
||||
validation.By(form.checkUniqueEmail),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
func (form *emailChangeRequestForm) checkUniqueEmail(value any) error {
|
||||
v, _ := value.(string)
|
||||
if v == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
found, _ := form.app.FindAuthRecordByEmail(form.record.Collection(), v)
|
||||
if found != nil && found.Id != form.record.Id {
|
||||
return validation.NewError("validation_invalid_new_email", "Invalid new email address.")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
195
apis/record_auth_email_change_request_test.go
Normal file
195
apis/record_auth_email_change_request_test.go
Normal file
|
@ -0,0 +1,195 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestRecordRequestEmailChange(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "unauthorized",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-email-change",
|
||||
Body: strings.NewReader(`{"newEmail":"change@example.com"}`),
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "not an auth collection",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/demo1/request-email-change",
|
||||
Body: strings.NewReader(`{"newEmail":"change@example.com"}`),
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "record authentication but from different auth collection",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/clients/request-email-change",
|
||||
Body: strings.NewReader(`{"newEmail":"change@example.com"}`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "superuser authentication",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-email-change",
|
||||
Body: strings.NewReader(`{"newEmail":"change@example.com"}`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "invalid data",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-email-change",
|
||||
Body: strings.NewReader(`{"newEmail`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "empty data",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-email-change",
|
||||
Body: strings.NewReader(`{}`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":`,
|
||||
`"newEmail":{"code":"validation_required"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "valid data (existing email)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-email-change",
|
||||
Body: strings.NewReader(`{"newEmail":"test2@example.com"}`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":`,
|
||||
`"newEmail":{"code":"validation_invalid_new_email"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "valid data (new email)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-email-change",
|
||||
Body: strings.NewReader(`{"newEmail":"change@example.com"}`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordRequestEmailChangeRequest": 1,
|
||||
"OnMailerSend": 1,
|
||||
"OnMailerRecordEmailChangeSend": 1,
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
if !strings.Contains(app.TestMailer.LastMessage().HTML, "/auth/confirm-email-change") {
|
||||
t.Fatalf("Expected email change email, got\n%v", app.TestMailer.LastMessage().HTML)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "OnRecordRequestEmailChangeRequest tx body write check",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-email-change",
|
||||
Body: strings.NewReader(`{"newEmail":"change@example.com"}`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.OnRecordRequestEmailChangeRequest().BindFunc(func(e *core.RecordRequestEmailChangeRequestEvent) error {
|
||||
original := e.App
|
||||
return e.App.RunInTransaction(func(txApp core.App) error {
|
||||
e.App = txApp
|
||||
defer func() { e.App = original }()
|
||||
|
||||
if err := e.Next(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return e.BadRequestError("TX_ERROR", nil)
|
||||
})
|
||||
})
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedEvents: map[string]int{"OnRecordRequestEmailChangeRequest": 1},
|
||||
ExpectedContent: []string{"TX_ERROR"},
|
||||
},
|
||||
|
||||
// rate limit checks
|
||||
// -----------------------------------------------------------
|
||||
{
|
||||
Name: "RateLimit rule - users:requestEmailChange",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-email-change",
|
||||
Body: strings.NewReader(`{"newEmail":"change@example.com"}`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 100, Label: "*:requestEmailChange"},
|
||||
{MaxRequests: 0, Label: "users:requestEmailChange"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "RateLimit rule - *:requestEmailChange",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-email-change",
|
||||
Body: strings.NewReader(`{"newEmail":"change@example.com"}`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 0, Label: "*:requestEmailChange"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
54
apis/record_auth_impersonate.go
Normal file
54
apis/record_auth_impersonate.go
Normal file
|
@ -0,0 +1,54 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
)
|
||||
|
||||
// note: for now allow superusers but it may change in the future to allow access
|
||||
// also to users with "Manage API" rule access depending on the use cases that will arise
|
||||
func recordAuthImpersonate(e *core.RequestEvent) error {
|
||||
if !e.HasSuperuserAuth() {
|
||||
return e.ForbiddenError("", nil)
|
||||
}
|
||||
|
||||
collection, err := findAuthCollection(e)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
record, err := e.App.FindRecordById(collection, e.Request.PathValue("id"))
|
||||
if err != nil {
|
||||
return e.NotFoundError("", err)
|
||||
}
|
||||
|
||||
form := &impersonateForm{}
|
||||
if err = e.BindBody(form); err != nil {
|
||||
return e.BadRequestError("An error occurred while loading the submitted data.", err)
|
||||
}
|
||||
if err = form.validate(); err != nil {
|
||||
return e.BadRequestError("An error occurred while validating the submitted data.", err)
|
||||
}
|
||||
|
||||
token, err := record.NewStaticAuthToken(time.Duration(form.Duration) * time.Second)
|
||||
if err != nil {
|
||||
e.InternalServerError("Failed to generate static auth token", err)
|
||||
}
|
||||
|
||||
return recordAuthResponse(e, record, token, "", nil)
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
type impersonateForm struct {
|
||||
// Duration is the optional custom token duration in seconds.
|
||||
Duration int64 `form:"duration" json:"duration"`
|
||||
}
|
||||
|
||||
func (form *impersonateForm) validate() error {
|
||||
return validation.ValidateStruct(form,
|
||||
validation.Field(&form.Duration, validation.Min(0)),
|
||||
)
|
||||
}
|
109
apis/record_auth_impersonate_test.go
Normal file
109
apis/record_auth_impersonate_test.go
Normal file
|
@ -0,0 +1,109 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestRecordAuthImpersonate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "unauthorized",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/impersonate/4q1xlclmfloku33",
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as different user",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/impersonate/4q1xlclmfloku33",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6Im9hcDY0MGNvdDR5cnUycyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.GfJo6EHIobgas_AXt-M-tj5IoQendPnrkMSe9ExuSEY",
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as the same user",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/impersonate/4q1xlclmfloku33",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/impersonate/4q1xlclmfloku33",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"token":"`,
|
||||
`"id":"4q1xlclmfloku33"`,
|
||||
`"record":{`,
|
||||
},
|
||||
NotExpectedContent: []string{
|
||||
// hidden fields should remain hidden even though we are authenticated as superuser
|
||||
`"tokenKey"`,
|
||||
`"password"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordAuthRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser with custom invalid duration",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/impersonate/4q1xlclmfloku33",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
Body: strings.NewReader(`{"duration":-1}`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"duration":{`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser with custom valid duration",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/impersonate/4q1xlclmfloku33",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
Body: strings.NewReader(`{"duration":100}`),
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"token":"`,
|
||||
`"id":"4q1xlclmfloku33"`,
|
||||
`"record":{`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordAuthRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
170
apis/record_auth_methods.go
Normal file
170
apis/record_auth_methods.go
Normal file
|
@ -0,0 +1,170 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"slices"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tools/auth"
|
||||
"github.com/pocketbase/pocketbase/tools/security"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
type otpResponse struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Duration int64 `json:"duration"` // in seconds
|
||||
}
|
||||
|
||||
type mfaResponse struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Duration int64 `json:"duration"` // in seconds
|
||||
}
|
||||
|
||||
type passwordResponse struct {
|
||||
IdentityFields []string `json:"identityFields"`
|
||||
Enabled bool `json:"enabled"`
|
||||
}
|
||||
|
||||
type oauth2Response struct {
|
||||
Providers []providerInfo `json:"providers"`
|
||||
Enabled bool `json:"enabled"`
|
||||
}
|
||||
|
||||
type providerInfo struct {
|
||||
Name string `json:"name"`
|
||||
DisplayName string `json:"displayName"`
|
||||
State string `json:"state"`
|
||||
AuthURL string `json:"authURL"`
|
||||
|
||||
// @todo
|
||||
// deprecated: use AuthURL instead
|
||||
// AuthUrl will be removed after dropping v0.22 support
|
||||
AuthUrl string `json:"authUrl"`
|
||||
|
||||
// technically could be omitted if the provider doesn't support PKCE,
|
||||
// but to avoid breaking existing typed clients we'll return them as empty string
|
||||
CodeVerifier string `json:"codeVerifier"`
|
||||
CodeChallenge string `json:"codeChallenge"`
|
||||
CodeChallengeMethod string `json:"codeChallengeMethod"`
|
||||
}
|
||||
|
||||
type authMethodsResponse struct {
|
||||
Password passwordResponse `json:"password"`
|
||||
OAuth2 oauth2Response `json:"oauth2"`
|
||||
MFA mfaResponse `json:"mfa"`
|
||||
OTP otpResponse `json:"otp"`
|
||||
|
||||
// legacy fields
|
||||
// @todo remove after dropping v0.22 support
|
||||
AuthProviders []providerInfo `json:"authProviders"`
|
||||
UsernamePassword bool `json:"usernamePassword"`
|
||||
EmailPassword bool `json:"emailPassword"`
|
||||
}
|
||||
|
||||
func (amr *authMethodsResponse) fillLegacyFields() {
|
||||
amr.EmailPassword = amr.Password.Enabled && slices.Contains(amr.Password.IdentityFields, "email")
|
||||
|
||||
amr.UsernamePassword = amr.Password.Enabled && slices.Contains(amr.Password.IdentityFields, "username")
|
||||
|
||||
if amr.OAuth2.Enabled {
|
||||
amr.AuthProviders = amr.OAuth2.Providers
|
||||
}
|
||||
}
|
||||
|
||||
func recordAuthMethods(e *core.RequestEvent) error {
|
||||
collection, err := findAuthCollection(e)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
result := authMethodsResponse{
|
||||
Password: passwordResponse{
|
||||
IdentityFields: make([]string, 0, len(collection.PasswordAuth.IdentityFields)),
|
||||
},
|
||||
OAuth2: oauth2Response{
|
||||
Providers: make([]providerInfo, 0, len(collection.OAuth2.Providers)),
|
||||
},
|
||||
OTP: otpResponse{
|
||||
Enabled: collection.OTP.Enabled,
|
||||
},
|
||||
MFA: mfaResponse{
|
||||
Enabled: collection.MFA.Enabled,
|
||||
},
|
||||
}
|
||||
|
||||
if collection.PasswordAuth.Enabled {
|
||||
result.Password.Enabled = true
|
||||
result.Password.IdentityFields = collection.PasswordAuth.IdentityFields
|
||||
}
|
||||
|
||||
if collection.OTP.Enabled {
|
||||
result.OTP.Duration = collection.OTP.Duration
|
||||
}
|
||||
|
||||
if collection.MFA.Enabled {
|
||||
result.MFA.Duration = collection.MFA.Duration
|
||||
}
|
||||
|
||||
if !collection.OAuth2.Enabled {
|
||||
result.fillLegacyFields()
|
||||
|
||||
return e.JSON(http.StatusOK, result)
|
||||
}
|
||||
|
||||
result.OAuth2.Enabled = true
|
||||
|
||||
for _, config := range collection.OAuth2.Providers {
|
||||
provider, err := config.InitProvider()
|
||||
if err != nil {
|
||||
e.App.Logger().Debug(
|
||||
"Failed to setup OAuth2 provider",
|
||||
slog.String("name", config.Name),
|
||||
slog.String("error", err.Error()),
|
||||
)
|
||||
continue // skip provider
|
||||
}
|
||||
|
||||
info := providerInfo{
|
||||
Name: config.Name,
|
||||
DisplayName: provider.DisplayName(),
|
||||
State: security.RandomString(30),
|
||||
}
|
||||
|
||||
if info.DisplayName == "" {
|
||||
info.DisplayName = config.Name
|
||||
}
|
||||
|
||||
urlOpts := []oauth2.AuthCodeOption{}
|
||||
|
||||
// custom providers url options
|
||||
switch config.Name {
|
||||
case auth.NameApple:
|
||||
// see https://developer.apple.com/documentation/sign_in_with_apple/sign_in_with_apple_js/incorporating_sign_in_with_apple_into_other_platforms#3332113
|
||||
urlOpts = append(urlOpts, oauth2.SetAuthURLParam("response_mode", "form_post"))
|
||||
}
|
||||
|
||||
if provider.PKCE() {
|
||||
info.CodeVerifier = security.RandomString(43)
|
||||
info.CodeChallenge = security.S256Challenge(info.CodeVerifier)
|
||||
info.CodeChallengeMethod = "S256"
|
||||
urlOpts = append(urlOpts,
|
||||
oauth2.SetAuthURLParam("code_challenge", info.CodeChallenge),
|
||||
oauth2.SetAuthURLParam("code_challenge_method", info.CodeChallengeMethod),
|
||||
)
|
||||
}
|
||||
|
||||
info.AuthURL = provider.BuildAuthURL(
|
||||
info.State,
|
||||
urlOpts...,
|
||||
) + "&redirect_uri=" // empty redirect_uri so that users can append their redirect url
|
||||
|
||||
info.AuthUrl = info.AuthURL
|
||||
|
||||
result.OAuth2.Providers = append(result.OAuth2.Providers, info)
|
||||
}
|
||||
|
||||
result.fillLegacyFields()
|
||||
|
||||
return e.JSON(http.StatusOK, result)
|
||||
}
|
106
apis/record_auth_methods_test.go
Normal file
106
apis/record_auth_methods_test.go
Normal file
|
@ -0,0 +1,106 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestRecordAuthMethodsList(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "missing collection",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/missing/auth-methods",
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "non auth collection",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/demo1/auth-methods",
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "auth collection with none auth methods allowed",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/nologin/auth-methods",
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"password":{"identityFields":[],"enabled":false}`,
|
||||
`"oauth2":{"providers":[],"enabled":false}`,
|
||||
`"mfa":{"enabled":false,"duration":0}`,
|
||||
`"otp":{"enabled":false,"duration":0}`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "auth collection with all auth methods allowed",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/users/auth-methods",
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"password":{"identityFields":["email","username"],"enabled":true}`,
|
||||
`"mfa":{"enabled":true,"duration":1800}`,
|
||||
`"otp":{"enabled":true,"duration":300}`,
|
||||
`"oauth2":{`,
|
||||
`"providers":[{`,
|
||||
`"name":"google"`,
|
||||
`"name":"gitlab"`,
|
||||
`"state":`,
|
||||
`"displayName":`,
|
||||
`"codeVerifier":`,
|
||||
`"codeChallenge":`,
|
||||
`"codeChallengeMethod":`,
|
||||
`"authURL":`,
|
||||
`redirect_uri="`, // ensures that the redirect_uri is the last url param
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
|
||||
// rate limit checks
|
||||
// -----------------------------------------------------------
|
||||
{
|
||||
Name: "RateLimit rule - nologin:listAuthMethods",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/nologin/auth-methods",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 100, Label: "*:listAuthMethods"},
|
||||
{MaxRequests: 0, Label: "nologin:listAuthMethods"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "RateLimit rule - *:listAuthMethods",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/nologin/auth-methods",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 0, Label: "*:listAuthMethods"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
127
apis/record_auth_otp_request.go
Normal file
127
apis/record_auth_otp_request.go
Normal file
|
@ -0,0 +1,127 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/go-ozzo/ozzo-validation/v4/is"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/mails"
|
||||
"github.com/pocketbase/pocketbase/tools/routine"
|
||||
"github.com/pocketbase/pocketbase/tools/security"
|
||||
)
|
||||
|
||||
func recordRequestOTP(e *core.RequestEvent) error {
|
||||
collection, err := findAuthCollection(e)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !collection.OTP.Enabled {
|
||||
return e.ForbiddenError("The collection is not configured to allow OTP authentication.", nil)
|
||||
}
|
||||
|
||||
form := &createOTPForm{}
|
||||
if err = e.BindBody(form); err != nil {
|
||||
return firstApiError(err, e.BadRequestError("An error occurred while loading the submitted data.", err))
|
||||
}
|
||||
if err = form.validate(); err != nil {
|
||||
return firstApiError(err, e.BadRequestError("An error occurred while validating the submitted data.", err))
|
||||
}
|
||||
|
||||
record, err := e.App.FindAuthRecordByEmail(collection, form.Email)
|
||||
|
||||
// ignore not found errors to allow custom record find implementations
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
return e.InternalServerError("", err)
|
||||
}
|
||||
|
||||
event := new(core.RecordCreateOTPRequestEvent)
|
||||
event.RequestEvent = e
|
||||
event.Password = security.RandomStringWithAlphabet(collection.OTP.Length, "1234567890")
|
||||
event.Collection = collection
|
||||
event.Record = record
|
||||
|
||||
originalApp := e.App
|
||||
|
||||
return e.App.OnRecordRequestOTPRequest().Trigger(event, func(e *core.RecordCreateOTPRequestEvent) error {
|
||||
if e.Record == nil {
|
||||
// write a dummy 200 response as a very rudimentary emails enumeration "protection"
|
||||
e.JSON(http.StatusOK, map[string]string{
|
||||
"otpId": core.GenerateDefaultRandomId(),
|
||||
})
|
||||
|
||||
return fmt.Errorf("missing or invalid %s OTP auth record with email %s", collection.Name, form.Email)
|
||||
}
|
||||
|
||||
var otp *core.OTP
|
||||
|
||||
// limit the new OTP creations for a single user
|
||||
if !e.App.IsDev() {
|
||||
otps, err := e.App.FindAllOTPsByRecord(e.Record)
|
||||
if err != nil {
|
||||
return firstApiError(err, e.InternalServerError("Failed to fetch previous record OTPs.", err))
|
||||
}
|
||||
|
||||
totalRecent := 0
|
||||
for _, existingOTP := range otps {
|
||||
if !existingOTP.HasExpired(collection.OTP.DurationTime()) {
|
||||
totalRecent++
|
||||
}
|
||||
// use the last issued one
|
||||
if totalRecent > 9 {
|
||||
otp = otps[0] // otps are DESC sorted
|
||||
e.App.Logger().Warn(
|
||||
"Too many OTP requests - reusing the last issued",
|
||||
"email", form.Email,
|
||||
"recordId", e.Record.Id,
|
||||
"otpId", existingOTP.Id,
|
||||
)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if otp == nil {
|
||||
// create new OTP
|
||||
// ---
|
||||
otp = core.NewOTP(e.App)
|
||||
otp.SetCollectionRef(e.Record.Collection().Id)
|
||||
otp.SetRecordRef(e.Record.Id)
|
||||
otp.SetPassword(e.Password)
|
||||
err = e.App.Save(otp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// send OTP email
|
||||
// (in the background as a very basic timing attacks and emails enumeration protection)
|
||||
// ---
|
||||
routine.FireAndForget(func() {
|
||||
err = mails.SendRecordOTP(originalApp, e.Record, otp.Id, e.Password)
|
||||
if err != nil {
|
||||
originalApp.Logger().Error("Failed to send OTP email", "error", errors.Join(err, originalApp.Delete(otp)))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
return execAfterSuccessTx(true, e.App, func() error {
|
||||
return e.JSON(http.StatusOK, map[string]string{"otpId": otp.Id})
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
type createOTPForm struct {
|
||||
Email string `form:"email" json:"email"`
|
||||
}
|
||||
|
||||
func (form createOTPForm) validate() error {
|
||||
return validation.ValidateStruct(&form,
|
||||
validation.Field(&form.Email, validation.Required, validation.Length(1, 255), is.EmailFormat),
|
||||
)
|
||||
}
|
316
apis/record_auth_otp_request_test.go
Normal file
316
apis/record_auth_otp_request_test.go
Normal file
|
@ -0,0 +1,316 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
)
|
||||
|
||||
func TestRecordRequestOTP(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "not an auth collection",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/demo1/request-otp",
|
||||
Body: strings.NewReader(`{"email":"test@example.com"}`),
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "auth collection with disabled otp",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-otp",
|
||||
Body: strings.NewReader(`{"email":"test@example.com"}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
usersCol, err := app.FindCollectionByNameOrId("users")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
usersCol.OTP.Enabled = false
|
||||
|
||||
if err := app.Save(usersCol); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "empty body",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-otp",
|
||||
Body: strings.NewReader(``),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{"email":{"code":"validation_required","message":"Cannot be blank."}}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "invalid body",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-otp",
|
||||
Body: strings.NewReader(`{"email`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "invalid request data",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-otp",
|
||||
Body: strings.NewReader(`{"email":"invalid"}`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"email":{"code":"validation_is_email`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "missing auth record",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-otp",
|
||||
Body: strings.NewReader(`{"email":"missing@example.com"}`),
|
||||
Delay: 100 * time.Millisecond,
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"otpId":"`, // some fake random generated string
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordRequestOTPRequest": 1,
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
if app.TestMailer.TotalSend() != 0 {
|
||||
t.Fatalf("Expected zero emails, got %d", app.TestMailer.TotalSend())
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "existing auth record (with < 9 non-expired)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-otp",
|
||||
Body: strings.NewReader(`{"email":"test@example.com"}`),
|
||||
Delay: 100 * time.Millisecond,
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// insert 8 non-expired and 2 expired
|
||||
for i := 0; i < 10; i++ {
|
||||
otp := core.NewOTP(app)
|
||||
otp.Id = "otp_" + strconv.Itoa(i)
|
||||
otp.SetCollectionRef(user.Collection().Id)
|
||||
otp.SetRecordRef(user.Id)
|
||||
otp.SetPassword("123456")
|
||||
if i >= 8 {
|
||||
expiredDate := types.NowDateTime().AddDate(-3, 0, 0)
|
||||
otp.SetRaw("created", expiredDate)
|
||||
otp.SetRaw("updated", expiredDate)
|
||||
}
|
||||
if err := app.SaveNoValidate(otp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"otpId":"`,
|
||||
},
|
||||
NotExpectedContent: []string{
|
||||
`"otpId":"otp_`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordRequestOTPRequest": 1,
|
||||
"OnMailerSend": 1,
|
||||
"OnMailerRecordOTPSend": 1,
|
||||
"OnModelCreate": 1,
|
||||
"OnModelCreateExecute": 1,
|
||||
"OnModelAfterCreateSuccess": 1,
|
||||
"OnModelValidate": 2, // + 1 for the OTP update after the email send
|
||||
"OnRecordCreate": 1,
|
||||
"OnRecordCreateExecute": 1,
|
||||
"OnRecordAfterCreateSuccess": 1,
|
||||
"OnRecordValidate": 2,
|
||||
// OTP update
|
||||
"OnModelUpdate": 1,
|
||||
"OnModelUpdateExecute": 1,
|
||||
"OnModelAfterUpdateSuccess": 1,
|
||||
"OnRecordUpdate": 1,
|
||||
"OnRecordUpdateExecute": 1,
|
||||
"OnRecordAfterUpdateSuccess": 1,
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
if app.TestMailer.TotalSend() != 1 {
|
||||
t.Fatalf("Expected 1 email, got %d", app.TestMailer.TotalSend())
|
||||
}
|
||||
|
||||
// ensure that sentTo is set
|
||||
otps, err := app.FindRecordsByFilter(core.CollectionNameOTPs, "sentTo='test@example.com'", "", 0, 0)
|
||||
if err != nil || len(otps) != 1 {
|
||||
t.Fatalf("Expected to find 1 OTP with sentTo %q, found %d", "test@example.com", len(otps))
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "existing auth record with intercepted email (with < 9 non-expired)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-otp",
|
||||
Body: strings.NewReader(`{"email":"test@example.com"}`),
|
||||
Delay: 100 * time.Millisecond,
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
// prevent email sent
|
||||
app.OnMailerRecordOTPSend("users").BindFunc(func(e *core.MailerRecordEvent) error {
|
||||
return nil
|
||||
})
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"otpId":"`,
|
||||
},
|
||||
NotExpectedContent: []string{
|
||||
`"otpId":"otp_`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordRequestOTPRequest": 1,
|
||||
"OnMailerRecordOTPSend": 1,
|
||||
"OnModelCreate": 1,
|
||||
"OnModelCreateExecute": 1,
|
||||
"OnModelAfterCreateSuccess": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnRecordCreate": 1,
|
||||
"OnRecordCreateExecute": 1,
|
||||
"OnRecordAfterCreateSuccess": 1,
|
||||
"OnRecordValidate": 1,
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
if app.TestMailer.TotalSend() != 0 {
|
||||
t.Fatalf("Expected 0 emails, got %d", app.TestMailer.TotalSend())
|
||||
}
|
||||
|
||||
// ensure that there is no OTP with user email as sentTo
|
||||
otps, err := app.FindRecordsByFilter(core.CollectionNameOTPs, "sentTo='test@example.com'", "", 0, 0)
|
||||
if err != nil || len(otps) != 0 {
|
||||
t.Fatalf("Expected to find 0 OTPs with sentTo %q, found %d", "test@example.com", len(otps))
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "existing auth record (with > 9 non-expired)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-otp",
|
||||
Body: strings.NewReader(`{"email":"test@example.com"}`),
|
||||
Delay: 100 * time.Millisecond,
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// insert 10 non-expired
|
||||
for i := 0; i < 10; i++ {
|
||||
otp := core.NewOTP(app)
|
||||
otp.Id = "otp_" + strconv.Itoa(i)
|
||||
otp.SetCollectionRef(user.Collection().Id)
|
||||
otp.SetRecordRef(user.Id)
|
||||
otp.SetPassword("123456")
|
||||
if err := app.SaveNoValidate(otp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"otpId":"otp_9"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordRequestOTPRequest": 1,
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
if app.TestMailer.TotalSend() != 0 {
|
||||
t.Fatalf("Expected 0 sent emails, got %d", app.TestMailer.TotalSend())
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "OnRecordRequestOTPRequest tx body write check",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-otp",
|
||||
Body: strings.NewReader(`{"email":"test@example.com"}`),
|
||||
Delay: 100 * time.Millisecond,
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.OnRecordRequestOTPRequest().BindFunc(func(e *core.RecordCreateOTPRequestEvent) error {
|
||||
original := e.App
|
||||
return e.App.RunInTransaction(func(txApp core.App) error {
|
||||
e.App = txApp
|
||||
defer func() { e.App = original }()
|
||||
|
||||
if err := e.Next(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return e.BadRequestError("TX_ERROR", nil)
|
||||
})
|
||||
})
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedEvents: map[string]int{"OnRecordRequestOTPRequest": 1},
|
||||
ExpectedContent: []string{"TX_ERROR"},
|
||||
},
|
||||
|
||||
// rate limit checks
|
||||
// -----------------------------------------------------------
|
||||
{
|
||||
Name: "RateLimit rule - users:requestOTP",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-otp",
|
||||
Body: strings.NewReader(`{"email":"test@example.com"}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 100, Label: "*:requestOTP"},
|
||||
{MaxRequests: 0, Label: "users:requestOTP"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "RateLimit rule - *:requestOTP",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-otp",
|
||||
Body: strings.NewReader(`{"email":"test@example.com"}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 0, Label: "*:requestOTP"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
104
apis/record_auth_password_reset_confirm.go
Normal file
104
apis/record_auth_password_reset_confirm.go
Normal file
|
@ -0,0 +1,104 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/core/validators"
|
||||
"github.com/pocketbase/pocketbase/tools/security"
|
||||
"github.com/spf13/cast"
|
||||
)
|
||||
|
||||
func recordConfirmPasswordReset(e *core.RequestEvent) error {
|
||||
collection, err := findAuthCollection(e)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
form := new(recordConfirmPasswordResetForm)
|
||||
form.app = e.App
|
||||
form.collection = collection
|
||||
if err = e.BindBody(form); err != nil {
|
||||
return e.BadRequestError("An error occurred while loading the submitted data.", err)
|
||||
}
|
||||
if err = form.validate(); err != nil {
|
||||
return firstApiError(err, e.BadRequestError("An error occurred while validating the submitted data.", err))
|
||||
}
|
||||
|
||||
authRecord, err := e.App.FindAuthRecordByToken(form.Token, core.TokenTypePasswordReset)
|
||||
if err != nil {
|
||||
return firstApiError(err, e.BadRequestError("Invalid or expired password reset token.", err))
|
||||
}
|
||||
|
||||
event := new(core.RecordConfirmPasswordResetRequestEvent)
|
||||
event.RequestEvent = e
|
||||
event.Collection = collection
|
||||
event.Record = authRecord
|
||||
|
||||
return e.App.OnRecordConfirmPasswordResetRequest().Trigger(event, func(e *core.RecordConfirmPasswordResetRequestEvent) error {
|
||||
authRecord.SetPassword(form.Password)
|
||||
|
||||
if !authRecord.Verified() {
|
||||
payload, err := security.ParseUnverifiedJWT(form.Token)
|
||||
if err == nil && authRecord.Email() == cast.ToString(payload[core.TokenClaimEmail]) {
|
||||
// mark as verified if the email hasn't changed
|
||||
authRecord.SetVerified(true)
|
||||
}
|
||||
}
|
||||
|
||||
err = e.App.Save(authRecord)
|
||||
if err != nil {
|
||||
return firstApiError(err, e.BadRequestError("Failed to set new password.", err))
|
||||
}
|
||||
|
||||
e.App.Store().Remove(getPasswordResetResendKey(authRecord))
|
||||
|
||||
return execAfterSuccessTx(true, e.App, func() error {
|
||||
return e.NoContent(http.StatusNoContent)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
type recordConfirmPasswordResetForm struct {
|
||||
app core.App
|
||||
collection *core.Collection
|
||||
|
||||
Token string `form:"token" json:"token"`
|
||||
Password string `form:"password" json:"password"`
|
||||
PasswordConfirm string `form:"passwordConfirm" json:"passwordConfirm"`
|
||||
}
|
||||
|
||||
func (form *recordConfirmPasswordResetForm) validate() error {
|
||||
min := 1
|
||||
passField, ok := form.collection.Fields.GetByName(core.FieldNamePassword).(*core.PasswordField)
|
||||
if ok && passField != nil && passField.Min > 0 {
|
||||
min = passField.Min
|
||||
}
|
||||
|
||||
return validation.ValidateStruct(form,
|
||||
validation.Field(&form.Token, validation.Required, validation.By(form.checkToken)),
|
||||
validation.Field(&form.Password, validation.Required, validation.Length(min, 255)), // the FieldPassword validator will check further the specicic length constraints
|
||||
validation.Field(&form.PasswordConfirm, validation.Required, validation.By(validators.Equal(form.Password))),
|
||||
)
|
||||
}
|
||||
|
||||
func (form *recordConfirmPasswordResetForm) checkToken(value any) error {
|
||||
v, _ := value.(string)
|
||||
if v == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
record, err := form.app.FindAuthRecordByToken(v, core.TokenTypePasswordReset)
|
||||
if err != nil || record == nil {
|
||||
return validation.NewError("validation_invalid_token", "Invalid or expired token.")
|
||||
}
|
||||
|
||||
if record.Collection().Id != form.collection.Id {
|
||||
return validation.NewError("validation_token_collection_mismatch", "The provided token is for different auth collection.")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
360
apis/record_auth_password_reset_confirm_test.go
Normal file
360
apis/record_auth_password_reset_confirm_test.go
Normal file
|
@ -0,0 +1,360 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestRecordConfirmPasswordReset(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "empty data",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/confirm-password-reset",
|
||||
Body: strings.NewReader(``),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"password":{"code":"validation_required"`,
|
||||
`"passwordConfirm":{"code":"validation_required"`,
|
||||
`"token":{"code":"validation_required"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "invalid data format",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/confirm-password-reset",
|
||||
Body: strings.NewReader(`{"password`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "expired token and invalid password",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/confirm-password-reset",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MTY0MDk5MTY2MSwidHlwZSI6InBhc3N3b3JkUmVzZXQiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8iLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.5Tm6_6amQqOlX3urAnXlEdmxwG5qQJfiTg6U0hHR1hk",
|
||||
"password":"1234567",
|
||||
"passwordConfirm":"7654321"
|
||||
}`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"token":{"code":"validation_invalid_token"`,
|
||||
`"password":{"code":"validation_length_out_of_range"`,
|
||||
`"passwordConfirm":{"code":"validation_values_mismatch"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "non-password reset token",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/confirm-password-reset",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InZlcmlmaWNhdGlvbiIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSJ9.SetHpu2H-x-q4TIUz-xiQjwi7MNwLCLvSs4O0hUSp0E",
|
||||
"password":"1234567!",
|
||||
"passwordConfirm":"1234567!"
|
||||
}`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"token":{"code":"validation_invalid_token"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "non auth collection",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/demo1/confirm-password-reset?expand=rel,missing",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InBhc3N3b3JkUmVzZXQiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8iLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.xR-xq1oHDy0D8Q4NDOAEyYKGHWd_swzoiSoL8FLFBHY",
|
||||
"password":"1234567!",
|
||||
"passwordConfirm":"1234567!"
|
||||
}`),
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "different auth collection",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/clients/confirm-password-reset?expand=rel,missing",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InBhc3N3b3JkUmVzZXQiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8iLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.xR-xq1oHDy0D8Q4NDOAEyYKGHWd_swzoiSoL8FLFBHY",
|
||||
"password":"1234567!",
|
||||
"passwordConfirm":"1234567!"
|
||||
}`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{"token":{"code":"validation_token_collection_mismatch"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "valid token and data (unverified user)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/confirm-password-reset",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InBhc3N3b3JkUmVzZXQiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8iLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.xR-xq1oHDy0D8Q4NDOAEyYKGHWd_swzoiSoL8FLFBHY",
|
||||
"password":"1234567!",
|
||||
"passwordConfirm":"1234567!"
|
||||
}`),
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordConfirmPasswordResetRequest": 1,
|
||||
"OnModelUpdate": 1,
|
||||
"OnModelUpdateExecute": 1,
|
||||
"OnModelAfterUpdateSuccess": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnRecordUpdate": 1,
|
||||
"OnRecordUpdateExecute": 1,
|
||||
"OnRecordAfterUpdateSuccess": 1,
|
||||
"OnRecordValidate": 1,
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to fetch confirm password user: %v", err)
|
||||
}
|
||||
|
||||
if user.Verified() {
|
||||
t.Fatal("Expected the user to be unverified")
|
||||
}
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
_, err := app.FindAuthRecordByToken(
|
||||
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InBhc3N3b3JkUmVzZXQiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8iLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.xR-xq1oHDy0D8Q4NDOAEyYKGHWd_swzoiSoL8FLFBHY",
|
||||
core.TokenTypePasswordReset,
|
||||
)
|
||||
if err == nil {
|
||||
t.Fatal("Expected the password reset token to be invalidated")
|
||||
}
|
||||
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to fetch confirm password user: %v", err)
|
||||
}
|
||||
|
||||
if !user.Verified() {
|
||||
t.Fatal("Expected the user to be marked as verified")
|
||||
}
|
||||
|
||||
if !user.ValidatePassword("1234567!") {
|
||||
t.Fatal("Password wasn't changed")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "valid token and data (unverified user with different email from the one in the token)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/confirm-password-reset",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InBhc3N3b3JkUmVzZXQiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8iLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.xR-xq1oHDy0D8Q4NDOAEyYKGHWd_swzoiSoL8FLFBHY",
|
||||
"password":"1234567!",
|
||||
"passwordConfirm":"1234567!"
|
||||
}`),
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordConfirmPasswordResetRequest": 1,
|
||||
"OnModelUpdate": 1,
|
||||
"OnModelUpdateExecute": 1,
|
||||
"OnModelAfterUpdateSuccess": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnRecordUpdate": 1,
|
||||
"OnRecordUpdateExecute": 1,
|
||||
"OnRecordAfterUpdateSuccess": 1,
|
||||
"OnRecordValidate": 1,
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to fetch confirm password user: %v", err)
|
||||
}
|
||||
|
||||
if user.Verified() {
|
||||
t.Fatal("Expected the user to be unverified")
|
||||
}
|
||||
|
||||
oldTokenKey := user.TokenKey()
|
||||
|
||||
// manually change the email to check whether the verified state will be updated
|
||||
user.SetEmail("test_update@example.com")
|
||||
if err = app.Save(user); err != nil {
|
||||
t.Fatalf("Failed to update user test email: %v", err)
|
||||
}
|
||||
|
||||
// resave with the old token key since the email change above
|
||||
// would change it and will make the password token invalid
|
||||
user.SetTokenKey(oldTokenKey)
|
||||
if err = app.Save(user); err != nil {
|
||||
t.Fatalf("Failed to restore original user tokenKey: %v", err)
|
||||
}
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
_, err := app.FindAuthRecordByToken(
|
||||
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InBhc3N3b3JkUmVzZXQiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8iLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.xR-xq1oHDy0D8Q4NDOAEyYKGHWd_swzoiSoL8FLFBHY",
|
||||
core.TokenTypePasswordReset,
|
||||
)
|
||||
if err == nil {
|
||||
t.Fatalf("Expected the password reset token to be invalidated")
|
||||
}
|
||||
|
||||
user, err := app.FindAuthRecordByEmail("users", "test_update@example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to fetch confirm password user: %v", err)
|
||||
}
|
||||
|
||||
if user.Verified() {
|
||||
t.Fatal("Expected the user to remain unverified")
|
||||
}
|
||||
|
||||
if !user.ValidatePassword("1234567!") {
|
||||
t.Fatal("Password wasn't changed")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "valid token and data (verified user)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/confirm-password-reset",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InBhc3N3b3JkUmVzZXQiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8iLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.xR-xq1oHDy0D8Q4NDOAEyYKGHWd_swzoiSoL8FLFBHY",
|
||||
"password":"1234567!",
|
||||
"passwordConfirm":"1234567!"
|
||||
}`),
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordConfirmPasswordResetRequest": 1,
|
||||
"OnModelUpdate": 1,
|
||||
"OnModelUpdateExecute": 1,
|
||||
"OnModelAfterUpdateSuccess": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnRecordUpdate": 1,
|
||||
"OnRecordUpdateExecute": 1,
|
||||
"OnRecordAfterUpdateSuccess": 1,
|
||||
"OnRecordValidate": 1,
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to fetch confirm password user: %v", err)
|
||||
}
|
||||
|
||||
// ensure that the user is already verified
|
||||
user.SetVerified(true)
|
||||
if err := app.Save(user); err != nil {
|
||||
t.Fatalf("Failed to update user verified state")
|
||||
}
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
_, err := app.FindAuthRecordByToken(
|
||||
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InBhc3N3b3JkUmVzZXQiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8iLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.xR-xq1oHDy0D8Q4NDOAEyYKGHWd_swzoiSoL8FLFBHY",
|
||||
core.TokenTypePasswordReset,
|
||||
)
|
||||
if err == nil {
|
||||
t.Fatal("Expected the password reset token to be invalidated")
|
||||
}
|
||||
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to fetch confirm password user: %v", err)
|
||||
}
|
||||
|
||||
if !user.Verified() {
|
||||
t.Fatal("Expected the user to remain verified")
|
||||
}
|
||||
|
||||
if !user.ValidatePassword("1234567!") {
|
||||
t.Fatal("Password wasn't changed")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "OnRecordConfirmPasswordResetRequest tx body write check",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/confirm-password-reset",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InBhc3N3b3JkUmVzZXQiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8iLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.xR-xq1oHDy0D8Q4NDOAEyYKGHWd_swzoiSoL8FLFBHY",
|
||||
"password":"1234567!",
|
||||
"passwordConfirm":"1234567!"
|
||||
}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.OnRecordConfirmPasswordResetRequest().BindFunc(func(e *core.RecordConfirmPasswordResetRequestEvent) error {
|
||||
original := e.App
|
||||
return e.App.RunInTransaction(func(txApp core.App) error {
|
||||
e.App = txApp
|
||||
defer func() { e.App = original }()
|
||||
|
||||
if err := e.Next(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return e.BadRequestError("TX_ERROR", nil)
|
||||
})
|
||||
})
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedEvents: map[string]int{"OnRecordConfirmPasswordResetRequest": 1},
|
||||
ExpectedContent: []string{"TX_ERROR"},
|
||||
},
|
||||
|
||||
// rate limit checks
|
||||
// -----------------------------------------------------------
|
||||
{
|
||||
Name: "RateLimit rule - users:confirmPasswordReset",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/confirm-password-reset",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InBhc3N3b3JkUmVzZXQiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8iLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.xR-xq1oHDy0D8Q4NDOAEyYKGHWd_swzoiSoL8FLFBHY",
|
||||
"password":"1234567!",
|
||||
"passwordConfirm":"1234567!"
|
||||
}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 100, Label: "*:confirmPasswordReset"},
|
||||
{MaxRequests: 0, Label: "users:confirmPasswordReset"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "RateLimit rule - *:confirmPasswordReset",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/confirm-password-reset",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InBhc3N3b3JkUmVzZXQiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8iLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.xR-xq1oHDy0D8Q4NDOAEyYKGHWd_swzoiSoL8FLFBHY",
|
||||
"password":"1234567!",
|
||||
"passwordConfirm":"1234567!"
|
||||
}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 0, Label: "*:confirmPasswordReset"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
88
apis/record_auth_password_reset_request.go
Normal file
88
apis/record_auth_password_reset_request.go
Normal file
|
@ -0,0 +1,88 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/go-ozzo/ozzo-validation/v4/is"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/mails"
|
||||
"github.com/pocketbase/pocketbase/tools/routine"
|
||||
)
|
||||
|
||||
func recordRequestPasswordReset(e *core.RequestEvent) error {
|
||||
collection, err := findAuthCollection(e)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !collection.PasswordAuth.Enabled {
|
||||
return e.BadRequestError("The collection is not configured to allow password authentication.", nil)
|
||||
}
|
||||
|
||||
form := new(recordRequestPasswordResetForm)
|
||||
if err = e.BindBody(form); err != nil {
|
||||
return firstApiError(err, e.BadRequestError("An error occurred while loading the submitted data.", err))
|
||||
}
|
||||
if err = form.validate(); err != nil {
|
||||
return firstApiError(err, e.BadRequestError("An error occurred while validating the submitted data.", err))
|
||||
}
|
||||
|
||||
record, err := e.App.FindAuthRecordByEmail(collection, form.Email)
|
||||
if err != nil {
|
||||
// eagerly write 204 response as a very basic measure against emails enumeration
|
||||
e.NoContent(http.StatusNoContent)
|
||||
return fmt.Errorf("failed to fetch %s record with email %s: %w", collection.Name, form.Email, err)
|
||||
}
|
||||
|
||||
resendKey := getPasswordResetResendKey(record)
|
||||
if e.App.Store().Has(resendKey) {
|
||||
// eagerly write 204 response as a very basic measure against emails enumeration
|
||||
e.NoContent(http.StatusNoContent)
|
||||
return errors.New("try again later - you've already requested a password reset email")
|
||||
}
|
||||
|
||||
event := new(core.RecordRequestPasswordResetRequestEvent)
|
||||
event.RequestEvent = e
|
||||
event.Collection = collection
|
||||
event.Record = record
|
||||
|
||||
return e.App.OnRecordRequestPasswordResetRequest().Trigger(event, func(e *core.RecordRequestPasswordResetRequestEvent) error {
|
||||
// run in background because we don't need to show the result to the client
|
||||
app := e.App
|
||||
routine.FireAndForget(func() {
|
||||
if err := mails.SendRecordPasswordReset(app, e.Record); err != nil {
|
||||
app.Logger().Error("Failed to send password reset email", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
app.Store().Set(resendKey, struct{}{})
|
||||
time.AfterFunc(2*time.Minute, func() {
|
||||
app.Store().Remove(resendKey)
|
||||
})
|
||||
})
|
||||
|
||||
return execAfterSuccessTx(true, e.App, func() error {
|
||||
return e.NoContent(http.StatusNoContent)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
type recordRequestPasswordResetForm struct {
|
||||
Email string `form:"email" json:"email"`
|
||||
}
|
||||
|
||||
func (form *recordRequestPasswordResetForm) validate() error {
|
||||
return validation.ValidateStruct(form,
|
||||
validation.Field(&form.Email, validation.Required, validation.Length(1, 255), is.EmailFormat),
|
||||
)
|
||||
}
|
||||
|
||||
func getPasswordResetResendKey(record *core.Record) string {
|
||||
return "@limitPasswordResetEmail_" + record.Collection().Id + record.Id
|
||||
}
|
169
apis/record_auth_password_reset_request_test.go
Normal file
169
apis/record_auth_password_reset_request_test.go
Normal file
|
@ -0,0 +1,169 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestRecordRequestPasswordReset(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "not an auth collection",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/demo1/request-password-reset",
|
||||
Body: strings.NewReader(``),
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "empty data",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-password-reset",
|
||||
Body: strings.NewReader(``),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{"email":{"code":"validation_required","message":"Cannot be blank."}}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "invalid data",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-password-reset",
|
||||
Body: strings.NewReader(`{"email`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "existing auth record in a collection with disabled password login",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/nologin/request-password-reset",
|
||||
Body: strings.NewReader(`{"email":"test@example.com"}`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "missing auth record",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-password-reset",
|
||||
Body: strings.NewReader(`{"email":"missing@example.com"}`),
|
||||
Delay: 100 * time.Millisecond,
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
if app.TestMailer.TotalSend() != 0 {
|
||||
t.Fatalf("Expected zero emails, got %d", app.TestMailer.TotalSend())
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "existing auth record",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-password-reset",
|
||||
Body: strings.NewReader(`{"email":"test@example.com"}`),
|
||||
Delay: 100 * time.Millisecond,
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordRequestPasswordResetRequest": 1,
|
||||
"OnMailerSend": 1,
|
||||
"OnMailerRecordPasswordResetSend": 1,
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
if !strings.Contains(app.TestMailer.LastMessage().HTML, "/auth/confirm-password-reset") {
|
||||
t.Fatalf("Expected password reset email, got\n%v", app.TestMailer.LastMessage().HTML)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "existing auth record (after already sent)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-password-reset",
|
||||
Body: strings.NewReader(`{"email":"test@example.com"}`),
|
||||
Delay: 100 * time.Millisecond,
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
// simulate recent verification sent
|
||||
authRecord, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
resendKey := "@limitPasswordResetEmail_" + authRecord.Collection().Id + authRecord.Id
|
||||
app.Store().Set(resendKey, struct{}{})
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "OnRecordRequestPasswordResetRequest tx body write check",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-password-reset",
|
||||
Body: strings.NewReader(`{"email":"test@example.com"}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.OnRecordRequestPasswordResetRequest().BindFunc(func(e *core.RecordRequestPasswordResetRequestEvent) error {
|
||||
original := e.App
|
||||
return e.App.RunInTransaction(func(txApp core.App) error {
|
||||
e.App = txApp
|
||||
defer func() { e.App = original }()
|
||||
|
||||
if err := e.Next(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return e.BadRequestError("TX_ERROR", nil)
|
||||
})
|
||||
})
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedEvents: map[string]int{"OnRecordRequestPasswordResetRequest": 1},
|
||||
ExpectedContent: []string{"TX_ERROR"},
|
||||
},
|
||||
|
||||
// rate limit checks
|
||||
// -----------------------------------------------------------
|
||||
{
|
||||
Name: "RateLimit rule - users:requestPasswordReset",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-password-reset",
|
||||
Body: strings.NewReader(`{"email":"missing@example.com"}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 100, Label: "*:requestPasswordReset"},
|
||||
{MaxRequests: 0, Label: "users:requestPasswordReset"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "RateLimit rule - *:requestPasswordReset",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-password-reset",
|
||||
Body: strings.NewReader(`{"email":"missing@example.com"}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 0, Label: "*:requestPasswordReset"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
29
apis/record_auth_refresh.go
Normal file
29
apis/record_auth_refresh.go
Normal file
|
@ -0,0 +1,29 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tools/security"
|
||||
"github.com/spf13/cast"
|
||||
)
|
||||
|
||||
func recordAuthRefresh(e *core.RequestEvent) error {
|
||||
record := e.Auth
|
||||
if record == nil {
|
||||
return e.NotFoundError("Missing auth record context.", nil)
|
||||
}
|
||||
|
||||
currentToken := getAuthTokenFromRequest(e)
|
||||
claims, _ := security.ParseUnverifiedJWT(currentToken)
|
||||
if v, ok := claims[core.TokenClaimRefreshable]; !ok || !cast.ToBool(v) {
|
||||
return e.ForbiddenError("The current auth token is not refreshable.", nil)
|
||||
}
|
||||
|
||||
event := new(core.RecordAuthRefreshRequestEvent)
|
||||
event.RequestEvent = e
|
||||
event.Collection = record.Collection()
|
||||
event.Record = record
|
||||
|
||||
return e.App.OnRecordAuthRefreshRequest().Trigger(event, func(e *core.RecordAuthRefreshRequestEvent) error {
|
||||
return RecordAuthResponse(e.RequestEvent, e.Record, "", nil)
|
||||
})
|
||||
}
|
202
apis/record_auth_refresh_test.go
Normal file
202
apis/record_auth_refresh_test.go
Normal file
|
@ -0,0 +1,202 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestRecordAuthRefresh(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "unauthorized",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-refresh",
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "superuser trying to refresh the auth of another auth collection",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-refresh",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "auth record + not an auth collection",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/demo1/auth-refresh",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "auth record + different auth collection",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/clients/auth-refresh?expand=rel,missing",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "auth record + same auth collection as the token",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-refresh?expand=rel,missing",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"token":`,
|
||||
`"record":`,
|
||||
`"id":"4q1xlclmfloku33"`,
|
||||
`"emailVisibility":false`,
|
||||
`"email":"test@example.com"`, // the owner can always view their email address
|
||||
`"expand":`,
|
||||
`"rel":`,
|
||||
`"id":"llvuca81nly1qls"`,
|
||||
},
|
||||
NotExpectedContent: []string{
|
||||
`"missing":`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordAuthRefreshRequest": 1,
|
||||
"OnRecordAuthRequest": 1,
|
||||
"OnRecordEnrich": 2,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "auth record + same auth collection as the token but static/unrefreshable",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-refresh",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6ZmFsc2V9.4IsO6YMsR19crhwl_YWzvRH8pfq2Ri4Gv2dzGyneLak",
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "unverified auth record in onlyVerified collection",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/clients/auth-refresh",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6Im8xeTBkZDBzcGQ3ODZtZCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.Zi0yXE-CNmnbTdVaQEzYZVuECqRdn3LgEM6pmB3XWBE",
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordAuthRefreshRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "verified auth record in onlyVerified collection",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/clients/auth-refresh",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"token":`,
|
||||
`"record":`,
|
||||
`"id":"gk390qegs4y47wn"`,
|
||||
`"verified":true`,
|
||||
`"email":"test@example.com"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordAuthRefreshRequest": 1,
|
||||
"OnRecordAuthRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "OnRecordAuthRefreshRequest tx body write check",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-refresh",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.OnRecordAuthRefreshRequest().BindFunc(func(e *core.RecordAuthRefreshRequestEvent) error {
|
||||
original := e.App
|
||||
return e.App.RunInTransaction(func(txApp core.App) error {
|
||||
e.App = txApp
|
||||
defer func() { e.App = original }()
|
||||
|
||||
if err := e.Next(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return e.BadRequestError("TX_ERROR", nil)
|
||||
})
|
||||
})
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedEvents: map[string]int{"OnRecordAuthRefreshRequest": 1},
|
||||
ExpectedContent: []string{"TX_ERROR"},
|
||||
},
|
||||
|
||||
// rate limit checks
|
||||
// -----------------------------------------------------------
|
||||
{
|
||||
Name: "RateLimit rule - users:authRefresh",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-refresh",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 100, Label: "*:authRefresh"},
|
||||
{MaxRequests: 0, Label: "users:authRefresh"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "RateLimit rule - *:authRefresh",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-refresh",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 0, Label: "*:authRefresh"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
102
apis/record_auth_verification_confirm.go
Normal file
102
apis/record_auth_verification_confirm.go
Normal file
|
@ -0,0 +1,102 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tools/security"
|
||||
"github.com/spf13/cast"
|
||||
)
|
||||
|
||||
func recordConfirmVerification(e *core.RequestEvent) error {
|
||||
collection, err := findAuthCollection(e)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if collection.Name == core.CollectionNameSuperusers {
|
||||
return e.BadRequestError("All superusers are verified by default.", nil)
|
||||
}
|
||||
|
||||
form := new(recordConfirmVerificationForm)
|
||||
form.app = e.App
|
||||
form.collection = collection
|
||||
if err = e.BindBody(form); err != nil {
|
||||
return firstApiError(err, e.BadRequestError("An error occurred while loading the submitted data.", err))
|
||||
}
|
||||
if err = form.validate(); err != nil {
|
||||
return firstApiError(err, e.BadRequestError("An error occurred while validating the submitted data.", err))
|
||||
}
|
||||
|
||||
record, err := form.app.FindAuthRecordByToken(form.Token, core.TokenTypeVerification)
|
||||
if err != nil {
|
||||
return e.BadRequestError("Invalid or expired verification token.", err)
|
||||
}
|
||||
|
||||
wasVerified := record.Verified()
|
||||
|
||||
event := new(core.RecordConfirmVerificationRequestEvent)
|
||||
event.RequestEvent = e
|
||||
event.Collection = collection
|
||||
event.Record = record
|
||||
|
||||
return e.App.OnRecordConfirmVerificationRequest().Trigger(event, func(e *core.RecordConfirmVerificationRequestEvent) error {
|
||||
if !wasVerified {
|
||||
e.Record.SetVerified(true)
|
||||
|
||||
if err := e.App.Save(e.Record); err != nil {
|
||||
return firstApiError(err, e.BadRequestError("An error occurred while saving the verified state.", err))
|
||||
}
|
||||
}
|
||||
|
||||
e.App.Store().Remove(getVerificationResendKey(e.Record))
|
||||
|
||||
return execAfterSuccessTx(true, e.App, func() error {
|
||||
return e.NoContent(http.StatusNoContent)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
type recordConfirmVerificationForm struct {
|
||||
app core.App
|
||||
collection *core.Collection
|
||||
|
||||
Token string `form:"token" json:"token"`
|
||||
}
|
||||
|
||||
func (form *recordConfirmVerificationForm) validate() error {
|
||||
return validation.ValidateStruct(form,
|
||||
validation.Field(&form.Token, validation.Required, validation.By(form.checkToken)),
|
||||
)
|
||||
}
|
||||
|
||||
func (form *recordConfirmVerificationForm) checkToken(value any) error {
|
||||
v, _ := value.(string)
|
||||
if v == "" {
|
||||
return nil // nothing to check
|
||||
}
|
||||
|
||||
claims, _ := security.ParseUnverifiedJWT(v)
|
||||
email := cast.ToString(claims["email"])
|
||||
if email == "" {
|
||||
return validation.NewError("validation_invalid_token_claims", "Missing email token claim.")
|
||||
}
|
||||
|
||||
record, err := form.app.FindAuthRecordByToken(v, core.TokenTypeVerification)
|
||||
if err != nil || record == nil {
|
||||
return validation.NewError("validation_invalid_token", "Invalid or expired token.")
|
||||
}
|
||||
|
||||
if record.Collection().Id != form.collection.Id {
|
||||
return validation.NewError("validation_token_collection_mismatch", "The provided token is for different auth collection.")
|
||||
}
|
||||
|
||||
if record.Email() != email {
|
||||
return validation.NewError("validation_token_email_mismatch", "The record email doesn't match with the requested token claims.")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
216
apis/record_auth_verification_confirm_test.go
Normal file
216
apis/record_auth_verification_confirm_test.go
Normal file
|
@ -0,0 +1,216 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestRecordConfirmVerification(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "empty data",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/confirm-verification",
|
||||
Body: strings.NewReader(``),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"token":{"code":"validation_required"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "invalid data format",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/confirm-verification",
|
||||
Body: strings.NewReader(`{"password`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "expired token",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/confirm-verification",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MTY0MDk5MTY2MSwidHlwZSI6InZlcmlmaWNhdGlvbiIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSJ9.qqelNNL2Udl6K_TJ282sNHYCpASgA6SIuSVKGfBHMZU"
|
||||
}`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"token":{"code":"validation_invalid_token"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "non-verification token",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/confirm-verification",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InBhc3N3b3JkUmVzZXQiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8iLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.xR-xq1oHDy0D8Q4NDOAEyYKGHWd_swzoiSoL8FLFBHY"
|
||||
}`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"token":{"code":"validation_invalid_token"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "non auth collection",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/demo1/confirm-verification?expand=rel,missing",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InZlcmlmaWNhdGlvbiIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSJ9.SetHpu2H-x-q4TIUz-xiQjwi7MNwLCLvSs4O0hUSp0E"
|
||||
}`),
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "different auth collection",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/clients/confirm-verification?expand=rel,missing",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InZlcmlmaWNhdGlvbiIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSJ9.SetHpu2H-x-q4TIUz-xiQjwi7MNwLCLvSs4O0hUSp0E"
|
||||
}`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{"token":{"code":"validation_token_collection_mismatch"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "valid token",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/confirm-verification",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InZlcmlmaWNhdGlvbiIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSJ9.SetHpu2H-x-q4TIUz-xiQjwi7MNwLCLvSs4O0hUSp0E"
|
||||
}`),
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordConfirmVerificationRequest": 1,
|
||||
"OnModelUpdate": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnModelUpdateExecute": 1,
|
||||
"OnModelAfterUpdateSuccess": 1,
|
||||
"OnRecordUpdate": 1,
|
||||
"OnRecordValidate": 1,
|
||||
"OnRecordUpdateExecute": 1,
|
||||
"OnRecordAfterUpdateSuccess": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "valid token (already verified)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/confirm-verification",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6Im9hcDY0MGNvdDR5cnUycyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InZlcmlmaWNhdGlvbiIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsImVtYWlsIjoidGVzdDJAZXhhbXBsZS5jb20ifQ.QQmM3odNFVk6u4J4-5H8IBM3dfk9YCD7mPW-8PhBAI8"
|
||||
}`),
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordConfirmVerificationRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "valid verification token from a collection without allowed login",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/nologin/confirm-verification",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6ImRjNDlrNmpnZWpuNDBoMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InZlcmlmaWNhdGlvbiIsImNvbGxlY3Rpb25JZCI6ImtwdjcwOXNrMmxxYnFrOCIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSJ9.5GmuZr4vmwk3Cb_3ZZWNxwbE75KZC-j71xxIPR9AsVw"
|
||||
}`),
|
||||
ExpectedStatus: 204,
|
||||
ExpectedContent: []string{},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordConfirmVerificationRequest": 1,
|
||||
"OnModelUpdate": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnModelUpdateExecute": 1,
|
||||
"OnModelAfterUpdateSuccess": 1,
|
||||
"OnRecordUpdate": 1,
|
||||
"OnRecordValidate": 1,
|
||||
"OnRecordUpdateExecute": 1,
|
||||
"OnRecordAfterUpdateSuccess": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "OnRecordConfirmVerificationRequest tx body write check",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/confirm-verification",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InZlcmlmaWNhdGlvbiIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSJ9.SetHpu2H-x-q4TIUz-xiQjwi7MNwLCLvSs4O0hUSp0E"
|
||||
}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.OnRecordConfirmVerificationRequest().BindFunc(func(e *core.RecordConfirmVerificationRequestEvent) error {
|
||||
original := e.App
|
||||
return e.App.RunInTransaction(func(txApp core.App) error {
|
||||
e.App = txApp
|
||||
defer func() { e.App = original }()
|
||||
|
||||
if err := e.Next(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return e.BadRequestError("TX_ERROR", nil)
|
||||
})
|
||||
})
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedEvents: map[string]int{"OnRecordConfirmVerificationRequest": 1},
|
||||
ExpectedContent: []string{"TX_ERROR"},
|
||||
},
|
||||
|
||||
// rate limit checks
|
||||
// -----------------------------------------------------------
|
||||
{
|
||||
Name: "RateLimit rule - nologin:confirmVerification",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/nologin/confirm-verification",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6ImRjNDlrNmpnZWpuNDBoMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InZlcmlmaWNhdGlvbiIsImNvbGxlY3Rpb25JZCI6ImtwdjcwOXNrMmxxYnFrOCIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSJ9.5GmuZr4vmwk3Cb_3ZZWNxwbE75KZC-j71xxIPR9AsVw"
|
||||
}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 100, Label: "*:confirmVerification"},
|
||||
{MaxRequests: 0, Label: "nologin:confirmVerification"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "RateLimit rule - *:confirmVerification",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/nologin/confirm-verification",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6ImRjNDlrNmpnZWpuNDBoMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InZlcmlmaWNhdGlvbiIsImNvbGxlY3Rpb25JZCI6ImtwdjcwOXNrMmxxYnFrOCIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSJ9.5GmuZr4vmwk3Cb_3ZZWNxwbE75KZC-j71xxIPR9AsVw"
|
||||
}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 0, Label: "*:confirmVerification"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
91
apis/record_auth_verification_request.go
Normal file
91
apis/record_auth_verification_request.go
Normal file
|
@ -0,0 +1,91 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/go-ozzo/ozzo-validation/v4/is"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/mails"
|
||||
"github.com/pocketbase/pocketbase/tools/routine"
|
||||
)
|
||||
|
||||
func recordRequestVerification(e *core.RequestEvent) error {
|
||||
collection, err := findAuthCollection(e)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if collection.Name == core.CollectionNameSuperusers {
|
||||
return e.BadRequestError("All superusers are verified by default.", nil)
|
||||
}
|
||||
|
||||
form := new(recordRequestVerificationForm)
|
||||
if err = e.BindBody(form); err != nil {
|
||||
return firstApiError(err, e.BadRequestError("An error occurred while loading the submitted data.", err))
|
||||
}
|
||||
if err = form.validate(); err != nil {
|
||||
return firstApiError(err, e.BadRequestError("An error occurred while validating the submitted data.", err))
|
||||
}
|
||||
|
||||
record, err := e.App.FindAuthRecordByEmail(collection, form.Email)
|
||||
if err != nil {
|
||||
// eagerly write 204 response as a very basic measure against emails enumeration
|
||||
e.NoContent(http.StatusNoContent)
|
||||
return fmt.Errorf("failed to fetch %s record with email %s: %w", collection.Name, form.Email, err)
|
||||
}
|
||||
|
||||
resendKey := getVerificationResendKey(record)
|
||||
if !record.Verified() && e.App.Store().Has(resendKey) {
|
||||
// eagerly write 204 response as a very basic measure against emails enumeration
|
||||
e.NoContent(http.StatusNoContent)
|
||||
return errors.New("try again later - you've already requested a verification email")
|
||||
}
|
||||
|
||||
event := new(core.RecordRequestVerificationRequestEvent)
|
||||
event.RequestEvent = e
|
||||
event.Collection = collection
|
||||
event.Record = record
|
||||
|
||||
return e.App.OnRecordRequestVerificationRequest().Trigger(event, func(e *core.RecordRequestVerificationRequestEvent) error {
|
||||
if e.Record.Verified() {
|
||||
return e.NoContent(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// run in background because we don't need to show the result to the client
|
||||
app := e.App
|
||||
routine.FireAndForget(func() {
|
||||
if err := mails.SendRecordVerification(app, e.Record); err != nil {
|
||||
app.Logger().Error("Failed to send verification email", "error", err)
|
||||
}
|
||||
|
||||
app.Store().Set(resendKey, struct{}{})
|
||||
time.AfterFunc(2*time.Minute, func() {
|
||||
app.Store().Remove(resendKey)
|
||||
})
|
||||
})
|
||||
|
||||
return execAfterSuccessTx(true, e.App, func() error {
|
||||
return e.NoContent(http.StatusNoContent)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
type recordRequestVerificationForm struct {
|
||||
Email string `form:"email" json:"email"`
|
||||
}
|
||||
|
||||
func (form *recordRequestVerificationForm) validate() error {
|
||||
return validation.ValidateStruct(form,
|
||||
validation.Field(&form.Email, validation.Required, validation.Length(1, 255), is.EmailFormat),
|
||||
)
|
||||
}
|
||||
|
||||
func getVerificationResendKey(record *core.Record) string {
|
||||
return "@limitVerificationEmail_" + record.Collection().Id + record.Id
|
||||
}
|
186
apis/record_auth_verification_request_test.go
Normal file
186
apis/record_auth_verification_request_test.go
Normal file
|
@ -0,0 +1,186 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestRecordRequestVerification(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "not an auth collection",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/demo1/request-verification",
|
||||
Body: strings.NewReader(``),
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "empty data",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-verification",
|
||||
Body: strings.NewReader(``),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{"email":{"code":"validation_required","message":"Cannot be blank."}}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "invalid data",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-verification",
|
||||
Body: strings.NewReader(`{"email`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "missing auth record",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-verification",
|
||||
Body: strings.NewReader(`{"email":"missing@example.com"}`),
|
||||
Delay: 100 * time.Millisecond,
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
if app.TestMailer.TotalSend() != 0 {
|
||||
t.Fatalf("Expected zero emails, got %d", app.TestMailer.TotalSend())
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "already verified auth record",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-verification",
|
||||
Body: strings.NewReader(`{"email":"test2@example.com"}`),
|
||||
Delay: 100 * time.Millisecond,
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordRequestVerificationRequest": 1,
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
if app.TestMailer.TotalSend() != 0 {
|
||||
t.Fatalf("Expected zero emails, got %d", app.TestMailer.TotalSend())
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "existing auth record",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-verification",
|
||||
Body: strings.NewReader(`{"email":"test@example.com"}`),
|
||||
Delay: 100 * time.Millisecond,
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordRequestVerificationRequest": 1,
|
||||
"OnMailerSend": 1,
|
||||
"OnMailerRecordVerificationSend": 1,
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
if !strings.Contains(app.TestMailer.LastMessage().HTML, "/auth/confirm-verification") {
|
||||
t.Fatalf("Expected verification email, got\n%v", app.TestMailer.LastMessage().HTML)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "existing auth record (after already sent)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-verification",
|
||||
Body: strings.NewReader(`{"email":"test@example.com"}`),
|
||||
Delay: 100 * time.Millisecond,
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
// terminated before firing the event
|
||||
// "OnRecordRequestVerificationRequest": 1,
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
// simulate recent verification sent
|
||||
authRecord, err := app.FindFirstRecordByData("users", "email", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
resendKey := "@limitVerificationEmail_" + authRecord.Collection().Id + authRecord.Id
|
||||
app.Store().Set(resendKey, struct{}{})
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
if app.TestMailer.TotalSend() != 0 {
|
||||
t.Fatalf("Expected zero emails, got %d", app.TestMailer.TotalSend())
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "OnRecordRequestVerificationRequest tx body write check",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-verification",
|
||||
Body: strings.NewReader(`{"email":"test@example.com"}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.OnRecordRequestVerificationRequest().BindFunc(func(e *core.RecordRequestVerificationRequestEvent) error {
|
||||
original := e.App
|
||||
return e.App.RunInTransaction(func(txApp core.App) error {
|
||||
e.App = txApp
|
||||
defer func() { e.App = original }()
|
||||
|
||||
if err := e.Next(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return e.BadRequestError("TX_ERROR", nil)
|
||||
})
|
||||
})
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedEvents: map[string]int{"OnRecordRequestVerificationRequest": 1},
|
||||
ExpectedContent: []string{"TX_ERROR"},
|
||||
},
|
||||
|
||||
// rate limit checks
|
||||
// -----------------------------------------------------------
|
||||
{
|
||||
Name: "RateLimit rule - users:requestVerification",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-verification",
|
||||
Body: strings.NewReader(`{"email":"test@example.com"}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 100, Label: "*:requestVerification"},
|
||||
{MaxRequests: 0, Label: "users:requestVerification"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "RateLimit rule - *:requestVerification",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-verification",
|
||||
Body: strings.NewReader(`{"email":"test@example.com"}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 0, Label: "*:requestVerification"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
374
apis/record_auth_with_oauth2.go
Normal file
374
apis/record_auth_with_oauth2.go
Normal file
|
@ -0,0 +1,374 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"maps"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tools/dbutils"
|
||||
"github.com/pocketbase/pocketbase/tools/filesystem"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
func recordAuthWithOAuth2(e *core.RequestEvent) error {
|
||||
collection, err := findAuthCollection(e)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !collection.OAuth2.Enabled {
|
||||
return e.ForbiddenError("The collection is not configured to allow OAuth2 authentication.", nil)
|
||||
}
|
||||
|
||||
var fallbackAuthRecord *core.Record
|
||||
if e.Auth != nil && e.Auth.Collection().Id == collection.Id {
|
||||
fallbackAuthRecord = e.Auth
|
||||
}
|
||||
|
||||
e.Set(core.RequestEventKeyInfoContext, core.RequestInfoContextOAuth2)
|
||||
|
||||
form := new(recordOAuth2LoginForm)
|
||||
form.collection = collection
|
||||
if err = e.BindBody(form); err != nil {
|
||||
return firstApiError(err, e.BadRequestError("An error occurred while loading the submitted data.", err))
|
||||
}
|
||||
|
||||
if form.RedirectUrl != "" && form.RedirectURL == "" {
|
||||
e.App.Logger().Warn("[recordAuthWithOAuth2] redirectUrl body param is deprecated and will be removed in the future. Please replace it with redirectURL.")
|
||||
form.RedirectURL = form.RedirectUrl
|
||||
}
|
||||
|
||||
if err = form.validate(); err != nil {
|
||||
return firstApiError(err, e.BadRequestError("An error occurred while loading the submitted data.", err))
|
||||
}
|
||||
|
||||
// exchange token for OAuth2 user info and locate existing ExternalAuth rel
|
||||
// ---------------------------------------------------------------
|
||||
|
||||
// load provider configuration
|
||||
providerConfig, ok := collection.OAuth2.GetProviderConfig(form.Provider)
|
||||
if !ok {
|
||||
return e.InternalServerError("Missing or invalid provider config.", nil)
|
||||
}
|
||||
|
||||
provider, err := providerConfig.InitProvider()
|
||||
if err != nil {
|
||||
return firstApiError(err, e.InternalServerError("Failed to init provider "+form.Provider, err))
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(e.Request.Context(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
provider.SetContext(ctx)
|
||||
provider.SetRedirectURL(form.RedirectURL)
|
||||
|
||||
var opts []oauth2.AuthCodeOption
|
||||
|
||||
if provider.PKCE() {
|
||||
opts = append(opts, oauth2.SetAuthURLParam("code_verifier", form.CodeVerifier))
|
||||
}
|
||||
|
||||
// fetch token
|
||||
token, err := provider.FetchToken(form.Code, opts...)
|
||||
if err != nil {
|
||||
return firstApiError(err, e.BadRequestError("Failed to fetch OAuth2 token.", err))
|
||||
}
|
||||
|
||||
// fetch external auth user
|
||||
authUser, err := provider.FetchAuthUser(token)
|
||||
if err != nil {
|
||||
return firstApiError(err, e.BadRequestError("Failed to fetch OAuth2 user.", err))
|
||||
}
|
||||
|
||||
var authRecord *core.Record
|
||||
|
||||
// check for existing relation with the auth collection
|
||||
externalAuthRel, err := e.App.FindFirstExternalAuthByExpr(dbx.HashExp{
|
||||
"collectionRef": form.collection.Id,
|
||||
"provider": form.Provider,
|
||||
"providerId": authUser.Id,
|
||||
})
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
return e.InternalServerError("Failed OAuth2 relation check.", err)
|
||||
}
|
||||
|
||||
switch {
|
||||
case err == nil && externalAuthRel != nil:
|
||||
authRecord, err = e.App.FindRecordById(form.collection, externalAuthRel.RecordRef())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
case fallbackAuthRecord != nil && fallbackAuthRecord.Collection().Id == form.collection.Id:
|
||||
// fallback to the logged auth record (if any)
|
||||
authRecord = fallbackAuthRecord
|
||||
case authUser.Email != "":
|
||||
// look for an existing auth record by the external auth record's email
|
||||
authRecord, err = e.App.FindAuthRecordByEmail(form.collection.Id, authUser.Email)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
return e.InternalServerError("Failed OAuth2 auth record check.", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------
|
||||
|
||||
event := new(core.RecordAuthWithOAuth2RequestEvent)
|
||||
event.RequestEvent = e
|
||||
event.Collection = collection
|
||||
event.ProviderName = form.Provider
|
||||
event.ProviderClient = provider
|
||||
event.OAuth2User = authUser
|
||||
event.CreateData = form.CreateData
|
||||
event.Record = authRecord
|
||||
event.IsNewRecord = authRecord == nil
|
||||
|
||||
return e.App.OnRecordAuthWithOAuth2Request().Trigger(event, func(e *core.RecordAuthWithOAuth2RequestEvent) error {
|
||||
if err := oauth2Submit(e, externalAuthRel); err != nil {
|
||||
return firstApiError(err, e.BadRequestError("Failed to authenticate.", err))
|
||||
}
|
||||
|
||||
// @todo revert back to struct after removing the custom auth.AuthUser marshalization
|
||||
meta := map[string]any{}
|
||||
rawOAuth2User, err := json.Marshal(e.OAuth2User)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = json.Unmarshal(rawOAuth2User, &meta)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
meta["isNew"] = e.IsNewRecord
|
||||
|
||||
return RecordAuthResponse(e.RequestEvent, e.Record, core.MFAMethodOAuth2, meta)
|
||||
})
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
type recordOAuth2LoginForm struct {
|
||||
collection *core.Collection
|
||||
|
||||
// Additional data that will be used for creating a new auth record
|
||||
// if an existing OAuth2 account doesn't exist.
|
||||
CreateData map[string]any `form:"createData" json:"createData"`
|
||||
|
||||
// The name of the OAuth2 client provider (eg. "google")
|
||||
Provider string `form:"provider" json:"provider"`
|
||||
|
||||
// The authorization code returned from the initial request.
|
||||
Code string `form:"code" json:"code"`
|
||||
|
||||
// The optional PKCE code verifier as part of the code_challenge sent with the initial request.
|
||||
CodeVerifier string `form:"codeVerifier" json:"codeVerifier"`
|
||||
|
||||
// The redirect url sent with the initial request.
|
||||
RedirectURL string `form:"redirectURL" json:"redirectURL"`
|
||||
|
||||
// @todo
|
||||
// deprecated: use RedirectURL instead
|
||||
// RedirectUrl will be removed after dropping v0.22 support
|
||||
RedirectUrl string `form:"redirectUrl" json:"redirectUrl"`
|
||||
}
|
||||
|
||||
func (form *recordOAuth2LoginForm) validate() error {
|
||||
return validation.ValidateStruct(form,
|
||||
validation.Field(&form.Provider, validation.Required, validation.Length(0, 100), validation.By(form.checkProviderName)),
|
||||
validation.Field(&form.Code, validation.Required),
|
||||
validation.Field(&form.RedirectURL, validation.Required),
|
||||
)
|
||||
}
|
||||
|
||||
func (form *recordOAuth2LoginForm) checkProviderName(value any) error {
|
||||
name, _ := value.(string)
|
||||
|
||||
_, ok := form.collection.OAuth2.GetProviderConfig(name)
|
||||
if !ok {
|
||||
return validation.NewError("validation_invalid_provider", "Provider with name {{.name}} is missing or is not enabled.").
|
||||
SetParams(map[string]any{"name": name})
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func oldCanAssignUsername(txApp core.App, collection *core.Collection, username string) bool {
|
||||
// ensure that username is unique
|
||||
index, hasUniqueue := dbutils.FindSingleColumnUniqueIndex(collection.Indexes, collection.OAuth2.MappedFields.Username)
|
||||
if hasUniqueue {
|
||||
var expr dbx.Expression
|
||||
if strings.EqualFold(index.Columns[0].Collate, "nocase") {
|
||||
// case-insensitive search
|
||||
expr = dbx.NewExp("username = {:username} COLLATE NOCASE", dbx.Params{"username": username})
|
||||
} else {
|
||||
expr = dbx.HashExp{"username": username}
|
||||
}
|
||||
|
||||
var exists int
|
||||
_ = txApp.RecordQuery(collection).Select("(1)").AndWhere(expr).Limit(1).Row(&exists)
|
||||
if exists > 0 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// ensure that the value matches the pattern of the username field (if text)
|
||||
txtField, _ := collection.Fields.GetByName(collection.OAuth2.MappedFields.Username).(*core.TextField)
|
||||
|
||||
return txtField != nil && txtField.ValidatePlainValue(username) == nil
|
||||
}
|
||||
|
||||
func oauth2Submit(e *core.RecordAuthWithOAuth2RequestEvent, optExternalAuth *core.ExternalAuth) error {
|
||||
return e.App.RunInTransaction(func(txApp core.App) error {
|
||||
if e.Record == nil {
|
||||
// extra check to prevent creating a superuser record via
|
||||
// OAuth2 in case the method is used by another action
|
||||
if e.Collection.Name == core.CollectionNameSuperusers {
|
||||
return errors.New("superusers are not allowed to sign-up with OAuth2")
|
||||
}
|
||||
|
||||
payload := maps.Clone(e.CreateData)
|
||||
if payload == nil {
|
||||
payload = map[string]any{}
|
||||
}
|
||||
|
||||
// assign the OAuth2 user email only if the user hasn't submitted one
|
||||
// (ignore empty/invalid values for consistency with the OAuth2->existing user update flow)
|
||||
if v, _ := payload[core.FieldNameEmail].(string); v == "" {
|
||||
payload[core.FieldNameEmail] = e.OAuth2User.Email
|
||||
}
|
||||
|
||||
// map known fields (unless the field was explicitly submitted as part of CreateData)
|
||||
if _, ok := payload[e.Collection.OAuth2.MappedFields.Id]; !ok && e.Collection.OAuth2.MappedFields.Id != "" {
|
||||
payload[e.Collection.OAuth2.MappedFields.Id] = e.OAuth2User.Id
|
||||
}
|
||||
if _, ok := payload[e.Collection.OAuth2.MappedFields.Name]; !ok && e.Collection.OAuth2.MappedFields.Name != "" {
|
||||
payload[e.Collection.OAuth2.MappedFields.Name] = e.OAuth2User.Name
|
||||
}
|
||||
if _, ok := payload[e.Collection.OAuth2.MappedFields.Username]; !ok &&
|
||||
// no explicit username payload value and existing OAuth2 mapping
|
||||
e.Collection.OAuth2.MappedFields.Username != "" &&
|
||||
// extra checks for backward compatibility with earlier versions
|
||||
oldCanAssignUsername(txApp, e.Collection, e.OAuth2User.Username) {
|
||||
payload[e.Collection.OAuth2.MappedFields.Username] = e.OAuth2User.Username
|
||||
}
|
||||
if _, ok := payload[e.Collection.OAuth2.MappedFields.AvatarURL]; !ok &&
|
||||
// no explicit avatar payload value and existing OAuth2 mapping
|
||||
e.Collection.OAuth2.MappedFields.AvatarURL != "" &&
|
||||
// non-empty OAuth2 avatar url
|
||||
e.OAuth2User.AvatarURL != "" {
|
||||
mappedField := e.Collection.Fields.GetByName(e.Collection.OAuth2.MappedFields.AvatarURL)
|
||||
if mappedField != nil && mappedField.Type() == core.FieldTypeFile {
|
||||
// download the avatar if the mapped field is a file
|
||||
avatarFile, err := func() (*filesystem.File, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
return filesystem.NewFileFromURL(ctx, e.OAuth2User.AvatarURL)
|
||||
}()
|
||||
if err != nil {
|
||||
txApp.Logger().Warn("Failed to retrieve OAuth2 avatar", slog.String("error", err.Error()))
|
||||
} else {
|
||||
payload[e.Collection.OAuth2.MappedFields.AvatarURL] = avatarFile
|
||||
}
|
||||
} else {
|
||||
// otherwise - assign the url string
|
||||
payload[e.Collection.OAuth2.MappedFields.AvatarURL] = e.OAuth2User.AvatarURL
|
||||
}
|
||||
}
|
||||
|
||||
createdRecord, err := sendOAuth2RecordCreateRequest(txApp, e, payload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
e.Record = createdRecord
|
||||
|
||||
if e.Record.Email() == e.OAuth2User.Email && !e.Record.Verified() {
|
||||
// mark as verified as long as it matches the OAuth2 data (even if the email is empty)
|
||||
e.Record.SetVerified(true)
|
||||
if err := txApp.Save(e.Record); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
} else {
|
||||
var needUpdate bool
|
||||
|
||||
isLoggedAuthRecord := e.Auth != nil &&
|
||||
e.Auth.Id == e.Record.Id &&
|
||||
e.Auth.Collection().Id == e.Record.Collection().Id
|
||||
|
||||
// set random password for users with unverified email
|
||||
// (this is in case a malicious actor has registered previously with the user email)
|
||||
if !isLoggedAuthRecord && e.Record.Email() != "" && !e.Record.Verified() {
|
||||
e.Record.SetRandomPassword()
|
||||
needUpdate = true
|
||||
}
|
||||
|
||||
// update the existing auth record empty email if the data.OAuth2User has one
|
||||
// (this is in case previously the auth record was created
|
||||
// with an OAuth2 provider that didn't return an email address)
|
||||
if e.Record.Email() == "" && e.OAuth2User.Email != "" {
|
||||
e.Record.SetEmail(e.OAuth2User.Email)
|
||||
needUpdate = true
|
||||
}
|
||||
|
||||
// update the existing auth record verified state
|
||||
// (only if the auth record doesn't have an email or the auth record email match with the one in data.OAuth2User)
|
||||
if !e.Record.Verified() && (e.Record.Email() == "" || e.Record.Email() == e.OAuth2User.Email) {
|
||||
e.Record.SetVerified(true)
|
||||
needUpdate = true
|
||||
}
|
||||
|
||||
if needUpdate {
|
||||
if err := txApp.Save(e.Record); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// create ExternalAuth relation if missing
|
||||
if optExternalAuth == nil {
|
||||
optExternalAuth = core.NewExternalAuth(txApp)
|
||||
optExternalAuth.SetCollectionRef(e.Record.Collection().Id)
|
||||
optExternalAuth.SetRecordRef(e.Record.Id)
|
||||
optExternalAuth.SetProvider(e.ProviderName)
|
||||
optExternalAuth.SetProviderId(e.OAuth2User.Id)
|
||||
|
||||
if err := txApp.Save(optExternalAuth); err != nil {
|
||||
return fmt.Errorf("failed to save linked rel: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func sendOAuth2RecordCreateRequest(txApp core.App, e *core.RecordAuthWithOAuth2RequestEvent, payload map[string]any) (*core.Record, error) {
|
||||
ir := &core.InternalRequest{
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/" + e.Collection.Name + "/records",
|
||||
Body: payload,
|
||||
}
|
||||
|
||||
var createdRecord *core.Record
|
||||
response, err := processInternalRequest(txApp, e.RequestEvent, ir, core.RequestInfoContextOAuth2, func(data any) error {
|
||||
createdRecord, _ = data.(*core.Record)
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if response.Status != http.StatusOK || createdRecord == nil {
|
||||
return nil, errors.New("failed to create OAuth2 auth record")
|
||||
}
|
||||
|
||||
return createdRecord, nil
|
||||
}
|
74
apis/record_auth_with_oauth2_redirect.go
Normal file
74
apis/record_auth_with_oauth2_redirect.go
Normal file
|
@ -0,0 +1,74 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tools/subscriptions"
|
||||
)
|
||||
|
||||
const (
|
||||
oauth2SubscriptionTopic string = "@oauth2"
|
||||
oauth2RedirectFailurePath string = "../_/#/auth/oauth2-redirect-failure"
|
||||
oauth2RedirectSuccessPath string = "../_/#/auth/oauth2-redirect-success"
|
||||
)
|
||||
|
||||
type oauth2RedirectData struct {
|
||||
State string `form:"state" json:"state"`
|
||||
Code string `form:"code" json:"code"`
|
||||
Error string `form:"error" json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func oauth2SubscriptionRedirect(e *core.RequestEvent) error {
|
||||
redirectStatusCode := http.StatusTemporaryRedirect
|
||||
if e.Request.Method != http.MethodGet {
|
||||
redirectStatusCode = http.StatusSeeOther
|
||||
}
|
||||
|
||||
data := oauth2RedirectData{}
|
||||
|
||||
if e.Request.Method == http.MethodPost {
|
||||
if err := e.BindBody(&data); err != nil {
|
||||
e.App.Logger().Debug("Failed to read OAuth2 redirect data", "error", err)
|
||||
return e.Redirect(redirectStatusCode, oauth2RedirectFailurePath)
|
||||
}
|
||||
} else {
|
||||
query := e.Request.URL.Query()
|
||||
data.State = query.Get("state")
|
||||
data.Code = query.Get("code")
|
||||
data.Error = query.Get("error")
|
||||
}
|
||||
|
||||
if data.State == "" {
|
||||
e.App.Logger().Debug("Missing OAuth2 state parameter")
|
||||
return e.Redirect(redirectStatusCode, oauth2RedirectFailurePath)
|
||||
}
|
||||
|
||||
client, err := e.App.SubscriptionsBroker().ClientById(data.State)
|
||||
if err != nil || client.IsDiscarded() || !client.HasSubscription(oauth2SubscriptionTopic) {
|
||||
e.App.Logger().Debug("Missing or invalid OAuth2 subscription client", "error", err, "clientId", data.State)
|
||||
return e.Redirect(redirectStatusCode, oauth2RedirectFailurePath)
|
||||
}
|
||||
defer client.Unsubscribe(oauth2SubscriptionTopic)
|
||||
|
||||
encodedData, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
e.App.Logger().Debug("Failed to marshalize OAuth2 redirect data", "error", err)
|
||||
return e.Redirect(redirectStatusCode, oauth2RedirectFailurePath)
|
||||
}
|
||||
|
||||
msg := subscriptions.Message{
|
||||
Name: oauth2SubscriptionTopic,
|
||||
Data: encodedData,
|
||||
}
|
||||
|
||||
client.Send(msg)
|
||||
|
||||
if data.Error != "" || data.Code == "" {
|
||||
e.App.Logger().Debug("Failed OAuth2 redirect due to an error or missing code parameter", "error", data.Error, "clientId", data.State)
|
||||
return e.Redirect(redirectStatusCode, oauth2RedirectFailurePath)
|
||||
}
|
||||
|
||||
return e.Redirect(redirectStatusCode, oauth2RedirectSuccessPath)
|
||||
}
|
274
apis/record_auth_with_oauth2_redirect_test.go
Normal file
274
apis/record_auth_with_oauth2_redirect_test.go
Normal file
|
@ -0,0 +1,274 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
"github.com/pocketbase/pocketbase/tools/subscriptions"
|
||||
)
|
||||
|
||||
func TestRecordAuthWithOAuth2Redirect(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
clientStubs := make([]map[string]subscriptions.Client, 0, 10)
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
c1 := subscriptions.NewDefaultClient()
|
||||
|
||||
c2 := subscriptions.NewDefaultClient()
|
||||
c2.Subscribe("@oauth2")
|
||||
|
||||
c3 := subscriptions.NewDefaultClient()
|
||||
c3.Subscribe("test1", "@oauth2")
|
||||
|
||||
c4 := subscriptions.NewDefaultClient()
|
||||
c4.Subscribe("test1", "test2")
|
||||
|
||||
c5 := subscriptions.NewDefaultClient()
|
||||
c5.Subscribe("@oauth2")
|
||||
c5.Discard()
|
||||
|
||||
clientStubs = append(clientStubs, map[string]subscriptions.Client{
|
||||
"c1": c1,
|
||||
"c2": c2,
|
||||
"c3": c3,
|
||||
"c4": c4,
|
||||
"c5": c5,
|
||||
})
|
||||
}
|
||||
|
||||
checkFailureRedirect := func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
loc := res.Header.Get("Location")
|
||||
if !strings.Contains(loc, "/oauth2-redirect-failure") {
|
||||
t.Fatalf("Expected failure redirect, got %q", loc)
|
||||
}
|
||||
}
|
||||
|
||||
checkSuccessRedirect := func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
loc := res.Header.Get("Location")
|
||||
if !strings.Contains(loc, "/oauth2-redirect-success") {
|
||||
t.Fatalf("Expected success redirect, got %q", loc)
|
||||
}
|
||||
}
|
||||
|
||||
// note: don't exit because it is usually called as part of a separate goroutine
|
||||
checkClientMessages := func(t testing.TB, clientId string, msg subscriptions.Message, expectedMessages map[string][]string) {
|
||||
if len(expectedMessages[clientId]) == 0 {
|
||||
t.Errorf("Unexpected client %q message, got %q:\n%q", clientId, msg.Name, msg.Data)
|
||||
return
|
||||
}
|
||||
|
||||
if msg.Name != "@oauth2" {
|
||||
t.Errorf("Expected @oauth2 msg.Name, got %q", msg.Name)
|
||||
return
|
||||
}
|
||||
|
||||
for _, txt := range expectedMessages[clientId] {
|
||||
if !strings.Contains(string(msg.Data), txt) {
|
||||
t.Errorf("Failed to find %q in \n%s", txt, msg.Data)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
beforeTestFunc := func(
|
||||
clients map[string]subscriptions.Client,
|
||||
expectedMessages map[string][]string,
|
||||
) func(testing.TB, *tests.TestApp, *core.ServeEvent) {
|
||||
return func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
for _, client := range clients {
|
||||
app.SubscriptionsBroker().Register(client)
|
||||
}
|
||||
|
||||
ctx, cancelFunc := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
|
||||
// add to the app store so that it can be cancelled manually after test completion
|
||||
app.Store().Set("cancelFunc", cancelFunc)
|
||||
|
||||
go func() {
|
||||
defer cancelFunc()
|
||||
|
||||
for {
|
||||
select {
|
||||
case msg, ok := <-clients["c1"].Channel():
|
||||
if ok {
|
||||
checkClientMessages(t, "c1", msg, expectedMessages)
|
||||
} else {
|
||||
t.Errorf("Unexpected c1 closed channel")
|
||||
}
|
||||
case msg, ok := <-clients["c2"].Channel():
|
||||
if ok {
|
||||
checkClientMessages(t, "c2", msg, expectedMessages)
|
||||
} else {
|
||||
t.Errorf("Unexpected c2 closed channel")
|
||||
}
|
||||
case msg, ok := <-clients["c3"].Channel():
|
||||
if ok {
|
||||
checkClientMessages(t, "c3", msg, expectedMessages)
|
||||
} else {
|
||||
t.Errorf("Unexpected c3 closed channel")
|
||||
}
|
||||
case msg, ok := <-clients["c4"].Channel():
|
||||
if ok {
|
||||
checkClientMessages(t, "c4", msg, expectedMessages)
|
||||
} else {
|
||||
t.Errorf("Unexpected c4 closed channel")
|
||||
}
|
||||
case _, ok := <-clients["c5"].Channel():
|
||||
if ok {
|
||||
t.Errorf("Expected c5 channel to be closed")
|
||||
}
|
||||
case <-ctx.Done():
|
||||
for _, c := range clients {
|
||||
c.Discard()
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "no state query param",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/oauth2-redirect?code=123",
|
||||
BeforeTestFunc: beforeTestFunc(clientStubs[0], nil),
|
||||
ExpectedStatus: http.StatusTemporaryRedirect,
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
app.Store().Get("cancelFunc").(context.CancelFunc)()
|
||||
|
||||
checkFailureRedirect(t, app, res)
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "invalid or missing client",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/oauth2-redirect?code=123&state=missing",
|
||||
BeforeTestFunc: beforeTestFunc(clientStubs[1], nil),
|
||||
ExpectedStatus: http.StatusTemporaryRedirect,
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
app.Store().Get("cancelFunc").(context.CancelFunc)()
|
||||
|
||||
checkFailureRedirect(t, app, res)
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "no code query param",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/oauth2-redirect?state=" + clientStubs[2]["c3"].Id(),
|
||||
BeforeTestFunc: beforeTestFunc(clientStubs[2], map[string][]string{
|
||||
"c3": {`"state":"` + clientStubs[2]["c3"].Id(), `"code":""`},
|
||||
}),
|
||||
ExpectedStatus: http.StatusTemporaryRedirect,
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
app.Store().Get("cancelFunc").(context.CancelFunc)()
|
||||
|
||||
checkFailureRedirect(t, app, res)
|
||||
|
||||
if clientStubs[2]["c3"].HasSubscription("@oauth2") {
|
||||
t.Fatalf("Expected oauth2 subscription to be removed")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "error query param",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/oauth2-redirect?error=example&code=123&state=" + clientStubs[3]["c3"].Id(),
|
||||
BeforeTestFunc: beforeTestFunc(clientStubs[3], map[string][]string{
|
||||
"c3": {`"state":"` + clientStubs[3]["c3"].Id(), `"code":"123"`, `"error":"example"`},
|
||||
}),
|
||||
ExpectedStatus: http.StatusTemporaryRedirect,
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
app.Store().Get("cancelFunc").(context.CancelFunc)()
|
||||
|
||||
checkFailureRedirect(t, app, res)
|
||||
|
||||
if clientStubs[3]["c3"].HasSubscription("@oauth2") {
|
||||
t.Fatalf("Expected oauth2 subscription to be removed")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "discarded client with @oauth2 subscription",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/oauth2-redirect?code=123&state=" + clientStubs[4]["c5"].Id(),
|
||||
BeforeTestFunc: beforeTestFunc(clientStubs[4], nil),
|
||||
ExpectedStatus: http.StatusTemporaryRedirect,
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
app.Store().Get("cancelFunc").(context.CancelFunc)()
|
||||
|
||||
checkFailureRedirect(t, app, res)
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "client without @oauth2 subscription",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/oauth2-redirect?code=123&state=" + clientStubs[4]["c4"].Id(),
|
||||
BeforeTestFunc: beforeTestFunc(clientStubs[5], nil),
|
||||
ExpectedStatus: http.StatusTemporaryRedirect,
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
app.Store().Get("cancelFunc").(context.CancelFunc)()
|
||||
|
||||
checkFailureRedirect(t, app, res)
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "client with @oauth2 subscription",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/oauth2-redirect?code=123&state=" + clientStubs[6]["c3"].Id(),
|
||||
BeforeTestFunc: beforeTestFunc(clientStubs[6], map[string][]string{
|
||||
"c3": {`"state":"` + clientStubs[6]["c3"].Id(), `"code":"123"`},
|
||||
}),
|
||||
ExpectedStatus: http.StatusTemporaryRedirect,
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
app.Store().Get("cancelFunc").(context.CancelFunc)()
|
||||
|
||||
checkSuccessRedirect(t, app, res)
|
||||
|
||||
if clientStubs[6]["c3"].HasSubscription("@oauth2") {
|
||||
t.Fatalf("Expected oauth2 subscription to be removed")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "(POST) client with @oauth2 subscription",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/oauth2-redirect",
|
||||
Body: strings.NewReader("code=123&state=" + clientStubs[7]["c3"].Id()),
|
||||
Headers: map[string]string{
|
||||
"content-type": "application/x-www-form-urlencoded",
|
||||
},
|
||||
BeforeTestFunc: beforeTestFunc(clientStubs[7], map[string][]string{
|
||||
"c3": {`"state":"` + clientStubs[7]["c3"].Id(), `"code":"123"`},
|
||||
}),
|
||||
ExpectedStatus: http.StatusSeeOther,
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
app.Store().Get("cancelFunc").(context.CancelFunc)()
|
||||
|
||||
checkSuccessRedirect(t, app, res)
|
||||
|
||||
if clientStubs[7]["c3"].HasSubscription("@oauth2") {
|
||||
t.Fatalf("Expected oauth2 subscription to be removed")
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
1715
apis/record_auth_with_oauth2_test.go
Normal file
1715
apis/record_auth_with_oauth2_test.go
Normal file
File diff suppressed because it is too large
Load diff
106
apis/record_auth_with_otp.go
Normal file
106
apis/record_auth_with_otp.go
Normal file
|
@ -0,0 +1,106 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
)
|
||||
|
||||
func recordAuthWithOTP(e *core.RequestEvent) error {
|
||||
collection, err := findAuthCollection(e)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !collection.OTP.Enabled {
|
||||
return e.ForbiddenError("The collection is not configured to allow OTP authentication.", nil)
|
||||
}
|
||||
|
||||
form := &authWithOTPForm{}
|
||||
if err = e.BindBody(form); err != nil {
|
||||
return firstApiError(err, e.BadRequestError("An error occurred while loading the submitted data.", err))
|
||||
}
|
||||
if err = form.validate(); err != nil {
|
||||
return firstApiError(err, e.BadRequestError("An error occurred while validating the submitted data.", err))
|
||||
}
|
||||
|
||||
e.Set(core.RequestEventKeyInfoContext, core.RequestInfoContextOTP)
|
||||
|
||||
event := new(core.RecordAuthWithOTPRequestEvent)
|
||||
event.RequestEvent = e
|
||||
event.Collection = collection
|
||||
|
||||
// extra validations
|
||||
// (note: returns a generic 400 as a very basic OTPs enumeration protection)
|
||||
// ---
|
||||
event.OTP, err = e.App.FindOTPById(form.OTPId)
|
||||
if err != nil {
|
||||
return e.BadRequestError("Invalid or expired OTP", err)
|
||||
}
|
||||
|
||||
if event.OTP.CollectionRef() != collection.Id {
|
||||
return e.BadRequestError("Invalid or expired OTP", errors.New("the OTP is for a different collection"))
|
||||
}
|
||||
|
||||
if event.OTP.HasExpired(collection.OTP.DurationTime()) {
|
||||
return e.BadRequestError("Invalid or expired OTP", errors.New("the OTP is expired"))
|
||||
}
|
||||
|
||||
event.Record, err = e.App.FindRecordById(event.OTP.CollectionRef(), event.OTP.RecordRef())
|
||||
if err != nil {
|
||||
return e.BadRequestError("Invalid or expired OTP", fmt.Errorf("missing auth record: %w", err))
|
||||
}
|
||||
|
||||
// since otps are usually simple digit numbers, enforce an extra rate limit rule as basic enumaration protection
|
||||
err = checkRateLimit(e, "@pb_otp_"+event.Record.Id, core.RateLimitRule{MaxRequests: 5, Duration: 180})
|
||||
if err != nil {
|
||||
return e.TooManyRequestsError("Too many attempts, please try again later with a new OTP.", nil)
|
||||
}
|
||||
|
||||
if !event.OTP.ValidatePassword(form.Password) {
|
||||
return e.BadRequestError("Invalid or expired OTP", errors.New("incorrect password"))
|
||||
}
|
||||
// ---
|
||||
|
||||
return e.App.OnRecordAuthWithOTPRequest().Trigger(event, func(e *core.RecordAuthWithOTPRequestEvent) error {
|
||||
// update the user email verified state in case the OTP originate from an email address matching the current record one
|
||||
//
|
||||
// note: don't wait for success auth response (it could fail because of MFA) and because we already validated the OTP above
|
||||
otpSentTo := e.OTP.SentTo()
|
||||
if !e.Record.Verified() && otpSentTo != "" && e.Record.Email() == otpSentTo {
|
||||
e.Record.SetVerified(true)
|
||||
err = e.App.Save(e.Record)
|
||||
if err != nil {
|
||||
e.App.Logger().Error("Failed to update record verified state after successful OTP validation",
|
||||
"error", err,
|
||||
"otpId", e.OTP.Id,
|
||||
"recordId", e.Record.Id,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// try to delete the used otp
|
||||
err = e.App.Delete(e.OTP)
|
||||
if err != nil {
|
||||
e.App.Logger().Error("Failed to delete used OTP", "error", err, "otpId", e.OTP.Id)
|
||||
}
|
||||
|
||||
return RecordAuthResponse(e.RequestEvent, e.Record, core.MFAMethodOTP, nil)
|
||||
})
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
type authWithOTPForm struct {
|
||||
OTPId string `form:"otpId" json:"otpId"`
|
||||
Password string `form:"password" json:"password"`
|
||||
}
|
||||
|
||||
func (form *authWithOTPForm) validate() error {
|
||||
return validation.ValidateStruct(form,
|
||||
validation.Field(&form.OTPId, validation.Required, validation.Length(1, 255)),
|
||||
validation.Field(&form.Password, validation.Required, validation.Length(1, 71)),
|
||||
)
|
||||
}
|
608
apis/record_auth_with_otp_test.go
Normal file
608
apis/record_auth_with_otp_test.go
Normal file
|
@ -0,0 +1,608 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
)
|
||||
|
||||
func TestRecordAuthWithOTP(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "not an auth collection",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/demo1/auth-with-otp",
|
||||
Body: strings.NewReader(`{"otpId":"test","password":"123456"}`),
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "auth collection with disabled otp",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-with-otp",
|
||||
Body: strings.NewReader(`{"otpId":"test","password":"123456"}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
usersCol, err := app.FindCollectionByNameOrId("users")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
usersCol.OTP.Enabled = false
|
||||
|
||||
if err := app.Save(usersCol); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "invalid body",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-with-otp",
|
||||
Body: strings.NewReader(`{"email`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "empty body",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-with-otp",
|
||||
Body: strings.NewReader(``),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"otpId":{"code":"validation_required"`,
|
||||
`"password":{"code":"validation_required"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "invalid request data",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-with-otp",
|
||||
Body: strings.NewReader(`{
|
||||
"otpId":"` + strings.Repeat("a", 256) + `",
|
||||
"password":"` + strings.Repeat("a", 72) + `"
|
||||
}`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"otpId":{"code":"validation_length_out_of_range"`,
|
||||
`"password":{"code":"validation_length_out_of_range"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "missing otp",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-with-otp",
|
||||
Body: strings.NewReader(`{
|
||||
"otpId":"missing",
|
||||
"password":"123456"
|
||||
}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
otp := core.NewOTP(app)
|
||||
otp.Id = strings.Repeat("a", 15)
|
||||
otp.SetCollectionRef(user.Collection().Id)
|
||||
otp.SetRecordRef(user.Id)
|
||||
otp.SetPassword("123456")
|
||||
if err := app.Save(otp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "otp for different collection",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-with-otp",
|
||||
Body: strings.NewReader(`{
|
||||
"otpId":"` + strings.Repeat("a", 15) + `",
|
||||
"password":"123456"
|
||||
}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
client, err := app.FindAuthRecordByEmail("clients", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
otp := core.NewOTP(app)
|
||||
otp.Id = strings.Repeat("a", 15)
|
||||
otp.SetCollectionRef(client.Collection().Id)
|
||||
otp.SetRecordRef(client.Id)
|
||||
otp.SetPassword("123456")
|
||||
if err := app.Save(otp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "otp with wrong password",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-with-otp",
|
||||
Body: strings.NewReader(`{
|
||||
"otpId":"` + strings.Repeat("a", 15) + `",
|
||||
"password":"123456"
|
||||
}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
otp := core.NewOTP(app)
|
||||
otp.Id = strings.Repeat("a", 15)
|
||||
otp.SetCollectionRef(user.Collection().Id)
|
||||
otp.SetRecordRef(user.Id)
|
||||
otp.SetPassword("1234567890")
|
||||
if err := app.Save(otp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "expired otp with valid password",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-with-otp",
|
||||
Body: strings.NewReader(`{
|
||||
"otpId":"` + strings.Repeat("a", 15) + `",
|
||||
"password":"123456"
|
||||
}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
otp := core.NewOTP(app)
|
||||
otp.Id = strings.Repeat("a", 15)
|
||||
otp.SetCollectionRef(user.Collection().Id)
|
||||
otp.SetRecordRef(user.Id)
|
||||
otp.SetPassword("123456")
|
||||
expiredDate := types.NowDateTime().AddDate(-3, 0, 0)
|
||||
otp.SetRaw("created", expiredDate)
|
||||
otp.SetRaw("updated", expiredDate)
|
||||
if err := app.Save(otp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "valid otp with valid password (enabled MFA)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-with-otp",
|
||||
Body: strings.NewReader(`{
|
||||
"otpId":"` + strings.Repeat("a", 15) + `",
|
||||
"password":"123456"
|
||||
}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
otp := core.NewOTP(app)
|
||||
otp.Id = strings.Repeat("a", 15)
|
||||
otp.SetCollectionRef(user.Collection().Id)
|
||||
otp.SetRecordRef(user.Id)
|
||||
otp.SetPassword("123456")
|
||||
if err := app.Save(otp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"mfaId":"`},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordAuthWithOTPRequest": 1,
|
||||
"OnRecordAuthRequest": 1,
|
||||
// ---
|
||||
"OnModelValidate": 1,
|
||||
"OnModelCreate": 1, // mfa record
|
||||
"OnModelCreateExecute": 1,
|
||||
"OnModelAfterCreateSuccess": 1,
|
||||
"OnModelDelete": 1, // otp delete
|
||||
"OnModelDeleteExecute": 1,
|
||||
"OnModelAfterDeleteSuccess": 1,
|
||||
// ---
|
||||
"OnRecordValidate": 1,
|
||||
"OnRecordCreate": 1,
|
||||
"OnRecordCreateExecute": 1,
|
||||
"OnRecordAfterCreateSuccess": 1,
|
||||
"OnRecordDelete": 1,
|
||||
"OnRecordDeleteExecute": 1,
|
||||
"OnRecordAfterDeleteSuccess": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "valid otp with valid password and empty sentTo (disabled MFA)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-with-otp",
|
||||
Body: strings.NewReader(`{
|
||||
"otpId":"` + strings.Repeat("a", 15) + `",
|
||||
"password":"123456"
|
||||
}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// ensure that the user is unverified
|
||||
user.SetVerified(false)
|
||||
if err = app.Save(user); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// disable MFA
|
||||
user.Collection().MFA.Enabled = false
|
||||
if err = app.Save(user.Collection()); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
otp := core.NewOTP(app)
|
||||
otp.Id = strings.Repeat("a", 15)
|
||||
otp.SetCollectionRef(user.Collection().Id)
|
||||
otp.SetRecordRef(user.Id)
|
||||
otp.SetPassword("123456")
|
||||
if err := app.Save(otp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// test at least once that the correct request info context is properly loaded
|
||||
app.OnRecordAuthRequest().BindFunc(func(e *core.RecordAuthRequestEvent) error {
|
||||
info, err := e.RequestInfo()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if info.Context != core.RequestInfoContextOTP {
|
||||
t.Fatalf("Expected request context %q, got %q", core.RequestInfoContextOTP, info.Context)
|
||||
}
|
||||
|
||||
return e.Next()
|
||||
})
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"token":"`,
|
||||
`"record":{`,
|
||||
`"email":"test@example.com"`,
|
||||
},
|
||||
NotExpectedContent: []string{
|
||||
`"meta":`,
|
||||
// hidden fields
|
||||
`"tokenKey"`,
|
||||
`"password"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordAuthWithOTPRequest": 1,
|
||||
"OnRecordAuthRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
// ---
|
||||
"OnModelValidate": 1,
|
||||
"OnModelCreate": 1, // authOrigin
|
||||
"OnModelCreateExecute": 1,
|
||||
"OnModelAfterCreateSuccess": 1,
|
||||
"OnModelDelete": 1, // otp delete
|
||||
"OnModelDeleteExecute": 1,
|
||||
"OnModelAfterDeleteSuccess": 1,
|
||||
// ---
|
||||
"OnRecordValidate": 1,
|
||||
"OnRecordCreate": 1,
|
||||
"OnRecordCreateExecute": 1,
|
||||
"OnRecordAfterCreateSuccess": 1,
|
||||
"OnRecordDelete": 1,
|
||||
"OnRecordDeleteExecute": 1,
|
||||
"OnRecordAfterDeleteSuccess": 1,
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if user.Verified() {
|
||||
t.Fatal("Expected the user to remain unverified because sentTo != email")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "valid otp with valid password and nonempty sentTo=email (disabled MFA)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-with-otp",
|
||||
Body: strings.NewReader(`{
|
||||
"otpId":"` + strings.Repeat("a", 15) + `",
|
||||
"password":"123456"
|
||||
}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// ensure that the user is unverified
|
||||
user.SetVerified(false)
|
||||
if err = app.Save(user); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// disable MFA
|
||||
user.Collection().MFA.Enabled = false
|
||||
if err = app.Save(user.Collection()); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
otp := core.NewOTP(app)
|
||||
otp.Id = strings.Repeat("a", 15)
|
||||
otp.SetCollectionRef(user.Collection().Id)
|
||||
otp.SetRecordRef(user.Id)
|
||||
otp.SetPassword("123456")
|
||||
otp.SetSentTo(user.Email())
|
||||
if err := app.Save(otp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"token":"`,
|
||||
`"record":{`,
|
||||
`"email":"test@example.com"`,
|
||||
},
|
||||
NotExpectedContent: []string{
|
||||
`"meta":`,
|
||||
// hidden fields
|
||||
`"tokenKey"`,
|
||||
`"password"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordAuthWithOTPRequest": 1,
|
||||
"OnRecordAuthRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
// ---
|
||||
"OnModelValidate": 2, // +1 because of the verified user update
|
||||
// authOrigin create
|
||||
"OnModelCreate": 1,
|
||||
"OnModelCreateExecute": 1,
|
||||
"OnModelAfterCreateSuccess": 1,
|
||||
// OTP delete
|
||||
"OnModelDelete": 1,
|
||||
"OnModelDeleteExecute": 1,
|
||||
"OnModelAfterDeleteSuccess": 1,
|
||||
// user verified update
|
||||
"OnModelUpdate": 1,
|
||||
"OnModelUpdateExecute": 1,
|
||||
"OnModelAfterUpdateSuccess": 1,
|
||||
// ---
|
||||
"OnRecordValidate": 2,
|
||||
"OnRecordCreate": 1,
|
||||
"OnRecordCreateExecute": 1,
|
||||
"OnRecordAfterCreateSuccess": 1,
|
||||
"OnRecordDelete": 1,
|
||||
"OnRecordDeleteExecute": 1,
|
||||
"OnRecordAfterDeleteSuccess": 1,
|
||||
"OnRecordUpdate": 1,
|
||||
"OnRecordUpdateExecute": 1,
|
||||
"OnRecordAfterUpdateSuccess": 1,
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !user.Verified() {
|
||||
t.Fatal("Expected the user to be marked as verified")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "OnRecordAuthWithOTPRequest tx body write check",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-with-otp",
|
||||
Body: strings.NewReader(`{
|
||||
"otpId":"` + strings.Repeat("a", 15) + `",
|
||||
"password":"123456"
|
||||
}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// disable MFA
|
||||
user.Collection().MFA.Enabled = false
|
||||
if err = app.Save(user.Collection()); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
otp := core.NewOTP(app)
|
||||
otp.Id = strings.Repeat("a", 15)
|
||||
otp.SetCollectionRef(user.Collection().Id)
|
||||
otp.SetRecordRef(user.Id)
|
||||
otp.SetPassword("123456")
|
||||
if err := app.Save(otp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
app.OnRecordAuthWithOTPRequest().BindFunc(func(e *core.RecordAuthWithOTPRequestEvent) error {
|
||||
original := e.App
|
||||
return e.App.RunInTransaction(func(txApp core.App) error {
|
||||
e.App = txApp
|
||||
defer func() { e.App = original }()
|
||||
|
||||
if err := e.Next(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return e.BadRequestError("TX_ERROR", nil)
|
||||
})
|
||||
})
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedEvents: map[string]int{"OnRecordAuthWithOTPRequest": 1},
|
||||
ExpectedContent: []string{"TX_ERROR"},
|
||||
},
|
||||
|
||||
// rate limit checks
|
||||
// -----------------------------------------------------------
|
||||
{
|
||||
Name: "RateLimit rule - users:authWithOTP",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-with-otp",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 100, Label: "*:authWithOTP"},
|
||||
{MaxRequests: 100, Label: "users:auth"},
|
||||
{MaxRequests: 0, Label: "users:authWithOTP"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "RateLimit rule - *:authWithOTP",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-with-otp",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 100, Label: "*:auth"},
|
||||
{MaxRequests: 0, Label: "*:authWithOTP"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "RateLimit rule - users:auth",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-with-otp",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 100, Label: "*:authWithOTP"},
|
||||
{MaxRequests: 0, Label: "users:auth"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "RateLimit rule - *:auth",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-with-otp",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 0, Label: "*:auth"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordAuthWithOTPManualRateLimiterCheck(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var storeCache map[string]any
|
||||
|
||||
otpAId := strings.Repeat("a", 15)
|
||||
otpBId := strings.Repeat("b", 15)
|
||||
|
||||
scenarios := []struct {
|
||||
otpId string
|
||||
password string
|
||||
expectedStatus int
|
||||
}{
|
||||
{otpAId, "12345", 400},
|
||||
{otpAId, "12345", 400},
|
||||
{otpBId, "12345", 400},
|
||||
{otpBId, "12345", 400},
|
||||
{otpBId, "12345", 400},
|
||||
{otpAId, "12345", 429},
|
||||
{otpAId, "123456", 429}, // reject even if it is correct
|
||||
{otpAId, "123456", 429},
|
||||
{otpBId, "123456", 429},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
(&tests.ApiScenario{
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-with-otp",
|
||||
Body: strings.NewReader(`{
|
||||
"otpId":"` + s.otpId + `",
|
||||
"password":"` + s.password + `"
|
||||
}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
for k, v := range storeCache {
|
||||
app.Store().Set(k, v)
|
||||
}
|
||||
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
user.Collection().MFA.Enabled = false
|
||||
if err := app.Save(user.Collection()); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
for _, id := range []string{otpAId, otpBId} {
|
||||
otp := core.NewOTP(app)
|
||||
otp.Id = id
|
||||
otp.SetCollectionRef(user.Collection().Id)
|
||||
otp.SetRecordRef(user.Id)
|
||||
otp.SetPassword("123456")
|
||||
if err := app.Save(otp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
},
|
||||
ExpectedStatus: s.expectedStatus,
|
||||
ExpectedContent: []string{`"`}, // it doesn't matter anything non-empty
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
storeCache = app.Store().GetAll()
|
||||
},
|
||||
}).Test(t)
|
||||
}
|
||||
}
|
135
apis/record_auth_with_password.go
Normal file
135
apis/record_auth_with_password.go
Normal file
|
@ -0,0 +1,135 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/go-ozzo/ozzo-validation/v4/is"
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tools/dbutils"
|
||||
"github.com/pocketbase/pocketbase/tools/list"
|
||||
)
|
||||
|
||||
func recordAuthWithPassword(e *core.RequestEvent) error {
|
||||
collection, err := findAuthCollection(e)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !collection.PasswordAuth.Enabled {
|
||||
return e.ForbiddenError("The collection is not configured to allow password authentication.", nil)
|
||||
}
|
||||
|
||||
form := &authWithPasswordForm{}
|
||||
if err = e.BindBody(form); err != nil {
|
||||
return firstApiError(err, e.BadRequestError("An error occurred while loading the submitted data.", err))
|
||||
}
|
||||
if err = form.validate(collection); err != nil {
|
||||
return firstApiError(err, e.BadRequestError("An error occurred while validating the submitted data.", err))
|
||||
}
|
||||
|
||||
e.Set(core.RequestEventKeyInfoContext, core.RequestInfoContextPasswordAuth)
|
||||
|
||||
var foundRecord *core.Record
|
||||
var foundErr error
|
||||
|
||||
if form.IdentityField != "" {
|
||||
foundRecord, foundErr = findRecordByIdentityField(e.App, collection, form.IdentityField, form.Identity)
|
||||
} else {
|
||||
// prioritize email lookup
|
||||
isEmail := is.EmailFormat.Validate(form.Identity) == nil
|
||||
if isEmail && list.ExistInSlice(core.FieldNameEmail, collection.PasswordAuth.IdentityFields) {
|
||||
foundRecord, foundErr = findRecordByIdentityField(e.App, collection, core.FieldNameEmail, form.Identity)
|
||||
}
|
||||
|
||||
// search by the other identity fields
|
||||
if !isEmail || foundErr != nil {
|
||||
for _, name := range collection.PasswordAuth.IdentityFields {
|
||||
if !isEmail && name == core.FieldNameEmail {
|
||||
continue // no need to search by the email field if it is not an email
|
||||
}
|
||||
|
||||
foundRecord, foundErr = findRecordByIdentityField(e.App, collection, name, form.Identity)
|
||||
if foundErr == nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ignore not found errors to allow custom record find implementations
|
||||
if foundErr != nil && !errors.Is(foundErr, sql.ErrNoRows) {
|
||||
return e.InternalServerError("", foundErr)
|
||||
}
|
||||
|
||||
event := new(core.RecordAuthWithPasswordRequestEvent)
|
||||
event.RequestEvent = e
|
||||
event.Collection = collection
|
||||
event.Record = foundRecord
|
||||
event.Identity = form.Identity
|
||||
event.Password = form.Password
|
||||
event.IdentityField = form.IdentityField
|
||||
|
||||
return e.App.OnRecordAuthWithPasswordRequest().Trigger(event, func(e *core.RecordAuthWithPasswordRequestEvent) error {
|
||||
if e.Record == nil || !e.Record.ValidatePassword(e.Password) {
|
||||
return e.BadRequestError("Failed to authenticate.", errors.New("invalid login credentials"))
|
||||
}
|
||||
|
||||
return RecordAuthResponse(e.RequestEvent, e.Record, core.MFAMethodPassword, nil)
|
||||
})
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
type authWithPasswordForm struct {
|
||||
Identity string `form:"identity" json:"identity"`
|
||||
Password string `form:"password" json:"password"`
|
||||
|
||||
// IdentityField specifies the field to use to search for the identity
|
||||
// (leave it empty for "auto" detection).
|
||||
IdentityField string `form:"identityField" json:"identityField"`
|
||||
}
|
||||
|
||||
func (form *authWithPasswordForm) validate(collection *core.Collection) error {
|
||||
return validation.ValidateStruct(form,
|
||||
validation.Field(&form.Identity, validation.Required, validation.Length(1, 255)),
|
||||
validation.Field(&form.Password, validation.Required, validation.Length(1, 255)),
|
||||
validation.Field(
|
||||
&form.IdentityField,
|
||||
validation.Length(1, 255),
|
||||
validation.In(list.ToInterfaceSlice(collection.PasswordAuth.IdentityFields)...),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
func findRecordByIdentityField(app core.App, collection *core.Collection, field string, value any) (*core.Record, error) {
|
||||
if !slices.Contains(collection.PasswordAuth.IdentityFields, field) {
|
||||
return nil, errors.New("invalid identity field " + field)
|
||||
}
|
||||
|
||||
index, ok := dbutils.FindSingleColumnUniqueIndex(collection.Indexes, field)
|
||||
if !ok {
|
||||
return nil, errors.New("missing " + field + " unique index constraint")
|
||||
}
|
||||
|
||||
var expr dbx.Expression
|
||||
if strings.EqualFold(index.Columns[0].Collate, "nocase") {
|
||||
// case-insensitive search
|
||||
expr = dbx.NewExp("[["+field+"]] = {:identity} COLLATE NOCASE", dbx.Params{"identity": value})
|
||||
} else {
|
||||
expr = dbx.HashExp{field: value}
|
||||
}
|
||||
|
||||
record := &core.Record{}
|
||||
|
||||
err := app.RecordQuery(collection).AndWhere(expr).Limit(1).One(record)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
713
apis/record_auth_with_password_test.go
Normal file
713
apis/record_auth_with_password_test.go
Normal file
|
@ -0,0 +1,713 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
"github.com/pocketbase/pocketbase/tools/dbutils"
|
||||
)
|
||||
|
||||
func TestRecordAuthWithPassword(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
updateIdentityIndex := func(collectionIdOrName string, fieldCollateMap map[string]string) func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
return func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
collection, err := app.FindCollectionByNameOrId("clients")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
for column, collate := range fieldCollateMap {
|
||||
index, ok := dbutils.FindSingleColumnUniqueIndex(collection.Indexes, column)
|
||||
if !ok {
|
||||
t.Fatalf("Missing unique identityField index for column %q", column)
|
||||
}
|
||||
|
||||
index.Columns[0].Collate = collate
|
||||
|
||||
collection.RemoveIndex(index.IndexName)
|
||||
collection.Indexes = append(collection.Indexes, index.Build())
|
||||
}
|
||||
|
||||
err = app.Save(collection)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to update identityField index: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "disabled password auth",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/nologin/auth-with-password",
|
||||
Body: strings.NewReader(`{"identity":"test@example.com","password":"1234567890"}`),
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "non-auth collection",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/demo1/auth-with-password",
|
||||
Body: strings.NewReader(`{"identity":"test@example.com","password":"1234567890"}`),
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "invalid body format",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/clients/auth-with-password",
|
||||
Body: strings.NewReader(`{"identity`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "empty body params",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/clients/auth-with-password",
|
||||
Body: strings.NewReader(`{"identity":"","password":""}`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"identity":{`,
|
||||
`"password":{`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "OnRecordAuthWithPasswordRequest tx body write check",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/clients/auth-with-password",
|
||||
Body: strings.NewReader(`{
|
||||
"identity":"test@example.com",
|
||||
"password":"1234567890"
|
||||
}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.OnRecordAuthWithPasswordRequest().BindFunc(func(e *core.RecordAuthWithPasswordRequestEvent) error {
|
||||
original := e.App
|
||||
return e.App.RunInTransaction(func(txApp core.App) error {
|
||||
e.App = txApp
|
||||
defer func() { e.App = original }()
|
||||
|
||||
if err := e.Next(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return e.BadRequestError("TX_ERROR", nil)
|
||||
})
|
||||
})
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedEvents: map[string]int{"OnRecordAuthWithPasswordRequest": 1},
|
||||
ExpectedContent: []string{"TX_ERROR"},
|
||||
},
|
||||
{
|
||||
Name: "valid identity field and invalid password",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/clients/auth-with-password",
|
||||
Body: strings.NewReader(`{
|
||||
"identity":"test@example.com",
|
||||
"password":"invalid"
|
||||
}`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{}`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordAuthWithPasswordRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "valid identity field (email) and valid password",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/clients/auth-with-password",
|
||||
Body: strings.NewReader(`{
|
||||
"identity":"test@example.com",
|
||||
"password":"1234567890"
|
||||
}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
// test at least once that the correct request info context is properly loaded
|
||||
app.OnRecordAuthRequest().BindFunc(func(e *core.RecordAuthRequestEvent) error {
|
||||
info, err := e.RequestInfo()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if info.Context != core.RequestInfoContextPasswordAuth {
|
||||
t.Fatalf("Expected request context %q, got %q", core.RequestInfoContextPasswordAuth, info.Context)
|
||||
}
|
||||
|
||||
return e.Next()
|
||||
})
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"email":"test@example.com"`,
|
||||
`"token":`,
|
||||
},
|
||||
NotExpectedContent: []string{
|
||||
// hidden fields
|
||||
`"tokenKey"`,
|
||||
`"password"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordAuthWithPasswordRequest": 1,
|
||||
"OnRecordAuthRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
// authOrigin track
|
||||
"OnModelCreate": 1,
|
||||
"OnModelCreateExecute": 1,
|
||||
"OnModelAfterCreateSuccess": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnRecordCreate": 1,
|
||||
"OnRecordCreateExecute": 1,
|
||||
"OnRecordAfterCreateSuccess": 1,
|
||||
"OnRecordValidate": 1,
|
||||
"OnMailerSend": 1,
|
||||
"OnMailerRecordAuthAlertSend": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "valid identity field (username) and valid password",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/clients/auth-with-password",
|
||||
Body: strings.NewReader(`{
|
||||
"identity":"clients57772",
|
||||
"password":"1234567890"
|
||||
}`),
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"email":"test@example.com"`,
|
||||
`"username":"clients57772"`,
|
||||
`"token":`,
|
||||
},
|
||||
NotExpectedContent: []string{
|
||||
// hidden fields
|
||||
`"tokenKey"`,
|
||||
`"password"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordAuthWithPasswordRequest": 1,
|
||||
"OnRecordAuthRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
// authOrigin track
|
||||
"OnModelCreate": 1,
|
||||
"OnModelCreateExecute": 1,
|
||||
"OnModelAfterCreateSuccess": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnRecordCreate": 1,
|
||||
"OnRecordCreateExecute": 1,
|
||||
"OnRecordAfterCreateSuccess": 1,
|
||||
"OnRecordValidate": 1,
|
||||
"OnMailerSend": 1,
|
||||
"OnMailerRecordAuthAlertSend": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "unknown explicit identityField",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/clients/auth-with-password",
|
||||
Body: strings.NewReader(`{
|
||||
"identityField": "created",
|
||||
"identity":"test@example.com",
|
||||
"password":"1234567890"
|
||||
}`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"identityField":{"code":"validation_in_invalid"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "valid identity field and valid password with mismatched explicit identityField",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/clients/auth-with-password",
|
||||
Body: strings.NewReader(`{
|
||||
"identityField": "username",
|
||||
"identity":"test@example.com",
|
||||
"password":"1234567890"
|
||||
}`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordAuthWithPasswordRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "valid identity field and valid password with matched explicit identityField",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/clients/auth-with-password",
|
||||
Body: strings.NewReader(`{
|
||||
"identityField": "username",
|
||||
"identity":"clients57772",
|
||||
"password":"1234567890"
|
||||
}`),
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"email":"test@example.com"`,
|
||||
`"username":"clients57772"`,
|
||||
`"token":`,
|
||||
},
|
||||
NotExpectedContent: []string{
|
||||
// hidden fields
|
||||
`"tokenKey"`,
|
||||
`"password"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordAuthWithPasswordRequest": 1,
|
||||
"OnRecordAuthRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
// authOrigin track
|
||||
"OnModelCreate": 1,
|
||||
"OnModelCreateExecute": 1,
|
||||
"OnModelAfterCreateSuccess": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnRecordCreate": 1,
|
||||
"OnRecordCreateExecute": 1,
|
||||
"OnRecordAfterCreateSuccess": 1,
|
||||
"OnRecordValidate": 1,
|
||||
"OnMailerSend": 1,
|
||||
"OnMailerRecordAuthAlertSend": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "valid identity (unverified) and valid password in onlyVerified collection",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/clients/auth-with-password",
|
||||
Body: strings.NewReader(`{
|
||||
"identity":"test2@example.com",
|
||||
"password":"1234567890"
|
||||
}`),
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordAuthWithPasswordRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "already authenticated record",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/clients/auth-with-password",
|
||||
Body: strings.NewReader(`{
|
||||
"identity":"test@example.com",
|
||||
"password":"1234567890"
|
||||
}`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"id":"gk390qegs4y47wn"`,
|
||||
`"email":"test@example.com"`,
|
||||
`"token":`,
|
||||
},
|
||||
NotExpectedContent: []string{
|
||||
// hidden fields
|
||||
`"tokenKey"`,
|
||||
`"password"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordAuthWithPasswordRequest": 1,
|
||||
"OnRecordAuthRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
// authOrigin track
|
||||
"OnModelCreate": 1,
|
||||
"OnModelCreateExecute": 1,
|
||||
"OnModelAfterCreateSuccess": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnRecordCreate": 1,
|
||||
"OnRecordCreateExecute": 1,
|
||||
"OnRecordAfterCreateSuccess": 1,
|
||||
"OnRecordValidate": 1,
|
||||
"OnMailerSend": 1,
|
||||
"OnMailerRecordAuthAlertSend": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "with mfa first auth check",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-with-password",
|
||||
Body: strings.NewReader(`{
|
||||
"identity":"test@example.com",
|
||||
"password":"1234567890"
|
||||
}`),
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{
|
||||
`"mfaId":"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordAuthWithPasswordRequest": 1,
|
||||
"OnRecordAuthRequest": 1,
|
||||
// mfa create
|
||||
"OnModelCreate": 1,
|
||||
"OnModelCreateExecute": 1,
|
||||
"OnModelAfterCreateSuccess": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnRecordCreate": 1,
|
||||
"OnRecordCreateExecute": 1,
|
||||
"OnRecordAfterCreateSuccess": 1,
|
||||
"OnRecordValidate": 1,
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
mfas, err := app.FindAllMFAsByRecord(user)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if v := len(mfas); v != 1 {
|
||||
t.Fatalf("Expected 1 mfa record to be created, got %d", v)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "with mfa second auth check",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-with-password",
|
||||
Body: strings.NewReader(`{
|
||||
"mfaId": "` + strings.Repeat("a", 15) + `",
|
||||
"identity":"test@example.com",
|
||||
"password":"1234567890"
|
||||
}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// insert a dummy mfa record
|
||||
mfa := core.NewMFA(app)
|
||||
mfa.Id = strings.Repeat("a", 15)
|
||||
mfa.SetCollectionRef(user.Collection().Id)
|
||||
mfa.SetRecordRef(user.Id)
|
||||
mfa.SetMethod("test")
|
||||
if err := app.Save(mfa); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"email":"test@example.com"`,
|
||||
`"token":`,
|
||||
},
|
||||
NotExpectedContent: []string{
|
||||
// hidden fields
|
||||
`"tokenKey"`,
|
||||
`"password"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordAuthWithPasswordRequest": 1,
|
||||
"OnRecordAuthRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
// authOrigin track
|
||||
"OnModelCreate": 1,
|
||||
"OnModelCreateExecute": 1,
|
||||
"OnModelAfterCreateSuccess": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnRecordCreate": 1,
|
||||
"OnRecordCreateExecute": 1,
|
||||
"OnRecordAfterCreateSuccess": 1,
|
||||
"OnRecordValidate": 1,
|
||||
"OnMailerSend": 0, // disabled auth email alerts
|
||||
"OnMailerRecordAuthAlertSend": 0,
|
||||
// mfa delete
|
||||
"OnModelDelete": 1,
|
||||
"OnModelDeleteExecute": 1,
|
||||
"OnModelAfterDeleteSuccess": 1,
|
||||
"OnRecordDelete": 1,
|
||||
"OnRecordDeleteExecute": 1,
|
||||
"OnRecordAfterDeleteSuccess": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "with enabled mfa but unsatisfied mfa rule (aka. skip the mfa check)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-with-password",
|
||||
Body: strings.NewReader(`{
|
||||
"identity":"test@example.com",
|
||||
"password":"1234567890"
|
||||
}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
users, err := app.FindCollectionByNameOrId("users")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
users.MFA.Enabled = true
|
||||
users.MFA.Rule = "1=2"
|
||||
|
||||
if err := app.Save(users); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"email":"test@example.com"`,
|
||||
`"token":`,
|
||||
},
|
||||
NotExpectedContent: []string{
|
||||
// hidden fields
|
||||
`"tokenKey"`,
|
||||
`"password"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordAuthWithPasswordRequest": 1,
|
||||
"OnRecordAuthRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
// authOrigin track
|
||||
"OnModelCreate": 1,
|
||||
"OnModelCreateExecute": 1,
|
||||
"OnModelAfterCreateSuccess": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnRecordCreate": 1,
|
||||
"OnRecordCreateExecute": 1,
|
||||
"OnRecordAfterCreateSuccess": 1,
|
||||
"OnRecordValidate": 1,
|
||||
"OnMailerSend": 0, // disabled auth email alerts
|
||||
"OnMailerRecordAuthAlertSend": 0,
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
mfas, err := app.FindAllMFAsByRecord(user)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if v := len(mfas); v != 0 {
|
||||
t.Fatalf("Expected no mfa records to be created, got %d", v)
|
||||
}
|
||||
},
|
||||
},
|
||||
|
||||
// case sensitivity checks
|
||||
// -----------------------------------------------------------
|
||||
{
|
||||
Name: "with explicit identityField (case-sensitive)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/clients/auth-with-password",
|
||||
Body: strings.NewReader(`{
|
||||
"identityField": "username",
|
||||
"identity":"Clients57772",
|
||||
"password":"1234567890"
|
||||
}`),
|
||||
BeforeTestFunc: updateIdentityIndex("clients", map[string]string{"username": ""}),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordAuthWithPasswordRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "with explicit identityField (case-insensitive)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/clients/auth-with-password",
|
||||
Body: strings.NewReader(`{
|
||||
"identityField": "username",
|
||||
"identity":"Clients57772",
|
||||
"password":"1234567890"
|
||||
}`),
|
||||
BeforeTestFunc: updateIdentityIndex("clients", map[string]string{"username": "nocase"}),
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"email":"test@example.com"`,
|
||||
`"username":"clients57772"`,
|
||||
`"token":`,
|
||||
},
|
||||
NotExpectedContent: []string{
|
||||
// hidden fields
|
||||
`"tokenKey"`,
|
||||
`"password"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordAuthWithPasswordRequest": 1,
|
||||
"OnRecordAuthRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
// authOrigin track
|
||||
"OnModelCreate": 1,
|
||||
"OnModelCreateExecute": 1,
|
||||
"OnModelAfterCreateSuccess": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnRecordCreate": 1,
|
||||
"OnRecordCreateExecute": 1,
|
||||
"OnRecordAfterCreateSuccess": 1,
|
||||
"OnRecordValidate": 1,
|
||||
"OnMailerSend": 1,
|
||||
"OnMailerRecordAuthAlertSend": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "without explicit identityField and non-email field (case-insensitive)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/clients/auth-with-password",
|
||||
Body: strings.NewReader(`{
|
||||
"identity":"Clients57772",
|
||||
"password":"1234567890"
|
||||
}`),
|
||||
BeforeTestFunc: updateIdentityIndex("clients", map[string]string{"username": "nocase"}),
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"email":"test@example.com"`,
|
||||
`"username":"clients57772"`,
|
||||
`"token":`,
|
||||
},
|
||||
NotExpectedContent: []string{
|
||||
// hidden fields
|
||||
`"tokenKey"`,
|
||||
`"password"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordAuthWithPasswordRequest": 1,
|
||||
"OnRecordAuthRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
// authOrigin track
|
||||
"OnModelCreate": 1,
|
||||
"OnModelCreateExecute": 1,
|
||||
"OnModelAfterCreateSuccess": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnRecordCreate": 1,
|
||||
"OnRecordCreateExecute": 1,
|
||||
"OnRecordAfterCreateSuccess": 1,
|
||||
"OnRecordValidate": 1,
|
||||
"OnMailerSend": 1,
|
||||
"OnMailerRecordAuthAlertSend": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "without explicit identityField and email field (case-insensitive)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/clients/auth-with-password",
|
||||
Body: strings.NewReader(`{
|
||||
"identity":"tESt@example.com",
|
||||
"password":"1234567890"
|
||||
}`),
|
||||
BeforeTestFunc: updateIdentityIndex("clients", map[string]string{"email": "nocase"}),
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"email":"test@example.com"`,
|
||||
`"username":"clients57772"`,
|
||||
`"token":`,
|
||||
},
|
||||
NotExpectedContent: []string{
|
||||
// hidden fields
|
||||
`"tokenKey"`,
|
||||
`"password"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordAuthWithPasswordRequest": 1,
|
||||
"OnRecordAuthRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
// authOrigin track
|
||||
"OnModelCreate": 1,
|
||||
"OnModelCreateExecute": 1,
|
||||
"OnModelAfterCreateSuccess": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnRecordCreate": 1,
|
||||
"OnRecordCreateExecute": 1,
|
||||
"OnRecordAfterCreateSuccess": 1,
|
||||
"OnRecordValidate": 1,
|
||||
"OnMailerSend": 1,
|
||||
"OnMailerRecordAuthAlertSend": 1,
|
||||
},
|
||||
},
|
||||
|
||||
// rate limit checks
|
||||
// -----------------------------------------------------------
|
||||
{
|
||||
Name: "RateLimit rule - users:authWithPassword",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-with-password",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 100, Label: "*:authWithPassword"},
|
||||
{MaxRequests: 100, Label: "users:auth"},
|
||||
{MaxRequests: 0, Label: "users:authWithPassword"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "RateLimit rule - *:authWithPassword",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-with-password",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 100, Label: "*:auth"},
|
||||
{MaxRequests: 0, Label: "*:authWithPassword"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "RateLimit rule - users:auth",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-with-password",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 100, Label: "*:authWithPassword"},
|
||||
{MaxRequests: 0, Label: "users:auth"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "RateLimit rule - *:auth",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-with-password",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 0, Label: "*:auth"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
742
apis/record_crud.go
Normal file
742
apis/record_crud.go
Normal file
|
@ -0,0 +1,742 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
cryptoRand "crypto/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/forms"
|
||||
"github.com/pocketbase/pocketbase/tools/filesystem"
|
||||
"github.com/pocketbase/pocketbase/tools/inflector"
|
||||
"github.com/pocketbase/pocketbase/tools/list"
|
||||
"github.com/pocketbase/pocketbase/tools/router"
|
||||
"github.com/pocketbase/pocketbase/tools/search"
|
||||
"github.com/pocketbase/pocketbase/tools/security"
|
||||
)
|
||||
|
||||
// bindRecordCrudApi registers the record crud api endpoints and
|
||||
// the corresponding handlers.
|
||||
//
|
||||
// note: the rate limiter is "inlined" because some of the crud actions are also used in the batch APIs
|
||||
func bindRecordCrudApi(app core.App, rg *router.RouterGroup[*core.RequestEvent]) {
|
||||
subGroup := rg.Group("/collections/{collection}/records").Unbind(DefaultRateLimitMiddlewareId)
|
||||
subGroup.GET("", recordsList)
|
||||
subGroup.GET("/{id}", recordView)
|
||||
subGroup.POST("", recordCreate(true, nil)).Bind(dynamicCollectionBodyLimit(""))
|
||||
subGroup.PATCH("/{id}", recordUpdate(true, nil)).Bind(dynamicCollectionBodyLimit(""))
|
||||
subGroup.DELETE("/{id}", recordDelete(true, nil))
|
||||
}
|
||||
|
||||
func recordsList(e *core.RequestEvent) error {
|
||||
collection, err := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue("collection"))
|
||||
if err != nil || collection == nil {
|
||||
return e.NotFoundError("Missing collection context.", err)
|
||||
}
|
||||
|
||||
err = checkCollectionRateLimit(e, collection, "list")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
requestInfo, err := e.RequestInfo()
|
||||
if err != nil {
|
||||
return firstApiError(err, e.BadRequestError("", err))
|
||||
}
|
||||
|
||||
if collection.ListRule == nil && !requestInfo.HasSuperuserAuth() {
|
||||
return e.ForbiddenError("Only superusers can perform this action.", nil)
|
||||
}
|
||||
|
||||
// forbid users and guests to query special filter/sort fields
|
||||
err = checkForSuperuserOnlyRuleFields(requestInfo)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
query := e.App.RecordQuery(collection)
|
||||
|
||||
fieldsResolver := core.NewRecordFieldResolver(e.App, collection, requestInfo, true)
|
||||
|
||||
if !requestInfo.HasSuperuserAuth() && collection.ListRule != nil && *collection.ListRule != "" {
|
||||
expr, err := search.FilterData(*collection.ListRule).BuildExpr(fieldsResolver)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
query.AndWhere(expr)
|
||||
|
||||
// will be applied by the search provider right before executing the query
|
||||
// fieldsResolver.UpdateQuery(query)
|
||||
}
|
||||
|
||||
// hidden fields are searchable only by superusers
|
||||
fieldsResolver.SetAllowHiddenFields(requestInfo.HasSuperuserAuth())
|
||||
|
||||
searchProvider := search.NewProvider(fieldsResolver).Query(query)
|
||||
|
||||
// use rowid when available to minimize the need of a covering index with the "id" field
|
||||
if !collection.IsView() {
|
||||
searchProvider.CountCol("_rowid_")
|
||||
}
|
||||
|
||||
records := []*core.Record{}
|
||||
result, err := searchProvider.ParseAndExec(e.Request.URL.Query().Encode(), &records)
|
||||
if err != nil {
|
||||
return firstApiError(err, e.BadRequestError("", err))
|
||||
}
|
||||
|
||||
event := new(core.RecordsListRequestEvent)
|
||||
event.RequestEvent = e
|
||||
event.Collection = collection
|
||||
event.Records = records
|
||||
event.Result = result
|
||||
|
||||
return e.App.OnRecordsListRequest().Trigger(event, func(e *core.RecordsListRequestEvent) error {
|
||||
if err := EnrichRecords(e.RequestEvent, e.Records); err != nil {
|
||||
return firstApiError(err, e.InternalServerError("Failed to enrich records", err))
|
||||
}
|
||||
|
||||
// Add a randomized throttle in case of too many empty search filter attempts.
|
||||
//
|
||||
// This is just for extra precaution since security researches raised concern regarding the possibility of eventual
|
||||
// timing attacks because the List API rule acts also as filter and executes in a single run with the client-side filters.
|
||||
// This is by design and it is an accepted trade off between performance, usability and correctness.
|
||||
//
|
||||
// While technically the below doesn't fully guarantee protection against filter timing attacks, in practice combined with the network latency it makes them even less feasible.
|
||||
// A properly configured rate limiter or individual fields Hidden checks are better suited if you are really concerned about eventual information disclosure by side-channel attacks.
|
||||
//
|
||||
// In all cases it doesn't really matter that much because it doesn't affect the builtin PocketBase security sensitive fields (e.g. password and tokenKey) since they
|
||||
// are not client-side filterable and in the few places where they need to be compared against an external value, a constant time check is used.
|
||||
if !e.HasSuperuserAuth() &&
|
||||
(collection.ListRule != nil && *collection.ListRule != "") &&
|
||||
(requestInfo.Query["filter"] != "") &&
|
||||
len(e.Records) == 0 &&
|
||||
checkRateLimit(e.RequestEvent, "@pb_list_timing_check_"+collection.Id, listTimingRateLimitRule) != nil {
|
||||
e.App.Logger().Debug("Randomized throttle because of too many failed searches", "collectionId", collection.Id)
|
||||
randomizedThrottle(150)
|
||||
}
|
||||
|
||||
return execAfterSuccessTx(true, e.App, func() error {
|
||||
return e.JSON(http.StatusOK, e.Result)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
var listTimingRateLimitRule = core.RateLimitRule{MaxRequests: 3, Duration: 3}
|
||||
|
||||
func randomizedThrottle(softMax int64) {
|
||||
var timeout int64
|
||||
randRange, err := cryptoRand.Int(cryptoRand.Reader, big.NewInt(softMax))
|
||||
if err == nil {
|
||||
timeout = randRange.Int64()
|
||||
} else {
|
||||
timeout = softMax
|
||||
}
|
||||
|
||||
time.Sleep(time.Duration(timeout) * time.Millisecond)
|
||||
}
|
||||
|
||||
func recordView(e *core.RequestEvent) error {
|
||||
collection, err := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue("collection"))
|
||||
if err != nil || collection == nil {
|
||||
return e.NotFoundError("Missing collection context.", err)
|
||||
}
|
||||
|
||||
err = checkCollectionRateLimit(e, collection, "view")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
recordId := e.Request.PathValue("id")
|
||||
if recordId == "" {
|
||||
return e.NotFoundError("", nil)
|
||||
}
|
||||
|
||||
requestInfo, err := e.RequestInfo()
|
||||
if err != nil {
|
||||
return firstApiError(err, e.BadRequestError("", err))
|
||||
}
|
||||
|
||||
if collection.ViewRule == nil && !requestInfo.HasSuperuserAuth() {
|
||||
return e.ForbiddenError("Only superusers can perform this action.", nil)
|
||||
}
|
||||
|
||||
ruleFunc := func(q *dbx.SelectQuery) error {
|
||||
if !requestInfo.HasSuperuserAuth() && collection.ViewRule != nil && *collection.ViewRule != "" {
|
||||
resolver := core.NewRecordFieldResolver(e.App, collection, requestInfo, true)
|
||||
expr, err := search.FilterData(*collection.ViewRule).BuildExpr(resolver)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resolver.UpdateQuery(q)
|
||||
q.AndWhere(expr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
record, fetchErr := e.App.FindRecordById(collection, recordId, ruleFunc)
|
||||
if fetchErr != nil || record == nil {
|
||||
return firstApiError(err, e.NotFoundError("", fetchErr))
|
||||
}
|
||||
|
||||
event := new(core.RecordRequestEvent)
|
||||
event.RequestEvent = e
|
||||
event.Collection = collection
|
||||
event.Record = record
|
||||
|
||||
return e.App.OnRecordViewRequest().Trigger(event, func(e *core.RecordRequestEvent) error {
|
||||
if err := EnrichRecord(e.RequestEvent, e.Record); err != nil {
|
||||
return firstApiError(err, e.InternalServerError("Failed to enrich record", err))
|
||||
}
|
||||
|
||||
return execAfterSuccessTx(true, e.App, func() error {
|
||||
return e.JSON(http.StatusOK, e.Record)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func recordCreate(responseWriteAfterTx bool, optFinalizer func(data any) error) func(e *core.RequestEvent) error {
|
||||
return func(e *core.RequestEvent) error {
|
||||
collection, err := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue("collection"))
|
||||
if err != nil || collection == nil {
|
||||
return e.NotFoundError("Missing collection context.", err)
|
||||
}
|
||||
|
||||
if collection.IsView() {
|
||||
return e.BadRequestError("Unsupported collection type.", nil)
|
||||
}
|
||||
|
||||
err = checkCollectionRateLimit(e, collection, "create")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
requestInfo, err := e.RequestInfo()
|
||||
if err != nil {
|
||||
return firstApiError(err, e.BadRequestError("", err))
|
||||
}
|
||||
|
||||
hasSuperuserAuth := requestInfo.HasSuperuserAuth()
|
||||
if !hasSuperuserAuth && collection.CreateRule == nil {
|
||||
return e.ForbiddenError("Only superusers can perform this action.", nil)
|
||||
}
|
||||
|
||||
record := core.NewRecord(collection)
|
||||
|
||||
data, err := recordDataFromRequest(e, record)
|
||||
if err != nil {
|
||||
return firstApiError(err, e.BadRequestError("Failed to read the submitted data.", err))
|
||||
}
|
||||
|
||||
// set a random password for the OAuth2 ignoring its plain password validators
|
||||
var skipPlainPasswordRecordValidators bool
|
||||
if requestInfo.Context == core.RequestInfoContextOAuth2 {
|
||||
if _, ok := data[core.FieldNamePassword]; !ok {
|
||||
data[core.FieldNamePassword] = security.RandomString(30)
|
||||
data[core.FieldNamePassword+"Confirm"] = data[core.FieldNamePassword]
|
||||
skipPlainPasswordRecordValidators = true
|
||||
}
|
||||
}
|
||||
|
||||
// replace modifiers fields so that the resolved value is always
|
||||
// available when accessing requestInfo.Body
|
||||
requestInfo.Body = data
|
||||
|
||||
form := forms.NewRecordUpsert(e.App, record)
|
||||
if hasSuperuserAuth {
|
||||
form.GrantSuperuserAccess()
|
||||
}
|
||||
form.Load(data)
|
||||
|
||||
if skipPlainPasswordRecordValidators {
|
||||
// unset the plain value to skip the plain password field validators
|
||||
if raw, ok := record.GetRaw(core.FieldNamePassword).(*core.PasswordFieldValue); ok {
|
||||
raw.Plain = ""
|
||||
}
|
||||
}
|
||||
|
||||
// check the request and record data against the create and manage rules
|
||||
if !hasSuperuserAuth && collection.CreateRule != nil {
|
||||
dummyRecord := record.Clone()
|
||||
|
||||
dummyRandomPart := "__pb_create__" + security.PseudorandomString(6)
|
||||
|
||||
// set an id if it doesn't have already
|
||||
// (the value doesn't matter; it is used only to minimize the breaking changes with earlier versions)
|
||||
if dummyRecord.Id == "" {
|
||||
dummyRecord.Id = "__temp_id__" + dummyRandomPart
|
||||
}
|
||||
|
||||
// unset the verified field to prevent manage API rule misuse in case the rule relies on it
|
||||
dummyRecord.SetVerified(false)
|
||||
|
||||
// export the dummy record data into db params
|
||||
dummyExport, err := dummyRecord.DBExport(e.App)
|
||||
if err != nil {
|
||||
return e.BadRequestError("Failed to create record", fmt.Errorf("dummy DBExport error: %w", err))
|
||||
}
|
||||
|
||||
dummyParams := make(dbx.Params, len(dummyExport))
|
||||
selects := make([]string, 0, len(dummyExport))
|
||||
var param string
|
||||
for k, v := range dummyExport {
|
||||
k = inflector.Columnify(k) // columnify is just as extra measure in case of custom fields
|
||||
param = "__pb_create__" + k
|
||||
dummyParams[param] = v
|
||||
selects = append(selects, "{:"+param+"} AS [["+k+"]]")
|
||||
}
|
||||
|
||||
// shallow clone the current collection
|
||||
dummyCollection := *collection
|
||||
dummyCollection.Id += dummyRandomPart
|
||||
dummyCollection.Name += inflector.Columnify(dummyRandomPart)
|
||||
|
||||
withFrom := fmt.Sprintf("WITH {{%s}} as (SELECT %s)", dummyCollection.Name, strings.Join(selects, ","))
|
||||
|
||||
// check non-empty create rule
|
||||
if *dummyCollection.CreateRule != "" {
|
||||
ruleQuery := e.App.ConcurrentDB().Select("(1)").PreFragment(withFrom).From(dummyCollection.Name).AndBind(dummyParams)
|
||||
|
||||
resolver := core.NewRecordFieldResolver(e.App, &dummyCollection, requestInfo, true)
|
||||
|
||||
expr, err := search.FilterData(*dummyCollection.CreateRule).BuildExpr(resolver)
|
||||
if err != nil {
|
||||
return e.BadRequestError("Failed to create record", fmt.Errorf("create rule build expression failure: %w", err))
|
||||
}
|
||||
ruleQuery.AndWhere(expr)
|
||||
|
||||
resolver.UpdateQuery(ruleQuery)
|
||||
|
||||
var exists int
|
||||
err = ruleQuery.Limit(1).Row(&exists)
|
||||
if err != nil || exists == 0 {
|
||||
return e.BadRequestError("Failed to create record", fmt.Errorf("create rule failure: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
// check for manage rule access
|
||||
manageRuleQuery := e.App.ConcurrentDB().Select("(1)").PreFragment(withFrom).From(dummyCollection.Name).AndBind(dummyParams)
|
||||
if !form.HasManageAccess() &&
|
||||
hasAuthManageAccess(e.App, requestInfo, &dummyCollection, manageRuleQuery) {
|
||||
form.GrantManagerAccess()
|
||||
}
|
||||
}
|
||||
|
||||
var isOptFinalizerCalled bool
|
||||
|
||||
event := new(core.RecordRequestEvent)
|
||||
event.RequestEvent = e
|
||||
event.Collection = collection
|
||||
event.Record = record
|
||||
|
||||
hookErr := e.App.OnRecordCreateRequest().Trigger(event, func(e *core.RecordRequestEvent) error {
|
||||
form.SetApp(e.App)
|
||||
form.SetRecord(e.Record)
|
||||
|
||||
err := form.Submit()
|
||||
if err != nil {
|
||||
return firstApiError(err, e.BadRequestError("Failed to create record", err))
|
||||
}
|
||||
|
||||
err = EnrichRecord(e.RequestEvent, e.Record)
|
||||
if err != nil {
|
||||
return firstApiError(err, e.InternalServerError("Failed to enrich record", err))
|
||||
}
|
||||
|
||||
err = execAfterSuccessTx(responseWriteAfterTx, e.App, func() error {
|
||||
return e.JSON(http.StatusOK, e.Record)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if optFinalizer != nil {
|
||||
isOptFinalizerCalled = true
|
||||
err = optFinalizer(e.Record)
|
||||
if err != nil {
|
||||
return firstApiError(err, e.InternalServerError("", err))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if hookErr != nil {
|
||||
return hookErr
|
||||
}
|
||||
|
||||
// e.g. in case the regular hook chain was stopped and the finalizer cannot be executed as part of the last e.Next() task
|
||||
if !isOptFinalizerCalled && optFinalizer != nil {
|
||||
if err := optFinalizer(event.Record); err != nil {
|
||||
return firstApiError(err, e.InternalServerError("", err))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func recordUpdate(responseWriteAfterTx bool, optFinalizer func(data any) error) func(e *core.RequestEvent) error {
|
||||
return func(e *core.RequestEvent) error {
|
||||
collection, err := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue("collection"))
|
||||
if err != nil || collection == nil {
|
||||
return e.NotFoundError("Missing collection context.", err)
|
||||
}
|
||||
|
||||
if collection.IsView() {
|
||||
return e.BadRequestError("Unsupported collection type.", nil)
|
||||
}
|
||||
|
||||
err = checkCollectionRateLimit(e, collection, "update")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
recordId := e.Request.PathValue("id")
|
||||
if recordId == "" {
|
||||
return e.NotFoundError("", nil)
|
||||
}
|
||||
|
||||
requestInfo, err := e.RequestInfo()
|
||||
if err != nil {
|
||||
return firstApiError(err, e.BadRequestError("", err))
|
||||
}
|
||||
|
||||
hasSuperuserAuth := requestInfo.HasSuperuserAuth()
|
||||
|
||||
if !hasSuperuserAuth && collection.UpdateRule == nil {
|
||||
return firstApiError(err, e.ForbiddenError("Only superusers can perform this action.", nil))
|
||||
}
|
||||
|
||||
// eager fetch the record so that the modifiers field values can be resolved
|
||||
record, err := e.App.FindRecordById(collection, recordId)
|
||||
if err != nil {
|
||||
return firstApiError(err, e.NotFoundError("", err))
|
||||
}
|
||||
|
||||
data, err := recordDataFromRequest(e, record)
|
||||
if err != nil {
|
||||
return firstApiError(err, e.BadRequestError("Failed to read the submitted data.", err))
|
||||
}
|
||||
|
||||
// replace modifiers fields so that the resolved value is always
|
||||
// available when accessing requestInfo.Body
|
||||
requestInfo.Body = data
|
||||
|
||||
ruleFunc := func(q *dbx.SelectQuery) error {
|
||||
if !hasSuperuserAuth && collection.UpdateRule != nil && *collection.UpdateRule != "" {
|
||||
resolver := core.NewRecordFieldResolver(e.App, collection, requestInfo, true)
|
||||
expr, err := search.FilterData(*collection.UpdateRule).BuildExpr(resolver)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resolver.UpdateQuery(q)
|
||||
q.AndWhere(expr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// refetch with access checks
|
||||
record, err = e.App.FindRecordById(collection, recordId, ruleFunc)
|
||||
if err != nil {
|
||||
return firstApiError(err, e.NotFoundError("", err))
|
||||
}
|
||||
|
||||
form := forms.NewRecordUpsert(e.App, record)
|
||||
if hasSuperuserAuth {
|
||||
form.GrantSuperuserAccess()
|
||||
}
|
||||
form.Load(data)
|
||||
|
||||
manageRuleQuery := e.App.ConcurrentDB().Select("(1)").From(collection.Name).AndWhere(dbx.HashExp{
|
||||
collection.Name + ".id": record.Id,
|
||||
})
|
||||
if !form.HasManageAccess() &&
|
||||
hasAuthManageAccess(e.App, requestInfo, collection, manageRuleQuery) {
|
||||
form.GrantManagerAccess()
|
||||
}
|
||||
|
||||
var isOptFinalizerCalled bool
|
||||
|
||||
event := new(core.RecordRequestEvent)
|
||||
event.RequestEvent = e
|
||||
event.Collection = collection
|
||||
event.Record = record
|
||||
|
||||
hookErr := e.App.OnRecordUpdateRequest().Trigger(event, func(e *core.RecordRequestEvent) error {
|
||||
form.SetApp(e.App)
|
||||
form.SetRecord(e.Record)
|
||||
|
||||
err := form.Submit()
|
||||
if err != nil {
|
||||
return firstApiError(err, e.BadRequestError("Failed to update record.", err))
|
||||
}
|
||||
|
||||
err = EnrichRecord(e.RequestEvent, e.Record)
|
||||
if err != nil {
|
||||
return firstApiError(err, e.InternalServerError("Failed to enrich record", err))
|
||||
}
|
||||
|
||||
err = execAfterSuccessTx(responseWriteAfterTx, e.App, func() error {
|
||||
return e.JSON(http.StatusOK, e.Record)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if optFinalizer != nil {
|
||||
isOptFinalizerCalled = true
|
||||
err = optFinalizer(e.Record)
|
||||
if err != nil {
|
||||
return firstApiError(err, e.InternalServerError("", fmt.Errorf("update optFinalizer error: %w", err)))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if hookErr != nil {
|
||||
return hookErr
|
||||
}
|
||||
|
||||
// e.g. in case the regular hook chain was stopped and the finalizer cannot be executed as part of the last e.Next() task
|
||||
if !isOptFinalizerCalled && optFinalizer != nil {
|
||||
if err := optFinalizer(event.Record); err != nil {
|
||||
return firstApiError(err, e.InternalServerError("", fmt.Errorf("update optFinalizer error: %w", err)))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func recordDelete(responseWriteAfterTx bool, optFinalizer func(data any) error) func(e *core.RequestEvent) error {
|
||||
return func(e *core.RequestEvent) error {
|
||||
collection, err := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue("collection"))
|
||||
if err != nil || collection == nil {
|
||||
return e.NotFoundError("Missing collection context.", err)
|
||||
}
|
||||
|
||||
if collection.IsView() {
|
||||
return e.BadRequestError("Unsupported collection type.", nil)
|
||||
}
|
||||
|
||||
err = checkCollectionRateLimit(e, collection, "delete")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
recordId := e.Request.PathValue("id")
|
||||
if recordId == "" {
|
||||
return e.NotFoundError("", nil)
|
||||
}
|
||||
|
||||
requestInfo, err := e.RequestInfo()
|
||||
if err != nil {
|
||||
return firstApiError(err, e.BadRequestError("", err))
|
||||
}
|
||||
|
||||
if !requestInfo.HasSuperuserAuth() && collection.DeleteRule == nil {
|
||||
return e.ForbiddenError("Only superusers can perform this action.", nil)
|
||||
}
|
||||
|
||||
ruleFunc := func(q *dbx.SelectQuery) error {
|
||||
if !requestInfo.HasSuperuserAuth() && collection.DeleteRule != nil && *collection.DeleteRule != "" {
|
||||
resolver := core.NewRecordFieldResolver(e.App, collection, requestInfo, true)
|
||||
expr, err := search.FilterData(*collection.DeleteRule).BuildExpr(resolver)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resolver.UpdateQuery(q)
|
||||
q.AndWhere(expr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
record, err := e.App.FindRecordById(collection, recordId, ruleFunc)
|
||||
if err != nil || record == nil {
|
||||
return e.NotFoundError("", err)
|
||||
}
|
||||
|
||||
var isOptFinalizerCalled bool
|
||||
|
||||
event := new(core.RecordRequestEvent)
|
||||
event.RequestEvent = e
|
||||
event.Collection = collection
|
||||
event.Record = record
|
||||
|
||||
hookErr := e.App.OnRecordDeleteRequest().Trigger(event, func(e *core.RecordRequestEvent) error {
|
||||
if err := e.App.Delete(e.Record); err != nil {
|
||||
return firstApiError(err, e.BadRequestError("Failed to delete record. Make sure that the record is not part of a required relation reference.", err))
|
||||
}
|
||||
|
||||
err = execAfterSuccessTx(responseWriteAfterTx, e.App, func() error {
|
||||
return e.NoContent(http.StatusNoContent)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if optFinalizer != nil {
|
||||
isOptFinalizerCalled = true
|
||||
err = optFinalizer(e.Record)
|
||||
if err != nil {
|
||||
return firstApiError(err, e.InternalServerError("", fmt.Errorf("delete optFinalizer error: %w", err)))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if hookErr != nil {
|
||||
return hookErr
|
||||
}
|
||||
|
||||
// e.g. in case the regular hook chain was stopped and the finalizer cannot be executed as part of the last e.Next() task
|
||||
if !isOptFinalizerCalled && optFinalizer != nil {
|
||||
if err := optFinalizer(event.Record); err != nil {
|
||||
return firstApiError(err, e.InternalServerError("", fmt.Errorf("delete optFinalizer error: %w", err)))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
func recordDataFromRequest(e *core.RequestEvent, record *core.Record) (map[string]any, error) {
|
||||
info, err := e.RequestInfo()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// resolve regular fields
|
||||
result := record.ReplaceModifiers(info.Body)
|
||||
|
||||
// resolve uploaded files
|
||||
uploadedFiles, err := extractUploadedFiles(e, record.Collection(), "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(uploadedFiles) > 0 {
|
||||
for k, files := range uploadedFiles {
|
||||
uploaded := make([]any, 0, len(files))
|
||||
|
||||
// if not remove/prepend/append -> merge with the submitted
|
||||
// info.Body values to prevent accidental old files deletion
|
||||
if info.Body[k] != nil &&
|
||||
!strings.HasPrefix(k, "+") &&
|
||||
!strings.HasSuffix(k, "+") &&
|
||||
!strings.HasSuffix(k, "-") {
|
||||
existing := list.ToUniqueStringSlice(info.Body[k])
|
||||
for _, name := range existing {
|
||||
uploaded = append(uploaded, name)
|
||||
}
|
||||
}
|
||||
|
||||
for _, file := range files {
|
||||
uploaded = append(uploaded, file)
|
||||
}
|
||||
|
||||
result[k] = uploaded
|
||||
}
|
||||
|
||||
result = record.ReplaceModifiers(result)
|
||||
}
|
||||
|
||||
isAuth := record.Collection().IsAuth()
|
||||
|
||||
// unset hidden fields for non-superusers
|
||||
if !info.HasSuperuserAuth() {
|
||||
for _, f := range record.Collection().Fields {
|
||||
if f.GetHidden() {
|
||||
// exception for the auth collection "password" field
|
||||
if isAuth && f.GetName() == core.FieldNamePassword {
|
||||
continue
|
||||
}
|
||||
|
||||
delete(result, f.GetName())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func extractUploadedFiles(re *core.RequestEvent, collection *core.Collection, prefix string) (map[string][]*filesystem.File, error) {
|
||||
contentType := re.Request.Header.Get("content-type")
|
||||
if !strings.HasPrefix(contentType, "multipart/form-data") {
|
||||
return nil, nil // not multipart/form-data request
|
||||
}
|
||||
|
||||
result := map[string][]*filesystem.File{}
|
||||
|
||||
for _, field := range collection.Fields {
|
||||
if field.Type() != core.FieldTypeFile {
|
||||
continue
|
||||
}
|
||||
|
||||
baseKey := field.GetName()
|
||||
|
||||
keys := []string{
|
||||
baseKey,
|
||||
// prepend and append modifiers
|
||||
"+" + baseKey,
|
||||
baseKey + "+",
|
||||
}
|
||||
|
||||
for _, k := range keys {
|
||||
if prefix != "" {
|
||||
k = prefix + "." + k
|
||||
}
|
||||
files, err := re.FindUploadedFiles(k)
|
||||
if err != nil && !errors.Is(err, http.ErrMissingFile) {
|
||||
return nil, err
|
||||
}
|
||||
if len(files) > 0 {
|
||||
result[k] = files
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// hasAuthManageAccess checks whether the client is allowed to have
|
||||
// [forms.RecordUpsert] auth management permissions
|
||||
// (e.g. allowing to change system auth fields without oldPassword).
|
||||
func hasAuthManageAccess(app core.App, requestInfo *core.RequestInfo, collection *core.Collection, query *dbx.SelectQuery) bool {
|
||||
if !collection.IsAuth() {
|
||||
return false
|
||||
}
|
||||
|
||||
manageRule := collection.ManageRule
|
||||
|
||||
if manageRule == nil || *manageRule == "" {
|
||||
return false // only for superusers (manageRule can't be empty)
|
||||
}
|
||||
|
||||
if requestInfo == nil || requestInfo.Auth == nil {
|
||||
return false // no auth record
|
||||
}
|
||||
|
||||
resolver := core.NewRecordFieldResolver(app, collection, requestInfo, true)
|
||||
|
||||
expr, err := search.FilterData(*manageRule).BuildExpr(resolver)
|
||||
if err != nil {
|
||||
app.Logger().Error("Manage rule build expression error", "error", err, "collectionId", collection.Id)
|
||||
return false
|
||||
}
|
||||
query.AndWhere(expr)
|
||||
|
||||
resolver.UpdateQuery(query)
|
||||
|
||||
var exists int
|
||||
|
||||
err = query.Limit(1).Row(&exists)
|
||||
|
||||
return err == nil && exists > 0
|
||||
}
|
314
apis/record_crud_auth_origin_test.go
Normal file
314
apis/record_crud_auth_origin_test.go
Normal file
|
@ -0,0 +1,314 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestRecordCrudAuthOriginList(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records",
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"page":1`,
|
||||
`"perPage":30`,
|
||||
`"totalItems":0`,
|
||||
`"totalPages":0`,
|
||||
`"items":[]`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordsListRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "regular auth with authOrigins",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records",
|
||||
Headers: map[string]string{
|
||||
// clients, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"page":1`,
|
||||
`"perPage":30`,
|
||||
`"totalItems":1`,
|
||||
`"totalPages":1`,
|
||||
`"id":"9r2j0m74260ur8i"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordsListRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "regular auth without authOrigins",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records",
|
||||
Headers: map[string]string{
|
||||
// users, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"page":1`,
|
||||
`"perPage":30`,
|
||||
`"totalItems":0`,
|
||||
`"totalPages":0`,
|
||||
`"items":[]`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordsListRequest": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordCrudAuthOriginView(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records/9r2j0m74260ur8i",
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "non-owner",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records/9r2j0m74260ur8i",
|
||||
Headers: map[string]string{
|
||||
// users, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "owner",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records/9r2j0m74260ur8i",
|
||||
Headers: map[string]string{
|
||||
// clients, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{`"id":"9r2j0m74260ur8i"`},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordViewRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordCrudAuthOriginDelete(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodDelete,
|
||||
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records/9r2j0m74260ur8i",
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "non-owner",
|
||||
Method: http.MethodDelete,
|
||||
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records/9r2j0m74260ur8i",
|
||||
Headers: map[string]string{
|
||||
// users, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "owner",
|
||||
Method: http.MethodDelete,
|
||||
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records/9r2j0m74260ur8i",
|
||||
Headers: map[string]string{
|
||||
// clients, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
|
||||
},
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordDeleteRequest": 1,
|
||||
"OnModelDelete": 1,
|
||||
"OnModelDeleteExecute": 1,
|
||||
"OnModelAfterDeleteSuccess": 1,
|
||||
"OnRecordDelete": 1,
|
||||
"OnRecordDeleteExecute": 1,
|
||||
"OnRecordAfterDeleteSuccess": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordCrudAuthOriginCreate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body := func() *strings.Reader {
|
||||
return strings.NewReader(`{
|
||||
"recordRef": "4q1xlclmfloku33",
|
||||
"collectionRef": "_pb_users_auth_",
|
||||
"fingerprint": "abc"
|
||||
}`)
|
||||
}
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records",
|
||||
Body: body(),
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "owner regular auth",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records",
|
||||
Headers: map[string]string{
|
||||
// users, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
Body: body(),
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "superusers auth",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records",
|
||||
Headers: map[string]string{
|
||||
// superusers, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
Body: body(),
|
||||
ExpectedContent: []string{
|
||||
`"fingerprint":"abc"`,
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordCreateRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
"OnModelCreate": 1,
|
||||
"OnModelCreateExecute": 1,
|
||||
"OnModelAfterCreateSuccess": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnRecordCreate": 1,
|
||||
"OnRecordCreateExecute": 1,
|
||||
"OnRecordAfterCreateSuccess": 1,
|
||||
"OnRecordValidate": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordCrudAuthOriginUpdate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body := func() *strings.Reader {
|
||||
return strings.NewReader(`{
|
||||
"fingerprint":"abc"
|
||||
}`)
|
||||
}
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodPatch,
|
||||
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records/9r2j0m74260ur8i",
|
||||
Body: body(),
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "owner regular auth",
|
||||
Method: http.MethodPatch,
|
||||
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records/9r2j0m74260ur8i",
|
||||
Headers: map[string]string{
|
||||
// clients, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
|
||||
},
|
||||
Body: body(),
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "superusers auth",
|
||||
Method: http.MethodPatch,
|
||||
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records/9r2j0m74260ur8i",
|
||||
Headers: map[string]string{
|
||||
// superusers, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
Body: body(),
|
||||
ExpectedContent: []string{
|
||||
`"id":"9r2j0m74260ur8i"`,
|
||||
`"fingerprint":"abc"`,
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordUpdateRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
"OnModelUpdate": 1,
|
||||
"OnModelUpdateExecute": 1,
|
||||
"OnModelAfterUpdateSuccess": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnRecordUpdate": 1,
|
||||
"OnRecordUpdateExecute": 1,
|
||||
"OnRecordAfterUpdateSuccess": 1,
|
||||
"OnRecordValidate": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
316
apis/record_crud_external_auth_test.go
Normal file
316
apis/record_crud_external_auth_test.go
Normal file
|
@ -0,0 +1,316 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestRecordCrudExternalAuthList(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records",
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"page":1`,
|
||||
`"perPage":30`,
|
||||
`"totalItems":0`,
|
||||
`"totalPages":0`,
|
||||
`"items":[]`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordsListRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "regular auth with externalAuths",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records",
|
||||
Headers: map[string]string{
|
||||
// clients, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"page":1`,
|
||||
`"perPage":30`,
|
||||
`"totalItems":1`,
|
||||
`"totalPages":1`,
|
||||
`"id":"f1z5b3843pzc964"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordsListRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "regular auth without externalAuths",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records",
|
||||
Headers: map[string]string{
|
||||
// users, test2@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6Im9hcDY0MGNvdDR5cnUycyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.GfJo6EHIobgas_AXt-M-tj5IoQendPnrkMSe9ExuSEY",
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"page":1`,
|
||||
`"perPage":30`,
|
||||
`"totalItems":0`,
|
||||
`"totalPages":0`,
|
||||
`"items":[]`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordsListRequest": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordCrudExternalAuthView(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records/dlmflokuq1xl342",
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "non-owner",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records/dlmflokuq1xl342",
|
||||
Headers: map[string]string{
|
||||
// clients, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
|
||||
},
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "owner",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records/dlmflokuq1xl342",
|
||||
Headers: map[string]string{
|
||||
// users, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{`"id":"dlmflokuq1xl342"`},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordViewRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordCrudExternalAuthDelete(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodDelete,
|
||||
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records/dlmflokuq1xl342",
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "non-owner",
|
||||
Method: http.MethodDelete,
|
||||
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records/dlmflokuq1xl342",
|
||||
Headers: map[string]string{
|
||||
// clients, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
|
||||
},
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "owner",
|
||||
Method: http.MethodDelete,
|
||||
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records/dlmflokuq1xl342",
|
||||
Headers: map[string]string{
|
||||
// users, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordDeleteRequest": 1,
|
||||
"OnModelDelete": 1,
|
||||
"OnModelDeleteExecute": 1,
|
||||
"OnModelAfterDeleteSuccess": 1,
|
||||
"OnRecordDelete": 1,
|
||||
"OnRecordDeleteExecute": 1,
|
||||
"OnRecordAfterDeleteSuccess": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordCrudExternalAuthCreate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body := func() *strings.Reader {
|
||||
return strings.NewReader(`{
|
||||
"recordRef": "4q1xlclmfloku33",
|
||||
"collectionRef": "_pb_users_auth_",
|
||||
"provider": "github",
|
||||
"providerId": "abc"
|
||||
}`)
|
||||
}
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records",
|
||||
Body: body(),
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "owner regular auth",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records",
|
||||
Headers: map[string]string{
|
||||
// users, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
Body: body(),
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "superusers auth",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records",
|
||||
Headers: map[string]string{
|
||||
// superusers, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
Body: body(),
|
||||
ExpectedContent: []string{
|
||||
`"recordRef":"4q1xlclmfloku33"`,
|
||||
`"providerId":"abc"`,
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordCreateRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
"OnModelCreate": 1,
|
||||
"OnModelCreateExecute": 1,
|
||||
"OnModelAfterCreateSuccess": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnRecordCreate": 1,
|
||||
"OnRecordCreateExecute": 1,
|
||||
"OnRecordAfterCreateSuccess": 1,
|
||||
"OnRecordValidate": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordCrudExternalAuthUpdate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body := func() *strings.Reader {
|
||||
return strings.NewReader(`{
|
||||
"providerId": "abc"
|
||||
}`)
|
||||
}
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodPatch,
|
||||
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records/dlmflokuq1xl342",
|
||||
Body: body(),
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "owner regular auth",
|
||||
Method: http.MethodPatch,
|
||||
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records/dlmflokuq1xl342",
|
||||
Headers: map[string]string{
|
||||
// clients, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
|
||||
},
|
||||
Body: body(),
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "superusers auth",
|
||||
Method: http.MethodPatch,
|
||||
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records/dlmflokuq1xl342",
|
||||
Headers: map[string]string{
|
||||
// superusers, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
Body: body(),
|
||||
ExpectedContent: []string{
|
||||
`"id":"dlmflokuq1xl342"`,
|
||||
`"providerId":"abc"`,
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordUpdateRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
"OnModelUpdate": 1,
|
||||
"OnModelUpdateExecute": 1,
|
||||
"OnModelAfterUpdateSuccess": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnRecordUpdate": 1,
|
||||
"OnRecordUpdateExecute": 1,
|
||||
"OnRecordAfterUpdateSuccess": 1,
|
||||
"OnRecordValidate": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
405
apis/record_crud_mfa_test.go
Normal file
405
apis/record_crud_mfa_test.go
Normal file
|
@ -0,0 +1,405 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestRecordCrudMFAList(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameMFAs + "/records",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubMFARecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"page":1`,
|
||||
`"perPage":30`,
|
||||
`"totalItems":0`,
|
||||
`"totalPages":0`,
|
||||
`"items":[]`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordsListRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "regular auth with mfas",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameMFAs + "/records",
|
||||
Headers: map[string]string{
|
||||
// users, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubMFARecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"page":1`,
|
||||
`"perPage":30`,
|
||||
`"totalItems":1`,
|
||||
`"totalPages":1`,
|
||||
`"id":"user1_0"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordsListRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "regular auth without mfas",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameMFAs + "/records",
|
||||
Headers: map[string]string{
|
||||
// clients, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubMFARecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"page":1`,
|
||||
`"perPage":30`,
|
||||
`"totalItems":0`,
|
||||
`"totalPages":0`,
|
||||
`"items":[]`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordsListRequest": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordCrudMFAView(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameMFAs + "/records/user1_0",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubMFARecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "non-owner",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameMFAs + "/records/user1_0",
|
||||
Headers: map[string]string{
|
||||
// clients, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubMFARecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "owner",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameMFAs + "/records/user1_0",
|
||||
Headers: map[string]string{
|
||||
// users, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubMFARecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{`"id":"user1_0"`},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordViewRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordCrudMFADelete(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodDelete,
|
||||
URL: "/api/collections/" + core.CollectionNameMFAs + "/records/user1_0",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubMFARecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "non-owner",
|
||||
Method: http.MethodDelete,
|
||||
URL: "/api/collections/" + core.CollectionNameMFAs + "/records/user1_0",
|
||||
Headers: map[string]string{
|
||||
// clients, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubMFARecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "owner",
|
||||
Method: http.MethodDelete,
|
||||
URL: "/api/collections/" + core.CollectionNameMFAs + "/records/user1_0",
|
||||
Headers: map[string]string{
|
||||
// users, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubMFARecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "superusers auth",
|
||||
Method: http.MethodDelete,
|
||||
URL: "/api/collections/" + core.CollectionNameMFAs + "/records/user1_0",
|
||||
Headers: map[string]string{
|
||||
// superusers, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubMFARecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordDeleteRequest": 1,
|
||||
"OnModelDelete": 1,
|
||||
"OnModelDeleteExecute": 1,
|
||||
"OnModelAfterDeleteSuccess": 1,
|
||||
"OnRecordDelete": 1,
|
||||
"OnRecordDeleteExecute": 1,
|
||||
"OnRecordAfterDeleteSuccess": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordCrudMFACreate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body := func() *strings.Reader {
|
||||
return strings.NewReader(`{
|
||||
"recordRef": "4q1xlclmfloku33",
|
||||
"collectionRef": "_pb_users_auth_",
|
||||
"method": "abc"
|
||||
}`)
|
||||
}
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/" + core.CollectionNameMFAs + "/records",
|
||||
Body: body(),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubMFARecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "owner regular auth",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/" + core.CollectionNameMFAs + "/records",
|
||||
Headers: map[string]string{
|
||||
// users, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
Body: body(),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubMFARecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "superusers auth",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/" + core.CollectionNameMFAs + "/records",
|
||||
Headers: map[string]string{
|
||||
// superusers, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
Body: body(),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubMFARecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedContent: []string{
|
||||
`"recordRef":"4q1xlclmfloku33"`,
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordCreateRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
"OnModelCreate": 1,
|
||||
"OnModelCreateExecute": 1,
|
||||
"OnModelAfterCreateSuccess": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnRecordCreate": 1,
|
||||
"OnRecordCreateExecute": 1,
|
||||
"OnRecordAfterCreateSuccess": 1,
|
||||
"OnRecordValidate": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordCrudMFAUpdate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body := func() *strings.Reader {
|
||||
return strings.NewReader(`{
|
||||
"method":"abc"
|
||||
}`)
|
||||
}
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodPatch,
|
||||
URL: "/api/collections/" + core.CollectionNameMFAs + "/records/user1_0",
|
||||
Body: body(),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubMFARecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "owner regular auth",
|
||||
Method: http.MethodPatch,
|
||||
URL: "/api/collections/" + core.CollectionNameMFAs + "/records/user1_0",
|
||||
Headers: map[string]string{
|
||||
// users, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
Body: body(),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubMFARecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "superusers auth",
|
||||
Method: http.MethodPatch,
|
||||
URL: "/api/collections/" + core.CollectionNameMFAs + "/records/user1_0",
|
||||
Headers: map[string]string{
|
||||
// superusers, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
Body: body(),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubMFARecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedContent: []string{
|
||||
`"id":"user1_0"`,
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordUpdateRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
"OnModelUpdate": 1,
|
||||
"OnModelUpdateExecute": 1,
|
||||
"OnModelAfterUpdateSuccess": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnRecordUpdate": 1,
|
||||
"OnRecordUpdateExecute": 1,
|
||||
"OnRecordAfterUpdateSuccess": 1,
|
||||
"OnRecordValidate": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
405
apis/record_crud_otp_test.go
Normal file
405
apis/record_crud_otp_test.go
Normal file
|
@ -0,0 +1,405 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestRecordCrudOTPList(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameOTPs + "/records",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubOTPRecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"page":1`,
|
||||
`"perPage":30`,
|
||||
`"totalItems":0`,
|
||||
`"totalPages":0`,
|
||||
`"items":[]`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordsListRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "regular auth with otps",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameOTPs + "/records",
|
||||
Headers: map[string]string{
|
||||
// users, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubOTPRecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"page":1`,
|
||||
`"perPage":30`,
|
||||
`"totalItems":1`,
|
||||
`"totalPages":1`,
|
||||
`"id":"user1_0"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordsListRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "regular auth without otps",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameOTPs + "/records",
|
||||
Headers: map[string]string{
|
||||
// clients, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubOTPRecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"page":1`,
|
||||
`"perPage":30`,
|
||||
`"totalItems":0`,
|
||||
`"totalPages":0`,
|
||||
`"items":[]`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordsListRequest": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordCrudOTPView(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameOTPs + "/records/user1_0",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubOTPRecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "non-owner",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameOTPs + "/records/user1_0",
|
||||
Headers: map[string]string{
|
||||
// clients, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubOTPRecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "owner",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameOTPs + "/records/user1_0",
|
||||
Headers: map[string]string{
|
||||
// users, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubOTPRecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{`"id":"user1_0"`},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordViewRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordCrudOTPDelete(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodDelete,
|
||||
URL: "/api/collections/" + core.CollectionNameOTPs + "/records/user1_0",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubOTPRecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "non-owner auth",
|
||||
Method: http.MethodDelete,
|
||||
URL: "/api/collections/" + core.CollectionNameOTPs + "/records/user1_0",
|
||||
Headers: map[string]string{
|
||||
// clients, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubOTPRecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "owner regular auth",
|
||||
Method: http.MethodDelete,
|
||||
URL: "/api/collections/" + core.CollectionNameOTPs + "/records/user1_0",
|
||||
Headers: map[string]string{
|
||||
// users, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubOTPRecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "superusers auth",
|
||||
Method: http.MethodDelete,
|
||||
URL: "/api/collections/" + core.CollectionNameOTPs + "/records/user1_0",
|
||||
Headers: map[string]string{
|
||||
// superusers, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubOTPRecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordDeleteRequest": 1,
|
||||
"OnModelDelete": 1,
|
||||
"OnModelDeleteExecute": 1,
|
||||
"OnModelAfterDeleteSuccess": 1,
|
||||
"OnRecordDelete": 1,
|
||||
"OnRecordDeleteExecute": 1,
|
||||
"OnRecordAfterDeleteSuccess": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordCrudOTPCreate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body := func() *strings.Reader {
|
||||
return strings.NewReader(`{
|
||||
"recordRef": "4q1xlclmfloku33",
|
||||
"collectionRef": "_pb_users_auth_",
|
||||
"password": "abc"
|
||||
}`)
|
||||
}
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/" + core.CollectionNameOTPs + "/records",
|
||||
Body: body(),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubOTPRecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "owner regular auth",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/" + core.CollectionNameOTPs + "/records",
|
||||
Headers: map[string]string{
|
||||
// users, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
Body: body(),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubOTPRecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "superusers auth",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/" + core.CollectionNameOTPs + "/records",
|
||||
Headers: map[string]string{
|
||||
// superusers, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
Body: body(),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubOTPRecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedContent: []string{
|
||||
`"recordRef":"4q1xlclmfloku33"`,
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordCreateRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
"OnModelCreate": 1,
|
||||
"OnModelCreateExecute": 1,
|
||||
"OnModelAfterCreateSuccess": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnRecordCreate": 1,
|
||||
"OnRecordCreateExecute": 1,
|
||||
"OnRecordAfterCreateSuccess": 1,
|
||||
"OnRecordValidate": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordCrudOTPUpdate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body := func() *strings.Reader {
|
||||
return strings.NewReader(`{
|
||||
"password":"abc"
|
||||
}`)
|
||||
}
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodPatch,
|
||||
URL: "/api/collections/" + core.CollectionNameOTPs + "/records/user1_0",
|
||||
Body: body(),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubOTPRecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "owner regular auth",
|
||||
Method: http.MethodPatch,
|
||||
URL: "/api/collections/" + core.CollectionNameOTPs + "/records/user1_0",
|
||||
Headers: map[string]string{
|
||||
// users, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
Body: body(),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubOTPRecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "superusers auth",
|
||||
Method: http.MethodPatch,
|
||||
URL: "/api/collections/" + core.CollectionNameOTPs + "/records/user1_0",
|
||||
Headers: map[string]string{
|
||||
// superusers, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
Body: body(),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubOTPRecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedContent: []string{
|
||||
`"id":"user1_0"`,
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordUpdateRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
"OnModelUpdate": 1,
|
||||
"OnModelUpdateExecute": 1,
|
||||
"OnModelAfterUpdateSuccess": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnRecordUpdate": 1,
|
||||
"OnRecordUpdateExecute": 1,
|
||||
"OnRecordAfterUpdateSuccess": 1,
|
||||
"OnRecordValidate": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
336
apis/record_crud_superuser_test.go
Normal file
336
apis/record_crud_superuser_test.go
Normal file
|
@ -0,0 +1,336 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestRecordCrudSuperuserList(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records",
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "non-superusers auth",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records",
|
||||
Headers: map[string]string{
|
||||
// users, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "superusers auth",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records",
|
||||
Headers: map[string]string{
|
||||
// _superusers, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"page":1`,
|
||||
`"perPage":30`,
|
||||
`"totalPages":1`,
|
||||
`"totalItems":4`,
|
||||
`"items":[{`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordsListRequest": 1,
|
||||
"OnRecordEnrich": 4,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordCrudSuperuserView(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records/sywbhecnh46rhm0",
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "non-superusers auth",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records/sywbhecnh46rhm0",
|
||||
Headers: map[string]string{
|
||||
// users, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "superusers auth",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records/sywbhecnh46rhm0",
|
||||
Headers: map[string]string{
|
||||
// _superusers, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"id":"sywbhecnh46rhm0"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordViewRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordCrudSuperuserDelete(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodDelete,
|
||||
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records/sbmbsdb40jyxf7h",
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "non-superusers auth",
|
||||
Method: http.MethodDelete,
|
||||
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records/sbmbsdb40jyxf7h",
|
||||
Headers: map[string]string{
|
||||
// users, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "superusers auth",
|
||||
Method: http.MethodDelete,
|
||||
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records/sbmbsdb40jyxf7h",
|
||||
Headers: map[string]string{
|
||||
// _superusers, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordDeleteRequest": 1,
|
||||
"OnModelDelete": 4, // + 3 AuthOrigins
|
||||
"OnModelDeleteExecute": 4,
|
||||
"OnModelAfterDeleteSuccess": 4,
|
||||
"OnRecordDelete": 4,
|
||||
"OnRecordDeleteExecute": 4,
|
||||
"OnRecordAfterDeleteSuccess": 4,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "delete the last superuser",
|
||||
Method: http.MethodDelete,
|
||||
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records/sywbhecnh46rhm0",
|
||||
Headers: map[string]string{
|
||||
// _superusers, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
// delete all other superusers
|
||||
superusers, err := app.FindAllRecords(core.CollectionNameSuperusers, dbx.Not(dbx.HashExp{"id": "sywbhecnh46rhm0"}))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
for _, superuser := range superusers {
|
||||
if err = app.Delete(superuser); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordDeleteRequest": 1,
|
||||
"OnModelDelete": 1,
|
||||
"OnModelAfterDeleteError": 1,
|
||||
"OnRecordDelete": 1,
|
||||
"OnRecordAfterDeleteError": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordCrudSuperuserCreate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body := func() *strings.Reader {
|
||||
return strings.NewReader(`{
|
||||
"email": "test_new@example.com",
|
||||
"password": "1234567890",
|
||||
"passwordConfirm": "1234567890",
|
||||
"verified": false
|
||||
}`)
|
||||
}
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records",
|
||||
Body: body(),
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "non-superusers auth",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records",
|
||||
Headers: map[string]string{
|
||||
// users, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
Body: body(),
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "superusers auth",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records",
|
||||
Headers: map[string]string{
|
||||
// _superusers, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
Body: body(),
|
||||
ExpectedContent: []string{
|
||||
`"collectionName":"_superusers"`,
|
||||
`"email":"test_new@example.com"`,
|
||||
`"verified":true`,
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordCreateRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
"OnModelCreate": 1,
|
||||
"OnModelCreateExecute": 1,
|
||||
"OnModelAfterCreateSuccess": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnRecordCreate": 1,
|
||||
"OnRecordCreateExecute": 1,
|
||||
"OnRecordAfterCreateSuccess": 1,
|
||||
"OnRecordValidate": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordCrudSuperuserUpdate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body := func() *strings.Reader {
|
||||
return strings.NewReader(`{
|
||||
"email": "test_new@example.com",
|
||||
"verified": true
|
||||
}`)
|
||||
}
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodPatch,
|
||||
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records/sywbhecnh46rhm0",
|
||||
Body: body(),
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "non-superusers auth",
|
||||
Method: http.MethodPatch,
|
||||
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records/sywbhecnh46rhm0",
|
||||
Headers: map[string]string{
|
||||
// users, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
Body: body(),
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "superusers auth",
|
||||
Method: http.MethodPatch,
|
||||
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records/sywbhecnh46rhm0",
|
||||
Headers: map[string]string{
|
||||
// _superusers, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
Body: body(),
|
||||
ExpectedContent: []string{
|
||||
`"collectionName":"_superusers"`,
|
||||
`"id":"sywbhecnh46rhm0"`,
|
||||
`"email":"test_new@example.com"`,
|
||||
`"verified":true`,
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordUpdateRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
"OnModelUpdate": 1,
|
||||
"OnModelUpdateExecute": 1,
|
||||
"OnModelAfterUpdateSuccess": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnRecordUpdate": 1,
|
||||
"OnRecordUpdateExecute": 1,
|
||||
"OnRecordAfterUpdateSuccess": 1,
|
||||
"OnRecordValidate": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
3584
apis/record_crud_test.go
Normal file
3584
apis/record_crud_test.go
Normal file
File diff suppressed because it is too large
Load diff
636
apis/record_helpers.go
Normal file
636
apis/record_helpers.go
Normal file
|
@ -0,0 +1,636 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/mails"
|
||||
"github.com/pocketbase/pocketbase/tools/router"
|
||||
"github.com/pocketbase/pocketbase/tools/routine"
|
||||
"github.com/pocketbase/pocketbase/tools/search"
|
||||
"github.com/pocketbase/pocketbase/tools/security"
|
||||
)
|
||||
|
||||
const (
|
||||
expandQueryParam = "expand"
|
||||
fieldsQueryParam = "fields"
|
||||
)
|
||||
|
||||
var ErrMFA = errors.New("mfa required")
|
||||
|
||||
// RecordAuthResponse writes standardized json record auth response
|
||||
// into the specified request context.
|
||||
//
|
||||
// The authMethod argument specify the name of the current authentication method (eg. password, oauth2, etc.)
|
||||
// that it is used primarily as an auth identifier during MFA and for login alerts.
|
||||
//
|
||||
// Set authMethod to empty string if you want to ignore the MFA checks and the login alerts
|
||||
// (can be also adjusted additionally via the OnRecordAuthRequest hook).
|
||||
func RecordAuthResponse(e *core.RequestEvent, authRecord *core.Record, authMethod string, meta any) error {
|
||||
token, tokenErr := authRecord.NewAuthToken()
|
||||
if tokenErr != nil {
|
||||
return e.InternalServerError("Failed to create auth token.", tokenErr)
|
||||
}
|
||||
|
||||
return recordAuthResponse(e, authRecord, token, authMethod, meta)
|
||||
}
|
||||
|
||||
func recordAuthResponse(e *core.RequestEvent, authRecord *core.Record, token string, authMethod string, meta any) error {
|
||||
originalRequestInfo, err := e.RequestInfo()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ok, err := e.App.CanAccessRecord(authRecord, originalRequestInfo, authRecord.Collection().AuthRule)
|
||||
if !ok {
|
||||
return firstApiError(err, e.ForbiddenError("The request doesn't satisfy the collection requirements to authenticate.", err))
|
||||
}
|
||||
|
||||
event := new(core.RecordAuthRequestEvent)
|
||||
event.RequestEvent = e
|
||||
event.Collection = authRecord.Collection()
|
||||
event.Record = authRecord
|
||||
event.Token = token
|
||||
event.Meta = meta
|
||||
event.AuthMethod = authMethod
|
||||
|
||||
return e.App.OnRecordAuthRequest().Trigger(event, func(e *core.RecordAuthRequestEvent) error {
|
||||
if e.Written() {
|
||||
return nil
|
||||
}
|
||||
|
||||
// MFA
|
||||
// ---
|
||||
mfaId, err := checkMFA(e.RequestEvent, e.Record, e.AuthMethod)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// require additional authentication
|
||||
if mfaId != "" {
|
||||
// eagerly write the mfa response and return an err so that
|
||||
// external middlewars are aware that the auth response requires an extra step
|
||||
e.JSON(http.StatusUnauthorized, map[string]string{
|
||||
"mfaId": mfaId,
|
||||
})
|
||||
return ErrMFA
|
||||
}
|
||||
// ---
|
||||
|
||||
// create a shallow copy of the cached request data and adjust it to the current auth record
|
||||
requestInfo := *originalRequestInfo
|
||||
requestInfo.Auth = e.Record
|
||||
|
||||
err = triggerRecordEnrichHooks(e.App, &requestInfo, []*core.Record{e.Record}, func() error {
|
||||
if e.Record.IsSuperuser() {
|
||||
e.Record.Unhide(e.Record.Collection().Fields.FieldNames()...)
|
||||
}
|
||||
|
||||
// allow always returning the email address of the authenticated model
|
||||
e.Record.IgnoreEmailVisibility(true)
|
||||
|
||||
// expand record relations
|
||||
expands := strings.Split(e.Request.URL.Query().Get(expandQueryParam), ",")
|
||||
if len(expands) > 0 {
|
||||
failed := e.App.ExpandRecord(e.Record, expands, expandFetch(e.App, &requestInfo))
|
||||
if len(failed) > 0 {
|
||||
e.App.Logger().Warn("[recordAuthResponse] Failed to expand relations", "error", failed)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if e.AuthMethod != "" && authRecord.Collection().AuthAlert.Enabled {
|
||||
if err = authAlert(e.RequestEvent, e.Record); err != nil {
|
||||
e.App.Logger().Warn("[recordAuthResponse] Failed to send login alert", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
result := struct {
|
||||
Meta any `json:"meta,omitempty"`
|
||||
Record *core.Record `json:"record"`
|
||||
Token string `json:"token"`
|
||||
}{
|
||||
Token: e.Token,
|
||||
Record: e.Record,
|
||||
}
|
||||
|
||||
if e.Meta != nil {
|
||||
result.Meta = e.Meta
|
||||
}
|
||||
|
||||
return execAfterSuccessTx(true, e.App, func() error {
|
||||
return e.JSON(http.StatusOK, result)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// wantsMFA checks whether to enable MFA for the specified auth record based on its MFA rule
|
||||
// (note: returns true even in case of an error as a safer default).
|
||||
func wantsMFA(e *core.RequestEvent, record *core.Record) (bool, error) {
|
||||
rule := record.Collection().MFA.Rule
|
||||
if rule == "" {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
requestInfo, err := e.RequestInfo()
|
||||
if err != nil {
|
||||
return true, err
|
||||
}
|
||||
|
||||
var exists int
|
||||
|
||||
query := e.App.RecordQuery(record.Collection()).
|
||||
Select("(1)").
|
||||
AndWhere(dbx.HashExp{record.Collection().Name + ".id": record.Id})
|
||||
|
||||
// parse and apply the access rule filter
|
||||
resolver := core.NewRecordFieldResolver(e.App, record.Collection(), requestInfo, true)
|
||||
expr, err := search.FilterData(rule).BuildExpr(resolver)
|
||||
if err != nil {
|
||||
return true, err
|
||||
}
|
||||
resolver.UpdateQuery(query)
|
||||
|
||||
err = query.AndWhere(expr).Limit(1).Row(&exists)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
return true, err
|
||||
}
|
||||
|
||||
return exists > 0, nil
|
||||
}
|
||||
|
||||
// checkMFA handles any MFA auth checks that needs to be performed for the specified request event.
|
||||
// Returns the mfaId that needs to be written as response to the user.
|
||||
//
|
||||
// (note: all auth methods are treated as equal and there is no requirement for "pairing").
|
||||
func checkMFA(e *core.RequestEvent, authRecord *core.Record, currentAuthMethod string) (string, error) {
|
||||
if !authRecord.Collection().MFA.Enabled || currentAuthMethod == "" {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
ok, err := wantsMFA(e, authRecord)
|
||||
if err != nil {
|
||||
return "", e.BadRequestError("Failed to authenticate.", fmt.Errorf("MFA rule failure: %w", err))
|
||||
}
|
||||
if !ok {
|
||||
return "", nil // no mfa needed for this auth record
|
||||
}
|
||||
|
||||
// read the mfaId either from the qyery params or request body
|
||||
mfaId := e.Request.URL.Query().Get("mfaId")
|
||||
if mfaId == "" {
|
||||
// check the body
|
||||
data := struct {
|
||||
MfaId string `form:"mfaId" json:"mfaId" xml:"mfaId"`
|
||||
}{}
|
||||
if err := e.BindBody(&data); err != nil {
|
||||
return "", firstApiError(err, e.BadRequestError("Failed to read MFA Id", err))
|
||||
}
|
||||
mfaId = data.MfaId
|
||||
}
|
||||
|
||||
// first-time auth
|
||||
// ---
|
||||
if mfaId == "" {
|
||||
mfa := core.NewMFA(e.App)
|
||||
mfa.SetCollectionRef(authRecord.Collection().Id)
|
||||
mfa.SetRecordRef(authRecord.Id)
|
||||
mfa.SetMethod(currentAuthMethod)
|
||||
if err := e.App.Save(mfa); err != nil {
|
||||
return "", firstApiError(err, e.InternalServerError("Failed to create MFA record", err))
|
||||
}
|
||||
|
||||
return mfa.Id, nil
|
||||
}
|
||||
|
||||
// second-time auth
|
||||
// ---
|
||||
mfa, err := e.App.FindMFAById(mfaId)
|
||||
deleteMFA := func() {
|
||||
// try to delete the expired mfa
|
||||
if mfa != nil {
|
||||
if deleteErr := e.App.Delete(mfa); deleteErr != nil {
|
||||
e.App.Logger().Warn("Failed to delete expired MFA record", "error", deleteErr, "mfaId", mfa.Id)
|
||||
}
|
||||
}
|
||||
}
|
||||
if err != nil || mfa.HasExpired(authRecord.Collection().MFA.DurationTime()) {
|
||||
deleteMFA()
|
||||
return "", e.BadRequestError("Invalid or expired MFA session.", err)
|
||||
}
|
||||
|
||||
if mfa.RecordRef() != authRecord.Id || mfa.CollectionRef() != authRecord.Collection().Id {
|
||||
return "", e.BadRequestError("Invalid MFA session.", nil)
|
||||
}
|
||||
|
||||
if mfa.Method() == currentAuthMethod {
|
||||
return "", e.BadRequestError("A different authentication method is required.", nil)
|
||||
}
|
||||
|
||||
deleteMFA()
|
||||
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// EnrichRecord parses the request context and enrich the provided record:
|
||||
// - expands relations (if defaultExpands and/or ?expand query param is set)
|
||||
// - ensures that the emails of the auth record and its expanded auth relations
|
||||
// are visible only for the current logged superuser, record owner or record with manage access
|
||||
func EnrichRecord(e *core.RequestEvent, record *core.Record, defaultExpands ...string) error {
|
||||
return EnrichRecords(e, []*core.Record{record}, defaultExpands...)
|
||||
}
|
||||
|
||||
// EnrichRecords parses the request context and enriches the provided records:
|
||||
// - expands relations (if defaultExpands and/or ?expand query param is set)
|
||||
// - ensures that the emails of the auth records and their expanded auth relations
|
||||
// are visible only for the current logged superuser, record owner or record with manage access
|
||||
//
|
||||
// Note: Expects all records to be from the same collection!
|
||||
func EnrichRecords(e *core.RequestEvent, records []*core.Record, defaultExpands ...string) error {
|
||||
if len(records) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
info, err := e.RequestInfo()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return triggerRecordEnrichHooks(e.App, info, records, func() error {
|
||||
expands := defaultExpands
|
||||
if param := info.Query[expandQueryParam]; param != "" {
|
||||
expands = append(expands, strings.Split(param, ",")...)
|
||||
}
|
||||
|
||||
err := defaultEnrichRecords(e.App, info, records, expands...)
|
||||
if err != nil {
|
||||
// only log because it is not critical
|
||||
e.App.Logger().Warn("failed to apply default enriching", "error", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
type iterator[T any] struct {
|
||||
items []T
|
||||
index int
|
||||
}
|
||||
|
||||
func (ri *iterator[T]) next() T {
|
||||
var item T
|
||||
|
||||
if ri.index < len(ri.items) {
|
||||
item = ri.items[ri.index]
|
||||
ri.index++
|
||||
}
|
||||
|
||||
return item
|
||||
}
|
||||
|
||||
func triggerRecordEnrichHooks(app core.App, requestInfo *core.RequestInfo, records []*core.Record, finalizer func() error) error {
|
||||
it := iterator[*core.Record]{items: records}
|
||||
|
||||
enrichHook := app.OnRecordEnrich()
|
||||
|
||||
event := new(core.RecordEnrichEvent)
|
||||
event.App = app
|
||||
event.RequestInfo = requestInfo
|
||||
|
||||
var iterate func(record *core.Record) error
|
||||
iterate = func(record *core.Record) error {
|
||||
if record == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
event.Record = record
|
||||
|
||||
return enrichHook.Trigger(event, func(ee *core.RecordEnrichEvent) error {
|
||||
next := it.next()
|
||||
if next == nil {
|
||||
if finalizer != nil {
|
||||
return finalizer()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
event.App = ee.App // in case it was replaced with a transaction
|
||||
event.Record = next
|
||||
|
||||
err := iterate(next)
|
||||
|
||||
event.App = app
|
||||
event.Record = record
|
||||
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
return iterate(it.next())
|
||||
}
|
||||
|
||||
func defaultEnrichRecords(app core.App, requestInfo *core.RequestInfo, records []*core.Record, expands ...string) error {
|
||||
err := autoResolveRecordsFlags(app, records, requestInfo)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to resolve records flags: %w", err)
|
||||
}
|
||||
|
||||
if len(expands) > 0 {
|
||||
expandErrs := app.ExpandRecords(records, expands, expandFetch(app, requestInfo))
|
||||
if len(expandErrs) > 0 {
|
||||
errsSlice := make([]error, 0, len(expandErrs))
|
||||
for key, err := range expandErrs {
|
||||
errsSlice = append(errsSlice, fmt.Errorf("failed to expand %q: %w", key, err))
|
||||
}
|
||||
return fmt.Errorf("failed to expand records: %w", errors.Join(errsSlice...))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// expandFetch is the records fetch function that is used to expand related records.
|
||||
func expandFetch(app core.App, originalRequestInfo *core.RequestInfo) core.ExpandFetchFunc {
|
||||
// shallow clone the provided request info to set an "expand" context
|
||||
requestInfoClone := *originalRequestInfo
|
||||
requestInfoPtr := &requestInfoClone
|
||||
requestInfoPtr.Context = core.RequestInfoContextExpand
|
||||
|
||||
return func(relCollection *core.Collection, relIds []string) ([]*core.Record, error) {
|
||||
records, findErr := app.FindRecordsByIds(relCollection.Id, relIds, func(q *dbx.SelectQuery) error {
|
||||
if requestInfoPtr.Auth != nil && requestInfoPtr.Auth.IsSuperuser() {
|
||||
return nil // superusers can access everything
|
||||
}
|
||||
|
||||
if relCollection.ViewRule == nil {
|
||||
return fmt.Errorf("only superusers can view collection %q records", relCollection.Name)
|
||||
}
|
||||
|
||||
if *relCollection.ViewRule != "" {
|
||||
resolver := core.NewRecordFieldResolver(app, relCollection, requestInfoPtr, true)
|
||||
expr, err := search.FilterData(*(relCollection.ViewRule)).BuildExpr(resolver)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resolver.UpdateQuery(q)
|
||||
q.AndWhere(expr)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if findErr != nil {
|
||||
return nil, findErr
|
||||
}
|
||||
|
||||
enrichErr := triggerRecordEnrichHooks(app, requestInfoPtr, records, func() error {
|
||||
if err := autoResolveRecordsFlags(app, records, requestInfoPtr); err != nil {
|
||||
// non-critical error
|
||||
app.Logger().Warn("Failed to apply autoResolveRecordsFlags for the expanded records", "error", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if enrichErr != nil {
|
||||
return nil, enrichErr
|
||||
}
|
||||
|
||||
return records, nil
|
||||
}
|
||||
}
|
||||
|
||||
// autoResolveRecordsFlags resolves various visibility flags of the provided records.
|
||||
//
|
||||
// Currently it enables:
|
||||
// - export of hidden fields if the current auth model is a superuser
|
||||
// - email export ignoring the emailVisibity checks if the current auth model is superuser, owner or a "manager".
|
||||
//
|
||||
// Note: Expects all records to be from the same collection!
|
||||
func autoResolveRecordsFlags(app core.App, records []*core.Record, requestInfo *core.RequestInfo) error {
|
||||
if len(records) == 0 {
|
||||
return nil // nothing to resolve
|
||||
}
|
||||
|
||||
if requestInfo.HasSuperuserAuth() {
|
||||
hiddenFields := records[0].Collection().Fields.FieldNames()
|
||||
for _, rec := range records {
|
||||
rec.Unhide(hiddenFields...)
|
||||
rec.IgnoreEmailVisibility(true)
|
||||
}
|
||||
}
|
||||
|
||||
// additional emailVisibility checks
|
||||
// ---------------------------------------------------------------
|
||||
if !records[0].Collection().IsAuth() {
|
||||
return nil // not auth collection records
|
||||
}
|
||||
|
||||
collection := records[0].Collection()
|
||||
|
||||
mappedRecords := make(map[string]*core.Record, len(records))
|
||||
recordIds := make([]any, len(records))
|
||||
for i, rec := range records {
|
||||
mappedRecords[rec.Id] = rec
|
||||
recordIds[i] = rec.Id
|
||||
}
|
||||
|
||||
if requestInfo.Auth != nil && mappedRecords[requestInfo.Auth.Id] != nil {
|
||||
mappedRecords[requestInfo.Auth.Id].IgnoreEmailVisibility(true)
|
||||
}
|
||||
|
||||
if collection.ManageRule == nil || *collection.ManageRule == "" {
|
||||
return nil // no manage rule to check
|
||||
}
|
||||
|
||||
// fetch the ids of the managed records
|
||||
// ---
|
||||
managedIds := []string{}
|
||||
|
||||
query := app.RecordQuery(collection).
|
||||
Select(app.ConcurrentDB().QuoteSimpleColumnName(collection.Name) + ".id").
|
||||
AndWhere(dbx.In(app.ConcurrentDB().QuoteSimpleColumnName(collection.Name)+".id", recordIds...))
|
||||
|
||||
resolver := core.NewRecordFieldResolver(app, collection, requestInfo, true)
|
||||
expr, err := search.FilterData(*collection.ManageRule).BuildExpr(resolver)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resolver.UpdateQuery(query)
|
||||
query.AndWhere(expr)
|
||||
|
||||
if err := query.Column(&managedIds); err != nil {
|
||||
return err
|
||||
}
|
||||
// ---
|
||||
|
||||
// ignore the email visibility check for the managed records
|
||||
for _, id := range managedIds {
|
||||
if rec, ok := mappedRecords[id]; ok {
|
||||
rec.IgnoreEmailVisibility(true)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var ruleQueryParams = []string{search.FilterQueryParam, search.SortQueryParam}
|
||||
var superuserOnlyRuleFields = []string{"@collection.", "@request."}
|
||||
|
||||
// checkForSuperuserOnlyRuleFields loosely checks and returns an error if
|
||||
// the provided RequestInfo contains rule fields that only the superuser can use.
|
||||
func checkForSuperuserOnlyRuleFields(requestInfo *core.RequestInfo) error {
|
||||
if len(requestInfo.Query) == 0 || requestInfo.HasSuperuserAuth() {
|
||||
return nil // superuser or nothing to check
|
||||
}
|
||||
|
||||
for _, param := range ruleQueryParams {
|
||||
v := requestInfo.Query[param]
|
||||
if v == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, field := range superuserOnlyRuleFields {
|
||||
if strings.Contains(v, field) {
|
||||
return router.NewForbiddenError("Only superusers can filter by "+field, nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// firstApiError returns the first ApiError from the errors list
|
||||
// (this is used usually to prevent unnecessary wraping and to allow bubling ApiError from nested hooks)
|
||||
//
|
||||
// If no ApiError is found, returns a default "Internal server" error.
|
||||
func firstApiError(errs ...error) *router.ApiError {
|
||||
var apiErr *router.ApiError
|
||||
var ok bool
|
||||
|
||||
for _, err := range errs {
|
||||
if err == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// quick assert to avoid the reflection checks
|
||||
apiErr, ok = err.(*router.ApiError)
|
||||
if ok {
|
||||
return apiErr
|
||||
}
|
||||
|
||||
// nested/wrapped errors
|
||||
if errors.As(err, &apiErr) {
|
||||
return apiErr
|
||||
}
|
||||
}
|
||||
|
||||
return router.NewInternalServerError("", errors.Join(errs...))
|
||||
}
|
||||
|
||||
// execAfterSuccessTx ensures that fn is executed only after a succesul transaction.
|
||||
//
|
||||
// If the current app instance is not a transactional or checkTx is false,
|
||||
// then fn is directly executed.
|
||||
//
|
||||
// It could be usually used to allow propagating an error or writing
|
||||
// custom response from within the wrapped transaction block.
|
||||
func execAfterSuccessTx(checkTx bool, app core.App, fn func() error) error {
|
||||
if txInfo := app.TxInfo(); txInfo != nil && checkTx {
|
||||
txInfo.OnComplete(func(txErr error) error {
|
||||
if txErr == nil {
|
||||
return fn()
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
return fn()
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
const maxAuthOrigins = 5
|
||||
|
||||
func authAlert(e *core.RequestEvent, authRecord *core.Record) error {
|
||||
// generating fingerprint
|
||||
// ---
|
||||
userAgent := e.Request.UserAgent()
|
||||
if len(userAgent) > 300 {
|
||||
userAgent = userAgent[:300]
|
||||
}
|
||||
fingerprint := security.MD5(e.RealIP() + userAgent)
|
||||
// ---
|
||||
|
||||
origins, err := e.App.FindAllAuthOriginsByRecord(authRecord)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
isFirstLogin := len(origins) == 0
|
||||
|
||||
var currentOrigin *core.AuthOrigin
|
||||
for _, origin := range origins {
|
||||
if origin.Fingerprint() == fingerprint {
|
||||
currentOrigin = origin
|
||||
break
|
||||
}
|
||||
}
|
||||
if currentOrigin == nil {
|
||||
currentOrigin = core.NewAuthOrigin(e.App)
|
||||
currentOrigin.SetCollectionRef(authRecord.Collection().Id)
|
||||
currentOrigin.SetRecordRef(authRecord.Id)
|
||||
currentOrigin.SetFingerprint(fingerprint)
|
||||
}
|
||||
|
||||
// send email alert for the new origin auth (skip first login)
|
||||
//
|
||||
// Note: The "fake" timeout is a temp solution to avoid blocking
|
||||
// for too long when the SMTP server is not accessible, due
|
||||
// to the lack of context cancellation support in the underlying
|
||||
// mailer and net/smtp package.
|
||||
// The goroutine technically "leaks" but we assume that the OS will
|
||||
// terminate the connection after some time (usually after 3-4 mins).
|
||||
if !isFirstLogin && currentOrigin.IsNew() && authRecord.Email() != "" {
|
||||
mailSent := make(chan error, 1)
|
||||
|
||||
timer := time.AfterFunc(15*time.Second, func() {
|
||||
mailSent <- errors.New("auth alert mail send wait timeout reached")
|
||||
})
|
||||
|
||||
routine.FireAndForget(func() {
|
||||
err := mails.SendRecordAuthAlert(e.App, authRecord)
|
||||
timer.Stop()
|
||||
mailSent <- err
|
||||
})
|
||||
|
||||
err = <-mailSent
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// try to keep only up to maxAuthOrigins
|
||||
// (pop the last used ones; it is not executed in a transaction to avoid unnecessary locks)
|
||||
if currentOrigin.IsNew() && len(origins) >= maxAuthOrigins {
|
||||
for i := len(origins) - 1; i >= maxAuthOrigins-1; i-- {
|
||||
if err := e.App.Delete(origins[i]); err != nil {
|
||||
// treat as non-critical error, just log for now
|
||||
e.App.Logger().Warn("Failed to delete old AuthOrigin record", "error", err, "authOriginId", origins[i].Id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// create/update the origin fingerprint
|
||||
return e.App.Save(currentOrigin)
|
||||
}
|
761
apis/record_helpers_test.go
Normal file
761
apis/record_helpers_test.go
Normal file
|
@ -0,0 +1,761 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pocketbase/pocketbase/apis"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
"github.com/pocketbase/pocketbase/tools/router"
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
)
|
||||
|
||||
func TestEnrichRecords(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// mock test data
|
||||
// ---
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
freshRecords := func(records []*core.Record) []*core.Record {
|
||||
result := make([]*core.Record, len(records))
|
||||
for i, r := range records {
|
||||
result[i] = r.Fresh()
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
superuser, err := app.FindAuthRecordByEmail(core.CollectionNameSuperusers, "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
usersRecords, err := app.FindRecordsByIds("users", []string{"4q1xlclmfloku33", "bgs820n361vj1qd"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
nologinRecords, err := app.FindRecordsByIds("nologin", []string{"dc49k6jgejn40h3", "oos036e9xvqeexy"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
demo1Records, err := app.FindRecordsByIds("demo1", []string{"al1h9ijdeojtsjy", "84nmscqy84lsi1t"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
demo5Records, err := app.FindRecordsByIds("demo5", []string{"la4y2w4o98acwuj", "qjeql998mtp1azp"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// temp update the view rule to ensure that request context is set to "expand"
|
||||
demo4, err := app.FindCollectionByNameOrId("demo4")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
demo4.ViewRule = types.Pointer("@request.context = 'expand'")
|
||||
if err := app.Save(demo4); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// ---
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
auth *core.Record
|
||||
records []*core.Record
|
||||
queryExpand string
|
||||
defaultExpands []string
|
||||
expected []string
|
||||
notExpected []string
|
||||
}{
|
||||
// email visibility checks
|
||||
{
|
||||
name: "[emailVisibility] guest",
|
||||
auth: nil,
|
||||
records: freshRecords(usersRecords),
|
||||
queryExpand: "",
|
||||
defaultExpands: nil,
|
||||
expected: []string{
|
||||
`"customField":"123"`,
|
||||
`"test3@example.com"`, // emailVisibility=true
|
||||
},
|
||||
notExpected: []string{
|
||||
`"test@example.com"`,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "[emailVisibility] owner",
|
||||
auth: user,
|
||||
records: freshRecords(usersRecords),
|
||||
queryExpand: "",
|
||||
defaultExpands: nil,
|
||||
expected: []string{
|
||||
`"customField":"123"`,
|
||||
`"test3@example.com"`, // emailVisibility=true
|
||||
`"test@example.com"`, // owner
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "[emailVisibility] manager",
|
||||
auth: user,
|
||||
records: freshRecords(nologinRecords),
|
||||
queryExpand: "",
|
||||
defaultExpands: nil,
|
||||
expected: []string{
|
||||
`"customField":"123"`,
|
||||
`"test3@example.com"`,
|
||||
`"test@example.com"`,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "[emailVisibility] superuser",
|
||||
auth: superuser,
|
||||
records: freshRecords(nologinRecords),
|
||||
queryExpand: "",
|
||||
defaultExpands: nil,
|
||||
expected: []string{
|
||||
`"customField":"123"`,
|
||||
`"test3@example.com"`,
|
||||
`"test@example.com"`,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "[emailVisibility + expand] recursive auth rule checks (regular user)",
|
||||
auth: user,
|
||||
records: freshRecords(demo1Records),
|
||||
queryExpand: "",
|
||||
defaultExpands: []string{"rel_many"},
|
||||
expected: []string{
|
||||
`"customField":"123"`,
|
||||
`"expand":{"rel_many"`,
|
||||
`"expand":{}`,
|
||||
`"test@example.com"`,
|
||||
},
|
||||
notExpected: []string{
|
||||
`"id":"bgs820n361vj1qd"`,
|
||||
`"id":"oap640cot4yru2s"`,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "[emailVisibility + expand] recursive auth rule checks (superuser)",
|
||||
auth: superuser,
|
||||
records: freshRecords(demo1Records),
|
||||
queryExpand: "",
|
||||
defaultExpands: []string{"rel_many"},
|
||||
expected: []string{
|
||||
`"customField":"123"`,
|
||||
`"test@example.com"`,
|
||||
`"expand":{"rel_many"`,
|
||||
`"id":"bgs820n361vj1qd"`,
|
||||
`"id":"4q1xlclmfloku33"`,
|
||||
`"id":"oap640cot4yru2s"`,
|
||||
},
|
||||
notExpected: []string{
|
||||
`"expand":{}`,
|
||||
},
|
||||
},
|
||||
|
||||
// expand checks
|
||||
{
|
||||
name: "[expand] guest (query)",
|
||||
auth: nil,
|
||||
records: freshRecords(usersRecords),
|
||||
queryExpand: "rel",
|
||||
defaultExpands: nil,
|
||||
expected: []string{
|
||||
`"customField":"123"`,
|
||||
`"expand":{"rel"`,
|
||||
`"id":"llvuca81nly1qls"`,
|
||||
`"id":"0yxhwia2amd8gec"`,
|
||||
},
|
||||
notExpected: []string{
|
||||
`"expand":{}`,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "[expand] guest (default expands)",
|
||||
auth: nil,
|
||||
records: freshRecords(usersRecords),
|
||||
queryExpand: "",
|
||||
defaultExpands: []string{"rel"},
|
||||
expected: []string{
|
||||
`"customField":"123"`,
|
||||
`"expand":{"rel"`,
|
||||
`"id":"llvuca81nly1qls"`,
|
||||
`"id":"0yxhwia2amd8gec"`,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "[expand] @request.context=expand check",
|
||||
auth: nil,
|
||||
records: freshRecords(demo5Records),
|
||||
queryExpand: "rel_one",
|
||||
defaultExpands: []string{"rel_many"},
|
||||
expected: []string{
|
||||
`"customField":"123"`,
|
||||
`"expand":{}`,
|
||||
`"expand":{"`,
|
||||
`"rel_many":[{`,
|
||||
`"rel_one":{`,
|
||||
`"id":"i9naidtvr6qsgb4"`,
|
||||
`"id":"qzaqccwrmva4o1n"`,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
app.OnRecordEnrich().BindFunc(func(e *core.RecordEnrichEvent) error {
|
||||
e.Record.WithCustomData(true)
|
||||
e.Record.Set("customField", "123")
|
||||
return e.Next()
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/?expand="+s.queryExpand, nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
requestEvent := new(core.RequestEvent)
|
||||
requestEvent.App = app
|
||||
requestEvent.Request = req
|
||||
requestEvent.Response = rec
|
||||
requestEvent.Auth = s.auth
|
||||
|
||||
err := apis.EnrichRecords(requestEvent, s.records, s.defaultExpands...)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
raw, err := json.Marshal(s.records)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
rawStr := string(raw)
|
||||
|
||||
for _, str := range s.expected {
|
||||
if !strings.Contains(rawStr, str) {
|
||||
t.Fatalf("Expected\n%q\nin\n%v", str, rawStr)
|
||||
}
|
||||
}
|
||||
|
||||
for _, str := range s.notExpected {
|
||||
if strings.Contains(rawStr, str) {
|
||||
t.Fatalf("Didn't expected\n%q\nin\n%v", str, rawStr)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordAuthResponseAuthRuleCheck(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
event := new(core.RequestEvent)
|
||||
event.App = app
|
||||
event.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
event.Response = httptest.NewRecorder()
|
||||
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
rule *string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
"admin only rule",
|
||||
nil,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"empty rule",
|
||||
types.Pointer(""),
|
||||
false,
|
||||
},
|
||||
{
|
||||
"false rule",
|
||||
types.Pointer("1=2"),
|
||||
true,
|
||||
},
|
||||
{
|
||||
"true rule",
|
||||
types.Pointer("1=1"),
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
user.Collection().AuthRule = s.rule
|
||||
|
||||
err := apis.RecordAuthResponse(event, user, "", nil)
|
||||
|
||||
hasErr := err != nil
|
||||
if s.expectError != hasErr {
|
||||
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
|
||||
}
|
||||
|
||||
// in all cases login alert shouldn't be send because of the empty auth method
|
||||
if app.TestMailer.TotalSend() != 0 {
|
||||
t.Fatalf("Expected no emails send, got %d:\n%v", app.TestMailer.TotalSend(), app.TestMailer.LastMessage().HTML)
|
||||
}
|
||||
|
||||
if !hasErr {
|
||||
return
|
||||
}
|
||||
|
||||
apiErr, ok := err.(*router.ApiError)
|
||||
|
||||
if !ok || apiErr == nil {
|
||||
t.Fatalf("Expected ApiError, got %v", apiErr)
|
||||
}
|
||||
|
||||
if apiErr.Status != http.StatusForbidden {
|
||||
t.Fatalf("Expected ApiError.Status %d, got %d", http.StatusForbidden, apiErr.Status)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordAuthResponseAuthAlertCheck(t *testing.T) {
|
||||
const testFingerprint = "d0f88d6c87767262ba8e93d6acccd784"
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
devices []string // mock existing device fingerprints
|
||||
expectDevices []string
|
||||
enabled bool
|
||||
expectEmail bool
|
||||
}{
|
||||
{
|
||||
name: "first login",
|
||||
devices: nil,
|
||||
expectDevices: []string{testFingerprint},
|
||||
enabled: true,
|
||||
expectEmail: false,
|
||||
},
|
||||
{
|
||||
name: "existing device",
|
||||
devices: []string{"1", testFingerprint},
|
||||
expectDevices: []string{"1", testFingerprint},
|
||||
enabled: true,
|
||||
expectEmail: false,
|
||||
},
|
||||
{
|
||||
name: "new device (< 5)",
|
||||
devices: []string{"1", "2"},
|
||||
expectDevices: []string{"1", "2", testFingerprint},
|
||||
enabled: true,
|
||||
expectEmail: true,
|
||||
},
|
||||
{
|
||||
name: "new device (>= 5)",
|
||||
devices: []string{"1", "2", "3", "4", "5"},
|
||||
expectDevices: []string{"2", "3", "4", "5", testFingerprint},
|
||||
enabled: true,
|
||||
expectEmail: true,
|
||||
},
|
||||
{
|
||||
name: "with disabled auth alert collection flag",
|
||||
devices: []string{"1", "2"},
|
||||
expectDevices: []string{"1", "2"},
|
||||
enabled: false,
|
||||
expectEmail: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
event := new(core.RequestEvent)
|
||||
event.App = app
|
||||
event.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
event.Response = httptest.NewRecorder()
|
||||
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
user.Collection().MFA.Enabled = false
|
||||
user.Collection().AuthRule = types.Pointer("")
|
||||
user.Collection().AuthAlert.Enabled = s.enabled
|
||||
|
||||
// ensure that there are no other auth origins
|
||||
err = app.DeleteAllAuthOriginsByRecord(user)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
mockCreated := types.NowDateTime().Add(-time.Duration(len(s.devices)+1) * time.Second)
|
||||
// insert the mock devices
|
||||
for _, fingerprint := range s.devices {
|
||||
mockCreated = mockCreated.Add(1 * time.Second)
|
||||
d := core.NewAuthOrigin(app)
|
||||
d.SetCollectionRef(user.Collection().Id)
|
||||
d.SetRecordRef(user.Id)
|
||||
d.SetFingerprint(fingerprint)
|
||||
d.SetRaw("created", mockCreated)
|
||||
d.SetRaw("updated", mockCreated)
|
||||
if err = app.Save(d); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
err = apis.RecordAuthResponse(event, user, "example", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to resolve auth response: %v", err)
|
||||
}
|
||||
|
||||
var expectTotalSend int
|
||||
if s.expectEmail {
|
||||
expectTotalSend = 1
|
||||
}
|
||||
if total := app.TestMailer.TotalSend(); total != expectTotalSend {
|
||||
t.Fatalf("Expected %d sent emails, got %d", expectTotalSend, total)
|
||||
}
|
||||
|
||||
devices, err := app.FindAllAuthOriginsByRecord(user)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to retrieve auth origins: %v", err)
|
||||
}
|
||||
|
||||
if len(devices) != len(s.expectDevices) {
|
||||
t.Fatalf("Expected %d devices, got %d", len(s.expectDevices), len(devices))
|
||||
}
|
||||
|
||||
for _, fingerprint := range s.expectDevices {
|
||||
var exists bool
|
||||
fingerprints := make([]string, 0, len(devices))
|
||||
for _, d := range devices {
|
||||
if d.Fingerprint() == fingerprint {
|
||||
exists = true
|
||||
break
|
||||
}
|
||||
fingerprints = append(fingerprints, d.Fingerprint())
|
||||
}
|
||||
if !exists {
|
||||
t.Fatalf("Missing device with fingerprint %q:\n%v", fingerprint, fingerprints)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordAuthResponseMFACheck(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
user2, err := app.FindAuthRecordByEmail("users", "test2@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
event := new(core.RequestEvent)
|
||||
event.App = app
|
||||
event.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
event.Response = rec
|
||||
|
||||
resetMFAs := func(authRecord *core.Record) {
|
||||
// ensure that mfa is enabled
|
||||
user.Collection().MFA.Enabled = true
|
||||
user.Collection().MFA.Duration = 5
|
||||
user.Collection().MFA.Rule = ""
|
||||
|
||||
mfas, err := app.FindAllMFAsByRecord(authRecord)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to retrieve mfas: %v", err)
|
||||
}
|
||||
for _, mfa := range mfas {
|
||||
if err := app.Delete(mfa); err != nil {
|
||||
t.Fatalf("Failed to delete mfa %q: %v", mfa.Id, err)
|
||||
}
|
||||
}
|
||||
|
||||
// reset response
|
||||
rec = httptest.NewRecorder()
|
||||
event.Response = rec
|
||||
}
|
||||
|
||||
totalMFAs := func(authRecord *core.Record) int {
|
||||
mfas, err := app.FindAllMFAsByRecord(authRecord)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to retrieve mfas: %v", err)
|
||||
}
|
||||
return len(mfas)
|
||||
}
|
||||
|
||||
t.Run("no collection MFA enabled", func(t *testing.T) {
|
||||
resetMFAs(user)
|
||||
|
||||
user.Collection().MFA.Enabled = false
|
||||
|
||||
err = apis.RecordAuthResponse(event, user, "example", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected nil, got error: %v", err)
|
||||
}
|
||||
|
||||
body := rec.Body.String()
|
||||
if strings.Contains(body, "mfaId") {
|
||||
t.Fatalf("Expected no mfaId in the response body, got\n%v", body)
|
||||
}
|
||||
if !strings.Contains(body, "token") {
|
||||
t.Fatalf("Expected auth token in the response body, got\n%v", body)
|
||||
}
|
||||
|
||||
if total := totalMFAs(user); total != 0 {
|
||||
t.Fatalf("Expected no mfa records to be created, got %d", total)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("no explicit auth method", func(t *testing.T) {
|
||||
resetMFAs(user)
|
||||
|
||||
err = apis.RecordAuthResponse(event, user, "", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected nil, got error: %v", err)
|
||||
}
|
||||
|
||||
body := rec.Body.String()
|
||||
if strings.Contains(body, "mfaId") {
|
||||
t.Fatalf("Expected no mfaId in the response body, got\n%v", body)
|
||||
}
|
||||
if !strings.Contains(body, "token") {
|
||||
t.Fatalf("Expected auth token in the response body, got\n%v", body)
|
||||
}
|
||||
|
||||
if total := totalMFAs(user); total != 0 {
|
||||
t.Fatalf("Expected no mfa records to be created, got %d", total)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("no mfa wanted (mfa rule check failure)", func(t *testing.T) {
|
||||
resetMFAs(user)
|
||||
user.Collection().MFA.Rule = "1=2"
|
||||
|
||||
err = apis.RecordAuthResponse(event, user, "example", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected nil, got error: %v", err)
|
||||
}
|
||||
|
||||
body := rec.Body.String()
|
||||
if strings.Contains(body, "mfaId") {
|
||||
t.Fatalf("Expected no mfaId in the response body, got\n%v", body)
|
||||
}
|
||||
if !strings.Contains(body, "token") {
|
||||
t.Fatalf("Expected auth token in the response body, got\n%v", body)
|
||||
}
|
||||
|
||||
if total := totalMFAs(user); total != 0 {
|
||||
t.Fatalf("Expected no mfa records to be created, got %d", total)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("mfa wanted (mfa rule check success)", func(t *testing.T) {
|
||||
resetMFAs(user)
|
||||
user.Collection().MFA.Rule = "1=1"
|
||||
|
||||
err = apis.RecordAuthResponse(event, user, "example", nil)
|
||||
if !errors.Is(err, apis.ErrMFA) {
|
||||
t.Fatalf("Expected ErrMFA, got: %v", err)
|
||||
}
|
||||
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, "mfaId") {
|
||||
t.Fatalf("Expected the created mfaId to be returned in the response body, got\n%v", body)
|
||||
}
|
||||
|
||||
if total := totalMFAs(user); total != 1 {
|
||||
t.Fatalf("Expected a single mfa record to be created, got %d", total)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("mfa first-time", func(t *testing.T) {
|
||||
resetMFAs(user)
|
||||
|
||||
err = apis.RecordAuthResponse(event, user, "example", nil)
|
||||
if !errors.Is(err, apis.ErrMFA) {
|
||||
t.Fatalf("Expected ErrMFA, got: %v", err)
|
||||
}
|
||||
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, "mfaId") {
|
||||
t.Fatalf("Expected the created mfaId to be returned in the response body, got\n%v", body)
|
||||
}
|
||||
|
||||
if total := totalMFAs(user); total != 1 {
|
||||
t.Fatalf("Expected a single mfa record to be created, got %d", total)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("mfa second-time with the same auth method", func(t *testing.T) {
|
||||
resetMFAs(user)
|
||||
|
||||
// create a dummy mfa record
|
||||
mfa := core.NewMFA(app)
|
||||
mfa.SetCollectionRef(user.Collection().Id)
|
||||
mfa.SetRecordRef(user.Id)
|
||||
mfa.SetMethod("example")
|
||||
if err = app.Save(mfa); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
event.Request = httptest.NewRequest(http.MethodGet, "/?mfaId="+mfa.Id, nil)
|
||||
|
||||
err = apis.RecordAuthResponse(event, user, "example", nil)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error, got nil")
|
||||
}
|
||||
|
||||
if total := totalMFAs(user); total != 1 {
|
||||
t.Fatalf("Expected only 1 mfa record (the existing one), got %d", total)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("mfa second-time with the different auth method (query param)", func(t *testing.T) {
|
||||
resetMFAs(user)
|
||||
|
||||
// create a dummy mfa record
|
||||
mfa := core.NewMFA(app)
|
||||
mfa.SetCollectionRef(user.Collection().Id)
|
||||
mfa.SetRecordRef(user.Id)
|
||||
mfa.SetMethod("example1")
|
||||
if err = app.Save(mfa); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
event.Request = httptest.NewRequest(http.MethodGet, "/?mfaId="+mfa.Id, nil)
|
||||
|
||||
err = apis.RecordAuthResponse(event, user, "example2", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected nil, got error: %v", err)
|
||||
}
|
||||
|
||||
if total := totalMFAs(user); total != 0 {
|
||||
t.Fatalf("Expected the dummy mfa record to be deleted, found %d", total)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("mfa second-time with the different auth method (body param)", func(t *testing.T) {
|
||||
resetMFAs(user)
|
||||
|
||||
// create a dummy mfa record
|
||||
mfa := core.NewMFA(app)
|
||||
mfa.SetCollectionRef(user.Collection().Id)
|
||||
mfa.SetRecordRef(user.Id)
|
||||
mfa.SetMethod("example1")
|
||||
if err = app.Save(mfa); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
event.Request = httptest.NewRequest(http.MethodGet, "/", strings.NewReader(`{"mfaId":"`+mfa.Id+`"}`))
|
||||
event.Request.Header.Add("content-type", "application/json")
|
||||
|
||||
err = apis.RecordAuthResponse(event, user, "example2", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected nil, got error: %v", err)
|
||||
}
|
||||
|
||||
if total := totalMFAs(user); total != 0 {
|
||||
t.Fatalf("Expected the dummy mfa record to be deleted, found %d", total)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing mfa", func(t *testing.T) {
|
||||
resetMFAs(user)
|
||||
|
||||
event.Request = httptest.NewRequest(http.MethodGet, "/?mfaId=missing", nil)
|
||||
|
||||
err = apis.RecordAuthResponse(event, user, "example2", nil)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error, got nil")
|
||||
}
|
||||
|
||||
if total := totalMFAs(user); total != 0 {
|
||||
t.Fatalf("Expected 0 mfa records, got %d", total)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("expired mfa", func(t *testing.T) {
|
||||
resetMFAs(user)
|
||||
|
||||
// create a dummy expired mfa record
|
||||
mfa := core.NewMFA(app)
|
||||
mfa.SetCollectionRef(user.Collection().Id)
|
||||
mfa.SetRecordRef(user.Id)
|
||||
mfa.SetMethod("example1")
|
||||
mfa.SetRaw("created", types.NowDateTime().Add(-1*time.Hour))
|
||||
mfa.SetRaw("updated", types.NowDateTime().Add(-1*time.Hour))
|
||||
if err = app.Save(mfa); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
event.Request = httptest.NewRequest(http.MethodGet, "/?mfaId="+mfa.Id, nil)
|
||||
|
||||
err = apis.RecordAuthResponse(event, user, "example2", nil)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error, got nil")
|
||||
}
|
||||
|
||||
if totalMFAs(user) != 0 {
|
||||
t.Fatal("Expected the expired mfa record to be deleted")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("mfa for different auth record", func(t *testing.T) {
|
||||
resetMFAs(user)
|
||||
|
||||
// create a dummy expired mfa record
|
||||
mfa := core.NewMFA(app)
|
||||
mfa.SetCollectionRef(user2.Collection().Id)
|
||||
mfa.SetRecordRef(user2.Id)
|
||||
mfa.SetMethod("example1")
|
||||
if err = app.Save(mfa); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
event.Request = httptest.NewRequest(http.MethodGet, "/?mfaId="+mfa.Id, nil)
|
||||
|
||||
err = apis.RecordAuthResponse(event, user, "example2", nil)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error, got nil")
|
||||
}
|
||||
|
||||
if total := totalMFAs(user); total != 0 {
|
||||
t.Fatalf("Expected no user mfas, got %d", total)
|
||||
}
|
||||
|
||||
if total := totalMFAs(user2); total != 1 {
|
||||
t.Fatalf("Expected only 1 user2 mfa, got %d", total)
|
||||
}
|
||||
})
|
||||
}
|
318
apis/serve.go
Normal file
318
apis/serve.go
Normal file
|
@ -0,0 +1,318 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fatih/color"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tools/hook"
|
||||
"github.com/pocketbase/pocketbase/tools/list"
|
||||
"github.com/pocketbase/pocketbase/tools/routine"
|
||||
"github.com/pocketbase/pocketbase/ui"
|
||||
"golang.org/x/crypto/acme"
|
||||
"golang.org/x/crypto/acme/autocert"
|
||||
)
|
||||
|
||||
// ServeConfig defines a configuration struct for apis.Serve().
|
||||
type ServeConfig struct {
|
||||
// ShowStartBanner indicates whether to show or hide the server start console message.
|
||||
ShowStartBanner bool
|
||||
|
||||
// HttpAddr is the TCP address to listen for the HTTP server (eg. "127.0.0.1:80").
|
||||
HttpAddr string
|
||||
|
||||
// HttpsAddr is the TCP address to listen for the HTTPS server (eg. "127.0.0.1:443").
|
||||
HttpsAddr string
|
||||
|
||||
// Optional domains list to use when issuing the TLS certificate.
|
||||
//
|
||||
// If not set, the host from the bound server address will be used.
|
||||
//
|
||||
// For convenience, for each "non-www" domain a "www" entry and
|
||||
// redirect will be automatically added.
|
||||
CertificateDomains []string
|
||||
|
||||
// AllowedOrigins is an optional list of CORS origins (default to "*").
|
||||
AllowedOrigins []string
|
||||
}
|
||||
|
||||
// Serve starts a new app web server.
|
||||
//
|
||||
// NB! The app should be bootstrapped before starting the web server.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// app.Bootstrap()
|
||||
// apis.Serve(app, apis.ServeConfig{
|
||||
// HttpAddr: "127.0.0.1:8080",
|
||||
// ShowStartBanner: false,
|
||||
// })
|
||||
func Serve(app core.App, config ServeConfig) error {
|
||||
if len(config.AllowedOrigins) == 0 {
|
||||
config.AllowedOrigins = []string{"*"}
|
||||
}
|
||||
|
||||
// ensure that the latest migrations are applied before starting the server
|
||||
err := app.RunAllMigrations()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
pbRouter, err := NewRouter(app)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
pbRouter.Bind(CORS(CORSConfig{
|
||||
AllowOrigins: config.AllowedOrigins,
|
||||
AllowMethods: []string{http.MethodGet, http.MethodHead, http.MethodPut, http.MethodPatch, http.MethodPost, http.MethodDelete},
|
||||
}))
|
||||
|
||||
pbRouter.GET("/_/{path...}", Static(ui.DistDirFS, false)).
|
||||
BindFunc(func(e *core.RequestEvent) error {
|
||||
// ignore root path
|
||||
if e.Request.PathValue(StaticWildcardParam) != "" {
|
||||
e.Response.Header().Set("Cache-Control", "max-age=1209600, stale-while-revalidate=86400")
|
||||
}
|
||||
|
||||
// add a default CSP
|
||||
if e.Response.Header().Get("Content-Security-Policy") == "" {
|
||||
e.Response.Header().Set("Content-Security-Policy", "default-src 'self'; style-src 'self' 'unsafe-inline'; img-src 'self' http://127.0.0.1:* https://tile.openstreetmap.org data: blob:; connect-src 'self' http://127.0.0.1:* https://nominatim.openstreetmap.org; script-src 'self' 'sha256-GRUzBA7PzKYug7pqxv5rJaec5bwDCw1Vo6/IXwvD3Tc='")
|
||||
}
|
||||
|
||||
return e.Next()
|
||||
}).
|
||||
Bind(Gzip())
|
||||
|
||||
// start http server
|
||||
// ---
|
||||
mainAddr := config.HttpAddr
|
||||
if config.HttpsAddr != "" {
|
||||
mainAddr = config.HttpsAddr
|
||||
}
|
||||
|
||||
var wwwRedirects []string
|
||||
|
||||
// extract the host names for the certificate host policy
|
||||
hostNames := config.CertificateDomains
|
||||
if len(hostNames) == 0 {
|
||||
host, _, _ := net.SplitHostPort(mainAddr)
|
||||
hostNames = append(hostNames, host)
|
||||
}
|
||||
for _, host := range hostNames {
|
||||
if strings.HasPrefix(host, "www.") {
|
||||
continue // explicitly set www host
|
||||
}
|
||||
|
||||
wwwHost := "www." + host
|
||||
if !list.ExistInSlice(wwwHost, hostNames) {
|
||||
hostNames = append(hostNames, wwwHost)
|
||||
wwwRedirects = append(wwwRedirects, wwwHost)
|
||||
}
|
||||
}
|
||||
|
||||
// implicit www->non-www redirect(s)
|
||||
if len(wwwRedirects) > 0 {
|
||||
pbRouter.Bind(wwwRedirect(wwwRedirects))
|
||||
}
|
||||
|
||||
certManager := &autocert.Manager{
|
||||
Prompt: autocert.AcceptTOS,
|
||||
Cache: autocert.DirCache(filepath.Join(app.DataDir(), core.LocalAutocertCacheDirName)),
|
||||
HostPolicy: autocert.HostWhitelist(hostNames...),
|
||||
}
|
||||
|
||||
// base request context used for cancelling long running requests
|
||||
// like the SSE connections
|
||||
baseCtx, cancelBaseCtx := context.WithCancel(context.Background())
|
||||
defer cancelBaseCtx()
|
||||
|
||||
server := &http.Server{
|
||||
TLSConfig: &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
GetCertificate: certManager.GetCertificate,
|
||||
NextProtos: []string{acme.ALPNProto},
|
||||
},
|
||||
// higher defaults to accommodate large file uploads/downloads
|
||||
WriteTimeout: 5 * time.Minute,
|
||||
ReadTimeout: 5 * time.Minute,
|
||||
ReadHeaderTimeout: 1 * time.Minute,
|
||||
Addr: mainAddr,
|
||||
BaseContext: func(l net.Listener) context.Context {
|
||||
return baseCtx
|
||||
},
|
||||
ErrorLog: log.New(&serverErrorLogWriter{app: app}, "", 0),
|
||||
}
|
||||
|
||||
serveEvent := new(core.ServeEvent)
|
||||
serveEvent.App = app
|
||||
serveEvent.Router = pbRouter
|
||||
serveEvent.Server = server
|
||||
serveEvent.CertManager = certManager
|
||||
serveEvent.InstallerFunc = DefaultInstallerFunc
|
||||
|
||||
var listener net.Listener
|
||||
|
||||
// graceful shutdown
|
||||
// ---------------------------------------------------------------
|
||||
// WaitGroup to block until server.ShutDown() returns because Serve and similar methods exit immediately.
|
||||
// Note that the WaitGroup would do nothing if the app.OnTerminate() hook isn't triggered.
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// try to gracefully shutdown the server on app termination
|
||||
app.OnTerminate().Bind(&hook.Handler[*core.TerminateEvent]{
|
||||
Id: "pbGracefulShutdown",
|
||||
Func: func(te *core.TerminateEvent) error {
|
||||
cancelBaseCtx()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
||||
defer cancel()
|
||||
|
||||
wg.Add(1)
|
||||
|
||||
_ = server.Shutdown(ctx)
|
||||
|
||||
if te.IsRestart {
|
||||
// wait for execve and other handlers up to 3 seconds before exit
|
||||
time.AfterFunc(3*time.Second, func() {
|
||||
wg.Done()
|
||||
})
|
||||
} else {
|
||||
wg.Done()
|
||||
}
|
||||
|
||||
return te.Next()
|
||||
},
|
||||
Priority: -9999,
|
||||
})
|
||||
|
||||
// wait for the graceful shutdown to complete before exit
|
||||
defer func() {
|
||||
wg.Wait()
|
||||
|
||||
if listener != nil {
|
||||
_ = listener.Close()
|
||||
}
|
||||
}()
|
||||
// ---------------------------------------------------------------
|
||||
|
||||
var baseURL string
|
||||
|
||||
// trigger the OnServe hook and start the tcp listener
|
||||
serveHookErr := app.OnServe().Trigger(serveEvent, func(e *core.ServeEvent) error {
|
||||
handler, err := e.Router.BuildMux()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
e.Server.Handler = handler
|
||||
|
||||
if config.HttpsAddr == "" {
|
||||
baseURL = "http://" + serverAddrToHost(serveEvent.Server.Addr)
|
||||
} else {
|
||||
baseURL = "https://"
|
||||
if len(config.CertificateDomains) > 0 {
|
||||
baseURL += config.CertificateDomains[0]
|
||||
} else {
|
||||
baseURL += serverAddrToHost(serveEvent.Server.Addr)
|
||||
}
|
||||
}
|
||||
|
||||
addr := e.Server.Addr
|
||||
if addr == "" {
|
||||
// fallback similar to the std Server.ListenAndServe/ListenAndServeTLS
|
||||
if config.HttpsAddr != "" {
|
||||
addr = ":https"
|
||||
} else {
|
||||
addr = ":http"
|
||||
}
|
||||
}
|
||||
|
||||
listener, err = net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if e.InstallerFunc != nil {
|
||||
app := e.App
|
||||
installerFunc := e.InstallerFunc
|
||||
routine.FireAndForget(func() {
|
||||
if err := loadInstaller(app, baseURL, installerFunc); err != nil {
|
||||
app.Logger().Warn("Failed to initialize installer", "error", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if serveHookErr != nil {
|
||||
return serveHookErr
|
||||
}
|
||||
|
||||
if listener == nil {
|
||||
//nolint:staticcheck
|
||||
return errors.New("The OnServe finalizer wasn't invoked. Did you forget to call the ServeEvent.Next() method?")
|
||||
}
|
||||
|
||||
if config.ShowStartBanner {
|
||||
date := new(strings.Builder)
|
||||
log.New(date, "", log.LstdFlags).Print()
|
||||
|
||||
bold := color.New(color.Bold).Add(color.FgGreen)
|
||||
bold.Printf(
|
||||
"%s Server started at %s\n",
|
||||
strings.TrimSpace(date.String()),
|
||||
color.CyanString("%s", baseURL),
|
||||
)
|
||||
|
||||
regular := color.New()
|
||||
regular.Printf("├─ REST API: %s\n", color.CyanString("%s/api/", baseURL))
|
||||
regular.Printf("└─ Dashboard: %s\n", color.CyanString("%s/_/", baseURL))
|
||||
}
|
||||
|
||||
var serveErr error
|
||||
if config.HttpsAddr != "" {
|
||||
if config.HttpAddr != "" {
|
||||
// start an additional HTTP server for redirecting the traffic to the HTTPS version
|
||||
go http.ListenAndServe(config.HttpAddr, certManager.HTTPHandler(nil))
|
||||
}
|
||||
|
||||
// start HTTPS server
|
||||
serveErr = serveEvent.Server.ServeTLS(listener, "", "")
|
||||
} else {
|
||||
// OR start HTTP server
|
||||
serveErr = serveEvent.Server.Serve(listener)
|
||||
}
|
||||
if serveErr != nil && !errors.Is(serveErr, http.ErrServerClosed) {
|
||||
return serveErr
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// serverAddrToHost loosely converts http.Server.Addr string into a host to print.
|
||||
func serverAddrToHost(addr string) string {
|
||||
if addr == "" || strings.HasSuffix(addr, ":http") || strings.HasSuffix(addr, ":https") {
|
||||
return "127.0.0.1"
|
||||
}
|
||||
return addr
|
||||
}
|
||||
|
||||
type serverErrorLogWriter struct {
|
||||
app core.App
|
||||
}
|
||||
|
||||
func (s *serverErrorLogWriter) Write(p []byte) (int, error) {
|
||||
s.app.Logger().Debug(strings.TrimSpace(string(p)))
|
||||
|
||||
return len(p), nil
|
||||
}
|
143
apis/settings.go
Normal file
143
apis/settings.go
Normal file
|
@ -0,0 +1,143 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/forms"
|
||||
"github.com/pocketbase/pocketbase/tools/router"
|
||||
)
|
||||
|
||||
// bindSettingsApi registers the settings api endpoints.
|
||||
func bindSettingsApi(app core.App, rg *router.RouterGroup[*core.RequestEvent]) {
|
||||
subGroup := rg.Group("/settings").Bind(RequireSuperuserAuth())
|
||||
subGroup.GET("", settingsList)
|
||||
subGroup.PATCH("", settingsSet)
|
||||
subGroup.POST("/test/s3", settingsTestS3)
|
||||
subGroup.POST("/test/email", settingsTestEmail)
|
||||
subGroup.POST("/apple/generate-client-secret", settingsGenerateAppleClientSecret)
|
||||
}
|
||||
|
||||
func settingsList(e *core.RequestEvent) error {
|
||||
clone, err := e.App.Settings().Clone()
|
||||
if err != nil {
|
||||
return e.InternalServerError("", err)
|
||||
}
|
||||
|
||||
event := new(core.SettingsListRequestEvent)
|
||||
event.RequestEvent = e
|
||||
event.Settings = clone
|
||||
|
||||
return e.App.OnSettingsListRequest().Trigger(event, func(e *core.SettingsListRequestEvent) error {
|
||||
return execAfterSuccessTx(true, e.App, func() error {
|
||||
return e.JSON(http.StatusOK, e.Settings)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func settingsSet(e *core.RequestEvent) error {
|
||||
event := new(core.SettingsUpdateRequestEvent)
|
||||
event.RequestEvent = e
|
||||
|
||||
if clone, err := e.App.Settings().Clone(); err == nil {
|
||||
event.OldSettings = clone
|
||||
} else {
|
||||
return e.BadRequestError("", err)
|
||||
}
|
||||
|
||||
if clone, err := e.App.Settings().Clone(); err == nil {
|
||||
event.NewSettings = clone
|
||||
} else {
|
||||
return e.BadRequestError("", err)
|
||||
}
|
||||
|
||||
if err := e.BindBody(&event.NewSettings); err != nil {
|
||||
return e.BadRequestError("An error occurred while loading the submitted data.", err)
|
||||
}
|
||||
|
||||
return e.App.OnSettingsUpdateRequest().Trigger(event, func(e *core.SettingsUpdateRequestEvent) error {
|
||||
err := e.App.Save(e.NewSettings)
|
||||
if err != nil {
|
||||
return e.BadRequestError("An error occurred while saving the new settings.", err)
|
||||
}
|
||||
|
||||
appSettings, err := e.App.Settings().Clone()
|
||||
if err != nil {
|
||||
return e.InternalServerError("Failed to clone app settings.", err)
|
||||
}
|
||||
|
||||
return execAfterSuccessTx(true, e.App, func() error {
|
||||
return e.JSON(http.StatusOK, appSettings)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func settingsTestS3(e *core.RequestEvent) error {
|
||||
form := forms.NewTestS3Filesystem(e.App)
|
||||
|
||||
// load request
|
||||
if err := e.BindBody(form); err != nil {
|
||||
return e.BadRequestError("An error occurred while loading the submitted data.", err)
|
||||
}
|
||||
|
||||
// send
|
||||
if err := form.Submit(); err != nil {
|
||||
// form error
|
||||
if fErr, ok := err.(validation.Errors); ok {
|
||||
return e.BadRequestError("Failed to test the S3 filesystem.", fErr)
|
||||
}
|
||||
|
||||
// mailer error
|
||||
return e.BadRequestError("Failed to test the S3 filesystem. Raw error: \n"+err.Error(), nil)
|
||||
}
|
||||
|
||||
return e.NoContent(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func settingsTestEmail(e *core.RequestEvent) error {
|
||||
form := forms.NewTestEmailSend(e.App)
|
||||
|
||||
// load request
|
||||
if err := e.BindBody(form); err != nil {
|
||||
return e.BadRequestError("An error occurred while loading the submitted data.", err)
|
||||
}
|
||||
|
||||
// send
|
||||
if err := form.Submit(); err != nil {
|
||||
// form error
|
||||
if fErr, ok := err.(validation.Errors); ok {
|
||||
return e.BadRequestError("Failed to send the test email.", fErr)
|
||||
}
|
||||
|
||||
// mailer error
|
||||
return e.BadRequestError("Failed to send the test email. Raw error: \n"+err.Error(), nil)
|
||||
}
|
||||
|
||||
return e.NoContent(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func settingsGenerateAppleClientSecret(e *core.RequestEvent) error {
|
||||
form := forms.NewAppleClientSecretCreate(e.App)
|
||||
|
||||
// load request
|
||||
if err := e.BindBody(form); err != nil {
|
||||
return e.BadRequestError("An error occurred while loading the submitted data.", err)
|
||||
}
|
||||
|
||||
// generate
|
||||
secret, err := form.Submit()
|
||||
if err != nil {
|
||||
// form error
|
||||
if fErr, ok := err.(validation.Errors); ok {
|
||||
return e.BadRequestError("Invalid client secret data.", fErr)
|
||||
}
|
||||
|
||||
// secret generation error
|
||||
return e.BadRequestError("Failed to generate client secret. Raw error: \n"+err.Error(), nil)
|
||||
}
|
||||
|
||||
return e.JSON(http.StatusOK, map[string]string{
|
||||
"secret": secret,
|
||||
})
|
||||
}
|
641
apis/settings_test.go
Normal file
641
apis/settings_test.go
Normal file
|
@ -0,0 +1,641 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestSettingsList(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "unauthorized",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/settings",
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as regular user",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/settings",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/settings",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"meta":{`,
|
||||
`"logs":{`,
|
||||
`"smtp":{`,
|
||||
`"s3":{`,
|
||||
`"backups":{`,
|
||||
`"batch":{`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnSettingsListRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "OnSettingsListRequest tx body write check",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/settings",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.OnSettingsListRequest().BindFunc(func(e *core.SettingsListRequestEvent) error {
|
||||
original := e.App
|
||||
return e.App.RunInTransaction(func(txApp core.App) error {
|
||||
e.App = txApp
|
||||
defer func() { e.App = original }()
|
||||
|
||||
if err := e.Next(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return e.BadRequestError("TX_ERROR", nil)
|
||||
})
|
||||
})
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedEvents: map[string]int{"OnSettingsListRequest": 1},
|
||||
ExpectedContent: []string{"TX_ERROR"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSettingsSet(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
validData := `{
|
||||
"meta":{"appName":"update_test"},
|
||||
"s3":{"secret": "s3_secret"},
|
||||
"backups":{"s3":{"secret":"backups_s3_secret"}}
|
||||
}`
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "unauthorized",
|
||||
Method: http.MethodPatch,
|
||||
URL: "/api/settings",
|
||||
Body: strings.NewReader(validData),
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as regular user",
|
||||
Method: http.MethodPatch,
|
||||
URL: "/api/settings",
|
||||
Body: strings.NewReader(validData),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser submitting empty data",
|
||||
Method: http.MethodPatch,
|
||||
URL: "/api/settings",
|
||||
Body: strings.NewReader(``),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"meta":{`,
|
||||
`"logs":{`,
|
||||
`"smtp":{`,
|
||||
`"s3":{`,
|
||||
`"backups":{`,
|
||||
`"batch":{`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnSettingsUpdateRequest": 1,
|
||||
"OnModelUpdate": 1,
|
||||
"OnModelUpdateExecute": 1,
|
||||
"OnModelAfterUpdateSuccess": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnSettingsReload": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser submitting invalid data",
|
||||
Method: http.MethodPatch,
|
||||
URL: "/api/settings",
|
||||
Body: strings.NewReader(`{"meta":{"appName":""}}`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"meta":{"appName":{"code":"validation_required"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnModelUpdate": 1,
|
||||
"OnModelAfterUpdateError": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnSettingsUpdateRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser submitting valid data",
|
||||
Method: http.MethodPatch,
|
||||
URL: "/api/settings",
|
||||
Body: strings.NewReader(validData),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"meta":{`,
|
||||
`"logs":{`,
|
||||
`"smtp":{`,
|
||||
`"s3":{`,
|
||||
`"backups":{`,
|
||||
`"batch":{`,
|
||||
`"appName":"update_test"`,
|
||||
},
|
||||
NotExpectedContent: []string{
|
||||
"secret",
|
||||
"password",
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnSettingsUpdateRequest": 1,
|
||||
"OnModelUpdate": 1,
|
||||
"OnModelUpdateExecute": 1,
|
||||
"OnModelAfterUpdateSuccess": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnSettingsReload": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "OnSettingsUpdateRequest tx body write check",
|
||||
Method: http.MethodPatch,
|
||||
URL: "/api/settings",
|
||||
Body: strings.NewReader(validData),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.OnSettingsUpdateRequest().BindFunc(func(e *core.SettingsUpdateRequestEvent) error {
|
||||
original := e.App
|
||||
return e.App.RunInTransaction(func(txApp core.App) error {
|
||||
e.App = txApp
|
||||
defer func() { e.App = original }()
|
||||
|
||||
if err := e.Next(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return e.BadRequestError("TX_ERROR", nil)
|
||||
})
|
||||
})
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedEvents: map[string]int{"OnSettingsUpdateRequest": 1},
|
||||
ExpectedContent: []string{"TX_ERROR"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSettingsTestS3(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "unauthorized",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/settings/test/s3",
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as regular user",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/settings/test/s3",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser (missing body + no s3)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/settings/test/s3",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"filesystem":{`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser (invalid filesystem)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/settings/test/s3",
|
||||
Body: strings.NewReader(`{"filesystem":"invalid"}`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"filesystem":{`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser (valid filesystem and no s3)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/settings/test/s3",
|
||||
Body: strings.NewReader(`{"filesystem":"storage"}`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{}`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSettingsTestEmail(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "unauthorized",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/settings/test/email",
|
||||
Body: strings.NewReader(`{
|
||||
"template": "verification",
|
||||
"email": "test@example.com"
|
||||
}`),
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as regular user",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/settings/test/email",
|
||||
Body: strings.NewReader(`{
|
||||
"template": "verification",
|
||||
"email": "test@example.com"
|
||||
}`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser (invalid body)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/settings/test/email",
|
||||
Body: strings.NewReader(`{`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser (empty json)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/settings/test/email",
|
||||
Body: strings.NewReader(`{}`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"email":{"code":"validation_required"`,
|
||||
`"template":{"code":"validation_required"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser (verifiation template)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/settings/test/email",
|
||||
Body: strings.NewReader(`{
|
||||
"template": "verification",
|
||||
"email": "test@example.com"
|
||||
}`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
if app.TestMailer.TotalSend() != 1 {
|
||||
t.Fatalf("[verification] Expected 1 sent email, got %d", app.TestMailer.TotalSend())
|
||||
}
|
||||
|
||||
if len(app.TestMailer.LastMessage().To) != 1 {
|
||||
t.Fatalf("[verification] Expected 1 recipient, got %v", app.TestMailer.LastMessage().To)
|
||||
}
|
||||
|
||||
if app.TestMailer.LastMessage().To[0].Address != "test@example.com" {
|
||||
t.Fatalf("[verification] Expected the email to be sent to %s, got %s", "test@example.com", app.TestMailer.LastMessage().To[0].Address)
|
||||
}
|
||||
|
||||
if !strings.Contains(app.TestMailer.LastMessage().HTML, "Verify") {
|
||||
t.Fatalf("[verification] Expected to sent a verification email, got \n%v\n%v", app.TestMailer.LastMessage().Subject, app.TestMailer.LastMessage().HTML)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 204,
|
||||
ExpectedContent: []string{},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnMailerSend": 1,
|
||||
"OnMailerRecordVerificationSend": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser (password reset template)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/settings/test/email",
|
||||
Body: strings.NewReader(`{
|
||||
"template": "password-reset",
|
||||
"email": "test@example.com"
|
||||
}`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
if app.TestMailer.TotalSend() != 1 {
|
||||
t.Fatalf("[password-reset] Expected 1 sent email, got %d", app.TestMailer.TotalSend())
|
||||
}
|
||||
|
||||
if len(app.TestMailer.LastMessage().To) != 1 {
|
||||
t.Fatalf("[password-reset] Expected 1 recipient, got %v", app.TestMailer.LastMessage().To)
|
||||
}
|
||||
|
||||
if app.TestMailer.LastMessage().To[0].Address != "test@example.com" {
|
||||
t.Fatalf("[password-reset] Expected the email to be sent to %s, got %s", "test@example.com", app.TestMailer.LastMessage().To[0].Address)
|
||||
}
|
||||
|
||||
if !strings.Contains(app.TestMailer.LastMessage().HTML, "Reset password") {
|
||||
t.Fatalf("[password-reset] Expected to sent a password-reset email, got \n%v\n%v", app.TestMailer.LastMessage().Subject, app.TestMailer.LastMessage().HTML)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 204,
|
||||
ExpectedContent: []string{},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnMailerSend": 1,
|
||||
"OnMailerRecordPasswordResetSend": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser (email change)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/settings/test/email",
|
||||
Body: strings.NewReader(`{
|
||||
"template": "email-change",
|
||||
"email": "test@example.com"
|
||||
}`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
if app.TestMailer.TotalSend() != 1 {
|
||||
t.Fatalf("[email-change] Expected 1 sent email, got %d", app.TestMailer.TotalSend())
|
||||
}
|
||||
|
||||
if len(app.TestMailer.LastMessage().To) != 1 {
|
||||
t.Fatalf("[email-change] Expected 1 recipient, got %v", app.TestMailer.LastMessage().To)
|
||||
}
|
||||
|
||||
if app.TestMailer.LastMessage().To[0].Address != "test@example.com" {
|
||||
t.Fatalf("[email-change] Expected the email to be sent to %s, got %s", "test@example.com", app.TestMailer.LastMessage().To[0].Address)
|
||||
}
|
||||
|
||||
if !strings.Contains(app.TestMailer.LastMessage().HTML, "Confirm new email") {
|
||||
t.Fatalf("[email-change] Expected to sent a confirm new email email, got \n%v\n%v", app.TestMailer.LastMessage().Subject, app.TestMailer.LastMessage().HTML)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 204,
|
||||
ExpectedContent: []string{},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnMailerSend": 1,
|
||||
"OnMailerRecordEmailChangeSend": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser (otp)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/settings/test/email",
|
||||
Body: strings.NewReader(`{
|
||||
"template": "otp",
|
||||
"email": "test@example.com"
|
||||
}`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
if app.TestMailer.TotalSend() != 1 {
|
||||
t.Fatalf("[otp] Expected 1 sent email, got %d", app.TestMailer.TotalSend())
|
||||
}
|
||||
|
||||
if len(app.TestMailer.LastMessage().To) != 1 {
|
||||
t.Fatalf("[otp] Expected 1 recipient, got %v", app.TestMailer.LastMessage().To)
|
||||
}
|
||||
|
||||
if app.TestMailer.LastMessage().To[0].Address != "test@example.com" {
|
||||
t.Fatalf("[otp] Expected the email to be sent to %s, got %s", "test@example.com", app.TestMailer.LastMessage().To[0].Address)
|
||||
}
|
||||
|
||||
if !strings.Contains(app.TestMailer.LastMessage().HTML, "one-time password") {
|
||||
t.Fatalf("[otp] Expected to sent OTP email, got \n%v\n%v", app.TestMailer.LastMessage().Subject, app.TestMailer.LastMessage().HTML)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 204,
|
||||
ExpectedContent: []string{},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnMailerSend": 1,
|
||||
"OnMailerRecordOTPSend": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateAppleClientSecret(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
encodedKey, err := x509.MarshalPKCS8PrivateKey(key)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
privatePem := pem.EncodeToMemory(
|
||||
&pem.Block{
|
||||
Type: "PRIVATE KEY",
|
||||
Bytes: encodedKey,
|
||||
},
|
||||
)
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "unauthorized",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/settings/apple/generate-client-secret",
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as regular user",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/settings/apple/generate-client-secret",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser (invalid body)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/settings/apple/generate-client-secret",
|
||||
Body: strings.NewReader(`{`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser (empty json)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/settings/apple/generate-client-secret",
|
||||
Body: strings.NewReader(`{}`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"clientId":{"code":"validation_required"`,
|
||||
`"teamId":{"code":"validation_required"`,
|
||||
`"keyId":{"code":"validation_required"`,
|
||||
`"privateKey":{"code":"validation_required"`,
|
||||
`"duration":{"code":"validation_required"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser (invalid data)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/settings/apple/generate-client-secret",
|
||||
Body: strings.NewReader(`{
|
||||
"clientId": "",
|
||||
"teamId": "123456789",
|
||||
"keyId": "123456789",
|
||||
"privateKey": "invalid",
|
||||
"duration": -1
|
||||
}`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"clientId":{"code":"validation_required"`,
|
||||
`"teamId":{"code":"validation_length_invalid"`,
|
||||
`"keyId":{"code":"validation_length_invalid"`,
|
||||
`"privateKey":{"code":"validation_match_invalid"`,
|
||||
`"duration":{"code":"validation_min_greater_equal_than_required"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser (valid data)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/settings/apple/generate-client-secret",
|
||||
Body: strings.NewReader(fmt.Sprintf(`{
|
||||
"clientId": "123",
|
||||
"teamId": "1234567890",
|
||||
"keyId": "1234567891",
|
||||
"privateKey": %q,
|
||||
"duration": 1
|
||||
}`, privatePem)),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"secret":"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue