1
0
Fork 0

Adding upstream version 0.28.1.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-05-22 10:57:38 +02:00
parent 88f1d47ab6
commit e28c88ef14
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
933 changed files with 194711 additions and 0 deletions

1538
core/app.go Normal file

File diff suppressed because it is too large Load diff

239
core/auth_origin_model.go Normal file
View file

@ -0,0 +1,239 @@
package core
import (
"context"
"errors"
"slices"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase/tools/hook"
"github.com/pocketbase/pocketbase/tools/types"
)
const CollectionNameAuthOrigins = "_authOrigins"
var (
_ Model = (*AuthOrigin)(nil)
_ PreValidator = (*AuthOrigin)(nil)
_ RecordProxy = (*AuthOrigin)(nil)
)
// AuthOrigin defines a Record proxy for working with the authOrigins collection.
type AuthOrigin struct {
*Record
}
// NewAuthOrigin instantiates and returns a new blank *AuthOrigin model.
//
// Example usage:
//
// origin := core.NewOrigin(app)
// origin.SetRecordRef(user.Id)
// origin.SetCollectionRef(user.Collection().Id)
// origin.SetFingerprint("...")
// app.Save(origin)
func NewAuthOrigin(app App) *AuthOrigin {
m := &AuthOrigin{}
c, err := app.FindCachedCollectionByNameOrId(CollectionNameAuthOrigins)
if err != nil {
// this is just to make tests easier since authOrigins is a system collection and it is expected to be always accessible
// (note: the loaded record is further checked on AuthOrigin.PreValidate())
c = NewBaseCollection("@___invalid___")
}
m.Record = NewRecord(c)
return m
}
// PreValidate implements the [PreValidator] interface and checks
// whether the proxy is properly loaded.
func (m *AuthOrigin) PreValidate(ctx context.Context, app App) error {
if m.Record == nil || m.Record.Collection().Name != CollectionNameAuthOrigins {
return errors.New("missing or invalid AuthOrigin ProxyRecord")
}
return nil
}
// ProxyRecord returns the proxied Record model.
func (m *AuthOrigin) ProxyRecord() *Record {
return m.Record
}
// SetProxyRecord loads the specified record model into the current proxy.
func (m *AuthOrigin) SetProxyRecord(record *Record) {
m.Record = record
}
// CollectionRef returns the "collectionRef" field value.
func (m *AuthOrigin) CollectionRef() string {
return m.GetString("collectionRef")
}
// SetCollectionRef updates the "collectionRef" record field value.
func (m *AuthOrigin) SetCollectionRef(collectionId string) {
m.Set("collectionRef", collectionId)
}
// RecordRef returns the "recordRef" record field value.
func (m *AuthOrigin) RecordRef() string {
return m.GetString("recordRef")
}
// SetRecordRef updates the "recordRef" record field value.
func (m *AuthOrigin) SetRecordRef(recordId string) {
m.Set("recordRef", recordId)
}
// Fingerprint returns the "fingerprint" record field value.
func (m *AuthOrigin) Fingerprint() string {
return m.GetString("fingerprint")
}
// SetFingerprint updates the "fingerprint" record field value.
func (m *AuthOrigin) SetFingerprint(fingerprint string) {
m.Set("fingerprint", fingerprint)
}
// Created returns the "created" record field value.
func (m *AuthOrigin) Created() types.DateTime {
return m.GetDateTime("created")
}
// Updated returns the "updated" record field value.
func (m *AuthOrigin) Updated() types.DateTime {
return m.GetDateTime("updated")
}
func (app *BaseApp) registerAuthOriginHooks() {
recordRefHooks[*AuthOrigin](app, CollectionNameAuthOrigins, CollectionTypeAuth)
// delete existing auth origins on password change
app.OnRecordUpdate().Bind(&hook.Handler[*RecordEvent]{
Func: func(e *RecordEvent) error {
err := e.Next()
if err != nil || !e.Record.Collection().IsAuth() {
return err
}
old := e.Record.Original().GetString(FieldNamePassword + ":hash")
new := e.Record.GetString(FieldNamePassword + ":hash")
if old != new {
err = e.App.DeleteAllAuthOriginsByRecord(e.Record)
if err != nil {
e.App.Logger().Warn(
"Failed to delete all previous auth origin fingerprints",
"error", err,
"recordId", e.Record.Id,
"collectionId", e.Record.Collection().Id,
)
}
}
return nil
},
Priority: 99,
})
}
// -------------------------------------------------------------------
// recordRefHooks registers common hooks that are usually used with record proxies
// that have polymorphic record relations (aka. "collectionRef" and "recordRef" fields).
func recordRefHooks[T RecordProxy](app App, collectionName string, optCollectionTypes ...string) {
app.OnRecordValidate(collectionName).Bind(&hook.Handler[*RecordEvent]{
Func: func(e *RecordEvent) error {
collectionId := e.Record.GetString("collectionRef")
err := validation.Validate(collectionId, validation.Required, validation.By(validateCollectionId(e.App, optCollectionTypes...)))
if err != nil {
return validation.Errors{"collectionRef": err}
}
recordId := e.Record.GetString("recordRef")
err = validation.Validate(recordId, validation.Required, validation.By(validateRecordId(e.App, collectionId)))
if err != nil {
return validation.Errors{"recordRef": err}
}
return e.Next()
},
Priority: 99,
})
// delete on collection ref delete
app.OnCollectionDeleteExecute().Bind(&hook.Handler[*CollectionEvent]{
Func: func(e *CollectionEvent) error {
if e.Collection.Name == collectionName || (len(optCollectionTypes) > 0 && !slices.Contains(optCollectionTypes, e.Collection.Type)) {
return e.Next()
}
originalApp := e.App
txErr := e.App.RunInTransaction(func(txApp App) error {
e.App = txApp
if err := e.Next(); err != nil {
return err
}
rels, err := txApp.FindAllRecords(collectionName, dbx.HashExp{"collectionRef": e.Collection.Id})
if err != nil {
return err
}
for _, mfa := range rels {
if err := txApp.Delete(mfa); err != nil {
return err
}
}
return nil
})
e.App = originalApp
return txErr
},
Priority: 99,
})
// delete on record ref delete
app.OnRecordDeleteExecute().Bind(&hook.Handler[*RecordEvent]{
Func: func(e *RecordEvent) error {
if e.Record.Collection().Name == collectionName ||
(len(optCollectionTypes) > 0 && !slices.Contains(optCollectionTypes, e.Record.Collection().Type)) {
return e.Next()
}
originalApp := e.App
txErr := e.App.RunInTransaction(func(txApp App) error {
e.App = txApp
if err := e.Next(); err != nil {
return err
}
rels, err := txApp.FindAllRecords(collectionName, dbx.HashExp{
"collectionRef": e.Record.Collection().Id,
"recordRef": e.Record.Id,
})
if err != nil {
return err
}
for _, rel := range rels {
if err := txApp.Delete(rel); err != nil {
return err
}
}
return nil
})
e.App = originalApp
return txErr
},
Priority: 99,
})
}

View file

@ -0,0 +1,332 @@
package core_test
import (
"fmt"
"slices"
"testing"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
"github.com/pocketbase/pocketbase/tools/types"
)
func TestNewAuthOrigin(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
origin := core.NewAuthOrigin(app)
if origin.Collection().Name != core.CollectionNameAuthOrigins {
t.Fatalf("Expected record with %q collection, got %q", core.CollectionNameAuthOrigins, origin.Collection().Name)
}
}
func TestAuthOriginProxyRecord(t *testing.T) {
t.Parallel()
record := core.NewRecord(core.NewBaseCollection("test"))
record.Id = "test_id"
origin := core.AuthOrigin{}
origin.SetProxyRecord(record)
if origin.ProxyRecord() == nil || origin.ProxyRecord().Id != record.Id {
t.Fatalf("Expected proxy record with id %q, got %v", record.Id, origin.ProxyRecord())
}
}
func TestAuthOriginRecordRef(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
origin := core.NewAuthOrigin(app)
testValues := []string{"test_1", "test2", ""}
for i, testValue := range testValues {
t.Run(fmt.Sprintf("%d_%q", i, testValue), func(t *testing.T) {
origin.SetRecordRef(testValue)
if v := origin.RecordRef(); v != testValue {
t.Fatalf("Expected getter %q, got %q", testValue, v)
}
if v := origin.GetString("recordRef"); v != testValue {
t.Fatalf("Expected field value %q, got %q", testValue, v)
}
})
}
}
func TestAuthOriginCollectionRef(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
origin := core.NewAuthOrigin(app)
testValues := []string{"test_1", "test2", ""}
for i, testValue := range testValues {
t.Run(fmt.Sprintf("%d_%q", i, testValue), func(t *testing.T) {
origin.SetCollectionRef(testValue)
if v := origin.CollectionRef(); v != testValue {
t.Fatalf("Expected getter %q, got %q", testValue, v)
}
if v := origin.GetString("collectionRef"); v != testValue {
t.Fatalf("Expected field value %q, got %q", testValue, v)
}
})
}
}
func TestAuthOriginFingerprint(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
origin := core.NewAuthOrigin(app)
testValues := []string{"test_1", "test2", ""}
for i, testValue := range testValues {
t.Run(fmt.Sprintf("%d_%q", i, testValue), func(t *testing.T) {
origin.SetFingerprint(testValue)
if v := origin.Fingerprint(); v != testValue {
t.Fatalf("Expected getter %q, got %q", testValue, v)
}
if v := origin.GetString("fingerprint"); v != testValue {
t.Fatalf("Expected field value %q, got %q", testValue, v)
}
})
}
}
func TestAuthOriginCreated(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
origin := core.NewAuthOrigin(app)
if v := origin.Created().String(); v != "" {
t.Fatalf("Expected empty created, got %q", v)
}
now := types.NowDateTime()
origin.SetRaw("created", now)
if v := origin.Created().String(); v != now.String() {
t.Fatalf("Expected %q created, got %q", now.String(), v)
}
}
func TestAuthOriginUpdated(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
origin := core.NewAuthOrigin(app)
if v := origin.Updated().String(); v != "" {
t.Fatalf("Expected empty updated, got %q", v)
}
now := types.NowDateTime()
origin.SetRaw("updated", now)
if v := origin.Updated().String(); v != now.String() {
t.Fatalf("Expected %q updated, got %q", now.String(), v)
}
}
func TestAuthOriginPreValidate(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
originsCol, err := app.FindCollectionByNameOrId(core.CollectionNameAuthOrigins)
if err != nil {
t.Fatal(err)
}
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
t.Run("no proxy record", func(t *testing.T) {
origin := &core.AuthOrigin{}
if err := app.Validate(origin); err == nil {
t.Fatal("Expected collection validation error")
}
})
t.Run("non-AuthOrigin collection", func(t *testing.T) {
origin := &core.AuthOrigin{}
origin.SetProxyRecord(core.NewRecord(core.NewBaseCollection("invalid")))
origin.SetRecordRef(user.Id)
origin.SetCollectionRef(user.Collection().Id)
origin.SetFingerprint("abc")
if err := app.Validate(origin); err == nil {
t.Fatal("Expected collection validation error")
}
})
t.Run("AuthOrigin collection", func(t *testing.T) {
origin := &core.AuthOrigin{}
origin.SetProxyRecord(core.NewRecord(originsCol))
origin.SetRecordRef(user.Id)
origin.SetCollectionRef(user.Collection().Id)
origin.SetFingerprint("abc")
if err := app.Validate(origin); err != nil {
t.Fatalf("Expected nil validation error, got %v", err)
}
})
}
func TestAuthOriginValidateHook(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
demo1, err := app.FindRecordById("demo1", "84nmscqy84lsi1t")
if err != nil {
t.Fatal(err)
}
scenarios := []struct {
name string
origin func() *core.AuthOrigin
expectErrors []string
}{
{
"empty",
func() *core.AuthOrigin {
return core.NewAuthOrigin(app)
},
[]string{"collectionRef", "recordRef", "fingerprint"},
},
{
"non-auth collection",
func() *core.AuthOrigin {
origin := core.NewAuthOrigin(app)
origin.SetCollectionRef(demo1.Collection().Id)
origin.SetRecordRef(demo1.Id)
origin.SetFingerprint("abc")
return origin
},
[]string{"collectionRef"},
},
{
"missing record id",
func() *core.AuthOrigin {
origin := core.NewAuthOrigin(app)
origin.SetCollectionRef(user.Collection().Id)
origin.SetRecordRef("missing")
origin.SetFingerprint("abc")
return origin
},
[]string{"recordRef"},
},
{
"valid ref",
func() *core.AuthOrigin {
origin := core.NewAuthOrigin(app)
origin.SetCollectionRef(user.Collection().Id)
origin.SetRecordRef(user.Id)
origin.SetFingerprint("abc")
return origin
},
[]string{},
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
errs := app.Validate(s.origin())
tests.TestValidationErrors(t, errs, s.expectErrors)
})
}
}
func TestAuthOriginPasswordChangeDeletion(t *testing.T) {
t.Parallel()
testApp, _ := tests.NewTestApp()
defer testApp.Cleanup()
// no auth origin associated with it
user1, err := testApp.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
superuser2, err := testApp.FindAuthRecordByEmail(core.CollectionNameSuperusers, "test2@example.com")
if err != nil {
t.Fatal(err)
}
client1, err := testApp.FindAuthRecordByEmail("clients", "test@example.com")
if err != nil {
t.Fatal(err)
}
scenarios := []struct {
record *core.Record
deletedIds []string
}{
{user1, nil},
{superuser2, []string{"5798yh833k6w6w0", "ic55o70g4f8pcl4", "dmy260k6ksjr4ib"}},
{client1, []string{"9r2j0m74260ur8i"}},
}
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%s_%s", i, s.record.Collection().Name, s.record.Id), func(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
deletedIds := []string{}
app.OnRecordDelete().BindFunc(func(e *core.RecordEvent) error {
deletedIds = append(deletedIds, e.Record.Id)
return e.Next()
})
s.record.SetPassword("new_password")
err := app.Save(s.record)
if err != nil {
t.Fatal(err)
}
if len(deletedIds) != len(s.deletedIds) {
t.Fatalf("Expected deleted ids\n%v\ngot\n%v", s.deletedIds, deletedIds)
}
for _, id := range s.deletedIds {
if !slices.Contains(deletedIds, id) {
t.Errorf("Expected to find deleted id %q in %v", id, deletedIds)
}
}
})
}
}

101
core/auth_origin_query.go Normal file
View file

@ -0,0 +1,101 @@
package core
import (
"errors"
"github.com/pocketbase/dbx"
)
// FindAllAuthOriginsByRecord returns all AuthOrigin models linked to the provided auth record (in DESC order).
func (app *BaseApp) FindAllAuthOriginsByRecord(authRecord *Record) ([]*AuthOrigin, error) {
result := []*AuthOrigin{}
err := app.RecordQuery(CollectionNameAuthOrigins).
AndWhere(dbx.HashExp{
"collectionRef": authRecord.Collection().Id,
"recordRef": authRecord.Id,
}).
OrderBy("created DESC").
All(&result)
if err != nil {
return nil, err
}
return result, nil
}
// FindAllAuthOriginsByCollection returns all AuthOrigin models linked to the provided collection (in DESC order).
func (app *BaseApp) FindAllAuthOriginsByCollection(collection *Collection) ([]*AuthOrigin, error) {
result := []*AuthOrigin{}
err := app.RecordQuery(CollectionNameAuthOrigins).
AndWhere(dbx.HashExp{"collectionRef": collection.Id}).
OrderBy("created DESC").
All(&result)
if err != nil {
return nil, err
}
return result, nil
}
// FindAuthOriginById returns a single AuthOrigin model by its id.
func (app *BaseApp) FindAuthOriginById(id string) (*AuthOrigin, error) {
result := &AuthOrigin{}
err := app.RecordQuery(CollectionNameAuthOrigins).
AndWhere(dbx.HashExp{"id": id}).
Limit(1).
One(result)
if err != nil {
return nil, err
}
return result, nil
}
// FindAuthOriginByRecordAndFingerprint returns a single AuthOrigin model
// by its authRecord relation and fingerprint.
func (app *BaseApp) FindAuthOriginByRecordAndFingerprint(authRecord *Record, fingerprint string) (*AuthOrigin, error) {
result := &AuthOrigin{}
err := app.RecordQuery(CollectionNameAuthOrigins).
AndWhere(dbx.HashExp{
"collectionRef": authRecord.Collection().Id,
"recordRef": authRecord.Id,
"fingerprint": fingerprint,
}).
Limit(1).
One(result)
if err != nil {
return nil, err
}
return result, nil
}
// DeleteAllAuthOriginsByRecord deletes all AuthOrigin models associated with the provided record.
//
// Returns a combined error with the failed deletes.
func (app *BaseApp) DeleteAllAuthOriginsByRecord(authRecord *Record) error {
models, err := app.FindAllAuthOriginsByRecord(authRecord)
if err != nil {
return err
}
var errs []error
for _, m := range models {
if err := app.Delete(m); err != nil {
errs = append(errs, err)
}
}
if len(errs) > 0 {
return errors.Join(errs...)
}
return nil
}

View file

@ -0,0 +1,268 @@
package core_test
import (
"fmt"
"slices"
"testing"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
)
func TestFindAllAuthOriginsByRecord(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
demo1, err := app.FindRecordById("demo1", "84nmscqy84lsi1t")
if err != nil {
t.Fatal(err)
}
superuser2, err := app.FindAuthRecordByEmail(core.CollectionNameSuperusers, "test2@example.com")
if err != nil {
t.Fatal(err)
}
superuser4, err := app.FindAuthRecordByEmail(core.CollectionNameSuperusers, "test4@example.com")
if err != nil {
t.Fatal(err)
}
client1, err := app.FindAuthRecordByEmail("clients", "test@example.com")
if err != nil {
t.Fatal(err)
}
scenarios := []struct {
record *core.Record
expected []string
}{
{demo1, nil},
{superuser2, []string{"5798yh833k6w6w0", "ic55o70g4f8pcl4", "dmy260k6ksjr4ib"}},
{superuser4, nil},
{client1, []string{"9r2j0m74260ur8i"}},
}
for _, s := range scenarios {
t.Run(s.record.Collection().Name+"_"+s.record.Id, func(t *testing.T) {
result, err := app.FindAllAuthOriginsByRecord(s.record)
if err != nil {
t.Fatal(err)
}
if len(result) != len(s.expected) {
t.Fatalf("Expected total origins %d, got %d", len(s.expected), len(result))
}
for i, id := range s.expected {
if result[i].Id != id {
t.Errorf("[%d] Expected id %q, got %q", i, id, result[i].Id)
}
}
})
}
}
func TestFindAllAuthOriginsByCollection(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
demo1, err := app.FindCollectionByNameOrId("demo1")
if err != nil {
t.Fatal(err)
}
superusers, err := app.FindCollectionByNameOrId(core.CollectionNameSuperusers)
if err != nil {
t.Fatal(err)
}
clients, err := app.FindCollectionByNameOrId("clients")
if err != nil {
t.Fatal(err)
}
scenarios := []struct {
collection *core.Collection
expected []string
}{
{demo1, nil},
{superusers, []string{"5798yh833k6w6w0", "ic55o70g4f8pcl4", "dmy260k6ksjr4ib", "5f29jy38bf5zm3f"}},
{clients, []string{"9r2j0m74260ur8i"}},
}
for _, s := range scenarios {
t.Run(s.collection.Name, func(t *testing.T) {
result, err := app.FindAllAuthOriginsByCollection(s.collection)
if err != nil {
t.Fatal(err)
}
if len(result) != len(s.expected) {
t.Fatalf("Expected total origins %d, got %d", len(s.expected), len(result))
}
for i, id := range s.expected {
if result[i].Id != id {
t.Errorf("[%d] Expected id %q, got %q", i, id, result[i].Id)
}
}
})
}
}
func TestFindAuthOriginById(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
scenarios := []struct {
id string
expectError bool
}{
{"", true},
{"84nmscqy84lsi1t", true}, // non-origin id
{"9r2j0m74260ur8i", false},
}
for _, s := range scenarios {
t.Run(s.id, func(t *testing.T) {
result, err := app.FindAuthOriginById(s.id)
hasErr := err != nil
if hasErr != s.expectError {
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
}
if hasErr {
return
}
if result.Id != s.id {
t.Fatalf("Expected record with id %q, got %q", s.id, result.Id)
}
})
}
}
func TestFindAuthOriginByRecordAndFingerprint(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
demo1, err := app.FindRecordById("demo1", "84nmscqy84lsi1t")
if err != nil {
t.Fatal(err)
}
superuser2, err := app.FindAuthRecordByEmail(core.CollectionNameSuperusers, "test2@example.com")
if err != nil {
t.Fatal(err)
}
scenarios := []struct {
record *core.Record
fingerprint string
expectError bool
}{
{demo1, "6afbfe481c31c08c55a746cccb88ece0", true},
{superuser2, "", true},
{superuser2, "abc", true},
{superuser2, "22bbbcbed36e25321f384ccf99f60057", false}, // fingerprint from different origin
{superuser2, "6afbfe481c31c08c55a746cccb88ece0", false},
}
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%s_%s", i, s.record.Id, s.fingerprint), func(t *testing.T) {
result, err := app.FindAuthOriginByRecordAndFingerprint(s.record, s.fingerprint)
hasErr := err != nil
if hasErr != s.expectError {
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
}
if hasErr {
return
}
if result.Fingerprint() != s.fingerprint {
t.Fatalf("Expected origin with fingerprint %q, got %q", s.fingerprint, result.Fingerprint())
}
if result.RecordRef() != s.record.Id || result.CollectionRef() != s.record.Collection().Id {
t.Fatalf("Expected record %q (%q), got %q (%q)", s.record.Id, s.record.Collection().Id, result.RecordRef(), result.CollectionRef())
}
})
}
}
func TestDeleteAllAuthOriginsByRecord(t *testing.T) {
t.Parallel()
testApp, _ := tests.NewTestApp()
defer testApp.Cleanup()
demo1, err := testApp.FindRecordById("demo1", "84nmscqy84lsi1t")
if err != nil {
t.Fatal(err)
}
superuser2, err := testApp.FindAuthRecordByEmail(core.CollectionNameSuperusers, "test2@example.com")
if err != nil {
t.Fatal(err)
}
superuser4, err := testApp.FindAuthRecordByEmail(core.CollectionNameSuperusers, "test4@example.com")
if err != nil {
t.Fatal(err)
}
client1, err := testApp.FindAuthRecordByEmail("clients", "test@example.com")
if err != nil {
t.Fatal(err)
}
scenarios := []struct {
record *core.Record
deletedIds []string
}{
{demo1, nil}, // non-auth record
{superuser2, []string{"5798yh833k6w6w0", "ic55o70g4f8pcl4", "dmy260k6ksjr4ib"}},
{superuser4, nil},
{client1, []string{"9r2j0m74260ur8i"}},
}
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%s_%s", i, s.record.Collection().Name, s.record.Id), func(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
deletedIds := []string{}
app.OnRecordDelete().BindFunc(func(e *core.RecordEvent) error {
deletedIds = append(deletedIds, e.Record.Id)
return e.Next()
})
err := app.DeleteAllAuthOriginsByRecord(s.record)
if err != nil {
t.Fatal(err)
}
if len(deletedIds) != len(s.deletedIds) {
t.Fatalf("Expected deleted ids\n%v\ngot\n%v", s.deletedIds, deletedIds)
}
for _, id := range s.deletedIds {
if !slices.Contains(deletedIds, id) {
t.Errorf("Expected to find deleted id %q in %v", id, deletedIds)
}
}
})
}
}

1535
core/base.go Normal file

File diff suppressed because it is too large Load diff

389
core/base_backup.go Normal file
View file

@ -0,0 +1,389 @@
package core
import (
"context"
"errors"
"fmt"
"io"
"log/slog"
"os"
"path/filepath"
"runtime"
"sort"
"time"
"github.com/pocketbase/pocketbase/tools/archive"
"github.com/pocketbase/pocketbase/tools/filesystem"
"github.com/pocketbase/pocketbase/tools/inflector"
"github.com/pocketbase/pocketbase/tools/osutils"
"github.com/pocketbase/pocketbase/tools/security"
)
const (
StoreKeyActiveBackup = "@activeBackup"
)
// CreateBackup creates a new backup of the current app pb_data directory.
//
// If name is empty, it will be autogenerated.
// If backup with the same name exists, the new backup file will replace it.
//
// The backup is executed within a transaction, meaning that new writes
// will be temporary "blocked" until the backup file is generated.
//
// To safely perform the backup, it is recommended to have free disk space
// for at least 2x the size of the pb_data directory.
//
// By default backups are stored in pb_data/backups
// (the backups directory itself is excluded from the generated backup).
//
// When using S3 storage for the uploaded collection files, you have to
// take care manually to backup those since they are not part of the pb_data.
//
// Backups can be stored on S3 if it is configured in app.Settings().Backups.
func (app *BaseApp) CreateBackup(ctx context.Context, name string) error {
if app.Store().Has(StoreKeyActiveBackup) {
return errors.New("try again later - another backup/restore operation has already been started")
}
app.Store().Set(StoreKeyActiveBackup, name)
defer app.Store().Remove(StoreKeyActiveBackup)
event := new(BackupEvent)
event.App = app
event.Context = ctx
event.Name = name
// default root dir entries to exclude from the backup generation
event.Exclude = []string{LocalBackupsDirName, LocalTempDirName, LocalAutocertCacheDirName}
return app.OnBackupCreate().Trigger(event, func(e *BackupEvent) error {
// generate a default name if missing
if e.Name == "" {
e.Name = generateBackupName(e.App, "pb_backup_")
}
// make sure that the special temp directory exists
// note: it needs to be inside the current pb_data to avoid "cross-device link" errors
localTempDir := filepath.Join(e.App.DataDir(), LocalTempDirName)
if err := os.MkdirAll(localTempDir, os.ModePerm); err != nil {
return fmt.Errorf("failed to create a temp dir: %w", err)
}
// archive pb_data in a temp directory, exluding the "backups" and the temp dirs
//
// Run in transaction to temporary block other writes (transactions uses the NonconcurrentDB connection).
// ---
tempPath := filepath.Join(localTempDir, "pb_backup_"+security.PseudorandomString(6))
createErr := e.App.RunInTransaction(func(txApp App) error {
return txApp.AuxRunInTransaction(func(txApp App) error {
// run manual checkpoint and truncate the WAL files
// (errors are ignored because it is not that important and the PRAGMA may not be supported by the used driver)
txApp.DB().NewQuery("PRAGMA wal_checkpoint(TRUNCATE)").Execute()
txApp.AuxDB().NewQuery("PRAGMA wal_checkpoint(TRUNCATE)").Execute()
return archive.Create(txApp.DataDir(), tempPath, e.Exclude...)
})
})
if createErr != nil {
return createErr
}
defer os.Remove(tempPath)
// persist the backup in the backups filesystem
// ---
fsys, err := e.App.NewBackupsFilesystem()
if err != nil {
return err
}
defer fsys.Close()
fsys.SetContext(e.Context)
file, err := filesystem.NewFileFromPath(tempPath)
if err != nil {
return err
}
file.OriginalName = e.Name
file.Name = file.OriginalName
if err := fsys.UploadFile(file, file.Name); err != nil {
return err
}
return nil
})
}
// RestoreBackup restores the backup with the specified name and restarts
// the current running application process.
//
// NB! This feature is experimental and currently is expected to work only on UNIX based systems.
//
// To safely perform the restore it is recommended to have free disk space
// for at least 2x the size of the restored pb_data backup.
//
// The performed steps are:
//
// 1. Download the backup with the specified name in a temp location
// (this is in case of S3; otherwise it creates a temp copy of the zip)
//
// 2. Extract the backup in a temp directory inside the app "pb_data"
// (eg. "pb_data/.pb_temp_to_delete/pb_restore").
//
// 3. Move the current app "pb_data" content (excluding the local backups and the special temp dir)
// under another temp sub dir that will be deleted on the next app start up
// (eg. "pb_data/.pb_temp_to_delete/old_pb_data").
// This is because on some environments it may not be allowed
// to delete the currently open "pb_data" files.
//
// 4. Move the extracted dir content to the app "pb_data".
//
// 5. Restart the app (on successful app bootstap it will also remove the old pb_data).
//
// If a failure occure during the restore process the dir changes are reverted.
// If for whatever reason the revert is not possible, it panics.
//
// Note that if your pb_data has custom network mounts as subdirectories, then
// it is possible the restore to fail during the `os.Rename` operations
// (see https://github.com/pocketbase/pocketbase/issues/4647).
func (app *BaseApp) RestoreBackup(ctx context.Context, name string) error {
if app.Store().Has(StoreKeyActiveBackup) {
return errors.New("try again later - another backup/restore operation has already been started")
}
app.Store().Set(StoreKeyActiveBackup, name)
defer app.Store().Remove(StoreKeyActiveBackup)
event := new(BackupEvent)
event.App = app
event.Context = ctx
event.Name = name
// default root dir entries to exclude from the backup restore
event.Exclude = []string{LocalBackupsDirName, LocalTempDirName, LocalAutocertCacheDirName}
return app.OnBackupRestore().Trigger(event, func(e *BackupEvent) error {
if runtime.GOOS == "windows" {
return errors.New("restore is not supported on Windows")
}
// make sure that the special temp directory exists
// note: it needs to be inside the current pb_data to avoid "cross-device link" errors
localTempDir := filepath.Join(e.App.DataDir(), LocalTempDirName)
if err := os.MkdirAll(localTempDir, os.ModePerm); err != nil {
return fmt.Errorf("failed to create a temp dir: %w", err)
}
fsys, err := e.App.NewBackupsFilesystem()
if err != nil {
return err
}
defer fsys.Close()
fsys.SetContext(e.Context)
if ok, _ := fsys.Exists(name); !ok {
return fmt.Errorf("missing or invalid backup file %q to restore", name)
}
extractedDataDir := filepath.Join(localTempDir, "pb_restore_"+security.PseudorandomString(8))
defer os.RemoveAll(extractedDataDir)
// extract the zip
if e.App.Settings().Backups.S3.Enabled {
br, err := fsys.GetReader(name)
if err != nil {
return err
}
defer br.Close()
// create a temp zip file from the blob.Reader and try to extract it
tempZip, err := os.CreateTemp(localTempDir, "pb_restore_zip")
if err != nil {
return err
}
defer os.Remove(tempZip.Name())
defer tempZip.Close() // note: this technically shouldn't be necessary but it is here to workaround platforms discrepancies
_, err = io.Copy(tempZip, br)
if err != nil {
return err
}
err = archive.Extract(tempZip.Name(), extractedDataDir)
if err != nil {
return err
}
// remove the temp zip file since we no longer need it
// (this is in case the app restarts and the defer calls are not called)
_ = tempZip.Close()
err = os.Remove(tempZip.Name())
if err != nil {
e.App.Logger().Warn(
"[RestoreBackup] Failed to remove the temp zip backup file",
slog.String("file", tempZip.Name()),
slog.String("error", err.Error()),
)
}
} else {
// manually construct the local path to avoid creating a copy of the zip file
// since the blob reader currently doesn't implement ReaderAt
zipPath := filepath.Join(app.DataDir(), LocalBackupsDirName, filepath.Base(name))
err = archive.Extract(zipPath, extractedDataDir)
if err != nil {
return err
}
}
// ensure that at least a database file exists
extractedDB := filepath.Join(extractedDataDir, "data.db")
if _, err := os.Stat(extractedDB); err != nil {
return fmt.Errorf("data.db file is missing or invalid: %w", err)
}
// move the current pb_data content to a special temp location
// that will hold the old data between dirs replace
// (the temp dir will be automatically removed on the next app start)
oldTempDataDir := filepath.Join(localTempDir, "old_pb_data_"+security.PseudorandomString(8))
if err := osutils.MoveDirContent(e.App.DataDir(), oldTempDataDir, e.Exclude...); err != nil {
return fmt.Errorf("failed to move the current pb_data content to a temp location: %w", err)
}
// move the extracted archive content to the app's pb_data
if err := osutils.MoveDirContent(extractedDataDir, e.App.DataDir(), e.Exclude...); err != nil {
return fmt.Errorf("failed to move the extracted archive content to pb_data: %w", err)
}
revertDataDirChanges := func() error {
if err := osutils.MoveDirContent(e.App.DataDir(), extractedDataDir, e.Exclude...); err != nil {
return fmt.Errorf("failed to revert the extracted dir change: %w", err)
}
if err := osutils.MoveDirContent(oldTempDataDir, e.App.DataDir(), e.Exclude...); err != nil {
return fmt.Errorf("failed to revert old pb_data dir change: %w", err)
}
return nil
}
// restart the app
if err := e.App.Restart(); err != nil {
if revertErr := revertDataDirChanges(); revertErr != nil {
panic(revertErr)
}
return fmt.Errorf("failed to restart the app process: %w", err)
}
return nil
})
}
// registerAutobackupHooks registers the autobackup app serve hooks.
func (app *BaseApp) registerAutobackupHooks() {
const jobId = "__pbAutoBackup__"
loadJob := func() {
rawSchedule := app.Settings().Backups.Cron
if rawSchedule == "" {
app.Cron().Remove(jobId)
return
}
app.Cron().Add(jobId, rawSchedule, func() {
const autoPrefix = "@auto_pb_backup_"
name := generateBackupName(app, autoPrefix)
if err := app.CreateBackup(context.Background(), name); err != nil {
app.Logger().Error(
"[Backup cron] Failed to create backup",
slog.String("name", name),
slog.String("error", err.Error()),
)
}
maxKeep := app.Settings().Backups.CronMaxKeep
if maxKeep == 0 {
return // no explicit limit
}
fsys, err := app.NewBackupsFilesystem()
if err != nil {
app.Logger().Error(
"[Backup cron] Failed to initialize the backup filesystem",
slog.String("error", err.Error()),
)
return
}
defer fsys.Close()
files, err := fsys.List(autoPrefix)
if err != nil {
app.Logger().Error(
"[Backup cron] Failed to list autogenerated backups",
slog.String("error", err.Error()),
)
return
}
if maxKeep >= len(files) {
return // nothing to remove
}
// sort desc
sort.Slice(files, func(i, j int) bool {
return files[i].ModTime.After(files[j].ModTime)
})
// keep only the most recent n auto backup files
toRemove := files[maxKeep:]
for _, f := range toRemove {
if err := fsys.Delete(f.Key); err != nil {
app.Logger().Error(
"[Backup cron] Failed to remove old autogenerated backup",
slog.String("key", f.Key),
slog.String("error", err.Error()),
)
}
}
})
}
app.OnBootstrap().BindFunc(func(e *BootstrapEvent) error {
if err := e.Next(); err != nil {
return err
}
loadJob()
return nil
})
app.OnSettingsReload().BindFunc(func(e *SettingsReloadEvent) error {
if err := e.Next(); err != nil {
return err
}
loadJob()
return nil
})
}
func generateBackupName(app App, prefix string) string {
appName := inflector.Snakecase(app.Settings().Meta.AppName)
if len(appName) > 50 {
appName = appName[:50]
}
return fmt.Sprintf(
"%s%s_%s.zip",
prefix,
appName,
time.Now().UTC().Format("20060102150405"),
)
}

164
core/base_backup_test.go Normal file
View file

@ -0,0 +1,164 @@
package core_test
import (
"context"
"fmt"
"io/fs"
"os"
"path/filepath"
"strings"
"testing"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
"github.com/pocketbase/pocketbase/tools/archive"
"github.com/pocketbase/pocketbase/tools/list"
)
func TestCreateBackup(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
// set some long app name with spaces and special characters
app.Settings().Meta.AppName = "test @! " + strings.Repeat("a", 100)
expectedAppNamePrefix := "test_" + strings.Repeat("a", 45)
// test pending error
app.Store().Set(core.StoreKeyActiveBackup, "")
if err := app.CreateBackup(context.Background(), "test.zip"); err == nil {
t.Fatal("Expected pending error, got nil")
}
app.Store().Remove(core.StoreKeyActiveBackup)
// create with auto generated name
if err := app.CreateBackup(context.Background(), ""); err != nil {
t.Fatal("Failed to create a backup with autogenerated name")
}
// create with custom name
if err := app.CreateBackup(context.Background(), "custom"); err != nil {
t.Fatal("Failed to create a backup with custom name")
}
// create new with the same name (aka. replace)
if err := app.CreateBackup(context.Background(), "custom"); err != nil {
t.Fatal("Failed to create and replace a backup with the same name")
}
backupsDir := filepath.Join(app.DataDir(), core.LocalBackupsDirName)
entries, err := os.ReadDir(backupsDir)
if err != nil {
t.Fatal(err)
}
expectedFiles := []string{
`^pb_backup_` + expectedAppNamePrefix + `_\w+\.zip$`,
`^pb_backup_` + expectedAppNamePrefix + `_\w+\.zip.attrs$`,
"custom",
"custom.attrs",
}
if len(entries) != len(expectedFiles) {
names := getEntryNames(entries)
t.Fatalf("Expected %d backup files, got %d: \n%v", len(expectedFiles), len(entries), names)
}
for i, entry := range entries {
if !list.ExistInSliceWithRegex(entry.Name(), expectedFiles) {
t.Fatalf("[%d] Missing backup file %q", i, entry.Name())
}
if strings.HasSuffix(entry.Name(), ".attrs") {
continue
}
path := filepath.Join(backupsDir, entry.Name())
if err := verifyBackupContent(app, path); err != nil {
t.Fatalf("[%d] Failed to verify backup content: %v", i, err)
}
}
}
func TestRestoreBackup(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
// create a initial test backup to ensure that there are at least 1
// backup file and that the generated zip doesn't contain the backups dir
if err := app.CreateBackup(context.Background(), "initial"); err != nil {
t.Fatal("Failed to create test initial backup")
}
// create test backup
if err := app.CreateBackup(context.Background(), "test"); err != nil {
t.Fatal("Failed to create test backup")
}
// test pending error
app.Store().Set(core.StoreKeyActiveBackup, "")
if err := app.RestoreBackup(context.Background(), "test"); err == nil {
t.Fatal("Expected pending error, got nil")
}
app.Store().Remove(core.StoreKeyActiveBackup)
// missing backup
if err := app.RestoreBackup(context.Background(), "missing"); err == nil {
t.Fatal("Expected missing error, got nil")
}
}
// -------------------------------------------------------------------
func verifyBackupContent(app core.App, path string) error {
dir, err := os.MkdirTemp("", "backup_test")
if err != nil {
return err
}
defer os.RemoveAll(dir)
if err := archive.Extract(path, dir); err != nil {
return err
}
expectedRootEntries := []string{
"storage",
"data.db",
"data.db-shm",
"data.db-wal",
"auxiliary.db",
"auxiliary.db-shm",
"auxiliary.db-wal",
".gitignore",
}
entries, err := os.ReadDir(dir)
if err != nil {
return err
}
if len(entries) != len(expectedRootEntries) {
names := getEntryNames(entries)
return fmt.Errorf("Expected %d backup files, got %d: \n%v", len(expectedRootEntries), len(entries), names)
}
for _, entry := range entries {
if !list.ExistInSliceWithRegex(entry.Name(), expectedRootEntries) {
return fmt.Errorf("Didn't expect %q entry", entry.Name())
}
}
return nil
}
func getEntryNames(entries []fs.DirEntry) []string {
names := make([]string, len(entries))
for i, entry := range entries {
names[i] = entry.Name()
}
return names
}

554
core/base_test.go Normal file
View file

@ -0,0 +1,554 @@
package core_test
import (
"context"
"database/sql"
"log/slog"
"os"
"slices"
"testing"
"time"
_ "unsafe"
"github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
"github.com/pocketbase/pocketbase/tools/logger"
"github.com/pocketbase/pocketbase/tools/mailer"
)
func TestNewBaseApp(t *testing.T) {
const testDataDir = "./pb_base_app_test_data_dir/"
defer os.RemoveAll(testDataDir)
app := core.NewBaseApp(core.BaseAppConfig{
DataDir: testDataDir,
EncryptionEnv: "test_env",
IsDev: true,
})
if app.DataDir() != testDataDir {
t.Fatalf("expected DataDir %q, got %q", testDataDir, app.DataDir())
}
if app.EncryptionEnv() != "test_env" {
t.Fatalf("expected EncryptionEnv test_env, got %q", app.EncryptionEnv())
}
if !app.IsDev() {
t.Fatalf("expected IsDev true, got %v", app.IsDev())
}
if app.Store() == nil {
t.Fatal("expected Store to be set, got nil")
}
if app.Settings() == nil {
t.Fatal("expected Settings to be set, got nil")
}
if app.SubscriptionsBroker() == nil {
t.Fatal("expected SubscriptionsBroker to be set, got nil")
}
if app.Cron() == nil {
t.Fatal("expected Cron to be set, got nil")
}
}
func TestBaseAppBootstrap(t *testing.T) {
const testDataDir = "./pb_base_app_test_data_dir/"
defer os.RemoveAll(testDataDir)
app := core.NewBaseApp(core.BaseAppConfig{
DataDir: testDataDir,
})
defer app.ResetBootstrapState()
if app.IsBootstrapped() {
t.Fatal("Didn't expect the application to be bootstrapped.")
}
if err := app.Bootstrap(); err != nil {
t.Fatal(err)
}
if !app.IsBootstrapped() {
t.Fatal("Expected the application to be bootstrapped.")
}
if stat, err := os.Stat(testDataDir); err != nil || !stat.IsDir() {
t.Fatal("Expected test data directory to be created.")
}
type nilCheck struct {
name string
value any
expectNil bool
}
runNilChecks := func(checks []nilCheck) {
for _, check := range checks {
t.Run(check.name, func(t *testing.T) {
isNil := check.value == nil
if isNil != check.expectNil {
t.Fatalf("Expected isNil %v, got %v", check.expectNil, isNil)
}
})
}
}
nilChecksBeforeReset := []nilCheck{
{"[before] db", app.DB(), false},
{"[before] concurrentDB", app.ConcurrentDB(), false},
{"[before] nonconcurrentDB", app.NonconcurrentDB(), false},
{"[before] auxDB", app.AuxDB(), false},
{"[before] auxConcurrentDB", app.AuxConcurrentDB(), false},
{"[before] auxNonconcurrentDB", app.AuxNonconcurrentDB(), false},
{"[before] settings", app.Settings(), false},
{"[before] logger", app.Logger(), false},
{"[before] cached collections", app.Store().Get(core.StoreKeyCachedCollections), false},
}
runNilChecks(nilChecksBeforeReset)
// reset
if err := app.ResetBootstrapState(); err != nil {
t.Fatal(err)
}
nilChecksAfterReset := []nilCheck{
{"[after] db", app.DB(), true},
{"[after] concurrentDB", app.ConcurrentDB(), true},
{"[after] nonconcurrentDB", app.NonconcurrentDB(), true},
{"[after] auxDB", app.AuxDB(), true},
{"[after] auxConcurrentDB", app.AuxConcurrentDB(), true},
{"[after] auxNonconcurrentDB", app.AuxNonconcurrentDB(), true},
{"[after] settings", app.Settings(), false},
{"[after] logger", app.Logger(), false},
{"[after] cached collections", app.Store().Get(core.StoreKeyCachedCollections), false},
}
runNilChecks(nilChecksAfterReset)
}
func TestNewBaseAppTx(t *testing.T) {
const testDataDir = "./pb_base_app_test_data_dir/"
defer os.RemoveAll(testDataDir)
app := core.NewBaseApp(core.BaseAppConfig{
DataDir: testDataDir,
})
defer app.ResetBootstrapState()
if err := app.Bootstrap(); err != nil {
t.Fatal(err)
}
mustNotHaveTx := func(app core.App) {
if app.IsTransactional() {
t.Fatalf("Didn't expect the app to be transactional")
}
if app.TxInfo() != nil {
t.Fatalf("Didn't expect the app.txInfo to be loaded")
}
}
mustHaveTx := func(app core.App) {
if !app.IsTransactional() {
t.Fatalf("Expected the app to be transactional")
}
if app.TxInfo() == nil {
t.Fatalf("Expected the app.txInfo to be loaded")
}
}
mustNotHaveTx(app)
app.RunInTransaction(func(txApp core.App) error {
mustHaveTx(txApp)
return nil
})
mustNotHaveTx(app)
}
func TestBaseAppNewMailClient(t *testing.T) {
const testDataDir = "./pb_base_app_test_data_dir/"
defer os.RemoveAll(testDataDir)
app := core.NewBaseApp(core.BaseAppConfig{
DataDir: testDataDir,
EncryptionEnv: "pb_test_env",
})
defer app.ResetBootstrapState()
client1 := app.NewMailClient()
m1, ok := client1.(*mailer.Sendmail)
if !ok {
t.Fatalf("Expected mailer.Sendmail instance, got %v", m1)
}
if m1.OnSend() == nil || m1.OnSend().Length() == 0 {
t.Fatal("Expected OnSend hook to be registered")
}
app.Settings().SMTP.Enabled = true
client2 := app.NewMailClient()
m2, ok := client2.(*mailer.SMTPClient)
if !ok {
t.Fatalf("Expected mailer.SMTPClient instance, got %v", m2)
}
if m2.OnSend() == nil || m2.OnSend().Length() == 0 {
t.Fatal("Expected OnSend hook to be registered")
}
}
func TestBaseAppNewFilesystem(t *testing.T) {
const testDataDir = "./pb_base_app_test_data_dir/"
defer os.RemoveAll(testDataDir)
app := core.NewBaseApp(core.BaseAppConfig{
DataDir: testDataDir,
})
defer app.ResetBootstrapState()
// local
local, localErr := app.NewFilesystem()
if localErr != nil {
t.Fatal(localErr)
}
if local == nil {
t.Fatal("Expected local filesystem instance, got nil")
}
// misconfigured s3
app.Settings().S3.Enabled = true
s3, s3Err := app.NewFilesystem()
if s3Err == nil {
t.Fatal("Expected S3 error, got nil")
}
if s3 != nil {
t.Fatalf("Expected nil s3 filesystem, got %v", s3)
}
}
func TestBaseAppNewBackupsFilesystem(t *testing.T) {
const testDataDir = "./pb_base_app_test_data_dir/"
defer os.RemoveAll(testDataDir)
app := core.NewBaseApp(core.BaseAppConfig{
DataDir: testDataDir,
})
defer app.ResetBootstrapState()
// local
local, localErr := app.NewBackupsFilesystem()
if localErr != nil {
t.Fatal(localErr)
}
if local == nil {
t.Fatal("Expected local backups filesystem instance, got nil")
}
// misconfigured s3
app.Settings().Backups.S3.Enabled = true
s3, s3Err := app.NewBackupsFilesystem()
if s3Err == nil {
t.Fatal("Expected S3 error, got nil")
}
if s3 != nil {
t.Fatalf("Expected nil s3 backups filesystem, got %v", s3)
}
}
func TestBaseAppLoggerWrites(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
// reset
if err := app.DeleteOldLogs(time.Now()); err != nil {
t.Fatal(err)
}
const logsThreshold = 200
totalLogs := func(app core.App, t *testing.T) int {
var total int
err := app.LogQuery().Select("count(*)").Row(&total)
if err != nil {
t.Fatalf("Failed to fetch total logs: %v", err)
}
return total
}
t.Run("disabled logs retention", func(t *testing.T) {
app.Settings().Logs.MaxDays = 0
for i := 0; i < logsThreshold+1; i++ {
app.Logger().Error("test")
}
if total := totalLogs(app, t); total != 0 {
t.Fatalf("Expected no logs, got %d", total)
}
})
t.Run("test batch logs writes", func(t *testing.T) {
app.Settings().Logs.MaxDays = 1
for i := 0; i < logsThreshold-1; i++ {
app.Logger().Error("test")
}
if total := totalLogs(app, t); total != 0 {
t.Fatalf("Expected no logs, got %d", total)
}
// should trigger batch write
app.Logger().Error("test")
// should be added for the next batch write
app.Logger().Error("test")
if total := totalLogs(app, t); total != logsThreshold {
t.Fatalf("Expected %d logs, got %d", logsThreshold, total)
}
// wait for ~3 secs to check the timer trigger
time.Sleep(3200 * time.Millisecond)
if total := totalLogs(app, t); total != logsThreshold+1 {
t.Fatalf("Expected %d logs, got %d", logsThreshold+1, total)
}
})
}
func TestBaseAppRefreshSettingsLoggerMinLevelEnabled(t *testing.T) {
scenarios := []struct {
name string
isDev bool
level int
// level->enabled map
expectations map[int]bool
}{
{
"dev mode",
true,
4,
map[int]bool{
3: true,
4: true,
5: true,
},
},
{
"nondev mode",
false,
4,
map[int]bool{
3: false,
4: true,
5: true,
},
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
const testDataDir = "./pb_base_app_test_data_dir/"
defer os.RemoveAll(testDataDir)
app := core.NewBaseApp(core.BaseAppConfig{
DataDir: testDataDir,
IsDev: s.isDev,
})
defer app.ResetBootstrapState()
if err := app.Bootstrap(); err != nil {
t.Fatal(err)
}
// silence query logs
app.ConcurrentDB().(*dbx.DB).ExecLogFunc = func(ctx context.Context, t time.Duration, sql string, result sql.Result, err error) {}
app.ConcurrentDB().(*dbx.DB).QueryLogFunc = func(ctx context.Context, t time.Duration, sql string, rows *sql.Rows, err error) {}
app.NonconcurrentDB().(*dbx.DB).ExecLogFunc = func(ctx context.Context, t time.Duration, sql string, result sql.Result, err error) {}
app.NonconcurrentDB().(*dbx.DB).QueryLogFunc = func(ctx context.Context, t time.Duration, sql string, rows *sql.Rows, err error) {}
handler, ok := app.Logger().Handler().(*logger.BatchHandler)
if !ok {
t.Fatalf("Expected BatchHandler, got %v", app.Logger().Handler())
}
app.Settings().Logs.MinLevel = s.level
if err := app.Save(app.Settings()); err != nil {
t.Fatalf("Failed to save settings: %v", err)
}
for level, enabled := range s.expectations {
if v := handler.Enabled(context.Background(), slog.Level(level)); v != enabled {
t.Fatalf("Expected level %d Enabled() to be %v, got %v", level, enabled, v)
}
}
})
}
}
func TestBaseAppDBDualBuilder(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
concurrentQueries := []string{}
nonconcurrentQueries := []string{}
app.ConcurrentDB().(*dbx.DB).QueryLogFunc = func(ctx context.Context, t time.Duration, sql string, rows *sql.Rows, err error) {
concurrentQueries = append(concurrentQueries, sql)
}
app.ConcurrentDB().(*dbx.DB).ExecLogFunc = func(ctx context.Context, t time.Duration, sql string, result sql.Result, err error) {
concurrentQueries = append(concurrentQueries, sql)
}
app.NonconcurrentDB().(*dbx.DB).QueryLogFunc = func(ctx context.Context, t time.Duration, sql string, rows *sql.Rows, err error) {
nonconcurrentQueries = append(nonconcurrentQueries, sql)
}
app.NonconcurrentDB().(*dbx.DB).ExecLogFunc = func(ctx context.Context, t time.Duration, sql string, result sql.Result, err error) {
nonconcurrentQueries = append(nonconcurrentQueries, sql)
}
type testQuery struct {
query string
isConcurrent bool
}
regularTests := []testQuery{
{" \n sEleCt 1", true},
{"With abc(x) AS (select 2) SELECT x FROM abc", true},
{"create table t1(x int)", false},
{"insert into t1(x) values(1)", false},
{"update t1 set x = 2", false},
{"delete from t1", false},
}
txTests := []testQuery{
{"select 3", false},
{" \n WITH abc(x) AS (select 4) SELECT x FROM abc", false},
{"create table t2(x int)", false},
{"insert into t2(x) values(1)", false},
{"update t2 set x = 2", false},
{"delete from t2", false},
}
for _, item := range regularTests {
_, err := app.DB().NewQuery(item.query).Execute()
if err != nil {
t.Fatalf("Failed to execute query %q error: %v", item.query, err)
}
}
app.RunInTransaction(func(txApp core.App) error {
for _, item := range txTests {
_, err := txApp.DB().NewQuery(item.query).Execute()
if err != nil {
t.Fatalf("Failed to execute query %q error: %v", item.query, err)
}
}
return nil
})
allTests := append(regularTests, txTests...)
for _, item := range allTests {
if item.isConcurrent {
if !slices.Contains(concurrentQueries, item.query) {
t.Fatalf("Expected concurrent query\n%q\ngot\nconcurrent:%v\nnonconcurrent:%v", item.query, concurrentQueries, nonconcurrentQueries)
}
} else {
if !slices.Contains(nonconcurrentQueries, item.query) {
t.Fatalf("Expected nonconcurrent query\n%q\ngot\nconcurrent:%v\nnonconcurrent:%v", item.query, concurrentQueries, nonconcurrentQueries)
}
}
}
}
func TestBaseAppAuxDBDualBuilder(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
concurrentQueries := []string{}
nonconcurrentQueries := []string{}
app.AuxConcurrentDB().(*dbx.DB).QueryLogFunc = func(ctx context.Context, t time.Duration, sql string, rows *sql.Rows, err error) {
concurrentQueries = append(concurrentQueries, sql)
}
app.AuxConcurrentDB().(*dbx.DB).ExecLogFunc = func(ctx context.Context, t time.Duration, sql string, result sql.Result, err error) {
concurrentQueries = append(concurrentQueries, sql)
}
app.AuxNonconcurrentDB().(*dbx.DB).QueryLogFunc = func(ctx context.Context, t time.Duration, sql string, rows *sql.Rows, err error) {
nonconcurrentQueries = append(nonconcurrentQueries, sql)
}
app.AuxNonconcurrentDB().(*dbx.DB).ExecLogFunc = func(ctx context.Context, t time.Duration, sql string, result sql.Result, err error) {
nonconcurrentQueries = append(nonconcurrentQueries, sql)
}
type testQuery struct {
query string
isConcurrent bool
}
regularTests := []testQuery{
{" \n sEleCt 1", true},
{"With abc(x) AS (select 2) SELECT x FROM abc", true},
{"create table t1(x int)", false},
{"insert into t1(x) values(1)", false},
{"update t1 set x = 2", false},
{"delete from t1", false},
}
txTests := []testQuery{
{"select 3", false},
{" \n WITH abc(x) AS (select 4) SELECT x FROM abc", false},
{"create table t2(x int)", false},
{"insert into t2(x) values(1)", false},
{"update t2 set x = 2", false},
{"delete from t2", false},
}
for _, item := range regularTests {
_, err := app.AuxDB().NewQuery(item.query).Execute()
if err != nil {
t.Fatalf("Failed to execute query %q error: %v", item.query, err)
}
}
app.AuxRunInTransaction(func(txApp core.App) error {
for _, item := range txTests {
_, err := txApp.AuxDB().NewQuery(item.query).Execute()
if err != nil {
t.Fatalf("Failed to execute query %q error: %v", item.query, err)
}
}
return nil
})
allTests := append(regularTests, txTests...)
for _, item := range allTests {
if item.isConcurrent {
if !slices.Contains(concurrentQueries, item.query) {
t.Fatalf("Expected concurrent query\n%q\ngot\nconcurrent:%v\nnonconcurrent:%v", item.query, concurrentQueries, nonconcurrentQueries)
}
} else {
if !slices.Contains(nonconcurrentQueries, item.query) {
t.Fatalf("Expected nonconcurrent query\n%q\ngot\nconcurrent:%v\nnonconcurrent:%v", item.query, concurrentQueries, nonconcurrentQueries)
}
}
}
}

200
core/collection_import.go Normal file
View file

@ -0,0 +1,200 @@
package core
import (
"cmp"
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"slices"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/spf13/cast"
)
// ImportCollectionsByMarshaledJSON is the same as [ImportCollections]
// but accept marshaled json array as import data (usually used for the autogenerated snapshots).
func (app *BaseApp) ImportCollectionsByMarshaledJSON(rawSliceOfMaps []byte, deleteMissing bool) error {
data := []map[string]any{}
err := json.Unmarshal(rawSliceOfMaps, &data)
if err != nil {
return err
}
return app.ImportCollections(data, deleteMissing)
}
// ImportCollections imports the provided collections data in a single transaction.
//
// For existing matching collections, the imported data is unmarshaled on top of the existing model.
//
// NB! If deleteMissing is true, ALL NON-SYSTEM COLLECTIONS AND SCHEMA FIELDS,
// that are not present in the imported configuration, WILL BE DELETED
// (this includes their related records data).
func (app *BaseApp) ImportCollections(toImport []map[string]any, deleteMissing bool) error {
if len(toImport) == 0 {
// prevent accidentally deleting all collections
return errors.New("no collections to import")
}
importedCollections := make([]*Collection, len(toImport))
mappedImported := make(map[string]*Collection, len(toImport))
// normalize imported collections data to ensure that all
// collection fields are present and properly initialized
for i, data := range toImport {
var imported *Collection
identifier := cast.ToString(data["id"])
if identifier == "" {
identifier = cast.ToString(data["name"])
}
existing, err := app.FindCollectionByNameOrId(identifier)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return err
}
if existing != nil {
// refetch for deep copy
imported, err = app.FindCollectionByNameOrId(existing.Id)
if err != nil {
return err
}
// ensure that the fields will be cleared
if data["fields"] == nil && deleteMissing {
data["fields"] = []map[string]any{}
}
rawData, err := json.Marshal(data)
if err != nil {
return err
}
// load the imported data
err = json.Unmarshal(rawData, imported)
if err != nil {
return err
}
// extend with the existing fields if necessary
for _, f := range existing.Fields {
if !f.GetSystem() && deleteMissing {
continue
}
if imported.Fields.GetById(f.GetId()) == nil {
// replace with the existing id to prevent accidental column deletion
// since otherwise the imported field will be treated as a new one
found := imported.Fields.GetByName(f.GetName())
if found != nil && found.Type() == f.Type() {
found.SetId(f.GetId())
}
imported.Fields.Add(f)
}
}
} else {
imported = &Collection{}
rawData, err := json.Marshal(data)
if err != nil {
return err
}
// load the imported data
err = json.Unmarshal(rawData, imported)
if err != nil {
return err
}
}
imported.IntegrityChecks(false)
importedCollections[i] = imported
mappedImported[imported.Id] = imported
}
// reorder views last since the view query could depend on some of the other collections
slices.SortStableFunc(importedCollections, func(a, b *Collection) int {
cmpA := -1
if a.IsView() {
cmpA = 1
}
cmpB := -1
if b.IsView() {
cmpB = 1
}
res := cmp.Compare(cmpA, cmpB)
if res == 0 {
res = a.Created.Compare(b.Created)
if res == 0 {
res = a.Updated.Compare(b.Updated)
}
}
return res
})
return app.RunInTransaction(func(txApp App) error {
existingCollections := []*Collection{}
if err := txApp.CollectionQuery().OrderBy("updated ASC").All(&existingCollections); err != nil {
return err
}
mappedExisting := make(map[string]*Collection, len(existingCollections))
for _, existing := range existingCollections {
existing.IntegrityChecks(false)
mappedExisting[existing.Id] = existing
}
// delete old collections not available in the new configuration
// (before saving the imports in case a deleted collection name is being reused)
if deleteMissing {
for _, existing := range existingCollections {
if mappedImported[existing.Id] != nil || existing.System {
continue // exist or system
}
// delete collection
if err := txApp.Delete(existing); err != nil {
return err
}
}
}
// upsert imported collections
for _, imported := range importedCollections {
if err := txApp.SaveNoValidate(imported); err != nil {
return fmt.Errorf("failed to save collection %q: %w", imported.Name, err)
}
}
// run validations
for _, imported := range importedCollections {
original := mappedExisting[imported.Id]
if original == nil {
original = imported
}
validator := newCollectionValidator(
context.Background(),
txApp,
imported,
original,
)
if err := validator.run(); err != nil {
// serialize the validation error(s)
serializedErr, _ := json.MarshalIndent(err, "", " ")
return validation.Errors{"collections": validation.NewError(
"validation_collections_import_failure",
fmt.Sprintf("Data validations failed for collection %q (%s):\n%s", imported.Name, imported.Id, serializedErr),
)}
}
}
return nil
})
}

View file

@ -0,0 +1,476 @@
package core_test
import (
"encoding/json"
"strings"
"testing"
"github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
)
func TestImportCollections(t *testing.T) {
t.Parallel()
testApp, _ := tests.NewTestApp()
defer testApp.Cleanup()
var regularCollections []*core.Collection
err := testApp.CollectionQuery().AndWhere(dbx.HashExp{"system": false}).All(&regularCollections)
if err != nil {
t.Fatal(err)
}
var systemCollections []*core.Collection
err = testApp.CollectionQuery().AndWhere(dbx.HashExp{"system": true}).All(&systemCollections)
if err != nil {
t.Fatal(err)
}
totalRegularCollections := len(regularCollections)
totalSystemCollections := len(systemCollections)
totalCollections := totalRegularCollections + totalSystemCollections
scenarios := []struct {
name string
data []map[string]any
deleteMissing bool
expectError bool
expectCollectionsCount int
afterTestFunc func(testApp *tests.TestApp, resultCollections []*core.Collection)
}{
{
name: "empty collections",
data: []map[string]any{},
expectError: true,
expectCollectionsCount: totalCollections,
},
{
name: "minimal collection import (with missing system fields)",
data: []map[string]any{
{"name": "import_test1", "type": "auth"},
{
"name": "import_test2", "fields": []map[string]any{
{"name": "test", "type": "text"},
},
},
},
deleteMissing: false,
expectError: false,
expectCollectionsCount: totalCollections + 2,
},
{
name: "minimal collection import (trigger collection model validations)",
data: []map[string]any{
{"name": ""},
{
"name": "import_test2", "fields": []map[string]any{
{"name": "test", "type": "text"},
},
},
},
deleteMissing: false,
expectError: true,
expectCollectionsCount: totalCollections,
},
{
name: "minimal collection import (trigger field settings validation)",
data: []map[string]any{
{"name": "import_test", "fields": []map[string]any{{"name": "test", "type": "text", "min": -1}}},
},
deleteMissing: false,
expectError: true,
expectCollectionsCount: totalCollections,
},
{
name: "new + update + delete (system collections delete should be ignored)",
data: []map[string]any{
{
"id": "wsmn24bux7wo113",
"name": "demo",
"fields": []map[string]any{
{
"id": "_2hlxbmp",
"name": "title",
"type": "text",
"system": false,
"required": true,
"min": 3,
"max": nil,
"pattern": "",
},
},
"indexes": []string{},
},
{
"name": "import1",
"fields": []map[string]any{
{
"name": "active",
"type": "bool",
},
},
},
},
deleteMissing: true,
expectError: false,
expectCollectionsCount: totalSystemCollections + 2,
},
{
name: "test with deleteMissing: false",
data: []map[string]any{
{
// "id": "wsmn24bux7wo113", // test update with only name as identifier
"name": "demo1",
"fields": []map[string]any{
{
"id": "_2hlxbmp",
"name": "title",
"type": "text",
"system": false,
"required": true,
"min": 3,
"max": nil,
"pattern": "",
},
{
"id": "_2hlxbmp",
"name": "field_with_duplicate_id",
"type": "text",
"system": false,
"required": true,
"unique": false,
"min": 4,
"max": nil,
"pattern": "",
},
{
"id": "abcd_import",
"name": "new_field",
"type": "text",
},
},
},
{
"name": "new_import",
"fields": []map[string]any{
{
"id": "abcd_import",
"name": "active",
"type": "bool",
},
},
},
},
deleteMissing: false,
expectError: false,
expectCollectionsCount: totalCollections + 1,
afterTestFunc: func(testApp *tests.TestApp, resultCollections []*core.Collection) {
expectedCollectionFields := map[string]int{
core.CollectionNameAuthOrigins: 6,
"nologin": 10,
"demo1": 19,
"demo2": 5,
"demo3": 5,
"demo4": 16,
"demo5": 9,
"new_import": 2,
}
for name, expectedCount := range expectedCollectionFields {
collection, err := testApp.FindCollectionByNameOrId(name)
if err != nil {
t.Fatal(err)
}
if totalFields := len(collection.Fields); totalFields != expectedCount {
t.Errorf("Expected %d %q fields, got %d", expectedCount, collection.Name, totalFields)
}
}
},
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
testApp, _ := tests.NewTestApp()
defer testApp.Cleanup()
err := testApp.ImportCollections(s.data, s.deleteMissing)
hasErr := err != nil
if hasErr != s.expectError {
t.Fatalf("Expected hasErr to be %v, got %v (%v)", s.expectError, hasErr, err)
}
// check collections count
collections := []*core.Collection{}
if err := testApp.CollectionQuery().All(&collections); err != nil {
t.Fatal(err)
}
if len(collections) != s.expectCollectionsCount {
t.Fatalf("Expected %d collections, got %d", s.expectCollectionsCount, len(collections))
}
if s.afterTestFunc != nil {
s.afterTestFunc(testApp, collections)
}
})
}
}
func TestImportCollectionsByMarshaledJSON(t *testing.T) {
t.Parallel()
testApp, _ := tests.NewTestApp()
defer testApp.Cleanup()
var regularCollections []*core.Collection
err := testApp.CollectionQuery().AndWhere(dbx.HashExp{"system": false}).All(&regularCollections)
if err != nil {
t.Fatal(err)
}
var systemCollections []*core.Collection
err = testApp.CollectionQuery().AndWhere(dbx.HashExp{"system": true}).All(&systemCollections)
if err != nil {
t.Fatal(err)
}
totalRegularCollections := len(regularCollections)
totalSystemCollections := len(systemCollections)
totalCollections := totalRegularCollections + totalSystemCollections
scenarios := []struct {
name string
data string
deleteMissing bool
expectError bool
expectCollectionsCount int
afterTestFunc func(testApp *tests.TestApp, resultCollections []*core.Collection)
}{
{
name: "invalid json array",
data: `{"test":123}`,
expectError: true,
expectCollectionsCount: totalCollections,
},
{
name: "new + update + delete (system collections delete should be ignored)",
data: `[
{
"id": "wsmn24bux7wo113",
"name": "demo",
"fields": [
{
"id": "_2hlxbmp",
"name": "title",
"type": "text",
"system": false,
"required": true,
"min": 3,
"max": null,
"pattern": ""
}
],
"indexes": []
},
{
"name": "import1",
"fields": [
{
"name": "active",
"type": "bool"
}
]
}
]`,
deleteMissing: true,
expectError: false,
expectCollectionsCount: totalSystemCollections + 2,
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
testApp, _ := tests.NewTestApp()
defer testApp.Cleanup()
err := testApp.ImportCollectionsByMarshaledJSON([]byte(s.data), s.deleteMissing)
hasErr := err != nil
if hasErr != s.expectError {
t.Fatalf("Expected hasErr to be %v, got %v (%v)", s.expectError, hasErr, err)
}
// check collections count
collections := []*core.Collection{}
if err := testApp.CollectionQuery().All(&collections); err != nil {
t.Fatal(err)
}
if len(collections) != s.expectCollectionsCount {
t.Fatalf("Expected %d collections, got %d", s.expectCollectionsCount, len(collections))
}
if s.afterTestFunc != nil {
s.afterTestFunc(testApp, collections)
}
})
}
}
func TestImportCollectionsUpdateRules(t *testing.T) {
t.Parallel()
scenarios := []struct {
name string
data map[string]any
deleteMissing bool
}{
{
"extend existing by name (without deleteMissing)",
map[string]any{"name": "clients", "authToken": map[string]any{"duration": 100}, "fields": []map[string]any{{"name": "test", "type": "text"}}},
false,
},
{
"extend existing by id (without deleteMissing)",
map[string]any{"id": "v851q4r790rhknl", "authToken": map[string]any{"duration": 100}, "fields": []map[string]any{{"name": "test", "type": "text"}}},
false,
},
{
"extend with delete missing",
map[string]any{
"id": "v851q4r790rhknl",
"authToken": map[string]any{"duration": 100},
"fields": []map[string]any{{"name": "test", "type": "text"}},
"passwordAuth": map[string]any{"identityFields": []string{"email"}},
"indexes": []string{
// min required system fields indexes
"CREATE UNIQUE INDEX `_v851q4r790rhknl_email_idx` ON `clients` (email) WHERE email != ''",
"CREATE UNIQUE INDEX `_v851q4r790rhknl_tokenKey_idx` ON `clients` (tokenKey)",
},
},
true,
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
testApp, _ := tests.NewTestApp()
defer testApp.Cleanup()
beforeCollection, err := testApp.FindCollectionByNameOrId("clients")
if err != nil {
t.Fatal(err)
}
err = testApp.ImportCollections([]map[string]any{s.data}, s.deleteMissing)
if err != nil {
t.Fatal(err)
}
afterCollection, err := testApp.FindCollectionByNameOrId("clients")
if err != nil {
t.Fatal(err)
}
if afterCollection.AuthToken.Duration != 100 {
t.Fatalf("Expected AuthToken duration to be %d, got %d", 100, afterCollection.AuthToken.Duration)
}
if beforeCollection.AuthToken.Secret != afterCollection.AuthToken.Secret {
t.Fatalf("Expected AuthToken secrets to remain the same, got\n%q\nVS\n%q", beforeCollection.AuthToken.Secret, afterCollection.AuthToken.Secret)
}
if beforeCollection.Name != afterCollection.Name {
t.Fatalf("Expected Name to remain the same, got\n%q\nVS\n%q", beforeCollection.Name, afterCollection.Name)
}
if beforeCollection.Id != afterCollection.Id {
t.Fatalf("Expected Id to remain the same, got\n%q\nVS\n%q", beforeCollection.Id, afterCollection.Id)
}
if !s.deleteMissing {
totalExpectedFields := len(beforeCollection.Fields) + 1
if v := len(afterCollection.Fields); v != totalExpectedFields {
t.Fatalf("Expected %d total fields, got %d", totalExpectedFields, v)
}
if afterCollection.Fields.GetByName("test") == nil {
t.Fatalf("Missing new field %q", "test")
}
// ensure that the old fields still exist
oldFields := beforeCollection.Fields.FieldNames()
for _, name := range oldFields {
if afterCollection.Fields.GetByName(name) == nil {
t.Fatalf("Missing expected old field %q", name)
}
}
} else {
totalExpectedFields := 1
for _, f := range beforeCollection.Fields {
if f.GetSystem() {
totalExpectedFields++
}
}
if v := len(afterCollection.Fields); v != totalExpectedFields {
t.Fatalf("Expected %d total fields, got %d", totalExpectedFields, v)
}
if afterCollection.Fields.GetByName("test") == nil {
t.Fatalf("Missing new field %q", "test")
}
// ensure that the old system fields still exist
for _, f := range beforeCollection.Fields {
if f.GetSystem() && afterCollection.Fields.GetByName(f.GetName()) == nil {
t.Fatalf("Missing expected old field %q", f.GetName())
}
}
}
})
}
}
func TestImportCollectionsCreateRules(t *testing.T) {
t.Parallel()
testApp, _ := tests.NewTestApp()
defer testApp.Cleanup()
err := testApp.ImportCollections([]map[string]any{
{"name": "new_test", "type": "auth", "authToken": map[string]any{"duration": 123}, "fields": []map[string]any{{"name": "test", "type": "text"}}},
}, false)
if err != nil {
t.Fatal(err)
}
collection, err := testApp.FindCollectionByNameOrId("new_test")
if err != nil {
t.Fatal(err)
}
raw, err := json.Marshal(collection)
if err != nil {
t.Fatal(err)
}
rawStr := string(raw)
expectedParts := []string{
`"name":"new_test"`,
`"fields":[`,
`"name":"id"`,
`"name":"email"`,
`"name":"tokenKey"`,
`"name":"password"`,
`"name":"test"`,
`"indexes":[`,
`CREATE UNIQUE INDEX`,
`"duration":123`,
}
for _, part := range expectedParts {
if !strings.Contains(rawStr, part) {
t.Errorf("Missing %q in\n%s", part, rawStr)
}
}
}

1073
core/collection_model.go Normal file

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,543 @@
package core
import (
"strconv"
"strings"
"time"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/go-ozzo/ozzo-validation/v4/is"
"github.com/pocketbase/pocketbase/tools/auth"
"github.com/pocketbase/pocketbase/tools/list"
"github.com/pocketbase/pocketbase/tools/security"
"github.com/pocketbase/pocketbase/tools/types"
"github.com/spf13/cast"
)
func (m *Collection) unsetMissingOAuth2MappedFields() {
if !m.IsAuth() {
return
}
if m.OAuth2.MappedFields.Id != "" {
if m.Fields.GetByName(m.OAuth2.MappedFields.Id) == nil {
m.OAuth2.MappedFields.Id = ""
}
}
if m.OAuth2.MappedFields.Name != "" {
if m.Fields.GetByName(m.OAuth2.MappedFields.Name) == nil {
m.OAuth2.MappedFields.Name = ""
}
}
if m.OAuth2.MappedFields.Username != "" {
if m.Fields.GetByName(m.OAuth2.MappedFields.Username) == nil {
m.OAuth2.MappedFields.Username = ""
}
}
if m.OAuth2.MappedFields.AvatarURL != "" {
if m.Fields.GetByName(m.OAuth2.MappedFields.AvatarURL) == nil {
m.OAuth2.MappedFields.AvatarURL = ""
}
}
}
func (m *Collection) setDefaultAuthOptions() {
m.collectionAuthOptions = collectionAuthOptions{
VerificationTemplate: defaultVerificationTemplate,
ResetPasswordTemplate: defaultResetPasswordTemplate,
ConfirmEmailChangeTemplate: defaultConfirmEmailChangeTemplate,
AuthRule: types.Pointer(""),
AuthAlert: AuthAlertConfig{
Enabled: true,
EmailTemplate: defaultAuthAlertTemplate,
},
PasswordAuth: PasswordAuthConfig{
Enabled: true,
IdentityFields: []string{FieldNameEmail},
},
MFA: MFAConfig{
Enabled: false,
Duration: 1800, // 30min
},
OTP: OTPConfig{
Enabled: false,
Duration: 180, // 3min
Length: 8,
EmailTemplate: defaultOTPTemplate,
},
AuthToken: TokenConfig{
Secret: security.RandomString(50),
Duration: 604800, // 7 days
},
PasswordResetToken: TokenConfig{
Secret: security.RandomString(50),
Duration: 1800, // 30min
},
EmailChangeToken: TokenConfig{
Secret: security.RandomString(50),
Duration: 1800, // 30min
},
VerificationToken: TokenConfig{
Secret: security.RandomString(50),
Duration: 259200, // 3days
},
FileToken: TokenConfig{
Secret: security.RandomString(50),
Duration: 180, // 3min
},
}
}
var _ optionsValidator = (*collectionAuthOptions)(nil)
// collectionAuthOptions defines the options for the "auth" type collection.
type collectionAuthOptions struct {
// AuthRule could be used to specify additional record constraints
// applied after record authentication and right before returning the
// auth token response to the client.
//
// For example, to allow only verified users you could set it to
// "verified = true".
//
// Set it to empty string to allow any Auth collection record to authenticate.
//
// Set it to nil to disallow authentication altogether for the collection
// (that includes password, OAuth2, etc.).
AuthRule *string `form:"authRule" json:"authRule"`
// ManageRule gives admin-like permissions to allow fully managing
// the auth record(s), eg. changing the password without requiring
// to enter the old one, directly updating the verified state and email, etc.
//
// This rule is executed in addition to the Create and Update API rules.
ManageRule *string `form:"manageRule" json:"manageRule"`
// AuthAlert defines options related to the auth alerts on new device login.
AuthAlert AuthAlertConfig `form:"authAlert" json:"authAlert"`
// OAuth2 specifies whether OAuth2 auth is enabled for the collection
// and which OAuth2 providers are allowed.
OAuth2 OAuth2Config `form:"oauth2" json:"oauth2"`
// PasswordAuth defines options related to the collection password authentication.
PasswordAuth PasswordAuthConfig `form:"passwordAuth" json:"passwordAuth"`
// MFA defines options related to the Multi-factor authentication (MFA).
MFA MFAConfig `form:"mfa" json:"mfa"`
// OTP defines options related to the One-time password authentication (OTP).
OTP OTPConfig `form:"otp" json:"otp"`
// Various token configurations
// ---
AuthToken TokenConfig `form:"authToken" json:"authToken"`
PasswordResetToken TokenConfig `form:"passwordResetToken" json:"passwordResetToken"`
EmailChangeToken TokenConfig `form:"emailChangeToken" json:"emailChangeToken"`
VerificationToken TokenConfig `form:"verificationToken" json:"verificationToken"`
FileToken TokenConfig `form:"fileToken" json:"fileToken"`
// Default email templates
// ---
VerificationTemplate EmailTemplate `form:"verificationTemplate" json:"verificationTemplate"`
ResetPasswordTemplate EmailTemplate `form:"resetPasswordTemplate" json:"resetPasswordTemplate"`
ConfirmEmailChangeTemplate EmailTemplate `form:"confirmEmailChangeTemplate" json:"confirmEmailChangeTemplate"`
}
func (o *collectionAuthOptions) validate(cv *collectionValidator) error {
err := validation.ValidateStruct(o,
validation.Field(
&o.AuthRule,
validation.By(cv.checkRule),
validation.By(cv.ensureNoSystemRuleChange(cv.original.AuthRule)),
),
validation.Field(
&o.ManageRule,
validation.NilOrNotEmpty,
validation.By(cv.checkRule),
validation.By(cv.ensureNoSystemRuleChange(cv.original.ManageRule)),
),
validation.Field(&o.AuthAlert),
validation.Field(&o.PasswordAuth),
validation.Field(&o.OAuth2),
validation.Field(&o.OTP),
validation.Field(&o.MFA),
validation.Field(&o.AuthToken),
validation.Field(&o.PasswordResetToken),
validation.Field(&o.EmailChangeToken),
validation.Field(&o.VerificationToken),
validation.Field(&o.FileToken),
validation.Field(&o.VerificationTemplate, validation.Required),
validation.Field(&o.ResetPasswordTemplate, validation.Required),
validation.Field(&o.ConfirmEmailChangeTemplate, validation.Required),
)
if err != nil {
return err
}
if o.MFA.Enabled {
// if MFA is enabled require at least 2 auth methods
//
// @todo maybe consider disabling the check because if custom auth methods
// are registered it may fail since we don't have mechanism to detect them at the moment
authsEnabled := 0
if o.PasswordAuth.Enabled {
authsEnabled++
}
if o.OAuth2.Enabled {
authsEnabled++
}
if o.OTP.Enabled {
authsEnabled++
}
if authsEnabled < 2 {
return validation.Errors{
"mfa": validation.Errors{
"enabled": validation.NewError("validation_mfa_not_enough_auths", "MFA requires at least 2 auth methods to be enabled."),
},
}
}
if o.MFA.Rule != "" {
mfaRuleValidators := []validation.RuleFunc{
cv.checkRule,
cv.ensureNoSystemRuleChange(&cv.original.MFA.Rule),
}
for _, validator := range mfaRuleValidators {
err := validator(&o.MFA.Rule)
if err != nil {
return validation.Errors{
"mfa": validation.Errors{
"rule": err,
},
}
}
}
}
}
// extra check to ensure that only unique identity fields are used
if o.PasswordAuth.Enabled {
err = validation.Validate(o.PasswordAuth.IdentityFields, validation.By(cv.checkFieldsForUniqueIndex))
if err != nil {
return validation.Errors{
"passwordAuth": validation.Errors{
"identityFields": err,
},
}
}
}
return nil
}
// -------------------------------------------------------------------
type EmailTemplate struct {
Subject string `form:"subject" json:"subject"`
Body string `form:"body" json:"body"`
}
// Validate makes EmailTemplate validatable by implementing [validation.Validatable] interface.
func (t EmailTemplate) Validate() error {
return validation.ValidateStruct(&t,
validation.Field(&t.Subject, validation.Required),
validation.Field(&t.Body, validation.Required),
)
}
// Resolve replaces the placeholder parameters in the current email
// template and returns its components as ready-to-use strings.
func (t EmailTemplate) Resolve(placeholders map[string]any) (subject, body string) {
body = t.Body
subject = t.Subject
for k, v := range placeholders {
vStr := cast.ToString(v)
// replace subject placeholder params (if any)
subject = strings.ReplaceAll(subject, k, vStr)
// replace body placeholder params (if any)
body = strings.ReplaceAll(body, k, vStr)
}
return subject, body
}
// -------------------------------------------------------------------
type AuthAlertConfig struct {
Enabled bool `form:"enabled" json:"enabled"`
EmailTemplate EmailTemplate `form:"emailTemplate" json:"emailTemplate"`
}
// Validate makes AuthAlertConfig validatable by implementing [validation.Validatable] interface.
func (c AuthAlertConfig) Validate() error {
return validation.ValidateStruct(&c,
// note: for now always run the email template validations even
// if not enabled since it could be used separately
validation.Field(&c.EmailTemplate),
)
}
// -------------------------------------------------------------------
type TokenConfig struct {
Secret string `form:"secret" json:"secret,omitempty"`
// Duration specifies how long an issued token to be valid (in seconds)
Duration int64 `form:"duration" json:"duration"`
}
// Validate makes TokenConfig validatable by implementing [validation.Validatable] interface.
func (c TokenConfig) Validate() error {
return validation.ValidateStruct(&c,
validation.Field(&c.Secret, validation.Required, validation.Length(30, 255)),
validation.Field(&c.Duration, validation.Required, validation.Min(10), validation.Max(94670856)), // ~3y max
)
}
// DurationTime returns the current Duration as [time.Duration].
func (c TokenConfig) DurationTime() time.Duration {
return time.Duration(c.Duration) * time.Second
}
// -------------------------------------------------------------------
type OTPConfig struct {
Enabled bool `form:"enabled" json:"enabled"`
// Duration specifies how long the OTP to be valid (in seconds)
Duration int64 `form:"duration" json:"duration"`
// Length specifies the auto generated password length.
Length int `form:"length" json:"length"`
// EmailTemplate is the default OTP email template that will be send to the auth record.
//
// In addition to the system placeholders you can also make use of
// [core.EmailPlaceholderOTPId] and [core.EmailPlaceholderOTP].
EmailTemplate EmailTemplate `form:"emailTemplate" json:"emailTemplate"`
}
// Validate makes OTPConfig validatable by implementing [validation.Validatable] interface.
func (c OTPConfig) Validate() error {
return validation.ValidateStruct(&c,
validation.Field(&c.Duration, validation.When(c.Enabled, validation.Required, validation.Min(10), validation.Max(86400))),
validation.Field(&c.Length, validation.When(c.Enabled, validation.Required, validation.Min(4))),
// note: for now always run the email template validations even
// if not enabled since it could be used separately
validation.Field(&c.EmailTemplate),
)
}
// DurationTime returns the current Duration as [time.Duration].
func (c OTPConfig) DurationTime() time.Duration {
return time.Duration(c.Duration) * time.Second
}
// -------------------------------------------------------------------
type MFAConfig struct {
Enabled bool `form:"enabled" json:"enabled"`
// Duration specifies how long an issued MFA to be valid (in seconds)
Duration int64 `form:"duration" json:"duration"`
// Rule is an optional field to restrict MFA only for the records that satisfy the rule.
//
// Leave it empty to enable MFA for everyone.
Rule string `form:"rule" json:"rule"`
}
// Validate makes MFAConfig validatable by implementing [validation.Validatable] interface.
func (c MFAConfig) Validate() error {
return validation.ValidateStruct(&c,
validation.Field(&c.Duration, validation.When(c.Enabled, validation.Required, validation.Min(10), validation.Max(86400))),
)
}
// DurationTime returns the current Duration as [time.Duration].
func (c MFAConfig) DurationTime() time.Duration {
return time.Duration(c.Duration) * time.Second
}
// -------------------------------------------------------------------
type PasswordAuthConfig struct {
Enabled bool `form:"enabled" json:"enabled"`
// IdentityFields is a list of field names that could be used as
// identity during password authentication.
//
// Usually only fields that has single column UNIQUE index are accepted as values.
IdentityFields []string `form:"identityFields" json:"identityFields"`
}
// Validate makes PasswordAuthConfig validatable by implementing [validation.Validatable] interface.
func (c PasswordAuthConfig) Validate() error {
// strip duplicated values
c.IdentityFields = list.ToUniqueStringSlice(c.IdentityFields)
if !c.Enabled {
return nil // no need to validate
}
return validation.ValidateStruct(&c,
validation.Field(&c.IdentityFields, validation.Required),
)
}
// -------------------------------------------------------------------
type OAuth2KnownFields struct {
Id string `form:"id" json:"id"`
Name string `form:"name" json:"name"`
Username string `form:"username" json:"username"`
AvatarURL string `form:"avatarURL" json:"avatarURL"`
}
type OAuth2Config struct {
Providers []OAuth2ProviderConfig `form:"providers" json:"providers"`
MappedFields OAuth2KnownFields `form:"mappedFields" json:"mappedFields"`
Enabled bool `form:"enabled" json:"enabled"`
}
// GetProviderConfig returns the first OAuth2ProviderConfig that matches the specified name.
//
// Returns false and zero config if no such provider is available in c.Providers.
func (c OAuth2Config) GetProviderConfig(name string) (config OAuth2ProviderConfig, exists bool) {
for _, p := range c.Providers {
if p.Name == name {
return p, true
}
}
return
}
// Validate makes OAuth2Config validatable by implementing [validation.Validatable] interface.
func (c OAuth2Config) Validate() error {
if !c.Enabled {
return nil // no need to validate
}
return validation.ValidateStruct(&c,
// note: don't require providers for now as they could be externally registered/removed
validation.Field(&c.Providers, validation.By(checkForDuplicatedProviders)),
)
}
func checkForDuplicatedProviders(value any) error {
configs, _ := value.([]OAuth2ProviderConfig)
existing := map[string]struct{}{}
for i, c := range configs {
if c.Name == "" {
continue // the name nonempty state is validated separately
}
if _, ok := existing[c.Name]; ok {
return validation.Errors{
strconv.Itoa(i): validation.Errors{
"name": validation.NewError("validation_duplicated_provider", "The provider {{.name}} is already registered.").
SetParams(map[string]any{"name": c.Name}),
},
}
}
existing[c.Name] = struct{}{}
}
return nil
}
type OAuth2ProviderConfig struct {
// PKCE overwrites the default provider PKCE config option.
//
// This usually shouldn't be needed but some OAuth2 vendors, like the LinkedIn OIDC,
// may require manual adjustment due to returning error if extra parameters are added to the request
// (https://github.com/pocketbase/pocketbase/discussions/3799#discussioncomment-7640312)
PKCE *bool `form:"pkce" json:"pkce"`
Name string `form:"name" json:"name"`
ClientId string `form:"clientId" json:"clientId"`
ClientSecret string `form:"clientSecret" json:"clientSecret,omitempty"`
AuthURL string `form:"authURL" json:"authURL"`
TokenURL string `form:"tokenURL" json:"tokenURL"`
UserInfoURL string `form:"userInfoURL" json:"userInfoURL"`
DisplayName string `form:"displayName" json:"displayName"`
Extra map[string]any `form:"extra" json:"extra"`
}
// Validate makes OAuth2ProviderConfig validatable by implementing [validation.Validatable] interface.
func (c OAuth2ProviderConfig) Validate() error {
return validation.ValidateStruct(&c,
validation.Field(&c.Name, validation.Required, validation.By(checkProviderName)),
validation.Field(&c.ClientId, validation.Required),
validation.Field(&c.ClientSecret, validation.Required),
validation.Field(&c.AuthURL, is.URL),
validation.Field(&c.TokenURL, is.URL),
validation.Field(&c.UserInfoURL, is.URL),
)
}
func checkProviderName(value any) error {
name, _ := value.(string)
if name == "" {
return nil // nothing to check
}
if _, err := auth.NewProviderByName(name); err != nil {
return validation.NewError("validation_missing_provider", "Invalid or missing provider with name {{.name}}.").
SetParams(map[string]any{"name": name})
}
return nil
}
// InitProvider returns a new auth.Provider instance loaded with the current OAuth2ProviderConfig options.
func (c OAuth2ProviderConfig) InitProvider() (auth.Provider, error) {
provider, err := auth.NewProviderByName(c.Name)
if err != nil {
return nil, err
}
if c.ClientId != "" {
provider.SetClientId(c.ClientId)
}
if c.ClientSecret != "" {
provider.SetClientSecret(c.ClientSecret)
}
if c.AuthURL != "" {
provider.SetAuthURL(c.AuthURL)
}
if c.UserInfoURL != "" {
provider.SetUserInfoURL(c.UserInfoURL)
}
if c.TokenURL != "" {
provider.SetTokenURL(c.TokenURL)
}
if c.DisplayName != "" {
provider.SetDisplayName(c.DisplayName)
}
if c.PKCE != nil {
provider.SetPKCE(*c.PKCE)
}
if c.Extra != nil {
provider.SetExtra(c.Extra)
}
return provider, nil
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,75 @@
package core
// Common settings placeholder tokens
const (
EmailPlaceholderAppName string = "{APP_NAME}"
EmailPlaceholderAppURL string = "{APP_URL}"
EmailPlaceholderToken string = "{TOKEN}"
EmailPlaceholderOTP string = "{OTP}"
EmailPlaceholderOTPId string = "{OTP_ID}"
)
var defaultVerificationTemplate = EmailTemplate{
Subject: "Verify your " + EmailPlaceholderAppName + " email",
Body: `<p>Hello,</p>
<p>Thank you for joining us at ` + EmailPlaceholderAppName + `.</p>
<p>Click on the button below to verify your email address.</p>
<p>
<a class="btn" href="` + EmailPlaceholderAppURL + "/_/#/auth/confirm-verification/" + EmailPlaceholderToken + `" target="_blank" rel="noopener">Verify</a>
</p>
<p>
Thanks,<br/>
` + EmailPlaceholderAppName + ` team
</p>`,
}
var defaultResetPasswordTemplate = EmailTemplate{
Subject: "Reset your " + EmailPlaceholderAppName + " password",
Body: `<p>Hello,</p>
<p>Click on the button below to reset your password.</p>
<p>
<a class="btn" href="` + EmailPlaceholderAppURL + "/_/#/auth/confirm-password-reset/" + EmailPlaceholderToken + `" target="_blank" rel="noopener">Reset password</a>
</p>
<p><i>If you didn't ask to reset your password, you can ignore this email.</i></p>
<p>
Thanks,<br/>
` + EmailPlaceholderAppName + ` team
</p>`,
}
var defaultConfirmEmailChangeTemplate = EmailTemplate{
Subject: "Confirm your " + EmailPlaceholderAppName + " new email address",
Body: `<p>Hello,</p>
<p>Click on the button below to confirm your new email address.</p>
<p>
<a class="btn" href="` + EmailPlaceholderAppURL + "/_/#/auth/confirm-email-change/" + EmailPlaceholderToken + `" target="_blank" rel="noopener">Confirm new email</a>
</p>
<p><i>If you didn't ask to change your email address, you can ignore this email.</i></p>
<p>
Thanks,<br/>
` + EmailPlaceholderAppName + ` team
</p>`,
}
var defaultOTPTemplate = EmailTemplate{
Subject: "OTP for " + EmailPlaceholderAppName,
Body: `<p>Hello,</p>
<p>Your one-time password is: <strong>` + EmailPlaceholderOTP + `</strong></p>
<p><i>If you didn't ask for the one-time password, you can ignore this email.</i></p>
<p>
Thanks,<br/>
` + EmailPlaceholderAppName + ` team
</p>`,
}
var defaultAuthAlertTemplate = EmailTemplate{
Subject: "Login from a new location",
Body: `<p>Hello,</p>
<p>We noticed a login to your ` + EmailPlaceholderAppName + ` account from a new location.</p>
<p>If this was you, you may disregard this email.</p>
<p><strong>If this wasn't you, you should immediately change your ` + EmailPlaceholderAppName + ` account password to revoke access from all other locations.</strong></p>
<p>
Thanks,<br/>
` + EmailPlaceholderAppName + ` team
</p>`,
}

View file

@ -0,0 +1,11 @@
package core
var _ optionsValidator = (*collectionBaseOptions)(nil)
// collectionBaseOptions defines the options for the "base" type collection.
type collectionBaseOptions struct {
}
func (o *collectionBaseOptions) validate(cv *collectionValidator) error {
return nil
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,18 @@
package core
import (
validation "github.com/go-ozzo/ozzo-validation/v4"
)
var _ optionsValidator = (*collectionViewOptions)(nil)
// collectionViewOptions defines the options for the "view" type collection.
type collectionViewOptions struct {
ViewQuery string `form:"viewQuery" json:"viewQuery"`
}
func (o *collectionViewOptions) validate(cv *collectionValidator) error {
return validation.ValidateStruct(o,
validation.Field(&o.ViewQuery, validation.Required, validation.By(cv.checkViewQuery)),
)
}

View file

@ -0,0 +1,79 @@
package core_test
import (
"testing"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
)
func TestCollectionViewOptionsValidate(t *testing.T) {
t.Parallel()
scenarios := []struct {
name string
collection func(app core.App) (*core.Collection, error)
expectedErrors []string
}{
{
name: "view with empty query",
collection: func(app core.App) (*core.Collection, error) {
c := core.NewViewCollection("new_auth")
return c, nil
},
expectedErrors: []string{"fields", "viewQuery"},
},
{
name: "view with invalid query",
collection: func(app core.App) (*core.Collection, error) {
c := core.NewViewCollection("new_auth")
c.ViewQuery = "invalid"
return c, nil
},
expectedErrors: []string{"fields", "viewQuery"},
},
{
name: "view with valid query but missing id",
collection: func(app core.App) (*core.Collection, error) {
c := core.NewViewCollection("new_auth")
c.ViewQuery = "select 1"
return c, nil
},
expectedErrors: []string{"fields", "viewQuery"},
},
{
name: "view with valid query",
collection: func(app core.App) (*core.Collection, error) {
c := core.NewViewCollection("new_auth")
c.ViewQuery = "select demo1.id, text as example from demo1"
return c, nil
},
expectedErrors: []string{},
},
{
name: "update view query ",
collection: func(app core.App) (*core.Collection, error) {
c, _ := app.FindCollectionByNameOrId("view2")
c.ViewQuery = "select demo1.id, text as example from demo1"
return c, nil
},
expectedErrors: []string{},
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
collection, err := s.collection(app)
if err != nil {
t.Fatalf("Failed to retrieve test collection: %v", err)
}
result := app.Validate(collection)
tests.TestValidationErrors(t, result, s.expectedErrors)
})
}
}

391
core/collection_query.go Normal file
View file

@ -0,0 +1,391 @@
package core
import (
"bytes"
"database/sql"
"encoding/json"
"errors"
"fmt"
"slices"
"strings"
"github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase/tools/list"
)
const StoreKeyCachedCollections = "pbAppCachedCollections"
// CollectionQuery returns a new Collection select query.
func (app *BaseApp) CollectionQuery() *dbx.SelectQuery {
return app.ModelQuery(&Collection{})
}
// FindCollections finds all collections by the given type(s).
//
// If collectionTypes is not set, it returns all collections.
//
// Example:
//
// app.FindAllCollections() // all collections
// app.FindAllCollections("auth", "view") // only auth and view collections
func (app *BaseApp) FindAllCollections(collectionTypes ...string) ([]*Collection, error) {
collections := []*Collection{}
q := app.CollectionQuery()
types := list.NonzeroUniques(collectionTypes)
if len(types) > 0 {
q.AndWhere(dbx.In("type", list.ToInterfaceSlice(types)...))
}
err := q.OrderBy("rowid ASC").All(&collections)
if err != nil {
return nil, err
}
return collections, nil
}
// ReloadCachedCollections fetches all collections and caches them into the app store.
func (app *BaseApp) ReloadCachedCollections() error {
collections, err := app.FindAllCollections()
if err != nil {
return err
}
app.Store().Set(StoreKeyCachedCollections, collections)
return nil
}
// FindCollectionByNameOrId finds a single collection by its name (case insensitive) or id.
func (app *BaseApp) FindCollectionByNameOrId(nameOrId string) (*Collection, error) {
m := &Collection{}
err := app.CollectionQuery().
AndWhere(dbx.NewExp("[[id]]={:id} OR LOWER([[name]])={:name}", dbx.Params{
"id": nameOrId,
"name": strings.ToLower(nameOrId),
})).
Limit(1).
One(m)
if err != nil {
return nil, err
}
return m, nil
}
// FindCachedCollectionByNameOrId is similar to [BaseApp.FindCollectionByNameOrId]
// but retrieves the Collection from the app cache instead of making a db call.
//
// NB! This method is suitable for read-only Collection operations.
//
// Returns [sql.ErrNoRows] if no Collection is found for consistency
// with the [BaseApp.FindCollectionByNameOrId] method.
//
// If you plan making changes to the returned Collection model,
// use [BaseApp.FindCollectionByNameOrId] instead.
//
// Caveats:
//
// - The returned Collection should be used only for read-only operations.
// Avoid directly modifying the returned cached Collection as it will affect
// the global cached value even if you don't persist the changes in the database!
// - If you are updating a Collection in a transaction and then call this method before commit,
// it'll return the cached Collection state and not the one from the uncommitted transaction.
// - The cache is automatically updated on collections db change (create/update/delete).
// To manually reload the cache you can call [BaseApp.ReloadCachedCollections].
func (app *BaseApp) FindCachedCollectionByNameOrId(nameOrId string) (*Collection, error) {
collections, _ := app.Store().Get(StoreKeyCachedCollections).([]*Collection)
if collections == nil {
// cache is not initialized yet (eg. run in a system migration)
return app.FindCollectionByNameOrId(nameOrId)
}
for _, c := range collections {
if strings.EqualFold(c.Name, nameOrId) || c.Id == nameOrId {
return c, nil
}
}
return nil, sql.ErrNoRows
}
// FindCollectionReferences returns information for all relation fields
// referencing the provided collection.
//
// If the provided collection has reference to itself then it will be
// also included in the result. To exclude it, pass the collection id
// as the excludeIds argument.
func (app *BaseApp) FindCollectionReferences(collection *Collection, excludeIds ...string) (map[*Collection][]Field, error) {
collections := []*Collection{}
query := app.CollectionQuery()
if uniqueExcludeIds := list.NonzeroUniques(excludeIds); len(uniqueExcludeIds) > 0 {
query.AndWhere(dbx.NotIn("id", list.ToInterfaceSlice(uniqueExcludeIds)...))
}
if err := query.All(&collections); err != nil {
return nil, err
}
result := map[*Collection][]Field{}
for _, c := range collections {
for _, rawField := range c.Fields {
f, ok := rawField.(*RelationField)
if ok && f.CollectionId == collection.Id {
result[c] = append(result[c], f)
}
}
}
return result, nil
}
// FindCachedCollectionReferences is similar to [BaseApp.FindCollectionReferences]
// but retrieves the Collection from the app cache instead of making a db call.
//
// NB! This method is suitable for read-only Collection operations.
//
// If you plan making changes to the returned Collection model,
// use [BaseApp.FindCollectionReferences] instead.
//
// Caveats:
//
// - The returned Collection should be used only for read-only operations.
// Avoid directly modifying the returned cached Collection as it will affect
// the global cached value even if you don't persist the changes in the database!
// - If you are updating a Collection in a transaction and then call this method before commit,
// it'll return the cached Collection state and not the one from the uncommitted transaction.
// - The cache is automatically updated on collections db change (create/update/delete).
// To manually reload the cache you can call [BaseApp.ReloadCachedCollections].
func (app *BaseApp) FindCachedCollectionReferences(collection *Collection, excludeIds ...string) (map[*Collection][]Field, error) {
collections, _ := app.Store().Get(StoreKeyCachedCollections).([]*Collection)
if collections == nil {
// cache is not initialized yet (eg. run in a system migration)
return app.FindCollectionReferences(collection, excludeIds...)
}
result := map[*Collection][]Field{}
for _, c := range collections {
if slices.Contains(excludeIds, c.Id) {
continue
}
for _, rawField := range c.Fields {
f, ok := rawField.(*RelationField)
if ok && f.CollectionId == collection.Id {
result[c] = append(result[c], f)
}
}
}
return result, nil
}
// IsCollectionNameUnique checks that there is no existing collection
// with the provided name (case insensitive!).
//
// Note: case insensitive check because the name is used also as
// table name for the records.
func (app *BaseApp) IsCollectionNameUnique(name string, excludeIds ...string) bool {
if name == "" {
return false
}
query := app.CollectionQuery().
Select("count(*)").
AndWhere(dbx.NewExp("LOWER([[name]])={:name}", dbx.Params{"name": strings.ToLower(name)})).
Limit(1)
if uniqueExcludeIds := list.NonzeroUniques(excludeIds); len(uniqueExcludeIds) > 0 {
query.AndWhere(dbx.NotIn("id", list.ToInterfaceSlice(uniqueExcludeIds)...))
}
var total int
return query.Row(&total) == nil && total == 0
}
// TruncateCollection deletes all records associated with the provided collection.
//
// The truncate operation is executed in a single transaction,
// aka. either everything is deleted or none.
//
// Note that this method will also trigger the records related
// cascade and file delete actions.
func (app *BaseApp) TruncateCollection(collection *Collection) error {
if collection.IsView() {
return errors.New("view collections cannot be truncated since they don't store their own records")
}
return app.RunInTransaction(func(txApp App) error {
records := make([]*Record, 0, 500)
for {
err := txApp.RecordQuery(collection).Limit(500).All(&records)
if err != nil {
return err
}
if len(records) == 0 {
return nil
}
for _, record := range records {
err = txApp.Delete(record)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return err
}
}
records = records[:0]
}
})
}
// -------------------------------------------------------------------
// saveViewCollection persists the provided View collection changes:
// - deletes the old related SQL view (if any)
// - creates a new SQL view with the latest newCollection.Options.Query
// - generates new feilds list based on newCollection.Options.Query
// - updates newCollection.Fields based on the generated view table info and query
// - saves the newCollection
//
// This method returns an error if newCollection is not a "view".
func saveViewCollection(app App, newCollection, oldCollection *Collection) error {
if !newCollection.IsView() {
return errors.New("not a view collection")
}
return app.RunInTransaction(func(txApp App) error {
query := newCollection.ViewQuery
// generate collection fields from the query
viewFields, err := txApp.CreateViewFields(query)
if err != nil {
return err
}
// delete old renamed view
if oldCollection != nil {
if err := txApp.DeleteView(oldCollection.Name); err != nil {
return err
}
}
// wrap view query if necessary
query, err = normalizeViewQueryId(txApp, query)
if err != nil {
return fmt.Errorf("failed to normalize view query id: %w", err)
}
// (re)create the view
if err := txApp.SaveView(newCollection.Name, query); err != nil {
return err
}
newCollection.Fields = viewFields
return txApp.Save(newCollection)
})
}
// normalizeViewQueryId wraps (if necessary) the provided view query
// with a subselect to ensure that the id column is a text since
// currently we don't support non-string model ids
// (see https://github.com/pocketbase/pocketbase/issues/3110).
func normalizeViewQueryId(app App, query string) (string, error) {
query = strings.Trim(strings.TrimSpace(query), ";")
info, err := getQueryTableInfo(app, query)
if err != nil {
return "", err
}
for _, row := range info {
if strings.EqualFold(row.Name, FieldNameId) && strings.EqualFold(row.Type, "TEXT") {
return query, nil // no wrapping needed
}
}
// raw parse to preserve the columns order
rawParsed := new(identifiersParser)
if err := rawParsed.parse(query); err != nil {
return "", err
}
columns := make([]string, 0, len(rawParsed.columns))
for _, col := range rawParsed.columns {
if col.alias == FieldNameId {
columns = append(columns, fmt.Sprintf("CAST([[%s]] as TEXT) [[%s]]", col.alias, col.alias))
} else {
columns = append(columns, "[["+col.alias+"]]")
}
}
query = fmt.Sprintf("SELECT %s FROM (%s)", strings.Join(columns, ","), query)
return query, nil
}
// resaveViewsWithChangedFields updates all view collections with changed fields.
func resaveViewsWithChangedFields(app App, excludeIds ...string) error {
collections, err := app.FindAllCollections(CollectionTypeView)
if err != nil {
return err
}
return app.RunInTransaction(func(txApp App) error {
for _, collection := range collections {
if len(excludeIds) > 0 && list.ExistInSlice(collection.Id, excludeIds) {
continue
}
// clone the existing fields for temp modifications
oldFields, err := collection.Fields.Clone()
if err != nil {
return err
}
// generate new fields from the query
newFields, err := txApp.CreateViewFields(collection.ViewQuery)
if err != nil {
return err
}
// unset the fields' ids to exclude from the comparison
for _, f := range oldFields {
f.SetId("")
}
for _, f := range newFields {
f.SetId("")
}
encodedNewFields, err := json.Marshal(newFields)
if err != nil {
return err
}
encodedOldFields, err := json.Marshal(oldFields)
if err != nil {
return err
}
if bytes.EqualFold(encodedNewFields, encodedOldFields) {
continue // no changes
}
if err := saveViewCollection(txApp, collection, nil); err != nil {
return err
}
}
return nil
})
}

View file

@ -0,0 +1,474 @@
package core_test
import (
"context"
"database/sql"
"fmt"
"os"
"path/filepath"
"slices"
"strings"
"testing"
"time"
"github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
"github.com/pocketbase/pocketbase/tools/list"
)
func TestCollectionQuery(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
expected := "SELECT {{_collections}}.* FROM `_collections`"
sql := app.CollectionQuery().Build().SQL()
if sql != expected {
t.Errorf("Expected sql %s, got %s", expected, sql)
}
}
func TestReloadCachedCollections(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
err := app.ReloadCachedCollections()
if err != nil {
t.Fatal(err)
}
cached := app.Store().Get(core.StoreKeyCachedCollections)
cachedCollections, ok := cached.([]*core.Collection)
if !ok {
t.Fatalf("Expected []*core.Collection, got %T", cached)
}
collections, err := app.FindAllCollections()
if err != nil {
t.Fatalf("Failed to retrieve all collections: %v", err)
}
if len(cachedCollections) != len(collections) {
t.Fatalf("Expected %d collections, got %d", len(collections), len(cachedCollections))
}
for _, c := range collections {
var exists bool
for _, cc := range cachedCollections {
if cc.Id == c.Id {
exists = true
break
}
}
if !exists {
t.Fatalf("The collections cache is missing collection %q", c.Name)
}
}
}
func TestFindAllCollections(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
scenarios := []struct {
collectionTypes []string
expectTotal int
}{
{nil, 16},
{[]string{}, 16},
{[]string{""}, 16},
{[]string{"unknown"}, 0},
{[]string{"unknown", core.CollectionTypeAuth}, 4},
{[]string{core.CollectionTypeAuth, core.CollectionTypeView}, 7},
}
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%s", i, strings.Join(s.collectionTypes, "_")), func(t *testing.T) {
collections, err := app.FindAllCollections(s.collectionTypes...)
if err != nil {
t.Fatal(err)
}
if len(collections) != s.expectTotal {
t.Fatalf("Expected %d collections, got %d", s.expectTotal, len(collections))
}
expectedTypes := list.NonzeroUniques(s.collectionTypes)
if len(expectedTypes) > 0 {
for _, c := range collections {
if !slices.Contains(expectedTypes, c.Type) {
t.Fatalf("Unexpected collection type %s\n%v", c.Type, c)
}
}
}
})
}
}
func TestFindCollectionByNameOrId(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
scenarios := []struct {
nameOrId string
expectError bool
}{
{"", true},
{"missing", true},
{"wsmn24bux7wo113", false},
{"demo1", false},
{"DEMO1", false}, // case insensitive
}
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%s", i, s.nameOrId), func(t *testing.T) {
model, err := app.FindCollectionByNameOrId(s.nameOrId)
hasErr := err != nil
if hasErr != s.expectError {
t.Fatalf("Expected hasErr to be %v, got %v (%v)", s.expectError, hasErr, err)
}
if model != nil && model.Id != s.nameOrId && !strings.EqualFold(model.Name, s.nameOrId) {
t.Fatalf("Expected model with identifier %s, got %v", s.nameOrId, model)
}
})
}
}
func TestFindCachedCollectionByNameOrId(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
totalQueries := 0
app.ConcurrentDB().(*dbx.DB).QueryLogFunc = func(ctx context.Context, t time.Duration, sql string, rows *sql.Rows, err error) {
totalQueries++
}
run := func(withCache bool) {
scenarios := []struct {
nameOrId string
expectError bool
}{
{"", true},
{"missing", true},
{"wsmn24bux7wo113", false},
{"demo1", false},
{"DEMO1", false}, // case insensitive
}
var expectedTotalQueries int
if withCache {
err := app.ReloadCachedCollections()
if err != nil {
t.Fatal(err)
}
} else {
app.Store().Reset(nil)
expectedTotalQueries = len(scenarios)
}
totalQueries = 0
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%s", i, s.nameOrId), func(t *testing.T) {
model, err := app.FindCachedCollectionByNameOrId(s.nameOrId)
hasErr := err != nil
if hasErr != s.expectError {
t.Fatalf("Expected hasErr to be %v, got %v (%v)", s.expectError, hasErr, err)
}
if model != nil && model.Id != s.nameOrId && !strings.EqualFold(model.Name, s.nameOrId) {
t.Fatalf("Expected model with identifier %s, got %v", s.nameOrId, model)
}
})
}
if totalQueries != expectedTotalQueries {
t.Fatalf("Expected %d totalQueries, got %d", expectedTotalQueries, totalQueries)
}
}
run(true)
run(false)
}
func TestFindCollectionReferences(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
collection, err := app.FindCollectionByNameOrId("demo3")
if err != nil {
t.Fatal(err)
}
result, err := app.FindCollectionReferences(
collection,
collection.Id,
// test whether "nonempty" exclude ids condition will be skipped
"",
"",
)
if err != nil {
t.Fatal(err)
}
if len(result) != 1 {
t.Fatalf("Expected 1 collection, got %d: %v", len(result), result)
}
expectedFields := []string{
"rel_one_no_cascade",
"rel_one_no_cascade_required",
"rel_one_cascade",
"rel_one_unique",
"rel_many_no_cascade",
"rel_many_no_cascade_required",
"rel_many_cascade",
"rel_many_unique",
}
for col, fields := range result {
if col.Name != "demo4" {
t.Fatalf("Expected collection demo4, got %s", col.Name)
}
if len(fields) != len(expectedFields) {
t.Fatalf("Expected fields %v, got %v", expectedFields, fields)
}
for i, f := range fields {
if !slices.Contains(expectedFields, f.GetName()) {
t.Fatalf("[%d] Didn't expect field %v", i, f)
}
}
}
}
func TestFindCachedCollectionReferences(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
collection, err := app.FindCollectionByNameOrId("demo3")
if err != nil {
t.Fatal(err)
}
totalQueries := 0
app.ConcurrentDB().(*dbx.DB).QueryLogFunc = func(ctx context.Context, t time.Duration, sql string, rows *sql.Rows, err error) {
totalQueries++
}
run := func(withCache bool) {
var expectedTotalQueries int
if withCache {
err := app.ReloadCachedCollections()
if err != nil {
t.Fatal(err)
}
} else {
app.Store().Reset(nil)
expectedTotalQueries = 1
}
totalQueries = 0
result, err := app.FindCachedCollectionReferences(
collection,
collection.Id,
// test whether "nonempty" exclude ids condition will be skipped
"",
"",
)
if err != nil {
t.Fatal(err)
}
if len(result) != 1 {
t.Fatalf("Expected 1 collection, got %d: %v", len(result), result)
}
expectedFields := []string{
"rel_one_no_cascade",
"rel_one_no_cascade_required",
"rel_one_cascade",
"rel_one_unique",
"rel_many_no_cascade",
"rel_many_no_cascade_required",
"rel_many_cascade",
"rel_many_unique",
}
for col, fields := range result {
if col.Name != "demo4" {
t.Fatalf("Expected collection demo4, got %s", col.Name)
}
if len(fields) != len(expectedFields) {
t.Fatalf("Expected fields %v, got %v", expectedFields, fields)
}
for i, f := range fields {
if !slices.Contains(expectedFields, f.GetName()) {
t.Fatalf("[%d] Didn't expect field %v", i, f)
}
}
}
if totalQueries != expectedTotalQueries {
t.Fatalf("Expected %d totalQueries, got %d", expectedTotalQueries, totalQueries)
}
}
run(true)
run(false)
}
func TestIsCollectionNameUnique(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
scenarios := []struct {
name string
excludeId string
expected bool
}{
{"", "", false},
{"demo1", "", false},
{"Demo1", "", false},
{"new", "", true},
{"demo1", "wsmn24bux7wo113", true},
}
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%s", i, s.name), func(t *testing.T) {
result := app.IsCollectionNameUnique(s.name, s.excludeId)
if result != s.expected {
t.Errorf("Expected %v, got %v", s.expected, result)
}
})
}
}
func TestFindCollectionTruncate(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
countFiles := func(collectionId string) (int, error) {
entries, err := os.ReadDir(filepath.Join(app.DataDir(), "storage", collectionId))
return len(entries), err
}
t.Run("truncate view", func(t *testing.T) {
view2, err := app.FindCollectionByNameOrId("view2")
if err != nil {
t.Fatal(err)
}
err = app.TruncateCollection(view2)
if err == nil {
t.Fatalf("Expected truncate to fail because view collections can't be truncated")
}
})
t.Run("truncate failure", func(t *testing.T) {
demo3, err := app.FindCollectionByNameOrId("demo3")
if err != nil {
t.Fatal(err)
}
originalTotalRecords, err := app.CountRecords(demo3)
if err != nil {
t.Fatal(err)
}
originalTotalFiles, err := countFiles(demo3.Id)
if err != nil {
t.Fatal(err)
}
err = app.TruncateCollection(demo3)
if err == nil {
t.Fatalf("Expected truncate to fail due to cascade delete failed required constraint")
}
// short delay to ensure that the file delete goroutine has been executed
time.Sleep(100 * time.Millisecond)
totalRecords, err := app.CountRecords(demo3)
if err != nil {
t.Fatal(err)
}
if totalRecords != originalTotalRecords {
t.Fatalf("Expected %d records, got %d", originalTotalRecords, totalRecords)
}
totalFiles, err := countFiles(demo3.Id)
if err != nil {
t.Fatal(err)
}
if totalFiles != originalTotalFiles {
t.Fatalf("Expected %d files, got %d", originalTotalFiles, totalFiles)
}
})
t.Run("truncate success", func(t *testing.T) {
demo5, err := app.FindCollectionByNameOrId("demo5")
if err != nil {
t.Fatal(err)
}
err = app.TruncateCollection(demo5)
if err != nil {
t.Fatal(err)
}
// short delay to ensure that the file delete goroutine has been executed
time.Sleep(100 * time.Millisecond)
total, err := app.CountRecords(demo5)
if err != nil {
t.Fatal(err)
}
if total != 0 {
t.Fatalf("Expected all records to be deleted, got %v", total)
}
totalFiles, err := countFiles(demo5.Id)
if err != nil {
t.Fatal(err)
}
if totalFiles != 0 {
t.Fatalf("Expected truncated record files to be deleted, got %d", totalFiles)
}
// try to truncate again (shouldn't return an error)
err = app.TruncateCollection(demo5)
if err != nil {
t.Fatal(err)
}
})
}

View file

@ -0,0 +1,364 @@
package core
import (
"fmt"
"log/slog"
"strconv"
"strings"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase/tools/dbutils"
"github.com/pocketbase/pocketbase/tools/security"
)
// SyncRecordTableSchema compares the two provided collections
// and applies the necessary related record table changes.
//
// If oldCollection is null, then only newCollection is used to create the record table.
//
// This method is automatically invoked as part of a collection create/update/delete operation.
func (app *BaseApp) SyncRecordTableSchema(newCollection *Collection, oldCollection *Collection) error {
if newCollection.IsView() {
return nil // nothing to sync since views don't have records table
}
txErr := app.RunInTransaction(func(txApp App) error {
// create
// -----------------------------------------------------------
if oldCollection == nil || !app.HasTable(oldCollection.Name) {
tableName := newCollection.Name
fields := newCollection.Fields
cols := make(map[string]string, len(fields))
// add fields definition
for _, field := range fields {
cols[field.GetName()] = field.ColumnType(app)
}
// create table
if _, err := txApp.DB().CreateTable(tableName, cols).Execute(); err != nil {
return err
}
return createCollectionIndexes(txApp, newCollection)
}
// update
// -----------------------------------------------------------
oldTableName := oldCollection.Name
newTableName := newCollection.Name
oldFields := oldCollection.Fields
newFields := newCollection.Fields
needTableRename := !strings.EqualFold(oldTableName, newTableName)
var needIndexesUpdate bool
if needTableRename ||
oldFields.String() != newFields.String() ||
oldCollection.Indexes.String() != newCollection.Indexes.String() {
needIndexesUpdate = true
}
if needIndexesUpdate {
// drop old indexes (if any)
if err := dropCollectionIndexes(txApp, oldCollection); err != nil {
return err
}
}
// check for renamed table
if needTableRename {
_, err := txApp.DB().RenameTable("{{"+oldTableName+"}}", "{{"+newTableName+"}}").Execute()
if err != nil {
return err
}
}
// check for deleted columns
for _, oldField := range oldFields {
if f := newFields.GetById(oldField.GetId()); f != nil {
continue // exist
}
_, err := txApp.DB().DropColumn(newTableName, oldField.GetName()).Execute()
if err != nil {
return fmt.Errorf("failed to drop column %s - %w", oldField.GetName(), err)
}
}
// check for new or renamed columns
toRename := map[string]string{}
for _, field := range newFields {
oldField := oldFields.GetById(field.GetId())
// Note:
// We are using a temporary column name when adding or renaming columns
// to ensure that there are no name collisions in case there is
// names switch/reuse of existing columns (eg. name, title -> title, name).
// This way we are always doing 1 more rename operation but it provides better less ambiguous experience.
if oldField == nil {
tempName := field.GetName() + security.PseudorandomString(5)
toRename[tempName] = field.GetName()
// add
_, err := txApp.DB().AddColumn(newTableName, tempName, field.ColumnType(txApp)).Execute()
if err != nil {
return fmt.Errorf("failed to add column %s - %w", field.GetName(), err)
}
} else if oldField.GetName() != field.GetName() {
tempName := field.GetName() + security.PseudorandomString(5)
toRename[tempName] = field.GetName()
// rename
_, err := txApp.DB().RenameColumn(newTableName, oldField.GetName(), tempName).Execute()
if err != nil {
return fmt.Errorf("failed to rename column %s - %w", oldField.GetName(), err)
}
}
}
// set the actual columns name
for tempName, actualName := range toRename {
_, err := txApp.DB().RenameColumn(newTableName, tempName, actualName).Execute()
if err != nil {
return err
}
}
if err := normalizeSingleVsMultipleFieldChanges(txApp, newCollection, oldCollection); err != nil {
return err
}
if needIndexesUpdate {
return createCollectionIndexes(txApp, newCollection)
}
return nil
})
if txErr != nil {
return txErr
}
// run optimize per the SQLite recommendations
// (https://www.sqlite.org/pragma.html#pragma_optimize)
_, optimizeErr := app.ConcurrentDB().NewQuery("PRAGMA optimize").Execute()
if optimizeErr != nil {
app.Logger().Warn("Failed to run PRAGMA optimize after record table sync", slog.String("error", optimizeErr.Error()))
}
return nil
}
func normalizeSingleVsMultipleFieldChanges(app App, newCollection *Collection, oldCollection *Collection) error {
if newCollection.IsView() || oldCollection == nil {
return nil // view or not an update
}
return app.RunInTransaction(func(txApp App) error {
for _, newField := range newCollection.Fields {
// allow to continue even if there is no old field for the cases
// when a new field is added and there are already inserted data
var isOldMultiple bool
if oldField := oldCollection.Fields.GetById(newField.GetId()); oldField != nil {
if mv, ok := oldField.(MultiValuer); ok {
isOldMultiple = mv.IsMultiple()
}
}
var isNewMultiple bool
if mv, ok := newField.(MultiValuer); ok {
isNewMultiple = mv.IsMultiple()
}
if isOldMultiple == isNewMultiple {
continue // no change
}
// -------------------------------------------------------
// update the field column definition
// -------------------------------------------------------
// temporary drop all views to prevent reference errors during the columns renaming
// (this is used as an "alternative" to the writable_schema PRAGMA)
views := []struct {
Name string `db:"name"`
SQL string `db:"sql"`
}{}
err := txApp.DB().Select("name", "sql").
From("sqlite_master").
AndWhere(dbx.NewExp("sql is not null")).
AndWhere(dbx.HashExp{"type": "view"}).
All(&views)
if err != nil {
return err
}
for _, view := range views {
err = txApp.DeleteView(view.Name)
if err != nil {
return err
}
}
originalName := newField.GetName()
oldTempName := "_" + newField.GetName() + security.PseudorandomString(5)
// rename temporary the original column to something else to allow inserting a new one in its place
_, err = txApp.DB().RenameColumn(newCollection.Name, originalName, oldTempName).Execute()
if err != nil {
return err
}
// reinsert the field column with the new type
_, err = txApp.DB().AddColumn(newCollection.Name, originalName, newField.ColumnType(txApp)).Execute()
if err != nil {
return err
}
var copyQuery *dbx.Query
if !isOldMultiple && isNewMultiple {
// single -> multiple (convert to array)
copyQuery = txApp.DB().NewQuery(fmt.Sprintf(
`UPDATE {{%s}} set [[%s]] = (
CASE
WHEN COALESCE([[%s]], '') = ''
THEN '[]'
ELSE (
CASE
WHEN json_valid([[%s]]) AND json_type([[%s]]) == 'array'
THEN [[%s]]
ELSE json_array([[%s]])
END
)
END
)`,
newCollection.Name,
originalName,
oldTempName,
oldTempName,
oldTempName,
oldTempName,
oldTempName,
))
} else {
// multiple -> single (keep only the last element)
//
// note: for file fields the actual file objects are not
// deleted allowing additional custom handling via migration
copyQuery = txApp.DB().NewQuery(fmt.Sprintf(
`UPDATE {{%s}} set [[%s]] = (
CASE
WHEN COALESCE([[%s]], '[]') = '[]'
THEN ''
ELSE (
CASE
WHEN json_valid([[%s]]) AND json_type([[%s]]) == 'array'
THEN COALESCE(json_extract([[%s]], '$[#-1]'), '')
ELSE [[%s]]
END
)
END
)`,
newCollection.Name,
originalName,
oldTempName,
oldTempName,
oldTempName,
oldTempName,
oldTempName,
))
}
// copy the normalized values
_, err = copyQuery.Execute()
if err != nil {
return err
}
// drop the original column
_, err = txApp.DB().DropColumn(newCollection.Name, oldTempName).Execute()
if err != nil {
return err
}
// restore views
for _, view := range views {
_, err = txApp.DB().NewQuery(view.SQL).Execute()
if err != nil {
return err
}
}
}
return nil
})
}
func dropCollectionIndexes(app App, collection *Collection) error {
if collection.IsView() {
return nil // views don't have indexes
}
return app.RunInTransaction(func(txApp App) error {
for _, raw := range collection.Indexes {
parsed := dbutils.ParseIndex(raw)
if !parsed.IsValid() {
continue
}
_, err := txApp.DB().NewQuery(fmt.Sprintf("DROP INDEX IF EXISTS [[%s]]", parsed.IndexName)).Execute()
if err != nil {
return err
}
}
return nil
})
}
func createCollectionIndexes(app App, collection *Collection) error {
if collection.IsView() {
return nil // views don't have indexes
}
return app.RunInTransaction(func(txApp App) error {
// upsert new indexes
//
// note: we are returning validation errors because the indexes cannot be
// easily validated in a form, aka. before persisting the related
// collection record table changes
errs := validation.Errors{}
for i, idx := range collection.Indexes {
parsed := dbutils.ParseIndex(idx)
// ensure that the index is always for the current collection
parsed.TableName = collection.Name
if !parsed.IsValid() {
errs[strconv.Itoa(i)] = validation.NewError(
"validation_invalid_index_expression",
"Invalid CREATE INDEX expression.",
)
continue
}
if _, err := txApp.DB().NewQuery(parsed.Build()).Execute(); err != nil {
errs[strconv.Itoa(i)] = validation.NewError(
"validation_invalid_index_expression",
fmt.Sprintf("Failed to create index %s - %v.", parsed.IndexName, err.Error()),
)
continue
}
}
if len(errs) > 0 {
return validation.Errors{"indexes": errs}
}
return nil
})
}

View file

@ -0,0 +1,296 @@
package core_test
import (
"bytes"
"encoding/json"
"testing"
"github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
"github.com/pocketbase/pocketbase/tools/list"
"github.com/pocketbase/pocketbase/tools/types"
)
func TestSyncRecordTableSchema(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
oldCollection, err := app.FindCollectionByNameOrId("demo2")
if err != nil {
t.Fatal(err)
}
updatedCollection, err := app.FindCollectionByNameOrId("demo2")
if err != nil {
t.Fatal(err)
}
updatedCollection.Name = "demo_renamed"
updatedCollection.Fields.RemoveByName("active")
updatedCollection.Fields.Add(&core.EmailField{
Name: "new_field",
})
updatedCollection.Fields.Add(&core.EmailField{
Id: updatedCollection.Fields.GetByName("title").GetId(),
Name: "title_renamed",
})
updatedCollection.Indexes = types.JSONArray[string]{"create index idx_title_renamed on anything (title_renamed)"}
baseCol := core.NewBaseCollection("new_base")
baseCol.Fields.Add(&core.TextField{Name: "test"})
authCol := core.NewAuthCollection("new_auth")
authCol.Fields.Add(&core.TextField{Name: "test"})
authCol.AddIndex("idx_auth_test", false, "email, id", "")
scenarios := []struct {
name string
newCollection *core.Collection
oldCollection *core.Collection
expectedColumns []string
expectedIndexesCount int
}{
{
"new base collection",
baseCol,
nil,
[]string{"id", "test"},
0,
},
{
"new auth collection",
authCol,
nil,
[]string{
"id", "test", "email", "verified",
"emailVisibility", "tokenKey", "password",
},
3,
},
{
"no changes",
oldCollection,
oldCollection,
[]string{"id", "created", "updated", "title", "active"},
3,
},
{
"renamed table, deleted column, renamed columnd and new column",
updatedCollection,
oldCollection,
[]string{"id", "created", "updated", "title_renamed", "new_field"},
1,
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
err := app.SyncRecordTableSchema(s.newCollection, s.oldCollection)
if err != nil {
t.Fatal(err)
}
if !app.HasTable(s.newCollection.Name) {
t.Fatalf("Expected table %s to exist", s.newCollection.Name)
}
cols, _ := app.TableColumns(s.newCollection.Name)
if len(cols) != len(s.expectedColumns) {
t.Fatalf("Expected columns %v, got %v", s.expectedColumns, cols)
}
for _, col := range cols {
if !list.ExistInSlice(col, s.expectedColumns) {
t.Fatalf("Couldn't find column %s in %v", col, s.expectedColumns)
}
}
indexes, _ := app.TableIndexes(s.newCollection.Name)
if totalIndexes := len(indexes); totalIndexes != s.expectedIndexesCount {
t.Fatalf("Expected %d indexes, got %d:\n%v", s.expectedIndexesCount, totalIndexes, indexes)
}
})
}
}
func getTotalViews(app core.App) (int, error) {
var total int
err := app.DB().Select("count(*)").
From("sqlite_master").
AndWhere(dbx.NewExp("sql is not null")).
AndWhere(dbx.HashExp{"type": "view"}).
Row(&total)
return total, err
}
func TestSingleVsMultipleValuesNormalization(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
collection, err := app.FindCollectionByNameOrId("demo1")
if err != nil {
t.Fatal(err)
}
beforeTotalViews, err := getTotalViews(app)
if err != nil {
t.Fatal(err)
}
// mock field changes
collection.Fields.GetByName("select_one").(*core.SelectField).MaxSelect = 2
collection.Fields.GetByName("select_many").(*core.SelectField).MaxSelect = 1
collection.Fields.GetByName("file_one").(*core.FileField).MaxSelect = 2
collection.Fields.GetByName("file_many").(*core.FileField).MaxSelect = 1
collection.Fields.GetByName("rel_one").(*core.RelationField).MaxSelect = 2
collection.Fields.GetByName("rel_many").(*core.RelationField).MaxSelect = 1
// new multivaluer field to check whether the array normalization
// will be applied for already inserted data
collection.Fields.Add(&core.SelectField{
Name: "new_multiple",
Values: []string{"a", "b", "c"},
MaxSelect: 3,
})
if err := app.Save(collection); err != nil {
t.Fatal(err)
}
// ensure that the views were reinserted
afterTotalViews, err := getTotalViews(app)
if err != nil {
t.Fatal(err)
}
if afterTotalViews != beforeTotalViews {
t.Fatalf("Expected total views %d, got %d", beforeTotalViews, afterTotalViews)
}
// check whether the columns DEFAULT definition was updated
// ---------------------------------------------------------------
tableInfo, err := app.TableInfo(collection.Name)
if err != nil {
t.Fatal(err)
}
tableInfoExpectations := map[string]string{
"select_one": `'[]'`,
"select_many": `''`,
"file_one": `'[]'`,
"file_many": `''`,
"rel_one": `'[]'`,
"rel_many": `''`,
"new_multiple": `'[]'`,
}
for col, dflt := range tableInfoExpectations {
t.Run("check default for "+col, func(t *testing.T) {
var row *core.TableInfoRow
for _, r := range tableInfo {
if r.Name == col {
row = r
break
}
}
if row == nil {
t.Fatalf("Missing info for column %q", col)
}
if v := row.DefaultValue.String; v != dflt {
t.Fatalf("Expected default value %q, got %q", dflt, v)
}
})
}
// check whether the values were normalized
// ---------------------------------------------------------------
type fieldsExpectation struct {
SelectOne string `db:"select_one"`
SelectMany string `db:"select_many"`
FileOne string `db:"file_one"`
FileMany string `db:"file_many"`
RelOne string `db:"rel_one"`
RelMany string `db:"rel_many"`
NewMultiple string `db:"new_multiple"`
}
fieldsScenarios := []struct {
recordId string
expected fieldsExpectation
}{
{
"imy661ixudk5izi",
fieldsExpectation{
SelectOne: `[]`,
SelectMany: ``,
FileOne: `[]`,
FileMany: ``,
RelOne: `[]`,
RelMany: ``,
NewMultiple: `[]`,
},
},
{
"al1h9ijdeojtsjy",
fieldsExpectation{
SelectOne: `["optionB"]`,
SelectMany: `optionB`,
FileOne: `["300_Jsjq7RdBgA.png"]`,
FileMany: ``,
RelOne: `["84nmscqy84lsi1t"]`,
RelMany: `oap640cot4yru2s`,
NewMultiple: `[]`,
},
},
{
"84nmscqy84lsi1t",
fieldsExpectation{
SelectOne: `["optionB"]`,
SelectMany: `optionC`,
FileOne: `["test_d61b33QdDU.txt"]`,
FileMany: `test_tC1Yc87DfC.txt`,
RelOne: `[]`,
RelMany: `oap640cot4yru2s`,
NewMultiple: `[]`,
},
},
}
for _, s := range fieldsScenarios {
t.Run("check fields for record "+s.recordId, func(t *testing.T) {
result := new(fieldsExpectation)
err := app.DB().Select(
"select_one",
"select_many",
"file_one",
"file_many",
"rel_one",
"rel_many",
"new_multiple",
).From(collection.Name).Where(dbx.HashExp{"id": s.recordId}).One(result)
if err != nil {
t.Fatalf("Failed to load record: %v", err)
}
encodedResult, err := json.Marshal(result)
if err != nil {
t.Fatalf("Failed to encode result: %v", err)
}
encodedExpectation, err := json.Marshal(s.expected)
if err != nil {
t.Fatalf("Failed to encode expectation: %v", err)
}
if !bytes.EqualFold(encodedExpectation, encodedResult) {
t.Fatalf("Expected \n%s, \ngot \n%s", encodedExpectation, encodedResult)
}
})
}
}

694
core/collection_validate.go Normal file
View file

@ -0,0 +1,694 @@
package core
import (
"context"
"fmt"
"regexp"
"strconv"
"strings"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase/core/validators"
"github.com/pocketbase/pocketbase/tools/dbutils"
"github.com/pocketbase/pocketbase/tools/list"
"github.com/pocketbase/pocketbase/tools/search"
"github.com/pocketbase/pocketbase/tools/types"
)
var collectionNameRegex = regexp.MustCompile(`^\w+$`)
func onCollectionValidate(e *CollectionEvent) error {
var original *Collection
if !e.Collection.IsNew() {
original = &Collection{}
if err := e.App.ModelQuery(original).Model(e.Collection.LastSavedPK(), original); err != nil {
return fmt.Errorf("failed to fetch old collection state: %w", err)
}
}
validator := newCollectionValidator(
e.Context,
e.App,
e.Collection,
original,
)
if err := validator.run(); err != nil {
return err
}
return e.Next()
}
func newCollectionValidator(ctx context.Context, app App, new, original *Collection) *collectionValidator {
validator := &collectionValidator{
ctx: ctx,
app: app,
new: new,
original: original,
}
// load old/original collection
if validator.original == nil {
validator.original = NewCollection(validator.new.Type, "")
}
return validator
}
type collectionValidator struct {
original *Collection
new *Collection
app App
ctx context.Context
}
type optionsValidator interface {
validate(cv *collectionValidator) error
}
func (validator *collectionValidator) run() error {
if validator.original.IsNew() {
validator.new.updateGeneratedIdIfExists(validator.app)
}
// generate fields from the query (overwriting any explicit user defined fields)
if validator.new.IsView() {
validator.new.Fields, _ = validator.app.CreateViewFields(validator.new.ViewQuery)
}
// validate base fields
baseErr := validation.ValidateStruct(validator.new,
validation.Field(
&validator.new.Id,
validation.Required,
validation.When(
validator.original.IsNew(),
validation.Length(1, 100),
validation.Match(DefaultIdRegex),
validation.By(validators.UniqueId(validator.app.ConcurrentDB(), validator.new.TableName())),
).Else(
validation.By(validators.Equal(validator.original.Id)),
),
),
validation.Field(
&validator.new.System,
validation.By(validator.ensureNoSystemFlagChange),
),
validation.Field(
&validator.new.Type,
validation.Required,
validation.In(
CollectionTypeBase,
CollectionTypeAuth,
CollectionTypeView,
),
validation.By(validator.ensureNoTypeChange),
),
validation.Field(
&validator.new.Name,
validation.Required,
validation.Length(1, 255),
validation.By(checkForVia),
validation.Match(collectionNameRegex),
validation.By(validator.ensureNoSystemNameChange),
validation.By(validator.checkUniqueName),
),
validation.Field(
&validator.new.Fields,
validation.By(validator.checkFieldDuplicates),
validation.By(validator.checkMinFields),
validation.When(
!validator.new.IsView(),
validation.By(validator.ensureNoSystemFieldsChange),
validation.By(validator.ensureNoFieldsTypeChange),
),
validation.When(validator.new.IsAuth(), validation.By(validator.checkReservedAuthKeys)),
validation.By(validator.checkFieldValidators),
),
validation.Field(
&validator.new.ListRule,
validation.By(validator.checkRule),
validation.By(validator.ensureNoSystemRuleChange(validator.original.ListRule)),
),
validation.Field(
&validator.new.ViewRule,
validation.By(validator.checkRule),
validation.By(validator.ensureNoSystemRuleChange(validator.original.ViewRule)),
),
validation.Field(
&validator.new.CreateRule,
validation.When(validator.new.IsView(), validation.Nil),
validation.By(validator.checkRule),
validation.By(validator.ensureNoSystemRuleChange(validator.original.CreateRule)),
),
validation.Field(
&validator.new.UpdateRule,
validation.When(validator.new.IsView(), validation.Nil),
validation.By(validator.checkRule),
validation.By(validator.ensureNoSystemRuleChange(validator.original.UpdateRule)),
),
validation.Field(
&validator.new.DeleteRule,
validation.When(validator.new.IsView(), validation.Nil),
validation.By(validator.checkRule),
validation.By(validator.ensureNoSystemRuleChange(validator.original.DeleteRule)),
),
validation.Field(&validator.new.Indexes, validation.By(validator.checkIndexes)),
)
optionsErr := validator.validateOptions()
return validators.JoinValidationErrors(baseErr, optionsErr)
}
func (validator *collectionValidator) checkUniqueName(value any) error {
v, _ := value.(string)
// ensure unique collection name
if !validator.app.IsCollectionNameUnique(v, validator.original.Id) {
return validation.NewError("validation_collection_name_exists", "Collection name must be unique (case insensitive).")
}
// ensure that the collection name doesn't collide with the id of any collection
dummyCollection := &Collection{}
if validator.app.ModelQuery(dummyCollection).Model(v, dummyCollection) == nil {
return validation.NewError("validation_collection_name_id_duplicate", "The name must not match an existing collection id.")
}
// ensure that there is no existing internal table with the provided name
if validator.original.Name != v && // has changed
validator.app.IsCollectionNameUnique(v) && // is not a collection (in case it was presaved)
validator.app.HasTable(v) {
return validation.NewError("validation_collection_name_invalid", "The name shouldn't match with an existing internal table.")
}
return nil
}
func (validator *collectionValidator) ensureNoSystemNameChange(value any) error {
v, _ := value.(string)
if !validator.original.IsNew() && validator.original.System && v != validator.original.Name {
return validation.NewError("validation_collection_system_name_change", "System collection name cannot be changed.")
}
return nil
}
func (validator *collectionValidator) ensureNoSystemFlagChange(value any) error {
v, _ := value.(bool)
if !validator.original.IsNew() && v != validator.original.System {
return validation.NewError("validation_collection_system_flag_change", "System collection state cannot be changed.")
}
return nil
}
func (validator *collectionValidator) ensureNoTypeChange(value any) error {
v, _ := value.(string)
if !validator.original.IsNew() && v != validator.original.Type {
return validation.NewError("validation_collection_type_change", "Collection type cannot be changed.")
}
return nil
}
func (validator *collectionValidator) ensureNoFieldsTypeChange(value any) error {
v, ok := value.(FieldsList)
if !ok {
return validators.ErrUnsupportedValueType
}
errs := validation.Errors{}
for i, field := range v {
oldField := validator.original.Fields.GetById(field.GetId())
if oldField != nil && oldField.Type() != field.Type() {
errs[strconv.Itoa(i)] = validation.NewError(
"validation_field_type_change",
"Field type cannot be changed.",
)
}
}
if len(errs) > 0 {
return errs
}
return nil
}
func (validator *collectionValidator) checkFieldDuplicates(value any) error {
fields, ok := value.(FieldsList)
if !ok {
return validators.ErrUnsupportedValueType
}
totalFields := len(fields)
ids := make([]string, 0, totalFields)
names := make([]string, 0, totalFields)
for i, field := range fields {
if list.ExistInSlice(field.GetId(), ids) {
return validation.Errors{
strconv.Itoa(i): validation.Errors{
"id": validation.NewError(
"validation_duplicated_field_id",
fmt.Sprintf("Duplicated or invalid field id %q", field.GetId()),
),
},
}
}
// field names are used as db columns and should be case insensitive
nameLower := strings.ToLower(field.GetName())
if list.ExistInSlice(nameLower, names) {
return validation.Errors{
strconv.Itoa(i): validation.Errors{
"name": validation.NewError(
"validation_duplicated_field_name",
"Duplicated or invalid field name {{.fieldName}}",
).SetParams(map[string]any{
"fieldName": field.GetName(),
}),
},
}
}
ids = append(ids, field.GetId())
names = append(names, nameLower)
}
return nil
}
func (validator *collectionValidator) checkFieldValidators(value any) error {
fields, ok := value.(FieldsList)
if !ok {
return validators.ErrUnsupportedValueType
}
errs := validation.Errors{}
for i, field := range fields {
if err := field.ValidateSettings(validator.ctx, validator.app, validator.new); err != nil {
errs[strconv.Itoa(i)] = err
}
}
if len(errs) > 0 {
return errs
}
return nil
}
func (cv *collectionValidator) checkViewQuery(value any) error {
v, _ := value.(string)
if v == "" {
return nil // nothing to check
}
if _, err := cv.app.CreateViewFields(v); err != nil {
return validation.NewError(
"validation_invalid_view_query",
fmt.Sprintf("Invalid query - %s", err.Error()),
)
}
return nil
}
var reservedAuthKeys = []string{"passwordConfirm", "oldPassword"}
func (cv *collectionValidator) checkReservedAuthKeys(value any) error {
fields, ok := value.(FieldsList)
if !ok {
return validators.ErrUnsupportedValueType
}
if !cv.new.IsAuth() {
return nil // not an auth collection
}
errs := validation.Errors{}
for i, field := range fields {
if list.ExistInSlice(field.GetName(), reservedAuthKeys) {
errs[strconv.Itoa(i)] = validation.Errors{
"name": validation.NewError(
"validation_reserved_field_name",
"The field name is reserved and cannot be used.",
),
}
}
}
if len(errs) > 0 {
return errs
}
return nil
}
func (cv *collectionValidator) checkMinFields(value any) error {
fields, ok := value.(FieldsList)
if !ok {
return validators.ErrUnsupportedValueType
}
if len(fields) == 0 {
return validation.ErrRequired
}
// all collections must have an "id" PK field
idField, _ := fields.GetByName(FieldNameId).(*TextField)
if idField == nil || !idField.PrimaryKey {
return validation.NewError("validation_missing_primary_key", `Missing or invalid "id" PK field.`)
}
switch cv.new.Type {
case CollectionTypeAuth:
passwordField, _ := fields.GetByName(FieldNamePassword).(*PasswordField)
if passwordField == nil {
return validation.NewError("validation_missing_password_field", `System "password" field is required.`)
}
if !passwordField.Hidden || !passwordField.System {
return validation.Errors{FieldNamePassword: ErrMustBeSystemAndHidden}
}
tokenKeyField, _ := fields.GetByName(FieldNameTokenKey).(*TextField)
if tokenKeyField == nil {
return validation.NewError("validation_missing_tokenKey_field", `System "tokenKey" field is required.`)
}
if !tokenKeyField.Hidden || !tokenKeyField.System {
return validation.Errors{FieldNameTokenKey: ErrMustBeSystemAndHidden}
}
emailField, _ := fields.GetByName(FieldNameEmail).(*EmailField)
if emailField == nil {
return validation.NewError("validation_missing_email_field", `System "email" field is required.`)
}
if !emailField.System {
return validation.Errors{FieldNameEmail: ErrMustBeSystem}
}
emailVisibilityField, _ := fields.GetByName(FieldNameEmailVisibility).(*BoolField)
if emailVisibilityField == nil {
return validation.NewError("validation_missing_emailVisibility_field", `System "emailVisibility" field is required.`)
}
if !emailVisibilityField.System {
return validation.Errors{FieldNameEmailVisibility: ErrMustBeSystem}
}
verifiedField, _ := fields.GetByName(FieldNameVerified).(*BoolField)
if verifiedField == nil {
return validation.NewError("validation_missing_verified_field", `System "verified" field is required.`)
}
if !verifiedField.System {
return validation.Errors{FieldNameVerified: ErrMustBeSystem}
}
return nil
}
return nil
}
func (validator *collectionValidator) ensureNoSystemFieldsChange(value any) error {
fields, ok := value.(FieldsList)
if !ok {
return validators.ErrUnsupportedValueType
}
if validator.original.IsNew() {
return nil // not an update
}
for _, oldField := range validator.original.Fields {
if !oldField.GetSystem() {
continue
}
newField := fields.GetById(oldField.GetId())
if newField == nil || oldField.GetName() != newField.GetName() {
return validation.NewError("validation_system_field_change", "System fields cannot be deleted or renamed.")
}
}
return nil
}
func (cv *collectionValidator) checkFieldsForUniqueIndex(value any) error {
names, ok := value.([]string)
if !ok {
return validators.ErrUnsupportedValueType
}
if len(names) == 0 {
return nil // nothing to check
}
for _, name := range names {
field := cv.new.Fields.GetByName(name)
if field == nil {
return validation.NewError("validation_missing_field", "Invalid or missing field {{.fieldName}}").
SetParams(map[string]any{"fieldName": name})
}
if _, ok := dbutils.FindSingleColumnUniqueIndex(cv.new.Indexes, name); !ok {
return validation.NewError("validation_missing_unique_constraint", "The field {{.fieldName}} doesn't have a UNIQUE constraint.").
SetParams(map[string]any{"fieldName": name})
}
}
return nil
}
// note: value could be either *string or string
func (validator *collectionValidator) checkRule(value any) error {
var vStr string
v, ok := value.(*string)
if ok {
if v != nil {
vStr = *v
}
} else {
vStr, ok = value.(string)
}
if !ok {
return validators.ErrUnsupportedValueType
}
if vStr == "" {
return nil // nothing to check
}
r := NewRecordFieldResolver(validator.app, validator.new, nil, true)
_, err := search.FilterData(vStr).BuildExpr(r)
if err != nil {
return validation.NewError("validation_invalid_rule", "Invalid rule. Raw error: "+err.Error())
}
return nil
}
func (validator *collectionValidator) ensureNoSystemRuleChange(oldRule *string) validation.RuleFunc {
return func(value any) error {
if validator.original.IsNew() || !validator.original.System {
return nil // not an update of a system collection
}
rule, ok := value.(*string)
if !ok {
return validators.ErrUnsupportedValueType
}
if (rule == nil && oldRule == nil) ||
(rule != nil && oldRule != nil && *rule == *oldRule) {
return nil
}
return validation.NewError("validation_collection_system_rule_change", "System collection API rule cannot be changed.")
}
}
func (cv *collectionValidator) checkIndexes(value any) error {
indexes, _ := value.(types.JSONArray[string])
if cv.new.IsView() && len(indexes) > 0 {
return validation.NewError(
"validation_indexes_not_supported",
"View collections don't support indexes.",
)
}
duplicatedNames := make(map[string]struct{}, len(indexes))
duplicatedDefinitions := make(map[string]struct{}, len(indexes))
for i, rawIndex := range indexes {
parsed := dbutils.ParseIndex(rawIndex)
// always set a table name because it is ignored anyway in order to keep it in sync with the collection name
parsed.TableName = "validator"
if !parsed.IsValid() {
return validation.Errors{
strconv.Itoa(i): validation.NewError(
"validation_invalid_index_expression",
"Invalid CREATE INDEX expression.",
),
}
}
if _, isDuplicated := duplicatedNames[strings.ToLower(parsed.IndexName)]; isDuplicated {
return validation.Errors{
strconv.Itoa(i): validation.NewError(
"validation_duplicated_index_name",
"The index name already exists.",
),
}
}
duplicatedNames[strings.ToLower(parsed.IndexName)] = struct{}{}
// ensure that the index name is not used in another collection
var usedTblName string
_ = cv.app.ConcurrentDB().Select("tbl_name").
From("sqlite_master").
AndWhere(dbx.HashExp{"type": "index"}).
AndWhere(dbx.NewExp("LOWER([[tbl_name]])!=LOWER({:oldName})", dbx.Params{"oldName": cv.original.Name})).
AndWhere(dbx.NewExp("LOWER([[tbl_name]])!=LOWER({:newName})", dbx.Params{"newName": cv.new.Name})).
AndWhere(dbx.NewExp("LOWER([[name]])=LOWER({:indexName})", dbx.Params{"indexName": parsed.IndexName})).
Limit(1).
Row(&usedTblName)
if usedTblName != "" {
return validation.Errors{
strconv.Itoa(i): validation.NewError(
"validation_existing_index_name",
"The index name is already used in {{.usedTableName}} collection.",
).SetParams(map[string]any{"usedTableName": usedTblName}),
}
}
// reset non-important identifiers
parsed.SchemaName = "validator"
parsed.IndexName = "validator"
parsedDef := parsed.Build()
if _, isDuplicated := duplicatedDefinitions[parsedDef]; isDuplicated {
return validation.Errors{
strconv.Itoa(i): validation.NewError(
"validation_duplicated_index_definition",
"The index definition already exists.",
),
}
}
duplicatedDefinitions[parsedDef] = struct{}{}
// note: we don't check the index table name because it is always
// overwritten by the SyncRecordTableSchema to allow
// easier partial modifications (eg. changing only the collection name).
// if !strings.EqualFold(parsed.TableName, form.Name) {
// return validation.Errors{
// strconv.Itoa(i): validation.NewError(
// "validation_invalid_index_table",
// fmt.Sprintf("The index table must be the same as the collection name."),
// ),
// }
// }
}
// ensure that unique indexes on system fields are not changed or removed
if !cv.original.IsNew() {
OLD_INDEXES_LOOP:
for _, oldIndex := range cv.original.Indexes {
oldParsed := dbutils.ParseIndex(oldIndex)
if !oldParsed.Unique {
continue
}
// reset collate and sort since they are not important for the unique constraint
for i := range oldParsed.Columns {
oldParsed.Columns[i].Collate = ""
oldParsed.Columns[i].Sort = ""
}
oldParsedStr := oldParsed.Build()
for _, column := range oldParsed.Columns {
for _, f := range cv.original.Fields {
if !f.GetSystem() || !strings.EqualFold(column.Name, f.GetName()) {
continue
}
var hasMatch bool
for _, newIndex := range cv.new.Indexes {
newParsed := dbutils.ParseIndex(newIndex)
// exclude the non-important identifiers from the check
newParsed.SchemaName = oldParsed.SchemaName
newParsed.IndexName = oldParsed.IndexName
newParsed.TableName = oldParsed.TableName
// exclude partial constraints
newParsed.Where = oldParsed.Where
// reset collate and sort
for i := range newParsed.Columns {
newParsed.Columns[i].Collate = ""
newParsed.Columns[i].Sort = ""
}
if oldParsedStr == newParsed.Build() {
hasMatch = true
break
}
}
if !hasMatch {
return validation.NewError(
"validation_invalid_unique_system_field_index",
"Unique index definition on system fields ({{.fieldName}}) is invalid or missing.",
).SetParams(map[string]any{"fieldName": f.GetName()})
}
continue OLD_INDEXES_LOOP
}
}
}
}
// check for required indexes
//
// note: this is in case the indexes were removed manually when creating/importing new auth collections
// and technically it is not necessary because on app.Save() the missing indexes will be reinserted by the system collection hook
if cv.new.IsAuth() {
requiredNames := []string{FieldNameTokenKey, FieldNameEmail}
for _, name := range requiredNames {
if _, ok := dbutils.FindSingleColumnUniqueIndex(indexes, name); !ok {
return validation.NewError(
"validation_missing_required_unique_index",
`Missing required unique index for field "{{.fieldName}}".`,
).SetParams(map[string]any{"fieldName": name})
}
}
}
return nil
}
func (validator *collectionValidator) validateOptions() error {
switch validator.new.Type {
case CollectionTypeAuth:
return validator.new.collectionAuthOptions.validate(validator)
case CollectionTypeView:
return validator.new.collectionViewOptions.validate(validator)
}
return nil
}

View file

@ -0,0 +1,909 @@
package core_test
import (
"testing"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
"github.com/pocketbase/pocketbase/tools/types"
)
func TestCollectionValidate(t *testing.T) {
t.Parallel()
scenarios := []struct {
name string
collection func(app core.App) (*core.Collection, error)
expectedErrors []string
}{
{
name: "empty collection",
collection: func(app core.App) (*core.Collection, error) {
return &core.Collection{}, nil
},
expectedErrors: []string{
"id", "name", "type", "fields", // no default fields because the type is unknown
},
},
{
name: "unknown type with all invalid fields",
collection: func(app core.App) (*core.Collection, error) {
c := &core.Collection{}
c.Id = "invalid_id ?!@#$"
c.Name = "invalid_name ?!@#$"
c.Type = "invalid_type"
c.ListRule = types.Pointer("missing = '123'")
c.ViewRule = types.Pointer("missing = '123'")
c.CreateRule = types.Pointer("missing = '123'")
c.UpdateRule = types.Pointer("missing = '123'")
c.DeleteRule = types.Pointer("missing = '123'")
c.Indexes = []string{"create index '' on '' ()"}
// type specific fields
c.ViewQuery = "invalid" // should be ignored
c.AuthRule = types.Pointer("missing = '123'") // should be ignored
return c, nil
},
expectedErrors: []string{
"id", "name", "type", "indexes",
"listRule", "viewRule", "createRule", "updateRule", "deleteRule",
"fields", // no default fields because the type is unknown
},
},
{
name: "base with invalid fields",
collection: func(app core.App) (*core.Collection, error) {
c := core.NewBaseCollection("invalid_name ?!@#$")
c.Indexes = []string{"create index '' on '' ()"}
// type specific fields
c.ViewQuery = "invalid" // should be ignored
c.AuthRule = types.Pointer("missing = '123'") // should be ignored
return c, nil
},
expectedErrors: []string{"name", "indexes"},
},
{
name: "view with invalid fields",
collection: func(app core.App) (*core.Collection, error) {
c := core.NewViewCollection("invalid_name ?!@#$")
c.Indexes = []string{"create index '' on '' ()"}
// type specific fields
c.ViewQuery = "invalid"
c.AuthRule = types.Pointer("missing = '123'") // should be ignored
return c, nil
},
expectedErrors: []string{"indexes", "name", "fields", "viewQuery"},
},
{
name: "auth with invalid fields",
collection: func(app core.App) (*core.Collection, error) {
c := core.NewAuthCollection("invalid_name ?!@#$")
c.Indexes = []string{"create index '' on '' ()"}
// type specific fields
c.ViewQuery = "invalid" // should be ignored
c.AuthRule = types.Pointer("missing = '123'")
return c, nil
},
expectedErrors: []string{"indexes", "name", "authRule"},
},
// type checks
{
name: "empty type",
collection: func(app core.App) (*core.Collection, error) {
c := core.NewBaseCollection("test")
c.Type = ""
return c, nil
},
expectedErrors: []string{"type"},
},
{
name: "unknown type",
collection: func(app core.App) (*core.Collection, error) {
c := core.NewBaseCollection("test")
c.Type = "unknown"
return c, nil
},
expectedErrors: []string{"type"},
},
{
name: "base type",
collection: func(app core.App) (*core.Collection, error) {
c := core.NewBaseCollection("test")
return c, nil
},
expectedErrors: []string{},
},
{
name: "view type",
collection: func(app core.App) (*core.Collection, error) {
c := core.NewViewCollection("test")
c.ViewQuery = "select 1 as id"
return c, nil
},
expectedErrors: []string{},
},
{
name: "auth type",
collection: func(app core.App) (*core.Collection, error) {
c := core.NewAuthCollection("test")
return c, nil
},
expectedErrors: []string{},
},
{
name: "changing type",
collection: func(app core.App) (*core.Collection, error) {
c, _ := app.FindCollectionByNameOrId("users")
c.Type = core.CollectionTypeBase
return c, nil
},
expectedErrors: []string{"type"},
},
// system checks
{
name: "change from system to regular",
collection: func(app core.App) (*core.Collection, error) {
c, _ := app.FindCollectionByNameOrId(core.CollectionNameSuperusers)
c.System = false
return c, nil
},
expectedErrors: []string{"system"},
},
{
name: "change from regular to system",
collection: func(app core.App) (*core.Collection, error) {
c, _ := app.FindCollectionByNameOrId("demo1")
c.System = true
return c, nil
},
expectedErrors: []string{"system"},
},
{
name: "create system",
collection: func(app core.App) (*core.Collection, error) {
c := core.NewBaseCollection("new_system")
c.System = true
return c, nil
},
expectedErrors: []string{},
},
// id checks
{
name: "empty id",
collection: func(app core.App) (*core.Collection, error) {
c := core.NewBaseCollection("test")
c.Id = ""
return c, nil
},
expectedErrors: []string{"id"},
},
{
name: "invalid id",
collection: func(app core.App) (*core.Collection, error) {
c := core.NewBaseCollection("test")
c.Id = "!invalid"
return c, nil
},
expectedErrors: []string{"id"},
},
{
name: "existing id",
collection: func(app core.App) (*core.Collection, error) {
c := core.NewBaseCollection("test")
c.Id = "_pb_users_auth_"
return c, nil
},
expectedErrors: []string{"id"},
},
{
name: "changing id",
collection: func(app core.App) (*core.Collection, error) {
c, _ := app.FindCollectionByNameOrId("demo3")
c.Id = "anything"
return c, nil
},
expectedErrors: []string{"id"},
},
{
name: "valid id",
collection: func(app core.App) (*core.Collection, error) {
c := core.NewBaseCollection("test")
c.Id = "anything"
return c, nil
},
expectedErrors: []string{},
},
// name checks
{
name: "empty name",
collection: func(app core.App) (*core.Collection, error) {
c := core.NewBaseCollection("")
c.Id = "test"
return c, nil
},
expectedErrors: []string{"name"},
},
{
name: "invalid name",
collection: func(app core.App) (*core.Collection, error) {
c := core.NewBaseCollection("!invalid")
return c, nil
},
expectedErrors: []string{"name"},
},
{
name: "name with _via_",
collection: func(app core.App) (*core.Collection, error) {
c := core.NewBaseCollection("a_via_b")
return c, nil
},
expectedErrors: []string{"name"},
},
{
name: "create with existing collection name",
collection: func(app core.App) (*core.Collection, error) {
c := core.NewBaseCollection("demo1")
return c, nil
},
expectedErrors: []string{"name"},
},
{
name: "create with existing internal table name",
collection: func(app core.App) (*core.Collection, error) {
c := core.NewBaseCollection("_collections")
return c, nil
},
expectedErrors: []string{"name"},
},
{
name: "update with existing collection name",
collection: func(app core.App) (*core.Collection, error) {
c, _ := app.FindCollectionByNameOrId("users")
c.Name = "demo1"
return c, nil
},
expectedErrors: []string{"name"},
},
{
name: "update with existing internal table name",
collection: func(app core.App) (*core.Collection, error) {
c, _ := app.FindCollectionByNameOrId("users")
c.Name = "_collections"
return c, nil
},
expectedErrors: []string{"name"},
},
{
name: "system collection name change",
collection: func(app core.App) (*core.Collection, error) {
c, _ := app.FindCollectionByNameOrId(core.CollectionNameSuperusers)
c.Name = "superusers_new"
return c, nil
},
expectedErrors: []string{"name"},
},
{
name: "create with valid name",
collection: func(app core.App) (*core.Collection, error) {
c := core.NewBaseCollection("new_col")
return c, nil
},
expectedErrors: []string{},
},
{
name: "update with valid name",
collection: func(app core.App) (*core.Collection, error) {
c, _ := app.FindCollectionByNameOrId("demo1")
c.Name = "demo1_new"
return c, nil
},
expectedErrors: []string{},
},
// rule checks
{
name: "invalid base rules",
collection: func(app core.App) (*core.Collection, error) {
c := core.NewBaseCollection("new")
c.ListRule = types.Pointer("!invalid")
c.ViewRule = types.Pointer("missing = 123")
c.CreateRule = types.Pointer("id = 123 && missing = 456")
c.UpdateRule = types.Pointer("(id = 123")
c.DeleteRule = types.Pointer("missing = 123")
return c, nil
},
expectedErrors: []string{"listRule", "viewRule", "createRule", "updateRule", "deleteRule"},
},
{
name: "valid base rules",
collection: func(app core.App) (*core.Collection, error) {
c := core.NewBaseCollection("new")
c.Fields.Add(&core.TextField{Name: "f1"}) // dummy field to ensure that new fields can be referenced
c.ListRule = types.Pointer("")
c.ViewRule = types.Pointer("f1 = 123")
c.CreateRule = types.Pointer("id = 123 && f1 = 456")
c.UpdateRule = types.Pointer("(id = 123)")
c.DeleteRule = types.Pointer("f1 = 123")
return c, nil
},
expectedErrors: []string{},
},
{
name: "view with non-nil create/update/delete rules",
collection: func(app core.App) (*core.Collection, error) {
c := core.NewViewCollection("new")
c.ViewQuery = "select 1 as id, 'text' as f1"
c.ListRule = types.Pointer("id = 123")
c.ViewRule = types.Pointer("f1 = 456")
c.CreateRule = types.Pointer("")
c.UpdateRule = types.Pointer("")
c.DeleteRule = types.Pointer("")
return c, nil
},
expectedErrors: []string{"createRule", "updateRule", "deleteRule"},
},
{
name: "view with nil create/update/delete rules",
collection: func(app core.App) (*core.Collection, error) {
c := core.NewViewCollection("new")
c.ViewQuery = "select 1 as id, 'text' as f1"
c.ListRule = types.Pointer("id = 1")
c.ViewRule = types.Pointer("f1 = 456")
return c, nil
},
expectedErrors: []string{},
},
{
name: "changing api rules",
collection: func(app core.App) (*core.Collection, error) {
c, _ := app.FindCollectionByNameOrId("users")
c.Fields.Add(&core.TextField{Name: "f1"}) // dummy field to ensure that new fields can be referenced
c.ListRule = types.Pointer("id = 1")
c.ViewRule = types.Pointer("f1 = 456")
c.CreateRule = types.Pointer("id = 123 && f1 = 456")
c.UpdateRule = types.Pointer("(id = 123)")
c.DeleteRule = types.Pointer("f1 = 123")
return c, nil
},
expectedErrors: []string{},
},
{
name: "changing system collection api rules",
collection: func(app core.App) (*core.Collection, error) {
c, _ := app.FindCollectionByNameOrId(core.CollectionNameSuperusers)
c.ListRule = types.Pointer("1 = 1")
c.ViewRule = types.Pointer("1 = 1")
c.CreateRule = types.Pointer("1 = 1")
c.UpdateRule = types.Pointer("1 = 1")
c.DeleteRule = types.Pointer("1 = 1")
c.ManageRule = types.Pointer("1 = 1")
c.AuthRule = types.Pointer("1 = 1")
return c, nil
},
expectedErrors: []string{
"listRule", "viewRule", "createRule", "updateRule",
"deleteRule", "manageRule", "authRule",
},
},
// indexes checks
{
name: "invalid index expression",
collection: func(app core.App) (*core.Collection, error) {
c, _ := app.FindCollectionByNameOrId("demo1")
c.Indexes = []string{
"create index invalid",
"create index idx_test_demo2 on anything (text)", // the name of table shouldn't matter
}
return c, nil
},
expectedErrors: []string{"indexes"},
},
{
name: "index name used in other table",
collection: func(app core.App) (*core.Collection, error) {
c, _ := app.FindCollectionByNameOrId("demo1")
c.Indexes = []string{
"create index `idx_test_demo1` on demo1 (id)",
"create index `__pb_USERS_auth__username_idx` on anything (text)", // should be case-insensitive
}
return c, nil
},
expectedErrors: []string{"indexes"},
},
{
name: "duplicated index names",
collection: func(app core.App) (*core.Collection, error) {
c, _ := app.FindCollectionByNameOrId("demo1")
c.Indexes = []string{
"create index idx_test_demo1 on demo1 (id)",
"create index idx_test_demo1 on anything (text)",
}
return c, nil
},
expectedErrors: []string{"indexes"},
},
{
name: "duplicated index definitions",
collection: func(app core.App) (*core.Collection, error) {
c, _ := app.FindCollectionByNameOrId("demo1")
c.Indexes = []string{
"create index idx_test_demo1 on demo1 (id)",
"create index idx_test_demo2 on demo1 (id)",
}
return c, nil
},
expectedErrors: []string{"indexes"},
},
{
name: "try to add index to a view collection",
collection: func(app core.App) (*core.Collection, error) {
c, _ := app.FindCollectionByNameOrId("view1")
c.Indexes = []string{"create index idx_test_view1 on view1 (id)"}
return c, nil
},
expectedErrors: []string{"indexes"},
},
{
name: "replace old with new indexes",
collection: func(app core.App) (*core.Collection, error) {
c, _ := app.FindCollectionByNameOrId("demo1")
c.Indexes = []string{
"create index idx_test_demo1 on demo1 (id)",
"create index idx_test_demo2 on anything (text)", // the name of table shouldn't matter
}
return c, nil
},
expectedErrors: []string{},
},
{
name: "old + new indexes",
collection: func(app core.App) (*core.Collection, error) {
c, _ := app.FindCollectionByNameOrId("demo1")
c.Indexes = []string{
"CREATE INDEX `_wsmn24bux7wo113_created_idx` ON `demo1` (`created`)",
"create index idx_test_demo1 on anything (id)",
}
return c, nil
},
expectedErrors: []string{},
},
{
name: "index for missing field",
collection: func(app core.App) (*core.Collection, error) {
c, _ := app.FindCollectionByNameOrId("demo1")
c.Indexes = []string{
"create index idx_test_demo1 on anything (missing)", // still valid because it is checked on db persist
}
return c, nil
},
expectedErrors: []string{},
},
{
name: "auth collection with missing required unique indexes",
collection: func(app core.App) (*core.Collection, error) {
c := core.NewAuthCollection("new_auth")
c.Indexes = []string{}
return c, nil
},
expectedErrors: []string{"indexes", "passwordAuth"},
},
{
name: "auth collection with non-unique required indexes",
collection: func(app core.App) (*core.Collection, error) {
c := core.NewAuthCollection("new_auth")
c.Indexes = []string{
"create index test_idx1 on new_auth (tokenKey)",
"create index test_idx2 on new_auth (email)",
}
return c, nil
},
expectedErrors: []string{"indexes", "passwordAuth"},
},
{
name: "auth collection with unique required indexes",
collection: func(app core.App) (*core.Collection, error) {
c := core.NewAuthCollection("new_auth")
c.Indexes = []string{
"create unique index test_idx1 on new_auth (tokenKey)",
"create unique index test_idx2 on new_auth (email)",
}
return c, nil
},
expectedErrors: []string{},
},
{
name: "removing index on system field",
collection: func(app core.App) (*core.Collection, error) {
demo2, err := app.FindCollectionByNameOrId("demo2")
if err != nil {
return nil, err
}
// mark the title field as system
demo2.Fields.GetByName("title").SetSystem(true)
if err = app.Save(demo2); err != nil {
return nil, err
}
// refresh
demo2, err = app.FindCollectionByNameOrId("demo2")
if err != nil {
return nil, err
}
demo2.RemoveIndex("idx_unique_demo2_title")
return demo2, nil
},
expectedErrors: []string{"indexes"},
},
{
name: "changing partial constraint of existing index on system field",
collection: func(app core.App) (*core.Collection, error) {
demo2, err := app.FindCollectionByNameOrId("demo2")
if err != nil {
return nil, err
}
// mark the title field as system
demo2.Fields.GetByName("title").SetSystem(true)
if err = app.Save(demo2); err != nil {
return nil, err
}
// refresh
demo2, err = app.FindCollectionByNameOrId("demo2")
if err != nil {
return nil, err
}
// replace the index with a partial one
demo2.RemoveIndex("idx_unique_demo2_title")
demo2.AddIndex("idx_new_demo2_title", true, "title", "1 = 1")
return demo2, nil
},
expectedErrors: []string{},
},
{
name: "changing column sort and collate of existing index on system field",
collection: func(app core.App) (*core.Collection, error) {
demo2, err := app.FindCollectionByNameOrId("demo2")
if err != nil {
return nil, err
}
// mark the title field as system
demo2.Fields.GetByName("title").SetSystem(true)
if err = app.Save(demo2); err != nil {
return nil, err
}
// refresh
demo2, err = app.FindCollectionByNameOrId("demo2")
if err != nil {
return nil, err
}
// replace the index with a new one for the same column but with collate and sort
demo2.RemoveIndex("idx_unique_demo2_title")
demo2.AddIndex("idx_new_demo2_title", true, "title COLLATE test ASC", "")
return demo2, nil
},
expectedErrors: []string{},
},
{
name: "adding new column to index on system field",
collection: func(app core.App) (*core.Collection, error) {
demo2, err := app.FindCollectionByNameOrId("demo2")
if err != nil {
return nil, err
}
// mark the title field as system
demo2.Fields.GetByName("title").SetSystem(true)
if err = app.Save(demo2); err != nil {
return nil, err
}
// refresh
demo2, err = app.FindCollectionByNameOrId("demo2")
if err != nil {
return nil, err
}
// replace the index with a non-unique one
demo2.RemoveIndex("idx_unique_demo2_title")
demo2.AddIndex("idx_new_title", false, "title, id", "")
return demo2, nil
},
expectedErrors: []string{"indexes"},
},
{
name: "changing index type on system field",
collection: func(app core.App) (*core.Collection, error) {
demo2, err := app.FindCollectionByNameOrId("demo2")
if err != nil {
return nil, err
}
// mark the title field as system
demo2.Fields.GetByName("title").SetSystem(true)
if err = app.Save(demo2); err != nil {
return nil, err
}
// refresh
demo2, err = app.FindCollectionByNameOrId("demo2")
if err != nil {
return nil, err
}
// replace the index with a non-unique one (partial constraints are ignored)
demo2.RemoveIndex("idx_unique_demo2_title")
demo2.AddIndex("idx_new_title", false, "title", "1=1")
return demo2, nil
},
expectedErrors: []string{"indexes"},
},
{
name: "changing index on non-system field",
collection: func(app core.App) (*core.Collection, error) {
demo2, err := app.FindCollectionByNameOrId("demo2")
if err != nil {
return nil, err
}
// replace the index with a partial one
demo2.RemoveIndex("idx_demo2_active")
demo2.AddIndex("idx_demo2_active", true, "active", "1 = 1")
return demo2, nil
},
expectedErrors: []string{},
},
// fields list checks
{
name: "empty fields",
collection: func(app core.App) (*core.Collection, error) {
c := core.NewBaseCollection("new_auth")
c.Fields = nil // the minimum fields should auto added
return c, nil
},
expectedErrors: []string{"fields"},
},
{
name: "no id primay key field",
collection: func(app core.App) (*core.Collection, error) {
c := core.NewBaseCollection("new_auth")
c.Fields = core.NewFieldsList(
&core.TextField{Name: "id"},
)
return c, nil
},
expectedErrors: []string{"fields"},
},
{
name: "with id primay key field",
collection: func(app core.App) (*core.Collection, error) {
c := core.NewBaseCollection("new_auth")
c.Fields = core.NewFieldsList(
&core.TextField{Name: "id", PrimaryKey: true, Required: true, Pattern: `\w+`},
)
return c, nil
},
expectedErrors: []string{},
},
{
name: "duplicated field names",
collection: func(app core.App) (*core.Collection, error) {
c := core.NewBaseCollection("new_auth")
c.Fields = core.NewFieldsList(
&core.TextField{Name: "id", PrimaryKey: true, Required: true, Pattern: `\w+`},
&core.TextField{Id: "f1", Name: "Test"}, // case-insensitive
&core.BoolField{Id: "f2", Name: "test"},
)
return c, nil
},
expectedErrors: []string{"fields"},
},
{
name: "changing field type",
collection: func(app core.App) (*core.Collection, error) {
c, _ := app.FindCollectionByNameOrId("demo1")
f := c.Fields.GetByName("text")
c.Fields.Add(&core.BoolField{Id: f.GetId(), Name: f.GetName()})
return c, nil
},
expectedErrors: []string{"fields"},
},
{
name: "renaming system field",
collection: func(app core.App) (*core.Collection, error) {
c, _ := app.FindCollectionByNameOrId(core.CollectionNameAuthOrigins)
f := c.Fields.GetByName("fingerprint")
f.SetName("fingerprint_new")
return c, nil
},
expectedErrors: []string{"fields"},
},
{
name: "deleting system field",
collection: func(app core.App) (*core.Collection, error) {
c, _ := app.FindCollectionByNameOrId(core.CollectionNameAuthOrigins)
c.Fields.RemoveByName("fingerprint")
return c, nil
},
expectedErrors: []string{"fields"},
},
{
name: "invalid field setting",
collection: func(app core.App) (*core.Collection, error) {
c := core.NewBaseCollection("test_new")
c.Fields.Add(&core.TextField{Name: "f1", Min: -10})
return c, nil
},
expectedErrors: []string{"fields"},
},
{
name: "valid field setting",
collection: func(app core.App) (*core.Collection, error) {
c := core.NewBaseCollection("test_new")
c.Fields.Add(&core.TextField{Name: "f1", Min: 10})
return c, nil
},
expectedErrors: []string{},
},
{
name: "fields view changes should be ignored",
collection: func(app core.App) (*core.Collection, error) {
c, _ := app.FindCollectionByNameOrId("view1")
c.Fields = nil
return c, nil
},
expectedErrors: []string{},
},
{
name: "with reserved auth only field name (passwordConfirm)",
collection: func(app core.App) (*core.Collection, error) {
c := core.NewAuthCollection("new_auth")
c.Fields.Add(
&core.TextField{Name: "passwordConfirm"},
)
return c, nil
},
expectedErrors: []string{"fields"},
},
{
name: "with reserved auth only field name (oldPassword)",
collection: func(app core.App) (*core.Collection, error) {
c := core.NewAuthCollection("new_auth")
c.Fields.Add(
&core.TextField{Name: "oldPassword"},
)
return c, nil
},
expectedErrors: []string{"fields"},
},
{
name: "with invalid password auth field options (1)",
collection: func(app core.App) (*core.Collection, error) {
c := core.NewAuthCollection("new_auth")
c.Fields.Add(
&core.TextField{Name: "password", System: true, Hidden: true}, // should be PasswordField
)
return c, nil
},
expectedErrors: []string{"fields"},
},
{
name: "with valid password auth field options (2)",
collection: func(app core.App) (*core.Collection, error) {
c := core.NewAuthCollection("new_auth")
c.Fields.Add(
&core.PasswordField{Name: "password", System: true, Hidden: true},
)
return c, nil
},
expectedErrors: []string{},
},
{
name: "with invalid tokenKey auth field options (1)",
collection: func(app core.App) (*core.Collection, error) {
c := core.NewAuthCollection("new_auth")
c.Fields.Add(
&core.TextField{Name: "tokenKey", System: true}, // should be also hidden
)
return c, nil
},
expectedErrors: []string{"fields"},
},
{
name: "with valid tokenKey auth field options (2)",
collection: func(app core.App) (*core.Collection, error) {
c := core.NewAuthCollection("new_auth")
c.Fields.Add(
&core.TextField{Name: "tokenKey", System: true, Hidden: true},
)
return c, nil
},
expectedErrors: []string{},
},
{
name: "with invalid email auth field options (1)",
collection: func(app core.App) (*core.Collection, error) {
c := core.NewAuthCollection("new_auth")
c.Fields.Add(
&core.TextField{Name: "email", System: true}, // should be EmailField
)
return c, nil
},
expectedErrors: []string{"fields"},
},
{
name: "with valid email auth field options (2)",
collection: func(app core.App) (*core.Collection, error) {
c := core.NewAuthCollection("new_auth")
c.Fields.Add(
&core.EmailField{Name: "email", System: true},
)
return c, nil
},
expectedErrors: []string{},
},
{
name: "with invalid verified auth field options (1)",
collection: func(app core.App) (*core.Collection, error) {
c := core.NewAuthCollection("new_auth")
c.Fields.Add(
&core.TextField{Name: "verified", System: true}, // should be BoolField
)
return c, nil
},
expectedErrors: []string{"fields"},
},
{
name: "with valid verified auth field options (2)",
collection: func(app core.App) (*core.Collection, error) {
c := core.NewAuthCollection("new_auth")
c.Fields.Add(
&core.BoolField{Name: "verified", System: true},
)
return c, nil
},
expectedErrors: []string{},
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
collection, err := s.collection(app)
if err != nil {
t.Fatalf("Failed to retrieve test collection: %v", err)
}
result := app.Validate(collection)
tests.TestValidationErrors(t, result, s.expectedErrors)
})
}
}

499
core/db.go Normal file
View file

@ -0,0 +1,499 @@
package core
import (
"context"
"errors"
"fmt"
"hash/crc32"
"regexp"
"slices"
"strconv"
"strings"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase/tools/security"
"github.com/spf13/cast"
)
const (
idColumn string = "id"
// DefaultIdLength is the default length of the generated model id.
DefaultIdLength int = 15
// DefaultIdAlphabet is the default characters set used for generating the model id.
DefaultIdAlphabet string = "abcdefghijklmnopqrstuvwxyz0123456789"
)
// DefaultIdRegex specifies the default regex pattern for an id value.
var DefaultIdRegex = regexp.MustCompile(`^\w+$`)
// DBExporter defines an interface for custom DB data export.
// Usually used as part of [App.Save].
type DBExporter interface {
// DBExport returns a key-value map with the data to be used when saving the struct in the database.
DBExport(app App) (map[string]any, error)
}
// PreValidator defines an optional model interface for registering a
// function that will run BEFORE firing the validation hooks (see [App.ValidateWithContext]).
type PreValidator interface {
// PreValidate defines a function that runs BEFORE the validation hooks.
PreValidate(ctx context.Context, app App) error
}
// PostValidator defines an optional model interface for registering a
// function that will run AFTER executing the validation hooks (see [App.ValidateWithContext]).
type PostValidator interface {
// PostValidate defines a function that runs AFTER the successful
// execution of the validation hooks.
PostValidate(ctx context.Context, app App) error
}
// GenerateDefaultRandomId generates a default random id string
// (note: the generated random string is not intended for security purposes).
func GenerateDefaultRandomId() string {
return security.PseudorandomStringWithAlphabet(DefaultIdLength, DefaultIdAlphabet)
}
// crc32Checksum generates a stringified crc32 checksum from the provided plain string.
func crc32Checksum(str string) string {
return strconv.FormatInt(int64(crc32.ChecksumIEEE([]byte(str))), 10)
}
// ModelQuery creates a new preconfigured select data.db query with preset
// SELECT, FROM and other common fields based on the provided model.
func (app *BaseApp) ModelQuery(m Model) *dbx.SelectQuery {
return app.modelQuery(app.ConcurrentDB(), m)
}
// AuxModelQuery creates a new preconfigured select auxiliary.db query with preset
// SELECT, FROM and other common fields based on the provided model.
func (app *BaseApp) AuxModelQuery(m Model) *dbx.SelectQuery {
return app.modelQuery(app.AuxConcurrentDB(), m)
}
func (app *BaseApp) modelQuery(db dbx.Builder, m Model) *dbx.SelectQuery {
tableName := m.TableName()
return db.
Select("{{" + tableName + "}}.*").
From(tableName).
WithBuildHook(func(query *dbx.Query) {
query.WithExecHook(execLockRetry(app.config.QueryTimeout, defaultMaxLockRetries))
})
}
// Delete deletes the specified model from the regular app database.
func (app *BaseApp) Delete(model Model) error {
return app.DeleteWithContext(context.Background(), model)
}
// Delete deletes the specified model from the regular app database
// (the context could be used to limit the query execution).
func (app *BaseApp) DeleteWithContext(ctx context.Context, model Model) error {
return app.delete(ctx, model, false)
}
// AuxDelete deletes the specified model from the auxiliary database.
func (app *BaseApp) AuxDelete(model Model) error {
return app.AuxDeleteWithContext(context.Background(), model)
}
// AuxDeleteWithContext deletes the specified model from the auxiliary database
// (the context could be used to limit the query execution).
func (app *BaseApp) AuxDeleteWithContext(ctx context.Context, model Model) error {
return app.delete(ctx, model, true)
}
func (app *BaseApp) delete(ctx context.Context, model Model, isForAuxDB bool) error {
event := new(ModelEvent)
event.App = app
event.Type = ModelEventTypeDelete
event.Context = ctx
event.Model = model
deleteErr := app.OnModelDelete().Trigger(event, func(e *ModelEvent) error {
pk := cast.ToString(e.Model.LastSavedPK())
if pk == "" {
return errors.New("the model can be deleted only if it is existing and has a non-empty primary key")
}
// db write
return e.App.OnModelDeleteExecute().Trigger(event, func(e *ModelEvent) error {
var db dbx.Builder
if isForAuxDB {
db = e.App.AuxNonconcurrentDB()
} else {
db = e.App.NonconcurrentDB()
}
return baseLockRetry(func(attempt int) error {
_, err := db.Delete(e.Model.TableName(), dbx.HashExp{
idColumn: pk,
}).WithContext(e.Context).Execute()
return err
}, defaultMaxLockRetries)
})
})
if deleteErr != nil {
errEvent := &ModelErrorEvent{ModelEvent: *event, Error: deleteErr}
errEvent.App = app // replace with the initial app in case it was changed by the hook
hookErr := app.OnModelAfterDeleteError().Trigger(errEvent)
if hookErr != nil {
return errors.Join(deleteErr, hookErr)
}
return deleteErr
}
if app.txInfo != nil {
// execute later after the transaction has completed
app.txInfo.OnComplete(func(txErr error) error {
if app.txInfo != nil && app.txInfo.parent != nil {
event.App = app.txInfo.parent
}
if txErr != nil {
return app.OnModelAfterDeleteError().Trigger(&ModelErrorEvent{
ModelEvent: *event,
Error: txErr,
})
}
return app.OnModelAfterDeleteSuccess().Trigger(event)
})
} else if err := event.App.OnModelAfterDeleteSuccess().Trigger(event); err != nil {
return err
}
return nil
}
// Save validates and saves the specified model into the regular app database.
//
// If you don't want to run validations, use [App.SaveNoValidate()].
func (app *BaseApp) Save(model Model) error {
return app.SaveWithContext(context.Background(), model)
}
// SaveWithContext is the same as [App.Save()] but allows specifying a context to limit the db execution.
//
// If you don't want to run validations, use [App.SaveNoValidateWithContext()].
func (app *BaseApp) SaveWithContext(ctx context.Context, model Model) error {
return app.save(ctx, model, true, false)
}
// SaveNoValidate saves the specified model into the regular app database without performing validations.
//
// If you want to also run validations before persisting, use [App.Save()].
func (app *BaseApp) SaveNoValidate(model Model) error {
return app.SaveNoValidateWithContext(context.Background(), model)
}
// SaveNoValidateWithContext is the same as [App.SaveNoValidate()]
// but allows specifying a context to limit the db execution.
//
// If you want to also run validations before persisting, use [App.SaveWithContext()].
func (app *BaseApp) SaveNoValidateWithContext(ctx context.Context, model Model) error {
return app.save(ctx, model, false, false)
}
// AuxSave validates and saves the specified model into the auxiliary app database.
//
// If you don't want to run validations, use [App.AuxSaveNoValidate()].
func (app *BaseApp) AuxSave(model Model) error {
return app.AuxSaveWithContext(context.Background(), model)
}
// AuxSaveWithContext is the same as [App.AuxSave()] but allows specifying a context to limit the db execution.
//
// If you don't want to run validations, use [App.AuxSaveNoValidateWithContext()].
func (app *BaseApp) AuxSaveWithContext(ctx context.Context, model Model) error {
return app.save(ctx, model, true, true)
}
// AuxSaveNoValidate saves the specified model into the auxiliary app database without performing validations.
//
// If you want to also run validations before persisting, use [App.AuxSave()].
func (app *BaseApp) AuxSaveNoValidate(model Model) error {
return app.AuxSaveNoValidateWithContext(context.Background(), model)
}
// AuxSaveNoValidateWithContext is the same as [App.AuxSaveNoValidate()]
// but allows specifying a context to limit the db execution.
//
// If you want to also run validations before persisting, use [App.AuxSaveWithContext()].
func (app *BaseApp) AuxSaveNoValidateWithContext(ctx context.Context, model Model) error {
return app.save(ctx, model, false, true)
}
// Validate triggers the OnModelValidate hook for the specified model.
func (app *BaseApp) Validate(model Model) error {
return app.ValidateWithContext(context.Background(), model)
}
// ValidateWithContext is the same as Validate but allows specifying the ModelEvent context.
func (app *BaseApp) ValidateWithContext(ctx context.Context, model Model) error {
if m, ok := model.(PreValidator); ok {
if err := m.PreValidate(ctx, app); err != nil {
return err
}
}
event := new(ModelEvent)
event.App = app
event.Context = ctx
event.Type = ModelEventTypeValidate
event.Model = model
return event.App.OnModelValidate().Trigger(event, func(e *ModelEvent) error {
if m, ok := e.Model.(PostValidator); ok {
if err := m.PostValidate(ctx, e.App); err != nil {
return err
}
}
return e.Next()
})
}
// -------------------------------------------------------------------
func (app *BaseApp) save(ctx context.Context, model Model, withValidations bool, isForAuxDB bool) error {
if model.IsNew() {
return app.create(ctx, model, withValidations, isForAuxDB)
}
return app.update(ctx, model, withValidations, isForAuxDB)
}
func (app *BaseApp) create(ctx context.Context, model Model, withValidations bool, isForAuxDB bool) error {
event := new(ModelEvent)
event.App = app
event.Context = ctx
event.Type = ModelEventTypeCreate
event.Model = model
saveErr := app.OnModelCreate().Trigger(event, func(e *ModelEvent) error {
// run validations (if any)
if withValidations {
validateErr := e.App.ValidateWithContext(e.Context, e.Model)
if validateErr != nil {
return validateErr
}
}
// db write
return e.App.OnModelCreateExecute().Trigger(event, func(e *ModelEvent) error {
var db dbx.Builder
if isForAuxDB {
db = e.App.AuxNonconcurrentDB()
} else {
db = e.App.NonconcurrentDB()
}
dbErr := baseLockRetry(func(attempt int) error {
if m, ok := e.Model.(DBExporter); ok {
data, err := m.DBExport(e.App)
if err != nil {
return err
}
// manually add the id to the data if missing
if _, ok := data[idColumn]; !ok {
data[idColumn] = e.Model.PK()
}
if cast.ToString(data[idColumn]) == "" {
return errors.New("empty primary key is not allowed when using the DBExporter interface")
}
_, err = db.Insert(e.Model.TableName(), data).WithContext(e.Context).Execute()
return err
}
return db.Model(e.Model).WithContext(e.Context).Insert()
}, defaultMaxLockRetries)
if dbErr != nil {
return dbErr
}
e.Model.MarkAsNotNew()
return nil
})
})
if saveErr != nil {
event.Model.MarkAsNew() // reset "new" state
errEvent := &ModelErrorEvent{ModelEvent: *event, Error: saveErr}
errEvent.App = app // replace with the initial app in case it was changed by the hook
hookErr := app.OnModelAfterCreateError().Trigger(errEvent)
if hookErr != nil {
return errors.Join(saveErr, hookErr)
}
return saveErr
}
if app.txInfo != nil {
// execute later after the transaction has completed
app.txInfo.OnComplete(func(txErr error) error {
if app.txInfo != nil && app.txInfo.parent != nil {
event.App = app.txInfo.parent
}
if txErr != nil {
event.Model.MarkAsNew() // reset "new" state
return app.OnModelAfterCreateError().Trigger(&ModelErrorEvent{
ModelEvent: *event,
Error: txErr,
})
}
return app.OnModelAfterCreateSuccess().Trigger(event)
})
} else if err := event.App.OnModelAfterCreateSuccess().Trigger(event); err != nil {
return err
}
return nil
}
func (app *BaseApp) update(ctx context.Context, model Model, withValidations bool, isForAuxDB bool) error {
event := new(ModelEvent)
event.App = app
event.Context = ctx
event.Type = ModelEventTypeUpdate
event.Model = model
saveErr := app.OnModelUpdate().Trigger(event, func(e *ModelEvent) error {
// run validations (if any)
if withValidations {
validateErr := e.App.ValidateWithContext(e.Context, e.Model)
if validateErr != nil {
return validateErr
}
}
// db write
return e.App.OnModelUpdateExecute().Trigger(event, func(e *ModelEvent) error {
var db dbx.Builder
if isForAuxDB {
db = e.App.AuxNonconcurrentDB()
} else {
db = e.App.NonconcurrentDB()
}
return baseLockRetry(func(attempt int) error {
if m, ok := e.Model.(DBExporter); ok {
data, err := m.DBExport(e.App)
if err != nil {
return err
}
// note: for now disallow primary key change for consistency with dbx.ModelQuery.Update()
if data[idColumn] != e.Model.LastSavedPK() {
return errors.New("primary key change is not allowed")
}
_, err = db.Update(e.Model.TableName(), data, dbx.HashExp{
idColumn: e.Model.LastSavedPK(),
}).WithContext(e.Context).Execute()
return err
}
return db.Model(e.Model).WithContext(e.Context).Update()
}, defaultMaxLockRetries)
})
})
if saveErr != nil {
errEvent := &ModelErrorEvent{ModelEvent: *event, Error: saveErr}
errEvent.App = app // replace with the initial app in case it was changed by the hook
hookErr := app.OnModelAfterUpdateError().Trigger(errEvent)
if hookErr != nil {
return errors.Join(saveErr, hookErr)
}
return saveErr
}
if app.txInfo != nil {
// execute later after the transaction has completed
app.txInfo.OnComplete(func(txErr error) error {
if app.txInfo != nil && app.txInfo.parent != nil {
event.App = app.txInfo.parent
}
if txErr != nil {
return app.OnModelAfterUpdateError().Trigger(&ModelErrorEvent{
ModelEvent: *event,
Error: txErr,
})
}
return app.OnModelAfterUpdateSuccess().Trigger(event)
})
} else if err := event.App.OnModelAfterUpdateSuccess().Trigger(event); err != nil {
return err
}
return nil
}
func validateCollectionId(app App, optTypes ...string) validation.RuleFunc {
return func(value any) error {
id, _ := value.(string)
if id == "" {
return nil
}
collection := &Collection{}
if err := app.ModelQuery(collection).Model(id, collection); err != nil {
return validation.NewError("validation_invalid_collection_id", "Missing or invalid collection.")
}
if len(optTypes) > 0 && !slices.Contains(optTypes, collection.Type) {
return validation.NewError(
"validation_invalid_collection_type",
fmt.Sprintf("Invalid collection type - must be %s.", strings.Join(optTypes, ", ")),
)
}
return nil
}
}
func validateRecordId(app App, collectionNameOrId string) validation.RuleFunc {
return func(value any) error {
id, _ := value.(string)
if id == "" {
return nil
}
collection, err := app.FindCachedCollectionByNameOrId(collectionNameOrId)
if err != nil {
return validation.NewError("validation_invalid_collection", "Missing or invalid collection.")
}
var exists int
rowErr := app.ConcurrentDB().Select("(1)").
From(collection.Name).
AndWhere(dbx.HashExp{"id": id}).
Limit(1).
Row(&exists)
if rowErr != nil || exists == 0 {
return validation.NewError("validation_invalid_record", "Missing or invalid record.")
}
return nil
}
}

187
core/db_builder.go Normal file
View file

@ -0,0 +1,187 @@
package core
import (
"strings"
"unicode"
"unicode/utf8"
"github.com/pocketbase/dbx"
)
var _ dbx.Builder = (*dualDBBuilder)(nil)
// note: expects both builder to use the same driver
type dualDBBuilder struct {
concurrentDB dbx.Builder
nonconcurrentDB dbx.Builder
}
// Select implements the [dbx.Builder.Select] interface method.
func (b *dualDBBuilder) Select(cols ...string) *dbx.SelectQuery {
return b.concurrentDB.Select(cols...)
}
// Model implements the [dbx.Builder.Model] interface method.
func (b *dualDBBuilder) Model(data interface{}) *dbx.ModelQuery {
return b.nonconcurrentDB.Model(data)
}
// GeneratePlaceholder implements the [dbx.Builder.GeneratePlaceholder] interface method.
func (b *dualDBBuilder) GeneratePlaceholder(i int) string {
return b.concurrentDB.GeneratePlaceholder(i)
}
// Quote implements the [dbx.Builder.Quote] interface method.
func (b *dualDBBuilder) Quote(str string) string {
return b.concurrentDB.Quote(str)
}
// QuoteSimpleTableName implements the [dbx.Builder.QuoteSimpleTableName] interface method.
func (b *dualDBBuilder) QuoteSimpleTableName(table string) string {
return b.concurrentDB.QuoteSimpleTableName(table)
}
// QuoteSimpleColumnName implements the [dbx.Builder.QuoteSimpleColumnName] interface method.
func (b *dualDBBuilder) QuoteSimpleColumnName(col string) string {
return b.concurrentDB.QuoteSimpleColumnName(col)
}
// QueryBuilder implements the [dbx.Builder.QueryBuilder] interface method.
func (b *dualDBBuilder) QueryBuilder() dbx.QueryBuilder {
return b.concurrentDB.QueryBuilder()
}
// Insert implements the [dbx.Builder.Insert] interface method.
func (b *dualDBBuilder) Insert(table string, cols dbx.Params) *dbx.Query {
return b.nonconcurrentDB.Insert(table, cols)
}
// Upsert implements the [dbx.Builder.Upsert] interface method.
func (b *dualDBBuilder) Upsert(table string, cols dbx.Params, constraints ...string) *dbx.Query {
return b.nonconcurrentDB.Upsert(table, cols, constraints...)
}
// Update implements the [dbx.Builder.Update] interface method.
func (b *dualDBBuilder) Update(table string, cols dbx.Params, where dbx.Expression) *dbx.Query {
return b.nonconcurrentDB.Update(table, cols, where)
}
// Delete implements the [dbx.Builder.Delete] interface method.
func (b *dualDBBuilder) Delete(table string, where dbx.Expression) *dbx.Query {
return b.nonconcurrentDB.Delete(table, where)
}
// CreateTable implements the [dbx.Builder.CreateTable] interface method.
func (b *dualDBBuilder) CreateTable(table string, cols map[string]string, options ...string) *dbx.Query {
return b.nonconcurrentDB.CreateTable(table, cols, options...)
}
// RenameTable implements the [dbx.Builder.RenameTable] interface method.
func (b *dualDBBuilder) RenameTable(oldName, newName string) *dbx.Query {
return b.nonconcurrentDB.RenameTable(oldName, newName)
}
// DropTable implements the [dbx.Builder.DropTable] interface method.
func (b *dualDBBuilder) DropTable(table string) *dbx.Query {
return b.nonconcurrentDB.DropTable(table)
}
// TruncateTable implements the [dbx.Builder.TruncateTable] interface method.
func (b *dualDBBuilder) TruncateTable(table string) *dbx.Query {
return b.nonconcurrentDB.TruncateTable(table)
}
// AddColumn implements the [dbx.Builder.AddColumn] interface method.
func (b *dualDBBuilder) AddColumn(table, col, typ string) *dbx.Query {
return b.nonconcurrentDB.AddColumn(table, col, typ)
}
// DropColumn implements the [dbx.Builder.DropColumn] interface method.
func (b *dualDBBuilder) DropColumn(table, col string) *dbx.Query {
return b.nonconcurrentDB.DropColumn(table, col)
}
// RenameColumn implements the [dbx.Builder.RenameColumn] interface method.
func (b *dualDBBuilder) RenameColumn(table, oldName, newName string) *dbx.Query {
return b.nonconcurrentDB.RenameColumn(table, oldName, newName)
}
// AlterColumn implements the [dbx.Builder.AlterColumn] interface method.
func (b *dualDBBuilder) AlterColumn(table, col, typ string) *dbx.Query {
return b.nonconcurrentDB.AlterColumn(table, col, typ)
}
// AddPrimaryKey implements the [dbx.Builder.AddPrimaryKey] interface method.
func (b *dualDBBuilder) AddPrimaryKey(table, name string, cols ...string) *dbx.Query {
return b.nonconcurrentDB.AddPrimaryKey(table, name, cols...)
}
// DropPrimaryKey implements the [dbx.Builder.DropPrimaryKey] interface method.
func (b *dualDBBuilder) DropPrimaryKey(table, name string) *dbx.Query {
return b.nonconcurrentDB.DropPrimaryKey(table, name)
}
// AddForeignKey implements the [dbx.Builder.AddForeignKey] interface method.
func (b *dualDBBuilder) AddForeignKey(table, name string, cols, refCols []string, refTable string, options ...string) *dbx.Query {
return b.nonconcurrentDB.AddForeignKey(table, name, cols, refCols, refTable, options...)
}
// DropForeignKey implements the [dbx.Builder.DropForeignKey] interface method.
func (b *dualDBBuilder) DropForeignKey(table, name string) *dbx.Query {
return b.nonconcurrentDB.DropForeignKey(table, name)
}
// CreateIndex implements the [dbx.Builder.CreateIndex] interface method.
func (b *dualDBBuilder) CreateIndex(table, name string, cols ...string) *dbx.Query {
return b.nonconcurrentDB.CreateIndex(table, name, cols...)
}
// CreateUniqueIndex implements the [dbx.Builder.CreateUniqueIndex] interface method.
func (b *dualDBBuilder) CreateUniqueIndex(table, name string, cols ...string) *dbx.Query {
return b.nonconcurrentDB.CreateUniqueIndex(table, name, cols...)
}
// DropIndex implements the [dbx.Builder.DropIndex] interface method.
func (b *dualDBBuilder) DropIndex(table, name string) *dbx.Query {
return b.nonconcurrentDB.DropIndex(table, name)
}
// NewQuery implements the [dbx.Builder.NewQuery] interface method by
// routing the SELECT queries to the concurrent builder instance.
func (b *dualDBBuilder) NewQuery(str string) *dbx.Query {
// note: technically INSERT/UPDATE/DELETE could also have CTE but since
// it is rare for now this scase is ignored to avoid unnecessary complicating the checks
trimmed := trimLeftSpaces(str)
if hasPrefixFold(trimmed, "SELECT") || hasPrefixFold(trimmed, "WITH") {
return b.concurrentDB.NewQuery(str)
}
return b.nonconcurrentDB.NewQuery(str)
}
var asciiSpace = [256]uint8{'\t': 1, '\n': 1, '\v': 1, '\f': 1, '\r': 1, ' ': 1}
// note: similar to strings.Space() but without the right trim because it is not needed in our case
func trimLeftSpaces(str string) string {
start := 0
for ; start < len(str); start++ {
c := str[start]
if c >= utf8.RuneSelf {
return strings.TrimLeftFunc(str[start:], unicode.IsSpace)
}
if asciiSpace[c] == 0 {
break
}
}
return str[start:]
}
// note: the prefix is expected to be ASCII
func hasPrefixFold(str, prefix string) bool {
if len(str) < len(prefix) {
return false
}
return strings.EqualFold(str[:len(prefix)], prefix)
}

22
core/db_connect.go Normal file
View file

@ -0,0 +1,22 @@
//go:build !no_default_driver
package core
import (
"github.com/pocketbase/dbx"
_ "modernc.org/sqlite"
)
func DefaultDBConnect(dbPath string) (*dbx.DB, error) {
// Note: the busy_timeout pragma must be first because
// the connection needs to be set to block on busy before WAL mode
// is set in case it hasn't been already set by another connection.
pragmas := "?_pragma=busy_timeout(10000)&_pragma=journal_mode(WAL)&_pragma=journal_size_limit(200000000)&_pragma=synchronous(NORMAL)&_pragma=foreign_keys(ON)&_pragma=temp_store(MEMORY)&_pragma=cache_size(-16000)"
db, err := dbx.Open("sqlite", dbPath+pragmas)
if err != nil {
return nil, err
}
return db, nil
}

View file

@ -0,0 +1,9 @@
//go:build no_default_driver
package core
import "github.com/pocketbase/dbx"
func DefaultDBConnect(dbPath string) (*dbx.DB, error) {
panic("DBConnect config option must be set when the no_default_driver tag is used!")
}

59
core/db_model.go Normal file
View file

@ -0,0 +1,59 @@
package core
// Model defines an interface with common methods that all db models should have.
//
// Note: for simplicity composite pk are not supported.
type Model interface {
TableName() string
PK() any
LastSavedPK() any
IsNew() bool
MarkAsNew()
MarkAsNotNew()
}
// BaseModel defines a base struct that is intended to be embedded into other custom models.
type BaseModel struct {
lastSavedPK string
// Id is the primary key of the model.
// It is usually autogenerated by the parent model implementation.
Id string `db:"id" json:"id" form:"id" xml:"id"`
}
// LastSavedPK returns the last saved primary key of the model.
//
// Its value is updated to the latest PK value after MarkAsNotNew() or PostScan() calls.
func (m *BaseModel) LastSavedPK() any {
return m.lastSavedPK
}
func (m *BaseModel) PK() any {
return m.Id
}
// IsNew indicates what type of db query (insert or update)
// should be used with the model instance.
func (m *BaseModel) IsNew() bool {
return m.lastSavedPK == ""
}
// MarkAsNew clears the pk field and marks the current model as "new"
// (aka. forces m.IsNew() to be true).
func (m *BaseModel) MarkAsNew() {
m.lastSavedPK = ""
}
// MarkAsNew set the pk field to the Id value and marks the current model
// as NOT "new" (aka. forces m.IsNew() to be false).
func (m *BaseModel) MarkAsNotNew() {
m.lastSavedPK = m.Id
}
// PostScan implements the [dbx.PostScanner] interface.
//
// It is usually executed right after the model is populated with the db row values.
func (m *BaseModel) PostScan() error {
m.MarkAsNotNew()
return nil
}

70
core/db_model_test.go Normal file
View file

@ -0,0 +1,70 @@
package core_test
import (
"testing"
"github.com/pocketbase/pocketbase/core"
)
func TestBaseModel(t *testing.T) {
id := "test_id"
m := core.BaseModel{Id: id}
if m.PK() != id {
t.Fatalf("[before PostScan] Expected PK %q, got %q", "", m.PK())
}
if m.LastSavedPK() != "" {
t.Fatalf("[before PostScan] Expected LastSavedPK %q, got %q", "", m.LastSavedPK())
}
if !m.IsNew() {
t.Fatalf("[before PostScan] Expected IsNew %v, got %v", true, m.IsNew())
}
if err := m.PostScan(); err != nil {
t.Fatal(err)
}
if m.PK() != id {
t.Fatalf("[after PostScan] Expected PK %q, got %q", "", m.PK())
}
if m.LastSavedPK() != id {
t.Fatalf("[after PostScan] Expected LastSavedPK %q, got %q", id, m.LastSavedPK())
}
if m.IsNew() {
t.Fatalf("[after PostScan] Expected IsNew %v, got %v", false, m.IsNew())
}
m.MarkAsNew()
if m.PK() != id {
t.Fatalf("[after MarkAsNew] Expected PK %q, got %q", id, m.PK())
}
if m.LastSavedPK() != "" {
t.Fatalf("[after MarkAsNew] Expected LastSavedPK %q, got %q", "", m.LastSavedPK())
}
if !m.IsNew() {
t.Fatalf("[after MarkAsNew] Expected IsNew %v, got %v", true, m.IsNew())
}
// mark as not new without id
m.MarkAsNotNew()
if m.PK() != id {
t.Fatalf("[after MarkAsNotNew] Expected PK %q, got %q", id, m.PK())
}
if m.LastSavedPK() != id {
t.Fatalf("[after MarkAsNotNew] Expected LastSavedPK %q, got %q", id, m.LastSavedPK())
}
if m.IsNew() {
t.Fatalf("[after MarkAsNotNew] Expected IsNew %v, got %v", false, m.IsNew())
}
}

70
core/db_retry.go Normal file
View file

@ -0,0 +1,70 @@
package core
import (
"context"
"database/sql"
"errors"
"fmt"
"strings"
"time"
"github.com/pocketbase/dbx"
)
// default retries intervals (in ms)
var defaultRetryIntervals = []int{50, 100, 150, 200, 300, 400, 500, 700, 1000}
// default max retry attempts
const defaultMaxLockRetries = 12
func execLockRetry(timeout time.Duration, maxRetries int) dbx.ExecHookFunc {
return func(q *dbx.Query, op func() error) error {
if q.Context() == nil {
cancelCtx, cancel := context.WithTimeout(context.Background(), timeout)
defer func() {
cancel()
//nolint:staticcheck
q.WithContext(nil) // reset
}()
q.WithContext(cancelCtx)
}
execErr := baseLockRetry(func(attempt int) error {
return op()
}, maxRetries)
if execErr != nil && !errors.Is(execErr, sql.ErrNoRows) {
execErr = fmt.Errorf("%w; failed query: %s", execErr, q.SQL())
}
return execErr
}
}
func baseLockRetry(op func(attempt int) error, maxRetries int) error {
attempt := 1
Retry:
err := op(attempt)
if err != nil && attempt <= maxRetries {
errStr := err.Error()
// we are checking the error against the plain error texts since the codes could vary between drivers
if strings.Contains(errStr, "database is locked") ||
strings.Contains(errStr, "table is locked") {
// wait and retry
time.Sleep(getDefaultRetryInterval(attempt))
attempt++
goto Retry
}
}
return err
}
func getDefaultRetryInterval(attempt int) time.Duration {
if attempt < 0 || attempt > len(defaultRetryIntervals)-1 {
return time.Duration(defaultRetryIntervals[len(defaultRetryIntervals)-1]) * time.Millisecond
}
return time.Duration(defaultRetryIntervals[attempt]) * time.Millisecond
}

66
core/db_retry_test.go Normal file
View file

@ -0,0 +1,66 @@
package core
import (
"errors"
"fmt"
"testing"
)
func TestGetDefaultRetryInterval(t *testing.T) {
t.Parallel()
if i := getDefaultRetryInterval(-1); i.Milliseconds() != 1000 {
t.Fatalf("Expected 1000ms, got %v", i)
}
if i := getDefaultRetryInterval(999); i.Milliseconds() != 1000 {
t.Fatalf("Expected 1000ms, got %v", i)
}
if i := getDefaultRetryInterval(3); i.Milliseconds() != 200 {
t.Fatalf("Expected 500ms, got %v", i)
}
}
func TestBaseLockRetry(t *testing.T) {
t.Parallel()
scenarios := []struct {
err error
failUntilAttempt int
expectedAttempts int
}{
{nil, 3, 1},
{errors.New("test"), 3, 1},
{errors.New("database is locked"), 3, 3},
{errors.New("table is locked"), 3, 3},
}
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%#v", i, s.err), func(t *testing.T) {
lastAttempt := 0
err := baseLockRetry(func(attempt int) error {
lastAttempt = attempt
if attempt < s.failUntilAttempt {
return s.err
}
return nil
}, s.failUntilAttempt+2)
if lastAttempt != s.expectedAttempts {
t.Errorf("Expected lastAttempt to be %d, got %d", s.expectedAttempts, lastAttempt)
}
if s.failUntilAttempt == s.expectedAttempts && err != nil {
t.Fatalf("Expected nil, got err %v", err)
}
if s.failUntilAttempt != s.expectedAttempts && s.err != nil && err == nil {
t.Fatalf("Expected error %q, got nil", s.err)
}
})
}
}

137
core/db_table.go Normal file
View file

@ -0,0 +1,137 @@
package core
import (
"database/sql"
"fmt"
"github.com/pocketbase/dbx"
)
// TableColumns returns all column names of a single table by its name.
func (app *BaseApp) TableColumns(tableName string) ([]string, error) {
columns := []string{}
err := app.ConcurrentDB().NewQuery("SELECT name FROM PRAGMA_TABLE_INFO({:tableName})").
Bind(dbx.Params{"tableName": tableName}).
Column(&columns)
return columns, err
}
type TableInfoRow struct {
// the `db:"pk"` tag has special semantic so we cannot rename
// the original field without specifying a custom mapper
PK int
Index int `db:"cid"`
Name string `db:"name"`
Type string `db:"type"`
NotNull bool `db:"notnull"`
DefaultValue sql.NullString `db:"dflt_value"`
}
// TableInfo returns the "table_info" pragma result for the specified table.
func (app *BaseApp) TableInfo(tableName string) ([]*TableInfoRow, error) {
info := []*TableInfoRow{}
err := app.ConcurrentDB().NewQuery("SELECT * FROM PRAGMA_TABLE_INFO({:tableName})").
Bind(dbx.Params{"tableName": tableName}).
All(&info)
if err != nil {
return nil, err
}
// mattn/go-sqlite3 doesn't throw an error on invalid or missing table
// so we additionally have to check whether the loaded info result is nonempty
if len(info) == 0 {
return nil, fmt.Errorf("empty table info probably due to invalid or missing table %s", tableName)
}
return info, nil
}
// TableIndexes returns a name grouped map with all non empty index of the specified table.
//
// Note: This method doesn't return an error on nonexisting table.
func (app *BaseApp) TableIndexes(tableName string) (map[string]string, error) {
indexes := []struct {
Name string
Sql string
}{}
err := app.ConcurrentDB().Select("name", "sql").
From("sqlite_master").
AndWhere(dbx.NewExp("sql is not null")).
AndWhere(dbx.HashExp{
"type": "index",
"tbl_name": tableName,
}).
All(&indexes)
if err != nil {
return nil, err
}
result := make(map[string]string, len(indexes))
for _, idx := range indexes {
result[idx.Name] = idx.Sql
}
return result, nil
}
// DeleteTable drops the specified table.
//
// This method is a no-op if a table with the provided name doesn't exist.
//
// NB! Be aware that this method is vulnerable to SQL injection and the
// "tableName" argument must come only from trusted input!
func (app *BaseApp) DeleteTable(tableName string) error {
_, err := app.NonconcurrentDB().NewQuery(fmt.Sprintf(
"DROP TABLE IF EXISTS {{%s}}",
tableName,
)).Execute()
return err
}
// HasTable checks if a table (or view) with the provided name exists (case insensitive).
// in the data.db.
func (app *BaseApp) HasTable(tableName string) bool {
return app.hasTable(app.ConcurrentDB(), tableName)
}
// AuxHasTable checks if a table (or view) with the provided name exists (case insensitive)
// in the auixiliary.db.
func (app *BaseApp) AuxHasTable(tableName string) bool {
return app.hasTable(app.AuxConcurrentDB(), tableName)
}
func (app *BaseApp) hasTable(db dbx.Builder, tableName string) bool {
var exists int
err := db.Select("(1)").
From("sqlite_schema").
AndWhere(dbx.HashExp{"type": []any{"table", "view"}}).
AndWhere(dbx.NewExp("LOWER([[name]])=LOWER({:tableName})", dbx.Params{"tableName": tableName})).
Limit(1).
Row(&exists)
return err == nil && exists > 0
}
// Vacuum executes VACUUM on the data.db in order to reclaim unused data db disk space.
func (app *BaseApp) Vacuum() error {
return app.vacuum(app.NonconcurrentDB())
}
// AuxVacuum executes VACUUM on the auxiliary.db in order to reclaim unused auxiliary db disk space.
func (app *BaseApp) AuxVacuum() error {
return app.vacuum(app.AuxNonconcurrentDB())
}
func (app *BaseApp) vacuum(db dbx.Builder) error {
_, err := db.NewQuery("VACUUM").Execute()
return err
}

250
core/db_table_test.go Normal file
View file

@ -0,0 +1,250 @@
package core_test
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"slices"
"testing"
"time"
"github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
)
func TestHasTable(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
scenarios := []struct {
tableName string
expected bool
}{
{"", false},
{"test", false},
{core.CollectionNameSuperusers, true},
{"demo3", true},
{"DEMO3", true}, // table names are case insensitives by default
{"view1", true}, // view
}
for _, s := range scenarios {
t.Run(s.tableName, func(t *testing.T) {
result := app.HasTable(s.tableName)
if result != s.expected {
t.Fatalf("Expected %v, got %v", s.expected, result)
}
})
}
}
func TestAuxHasTable(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
scenarios := []struct {
tableName string
expected bool
}{
{"", false},
{"test", false},
{"_lOGS", true}, // table names are case insensitives by default
}
for _, s := range scenarios {
t.Run(s.tableName, func(t *testing.T) {
result := app.AuxHasTable(s.tableName)
if result != s.expected {
t.Fatalf("Expected %v, got %v", s.expected, result)
}
})
}
}
func TestTableColumns(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
scenarios := []struct {
tableName string
expected []string
}{
{"", nil},
{"_params", []string{"id", "value", "created", "updated"}},
}
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%s", i, s.tableName), func(t *testing.T) {
columns, _ := app.TableColumns(s.tableName)
if len(columns) != len(s.expected) {
t.Fatalf("Expected columns %v, got %v", s.expected, columns)
}
for _, c := range columns {
if !slices.Contains(s.expected, c) {
t.Errorf("Didn't expect column %s", c)
}
}
})
}
}
func TestTableInfo(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
scenarios := []struct {
tableName string
expected string
}{
{"", "null"},
{"missing", "null"},
{
"_params",
`[{"PK":0,"Index":0,"Name":"created","Type":"TEXT","NotNull":true,"DefaultValue":{"String":"''","Valid":true}},{"PK":1,"Index":1,"Name":"id","Type":"TEXT","NotNull":true,"DefaultValue":{"String":"'r'||lower(hex(randomblob(7)))","Valid":true}},{"PK":0,"Index":2,"Name":"updated","Type":"TEXT","NotNull":true,"DefaultValue":{"String":"''","Valid":true}},{"PK":0,"Index":3,"Name":"value","Type":"JSON","NotNull":false,"DefaultValue":{"String":"NULL","Valid":true}}]`,
},
}
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%s", i, s.tableName), func(t *testing.T) {
rows, _ := app.TableInfo(s.tableName)
raw, err := json.Marshal(rows)
if err != nil {
t.Fatal(err)
}
if str := string(raw); str != s.expected {
t.Fatalf("Expected\n%s\ngot\n%s", s.expected, str)
}
})
}
}
func TestTableIndexes(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
scenarios := []struct {
tableName string
expected []string
}{
{"", nil},
{"missing", nil},
{
core.CollectionNameSuperusers,
[]string{"idx_email__pbc_3323866339", "idx_tokenKey__pbc_3323866339"},
},
}
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%s", i, s.tableName), func(t *testing.T) {
indexes, _ := app.TableIndexes(s.tableName)
if len(indexes) != len(s.expected) {
t.Fatalf("Expected %d indexes, got %d\n%v", len(s.expected), len(indexes), indexes)
}
for _, name := range s.expected {
if v, ok := indexes[name]; !ok || v == "" {
t.Fatalf("Expected non-empty index %q in \n%v", name, indexes)
}
}
})
}
}
func TestDeleteTable(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
scenarios := []struct {
tableName string
expectError bool
}{
{"", true},
{"test", false}, // missing tables are ignored
{"_admins", false},
{"demo3", false},
}
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%s", i, s.tableName), func(t *testing.T) {
err := app.DeleteTable(s.tableName)
hasErr := err != nil
if hasErr != s.expectError {
t.Fatalf("Expected hasErr %v, got %v", s.expectError, hasErr)
}
})
}
}
func TestVacuum(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
calledQueries := []string{}
app.NonconcurrentDB().(*dbx.DB).QueryLogFunc = func(ctx context.Context, t time.Duration, sql string, rows *sql.Rows, err error) {
calledQueries = append(calledQueries, sql)
}
app.NonconcurrentDB().(*dbx.DB).ExecLogFunc = func(ctx context.Context, t time.Duration, sql string, result sql.Result, err error) {
calledQueries = append(calledQueries, sql)
}
if err := app.Vacuum(); err != nil {
t.Fatal(err)
}
if total := len(calledQueries); total != 1 {
t.Fatalf("Expected 1 query, got %d", total)
}
if calledQueries[0] != "VACUUM" {
t.Fatalf("Expected VACUUM query, got %s", calledQueries[0])
}
}
func TestAuxVacuum(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
calledQueries := []string{}
app.AuxNonconcurrentDB().(*dbx.DB).QueryLogFunc = func(ctx context.Context, t time.Duration, sql string, rows *sql.Rows, err error) {
calledQueries = append(calledQueries, sql)
}
app.AuxNonconcurrentDB().(*dbx.DB).ExecLogFunc = func(ctx context.Context, t time.Duration, sql string, result sql.Result, err error) {
calledQueries = append(calledQueries, sql)
}
if err := app.AuxVacuum(); err != nil {
t.Fatal(err)
}
if total := len(calledQueries); total != 1 {
t.Fatalf("Expected 1 query, got %d", total)
}
if calledQueries[0] != "VACUUM" {
t.Fatalf("Expected VACUUM query, got %s", calledQueries[0])
}
}

113
core/db_test.go Normal file
View file

@ -0,0 +1,113 @@
package core_test
import (
"context"
"errors"
"testing"
"github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
)
func TestGenerateDefaultRandomId(t *testing.T) {
t.Parallel()
id1 := core.GenerateDefaultRandomId()
id2 := core.GenerateDefaultRandomId()
if id1 == id2 {
t.Fatalf("Expected id1 and id2 to differ, got %q", id1)
}
if l := len(id1); l != 15 {
t.Fatalf("Expected id1 length %d, got %d", 15, l)
}
if l := len(id2); l != 15 {
t.Fatalf("Expected id2 length %d, got %d", 15, l)
}
}
func TestModelQuery(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
modelsQuery := app.ModelQuery(&core.Collection{})
logsModelQuery := app.AuxModelQuery(&core.Collection{})
if app.ConcurrentDB() == modelsQuery.Info().Builder {
t.Fatalf("ModelQuery() is not using app.ConcurrentDB()")
}
if app.AuxConcurrentDB() == logsModelQuery.Info().Builder {
t.Fatalf("AuxModelQuery() is not using app.AuxConcurrentDB()")
}
expectedSQL := "SELECT {{_collections}}.* FROM `_collections`"
for i, q := range []*dbx.SelectQuery{modelsQuery, logsModelQuery} {
sql := q.Build().SQL()
if sql != expectedSQL {
t.Fatalf("[%d] Expected select\n%s\ngot\n%s", i, expectedSQL, sql)
}
}
}
func TestValidate(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
u := &mockSuperusers{}
testErr := errors.New("test")
app.OnModelValidate().BindFunc(func(e *core.ModelEvent) error {
return testErr
})
err := app.Validate(u)
if err != testErr {
t.Fatalf("Expected error %v, got %v", testErr, err)
}
}
func TestValidateWithContext(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
u := &mockSuperusers{}
testErr := errors.New("test")
app.OnModelValidate().BindFunc(func(e *core.ModelEvent) error {
if v := e.Context.Value("test"); v != 123 {
t.Fatalf("Expected 'test' context value %#v, got %#v", 123, v)
}
return testErr
})
//nolint:staticcheck
ctx := context.WithValue(context.Background(), "test", 123)
err := app.ValidateWithContext(ctx, u)
if err != testErr {
t.Fatalf("Expected error %v, got %v", testErr, err)
}
}
// -------------------------------------------------------------------
type mockSuperusers struct {
core.BaseModel
Email string `db:"email"`
}
func (m *mockSuperusers) TableName() string {
return core.CollectionNameSuperusers
}

112
core/db_tx.go Normal file
View file

@ -0,0 +1,112 @@
package core
import (
"errors"
"fmt"
"sync"
"github.com/pocketbase/dbx"
)
// RunInTransaction wraps fn into a transaction for the regular app database.
//
// It is safe to nest RunInTransaction calls as long as you use the callback's txApp.
func (app *BaseApp) RunInTransaction(fn func(txApp App) error) error {
return app.runInTransaction(app.NonconcurrentDB(), fn, false)
}
// AuxRunInTransaction wraps fn into a transaction for the auxiliary app database.
//
// It is safe to nest RunInTransaction calls as long as you use the callback's txApp.
func (app *BaseApp) AuxRunInTransaction(fn func(txApp App) error) error {
return app.runInTransaction(app.AuxNonconcurrentDB(), fn, true)
}
func (app *BaseApp) runInTransaction(db dbx.Builder, fn func(txApp App) error, isForAuxDB bool) error {
switch txOrDB := db.(type) {
case *dbx.Tx:
// run as part of the already existing transaction
return fn(app)
case *dbx.DB:
var txApp *BaseApp
txErr := txOrDB.Transactional(func(tx *dbx.Tx) error {
txApp = app.createTxApp(tx, isForAuxDB)
return fn(txApp)
})
// execute all after event calls on transaction complete
if txApp != nil && txApp.txInfo != nil {
afterFuncErr := txApp.txInfo.runAfterFuncs(txErr)
if afterFuncErr != nil {
return errors.Join(txErr, afterFuncErr)
}
}
return txErr
default:
return errors.New("failed to start transaction (unknown db type)")
}
}
// createTxApp shallow clones the current app and assigns a new tx state.
func (app *BaseApp) createTxApp(tx *dbx.Tx, isForAuxDB bool) *BaseApp {
clone := *app
if isForAuxDB {
clone.auxConcurrentDB = tx
clone.auxNonconcurrentDB = tx
} else {
clone.concurrentDB = tx
clone.nonconcurrentDB = tx
}
clone.txInfo = &TxAppInfo{
parent: app,
isForAuxDB: isForAuxDB,
}
return &clone
}
// TxAppInfo represents an active transaction context associated to an existing app instance.
type TxAppInfo struct {
parent *BaseApp
afterFuncs []func(txErr error) error
mu sync.Mutex
isForAuxDB bool
}
// OnComplete registers the provided callback that will be invoked
// once the related transaction ends (either completes successfully or rollbacked with an error).
//
// The callback receives the transaction error (if any) as its argument.
// Any additional errors returned by the OnComplete callbacks will be
// joined together with txErr when returning the final transaction result.
func (a *TxAppInfo) OnComplete(fn func(txErr error) error) {
a.mu.Lock()
defer a.mu.Unlock()
a.afterFuncs = append(a.afterFuncs, fn)
}
// note: can be called only once because TxAppInfo is cleared
func (a *TxAppInfo) runAfterFuncs(txErr error) error {
a.mu.Lock()
defer a.mu.Unlock()
var errs []error
for _, call := range a.afterFuncs {
if err := call(txErr); err != nil {
errs = append(errs, err)
}
}
a.afterFuncs = nil
if len(errs) > 0 {
return fmt.Errorf("transaction afterFunc errors: %w", errors.Join(errs...))
}
return nil
}

413
core/db_tx_test.go Normal file
View file

@ -0,0 +1,413 @@
package core_test
import (
"errors"
"testing"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
)
func TestRunInTransaction(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
t.Run("failed nested transaction", func(t *testing.T) {
app.RunInTransaction(func(txApp core.App) error {
superuser, _ := txApp.FindAuthRecordByEmail(core.CollectionNameSuperusers, "test@example.com")
return txApp.RunInTransaction(func(tx2Dao core.App) error {
if err := tx2Dao.Delete(superuser); err != nil {
t.Fatal(err)
}
return errors.New("test error")
})
})
// superuser should still exist
superuser, _ := app.FindAuthRecordByEmail(core.CollectionNameSuperusers, "test@example.com")
if superuser == nil {
t.Fatal("Expected superuser test@example.com to not be deleted")
}
})
t.Run("successful nested transaction", func(t *testing.T) {
app.RunInTransaction(func(txApp core.App) error {
superuser, _ := txApp.FindAuthRecordByEmail(core.CollectionNameSuperusers, "test@example.com")
return txApp.RunInTransaction(func(tx2Dao core.App) error {
return tx2Dao.Delete(superuser)
})
})
// superuser should have been deleted
superuser, _ := app.FindAuthRecordByEmail(core.CollectionNameSuperusers, "test@example.com")
if superuser != nil {
t.Fatalf("Expected superuser test@example.com to be deleted, found %v", superuser)
}
})
}
func TestTransactionHooksCallsOnFailure(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
createHookCalls := 0
updateHookCalls := 0
deleteHookCalls := 0
afterCreateHookCalls := 0
afterUpdateHookCalls := 0
afterDeleteHookCalls := 0
app.OnModelCreate().BindFunc(func(e *core.ModelEvent) error {
createHookCalls++
return e.Next()
})
app.OnModelUpdate().BindFunc(func(e *core.ModelEvent) error {
updateHookCalls++
return e.Next()
})
app.OnModelDelete().BindFunc(func(e *core.ModelEvent) error {
deleteHookCalls++
return e.Next()
})
app.OnModelAfterCreateSuccess().BindFunc(func(e *core.ModelEvent) error {
afterCreateHookCalls++
return e.Next()
})
app.OnModelAfterUpdateSuccess().BindFunc(func(e *core.ModelEvent) error {
afterUpdateHookCalls++
return e.Next()
})
app.OnModelAfterDeleteSuccess().BindFunc(func(e *core.ModelEvent) error {
afterDeleteHookCalls++
return e.Next()
})
existingModel, _ := app.FindAuthRecordByEmail(core.CollectionNameSuperusers, "test@example.com")
app.RunInTransaction(func(txApp1 core.App) error {
return txApp1.RunInTransaction(func(txApp2 core.App) error {
// test create
// ---
newModel := core.NewRecord(existingModel.Collection())
newModel.SetEmail("test_new1@example.com")
newModel.SetPassword("1234567890")
if err := txApp2.Save(newModel); err != nil {
t.Fatal(err)
}
// test update (twice)
// ---
if err := txApp2.Save(existingModel); err != nil {
t.Fatal(err)
}
if err := txApp2.Save(existingModel); err != nil {
t.Fatal(err)
}
// test delete
// ---
if err := txApp2.Delete(newModel); err != nil {
t.Fatal(err)
}
return errors.New("test_tx_error")
})
})
if createHookCalls != 1 {
t.Errorf("Expected createHookCalls to be called 1 time, got %d", createHookCalls)
}
if updateHookCalls != 2 {
t.Errorf("Expected updateHookCalls to be called 2 times, got %d", updateHookCalls)
}
if deleteHookCalls != 1 {
t.Errorf("Expected deleteHookCalls to be called 1 time, got %d", deleteHookCalls)
}
if afterCreateHookCalls != 0 {
t.Errorf("Expected afterCreateHookCalls to be called 0 times, got %d", afterCreateHookCalls)
}
if afterUpdateHookCalls != 0 {
t.Errorf("Expected afterUpdateHookCalls to be called 0 times, got %d", afterUpdateHookCalls)
}
if afterDeleteHookCalls != 0 {
t.Errorf("Expected afterDeleteHookCalls to be called 0 times, got %d", afterDeleteHookCalls)
}
}
func TestTransactionHooksCallsOnSuccess(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
createHookCalls := 0
updateHookCalls := 0
deleteHookCalls := 0
afterCreateHookCalls := 0
afterUpdateHookCalls := 0
afterDeleteHookCalls := 0
app.OnModelCreate().BindFunc(func(e *core.ModelEvent) error {
createHookCalls++
return e.Next()
})
app.OnModelUpdate().BindFunc(func(e *core.ModelEvent) error {
updateHookCalls++
return e.Next()
})
app.OnModelDelete().BindFunc(func(e *core.ModelEvent) error {
deleteHookCalls++
return e.Next()
})
app.OnModelAfterCreateSuccess().BindFunc(func(e *core.ModelEvent) error {
if e.App.IsTransactional() {
t.Fatal("Expected e.App to be non-transactional")
}
afterCreateHookCalls++
return e.Next()
})
app.OnModelAfterUpdateSuccess().BindFunc(func(e *core.ModelEvent) error {
if e.App.IsTransactional() {
t.Fatal("Expected e.App to be non-transactional")
}
afterUpdateHookCalls++
return e.Next()
})
app.OnModelAfterDeleteSuccess().BindFunc(func(e *core.ModelEvent) error {
if e.App.IsTransactional() {
t.Fatal("Expected e.App to be non-transactional")
}
afterDeleteHookCalls++
return e.Next()
})
existingModel, _ := app.FindAuthRecordByEmail(core.CollectionNameSuperusers, "test@example.com")
app.RunInTransaction(func(txApp1 core.App) error {
return txApp1.RunInTransaction(func(txApp2 core.App) error {
// test create
// ---
newModel := core.NewRecord(existingModel.Collection())
newModel.SetEmail("test_new1@example.com")
newModel.SetPassword("1234567890")
if err := txApp2.Save(newModel); err != nil {
t.Fatal(err)
}
// test update (twice)
// ---
if err := txApp2.Save(existingModel); err != nil {
t.Fatal(err)
}
if err := txApp2.Save(existingModel); err != nil {
t.Fatal(err)
}
// test delete
// ---
if err := txApp2.Delete(newModel); err != nil {
t.Fatal(err)
}
return nil
})
})
if createHookCalls != 1 {
t.Errorf("Expected createHookCalls to be called 1 time, got %d", createHookCalls)
}
if updateHookCalls != 2 {
t.Errorf("Expected updateHookCalls to be called 2 times, got %d", updateHookCalls)
}
if deleteHookCalls != 1 {
t.Errorf("Expected deleteHookCalls to be called 1 time, got %d", deleteHookCalls)
}
if afterCreateHookCalls != 1 {
t.Errorf("Expected afterCreateHookCalls to be called 1 time, got %d", afterCreateHookCalls)
}
if afterUpdateHookCalls != 2 {
t.Errorf("Expected afterUpdateHookCalls to be called 2 times, got %d", afterUpdateHookCalls)
}
if afterDeleteHookCalls != 1 {
t.Errorf("Expected afterDeleteHookCalls to be called 1 time, got %d", afterDeleteHookCalls)
}
}
func TestTransactionFromInnerCreateHook(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
app.OnRecordCreateExecute("demo2").BindFunc(func(e *core.RecordEvent) error {
originalApp := e.App
return e.App.RunInTransaction(func(txApp core.App) error {
e.App = txApp
defer func() {
e.App = originalApp
}()
nextErr := e.Next()
return nextErr
})
})
app.OnRecordAfterCreateSuccess("demo2").BindFunc(func(e *core.RecordEvent) error {
if e.App.IsTransactional() {
t.Fatal("Expected e.App to be non-transactional")
}
// perform a db query with the app instance to ensure that it is still valid
_, err := e.App.FindFirstRecordByFilter("demo2", "1=1")
if err != nil {
t.Fatalf("Failed to perform a db query after tx success: %v", err)
}
return e.Next()
})
collection, err := app.FindCollectionByNameOrId("demo2")
if err != nil {
t.Fatal(err)
}
record := core.NewRecord(collection)
record.Set("title", "test_inner_tx")
if err = app.Save(record); err != nil {
t.Fatalf("Create failed: %v", err)
}
expectedHookCalls := map[string]int{
"OnRecordCreateExecute": 1,
"OnRecordAfterCreateSuccess": 1,
}
for k, total := range expectedHookCalls {
if found, ok := app.EventCalls[k]; !ok || total != found {
t.Fatalf("Expected %q %d calls, got %d", k, total, found)
}
}
}
func TestTransactionFromInnerUpdateHook(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
app.OnRecordUpdateExecute("demo2").BindFunc(func(e *core.RecordEvent) error {
originalApp := e.App
return e.App.RunInTransaction(func(txApp core.App) error {
e.App = txApp
defer func() {
e.App = originalApp
}()
nextErr := e.Next()
return nextErr
})
})
app.OnRecordAfterUpdateSuccess("demo2").BindFunc(func(e *core.RecordEvent) error {
if e.App.IsTransactional() {
t.Fatal("Expected e.App to be non-transactional")
}
// perform a db query with the app instance to ensure that it is still valid
_, err := e.App.FindFirstRecordByFilter("demo2", "1=1")
if err != nil {
t.Fatalf("Failed to perform a db query after tx success: %v", err)
}
return e.Next()
})
existingModel, err := app.FindFirstRecordByFilter("demo2", "1=1")
if err != nil {
t.Fatal(err)
}
if err = app.Save(existingModel); err != nil {
t.Fatalf("Update failed: %v", err)
}
expectedHookCalls := map[string]int{
"OnRecordUpdateExecute": 1,
"OnRecordAfterUpdateSuccess": 1,
}
for k, total := range expectedHookCalls {
if found, ok := app.EventCalls[k]; !ok || total != found {
t.Fatalf("Expected %q %d calls, got %d", k, total, found)
}
}
}
func TestTransactionFromInnerDeleteHook(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
app.OnRecordDeleteExecute("demo2").BindFunc(func(e *core.RecordEvent) error {
originalApp := e.App
return e.App.RunInTransaction(func(txApp core.App) error {
e.App = txApp
defer func() {
e.App = originalApp
}()
nextErr := e.Next()
return nextErr
})
})
app.OnRecordAfterDeleteSuccess("demo2").BindFunc(func(e *core.RecordEvent) error {
if e.App.IsTransactional() {
t.Fatal("Expected e.App to be non-transactional")
}
// perform a db query with the app instance to ensure that it is still valid
_, err := e.App.FindFirstRecordByFilter("demo2", "1=1")
if err != nil {
t.Fatalf("Failed to perform a db query after tx success: %v", err)
}
return e.Next()
})
existingModel, err := app.FindFirstRecordByFilter("demo2", "1=1")
if err != nil {
t.Fatal(err)
}
if err = app.Delete(existingModel); err != nil {
t.Fatalf("Delete failed: %v", err)
}
expectedHookCalls := map[string]int{
"OnRecordDeleteExecute": 1,
"OnRecordAfterDeleteSuccess": 1,
}
for k, total := range expectedHookCalls {
if found, ok := app.EventCalls[k]; !ok || total != found {
t.Fatalf("Expected %q %d calls, got %d", k, total, found)
}
}
}

197
core/event_request.go Normal file
View file

@ -0,0 +1,197 @@
package core
import (
"maps"
"net/netip"
"strings"
"sync"
"github.com/pocketbase/pocketbase/tools/inflector"
"github.com/pocketbase/pocketbase/tools/router"
)
// Common request store keys used by the middlewares and api handlers.
const (
RequestEventKeyInfoContext = "infoContext"
)
// RequestEvent defines the PocketBase router handler event.
type RequestEvent struct {
App App
cachedRequestInfo *RequestInfo
Auth *Record
router.Event
mu sync.Mutex
}
// RealIP returns the "real" IP address from the configured trusted proxy headers.
//
// If Settings.TrustedProxy is not configured or the found IP is empty,
// it fallbacks to e.RemoteIP().
//
// NB!
// Be careful when used in a security critical context as it relies on
// the trusted proxy to be properly configured and your app to be accessible only through it.
// If you are not sure, use e.RemoteIP().
func (e *RequestEvent) RealIP() string {
settings := e.App.Settings()
for _, h := range settings.TrustedProxy.Headers {
headerValues := e.Request.Header.Values(h)
if len(headerValues) == 0 {
continue
}
// extract the last header value as it is expected to be the one controlled by the proxy
ipsList := headerValues[len(headerValues)-1]
if ipsList == "" {
continue
}
ips := strings.Split(ipsList, ",")
if settings.TrustedProxy.UseLeftmostIP {
for _, ip := range ips {
parsed, err := netip.ParseAddr(strings.TrimSpace(ip))
if err == nil {
return parsed.StringExpanded()
}
}
} else {
for i := len(ips) - 1; i >= 0; i-- {
parsed, err := netip.ParseAddr(strings.TrimSpace(ips[i]))
if err == nil {
return parsed.StringExpanded()
}
}
}
}
return e.RemoteIP()
}
// HasSuperuserAuth checks whether the current RequestEvent has superuser authentication loaded.
func (e *RequestEvent) HasSuperuserAuth() bool {
return e.Auth != nil && e.Auth.IsSuperuser()
}
// RequestInfo parses the current request into RequestInfo instance.
//
// Note that the returned result is cached to avoid copying the request data multiple times
// but the auth state and other common store items are always refreshed in case they were changed by another handler.
func (e *RequestEvent) RequestInfo() (*RequestInfo, error) {
e.mu.Lock()
defer e.mu.Unlock()
if e.cachedRequestInfo != nil {
e.cachedRequestInfo.Auth = e.Auth
infoCtx, _ := e.Get(RequestEventKeyInfoContext).(string)
if infoCtx != "" {
e.cachedRequestInfo.Context = infoCtx
} else {
e.cachedRequestInfo.Context = RequestInfoContextDefault
}
} else {
// (re)init e.cachedRequestInfo based on the current request event
if err := e.initRequestInfo(); err != nil {
return nil, err
}
}
return e.cachedRequestInfo, nil
}
func (e *RequestEvent) initRequestInfo() error {
infoCtx, _ := e.Get(RequestEventKeyInfoContext).(string)
if infoCtx == "" {
infoCtx = RequestInfoContextDefault
}
info := &RequestInfo{
Context: infoCtx,
Method: e.Request.Method,
Query: map[string]string{},
Headers: map[string]string{},
Body: map[string]any{},
}
if err := e.BindBody(&info.Body); err != nil {
return err
}
// extract the first value of all query params
query := e.Request.URL.Query()
for k, v := range query {
if len(v) > 0 {
info.Query[k] = v[0]
}
}
// extract the first value of all headers and normalizes the keys
// ("X-Token" is converted to "x_token")
for k, v := range e.Request.Header {
if len(v) > 0 {
info.Headers[inflector.Snakecase(k)] = v[0]
}
}
info.Auth = e.Auth
e.cachedRequestInfo = info
return nil
}
// -------------------------------------------------------------------
const (
RequestInfoContextDefault = "default"
RequestInfoContextExpand = "expand"
RequestInfoContextRealtime = "realtime"
RequestInfoContextProtectedFile = "protectedFile"
RequestInfoContextBatch = "batch"
RequestInfoContextOAuth2 = "oauth2"
RequestInfoContextOTP = "otp"
RequestInfoContextPasswordAuth = "password"
)
// RequestInfo defines a HTTP request data struct, usually used
// as part of the `@request.*` filter resolver.
//
// The Query and Headers fields contains only the first value for each found entry.
type RequestInfo struct {
Query map[string]string `json:"query"`
Headers map[string]string `json:"headers"`
Body map[string]any `json:"body"`
Auth *Record `json:"auth"`
Method string `json:"method"`
Context string `json:"context"`
}
// HasSuperuserAuth checks whether the current RequestInfo instance
// has superuser authentication loaded.
func (info *RequestInfo) HasSuperuserAuth() bool {
return info.Auth != nil && info.Auth.IsSuperuser()
}
// Clone creates a new shallow copy of the current RequestInfo and its Auth record (if any).
func (info *RequestInfo) Clone() *RequestInfo {
clone := &RequestInfo{
Method: info.Method,
Context: info.Context,
Query: maps.Clone(info.Query),
Body: maps.Clone(info.Body),
Headers: maps.Clone(info.Headers),
}
if info.Auth != nil {
clone.Auth = info.Auth.Fresh()
}
return clone
}

View file

@ -0,0 +1,33 @@
package core
import (
"net/http"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/pocketbase/pocketbase/tools/hook"
)
type BatchRequestEvent struct {
hook.Event
*RequestEvent
Batch []*InternalRequest
}
type InternalRequest struct {
// note: for uploading files the value must be either *filesystem.File or []*filesystem.File
Body map[string]any `form:"body" json:"body"`
Headers map[string]string `form:"headers" json:"headers"`
Method string `form:"method" json:"method"`
URL string `form:"url" json:"url"`
}
func (br InternalRequest) Validate() error {
return validation.ValidateStruct(&br,
validation.Field(&br.Method, validation.Required, validation.In(http.MethodGet, http.MethodPost, http.MethodPut, http.MethodPatch, http.MethodDelete)),
validation.Field(&br.URL, validation.Required, validation.Length(0, 2000)),
)
}

View file

@ -0,0 +1,74 @@
package core_test
import (
"net/http"
"strings"
"testing"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
)
func TestInternalRequestValidate(t *testing.T) {
scenarios := []struct {
name string
request core.InternalRequest
expectedErrors []string
}{
{
"empty struct",
core.InternalRequest{},
[]string{"method", "url"},
},
// method
{
"GET method",
core.InternalRequest{URL: "test", Method: http.MethodGet},
[]string{},
},
{
"POST method",
core.InternalRequest{URL: "test", Method: http.MethodPost},
[]string{},
},
{
"PUT method",
core.InternalRequest{URL: "test", Method: http.MethodPut},
[]string{},
},
{
"PATCH method",
core.InternalRequest{URL: "test", Method: http.MethodPatch},
[]string{},
},
{
"DELETE method",
core.InternalRequest{URL: "test", Method: http.MethodDelete},
[]string{},
},
{
"unknown method",
core.InternalRequest{URL: "test", Method: "unknown"},
[]string{"method"},
},
// url
{
"url <= 2000",
core.InternalRequest{URL: strings.Repeat("a", 2000), Method: http.MethodGet},
[]string{},
},
{
"url > 2000",
core.InternalRequest{URL: strings.Repeat("a", 2001), Method: http.MethodGet},
[]string{"url"},
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
tests.TestValidationErrors(t, s.request.Validate(), s.expectedErrors)
})
}
}

334
core/event_request_test.go Normal file
View file

@ -0,0 +1,334 @@
package core_test
import (
"encoding/json"
"net/http"
"strings"
"testing"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
)
func TestEventRequestRealIP(t *testing.T) {
t.Parallel()
headers := map[string][]string{
"CF-Connecting-IP": {"1.2.3.4", "1.1.1.1"},
"Fly-Client-IP": {"1.2.3.4", "1.1.1.2"},
"X-Real-IP": {"1.2.3.4", "1.1.1.3,1.1.1.4"},
"X-Forwarded-For": {"1.2.3.4", "invalid,1.1.1.5,1.1.1.6,invalid"},
}
scenarios := []struct {
name string
headers map[string][]string
trustedHeaders []string
useLeftmostIP bool
expected string
}{
{
"no trusted headers",
headers,
nil,
false,
"127.0.0.1",
},
{
"non-matching trusted header",
headers,
[]string{"header1", "header2"},
false,
"127.0.0.1",
},
{
"trusted X-Real-IP (rightmost)",
headers,
[]string{"header1", "x-real-ip", "x-forwarded-for"},
false,
"1.1.1.4",
},
{
"trusted X-Real-IP (leftmost)",
headers,
[]string{"header1", "x-real-ip", "x-forwarded-for"},
true,
"1.1.1.3",
},
{
"trusted X-Forwarded-For (rightmost)",
headers,
[]string{"header1", "x-forwarded-for"},
false,
"1.1.1.6",
},
{
"trusted X-Forwarded-For (leftmost)",
headers,
[]string{"header1", "x-forwarded-for"},
true,
"1.1.1.5",
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
app, err := tests.NewTestApp()
if err != nil {
t.Fatal(err)
}
defer app.Cleanup()
app.Settings().TrustedProxy.Headers = s.trustedHeaders
app.Settings().TrustedProxy.UseLeftmostIP = s.useLeftmostIP
event := core.RequestEvent{}
event.App = app
event.Request, err = http.NewRequest(http.MethodGet, "/", nil)
if err != nil {
t.Fatal(err)
}
event.Request.RemoteAddr = "127.0.0.1:80" // fallback
for k, values := range s.headers {
for _, v := range values {
event.Request.Header.Add(k, v)
}
}
result := event.RealIP()
if result != s.expected {
t.Fatalf("Expected ip %q, got %q", s.expected, result)
}
})
}
}
func TestEventRequestHasSuperUserAuth(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
superuser, err := app.FindAuthRecordByEmail(core.CollectionNameSuperusers, "test@example.com")
if err != nil {
t.Fatal(err)
}
scenarios := []struct {
name string
record *core.Record
expected bool
}{
{"nil record", nil, false},
{"regular user record", user, false},
{"superuser record", superuser, true},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
e := core.RequestEvent{}
e.Auth = s.record
result := e.HasSuperuserAuth()
if result != s.expected {
t.Fatalf("Expected %v, got %v", s.expected, result)
}
})
}
}
func TestRequestEventRequestInfo(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
userCol, err := app.FindCollectionByNameOrId("users")
if err != nil {
t.Fatal(err)
}
user1 := core.NewRecord(userCol)
user1.Id = "user1"
user1.SetEmail("test1@example.com")
user2 := core.NewRecord(userCol)
user2.Id = "user2"
user2.SetEmail("test2@example.com")
testBody := `{"a":123,"b":"test"}`
event := core.RequestEvent{}
event.Request, err = http.NewRequest("POST", "/test?q1=123&q2=456", strings.NewReader(testBody))
if err != nil {
t.Fatal(err)
}
event.Request.Header.Add("content-type", "application/json")
event.Request.Header.Add("x-test", "test")
event.Set(core.RequestEventKeyInfoContext, "test")
event.Auth = user1
t.Run("init", func(t *testing.T) {
info, err := event.RequestInfo()
if err != nil {
t.Fatalf("Failed to resolve request info: %v", err)
}
raw, err := json.Marshal(info)
if err != nil {
t.Fatalf("Failed to serialize request info: %v", err)
}
rawStr := string(raw)
expected := `{"query":{"q1":"123","q2":"456"},"headers":{"content_type":"application/json","x_test":"test"},"body":{"a":123,"b":"test"},"auth":{"avatar":"","collectionId":"_pb_users_auth_","collectionName":"users","created":"","emailVisibility":false,"file":[],"id":"user1","name":"","rel":"","updated":"","username":"","verified":false},"method":"POST","context":"test"}`
if expected != rawStr {
t.Fatalf("Expected\n%v\ngot\n%v", expected, rawStr)
}
})
t.Run("change user and context", func(t *testing.T) {
event.Set(core.RequestEventKeyInfoContext, "test2")
event.Auth = user2
info, err := event.RequestInfo()
if err != nil {
t.Fatalf("Failed to resolve request info: %v", err)
}
raw, err := json.Marshal(info)
if err != nil {
t.Fatalf("Failed to serialize request info: %v", err)
}
rawStr := string(raw)
expected := `{"query":{"q1":"123","q2":"456"},"headers":{"content_type":"application/json","x_test":"test"},"body":{"a":123,"b":"test"},"auth":{"avatar":"","collectionId":"_pb_users_auth_","collectionName":"users","created":"","emailVisibility":false,"file":[],"id":"user2","name":"","rel":"","updated":"","username":"","verified":false},"method":"POST","context":"test2"}`
if expected != rawStr {
t.Fatalf("Expected\n%v\ngot\n%v", expected, rawStr)
}
})
}
func TestRequestInfoHasSuperuserAuth(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
superuser, err := app.FindAuthRecordByEmail(core.CollectionNameSuperusers, "test@example.com")
if err != nil {
t.Fatal(err)
}
event := core.RequestEvent{}
event.Request, err = http.NewRequest("POST", "/test?q1=123&q2=456", strings.NewReader(`{"a":123,"b":"test"}`))
if err != nil {
t.Fatal(err)
}
event.Request.Header.Add("content-type", "application/json")
scenarios := []struct {
name string
record *core.Record
expected bool
}{
{"nil record", nil, false},
{"regular user record", user, false},
{"superuser record", superuser, true},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
event.Auth = s.record
info, err := event.RequestInfo()
if err != nil {
t.Fatalf("Failed to resolve request info: %v", err)
}
result := info.HasSuperuserAuth()
if result != s.expected {
t.Fatalf("Expected %v, got %v", s.expected, result)
}
})
}
}
func TestRequestInfoClone(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
userCol, err := app.FindCollectionByNameOrId("users")
if err != nil {
t.Fatal(err)
}
user := core.NewRecord(userCol)
user.Id = "user1"
user.SetEmail("test1@example.com")
event := core.RequestEvent{}
event.Request, err = http.NewRequest("POST", "/test?q1=123&q2=456", strings.NewReader(`{"a":123,"b":"test"}`))
if err != nil {
t.Fatal(err)
}
event.Request.Header.Add("content-type", "application/json")
event.Auth = user
info, err := event.RequestInfo()
if err != nil {
t.Fatalf("Failed to resolve request info: %v", err)
}
clone := info.Clone()
// modify the clone fields to ensure that it is a shallow copy
clone.Headers["new_header"] = "test"
clone.Query["new_query"] = "test"
clone.Body["new_body"] = "test"
clone.Auth.Id = "user2" // should be a Fresh copy of the record
// check the original data
// ---
originalRaw, err := json.Marshal(info)
if err != nil {
t.Fatalf("Failed to serialize original request info: %v", err)
}
originalRawStr := string(originalRaw)
expectedRawStr := `{"query":{"q1":"123","q2":"456"},"headers":{"content_type":"application/json"},"body":{"a":123,"b":"test"},"auth":{"avatar":"","collectionId":"_pb_users_auth_","collectionName":"users","created":"","emailVisibility":false,"file":[],"id":"user1","name":"","rel":"","updated":"","username":"","verified":false},"method":"POST","context":"default"}`
if expectedRawStr != originalRawStr {
t.Fatalf("Expected original info\n%v\ngot\n%v", expectedRawStr, originalRawStr)
}
// check the clone data
// ---
cloneRaw, err := json.Marshal(clone)
if err != nil {
t.Fatalf("Failed to serialize clone request info: %v", err)
}
cloneRawStr := string(cloneRaw)
expectedCloneStr := `{"query":{"new_query":"test","q1":"123","q2":"456"},"headers":{"content_type":"application/json","new_header":"test"},"body":{"a":123,"b":"test","new_body":"test"},"auth":{"avatar":"","collectionId":"_pb_users_auth_","collectionName":"users","created":"","emailVisibility":false,"file":[],"id":"user2","name":"","rel":"","updated":"","username":"","verified":false},"method":"POST","context":"default"}`
if expectedCloneStr != cloneRawStr {
t.Fatalf("Expected clone info\n%v\ngot\n%v", expectedCloneStr, cloneRawStr)
}
}

582
core/events.go Normal file
View file

@ -0,0 +1,582 @@
package core
import (
"context"
"net/http"
"time"
"github.com/pocketbase/pocketbase/tools/auth"
"github.com/pocketbase/pocketbase/tools/hook"
"github.com/pocketbase/pocketbase/tools/mailer"
"github.com/pocketbase/pocketbase/tools/router"
"github.com/pocketbase/pocketbase/tools/search"
"github.com/pocketbase/pocketbase/tools/subscriptions"
"golang.org/x/crypto/acme/autocert"
)
type HookTagger interface {
HookTags() []string
}
// -------------------------------------------------------------------
type baseModelEventData struct {
Model Model
}
func (e *baseModelEventData) Tags() []string {
if e.Model == nil {
return nil
}
if ht, ok := e.Model.(HookTagger); ok {
return ht.HookTags()
}
return []string{e.Model.TableName()}
}
// -------------------------------------------------------------------
type baseRecordEventData struct {
Record *Record
}
func (e *baseRecordEventData) Tags() []string {
if e.Record == nil {
return nil
}
return e.Record.HookTags()
}
// -------------------------------------------------------------------
type baseCollectionEventData struct {
Collection *Collection
}
func (e *baseCollectionEventData) Tags() []string {
if e.Collection == nil {
return nil
}
tags := make([]string, 0, 2)
if e.Collection.Id != "" {
tags = append(tags, e.Collection.Id)
}
if e.Collection.Name != "" {
tags = append(tags, e.Collection.Name)
}
return tags
}
// -------------------------------------------------------------------
// App events data
// -------------------------------------------------------------------
type BootstrapEvent struct {
hook.Event
App App
}
type TerminateEvent struct {
hook.Event
App App
IsRestart bool
}
type BackupEvent struct {
hook.Event
App App
Context context.Context
Name string // the name of the backup to create/restore.
Exclude []string // list of dir entries to exclude from the backup create/restore.
}
type ServeEvent struct {
hook.Event
App App
Router *router.Router[*RequestEvent]
Server *http.Server
CertManager *autocert.Manager
// InstallerFunc is the "installer" function that is called after
// successful server tcp bind but only if there is no explicit
// superuser record created yet.
//
// It runs in a separate goroutine and its default value is [apis.DefaultInstallerFunc].
//
// It receives a system superuser record as argument that you can use to generate
// a short-lived auth token (e.g. systemSuperuser.NewStaticAuthToken(30 * time.Minute))
// and concatenate it as query param for your installer page
// (if you are using the client-side SDKs, you can then load the
// token with pb.authStore.save(token) and perform any Web API request
// e.g. creating a new superuser).
//
// Set it to nil if you want to skip the installer.
InstallerFunc func(app App, systemSuperuser *Record, baseURL string) error
}
// -------------------------------------------------------------------
// Settings events data
// -------------------------------------------------------------------
type SettingsListRequestEvent struct {
hook.Event
*RequestEvent
Settings *Settings
}
type SettingsUpdateRequestEvent struct {
hook.Event
*RequestEvent
OldSettings *Settings
NewSettings *Settings
}
type SettingsReloadEvent struct {
hook.Event
App App
}
// -------------------------------------------------------------------
// Mailer events data
// -------------------------------------------------------------------
type MailerEvent struct {
hook.Event
App App
Mailer mailer.Mailer
Message *mailer.Message
}
type MailerRecordEvent struct {
MailerEvent
baseRecordEventData
Meta map[string]any
}
// -------------------------------------------------------------------
// Model events data
// -------------------------------------------------------------------
const (
ModelEventTypeCreate = "create"
ModelEventTypeUpdate = "update"
ModelEventTypeDelete = "delete"
ModelEventTypeValidate = "validate"
)
type ModelEvent struct {
hook.Event
App App
baseModelEventData
Context context.Context
// Could be any of the ModelEventType* constants, like:
// - create
// - update
// - delete
// - validate
Type string
}
type ModelErrorEvent struct {
Error error
ModelEvent
}
// -------------------------------------------------------------------
// Record events data
// -------------------------------------------------------------------
type RecordEvent struct {
hook.Event
App App
baseRecordEventData
Context context.Context
// Could be any of the ModelEventType* constants, like:
// - create
// - update
// - delete
// - validate
Type string
}
type RecordErrorEvent struct {
Error error
RecordEvent
}
func syncModelEventWithRecordEvent(me *ModelEvent, re *RecordEvent) {
me.App = re.App
me.Context = re.Context
me.Type = re.Type
// @todo enable if after profiling doesn't have significant impact
// skip for now to avoid excessive checks and assume that the
// Model and the Record fields still points to the same instance
//
// if _, ok := me.Model.(*Record); ok {
// me.Model = re.Record
// } else if proxy, ok := me.Model.(RecordProxy); ok {
// proxy.SetProxyRecord(re.Record)
// }
}
func syncRecordEventWithModelEvent(re *RecordEvent, me *ModelEvent) {
re.App = me.App
re.Context = me.Context
re.Type = me.Type
}
func newRecordEventFromModelEvent(me *ModelEvent) (*RecordEvent, bool) {
record, ok := me.Model.(*Record)
if !ok {
proxy, ok := me.Model.(RecordProxy)
if !ok {
return nil, false
}
record = proxy.ProxyRecord()
}
re := new(RecordEvent)
re.App = me.App
re.Context = me.Context
re.Type = me.Type
re.Record = record
return re, true
}
func newRecordErrorEventFromModelErrorEvent(me *ModelErrorEvent) (*RecordErrorEvent, bool) {
recordEvent, ok := newRecordEventFromModelEvent(&me.ModelEvent)
if !ok {
return nil, false
}
re := new(RecordErrorEvent)
re.RecordEvent = *recordEvent
re.Error = me.Error
return re, true
}
func syncModelErrorEventWithRecordErrorEvent(me *ModelErrorEvent, re *RecordErrorEvent) {
syncModelEventWithRecordEvent(&me.ModelEvent, &re.RecordEvent)
me.Error = re.Error
}
func syncRecordErrorEventWithModelErrorEvent(re *RecordErrorEvent, me *ModelErrorEvent) {
syncRecordEventWithModelEvent(&re.RecordEvent, &me.ModelEvent)
re.Error = me.Error
}
// -------------------------------------------------------------------
// Collection events data
// -------------------------------------------------------------------
type CollectionEvent struct {
hook.Event
App App
baseCollectionEventData
Context context.Context
// Could be any of the ModelEventType* constants, like:
// - create
// - update
// - delete
// - validate
Type string
}
type CollectionErrorEvent struct {
Error error
CollectionEvent
}
func syncModelEventWithCollectionEvent(me *ModelEvent, ce *CollectionEvent) {
me.App = ce.App
me.Context = ce.Context
me.Type = ce.Type
me.Model = ce.Collection
}
func syncCollectionEventWithModelEvent(ce *CollectionEvent, me *ModelEvent) {
ce.App = me.App
ce.Context = me.Context
ce.Type = me.Type
if c, ok := me.Model.(*Collection); ok {
ce.Collection = c
}
}
func newCollectionEventFromModelEvent(me *ModelEvent) (*CollectionEvent, bool) {
record, ok := me.Model.(*Collection)
if !ok {
return nil, false
}
ce := new(CollectionEvent)
ce.App = me.App
ce.Context = me.Context
ce.Type = me.Type
ce.Collection = record
return ce, true
}
func newCollectionErrorEventFromModelErrorEvent(me *ModelErrorEvent) (*CollectionErrorEvent, bool) {
collectionevent, ok := newCollectionEventFromModelEvent(&me.ModelEvent)
if !ok {
return nil, false
}
ce := new(CollectionErrorEvent)
ce.CollectionEvent = *collectionevent
ce.Error = me.Error
return ce, true
}
func syncModelErrorEventWithCollectionErrorEvent(me *ModelErrorEvent, ce *CollectionErrorEvent) {
syncModelEventWithCollectionEvent(&me.ModelEvent, &ce.CollectionEvent)
me.Error = ce.Error
}
func syncCollectionErrorEventWithModelErrorEvent(ce *CollectionErrorEvent, me *ModelErrorEvent) {
syncCollectionEventWithModelEvent(&ce.CollectionEvent, &me.ModelEvent)
ce.Error = me.Error
}
// -------------------------------------------------------------------
// File API events data
// -------------------------------------------------------------------
type FileTokenRequestEvent struct {
hook.Event
*RequestEvent
baseRecordEventData
Token string
}
type FileDownloadRequestEvent struct {
hook.Event
*RequestEvent
baseCollectionEventData
Record *Record
FileField *FileField
ServedPath string
ServedName string
}
// -------------------------------------------------------------------
// Collection API events data
// -------------------------------------------------------------------
type CollectionsListRequestEvent struct {
hook.Event
*RequestEvent
Collections []*Collection
Result *search.Result
}
type CollectionsImportRequestEvent struct {
hook.Event
*RequestEvent
CollectionsData []map[string]any
DeleteMissing bool
}
type CollectionRequestEvent struct {
hook.Event
*RequestEvent
baseCollectionEventData
}
// -------------------------------------------------------------------
// Realtime API events data
// -------------------------------------------------------------------
type RealtimeConnectRequestEvent struct {
hook.Event
*RequestEvent
Client subscriptions.Client
// note: modifying it after the connect has no effect
IdleTimeout time.Duration
}
type RealtimeMessageEvent struct {
hook.Event
*RequestEvent
Client subscriptions.Client
Message *subscriptions.Message
}
type RealtimeSubscribeRequestEvent struct {
hook.Event
*RequestEvent
Client subscriptions.Client
Subscriptions []string
}
// -------------------------------------------------------------------
// Record CRUD API events data
// -------------------------------------------------------------------
type RecordsListRequestEvent struct {
hook.Event
*RequestEvent
baseCollectionEventData
// @todo consider removing and maybe add as generic to the search.Result?
Records []*Record
Result *search.Result
}
type RecordRequestEvent struct {
hook.Event
*RequestEvent
baseCollectionEventData
Record *Record
}
type RecordEnrichEvent struct {
hook.Event
App App
baseRecordEventData
RequestInfo *RequestInfo
}
// -------------------------------------------------------------------
// Auth Record API events data
// -------------------------------------------------------------------
type RecordCreateOTPRequestEvent struct {
hook.Event
*RequestEvent
baseCollectionEventData
Record *Record
Password string
}
type RecordAuthWithOTPRequestEvent struct {
hook.Event
*RequestEvent
baseCollectionEventData
Record *Record
OTP *OTP
}
type RecordAuthRequestEvent struct {
hook.Event
*RequestEvent
baseCollectionEventData
Record *Record
Token string
Meta any
AuthMethod string
}
type RecordAuthWithPasswordRequestEvent struct {
hook.Event
*RequestEvent
baseCollectionEventData
Record *Record
Identity string
IdentityField string
Password string
}
type RecordAuthWithOAuth2RequestEvent struct {
hook.Event
*RequestEvent
baseCollectionEventData
ProviderName string
ProviderClient auth.Provider
Record *Record
OAuth2User *auth.AuthUser
CreateData map[string]any
IsNewRecord bool
}
type RecordAuthRefreshRequestEvent struct {
hook.Event
*RequestEvent
baseCollectionEventData
Record *Record
}
type RecordRequestPasswordResetRequestEvent struct {
hook.Event
*RequestEvent
baseCollectionEventData
Record *Record
}
type RecordConfirmPasswordResetRequestEvent struct {
hook.Event
*RequestEvent
baseCollectionEventData
Record *Record
}
type RecordRequestVerificationRequestEvent struct {
hook.Event
*RequestEvent
baseCollectionEventData
Record *Record
}
type RecordConfirmVerificationRequestEvent struct {
hook.Event
*RequestEvent
baseCollectionEventData
Record *Record
}
type RecordRequestEmailChangeRequestEvent struct {
hook.Event
*RequestEvent
baseCollectionEventData
Record *Record
NewEmail string
}
type RecordConfirmEmailChangeRequestEvent struct {
hook.Event
*RequestEvent
baseCollectionEventData
Record *Record
NewEmail string
}

140
core/external_auth_model.go Normal file
View file

@ -0,0 +1,140 @@
package core
import (
"context"
"errors"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/pocketbase/pocketbase/tools/auth"
"github.com/pocketbase/pocketbase/tools/hook"
"github.com/pocketbase/pocketbase/tools/types"
)
var (
_ Model = (*ExternalAuth)(nil)
_ PreValidator = (*ExternalAuth)(nil)
_ RecordProxy = (*ExternalAuth)(nil)
)
const CollectionNameExternalAuths = "_externalAuths"
// ExternalAuth defines a Record proxy for working with the externalAuths collection.
type ExternalAuth struct {
*Record
}
// NewExternalAuth instantiates and returns a new blank *ExternalAuth model.
//
// Example usage:
//
// ea := core.NewExternalAuth(app)
// ea.SetRecordRef(user.Id)
// ea.SetCollectionRef(user.Collection().Id)
// ea.SetProvider("google")
// ea.SetProviderId("...")
// app.Save(ea)
func NewExternalAuth(app App) *ExternalAuth {
m := &ExternalAuth{}
c, err := app.FindCachedCollectionByNameOrId(CollectionNameExternalAuths)
if err != nil {
// this is just to make tests easier since it is a system collection and it is expected to be always accessible
// (note: the loaded record is further checked on ExternalAuth.PreValidate())
c = NewBaseCollection("@__invalid__")
}
m.Record = NewRecord(c)
return m
}
// PreValidate implements the [PreValidator] interface and checks
// whether the proxy is properly loaded.
func (m *ExternalAuth) PreValidate(ctx context.Context, app App) error {
if m.Record == nil || m.Record.Collection().Name != CollectionNameExternalAuths {
return errors.New("missing or invalid ExternalAuth ProxyRecord")
}
return nil
}
// ProxyRecord returns the proxied Record model.
func (m *ExternalAuth) ProxyRecord() *Record {
return m.Record
}
// SetProxyRecord loads the specified record model into the current proxy.
func (m *ExternalAuth) SetProxyRecord(record *Record) {
m.Record = record
}
// CollectionRef returns the "collectionRef" field value.
func (m *ExternalAuth) CollectionRef() string {
return m.GetString("collectionRef")
}
// SetCollectionRef updates the "collectionRef" record field value.
func (m *ExternalAuth) SetCollectionRef(collectionId string) {
m.Set("collectionRef", collectionId)
}
// RecordRef returns the "recordRef" record field value.
func (m *ExternalAuth) RecordRef() string {
return m.GetString("recordRef")
}
// SetRecordRef updates the "recordRef" record field value.
func (m *ExternalAuth) SetRecordRef(recordId string) {
m.Set("recordRef", recordId)
}
// Provider returns the "provider" record field value.
func (m *ExternalAuth) Provider() string {
return m.GetString("provider")
}
// SetProvider updates the "provider" record field value.
func (m *ExternalAuth) SetProvider(provider string) {
m.Set("provider", provider)
}
// Provider returns the "providerId" record field value.
func (m *ExternalAuth) ProviderId() string {
return m.GetString("providerId")
}
// SetProvider updates the "providerId" record field value.
func (m *ExternalAuth) SetProviderId(providerId string) {
m.Set("providerId", providerId)
}
// Created returns the "created" record field value.
func (m *ExternalAuth) Created() types.DateTime {
return m.GetDateTime("created")
}
// Updated returns the "updated" record field value.
func (m *ExternalAuth) Updated() types.DateTime {
return m.GetDateTime("updated")
}
func (app *BaseApp) registerExternalAuthHooks() {
recordRefHooks[*ExternalAuth](app, CollectionNameExternalAuths, CollectionTypeAuth)
app.OnRecordValidate(CollectionNameExternalAuths).Bind(&hook.Handler[*RecordEvent]{
Func: func(e *RecordEvent) error {
providerNames := make([]any, 0, len(auth.Providers))
for name := range auth.Providers {
providerNames = append(providerNames, name)
}
provider := e.Record.GetString("provider")
if err := validation.Validate(provider, validation.Required, validation.In(providerNames...)); err != nil {
return validation.Errors{"provider": err}
}
return e.Next()
},
Priority: 99,
})
}

View file

@ -0,0 +1,310 @@
package core_test
import (
"fmt"
"testing"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
"github.com/pocketbase/pocketbase/tools/types"
)
func TestNewExternalAuth(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
ea := core.NewExternalAuth(app)
if ea.Collection().Name != core.CollectionNameExternalAuths {
t.Fatalf("Expected record with %q collection, got %q", core.CollectionNameExternalAuths, ea.Collection().Name)
}
}
func TestExternalAuthProxyRecord(t *testing.T) {
t.Parallel()
record := core.NewRecord(core.NewBaseCollection("test"))
record.Id = "test_id"
ea := core.ExternalAuth{}
ea.SetProxyRecord(record)
if ea.ProxyRecord() == nil || ea.ProxyRecord().Id != record.Id {
t.Fatalf("Expected proxy record with id %q, got %v", record.Id, ea.ProxyRecord())
}
}
func TestExternalAuthRecordRef(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
ea := core.NewExternalAuth(app)
testValues := []string{"test_1", "test2", ""}
for i, testValue := range testValues {
t.Run(fmt.Sprintf("%d_%q", i, testValue), func(t *testing.T) {
ea.SetRecordRef(testValue)
if v := ea.RecordRef(); v != testValue {
t.Fatalf("Expected getter %q, got %q", testValue, v)
}
if v := ea.GetString("recordRef"); v != testValue {
t.Fatalf("Expected field value %q, got %q", testValue, v)
}
})
}
}
func TestExternalAuthCollectionRef(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
ea := core.NewExternalAuth(app)
testValues := []string{"test_1", "test2", ""}
for i, testValue := range testValues {
t.Run(fmt.Sprintf("%d_%q", i, testValue), func(t *testing.T) {
ea.SetCollectionRef(testValue)
if v := ea.CollectionRef(); v != testValue {
t.Fatalf("Expected getter %q, got %q", testValue, v)
}
if v := ea.GetString("collectionRef"); v != testValue {
t.Fatalf("Expected field value %q, got %q", testValue, v)
}
})
}
}
func TestExternalAuthProvider(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
ea := core.NewExternalAuth(app)
testValues := []string{"test_1", "test2", ""}
for i, testValue := range testValues {
t.Run(fmt.Sprintf("%d_%q", i, testValue), func(t *testing.T) {
ea.SetProvider(testValue)
if v := ea.Provider(); v != testValue {
t.Fatalf("Expected getter %q, got %q", testValue, v)
}
if v := ea.GetString("provider"); v != testValue {
t.Fatalf("Expected field value %q, got %q", testValue, v)
}
})
}
}
func TestExternalAuthProviderId(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
ea := core.NewExternalAuth(app)
testValues := []string{"test_1", "test2", ""}
for i, testValue := range testValues {
t.Run(fmt.Sprintf("%d_%q", i, testValue), func(t *testing.T) {
ea.SetProviderId(testValue)
if v := ea.ProviderId(); v != testValue {
t.Fatalf("Expected getter %q, got %q", testValue, v)
}
if v := ea.GetString("providerId"); v != testValue {
t.Fatalf("Expected field value %q, got %q", testValue, v)
}
})
}
}
func TestExternalAuthCreated(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
ea := core.NewExternalAuth(app)
if v := ea.Created().String(); v != "" {
t.Fatalf("Expected empty created, got %q", v)
}
now := types.NowDateTime()
ea.SetRaw("created", now)
if v := ea.Created().String(); v != now.String() {
t.Fatalf("Expected %q created, got %q", now.String(), v)
}
}
func TestExternalAuthUpdated(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
ea := core.NewExternalAuth(app)
if v := ea.Updated().String(); v != "" {
t.Fatalf("Expected empty updated, got %q", v)
}
now := types.NowDateTime()
ea.SetRaw("updated", now)
if v := ea.Updated().String(); v != now.String() {
t.Fatalf("Expected %q updated, got %q", now.String(), v)
}
}
func TestExternalAuthPreValidate(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
externalAuthsCol, err := app.FindCollectionByNameOrId(core.CollectionNameExternalAuths)
if err != nil {
t.Fatal(err)
}
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
t.Run("no proxy record", func(t *testing.T) {
externalAuth := &core.ExternalAuth{}
if err := app.Validate(externalAuth); err == nil {
t.Fatal("Expected collection validation error")
}
})
t.Run("non-ExternalAuth collection", func(t *testing.T) {
externalAuth := &core.ExternalAuth{}
externalAuth.SetProxyRecord(core.NewRecord(core.NewBaseCollection("invalid")))
externalAuth.SetRecordRef(user.Id)
externalAuth.SetCollectionRef(user.Collection().Id)
externalAuth.SetProvider("gitlab")
externalAuth.SetProviderId("test123")
if err := app.Validate(externalAuth); err == nil {
t.Fatal("Expected collection validation error")
}
})
t.Run("ExternalAuth collection", func(t *testing.T) {
externalAuth := &core.ExternalAuth{}
externalAuth.SetProxyRecord(core.NewRecord(externalAuthsCol))
externalAuth.SetRecordRef(user.Id)
externalAuth.SetCollectionRef(user.Collection().Id)
externalAuth.SetProvider("gitlab")
externalAuth.SetProviderId("test123")
if err := app.Validate(externalAuth); err != nil {
t.Fatalf("Expected nil validation error, got %v", err)
}
})
}
func TestExternalAuthValidateHook(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
demo1, err := app.FindRecordById("demo1", "84nmscqy84lsi1t")
if err != nil {
t.Fatal(err)
}
scenarios := []struct {
name string
externalAuth func() *core.ExternalAuth
expectErrors []string
}{
{
"empty",
func() *core.ExternalAuth {
return core.NewExternalAuth(app)
},
[]string{"collectionRef", "recordRef", "provider", "providerId"},
},
{
"non-auth collection",
func() *core.ExternalAuth {
ea := core.NewExternalAuth(app)
ea.SetCollectionRef(demo1.Collection().Id)
ea.SetRecordRef(demo1.Id)
ea.SetProvider("gitlab")
ea.SetProviderId("test123")
return ea
},
[]string{"collectionRef"},
},
{
"disabled provider",
func() *core.ExternalAuth {
ea := core.NewExternalAuth(app)
ea.SetCollectionRef(user.Collection().Id)
ea.SetRecordRef("missing")
ea.SetProvider("apple")
ea.SetProviderId("test123")
return ea
},
[]string{"recordRef"},
},
{
"missing record id",
func() *core.ExternalAuth {
ea := core.NewExternalAuth(app)
ea.SetCollectionRef(user.Collection().Id)
ea.SetRecordRef("missing")
ea.SetProvider("gitlab")
ea.SetProviderId("test123")
return ea
},
[]string{"recordRef"},
},
{
"valid ref",
func() *core.ExternalAuth {
ea := core.NewExternalAuth(app)
ea.SetCollectionRef(user.Collection().Id)
ea.SetRecordRef(user.Id)
ea.SetProvider("gitlab")
ea.SetProviderId("test123")
return ea
},
[]string{},
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
errs := app.Validate(s.externalAuth())
tests.TestValidationErrors(t, errs, s.expectErrors)
})
}
}

View file

@ -0,0 +1,61 @@
package core
import (
"github.com/pocketbase/dbx"
)
// FindAllExternalAuthsByRecord returns all ExternalAuth models
// linked to the provided auth record.
func (app *BaseApp) FindAllExternalAuthsByRecord(authRecord *Record) ([]*ExternalAuth, error) {
auths := []*ExternalAuth{}
err := app.RecordQuery(CollectionNameExternalAuths).
AndWhere(dbx.HashExp{
"collectionRef": authRecord.Collection().Id,
"recordRef": authRecord.Id,
}).
OrderBy("created DESC").
All(&auths)
if err != nil {
return nil, err
}
return auths, nil
}
// FindAllExternalAuthsByCollection returns all ExternalAuth models
// linked to the provided auth collection.
func (app *BaseApp) FindAllExternalAuthsByCollection(collection *Collection) ([]*ExternalAuth, error) {
auths := []*ExternalAuth{}
err := app.RecordQuery(CollectionNameExternalAuths).
AndWhere(dbx.HashExp{"collectionRef": collection.Id}).
OrderBy("created DESC").
All(&auths)
if err != nil {
return nil, err
}
return auths, nil
}
// FindFirstExternalAuthByExpr returns the first available (the most recent created)
// ExternalAuth model that satisfies the non-nil expression.
func (app *BaseApp) FindFirstExternalAuthByExpr(expr dbx.Expression) (*ExternalAuth, error) {
model := &ExternalAuth{}
err := app.RecordQuery(CollectionNameExternalAuths).
AndWhere(dbx.Not(dbx.HashExp{"providerId": ""})). // exclude empty providerIds
AndWhere(expr).
OrderBy("created DESC").
Limit(1).
One(model)
if err != nil {
return nil, err
}
return model, nil
}

View file

@ -0,0 +1,176 @@
package core_test
import (
"fmt"
"testing"
"github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
)
func TestFindAllExternalAuthsByRecord(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
demo1, err := app.FindRecordById("demo1", "84nmscqy84lsi1t")
if err != nil {
t.Fatal(err)
}
superuser1, err := app.FindAuthRecordByEmail(core.CollectionNameSuperusers, "test@example.com")
if err != nil {
t.Fatal(err)
}
user1, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
user2, err := app.FindAuthRecordByEmail("users", "test2@example.com")
if err != nil {
t.Fatal(err)
}
user3, err := app.FindAuthRecordByEmail("users", "test3@example.com")
if err != nil {
t.Fatal(err)
}
client1, err := app.FindAuthRecordByEmail("clients", "test@example.com")
if err != nil {
t.Fatal(err)
}
scenarios := []struct {
record *core.Record
expected []string
}{
{demo1, nil},
{superuser1, nil},
{client1, []string{"f1z5b3843pzc964"}},
{user1, []string{"clmflokuq1xl341", "dlmflokuq1xl342"}},
{user2, nil},
{user3, []string{"5eto7nmys833164"}},
}
for _, s := range scenarios {
t.Run(s.record.Collection().Name+"_"+s.record.Id, func(t *testing.T) {
result, err := app.FindAllExternalAuthsByRecord(s.record)
if err != nil {
t.Fatal(err)
}
if len(result) != len(s.expected) {
t.Fatalf("Expected total models %d, got %d", len(s.expected), len(result))
}
for i, id := range s.expected {
if result[i].Id != id {
t.Errorf("[%d] Expected id %q, got %q", i, id, result[i].Id)
}
}
})
}
}
func TestFindAllExternalAuthsByCollection(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
demo1, err := app.FindCollectionByNameOrId("demo1")
if err != nil {
t.Fatal(err)
}
superusers, err := app.FindCollectionByNameOrId(core.CollectionNameSuperusers)
if err != nil {
t.Fatal(err)
}
clients, err := app.FindCollectionByNameOrId("clients")
if err != nil {
t.Fatal(err)
}
users, err := app.FindCollectionByNameOrId("users")
if err != nil {
t.Fatal(err)
}
scenarios := []struct {
collection *core.Collection
expected []string
}{
{demo1, nil},
{superusers, nil},
{clients, []string{
"f1z5b3843pzc964",
}},
{users, []string{
"5eto7nmys833164",
"clmflokuq1xl341",
"dlmflokuq1xl342",
}},
}
for _, s := range scenarios {
t.Run(s.collection.Name, func(t *testing.T) {
result, err := app.FindAllExternalAuthsByCollection(s.collection)
if err != nil {
t.Fatal(err)
}
if len(result) != len(s.expected) {
t.Fatalf("Expected total models %d, got %d", len(s.expected), len(result))
}
for i, id := range s.expected {
if result[i].Id != id {
t.Errorf("[%d] Expected id %q, got %q", i, id, result[i].Id)
}
}
})
}
}
func TestFindFirstExternalAuthByExpr(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
scenarios := []struct {
expr dbx.Expression
expectedId string
}{
{dbx.HashExp{"collectionRef": "invalid"}, ""},
{dbx.HashExp{"collectionRef": "_pb_users_auth_"}, "5eto7nmys833164"},
{dbx.HashExp{"collectionRef": "_pb_users_auth_", "provider": "gitlab"}, "dlmflokuq1xl342"},
}
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%v", i, s.expr.Build(app.ConcurrentDB().(*dbx.DB), dbx.Params{})), func(t *testing.T) {
result, err := app.FindFirstExternalAuthByExpr(s.expr)
hasErr := err != nil
expectErr := s.expectedId == ""
if hasErr != expectErr {
t.Fatalf("Expected hasErr %v, got %v", expectErr, hasErr)
}
if hasErr {
return
}
if result.Id != s.expectedId {
t.Errorf("Expected id %q, got %q", s.expectedId, result.Id)
}
})
}
}

252
core/field.go Normal file
View file

@ -0,0 +1,252 @@
package core
import (
"context"
"database/sql/driver"
"regexp"
"strings"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/pocketbase/pocketbase/core/validators"
"github.com/pocketbase/pocketbase/tools/list"
)
var fieldNameRegex = regexp.MustCompile(`^\w+$`)
const maxSafeJSONInt int64 = 1<<53 - 1
// Commonly used field names.
const (
FieldNameId = "id"
FieldNameCollectionId = "collectionId"
FieldNameCollectionName = "collectionName"
FieldNameExpand = "expand"
FieldNameEmail = "email"
FieldNameEmailVisibility = "emailVisibility"
FieldNameVerified = "verified"
FieldNameTokenKey = "tokenKey"
FieldNamePassword = "password"
)
// SystemFields returns special internal field names that are usually readonly.
var SystemDynamicFieldNames = []string{
FieldNameCollectionId,
FieldNameCollectionName,
FieldNameExpand,
}
// Common RecordInterceptor action names.
const (
InterceptorActionValidate = "validate"
InterceptorActionDelete = "delete"
InterceptorActionDeleteExecute = "deleteExecute"
InterceptorActionAfterDelete = "afterDelete"
InterceptorActionAfterDeleteError = "afterDeleteError"
InterceptorActionCreate = "create"
InterceptorActionCreateExecute = "createExecute"
InterceptorActionAfterCreate = "afterCreate"
InterceptorActionAfterCreateError = "afterCreateFailure"
InterceptorActionUpdate = "update"
InterceptorActionUpdateExecute = "updateExecute"
InterceptorActionAfterUpdate = "afterUpdate"
InterceptorActionAfterUpdateError = "afterUpdateError"
)
// Common field errors.
var (
ErrUnknownField = validation.NewError("validation_unknown_field", "Unknown or invalid field.")
ErrInvalidFieldValue = validation.NewError("validation_invalid_field_value", "Invalid field value.")
ErrMustBeSystemAndHidden = validation.NewError("validation_must_be_system_and_hidden", `The field must be marked as "System" and "Hidden".`)
ErrMustBeSystem = validation.NewError("validation_must_be_system", `The field must be marked as "System".`)
)
// FieldFactoryFunc defines a simple function to construct a specific Field instance.
type FieldFactoryFunc func() Field
// Fields holds all available collection fields.
var Fields = map[string]FieldFactoryFunc{}
// Field defines a common interface that all Collection fields should implement.
type Field interface {
// note: the getters has an explicit "Get" prefix to avoid conflicts with their related field members
// GetId returns the field id.
GetId() string
// SetId changes the field id.
SetId(id string)
// GetName returns the field name.
GetName() string
// SetName changes the field name.
SetName(name string)
// GetSystem returns the field system flag state.
GetSystem() bool
// SetSystem changes the field system flag state.
SetSystem(system bool)
// GetHidden returns the field hidden flag state.
GetHidden() bool
// SetHidden changes the field hidden flag state.
SetHidden(hidden bool)
// Type returns the unique type of the field.
Type() string
// ColumnType returns the DB column definition of the field.
ColumnType(app App) string
// PrepareValue returns a properly formatted field value based on the provided raw one.
//
// This method is also called on record construction to initialize its default field value.
PrepareValue(record *Record, raw any) (any, error)
// ValidateSettings validates the current field value associated with the provided record.
ValidateValue(ctx context.Context, app App, record *Record) error
// ValidateSettings validates the current field settings.
ValidateSettings(ctx context.Context, app App, collection *Collection) error
}
// MaxBodySizeCalculator defines an optional field interface for
// specifying the max size of a field value.
type MaxBodySizeCalculator interface {
// CalculateMaxBodySize returns the approximate max body size of a field value.
CalculateMaxBodySize() int64
}
type (
SetterFunc func(record *Record, raw any)
// SetterFinder defines a field interface for registering custom field value setters.
SetterFinder interface {
// FindSetter returns a single field value setter function
// by performing pattern-like field matching using the specified key.
//
// The key is usually just the field name but it could also
// contains "modifier" characters based on which you can perform custom set operations
// (ex. "users+" could be mapped to a function that will append new user to the existing field value).
//
// Return nil if you want to fallback to the default field value setter.
FindSetter(key string) SetterFunc
}
)
type (
GetterFunc func(record *Record) any
// GetterFinder defines a field interface for registering custom field value getters.
GetterFinder interface {
// FindGetter returns a single field value getter function
// by performing pattern-like field matching using the specified key.
//
// The key is usually just the field name but it could also
// contains "modifier" characters based on which you can perform custom get operations
// (ex. "description:excerpt" could be mapped to a function that will return an excerpt of the current field value).
//
// Return nil if you want to fallback to the default field value setter.
FindGetter(key string) GetterFunc
}
)
// DriverValuer defines a Field interface for exporting and formatting
// a field value for the database.
type DriverValuer interface {
// DriverValue exports a single field value for persistence in the database.
DriverValue(record *Record) (driver.Value, error)
}
// MultiValuer defines a field interface that every multi-valued (eg. with MaxSelect) field has.
type MultiValuer interface {
// IsMultiple checks whether the field is configured to support multiple or single values.
IsMultiple() bool
}
// RecordInterceptor defines a field interface for reacting to various
// Record related operations (create, delete, validate, etc.).
type RecordInterceptor interface {
// Interceptor is invoked when a specific record action occurs
// allowing you to perform extra validations and normalization
// (ex. uploading or deleting files).
//
// Note that users must call actionFunc() manually if they want to
// execute the specific record action.
Intercept(
ctx context.Context,
app App,
record *Record,
actionName string,
actionFunc func() error,
) error
}
// DefaultFieldIdValidationRule performs base validation on a field id value.
func DefaultFieldIdValidationRule(value any) error {
v, ok := value.(string)
if !ok {
return validators.ErrUnsupportedValueType
}
rules := []validation.Rule{
validation.Required,
validation.Length(1, 100),
}
for _, r := range rules {
if err := r.Validate(v); err != nil {
return err
}
}
return nil
}
// exclude special filter and system literals
var excludeNames = append([]any{
"null", "true", "false", "_rowid_",
}, list.ToInterfaceSlice(SystemDynamicFieldNames)...)
// DefaultFieldIdValidationRule performs base validation on a field name value.
func DefaultFieldNameValidationRule(value any) error {
v, ok := value.(string)
if !ok {
return validators.ErrUnsupportedValueType
}
rules := []validation.Rule{
validation.Required,
validation.Length(1, 100),
validation.Match(fieldNameRegex),
validation.NotIn(excludeNames...),
validation.By(checkForVia),
}
for _, r := range rules {
if err := r.Validate(v); err != nil {
return err
}
}
return nil
}
func checkForVia(value any) error {
v, _ := value.(string)
if v == "" {
return nil
}
if strings.Contains(strings.ToLower(v), "_via_") {
return validation.NewError("validation_found_via", `The value cannot contain "_via_".`)
}
return nil
}
func noopSetter(record *Record, raw any) {
// do nothing
}

215
core/field_autodate.go Normal file
View file

@ -0,0 +1,215 @@
package core
import (
"context"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/pocketbase/pocketbase/core/validators"
"github.com/pocketbase/pocketbase/tools/types"
)
func init() {
Fields[FieldTypeAutodate] = func() Field {
return &AutodateField{}
}
}
const FieldTypeAutodate = "autodate"
// used to keep track of the last set autodate value
const autodateLastKnownPrefix = internalCustomFieldKeyPrefix + "_last_autodate_"
var (
_ Field = (*AutodateField)(nil)
_ SetterFinder = (*AutodateField)(nil)
_ RecordInterceptor = (*AutodateField)(nil)
)
// AutodateField defines an "autodate" type field, aka.
// field which datetime value could be auto set on record create/update.
//
// This field is usually used for defining timestamp fields like "created" and "updated".
//
// Requires either both or at least one of the OnCreate or OnUpdate options to be set.
type AutodateField struct {
// Name (required) is the unique name of the field.
Name string `form:"name" json:"name"`
// Id is the unique stable field identifier.
//
// It is automatically generated from the name when adding to a collection FieldsList.
Id string `form:"id" json:"id"`
// System prevents the renaming and removal of the field.
System bool `form:"system" json:"system"`
// Hidden hides the field from the API response.
Hidden bool `form:"hidden" json:"hidden"`
// Presentable hints the Dashboard UI to use the underlying
// field record value in the relation preview label.
Presentable bool `form:"presentable" json:"presentable"`
// ---
// OnCreate auto sets the current datetime as field value on record create.
OnCreate bool `form:"onCreate" json:"onCreate"`
// OnUpdate auto sets the current datetime as field value on record update.
OnUpdate bool `form:"onUpdate" json:"onUpdate"`
}
// Type implements [Field.Type] interface method.
func (f *AutodateField) Type() string {
return FieldTypeAutodate
}
// GetId implements [Field.GetId] interface method.
func (f *AutodateField) GetId() string {
return f.Id
}
// SetId implements [Field.SetId] interface method.
func (f *AutodateField) SetId(id string) {
f.Id = id
}
// GetName implements [Field.GetName] interface method.
func (f *AutodateField) GetName() string {
return f.Name
}
// SetName implements [Field.SetName] interface method.
func (f *AutodateField) SetName(name string) {
f.Name = name
}
// GetSystem implements [Field.GetSystem] interface method.
func (f *AutodateField) GetSystem() bool {
return f.System
}
// SetSystem implements [Field.SetSystem] interface method.
func (f *AutodateField) SetSystem(system bool) {
f.System = system
}
// GetHidden implements [Field.GetHidden] interface method.
func (f *AutodateField) GetHidden() bool {
return f.Hidden
}
// SetHidden implements [Field.SetHidden] interface method.
func (f *AutodateField) SetHidden(hidden bool) {
f.Hidden = hidden
}
// ColumnType implements [Field.ColumnType] interface method.
func (f *AutodateField) ColumnType(app App) string {
return "TEXT DEFAULT '' NOT NULL" // note: sqlite doesn't allow adding new columns with non-constant defaults
}
// PrepareValue implements [Field.PrepareValue] interface method.
func (f *AutodateField) PrepareValue(record *Record, raw any) (any, error) {
val, _ := types.ParseDateTime(raw)
return val, nil
}
// ValidateValue implements [Field.ValidateValue] interface method.
func (f *AutodateField) ValidateValue(ctx context.Context, app App, record *Record) error {
return nil
}
// ValidateSettings implements [Field.ValidateSettings] interface method.
func (f *AutodateField) ValidateSettings(ctx context.Context, app App, collection *Collection) error {
oldOnCreate := f.OnCreate
oldOnUpdate := f.OnUpdate
oldCollection, _ := app.FindCollectionByNameOrId(collection.Id)
if oldCollection != nil {
oldField, ok := oldCollection.Fields.GetById(f.Id).(*AutodateField)
if ok && oldField != nil {
oldOnCreate = oldField.OnCreate
oldOnUpdate = oldField.OnUpdate
}
}
return validation.ValidateStruct(f,
validation.Field(&f.Id, validation.By(DefaultFieldIdValidationRule)),
validation.Field(&f.Name, validation.By(DefaultFieldNameValidationRule)),
validation.Field(
&f.OnCreate,
validation.When(f.System, validation.By(validators.Equal(oldOnCreate))),
validation.Required.Error("either onCreate or onUpdate must be enabled").When(!f.OnUpdate),
),
validation.Field(
&f.OnUpdate,
validation.When(f.System, validation.By(validators.Equal(oldOnUpdate))),
validation.Required.Error("either onCreate or onUpdate must be enabled").When(!f.OnCreate),
),
)
}
// FindSetter implements the [SetterFinder] interface.
func (f *AutodateField) FindSetter(key string) SetterFunc {
switch key {
case f.Name:
// return noopSetter to disallow updating the value with record.Set()
return noopSetter
default:
return nil
}
}
// Intercept implements the [RecordInterceptor] interface.
func (f *AutodateField) Intercept(
ctx context.Context,
app App,
record *Record,
actionName string,
actionFunc func() error,
) error {
switch actionName {
case InterceptorActionCreateExecute:
// ignore if a date different from the old one was manually set with SetRaw
if f.OnCreate && record.GetDateTime(f.Name).Equal(f.getLastKnownValue(record)) {
now := types.NowDateTime()
record.SetRaw(f.Name, now)
record.SetRaw(autodateLastKnownPrefix+f.Name, now) // eagerly set so that it can be renewed on resave after failure
}
if err := actionFunc(); err != nil {
return err
}
record.SetRaw(autodateLastKnownPrefix+f.Name, record.GetRaw(f.Name))
return nil
case InterceptorActionUpdateExecute:
// ignore if a date different from the old one was manually set with SetRaw
if f.OnUpdate && record.GetDateTime(f.Name).Equal(f.getLastKnownValue(record)) {
now := types.NowDateTime()
record.SetRaw(f.Name, now)
record.SetRaw(autodateLastKnownPrefix+f.Name, now) // eagerly set so that it can be renewed on resave after failure
}
if err := actionFunc(); err != nil {
return err
}
record.SetRaw(autodateLastKnownPrefix+f.Name, record.GetRaw(f.Name))
return nil
default:
return actionFunc()
}
}
func (f *AutodateField) getLastKnownValue(record *Record) types.DateTime {
v := record.GetDateTime(autodateLastKnownPrefix + f.Name)
if !v.IsZero() {
return v
}
return record.Original().GetDateTime(f.Name)
}

441
core/field_autodate_test.go Normal file
View file

@ -0,0 +1,441 @@
package core_test
import (
"context"
"errors"
"fmt"
"strings"
"testing"
"time"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
"github.com/pocketbase/pocketbase/tools/hook"
"github.com/pocketbase/pocketbase/tools/types"
)
func TestAutodateFieldBaseMethods(t *testing.T) {
testFieldBaseMethods(t, core.FieldTypeAutodate)
}
func TestAutodateFieldColumnType(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
f := &core.AutodateField{}
expected := "TEXT DEFAULT '' NOT NULL"
if v := f.ColumnType(app); v != expected {
t.Fatalf("Expected\n%q\ngot\n%q", expected, v)
}
}
func TestAutodateFieldPrepareValue(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
f := &core.AutodateField{}
record := core.NewRecord(core.NewBaseCollection("test"))
scenarios := []struct {
raw any
expected string
}{
{"", ""},
{"invalid", ""},
{"2024-01-01 00:11:22.345Z", "2024-01-01 00:11:22.345Z"},
{time.Date(2024, 1, 2, 3, 4, 5, 0, time.UTC), "2024-01-02 03:04:05.000Z"},
}
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%#v", i, s.raw), func(t *testing.T) {
v, err := f.PrepareValue(record, s.raw)
if err != nil {
t.Fatal(err)
}
vDate, ok := v.(types.DateTime)
if !ok {
t.Fatalf("Expected types.DateTime instance, got %T", v)
}
if vDate.String() != s.expected {
t.Fatalf("Expected %v, got %v", s.expected, v)
}
})
}
}
func TestAutodateFieldValidateValue(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
collection := core.NewBaseCollection("test_collection")
scenarios := []struct {
name string
field *core.AutodateField
record func() *core.Record
expectError bool
}{
{
"invalid raw value",
&core.AutodateField{Name: "test"},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", 123)
return record
},
false,
},
{
"missing field value",
&core.AutodateField{Name: "test"},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("abc", true)
return record
},
false,
},
{
"existing field value",
&core.AutodateField{Name: "test"},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", types.NowDateTime())
return record
},
false,
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
err := s.field.ValidateValue(context.Background(), app, s.record())
hasErr := err != nil
if hasErr != s.expectError {
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
}
})
}
}
func TestAutodateFieldValidateSettings(t *testing.T) {
testDefaultFieldIdValidation(t, core.FieldTypeAutodate)
testDefaultFieldNameValidation(t, core.FieldTypeAutodate)
app, _ := tests.NewTestApp()
defer app.Cleanup()
superusers, err := app.FindCollectionByNameOrId(core.CollectionNameSuperusers)
if err != nil {
t.Fatal(err)
}
scenarios := []struct {
name string
field func() *core.AutodateField
expectErrors []string
}{
{
"empty onCreate and onUpdate",
func() *core.AutodateField {
return &core.AutodateField{
Id: "test",
Name: "test",
}
},
[]string{"onCreate", "onUpdate"},
},
{
"with onCreate",
func() *core.AutodateField {
return &core.AutodateField{
Id: "test",
Name: "test",
OnCreate: true,
}
},
[]string{},
},
{
"with onUpdate",
func() *core.AutodateField {
return &core.AutodateField{
Id: "test",
Name: "test",
OnUpdate: true,
}
},
[]string{},
},
{
"change of a system autodate field",
func() *core.AutodateField {
created := superusers.Fields.GetByName("created").(*core.AutodateField)
created.OnCreate = !created.OnCreate
created.OnUpdate = !created.OnUpdate
return created
},
[]string{"onCreate", "onUpdate"},
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
errs := s.field().ValidateSettings(context.Background(), app, superusers)
tests.TestValidationErrors(t, errs, s.expectErrors)
})
}
}
func TestAutodateFieldFindSetter(t *testing.T) {
field := &core.AutodateField{Name: "test"}
collection := core.NewBaseCollection("test_collection")
collection.Fields.Add(field)
initialDate, err := types.ParseDateTime("2024-01-02 03:04:05.789Z")
if err != nil {
t.Fatal(err)
}
record := core.NewRecord(collection)
record.SetRaw("test", initialDate)
t.Run("no matching setter", func(t *testing.T) {
f := field.FindSetter("abc")
if f != nil {
t.Fatal("Expected nil setter")
}
})
t.Run("matching setter", func(t *testing.T) {
f := field.FindSetter("test")
if f == nil {
t.Fatal("Expected non-nil setter")
}
f(record, types.NowDateTime()) // should be ignored
if v := record.GetString("test"); v != "2024-01-02 03:04:05.789Z" {
t.Fatalf("Expected no value change, got %q", v)
}
})
}
func cutMilliseconds(datetime string) string {
if len(datetime) > 19 {
return datetime[:19]
}
return datetime
}
func TestAutodateFieldIntercept(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
initialDate, err := types.ParseDateTime("2024-01-02 03:04:05.789Z")
if err != nil {
t.Fatal(err)
}
collection := core.NewBaseCollection("test_collection")
scenarios := []struct {
name string
actionName string
field *core.AutodateField
record func() *core.Record
expected string
}{
{
"non-matching action",
"test",
&core.AutodateField{Name: "test", OnCreate: true, OnUpdate: true},
func() *core.Record {
return core.NewRecord(collection)
},
"",
},
{
"create with zero value (disabled onCreate)",
core.InterceptorActionCreateExecute,
&core.AutodateField{Name: "test", OnCreate: false, OnUpdate: true},
func() *core.Record {
return core.NewRecord(collection)
},
"",
},
{
"create with zero value",
core.InterceptorActionCreateExecute,
&core.AutodateField{Name: "test", OnCreate: true, OnUpdate: true},
func() *core.Record {
return core.NewRecord(collection)
},
"{NOW}",
},
{
"create with non-zero value",
core.InterceptorActionCreateExecute,
&core.AutodateField{Name: "test", OnCreate: true, OnUpdate: true},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", initialDate)
return record
},
initialDate.String(),
},
{
"update with zero value (disabled onUpdate)",
core.InterceptorActionUpdateExecute,
&core.AutodateField{Name: "test", OnCreate: true, OnUpdate: false},
func() *core.Record {
return core.NewRecord(collection)
},
"",
},
{
"update with zero value",
core.InterceptorActionUpdateExecute,
&core.AutodateField{Name: "test", OnCreate: true, OnUpdate: true},
func() *core.Record {
return core.NewRecord(collection)
},
"{NOW}",
},
{
"update with non-zero value",
core.InterceptorActionUpdateExecute,
&core.AutodateField{Name: "test", OnCreate: true, OnUpdate: true},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", initialDate)
return record
},
initialDate.String(),
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
actionCalls := 0
record := s.record()
now := types.NowDateTime().String()
err := s.field.Intercept(context.Background(), app, record, s.actionName, func() error {
actionCalls++
return nil
})
if err != nil {
t.Fatal(err)
}
if actionCalls != 1 {
t.Fatalf("Expected actionCalls %d, got %d", 1, actionCalls)
}
expected := cutMilliseconds(strings.ReplaceAll(s.expected, "{NOW}", now))
v := cutMilliseconds(record.GetString(s.field.GetName()))
if v != expected {
t.Fatalf("Expected value %q, got %q", expected, v)
}
})
}
}
func TestAutodateRecordResave(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
collection, err := app.FindCollectionByNameOrId("demo2")
if err != nil {
t.Fatal(err)
}
record, err := app.FindRecordById(collection, "llvuca81nly1qls")
if err != nil {
t.Fatal(err)
}
lastUpdated := record.GetDateTime("updated")
// save with autogenerated date
err = app.Save(record)
if err != nil {
t.Fatal(err)
}
newUpdated := record.GetDateTime("updated")
if newUpdated.Equal(lastUpdated) {
t.Fatalf("[0] Expected updated to change, got %v", newUpdated)
}
lastUpdated = newUpdated
// save with custom date
manualUpdated := lastUpdated.Add(-1 * time.Minute)
record.SetRaw("updated", manualUpdated)
err = app.Save(record)
if err != nil {
t.Fatal(err)
}
newUpdated = record.GetDateTime("updated")
if !newUpdated.Equal(manualUpdated) {
t.Fatalf("[1] Expected updated to be the manual set date %v, got %v", manualUpdated, newUpdated)
}
lastUpdated = newUpdated
// save again with autogenerated date
err = app.Save(record)
if err != nil {
t.Fatal(err)
}
newUpdated = record.GetDateTime("updated")
if newUpdated.Equal(lastUpdated) {
t.Fatalf("[2] Expected updated to change, got %v", newUpdated)
}
lastUpdated = newUpdated
// simulate save failure
app.OnRecordUpdateExecute(collection.Id).Bind(&hook.Handler[*core.RecordEvent]{
Id: "test_failure",
Func: func(*core.RecordEvent) error {
return errors.New("test")
},
Priority: 9999999999, // as latest as possible
})
// save again with autogenerated date (should fail)
err = app.Save(record)
if err == nil {
t.Fatal("Expected save failure")
}
// updated should still be set even after save failure
newUpdated = record.GetDateTime("updated")
if newUpdated.Equal(lastUpdated) {
t.Fatalf("[3] Expected updated to change, got %v", newUpdated)
}
lastUpdated = newUpdated
// cleanup the error and resave again
app.OnRecordUpdateExecute(collection.Id).Unbind("test_failure")
err = app.Save(record)
if err != nil {
t.Fatal(err)
}
newUpdated = record.GetDateTime("updated")
if newUpdated.Equal(lastUpdated) {
t.Fatalf("[4] Expected updated to change, got %v", newUpdated)
}
}

124
core/field_bool.go Normal file
View file

@ -0,0 +1,124 @@
package core
import (
"context"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/pocketbase/pocketbase/core/validators"
"github.com/spf13/cast"
)
func init() {
Fields[FieldTypeBool] = func() Field {
return &BoolField{}
}
}
const FieldTypeBool = "bool"
var _ Field = (*BoolField)(nil)
// BoolField defines "bool" type field to store a single true/false value.
//
// The respective zero record field value is false.
type BoolField struct {
// Name (required) is the unique name of the field.
Name string `form:"name" json:"name"`
// Id is the unique stable field identifier.
//
// It is automatically generated from the name when adding to a collection FieldsList.
Id string `form:"id" json:"id"`
// System prevents the renaming and removal of the field.
System bool `form:"system" json:"system"`
// Hidden hides the field from the API response.
Hidden bool `form:"hidden" json:"hidden"`
// Presentable hints the Dashboard UI to use the underlying
// field record value in the relation preview label.
Presentable bool `form:"presentable" json:"presentable"`
// ---
// Required will require the field value to be always "true".
Required bool `form:"required" json:"required"`
}
// Type implements [Field.Type] interface method.
func (f *BoolField) Type() string {
return FieldTypeBool
}
// GetId implements [Field.GetId] interface method.
func (f *BoolField) GetId() string {
return f.Id
}
// SetId implements [Field.SetId] interface method.
func (f *BoolField) SetId(id string) {
f.Id = id
}
// GetName implements [Field.GetName] interface method.
func (f *BoolField) GetName() string {
return f.Name
}
// SetName implements [Field.SetName] interface method.
func (f *BoolField) SetName(name string) {
f.Name = name
}
// GetSystem implements [Field.GetSystem] interface method.
func (f *BoolField) GetSystem() bool {
return f.System
}
// SetSystem implements [Field.SetSystem] interface method.
func (f *BoolField) SetSystem(system bool) {
f.System = system
}
// GetHidden implements [Field.GetHidden] interface method.
func (f *BoolField) GetHidden() bool {
return f.Hidden
}
// SetHidden implements [Field.SetHidden] interface method.
func (f *BoolField) SetHidden(hidden bool) {
f.Hidden = hidden
}
// ColumnType implements [Field.ColumnType] interface method.
func (f *BoolField) ColumnType(app App) string {
return "BOOLEAN DEFAULT FALSE NOT NULL"
}
// PrepareValue implements [Field.PrepareValue] interface method.
func (f *BoolField) PrepareValue(record *Record, raw any) (any, error) {
return cast.ToBool(raw), nil
}
// ValidateValue implements [Field.ValidateValue] interface method.
func (f *BoolField) ValidateValue(ctx context.Context, app App, record *Record) error {
v, ok := record.GetRaw(f.Name).(bool)
if !ok {
return validators.ErrUnsupportedValueType
}
if f.Required {
return validation.Required.Validate(v)
}
return nil
}
// ValidateSettings implements [Field.ValidateSettings] interface method.
func (f *BoolField) ValidateSettings(ctx context.Context, app App, collection *Collection) error {
return validation.ValidateStruct(f,
validation.Field(&f.Id, validation.By(DefaultFieldIdValidationRule)),
validation.Field(&f.Name, validation.By(DefaultFieldNameValidationRule)),
)
}

150
core/field_bool_test.go Normal file
View file

@ -0,0 +1,150 @@
package core_test
import (
"context"
"fmt"
"testing"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
)
func TestBoolFieldBaseMethods(t *testing.T) {
testFieldBaseMethods(t, core.FieldTypeBool)
}
func TestBoolFieldColumnType(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
f := &core.BoolField{}
expected := "BOOLEAN DEFAULT FALSE NOT NULL"
if v := f.ColumnType(app); v != expected {
t.Fatalf("Expected\n%q\ngot\n%q", expected, v)
}
}
func TestBoolFieldPrepareValue(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
f := &core.BoolField{}
record := core.NewRecord(core.NewBaseCollection("test"))
scenarios := []struct {
raw any
expected bool
}{
{"", false},
{"f", false},
{"t", true},
{1, true},
{0, false},
}
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%#v", i, s.raw), func(t *testing.T) {
v, err := f.PrepareValue(record, s.raw)
if err != nil {
t.Fatal(err)
}
if v != s.expected {
t.Fatalf("Expected %v, got %v", s.expected, v)
}
})
}
}
func TestBoolFieldValidateValue(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
collection := core.NewBaseCollection("test_collection")
scenarios := []struct {
name string
field *core.BoolField
record func() *core.Record
expectError bool
}{
{
"invalid raw value",
&core.BoolField{Name: "test"},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", 123)
return record
},
true,
},
{
"missing field value (non-required)",
&core.BoolField{Name: "test"},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("abc", true)
return record
},
true, // because of failed nil.(bool) cast
},
{
"missing field value (required)",
&core.BoolField{Name: "test", Required: true},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("abc", true)
return record
},
true,
},
{
"false field value (non-required)",
&core.BoolField{Name: "test"},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", false)
return record
},
false,
},
{
"false field value (required)",
&core.BoolField{Name: "test", Required: true},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", false)
return record
},
true,
},
{
"true field value (required)",
&core.BoolField{Name: "test", Required: true},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", true)
return record
},
false,
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
err := s.field.ValidateValue(context.Background(), app, s.record())
hasErr := err != nil
if hasErr != s.expectError {
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
}
})
}
}
func TestBoolFieldValidateSettings(t *testing.T) {
testDefaultFieldIdValidation(t, core.FieldTypeBool)
testDefaultFieldNameValidation(t, core.FieldTypeBool)
}

174
core/field_date.go Normal file
View file

@ -0,0 +1,174 @@
package core
import (
"context"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/pocketbase/pocketbase/core/validators"
"github.com/pocketbase/pocketbase/tools/types"
)
func init() {
Fields[FieldTypeDate] = func() Field {
return &DateField{}
}
}
const FieldTypeDate = "date"
var _ Field = (*DateField)(nil)
// DateField defines "date" type field to store a single [types.DateTime] value.
//
// The respective zero record field value is the zero [types.DateTime].
type DateField struct {
// Name (required) is the unique name of the field.
Name string `form:"name" json:"name"`
// Id is the unique stable field identifier.
//
// It is automatically generated from the name when adding to a collection FieldsList.
Id string `form:"id" json:"id"`
// System prevents the renaming and removal of the field.
System bool `form:"system" json:"system"`
// Hidden hides the field from the API response.
Hidden bool `form:"hidden" json:"hidden"`
// Presentable hints the Dashboard UI to use the underlying
// field record value in the relation preview label.
Presentable bool `form:"presentable" json:"presentable"`
// ---
// Min specifies the min allowed field value.
//
// Leave it empty to skip the validator.
Min types.DateTime `form:"min" json:"min"`
// Max specifies the max allowed field value.
//
// Leave it empty to skip the validator.
Max types.DateTime `form:"max" json:"max"`
// Required will require the field value to be non-zero [types.DateTime].
Required bool `form:"required" json:"required"`
}
// Type implements [Field.Type] interface method.
func (f *DateField) Type() string {
return FieldTypeDate
}
// GetId implements [Field.GetId] interface method.
func (f *DateField) GetId() string {
return f.Id
}
// SetId implements [Field.SetId] interface method.
func (f *DateField) SetId(id string) {
f.Id = id
}
// GetName implements [Field.GetName] interface method.
func (f *DateField) GetName() string {
return f.Name
}
// SetName implements [Field.SetName] interface method.
func (f *DateField) SetName(name string) {
f.Name = name
}
// GetSystem implements [Field.GetSystem] interface method.
func (f *DateField) GetSystem() bool {
return f.System
}
// SetSystem implements [Field.SetSystem] interface method.
func (f *DateField) SetSystem(system bool) {
f.System = system
}
// GetHidden implements [Field.GetHidden] interface method.
func (f *DateField) GetHidden() bool {
return f.Hidden
}
// SetHidden implements [Field.SetHidden] interface method.
func (f *DateField) SetHidden(hidden bool) {
f.Hidden = hidden
}
// ColumnType implements [Field.ColumnType] interface method.
func (f *DateField) ColumnType(app App) string {
return "TEXT DEFAULT '' NOT NULL"
}
// PrepareValue implements [Field.PrepareValue] interface method.
func (f *DateField) PrepareValue(record *Record, raw any) (any, error) {
// ignore scan errors since the format may change between versions
// and to allow running db adjusting migrations
val, _ := types.ParseDateTime(raw)
return val, nil
}
// ValidateValue implements [Field.ValidateValue] interface method.
func (f *DateField) ValidateValue(ctx context.Context, app App, record *Record) error {
val, ok := record.GetRaw(f.Name).(types.DateTime)
if !ok {
return validators.ErrUnsupportedValueType
}
if val.IsZero() {
if f.Required {
return validation.ErrRequired
}
return nil // nothing to check
}
if !f.Min.IsZero() {
if err := validation.Min(f.Min.Time()).Validate(val.Time()); err != nil {
return err
}
}
if !f.Max.IsZero() {
if err := validation.Max(f.Max.Time()).Validate(val.Time()); err != nil {
return err
}
}
return nil
}
// ValidateSettings implements [Field.ValidateSettings] interface method.
func (f *DateField) ValidateSettings(ctx context.Context, app App, collection *Collection) error {
return validation.ValidateStruct(f,
validation.Field(&f.Id, validation.By(DefaultFieldIdValidationRule)),
validation.Field(&f.Name, validation.By(DefaultFieldNameValidationRule)),
validation.Field(&f.Max, validation.By(f.checkRange(f.Min, f.Max))),
)
}
func (f *DateField) checkRange(min types.DateTime, max types.DateTime) validation.RuleFunc {
return func(value any) error {
v, _ := value.(types.DateTime)
if v.IsZero() {
return nil // nothing to check
}
dr := validation.Date(types.DefaultDateLayout)
if !min.IsZero() {
dr.Min(min.Time())
}
if !max.IsZero() {
dr.Max(max.Time())
}
return dr.Validate(v.String())
}
}

229
core/field_date_test.go Normal file
View file

@ -0,0 +1,229 @@
package core_test
import (
"context"
"fmt"
"testing"
"time"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
"github.com/pocketbase/pocketbase/tools/types"
)
func TestDateFieldBaseMethods(t *testing.T) {
testFieldBaseMethods(t, core.FieldTypeDate)
}
func TestDateFieldColumnType(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
f := &core.DateField{}
expected := "TEXT DEFAULT '' NOT NULL"
if v := f.ColumnType(app); v != expected {
t.Fatalf("Expected\n%q\ngot\n%q", expected, v)
}
}
func TestDateFieldPrepareValue(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
f := &core.DateField{}
record := core.NewRecord(core.NewBaseCollection("test"))
scenarios := []struct {
raw any
expected string
}{
{"", ""},
{"invalid", ""},
{"2024-01-01 00:11:22.345Z", "2024-01-01 00:11:22.345Z"},
{time.Date(2024, 1, 2, 3, 4, 5, 0, time.UTC), "2024-01-02 03:04:05.000Z"},
}
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%#v", i, s.raw), func(t *testing.T) {
v, err := f.PrepareValue(record, s.raw)
if err != nil {
t.Fatal(err)
}
vDate, ok := v.(types.DateTime)
if !ok {
t.Fatalf("Expected types.DateTime instance, got %T", v)
}
if vDate.String() != s.expected {
t.Fatalf("Expected %v, got %v", s.expected, v)
}
})
}
}
func TestDateFieldValidateValue(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
collection := core.NewBaseCollection("test_collection")
scenarios := []struct {
name string
field *core.DateField
record func() *core.Record
expectError bool
}{
{
"invalid raw value",
&core.DateField{Name: "test"},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", 123)
return record
},
true,
},
{
"zero field value (not required)",
&core.DateField{Name: "test"},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", types.DateTime{})
return record
},
false,
},
{
"zero field value (required)",
&core.DateField{Name: "test", Required: true},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", types.DateTime{})
return record
},
true,
},
{
"non-zero field value (required)",
&core.DateField{Name: "test", Required: true},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", types.NowDateTime())
return record
},
false,
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
err := s.field.ValidateValue(context.Background(), app, s.record())
hasErr := err != nil
if hasErr != s.expectError {
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
}
})
}
}
func TestDateFieldValidateSettings(t *testing.T) {
testDefaultFieldIdValidation(t, core.FieldTypeDate)
testDefaultFieldNameValidation(t, core.FieldTypeDate)
app, _ := tests.NewTestApp()
defer app.Cleanup()
collection := core.NewBaseCollection("test_collection")
scenarios := []struct {
name string
field func() *core.DateField
expectErrors []string
}{
{
"zero Min/Max",
func() *core.DateField {
return &core.DateField{
Id: "test",
Name: "test",
}
},
[]string{},
},
{
"non-empty Min with empty Max",
func() *core.DateField {
return &core.DateField{
Id: "test",
Name: "test",
Min: types.NowDateTime(),
}
},
[]string{},
},
{
"empty Min non-empty Max",
func() *core.DateField {
return &core.DateField{
Id: "test",
Name: "test",
Max: types.NowDateTime(),
}
},
[]string{},
},
{
"Min = Max",
func() *core.DateField {
date := types.NowDateTime()
return &core.DateField{
Id: "test",
Name: "test",
Min: date,
Max: date,
}
},
[]string{},
},
{
"Min > Max",
func() *core.DateField {
min := types.NowDateTime()
max := min.Add(-5 * time.Second)
return &core.DateField{
Id: "test",
Name: "test",
Min: min,
Max: max,
}
},
[]string{},
},
{
"Min < Max",
func() *core.DateField {
max := types.NowDateTime()
min := max.Add(-5 * time.Second)
return &core.DateField{
Id: "test",
Name: "test",
Min: min,
Max: max,
}
},
[]string{},
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
errs := s.field().ValidateSettings(context.Background(), app, collection)
tests.TestValidationErrors(t, errs, s.expectErrors)
})
}
}

162
core/field_editor.go Normal file
View file

@ -0,0 +1,162 @@
package core
import (
"context"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/pocketbase/pocketbase/core/validators"
"github.com/spf13/cast"
)
func init() {
Fields[FieldTypeEditor] = func() Field {
return &EditorField{}
}
}
const FieldTypeEditor = "editor"
const DefaultEditorFieldMaxSize int64 = 5 << 20
var (
_ Field = (*EditorField)(nil)
_ MaxBodySizeCalculator = (*EditorField)(nil)
)
// EditorField defines "editor" type field to store HTML formatted text.
//
// The respective zero record field value is empty string.
type EditorField struct {
// Name (required) is the unique name of the field.
Name string `form:"name" json:"name"`
// Id is the unique stable field identifier.
//
// It is automatically generated from the name when adding to a collection FieldsList.
Id string `form:"id" json:"id"`
// System prevents the renaming and removal of the field.
System bool `form:"system" json:"system"`
// Hidden hides the field from the API response.
Hidden bool `form:"hidden" json:"hidden"`
// Presentable hints the Dashboard UI to use the underlying
// field record value in the relation preview label.
Presentable bool `form:"presentable" json:"presentable"`
// ---
// MaxSize specifies the maximum size of the allowed field value (in bytes and up to 2^53-1).
//
// If zero, a default limit of ~5MB is applied.
MaxSize int64 `form:"maxSize" json:"maxSize"`
// ConvertURLs is usually used to instruct the editor whether to
// apply url conversion (eg. stripping the domain name in case the
// urls are using the same domain as the one where the editor is loaded).
//
// (see also https://www.tiny.cloud/docs/tinymce/6/url-handling/#convert_urls)
ConvertURLs bool `form:"convertURLs" json:"convertURLs"`
// Required will require the field value to be non-empty string.
Required bool `form:"required" json:"required"`
}
// Type implements [Field.Type] interface method.
func (f *EditorField) Type() string {
return FieldTypeEditor
}
// GetId implements [Field.GetId] interface method.
func (f *EditorField) GetId() string {
return f.Id
}
// SetId implements [Field.SetId] interface method.
func (f *EditorField) SetId(id string) {
f.Id = id
}
// GetName implements [Field.GetName] interface method.
func (f *EditorField) GetName() string {
return f.Name
}
// SetName implements [Field.SetName] interface method.
func (f *EditorField) SetName(name string) {
f.Name = name
}
// GetSystem implements [Field.GetSystem] interface method.
func (f *EditorField) GetSystem() bool {
return f.System
}
// SetSystem implements [Field.SetSystem] interface method.
func (f *EditorField) SetSystem(system bool) {
f.System = system
}
// GetHidden implements [Field.GetHidden] interface method.
func (f *EditorField) GetHidden() bool {
return f.Hidden
}
// SetHidden implements [Field.SetHidden] interface method.
func (f *EditorField) SetHidden(hidden bool) {
f.Hidden = hidden
}
// ColumnType implements [Field.ColumnType] interface method.
func (f *EditorField) ColumnType(app App) string {
return "TEXT DEFAULT '' NOT NULL"
}
// PrepareValue implements [Field.PrepareValue] interface method.
func (f *EditorField) PrepareValue(record *Record, raw any) (any, error) {
return cast.ToString(raw), nil
}
// ValidateValue implements [Field.ValidateValue] interface method.
func (f *EditorField) ValidateValue(ctx context.Context, app App, record *Record) error {
val, ok := record.GetRaw(f.Name).(string)
if !ok {
return validators.ErrUnsupportedValueType
}
if f.Required {
if err := validation.Required.Validate(val); err != nil {
return err
}
}
maxSize := f.CalculateMaxBodySize()
if int64(len(val)) > maxSize {
return validation.NewError(
"validation_content_size_limit",
"The maximum allowed content size is {{.maxSize}} bytes",
).SetParams(map[string]any{"maxSize": maxSize})
}
return nil
}
// ValidateSettings implements [Field.ValidateSettings] interface method.
func (f *EditorField) ValidateSettings(ctx context.Context, app App, collection *Collection) error {
return validation.ValidateStruct(f,
validation.Field(&f.Id, validation.By(DefaultFieldIdValidationRule)),
validation.Field(&f.Name, validation.By(DefaultFieldNameValidationRule)),
validation.Field(&f.MaxSize, validation.Min(0), validation.Max(maxSafeJSONInt)),
)
}
// CalculateMaxBodySize implements the [MaxBodySizeCalculator] interface.
func (f *EditorField) CalculateMaxBodySize() int64 {
if f.MaxSize <= 0 {
return DefaultEditorFieldMaxSize
}
return f.MaxSize
}

252
core/field_editor_test.go Normal file
View file

@ -0,0 +1,252 @@
package core_test
import (
"context"
"fmt"
"strings"
"testing"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
)
func TestEditorFieldBaseMethods(t *testing.T) {
testFieldBaseMethods(t, core.FieldTypeEditor)
}
func TestEditorFieldColumnType(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
f := &core.EditorField{}
expected := "TEXT DEFAULT '' NOT NULL"
if v := f.ColumnType(app); v != expected {
t.Fatalf("Expected\n%q\ngot\n%q", expected, v)
}
}
func TestEditorFieldPrepareValue(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
f := &core.EditorField{}
record := core.NewRecord(core.NewBaseCollection("test"))
scenarios := []struct {
raw any
expected string
}{
{"", ""},
{"test", "test"},
{false, "false"},
{true, "true"},
{123.456, "123.456"},
}
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%#v", i, s.raw), func(t *testing.T) {
v, err := f.PrepareValue(record, s.raw)
if err != nil {
t.Fatal(err)
}
vStr, ok := v.(string)
if !ok {
t.Fatalf("Expected string instance, got %T", v)
}
if vStr != s.expected {
t.Fatalf("Expected %q, got %q", s.expected, v)
}
})
}
}
func TestEditorFieldValidateValue(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
collection := core.NewBaseCollection("test_collection")
scenarios := []struct {
name string
field *core.EditorField
record func() *core.Record
expectError bool
}{
{
"invalid raw value",
&core.EditorField{Name: "test"},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", 123)
return record
},
true,
},
{
"zero field value (not required)",
&core.EditorField{Name: "test"},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", "")
return record
},
false,
},
{
"zero field value (required)",
&core.EditorField{Name: "test", Required: true},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", "")
return record
},
true,
},
{
"non-zero field value (required)",
&core.EditorField{Name: "test", Required: true},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", "abc")
return record
},
false,
},
{
"> default MaxSize",
&core.EditorField{Name: "test", Required: true},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", strings.Repeat("a", 1+(5<<20)))
return record
},
true,
},
{
"> MaxSize",
&core.EditorField{Name: "test", Required: true, MaxSize: 5},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", "abcdef")
return record
},
true,
},
{
"<= MaxSize",
&core.EditorField{Name: "test", Required: true, MaxSize: 5},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", "abcde")
return record
},
false,
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
err := s.field.ValidateValue(context.Background(), app, s.record())
hasErr := err != nil
if hasErr != s.expectError {
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
}
})
}
}
func TestEditorFieldValidateSettings(t *testing.T) {
testDefaultFieldIdValidation(t, core.FieldTypeEditor)
testDefaultFieldNameValidation(t, core.FieldTypeEditor)
app, _ := tests.NewTestApp()
defer app.Cleanup()
collection := core.NewBaseCollection("test_collection")
scenarios := []struct {
name string
field func() *core.EditorField
expectErrors []string
}{
{
"< 0 MaxSize",
func() *core.EditorField {
return &core.EditorField{
Id: "test",
Name: "test",
MaxSize: -1,
}
},
[]string{"maxSize"},
},
{
"= 0 MaxSize",
func() *core.EditorField {
return &core.EditorField{
Id: "test",
Name: "test",
}
},
[]string{},
},
{
"> 0 MaxSize",
func() *core.EditorField {
return &core.EditorField{
Id: "test",
Name: "test",
MaxSize: 1,
}
},
[]string{},
},
{
"MaxSize > safe json int",
func() *core.EditorField {
return &core.EditorField{
Id: "test",
Name: "test",
MaxSize: 1 << 53,
}
},
[]string{"maxSize"},
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
errs := s.field().ValidateSettings(context.Background(), app, collection)
tests.TestValidationErrors(t, errs, s.expectErrors)
})
}
}
func TestEditorFieldCalculateMaxBodySize(t *testing.T) {
testApp, _ := tests.NewTestApp()
defer testApp.Cleanup()
scenarios := []struct {
field *core.EditorField
expected int64
}{
{&core.EditorField{}, core.DefaultEditorFieldMaxSize},
{&core.EditorField{MaxSize: 10}, 10},
}
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%d", i, s.field.MaxSize), func(t *testing.T) {
result := s.field.CalculateMaxBodySize()
if result != s.expected {
t.Fatalf("Expected %d, got %d", s.expected, result)
}
})
}
}

167
core/field_email.go Normal file
View file

@ -0,0 +1,167 @@
package core
import (
"context"
"slices"
"strings"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/go-ozzo/ozzo-validation/v4/is"
"github.com/pocketbase/pocketbase/core/validators"
"github.com/spf13/cast"
)
func init() {
Fields[FieldTypeEmail] = func() Field {
return &EmailField{}
}
}
const FieldTypeEmail = "email"
var _ Field = (*EmailField)(nil)
// EmailField defines "email" type field for storing a single email string address.
//
// The respective zero record field value is empty string.
type EmailField struct {
// Name (required) is the unique name of the field.
Name string `form:"name" json:"name"`
// Id is the unique stable field identifier.
//
// It is automatically generated from the name when adding to a collection FieldsList.
Id string `form:"id" json:"id"`
// System prevents the renaming and removal of the field.
System bool `form:"system" json:"system"`
// Hidden hides the field from the API response.
Hidden bool `form:"hidden" json:"hidden"`
// Presentable hints the Dashboard UI to use the underlying
// field record value in the relation preview label.
Presentable bool `form:"presentable" json:"presentable"`
// ---
// ExceptDomains will require the email domain to NOT be included in the listed ones.
//
// This validator can be set only if OnlyDomains is empty.
ExceptDomains []string `form:"exceptDomains" json:"exceptDomains"`
// OnlyDomains will require the email domain to be included in the listed ones.
//
// This validator can be set only if ExceptDomains is empty.
OnlyDomains []string `form:"onlyDomains" json:"onlyDomains"`
// Required will require the field value to be non-empty email string.
Required bool `form:"required" json:"required"`
}
// Type implements [Field.Type] interface method.
func (f *EmailField) Type() string {
return FieldTypeEmail
}
// GetId implements [Field.GetId] interface method.
func (f *EmailField) GetId() string {
return f.Id
}
// SetId implements [Field.SetId] interface method.
func (f *EmailField) SetId(id string) {
f.Id = id
}
// GetName implements [Field.GetName] interface method.
func (f *EmailField) GetName() string {
return f.Name
}
// SetName implements [Field.SetName] interface method.
func (f *EmailField) SetName(name string) {
f.Name = name
}
// GetSystem implements [Field.GetSystem] interface method.
func (f *EmailField) GetSystem() bool {
return f.System
}
// SetSystem implements [Field.SetSystem] interface method.
func (f *EmailField) SetSystem(system bool) {
f.System = system
}
// GetHidden implements [Field.GetHidden] interface method.
func (f *EmailField) GetHidden() bool {
return f.Hidden
}
// SetHidden implements [Field.SetHidden] interface method.
func (f *EmailField) SetHidden(hidden bool) {
f.Hidden = hidden
}
// ColumnType implements [Field.ColumnType] interface method.
func (f *EmailField) ColumnType(app App) string {
return "TEXT DEFAULT '' NOT NULL"
}
// PrepareValue implements [Field.PrepareValue] interface method.
func (f *EmailField) PrepareValue(record *Record, raw any) (any, error) {
return cast.ToString(raw), nil
}
// ValidateValue implements [Field.ValidateValue] interface method.
func (f *EmailField) ValidateValue(ctx context.Context, app App, record *Record) error {
val, ok := record.GetRaw(f.Name).(string)
if !ok {
return validators.ErrUnsupportedValueType
}
if f.Required {
if err := validation.Required.Validate(val); err != nil {
return err
}
}
if val == "" {
return nil // nothing to check
}
if err := is.EmailFormat.Validate(val); err != nil {
return err
}
domain := val[strings.LastIndex(val, "@")+1:]
// only domains check
if len(f.OnlyDomains) > 0 && !slices.Contains(f.OnlyDomains, domain) {
return validation.NewError("validation_email_domain_not_allowed", "Email domain is not allowed")
}
// except domains check
if len(f.ExceptDomains) > 0 && slices.Contains(f.ExceptDomains, domain) {
return validation.NewError("validation_email_domain_not_allowed", "Email domain is not allowed")
}
return nil
}
// ValidateSettings implements [Field.ValidateSettings] interface method.
func (f *EmailField) ValidateSettings(ctx context.Context, app App, collection *Collection) error {
return validation.ValidateStruct(f,
validation.Field(&f.Id, validation.By(DefaultFieldIdValidationRule)),
validation.Field(&f.Name, validation.By(DefaultFieldNameValidationRule)),
validation.Field(
&f.ExceptDomains,
validation.When(len(f.OnlyDomains) > 0, validation.Empty).Else(validation.Each(is.Domain)),
),
validation.Field(
&f.OnlyDomains,
validation.When(len(f.ExceptDomains) > 0, validation.Empty).Else(validation.Each(is.Domain)),
),
)
}

271
core/field_email_test.go Normal file
View file

@ -0,0 +1,271 @@
package core_test
import (
"context"
"fmt"
"testing"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
)
func TestEmailFieldBaseMethods(t *testing.T) {
testFieldBaseMethods(t, core.FieldTypeEmail)
}
func TestEmailFieldColumnType(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
f := &core.EmailField{}
expected := "TEXT DEFAULT '' NOT NULL"
if v := f.ColumnType(app); v != expected {
t.Fatalf("Expected\n%q\ngot\n%q", expected, v)
}
}
func TestEmailFieldPrepareValue(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
f := &core.EmailField{}
record := core.NewRecord(core.NewBaseCollection("test"))
scenarios := []struct {
raw any
expected string
}{
{"", ""},
{"test", "test"},
{false, "false"},
{true, "true"},
{123.456, "123.456"},
}
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%#v", i, s.raw), func(t *testing.T) {
v, err := f.PrepareValue(record, s.raw)
if err != nil {
t.Fatal(err)
}
vStr, ok := v.(string)
if !ok {
t.Fatalf("Expected string instance, got %T", v)
}
if vStr != s.expected {
t.Fatalf("Expected %q, got %q", s.expected, v)
}
})
}
}
func TestEmailFieldValidateValue(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
collection := core.NewBaseCollection("test_collection")
scenarios := []struct {
name string
field *core.EmailField
record func() *core.Record
expectError bool
}{
{
"invalid raw value",
&core.EmailField{Name: "test"},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", 123)
return record
},
true,
},
{
"zero field value (not required)",
&core.EmailField{Name: "test"},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", "")
return record
},
false,
},
{
"zero field value (required)",
&core.EmailField{Name: "test", Required: true},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", "")
return record
},
true,
},
{
"non-zero field value (required)",
&core.EmailField{Name: "test", Required: true},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", "test@example.com")
return record
},
false,
},
{
"invalid email",
&core.EmailField{Name: "test"},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", "invalid")
return record
},
true,
},
{
"failed onlyDomains",
&core.EmailField{Name: "test", OnlyDomains: []string{"example.org", "example.net"}},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", "test@example.com")
return record
},
true,
},
{
"success onlyDomains",
&core.EmailField{Name: "test", OnlyDomains: []string{"example.org", "example.com"}},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", "test@example.com")
return record
},
false,
},
{
"failed exceptDomains",
&core.EmailField{Name: "test", ExceptDomains: []string{"example.org", "example.com"}},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", "test@example.com")
return record
},
true,
},
{
"success exceptDomains",
&core.EmailField{Name: "test", ExceptDomains: []string{"example.org", "example.net"}},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", "test@example.com")
return record
},
false,
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
err := s.field.ValidateValue(context.Background(), app, s.record())
hasErr := err != nil
if hasErr != s.expectError {
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
}
})
}
}
func TestEmailFieldValidateSettings(t *testing.T) {
testDefaultFieldIdValidation(t, core.FieldTypeEmail)
testDefaultFieldNameValidation(t, core.FieldTypeEmail)
app, _ := tests.NewTestApp()
defer app.Cleanup()
collection := core.NewBaseCollection("test_collection")
scenarios := []struct {
name string
field func() *core.EmailField
expectErrors []string
}{
{
"zero minimal",
func() *core.EmailField {
return &core.EmailField{
Id: "test",
Name: "test",
}
},
[]string{},
},
{
"both onlyDomains and exceptDomains",
func() *core.EmailField {
return &core.EmailField{
Id: "test",
Name: "test",
OnlyDomains: []string{"example.com"},
ExceptDomains: []string{"example.org"},
}
},
[]string{"onlyDomains", "exceptDomains"},
},
{
"invalid onlyDomains",
func() *core.EmailField {
return &core.EmailField{
Id: "test",
Name: "test",
OnlyDomains: []string{"example.com", "invalid"},
}
},
[]string{"onlyDomains"},
},
{
"valid onlyDomains",
func() *core.EmailField {
return &core.EmailField{
Id: "test",
Name: "test",
OnlyDomains: []string{"example.com", "example.org"},
}
},
[]string{},
},
{
"invalid exceptDomains",
func() *core.EmailField {
return &core.EmailField{
Id: "test",
Name: "test",
ExceptDomains: []string{"example.com", "invalid"},
}
},
[]string{"exceptDomains"},
},
{
"valid exceptDomains",
func() *core.EmailField {
return &core.EmailField{
Id: "test",
Name: "test",
ExceptDomains: []string{"example.com", "example.org"},
}
},
[]string{},
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
errs := s.field().ValidateSettings(context.Background(), app, collection)
tests.TestValidationErrors(t, errs, s.expectErrors)
})
}
}

820
core/field_file.go Normal file
View file

@ -0,0 +1,820 @@
package core
import (
"context"
"database/sql/driver"
"errors"
"fmt"
"log"
"regexp"
"strings"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/pocketbase/pocketbase/core/validators"
"github.com/pocketbase/pocketbase/tools/filesystem"
"github.com/pocketbase/pocketbase/tools/list"
"github.com/pocketbase/pocketbase/tools/types"
"github.com/spf13/cast"
)
func init() {
Fields[FieldTypeFile] = func() Field {
return &FileField{}
}
}
const FieldTypeFile = "file"
const DefaultFileFieldMaxSize int64 = 5 << 20
var looseFilenameRegex = regexp.MustCompile(`^[^\./\\][^/\\]+$`)
const (
deletedFilesPrefix = internalCustomFieldKeyPrefix + "_deletedFilesPrefix_"
uploadedFilesPrefix = internalCustomFieldKeyPrefix + "_uploadedFilesPrefix_"
)
var (
_ Field = (*FileField)(nil)
_ MultiValuer = (*FileField)(nil)
_ DriverValuer = (*FileField)(nil)
_ GetterFinder = (*FileField)(nil)
_ SetterFinder = (*FileField)(nil)
_ RecordInterceptor = (*FileField)(nil)
_ MaxBodySizeCalculator = (*FileField)(nil)
)
// FileField defines "file" type field for managing record file(s).
//
// Only the file name is stored as part of the record value.
// New files (aka. files to upload) are expected to be of *filesytem.File.
//
// If MaxSelect is not set or <= 1, then the field value is expected to be a single record id.
//
// If MaxSelect is > 1, then the field value is expected to be a slice of record ids.
//
// The respective zero record field value is either empty string (single) or empty string slice (multiple).
//
// ---
//
// The following additional setter keys are available:
//
// - "fieldName+" - append one or more files to the existing record one. For example:
//
// // []string{"old1.txt", "old2.txt", "new1_ajkvass.txt", "new2_klhfnwd.txt"}
// record.Set("documents+", []*filesystem.File{new1, new2})
//
// - "+fieldName" - prepend one or more files to the existing record one. For example:
//
// // []string{"new1_ajkvass.txt", "new2_klhfnwd.txt", "old1.txt", "old2.txt",}
// record.Set("+documents", []*filesystem.File{new1, new2})
//
// - "fieldName-" - subtract/delete one or more files from the existing record one. For example:
//
// // []string{"old2.txt",}
// record.Set("documents-", "old1.txt")
type FileField struct {
// Name (required) is the unique name of the field.
Name string `form:"name" json:"name"`
// Id is the unique stable field identifier.
//
// It is automatically generated from the name when adding to a collection FieldsList.
Id string `form:"id" json:"id"`
// System prevents the renaming and removal of the field.
System bool `form:"system" json:"system"`
// Hidden hides the field from the API response.
Hidden bool `form:"hidden" json:"hidden"`
// Presentable hints the Dashboard UI to use the underlying
// field record value in the relation preview label.
Presentable bool `form:"presentable" json:"presentable"`
// ---
// MaxSize specifies the maximum size of a single uploaded file (in bytes and up to 2^53-1).
//
// If zero, a default limit of 5MB is applied.
MaxSize int64 `form:"maxSize" json:"maxSize"`
// MaxSelect specifies the max allowed files.
//
// For multiple files the value must be > 1, otherwise fallbacks to single (default).
MaxSelect int `form:"maxSelect" json:"maxSelect"`
// MimeTypes specifies an optional list of the allowed file mime types.
//
// Leave it empty to disable the validator.
MimeTypes []string `form:"mimeTypes" json:"mimeTypes"`
// Thumbs specifies an optional list of the supported thumbs for image based files.
//
// Each entry must be in one of the following formats:
//
// - WxH (eg. 100x300) - crop to WxH viewbox (from center)
// - WxHt (eg. 100x300t) - crop to WxH viewbox (from top)
// - WxHb (eg. 100x300b) - crop to WxH viewbox (from bottom)
// - WxHf (eg. 100x300f) - fit inside a WxH viewbox (without cropping)
// - 0xH (eg. 0x300) - resize to H height preserving the aspect ratio
// - Wx0 (eg. 100x0) - resize to W width preserving the aspect ratio
Thumbs []string `form:"thumbs" json:"thumbs"`
// Protected will require the users to provide a special file token to access the file.
//
// Note that by default all files are publicly accessible.
//
// For the majority of the cases this is fine because by default
// all file names have random part appended to their name which
// need to be known by the user before accessing the file.
Protected bool `form:"protected" json:"protected"`
// Required will require the field value to have at least one file.
Required bool `form:"required" json:"required"`
}
// Type implements [Field.Type] interface method.
func (f *FileField) Type() string {
return FieldTypeFile
}
// GetId implements [Field.GetId] interface method.
func (f *FileField) GetId() string {
return f.Id
}
// SetId implements [Field.SetId] interface method.
func (f *FileField) SetId(id string) {
f.Id = id
}
// GetName implements [Field.GetName] interface method.
func (f *FileField) GetName() string {
return f.Name
}
// SetName implements [Field.SetName] interface method.
func (f *FileField) SetName(name string) {
f.Name = name
}
// GetSystem implements [Field.GetSystem] interface method.
func (f *FileField) GetSystem() bool {
return f.System
}
// SetSystem implements [Field.SetSystem] interface method.
func (f *FileField) SetSystem(system bool) {
f.System = system
}
// GetHidden implements [Field.GetHidden] interface method.
func (f *FileField) GetHidden() bool {
return f.Hidden
}
// SetHidden implements [Field.SetHidden] interface method.
func (f *FileField) SetHidden(hidden bool) {
f.Hidden = hidden
}
// IsMultiple implements MultiValuer interface and checks whether the
// current field options support multiple values.
func (f *FileField) IsMultiple() bool {
return f.MaxSelect > 1
}
// ColumnType implements [Field.ColumnType] interface method.
func (f *FileField) ColumnType(app App) string {
if f.IsMultiple() {
return "JSON DEFAULT '[]' NOT NULL"
}
return "TEXT DEFAULT '' NOT NULL"
}
// PrepareValue implements [Field.PrepareValue] interface method.
func (f *FileField) PrepareValue(record *Record, raw any) (any, error) {
return f.normalizeValue(raw), nil
}
// DriverValue implements the [DriverValuer] interface.
func (f *FileField) DriverValue(record *Record) (driver.Value, error) {
files := f.toSliceValue(record.GetRaw(f.Name))
if f.IsMultiple() {
ja := make(types.JSONArray[string], len(files))
for i, v := range files {
ja[i] = f.getFileName(v)
}
return ja, nil
}
if len(files) == 0 {
return "", nil
}
return f.getFileName(files[len(files)-1]), nil
}
// ValidateSettings implements [Field.ValidateSettings] interface method.
func (f *FileField) ValidateSettings(ctx context.Context, app App, collection *Collection) error {
return validation.ValidateStruct(f,
validation.Field(&f.Id, validation.By(DefaultFieldIdValidationRule)),
validation.Field(&f.Name, validation.By(DefaultFieldNameValidationRule)),
validation.Field(&f.MaxSelect, validation.Min(0), validation.Max(maxSafeJSONInt)),
validation.Field(&f.MaxSize, validation.Min(0), validation.Max(maxSafeJSONInt)),
validation.Field(&f.Thumbs, validation.Each(
validation.NotIn("0x0", "0x0t", "0x0b", "0x0f"),
validation.Match(filesystem.ThumbSizeRegex),
)),
)
}
// ValidateValue implements [Field.ValidateValue] interface method.
func (f *FileField) ValidateValue(ctx context.Context, app App, record *Record) error {
files := f.toSliceValue(record.GetRaw(f.Name))
if len(files) == 0 {
if f.Required {
return validation.ErrRequired
}
return nil // nothing to check
}
// validate existing and disallow new plain string filenames submission
// (new files must be *filesystem.File)
// ---
oldExistingStrings := f.toSliceValue(f.getLatestOldValue(app, record))
existingStrings := list.ToInterfaceSlice(f.extractPlainStrings(files))
addedStrings := f.excludeFiles(existingStrings, oldExistingStrings)
if len(addedStrings) > 0 {
invalidFiles := make([]string, len(addedStrings))
for i, invalid := range addedStrings {
invalidStr := cast.ToString(invalid)
if len(invalidStr) > 250 {
invalidStr = invalidStr[:250]
}
invalidFiles[i] = invalidStr
}
return validation.NewError("validation_invalid_file", "Invalid new files: {{.invalidFiles}}.").
SetParams(map[string]any{"invalidFiles": invalidFiles})
}
maxSelect := f.maxSelect()
if len(files) > maxSelect {
return validation.NewError("validation_too_many_files", "The maximum allowed files is {{.maxSelect}}").
SetParams(map[string]any{"maxSelect": maxSelect})
}
// validate uploaded
// ---
uploads := f.extractUploadableFiles(files)
for _, upload := range uploads {
// loosely check the filename just in case it was manually changed after the normalization
err := validation.Length(1, 150).Validate(upload.Name)
if err != nil {
return err
}
err = validation.Match(looseFilenameRegex).Validate(upload.Name)
if err != nil {
return err
}
// check size
err = validators.UploadedFileSize(f.maxSize())(upload)
if err != nil {
return err
}
// check type
if len(f.MimeTypes) > 0 {
err = validators.UploadedFileMimeType(f.MimeTypes)(upload)
if err != nil {
return err
}
}
}
return nil
}
func (f *FileField) maxSize() int64 {
if f.MaxSize <= 0 {
return DefaultFileFieldMaxSize
}
return f.MaxSize
}
func (f *FileField) maxSelect() int {
if f.MaxSelect <= 1 {
return 1
}
return f.MaxSelect
}
// CalculateMaxBodySize implements the [MaxBodySizeCalculator] interface.
func (f *FileField) CalculateMaxBodySize() int64 {
return f.maxSize() * int64(f.maxSelect())
}
// Interceptors
// -------------------------------------------------------------------
// Intercept implements the [RecordInterceptor] interface.
//
// note: files delete after records deletion is handled globally by the app FileManager hook
func (f *FileField) Intercept(
ctx context.Context,
app App,
record *Record,
actionName string,
actionFunc func() error,
) error {
switch actionName {
case InterceptorActionCreateExecute, InterceptorActionUpdateExecute:
oldValue := f.getLatestOldValue(app, record)
err := f.processFilesToUpload(ctx, app, record)
if err != nil {
return err
}
err = actionFunc()
if err != nil {
return errors.Join(err, f.afterRecordExecuteFailure(newContextIfInvalid(ctx), app, record))
}
f.rememberFilesToDelete(app, record, oldValue)
f.afterRecordExecuteSuccess(newContextIfInvalid(ctx), app, record)
return nil
case InterceptorActionAfterCreateError, InterceptorActionAfterUpdateError:
// when in transaction we assume that the error was handled by afterRecordExecuteFailure
if app.IsTransactional() {
return actionFunc()
}
failedToDelete, deleteErr := f.deleteNewlyUploadedFiles(newContextIfInvalid(ctx), app, record)
if deleteErr != nil {
app.Logger().Warn(
"Failed to cleanup all new files after record commit failure",
"error", deleteErr,
"failedToDelete", failedToDelete,
)
}
record.SetRaw(deletedFilesPrefix+f.Name, nil)
if record.IsNew() {
// try to delete the record directory if there are no other files
//
// note: executed only on create failure to avoid accidentally
// deleting a concurrently updating directory due to the
// eventual consistent nature of some storage providers
err := f.deleteEmptyRecordDir(newContextIfInvalid(ctx), app, record)
if err != nil {
app.Logger().Warn("Failed to delete empty dir after new record commit failure", "error", err)
}
}
return actionFunc()
case InterceptorActionAfterCreate, InterceptorActionAfterUpdate:
record.SetRaw(uploadedFilesPrefix+f.Name, nil)
err := f.processFilesToDelete(ctx, app, record)
if err != nil {
return err
}
return actionFunc()
default:
return actionFunc()
}
}
func (f *FileField) getLatestOldValue(app App, record *Record) any {
if !record.IsNew() {
latestOriginal, err := app.FindRecordById(record.Collection(), cast.ToString(record.LastSavedPK()))
if err == nil {
return latestOriginal.GetRaw(f.Name)
}
}
return record.Original().GetRaw(f.Name)
}
func (f *FileField) afterRecordExecuteSuccess(ctx context.Context, app App, record *Record) {
uploaded, _ := record.GetRaw(uploadedFilesPrefix + f.Name).([]*filesystem.File)
// replace the uploaded file objects with their plain string names
newValue := f.toSliceValue(record.GetRaw(f.Name))
for i, v := range newValue {
if file, ok := v.(*filesystem.File); ok {
uploaded = append(uploaded, file)
newValue[i] = file.Name
}
}
f.setValue(record, newValue)
record.SetRaw(uploadedFilesPrefix+f.Name, uploaded)
}
func (f *FileField) afterRecordExecuteFailure(ctx context.Context, app App, record *Record) error {
uploaded := f.extractUploadableFiles(f.toSliceValue(record.GetRaw(f.Name)))
toDelete := make([]string, len(uploaded))
for i, file := range uploaded {
toDelete[i] = file.Name
}
// delete previously uploaded files
failedToDelete, deleteErr := f.deleteFilesByNamesList(ctx, app, record, list.ToUniqueStringSlice(toDelete))
if len(failedToDelete) > 0 {
app.Logger().Warn(
"Failed to cleanup the new uploaded file after record db write failure",
"error", deleteErr,
"failedToDelete", failedToDelete,
)
}
return deleteErr
}
func (f *FileField) deleteEmptyRecordDir(ctx context.Context, app App, record *Record) error {
fsys, err := app.NewFilesystem()
if err != nil {
return err
}
defer fsys.Close()
fsys.SetContext(newContextIfInvalid(ctx))
dir := record.BaseFilesPath()
if !fsys.IsEmptyDir(dir) {
return nil // no-op
}
err = fsys.Delete(dir)
if err != nil && !errors.Is(err, filesystem.ErrNotFound) {
return err
}
return nil
}
func (f *FileField) processFilesToDelete(ctx context.Context, app App, record *Record) error {
markedForDelete, _ := record.GetRaw(deletedFilesPrefix + f.Name).([]string)
if len(markedForDelete) == 0 {
return nil
}
old := list.ToInterfaceSlice(markedForDelete)
new := list.ToInterfaceSlice(f.extractPlainStrings(f.toSliceValue(record.GetRaw(f.Name))))
diff := f.excludeFiles(old, new)
toDelete := make([]string, len(diff))
for i, del := range diff {
toDelete[i] = f.getFileName(del)
}
failedToDelete, err := f.deleteFilesByNamesList(ctx, app, record, list.ToUniqueStringSlice(toDelete))
record.SetRaw(deletedFilesPrefix+f.Name, failedToDelete)
return err
}
func (f *FileField) rememberFilesToDelete(app App, record *Record, oldValue any) {
old := list.ToInterfaceSlice(f.extractPlainStrings(f.toSliceValue(oldValue)))
new := list.ToInterfaceSlice(f.extractPlainStrings(f.toSliceValue(record.GetRaw(f.Name))))
diff := f.excludeFiles(old, new)
toDelete, _ := record.GetRaw(deletedFilesPrefix + f.Name).([]string)
for _, del := range diff {
toDelete = append(toDelete, f.getFileName(del))
}
record.SetRaw(deletedFilesPrefix+f.Name, toDelete)
}
func (f *FileField) processFilesToUpload(ctx context.Context, app App, record *Record) error {
uploads := f.extractUploadableFiles(f.toSliceValue(record.GetRaw(f.Name)))
if len(uploads) == 0 {
return nil
}
if record.Id == "" {
return errors.New("uploading files requires the record to have a valid nonempty id")
}
fsys, err := app.NewFilesystem()
if err != nil {
return err
}
defer fsys.Close()
fsys.SetContext(ctx)
var failed []error // list of upload errors
var succeeded []string // list of uploaded file names
for _, upload := range uploads {
path := record.BaseFilesPath() + "/" + upload.Name
if err := fsys.UploadFile(upload, path); err == nil {
succeeded = append(succeeded, upload.Name)
} else {
failed = append(failed, fmt.Errorf("%q: %w", upload.Name, err))
break // for now stop on the first error since we currently don't allow partial uploads
}
}
if len(failed) > 0 {
// cleanup - try to delete the successfully uploaded files (if any)
_, cleanupErr := f.deleteFilesByNamesList(newContextIfInvalid(ctx), app, record, succeeded)
failed = append(failed, cleanupErr)
return fmt.Errorf("failed to upload all files: %w", errors.Join(failed...))
}
return nil
}
func (f *FileField) deleteNewlyUploadedFiles(ctx context.Context, app App, record *Record) ([]string, error) {
uploaded, _ := record.GetRaw(uploadedFilesPrefix + f.Name).([]*filesystem.File)
if len(uploaded) == 0 {
return nil, nil
}
names := make([]string, len(uploaded))
for i, file := range uploaded {
names[i] = file.Name
}
failed, err := f.deleteFilesByNamesList(ctx, app, record, list.ToUniqueStringSlice(names))
if err != nil {
return failed, err
}
record.SetRaw(uploadedFilesPrefix+f.Name, nil)
return nil, nil
}
// deleteFiles deletes a list of record files by their names.
// Returns the failed/remaining files.
func (f *FileField) deleteFilesByNamesList(ctx context.Context, app App, record *Record, filenames []string) ([]string, error) {
if len(filenames) == 0 {
return nil, nil // nothing to delete
}
if record.Id == "" {
return filenames, errors.New("the record doesn't have an id")
}
fsys, err := app.NewFilesystem()
if err != nil {
return filenames, err
}
defer fsys.Close()
fsys.SetContext(ctx)
var failures []error
for i := len(filenames) - 1; i >= 0; i-- {
filename := filenames[i]
if filename == "" || strings.ContainsAny(filename, "/\\") {
continue // empty or not a plain filename
}
path := record.BaseFilesPath() + "/" + filename
err := fsys.Delete(path)
if err != nil && !errors.Is(err, filesystem.ErrNotFound) {
// store the delete error
failures = append(failures, fmt.Errorf("file %d (%q): %w", i, filename, err))
} else {
// remove the deleted file from the list
filenames = append(filenames[:i], filenames[i+1:]...)
// try to delete the related file thumbs (if any)
thumbsErr := fsys.DeletePrefix(record.BaseFilesPath() + "/thumbs_" + filename + "/")
if len(thumbsErr) > 0 {
app.Logger().Warn("Failed to delete file thumbs", "error", errors.Join(thumbsErr...))
}
}
}
if len(failures) > 0 {
return filenames, fmt.Errorf("failed to delete all files: %w", errors.Join(failures...))
}
return nil, nil
}
// newContextIfInvalid returns a new Background context if the provided one was cancelled.
func newContextIfInvalid(ctx context.Context) context.Context {
if ctx.Err() == nil {
return ctx
}
return context.Background()
}
// -------------------------------------------------------------------
// FindGetter implements the [GetterFinder] interface.
func (f *FileField) FindGetter(key string) GetterFunc {
switch key {
case f.Name:
return func(record *Record) any {
return record.GetRaw(f.Name)
}
case f.Name + ":unsaved":
return func(record *Record) any {
return f.extractUploadableFiles(f.toSliceValue(record.GetRaw(f.Name)))
}
case f.Name + ":uploaded":
// deprecated
log.Println("[file field getter] please replace :uploaded with :unsaved")
return func(record *Record) any {
return f.extractUploadableFiles(f.toSliceValue(record.GetRaw(f.Name)))
}
default:
return nil
}
}
// -------------------------------------------------------------------
// FindSetter implements the [SetterFinder] interface.
func (f *FileField) FindSetter(key string) SetterFunc {
switch key {
case f.Name:
return f.setValue
case "+" + f.Name:
return f.prependValue
case f.Name + "+":
return f.appendValue
case f.Name + "-":
return f.subtractValue
default:
return nil
}
}
func (f *FileField) setValue(record *Record, raw any) {
val := f.normalizeValue(raw)
record.SetRaw(f.Name, val)
}
func (f *FileField) prependValue(record *Record, toPrepend any) {
files := f.toSliceValue(record.GetRaw(f.Name))
prepends := f.toSliceValue(toPrepend)
if len(prepends) > 0 {
files = append(prepends, files...)
}
f.setValue(record, files)
}
func (f *FileField) appendValue(record *Record, toAppend any) {
files := f.toSliceValue(record.GetRaw(f.Name))
appends := f.toSliceValue(toAppend)
if len(appends) > 0 {
files = append(files, appends...)
}
f.setValue(record, files)
}
func (f *FileField) subtractValue(record *Record, toRemove any) {
files := f.excludeFiles(
f.toSliceValue(record.GetRaw(f.Name)),
f.toSliceValue(toRemove),
)
f.setValue(record, files)
}
func (f *FileField) normalizeValue(raw any) any {
files := f.toSliceValue(raw)
if f.IsMultiple() {
return files
}
if len(files) > 0 {
return files[len(files)-1] // the last selected
}
return ""
}
func (f *FileField) toSliceValue(raw any) []any {
var result []any
switch value := raw.(type) {
case nil:
// nothing to cast
case *filesystem.File:
result = append(result, value)
case filesystem.File:
result = append(result, &value)
case []*filesystem.File:
for _, v := range value {
result = append(result, v)
}
case []filesystem.File:
for _, v := range value {
result = append(result, &v)
}
case []any:
for _, v := range value {
casted := f.toSliceValue(v)
if len(casted) == 1 {
result = append(result, casted[0])
}
}
default:
result = list.ToInterfaceSlice(list.ToUniqueStringSlice(value))
}
return f.uniqueFiles(result)
}
func (f *FileField) uniqueFiles(files []any) []any {
found := make(map[string]struct{}, len(files))
result := make([]any, 0, len(files))
for _, fv := range files {
name := f.getFileName(fv)
if _, ok := found[name]; !ok {
result = append(result, fv)
found[name] = struct{}{}
}
}
return result
}
func (f *FileField) extractPlainStrings(files []any) []string {
result := []string{}
for _, raw := range files {
if f, ok := raw.(string); ok {
result = append(result, f)
}
}
return result
}
func (f *FileField) extractUploadableFiles(files []any) []*filesystem.File {
result := []*filesystem.File{}
for _, raw := range files {
if upload, ok := raw.(*filesystem.File); ok {
result = append(result, upload)
}
}
return result
}
func (f *FileField) excludeFiles(base []any, toExclude []any) []any {
result := make([]any, 0, len(base))
SUBTRACT_LOOP:
for _, fv := range base {
for _, exclude := range toExclude {
if f.getFileName(exclude) == f.getFileName(fv) {
continue SUBTRACT_LOOP // skip
}
}
result = append(result, fv)
}
return result
}
func (f *FileField) getFileName(file any) string {
switch v := file.(type) {
case string:
return v
case *filesystem.File:
return v.Name
default:
return ""
}
}

1152
core/field_file_test.go Normal file

File diff suppressed because it is too large Load diff

148
core/field_geo_point.go Normal file
View file

@ -0,0 +1,148 @@
package core
import (
"context"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/pocketbase/pocketbase/core/validators"
"github.com/pocketbase/pocketbase/tools/types"
)
func init() {
Fields[FieldTypeGeoPoint] = func() Field {
return &GeoPointField{}
}
}
const FieldTypeGeoPoint = "geoPoint"
var (
_ Field = (*GeoPointField)(nil)
)
// GeoPointField defines "geoPoint" type field for storing latitude and longitude GPS coordinates.
//
// You can set the record field value as [types.GeoPoint], map or serialized json object with lat-lon props.
// The stored value is always converted to [types.GeoPoint].
// Nil, empty map, empty bytes slice, etc. results in zero [types.GeoPoint].
//
// Examples of updating a record's GeoPointField value programmatically:
//
// record.Set("location", types.GeoPoint{Lat: 123, Lon: 456})
// record.Set("location", map[string]any{"lat":123, "lon":456})
// record.Set("location", []byte(`{"lat":123, "lon":456}`)
type GeoPointField struct {
// Name (required) is the unique name of the field.
Name string `form:"name" json:"name"`
// Id is the unique stable field identifier.
//
// It is automatically generated from the name when adding to a collection FieldsList.
Id string `form:"id" json:"id"`
// System prevents the renaming and removal of the field.
System bool `form:"system" json:"system"`
// Hidden hides the field from the API response.
Hidden bool `form:"hidden" json:"hidden"`
// Presentable hints the Dashboard UI to use the underlying
// field record value in the relation preview label.
Presentable bool `form:"presentable" json:"presentable"`
// ---
// Required will require the field coordinates to be non-zero (aka. not "Null Island").
Required bool `form:"required" json:"required"`
}
// Type implements [Field.Type] interface method.
func (f *GeoPointField) Type() string {
return FieldTypeGeoPoint
}
// GetId implements [Field.GetId] interface method.
func (f *GeoPointField) GetId() string {
return f.Id
}
// SetId implements [Field.SetId] interface method.
func (f *GeoPointField) SetId(id string) {
f.Id = id
}
// GetName implements [Field.GetName] interface method.
func (f *GeoPointField) GetName() string {
return f.Name
}
// SetName implements [Field.SetName] interface method.
func (f *GeoPointField) SetName(name string) {
f.Name = name
}
// GetSystem implements [Field.GetSystem] interface method.
func (f *GeoPointField) GetSystem() bool {
return f.System
}
// SetSystem implements [Field.SetSystem] interface method.
func (f *GeoPointField) SetSystem(system bool) {
f.System = system
}
// GetHidden implements [Field.GetHidden] interface method.
func (f *GeoPointField) GetHidden() bool {
return f.Hidden
}
// SetHidden implements [Field.SetHidden] interface method.
func (f *GeoPointField) SetHidden(hidden bool) {
f.Hidden = hidden
}
// ColumnType implements [Field.ColumnType] interface method.
func (f *GeoPointField) ColumnType(app App) string {
return `JSON DEFAULT '{"lon":0,"lat":0}' NOT NULL`
}
// PrepareValue implements [Field.PrepareValue] interface method.
func (f *GeoPointField) PrepareValue(record *Record, raw any) (any, error) {
point := types.GeoPoint{}
err := point.Scan(raw)
return point, err
}
// ValidateValue implements [Field.ValidateValue] interface method.
func (f *GeoPointField) ValidateValue(ctx context.Context, app App, record *Record) error {
val, ok := record.GetRaw(f.Name).(types.GeoPoint)
if !ok {
return validators.ErrUnsupportedValueType
}
// zero value
if val.Lat == 0 && val.Lon == 0 {
if f.Required {
return validation.ErrRequired
}
return nil
}
if val.Lat < -90 || val.Lat > 90 {
return validation.NewError("validation_invalid_latitude", "Latitude must be between -90 and 90 degrees.")
}
if val.Lon < -180 || val.Lon > 180 {
return validation.NewError("validation_invalid_longitude", "Longitude must be between -180 and 180 degrees.")
}
return nil
}
// ValidateSettings implements [Field.ValidateSettings] interface method.
func (f *GeoPointField) ValidateSettings(ctx context.Context, app App, collection *Collection) error {
return validation.ValidateStruct(f,
validation.Field(&f.Id, validation.By(DefaultFieldIdValidationRule)),
validation.Field(&f.Name, validation.By(DefaultFieldNameValidationRule)),
)
}

View file

@ -0,0 +1,202 @@
package core_test
import (
"context"
"encoding/json"
"fmt"
"testing"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
"github.com/pocketbase/pocketbase/tools/types"
)
func TestGeoPointFieldBaseMethods(t *testing.T) {
testFieldBaseMethods(t, core.FieldTypeGeoPoint)
}
func TestGeoPointFieldColumnType(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
f := &core.GeoPointField{}
expected := `JSON DEFAULT '{"lon":0,"lat":0}' NOT NULL`
if v := f.ColumnType(app); v != expected {
t.Fatalf("Expected\n%q\ngot\n%q", expected, v)
}
}
func TestGeoPointFieldPrepareValue(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
f := &core.GeoPointField{}
record := core.NewRecord(core.NewBaseCollection("test"))
scenarios := []struct {
raw any
expected string
}{
{nil, `{"lon":0,"lat":0}`},
{"", `{"lon":0,"lat":0}`},
{[]byte{}, `{"lon":0,"lat":0}`},
{map[string]any{}, `{"lon":0,"lat":0}`},
{types.GeoPoint{Lon: 10, Lat: 20}, `{"lon":10,"lat":20}`},
{&types.GeoPoint{Lon: 10, Lat: 20}, `{"lon":10,"lat":20}`},
{[]byte(`{"lon": 10, "lat": 20}`), `{"lon":10,"lat":20}`},
{map[string]any{"lon": 10, "lat": 20}, `{"lon":10,"lat":20}`},
{map[string]float64{"lon": 10, "lat": 20}, `{"lon":10,"lat":20}`},
}
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%#v", i, s.raw), func(t *testing.T) {
v, err := f.PrepareValue(record, s.raw)
if err != nil {
t.Fatal(err)
}
raw, err := json.Marshal(v)
if err != nil {
t.Fatal(err)
}
rawStr := string(raw)
if rawStr != s.expected {
t.Fatalf("Expected\n%s\ngot\n%s", s.expected, rawStr)
}
})
}
}
func TestGeoPointFieldValidateValue(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
collection := core.NewBaseCollection("test_collection")
scenarios := []struct {
name string
field *core.GeoPointField
record func() *core.Record
expectError bool
}{
{
"invalid raw value",
&core.GeoPointField{Name: "test"},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", 123)
return record
},
true,
},
{
"zero field value (non-required)",
&core.GeoPointField{Name: "test"},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", types.GeoPoint{})
return record
},
false,
},
{
"zero field value (required)",
&core.GeoPointField{Name: "test", Required: true},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", types.GeoPoint{})
return record
},
true,
},
{
"non-zero Lat field value (required)",
&core.GeoPointField{Name: "test", Required: true},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", types.GeoPoint{Lat: 1})
return record
},
false,
},
{
"non-zero Lon field value (required)",
&core.GeoPointField{Name: "test", Required: true},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", types.GeoPoint{Lon: 1})
return record
},
false,
},
{
"non-zero Lat-Lon field value (required)",
&core.GeoPointField{Name: "test", Required: true},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", types.GeoPoint{Lon: -1, Lat: -2})
return record
},
false,
},
{
"lat < -90",
&core.GeoPointField{Name: "test"},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", types.GeoPoint{Lat: -90.1})
return record
},
true,
},
{
"lat > 90",
&core.GeoPointField{Name: "test"},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", types.GeoPoint{Lat: 90.1})
return record
},
true,
},
{
"lon < -180",
&core.GeoPointField{Name: "test"},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", types.GeoPoint{Lon: -180.1})
return record
},
true,
},
{
"lon > 180",
&core.GeoPointField{Name: "test"},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", types.GeoPoint{Lon: 180.1})
return record
},
true,
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
err := s.field.ValidateValue(context.Background(), app, s.record())
hasErr := err != nil
if hasErr != s.expectError {
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
}
})
}
}
func TestGeoPointFieldValidateSettings(t *testing.T) {
testDefaultFieldIdValidation(t, core.FieldTypeGeoPoint)
testDefaultFieldNameValidation(t, core.FieldTypeGeoPoint)
}

195
core/field_json.go Normal file
View file

@ -0,0 +1,195 @@
package core
import (
"context"
"slices"
"strconv"
"strings"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/go-ozzo/ozzo-validation/v4/is"
"github.com/pocketbase/pocketbase/core/validators"
"github.com/pocketbase/pocketbase/tools/types"
)
func init() {
Fields[FieldTypeJSON] = func() Field {
return &JSONField{}
}
}
const FieldTypeJSON = "json"
const DefaultJSONFieldMaxSize int64 = 1 << 20
var (
_ Field = (*JSONField)(nil)
_ MaxBodySizeCalculator = (*JSONField)(nil)
)
// JSONField defines "json" type field for storing any serialized JSON value.
//
// The respective zero record field value is the zero [types.JSONRaw].
type JSONField struct {
// Name (required) is the unique name of the field.
Name string `form:"name" json:"name"`
// Id is the unique stable field identifier.
//
// It is automatically generated from the name when adding to a collection FieldsList.
Id string `form:"id" json:"id"`
// System prevents the renaming and removal of the field.
System bool `form:"system" json:"system"`
// Hidden hides the field from the API response.
Hidden bool `form:"hidden" json:"hidden"`
// Presentable hints the Dashboard UI to use the underlying
// field record value in the relation preview label.
Presentable bool `form:"presentable" json:"presentable"`
// ---
// MaxSize specifies the maximum size of the allowed field value (in bytes and up to 2^53-1).
//
// If zero, a default limit of 1MB is applied.
MaxSize int64 `form:"maxSize" json:"maxSize"`
// Required will require the field value to be non-empty JSON value
// (aka. not "null", `""`, "[]", "{}").
Required bool `form:"required" json:"required"`
}
// Type implements [Field.Type] interface method.
func (f *JSONField) Type() string {
return FieldTypeJSON
}
// GetId implements [Field.GetId] interface method.
func (f *JSONField) GetId() string {
return f.Id
}
// SetId implements [Field.SetId] interface method.
func (f *JSONField) SetId(id string) {
f.Id = id
}
// GetName implements [Field.GetName] interface method.
func (f *JSONField) GetName() string {
return f.Name
}
// SetName implements [Field.SetName] interface method.
func (f *JSONField) SetName(name string) {
f.Name = name
}
// GetSystem implements [Field.GetSystem] interface method.
func (f *JSONField) GetSystem() bool {
return f.System
}
// SetSystem implements [Field.SetSystem] interface method.
func (f *JSONField) SetSystem(system bool) {
f.System = system
}
// GetHidden implements [Field.GetHidden] interface method.
func (f *JSONField) GetHidden() bool {
return f.Hidden
}
// SetHidden implements [Field.SetHidden] interface method.
func (f *JSONField) SetHidden(hidden bool) {
f.Hidden = hidden
}
// ColumnType implements [Field.ColumnType] interface method.
func (f *JSONField) ColumnType(app App) string {
return "JSON DEFAULT NULL"
}
// PrepareValue implements [Field.PrepareValue] interface method.
func (f *JSONField) PrepareValue(record *Record, raw any) (any, error) {
if str, ok := raw.(string); ok {
// in order to support seamlessly both json and multipart/form-data requests,
// the following normalization rules are applied for plain string values:
// - "true" is converted to the json `true`
// - "false" is converted to the json `false`
// - "null" is converted to the json `null`
// - "[1,2,3]" is converted to the json `[1,2,3]`
// - "{\"a\":1,\"b\":2}" is converted to the json `{"a":1,"b":2}`
// - numeric strings are converted to json number
// - double quoted strings are left as they are (aka. without normalizations)
// - any other string (empty string too) is double quoted
if str == "" {
raw = strconv.Quote(str)
} else if str == "null" || str == "true" || str == "false" {
raw = str
} else if ((str[0] >= '0' && str[0] <= '9') ||
str[0] == '-' ||
str[0] == '"' ||
str[0] == '[' ||
str[0] == '{') &&
is.JSON.Validate(str) == nil {
raw = str
} else {
raw = strconv.Quote(str)
}
}
return types.ParseJSONRaw(raw)
}
var emptyJSONValues = []string{
"null", `""`, "[]", "{}", "",
}
// ValidateValue implements [Field.ValidateValue] interface method.
func (f *JSONField) ValidateValue(ctx context.Context, app App, record *Record) error {
raw, ok := record.GetRaw(f.Name).(types.JSONRaw)
if !ok {
return validators.ErrUnsupportedValueType
}
maxSize := f.CalculateMaxBodySize()
if int64(len(raw)) > maxSize {
return validation.NewError(
"validation_json_size_limit",
"The maximum allowed JSON size is {{.maxSize}} bytes",
).SetParams(map[string]any{"maxSize": maxSize})
}
if is.JSON.Validate(raw) != nil {
return validation.NewError("validation_invalid_json", "Must be a valid json value")
}
rawStr := strings.TrimSpace(raw.String())
if f.Required && slices.Contains(emptyJSONValues, rawStr) {
return validation.ErrRequired
}
return nil
}
// ValidateSettings implements [Field.ValidateSettings] interface method.
func (f *JSONField) ValidateSettings(ctx context.Context, app App, collection *Collection) error {
return validation.ValidateStruct(f,
validation.Field(&f.Id, validation.By(DefaultFieldIdValidationRule)),
validation.Field(&f.Name, validation.By(DefaultFieldNameValidationRule)),
validation.Field(&f.MaxSize, validation.Min(0), validation.Max(maxSafeJSONInt)),
)
}
// CalculateMaxBodySize implements the [MaxBodySizeCalculator] interface.
func (f *JSONField) CalculateMaxBodySize() int64 {
if f.MaxSize <= 0 {
return DefaultJSONFieldMaxSize
}
return f.MaxSize
}

277
core/field_json_test.go Normal file
View file

@ -0,0 +1,277 @@
package core_test
import (
"context"
"fmt"
"strings"
"testing"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
"github.com/pocketbase/pocketbase/tools/types"
)
func TestJSONFieldBaseMethods(t *testing.T) {
testFieldBaseMethods(t, core.FieldTypeJSON)
}
func TestJSONFieldColumnType(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
f := &core.JSONField{}
expected := "JSON DEFAULT NULL"
if v := f.ColumnType(app); v != expected {
t.Fatalf("Expected\n%q\ngot\n%q", expected, v)
}
}
func TestJSONFieldPrepareValue(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
f := &core.JSONField{}
record := core.NewRecord(core.NewBaseCollection("test"))
scenarios := []struct {
raw any
expected string
}{
{"null", `null`},
{"", `""`},
{"true", `true`},
{"false", `false`},
{"test", `"test"`},
{"123", `123`},
{"-456", `-456`},
{"[1,2,3]", `[1,2,3]`},
{"[1,2,3", `"[1,2,3"`},
{`{"a":1,"b":2}`, `{"a":1,"b":2}`},
{`{"a":1,"b":2`, `"{\"a\":1,\"b\":2"`},
{[]int{1, 2, 3}, `[1,2,3]`},
{map[string]int{"a": 1, "b": 2}, `{"a":1,"b":2}`},
{nil, `null`},
{false, `false`},
{true, `true`},
{-78, `-78`},
{123.456, `123.456`},
}
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%#v", i, s.raw), func(t *testing.T) {
v, err := f.PrepareValue(record, s.raw)
if err != nil {
t.Fatal(err)
}
raw, ok := v.(types.JSONRaw)
if !ok {
t.Fatalf("Expected string instance, got %T", v)
}
rawStr := raw.String()
if rawStr != s.expected {
t.Fatalf("Expected\n%#v\ngot\n%#v", s.expected, rawStr)
}
})
}
}
func TestJSONFieldValidateValue(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
collection := core.NewBaseCollection("test_collection")
scenarios := []struct {
name string
field *core.JSONField
record func() *core.Record
expectError bool
}{
{
"invalid raw value",
&core.JSONField{Name: "test"},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", 123)
return record
},
true,
},
{
"zero field value (not required)",
&core.JSONField{Name: "test"},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", types.JSONRaw{})
return record
},
false,
},
{
"zero field value (required)",
&core.JSONField{Name: "test", Required: true},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", types.JSONRaw{})
return record
},
true,
},
{
"non-zero field value (required)",
&core.JSONField{Name: "test", Required: true},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", types.JSONRaw("[1,2,3]"))
return record
},
false,
},
{
"non-zero field value (required)",
&core.JSONField{Name: "test", Required: true},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", types.JSONRaw(`"aaa"`))
return record
},
false,
},
{
"> default MaxSize",
&core.JSONField{Name: "test"},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", types.JSONRaw(`"`+strings.Repeat("a", (1<<20))+`"`))
return record
},
true,
},
{
"> MaxSize",
&core.JSONField{Name: "test", MaxSize: 5},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", types.JSONRaw(`"aaaa"`))
return record
},
true,
},
{
"<= MaxSize",
&core.JSONField{Name: "test", MaxSize: 5},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", types.JSONRaw(`"aaa"`))
return record
},
false,
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
err := s.field.ValidateValue(context.Background(), app, s.record())
hasErr := err != nil
if hasErr != s.expectError {
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
}
})
}
}
func TestJSONFieldValidateSettings(t *testing.T) {
testDefaultFieldIdValidation(t, core.FieldTypeJSON)
testDefaultFieldNameValidation(t, core.FieldTypeJSON)
app, _ := tests.NewTestApp()
defer app.Cleanup()
collection := core.NewBaseCollection("test_collection")
scenarios := []struct {
name string
field func() *core.JSONField
expectErrors []string
}{
{
"MaxSize < 0",
func() *core.JSONField {
return &core.JSONField{
Id: "test",
Name: "test",
MaxSize: -1,
}
},
[]string{"maxSize"},
},
{
"MaxSize = 0",
func() *core.JSONField {
return &core.JSONField{
Id: "test",
Name: "test",
}
},
[]string{},
},
{
"MaxSize > 0",
func() *core.JSONField {
return &core.JSONField{
Id: "test",
Name: "test",
MaxSize: 1,
}
},
[]string{},
},
{
"MaxSize > safe json int",
func() *core.JSONField {
return &core.JSONField{
Id: "test",
Name: "test",
MaxSize: 1 << 53,
}
},
[]string{"maxSize"},
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
errs := s.field().ValidateSettings(context.Background(), app, collection)
tests.TestValidationErrors(t, errs, s.expectErrors)
})
}
}
func TestJSONFieldCalculateMaxBodySize(t *testing.T) {
testApp, _ := tests.NewTestApp()
defer testApp.Cleanup()
scenarios := []struct {
field *core.JSONField
expected int64
}{
{&core.JSONField{}, core.DefaultJSONFieldMaxSize},
{&core.JSONField{MaxSize: 10}, 10},
}
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%d", i, s.field.MaxSize), func(t *testing.T) {
result := s.field.CalculateMaxBodySize()
if result != s.expected {
t.Fatalf("Expected %d, got %d", s.expected, result)
}
})
}
}

222
core/field_number.go Normal file
View file

@ -0,0 +1,222 @@
package core
import (
"context"
"fmt"
"math"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/pocketbase/pocketbase/core/validators"
"github.com/spf13/cast"
)
func init() {
Fields[FieldTypeNumber] = func() Field {
return &NumberField{}
}
}
const FieldTypeNumber = "number"
var (
_ Field = (*NumberField)(nil)
_ SetterFinder = (*NumberField)(nil)
)
// NumberField defines "number" type field for storing numeric (float64) value.
//
// The respective zero record field value is 0.
//
// The following additional setter keys are available:
//
// - "fieldName+" - appends to the existing record value. For example:
// record.Set("total+", 5)
// - "fieldName-" - subtracts from the existing record value. For example:
// record.Set("total-", 5)
type NumberField struct {
// Name (required) is the unique name of the field.
Name string `form:"name" json:"name"`
// Id is the unique stable field identifier.
//
// It is automatically generated from the name when adding to a collection FieldsList.
Id string `form:"id" json:"id"`
// System prevents the renaming and removal of the field.
System bool `form:"system" json:"system"`
// Hidden hides the field from the API response.
Hidden bool `form:"hidden" json:"hidden"`
// Presentable hints the Dashboard UI to use the underlying
// field record value in the relation preview label.
Presentable bool `form:"presentable" json:"presentable"`
// ---
// Min specifies the min allowed field value.
//
// Leave it nil to skip the validator.
Min *float64 `form:"min" json:"min"`
// Max specifies the max allowed field value.
//
// Leave it nil to skip the validator.
Max *float64 `form:"max" json:"max"`
// OnlyInt will require the field value to be integer.
OnlyInt bool `form:"onlyInt" json:"onlyInt"`
// Required will require the field value to be non-zero.
Required bool `form:"required" json:"required"`
}
// Type implements [Field.Type] interface method.
func (f *NumberField) Type() string {
return FieldTypeNumber
}
// GetId implements [Field.GetId] interface method.
func (f *NumberField) GetId() string {
return f.Id
}
// SetId implements [Field.SetId] interface method.
func (f *NumberField) SetId(id string) {
f.Id = id
}
// GetName implements [Field.GetName] interface method.
func (f *NumberField) GetName() string {
return f.Name
}
// SetName implements [Field.SetName] interface method.
func (f *NumberField) SetName(name string) {
f.Name = name
}
// GetSystem implements [Field.GetSystem] interface method.
func (f *NumberField) GetSystem() bool {
return f.System
}
// SetSystem implements [Field.SetSystem] interface method.
func (f *NumberField) SetSystem(system bool) {
f.System = system
}
// GetHidden implements [Field.GetHidden] interface method.
func (f *NumberField) GetHidden() bool {
return f.Hidden
}
// SetHidden implements [Field.SetHidden] interface method.
func (f *NumberField) SetHidden(hidden bool) {
f.Hidden = hidden
}
// ColumnType implements [Field.ColumnType] interface method.
func (f *NumberField) ColumnType(app App) string {
return "NUMERIC DEFAULT 0 NOT NULL"
}
// PrepareValue implements [Field.PrepareValue] interface method.
func (f *NumberField) PrepareValue(record *Record, raw any) (any, error) {
return cast.ToFloat64(raw), nil
}
// ValidateValue implements [Field.ValidateValue] interface method.
func (f *NumberField) ValidateValue(ctx context.Context, app App, record *Record) error {
val, ok := record.GetRaw(f.Name).(float64)
if !ok {
return validators.ErrUnsupportedValueType
}
if math.IsInf(val, 0) || math.IsNaN(val) {
return validation.NewError("validation_not_a_number", "The submitted number is not properly formatted")
}
if val == 0 {
if f.Required {
if err := validation.Required.Validate(val); err != nil {
return err
}
}
return nil
}
if f.OnlyInt && val != float64(int64(val)) {
return validation.NewError("validation_only_int_constraint", "Decimal numbers are not allowed")
}
if f.Min != nil && val < *f.Min {
return validation.NewError("validation_min_number_constraint", fmt.Sprintf("Must be larger than %f", *f.Min))
}
if f.Max != nil && val > *f.Max {
return validation.NewError("validation_max_number_constraint", fmt.Sprintf("Must be less than %f", *f.Max))
}
return nil
}
// ValidateSettings implements [Field.ValidateSettings] interface method.
func (f *NumberField) ValidateSettings(ctx context.Context, app App, collection *Collection) error {
maxRules := []validation.Rule{
validation.By(f.checkOnlyInt),
}
if f.Min != nil && f.Max != nil {
maxRules = append(maxRules, validation.Min(*f.Min))
}
return validation.ValidateStruct(f,
validation.Field(&f.Id, validation.By(DefaultFieldIdValidationRule)),
validation.Field(&f.Name, validation.By(DefaultFieldNameValidationRule)),
validation.Field(&f.Min, validation.By(f.checkOnlyInt)),
validation.Field(&f.Max, maxRules...),
)
}
func (f *NumberField) checkOnlyInt(value any) error {
v, _ := value.(*float64)
if v == nil || !f.OnlyInt {
return nil // nothing to check
}
if *v != float64(int64(*v)) {
return validation.NewError("validation_only_int_constraint", "Decimal numbers are not allowed.")
}
return nil
}
// FindSetter implements the [SetterFinder] interface.
func (f *NumberField) FindSetter(key string) SetterFunc {
switch key {
case f.Name:
return f.setValue
case f.Name + "+":
return f.addValue
case f.Name + "-":
return f.subtractValue
default:
return nil
}
}
func (f *NumberField) setValue(record *Record, raw any) {
record.SetRaw(f.Name, cast.ToFloat64(raw))
}
func (f *NumberField) addValue(record *Record, raw any) {
val := cast.ToFloat64(record.GetRaw(f.Name))
record.SetRaw(f.Name, val+cast.ToFloat64(raw))
}
func (f *NumberField) subtractValue(record *Record, raw any) {
val := cast.ToFloat64(record.GetRaw(f.Name))
record.SetRaw(f.Name, val-cast.ToFloat64(raw))
}

403
core/field_number_test.go Normal file
View file

@ -0,0 +1,403 @@
package core_test
import (
"context"
"fmt"
"testing"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
"github.com/pocketbase/pocketbase/tools/types"
)
func TestNumberFieldBaseMethods(t *testing.T) {
testFieldBaseMethods(t, core.FieldTypeNumber)
}
func TestNumberFieldColumnType(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
f := &core.NumberField{}
expected := "NUMERIC DEFAULT 0 NOT NULL"
if v := f.ColumnType(app); v != expected {
t.Fatalf("Expected\n%q\ngot\n%q", expected, v)
}
}
func TestNumberFieldPrepareValue(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
f := &core.NumberField{}
record := core.NewRecord(core.NewBaseCollection("test"))
scenarios := []struct {
raw any
expected float64
}{
{"", 0},
{"test", 0},
{false, 0},
{true, 1},
{-2, -2},
{123.456, 123.456},
}
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%#v", i, s.raw), func(t *testing.T) {
vRaw, err := f.PrepareValue(record, s.raw)
if err != nil {
t.Fatal(err)
}
v, ok := vRaw.(float64)
if !ok {
t.Fatalf("Expected float64 instance, got %T", v)
}
if v != s.expected {
t.Fatalf("Expected %f, got %f", s.expected, v)
}
})
}
}
func TestNumberFieldValidateValue(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
collection := core.NewBaseCollection("test_collection")
scenarios := []struct {
name string
field *core.NumberField
record func() *core.Record
expectError bool
}{
{
"invalid raw value",
&core.NumberField{Name: "test"},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", "123")
return record
},
true,
},
{
"zero field value (not required)",
&core.NumberField{Name: "test"},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", 0.0)
return record
},
false,
},
{
"zero field value (required)",
&core.NumberField{Name: "test", Required: true},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", 0.0)
return record
},
true,
},
{
"non-zero field value (required)",
&core.NumberField{Name: "test", Required: true},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", 123.0)
return record
},
false,
},
{
"decimal with onlyInt",
&core.NumberField{Name: "test", OnlyInt: true},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", 123.456)
return record
},
true,
},
{
"int with onlyInt",
&core.NumberField{Name: "test", OnlyInt: true},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", 123.0)
return record
},
false,
},
{
"< min",
&core.NumberField{Name: "test", Min: types.Pointer(2.0)},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", 1.0)
return record
},
true,
},
{
">= min",
&core.NumberField{Name: "test", Min: types.Pointer(2.0)},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", 2.0)
return record
},
false,
},
{
"> max",
&core.NumberField{Name: "test", Max: types.Pointer(2.0)},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", 3.0)
return record
},
true,
},
{
"<= max",
&core.NumberField{Name: "test", Max: types.Pointer(2.0)},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", 2.0)
return record
},
false,
},
{
"infinity",
&core.NumberField{Name: "test"},
func() *core.Record {
record := core.NewRecord(collection)
record.Set("test", "Inf")
return record
},
true,
},
{
"NaN",
&core.NumberField{Name: "test"},
func() *core.Record {
record := core.NewRecord(collection)
record.Set("test", "NaN")
return record
},
true,
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
err := s.field.ValidateValue(context.Background(), app, s.record())
hasErr := err != nil
if hasErr != s.expectError {
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
}
})
}
}
func TestNumberFieldValidateSettings(t *testing.T) {
testDefaultFieldIdValidation(t, core.FieldTypeNumber)
testDefaultFieldNameValidation(t, core.FieldTypeNumber)
app, _ := tests.NewTestApp()
defer app.Cleanup()
collection := core.NewBaseCollection("test_collection")
scenarios := []struct {
name string
field func() *core.NumberField
expectErrors []string
}{
{
"zero",
func() *core.NumberField {
return &core.NumberField{
Id: "test",
Name: "test",
}
},
[]string{},
},
{
"decumal min",
func() *core.NumberField {
return &core.NumberField{
Id: "test",
Name: "test",
Min: types.Pointer(1.2),
}
},
[]string{},
},
{
"decumal min (onlyInt)",
func() *core.NumberField {
return &core.NumberField{
Id: "test",
Name: "test",
OnlyInt: true,
Min: types.Pointer(1.2),
}
},
[]string{"min"},
},
{
"int min (onlyInt)",
func() *core.NumberField {
return &core.NumberField{
Id: "test",
Name: "test",
OnlyInt: true,
Min: types.Pointer(1.0),
}
},
[]string{},
},
{
"decumal max",
func() *core.NumberField {
return &core.NumberField{
Id: "test",
Name: "test",
Max: types.Pointer(1.2),
}
},
[]string{},
},
{
"decumal max (onlyInt)",
func() *core.NumberField {
return &core.NumberField{
Id: "test",
Name: "test",
OnlyInt: true,
Max: types.Pointer(1.2),
}
},
[]string{"max"},
},
{
"int max (onlyInt)",
func() *core.NumberField {
return &core.NumberField{
Id: "test",
Name: "test",
OnlyInt: true,
Max: types.Pointer(1.0),
}
},
[]string{},
},
{
"min > max",
func() *core.NumberField {
return &core.NumberField{
Id: "test",
Name: "test",
Min: types.Pointer(2.0),
Max: types.Pointer(1.0),
}
},
[]string{"max"},
},
{
"min <= max",
func() *core.NumberField {
return &core.NumberField{
Id: "test",
Name: "test",
Min: types.Pointer(2.0),
Max: types.Pointer(2.0),
}
},
[]string{},
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
errs := s.field().ValidateSettings(context.Background(), app, collection)
tests.TestValidationErrors(t, errs, s.expectErrors)
})
}
}
func TestNumberFieldFindSetter(t *testing.T) {
field := &core.NumberField{Name: "test"}
collection := core.NewBaseCollection("test_collection")
collection.Fields.Add(field)
t.Run("no match", func(t *testing.T) {
f := field.FindSetter("abc")
if f != nil {
t.Fatal("Expected nil setter")
}
})
t.Run("direct name match", func(t *testing.T) {
f := field.FindSetter("test")
if f == nil {
t.Fatal("Expected non-nil setter")
}
record := core.NewRecord(collection)
record.SetRaw("test", 2.0)
f(record, "123.456") // should be casted
if v := record.Get("test"); v != 123.456 {
t.Fatalf("Expected %f, got %f", 123.456, v)
}
})
t.Run("name+ match", func(t *testing.T) {
f := field.FindSetter("test+")
if f == nil {
t.Fatal("Expected non-nil setter")
}
record := core.NewRecord(collection)
record.SetRaw("test", 2.0)
f(record, "1.5") // should be casted and appended to the existing value
if v := record.Get("test"); v != 3.5 {
t.Fatalf("Expected %f, got %f", 3.5, v)
}
})
t.Run("name- match", func(t *testing.T) {
f := field.FindSetter("test-")
if f == nil {
t.Fatal("Expected non-nil setter")
}
record := core.NewRecord(collection)
record.SetRaw("test", 2.0)
f(record, "1.5") // should be casted and subtracted from the existing value
if v := record.Get("test"); v != 0.5 {
t.Fatalf("Expected %f, got %f", 0.5, v)
}
})
}

318
core/field_password.go Normal file
View file

@ -0,0 +1,318 @@
package core
import (
"context"
"database/sql/driver"
"fmt"
"regexp"
"strings"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/pocketbase/pocketbase/core/validators"
"github.com/spf13/cast"
"golang.org/x/crypto/bcrypt"
)
func init() {
Fields[FieldTypePassword] = func() Field {
return &PasswordField{}
}
}
const FieldTypePassword = "password"
var (
_ Field = (*PasswordField)(nil)
_ GetterFinder = (*PasswordField)(nil)
_ SetterFinder = (*PasswordField)(nil)
_ DriverValuer = (*PasswordField)(nil)
_ RecordInterceptor = (*PasswordField)(nil)
)
// PasswordField defines "password" type field for storing bcrypt hashed strings
// (usually used only internally for the "password" auth collection system field).
//
// If you want to set a direct bcrypt hash as record field value you can use the SetRaw method, for example:
//
// // generates a bcrypt hash of "123456" and set it as field value
// // (record.GetString("password") returns the plain password until persisted, otherwise empty string)
// record.Set("password", "123456")
//
// // set directly a bcrypt hash of "123456" as field value
// // (record.GetString("password") returns empty string)
// record.SetRaw("password", "$2a$10$.5Elh8fgxypNUWhpUUr/xOa2sZm0VIaE0qWuGGl9otUfobb46T1Pq")
//
// The following additional getter keys are available:
//
// - "fieldName:hash" - returns the bcrypt hash string of the record field value (if any). For example:
// record.GetString("password:hash")
type PasswordField struct {
// Name (required) is the unique name of the field.
Name string `form:"name" json:"name"`
// Id is the unique stable field identifier.
//
// It is automatically generated from the name when adding to a collection FieldsList.
Id string `form:"id" json:"id"`
// System prevents the renaming and removal of the field.
System bool `form:"system" json:"system"`
// Hidden hides the field from the API response.
Hidden bool `form:"hidden" json:"hidden"`
// Presentable hints the Dashboard UI to use the underlying
// field record value in the relation preview label.
Presentable bool `form:"presentable" json:"presentable"`
// ---
// Pattern specifies an optional regex pattern to match against the field value.
//
// Leave it empty to skip the pattern check.
Pattern string `form:"pattern" json:"pattern"`
// Min specifies an optional required field string length.
Min int `form:"min" json:"min"`
// Max specifies an optional required field string length.
//
// If zero, fallback to max 71 bytes.
Max int `form:"max" json:"max"`
// Cost specifies the cost/weight/iteration/etc. bcrypt factor.
//
// If zero, fallback to [bcrypt.DefaultCost].
//
// If explicitly set, must be between [bcrypt.MinCost] and [bcrypt.MaxCost].
Cost int `form:"cost" json:"cost"`
// Required will require the field value to be non-empty string.
Required bool `form:"required" json:"required"`
}
// Type implements [Field.Type] interface method.
func (f *PasswordField) Type() string {
return FieldTypePassword
}
// GetId implements [Field.GetId] interface method.
func (f *PasswordField) GetId() string {
return f.Id
}
// SetId implements [Field.SetId] interface method.
func (f *PasswordField) SetId(id string) {
f.Id = id
}
// GetName implements [Field.GetName] interface method.
func (f *PasswordField) GetName() string {
return f.Name
}
// SetName implements [Field.SetName] interface method.
func (f *PasswordField) SetName(name string) {
f.Name = name
}
// GetSystem implements [Field.GetSystem] interface method.
func (f *PasswordField) GetSystem() bool {
return f.System
}
// SetSystem implements [Field.SetSystem] interface method.
func (f *PasswordField) SetSystem(system bool) {
f.System = system
}
// GetHidden implements [Field.GetHidden] interface method.
func (f *PasswordField) GetHidden() bool {
return f.Hidden
}
// SetHidden implements [Field.SetHidden] interface method.
func (f *PasswordField) SetHidden(hidden bool) {
f.Hidden = hidden
}
// ColumnType implements [Field.ColumnType] interface method.
func (f *PasswordField) ColumnType(app App) string {
return "TEXT DEFAULT '' NOT NULL"
}
// DriverValue implements the [DriverValuer] interface.
func (f *PasswordField) DriverValue(record *Record) (driver.Value, error) {
fp := f.getPasswordValue(record)
return fp.Hash, fp.LastError
}
// PrepareValue implements [Field.PrepareValue] interface method.
func (f *PasswordField) PrepareValue(record *Record, raw any) (any, error) {
return &PasswordFieldValue{
Hash: cast.ToString(raw),
}, nil
}
// ValidateValue implements [Field.ValidateValue] interface method.
func (f *PasswordField) ValidateValue(ctx context.Context, app App, record *Record) error {
fp, ok := record.GetRaw(f.Name).(*PasswordFieldValue)
if !ok {
return validators.ErrUnsupportedValueType
}
if fp.LastError != nil {
return fp.LastError
}
if f.Required {
if err := validation.Required.Validate(fp.Hash); err != nil {
return err
}
}
if fp.Plain == "" {
return nil // nothing to check
}
// note: casted to []rune to count multi-byte chars as one for the
// sake of more intuitive UX and clearer user error messages
//
// note2: technically multi-byte strings could produce bigger length than the bcrypt limit
// but it should be fine as it will be just truncated (even if it cuts a byte sequence in the middle)
length := len([]rune(fp.Plain))
if length < f.Min {
return validation.NewError("validation_min_text_constraint", fmt.Sprintf("Must be at least %d character(s)", f.Min))
}
maxLength := f.Max
if maxLength <= 0 {
maxLength = 71
}
if length > maxLength {
return validation.NewError("validation_max_text_constraint", fmt.Sprintf("Must be less than %d character(s)", maxLength))
}
if f.Pattern != "" {
match, _ := regexp.MatchString(f.Pattern, fp.Plain)
if !match {
return validation.NewError("validation_invalid_format", "Invalid value format")
}
}
return nil
}
// ValidateSettings implements [Field.ValidateSettings] interface method.
func (f *PasswordField) ValidateSettings(ctx context.Context, app App, collection *Collection) error {
return validation.ValidateStruct(f,
validation.Field(&f.Id, validation.By(DefaultFieldIdValidationRule)),
validation.Field(&f.Name, validation.By(DefaultFieldNameValidationRule)),
validation.Field(&f.Min, validation.Min(1), validation.Max(71)),
validation.Field(&f.Max, validation.Min(f.Min), validation.Max(71)),
validation.Field(&f.Cost, validation.Min(bcrypt.MinCost), validation.Max(bcrypt.MaxCost)),
validation.Field(&f.Pattern, validation.By(validators.IsRegex)),
)
}
func (f *PasswordField) getPasswordValue(record *Record) *PasswordFieldValue {
raw := record.GetRaw(f.Name)
switch v := raw.(type) {
case *PasswordFieldValue:
return v
case string:
// we assume that any raw string starting with $2 is bcrypt hash
if strings.HasPrefix(v, "$2") {
return &PasswordFieldValue{Hash: v}
}
}
return &PasswordFieldValue{}
}
// Intercept implements the [RecordInterceptor] interface.
func (f *PasswordField) Intercept(
ctx context.Context,
app App,
record *Record,
actionName string,
actionFunc func() error,
) error {
switch actionName {
case InterceptorActionAfterCreate, InterceptorActionAfterUpdate:
// unset the plain field value after successful create/update
fp := f.getPasswordValue(record)
fp.Plain = ""
}
return actionFunc()
}
// FindGetter implements the [GetterFinder] interface.
func (f *PasswordField) FindGetter(key string) GetterFunc {
switch key {
case f.Name:
return func(record *Record) any {
return f.getPasswordValue(record).Plain
}
case f.Name + ":hash":
return func(record *Record) any {
return f.getPasswordValue(record).Hash
}
default:
return nil
}
}
// FindSetter implements the [SetterFinder] interface.
func (f *PasswordField) FindSetter(key string) SetterFunc {
switch key {
case f.Name:
return f.setValue
default:
return nil
}
}
func (f *PasswordField) setValue(record *Record, raw any) {
fv := &PasswordFieldValue{
Plain: cast.ToString(raw),
}
// hash the password
if fv.Plain != "" {
cost := f.Cost
if cost <= 0 {
cost = bcrypt.DefaultCost
}
hash, err := bcrypt.GenerateFromPassword([]byte(fv.Plain), cost)
if err != nil {
fv.LastError = err
}
fv.Hash = string(hash)
}
record.SetRaw(f.Name, fv)
}
// -------------------------------------------------------------------
type PasswordFieldValue struct {
LastError error
Hash string
Plain string
}
func (pv PasswordFieldValue) Validate(pass string) bool {
if pv.Hash == "" || pv.LastError != nil {
return false
}
err := bcrypt.CompareHashAndPassword([]byte(pv.Hash), []byte(pass))
return err == nil
}

568
core/field_password_test.go Normal file
View file

@ -0,0 +1,568 @@
package core_test
import (
"context"
"encoding/json"
"errors"
"fmt"
"strings"
"testing"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
"golang.org/x/crypto/bcrypt"
)
func TestPasswordFieldBaseMethods(t *testing.T) {
testFieldBaseMethods(t, core.FieldTypePassword)
}
func TestPasswordFieldColumnType(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
f := &core.PasswordField{}
expected := "TEXT DEFAULT '' NOT NULL"
if v := f.ColumnType(app); v != expected {
t.Fatalf("Expected\n%q\ngot\n%q", expected, v)
}
}
func TestPasswordFieldPrepareValue(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
f := &core.PasswordField{}
record := core.NewRecord(core.NewBaseCollection("test"))
scenarios := []struct {
raw any
expected string
}{
{"", ""},
{"test", "test"},
{false, "false"},
{true, "true"},
{123.456, "123.456"},
}
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%#v", i, s.raw), func(t *testing.T) {
v, err := f.PrepareValue(record, s.raw)
if err != nil {
t.Fatal(err)
}
pv, ok := v.(*core.PasswordFieldValue)
if !ok {
t.Fatalf("Expected PasswordFieldValue instance, got %T", v)
}
if pv.Hash != s.expected {
t.Fatalf("Expected %q, got %q", s.expected, v)
}
})
}
}
func TestPasswordFieldDriverValue(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
f := &core.PasswordField{Name: "test"}
err := errors.New("example_err")
scenarios := []struct {
raw any
expected *core.PasswordFieldValue
}{
{123, &core.PasswordFieldValue{}},
{"abc", &core.PasswordFieldValue{}},
{"$2abc", &core.PasswordFieldValue{Hash: "$2abc"}},
{&core.PasswordFieldValue{Hash: "test", LastError: err}, &core.PasswordFieldValue{Hash: "test", LastError: err}},
}
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%v", i, s.raw), func(t *testing.T) {
record := core.NewRecord(core.NewBaseCollection("test"))
record.SetRaw(f.GetName(), s.raw)
v, err := f.DriverValue(record)
vStr, ok := v.(string)
if !ok {
t.Fatalf("Expected string instance, got %T", v)
}
var errStr string
if err != nil {
errStr = err.Error()
}
var expectedErrStr string
if s.expected.LastError != nil {
expectedErrStr = s.expected.LastError.Error()
}
if errStr != expectedErrStr {
t.Fatalf("Expected error %q, got %q", expectedErrStr, errStr)
}
if vStr != s.expected.Hash {
t.Fatalf("Expected hash %q, got %q", s.expected.Hash, vStr)
}
})
}
}
func TestPasswordFieldValidateValue(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
collection := core.NewBaseCollection("test_collection")
scenarios := []struct {
name string
field *core.PasswordField
record func() *core.Record
expectError bool
}{
{
"invalid raw value",
&core.PasswordField{Name: "test"},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", "123")
return record
},
true,
},
{
"zero field value (not required)",
&core.PasswordField{Name: "test"},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", &core.PasswordFieldValue{})
return record
},
false,
},
{
"zero field value (required)",
&core.PasswordField{Name: "test", Required: true},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", &core.PasswordFieldValue{})
return record
},
true,
},
{
"empty hash but non-empty plain password (required)",
&core.PasswordField{Name: "test", Required: true},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", &core.PasswordFieldValue{Plain: "test"})
return record
},
true,
},
{
"non-empty hash (required)",
&core.PasswordField{Name: "test", Required: true},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", &core.PasswordFieldValue{Hash: "test"})
return record
},
false,
},
{
"with LastError",
&core.PasswordField{Name: "test"},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", &core.PasswordFieldValue{LastError: errors.New("test")})
return record
},
true,
},
{
"< Min",
&core.PasswordField{Name: "test", Min: 3},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", &core.PasswordFieldValue{Plain: "аб"}) // multi-byte chars test
return record
},
true,
},
{
">= Min",
&core.PasswordField{Name: "test", Min: 3},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", &core.PasswordFieldValue{Plain: "абв"}) // multi-byte chars test
return record
},
false,
},
{
"> default Max",
&core.PasswordField{Name: "test"},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", &core.PasswordFieldValue{Plain: strings.Repeat("a", 72)})
return record
},
true,
},
{
"<= default Max",
&core.PasswordField{Name: "test"},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", &core.PasswordFieldValue{Plain: strings.Repeat("a", 71)})
return record
},
false,
},
{
"> Max",
&core.PasswordField{Name: "test", Max: 2},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", &core.PasswordFieldValue{Plain: "абв"}) // multi-byte chars test
return record
},
true,
},
{
"<= Max",
&core.PasswordField{Name: "test", Max: 2},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", &core.PasswordFieldValue{Plain: "аб"}) // multi-byte chars test
return record
},
false,
},
{
"non-matching pattern",
&core.PasswordField{Name: "test", Pattern: `\d+`},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", &core.PasswordFieldValue{Plain: "abc"})
return record
},
true,
},
{
"matching pattern",
&core.PasswordField{Name: "test", Pattern: `\d+`},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", &core.PasswordFieldValue{Plain: "123"})
return record
},
false,
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
err := s.field.ValidateValue(context.Background(), app, s.record())
hasErr := err != nil
if hasErr != s.expectError {
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
}
})
}
}
func TestPasswordFieldValidateSettings(t *testing.T) {
testDefaultFieldIdValidation(t, core.FieldTypePassword)
testDefaultFieldNameValidation(t, core.FieldTypePassword)
app, _ := tests.NewTestApp()
defer app.Cleanup()
scenarios := []struct {
name string
field func(col *core.Collection) *core.PasswordField
expectErrors []string
}{
{
"zero minimal",
func(col *core.Collection) *core.PasswordField {
return &core.PasswordField{
Id: "test",
Name: "test",
}
},
[]string{},
},
{
"invalid pattern",
func(col *core.Collection) *core.PasswordField {
return &core.PasswordField{
Id: "test",
Name: "test",
Pattern: "(invalid",
}
},
[]string{"pattern"},
},
{
"valid pattern",
func(col *core.Collection) *core.PasswordField {
return &core.PasswordField{
Id: "test",
Name: "test",
Pattern: `\d+`,
}
},
[]string{},
},
{
"Min < 0",
func(col *core.Collection) *core.PasswordField {
return &core.PasswordField{
Id: "test",
Name: "test",
Min: -1,
}
},
[]string{"min"},
},
{
"Min > 71",
func(col *core.Collection) *core.PasswordField {
return &core.PasswordField{
Id: "test",
Name: "test",
Min: 72,
}
},
[]string{"min"},
},
{
"valid Min",
func(col *core.Collection) *core.PasswordField {
return &core.PasswordField{
Id: "test",
Name: "test",
Min: 5,
}
},
[]string{},
},
{
"Max < Min",
func(col *core.Collection) *core.PasswordField {
return &core.PasswordField{
Id: "test",
Name: "test",
Min: 2,
Max: 1,
}
},
[]string{"max"},
},
{
"Min > Min",
func(col *core.Collection) *core.PasswordField {
return &core.PasswordField{
Id: "test",
Name: "test",
Min: 2,
Max: 3,
}
},
[]string{},
},
{
"Max > 71",
func(col *core.Collection) *core.PasswordField {
return &core.PasswordField{
Id: "test",
Name: "test",
Max: 72,
}
},
[]string{"max"},
},
{
"cost < bcrypt.MinCost",
func(col *core.Collection) *core.PasswordField {
return &core.PasswordField{
Id: "test",
Name: "test",
Cost: bcrypt.MinCost - 1,
}
},
[]string{"cost"},
},
{
"cost > bcrypt.MaxCost",
func(col *core.Collection) *core.PasswordField {
return &core.PasswordField{
Id: "test",
Name: "test",
Cost: bcrypt.MaxCost + 1,
}
},
[]string{"cost"},
},
{
"valid cost",
func(col *core.Collection) *core.PasswordField {
return &core.PasswordField{
Id: "test",
Name: "test",
Cost: 12,
}
},
[]string{},
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
collection := core.NewBaseCollection("test_collection")
collection.Fields.GetByName("id").SetId("test") // set a dummy known id so that it can be replaced
field := s.field(collection)
collection.Fields.Add(field)
errs := field.ValidateSettings(context.Background(), app, collection)
tests.TestValidationErrors(t, errs, s.expectErrors)
})
}
}
func TestPasswordFieldFindSetter(t *testing.T) {
scenarios := []struct {
name string
key string
value any
field *core.PasswordField
hasSetter bool
expected string
}{
{
"no match",
"example",
"abc",
&core.PasswordField{Name: "test"},
false,
"",
},
{
"exact match",
"test",
"abc",
&core.PasswordField{Name: "test"},
true,
`"abc"`,
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
collection := core.NewBaseCollection("test_collection")
collection.Fields.Add(s.field)
setter := s.field.FindSetter(s.key)
hasSetter := setter != nil
if hasSetter != s.hasSetter {
t.Fatalf("Expected hasSetter %v, got %v", s.hasSetter, hasSetter)
}
if !hasSetter {
return
}
record := core.NewRecord(collection)
record.SetRaw(s.field.GetName(), []string{"c", "d"})
setter(record, s.value)
raw, err := json.Marshal(record.Get(s.field.GetName()))
if err != nil {
t.Fatal(err)
}
rawStr := string(raw)
if rawStr != s.expected {
t.Fatalf("Expected %q, got %q", s.expected, rawStr)
}
})
}
}
func TestPasswordFieldFindGetter(t *testing.T) {
scenarios := []struct {
name string
key string
field *core.PasswordField
hasGetter bool
expected string
}{
{
"no match",
"example",
&core.PasswordField{Name: "test"},
false,
"",
},
{
"field name match",
"test",
&core.PasswordField{Name: "test"},
true,
"test_plain",
},
{
"field name hash modifier",
"test:hash",
&core.PasswordField{Name: "test"},
true,
"test_hash",
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
collection := core.NewBaseCollection("test_collection")
collection.Fields.Add(s.field)
getter := s.field.FindGetter(s.key)
hasGetter := getter != nil
if hasGetter != s.hasGetter {
t.Fatalf("Expected hasGetter %v, got %v", s.hasGetter, hasGetter)
}
if !hasGetter {
return
}
record := core.NewRecord(collection)
record.SetRaw(s.field.GetName(), &core.PasswordFieldValue{Hash: "test_hash", Plain: "test_plain"})
result := getter(record)
if result != s.expected {
t.Fatalf("Expected %q, got %#v", s.expected, result)
}
})
}
}

350
core/field_relation.go Normal file
View file

@ -0,0 +1,350 @@
package core
import (
"context"
"database/sql/driver"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase/tools/list"
"github.com/pocketbase/pocketbase/tools/types"
)
func init() {
Fields[FieldTypeRelation] = func() Field {
return &RelationField{}
}
}
const FieldTypeRelation = "relation"
var (
_ Field = (*RelationField)(nil)
_ MultiValuer = (*RelationField)(nil)
_ DriverValuer = (*RelationField)(nil)
_ SetterFinder = (*RelationField)(nil)
)
// RelationField defines "relation" type field for storing single or
// multiple collection record references.
//
// Requires the CollectionId option to be set.
//
// If MaxSelect is not set or <= 1, then the field value is expected to be a single record id.
//
// If MaxSelect is > 1, then the field value is expected to be a slice of record ids.
//
// The respective zero record field value is either empty string (single) or empty string slice (multiple).
//
// ---
//
// The following additional setter keys are available:
//
// - "fieldName+" - append one or more values to the existing record one. For example:
//
// record.Set("categories+", []string{"new1", "new2"}) // []string{"old1", "old2", "new1", "new2"}
//
// - "+fieldName" - prepend one or more values to the existing record one. For example:
//
// record.Set("+categories", []string{"new1", "new2"}) // []string{"new1", "new2", "old1", "old2"}
//
// - "fieldName-" - subtract one or more values from the existing record one. For example:
//
// record.Set("categories-", "old1") // []string{"old2"}
type RelationField struct {
// Name (required) is the unique name of the field.
Name string `form:"name" json:"name"`
// Id is the unique stable field identifier.
//
// It is automatically generated from the name when adding to a collection FieldsList.
Id string `form:"id" json:"id"`
// System prevents the renaming and removal of the field.
System bool `form:"system" json:"system"`
// Hidden hides the field from the API response.
Hidden bool `form:"hidden" json:"hidden"`
// Presentable hints the Dashboard UI to use the underlying
// field record value in the relation preview label.
Presentable bool `form:"presentable" json:"presentable"`
// ---
// CollectionId is the id of the related collection.
CollectionId string `form:"collectionId" json:"collectionId"`
// CascadeDelete indicates whether the root model should be deleted
// in case of delete of all linked relations.
CascadeDelete bool `form:"cascadeDelete" json:"cascadeDelete"`
// MinSelect indicates the min number of allowed relation records
// that could be linked to the main model.
//
// No min limit is applied if it is zero or negative value.
MinSelect int `form:"minSelect" json:"minSelect"`
// MaxSelect indicates the max number of allowed relation records
// that could be linked to the main model.
//
// For multiple select the value must be > 1, otherwise fallbacks to single (default).
//
// If MinSelect is set, MaxSelect must be at least >= MinSelect.
MaxSelect int `form:"maxSelect" json:"maxSelect"`
// Required will require the field value to be non-empty.
Required bool `form:"required" json:"required"`
}
// Type implements [Field.Type] interface method.
func (f *RelationField) Type() string {
return FieldTypeRelation
}
// GetId implements [Field.GetId] interface method.
func (f *RelationField) GetId() string {
return f.Id
}
// SetId implements [Field.SetId] interface method.
func (f *RelationField) SetId(id string) {
f.Id = id
}
// GetName implements [Field.GetName] interface method.
func (f *RelationField) GetName() string {
return f.Name
}
// SetName implements [Field.SetName] interface method.
func (f *RelationField) SetName(name string) {
f.Name = name
}
// GetSystem implements [Field.GetSystem] interface method.
func (f *RelationField) GetSystem() bool {
return f.System
}
// SetSystem implements [Field.SetSystem] interface method.
func (f *RelationField) SetSystem(system bool) {
f.System = system
}
// GetHidden implements [Field.GetHidden] interface method.
func (f *RelationField) GetHidden() bool {
return f.Hidden
}
// SetHidden implements [Field.SetHidden] interface method.
func (f *RelationField) SetHidden(hidden bool) {
f.Hidden = hidden
}
// IsMultiple implements [MultiValuer] interface and checks whether the
// current field options support multiple values.
func (f *RelationField) IsMultiple() bool {
return f.MaxSelect > 1
}
// ColumnType implements [Field.ColumnType] interface method.
func (f *RelationField) ColumnType(app App) string {
if f.IsMultiple() {
return "JSON DEFAULT '[]' NOT NULL"
}
return "TEXT DEFAULT '' NOT NULL"
}
// PrepareValue implements [Field.PrepareValue] interface method.
func (f *RelationField) PrepareValue(record *Record, raw any) (any, error) {
return f.normalizeValue(raw), nil
}
func (f *RelationField) normalizeValue(raw any) any {
val := list.ToUniqueStringSlice(raw)
if !f.IsMultiple() {
if len(val) > 0 {
return val[len(val)-1] // the last selected
}
return ""
}
return val
}
// DriverValue implements the [DriverValuer] interface.
func (f *RelationField) DriverValue(record *Record) (driver.Value, error) {
val := list.ToUniqueStringSlice(record.GetRaw(f.Name))
if !f.IsMultiple() {
if len(val) > 0 {
return val[len(val)-1], nil // the last selected
}
return "", nil
}
// serialize as json string array
return append(types.JSONArray[string]{}, val...), nil
}
// ValidateValue implements [Field.ValidateValue] interface method.
func (f *RelationField) ValidateValue(ctx context.Context, app App, record *Record) error {
ids := list.ToUniqueStringSlice(record.GetRaw(f.Name))
if len(ids) == 0 {
if f.Required {
return validation.ErrRequired
}
return nil // nothing to check
}
if f.MinSelect > 0 && len(ids) < f.MinSelect {
return validation.NewError("validation_not_enough_values", "Select at least {{.minSelect}}").
SetParams(map[string]any{"minSelect": f.MinSelect})
}
maxSelect := max(f.MaxSelect, 1)
if len(ids) > maxSelect {
return validation.NewError("validation_too_many_values", "Select no more than {{.maxSelect}}").
SetParams(map[string]any{"maxSelect": maxSelect})
}
// check if the related records exist
// ---
relCollection, err := app.FindCachedCollectionByNameOrId(f.CollectionId)
if err != nil {
return validation.NewError("validation_missing_rel_collection", "Relation connection is missing or cannot be accessed")
}
var total int
_ = app.ConcurrentDB().
Select("count(*)").
From(relCollection.Name).
AndWhere(dbx.In("id", list.ToInterfaceSlice(ids)...)).
Row(&total)
if total != len(ids) {
return validation.NewError("validation_missing_rel_records", "Failed to find all relation records with the provided ids")
}
// ---
return nil
}
// ValidateSettings implements [Field.ValidateSettings] interface method.
func (f *RelationField) ValidateSettings(ctx context.Context, app App, collection *Collection) error {
return validation.ValidateStruct(f,
validation.Field(&f.Id, validation.By(DefaultFieldIdValidationRule)),
validation.Field(&f.Name, validation.By(DefaultFieldNameValidationRule)),
validation.Field(&f.CollectionId, validation.Required, validation.By(f.checkCollectionId(app, collection))),
validation.Field(&f.MinSelect, validation.Min(0)),
validation.Field(&f.MaxSelect, validation.When(f.MinSelect > 0, validation.Required), validation.Min(f.MinSelect)),
)
}
func (f *RelationField) checkCollectionId(app App, collection *Collection) validation.RuleFunc {
return func(value any) error {
v, _ := value.(string)
if v == "" {
return nil // nothing to check
}
var oldCollection *Collection
if !collection.IsNew() {
var err error
oldCollection, err = app.FindCachedCollectionByNameOrId(collection.Id)
if err != nil {
return err
}
}
// prevent collectionId change
if oldCollection != nil {
oldField, _ := oldCollection.Fields.GetById(f.Id).(*RelationField)
if oldField != nil && oldField.CollectionId != v {
return validation.NewError(
"validation_field_relation_change",
"The relation collection cannot be changed.",
)
}
}
relCollection, _ := app.FindCachedCollectionByNameOrId(v)
// validate collectionId
if relCollection == nil || relCollection.Id != v {
return validation.NewError(
"validation_field_relation_missing_collection",
"The relation collection doesn't exist.",
)
}
// allow only views to have relations to other views
// (see https://github.com/pocketbase/pocketbase/issues/3000)
if !collection.IsView() && relCollection.IsView() {
return validation.NewError(
"validation_relation_field_non_view_base_collection",
"Only view collections are allowed to have relations to other views.",
)
}
return nil
}
}
// ---
// FindSetter implements [SetterFinder] interface method.
func (f *RelationField) FindSetter(key string) SetterFunc {
switch key {
case f.Name:
return f.setValue
case "+" + f.Name:
return f.prependValue
case f.Name + "+":
return f.appendValue
case f.Name + "-":
return f.subtractValue
default:
return nil
}
}
func (f *RelationField) setValue(record *Record, raw any) {
record.SetRaw(f.Name, f.normalizeValue(raw))
}
func (f *RelationField) appendValue(record *Record, modifierValue any) {
val := record.GetRaw(f.Name)
val = append(
list.ToUniqueStringSlice(val),
list.ToUniqueStringSlice(modifierValue)...,
)
f.setValue(record, val)
}
func (f *RelationField) prependValue(record *Record, modifierValue any) {
val := record.GetRaw(f.Name)
val = append(
list.ToUniqueStringSlice(modifierValue),
list.ToUniqueStringSlice(val)...,
)
f.setValue(record, val)
}
func (f *RelationField) subtractValue(record *Record, modifierValue any) {
val := record.GetRaw(f.Name)
val = list.SubtractSlice(
list.ToUniqueStringSlice(val),
list.ToUniqueStringSlice(modifierValue),
)
f.setValue(record, val)
}

603
core/field_relation_test.go Normal file
View file

@ -0,0 +1,603 @@
package core_test
import (
"context"
"encoding/json"
"fmt"
"testing"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
"github.com/pocketbase/pocketbase/tools/types"
)
func TestRelationFieldBaseMethods(t *testing.T) {
testFieldBaseMethods(t, core.FieldTypeRelation)
}
func TestRelationFieldColumnType(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
scenarios := []struct {
name string
field *core.RelationField
expected string
}{
{
"single (zero)",
&core.RelationField{},
"TEXT DEFAULT '' NOT NULL",
},
{
"single",
&core.RelationField{MaxSelect: 1},
"TEXT DEFAULT '' NOT NULL",
},
{
"multiple",
&core.RelationField{MaxSelect: 2},
"JSON DEFAULT '[]' NOT NULL",
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
if v := s.field.ColumnType(app); v != s.expected {
t.Fatalf("Expected\n%q\ngot\n%q", s.expected, v)
}
})
}
}
func TestRelationFieldIsMultiple(t *testing.T) {
scenarios := []struct {
name string
field *core.RelationField
expected bool
}{
{
"zero",
&core.RelationField{},
false,
},
{
"single",
&core.RelationField{MaxSelect: 1},
false,
},
{
"multiple",
&core.RelationField{MaxSelect: 2},
true,
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
if v := s.field.IsMultiple(); v != s.expected {
t.Fatalf("Expected %v, got %v", s.expected, v)
}
})
}
}
func TestRelationFieldPrepareValue(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
record := core.NewRecord(core.NewBaseCollection("test"))
scenarios := []struct {
raw any
field *core.RelationField
expected string
}{
// single
{nil, &core.RelationField{MaxSelect: 1}, `""`},
{"", &core.RelationField{MaxSelect: 1}, `""`},
{123, &core.RelationField{MaxSelect: 1}, `"123"`},
{"a", &core.RelationField{MaxSelect: 1}, `"a"`},
{`["a"]`, &core.RelationField{MaxSelect: 1}, `"a"`},
{[]string{}, &core.RelationField{MaxSelect: 1}, `""`},
{[]string{"a", "b"}, &core.RelationField{MaxSelect: 1}, `"b"`},
// multiple
{nil, &core.RelationField{MaxSelect: 2}, `[]`},
{"", &core.RelationField{MaxSelect: 2}, `[]`},
{123, &core.RelationField{MaxSelect: 2}, `["123"]`},
{"a", &core.RelationField{MaxSelect: 2}, `["a"]`},
{`["a"]`, &core.RelationField{MaxSelect: 2}, `["a"]`},
{[]string{}, &core.RelationField{MaxSelect: 2}, `[]`},
{[]string{"a", "b", "c"}, &core.RelationField{MaxSelect: 2}, `["a","b","c"]`},
}
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%#v_%v", i, s.raw, s.field.IsMultiple()), func(t *testing.T) {
v, err := s.field.PrepareValue(record, s.raw)
if err != nil {
t.Fatal(err)
}
vRaw, err := json.Marshal(v)
if err != nil {
t.Fatal(err)
}
if string(vRaw) != s.expected {
t.Fatalf("Expected %q, got %q", s.expected, vRaw)
}
})
}
}
func TestRelationFieldDriverValue(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
scenarios := []struct {
raw any
field *core.RelationField
expected string
}{
// single
{nil, &core.RelationField{MaxSelect: 1}, `""`},
{"", &core.RelationField{MaxSelect: 1}, `""`},
{123, &core.RelationField{MaxSelect: 1}, `"123"`},
{"a", &core.RelationField{MaxSelect: 1}, `"a"`},
{`["a"]`, &core.RelationField{MaxSelect: 1}, `"a"`},
{[]string{}, &core.RelationField{MaxSelect: 1}, `""`},
{[]string{"a", "b"}, &core.RelationField{MaxSelect: 1}, `"b"`},
// multiple
{nil, &core.RelationField{MaxSelect: 2}, `[]`},
{"", &core.RelationField{MaxSelect: 2}, `[]`},
{123, &core.RelationField{MaxSelect: 2}, `["123"]`},
{"a", &core.RelationField{MaxSelect: 2}, `["a"]`},
{`["a"]`, &core.RelationField{MaxSelect: 2}, `["a"]`},
{[]string{}, &core.RelationField{MaxSelect: 2}, `[]`},
{[]string{"a", "b", "c"}, &core.RelationField{MaxSelect: 2}, `["a","b","c"]`},
}
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%#v_%v", i, s.raw, s.field.IsMultiple()), func(t *testing.T) {
record := core.NewRecord(core.NewBaseCollection("test"))
record.SetRaw(s.field.GetName(), s.raw)
v, err := s.field.DriverValue(record)
if err != nil {
t.Fatal(err)
}
if s.field.IsMultiple() {
_, ok := v.(types.JSONArray[string])
if !ok {
t.Fatalf("Expected types.JSONArray value, got %T", v)
}
} else {
_, ok := v.(string)
if !ok {
t.Fatalf("Expected string value, got %T", v)
}
}
vRaw, err := json.Marshal(v)
if err != nil {
t.Fatal(err)
}
if string(vRaw) != s.expected {
t.Fatalf("Expected %q, got %q", s.expected, vRaw)
}
})
}
}
func TestRelationFieldValidateValue(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
demo1, err := app.FindCollectionByNameOrId("demo1")
if err != nil {
t.Fatal(err)
}
scenarios := []struct {
name string
field *core.RelationField
record func() *core.Record
expectError bool
}{
// single
{
"[single] zero field value (not required)",
&core.RelationField{Name: "test", MaxSelect: 1, CollectionId: demo1.Id},
func() *core.Record {
record := core.NewRecord(core.NewBaseCollection("test_collection"))
record.SetRaw("test", "")
return record
},
false,
},
{
"[single] zero field value (required)",
&core.RelationField{Name: "test", MaxSelect: 1, CollectionId: demo1.Id, Required: true},
func() *core.Record {
record := core.NewRecord(core.NewBaseCollection("test_collection"))
record.SetRaw("test", "")
return record
},
true,
},
{
"[single] id from other collection",
&core.RelationField{Name: "test", MaxSelect: 1, CollectionId: demo1.Id},
func() *core.Record {
record := core.NewRecord(core.NewBaseCollection("test_collection"))
record.SetRaw("test", "achvryl401bhse3")
return record
},
true,
},
{
"[single] valid id",
&core.RelationField{Name: "test", MaxSelect: 1, CollectionId: demo1.Id},
func() *core.Record {
record := core.NewRecord(core.NewBaseCollection("test_collection"))
record.SetRaw("test", "84nmscqy84lsi1t")
return record
},
false,
},
{
"[single] > MaxSelect",
&core.RelationField{Name: "test", MaxSelect: 1, CollectionId: demo1.Id},
func() *core.Record {
record := core.NewRecord(core.NewBaseCollection("test_collection"))
record.SetRaw("test", []string{"84nmscqy84lsi1t", "al1h9ijdeojtsjy"})
return record
},
true,
},
// multiple
{
"[multiple] zero field value (not required)",
&core.RelationField{Name: "test", MaxSelect: 2, CollectionId: demo1.Id},
func() *core.Record {
record := core.NewRecord(core.NewBaseCollection("test_collection"))
record.SetRaw("test", []string{})
return record
},
false,
},
{
"[multiple] zero field value (required)",
&core.RelationField{Name: "test", MaxSelect: 2, CollectionId: demo1.Id, Required: true},
func() *core.Record {
record := core.NewRecord(core.NewBaseCollection("test_collection"))
record.SetRaw("test", []string{})
return record
},
true,
},
{
"[multiple] id from other collection",
&core.RelationField{Name: "test", MaxSelect: 2, CollectionId: demo1.Id},
func() *core.Record {
record := core.NewRecord(core.NewBaseCollection("test_collection"))
record.SetRaw("test", []string{"84nmscqy84lsi1t", "achvryl401bhse3"})
return record
},
true,
},
{
"[multiple] valid id",
&core.RelationField{Name: "test", MaxSelect: 2, CollectionId: demo1.Id},
func() *core.Record {
record := core.NewRecord(core.NewBaseCollection("test_collection"))
record.SetRaw("test", []string{"84nmscqy84lsi1t", "al1h9ijdeojtsjy"})
return record
},
false,
},
{
"[multiple] > MaxSelect",
&core.RelationField{Name: "test", MaxSelect: 2, CollectionId: demo1.Id},
func() *core.Record {
record := core.NewRecord(core.NewBaseCollection("test_collection"))
record.SetRaw("test", []string{"84nmscqy84lsi1t", "al1h9ijdeojtsjy", "imy661ixudk5izi"})
return record
},
true,
},
{
"[multiple] < MinSelect",
&core.RelationField{Name: "test", MinSelect: 2, MaxSelect: 99, CollectionId: demo1.Id},
func() *core.Record {
record := core.NewRecord(core.NewBaseCollection("test_collection"))
record.SetRaw("test", []string{"84nmscqy84lsi1t"})
return record
},
true,
},
{
"[multiple] >= MinSelect",
&core.RelationField{Name: "test", MinSelect: 2, MaxSelect: 99, CollectionId: demo1.Id},
func() *core.Record {
record := core.NewRecord(core.NewBaseCollection("test_collection"))
record.SetRaw("test", []string{"84nmscqy84lsi1t", "al1h9ijdeojtsjy", "imy661ixudk5izi"})
return record
},
false,
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
err := s.field.ValidateValue(context.Background(), app, s.record())
hasErr := err != nil
if hasErr != s.expectError {
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
}
})
}
}
func TestRelationFieldValidateSettings(t *testing.T) {
testDefaultFieldIdValidation(t, core.FieldTypeRelation)
testDefaultFieldNameValidation(t, core.FieldTypeRelation)
app, _ := tests.NewTestApp()
defer app.Cleanup()
demo1, err := app.FindCollectionByNameOrId("demo1")
if err != nil {
t.Fatal(err)
}
scenarios := []struct {
name string
field func(col *core.Collection) *core.RelationField
expectErrors []string
}{
{
"zero minimal",
func(col *core.Collection) *core.RelationField {
return &core.RelationField{
Id: "test",
Name: "test",
}
},
[]string{"collectionId"},
},
{
"invalid collectionId",
func(col *core.Collection) *core.RelationField {
return &core.RelationField{
Id: "test",
Name: "test",
CollectionId: demo1.Name,
}
},
[]string{"collectionId"},
},
{
"valid collectionId",
func(col *core.Collection) *core.RelationField {
return &core.RelationField{
Id: "test",
Name: "test",
CollectionId: demo1.Id,
}
},
[]string{},
},
{
"base->view",
func(col *core.Collection) *core.RelationField {
return &core.RelationField{
Id: "test",
Name: "test",
CollectionId: "v9gwnfh02gjq1q0",
}
},
[]string{"collectionId"},
},
{
"view->view",
func(col *core.Collection) *core.RelationField {
col.Type = core.CollectionTypeView
return &core.RelationField{
Id: "test",
Name: "test",
CollectionId: "v9gwnfh02gjq1q0",
}
},
[]string{},
},
{
"MinSelect < 0",
func(col *core.Collection) *core.RelationField {
return &core.RelationField{
Id: "test",
Name: "test",
CollectionId: demo1.Id,
MinSelect: -1,
}
},
[]string{"minSelect"},
},
{
"MinSelect > 0",
func(col *core.Collection) *core.RelationField {
return &core.RelationField{
Id: "test",
Name: "test",
CollectionId: demo1.Id,
MinSelect: 1,
}
},
[]string{"maxSelect"},
},
{
"MaxSelect < MinSelect",
func(col *core.Collection) *core.RelationField {
return &core.RelationField{
Id: "test",
Name: "test",
CollectionId: demo1.Id,
MinSelect: 2,
MaxSelect: 1,
}
},
[]string{"maxSelect"},
},
{
"MaxSelect >= MinSelect",
func(col *core.Collection) *core.RelationField {
return &core.RelationField{
Id: "test",
Name: "test",
CollectionId: demo1.Id,
MinSelect: 2,
MaxSelect: 2,
}
},
[]string{},
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
collection := core.NewBaseCollection("test_collection")
collection.Fields.GetByName("id").SetId("test") // set a dummy known id so that it can be replaced
field := s.field(collection)
collection.Fields.Add(field)
errs := field.ValidateSettings(context.Background(), app, collection)
tests.TestValidationErrors(t, errs, s.expectErrors)
})
}
}
func TestRelationFieldFindSetter(t *testing.T) {
scenarios := []struct {
name string
key string
value any
field *core.RelationField
hasSetter bool
expected string
}{
{
"no match",
"example",
"b",
&core.RelationField{Name: "test", MaxSelect: 1},
false,
"",
},
{
"exact match (single)",
"test",
"b",
&core.RelationField{Name: "test", MaxSelect: 1},
true,
`"b"`,
},
{
"exact match (multiple)",
"test",
[]string{"a", "b"},
&core.RelationField{Name: "test", MaxSelect: 2},
true,
`["a","b"]`,
},
{
"append (single)",
"test+",
"b",
&core.RelationField{Name: "test", MaxSelect: 1},
true,
`"b"`,
},
{
"append (multiple)",
"test+",
[]string{"a"},
&core.RelationField{Name: "test", MaxSelect: 2},
true,
`["c","d","a"]`,
},
{
"prepend (single)",
"+test",
"b",
&core.RelationField{Name: "test", MaxSelect: 1},
true,
`"d"`, // the last of the existing values
},
{
"prepend (multiple)",
"+test",
[]string{"a"},
&core.RelationField{Name: "test", MaxSelect: 2},
true,
`["a","c","d"]`,
},
{
"subtract (single)",
"test-",
"d",
&core.RelationField{Name: "test", MaxSelect: 1},
true,
`"c"`,
},
{
"subtract (multiple)",
"test-",
[]string{"unknown", "c"},
&core.RelationField{Name: "test", MaxSelect: 2},
true,
`["d"]`,
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
collection := core.NewBaseCollection("test_collection")
collection.Fields.Add(s.field)
setter := s.field.FindSetter(s.key)
hasSetter := setter != nil
if hasSetter != s.hasSetter {
t.Fatalf("Expected hasSetter %v, got %v", s.hasSetter, hasSetter)
}
if !hasSetter {
return
}
record := core.NewRecord(collection)
record.SetRaw(s.field.GetName(), []string{"c", "d"})
setter(record, s.value)
raw, err := json.Marshal(record.Get(s.field.GetName()))
if err != nil {
t.Fatal(err)
}
rawStr := string(raw)
if rawStr != s.expected {
t.Fatalf("Expected %q, got %q", s.expected, rawStr)
}
})
}
}

275
core/field_select.go Normal file
View file

@ -0,0 +1,275 @@
package core
import (
"context"
"database/sql/driver"
"slices"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/pocketbase/pocketbase/tools/list"
"github.com/pocketbase/pocketbase/tools/types"
)
func init() {
Fields[FieldTypeSelect] = func() Field {
return &SelectField{}
}
}
const FieldTypeSelect = "select"
var (
_ Field = (*SelectField)(nil)
_ MultiValuer = (*SelectField)(nil)
_ DriverValuer = (*SelectField)(nil)
_ SetterFinder = (*SelectField)(nil)
)
// SelectField defines "select" type field for storing single or
// multiple string values from a predefined list.
//
// Requires the Values option to be set.
//
// If MaxSelect is not set or <= 1, then the field value is expected to be a single Values element.
//
// If MaxSelect is > 1, then the field value is expected to be a subset of Values slice.
//
// The respective zero record field value is either empty string (single) or empty string slice (multiple).
//
// ---
//
// The following additional setter keys are available:
//
// - "fieldName+" - append one or more values to the existing record one. For example:
//
// record.Set("roles+", []string{"new1", "new2"}) // []string{"old1", "old2", "new1", "new2"}
//
// - "+fieldName" - prepend one or more values to the existing record one. For example:
//
// record.Set("+roles", []string{"new1", "new2"}) // []string{"new1", "new2", "old1", "old2"}
//
// - "fieldName-" - subtract one or more values from the existing record one. For example:
//
// record.Set("roles-", "old1") // []string{"old2"}
type SelectField struct {
// Name (required) is the unique name of the field.
Name string `form:"name" json:"name"`
// Id is the unique stable field identifier.
//
// It is automatically generated from the name when adding to a collection FieldsList.
Id string `form:"id" json:"id"`
// System prevents the renaming and removal of the field.
System bool `form:"system" json:"system"`
// Hidden hides the field from the API response.
Hidden bool `form:"hidden" json:"hidden"`
// Presentable hints the Dashboard UI to use the underlying
// field record value in the relation preview label.
Presentable bool `form:"presentable" json:"presentable"`
// ---
// Values specifies the list of accepted values.
Values []string `form:"values" json:"values"`
// MaxSelect specifies the max allowed selected values.
//
// For multiple select the value must be > 1, otherwise fallbacks to single (default).
MaxSelect int `form:"maxSelect" json:"maxSelect"`
// Required will require the field value to be non-empty.
Required bool `form:"required" json:"required"`
}
// Type implements [Field.Type] interface method.
func (f *SelectField) Type() string {
return FieldTypeSelect
}
// GetId implements [Field.GetId] interface method.
func (f *SelectField) GetId() string {
return f.Id
}
// SetId implements [Field.SetId] interface method.
func (f *SelectField) SetId(id string) {
f.Id = id
}
// GetName implements [Field.GetName] interface method.
func (f *SelectField) GetName() string {
return f.Name
}
// SetName implements [Field.SetName] interface method.
func (f *SelectField) SetName(name string) {
f.Name = name
}
// GetSystem implements [Field.GetSystem] interface method.
func (f *SelectField) GetSystem() bool {
return f.System
}
// SetSystem implements [Field.SetSystem] interface method.
func (f *SelectField) SetSystem(system bool) {
f.System = system
}
// GetHidden implements [Field.GetHidden] interface method.
func (f *SelectField) GetHidden() bool {
return f.Hidden
}
// SetHidden implements [Field.SetHidden] interface method.
func (f *SelectField) SetHidden(hidden bool) {
f.Hidden = hidden
}
// IsMultiple implements [MultiValuer] interface and checks whether the
// current field options support multiple values.
func (f *SelectField) IsMultiple() bool {
return f.MaxSelect > 1
}
// ColumnType implements [Field.ColumnType] interface method.
func (f *SelectField) ColumnType(app App) string {
if f.IsMultiple() {
return "JSON DEFAULT '[]' NOT NULL"
}
return "TEXT DEFAULT '' NOT NULL"
}
// PrepareValue implements [Field.PrepareValue] interface method.
func (f *SelectField) PrepareValue(record *Record, raw any) (any, error) {
return f.normalizeValue(raw), nil
}
func (f *SelectField) normalizeValue(raw any) any {
val := list.ToUniqueStringSlice(raw)
if !f.IsMultiple() {
if len(val) > 0 {
return val[len(val)-1] // the last selected
}
return ""
}
return val
}
// DriverValue implements the [DriverValuer] interface.
func (f *SelectField) DriverValue(record *Record) (driver.Value, error) {
val := list.ToUniqueStringSlice(record.GetRaw(f.Name))
if !f.IsMultiple() {
if len(val) > 0 {
return val[len(val)-1], nil // the last selected
}
return "", nil
}
// serialize as json string array
return append(types.JSONArray[string]{}, val...), nil
}
// ValidateValue implements [Field.ValidateValue] interface method.
func (f *SelectField) ValidateValue(ctx context.Context, app App, record *Record) error {
normalizedVal := list.ToUniqueStringSlice(record.GetRaw(f.Name))
if len(normalizedVal) == 0 {
if f.Required {
return validation.ErrRequired
}
return nil // nothing to check
}
maxSelect := max(f.MaxSelect, 1)
// check max selected items
if len(normalizedVal) > maxSelect {
return validation.NewError("validation_too_many_values", "Select no more than {{.maxSelect}}").
SetParams(map[string]any{"maxSelect": maxSelect})
}
// check against the allowed values
for _, val := range normalizedVal {
if !slices.Contains(f.Values, val) {
return validation.NewError("validation_invalid_value", "Invalid value {{.value}}").
SetParams(map[string]any{"value": val})
}
}
return nil
}
// ValidateSettings implements [Field.ValidateSettings] interface method.
func (f *SelectField) ValidateSettings(ctx context.Context, app App, collection *Collection) error {
max := len(f.Values)
if max == 0 {
max = 1
}
return validation.ValidateStruct(f,
validation.Field(&f.Id, validation.By(DefaultFieldIdValidationRule)),
validation.Field(&f.Name, validation.By(DefaultFieldNameValidationRule)),
validation.Field(&f.Values, validation.Required),
validation.Field(&f.MaxSelect, validation.Min(0), validation.Max(max)),
)
}
// FindSetter implements the [SetterFinder] interface.
func (f *SelectField) FindSetter(key string) SetterFunc {
switch key {
case f.Name:
return f.setValue
case "+" + f.Name:
return f.prependValue
case f.Name + "+":
return f.appendValue
case f.Name + "-":
return f.subtractValue
default:
return nil
}
}
func (f *SelectField) setValue(record *Record, raw any) {
record.SetRaw(f.Name, f.normalizeValue(raw))
}
func (f *SelectField) appendValue(record *Record, modifierValue any) {
val := record.GetRaw(f.Name)
val = append(
list.ToUniqueStringSlice(val),
list.ToUniqueStringSlice(modifierValue)...,
)
f.setValue(record, val)
}
func (f *SelectField) prependValue(record *Record, modifierValue any) {
val := record.GetRaw(f.Name)
val = append(
list.ToUniqueStringSlice(modifierValue),
list.ToUniqueStringSlice(val)...,
)
f.setValue(record, val)
}
func (f *SelectField) subtractValue(record *Record, modifierValue any) {
val := record.GetRaw(f.Name)
val = list.SubtractSlice(
list.ToUniqueStringSlice(val),
list.ToUniqueStringSlice(modifierValue),
)
f.setValue(record, val)
}

516
core/field_select_test.go Normal file
View file

@ -0,0 +1,516 @@
package core_test
import (
"context"
"encoding/json"
"fmt"
"testing"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
"github.com/pocketbase/pocketbase/tools/types"
)
func TestSelectFieldBaseMethods(t *testing.T) {
testFieldBaseMethods(t, core.FieldTypeSelect)
}
func TestSelectFieldColumnType(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
scenarios := []struct {
name string
field *core.SelectField
expected string
}{
{
"single (zero)",
&core.SelectField{},
"TEXT DEFAULT '' NOT NULL",
},
{
"single",
&core.SelectField{MaxSelect: 1},
"TEXT DEFAULT '' NOT NULL",
},
{
"multiple",
&core.SelectField{MaxSelect: 2},
"JSON DEFAULT '[]' NOT NULL",
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
if v := s.field.ColumnType(app); v != s.expected {
t.Fatalf("Expected\n%q\ngot\n%q", s.expected, v)
}
})
}
}
func TestSelectFieldIsMultiple(t *testing.T) {
scenarios := []struct {
name string
field *core.SelectField
expected bool
}{
{
"single (zero)",
&core.SelectField{},
false,
},
{
"single",
&core.SelectField{MaxSelect: 1},
false,
},
{
"multiple (>1)",
&core.SelectField{MaxSelect: 2},
true,
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
if v := s.field.IsMultiple(); v != s.expected {
t.Fatalf("Expected %v, got %v", s.expected, v)
}
})
}
}
func TestSelectFieldPrepareValue(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
record := core.NewRecord(core.NewBaseCollection("test"))
scenarios := []struct {
raw any
field *core.SelectField
expected string
}{
// single
{nil, &core.SelectField{}, `""`},
{"", &core.SelectField{}, `""`},
{123, &core.SelectField{}, `"123"`},
{"a", &core.SelectField{}, `"a"`},
{`["a"]`, &core.SelectField{}, `"a"`},
{[]string{}, &core.SelectField{}, `""`},
{[]string{"a", "b"}, &core.SelectField{}, `"b"`},
// multiple
{nil, &core.SelectField{MaxSelect: 2}, `[]`},
{"", &core.SelectField{MaxSelect: 2}, `[]`},
{123, &core.SelectField{MaxSelect: 2}, `["123"]`},
{"a", &core.SelectField{MaxSelect: 2}, `["a"]`},
{`["a"]`, &core.SelectField{MaxSelect: 2}, `["a"]`},
{[]string{}, &core.SelectField{MaxSelect: 2}, `[]`},
{[]string{"a", "b", "c"}, &core.SelectField{MaxSelect: 2}, `["a","b","c"]`},
}
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%#v_%v", i, s.raw, s.field.IsMultiple()), func(t *testing.T) {
v, err := s.field.PrepareValue(record, s.raw)
if err != nil {
t.Fatal(err)
}
vRaw, err := json.Marshal(v)
if err != nil {
t.Fatal(err)
}
if string(vRaw) != s.expected {
t.Fatalf("Expected %q, got %q", s.expected, vRaw)
}
})
}
}
func TestSelectFieldDriverValue(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
scenarios := []struct {
raw any
field *core.SelectField
expected string
}{
// single
{nil, &core.SelectField{}, `""`},
{"", &core.SelectField{}, `""`},
{123, &core.SelectField{}, `"123"`},
{"a", &core.SelectField{}, `"a"`},
{`["a"]`, &core.SelectField{}, `"a"`},
{[]string{}, &core.SelectField{}, `""`},
{[]string{"a", "b"}, &core.SelectField{}, `"b"`},
// multiple
{nil, &core.SelectField{MaxSelect: 2}, `[]`},
{"", &core.SelectField{MaxSelect: 2}, `[]`},
{123, &core.SelectField{MaxSelect: 2}, `["123"]`},
{"a", &core.SelectField{MaxSelect: 2}, `["a"]`},
{`["a"]`, &core.SelectField{MaxSelect: 2}, `["a"]`},
{[]string{}, &core.SelectField{MaxSelect: 2}, `[]`},
{[]string{"a", "b", "c"}, &core.SelectField{MaxSelect: 2}, `["a","b","c"]`},
}
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%#v_%v", i, s.raw, s.field.IsMultiple()), func(t *testing.T) {
record := core.NewRecord(core.NewBaseCollection("test"))
record.SetRaw(s.field.GetName(), s.raw)
v, err := s.field.DriverValue(record)
if err != nil {
t.Fatal(err)
}
if s.field.IsMultiple() {
_, ok := v.(types.JSONArray[string])
if !ok {
t.Fatalf("Expected types.JSONArray value, got %T", v)
}
} else {
_, ok := v.(string)
if !ok {
t.Fatalf("Expected string value, got %T", v)
}
}
vRaw, err := json.Marshal(v)
if err != nil {
t.Fatal(err)
}
if string(vRaw) != s.expected {
t.Fatalf("Expected %q, got %q", s.expected, vRaw)
}
})
}
}
func TestSelectFieldValidateValue(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
collection := core.NewBaseCollection("test_collection")
values := []string{"a", "b", "c"}
scenarios := []struct {
name string
field *core.SelectField
record func() *core.Record
expectError bool
}{
// single
{
"[single] zero field value (not required)",
&core.SelectField{Name: "test", Values: values, MaxSelect: 1},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", "")
return record
},
false,
},
{
"[single] zero field value (required)",
&core.SelectField{Name: "test", Values: values, MaxSelect: 1, Required: true},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", "")
return record
},
true,
},
{
"[single] unknown value",
&core.SelectField{Name: "test", Values: values, MaxSelect: 1},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", "unknown")
return record
},
true,
},
{
"[single] known value",
&core.SelectField{Name: "test", Values: values, MaxSelect: 1},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", "a")
return record
},
false,
},
{
"[single] > MaxSelect",
&core.SelectField{Name: "test", Values: values, MaxSelect: 1},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", []string{"a", "b"})
return record
},
true,
},
// multiple
{
"[multiple] zero field value (not required)",
&core.SelectField{Name: "test", Values: values, MaxSelect: 2},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", []string{})
return record
},
false,
},
{
"[multiple] zero field value (required)",
&core.SelectField{Name: "test", Values: values, MaxSelect: 2, Required: true},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", []string{})
return record
},
true,
},
{
"[multiple] unknown value",
&core.SelectField{Name: "test", Values: values, MaxSelect: 2},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", []string{"a", "unknown"})
return record
},
true,
},
{
"[multiple] known value",
&core.SelectField{Name: "test", Values: values, MaxSelect: 2},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", []string{"a", "b"})
return record
},
false,
},
{
"[multiple] > MaxSelect",
&core.SelectField{Name: "test", Values: values, MaxSelect: 2},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", []string{"a", "b", "c"})
return record
},
true,
},
{
"[multiple] > MaxSelect (duplicated values)",
&core.SelectField{Name: "test", Values: values, MaxSelect: 2},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", []string{"a", "b", "b", "a"})
return record
},
false,
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
err := s.field.ValidateValue(context.Background(), app, s.record())
hasErr := err != nil
if hasErr != s.expectError {
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
}
})
}
}
func TestSelectFieldValidateSettings(t *testing.T) {
testDefaultFieldIdValidation(t, core.FieldTypeSelect)
testDefaultFieldNameValidation(t, core.FieldTypeSelect)
app, _ := tests.NewTestApp()
defer app.Cleanup()
scenarios := []struct {
name string
field func() *core.SelectField
expectErrors []string
}{
{
"zero minimal",
func() *core.SelectField {
return &core.SelectField{
Id: "test",
Name: "test",
}
},
[]string{"values"},
},
{
"MaxSelect > Values length",
func() *core.SelectField {
return &core.SelectField{
Id: "test",
Name: "test",
Values: []string{"a", "b"},
MaxSelect: 3,
}
},
[]string{"maxSelect"},
},
{
"MaxSelect <= Values length",
func() *core.SelectField {
return &core.SelectField{
Id: "test",
Name: "test",
Values: []string{"a", "b"},
MaxSelect: 2,
}
},
[]string{},
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
field := s.field()
collection := core.NewBaseCollection("test_collection")
collection.Fields.Add(field)
errs := field.ValidateSettings(context.Background(), app, collection)
tests.TestValidationErrors(t, errs, s.expectErrors)
})
}
}
func TestSelectFieldFindSetter(t *testing.T) {
values := []string{"a", "b", "c", "d"}
scenarios := []struct {
name string
key string
value any
field *core.SelectField
hasSetter bool
expected string
}{
{
"no match",
"example",
"b",
&core.SelectField{Name: "test", MaxSelect: 1, Values: values},
false,
"",
},
{
"exact match (single)",
"test",
"b",
&core.SelectField{Name: "test", MaxSelect: 1, Values: values},
true,
`"b"`,
},
{
"exact match (multiple)",
"test",
[]string{"a", "b"},
&core.SelectField{Name: "test", MaxSelect: 2, Values: values},
true,
`["a","b"]`,
},
{
"append (single)",
"test+",
"b",
&core.SelectField{Name: "test", MaxSelect: 1, Values: values},
true,
`"b"`,
},
{
"append (multiple)",
"test+",
[]string{"a"},
&core.SelectField{Name: "test", MaxSelect: 2, Values: values},
true,
`["c","d","a"]`,
},
{
"prepend (single)",
"+test",
"b",
&core.SelectField{Name: "test", MaxSelect: 1, Values: values},
true,
`"d"`, // the last of the existing values
},
{
"prepend (multiple)",
"+test",
[]string{"a"},
&core.SelectField{Name: "test", MaxSelect: 2, Values: values},
true,
`["a","c","d"]`,
},
{
"subtract (single)",
"test-",
"d",
&core.SelectField{Name: "test", MaxSelect: 1, Values: values},
true,
`"c"`,
},
{
"subtract (multiple)",
"test-",
[]string{"unknown", "c"},
&core.SelectField{Name: "test", MaxSelect: 2, Values: values},
true,
`["d"]`,
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
collection := core.NewBaseCollection("test_collection")
collection.Fields.Add(s.field)
setter := s.field.FindSetter(s.key)
hasSetter := setter != nil
if hasSetter != s.hasSetter {
t.Fatalf("Expected hasSetter %v, got %v", s.hasSetter, hasSetter)
}
if !hasSetter {
return
}
record := core.NewRecord(collection)
record.SetRaw(s.field.GetName(), []string{"c", "d"})
setter(record, s.value)
raw, err := json.Marshal(record.Get(s.field.GetName()))
if err != nil {
t.Fatal(err)
}
rawStr := string(raw)
if rawStr != s.expected {
t.Fatalf("Expected %q, got %q", s.expected, rawStr)
}
})
}
}

261
core/field_test.go Normal file
View file

@ -0,0 +1,261 @@
package core_test
import (
"context"
"strings"
"testing"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
)
func testFieldBaseMethods(t *testing.T, fieldType string) {
factory, ok := core.Fields[fieldType]
if !ok {
t.Fatalf("Missing %q field factory", fieldType)
}
f := factory()
if f == nil {
t.Fatal("Expected non-nil Field instance")
}
t.Run("type", func(t *testing.T) {
if v := f.Type(); v != fieldType {
t.Fatalf("Expected type %q, got %q", fieldType, v)
}
})
t.Run("id", func(t *testing.T) {
testValues := []string{"new_id", ""}
for _, expected := range testValues {
f.SetId(expected)
if v := f.GetId(); v != expected {
t.Fatalf("Expected id %q, got %q", expected, v)
}
}
})
t.Run("name", func(t *testing.T) {
testValues := []string{"new_name", ""}
for _, expected := range testValues {
f.SetName(expected)
if v := f.GetName(); v != expected {
t.Fatalf("Expected name %q, got %q", expected, v)
}
}
})
t.Run("system", func(t *testing.T) {
testValues := []bool{false, true}
for _, expected := range testValues {
f.SetSystem(expected)
if v := f.GetSystem(); v != expected {
t.Fatalf("Expected system %v, got %v", expected, v)
}
}
})
t.Run("hidden", func(t *testing.T) {
testValues := []bool{false, true}
for _, expected := range testValues {
f.SetHidden(expected)
if v := f.GetHidden(); v != expected {
t.Fatalf("Expected hidden %v, got %v", expected, v)
}
}
})
}
func testDefaultFieldIdValidation(t *testing.T, fieldType string) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
collection := core.NewBaseCollection("test_collection")
scenarios := []struct {
name string
field func() core.Field
expectError bool
}{
{
"empty value",
func() core.Field {
f := core.Fields[fieldType]()
return f
},
true,
},
{
"invalid length",
func() core.Field {
f := core.Fields[fieldType]()
f.SetId(strings.Repeat("a", 101))
return f
},
true,
},
{
"valid length",
func() core.Field {
f := core.Fields[fieldType]()
f.SetId(strings.Repeat("a", 100))
return f
},
false,
},
}
for _, s := range scenarios {
t.Run("[id] "+s.name, func(t *testing.T) {
errs, _ := s.field().ValidateSettings(context.Background(), app, collection).(validation.Errors)
hasErr := errs["id"] != nil
if hasErr != s.expectError {
t.Fatalf("Expected hasErr %v, got %v", s.expectError, hasErr)
}
})
}
}
func testDefaultFieldNameValidation(t *testing.T, fieldType string) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
collection := core.NewBaseCollection("test_collection")
scenarios := []struct {
name string
field func() core.Field
expectError bool
}{
{
"empty value",
func() core.Field {
f := core.Fields[fieldType]()
return f
},
true,
},
{
"invalid length",
func() core.Field {
f := core.Fields[fieldType]()
f.SetName(strings.Repeat("a", 101))
return f
},
true,
},
{
"valid length",
func() core.Field {
f := core.Fields[fieldType]()
f.SetName(strings.Repeat("a", 100))
return f
},
false,
},
{
"invalid regex",
func() core.Field {
f := core.Fields[fieldType]()
f.SetName("test(")
return f
},
true,
},
{
"valid regex",
func() core.Field {
f := core.Fields[fieldType]()
f.SetName("test_123")
return f
},
false,
},
{
"_via_",
func() core.Field {
f := core.Fields[fieldType]()
f.SetName("a_via_b")
return f
},
true,
},
{
"system reserved - null",
func() core.Field {
f := core.Fields[fieldType]()
f.SetName("null")
return f
},
true,
},
{
"system reserved - false",
func() core.Field {
f := core.Fields[fieldType]()
f.SetName("false")
return f
},
true,
},
{
"system reserved - true",
func() core.Field {
f := core.Fields[fieldType]()
f.SetName("true")
return f
},
true,
},
{
"system reserved - _rowid_",
func() core.Field {
f := core.Fields[fieldType]()
f.SetName("_rowid_")
return f
},
true,
},
{
"system reserved - expand",
func() core.Field {
f := core.Fields[fieldType]()
f.SetName("expand")
return f
},
true,
},
{
"system reserved - collectionId",
func() core.Field {
f := core.Fields[fieldType]()
f.SetName("collectionId")
return f
},
true,
},
{
"system reserved - collectionName",
func() core.Field {
f := core.Fields[fieldType]()
f.SetName("collectionName")
return f
},
true,
},
}
for _, s := range scenarios {
t.Run("[name] "+s.name, func(t *testing.T) {
errs, _ := s.field().ValidateSettings(context.Background(), app, collection).(validation.Errors)
hasErr := errs["name"] != nil
if hasErr != s.expectError {
t.Fatalf("Expected hasErr %v, got %v", s.expectError, hasErr)
}
})
}
}

367
core/field_text.go Normal file
View file

@ -0,0 +1,367 @@
package core
import (
"context"
"database/sql"
"errors"
"fmt"
"regexp"
"strings"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase/core/validators"
"github.com/pocketbase/pocketbase/tools/security"
"github.com/spf13/cast"
)
func init() {
Fields[FieldTypeText] = func() Field {
return &TextField{}
}
}
const FieldTypeText = "text"
const autogenerateModifier = ":autogenerate"
var (
_ Field = (*TextField)(nil)
_ SetterFinder = (*TextField)(nil)
_ RecordInterceptor = (*TextField)(nil)
)
// TextField defines "text" type field for storing any string value.
//
// The respective zero record field value is empty string.
//
// The following additional setter keys are available:
//
// - "fieldName:autogenerate" - autogenerate field value if AutogeneratePattern is set. For example:
//
// record.Set("slug:autogenerate", "") // [random value]
// record.Set("slug:autogenerate", "abc-") // abc-[random value]
type TextField struct {
// Name (required) is the unique name of the field.
Name string `form:"name" json:"name"`
// Id is the unique stable field identifier.
//
// It is automatically generated from the name when adding to a collection FieldsList.
Id string `form:"id" json:"id"`
// System prevents the renaming and removal of the field.
System bool `form:"system" json:"system"`
// Hidden hides the field from the API response.
Hidden bool `form:"hidden" json:"hidden"`
// Presentable hints the Dashboard UI to use the underlying
// field record value in the relation preview label.
Presentable bool `form:"presentable" json:"presentable"`
// ---
// Min specifies the minimum required string characters.
//
// if zero value, no min limit is applied.
Min int `form:"min" json:"min"`
// Max specifies the maximum allowed string characters.
//
// If zero, a default limit of 5000 is applied.
Max int `form:"max" json:"max"`
// Pattern specifies an optional regex pattern to match against the field value.
//
// Leave it empty to skip the pattern check.
Pattern string `form:"pattern" json:"pattern"`
// AutogeneratePattern specifies an optional regex pattern that could
// be used to generate random string from it and set it automatically
// on record create if no explicit value is set or when the `:autogenerate` modifier is used.
//
// Note: the generated value still needs to satisfy min, max, pattern (if set)
AutogeneratePattern string `form:"autogeneratePattern" json:"autogeneratePattern"`
// Required will require the field value to be non-empty string.
Required bool `form:"required" json:"required"`
// PrimaryKey will mark the field as primary key.
//
// A single collection can have only 1 field marked as primary key.
PrimaryKey bool `form:"primaryKey" json:"primaryKey"`
}
// Type implements [Field.Type] interface method.
func (f *TextField) Type() string {
return FieldTypeText
}
// GetId implements [Field.GetId] interface method.
func (f *TextField) GetId() string {
return f.Id
}
// SetId implements [Field.SetId] interface method.
func (f *TextField) SetId(id string) {
f.Id = id
}
// GetName implements [Field.GetName] interface method.
func (f *TextField) GetName() string {
return f.Name
}
// SetName implements [Field.SetName] interface method.
func (f *TextField) SetName(name string) {
f.Name = name
}
// GetSystem implements [Field.GetSystem] interface method.
func (f *TextField) GetSystem() bool {
return f.System
}
// SetSystem implements [Field.SetSystem] interface method.
func (f *TextField) SetSystem(system bool) {
f.System = system
}
// GetHidden implements [Field.GetHidden] interface method.
func (f *TextField) GetHidden() bool {
return f.Hidden
}
// SetHidden implements [Field.SetHidden] interface method.
func (f *TextField) SetHidden(hidden bool) {
f.Hidden = hidden
}
// ColumnType implements [Field.ColumnType] interface method.
func (f *TextField) ColumnType(app App) string {
if f.PrimaryKey {
// note: the default is just a last resort fallback to avoid empty
// string values in case the record was inserted with raw sql and
// it is not actually used when operating with the db abstraction
return "TEXT PRIMARY KEY DEFAULT ('r'||lower(hex(randomblob(7)))) NOT NULL"
}
return "TEXT DEFAULT '' NOT NULL"
}
// PrepareValue implements [Field.PrepareValue] interface method.
func (f *TextField) PrepareValue(record *Record, raw any) (any, error) {
return cast.ToString(raw), nil
}
var forbiddenPKChars = []string{"/", "\\"}
// ValidateValue implements [Field.ValidateValue] interface method.
func (f *TextField) ValidateValue(ctx context.Context, app App, record *Record) error {
newVal, ok := record.GetRaw(f.Name).(string)
if !ok {
return validators.ErrUnsupportedValueType
}
if f.PrimaryKey {
// disallow PK change
if !record.IsNew() {
oldVal := record.LastSavedPK()
if oldVal != newVal {
return validation.NewError("validation_pk_change", "The record primary key cannot be changed.")
}
if oldVal != "" {
// no need to further validate because the id can't be updated
// and because the id could have been inserted manually by migration from another system
// that may not comply with the user defined PocketBase validations
return nil
}
} else {
// disallow PK special characters no matter of the Pattern validator to minimize
// side-effects when the primary key is used for example in a directory path
for _, c := range forbiddenPKChars {
if strings.Contains(newVal, c) {
return validation.NewError("validation_pk_forbidden", "The record primary key contains forbidden characters.").
SetParams(map[string]any{"forbidden": c})
}
}
// this technically shouldn't be necessarily but again to
// minimize misuse of the Pattern validator that could cause
// side-effects on some platforms check for duplicates in a case-insensitive manner
//
// (@todo eventually may get replaced in the future with a system unique constraint to avoid races or wrapping the request in a transaction)
if f.Pattern != defaultLowercaseRecordIdPattern {
var exists int
err := app.ConcurrentDB().
Select("(1)").
From(record.TableName()).
Where(dbx.NewExp("id = {:id} COLLATE NOCASE", dbx.Params{"id": newVal})).
Limit(1).
Row(&exists)
if exists > 0 || (err != nil && !errors.Is(err, sql.ErrNoRows)) {
return validation.NewError("validation_pk_invalid", "The record primary key is invalid or already exists.")
}
}
}
}
return f.ValidatePlainValue(newVal)
}
// ValidatePlainValue validates the provided string against the field options.
func (f *TextField) ValidatePlainValue(value string) error {
if f.Required || f.PrimaryKey {
if err := validation.Required.Validate(value); err != nil {
return err
}
}
if value == "" {
return nil // nothing to check
}
// note: casted to []rune to count multi-byte chars as one
length := len([]rune(value))
if f.Min > 0 && length < f.Min {
return validation.NewError("validation_min_text_constraint", "Must be at least {{.min}} character(s)").
SetParams(map[string]any{"min": f.Min})
}
max := f.Max
if max == 0 {
max = 5000
}
if max > 0 && length > max {
return validation.NewError("validation_max_text_constraint", "Must be no more than {{.max}} character(s)").
SetParams(map[string]any{"max": max})
}
if f.Pattern != "" {
match, _ := regexp.MatchString(f.Pattern, value)
if !match {
return validation.NewError("validation_invalid_format", "Invalid value format")
}
}
return nil
}
// ValidateSettings implements [Field.ValidateSettings] interface method.
func (f *TextField) ValidateSettings(ctx context.Context, app App, collection *Collection) error {
return validation.ValidateStruct(f,
validation.Field(&f.Id, validation.By(DefaultFieldIdValidationRule)),
validation.Field(&f.Name,
validation.By(DefaultFieldNameValidationRule),
validation.When(f.PrimaryKey, validation.In(idColumn).Error(`The primary key must be named "id".`)),
),
validation.Field(&f.PrimaryKey, validation.By(f.checkOtherFieldsForPK(collection))),
validation.Field(&f.Min, validation.Min(0), validation.Max(maxSafeJSONInt)),
validation.Field(&f.Max, validation.Min(f.Min), validation.Max(maxSafeJSONInt)),
validation.Field(&f.Pattern, validation.When(f.PrimaryKey, validation.Required), validation.By(validators.IsRegex)),
validation.Field(&f.Hidden, validation.When(f.PrimaryKey, validation.Empty)),
validation.Field(&f.Required, validation.When(f.PrimaryKey, validation.Required)),
validation.Field(&f.AutogeneratePattern, validation.By(validators.IsRegex), validation.By(f.checkAutogeneratePattern)),
)
}
func (f *TextField) checkOtherFieldsForPK(collection *Collection) validation.RuleFunc {
return func(value any) error {
v, _ := value.(bool)
if !v {
return nil // not a pk
}
totalPrimaryKeys := 0
for _, field := range collection.Fields {
if text, ok := field.(*TextField); ok && text.PrimaryKey {
totalPrimaryKeys++
}
if totalPrimaryKeys > 1 {
return validation.NewError("validation_unsupported_composite_pk", "Composite PKs are not supported and the collection must have only 1 PK.")
}
}
return nil
}
}
func (f *TextField) checkAutogeneratePattern(value any) error {
v, _ := value.(string)
if v == "" {
return nil // nothing to check
}
// run 10 tests to check for conflicts with the other field validators
for i := 0; i < 10; i++ {
generated, err := security.RandomStringByRegex(v)
if err != nil {
return validation.NewError("validation_invalid_autogenerate_pattern", err.Error())
}
// (loosely) check whether the generated pattern satisfies the current field settings
if err := f.ValidatePlainValue(generated); err != nil {
return validation.NewError(
"validation_invalid_autogenerate_pattern_value",
fmt.Sprintf("The provided autogenerate pattern could produce invalid field values, ex.: %q", generated),
)
}
}
return nil
}
// Intercept implements the [RecordInterceptor] interface.
func (f *TextField) Intercept(
ctx context.Context,
app App,
record *Record,
actionName string,
actionFunc func() error,
) error {
// set autogenerated value if missing for new records
switch actionName {
case InterceptorActionValidate, InterceptorActionCreate:
if f.AutogeneratePattern != "" && f.hasZeroValue(record) && record.IsNew() {
v, err := security.RandomStringByRegex(f.AutogeneratePattern)
if err != nil {
return fmt.Errorf("failed to autogenerate %q value: %w", f.Name, err)
}
record.SetRaw(f.Name, v)
}
}
return actionFunc()
}
func (f *TextField) hasZeroValue(record *Record) bool {
v, _ := record.GetRaw(f.Name).(string)
return v == ""
}
// FindSetter implements the [SetterFinder] interface.
func (f *TextField) FindSetter(key string) SetterFunc {
switch key {
case f.Name:
return func(record *Record, raw any) {
record.SetRaw(f.Name, cast.ToString(raw))
}
case f.Name + autogenerateModifier:
return func(record *Record, raw any) {
v := cast.ToString(raw)
if f.AutogeneratePattern != "" {
generated, _ := security.RandomStringByRegex(f.AutogeneratePattern)
v += generated
}
record.SetRaw(f.Name, v)
}
default:
return nil
}
}

654
core/field_text_test.go Normal file
View file

@ -0,0 +1,654 @@
package core_test
import (
"context"
"fmt"
"strings"
"testing"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
)
func TestTextFieldBaseMethods(t *testing.T) {
testFieldBaseMethods(t, core.FieldTypeText)
}
func TestTextFieldColumnType(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
f := &core.TextField{}
expected := "TEXT DEFAULT '' NOT NULL"
if v := f.ColumnType(app); v != expected {
t.Fatalf("Expected\n%q\ngot\n%q", expected, v)
}
}
func TestTextFieldPrepareValue(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
f := &core.TextField{}
record := core.NewRecord(core.NewBaseCollection("test"))
scenarios := []struct {
raw any
expected string
}{
{"", ""},
{"test", "test"},
{false, "false"},
{true, "true"},
{123.456, "123.456"},
}
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%#v", i, s.raw), func(t *testing.T) {
v, err := f.PrepareValue(record, s.raw)
if err != nil {
t.Fatal(err)
}
vStr, ok := v.(string)
if !ok {
t.Fatalf("Expected string instance, got %T", v)
}
if vStr != s.expected {
t.Fatalf("Expected %q, got %q", s.expected, v)
}
})
}
}
func TestTextFieldValidateValue(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
collection, err := app.FindCollectionByNameOrId("demo1")
if err != nil {
t.Fatal(err)
}
existingRecord, err := app.FindFirstRecordByFilter(collection, "id != ''")
if err != nil {
t.Fatal(err)
}
scenarios := []struct {
name string
field *core.TextField
record func() *core.Record
expectError bool
}{
{
"invalid raw value",
&core.TextField{Name: "test"},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", 123)
return record
},
true,
},
{
"zero field value (not required)",
&core.TextField{Name: "test", Pattern: `\d+`, Min: 10, Max: 100}, // other fields validators should be ignored
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", "")
return record
},
false,
},
{
"zero field value (required)",
&core.TextField{Name: "test", Required: true},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", "")
return record
},
true,
},
{
"non-zero field value (required)",
&core.TextField{Name: "test", Required: true},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", "abc")
return record
},
false,
},
{
"special forbidden character / (non-primaryKey)",
&core.TextField{Name: "test", PrimaryKey: false},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", "/")
return record
},
false,
},
{
"special forbidden character \\ (non-primaryKey)",
&core.TextField{Name: "test", PrimaryKey: false},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", "\\")
return record
},
false,
},
{
"special forbidden character / (primaryKey)",
&core.TextField{Name: "test", PrimaryKey: true},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", "/")
return record
},
true,
},
{
"special forbidden character \\ (primaryKey)",
&core.TextField{Name: "test", PrimaryKey: true},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", "\\")
return record
},
true,
},
{
"zero field value (primaryKey)",
&core.TextField{Name: "test", PrimaryKey: true},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", "")
return record
},
true,
},
{
"non-zero field value (primaryKey)",
&core.TextField{Name: "test", PrimaryKey: true},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", "abcd")
return record
},
false,
},
{
"case-insensitive duplicated primary key check",
&core.TextField{Name: "test", PrimaryKey: true},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", strings.ToUpper(existingRecord.Id))
return record
},
true,
},
{
"< min",
&core.TextField{Name: "test", Min: 4},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", "абв") // multi-byte
return record
},
true,
},
{
">= min",
&core.TextField{Name: "test", Min: 3},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", "абв") // multi-byte
return record
},
false,
},
{
"> default max",
&core.TextField{Name: "test"},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", strings.Repeat("a", 5001))
return record
},
true,
},
{
"<= default max",
&core.TextField{Name: "test"},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", strings.Repeat("a", 500))
return record
},
false,
},
{
"> max",
&core.TextField{Name: "test", Max: 2},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", "абв") // multi-byte
return record
},
true,
},
{
"<= max",
&core.TextField{Name: "test", Min: 3},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", "абв") // multi-byte
return record
},
false,
},
{
"mismatched pattern",
&core.TextField{Name: "test", Pattern: `\d+`},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", "abc")
return record
},
true,
},
{
"matched pattern",
&core.TextField{Name: "test", Pattern: `\d+`},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", "123")
return record
},
false,
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
err := s.field.ValidateValue(context.Background(), app, s.record())
hasErr := err != nil
if hasErr != s.expectError {
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
}
})
}
}
func TestTextFieldValidateSettings(t *testing.T) {
testDefaultFieldIdValidation(t, core.FieldTypeText)
testDefaultFieldNameValidation(t, core.FieldTypeText)
app, _ := tests.NewTestApp()
defer app.Cleanup()
scenarios := []struct {
name string
field func() *core.TextField
expectErrors []string
}{
{
"zero minimal",
func() *core.TextField {
return &core.TextField{
Id: "test",
Name: "test",
}
},
[]string{},
},
{
"primaryKey without required",
func() *core.TextField {
return &core.TextField{
Id: "test",
Name: "id",
PrimaryKey: true,
Pattern: `\d+`,
}
},
[]string{"required"},
},
{
"primaryKey without pattern",
func() *core.TextField {
return &core.TextField{
Id: "test",
Name: "id",
PrimaryKey: true,
Required: true,
}
},
[]string{"pattern"},
},
{
"primaryKey with hidden",
func() *core.TextField {
return &core.TextField{
Id: "test",
Name: "id",
Required: true,
PrimaryKey: true,
Hidden: true,
Pattern: `\d+`,
}
},
[]string{"hidden"},
},
{
"primaryKey with name != id",
func() *core.TextField {
return &core.TextField{
Id: "test",
Name: "test",
PrimaryKey: true,
Required: true,
Pattern: `\d+`,
}
},
[]string{"name"},
},
{
"multiple primaryKey fields",
func() *core.TextField {
return &core.TextField{
Id: "test2",
Name: "id",
PrimaryKey: true,
Pattern: `\d+`,
Required: true,
}
},
[]string{"primaryKey"},
},
{
"invalid pattern",
func() *core.TextField {
return &core.TextField{
Id: "test2",
Name: "id",
Pattern: `(invalid`,
}
},
[]string{"pattern"},
},
{
"valid pattern",
func() *core.TextField {
return &core.TextField{
Id: "test2",
Name: "id",
Pattern: `\d+`,
}
},
[]string{},
},
{
"invalid autogeneratePattern",
func() *core.TextField {
return &core.TextField{
Id: "test2",
Name: "id",
AutogeneratePattern: `(invalid`,
}
},
[]string{"autogeneratePattern"},
},
{
"valid autogeneratePattern",
func() *core.TextField {
return &core.TextField{
Id: "test2",
Name: "id",
AutogeneratePattern: `[a-z]+`,
}
},
[]string{},
},
{
"conflicting pattern and autogeneratePattern",
func() *core.TextField {
return &core.TextField{
Id: "test2",
Name: "id",
Pattern: `\d+`,
AutogeneratePattern: `[a-z]+`,
}
},
[]string{"autogeneratePattern"},
},
{
"Max > safe json int",
func() *core.TextField {
return &core.TextField{
Id: "test",
Name: "test",
Max: 1 << 53,
}
},
[]string{"max"},
},
{
"Max < 0",
func() *core.TextField {
return &core.TextField{
Id: "test",
Name: "test",
Max: -1,
}
},
[]string{"max"},
},
{
"Min > safe json int",
func() *core.TextField {
return &core.TextField{
Id: "test",
Name: "test",
Min: 1 << 53,
}
},
[]string{"min"},
},
{
"Min < 0",
func() *core.TextField {
return &core.TextField{
Id: "test",
Name: "test",
Min: -1,
}
},
[]string{"min"},
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
field := s.field()
collection := core.NewBaseCollection("test_collection")
collection.Fields.GetByName("id").SetId("test") // set a dummy known id so that it can be replaced
collection.Fields.Add(field)
errs := field.ValidateSettings(context.Background(), app, collection)
tests.TestValidationErrors(t, errs, s.expectErrors)
})
}
}
func TestTextFieldAutogenerate(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
collection := core.NewBaseCollection("test_collection")
scenarios := []struct {
name string
actionName string
field *core.TextField
record func() *core.Record
expected string
}{
{
"non-matching action",
core.InterceptorActionUpdate,
&core.TextField{Name: "test", AutogeneratePattern: "abc"},
func() *core.Record {
return core.NewRecord(collection)
},
"",
},
{
"matching action (create)",
core.InterceptorActionCreate,
&core.TextField{Name: "test", AutogeneratePattern: "abc"},
func() *core.Record {
return core.NewRecord(collection)
},
"abc",
},
{
"matching action (validate)",
core.InterceptorActionValidate,
&core.TextField{Name: "test", AutogeneratePattern: "abc"},
func() *core.Record {
return core.NewRecord(collection)
},
"abc",
},
{
"existing non-zero value",
core.InterceptorActionCreate,
&core.TextField{Name: "test", AutogeneratePattern: "abc"},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", "123")
return record
},
"123",
},
{
"non-new record",
core.InterceptorActionValidate,
&core.TextField{Name: "test", AutogeneratePattern: "abc"},
func() *core.Record {
record := core.NewRecord(collection)
record.Id = "test"
record.PostScan()
return record
},
"",
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
actionCalls := 0
record := s.record()
err := s.field.Intercept(context.Background(), app, record, s.actionName, func() error {
actionCalls++
return nil
})
if err != nil {
t.Fatal(err)
}
if actionCalls != 1 {
t.Fatalf("Expected actionCalls %d, got %d", 1, actionCalls)
}
v := record.GetString(s.field.GetName())
if v != s.expected {
t.Fatalf("Expected value %q, got %q", s.expected, v)
}
})
}
}
func TestTextFieldFindSetter(t *testing.T) {
scenarios := []struct {
name string
key string
value any
field *core.TextField
hasSetter bool
expected string
}{
{
"no match",
"example",
"abc",
&core.TextField{Name: "test", AutogeneratePattern: "test"},
false,
"",
},
{
"exact match",
"test",
"abc",
&core.TextField{Name: "test", AutogeneratePattern: "test"},
true,
"abc",
},
{
"autogenerate modifier",
"test:autogenerate",
"abc",
&core.TextField{Name: "test", AutogeneratePattern: "test"},
true,
"abctest",
},
{
"autogenerate modifier without AutogeneratePattern option",
"test:autogenerate",
"abc",
&core.TextField{Name: "test"},
true,
"abc",
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
collection := core.NewBaseCollection("test_collection")
collection.Fields.Add(s.field)
setter := s.field.FindSetter(s.key)
hasSetter := setter != nil
if hasSetter != s.hasSetter {
t.Fatalf("Expected hasSetter %v, got %v", s.hasSetter, hasSetter)
}
if !hasSetter {
return
}
record := core.NewRecord(collection)
setter(record, s.value)
result := record.GetString(s.field.Name)
if result != s.expected {
t.Fatalf("Expected %q, got %q", s.expected, result)
}
})
}
}

168
core/field_url.go Normal file
View file

@ -0,0 +1,168 @@
package core
import (
"context"
"net/url"
"slices"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/go-ozzo/ozzo-validation/v4/is"
"github.com/pocketbase/pocketbase/core/validators"
"github.com/spf13/cast"
)
func init() {
Fields[FieldTypeURL] = func() Field {
return &URLField{}
}
}
const FieldTypeURL = "url"
var _ Field = (*URLField)(nil)
// URLField defines "url" type field for storing a single URL string value.
//
// The respective zero record field value is empty string.
type URLField struct {
// Name (required) is the unique name of the field.
Name string `form:"name" json:"name"`
// Id is the unique stable field identifier.
//
// It is automatically generated from the name when adding to a collection FieldsList.
Id string `form:"id" json:"id"`
// System prevents the renaming and removal of the field.
System bool `form:"system" json:"system"`
// Hidden hides the field from the API response.
Hidden bool `form:"hidden" json:"hidden"`
// Presentable hints the Dashboard UI to use the underlying
// field record value in the relation preview label.
Presentable bool `form:"presentable" json:"presentable"`
// ---
// ExceptDomains will require the URL domain to NOT be included in the listed ones.
//
// This validator can be set only if OnlyDomains is empty.
ExceptDomains []string `form:"exceptDomains" json:"exceptDomains"`
// OnlyDomains will require the URL domain to be included in the listed ones.
//
// This validator can be set only if ExceptDomains is empty.
OnlyDomains []string `form:"onlyDomains" json:"onlyDomains"`
// Required will require the field value to be non-empty URL string.
Required bool `form:"required" json:"required"`
}
// Type implements [Field.Type] interface method.
func (f *URLField) Type() string {
return FieldTypeURL
}
// GetId implements [Field.GetId] interface method.
func (f *URLField) GetId() string {
return f.Id
}
// SetId implements [Field.SetId] interface method.
func (f *URLField) SetId(id string) {
f.Id = id
}
// GetName implements [Field.GetName] interface method.
func (f *URLField) GetName() string {
return f.Name
}
// SetName implements [Field.SetName] interface method.
func (f *URLField) SetName(name string) {
f.Name = name
}
// GetSystem implements [Field.GetSystem] interface method.
func (f *URLField) GetSystem() bool {
return f.System
}
// SetSystem implements [Field.SetSystem] interface method.
func (f *URLField) SetSystem(system bool) {
f.System = system
}
// GetHidden implements [Field.GetHidden] interface method.
func (f *URLField) GetHidden() bool {
return f.Hidden
}
// SetHidden implements [Field.SetHidden] interface method.
func (f *URLField) SetHidden(hidden bool) {
f.Hidden = hidden
}
// ColumnType implements [Field.ColumnType] interface method.
func (f *URLField) ColumnType(app App) string {
return "TEXT DEFAULT '' NOT NULL"
}
// PrepareValue implements [Field.PrepareValue] interface method.
func (f *URLField) PrepareValue(record *Record, raw any) (any, error) {
return cast.ToString(raw), nil
}
// ValidateValue implements [Field.ValidateValue] interface method.
func (f *URLField) ValidateValue(ctx context.Context, app App, record *Record) error {
val, ok := record.GetRaw(f.Name).(string)
if !ok {
return validators.ErrUnsupportedValueType
}
if f.Required {
if err := validation.Required.Validate(val); err != nil {
return err
}
}
if val == "" {
return nil // nothing to check
}
if is.URL.Validate(val) != nil {
return validation.NewError("validation_invalid_url", "Must be a valid url")
}
// extract host/domain
u, _ := url.Parse(val)
// only domains check
if len(f.OnlyDomains) > 0 && !slices.Contains(f.OnlyDomains, u.Host) {
return validation.NewError("validation_url_domain_not_allowed", "Url domain is not allowed")
}
// except domains check
if len(f.ExceptDomains) > 0 && slices.Contains(f.ExceptDomains, u.Host) {
return validation.NewError("validation_url_domain_not_allowed", "Url domain is not allowed")
}
return nil
}
// ValidateSettings implements [Field.ValidateSettings] interface method.
func (f *URLField) ValidateSettings(ctx context.Context, app App, collection *Collection) error {
return validation.ValidateStruct(f,
validation.Field(&f.Id, validation.By(DefaultFieldIdValidationRule)),
validation.Field(&f.Name, validation.By(DefaultFieldNameValidationRule)),
validation.Field(
&f.ExceptDomains,
validation.When(len(f.OnlyDomains) > 0, validation.Empty).Else(validation.Each(is.Domain)),
),
validation.Field(
&f.OnlyDomains,
validation.When(len(f.ExceptDomains) > 0, validation.Empty).Else(validation.Each(is.Domain)),
),
)
}

271
core/field_url_test.go Normal file
View file

@ -0,0 +1,271 @@
package core_test
import (
"context"
"fmt"
"testing"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
)
func TestURLFieldBaseMethods(t *testing.T) {
testFieldBaseMethods(t, core.FieldTypeURL)
}
func TestURLFieldColumnType(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
f := &core.URLField{}
expected := "TEXT DEFAULT '' NOT NULL"
if v := f.ColumnType(app); v != expected {
t.Fatalf("Expected\n%q\ngot\n%q", expected, v)
}
}
func TestURLFieldPrepareValue(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
f := &core.URLField{}
record := core.NewRecord(core.NewBaseCollection("test"))
scenarios := []struct {
raw any
expected string
}{
{"", ""},
{"test", "test"},
{false, "false"},
{true, "true"},
{123.456, "123.456"},
}
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%#v", i, s.raw), func(t *testing.T) {
v, err := f.PrepareValue(record, s.raw)
if err != nil {
t.Fatal(err)
}
vStr, ok := v.(string)
if !ok {
t.Fatalf("Expected string instance, got %T", v)
}
if vStr != s.expected {
t.Fatalf("Expected %q, got %q", s.expected, v)
}
})
}
}
func TestURLFieldValidateValue(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
collection := core.NewBaseCollection("test_collection")
scenarios := []struct {
name string
field *core.URLField
record func() *core.Record
expectError bool
}{
{
"invalid raw value",
&core.URLField{Name: "test"},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", 123)
return record
},
true,
},
{
"zero field value (not required)",
&core.URLField{Name: "test"},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", "")
return record
},
false,
},
{
"zero field value (required)",
&core.URLField{Name: "test", Required: true},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", "")
return record
},
true,
},
{
"non-zero field value (required)",
&core.URLField{Name: "test", Required: true},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", "https://example.com")
return record
},
false,
},
{
"invalid url",
&core.URLField{Name: "test"},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", "invalid")
return record
},
true,
},
{
"failed onlyDomains",
&core.URLField{Name: "test", OnlyDomains: []string{"example.org", "example.net"}},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", "https://example.com")
return record
},
true,
},
{
"success onlyDomains",
&core.URLField{Name: "test", OnlyDomains: []string{"example.org", "example.com"}},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", "https://example.com")
return record
},
false,
},
{
"failed exceptDomains",
&core.URLField{Name: "test", ExceptDomains: []string{"example.org", "example.com"}},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", "https://example.com")
return record
},
true,
},
{
"success exceptDomains",
&core.URLField{Name: "test", ExceptDomains: []string{"example.org", "example.net"}},
func() *core.Record {
record := core.NewRecord(collection)
record.SetRaw("test", "https://example.com")
return record
},
false,
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
err := s.field.ValidateValue(context.Background(), app, s.record())
hasErr := err != nil
if hasErr != s.expectError {
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
}
})
}
}
func TestURLFieldValidateSettings(t *testing.T) {
testDefaultFieldIdValidation(t, core.FieldTypeURL)
testDefaultFieldNameValidation(t, core.FieldTypeURL)
app, _ := tests.NewTestApp()
defer app.Cleanup()
collection := core.NewBaseCollection("test_collection")
scenarios := []struct {
name string
field func() *core.URLField
expectErrors []string
}{
{
"zero minimal",
func() *core.URLField {
return &core.URLField{
Id: "test",
Name: "test",
}
},
[]string{},
},
{
"both onlyDomains and exceptDomains",
func() *core.URLField {
return &core.URLField{
Id: "test",
Name: "test",
OnlyDomains: []string{"example.com"},
ExceptDomains: []string{"example.org"},
}
},
[]string{"onlyDomains", "exceptDomains"},
},
{
"invalid onlyDomains",
func() *core.URLField {
return &core.URLField{
Id: "test",
Name: "test",
OnlyDomains: []string{"example.com", "invalid"},
}
},
[]string{"onlyDomains"},
},
{
"valid onlyDomains",
func() *core.URLField {
return &core.URLField{
Id: "test",
Name: "test",
OnlyDomains: []string{"example.com", "example.org"},
}
},
[]string{},
},
{
"invalid exceptDomains",
func() *core.URLField {
return &core.URLField{
Id: "test",
Name: "test",
ExceptDomains: []string{"example.com", "invalid"},
}
},
[]string{"exceptDomains"},
},
{
"valid exceptDomains",
func() *core.URLField {
return &core.URLField{
Id: "test",
Name: "test",
ExceptDomains: []string{"example.com", "example.org"},
}
},
[]string{},
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
errs := s.field().ValidateSettings(context.Background(), app, collection)
tests.TestValidationErrors(t, errs, s.expectErrors)
})
}
}

388
core/fields_list.go Normal file
View file

@ -0,0 +1,388 @@
package core
import (
"database/sql/driver"
"encoding/json"
"fmt"
"slices"
"strconv"
)
// NewFieldsList creates a new FieldsList instance with the provided fields.
func NewFieldsList(fields ...Field) FieldsList {
l := make(FieldsList, 0, len(fields))
for _, f := range fields {
l.add(-1, f)
}
return l
}
// FieldsList defines a Collection slice of fields.
type FieldsList []Field
// Clone creates a deep clone of the current list.
func (l FieldsList) Clone() (FieldsList, error) {
copyRaw, err := json.Marshal(l)
if err != nil {
return nil, err
}
result := FieldsList{}
if err := json.Unmarshal(copyRaw, &result); err != nil {
return nil, err
}
return result, nil
}
// FieldNames returns a slice with the name of all list fields.
func (l FieldsList) FieldNames() []string {
result := make([]string, len(l))
for i, field := range l {
result[i] = field.GetName()
}
return result
}
// AsMap returns a map with all registered list field.
// The returned map is indexed with each field name.
func (l FieldsList) AsMap() map[string]Field {
result := make(map[string]Field, len(l))
for _, field := range l {
result[field.GetName()] = field
}
return result
}
// GetById returns a single field by its id.
func (l FieldsList) GetById(fieldId string) Field {
for _, field := range l {
if field.GetId() == fieldId {
return field
}
}
return nil
}
// GetByName returns a single field by its name.
func (l FieldsList) GetByName(fieldName string) Field {
for _, field := range l {
if field.GetName() == fieldName {
return field
}
}
return nil
}
// RemoveById removes a single field by its id.
//
// This method does nothing if field with the specified id doesn't exist.
func (l *FieldsList) RemoveById(fieldId string) {
fields := *l
for i, field := range fields {
if field.GetId() == fieldId {
*l = append(fields[:i], fields[i+1:]...)
return
}
}
}
// RemoveByName removes a single field by its name.
//
// This method does nothing if field with the specified name doesn't exist.
func (l *FieldsList) RemoveByName(fieldName string) {
fields := *l
for i, field := range fields {
if field.GetName() == fieldName {
*l = append(fields[:i], fields[i+1:]...)
return
}
}
}
// Add adds one or more fields to the current list.
//
// By default this method will try to REPLACE existing fields with
// the new ones by their id or by their name if the new field doesn't have an explicit id.
//
// If no matching existing field is found, it will APPEND the field to the end of the list.
//
// In all cases, if any of the new fields don't have an explicit id it will auto generate a default one for them
// (the id value doesn't really matter and it is mostly used as a stable identifier in case of a field rename).
func (l *FieldsList) Add(fields ...Field) {
for _, f := range fields {
l.add(-1, f)
}
}
// AddAt is the same as Add but insert/move the fields at the specific position.
//
// If pos < 0, then this method acts the same as calling Add.
//
// If pos > FieldsList total items, then the specified fields are inserted/moved at the end of the list.
func (l *FieldsList) AddAt(pos int, fields ...Field) {
total := len(*l)
for i, f := range fields {
if pos < 0 {
l.add(-1, f)
} else if pos > total {
l.add(total+i, f)
} else {
l.add(pos+i, f)
}
}
}
// AddMarshaledJSON parses the provided raw json data and adds the
// found fields into the current list (following the same rule as the Add method).
//
// The rawJSON argument could be one of:
// - serialized array of field objects
// - single field object.
//
// Example:
//
// l.AddMarshaledJSON([]byte{`{"type":"text", name: "test"}`})
// l.AddMarshaledJSON([]byte{`[{"type":"text", name: "test1"}, {"type":"text", name: "test2"}]`})
func (l *FieldsList) AddMarshaledJSON(rawJSON []byte) error {
extractedFields, err := marshaledJSONtoFieldsList(rawJSON)
if err != nil {
return err
}
l.Add(extractedFields...)
return nil
}
// AddMarshaledJSONAt is the same as AddMarshaledJSON but insert/move the fields at the specific position.
//
// If pos < 0, then this method acts the same as calling AddMarshaledJSON.
//
// If pos > FieldsList total items, then the specified fields are inserted/moved at the end of the list.
func (l *FieldsList) AddMarshaledJSONAt(pos int, rawJSON []byte) error {
extractedFields, err := marshaledJSONtoFieldsList(rawJSON)
if err != nil {
return err
}
l.AddAt(pos, extractedFields...)
return nil
}
func marshaledJSONtoFieldsList(rawJSON []byte) (FieldsList, error) {
extractedFields := FieldsList{}
// nothing to add
if len(rawJSON) == 0 {
return extractedFields, nil
}
// try to unmarshal first into a new fieds list
// (assuming that rawJSON is array of objects)
err := json.Unmarshal(rawJSON, &extractedFields)
if err != nil {
// try again but wrap the rawJSON in []
// (assuming that rawJSON is a single object)
wrapped := make([]byte, 0, len(rawJSON)+2)
wrapped = append(wrapped, '[')
wrapped = append(wrapped, rawJSON...)
wrapped = append(wrapped, ']')
err = json.Unmarshal(wrapped, &extractedFields)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal the provided JSON - expects array of objects or just single object: %w", err)
}
}
return extractedFields, nil
}
func (l *FieldsList) add(pos int, newField Field) {
fields := *l
var replaceByName bool
var replaceInPlace bool
if pos < 0 {
replaceInPlace = true
pos = len(fields)
} else if pos > len(fields) {
pos = len(fields)
}
newFieldId := newField.GetId()
// set default id
if newFieldId == "" {
replaceByName = true
baseId := newField.Type() + crc32Checksum(newField.GetName())
newFieldId = baseId
for i := 2; i < 1000; i++ {
if l.GetById(newFieldId) == nil {
break // already unique
}
newFieldId = baseId + strconv.Itoa(i)
}
newField.SetId(newFieldId)
}
// try to replace existing
for i, field := range fields {
if replaceByName {
if name := newField.GetName(); name != "" && field.GetName() == name {
// reuse the original id
newField.SetId(field.GetId())
if replaceInPlace {
(*l)[i] = newField
return
} else {
// remove the current field and insert it later at the specific position
*l = slices.Delete(*l, i, i+1)
if total := len(*l); pos > total {
pos = total
}
break
}
}
} else {
if field.GetId() == newFieldId {
if replaceInPlace {
(*l)[i] = newField
return
} else {
// remove the current field and insert it later at the specific position
*l = slices.Delete(*l, i, i+1)
if total := len(*l); pos > total {
pos = total
}
break
}
}
}
}
// insert the new field
*l = slices.Insert(*l, pos, newField)
}
// String returns the string representation of the current list.
func (l FieldsList) String() string {
v, _ := json.Marshal(l)
return string(v)
}
type onlyFieldType struct {
Type string `json:"type"`
}
type fieldWithType struct {
Field
Type string `json:"type"`
}
func (fwt *fieldWithType) UnmarshalJSON(data []byte) error {
// extract the field type to init a blank factory
t := &onlyFieldType{}
if err := json.Unmarshal(data, t); err != nil {
return fmt.Errorf("failed to unmarshal field type: %w", err)
}
factory, ok := Fields[t.Type]
if !ok {
return fmt.Errorf("missing or unknown field type in %s", data)
}
fwt.Type = t.Type
fwt.Field = factory()
// unmarshal the rest of the data into the created field
if err := json.Unmarshal(data, fwt.Field); err != nil {
return fmt.Errorf("failed to unmarshal field: %w", err)
}
return nil
}
// UnmarshalJSON implements [json.Unmarshaler] and
// loads the provided json data into the current FieldsList.
func (l *FieldsList) UnmarshalJSON(data []byte) error {
fwts := []fieldWithType{}
if err := json.Unmarshal(data, &fwts); err != nil {
return err
}
*l = []Field{} // reset
for _, fwt := range fwts {
l.add(-1, fwt.Field)
}
return nil
}
// MarshalJSON implements the [json.Marshaler] interface.
func (l FieldsList) MarshalJSON() ([]byte, error) {
if l == nil {
l = []Field{} // always init to ensure that it is serialized as empty array
}
wrapper := make([]map[string]any, 0, len(l))
for _, f := range l {
// precompute the json into a map so that we can append the type to a flatten object
raw, err := json.Marshal(f)
if err != nil {
return nil, err
}
data := map[string]any{}
if err := json.Unmarshal(raw, &data); err != nil {
return nil, err
}
data["type"] = f.Type()
wrapper = append(wrapper, data)
}
return json.Marshal(wrapper)
}
// Value implements the [driver.Valuer] interface.
func (l FieldsList) Value() (driver.Value, error) {
data, err := json.Marshal(l)
return string(data), err
}
// Scan implements [sql.Scanner] interface to scan the provided value
// into the current FieldsList instance.
func (l *FieldsList) Scan(value any) error {
var data []byte
switch v := value.(type) {
case nil:
// no cast needed
case []byte:
data = v
case string:
data = []byte(v)
default:
return fmt.Errorf("failed to unmarshal FieldsList value %q", value)
}
if len(data) == 0 {
data = []byte("[]")
}
return l.UnmarshalJSON(data)
}

558
core/fields_list_test.go Normal file
View file

@ -0,0 +1,558 @@
package core_test
import (
"bytes"
"encoding/json"
"slices"
"strconv"
"strings"
"testing"
"github.com/pocketbase/pocketbase/core"
)
func TestNewFieldsList(t *testing.T) {
fields := core.NewFieldsList(
&core.TextField{Id: "id1", Name: "test1"},
&core.TextField{Name: "test2"},
&core.TextField{Id: "id1", Name: "test1_new"}, // should replace the original id1 field
)
if len(fields) != 2 {
t.Fatalf("Expected 2 fields, got %d (%v)", len(fields), fields)
}
for _, f := range fields {
if f.GetId() == "" {
t.Fatalf("Expected field id to be set, found empty id for field %v", f)
}
}
if fields[0].GetName() != "test1_new" {
t.Fatalf("Expected field with name test1_new, got %s", fields[0].GetName())
}
if fields[1].GetName() != "test2" {
t.Fatalf("Expected field with name test2, got %s", fields[1].GetName())
}
}
func TestFieldsListClone(t *testing.T) {
f1 := &core.TextField{Name: "test1"}
f2 := &core.EmailField{Name: "test2"}
s1 := core.NewFieldsList(f1, f2)
s2, err := s1.Clone()
if err != nil {
t.Fatal(err)
}
s1Str := s1.String()
s2Str := s2.String()
if s1Str != s2Str {
t.Fatalf("Expected the cloned list to be equal, got \n%v\nVS\n%v", s1, s2)
}
// change in one list shouldn't result to change in the other
// (aka. check if it is a deep clone)
s1[0].SetName("test1_update")
if s2[0].GetName() != "test1" {
t.Fatalf("Expected s2 field name to not change, got %q", s2[0].GetName())
}
}
func TestFieldsListFieldNames(t *testing.T) {
f1 := &core.TextField{Name: "test1"}
f2 := &core.EmailField{Name: "test2"}
testFieldsList := core.NewFieldsList(f1, f2)
result := testFieldsList.FieldNames()
expected := []string{f1.Name, f2.Name}
if len(result) != len(expected) {
t.Fatalf("Expected %d slice elements, got %d\n%v", len(expected), len(result), result)
}
for _, name := range expected {
if !slices.Contains(result, name) {
t.Fatalf("Missing name %q in %v", name, result)
}
}
}
func TestFieldsListAsMap(t *testing.T) {
f1 := &core.TextField{Name: "test1"}
f2 := &core.EmailField{Name: "test2"}
testFieldsList := core.NewFieldsList(f1, f2)
result := testFieldsList.AsMap()
expectedIndexes := []string{f1.Name, f2.Name}
if len(result) != len(expectedIndexes) {
t.Fatalf("Expected %d map elements, got %d\n%v", len(expectedIndexes), len(result), result)
}
for _, index := range expectedIndexes {
if _, ok := result[index]; !ok {
t.Fatalf("Missing index %q", index)
}
}
}
func TestFieldsListGetById(t *testing.T) {
f1 := &core.TextField{Id: "id1", Name: "test1"}
f2 := &core.EmailField{Id: "id2", Name: "test2"}
testFieldsList := core.NewFieldsList(f1, f2)
// missing field id
result1 := testFieldsList.GetById("test1")
if result1 != nil {
t.Fatalf("Found unexpected field %v", result1)
}
// existing field id
result2 := testFieldsList.GetById("id2")
if result2 == nil || result2.GetId() != "id2" {
t.Fatalf("Cannot find field with id %q, got %v ", "id2", result2)
}
}
func TestFieldsListGetByName(t *testing.T) {
f1 := &core.TextField{Id: "id1", Name: "test1"}
f2 := &core.EmailField{Id: "id2", Name: "test2"}
testFieldsList := core.NewFieldsList(f1, f2)
// missing field name
result1 := testFieldsList.GetByName("id1")
if result1 != nil {
t.Fatalf("Found unexpected field %v", result1)
}
// existing field name
result2 := testFieldsList.GetByName("test2")
if result2 == nil || result2.GetName() != "test2" {
t.Fatalf("Cannot find field with name %q, got %v ", "test2", result2)
}
}
func TestFieldsListRemove(t *testing.T) {
testFieldsList := core.NewFieldsList(
&core.TextField{Id: "id1", Name: "test1"},
&core.TextField{Id: "id2", Name: "test2"},
&core.TextField{Id: "id3", Name: "test3"},
&core.TextField{Id: "id4", Name: "test4"},
&core.TextField{Id: "id5", Name: "test5"},
&core.TextField{Id: "id6", Name: "test6"},
)
// remove by id
testFieldsList.RemoveById("id2")
testFieldsList.RemoveById("test3") // should do nothing
// remove by name
testFieldsList.RemoveByName("test5")
testFieldsList.RemoveByName("id6") // should do nothing
expected := []string{"test1", "test3", "test4", "test6"}
if len(testFieldsList) != len(expected) {
t.Fatalf("Expected %d, got %d\n%v", len(expected), len(testFieldsList), testFieldsList)
}
for _, name := range expected {
if f := testFieldsList.GetByName(name); f == nil {
t.Fatalf("Missing field %q", name)
}
}
}
func TestFieldsListAdd(t *testing.T) {
f0 := &core.TextField{}
f1 := &core.TextField{Name: "test1"}
f2 := &core.TextField{Id: "f2Id", Name: "test2"}
f3 := &core.TextField{Id: "f3Id", Name: "test3"}
testFieldsList := core.NewFieldsList(f0, f1, f2, f3)
f2New := &core.EmailField{Id: "f2Id", Name: "test2_new"}
f4 := &core.URLField{Name: "test4"}
testFieldsList.Add(f2New)
testFieldsList.Add(f4)
if len(testFieldsList) != 5 {
t.Fatalf("Expected %d, got %d\n%v", 5, len(testFieldsList), testFieldsList)
}
// check if each field has id
for _, f := range testFieldsList {
if f.GetId() == "" {
t.Fatalf("Expected field id to be set, found empty id for field %v", f)
}
}
// check if f2 field was replaced
if f := testFieldsList.GetById("f2Id"); f == nil || f.Type() != core.FieldTypeEmail {
t.Fatalf("Expected f2 field to be replaced, found %v", f)
}
// check if f4 was added
if f := testFieldsList.GetByName("test4"); f == nil || f.GetName() != "test4" {
t.Fatalf("Expected f4 field to be added, found %v", f)
}
}
func TestFieldsListAddMarshaledJSON(t *testing.T) {
t.Parallel()
scenarios := []struct {
name string
raw []byte
expectError bool
expectedFields map[string]string
}{
{
"nil",
nil,
false,
map[string]string{"abc": core.FieldTypeNumber},
},
{
"empty array",
[]byte(`[]`),
false,
map[string]string{"abc": core.FieldTypeNumber},
},
{
"empty object",
[]byte(`{}`),
true,
map[string]string{"abc": core.FieldTypeNumber},
},
{
"array with empty object",
[]byte(`[{}]`),
true,
map[string]string{"abc": core.FieldTypeNumber},
},
{
"single object with invalid type",
[]byte(`{"type":"missing","name":"test"}`),
true,
map[string]string{"abc": core.FieldTypeNumber},
},
{
"single object with valid type",
[]byte(`{"type":"text","name":"test"}`),
false,
map[string]string{
"abc": core.FieldTypeNumber,
"test": core.FieldTypeText,
},
},
{
"array of object with valid types",
[]byte(`[{"type":"text","name":"test1"},{"type":"url","name":"test2"}]`),
false,
map[string]string{
"abc": core.FieldTypeNumber,
"test1": core.FieldTypeText,
"test2": core.FieldTypeURL,
},
},
{
"fields with duplicated ids should replace existing fields",
[]byte(`[{"type":"text","name":"test1"},{"type":"url","name":"test2"},{"type":"text","name":"abc2", "id":"abc_id"}]`),
false,
map[string]string{
"abc2": core.FieldTypeText,
"test1": core.FieldTypeText,
"test2": core.FieldTypeURL,
},
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
testList := core.NewFieldsList(&core.NumberField{Name: "abc", Id: "abc_id"})
err := testList.AddMarshaledJSON(s.raw)
hasErr := err != nil
if hasErr != s.expectError {
t.Fatalf("Expected hasErr %v, got %v", s.expectError, hasErr)
}
if len(s.expectedFields) != len(testList) {
t.Fatalf("Expected %d fields, got %d", len(s.expectedFields), len(testList))
}
for fieldName, typ := range s.expectedFields {
f := testList.GetByName(fieldName)
if f == nil {
t.Errorf("Missing expected field %q", fieldName)
continue
}
if f.Type() != typ {
t.Errorf("Expect field %q to has type %q, got %q", fieldName, typ, f.Type())
}
}
})
}
}
func TestFieldsListAddAt(t *testing.T) {
scenarios := []struct {
position int
expected []string
}{
{-2, []string{"test1", "test2_new", "test3", "test4"}},
{-1, []string{"test1", "test2_new", "test3", "test4"}},
{0, []string{"test2_new", "test4", "test1", "test3"}},
{1, []string{"test1", "test2_new", "test4", "test3"}},
{2, []string{"test1", "test3", "test2_new", "test4"}},
{3, []string{"test1", "test3", "test2_new", "test4"}},
{4, []string{"test1", "test3", "test2_new", "test4"}},
{5, []string{"test1", "test3", "test2_new", "test4"}},
}
for _, s := range scenarios {
t.Run(strconv.Itoa(s.position), func(t *testing.T) {
f1 := &core.TextField{Id: "f1Id", Name: "test1"}
f2 := &core.TextField{Id: "f2Id", Name: "test2"}
f3 := &core.TextField{Id: "f3Id", Name: "test3"}
testFieldsList := core.NewFieldsList(f1, f2, f3)
f2New := &core.EmailField{Id: "f2Id", Name: "test2_new"}
f4 := &core.URLField{Name: "test4"}
testFieldsList.AddAt(s.position, f2New, f4)
rawNames, err := json.Marshal(testFieldsList.FieldNames())
if err != nil {
t.Fatal(err)
}
rawExpected, err := json.Marshal(s.expected)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(rawNames, rawExpected) {
t.Fatalf("Expected fields\n%s\ngot\n%s", rawExpected, rawNames)
}
})
}
}
func TestFieldsListAddMarshaledJSONAt(t *testing.T) {
scenarios := []struct {
position int
expected []string
}{
{-2, []string{"test1", "test2_new", "test3", "test4"}},
{-1, []string{"test1", "test2_new", "test3", "test4"}},
{0, []string{"test2_new", "test4", "test1", "test3"}},
{1, []string{"test1", "test2_new", "test4", "test3"}},
{2, []string{"test1", "test3", "test2_new", "test4"}},
{3, []string{"test1", "test3", "test2_new", "test4"}},
{4, []string{"test1", "test3", "test2_new", "test4"}},
{5, []string{"test1", "test3", "test2_new", "test4"}},
}
for _, s := range scenarios {
t.Run(strconv.Itoa(s.position), func(t *testing.T) {
f1 := &core.TextField{Id: "f1Id", Name: "test1"}
f2 := &core.TextField{Id: "f2Id", Name: "test2"}
f3 := &core.TextField{Id: "f3Id", Name: "test3"}
testFieldsList := core.NewFieldsList(f1, f2, f3)
err := testFieldsList.AddMarshaledJSONAt(s.position, []byte(`[
{"id":"f2Id", "name":"test2_new", "type": "text"},
{"name": "test4", "type": "text"}
]`))
if err != nil {
t.Fatal(err)
}
rawNames, err := json.Marshal(testFieldsList.FieldNames())
if err != nil {
t.Fatal(err)
}
rawExpected, err := json.Marshal(s.expected)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(rawNames, rawExpected) {
t.Fatalf("Expected fields\n%s\ngot\n%s", rawExpected, rawNames)
}
})
}
}
func TestFieldsListStringAndValue(t *testing.T) {
t.Run("empty list", func(t *testing.T) {
testFieldsList := core.NewFieldsList()
str := testFieldsList.String()
if str != "[]" {
t.Fatalf("Expected empty slice, got\n%q", str)
}
v, err := testFieldsList.Value()
if err != nil {
t.Fatal(err)
}
if v != str {
t.Fatalf("Expected String and Value to match")
}
})
t.Run("list with fields", func(t *testing.T) {
testFieldsList := core.NewFieldsList(
&core.TextField{Id: "f1id", Name: "test1"},
&core.BoolField{Id: "f2id", Name: "test2"},
&core.URLField{Id: "f3id", Name: "test3"},
)
str := testFieldsList.String()
v, err := testFieldsList.Value()
if err != nil {
t.Fatal(err)
}
if v != str {
t.Fatalf("Expected String and Value to match")
}
expectedParts := []string{
`"type":"bool"`,
`"type":"url"`,
`"type":"text"`,
`"id":"f1id"`,
`"id":"f2id"`,
`"id":"f3id"`,
`"name":"test1"`,
`"name":"test2"`,
`"name":"test3"`,
}
for _, part := range expectedParts {
if !strings.Contains(str, part) {
t.Fatalf("Missing %q in\nn%v", part, str)
}
}
})
}
func TestFieldsListScan(t *testing.T) {
scenarios := []struct {
name string
data any
expectError bool
expectJSON string
}{
{"nil", nil, false, "[]"},
{"empty string", "", false, "[]"},
{"empty byte", []byte{}, false, "[]"},
{"empty string array", "[]", false, "[]"},
{"invalid string", "invalid", true, "[]"},
{"non-string", 123, true, "[]"},
{"item with no field type", `[{}]`, true, "[]"},
{
"unknown field type",
`[{"id":"123","name":"test1","type":"unknown"},{"id":"456","name":"test2","type":"bool"}]`,
true,
`[]`,
},
{
"only the minimum field options",
`[{"id":"123","name":"test1","type":"text","required":true},{"id":"456","name":"test2","type":"bool"}]`,
false,
`[{"autogeneratePattern":"","hidden":false,"id":"123","max":0,"min":0,"name":"test1","pattern":"","presentable":false,"primaryKey":false,"required":true,"system":false,"type":"text"},{"hidden":false,"id":"456","name":"test2","presentable":false,"required":false,"system":false,"type":"bool"}]`,
},
{
"all field options",
`[{"autogeneratePattern":"","hidden":true,"id":"123","max":12,"min":0,"name":"test1","pattern":"","presentable":true,"primaryKey":false,"required":true,"system":false,"type":"text"},{"hidden":false,"id":"456","name":"test2","presentable":false,"required":false,"system":true,"type":"bool"}]`,
false,
`[{"autogeneratePattern":"","hidden":true,"id":"123","max":12,"min":0,"name":"test1","pattern":"","presentable":true,"primaryKey":false,"required":true,"system":false,"type":"text"},{"hidden":false,"id":"456","name":"test2","presentable":false,"required":false,"system":true,"type":"bool"}]`,
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
testFieldsList := core.FieldsList{}
err := testFieldsList.Scan(s.data)
hasErr := err != nil
if hasErr != s.expectError {
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
}
str := testFieldsList.String()
if str != s.expectJSON {
t.Fatalf("Expected\n%v\ngot\n%v", s.expectJSON, str)
}
})
}
}
func TestFieldsListJSON(t *testing.T) {
scenarios := []struct {
name string
data string
expectError bool
expectJSON string
}{
{"empty string", "", true, "[]"},
{"invalid string", "invalid", true, "[]"},
{"empty string array", "[]", false, "[]"},
{"item with no field type", `[{}]`, true, "[]"},
{
"unknown field type",
`[{"id":"123","name":"test1","type":"unknown"},{"id":"456","name":"test2","type":"bool"}]`,
true,
`[]`,
},
{
"only the minimum field options",
`[{"id":"123","name":"test1","type":"text","required":true},{"id":"456","name":"test2","type":"bool"}]`,
false,
`[{"autogeneratePattern":"","hidden":false,"id":"123","max":0,"min":0,"name":"test1","pattern":"","presentable":false,"primaryKey":false,"required":true,"system":false,"type":"text"},{"hidden":false,"id":"456","name":"test2","presentable":false,"required":false,"system":false,"type":"bool"}]`,
},
{
"all field options",
`[{"autogeneratePattern":"","hidden":true,"id":"123","max":12,"min":0,"name":"test1","pattern":"","presentable":true,"primaryKey":false,"required":true,"system":false,"type":"text"},{"hidden":false,"id":"456","name":"test2","presentable":false,"required":false,"system":true,"type":"bool"}]`,
false,
`[{"autogeneratePattern":"","hidden":true,"id":"123","max":12,"min":0,"name":"test1","pattern":"","presentable":true,"primaryKey":false,"required":true,"system":false,"type":"text"},{"hidden":false,"id":"456","name":"test2","presentable":false,"required":false,"system":true,"type":"bool"}]`,
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
testFieldsList := core.FieldsList{}
err := testFieldsList.UnmarshalJSON([]byte(s.data))
hasErr := err != nil
if hasErr != s.expectError {
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
}
raw, err := testFieldsList.MarshalJSON()
if err != nil {
t.Fatal(err)
}
str := string(raw)
if str != s.expectJSON {
t.Fatalf("Expected\n%v\ngot\n%v", s.expectJSON, str)
}
})
}
}

22
core/log_model.go Normal file
View file

@ -0,0 +1,22 @@
package core
import "github.com/pocketbase/pocketbase/tools/types"
var (
_ Model = (*Log)(nil)
)
const LogsTableName = "_logs"
type Log struct {
BaseModel
Created types.DateTime `db:"created" json:"created"`
Data types.JSONMap[any] `db:"data" json:"data"`
Message string `db:"message" json:"message"`
Level int `db:"level" json:"level"`
}
func (m *Log) TableName() string {
return LogsTableName
}

67
core/log_printer.go Normal file
View file

@ -0,0 +1,67 @@
package core
import (
"fmt"
"log/slog"
"strings"
"github.com/fatih/color"
"github.com/pocketbase/pocketbase/tools/logger"
"github.com/pocketbase/pocketbase/tools/store"
"github.com/spf13/cast"
)
var cachedColors = store.New[string, *color.Color](nil)
// getColor returns [color.Color] object and cache it (if not already).
func getColor(attrs ...color.Attribute) (c *color.Color) {
cacheKey := fmt.Sprint(attrs)
if c = cachedColors.Get(cacheKey); c == nil {
c = color.New(attrs...)
cachedColors.Set(cacheKey, c)
}
return
}
// printLog prints the provided log to the stderr.
// (note: defined as variable to overwriting in the tests)
var printLog = func(log *logger.Log) {
var str strings.Builder
switch log.Level {
case slog.LevelDebug:
str.WriteString(getColor(color.Bold, color.FgHiBlack).Sprint("DEBUG "))
str.WriteString(getColor(color.FgWhite).Sprint(log.Message))
case slog.LevelInfo:
str.WriteString(getColor(color.Bold, color.FgWhite).Sprint("INFO "))
str.WriteString(getColor(color.FgWhite).Sprint(log.Message))
case slog.LevelWarn:
str.WriteString(getColor(color.Bold, color.FgYellow).Sprint("WARN "))
str.WriteString(getColor(color.FgYellow).Sprint(log.Message))
case slog.LevelError:
str.WriteString(getColor(color.Bold, color.FgRed).Sprint("ERROR "))
str.WriteString(getColor(color.FgRed).Sprint(log.Message))
default:
str.WriteString(getColor(color.Bold, color.FgCyan).Sprintf("[%d] ", log.Level))
str.WriteString(getColor(color.FgCyan).Sprint(log.Message))
}
str.WriteString("\n")
if v, ok := log.Data["type"]; ok && cast.ToString(v) == "request" {
padding := 0
keys := []string{"error", "details"}
for _, k := range keys {
if v := log.Data[k]; v != nil {
str.WriteString(getColor(color.FgHiRed).Sprintf("%s└─ %v", strings.Repeat(" ", padding), v))
str.WriteString("\n")
padding += 3
}
}
} else if len(log.Data) > 0 {
str.WriteString(getColor(color.FgHiBlack).Sprintf("└─ %v", log.Data))
str.WriteString("\n")
}
fmt.Print(str.String())
}

124
core/log_printer_test.go Normal file
View file

@ -0,0 +1,124 @@
package core
import (
"context"
"database/sql"
"log/slog"
"os"
"testing"
"time"
"github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase/tools/list"
"github.com/pocketbase/pocketbase/tools/logger"
)
func TestBaseAppLoggerLevelDevPrint(t *testing.T) {
testLogLevel := 4
scenarios := []struct {
name string
isDev bool
levels []int
printedLevels []int
persistedLevels []int
}{
{
"dev mode",
true,
[]int{testLogLevel - 1, testLogLevel, testLogLevel + 1},
[]int{testLogLevel - 1, testLogLevel, testLogLevel + 1},
[]int{testLogLevel, testLogLevel + 1},
},
{
"nondev mode",
false,
[]int{testLogLevel - 1, testLogLevel, testLogLevel + 1},
[]int{},
[]int{testLogLevel, testLogLevel + 1},
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
const testDataDir = "./pb_base_app_test_data_dir/"
defer os.RemoveAll(testDataDir)
app := NewBaseApp(BaseAppConfig{
DataDir: testDataDir,
IsDev: s.isDev,
})
defer app.ResetBootstrapState()
if err := app.Bootstrap(); err != nil {
t.Fatal(err)
}
// silence query logs
app.concurrentDB.(*dbx.DB).ExecLogFunc = func(ctx context.Context, t time.Duration, sql string, result sql.Result, err error) {}
app.concurrentDB.(*dbx.DB).QueryLogFunc = func(ctx context.Context, t time.Duration, sql string, rows *sql.Rows, err error) {}
app.nonconcurrentDB.(*dbx.DB).ExecLogFunc = func(ctx context.Context, t time.Duration, sql string, result sql.Result, err error) {}
app.nonconcurrentDB.(*dbx.DB).QueryLogFunc = func(ctx context.Context, t time.Duration, sql string, rows *sql.Rows, err error) {}
app.Settings().Logs.MinLevel = testLogLevel
if err := app.Save(app.Settings()); err != nil {
t.Fatal(err)
}
var printedLevels []int
var persistedLevels []int
ctx := context.Background()
// track printed logs
originalPrintLog := printLog
defer func() {
printLog = originalPrintLog
}()
printLog = func(log *logger.Log) {
printedLevels = append(printedLevels, int(log.Level))
}
// track persisted logs
app.OnModelAfterCreateSuccess("_logs").BindFunc(func(e *ModelEvent) error {
l, ok := e.Model.(*Log)
if ok {
persistedLevels = append(persistedLevels, l.Level)
}
return e.Next()
})
// write and persist logs
for _, l := range s.levels {
app.Logger().Log(ctx, slog.Level(l), "test")
}
handler, ok := app.Logger().Handler().(*logger.BatchHandler)
if !ok {
t.Fatalf("Expected BatchHandler, got %v", app.Logger().Handler())
}
if err := handler.WriteAll(ctx); err != nil {
t.Fatalf("Failed to write all logs: %v", err)
}
// check persisted log levels
if len(s.persistedLevels) != len(persistedLevels) {
t.Fatalf("Expected persisted levels \n%v\ngot\n%v", s.persistedLevels, persistedLevels)
}
for _, l := range persistedLevels {
if !list.ExistInSlice(l, s.persistedLevels) {
t.Fatalf("Missing expected persisted level %v in %v", l, persistedLevels)
}
}
// check printed log levels
if len(s.printedLevels) != len(printedLevels) {
t.Fatalf("Expected printed levels \n%v\ngot\n%v", s.printedLevels, printedLevels)
}
for _, l := range printedLevels {
if !list.ExistInSlice(l, s.printedLevels) {
t.Fatalf("Missing expected printed level %v in %v", l, printedLevels)
}
}
})
}
}

65
core/log_query.go Normal file
View file

@ -0,0 +1,65 @@
package core
import (
"time"
"github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase/tools/types"
)
// LogQuery returns a new Log select query.
func (app *BaseApp) LogQuery() *dbx.SelectQuery {
return app.AuxModelQuery(&Log{})
}
// FindLogById finds a single Log entry by its id.
func (app *BaseApp) FindLogById(id string) (*Log, error) {
model := &Log{}
err := app.LogQuery().
AndWhere(dbx.HashExp{"id": id}).
Limit(1).
One(model)
if err != nil {
return nil, err
}
return model, nil
}
// LogsStatsItem defines the total number of logs for a specific time period.
type LogsStatsItem struct {
Date types.DateTime `db:"date" json:"date"`
Total int `db:"total" json:"total"`
}
// LogsStats returns hourly grouped logs statistics.
func (app *BaseApp) LogsStats(expr dbx.Expression) ([]*LogsStatsItem, error) {
result := []*LogsStatsItem{}
query := app.LogQuery().
Select("count(id) as total", "strftime('%Y-%m-%d %H:00:00', created) as date").
GroupBy("date")
if expr != nil {
query.AndWhere(expr)
}
err := query.All(&result)
return result, err
}
// DeleteOldLogs delete all logs that are created before createdBefore.
//
// For better performance the logs delete is executed as plain SQL statement,
// aka. no delete model hook events will be fired.
func (app *BaseApp) DeleteOldLogs(createdBefore time.Time) error {
formattedDate := createdBefore.UTC().Format(types.DefaultDateLayout)
expr := dbx.NewExp("[[created]] <= {:date}", dbx.Params{"date": formattedDate})
_, err := app.auxNonconcurrentDB.Delete((&Log{}).TableName(), expr).Execute()
return err
}

114
core/log_query_test.go Normal file
View file

@ -0,0 +1,114 @@
package core_test
import (
"encoding/json"
"fmt"
"testing"
"time"
"github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
"github.com/pocketbase/pocketbase/tools/types"
)
func TestFindLogById(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
tests.StubLogsData(app)
scenarios := []struct {
id string
expectError bool
}{
{"", true},
{"invalid", true},
{"00000000-9f38-44fb-bf82-c8f53b310d91", true},
{"873f2133-9f38-44fb-bf82-c8f53b310d91", false},
}
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%s", i, s.id), func(t *testing.T) {
log, err := app.FindLogById(s.id)
hasErr := err != nil
if hasErr != s.expectError {
t.Fatalf("Expected hasErr to be %v, got %v (%v)", s.expectError, hasErr, err)
}
if log != nil && log.Id != s.id {
t.Fatalf("Expected log with id %q, got %q", s.id, log.Id)
}
})
}
}
func TestLogsStats(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
tests.StubLogsData(app)
expected := `[{"date":"2022-05-01 10:00:00.000Z","total":1},{"date":"2022-05-02 10:00:00.000Z","total":1}]`
now := time.Now().UTC().Format(types.DefaultDateLayout)
exp := dbx.NewExp("[[created]] <= {:date}", dbx.Params{"date": now})
result, err := app.LogsStats(exp)
if err != nil {
t.Fatal(err)
}
encoded, _ := json.Marshal(result)
if string(encoded) != expected {
t.Fatalf("Expected\n%q\ngot\n%q", expected, string(encoded))
}
}
func TestDeleteOldLogs(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
tests.StubLogsData(app)
scenarios := []struct {
date string
expectedTotal int
}{
{"2022-01-01 10:00:00.000Z", 2}, // no logs to delete before that time
{"2022-05-01 11:00:00.000Z", 1}, // only 1 log should have left
{"2022-05-03 11:00:00.000Z", 0}, // no more logs should have left
{"2022-05-04 11:00:00.000Z", 0}, // no more logs should have left
}
for _, s := range scenarios {
t.Run(s.date, func(t *testing.T) {
date, dateErr := time.Parse(types.DefaultDateLayout, s.date)
if dateErr != nil {
t.Fatalf("Date error %v", dateErr)
}
deleteErr := app.DeleteOldLogs(date)
if deleteErr != nil {
t.Fatalf("Delete error %v", deleteErr)
}
// check total remaining logs
var total int
countErr := app.AuxModelQuery(&core.Log{}).Select("count(*)").Row(&total)
if countErr != nil {
t.Errorf("Count error %v", countErr)
}
if total != s.expectedTotal {
t.Errorf("Expected %d remaining logs, got %d", s.expectedTotal, total)
}
})
}
}

157
core/mfa_model.go Normal file
View file

@ -0,0 +1,157 @@
package core
import (
"context"
"errors"
"time"
"github.com/pocketbase/pocketbase/tools/hook"
"github.com/pocketbase/pocketbase/tools/types"
)
const (
MFAMethodPassword = "password"
MFAMethodOAuth2 = "oauth2"
MFAMethodOTP = "otp"
)
const CollectionNameMFAs = "_mfas"
var (
_ Model = (*MFA)(nil)
_ PreValidator = (*MFA)(nil)
_ RecordProxy = (*MFA)(nil)
)
// MFA defines a Record proxy for working with the mfas collection.
type MFA struct {
*Record
}
// NewMFA instantiates and returns a new blank *MFA model.
//
// Example usage:
//
// mfa := core.NewMFA(app)
// mfa.SetRecordRef(user.Id)
// mfa.SetCollectionRef(user.Collection().Id)
// mfa.SetMethod(core.MFAMethodPassword)
// app.Save(mfa)
func NewMFA(app App) *MFA {
m := &MFA{}
c, err := app.FindCachedCollectionByNameOrId(CollectionNameMFAs)
if err != nil {
// this is just to make tests easier since mfa is a system collection and it is expected to be always accessible
// (note: the loaded record is further checked on MFA.PreValidate())
c = NewBaseCollection("@__invalid__")
}
m.Record = NewRecord(c)
return m
}
// PreValidate implements the [PreValidator] interface and checks
// whether the proxy is properly loaded.
func (m *MFA) PreValidate(ctx context.Context, app App) error {
if m.Record == nil || m.Record.Collection().Name != CollectionNameMFAs {
return errors.New("missing or invalid mfa ProxyRecord")
}
return nil
}
// ProxyRecord returns the proxied Record model.
func (m *MFA) ProxyRecord() *Record {
return m.Record
}
// SetProxyRecord loads the specified record model into the current proxy.
func (m *MFA) SetProxyRecord(record *Record) {
m.Record = record
}
// CollectionRef returns the "collectionRef" field value.
func (m *MFA) CollectionRef() string {
return m.GetString("collectionRef")
}
// SetCollectionRef updates the "collectionRef" record field value.
func (m *MFA) SetCollectionRef(collectionId string) {
m.Set("collectionRef", collectionId)
}
// RecordRef returns the "recordRef" record field value.
func (m *MFA) RecordRef() string {
return m.GetString("recordRef")
}
// SetRecordRef updates the "recordRef" record field value.
func (m *MFA) SetRecordRef(recordId string) {
m.Set("recordRef", recordId)
}
// Method returns the "method" record field value.
func (m *MFA) Method() string {
return m.GetString("method")
}
// SetMethod updates the "method" record field value.
func (m *MFA) SetMethod(method string) {
m.Set("method", method)
}
// Created returns the "created" record field value.
func (m *MFA) Created() types.DateTime {
return m.GetDateTime("created")
}
// Updated returns the "updated" record field value.
func (m *MFA) Updated() types.DateTime {
return m.GetDateTime("updated")
}
// HasExpired checks if the mfa is expired, aka. whether it has been
// more than maxElapsed time since its creation.
func (m *MFA) HasExpired(maxElapsed time.Duration) bool {
return time.Since(m.Created().Time()) > maxElapsed
}
func (app *BaseApp) registerMFAHooks() {
recordRefHooks[*MFA](app, CollectionNameMFAs, CollectionTypeAuth)
// run on every hour to cleanup expired mfa sessions
app.Cron().Add("__pbMFACleanup__", "0 * * * *", func() {
if err := app.DeleteExpiredMFAs(); err != nil {
app.Logger().Warn("Failed to delete expired MFA sessions", "error", err)
}
})
// delete existing mfas on password change
app.OnRecordUpdate().Bind(&hook.Handler[*RecordEvent]{
Func: func(e *RecordEvent) error {
err := e.Next()
if err != nil || !e.Record.Collection().IsAuth() {
return err
}
old := e.Record.Original().GetString(FieldNamePassword + ":hash")
new := e.Record.GetString(FieldNamePassword + ":hash")
if old != new {
err = e.App.DeleteAllMFAsByRecord(e.Record)
if err != nil {
e.App.Logger().Warn(
"Failed to delete all previous mfas",
"error", err,
"recordId", e.Record.Id,
"collectionId", e.Record.Collection().Id,
)
}
}
return nil
},
Priority: 99,
})
}

302
core/mfa_model_test.go Normal file
View file

@ -0,0 +1,302 @@
package core_test
import (
"fmt"
"testing"
"time"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
"github.com/pocketbase/pocketbase/tools/types"
)
func TestNewMFA(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
mfa := core.NewMFA(app)
if mfa.Collection().Name != core.CollectionNameMFAs {
t.Fatalf("Expected record with %q collection, got %q", core.CollectionNameMFAs, mfa.Collection().Name)
}
}
func TestMFAProxyRecord(t *testing.T) {
t.Parallel()
record := core.NewRecord(core.NewBaseCollection("test"))
record.Id = "test_id"
mfa := core.MFA{}
mfa.SetProxyRecord(record)
if mfa.ProxyRecord() == nil || mfa.ProxyRecord().Id != record.Id {
t.Fatalf("Expected proxy record with id %q, got %v", record.Id, mfa.ProxyRecord())
}
}
func TestMFARecordRef(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
mfa := core.NewMFA(app)
testValues := []string{"test_1", "test2", ""}
for i, testValue := range testValues {
t.Run(fmt.Sprintf("%d_%q", i, testValue), func(t *testing.T) {
mfa.SetRecordRef(testValue)
if v := mfa.RecordRef(); v != testValue {
t.Fatalf("Expected getter %q, got %q", testValue, v)
}
if v := mfa.GetString("recordRef"); v != testValue {
t.Fatalf("Expected field value %q, got %q", testValue, v)
}
})
}
}
func TestMFACollectionRef(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
mfa := core.NewMFA(app)
testValues := []string{"test_1", "test2", ""}
for i, testValue := range testValues {
t.Run(fmt.Sprintf("%d_%q", i, testValue), func(t *testing.T) {
mfa.SetCollectionRef(testValue)
if v := mfa.CollectionRef(); v != testValue {
t.Fatalf("Expected getter %q, got %q", testValue, v)
}
if v := mfa.GetString("collectionRef"); v != testValue {
t.Fatalf("Expected field value %q, got %q", testValue, v)
}
})
}
}
func TestMFAMethod(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
mfa := core.NewMFA(app)
testValues := []string{"test_1", "test2", ""}
for i, testValue := range testValues {
t.Run(fmt.Sprintf("%d_%q", i, testValue), func(t *testing.T) {
mfa.SetMethod(testValue)
if v := mfa.Method(); v != testValue {
t.Fatalf("Expected getter %q, got %q", testValue, v)
}
if v := mfa.GetString("method"); v != testValue {
t.Fatalf("Expected field value %q, got %q", testValue, v)
}
})
}
}
func TestMFACreated(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
mfa := core.NewMFA(app)
if v := mfa.Created().String(); v != "" {
t.Fatalf("Expected empty created, got %q", v)
}
now := types.NowDateTime()
mfa.SetRaw("created", now)
if v := mfa.Created().String(); v != now.String() {
t.Fatalf("Expected %q created, got %q", now.String(), v)
}
}
func TestMFAUpdated(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
mfa := core.NewMFA(app)
if v := mfa.Updated().String(); v != "" {
t.Fatalf("Expected empty updated, got %q", v)
}
now := types.NowDateTime()
mfa.SetRaw("updated", now)
if v := mfa.Updated().String(); v != now.String() {
t.Fatalf("Expected %q updated, got %q", now.String(), v)
}
}
func TestMFAHasExpired(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
now := types.NowDateTime()
mfa := core.NewMFA(app)
mfa.SetRaw("created", now.Add(-5*time.Minute))
scenarios := []struct {
maxElapsed time.Duration
expected bool
}{
{0 * time.Minute, true},
{3 * time.Minute, true},
{5 * time.Minute, true},
{6 * time.Minute, false},
}
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%s", i, s.maxElapsed.String()), func(t *testing.T) {
result := mfa.HasExpired(s.maxElapsed)
if result != s.expected {
t.Fatalf("Expected %v, got %v", s.expected, result)
}
})
}
}
func TestMFAPreValidate(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
mfasCol, err := app.FindCollectionByNameOrId(core.CollectionNameMFAs)
if err != nil {
t.Fatal(err)
}
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
t.Run("no proxy record", func(t *testing.T) {
mfa := &core.MFA{}
if err := app.Validate(mfa); err == nil {
t.Fatal("Expected collection validation error")
}
})
t.Run("non-MFA collection", func(t *testing.T) {
mfa := &core.MFA{}
mfa.SetProxyRecord(core.NewRecord(core.NewBaseCollection("invalid")))
mfa.SetRecordRef(user.Id)
mfa.SetCollectionRef(user.Collection().Id)
mfa.SetMethod("test123")
if err := app.Validate(mfa); err == nil {
t.Fatal("Expected collection validation error")
}
})
t.Run("MFA collection", func(t *testing.T) {
mfa := &core.MFA{}
mfa.SetProxyRecord(core.NewRecord(mfasCol))
mfa.SetRecordRef(user.Id)
mfa.SetCollectionRef(user.Collection().Id)
mfa.SetMethod("test123")
if err := app.Validate(mfa); err != nil {
t.Fatalf("Expected nil validation error, got %v", err)
}
})
}
func TestMFAValidateHook(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
demo1, err := app.FindRecordById("demo1", "84nmscqy84lsi1t")
if err != nil {
t.Fatal(err)
}
scenarios := []struct {
name string
mfa func() *core.MFA
expectErrors []string
}{
{
"empty",
func() *core.MFA {
return core.NewMFA(app)
},
[]string{"collectionRef", "recordRef", "method"},
},
{
"non-auth collection",
func() *core.MFA {
mfa := core.NewMFA(app)
mfa.SetCollectionRef(demo1.Collection().Id)
mfa.SetRecordRef(demo1.Id)
mfa.SetMethod("test123")
return mfa
},
[]string{"collectionRef"},
},
{
"missing record id",
func() *core.MFA {
mfa := core.NewMFA(app)
mfa.SetCollectionRef(user.Collection().Id)
mfa.SetRecordRef("missing")
mfa.SetMethod("test123")
return mfa
},
[]string{"recordRef"},
},
{
"valid ref",
func() *core.MFA {
mfa := core.NewMFA(app)
mfa.SetCollectionRef(user.Collection().Id)
mfa.SetRecordRef(user.Id)
mfa.SetMethod("test123")
return mfa
},
[]string{},
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
errs := app.Validate(s.mfa())
tests.TestValidationErrors(t, errs, s.expectErrors)
})
}
}

117
core/mfa_query.go Normal file
View file

@ -0,0 +1,117 @@
package core
import (
"errors"
"time"
"github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase/tools/types"
)
// FindAllMFAsByRecord returns all MFA models linked to the provided auth record.
func (app *BaseApp) FindAllMFAsByRecord(authRecord *Record) ([]*MFA, error) {
result := []*MFA{}
err := app.RecordQuery(CollectionNameMFAs).
AndWhere(dbx.HashExp{
"collectionRef": authRecord.Collection().Id,
"recordRef": authRecord.Id,
}).
OrderBy("created DESC").
All(&result)
if err != nil {
return nil, err
}
return result, nil
}
// FindAllMFAsByCollection returns all MFA models linked to the provided collection.
func (app *BaseApp) FindAllMFAsByCollection(collection *Collection) ([]*MFA, error) {
result := []*MFA{}
err := app.RecordQuery(CollectionNameMFAs).
AndWhere(dbx.HashExp{"collectionRef": collection.Id}).
OrderBy("created DESC").
All(&result)
if err != nil {
return nil, err
}
return result, nil
}
// FindMFAById returns a single MFA model by its id.
func (app *BaseApp) FindMFAById(id string) (*MFA, error) {
result := &MFA{}
err := app.RecordQuery(CollectionNameMFAs).
AndWhere(dbx.HashExp{"id": id}).
Limit(1).
One(result)
if err != nil {
return nil, err
}
return result, nil
}
// DeleteAllMFAsByRecord deletes all MFA models associated with the provided record.
//
// Returns a combined error with the failed deletes.
func (app *BaseApp) DeleteAllMFAsByRecord(authRecord *Record) error {
models, err := app.FindAllMFAsByRecord(authRecord)
if err != nil {
return err
}
var errs []error
for _, m := range models {
if err := app.Delete(m); err != nil {
errs = append(errs, err)
}
}
if len(errs) > 0 {
return errors.Join(errs...)
}
return nil
}
// DeleteExpiredMFAs deletes the expired MFAs for all auth collections.
func (app *BaseApp) DeleteExpiredMFAs() error {
authCollections, err := app.FindAllCollections(CollectionTypeAuth)
if err != nil {
return err
}
// note: perform even if MFA is disabled to ensure that there are no dangling old records
for _, collection := range authCollections {
minValidDate, err := types.ParseDateTime(time.Now().Add(-1 * collection.MFA.DurationTime()))
if err != nil {
return err
}
items := []*Record{}
err = app.RecordQuery(CollectionNameMFAs).
AndWhere(dbx.HashExp{"collectionRef": collection.Id}).
AndWhere(dbx.NewExp("[[created]] < {:date}", dbx.Params{"date": minValidDate})).
All(&items)
if err != nil {
return err
}
for _, item := range items {
err = app.Delete(item)
if err != nil {
return err
}
}
}
return nil
}

311
core/mfa_query_test.go Normal file
View file

@ -0,0 +1,311 @@
package core_test
import (
"fmt"
"slices"
"testing"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
)
func TestFindAllMFAsByRecord(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
if err := tests.StubMFARecords(app); err != nil {
t.Fatal(err)
}
demo1, err := app.FindRecordById("demo1", "84nmscqy84lsi1t")
if err != nil {
t.Fatal(err)
}
superuser2, err := app.FindAuthRecordByEmail(core.CollectionNameSuperusers, "test2@example.com")
if err != nil {
t.Fatal(err)
}
superuser4, err := app.FindAuthRecordByEmail(core.CollectionNameSuperusers, "test4@example.com")
if err != nil {
t.Fatal(err)
}
user1, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
scenarios := []struct {
record *core.Record
expected []string
}{
{demo1, nil},
{superuser2, []string{"superuser2_0", "superuser2_3", "superuser2_2", "superuser2_1", "superuser2_4"}},
{superuser4, nil},
{user1, []string{"user1_0"}},
}
for _, s := range scenarios {
t.Run(s.record.Collection().Name+"_"+s.record.Id, func(t *testing.T) {
result, err := app.FindAllMFAsByRecord(s.record)
if err != nil {
t.Fatal(err)
}
if len(result) != len(s.expected) {
t.Fatalf("Expected total mfas %d, got %d", len(s.expected), len(result))
}
for i, id := range s.expected {
if result[i].Id != id {
t.Errorf("[%d] Expected id %q, got %q", i, id, result[i].Id)
}
}
})
}
}
func TestFindAllMFAsByCollection(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
if err := tests.StubMFARecords(app); err != nil {
t.Fatal(err)
}
demo1, err := app.FindCollectionByNameOrId("demo1")
if err != nil {
t.Fatal(err)
}
superusers, err := app.FindCollectionByNameOrId(core.CollectionNameSuperusers)
if err != nil {
t.Fatal(err)
}
clients, err := app.FindCollectionByNameOrId("clients")
if err != nil {
t.Fatal(err)
}
users, err := app.FindCollectionByNameOrId("users")
if err != nil {
t.Fatal(err)
}
scenarios := []struct {
collection *core.Collection
expected []string
}{
{demo1, nil},
{superusers, []string{
"superuser2_0",
"superuser2_3",
"superuser3_0",
"superuser2_2",
"superuser3_1",
"superuser2_1",
"superuser2_4",
}},
{clients, nil},
{users, []string{"user1_0"}},
}
for _, s := range scenarios {
t.Run(s.collection.Name, func(t *testing.T) {
result, err := app.FindAllMFAsByCollection(s.collection)
if err != nil {
t.Fatal(err)
}
if len(result) != len(s.expected) {
t.Fatalf("Expected total mfas %d, got %d", len(s.expected), len(result))
}
for i, id := range s.expected {
if result[i].Id != id {
t.Errorf("[%d] Expected id %q, got %q", i, id, result[i].Id)
}
}
})
}
}
func TestFindMFAById(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
if err := tests.StubMFARecords(app); err != nil {
t.Fatal(err)
}
scenarios := []struct {
id string
expectError bool
}{
{"", true},
{"84nmscqy84lsi1t", true}, // non-mfa id
{"superuser2_0", false},
{"superuser2_4", false}, // expired
{"user1_0", false},
}
for _, s := range scenarios {
t.Run(s.id, func(t *testing.T) {
result, err := app.FindMFAById(s.id)
hasErr := err != nil
if hasErr != s.expectError {
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
}
if hasErr {
return
}
if result.Id != s.id {
t.Fatalf("Expected record with id %q, got %q", s.id, result.Id)
}
})
}
}
func TestDeleteAllMFAsByRecord(t *testing.T) {
t.Parallel()
testApp, _ := tests.NewTestApp()
defer testApp.Cleanup()
demo1, err := testApp.FindRecordById("demo1", "84nmscqy84lsi1t")
if err != nil {
t.Fatal(err)
}
superuser2, err := testApp.FindAuthRecordByEmail(core.CollectionNameSuperusers, "test2@example.com")
if err != nil {
t.Fatal(err)
}
superuser4, err := testApp.FindAuthRecordByEmail(core.CollectionNameSuperusers, "test4@example.com")
if err != nil {
t.Fatal(err)
}
user1, err := testApp.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
scenarios := []struct {
record *core.Record
deletedIds []string
}{
{demo1, nil}, // non-auth record
{superuser2, []string{"superuser2_0", "superuser2_1", "superuser2_3", "superuser2_2", "superuser2_4"}},
{superuser4, nil},
{user1, []string{"user1_0"}},
}
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%s_%s", i, s.record.Collection().Name, s.record.Id), func(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
if err := tests.StubMFARecords(app); err != nil {
t.Fatal(err)
}
deletedIds := []string{}
app.OnRecordAfterDeleteSuccess().BindFunc(func(e *core.RecordEvent) error {
deletedIds = append(deletedIds, e.Record.Id)
return e.Next()
})
err := app.DeleteAllMFAsByRecord(s.record)
if err != nil {
t.Fatal(err)
}
if len(deletedIds) != len(s.deletedIds) {
t.Fatalf("Expected deleted ids\n%v\ngot\n%v", s.deletedIds, deletedIds)
}
for _, id := range s.deletedIds {
if !slices.Contains(deletedIds, id) {
t.Errorf("Expected to find deleted id %q in %v", id, deletedIds)
}
}
})
}
}
func TestDeleteExpiredMFAs(t *testing.T) {
t.Parallel()
checkDeletedIds := func(app core.App, t *testing.T, expectedDeletedIds []string) {
if err := tests.StubMFARecords(app); err != nil {
t.Fatal(err)
}
deletedIds := []string{}
app.OnRecordDelete().BindFunc(func(e *core.RecordEvent) error {
deletedIds = append(deletedIds, e.Record.Id)
return e.Next()
})
if err := app.DeleteExpiredMFAs(); err != nil {
t.Fatal(err)
}
if len(deletedIds) != len(expectedDeletedIds) {
t.Fatalf("Expected deleted ids\n%v\ngot\n%v", expectedDeletedIds, deletedIds)
}
for _, id := range expectedDeletedIds {
if !slices.Contains(deletedIds, id) {
t.Errorf("Expected to find deleted id %q in %v", id, deletedIds)
}
}
}
t.Run("default test collections", func(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
checkDeletedIds(app, t, []string{
"user1_0",
"superuser2_1",
"superuser2_4",
})
})
t.Run("mfa collection duration mock", func(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
superusers, err := app.FindCollectionByNameOrId(core.CollectionNameSuperusers)
if err != nil {
t.Fatal(err)
}
superusers.MFA.Duration = 60
if err := app.Save(superusers); err != nil {
t.Fatalf("Failed to mock superusers mfa duration: %v", err)
}
checkDeletedIds(app, t, []string{
"user1_0",
"superuser2_1",
"superuser2_2",
"superuser2_4",
"superuser3_1",
})
})
}

83
core/migrations_list.go Normal file
View file

@ -0,0 +1,83 @@
package core
import (
"path/filepath"
"runtime"
"sort"
)
type Migration struct {
Up func(txApp App) error
Down func(txApp App) error
File string
ReapplyCondition func(txApp App, runner *MigrationsRunner, fileName string) (bool, error)
}
// MigrationsList defines a list with migration definitions
type MigrationsList struct {
list []*Migration
}
// Item returns a single migration from the list by its index.
func (l *MigrationsList) Item(index int) *Migration {
return l.list[index]
}
// Items returns the internal migrations list slice.
func (l *MigrationsList) Items() []*Migration {
return l.list
}
// Copy copies all provided list migrations into the current one.
func (l *MigrationsList) Copy(list MigrationsList) {
for _, item := range list.Items() {
l.Register(item.Up, item.Down, item.File)
}
}
// Add adds adds an existing migration definition to the list.
//
// If m.File is not provided, it will try to get the name from its .go file.
//
// The list will be sorted automatically based on the migrations file name.
func (l *MigrationsList) Add(m *Migration) {
if m.File == "" {
_, path, _, _ := runtime.Caller(1)
m.File = filepath.Base(path)
}
l.list = append(l.list, m)
sort.SliceStable(l.list, func(i int, j int) bool {
return l.list[i].File < l.list[j].File
})
}
// Register adds new migration definition to the list.
//
// If optFilename is not provided, it will try to get the name from its .go file.
//
// The list will be sorted automatically based on the migrations file name.
func (l *MigrationsList) Register(
up func(txApp App) error,
down func(txApp App) error,
optFilename ...string,
) {
var file string
if len(optFilename) > 0 {
file = optFilename[0]
} else {
_, path, _, _ := runtime.Caller(1)
file = filepath.Base(path)
}
l.list = append(l.list, &Migration{
File: file,
Up: up,
Down: down,
})
sort.SliceStable(l.list, func(i int, j int) bool {
return l.list[i].File < l.list[j].File
})
}

View file

@ -0,0 +1,48 @@
package core_test
import (
"testing"
"github.com/pocketbase/pocketbase/core"
)
func TestMigrationsList(t *testing.T) {
l1 := core.MigrationsList{}
l1.Add(&core.Migration{File: "5_test.go"})
l1.Add(&core.Migration{ /* auto detect file name */ })
l1.Register(nil, nil, "3_test.go")
l1.Register(nil, nil, "1_test.go")
l1.Register(nil, nil, "2_test.go")
l1.Register(nil, nil /* auto detect file name */)
l2 := core.MigrationsList{}
l2.Register(nil, nil, "4_test.go")
l2.Copy(l1)
expected := []string{
"1_test.go",
"2_test.go",
"3_test.go",
"4_test.go",
"5_test.go",
// twice because there 2 test migrations with auto filename
"migrations_list_test.go",
"migrations_list_test.go",
}
items := l2.Items()
if len(items) != len(expected) {
names := make([]string, len(items))
for i, item := range items {
names[i] = item.File
}
t.Fatalf("Expected %d items, got %d:\n%v", len(expected), len(names), names)
}
for i, name := range expected {
item := l2.Item(i)
if item.File != name {
t.Fatalf("Expected name %s for index %d, got %s", name, i, item.File)
}
}
}

317
core/migrations_runner.go Normal file
View file

@ -0,0 +1,317 @@
package core
import (
"fmt"
"strings"
"time"
"github.com/fatih/color"
"github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase/tools/osutils"
"github.com/spf13/cast"
)
var AppMigrations MigrationsList
var SystemMigrations MigrationsList
const DefaultMigrationsTable = "_migrations"
// MigrationsRunner defines a simple struct for managing the execution of db migrations.
type MigrationsRunner struct {
app App
tableName string
migrationsList MigrationsList
inited bool
}
// NewMigrationsRunner creates and initializes a new db migrations MigrationsRunner instance.
func NewMigrationsRunner(app App, migrationsList MigrationsList) *MigrationsRunner {
return &MigrationsRunner{
app: app,
migrationsList: migrationsList,
tableName: DefaultMigrationsTable,
}
}
// Run interactively executes the current runner with the provided args.
//
// The following commands are supported:
// - up - applies all migrations
// - down [n] - reverts the last n (default 1) applied migrations
// - history-sync - syncs the migrations table with the runner's migrations list
func (r *MigrationsRunner) Run(args ...string) error {
if err := r.initMigrationsTable(); err != nil {
return err
}
cmd := "up"
if len(args) > 0 {
cmd = args[0]
}
switch cmd {
case "up":
applied, err := r.Up()
if err != nil {
return err
}
if len(applied) == 0 {
color.Green("No new migrations to apply.")
} else {
for _, file := range applied {
color.Green("Applied %s", file)
}
}
return nil
case "down":
toRevertCount := 1
if len(args) > 1 {
toRevertCount = cast.ToInt(args[1])
if toRevertCount < 0 {
// revert all applied migrations
toRevertCount = len(r.migrationsList.Items())
}
}
names, err := r.lastAppliedMigrations(toRevertCount)
if err != nil {
return err
}
confirm := osutils.YesNoPrompt(fmt.Sprintf(
"\n%v\nDo you really want to revert the last %d applied migration(s)?",
strings.Join(names, "\n"),
toRevertCount,
), false)
if !confirm {
fmt.Println("The command has been cancelled")
return nil
}
reverted, err := r.Down(toRevertCount)
if err != nil {
return err
}
if len(reverted) == 0 {
color.Green("No migrations to revert.")
} else {
for _, file := range reverted {
color.Green("Reverted %s", file)
}
}
return nil
case "history-sync":
if err := r.RemoveMissingAppliedMigrations(); err != nil {
return err
}
color.Green("The %s table was synced with the available migrations.", r.tableName)
return nil
default:
return fmt.Errorf("unsupported command: %q", cmd)
}
}
// Up executes all unapplied migrations for the provided runner.
//
// On success returns list with the applied migrations file names.
func (r *MigrationsRunner) Up() ([]string, error) {
if err := r.initMigrationsTable(); err != nil {
return nil, err
}
applied := []string{}
err := r.app.AuxRunInTransaction(func(txApp App) error {
return txApp.RunInTransaction(func(txApp App) error {
for _, m := range r.migrationsList.Items() {
// applied migrations check
if r.isMigrationApplied(txApp, m.File) {
if m.ReapplyCondition == nil {
continue // no need to reapply
}
shouldReapply, err := m.ReapplyCondition(txApp, r, m.File)
if err != nil {
return err
}
if !shouldReapply {
continue
}
// clear previous history stored entry
// (it will be recreated after successful execution)
r.saveRevertedMigration(txApp, m.File)
}
// ignore empty Up action
if m.Up != nil {
if err := m.Up(txApp); err != nil {
return fmt.Errorf("failed to apply migration %s: %w", m.File, err)
}
}
if err := r.saveAppliedMigration(txApp, m.File); err != nil {
return fmt.Errorf("failed to save applied migration info for %s: %w", m.File, err)
}
applied = append(applied, m.File)
}
return nil
})
})
if err != nil {
return nil, err
}
return applied, nil
}
// Down reverts the last `toRevertCount` applied migrations
// (in the order they were applied).
//
// On success returns list with the reverted migrations file names.
func (r *MigrationsRunner) Down(toRevertCount int) ([]string, error) {
if err := r.initMigrationsTable(); err != nil {
return nil, err
}
reverted := make([]string, 0, toRevertCount)
names, appliedErr := r.lastAppliedMigrations(toRevertCount)
if appliedErr != nil {
return nil, appliedErr
}
err := r.app.AuxRunInTransaction(func(txApp App) error {
return txApp.RunInTransaction(func(txApp App) error {
for _, name := range names {
for _, m := range r.migrationsList.Items() {
if m.File != name {
continue
}
// revert limit reached
if toRevertCount-len(reverted) <= 0 {
return nil
}
// ignore empty Down action
if m.Down != nil {
if err := m.Down(txApp); err != nil {
return fmt.Errorf("failed to revert migration %s: %w", m.File, err)
}
}
if err := r.saveRevertedMigration(txApp, m.File); err != nil {
return fmt.Errorf("failed to save reverted migration info for %s: %w", m.File, err)
}
reverted = append(reverted, m.File)
}
}
return nil
})
})
if err != nil {
return nil, err
}
return reverted, nil
}
// RemoveMissingAppliedMigrations removes the db entries of all applied migrations
// that are not listed in the runner's migrations list.
func (r *MigrationsRunner) RemoveMissingAppliedMigrations() error {
loadedMigrations := r.migrationsList.Items()
names := make([]any, len(loadedMigrations))
for i, migration := range loadedMigrations {
names[i] = migration.File
}
_, err := r.app.DB().Delete(r.tableName, dbx.Not(dbx.HashExp{
"file": names,
})).Execute()
return err
}
func (r *MigrationsRunner) initMigrationsTable() error {
if r.inited {
return nil // already inited
}
rawQuery := fmt.Sprintf(
"CREATE TABLE IF NOT EXISTS {{%s}} (file VARCHAR(255) PRIMARY KEY NOT NULL, applied INTEGER NOT NULL)",
r.tableName,
)
_, err := r.app.DB().NewQuery(rawQuery).Execute()
if err == nil {
r.inited = true
}
return err
}
func (r *MigrationsRunner) isMigrationApplied(txApp App, file string) bool {
var exists int
err := txApp.DB().Select("(1)").
From(r.tableName).
Where(dbx.HashExp{"file": file}).
Limit(1).
Row(&exists)
return err == nil && exists > 0
}
func (r *MigrationsRunner) saveAppliedMigration(txApp App, file string) error {
_, err := txApp.DB().Insert(r.tableName, dbx.Params{
"file": file,
"applied": time.Now().UnixMicro(),
}).Execute()
return err
}
func (r *MigrationsRunner) saveRevertedMigration(txApp App, file string) error {
_, err := txApp.DB().Delete(r.tableName, dbx.HashExp{"file": file}).Execute()
return err
}
func (r *MigrationsRunner) lastAppliedMigrations(limit int) ([]string, error) {
var files = make([]string, 0, limit)
loadedMigrations := r.migrationsList.Items()
names := make([]any, len(loadedMigrations))
for i, migration := range loadedMigrations {
names[i] = migration.File
}
err := r.app.DB().Select("file").
From(r.tableName).
Where(dbx.Not(dbx.HashExp{"applied": nil})).
AndWhere(dbx.HashExp{"file": names}).
// unify microseconds and seconds applied time for backward compatibility
OrderBy("substr(applied||'0000000000000000', 0, 17) DESC").
AndOrderBy("file DESC").
Limit(int64(limit)).
Column(&files)
if err != nil {
return nil, err
}
return files, nil
}

View file

@ -0,0 +1,212 @@
package core_test
import (
"encoding/json"
"fmt"
"testing"
"time"
"github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
)
func TestMigrationsRunnerUpAndDown(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
callsOrder := []string{}
l := core.MigrationsList{}
l.Register(func(app core.App) error {
callsOrder = append(callsOrder, "up2")
return nil
}, func(app core.App) error {
callsOrder = append(callsOrder, "down2")
return nil
}, "2_test")
l.Register(func(app core.App) error {
callsOrder = append(callsOrder, "up3")
return nil
}, func(app core.App) error {
callsOrder = append(callsOrder, "down3")
return nil
}, "3_test")
l.Register(func(app core.App) error {
callsOrder = append(callsOrder, "up1")
return nil
}, func(app core.App) error {
callsOrder = append(callsOrder, "down1")
return nil
}, "1_test")
l.Register(func(app core.App) error {
callsOrder = append(callsOrder, "up4")
return nil
}, func(app core.App) error {
callsOrder = append(callsOrder, "down4")
return nil
}, "4_test")
l.Add(&core.Migration{
Up: func(app core.App) error {
callsOrder = append(callsOrder, "up5")
return nil
},
Down: func(app core.App) error {
callsOrder = append(callsOrder, "down5")
return nil
},
File: "5_test",
ReapplyCondition: func(txApp core.App, runner *core.MigrationsRunner, fileName string) (bool, error) {
return true, nil
},
})
runner := core.NewMigrationsRunner(app, l)
// ---------------------------------------------------------------
// simulate partially out-of-order applied migration
// ---------------------------------------------------------------
_, err := app.DB().Insert(core.DefaultMigrationsTable, dbx.Params{
"file": "4_test",
"applied": time.Now().UnixMicro() - 2,
}).Execute()
if err != nil {
t.Fatalf("Failed to insert 5_test migration: %v", err)
}
_, err = app.DB().Insert(core.DefaultMigrationsTable, dbx.Params{
"file": "5_test",
"applied": time.Now().UnixMicro() - 1,
}).Execute()
if err != nil {
t.Fatalf("Failed to insert 5_test migration: %v", err)
}
_, err = app.DB().Insert(core.DefaultMigrationsTable, dbx.Params{
"file": "2_test",
"applied": time.Now().UnixMicro(),
}).Execute()
if err != nil {
t.Fatalf("Failed to insert 2_test migration: %v", err)
}
// ---------------------------------------------------------------
// Up()
// ---------------------------------------------------------------
if _, err := runner.Up(); err != nil {
t.Fatal(err)
}
expectedUpCallsOrder := `["up1","up3","up5"]` // skip up2 and up4 since they were applied already (up5 has extra reapply condition)
upCallsOrder, err := json.Marshal(callsOrder)
if err != nil {
t.Fatal(err)
}
if v := string(upCallsOrder); v != expectedUpCallsOrder {
t.Fatalf("Expected Up() calls order %s, got %s", expectedUpCallsOrder, upCallsOrder)
}
// ---------------------------------------------------------------
// reset callsOrder
callsOrder = []string{}
// simulate unrun migration
l.Register(nil, func(app core.App) error {
callsOrder = append(callsOrder, "down6")
return nil
}, "6_test")
// simulate applied migrations from different migrations list
_, err = app.DB().Insert(core.DefaultMigrationsTable, dbx.Params{
"file": "from_different_list",
"applied": time.Now().UnixMicro(),
}).Execute()
if err != nil {
t.Fatalf("Failed to insert from_different_list migration: %v", err)
}
// ---------------------------------------------------------------
// ---------------------------------------------------------------
// Down()
// ---------------------------------------------------------------
if _, err := runner.Down(2); err != nil {
t.Fatal(err)
}
expectedDownCallsOrder := `["down5","down3"]` // revert in the applied order
downCallsOrder, err := json.Marshal(callsOrder)
if err != nil {
t.Fatal(err)
}
if v := string(downCallsOrder); v != expectedDownCallsOrder {
t.Fatalf("Expected Down() calls order %s, got %s", expectedDownCallsOrder, downCallsOrder)
}
}
func TestMigrationsRunnerRemoveMissingAppliedMigrations(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
// mock migrations history
for i := 1; i <= 3; i++ {
_, err := app.DB().Insert(core.DefaultMigrationsTable, dbx.Params{
"file": fmt.Sprintf("%d_test", i),
"applied": time.Now().UnixMicro(),
}).Execute()
if err != nil {
t.Fatal(err)
}
}
if !isMigrationApplied(app, "2_test") {
t.Fatalf("Expected 2_test migration to be applied")
}
// create a runner without 2_test to mock deleted migration
l := core.MigrationsList{}
l.Register(func(app core.App) error {
return nil
}, func(app core.App) error {
return nil
}, "1_test")
l.Register(func(app core.App) error {
return nil
}, func(app core.App) error {
return nil
}, "3_test")
r := core.NewMigrationsRunner(app, l)
if err := r.RemoveMissingAppliedMigrations(); err != nil {
t.Fatalf("Failed to remove missing applied migrations: %v", err)
}
if isMigrationApplied(app, "2_test") {
t.Fatalf("Expected 2_test migration to NOT be applied")
}
}
func isMigrationApplied(app core.App, file string) bool {
var exists int
err := app.DB().Select("(1)").
From(core.DefaultMigrationsTable).
Where(dbx.HashExp{"file": file}).
Limit(1).
Row(&exists)
return err == nil && exists > 0
}

127
core/otp_model.go Normal file
View file

@ -0,0 +1,127 @@
package core
import (
"context"
"errors"
"time"
"github.com/pocketbase/pocketbase/tools/types"
)
const CollectionNameOTPs = "_otps"
var (
_ Model = (*OTP)(nil)
_ PreValidator = (*OTP)(nil)
_ RecordProxy = (*OTP)(nil)
)
// OTP defines a Record proxy for working with the otps collection.
type OTP struct {
*Record
}
// NewOTP instantiates and returns a new blank *OTP model.
//
// Example usage:
//
// otp := core.NewOTP(app)
// otp.SetRecordRef(user.Id)
// otp.SetCollectionRef(user.Collection().Id)
// otp.SetPassword(security.RandomStringWithAlphabet(6, "1234567890"))
// app.Save(otp)
func NewOTP(app App) *OTP {
m := &OTP{}
c, err := app.FindCachedCollectionByNameOrId(CollectionNameOTPs)
if err != nil {
// this is just to make tests easier since otp is a system collection and it is expected to be always accessible
// (note: the loaded record is further checked on OTP.PreValidate())
c = NewBaseCollection("__invalid__")
}
m.Record = NewRecord(c)
return m
}
// PreValidate implements the [PreValidator] interface and checks
// whether the proxy is properly loaded.
func (m *OTP) PreValidate(ctx context.Context, app App) error {
if m.Record == nil || m.Record.Collection().Name != CollectionNameOTPs {
return errors.New("missing or invalid otp ProxyRecord")
}
return nil
}
// ProxyRecord returns the proxied Record model.
func (m *OTP) ProxyRecord() *Record {
return m.Record
}
// SetProxyRecord loads the specified record model into the current proxy.
func (m *OTP) SetProxyRecord(record *Record) {
m.Record = record
}
// CollectionRef returns the "collectionRef" field value.
func (m *OTP) CollectionRef() string {
return m.GetString("collectionRef")
}
// SetCollectionRef updates the "collectionRef" record field value.
func (m *OTP) SetCollectionRef(collectionId string) {
m.Set("collectionRef", collectionId)
}
// RecordRef returns the "recordRef" record field value.
func (m *OTP) RecordRef() string {
return m.GetString("recordRef")
}
// SetRecordRef updates the "recordRef" record field value.
func (m *OTP) SetRecordRef(recordId string) {
m.Set("recordRef", recordId)
}
// SentTo returns the "sentTo" record field value.
//
// It could be any string value (email, phone, message app id, etc.)
// and usually is used as part of the auth flow to update the verified
// user state in case for example the sentTo value matches with the user record email.
func (m *OTP) SentTo() string {
return m.GetString("sentTo")
}
// SetSentTo updates the "sentTo" record field value.
func (m *OTP) SetSentTo(val string) {
m.Set("sentTo", val)
}
// Created returns the "created" record field value.
func (m *OTP) Created() types.DateTime {
return m.GetDateTime("created")
}
// Updated returns the "updated" record field value.
func (m *OTP) Updated() types.DateTime {
return m.GetDateTime("updated")
}
// HasExpired checks if the otp is expired, aka. whether it has been
// more than maxElapsed time since its creation.
func (m *OTP) HasExpired(maxElapsed time.Duration) bool {
return time.Since(m.Created().Time()) > maxElapsed
}
func (app *BaseApp) registerOTPHooks() {
recordRefHooks[*OTP](app, CollectionNameOTPs, CollectionTypeAuth)
// run on every hour to cleanup expired otp sessions
app.Cron().Add("__pbOTPCleanup__", "0 * * * *", func() {
if err := app.DeleteExpiredOTPs(); err != nil {
app.Logger().Warn("Failed to delete expired OTP sessions", "error", err)
}
})
}

302
core/otp_model_test.go Normal file
View file

@ -0,0 +1,302 @@
package core_test
import (
"fmt"
"testing"
"time"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
"github.com/pocketbase/pocketbase/tools/types"
)
func TestNewOTP(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
otp := core.NewOTP(app)
if otp.Collection().Name != core.CollectionNameOTPs {
t.Fatalf("Expected record with %q collection, got %q", core.CollectionNameOTPs, otp.Collection().Name)
}
}
func TestOTPProxyRecord(t *testing.T) {
t.Parallel()
record := core.NewRecord(core.NewBaseCollection("test"))
record.Id = "test_id"
otp := core.OTP{}
otp.SetProxyRecord(record)
if otp.ProxyRecord() == nil || otp.ProxyRecord().Id != record.Id {
t.Fatalf("Expected proxy record with id %q, got %v", record.Id, otp.ProxyRecord())
}
}
func TestOTPRecordRef(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
otp := core.NewOTP(app)
testValues := []string{"test_1", "test2", ""}
for i, testValue := range testValues {
t.Run(fmt.Sprintf("%d_%q", i, testValue), func(t *testing.T) {
otp.SetRecordRef(testValue)
if v := otp.RecordRef(); v != testValue {
t.Fatalf("Expected getter %q, got %q", testValue, v)
}
if v := otp.GetString("recordRef"); v != testValue {
t.Fatalf("Expected field value %q, got %q", testValue, v)
}
})
}
}
func TestOTPCollectionRef(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
otp := core.NewOTP(app)
testValues := []string{"test_1", "test2", ""}
for i, testValue := range testValues {
t.Run(fmt.Sprintf("%d_%q", i, testValue), func(t *testing.T) {
otp.SetCollectionRef(testValue)
if v := otp.CollectionRef(); v != testValue {
t.Fatalf("Expected getter %q, got %q", testValue, v)
}
if v := otp.GetString("collectionRef"); v != testValue {
t.Fatalf("Expected field value %q, got %q", testValue, v)
}
})
}
}
func TestOTPSentTo(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
otp := core.NewOTP(app)
testValues := []string{"test_1", "test2", ""}
for i, testValue := range testValues {
t.Run(fmt.Sprintf("%d_%q", i, testValue), func(t *testing.T) {
otp.SetSentTo(testValue)
if v := otp.SentTo(); v != testValue {
t.Fatalf("Expected getter %q, got %q", testValue, v)
}
if v := otp.GetString("sentTo"); v != testValue {
t.Fatalf("Expected field value %q, got %q", testValue, v)
}
})
}
}
func TestOTPCreated(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
otp := core.NewOTP(app)
if v := otp.Created().String(); v != "" {
t.Fatalf("Expected empty created, got %q", v)
}
now := types.NowDateTime()
otp.SetRaw("created", now)
if v := otp.Created().String(); v != now.String() {
t.Fatalf("Expected %q created, got %q", now.String(), v)
}
}
func TestOTPUpdated(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
otp := core.NewOTP(app)
if v := otp.Updated().String(); v != "" {
t.Fatalf("Expected empty updated, got %q", v)
}
now := types.NowDateTime()
otp.SetRaw("updated", now)
if v := otp.Updated().String(); v != now.String() {
t.Fatalf("Expected %q updated, got %q", now.String(), v)
}
}
func TestOTPHasExpired(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
now := types.NowDateTime()
otp := core.NewOTP(app)
otp.SetRaw("created", now.Add(-5*time.Minute))
scenarios := []struct {
maxElapsed time.Duration
expected bool
}{
{0 * time.Minute, true},
{3 * time.Minute, true},
{5 * time.Minute, true},
{6 * time.Minute, false},
}
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%s", i, s.maxElapsed.String()), func(t *testing.T) {
result := otp.HasExpired(s.maxElapsed)
if result != s.expected {
t.Fatalf("Expected %v, got %v", s.expected, result)
}
})
}
}
func TestOTPPreValidate(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
otpsCol, err := app.FindCollectionByNameOrId(core.CollectionNameOTPs)
if err != nil {
t.Fatal(err)
}
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
t.Run("no proxy record", func(t *testing.T) {
otp := &core.OTP{}
if err := app.Validate(otp); err == nil {
t.Fatal("Expected collection validation error")
}
})
t.Run("non-OTP collection", func(t *testing.T) {
otp := &core.OTP{}
otp.SetProxyRecord(core.NewRecord(core.NewBaseCollection("invalid")))
otp.SetRecordRef(user.Id)
otp.SetCollectionRef(user.Collection().Id)
otp.SetPassword("test123")
if err := app.Validate(otp); err == nil {
t.Fatal("Expected collection validation error")
}
})
t.Run("OTP collection", func(t *testing.T) {
otp := &core.OTP{}
otp.SetProxyRecord(core.NewRecord(otpsCol))
otp.SetRecordRef(user.Id)
otp.SetCollectionRef(user.Collection().Id)
otp.SetPassword("test123")
if err := app.Validate(otp); err != nil {
t.Fatalf("Expected nil validation error, got %v", err)
}
})
}
func TestOTPValidateHook(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
demo1, err := app.FindRecordById("demo1", "84nmscqy84lsi1t")
if err != nil {
t.Fatal(err)
}
scenarios := []struct {
name string
otp func() *core.OTP
expectErrors []string
}{
{
"empty",
func() *core.OTP {
return core.NewOTP(app)
},
[]string{"collectionRef", "recordRef", "password"},
},
{
"non-auth collection",
func() *core.OTP {
otp := core.NewOTP(app)
otp.SetCollectionRef(demo1.Collection().Id)
otp.SetRecordRef(demo1.Id)
otp.SetPassword("test123")
return otp
},
[]string{"collectionRef"},
},
{
"missing record id",
func() *core.OTP {
otp := core.NewOTP(app)
otp.SetCollectionRef(user.Collection().Id)
otp.SetRecordRef("missing")
otp.SetPassword("test123")
return otp
},
[]string{"recordRef"},
},
{
"valid ref",
func() *core.OTP {
otp := core.NewOTP(app)
otp.SetCollectionRef(user.Collection().Id)
otp.SetRecordRef(user.Id)
otp.SetPassword("test123")
return otp
},
[]string{},
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
errs := app.Validate(s.otp())
tests.TestValidationErrors(t, errs, s.expectErrors)
})
}
}

117
core/otp_query.go Normal file
View file

@ -0,0 +1,117 @@
package core
import (
"errors"
"time"
"github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase/tools/types"
)
// FindAllOTPsByRecord returns all OTP models linked to the provided auth record.
func (app *BaseApp) FindAllOTPsByRecord(authRecord *Record) ([]*OTP, error) {
result := []*OTP{}
err := app.RecordQuery(CollectionNameOTPs).
AndWhere(dbx.HashExp{
"collectionRef": authRecord.Collection().Id,
"recordRef": authRecord.Id,
}).
OrderBy("created DESC").
All(&result)
if err != nil {
return nil, err
}
return result, nil
}
// FindAllOTPsByCollection returns all OTP models linked to the provided collection.
func (app *BaseApp) FindAllOTPsByCollection(collection *Collection) ([]*OTP, error) {
result := []*OTP{}
err := app.RecordQuery(CollectionNameOTPs).
AndWhere(dbx.HashExp{"collectionRef": collection.Id}).
OrderBy("created DESC").
All(&result)
if err != nil {
return nil, err
}
return result, nil
}
// FindOTPById returns a single OTP model by its id.
func (app *BaseApp) FindOTPById(id string) (*OTP, error) {
result := &OTP{}
err := app.RecordQuery(CollectionNameOTPs).
AndWhere(dbx.HashExp{"id": id}).
Limit(1).
One(result)
if err != nil {
return nil, err
}
return result, nil
}
// DeleteAllOTPsByRecord deletes all OTP models associated with the provided record.
//
// Returns a combined error with the failed deletes.
func (app *BaseApp) DeleteAllOTPsByRecord(authRecord *Record) error {
models, err := app.FindAllOTPsByRecord(authRecord)
if err != nil {
return err
}
var errs []error
for _, m := range models {
if err := app.Delete(m); err != nil {
errs = append(errs, err)
}
}
if len(errs) > 0 {
return errors.Join(errs...)
}
return nil
}
// DeleteExpiredOTPs deletes the expired OTPs for all auth collections.
func (app *BaseApp) DeleteExpiredOTPs() error {
authCollections, err := app.FindAllCollections(CollectionTypeAuth)
if err != nil {
return err
}
// note: perform even if OTP is disabled to ensure that there are no dangling old records
for _, collection := range authCollections {
minValidDate, err := types.ParseDateTime(time.Now().Add(-1 * collection.OTP.DurationTime()))
if err != nil {
return err
}
items := []*Record{}
err = app.RecordQuery(CollectionNameOTPs).
AndWhere(dbx.HashExp{"collectionRef": collection.Id}).
AndWhere(dbx.NewExp("[[created]] < {:date}", dbx.Params{"date": minValidDate})).
All(&items)
if err != nil {
return err
}
for _, item := range items {
err = app.Delete(item)
if err != nil {
return err
}
}
}
return nil
}

310
core/otp_query_test.go Normal file
View file

@ -0,0 +1,310 @@
package core_test
import (
"fmt"
"slices"
"testing"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
)
func TestFindAllOTPsByRecord(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
if err := tests.StubOTPRecords(app); err != nil {
t.Fatal(err)
}
demo1, err := app.FindRecordById("demo1", "84nmscqy84lsi1t")
if err != nil {
t.Fatal(err)
}
superuser2, err := app.FindAuthRecordByEmail(core.CollectionNameSuperusers, "test2@example.com")
if err != nil {
t.Fatal(err)
}
superuser4, err := app.FindAuthRecordByEmail(core.CollectionNameSuperusers, "test4@example.com")
if err != nil {
t.Fatal(err)
}
user1, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
scenarios := []struct {
record *core.Record
expected []string
}{
{demo1, nil},
{superuser2, []string{"superuser2_0", "superuser2_1", "superuser2_3", "superuser2_2", "superuser2_4"}},
{superuser4, nil},
{user1, []string{"user1_0"}},
}
for _, s := range scenarios {
t.Run(s.record.Collection().Name+"_"+s.record.Id, func(t *testing.T) {
result, err := app.FindAllOTPsByRecord(s.record)
if err != nil {
t.Fatal(err)
}
if len(result) != len(s.expected) {
t.Fatalf("Expected total otps %d, got %d", len(s.expected), len(result))
}
for i, id := range s.expected {
if result[i].Id != id {
t.Errorf("[%d] Expected id %q, got %q", i, id, result[i].Id)
}
}
})
}
}
func TestFindAllOTPsByCollection(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
if err := tests.StubOTPRecords(app); err != nil {
t.Fatal(err)
}
demo1, err := app.FindCollectionByNameOrId("demo1")
if err != nil {
t.Fatal(err)
}
superusers, err := app.FindCollectionByNameOrId(core.CollectionNameSuperusers)
if err != nil {
t.Fatal(err)
}
clients, err := app.FindCollectionByNameOrId("clients")
if err != nil {
t.Fatal(err)
}
users, err := app.FindCollectionByNameOrId("users")
if err != nil {
t.Fatal(err)
}
scenarios := []struct {
collection *core.Collection
expected []string
}{
{demo1, nil},
{superusers, []string{
"superuser2_0",
"superuser2_1",
"superuser2_3",
"superuser3_0",
"superuser3_1",
"superuser2_2",
"superuser2_4",
}},
{clients, nil},
{users, []string{"user1_0"}},
}
for _, s := range scenarios {
t.Run(s.collection.Name, func(t *testing.T) {
result, err := app.FindAllOTPsByCollection(s.collection)
if err != nil {
t.Fatal(err)
}
if len(result) != len(s.expected) {
t.Fatalf("Expected total otps %d, got %d", len(s.expected), len(result))
}
for i, id := range s.expected {
if result[i].Id != id {
t.Errorf("[%d] Expected id %q, got %q", i, id, result[i].Id)
}
}
})
}
}
func TestFindOTPById(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
if err := tests.StubOTPRecords(app); err != nil {
t.Fatal(err)
}
scenarios := []struct {
id string
expectError bool
}{
{"", true},
{"84nmscqy84lsi1t", true}, // non-otp id
{"superuser2_0", false},
{"superuser2_4", false}, // expired
{"user1_0", false},
}
for _, s := range scenarios {
t.Run(s.id, func(t *testing.T) {
result, err := app.FindOTPById(s.id)
hasErr := err != nil
if hasErr != s.expectError {
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
}
if hasErr {
return
}
if result.Id != s.id {
t.Fatalf("Expected record with id %q, got %q", s.id, result.Id)
}
})
}
}
func TestDeleteAllOTPsByRecord(t *testing.T) {
t.Parallel()
testApp, _ := tests.NewTestApp()
defer testApp.Cleanup()
demo1, err := testApp.FindRecordById("demo1", "84nmscqy84lsi1t")
if err != nil {
t.Fatal(err)
}
superuser2, err := testApp.FindAuthRecordByEmail(core.CollectionNameSuperusers, "test2@example.com")
if err != nil {
t.Fatal(err)
}
superuser4, err := testApp.FindAuthRecordByEmail(core.CollectionNameSuperusers, "test4@example.com")
if err != nil {
t.Fatal(err)
}
user1, err := testApp.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
scenarios := []struct {
record *core.Record
deletedIds []string
}{
{demo1, nil}, // non-auth record
{superuser2, []string{"superuser2_0", "superuser2_1", "superuser2_3", "superuser2_2", "superuser2_4"}},
{superuser4, nil},
{user1, []string{"user1_0"}},
}
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%s_%s", i, s.record.Collection().Name, s.record.Id), func(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
if err := tests.StubOTPRecords(app); err != nil {
t.Fatal(err)
}
deletedIds := []string{}
app.OnRecordAfterDeleteSuccess().BindFunc(func(e *core.RecordEvent) error {
deletedIds = append(deletedIds, e.Record.Id)
return e.Next()
})
err := app.DeleteAllOTPsByRecord(s.record)
if err != nil {
t.Fatal(err)
}
if len(deletedIds) != len(s.deletedIds) {
t.Fatalf("Expected deleted ids\n%v\ngot\n%v", s.deletedIds, deletedIds)
}
for _, id := range s.deletedIds {
if !slices.Contains(deletedIds, id) {
t.Errorf("Expected to find deleted id %q in %v", id, deletedIds)
}
}
})
}
}
func TestDeleteExpiredOTPs(t *testing.T) {
t.Parallel()
checkDeletedIds := func(app core.App, t *testing.T, expectedDeletedIds []string) {
if err := tests.StubOTPRecords(app); err != nil {
t.Fatal(err)
}
deletedIds := []string{}
app.OnRecordAfterDeleteSuccess().BindFunc(func(e *core.RecordEvent) error {
deletedIds = append(deletedIds, e.Record.Id)
return e.Next()
})
if err := app.DeleteExpiredOTPs(); err != nil {
t.Fatal(err)
}
if len(deletedIds) != len(expectedDeletedIds) {
t.Fatalf("Expected deleted ids\n%v\ngot\n%v", expectedDeletedIds, deletedIds)
}
for _, id := range expectedDeletedIds {
if !slices.Contains(deletedIds, id) {
t.Errorf("Expected to find deleted id %q in %v", id, deletedIds)
}
}
}
t.Run("default test collections", func(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
checkDeletedIds(app, t, []string{
"user1_0",
"superuser2_2",
"superuser2_4",
})
})
t.Run("otp collection duration mock", func(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
superusers, err := app.FindCollectionByNameOrId(core.CollectionNameSuperusers)
if err != nil {
t.Fatal(err)
}
superusers.OTP.Duration = 60
if err := app.Save(superusers); err != nil {
t.Fatalf("Failed to mock superusers otp duration: %v", err)
}
checkDeletedIds(app, t, []string{
"user1_0",
"superuser2_2",
"superuser2_4",
"superuser3_1",
})
})
}

View file

@ -0,0 +1,403 @@
package core
import (
"encoding/json"
"errors"
"fmt"
"slices"
"strconv"
"strings"
"github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase/tools/search"
"github.com/pocketbase/pocketbase/tools/security"
"github.com/pocketbase/pocketbase/tools/types"
"github.com/spf13/cast"
)
// filter modifiers
const (
eachModifier string = "each"
issetModifier string = "isset"
lengthModifier string = "length"
lowerModifier string = "lower"
)
// ensure that `search.FieldResolver` interface is implemented
var _ search.FieldResolver = (*RecordFieldResolver)(nil)
// RecordFieldResolver defines a custom search resolver struct for
// managing Record model search fields.
//
// Usually used together with `search.Provider`.
// Example:
//
// resolver := resolvers.NewRecordFieldResolver(
// app,
// myCollection,
// &models.RequestInfo{...},
// true,
// )
// provider := search.NewProvider(resolver)
// ...
type RecordFieldResolver struct {
app App
baseCollection *Collection
requestInfo *RequestInfo
staticRequestInfo map[string]any
allowedFields []string
joins []*join
allowHiddenFields bool
}
// AllowedFields returns a copy of the resolver's allowed fields.
func (r *RecordFieldResolver) AllowedFields() []string {
return slices.Clone(r.allowedFields)
}
// SetAllowedFields replaces the resolver's allowed fields with the new ones.
func (r *RecordFieldResolver) SetAllowedFields(newAllowedFields []string) {
r.allowedFields = slices.Clone(newAllowedFields)
}
// AllowHiddenFields returns whether the current resolver allows filtering hidden fields.
func (r *RecordFieldResolver) AllowHiddenFields() bool {
return r.allowHiddenFields
}
// SetAllowHiddenFields enables or disables hidden fields filtering.
func (r *RecordFieldResolver) SetAllowHiddenFields(allowHiddenFields bool) {
r.allowHiddenFields = allowHiddenFields
}
// NewRecordFieldResolver creates and initializes a new `RecordFieldResolver`.
func NewRecordFieldResolver(
app App,
baseCollection *Collection,
requestInfo *RequestInfo,
allowHiddenFields bool,
) *RecordFieldResolver {
r := &RecordFieldResolver{
app: app,
baseCollection: baseCollection,
requestInfo: requestInfo,
allowHiddenFields: allowHiddenFields, // note: it is not based only on the requestInfo.auth since it could be used by a non-request internal method
joins: []*join{},
allowedFields: []string{
`^\w+[\w\.\:]*$`,
`^\@request\.context$`,
`^\@request\.method$`,
`^\@request\.auth\.[\w\.\:]*\w+$`,
`^\@request\.body\.[\w\.\:]*\w+$`,
`^\@request\.query\.[\w\.\:]*\w+$`,
`^\@request\.headers\.[\w\.\:]*\w+$`,
`^\@collection\.\w+(\:\w+)?\.[\w\.\:]*\w+$`,
},
}
r.staticRequestInfo = map[string]any{}
if r.requestInfo != nil {
r.staticRequestInfo["context"] = r.requestInfo.Context
r.staticRequestInfo["method"] = r.requestInfo.Method
r.staticRequestInfo["query"] = r.requestInfo.Query
r.staticRequestInfo["headers"] = r.requestInfo.Headers
r.staticRequestInfo["body"] = r.requestInfo.Body
r.staticRequestInfo["auth"] = nil
if r.requestInfo.Auth != nil {
authClone := r.requestInfo.Auth.Clone()
r.staticRequestInfo["auth"] = authClone.
Unhide(authClone.Collection().Fields.FieldNames()...).
IgnoreEmailVisibility(true).
PublicExport()
}
}
return r
}
// UpdateQuery implements `search.FieldResolver` interface.
//
// Conditionally updates the provided search query based on the
// resolved fields (eg. dynamically joining relations).
func (r *RecordFieldResolver) UpdateQuery(query *dbx.SelectQuery) error {
if len(r.joins) > 0 {
query.Distinct(true)
for _, join := range r.joins {
query.LeftJoin(
(join.tableName + " " + join.tableAlias),
join.on,
)
}
}
return nil
}
// Resolve implements `search.FieldResolver` interface.
//
// Example of some resolvable fieldName formats:
//
// id
// someSelect.each
// project.screen.status
// screen.project_via_prototype.name
// @request.context
// @request.method
// @request.query.filter
// @request.headers.x_token
// @request.auth.someRelation.name
// @request.body.someRelation.name
// @request.body.someField
// @request.body.someSelect:each
// @request.body.someField:isset
// @collection.product.name
func (r *RecordFieldResolver) Resolve(fieldName string) (*search.ResolverResult, error) {
return parseAndRun(fieldName, r)
}
func (r *RecordFieldResolver) resolveStaticRequestField(path ...string) (*search.ResolverResult, error) {
if len(path) == 0 {
return nil, errors.New("at least one path key should be provided")
}
lastProp, modifier, err := splitModifier(path[len(path)-1])
if err != nil {
return nil, err
}
path[len(path)-1] = lastProp
// extract value
resultVal, err := extractNestedVal(r.staticRequestInfo, path...)
if err != nil {
r.app.Logger().Debug("resolveStaticRequestField graceful fallback", "error", err.Error())
}
if modifier == issetModifier {
if err != nil {
return &search.ResolverResult{Identifier: "FALSE"}, nil
}
return &search.ResolverResult{Identifier: "TRUE"}, nil
}
// note: we are ignoring the error because requestInfo is dynamic
// and some of the lookup keys may not be defined for the request
switch v := resultVal.(type) {
case nil:
return &search.ResolverResult{Identifier: "NULL"}, nil
case string:
// check if it is a number field and explicitly try to cast to
// float in case of a numeric string value was used
// (this usually the case when the data is from a multipart/form-data request)
field := r.baseCollection.Fields.GetByName(path[len(path)-1])
if field != nil && field.Type() == FieldTypeNumber {
if nv, err := strconv.ParseFloat(v, 64); err == nil {
resultVal = nv
}
}
// otherwise - no further processing is needed...
case bool, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64:
// no further processing is needed...
default:
// non-plain value
// try casting to string (in case for exampe fmt.Stringer is implemented)
val, castErr := cast.ToStringE(v)
// if that doesn't work, try encoding it
if castErr != nil {
encoded, jsonErr := json.Marshal(v)
if jsonErr == nil {
val = string(encoded)
}
}
resultVal = val
}
placeholder := "f" + security.PseudorandomString(8)
if modifier == lowerModifier {
return &search.ResolverResult{
Identifier: "LOWER({:" + placeholder + "})",
Params: dbx.Params{placeholder: resultVal},
}, nil
}
return &search.ResolverResult{
Identifier: "{:" + placeholder + "}",
Params: dbx.Params{placeholder: resultVal},
}, nil
}
func (r *RecordFieldResolver) loadCollection(collectionNameOrId string) (*Collection, error) {
if collectionNameOrId == r.baseCollection.Name || collectionNameOrId == r.baseCollection.Id {
return r.baseCollection, nil
}
return getCollectionByModelOrIdentifier(r.app, collectionNameOrId)
}
func (r *RecordFieldResolver) registerJoin(tableName string, tableAlias string, on dbx.Expression) {
join := &join{
tableName: tableName,
tableAlias: tableAlias,
on: on,
}
// replace existing join
for i, j := range r.joins {
if j.tableAlias == join.tableAlias {
r.joins[i] = join
return
}
}
// register new join
r.joins = append(r.joins, join)
}
type mapExtractor interface {
AsMap() map[string]any
}
func extractNestedVal(rawData any, keys ...string) (any, error) {
if len(keys) == 0 {
return nil, errors.New("at least one key should be provided")
}
switch m := rawData.(type) {
// maps
case map[string]any:
return mapVal(m, keys...)
case map[string]string:
return mapVal(m, keys...)
case map[string]bool:
return mapVal(m, keys...)
case map[string]float32:
return mapVal(m, keys...)
case map[string]float64:
return mapVal(m, keys...)
case map[string]int:
return mapVal(m, keys...)
case map[string]int8:
return mapVal(m, keys...)
case map[string]int16:
return mapVal(m, keys...)
case map[string]int32:
return mapVal(m, keys...)
case map[string]int64:
return mapVal(m, keys...)
case map[string]uint:
return mapVal(m, keys...)
case map[string]uint8:
return mapVal(m, keys...)
case map[string]uint16:
return mapVal(m, keys...)
case map[string]uint32:
return mapVal(m, keys...)
case map[string]uint64:
return mapVal(m, keys...)
case mapExtractor:
return mapVal(m.AsMap(), keys...)
case types.JSONRaw:
var raw any
err := json.Unmarshal(m, &raw)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal raw JSON in order extract nested value from: %w", err)
}
return extractNestedVal(raw, keys...)
// slices
case []string:
return arrVal(m, keys...)
case []bool:
return arrVal(m, keys...)
case []float32:
return arrVal(m, keys...)
case []float64:
return arrVal(m, keys...)
case []int:
return arrVal(m, keys...)
case []int8:
return arrVal(m, keys...)
case []int16:
return arrVal(m, keys...)
case []int32:
return arrVal(m, keys...)
case []int64:
return arrVal(m, keys...)
case []uint:
return arrVal(m, keys...)
case []uint8:
return arrVal(m, keys...)
case []uint16:
return arrVal(m, keys...)
case []uint32:
return arrVal(m, keys...)
case []uint64:
return arrVal(m, keys...)
case []mapExtractor:
extracted := make([]any, len(m))
for i, v := range m {
extracted[i] = v.AsMap()
}
return arrVal(extracted, keys...)
case []any:
return arrVal(m, keys...)
case []types.JSONRaw:
return arrVal(m, keys...)
default:
return nil, fmt.Errorf("expected map or array, got %#v", rawData)
}
}
func mapVal[T any](m map[string]T, keys ...string) (any, error) {
result, ok := m[keys[0]]
if !ok {
return nil, fmt.Errorf("invalid key path - missing key %q", keys[0])
}
// end key reached
if len(keys) == 1 {
return result, nil
}
return extractNestedVal(result, keys[1:]...)
}
func arrVal[T any](m []T, keys ...string) (any, error) {
idx, err := strconv.Atoi(keys[0])
if err != nil || idx < 0 || idx >= len(m) {
return nil, fmt.Errorf("invalid key path - invalid or missing array index %q", keys[0])
}
result := m[idx]
// end key reached
if len(keys) == 1 {
return result, nil
}
return extractNestedVal(result, keys[1:]...)
}
func splitModifier(combined string) (string, string, error) {
parts := strings.Split(combined, ":")
if len(parts) != 2 {
return combined, "", nil
}
// validate modifier
switch parts[1] {
case issetModifier,
eachModifier,
lengthModifier,
lowerModifier:
return parts[0], parts[1], nil
}
return "", "", fmt.Errorf("unknown modifier in %q", combined)
}

View file

@ -0,0 +1,70 @@
package core
import (
"fmt"
"strings"
"github.com/pocketbase/dbx"
)
var _ dbx.Expression = (*multiMatchSubquery)(nil)
// join defines the specification for a single SQL JOIN clause.
type join struct {
tableName string
tableAlias string
on dbx.Expression
}
// multiMatchSubquery defines a record multi-match subquery expression.
type multiMatchSubquery struct {
baseTableAlias string
fromTableName string
fromTableAlias string
valueIdentifier string
joins []*join
params dbx.Params
}
// Build converts the expression into a SQL fragment.
//
// Implements [dbx.Expression] interface.
func (m *multiMatchSubquery) Build(db *dbx.DB, params dbx.Params) string {
if m.baseTableAlias == "" || m.fromTableName == "" || m.fromTableAlias == "" {
return "0=1"
}
if params == nil {
params = m.params
} else {
// merge by updating the parent params
for k, v := range m.params {
params[k] = v
}
}
var mergedJoins strings.Builder
for i, j := range m.joins {
if i > 0 {
mergedJoins.WriteString(" ")
}
mergedJoins.WriteString("LEFT JOIN ")
mergedJoins.WriteString(db.QuoteTableName(j.tableName))
mergedJoins.WriteString(" ")
mergedJoins.WriteString(db.QuoteTableName(j.tableAlias))
if j.on != nil {
mergedJoins.WriteString(" ON ")
mergedJoins.WriteString(j.on.Build(db, params))
}
}
return fmt.Sprintf(
`SELECT %s as [[multiMatchValue]] FROM %s %s %s WHERE %s = %s`,
db.QuoteColumnName(m.valueIdentifier),
db.QuoteTableName(m.fromTableName),
db.QuoteTableName(m.fromTableAlias),
mergedJoins.String(),
db.QuoteColumnName(m.fromTableAlias+".id"),
db.QuoteColumnName(m.baseTableAlias+".id"),
)
}

View file

@ -0,0 +1,792 @@
package core
import (
"encoding/json"
"errors"
"fmt"
"reflect"
"regexp"
"strconv"
"strings"
"github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase/tools/dbutils"
"github.com/pocketbase/pocketbase/tools/inflector"
"github.com/pocketbase/pocketbase/tools/list"
"github.com/pocketbase/pocketbase/tools/search"
"github.com/pocketbase/pocketbase/tools/security"
"github.com/spf13/cast"
)
// maxNestedRels defines the max allowed nested relations depth.
const maxNestedRels = 6
// list of auth filter fields that don't require join with the auth
// collection or any other extra checks to be resolved.
var plainRequestAuthFields = map[string]struct{}{
"@request.auth." + FieldNameId: {},
"@request.auth." + FieldNameCollectionId: {},
"@request.auth." + FieldNameCollectionName: {},
"@request.auth." + FieldNameEmail: {},
"@request.auth." + FieldNameEmailVisibility: {},
"@request.auth." + FieldNameVerified: {},
}
// parseAndRun starts a new one-off RecordFieldResolver.Resolve execution.
func parseAndRun(fieldName string, resolver *RecordFieldResolver) (*search.ResolverResult, error) {
r := &runner{
fieldName: fieldName,
resolver: resolver,
}
return r.run()
}
type runner struct {
used bool // indicates whether the runner was already executed
resolver *RecordFieldResolver // resolver is the shared expression fields resolver
fieldName string // the name of the single field expression the runner is responsible for
// shared processing state
// ---------------------------------------------------------------
activeProps []string // holds the active props that remains to be processed
activeCollectionName string // the last used collection name
activeTableAlias string // the last used table alias
allowHiddenFields bool // indicates whether hidden fields (eg. email) should be allowed without extra conditions
nullifyMisingField bool // indicating whether to return null on missing field or return an error
withMultiMatch bool // indicates whether to attach a multiMatchSubquery condition to the ResolverResult
multiMatchActiveTableAlias string // the last used multi-match table alias
multiMatch *multiMatchSubquery // the multi-match subquery expression generated from the fieldName
}
func (r *runner) run() (*search.ResolverResult, error) {
if r.used {
return nil, errors.New("the runner was already used")
}
if len(r.resolver.allowedFields) > 0 && !list.ExistInSliceWithRegex(r.fieldName, r.resolver.allowedFields) {
return nil, fmt.Errorf("failed to resolve field %q", r.fieldName)
}
defer func() {
r.used = true
}()
r.prepare()
// check for @collection field (aka. non-relational join)
// must be in the format "@collection.COLLECTION_NAME.FIELD[.FIELD2....]"
if r.activeProps[0] == "@collection" {
return r.processCollectionField()
}
if r.activeProps[0] == "@request" {
if r.resolver.requestInfo == nil {
return &search.ResolverResult{Identifier: "NULL"}, nil
}
if strings.HasPrefix(r.fieldName, "@request.auth.") {
return r.processRequestAuthField()
}
if strings.HasPrefix(r.fieldName, "@request.body.") && len(r.activeProps) > 2 {
name, modifier, err := splitModifier(r.activeProps[2])
if err != nil {
return nil, err
}
bodyField := r.resolver.baseCollection.Fields.GetByName(name)
if bodyField == nil {
return r.resolver.resolveStaticRequestField(r.activeProps[1:]...)
}
// check for body relation field
if bodyField.Type() == FieldTypeRelation && len(r.activeProps) > 3 {
return r.processRequestInfoRelationField(bodyField)
}
// check for body arrayble fields ":each" modifier
if modifier == eachModifier && len(r.activeProps) == 3 {
return r.processRequestInfoEachModifier(bodyField)
}
// check for body arrayble fields ":length" modifier
if modifier == lengthModifier && len(r.activeProps) == 3 {
return r.processRequestInfoLengthModifier(bodyField)
}
// check for body arrayble fields ":lower" modifier
if modifier == lowerModifier && len(r.activeProps) == 3 {
return r.processRequestInfoLowerModifier(bodyField)
}
}
// some other @request.* static field
return r.resolver.resolveStaticRequestField(r.activeProps[1:]...)
}
// regular field
return r.processActiveProps()
}
func (r *runner) prepare() {
r.activeProps = strings.Split(r.fieldName, ".")
r.activeCollectionName = r.resolver.baseCollection.Name
r.activeTableAlias = inflector.Columnify(r.activeCollectionName)
r.allowHiddenFields = r.resolver.allowHiddenFields
// always allow hidden fields since the @.* filter is a system one
if r.activeProps[0] == "@collection" || r.activeProps[0] == "@request" {
r.allowHiddenFields = true
}
// enable the ignore flag for missing @request.* fields for backward
// compatibility and consistency with all @request.* filter fields and types
r.nullifyMisingField = r.activeProps[0] == "@request"
// prepare a multi-match subquery
r.multiMatch = &multiMatchSubquery{
baseTableAlias: r.activeTableAlias,
params: dbx.Params{},
}
r.multiMatch.fromTableName = inflector.Columnify(r.activeCollectionName)
r.multiMatch.fromTableAlias = "__mm_" + r.activeTableAlias
r.multiMatchActiveTableAlias = r.multiMatch.fromTableAlias
r.withMultiMatch = false
}
func (r *runner) processCollectionField() (*search.ResolverResult, error) {
if len(r.activeProps) < 3 {
return nil, fmt.Errorf("invalid @collection field path in %q", r.fieldName)
}
// nameOrId or nameOrId:alias
collectionParts := strings.SplitN(r.activeProps[1], ":", 2)
collection, err := r.resolver.loadCollection(collectionParts[0])
if err != nil {
return nil, fmt.Errorf("failed to load collection %q from field path %q", r.activeProps[1], r.fieldName)
}
r.activeCollectionName = collection.Name
if len(collectionParts) == 2 && collectionParts[1] != "" {
r.activeTableAlias = inflector.Columnify("__collection_alias_" + collectionParts[1])
} else {
r.activeTableAlias = inflector.Columnify("__collection_" + r.activeCollectionName)
}
r.withMultiMatch = true
// join the collection to the main query
r.resolver.registerJoin(inflector.Columnify(collection.Name), r.activeTableAlias, nil)
// join the collection to the multi-match subquery
r.multiMatchActiveTableAlias = "__mm" + r.activeTableAlias
r.multiMatch.joins = append(r.multiMatch.joins, &join{
tableName: inflector.Columnify(collection.Name),
tableAlias: r.multiMatchActiveTableAlias,
})
// leave only the collection fields
// aka. @collection.someCollection.fieldA.fieldB -> fieldA.fieldB
r.activeProps = r.activeProps[2:]
return r.processActiveProps()
}
func (r *runner) processRequestAuthField() (*search.ResolverResult, error) {
if r.resolver.requestInfo == nil || r.resolver.requestInfo.Auth == nil || r.resolver.requestInfo.Auth.Collection() == nil {
return &search.ResolverResult{Identifier: "NULL"}, nil
}
// plain auth field
// ---
if _, ok := plainRequestAuthFields[r.fieldName]; ok {
return r.resolver.resolveStaticRequestField(r.activeProps[1:]...)
}
// resolve the auth collection field
// ---
collection := r.resolver.requestInfo.Auth.Collection()
r.activeCollectionName = collection.Name
r.activeTableAlias = "__auth_" + inflector.Columnify(r.activeCollectionName)
// join the auth collection to the main query
r.resolver.registerJoin(
inflector.Columnify(r.activeCollectionName),
r.activeTableAlias,
dbx.HashExp{
// aka. __auth_users.id = :userId
(r.activeTableAlias + ".id"): r.resolver.requestInfo.Auth.Id,
},
)
// join the auth collection to the multi-match subquery
r.multiMatchActiveTableAlias = "__mm_" + r.activeTableAlias
r.multiMatch.joins = append(
r.multiMatch.joins,
&join{
tableName: inflector.Columnify(r.activeCollectionName),
tableAlias: r.multiMatchActiveTableAlias,
on: dbx.HashExp{
(r.multiMatchActiveTableAlias + ".id"): r.resolver.requestInfo.Auth.Id,
},
},
)
// leave only the auth relation fields
// aka. @request.auth.fieldA.fieldB -> fieldA.fieldB
r.activeProps = r.activeProps[2:]
return r.processActiveProps()
}
// note: nil value is returned as empty slice
func toSlice(value any) []any {
if value == nil {
return []any{}
}
rv := reflect.ValueOf(value)
kind := rv.Kind()
if kind != reflect.Slice && kind != reflect.Array {
return []any{value}
}
rvLen := rv.Len()
result := make([]interface{}, rvLen)
for i := 0; i < rvLen; i++ {
result[i] = rv.Index(i).Interface()
}
return result
}
func (r *runner) processRequestInfoLowerModifier(bodyField Field) (*search.ResolverResult, error) {
rawValue := cast.ToString(r.resolver.requestInfo.Body[bodyField.GetName()])
placeholder := "infoLower" + bodyField.GetName() + security.PseudorandomString(6)
result := &search.ResolverResult{
Identifier: "LOWER({:" + placeholder + "})",
Params: dbx.Params{placeholder: rawValue},
}
return result, nil
}
func (r *runner) processRequestInfoLengthModifier(bodyField Field) (*search.ResolverResult, error) {
if _, ok := bodyField.(MultiValuer); !ok {
return nil, fmt.Errorf("field %q doesn't support multivalue operations", bodyField.GetName())
}
bodyItems := toSlice(r.resolver.requestInfo.Body[bodyField.GetName()])
result := &search.ResolverResult{
Identifier: strconv.Itoa(len(bodyItems)),
}
return result, nil
}
func (r *runner) processRequestInfoEachModifier(bodyField Field) (*search.ResolverResult, error) {
multiValuer, ok := bodyField.(MultiValuer)
if !ok {
return nil, fmt.Errorf("field %q doesn't support multivalue operations", bodyField.GetName())
}
bodyItems := toSlice(r.resolver.requestInfo.Body[bodyField.GetName()])
bodyItemsRaw, err := json.Marshal(bodyItems)
if err != nil {
return nil, fmt.Errorf("cannot serialize the data for field %q", r.activeProps[2])
}
placeholder := "dataEach" + security.PseudorandomString(6)
cleanFieldName := inflector.Columnify(bodyField.GetName())
jeTable := fmt.Sprintf("json_each({:%s})", placeholder)
jeAlias := "__dataEach_" + cleanFieldName + "_je"
r.resolver.registerJoin(jeTable, jeAlias, nil)
result := &search.ResolverResult{
Identifier: fmt.Sprintf("[[%s.value]]", jeAlias),
Params: dbx.Params{placeholder: bodyItemsRaw},
}
if multiValuer.IsMultiple() {
r.withMultiMatch = true
}
if r.withMultiMatch {
placeholder2 := "mm" + placeholder
jeTable2 := fmt.Sprintf("json_each({:%s})", placeholder2)
jeAlias2 := "__mm" + jeAlias
r.multiMatch.joins = append(r.multiMatch.joins, &join{
tableName: jeTable2,
tableAlias: jeAlias2,
})
r.multiMatch.params[placeholder2] = bodyItemsRaw
r.multiMatch.valueIdentifier = fmt.Sprintf("[[%s.value]]", jeAlias2)
result.MultiMatchSubQuery = r.multiMatch
}
return result, nil
}
func (r *runner) processRequestInfoRelationField(bodyField Field) (*search.ResolverResult, error) {
relField, ok := bodyField.(*RelationField)
if !ok {
return nil, fmt.Errorf("failed to initialize data relation field %q", bodyField.GetName())
}
dataRelCollection, err := r.resolver.loadCollection(relField.CollectionId)
if err != nil {
return nil, fmt.Errorf("failed to load collection %q from data field %q", relField.CollectionId, relField.Name)
}
var dataRelIds []string
if r.resolver.requestInfo != nil && len(r.resolver.requestInfo.Body) != 0 {
dataRelIds = list.ToUniqueStringSlice(r.resolver.requestInfo.Body[relField.Name])
}
if len(dataRelIds) == 0 {
return &search.ResolverResult{Identifier: "NULL"}, nil
}
r.activeCollectionName = dataRelCollection.Name
r.activeTableAlias = inflector.Columnify("__data_" + dataRelCollection.Name + "_" + relField.Name)
// join the data rel collection to the main collection
r.resolver.registerJoin(
r.activeCollectionName,
r.activeTableAlias,
dbx.In(
fmt.Sprintf("[[%s.id]]", r.activeTableAlias),
list.ToInterfaceSlice(dataRelIds)...,
),
)
if relField.IsMultiple() {
r.withMultiMatch = true
}
// join the data rel collection to the multi-match subquery
r.multiMatchActiveTableAlias = inflector.Columnify("__data_mm_" + dataRelCollection.Name + "_" + relField.Name)
r.multiMatch.joins = append(
r.multiMatch.joins,
&join{
tableName: r.activeCollectionName,
tableAlias: r.multiMatchActiveTableAlias,
on: dbx.In(
fmt.Sprintf("[[%s.id]]", r.multiMatchActiveTableAlias),
list.ToInterfaceSlice(dataRelIds)...,
),
},
)
// leave only the data relation fields
// aka. @request.body.someRel.fieldA.fieldB -> fieldA.fieldB
r.activeProps = r.activeProps[3:]
return r.processActiveProps()
}
var viaRegex = regexp.MustCompile(`^(\w+)_via_(\w+)$`)
func (r *runner) processActiveProps() (*search.ResolverResult, error) {
totalProps := len(r.activeProps)
for i, prop := range r.activeProps {
collection, err := r.resolver.loadCollection(r.activeCollectionName)
if err != nil {
return nil, fmt.Errorf("failed to resolve field %q", prop)
}
// last prop
if i == totalProps-1 {
return r.processLastProp(collection, prop)
}
field := collection.Fields.GetByName(prop)
if field != nil && field.GetHidden() && !r.allowHiddenFields {
return nil, fmt.Errorf("non-filterable field %q", prop)
}
// json or geoPoint field -> treat the rest of the props as json path
// @todo consider converting to "JSONExtractable" interface with optional extra validation for the remaining props?
if field != nil && (field.Type() == FieldTypeJSON || field.Type() == FieldTypeGeoPoint) {
var jsonPath strings.Builder
for j, p := range r.activeProps[i+1:] {
if _, err := strconv.Atoi(p); err == nil {
jsonPath.WriteString("[")
jsonPath.WriteString(inflector.Columnify(p))
jsonPath.WriteString("]")
} else {
if j > 0 {
jsonPath.WriteString(".")
}
jsonPath.WriteString(inflector.Columnify(p))
}
}
jsonPathStr := jsonPath.String()
result := &search.ResolverResult{
NoCoalesce: true,
Identifier: dbutils.JSONExtract(r.activeTableAlias+"."+inflector.Columnify(prop), jsonPathStr),
}
if r.withMultiMatch {
r.multiMatch.valueIdentifier = dbutils.JSONExtract(r.multiMatchActiveTableAlias+"."+inflector.Columnify(prop), jsonPathStr)
result.MultiMatchSubQuery = r.multiMatch
}
return result, nil
}
if i >= maxNestedRels {
return nil, fmt.Errorf("max nested relations reached for field %q", prop)
}
// check for back relation (eg. yourCollection_via_yourRelField)
// -----------------------------------------------------------
if field == nil {
parts := viaRegex.FindStringSubmatch(prop)
if len(parts) != 3 {
if r.nullifyMisingField {
return &search.ResolverResult{Identifier: "NULL"}, nil
}
return nil, fmt.Errorf("failed to resolve field %q", prop)
}
backCollection, err := r.resolver.loadCollection(parts[1])
if err != nil {
if r.nullifyMisingField {
return &search.ResolverResult{Identifier: "NULL"}, nil
}
return nil, fmt.Errorf("failed to load back relation field %q collection", prop)
}
backField := backCollection.Fields.GetByName(parts[2])
if backField == nil {
if r.nullifyMisingField {
return &search.ResolverResult{Identifier: "NULL"}, nil
}
return nil, fmt.Errorf("missing back relation field %q", parts[2])
}
if backField.Type() != FieldTypeRelation {
if r.nullifyMisingField {
return &search.ResolverResult{Identifier: "NULL"}, nil
}
return nil, fmt.Errorf("invalid back relation field %q", parts[2])
}
if backField.GetHidden() && !r.allowHiddenFields {
return nil, fmt.Errorf("non-filterable back relation field %q", backField.GetName())
}
backRelField, ok := backField.(*RelationField)
if !ok {
return nil, fmt.Errorf("failed to initialize back relation field %q", backField.GetName())
}
if backRelField.CollectionId != collection.Id {
// https://github.com/pocketbase/pocketbase/discussions/6590#discussioncomment-12496581
if r.nullifyMisingField {
return &search.ResolverResult{Identifier: "NULL"}, nil
}
return nil, fmt.Errorf("invalid collection reference of a back relation field %q", backField.GetName())
}
// join the back relation to the main query
// ---
cleanProp := inflector.Columnify(prop)
cleanBackFieldName := inflector.Columnify(backRelField.Name)
newTableAlias := r.activeTableAlias + "_" + cleanProp
newCollectionName := inflector.Columnify(backCollection.Name)
isBackRelMultiple := backRelField.IsMultiple()
if !isBackRelMultiple {
// additionally check if the rel field has a single column unique index
_, hasUniqueIndex := dbutils.FindSingleColumnUniqueIndex(backCollection.Indexes, backRelField.Name)
isBackRelMultiple = !hasUniqueIndex
}
if !isBackRelMultiple {
r.resolver.registerJoin(
newCollectionName,
newTableAlias,
dbx.NewExp(fmt.Sprintf("[[%s.%s]] = [[%s.id]]", newTableAlias, cleanBackFieldName, r.activeTableAlias)),
)
} else {
jeAlias := r.activeTableAlias + "_" + cleanProp + "_je"
r.resolver.registerJoin(
newCollectionName,
newTableAlias,
dbx.NewExp(fmt.Sprintf(
"[[%s.id]] IN (SELECT [[%s.value]] FROM %s {{%s}})",
r.activeTableAlias,
jeAlias,
dbutils.JSONEach(newTableAlias+"."+cleanBackFieldName),
jeAlias,
)),
)
}
r.activeCollectionName = newCollectionName
r.activeTableAlias = newTableAlias
// ---
// join the back relation to the multi-match subquery
// ---
if isBackRelMultiple {
r.withMultiMatch = true // enable multimatch if not already
}
newTableAlias2 := r.multiMatchActiveTableAlias + "_" + cleanProp
if !isBackRelMultiple {
r.multiMatch.joins = append(
r.multiMatch.joins,
&join{
tableName: newCollectionName,
tableAlias: newTableAlias2,
on: dbx.NewExp(fmt.Sprintf("[[%s.%s]] = [[%s.id]]", newTableAlias2, cleanBackFieldName, r.multiMatchActiveTableAlias)),
},
)
} else {
jeAlias2 := r.multiMatchActiveTableAlias + "_" + cleanProp + "_je"
r.multiMatch.joins = append(
r.multiMatch.joins,
&join{
tableName: newCollectionName,
tableAlias: newTableAlias2,
on: dbx.NewExp(fmt.Sprintf(
"[[%s.id]] IN (SELECT [[%s.value]] FROM %s {{%s}})",
r.multiMatchActiveTableAlias,
jeAlias2,
dbutils.JSONEach(newTableAlias2+"."+cleanBackFieldName),
jeAlias2,
)),
},
)
}
r.multiMatchActiveTableAlias = newTableAlias2
// ---
continue
}
// -----------------------------------------------------------
// check for direct relation
if field.Type() != FieldTypeRelation {
return nil, fmt.Errorf("field %q is not a valid relation", prop)
}
// join the relation to the main query
// ---
relField, ok := field.(*RelationField)
if !ok {
return nil, fmt.Errorf("failed to initialize relation field %q", prop)
}
relCollection, relErr := r.resolver.loadCollection(relField.CollectionId)
if relErr != nil {
return nil, fmt.Errorf("failed to load field %q collection", prop)
}
// "id" lookups optimization for single relations to avoid unnecessary joins,
// aka. "user.id" and "user" should produce the same query identifier
if !relField.IsMultiple() &&
// the penultimate prop is "id"
i == totalProps-2 && r.activeProps[i+1] == FieldNameId {
return r.processLastProp(collection, relField.Name)
}
cleanFieldName := inflector.Columnify(relField.Name)
prefixedFieldName := r.activeTableAlias + "." + cleanFieldName
newTableAlias := r.activeTableAlias + "_" + cleanFieldName
newCollectionName := relCollection.Name
if !relField.IsMultiple() {
r.resolver.registerJoin(
inflector.Columnify(newCollectionName),
newTableAlias,
dbx.NewExp(fmt.Sprintf("[[%s.id]] = [[%s]]", newTableAlias, prefixedFieldName)),
)
} else {
jeAlias := r.activeTableAlias + "_" + cleanFieldName + "_je"
r.resolver.registerJoin(dbutils.JSONEach(prefixedFieldName), jeAlias, nil)
r.resolver.registerJoin(
inflector.Columnify(newCollectionName),
newTableAlias,
dbx.NewExp(fmt.Sprintf("[[%s.id]] = [[%s.value]]", newTableAlias, jeAlias)),
)
}
r.activeCollectionName = newCollectionName
r.activeTableAlias = newTableAlias
// ---
// join the relation to the multi-match subquery
// ---
if relField.IsMultiple() {
r.withMultiMatch = true // enable multimatch if not already
}
newTableAlias2 := r.multiMatchActiveTableAlias + "_" + cleanFieldName
prefixedFieldName2 := r.multiMatchActiveTableAlias + "." + cleanFieldName
if !relField.IsMultiple() {
r.multiMatch.joins = append(
r.multiMatch.joins,
&join{
tableName: inflector.Columnify(newCollectionName),
tableAlias: newTableAlias2,
on: dbx.NewExp(fmt.Sprintf("[[%s.id]] = [[%s]]", newTableAlias2, prefixedFieldName2)),
},
)
} else {
jeAlias2 := r.multiMatchActiveTableAlias + "_" + cleanFieldName + "_je"
r.multiMatch.joins = append(
r.multiMatch.joins,
&join{
tableName: dbutils.JSONEach(prefixedFieldName2),
tableAlias: jeAlias2,
},
&join{
tableName: inflector.Columnify(newCollectionName),
tableAlias: newTableAlias2,
on: dbx.NewExp(fmt.Sprintf("[[%s.id]] = [[%s.value]]", newTableAlias2, jeAlias2)),
},
)
}
r.multiMatchActiveTableAlias = newTableAlias2
// ---
}
return nil, fmt.Errorf("failed to resolve field %q", r.fieldName)
}
func (r *runner) processLastProp(collection *Collection, prop string) (*search.ResolverResult, error) {
name, modifier, err := splitModifier(prop)
if err != nil {
return nil, err
}
field := collection.Fields.GetByName(name)
if field == nil {
if r.nullifyMisingField {
return &search.ResolverResult{Identifier: "NULL"}, nil
}
return nil, fmt.Errorf("unknown field %q", name)
}
if field.GetHidden() && !r.allowHiddenFields {
return nil, fmt.Errorf("non-filterable field %q", name)
}
multvaluer, isMultivaluer := field.(MultiValuer)
cleanFieldName := inflector.Columnify(field.GetName())
// arrayable fields with ":length" modifier
// -------------------------------------------------------
if modifier == lengthModifier && isMultivaluer {
jePair := r.activeTableAlias + "." + cleanFieldName
result := &search.ResolverResult{
Identifier: dbutils.JSONArrayLength(jePair),
}
if r.withMultiMatch {
jePair2 := r.multiMatchActiveTableAlias + "." + cleanFieldName
r.multiMatch.valueIdentifier = dbutils.JSONArrayLength(jePair2)
result.MultiMatchSubQuery = r.multiMatch
}
return result, nil
}
// arrayable fields with ":each" modifier
// -------------------------------------------------------
if modifier == eachModifier && isMultivaluer {
jePair := r.activeTableAlias + "." + cleanFieldName
jeAlias := r.activeTableAlias + "_" + cleanFieldName + "_je"
r.resolver.registerJoin(dbutils.JSONEach(jePair), jeAlias, nil)
result := &search.ResolverResult{
Identifier: fmt.Sprintf("[[%s.value]]", jeAlias),
}
if multvaluer.IsMultiple() {
r.withMultiMatch = true
}
if r.withMultiMatch {
jePair2 := r.multiMatchActiveTableAlias + "." + cleanFieldName
jeAlias2 := r.multiMatchActiveTableAlias + "_" + cleanFieldName + "_je"
r.multiMatch.joins = append(r.multiMatch.joins, &join{
tableName: dbutils.JSONEach(jePair2),
tableAlias: jeAlias2,
})
r.multiMatch.valueIdentifier = fmt.Sprintf("[[%s.value]]", jeAlias2)
result.MultiMatchSubQuery = r.multiMatch
}
return result, nil
}
// default
// -------------------------------------------------------
result := &search.ResolverResult{
Identifier: "[[" + r.activeTableAlias + "." + cleanFieldName + "]]",
}
if r.withMultiMatch {
r.multiMatch.valueIdentifier = "[[" + r.multiMatchActiveTableAlias + "." + cleanFieldName + "]]"
result.MultiMatchSubQuery = r.multiMatch
}
// allow querying only auth records with emails marked as public
if field.GetName() == FieldNameEmail && !r.allowHiddenFields && collection.IsAuth() {
result.AfterBuild = func(expr dbx.Expression) dbx.Expression {
return dbx.Enclose(dbx.And(expr, dbx.NewExp(fmt.Sprintf(
"[[%s.%s]] = TRUE",
r.activeTableAlias,
FieldNameEmailVisibility,
))))
}
}
// wrap in json_extract to ensure that top-level primitives
// stored as json work correctly when compared to their SQL equivalent
// (https://github.com/pocketbase/pocketbase/issues/4068)
if field.Type() == FieldTypeJSON {
result.NoCoalesce = true
result.Identifier = dbutils.JSONExtract(r.activeTableAlias+"."+cleanFieldName, "")
if r.withMultiMatch {
r.multiMatch.valueIdentifier = dbutils.JSONExtract(r.multiMatchActiveTableAlias+"."+cleanFieldName, "")
}
}
// account for the ":lower" modifier
if modifier == lowerModifier {
result.Identifier = "LOWER(" + result.Identifier + ")"
if r.withMultiMatch {
r.multiMatch.valueIdentifier = "LOWER(" + r.multiMatch.valueIdentifier + ")"
}
}
return result, nil
}

File diff suppressed because one or more lines are too long

Some files were not shown because too many files have changed in this diff Show more