Adding upstream version 0.28.1.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
88f1d47ab6
commit
e28c88ef14
933 changed files with 194711 additions and 0 deletions
1538
core/app.go
Normal file
1538
core/app.go
Normal file
File diff suppressed because it is too large
Load diff
239
core/auth_origin_model.go
Normal file
239
core/auth_origin_model.go
Normal file
|
@ -0,0 +1,239 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"slices"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/tools/hook"
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
)
|
||||
|
||||
const CollectionNameAuthOrigins = "_authOrigins"
|
||||
|
||||
var (
|
||||
_ Model = (*AuthOrigin)(nil)
|
||||
_ PreValidator = (*AuthOrigin)(nil)
|
||||
_ RecordProxy = (*AuthOrigin)(nil)
|
||||
)
|
||||
|
||||
// AuthOrigin defines a Record proxy for working with the authOrigins collection.
|
||||
type AuthOrigin struct {
|
||||
*Record
|
||||
}
|
||||
|
||||
// NewAuthOrigin instantiates and returns a new blank *AuthOrigin model.
|
||||
//
|
||||
// Example usage:
|
||||
//
|
||||
// origin := core.NewOrigin(app)
|
||||
// origin.SetRecordRef(user.Id)
|
||||
// origin.SetCollectionRef(user.Collection().Id)
|
||||
// origin.SetFingerprint("...")
|
||||
// app.Save(origin)
|
||||
func NewAuthOrigin(app App) *AuthOrigin {
|
||||
m := &AuthOrigin{}
|
||||
|
||||
c, err := app.FindCachedCollectionByNameOrId(CollectionNameAuthOrigins)
|
||||
if err != nil {
|
||||
// this is just to make tests easier since authOrigins is a system collection and it is expected to be always accessible
|
||||
// (note: the loaded record is further checked on AuthOrigin.PreValidate())
|
||||
c = NewBaseCollection("@___invalid___")
|
||||
}
|
||||
|
||||
m.Record = NewRecord(c)
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
// PreValidate implements the [PreValidator] interface and checks
|
||||
// whether the proxy is properly loaded.
|
||||
func (m *AuthOrigin) PreValidate(ctx context.Context, app App) error {
|
||||
if m.Record == nil || m.Record.Collection().Name != CollectionNameAuthOrigins {
|
||||
return errors.New("missing or invalid AuthOrigin ProxyRecord")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ProxyRecord returns the proxied Record model.
|
||||
func (m *AuthOrigin) ProxyRecord() *Record {
|
||||
return m.Record
|
||||
}
|
||||
|
||||
// SetProxyRecord loads the specified record model into the current proxy.
|
||||
func (m *AuthOrigin) SetProxyRecord(record *Record) {
|
||||
m.Record = record
|
||||
}
|
||||
|
||||
// CollectionRef returns the "collectionRef" field value.
|
||||
func (m *AuthOrigin) CollectionRef() string {
|
||||
return m.GetString("collectionRef")
|
||||
}
|
||||
|
||||
// SetCollectionRef updates the "collectionRef" record field value.
|
||||
func (m *AuthOrigin) SetCollectionRef(collectionId string) {
|
||||
m.Set("collectionRef", collectionId)
|
||||
}
|
||||
|
||||
// RecordRef returns the "recordRef" record field value.
|
||||
func (m *AuthOrigin) RecordRef() string {
|
||||
return m.GetString("recordRef")
|
||||
}
|
||||
|
||||
// SetRecordRef updates the "recordRef" record field value.
|
||||
func (m *AuthOrigin) SetRecordRef(recordId string) {
|
||||
m.Set("recordRef", recordId)
|
||||
}
|
||||
|
||||
// Fingerprint returns the "fingerprint" record field value.
|
||||
func (m *AuthOrigin) Fingerprint() string {
|
||||
return m.GetString("fingerprint")
|
||||
}
|
||||
|
||||
// SetFingerprint updates the "fingerprint" record field value.
|
||||
func (m *AuthOrigin) SetFingerprint(fingerprint string) {
|
||||
m.Set("fingerprint", fingerprint)
|
||||
}
|
||||
|
||||
// Created returns the "created" record field value.
|
||||
func (m *AuthOrigin) Created() types.DateTime {
|
||||
return m.GetDateTime("created")
|
||||
}
|
||||
|
||||
// Updated returns the "updated" record field value.
|
||||
func (m *AuthOrigin) Updated() types.DateTime {
|
||||
return m.GetDateTime("updated")
|
||||
}
|
||||
|
||||
func (app *BaseApp) registerAuthOriginHooks() {
|
||||
recordRefHooks[*AuthOrigin](app, CollectionNameAuthOrigins, CollectionTypeAuth)
|
||||
|
||||
// delete existing auth origins on password change
|
||||
app.OnRecordUpdate().Bind(&hook.Handler[*RecordEvent]{
|
||||
Func: func(e *RecordEvent) error {
|
||||
err := e.Next()
|
||||
if err != nil || !e.Record.Collection().IsAuth() {
|
||||
return err
|
||||
}
|
||||
|
||||
old := e.Record.Original().GetString(FieldNamePassword + ":hash")
|
||||
new := e.Record.GetString(FieldNamePassword + ":hash")
|
||||
if old != new {
|
||||
err = e.App.DeleteAllAuthOriginsByRecord(e.Record)
|
||||
if err != nil {
|
||||
e.App.Logger().Warn(
|
||||
"Failed to delete all previous auth origin fingerprints",
|
||||
"error", err,
|
||||
"recordId", e.Record.Id,
|
||||
"collectionId", e.Record.Collection().Id,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
Priority: 99,
|
||||
})
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
// recordRefHooks registers common hooks that are usually used with record proxies
|
||||
// that have polymorphic record relations (aka. "collectionRef" and "recordRef" fields).
|
||||
func recordRefHooks[T RecordProxy](app App, collectionName string, optCollectionTypes ...string) {
|
||||
app.OnRecordValidate(collectionName).Bind(&hook.Handler[*RecordEvent]{
|
||||
Func: func(e *RecordEvent) error {
|
||||
collectionId := e.Record.GetString("collectionRef")
|
||||
err := validation.Validate(collectionId, validation.Required, validation.By(validateCollectionId(e.App, optCollectionTypes...)))
|
||||
if err != nil {
|
||||
return validation.Errors{"collectionRef": err}
|
||||
}
|
||||
|
||||
recordId := e.Record.GetString("recordRef")
|
||||
err = validation.Validate(recordId, validation.Required, validation.By(validateRecordId(e.App, collectionId)))
|
||||
if err != nil {
|
||||
return validation.Errors{"recordRef": err}
|
||||
}
|
||||
|
||||
return e.Next()
|
||||
},
|
||||
Priority: 99,
|
||||
})
|
||||
|
||||
// delete on collection ref delete
|
||||
app.OnCollectionDeleteExecute().Bind(&hook.Handler[*CollectionEvent]{
|
||||
Func: func(e *CollectionEvent) error {
|
||||
if e.Collection.Name == collectionName || (len(optCollectionTypes) > 0 && !slices.Contains(optCollectionTypes, e.Collection.Type)) {
|
||||
return e.Next()
|
||||
}
|
||||
|
||||
originalApp := e.App
|
||||
txErr := e.App.RunInTransaction(func(txApp App) error {
|
||||
e.App = txApp
|
||||
|
||||
if err := e.Next(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rels, err := txApp.FindAllRecords(collectionName, dbx.HashExp{"collectionRef": e.Collection.Id})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, mfa := range rels {
|
||||
if err := txApp.Delete(mfa); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
e.App = originalApp
|
||||
|
||||
return txErr
|
||||
},
|
||||
Priority: 99,
|
||||
})
|
||||
|
||||
// delete on record ref delete
|
||||
app.OnRecordDeleteExecute().Bind(&hook.Handler[*RecordEvent]{
|
||||
Func: func(e *RecordEvent) error {
|
||||
if e.Record.Collection().Name == collectionName ||
|
||||
(len(optCollectionTypes) > 0 && !slices.Contains(optCollectionTypes, e.Record.Collection().Type)) {
|
||||
return e.Next()
|
||||
}
|
||||
|
||||
originalApp := e.App
|
||||
txErr := e.App.RunInTransaction(func(txApp App) error {
|
||||
e.App = txApp
|
||||
|
||||
if err := e.Next(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rels, err := txApp.FindAllRecords(collectionName, dbx.HashExp{
|
||||
"collectionRef": e.Record.Collection().Id,
|
||||
"recordRef": e.Record.Id,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, rel := range rels {
|
||||
if err := txApp.Delete(rel); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
e.App = originalApp
|
||||
|
||||
return txErr
|
||||
},
|
||||
Priority: 99,
|
||||
})
|
||||
}
|
332
core/auth_origin_model_test.go
Normal file
332
core/auth_origin_model_test.go
Normal file
|
@ -0,0 +1,332 @@
|
|||
package core_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
)
|
||||
|
||||
func TestNewAuthOrigin(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
origin := core.NewAuthOrigin(app)
|
||||
|
||||
if origin.Collection().Name != core.CollectionNameAuthOrigins {
|
||||
t.Fatalf("Expected record with %q collection, got %q", core.CollectionNameAuthOrigins, origin.Collection().Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthOriginProxyRecord(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
record := core.NewRecord(core.NewBaseCollection("test"))
|
||||
record.Id = "test_id"
|
||||
|
||||
origin := core.AuthOrigin{}
|
||||
origin.SetProxyRecord(record)
|
||||
|
||||
if origin.ProxyRecord() == nil || origin.ProxyRecord().Id != record.Id {
|
||||
t.Fatalf("Expected proxy record with id %q, got %v", record.Id, origin.ProxyRecord())
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthOriginRecordRef(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
origin := core.NewAuthOrigin(app)
|
||||
|
||||
testValues := []string{"test_1", "test2", ""}
|
||||
for i, testValue := range testValues {
|
||||
t.Run(fmt.Sprintf("%d_%q", i, testValue), func(t *testing.T) {
|
||||
origin.SetRecordRef(testValue)
|
||||
|
||||
if v := origin.RecordRef(); v != testValue {
|
||||
t.Fatalf("Expected getter %q, got %q", testValue, v)
|
||||
}
|
||||
|
||||
if v := origin.GetString("recordRef"); v != testValue {
|
||||
t.Fatalf("Expected field value %q, got %q", testValue, v)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthOriginCollectionRef(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
origin := core.NewAuthOrigin(app)
|
||||
|
||||
testValues := []string{"test_1", "test2", ""}
|
||||
for i, testValue := range testValues {
|
||||
t.Run(fmt.Sprintf("%d_%q", i, testValue), func(t *testing.T) {
|
||||
origin.SetCollectionRef(testValue)
|
||||
|
||||
if v := origin.CollectionRef(); v != testValue {
|
||||
t.Fatalf("Expected getter %q, got %q", testValue, v)
|
||||
}
|
||||
|
||||
if v := origin.GetString("collectionRef"); v != testValue {
|
||||
t.Fatalf("Expected field value %q, got %q", testValue, v)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthOriginFingerprint(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
origin := core.NewAuthOrigin(app)
|
||||
|
||||
testValues := []string{"test_1", "test2", ""}
|
||||
for i, testValue := range testValues {
|
||||
t.Run(fmt.Sprintf("%d_%q", i, testValue), func(t *testing.T) {
|
||||
origin.SetFingerprint(testValue)
|
||||
|
||||
if v := origin.Fingerprint(); v != testValue {
|
||||
t.Fatalf("Expected getter %q, got %q", testValue, v)
|
||||
}
|
||||
|
||||
if v := origin.GetString("fingerprint"); v != testValue {
|
||||
t.Fatalf("Expected field value %q, got %q", testValue, v)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthOriginCreated(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
origin := core.NewAuthOrigin(app)
|
||||
|
||||
if v := origin.Created().String(); v != "" {
|
||||
t.Fatalf("Expected empty created, got %q", v)
|
||||
}
|
||||
|
||||
now := types.NowDateTime()
|
||||
origin.SetRaw("created", now)
|
||||
|
||||
if v := origin.Created().String(); v != now.String() {
|
||||
t.Fatalf("Expected %q created, got %q", now.String(), v)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthOriginUpdated(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
origin := core.NewAuthOrigin(app)
|
||||
|
||||
if v := origin.Updated().String(); v != "" {
|
||||
t.Fatalf("Expected empty updated, got %q", v)
|
||||
}
|
||||
|
||||
now := types.NowDateTime()
|
||||
origin.SetRaw("updated", now)
|
||||
|
||||
if v := origin.Updated().String(); v != now.String() {
|
||||
t.Fatalf("Expected %q updated, got %q", now.String(), v)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthOriginPreValidate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
originsCol, err := app.FindCollectionByNameOrId(core.CollectionNameAuthOrigins)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Run("no proxy record", func(t *testing.T) {
|
||||
origin := &core.AuthOrigin{}
|
||||
|
||||
if err := app.Validate(origin); err == nil {
|
||||
t.Fatal("Expected collection validation error")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("non-AuthOrigin collection", func(t *testing.T) {
|
||||
origin := &core.AuthOrigin{}
|
||||
origin.SetProxyRecord(core.NewRecord(core.NewBaseCollection("invalid")))
|
||||
origin.SetRecordRef(user.Id)
|
||||
origin.SetCollectionRef(user.Collection().Id)
|
||||
origin.SetFingerprint("abc")
|
||||
|
||||
if err := app.Validate(origin); err == nil {
|
||||
t.Fatal("Expected collection validation error")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("AuthOrigin collection", func(t *testing.T) {
|
||||
origin := &core.AuthOrigin{}
|
||||
origin.SetProxyRecord(core.NewRecord(originsCol))
|
||||
origin.SetRecordRef(user.Id)
|
||||
origin.SetCollectionRef(user.Collection().Id)
|
||||
origin.SetFingerprint("abc")
|
||||
|
||||
if err := app.Validate(origin); err != nil {
|
||||
t.Fatalf("Expected nil validation error, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestAuthOriginValidateHook(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
demo1, err := app.FindRecordById("demo1", "84nmscqy84lsi1t")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
origin func() *core.AuthOrigin
|
||||
expectErrors []string
|
||||
}{
|
||||
{
|
||||
"empty",
|
||||
func() *core.AuthOrigin {
|
||||
return core.NewAuthOrigin(app)
|
||||
},
|
||||
[]string{"collectionRef", "recordRef", "fingerprint"},
|
||||
},
|
||||
{
|
||||
"non-auth collection",
|
||||
func() *core.AuthOrigin {
|
||||
origin := core.NewAuthOrigin(app)
|
||||
origin.SetCollectionRef(demo1.Collection().Id)
|
||||
origin.SetRecordRef(demo1.Id)
|
||||
origin.SetFingerprint("abc")
|
||||
return origin
|
||||
},
|
||||
[]string{"collectionRef"},
|
||||
},
|
||||
{
|
||||
"missing record id",
|
||||
func() *core.AuthOrigin {
|
||||
origin := core.NewAuthOrigin(app)
|
||||
origin.SetCollectionRef(user.Collection().Id)
|
||||
origin.SetRecordRef("missing")
|
||||
origin.SetFingerprint("abc")
|
||||
return origin
|
||||
},
|
||||
[]string{"recordRef"},
|
||||
},
|
||||
{
|
||||
"valid ref",
|
||||
func() *core.AuthOrigin {
|
||||
origin := core.NewAuthOrigin(app)
|
||||
origin.SetCollectionRef(user.Collection().Id)
|
||||
origin.SetRecordRef(user.Id)
|
||||
origin.SetFingerprint("abc")
|
||||
return origin
|
||||
},
|
||||
[]string{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
errs := app.Validate(s.origin())
|
||||
tests.TestValidationErrors(t, errs, s.expectErrors)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthOriginPasswordChangeDeletion(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testApp, _ := tests.NewTestApp()
|
||||
defer testApp.Cleanup()
|
||||
|
||||
// no auth origin associated with it
|
||||
user1, err := testApp.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
superuser2, err := testApp.FindAuthRecordByEmail(core.CollectionNameSuperusers, "test2@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
client1, err := testApp.FindAuthRecordByEmail("clients", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
scenarios := []struct {
|
||||
record *core.Record
|
||||
deletedIds []string
|
||||
}{
|
||||
{user1, nil},
|
||||
{superuser2, []string{"5798yh833k6w6w0", "ic55o70g4f8pcl4", "dmy260k6ksjr4ib"}},
|
||||
{client1, []string{"9r2j0m74260ur8i"}},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%s_%s", i, s.record.Collection().Name, s.record.Id), func(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
deletedIds := []string{}
|
||||
app.OnRecordDelete().BindFunc(func(e *core.RecordEvent) error {
|
||||
deletedIds = append(deletedIds, e.Record.Id)
|
||||
return e.Next()
|
||||
})
|
||||
|
||||
s.record.SetPassword("new_password")
|
||||
|
||||
err := app.Save(s.record)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(deletedIds) != len(s.deletedIds) {
|
||||
t.Fatalf("Expected deleted ids\n%v\ngot\n%v", s.deletedIds, deletedIds)
|
||||
}
|
||||
|
||||
for _, id := range s.deletedIds {
|
||||
if !slices.Contains(deletedIds, id) {
|
||||
t.Errorf("Expected to find deleted id %q in %v", id, deletedIds)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
101
core/auth_origin_query.go
Normal file
101
core/auth_origin_query.go
Normal file
|
@ -0,0 +1,101 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/pocketbase/dbx"
|
||||
)
|
||||
|
||||
// FindAllAuthOriginsByRecord returns all AuthOrigin models linked to the provided auth record (in DESC order).
|
||||
func (app *BaseApp) FindAllAuthOriginsByRecord(authRecord *Record) ([]*AuthOrigin, error) {
|
||||
result := []*AuthOrigin{}
|
||||
|
||||
err := app.RecordQuery(CollectionNameAuthOrigins).
|
||||
AndWhere(dbx.HashExp{
|
||||
"collectionRef": authRecord.Collection().Id,
|
||||
"recordRef": authRecord.Id,
|
||||
}).
|
||||
OrderBy("created DESC").
|
||||
All(&result)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// FindAllAuthOriginsByCollection returns all AuthOrigin models linked to the provided collection (in DESC order).
|
||||
func (app *BaseApp) FindAllAuthOriginsByCollection(collection *Collection) ([]*AuthOrigin, error) {
|
||||
result := []*AuthOrigin{}
|
||||
|
||||
err := app.RecordQuery(CollectionNameAuthOrigins).
|
||||
AndWhere(dbx.HashExp{"collectionRef": collection.Id}).
|
||||
OrderBy("created DESC").
|
||||
All(&result)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// FindAuthOriginById returns a single AuthOrigin model by its id.
|
||||
func (app *BaseApp) FindAuthOriginById(id string) (*AuthOrigin, error) {
|
||||
result := &AuthOrigin{}
|
||||
|
||||
err := app.RecordQuery(CollectionNameAuthOrigins).
|
||||
AndWhere(dbx.HashExp{"id": id}).
|
||||
Limit(1).
|
||||
One(result)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// FindAuthOriginByRecordAndFingerprint returns a single AuthOrigin model
|
||||
// by its authRecord relation and fingerprint.
|
||||
func (app *BaseApp) FindAuthOriginByRecordAndFingerprint(authRecord *Record, fingerprint string) (*AuthOrigin, error) {
|
||||
result := &AuthOrigin{}
|
||||
|
||||
err := app.RecordQuery(CollectionNameAuthOrigins).
|
||||
AndWhere(dbx.HashExp{
|
||||
"collectionRef": authRecord.Collection().Id,
|
||||
"recordRef": authRecord.Id,
|
||||
"fingerprint": fingerprint,
|
||||
}).
|
||||
Limit(1).
|
||||
One(result)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// DeleteAllAuthOriginsByRecord deletes all AuthOrigin models associated with the provided record.
|
||||
//
|
||||
// Returns a combined error with the failed deletes.
|
||||
func (app *BaseApp) DeleteAllAuthOriginsByRecord(authRecord *Record) error {
|
||||
models, err := app.FindAllAuthOriginsByRecord(authRecord)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var errs []error
|
||||
for _, m := range models {
|
||||
if err := app.Delete(m); err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
}
|
||||
if len(errs) > 0 {
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
268
core/auth_origin_query_test.go
Normal file
268
core/auth_origin_query_test.go
Normal file
|
@ -0,0 +1,268 @@
|
|||
package core_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestFindAllAuthOriginsByRecord(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
demo1, err := app.FindRecordById("demo1", "84nmscqy84lsi1t")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
superuser2, err := app.FindAuthRecordByEmail(core.CollectionNameSuperusers, "test2@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
superuser4, err := app.FindAuthRecordByEmail(core.CollectionNameSuperusers, "test4@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
client1, err := app.FindAuthRecordByEmail("clients", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
scenarios := []struct {
|
||||
record *core.Record
|
||||
expected []string
|
||||
}{
|
||||
{demo1, nil},
|
||||
{superuser2, []string{"5798yh833k6w6w0", "ic55o70g4f8pcl4", "dmy260k6ksjr4ib"}},
|
||||
{superuser4, nil},
|
||||
{client1, []string{"9r2j0m74260ur8i"}},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.record.Collection().Name+"_"+s.record.Id, func(t *testing.T) {
|
||||
result, err := app.FindAllAuthOriginsByRecord(s.record)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(result) != len(s.expected) {
|
||||
t.Fatalf("Expected total origins %d, got %d", len(s.expected), len(result))
|
||||
}
|
||||
|
||||
for i, id := range s.expected {
|
||||
if result[i].Id != id {
|
||||
t.Errorf("[%d] Expected id %q, got %q", i, id, result[i].Id)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindAllAuthOriginsByCollection(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
demo1, err := app.FindCollectionByNameOrId("demo1")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
superusers, err := app.FindCollectionByNameOrId(core.CollectionNameSuperusers)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
clients, err := app.FindCollectionByNameOrId("clients")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
scenarios := []struct {
|
||||
collection *core.Collection
|
||||
expected []string
|
||||
}{
|
||||
{demo1, nil},
|
||||
{superusers, []string{"5798yh833k6w6w0", "ic55o70g4f8pcl4", "dmy260k6ksjr4ib", "5f29jy38bf5zm3f"}},
|
||||
{clients, []string{"9r2j0m74260ur8i"}},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.collection.Name, func(t *testing.T) {
|
||||
result, err := app.FindAllAuthOriginsByCollection(s.collection)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(result) != len(s.expected) {
|
||||
t.Fatalf("Expected total origins %d, got %d", len(s.expected), len(result))
|
||||
}
|
||||
|
||||
for i, id := range s.expected {
|
||||
if result[i].Id != id {
|
||||
t.Errorf("[%d] Expected id %q, got %q", i, id, result[i].Id)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindAuthOriginById(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
scenarios := []struct {
|
||||
id string
|
||||
expectError bool
|
||||
}{
|
||||
{"", true},
|
||||
{"84nmscqy84lsi1t", true}, // non-origin id
|
||||
{"9r2j0m74260ur8i", false},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.id, func(t *testing.T) {
|
||||
result, err := app.FindAuthOriginById(s.id)
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
|
||||
}
|
||||
|
||||
if hasErr {
|
||||
return
|
||||
}
|
||||
|
||||
if result.Id != s.id {
|
||||
t.Fatalf("Expected record with id %q, got %q", s.id, result.Id)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindAuthOriginByRecordAndFingerprint(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
demo1, err := app.FindRecordById("demo1", "84nmscqy84lsi1t")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
superuser2, err := app.FindAuthRecordByEmail(core.CollectionNameSuperusers, "test2@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
scenarios := []struct {
|
||||
record *core.Record
|
||||
fingerprint string
|
||||
expectError bool
|
||||
}{
|
||||
{demo1, "6afbfe481c31c08c55a746cccb88ece0", true},
|
||||
{superuser2, "", true},
|
||||
{superuser2, "abc", true},
|
||||
{superuser2, "22bbbcbed36e25321f384ccf99f60057", false}, // fingerprint from different origin
|
||||
{superuser2, "6afbfe481c31c08c55a746cccb88ece0", false},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%s_%s", i, s.record.Id, s.fingerprint), func(t *testing.T) {
|
||||
result, err := app.FindAuthOriginByRecordAndFingerprint(s.record, s.fingerprint)
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
|
||||
}
|
||||
|
||||
if hasErr {
|
||||
return
|
||||
}
|
||||
|
||||
if result.Fingerprint() != s.fingerprint {
|
||||
t.Fatalf("Expected origin with fingerprint %q, got %q", s.fingerprint, result.Fingerprint())
|
||||
}
|
||||
|
||||
if result.RecordRef() != s.record.Id || result.CollectionRef() != s.record.Collection().Id {
|
||||
t.Fatalf("Expected record %q (%q), got %q (%q)", s.record.Id, s.record.Collection().Id, result.RecordRef(), result.CollectionRef())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteAllAuthOriginsByRecord(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testApp, _ := tests.NewTestApp()
|
||||
defer testApp.Cleanup()
|
||||
|
||||
demo1, err := testApp.FindRecordById("demo1", "84nmscqy84lsi1t")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
superuser2, err := testApp.FindAuthRecordByEmail(core.CollectionNameSuperusers, "test2@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
superuser4, err := testApp.FindAuthRecordByEmail(core.CollectionNameSuperusers, "test4@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
client1, err := testApp.FindAuthRecordByEmail("clients", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
scenarios := []struct {
|
||||
record *core.Record
|
||||
deletedIds []string
|
||||
}{
|
||||
{demo1, nil}, // non-auth record
|
||||
{superuser2, []string{"5798yh833k6w6w0", "ic55o70g4f8pcl4", "dmy260k6ksjr4ib"}},
|
||||
{superuser4, nil},
|
||||
{client1, []string{"9r2j0m74260ur8i"}},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%s_%s", i, s.record.Collection().Name, s.record.Id), func(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
deletedIds := []string{}
|
||||
app.OnRecordDelete().BindFunc(func(e *core.RecordEvent) error {
|
||||
deletedIds = append(deletedIds, e.Record.Id)
|
||||
return e.Next()
|
||||
})
|
||||
|
||||
err := app.DeleteAllAuthOriginsByRecord(s.record)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(deletedIds) != len(s.deletedIds) {
|
||||
t.Fatalf("Expected deleted ids\n%v\ngot\n%v", s.deletedIds, deletedIds)
|
||||
}
|
||||
|
||||
for _, id := range s.deletedIds {
|
||||
if !slices.Contains(deletedIds, id) {
|
||||
t.Errorf("Expected to find deleted id %q in %v", id, deletedIds)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
1535
core/base.go
Normal file
1535
core/base.go
Normal file
File diff suppressed because it is too large
Load diff
389
core/base_backup.go
Normal file
389
core/base_backup.go
Normal file
|
@ -0,0 +1,389 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/archive"
|
||||
"github.com/pocketbase/pocketbase/tools/filesystem"
|
||||
"github.com/pocketbase/pocketbase/tools/inflector"
|
||||
"github.com/pocketbase/pocketbase/tools/osutils"
|
||||
"github.com/pocketbase/pocketbase/tools/security"
|
||||
)
|
||||
|
||||
const (
|
||||
StoreKeyActiveBackup = "@activeBackup"
|
||||
)
|
||||
|
||||
// CreateBackup creates a new backup of the current app pb_data directory.
|
||||
//
|
||||
// If name is empty, it will be autogenerated.
|
||||
// If backup with the same name exists, the new backup file will replace it.
|
||||
//
|
||||
// The backup is executed within a transaction, meaning that new writes
|
||||
// will be temporary "blocked" until the backup file is generated.
|
||||
//
|
||||
// To safely perform the backup, it is recommended to have free disk space
|
||||
// for at least 2x the size of the pb_data directory.
|
||||
//
|
||||
// By default backups are stored in pb_data/backups
|
||||
// (the backups directory itself is excluded from the generated backup).
|
||||
//
|
||||
// When using S3 storage for the uploaded collection files, you have to
|
||||
// take care manually to backup those since they are not part of the pb_data.
|
||||
//
|
||||
// Backups can be stored on S3 if it is configured in app.Settings().Backups.
|
||||
func (app *BaseApp) CreateBackup(ctx context.Context, name string) error {
|
||||
if app.Store().Has(StoreKeyActiveBackup) {
|
||||
return errors.New("try again later - another backup/restore operation has already been started")
|
||||
}
|
||||
|
||||
app.Store().Set(StoreKeyActiveBackup, name)
|
||||
defer app.Store().Remove(StoreKeyActiveBackup)
|
||||
|
||||
event := new(BackupEvent)
|
||||
event.App = app
|
||||
event.Context = ctx
|
||||
event.Name = name
|
||||
// default root dir entries to exclude from the backup generation
|
||||
event.Exclude = []string{LocalBackupsDirName, LocalTempDirName, LocalAutocertCacheDirName}
|
||||
|
||||
return app.OnBackupCreate().Trigger(event, func(e *BackupEvent) error {
|
||||
// generate a default name if missing
|
||||
if e.Name == "" {
|
||||
e.Name = generateBackupName(e.App, "pb_backup_")
|
||||
}
|
||||
|
||||
// make sure that the special temp directory exists
|
||||
// note: it needs to be inside the current pb_data to avoid "cross-device link" errors
|
||||
localTempDir := filepath.Join(e.App.DataDir(), LocalTempDirName)
|
||||
if err := os.MkdirAll(localTempDir, os.ModePerm); err != nil {
|
||||
return fmt.Errorf("failed to create a temp dir: %w", err)
|
||||
}
|
||||
|
||||
// archive pb_data in a temp directory, exluding the "backups" and the temp dirs
|
||||
//
|
||||
// Run in transaction to temporary block other writes (transactions uses the NonconcurrentDB connection).
|
||||
// ---
|
||||
tempPath := filepath.Join(localTempDir, "pb_backup_"+security.PseudorandomString(6))
|
||||
createErr := e.App.RunInTransaction(func(txApp App) error {
|
||||
return txApp.AuxRunInTransaction(func(txApp App) error {
|
||||
// run manual checkpoint and truncate the WAL files
|
||||
// (errors are ignored because it is not that important and the PRAGMA may not be supported by the used driver)
|
||||
txApp.DB().NewQuery("PRAGMA wal_checkpoint(TRUNCATE)").Execute()
|
||||
txApp.AuxDB().NewQuery("PRAGMA wal_checkpoint(TRUNCATE)").Execute()
|
||||
|
||||
return archive.Create(txApp.DataDir(), tempPath, e.Exclude...)
|
||||
})
|
||||
})
|
||||
if createErr != nil {
|
||||
return createErr
|
||||
}
|
||||
defer os.Remove(tempPath)
|
||||
|
||||
// persist the backup in the backups filesystem
|
||||
// ---
|
||||
fsys, err := e.App.NewBackupsFilesystem()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer fsys.Close()
|
||||
|
||||
fsys.SetContext(e.Context)
|
||||
|
||||
file, err := filesystem.NewFileFromPath(tempPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
file.OriginalName = e.Name
|
||||
file.Name = file.OriginalName
|
||||
|
||||
if err := fsys.UploadFile(file, file.Name); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// RestoreBackup restores the backup with the specified name and restarts
|
||||
// the current running application process.
|
||||
//
|
||||
// NB! This feature is experimental and currently is expected to work only on UNIX based systems.
|
||||
//
|
||||
// To safely perform the restore it is recommended to have free disk space
|
||||
// for at least 2x the size of the restored pb_data backup.
|
||||
//
|
||||
// The performed steps are:
|
||||
//
|
||||
// 1. Download the backup with the specified name in a temp location
|
||||
// (this is in case of S3; otherwise it creates a temp copy of the zip)
|
||||
//
|
||||
// 2. Extract the backup in a temp directory inside the app "pb_data"
|
||||
// (eg. "pb_data/.pb_temp_to_delete/pb_restore").
|
||||
//
|
||||
// 3. Move the current app "pb_data" content (excluding the local backups and the special temp dir)
|
||||
// under another temp sub dir that will be deleted on the next app start up
|
||||
// (eg. "pb_data/.pb_temp_to_delete/old_pb_data").
|
||||
// This is because on some environments it may not be allowed
|
||||
// to delete the currently open "pb_data" files.
|
||||
//
|
||||
// 4. Move the extracted dir content to the app "pb_data".
|
||||
//
|
||||
// 5. Restart the app (on successful app bootstap it will also remove the old pb_data).
|
||||
//
|
||||
// If a failure occure during the restore process the dir changes are reverted.
|
||||
// If for whatever reason the revert is not possible, it panics.
|
||||
//
|
||||
// Note that if your pb_data has custom network mounts as subdirectories, then
|
||||
// it is possible the restore to fail during the `os.Rename` operations
|
||||
// (see https://github.com/pocketbase/pocketbase/issues/4647).
|
||||
func (app *BaseApp) RestoreBackup(ctx context.Context, name string) error {
|
||||
if app.Store().Has(StoreKeyActiveBackup) {
|
||||
return errors.New("try again later - another backup/restore operation has already been started")
|
||||
}
|
||||
|
||||
app.Store().Set(StoreKeyActiveBackup, name)
|
||||
defer app.Store().Remove(StoreKeyActiveBackup)
|
||||
|
||||
event := new(BackupEvent)
|
||||
event.App = app
|
||||
event.Context = ctx
|
||||
event.Name = name
|
||||
// default root dir entries to exclude from the backup restore
|
||||
event.Exclude = []string{LocalBackupsDirName, LocalTempDirName, LocalAutocertCacheDirName}
|
||||
|
||||
return app.OnBackupRestore().Trigger(event, func(e *BackupEvent) error {
|
||||
if runtime.GOOS == "windows" {
|
||||
return errors.New("restore is not supported on Windows")
|
||||
}
|
||||
|
||||
// make sure that the special temp directory exists
|
||||
// note: it needs to be inside the current pb_data to avoid "cross-device link" errors
|
||||
localTempDir := filepath.Join(e.App.DataDir(), LocalTempDirName)
|
||||
if err := os.MkdirAll(localTempDir, os.ModePerm); err != nil {
|
||||
return fmt.Errorf("failed to create a temp dir: %w", err)
|
||||
}
|
||||
|
||||
fsys, err := e.App.NewBackupsFilesystem()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer fsys.Close()
|
||||
|
||||
fsys.SetContext(e.Context)
|
||||
|
||||
if ok, _ := fsys.Exists(name); !ok {
|
||||
return fmt.Errorf("missing or invalid backup file %q to restore", name)
|
||||
}
|
||||
|
||||
extractedDataDir := filepath.Join(localTempDir, "pb_restore_"+security.PseudorandomString(8))
|
||||
defer os.RemoveAll(extractedDataDir)
|
||||
|
||||
// extract the zip
|
||||
if e.App.Settings().Backups.S3.Enabled {
|
||||
br, err := fsys.GetReader(name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer br.Close()
|
||||
|
||||
// create a temp zip file from the blob.Reader and try to extract it
|
||||
tempZip, err := os.CreateTemp(localTempDir, "pb_restore_zip")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer os.Remove(tempZip.Name())
|
||||
defer tempZip.Close() // note: this technically shouldn't be necessary but it is here to workaround platforms discrepancies
|
||||
|
||||
_, err = io.Copy(tempZip, br)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = archive.Extract(tempZip.Name(), extractedDataDir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// remove the temp zip file since we no longer need it
|
||||
// (this is in case the app restarts and the defer calls are not called)
|
||||
_ = tempZip.Close()
|
||||
err = os.Remove(tempZip.Name())
|
||||
if err != nil {
|
||||
e.App.Logger().Warn(
|
||||
"[RestoreBackup] Failed to remove the temp zip backup file",
|
||||
slog.String("file", tempZip.Name()),
|
||||
slog.String("error", err.Error()),
|
||||
)
|
||||
}
|
||||
} else {
|
||||
// manually construct the local path to avoid creating a copy of the zip file
|
||||
// since the blob reader currently doesn't implement ReaderAt
|
||||
zipPath := filepath.Join(app.DataDir(), LocalBackupsDirName, filepath.Base(name))
|
||||
|
||||
err = archive.Extract(zipPath, extractedDataDir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// ensure that at least a database file exists
|
||||
extractedDB := filepath.Join(extractedDataDir, "data.db")
|
||||
if _, err := os.Stat(extractedDB); err != nil {
|
||||
return fmt.Errorf("data.db file is missing or invalid: %w", err)
|
||||
}
|
||||
|
||||
// move the current pb_data content to a special temp location
|
||||
// that will hold the old data between dirs replace
|
||||
// (the temp dir will be automatically removed on the next app start)
|
||||
oldTempDataDir := filepath.Join(localTempDir, "old_pb_data_"+security.PseudorandomString(8))
|
||||
if err := osutils.MoveDirContent(e.App.DataDir(), oldTempDataDir, e.Exclude...); err != nil {
|
||||
return fmt.Errorf("failed to move the current pb_data content to a temp location: %w", err)
|
||||
}
|
||||
|
||||
// move the extracted archive content to the app's pb_data
|
||||
if err := osutils.MoveDirContent(extractedDataDir, e.App.DataDir(), e.Exclude...); err != nil {
|
||||
return fmt.Errorf("failed to move the extracted archive content to pb_data: %w", err)
|
||||
}
|
||||
|
||||
revertDataDirChanges := func() error {
|
||||
if err := osutils.MoveDirContent(e.App.DataDir(), extractedDataDir, e.Exclude...); err != nil {
|
||||
return fmt.Errorf("failed to revert the extracted dir change: %w", err)
|
||||
}
|
||||
|
||||
if err := osutils.MoveDirContent(oldTempDataDir, e.App.DataDir(), e.Exclude...); err != nil {
|
||||
return fmt.Errorf("failed to revert old pb_data dir change: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// restart the app
|
||||
if err := e.App.Restart(); err != nil {
|
||||
if revertErr := revertDataDirChanges(); revertErr != nil {
|
||||
panic(revertErr)
|
||||
}
|
||||
|
||||
return fmt.Errorf("failed to restart the app process: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// registerAutobackupHooks registers the autobackup app serve hooks.
|
||||
func (app *BaseApp) registerAutobackupHooks() {
|
||||
const jobId = "__pbAutoBackup__"
|
||||
|
||||
loadJob := func() {
|
||||
rawSchedule := app.Settings().Backups.Cron
|
||||
if rawSchedule == "" {
|
||||
app.Cron().Remove(jobId)
|
||||
return
|
||||
}
|
||||
|
||||
app.Cron().Add(jobId, rawSchedule, func() {
|
||||
const autoPrefix = "@auto_pb_backup_"
|
||||
|
||||
name := generateBackupName(app, autoPrefix)
|
||||
|
||||
if err := app.CreateBackup(context.Background(), name); err != nil {
|
||||
app.Logger().Error(
|
||||
"[Backup cron] Failed to create backup",
|
||||
slog.String("name", name),
|
||||
slog.String("error", err.Error()),
|
||||
)
|
||||
}
|
||||
|
||||
maxKeep := app.Settings().Backups.CronMaxKeep
|
||||
|
||||
if maxKeep == 0 {
|
||||
return // no explicit limit
|
||||
}
|
||||
|
||||
fsys, err := app.NewBackupsFilesystem()
|
||||
if err != nil {
|
||||
app.Logger().Error(
|
||||
"[Backup cron] Failed to initialize the backup filesystem",
|
||||
slog.String("error", err.Error()),
|
||||
)
|
||||
return
|
||||
}
|
||||
defer fsys.Close()
|
||||
|
||||
files, err := fsys.List(autoPrefix)
|
||||
if err != nil {
|
||||
app.Logger().Error(
|
||||
"[Backup cron] Failed to list autogenerated backups",
|
||||
slog.String("error", err.Error()),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
if maxKeep >= len(files) {
|
||||
return // nothing to remove
|
||||
}
|
||||
|
||||
// sort desc
|
||||
sort.Slice(files, func(i, j int) bool {
|
||||
return files[i].ModTime.After(files[j].ModTime)
|
||||
})
|
||||
|
||||
// keep only the most recent n auto backup files
|
||||
toRemove := files[maxKeep:]
|
||||
|
||||
for _, f := range toRemove {
|
||||
if err := fsys.Delete(f.Key); err != nil {
|
||||
app.Logger().Error(
|
||||
"[Backup cron] Failed to remove old autogenerated backup",
|
||||
slog.String("key", f.Key),
|
||||
slog.String("error", err.Error()),
|
||||
)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
app.OnBootstrap().BindFunc(func(e *BootstrapEvent) error {
|
||||
if err := e.Next(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
loadJob()
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
app.OnSettingsReload().BindFunc(func(e *SettingsReloadEvent) error {
|
||||
if err := e.Next(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
loadJob()
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func generateBackupName(app App, prefix string) string {
|
||||
appName := inflector.Snakecase(app.Settings().Meta.AppName)
|
||||
if len(appName) > 50 {
|
||||
appName = appName[:50]
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
"%s%s_%s.zip",
|
||||
prefix,
|
||||
appName,
|
||||
time.Now().UTC().Format("20060102150405"),
|
||||
)
|
||||
}
|
164
core/base_backup_test.go
Normal file
164
core/base_backup_test.go
Normal file
|
@ -0,0 +1,164 @@
|
|||
package core_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
"github.com/pocketbase/pocketbase/tools/archive"
|
||||
"github.com/pocketbase/pocketbase/tools/list"
|
||||
)
|
||||
|
||||
func TestCreateBackup(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
// set some long app name with spaces and special characters
|
||||
app.Settings().Meta.AppName = "test @! " + strings.Repeat("a", 100)
|
||||
|
||||
expectedAppNamePrefix := "test_" + strings.Repeat("a", 45)
|
||||
|
||||
// test pending error
|
||||
app.Store().Set(core.StoreKeyActiveBackup, "")
|
||||
if err := app.CreateBackup(context.Background(), "test.zip"); err == nil {
|
||||
t.Fatal("Expected pending error, got nil")
|
||||
}
|
||||
app.Store().Remove(core.StoreKeyActiveBackup)
|
||||
|
||||
// create with auto generated name
|
||||
if err := app.CreateBackup(context.Background(), ""); err != nil {
|
||||
t.Fatal("Failed to create a backup with autogenerated name")
|
||||
}
|
||||
|
||||
// create with custom name
|
||||
if err := app.CreateBackup(context.Background(), "custom"); err != nil {
|
||||
t.Fatal("Failed to create a backup with custom name")
|
||||
}
|
||||
|
||||
// create new with the same name (aka. replace)
|
||||
if err := app.CreateBackup(context.Background(), "custom"); err != nil {
|
||||
t.Fatal("Failed to create and replace a backup with the same name")
|
||||
}
|
||||
|
||||
backupsDir := filepath.Join(app.DataDir(), core.LocalBackupsDirName)
|
||||
|
||||
entries, err := os.ReadDir(backupsDir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expectedFiles := []string{
|
||||
`^pb_backup_` + expectedAppNamePrefix + `_\w+\.zip$`,
|
||||
`^pb_backup_` + expectedAppNamePrefix + `_\w+\.zip.attrs$`,
|
||||
"custom",
|
||||
"custom.attrs",
|
||||
}
|
||||
|
||||
if len(entries) != len(expectedFiles) {
|
||||
names := getEntryNames(entries)
|
||||
t.Fatalf("Expected %d backup files, got %d: \n%v", len(expectedFiles), len(entries), names)
|
||||
}
|
||||
|
||||
for i, entry := range entries {
|
||||
if !list.ExistInSliceWithRegex(entry.Name(), expectedFiles) {
|
||||
t.Fatalf("[%d] Missing backup file %q", i, entry.Name())
|
||||
}
|
||||
|
||||
if strings.HasSuffix(entry.Name(), ".attrs") {
|
||||
continue
|
||||
}
|
||||
|
||||
path := filepath.Join(backupsDir, entry.Name())
|
||||
|
||||
if err := verifyBackupContent(app, path); err != nil {
|
||||
t.Fatalf("[%d] Failed to verify backup content: %v", i, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRestoreBackup(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
// create a initial test backup to ensure that there are at least 1
|
||||
// backup file and that the generated zip doesn't contain the backups dir
|
||||
if err := app.CreateBackup(context.Background(), "initial"); err != nil {
|
||||
t.Fatal("Failed to create test initial backup")
|
||||
}
|
||||
|
||||
// create test backup
|
||||
if err := app.CreateBackup(context.Background(), "test"); err != nil {
|
||||
t.Fatal("Failed to create test backup")
|
||||
}
|
||||
|
||||
// test pending error
|
||||
app.Store().Set(core.StoreKeyActiveBackup, "")
|
||||
if err := app.RestoreBackup(context.Background(), "test"); err == nil {
|
||||
t.Fatal("Expected pending error, got nil")
|
||||
}
|
||||
app.Store().Remove(core.StoreKeyActiveBackup)
|
||||
|
||||
// missing backup
|
||||
if err := app.RestoreBackup(context.Background(), "missing"); err == nil {
|
||||
t.Fatal("Expected missing error, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
func verifyBackupContent(app core.App, path string) error {
|
||||
dir, err := os.MkdirTemp("", "backup_test")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
if err := archive.Extract(path, dir); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
expectedRootEntries := []string{
|
||||
"storage",
|
||||
"data.db",
|
||||
"data.db-shm",
|
||||
"data.db-wal",
|
||||
"auxiliary.db",
|
||||
"auxiliary.db-shm",
|
||||
"auxiliary.db-wal",
|
||||
".gitignore",
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(entries) != len(expectedRootEntries) {
|
||||
names := getEntryNames(entries)
|
||||
return fmt.Errorf("Expected %d backup files, got %d: \n%v", len(expectedRootEntries), len(entries), names)
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
if !list.ExistInSliceWithRegex(entry.Name(), expectedRootEntries) {
|
||||
return fmt.Errorf("Didn't expect %q entry", entry.Name())
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getEntryNames(entries []fs.DirEntry) []string {
|
||||
names := make([]string, len(entries))
|
||||
|
||||
for i, entry := range entries {
|
||||
names[i] = entry.Name()
|
||||
}
|
||||
|
||||
return names
|
||||
}
|
554
core/base_test.go
Normal file
554
core/base_test.go
Normal file
|
@ -0,0 +1,554 @@
|
|||
package core_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"log/slog"
|
||||
"os"
|
||||
"slices"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
_ "unsafe"
|
||||
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
"github.com/pocketbase/pocketbase/tools/logger"
|
||||
"github.com/pocketbase/pocketbase/tools/mailer"
|
||||
)
|
||||
|
||||
func TestNewBaseApp(t *testing.T) {
|
||||
const testDataDir = "./pb_base_app_test_data_dir/"
|
||||
defer os.RemoveAll(testDataDir)
|
||||
|
||||
app := core.NewBaseApp(core.BaseAppConfig{
|
||||
DataDir: testDataDir,
|
||||
EncryptionEnv: "test_env",
|
||||
IsDev: true,
|
||||
})
|
||||
|
||||
if app.DataDir() != testDataDir {
|
||||
t.Fatalf("expected DataDir %q, got %q", testDataDir, app.DataDir())
|
||||
}
|
||||
|
||||
if app.EncryptionEnv() != "test_env" {
|
||||
t.Fatalf("expected EncryptionEnv test_env, got %q", app.EncryptionEnv())
|
||||
}
|
||||
|
||||
if !app.IsDev() {
|
||||
t.Fatalf("expected IsDev true, got %v", app.IsDev())
|
||||
}
|
||||
|
||||
if app.Store() == nil {
|
||||
t.Fatal("expected Store to be set, got nil")
|
||||
}
|
||||
|
||||
if app.Settings() == nil {
|
||||
t.Fatal("expected Settings to be set, got nil")
|
||||
}
|
||||
|
||||
if app.SubscriptionsBroker() == nil {
|
||||
t.Fatal("expected SubscriptionsBroker to be set, got nil")
|
||||
}
|
||||
|
||||
if app.Cron() == nil {
|
||||
t.Fatal("expected Cron to be set, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseAppBootstrap(t *testing.T) {
|
||||
const testDataDir = "./pb_base_app_test_data_dir/"
|
||||
defer os.RemoveAll(testDataDir)
|
||||
|
||||
app := core.NewBaseApp(core.BaseAppConfig{
|
||||
DataDir: testDataDir,
|
||||
})
|
||||
defer app.ResetBootstrapState()
|
||||
|
||||
if app.IsBootstrapped() {
|
||||
t.Fatal("Didn't expect the application to be bootstrapped.")
|
||||
}
|
||||
|
||||
if err := app.Bootstrap(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !app.IsBootstrapped() {
|
||||
t.Fatal("Expected the application to be bootstrapped.")
|
||||
}
|
||||
|
||||
if stat, err := os.Stat(testDataDir); err != nil || !stat.IsDir() {
|
||||
t.Fatal("Expected test data directory to be created.")
|
||||
}
|
||||
|
||||
type nilCheck struct {
|
||||
name string
|
||||
value any
|
||||
expectNil bool
|
||||
}
|
||||
|
||||
runNilChecks := func(checks []nilCheck) {
|
||||
for _, check := range checks {
|
||||
t.Run(check.name, func(t *testing.T) {
|
||||
isNil := check.value == nil
|
||||
if isNil != check.expectNil {
|
||||
t.Fatalf("Expected isNil %v, got %v", check.expectNil, isNil)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
nilChecksBeforeReset := []nilCheck{
|
||||
{"[before] db", app.DB(), false},
|
||||
{"[before] concurrentDB", app.ConcurrentDB(), false},
|
||||
{"[before] nonconcurrentDB", app.NonconcurrentDB(), false},
|
||||
{"[before] auxDB", app.AuxDB(), false},
|
||||
{"[before] auxConcurrentDB", app.AuxConcurrentDB(), false},
|
||||
{"[before] auxNonconcurrentDB", app.AuxNonconcurrentDB(), false},
|
||||
{"[before] settings", app.Settings(), false},
|
||||
{"[before] logger", app.Logger(), false},
|
||||
{"[before] cached collections", app.Store().Get(core.StoreKeyCachedCollections), false},
|
||||
}
|
||||
|
||||
runNilChecks(nilChecksBeforeReset)
|
||||
|
||||
// reset
|
||||
if err := app.ResetBootstrapState(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
nilChecksAfterReset := []nilCheck{
|
||||
{"[after] db", app.DB(), true},
|
||||
{"[after] concurrentDB", app.ConcurrentDB(), true},
|
||||
{"[after] nonconcurrentDB", app.NonconcurrentDB(), true},
|
||||
{"[after] auxDB", app.AuxDB(), true},
|
||||
{"[after] auxConcurrentDB", app.AuxConcurrentDB(), true},
|
||||
{"[after] auxNonconcurrentDB", app.AuxNonconcurrentDB(), true},
|
||||
{"[after] settings", app.Settings(), false},
|
||||
{"[after] logger", app.Logger(), false},
|
||||
{"[after] cached collections", app.Store().Get(core.StoreKeyCachedCollections), false},
|
||||
}
|
||||
|
||||
runNilChecks(nilChecksAfterReset)
|
||||
}
|
||||
|
||||
func TestNewBaseAppTx(t *testing.T) {
|
||||
const testDataDir = "./pb_base_app_test_data_dir/"
|
||||
defer os.RemoveAll(testDataDir)
|
||||
|
||||
app := core.NewBaseApp(core.BaseAppConfig{
|
||||
DataDir: testDataDir,
|
||||
})
|
||||
defer app.ResetBootstrapState()
|
||||
|
||||
if err := app.Bootstrap(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
mustNotHaveTx := func(app core.App) {
|
||||
if app.IsTransactional() {
|
||||
t.Fatalf("Didn't expect the app to be transactional")
|
||||
}
|
||||
|
||||
if app.TxInfo() != nil {
|
||||
t.Fatalf("Didn't expect the app.txInfo to be loaded")
|
||||
}
|
||||
}
|
||||
|
||||
mustHaveTx := func(app core.App) {
|
||||
if !app.IsTransactional() {
|
||||
t.Fatalf("Expected the app to be transactional")
|
||||
}
|
||||
|
||||
if app.TxInfo() == nil {
|
||||
t.Fatalf("Expected the app.txInfo to be loaded")
|
||||
}
|
||||
}
|
||||
|
||||
mustNotHaveTx(app)
|
||||
|
||||
app.RunInTransaction(func(txApp core.App) error {
|
||||
mustHaveTx(txApp)
|
||||
return nil
|
||||
})
|
||||
|
||||
mustNotHaveTx(app)
|
||||
}
|
||||
|
||||
func TestBaseAppNewMailClient(t *testing.T) {
|
||||
const testDataDir = "./pb_base_app_test_data_dir/"
|
||||
defer os.RemoveAll(testDataDir)
|
||||
|
||||
app := core.NewBaseApp(core.BaseAppConfig{
|
||||
DataDir: testDataDir,
|
||||
EncryptionEnv: "pb_test_env",
|
||||
})
|
||||
defer app.ResetBootstrapState()
|
||||
|
||||
client1 := app.NewMailClient()
|
||||
m1, ok := client1.(*mailer.Sendmail)
|
||||
if !ok {
|
||||
t.Fatalf("Expected mailer.Sendmail instance, got %v", m1)
|
||||
}
|
||||
if m1.OnSend() == nil || m1.OnSend().Length() == 0 {
|
||||
t.Fatal("Expected OnSend hook to be registered")
|
||||
}
|
||||
|
||||
app.Settings().SMTP.Enabled = true
|
||||
|
||||
client2 := app.NewMailClient()
|
||||
m2, ok := client2.(*mailer.SMTPClient)
|
||||
if !ok {
|
||||
t.Fatalf("Expected mailer.SMTPClient instance, got %v", m2)
|
||||
}
|
||||
if m2.OnSend() == nil || m2.OnSend().Length() == 0 {
|
||||
t.Fatal("Expected OnSend hook to be registered")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseAppNewFilesystem(t *testing.T) {
|
||||
const testDataDir = "./pb_base_app_test_data_dir/"
|
||||
defer os.RemoveAll(testDataDir)
|
||||
|
||||
app := core.NewBaseApp(core.BaseAppConfig{
|
||||
DataDir: testDataDir,
|
||||
})
|
||||
defer app.ResetBootstrapState()
|
||||
|
||||
// local
|
||||
local, localErr := app.NewFilesystem()
|
||||
if localErr != nil {
|
||||
t.Fatal(localErr)
|
||||
}
|
||||
if local == nil {
|
||||
t.Fatal("Expected local filesystem instance, got nil")
|
||||
}
|
||||
|
||||
// misconfigured s3
|
||||
app.Settings().S3.Enabled = true
|
||||
s3, s3Err := app.NewFilesystem()
|
||||
if s3Err == nil {
|
||||
t.Fatal("Expected S3 error, got nil")
|
||||
}
|
||||
if s3 != nil {
|
||||
t.Fatalf("Expected nil s3 filesystem, got %v", s3)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseAppNewBackupsFilesystem(t *testing.T) {
|
||||
const testDataDir = "./pb_base_app_test_data_dir/"
|
||||
defer os.RemoveAll(testDataDir)
|
||||
|
||||
app := core.NewBaseApp(core.BaseAppConfig{
|
||||
DataDir: testDataDir,
|
||||
})
|
||||
defer app.ResetBootstrapState()
|
||||
|
||||
// local
|
||||
local, localErr := app.NewBackupsFilesystem()
|
||||
if localErr != nil {
|
||||
t.Fatal(localErr)
|
||||
}
|
||||
if local == nil {
|
||||
t.Fatal("Expected local backups filesystem instance, got nil")
|
||||
}
|
||||
|
||||
// misconfigured s3
|
||||
app.Settings().Backups.S3.Enabled = true
|
||||
s3, s3Err := app.NewBackupsFilesystem()
|
||||
if s3Err == nil {
|
||||
t.Fatal("Expected S3 error, got nil")
|
||||
}
|
||||
if s3 != nil {
|
||||
t.Fatalf("Expected nil s3 backups filesystem, got %v", s3)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseAppLoggerWrites(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
// reset
|
||||
if err := app.DeleteOldLogs(time.Now()); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
const logsThreshold = 200
|
||||
|
||||
totalLogs := func(app core.App, t *testing.T) int {
|
||||
var total int
|
||||
|
||||
err := app.LogQuery().Select("count(*)").Row(&total)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to fetch total logs: %v", err)
|
||||
}
|
||||
|
||||
return total
|
||||
}
|
||||
|
||||
t.Run("disabled logs retention", func(t *testing.T) {
|
||||
app.Settings().Logs.MaxDays = 0
|
||||
|
||||
for i := 0; i < logsThreshold+1; i++ {
|
||||
app.Logger().Error("test")
|
||||
}
|
||||
|
||||
if total := totalLogs(app, t); total != 0 {
|
||||
t.Fatalf("Expected no logs, got %d", total)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("test batch logs writes", func(t *testing.T) {
|
||||
app.Settings().Logs.MaxDays = 1
|
||||
|
||||
for i := 0; i < logsThreshold-1; i++ {
|
||||
app.Logger().Error("test")
|
||||
}
|
||||
|
||||
if total := totalLogs(app, t); total != 0 {
|
||||
t.Fatalf("Expected no logs, got %d", total)
|
||||
}
|
||||
|
||||
// should trigger batch write
|
||||
app.Logger().Error("test")
|
||||
|
||||
// should be added for the next batch write
|
||||
app.Logger().Error("test")
|
||||
|
||||
if total := totalLogs(app, t); total != logsThreshold {
|
||||
t.Fatalf("Expected %d logs, got %d", logsThreshold, total)
|
||||
}
|
||||
|
||||
// wait for ~3 secs to check the timer trigger
|
||||
time.Sleep(3200 * time.Millisecond)
|
||||
if total := totalLogs(app, t); total != logsThreshold+1 {
|
||||
t.Fatalf("Expected %d logs, got %d", logsThreshold+1, total)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestBaseAppRefreshSettingsLoggerMinLevelEnabled(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
name string
|
||||
isDev bool
|
||||
level int
|
||||
// level->enabled map
|
||||
expectations map[int]bool
|
||||
}{
|
||||
{
|
||||
"dev mode",
|
||||
true,
|
||||
4,
|
||||
map[int]bool{
|
||||
3: true,
|
||||
4: true,
|
||||
5: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
"nondev mode",
|
||||
false,
|
||||
4,
|
||||
map[int]bool{
|
||||
3: false,
|
||||
4: true,
|
||||
5: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
const testDataDir = "./pb_base_app_test_data_dir/"
|
||||
defer os.RemoveAll(testDataDir)
|
||||
|
||||
app := core.NewBaseApp(core.BaseAppConfig{
|
||||
DataDir: testDataDir,
|
||||
IsDev: s.isDev,
|
||||
})
|
||||
defer app.ResetBootstrapState()
|
||||
|
||||
if err := app.Bootstrap(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// silence query logs
|
||||
app.ConcurrentDB().(*dbx.DB).ExecLogFunc = func(ctx context.Context, t time.Duration, sql string, result sql.Result, err error) {}
|
||||
app.ConcurrentDB().(*dbx.DB).QueryLogFunc = func(ctx context.Context, t time.Duration, sql string, rows *sql.Rows, err error) {}
|
||||
app.NonconcurrentDB().(*dbx.DB).ExecLogFunc = func(ctx context.Context, t time.Duration, sql string, result sql.Result, err error) {}
|
||||
app.NonconcurrentDB().(*dbx.DB).QueryLogFunc = func(ctx context.Context, t time.Duration, sql string, rows *sql.Rows, err error) {}
|
||||
|
||||
handler, ok := app.Logger().Handler().(*logger.BatchHandler)
|
||||
if !ok {
|
||||
t.Fatalf("Expected BatchHandler, got %v", app.Logger().Handler())
|
||||
}
|
||||
|
||||
app.Settings().Logs.MinLevel = s.level
|
||||
|
||||
if err := app.Save(app.Settings()); err != nil {
|
||||
t.Fatalf("Failed to save settings: %v", err)
|
||||
}
|
||||
|
||||
for level, enabled := range s.expectations {
|
||||
if v := handler.Enabled(context.Background(), slog.Level(level)); v != enabled {
|
||||
t.Fatalf("Expected level %d Enabled() to be %v, got %v", level, enabled, v)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseAppDBDualBuilder(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
concurrentQueries := []string{}
|
||||
nonconcurrentQueries := []string{}
|
||||
app.ConcurrentDB().(*dbx.DB).QueryLogFunc = func(ctx context.Context, t time.Duration, sql string, rows *sql.Rows, err error) {
|
||||
concurrentQueries = append(concurrentQueries, sql)
|
||||
}
|
||||
app.ConcurrentDB().(*dbx.DB).ExecLogFunc = func(ctx context.Context, t time.Duration, sql string, result sql.Result, err error) {
|
||||
concurrentQueries = append(concurrentQueries, sql)
|
||||
}
|
||||
app.NonconcurrentDB().(*dbx.DB).QueryLogFunc = func(ctx context.Context, t time.Duration, sql string, rows *sql.Rows, err error) {
|
||||
nonconcurrentQueries = append(nonconcurrentQueries, sql)
|
||||
}
|
||||
app.NonconcurrentDB().(*dbx.DB).ExecLogFunc = func(ctx context.Context, t time.Duration, sql string, result sql.Result, err error) {
|
||||
nonconcurrentQueries = append(nonconcurrentQueries, sql)
|
||||
}
|
||||
|
||||
type testQuery struct {
|
||||
query string
|
||||
isConcurrent bool
|
||||
}
|
||||
|
||||
regularTests := []testQuery{
|
||||
{" \n sEleCt 1", true},
|
||||
{"With abc(x) AS (select 2) SELECT x FROM abc", true},
|
||||
{"create table t1(x int)", false},
|
||||
{"insert into t1(x) values(1)", false},
|
||||
{"update t1 set x = 2", false},
|
||||
{"delete from t1", false},
|
||||
}
|
||||
|
||||
txTests := []testQuery{
|
||||
{"select 3", false},
|
||||
{" \n WITH abc(x) AS (select 4) SELECT x FROM abc", false},
|
||||
{"create table t2(x int)", false},
|
||||
{"insert into t2(x) values(1)", false},
|
||||
{"update t2 set x = 2", false},
|
||||
{"delete from t2", false},
|
||||
}
|
||||
|
||||
for _, item := range regularTests {
|
||||
_, err := app.DB().NewQuery(item.query).Execute()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to execute query %q error: %v", item.query, err)
|
||||
}
|
||||
}
|
||||
|
||||
app.RunInTransaction(func(txApp core.App) error {
|
||||
for _, item := range txTests {
|
||||
_, err := txApp.DB().NewQuery(item.query).Execute()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to execute query %q error: %v", item.query, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
allTests := append(regularTests, txTests...)
|
||||
for _, item := range allTests {
|
||||
if item.isConcurrent {
|
||||
if !slices.Contains(concurrentQueries, item.query) {
|
||||
t.Fatalf("Expected concurrent query\n%q\ngot\nconcurrent:%v\nnonconcurrent:%v", item.query, concurrentQueries, nonconcurrentQueries)
|
||||
}
|
||||
} else {
|
||||
if !slices.Contains(nonconcurrentQueries, item.query) {
|
||||
t.Fatalf("Expected nonconcurrent query\n%q\ngot\nconcurrent:%v\nnonconcurrent:%v", item.query, concurrentQueries, nonconcurrentQueries)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseAppAuxDBDualBuilder(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
concurrentQueries := []string{}
|
||||
nonconcurrentQueries := []string{}
|
||||
app.AuxConcurrentDB().(*dbx.DB).QueryLogFunc = func(ctx context.Context, t time.Duration, sql string, rows *sql.Rows, err error) {
|
||||
concurrentQueries = append(concurrentQueries, sql)
|
||||
}
|
||||
app.AuxConcurrentDB().(*dbx.DB).ExecLogFunc = func(ctx context.Context, t time.Duration, sql string, result sql.Result, err error) {
|
||||
concurrentQueries = append(concurrentQueries, sql)
|
||||
}
|
||||
app.AuxNonconcurrentDB().(*dbx.DB).QueryLogFunc = func(ctx context.Context, t time.Duration, sql string, rows *sql.Rows, err error) {
|
||||
nonconcurrentQueries = append(nonconcurrentQueries, sql)
|
||||
}
|
||||
app.AuxNonconcurrentDB().(*dbx.DB).ExecLogFunc = func(ctx context.Context, t time.Duration, sql string, result sql.Result, err error) {
|
||||
nonconcurrentQueries = append(nonconcurrentQueries, sql)
|
||||
}
|
||||
|
||||
type testQuery struct {
|
||||
query string
|
||||
isConcurrent bool
|
||||
}
|
||||
|
||||
regularTests := []testQuery{
|
||||
{" \n sEleCt 1", true},
|
||||
{"With abc(x) AS (select 2) SELECT x FROM abc", true},
|
||||
{"create table t1(x int)", false},
|
||||
{"insert into t1(x) values(1)", false},
|
||||
{"update t1 set x = 2", false},
|
||||
{"delete from t1", false},
|
||||
}
|
||||
|
||||
txTests := []testQuery{
|
||||
{"select 3", false},
|
||||
{" \n WITH abc(x) AS (select 4) SELECT x FROM abc", false},
|
||||
{"create table t2(x int)", false},
|
||||
{"insert into t2(x) values(1)", false},
|
||||
{"update t2 set x = 2", false},
|
||||
{"delete from t2", false},
|
||||
}
|
||||
|
||||
for _, item := range regularTests {
|
||||
_, err := app.AuxDB().NewQuery(item.query).Execute()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to execute query %q error: %v", item.query, err)
|
||||
}
|
||||
}
|
||||
|
||||
app.AuxRunInTransaction(func(txApp core.App) error {
|
||||
for _, item := range txTests {
|
||||
_, err := txApp.AuxDB().NewQuery(item.query).Execute()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to execute query %q error: %v", item.query, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
allTests := append(regularTests, txTests...)
|
||||
for _, item := range allTests {
|
||||
if item.isConcurrent {
|
||||
if !slices.Contains(concurrentQueries, item.query) {
|
||||
t.Fatalf("Expected concurrent query\n%q\ngot\nconcurrent:%v\nnonconcurrent:%v", item.query, concurrentQueries, nonconcurrentQueries)
|
||||
}
|
||||
} else {
|
||||
if !slices.Contains(nonconcurrentQueries, item.query) {
|
||||
t.Fatalf("Expected nonconcurrent query\n%q\ngot\nconcurrent:%v\nnonconcurrent:%v", item.query, concurrentQueries, nonconcurrentQueries)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
200
core/collection_import.go
Normal file
200
core/collection_import.go
Normal file
|
@ -0,0 +1,200 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/spf13/cast"
|
||||
)
|
||||
|
||||
// ImportCollectionsByMarshaledJSON is the same as [ImportCollections]
|
||||
// but accept marshaled json array as import data (usually used for the autogenerated snapshots).
|
||||
func (app *BaseApp) ImportCollectionsByMarshaledJSON(rawSliceOfMaps []byte, deleteMissing bool) error {
|
||||
data := []map[string]any{}
|
||||
|
||||
err := json.Unmarshal(rawSliceOfMaps, &data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return app.ImportCollections(data, deleteMissing)
|
||||
}
|
||||
|
||||
// ImportCollections imports the provided collections data in a single transaction.
|
||||
//
|
||||
// For existing matching collections, the imported data is unmarshaled on top of the existing model.
|
||||
//
|
||||
// NB! If deleteMissing is true, ALL NON-SYSTEM COLLECTIONS AND SCHEMA FIELDS,
|
||||
// that are not present in the imported configuration, WILL BE DELETED
|
||||
// (this includes their related records data).
|
||||
func (app *BaseApp) ImportCollections(toImport []map[string]any, deleteMissing bool) error {
|
||||
if len(toImport) == 0 {
|
||||
// prevent accidentally deleting all collections
|
||||
return errors.New("no collections to import")
|
||||
}
|
||||
|
||||
importedCollections := make([]*Collection, len(toImport))
|
||||
mappedImported := make(map[string]*Collection, len(toImport))
|
||||
|
||||
// normalize imported collections data to ensure that all
|
||||
// collection fields are present and properly initialized
|
||||
for i, data := range toImport {
|
||||
var imported *Collection
|
||||
|
||||
identifier := cast.ToString(data["id"])
|
||||
if identifier == "" {
|
||||
identifier = cast.ToString(data["name"])
|
||||
}
|
||||
|
||||
existing, err := app.FindCollectionByNameOrId(identifier)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
return err
|
||||
}
|
||||
|
||||
if existing != nil {
|
||||
// refetch for deep copy
|
||||
imported, err = app.FindCollectionByNameOrId(existing.Id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// ensure that the fields will be cleared
|
||||
if data["fields"] == nil && deleteMissing {
|
||||
data["fields"] = []map[string]any{}
|
||||
}
|
||||
|
||||
rawData, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// load the imported data
|
||||
err = json.Unmarshal(rawData, imported)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// extend with the existing fields if necessary
|
||||
for _, f := range existing.Fields {
|
||||
if !f.GetSystem() && deleteMissing {
|
||||
continue
|
||||
}
|
||||
if imported.Fields.GetById(f.GetId()) == nil {
|
||||
// replace with the existing id to prevent accidental column deletion
|
||||
// since otherwise the imported field will be treated as a new one
|
||||
found := imported.Fields.GetByName(f.GetName())
|
||||
if found != nil && found.Type() == f.Type() {
|
||||
found.SetId(f.GetId())
|
||||
}
|
||||
imported.Fields.Add(f)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
imported = &Collection{}
|
||||
|
||||
rawData, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// load the imported data
|
||||
err = json.Unmarshal(rawData, imported)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
imported.IntegrityChecks(false)
|
||||
|
||||
importedCollections[i] = imported
|
||||
mappedImported[imported.Id] = imported
|
||||
}
|
||||
|
||||
// reorder views last since the view query could depend on some of the other collections
|
||||
slices.SortStableFunc(importedCollections, func(a, b *Collection) int {
|
||||
cmpA := -1
|
||||
if a.IsView() {
|
||||
cmpA = 1
|
||||
}
|
||||
|
||||
cmpB := -1
|
||||
if b.IsView() {
|
||||
cmpB = 1
|
||||
}
|
||||
|
||||
res := cmp.Compare(cmpA, cmpB)
|
||||
if res == 0 {
|
||||
res = a.Created.Compare(b.Created)
|
||||
if res == 0 {
|
||||
res = a.Updated.Compare(b.Updated)
|
||||
}
|
||||
}
|
||||
return res
|
||||
})
|
||||
|
||||
return app.RunInTransaction(func(txApp App) error {
|
||||
existingCollections := []*Collection{}
|
||||
if err := txApp.CollectionQuery().OrderBy("updated ASC").All(&existingCollections); err != nil {
|
||||
return err
|
||||
}
|
||||
mappedExisting := make(map[string]*Collection, len(existingCollections))
|
||||
for _, existing := range existingCollections {
|
||||
existing.IntegrityChecks(false)
|
||||
mappedExisting[existing.Id] = existing
|
||||
}
|
||||
|
||||
// delete old collections not available in the new configuration
|
||||
// (before saving the imports in case a deleted collection name is being reused)
|
||||
if deleteMissing {
|
||||
for _, existing := range existingCollections {
|
||||
if mappedImported[existing.Id] != nil || existing.System {
|
||||
continue // exist or system
|
||||
}
|
||||
|
||||
// delete collection
|
||||
if err := txApp.Delete(existing); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// upsert imported collections
|
||||
for _, imported := range importedCollections {
|
||||
if err := txApp.SaveNoValidate(imported); err != nil {
|
||||
return fmt.Errorf("failed to save collection %q: %w", imported.Name, err)
|
||||
}
|
||||
}
|
||||
|
||||
// run validations
|
||||
for _, imported := range importedCollections {
|
||||
original := mappedExisting[imported.Id]
|
||||
if original == nil {
|
||||
original = imported
|
||||
}
|
||||
|
||||
validator := newCollectionValidator(
|
||||
context.Background(),
|
||||
txApp,
|
||||
imported,
|
||||
original,
|
||||
)
|
||||
if err := validator.run(); err != nil {
|
||||
// serialize the validation error(s)
|
||||
serializedErr, _ := json.MarshalIndent(err, "", " ")
|
||||
|
||||
return validation.Errors{"collections": validation.NewError(
|
||||
"validation_collections_import_failure",
|
||||
fmt.Sprintf("Data validations failed for collection %q (%s):\n%s", imported.Name, imported.Id, serializedErr),
|
||||
)}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
476
core/collection_import_test.go
Normal file
476
core/collection_import_test.go
Normal file
|
@ -0,0 +1,476 @@
|
|||
package core_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestImportCollections(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testApp, _ := tests.NewTestApp()
|
||||
defer testApp.Cleanup()
|
||||
|
||||
var regularCollections []*core.Collection
|
||||
err := testApp.CollectionQuery().AndWhere(dbx.HashExp{"system": false}).All(®ularCollections)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var systemCollections []*core.Collection
|
||||
err = testApp.CollectionQuery().AndWhere(dbx.HashExp{"system": true}).All(&systemCollections)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
totalRegularCollections := len(regularCollections)
|
||||
totalSystemCollections := len(systemCollections)
|
||||
totalCollections := totalRegularCollections + totalSystemCollections
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
data []map[string]any
|
||||
deleteMissing bool
|
||||
expectError bool
|
||||
expectCollectionsCount int
|
||||
afterTestFunc func(testApp *tests.TestApp, resultCollections []*core.Collection)
|
||||
}{
|
||||
{
|
||||
name: "empty collections",
|
||||
data: []map[string]any{},
|
||||
expectError: true,
|
||||
expectCollectionsCount: totalCollections,
|
||||
},
|
||||
{
|
||||
name: "minimal collection import (with missing system fields)",
|
||||
data: []map[string]any{
|
||||
{"name": "import_test1", "type": "auth"},
|
||||
{
|
||||
"name": "import_test2", "fields": []map[string]any{
|
||||
{"name": "test", "type": "text"},
|
||||
},
|
||||
},
|
||||
},
|
||||
deleteMissing: false,
|
||||
expectError: false,
|
||||
expectCollectionsCount: totalCollections + 2,
|
||||
},
|
||||
{
|
||||
name: "minimal collection import (trigger collection model validations)",
|
||||
data: []map[string]any{
|
||||
{"name": ""},
|
||||
{
|
||||
"name": "import_test2", "fields": []map[string]any{
|
||||
{"name": "test", "type": "text"},
|
||||
},
|
||||
},
|
||||
},
|
||||
deleteMissing: false,
|
||||
expectError: true,
|
||||
expectCollectionsCount: totalCollections,
|
||||
},
|
||||
{
|
||||
name: "minimal collection import (trigger field settings validation)",
|
||||
data: []map[string]any{
|
||||
{"name": "import_test", "fields": []map[string]any{{"name": "test", "type": "text", "min": -1}}},
|
||||
},
|
||||
deleteMissing: false,
|
||||
expectError: true,
|
||||
expectCollectionsCount: totalCollections,
|
||||
},
|
||||
{
|
||||
name: "new + update + delete (system collections delete should be ignored)",
|
||||
data: []map[string]any{
|
||||
{
|
||||
"id": "wsmn24bux7wo113",
|
||||
"name": "demo",
|
||||
"fields": []map[string]any{
|
||||
{
|
||||
"id": "_2hlxbmp",
|
||||
"name": "title",
|
||||
"type": "text",
|
||||
"system": false,
|
||||
"required": true,
|
||||
"min": 3,
|
||||
"max": nil,
|
||||
"pattern": "",
|
||||
},
|
||||
},
|
||||
"indexes": []string{},
|
||||
},
|
||||
{
|
||||
"name": "import1",
|
||||
"fields": []map[string]any{
|
||||
{
|
||||
"name": "active",
|
||||
"type": "bool",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
deleteMissing: true,
|
||||
expectError: false,
|
||||
expectCollectionsCount: totalSystemCollections + 2,
|
||||
},
|
||||
{
|
||||
name: "test with deleteMissing: false",
|
||||
data: []map[string]any{
|
||||
{
|
||||
// "id": "wsmn24bux7wo113", // test update with only name as identifier
|
||||
"name": "demo1",
|
||||
"fields": []map[string]any{
|
||||
{
|
||||
"id": "_2hlxbmp",
|
||||
"name": "title",
|
||||
"type": "text",
|
||||
"system": false,
|
||||
"required": true,
|
||||
"min": 3,
|
||||
"max": nil,
|
||||
"pattern": "",
|
||||
},
|
||||
{
|
||||
"id": "_2hlxbmp",
|
||||
"name": "field_with_duplicate_id",
|
||||
"type": "text",
|
||||
"system": false,
|
||||
"required": true,
|
||||
"unique": false,
|
||||
"min": 4,
|
||||
"max": nil,
|
||||
"pattern": "",
|
||||
},
|
||||
{
|
||||
"id": "abcd_import",
|
||||
"name": "new_field",
|
||||
"type": "text",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "new_import",
|
||||
"fields": []map[string]any{
|
||||
{
|
||||
"id": "abcd_import",
|
||||
"name": "active",
|
||||
"type": "bool",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
deleteMissing: false,
|
||||
expectError: false,
|
||||
expectCollectionsCount: totalCollections + 1,
|
||||
afterTestFunc: func(testApp *tests.TestApp, resultCollections []*core.Collection) {
|
||||
expectedCollectionFields := map[string]int{
|
||||
core.CollectionNameAuthOrigins: 6,
|
||||
"nologin": 10,
|
||||
"demo1": 19,
|
||||
"demo2": 5,
|
||||
"demo3": 5,
|
||||
"demo4": 16,
|
||||
"demo5": 9,
|
||||
"new_import": 2,
|
||||
}
|
||||
for name, expectedCount := range expectedCollectionFields {
|
||||
collection, err := testApp.FindCollectionByNameOrId(name)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if totalFields := len(collection.Fields); totalFields != expectedCount {
|
||||
t.Errorf("Expected %d %q fields, got %d", expectedCount, collection.Name, totalFields)
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
testApp, _ := tests.NewTestApp()
|
||||
defer testApp.Cleanup()
|
||||
|
||||
err := testApp.ImportCollections(s.data, s.deleteMissing)
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Fatalf("Expected hasErr to be %v, got %v (%v)", s.expectError, hasErr, err)
|
||||
}
|
||||
|
||||
// check collections count
|
||||
collections := []*core.Collection{}
|
||||
if err := testApp.CollectionQuery().All(&collections); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(collections) != s.expectCollectionsCount {
|
||||
t.Fatalf("Expected %d collections, got %d", s.expectCollectionsCount, len(collections))
|
||||
}
|
||||
|
||||
if s.afterTestFunc != nil {
|
||||
s.afterTestFunc(testApp, collections)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestImportCollectionsByMarshaledJSON(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testApp, _ := tests.NewTestApp()
|
||||
defer testApp.Cleanup()
|
||||
|
||||
var regularCollections []*core.Collection
|
||||
err := testApp.CollectionQuery().AndWhere(dbx.HashExp{"system": false}).All(®ularCollections)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var systemCollections []*core.Collection
|
||||
err = testApp.CollectionQuery().AndWhere(dbx.HashExp{"system": true}).All(&systemCollections)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
totalRegularCollections := len(regularCollections)
|
||||
totalSystemCollections := len(systemCollections)
|
||||
totalCollections := totalRegularCollections + totalSystemCollections
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
data string
|
||||
deleteMissing bool
|
||||
expectError bool
|
||||
expectCollectionsCount int
|
||||
afterTestFunc func(testApp *tests.TestApp, resultCollections []*core.Collection)
|
||||
}{
|
||||
{
|
||||
name: "invalid json array",
|
||||
data: `{"test":123}`,
|
||||
expectError: true,
|
||||
expectCollectionsCount: totalCollections,
|
||||
},
|
||||
{
|
||||
name: "new + update + delete (system collections delete should be ignored)",
|
||||
data: `[
|
||||
{
|
||||
"id": "wsmn24bux7wo113",
|
||||
"name": "demo",
|
||||
"fields": [
|
||||
{
|
||||
"id": "_2hlxbmp",
|
||||
"name": "title",
|
||||
"type": "text",
|
||||
"system": false,
|
||||
"required": true,
|
||||
"min": 3,
|
||||
"max": null,
|
||||
"pattern": ""
|
||||
}
|
||||
],
|
||||
"indexes": []
|
||||
},
|
||||
{
|
||||
"name": "import1",
|
||||
"fields": [
|
||||
{
|
||||
"name": "active",
|
||||
"type": "bool"
|
||||
}
|
||||
]
|
||||
}
|
||||
]`,
|
||||
deleteMissing: true,
|
||||
expectError: false,
|
||||
expectCollectionsCount: totalSystemCollections + 2,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
testApp, _ := tests.NewTestApp()
|
||||
defer testApp.Cleanup()
|
||||
|
||||
err := testApp.ImportCollectionsByMarshaledJSON([]byte(s.data), s.deleteMissing)
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Fatalf("Expected hasErr to be %v, got %v (%v)", s.expectError, hasErr, err)
|
||||
}
|
||||
|
||||
// check collections count
|
||||
collections := []*core.Collection{}
|
||||
if err := testApp.CollectionQuery().All(&collections); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(collections) != s.expectCollectionsCount {
|
||||
t.Fatalf("Expected %d collections, got %d", s.expectCollectionsCount, len(collections))
|
||||
}
|
||||
|
||||
if s.afterTestFunc != nil {
|
||||
s.afterTestFunc(testApp, collections)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestImportCollectionsUpdateRules(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
data map[string]any
|
||||
deleteMissing bool
|
||||
}{
|
||||
{
|
||||
"extend existing by name (without deleteMissing)",
|
||||
map[string]any{"name": "clients", "authToken": map[string]any{"duration": 100}, "fields": []map[string]any{{"name": "test", "type": "text"}}},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"extend existing by id (without deleteMissing)",
|
||||
map[string]any{"id": "v851q4r790rhknl", "authToken": map[string]any{"duration": 100}, "fields": []map[string]any{{"name": "test", "type": "text"}}},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"extend with delete missing",
|
||||
map[string]any{
|
||||
"id": "v851q4r790rhknl",
|
||||
"authToken": map[string]any{"duration": 100},
|
||||
"fields": []map[string]any{{"name": "test", "type": "text"}},
|
||||
"passwordAuth": map[string]any{"identityFields": []string{"email"}},
|
||||
"indexes": []string{
|
||||
// min required system fields indexes
|
||||
"CREATE UNIQUE INDEX `_v851q4r790rhknl_email_idx` ON `clients` (email) WHERE email != ''",
|
||||
"CREATE UNIQUE INDEX `_v851q4r790rhknl_tokenKey_idx` ON `clients` (tokenKey)",
|
||||
},
|
||||
},
|
||||
true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
testApp, _ := tests.NewTestApp()
|
||||
defer testApp.Cleanup()
|
||||
|
||||
beforeCollection, err := testApp.FindCollectionByNameOrId("clients")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = testApp.ImportCollections([]map[string]any{s.data}, s.deleteMissing)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
afterCollection, err := testApp.FindCollectionByNameOrId("clients")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if afterCollection.AuthToken.Duration != 100 {
|
||||
t.Fatalf("Expected AuthToken duration to be %d, got %d", 100, afterCollection.AuthToken.Duration)
|
||||
}
|
||||
if beforeCollection.AuthToken.Secret != afterCollection.AuthToken.Secret {
|
||||
t.Fatalf("Expected AuthToken secrets to remain the same, got\n%q\nVS\n%q", beforeCollection.AuthToken.Secret, afterCollection.AuthToken.Secret)
|
||||
}
|
||||
if beforeCollection.Name != afterCollection.Name {
|
||||
t.Fatalf("Expected Name to remain the same, got\n%q\nVS\n%q", beforeCollection.Name, afterCollection.Name)
|
||||
}
|
||||
if beforeCollection.Id != afterCollection.Id {
|
||||
t.Fatalf("Expected Id to remain the same, got\n%q\nVS\n%q", beforeCollection.Id, afterCollection.Id)
|
||||
}
|
||||
|
||||
if !s.deleteMissing {
|
||||
totalExpectedFields := len(beforeCollection.Fields) + 1
|
||||
if v := len(afterCollection.Fields); v != totalExpectedFields {
|
||||
t.Fatalf("Expected %d total fields, got %d", totalExpectedFields, v)
|
||||
}
|
||||
|
||||
if afterCollection.Fields.GetByName("test") == nil {
|
||||
t.Fatalf("Missing new field %q", "test")
|
||||
}
|
||||
|
||||
// ensure that the old fields still exist
|
||||
oldFields := beforeCollection.Fields.FieldNames()
|
||||
for _, name := range oldFields {
|
||||
if afterCollection.Fields.GetByName(name) == nil {
|
||||
t.Fatalf("Missing expected old field %q", name)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
totalExpectedFields := 1
|
||||
for _, f := range beforeCollection.Fields {
|
||||
if f.GetSystem() {
|
||||
totalExpectedFields++
|
||||
}
|
||||
}
|
||||
|
||||
if v := len(afterCollection.Fields); v != totalExpectedFields {
|
||||
t.Fatalf("Expected %d total fields, got %d", totalExpectedFields, v)
|
||||
}
|
||||
|
||||
if afterCollection.Fields.GetByName("test") == nil {
|
||||
t.Fatalf("Missing new field %q", "test")
|
||||
}
|
||||
|
||||
// ensure that the old system fields still exist
|
||||
for _, f := range beforeCollection.Fields {
|
||||
if f.GetSystem() && afterCollection.Fields.GetByName(f.GetName()) == nil {
|
||||
t.Fatalf("Missing expected old field %q", f.GetName())
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestImportCollectionsCreateRules(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testApp, _ := tests.NewTestApp()
|
||||
defer testApp.Cleanup()
|
||||
|
||||
err := testApp.ImportCollections([]map[string]any{
|
||||
{"name": "new_test", "type": "auth", "authToken": map[string]any{"duration": 123}, "fields": []map[string]any{{"name": "test", "type": "text"}}},
|
||||
}, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
collection, err := testApp.FindCollectionByNameOrId("new_test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
raw, err := json.Marshal(collection)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
rawStr := string(raw)
|
||||
|
||||
expectedParts := []string{
|
||||
`"name":"new_test"`,
|
||||
`"fields":[`,
|
||||
`"name":"id"`,
|
||||
`"name":"email"`,
|
||||
`"name":"tokenKey"`,
|
||||
`"name":"password"`,
|
||||
`"name":"test"`,
|
||||
`"indexes":[`,
|
||||
`CREATE UNIQUE INDEX`,
|
||||
`"duration":123`,
|
||||
}
|
||||
|
||||
for _, part := range expectedParts {
|
||||
if !strings.Contains(rawStr, part) {
|
||||
t.Errorf("Missing %q in\n%s", part, rawStr)
|
||||
}
|
||||
}
|
||||
}
|
1073
core/collection_model.go
Normal file
1073
core/collection_model.go
Normal file
File diff suppressed because it is too large
Load diff
543
core/collection_model_auth_options.go
Normal file
543
core/collection_model_auth_options.go
Normal file
|
@ -0,0 +1,543 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/go-ozzo/ozzo-validation/v4/is"
|
||||
"github.com/pocketbase/pocketbase/tools/auth"
|
||||
"github.com/pocketbase/pocketbase/tools/list"
|
||||
"github.com/pocketbase/pocketbase/tools/security"
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
"github.com/spf13/cast"
|
||||
)
|
||||
|
||||
func (m *Collection) unsetMissingOAuth2MappedFields() {
|
||||
if !m.IsAuth() {
|
||||
return
|
||||
}
|
||||
|
||||
if m.OAuth2.MappedFields.Id != "" {
|
||||
if m.Fields.GetByName(m.OAuth2.MappedFields.Id) == nil {
|
||||
m.OAuth2.MappedFields.Id = ""
|
||||
}
|
||||
}
|
||||
|
||||
if m.OAuth2.MappedFields.Name != "" {
|
||||
if m.Fields.GetByName(m.OAuth2.MappedFields.Name) == nil {
|
||||
m.OAuth2.MappedFields.Name = ""
|
||||
}
|
||||
}
|
||||
|
||||
if m.OAuth2.MappedFields.Username != "" {
|
||||
if m.Fields.GetByName(m.OAuth2.MappedFields.Username) == nil {
|
||||
m.OAuth2.MappedFields.Username = ""
|
||||
}
|
||||
}
|
||||
|
||||
if m.OAuth2.MappedFields.AvatarURL != "" {
|
||||
if m.Fields.GetByName(m.OAuth2.MappedFields.AvatarURL) == nil {
|
||||
m.OAuth2.MappedFields.AvatarURL = ""
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Collection) setDefaultAuthOptions() {
|
||||
m.collectionAuthOptions = collectionAuthOptions{
|
||||
VerificationTemplate: defaultVerificationTemplate,
|
||||
ResetPasswordTemplate: defaultResetPasswordTemplate,
|
||||
ConfirmEmailChangeTemplate: defaultConfirmEmailChangeTemplate,
|
||||
AuthRule: types.Pointer(""),
|
||||
AuthAlert: AuthAlertConfig{
|
||||
Enabled: true,
|
||||
EmailTemplate: defaultAuthAlertTemplate,
|
||||
},
|
||||
PasswordAuth: PasswordAuthConfig{
|
||||
Enabled: true,
|
||||
IdentityFields: []string{FieldNameEmail},
|
||||
},
|
||||
MFA: MFAConfig{
|
||||
Enabled: false,
|
||||
Duration: 1800, // 30min
|
||||
},
|
||||
OTP: OTPConfig{
|
||||
Enabled: false,
|
||||
Duration: 180, // 3min
|
||||
Length: 8,
|
||||
EmailTemplate: defaultOTPTemplate,
|
||||
},
|
||||
AuthToken: TokenConfig{
|
||||
Secret: security.RandomString(50),
|
||||
Duration: 604800, // 7 days
|
||||
},
|
||||
PasswordResetToken: TokenConfig{
|
||||
Secret: security.RandomString(50),
|
||||
Duration: 1800, // 30min
|
||||
},
|
||||
EmailChangeToken: TokenConfig{
|
||||
Secret: security.RandomString(50),
|
||||
Duration: 1800, // 30min
|
||||
},
|
||||
VerificationToken: TokenConfig{
|
||||
Secret: security.RandomString(50),
|
||||
Duration: 259200, // 3days
|
||||
},
|
||||
FileToken: TokenConfig{
|
||||
Secret: security.RandomString(50),
|
||||
Duration: 180, // 3min
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
var _ optionsValidator = (*collectionAuthOptions)(nil)
|
||||
|
||||
// collectionAuthOptions defines the options for the "auth" type collection.
|
||||
type collectionAuthOptions struct {
|
||||
// AuthRule could be used to specify additional record constraints
|
||||
// applied after record authentication and right before returning the
|
||||
// auth token response to the client.
|
||||
//
|
||||
// For example, to allow only verified users you could set it to
|
||||
// "verified = true".
|
||||
//
|
||||
// Set it to empty string to allow any Auth collection record to authenticate.
|
||||
//
|
||||
// Set it to nil to disallow authentication altogether for the collection
|
||||
// (that includes password, OAuth2, etc.).
|
||||
AuthRule *string `form:"authRule" json:"authRule"`
|
||||
|
||||
// ManageRule gives admin-like permissions to allow fully managing
|
||||
// the auth record(s), eg. changing the password without requiring
|
||||
// to enter the old one, directly updating the verified state and email, etc.
|
||||
//
|
||||
// This rule is executed in addition to the Create and Update API rules.
|
||||
ManageRule *string `form:"manageRule" json:"manageRule"`
|
||||
|
||||
// AuthAlert defines options related to the auth alerts on new device login.
|
||||
AuthAlert AuthAlertConfig `form:"authAlert" json:"authAlert"`
|
||||
|
||||
// OAuth2 specifies whether OAuth2 auth is enabled for the collection
|
||||
// and which OAuth2 providers are allowed.
|
||||
OAuth2 OAuth2Config `form:"oauth2" json:"oauth2"`
|
||||
|
||||
// PasswordAuth defines options related to the collection password authentication.
|
||||
PasswordAuth PasswordAuthConfig `form:"passwordAuth" json:"passwordAuth"`
|
||||
|
||||
// MFA defines options related to the Multi-factor authentication (MFA).
|
||||
MFA MFAConfig `form:"mfa" json:"mfa"`
|
||||
|
||||
// OTP defines options related to the One-time password authentication (OTP).
|
||||
OTP OTPConfig `form:"otp" json:"otp"`
|
||||
|
||||
// Various token configurations
|
||||
// ---
|
||||
AuthToken TokenConfig `form:"authToken" json:"authToken"`
|
||||
PasswordResetToken TokenConfig `form:"passwordResetToken" json:"passwordResetToken"`
|
||||
EmailChangeToken TokenConfig `form:"emailChangeToken" json:"emailChangeToken"`
|
||||
VerificationToken TokenConfig `form:"verificationToken" json:"verificationToken"`
|
||||
FileToken TokenConfig `form:"fileToken" json:"fileToken"`
|
||||
|
||||
// Default email templates
|
||||
// ---
|
||||
VerificationTemplate EmailTemplate `form:"verificationTemplate" json:"verificationTemplate"`
|
||||
ResetPasswordTemplate EmailTemplate `form:"resetPasswordTemplate" json:"resetPasswordTemplate"`
|
||||
ConfirmEmailChangeTemplate EmailTemplate `form:"confirmEmailChangeTemplate" json:"confirmEmailChangeTemplate"`
|
||||
}
|
||||
|
||||
func (o *collectionAuthOptions) validate(cv *collectionValidator) error {
|
||||
err := validation.ValidateStruct(o,
|
||||
validation.Field(
|
||||
&o.AuthRule,
|
||||
validation.By(cv.checkRule),
|
||||
validation.By(cv.ensureNoSystemRuleChange(cv.original.AuthRule)),
|
||||
),
|
||||
validation.Field(
|
||||
&o.ManageRule,
|
||||
validation.NilOrNotEmpty,
|
||||
validation.By(cv.checkRule),
|
||||
validation.By(cv.ensureNoSystemRuleChange(cv.original.ManageRule)),
|
||||
),
|
||||
validation.Field(&o.AuthAlert),
|
||||
validation.Field(&o.PasswordAuth),
|
||||
validation.Field(&o.OAuth2),
|
||||
validation.Field(&o.OTP),
|
||||
validation.Field(&o.MFA),
|
||||
validation.Field(&o.AuthToken),
|
||||
validation.Field(&o.PasswordResetToken),
|
||||
validation.Field(&o.EmailChangeToken),
|
||||
validation.Field(&o.VerificationToken),
|
||||
validation.Field(&o.FileToken),
|
||||
validation.Field(&o.VerificationTemplate, validation.Required),
|
||||
validation.Field(&o.ResetPasswordTemplate, validation.Required),
|
||||
validation.Field(&o.ConfirmEmailChangeTemplate, validation.Required),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if o.MFA.Enabled {
|
||||
// if MFA is enabled require at least 2 auth methods
|
||||
//
|
||||
// @todo maybe consider disabling the check because if custom auth methods
|
||||
// are registered it may fail since we don't have mechanism to detect them at the moment
|
||||
authsEnabled := 0
|
||||
if o.PasswordAuth.Enabled {
|
||||
authsEnabled++
|
||||
}
|
||||
if o.OAuth2.Enabled {
|
||||
authsEnabled++
|
||||
}
|
||||
if o.OTP.Enabled {
|
||||
authsEnabled++
|
||||
}
|
||||
if authsEnabled < 2 {
|
||||
return validation.Errors{
|
||||
"mfa": validation.Errors{
|
||||
"enabled": validation.NewError("validation_mfa_not_enough_auths", "MFA requires at least 2 auth methods to be enabled."),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
if o.MFA.Rule != "" {
|
||||
mfaRuleValidators := []validation.RuleFunc{
|
||||
cv.checkRule,
|
||||
cv.ensureNoSystemRuleChange(&cv.original.MFA.Rule),
|
||||
}
|
||||
|
||||
for _, validator := range mfaRuleValidators {
|
||||
err := validator(&o.MFA.Rule)
|
||||
if err != nil {
|
||||
return validation.Errors{
|
||||
"mfa": validation.Errors{
|
||||
"rule": err,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// extra check to ensure that only unique identity fields are used
|
||||
if o.PasswordAuth.Enabled {
|
||||
err = validation.Validate(o.PasswordAuth.IdentityFields, validation.By(cv.checkFieldsForUniqueIndex))
|
||||
if err != nil {
|
||||
return validation.Errors{
|
||||
"passwordAuth": validation.Errors{
|
||||
"identityFields": err,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
type EmailTemplate struct {
|
||||
Subject string `form:"subject" json:"subject"`
|
||||
Body string `form:"body" json:"body"`
|
||||
}
|
||||
|
||||
// Validate makes EmailTemplate validatable by implementing [validation.Validatable] interface.
|
||||
func (t EmailTemplate) Validate() error {
|
||||
return validation.ValidateStruct(&t,
|
||||
validation.Field(&t.Subject, validation.Required),
|
||||
validation.Field(&t.Body, validation.Required),
|
||||
)
|
||||
}
|
||||
|
||||
// Resolve replaces the placeholder parameters in the current email
|
||||
// template and returns its components as ready-to-use strings.
|
||||
func (t EmailTemplate) Resolve(placeholders map[string]any) (subject, body string) {
|
||||
body = t.Body
|
||||
subject = t.Subject
|
||||
|
||||
for k, v := range placeholders {
|
||||
vStr := cast.ToString(v)
|
||||
|
||||
// replace subject placeholder params (if any)
|
||||
subject = strings.ReplaceAll(subject, k, vStr)
|
||||
|
||||
// replace body placeholder params (if any)
|
||||
body = strings.ReplaceAll(body, k, vStr)
|
||||
}
|
||||
|
||||
return subject, body
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
type AuthAlertConfig struct {
|
||||
Enabled bool `form:"enabled" json:"enabled"`
|
||||
EmailTemplate EmailTemplate `form:"emailTemplate" json:"emailTemplate"`
|
||||
}
|
||||
|
||||
// Validate makes AuthAlertConfig validatable by implementing [validation.Validatable] interface.
|
||||
func (c AuthAlertConfig) Validate() error {
|
||||
return validation.ValidateStruct(&c,
|
||||
// note: for now always run the email template validations even
|
||||
// if not enabled since it could be used separately
|
||||
validation.Field(&c.EmailTemplate),
|
||||
)
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
type TokenConfig struct {
|
||||
Secret string `form:"secret" json:"secret,omitempty"`
|
||||
|
||||
// Duration specifies how long an issued token to be valid (in seconds)
|
||||
Duration int64 `form:"duration" json:"duration"`
|
||||
}
|
||||
|
||||
// Validate makes TokenConfig validatable by implementing [validation.Validatable] interface.
|
||||
func (c TokenConfig) Validate() error {
|
||||
return validation.ValidateStruct(&c,
|
||||
validation.Field(&c.Secret, validation.Required, validation.Length(30, 255)),
|
||||
validation.Field(&c.Duration, validation.Required, validation.Min(10), validation.Max(94670856)), // ~3y max
|
||||
)
|
||||
}
|
||||
|
||||
// DurationTime returns the current Duration as [time.Duration].
|
||||
func (c TokenConfig) DurationTime() time.Duration {
|
||||
return time.Duration(c.Duration) * time.Second
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
type OTPConfig struct {
|
||||
Enabled bool `form:"enabled" json:"enabled"`
|
||||
|
||||
// Duration specifies how long the OTP to be valid (in seconds)
|
||||
Duration int64 `form:"duration" json:"duration"`
|
||||
|
||||
// Length specifies the auto generated password length.
|
||||
Length int `form:"length" json:"length"`
|
||||
|
||||
// EmailTemplate is the default OTP email template that will be send to the auth record.
|
||||
//
|
||||
// In addition to the system placeholders you can also make use of
|
||||
// [core.EmailPlaceholderOTPId] and [core.EmailPlaceholderOTP].
|
||||
EmailTemplate EmailTemplate `form:"emailTemplate" json:"emailTemplate"`
|
||||
}
|
||||
|
||||
// Validate makes OTPConfig validatable by implementing [validation.Validatable] interface.
|
||||
func (c OTPConfig) Validate() error {
|
||||
return validation.ValidateStruct(&c,
|
||||
validation.Field(&c.Duration, validation.When(c.Enabled, validation.Required, validation.Min(10), validation.Max(86400))),
|
||||
validation.Field(&c.Length, validation.When(c.Enabled, validation.Required, validation.Min(4))),
|
||||
// note: for now always run the email template validations even
|
||||
// if not enabled since it could be used separately
|
||||
validation.Field(&c.EmailTemplate),
|
||||
)
|
||||
}
|
||||
|
||||
// DurationTime returns the current Duration as [time.Duration].
|
||||
func (c OTPConfig) DurationTime() time.Duration {
|
||||
return time.Duration(c.Duration) * time.Second
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
type MFAConfig struct {
|
||||
Enabled bool `form:"enabled" json:"enabled"`
|
||||
|
||||
// Duration specifies how long an issued MFA to be valid (in seconds)
|
||||
Duration int64 `form:"duration" json:"duration"`
|
||||
|
||||
// Rule is an optional field to restrict MFA only for the records that satisfy the rule.
|
||||
//
|
||||
// Leave it empty to enable MFA for everyone.
|
||||
Rule string `form:"rule" json:"rule"`
|
||||
}
|
||||
|
||||
// Validate makes MFAConfig validatable by implementing [validation.Validatable] interface.
|
||||
func (c MFAConfig) Validate() error {
|
||||
return validation.ValidateStruct(&c,
|
||||
validation.Field(&c.Duration, validation.When(c.Enabled, validation.Required, validation.Min(10), validation.Max(86400))),
|
||||
)
|
||||
}
|
||||
|
||||
// DurationTime returns the current Duration as [time.Duration].
|
||||
func (c MFAConfig) DurationTime() time.Duration {
|
||||
return time.Duration(c.Duration) * time.Second
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
type PasswordAuthConfig struct {
|
||||
Enabled bool `form:"enabled" json:"enabled"`
|
||||
|
||||
// IdentityFields is a list of field names that could be used as
|
||||
// identity during password authentication.
|
||||
//
|
||||
// Usually only fields that has single column UNIQUE index are accepted as values.
|
||||
IdentityFields []string `form:"identityFields" json:"identityFields"`
|
||||
}
|
||||
|
||||
// Validate makes PasswordAuthConfig validatable by implementing [validation.Validatable] interface.
|
||||
func (c PasswordAuthConfig) Validate() error {
|
||||
// strip duplicated values
|
||||
c.IdentityFields = list.ToUniqueStringSlice(c.IdentityFields)
|
||||
|
||||
if !c.Enabled {
|
||||
return nil // no need to validate
|
||||
}
|
||||
|
||||
return validation.ValidateStruct(&c,
|
||||
validation.Field(&c.IdentityFields, validation.Required),
|
||||
)
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
type OAuth2KnownFields struct {
|
||||
Id string `form:"id" json:"id"`
|
||||
Name string `form:"name" json:"name"`
|
||||
Username string `form:"username" json:"username"`
|
||||
AvatarURL string `form:"avatarURL" json:"avatarURL"`
|
||||
}
|
||||
|
||||
type OAuth2Config struct {
|
||||
Providers []OAuth2ProviderConfig `form:"providers" json:"providers"`
|
||||
|
||||
MappedFields OAuth2KnownFields `form:"mappedFields" json:"mappedFields"`
|
||||
|
||||
Enabled bool `form:"enabled" json:"enabled"`
|
||||
}
|
||||
|
||||
// GetProviderConfig returns the first OAuth2ProviderConfig that matches the specified name.
|
||||
//
|
||||
// Returns false and zero config if no such provider is available in c.Providers.
|
||||
func (c OAuth2Config) GetProviderConfig(name string) (config OAuth2ProviderConfig, exists bool) {
|
||||
for _, p := range c.Providers {
|
||||
if p.Name == name {
|
||||
return p, true
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Validate makes OAuth2Config validatable by implementing [validation.Validatable] interface.
|
||||
func (c OAuth2Config) Validate() error {
|
||||
if !c.Enabled {
|
||||
return nil // no need to validate
|
||||
}
|
||||
|
||||
return validation.ValidateStruct(&c,
|
||||
// note: don't require providers for now as they could be externally registered/removed
|
||||
validation.Field(&c.Providers, validation.By(checkForDuplicatedProviders)),
|
||||
)
|
||||
}
|
||||
|
||||
func checkForDuplicatedProviders(value any) error {
|
||||
configs, _ := value.([]OAuth2ProviderConfig)
|
||||
|
||||
existing := map[string]struct{}{}
|
||||
|
||||
for i, c := range configs {
|
||||
if c.Name == "" {
|
||||
continue // the name nonempty state is validated separately
|
||||
}
|
||||
if _, ok := existing[c.Name]; ok {
|
||||
return validation.Errors{
|
||||
strconv.Itoa(i): validation.Errors{
|
||||
"name": validation.NewError("validation_duplicated_provider", "The provider {{.name}} is already registered.").
|
||||
SetParams(map[string]any{"name": c.Name}),
|
||||
},
|
||||
}
|
||||
}
|
||||
existing[c.Name] = struct{}{}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type OAuth2ProviderConfig struct {
|
||||
// PKCE overwrites the default provider PKCE config option.
|
||||
//
|
||||
// This usually shouldn't be needed but some OAuth2 vendors, like the LinkedIn OIDC,
|
||||
// may require manual adjustment due to returning error if extra parameters are added to the request
|
||||
// (https://github.com/pocketbase/pocketbase/discussions/3799#discussioncomment-7640312)
|
||||
PKCE *bool `form:"pkce" json:"pkce"`
|
||||
|
||||
Name string `form:"name" json:"name"`
|
||||
ClientId string `form:"clientId" json:"clientId"`
|
||||
ClientSecret string `form:"clientSecret" json:"clientSecret,omitempty"`
|
||||
AuthURL string `form:"authURL" json:"authURL"`
|
||||
TokenURL string `form:"tokenURL" json:"tokenURL"`
|
||||
UserInfoURL string `form:"userInfoURL" json:"userInfoURL"`
|
||||
DisplayName string `form:"displayName" json:"displayName"`
|
||||
Extra map[string]any `form:"extra" json:"extra"`
|
||||
}
|
||||
|
||||
// Validate makes OAuth2ProviderConfig validatable by implementing [validation.Validatable] interface.
|
||||
func (c OAuth2ProviderConfig) Validate() error {
|
||||
return validation.ValidateStruct(&c,
|
||||
validation.Field(&c.Name, validation.Required, validation.By(checkProviderName)),
|
||||
validation.Field(&c.ClientId, validation.Required),
|
||||
validation.Field(&c.ClientSecret, validation.Required),
|
||||
validation.Field(&c.AuthURL, is.URL),
|
||||
validation.Field(&c.TokenURL, is.URL),
|
||||
validation.Field(&c.UserInfoURL, is.URL),
|
||||
)
|
||||
}
|
||||
|
||||
func checkProviderName(value any) error {
|
||||
name, _ := value.(string)
|
||||
if name == "" {
|
||||
return nil // nothing to check
|
||||
}
|
||||
|
||||
if _, err := auth.NewProviderByName(name); err != nil {
|
||||
return validation.NewError("validation_missing_provider", "Invalid or missing provider with name {{.name}}.").
|
||||
SetParams(map[string]any{"name": name})
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// InitProvider returns a new auth.Provider instance loaded with the current OAuth2ProviderConfig options.
|
||||
func (c OAuth2ProviderConfig) InitProvider() (auth.Provider, error) {
|
||||
provider, err := auth.NewProviderByName(c.Name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if c.ClientId != "" {
|
||||
provider.SetClientId(c.ClientId)
|
||||
}
|
||||
|
||||
if c.ClientSecret != "" {
|
||||
provider.SetClientSecret(c.ClientSecret)
|
||||
}
|
||||
|
||||
if c.AuthURL != "" {
|
||||
provider.SetAuthURL(c.AuthURL)
|
||||
}
|
||||
|
||||
if c.UserInfoURL != "" {
|
||||
provider.SetUserInfoURL(c.UserInfoURL)
|
||||
}
|
||||
|
||||
if c.TokenURL != "" {
|
||||
provider.SetTokenURL(c.TokenURL)
|
||||
}
|
||||
|
||||
if c.DisplayName != "" {
|
||||
provider.SetDisplayName(c.DisplayName)
|
||||
}
|
||||
|
||||
if c.PKCE != nil {
|
||||
provider.SetPKCE(*c.PKCE)
|
||||
}
|
||||
|
||||
if c.Extra != nil {
|
||||
provider.SetExtra(c.Extra)
|
||||
}
|
||||
|
||||
return provider, nil
|
||||
}
|
1026
core/collection_model_auth_options_test.go
Normal file
1026
core/collection_model_auth_options_test.go
Normal file
File diff suppressed because it is too large
Load diff
75
core/collection_model_auth_templates.go
Normal file
75
core/collection_model_auth_templates.go
Normal file
|
@ -0,0 +1,75 @@
|
|||
package core
|
||||
|
||||
// Common settings placeholder tokens
|
||||
const (
|
||||
EmailPlaceholderAppName string = "{APP_NAME}"
|
||||
EmailPlaceholderAppURL string = "{APP_URL}"
|
||||
EmailPlaceholderToken string = "{TOKEN}"
|
||||
EmailPlaceholderOTP string = "{OTP}"
|
||||
EmailPlaceholderOTPId string = "{OTP_ID}"
|
||||
)
|
||||
|
||||
var defaultVerificationTemplate = EmailTemplate{
|
||||
Subject: "Verify your " + EmailPlaceholderAppName + " email",
|
||||
Body: `<p>Hello,</p>
|
||||
<p>Thank you for joining us at ` + EmailPlaceholderAppName + `.</p>
|
||||
<p>Click on the button below to verify your email address.</p>
|
||||
<p>
|
||||
<a class="btn" href="` + EmailPlaceholderAppURL + "/_/#/auth/confirm-verification/" + EmailPlaceholderToken + `" target="_blank" rel="noopener">Verify</a>
|
||||
</p>
|
||||
<p>
|
||||
Thanks,<br/>
|
||||
` + EmailPlaceholderAppName + ` team
|
||||
</p>`,
|
||||
}
|
||||
|
||||
var defaultResetPasswordTemplate = EmailTemplate{
|
||||
Subject: "Reset your " + EmailPlaceholderAppName + " password",
|
||||
Body: `<p>Hello,</p>
|
||||
<p>Click on the button below to reset your password.</p>
|
||||
<p>
|
||||
<a class="btn" href="` + EmailPlaceholderAppURL + "/_/#/auth/confirm-password-reset/" + EmailPlaceholderToken + `" target="_blank" rel="noopener">Reset password</a>
|
||||
</p>
|
||||
<p><i>If you didn't ask to reset your password, you can ignore this email.</i></p>
|
||||
<p>
|
||||
Thanks,<br/>
|
||||
` + EmailPlaceholderAppName + ` team
|
||||
</p>`,
|
||||
}
|
||||
|
||||
var defaultConfirmEmailChangeTemplate = EmailTemplate{
|
||||
Subject: "Confirm your " + EmailPlaceholderAppName + " new email address",
|
||||
Body: `<p>Hello,</p>
|
||||
<p>Click on the button below to confirm your new email address.</p>
|
||||
<p>
|
||||
<a class="btn" href="` + EmailPlaceholderAppURL + "/_/#/auth/confirm-email-change/" + EmailPlaceholderToken + `" target="_blank" rel="noopener">Confirm new email</a>
|
||||
</p>
|
||||
<p><i>If you didn't ask to change your email address, you can ignore this email.</i></p>
|
||||
<p>
|
||||
Thanks,<br/>
|
||||
` + EmailPlaceholderAppName + ` team
|
||||
</p>`,
|
||||
}
|
||||
|
||||
var defaultOTPTemplate = EmailTemplate{
|
||||
Subject: "OTP for " + EmailPlaceholderAppName,
|
||||
Body: `<p>Hello,</p>
|
||||
<p>Your one-time password is: <strong>` + EmailPlaceholderOTP + `</strong></p>
|
||||
<p><i>If you didn't ask for the one-time password, you can ignore this email.</i></p>
|
||||
<p>
|
||||
Thanks,<br/>
|
||||
` + EmailPlaceholderAppName + ` team
|
||||
</p>`,
|
||||
}
|
||||
|
||||
var defaultAuthAlertTemplate = EmailTemplate{
|
||||
Subject: "Login from a new location",
|
||||
Body: `<p>Hello,</p>
|
||||
<p>We noticed a login to your ` + EmailPlaceholderAppName + ` account from a new location.</p>
|
||||
<p>If this was you, you may disregard this email.</p>
|
||||
<p><strong>If this wasn't you, you should immediately change your ` + EmailPlaceholderAppName + ` account password to revoke access from all other locations.</strong></p>
|
||||
<p>
|
||||
Thanks,<br/>
|
||||
` + EmailPlaceholderAppName + ` team
|
||||
</p>`,
|
||||
}
|
11
core/collection_model_base_options.go
Normal file
11
core/collection_model_base_options.go
Normal file
|
@ -0,0 +1,11 @@
|
|||
package core
|
||||
|
||||
var _ optionsValidator = (*collectionBaseOptions)(nil)
|
||||
|
||||
// collectionBaseOptions defines the options for the "base" type collection.
|
||||
type collectionBaseOptions struct {
|
||||
}
|
||||
|
||||
func (o *collectionBaseOptions) validate(cv *collectionValidator) error {
|
||||
return nil
|
||||
}
|
1638
core/collection_model_test.go
Normal file
1638
core/collection_model_test.go
Normal file
File diff suppressed because it is too large
Load diff
18
core/collection_model_view_options.go
Normal file
18
core/collection_model_view_options.go
Normal file
|
@ -0,0 +1,18 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
)
|
||||
|
||||
var _ optionsValidator = (*collectionViewOptions)(nil)
|
||||
|
||||
// collectionViewOptions defines the options for the "view" type collection.
|
||||
type collectionViewOptions struct {
|
||||
ViewQuery string `form:"viewQuery" json:"viewQuery"`
|
||||
}
|
||||
|
||||
func (o *collectionViewOptions) validate(cv *collectionValidator) error {
|
||||
return validation.ValidateStruct(o,
|
||||
validation.Field(&o.ViewQuery, validation.Required, validation.By(cv.checkViewQuery)),
|
||||
)
|
||||
}
|
79
core/collection_model_view_options_test.go
Normal file
79
core/collection_model_view_options_test.go
Normal file
|
@ -0,0 +1,79 @@
|
|||
package core_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestCollectionViewOptionsValidate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
collection func(app core.App) (*core.Collection, error)
|
||||
expectedErrors []string
|
||||
}{
|
||||
{
|
||||
name: "view with empty query",
|
||||
collection: func(app core.App) (*core.Collection, error) {
|
||||
c := core.NewViewCollection("new_auth")
|
||||
return c, nil
|
||||
},
|
||||
expectedErrors: []string{"fields", "viewQuery"},
|
||||
},
|
||||
{
|
||||
name: "view with invalid query",
|
||||
collection: func(app core.App) (*core.Collection, error) {
|
||||
c := core.NewViewCollection("new_auth")
|
||||
c.ViewQuery = "invalid"
|
||||
return c, nil
|
||||
},
|
||||
expectedErrors: []string{"fields", "viewQuery"},
|
||||
},
|
||||
{
|
||||
name: "view with valid query but missing id",
|
||||
collection: func(app core.App) (*core.Collection, error) {
|
||||
c := core.NewViewCollection("new_auth")
|
||||
c.ViewQuery = "select 1"
|
||||
return c, nil
|
||||
},
|
||||
expectedErrors: []string{"fields", "viewQuery"},
|
||||
},
|
||||
{
|
||||
name: "view with valid query",
|
||||
collection: func(app core.App) (*core.Collection, error) {
|
||||
c := core.NewViewCollection("new_auth")
|
||||
c.ViewQuery = "select demo1.id, text as example from demo1"
|
||||
return c, nil
|
||||
},
|
||||
expectedErrors: []string{},
|
||||
},
|
||||
{
|
||||
name: "update view query ",
|
||||
collection: func(app core.App) (*core.Collection, error) {
|
||||
c, _ := app.FindCollectionByNameOrId("view2")
|
||||
c.ViewQuery = "select demo1.id, text as example from demo1"
|
||||
return c, nil
|
||||
},
|
||||
expectedErrors: []string{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
collection, err := s.collection(app)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to retrieve test collection: %v", err)
|
||||
}
|
||||
|
||||
result := app.Validate(collection)
|
||||
|
||||
tests.TestValidationErrors(t, result, s.expectedErrors)
|
||||
})
|
||||
}
|
||||
}
|
391
core/collection_query.go
Normal file
391
core/collection_query.go
Normal file
|
@ -0,0 +1,391 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/tools/list"
|
||||
)
|
||||
|
||||
const StoreKeyCachedCollections = "pbAppCachedCollections"
|
||||
|
||||
// CollectionQuery returns a new Collection select query.
|
||||
func (app *BaseApp) CollectionQuery() *dbx.SelectQuery {
|
||||
return app.ModelQuery(&Collection{})
|
||||
}
|
||||
|
||||
// FindCollections finds all collections by the given type(s).
|
||||
//
|
||||
// If collectionTypes is not set, it returns all collections.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// app.FindAllCollections() // all collections
|
||||
// app.FindAllCollections("auth", "view") // only auth and view collections
|
||||
func (app *BaseApp) FindAllCollections(collectionTypes ...string) ([]*Collection, error) {
|
||||
collections := []*Collection{}
|
||||
|
||||
q := app.CollectionQuery()
|
||||
|
||||
types := list.NonzeroUniques(collectionTypes)
|
||||
if len(types) > 0 {
|
||||
q.AndWhere(dbx.In("type", list.ToInterfaceSlice(types)...))
|
||||
}
|
||||
|
||||
err := q.OrderBy("rowid ASC").All(&collections)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return collections, nil
|
||||
}
|
||||
|
||||
// ReloadCachedCollections fetches all collections and caches them into the app store.
|
||||
func (app *BaseApp) ReloadCachedCollections() error {
|
||||
collections, err := app.FindAllCollections()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
app.Store().Set(StoreKeyCachedCollections, collections)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// FindCollectionByNameOrId finds a single collection by its name (case insensitive) or id.
|
||||
func (app *BaseApp) FindCollectionByNameOrId(nameOrId string) (*Collection, error) {
|
||||
m := &Collection{}
|
||||
|
||||
err := app.CollectionQuery().
|
||||
AndWhere(dbx.NewExp("[[id]]={:id} OR LOWER([[name]])={:name}", dbx.Params{
|
||||
"id": nameOrId,
|
||||
"name": strings.ToLower(nameOrId),
|
||||
})).
|
||||
Limit(1).
|
||||
One(m)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// FindCachedCollectionByNameOrId is similar to [BaseApp.FindCollectionByNameOrId]
|
||||
// but retrieves the Collection from the app cache instead of making a db call.
|
||||
//
|
||||
// NB! This method is suitable for read-only Collection operations.
|
||||
//
|
||||
// Returns [sql.ErrNoRows] if no Collection is found for consistency
|
||||
// with the [BaseApp.FindCollectionByNameOrId] method.
|
||||
//
|
||||
// If you plan making changes to the returned Collection model,
|
||||
// use [BaseApp.FindCollectionByNameOrId] instead.
|
||||
//
|
||||
// Caveats:
|
||||
//
|
||||
// - The returned Collection should be used only for read-only operations.
|
||||
// Avoid directly modifying the returned cached Collection as it will affect
|
||||
// the global cached value even if you don't persist the changes in the database!
|
||||
// - If you are updating a Collection in a transaction and then call this method before commit,
|
||||
// it'll return the cached Collection state and not the one from the uncommitted transaction.
|
||||
// - The cache is automatically updated on collections db change (create/update/delete).
|
||||
// To manually reload the cache you can call [BaseApp.ReloadCachedCollections].
|
||||
func (app *BaseApp) FindCachedCollectionByNameOrId(nameOrId string) (*Collection, error) {
|
||||
collections, _ := app.Store().Get(StoreKeyCachedCollections).([]*Collection)
|
||||
if collections == nil {
|
||||
// cache is not initialized yet (eg. run in a system migration)
|
||||
return app.FindCollectionByNameOrId(nameOrId)
|
||||
}
|
||||
|
||||
for _, c := range collections {
|
||||
if strings.EqualFold(c.Name, nameOrId) || c.Id == nameOrId {
|
||||
return c, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, sql.ErrNoRows
|
||||
}
|
||||
|
||||
// FindCollectionReferences returns information for all relation fields
|
||||
// referencing the provided collection.
|
||||
//
|
||||
// If the provided collection has reference to itself then it will be
|
||||
// also included in the result. To exclude it, pass the collection id
|
||||
// as the excludeIds argument.
|
||||
func (app *BaseApp) FindCollectionReferences(collection *Collection, excludeIds ...string) (map[*Collection][]Field, error) {
|
||||
collections := []*Collection{}
|
||||
|
||||
query := app.CollectionQuery()
|
||||
|
||||
if uniqueExcludeIds := list.NonzeroUniques(excludeIds); len(uniqueExcludeIds) > 0 {
|
||||
query.AndWhere(dbx.NotIn("id", list.ToInterfaceSlice(uniqueExcludeIds)...))
|
||||
}
|
||||
|
||||
if err := query.All(&collections); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := map[*Collection][]Field{}
|
||||
|
||||
for _, c := range collections {
|
||||
for _, rawField := range c.Fields {
|
||||
f, ok := rawField.(*RelationField)
|
||||
if ok && f.CollectionId == collection.Id {
|
||||
result[c] = append(result[c], f)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// FindCachedCollectionReferences is similar to [BaseApp.FindCollectionReferences]
|
||||
// but retrieves the Collection from the app cache instead of making a db call.
|
||||
//
|
||||
// NB! This method is suitable for read-only Collection operations.
|
||||
//
|
||||
// If you plan making changes to the returned Collection model,
|
||||
// use [BaseApp.FindCollectionReferences] instead.
|
||||
//
|
||||
// Caveats:
|
||||
//
|
||||
// - The returned Collection should be used only for read-only operations.
|
||||
// Avoid directly modifying the returned cached Collection as it will affect
|
||||
// the global cached value even if you don't persist the changes in the database!
|
||||
// - If you are updating a Collection in a transaction and then call this method before commit,
|
||||
// it'll return the cached Collection state and not the one from the uncommitted transaction.
|
||||
// - The cache is automatically updated on collections db change (create/update/delete).
|
||||
// To manually reload the cache you can call [BaseApp.ReloadCachedCollections].
|
||||
func (app *BaseApp) FindCachedCollectionReferences(collection *Collection, excludeIds ...string) (map[*Collection][]Field, error) {
|
||||
collections, _ := app.Store().Get(StoreKeyCachedCollections).([]*Collection)
|
||||
if collections == nil {
|
||||
// cache is not initialized yet (eg. run in a system migration)
|
||||
return app.FindCollectionReferences(collection, excludeIds...)
|
||||
}
|
||||
|
||||
result := map[*Collection][]Field{}
|
||||
|
||||
for _, c := range collections {
|
||||
if slices.Contains(excludeIds, c.Id) {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, rawField := range c.Fields {
|
||||
f, ok := rawField.(*RelationField)
|
||||
if ok && f.CollectionId == collection.Id {
|
||||
result[c] = append(result[c], f)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// IsCollectionNameUnique checks that there is no existing collection
|
||||
// with the provided name (case insensitive!).
|
||||
//
|
||||
// Note: case insensitive check because the name is used also as
|
||||
// table name for the records.
|
||||
func (app *BaseApp) IsCollectionNameUnique(name string, excludeIds ...string) bool {
|
||||
if name == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
query := app.CollectionQuery().
|
||||
Select("count(*)").
|
||||
AndWhere(dbx.NewExp("LOWER([[name]])={:name}", dbx.Params{"name": strings.ToLower(name)})).
|
||||
Limit(1)
|
||||
|
||||
if uniqueExcludeIds := list.NonzeroUniques(excludeIds); len(uniqueExcludeIds) > 0 {
|
||||
query.AndWhere(dbx.NotIn("id", list.ToInterfaceSlice(uniqueExcludeIds)...))
|
||||
}
|
||||
|
||||
var total int
|
||||
|
||||
return query.Row(&total) == nil && total == 0
|
||||
}
|
||||
|
||||
// TruncateCollection deletes all records associated with the provided collection.
|
||||
//
|
||||
// The truncate operation is executed in a single transaction,
|
||||
// aka. either everything is deleted or none.
|
||||
//
|
||||
// Note that this method will also trigger the records related
|
||||
// cascade and file delete actions.
|
||||
func (app *BaseApp) TruncateCollection(collection *Collection) error {
|
||||
if collection.IsView() {
|
||||
return errors.New("view collections cannot be truncated since they don't store their own records")
|
||||
}
|
||||
|
||||
return app.RunInTransaction(func(txApp App) error {
|
||||
records := make([]*Record, 0, 500)
|
||||
|
||||
for {
|
||||
err := txApp.RecordQuery(collection).Limit(500).All(&records)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(records) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, record := range records {
|
||||
err = txApp.Delete(record)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
records = records[:0]
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
// saveViewCollection persists the provided View collection changes:
|
||||
// - deletes the old related SQL view (if any)
|
||||
// - creates a new SQL view with the latest newCollection.Options.Query
|
||||
// - generates new feilds list based on newCollection.Options.Query
|
||||
// - updates newCollection.Fields based on the generated view table info and query
|
||||
// - saves the newCollection
|
||||
//
|
||||
// This method returns an error if newCollection is not a "view".
|
||||
func saveViewCollection(app App, newCollection, oldCollection *Collection) error {
|
||||
if !newCollection.IsView() {
|
||||
return errors.New("not a view collection")
|
||||
}
|
||||
|
||||
return app.RunInTransaction(func(txApp App) error {
|
||||
query := newCollection.ViewQuery
|
||||
|
||||
// generate collection fields from the query
|
||||
viewFields, err := txApp.CreateViewFields(query)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// delete old renamed view
|
||||
if oldCollection != nil {
|
||||
if err := txApp.DeleteView(oldCollection.Name); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// wrap view query if necessary
|
||||
query, err = normalizeViewQueryId(txApp, query)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to normalize view query id: %w", err)
|
||||
}
|
||||
|
||||
// (re)create the view
|
||||
if err := txApp.SaveView(newCollection.Name, query); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
newCollection.Fields = viewFields
|
||||
|
||||
return txApp.Save(newCollection)
|
||||
})
|
||||
}
|
||||
|
||||
// normalizeViewQueryId wraps (if necessary) the provided view query
|
||||
// with a subselect to ensure that the id column is a text since
|
||||
// currently we don't support non-string model ids
|
||||
// (see https://github.com/pocketbase/pocketbase/issues/3110).
|
||||
func normalizeViewQueryId(app App, query string) (string, error) {
|
||||
query = strings.Trim(strings.TrimSpace(query), ";")
|
||||
|
||||
info, err := getQueryTableInfo(app, query)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
for _, row := range info {
|
||||
if strings.EqualFold(row.Name, FieldNameId) && strings.EqualFold(row.Type, "TEXT") {
|
||||
return query, nil // no wrapping needed
|
||||
}
|
||||
}
|
||||
|
||||
// raw parse to preserve the columns order
|
||||
rawParsed := new(identifiersParser)
|
||||
if err := rawParsed.parse(query); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
columns := make([]string, 0, len(rawParsed.columns))
|
||||
for _, col := range rawParsed.columns {
|
||||
if col.alias == FieldNameId {
|
||||
columns = append(columns, fmt.Sprintf("CAST([[%s]] as TEXT) [[%s]]", col.alias, col.alias))
|
||||
} else {
|
||||
columns = append(columns, "[["+col.alias+"]]")
|
||||
}
|
||||
}
|
||||
|
||||
query = fmt.Sprintf("SELECT %s FROM (%s)", strings.Join(columns, ","), query)
|
||||
|
||||
return query, nil
|
||||
}
|
||||
|
||||
// resaveViewsWithChangedFields updates all view collections with changed fields.
|
||||
func resaveViewsWithChangedFields(app App, excludeIds ...string) error {
|
||||
collections, err := app.FindAllCollections(CollectionTypeView)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return app.RunInTransaction(func(txApp App) error {
|
||||
for _, collection := range collections {
|
||||
if len(excludeIds) > 0 && list.ExistInSlice(collection.Id, excludeIds) {
|
||||
continue
|
||||
}
|
||||
|
||||
// clone the existing fields for temp modifications
|
||||
oldFields, err := collection.Fields.Clone()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// generate new fields from the query
|
||||
newFields, err := txApp.CreateViewFields(collection.ViewQuery)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// unset the fields' ids to exclude from the comparison
|
||||
for _, f := range oldFields {
|
||||
f.SetId("")
|
||||
}
|
||||
for _, f := range newFields {
|
||||
f.SetId("")
|
||||
}
|
||||
|
||||
encodedNewFields, err := json.Marshal(newFields)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
encodedOldFields, err := json.Marshal(oldFields)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if bytes.EqualFold(encodedNewFields, encodedOldFields) {
|
||||
continue // no changes
|
||||
}
|
||||
|
||||
if err := saveViewCollection(txApp, collection, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
474
core/collection_query_test.go
Normal file
474
core/collection_query_test.go
Normal file
|
@ -0,0 +1,474 @@
|
|||
package core_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
"github.com/pocketbase/pocketbase/tools/list"
|
||||
)
|
||||
|
||||
func TestCollectionQuery(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
expected := "SELECT {{_collections}}.* FROM `_collections`"
|
||||
|
||||
sql := app.CollectionQuery().Build().SQL()
|
||||
if sql != expected {
|
||||
t.Errorf("Expected sql %s, got %s", expected, sql)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReloadCachedCollections(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
err := app.ReloadCachedCollections()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cached := app.Store().Get(core.StoreKeyCachedCollections)
|
||||
|
||||
cachedCollections, ok := cached.([]*core.Collection)
|
||||
if !ok {
|
||||
t.Fatalf("Expected []*core.Collection, got %T", cached)
|
||||
}
|
||||
|
||||
collections, err := app.FindAllCollections()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to retrieve all collections: %v", err)
|
||||
}
|
||||
|
||||
if len(cachedCollections) != len(collections) {
|
||||
t.Fatalf("Expected %d collections, got %d", len(collections), len(cachedCollections))
|
||||
}
|
||||
|
||||
for _, c := range collections {
|
||||
var exists bool
|
||||
for _, cc := range cachedCollections {
|
||||
if cc.Id == c.Id {
|
||||
exists = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !exists {
|
||||
t.Fatalf("The collections cache is missing collection %q", c.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindAllCollections(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
scenarios := []struct {
|
||||
collectionTypes []string
|
||||
expectTotal int
|
||||
}{
|
||||
{nil, 16},
|
||||
{[]string{}, 16},
|
||||
{[]string{""}, 16},
|
||||
{[]string{"unknown"}, 0},
|
||||
{[]string{"unknown", core.CollectionTypeAuth}, 4},
|
||||
{[]string{core.CollectionTypeAuth, core.CollectionTypeView}, 7},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%s", i, strings.Join(s.collectionTypes, "_")), func(t *testing.T) {
|
||||
collections, err := app.FindAllCollections(s.collectionTypes...)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(collections) != s.expectTotal {
|
||||
t.Fatalf("Expected %d collections, got %d", s.expectTotal, len(collections))
|
||||
}
|
||||
|
||||
expectedTypes := list.NonzeroUniques(s.collectionTypes)
|
||||
if len(expectedTypes) > 0 {
|
||||
for _, c := range collections {
|
||||
if !slices.Contains(expectedTypes, c.Type) {
|
||||
t.Fatalf("Unexpected collection type %s\n%v", c.Type, c)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindCollectionByNameOrId(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
scenarios := []struct {
|
||||
nameOrId string
|
||||
expectError bool
|
||||
}{
|
||||
{"", true},
|
||||
{"missing", true},
|
||||
{"wsmn24bux7wo113", false},
|
||||
{"demo1", false},
|
||||
{"DEMO1", false}, // case insensitive
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%s", i, s.nameOrId), func(t *testing.T) {
|
||||
model, err := app.FindCollectionByNameOrId(s.nameOrId)
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Fatalf("Expected hasErr to be %v, got %v (%v)", s.expectError, hasErr, err)
|
||||
}
|
||||
|
||||
if model != nil && model.Id != s.nameOrId && !strings.EqualFold(model.Name, s.nameOrId) {
|
||||
t.Fatalf("Expected model with identifier %s, got %v", s.nameOrId, model)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindCachedCollectionByNameOrId(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
totalQueries := 0
|
||||
app.ConcurrentDB().(*dbx.DB).QueryLogFunc = func(ctx context.Context, t time.Duration, sql string, rows *sql.Rows, err error) {
|
||||
totalQueries++
|
||||
}
|
||||
|
||||
run := func(withCache bool) {
|
||||
scenarios := []struct {
|
||||
nameOrId string
|
||||
expectError bool
|
||||
}{
|
||||
{"", true},
|
||||
{"missing", true},
|
||||
{"wsmn24bux7wo113", false},
|
||||
{"demo1", false},
|
||||
{"DEMO1", false}, // case insensitive
|
||||
}
|
||||
|
||||
var expectedTotalQueries int
|
||||
|
||||
if withCache {
|
||||
err := app.ReloadCachedCollections()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
} else {
|
||||
app.Store().Reset(nil)
|
||||
expectedTotalQueries = len(scenarios)
|
||||
}
|
||||
|
||||
totalQueries = 0
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%s", i, s.nameOrId), func(t *testing.T) {
|
||||
model, err := app.FindCachedCollectionByNameOrId(s.nameOrId)
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Fatalf("Expected hasErr to be %v, got %v (%v)", s.expectError, hasErr, err)
|
||||
}
|
||||
|
||||
if model != nil && model.Id != s.nameOrId && !strings.EqualFold(model.Name, s.nameOrId) {
|
||||
t.Fatalf("Expected model with identifier %s, got %v", s.nameOrId, model)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
if totalQueries != expectedTotalQueries {
|
||||
t.Fatalf("Expected %d totalQueries, got %d", expectedTotalQueries, totalQueries)
|
||||
}
|
||||
}
|
||||
|
||||
run(true)
|
||||
|
||||
run(false)
|
||||
}
|
||||
|
||||
func TestFindCollectionReferences(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
collection, err := app.FindCollectionByNameOrId("demo3")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
result, err := app.FindCollectionReferences(
|
||||
collection,
|
||||
collection.Id,
|
||||
// test whether "nonempty" exclude ids condition will be skipped
|
||||
"",
|
||||
"",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(result) != 1 {
|
||||
t.Fatalf("Expected 1 collection, got %d: %v", len(result), result)
|
||||
}
|
||||
|
||||
expectedFields := []string{
|
||||
"rel_one_no_cascade",
|
||||
"rel_one_no_cascade_required",
|
||||
"rel_one_cascade",
|
||||
"rel_one_unique",
|
||||
"rel_many_no_cascade",
|
||||
"rel_many_no_cascade_required",
|
||||
"rel_many_cascade",
|
||||
"rel_many_unique",
|
||||
}
|
||||
|
||||
for col, fields := range result {
|
||||
if col.Name != "demo4" {
|
||||
t.Fatalf("Expected collection demo4, got %s", col.Name)
|
||||
}
|
||||
if len(fields) != len(expectedFields) {
|
||||
t.Fatalf("Expected fields %v, got %v", expectedFields, fields)
|
||||
}
|
||||
for i, f := range fields {
|
||||
if !slices.Contains(expectedFields, f.GetName()) {
|
||||
t.Fatalf("[%d] Didn't expect field %v", i, f)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindCachedCollectionReferences(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
collection, err := app.FindCollectionByNameOrId("demo3")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
totalQueries := 0
|
||||
app.ConcurrentDB().(*dbx.DB).QueryLogFunc = func(ctx context.Context, t time.Duration, sql string, rows *sql.Rows, err error) {
|
||||
totalQueries++
|
||||
}
|
||||
|
||||
run := func(withCache bool) {
|
||||
var expectedTotalQueries int
|
||||
|
||||
if withCache {
|
||||
err := app.ReloadCachedCollections()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
} else {
|
||||
app.Store().Reset(nil)
|
||||
expectedTotalQueries = 1
|
||||
}
|
||||
|
||||
totalQueries = 0
|
||||
|
||||
result, err := app.FindCachedCollectionReferences(
|
||||
collection,
|
||||
collection.Id,
|
||||
// test whether "nonempty" exclude ids condition will be skipped
|
||||
"",
|
||||
"",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(result) != 1 {
|
||||
t.Fatalf("Expected 1 collection, got %d: %v", len(result), result)
|
||||
}
|
||||
|
||||
expectedFields := []string{
|
||||
"rel_one_no_cascade",
|
||||
"rel_one_no_cascade_required",
|
||||
"rel_one_cascade",
|
||||
"rel_one_unique",
|
||||
"rel_many_no_cascade",
|
||||
"rel_many_no_cascade_required",
|
||||
"rel_many_cascade",
|
||||
"rel_many_unique",
|
||||
}
|
||||
|
||||
for col, fields := range result {
|
||||
if col.Name != "demo4" {
|
||||
t.Fatalf("Expected collection demo4, got %s", col.Name)
|
||||
}
|
||||
if len(fields) != len(expectedFields) {
|
||||
t.Fatalf("Expected fields %v, got %v", expectedFields, fields)
|
||||
}
|
||||
for i, f := range fields {
|
||||
if !slices.Contains(expectedFields, f.GetName()) {
|
||||
t.Fatalf("[%d] Didn't expect field %v", i, f)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if totalQueries != expectedTotalQueries {
|
||||
t.Fatalf("Expected %d totalQueries, got %d", expectedTotalQueries, totalQueries)
|
||||
}
|
||||
}
|
||||
|
||||
run(true)
|
||||
|
||||
run(false)
|
||||
}
|
||||
|
||||
func TestIsCollectionNameUnique(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
excludeId string
|
||||
expected bool
|
||||
}{
|
||||
{"", "", false},
|
||||
{"demo1", "", false},
|
||||
{"Demo1", "", false},
|
||||
{"new", "", true},
|
||||
{"demo1", "wsmn24bux7wo113", true},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%s", i, s.name), func(t *testing.T) {
|
||||
result := app.IsCollectionNameUnique(s.name, s.excludeId)
|
||||
if result != s.expected {
|
||||
t.Errorf("Expected %v, got %v", s.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindCollectionTruncate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
countFiles := func(collectionId string) (int, error) {
|
||||
entries, err := os.ReadDir(filepath.Join(app.DataDir(), "storage", collectionId))
|
||||
return len(entries), err
|
||||
}
|
||||
|
||||
t.Run("truncate view", func(t *testing.T) {
|
||||
view2, err := app.FindCollectionByNameOrId("view2")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = app.TruncateCollection(view2)
|
||||
if err == nil {
|
||||
t.Fatalf("Expected truncate to fail because view collections can't be truncated")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("truncate failure", func(t *testing.T) {
|
||||
demo3, err := app.FindCollectionByNameOrId("demo3")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
originalTotalRecords, err := app.CountRecords(demo3)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
originalTotalFiles, err := countFiles(demo3.Id)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = app.TruncateCollection(demo3)
|
||||
if err == nil {
|
||||
t.Fatalf("Expected truncate to fail due to cascade delete failed required constraint")
|
||||
}
|
||||
|
||||
// short delay to ensure that the file delete goroutine has been executed
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
totalRecords, err := app.CountRecords(demo3)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if totalRecords != originalTotalRecords {
|
||||
t.Fatalf("Expected %d records, got %d", originalTotalRecords, totalRecords)
|
||||
}
|
||||
|
||||
totalFiles, err := countFiles(demo3.Id)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if totalFiles != originalTotalFiles {
|
||||
t.Fatalf("Expected %d files, got %d", originalTotalFiles, totalFiles)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("truncate success", func(t *testing.T) {
|
||||
demo5, err := app.FindCollectionByNameOrId("demo5")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = app.TruncateCollection(demo5)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// short delay to ensure that the file delete goroutine has been executed
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
total, err := app.CountRecords(demo5)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if total != 0 {
|
||||
t.Fatalf("Expected all records to be deleted, got %v", total)
|
||||
}
|
||||
|
||||
totalFiles, err := countFiles(demo5.Id)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if totalFiles != 0 {
|
||||
t.Fatalf("Expected truncated record files to be deleted, got %d", totalFiles)
|
||||
}
|
||||
|
||||
// try to truncate again (shouldn't return an error)
|
||||
err = app.TruncateCollection(demo5)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
}
|
364
core/collection_record_table_sync.go
Normal file
364
core/collection_record_table_sync.go
Normal file
|
@ -0,0 +1,364 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/tools/dbutils"
|
||||
"github.com/pocketbase/pocketbase/tools/security"
|
||||
)
|
||||
|
||||
// SyncRecordTableSchema compares the two provided collections
|
||||
// and applies the necessary related record table changes.
|
||||
//
|
||||
// If oldCollection is null, then only newCollection is used to create the record table.
|
||||
//
|
||||
// This method is automatically invoked as part of a collection create/update/delete operation.
|
||||
func (app *BaseApp) SyncRecordTableSchema(newCollection *Collection, oldCollection *Collection) error {
|
||||
if newCollection.IsView() {
|
||||
return nil // nothing to sync since views don't have records table
|
||||
}
|
||||
|
||||
txErr := app.RunInTransaction(func(txApp App) error {
|
||||
// create
|
||||
// -----------------------------------------------------------
|
||||
if oldCollection == nil || !app.HasTable(oldCollection.Name) {
|
||||
tableName := newCollection.Name
|
||||
|
||||
fields := newCollection.Fields
|
||||
|
||||
cols := make(map[string]string, len(fields))
|
||||
|
||||
// add fields definition
|
||||
for _, field := range fields {
|
||||
cols[field.GetName()] = field.ColumnType(app)
|
||||
}
|
||||
|
||||
// create table
|
||||
if _, err := txApp.DB().CreateTable(tableName, cols).Execute(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return createCollectionIndexes(txApp, newCollection)
|
||||
}
|
||||
|
||||
// update
|
||||
// -----------------------------------------------------------
|
||||
oldTableName := oldCollection.Name
|
||||
newTableName := newCollection.Name
|
||||
oldFields := oldCollection.Fields
|
||||
newFields := newCollection.Fields
|
||||
|
||||
needTableRename := !strings.EqualFold(oldTableName, newTableName)
|
||||
|
||||
var needIndexesUpdate bool
|
||||
if needTableRename ||
|
||||
oldFields.String() != newFields.String() ||
|
||||
oldCollection.Indexes.String() != newCollection.Indexes.String() {
|
||||
needIndexesUpdate = true
|
||||
}
|
||||
|
||||
if needIndexesUpdate {
|
||||
// drop old indexes (if any)
|
||||
if err := dropCollectionIndexes(txApp, oldCollection); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// check for renamed table
|
||||
if needTableRename {
|
||||
_, err := txApp.DB().RenameTable("{{"+oldTableName+"}}", "{{"+newTableName+"}}").Execute()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// check for deleted columns
|
||||
for _, oldField := range oldFields {
|
||||
if f := newFields.GetById(oldField.GetId()); f != nil {
|
||||
continue // exist
|
||||
}
|
||||
|
||||
_, err := txApp.DB().DropColumn(newTableName, oldField.GetName()).Execute()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to drop column %s - %w", oldField.GetName(), err)
|
||||
}
|
||||
}
|
||||
|
||||
// check for new or renamed columns
|
||||
toRename := map[string]string{}
|
||||
for _, field := range newFields {
|
||||
oldField := oldFields.GetById(field.GetId())
|
||||
// Note:
|
||||
// We are using a temporary column name when adding or renaming columns
|
||||
// to ensure that there are no name collisions in case there is
|
||||
// names switch/reuse of existing columns (eg. name, title -> title, name).
|
||||
// This way we are always doing 1 more rename operation but it provides better less ambiguous experience.
|
||||
|
||||
if oldField == nil {
|
||||
tempName := field.GetName() + security.PseudorandomString(5)
|
||||
toRename[tempName] = field.GetName()
|
||||
|
||||
// add
|
||||
_, err := txApp.DB().AddColumn(newTableName, tempName, field.ColumnType(txApp)).Execute()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add column %s - %w", field.GetName(), err)
|
||||
}
|
||||
} else if oldField.GetName() != field.GetName() {
|
||||
tempName := field.GetName() + security.PseudorandomString(5)
|
||||
toRename[tempName] = field.GetName()
|
||||
|
||||
// rename
|
||||
_, err := txApp.DB().RenameColumn(newTableName, oldField.GetName(), tempName).Execute()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to rename column %s - %w", oldField.GetName(), err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// set the actual columns name
|
||||
for tempName, actualName := range toRename {
|
||||
_, err := txApp.DB().RenameColumn(newTableName, tempName, actualName).Execute()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if err := normalizeSingleVsMultipleFieldChanges(txApp, newCollection, oldCollection); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if needIndexesUpdate {
|
||||
return createCollectionIndexes(txApp, newCollection)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if txErr != nil {
|
||||
return txErr
|
||||
}
|
||||
|
||||
// run optimize per the SQLite recommendations
|
||||
// (https://www.sqlite.org/pragma.html#pragma_optimize)
|
||||
_, optimizeErr := app.ConcurrentDB().NewQuery("PRAGMA optimize").Execute()
|
||||
if optimizeErr != nil {
|
||||
app.Logger().Warn("Failed to run PRAGMA optimize after record table sync", slog.String("error", optimizeErr.Error()))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func normalizeSingleVsMultipleFieldChanges(app App, newCollection *Collection, oldCollection *Collection) error {
|
||||
if newCollection.IsView() || oldCollection == nil {
|
||||
return nil // view or not an update
|
||||
}
|
||||
|
||||
return app.RunInTransaction(func(txApp App) error {
|
||||
for _, newField := range newCollection.Fields {
|
||||
// allow to continue even if there is no old field for the cases
|
||||
// when a new field is added and there are already inserted data
|
||||
var isOldMultiple bool
|
||||
if oldField := oldCollection.Fields.GetById(newField.GetId()); oldField != nil {
|
||||
if mv, ok := oldField.(MultiValuer); ok {
|
||||
isOldMultiple = mv.IsMultiple()
|
||||
}
|
||||
}
|
||||
|
||||
var isNewMultiple bool
|
||||
if mv, ok := newField.(MultiValuer); ok {
|
||||
isNewMultiple = mv.IsMultiple()
|
||||
}
|
||||
|
||||
if isOldMultiple == isNewMultiple {
|
||||
continue // no change
|
||||
}
|
||||
|
||||
// -------------------------------------------------------
|
||||
// update the field column definition
|
||||
// -------------------------------------------------------
|
||||
|
||||
// temporary drop all views to prevent reference errors during the columns renaming
|
||||
// (this is used as an "alternative" to the writable_schema PRAGMA)
|
||||
views := []struct {
|
||||
Name string `db:"name"`
|
||||
SQL string `db:"sql"`
|
||||
}{}
|
||||
err := txApp.DB().Select("name", "sql").
|
||||
From("sqlite_master").
|
||||
AndWhere(dbx.NewExp("sql is not null")).
|
||||
AndWhere(dbx.HashExp{"type": "view"}).
|
||||
All(&views)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, view := range views {
|
||||
err = txApp.DeleteView(view.Name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
originalName := newField.GetName()
|
||||
oldTempName := "_" + newField.GetName() + security.PseudorandomString(5)
|
||||
|
||||
// rename temporary the original column to something else to allow inserting a new one in its place
|
||||
_, err = txApp.DB().RenameColumn(newCollection.Name, originalName, oldTempName).Execute()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// reinsert the field column with the new type
|
||||
_, err = txApp.DB().AddColumn(newCollection.Name, originalName, newField.ColumnType(txApp)).Execute()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var copyQuery *dbx.Query
|
||||
|
||||
if !isOldMultiple && isNewMultiple {
|
||||
// single -> multiple (convert to array)
|
||||
copyQuery = txApp.DB().NewQuery(fmt.Sprintf(
|
||||
`UPDATE {{%s}} set [[%s]] = (
|
||||
CASE
|
||||
WHEN COALESCE([[%s]], '') = ''
|
||||
THEN '[]'
|
||||
ELSE (
|
||||
CASE
|
||||
WHEN json_valid([[%s]]) AND json_type([[%s]]) == 'array'
|
||||
THEN [[%s]]
|
||||
ELSE json_array([[%s]])
|
||||
END
|
||||
)
|
||||
END
|
||||
)`,
|
||||
newCollection.Name,
|
||||
originalName,
|
||||
oldTempName,
|
||||
oldTempName,
|
||||
oldTempName,
|
||||
oldTempName,
|
||||
oldTempName,
|
||||
))
|
||||
} else {
|
||||
// multiple -> single (keep only the last element)
|
||||
//
|
||||
// note: for file fields the actual file objects are not
|
||||
// deleted allowing additional custom handling via migration
|
||||
copyQuery = txApp.DB().NewQuery(fmt.Sprintf(
|
||||
`UPDATE {{%s}} set [[%s]] = (
|
||||
CASE
|
||||
WHEN COALESCE([[%s]], '[]') = '[]'
|
||||
THEN ''
|
||||
ELSE (
|
||||
CASE
|
||||
WHEN json_valid([[%s]]) AND json_type([[%s]]) == 'array'
|
||||
THEN COALESCE(json_extract([[%s]], '$[#-1]'), '')
|
||||
ELSE [[%s]]
|
||||
END
|
||||
)
|
||||
END
|
||||
)`,
|
||||
newCollection.Name,
|
||||
originalName,
|
||||
oldTempName,
|
||||
oldTempName,
|
||||
oldTempName,
|
||||
oldTempName,
|
||||
oldTempName,
|
||||
))
|
||||
}
|
||||
|
||||
// copy the normalized values
|
||||
_, err = copyQuery.Execute()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// drop the original column
|
||||
_, err = txApp.DB().DropColumn(newCollection.Name, oldTempName).Execute()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// restore views
|
||||
for _, view := range views {
|
||||
_, err = txApp.DB().NewQuery(view.SQL).Execute()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func dropCollectionIndexes(app App, collection *Collection) error {
|
||||
if collection.IsView() {
|
||||
return nil // views don't have indexes
|
||||
}
|
||||
|
||||
return app.RunInTransaction(func(txApp App) error {
|
||||
for _, raw := range collection.Indexes {
|
||||
parsed := dbutils.ParseIndex(raw)
|
||||
|
||||
if !parsed.IsValid() {
|
||||
continue
|
||||
}
|
||||
|
||||
_, err := txApp.DB().NewQuery(fmt.Sprintf("DROP INDEX IF EXISTS [[%s]]", parsed.IndexName)).Execute()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func createCollectionIndexes(app App, collection *Collection) error {
|
||||
if collection.IsView() {
|
||||
return nil // views don't have indexes
|
||||
}
|
||||
|
||||
return app.RunInTransaction(func(txApp App) error {
|
||||
// upsert new indexes
|
||||
//
|
||||
// note: we are returning validation errors because the indexes cannot be
|
||||
// easily validated in a form, aka. before persisting the related
|
||||
// collection record table changes
|
||||
errs := validation.Errors{}
|
||||
for i, idx := range collection.Indexes {
|
||||
parsed := dbutils.ParseIndex(idx)
|
||||
|
||||
// ensure that the index is always for the current collection
|
||||
parsed.TableName = collection.Name
|
||||
|
||||
if !parsed.IsValid() {
|
||||
errs[strconv.Itoa(i)] = validation.NewError(
|
||||
"validation_invalid_index_expression",
|
||||
"Invalid CREATE INDEX expression.",
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
if _, err := txApp.DB().NewQuery(parsed.Build()).Execute(); err != nil {
|
||||
errs[strconv.Itoa(i)] = validation.NewError(
|
||||
"validation_invalid_index_expression",
|
||||
fmt.Sprintf("Failed to create index %s - %v.", parsed.IndexName, err.Error()),
|
||||
)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if len(errs) > 0 {
|
||||
return validation.Errors{"indexes": errs}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
296
core/collection_record_table_sync_test.go
Normal file
296
core/collection_record_table_sync_test.go
Normal file
|
@ -0,0 +1,296 @@
|
|||
package core_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
"github.com/pocketbase/pocketbase/tools/list"
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
)
|
||||
|
||||
func TestSyncRecordTableSchema(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
oldCollection, err := app.FindCollectionByNameOrId("demo2")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
updatedCollection, err := app.FindCollectionByNameOrId("demo2")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
updatedCollection.Name = "demo_renamed"
|
||||
updatedCollection.Fields.RemoveByName("active")
|
||||
updatedCollection.Fields.Add(&core.EmailField{
|
||||
Name: "new_field",
|
||||
})
|
||||
updatedCollection.Fields.Add(&core.EmailField{
|
||||
Id: updatedCollection.Fields.GetByName("title").GetId(),
|
||||
Name: "title_renamed",
|
||||
})
|
||||
updatedCollection.Indexes = types.JSONArray[string]{"create index idx_title_renamed on anything (title_renamed)"}
|
||||
|
||||
baseCol := core.NewBaseCollection("new_base")
|
||||
baseCol.Fields.Add(&core.TextField{Name: "test"})
|
||||
|
||||
authCol := core.NewAuthCollection("new_auth")
|
||||
authCol.Fields.Add(&core.TextField{Name: "test"})
|
||||
authCol.AddIndex("idx_auth_test", false, "email, id", "")
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
newCollection *core.Collection
|
||||
oldCollection *core.Collection
|
||||
expectedColumns []string
|
||||
expectedIndexesCount int
|
||||
}{
|
||||
{
|
||||
"new base collection",
|
||||
baseCol,
|
||||
nil,
|
||||
[]string{"id", "test"},
|
||||
0,
|
||||
},
|
||||
{
|
||||
"new auth collection",
|
||||
authCol,
|
||||
nil,
|
||||
[]string{
|
||||
"id", "test", "email", "verified",
|
||||
"emailVisibility", "tokenKey", "password",
|
||||
},
|
||||
3,
|
||||
},
|
||||
{
|
||||
"no changes",
|
||||
oldCollection,
|
||||
oldCollection,
|
||||
[]string{"id", "created", "updated", "title", "active"},
|
||||
3,
|
||||
},
|
||||
{
|
||||
"renamed table, deleted column, renamed columnd and new column",
|
||||
updatedCollection,
|
||||
oldCollection,
|
||||
[]string{"id", "created", "updated", "title_renamed", "new_field"},
|
||||
1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
err := app.SyncRecordTableSchema(s.newCollection, s.oldCollection)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !app.HasTable(s.newCollection.Name) {
|
||||
t.Fatalf("Expected table %s to exist", s.newCollection.Name)
|
||||
}
|
||||
|
||||
cols, _ := app.TableColumns(s.newCollection.Name)
|
||||
if len(cols) != len(s.expectedColumns) {
|
||||
t.Fatalf("Expected columns %v, got %v", s.expectedColumns, cols)
|
||||
}
|
||||
|
||||
for _, col := range cols {
|
||||
if !list.ExistInSlice(col, s.expectedColumns) {
|
||||
t.Fatalf("Couldn't find column %s in %v", col, s.expectedColumns)
|
||||
}
|
||||
}
|
||||
|
||||
indexes, _ := app.TableIndexes(s.newCollection.Name)
|
||||
|
||||
if totalIndexes := len(indexes); totalIndexes != s.expectedIndexesCount {
|
||||
t.Fatalf("Expected %d indexes, got %d:\n%v", s.expectedIndexesCount, totalIndexes, indexes)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func getTotalViews(app core.App) (int, error) {
|
||||
var total int
|
||||
|
||||
err := app.DB().Select("count(*)").
|
||||
From("sqlite_master").
|
||||
AndWhere(dbx.NewExp("sql is not null")).
|
||||
AndWhere(dbx.HashExp{"type": "view"}).
|
||||
Row(&total)
|
||||
|
||||
return total, err
|
||||
}
|
||||
|
||||
func TestSingleVsMultipleValuesNormalization(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
collection, err := app.FindCollectionByNameOrId("demo1")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
beforeTotalViews, err := getTotalViews(app)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// mock field changes
|
||||
collection.Fields.GetByName("select_one").(*core.SelectField).MaxSelect = 2
|
||||
collection.Fields.GetByName("select_many").(*core.SelectField).MaxSelect = 1
|
||||
collection.Fields.GetByName("file_one").(*core.FileField).MaxSelect = 2
|
||||
collection.Fields.GetByName("file_many").(*core.FileField).MaxSelect = 1
|
||||
collection.Fields.GetByName("rel_one").(*core.RelationField).MaxSelect = 2
|
||||
collection.Fields.GetByName("rel_many").(*core.RelationField).MaxSelect = 1
|
||||
|
||||
// new multivaluer field to check whether the array normalization
|
||||
// will be applied for already inserted data
|
||||
collection.Fields.Add(&core.SelectField{
|
||||
Name: "new_multiple",
|
||||
Values: []string{"a", "b", "c"},
|
||||
MaxSelect: 3,
|
||||
})
|
||||
|
||||
if err := app.Save(collection); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// ensure that the views were reinserted
|
||||
afterTotalViews, err := getTotalViews(app)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if afterTotalViews != beforeTotalViews {
|
||||
t.Fatalf("Expected total views %d, got %d", beforeTotalViews, afterTotalViews)
|
||||
}
|
||||
|
||||
// check whether the columns DEFAULT definition was updated
|
||||
// ---------------------------------------------------------------
|
||||
tableInfo, err := app.TableInfo(collection.Name)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
tableInfoExpectations := map[string]string{
|
||||
"select_one": `'[]'`,
|
||||
"select_many": `''`,
|
||||
"file_one": `'[]'`,
|
||||
"file_many": `''`,
|
||||
"rel_one": `'[]'`,
|
||||
"rel_many": `''`,
|
||||
"new_multiple": `'[]'`,
|
||||
}
|
||||
for col, dflt := range tableInfoExpectations {
|
||||
t.Run("check default for "+col, func(t *testing.T) {
|
||||
var row *core.TableInfoRow
|
||||
for _, r := range tableInfo {
|
||||
if r.Name == col {
|
||||
row = r
|
||||
break
|
||||
}
|
||||
}
|
||||
if row == nil {
|
||||
t.Fatalf("Missing info for column %q", col)
|
||||
}
|
||||
|
||||
if v := row.DefaultValue.String; v != dflt {
|
||||
t.Fatalf("Expected default value %q, got %q", dflt, v)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// check whether the values were normalized
|
||||
// ---------------------------------------------------------------
|
||||
type fieldsExpectation struct {
|
||||
SelectOne string `db:"select_one"`
|
||||
SelectMany string `db:"select_many"`
|
||||
FileOne string `db:"file_one"`
|
||||
FileMany string `db:"file_many"`
|
||||
RelOne string `db:"rel_one"`
|
||||
RelMany string `db:"rel_many"`
|
||||
NewMultiple string `db:"new_multiple"`
|
||||
}
|
||||
|
||||
fieldsScenarios := []struct {
|
||||
recordId string
|
||||
expected fieldsExpectation
|
||||
}{
|
||||
{
|
||||
"imy661ixudk5izi",
|
||||
fieldsExpectation{
|
||||
SelectOne: `[]`,
|
||||
SelectMany: ``,
|
||||
FileOne: `[]`,
|
||||
FileMany: ``,
|
||||
RelOne: `[]`,
|
||||
RelMany: ``,
|
||||
NewMultiple: `[]`,
|
||||
},
|
||||
},
|
||||
{
|
||||
"al1h9ijdeojtsjy",
|
||||
fieldsExpectation{
|
||||
SelectOne: `["optionB"]`,
|
||||
SelectMany: `optionB`,
|
||||
FileOne: `["300_Jsjq7RdBgA.png"]`,
|
||||
FileMany: ``,
|
||||
RelOne: `["84nmscqy84lsi1t"]`,
|
||||
RelMany: `oap640cot4yru2s`,
|
||||
NewMultiple: `[]`,
|
||||
},
|
||||
},
|
||||
{
|
||||
"84nmscqy84lsi1t",
|
||||
fieldsExpectation{
|
||||
SelectOne: `["optionB"]`,
|
||||
SelectMany: `optionC`,
|
||||
FileOne: `["test_d61b33QdDU.txt"]`,
|
||||
FileMany: `test_tC1Yc87DfC.txt`,
|
||||
RelOne: `[]`,
|
||||
RelMany: `oap640cot4yru2s`,
|
||||
NewMultiple: `[]`,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range fieldsScenarios {
|
||||
t.Run("check fields for record "+s.recordId, func(t *testing.T) {
|
||||
result := new(fieldsExpectation)
|
||||
|
||||
err := app.DB().Select(
|
||||
"select_one",
|
||||
"select_many",
|
||||
"file_one",
|
||||
"file_many",
|
||||
"rel_one",
|
||||
"rel_many",
|
||||
"new_multiple",
|
||||
).From(collection.Name).Where(dbx.HashExp{"id": s.recordId}).One(result)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load record: %v", err)
|
||||
}
|
||||
|
||||
encodedResult, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to encode result: %v", err)
|
||||
}
|
||||
|
||||
encodedExpectation, err := json.Marshal(s.expected)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to encode expectation: %v", err)
|
||||
}
|
||||
|
||||
if !bytes.EqualFold(encodedExpectation, encodedResult) {
|
||||
t.Fatalf("Expected \n%s, \ngot \n%s", encodedExpectation, encodedResult)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
694
core/collection_validate.go
Normal file
694
core/collection_validate.go
Normal file
|
@ -0,0 +1,694 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/core/validators"
|
||||
"github.com/pocketbase/pocketbase/tools/dbutils"
|
||||
"github.com/pocketbase/pocketbase/tools/list"
|
||||
"github.com/pocketbase/pocketbase/tools/search"
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
)
|
||||
|
||||
var collectionNameRegex = regexp.MustCompile(`^\w+$`)
|
||||
|
||||
func onCollectionValidate(e *CollectionEvent) error {
|
||||
var original *Collection
|
||||
if !e.Collection.IsNew() {
|
||||
original = &Collection{}
|
||||
if err := e.App.ModelQuery(original).Model(e.Collection.LastSavedPK(), original); err != nil {
|
||||
return fmt.Errorf("failed to fetch old collection state: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
validator := newCollectionValidator(
|
||||
e.Context,
|
||||
e.App,
|
||||
e.Collection,
|
||||
original,
|
||||
)
|
||||
|
||||
if err := validator.run(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return e.Next()
|
||||
}
|
||||
|
||||
func newCollectionValidator(ctx context.Context, app App, new, original *Collection) *collectionValidator {
|
||||
validator := &collectionValidator{
|
||||
ctx: ctx,
|
||||
app: app,
|
||||
new: new,
|
||||
original: original,
|
||||
}
|
||||
|
||||
// load old/original collection
|
||||
if validator.original == nil {
|
||||
validator.original = NewCollection(validator.new.Type, "")
|
||||
}
|
||||
|
||||
return validator
|
||||
}
|
||||
|
||||
type collectionValidator struct {
|
||||
original *Collection
|
||||
new *Collection
|
||||
app App
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
type optionsValidator interface {
|
||||
validate(cv *collectionValidator) error
|
||||
}
|
||||
|
||||
func (validator *collectionValidator) run() error {
|
||||
if validator.original.IsNew() {
|
||||
validator.new.updateGeneratedIdIfExists(validator.app)
|
||||
}
|
||||
|
||||
// generate fields from the query (overwriting any explicit user defined fields)
|
||||
if validator.new.IsView() {
|
||||
validator.new.Fields, _ = validator.app.CreateViewFields(validator.new.ViewQuery)
|
||||
}
|
||||
|
||||
// validate base fields
|
||||
baseErr := validation.ValidateStruct(validator.new,
|
||||
validation.Field(
|
||||
&validator.new.Id,
|
||||
validation.Required,
|
||||
validation.When(
|
||||
validator.original.IsNew(),
|
||||
validation.Length(1, 100),
|
||||
validation.Match(DefaultIdRegex),
|
||||
validation.By(validators.UniqueId(validator.app.ConcurrentDB(), validator.new.TableName())),
|
||||
).Else(
|
||||
validation.By(validators.Equal(validator.original.Id)),
|
||||
),
|
||||
),
|
||||
validation.Field(
|
||||
&validator.new.System,
|
||||
validation.By(validator.ensureNoSystemFlagChange),
|
||||
),
|
||||
validation.Field(
|
||||
&validator.new.Type,
|
||||
validation.Required,
|
||||
validation.In(
|
||||
CollectionTypeBase,
|
||||
CollectionTypeAuth,
|
||||
CollectionTypeView,
|
||||
),
|
||||
validation.By(validator.ensureNoTypeChange),
|
||||
),
|
||||
validation.Field(
|
||||
&validator.new.Name,
|
||||
validation.Required,
|
||||
validation.Length(1, 255),
|
||||
validation.By(checkForVia),
|
||||
validation.Match(collectionNameRegex),
|
||||
validation.By(validator.ensureNoSystemNameChange),
|
||||
validation.By(validator.checkUniqueName),
|
||||
),
|
||||
validation.Field(
|
||||
&validator.new.Fields,
|
||||
validation.By(validator.checkFieldDuplicates),
|
||||
validation.By(validator.checkMinFields),
|
||||
validation.When(
|
||||
!validator.new.IsView(),
|
||||
validation.By(validator.ensureNoSystemFieldsChange),
|
||||
validation.By(validator.ensureNoFieldsTypeChange),
|
||||
),
|
||||
validation.When(validator.new.IsAuth(), validation.By(validator.checkReservedAuthKeys)),
|
||||
validation.By(validator.checkFieldValidators),
|
||||
),
|
||||
validation.Field(
|
||||
&validator.new.ListRule,
|
||||
validation.By(validator.checkRule),
|
||||
validation.By(validator.ensureNoSystemRuleChange(validator.original.ListRule)),
|
||||
),
|
||||
validation.Field(
|
||||
&validator.new.ViewRule,
|
||||
validation.By(validator.checkRule),
|
||||
validation.By(validator.ensureNoSystemRuleChange(validator.original.ViewRule)),
|
||||
),
|
||||
validation.Field(
|
||||
&validator.new.CreateRule,
|
||||
validation.When(validator.new.IsView(), validation.Nil),
|
||||
validation.By(validator.checkRule),
|
||||
validation.By(validator.ensureNoSystemRuleChange(validator.original.CreateRule)),
|
||||
),
|
||||
validation.Field(
|
||||
&validator.new.UpdateRule,
|
||||
validation.When(validator.new.IsView(), validation.Nil),
|
||||
validation.By(validator.checkRule),
|
||||
validation.By(validator.ensureNoSystemRuleChange(validator.original.UpdateRule)),
|
||||
),
|
||||
validation.Field(
|
||||
&validator.new.DeleteRule,
|
||||
validation.When(validator.new.IsView(), validation.Nil),
|
||||
validation.By(validator.checkRule),
|
||||
validation.By(validator.ensureNoSystemRuleChange(validator.original.DeleteRule)),
|
||||
),
|
||||
validation.Field(&validator.new.Indexes, validation.By(validator.checkIndexes)),
|
||||
)
|
||||
|
||||
optionsErr := validator.validateOptions()
|
||||
|
||||
return validators.JoinValidationErrors(baseErr, optionsErr)
|
||||
}
|
||||
|
||||
func (validator *collectionValidator) checkUniqueName(value any) error {
|
||||
v, _ := value.(string)
|
||||
|
||||
// ensure unique collection name
|
||||
if !validator.app.IsCollectionNameUnique(v, validator.original.Id) {
|
||||
return validation.NewError("validation_collection_name_exists", "Collection name must be unique (case insensitive).")
|
||||
}
|
||||
|
||||
// ensure that the collection name doesn't collide with the id of any collection
|
||||
dummyCollection := &Collection{}
|
||||
if validator.app.ModelQuery(dummyCollection).Model(v, dummyCollection) == nil {
|
||||
return validation.NewError("validation_collection_name_id_duplicate", "The name must not match an existing collection id.")
|
||||
}
|
||||
|
||||
// ensure that there is no existing internal table with the provided name
|
||||
if validator.original.Name != v && // has changed
|
||||
validator.app.IsCollectionNameUnique(v) && // is not a collection (in case it was presaved)
|
||||
validator.app.HasTable(v) {
|
||||
return validation.NewError("validation_collection_name_invalid", "The name shouldn't match with an existing internal table.")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (validator *collectionValidator) ensureNoSystemNameChange(value any) error {
|
||||
v, _ := value.(string)
|
||||
|
||||
if !validator.original.IsNew() && validator.original.System && v != validator.original.Name {
|
||||
return validation.NewError("validation_collection_system_name_change", "System collection name cannot be changed.")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (validator *collectionValidator) ensureNoSystemFlagChange(value any) error {
|
||||
v, _ := value.(bool)
|
||||
|
||||
if !validator.original.IsNew() && v != validator.original.System {
|
||||
return validation.NewError("validation_collection_system_flag_change", "System collection state cannot be changed.")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (validator *collectionValidator) ensureNoTypeChange(value any) error {
|
||||
v, _ := value.(string)
|
||||
|
||||
if !validator.original.IsNew() && v != validator.original.Type {
|
||||
return validation.NewError("validation_collection_type_change", "Collection type cannot be changed.")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (validator *collectionValidator) ensureNoFieldsTypeChange(value any) error {
|
||||
v, ok := value.(FieldsList)
|
||||
if !ok {
|
||||
return validators.ErrUnsupportedValueType
|
||||
}
|
||||
|
||||
errs := validation.Errors{}
|
||||
|
||||
for i, field := range v {
|
||||
oldField := validator.original.Fields.GetById(field.GetId())
|
||||
|
||||
if oldField != nil && oldField.Type() != field.Type() {
|
||||
errs[strconv.Itoa(i)] = validation.NewError(
|
||||
"validation_field_type_change",
|
||||
"Field type cannot be changed.",
|
||||
)
|
||||
}
|
||||
}
|
||||
if len(errs) > 0 {
|
||||
return errs
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (validator *collectionValidator) checkFieldDuplicates(value any) error {
|
||||
fields, ok := value.(FieldsList)
|
||||
if !ok {
|
||||
return validators.ErrUnsupportedValueType
|
||||
}
|
||||
|
||||
totalFields := len(fields)
|
||||
ids := make([]string, 0, totalFields)
|
||||
names := make([]string, 0, totalFields)
|
||||
|
||||
for i, field := range fields {
|
||||
if list.ExistInSlice(field.GetId(), ids) {
|
||||
return validation.Errors{
|
||||
strconv.Itoa(i): validation.Errors{
|
||||
"id": validation.NewError(
|
||||
"validation_duplicated_field_id",
|
||||
fmt.Sprintf("Duplicated or invalid field id %q", field.GetId()),
|
||||
),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// field names are used as db columns and should be case insensitive
|
||||
nameLower := strings.ToLower(field.GetName())
|
||||
|
||||
if list.ExistInSlice(nameLower, names) {
|
||||
return validation.Errors{
|
||||
strconv.Itoa(i): validation.Errors{
|
||||
"name": validation.NewError(
|
||||
"validation_duplicated_field_name",
|
||||
"Duplicated or invalid field name {{.fieldName}}",
|
||||
).SetParams(map[string]any{
|
||||
"fieldName": field.GetName(),
|
||||
}),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
ids = append(ids, field.GetId())
|
||||
names = append(names, nameLower)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (validator *collectionValidator) checkFieldValidators(value any) error {
|
||||
fields, ok := value.(FieldsList)
|
||||
if !ok {
|
||||
return validators.ErrUnsupportedValueType
|
||||
}
|
||||
|
||||
errs := validation.Errors{}
|
||||
|
||||
for i, field := range fields {
|
||||
if err := field.ValidateSettings(validator.ctx, validator.app, validator.new); err != nil {
|
||||
errs[strconv.Itoa(i)] = err
|
||||
}
|
||||
}
|
||||
|
||||
if len(errs) > 0 {
|
||||
return errs
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cv *collectionValidator) checkViewQuery(value any) error {
|
||||
v, _ := value.(string)
|
||||
if v == "" {
|
||||
return nil // nothing to check
|
||||
}
|
||||
|
||||
if _, err := cv.app.CreateViewFields(v); err != nil {
|
||||
return validation.NewError(
|
||||
"validation_invalid_view_query",
|
||||
fmt.Sprintf("Invalid query - %s", err.Error()),
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var reservedAuthKeys = []string{"passwordConfirm", "oldPassword"}
|
||||
|
||||
func (cv *collectionValidator) checkReservedAuthKeys(value any) error {
|
||||
fields, ok := value.(FieldsList)
|
||||
if !ok {
|
||||
return validators.ErrUnsupportedValueType
|
||||
}
|
||||
|
||||
if !cv.new.IsAuth() {
|
||||
return nil // not an auth collection
|
||||
}
|
||||
|
||||
errs := validation.Errors{}
|
||||
for i, field := range fields {
|
||||
if list.ExistInSlice(field.GetName(), reservedAuthKeys) {
|
||||
errs[strconv.Itoa(i)] = validation.Errors{
|
||||
"name": validation.NewError(
|
||||
"validation_reserved_field_name",
|
||||
"The field name is reserved and cannot be used.",
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(errs) > 0 {
|
||||
return errs
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cv *collectionValidator) checkMinFields(value any) error {
|
||||
fields, ok := value.(FieldsList)
|
||||
if !ok {
|
||||
return validators.ErrUnsupportedValueType
|
||||
}
|
||||
|
||||
if len(fields) == 0 {
|
||||
return validation.ErrRequired
|
||||
}
|
||||
|
||||
// all collections must have an "id" PK field
|
||||
idField, _ := fields.GetByName(FieldNameId).(*TextField)
|
||||
if idField == nil || !idField.PrimaryKey {
|
||||
return validation.NewError("validation_missing_primary_key", `Missing or invalid "id" PK field.`)
|
||||
}
|
||||
|
||||
switch cv.new.Type {
|
||||
case CollectionTypeAuth:
|
||||
passwordField, _ := fields.GetByName(FieldNamePassword).(*PasswordField)
|
||||
if passwordField == nil {
|
||||
return validation.NewError("validation_missing_password_field", `System "password" field is required.`)
|
||||
}
|
||||
if !passwordField.Hidden || !passwordField.System {
|
||||
return validation.Errors{FieldNamePassword: ErrMustBeSystemAndHidden}
|
||||
}
|
||||
|
||||
tokenKeyField, _ := fields.GetByName(FieldNameTokenKey).(*TextField)
|
||||
if tokenKeyField == nil {
|
||||
return validation.NewError("validation_missing_tokenKey_field", `System "tokenKey" field is required.`)
|
||||
}
|
||||
if !tokenKeyField.Hidden || !tokenKeyField.System {
|
||||
return validation.Errors{FieldNameTokenKey: ErrMustBeSystemAndHidden}
|
||||
}
|
||||
|
||||
emailField, _ := fields.GetByName(FieldNameEmail).(*EmailField)
|
||||
if emailField == nil {
|
||||
return validation.NewError("validation_missing_email_field", `System "email" field is required.`)
|
||||
}
|
||||
if !emailField.System {
|
||||
return validation.Errors{FieldNameEmail: ErrMustBeSystem}
|
||||
}
|
||||
|
||||
emailVisibilityField, _ := fields.GetByName(FieldNameEmailVisibility).(*BoolField)
|
||||
if emailVisibilityField == nil {
|
||||
return validation.NewError("validation_missing_emailVisibility_field", `System "emailVisibility" field is required.`)
|
||||
}
|
||||
if !emailVisibilityField.System {
|
||||
return validation.Errors{FieldNameEmailVisibility: ErrMustBeSystem}
|
||||
}
|
||||
|
||||
verifiedField, _ := fields.GetByName(FieldNameVerified).(*BoolField)
|
||||
if verifiedField == nil {
|
||||
return validation.NewError("validation_missing_verified_field", `System "verified" field is required.`)
|
||||
}
|
||||
if !verifiedField.System {
|
||||
return validation.Errors{FieldNameVerified: ErrMustBeSystem}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (validator *collectionValidator) ensureNoSystemFieldsChange(value any) error {
|
||||
fields, ok := value.(FieldsList)
|
||||
if !ok {
|
||||
return validators.ErrUnsupportedValueType
|
||||
}
|
||||
|
||||
if validator.original.IsNew() {
|
||||
return nil // not an update
|
||||
}
|
||||
|
||||
for _, oldField := range validator.original.Fields {
|
||||
if !oldField.GetSystem() {
|
||||
continue
|
||||
}
|
||||
|
||||
newField := fields.GetById(oldField.GetId())
|
||||
|
||||
if newField == nil || oldField.GetName() != newField.GetName() {
|
||||
return validation.NewError("validation_system_field_change", "System fields cannot be deleted or renamed.")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cv *collectionValidator) checkFieldsForUniqueIndex(value any) error {
|
||||
names, ok := value.([]string)
|
||||
if !ok {
|
||||
return validators.ErrUnsupportedValueType
|
||||
}
|
||||
|
||||
if len(names) == 0 {
|
||||
return nil // nothing to check
|
||||
}
|
||||
|
||||
for _, name := range names {
|
||||
field := cv.new.Fields.GetByName(name)
|
||||
if field == nil {
|
||||
return validation.NewError("validation_missing_field", "Invalid or missing field {{.fieldName}}").
|
||||
SetParams(map[string]any{"fieldName": name})
|
||||
}
|
||||
|
||||
if _, ok := dbutils.FindSingleColumnUniqueIndex(cv.new.Indexes, name); !ok {
|
||||
return validation.NewError("validation_missing_unique_constraint", "The field {{.fieldName}} doesn't have a UNIQUE constraint.").
|
||||
SetParams(map[string]any{"fieldName": name})
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// note: value could be either *string or string
|
||||
func (validator *collectionValidator) checkRule(value any) error {
|
||||
var vStr string
|
||||
|
||||
v, ok := value.(*string)
|
||||
if ok {
|
||||
if v != nil {
|
||||
vStr = *v
|
||||
}
|
||||
} else {
|
||||
vStr, ok = value.(string)
|
||||
}
|
||||
if !ok {
|
||||
return validators.ErrUnsupportedValueType
|
||||
}
|
||||
|
||||
if vStr == "" {
|
||||
return nil // nothing to check
|
||||
}
|
||||
|
||||
r := NewRecordFieldResolver(validator.app, validator.new, nil, true)
|
||||
_, err := search.FilterData(vStr).BuildExpr(r)
|
||||
if err != nil {
|
||||
return validation.NewError("validation_invalid_rule", "Invalid rule. Raw error: "+err.Error())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (validator *collectionValidator) ensureNoSystemRuleChange(oldRule *string) validation.RuleFunc {
|
||||
return func(value any) error {
|
||||
if validator.original.IsNew() || !validator.original.System {
|
||||
return nil // not an update of a system collection
|
||||
}
|
||||
|
||||
rule, ok := value.(*string)
|
||||
if !ok {
|
||||
return validators.ErrUnsupportedValueType
|
||||
}
|
||||
|
||||
if (rule == nil && oldRule == nil) ||
|
||||
(rule != nil && oldRule != nil && *rule == *oldRule) {
|
||||
return nil
|
||||
}
|
||||
|
||||
return validation.NewError("validation_collection_system_rule_change", "System collection API rule cannot be changed.")
|
||||
}
|
||||
}
|
||||
|
||||
func (cv *collectionValidator) checkIndexes(value any) error {
|
||||
indexes, _ := value.(types.JSONArray[string])
|
||||
|
||||
if cv.new.IsView() && len(indexes) > 0 {
|
||||
return validation.NewError(
|
||||
"validation_indexes_not_supported",
|
||||
"View collections don't support indexes.",
|
||||
)
|
||||
}
|
||||
|
||||
duplicatedNames := make(map[string]struct{}, len(indexes))
|
||||
duplicatedDefinitions := make(map[string]struct{}, len(indexes))
|
||||
|
||||
for i, rawIndex := range indexes {
|
||||
parsed := dbutils.ParseIndex(rawIndex)
|
||||
|
||||
// always set a table name because it is ignored anyway in order to keep it in sync with the collection name
|
||||
parsed.TableName = "validator"
|
||||
|
||||
if !parsed.IsValid() {
|
||||
return validation.Errors{
|
||||
strconv.Itoa(i): validation.NewError(
|
||||
"validation_invalid_index_expression",
|
||||
"Invalid CREATE INDEX expression.",
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
if _, isDuplicated := duplicatedNames[strings.ToLower(parsed.IndexName)]; isDuplicated {
|
||||
return validation.Errors{
|
||||
strconv.Itoa(i): validation.NewError(
|
||||
"validation_duplicated_index_name",
|
||||
"The index name already exists.",
|
||||
),
|
||||
}
|
||||
}
|
||||
duplicatedNames[strings.ToLower(parsed.IndexName)] = struct{}{}
|
||||
|
||||
// ensure that the index name is not used in another collection
|
||||
var usedTblName string
|
||||
_ = cv.app.ConcurrentDB().Select("tbl_name").
|
||||
From("sqlite_master").
|
||||
AndWhere(dbx.HashExp{"type": "index"}).
|
||||
AndWhere(dbx.NewExp("LOWER([[tbl_name]])!=LOWER({:oldName})", dbx.Params{"oldName": cv.original.Name})).
|
||||
AndWhere(dbx.NewExp("LOWER([[tbl_name]])!=LOWER({:newName})", dbx.Params{"newName": cv.new.Name})).
|
||||
AndWhere(dbx.NewExp("LOWER([[name]])=LOWER({:indexName})", dbx.Params{"indexName": parsed.IndexName})).
|
||||
Limit(1).
|
||||
Row(&usedTblName)
|
||||
if usedTblName != "" {
|
||||
return validation.Errors{
|
||||
strconv.Itoa(i): validation.NewError(
|
||||
"validation_existing_index_name",
|
||||
"The index name is already used in {{.usedTableName}} collection.",
|
||||
).SetParams(map[string]any{"usedTableName": usedTblName}),
|
||||
}
|
||||
}
|
||||
|
||||
// reset non-important identifiers
|
||||
parsed.SchemaName = "validator"
|
||||
parsed.IndexName = "validator"
|
||||
parsedDef := parsed.Build()
|
||||
|
||||
if _, isDuplicated := duplicatedDefinitions[parsedDef]; isDuplicated {
|
||||
return validation.Errors{
|
||||
strconv.Itoa(i): validation.NewError(
|
||||
"validation_duplicated_index_definition",
|
||||
"The index definition already exists.",
|
||||
),
|
||||
}
|
||||
}
|
||||
duplicatedDefinitions[parsedDef] = struct{}{}
|
||||
|
||||
// note: we don't check the index table name because it is always
|
||||
// overwritten by the SyncRecordTableSchema to allow
|
||||
// easier partial modifications (eg. changing only the collection name).
|
||||
// if !strings.EqualFold(parsed.TableName, form.Name) {
|
||||
// return validation.Errors{
|
||||
// strconv.Itoa(i): validation.NewError(
|
||||
// "validation_invalid_index_table",
|
||||
// fmt.Sprintf("The index table must be the same as the collection name."),
|
||||
// ),
|
||||
// }
|
||||
// }
|
||||
}
|
||||
|
||||
// ensure that unique indexes on system fields are not changed or removed
|
||||
if !cv.original.IsNew() {
|
||||
OLD_INDEXES_LOOP:
|
||||
for _, oldIndex := range cv.original.Indexes {
|
||||
oldParsed := dbutils.ParseIndex(oldIndex)
|
||||
if !oldParsed.Unique {
|
||||
continue
|
||||
}
|
||||
|
||||
// reset collate and sort since they are not important for the unique constraint
|
||||
for i := range oldParsed.Columns {
|
||||
oldParsed.Columns[i].Collate = ""
|
||||
oldParsed.Columns[i].Sort = ""
|
||||
}
|
||||
|
||||
oldParsedStr := oldParsed.Build()
|
||||
|
||||
for _, column := range oldParsed.Columns {
|
||||
for _, f := range cv.original.Fields {
|
||||
if !f.GetSystem() || !strings.EqualFold(column.Name, f.GetName()) {
|
||||
continue
|
||||
}
|
||||
|
||||
var hasMatch bool
|
||||
for _, newIndex := range cv.new.Indexes {
|
||||
newParsed := dbutils.ParseIndex(newIndex)
|
||||
|
||||
// exclude the non-important identifiers from the check
|
||||
newParsed.SchemaName = oldParsed.SchemaName
|
||||
newParsed.IndexName = oldParsed.IndexName
|
||||
newParsed.TableName = oldParsed.TableName
|
||||
|
||||
// exclude partial constraints
|
||||
newParsed.Where = oldParsed.Where
|
||||
|
||||
// reset collate and sort
|
||||
for i := range newParsed.Columns {
|
||||
newParsed.Columns[i].Collate = ""
|
||||
newParsed.Columns[i].Sort = ""
|
||||
}
|
||||
|
||||
if oldParsedStr == newParsed.Build() {
|
||||
hasMatch = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !hasMatch {
|
||||
return validation.NewError(
|
||||
"validation_invalid_unique_system_field_index",
|
||||
"Unique index definition on system fields ({{.fieldName}}) is invalid or missing.",
|
||||
).SetParams(map[string]any{"fieldName": f.GetName()})
|
||||
}
|
||||
|
||||
continue OLD_INDEXES_LOOP
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// check for required indexes
|
||||
//
|
||||
// note: this is in case the indexes were removed manually when creating/importing new auth collections
|
||||
// and technically it is not necessary because on app.Save() the missing indexes will be reinserted by the system collection hook
|
||||
if cv.new.IsAuth() {
|
||||
requiredNames := []string{FieldNameTokenKey, FieldNameEmail}
|
||||
for _, name := range requiredNames {
|
||||
if _, ok := dbutils.FindSingleColumnUniqueIndex(indexes, name); !ok {
|
||||
return validation.NewError(
|
||||
"validation_missing_required_unique_index",
|
||||
`Missing required unique index for field "{{.fieldName}}".`,
|
||||
).SetParams(map[string]any{"fieldName": name})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (validator *collectionValidator) validateOptions() error {
|
||||
switch validator.new.Type {
|
||||
case CollectionTypeAuth:
|
||||
return validator.new.collectionAuthOptions.validate(validator)
|
||||
case CollectionTypeView:
|
||||
return validator.new.collectionViewOptions.validate(validator)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
909
core/collection_validate_test.go
Normal file
909
core/collection_validate_test.go
Normal file
|
@ -0,0 +1,909 @@
|
|||
package core_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
)
|
||||
|
||||
func TestCollectionValidate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
collection func(app core.App) (*core.Collection, error)
|
||||
expectedErrors []string
|
||||
}{
|
||||
{
|
||||
name: "empty collection",
|
||||
collection: func(app core.App) (*core.Collection, error) {
|
||||
return &core.Collection{}, nil
|
||||
},
|
||||
expectedErrors: []string{
|
||||
"id", "name", "type", "fields", // no default fields because the type is unknown
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "unknown type with all invalid fields",
|
||||
collection: func(app core.App) (*core.Collection, error) {
|
||||
c := &core.Collection{}
|
||||
c.Id = "invalid_id ?!@#$"
|
||||
c.Name = "invalid_name ?!@#$"
|
||||
c.Type = "invalid_type"
|
||||
c.ListRule = types.Pointer("missing = '123'")
|
||||
c.ViewRule = types.Pointer("missing = '123'")
|
||||
c.CreateRule = types.Pointer("missing = '123'")
|
||||
c.UpdateRule = types.Pointer("missing = '123'")
|
||||
c.DeleteRule = types.Pointer("missing = '123'")
|
||||
c.Indexes = []string{"create index '' on '' ()"}
|
||||
|
||||
// type specific fields
|
||||
c.ViewQuery = "invalid" // should be ignored
|
||||
c.AuthRule = types.Pointer("missing = '123'") // should be ignored
|
||||
|
||||
return c, nil
|
||||
},
|
||||
expectedErrors: []string{
|
||||
"id", "name", "type", "indexes",
|
||||
"listRule", "viewRule", "createRule", "updateRule", "deleteRule",
|
||||
"fields", // no default fields because the type is unknown
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "base with invalid fields",
|
||||
collection: func(app core.App) (*core.Collection, error) {
|
||||
c := core.NewBaseCollection("invalid_name ?!@#$")
|
||||
c.Indexes = []string{"create index '' on '' ()"}
|
||||
|
||||
// type specific fields
|
||||
c.ViewQuery = "invalid" // should be ignored
|
||||
c.AuthRule = types.Pointer("missing = '123'") // should be ignored
|
||||
|
||||
return c, nil
|
||||
},
|
||||
expectedErrors: []string{"name", "indexes"},
|
||||
},
|
||||
{
|
||||
name: "view with invalid fields",
|
||||
collection: func(app core.App) (*core.Collection, error) {
|
||||
c := core.NewViewCollection("invalid_name ?!@#$")
|
||||
c.Indexes = []string{"create index '' on '' ()"}
|
||||
|
||||
// type specific fields
|
||||
c.ViewQuery = "invalid"
|
||||
c.AuthRule = types.Pointer("missing = '123'") // should be ignored
|
||||
|
||||
return c, nil
|
||||
},
|
||||
expectedErrors: []string{"indexes", "name", "fields", "viewQuery"},
|
||||
},
|
||||
{
|
||||
name: "auth with invalid fields",
|
||||
collection: func(app core.App) (*core.Collection, error) {
|
||||
c := core.NewAuthCollection("invalid_name ?!@#$")
|
||||
c.Indexes = []string{"create index '' on '' ()"}
|
||||
|
||||
// type specific fields
|
||||
c.ViewQuery = "invalid" // should be ignored
|
||||
c.AuthRule = types.Pointer("missing = '123'")
|
||||
|
||||
return c, nil
|
||||
},
|
||||
expectedErrors: []string{"indexes", "name", "authRule"},
|
||||
},
|
||||
|
||||
// type checks
|
||||
{
|
||||
name: "empty type",
|
||||
collection: func(app core.App) (*core.Collection, error) {
|
||||
c := core.NewBaseCollection("test")
|
||||
c.Type = ""
|
||||
return c, nil
|
||||
},
|
||||
expectedErrors: []string{"type"},
|
||||
},
|
||||
{
|
||||
name: "unknown type",
|
||||
collection: func(app core.App) (*core.Collection, error) {
|
||||
c := core.NewBaseCollection("test")
|
||||
c.Type = "unknown"
|
||||
return c, nil
|
||||
},
|
||||
expectedErrors: []string{"type"},
|
||||
},
|
||||
{
|
||||
name: "base type",
|
||||
collection: func(app core.App) (*core.Collection, error) {
|
||||
c := core.NewBaseCollection("test")
|
||||
return c, nil
|
||||
},
|
||||
expectedErrors: []string{},
|
||||
},
|
||||
{
|
||||
name: "view type",
|
||||
collection: func(app core.App) (*core.Collection, error) {
|
||||
c := core.NewViewCollection("test")
|
||||
c.ViewQuery = "select 1 as id"
|
||||
return c, nil
|
||||
},
|
||||
expectedErrors: []string{},
|
||||
},
|
||||
{
|
||||
name: "auth type",
|
||||
collection: func(app core.App) (*core.Collection, error) {
|
||||
c := core.NewAuthCollection("test")
|
||||
return c, nil
|
||||
},
|
||||
expectedErrors: []string{},
|
||||
},
|
||||
{
|
||||
name: "changing type",
|
||||
collection: func(app core.App) (*core.Collection, error) {
|
||||
c, _ := app.FindCollectionByNameOrId("users")
|
||||
c.Type = core.CollectionTypeBase
|
||||
return c, nil
|
||||
},
|
||||
expectedErrors: []string{"type"},
|
||||
},
|
||||
|
||||
// system checks
|
||||
{
|
||||
name: "change from system to regular",
|
||||
collection: func(app core.App) (*core.Collection, error) {
|
||||
c, _ := app.FindCollectionByNameOrId(core.CollectionNameSuperusers)
|
||||
c.System = false
|
||||
return c, nil
|
||||
},
|
||||
expectedErrors: []string{"system"},
|
||||
},
|
||||
{
|
||||
name: "change from regular to system",
|
||||
collection: func(app core.App) (*core.Collection, error) {
|
||||
c, _ := app.FindCollectionByNameOrId("demo1")
|
||||
c.System = true
|
||||
return c, nil
|
||||
},
|
||||
expectedErrors: []string{"system"},
|
||||
},
|
||||
{
|
||||
name: "create system",
|
||||
collection: func(app core.App) (*core.Collection, error) {
|
||||
c := core.NewBaseCollection("new_system")
|
||||
c.System = true
|
||||
return c, nil
|
||||
},
|
||||
expectedErrors: []string{},
|
||||
},
|
||||
|
||||
// id checks
|
||||
{
|
||||
name: "empty id",
|
||||
collection: func(app core.App) (*core.Collection, error) {
|
||||
c := core.NewBaseCollection("test")
|
||||
c.Id = ""
|
||||
return c, nil
|
||||
},
|
||||
expectedErrors: []string{"id"},
|
||||
},
|
||||
{
|
||||
name: "invalid id",
|
||||
collection: func(app core.App) (*core.Collection, error) {
|
||||
c := core.NewBaseCollection("test")
|
||||
c.Id = "!invalid"
|
||||
return c, nil
|
||||
},
|
||||
expectedErrors: []string{"id"},
|
||||
},
|
||||
{
|
||||
name: "existing id",
|
||||
collection: func(app core.App) (*core.Collection, error) {
|
||||
c := core.NewBaseCollection("test")
|
||||
c.Id = "_pb_users_auth_"
|
||||
return c, nil
|
||||
},
|
||||
expectedErrors: []string{"id"},
|
||||
},
|
||||
{
|
||||
name: "changing id",
|
||||
collection: func(app core.App) (*core.Collection, error) {
|
||||
c, _ := app.FindCollectionByNameOrId("demo3")
|
||||
c.Id = "anything"
|
||||
return c, nil
|
||||
},
|
||||
expectedErrors: []string{"id"},
|
||||
},
|
||||
{
|
||||
name: "valid id",
|
||||
collection: func(app core.App) (*core.Collection, error) {
|
||||
c := core.NewBaseCollection("test")
|
||||
c.Id = "anything"
|
||||
return c, nil
|
||||
},
|
||||
expectedErrors: []string{},
|
||||
},
|
||||
|
||||
// name checks
|
||||
{
|
||||
name: "empty name",
|
||||
collection: func(app core.App) (*core.Collection, error) {
|
||||
c := core.NewBaseCollection("")
|
||||
c.Id = "test"
|
||||
return c, nil
|
||||
},
|
||||
expectedErrors: []string{"name"},
|
||||
},
|
||||
{
|
||||
name: "invalid name",
|
||||
collection: func(app core.App) (*core.Collection, error) {
|
||||
c := core.NewBaseCollection("!invalid")
|
||||
return c, nil
|
||||
},
|
||||
expectedErrors: []string{"name"},
|
||||
},
|
||||
{
|
||||
name: "name with _via_",
|
||||
collection: func(app core.App) (*core.Collection, error) {
|
||||
c := core.NewBaseCollection("a_via_b")
|
||||
return c, nil
|
||||
},
|
||||
expectedErrors: []string{"name"},
|
||||
},
|
||||
{
|
||||
name: "create with existing collection name",
|
||||
collection: func(app core.App) (*core.Collection, error) {
|
||||
c := core.NewBaseCollection("demo1")
|
||||
return c, nil
|
||||
},
|
||||
expectedErrors: []string{"name"},
|
||||
},
|
||||
{
|
||||
name: "create with existing internal table name",
|
||||
collection: func(app core.App) (*core.Collection, error) {
|
||||
c := core.NewBaseCollection("_collections")
|
||||
return c, nil
|
||||
},
|
||||
expectedErrors: []string{"name"},
|
||||
},
|
||||
{
|
||||
name: "update with existing collection name",
|
||||
collection: func(app core.App) (*core.Collection, error) {
|
||||
c, _ := app.FindCollectionByNameOrId("users")
|
||||
c.Name = "demo1"
|
||||
return c, nil
|
||||
},
|
||||
expectedErrors: []string{"name"},
|
||||
},
|
||||
{
|
||||
name: "update with existing internal table name",
|
||||
collection: func(app core.App) (*core.Collection, error) {
|
||||
c, _ := app.FindCollectionByNameOrId("users")
|
||||
c.Name = "_collections"
|
||||
return c, nil
|
||||
},
|
||||
expectedErrors: []string{"name"},
|
||||
},
|
||||
{
|
||||
name: "system collection name change",
|
||||
collection: func(app core.App) (*core.Collection, error) {
|
||||
c, _ := app.FindCollectionByNameOrId(core.CollectionNameSuperusers)
|
||||
c.Name = "superusers_new"
|
||||
return c, nil
|
||||
},
|
||||
expectedErrors: []string{"name"},
|
||||
},
|
||||
{
|
||||
name: "create with valid name",
|
||||
collection: func(app core.App) (*core.Collection, error) {
|
||||
c := core.NewBaseCollection("new_col")
|
||||
return c, nil
|
||||
},
|
||||
expectedErrors: []string{},
|
||||
},
|
||||
{
|
||||
name: "update with valid name",
|
||||
collection: func(app core.App) (*core.Collection, error) {
|
||||
c, _ := app.FindCollectionByNameOrId("demo1")
|
||||
c.Name = "demo1_new"
|
||||
return c, nil
|
||||
},
|
||||
expectedErrors: []string{},
|
||||
},
|
||||
|
||||
// rule checks
|
||||
{
|
||||
name: "invalid base rules",
|
||||
collection: func(app core.App) (*core.Collection, error) {
|
||||
c := core.NewBaseCollection("new")
|
||||
c.ListRule = types.Pointer("!invalid")
|
||||
c.ViewRule = types.Pointer("missing = 123")
|
||||
c.CreateRule = types.Pointer("id = 123 && missing = 456")
|
||||
c.UpdateRule = types.Pointer("(id = 123")
|
||||
c.DeleteRule = types.Pointer("missing = 123")
|
||||
return c, nil
|
||||
},
|
||||
expectedErrors: []string{"listRule", "viewRule", "createRule", "updateRule", "deleteRule"},
|
||||
},
|
||||
{
|
||||
name: "valid base rules",
|
||||
collection: func(app core.App) (*core.Collection, error) {
|
||||
c := core.NewBaseCollection("new")
|
||||
c.Fields.Add(&core.TextField{Name: "f1"}) // dummy field to ensure that new fields can be referenced
|
||||
c.ListRule = types.Pointer("")
|
||||
c.ViewRule = types.Pointer("f1 = 123")
|
||||
c.CreateRule = types.Pointer("id = 123 && f1 = 456")
|
||||
c.UpdateRule = types.Pointer("(id = 123)")
|
||||
c.DeleteRule = types.Pointer("f1 = 123")
|
||||
return c, nil
|
||||
},
|
||||
expectedErrors: []string{},
|
||||
},
|
||||
{
|
||||
name: "view with non-nil create/update/delete rules",
|
||||
collection: func(app core.App) (*core.Collection, error) {
|
||||
c := core.NewViewCollection("new")
|
||||
c.ViewQuery = "select 1 as id, 'text' as f1"
|
||||
c.ListRule = types.Pointer("id = 123")
|
||||
c.ViewRule = types.Pointer("f1 = 456")
|
||||
c.CreateRule = types.Pointer("")
|
||||
c.UpdateRule = types.Pointer("")
|
||||
c.DeleteRule = types.Pointer("")
|
||||
return c, nil
|
||||
},
|
||||
expectedErrors: []string{"createRule", "updateRule", "deleteRule"},
|
||||
},
|
||||
{
|
||||
name: "view with nil create/update/delete rules",
|
||||
collection: func(app core.App) (*core.Collection, error) {
|
||||
c := core.NewViewCollection("new")
|
||||
c.ViewQuery = "select 1 as id, 'text' as f1"
|
||||
c.ListRule = types.Pointer("id = 1")
|
||||
c.ViewRule = types.Pointer("f1 = 456")
|
||||
return c, nil
|
||||
},
|
||||
expectedErrors: []string{},
|
||||
},
|
||||
{
|
||||
name: "changing api rules",
|
||||
collection: func(app core.App) (*core.Collection, error) {
|
||||
c, _ := app.FindCollectionByNameOrId("users")
|
||||
c.Fields.Add(&core.TextField{Name: "f1"}) // dummy field to ensure that new fields can be referenced
|
||||
c.ListRule = types.Pointer("id = 1")
|
||||
c.ViewRule = types.Pointer("f1 = 456")
|
||||
c.CreateRule = types.Pointer("id = 123 && f1 = 456")
|
||||
c.UpdateRule = types.Pointer("(id = 123)")
|
||||
c.DeleteRule = types.Pointer("f1 = 123")
|
||||
return c, nil
|
||||
},
|
||||
expectedErrors: []string{},
|
||||
},
|
||||
{
|
||||
name: "changing system collection api rules",
|
||||
collection: func(app core.App) (*core.Collection, error) {
|
||||
c, _ := app.FindCollectionByNameOrId(core.CollectionNameSuperusers)
|
||||
c.ListRule = types.Pointer("1 = 1")
|
||||
c.ViewRule = types.Pointer("1 = 1")
|
||||
c.CreateRule = types.Pointer("1 = 1")
|
||||
c.UpdateRule = types.Pointer("1 = 1")
|
||||
c.DeleteRule = types.Pointer("1 = 1")
|
||||
c.ManageRule = types.Pointer("1 = 1")
|
||||
c.AuthRule = types.Pointer("1 = 1")
|
||||
return c, nil
|
||||
},
|
||||
expectedErrors: []string{
|
||||
"listRule", "viewRule", "createRule", "updateRule",
|
||||
"deleteRule", "manageRule", "authRule",
|
||||
},
|
||||
},
|
||||
|
||||
// indexes checks
|
||||
{
|
||||
name: "invalid index expression",
|
||||
collection: func(app core.App) (*core.Collection, error) {
|
||||
c, _ := app.FindCollectionByNameOrId("demo1")
|
||||
c.Indexes = []string{
|
||||
"create index invalid",
|
||||
"create index idx_test_demo2 on anything (text)", // the name of table shouldn't matter
|
||||
}
|
||||
return c, nil
|
||||
},
|
||||
expectedErrors: []string{"indexes"},
|
||||
},
|
||||
{
|
||||
name: "index name used in other table",
|
||||
collection: func(app core.App) (*core.Collection, error) {
|
||||
c, _ := app.FindCollectionByNameOrId("demo1")
|
||||
c.Indexes = []string{
|
||||
"create index `idx_test_demo1` on demo1 (id)",
|
||||
"create index `__pb_USERS_auth__username_idx` on anything (text)", // should be case-insensitive
|
||||
}
|
||||
return c, nil
|
||||
},
|
||||
expectedErrors: []string{"indexes"},
|
||||
},
|
||||
{
|
||||
name: "duplicated index names",
|
||||
collection: func(app core.App) (*core.Collection, error) {
|
||||
c, _ := app.FindCollectionByNameOrId("demo1")
|
||||
c.Indexes = []string{
|
||||
"create index idx_test_demo1 on demo1 (id)",
|
||||
"create index idx_test_demo1 on anything (text)",
|
||||
}
|
||||
return c, nil
|
||||
},
|
||||
expectedErrors: []string{"indexes"},
|
||||
},
|
||||
{
|
||||
name: "duplicated index definitions",
|
||||
collection: func(app core.App) (*core.Collection, error) {
|
||||
c, _ := app.FindCollectionByNameOrId("demo1")
|
||||
c.Indexes = []string{
|
||||
"create index idx_test_demo1 on demo1 (id)",
|
||||
"create index idx_test_demo2 on demo1 (id)",
|
||||
}
|
||||
return c, nil
|
||||
},
|
||||
expectedErrors: []string{"indexes"},
|
||||
},
|
||||
{
|
||||
name: "try to add index to a view collection",
|
||||
collection: func(app core.App) (*core.Collection, error) {
|
||||
c, _ := app.FindCollectionByNameOrId("view1")
|
||||
c.Indexes = []string{"create index idx_test_view1 on view1 (id)"}
|
||||
return c, nil
|
||||
},
|
||||
expectedErrors: []string{"indexes"},
|
||||
},
|
||||
{
|
||||
name: "replace old with new indexes",
|
||||
collection: func(app core.App) (*core.Collection, error) {
|
||||
c, _ := app.FindCollectionByNameOrId("demo1")
|
||||
c.Indexes = []string{
|
||||
"create index idx_test_demo1 on demo1 (id)",
|
||||
"create index idx_test_demo2 on anything (text)", // the name of table shouldn't matter
|
||||
}
|
||||
return c, nil
|
||||
},
|
||||
expectedErrors: []string{},
|
||||
},
|
||||
{
|
||||
name: "old + new indexes",
|
||||
collection: func(app core.App) (*core.Collection, error) {
|
||||
c, _ := app.FindCollectionByNameOrId("demo1")
|
||||
c.Indexes = []string{
|
||||
"CREATE INDEX `_wsmn24bux7wo113_created_idx` ON `demo1` (`created`)",
|
||||
"create index idx_test_demo1 on anything (id)",
|
||||
}
|
||||
return c, nil
|
||||
},
|
||||
expectedErrors: []string{},
|
||||
},
|
||||
{
|
||||
name: "index for missing field",
|
||||
collection: func(app core.App) (*core.Collection, error) {
|
||||
c, _ := app.FindCollectionByNameOrId("demo1")
|
||||
c.Indexes = []string{
|
||||
"create index idx_test_demo1 on anything (missing)", // still valid because it is checked on db persist
|
||||
}
|
||||
return c, nil
|
||||
},
|
||||
expectedErrors: []string{},
|
||||
},
|
||||
{
|
||||
name: "auth collection with missing required unique indexes",
|
||||
collection: func(app core.App) (*core.Collection, error) {
|
||||
c := core.NewAuthCollection("new_auth")
|
||||
c.Indexes = []string{}
|
||||
return c, nil
|
||||
},
|
||||
expectedErrors: []string{"indexes", "passwordAuth"},
|
||||
},
|
||||
{
|
||||
name: "auth collection with non-unique required indexes",
|
||||
collection: func(app core.App) (*core.Collection, error) {
|
||||
c := core.NewAuthCollection("new_auth")
|
||||
c.Indexes = []string{
|
||||
"create index test_idx1 on new_auth (tokenKey)",
|
||||
"create index test_idx2 on new_auth (email)",
|
||||
}
|
||||
return c, nil
|
||||
},
|
||||
expectedErrors: []string{"indexes", "passwordAuth"},
|
||||
},
|
||||
{
|
||||
name: "auth collection with unique required indexes",
|
||||
collection: func(app core.App) (*core.Collection, error) {
|
||||
c := core.NewAuthCollection("new_auth")
|
||||
c.Indexes = []string{
|
||||
"create unique index test_idx1 on new_auth (tokenKey)",
|
||||
"create unique index test_idx2 on new_auth (email)",
|
||||
}
|
||||
return c, nil
|
||||
},
|
||||
expectedErrors: []string{},
|
||||
},
|
||||
{
|
||||
name: "removing index on system field",
|
||||
collection: func(app core.App) (*core.Collection, error) {
|
||||
demo2, err := app.FindCollectionByNameOrId("demo2")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// mark the title field as system
|
||||
demo2.Fields.GetByName("title").SetSystem(true)
|
||||
if err = app.Save(demo2); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// refresh
|
||||
demo2, err = app.FindCollectionByNameOrId("demo2")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
demo2.RemoveIndex("idx_unique_demo2_title")
|
||||
|
||||
return demo2, nil
|
||||
},
|
||||
expectedErrors: []string{"indexes"},
|
||||
},
|
||||
{
|
||||
name: "changing partial constraint of existing index on system field",
|
||||
collection: func(app core.App) (*core.Collection, error) {
|
||||
demo2, err := app.FindCollectionByNameOrId("demo2")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// mark the title field as system
|
||||
demo2.Fields.GetByName("title").SetSystem(true)
|
||||
if err = app.Save(demo2); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// refresh
|
||||
demo2, err = app.FindCollectionByNameOrId("demo2")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// replace the index with a partial one
|
||||
demo2.RemoveIndex("idx_unique_demo2_title")
|
||||
demo2.AddIndex("idx_new_demo2_title", true, "title", "1 = 1")
|
||||
|
||||
return demo2, nil
|
||||
},
|
||||
expectedErrors: []string{},
|
||||
},
|
||||
{
|
||||
name: "changing column sort and collate of existing index on system field",
|
||||
collection: func(app core.App) (*core.Collection, error) {
|
||||
demo2, err := app.FindCollectionByNameOrId("demo2")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// mark the title field as system
|
||||
demo2.Fields.GetByName("title").SetSystem(true)
|
||||
if err = app.Save(demo2); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// refresh
|
||||
demo2, err = app.FindCollectionByNameOrId("demo2")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// replace the index with a new one for the same column but with collate and sort
|
||||
demo2.RemoveIndex("idx_unique_demo2_title")
|
||||
demo2.AddIndex("idx_new_demo2_title", true, "title COLLATE test ASC", "")
|
||||
|
||||
return demo2, nil
|
||||
},
|
||||
expectedErrors: []string{},
|
||||
},
|
||||
{
|
||||
name: "adding new column to index on system field",
|
||||
collection: func(app core.App) (*core.Collection, error) {
|
||||
demo2, err := app.FindCollectionByNameOrId("demo2")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// mark the title field as system
|
||||
demo2.Fields.GetByName("title").SetSystem(true)
|
||||
if err = app.Save(demo2); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// refresh
|
||||
demo2, err = app.FindCollectionByNameOrId("demo2")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// replace the index with a non-unique one
|
||||
demo2.RemoveIndex("idx_unique_demo2_title")
|
||||
demo2.AddIndex("idx_new_title", false, "title, id", "")
|
||||
|
||||
return demo2, nil
|
||||
},
|
||||
expectedErrors: []string{"indexes"},
|
||||
},
|
||||
{
|
||||
name: "changing index type on system field",
|
||||
collection: func(app core.App) (*core.Collection, error) {
|
||||
demo2, err := app.FindCollectionByNameOrId("demo2")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// mark the title field as system
|
||||
demo2.Fields.GetByName("title").SetSystem(true)
|
||||
if err = app.Save(demo2); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// refresh
|
||||
demo2, err = app.FindCollectionByNameOrId("demo2")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// replace the index with a non-unique one (partial constraints are ignored)
|
||||
demo2.RemoveIndex("idx_unique_demo2_title")
|
||||
demo2.AddIndex("idx_new_title", false, "title", "1=1")
|
||||
|
||||
return demo2, nil
|
||||
},
|
||||
expectedErrors: []string{"indexes"},
|
||||
},
|
||||
{
|
||||
name: "changing index on non-system field",
|
||||
collection: func(app core.App) (*core.Collection, error) {
|
||||
demo2, err := app.FindCollectionByNameOrId("demo2")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// replace the index with a partial one
|
||||
demo2.RemoveIndex("idx_demo2_active")
|
||||
demo2.AddIndex("idx_demo2_active", true, "active", "1 = 1")
|
||||
|
||||
return demo2, nil
|
||||
},
|
||||
expectedErrors: []string{},
|
||||
},
|
||||
|
||||
// fields list checks
|
||||
{
|
||||
name: "empty fields",
|
||||
collection: func(app core.App) (*core.Collection, error) {
|
||||
c := core.NewBaseCollection("new_auth")
|
||||
c.Fields = nil // the minimum fields should auto added
|
||||
return c, nil
|
||||
},
|
||||
expectedErrors: []string{"fields"},
|
||||
},
|
||||
{
|
||||
name: "no id primay key field",
|
||||
collection: func(app core.App) (*core.Collection, error) {
|
||||
c := core.NewBaseCollection("new_auth")
|
||||
c.Fields = core.NewFieldsList(
|
||||
&core.TextField{Name: "id"},
|
||||
)
|
||||
return c, nil
|
||||
},
|
||||
expectedErrors: []string{"fields"},
|
||||
},
|
||||
{
|
||||
name: "with id primay key field",
|
||||
collection: func(app core.App) (*core.Collection, error) {
|
||||
c := core.NewBaseCollection("new_auth")
|
||||
c.Fields = core.NewFieldsList(
|
||||
&core.TextField{Name: "id", PrimaryKey: true, Required: true, Pattern: `\w+`},
|
||||
)
|
||||
return c, nil
|
||||
},
|
||||
expectedErrors: []string{},
|
||||
},
|
||||
{
|
||||
name: "duplicated field names",
|
||||
collection: func(app core.App) (*core.Collection, error) {
|
||||
c := core.NewBaseCollection("new_auth")
|
||||
c.Fields = core.NewFieldsList(
|
||||
&core.TextField{Name: "id", PrimaryKey: true, Required: true, Pattern: `\w+`},
|
||||
&core.TextField{Id: "f1", Name: "Test"}, // case-insensitive
|
||||
&core.BoolField{Id: "f2", Name: "test"},
|
||||
)
|
||||
return c, nil
|
||||
},
|
||||
expectedErrors: []string{"fields"},
|
||||
},
|
||||
{
|
||||
name: "changing field type",
|
||||
collection: func(app core.App) (*core.Collection, error) {
|
||||
c, _ := app.FindCollectionByNameOrId("demo1")
|
||||
f := c.Fields.GetByName("text")
|
||||
c.Fields.Add(&core.BoolField{Id: f.GetId(), Name: f.GetName()})
|
||||
return c, nil
|
||||
},
|
||||
expectedErrors: []string{"fields"},
|
||||
},
|
||||
{
|
||||
name: "renaming system field",
|
||||
collection: func(app core.App) (*core.Collection, error) {
|
||||
c, _ := app.FindCollectionByNameOrId(core.CollectionNameAuthOrigins)
|
||||
f := c.Fields.GetByName("fingerprint")
|
||||
f.SetName("fingerprint_new")
|
||||
return c, nil
|
||||
},
|
||||
expectedErrors: []string{"fields"},
|
||||
},
|
||||
{
|
||||
name: "deleting system field",
|
||||
collection: func(app core.App) (*core.Collection, error) {
|
||||
c, _ := app.FindCollectionByNameOrId(core.CollectionNameAuthOrigins)
|
||||
c.Fields.RemoveByName("fingerprint")
|
||||
return c, nil
|
||||
},
|
||||
expectedErrors: []string{"fields"},
|
||||
},
|
||||
{
|
||||
name: "invalid field setting",
|
||||
collection: func(app core.App) (*core.Collection, error) {
|
||||
c := core.NewBaseCollection("test_new")
|
||||
c.Fields.Add(&core.TextField{Name: "f1", Min: -10})
|
||||
return c, nil
|
||||
},
|
||||
expectedErrors: []string{"fields"},
|
||||
},
|
||||
{
|
||||
name: "valid field setting",
|
||||
collection: func(app core.App) (*core.Collection, error) {
|
||||
c := core.NewBaseCollection("test_new")
|
||||
c.Fields.Add(&core.TextField{Name: "f1", Min: 10})
|
||||
return c, nil
|
||||
},
|
||||
expectedErrors: []string{},
|
||||
},
|
||||
{
|
||||
name: "fields view changes should be ignored",
|
||||
collection: func(app core.App) (*core.Collection, error) {
|
||||
c, _ := app.FindCollectionByNameOrId("view1")
|
||||
c.Fields = nil
|
||||
return c, nil
|
||||
},
|
||||
expectedErrors: []string{},
|
||||
},
|
||||
{
|
||||
name: "with reserved auth only field name (passwordConfirm)",
|
||||
collection: func(app core.App) (*core.Collection, error) {
|
||||
c := core.NewAuthCollection("new_auth")
|
||||
c.Fields.Add(
|
||||
&core.TextField{Name: "passwordConfirm"},
|
||||
)
|
||||
return c, nil
|
||||
},
|
||||
expectedErrors: []string{"fields"},
|
||||
},
|
||||
{
|
||||
name: "with reserved auth only field name (oldPassword)",
|
||||
collection: func(app core.App) (*core.Collection, error) {
|
||||
c := core.NewAuthCollection("new_auth")
|
||||
c.Fields.Add(
|
||||
&core.TextField{Name: "oldPassword"},
|
||||
)
|
||||
return c, nil
|
||||
},
|
||||
expectedErrors: []string{"fields"},
|
||||
},
|
||||
{
|
||||
name: "with invalid password auth field options (1)",
|
||||
collection: func(app core.App) (*core.Collection, error) {
|
||||
c := core.NewAuthCollection("new_auth")
|
||||
c.Fields.Add(
|
||||
&core.TextField{Name: "password", System: true, Hidden: true}, // should be PasswordField
|
||||
)
|
||||
return c, nil
|
||||
},
|
||||
expectedErrors: []string{"fields"},
|
||||
},
|
||||
{
|
||||
name: "with valid password auth field options (2)",
|
||||
collection: func(app core.App) (*core.Collection, error) {
|
||||
c := core.NewAuthCollection("new_auth")
|
||||
c.Fields.Add(
|
||||
&core.PasswordField{Name: "password", System: true, Hidden: true},
|
||||
)
|
||||
return c, nil
|
||||
},
|
||||
expectedErrors: []string{},
|
||||
},
|
||||
{
|
||||
name: "with invalid tokenKey auth field options (1)",
|
||||
collection: func(app core.App) (*core.Collection, error) {
|
||||
c := core.NewAuthCollection("new_auth")
|
||||
c.Fields.Add(
|
||||
&core.TextField{Name: "tokenKey", System: true}, // should be also hidden
|
||||
)
|
||||
return c, nil
|
||||
},
|
||||
expectedErrors: []string{"fields"},
|
||||
},
|
||||
{
|
||||
name: "with valid tokenKey auth field options (2)",
|
||||
collection: func(app core.App) (*core.Collection, error) {
|
||||
c := core.NewAuthCollection("new_auth")
|
||||
c.Fields.Add(
|
||||
&core.TextField{Name: "tokenKey", System: true, Hidden: true},
|
||||
)
|
||||
return c, nil
|
||||
},
|
||||
expectedErrors: []string{},
|
||||
},
|
||||
{
|
||||
name: "with invalid email auth field options (1)",
|
||||
collection: func(app core.App) (*core.Collection, error) {
|
||||
c := core.NewAuthCollection("new_auth")
|
||||
c.Fields.Add(
|
||||
&core.TextField{Name: "email", System: true}, // should be EmailField
|
||||
)
|
||||
return c, nil
|
||||
},
|
||||
expectedErrors: []string{"fields"},
|
||||
},
|
||||
{
|
||||
name: "with valid email auth field options (2)",
|
||||
collection: func(app core.App) (*core.Collection, error) {
|
||||
c := core.NewAuthCollection("new_auth")
|
||||
c.Fields.Add(
|
||||
&core.EmailField{Name: "email", System: true},
|
||||
)
|
||||
return c, nil
|
||||
},
|
||||
expectedErrors: []string{},
|
||||
},
|
||||
{
|
||||
name: "with invalid verified auth field options (1)",
|
||||
collection: func(app core.App) (*core.Collection, error) {
|
||||
c := core.NewAuthCollection("new_auth")
|
||||
c.Fields.Add(
|
||||
&core.TextField{Name: "verified", System: true}, // should be BoolField
|
||||
)
|
||||
return c, nil
|
||||
},
|
||||
expectedErrors: []string{"fields"},
|
||||
},
|
||||
{
|
||||
name: "with valid verified auth field options (2)",
|
||||
collection: func(app core.App) (*core.Collection, error) {
|
||||
c := core.NewAuthCollection("new_auth")
|
||||
c.Fields.Add(
|
||||
&core.BoolField{Name: "verified", System: true},
|
||||
)
|
||||
return c, nil
|
||||
},
|
||||
expectedErrors: []string{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
collection, err := s.collection(app)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to retrieve test collection: %v", err)
|
||||
}
|
||||
|
||||
result := app.Validate(collection)
|
||||
|
||||
tests.TestValidationErrors(t, result, s.expectedErrors)
|
||||
})
|
||||
}
|
||||
}
|
499
core/db.go
Normal file
499
core/db.go
Normal file
|
@ -0,0 +1,499 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash/crc32"
|
||||
"regexp"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/tools/security"
|
||||
"github.com/spf13/cast"
|
||||
)
|
||||
|
||||
const (
|
||||
idColumn string = "id"
|
||||
|
||||
// DefaultIdLength is the default length of the generated model id.
|
||||
DefaultIdLength int = 15
|
||||
|
||||
// DefaultIdAlphabet is the default characters set used for generating the model id.
|
||||
DefaultIdAlphabet string = "abcdefghijklmnopqrstuvwxyz0123456789"
|
||||
)
|
||||
|
||||
// DefaultIdRegex specifies the default regex pattern for an id value.
|
||||
var DefaultIdRegex = regexp.MustCompile(`^\w+$`)
|
||||
|
||||
// DBExporter defines an interface for custom DB data export.
|
||||
// Usually used as part of [App.Save].
|
||||
type DBExporter interface {
|
||||
// DBExport returns a key-value map with the data to be used when saving the struct in the database.
|
||||
DBExport(app App) (map[string]any, error)
|
||||
}
|
||||
|
||||
// PreValidator defines an optional model interface for registering a
|
||||
// function that will run BEFORE firing the validation hooks (see [App.ValidateWithContext]).
|
||||
type PreValidator interface {
|
||||
// PreValidate defines a function that runs BEFORE the validation hooks.
|
||||
PreValidate(ctx context.Context, app App) error
|
||||
}
|
||||
|
||||
// PostValidator defines an optional model interface for registering a
|
||||
// function that will run AFTER executing the validation hooks (see [App.ValidateWithContext]).
|
||||
type PostValidator interface {
|
||||
// PostValidate defines a function that runs AFTER the successful
|
||||
// execution of the validation hooks.
|
||||
PostValidate(ctx context.Context, app App) error
|
||||
}
|
||||
|
||||
// GenerateDefaultRandomId generates a default random id string
|
||||
// (note: the generated random string is not intended for security purposes).
|
||||
func GenerateDefaultRandomId() string {
|
||||
return security.PseudorandomStringWithAlphabet(DefaultIdLength, DefaultIdAlphabet)
|
||||
}
|
||||
|
||||
// crc32Checksum generates a stringified crc32 checksum from the provided plain string.
|
||||
func crc32Checksum(str string) string {
|
||||
return strconv.FormatInt(int64(crc32.ChecksumIEEE([]byte(str))), 10)
|
||||
}
|
||||
|
||||
// ModelQuery creates a new preconfigured select data.db query with preset
|
||||
// SELECT, FROM and other common fields based on the provided model.
|
||||
func (app *BaseApp) ModelQuery(m Model) *dbx.SelectQuery {
|
||||
return app.modelQuery(app.ConcurrentDB(), m)
|
||||
}
|
||||
|
||||
// AuxModelQuery creates a new preconfigured select auxiliary.db query with preset
|
||||
// SELECT, FROM and other common fields based on the provided model.
|
||||
func (app *BaseApp) AuxModelQuery(m Model) *dbx.SelectQuery {
|
||||
return app.modelQuery(app.AuxConcurrentDB(), m)
|
||||
}
|
||||
|
||||
func (app *BaseApp) modelQuery(db dbx.Builder, m Model) *dbx.SelectQuery {
|
||||
tableName := m.TableName()
|
||||
|
||||
return db.
|
||||
Select("{{" + tableName + "}}.*").
|
||||
From(tableName).
|
||||
WithBuildHook(func(query *dbx.Query) {
|
||||
query.WithExecHook(execLockRetry(app.config.QueryTimeout, defaultMaxLockRetries))
|
||||
})
|
||||
}
|
||||
|
||||
// Delete deletes the specified model from the regular app database.
|
||||
func (app *BaseApp) Delete(model Model) error {
|
||||
return app.DeleteWithContext(context.Background(), model)
|
||||
}
|
||||
|
||||
// Delete deletes the specified model from the regular app database
|
||||
// (the context could be used to limit the query execution).
|
||||
func (app *BaseApp) DeleteWithContext(ctx context.Context, model Model) error {
|
||||
return app.delete(ctx, model, false)
|
||||
}
|
||||
|
||||
// AuxDelete deletes the specified model from the auxiliary database.
|
||||
func (app *BaseApp) AuxDelete(model Model) error {
|
||||
return app.AuxDeleteWithContext(context.Background(), model)
|
||||
}
|
||||
|
||||
// AuxDeleteWithContext deletes the specified model from the auxiliary database
|
||||
// (the context could be used to limit the query execution).
|
||||
func (app *BaseApp) AuxDeleteWithContext(ctx context.Context, model Model) error {
|
||||
return app.delete(ctx, model, true)
|
||||
}
|
||||
|
||||
func (app *BaseApp) delete(ctx context.Context, model Model, isForAuxDB bool) error {
|
||||
event := new(ModelEvent)
|
||||
event.App = app
|
||||
event.Type = ModelEventTypeDelete
|
||||
event.Context = ctx
|
||||
event.Model = model
|
||||
|
||||
deleteErr := app.OnModelDelete().Trigger(event, func(e *ModelEvent) error {
|
||||
pk := cast.ToString(e.Model.LastSavedPK())
|
||||
if pk == "" {
|
||||
return errors.New("the model can be deleted only if it is existing and has a non-empty primary key")
|
||||
}
|
||||
|
||||
// db write
|
||||
return e.App.OnModelDeleteExecute().Trigger(event, func(e *ModelEvent) error {
|
||||
var db dbx.Builder
|
||||
if isForAuxDB {
|
||||
db = e.App.AuxNonconcurrentDB()
|
||||
} else {
|
||||
db = e.App.NonconcurrentDB()
|
||||
}
|
||||
|
||||
return baseLockRetry(func(attempt int) error {
|
||||
_, err := db.Delete(e.Model.TableName(), dbx.HashExp{
|
||||
idColumn: pk,
|
||||
}).WithContext(e.Context).Execute()
|
||||
|
||||
return err
|
||||
}, defaultMaxLockRetries)
|
||||
})
|
||||
})
|
||||
if deleteErr != nil {
|
||||
errEvent := &ModelErrorEvent{ModelEvent: *event, Error: deleteErr}
|
||||
errEvent.App = app // replace with the initial app in case it was changed by the hook
|
||||
hookErr := app.OnModelAfterDeleteError().Trigger(errEvent)
|
||||
if hookErr != nil {
|
||||
return errors.Join(deleteErr, hookErr)
|
||||
}
|
||||
|
||||
return deleteErr
|
||||
}
|
||||
|
||||
if app.txInfo != nil {
|
||||
// execute later after the transaction has completed
|
||||
app.txInfo.OnComplete(func(txErr error) error {
|
||||
if app.txInfo != nil && app.txInfo.parent != nil {
|
||||
event.App = app.txInfo.parent
|
||||
}
|
||||
|
||||
if txErr != nil {
|
||||
return app.OnModelAfterDeleteError().Trigger(&ModelErrorEvent{
|
||||
ModelEvent: *event,
|
||||
Error: txErr,
|
||||
})
|
||||
}
|
||||
|
||||
return app.OnModelAfterDeleteSuccess().Trigger(event)
|
||||
})
|
||||
} else if err := event.App.OnModelAfterDeleteSuccess().Trigger(event); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Save validates and saves the specified model into the regular app database.
|
||||
//
|
||||
// If you don't want to run validations, use [App.SaveNoValidate()].
|
||||
func (app *BaseApp) Save(model Model) error {
|
||||
return app.SaveWithContext(context.Background(), model)
|
||||
}
|
||||
|
||||
// SaveWithContext is the same as [App.Save()] but allows specifying a context to limit the db execution.
|
||||
//
|
||||
// If you don't want to run validations, use [App.SaveNoValidateWithContext()].
|
||||
func (app *BaseApp) SaveWithContext(ctx context.Context, model Model) error {
|
||||
return app.save(ctx, model, true, false)
|
||||
}
|
||||
|
||||
// SaveNoValidate saves the specified model into the regular app database without performing validations.
|
||||
//
|
||||
// If you want to also run validations before persisting, use [App.Save()].
|
||||
func (app *BaseApp) SaveNoValidate(model Model) error {
|
||||
return app.SaveNoValidateWithContext(context.Background(), model)
|
||||
}
|
||||
|
||||
// SaveNoValidateWithContext is the same as [App.SaveNoValidate()]
|
||||
// but allows specifying a context to limit the db execution.
|
||||
//
|
||||
// If you want to also run validations before persisting, use [App.SaveWithContext()].
|
||||
func (app *BaseApp) SaveNoValidateWithContext(ctx context.Context, model Model) error {
|
||||
return app.save(ctx, model, false, false)
|
||||
}
|
||||
|
||||
// AuxSave validates and saves the specified model into the auxiliary app database.
|
||||
//
|
||||
// If you don't want to run validations, use [App.AuxSaveNoValidate()].
|
||||
func (app *BaseApp) AuxSave(model Model) error {
|
||||
return app.AuxSaveWithContext(context.Background(), model)
|
||||
}
|
||||
|
||||
// AuxSaveWithContext is the same as [App.AuxSave()] but allows specifying a context to limit the db execution.
|
||||
//
|
||||
// If you don't want to run validations, use [App.AuxSaveNoValidateWithContext()].
|
||||
func (app *BaseApp) AuxSaveWithContext(ctx context.Context, model Model) error {
|
||||
return app.save(ctx, model, true, true)
|
||||
}
|
||||
|
||||
// AuxSaveNoValidate saves the specified model into the auxiliary app database without performing validations.
|
||||
//
|
||||
// If you want to also run validations before persisting, use [App.AuxSave()].
|
||||
func (app *BaseApp) AuxSaveNoValidate(model Model) error {
|
||||
return app.AuxSaveNoValidateWithContext(context.Background(), model)
|
||||
}
|
||||
|
||||
// AuxSaveNoValidateWithContext is the same as [App.AuxSaveNoValidate()]
|
||||
// but allows specifying a context to limit the db execution.
|
||||
//
|
||||
// If you want to also run validations before persisting, use [App.AuxSaveWithContext()].
|
||||
func (app *BaseApp) AuxSaveNoValidateWithContext(ctx context.Context, model Model) error {
|
||||
return app.save(ctx, model, false, true)
|
||||
}
|
||||
|
||||
// Validate triggers the OnModelValidate hook for the specified model.
|
||||
func (app *BaseApp) Validate(model Model) error {
|
||||
return app.ValidateWithContext(context.Background(), model)
|
||||
}
|
||||
|
||||
// ValidateWithContext is the same as Validate but allows specifying the ModelEvent context.
|
||||
func (app *BaseApp) ValidateWithContext(ctx context.Context, model Model) error {
|
||||
if m, ok := model.(PreValidator); ok {
|
||||
if err := m.PreValidate(ctx, app); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
event := new(ModelEvent)
|
||||
event.App = app
|
||||
event.Context = ctx
|
||||
event.Type = ModelEventTypeValidate
|
||||
event.Model = model
|
||||
|
||||
return event.App.OnModelValidate().Trigger(event, func(e *ModelEvent) error {
|
||||
if m, ok := e.Model.(PostValidator); ok {
|
||||
if err := m.PostValidate(ctx, e.App); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return e.Next()
|
||||
})
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
func (app *BaseApp) save(ctx context.Context, model Model, withValidations bool, isForAuxDB bool) error {
|
||||
if model.IsNew() {
|
||||
return app.create(ctx, model, withValidations, isForAuxDB)
|
||||
}
|
||||
|
||||
return app.update(ctx, model, withValidations, isForAuxDB)
|
||||
}
|
||||
|
||||
func (app *BaseApp) create(ctx context.Context, model Model, withValidations bool, isForAuxDB bool) error {
|
||||
event := new(ModelEvent)
|
||||
event.App = app
|
||||
event.Context = ctx
|
||||
event.Type = ModelEventTypeCreate
|
||||
event.Model = model
|
||||
|
||||
saveErr := app.OnModelCreate().Trigger(event, func(e *ModelEvent) error {
|
||||
// run validations (if any)
|
||||
if withValidations {
|
||||
validateErr := e.App.ValidateWithContext(e.Context, e.Model)
|
||||
if validateErr != nil {
|
||||
return validateErr
|
||||
}
|
||||
}
|
||||
|
||||
// db write
|
||||
return e.App.OnModelCreateExecute().Trigger(event, func(e *ModelEvent) error {
|
||||
var db dbx.Builder
|
||||
if isForAuxDB {
|
||||
db = e.App.AuxNonconcurrentDB()
|
||||
} else {
|
||||
db = e.App.NonconcurrentDB()
|
||||
}
|
||||
|
||||
dbErr := baseLockRetry(func(attempt int) error {
|
||||
if m, ok := e.Model.(DBExporter); ok {
|
||||
data, err := m.DBExport(e.App)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// manually add the id to the data if missing
|
||||
if _, ok := data[idColumn]; !ok {
|
||||
data[idColumn] = e.Model.PK()
|
||||
}
|
||||
|
||||
if cast.ToString(data[idColumn]) == "" {
|
||||
return errors.New("empty primary key is not allowed when using the DBExporter interface")
|
||||
}
|
||||
|
||||
_, err = db.Insert(e.Model.TableName(), data).WithContext(e.Context).Execute()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
return db.Model(e.Model).WithContext(e.Context).Insert()
|
||||
}, defaultMaxLockRetries)
|
||||
if dbErr != nil {
|
||||
return dbErr
|
||||
}
|
||||
|
||||
e.Model.MarkAsNotNew()
|
||||
|
||||
return nil
|
||||
})
|
||||
})
|
||||
if saveErr != nil {
|
||||
event.Model.MarkAsNew() // reset "new" state
|
||||
|
||||
errEvent := &ModelErrorEvent{ModelEvent: *event, Error: saveErr}
|
||||
errEvent.App = app // replace with the initial app in case it was changed by the hook
|
||||
hookErr := app.OnModelAfterCreateError().Trigger(errEvent)
|
||||
if hookErr != nil {
|
||||
return errors.Join(saveErr, hookErr)
|
||||
}
|
||||
|
||||
return saveErr
|
||||
}
|
||||
|
||||
if app.txInfo != nil {
|
||||
// execute later after the transaction has completed
|
||||
app.txInfo.OnComplete(func(txErr error) error {
|
||||
if app.txInfo != nil && app.txInfo.parent != nil {
|
||||
event.App = app.txInfo.parent
|
||||
}
|
||||
|
||||
if txErr != nil {
|
||||
event.Model.MarkAsNew() // reset "new" state
|
||||
|
||||
return app.OnModelAfterCreateError().Trigger(&ModelErrorEvent{
|
||||
ModelEvent: *event,
|
||||
Error: txErr,
|
||||
})
|
||||
}
|
||||
|
||||
return app.OnModelAfterCreateSuccess().Trigger(event)
|
||||
})
|
||||
} else if err := event.App.OnModelAfterCreateSuccess().Trigger(event); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (app *BaseApp) update(ctx context.Context, model Model, withValidations bool, isForAuxDB bool) error {
|
||||
event := new(ModelEvent)
|
||||
event.App = app
|
||||
event.Context = ctx
|
||||
event.Type = ModelEventTypeUpdate
|
||||
event.Model = model
|
||||
|
||||
saveErr := app.OnModelUpdate().Trigger(event, func(e *ModelEvent) error {
|
||||
// run validations (if any)
|
||||
if withValidations {
|
||||
validateErr := e.App.ValidateWithContext(e.Context, e.Model)
|
||||
if validateErr != nil {
|
||||
return validateErr
|
||||
}
|
||||
}
|
||||
|
||||
// db write
|
||||
return e.App.OnModelUpdateExecute().Trigger(event, func(e *ModelEvent) error {
|
||||
var db dbx.Builder
|
||||
if isForAuxDB {
|
||||
db = e.App.AuxNonconcurrentDB()
|
||||
} else {
|
||||
db = e.App.NonconcurrentDB()
|
||||
}
|
||||
|
||||
return baseLockRetry(func(attempt int) error {
|
||||
if m, ok := e.Model.(DBExporter); ok {
|
||||
data, err := m.DBExport(e.App)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// note: for now disallow primary key change for consistency with dbx.ModelQuery.Update()
|
||||
if data[idColumn] != e.Model.LastSavedPK() {
|
||||
return errors.New("primary key change is not allowed")
|
||||
}
|
||||
|
||||
_, err = db.Update(e.Model.TableName(), data, dbx.HashExp{
|
||||
idColumn: e.Model.LastSavedPK(),
|
||||
}).WithContext(e.Context).Execute()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
return db.Model(e.Model).WithContext(e.Context).Update()
|
||||
}, defaultMaxLockRetries)
|
||||
})
|
||||
})
|
||||
if saveErr != nil {
|
||||
errEvent := &ModelErrorEvent{ModelEvent: *event, Error: saveErr}
|
||||
errEvent.App = app // replace with the initial app in case it was changed by the hook
|
||||
hookErr := app.OnModelAfterUpdateError().Trigger(errEvent)
|
||||
if hookErr != nil {
|
||||
return errors.Join(saveErr, hookErr)
|
||||
}
|
||||
|
||||
return saveErr
|
||||
}
|
||||
|
||||
if app.txInfo != nil {
|
||||
// execute later after the transaction has completed
|
||||
app.txInfo.OnComplete(func(txErr error) error {
|
||||
if app.txInfo != nil && app.txInfo.parent != nil {
|
||||
event.App = app.txInfo.parent
|
||||
}
|
||||
|
||||
if txErr != nil {
|
||||
return app.OnModelAfterUpdateError().Trigger(&ModelErrorEvent{
|
||||
ModelEvent: *event,
|
||||
Error: txErr,
|
||||
})
|
||||
}
|
||||
|
||||
return app.OnModelAfterUpdateSuccess().Trigger(event)
|
||||
})
|
||||
} else if err := event.App.OnModelAfterUpdateSuccess().Trigger(event); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateCollectionId(app App, optTypes ...string) validation.RuleFunc {
|
||||
return func(value any) error {
|
||||
id, _ := value.(string)
|
||||
if id == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
collection := &Collection{}
|
||||
if err := app.ModelQuery(collection).Model(id, collection); err != nil {
|
||||
return validation.NewError("validation_invalid_collection_id", "Missing or invalid collection.")
|
||||
}
|
||||
|
||||
if len(optTypes) > 0 && !slices.Contains(optTypes, collection.Type) {
|
||||
return validation.NewError(
|
||||
"validation_invalid_collection_type",
|
||||
fmt.Sprintf("Invalid collection type - must be %s.", strings.Join(optTypes, ", ")),
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func validateRecordId(app App, collectionNameOrId string) validation.RuleFunc {
|
||||
return func(value any) error {
|
||||
id, _ := value.(string)
|
||||
if id == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
collection, err := app.FindCachedCollectionByNameOrId(collectionNameOrId)
|
||||
if err != nil {
|
||||
return validation.NewError("validation_invalid_collection", "Missing or invalid collection.")
|
||||
}
|
||||
|
||||
var exists int
|
||||
|
||||
rowErr := app.ConcurrentDB().Select("(1)").
|
||||
From(collection.Name).
|
||||
AndWhere(dbx.HashExp{"id": id}).
|
||||
Limit(1).
|
||||
Row(&exists)
|
||||
|
||||
if rowErr != nil || exists == 0 {
|
||||
return validation.NewError("validation_invalid_record", "Missing or invalid record.")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
187
core/db_builder.go
Normal file
187
core/db_builder.go
Normal file
|
@ -0,0 +1,187 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"unicode"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/pocketbase/dbx"
|
||||
)
|
||||
|
||||
var _ dbx.Builder = (*dualDBBuilder)(nil)
|
||||
|
||||
// note: expects both builder to use the same driver
|
||||
type dualDBBuilder struct {
|
||||
concurrentDB dbx.Builder
|
||||
nonconcurrentDB dbx.Builder
|
||||
}
|
||||
|
||||
// Select implements the [dbx.Builder.Select] interface method.
|
||||
func (b *dualDBBuilder) Select(cols ...string) *dbx.SelectQuery {
|
||||
return b.concurrentDB.Select(cols...)
|
||||
}
|
||||
|
||||
// Model implements the [dbx.Builder.Model] interface method.
|
||||
func (b *dualDBBuilder) Model(data interface{}) *dbx.ModelQuery {
|
||||
return b.nonconcurrentDB.Model(data)
|
||||
}
|
||||
|
||||
// GeneratePlaceholder implements the [dbx.Builder.GeneratePlaceholder] interface method.
|
||||
func (b *dualDBBuilder) GeneratePlaceholder(i int) string {
|
||||
return b.concurrentDB.GeneratePlaceholder(i)
|
||||
}
|
||||
|
||||
// Quote implements the [dbx.Builder.Quote] interface method.
|
||||
func (b *dualDBBuilder) Quote(str string) string {
|
||||
return b.concurrentDB.Quote(str)
|
||||
}
|
||||
|
||||
// QuoteSimpleTableName implements the [dbx.Builder.QuoteSimpleTableName] interface method.
|
||||
func (b *dualDBBuilder) QuoteSimpleTableName(table string) string {
|
||||
return b.concurrentDB.QuoteSimpleTableName(table)
|
||||
}
|
||||
|
||||
// QuoteSimpleColumnName implements the [dbx.Builder.QuoteSimpleColumnName] interface method.
|
||||
func (b *dualDBBuilder) QuoteSimpleColumnName(col string) string {
|
||||
return b.concurrentDB.QuoteSimpleColumnName(col)
|
||||
}
|
||||
|
||||
// QueryBuilder implements the [dbx.Builder.QueryBuilder] interface method.
|
||||
func (b *dualDBBuilder) QueryBuilder() dbx.QueryBuilder {
|
||||
return b.concurrentDB.QueryBuilder()
|
||||
}
|
||||
|
||||
// Insert implements the [dbx.Builder.Insert] interface method.
|
||||
func (b *dualDBBuilder) Insert(table string, cols dbx.Params) *dbx.Query {
|
||||
return b.nonconcurrentDB.Insert(table, cols)
|
||||
}
|
||||
|
||||
// Upsert implements the [dbx.Builder.Upsert] interface method.
|
||||
func (b *dualDBBuilder) Upsert(table string, cols dbx.Params, constraints ...string) *dbx.Query {
|
||||
return b.nonconcurrentDB.Upsert(table, cols, constraints...)
|
||||
}
|
||||
|
||||
// Update implements the [dbx.Builder.Update] interface method.
|
||||
func (b *dualDBBuilder) Update(table string, cols dbx.Params, where dbx.Expression) *dbx.Query {
|
||||
return b.nonconcurrentDB.Update(table, cols, where)
|
||||
}
|
||||
|
||||
// Delete implements the [dbx.Builder.Delete] interface method.
|
||||
func (b *dualDBBuilder) Delete(table string, where dbx.Expression) *dbx.Query {
|
||||
return b.nonconcurrentDB.Delete(table, where)
|
||||
}
|
||||
|
||||
// CreateTable implements the [dbx.Builder.CreateTable] interface method.
|
||||
func (b *dualDBBuilder) CreateTable(table string, cols map[string]string, options ...string) *dbx.Query {
|
||||
return b.nonconcurrentDB.CreateTable(table, cols, options...)
|
||||
}
|
||||
|
||||
// RenameTable implements the [dbx.Builder.RenameTable] interface method.
|
||||
func (b *dualDBBuilder) RenameTable(oldName, newName string) *dbx.Query {
|
||||
return b.nonconcurrentDB.RenameTable(oldName, newName)
|
||||
}
|
||||
|
||||
// DropTable implements the [dbx.Builder.DropTable] interface method.
|
||||
func (b *dualDBBuilder) DropTable(table string) *dbx.Query {
|
||||
return b.nonconcurrentDB.DropTable(table)
|
||||
}
|
||||
|
||||
// TruncateTable implements the [dbx.Builder.TruncateTable] interface method.
|
||||
func (b *dualDBBuilder) TruncateTable(table string) *dbx.Query {
|
||||
return b.nonconcurrentDB.TruncateTable(table)
|
||||
}
|
||||
|
||||
// AddColumn implements the [dbx.Builder.AddColumn] interface method.
|
||||
func (b *dualDBBuilder) AddColumn(table, col, typ string) *dbx.Query {
|
||||
return b.nonconcurrentDB.AddColumn(table, col, typ)
|
||||
}
|
||||
|
||||
// DropColumn implements the [dbx.Builder.DropColumn] interface method.
|
||||
func (b *dualDBBuilder) DropColumn(table, col string) *dbx.Query {
|
||||
return b.nonconcurrentDB.DropColumn(table, col)
|
||||
}
|
||||
|
||||
// RenameColumn implements the [dbx.Builder.RenameColumn] interface method.
|
||||
func (b *dualDBBuilder) RenameColumn(table, oldName, newName string) *dbx.Query {
|
||||
return b.nonconcurrentDB.RenameColumn(table, oldName, newName)
|
||||
}
|
||||
|
||||
// AlterColumn implements the [dbx.Builder.AlterColumn] interface method.
|
||||
func (b *dualDBBuilder) AlterColumn(table, col, typ string) *dbx.Query {
|
||||
return b.nonconcurrentDB.AlterColumn(table, col, typ)
|
||||
}
|
||||
|
||||
// AddPrimaryKey implements the [dbx.Builder.AddPrimaryKey] interface method.
|
||||
func (b *dualDBBuilder) AddPrimaryKey(table, name string, cols ...string) *dbx.Query {
|
||||
return b.nonconcurrentDB.AddPrimaryKey(table, name, cols...)
|
||||
}
|
||||
|
||||
// DropPrimaryKey implements the [dbx.Builder.DropPrimaryKey] interface method.
|
||||
func (b *dualDBBuilder) DropPrimaryKey(table, name string) *dbx.Query {
|
||||
return b.nonconcurrentDB.DropPrimaryKey(table, name)
|
||||
}
|
||||
|
||||
// AddForeignKey implements the [dbx.Builder.AddForeignKey] interface method.
|
||||
func (b *dualDBBuilder) AddForeignKey(table, name string, cols, refCols []string, refTable string, options ...string) *dbx.Query {
|
||||
return b.nonconcurrentDB.AddForeignKey(table, name, cols, refCols, refTable, options...)
|
||||
}
|
||||
|
||||
// DropForeignKey implements the [dbx.Builder.DropForeignKey] interface method.
|
||||
func (b *dualDBBuilder) DropForeignKey(table, name string) *dbx.Query {
|
||||
return b.nonconcurrentDB.DropForeignKey(table, name)
|
||||
}
|
||||
|
||||
// CreateIndex implements the [dbx.Builder.CreateIndex] interface method.
|
||||
func (b *dualDBBuilder) CreateIndex(table, name string, cols ...string) *dbx.Query {
|
||||
return b.nonconcurrentDB.CreateIndex(table, name, cols...)
|
||||
}
|
||||
|
||||
// CreateUniqueIndex implements the [dbx.Builder.CreateUniqueIndex] interface method.
|
||||
func (b *dualDBBuilder) CreateUniqueIndex(table, name string, cols ...string) *dbx.Query {
|
||||
return b.nonconcurrentDB.CreateUniqueIndex(table, name, cols...)
|
||||
}
|
||||
|
||||
// DropIndex implements the [dbx.Builder.DropIndex] interface method.
|
||||
func (b *dualDBBuilder) DropIndex(table, name string) *dbx.Query {
|
||||
return b.nonconcurrentDB.DropIndex(table, name)
|
||||
}
|
||||
|
||||
// NewQuery implements the [dbx.Builder.NewQuery] interface method by
|
||||
// routing the SELECT queries to the concurrent builder instance.
|
||||
func (b *dualDBBuilder) NewQuery(str string) *dbx.Query {
|
||||
// note: technically INSERT/UPDATE/DELETE could also have CTE but since
|
||||
// it is rare for now this scase is ignored to avoid unnecessary complicating the checks
|
||||
trimmed := trimLeftSpaces(str)
|
||||
if hasPrefixFold(trimmed, "SELECT") || hasPrefixFold(trimmed, "WITH") {
|
||||
return b.concurrentDB.NewQuery(str)
|
||||
}
|
||||
|
||||
return b.nonconcurrentDB.NewQuery(str)
|
||||
}
|
||||
|
||||
var asciiSpace = [256]uint8{'\t': 1, '\n': 1, '\v': 1, '\f': 1, '\r': 1, ' ': 1}
|
||||
|
||||
// note: similar to strings.Space() but without the right trim because it is not needed in our case
|
||||
func trimLeftSpaces(str string) string {
|
||||
start := 0
|
||||
for ; start < len(str); start++ {
|
||||
c := str[start]
|
||||
if c >= utf8.RuneSelf {
|
||||
return strings.TrimLeftFunc(str[start:], unicode.IsSpace)
|
||||
}
|
||||
if asciiSpace[c] == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return str[start:]
|
||||
}
|
||||
|
||||
// note: the prefix is expected to be ASCII
|
||||
func hasPrefixFold(str, prefix string) bool {
|
||||
if len(str) < len(prefix) {
|
||||
return false
|
||||
}
|
||||
|
||||
return strings.EqualFold(str[:len(prefix)], prefix)
|
||||
}
|
22
core/db_connect.go
Normal file
22
core/db_connect.go
Normal file
|
@ -0,0 +1,22 @@
|
|||
//go:build !no_default_driver
|
||||
|
||||
package core
|
||||
|
||||
import (
|
||||
"github.com/pocketbase/dbx"
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
func DefaultDBConnect(dbPath string) (*dbx.DB, error) {
|
||||
// Note: the busy_timeout pragma must be first because
|
||||
// the connection needs to be set to block on busy before WAL mode
|
||||
// is set in case it hasn't been already set by another connection.
|
||||
pragmas := "?_pragma=busy_timeout(10000)&_pragma=journal_mode(WAL)&_pragma=journal_size_limit(200000000)&_pragma=synchronous(NORMAL)&_pragma=foreign_keys(ON)&_pragma=temp_store(MEMORY)&_pragma=cache_size(-16000)"
|
||||
|
||||
db, err := dbx.Open("sqlite", dbPath+pragmas)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
9
core/db_connect_nodefaultdriver.go
Normal file
9
core/db_connect_nodefaultdriver.go
Normal file
|
@ -0,0 +1,9 @@
|
|||
//go:build no_default_driver
|
||||
|
||||
package core
|
||||
|
||||
import "github.com/pocketbase/dbx"
|
||||
|
||||
func DefaultDBConnect(dbPath string) (*dbx.DB, error) {
|
||||
panic("DBConnect config option must be set when the no_default_driver tag is used!")
|
||||
}
|
59
core/db_model.go
Normal file
59
core/db_model.go
Normal file
|
@ -0,0 +1,59 @@
|
|||
package core
|
||||
|
||||
// Model defines an interface with common methods that all db models should have.
|
||||
//
|
||||
// Note: for simplicity composite pk are not supported.
|
||||
type Model interface {
|
||||
TableName() string
|
||||
PK() any
|
||||
LastSavedPK() any
|
||||
IsNew() bool
|
||||
MarkAsNew()
|
||||
MarkAsNotNew()
|
||||
}
|
||||
|
||||
// BaseModel defines a base struct that is intended to be embedded into other custom models.
|
||||
type BaseModel struct {
|
||||
lastSavedPK string
|
||||
|
||||
// Id is the primary key of the model.
|
||||
// It is usually autogenerated by the parent model implementation.
|
||||
Id string `db:"id" json:"id" form:"id" xml:"id"`
|
||||
}
|
||||
|
||||
// LastSavedPK returns the last saved primary key of the model.
|
||||
//
|
||||
// Its value is updated to the latest PK value after MarkAsNotNew() or PostScan() calls.
|
||||
func (m *BaseModel) LastSavedPK() any {
|
||||
return m.lastSavedPK
|
||||
}
|
||||
|
||||
func (m *BaseModel) PK() any {
|
||||
return m.Id
|
||||
}
|
||||
|
||||
// IsNew indicates what type of db query (insert or update)
|
||||
// should be used with the model instance.
|
||||
func (m *BaseModel) IsNew() bool {
|
||||
return m.lastSavedPK == ""
|
||||
}
|
||||
|
||||
// MarkAsNew clears the pk field and marks the current model as "new"
|
||||
// (aka. forces m.IsNew() to be true).
|
||||
func (m *BaseModel) MarkAsNew() {
|
||||
m.lastSavedPK = ""
|
||||
}
|
||||
|
||||
// MarkAsNew set the pk field to the Id value and marks the current model
|
||||
// as NOT "new" (aka. forces m.IsNew() to be false).
|
||||
func (m *BaseModel) MarkAsNotNew() {
|
||||
m.lastSavedPK = m.Id
|
||||
}
|
||||
|
||||
// PostScan implements the [dbx.PostScanner] interface.
|
||||
//
|
||||
// It is usually executed right after the model is populated with the db row values.
|
||||
func (m *BaseModel) PostScan() error {
|
||||
m.MarkAsNotNew()
|
||||
return nil
|
||||
}
|
70
core/db_model_test.go
Normal file
70
core/db_model_test.go
Normal file
|
@ -0,0 +1,70 @@
|
|||
package core_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
)
|
||||
|
||||
func TestBaseModel(t *testing.T) {
|
||||
id := "test_id"
|
||||
|
||||
m := core.BaseModel{Id: id}
|
||||
|
||||
if m.PK() != id {
|
||||
t.Fatalf("[before PostScan] Expected PK %q, got %q", "", m.PK())
|
||||
}
|
||||
|
||||
if m.LastSavedPK() != "" {
|
||||
t.Fatalf("[before PostScan] Expected LastSavedPK %q, got %q", "", m.LastSavedPK())
|
||||
}
|
||||
|
||||
if !m.IsNew() {
|
||||
t.Fatalf("[before PostScan] Expected IsNew %v, got %v", true, m.IsNew())
|
||||
}
|
||||
|
||||
if err := m.PostScan(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if m.PK() != id {
|
||||
t.Fatalf("[after PostScan] Expected PK %q, got %q", "", m.PK())
|
||||
}
|
||||
|
||||
if m.LastSavedPK() != id {
|
||||
t.Fatalf("[after PostScan] Expected LastSavedPK %q, got %q", id, m.LastSavedPK())
|
||||
}
|
||||
|
||||
if m.IsNew() {
|
||||
t.Fatalf("[after PostScan] Expected IsNew %v, got %v", false, m.IsNew())
|
||||
}
|
||||
|
||||
m.MarkAsNew()
|
||||
|
||||
if m.PK() != id {
|
||||
t.Fatalf("[after MarkAsNew] Expected PK %q, got %q", id, m.PK())
|
||||
}
|
||||
|
||||
if m.LastSavedPK() != "" {
|
||||
t.Fatalf("[after MarkAsNew] Expected LastSavedPK %q, got %q", "", m.LastSavedPK())
|
||||
}
|
||||
|
||||
if !m.IsNew() {
|
||||
t.Fatalf("[after MarkAsNew] Expected IsNew %v, got %v", true, m.IsNew())
|
||||
}
|
||||
|
||||
// mark as not new without id
|
||||
m.MarkAsNotNew()
|
||||
|
||||
if m.PK() != id {
|
||||
t.Fatalf("[after MarkAsNotNew] Expected PK %q, got %q", id, m.PK())
|
||||
}
|
||||
|
||||
if m.LastSavedPK() != id {
|
||||
t.Fatalf("[after MarkAsNotNew] Expected LastSavedPK %q, got %q", id, m.LastSavedPK())
|
||||
}
|
||||
|
||||
if m.IsNew() {
|
||||
t.Fatalf("[after MarkAsNotNew] Expected IsNew %v, got %v", false, m.IsNew())
|
||||
}
|
||||
}
|
70
core/db_retry.go
Normal file
70
core/db_retry.go
Normal file
|
@ -0,0 +1,70 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/pocketbase/dbx"
|
||||
)
|
||||
|
||||
// default retries intervals (in ms)
|
||||
var defaultRetryIntervals = []int{50, 100, 150, 200, 300, 400, 500, 700, 1000}
|
||||
|
||||
// default max retry attempts
|
||||
const defaultMaxLockRetries = 12
|
||||
|
||||
func execLockRetry(timeout time.Duration, maxRetries int) dbx.ExecHookFunc {
|
||||
return func(q *dbx.Query, op func() error) error {
|
||||
if q.Context() == nil {
|
||||
cancelCtx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer func() {
|
||||
cancel()
|
||||
//nolint:staticcheck
|
||||
q.WithContext(nil) // reset
|
||||
}()
|
||||
q.WithContext(cancelCtx)
|
||||
}
|
||||
|
||||
execErr := baseLockRetry(func(attempt int) error {
|
||||
return op()
|
||||
}, maxRetries)
|
||||
if execErr != nil && !errors.Is(execErr, sql.ErrNoRows) {
|
||||
execErr = fmt.Errorf("%w; failed query: %s", execErr, q.SQL())
|
||||
}
|
||||
|
||||
return execErr
|
||||
}
|
||||
}
|
||||
|
||||
func baseLockRetry(op func(attempt int) error, maxRetries int) error {
|
||||
attempt := 1
|
||||
|
||||
Retry:
|
||||
err := op(attempt)
|
||||
|
||||
if err != nil && attempt <= maxRetries {
|
||||
errStr := err.Error()
|
||||
// we are checking the error against the plain error texts since the codes could vary between drivers
|
||||
if strings.Contains(errStr, "database is locked") ||
|
||||
strings.Contains(errStr, "table is locked") {
|
||||
// wait and retry
|
||||
time.Sleep(getDefaultRetryInterval(attempt))
|
||||
attempt++
|
||||
goto Retry
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func getDefaultRetryInterval(attempt int) time.Duration {
|
||||
if attempt < 0 || attempt > len(defaultRetryIntervals)-1 {
|
||||
return time.Duration(defaultRetryIntervals[len(defaultRetryIntervals)-1]) * time.Millisecond
|
||||
}
|
||||
|
||||
return time.Duration(defaultRetryIntervals[attempt]) * time.Millisecond
|
||||
}
|
66
core/db_retry_test.go
Normal file
66
core/db_retry_test.go
Normal file
|
@ -0,0 +1,66 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGetDefaultRetryInterval(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if i := getDefaultRetryInterval(-1); i.Milliseconds() != 1000 {
|
||||
t.Fatalf("Expected 1000ms, got %v", i)
|
||||
}
|
||||
|
||||
if i := getDefaultRetryInterval(999); i.Milliseconds() != 1000 {
|
||||
t.Fatalf("Expected 1000ms, got %v", i)
|
||||
}
|
||||
|
||||
if i := getDefaultRetryInterval(3); i.Milliseconds() != 200 {
|
||||
t.Fatalf("Expected 500ms, got %v", i)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseLockRetry(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []struct {
|
||||
err error
|
||||
failUntilAttempt int
|
||||
expectedAttempts int
|
||||
}{
|
||||
{nil, 3, 1},
|
||||
{errors.New("test"), 3, 1},
|
||||
{errors.New("database is locked"), 3, 3},
|
||||
{errors.New("table is locked"), 3, 3},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%#v", i, s.err), func(t *testing.T) {
|
||||
lastAttempt := 0
|
||||
|
||||
err := baseLockRetry(func(attempt int) error {
|
||||
lastAttempt = attempt
|
||||
|
||||
if attempt < s.failUntilAttempt {
|
||||
return s.err
|
||||
}
|
||||
|
||||
return nil
|
||||
}, s.failUntilAttempt+2)
|
||||
|
||||
if lastAttempt != s.expectedAttempts {
|
||||
t.Errorf("Expected lastAttempt to be %d, got %d", s.expectedAttempts, lastAttempt)
|
||||
}
|
||||
|
||||
if s.failUntilAttempt == s.expectedAttempts && err != nil {
|
||||
t.Fatalf("Expected nil, got err %v", err)
|
||||
}
|
||||
|
||||
if s.failUntilAttempt != s.expectedAttempts && s.err != nil && err == nil {
|
||||
t.Fatalf("Expected error %q, got nil", s.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
137
core/db_table.go
Normal file
137
core/db_table.go
Normal file
|
@ -0,0 +1,137 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/pocketbase/dbx"
|
||||
)
|
||||
|
||||
// TableColumns returns all column names of a single table by its name.
|
||||
func (app *BaseApp) TableColumns(tableName string) ([]string, error) {
|
||||
columns := []string{}
|
||||
|
||||
err := app.ConcurrentDB().NewQuery("SELECT name FROM PRAGMA_TABLE_INFO({:tableName})").
|
||||
Bind(dbx.Params{"tableName": tableName}).
|
||||
Column(&columns)
|
||||
|
||||
return columns, err
|
||||
}
|
||||
|
||||
type TableInfoRow struct {
|
||||
// the `db:"pk"` tag has special semantic so we cannot rename
|
||||
// the original field without specifying a custom mapper
|
||||
PK int
|
||||
|
||||
Index int `db:"cid"`
|
||||
Name string `db:"name"`
|
||||
Type string `db:"type"`
|
||||
NotNull bool `db:"notnull"`
|
||||
DefaultValue sql.NullString `db:"dflt_value"`
|
||||
}
|
||||
|
||||
// TableInfo returns the "table_info" pragma result for the specified table.
|
||||
func (app *BaseApp) TableInfo(tableName string) ([]*TableInfoRow, error) {
|
||||
info := []*TableInfoRow{}
|
||||
|
||||
err := app.ConcurrentDB().NewQuery("SELECT * FROM PRAGMA_TABLE_INFO({:tableName})").
|
||||
Bind(dbx.Params{"tableName": tableName}).
|
||||
All(&info)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// mattn/go-sqlite3 doesn't throw an error on invalid or missing table
|
||||
// so we additionally have to check whether the loaded info result is nonempty
|
||||
if len(info) == 0 {
|
||||
return nil, fmt.Errorf("empty table info probably due to invalid or missing table %s", tableName)
|
||||
}
|
||||
|
||||
return info, nil
|
||||
}
|
||||
|
||||
// TableIndexes returns a name grouped map with all non empty index of the specified table.
|
||||
//
|
||||
// Note: This method doesn't return an error on nonexisting table.
|
||||
func (app *BaseApp) TableIndexes(tableName string) (map[string]string, error) {
|
||||
indexes := []struct {
|
||||
Name string
|
||||
Sql string
|
||||
}{}
|
||||
|
||||
err := app.ConcurrentDB().Select("name", "sql").
|
||||
From("sqlite_master").
|
||||
AndWhere(dbx.NewExp("sql is not null")).
|
||||
AndWhere(dbx.HashExp{
|
||||
"type": "index",
|
||||
"tbl_name": tableName,
|
||||
}).
|
||||
All(&indexes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := make(map[string]string, len(indexes))
|
||||
|
||||
for _, idx := range indexes {
|
||||
result[idx.Name] = idx.Sql
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// DeleteTable drops the specified table.
|
||||
//
|
||||
// This method is a no-op if a table with the provided name doesn't exist.
|
||||
//
|
||||
// NB! Be aware that this method is vulnerable to SQL injection and the
|
||||
// "tableName" argument must come only from trusted input!
|
||||
func (app *BaseApp) DeleteTable(tableName string) error {
|
||||
_, err := app.NonconcurrentDB().NewQuery(fmt.Sprintf(
|
||||
"DROP TABLE IF EXISTS {{%s}}",
|
||||
tableName,
|
||||
)).Execute()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// HasTable checks if a table (or view) with the provided name exists (case insensitive).
|
||||
// in the data.db.
|
||||
func (app *BaseApp) HasTable(tableName string) bool {
|
||||
return app.hasTable(app.ConcurrentDB(), tableName)
|
||||
}
|
||||
|
||||
// AuxHasTable checks if a table (or view) with the provided name exists (case insensitive)
|
||||
// in the auixiliary.db.
|
||||
func (app *BaseApp) AuxHasTable(tableName string) bool {
|
||||
return app.hasTable(app.AuxConcurrentDB(), tableName)
|
||||
}
|
||||
|
||||
func (app *BaseApp) hasTable(db dbx.Builder, tableName string) bool {
|
||||
var exists int
|
||||
|
||||
err := db.Select("(1)").
|
||||
From("sqlite_schema").
|
||||
AndWhere(dbx.HashExp{"type": []any{"table", "view"}}).
|
||||
AndWhere(dbx.NewExp("LOWER([[name]])=LOWER({:tableName})", dbx.Params{"tableName": tableName})).
|
||||
Limit(1).
|
||||
Row(&exists)
|
||||
|
||||
return err == nil && exists > 0
|
||||
}
|
||||
|
||||
// Vacuum executes VACUUM on the data.db in order to reclaim unused data db disk space.
|
||||
func (app *BaseApp) Vacuum() error {
|
||||
return app.vacuum(app.NonconcurrentDB())
|
||||
}
|
||||
|
||||
// AuxVacuum executes VACUUM on the auxiliary.db in order to reclaim unused auxiliary db disk space.
|
||||
func (app *BaseApp) AuxVacuum() error {
|
||||
return app.vacuum(app.AuxNonconcurrentDB())
|
||||
}
|
||||
|
||||
func (app *BaseApp) vacuum(db dbx.Builder) error {
|
||||
_, err := db.NewQuery("VACUUM").Execute()
|
||||
|
||||
return err
|
||||
}
|
250
core/db_table_test.go
Normal file
250
core/db_table_test.go
Normal file
|
@ -0,0 +1,250 @@
|
|||
package core_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"slices"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestHasTable(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
scenarios := []struct {
|
||||
tableName string
|
||||
expected bool
|
||||
}{
|
||||
{"", false},
|
||||
{"test", false},
|
||||
{core.CollectionNameSuperusers, true},
|
||||
{"demo3", true},
|
||||
{"DEMO3", true}, // table names are case insensitives by default
|
||||
{"view1", true}, // view
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.tableName, func(t *testing.T) {
|
||||
result := app.HasTable(s.tableName)
|
||||
if result != s.expected {
|
||||
t.Fatalf("Expected %v, got %v", s.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuxHasTable(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
scenarios := []struct {
|
||||
tableName string
|
||||
expected bool
|
||||
}{
|
||||
{"", false},
|
||||
{"test", false},
|
||||
{"_lOGS", true}, // table names are case insensitives by default
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.tableName, func(t *testing.T) {
|
||||
result := app.AuxHasTable(s.tableName)
|
||||
if result != s.expected {
|
||||
t.Fatalf("Expected %v, got %v", s.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTableColumns(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
scenarios := []struct {
|
||||
tableName string
|
||||
expected []string
|
||||
}{
|
||||
{"", nil},
|
||||
{"_params", []string{"id", "value", "created", "updated"}},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%s", i, s.tableName), func(t *testing.T) {
|
||||
columns, _ := app.TableColumns(s.tableName)
|
||||
|
||||
if len(columns) != len(s.expected) {
|
||||
t.Fatalf("Expected columns %v, got %v", s.expected, columns)
|
||||
}
|
||||
|
||||
for _, c := range columns {
|
||||
if !slices.Contains(s.expected, c) {
|
||||
t.Errorf("Didn't expect column %s", c)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTableInfo(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
scenarios := []struct {
|
||||
tableName string
|
||||
expected string
|
||||
}{
|
||||
{"", "null"},
|
||||
{"missing", "null"},
|
||||
{
|
||||
"_params",
|
||||
`[{"PK":0,"Index":0,"Name":"created","Type":"TEXT","NotNull":true,"DefaultValue":{"String":"''","Valid":true}},{"PK":1,"Index":1,"Name":"id","Type":"TEXT","NotNull":true,"DefaultValue":{"String":"'r'||lower(hex(randomblob(7)))","Valid":true}},{"PK":0,"Index":2,"Name":"updated","Type":"TEXT","NotNull":true,"DefaultValue":{"String":"''","Valid":true}},{"PK":0,"Index":3,"Name":"value","Type":"JSON","NotNull":false,"DefaultValue":{"String":"NULL","Valid":true}}]`,
|
||||
},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%s", i, s.tableName), func(t *testing.T) {
|
||||
rows, _ := app.TableInfo(s.tableName)
|
||||
|
||||
raw, err := json.Marshal(rows)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if str := string(raw); str != s.expected {
|
||||
t.Fatalf("Expected\n%s\ngot\n%s", s.expected, str)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTableIndexes(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
scenarios := []struct {
|
||||
tableName string
|
||||
expected []string
|
||||
}{
|
||||
{"", nil},
|
||||
{"missing", nil},
|
||||
{
|
||||
core.CollectionNameSuperusers,
|
||||
[]string{"idx_email__pbc_3323866339", "idx_tokenKey__pbc_3323866339"},
|
||||
},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%s", i, s.tableName), func(t *testing.T) {
|
||||
indexes, _ := app.TableIndexes(s.tableName)
|
||||
|
||||
if len(indexes) != len(s.expected) {
|
||||
t.Fatalf("Expected %d indexes, got %d\n%v", len(s.expected), len(indexes), indexes)
|
||||
}
|
||||
|
||||
for _, name := range s.expected {
|
||||
if v, ok := indexes[name]; !ok || v == "" {
|
||||
t.Fatalf("Expected non-empty index %q in \n%v", name, indexes)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteTable(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
scenarios := []struct {
|
||||
tableName string
|
||||
expectError bool
|
||||
}{
|
||||
{"", true},
|
||||
{"test", false}, // missing tables are ignored
|
||||
{"_admins", false},
|
||||
{"demo3", false},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%s", i, s.tableName), func(t *testing.T) {
|
||||
err := app.DeleteTable(s.tableName)
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Fatalf("Expected hasErr %v, got %v", s.expectError, hasErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestVacuum(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
calledQueries := []string{}
|
||||
app.NonconcurrentDB().(*dbx.DB).QueryLogFunc = func(ctx context.Context, t time.Duration, sql string, rows *sql.Rows, err error) {
|
||||
calledQueries = append(calledQueries, sql)
|
||||
}
|
||||
app.NonconcurrentDB().(*dbx.DB).ExecLogFunc = func(ctx context.Context, t time.Duration, sql string, result sql.Result, err error) {
|
||||
calledQueries = append(calledQueries, sql)
|
||||
}
|
||||
|
||||
if err := app.Vacuum(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if total := len(calledQueries); total != 1 {
|
||||
t.Fatalf("Expected 1 query, got %d", total)
|
||||
}
|
||||
|
||||
if calledQueries[0] != "VACUUM" {
|
||||
t.Fatalf("Expected VACUUM query, got %s", calledQueries[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuxVacuum(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
calledQueries := []string{}
|
||||
app.AuxNonconcurrentDB().(*dbx.DB).QueryLogFunc = func(ctx context.Context, t time.Duration, sql string, rows *sql.Rows, err error) {
|
||||
calledQueries = append(calledQueries, sql)
|
||||
}
|
||||
app.AuxNonconcurrentDB().(*dbx.DB).ExecLogFunc = func(ctx context.Context, t time.Duration, sql string, result sql.Result, err error) {
|
||||
calledQueries = append(calledQueries, sql)
|
||||
}
|
||||
|
||||
if err := app.AuxVacuum(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if total := len(calledQueries); total != 1 {
|
||||
t.Fatalf("Expected 1 query, got %d", total)
|
||||
}
|
||||
|
||||
if calledQueries[0] != "VACUUM" {
|
||||
t.Fatalf("Expected VACUUM query, got %s", calledQueries[0])
|
||||
}
|
||||
}
|
113
core/db_test.go
Normal file
113
core/db_test.go
Normal file
|
@ -0,0 +1,113 @@
|
|||
package core_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestGenerateDefaultRandomId(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
id1 := core.GenerateDefaultRandomId()
|
||||
id2 := core.GenerateDefaultRandomId()
|
||||
|
||||
if id1 == id2 {
|
||||
t.Fatalf("Expected id1 and id2 to differ, got %q", id1)
|
||||
}
|
||||
|
||||
if l := len(id1); l != 15 {
|
||||
t.Fatalf("Expected id1 length %d, got %d", 15, l)
|
||||
}
|
||||
|
||||
if l := len(id2); l != 15 {
|
||||
t.Fatalf("Expected id2 length %d, got %d", 15, l)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelQuery(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
modelsQuery := app.ModelQuery(&core.Collection{})
|
||||
logsModelQuery := app.AuxModelQuery(&core.Collection{})
|
||||
|
||||
if app.ConcurrentDB() == modelsQuery.Info().Builder {
|
||||
t.Fatalf("ModelQuery() is not using app.ConcurrentDB()")
|
||||
}
|
||||
|
||||
if app.AuxConcurrentDB() == logsModelQuery.Info().Builder {
|
||||
t.Fatalf("AuxModelQuery() is not using app.AuxConcurrentDB()")
|
||||
}
|
||||
|
||||
expectedSQL := "SELECT {{_collections}}.* FROM `_collections`"
|
||||
for i, q := range []*dbx.SelectQuery{modelsQuery, logsModelQuery} {
|
||||
sql := q.Build().SQL()
|
||||
if sql != expectedSQL {
|
||||
t.Fatalf("[%d] Expected select\n%s\ngot\n%s", i, expectedSQL, sql)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
u := &mockSuperusers{}
|
||||
|
||||
testErr := errors.New("test")
|
||||
|
||||
app.OnModelValidate().BindFunc(func(e *core.ModelEvent) error {
|
||||
return testErr
|
||||
})
|
||||
|
||||
err := app.Validate(u)
|
||||
if err != testErr {
|
||||
t.Fatalf("Expected error %v, got %v", testErr, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateWithContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
u := &mockSuperusers{}
|
||||
|
||||
testErr := errors.New("test")
|
||||
|
||||
app.OnModelValidate().BindFunc(func(e *core.ModelEvent) error {
|
||||
if v := e.Context.Value("test"); v != 123 {
|
||||
t.Fatalf("Expected 'test' context value %#v, got %#v", 123, v)
|
||||
}
|
||||
return testErr
|
||||
})
|
||||
|
||||
//nolint:staticcheck
|
||||
ctx := context.WithValue(context.Background(), "test", 123)
|
||||
|
||||
err := app.ValidateWithContext(ctx, u)
|
||||
if err != testErr {
|
||||
t.Fatalf("Expected error %v, got %v", testErr, err)
|
||||
}
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
type mockSuperusers struct {
|
||||
core.BaseModel
|
||||
Email string `db:"email"`
|
||||
}
|
||||
|
||||
func (m *mockSuperusers) TableName() string {
|
||||
return core.CollectionNameSuperusers
|
||||
}
|
112
core/db_tx.go
Normal file
112
core/db_tx.go
Normal file
|
@ -0,0 +1,112 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/pocketbase/dbx"
|
||||
)
|
||||
|
||||
// RunInTransaction wraps fn into a transaction for the regular app database.
|
||||
//
|
||||
// It is safe to nest RunInTransaction calls as long as you use the callback's txApp.
|
||||
func (app *BaseApp) RunInTransaction(fn func(txApp App) error) error {
|
||||
return app.runInTransaction(app.NonconcurrentDB(), fn, false)
|
||||
}
|
||||
|
||||
// AuxRunInTransaction wraps fn into a transaction for the auxiliary app database.
|
||||
//
|
||||
// It is safe to nest RunInTransaction calls as long as you use the callback's txApp.
|
||||
func (app *BaseApp) AuxRunInTransaction(fn func(txApp App) error) error {
|
||||
return app.runInTransaction(app.AuxNonconcurrentDB(), fn, true)
|
||||
}
|
||||
|
||||
func (app *BaseApp) runInTransaction(db dbx.Builder, fn func(txApp App) error, isForAuxDB bool) error {
|
||||
switch txOrDB := db.(type) {
|
||||
case *dbx.Tx:
|
||||
// run as part of the already existing transaction
|
||||
return fn(app)
|
||||
case *dbx.DB:
|
||||
var txApp *BaseApp
|
||||
txErr := txOrDB.Transactional(func(tx *dbx.Tx) error {
|
||||
txApp = app.createTxApp(tx, isForAuxDB)
|
||||
return fn(txApp)
|
||||
})
|
||||
|
||||
// execute all after event calls on transaction complete
|
||||
if txApp != nil && txApp.txInfo != nil {
|
||||
afterFuncErr := txApp.txInfo.runAfterFuncs(txErr)
|
||||
if afterFuncErr != nil {
|
||||
return errors.Join(txErr, afterFuncErr)
|
||||
}
|
||||
}
|
||||
|
||||
return txErr
|
||||
default:
|
||||
return errors.New("failed to start transaction (unknown db type)")
|
||||
}
|
||||
}
|
||||
|
||||
// createTxApp shallow clones the current app and assigns a new tx state.
|
||||
func (app *BaseApp) createTxApp(tx *dbx.Tx, isForAuxDB bool) *BaseApp {
|
||||
clone := *app
|
||||
|
||||
if isForAuxDB {
|
||||
clone.auxConcurrentDB = tx
|
||||
clone.auxNonconcurrentDB = tx
|
||||
} else {
|
||||
clone.concurrentDB = tx
|
||||
clone.nonconcurrentDB = tx
|
||||
}
|
||||
|
||||
clone.txInfo = &TxAppInfo{
|
||||
parent: app,
|
||||
isForAuxDB: isForAuxDB,
|
||||
}
|
||||
|
||||
return &clone
|
||||
}
|
||||
|
||||
// TxAppInfo represents an active transaction context associated to an existing app instance.
|
||||
type TxAppInfo struct {
|
||||
parent *BaseApp
|
||||
afterFuncs []func(txErr error) error
|
||||
mu sync.Mutex
|
||||
isForAuxDB bool
|
||||
}
|
||||
|
||||
// OnComplete registers the provided callback that will be invoked
|
||||
// once the related transaction ends (either completes successfully or rollbacked with an error).
|
||||
//
|
||||
// The callback receives the transaction error (if any) as its argument.
|
||||
// Any additional errors returned by the OnComplete callbacks will be
|
||||
// joined together with txErr when returning the final transaction result.
|
||||
func (a *TxAppInfo) OnComplete(fn func(txErr error) error) {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
|
||||
a.afterFuncs = append(a.afterFuncs, fn)
|
||||
}
|
||||
|
||||
// note: can be called only once because TxAppInfo is cleared
|
||||
func (a *TxAppInfo) runAfterFuncs(txErr error) error {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
|
||||
var errs []error
|
||||
|
||||
for _, call := range a.afterFuncs {
|
||||
if err := call(txErr); err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
}
|
||||
|
||||
a.afterFuncs = nil
|
||||
|
||||
if len(errs) > 0 {
|
||||
return fmt.Errorf("transaction afterFunc errors: %w", errors.Join(errs...))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
413
core/db_tx_test.go
Normal file
413
core/db_tx_test.go
Normal file
|
@ -0,0 +1,413 @@
|
|||
package core_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestRunInTransaction(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
t.Run("failed nested transaction", func(t *testing.T) {
|
||||
app.RunInTransaction(func(txApp core.App) error {
|
||||
superuser, _ := txApp.FindAuthRecordByEmail(core.CollectionNameSuperusers, "test@example.com")
|
||||
|
||||
return txApp.RunInTransaction(func(tx2Dao core.App) error {
|
||||
if err := tx2Dao.Delete(superuser); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return errors.New("test error")
|
||||
})
|
||||
})
|
||||
|
||||
// superuser should still exist
|
||||
superuser, _ := app.FindAuthRecordByEmail(core.CollectionNameSuperusers, "test@example.com")
|
||||
if superuser == nil {
|
||||
t.Fatal("Expected superuser test@example.com to not be deleted")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("successful nested transaction", func(t *testing.T) {
|
||||
app.RunInTransaction(func(txApp core.App) error {
|
||||
superuser, _ := txApp.FindAuthRecordByEmail(core.CollectionNameSuperusers, "test@example.com")
|
||||
|
||||
return txApp.RunInTransaction(func(tx2Dao core.App) error {
|
||||
return tx2Dao.Delete(superuser)
|
||||
})
|
||||
})
|
||||
|
||||
// superuser should have been deleted
|
||||
superuser, _ := app.FindAuthRecordByEmail(core.CollectionNameSuperusers, "test@example.com")
|
||||
if superuser != nil {
|
||||
t.Fatalf("Expected superuser test@example.com to be deleted, found %v", superuser)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestTransactionHooksCallsOnFailure(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
createHookCalls := 0
|
||||
updateHookCalls := 0
|
||||
deleteHookCalls := 0
|
||||
afterCreateHookCalls := 0
|
||||
afterUpdateHookCalls := 0
|
||||
afterDeleteHookCalls := 0
|
||||
|
||||
app.OnModelCreate().BindFunc(func(e *core.ModelEvent) error {
|
||||
createHookCalls++
|
||||
return e.Next()
|
||||
})
|
||||
|
||||
app.OnModelUpdate().BindFunc(func(e *core.ModelEvent) error {
|
||||
updateHookCalls++
|
||||
return e.Next()
|
||||
})
|
||||
|
||||
app.OnModelDelete().BindFunc(func(e *core.ModelEvent) error {
|
||||
deleteHookCalls++
|
||||
return e.Next()
|
||||
})
|
||||
|
||||
app.OnModelAfterCreateSuccess().BindFunc(func(e *core.ModelEvent) error {
|
||||
afterCreateHookCalls++
|
||||
return e.Next()
|
||||
})
|
||||
|
||||
app.OnModelAfterUpdateSuccess().BindFunc(func(e *core.ModelEvent) error {
|
||||
afterUpdateHookCalls++
|
||||
return e.Next()
|
||||
})
|
||||
|
||||
app.OnModelAfterDeleteSuccess().BindFunc(func(e *core.ModelEvent) error {
|
||||
afterDeleteHookCalls++
|
||||
return e.Next()
|
||||
})
|
||||
|
||||
existingModel, _ := app.FindAuthRecordByEmail(core.CollectionNameSuperusers, "test@example.com")
|
||||
|
||||
app.RunInTransaction(func(txApp1 core.App) error {
|
||||
return txApp1.RunInTransaction(func(txApp2 core.App) error {
|
||||
// test create
|
||||
// ---
|
||||
newModel := core.NewRecord(existingModel.Collection())
|
||||
newModel.SetEmail("test_new1@example.com")
|
||||
newModel.SetPassword("1234567890")
|
||||
if err := txApp2.Save(newModel); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// test update (twice)
|
||||
// ---
|
||||
if err := txApp2.Save(existingModel); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := txApp2.Save(existingModel); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// test delete
|
||||
// ---
|
||||
if err := txApp2.Delete(newModel); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return errors.New("test_tx_error")
|
||||
})
|
||||
})
|
||||
|
||||
if createHookCalls != 1 {
|
||||
t.Errorf("Expected createHookCalls to be called 1 time, got %d", createHookCalls)
|
||||
}
|
||||
if updateHookCalls != 2 {
|
||||
t.Errorf("Expected updateHookCalls to be called 2 times, got %d", updateHookCalls)
|
||||
}
|
||||
if deleteHookCalls != 1 {
|
||||
t.Errorf("Expected deleteHookCalls to be called 1 time, got %d", deleteHookCalls)
|
||||
}
|
||||
if afterCreateHookCalls != 0 {
|
||||
t.Errorf("Expected afterCreateHookCalls to be called 0 times, got %d", afterCreateHookCalls)
|
||||
}
|
||||
if afterUpdateHookCalls != 0 {
|
||||
t.Errorf("Expected afterUpdateHookCalls to be called 0 times, got %d", afterUpdateHookCalls)
|
||||
}
|
||||
if afterDeleteHookCalls != 0 {
|
||||
t.Errorf("Expected afterDeleteHookCalls to be called 0 times, got %d", afterDeleteHookCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransactionHooksCallsOnSuccess(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
createHookCalls := 0
|
||||
updateHookCalls := 0
|
||||
deleteHookCalls := 0
|
||||
afterCreateHookCalls := 0
|
||||
afterUpdateHookCalls := 0
|
||||
afterDeleteHookCalls := 0
|
||||
|
||||
app.OnModelCreate().BindFunc(func(e *core.ModelEvent) error {
|
||||
createHookCalls++
|
||||
return e.Next()
|
||||
})
|
||||
|
||||
app.OnModelUpdate().BindFunc(func(e *core.ModelEvent) error {
|
||||
updateHookCalls++
|
||||
return e.Next()
|
||||
})
|
||||
|
||||
app.OnModelDelete().BindFunc(func(e *core.ModelEvent) error {
|
||||
deleteHookCalls++
|
||||
return e.Next()
|
||||
})
|
||||
|
||||
app.OnModelAfterCreateSuccess().BindFunc(func(e *core.ModelEvent) error {
|
||||
if e.App.IsTransactional() {
|
||||
t.Fatal("Expected e.App to be non-transactional")
|
||||
}
|
||||
|
||||
afterCreateHookCalls++
|
||||
return e.Next()
|
||||
})
|
||||
|
||||
app.OnModelAfterUpdateSuccess().BindFunc(func(e *core.ModelEvent) error {
|
||||
if e.App.IsTransactional() {
|
||||
t.Fatal("Expected e.App to be non-transactional")
|
||||
}
|
||||
|
||||
afterUpdateHookCalls++
|
||||
return e.Next()
|
||||
})
|
||||
|
||||
app.OnModelAfterDeleteSuccess().BindFunc(func(e *core.ModelEvent) error {
|
||||
if e.App.IsTransactional() {
|
||||
t.Fatal("Expected e.App to be non-transactional")
|
||||
}
|
||||
|
||||
afterDeleteHookCalls++
|
||||
return e.Next()
|
||||
})
|
||||
|
||||
existingModel, _ := app.FindAuthRecordByEmail(core.CollectionNameSuperusers, "test@example.com")
|
||||
|
||||
app.RunInTransaction(func(txApp1 core.App) error {
|
||||
return txApp1.RunInTransaction(func(txApp2 core.App) error {
|
||||
// test create
|
||||
// ---
|
||||
newModel := core.NewRecord(existingModel.Collection())
|
||||
newModel.SetEmail("test_new1@example.com")
|
||||
newModel.SetPassword("1234567890")
|
||||
if err := txApp2.Save(newModel); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// test update (twice)
|
||||
// ---
|
||||
if err := txApp2.Save(existingModel); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := txApp2.Save(existingModel); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// test delete
|
||||
// ---
|
||||
if err := txApp2.Delete(newModel); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
})
|
||||
|
||||
if createHookCalls != 1 {
|
||||
t.Errorf("Expected createHookCalls to be called 1 time, got %d", createHookCalls)
|
||||
}
|
||||
if updateHookCalls != 2 {
|
||||
t.Errorf("Expected updateHookCalls to be called 2 times, got %d", updateHookCalls)
|
||||
}
|
||||
if deleteHookCalls != 1 {
|
||||
t.Errorf("Expected deleteHookCalls to be called 1 time, got %d", deleteHookCalls)
|
||||
}
|
||||
if afterCreateHookCalls != 1 {
|
||||
t.Errorf("Expected afterCreateHookCalls to be called 1 time, got %d", afterCreateHookCalls)
|
||||
}
|
||||
if afterUpdateHookCalls != 2 {
|
||||
t.Errorf("Expected afterUpdateHookCalls to be called 2 times, got %d", afterUpdateHookCalls)
|
||||
}
|
||||
if afterDeleteHookCalls != 1 {
|
||||
t.Errorf("Expected afterDeleteHookCalls to be called 1 time, got %d", afterDeleteHookCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransactionFromInnerCreateHook(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
app.OnRecordCreateExecute("demo2").BindFunc(func(e *core.RecordEvent) error {
|
||||
originalApp := e.App
|
||||
return e.App.RunInTransaction(func(txApp core.App) error {
|
||||
e.App = txApp
|
||||
defer func() {
|
||||
e.App = originalApp
|
||||
}()
|
||||
|
||||
nextErr := e.Next()
|
||||
|
||||
return nextErr
|
||||
})
|
||||
})
|
||||
|
||||
app.OnRecordAfterCreateSuccess("demo2").BindFunc(func(e *core.RecordEvent) error {
|
||||
if e.App.IsTransactional() {
|
||||
t.Fatal("Expected e.App to be non-transactional")
|
||||
}
|
||||
|
||||
// perform a db query with the app instance to ensure that it is still valid
|
||||
_, err := e.App.FindFirstRecordByFilter("demo2", "1=1")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to perform a db query after tx success: %v", err)
|
||||
}
|
||||
|
||||
return e.Next()
|
||||
})
|
||||
|
||||
collection, err := app.FindCollectionByNameOrId("demo2")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
record := core.NewRecord(collection)
|
||||
|
||||
record.Set("title", "test_inner_tx")
|
||||
|
||||
if err = app.Save(record); err != nil {
|
||||
t.Fatalf("Create failed: %v", err)
|
||||
}
|
||||
|
||||
expectedHookCalls := map[string]int{
|
||||
"OnRecordCreateExecute": 1,
|
||||
"OnRecordAfterCreateSuccess": 1,
|
||||
}
|
||||
for k, total := range expectedHookCalls {
|
||||
if found, ok := app.EventCalls[k]; !ok || total != found {
|
||||
t.Fatalf("Expected %q %d calls, got %d", k, total, found)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransactionFromInnerUpdateHook(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
app.OnRecordUpdateExecute("demo2").BindFunc(func(e *core.RecordEvent) error {
|
||||
originalApp := e.App
|
||||
return e.App.RunInTransaction(func(txApp core.App) error {
|
||||
e.App = txApp
|
||||
defer func() {
|
||||
e.App = originalApp
|
||||
}()
|
||||
|
||||
nextErr := e.Next()
|
||||
|
||||
return nextErr
|
||||
})
|
||||
})
|
||||
|
||||
app.OnRecordAfterUpdateSuccess("demo2").BindFunc(func(e *core.RecordEvent) error {
|
||||
if e.App.IsTransactional() {
|
||||
t.Fatal("Expected e.App to be non-transactional")
|
||||
}
|
||||
|
||||
// perform a db query with the app instance to ensure that it is still valid
|
||||
_, err := e.App.FindFirstRecordByFilter("demo2", "1=1")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to perform a db query after tx success: %v", err)
|
||||
}
|
||||
|
||||
return e.Next()
|
||||
})
|
||||
|
||||
existingModel, err := app.FindFirstRecordByFilter("demo2", "1=1")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err = app.Save(existingModel); err != nil {
|
||||
t.Fatalf("Update failed: %v", err)
|
||||
}
|
||||
|
||||
expectedHookCalls := map[string]int{
|
||||
"OnRecordUpdateExecute": 1,
|
||||
"OnRecordAfterUpdateSuccess": 1,
|
||||
}
|
||||
for k, total := range expectedHookCalls {
|
||||
if found, ok := app.EventCalls[k]; !ok || total != found {
|
||||
t.Fatalf("Expected %q %d calls, got %d", k, total, found)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransactionFromInnerDeleteHook(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
app.OnRecordDeleteExecute("demo2").BindFunc(func(e *core.RecordEvent) error {
|
||||
originalApp := e.App
|
||||
return e.App.RunInTransaction(func(txApp core.App) error {
|
||||
e.App = txApp
|
||||
defer func() {
|
||||
e.App = originalApp
|
||||
}()
|
||||
|
||||
nextErr := e.Next()
|
||||
|
||||
return nextErr
|
||||
})
|
||||
})
|
||||
|
||||
app.OnRecordAfterDeleteSuccess("demo2").BindFunc(func(e *core.RecordEvent) error {
|
||||
if e.App.IsTransactional() {
|
||||
t.Fatal("Expected e.App to be non-transactional")
|
||||
}
|
||||
|
||||
// perform a db query with the app instance to ensure that it is still valid
|
||||
_, err := e.App.FindFirstRecordByFilter("demo2", "1=1")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to perform a db query after tx success: %v", err)
|
||||
}
|
||||
|
||||
return e.Next()
|
||||
})
|
||||
|
||||
existingModel, err := app.FindFirstRecordByFilter("demo2", "1=1")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err = app.Delete(existingModel); err != nil {
|
||||
t.Fatalf("Delete failed: %v", err)
|
||||
}
|
||||
|
||||
expectedHookCalls := map[string]int{
|
||||
"OnRecordDeleteExecute": 1,
|
||||
"OnRecordAfterDeleteSuccess": 1,
|
||||
}
|
||||
for k, total := range expectedHookCalls {
|
||||
if found, ok := app.EventCalls[k]; !ok || total != found {
|
||||
t.Fatalf("Expected %q %d calls, got %d", k, total, found)
|
||||
}
|
||||
}
|
||||
}
|
197
core/event_request.go
Normal file
197
core/event_request.go
Normal file
|
@ -0,0 +1,197 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"maps"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/inflector"
|
||||
"github.com/pocketbase/pocketbase/tools/router"
|
||||
)
|
||||
|
||||
// Common request store keys used by the middlewares and api handlers.
|
||||
const (
|
||||
RequestEventKeyInfoContext = "infoContext"
|
||||
)
|
||||
|
||||
// RequestEvent defines the PocketBase router handler event.
|
||||
type RequestEvent struct {
|
||||
App App
|
||||
|
||||
cachedRequestInfo *RequestInfo
|
||||
|
||||
Auth *Record
|
||||
|
||||
router.Event
|
||||
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// RealIP returns the "real" IP address from the configured trusted proxy headers.
|
||||
//
|
||||
// If Settings.TrustedProxy is not configured or the found IP is empty,
|
||||
// it fallbacks to e.RemoteIP().
|
||||
//
|
||||
// NB!
|
||||
// Be careful when used in a security critical context as it relies on
|
||||
// the trusted proxy to be properly configured and your app to be accessible only through it.
|
||||
// If you are not sure, use e.RemoteIP().
|
||||
func (e *RequestEvent) RealIP() string {
|
||||
settings := e.App.Settings()
|
||||
|
||||
for _, h := range settings.TrustedProxy.Headers {
|
||||
headerValues := e.Request.Header.Values(h)
|
||||
if len(headerValues) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// extract the last header value as it is expected to be the one controlled by the proxy
|
||||
ipsList := headerValues[len(headerValues)-1]
|
||||
if ipsList == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
ips := strings.Split(ipsList, ",")
|
||||
|
||||
if settings.TrustedProxy.UseLeftmostIP {
|
||||
for _, ip := range ips {
|
||||
parsed, err := netip.ParseAddr(strings.TrimSpace(ip))
|
||||
if err == nil {
|
||||
return parsed.StringExpanded()
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for i := len(ips) - 1; i >= 0; i-- {
|
||||
parsed, err := netip.ParseAddr(strings.TrimSpace(ips[i]))
|
||||
if err == nil {
|
||||
return parsed.StringExpanded()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return e.RemoteIP()
|
||||
}
|
||||
|
||||
// HasSuperuserAuth checks whether the current RequestEvent has superuser authentication loaded.
|
||||
func (e *RequestEvent) HasSuperuserAuth() bool {
|
||||
return e.Auth != nil && e.Auth.IsSuperuser()
|
||||
}
|
||||
|
||||
// RequestInfo parses the current request into RequestInfo instance.
|
||||
//
|
||||
// Note that the returned result is cached to avoid copying the request data multiple times
|
||||
// but the auth state and other common store items are always refreshed in case they were changed by another handler.
|
||||
func (e *RequestEvent) RequestInfo() (*RequestInfo, error) {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
|
||||
if e.cachedRequestInfo != nil {
|
||||
e.cachedRequestInfo.Auth = e.Auth
|
||||
|
||||
infoCtx, _ := e.Get(RequestEventKeyInfoContext).(string)
|
||||
if infoCtx != "" {
|
||||
e.cachedRequestInfo.Context = infoCtx
|
||||
} else {
|
||||
e.cachedRequestInfo.Context = RequestInfoContextDefault
|
||||
}
|
||||
} else {
|
||||
// (re)init e.cachedRequestInfo based on the current request event
|
||||
if err := e.initRequestInfo(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return e.cachedRequestInfo, nil
|
||||
}
|
||||
|
||||
func (e *RequestEvent) initRequestInfo() error {
|
||||
infoCtx, _ := e.Get(RequestEventKeyInfoContext).(string)
|
||||
if infoCtx == "" {
|
||||
infoCtx = RequestInfoContextDefault
|
||||
}
|
||||
|
||||
info := &RequestInfo{
|
||||
Context: infoCtx,
|
||||
Method: e.Request.Method,
|
||||
Query: map[string]string{},
|
||||
Headers: map[string]string{},
|
||||
Body: map[string]any{},
|
||||
}
|
||||
|
||||
if err := e.BindBody(&info.Body); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// extract the first value of all query params
|
||||
query := e.Request.URL.Query()
|
||||
for k, v := range query {
|
||||
if len(v) > 0 {
|
||||
info.Query[k] = v[0]
|
||||
}
|
||||
}
|
||||
|
||||
// extract the first value of all headers and normalizes the keys
|
||||
// ("X-Token" is converted to "x_token")
|
||||
for k, v := range e.Request.Header {
|
||||
if len(v) > 0 {
|
||||
info.Headers[inflector.Snakecase(k)] = v[0]
|
||||
}
|
||||
}
|
||||
|
||||
info.Auth = e.Auth
|
||||
|
||||
e.cachedRequestInfo = info
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
const (
|
||||
RequestInfoContextDefault = "default"
|
||||
RequestInfoContextExpand = "expand"
|
||||
RequestInfoContextRealtime = "realtime"
|
||||
RequestInfoContextProtectedFile = "protectedFile"
|
||||
RequestInfoContextBatch = "batch"
|
||||
RequestInfoContextOAuth2 = "oauth2"
|
||||
RequestInfoContextOTP = "otp"
|
||||
RequestInfoContextPasswordAuth = "password"
|
||||
)
|
||||
|
||||
// RequestInfo defines a HTTP request data struct, usually used
|
||||
// as part of the `@request.*` filter resolver.
|
||||
//
|
||||
// The Query and Headers fields contains only the first value for each found entry.
|
||||
type RequestInfo struct {
|
||||
Query map[string]string `json:"query"`
|
||||
Headers map[string]string `json:"headers"`
|
||||
Body map[string]any `json:"body"`
|
||||
Auth *Record `json:"auth"`
|
||||
Method string `json:"method"`
|
||||
Context string `json:"context"`
|
||||
}
|
||||
|
||||
// HasSuperuserAuth checks whether the current RequestInfo instance
|
||||
// has superuser authentication loaded.
|
||||
func (info *RequestInfo) HasSuperuserAuth() bool {
|
||||
return info.Auth != nil && info.Auth.IsSuperuser()
|
||||
}
|
||||
|
||||
// Clone creates a new shallow copy of the current RequestInfo and its Auth record (if any).
|
||||
func (info *RequestInfo) Clone() *RequestInfo {
|
||||
clone := &RequestInfo{
|
||||
Method: info.Method,
|
||||
Context: info.Context,
|
||||
Query: maps.Clone(info.Query),
|
||||
Body: maps.Clone(info.Body),
|
||||
Headers: maps.Clone(info.Headers),
|
||||
}
|
||||
|
||||
if info.Auth != nil {
|
||||
clone.Auth = info.Auth.Fresh()
|
||||
}
|
||||
|
||||
return clone
|
||||
}
|
33
core/event_request_batch.go
Normal file
33
core/event_request_batch.go
Normal file
|
@ -0,0 +1,33 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/pocketbase/pocketbase/tools/hook"
|
||||
)
|
||||
|
||||
type BatchRequestEvent struct {
|
||||
hook.Event
|
||||
*RequestEvent
|
||||
|
||||
Batch []*InternalRequest
|
||||
}
|
||||
|
||||
type InternalRequest struct {
|
||||
// note: for uploading files the value must be either *filesystem.File or []*filesystem.File
|
||||
Body map[string]any `form:"body" json:"body"`
|
||||
|
||||
Headers map[string]string `form:"headers" json:"headers"`
|
||||
|
||||
Method string `form:"method" json:"method"`
|
||||
|
||||
URL string `form:"url" json:"url"`
|
||||
}
|
||||
|
||||
func (br InternalRequest) Validate() error {
|
||||
return validation.ValidateStruct(&br,
|
||||
validation.Field(&br.Method, validation.Required, validation.In(http.MethodGet, http.MethodPost, http.MethodPut, http.MethodPatch, http.MethodDelete)),
|
||||
validation.Field(&br.URL, validation.Required, validation.Length(0, 2000)),
|
||||
)
|
||||
}
|
74
core/event_request_batch_test.go
Normal file
74
core/event_request_batch_test.go
Normal file
|
@ -0,0 +1,74 @@
|
|||
package core_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestInternalRequestValidate(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
name string
|
||||
request core.InternalRequest
|
||||
expectedErrors []string
|
||||
}{
|
||||
{
|
||||
"empty struct",
|
||||
core.InternalRequest{},
|
||||
[]string{"method", "url"},
|
||||
},
|
||||
|
||||
// method
|
||||
{
|
||||
"GET method",
|
||||
core.InternalRequest{URL: "test", Method: http.MethodGet},
|
||||
[]string{},
|
||||
},
|
||||
{
|
||||
"POST method",
|
||||
core.InternalRequest{URL: "test", Method: http.MethodPost},
|
||||
[]string{},
|
||||
},
|
||||
{
|
||||
"PUT method",
|
||||
core.InternalRequest{URL: "test", Method: http.MethodPut},
|
||||
[]string{},
|
||||
},
|
||||
{
|
||||
"PATCH method",
|
||||
core.InternalRequest{URL: "test", Method: http.MethodPatch},
|
||||
[]string{},
|
||||
},
|
||||
{
|
||||
"DELETE method",
|
||||
core.InternalRequest{URL: "test", Method: http.MethodDelete},
|
||||
[]string{},
|
||||
},
|
||||
{
|
||||
"unknown method",
|
||||
core.InternalRequest{URL: "test", Method: "unknown"},
|
||||
[]string{"method"},
|
||||
},
|
||||
|
||||
// url
|
||||
{
|
||||
"url <= 2000",
|
||||
core.InternalRequest{URL: strings.Repeat("a", 2000), Method: http.MethodGet},
|
||||
[]string{},
|
||||
},
|
||||
{
|
||||
"url > 2000",
|
||||
core.InternalRequest{URL: strings.Repeat("a", 2001), Method: http.MethodGet},
|
||||
[]string{"url"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
tests.TestValidationErrors(t, s.request.Validate(), s.expectedErrors)
|
||||
})
|
||||
}
|
||||
}
|
334
core/event_request_test.go
Normal file
334
core/event_request_test.go
Normal file
|
@ -0,0 +1,334 @@
|
|||
package core_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestEventRequestRealIP(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
headers := map[string][]string{
|
||||
"CF-Connecting-IP": {"1.2.3.4", "1.1.1.1"},
|
||||
"Fly-Client-IP": {"1.2.3.4", "1.1.1.2"},
|
||||
"X-Real-IP": {"1.2.3.4", "1.1.1.3,1.1.1.4"},
|
||||
"X-Forwarded-For": {"1.2.3.4", "invalid,1.1.1.5,1.1.1.6,invalid"},
|
||||
}
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
headers map[string][]string
|
||||
trustedHeaders []string
|
||||
useLeftmostIP bool
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
"no trusted headers",
|
||||
headers,
|
||||
nil,
|
||||
false,
|
||||
"127.0.0.1",
|
||||
},
|
||||
{
|
||||
"non-matching trusted header",
|
||||
headers,
|
||||
[]string{"header1", "header2"},
|
||||
false,
|
||||
"127.0.0.1",
|
||||
},
|
||||
{
|
||||
"trusted X-Real-IP (rightmost)",
|
||||
headers,
|
||||
[]string{"header1", "x-real-ip", "x-forwarded-for"},
|
||||
false,
|
||||
"1.1.1.4",
|
||||
},
|
||||
{
|
||||
"trusted X-Real-IP (leftmost)",
|
||||
headers,
|
||||
[]string{"header1", "x-real-ip", "x-forwarded-for"},
|
||||
true,
|
||||
"1.1.1.3",
|
||||
},
|
||||
{
|
||||
"trusted X-Forwarded-For (rightmost)",
|
||||
headers,
|
||||
[]string{"header1", "x-forwarded-for"},
|
||||
false,
|
||||
"1.1.1.6",
|
||||
},
|
||||
{
|
||||
"trusted X-Forwarded-For (leftmost)",
|
||||
headers,
|
||||
[]string{"header1", "x-forwarded-for"},
|
||||
true,
|
||||
"1.1.1.5",
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
app, err := tests.NewTestApp()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer app.Cleanup()
|
||||
|
||||
app.Settings().TrustedProxy.Headers = s.trustedHeaders
|
||||
app.Settings().TrustedProxy.UseLeftmostIP = s.useLeftmostIP
|
||||
|
||||
event := core.RequestEvent{}
|
||||
event.App = app
|
||||
|
||||
event.Request, err = http.NewRequest(http.MethodGet, "/", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
event.Request.RemoteAddr = "127.0.0.1:80" // fallback
|
||||
|
||||
for k, values := range s.headers {
|
||||
for _, v := range values {
|
||||
event.Request.Header.Add(k, v)
|
||||
}
|
||||
}
|
||||
|
||||
result := event.RealIP()
|
||||
|
||||
if result != s.expected {
|
||||
t.Fatalf("Expected ip %q, got %q", s.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventRequestHasSuperUserAuth(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
superuser, err := app.FindAuthRecordByEmail(core.CollectionNameSuperusers, "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
record *core.Record
|
||||
expected bool
|
||||
}{
|
||||
{"nil record", nil, false},
|
||||
{"regular user record", user, false},
|
||||
{"superuser record", superuser, true},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
e := core.RequestEvent{}
|
||||
e.Auth = s.record
|
||||
|
||||
result := e.HasSuperuserAuth()
|
||||
|
||||
if result != s.expected {
|
||||
t.Fatalf("Expected %v, got %v", s.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestEventRequestInfo(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
userCol, err := app.FindCollectionByNameOrId("users")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
user1 := core.NewRecord(userCol)
|
||||
user1.Id = "user1"
|
||||
user1.SetEmail("test1@example.com")
|
||||
|
||||
user2 := core.NewRecord(userCol)
|
||||
user2.Id = "user2"
|
||||
user2.SetEmail("test2@example.com")
|
||||
|
||||
testBody := `{"a":123,"b":"test"}`
|
||||
|
||||
event := core.RequestEvent{}
|
||||
event.Request, err = http.NewRequest("POST", "/test?q1=123&q2=456", strings.NewReader(testBody))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
event.Request.Header.Add("content-type", "application/json")
|
||||
event.Request.Header.Add("x-test", "test")
|
||||
event.Set(core.RequestEventKeyInfoContext, "test")
|
||||
event.Auth = user1
|
||||
|
||||
t.Run("init", func(t *testing.T) {
|
||||
info, err := event.RequestInfo()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to resolve request info: %v", err)
|
||||
}
|
||||
|
||||
raw, err := json.Marshal(info)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to serialize request info: %v", err)
|
||||
}
|
||||
rawStr := string(raw)
|
||||
|
||||
expected := `{"query":{"q1":"123","q2":"456"},"headers":{"content_type":"application/json","x_test":"test"},"body":{"a":123,"b":"test"},"auth":{"avatar":"","collectionId":"_pb_users_auth_","collectionName":"users","created":"","emailVisibility":false,"file":[],"id":"user1","name":"","rel":"","updated":"","username":"","verified":false},"method":"POST","context":"test"}`
|
||||
|
||||
if expected != rawStr {
|
||||
t.Fatalf("Expected\n%v\ngot\n%v", expected, rawStr)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("change user and context", func(t *testing.T) {
|
||||
event.Set(core.RequestEventKeyInfoContext, "test2")
|
||||
event.Auth = user2
|
||||
|
||||
info, err := event.RequestInfo()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to resolve request info: %v", err)
|
||||
}
|
||||
|
||||
raw, err := json.Marshal(info)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to serialize request info: %v", err)
|
||||
}
|
||||
rawStr := string(raw)
|
||||
|
||||
expected := `{"query":{"q1":"123","q2":"456"},"headers":{"content_type":"application/json","x_test":"test"},"body":{"a":123,"b":"test"},"auth":{"avatar":"","collectionId":"_pb_users_auth_","collectionName":"users","created":"","emailVisibility":false,"file":[],"id":"user2","name":"","rel":"","updated":"","username":"","verified":false},"method":"POST","context":"test2"}`
|
||||
|
||||
if expected != rawStr {
|
||||
t.Fatalf("Expected\n%v\ngot\n%v", expected, rawStr)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRequestInfoHasSuperuserAuth(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
superuser, err := app.FindAuthRecordByEmail(core.CollectionNameSuperusers, "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
event := core.RequestEvent{}
|
||||
event.Request, err = http.NewRequest("POST", "/test?q1=123&q2=456", strings.NewReader(`{"a":123,"b":"test"}`))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
event.Request.Header.Add("content-type", "application/json")
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
record *core.Record
|
||||
expected bool
|
||||
}{
|
||||
{"nil record", nil, false},
|
||||
{"regular user record", user, false},
|
||||
{"superuser record", superuser, true},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
event.Auth = s.record
|
||||
|
||||
info, err := event.RequestInfo()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to resolve request info: %v", err)
|
||||
}
|
||||
|
||||
result := info.HasSuperuserAuth()
|
||||
|
||||
if result != s.expected {
|
||||
t.Fatalf("Expected %v, got %v", s.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestInfoClone(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
userCol, err := app.FindCollectionByNameOrId("users")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
user := core.NewRecord(userCol)
|
||||
user.Id = "user1"
|
||||
user.SetEmail("test1@example.com")
|
||||
|
||||
event := core.RequestEvent{}
|
||||
event.Request, err = http.NewRequest("POST", "/test?q1=123&q2=456", strings.NewReader(`{"a":123,"b":"test"}`))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
event.Request.Header.Add("content-type", "application/json")
|
||||
event.Auth = user
|
||||
|
||||
info, err := event.RequestInfo()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to resolve request info: %v", err)
|
||||
}
|
||||
|
||||
clone := info.Clone()
|
||||
|
||||
// modify the clone fields to ensure that it is a shallow copy
|
||||
clone.Headers["new_header"] = "test"
|
||||
clone.Query["new_query"] = "test"
|
||||
clone.Body["new_body"] = "test"
|
||||
clone.Auth.Id = "user2" // should be a Fresh copy of the record
|
||||
|
||||
// check the original data
|
||||
// ---
|
||||
originalRaw, err := json.Marshal(info)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to serialize original request info: %v", err)
|
||||
}
|
||||
originalRawStr := string(originalRaw)
|
||||
|
||||
expectedRawStr := `{"query":{"q1":"123","q2":"456"},"headers":{"content_type":"application/json"},"body":{"a":123,"b":"test"},"auth":{"avatar":"","collectionId":"_pb_users_auth_","collectionName":"users","created":"","emailVisibility":false,"file":[],"id":"user1","name":"","rel":"","updated":"","username":"","verified":false},"method":"POST","context":"default"}`
|
||||
if expectedRawStr != originalRawStr {
|
||||
t.Fatalf("Expected original info\n%v\ngot\n%v", expectedRawStr, originalRawStr)
|
||||
}
|
||||
|
||||
// check the clone data
|
||||
// ---
|
||||
cloneRaw, err := json.Marshal(clone)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to serialize clone request info: %v", err)
|
||||
}
|
||||
cloneRawStr := string(cloneRaw)
|
||||
|
||||
expectedCloneStr := `{"query":{"new_query":"test","q1":"123","q2":"456"},"headers":{"content_type":"application/json","new_header":"test"},"body":{"a":123,"b":"test","new_body":"test"},"auth":{"avatar":"","collectionId":"_pb_users_auth_","collectionName":"users","created":"","emailVisibility":false,"file":[],"id":"user2","name":"","rel":"","updated":"","username":"","verified":false},"method":"POST","context":"default"}`
|
||||
if expectedCloneStr != cloneRawStr {
|
||||
t.Fatalf("Expected clone info\n%v\ngot\n%v", expectedCloneStr, cloneRawStr)
|
||||
}
|
||||
}
|
582
core/events.go
Normal file
582
core/events.go
Normal file
|
@ -0,0 +1,582 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/auth"
|
||||
"github.com/pocketbase/pocketbase/tools/hook"
|
||||
"github.com/pocketbase/pocketbase/tools/mailer"
|
||||
"github.com/pocketbase/pocketbase/tools/router"
|
||||
"github.com/pocketbase/pocketbase/tools/search"
|
||||
"github.com/pocketbase/pocketbase/tools/subscriptions"
|
||||
"golang.org/x/crypto/acme/autocert"
|
||||
)
|
||||
|
||||
type HookTagger interface {
|
||||
HookTags() []string
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
type baseModelEventData struct {
|
||||
Model Model
|
||||
}
|
||||
|
||||
func (e *baseModelEventData) Tags() []string {
|
||||
if e.Model == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if ht, ok := e.Model.(HookTagger); ok {
|
||||
return ht.HookTags()
|
||||
}
|
||||
|
||||
return []string{e.Model.TableName()}
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
type baseRecordEventData struct {
|
||||
Record *Record
|
||||
}
|
||||
|
||||
func (e *baseRecordEventData) Tags() []string {
|
||||
if e.Record == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return e.Record.HookTags()
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
type baseCollectionEventData struct {
|
||||
Collection *Collection
|
||||
}
|
||||
|
||||
func (e *baseCollectionEventData) Tags() []string {
|
||||
if e.Collection == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
tags := make([]string, 0, 2)
|
||||
|
||||
if e.Collection.Id != "" {
|
||||
tags = append(tags, e.Collection.Id)
|
||||
}
|
||||
|
||||
if e.Collection.Name != "" {
|
||||
tags = append(tags, e.Collection.Name)
|
||||
}
|
||||
|
||||
return tags
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
// App events data
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
type BootstrapEvent struct {
|
||||
hook.Event
|
||||
App App
|
||||
}
|
||||
|
||||
type TerminateEvent struct {
|
||||
hook.Event
|
||||
App App
|
||||
IsRestart bool
|
||||
}
|
||||
|
||||
type BackupEvent struct {
|
||||
hook.Event
|
||||
App App
|
||||
Context context.Context
|
||||
Name string // the name of the backup to create/restore.
|
||||
Exclude []string // list of dir entries to exclude from the backup create/restore.
|
||||
}
|
||||
|
||||
type ServeEvent struct {
|
||||
hook.Event
|
||||
App App
|
||||
Router *router.Router[*RequestEvent]
|
||||
Server *http.Server
|
||||
CertManager *autocert.Manager
|
||||
|
||||
// InstallerFunc is the "installer" function that is called after
|
||||
// successful server tcp bind but only if there is no explicit
|
||||
// superuser record created yet.
|
||||
//
|
||||
// It runs in a separate goroutine and its default value is [apis.DefaultInstallerFunc].
|
||||
//
|
||||
// It receives a system superuser record as argument that you can use to generate
|
||||
// a short-lived auth token (e.g. systemSuperuser.NewStaticAuthToken(30 * time.Minute))
|
||||
// and concatenate it as query param for your installer page
|
||||
// (if you are using the client-side SDKs, you can then load the
|
||||
// token with pb.authStore.save(token) and perform any Web API request
|
||||
// e.g. creating a new superuser).
|
||||
//
|
||||
// Set it to nil if you want to skip the installer.
|
||||
InstallerFunc func(app App, systemSuperuser *Record, baseURL string) error
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
// Settings events data
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
type SettingsListRequestEvent struct {
|
||||
hook.Event
|
||||
*RequestEvent
|
||||
|
||||
Settings *Settings
|
||||
}
|
||||
|
||||
type SettingsUpdateRequestEvent struct {
|
||||
hook.Event
|
||||
*RequestEvent
|
||||
|
||||
OldSettings *Settings
|
||||
NewSettings *Settings
|
||||
}
|
||||
|
||||
type SettingsReloadEvent struct {
|
||||
hook.Event
|
||||
App App
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
// Mailer events data
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
type MailerEvent struct {
|
||||
hook.Event
|
||||
App App
|
||||
|
||||
Mailer mailer.Mailer
|
||||
Message *mailer.Message
|
||||
}
|
||||
|
||||
type MailerRecordEvent struct {
|
||||
MailerEvent
|
||||
baseRecordEventData
|
||||
Meta map[string]any
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
// Model events data
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
const (
|
||||
ModelEventTypeCreate = "create"
|
||||
ModelEventTypeUpdate = "update"
|
||||
ModelEventTypeDelete = "delete"
|
||||
ModelEventTypeValidate = "validate"
|
||||
)
|
||||
|
||||
type ModelEvent struct {
|
||||
hook.Event
|
||||
App App
|
||||
baseModelEventData
|
||||
Context context.Context
|
||||
|
||||
// Could be any of the ModelEventType* constants, like:
|
||||
// - create
|
||||
// - update
|
||||
// - delete
|
||||
// - validate
|
||||
Type string
|
||||
}
|
||||
|
||||
type ModelErrorEvent struct {
|
||||
Error error
|
||||
ModelEvent
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
// Record events data
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
type RecordEvent struct {
|
||||
hook.Event
|
||||
App App
|
||||
baseRecordEventData
|
||||
Context context.Context
|
||||
|
||||
// Could be any of the ModelEventType* constants, like:
|
||||
// - create
|
||||
// - update
|
||||
// - delete
|
||||
// - validate
|
||||
Type string
|
||||
}
|
||||
|
||||
type RecordErrorEvent struct {
|
||||
Error error
|
||||
RecordEvent
|
||||
}
|
||||
|
||||
func syncModelEventWithRecordEvent(me *ModelEvent, re *RecordEvent) {
|
||||
me.App = re.App
|
||||
me.Context = re.Context
|
||||
me.Type = re.Type
|
||||
|
||||
// @todo enable if after profiling doesn't have significant impact
|
||||
// skip for now to avoid excessive checks and assume that the
|
||||
// Model and the Record fields still points to the same instance
|
||||
//
|
||||
// if _, ok := me.Model.(*Record); ok {
|
||||
// me.Model = re.Record
|
||||
// } else if proxy, ok := me.Model.(RecordProxy); ok {
|
||||
// proxy.SetProxyRecord(re.Record)
|
||||
// }
|
||||
}
|
||||
|
||||
func syncRecordEventWithModelEvent(re *RecordEvent, me *ModelEvent) {
|
||||
re.App = me.App
|
||||
re.Context = me.Context
|
||||
re.Type = me.Type
|
||||
}
|
||||
|
||||
func newRecordEventFromModelEvent(me *ModelEvent) (*RecordEvent, bool) {
|
||||
record, ok := me.Model.(*Record)
|
||||
if !ok {
|
||||
proxy, ok := me.Model.(RecordProxy)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
record = proxy.ProxyRecord()
|
||||
}
|
||||
|
||||
re := new(RecordEvent)
|
||||
re.App = me.App
|
||||
re.Context = me.Context
|
||||
re.Type = me.Type
|
||||
re.Record = record
|
||||
|
||||
return re, true
|
||||
}
|
||||
|
||||
func newRecordErrorEventFromModelErrorEvent(me *ModelErrorEvent) (*RecordErrorEvent, bool) {
|
||||
recordEvent, ok := newRecordEventFromModelEvent(&me.ModelEvent)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
re := new(RecordErrorEvent)
|
||||
re.RecordEvent = *recordEvent
|
||||
re.Error = me.Error
|
||||
|
||||
return re, true
|
||||
}
|
||||
|
||||
func syncModelErrorEventWithRecordErrorEvent(me *ModelErrorEvent, re *RecordErrorEvent) {
|
||||
syncModelEventWithRecordEvent(&me.ModelEvent, &re.RecordEvent)
|
||||
me.Error = re.Error
|
||||
}
|
||||
|
||||
func syncRecordErrorEventWithModelErrorEvent(re *RecordErrorEvent, me *ModelErrorEvent) {
|
||||
syncRecordEventWithModelEvent(&re.RecordEvent, &me.ModelEvent)
|
||||
re.Error = me.Error
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
// Collection events data
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
type CollectionEvent struct {
|
||||
hook.Event
|
||||
App App
|
||||
baseCollectionEventData
|
||||
Context context.Context
|
||||
|
||||
// Could be any of the ModelEventType* constants, like:
|
||||
// - create
|
||||
// - update
|
||||
// - delete
|
||||
// - validate
|
||||
Type string
|
||||
}
|
||||
|
||||
type CollectionErrorEvent struct {
|
||||
Error error
|
||||
CollectionEvent
|
||||
}
|
||||
|
||||
func syncModelEventWithCollectionEvent(me *ModelEvent, ce *CollectionEvent) {
|
||||
me.App = ce.App
|
||||
me.Context = ce.Context
|
||||
me.Type = ce.Type
|
||||
me.Model = ce.Collection
|
||||
}
|
||||
|
||||
func syncCollectionEventWithModelEvent(ce *CollectionEvent, me *ModelEvent) {
|
||||
ce.App = me.App
|
||||
ce.Context = me.Context
|
||||
ce.Type = me.Type
|
||||
if c, ok := me.Model.(*Collection); ok {
|
||||
ce.Collection = c
|
||||
}
|
||||
}
|
||||
|
||||
func newCollectionEventFromModelEvent(me *ModelEvent) (*CollectionEvent, bool) {
|
||||
record, ok := me.Model.(*Collection)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
ce := new(CollectionEvent)
|
||||
ce.App = me.App
|
||||
ce.Context = me.Context
|
||||
ce.Type = me.Type
|
||||
ce.Collection = record
|
||||
|
||||
return ce, true
|
||||
}
|
||||
|
||||
func newCollectionErrorEventFromModelErrorEvent(me *ModelErrorEvent) (*CollectionErrorEvent, bool) {
|
||||
collectionevent, ok := newCollectionEventFromModelEvent(&me.ModelEvent)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
ce := new(CollectionErrorEvent)
|
||||
ce.CollectionEvent = *collectionevent
|
||||
ce.Error = me.Error
|
||||
|
||||
return ce, true
|
||||
}
|
||||
|
||||
func syncModelErrorEventWithCollectionErrorEvent(me *ModelErrorEvent, ce *CollectionErrorEvent) {
|
||||
syncModelEventWithCollectionEvent(&me.ModelEvent, &ce.CollectionEvent)
|
||||
me.Error = ce.Error
|
||||
}
|
||||
|
||||
func syncCollectionErrorEventWithModelErrorEvent(ce *CollectionErrorEvent, me *ModelErrorEvent) {
|
||||
syncCollectionEventWithModelEvent(&ce.CollectionEvent, &me.ModelEvent)
|
||||
ce.Error = me.Error
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
// File API events data
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
type FileTokenRequestEvent struct {
|
||||
hook.Event
|
||||
*RequestEvent
|
||||
baseRecordEventData
|
||||
|
||||
Token string
|
||||
}
|
||||
|
||||
type FileDownloadRequestEvent struct {
|
||||
hook.Event
|
||||
*RequestEvent
|
||||
baseCollectionEventData
|
||||
|
||||
Record *Record
|
||||
FileField *FileField
|
||||
ServedPath string
|
||||
ServedName string
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
// Collection API events data
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
type CollectionsListRequestEvent struct {
|
||||
hook.Event
|
||||
*RequestEvent
|
||||
|
||||
Collections []*Collection
|
||||
Result *search.Result
|
||||
}
|
||||
|
||||
type CollectionsImportRequestEvent struct {
|
||||
hook.Event
|
||||
*RequestEvent
|
||||
|
||||
CollectionsData []map[string]any
|
||||
DeleteMissing bool
|
||||
}
|
||||
|
||||
type CollectionRequestEvent struct {
|
||||
hook.Event
|
||||
*RequestEvent
|
||||
baseCollectionEventData
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
// Realtime API events data
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
type RealtimeConnectRequestEvent struct {
|
||||
hook.Event
|
||||
*RequestEvent
|
||||
|
||||
Client subscriptions.Client
|
||||
|
||||
// note: modifying it after the connect has no effect
|
||||
IdleTimeout time.Duration
|
||||
}
|
||||
|
||||
type RealtimeMessageEvent struct {
|
||||
hook.Event
|
||||
*RequestEvent
|
||||
|
||||
Client subscriptions.Client
|
||||
Message *subscriptions.Message
|
||||
}
|
||||
|
||||
type RealtimeSubscribeRequestEvent struct {
|
||||
hook.Event
|
||||
*RequestEvent
|
||||
|
||||
Client subscriptions.Client
|
||||
Subscriptions []string
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
// Record CRUD API events data
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
type RecordsListRequestEvent struct {
|
||||
hook.Event
|
||||
*RequestEvent
|
||||
baseCollectionEventData
|
||||
|
||||
// @todo consider removing and maybe add as generic to the search.Result?
|
||||
Records []*Record
|
||||
Result *search.Result
|
||||
}
|
||||
|
||||
type RecordRequestEvent struct {
|
||||
hook.Event
|
||||
*RequestEvent
|
||||
baseCollectionEventData
|
||||
|
||||
Record *Record
|
||||
}
|
||||
|
||||
type RecordEnrichEvent struct {
|
||||
hook.Event
|
||||
App App
|
||||
baseRecordEventData
|
||||
|
||||
RequestInfo *RequestInfo
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
// Auth Record API events data
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
type RecordCreateOTPRequestEvent struct {
|
||||
hook.Event
|
||||
*RequestEvent
|
||||
baseCollectionEventData
|
||||
|
||||
Record *Record
|
||||
Password string
|
||||
}
|
||||
|
||||
type RecordAuthWithOTPRequestEvent struct {
|
||||
hook.Event
|
||||
*RequestEvent
|
||||
baseCollectionEventData
|
||||
|
||||
Record *Record
|
||||
OTP *OTP
|
||||
}
|
||||
|
||||
type RecordAuthRequestEvent struct {
|
||||
hook.Event
|
||||
*RequestEvent
|
||||
baseCollectionEventData
|
||||
|
||||
Record *Record
|
||||
Token string
|
||||
Meta any
|
||||
AuthMethod string
|
||||
}
|
||||
|
||||
type RecordAuthWithPasswordRequestEvent struct {
|
||||
hook.Event
|
||||
*RequestEvent
|
||||
baseCollectionEventData
|
||||
|
||||
Record *Record
|
||||
Identity string
|
||||
IdentityField string
|
||||
Password string
|
||||
}
|
||||
|
||||
type RecordAuthWithOAuth2RequestEvent struct {
|
||||
hook.Event
|
||||
*RequestEvent
|
||||
baseCollectionEventData
|
||||
|
||||
ProviderName string
|
||||
ProviderClient auth.Provider
|
||||
Record *Record
|
||||
OAuth2User *auth.AuthUser
|
||||
CreateData map[string]any
|
||||
IsNewRecord bool
|
||||
}
|
||||
|
||||
type RecordAuthRefreshRequestEvent struct {
|
||||
hook.Event
|
||||
*RequestEvent
|
||||
baseCollectionEventData
|
||||
|
||||
Record *Record
|
||||
}
|
||||
|
||||
type RecordRequestPasswordResetRequestEvent struct {
|
||||
hook.Event
|
||||
*RequestEvent
|
||||
baseCollectionEventData
|
||||
|
||||
Record *Record
|
||||
}
|
||||
|
||||
type RecordConfirmPasswordResetRequestEvent struct {
|
||||
hook.Event
|
||||
*RequestEvent
|
||||
baseCollectionEventData
|
||||
|
||||
Record *Record
|
||||
}
|
||||
|
||||
type RecordRequestVerificationRequestEvent struct {
|
||||
hook.Event
|
||||
*RequestEvent
|
||||
baseCollectionEventData
|
||||
|
||||
Record *Record
|
||||
}
|
||||
|
||||
type RecordConfirmVerificationRequestEvent struct {
|
||||
hook.Event
|
||||
*RequestEvent
|
||||
baseCollectionEventData
|
||||
|
||||
Record *Record
|
||||
}
|
||||
|
||||
type RecordRequestEmailChangeRequestEvent struct {
|
||||
hook.Event
|
||||
*RequestEvent
|
||||
baseCollectionEventData
|
||||
|
||||
Record *Record
|
||||
NewEmail string
|
||||
}
|
||||
|
||||
type RecordConfirmEmailChangeRequestEvent struct {
|
||||
hook.Event
|
||||
*RequestEvent
|
||||
baseCollectionEventData
|
||||
|
||||
Record *Record
|
||||
NewEmail string
|
||||
}
|
140
core/external_auth_model.go
Normal file
140
core/external_auth_model.go
Normal file
|
@ -0,0 +1,140 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/pocketbase/pocketbase/tools/auth"
|
||||
"github.com/pocketbase/pocketbase/tools/hook"
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
)
|
||||
|
||||
var (
|
||||
_ Model = (*ExternalAuth)(nil)
|
||||
_ PreValidator = (*ExternalAuth)(nil)
|
||||
_ RecordProxy = (*ExternalAuth)(nil)
|
||||
)
|
||||
|
||||
const CollectionNameExternalAuths = "_externalAuths"
|
||||
|
||||
// ExternalAuth defines a Record proxy for working with the externalAuths collection.
|
||||
type ExternalAuth struct {
|
||||
*Record
|
||||
}
|
||||
|
||||
// NewExternalAuth instantiates and returns a new blank *ExternalAuth model.
|
||||
//
|
||||
// Example usage:
|
||||
//
|
||||
// ea := core.NewExternalAuth(app)
|
||||
// ea.SetRecordRef(user.Id)
|
||||
// ea.SetCollectionRef(user.Collection().Id)
|
||||
// ea.SetProvider("google")
|
||||
// ea.SetProviderId("...")
|
||||
// app.Save(ea)
|
||||
func NewExternalAuth(app App) *ExternalAuth {
|
||||
m := &ExternalAuth{}
|
||||
|
||||
c, err := app.FindCachedCollectionByNameOrId(CollectionNameExternalAuths)
|
||||
if err != nil {
|
||||
// this is just to make tests easier since it is a system collection and it is expected to be always accessible
|
||||
// (note: the loaded record is further checked on ExternalAuth.PreValidate())
|
||||
c = NewBaseCollection("@__invalid__")
|
||||
}
|
||||
|
||||
m.Record = NewRecord(c)
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
// PreValidate implements the [PreValidator] interface and checks
|
||||
// whether the proxy is properly loaded.
|
||||
func (m *ExternalAuth) PreValidate(ctx context.Context, app App) error {
|
||||
if m.Record == nil || m.Record.Collection().Name != CollectionNameExternalAuths {
|
||||
return errors.New("missing or invalid ExternalAuth ProxyRecord")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ProxyRecord returns the proxied Record model.
|
||||
func (m *ExternalAuth) ProxyRecord() *Record {
|
||||
return m.Record
|
||||
}
|
||||
|
||||
// SetProxyRecord loads the specified record model into the current proxy.
|
||||
func (m *ExternalAuth) SetProxyRecord(record *Record) {
|
||||
m.Record = record
|
||||
}
|
||||
|
||||
// CollectionRef returns the "collectionRef" field value.
|
||||
func (m *ExternalAuth) CollectionRef() string {
|
||||
return m.GetString("collectionRef")
|
||||
}
|
||||
|
||||
// SetCollectionRef updates the "collectionRef" record field value.
|
||||
func (m *ExternalAuth) SetCollectionRef(collectionId string) {
|
||||
m.Set("collectionRef", collectionId)
|
||||
}
|
||||
|
||||
// RecordRef returns the "recordRef" record field value.
|
||||
func (m *ExternalAuth) RecordRef() string {
|
||||
return m.GetString("recordRef")
|
||||
}
|
||||
|
||||
// SetRecordRef updates the "recordRef" record field value.
|
||||
func (m *ExternalAuth) SetRecordRef(recordId string) {
|
||||
m.Set("recordRef", recordId)
|
||||
}
|
||||
|
||||
// Provider returns the "provider" record field value.
|
||||
func (m *ExternalAuth) Provider() string {
|
||||
return m.GetString("provider")
|
||||
}
|
||||
|
||||
// SetProvider updates the "provider" record field value.
|
||||
func (m *ExternalAuth) SetProvider(provider string) {
|
||||
m.Set("provider", provider)
|
||||
}
|
||||
|
||||
// Provider returns the "providerId" record field value.
|
||||
func (m *ExternalAuth) ProviderId() string {
|
||||
return m.GetString("providerId")
|
||||
}
|
||||
|
||||
// SetProvider updates the "providerId" record field value.
|
||||
func (m *ExternalAuth) SetProviderId(providerId string) {
|
||||
m.Set("providerId", providerId)
|
||||
}
|
||||
|
||||
// Created returns the "created" record field value.
|
||||
func (m *ExternalAuth) Created() types.DateTime {
|
||||
return m.GetDateTime("created")
|
||||
}
|
||||
|
||||
// Updated returns the "updated" record field value.
|
||||
func (m *ExternalAuth) Updated() types.DateTime {
|
||||
return m.GetDateTime("updated")
|
||||
}
|
||||
|
||||
func (app *BaseApp) registerExternalAuthHooks() {
|
||||
recordRefHooks[*ExternalAuth](app, CollectionNameExternalAuths, CollectionTypeAuth)
|
||||
|
||||
app.OnRecordValidate(CollectionNameExternalAuths).Bind(&hook.Handler[*RecordEvent]{
|
||||
Func: func(e *RecordEvent) error {
|
||||
providerNames := make([]any, 0, len(auth.Providers))
|
||||
for name := range auth.Providers {
|
||||
providerNames = append(providerNames, name)
|
||||
}
|
||||
|
||||
provider := e.Record.GetString("provider")
|
||||
if err := validation.Validate(provider, validation.Required, validation.In(providerNames...)); err != nil {
|
||||
return validation.Errors{"provider": err}
|
||||
}
|
||||
|
||||
return e.Next()
|
||||
},
|
||||
Priority: 99,
|
||||
})
|
||||
}
|
310
core/external_auth_model_test.go
Normal file
310
core/external_auth_model_test.go
Normal file
|
@ -0,0 +1,310 @@
|
|||
package core_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
)
|
||||
|
||||
func TestNewExternalAuth(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
ea := core.NewExternalAuth(app)
|
||||
|
||||
if ea.Collection().Name != core.CollectionNameExternalAuths {
|
||||
t.Fatalf("Expected record with %q collection, got %q", core.CollectionNameExternalAuths, ea.Collection().Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExternalAuthProxyRecord(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
record := core.NewRecord(core.NewBaseCollection("test"))
|
||||
record.Id = "test_id"
|
||||
|
||||
ea := core.ExternalAuth{}
|
||||
ea.SetProxyRecord(record)
|
||||
|
||||
if ea.ProxyRecord() == nil || ea.ProxyRecord().Id != record.Id {
|
||||
t.Fatalf("Expected proxy record with id %q, got %v", record.Id, ea.ProxyRecord())
|
||||
}
|
||||
}
|
||||
|
||||
func TestExternalAuthRecordRef(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
ea := core.NewExternalAuth(app)
|
||||
|
||||
testValues := []string{"test_1", "test2", ""}
|
||||
for i, testValue := range testValues {
|
||||
t.Run(fmt.Sprintf("%d_%q", i, testValue), func(t *testing.T) {
|
||||
ea.SetRecordRef(testValue)
|
||||
|
||||
if v := ea.RecordRef(); v != testValue {
|
||||
t.Fatalf("Expected getter %q, got %q", testValue, v)
|
||||
}
|
||||
|
||||
if v := ea.GetString("recordRef"); v != testValue {
|
||||
t.Fatalf("Expected field value %q, got %q", testValue, v)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExternalAuthCollectionRef(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
ea := core.NewExternalAuth(app)
|
||||
|
||||
testValues := []string{"test_1", "test2", ""}
|
||||
for i, testValue := range testValues {
|
||||
t.Run(fmt.Sprintf("%d_%q", i, testValue), func(t *testing.T) {
|
||||
ea.SetCollectionRef(testValue)
|
||||
|
||||
if v := ea.CollectionRef(); v != testValue {
|
||||
t.Fatalf("Expected getter %q, got %q", testValue, v)
|
||||
}
|
||||
|
||||
if v := ea.GetString("collectionRef"); v != testValue {
|
||||
t.Fatalf("Expected field value %q, got %q", testValue, v)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExternalAuthProvider(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
ea := core.NewExternalAuth(app)
|
||||
|
||||
testValues := []string{"test_1", "test2", ""}
|
||||
for i, testValue := range testValues {
|
||||
t.Run(fmt.Sprintf("%d_%q", i, testValue), func(t *testing.T) {
|
||||
ea.SetProvider(testValue)
|
||||
|
||||
if v := ea.Provider(); v != testValue {
|
||||
t.Fatalf("Expected getter %q, got %q", testValue, v)
|
||||
}
|
||||
|
||||
if v := ea.GetString("provider"); v != testValue {
|
||||
t.Fatalf("Expected field value %q, got %q", testValue, v)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExternalAuthProviderId(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
ea := core.NewExternalAuth(app)
|
||||
|
||||
testValues := []string{"test_1", "test2", ""}
|
||||
for i, testValue := range testValues {
|
||||
t.Run(fmt.Sprintf("%d_%q", i, testValue), func(t *testing.T) {
|
||||
ea.SetProviderId(testValue)
|
||||
|
||||
if v := ea.ProviderId(); v != testValue {
|
||||
t.Fatalf("Expected getter %q, got %q", testValue, v)
|
||||
}
|
||||
|
||||
if v := ea.GetString("providerId"); v != testValue {
|
||||
t.Fatalf("Expected field value %q, got %q", testValue, v)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExternalAuthCreated(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
ea := core.NewExternalAuth(app)
|
||||
|
||||
if v := ea.Created().String(); v != "" {
|
||||
t.Fatalf("Expected empty created, got %q", v)
|
||||
}
|
||||
|
||||
now := types.NowDateTime()
|
||||
ea.SetRaw("created", now)
|
||||
|
||||
if v := ea.Created().String(); v != now.String() {
|
||||
t.Fatalf("Expected %q created, got %q", now.String(), v)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExternalAuthUpdated(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
ea := core.NewExternalAuth(app)
|
||||
|
||||
if v := ea.Updated().String(); v != "" {
|
||||
t.Fatalf("Expected empty updated, got %q", v)
|
||||
}
|
||||
|
||||
now := types.NowDateTime()
|
||||
ea.SetRaw("updated", now)
|
||||
|
||||
if v := ea.Updated().String(); v != now.String() {
|
||||
t.Fatalf("Expected %q updated, got %q", now.String(), v)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExternalAuthPreValidate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
externalAuthsCol, err := app.FindCollectionByNameOrId(core.CollectionNameExternalAuths)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Run("no proxy record", func(t *testing.T) {
|
||||
externalAuth := &core.ExternalAuth{}
|
||||
|
||||
if err := app.Validate(externalAuth); err == nil {
|
||||
t.Fatal("Expected collection validation error")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("non-ExternalAuth collection", func(t *testing.T) {
|
||||
externalAuth := &core.ExternalAuth{}
|
||||
externalAuth.SetProxyRecord(core.NewRecord(core.NewBaseCollection("invalid")))
|
||||
externalAuth.SetRecordRef(user.Id)
|
||||
externalAuth.SetCollectionRef(user.Collection().Id)
|
||||
externalAuth.SetProvider("gitlab")
|
||||
externalAuth.SetProviderId("test123")
|
||||
|
||||
if err := app.Validate(externalAuth); err == nil {
|
||||
t.Fatal("Expected collection validation error")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ExternalAuth collection", func(t *testing.T) {
|
||||
externalAuth := &core.ExternalAuth{}
|
||||
externalAuth.SetProxyRecord(core.NewRecord(externalAuthsCol))
|
||||
externalAuth.SetRecordRef(user.Id)
|
||||
externalAuth.SetCollectionRef(user.Collection().Id)
|
||||
externalAuth.SetProvider("gitlab")
|
||||
externalAuth.SetProviderId("test123")
|
||||
|
||||
if err := app.Validate(externalAuth); err != nil {
|
||||
t.Fatalf("Expected nil validation error, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestExternalAuthValidateHook(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
demo1, err := app.FindRecordById("demo1", "84nmscqy84lsi1t")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
externalAuth func() *core.ExternalAuth
|
||||
expectErrors []string
|
||||
}{
|
||||
{
|
||||
"empty",
|
||||
func() *core.ExternalAuth {
|
||||
return core.NewExternalAuth(app)
|
||||
},
|
||||
[]string{"collectionRef", "recordRef", "provider", "providerId"},
|
||||
},
|
||||
{
|
||||
"non-auth collection",
|
||||
func() *core.ExternalAuth {
|
||||
ea := core.NewExternalAuth(app)
|
||||
ea.SetCollectionRef(demo1.Collection().Id)
|
||||
ea.SetRecordRef(demo1.Id)
|
||||
ea.SetProvider("gitlab")
|
||||
ea.SetProviderId("test123")
|
||||
return ea
|
||||
},
|
||||
[]string{"collectionRef"},
|
||||
},
|
||||
{
|
||||
"disabled provider",
|
||||
func() *core.ExternalAuth {
|
||||
ea := core.NewExternalAuth(app)
|
||||
ea.SetCollectionRef(user.Collection().Id)
|
||||
ea.SetRecordRef("missing")
|
||||
ea.SetProvider("apple")
|
||||
ea.SetProviderId("test123")
|
||||
return ea
|
||||
},
|
||||
[]string{"recordRef"},
|
||||
},
|
||||
{
|
||||
"missing record id",
|
||||
func() *core.ExternalAuth {
|
||||
ea := core.NewExternalAuth(app)
|
||||
ea.SetCollectionRef(user.Collection().Id)
|
||||
ea.SetRecordRef("missing")
|
||||
ea.SetProvider("gitlab")
|
||||
ea.SetProviderId("test123")
|
||||
return ea
|
||||
},
|
||||
[]string{"recordRef"},
|
||||
},
|
||||
{
|
||||
"valid ref",
|
||||
func() *core.ExternalAuth {
|
||||
ea := core.NewExternalAuth(app)
|
||||
ea.SetCollectionRef(user.Collection().Id)
|
||||
ea.SetRecordRef(user.Id)
|
||||
ea.SetProvider("gitlab")
|
||||
ea.SetProviderId("test123")
|
||||
return ea
|
||||
},
|
||||
[]string{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
errs := app.Validate(s.externalAuth())
|
||||
tests.TestValidationErrors(t, errs, s.expectErrors)
|
||||
})
|
||||
}
|
||||
}
|
61
core/external_auth_query.go
Normal file
61
core/external_auth_query.go
Normal file
|
@ -0,0 +1,61 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"github.com/pocketbase/dbx"
|
||||
)
|
||||
|
||||
// FindAllExternalAuthsByRecord returns all ExternalAuth models
|
||||
// linked to the provided auth record.
|
||||
func (app *BaseApp) FindAllExternalAuthsByRecord(authRecord *Record) ([]*ExternalAuth, error) {
|
||||
auths := []*ExternalAuth{}
|
||||
|
||||
err := app.RecordQuery(CollectionNameExternalAuths).
|
||||
AndWhere(dbx.HashExp{
|
||||
"collectionRef": authRecord.Collection().Id,
|
||||
"recordRef": authRecord.Id,
|
||||
}).
|
||||
OrderBy("created DESC").
|
||||
All(&auths)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return auths, nil
|
||||
}
|
||||
|
||||
// FindAllExternalAuthsByCollection returns all ExternalAuth models
|
||||
// linked to the provided auth collection.
|
||||
func (app *BaseApp) FindAllExternalAuthsByCollection(collection *Collection) ([]*ExternalAuth, error) {
|
||||
auths := []*ExternalAuth{}
|
||||
|
||||
err := app.RecordQuery(CollectionNameExternalAuths).
|
||||
AndWhere(dbx.HashExp{"collectionRef": collection.Id}).
|
||||
OrderBy("created DESC").
|
||||
All(&auths)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return auths, nil
|
||||
}
|
||||
|
||||
// FindFirstExternalAuthByExpr returns the first available (the most recent created)
|
||||
// ExternalAuth model that satisfies the non-nil expression.
|
||||
func (app *BaseApp) FindFirstExternalAuthByExpr(expr dbx.Expression) (*ExternalAuth, error) {
|
||||
model := &ExternalAuth{}
|
||||
|
||||
err := app.RecordQuery(CollectionNameExternalAuths).
|
||||
AndWhere(dbx.Not(dbx.HashExp{"providerId": ""})). // exclude empty providerIds
|
||||
AndWhere(expr).
|
||||
OrderBy("created DESC").
|
||||
Limit(1).
|
||||
One(model)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return model, nil
|
||||
}
|
176
core/external_auth_query_test.go
Normal file
176
core/external_auth_query_test.go
Normal file
|
@ -0,0 +1,176 @@
|
|||
package core_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestFindAllExternalAuthsByRecord(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
demo1, err := app.FindRecordById("demo1", "84nmscqy84lsi1t")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
superuser1, err := app.FindAuthRecordByEmail(core.CollectionNameSuperusers, "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
user1, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
user2, err := app.FindAuthRecordByEmail("users", "test2@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
user3, err := app.FindAuthRecordByEmail("users", "test3@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
client1, err := app.FindAuthRecordByEmail("clients", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
scenarios := []struct {
|
||||
record *core.Record
|
||||
expected []string
|
||||
}{
|
||||
{demo1, nil},
|
||||
{superuser1, nil},
|
||||
{client1, []string{"f1z5b3843pzc964"}},
|
||||
{user1, []string{"clmflokuq1xl341", "dlmflokuq1xl342"}},
|
||||
{user2, nil},
|
||||
{user3, []string{"5eto7nmys833164"}},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.record.Collection().Name+"_"+s.record.Id, func(t *testing.T) {
|
||||
result, err := app.FindAllExternalAuthsByRecord(s.record)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(result) != len(s.expected) {
|
||||
t.Fatalf("Expected total models %d, got %d", len(s.expected), len(result))
|
||||
}
|
||||
|
||||
for i, id := range s.expected {
|
||||
if result[i].Id != id {
|
||||
t.Errorf("[%d] Expected id %q, got %q", i, id, result[i].Id)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindAllExternalAuthsByCollection(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
demo1, err := app.FindCollectionByNameOrId("demo1")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
superusers, err := app.FindCollectionByNameOrId(core.CollectionNameSuperusers)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
clients, err := app.FindCollectionByNameOrId("clients")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
users, err := app.FindCollectionByNameOrId("users")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
scenarios := []struct {
|
||||
collection *core.Collection
|
||||
expected []string
|
||||
}{
|
||||
{demo1, nil},
|
||||
{superusers, nil},
|
||||
{clients, []string{
|
||||
"f1z5b3843pzc964",
|
||||
}},
|
||||
{users, []string{
|
||||
"5eto7nmys833164",
|
||||
"clmflokuq1xl341",
|
||||
"dlmflokuq1xl342",
|
||||
}},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.collection.Name, func(t *testing.T) {
|
||||
result, err := app.FindAllExternalAuthsByCollection(s.collection)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(result) != len(s.expected) {
|
||||
t.Fatalf("Expected total models %d, got %d", len(s.expected), len(result))
|
||||
}
|
||||
|
||||
for i, id := range s.expected {
|
||||
if result[i].Id != id {
|
||||
t.Errorf("[%d] Expected id %q, got %q", i, id, result[i].Id)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindFirstExternalAuthByExpr(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
scenarios := []struct {
|
||||
expr dbx.Expression
|
||||
expectedId string
|
||||
}{
|
||||
{dbx.HashExp{"collectionRef": "invalid"}, ""},
|
||||
{dbx.HashExp{"collectionRef": "_pb_users_auth_"}, "5eto7nmys833164"},
|
||||
{dbx.HashExp{"collectionRef": "_pb_users_auth_", "provider": "gitlab"}, "dlmflokuq1xl342"},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%v", i, s.expr.Build(app.ConcurrentDB().(*dbx.DB), dbx.Params{})), func(t *testing.T) {
|
||||
result, err := app.FindFirstExternalAuthByExpr(s.expr)
|
||||
|
||||
hasErr := err != nil
|
||||
expectErr := s.expectedId == ""
|
||||
if hasErr != expectErr {
|
||||
t.Fatalf("Expected hasErr %v, got %v", expectErr, hasErr)
|
||||
}
|
||||
|
||||
if hasErr {
|
||||
return
|
||||
}
|
||||
|
||||
if result.Id != s.expectedId {
|
||||
t.Errorf("Expected id %q, got %q", s.expectedId, result.Id)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
252
core/field.go
Normal file
252
core/field.go
Normal file
|
@ -0,0 +1,252 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql/driver"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/pocketbase/pocketbase/core/validators"
|
||||
"github.com/pocketbase/pocketbase/tools/list"
|
||||
)
|
||||
|
||||
var fieldNameRegex = regexp.MustCompile(`^\w+$`)
|
||||
|
||||
const maxSafeJSONInt int64 = 1<<53 - 1
|
||||
|
||||
// Commonly used field names.
|
||||
const (
|
||||
FieldNameId = "id"
|
||||
FieldNameCollectionId = "collectionId"
|
||||
FieldNameCollectionName = "collectionName"
|
||||
FieldNameExpand = "expand"
|
||||
FieldNameEmail = "email"
|
||||
FieldNameEmailVisibility = "emailVisibility"
|
||||
FieldNameVerified = "verified"
|
||||
FieldNameTokenKey = "tokenKey"
|
||||
FieldNamePassword = "password"
|
||||
)
|
||||
|
||||
// SystemFields returns special internal field names that are usually readonly.
|
||||
var SystemDynamicFieldNames = []string{
|
||||
FieldNameCollectionId,
|
||||
FieldNameCollectionName,
|
||||
FieldNameExpand,
|
||||
}
|
||||
|
||||
// Common RecordInterceptor action names.
|
||||
const (
|
||||
InterceptorActionValidate = "validate"
|
||||
InterceptorActionDelete = "delete"
|
||||
InterceptorActionDeleteExecute = "deleteExecute"
|
||||
InterceptorActionAfterDelete = "afterDelete"
|
||||
InterceptorActionAfterDeleteError = "afterDeleteError"
|
||||
InterceptorActionCreate = "create"
|
||||
InterceptorActionCreateExecute = "createExecute"
|
||||
InterceptorActionAfterCreate = "afterCreate"
|
||||
InterceptorActionAfterCreateError = "afterCreateFailure"
|
||||
InterceptorActionUpdate = "update"
|
||||
InterceptorActionUpdateExecute = "updateExecute"
|
||||
InterceptorActionAfterUpdate = "afterUpdate"
|
||||
InterceptorActionAfterUpdateError = "afterUpdateError"
|
||||
)
|
||||
|
||||
// Common field errors.
|
||||
var (
|
||||
ErrUnknownField = validation.NewError("validation_unknown_field", "Unknown or invalid field.")
|
||||
ErrInvalidFieldValue = validation.NewError("validation_invalid_field_value", "Invalid field value.")
|
||||
ErrMustBeSystemAndHidden = validation.NewError("validation_must_be_system_and_hidden", `The field must be marked as "System" and "Hidden".`)
|
||||
ErrMustBeSystem = validation.NewError("validation_must_be_system", `The field must be marked as "System".`)
|
||||
)
|
||||
|
||||
// FieldFactoryFunc defines a simple function to construct a specific Field instance.
|
||||
type FieldFactoryFunc func() Field
|
||||
|
||||
// Fields holds all available collection fields.
|
||||
var Fields = map[string]FieldFactoryFunc{}
|
||||
|
||||
// Field defines a common interface that all Collection fields should implement.
|
||||
type Field interface {
|
||||
// note: the getters has an explicit "Get" prefix to avoid conflicts with their related field members
|
||||
|
||||
// GetId returns the field id.
|
||||
GetId() string
|
||||
|
||||
// SetId changes the field id.
|
||||
SetId(id string)
|
||||
|
||||
// GetName returns the field name.
|
||||
GetName() string
|
||||
|
||||
// SetName changes the field name.
|
||||
SetName(name string)
|
||||
|
||||
// GetSystem returns the field system flag state.
|
||||
GetSystem() bool
|
||||
|
||||
// SetSystem changes the field system flag state.
|
||||
SetSystem(system bool)
|
||||
|
||||
// GetHidden returns the field hidden flag state.
|
||||
GetHidden() bool
|
||||
|
||||
// SetHidden changes the field hidden flag state.
|
||||
SetHidden(hidden bool)
|
||||
|
||||
// Type returns the unique type of the field.
|
||||
Type() string
|
||||
|
||||
// ColumnType returns the DB column definition of the field.
|
||||
ColumnType(app App) string
|
||||
|
||||
// PrepareValue returns a properly formatted field value based on the provided raw one.
|
||||
//
|
||||
// This method is also called on record construction to initialize its default field value.
|
||||
PrepareValue(record *Record, raw any) (any, error)
|
||||
|
||||
// ValidateSettings validates the current field value associated with the provided record.
|
||||
ValidateValue(ctx context.Context, app App, record *Record) error
|
||||
|
||||
// ValidateSettings validates the current field settings.
|
||||
ValidateSettings(ctx context.Context, app App, collection *Collection) error
|
||||
}
|
||||
|
||||
// MaxBodySizeCalculator defines an optional field interface for
|
||||
// specifying the max size of a field value.
|
||||
type MaxBodySizeCalculator interface {
|
||||
// CalculateMaxBodySize returns the approximate max body size of a field value.
|
||||
CalculateMaxBodySize() int64
|
||||
}
|
||||
|
||||
type (
|
||||
SetterFunc func(record *Record, raw any)
|
||||
|
||||
// SetterFinder defines a field interface for registering custom field value setters.
|
||||
SetterFinder interface {
|
||||
// FindSetter returns a single field value setter function
|
||||
// by performing pattern-like field matching using the specified key.
|
||||
//
|
||||
// The key is usually just the field name but it could also
|
||||
// contains "modifier" characters based on which you can perform custom set operations
|
||||
// (ex. "users+" could be mapped to a function that will append new user to the existing field value).
|
||||
//
|
||||
// Return nil if you want to fallback to the default field value setter.
|
||||
FindSetter(key string) SetterFunc
|
||||
}
|
||||
)
|
||||
|
||||
type (
|
||||
GetterFunc func(record *Record) any
|
||||
|
||||
// GetterFinder defines a field interface for registering custom field value getters.
|
||||
GetterFinder interface {
|
||||
// FindGetter returns a single field value getter function
|
||||
// by performing pattern-like field matching using the specified key.
|
||||
//
|
||||
// The key is usually just the field name but it could also
|
||||
// contains "modifier" characters based on which you can perform custom get operations
|
||||
// (ex. "description:excerpt" could be mapped to a function that will return an excerpt of the current field value).
|
||||
//
|
||||
// Return nil if you want to fallback to the default field value setter.
|
||||
FindGetter(key string) GetterFunc
|
||||
}
|
||||
)
|
||||
|
||||
// DriverValuer defines a Field interface for exporting and formatting
|
||||
// a field value for the database.
|
||||
type DriverValuer interface {
|
||||
// DriverValue exports a single field value for persistence in the database.
|
||||
DriverValue(record *Record) (driver.Value, error)
|
||||
}
|
||||
|
||||
// MultiValuer defines a field interface that every multi-valued (eg. with MaxSelect) field has.
|
||||
type MultiValuer interface {
|
||||
// IsMultiple checks whether the field is configured to support multiple or single values.
|
||||
IsMultiple() bool
|
||||
}
|
||||
|
||||
// RecordInterceptor defines a field interface for reacting to various
|
||||
// Record related operations (create, delete, validate, etc.).
|
||||
type RecordInterceptor interface {
|
||||
// Interceptor is invoked when a specific record action occurs
|
||||
// allowing you to perform extra validations and normalization
|
||||
// (ex. uploading or deleting files).
|
||||
//
|
||||
// Note that users must call actionFunc() manually if they want to
|
||||
// execute the specific record action.
|
||||
Intercept(
|
||||
ctx context.Context,
|
||||
app App,
|
||||
record *Record,
|
||||
actionName string,
|
||||
actionFunc func() error,
|
||||
) error
|
||||
}
|
||||
|
||||
// DefaultFieldIdValidationRule performs base validation on a field id value.
|
||||
func DefaultFieldIdValidationRule(value any) error {
|
||||
v, ok := value.(string)
|
||||
if !ok {
|
||||
return validators.ErrUnsupportedValueType
|
||||
}
|
||||
|
||||
rules := []validation.Rule{
|
||||
validation.Required,
|
||||
validation.Length(1, 100),
|
||||
}
|
||||
|
||||
for _, r := range rules {
|
||||
if err := r.Validate(v); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// exclude special filter and system literals
|
||||
var excludeNames = append([]any{
|
||||
"null", "true", "false", "_rowid_",
|
||||
}, list.ToInterfaceSlice(SystemDynamicFieldNames)...)
|
||||
|
||||
// DefaultFieldIdValidationRule performs base validation on a field name value.
|
||||
func DefaultFieldNameValidationRule(value any) error {
|
||||
v, ok := value.(string)
|
||||
if !ok {
|
||||
return validators.ErrUnsupportedValueType
|
||||
}
|
||||
|
||||
rules := []validation.Rule{
|
||||
validation.Required,
|
||||
validation.Length(1, 100),
|
||||
validation.Match(fieldNameRegex),
|
||||
validation.NotIn(excludeNames...),
|
||||
validation.By(checkForVia),
|
||||
}
|
||||
|
||||
for _, r := range rules {
|
||||
if err := r.Validate(v); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func checkForVia(value any) error {
|
||||
v, _ := value.(string)
|
||||
if v == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
if strings.Contains(strings.ToLower(v), "_via_") {
|
||||
return validation.NewError("validation_found_via", `The value cannot contain "_via_".`)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func noopSetter(record *Record, raw any) {
|
||||
// do nothing
|
||||
}
|
215
core/field_autodate.go
Normal file
215
core/field_autodate.go
Normal file
|
@ -0,0 +1,215 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/pocketbase/pocketbase/core/validators"
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Fields[FieldTypeAutodate] = func() Field {
|
||||
return &AutodateField{}
|
||||
}
|
||||
}
|
||||
|
||||
const FieldTypeAutodate = "autodate"
|
||||
|
||||
// used to keep track of the last set autodate value
|
||||
const autodateLastKnownPrefix = internalCustomFieldKeyPrefix + "_last_autodate_"
|
||||
|
||||
var (
|
||||
_ Field = (*AutodateField)(nil)
|
||||
_ SetterFinder = (*AutodateField)(nil)
|
||||
_ RecordInterceptor = (*AutodateField)(nil)
|
||||
)
|
||||
|
||||
// AutodateField defines an "autodate" type field, aka.
|
||||
// field which datetime value could be auto set on record create/update.
|
||||
//
|
||||
// This field is usually used for defining timestamp fields like "created" and "updated".
|
||||
//
|
||||
// Requires either both or at least one of the OnCreate or OnUpdate options to be set.
|
||||
type AutodateField struct {
|
||||
// Name (required) is the unique name of the field.
|
||||
Name string `form:"name" json:"name"`
|
||||
|
||||
// Id is the unique stable field identifier.
|
||||
//
|
||||
// It is automatically generated from the name when adding to a collection FieldsList.
|
||||
Id string `form:"id" json:"id"`
|
||||
|
||||
// System prevents the renaming and removal of the field.
|
||||
System bool `form:"system" json:"system"`
|
||||
|
||||
// Hidden hides the field from the API response.
|
||||
Hidden bool `form:"hidden" json:"hidden"`
|
||||
|
||||
// Presentable hints the Dashboard UI to use the underlying
|
||||
// field record value in the relation preview label.
|
||||
Presentable bool `form:"presentable" json:"presentable"`
|
||||
|
||||
// ---
|
||||
|
||||
// OnCreate auto sets the current datetime as field value on record create.
|
||||
OnCreate bool `form:"onCreate" json:"onCreate"`
|
||||
|
||||
// OnUpdate auto sets the current datetime as field value on record update.
|
||||
OnUpdate bool `form:"onUpdate" json:"onUpdate"`
|
||||
}
|
||||
|
||||
// Type implements [Field.Type] interface method.
|
||||
func (f *AutodateField) Type() string {
|
||||
return FieldTypeAutodate
|
||||
}
|
||||
|
||||
// GetId implements [Field.GetId] interface method.
|
||||
func (f *AutodateField) GetId() string {
|
||||
return f.Id
|
||||
}
|
||||
|
||||
// SetId implements [Field.SetId] interface method.
|
||||
func (f *AutodateField) SetId(id string) {
|
||||
f.Id = id
|
||||
}
|
||||
|
||||
// GetName implements [Field.GetName] interface method.
|
||||
func (f *AutodateField) GetName() string {
|
||||
return f.Name
|
||||
}
|
||||
|
||||
// SetName implements [Field.SetName] interface method.
|
||||
func (f *AutodateField) SetName(name string) {
|
||||
f.Name = name
|
||||
}
|
||||
|
||||
// GetSystem implements [Field.GetSystem] interface method.
|
||||
func (f *AutodateField) GetSystem() bool {
|
||||
return f.System
|
||||
}
|
||||
|
||||
// SetSystem implements [Field.SetSystem] interface method.
|
||||
func (f *AutodateField) SetSystem(system bool) {
|
||||
f.System = system
|
||||
}
|
||||
|
||||
// GetHidden implements [Field.GetHidden] interface method.
|
||||
func (f *AutodateField) GetHidden() bool {
|
||||
return f.Hidden
|
||||
}
|
||||
|
||||
// SetHidden implements [Field.SetHidden] interface method.
|
||||
func (f *AutodateField) SetHidden(hidden bool) {
|
||||
f.Hidden = hidden
|
||||
}
|
||||
|
||||
// ColumnType implements [Field.ColumnType] interface method.
|
||||
func (f *AutodateField) ColumnType(app App) string {
|
||||
return "TEXT DEFAULT '' NOT NULL" // note: sqlite doesn't allow adding new columns with non-constant defaults
|
||||
}
|
||||
|
||||
// PrepareValue implements [Field.PrepareValue] interface method.
|
||||
func (f *AutodateField) PrepareValue(record *Record, raw any) (any, error) {
|
||||
val, _ := types.ParseDateTime(raw)
|
||||
return val, nil
|
||||
}
|
||||
|
||||
// ValidateValue implements [Field.ValidateValue] interface method.
|
||||
func (f *AutodateField) ValidateValue(ctx context.Context, app App, record *Record) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateSettings implements [Field.ValidateSettings] interface method.
|
||||
func (f *AutodateField) ValidateSettings(ctx context.Context, app App, collection *Collection) error {
|
||||
oldOnCreate := f.OnCreate
|
||||
oldOnUpdate := f.OnUpdate
|
||||
|
||||
oldCollection, _ := app.FindCollectionByNameOrId(collection.Id)
|
||||
if oldCollection != nil {
|
||||
oldField, ok := oldCollection.Fields.GetById(f.Id).(*AutodateField)
|
||||
if ok && oldField != nil {
|
||||
oldOnCreate = oldField.OnCreate
|
||||
oldOnUpdate = oldField.OnUpdate
|
||||
}
|
||||
}
|
||||
|
||||
return validation.ValidateStruct(f,
|
||||
validation.Field(&f.Id, validation.By(DefaultFieldIdValidationRule)),
|
||||
validation.Field(&f.Name, validation.By(DefaultFieldNameValidationRule)),
|
||||
validation.Field(
|
||||
&f.OnCreate,
|
||||
validation.When(f.System, validation.By(validators.Equal(oldOnCreate))),
|
||||
validation.Required.Error("either onCreate or onUpdate must be enabled").When(!f.OnUpdate),
|
||||
),
|
||||
validation.Field(
|
||||
&f.OnUpdate,
|
||||
validation.When(f.System, validation.By(validators.Equal(oldOnUpdate))),
|
||||
validation.Required.Error("either onCreate or onUpdate must be enabled").When(!f.OnCreate),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
// FindSetter implements the [SetterFinder] interface.
|
||||
func (f *AutodateField) FindSetter(key string) SetterFunc {
|
||||
switch key {
|
||||
case f.Name:
|
||||
// return noopSetter to disallow updating the value with record.Set()
|
||||
return noopSetter
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Intercept implements the [RecordInterceptor] interface.
|
||||
func (f *AutodateField) Intercept(
|
||||
ctx context.Context,
|
||||
app App,
|
||||
record *Record,
|
||||
actionName string,
|
||||
actionFunc func() error,
|
||||
) error {
|
||||
switch actionName {
|
||||
case InterceptorActionCreateExecute:
|
||||
// ignore if a date different from the old one was manually set with SetRaw
|
||||
if f.OnCreate && record.GetDateTime(f.Name).Equal(f.getLastKnownValue(record)) {
|
||||
now := types.NowDateTime()
|
||||
record.SetRaw(f.Name, now)
|
||||
record.SetRaw(autodateLastKnownPrefix+f.Name, now) // eagerly set so that it can be renewed on resave after failure
|
||||
}
|
||||
|
||||
if err := actionFunc(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
record.SetRaw(autodateLastKnownPrefix+f.Name, record.GetRaw(f.Name))
|
||||
|
||||
return nil
|
||||
case InterceptorActionUpdateExecute:
|
||||
// ignore if a date different from the old one was manually set with SetRaw
|
||||
if f.OnUpdate && record.GetDateTime(f.Name).Equal(f.getLastKnownValue(record)) {
|
||||
now := types.NowDateTime()
|
||||
record.SetRaw(f.Name, now)
|
||||
record.SetRaw(autodateLastKnownPrefix+f.Name, now) // eagerly set so that it can be renewed on resave after failure
|
||||
}
|
||||
|
||||
if err := actionFunc(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
record.SetRaw(autodateLastKnownPrefix+f.Name, record.GetRaw(f.Name))
|
||||
|
||||
return nil
|
||||
default:
|
||||
return actionFunc()
|
||||
}
|
||||
}
|
||||
|
||||
func (f *AutodateField) getLastKnownValue(record *Record) types.DateTime {
|
||||
v := record.GetDateTime(autodateLastKnownPrefix + f.Name)
|
||||
if !v.IsZero() {
|
||||
return v
|
||||
}
|
||||
|
||||
return record.Original().GetDateTime(f.Name)
|
||||
}
|
441
core/field_autodate_test.go
Normal file
441
core/field_autodate_test.go
Normal file
|
@ -0,0 +1,441 @@
|
|||
package core_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
"github.com/pocketbase/pocketbase/tools/hook"
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
)
|
||||
|
||||
func TestAutodateFieldBaseMethods(t *testing.T) {
|
||||
testFieldBaseMethods(t, core.FieldTypeAutodate)
|
||||
}
|
||||
|
||||
func TestAutodateFieldColumnType(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
f := &core.AutodateField{}
|
||||
|
||||
expected := "TEXT DEFAULT '' NOT NULL"
|
||||
|
||||
if v := f.ColumnType(app); v != expected {
|
||||
t.Fatalf("Expected\n%q\ngot\n%q", expected, v)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAutodateFieldPrepareValue(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
f := &core.AutodateField{}
|
||||
record := core.NewRecord(core.NewBaseCollection("test"))
|
||||
|
||||
scenarios := []struct {
|
||||
raw any
|
||||
expected string
|
||||
}{
|
||||
{"", ""},
|
||||
{"invalid", ""},
|
||||
{"2024-01-01 00:11:22.345Z", "2024-01-01 00:11:22.345Z"},
|
||||
{time.Date(2024, 1, 2, 3, 4, 5, 0, time.UTC), "2024-01-02 03:04:05.000Z"},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%#v", i, s.raw), func(t *testing.T) {
|
||||
v, err := f.PrepareValue(record, s.raw)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
vDate, ok := v.(types.DateTime)
|
||||
if !ok {
|
||||
t.Fatalf("Expected types.DateTime instance, got %T", v)
|
||||
}
|
||||
|
||||
if vDate.String() != s.expected {
|
||||
t.Fatalf("Expected %v, got %v", s.expected, v)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAutodateFieldValidateValue(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
collection := core.NewBaseCollection("test_collection")
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
field *core.AutodateField
|
||||
record func() *core.Record
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
"invalid raw value",
|
||||
&core.AutodateField{Name: "test"},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", 123)
|
||||
return record
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"missing field value",
|
||||
&core.AutodateField{Name: "test"},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("abc", true)
|
||||
return record
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"existing field value",
|
||||
&core.AutodateField{Name: "test"},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", types.NowDateTime())
|
||||
return record
|
||||
},
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
err := s.field.ValidateValue(context.Background(), app, s.record())
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAutodateFieldValidateSettings(t *testing.T) {
|
||||
testDefaultFieldIdValidation(t, core.FieldTypeAutodate)
|
||||
testDefaultFieldNameValidation(t, core.FieldTypeAutodate)
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
superusers, err := app.FindCollectionByNameOrId(core.CollectionNameSuperusers)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
field func() *core.AutodateField
|
||||
expectErrors []string
|
||||
}{
|
||||
{
|
||||
"empty onCreate and onUpdate",
|
||||
func() *core.AutodateField {
|
||||
return &core.AutodateField{
|
||||
Id: "test",
|
||||
Name: "test",
|
||||
}
|
||||
},
|
||||
[]string{"onCreate", "onUpdate"},
|
||||
},
|
||||
{
|
||||
"with onCreate",
|
||||
func() *core.AutodateField {
|
||||
return &core.AutodateField{
|
||||
Id: "test",
|
||||
Name: "test",
|
||||
OnCreate: true,
|
||||
}
|
||||
},
|
||||
[]string{},
|
||||
},
|
||||
{
|
||||
"with onUpdate",
|
||||
func() *core.AutodateField {
|
||||
return &core.AutodateField{
|
||||
Id: "test",
|
||||
Name: "test",
|
||||
OnUpdate: true,
|
||||
}
|
||||
},
|
||||
[]string{},
|
||||
},
|
||||
{
|
||||
"change of a system autodate field",
|
||||
func() *core.AutodateField {
|
||||
created := superusers.Fields.GetByName("created").(*core.AutodateField)
|
||||
created.OnCreate = !created.OnCreate
|
||||
created.OnUpdate = !created.OnUpdate
|
||||
return created
|
||||
},
|
||||
[]string{"onCreate", "onUpdate"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
errs := s.field().ValidateSettings(context.Background(), app, superusers)
|
||||
|
||||
tests.TestValidationErrors(t, errs, s.expectErrors)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAutodateFieldFindSetter(t *testing.T) {
|
||||
field := &core.AutodateField{Name: "test"}
|
||||
|
||||
collection := core.NewBaseCollection("test_collection")
|
||||
collection.Fields.Add(field)
|
||||
|
||||
initialDate, err := types.ParseDateTime("2024-01-02 03:04:05.789Z")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", initialDate)
|
||||
|
||||
t.Run("no matching setter", func(t *testing.T) {
|
||||
f := field.FindSetter("abc")
|
||||
if f != nil {
|
||||
t.Fatal("Expected nil setter")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("matching setter", func(t *testing.T) {
|
||||
f := field.FindSetter("test")
|
||||
if f == nil {
|
||||
t.Fatal("Expected non-nil setter")
|
||||
}
|
||||
|
||||
f(record, types.NowDateTime()) // should be ignored
|
||||
|
||||
if v := record.GetString("test"); v != "2024-01-02 03:04:05.789Z" {
|
||||
t.Fatalf("Expected no value change, got %q", v)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func cutMilliseconds(datetime string) string {
|
||||
if len(datetime) > 19 {
|
||||
return datetime[:19]
|
||||
}
|
||||
return datetime
|
||||
}
|
||||
|
||||
func TestAutodateFieldIntercept(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
initialDate, err := types.ParseDateTime("2024-01-02 03:04:05.789Z")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
collection := core.NewBaseCollection("test_collection")
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
actionName string
|
||||
field *core.AutodateField
|
||||
record func() *core.Record
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
"non-matching action",
|
||||
"test",
|
||||
&core.AutodateField{Name: "test", OnCreate: true, OnUpdate: true},
|
||||
func() *core.Record {
|
||||
return core.NewRecord(collection)
|
||||
},
|
||||
"",
|
||||
},
|
||||
{
|
||||
"create with zero value (disabled onCreate)",
|
||||
core.InterceptorActionCreateExecute,
|
||||
&core.AutodateField{Name: "test", OnCreate: false, OnUpdate: true},
|
||||
func() *core.Record {
|
||||
return core.NewRecord(collection)
|
||||
},
|
||||
"",
|
||||
},
|
||||
{
|
||||
"create with zero value",
|
||||
core.InterceptorActionCreateExecute,
|
||||
&core.AutodateField{Name: "test", OnCreate: true, OnUpdate: true},
|
||||
func() *core.Record {
|
||||
return core.NewRecord(collection)
|
||||
},
|
||||
"{NOW}",
|
||||
},
|
||||
{
|
||||
"create with non-zero value",
|
||||
core.InterceptorActionCreateExecute,
|
||||
&core.AutodateField{Name: "test", OnCreate: true, OnUpdate: true},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", initialDate)
|
||||
return record
|
||||
},
|
||||
initialDate.String(),
|
||||
},
|
||||
{
|
||||
"update with zero value (disabled onUpdate)",
|
||||
core.InterceptorActionUpdateExecute,
|
||||
&core.AutodateField{Name: "test", OnCreate: true, OnUpdate: false},
|
||||
func() *core.Record {
|
||||
return core.NewRecord(collection)
|
||||
},
|
||||
"",
|
||||
},
|
||||
{
|
||||
"update with zero value",
|
||||
core.InterceptorActionUpdateExecute,
|
||||
&core.AutodateField{Name: "test", OnCreate: true, OnUpdate: true},
|
||||
func() *core.Record {
|
||||
return core.NewRecord(collection)
|
||||
},
|
||||
"{NOW}",
|
||||
},
|
||||
{
|
||||
"update with non-zero value",
|
||||
core.InterceptorActionUpdateExecute,
|
||||
&core.AutodateField{Name: "test", OnCreate: true, OnUpdate: true},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", initialDate)
|
||||
return record
|
||||
},
|
||||
initialDate.String(),
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
actionCalls := 0
|
||||
record := s.record()
|
||||
|
||||
now := types.NowDateTime().String()
|
||||
err := s.field.Intercept(context.Background(), app, record, s.actionName, func() error {
|
||||
actionCalls++
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if actionCalls != 1 {
|
||||
t.Fatalf("Expected actionCalls %d, got %d", 1, actionCalls)
|
||||
}
|
||||
|
||||
expected := cutMilliseconds(strings.ReplaceAll(s.expected, "{NOW}", now))
|
||||
|
||||
v := cutMilliseconds(record.GetString(s.field.GetName()))
|
||||
if v != expected {
|
||||
t.Fatalf("Expected value %q, got %q", expected, v)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAutodateRecordResave(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
collection, err := app.FindCollectionByNameOrId("demo2")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
record, err := app.FindRecordById(collection, "llvuca81nly1qls")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
lastUpdated := record.GetDateTime("updated")
|
||||
|
||||
// save with autogenerated date
|
||||
err = app.Save(record)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
newUpdated := record.GetDateTime("updated")
|
||||
if newUpdated.Equal(lastUpdated) {
|
||||
t.Fatalf("[0] Expected updated to change, got %v", newUpdated)
|
||||
}
|
||||
lastUpdated = newUpdated
|
||||
|
||||
// save with custom date
|
||||
manualUpdated := lastUpdated.Add(-1 * time.Minute)
|
||||
record.SetRaw("updated", manualUpdated)
|
||||
err = app.Save(record)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
newUpdated = record.GetDateTime("updated")
|
||||
if !newUpdated.Equal(manualUpdated) {
|
||||
t.Fatalf("[1] Expected updated to be the manual set date %v, got %v", manualUpdated, newUpdated)
|
||||
}
|
||||
lastUpdated = newUpdated
|
||||
|
||||
// save again with autogenerated date
|
||||
err = app.Save(record)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
newUpdated = record.GetDateTime("updated")
|
||||
if newUpdated.Equal(lastUpdated) {
|
||||
t.Fatalf("[2] Expected updated to change, got %v", newUpdated)
|
||||
}
|
||||
lastUpdated = newUpdated
|
||||
|
||||
// simulate save failure
|
||||
app.OnRecordUpdateExecute(collection.Id).Bind(&hook.Handler[*core.RecordEvent]{
|
||||
Id: "test_failure",
|
||||
Func: func(*core.RecordEvent) error {
|
||||
return errors.New("test")
|
||||
},
|
||||
Priority: 9999999999, // as latest as possible
|
||||
})
|
||||
|
||||
// save again with autogenerated date (should fail)
|
||||
err = app.Save(record)
|
||||
if err == nil {
|
||||
t.Fatal("Expected save failure")
|
||||
}
|
||||
|
||||
// updated should still be set even after save failure
|
||||
newUpdated = record.GetDateTime("updated")
|
||||
if newUpdated.Equal(lastUpdated) {
|
||||
t.Fatalf("[3] Expected updated to change, got %v", newUpdated)
|
||||
}
|
||||
lastUpdated = newUpdated
|
||||
|
||||
// cleanup the error and resave again
|
||||
app.OnRecordUpdateExecute(collection.Id).Unbind("test_failure")
|
||||
|
||||
err = app.Save(record)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
newUpdated = record.GetDateTime("updated")
|
||||
if newUpdated.Equal(lastUpdated) {
|
||||
t.Fatalf("[4] Expected updated to change, got %v", newUpdated)
|
||||
}
|
||||
}
|
124
core/field_bool.go
Normal file
124
core/field_bool.go
Normal file
|
@ -0,0 +1,124 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/pocketbase/pocketbase/core/validators"
|
||||
"github.com/spf13/cast"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Fields[FieldTypeBool] = func() Field {
|
||||
return &BoolField{}
|
||||
}
|
||||
}
|
||||
|
||||
const FieldTypeBool = "bool"
|
||||
|
||||
var _ Field = (*BoolField)(nil)
|
||||
|
||||
// BoolField defines "bool" type field to store a single true/false value.
|
||||
//
|
||||
// The respective zero record field value is false.
|
||||
type BoolField struct {
|
||||
// Name (required) is the unique name of the field.
|
||||
Name string `form:"name" json:"name"`
|
||||
|
||||
// Id is the unique stable field identifier.
|
||||
//
|
||||
// It is automatically generated from the name when adding to a collection FieldsList.
|
||||
Id string `form:"id" json:"id"`
|
||||
|
||||
// System prevents the renaming and removal of the field.
|
||||
System bool `form:"system" json:"system"`
|
||||
|
||||
// Hidden hides the field from the API response.
|
||||
Hidden bool `form:"hidden" json:"hidden"`
|
||||
|
||||
// Presentable hints the Dashboard UI to use the underlying
|
||||
// field record value in the relation preview label.
|
||||
Presentable bool `form:"presentable" json:"presentable"`
|
||||
|
||||
// ---
|
||||
|
||||
// Required will require the field value to be always "true".
|
||||
Required bool `form:"required" json:"required"`
|
||||
}
|
||||
|
||||
// Type implements [Field.Type] interface method.
|
||||
func (f *BoolField) Type() string {
|
||||
return FieldTypeBool
|
||||
}
|
||||
|
||||
// GetId implements [Field.GetId] interface method.
|
||||
func (f *BoolField) GetId() string {
|
||||
return f.Id
|
||||
}
|
||||
|
||||
// SetId implements [Field.SetId] interface method.
|
||||
func (f *BoolField) SetId(id string) {
|
||||
f.Id = id
|
||||
}
|
||||
|
||||
// GetName implements [Field.GetName] interface method.
|
||||
func (f *BoolField) GetName() string {
|
||||
return f.Name
|
||||
}
|
||||
|
||||
// SetName implements [Field.SetName] interface method.
|
||||
func (f *BoolField) SetName(name string) {
|
||||
f.Name = name
|
||||
}
|
||||
|
||||
// GetSystem implements [Field.GetSystem] interface method.
|
||||
func (f *BoolField) GetSystem() bool {
|
||||
return f.System
|
||||
}
|
||||
|
||||
// SetSystem implements [Field.SetSystem] interface method.
|
||||
func (f *BoolField) SetSystem(system bool) {
|
||||
f.System = system
|
||||
}
|
||||
|
||||
// GetHidden implements [Field.GetHidden] interface method.
|
||||
func (f *BoolField) GetHidden() bool {
|
||||
return f.Hidden
|
||||
}
|
||||
|
||||
// SetHidden implements [Field.SetHidden] interface method.
|
||||
func (f *BoolField) SetHidden(hidden bool) {
|
||||
f.Hidden = hidden
|
||||
}
|
||||
|
||||
// ColumnType implements [Field.ColumnType] interface method.
|
||||
func (f *BoolField) ColumnType(app App) string {
|
||||
return "BOOLEAN DEFAULT FALSE NOT NULL"
|
||||
}
|
||||
|
||||
// PrepareValue implements [Field.PrepareValue] interface method.
|
||||
func (f *BoolField) PrepareValue(record *Record, raw any) (any, error) {
|
||||
return cast.ToBool(raw), nil
|
||||
}
|
||||
|
||||
// ValidateValue implements [Field.ValidateValue] interface method.
|
||||
func (f *BoolField) ValidateValue(ctx context.Context, app App, record *Record) error {
|
||||
v, ok := record.GetRaw(f.Name).(bool)
|
||||
if !ok {
|
||||
return validators.ErrUnsupportedValueType
|
||||
}
|
||||
|
||||
if f.Required {
|
||||
return validation.Required.Validate(v)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateSettings implements [Field.ValidateSettings] interface method.
|
||||
func (f *BoolField) ValidateSettings(ctx context.Context, app App, collection *Collection) error {
|
||||
return validation.ValidateStruct(f,
|
||||
validation.Field(&f.Id, validation.By(DefaultFieldIdValidationRule)),
|
||||
validation.Field(&f.Name, validation.By(DefaultFieldNameValidationRule)),
|
||||
)
|
||||
}
|
150
core/field_bool_test.go
Normal file
150
core/field_bool_test.go
Normal file
|
@ -0,0 +1,150 @@
|
|||
package core_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestBoolFieldBaseMethods(t *testing.T) {
|
||||
testFieldBaseMethods(t, core.FieldTypeBool)
|
||||
}
|
||||
|
||||
func TestBoolFieldColumnType(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
f := &core.BoolField{}
|
||||
|
||||
expected := "BOOLEAN DEFAULT FALSE NOT NULL"
|
||||
|
||||
if v := f.ColumnType(app); v != expected {
|
||||
t.Fatalf("Expected\n%q\ngot\n%q", expected, v)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBoolFieldPrepareValue(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
f := &core.BoolField{}
|
||||
record := core.NewRecord(core.NewBaseCollection("test"))
|
||||
|
||||
scenarios := []struct {
|
||||
raw any
|
||||
expected bool
|
||||
}{
|
||||
{"", false},
|
||||
{"f", false},
|
||||
{"t", true},
|
||||
{1, true},
|
||||
{0, false},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%#v", i, s.raw), func(t *testing.T) {
|
||||
v, err := f.PrepareValue(record, s.raw)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if v != s.expected {
|
||||
t.Fatalf("Expected %v, got %v", s.expected, v)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBoolFieldValidateValue(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
collection := core.NewBaseCollection("test_collection")
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
field *core.BoolField
|
||||
record func() *core.Record
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
"invalid raw value",
|
||||
&core.BoolField{Name: "test"},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", 123)
|
||||
return record
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"missing field value (non-required)",
|
||||
&core.BoolField{Name: "test"},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("abc", true)
|
||||
return record
|
||||
},
|
||||
true, // because of failed nil.(bool) cast
|
||||
},
|
||||
{
|
||||
"missing field value (required)",
|
||||
&core.BoolField{Name: "test", Required: true},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("abc", true)
|
||||
return record
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"false field value (non-required)",
|
||||
&core.BoolField{Name: "test"},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", false)
|
||||
return record
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"false field value (required)",
|
||||
&core.BoolField{Name: "test", Required: true},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", false)
|
||||
return record
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"true field value (required)",
|
||||
&core.BoolField{Name: "test", Required: true},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", true)
|
||||
return record
|
||||
},
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
err := s.field.ValidateValue(context.Background(), app, s.record())
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBoolFieldValidateSettings(t *testing.T) {
|
||||
testDefaultFieldIdValidation(t, core.FieldTypeBool)
|
||||
testDefaultFieldNameValidation(t, core.FieldTypeBool)
|
||||
}
|
174
core/field_date.go
Normal file
174
core/field_date.go
Normal file
|
@ -0,0 +1,174 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/pocketbase/pocketbase/core/validators"
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Fields[FieldTypeDate] = func() Field {
|
||||
return &DateField{}
|
||||
}
|
||||
}
|
||||
|
||||
const FieldTypeDate = "date"
|
||||
|
||||
var _ Field = (*DateField)(nil)
|
||||
|
||||
// DateField defines "date" type field to store a single [types.DateTime] value.
|
||||
//
|
||||
// The respective zero record field value is the zero [types.DateTime].
|
||||
type DateField struct {
|
||||
// Name (required) is the unique name of the field.
|
||||
Name string `form:"name" json:"name"`
|
||||
|
||||
// Id is the unique stable field identifier.
|
||||
//
|
||||
// It is automatically generated from the name when adding to a collection FieldsList.
|
||||
Id string `form:"id" json:"id"`
|
||||
|
||||
// System prevents the renaming and removal of the field.
|
||||
System bool `form:"system" json:"system"`
|
||||
|
||||
// Hidden hides the field from the API response.
|
||||
Hidden bool `form:"hidden" json:"hidden"`
|
||||
|
||||
// Presentable hints the Dashboard UI to use the underlying
|
||||
// field record value in the relation preview label.
|
||||
Presentable bool `form:"presentable" json:"presentable"`
|
||||
|
||||
// ---
|
||||
|
||||
// Min specifies the min allowed field value.
|
||||
//
|
||||
// Leave it empty to skip the validator.
|
||||
Min types.DateTime `form:"min" json:"min"`
|
||||
|
||||
// Max specifies the max allowed field value.
|
||||
//
|
||||
// Leave it empty to skip the validator.
|
||||
Max types.DateTime `form:"max" json:"max"`
|
||||
|
||||
// Required will require the field value to be non-zero [types.DateTime].
|
||||
Required bool `form:"required" json:"required"`
|
||||
}
|
||||
|
||||
// Type implements [Field.Type] interface method.
|
||||
func (f *DateField) Type() string {
|
||||
return FieldTypeDate
|
||||
}
|
||||
|
||||
// GetId implements [Field.GetId] interface method.
|
||||
func (f *DateField) GetId() string {
|
||||
return f.Id
|
||||
}
|
||||
|
||||
// SetId implements [Field.SetId] interface method.
|
||||
func (f *DateField) SetId(id string) {
|
||||
f.Id = id
|
||||
}
|
||||
|
||||
// GetName implements [Field.GetName] interface method.
|
||||
func (f *DateField) GetName() string {
|
||||
return f.Name
|
||||
}
|
||||
|
||||
// SetName implements [Field.SetName] interface method.
|
||||
func (f *DateField) SetName(name string) {
|
||||
f.Name = name
|
||||
}
|
||||
|
||||
// GetSystem implements [Field.GetSystem] interface method.
|
||||
func (f *DateField) GetSystem() bool {
|
||||
return f.System
|
||||
}
|
||||
|
||||
// SetSystem implements [Field.SetSystem] interface method.
|
||||
func (f *DateField) SetSystem(system bool) {
|
||||
f.System = system
|
||||
}
|
||||
|
||||
// GetHidden implements [Field.GetHidden] interface method.
|
||||
func (f *DateField) GetHidden() bool {
|
||||
return f.Hidden
|
||||
}
|
||||
|
||||
// SetHidden implements [Field.SetHidden] interface method.
|
||||
func (f *DateField) SetHidden(hidden bool) {
|
||||
f.Hidden = hidden
|
||||
}
|
||||
|
||||
// ColumnType implements [Field.ColumnType] interface method.
|
||||
func (f *DateField) ColumnType(app App) string {
|
||||
return "TEXT DEFAULT '' NOT NULL"
|
||||
}
|
||||
|
||||
// PrepareValue implements [Field.PrepareValue] interface method.
|
||||
func (f *DateField) PrepareValue(record *Record, raw any) (any, error) {
|
||||
// ignore scan errors since the format may change between versions
|
||||
// and to allow running db adjusting migrations
|
||||
val, _ := types.ParseDateTime(raw)
|
||||
return val, nil
|
||||
}
|
||||
|
||||
// ValidateValue implements [Field.ValidateValue] interface method.
|
||||
func (f *DateField) ValidateValue(ctx context.Context, app App, record *Record) error {
|
||||
val, ok := record.GetRaw(f.Name).(types.DateTime)
|
||||
if !ok {
|
||||
return validators.ErrUnsupportedValueType
|
||||
}
|
||||
|
||||
if val.IsZero() {
|
||||
if f.Required {
|
||||
return validation.ErrRequired
|
||||
}
|
||||
return nil // nothing to check
|
||||
}
|
||||
|
||||
if !f.Min.IsZero() {
|
||||
if err := validation.Min(f.Min.Time()).Validate(val.Time()); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if !f.Max.IsZero() {
|
||||
if err := validation.Max(f.Max.Time()).Validate(val.Time()); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateSettings implements [Field.ValidateSettings] interface method.
|
||||
func (f *DateField) ValidateSettings(ctx context.Context, app App, collection *Collection) error {
|
||||
return validation.ValidateStruct(f,
|
||||
validation.Field(&f.Id, validation.By(DefaultFieldIdValidationRule)),
|
||||
validation.Field(&f.Name, validation.By(DefaultFieldNameValidationRule)),
|
||||
validation.Field(&f.Max, validation.By(f.checkRange(f.Min, f.Max))),
|
||||
)
|
||||
}
|
||||
|
||||
func (f *DateField) checkRange(min types.DateTime, max types.DateTime) validation.RuleFunc {
|
||||
return func(value any) error {
|
||||
v, _ := value.(types.DateTime)
|
||||
if v.IsZero() {
|
||||
return nil // nothing to check
|
||||
}
|
||||
|
||||
dr := validation.Date(types.DefaultDateLayout)
|
||||
|
||||
if !min.IsZero() {
|
||||
dr.Min(min.Time())
|
||||
}
|
||||
|
||||
if !max.IsZero() {
|
||||
dr.Max(max.Time())
|
||||
}
|
||||
|
||||
return dr.Validate(v.String())
|
||||
}
|
||||
}
|
229
core/field_date_test.go
Normal file
229
core/field_date_test.go
Normal file
|
@ -0,0 +1,229 @@
|
|||
package core_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
)
|
||||
|
||||
func TestDateFieldBaseMethods(t *testing.T) {
|
||||
testFieldBaseMethods(t, core.FieldTypeDate)
|
||||
}
|
||||
|
||||
func TestDateFieldColumnType(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
f := &core.DateField{}
|
||||
|
||||
expected := "TEXT DEFAULT '' NOT NULL"
|
||||
|
||||
if v := f.ColumnType(app); v != expected {
|
||||
t.Fatalf("Expected\n%q\ngot\n%q", expected, v)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDateFieldPrepareValue(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
f := &core.DateField{}
|
||||
record := core.NewRecord(core.NewBaseCollection("test"))
|
||||
|
||||
scenarios := []struct {
|
||||
raw any
|
||||
expected string
|
||||
}{
|
||||
{"", ""},
|
||||
{"invalid", ""},
|
||||
{"2024-01-01 00:11:22.345Z", "2024-01-01 00:11:22.345Z"},
|
||||
{time.Date(2024, 1, 2, 3, 4, 5, 0, time.UTC), "2024-01-02 03:04:05.000Z"},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%#v", i, s.raw), func(t *testing.T) {
|
||||
v, err := f.PrepareValue(record, s.raw)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
vDate, ok := v.(types.DateTime)
|
||||
if !ok {
|
||||
t.Fatalf("Expected types.DateTime instance, got %T", v)
|
||||
}
|
||||
|
||||
if vDate.String() != s.expected {
|
||||
t.Fatalf("Expected %v, got %v", s.expected, v)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDateFieldValidateValue(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
collection := core.NewBaseCollection("test_collection")
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
field *core.DateField
|
||||
record func() *core.Record
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
"invalid raw value",
|
||||
&core.DateField{Name: "test"},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", 123)
|
||||
return record
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"zero field value (not required)",
|
||||
&core.DateField{Name: "test"},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", types.DateTime{})
|
||||
return record
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"zero field value (required)",
|
||||
&core.DateField{Name: "test", Required: true},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", types.DateTime{})
|
||||
return record
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"non-zero field value (required)",
|
||||
&core.DateField{Name: "test", Required: true},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", types.NowDateTime())
|
||||
return record
|
||||
},
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
err := s.field.ValidateValue(context.Background(), app, s.record())
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDateFieldValidateSettings(t *testing.T) {
|
||||
testDefaultFieldIdValidation(t, core.FieldTypeDate)
|
||||
testDefaultFieldNameValidation(t, core.FieldTypeDate)
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
collection := core.NewBaseCollection("test_collection")
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
field func() *core.DateField
|
||||
expectErrors []string
|
||||
}{
|
||||
{
|
||||
"zero Min/Max",
|
||||
func() *core.DateField {
|
||||
return &core.DateField{
|
||||
Id: "test",
|
||||
Name: "test",
|
||||
}
|
||||
},
|
||||
[]string{},
|
||||
},
|
||||
{
|
||||
"non-empty Min with empty Max",
|
||||
func() *core.DateField {
|
||||
return &core.DateField{
|
||||
Id: "test",
|
||||
Name: "test",
|
||||
Min: types.NowDateTime(),
|
||||
}
|
||||
},
|
||||
[]string{},
|
||||
},
|
||||
{
|
||||
"empty Min non-empty Max",
|
||||
func() *core.DateField {
|
||||
return &core.DateField{
|
||||
Id: "test",
|
||||
Name: "test",
|
||||
Max: types.NowDateTime(),
|
||||
}
|
||||
},
|
||||
[]string{},
|
||||
},
|
||||
{
|
||||
"Min = Max",
|
||||
func() *core.DateField {
|
||||
date := types.NowDateTime()
|
||||
return &core.DateField{
|
||||
Id: "test",
|
||||
Name: "test",
|
||||
Min: date,
|
||||
Max: date,
|
||||
}
|
||||
},
|
||||
[]string{},
|
||||
},
|
||||
{
|
||||
"Min > Max",
|
||||
func() *core.DateField {
|
||||
min := types.NowDateTime()
|
||||
max := min.Add(-5 * time.Second)
|
||||
return &core.DateField{
|
||||
Id: "test",
|
||||
Name: "test",
|
||||
Min: min,
|
||||
Max: max,
|
||||
}
|
||||
},
|
||||
[]string{},
|
||||
},
|
||||
{
|
||||
"Min < Max",
|
||||
func() *core.DateField {
|
||||
max := types.NowDateTime()
|
||||
min := max.Add(-5 * time.Second)
|
||||
return &core.DateField{
|
||||
Id: "test",
|
||||
Name: "test",
|
||||
Min: min,
|
||||
Max: max,
|
||||
}
|
||||
},
|
||||
[]string{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
errs := s.field().ValidateSettings(context.Background(), app, collection)
|
||||
|
||||
tests.TestValidationErrors(t, errs, s.expectErrors)
|
||||
})
|
||||
}
|
||||
}
|
162
core/field_editor.go
Normal file
162
core/field_editor.go
Normal file
|
@ -0,0 +1,162 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/pocketbase/pocketbase/core/validators"
|
||||
"github.com/spf13/cast"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Fields[FieldTypeEditor] = func() Field {
|
||||
return &EditorField{}
|
||||
}
|
||||
}
|
||||
|
||||
const FieldTypeEditor = "editor"
|
||||
|
||||
const DefaultEditorFieldMaxSize int64 = 5 << 20
|
||||
|
||||
var (
|
||||
_ Field = (*EditorField)(nil)
|
||||
_ MaxBodySizeCalculator = (*EditorField)(nil)
|
||||
)
|
||||
|
||||
// EditorField defines "editor" type field to store HTML formatted text.
|
||||
//
|
||||
// The respective zero record field value is empty string.
|
||||
type EditorField struct {
|
||||
// Name (required) is the unique name of the field.
|
||||
Name string `form:"name" json:"name"`
|
||||
|
||||
// Id is the unique stable field identifier.
|
||||
//
|
||||
// It is automatically generated from the name when adding to a collection FieldsList.
|
||||
Id string `form:"id" json:"id"`
|
||||
|
||||
// System prevents the renaming and removal of the field.
|
||||
System bool `form:"system" json:"system"`
|
||||
|
||||
// Hidden hides the field from the API response.
|
||||
Hidden bool `form:"hidden" json:"hidden"`
|
||||
|
||||
// Presentable hints the Dashboard UI to use the underlying
|
||||
// field record value in the relation preview label.
|
||||
Presentable bool `form:"presentable" json:"presentable"`
|
||||
|
||||
// ---
|
||||
|
||||
// MaxSize specifies the maximum size of the allowed field value (in bytes and up to 2^53-1).
|
||||
//
|
||||
// If zero, a default limit of ~5MB is applied.
|
||||
MaxSize int64 `form:"maxSize" json:"maxSize"`
|
||||
|
||||
// ConvertURLs is usually used to instruct the editor whether to
|
||||
// apply url conversion (eg. stripping the domain name in case the
|
||||
// urls are using the same domain as the one where the editor is loaded).
|
||||
//
|
||||
// (see also https://www.tiny.cloud/docs/tinymce/6/url-handling/#convert_urls)
|
||||
ConvertURLs bool `form:"convertURLs" json:"convertURLs"`
|
||||
|
||||
// Required will require the field value to be non-empty string.
|
||||
Required bool `form:"required" json:"required"`
|
||||
}
|
||||
|
||||
// Type implements [Field.Type] interface method.
|
||||
func (f *EditorField) Type() string {
|
||||
return FieldTypeEditor
|
||||
}
|
||||
|
||||
// GetId implements [Field.GetId] interface method.
|
||||
func (f *EditorField) GetId() string {
|
||||
return f.Id
|
||||
}
|
||||
|
||||
// SetId implements [Field.SetId] interface method.
|
||||
func (f *EditorField) SetId(id string) {
|
||||
f.Id = id
|
||||
}
|
||||
|
||||
// GetName implements [Field.GetName] interface method.
|
||||
func (f *EditorField) GetName() string {
|
||||
return f.Name
|
||||
}
|
||||
|
||||
// SetName implements [Field.SetName] interface method.
|
||||
func (f *EditorField) SetName(name string) {
|
||||
f.Name = name
|
||||
}
|
||||
|
||||
// GetSystem implements [Field.GetSystem] interface method.
|
||||
func (f *EditorField) GetSystem() bool {
|
||||
return f.System
|
||||
}
|
||||
|
||||
// SetSystem implements [Field.SetSystem] interface method.
|
||||
func (f *EditorField) SetSystem(system bool) {
|
||||
f.System = system
|
||||
}
|
||||
|
||||
// GetHidden implements [Field.GetHidden] interface method.
|
||||
func (f *EditorField) GetHidden() bool {
|
||||
return f.Hidden
|
||||
}
|
||||
|
||||
// SetHidden implements [Field.SetHidden] interface method.
|
||||
func (f *EditorField) SetHidden(hidden bool) {
|
||||
f.Hidden = hidden
|
||||
}
|
||||
|
||||
// ColumnType implements [Field.ColumnType] interface method.
|
||||
func (f *EditorField) ColumnType(app App) string {
|
||||
return "TEXT DEFAULT '' NOT NULL"
|
||||
}
|
||||
|
||||
// PrepareValue implements [Field.PrepareValue] interface method.
|
||||
func (f *EditorField) PrepareValue(record *Record, raw any) (any, error) {
|
||||
return cast.ToString(raw), nil
|
||||
}
|
||||
|
||||
// ValidateValue implements [Field.ValidateValue] interface method.
|
||||
func (f *EditorField) ValidateValue(ctx context.Context, app App, record *Record) error {
|
||||
val, ok := record.GetRaw(f.Name).(string)
|
||||
if !ok {
|
||||
return validators.ErrUnsupportedValueType
|
||||
}
|
||||
|
||||
if f.Required {
|
||||
if err := validation.Required.Validate(val); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
maxSize := f.CalculateMaxBodySize()
|
||||
|
||||
if int64(len(val)) > maxSize {
|
||||
return validation.NewError(
|
||||
"validation_content_size_limit",
|
||||
"The maximum allowed content size is {{.maxSize}} bytes",
|
||||
).SetParams(map[string]any{"maxSize": maxSize})
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateSettings implements [Field.ValidateSettings] interface method.
|
||||
func (f *EditorField) ValidateSettings(ctx context.Context, app App, collection *Collection) error {
|
||||
return validation.ValidateStruct(f,
|
||||
validation.Field(&f.Id, validation.By(DefaultFieldIdValidationRule)),
|
||||
validation.Field(&f.Name, validation.By(DefaultFieldNameValidationRule)),
|
||||
validation.Field(&f.MaxSize, validation.Min(0), validation.Max(maxSafeJSONInt)),
|
||||
)
|
||||
}
|
||||
|
||||
// CalculateMaxBodySize implements the [MaxBodySizeCalculator] interface.
|
||||
func (f *EditorField) CalculateMaxBodySize() int64 {
|
||||
if f.MaxSize <= 0 {
|
||||
return DefaultEditorFieldMaxSize
|
||||
}
|
||||
|
||||
return f.MaxSize
|
||||
}
|
252
core/field_editor_test.go
Normal file
252
core/field_editor_test.go
Normal file
|
@ -0,0 +1,252 @@
|
|||
package core_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestEditorFieldBaseMethods(t *testing.T) {
|
||||
testFieldBaseMethods(t, core.FieldTypeEditor)
|
||||
}
|
||||
|
||||
func TestEditorFieldColumnType(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
f := &core.EditorField{}
|
||||
|
||||
expected := "TEXT DEFAULT '' NOT NULL"
|
||||
|
||||
if v := f.ColumnType(app); v != expected {
|
||||
t.Fatalf("Expected\n%q\ngot\n%q", expected, v)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEditorFieldPrepareValue(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
f := &core.EditorField{}
|
||||
record := core.NewRecord(core.NewBaseCollection("test"))
|
||||
|
||||
scenarios := []struct {
|
||||
raw any
|
||||
expected string
|
||||
}{
|
||||
{"", ""},
|
||||
{"test", "test"},
|
||||
{false, "false"},
|
||||
{true, "true"},
|
||||
{123.456, "123.456"},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%#v", i, s.raw), func(t *testing.T) {
|
||||
v, err := f.PrepareValue(record, s.raw)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
vStr, ok := v.(string)
|
||||
if !ok {
|
||||
t.Fatalf("Expected string instance, got %T", v)
|
||||
}
|
||||
|
||||
if vStr != s.expected {
|
||||
t.Fatalf("Expected %q, got %q", s.expected, v)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEditorFieldValidateValue(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
collection := core.NewBaseCollection("test_collection")
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
field *core.EditorField
|
||||
record func() *core.Record
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
"invalid raw value",
|
||||
&core.EditorField{Name: "test"},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", 123)
|
||||
return record
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"zero field value (not required)",
|
||||
&core.EditorField{Name: "test"},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", "")
|
||||
return record
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"zero field value (required)",
|
||||
&core.EditorField{Name: "test", Required: true},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", "")
|
||||
return record
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"non-zero field value (required)",
|
||||
&core.EditorField{Name: "test", Required: true},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", "abc")
|
||||
return record
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"> default MaxSize",
|
||||
&core.EditorField{Name: "test", Required: true},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", strings.Repeat("a", 1+(5<<20)))
|
||||
return record
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"> MaxSize",
|
||||
&core.EditorField{Name: "test", Required: true, MaxSize: 5},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", "abcdef")
|
||||
return record
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"<= MaxSize",
|
||||
&core.EditorField{Name: "test", Required: true, MaxSize: 5},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", "abcde")
|
||||
return record
|
||||
},
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
err := s.field.ValidateValue(context.Background(), app, s.record())
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEditorFieldValidateSettings(t *testing.T) {
|
||||
testDefaultFieldIdValidation(t, core.FieldTypeEditor)
|
||||
testDefaultFieldNameValidation(t, core.FieldTypeEditor)
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
collection := core.NewBaseCollection("test_collection")
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
field func() *core.EditorField
|
||||
expectErrors []string
|
||||
}{
|
||||
{
|
||||
"< 0 MaxSize",
|
||||
func() *core.EditorField {
|
||||
return &core.EditorField{
|
||||
Id: "test",
|
||||
Name: "test",
|
||||
MaxSize: -1,
|
||||
}
|
||||
},
|
||||
[]string{"maxSize"},
|
||||
},
|
||||
{
|
||||
"= 0 MaxSize",
|
||||
func() *core.EditorField {
|
||||
return &core.EditorField{
|
||||
Id: "test",
|
||||
Name: "test",
|
||||
}
|
||||
},
|
||||
[]string{},
|
||||
},
|
||||
{
|
||||
"> 0 MaxSize",
|
||||
func() *core.EditorField {
|
||||
return &core.EditorField{
|
||||
Id: "test",
|
||||
Name: "test",
|
||||
MaxSize: 1,
|
||||
}
|
||||
},
|
||||
[]string{},
|
||||
},
|
||||
{
|
||||
"MaxSize > safe json int",
|
||||
func() *core.EditorField {
|
||||
return &core.EditorField{
|
||||
Id: "test",
|
||||
Name: "test",
|
||||
MaxSize: 1 << 53,
|
||||
}
|
||||
},
|
||||
[]string{"maxSize"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
errs := s.field().ValidateSettings(context.Background(), app, collection)
|
||||
|
||||
tests.TestValidationErrors(t, errs, s.expectErrors)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEditorFieldCalculateMaxBodySize(t *testing.T) {
|
||||
testApp, _ := tests.NewTestApp()
|
||||
defer testApp.Cleanup()
|
||||
|
||||
scenarios := []struct {
|
||||
field *core.EditorField
|
||||
expected int64
|
||||
}{
|
||||
{&core.EditorField{}, core.DefaultEditorFieldMaxSize},
|
||||
{&core.EditorField{MaxSize: 10}, 10},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%d", i, s.field.MaxSize), func(t *testing.T) {
|
||||
result := s.field.CalculateMaxBodySize()
|
||||
|
||||
if result != s.expected {
|
||||
t.Fatalf("Expected %d, got %d", s.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
167
core/field_email.go
Normal file
167
core/field_email.go
Normal file
|
@ -0,0 +1,167 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"context"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/go-ozzo/ozzo-validation/v4/is"
|
||||
"github.com/pocketbase/pocketbase/core/validators"
|
||||
"github.com/spf13/cast"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Fields[FieldTypeEmail] = func() Field {
|
||||
return &EmailField{}
|
||||
}
|
||||
}
|
||||
|
||||
const FieldTypeEmail = "email"
|
||||
|
||||
var _ Field = (*EmailField)(nil)
|
||||
|
||||
// EmailField defines "email" type field for storing a single email string address.
|
||||
//
|
||||
// The respective zero record field value is empty string.
|
||||
type EmailField struct {
|
||||
// Name (required) is the unique name of the field.
|
||||
Name string `form:"name" json:"name"`
|
||||
|
||||
// Id is the unique stable field identifier.
|
||||
//
|
||||
// It is automatically generated from the name when adding to a collection FieldsList.
|
||||
Id string `form:"id" json:"id"`
|
||||
|
||||
// System prevents the renaming and removal of the field.
|
||||
System bool `form:"system" json:"system"`
|
||||
|
||||
// Hidden hides the field from the API response.
|
||||
Hidden bool `form:"hidden" json:"hidden"`
|
||||
|
||||
// Presentable hints the Dashboard UI to use the underlying
|
||||
// field record value in the relation preview label.
|
||||
Presentable bool `form:"presentable" json:"presentable"`
|
||||
|
||||
// ---
|
||||
|
||||
// ExceptDomains will require the email domain to NOT be included in the listed ones.
|
||||
//
|
||||
// This validator can be set only if OnlyDomains is empty.
|
||||
ExceptDomains []string `form:"exceptDomains" json:"exceptDomains"`
|
||||
|
||||
// OnlyDomains will require the email domain to be included in the listed ones.
|
||||
//
|
||||
// This validator can be set only if ExceptDomains is empty.
|
||||
OnlyDomains []string `form:"onlyDomains" json:"onlyDomains"`
|
||||
|
||||
// Required will require the field value to be non-empty email string.
|
||||
Required bool `form:"required" json:"required"`
|
||||
}
|
||||
|
||||
// Type implements [Field.Type] interface method.
|
||||
func (f *EmailField) Type() string {
|
||||
return FieldTypeEmail
|
||||
}
|
||||
|
||||
// GetId implements [Field.GetId] interface method.
|
||||
func (f *EmailField) GetId() string {
|
||||
return f.Id
|
||||
}
|
||||
|
||||
// SetId implements [Field.SetId] interface method.
|
||||
func (f *EmailField) SetId(id string) {
|
||||
f.Id = id
|
||||
}
|
||||
|
||||
// GetName implements [Field.GetName] interface method.
|
||||
func (f *EmailField) GetName() string {
|
||||
return f.Name
|
||||
}
|
||||
|
||||
// SetName implements [Field.SetName] interface method.
|
||||
func (f *EmailField) SetName(name string) {
|
||||
f.Name = name
|
||||
}
|
||||
|
||||
// GetSystem implements [Field.GetSystem] interface method.
|
||||
func (f *EmailField) GetSystem() bool {
|
||||
return f.System
|
||||
}
|
||||
|
||||
// SetSystem implements [Field.SetSystem] interface method.
|
||||
func (f *EmailField) SetSystem(system bool) {
|
||||
f.System = system
|
||||
}
|
||||
|
||||
// GetHidden implements [Field.GetHidden] interface method.
|
||||
func (f *EmailField) GetHidden() bool {
|
||||
return f.Hidden
|
||||
}
|
||||
|
||||
// SetHidden implements [Field.SetHidden] interface method.
|
||||
func (f *EmailField) SetHidden(hidden bool) {
|
||||
f.Hidden = hidden
|
||||
}
|
||||
|
||||
// ColumnType implements [Field.ColumnType] interface method.
|
||||
func (f *EmailField) ColumnType(app App) string {
|
||||
return "TEXT DEFAULT '' NOT NULL"
|
||||
}
|
||||
|
||||
// PrepareValue implements [Field.PrepareValue] interface method.
|
||||
func (f *EmailField) PrepareValue(record *Record, raw any) (any, error) {
|
||||
return cast.ToString(raw), nil
|
||||
}
|
||||
|
||||
// ValidateValue implements [Field.ValidateValue] interface method.
|
||||
func (f *EmailField) ValidateValue(ctx context.Context, app App, record *Record) error {
|
||||
val, ok := record.GetRaw(f.Name).(string)
|
||||
if !ok {
|
||||
return validators.ErrUnsupportedValueType
|
||||
}
|
||||
|
||||
if f.Required {
|
||||
if err := validation.Required.Validate(val); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if val == "" {
|
||||
return nil // nothing to check
|
||||
}
|
||||
|
||||
if err := is.EmailFormat.Validate(val); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
domain := val[strings.LastIndex(val, "@")+1:]
|
||||
|
||||
// only domains check
|
||||
if len(f.OnlyDomains) > 0 && !slices.Contains(f.OnlyDomains, domain) {
|
||||
return validation.NewError("validation_email_domain_not_allowed", "Email domain is not allowed")
|
||||
}
|
||||
|
||||
// except domains check
|
||||
if len(f.ExceptDomains) > 0 && slices.Contains(f.ExceptDomains, domain) {
|
||||
return validation.NewError("validation_email_domain_not_allowed", "Email domain is not allowed")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateSettings implements [Field.ValidateSettings] interface method.
|
||||
func (f *EmailField) ValidateSettings(ctx context.Context, app App, collection *Collection) error {
|
||||
return validation.ValidateStruct(f,
|
||||
validation.Field(&f.Id, validation.By(DefaultFieldIdValidationRule)),
|
||||
validation.Field(&f.Name, validation.By(DefaultFieldNameValidationRule)),
|
||||
validation.Field(
|
||||
&f.ExceptDomains,
|
||||
validation.When(len(f.OnlyDomains) > 0, validation.Empty).Else(validation.Each(is.Domain)),
|
||||
),
|
||||
validation.Field(
|
||||
&f.OnlyDomains,
|
||||
validation.When(len(f.ExceptDomains) > 0, validation.Empty).Else(validation.Each(is.Domain)),
|
||||
),
|
||||
)
|
||||
}
|
271
core/field_email_test.go
Normal file
271
core/field_email_test.go
Normal file
|
@ -0,0 +1,271 @@
|
|||
package core_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestEmailFieldBaseMethods(t *testing.T) {
|
||||
testFieldBaseMethods(t, core.FieldTypeEmail)
|
||||
}
|
||||
|
||||
func TestEmailFieldColumnType(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
f := &core.EmailField{}
|
||||
|
||||
expected := "TEXT DEFAULT '' NOT NULL"
|
||||
|
||||
if v := f.ColumnType(app); v != expected {
|
||||
t.Fatalf("Expected\n%q\ngot\n%q", expected, v)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEmailFieldPrepareValue(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
f := &core.EmailField{}
|
||||
record := core.NewRecord(core.NewBaseCollection("test"))
|
||||
|
||||
scenarios := []struct {
|
||||
raw any
|
||||
expected string
|
||||
}{
|
||||
{"", ""},
|
||||
{"test", "test"},
|
||||
{false, "false"},
|
||||
{true, "true"},
|
||||
{123.456, "123.456"},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%#v", i, s.raw), func(t *testing.T) {
|
||||
v, err := f.PrepareValue(record, s.raw)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
vStr, ok := v.(string)
|
||||
if !ok {
|
||||
t.Fatalf("Expected string instance, got %T", v)
|
||||
}
|
||||
|
||||
if vStr != s.expected {
|
||||
t.Fatalf("Expected %q, got %q", s.expected, v)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEmailFieldValidateValue(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
collection := core.NewBaseCollection("test_collection")
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
field *core.EmailField
|
||||
record func() *core.Record
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
"invalid raw value",
|
||||
&core.EmailField{Name: "test"},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", 123)
|
||||
return record
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"zero field value (not required)",
|
||||
&core.EmailField{Name: "test"},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", "")
|
||||
return record
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"zero field value (required)",
|
||||
&core.EmailField{Name: "test", Required: true},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", "")
|
||||
return record
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"non-zero field value (required)",
|
||||
&core.EmailField{Name: "test", Required: true},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", "test@example.com")
|
||||
return record
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"invalid email",
|
||||
&core.EmailField{Name: "test"},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", "invalid")
|
||||
return record
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"failed onlyDomains",
|
||||
&core.EmailField{Name: "test", OnlyDomains: []string{"example.org", "example.net"}},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", "test@example.com")
|
||||
return record
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"success onlyDomains",
|
||||
&core.EmailField{Name: "test", OnlyDomains: []string{"example.org", "example.com"}},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", "test@example.com")
|
||||
return record
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"failed exceptDomains",
|
||||
&core.EmailField{Name: "test", ExceptDomains: []string{"example.org", "example.com"}},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", "test@example.com")
|
||||
return record
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"success exceptDomains",
|
||||
&core.EmailField{Name: "test", ExceptDomains: []string{"example.org", "example.net"}},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", "test@example.com")
|
||||
return record
|
||||
},
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
err := s.field.ValidateValue(context.Background(), app, s.record())
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEmailFieldValidateSettings(t *testing.T) {
|
||||
testDefaultFieldIdValidation(t, core.FieldTypeEmail)
|
||||
testDefaultFieldNameValidation(t, core.FieldTypeEmail)
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
collection := core.NewBaseCollection("test_collection")
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
field func() *core.EmailField
|
||||
expectErrors []string
|
||||
}{
|
||||
{
|
||||
"zero minimal",
|
||||
func() *core.EmailField {
|
||||
return &core.EmailField{
|
||||
Id: "test",
|
||||
Name: "test",
|
||||
}
|
||||
},
|
||||
[]string{},
|
||||
},
|
||||
{
|
||||
"both onlyDomains and exceptDomains",
|
||||
func() *core.EmailField {
|
||||
return &core.EmailField{
|
||||
Id: "test",
|
||||
Name: "test",
|
||||
OnlyDomains: []string{"example.com"},
|
||||
ExceptDomains: []string{"example.org"},
|
||||
}
|
||||
},
|
||||
[]string{"onlyDomains", "exceptDomains"},
|
||||
},
|
||||
{
|
||||
"invalid onlyDomains",
|
||||
func() *core.EmailField {
|
||||
return &core.EmailField{
|
||||
Id: "test",
|
||||
Name: "test",
|
||||
OnlyDomains: []string{"example.com", "invalid"},
|
||||
}
|
||||
},
|
||||
[]string{"onlyDomains"},
|
||||
},
|
||||
{
|
||||
"valid onlyDomains",
|
||||
func() *core.EmailField {
|
||||
return &core.EmailField{
|
||||
Id: "test",
|
||||
Name: "test",
|
||||
OnlyDomains: []string{"example.com", "example.org"},
|
||||
}
|
||||
},
|
||||
[]string{},
|
||||
},
|
||||
{
|
||||
"invalid exceptDomains",
|
||||
func() *core.EmailField {
|
||||
return &core.EmailField{
|
||||
Id: "test",
|
||||
Name: "test",
|
||||
ExceptDomains: []string{"example.com", "invalid"},
|
||||
}
|
||||
},
|
||||
[]string{"exceptDomains"},
|
||||
},
|
||||
{
|
||||
"valid exceptDomains",
|
||||
func() *core.EmailField {
|
||||
return &core.EmailField{
|
||||
Id: "test",
|
||||
Name: "test",
|
||||
ExceptDomains: []string{"example.com", "example.org"},
|
||||
}
|
||||
},
|
||||
[]string{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
errs := s.field().ValidateSettings(context.Background(), app, collection)
|
||||
|
||||
tests.TestValidationErrors(t, errs, s.expectErrors)
|
||||
})
|
||||
}
|
||||
}
|
820
core/field_file.go
Normal file
820
core/field_file.go
Normal file
|
@ -0,0 +1,820 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/pocketbase/pocketbase/core/validators"
|
||||
"github.com/pocketbase/pocketbase/tools/filesystem"
|
||||
"github.com/pocketbase/pocketbase/tools/list"
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
"github.com/spf13/cast"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Fields[FieldTypeFile] = func() Field {
|
||||
return &FileField{}
|
||||
}
|
||||
}
|
||||
|
||||
const FieldTypeFile = "file"
|
||||
|
||||
const DefaultFileFieldMaxSize int64 = 5 << 20
|
||||
|
||||
var looseFilenameRegex = regexp.MustCompile(`^[^\./\\][^/\\]+$`)
|
||||
|
||||
const (
|
||||
deletedFilesPrefix = internalCustomFieldKeyPrefix + "_deletedFilesPrefix_"
|
||||
uploadedFilesPrefix = internalCustomFieldKeyPrefix + "_uploadedFilesPrefix_"
|
||||
)
|
||||
|
||||
var (
|
||||
_ Field = (*FileField)(nil)
|
||||
_ MultiValuer = (*FileField)(nil)
|
||||
_ DriverValuer = (*FileField)(nil)
|
||||
_ GetterFinder = (*FileField)(nil)
|
||||
_ SetterFinder = (*FileField)(nil)
|
||||
_ RecordInterceptor = (*FileField)(nil)
|
||||
_ MaxBodySizeCalculator = (*FileField)(nil)
|
||||
)
|
||||
|
||||
// FileField defines "file" type field for managing record file(s).
|
||||
//
|
||||
// Only the file name is stored as part of the record value.
|
||||
// New files (aka. files to upload) are expected to be of *filesytem.File.
|
||||
//
|
||||
// If MaxSelect is not set or <= 1, then the field value is expected to be a single record id.
|
||||
//
|
||||
// If MaxSelect is > 1, then the field value is expected to be a slice of record ids.
|
||||
//
|
||||
// The respective zero record field value is either empty string (single) or empty string slice (multiple).
|
||||
//
|
||||
// ---
|
||||
//
|
||||
// The following additional setter keys are available:
|
||||
//
|
||||
// - "fieldName+" - append one or more files to the existing record one. For example:
|
||||
//
|
||||
// // []string{"old1.txt", "old2.txt", "new1_ajkvass.txt", "new2_klhfnwd.txt"}
|
||||
// record.Set("documents+", []*filesystem.File{new1, new2})
|
||||
//
|
||||
// - "+fieldName" - prepend one or more files to the existing record one. For example:
|
||||
//
|
||||
// // []string{"new1_ajkvass.txt", "new2_klhfnwd.txt", "old1.txt", "old2.txt",}
|
||||
// record.Set("+documents", []*filesystem.File{new1, new2})
|
||||
//
|
||||
// - "fieldName-" - subtract/delete one or more files from the existing record one. For example:
|
||||
//
|
||||
// // []string{"old2.txt",}
|
||||
// record.Set("documents-", "old1.txt")
|
||||
type FileField struct {
|
||||
// Name (required) is the unique name of the field.
|
||||
Name string `form:"name" json:"name"`
|
||||
|
||||
// Id is the unique stable field identifier.
|
||||
//
|
||||
// It is automatically generated from the name when adding to a collection FieldsList.
|
||||
Id string `form:"id" json:"id"`
|
||||
|
||||
// System prevents the renaming and removal of the field.
|
||||
System bool `form:"system" json:"system"`
|
||||
|
||||
// Hidden hides the field from the API response.
|
||||
Hidden bool `form:"hidden" json:"hidden"`
|
||||
|
||||
// Presentable hints the Dashboard UI to use the underlying
|
||||
// field record value in the relation preview label.
|
||||
Presentable bool `form:"presentable" json:"presentable"`
|
||||
|
||||
// ---
|
||||
|
||||
// MaxSize specifies the maximum size of a single uploaded file (in bytes and up to 2^53-1).
|
||||
//
|
||||
// If zero, a default limit of 5MB is applied.
|
||||
MaxSize int64 `form:"maxSize" json:"maxSize"`
|
||||
|
||||
// MaxSelect specifies the max allowed files.
|
||||
//
|
||||
// For multiple files the value must be > 1, otherwise fallbacks to single (default).
|
||||
MaxSelect int `form:"maxSelect" json:"maxSelect"`
|
||||
|
||||
// MimeTypes specifies an optional list of the allowed file mime types.
|
||||
//
|
||||
// Leave it empty to disable the validator.
|
||||
MimeTypes []string `form:"mimeTypes" json:"mimeTypes"`
|
||||
|
||||
// Thumbs specifies an optional list of the supported thumbs for image based files.
|
||||
//
|
||||
// Each entry must be in one of the following formats:
|
||||
//
|
||||
// - WxH (eg. 100x300) - crop to WxH viewbox (from center)
|
||||
// - WxHt (eg. 100x300t) - crop to WxH viewbox (from top)
|
||||
// - WxHb (eg. 100x300b) - crop to WxH viewbox (from bottom)
|
||||
// - WxHf (eg. 100x300f) - fit inside a WxH viewbox (without cropping)
|
||||
// - 0xH (eg. 0x300) - resize to H height preserving the aspect ratio
|
||||
// - Wx0 (eg. 100x0) - resize to W width preserving the aspect ratio
|
||||
Thumbs []string `form:"thumbs" json:"thumbs"`
|
||||
|
||||
// Protected will require the users to provide a special file token to access the file.
|
||||
//
|
||||
// Note that by default all files are publicly accessible.
|
||||
//
|
||||
// For the majority of the cases this is fine because by default
|
||||
// all file names have random part appended to their name which
|
||||
// need to be known by the user before accessing the file.
|
||||
Protected bool `form:"protected" json:"protected"`
|
||||
|
||||
// Required will require the field value to have at least one file.
|
||||
Required bool `form:"required" json:"required"`
|
||||
}
|
||||
|
||||
// Type implements [Field.Type] interface method.
|
||||
func (f *FileField) Type() string {
|
||||
return FieldTypeFile
|
||||
}
|
||||
|
||||
// GetId implements [Field.GetId] interface method.
|
||||
func (f *FileField) GetId() string {
|
||||
return f.Id
|
||||
}
|
||||
|
||||
// SetId implements [Field.SetId] interface method.
|
||||
func (f *FileField) SetId(id string) {
|
||||
f.Id = id
|
||||
}
|
||||
|
||||
// GetName implements [Field.GetName] interface method.
|
||||
func (f *FileField) GetName() string {
|
||||
return f.Name
|
||||
}
|
||||
|
||||
// SetName implements [Field.SetName] interface method.
|
||||
func (f *FileField) SetName(name string) {
|
||||
f.Name = name
|
||||
}
|
||||
|
||||
// GetSystem implements [Field.GetSystem] interface method.
|
||||
func (f *FileField) GetSystem() bool {
|
||||
return f.System
|
||||
}
|
||||
|
||||
// SetSystem implements [Field.SetSystem] interface method.
|
||||
func (f *FileField) SetSystem(system bool) {
|
||||
f.System = system
|
||||
}
|
||||
|
||||
// GetHidden implements [Field.GetHidden] interface method.
|
||||
func (f *FileField) GetHidden() bool {
|
||||
return f.Hidden
|
||||
}
|
||||
|
||||
// SetHidden implements [Field.SetHidden] interface method.
|
||||
func (f *FileField) SetHidden(hidden bool) {
|
||||
f.Hidden = hidden
|
||||
}
|
||||
|
||||
// IsMultiple implements MultiValuer interface and checks whether the
|
||||
// current field options support multiple values.
|
||||
func (f *FileField) IsMultiple() bool {
|
||||
return f.MaxSelect > 1
|
||||
}
|
||||
|
||||
// ColumnType implements [Field.ColumnType] interface method.
|
||||
func (f *FileField) ColumnType(app App) string {
|
||||
if f.IsMultiple() {
|
||||
return "JSON DEFAULT '[]' NOT NULL"
|
||||
}
|
||||
|
||||
return "TEXT DEFAULT '' NOT NULL"
|
||||
}
|
||||
|
||||
// PrepareValue implements [Field.PrepareValue] interface method.
|
||||
func (f *FileField) PrepareValue(record *Record, raw any) (any, error) {
|
||||
return f.normalizeValue(raw), nil
|
||||
}
|
||||
|
||||
// DriverValue implements the [DriverValuer] interface.
|
||||
func (f *FileField) DriverValue(record *Record) (driver.Value, error) {
|
||||
files := f.toSliceValue(record.GetRaw(f.Name))
|
||||
|
||||
if f.IsMultiple() {
|
||||
ja := make(types.JSONArray[string], len(files))
|
||||
for i, v := range files {
|
||||
ja[i] = f.getFileName(v)
|
||||
}
|
||||
return ja, nil
|
||||
}
|
||||
|
||||
if len(files) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
return f.getFileName(files[len(files)-1]), nil
|
||||
}
|
||||
|
||||
// ValidateSettings implements [Field.ValidateSettings] interface method.
|
||||
func (f *FileField) ValidateSettings(ctx context.Context, app App, collection *Collection) error {
|
||||
return validation.ValidateStruct(f,
|
||||
validation.Field(&f.Id, validation.By(DefaultFieldIdValidationRule)),
|
||||
validation.Field(&f.Name, validation.By(DefaultFieldNameValidationRule)),
|
||||
validation.Field(&f.MaxSelect, validation.Min(0), validation.Max(maxSafeJSONInt)),
|
||||
validation.Field(&f.MaxSize, validation.Min(0), validation.Max(maxSafeJSONInt)),
|
||||
validation.Field(&f.Thumbs, validation.Each(
|
||||
validation.NotIn("0x0", "0x0t", "0x0b", "0x0f"),
|
||||
validation.Match(filesystem.ThumbSizeRegex),
|
||||
)),
|
||||
)
|
||||
}
|
||||
|
||||
// ValidateValue implements [Field.ValidateValue] interface method.
|
||||
func (f *FileField) ValidateValue(ctx context.Context, app App, record *Record) error {
|
||||
files := f.toSliceValue(record.GetRaw(f.Name))
|
||||
if len(files) == 0 {
|
||||
if f.Required {
|
||||
return validation.ErrRequired
|
||||
}
|
||||
return nil // nothing to check
|
||||
}
|
||||
|
||||
// validate existing and disallow new plain string filenames submission
|
||||
// (new files must be *filesystem.File)
|
||||
// ---
|
||||
oldExistingStrings := f.toSliceValue(f.getLatestOldValue(app, record))
|
||||
existingStrings := list.ToInterfaceSlice(f.extractPlainStrings(files))
|
||||
addedStrings := f.excludeFiles(existingStrings, oldExistingStrings)
|
||||
|
||||
if len(addedStrings) > 0 {
|
||||
invalidFiles := make([]string, len(addedStrings))
|
||||
for i, invalid := range addedStrings {
|
||||
invalidStr := cast.ToString(invalid)
|
||||
if len(invalidStr) > 250 {
|
||||
invalidStr = invalidStr[:250]
|
||||
}
|
||||
invalidFiles[i] = invalidStr
|
||||
}
|
||||
|
||||
return validation.NewError("validation_invalid_file", "Invalid new files: {{.invalidFiles}}.").
|
||||
SetParams(map[string]any{"invalidFiles": invalidFiles})
|
||||
}
|
||||
|
||||
maxSelect := f.maxSelect()
|
||||
if len(files) > maxSelect {
|
||||
return validation.NewError("validation_too_many_files", "The maximum allowed files is {{.maxSelect}}").
|
||||
SetParams(map[string]any{"maxSelect": maxSelect})
|
||||
}
|
||||
|
||||
// validate uploaded
|
||||
// ---
|
||||
uploads := f.extractUploadableFiles(files)
|
||||
for _, upload := range uploads {
|
||||
// loosely check the filename just in case it was manually changed after the normalization
|
||||
err := validation.Length(1, 150).Validate(upload.Name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = validation.Match(looseFilenameRegex).Validate(upload.Name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// check size
|
||||
err = validators.UploadedFileSize(f.maxSize())(upload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// check type
|
||||
if len(f.MimeTypes) > 0 {
|
||||
err = validators.UploadedFileMimeType(f.MimeTypes)(upload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *FileField) maxSize() int64 {
|
||||
if f.MaxSize <= 0 {
|
||||
return DefaultFileFieldMaxSize
|
||||
}
|
||||
|
||||
return f.MaxSize
|
||||
}
|
||||
|
||||
func (f *FileField) maxSelect() int {
|
||||
if f.MaxSelect <= 1 {
|
||||
return 1
|
||||
}
|
||||
|
||||
return f.MaxSelect
|
||||
}
|
||||
|
||||
// CalculateMaxBodySize implements the [MaxBodySizeCalculator] interface.
|
||||
func (f *FileField) CalculateMaxBodySize() int64 {
|
||||
return f.maxSize() * int64(f.maxSelect())
|
||||
}
|
||||
|
||||
// Interceptors
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
// Intercept implements the [RecordInterceptor] interface.
|
||||
//
|
||||
// note: files delete after records deletion is handled globally by the app FileManager hook
|
||||
func (f *FileField) Intercept(
|
||||
ctx context.Context,
|
||||
app App,
|
||||
record *Record,
|
||||
actionName string,
|
||||
actionFunc func() error,
|
||||
) error {
|
||||
switch actionName {
|
||||
case InterceptorActionCreateExecute, InterceptorActionUpdateExecute:
|
||||
oldValue := f.getLatestOldValue(app, record)
|
||||
|
||||
err := f.processFilesToUpload(ctx, app, record)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = actionFunc()
|
||||
if err != nil {
|
||||
return errors.Join(err, f.afterRecordExecuteFailure(newContextIfInvalid(ctx), app, record))
|
||||
}
|
||||
|
||||
f.rememberFilesToDelete(app, record, oldValue)
|
||||
|
||||
f.afterRecordExecuteSuccess(newContextIfInvalid(ctx), app, record)
|
||||
|
||||
return nil
|
||||
case InterceptorActionAfterCreateError, InterceptorActionAfterUpdateError:
|
||||
// when in transaction we assume that the error was handled by afterRecordExecuteFailure
|
||||
if app.IsTransactional() {
|
||||
return actionFunc()
|
||||
}
|
||||
|
||||
failedToDelete, deleteErr := f.deleteNewlyUploadedFiles(newContextIfInvalid(ctx), app, record)
|
||||
if deleteErr != nil {
|
||||
app.Logger().Warn(
|
||||
"Failed to cleanup all new files after record commit failure",
|
||||
"error", deleteErr,
|
||||
"failedToDelete", failedToDelete,
|
||||
)
|
||||
}
|
||||
|
||||
record.SetRaw(deletedFilesPrefix+f.Name, nil)
|
||||
|
||||
if record.IsNew() {
|
||||
// try to delete the record directory if there are no other files
|
||||
//
|
||||
// note: executed only on create failure to avoid accidentally
|
||||
// deleting a concurrently updating directory due to the
|
||||
// eventual consistent nature of some storage providers
|
||||
err := f.deleteEmptyRecordDir(newContextIfInvalid(ctx), app, record)
|
||||
if err != nil {
|
||||
app.Logger().Warn("Failed to delete empty dir after new record commit failure", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
return actionFunc()
|
||||
case InterceptorActionAfterCreate, InterceptorActionAfterUpdate:
|
||||
record.SetRaw(uploadedFilesPrefix+f.Name, nil)
|
||||
|
||||
err := f.processFilesToDelete(ctx, app, record)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return actionFunc()
|
||||
default:
|
||||
return actionFunc()
|
||||
}
|
||||
}
|
||||
func (f *FileField) getLatestOldValue(app App, record *Record) any {
|
||||
if !record.IsNew() {
|
||||
latestOriginal, err := app.FindRecordById(record.Collection(), cast.ToString(record.LastSavedPK()))
|
||||
if err == nil {
|
||||
return latestOriginal.GetRaw(f.Name)
|
||||
}
|
||||
}
|
||||
|
||||
return record.Original().GetRaw(f.Name)
|
||||
}
|
||||
|
||||
func (f *FileField) afterRecordExecuteSuccess(ctx context.Context, app App, record *Record) {
|
||||
uploaded, _ := record.GetRaw(uploadedFilesPrefix + f.Name).([]*filesystem.File)
|
||||
|
||||
// replace the uploaded file objects with their plain string names
|
||||
newValue := f.toSliceValue(record.GetRaw(f.Name))
|
||||
for i, v := range newValue {
|
||||
if file, ok := v.(*filesystem.File); ok {
|
||||
uploaded = append(uploaded, file)
|
||||
newValue[i] = file.Name
|
||||
}
|
||||
}
|
||||
f.setValue(record, newValue)
|
||||
|
||||
record.SetRaw(uploadedFilesPrefix+f.Name, uploaded)
|
||||
}
|
||||
|
||||
func (f *FileField) afterRecordExecuteFailure(ctx context.Context, app App, record *Record) error {
|
||||
uploaded := f.extractUploadableFiles(f.toSliceValue(record.GetRaw(f.Name)))
|
||||
|
||||
toDelete := make([]string, len(uploaded))
|
||||
for i, file := range uploaded {
|
||||
toDelete[i] = file.Name
|
||||
}
|
||||
|
||||
// delete previously uploaded files
|
||||
failedToDelete, deleteErr := f.deleteFilesByNamesList(ctx, app, record, list.ToUniqueStringSlice(toDelete))
|
||||
|
||||
if len(failedToDelete) > 0 {
|
||||
app.Logger().Warn(
|
||||
"Failed to cleanup the new uploaded file after record db write failure",
|
||||
"error", deleteErr,
|
||||
"failedToDelete", failedToDelete,
|
||||
)
|
||||
}
|
||||
|
||||
return deleteErr
|
||||
}
|
||||
|
||||
func (f *FileField) deleteEmptyRecordDir(ctx context.Context, app App, record *Record) error {
|
||||
fsys, err := app.NewFilesystem()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer fsys.Close()
|
||||
fsys.SetContext(newContextIfInvalid(ctx))
|
||||
|
||||
dir := record.BaseFilesPath()
|
||||
|
||||
if !fsys.IsEmptyDir(dir) {
|
||||
return nil // no-op
|
||||
}
|
||||
|
||||
err = fsys.Delete(dir)
|
||||
if err != nil && !errors.Is(err, filesystem.ErrNotFound) {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *FileField) processFilesToDelete(ctx context.Context, app App, record *Record) error {
|
||||
markedForDelete, _ := record.GetRaw(deletedFilesPrefix + f.Name).([]string)
|
||||
if len(markedForDelete) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
old := list.ToInterfaceSlice(markedForDelete)
|
||||
new := list.ToInterfaceSlice(f.extractPlainStrings(f.toSliceValue(record.GetRaw(f.Name))))
|
||||
diff := f.excludeFiles(old, new)
|
||||
|
||||
toDelete := make([]string, len(diff))
|
||||
for i, del := range diff {
|
||||
toDelete[i] = f.getFileName(del)
|
||||
}
|
||||
|
||||
failedToDelete, err := f.deleteFilesByNamesList(ctx, app, record, list.ToUniqueStringSlice(toDelete))
|
||||
|
||||
record.SetRaw(deletedFilesPrefix+f.Name, failedToDelete)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (f *FileField) rememberFilesToDelete(app App, record *Record, oldValue any) {
|
||||
old := list.ToInterfaceSlice(f.extractPlainStrings(f.toSliceValue(oldValue)))
|
||||
new := list.ToInterfaceSlice(f.extractPlainStrings(f.toSliceValue(record.GetRaw(f.Name))))
|
||||
diff := f.excludeFiles(old, new)
|
||||
|
||||
toDelete, _ := record.GetRaw(deletedFilesPrefix + f.Name).([]string)
|
||||
|
||||
for _, del := range diff {
|
||||
toDelete = append(toDelete, f.getFileName(del))
|
||||
}
|
||||
|
||||
record.SetRaw(deletedFilesPrefix+f.Name, toDelete)
|
||||
}
|
||||
|
||||
func (f *FileField) processFilesToUpload(ctx context.Context, app App, record *Record) error {
|
||||
uploads := f.extractUploadableFiles(f.toSliceValue(record.GetRaw(f.Name)))
|
||||
if len(uploads) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if record.Id == "" {
|
||||
return errors.New("uploading files requires the record to have a valid nonempty id")
|
||||
}
|
||||
|
||||
fsys, err := app.NewFilesystem()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer fsys.Close()
|
||||
fsys.SetContext(ctx)
|
||||
|
||||
var failed []error // list of upload errors
|
||||
var succeeded []string // list of uploaded file names
|
||||
|
||||
for _, upload := range uploads {
|
||||
path := record.BaseFilesPath() + "/" + upload.Name
|
||||
if err := fsys.UploadFile(upload, path); err == nil {
|
||||
succeeded = append(succeeded, upload.Name)
|
||||
} else {
|
||||
failed = append(failed, fmt.Errorf("%q: %w", upload.Name, err))
|
||||
break // for now stop on the first error since we currently don't allow partial uploads
|
||||
}
|
||||
}
|
||||
|
||||
if len(failed) > 0 {
|
||||
// cleanup - try to delete the successfully uploaded files (if any)
|
||||
_, cleanupErr := f.deleteFilesByNamesList(newContextIfInvalid(ctx), app, record, succeeded)
|
||||
|
||||
failed = append(failed, cleanupErr)
|
||||
|
||||
return fmt.Errorf("failed to upload all files: %w", errors.Join(failed...))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *FileField) deleteNewlyUploadedFiles(ctx context.Context, app App, record *Record) ([]string, error) {
|
||||
uploaded, _ := record.GetRaw(uploadedFilesPrefix + f.Name).([]*filesystem.File)
|
||||
if len(uploaded) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
names := make([]string, len(uploaded))
|
||||
for i, file := range uploaded {
|
||||
names[i] = file.Name
|
||||
}
|
||||
|
||||
failed, err := f.deleteFilesByNamesList(ctx, app, record, list.ToUniqueStringSlice(names))
|
||||
if err != nil {
|
||||
return failed, err
|
||||
}
|
||||
|
||||
record.SetRaw(uploadedFilesPrefix+f.Name, nil)
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// deleteFiles deletes a list of record files by their names.
|
||||
// Returns the failed/remaining files.
|
||||
func (f *FileField) deleteFilesByNamesList(ctx context.Context, app App, record *Record, filenames []string) ([]string, error) {
|
||||
if len(filenames) == 0 {
|
||||
return nil, nil // nothing to delete
|
||||
}
|
||||
|
||||
if record.Id == "" {
|
||||
return filenames, errors.New("the record doesn't have an id")
|
||||
}
|
||||
|
||||
fsys, err := app.NewFilesystem()
|
||||
if err != nil {
|
||||
return filenames, err
|
||||
}
|
||||
defer fsys.Close()
|
||||
fsys.SetContext(ctx)
|
||||
|
||||
var failures []error
|
||||
|
||||
for i := len(filenames) - 1; i >= 0; i-- {
|
||||
filename := filenames[i]
|
||||
if filename == "" || strings.ContainsAny(filename, "/\\") {
|
||||
continue // empty or not a plain filename
|
||||
}
|
||||
|
||||
path := record.BaseFilesPath() + "/" + filename
|
||||
|
||||
err := fsys.Delete(path)
|
||||
if err != nil && !errors.Is(err, filesystem.ErrNotFound) {
|
||||
// store the delete error
|
||||
failures = append(failures, fmt.Errorf("file %d (%q): %w", i, filename, err))
|
||||
} else {
|
||||
// remove the deleted file from the list
|
||||
filenames = append(filenames[:i], filenames[i+1:]...)
|
||||
|
||||
// try to delete the related file thumbs (if any)
|
||||
thumbsErr := fsys.DeletePrefix(record.BaseFilesPath() + "/thumbs_" + filename + "/")
|
||||
if len(thumbsErr) > 0 {
|
||||
app.Logger().Warn("Failed to delete file thumbs", "error", errors.Join(thumbsErr...))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(failures) > 0 {
|
||||
return filenames, fmt.Errorf("failed to delete all files: %w", errors.Join(failures...))
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// newContextIfInvalid returns a new Background context if the provided one was cancelled.
|
||||
func newContextIfInvalid(ctx context.Context) context.Context {
|
||||
if ctx.Err() == nil {
|
||||
return ctx
|
||||
}
|
||||
|
||||
return context.Background()
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
// FindGetter implements the [GetterFinder] interface.
|
||||
func (f *FileField) FindGetter(key string) GetterFunc {
|
||||
switch key {
|
||||
case f.Name:
|
||||
return func(record *Record) any {
|
||||
return record.GetRaw(f.Name)
|
||||
}
|
||||
case f.Name + ":unsaved":
|
||||
return func(record *Record) any {
|
||||
return f.extractUploadableFiles(f.toSliceValue(record.GetRaw(f.Name)))
|
||||
}
|
||||
case f.Name + ":uploaded":
|
||||
// deprecated
|
||||
log.Println("[file field getter] please replace :uploaded with :unsaved")
|
||||
return func(record *Record) any {
|
||||
return f.extractUploadableFiles(f.toSliceValue(record.GetRaw(f.Name)))
|
||||
}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
// FindSetter implements the [SetterFinder] interface.
|
||||
func (f *FileField) FindSetter(key string) SetterFunc {
|
||||
switch key {
|
||||
case f.Name:
|
||||
return f.setValue
|
||||
case "+" + f.Name:
|
||||
return f.prependValue
|
||||
case f.Name + "+":
|
||||
return f.appendValue
|
||||
case f.Name + "-":
|
||||
return f.subtractValue
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (f *FileField) setValue(record *Record, raw any) {
|
||||
val := f.normalizeValue(raw)
|
||||
|
||||
record.SetRaw(f.Name, val)
|
||||
}
|
||||
|
||||
func (f *FileField) prependValue(record *Record, toPrepend any) {
|
||||
files := f.toSliceValue(record.GetRaw(f.Name))
|
||||
prepends := f.toSliceValue(toPrepend)
|
||||
|
||||
if len(prepends) > 0 {
|
||||
files = append(prepends, files...)
|
||||
}
|
||||
|
||||
f.setValue(record, files)
|
||||
}
|
||||
|
||||
func (f *FileField) appendValue(record *Record, toAppend any) {
|
||||
files := f.toSliceValue(record.GetRaw(f.Name))
|
||||
appends := f.toSliceValue(toAppend)
|
||||
|
||||
if len(appends) > 0 {
|
||||
files = append(files, appends...)
|
||||
}
|
||||
|
||||
f.setValue(record, files)
|
||||
}
|
||||
|
||||
func (f *FileField) subtractValue(record *Record, toRemove any) {
|
||||
files := f.excludeFiles(
|
||||
f.toSliceValue(record.GetRaw(f.Name)),
|
||||
f.toSliceValue(toRemove),
|
||||
)
|
||||
|
||||
f.setValue(record, files)
|
||||
}
|
||||
|
||||
func (f *FileField) normalizeValue(raw any) any {
|
||||
files := f.toSliceValue(raw)
|
||||
|
||||
if f.IsMultiple() {
|
||||
return files
|
||||
}
|
||||
|
||||
if len(files) > 0 {
|
||||
return files[len(files)-1] // the last selected
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func (f *FileField) toSliceValue(raw any) []any {
|
||||
var result []any
|
||||
|
||||
switch value := raw.(type) {
|
||||
case nil:
|
||||
// nothing to cast
|
||||
case *filesystem.File:
|
||||
result = append(result, value)
|
||||
case filesystem.File:
|
||||
result = append(result, &value)
|
||||
case []*filesystem.File:
|
||||
for _, v := range value {
|
||||
result = append(result, v)
|
||||
}
|
||||
case []filesystem.File:
|
||||
for _, v := range value {
|
||||
result = append(result, &v)
|
||||
}
|
||||
case []any:
|
||||
for _, v := range value {
|
||||
casted := f.toSliceValue(v)
|
||||
if len(casted) == 1 {
|
||||
result = append(result, casted[0])
|
||||
}
|
||||
}
|
||||
default:
|
||||
result = list.ToInterfaceSlice(list.ToUniqueStringSlice(value))
|
||||
}
|
||||
|
||||
return f.uniqueFiles(result)
|
||||
}
|
||||
|
||||
func (f *FileField) uniqueFiles(files []any) []any {
|
||||
found := make(map[string]struct{}, len(files))
|
||||
result := make([]any, 0, len(files))
|
||||
|
||||
for _, fv := range files {
|
||||
name := f.getFileName(fv)
|
||||
if _, ok := found[name]; !ok {
|
||||
result = append(result, fv)
|
||||
found[name] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func (f *FileField) extractPlainStrings(files []any) []string {
|
||||
result := []string{}
|
||||
|
||||
for _, raw := range files {
|
||||
if f, ok := raw.(string); ok {
|
||||
result = append(result, f)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func (f *FileField) extractUploadableFiles(files []any) []*filesystem.File {
|
||||
result := []*filesystem.File{}
|
||||
|
||||
for _, raw := range files {
|
||||
if upload, ok := raw.(*filesystem.File); ok {
|
||||
result = append(result, upload)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func (f *FileField) excludeFiles(base []any, toExclude []any) []any {
|
||||
result := make([]any, 0, len(base))
|
||||
|
||||
SUBTRACT_LOOP:
|
||||
for _, fv := range base {
|
||||
for _, exclude := range toExclude {
|
||||
if f.getFileName(exclude) == f.getFileName(fv) {
|
||||
continue SUBTRACT_LOOP // skip
|
||||
}
|
||||
}
|
||||
|
||||
result = append(result, fv)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func (f *FileField) getFileName(file any) string {
|
||||
switch v := file.(type) {
|
||||
case string:
|
||||
return v
|
||||
case *filesystem.File:
|
||||
return v.Name
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
1152
core/field_file_test.go
Normal file
1152
core/field_file_test.go
Normal file
File diff suppressed because it is too large
Load diff
148
core/field_geo_point.go
Normal file
148
core/field_geo_point.go
Normal file
|
@ -0,0 +1,148 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/pocketbase/pocketbase/core/validators"
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Fields[FieldTypeGeoPoint] = func() Field {
|
||||
return &GeoPointField{}
|
||||
}
|
||||
}
|
||||
|
||||
const FieldTypeGeoPoint = "geoPoint"
|
||||
|
||||
var (
|
||||
_ Field = (*GeoPointField)(nil)
|
||||
)
|
||||
|
||||
// GeoPointField defines "geoPoint" type field for storing latitude and longitude GPS coordinates.
|
||||
//
|
||||
// You can set the record field value as [types.GeoPoint], map or serialized json object with lat-lon props.
|
||||
// The stored value is always converted to [types.GeoPoint].
|
||||
// Nil, empty map, empty bytes slice, etc. results in zero [types.GeoPoint].
|
||||
//
|
||||
// Examples of updating a record's GeoPointField value programmatically:
|
||||
//
|
||||
// record.Set("location", types.GeoPoint{Lat: 123, Lon: 456})
|
||||
// record.Set("location", map[string]any{"lat":123, "lon":456})
|
||||
// record.Set("location", []byte(`{"lat":123, "lon":456}`)
|
||||
type GeoPointField struct {
|
||||
// Name (required) is the unique name of the field.
|
||||
Name string `form:"name" json:"name"`
|
||||
|
||||
// Id is the unique stable field identifier.
|
||||
//
|
||||
// It is automatically generated from the name when adding to a collection FieldsList.
|
||||
Id string `form:"id" json:"id"`
|
||||
|
||||
// System prevents the renaming and removal of the field.
|
||||
System bool `form:"system" json:"system"`
|
||||
|
||||
// Hidden hides the field from the API response.
|
||||
Hidden bool `form:"hidden" json:"hidden"`
|
||||
|
||||
// Presentable hints the Dashboard UI to use the underlying
|
||||
// field record value in the relation preview label.
|
||||
Presentable bool `form:"presentable" json:"presentable"`
|
||||
|
||||
// ---
|
||||
|
||||
// Required will require the field coordinates to be non-zero (aka. not "Null Island").
|
||||
Required bool `form:"required" json:"required"`
|
||||
}
|
||||
|
||||
// Type implements [Field.Type] interface method.
|
||||
func (f *GeoPointField) Type() string {
|
||||
return FieldTypeGeoPoint
|
||||
}
|
||||
|
||||
// GetId implements [Field.GetId] interface method.
|
||||
func (f *GeoPointField) GetId() string {
|
||||
return f.Id
|
||||
}
|
||||
|
||||
// SetId implements [Field.SetId] interface method.
|
||||
func (f *GeoPointField) SetId(id string) {
|
||||
f.Id = id
|
||||
}
|
||||
|
||||
// GetName implements [Field.GetName] interface method.
|
||||
func (f *GeoPointField) GetName() string {
|
||||
return f.Name
|
||||
}
|
||||
|
||||
// SetName implements [Field.SetName] interface method.
|
||||
func (f *GeoPointField) SetName(name string) {
|
||||
f.Name = name
|
||||
}
|
||||
|
||||
// GetSystem implements [Field.GetSystem] interface method.
|
||||
func (f *GeoPointField) GetSystem() bool {
|
||||
return f.System
|
||||
}
|
||||
|
||||
// SetSystem implements [Field.SetSystem] interface method.
|
||||
func (f *GeoPointField) SetSystem(system bool) {
|
||||
f.System = system
|
||||
}
|
||||
|
||||
// GetHidden implements [Field.GetHidden] interface method.
|
||||
func (f *GeoPointField) GetHidden() bool {
|
||||
return f.Hidden
|
||||
}
|
||||
|
||||
// SetHidden implements [Field.SetHidden] interface method.
|
||||
func (f *GeoPointField) SetHidden(hidden bool) {
|
||||
f.Hidden = hidden
|
||||
}
|
||||
|
||||
// ColumnType implements [Field.ColumnType] interface method.
|
||||
func (f *GeoPointField) ColumnType(app App) string {
|
||||
return `JSON DEFAULT '{"lon":0,"lat":0}' NOT NULL`
|
||||
}
|
||||
|
||||
// PrepareValue implements [Field.PrepareValue] interface method.
|
||||
func (f *GeoPointField) PrepareValue(record *Record, raw any) (any, error) {
|
||||
point := types.GeoPoint{}
|
||||
err := point.Scan(raw)
|
||||
return point, err
|
||||
}
|
||||
|
||||
// ValidateValue implements [Field.ValidateValue] interface method.
|
||||
func (f *GeoPointField) ValidateValue(ctx context.Context, app App, record *Record) error {
|
||||
val, ok := record.GetRaw(f.Name).(types.GeoPoint)
|
||||
if !ok {
|
||||
return validators.ErrUnsupportedValueType
|
||||
}
|
||||
|
||||
// zero value
|
||||
if val.Lat == 0 && val.Lon == 0 {
|
||||
if f.Required {
|
||||
return validation.ErrRequired
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if val.Lat < -90 || val.Lat > 90 {
|
||||
return validation.NewError("validation_invalid_latitude", "Latitude must be between -90 and 90 degrees.")
|
||||
}
|
||||
|
||||
if val.Lon < -180 || val.Lon > 180 {
|
||||
return validation.NewError("validation_invalid_longitude", "Longitude must be between -180 and 180 degrees.")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateSettings implements [Field.ValidateSettings] interface method.
|
||||
func (f *GeoPointField) ValidateSettings(ctx context.Context, app App, collection *Collection) error {
|
||||
return validation.ValidateStruct(f,
|
||||
validation.Field(&f.Id, validation.By(DefaultFieldIdValidationRule)),
|
||||
validation.Field(&f.Name, validation.By(DefaultFieldNameValidationRule)),
|
||||
)
|
||||
}
|
202
core/field_geo_point_test.go
Normal file
202
core/field_geo_point_test.go
Normal file
|
@ -0,0 +1,202 @@
|
|||
package core_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
)
|
||||
|
||||
func TestGeoPointFieldBaseMethods(t *testing.T) {
|
||||
testFieldBaseMethods(t, core.FieldTypeGeoPoint)
|
||||
}
|
||||
|
||||
func TestGeoPointFieldColumnType(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
f := &core.GeoPointField{}
|
||||
|
||||
expected := `JSON DEFAULT '{"lon":0,"lat":0}' NOT NULL`
|
||||
|
||||
if v := f.ColumnType(app); v != expected {
|
||||
t.Fatalf("Expected\n%q\ngot\n%q", expected, v)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeoPointFieldPrepareValue(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
f := &core.GeoPointField{}
|
||||
record := core.NewRecord(core.NewBaseCollection("test"))
|
||||
|
||||
scenarios := []struct {
|
||||
raw any
|
||||
expected string
|
||||
}{
|
||||
{nil, `{"lon":0,"lat":0}`},
|
||||
{"", `{"lon":0,"lat":0}`},
|
||||
{[]byte{}, `{"lon":0,"lat":0}`},
|
||||
{map[string]any{}, `{"lon":0,"lat":0}`},
|
||||
{types.GeoPoint{Lon: 10, Lat: 20}, `{"lon":10,"lat":20}`},
|
||||
{&types.GeoPoint{Lon: 10, Lat: 20}, `{"lon":10,"lat":20}`},
|
||||
{[]byte(`{"lon": 10, "lat": 20}`), `{"lon":10,"lat":20}`},
|
||||
{map[string]any{"lon": 10, "lat": 20}, `{"lon":10,"lat":20}`},
|
||||
{map[string]float64{"lon": 10, "lat": 20}, `{"lon":10,"lat":20}`},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%#v", i, s.raw), func(t *testing.T) {
|
||||
v, err := f.PrepareValue(record, s.raw)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
raw, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
rawStr := string(raw)
|
||||
|
||||
if rawStr != s.expected {
|
||||
t.Fatalf("Expected\n%s\ngot\n%s", s.expected, rawStr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeoPointFieldValidateValue(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
collection := core.NewBaseCollection("test_collection")
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
field *core.GeoPointField
|
||||
record func() *core.Record
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
"invalid raw value",
|
||||
&core.GeoPointField{Name: "test"},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", 123)
|
||||
return record
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"zero field value (non-required)",
|
||||
&core.GeoPointField{Name: "test"},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", types.GeoPoint{})
|
||||
return record
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"zero field value (required)",
|
||||
&core.GeoPointField{Name: "test", Required: true},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", types.GeoPoint{})
|
||||
return record
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"non-zero Lat field value (required)",
|
||||
&core.GeoPointField{Name: "test", Required: true},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", types.GeoPoint{Lat: 1})
|
||||
return record
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"non-zero Lon field value (required)",
|
||||
&core.GeoPointField{Name: "test", Required: true},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", types.GeoPoint{Lon: 1})
|
||||
return record
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"non-zero Lat-Lon field value (required)",
|
||||
&core.GeoPointField{Name: "test", Required: true},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", types.GeoPoint{Lon: -1, Lat: -2})
|
||||
return record
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"lat < -90",
|
||||
&core.GeoPointField{Name: "test"},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", types.GeoPoint{Lat: -90.1})
|
||||
return record
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"lat > 90",
|
||||
&core.GeoPointField{Name: "test"},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", types.GeoPoint{Lat: 90.1})
|
||||
return record
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"lon < -180",
|
||||
&core.GeoPointField{Name: "test"},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", types.GeoPoint{Lon: -180.1})
|
||||
return record
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"lon > 180",
|
||||
&core.GeoPointField{Name: "test"},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", types.GeoPoint{Lon: 180.1})
|
||||
return record
|
||||
},
|
||||
true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
err := s.field.ValidateValue(context.Background(), app, s.record())
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeoPointFieldValidateSettings(t *testing.T) {
|
||||
testDefaultFieldIdValidation(t, core.FieldTypeGeoPoint)
|
||||
testDefaultFieldNameValidation(t, core.FieldTypeGeoPoint)
|
||||
}
|
195
core/field_json.go
Normal file
195
core/field_json.go
Normal file
|
@ -0,0 +1,195 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"context"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/go-ozzo/ozzo-validation/v4/is"
|
||||
"github.com/pocketbase/pocketbase/core/validators"
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Fields[FieldTypeJSON] = func() Field {
|
||||
return &JSONField{}
|
||||
}
|
||||
}
|
||||
|
||||
const FieldTypeJSON = "json"
|
||||
|
||||
const DefaultJSONFieldMaxSize int64 = 1 << 20
|
||||
|
||||
var (
|
||||
_ Field = (*JSONField)(nil)
|
||||
_ MaxBodySizeCalculator = (*JSONField)(nil)
|
||||
)
|
||||
|
||||
// JSONField defines "json" type field for storing any serialized JSON value.
|
||||
//
|
||||
// The respective zero record field value is the zero [types.JSONRaw].
|
||||
type JSONField struct {
|
||||
// Name (required) is the unique name of the field.
|
||||
Name string `form:"name" json:"name"`
|
||||
|
||||
// Id is the unique stable field identifier.
|
||||
//
|
||||
// It is automatically generated from the name when adding to a collection FieldsList.
|
||||
Id string `form:"id" json:"id"`
|
||||
|
||||
// System prevents the renaming and removal of the field.
|
||||
System bool `form:"system" json:"system"`
|
||||
|
||||
// Hidden hides the field from the API response.
|
||||
Hidden bool `form:"hidden" json:"hidden"`
|
||||
|
||||
// Presentable hints the Dashboard UI to use the underlying
|
||||
// field record value in the relation preview label.
|
||||
Presentable bool `form:"presentable" json:"presentable"`
|
||||
|
||||
// ---
|
||||
|
||||
// MaxSize specifies the maximum size of the allowed field value (in bytes and up to 2^53-1).
|
||||
//
|
||||
// If zero, a default limit of 1MB is applied.
|
||||
MaxSize int64 `form:"maxSize" json:"maxSize"`
|
||||
|
||||
// Required will require the field value to be non-empty JSON value
|
||||
// (aka. not "null", `""`, "[]", "{}").
|
||||
Required bool `form:"required" json:"required"`
|
||||
}
|
||||
|
||||
// Type implements [Field.Type] interface method.
|
||||
func (f *JSONField) Type() string {
|
||||
return FieldTypeJSON
|
||||
}
|
||||
|
||||
// GetId implements [Field.GetId] interface method.
|
||||
func (f *JSONField) GetId() string {
|
||||
return f.Id
|
||||
}
|
||||
|
||||
// SetId implements [Field.SetId] interface method.
|
||||
func (f *JSONField) SetId(id string) {
|
||||
f.Id = id
|
||||
}
|
||||
|
||||
// GetName implements [Field.GetName] interface method.
|
||||
func (f *JSONField) GetName() string {
|
||||
return f.Name
|
||||
}
|
||||
|
||||
// SetName implements [Field.SetName] interface method.
|
||||
func (f *JSONField) SetName(name string) {
|
||||
f.Name = name
|
||||
}
|
||||
|
||||
// GetSystem implements [Field.GetSystem] interface method.
|
||||
func (f *JSONField) GetSystem() bool {
|
||||
return f.System
|
||||
}
|
||||
|
||||
// SetSystem implements [Field.SetSystem] interface method.
|
||||
func (f *JSONField) SetSystem(system bool) {
|
||||
f.System = system
|
||||
}
|
||||
|
||||
// GetHidden implements [Field.GetHidden] interface method.
|
||||
func (f *JSONField) GetHidden() bool {
|
||||
return f.Hidden
|
||||
}
|
||||
|
||||
// SetHidden implements [Field.SetHidden] interface method.
|
||||
func (f *JSONField) SetHidden(hidden bool) {
|
||||
f.Hidden = hidden
|
||||
}
|
||||
|
||||
// ColumnType implements [Field.ColumnType] interface method.
|
||||
func (f *JSONField) ColumnType(app App) string {
|
||||
return "JSON DEFAULT NULL"
|
||||
}
|
||||
|
||||
// PrepareValue implements [Field.PrepareValue] interface method.
|
||||
func (f *JSONField) PrepareValue(record *Record, raw any) (any, error) {
|
||||
if str, ok := raw.(string); ok {
|
||||
// in order to support seamlessly both json and multipart/form-data requests,
|
||||
// the following normalization rules are applied for plain string values:
|
||||
// - "true" is converted to the json `true`
|
||||
// - "false" is converted to the json `false`
|
||||
// - "null" is converted to the json `null`
|
||||
// - "[1,2,3]" is converted to the json `[1,2,3]`
|
||||
// - "{\"a\":1,\"b\":2}" is converted to the json `{"a":1,"b":2}`
|
||||
// - numeric strings are converted to json number
|
||||
// - double quoted strings are left as they are (aka. without normalizations)
|
||||
// - any other string (empty string too) is double quoted
|
||||
if str == "" {
|
||||
raw = strconv.Quote(str)
|
||||
} else if str == "null" || str == "true" || str == "false" {
|
||||
raw = str
|
||||
} else if ((str[0] >= '0' && str[0] <= '9') ||
|
||||
str[0] == '-' ||
|
||||
str[0] == '"' ||
|
||||
str[0] == '[' ||
|
||||
str[0] == '{') &&
|
||||
is.JSON.Validate(str) == nil {
|
||||
raw = str
|
||||
} else {
|
||||
raw = strconv.Quote(str)
|
||||
}
|
||||
}
|
||||
|
||||
return types.ParseJSONRaw(raw)
|
||||
}
|
||||
|
||||
var emptyJSONValues = []string{
|
||||
"null", `""`, "[]", "{}", "",
|
||||
}
|
||||
|
||||
// ValidateValue implements [Field.ValidateValue] interface method.
|
||||
func (f *JSONField) ValidateValue(ctx context.Context, app App, record *Record) error {
|
||||
raw, ok := record.GetRaw(f.Name).(types.JSONRaw)
|
||||
if !ok {
|
||||
return validators.ErrUnsupportedValueType
|
||||
}
|
||||
|
||||
maxSize := f.CalculateMaxBodySize()
|
||||
|
||||
if int64(len(raw)) > maxSize {
|
||||
return validation.NewError(
|
||||
"validation_json_size_limit",
|
||||
"The maximum allowed JSON size is {{.maxSize}} bytes",
|
||||
).SetParams(map[string]any{"maxSize": maxSize})
|
||||
}
|
||||
|
||||
if is.JSON.Validate(raw) != nil {
|
||||
return validation.NewError("validation_invalid_json", "Must be a valid json value")
|
||||
}
|
||||
|
||||
rawStr := strings.TrimSpace(raw.String())
|
||||
|
||||
if f.Required && slices.Contains(emptyJSONValues, rawStr) {
|
||||
return validation.ErrRequired
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateSettings implements [Field.ValidateSettings] interface method.
|
||||
func (f *JSONField) ValidateSettings(ctx context.Context, app App, collection *Collection) error {
|
||||
return validation.ValidateStruct(f,
|
||||
validation.Field(&f.Id, validation.By(DefaultFieldIdValidationRule)),
|
||||
validation.Field(&f.Name, validation.By(DefaultFieldNameValidationRule)),
|
||||
validation.Field(&f.MaxSize, validation.Min(0), validation.Max(maxSafeJSONInt)),
|
||||
)
|
||||
}
|
||||
|
||||
// CalculateMaxBodySize implements the [MaxBodySizeCalculator] interface.
|
||||
func (f *JSONField) CalculateMaxBodySize() int64 {
|
||||
if f.MaxSize <= 0 {
|
||||
return DefaultJSONFieldMaxSize
|
||||
}
|
||||
|
||||
return f.MaxSize
|
||||
}
|
277
core/field_json_test.go
Normal file
277
core/field_json_test.go
Normal file
|
@ -0,0 +1,277 @@
|
|||
package core_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
)
|
||||
|
||||
func TestJSONFieldBaseMethods(t *testing.T) {
|
||||
testFieldBaseMethods(t, core.FieldTypeJSON)
|
||||
}
|
||||
|
||||
func TestJSONFieldColumnType(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
f := &core.JSONField{}
|
||||
|
||||
expected := "JSON DEFAULT NULL"
|
||||
|
||||
if v := f.ColumnType(app); v != expected {
|
||||
t.Fatalf("Expected\n%q\ngot\n%q", expected, v)
|
||||
}
|
||||
}
|
||||
|
||||
func TestJSONFieldPrepareValue(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
f := &core.JSONField{}
|
||||
record := core.NewRecord(core.NewBaseCollection("test"))
|
||||
|
||||
scenarios := []struct {
|
||||
raw any
|
||||
expected string
|
||||
}{
|
||||
{"null", `null`},
|
||||
{"", `""`},
|
||||
{"true", `true`},
|
||||
{"false", `false`},
|
||||
{"test", `"test"`},
|
||||
{"123", `123`},
|
||||
{"-456", `-456`},
|
||||
{"[1,2,3]", `[1,2,3]`},
|
||||
{"[1,2,3", `"[1,2,3"`},
|
||||
{`{"a":1,"b":2}`, `{"a":1,"b":2}`},
|
||||
{`{"a":1,"b":2`, `"{\"a\":1,\"b\":2"`},
|
||||
{[]int{1, 2, 3}, `[1,2,3]`},
|
||||
{map[string]int{"a": 1, "b": 2}, `{"a":1,"b":2}`},
|
||||
{nil, `null`},
|
||||
{false, `false`},
|
||||
{true, `true`},
|
||||
{-78, `-78`},
|
||||
{123.456, `123.456`},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%#v", i, s.raw), func(t *testing.T) {
|
||||
v, err := f.PrepareValue(record, s.raw)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
raw, ok := v.(types.JSONRaw)
|
||||
if !ok {
|
||||
t.Fatalf("Expected string instance, got %T", v)
|
||||
}
|
||||
rawStr := raw.String()
|
||||
|
||||
if rawStr != s.expected {
|
||||
t.Fatalf("Expected\n%#v\ngot\n%#v", s.expected, rawStr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestJSONFieldValidateValue(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
collection := core.NewBaseCollection("test_collection")
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
field *core.JSONField
|
||||
record func() *core.Record
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
"invalid raw value",
|
||||
&core.JSONField{Name: "test"},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", 123)
|
||||
return record
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"zero field value (not required)",
|
||||
&core.JSONField{Name: "test"},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", types.JSONRaw{})
|
||||
return record
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"zero field value (required)",
|
||||
&core.JSONField{Name: "test", Required: true},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", types.JSONRaw{})
|
||||
return record
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"non-zero field value (required)",
|
||||
&core.JSONField{Name: "test", Required: true},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", types.JSONRaw("[1,2,3]"))
|
||||
return record
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"non-zero field value (required)",
|
||||
&core.JSONField{Name: "test", Required: true},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", types.JSONRaw(`"aaa"`))
|
||||
return record
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"> default MaxSize",
|
||||
&core.JSONField{Name: "test"},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", types.JSONRaw(`"`+strings.Repeat("a", (1<<20))+`"`))
|
||||
return record
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"> MaxSize",
|
||||
&core.JSONField{Name: "test", MaxSize: 5},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", types.JSONRaw(`"aaaa"`))
|
||||
return record
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"<= MaxSize",
|
||||
&core.JSONField{Name: "test", MaxSize: 5},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", types.JSONRaw(`"aaa"`))
|
||||
return record
|
||||
},
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
err := s.field.ValidateValue(context.Background(), app, s.record())
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestJSONFieldValidateSettings(t *testing.T) {
|
||||
testDefaultFieldIdValidation(t, core.FieldTypeJSON)
|
||||
testDefaultFieldNameValidation(t, core.FieldTypeJSON)
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
collection := core.NewBaseCollection("test_collection")
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
field func() *core.JSONField
|
||||
expectErrors []string
|
||||
}{
|
||||
{
|
||||
"MaxSize < 0",
|
||||
func() *core.JSONField {
|
||||
return &core.JSONField{
|
||||
Id: "test",
|
||||
Name: "test",
|
||||
MaxSize: -1,
|
||||
}
|
||||
},
|
||||
[]string{"maxSize"},
|
||||
},
|
||||
{
|
||||
"MaxSize = 0",
|
||||
func() *core.JSONField {
|
||||
return &core.JSONField{
|
||||
Id: "test",
|
||||
Name: "test",
|
||||
}
|
||||
},
|
||||
[]string{},
|
||||
},
|
||||
{
|
||||
"MaxSize > 0",
|
||||
func() *core.JSONField {
|
||||
return &core.JSONField{
|
||||
Id: "test",
|
||||
Name: "test",
|
||||
MaxSize: 1,
|
||||
}
|
||||
},
|
||||
[]string{},
|
||||
},
|
||||
{
|
||||
"MaxSize > safe json int",
|
||||
func() *core.JSONField {
|
||||
return &core.JSONField{
|
||||
Id: "test",
|
||||
Name: "test",
|
||||
MaxSize: 1 << 53,
|
||||
}
|
||||
},
|
||||
[]string{"maxSize"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
errs := s.field().ValidateSettings(context.Background(), app, collection)
|
||||
|
||||
tests.TestValidationErrors(t, errs, s.expectErrors)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestJSONFieldCalculateMaxBodySize(t *testing.T) {
|
||||
testApp, _ := tests.NewTestApp()
|
||||
defer testApp.Cleanup()
|
||||
|
||||
scenarios := []struct {
|
||||
field *core.JSONField
|
||||
expected int64
|
||||
}{
|
||||
{&core.JSONField{}, core.DefaultJSONFieldMaxSize},
|
||||
{&core.JSONField{MaxSize: 10}, 10},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%d", i, s.field.MaxSize), func(t *testing.T) {
|
||||
result := s.field.CalculateMaxBodySize()
|
||||
|
||||
if result != s.expected {
|
||||
t.Fatalf("Expected %d, got %d", s.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
222
core/field_number.go
Normal file
222
core/field_number.go
Normal file
|
@ -0,0 +1,222 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/pocketbase/pocketbase/core/validators"
|
||||
"github.com/spf13/cast"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Fields[FieldTypeNumber] = func() Field {
|
||||
return &NumberField{}
|
||||
}
|
||||
}
|
||||
|
||||
const FieldTypeNumber = "number"
|
||||
|
||||
var (
|
||||
_ Field = (*NumberField)(nil)
|
||||
_ SetterFinder = (*NumberField)(nil)
|
||||
)
|
||||
|
||||
// NumberField defines "number" type field for storing numeric (float64) value.
|
||||
//
|
||||
// The respective zero record field value is 0.
|
||||
//
|
||||
// The following additional setter keys are available:
|
||||
//
|
||||
// - "fieldName+" - appends to the existing record value. For example:
|
||||
// record.Set("total+", 5)
|
||||
// - "fieldName-" - subtracts from the existing record value. For example:
|
||||
// record.Set("total-", 5)
|
||||
type NumberField struct {
|
||||
// Name (required) is the unique name of the field.
|
||||
Name string `form:"name" json:"name"`
|
||||
|
||||
// Id is the unique stable field identifier.
|
||||
//
|
||||
// It is automatically generated from the name when adding to a collection FieldsList.
|
||||
Id string `form:"id" json:"id"`
|
||||
|
||||
// System prevents the renaming and removal of the field.
|
||||
System bool `form:"system" json:"system"`
|
||||
|
||||
// Hidden hides the field from the API response.
|
||||
Hidden bool `form:"hidden" json:"hidden"`
|
||||
|
||||
// Presentable hints the Dashboard UI to use the underlying
|
||||
// field record value in the relation preview label.
|
||||
Presentable bool `form:"presentable" json:"presentable"`
|
||||
|
||||
// ---
|
||||
|
||||
// Min specifies the min allowed field value.
|
||||
//
|
||||
// Leave it nil to skip the validator.
|
||||
Min *float64 `form:"min" json:"min"`
|
||||
|
||||
// Max specifies the max allowed field value.
|
||||
//
|
||||
// Leave it nil to skip the validator.
|
||||
Max *float64 `form:"max" json:"max"`
|
||||
|
||||
// OnlyInt will require the field value to be integer.
|
||||
OnlyInt bool `form:"onlyInt" json:"onlyInt"`
|
||||
|
||||
// Required will require the field value to be non-zero.
|
||||
Required bool `form:"required" json:"required"`
|
||||
}
|
||||
|
||||
// Type implements [Field.Type] interface method.
|
||||
func (f *NumberField) Type() string {
|
||||
return FieldTypeNumber
|
||||
}
|
||||
|
||||
// GetId implements [Field.GetId] interface method.
|
||||
func (f *NumberField) GetId() string {
|
||||
return f.Id
|
||||
}
|
||||
|
||||
// SetId implements [Field.SetId] interface method.
|
||||
func (f *NumberField) SetId(id string) {
|
||||
f.Id = id
|
||||
}
|
||||
|
||||
// GetName implements [Field.GetName] interface method.
|
||||
func (f *NumberField) GetName() string {
|
||||
return f.Name
|
||||
}
|
||||
|
||||
// SetName implements [Field.SetName] interface method.
|
||||
func (f *NumberField) SetName(name string) {
|
||||
f.Name = name
|
||||
}
|
||||
|
||||
// GetSystem implements [Field.GetSystem] interface method.
|
||||
func (f *NumberField) GetSystem() bool {
|
||||
return f.System
|
||||
}
|
||||
|
||||
// SetSystem implements [Field.SetSystem] interface method.
|
||||
func (f *NumberField) SetSystem(system bool) {
|
||||
f.System = system
|
||||
}
|
||||
|
||||
// GetHidden implements [Field.GetHidden] interface method.
|
||||
func (f *NumberField) GetHidden() bool {
|
||||
return f.Hidden
|
||||
}
|
||||
|
||||
// SetHidden implements [Field.SetHidden] interface method.
|
||||
func (f *NumberField) SetHidden(hidden bool) {
|
||||
f.Hidden = hidden
|
||||
}
|
||||
|
||||
// ColumnType implements [Field.ColumnType] interface method.
|
||||
func (f *NumberField) ColumnType(app App) string {
|
||||
return "NUMERIC DEFAULT 0 NOT NULL"
|
||||
}
|
||||
|
||||
// PrepareValue implements [Field.PrepareValue] interface method.
|
||||
func (f *NumberField) PrepareValue(record *Record, raw any) (any, error) {
|
||||
return cast.ToFloat64(raw), nil
|
||||
}
|
||||
|
||||
// ValidateValue implements [Field.ValidateValue] interface method.
|
||||
func (f *NumberField) ValidateValue(ctx context.Context, app App, record *Record) error {
|
||||
val, ok := record.GetRaw(f.Name).(float64)
|
||||
if !ok {
|
||||
return validators.ErrUnsupportedValueType
|
||||
}
|
||||
|
||||
if math.IsInf(val, 0) || math.IsNaN(val) {
|
||||
return validation.NewError("validation_not_a_number", "The submitted number is not properly formatted")
|
||||
}
|
||||
|
||||
if val == 0 {
|
||||
if f.Required {
|
||||
if err := validation.Required.Validate(val); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if f.OnlyInt && val != float64(int64(val)) {
|
||||
return validation.NewError("validation_only_int_constraint", "Decimal numbers are not allowed")
|
||||
}
|
||||
|
||||
if f.Min != nil && val < *f.Min {
|
||||
return validation.NewError("validation_min_number_constraint", fmt.Sprintf("Must be larger than %f", *f.Min))
|
||||
}
|
||||
|
||||
if f.Max != nil && val > *f.Max {
|
||||
return validation.NewError("validation_max_number_constraint", fmt.Sprintf("Must be less than %f", *f.Max))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateSettings implements [Field.ValidateSettings] interface method.
|
||||
func (f *NumberField) ValidateSettings(ctx context.Context, app App, collection *Collection) error {
|
||||
maxRules := []validation.Rule{
|
||||
validation.By(f.checkOnlyInt),
|
||||
}
|
||||
if f.Min != nil && f.Max != nil {
|
||||
maxRules = append(maxRules, validation.Min(*f.Min))
|
||||
}
|
||||
|
||||
return validation.ValidateStruct(f,
|
||||
validation.Field(&f.Id, validation.By(DefaultFieldIdValidationRule)),
|
||||
validation.Field(&f.Name, validation.By(DefaultFieldNameValidationRule)),
|
||||
validation.Field(&f.Min, validation.By(f.checkOnlyInt)),
|
||||
validation.Field(&f.Max, maxRules...),
|
||||
)
|
||||
}
|
||||
|
||||
func (f *NumberField) checkOnlyInt(value any) error {
|
||||
v, _ := value.(*float64)
|
||||
if v == nil || !f.OnlyInt {
|
||||
return nil // nothing to check
|
||||
}
|
||||
|
||||
if *v != float64(int64(*v)) {
|
||||
return validation.NewError("validation_only_int_constraint", "Decimal numbers are not allowed.")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// FindSetter implements the [SetterFinder] interface.
|
||||
func (f *NumberField) FindSetter(key string) SetterFunc {
|
||||
switch key {
|
||||
case f.Name:
|
||||
return f.setValue
|
||||
case f.Name + "+":
|
||||
return f.addValue
|
||||
case f.Name + "-":
|
||||
return f.subtractValue
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (f *NumberField) setValue(record *Record, raw any) {
|
||||
record.SetRaw(f.Name, cast.ToFloat64(raw))
|
||||
}
|
||||
|
||||
func (f *NumberField) addValue(record *Record, raw any) {
|
||||
val := cast.ToFloat64(record.GetRaw(f.Name))
|
||||
|
||||
record.SetRaw(f.Name, val+cast.ToFloat64(raw))
|
||||
}
|
||||
|
||||
func (f *NumberField) subtractValue(record *Record, raw any) {
|
||||
val := cast.ToFloat64(record.GetRaw(f.Name))
|
||||
|
||||
record.SetRaw(f.Name, val-cast.ToFloat64(raw))
|
||||
}
|
403
core/field_number_test.go
Normal file
403
core/field_number_test.go
Normal file
|
@ -0,0 +1,403 @@
|
|||
package core_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
)
|
||||
|
||||
func TestNumberFieldBaseMethods(t *testing.T) {
|
||||
testFieldBaseMethods(t, core.FieldTypeNumber)
|
||||
}
|
||||
|
||||
func TestNumberFieldColumnType(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
f := &core.NumberField{}
|
||||
|
||||
expected := "NUMERIC DEFAULT 0 NOT NULL"
|
||||
|
||||
if v := f.ColumnType(app); v != expected {
|
||||
t.Fatalf("Expected\n%q\ngot\n%q", expected, v)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNumberFieldPrepareValue(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
f := &core.NumberField{}
|
||||
record := core.NewRecord(core.NewBaseCollection("test"))
|
||||
|
||||
scenarios := []struct {
|
||||
raw any
|
||||
expected float64
|
||||
}{
|
||||
{"", 0},
|
||||
{"test", 0},
|
||||
{false, 0},
|
||||
{true, 1},
|
||||
{-2, -2},
|
||||
{123.456, 123.456},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%#v", i, s.raw), func(t *testing.T) {
|
||||
vRaw, err := f.PrepareValue(record, s.raw)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
v, ok := vRaw.(float64)
|
||||
if !ok {
|
||||
t.Fatalf("Expected float64 instance, got %T", v)
|
||||
}
|
||||
|
||||
if v != s.expected {
|
||||
t.Fatalf("Expected %f, got %f", s.expected, v)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNumberFieldValidateValue(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
collection := core.NewBaseCollection("test_collection")
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
field *core.NumberField
|
||||
record func() *core.Record
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
"invalid raw value",
|
||||
&core.NumberField{Name: "test"},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", "123")
|
||||
return record
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"zero field value (not required)",
|
||||
&core.NumberField{Name: "test"},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", 0.0)
|
||||
return record
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"zero field value (required)",
|
||||
&core.NumberField{Name: "test", Required: true},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", 0.0)
|
||||
return record
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"non-zero field value (required)",
|
||||
&core.NumberField{Name: "test", Required: true},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", 123.0)
|
||||
return record
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"decimal with onlyInt",
|
||||
&core.NumberField{Name: "test", OnlyInt: true},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", 123.456)
|
||||
return record
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"int with onlyInt",
|
||||
&core.NumberField{Name: "test", OnlyInt: true},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", 123.0)
|
||||
return record
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"< min",
|
||||
&core.NumberField{Name: "test", Min: types.Pointer(2.0)},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", 1.0)
|
||||
return record
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
">= min",
|
||||
&core.NumberField{Name: "test", Min: types.Pointer(2.0)},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", 2.0)
|
||||
return record
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"> max",
|
||||
&core.NumberField{Name: "test", Max: types.Pointer(2.0)},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", 3.0)
|
||||
return record
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"<= max",
|
||||
&core.NumberField{Name: "test", Max: types.Pointer(2.0)},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", 2.0)
|
||||
return record
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"infinity",
|
||||
&core.NumberField{Name: "test"},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.Set("test", "Inf")
|
||||
return record
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"NaN",
|
||||
&core.NumberField{Name: "test"},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.Set("test", "NaN")
|
||||
return record
|
||||
},
|
||||
true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
err := s.field.ValidateValue(context.Background(), app, s.record())
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNumberFieldValidateSettings(t *testing.T) {
|
||||
testDefaultFieldIdValidation(t, core.FieldTypeNumber)
|
||||
testDefaultFieldNameValidation(t, core.FieldTypeNumber)
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
collection := core.NewBaseCollection("test_collection")
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
field func() *core.NumberField
|
||||
expectErrors []string
|
||||
}{
|
||||
{
|
||||
"zero",
|
||||
func() *core.NumberField {
|
||||
return &core.NumberField{
|
||||
Id: "test",
|
||||
Name: "test",
|
||||
}
|
||||
},
|
||||
[]string{},
|
||||
},
|
||||
{
|
||||
"decumal min",
|
||||
func() *core.NumberField {
|
||||
return &core.NumberField{
|
||||
Id: "test",
|
||||
Name: "test",
|
||||
Min: types.Pointer(1.2),
|
||||
}
|
||||
},
|
||||
[]string{},
|
||||
},
|
||||
{
|
||||
"decumal min (onlyInt)",
|
||||
func() *core.NumberField {
|
||||
return &core.NumberField{
|
||||
Id: "test",
|
||||
Name: "test",
|
||||
OnlyInt: true,
|
||||
Min: types.Pointer(1.2),
|
||||
}
|
||||
},
|
||||
[]string{"min"},
|
||||
},
|
||||
{
|
||||
"int min (onlyInt)",
|
||||
func() *core.NumberField {
|
||||
return &core.NumberField{
|
||||
Id: "test",
|
||||
Name: "test",
|
||||
OnlyInt: true,
|
||||
Min: types.Pointer(1.0),
|
||||
}
|
||||
},
|
||||
[]string{},
|
||||
},
|
||||
{
|
||||
"decumal max",
|
||||
func() *core.NumberField {
|
||||
return &core.NumberField{
|
||||
Id: "test",
|
||||
Name: "test",
|
||||
Max: types.Pointer(1.2),
|
||||
}
|
||||
},
|
||||
[]string{},
|
||||
},
|
||||
{
|
||||
"decumal max (onlyInt)",
|
||||
func() *core.NumberField {
|
||||
return &core.NumberField{
|
||||
Id: "test",
|
||||
Name: "test",
|
||||
OnlyInt: true,
|
||||
Max: types.Pointer(1.2),
|
||||
}
|
||||
},
|
||||
[]string{"max"},
|
||||
},
|
||||
{
|
||||
"int max (onlyInt)",
|
||||
func() *core.NumberField {
|
||||
return &core.NumberField{
|
||||
Id: "test",
|
||||
Name: "test",
|
||||
OnlyInt: true,
|
||||
Max: types.Pointer(1.0),
|
||||
}
|
||||
},
|
||||
[]string{},
|
||||
},
|
||||
{
|
||||
"min > max",
|
||||
func() *core.NumberField {
|
||||
return &core.NumberField{
|
||||
Id: "test",
|
||||
Name: "test",
|
||||
Min: types.Pointer(2.0),
|
||||
Max: types.Pointer(1.0),
|
||||
}
|
||||
},
|
||||
[]string{"max"},
|
||||
},
|
||||
{
|
||||
"min <= max",
|
||||
func() *core.NumberField {
|
||||
return &core.NumberField{
|
||||
Id: "test",
|
||||
Name: "test",
|
||||
Min: types.Pointer(2.0),
|
||||
Max: types.Pointer(2.0),
|
||||
}
|
||||
},
|
||||
[]string{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
errs := s.field().ValidateSettings(context.Background(), app, collection)
|
||||
|
||||
tests.TestValidationErrors(t, errs, s.expectErrors)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNumberFieldFindSetter(t *testing.T) {
|
||||
field := &core.NumberField{Name: "test"}
|
||||
|
||||
collection := core.NewBaseCollection("test_collection")
|
||||
collection.Fields.Add(field)
|
||||
|
||||
t.Run("no match", func(t *testing.T) {
|
||||
f := field.FindSetter("abc")
|
||||
if f != nil {
|
||||
t.Fatal("Expected nil setter")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("direct name match", func(t *testing.T) {
|
||||
f := field.FindSetter("test")
|
||||
if f == nil {
|
||||
t.Fatal("Expected non-nil setter")
|
||||
}
|
||||
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", 2.0)
|
||||
|
||||
f(record, "123.456") // should be casted
|
||||
|
||||
if v := record.Get("test"); v != 123.456 {
|
||||
t.Fatalf("Expected %f, got %f", 123.456, v)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("name+ match", func(t *testing.T) {
|
||||
f := field.FindSetter("test+")
|
||||
if f == nil {
|
||||
t.Fatal("Expected non-nil setter")
|
||||
}
|
||||
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", 2.0)
|
||||
|
||||
f(record, "1.5") // should be casted and appended to the existing value
|
||||
|
||||
if v := record.Get("test"); v != 3.5 {
|
||||
t.Fatalf("Expected %f, got %f", 3.5, v)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("name- match", func(t *testing.T) {
|
||||
f := field.FindSetter("test-")
|
||||
if f == nil {
|
||||
t.Fatal("Expected non-nil setter")
|
||||
}
|
||||
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", 2.0)
|
||||
|
||||
f(record, "1.5") // should be casted and subtracted from the existing value
|
||||
|
||||
if v := record.Get("test"); v != 0.5 {
|
||||
t.Fatalf("Expected %f, got %f", 0.5, v)
|
||||
}
|
||||
})
|
||||
}
|
318
core/field_password.go
Normal file
318
core/field_password.go
Normal file
|
@ -0,0 +1,318 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/pocketbase/pocketbase/core/validators"
|
||||
"github.com/spf13/cast"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Fields[FieldTypePassword] = func() Field {
|
||||
return &PasswordField{}
|
||||
}
|
||||
}
|
||||
|
||||
const FieldTypePassword = "password"
|
||||
|
||||
var (
|
||||
_ Field = (*PasswordField)(nil)
|
||||
_ GetterFinder = (*PasswordField)(nil)
|
||||
_ SetterFinder = (*PasswordField)(nil)
|
||||
_ DriverValuer = (*PasswordField)(nil)
|
||||
_ RecordInterceptor = (*PasswordField)(nil)
|
||||
)
|
||||
|
||||
// PasswordField defines "password" type field for storing bcrypt hashed strings
|
||||
// (usually used only internally for the "password" auth collection system field).
|
||||
//
|
||||
// If you want to set a direct bcrypt hash as record field value you can use the SetRaw method, for example:
|
||||
//
|
||||
// // generates a bcrypt hash of "123456" and set it as field value
|
||||
// // (record.GetString("password") returns the plain password until persisted, otherwise empty string)
|
||||
// record.Set("password", "123456")
|
||||
//
|
||||
// // set directly a bcrypt hash of "123456" as field value
|
||||
// // (record.GetString("password") returns empty string)
|
||||
// record.SetRaw("password", "$2a$10$.5Elh8fgxypNUWhpUUr/xOa2sZm0VIaE0qWuGGl9otUfobb46T1Pq")
|
||||
//
|
||||
// The following additional getter keys are available:
|
||||
//
|
||||
// - "fieldName:hash" - returns the bcrypt hash string of the record field value (if any). For example:
|
||||
// record.GetString("password:hash")
|
||||
type PasswordField struct {
|
||||
// Name (required) is the unique name of the field.
|
||||
Name string `form:"name" json:"name"`
|
||||
|
||||
// Id is the unique stable field identifier.
|
||||
//
|
||||
// It is automatically generated from the name when adding to a collection FieldsList.
|
||||
Id string `form:"id" json:"id"`
|
||||
|
||||
// System prevents the renaming and removal of the field.
|
||||
System bool `form:"system" json:"system"`
|
||||
|
||||
// Hidden hides the field from the API response.
|
||||
Hidden bool `form:"hidden" json:"hidden"`
|
||||
|
||||
// Presentable hints the Dashboard UI to use the underlying
|
||||
// field record value in the relation preview label.
|
||||
Presentable bool `form:"presentable" json:"presentable"`
|
||||
|
||||
// ---
|
||||
|
||||
// Pattern specifies an optional regex pattern to match against the field value.
|
||||
//
|
||||
// Leave it empty to skip the pattern check.
|
||||
Pattern string `form:"pattern" json:"pattern"`
|
||||
|
||||
// Min specifies an optional required field string length.
|
||||
Min int `form:"min" json:"min"`
|
||||
|
||||
// Max specifies an optional required field string length.
|
||||
//
|
||||
// If zero, fallback to max 71 bytes.
|
||||
Max int `form:"max" json:"max"`
|
||||
|
||||
// Cost specifies the cost/weight/iteration/etc. bcrypt factor.
|
||||
//
|
||||
// If zero, fallback to [bcrypt.DefaultCost].
|
||||
//
|
||||
// If explicitly set, must be between [bcrypt.MinCost] and [bcrypt.MaxCost].
|
||||
Cost int `form:"cost" json:"cost"`
|
||||
|
||||
// Required will require the field value to be non-empty string.
|
||||
Required bool `form:"required" json:"required"`
|
||||
}
|
||||
|
||||
// Type implements [Field.Type] interface method.
|
||||
func (f *PasswordField) Type() string {
|
||||
return FieldTypePassword
|
||||
}
|
||||
|
||||
// GetId implements [Field.GetId] interface method.
|
||||
func (f *PasswordField) GetId() string {
|
||||
return f.Id
|
||||
}
|
||||
|
||||
// SetId implements [Field.SetId] interface method.
|
||||
func (f *PasswordField) SetId(id string) {
|
||||
f.Id = id
|
||||
}
|
||||
|
||||
// GetName implements [Field.GetName] interface method.
|
||||
func (f *PasswordField) GetName() string {
|
||||
return f.Name
|
||||
}
|
||||
|
||||
// SetName implements [Field.SetName] interface method.
|
||||
func (f *PasswordField) SetName(name string) {
|
||||
f.Name = name
|
||||
}
|
||||
|
||||
// GetSystem implements [Field.GetSystem] interface method.
|
||||
func (f *PasswordField) GetSystem() bool {
|
||||
return f.System
|
||||
}
|
||||
|
||||
// SetSystem implements [Field.SetSystem] interface method.
|
||||
func (f *PasswordField) SetSystem(system bool) {
|
||||
f.System = system
|
||||
}
|
||||
|
||||
// GetHidden implements [Field.GetHidden] interface method.
|
||||
func (f *PasswordField) GetHidden() bool {
|
||||
return f.Hidden
|
||||
}
|
||||
|
||||
// SetHidden implements [Field.SetHidden] interface method.
|
||||
func (f *PasswordField) SetHidden(hidden bool) {
|
||||
f.Hidden = hidden
|
||||
}
|
||||
|
||||
// ColumnType implements [Field.ColumnType] interface method.
|
||||
func (f *PasswordField) ColumnType(app App) string {
|
||||
return "TEXT DEFAULT '' NOT NULL"
|
||||
}
|
||||
|
||||
// DriverValue implements the [DriverValuer] interface.
|
||||
func (f *PasswordField) DriverValue(record *Record) (driver.Value, error) {
|
||||
fp := f.getPasswordValue(record)
|
||||
return fp.Hash, fp.LastError
|
||||
}
|
||||
|
||||
// PrepareValue implements [Field.PrepareValue] interface method.
|
||||
func (f *PasswordField) PrepareValue(record *Record, raw any) (any, error) {
|
||||
return &PasswordFieldValue{
|
||||
Hash: cast.ToString(raw),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ValidateValue implements [Field.ValidateValue] interface method.
|
||||
func (f *PasswordField) ValidateValue(ctx context.Context, app App, record *Record) error {
|
||||
fp, ok := record.GetRaw(f.Name).(*PasswordFieldValue)
|
||||
if !ok {
|
||||
return validators.ErrUnsupportedValueType
|
||||
}
|
||||
|
||||
if fp.LastError != nil {
|
||||
return fp.LastError
|
||||
}
|
||||
|
||||
if f.Required {
|
||||
if err := validation.Required.Validate(fp.Hash); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if fp.Plain == "" {
|
||||
return nil // nothing to check
|
||||
}
|
||||
|
||||
// note: casted to []rune to count multi-byte chars as one for the
|
||||
// sake of more intuitive UX and clearer user error messages
|
||||
//
|
||||
// note2: technically multi-byte strings could produce bigger length than the bcrypt limit
|
||||
// but it should be fine as it will be just truncated (even if it cuts a byte sequence in the middle)
|
||||
length := len([]rune(fp.Plain))
|
||||
|
||||
if length < f.Min {
|
||||
return validation.NewError("validation_min_text_constraint", fmt.Sprintf("Must be at least %d character(s)", f.Min))
|
||||
}
|
||||
|
||||
maxLength := f.Max
|
||||
if maxLength <= 0 {
|
||||
maxLength = 71
|
||||
}
|
||||
if length > maxLength {
|
||||
return validation.NewError("validation_max_text_constraint", fmt.Sprintf("Must be less than %d character(s)", maxLength))
|
||||
}
|
||||
|
||||
if f.Pattern != "" {
|
||||
match, _ := regexp.MatchString(f.Pattern, fp.Plain)
|
||||
if !match {
|
||||
return validation.NewError("validation_invalid_format", "Invalid value format")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateSettings implements [Field.ValidateSettings] interface method.
|
||||
func (f *PasswordField) ValidateSettings(ctx context.Context, app App, collection *Collection) error {
|
||||
return validation.ValidateStruct(f,
|
||||
validation.Field(&f.Id, validation.By(DefaultFieldIdValidationRule)),
|
||||
validation.Field(&f.Name, validation.By(DefaultFieldNameValidationRule)),
|
||||
validation.Field(&f.Min, validation.Min(1), validation.Max(71)),
|
||||
validation.Field(&f.Max, validation.Min(f.Min), validation.Max(71)),
|
||||
validation.Field(&f.Cost, validation.Min(bcrypt.MinCost), validation.Max(bcrypt.MaxCost)),
|
||||
validation.Field(&f.Pattern, validation.By(validators.IsRegex)),
|
||||
)
|
||||
}
|
||||
|
||||
func (f *PasswordField) getPasswordValue(record *Record) *PasswordFieldValue {
|
||||
raw := record.GetRaw(f.Name)
|
||||
|
||||
switch v := raw.(type) {
|
||||
case *PasswordFieldValue:
|
||||
return v
|
||||
case string:
|
||||
// we assume that any raw string starting with $2 is bcrypt hash
|
||||
if strings.HasPrefix(v, "$2") {
|
||||
return &PasswordFieldValue{Hash: v}
|
||||
}
|
||||
}
|
||||
|
||||
return &PasswordFieldValue{}
|
||||
}
|
||||
|
||||
// Intercept implements the [RecordInterceptor] interface.
|
||||
func (f *PasswordField) Intercept(
|
||||
ctx context.Context,
|
||||
app App,
|
||||
record *Record,
|
||||
actionName string,
|
||||
actionFunc func() error,
|
||||
) error {
|
||||
switch actionName {
|
||||
case InterceptorActionAfterCreate, InterceptorActionAfterUpdate:
|
||||
// unset the plain field value after successful create/update
|
||||
fp := f.getPasswordValue(record)
|
||||
fp.Plain = ""
|
||||
}
|
||||
|
||||
return actionFunc()
|
||||
}
|
||||
|
||||
// FindGetter implements the [GetterFinder] interface.
|
||||
func (f *PasswordField) FindGetter(key string) GetterFunc {
|
||||
switch key {
|
||||
case f.Name:
|
||||
return func(record *Record) any {
|
||||
return f.getPasswordValue(record).Plain
|
||||
}
|
||||
case f.Name + ":hash":
|
||||
return func(record *Record) any {
|
||||
return f.getPasswordValue(record).Hash
|
||||
}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// FindSetter implements the [SetterFinder] interface.
|
||||
func (f *PasswordField) FindSetter(key string) SetterFunc {
|
||||
switch key {
|
||||
case f.Name:
|
||||
return f.setValue
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (f *PasswordField) setValue(record *Record, raw any) {
|
||||
fv := &PasswordFieldValue{
|
||||
Plain: cast.ToString(raw),
|
||||
}
|
||||
|
||||
// hash the password
|
||||
if fv.Plain != "" {
|
||||
cost := f.Cost
|
||||
if cost <= 0 {
|
||||
cost = bcrypt.DefaultCost
|
||||
}
|
||||
|
||||
hash, err := bcrypt.GenerateFromPassword([]byte(fv.Plain), cost)
|
||||
if err != nil {
|
||||
fv.LastError = err
|
||||
}
|
||||
|
||||
fv.Hash = string(hash)
|
||||
}
|
||||
|
||||
record.SetRaw(f.Name, fv)
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
type PasswordFieldValue struct {
|
||||
LastError error
|
||||
Hash string
|
||||
Plain string
|
||||
}
|
||||
|
||||
func (pv PasswordFieldValue) Validate(pass string) bool {
|
||||
if pv.Hash == "" || pv.LastError != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
err := bcrypt.CompareHashAndPassword([]byte(pv.Hash), []byte(pass))
|
||||
|
||||
return err == nil
|
||||
}
|
568
core/field_password_test.go
Normal file
568
core/field_password_test.go
Normal file
|
@ -0,0 +1,568 @@
|
|||
package core_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
func TestPasswordFieldBaseMethods(t *testing.T) {
|
||||
testFieldBaseMethods(t, core.FieldTypePassword)
|
||||
}
|
||||
|
||||
func TestPasswordFieldColumnType(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
f := &core.PasswordField{}
|
||||
|
||||
expected := "TEXT DEFAULT '' NOT NULL"
|
||||
|
||||
if v := f.ColumnType(app); v != expected {
|
||||
t.Fatalf("Expected\n%q\ngot\n%q", expected, v)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPasswordFieldPrepareValue(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
f := &core.PasswordField{}
|
||||
record := core.NewRecord(core.NewBaseCollection("test"))
|
||||
|
||||
scenarios := []struct {
|
||||
raw any
|
||||
expected string
|
||||
}{
|
||||
{"", ""},
|
||||
{"test", "test"},
|
||||
{false, "false"},
|
||||
{true, "true"},
|
||||
{123.456, "123.456"},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%#v", i, s.raw), func(t *testing.T) {
|
||||
v, err := f.PrepareValue(record, s.raw)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
pv, ok := v.(*core.PasswordFieldValue)
|
||||
if !ok {
|
||||
t.Fatalf("Expected PasswordFieldValue instance, got %T", v)
|
||||
}
|
||||
|
||||
if pv.Hash != s.expected {
|
||||
t.Fatalf("Expected %q, got %q", s.expected, v)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPasswordFieldDriverValue(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
f := &core.PasswordField{Name: "test"}
|
||||
|
||||
err := errors.New("example_err")
|
||||
|
||||
scenarios := []struct {
|
||||
raw any
|
||||
expected *core.PasswordFieldValue
|
||||
}{
|
||||
{123, &core.PasswordFieldValue{}},
|
||||
{"abc", &core.PasswordFieldValue{}},
|
||||
{"$2abc", &core.PasswordFieldValue{Hash: "$2abc"}},
|
||||
{&core.PasswordFieldValue{Hash: "test", LastError: err}, &core.PasswordFieldValue{Hash: "test", LastError: err}},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%v", i, s.raw), func(t *testing.T) {
|
||||
record := core.NewRecord(core.NewBaseCollection("test"))
|
||||
record.SetRaw(f.GetName(), s.raw)
|
||||
|
||||
v, err := f.DriverValue(record)
|
||||
|
||||
vStr, ok := v.(string)
|
||||
if !ok {
|
||||
t.Fatalf("Expected string instance, got %T", v)
|
||||
}
|
||||
|
||||
var errStr string
|
||||
if err != nil {
|
||||
errStr = err.Error()
|
||||
}
|
||||
|
||||
var expectedErrStr string
|
||||
if s.expected.LastError != nil {
|
||||
expectedErrStr = s.expected.LastError.Error()
|
||||
}
|
||||
|
||||
if errStr != expectedErrStr {
|
||||
t.Fatalf("Expected error %q, got %q", expectedErrStr, errStr)
|
||||
}
|
||||
|
||||
if vStr != s.expected.Hash {
|
||||
t.Fatalf("Expected hash %q, got %q", s.expected.Hash, vStr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPasswordFieldValidateValue(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
collection := core.NewBaseCollection("test_collection")
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
field *core.PasswordField
|
||||
record func() *core.Record
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
"invalid raw value",
|
||||
&core.PasswordField{Name: "test"},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", "123")
|
||||
return record
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"zero field value (not required)",
|
||||
&core.PasswordField{Name: "test"},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", &core.PasswordFieldValue{})
|
||||
return record
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"zero field value (required)",
|
||||
&core.PasswordField{Name: "test", Required: true},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", &core.PasswordFieldValue{})
|
||||
return record
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"empty hash but non-empty plain password (required)",
|
||||
&core.PasswordField{Name: "test", Required: true},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", &core.PasswordFieldValue{Plain: "test"})
|
||||
return record
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"non-empty hash (required)",
|
||||
&core.PasswordField{Name: "test", Required: true},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", &core.PasswordFieldValue{Hash: "test"})
|
||||
return record
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"with LastError",
|
||||
&core.PasswordField{Name: "test"},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", &core.PasswordFieldValue{LastError: errors.New("test")})
|
||||
return record
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"< Min",
|
||||
&core.PasswordField{Name: "test", Min: 3},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", &core.PasswordFieldValue{Plain: "аб"}) // multi-byte chars test
|
||||
return record
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
">= Min",
|
||||
&core.PasswordField{Name: "test", Min: 3},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", &core.PasswordFieldValue{Plain: "абв"}) // multi-byte chars test
|
||||
return record
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"> default Max",
|
||||
&core.PasswordField{Name: "test"},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", &core.PasswordFieldValue{Plain: strings.Repeat("a", 72)})
|
||||
return record
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"<= default Max",
|
||||
&core.PasswordField{Name: "test"},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", &core.PasswordFieldValue{Plain: strings.Repeat("a", 71)})
|
||||
return record
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"> Max",
|
||||
&core.PasswordField{Name: "test", Max: 2},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", &core.PasswordFieldValue{Plain: "абв"}) // multi-byte chars test
|
||||
return record
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"<= Max",
|
||||
&core.PasswordField{Name: "test", Max: 2},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", &core.PasswordFieldValue{Plain: "аб"}) // multi-byte chars test
|
||||
return record
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"non-matching pattern",
|
||||
&core.PasswordField{Name: "test", Pattern: `\d+`},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", &core.PasswordFieldValue{Plain: "abc"})
|
||||
return record
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"matching pattern",
|
||||
&core.PasswordField{Name: "test", Pattern: `\d+`},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", &core.PasswordFieldValue{Plain: "123"})
|
||||
return record
|
||||
},
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
err := s.field.ValidateValue(context.Background(), app, s.record())
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPasswordFieldValidateSettings(t *testing.T) {
|
||||
testDefaultFieldIdValidation(t, core.FieldTypePassword)
|
||||
testDefaultFieldNameValidation(t, core.FieldTypePassword)
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
field func(col *core.Collection) *core.PasswordField
|
||||
expectErrors []string
|
||||
}{
|
||||
{
|
||||
"zero minimal",
|
||||
func(col *core.Collection) *core.PasswordField {
|
||||
return &core.PasswordField{
|
||||
Id: "test",
|
||||
Name: "test",
|
||||
}
|
||||
},
|
||||
[]string{},
|
||||
},
|
||||
{
|
||||
"invalid pattern",
|
||||
func(col *core.Collection) *core.PasswordField {
|
||||
return &core.PasswordField{
|
||||
Id: "test",
|
||||
Name: "test",
|
||||
Pattern: "(invalid",
|
||||
}
|
||||
},
|
||||
[]string{"pattern"},
|
||||
},
|
||||
{
|
||||
"valid pattern",
|
||||
func(col *core.Collection) *core.PasswordField {
|
||||
return &core.PasswordField{
|
||||
Id: "test",
|
||||
Name: "test",
|
||||
Pattern: `\d+`,
|
||||
}
|
||||
},
|
||||
[]string{},
|
||||
},
|
||||
{
|
||||
"Min < 0",
|
||||
func(col *core.Collection) *core.PasswordField {
|
||||
return &core.PasswordField{
|
||||
Id: "test",
|
||||
Name: "test",
|
||||
Min: -1,
|
||||
}
|
||||
},
|
||||
[]string{"min"},
|
||||
},
|
||||
{
|
||||
"Min > 71",
|
||||
func(col *core.Collection) *core.PasswordField {
|
||||
return &core.PasswordField{
|
||||
Id: "test",
|
||||
Name: "test",
|
||||
Min: 72,
|
||||
}
|
||||
},
|
||||
[]string{"min"},
|
||||
},
|
||||
{
|
||||
"valid Min",
|
||||
func(col *core.Collection) *core.PasswordField {
|
||||
return &core.PasswordField{
|
||||
Id: "test",
|
||||
Name: "test",
|
||||
Min: 5,
|
||||
}
|
||||
},
|
||||
[]string{},
|
||||
},
|
||||
{
|
||||
"Max < Min",
|
||||
func(col *core.Collection) *core.PasswordField {
|
||||
return &core.PasswordField{
|
||||
Id: "test",
|
||||
Name: "test",
|
||||
Min: 2,
|
||||
Max: 1,
|
||||
}
|
||||
},
|
||||
[]string{"max"},
|
||||
},
|
||||
{
|
||||
"Min > Min",
|
||||
func(col *core.Collection) *core.PasswordField {
|
||||
return &core.PasswordField{
|
||||
Id: "test",
|
||||
Name: "test",
|
||||
Min: 2,
|
||||
Max: 3,
|
||||
}
|
||||
},
|
||||
[]string{},
|
||||
},
|
||||
{
|
||||
"Max > 71",
|
||||
func(col *core.Collection) *core.PasswordField {
|
||||
return &core.PasswordField{
|
||||
Id: "test",
|
||||
Name: "test",
|
||||
Max: 72,
|
||||
}
|
||||
},
|
||||
[]string{"max"},
|
||||
},
|
||||
{
|
||||
"cost < bcrypt.MinCost",
|
||||
func(col *core.Collection) *core.PasswordField {
|
||||
return &core.PasswordField{
|
||||
Id: "test",
|
||||
Name: "test",
|
||||
Cost: bcrypt.MinCost - 1,
|
||||
}
|
||||
},
|
||||
[]string{"cost"},
|
||||
},
|
||||
{
|
||||
"cost > bcrypt.MaxCost",
|
||||
func(col *core.Collection) *core.PasswordField {
|
||||
return &core.PasswordField{
|
||||
Id: "test",
|
||||
Name: "test",
|
||||
Cost: bcrypt.MaxCost + 1,
|
||||
}
|
||||
},
|
||||
[]string{"cost"},
|
||||
},
|
||||
{
|
||||
"valid cost",
|
||||
func(col *core.Collection) *core.PasswordField {
|
||||
return &core.PasswordField{
|
||||
Id: "test",
|
||||
Name: "test",
|
||||
Cost: 12,
|
||||
}
|
||||
},
|
||||
[]string{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
collection := core.NewBaseCollection("test_collection")
|
||||
collection.Fields.GetByName("id").SetId("test") // set a dummy known id so that it can be replaced
|
||||
|
||||
field := s.field(collection)
|
||||
|
||||
collection.Fields.Add(field)
|
||||
|
||||
errs := field.ValidateSettings(context.Background(), app, collection)
|
||||
|
||||
tests.TestValidationErrors(t, errs, s.expectErrors)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPasswordFieldFindSetter(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
name string
|
||||
key string
|
||||
value any
|
||||
field *core.PasswordField
|
||||
hasSetter bool
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
"no match",
|
||||
"example",
|
||||
"abc",
|
||||
&core.PasswordField{Name: "test"},
|
||||
false,
|
||||
"",
|
||||
},
|
||||
{
|
||||
"exact match",
|
||||
"test",
|
||||
"abc",
|
||||
&core.PasswordField{Name: "test"},
|
||||
true,
|
||||
`"abc"`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
collection := core.NewBaseCollection("test_collection")
|
||||
collection.Fields.Add(s.field)
|
||||
|
||||
setter := s.field.FindSetter(s.key)
|
||||
|
||||
hasSetter := setter != nil
|
||||
if hasSetter != s.hasSetter {
|
||||
t.Fatalf("Expected hasSetter %v, got %v", s.hasSetter, hasSetter)
|
||||
}
|
||||
|
||||
if !hasSetter {
|
||||
return
|
||||
}
|
||||
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw(s.field.GetName(), []string{"c", "d"})
|
||||
|
||||
setter(record, s.value)
|
||||
|
||||
raw, err := json.Marshal(record.Get(s.field.GetName()))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
rawStr := string(raw)
|
||||
|
||||
if rawStr != s.expected {
|
||||
t.Fatalf("Expected %q, got %q", s.expected, rawStr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPasswordFieldFindGetter(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
name string
|
||||
key string
|
||||
field *core.PasswordField
|
||||
hasGetter bool
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
"no match",
|
||||
"example",
|
||||
&core.PasswordField{Name: "test"},
|
||||
false,
|
||||
"",
|
||||
},
|
||||
{
|
||||
"field name match",
|
||||
"test",
|
||||
&core.PasswordField{Name: "test"},
|
||||
true,
|
||||
"test_plain",
|
||||
},
|
||||
{
|
||||
"field name hash modifier",
|
||||
"test:hash",
|
||||
&core.PasswordField{Name: "test"},
|
||||
true,
|
||||
"test_hash",
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
collection := core.NewBaseCollection("test_collection")
|
||||
collection.Fields.Add(s.field)
|
||||
|
||||
getter := s.field.FindGetter(s.key)
|
||||
|
||||
hasGetter := getter != nil
|
||||
if hasGetter != s.hasGetter {
|
||||
t.Fatalf("Expected hasGetter %v, got %v", s.hasGetter, hasGetter)
|
||||
}
|
||||
|
||||
if !hasGetter {
|
||||
return
|
||||
}
|
||||
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw(s.field.GetName(), &core.PasswordFieldValue{Hash: "test_hash", Plain: "test_plain"})
|
||||
|
||||
result := getter(record)
|
||||
|
||||
if result != s.expected {
|
||||
t.Fatalf("Expected %q, got %#v", s.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
350
core/field_relation.go
Normal file
350
core/field_relation.go
Normal file
|
@ -0,0 +1,350 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql/driver"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/tools/list"
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Fields[FieldTypeRelation] = func() Field {
|
||||
return &RelationField{}
|
||||
}
|
||||
}
|
||||
|
||||
const FieldTypeRelation = "relation"
|
||||
|
||||
var (
|
||||
_ Field = (*RelationField)(nil)
|
||||
_ MultiValuer = (*RelationField)(nil)
|
||||
_ DriverValuer = (*RelationField)(nil)
|
||||
_ SetterFinder = (*RelationField)(nil)
|
||||
)
|
||||
|
||||
// RelationField defines "relation" type field for storing single or
|
||||
// multiple collection record references.
|
||||
//
|
||||
// Requires the CollectionId option to be set.
|
||||
//
|
||||
// If MaxSelect is not set or <= 1, then the field value is expected to be a single record id.
|
||||
//
|
||||
// If MaxSelect is > 1, then the field value is expected to be a slice of record ids.
|
||||
//
|
||||
// The respective zero record field value is either empty string (single) or empty string slice (multiple).
|
||||
//
|
||||
// ---
|
||||
//
|
||||
// The following additional setter keys are available:
|
||||
//
|
||||
// - "fieldName+" - append one or more values to the existing record one. For example:
|
||||
//
|
||||
// record.Set("categories+", []string{"new1", "new2"}) // []string{"old1", "old2", "new1", "new2"}
|
||||
//
|
||||
// - "+fieldName" - prepend one or more values to the existing record one. For example:
|
||||
//
|
||||
// record.Set("+categories", []string{"new1", "new2"}) // []string{"new1", "new2", "old1", "old2"}
|
||||
//
|
||||
// - "fieldName-" - subtract one or more values from the existing record one. For example:
|
||||
//
|
||||
// record.Set("categories-", "old1") // []string{"old2"}
|
||||
type RelationField struct {
|
||||
// Name (required) is the unique name of the field.
|
||||
Name string `form:"name" json:"name"`
|
||||
|
||||
// Id is the unique stable field identifier.
|
||||
//
|
||||
// It is automatically generated from the name when adding to a collection FieldsList.
|
||||
Id string `form:"id" json:"id"`
|
||||
|
||||
// System prevents the renaming and removal of the field.
|
||||
System bool `form:"system" json:"system"`
|
||||
|
||||
// Hidden hides the field from the API response.
|
||||
Hidden bool `form:"hidden" json:"hidden"`
|
||||
|
||||
// Presentable hints the Dashboard UI to use the underlying
|
||||
// field record value in the relation preview label.
|
||||
Presentable bool `form:"presentable" json:"presentable"`
|
||||
|
||||
// ---
|
||||
|
||||
// CollectionId is the id of the related collection.
|
||||
CollectionId string `form:"collectionId" json:"collectionId"`
|
||||
|
||||
// CascadeDelete indicates whether the root model should be deleted
|
||||
// in case of delete of all linked relations.
|
||||
CascadeDelete bool `form:"cascadeDelete" json:"cascadeDelete"`
|
||||
|
||||
// MinSelect indicates the min number of allowed relation records
|
||||
// that could be linked to the main model.
|
||||
//
|
||||
// No min limit is applied if it is zero or negative value.
|
||||
MinSelect int `form:"minSelect" json:"minSelect"`
|
||||
|
||||
// MaxSelect indicates the max number of allowed relation records
|
||||
// that could be linked to the main model.
|
||||
//
|
||||
// For multiple select the value must be > 1, otherwise fallbacks to single (default).
|
||||
//
|
||||
// If MinSelect is set, MaxSelect must be at least >= MinSelect.
|
||||
MaxSelect int `form:"maxSelect" json:"maxSelect"`
|
||||
|
||||
// Required will require the field value to be non-empty.
|
||||
Required bool `form:"required" json:"required"`
|
||||
}
|
||||
|
||||
// Type implements [Field.Type] interface method.
|
||||
func (f *RelationField) Type() string {
|
||||
return FieldTypeRelation
|
||||
}
|
||||
|
||||
// GetId implements [Field.GetId] interface method.
|
||||
func (f *RelationField) GetId() string {
|
||||
return f.Id
|
||||
}
|
||||
|
||||
// SetId implements [Field.SetId] interface method.
|
||||
func (f *RelationField) SetId(id string) {
|
||||
f.Id = id
|
||||
}
|
||||
|
||||
// GetName implements [Field.GetName] interface method.
|
||||
func (f *RelationField) GetName() string {
|
||||
return f.Name
|
||||
}
|
||||
|
||||
// SetName implements [Field.SetName] interface method.
|
||||
func (f *RelationField) SetName(name string) {
|
||||
f.Name = name
|
||||
}
|
||||
|
||||
// GetSystem implements [Field.GetSystem] interface method.
|
||||
func (f *RelationField) GetSystem() bool {
|
||||
return f.System
|
||||
}
|
||||
|
||||
// SetSystem implements [Field.SetSystem] interface method.
|
||||
func (f *RelationField) SetSystem(system bool) {
|
||||
f.System = system
|
||||
}
|
||||
|
||||
// GetHidden implements [Field.GetHidden] interface method.
|
||||
func (f *RelationField) GetHidden() bool {
|
||||
return f.Hidden
|
||||
}
|
||||
|
||||
// SetHidden implements [Field.SetHidden] interface method.
|
||||
func (f *RelationField) SetHidden(hidden bool) {
|
||||
f.Hidden = hidden
|
||||
}
|
||||
|
||||
// IsMultiple implements [MultiValuer] interface and checks whether the
|
||||
// current field options support multiple values.
|
||||
func (f *RelationField) IsMultiple() bool {
|
||||
return f.MaxSelect > 1
|
||||
}
|
||||
|
||||
// ColumnType implements [Field.ColumnType] interface method.
|
||||
func (f *RelationField) ColumnType(app App) string {
|
||||
if f.IsMultiple() {
|
||||
return "JSON DEFAULT '[]' NOT NULL"
|
||||
}
|
||||
|
||||
return "TEXT DEFAULT '' NOT NULL"
|
||||
}
|
||||
|
||||
// PrepareValue implements [Field.PrepareValue] interface method.
|
||||
func (f *RelationField) PrepareValue(record *Record, raw any) (any, error) {
|
||||
return f.normalizeValue(raw), nil
|
||||
}
|
||||
|
||||
func (f *RelationField) normalizeValue(raw any) any {
|
||||
val := list.ToUniqueStringSlice(raw)
|
||||
|
||||
if !f.IsMultiple() {
|
||||
if len(val) > 0 {
|
||||
return val[len(val)-1] // the last selected
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
return val
|
||||
}
|
||||
|
||||
// DriverValue implements the [DriverValuer] interface.
|
||||
func (f *RelationField) DriverValue(record *Record) (driver.Value, error) {
|
||||
val := list.ToUniqueStringSlice(record.GetRaw(f.Name))
|
||||
|
||||
if !f.IsMultiple() {
|
||||
if len(val) > 0 {
|
||||
return val[len(val)-1], nil // the last selected
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// serialize as json string array
|
||||
return append(types.JSONArray[string]{}, val...), nil
|
||||
}
|
||||
|
||||
// ValidateValue implements [Field.ValidateValue] interface method.
|
||||
func (f *RelationField) ValidateValue(ctx context.Context, app App, record *Record) error {
|
||||
ids := list.ToUniqueStringSlice(record.GetRaw(f.Name))
|
||||
if len(ids) == 0 {
|
||||
if f.Required {
|
||||
return validation.ErrRequired
|
||||
}
|
||||
return nil // nothing to check
|
||||
}
|
||||
|
||||
if f.MinSelect > 0 && len(ids) < f.MinSelect {
|
||||
return validation.NewError("validation_not_enough_values", "Select at least {{.minSelect}}").
|
||||
SetParams(map[string]any{"minSelect": f.MinSelect})
|
||||
}
|
||||
|
||||
maxSelect := max(f.MaxSelect, 1)
|
||||
if len(ids) > maxSelect {
|
||||
return validation.NewError("validation_too_many_values", "Select no more than {{.maxSelect}}").
|
||||
SetParams(map[string]any{"maxSelect": maxSelect})
|
||||
}
|
||||
|
||||
// check if the related records exist
|
||||
// ---
|
||||
relCollection, err := app.FindCachedCollectionByNameOrId(f.CollectionId)
|
||||
if err != nil {
|
||||
return validation.NewError("validation_missing_rel_collection", "Relation connection is missing or cannot be accessed")
|
||||
}
|
||||
|
||||
var total int
|
||||
_ = app.ConcurrentDB().
|
||||
Select("count(*)").
|
||||
From(relCollection.Name).
|
||||
AndWhere(dbx.In("id", list.ToInterfaceSlice(ids)...)).
|
||||
Row(&total)
|
||||
if total != len(ids) {
|
||||
return validation.NewError("validation_missing_rel_records", "Failed to find all relation records with the provided ids")
|
||||
}
|
||||
// ---
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateSettings implements [Field.ValidateSettings] interface method.
|
||||
func (f *RelationField) ValidateSettings(ctx context.Context, app App, collection *Collection) error {
|
||||
return validation.ValidateStruct(f,
|
||||
validation.Field(&f.Id, validation.By(DefaultFieldIdValidationRule)),
|
||||
validation.Field(&f.Name, validation.By(DefaultFieldNameValidationRule)),
|
||||
validation.Field(&f.CollectionId, validation.Required, validation.By(f.checkCollectionId(app, collection))),
|
||||
validation.Field(&f.MinSelect, validation.Min(0)),
|
||||
validation.Field(&f.MaxSelect, validation.When(f.MinSelect > 0, validation.Required), validation.Min(f.MinSelect)),
|
||||
)
|
||||
}
|
||||
|
||||
func (f *RelationField) checkCollectionId(app App, collection *Collection) validation.RuleFunc {
|
||||
return func(value any) error {
|
||||
v, _ := value.(string)
|
||||
if v == "" {
|
||||
return nil // nothing to check
|
||||
}
|
||||
|
||||
var oldCollection *Collection
|
||||
|
||||
if !collection.IsNew() {
|
||||
var err error
|
||||
oldCollection, err = app.FindCachedCollectionByNameOrId(collection.Id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// prevent collectionId change
|
||||
if oldCollection != nil {
|
||||
oldField, _ := oldCollection.Fields.GetById(f.Id).(*RelationField)
|
||||
if oldField != nil && oldField.CollectionId != v {
|
||||
return validation.NewError(
|
||||
"validation_field_relation_change",
|
||||
"The relation collection cannot be changed.",
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
relCollection, _ := app.FindCachedCollectionByNameOrId(v)
|
||||
|
||||
// validate collectionId
|
||||
if relCollection == nil || relCollection.Id != v {
|
||||
return validation.NewError(
|
||||
"validation_field_relation_missing_collection",
|
||||
"The relation collection doesn't exist.",
|
||||
)
|
||||
}
|
||||
|
||||
// allow only views to have relations to other views
|
||||
// (see https://github.com/pocketbase/pocketbase/issues/3000)
|
||||
if !collection.IsView() && relCollection.IsView() {
|
||||
return validation.NewError(
|
||||
"validation_relation_field_non_view_base_collection",
|
||||
"Only view collections are allowed to have relations to other views.",
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// ---
|
||||
|
||||
// FindSetter implements [SetterFinder] interface method.
|
||||
func (f *RelationField) FindSetter(key string) SetterFunc {
|
||||
switch key {
|
||||
case f.Name:
|
||||
return f.setValue
|
||||
case "+" + f.Name:
|
||||
return f.prependValue
|
||||
case f.Name + "+":
|
||||
return f.appendValue
|
||||
case f.Name + "-":
|
||||
return f.subtractValue
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (f *RelationField) setValue(record *Record, raw any) {
|
||||
record.SetRaw(f.Name, f.normalizeValue(raw))
|
||||
}
|
||||
|
||||
func (f *RelationField) appendValue(record *Record, modifierValue any) {
|
||||
val := record.GetRaw(f.Name)
|
||||
|
||||
val = append(
|
||||
list.ToUniqueStringSlice(val),
|
||||
list.ToUniqueStringSlice(modifierValue)...,
|
||||
)
|
||||
|
||||
f.setValue(record, val)
|
||||
}
|
||||
|
||||
func (f *RelationField) prependValue(record *Record, modifierValue any) {
|
||||
val := record.GetRaw(f.Name)
|
||||
|
||||
val = append(
|
||||
list.ToUniqueStringSlice(modifierValue),
|
||||
list.ToUniqueStringSlice(val)...,
|
||||
)
|
||||
|
||||
f.setValue(record, val)
|
||||
}
|
||||
|
||||
func (f *RelationField) subtractValue(record *Record, modifierValue any) {
|
||||
val := record.GetRaw(f.Name)
|
||||
|
||||
val = list.SubtractSlice(
|
||||
list.ToUniqueStringSlice(val),
|
||||
list.ToUniqueStringSlice(modifierValue),
|
||||
)
|
||||
|
||||
f.setValue(record, val)
|
||||
}
|
603
core/field_relation_test.go
Normal file
603
core/field_relation_test.go
Normal file
|
@ -0,0 +1,603 @@
|
|||
package core_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
)
|
||||
|
||||
func TestRelationFieldBaseMethods(t *testing.T) {
|
||||
testFieldBaseMethods(t, core.FieldTypeRelation)
|
||||
}
|
||||
|
||||
func TestRelationFieldColumnType(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
field *core.RelationField
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
"single (zero)",
|
||||
&core.RelationField{},
|
||||
"TEXT DEFAULT '' NOT NULL",
|
||||
},
|
||||
{
|
||||
"single",
|
||||
&core.RelationField{MaxSelect: 1},
|
||||
"TEXT DEFAULT '' NOT NULL",
|
||||
},
|
||||
{
|
||||
"multiple",
|
||||
&core.RelationField{MaxSelect: 2},
|
||||
"JSON DEFAULT '[]' NOT NULL",
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
if v := s.field.ColumnType(app); v != s.expected {
|
||||
t.Fatalf("Expected\n%q\ngot\n%q", s.expected, v)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRelationFieldIsMultiple(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
name string
|
||||
field *core.RelationField
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
"zero",
|
||||
&core.RelationField{},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"single",
|
||||
&core.RelationField{MaxSelect: 1},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"multiple",
|
||||
&core.RelationField{MaxSelect: 2},
|
||||
true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
if v := s.field.IsMultiple(); v != s.expected {
|
||||
t.Fatalf("Expected %v, got %v", s.expected, v)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRelationFieldPrepareValue(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
record := core.NewRecord(core.NewBaseCollection("test"))
|
||||
|
||||
scenarios := []struct {
|
||||
raw any
|
||||
field *core.RelationField
|
||||
expected string
|
||||
}{
|
||||
// single
|
||||
{nil, &core.RelationField{MaxSelect: 1}, `""`},
|
||||
{"", &core.RelationField{MaxSelect: 1}, `""`},
|
||||
{123, &core.RelationField{MaxSelect: 1}, `"123"`},
|
||||
{"a", &core.RelationField{MaxSelect: 1}, `"a"`},
|
||||
{`["a"]`, &core.RelationField{MaxSelect: 1}, `"a"`},
|
||||
{[]string{}, &core.RelationField{MaxSelect: 1}, `""`},
|
||||
{[]string{"a", "b"}, &core.RelationField{MaxSelect: 1}, `"b"`},
|
||||
|
||||
// multiple
|
||||
{nil, &core.RelationField{MaxSelect: 2}, `[]`},
|
||||
{"", &core.RelationField{MaxSelect: 2}, `[]`},
|
||||
{123, &core.RelationField{MaxSelect: 2}, `["123"]`},
|
||||
{"a", &core.RelationField{MaxSelect: 2}, `["a"]`},
|
||||
{`["a"]`, &core.RelationField{MaxSelect: 2}, `["a"]`},
|
||||
{[]string{}, &core.RelationField{MaxSelect: 2}, `[]`},
|
||||
{[]string{"a", "b", "c"}, &core.RelationField{MaxSelect: 2}, `["a","b","c"]`},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%#v_%v", i, s.raw, s.field.IsMultiple()), func(t *testing.T) {
|
||||
v, err := s.field.PrepareValue(record, s.raw)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
vRaw, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if string(vRaw) != s.expected {
|
||||
t.Fatalf("Expected %q, got %q", s.expected, vRaw)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRelationFieldDriverValue(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
scenarios := []struct {
|
||||
raw any
|
||||
field *core.RelationField
|
||||
expected string
|
||||
}{
|
||||
// single
|
||||
{nil, &core.RelationField{MaxSelect: 1}, `""`},
|
||||
{"", &core.RelationField{MaxSelect: 1}, `""`},
|
||||
{123, &core.RelationField{MaxSelect: 1}, `"123"`},
|
||||
{"a", &core.RelationField{MaxSelect: 1}, `"a"`},
|
||||
{`["a"]`, &core.RelationField{MaxSelect: 1}, `"a"`},
|
||||
{[]string{}, &core.RelationField{MaxSelect: 1}, `""`},
|
||||
{[]string{"a", "b"}, &core.RelationField{MaxSelect: 1}, `"b"`},
|
||||
|
||||
// multiple
|
||||
{nil, &core.RelationField{MaxSelect: 2}, `[]`},
|
||||
{"", &core.RelationField{MaxSelect: 2}, `[]`},
|
||||
{123, &core.RelationField{MaxSelect: 2}, `["123"]`},
|
||||
{"a", &core.RelationField{MaxSelect: 2}, `["a"]`},
|
||||
{`["a"]`, &core.RelationField{MaxSelect: 2}, `["a"]`},
|
||||
{[]string{}, &core.RelationField{MaxSelect: 2}, `[]`},
|
||||
{[]string{"a", "b", "c"}, &core.RelationField{MaxSelect: 2}, `["a","b","c"]`},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%#v_%v", i, s.raw, s.field.IsMultiple()), func(t *testing.T) {
|
||||
record := core.NewRecord(core.NewBaseCollection("test"))
|
||||
record.SetRaw(s.field.GetName(), s.raw)
|
||||
|
||||
v, err := s.field.DriverValue(record)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if s.field.IsMultiple() {
|
||||
_, ok := v.(types.JSONArray[string])
|
||||
if !ok {
|
||||
t.Fatalf("Expected types.JSONArray value, got %T", v)
|
||||
}
|
||||
} else {
|
||||
_, ok := v.(string)
|
||||
if !ok {
|
||||
t.Fatalf("Expected string value, got %T", v)
|
||||
}
|
||||
}
|
||||
|
||||
vRaw, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if string(vRaw) != s.expected {
|
||||
t.Fatalf("Expected %q, got %q", s.expected, vRaw)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRelationFieldValidateValue(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
demo1, err := app.FindCollectionByNameOrId("demo1")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
field *core.RelationField
|
||||
record func() *core.Record
|
||||
expectError bool
|
||||
}{
|
||||
// single
|
||||
{
|
||||
"[single] zero field value (not required)",
|
||||
&core.RelationField{Name: "test", MaxSelect: 1, CollectionId: demo1.Id},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(core.NewBaseCollection("test_collection"))
|
||||
record.SetRaw("test", "")
|
||||
return record
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"[single] zero field value (required)",
|
||||
&core.RelationField{Name: "test", MaxSelect: 1, CollectionId: demo1.Id, Required: true},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(core.NewBaseCollection("test_collection"))
|
||||
record.SetRaw("test", "")
|
||||
return record
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"[single] id from other collection",
|
||||
&core.RelationField{Name: "test", MaxSelect: 1, CollectionId: demo1.Id},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(core.NewBaseCollection("test_collection"))
|
||||
record.SetRaw("test", "achvryl401bhse3")
|
||||
return record
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"[single] valid id",
|
||||
&core.RelationField{Name: "test", MaxSelect: 1, CollectionId: demo1.Id},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(core.NewBaseCollection("test_collection"))
|
||||
record.SetRaw("test", "84nmscqy84lsi1t")
|
||||
return record
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"[single] > MaxSelect",
|
||||
&core.RelationField{Name: "test", MaxSelect: 1, CollectionId: demo1.Id},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(core.NewBaseCollection("test_collection"))
|
||||
record.SetRaw("test", []string{"84nmscqy84lsi1t", "al1h9ijdeojtsjy"})
|
||||
return record
|
||||
},
|
||||
true,
|
||||
},
|
||||
|
||||
// multiple
|
||||
{
|
||||
"[multiple] zero field value (not required)",
|
||||
&core.RelationField{Name: "test", MaxSelect: 2, CollectionId: demo1.Id},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(core.NewBaseCollection("test_collection"))
|
||||
record.SetRaw("test", []string{})
|
||||
return record
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"[multiple] zero field value (required)",
|
||||
&core.RelationField{Name: "test", MaxSelect: 2, CollectionId: demo1.Id, Required: true},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(core.NewBaseCollection("test_collection"))
|
||||
record.SetRaw("test", []string{})
|
||||
return record
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"[multiple] id from other collection",
|
||||
&core.RelationField{Name: "test", MaxSelect: 2, CollectionId: demo1.Id},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(core.NewBaseCollection("test_collection"))
|
||||
record.SetRaw("test", []string{"84nmscqy84lsi1t", "achvryl401bhse3"})
|
||||
return record
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"[multiple] valid id",
|
||||
&core.RelationField{Name: "test", MaxSelect: 2, CollectionId: demo1.Id},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(core.NewBaseCollection("test_collection"))
|
||||
record.SetRaw("test", []string{"84nmscqy84lsi1t", "al1h9ijdeojtsjy"})
|
||||
return record
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"[multiple] > MaxSelect",
|
||||
&core.RelationField{Name: "test", MaxSelect: 2, CollectionId: demo1.Id},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(core.NewBaseCollection("test_collection"))
|
||||
record.SetRaw("test", []string{"84nmscqy84lsi1t", "al1h9ijdeojtsjy", "imy661ixudk5izi"})
|
||||
return record
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"[multiple] < MinSelect",
|
||||
&core.RelationField{Name: "test", MinSelect: 2, MaxSelect: 99, CollectionId: demo1.Id},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(core.NewBaseCollection("test_collection"))
|
||||
record.SetRaw("test", []string{"84nmscqy84lsi1t"})
|
||||
return record
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"[multiple] >= MinSelect",
|
||||
&core.RelationField{Name: "test", MinSelect: 2, MaxSelect: 99, CollectionId: demo1.Id},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(core.NewBaseCollection("test_collection"))
|
||||
record.SetRaw("test", []string{"84nmscqy84lsi1t", "al1h9ijdeojtsjy", "imy661ixudk5izi"})
|
||||
return record
|
||||
},
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
err := s.field.ValidateValue(context.Background(), app, s.record())
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRelationFieldValidateSettings(t *testing.T) {
|
||||
testDefaultFieldIdValidation(t, core.FieldTypeRelation)
|
||||
testDefaultFieldNameValidation(t, core.FieldTypeRelation)
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
demo1, err := app.FindCollectionByNameOrId("demo1")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
field func(col *core.Collection) *core.RelationField
|
||||
expectErrors []string
|
||||
}{
|
||||
{
|
||||
"zero minimal",
|
||||
func(col *core.Collection) *core.RelationField {
|
||||
return &core.RelationField{
|
||||
Id: "test",
|
||||
Name: "test",
|
||||
}
|
||||
},
|
||||
[]string{"collectionId"},
|
||||
},
|
||||
{
|
||||
"invalid collectionId",
|
||||
func(col *core.Collection) *core.RelationField {
|
||||
return &core.RelationField{
|
||||
Id: "test",
|
||||
Name: "test",
|
||||
CollectionId: demo1.Name,
|
||||
}
|
||||
},
|
||||
[]string{"collectionId"},
|
||||
},
|
||||
{
|
||||
"valid collectionId",
|
||||
func(col *core.Collection) *core.RelationField {
|
||||
return &core.RelationField{
|
||||
Id: "test",
|
||||
Name: "test",
|
||||
CollectionId: demo1.Id,
|
||||
}
|
||||
},
|
||||
[]string{},
|
||||
},
|
||||
{
|
||||
"base->view",
|
||||
func(col *core.Collection) *core.RelationField {
|
||||
return &core.RelationField{
|
||||
Id: "test",
|
||||
Name: "test",
|
||||
CollectionId: "v9gwnfh02gjq1q0",
|
||||
}
|
||||
},
|
||||
[]string{"collectionId"},
|
||||
},
|
||||
{
|
||||
"view->view",
|
||||
func(col *core.Collection) *core.RelationField {
|
||||
col.Type = core.CollectionTypeView
|
||||
return &core.RelationField{
|
||||
Id: "test",
|
||||
Name: "test",
|
||||
CollectionId: "v9gwnfh02gjq1q0",
|
||||
}
|
||||
},
|
||||
[]string{},
|
||||
},
|
||||
{
|
||||
"MinSelect < 0",
|
||||
func(col *core.Collection) *core.RelationField {
|
||||
return &core.RelationField{
|
||||
Id: "test",
|
||||
Name: "test",
|
||||
CollectionId: demo1.Id,
|
||||
MinSelect: -1,
|
||||
}
|
||||
},
|
||||
[]string{"minSelect"},
|
||||
},
|
||||
{
|
||||
"MinSelect > 0",
|
||||
func(col *core.Collection) *core.RelationField {
|
||||
return &core.RelationField{
|
||||
Id: "test",
|
||||
Name: "test",
|
||||
CollectionId: demo1.Id,
|
||||
MinSelect: 1,
|
||||
}
|
||||
},
|
||||
[]string{"maxSelect"},
|
||||
},
|
||||
{
|
||||
"MaxSelect < MinSelect",
|
||||
func(col *core.Collection) *core.RelationField {
|
||||
return &core.RelationField{
|
||||
Id: "test",
|
||||
Name: "test",
|
||||
CollectionId: demo1.Id,
|
||||
MinSelect: 2,
|
||||
MaxSelect: 1,
|
||||
}
|
||||
},
|
||||
[]string{"maxSelect"},
|
||||
},
|
||||
{
|
||||
"MaxSelect >= MinSelect",
|
||||
func(col *core.Collection) *core.RelationField {
|
||||
return &core.RelationField{
|
||||
Id: "test",
|
||||
Name: "test",
|
||||
CollectionId: demo1.Id,
|
||||
MinSelect: 2,
|
||||
MaxSelect: 2,
|
||||
}
|
||||
},
|
||||
[]string{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
collection := core.NewBaseCollection("test_collection")
|
||||
collection.Fields.GetByName("id").SetId("test") // set a dummy known id so that it can be replaced
|
||||
|
||||
field := s.field(collection)
|
||||
|
||||
collection.Fields.Add(field)
|
||||
|
||||
errs := field.ValidateSettings(context.Background(), app, collection)
|
||||
|
||||
tests.TestValidationErrors(t, errs, s.expectErrors)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRelationFieldFindSetter(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
name string
|
||||
key string
|
||||
value any
|
||||
field *core.RelationField
|
||||
hasSetter bool
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
"no match",
|
||||
"example",
|
||||
"b",
|
||||
&core.RelationField{Name: "test", MaxSelect: 1},
|
||||
false,
|
||||
"",
|
||||
},
|
||||
{
|
||||
"exact match (single)",
|
||||
"test",
|
||||
"b",
|
||||
&core.RelationField{Name: "test", MaxSelect: 1},
|
||||
true,
|
||||
`"b"`,
|
||||
},
|
||||
{
|
||||
"exact match (multiple)",
|
||||
"test",
|
||||
[]string{"a", "b"},
|
||||
&core.RelationField{Name: "test", MaxSelect: 2},
|
||||
true,
|
||||
`["a","b"]`,
|
||||
},
|
||||
{
|
||||
"append (single)",
|
||||
"test+",
|
||||
"b",
|
||||
&core.RelationField{Name: "test", MaxSelect: 1},
|
||||
true,
|
||||
`"b"`,
|
||||
},
|
||||
{
|
||||
"append (multiple)",
|
||||
"test+",
|
||||
[]string{"a"},
|
||||
&core.RelationField{Name: "test", MaxSelect: 2},
|
||||
true,
|
||||
`["c","d","a"]`,
|
||||
},
|
||||
{
|
||||
"prepend (single)",
|
||||
"+test",
|
||||
"b",
|
||||
&core.RelationField{Name: "test", MaxSelect: 1},
|
||||
true,
|
||||
`"d"`, // the last of the existing values
|
||||
},
|
||||
{
|
||||
"prepend (multiple)",
|
||||
"+test",
|
||||
[]string{"a"},
|
||||
&core.RelationField{Name: "test", MaxSelect: 2},
|
||||
true,
|
||||
`["a","c","d"]`,
|
||||
},
|
||||
{
|
||||
"subtract (single)",
|
||||
"test-",
|
||||
"d",
|
||||
&core.RelationField{Name: "test", MaxSelect: 1},
|
||||
true,
|
||||
`"c"`,
|
||||
},
|
||||
{
|
||||
"subtract (multiple)",
|
||||
"test-",
|
||||
[]string{"unknown", "c"},
|
||||
&core.RelationField{Name: "test", MaxSelect: 2},
|
||||
true,
|
||||
`["d"]`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
collection := core.NewBaseCollection("test_collection")
|
||||
collection.Fields.Add(s.field)
|
||||
|
||||
setter := s.field.FindSetter(s.key)
|
||||
|
||||
hasSetter := setter != nil
|
||||
if hasSetter != s.hasSetter {
|
||||
t.Fatalf("Expected hasSetter %v, got %v", s.hasSetter, hasSetter)
|
||||
}
|
||||
|
||||
if !hasSetter {
|
||||
return
|
||||
}
|
||||
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw(s.field.GetName(), []string{"c", "d"})
|
||||
|
||||
setter(record, s.value)
|
||||
|
||||
raw, err := json.Marshal(record.Get(s.field.GetName()))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
rawStr := string(raw)
|
||||
|
||||
if rawStr != s.expected {
|
||||
t.Fatalf("Expected %q, got %q", s.expected, rawStr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
275
core/field_select.go
Normal file
275
core/field_select.go
Normal file
|
@ -0,0 +1,275 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql/driver"
|
||||
"slices"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/pocketbase/pocketbase/tools/list"
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Fields[FieldTypeSelect] = func() Field {
|
||||
return &SelectField{}
|
||||
}
|
||||
}
|
||||
|
||||
const FieldTypeSelect = "select"
|
||||
|
||||
var (
|
||||
_ Field = (*SelectField)(nil)
|
||||
_ MultiValuer = (*SelectField)(nil)
|
||||
_ DriverValuer = (*SelectField)(nil)
|
||||
_ SetterFinder = (*SelectField)(nil)
|
||||
)
|
||||
|
||||
// SelectField defines "select" type field for storing single or
|
||||
// multiple string values from a predefined list.
|
||||
//
|
||||
// Requires the Values option to be set.
|
||||
//
|
||||
// If MaxSelect is not set or <= 1, then the field value is expected to be a single Values element.
|
||||
//
|
||||
// If MaxSelect is > 1, then the field value is expected to be a subset of Values slice.
|
||||
//
|
||||
// The respective zero record field value is either empty string (single) or empty string slice (multiple).
|
||||
//
|
||||
// ---
|
||||
//
|
||||
// The following additional setter keys are available:
|
||||
//
|
||||
// - "fieldName+" - append one or more values to the existing record one. For example:
|
||||
//
|
||||
// record.Set("roles+", []string{"new1", "new2"}) // []string{"old1", "old2", "new1", "new2"}
|
||||
//
|
||||
// - "+fieldName" - prepend one or more values to the existing record one. For example:
|
||||
//
|
||||
// record.Set("+roles", []string{"new1", "new2"}) // []string{"new1", "new2", "old1", "old2"}
|
||||
//
|
||||
// - "fieldName-" - subtract one or more values from the existing record one. For example:
|
||||
//
|
||||
// record.Set("roles-", "old1") // []string{"old2"}
|
||||
type SelectField struct {
|
||||
// Name (required) is the unique name of the field.
|
||||
Name string `form:"name" json:"name"`
|
||||
|
||||
// Id is the unique stable field identifier.
|
||||
//
|
||||
// It is automatically generated from the name when adding to a collection FieldsList.
|
||||
Id string `form:"id" json:"id"`
|
||||
|
||||
// System prevents the renaming and removal of the field.
|
||||
System bool `form:"system" json:"system"`
|
||||
|
||||
// Hidden hides the field from the API response.
|
||||
Hidden bool `form:"hidden" json:"hidden"`
|
||||
|
||||
// Presentable hints the Dashboard UI to use the underlying
|
||||
// field record value in the relation preview label.
|
||||
Presentable bool `form:"presentable" json:"presentable"`
|
||||
|
||||
// ---
|
||||
|
||||
// Values specifies the list of accepted values.
|
||||
Values []string `form:"values" json:"values"`
|
||||
|
||||
// MaxSelect specifies the max allowed selected values.
|
||||
//
|
||||
// For multiple select the value must be > 1, otherwise fallbacks to single (default).
|
||||
MaxSelect int `form:"maxSelect" json:"maxSelect"`
|
||||
|
||||
// Required will require the field value to be non-empty.
|
||||
Required bool `form:"required" json:"required"`
|
||||
}
|
||||
|
||||
// Type implements [Field.Type] interface method.
|
||||
func (f *SelectField) Type() string {
|
||||
return FieldTypeSelect
|
||||
}
|
||||
|
||||
// GetId implements [Field.GetId] interface method.
|
||||
func (f *SelectField) GetId() string {
|
||||
return f.Id
|
||||
}
|
||||
|
||||
// SetId implements [Field.SetId] interface method.
|
||||
func (f *SelectField) SetId(id string) {
|
||||
f.Id = id
|
||||
}
|
||||
|
||||
// GetName implements [Field.GetName] interface method.
|
||||
func (f *SelectField) GetName() string {
|
||||
return f.Name
|
||||
}
|
||||
|
||||
// SetName implements [Field.SetName] interface method.
|
||||
func (f *SelectField) SetName(name string) {
|
||||
f.Name = name
|
||||
}
|
||||
|
||||
// GetSystem implements [Field.GetSystem] interface method.
|
||||
func (f *SelectField) GetSystem() bool {
|
||||
return f.System
|
||||
}
|
||||
|
||||
// SetSystem implements [Field.SetSystem] interface method.
|
||||
func (f *SelectField) SetSystem(system bool) {
|
||||
f.System = system
|
||||
}
|
||||
|
||||
// GetHidden implements [Field.GetHidden] interface method.
|
||||
func (f *SelectField) GetHidden() bool {
|
||||
return f.Hidden
|
||||
}
|
||||
|
||||
// SetHidden implements [Field.SetHidden] interface method.
|
||||
func (f *SelectField) SetHidden(hidden bool) {
|
||||
f.Hidden = hidden
|
||||
}
|
||||
|
||||
// IsMultiple implements [MultiValuer] interface and checks whether the
|
||||
// current field options support multiple values.
|
||||
func (f *SelectField) IsMultiple() bool {
|
||||
return f.MaxSelect > 1
|
||||
}
|
||||
|
||||
// ColumnType implements [Field.ColumnType] interface method.
|
||||
func (f *SelectField) ColumnType(app App) string {
|
||||
if f.IsMultiple() {
|
||||
return "JSON DEFAULT '[]' NOT NULL"
|
||||
}
|
||||
|
||||
return "TEXT DEFAULT '' NOT NULL"
|
||||
}
|
||||
|
||||
// PrepareValue implements [Field.PrepareValue] interface method.
|
||||
func (f *SelectField) PrepareValue(record *Record, raw any) (any, error) {
|
||||
return f.normalizeValue(raw), nil
|
||||
}
|
||||
|
||||
func (f *SelectField) normalizeValue(raw any) any {
|
||||
val := list.ToUniqueStringSlice(raw)
|
||||
|
||||
if !f.IsMultiple() {
|
||||
if len(val) > 0 {
|
||||
return val[len(val)-1] // the last selected
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
return val
|
||||
}
|
||||
|
||||
// DriverValue implements the [DriverValuer] interface.
|
||||
func (f *SelectField) DriverValue(record *Record) (driver.Value, error) {
|
||||
val := list.ToUniqueStringSlice(record.GetRaw(f.Name))
|
||||
|
||||
if !f.IsMultiple() {
|
||||
if len(val) > 0 {
|
||||
return val[len(val)-1], nil // the last selected
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// serialize as json string array
|
||||
return append(types.JSONArray[string]{}, val...), nil
|
||||
}
|
||||
|
||||
// ValidateValue implements [Field.ValidateValue] interface method.
|
||||
func (f *SelectField) ValidateValue(ctx context.Context, app App, record *Record) error {
|
||||
normalizedVal := list.ToUniqueStringSlice(record.GetRaw(f.Name))
|
||||
if len(normalizedVal) == 0 {
|
||||
if f.Required {
|
||||
return validation.ErrRequired
|
||||
}
|
||||
return nil // nothing to check
|
||||
}
|
||||
|
||||
maxSelect := max(f.MaxSelect, 1)
|
||||
|
||||
// check max selected items
|
||||
if len(normalizedVal) > maxSelect {
|
||||
return validation.NewError("validation_too_many_values", "Select no more than {{.maxSelect}}").
|
||||
SetParams(map[string]any{"maxSelect": maxSelect})
|
||||
}
|
||||
|
||||
// check against the allowed values
|
||||
for _, val := range normalizedVal {
|
||||
if !slices.Contains(f.Values, val) {
|
||||
return validation.NewError("validation_invalid_value", "Invalid value {{.value}}").
|
||||
SetParams(map[string]any{"value": val})
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateSettings implements [Field.ValidateSettings] interface method.
|
||||
func (f *SelectField) ValidateSettings(ctx context.Context, app App, collection *Collection) error {
|
||||
max := len(f.Values)
|
||||
if max == 0 {
|
||||
max = 1
|
||||
}
|
||||
|
||||
return validation.ValidateStruct(f,
|
||||
validation.Field(&f.Id, validation.By(DefaultFieldIdValidationRule)),
|
||||
validation.Field(&f.Name, validation.By(DefaultFieldNameValidationRule)),
|
||||
validation.Field(&f.Values, validation.Required),
|
||||
validation.Field(&f.MaxSelect, validation.Min(0), validation.Max(max)),
|
||||
)
|
||||
}
|
||||
|
||||
// FindSetter implements the [SetterFinder] interface.
|
||||
func (f *SelectField) FindSetter(key string) SetterFunc {
|
||||
switch key {
|
||||
case f.Name:
|
||||
return f.setValue
|
||||
case "+" + f.Name:
|
||||
return f.prependValue
|
||||
case f.Name + "+":
|
||||
return f.appendValue
|
||||
case f.Name + "-":
|
||||
return f.subtractValue
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (f *SelectField) setValue(record *Record, raw any) {
|
||||
record.SetRaw(f.Name, f.normalizeValue(raw))
|
||||
}
|
||||
|
||||
func (f *SelectField) appendValue(record *Record, modifierValue any) {
|
||||
val := record.GetRaw(f.Name)
|
||||
|
||||
val = append(
|
||||
list.ToUniqueStringSlice(val),
|
||||
list.ToUniqueStringSlice(modifierValue)...,
|
||||
)
|
||||
|
||||
f.setValue(record, val)
|
||||
}
|
||||
|
||||
func (f *SelectField) prependValue(record *Record, modifierValue any) {
|
||||
val := record.GetRaw(f.Name)
|
||||
|
||||
val = append(
|
||||
list.ToUniqueStringSlice(modifierValue),
|
||||
list.ToUniqueStringSlice(val)...,
|
||||
)
|
||||
|
||||
f.setValue(record, val)
|
||||
}
|
||||
|
||||
func (f *SelectField) subtractValue(record *Record, modifierValue any) {
|
||||
val := record.GetRaw(f.Name)
|
||||
|
||||
val = list.SubtractSlice(
|
||||
list.ToUniqueStringSlice(val),
|
||||
list.ToUniqueStringSlice(modifierValue),
|
||||
)
|
||||
|
||||
f.setValue(record, val)
|
||||
}
|
516
core/field_select_test.go
Normal file
516
core/field_select_test.go
Normal file
|
@ -0,0 +1,516 @@
|
|||
package core_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
)
|
||||
|
||||
func TestSelectFieldBaseMethods(t *testing.T) {
|
||||
testFieldBaseMethods(t, core.FieldTypeSelect)
|
||||
}
|
||||
|
||||
func TestSelectFieldColumnType(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
field *core.SelectField
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
"single (zero)",
|
||||
&core.SelectField{},
|
||||
"TEXT DEFAULT '' NOT NULL",
|
||||
},
|
||||
{
|
||||
"single",
|
||||
&core.SelectField{MaxSelect: 1},
|
||||
"TEXT DEFAULT '' NOT NULL",
|
||||
},
|
||||
{
|
||||
"multiple",
|
||||
&core.SelectField{MaxSelect: 2},
|
||||
"JSON DEFAULT '[]' NOT NULL",
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
if v := s.field.ColumnType(app); v != s.expected {
|
||||
t.Fatalf("Expected\n%q\ngot\n%q", s.expected, v)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectFieldIsMultiple(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
name string
|
||||
field *core.SelectField
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
"single (zero)",
|
||||
&core.SelectField{},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"single",
|
||||
&core.SelectField{MaxSelect: 1},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"multiple (>1)",
|
||||
&core.SelectField{MaxSelect: 2},
|
||||
true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
if v := s.field.IsMultiple(); v != s.expected {
|
||||
t.Fatalf("Expected %v, got %v", s.expected, v)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectFieldPrepareValue(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
record := core.NewRecord(core.NewBaseCollection("test"))
|
||||
|
||||
scenarios := []struct {
|
||||
raw any
|
||||
field *core.SelectField
|
||||
expected string
|
||||
}{
|
||||
// single
|
||||
{nil, &core.SelectField{}, `""`},
|
||||
{"", &core.SelectField{}, `""`},
|
||||
{123, &core.SelectField{}, `"123"`},
|
||||
{"a", &core.SelectField{}, `"a"`},
|
||||
{`["a"]`, &core.SelectField{}, `"a"`},
|
||||
{[]string{}, &core.SelectField{}, `""`},
|
||||
{[]string{"a", "b"}, &core.SelectField{}, `"b"`},
|
||||
|
||||
// multiple
|
||||
{nil, &core.SelectField{MaxSelect: 2}, `[]`},
|
||||
{"", &core.SelectField{MaxSelect: 2}, `[]`},
|
||||
{123, &core.SelectField{MaxSelect: 2}, `["123"]`},
|
||||
{"a", &core.SelectField{MaxSelect: 2}, `["a"]`},
|
||||
{`["a"]`, &core.SelectField{MaxSelect: 2}, `["a"]`},
|
||||
{[]string{}, &core.SelectField{MaxSelect: 2}, `[]`},
|
||||
{[]string{"a", "b", "c"}, &core.SelectField{MaxSelect: 2}, `["a","b","c"]`},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%#v_%v", i, s.raw, s.field.IsMultiple()), func(t *testing.T) {
|
||||
v, err := s.field.PrepareValue(record, s.raw)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
vRaw, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if string(vRaw) != s.expected {
|
||||
t.Fatalf("Expected %q, got %q", s.expected, vRaw)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectFieldDriverValue(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
scenarios := []struct {
|
||||
raw any
|
||||
field *core.SelectField
|
||||
expected string
|
||||
}{
|
||||
// single
|
||||
{nil, &core.SelectField{}, `""`},
|
||||
{"", &core.SelectField{}, `""`},
|
||||
{123, &core.SelectField{}, `"123"`},
|
||||
{"a", &core.SelectField{}, `"a"`},
|
||||
{`["a"]`, &core.SelectField{}, `"a"`},
|
||||
{[]string{}, &core.SelectField{}, `""`},
|
||||
{[]string{"a", "b"}, &core.SelectField{}, `"b"`},
|
||||
|
||||
// multiple
|
||||
{nil, &core.SelectField{MaxSelect: 2}, `[]`},
|
||||
{"", &core.SelectField{MaxSelect: 2}, `[]`},
|
||||
{123, &core.SelectField{MaxSelect: 2}, `["123"]`},
|
||||
{"a", &core.SelectField{MaxSelect: 2}, `["a"]`},
|
||||
{`["a"]`, &core.SelectField{MaxSelect: 2}, `["a"]`},
|
||||
{[]string{}, &core.SelectField{MaxSelect: 2}, `[]`},
|
||||
{[]string{"a", "b", "c"}, &core.SelectField{MaxSelect: 2}, `["a","b","c"]`},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%#v_%v", i, s.raw, s.field.IsMultiple()), func(t *testing.T) {
|
||||
record := core.NewRecord(core.NewBaseCollection("test"))
|
||||
record.SetRaw(s.field.GetName(), s.raw)
|
||||
|
||||
v, err := s.field.DriverValue(record)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if s.field.IsMultiple() {
|
||||
_, ok := v.(types.JSONArray[string])
|
||||
if !ok {
|
||||
t.Fatalf("Expected types.JSONArray value, got %T", v)
|
||||
}
|
||||
} else {
|
||||
_, ok := v.(string)
|
||||
if !ok {
|
||||
t.Fatalf("Expected string value, got %T", v)
|
||||
}
|
||||
}
|
||||
|
||||
vRaw, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if string(vRaw) != s.expected {
|
||||
t.Fatalf("Expected %q, got %q", s.expected, vRaw)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectFieldValidateValue(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
collection := core.NewBaseCollection("test_collection")
|
||||
|
||||
values := []string{"a", "b", "c"}
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
field *core.SelectField
|
||||
record func() *core.Record
|
||||
expectError bool
|
||||
}{
|
||||
// single
|
||||
{
|
||||
"[single] zero field value (not required)",
|
||||
&core.SelectField{Name: "test", Values: values, MaxSelect: 1},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", "")
|
||||
return record
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"[single] zero field value (required)",
|
||||
&core.SelectField{Name: "test", Values: values, MaxSelect: 1, Required: true},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", "")
|
||||
return record
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"[single] unknown value",
|
||||
&core.SelectField{Name: "test", Values: values, MaxSelect: 1},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", "unknown")
|
||||
return record
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"[single] known value",
|
||||
&core.SelectField{Name: "test", Values: values, MaxSelect: 1},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", "a")
|
||||
return record
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"[single] > MaxSelect",
|
||||
&core.SelectField{Name: "test", Values: values, MaxSelect: 1},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", []string{"a", "b"})
|
||||
return record
|
||||
},
|
||||
true,
|
||||
},
|
||||
|
||||
// multiple
|
||||
{
|
||||
"[multiple] zero field value (not required)",
|
||||
&core.SelectField{Name: "test", Values: values, MaxSelect: 2},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", []string{})
|
||||
return record
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"[multiple] zero field value (required)",
|
||||
&core.SelectField{Name: "test", Values: values, MaxSelect: 2, Required: true},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", []string{})
|
||||
return record
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"[multiple] unknown value",
|
||||
&core.SelectField{Name: "test", Values: values, MaxSelect: 2},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", []string{"a", "unknown"})
|
||||
return record
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"[multiple] known value",
|
||||
&core.SelectField{Name: "test", Values: values, MaxSelect: 2},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", []string{"a", "b"})
|
||||
return record
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"[multiple] > MaxSelect",
|
||||
&core.SelectField{Name: "test", Values: values, MaxSelect: 2},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", []string{"a", "b", "c"})
|
||||
return record
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"[multiple] > MaxSelect (duplicated values)",
|
||||
&core.SelectField{Name: "test", Values: values, MaxSelect: 2},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", []string{"a", "b", "b", "a"})
|
||||
return record
|
||||
},
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
err := s.field.ValidateValue(context.Background(), app, s.record())
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectFieldValidateSettings(t *testing.T) {
|
||||
testDefaultFieldIdValidation(t, core.FieldTypeSelect)
|
||||
testDefaultFieldNameValidation(t, core.FieldTypeSelect)
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
field func() *core.SelectField
|
||||
expectErrors []string
|
||||
}{
|
||||
{
|
||||
"zero minimal",
|
||||
func() *core.SelectField {
|
||||
return &core.SelectField{
|
||||
Id: "test",
|
||||
Name: "test",
|
||||
}
|
||||
},
|
||||
[]string{"values"},
|
||||
},
|
||||
{
|
||||
"MaxSelect > Values length",
|
||||
func() *core.SelectField {
|
||||
return &core.SelectField{
|
||||
Id: "test",
|
||||
Name: "test",
|
||||
Values: []string{"a", "b"},
|
||||
MaxSelect: 3,
|
||||
}
|
||||
},
|
||||
[]string{"maxSelect"},
|
||||
},
|
||||
{
|
||||
"MaxSelect <= Values length",
|
||||
func() *core.SelectField {
|
||||
return &core.SelectField{
|
||||
Id: "test",
|
||||
Name: "test",
|
||||
Values: []string{"a", "b"},
|
||||
MaxSelect: 2,
|
||||
}
|
||||
},
|
||||
[]string{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
field := s.field()
|
||||
|
||||
collection := core.NewBaseCollection("test_collection")
|
||||
collection.Fields.Add(field)
|
||||
|
||||
errs := field.ValidateSettings(context.Background(), app, collection)
|
||||
|
||||
tests.TestValidationErrors(t, errs, s.expectErrors)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectFieldFindSetter(t *testing.T) {
|
||||
values := []string{"a", "b", "c", "d"}
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
key string
|
||||
value any
|
||||
field *core.SelectField
|
||||
hasSetter bool
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
"no match",
|
||||
"example",
|
||||
"b",
|
||||
&core.SelectField{Name: "test", MaxSelect: 1, Values: values},
|
||||
false,
|
||||
"",
|
||||
},
|
||||
{
|
||||
"exact match (single)",
|
||||
"test",
|
||||
"b",
|
||||
&core.SelectField{Name: "test", MaxSelect: 1, Values: values},
|
||||
true,
|
||||
`"b"`,
|
||||
},
|
||||
{
|
||||
"exact match (multiple)",
|
||||
"test",
|
||||
[]string{"a", "b"},
|
||||
&core.SelectField{Name: "test", MaxSelect: 2, Values: values},
|
||||
true,
|
||||
`["a","b"]`,
|
||||
},
|
||||
{
|
||||
"append (single)",
|
||||
"test+",
|
||||
"b",
|
||||
&core.SelectField{Name: "test", MaxSelect: 1, Values: values},
|
||||
true,
|
||||
`"b"`,
|
||||
},
|
||||
{
|
||||
"append (multiple)",
|
||||
"test+",
|
||||
[]string{"a"},
|
||||
&core.SelectField{Name: "test", MaxSelect: 2, Values: values},
|
||||
true,
|
||||
`["c","d","a"]`,
|
||||
},
|
||||
{
|
||||
"prepend (single)",
|
||||
"+test",
|
||||
"b",
|
||||
&core.SelectField{Name: "test", MaxSelect: 1, Values: values},
|
||||
true,
|
||||
`"d"`, // the last of the existing values
|
||||
},
|
||||
{
|
||||
"prepend (multiple)",
|
||||
"+test",
|
||||
[]string{"a"},
|
||||
&core.SelectField{Name: "test", MaxSelect: 2, Values: values},
|
||||
true,
|
||||
`["a","c","d"]`,
|
||||
},
|
||||
{
|
||||
"subtract (single)",
|
||||
"test-",
|
||||
"d",
|
||||
&core.SelectField{Name: "test", MaxSelect: 1, Values: values},
|
||||
true,
|
||||
`"c"`,
|
||||
},
|
||||
{
|
||||
"subtract (multiple)",
|
||||
"test-",
|
||||
[]string{"unknown", "c"},
|
||||
&core.SelectField{Name: "test", MaxSelect: 2, Values: values},
|
||||
true,
|
||||
`["d"]`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
collection := core.NewBaseCollection("test_collection")
|
||||
collection.Fields.Add(s.field)
|
||||
|
||||
setter := s.field.FindSetter(s.key)
|
||||
|
||||
hasSetter := setter != nil
|
||||
if hasSetter != s.hasSetter {
|
||||
t.Fatalf("Expected hasSetter %v, got %v", s.hasSetter, hasSetter)
|
||||
}
|
||||
|
||||
if !hasSetter {
|
||||
return
|
||||
}
|
||||
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw(s.field.GetName(), []string{"c", "d"})
|
||||
|
||||
setter(record, s.value)
|
||||
|
||||
raw, err := json.Marshal(record.Get(s.field.GetName()))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
rawStr := string(raw)
|
||||
|
||||
if rawStr != s.expected {
|
||||
t.Fatalf("Expected %q, got %q", s.expected, rawStr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
261
core/field_test.go
Normal file
261
core/field_test.go
Normal file
|
@ -0,0 +1,261 @@
|
|||
package core_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func testFieldBaseMethods(t *testing.T, fieldType string) {
|
||||
factory, ok := core.Fields[fieldType]
|
||||
if !ok {
|
||||
t.Fatalf("Missing %q field factory", fieldType)
|
||||
}
|
||||
|
||||
f := factory()
|
||||
if f == nil {
|
||||
t.Fatal("Expected non-nil Field instance")
|
||||
}
|
||||
|
||||
t.Run("type", func(t *testing.T) {
|
||||
if v := f.Type(); v != fieldType {
|
||||
t.Fatalf("Expected type %q, got %q", fieldType, v)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("id", func(t *testing.T) {
|
||||
testValues := []string{"new_id", ""}
|
||||
for _, expected := range testValues {
|
||||
f.SetId(expected)
|
||||
if v := f.GetId(); v != expected {
|
||||
t.Fatalf("Expected id %q, got %q", expected, v)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("name", func(t *testing.T) {
|
||||
testValues := []string{"new_name", ""}
|
||||
for _, expected := range testValues {
|
||||
f.SetName(expected)
|
||||
if v := f.GetName(); v != expected {
|
||||
t.Fatalf("Expected name %q, got %q", expected, v)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("system", func(t *testing.T) {
|
||||
testValues := []bool{false, true}
|
||||
for _, expected := range testValues {
|
||||
f.SetSystem(expected)
|
||||
if v := f.GetSystem(); v != expected {
|
||||
t.Fatalf("Expected system %v, got %v", expected, v)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("hidden", func(t *testing.T) {
|
||||
testValues := []bool{false, true}
|
||||
for _, expected := range testValues {
|
||||
f.SetHidden(expected)
|
||||
if v := f.GetHidden(); v != expected {
|
||||
t.Fatalf("Expected hidden %v, got %v", expected, v)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func testDefaultFieldIdValidation(t *testing.T, fieldType string) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
collection := core.NewBaseCollection("test_collection")
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
field func() core.Field
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
"empty value",
|
||||
func() core.Field {
|
||||
f := core.Fields[fieldType]()
|
||||
return f
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"invalid length",
|
||||
func() core.Field {
|
||||
f := core.Fields[fieldType]()
|
||||
f.SetId(strings.Repeat("a", 101))
|
||||
return f
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"valid length",
|
||||
func() core.Field {
|
||||
f := core.Fields[fieldType]()
|
||||
f.SetId(strings.Repeat("a", 100))
|
||||
return f
|
||||
},
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run("[id] "+s.name, func(t *testing.T) {
|
||||
errs, _ := s.field().ValidateSettings(context.Background(), app, collection).(validation.Errors)
|
||||
|
||||
hasErr := errs["id"] != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Fatalf("Expected hasErr %v, got %v", s.expectError, hasErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func testDefaultFieldNameValidation(t *testing.T, fieldType string) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
collection := core.NewBaseCollection("test_collection")
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
field func() core.Field
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
"empty value",
|
||||
func() core.Field {
|
||||
f := core.Fields[fieldType]()
|
||||
return f
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"invalid length",
|
||||
func() core.Field {
|
||||
f := core.Fields[fieldType]()
|
||||
f.SetName(strings.Repeat("a", 101))
|
||||
return f
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"valid length",
|
||||
func() core.Field {
|
||||
f := core.Fields[fieldType]()
|
||||
f.SetName(strings.Repeat("a", 100))
|
||||
return f
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"invalid regex",
|
||||
func() core.Field {
|
||||
f := core.Fields[fieldType]()
|
||||
f.SetName("test(")
|
||||
return f
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"valid regex",
|
||||
func() core.Field {
|
||||
f := core.Fields[fieldType]()
|
||||
f.SetName("test_123")
|
||||
return f
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"_via_",
|
||||
func() core.Field {
|
||||
f := core.Fields[fieldType]()
|
||||
f.SetName("a_via_b")
|
||||
return f
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"system reserved - null",
|
||||
func() core.Field {
|
||||
f := core.Fields[fieldType]()
|
||||
f.SetName("null")
|
||||
return f
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"system reserved - false",
|
||||
func() core.Field {
|
||||
f := core.Fields[fieldType]()
|
||||
f.SetName("false")
|
||||
return f
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"system reserved - true",
|
||||
func() core.Field {
|
||||
f := core.Fields[fieldType]()
|
||||
f.SetName("true")
|
||||
return f
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"system reserved - _rowid_",
|
||||
func() core.Field {
|
||||
f := core.Fields[fieldType]()
|
||||
f.SetName("_rowid_")
|
||||
return f
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"system reserved - expand",
|
||||
func() core.Field {
|
||||
f := core.Fields[fieldType]()
|
||||
f.SetName("expand")
|
||||
return f
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"system reserved - collectionId",
|
||||
func() core.Field {
|
||||
f := core.Fields[fieldType]()
|
||||
f.SetName("collectionId")
|
||||
return f
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"system reserved - collectionName",
|
||||
func() core.Field {
|
||||
f := core.Fields[fieldType]()
|
||||
f.SetName("collectionName")
|
||||
return f
|
||||
},
|
||||
true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run("[name] "+s.name, func(t *testing.T) {
|
||||
errs, _ := s.field().ValidateSettings(context.Background(), app, collection).(validation.Errors)
|
||||
|
||||
hasErr := errs["name"] != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Fatalf("Expected hasErr %v, got %v", s.expectError, hasErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
367
core/field_text.go
Normal file
367
core/field_text.go
Normal file
|
@ -0,0 +1,367 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/core/validators"
|
||||
"github.com/pocketbase/pocketbase/tools/security"
|
||||
"github.com/spf13/cast"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Fields[FieldTypeText] = func() Field {
|
||||
return &TextField{}
|
||||
}
|
||||
}
|
||||
|
||||
const FieldTypeText = "text"
|
||||
|
||||
const autogenerateModifier = ":autogenerate"
|
||||
|
||||
var (
|
||||
_ Field = (*TextField)(nil)
|
||||
_ SetterFinder = (*TextField)(nil)
|
||||
_ RecordInterceptor = (*TextField)(nil)
|
||||
)
|
||||
|
||||
// TextField defines "text" type field for storing any string value.
|
||||
//
|
||||
// The respective zero record field value is empty string.
|
||||
//
|
||||
// The following additional setter keys are available:
|
||||
//
|
||||
// - "fieldName:autogenerate" - autogenerate field value if AutogeneratePattern is set. For example:
|
||||
//
|
||||
// record.Set("slug:autogenerate", "") // [random value]
|
||||
// record.Set("slug:autogenerate", "abc-") // abc-[random value]
|
||||
type TextField struct {
|
||||
// Name (required) is the unique name of the field.
|
||||
Name string `form:"name" json:"name"`
|
||||
|
||||
// Id is the unique stable field identifier.
|
||||
//
|
||||
// It is automatically generated from the name when adding to a collection FieldsList.
|
||||
Id string `form:"id" json:"id"`
|
||||
|
||||
// System prevents the renaming and removal of the field.
|
||||
System bool `form:"system" json:"system"`
|
||||
|
||||
// Hidden hides the field from the API response.
|
||||
Hidden bool `form:"hidden" json:"hidden"`
|
||||
|
||||
// Presentable hints the Dashboard UI to use the underlying
|
||||
// field record value in the relation preview label.
|
||||
Presentable bool `form:"presentable" json:"presentable"`
|
||||
|
||||
// ---
|
||||
|
||||
// Min specifies the minimum required string characters.
|
||||
//
|
||||
// if zero value, no min limit is applied.
|
||||
Min int `form:"min" json:"min"`
|
||||
|
||||
// Max specifies the maximum allowed string characters.
|
||||
//
|
||||
// If zero, a default limit of 5000 is applied.
|
||||
Max int `form:"max" json:"max"`
|
||||
|
||||
// Pattern specifies an optional regex pattern to match against the field value.
|
||||
//
|
||||
// Leave it empty to skip the pattern check.
|
||||
Pattern string `form:"pattern" json:"pattern"`
|
||||
|
||||
// AutogeneratePattern specifies an optional regex pattern that could
|
||||
// be used to generate random string from it and set it automatically
|
||||
// on record create if no explicit value is set or when the `:autogenerate` modifier is used.
|
||||
//
|
||||
// Note: the generated value still needs to satisfy min, max, pattern (if set)
|
||||
AutogeneratePattern string `form:"autogeneratePattern" json:"autogeneratePattern"`
|
||||
|
||||
// Required will require the field value to be non-empty string.
|
||||
Required bool `form:"required" json:"required"`
|
||||
|
||||
// PrimaryKey will mark the field as primary key.
|
||||
//
|
||||
// A single collection can have only 1 field marked as primary key.
|
||||
PrimaryKey bool `form:"primaryKey" json:"primaryKey"`
|
||||
}
|
||||
|
||||
// Type implements [Field.Type] interface method.
|
||||
func (f *TextField) Type() string {
|
||||
return FieldTypeText
|
||||
}
|
||||
|
||||
// GetId implements [Field.GetId] interface method.
|
||||
func (f *TextField) GetId() string {
|
||||
return f.Id
|
||||
}
|
||||
|
||||
// SetId implements [Field.SetId] interface method.
|
||||
func (f *TextField) SetId(id string) {
|
||||
f.Id = id
|
||||
}
|
||||
|
||||
// GetName implements [Field.GetName] interface method.
|
||||
func (f *TextField) GetName() string {
|
||||
return f.Name
|
||||
}
|
||||
|
||||
// SetName implements [Field.SetName] interface method.
|
||||
func (f *TextField) SetName(name string) {
|
||||
f.Name = name
|
||||
}
|
||||
|
||||
// GetSystem implements [Field.GetSystem] interface method.
|
||||
func (f *TextField) GetSystem() bool {
|
||||
return f.System
|
||||
}
|
||||
|
||||
// SetSystem implements [Field.SetSystem] interface method.
|
||||
func (f *TextField) SetSystem(system bool) {
|
||||
f.System = system
|
||||
}
|
||||
|
||||
// GetHidden implements [Field.GetHidden] interface method.
|
||||
func (f *TextField) GetHidden() bool {
|
||||
return f.Hidden
|
||||
}
|
||||
|
||||
// SetHidden implements [Field.SetHidden] interface method.
|
||||
func (f *TextField) SetHidden(hidden bool) {
|
||||
f.Hidden = hidden
|
||||
}
|
||||
|
||||
// ColumnType implements [Field.ColumnType] interface method.
|
||||
func (f *TextField) ColumnType(app App) string {
|
||||
if f.PrimaryKey {
|
||||
// note: the default is just a last resort fallback to avoid empty
|
||||
// string values in case the record was inserted with raw sql and
|
||||
// it is not actually used when operating with the db abstraction
|
||||
return "TEXT PRIMARY KEY DEFAULT ('r'||lower(hex(randomblob(7)))) NOT NULL"
|
||||
}
|
||||
|
||||
return "TEXT DEFAULT '' NOT NULL"
|
||||
}
|
||||
|
||||
// PrepareValue implements [Field.PrepareValue] interface method.
|
||||
func (f *TextField) PrepareValue(record *Record, raw any) (any, error) {
|
||||
return cast.ToString(raw), nil
|
||||
}
|
||||
|
||||
var forbiddenPKChars = []string{"/", "\\"}
|
||||
|
||||
// ValidateValue implements [Field.ValidateValue] interface method.
|
||||
func (f *TextField) ValidateValue(ctx context.Context, app App, record *Record) error {
|
||||
newVal, ok := record.GetRaw(f.Name).(string)
|
||||
if !ok {
|
||||
return validators.ErrUnsupportedValueType
|
||||
}
|
||||
|
||||
if f.PrimaryKey {
|
||||
// disallow PK change
|
||||
if !record.IsNew() {
|
||||
oldVal := record.LastSavedPK()
|
||||
if oldVal != newVal {
|
||||
return validation.NewError("validation_pk_change", "The record primary key cannot be changed.")
|
||||
}
|
||||
if oldVal != "" {
|
||||
// no need to further validate because the id can't be updated
|
||||
// and because the id could have been inserted manually by migration from another system
|
||||
// that may not comply with the user defined PocketBase validations
|
||||
return nil
|
||||
}
|
||||
} else {
|
||||
// disallow PK special characters no matter of the Pattern validator to minimize
|
||||
// side-effects when the primary key is used for example in a directory path
|
||||
for _, c := range forbiddenPKChars {
|
||||
if strings.Contains(newVal, c) {
|
||||
return validation.NewError("validation_pk_forbidden", "The record primary key contains forbidden characters.").
|
||||
SetParams(map[string]any{"forbidden": c})
|
||||
}
|
||||
}
|
||||
|
||||
// this technically shouldn't be necessarily but again to
|
||||
// minimize misuse of the Pattern validator that could cause
|
||||
// side-effects on some platforms check for duplicates in a case-insensitive manner
|
||||
//
|
||||
// (@todo eventually may get replaced in the future with a system unique constraint to avoid races or wrapping the request in a transaction)
|
||||
if f.Pattern != defaultLowercaseRecordIdPattern {
|
||||
var exists int
|
||||
err := app.ConcurrentDB().
|
||||
Select("(1)").
|
||||
From(record.TableName()).
|
||||
Where(dbx.NewExp("id = {:id} COLLATE NOCASE", dbx.Params{"id": newVal})).
|
||||
Limit(1).
|
||||
Row(&exists)
|
||||
if exists > 0 || (err != nil && !errors.Is(err, sql.ErrNoRows)) {
|
||||
return validation.NewError("validation_pk_invalid", "The record primary key is invalid or already exists.")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return f.ValidatePlainValue(newVal)
|
||||
}
|
||||
|
||||
// ValidatePlainValue validates the provided string against the field options.
|
||||
func (f *TextField) ValidatePlainValue(value string) error {
|
||||
if f.Required || f.PrimaryKey {
|
||||
if err := validation.Required.Validate(value); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if value == "" {
|
||||
return nil // nothing to check
|
||||
}
|
||||
|
||||
// note: casted to []rune to count multi-byte chars as one
|
||||
length := len([]rune(value))
|
||||
|
||||
if f.Min > 0 && length < f.Min {
|
||||
return validation.NewError("validation_min_text_constraint", "Must be at least {{.min}} character(s)").
|
||||
SetParams(map[string]any{"min": f.Min})
|
||||
}
|
||||
|
||||
max := f.Max
|
||||
if max == 0 {
|
||||
max = 5000
|
||||
}
|
||||
|
||||
if max > 0 && length > max {
|
||||
return validation.NewError("validation_max_text_constraint", "Must be no more than {{.max}} character(s)").
|
||||
SetParams(map[string]any{"max": max})
|
||||
}
|
||||
|
||||
if f.Pattern != "" {
|
||||
match, _ := regexp.MatchString(f.Pattern, value)
|
||||
if !match {
|
||||
return validation.NewError("validation_invalid_format", "Invalid value format")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateSettings implements [Field.ValidateSettings] interface method.
|
||||
func (f *TextField) ValidateSettings(ctx context.Context, app App, collection *Collection) error {
|
||||
return validation.ValidateStruct(f,
|
||||
validation.Field(&f.Id, validation.By(DefaultFieldIdValidationRule)),
|
||||
validation.Field(&f.Name,
|
||||
validation.By(DefaultFieldNameValidationRule),
|
||||
validation.When(f.PrimaryKey, validation.In(idColumn).Error(`The primary key must be named "id".`)),
|
||||
),
|
||||
validation.Field(&f.PrimaryKey, validation.By(f.checkOtherFieldsForPK(collection))),
|
||||
validation.Field(&f.Min, validation.Min(0), validation.Max(maxSafeJSONInt)),
|
||||
validation.Field(&f.Max, validation.Min(f.Min), validation.Max(maxSafeJSONInt)),
|
||||
validation.Field(&f.Pattern, validation.When(f.PrimaryKey, validation.Required), validation.By(validators.IsRegex)),
|
||||
validation.Field(&f.Hidden, validation.When(f.PrimaryKey, validation.Empty)),
|
||||
validation.Field(&f.Required, validation.When(f.PrimaryKey, validation.Required)),
|
||||
validation.Field(&f.AutogeneratePattern, validation.By(validators.IsRegex), validation.By(f.checkAutogeneratePattern)),
|
||||
)
|
||||
}
|
||||
|
||||
func (f *TextField) checkOtherFieldsForPK(collection *Collection) validation.RuleFunc {
|
||||
return func(value any) error {
|
||||
v, _ := value.(bool)
|
||||
if !v {
|
||||
return nil // not a pk
|
||||
}
|
||||
|
||||
totalPrimaryKeys := 0
|
||||
for _, field := range collection.Fields {
|
||||
if text, ok := field.(*TextField); ok && text.PrimaryKey {
|
||||
totalPrimaryKeys++
|
||||
}
|
||||
|
||||
if totalPrimaryKeys > 1 {
|
||||
return validation.NewError("validation_unsupported_composite_pk", "Composite PKs are not supported and the collection must have only 1 PK.")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (f *TextField) checkAutogeneratePattern(value any) error {
|
||||
v, _ := value.(string)
|
||||
if v == "" {
|
||||
return nil // nothing to check
|
||||
}
|
||||
|
||||
// run 10 tests to check for conflicts with the other field validators
|
||||
for i := 0; i < 10; i++ {
|
||||
generated, err := security.RandomStringByRegex(v)
|
||||
if err != nil {
|
||||
return validation.NewError("validation_invalid_autogenerate_pattern", err.Error())
|
||||
}
|
||||
|
||||
// (loosely) check whether the generated pattern satisfies the current field settings
|
||||
if err := f.ValidatePlainValue(generated); err != nil {
|
||||
return validation.NewError(
|
||||
"validation_invalid_autogenerate_pattern_value",
|
||||
fmt.Sprintf("The provided autogenerate pattern could produce invalid field values, ex.: %q", generated),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Intercept implements the [RecordInterceptor] interface.
|
||||
func (f *TextField) Intercept(
|
||||
ctx context.Context,
|
||||
app App,
|
||||
record *Record,
|
||||
actionName string,
|
||||
actionFunc func() error,
|
||||
) error {
|
||||
// set autogenerated value if missing for new records
|
||||
switch actionName {
|
||||
case InterceptorActionValidate, InterceptorActionCreate:
|
||||
if f.AutogeneratePattern != "" && f.hasZeroValue(record) && record.IsNew() {
|
||||
v, err := security.RandomStringByRegex(f.AutogeneratePattern)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to autogenerate %q value: %w", f.Name, err)
|
||||
}
|
||||
record.SetRaw(f.Name, v)
|
||||
}
|
||||
}
|
||||
|
||||
return actionFunc()
|
||||
}
|
||||
|
||||
func (f *TextField) hasZeroValue(record *Record) bool {
|
||||
v, _ := record.GetRaw(f.Name).(string)
|
||||
return v == ""
|
||||
}
|
||||
|
||||
// FindSetter implements the [SetterFinder] interface.
|
||||
func (f *TextField) FindSetter(key string) SetterFunc {
|
||||
switch key {
|
||||
case f.Name:
|
||||
return func(record *Record, raw any) {
|
||||
record.SetRaw(f.Name, cast.ToString(raw))
|
||||
}
|
||||
case f.Name + autogenerateModifier:
|
||||
return func(record *Record, raw any) {
|
||||
v := cast.ToString(raw)
|
||||
|
||||
if f.AutogeneratePattern != "" {
|
||||
generated, _ := security.RandomStringByRegex(f.AutogeneratePattern)
|
||||
v += generated
|
||||
}
|
||||
|
||||
record.SetRaw(f.Name, v)
|
||||
}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
654
core/field_text_test.go
Normal file
654
core/field_text_test.go
Normal file
|
@ -0,0 +1,654 @@
|
|||
package core_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestTextFieldBaseMethods(t *testing.T) {
|
||||
testFieldBaseMethods(t, core.FieldTypeText)
|
||||
}
|
||||
|
||||
func TestTextFieldColumnType(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
f := &core.TextField{}
|
||||
|
||||
expected := "TEXT DEFAULT '' NOT NULL"
|
||||
|
||||
if v := f.ColumnType(app); v != expected {
|
||||
t.Fatalf("Expected\n%q\ngot\n%q", expected, v)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTextFieldPrepareValue(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
f := &core.TextField{}
|
||||
record := core.NewRecord(core.NewBaseCollection("test"))
|
||||
|
||||
scenarios := []struct {
|
||||
raw any
|
||||
expected string
|
||||
}{
|
||||
{"", ""},
|
||||
{"test", "test"},
|
||||
{false, "false"},
|
||||
{true, "true"},
|
||||
{123.456, "123.456"},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%#v", i, s.raw), func(t *testing.T) {
|
||||
v, err := f.PrepareValue(record, s.raw)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
vStr, ok := v.(string)
|
||||
if !ok {
|
||||
t.Fatalf("Expected string instance, got %T", v)
|
||||
}
|
||||
|
||||
if vStr != s.expected {
|
||||
t.Fatalf("Expected %q, got %q", s.expected, v)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTextFieldValidateValue(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
collection, err := app.FindCollectionByNameOrId("demo1")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
existingRecord, err := app.FindFirstRecordByFilter(collection, "id != ''")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
field *core.TextField
|
||||
record func() *core.Record
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
"invalid raw value",
|
||||
&core.TextField{Name: "test"},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", 123)
|
||||
return record
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"zero field value (not required)",
|
||||
&core.TextField{Name: "test", Pattern: `\d+`, Min: 10, Max: 100}, // other fields validators should be ignored
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", "")
|
||||
return record
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"zero field value (required)",
|
||||
&core.TextField{Name: "test", Required: true},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", "")
|
||||
return record
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"non-zero field value (required)",
|
||||
&core.TextField{Name: "test", Required: true},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", "abc")
|
||||
return record
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"special forbidden character / (non-primaryKey)",
|
||||
&core.TextField{Name: "test", PrimaryKey: false},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", "/")
|
||||
return record
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"special forbidden character \\ (non-primaryKey)",
|
||||
&core.TextField{Name: "test", PrimaryKey: false},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", "\\")
|
||||
return record
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"special forbidden character / (primaryKey)",
|
||||
&core.TextField{Name: "test", PrimaryKey: true},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", "/")
|
||||
return record
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"special forbidden character \\ (primaryKey)",
|
||||
&core.TextField{Name: "test", PrimaryKey: true},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", "\\")
|
||||
return record
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"zero field value (primaryKey)",
|
||||
&core.TextField{Name: "test", PrimaryKey: true},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", "")
|
||||
return record
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"non-zero field value (primaryKey)",
|
||||
&core.TextField{Name: "test", PrimaryKey: true},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", "abcd")
|
||||
return record
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"case-insensitive duplicated primary key check",
|
||||
&core.TextField{Name: "test", PrimaryKey: true},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", strings.ToUpper(existingRecord.Id))
|
||||
return record
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"< min",
|
||||
&core.TextField{Name: "test", Min: 4},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", "абв") // multi-byte
|
||||
return record
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
">= min",
|
||||
&core.TextField{Name: "test", Min: 3},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", "абв") // multi-byte
|
||||
return record
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"> default max",
|
||||
&core.TextField{Name: "test"},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", strings.Repeat("a", 5001))
|
||||
return record
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"<= default max",
|
||||
&core.TextField{Name: "test"},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", strings.Repeat("a", 500))
|
||||
return record
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"> max",
|
||||
&core.TextField{Name: "test", Max: 2},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", "абв") // multi-byte
|
||||
return record
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"<= max",
|
||||
&core.TextField{Name: "test", Min: 3},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", "абв") // multi-byte
|
||||
return record
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"mismatched pattern",
|
||||
&core.TextField{Name: "test", Pattern: `\d+`},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", "abc")
|
||||
return record
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"matched pattern",
|
||||
&core.TextField{Name: "test", Pattern: `\d+`},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", "123")
|
||||
return record
|
||||
},
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
err := s.field.ValidateValue(context.Background(), app, s.record())
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTextFieldValidateSettings(t *testing.T) {
|
||||
testDefaultFieldIdValidation(t, core.FieldTypeText)
|
||||
testDefaultFieldNameValidation(t, core.FieldTypeText)
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
field func() *core.TextField
|
||||
expectErrors []string
|
||||
}{
|
||||
{
|
||||
"zero minimal",
|
||||
func() *core.TextField {
|
||||
return &core.TextField{
|
||||
Id: "test",
|
||||
Name: "test",
|
||||
}
|
||||
},
|
||||
[]string{},
|
||||
},
|
||||
{
|
||||
"primaryKey without required",
|
||||
func() *core.TextField {
|
||||
return &core.TextField{
|
||||
Id: "test",
|
||||
Name: "id",
|
||||
PrimaryKey: true,
|
||||
Pattern: `\d+`,
|
||||
}
|
||||
},
|
||||
[]string{"required"},
|
||||
},
|
||||
{
|
||||
"primaryKey without pattern",
|
||||
func() *core.TextField {
|
||||
return &core.TextField{
|
||||
Id: "test",
|
||||
Name: "id",
|
||||
PrimaryKey: true,
|
||||
Required: true,
|
||||
}
|
||||
},
|
||||
[]string{"pattern"},
|
||||
},
|
||||
{
|
||||
"primaryKey with hidden",
|
||||
func() *core.TextField {
|
||||
return &core.TextField{
|
||||
Id: "test",
|
||||
Name: "id",
|
||||
Required: true,
|
||||
PrimaryKey: true,
|
||||
Hidden: true,
|
||||
Pattern: `\d+`,
|
||||
}
|
||||
},
|
||||
[]string{"hidden"},
|
||||
},
|
||||
{
|
||||
"primaryKey with name != id",
|
||||
func() *core.TextField {
|
||||
return &core.TextField{
|
||||
Id: "test",
|
||||
Name: "test",
|
||||
PrimaryKey: true,
|
||||
Required: true,
|
||||
Pattern: `\d+`,
|
||||
}
|
||||
},
|
||||
[]string{"name"},
|
||||
},
|
||||
{
|
||||
"multiple primaryKey fields",
|
||||
func() *core.TextField {
|
||||
return &core.TextField{
|
||||
Id: "test2",
|
||||
Name: "id",
|
||||
PrimaryKey: true,
|
||||
Pattern: `\d+`,
|
||||
Required: true,
|
||||
}
|
||||
},
|
||||
[]string{"primaryKey"},
|
||||
},
|
||||
{
|
||||
"invalid pattern",
|
||||
func() *core.TextField {
|
||||
return &core.TextField{
|
||||
Id: "test2",
|
||||
Name: "id",
|
||||
Pattern: `(invalid`,
|
||||
}
|
||||
},
|
||||
[]string{"pattern"},
|
||||
},
|
||||
{
|
||||
"valid pattern",
|
||||
func() *core.TextField {
|
||||
return &core.TextField{
|
||||
Id: "test2",
|
||||
Name: "id",
|
||||
Pattern: `\d+`,
|
||||
}
|
||||
},
|
||||
[]string{},
|
||||
},
|
||||
{
|
||||
"invalid autogeneratePattern",
|
||||
func() *core.TextField {
|
||||
return &core.TextField{
|
||||
Id: "test2",
|
||||
Name: "id",
|
||||
AutogeneratePattern: `(invalid`,
|
||||
}
|
||||
},
|
||||
[]string{"autogeneratePattern"},
|
||||
},
|
||||
{
|
||||
"valid autogeneratePattern",
|
||||
func() *core.TextField {
|
||||
return &core.TextField{
|
||||
Id: "test2",
|
||||
Name: "id",
|
||||
AutogeneratePattern: `[a-z]+`,
|
||||
}
|
||||
},
|
||||
[]string{},
|
||||
},
|
||||
{
|
||||
"conflicting pattern and autogeneratePattern",
|
||||
func() *core.TextField {
|
||||
return &core.TextField{
|
||||
Id: "test2",
|
||||
Name: "id",
|
||||
Pattern: `\d+`,
|
||||
AutogeneratePattern: `[a-z]+`,
|
||||
}
|
||||
},
|
||||
[]string{"autogeneratePattern"},
|
||||
},
|
||||
{
|
||||
"Max > safe json int",
|
||||
func() *core.TextField {
|
||||
return &core.TextField{
|
||||
Id: "test",
|
||||
Name: "test",
|
||||
Max: 1 << 53,
|
||||
}
|
||||
},
|
||||
[]string{"max"},
|
||||
},
|
||||
{
|
||||
"Max < 0",
|
||||
func() *core.TextField {
|
||||
return &core.TextField{
|
||||
Id: "test",
|
||||
Name: "test",
|
||||
Max: -1,
|
||||
}
|
||||
},
|
||||
[]string{"max"},
|
||||
},
|
||||
{
|
||||
"Min > safe json int",
|
||||
func() *core.TextField {
|
||||
return &core.TextField{
|
||||
Id: "test",
|
||||
Name: "test",
|
||||
Min: 1 << 53,
|
||||
}
|
||||
},
|
||||
[]string{"min"},
|
||||
},
|
||||
{
|
||||
"Min < 0",
|
||||
func() *core.TextField {
|
||||
return &core.TextField{
|
||||
Id: "test",
|
||||
Name: "test",
|
||||
Min: -1,
|
||||
}
|
||||
},
|
||||
[]string{"min"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
field := s.field()
|
||||
|
||||
collection := core.NewBaseCollection("test_collection")
|
||||
collection.Fields.GetByName("id").SetId("test") // set a dummy known id so that it can be replaced
|
||||
collection.Fields.Add(field)
|
||||
|
||||
errs := field.ValidateSettings(context.Background(), app, collection)
|
||||
|
||||
tests.TestValidationErrors(t, errs, s.expectErrors)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTextFieldAutogenerate(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
collection := core.NewBaseCollection("test_collection")
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
actionName string
|
||||
field *core.TextField
|
||||
record func() *core.Record
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
"non-matching action",
|
||||
core.InterceptorActionUpdate,
|
||||
&core.TextField{Name: "test", AutogeneratePattern: "abc"},
|
||||
func() *core.Record {
|
||||
return core.NewRecord(collection)
|
||||
},
|
||||
"",
|
||||
},
|
||||
{
|
||||
"matching action (create)",
|
||||
core.InterceptorActionCreate,
|
||||
&core.TextField{Name: "test", AutogeneratePattern: "abc"},
|
||||
func() *core.Record {
|
||||
return core.NewRecord(collection)
|
||||
},
|
||||
"abc",
|
||||
},
|
||||
{
|
||||
"matching action (validate)",
|
||||
core.InterceptorActionValidate,
|
||||
&core.TextField{Name: "test", AutogeneratePattern: "abc"},
|
||||
func() *core.Record {
|
||||
return core.NewRecord(collection)
|
||||
},
|
||||
"abc",
|
||||
},
|
||||
{
|
||||
"existing non-zero value",
|
||||
core.InterceptorActionCreate,
|
||||
&core.TextField{Name: "test", AutogeneratePattern: "abc"},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", "123")
|
||||
return record
|
||||
},
|
||||
"123",
|
||||
},
|
||||
{
|
||||
"non-new record",
|
||||
core.InterceptorActionValidate,
|
||||
&core.TextField{Name: "test", AutogeneratePattern: "abc"},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.Id = "test"
|
||||
record.PostScan()
|
||||
return record
|
||||
},
|
||||
"",
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
actionCalls := 0
|
||||
record := s.record()
|
||||
|
||||
err := s.field.Intercept(context.Background(), app, record, s.actionName, func() error {
|
||||
actionCalls++
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if actionCalls != 1 {
|
||||
t.Fatalf("Expected actionCalls %d, got %d", 1, actionCalls)
|
||||
}
|
||||
|
||||
v := record.GetString(s.field.GetName())
|
||||
if v != s.expected {
|
||||
t.Fatalf("Expected value %q, got %q", s.expected, v)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTextFieldFindSetter(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
name string
|
||||
key string
|
||||
value any
|
||||
field *core.TextField
|
||||
hasSetter bool
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
"no match",
|
||||
"example",
|
||||
"abc",
|
||||
&core.TextField{Name: "test", AutogeneratePattern: "test"},
|
||||
false,
|
||||
"",
|
||||
},
|
||||
{
|
||||
"exact match",
|
||||
"test",
|
||||
"abc",
|
||||
&core.TextField{Name: "test", AutogeneratePattern: "test"},
|
||||
true,
|
||||
"abc",
|
||||
},
|
||||
{
|
||||
"autogenerate modifier",
|
||||
"test:autogenerate",
|
||||
"abc",
|
||||
&core.TextField{Name: "test", AutogeneratePattern: "test"},
|
||||
true,
|
||||
"abctest",
|
||||
},
|
||||
{
|
||||
"autogenerate modifier without AutogeneratePattern option",
|
||||
"test:autogenerate",
|
||||
"abc",
|
||||
&core.TextField{Name: "test"},
|
||||
true,
|
||||
"abc",
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
collection := core.NewBaseCollection("test_collection")
|
||||
collection.Fields.Add(s.field)
|
||||
|
||||
setter := s.field.FindSetter(s.key)
|
||||
|
||||
hasSetter := setter != nil
|
||||
if hasSetter != s.hasSetter {
|
||||
t.Fatalf("Expected hasSetter %v, got %v", s.hasSetter, hasSetter)
|
||||
}
|
||||
|
||||
if !hasSetter {
|
||||
return
|
||||
}
|
||||
|
||||
record := core.NewRecord(collection)
|
||||
|
||||
setter(record, s.value)
|
||||
|
||||
result := record.GetString(s.field.Name)
|
||||
|
||||
if result != s.expected {
|
||||
t.Fatalf("Expected %q, got %q", s.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
168
core/field_url.go
Normal file
168
core/field_url.go
Normal file
|
@ -0,0 +1,168 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/url"
|
||||
"slices"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/go-ozzo/ozzo-validation/v4/is"
|
||||
"github.com/pocketbase/pocketbase/core/validators"
|
||||
"github.com/spf13/cast"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Fields[FieldTypeURL] = func() Field {
|
||||
return &URLField{}
|
||||
}
|
||||
}
|
||||
|
||||
const FieldTypeURL = "url"
|
||||
|
||||
var _ Field = (*URLField)(nil)
|
||||
|
||||
// URLField defines "url" type field for storing a single URL string value.
|
||||
//
|
||||
// The respective zero record field value is empty string.
|
||||
type URLField struct {
|
||||
// Name (required) is the unique name of the field.
|
||||
Name string `form:"name" json:"name"`
|
||||
|
||||
// Id is the unique stable field identifier.
|
||||
//
|
||||
// It is automatically generated from the name when adding to a collection FieldsList.
|
||||
Id string `form:"id" json:"id"`
|
||||
|
||||
// System prevents the renaming and removal of the field.
|
||||
System bool `form:"system" json:"system"`
|
||||
|
||||
// Hidden hides the field from the API response.
|
||||
Hidden bool `form:"hidden" json:"hidden"`
|
||||
|
||||
// Presentable hints the Dashboard UI to use the underlying
|
||||
// field record value in the relation preview label.
|
||||
Presentable bool `form:"presentable" json:"presentable"`
|
||||
|
||||
// ---
|
||||
|
||||
// ExceptDomains will require the URL domain to NOT be included in the listed ones.
|
||||
//
|
||||
// This validator can be set only if OnlyDomains is empty.
|
||||
ExceptDomains []string `form:"exceptDomains" json:"exceptDomains"`
|
||||
|
||||
// OnlyDomains will require the URL domain to be included in the listed ones.
|
||||
//
|
||||
// This validator can be set only if ExceptDomains is empty.
|
||||
OnlyDomains []string `form:"onlyDomains" json:"onlyDomains"`
|
||||
|
||||
// Required will require the field value to be non-empty URL string.
|
||||
Required bool `form:"required" json:"required"`
|
||||
}
|
||||
|
||||
// Type implements [Field.Type] interface method.
|
||||
func (f *URLField) Type() string {
|
||||
return FieldTypeURL
|
||||
}
|
||||
|
||||
// GetId implements [Field.GetId] interface method.
|
||||
func (f *URLField) GetId() string {
|
||||
return f.Id
|
||||
}
|
||||
|
||||
// SetId implements [Field.SetId] interface method.
|
||||
func (f *URLField) SetId(id string) {
|
||||
f.Id = id
|
||||
}
|
||||
|
||||
// GetName implements [Field.GetName] interface method.
|
||||
func (f *URLField) GetName() string {
|
||||
return f.Name
|
||||
}
|
||||
|
||||
// SetName implements [Field.SetName] interface method.
|
||||
func (f *URLField) SetName(name string) {
|
||||
f.Name = name
|
||||
}
|
||||
|
||||
// GetSystem implements [Field.GetSystem] interface method.
|
||||
func (f *URLField) GetSystem() bool {
|
||||
return f.System
|
||||
}
|
||||
|
||||
// SetSystem implements [Field.SetSystem] interface method.
|
||||
func (f *URLField) SetSystem(system bool) {
|
||||
f.System = system
|
||||
}
|
||||
|
||||
// GetHidden implements [Field.GetHidden] interface method.
|
||||
func (f *URLField) GetHidden() bool {
|
||||
return f.Hidden
|
||||
}
|
||||
|
||||
// SetHidden implements [Field.SetHidden] interface method.
|
||||
func (f *URLField) SetHidden(hidden bool) {
|
||||
f.Hidden = hidden
|
||||
}
|
||||
|
||||
// ColumnType implements [Field.ColumnType] interface method.
|
||||
func (f *URLField) ColumnType(app App) string {
|
||||
return "TEXT DEFAULT '' NOT NULL"
|
||||
}
|
||||
|
||||
// PrepareValue implements [Field.PrepareValue] interface method.
|
||||
func (f *URLField) PrepareValue(record *Record, raw any) (any, error) {
|
||||
return cast.ToString(raw), nil
|
||||
}
|
||||
|
||||
// ValidateValue implements [Field.ValidateValue] interface method.
|
||||
func (f *URLField) ValidateValue(ctx context.Context, app App, record *Record) error {
|
||||
val, ok := record.GetRaw(f.Name).(string)
|
||||
if !ok {
|
||||
return validators.ErrUnsupportedValueType
|
||||
}
|
||||
|
||||
if f.Required {
|
||||
if err := validation.Required.Validate(val); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if val == "" {
|
||||
return nil // nothing to check
|
||||
}
|
||||
|
||||
if is.URL.Validate(val) != nil {
|
||||
return validation.NewError("validation_invalid_url", "Must be a valid url")
|
||||
}
|
||||
|
||||
// extract host/domain
|
||||
u, _ := url.Parse(val)
|
||||
|
||||
// only domains check
|
||||
if len(f.OnlyDomains) > 0 && !slices.Contains(f.OnlyDomains, u.Host) {
|
||||
return validation.NewError("validation_url_domain_not_allowed", "Url domain is not allowed")
|
||||
}
|
||||
|
||||
// except domains check
|
||||
if len(f.ExceptDomains) > 0 && slices.Contains(f.ExceptDomains, u.Host) {
|
||||
return validation.NewError("validation_url_domain_not_allowed", "Url domain is not allowed")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateSettings implements [Field.ValidateSettings] interface method.
|
||||
func (f *URLField) ValidateSettings(ctx context.Context, app App, collection *Collection) error {
|
||||
return validation.ValidateStruct(f,
|
||||
validation.Field(&f.Id, validation.By(DefaultFieldIdValidationRule)),
|
||||
validation.Field(&f.Name, validation.By(DefaultFieldNameValidationRule)),
|
||||
validation.Field(
|
||||
&f.ExceptDomains,
|
||||
validation.When(len(f.OnlyDomains) > 0, validation.Empty).Else(validation.Each(is.Domain)),
|
||||
),
|
||||
validation.Field(
|
||||
&f.OnlyDomains,
|
||||
validation.When(len(f.ExceptDomains) > 0, validation.Empty).Else(validation.Each(is.Domain)),
|
||||
),
|
||||
)
|
||||
}
|
271
core/field_url_test.go
Normal file
271
core/field_url_test.go
Normal file
|
@ -0,0 +1,271 @@
|
|||
package core_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestURLFieldBaseMethods(t *testing.T) {
|
||||
testFieldBaseMethods(t, core.FieldTypeURL)
|
||||
}
|
||||
|
||||
func TestURLFieldColumnType(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
f := &core.URLField{}
|
||||
|
||||
expected := "TEXT DEFAULT '' NOT NULL"
|
||||
|
||||
if v := f.ColumnType(app); v != expected {
|
||||
t.Fatalf("Expected\n%q\ngot\n%q", expected, v)
|
||||
}
|
||||
}
|
||||
|
||||
func TestURLFieldPrepareValue(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
f := &core.URLField{}
|
||||
record := core.NewRecord(core.NewBaseCollection("test"))
|
||||
|
||||
scenarios := []struct {
|
||||
raw any
|
||||
expected string
|
||||
}{
|
||||
{"", ""},
|
||||
{"test", "test"},
|
||||
{false, "false"},
|
||||
{true, "true"},
|
||||
{123.456, "123.456"},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%#v", i, s.raw), func(t *testing.T) {
|
||||
v, err := f.PrepareValue(record, s.raw)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
vStr, ok := v.(string)
|
||||
if !ok {
|
||||
t.Fatalf("Expected string instance, got %T", v)
|
||||
}
|
||||
|
||||
if vStr != s.expected {
|
||||
t.Fatalf("Expected %q, got %q", s.expected, v)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestURLFieldValidateValue(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
collection := core.NewBaseCollection("test_collection")
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
field *core.URLField
|
||||
record func() *core.Record
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
"invalid raw value",
|
||||
&core.URLField{Name: "test"},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", 123)
|
||||
return record
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"zero field value (not required)",
|
||||
&core.URLField{Name: "test"},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", "")
|
||||
return record
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"zero field value (required)",
|
||||
&core.URLField{Name: "test", Required: true},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", "")
|
||||
return record
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"non-zero field value (required)",
|
||||
&core.URLField{Name: "test", Required: true},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", "https://example.com")
|
||||
return record
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"invalid url",
|
||||
&core.URLField{Name: "test"},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", "invalid")
|
||||
return record
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"failed onlyDomains",
|
||||
&core.URLField{Name: "test", OnlyDomains: []string{"example.org", "example.net"}},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", "https://example.com")
|
||||
return record
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"success onlyDomains",
|
||||
&core.URLField{Name: "test", OnlyDomains: []string{"example.org", "example.com"}},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", "https://example.com")
|
||||
return record
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"failed exceptDomains",
|
||||
&core.URLField{Name: "test", ExceptDomains: []string{"example.org", "example.com"}},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", "https://example.com")
|
||||
return record
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"success exceptDomains",
|
||||
&core.URLField{Name: "test", ExceptDomains: []string{"example.org", "example.net"}},
|
||||
func() *core.Record {
|
||||
record := core.NewRecord(collection)
|
||||
record.SetRaw("test", "https://example.com")
|
||||
return record
|
||||
},
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
err := s.field.ValidateValue(context.Background(), app, s.record())
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestURLFieldValidateSettings(t *testing.T) {
|
||||
testDefaultFieldIdValidation(t, core.FieldTypeURL)
|
||||
testDefaultFieldNameValidation(t, core.FieldTypeURL)
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
collection := core.NewBaseCollection("test_collection")
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
field func() *core.URLField
|
||||
expectErrors []string
|
||||
}{
|
||||
{
|
||||
"zero minimal",
|
||||
func() *core.URLField {
|
||||
return &core.URLField{
|
||||
Id: "test",
|
||||
Name: "test",
|
||||
}
|
||||
},
|
||||
[]string{},
|
||||
},
|
||||
{
|
||||
"both onlyDomains and exceptDomains",
|
||||
func() *core.URLField {
|
||||
return &core.URLField{
|
||||
Id: "test",
|
||||
Name: "test",
|
||||
OnlyDomains: []string{"example.com"},
|
||||
ExceptDomains: []string{"example.org"},
|
||||
}
|
||||
},
|
||||
[]string{"onlyDomains", "exceptDomains"},
|
||||
},
|
||||
{
|
||||
"invalid onlyDomains",
|
||||
func() *core.URLField {
|
||||
return &core.URLField{
|
||||
Id: "test",
|
||||
Name: "test",
|
||||
OnlyDomains: []string{"example.com", "invalid"},
|
||||
}
|
||||
},
|
||||
[]string{"onlyDomains"},
|
||||
},
|
||||
{
|
||||
"valid onlyDomains",
|
||||
func() *core.URLField {
|
||||
return &core.URLField{
|
||||
Id: "test",
|
||||
Name: "test",
|
||||
OnlyDomains: []string{"example.com", "example.org"},
|
||||
}
|
||||
},
|
||||
[]string{},
|
||||
},
|
||||
{
|
||||
"invalid exceptDomains",
|
||||
func() *core.URLField {
|
||||
return &core.URLField{
|
||||
Id: "test",
|
||||
Name: "test",
|
||||
ExceptDomains: []string{"example.com", "invalid"},
|
||||
}
|
||||
},
|
||||
[]string{"exceptDomains"},
|
||||
},
|
||||
{
|
||||
"valid exceptDomains",
|
||||
func() *core.URLField {
|
||||
return &core.URLField{
|
||||
Id: "test",
|
||||
Name: "test",
|
||||
ExceptDomains: []string{"example.com", "example.org"},
|
||||
}
|
||||
},
|
||||
[]string{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
errs := s.field().ValidateSettings(context.Background(), app, collection)
|
||||
|
||||
tests.TestValidationErrors(t, errs, s.expectErrors)
|
||||
})
|
||||
}
|
||||
}
|
388
core/fields_list.go
Normal file
388
core/fields_list.go
Normal file
|
@ -0,0 +1,388 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// NewFieldsList creates a new FieldsList instance with the provided fields.
|
||||
func NewFieldsList(fields ...Field) FieldsList {
|
||||
l := make(FieldsList, 0, len(fields))
|
||||
|
||||
for _, f := range fields {
|
||||
l.add(-1, f)
|
||||
}
|
||||
|
||||
return l
|
||||
}
|
||||
|
||||
// FieldsList defines a Collection slice of fields.
|
||||
type FieldsList []Field
|
||||
|
||||
// Clone creates a deep clone of the current list.
|
||||
func (l FieldsList) Clone() (FieldsList, error) {
|
||||
copyRaw, err := json.Marshal(l)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := FieldsList{}
|
||||
if err := json.Unmarshal(copyRaw, &result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// FieldNames returns a slice with the name of all list fields.
|
||||
func (l FieldsList) FieldNames() []string {
|
||||
result := make([]string, len(l))
|
||||
|
||||
for i, field := range l {
|
||||
result[i] = field.GetName()
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// AsMap returns a map with all registered list field.
|
||||
// The returned map is indexed with each field name.
|
||||
func (l FieldsList) AsMap() map[string]Field {
|
||||
result := make(map[string]Field, len(l))
|
||||
|
||||
for _, field := range l {
|
||||
result[field.GetName()] = field
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// GetById returns a single field by its id.
|
||||
func (l FieldsList) GetById(fieldId string) Field {
|
||||
for _, field := range l {
|
||||
if field.GetId() == fieldId {
|
||||
return field
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetByName returns a single field by its name.
|
||||
func (l FieldsList) GetByName(fieldName string) Field {
|
||||
for _, field := range l {
|
||||
if field.GetName() == fieldName {
|
||||
return field
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveById removes a single field by its id.
|
||||
//
|
||||
// This method does nothing if field with the specified id doesn't exist.
|
||||
func (l *FieldsList) RemoveById(fieldId string) {
|
||||
fields := *l
|
||||
for i, field := range fields {
|
||||
if field.GetId() == fieldId {
|
||||
*l = append(fields[:i], fields[i+1:]...)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RemoveByName removes a single field by its name.
|
||||
//
|
||||
// This method does nothing if field with the specified name doesn't exist.
|
||||
func (l *FieldsList) RemoveByName(fieldName string) {
|
||||
fields := *l
|
||||
for i, field := range fields {
|
||||
if field.GetName() == fieldName {
|
||||
*l = append(fields[:i], fields[i+1:]...)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add adds one or more fields to the current list.
|
||||
//
|
||||
// By default this method will try to REPLACE existing fields with
|
||||
// the new ones by their id or by their name if the new field doesn't have an explicit id.
|
||||
//
|
||||
// If no matching existing field is found, it will APPEND the field to the end of the list.
|
||||
//
|
||||
// In all cases, if any of the new fields don't have an explicit id it will auto generate a default one for them
|
||||
// (the id value doesn't really matter and it is mostly used as a stable identifier in case of a field rename).
|
||||
func (l *FieldsList) Add(fields ...Field) {
|
||||
for _, f := range fields {
|
||||
l.add(-1, f)
|
||||
}
|
||||
}
|
||||
|
||||
// AddAt is the same as Add but insert/move the fields at the specific position.
|
||||
//
|
||||
// If pos < 0, then this method acts the same as calling Add.
|
||||
//
|
||||
// If pos > FieldsList total items, then the specified fields are inserted/moved at the end of the list.
|
||||
func (l *FieldsList) AddAt(pos int, fields ...Field) {
|
||||
total := len(*l)
|
||||
|
||||
for i, f := range fields {
|
||||
if pos < 0 {
|
||||
l.add(-1, f)
|
||||
} else if pos > total {
|
||||
l.add(total+i, f)
|
||||
} else {
|
||||
l.add(pos+i, f)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// AddMarshaledJSON parses the provided raw json data and adds the
|
||||
// found fields into the current list (following the same rule as the Add method).
|
||||
//
|
||||
// The rawJSON argument could be one of:
|
||||
// - serialized array of field objects
|
||||
// - single field object.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// l.AddMarshaledJSON([]byte{`{"type":"text", name: "test"}`})
|
||||
// l.AddMarshaledJSON([]byte{`[{"type":"text", name: "test1"}, {"type":"text", name: "test2"}]`})
|
||||
func (l *FieldsList) AddMarshaledJSON(rawJSON []byte) error {
|
||||
extractedFields, err := marshaledJSONtoFieldsList(rawJSON)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
l.Add(extractedFields...)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddMarshaledJSONAt is the same as AddMarshaledJSON but insert/move the fields at the specific position.
|
||||
//
|
||||
// If pos < 0, then this method acts the same as calling AddMarshaledJSON.
|
||||
//
|
||||
// If pos > FieldsList total items, then the specified fields are inserted/moved at the end of the list.
|
||||
func (l *FieldsList) AddMarshaledJSONAt(pos int, rawJSON []byte) error {
|
||||
extractedFields, err := marshaledJSONtoFieldsList(rawJSON)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
l.AddAt(pos, extractedFields...)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func marshaledJSONtoFieldsList(rawJSON []byte) (FieldsList, error) {
|
||||
extractedFields := FieldsList{}
|
||||
|
||||
// nothing to add
|
||||
if len(rawJSON) == 0 {
|
||||
return extractedFields, nil
|
||||
}
|
||||
|
||||
// try to unmarshal first into a new fieds list
|
||||
// (assuming that rawJSON is array of objects)
|
||||
err := json.Unmarshal(rawJSON, &extractedFields)
|
||||
if err != nil {
|
||||
// try again but wrap the rawJSON in []
|
||||
// (assuming that rawJSON is a single object)
|
||||
wrapped := make([]byte, 0, len(rawJSON)+2)
|
||||
wrapped = append(wrapped, '[')
|
||||
wrapped = append(wrapped, rawJSON...)
|
||||
wrapped = append(wrapped, ']')
|
||||
err = json.Unmarshal(wrapped, &extractedFields)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal the provided JSON - expects array of objects or just single object: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return extractedFields, nil
|
||||
}
|
||||
|
||||
func (l *FieldsList) add(pos int, newField Field) {
|
||||
fields := *l
|
||||
|
||||
var replaceByName bool
|
||||
var replaceInPlace bool
|
||||
|
||||
if pos < 0 {
|
||||
replaceInPlace = true
|
||||
pos = len(fields)
|
||||
} else if pos > len(fields) {
|
||||
pos = len(fields)
|
||||
}
|
||||
|
||||
newFieldId := newField.GetId()
|
||||
|
||||
// set default id
|
||||
if newFieldId == "" {
|
||||
replaceByName = true
|
||||
|
||||
baseId := newField.Type() + crc32Checksum(newField.GetName())
|
||||
newFieldId = baseId
|
||||
for i := 2; i < 1000; i++ {
|
||||
if l.GetById(newFieldId) == nil {
|
||||
break // already unique
|
||||
}
|
||||
newFieldId = baseId + strconv.Itoa(i)
|
||||
}
|
||||
newField.SetId(newFieldId)
|
||||
}
|
||||
|
||||
// try to replace existing
|
||||
for i, field := range fields {
|
||||
if replaceByName {
|
||||
if name := newField.GetName(); name != "" && field.GetName() == name {
|
||||
// reuse the original id
|
||||
newField.SetId(field.GetId())
|
||||
|
||||
if replaceInPlace {
|
||||
(*l)[i] = newField
|
||||
return
|
||||
} else {
|
||||
// remove the current field and insert it later at the specific position
|
||||
*l = slices.Delete(*l, i, i+1)
|
||||
if total := len(*l); pos > total {
|
||||
pos = total
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if field.GetId() == newFieldId {
|
||||
if replaceInPlace {
|
||||
(*l)[i] = newField
|
||||
return
|
||||
} else {
|
||||
// remove the current field and insert it later at the specific position
|
||||
*l = slices.Delete(*l, i, i+1)
|
||||
if total := len(*l); pos > total {
|
||||
pos = total
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// insert the new field
|
||||
*l = slices.Insert(*l, pos, newField)
|
||||
}
|
||||
|
||||
// String returns the string representation of the current list.
|
||||
func (l FieldsList) String() string {
|
||||
v, _ := json.Marshal(l)
|
||||
return string(v)
|
||||
}
|
||||
|
||||
type onlyFieldType struct {
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
type fieldWithType struct {
|
||||
Field
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
func (fwt *fieldWithType) UnmarshalJSON(data []byte) error {
|
||||
// extract the field type to init a blank factory
|
||||
t := &onlyFieldType{}
|
||||
if err := json.Unmarshal(data, t); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal field type: %w", err)
|
||||
}
|
||||
|
||||
factory, ok := Fields[t.Type]
|
||||
if !ok {
|
||||
return fmt.Errorf("missing or unknown field type in %s", data)
|
||||
}
|
||||
|
||||
fwt.Type = t.Type
|
||||
fwt.Field = factory()
|
||||
|
||||
// unmarshal the rest of the data into the created field
|
||||
if err := json.Unmarshal(data, fwt.Field); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal field: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements [json.Unmarshaler] and
|
||||
// loads the provided json data into the current FieldsList.
|
||||
func (l *FieldsList) UnmarshalJSON(data []byte) error {
|
||||
fwts := []fieldWithType{}
|
||||
|
||||
if err := json.Unmarshal(data, &fwts); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*l = []Field{} // reset
|
||||
|
||||
for _, fwt := range fwts {
|
||||
l.add(-1, fwt.Field)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarshalJSON implements the [json.Marshaler] interface.
|
||||
func (l FieldsList) MarshalJSON() ([]byte, error) {
|
||||
if l == nil {
|
||||
l = []Field{} // always init to ensure that it is serialized as empty array
|
||||
}
|
||||
|
||||
wrapper := make([]map[string]any, 0, len(l))
|
||||
|
||||
for _, f := range l {
|
||||
// precompute the json into a map so that we can append the type to a flatten object
|
||||
raw, err := json.Marshal(f)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
data := map[string]any{}
|
||||
if err := json.Unmarshal(raw, &data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
data["type"] = f.Type()
|
||||
|
||||
wrapper = append(wrapper, data)
|
||||
}
|
||||
|
||||
return json.Marshal(wrapper)
|
||||
}
|
||||
|
||||
// Value implements the [driver.Valuer] interface.
|
||||
func (l FieldsList) Value() (driver.Value, error) {
|
||||
data, err := json.Marshal(l)
|
||||
|
||||
return string(data), err
|
||||
}
|
||||
|
||||
// Scan implements [sql.Scanner] interface to scan the provided value
|
||||
// into the current FieldsList instance.
|
||||
func (l *FieldsList) Scan(value any) error {
|
||||
var data []byte
|
||||
switch v := value.(type) {
|
||||
case nil:
|
||||
// no cast needed
|
||||
case []byte:
|
||||
data = v
|
||||
case string:
|
||||
data = []byte(v)
|
||||
default:
|
||||
return fmt.Errorf("failed to unmarshal FieldsList value %q", value)
|
||||
}
|
||||
|
||||
if len(data) == 0 {
|
||||
data = []byte("[]")
|
||||
}
|
||||
|
||||
return l.UnmarshalJSON(data)
|
||||
}
|
558
core/fields_list_test.go
Normal file
558
core/fields_list_test.go
Normal file
|
@ -0,0 +1,558 @@
|
|||
package core_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
)
|
||||
|
||||
func TestNewFieldsList(t *testing.T) {
|
||||
fields := core.NewFieldsList(
|
||||
&core.TextField{Id: "id1", Name: "test1"},
|
||||
&core.TextField{Name: "test2"},
|
||||
&core.TextField{Id: "id1", Name: "test1_new"}, // should replace the original id1 field
|
||||
)
|
||||
|
||||
if len(fields) != 2 {
|
||||
t.Fatalf("Expected 2 fields, got %d (%v)", len(fields), fields)
|
||||
}
|
||||
|
||||
for _, f := range fields {
|
||||
if f.GetId() == "" {
|
||||
t.Fatalf("Expected field id to be set, found empty id for field %v", f)
|
||||
}
|
||||
}
|
||||
|
||||
if fields[0].GetName() != "test1_new" {
|
||||
t.Fatalf("Expected field with name test1_new, got %s", fields[0].GetName())
|
||||
}
|
||||
|
||||
if fields[1].GetName() != "test2" {
|
||||
t.Fatalf("Expected field with name test2, got %s", fields[1].GetName())
|
||||
}
|
||||
}
|
||||
|
||||
func TestFieldsListClone(t *testing.T) {
|
||||
f1 := &core.TextField{Name: "test1"}
|
||||
f2 := &core.EmailField{Name: "test2"}
|
||||
s1 := core.NewFieldsList(f1, f2)
|
||||
|
||||
s2, err := s1.Clone()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
s1Str := s1.String()
|
||||
s2Str := s2.String()
|
||||
|
||||
if s1Str != s2Str {
|
||||
t.Fatalf("Expected the cloned list to be equal, got \n%v\nVS\n%v", s1, s2)
|
||||
}
|
||||
|
||||
// change in one list shouldn't result to change in the other
|
||||
// (aka. check if it is a deep clone)
|
||||
s1[0].SetName("test1_update")
|
||||
if s2[0].GetName() != "test1" {
|
||||
t.Fatalf("Expected s2 field name to not change, got %q", s2[0].GetName())
|
||||
}
|
||||
}
|
||||
|
||||
func TestFieldsListFieldNames(t *testing.T) {
|
||||
f1 := &core.TextField{Name: "test1"}
|
||||
f2 := &core.EmailField{Name: "test2"}
|
||||
testFieldsList := core.NewFieldsList(f1, f2)
|
||||
|
||||
result := testFieldsList.FieldNames()
|
||||
|
||||
expected := []string{f1.Name, f2.Name}
|
||||
|
||||
if len(result) != len(expected) {
|
||||
t.Fatalf("Expected %d slice elements, got %d\n%v", len(expected), len(result), result)
|
||||
}
|
||||
|
||||
for _, name := range expected {
|
||||
if !slices.Contains(result, name) {
|
||||
t.Fatalf("Missing name %q in %v", name, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFieldsListAsMap(t *testing.T) {
|
||||
f1 := &core.TextField{Name: "test1"}
|
||||
f2 := &core.EmailField{Name: "test2"}
|
||||
testFieldsList := core.NewFieldsList(f1, f2)
|
||||
|
||||
result := testFieldsList.AsMap()
|
||||
|
||||
expectedIndexes := []string{f1.Name, f2.Name}
|
||||
|
||||
if len(result) != len(expectedIndexes) {
|
||||
t.Fatalf("Expected %d map elements, got %d\n%v", len(expectedIndexes), len(result), result)
|
||||
}
|
||||
|
||||
for _, index := range expectedIndexes {
|
||||
if _, ok := result[index]; !ok {
|
||||
t.Fatalf("Missing index %q", index)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFieldsListGetById(t *testing.T) {
|
||||
f1 := &core.TextField{Id: "id1", Name: "test1"}
|
||||
f2 := &core.EmailField{Id: "id2", Name: "test2"}
|
||||
testFieldsList := core.NewFieldsList(f1, f2)
|
||||
|
||||
// missing field id
|
||||
result1 := testFieldsList.GetById("test1")
|
||||
if result1 != nil {
|
||||
t.Fatalf("Found unexpected field %v", result1)
|
||||
}
|
||||
|
||||
// existing field id
|
||||
result2 := testFieldsList.GetById("id2")
|
||||
if result2 == nil || result2.GetId() != "id2" {
|
||||
t.Fatalf("Cannot find field with id %q, got %v ", "id2", result2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFieldsListGetByName(t *testing.T) {
|
||||
f1 := &core.TextField{Id: "id1", Name: "test1"}
|
||||
f2 := &core.EmailField{Id: "id2", Name: "test2"}
|
||||
testFieldsList := core.NewFieldsList(f1, f2)
|
||||
|
||||
// missing field name
|
||||
result1 := testFieldsList.GetByName("id1")
|
||||
if result1 != nil {
|
||||
t.Fatalf("Found unexpected field %v", result1)
|
||||
}
|
||||
|
||||
// existing field name
|
||||
result2 := testFieldsList.GetByName("test2")
|
||||
if result2 == nil || result2.GetName() != "test2" {
|
||||
t.Fatalf("Cannot find field with name %q, got %v ", "test2", result2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFieldsListRemove(t *testing.T) {
|
||||
testFieldsList := core.NewFieldsList(
|
||||
&core.TextField{Id: "id1", Name: "test1"},
|
||||
&core.TextField{Id: "id2", Name: "test2"},
|
||||
&core.TextField{Id: "id3", Name: "test3"},
|
||||
&core.TextField{Id: "id4", Name: "test4"},
|
||||
&core.TextField{Id: "id5", Name: "test5"},
|
||||
&core.TextField{Id: "id6", Name: "test6"},
|
||||
)
|
||||
|
||||
// remove by id
|
||||
testFieldsList.RemoveById("id2")
|
||||
testFieldsList.RemoveById("test3") // should do nothing
|
||||
|
||||
// remove by name
|
||||
testFieldsList.RemoveByName("test5")
|
||||
testFieldsList.RemoveByName("id6") // should do nothing
|
||||
|
||||
expected := []string{"test1", "test3", "test4", "test6"}
|
||||
|
||||
if len(testFieldsList) != len(expected) {
|
||||
t.Fatalf("Expected %d, got %d\n%v", len(expected), len(testFieldsList), testFieldsList)
|
||||
}
|
||||
|
||||
for _, name := range expected {
|
||||
if f := testFieldsList.GetByName(name); f == nil {
|
||||
t.Fatalf("Missing field %q", name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFieldsListAdd(t *testing.T) {
|
||||
f0 := &core.TextField{}
|
||||
f1 := &core.TextField{Name: "test1"}
|
||||
f2 := &core.TextField{Id: "f2Id", Name: "test2"}
|
||||
f3 := &core.TextField{Id: "f3Id", Name: "test3"}
|
||||
testFieldsList := core.NewFieldsList(f0, f1, f2, f3)
|
||||
|
||||
f2New := &core.EmailField{Id: "f2Id", Name: "test2_new"}
|
||||
f4 := &core.URLField{Name: "test4"}
|
||||
|
||||
testFieldsList.Add(f2New)
|
||||
testFieldsList.Add(f4)
|
||||
|
||||
if len(testFieldsList) != 5 {
|
||||
t.Fatalf("Expected %d, got %d\n%v", 5, len(testFieldsList), testFieldsList)
|
||||
}
|
||||
|
||||
// check if each field has id
|
||||
for _, f := range testFieldsList {
|
||||
if f.GetId() == "" {
|
||||
t.Fatalf("Expected field id to be set, found empty id for field %v", f)
|
||||
}
|
||||
}
|
||||
|
||||
// check if f2 field was replaced
|
||||
if f := testFieldsList.GetById("f2Id"); f == nil || f.Type() != core.FieldTypeEmail {
|
||||
t.Fatalf("Expected f2 field to be replaced, found %v", f)
|
||||
}
|
||||
|
||||
// check if f4 was added
|
||||
if f := testFieldsList.GetByName("test4"); f == nil || f.GetName() != "test4" {
|
||||
t.Fatalf("Expected f4 field to be added, found %v", f)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFieldsListAddMarshaledJSON(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
raw []byte
|
||||
expectError bool
|
||||
expectedFields map[string]string
|
||||
}{
|
||||
{
|
||||
"nil",
|
||||
nil,
|
||||
false,
|
||||
map[string]string{"abc": core.FieldTypeNumber},
|
||||
},
|
||||
{
|
||||
"empty array",
|
||||
[]byte(`[]`),
|
||||
false,
|
||||
map[string]string{"abc": core.FieldTypeNumber},
|
||||
},
|
||||
{
|
||||
"empty object",
|
||||
[]byte(`{}`),
|
||||
true,
|
||||
map[string]string{"abc": core.FieldTypeNumber},
|
||||
},
|
||||
{
|
||||
"array with empty object",
|
||||
[]byte(`[{}]`),
|
||||
true,
|
||||
map[string]string{"abc": core.FieldTypeNumber},
|
||||
},
|
||||
{
|
||||
"single object with invalid type",
|
||||
[]byte(`{"type":"missing","name":"test"}`),
|
||||
true,
|
||||
map[string]string{"abc": core.FieldTypeNumber},
|
||||
},
|
||||
{
|
||||
"single object with valid type",
|
||||
[]byte(`{"type":"text","name":"test"}`),
|
||||
false,
|
||||
map[string]string{
|
||||
"abc": core.FieldTypeNumber,
|
||||
"test": core.FieldTypeText,
|
||||
},
|
||||
},
|
||||
{
|
||||
"array of object with valid types",
|
||||
[]byte(`[{"type":"text","name":"test1"},{"type":"url","name":"test2"}]`),
|
||||
false,
|
||||
map[string]string{
|
||||
"abc": core.FieldTypeNumber,
|
||||
"test1": core.FieldTypeText,
|
||||
"test2": core.FieldTypeURL,
|
||||
},
|
||||
},
|
||||
{
|
||||
"fields with duplicated ids should replace existing fields",
|
||||
[]byte(`[{"type":"text","name":"test1"},{"type":"url","name":"test2"},{"type":"text","name":"abc2", "id":"abc_id"}]`),
|
||||
false,
|
||||
map[string]string{
|
||||
"abc2": core.FieldTypeText,
|
||||
"test1": core.FieldTypeText,
|
||||
"test2": core.FieldTypeURL,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
testList := core.NewFieldsList(&core.NumberField{Name: "abc", Id: "abc_id"})
|
||||
err := testList.AddMarshaledJSON(s.raw)
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Fatalf("Expected hasErr %v, got %v", s.expectError, hasErr)
|
||||
}
|
||||
|
||||
if len(s.expectedFields) != len(testList) {
|
||||
t.Fatalf("Expected %d fields, got %d", len(s.expectedFields), len(testList))
|
||||
}
|
||||
|
||||
for fieldName, typ := range s.expectedFields {
|
||||
f := testList.GetByName(fieldName)
|
||||
|
||||
if f == nil {
|
||||
t.Errorf("Missing expected field %q", fieldName)
|
||||
continue
|
||||
}
|
||||
|
||||
if f.Type() != typ {
|
||||
t.Errorf("Expect field %q to has type %q, got %q", fieldName, typ, f.Type())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFieldsListAddAt(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
position int
|
||||
expected []string
|
||||
}{
|
||||
{-2, []string{"test1", "test2_new", "test3", "test4"}},
|
||||
{-1, []string{"test1", "test2_new", "test3", "test4"}},
|
||||
{0, []string{"test2_new", "test4", "test1", "test3"}},
|
||||
{1, []string{"test1", "test2_new", "test4", "test3"}},
|
||||
{2, []string{"test1", "test3", "test2_new", "test4"}},
|
||||
{3, []string{"test1", "test3", "test2_new", "test4"}},
|
||||
{4, []string{"test1", "test3", "test2_new", "test4"}},
|
||||
{5, []string{"test1", "test3", "test2_new", "test4"}},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(strconv.Itoa(s.position), func(t *testing.T) {
|
||||
f1 := &core.TextField{Id: "f1Id", Name: "test1"}
|
||||
f2 := &core.TextField{Id: "f2Id", Name: "test2"}
|
||||
f3 := &core.TextField{Id: "f3Id", Name: "test3"}
|
||||
testFieldsList := core.NewFieldsList(f1, f2, f3)
|
||||
|
||||
f2New := &core.EmailField{Id: "f2Id", Name: "test2_new"}
|
||||
f4 := &core.URLField{Name: "test4"}
|
||||
testFieldsList.AddAt(s.position, f2New, f4)
|
||||
|
||||
rawNames, err := json.Marshal(testFieldsList.FieldNames())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
rawExpected, err := json.Marshal(s.expected)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(rawNames, rawExpected) {
|
||||
t.Fatalf("Expected fields\n%s\ngot\n%s", rawExpected, rawNames)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFieldsListAddMarshaledJSONAt(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
position int
|
||||
expected []string
|
||||
}{
|
||||
{-2, []string{"test1", "test2_new", "test3", "test4"}},
|
||||
{-1, []string{"test1", "test2_new", "test3", "test4"}},
|
||||
{0, []string{"test2_new", "test4", "test1", "test3"}},
|
||||
{1, []string{"test1", "test2_new", "test4", "test3"}},
|
||||
{2, []string{"test1", "test3", "test2_new", "test4"}},
|
||||
{3, []string{"test1", "test3", "test2_new", "test4"}},
|
||||
{4, []string{"test1", "test3", "test2_new", "test4"}},
|
||||
{5, []string{"test1", "test3", "test2_new", "test4"}},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(strconv.Itoa(s.position), func(t *testing.T) {
|
||||
f1 := &core.TextField{Id: "f1Id", Name: "test1"}
|
||||
f2 := &core.TextField{Id: "f2Id", Name: "test2"}
|
||||
f3 := &core.TextField{Id: "f3Id", Name: "test3"}
|
||||
testFieldsList := core.NewFieldsList(f1, f2, f3)
|
||||
|
||||
err := testFieldsList.AddMarshaledJSONAt(s.position, []byte(`[
|
||||
{"id":"f2Id", "name":"test2_new", "type": "text"},
|
||||
{"name": "test4", "type": "text"}
|
||||
]`))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
rawNames, err := json.Marshal(testFieldsList.FieldNames())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
rawExpected, err := json.Marshal(s.expected)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(rawNames, rawExpected) {
|
||||
t.Fatalf("Expected fields\n%s\ngot\n%s", rawExpected, rawNames)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFieldsListStringAndValue(t *testing.T) {
|
||||
t.Run("empty list", func(t *testing.T) {
|
||||
testFieldsList := core.NewFieldsList()
|
||||
|
||||
str := testFieldsList.String()
|
||||
if str != "[]" {
|
||||
t.Fatalf("Expected empty slice, got\n%q", str)
|
||||
}
|
||||
|
||||
v, err := testFieldsList.Value()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if v != str {
|
||||
t.Fatalf("Expected String and Value to match")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("list with fields", func(t *testing.T) {
|
||||
testFieldsList := core.NewFieldsList(
|
||||
&core.TextField{Id: "f1id", Name: "test1"},
|
||||
&core.BoolField{Id: "f2id", Name: "test2"},
|
||||
&core.URLField{Id: "f3id", Name: "test3"},
|
||||
)
|
||||
|
||||
str := testFieldsList.String()
|
||||
|
||||
v, err := testFieldsList.Value()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if v != str {
|
||||
t.Fatalf("Expected String and Value to match")
|
||||
}
|
||||
|
||||
expectedParts := []string{
|
||||
`"type":"bool"`,
|
||||
`"type":"url"`,
|
||||
`"type":"text"`,
|
||||
`"id":"f1id"`,
|
||||
`"id":"f2id"`,
|
||||
`"id":"f3id"`,
|
||||
`"name":"test1"`,
|
||||
`"name":"test2"`,
|
||||
`"name":"test3"`,
|
||||
}
|
||||
|
||||
for _, part := range expectedParts {
|
||||
if !strings.Contains(str, part) {
|
||||
t.Fatalf("Missing %q in\nn%v", part, str)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestFieldsListScan(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
name string
|
||||
data any
|
||||
expectError bool
|
||||
expectJSON string
|
||||
}{
|
||||
{"nil", nil, false, "[]"},
|
||||
{"empty string", "", false, "[]"},
|
||||
{"empty byte", []byte{}, false, "[]"},
|
||||
{"empty string array", "[]", false, "[]"},
|
||||
{"invalid string", "invalid", true, "[]"},
|
||||
{"non-string", 123, true, "[]"},
|
||||
{"item with no field type", `[{}]`, true, "[]"},
|
||||
{
|
||||
"unknown field type",
|
||||
`[{"id":"123","name":"test1","type":"unknown"},{"id":"456","name":"test2","type":"bool"}]`,
|
||||
true,
|
||||
`[]`,
|
||||
},
|
||||
{
|
||||
"only the minimum field options",
|
||||
`[{"id":"123","name":"test1","type":"text","required":true},{"id":"456","name":"test2","type":"bool"}]`,
|
||||
false,
|
||||
`[{"autogeneratePattern":"","hidden":false,"id":"123","max":0,"min":0,"name":"test1","pattern":"","presentable":false,"primaryKey":false,"required":true,"system":false,"type":"text"},{"hidden":false,"id":"456","name":"test2","presentable":false,"required":false,"system":false,"type":"bool"}]`,
|
||||
},
|
||||
{
|
||||
"all field options",
|
||||
`[{"autogeneratePattern":"","hidden":true,"id":"123","max":12,"min":0,"name":"test1","pattern":"","presentable":true,"primaryKey":false,"required":true,"system":false,"type":"text"},{"hidden":false,"id":"456","name":"test2","presentable":false,"required":false,"system":true,"type":"bool"}]`,
|
||||
false,
|
||||
`[{"autogeneratePattern":"","hidden":true,"id":"123","max":12,"min":0,"name":"test1","pattern":"","presentable":true,"primaryKey":false,"required":true,"system":false,"type":"text"},{"hidden":false,"id":"456","name":"test2","presentable":false,"required":false,"system":true,"type":"bool"}]`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
testFieldsList := core.FieldsList{}
|
||||
|
||||
err := testFieldsList.Scan(s.data)
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
|
||||
}
|
||||
|
||||
str := testFieldsList.String()
|
||||
if str != s.expectJSON {
|
||||
t.Fatalf("Expected\n%v\ngot\n%v", s.expectJSON, str)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFieldsListJSON(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
name string
|
||||
data string
|
||||
expectError bool
|
||||
expectJSON string
|
||||
}{
|
||||
{"empty string", "", true, "[]"},
|
||||
{"invalid string", "invalid", true, "[]"},
|
||||
{"empty string array", "[]", false, "[]"},
|
||||
{"item with no field type", `[{}]`, true, "[]"},
|
||||
{
|
||||
"unknown field type",
|
||||
`[{"id":"123","name":"test1","type":"unknown"},{"id":"456","name":"test2","type":"bool"}]`,
|
||||
true,
|
||||
`[]`,
|
||||
},
|
||||
{
|
||||
"only the minimum field options",
|
||||
`[{"id":"123","name":"test1","type":"text","required":true},{"id":"456","name":"test2","type":"bool"}]`,
|
||||
false,
|
||||
`[{"autogeneratePattern":"","hidden":false,"id":"123","max":0,"min":0,"name":"test1","pattern":"","presentable":false,"primaryKey":false,"required":true,"system":false,"type":"text"},{"hidden":false,"id":"456","name":"test2","presentable":false,"required":false,"system":false,"type":"bool"}]`,
|
||||
},
|
||||
{
|
||||
"all field options",
|
||||
`[{"autogeneratePattern":"","hidden":true,"id":"123","max":12,"min":0,"name":"test1","pattern":"","presentable":true,"primaryKey":false,"required":true,"system":false,"type":"text"},{"hidden":false,"id":"456","name":"test2","presentable":false,"required":false,"system":true,"type":"bool"}]`,
|
||||
false,
|
||||
`[{"autogeneratePattern":"","hidden":true,"id":"123","max":12,"min":0,"name":"test1","pattern":"","presentable":true,"primaryKey":false,"required":true,"system":false,"type":"text"},{"hidden":false,"id":"456","name":"test2","presentable":false,"required":false,"system":true,"type":"bool"}]`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
testFieldsList := core.FieldsList{}
|
||||
|
||||
err := testFieldsList.UnmarshalJSON([]byte(s.data))
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
|
||||
}
|
||||
|
||||
raw, err := testFieldsList.MarshalJSON()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
str := string(raw)
|
||||
if str != s.expectJSON {
|
||||
t.Fatalf("Expected\n%v\ngot\n%v", s.expectJSON, str)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
22
core/log_model.go
Normal file
22
core/log_model.go
Normal file
|
@ -0,0 +1,22 @@
|
|||
package core
|
||||
|
||||
import "github.com/pocketbase/pocketbase/tools/types"
|
||||
|
||||
var (
|
||||
_ Model = (*Log)(nil)
|
||||
)
|
||||
|
||||
const LogsTableName = "_logs"
|
||||
|
||||
type Log struct {
|
||||
BaseModel
|
||||
|
||||
Created types.DateTime `db:"created" json:"created"`
|
||||
Data types.JSONMap[any] `db:"data" json:"data"`
|
||||
Message string `db:"message" json:"message"`
|
||||
Level int `db:"level" json:"level"`
|
||||
}
|
||||
|
||||
func (m *Log) TableName() string {
|
||||
return LogsTableName
|
||||
}
|
67
core/log_printer.go
Normal file
67
core/log_printer.go
Normal file
|
@ -0,0 +1,67 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
|
||||
"github.com/fatih/color"
|
||||
"github.com/pocketbase/pocketbase/tools/logger"
|
||||
"github.com/pocketbase/pocketbase/tools/store"
|
||||
"github.com/spf13/cast"
|
||||
)
|
||||
|
||||
var cachedColors = store.New[string, *color.Color](nil)
|
||||
|
||||
// getColor returns [color.Color] object and cache it (if not already).
|
||||
func getColor(attrs ...color.Attribute) (c *color.Color) {
|
||||
cacheKey := fmt.Sprint(attrs)
|
||||
if c = cachedColors.Get(cacheKey); c == nil {
|
||||
c = color.New(attrs...)
|
||||
cachedColors.Set(cacheKey, c)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// printLog prints the provided log to the stderr.
|
||||
// (note: defined as variable to overwriting in the tests)
|
||||
var printLog = func(log *logger.Log) {
|
||||
var str strings.Builder
|
||||
|
||||
switch log.Level {
|
||||
case slog.LevelDebug:
|
||||
str.WriteString(getColor(color.Bold, color.FgHiBlack).Sprint("DEBUG "))
|
||||
str.WriteString(getColor(color.FgWhite).Sprint(log.Message))
|
||||
case slog.LevelInfo:
|
||||
str.WriteString(getColor(color.Bold, color.FgWhite).Sprint("INFO "))
|
||||
str.WriteString(getColor(color.FgWhite).Sprint(log.Message))
|
||||
case slog.LevelWarn:
|
||||
str.WriteString(getColor(color.Bold, color.FgYellow).Sprint("WARN "))
|
||||
str.WriteString(getColor(color.FgYellow).Sprint(log.Message))
|
||||
case slog.LevelError:
|
||||
str.WriteString(getColor(color.Bold, color.FgRed).Sprint("ERROR "))
|
||||
str.WriteString(getColor(color.FgRed).Sprint(log.Message))
|
||||
default:
|
||||
str.WriteString(getColor(color.Bold, color.FgCyan).Sprintf("[%d] ", log.Level))
|
||||
str.WriteString(getColor(color.FgCyan).Sprint(log.Message))
|
||||
}
|
||||
|
||||
str.WriteString("\n")
|
||||
|
||||
if v, ok := log.Data["type"]; ok && cast.ToString(v) == "request" {
|
||||
padding := 0
|
||||
keys := []string{"error", "details"}
|
||||
for _, k := range keys {
|
||||
if v := log.Data[k]; v != nil {
|
||||
str.WriteString(getColor(color.FgHiRed).Sprintf("%s└─ %v", strings.Repeat(" ", padding), v))
|
||||
str.WriteString("\n")
|
||||
padding += 3
|
||||
}
|
||||
}
|
||||
} else if len(log.Data) > 0 {
|
||||
str.WriteString(getColor(color.FgHiBlack).Sprintf("└─ %v", log.Data))
|
||||
str.WriteString("\n")
|
||||
}
|
||||
|
||||
fmt.Print(str.String())
|
||||
}
|
124
core/log_printer_test.go
Normal file
124
core/log_printer_test.go
Normal file
|
@ -0,0 +1,124 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"log/slog"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/tools/list"
|
||||
"github.com/pocketbase/pocketbase/tools/logger"
|
||||
)
|
||||
|
||||
func TestBaseAppLoggerLevelDevPrint(t *testing.T) {
|
||||
testLogLevel := 4
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
isDev bool
|
||||
levels []int
|
||||
printedLevels []int
|
||||
persistedLevels []int
|
||||
}{
|
||||
{
|
||||
"dev mode",
|
||||
true,
|
||||
[]int{testLogLevel - 1, testLogLevel, testLogLevel + 1},
|
||||
[]int{testLogLevel - 1, testLogLevel, testLogLevel + 1},
|
||||
[]int{testLogLevel, testLogLevel + 1},
|
||||
},
|
||||
{
|
||||
"nondev mode",
|
||||
false,
|
||||
[]int{testLogLevel - 1, testLogLevel, testLogLevel + 1},
|
||||
[]int{},
|
||||
[]int{testLogLevel, testLogLevel + 1},
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
const testDataDir = "./pb_base_app_test_data_dir/"
|
||||
defer os.RemoveAll(testDataDir)
|
||||
|
||||
app := NewBaseApp(BaseAppConfig{
|
||||
DataDir: testDataDir,
|
||||
IsDev: s.isDev,
|
||||
})
|
||||
defer app.ResetBootstrapState()
|
||||
|
||||
if err := app.Bootstrap(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// silence query logs
|
||||
app.concurrentDB.(*dbx.DB).ExecLogFunc = func(ctx context.Context, t time.Duration, sql string, result sql.Result, err error) {}
|
||||
app.concurrentDB.(*dbx.DB).QueryLogFunc = func(ctx context.Context, t time.Duration, sql string, rows *sql.Rows, err error) {}
|
||||
app.nonconcurrentDB.(*dbx.DB).ExecLogFunc = func(ctx context.Context, t time.Duration, sql string, result sql.Result, err error) {}
|
||||
app.nonconcurrentDB.(*dbx.DB).QueryLogFunc = func(ctx context.Context, t time.Duration, sql string, rows *sql.Rows, err error) {}
|
||||
|
||||
app.Settings().Logs.MinLevel = testLogLevel
|
||||
if err := app.Save(app.Settings()); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var printedLevels []int
|
||||
var persistedLevels []int
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// track printed logs
|
||||
originalPrintLog := printLog
|
||||
defer func() {
|
||||
printLog = originalPrintLog
|
||||
}()
|
||||
printLog = func(log *logger.Log) {
|
||||
printedLevels = append(printedLevels, int(log.Level))
|
||||
}
|
||||
|
||||
// track persisted logs
|
||||
app.OnModelAfterCreateSuccess("_logs").BindFunc(func(e *ModelEvent) error {
|
||||
l, ok := e.Model.(*Log)
|
||||
if ok {
|
||||
persistedLevels = append(persistedLevels, l.Level)
|
||||
}
|
||||
return e.Next()
|
||||
})
|
||||
|
||||
// write and persist logs
|
||||
for _, l := range s.levels {
|
||||
app.Logger().Log(ctx, slog.Level(l), "test")
|
||||
}
|
||||
handler, ok := app.Logger().Handler().(*logger.BatchHandler)
|
||||
if !ok {
|
||||
t.Fatalf("Expected BatchHandler, got %v", app.Logger().Handler())
|
||||
}
|
||||
if err := handler.WriteAll(ctx); err != nil {
|
||||
t.Fatalf("Failed to write all logs: %v", err)
|
||||
}
|
||||
|
||||
// check persisted log levels
|
||||
if len(s.persistedLevels) != len(persistedLevels) {
|
||||
t.Fatalf("Expected persisted levels \n%v\ngot\n%v", s.persistedLevels, persistedLevels)
|
||||
}
|
||||
for _, l := range persistedLevels {
|
||||
if !list.ExistInSlice(l, s.persistedLevels) {
|
||||
t.Fatalf("Missing expected persisted level %v in %v", l, persistedLevels)
|
||||
}
|
||||
}
|
||||
|
||||
// check printed log levels
|
||||
if len(s.printedLevels) != len(printedLevels) {
|
||||
t.Fatalf("Expected printed levels \n%v\ngot\n%v", s.printedLevels, printedLevels)
|
||||
}
|
||||
for _, l := range printedLevels {
|
||||
if !list.ExistInSlice(l, s.printedLevels) {
|
||||
t.Fatalf("Missing expected printed level %v in %v", l, printedLevels)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
65
core/log_query.go
Normal file
65
core/log_query.go
Normal file
|
@ -0,0 +1,65 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
)
|
||||
|
||||
// LogQuery returns a new Log select query.
|
||||
func (app *BaseApp) LogQuery() *dbx.SelectQuery {
|
||||
return app.AuxModelQuery(&Log{})
|
||||
}
|
||||
|
||||
// FindLogById finds a single Log entry by its id.
|
||||
func (app *BaseApp) FindLogById(id string) (*Log, error) {
|
||||
model := &Log{}
|
||||
|
||||
err := app.LogQuery().
|
||||
AndWhere(dbx.HashExp{"id": id}).
|
||||
Limit(1).
|
||||
One(model)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return model, nil
|
||||
}
|
||||
|
||||
// LogsStatsItem defines the total number of logs for a specific time period.
|
||||
type LogsStatsItem struct {
|
||||
Date types.DateTime `db:"date" json:"date"`
|
||||
Total int `db:"total" json:"total"`
|
||||
}
|
||||
|
||||
// LogsStats returns hourly grouped logs statistics.
|
||||
func (app *BaseApp) LogsStats(expr dbx.Expression) ([]*LogsStatsItem, error) {
|
||||
result := []*LogsStatsItem{}
|
||||
|
||||
query := app.LogQuery().
|
||||
Select("count(id) as total", "strftime('%Y-%m-%d %H:00:00', created) as date").
|
||||
GroupBy("date")
|
||||
|
||||
if expr != nil {
|
||||
query.AndWhere(expr)
|
||||
}
|
||||
|
||||
err := query.All(&result)
|
||||
|
||||
return result, err
|
||||
}
|
||||
|
||||
// DeleteOldLogs delete all logs that are created before createdBefore.
|
||||
//
|
||||
// For better performance the logs delete is executed as plain SQL statement,
|
||||
// aka. no delete model hook events will be fired.
|
||||
func (app *BaseApp) DeleteOldLogs(createdBefore time.Time) error {
|
||||
formattedDate := createdBefore.UTC().Format(types.DefaultDateLayout)
|
||||
expr := dbx.NewExp("[[created]] <= {:date}", dbx.Params{"date": formattedDate})
|
||||
|
||||
_, err := app.auxNonconcurrentDB.Delete((&Log{}).TableName(), expr).Execute()
|
||||
|
||||
return err
|
||||
}
|
114
core/log_query_test.go
Normal file
114
core/log_query_test.go
Normal file
|
@ -0,0 +1,114 @@
|
|||
package core_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
)
|
||||
|
||||
func TestFindLogById(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
tests.StubLogsData(app)
|
||||
|
||||
scenarios := []struct {
|
||||
id string
|
||||
expectError bool
|
||||
}{
|
||||
{"", true},
|
||||
{"invalid", true},
|
||||
{"00000000-9f38-44fb-bf82-c8f53b310d91", true},
|
||||
{"873f2133-9f38-44fb-bf82-c8f53b310d91", false},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%s", i, s.id), func(t *testing.T) {
|
||||
log, err := app.FindLogById(s.id)
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Fatalf("Expected hasErr to be %v, got %v (%v)", s.expectError, hasErr, err)
|
||||
}
|
||||
|
||||
if log != nil && log.Id != s.id {
|
||||
t.Fatalf("Expected log with id %q, got %q", s.id, log.Id)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogsStats(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
tests.StubLogsData(app)
|
||||
|
||||
expected := `[{"date":"2022-05-01 10:00:00.000Z","total":1},{"date":"2022-05-02 10:00:00.000Z","total":1}]`
|
||||
|
||||
now := time.Now().UTC().Format(types.DefaultDateLayout)
|
||||
exp := dbx.NewExp("[[created]] <= {:date}", dbx.Params{"date": now})
|
||||
result, err := app.LogsStats(exp)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
encoded, _ := json.Marshal(result)
|
||||
if string(encoded) != expected {
|
||||
t.Fatalf("Expected\n%q\ngot\n%q", expected, string(encoded))
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteOldLogs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
tests.StubLogsData(app)
|
||||
|
||||
scenarios := []struct {
|
||||
date string
|
||||
expectedTotal int
|
||||
}{
|
||||
{"2022-01-01 10:00:00.000Z", 2}, // no logs to delete before that time
|
||||
{"2022-05-01 11:00:00.000Z", 1}, // only 1 log should have left
|
||||
{"2022-05-03 11:00:00.000Z", 0}, // no more logs should have left
|
||||
{"2022-05-04 11:00:00.000Z", 0}, // no more logs should have left
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.date, func(t *testing.T) {
|
||||
date, dateErr := time.Parse(types.DefaultDateLayout, s.date)
|
||||
if dateErr != nil {
|
||||
t.Fatalf("Date error %v", dateErr)
|
||||
}
|
||||
|
||||
deleteErr := app.DeleteOldLogs(date)
|
||||
if deleteErr != nil {
|
||||
t.Fatalf("Delete error %v", deleteErr)
|
||||
}
|
||||
|
||||
// check total remaining logs
|
||||
var total int
|
||||
countErr := app.AuxModelQuery(&core.Log{}).Select("count(*)").Row(&total)
|
||||
if countErr != nil {
|
||||
t.Errorf("Count error %v", countErr)
|
||||
}
|
||||
|
||||
if total != s.expectedTotal {
|
||||
t.Errorf("Expected %d remaining logs, got %d", s.expectedTotal, total)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
157
core/mfa_model.go
Normal file
157
core/mfa_model.go
Normal file
|
@ -0,0 +1,157 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/hook"
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
)
|
||||
|
||||
const (
|
||||
MFAMethodPassword = "password"
|
||||
MFAMethodOAuth2 = "oauth2"
|
||||
MFAMethodOTP = "otp"
|
||||
)
|
||||
|
||||
const CollectionNameMFAs = "_mfas"
|
||||
|
||||
var (
|
||||
_ Model = (*MFA)(nil)
|
||||
_ PreValidator = (*MFA)(nil)
|
||||
_ RecordProxy = (*MFA)(nil)
|
||||
)
|
||||
|
||||
// MFA defines a Record proxy for working with the mfas collection.
|
||||
type MFA struct {
|
||||
*Record
|
||||
}
|
||||
|
||||
// NewMFA instantiates and returns a new blank *MFA model.
|
||||
//
|
||||
// Example usage:
|
||||
//
|
||||
// mfa := core.NewMFA(app)
|
||||
// mfa.SetRecordRef(user.Id)
|
||||
// mfa.SetCollectionRef(user.Collection().Id)
|
||||
// mfa.SetMethod(core.MFAMethodPassword)
|
||||
// app.Save(mfa)
|
||||
func NewMFA(app App) *MFA {
|
||||
m := &MFA{}
|
||||
|
||||
c, err := app.FindCachedCollectionByNameOrId(CollectionNameMFAs)
|
||||
if err != nil {
|
||||
// this is just to make tests easier since mfa is a system collection and it is expected to be always accessible
|
||||
// (note: the loaded record is further checked on MFA.PreValidate())
|
||||
c = NewBaseCollection("@__invalid__")
|
||||
}
|
||||
|
||||
m.Record = NewRecord(c)
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
// PreValidate implements the [PreValidator] interface and checks
|
||||
// whether the proxy is properly loaded.
|
||||
func (m *MFA) PreValidate(ctx context.Context, app App) error {
|
||||
if m.Record == nil || m.Record.Collection().Name != CollectionNameMFAs {
|
||||
return errors.New("missing or invalid mfa ProxyRecord")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ProxyRecord returns the proxied Record model.
|
||||
func (m *MFA) ProxyRecord() *Record {
|
||||
return m.Record
|
||||
}
|
||||
|
||||
// SetProxyRecord loads the specified record model into the current proxy.
|
||||
func (m *MFA) SetProxyRecord(record *Record) {
|
||||
m.Record = record
|
||||
}
|
||||
|
||||
// CollectionRef returns the "collectionRef" field value.
|
||||
func (m *MFA) CollectionRef() string {
|
||||
return m.GetString("collectionRef")
|
||||
}
|
||||
|
||||
// SetCollectionRef updates the "collectionRef" record field value.
|
||||
func (m *MFA) SetCollectionRef(collectionId string) {
|
||||
m.Set("collectionRef", collectionId)
|
||||
}
|
||||
|
||||
// RecordRef returns the "recordRef" record field value.
|
||||
func (m *MFA) RecordRef() string {
|
||||
return m.GetString("recordRef")
|
||||
}
|
||||
|
||||
// SetRecordRef updates the "recordRef" record field value.
|
||||
func (m *MFA) SetRecordRef(recordId string) {
|
||||
m.Set("recordRef", recordId)
|
||||
}
|
||||
|
||||
// Method returns the "method" record field value.
|
||||
func (m *MFA) Method() string {
|
||||
return m.GetString("method")
|
||||
}
|
||||
|
||||
// SetMethod updates the "method" record field value.
|
||||
func (m *MFA) SetMethod(method string) {
|
||||
m.Set("method", method)
|
||||
}
|
||||
|
||||
// Created returns the "created" record field value.
|
||||
func (m *MFA) Created() types.DateTime {
|
||||
return m.GetDateTime("created")
|
||||
}
|
||||
|
||||
// Updated returns the "updated" record field value.
|
||||
func (m *MFA) Updated() types.DateTime {
|
||||
return m.GetDateTime("updated")
|
||||
}
|
||||
|
||||
// HasExpired checks if the mfa is expired, aka. whether it has been
|
||||
// more than maxElapsed time since its creation.
|
||||
func (m *MFA) HasExpired(maxElapsed time.Duration) bool {
|
||||
return time.Since(m.Created().Time()) > maxElapsed
|
||||
}
|
||||
|
||||
func (app *BaseApp) registerMFAHooks() {
|
||||
recordRefHooks[*MFA](app, CollectionNameMFAs, CollectionTypeAuth)
|
||||
|
||||
// run on every hour to cleanup expired mfa sessions
|
||||
app.Cron().Add("__pbMFACleanup__", "0 * * * *", func() {
|
||||
if err := app.DeleteExpiredMFAs(); err != nil {
|
||||
app.Logger().Warn("Failed to delete expired MFA sessions", "error", err)
|
||||
}
|
||||
})
|
||||
|
||||
// delete existing mfas on password change
|
||||
app.OnRecordUpdate().Bind(&hook.Handler[*RecordEvent]{
|
||||
Func: func(e *RecordEvent) error {
|
||||
err := e.Next()
|
||||
if err != nil || !e.Record.Collection().IsAuth() {
|
||||
return err
|
||||
}
|
||||
|
||||
old := e.Record.Original().GetString(FieldNamePassword + ":hash")
|
||||
new := e.Record.GetString(FieldNamePassword + ":hash")
|
||||
if old != new {
|
||||
err = e.App.DeleteAllMFAsByRecord(e.Record)
|
||||
if err != nil {
|
||||
e.App.Logger().Warn(
|
||||
"Failed to delete all previous mfas",
|
||||
"error", err,
|
||||
"recordId", e.Record.Id,
|
||||
"collectionId", e.Record.Collection().Id,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
Priority: 99,
|
||||
})
|
||||
}
|
302
core/mfa_model_test.go
Normal file
302
core/mfa_model_test.go
Normal file
|
@ -0,0 +1,302 @@
|
|||
package core_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
)
|
||||
|
||||
func TestNewMFA(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
mfa := core.NewMFA(app)
|
||||
|
||||
if mfa.Collection().Name != core.CollectionNameMFAs {
|
||||
t.Fatalf("Expected record with %q collection, got %q", core.CollectionNameMFAs, mfa.Collection().Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMFAProxyRecord(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
record := core.NewRecord(core.NewBaseCollection("test"))
|
||||
record.Id = "test_id"
|
||||
|
||||
mfa := core.MFA{}
|
||||
mfa.SetProxyRecord(record)
|
||||
|
||||
if mfa.ProxyRecord() == nil || mfa.ProxyRecord().Id != record.Id {
|
||||
t.Fatalf("Expected proxy record with id %q, got %v", record.Id, mfa.ProxyRecord())
|
||||
}
|
||||
}
|
||||
|
||||
func TestMFARecordRef(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
mfa := core.NewMFA(app)
|
||||
|
||||
testValues := []string{"test_1", "test2", ""}
|
||||
for i, testValue := range testValues {
|
||||
t.Run(fmt.Sprintf("%d_%q", i, testValue), func(t *testing.T) {
|
||||
mfa.SetRecordRef(testValue)
|
||||
|
||||
if v := mfa.RecordRef(); v != testValue {
|
||||
t.Fatalf("Expected getter %q, got %q", testValue, v)
|
||||
}
|
||||
|
||||
if v := mfa.GetString("recordRef"); v != testValue {
|
||||
t.Fatalf("Expected field value %q, got %q", testValue, v)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMFACollectionRef(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
mfa := core.NewMFA(app)
|
||||
|
||||
testValues := []string{"test_1", "test2", ""}
|
||||
for i, testValue := range testValues {
|
||||
t.Run(fmt.Sprintf("%d_%q", i, testValue), func(t *testing.T) {
|
||||
mfa.SetCollectionRef(testValue)
|
||||
|
||||
if v := mfa.CollectionRef(); v != testValue {
|
||||
t.Fatalf("Expected getter %q, got %q", testValue, v)
|
||||
}
|
||||
|
||||
if v := mfa.GetString("collectionRef"); v != testValue {
|
||||
t.Fatalf("Expected field value %q, got %q", testValue, v)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMFAMethod(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
mfa := core.NewMFA(app)
|
||||
|
||||
testValues := []string{"test_1", "test2", ""}
|
||||
for i, testValue := range testValues {
|
||||
t.Run(fmt.Sprintf("%d_%q", i, testValue), func(t *testing.T) {
|
||||
mfa.SetMethod(testValue)
|
||||
|
||||
if v := mfa.Method(); v != testValue {
|
||||
t.Fatalf("Expected getter %q, got %q", testValue, v)
|
||||
}
|
||||
|
||||
if v := mfa.GetString("method"); v != testValue {
|
||||
t.Fatalf("Expected field value %q, got %q", testValue, v)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMFACreated(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
mfa := core.NewMFA(app)
|
||||
|
||||
if v := mfa.Created().String(); v != "" {
|
||||
t.Fatalf("Expected empty created, got %q", v)
|
||||
}
|
||||
|
||||
now := types.NowDateTime()
|
||||
mfa.SetRaw("created", now)
|
||||
|
||||
if v := mfa.Created().String(); v != now.String() {
|
||||
t.Fatalf("Expected %q created, got %q", now.String(), v)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMFAUpdated(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
mfa := core.NewMFA(app)
|
||||
|
||||
if v := mfa.Updated().String(); v != "" {
|
||||
t.Fatalf("Expected empty updated, got %q", v)
|
||||
}
|
||||
|
||||
now := types.NowDateTime()
|
||||
mfa.SetRaw("updated", now)
|
||||
|
||||
if v := mfa.Updated().String(); v != now.String() {
|
||||
t.Fatalf("Expected %q updated, got %q", now.String(), v)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMFAHasExpired(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
now := types.NowDateTime()
|
||||
|
||||
mfa := core.NewMFA(app)
|
||||
mfa.SetRaw("created", now.Add(-5*time.Minute))
|
||||
|
||||
scenarios := []struct {
|
||||
maxElapsed time.Duration
|
||||
expected bool
|
||||
}{
|
||||
{0 * time.Minute, true},
|
||||
{3 * time.Minute, true},
|
||||
{5 * time.Minute, true},
|
||||
{6 * time.Minute, false},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%s", i, s.maxElapsed.String()), func(t *testing.T) {
|
||||
result := mfa.HasExpired(s.maxElapsed)
|
||||
|
||||
if result != s.expected {
|
||||
t.Fatalf("Expected %v, got %v", s.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMFAPreValidate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
mfasCol, err := app.FindCollectionByNameOrId(core.CollectionNameMFAs)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Run("no proxy record", func(t *testing.T) {
|
||||
mfa := &core.MFA{}
|
||||
|
||||
if err := app.Validate(mfa); err == nil {
|
||||
t.Fatal("Expected collection validation error")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("non-MFA collection", func(t *testing.T) {
|
||||
mfa := &core.MFA{}
|
||||
mfa.SetProxyRecord(core.NewRecord(core.NewBaseCollection("invalid")))
|
||||
mfa.SetRecordRef(user.Id)
|
||||
mfa.SetCollectionRef(user.Collection().Id)
|
||||
mfa.SetMethod("test123")
|
||||
|
||||
if err := app.Validate(mfa); err == nil {
|
||||
t.Fatal("Expected collection validation error")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("MFA collection", func(t *testing.T) {
|
||||
mfa := &core.MFA{}
|
||||
mfa.SetProxyRecord(core.NewRecord(mfasCol))
|
||||
mfa.SetRecordRef(user.Id)
|
||||
mfa.SetCollectionRef(user.Collection().Id)
|
||||
mfa.SetMethod("test123")
|
||||
|
||||
if err := app.Validate(mfa); err != nil {
|
||||
t.Fatalf("Expected nil validation error, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestMFAValidateHook(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
demo1, err := app.FindRecordById("demo1", "84nmscqy84lsi1t")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
mfa func() *core.MFA
|
||||
expectErrors []string
|
||||
}{
|
||||
{
|
||||
"empty",
|
||||
func() *core.MFA {
|
||||
return core.NewMFA(app)
|
||||
},
|
||||
[]string{"collectionRef", "recordRef", "method"},
|
||||
},
|
||||
{
|
||||
"non-auth collection",
|
||||
func() *core.MFA {
|
||||
mfa := core.NewMFA(app)
|
||||
mfa.SetCollectionRef(demo1.Collection().Id)
|
||||
mfa.SetRecordRef(demo1.Id)
|
||||
mfa.SetMethod("test123")
|
||||
return mfa
|
||||
},
|
||||
[]string{"collectionRef"},
|
||||
},
|
||||
{
|
||||
"missing record id",
|
||||
func() *core.MFA {
|
||||
mfa := core.NewMFA(app)
|
||||
mfa.SetCollectionRef(user.Collection().Id)
|
||||
mfa.SetRecordRef("missing")
|
||||
mfa.SetMethod("test123")
|
||||
return mfa
|
||||
},
|
||||
[]string{"recordRef"},
|
||||
},
|
||||
{
|
||||
"valid ref",
|
||||
func() *core.MFA {
|
||||
mfa := core.NewMFA(app)
|
||||
mfa.SetCollectionRef(user.Collection().Id)
|
||||
mfa.SetRecordRef(user.Id)
|
||||
mfa.SetMethod("test123")
|
||||
return mfa
|
||||
},
|
||||
[]string{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
errs := app.Validate(s.mfa())
|
||||
tests.TestValidationErrors(t, errs, s.expectErrors)
|
||||
})
|
||||
}
|
||||
}
|
117
core/mfa_query.go
Normal file
117
core/mfa_query.go
Normal file
|
@ -0,0 +1,117 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
)
|
||||
|
||||
// FindAllMFAsByRecord returns all MFA models linked to the provided auth record.
|
||||
func (app *BaseApp) FindAllMFAsByRecord(authRecord *Record) ([]*MFA, error) {
|
||||
result := []*MFA{}
|
||||
|
||||
err := app.RecordQuery(CollectionNameMFAs).
|
||||
AndWhere(dbx.HashExp{
|
||||
"collectionRef": authRecord.Collection().Id,
|
||||
"recordRef": authRecord.Id,
|
||||
}).
|
||||
OrderBy("created DESC").
|
||||
All(&result)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// FindAllMFAsByCollection returns all MFA models linked to the provided collection.
|
||||
func (app *BaseApp) FindAllMFAsByCollection(collection *Collection) ([]*MFA, error) {
|
||||
result := []*MFA{}
|
||||
|
||||
err := app.RecordQuery(CollectionNameMFAs).
|
||||
AndWhere(dbx.HashExp{"collectionRef": collection.Id}).
|
||||
OrderBy("created DESC").
|
||||
All(&result)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// FindMFAById returns a single MFA model by its id.
|
||||
func (app *BaseApp) FindMFAById(id string) (*MFA, error) {
|
||||
result := &MFA{}
|
||||
|
||||
err := app.RecordQuery(CollectionNameMFAs).
|
||||
AndWhere(dbx.HashExp{"id": id}).
|
||||
Limit(1).
|
||||
One(result)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// DeleteAllMFAsByRecord deletes all MFA models associated with the provided record.
|
||||
//
|
||||
// Returns a combined error with the failed deletes.
|
||||
func (app *BaseApp) DeleteAllMFAsByRecord(authRecord *Record) error {
|
||||
models, err := app.FindAllMFAsByRecord(authRecord)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var errs []error
|
||||
for _, m := range models {
|
||||
if err := app.Delete(m); err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
}
|
||||
if len(errs) > 0 {
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteExpiredMFAs deletes the expired MFAs for all auth collections.
|
||||
func (app *BaseApp) DeleteExpiredMFAs() error {
|
||||
authCollections, err := app.FindAllCollections(CollectionTypeAuth)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// note: perform even if MFA is disabled to ensure that there are no dangling old records
|
||||
for _, collection := range authCollections {
|
||||
minValidDate, err := types.ParseDateTime(time.Now().Add(-1 * collection.MFA.DurationTime()))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
items := []*Record{}
|
||||
|
||||
err = app.RecordQuery(CollectionNameMFAs).
|
||||
AndWhere(dbx.HashExp{"collectionRef": collection.Id}).
|
||||
AndWhere(dbx.NewExp("[[created]] < {:date}", dbx.Params{"date": minValidDate})).
|
||||
All(&items)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, item := range items {
|
||||
err = app.Delete(item)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
311
core/mfa_query_test.go
Normal file
311
core/mfa_query_test.go
Normal file
|
@ -0,0 +1,311 @@
|
|||
package core_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestFindAllMFAsByRecord(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
if err := tests.StubMFARecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
demo1, err := app.FindRecordById("demo1", "84nmscqy84lsi1t")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
superuser2, err := app.FindAuthRecordByEmail(core.CollectionNameSuperusers, "test2@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
superuser4, err := app.FindAuthRecordByEmail(core.CollectionNameSuperusers, "test4@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
user1, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
scenarios := []struct {
|
||||
record *core.Record
|
||||
expected []string
|
||||
}{
|
||||
{demo1, nil},
|
||||
{superuser2, []string{"superuser2_0", "superuser2_3", "superuser2_2", "superuser2_1", "superuser2_4"}},
|
||||
{superuser4, nil},
|
||||
{user1, []string{"user1_0"}},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.record.Collection().Name+"_"+s.record.Id, func(t *testing.T) {
|
||||
result, err := app.FindAllMFAsByRecord(s.record)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(result) != len(s.expected) {
|
||||
t.Fatalf("Expected total mfas %d, got %d", len(s.expected), len(result))
|
||||
}
|
||||
|
||||
for i, id := range s.expected {
|
||||
if result[i].Id != id {
|
||||
t.Errorf("[%d] Expected id %q, got %q", i, id, result[i].Id)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindAllMFAsByCollection(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
if err := tests.StubMFARecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
demo1, err := app.FindCollectionByNameOrId("demo1")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
superusers, err := app.FindCollectionByNameOrId(core.CollectionNameSuperusers)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
clients, err := app.FindCollectionByNameOrId("clients")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
users, err := app.FindCollectionByNameOrId("users")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
scenarios := []struct {
|
||||
collection *core.Collection
|
||||
expected []string
|
||||
}{
|
||||
{demo1, nil},
|
||||
{superusers, []string{
|
||||
"superuser2_0",
|
||||
"superuser2_3",
|
||||
"superuser3_0",
|
||||
"superuser2_2",
|
||||
"superuser3_1",
|
||||
"superuser2_1",
|
||||
"superuser2_4",
|
||||
}},
|
||||
{clients, nil},
|
||||
{users, []string{"user1_0"}},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.collection.Name, func(t *testing.T) {
|
||||
result, err := app.FindAllMFAsByCollection(s.collection)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(result) != len(s.expected) {
|
||||
t.Fatalf("Expected total mfas %d, got %d", len(s.expected), len(result))
|
||||
}
|
||||
|
||||
for i, id := range s.expected {
|
||||
if result[i].Id != id {
|
||||
t.Errorf("[%d] Expected id %q, got %q", i, id, result[i].Id)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindMFAById(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
if err := tests.StubMFARecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
scenarios := []struct {
|
||||
id string
|
||||
expectError bool
|
||||
}{
|
||||
{"", true},
|
||||
{"84nmscqy84lsi1t", true}, // non-mfa id
|
||||
{"superuser2_0", false},
|
||||
{"superuser2_4", false}, // expired
|
||||
{"user1_0", false},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.id, func(t *testing.T) {
|
||||
result, err := app.FindMFAById(s.id)
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
|
||||
}
|
||||
|
||||
if hasErr {
|
||||
return
|
||||
}
|
||||
|
||||
if result.Id != s.id {
|
||||
t.Fatalf("Expected record with id %q, got %q", s.id, result.Id)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteAllMFAsByRecord(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testApp, _ := tests.NewTestApp()
|
||||
defer testApp.Cleanup()
|
||||
|
||||
demo1, err := testApp.FindRecordById("demo1", "84nmscqy84lsi1t")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
superuser2, err := testApp.FindAuthRecordByEmail(core.CollectionNameSuperusers, "test2@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
superuser4, err := testApp.FindAuthRecordByEmail(core.CollectionNameSuperusers, "test4@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
user1, err := testApp.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
scenarios := []struct {
|
||||
record *core.Record
|
||||
deletedIds []string
|
||||
}{
|
||||
{demo1, nil}, // non-auth record
|
||||
{superuser2, []string{"superuser2_0", "superuser2_1", "superuser2_3", "superuser2_2", "superuser2_4"}},
|
||||
{superuser4, nil},
|
||||
{user1, []string{"user1_0"}},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%s_%s", i, s.record.Collection().Name, s.record.Id), func(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
if err := tests.StubMFARecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
deletedIds := []string{}
|
||||
app.OnRecordAfterDeleteSuccess().BindFunc(func(e *core.RecordEvent) error {
|
||||
deletedIds = append(deletedIds, e.Record.Id)
|
||||
return e.Next()
|
||||
})
|
||||
|
||||
err := app.DeleteAllMFAsByRecord(s.record)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(deletedIds) != len(s.deletedIds) {
|
||||
t.Fatalf("Expected deleted ids\n%v\ngot\n%v", s.deletedIds, deletedIds)
|
||||
}
|
||||
|
||||
for _, id := range s.deletedIds {
|
||||
if !slices.Contains(deletedIds, id) {
|
||||
t.Errorf("Expected to find deleted id %q in %v", id, deletedIds)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteExpiredMFAs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
checkDeletedIds := func(app core.App, t *testing.T, expectedDeletedIds []string) {
|
||||
if err := tests.StubMFARecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
deletedIds := []string{}
|
||||
app.OnRecordDelete().BindFunc(func(e *core.RecordEvent) error {
|
||||
deletedIds = append(deletedIds, e.Record.Id)
|
||||
return e.Next()
|
||||
})
|
||||
|
||||
if err := app.DeleteExpiredMFAs(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(deletedIds) != len(expectedDeletedIds) {
|
||||
t.Fatalf("Expected deleted ids\n%v\ngot\n%v", expectedDeletedIds, deletedIds)
|
||||
}
|
||||
|
||||
for _, id := range expectedDeletedIds {
|
||||
if !slices.Contains(deletedIds, id) {
|
||||
t.Errorf("Expected to find deleted id %q in %v", id, deletedIds)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
t.Run("default test collections", func(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
checkDeletedIds(app, t, []string{
|
||||
"user1_0",
|
||||
"superuser2_1",
|
||||
"superuser2_4",
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("mfa collection duration mock", func(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
superusers, err := app.FindCollectionByNameOrId(core.CollectionNameSuperusers)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
superusers.MFA.Duration = 60
|
||||
if err := app.Save(superusers); err != nil {
|
||||
t.Fatalf("Failed to mock superusers mfa duration: %v", err)
|
||||
}
|
||||
|
||||
checkDeletedIds(app, t, []string{
|
||||
"user1_0",
|
||||
"superuser2_1",
|
||||
"superuser2_2",
|
||||
"superuser2_4",
|
||||
"superuser3_1",
|
||||
})
|
||||
})
|
||||
}
|
83
core/migrations_list.go
Normal file
83
core/migrations_list.go
Normal file
|
@ -0,0 +1,83 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"sort"
|
||||
)
|
||||
|
||||
type Migration struct {
|
||||
Up func(txApp App) error
|
||||
Down func(txApp App) error
|
||||
File string
|
||||
ReapplyCondition func(txApp App, runner *MigrationsRunner, fileName string) (bool, error)
|
||||
}
|
||||
|
||||
// MigrationsList defines a list with migration definitions
|
||||
type MigrationsList struct {
|
||||
list []*Migration
|
||||
}
|
||||
|
||||
// Item returns a single migration from the list by its index.
|
||||
func (l *MigrationsList) Item(index int) *Migration {
|
||||
return l.list[index]
|
||||
}
|
||||
|
||||
// Items returns the internal migrations list slice.
|
||||
func (l *MigrationsList) Items() []*Migration {
|
||||
return l.list
|
||||
}
|
||||
|
||||
// Copy copies all provided list migrations into the current one.
|
||||
func (l *MigrationsList) Copy(list MigrationsList) {
|
||||
for _, item := range list.Items() {
|
||||
l.Register(item.Up, item.Down, item.File)
|
||||
}
|
||||
}
|
||||
|
||||
// Add adds adds an existing migration definition to the list.
|
||||
//
|
||||
// If m.File is not provided, it will try to get the name from its .go file.
|
||||
//
|
||||
// The list will be sorted automatically based on the migrations file name.
|
||||
func (l *MigrationsList) Add(m *Migration) {
|
||||
if m.File == "" {
|
||||
_, path, _, _ := runtime.Caller(1)
|
||||
m.File = filepath.Base(path)
|
||||
}
|
||||
|
||||
l.list = append(l.list, m)
|
||||
|
||||
sort.SliceStable(l.list, func(i int, j int) bool {
|
||||
return l.list[i].File < l.list[j].File
|
||||
})
|
||||
}
|
||||
|
||||
// Register adds new migration definition to the list.
|
||||
//
|
||||
// If optFilename is not provided, it will try to get the name from its .go file.
|
||||
//
|
||||
// The list will be sorted automatically based on the migrations file name.
|
||||
func (l *MigrationsList) Register(
|
||||
up func(txApp App) error,
|
||||
down func(txApp App) error,
|
||||
optFilename ...string,
|
||||
) {
|
||||
var file string
|
||||
if len(optFilename) > 0 {
|
||||
file = optFilename[0]
|
||||
} else {
|
||||
_, path, _, _ := runtime.Caller(1)
|
||||
file = filepath.Base(path)
|
||||
}
|
||||
|
||||
l.list = append(l.list, &Migration{
|
||||
File: file,
|
||||
Up: up,
|
||||
Down: down,
|
||||
})
|
||||
|
||||
sort.SliceStable(l.list, func(i int, j int) bool {
|
||||
return l.list[i].File < l.list[j].File
|
||||
})
|
||||
}
|
48
core/migrations_list_test.go
Normal file
48
core/migrations_list_test.go
Normal file
|
@ -0,0 +1,48 @@
|
|||
package core_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
)
|
||||
|
||||
func TestMigrationsList(t *testing.T) {
|
||||
l1 := core.MigrationsList{}
|
||||
l1.Add(&core.Migration{File: "5_test.go"})
|
||||
l1.Add(&core.Migration{ /* auto detect file name */ })
|
||||
l1.Register(nil, nil, "3_test.go")
|
||||
l1.Register(nil, nil, "1_test.go")
|
||||
l1.Register(nil, nil, "2_test.go")
|
||||
l1.Register(nil, nil /* auto detect file name */)
|
||||
|
||||
l2 := core.MigrationsList{}
|
||||
l2.Register(nil, nil, "4_test.go")
|
||||
l2.Copy(l1)
|
||||
|
||||
expected := []string{
|
||||
"1_test.go",
|
||||
"2_test.go",
|
||||
"3_test.go",
|
||||
"4_test.go",
|
||||
"5_test.go",
|
||||
// twice because there 2 test migrations with auto filename
|
||||
"migrations_list_test.go",
|
||||
"migrations_list_test.go",
|
||||
}
|
||||
|
||||
items := l2.Items()
|
||||
if len(items) != len(expected) {
|
||||
names := make([]string, len(items))
|
||||
for i, item := range items {
|
||||
names[i] = item.File
|
||||
}
|
||||
t.Fatalf("Expected %d items, got %d:\n%v", len(expected), len(names), names)
|
||||
}
|
||||
|
||||
for i, name := range expected {
|
||||
item := l2.Item(i)
|
||||
if item.File != name {
|
||||
t.Fatalf("Expected name %s for index %d, got %s", name, i, item.File)
|
||||
}
|
||||
}
|
||||
}
|
317
core/migrations_runner.go
Normal file
317
core/migrations_runner.go
Normal file
|
@ -0,0 +1,317 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/fatih/color"
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/tools/osutils"
|
||||
"github.com/spf13/cast"
|
||||
)
|
||||
|
||||
var AppMigrations MigrationsList
|
||||
var SystemMigrations MigrationsList
|
||||
|
||||
const DefaultMigrationsTable = "_migrations"
|
||||
|
||||
// MigrationsRunner defines a simple struct for managing the execution of db migrations.
|
||||
type MigrationsRunner struct {
|
||||
app App
|
||||
tableName string
|
||||
migrationsList MigrationsList
|
||||
inited bool
|
||||
}
|
||||
|
||||
// NewMigrationsRunner creates and initializes a new db migrations MigrationsRunner instance.
|
||||
func NewMigrationsRunner(app App, migrationsList MigrationsList) *MigrationsRunner {
|
||||
return &MigrationsRunner{
|
||||
app: app,
|
||||
migrationsList: migrationsList,
|
||||
tableName: DefaultMigrationsTable,
|
||||
}
|
||||
}
|
||||
|
||||
// Run interactively executes the current runner with the provided args.
|
||||
//
|
||||
// The following commands are supported:
|
||||
// - up - applies all migrations
|
||||
// - down [n] - reverts the last n (default 1) applied migrations
|
||||
// - history-sync - syncs the migrations table with the runner's migrations list
|
||||
func (r *MigrationsRunner) Run(args ...string) error {
|
||||
if err := r.initMigrationsTable(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd := "up"
|
||||
if len(args) > 0 {
|
||||
cmd = args[0]
|
||||
}
|
||||
|
||||
switch cmd {
|
||||
case "up":
|
||||
applied, err := r.Up()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(applied) == 0 {
|
||||
color.Green("No new migrations to apply.")
|
||||
} else {
|
||||
for _, file := range applied {
|
||||
color.Green("Applied %s", file)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
case "down":
|
||||
toRevertCount := 1
|
||||
if len(args) > 1 {
|
||||
toRevertCount = cast.ToInt(args[1])
|
||||
if toRevertCount < 0 {
|
||||
// revert all applied migrations
|
||||
toRevertCount = len(r.migrationsList.Items())
|
||||
}
|
||||
}
|
||||
|
||||
names, err := r.lastAppliedMigrations(toRevertCount)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
confirm := osutils.YesNoPrompt(fmt.Sprintf(
|
||||
"\n%v\nDo you really want to revert the last %d applied migration(s)?",
|
||||
strings.Join(names, "\n"),
|
||||
toRevertCount,
|
||||
), false)
|
||||
if !confirm {
|
||||
fmt.Println("The command has been cancelled")
|
||||
return nil
|
||||
}
|
||||
|
||||
reverted, err := r.Down(toRevertCount)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(reverted) == 0 {
|
||||
color.Green("No migrations to revert.")
|
||||
} else {
|
||||
for _, file := range reverted {
|
||||
color.Green("Reverted %s", file)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
case "history-sync":
|
||||
if err := r.RemoveMissingAppliedMigrations(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
color.Green("The %s table was synced with the available migrations.", r.tableName)
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("unsupported command: %q", cmd)
|
||||
}
|
||||
}
|
||||
|
||||
// Up executes all unapplied migrations for the provided runner.
|
||||
//
|
||||
// On success returns list with the applied migrations file names.
|
||||
func (r *MigrationsRunner) Up() ([]string, error) {
|
||||
if err := r.initMigrationsTable(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
applied := []string{}
|
||||
|
||||
err := r.app.AuxRunInTransaction(func(txApp App) error {
|
||||
return txApp.RunInTransaction(func(txApp App) error {
|
||||
for _, m := range r.migrationsList.Items() {
|
||||
// applied migrations check
|
||||
if r.isMigrationApplied(txApp, m.File) {
|
||||
if m.ReapplyCondition == nil {
|
||||
continue // no need to reapply
|
||||
}
|
||||
|
||||
shouldReapply, err := m.ReapplyCondition(txApp, r, m.File)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !shouldReapply {
|
||||
continue
|
||||
}
|
||||
|
||||
// clear previous history stored entry
|
||||
// (it will be recreated after successful execution)
|
||||
r.saveRevertedMigration(txApp, m.File)
|
||||
}
|
||||
|
||||
// ignore empty Up action
|
||||
if m.Up != nil {
|
||||
if err := m.Up(txApp); err != nil {
|
||||
return fmt.Errorf("failed to apply migration %s: %w", m.File, err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := r.saveAppliedMigration(txApp, m.File); err != nil {
|
||||
return fmt.Errorf("failed to save applied migration info for %s: %w", m.File, err)
|
||||
}
|
||||
|
||||
applied = append(applied, m.File)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return applied, nil
|
||||
}
|
||||
|
||||
// Down reverts the last `toRevertCount` applied migrations
|
||||
// (in the order they were applied).
|
||||
//
|
||||
// On success returns list with the reverted migrations file names.
|
||||
func (r *MigrationsRunner) Down(toRevertCount int) ([]string, error) {
|
||||
if err := r.initMigrationsTable(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
reverted := make([]string, 0, toRevertCount)
|
||||
|
||||
names, appliedErr := r.lastAppliedMigrations(toRevertCount)
|
||||
if appliedErr != nil {
|
||||
return nil, appliedErr
|
||||
}
|
||||
|
||||
err := r.app.AuxRunInTransaction(func(txApp App) error {
|
||||
return txApp.RunInTransaction(func(txApp App) error {
|
||||
for _, name := range names {
|
||||
for _, m := range r.migrationsList.Items() {
|
||||
if m.File != name {
|
||||
continue
|
||||
}
|
||||
|
||||
// revert limit reached
|
||||
if toRevertCount-len(reverted) <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ignore empty Down action
|
||||
if m.Down != nil {
|
||||
if err := m.Down(txApp); err != nil {
|
||||
return fmt.Errorf("failed to revert migration %s: %w", m.File, err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := r.saveRevertedMigration(txApp, m.File); err != nil {
|
||||
return fmt.Errorf("failed to save reverted migration info for %s: %w", m.File, err)
|
||||
}
|
||||
|
||||
reverted = append(reverted, m.File)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return reverted, nil
|
||||
}
|
||||
|
||||
// RemoveMissingAppliedMigrations removes the db entries of all applied migrations
|
||||
// that are not listed in the runner's migrations list.
|
||||
func (r *MigrationsRunner) RemoveMissingAppliedMigrations() error {
|
||||
loadedMigrations := r.migrationsList.Items()
|
||||
|
||||
names := make([]any, len(loadedMigrations))
|
||||
for i, migration := range loadedMigrations {
|
||||
names[i] = migration.File
|
||||
}
|
||||
|
||||
_, err := r.app.DB().Delete(r.tableName, dbx.Not(dbx.HashExp{
|
||||
"file": names,
|
||||
})).Execute()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *MigrationsRunner) initMigrationsTable() error {
|
||||
if r.inited {
|
||||
return nil // already inited
|
||||
}
|
||||
|
||||
rawQuery := fmt.Sprintf(
|
||||
"CREATE TABLE IF NOT EXISTS {{%s}} (file VARCHAR(255) PRIMARY KEY NOT NULL, applied INTEGER NOT NULL)",
|
||||
r.tableName,
|
||||
)
|
||||
|
||||
_, err := r.app.DB().NewQuery(rawQuery).Execute()
|
||||
|
||||
if err == nil {
|
||||
r.inited = true
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *MigrationsRunner) isMigrationApplied(txApp App, file string) bool {
|
||||
var exists int
|
||||
|
||||
err := txApp.DB().Select("(1)").
|
||||
From(r.tableName).
|
||||
Where(dbx.HashExp{"file": file}).
|
||||
Limit(1).
|
||||
Row(&exists)
|
||||
|
||||
return err == nil && exists > 0
|
||||
}
|
||||
|
||||
func (r *MigrationsRunner) saveAppliedMigration(txApp App, file string) error {
|
||||
_, err := txApp.DB().Insert(r.tableName, dbx.Params{
|
||||
"file": file,
|
||||
"applied": time.Now().UnixMicro(),
|
||||
}).Execute()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *MigrationsRunner) saveRevertedMigration(txApp App, file string) error {
|
||||
_, err := txApp.DB().Delete(r.tableName, dbx.HashExp{"file": file}).Execute()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *MigrationsRunner) lastAppliedMigrations(limit int) ([]string, error) {
|
||||
var files = make([]string, 0, limit)
|
||||
|
||||
loadedMigrations := r.migrationsList.Items()
|
||||
|
||||
names := make([]any, len(loadedMigrations))
|
||||
for i, migration := range loadedMigrations {
|
||||
names[i] = migration.File
|
||||
}
|
||||
|
||||
err := r.app.DB().Select("file").
|
||||
From(r.tableName).
|
||||
Where(dbx.Not(dbx.HashExp{"applied": nil})).
|
||||
AndWhere(dbx.HashExp{"file": names}).
|
||||
// unify microseconds and seconds applied time for backward compatibility
|
||||
OrderBy("substr(applied||'0000000000000000', 0, 17) DESC").
|
||||
AndOrderBy("file DESC").
|
||||
Limit(int64(limit)).
|
||||
Column(&files)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return files, nil
|
||||
}
|
212
core/migrations_runner_test.go
Normal file
212
core/migrations_runner_test.go
Normal file
|
@ -0,0 +1,212 @@
|
|||
package core_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestMigrationsRunnerUpAndDown(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
callsOrder := []string{}
|
||||
|
||||
l := core.MigrationsList{}
|
||||
l.Register(func(app core.App) error {
|
||||
callsOrder = append(callsOrder, "up2")
|
||||
return nil
|
||||
}, func(app core.App) error {
|
||||
callsOrder = append(callsOrder, "down2")
|
||||
return nil
|
||||
}, "2_test")
|
||||
l.Register(func(app core.App) error {
|
||||
callsOrder = append(callsOrder, "up3")
|
||||
return nil
|
||||
}, func(app core.App) error {
|
||||
callsOrder = append(callsOrder, "down3")
|
||||
return nil
|
||||
}, "3_test")
|
||||
l.Register(func(app core.App) error {
|
||||
callsOrder = append(callsOrder, "up1")
|
||||
return nil
|
||||
}, func(app core.App) error {
|
||||
callsOrder = append(callsOrder, "down1")
|
||||
return nil
|
||||
}, "1_test")
|
||||
l.Register(func(app core.App) error {
|
||||
callsOrder = append(callsOrder, "up4")
|
||||
return nil
|
||||
}, func(app core.App) error {
|
||||
callsOrder = append(callsOrder, "down4")
|
||||
return nil
|
||||
}, "4_test")
|
||||
l.Add(&core.Migration{
|
||||
Up: func(app core.App) error {
|
||||
callsOrder = append(callsOrder, "up5")
|
||||
return nil
|
||||
},
|
||||
Down: func(app core.App) error {
|
||||
callsOrder = append(callsOrder, "down5")
|
||||
return nil
|
||||
},
|
||||
File: "5_test",
|
||||
ReapplyCondition: func(txApp core.App, runner *core.MigrationsRunner, fileName string) (bool, error) {
|
||||
return true, nil
|
||||
},
|
||||
})
|
||||
|
||||
runner := core.NewMigrationsRunner(app, l)
|
||||
|
||||
// ---------------------------------------------------------------
|
||||
// simulate partially out-of-order applied migration
|
||||
// ---------------------------------------------------------------
|
||||
|
||||
_, err := app.DB().Insert(core.DefaultMigrationsTable, dbx.Params{
|
||||
"file": "4_test",
|
||||
"applied": time.Now().UnixMicro() - 2,
|
||||
}).Execute()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to insert 5_test migration: %v", err)
|
||||
}
|
||||
|
||||
_, err = app.DB().Insert(core.DefaultMigrationsTable, dbx.Params{
|
||||
"file": "5_test",
|
||||
"applied": time.Now().UnixMicro() - 1,
|
||||
}).Execute()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to insert 5_test migration: %v", err)
|
||||
}
|
||||
|
||||
_, err = app.DB().Insert(core.DefaultMigrationsTable, dbx.Params{
|
||||
"file": "2_test",
|
||||
"applied": time.Now().UnixMicro(),
|
||||
}).Execute()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to insert 2_test migration: %v", err)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------
|
||||
// Up()
|
||||
// ---------------------------------------------------------------
|
||||
|
||||
if _, err := runner.Up(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expectedUpCallsOrder := `["up1","up3","up5"]` // skip up2 and up4 since they were applied already (up5 has extra reapply condition)
|
||||
|
||||
upCallsOrder, err := json.Marshal(callsOrder)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if v := string(upCallsOrder); v != expectedUpCallsOrder {
|
||||
t.Fatalf("Expected Up() calls order %s, got %s", expectedUpCallsOrder, upCallsOrder)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------
|
||||
|
||||
// reset callsOrder
|
||||
callsOrder = []string{}
|
||||
|
||||
// simulate unrun migration
|
||||
l.Register(nil, func(app core.App) error {
|
||||
callsOrder = append(callsOrder, "down6")
|
||||
return nil
|
||||
}, "6_test")
|
||||
|
||||
// simulate applied migrations from different migrations list
|
||||
_, err = app.DB().Insert(core.DefaultMigrationsTable, dbx.Params{
|
||||
"file": "from_different_list",
|
||||
"applied": time.Now().UnixMicro(),
|
||||
}).Execute()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to insert from_different_list migration: %v", err)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------
|
||||
|
||||
// ---------------------------------------------------------------
|
||||
// Down()
|
||||
// ---------------------------------------------------------------
|
||||
|
||||
if _, err := runner.Down(2); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expectedDownCallsOrder := `["down5","down3"]` // revert in the applied order
|
||||
|
||||
downCallsOrder, err := json.Marshal(callsOrder)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if v := string(downCallsOrder); v != expectedDownCallsOrder {
|
||||
t.Fatalf("Expected Down() calls order %s, got %s", expectedDownCallsOrder, downCallsOrder)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMigrationsRunnerRemoveMissingAppliedMigrations(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
// mock migrations history
|
||||
for i := 1; i <= 3; i++ {
|
||||
_, err := app.DB().Insert(core.DefaultMigrationsTable, dbx.Params{
|
||||
"file": fmt.Sprintf("%d_test", i),
|
||||
"applied": time.Now().UnixMicro(),
|
||||
}).Execute()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
if !isMigrationApplied(app, "2_test") {
|
||||
t.Fatalf("Expected 2_test migration to be applied")
|
||||
}
|
||||
|
||||
// create a runner without 2_test to mock deleted migration
|
||||
l := core.MigrationsList{}
|
||||
l.Register(func(app core.App) error {
|
||||
return nil
|
||||
}, func(app core.App) error {
|
||||
return nil
|
||||
}, "1_test")
|
||||
l.Register(func(app core.App) error {
|
||||
return nil
|
||||
}, func(app core.App) error {
|
||||
return nil
|
||||
}, "3_test")
|
||||
|
||||
r := core.NewMigrationsRunner(app, l)
|
||||
|
||||
if err := r.RemoveMissingAppliedMigrations(); err != nil {
|
||||
t.Fatalf("Failed to remove missing applied migrations: %v", err)
|
||||
}
|
||||
|
||||
if isMigrationApplied(app, "2_test") {
|
||||
t.Fatalf("Expected 2_test migration to NOT be applied")
|
||||
}
|
||||
}
|
||||
|
||||
func isMigrationApplied(app core.App, file string) bool {
|
||||
var exists int
|
||||
|
||||
err := app.DB().Select("(1)").
|
||||
From(core.DefaultMigrationsTable).
|
||||
Where(dbx.HashExp{"file": file}).
|
||||
Limit(1).
|
||||
Row(&exists)
|
||||
|
||||
return err == nil && exists > 0
|
||||
}
|
127
core/otp_model.go
Normal file
127
core/otp_model.go
Normal file
|
@ -0,0 +1,127 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
)
|
||||
|
||||
const CollectionNameOTPs = "_otps"
|
||||
|
||||
var (
|
||||
_ Model = (*OTP)(nil)
|
||||
_ PreValidator = (*OTP)(nil)
|
||||
_ RecordProxy = (*OTP)(nil)
|
||||
)
|
||||
|
||||
// OTP defines a Record proxy for working with the otps collection.
|
||||
type OTP struct {
|
||||
*Record
|
||||
}
|
||||
|
||||
// NewOTP instantiates and returns a new blank *OTP model.
|
||||
//
|
||||
// Example usage:
|
||||
//
|
||||
// otp := core.NewOTP(app)
|
||||
// otp.SetRecordRef(user.Id)
|
||||
// otp.SetCollectionRef(user.Collection().Id)
|
||||
// otp.SetPassword(security.RandomStringWithAlphabet(6, "1234567890"))
|
||||
// app.Save(otp)
|
||||
func NewOTP(app App) *OTP {
|
||||
m := &OTP{}
|
||||
|
||||
c, err := app.FindCachedCollectionByNameOrId(CollectionNameOTPs)
|
||||
if err != nil {
|
||||
// this is just to make tests easier since otp is a system collection and it is expected to be always accessible
|
||||
// (note: the loaded record is further checked on OTP.PreValidate())
|
||||
c = NewBaseCollection("__invalid__")
|
||||
}
|
||||
|
||||
m.Record = NewRecord(c)
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
// PreValidate implements the [PreValidator] interface and checks
|
||||
// whether the proxy is properly loaded.
|
||||
func (m *OTP) PreValidate(ctx context.Context, app App) error {
|
||||
if m.Record == nil || m.Record.Collection().Name != CollectionNameOTPs {
|
||||
return errors.New("missing or invalid otp ProxyRecord")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ProxyRecord returns the proxied Record model.
|
||||
func (m *OTP) ProxyRecord() *Record {
|
||||
return m.Record
|
||||
}
|
||||
|
||||
// SetProxyRecord loads the specified record model into the current proxy.
|
||||
func (m *OTP) SetProxyRecord(record *Record) {
|
||||
m.Record = record
|
||||
}
|
||||
|
||||
// CollectionRef returns the "collectionRef" field value.
|
||||
func (m *OTP) CollectionRef() string {
|
||||
return m.GetString("collectionRef")
|
||||
}
|
||||
|
||||
// SetCollectionRef updates the "collectionRef" record field value.
|
||||
func (m *OTP) SetCollectionRef(collectionId string) {
|
||||
m.Set("collectionRef", collectionId)
|
||||
}
|
||||
|
||||
// RecordRef returns the "recordRef" record field value.
|
||||
func (m *OTP) RecordRef() string {
|
||||
return m.GetString("recordRef")
|
||||
}
|
||||
|
||||
// SetRecordRef updates the "recordRef" record field value.
|
||||
func (m *OTP) SetRecordRef(recordId string) {
|
||||
m.Set("recordRef", recordId)
|
||||
}
|
||||
|
||||
// SentTo returns the "sentTo" record field value.
|
||||
//
|
||||
// It could be any string value (email, phone, message app id, etc.)
|
||||
// and usually is used as part of the auth flow to update the verified
|
||||
// user state in case for example the sentTo value matches with the user record email.
|
||||
func (m *OTP) SentTo() string {
|
||||
return m.GetString("sentTo")
|
||||
}
|
||||
|
||||
// SetSentTo updates the "sentTo" record field value.
|
||||
func (m *OTP) SetSentTo(val string) {
|
||||
m.Set("sentTo", val)
|
||||
}
|
||||
|
||||
// Created returns the "created" record field value.
|
||||
func (m *OTP) Created() types.DateTime {
|
||||
return m.GetDateTime("created")
|
||||
}
|
||||
|
||||
// Updated returns the "updated" record field value.
|
||||
func (m *OTP) Updated() types.DateTime {
|
||||
return m.GetDateTime("updated")
|
||||
}
|
||||
|
||||
// HasExpired checks if the otp is expired, aka. whether it has been
|
||||
// more than maxElapsed time since its creation.
|
||||
func (m *OTP) HasExpired(maxElapsed time.Duration) bool {
|
||||
return time.Since(m.Created().Time()) > maxElapsed
|
||||
}
|
||||
|
||||
func (app *BaseApp) registerOTPHooks() {
|
||||
recordRefHooks[*OTP](app, CollectionNameOTPs, CollectionTypeAuth)
|
||||
|
||||
// run on every hour to cleanup expired otp sessions
|
||||
app.Cron().Add("__pbOTPCleanup__", "0 * * * *", func() {
|
||||
if err := app.DeleteExpiredOTPs(); err != nil {
|
||||
app.Logger().Warn("Failed to delete expired OTP sessions", "error", err)
|
||||
}
|
||||
})
|
||||
}
|
302
core/otp_model_test.go
Normal file
302
core/otp_model_test.go
Normal file
|
@ -0,0 +1,302 @@
|
|||
package core_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
)
|
||||
|
||||
func TestNewOTP(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
otp := core.NewOTP(app)
|
||||
|
||||
if otp.Collection().Name != core.CollectionNameOTPs {
|
||||
t.Fatalf("Expected record with %q collection, got %q", core.CollectionNameOTPs, otp.Collection().Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOTPProxyRecord(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
record := core.NewRecord(core.NewBaseCollection("test"))
|
||||
record.Id = "test_id"
|
||||
|
||||
otp := core.OTP{}
|
||||
otp.SetProxyRecord(record)
|
||||
|
||||
if otp.ProxyRecord() == nil || otp.ProxyRecord().Id != record.Id {
|
||||
t.Fatalf("Expected proxy record with id %q, got %v", record.Id, otp.ProxyRecord())
|
||||
}
|
||||
}
|
||||
|
||||
func TestOTPRecordRef(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
otp := core.NewOTP(app)
|
||||
|
||||
testValues := []string{"test_1", "test2", ""}
|
||||
for i, testValue := range testValues {
|
||||
t.Run(fmt.Sprintf("%d_%q", i, testValue), func(t *testing.T) {
|
||||
otp.SetRecordRef(testValue)
|
||||
|
||||
if v := otp.RecordRef(); v != testValue {
|
||||
t.Fatalf("Expected getter %q, got %q", testValue, v)
|
||||
}
|
||||
|
||||
if v := otp.GetString("recordRef"); v != testValue {
|
||||
t.Fatalf("Expected field value %q, got %q", testValue, v)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOTPCollectionRef(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
otp := core.NewOTP(app)
|
||||
|
||||
testValues := []string{"test_1", "test2", ""}
|
||||
for i, testValue := range testValues {
|
||||
t.Run(fmt.Sprintf("%d_%q", i, testValue), func(t *testing.T) {
|
||||
otp.SetCollectionRef(testValue)
|
||||
|
||||
if v := otp.CollectionRef(); v != testValue {
|
||||
t.Fatalf("Expected getter %q, got %q", testValue, v)
|
||||
}
|
||||
|
||||
if v := otp.GetString("collectionRef"); v != testValue {
|
||||
t.Fatalf("Expected field value %q, got %q", testValue, v)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOTPSentTo(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
otp := core.NewOTP(app)
|
||||
|
||||
testValues := []string{"test_1", "test2", ""}
|
||||
for i, testValue := range testValues {
|
||||
t.Run(fmt.Sprintf("%d_%q", i, testValue), func(t *testing.T) {
|
||||
otp.SetSentTo(testValue)
|
||||
|
||||
if v := otp.SentTo(); v != testValue {
|
||||
t.Fatalf("Expected getter %q, got %q", testValue, v)
|
||||
}
|
||||
|
||||
if v := otp.GetString("sentTo"); v != testValue {
|
||||
t.Fatalf("Expected field value %q, got %q", testValue, v)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOTPCreated(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
otp := core.NewOTP(app)
|
||||
|
||||
if v := otp.Created().String(); v != "" {
|
||||
t.Fatalf("Expected empty created, got %q", v)
|
||||
}
|
||||
|
||||
now := types.NowDateTime()
|
||||
otp.SetRaw("created", now)
|
||||
|
||||
if v := otp.Created().String(); v != now.String() {
|
||||
t.Fatalf("Expected %q created, got %q", now.String(), v)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOTPUpdated(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
otp := core.NewOTP(app)
|
||||
|
||||
if v := otp.Updated().String(); v != "" {
|
||||
t.Fatalf("Expected empty updated, got %q", v)
|
||||
}
|
||||
|
||||
now := types.NowDateTime()
|
||||
otp.SetRaw("updated", now)
|
||||
|
||||
if v := otp.Updated().String(); v != now.String() {
|
||||
t.Fatalf("Expected %q updated, got %q", now.String(), v)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOTPHasExpired(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
now := types.NowDateTime()
|
||||
|
||||
otp := core.NewOTP(app)
|
||||
otp.SetRaw("created", now.Add(-5*time.Minute))
|
||||
|
||||
scenarios := []struct {
|
||||
maxElapsed time.Duration
|
||||
expected bool
|
||||
}{
|
||||
{0 * time.Minute, true},
|
||||
{3 * time.Minute, true},
|
||||
{5 * time.Minute, true},
|
||||
{6 * time.Minute, false},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%s", i, s.maxElapsed.String()), func(t *testing.T) {
|
||||
result := otp.HasExpired(s.maxElapsed)
|
||||
|
||||
if result != s.expected {
|
||||
t.Fatalf("Expected %v, got %v", s.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOTPPreValidate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
otpsCol, err := app.FindCollectionByNameOrId(core.CollectionNameOTPs)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Run("no proxy record", func(t *testing.T) {
|
||||
otp := &core.OTP{}
|
||||
|
||||
if err := app.Validate(otp); err == nil {
|
||||
t.Fatal("Expected collection validation error")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("non-OTP collection", func(t *testing.T) {
|
||||
otp := &core.OTP{}
|
||||
otp.SetProxyRecord(core.NewRecord(core.NewBaseCollection("invalid")))
|
||||
otp.SetRecordRef(user.Id)
|
||||
otp.SetCollectionRef(user.Collection().Id)
|
||||
otp.SetPassword("test123")
|
||||
|
||||
if err := app.Validate(otp); err == nil {
|
||||
t.Fatal("Expected collection validation error")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("OTP collection", func(t *testing.T) {
|
||||
otp := &core.OTP{}
|
||||
otp.SetProxyRecord(core.NewRecord(otpsCol))
|
||||
otp.SetRecordRef(user.Id)
|
||||
otp.SetCollectionRef(user.Collection().Id)
|
||||
otp.SetPassword("test123")
|
||||
|
||||
if err := app.Validate(otp); err != nil {
|
||||
t.Fatalf("Expected nil validation error, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestOTPValidateHook(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
demo1, err := app.FindRecordById("demo1", "84nmscqy84lsi1t")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
otp func() *core.OTP
|
||||
expectErrors []string
|
||||
}{
|
||||
{
|
||||
"empty",
|
||||
func() *core.OTP {
|
||||
return core.NewOTP(app)
|
||||
},
|
||||
[]string{"collectionRef", "recordRef", "password"},
|
||||
},
|
||||
{
|
||||
"non-auth collection",
|
||||
func() *core.OTP {
|
||||
otp := core.NewOTP(app)
|
||||
otp.SetCollectionRef(demo1.Collection().Id)
|
||||
otp.SetRecordRef(demo1.Id)
|
||||
otp.SetPassword("test123")
|
||||
return otp
|
||||
},
|
||||
[]string{"collectionRef"},
|
||||
},
|
||||
{
|
||||
"missing record id",
|
||||
func() *core.OTP {
|
||||
otp := core.NewOTP(app)
|
||||
otp.SetCollectionRef(user.Collection().Id)
|
||||
otp.SetRecordRef("missing")
|
||||
otp.SetPassword("test123")
|
||||
return otp
|
||||
},
|
||||
[]string{"recordRef"},
|
||||
},
|
||||
{
|
||||
"valid ref",
|
||||
func() *core.OTP {
|
||||
otp := core.NewOTP(app)
|
||||
otp.SetCollectionRef(user.Collection().Id)
|
||||
otp.SetRecordRef(user.Id)
|
||||
otp.SetPassword("test123")
|
||||
return otp
|
||||
},
|
||||
[]string{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
errs := app.Validate(s.otp())
|
||||
tests.TestValidationErrors(t, errs, s.expectErrors)
|
||||
})
|
||||
}
|
||||
}
|
117
core/otp_query.go
Normal file
117
core/otp_query.go
Normal file
|
@ -0,0 +1,117 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
)
|
||||
|
||||
// FindAllOTPsByRecord returns all OTP models linked to the provided auth record.
|
||||
func (app *BaseApp) FindAllOTPsByRecord(authRecord *Record) ([]*OTP, error) {
|
||||
result := []*OTP{}
|
||||
|
||||
err := app.RecordQuery(CollectionNameOTPs).
|
||||
AndWhere(dbx.HashExp{
|
||||
"collectionRef": authRecord.Collection().Id,
|
||||
"recordRef": authRecord.Id,
|
||||
}).
|
||||
OrderBy("created DESC").
|
||||
All(&result)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// FindAllOTPsByCollection returns all OTP models linked to the provided collection.
|
||||
func (app *BaseApp) FindAllOTPsByCollection(collection *Collection) ([]*OTP, error) {
|
||||
result := []*OTP{}
|
||||
|
||||
err := app.RecordQuery(CollectionNameOTPs).
|
||||
AndWhere(dbx.HashExp{"collectionRef": collection.Id}).
|
||||
OrderBy("created DESC").
|
||||
All(&result)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// FindOTPById returns a single OTP model by its id.
|
||||
func (app *BaseApp) FindOTPById(id string) (*OTP, error) {
|
||||
result := &OTP{}
|
||||
|
||||
err := app.RecordQuery(CollectionNameOTPs).
|
||||
AndWhere(dbx.HashExp{"id": id}).
|
||||
Limit(1).
|
||||
One(result)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// DeleteAllOTPsByRecord deletes all OTP models associated with the provided record.
|
||||
//
|
||||
// Returns a combined error with the failed deletes.
|
||||
func (app *BaseApp) DeleteAllOTPsByRecord(authRecord *Record) error {
|
||||
models, err := app.FindAllOTPsByRecord(authRecord)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var errs []error
|
||||
for _, m := range models {
|
||||
if err := app.Delete(m); err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
}
|
||||
if len(errs) > 0 {
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteExpiredOTPs deletes the expired OTPs for all auth collections.
|
||||
func (app *BaseApp) DeleteExpiredOTPs() error {
|
||||
authCollections, err := app.FindAllCollections(CollectionTypeAuth)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// note: perform even if OTP is disabled to ensure that there are no dangling old records
|
||||
for _, collection := range authCollections {
|
||||
minValidDate, err := types.ParseDateTime(time.Now().Add(-1 * collection.OTP.DurationTime()))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
items := []*Record{}
|
||||
|
||||
err = app.RecordQuery(CollectionNameOTPs).
|
||||
AndWhere(dbx.HashExp{"collectionRef": collection.Id}).
|
||||
AndWhere(dbx.NewExp("[[created]] < {:date}", dbx.Params{"date": minValidDate})).
|
||||
All(&items)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, item := range items {
|
||||
err = app.Delete(item)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
310
core/otp_query_test.go
Normal file
310
core/otp_query_test.go
Normal file
|
@ -0,0 +1,310 @@
|
|||
package core_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestFindAllOTPsByRecord(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
if err := tests.StubOTPRecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
demo1, err := app.FindRecordById("demo1", "84nmscqy84lsi1t")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
superuser2, err := app.FindAuthRecordByEmail(core.CollectionNameSuperusers, "test2@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
superuser4, err := app.FindAuthRecordByEmail(core.CollectionNameSuperusers, "test4@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
user1, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
scenarios := []struct {
|
||||
record *core.Record
|
||||
expected []string
|
||||
}{
|
||||
{demo1, nil},
|
||||
{superuser2, []string{"superuser2_0", "superuser2_1", "superuser2_3", "superuser2_2", "superuser2_4"}},
|
||||
{superuser4, nil},
|
||||
{user1, []string{"user1_0"}},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.record.Collection().Name+"_"+s.record.Id, func(t *testing.T) {
|
||||
result, err := app.FindAllOTPsByRecord(s.record)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(result) != len(s.expected) {
|
||||
t.Fatalf("Expected total otps %d, got %d", len(s.expected), len(result))
|
||||
}
|
||||
|
||||
for i, id := range s.expected {
|
||||
if result[i].Id != id {
|
||||
t.Errorf("[%d] Expected id %q, got %q", i, id, result[i].Id)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindAllOTPsByCollection(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
if err := tests.StubOTPRecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
demo1, err := app.FindCollectionByNameOrId("demo1")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
superusers, err := app.FindCollectionByNameOrId(core.CollectionNameSuperusers)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
clients, err := app.FindCollectionByNameOrId("clients")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
users, err := app.FindCollectionByNameOrId("users")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
scenarios := []struct {
|
||||
collection *core.Collection
|
||||
expected []string
|
||||
}{
|
||||
{demo1, nil},
|
||||
{superusers, []string{
|
||||
"superuser2_0",
|
||||
"superuser2_1",
|
||||
"superuser2_3",
|
||||
"superuser3_0",
|
||||
"superuser3_1",
|
||||
"superuser2_2",
|
||||
"superuser2_4",
|
||||
}},
|
||||
{clients, nil},
|
||||
{users, []string{"user1_0"}},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.collection.Name, func(t *testing.T) {
|
||||
result, err := app.FindAllOTPsByCollection(s.collection)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(result) != len(s.expected) {
|
||||
t.Fatalf("Expected total otps %d, got %d", len(s.expected), len(result))
|
||||
}
|
||||
|
||||
for i, id := range s.expected {
|
||||
if result[i].Id != id {
|
||||
t.Errorf("[%d] Expected id %q, got %q", i, id, result[i].Id)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindOTPById(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
if err := tests.StubOTPRecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
scenarios := []struct {
|
||||
id string
|
||||
expectError bool
|
||||
}{
|
||||
{"", true},
|
||||
{"84nmscqy84lsi1t", true}, // non-otp id
|
||||
{"superuser2_0", false},
|
||||
{"superuser2_4", false}, // expired
|
||||
{"user1_0", false},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.id, func(t *testing.T) {
|
||||
result, err := app.FindOTPById(s.id)
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
|
||||
}
|
||||
|
||||
if hasErr {
|
||||
return
|
||||
}
|
||||
|
||||
if result.Id != s.id {
|
||||
t.Fatalf("Expected record with id %q, got %q", s.id, result.Id)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteAllOTPsByRecord(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testApp, _ := tests.NewTestApp()
|
||||
defer testApp.Cleanup()
|
||||
|
||||
demo1, err := testApp.FindRecordById("demo1", "84nmscqy84lsi1t")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
superuser2, err := testApp.FindAuthRecordByEmail(core.CollectionNameSuperusers, "test2@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
superuser4, err := testApp.FindAuthRecordByEmail(core.CollectionNameSuperusers, "test4@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
user1, err := testApp.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
scenarios := []struct {
|
||||
record *core.Record
|
||||
deletedIds []string
|
||||
}{
|
||||
{demo1, nil}, // non-auth record
|
||||
{superuser2, []string{"superuser2_0", "superuser2_1", "superuser2_3", "superuser2_2", "superuser2_4"}},
|
||||
{superuser4, nil},
|
||||
{user1, []string{"user1_0"}},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%s_%s", i, s.record.Collection().Name, s.record.Id), func(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
if err := tests.StubOTPRecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
deletedIds := []string{}
|
||||
app.OnRecordAfterDeleteSuccess().BindFunc(func(e *core.RecordEvent) error {
|
||||
deletedIds = append(deletedIds, e.Record.Id)
|
||||
return e.Next()
|
||||
})
|
||||
|
||||
err := app.DeleteAllOTPsByRecord(s.record)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(deletedIds) != len(s.deletedIds) {
|
||||
t.Fatalf("Expected deleted ids\n%v\ngot\n%v", s.deletedIds, deletedIds)
|
||||
}
|
||||
|
||||
for _, id := range s.deletedIds {
|
||||
if !slices.Contains(deletedIds, id) {
|
||||
t.Errorf("Expected to find deleted id %q in %v", id, deletedIds)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteExpiredOTPs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
checkDeletedIds := func(app core.App, t *testing.T, expectedDeletedIds []string) {
|
||||
if err := tests.StubOTPRecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
deletedIds := []string{}
|
||||
app.OnRecordAfterDeleteSuccess().BindFunc(func(e *core.RecordEvent) error {
|
||||
deletedIds = append(deletedIds, e.Record.Id)
|
||||
return e.Next()
|
||||
})
|
||||
|
||||
if err := app.DeleteExpiredOTPs(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(deletedIds) != len(expectedDeletedIds) {
|
||||
t.Fatalf("Expected deleted ids\n%v\ngot\n%v", expectedDeletedIds, deletedIds)
|
||||
}
|
||||
|
||||
for _, id := range expectedDeletedIds {
|
||||
if !slices.Contains(deletedIds, id) {
|
||||
t.Errorf("Expected to find deleted id %q in %v", id, deletedIds)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
t.Run("default test collections", func(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
checkDeletedIds(app, t, []string{
|
||||
"user1_0",
|
||||
"superuser2_2",
|
||||
"superuser2_4",
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("otp collection duration mock", func(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
superusers, err := app.FindCollectionByNameOrId(core.CollectionNameSuperusers)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
superusers.OTP.Duration = 60
|
||||
if err := app.Save(superusers); err != nil {
|
||||
t.Fatalf("Failed to mock superusers otp duration: %v", err)
|
||||
}
|
||||
|
||||
checkDeletedIds(app, t, []string{
|
||||
"user1_0",
|
||||
"superuser2_2",
|
||||
"superuser2_4",
|
||||
"superuser3_1",
|
||||
})
|
||||
})
|
||||
}
|
403
core/record_field_resolver.go
Normal file
403
core/record_field_resolver.go
Normal file
|
@ -0,0 +1,403 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/tools/search"
|
||||
"github.com/pocketbase/pocketbase/tools/security"
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
"github.com/spf13/cast"
|
||||
)
|
||||
|
||||
// filter modifiers
|
||||
const (
|
||||
eachModifier string = "each"
|
||||
issetModifier string = "isset"
|
||||
lengthModifier string = "length"
|
||||
lowerModifier string = "lower"
|
||||
)
|
||||
|
||||
// ensure that `search.FieldResolver` interface is implemented
|
||||
var _ search.FieldResolver = (*RecordFieldResolver)(nil)
|
||||
|
||||
// RecordFieldResolver defines a custom search resolver struct for
|
||||
// managing Record model search fields.
|
||||
//
|
||||
// Usually used together with `search.Provider`.
|
||||
// Example:
|
||||
//
|
||||
// resolver := resolvers.NewRecordFieldResolver(
|
||||
// app,
|
||||
// myCollection,
|
||||
// &models.RequestInfo{...},
|
||||
// true,
|
||||
// )
|
||||
// provider := search.NewProvider(resolver)
|
||||
// ...
|
||||
type RecordFieldResolver struct {
|
||||
app App
|
||||
baseCollection *Collection
|
||||
requestInfo *RequestInfo
|
||||
staticRequestInfo map[string]any
|
||||
allowedFields []string
|
||||
joins []*join
|
||||
allowHiddenFields bool
|
||||
}
|
||||
|
||||
// AllowedFields returns a copy of the resolver's allowed fields.
|
||||
func (r *RecordFieldResolver) AllowedFields() []string {
|
||||
return slices.Clone(r.allowedFields)
|
||||
}
|
||||
|
||||
// SetAllowedFields replaces the resolver's allowed fields with the new ones.
|
||||
func (r *RecordFieldResolver) SetAllowedFields(newAllowedFields []string) {
|
||||
r.allowedFields = slices.Clone(newAllowedFields)
|
||||
}
|
||||
|
||||
// AllowHiddenFields returns whether the current resolver allows filtering hidden fields.
|
||||
func (r *RecordFieldResolver) AllowHiddenFields() bool {
|
||||
return r.allowHiddenFields
|
||||
}
|
||||
|
||||
// SetAllowHiddenFields enables or disables hidden fields filtering.
|
||||
func (r *RecordFieldResolver) SetAllowHiddenFields(allowHiddenFields bool) {
|
||||
r.allowHiddenFields = allowHiddenFields
|
||||
}
|
||||
|
||||
// NewRecordFieldResolver creates and initializes a new `RecordFieldResolver`.
|
||||
func NewRecordFieldResolver(
|
||||
app App,
|
||||
baseCollection *Collection,
|
||||
requestInfo *RequestInfo,
|
||||
allowHiddenFields bool,
|
||||
) *RecordFieldResolver {
|
||||
r := &RecordFieldResolver{
|
||||
app: app,
|
||||
baseCollection: baseCollection,
|
||||
requestInfo: requestInfo,
|
||||
allowHiddenFields: allowHiddenFields, // note: it is not based only on the requestInfo.auth since it could be used by a non-request internal method
|
||||
joins: []*join{},
|
||||
allowedFields: []string{
|
||||
`^\w+[\w\.\:]*$`,
|
||||
`^\@request\.context$`,
|
||||
`^\@request\.method$`,
|
||||
`^\@request\.auth\.[\w\.\:]*\w+$`,
|
||||
`^\@request\.body\.[\w\.\:]*\w+$`,
|
||||
`^\@request\.query\.[\w\.\:]*\w+$`,
|
||||
`^\@request\.headers\.[\w\.\:]*\w+$`,
|
||||
`^\@collection\.\w+(\:\w+)?\.[\w\.\:]*\w+$`,
|
||||
},
|
||||
}
|
||||
|
||||
r.staticRequestInfo = map[string]any{}
|
||||
if r.requestInfo != nil {
|
||||
r.staticRequestInfo["context"] = r.requestInfo.Context
|
||||
r.staticRequestInfo["method"] = r.requestInfo.Method
|
||||
r.staticRequestInfo["query"] = r.requestInfo.Query
|
||||
r.staticRequestInfo["headers"] = r.requestInfo.Headers
|
||||
r.staticRequestInfo["body"] = r.requestInfo.Body
|
||||
r.staticRequestInfo["auth"] = nil
|
||||
if r.requestInfo.Auth != nil {
|
||||
authClone := r.requestInfo.Auth.Clone()
|
||||
r.staticRequestInfo["auth"] = authClone.
|
||||
Unhide(authClone.Collection().Fields.FieldNames()...).
|
||||
IgnoreEmailVisibility(true).
|
||||
PublicExport()
|
||||
}
|
||||
}
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
// UpdateQuery implements `search.FieldResolver` interface.
|
||||
//
|
||||
// Conditionally updates the provided search query based on the
|
||||
// resolved fields (eg. dynamically joining relations).
|
||||
func (r *RecordFieldResolver) UpdateQuery(query *dbx.SelectQuery) error {
|
||||
if len(r.joins) > 0 {
|
||||
query.Distinct(true)
|
||||
|
||||
for _, join := range r.joins {
|
||||
query.LeftJoin(
|
||||
(join.tableName + " " + join.tableAlias),
|
||||
join.on,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Resolve implements `search.FieldResolver` interface.
|
||||
//
|
||||
// Example of some resolvable fieldName formats:
|
||||
//
|
||||
// id
|
||||
// someSelect.each
|
||||
// project.screen.status
|
||||
// screen.project_via_prototype.name
|
||||
// @request.context
|
||||
// @request.method
|
||||
// @request.query.filter
|
||||
// @request.headers.x_token
|
||||
// @request.auth.someRelation.name
|
||||
// @request.body.someRelation.name
|
||||
// @request.body.someField
|
||||
// @request.body.someSelect:each
|
||||
// @request.body.someField:isset
|
||||
// @collection.product.name
|
||||
func (r *RecordFieldResolver) Resolve(fieldName string) (*search.ResolverResult, error) {
|
||||
return parseAndRun(fieldName, r)
|
||||
}
|
||||
|
||||
func (r *RecordFieldResolver) resolveStaticRequestField(path ...string) (*search.ResolverResult, error) {
|
||||
if len(path) == 0 {
|
||||
return nil, errors.New("at least one path key should be provided")
|
||||
}
|
||||
|
||||
lastProp, modifier, err := splitModifier(path[len(path)-1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
path[len(path)-1] = lastProp
|
||||
|
||||
// extract value
|
||||
resultVal, err := extractNestedVal(r.staticRequestInfo, path...)
|
||||
if err != nil {
|
||||
r.app.Logger().Debug("resolveStaticRequestField graceful fallback", "error", err.Error())
|
||||
}
|
||||
|
||||
if modifier == issetModifier {
|
||||
if err != nil {
|
||||
return &search.ResolverResult{Identifier: "FALSE"}, nil
|
||||
}
|
||||
return &search.ResolverResult{Identifier: "TRUE"}, nil
|
||||
}
|
||||
|
||||
// note: we are ignoring the error because requestInfo is dynamic
|
||||
// and some of the lookup keys may not be defined for the request
|
||||
|
||||
switch v := resultVal.(type) {
|
||||
case nil:
|
||||
return &search.ResolverResult{Identifier: "NULL"}, nil
|
||||
case string:
|
||||
// check if it is a number field and explicitly try to cast to
|
||||
// float in case of a numeric string value was used
|
||||
// (this usually the case when the data is from a multipart/form-data request)
|
||||
field := r.baseCollection.Fields.GetByName(path[len(path)-1])
|
||||
if field != nil && field.Type() == FieldTypeNumber {
|
||||
if nv, err := strconv.ParseFloat(v, 64); err == nil {
|
||||
resultVal = nv
|
||||
}
|
||||
}
|
||||
// otherwise - no further processing is needed...
|
||||
case bool, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64:
|
||||
// no further processing is needed...
|
||||
default:
|
||||
// non-plain value
|
||||
// try casting to string (in case for exampe fmt.Stringer is implemented)
|
||||
val, castErr := cast.ToStringE(v)
|
||||
|
||||
// if that doesn't work, try encoding it
|
||||
if castErr != nil {
|
||||
encoded, jsonErr := json.Marshal(v)
|
||||
if jsonErr == nil {
|
||||
val = string(encoded)
|
||||
}
|
||||
}
|
||||
|
||||
resultVal = val
|
||||
}
|
||||
|
||||
placeholder := "f" + security.PseudorandomString(8)
|
||||
|
||||
if modifier == lowerModifier {
|
||||
return &search.ResolverResult{
|
||||
Identifier: "LOWER({:" + placeholder + "})",
|
||||
Params: dbx.Params{placeholder: resultVal},
|
||||
}, nil
|
||||
}
|
||||
|
||||
return &search.ResolverResult{
|
||||
Identifier: "{:" + placeholder + "}",
|
||||
Params: dbx.Params{placeholder: resultVal},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *RecordFieldResolver) loadCollection(collectionNameOrId string) (*Collection, error) {
|
||||
if collectionNameOrId == r.baseCollection.Name || collectionNameOrId == r.baseCollection.Id {
|
||||
return r.baseCollection, nil
|
||||
}
|
||||
|
||||
return getCollectionByModelOrIdentifier(r.app, collectionNameOrId)
|
||||
}
|
||||
|
||||
func (r *RecordFieldResolver) registerJoin(tableName string, tableAlias string, on dbx.Expression) {
|
||||
join := &join{
|
||||
tableName: tableName,
|
||||
tableAlias: tableAlias,
|
||||
on: on,
|
||||
}
|
||||
|
||||
// replace existing join
|
||||
for i, j := range r.joins {
|
||||
if j.tableAlias == join.tableAlias {
|
||||
r.joins[i] = join
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// register new join
|
||||
r.joins = append(r.joins, join)
|
||||
}
|
||||
|
||||
type mapExtractor interface {
|
||||
AsMap() map[string]any
|
||||
}
|
||||
|
||||
func extractNestedVal(rawData any, keys ...string) (any, error) {
|
||||
if len(keys) == 0 {
|
||||
return nil, errors.New("at least one key should be provided")
|
||||
}
|
||||
|
||||
switch m := rawData.(type) {
|
||||
// maps
|
||||
case map[string]any:
|
||||
return mapVal(m, keys...)
|
||||
case map[string]string:
|
||||
return mapVal(m, keys...)
|
||||
case map[string]bool:
|
||||
return mapVal(m, keys...)
|
||||
case map[string]float32:
|
||||
return mapVal(m, keys...)
|
||||
case map[string]float64:
|
||||
return mapVal(m, keys...)
|
||||
case map[string]int:
|
||||
return mapVal(m, keys...)
|
||||
case map[string]int8:
|
||||
return mapVal(m, keys...)
|
||||
case map[string]int16:
|
||||
return mapVal(m, keys...)
|
||||
case map[string]int32:
|
||||
return mapVal(m, keys...)
|
||||
case map[string]int64:
|
||||
return mapVal(m, keys...)
|
||||
case map[string]uint:
|
||||
return mapVal(m, keys...)
|
||||
case map[string]uint8:
|
||||
return mapVal(m, keys...)
|
||||
case map[string]uint16:
|
||||
return mapVal(m, keys...)
|
||||
case map[string]uint32:
|
||||
return mapVal(m, keys...)
|
||||
case map[string]uint64:
|
||||
return mapVal(m, keys...)
|
||||
case mapExtractor:
|
||||
return mapVal(m.AsMap(), keys...)
|
||||
case types.JSONRaw:
|
||||
var raw any
|
||||
err := json.Unmarshal(m, &raw)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal raw JSON in order extract nested value from: %w", err)
|
||||
}
|
||||
return extractNestedVal(raw, keys...)
|
||||
|
||||
// slices
|
||||
case []string:
|
||||
return arrVal(m, keys...)
|
||||
case []bool:
|
||||
return arrVal(m, keys...)
|
||||
case []float32:
|
||||
return arrVal(m, keys...)
|
||||
case []float64:
|
||||
return arrVal(m, keys...)
|
||||
case []int:
|
||||
return arrVal(m, keys...)
|
||||
case []int8:
|
||||
return arrVal(m, keys...)
|
||||
case []int16:
|
||||
return arrVal(m, keys...)
|
||||
case []int32:
|
||||
return arrVal(m, keys...)
|
||||
case []int64:
|
||||
return arrVal(m, keys...)
|
||||
case []uint:
|
||||
return arrVal(m, keys...)
|
||||
case []uint8:
|
||||
return arrVal(m, keys...)
|
||||
case []uint16:
|
||||
return arrVal(m, keys...)
|
||||
case []uint32:
|
||||
return arrVal(m, keys...)
|
||||
case []uint64:
|
||||
return arrVal(m, keys...)
|
||||
case []mapExtractor:
|
||||
extracted := make([]any, len(m))
|
||||
for i, v := range m {
|
||||
extracted[i] = v.AsMap()
|
||||
}
|
||||
return arrVal(extracted, keys...)
|
||||
case []any:
|
||||
return arrVal(m, keys...)
|
||||
case []types.JSONRaw:
|
||||
return arrVal(m, keys...)
|
||||
default:
|
||||
return nil, fmt.Errorf("expected map or array, got %#v", rawData)
|
||||
}
|
||||
}
|
||||
|
||||
func mapVal[T any](m map[string]T, keys ...string) (any, error) {
|
||||
result, ok := m[keys[0]]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid key path - missing key %q", keys[0])
|
||||
}
|
||||
|
||||
// end key reached
|
||||
if len(keys) == 1 {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
return extractNestedVal(result, keys[1:]...)
|
||||
}
|
||||
|
||||
func arrVal[T any](m []T, keys ...string) (any, error) {
|
||||
idx, err := strconv.Atoi(keys[0])
|
||||
if err != nil || idx < 0 || idx >= len(m) {
|
||||
return nil, fmt.Errorf("invalid key path - invalid or missing array index %q", keys[0])
|
||||
}
|
||||
|
||||
result := m[idx]
|
||||
|
||||
// end key reached
|
||||
if len(keys) == 1 {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
return extractNestedVal(result, keys[1:]...)
|
||||
}
|
||||
|
||||
func splitModifier(combined string) (string, string, error) {
|
||||
parts := strings.Split(combined, ":")
|
||||
|
||||
if len(parts) != 2 {
|
||||
return combined, "", nil
|
||||
}
|
||||
|
||||
// validate modifier
|
||||
switch parts[1] {
|
||||
case issetModifier,
|
||||
eachModifier,
|
||||
lengthModifier,
|
||||
lowerModifier:
|
||||
return parts[0], parts[1], nil
|
||||
}
|
||||
|
||||
return "", "", fmt.Errorf("unknown modifier in %q", combined)
|
||||
}
|
70
core/record_field_resolver_multi_match.go
Normal file
70
core/record_field_resolver_multi_match.go
Normal file
|
@ -0,0 +1,70 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/pocketbase/dbx"
|
||||
)
|
||||
|
||||
var _ dbx.Expression = (*multiMatchSubquery)(nil)
|
||||
|
||||
// join defines the specification for a single SQL JOIN clause.
|
||||
type join struct {
|
||||
tableName string
|
||||
tableAlias string
|
||||
on dbx.Expression
|
||||
}
|
||||
|
||||
// multiMatchSubquery defines a record multi-match subquery expression.
|
||||
type multiMatchSubquery struct {
|
||||
baseTableAlias string
|
||||
fromTableName string
|
||||
fromTableAlias string
|
||||
valueIdentifier string
|
||||
joins []*join
|
||||
params dbx.Params
|
||||
}
|
||||
|
||||
// Build converts the expression into a SQL fragment.
|
||||
//
|
||||
// Implements [dbx.Expression] interface.
|
||||
func (m *multiMatchSubquery) Build(db *dbx.DB, params dbx.Params) string {
|
||||
if m.baseTableAlias == "" || m.fromTableName == "" || m.fromTableAlias == "" {
|
||||
return "0=1"
|
||||
}
|
||||
|
||||
if params == nil {
|
||||
params = m.params
|
||||
} else {
|
||||
// merge by updating the parent params
|
||||
for k, v := range m.params {
|
||||
params[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
var mergedJoins strings.Builder
|
||||
for i, j := range m.joins {
|
||||
if i > 0 {
|
||||
mergedJoins.WriteString(" ")
|
||||
}
|
||||
mergedJoins.WriteString("LEFT JOIN ")
|
||||
mergedJoins.WriteString(db.QuoteTableName(j.tableName))
|
||||
mergedJoins.WriteString(" ")
|
||||
mergedJoins.WriteString(db.QuoteTableName(j.tableAlias))
|
||||
if j.on != nil {
|
||||
mergedJoins.WriteString(" ON ")
|
||||
mergedJoins.WriteString(j.on.Build(db, params))
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
`SELECT %s as [[multiMatchValue]] FROM %s %s %s WHERE %s = %s`,
|
||||
db.QuoteColumnName(m.valueIdentifier),
|
||||
db.QuoteTableName(m.fromTableName),
|
||||
db.QuoteTableName(m.fromTableAlias),
|
||||
mergedJoins.String(),
|
||||
db.QuoteColumnName(m.fromTableAlias+".id"),
|
||||
db.QuoteColumnName(m.baseTableAlias+".id"),
|
||||
)
|
||||
}
|
792
core/record_field_resolver_runner.go
Normal file
792
core/record_field_resolver_runner.go
Normal file
|
@ -0,0 +1,792 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/tools/dbutils"
|
||||
"github.com/pocketbase/pocketbase/tools/inflector"
|
||||
"github.com/pocketbase/pocketbase/tools/list"
|
||||
"github.com/pocketbase/pocketbase/tools/search"
|
||||
"github.com/pocketbase/pocketbase/tools/security"
|
||||
"github.com/spf13/cast"
|
||||
)
|
||||
|
||||
// maxNestedRels defines the max allowed nested relations depth.
|
||||
const maxNestedRels = 6
|
||||
|
||||
// list of auth filter fields that don't require join with the auth
|
||||
// collection or any other extra checks to be resolved.
|
||||
var plainRequestAuthFields = map[string]struct{}{
|
||||
"@request.auth." + FieldNameId: {},
|
||||
"@request.auth." + FieldNameCollectionId: {},
|
||||
"@request.auth." + FieldNameCollectionName: {},
|
||||
"@request.auth." + FieldNameEmail: {},
|
||||
"@request.auth." + FieldNameEmailVisibility: {},
|
||||
"@request.auth." + FieldNameVerified: {},
|
||||
}
|
||||
|
||||
// parseAndRun starts a new one-off RecordFieldResolver.Resolve execution.
|
||||
func parseAndRun(fieldName string, resolver *RecordFieldResolver) (*search.ResolverResult, error) {
|
||||
r := &runner{
|
||||
fieldName: fieldName,
|
||||
resolver: resolver,
|
||||
}
|
||||
|
||||
return r.run()
|
||||
}
|
||||
|
||||
type runner struct {
|
||||
used bool // indicates whether the runner was already executed
|
||||
resolver *RecordFieldResolver // resolver is the shared expression fields resolver
|
||||
fieldName string // the name of the single field expression the runner is responsible for
|
||||
|
||||
// shared processing state
|
||||
// ---------------------------------------------------------------
|
||||
activeProps []string // holds the active props that remains to be processed
|
||||
activeCollectionName string // the last used collection name
|
||||
activeTableAlias string // the last used table alias
|
||||
allowHiddenFields bool // indicates whether hidden fields (eg. email) should be allowed without extra conditions
|
||||
nullifyMisingField bool // indicating whether to return null on missing field or return an error
|
||||
withMultiMatch bool // indicates whether to attach a multiMatchSubquery condition to the ResolverResult
|
||||
multiMatchActiveTableAlias string // the last used multi-match table alias
|
||||
multiMatch *multiMatchSubquery // the multi-match subquery expression generated from the fieldName
|
||||
}
|
||||
|
||||
func (r *runner) run() (*search.ResolverResult, error) {
|
||||
if r.used {
|
||||
return nil, errors.New("the runner was already used")
|
||||
}
|
||||
|
||||
if len(r.resolver.allowedFields) > 0 && !list.ExistInSliceWithRegex(r.fieldName, r.resolver.allowedFields) {
|
||||
return nil, fmt.Errorf("failed to resolve field %q", r.fieldName)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
r.used = true
|
||||
}()
|
||||
|
||||
r.prepare()
|
||||
|
||||
// check for @collection field (aka. non-relational join)
|
||||
// must be in the format "@collection.COLLECTION_NAME.FIELD[.FIELD2....]"
|
||||
if r.activeProps[0] == "@collection" {
|
||||
return r.processCollectionField()
|
||||
}
|
||||
|
||||
if r.activeProps[0] == "@request" {
|
||||
if r.resolver.requestInfo == nil {
|
||||
return &search.ResolverResult{Identifier: "NULL"}, nil
|
||||
}
|
||||
|
||||
if strings.HasPrefix(r.fieldName, "@request.auth.") {
|
||||
return r.processRequestAuthField()
|
||||
}
|
||||
|
||||
if strings.HasPrefix(r.fieldName, "@request.body.") && len(r.activeProps) > 2 {
|
||||
name, modifier, err := splitModifier(r.activeProps[2])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
bodyField := r.resolver.baseCollection.Fields.GetByName(name)
|
||||
if bodyField == nil {
|
||||
return r.resolver.resolveStaticRequestField(r.activeProps[1:]...)
|
||||
}
|
||||
|
||||
// check for body relation field
|
||||
if bodyField.Type() == FieldTypeRelation && len(r.activeProps) > 3 {
|
||||
return r.processRequestInfoRelationField(bodyField)
|
||||
}
|
||||
|
||||
// check for body arrayble fields ":each" modifier
|
||||
if modifier == eachModifier && len(r.activeProps) == 3 {
|
||||
return r.processRequestInfoEachModifier(bodyField)
|
||||
}
|
||||
|
||||
// check for body arrayble fields ":length" modifier
|
||||
if modifier == lengthModifier && len(r.activeProps) == 3 {
|
||||
return r.processRequestInfoLengthModifier(bodyField)
|
||||
}
|
||||
|
||||
// check for body arrayble fields ":lower" modifier
|
||||
if modifier == lowerModifier && len(r.activeProps) == 3 {
|
||||
return r.processRequestInfoLowerModifier(bodyField)
|
||||
}
|
||||
}
|
||||
|
||||
// some other @request.* static field
|
||||
return r.resolver.resolveStaticRequestField(r.activeProps[1:]...)
|
||||
}
|
||||
|
||||
// regular field
|
||||
return r.processActiveProps()
|
||||
}
|
||||
|
||||
func (r *runner) prepare() {
|
||||
r.activeProps = strings.Split(r.fieldName, ".")
|
||||
|
||||
r.activeCollectionName = r.resolver.baseCollection.Name
|
||||
r.activeTableAlias = inflector.Columnify(r.activeCollectionName)
|
||||
|
||||
r.allowHiddenFields = r.resolver.allowHiddenFields
|
||||
// always allow hidden fields since the @.* filter is a system one
|
||||
if r.activeProps[0] == "@collection" || r.activeProps[0] == "@request" {
|
||||
r.allowHiddenFields = true
|
||||
}
|
||||
|
||||
// enable the ignore flag for missing @request.* fields for backward
|
||||
// compatibility and consistency with all @request.* filter fields and types
|
||||
r.nullifyMisingField = r.activeProps[0] == "@request"
|
||||
|
||||
// prepare a multi-match subquery
|
||||
r.multiMatch = &multiMatchSubquery{
|
||||
baseTableAlias: r.activeTableAlias,
|
||||
params: dbx.Params{},
|
||||
}
|
||||
r.multiMatch.fromTableName = inflector.Columnify(r.activeCollectionName)
|
||||
r.multiMatch.fromTableAlias = "__mm_" + r.activeTableAlias
|
||||
r.multiMatchActiveTableAlias = r.multiMatch.fromTableAlias
|
||||
r.withMultiMatch = false
|
||||
}
|
||||
|
||||
func (r *runner) processCollectionField() (*search.ResolverResult, error) {
|
||||
if len(r.activeProps) < 3 {
|
||||
return nil, fmt.Errorf("invalid @collection field path in %q", r.fieldName)
|
||||
}
|
||||
|
||||
// nameOrId or nameOrId:alias
|
||||
collectionParts := strings.SplitN(r.activeProps[1], ":", 2)
|
||||
|
||||
collection, err := r.resolver.loadCollection(collectionParts[0])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load collection %q from field path %q", r.activeProps[1], r.fieldName)
|
||||
}
|
||||
|
||||
r.activeCollectionName = collection.Name
|
||||
|
||||
if len(collectionParts) == 2 && collectionParts[1] != "" {
|
||||
r.activeTableAlias = inflector.Columnify("__collection_alias_" + collectionParts[1])
|
||||
} else {
|
||||
r.activeTableAlias = inflector.Columnify("__collection_" + r.activeCollectionName)
|
||||
}
|
||||
|
||||
r.withMultiMatch = true
|
||||
|
||||
// join the collection to the main query
|
||||
r.resolver.registerJoin(inflector.Columnify(collection.Name), r.activeTableAlias, nil)
|
||||
|
||||
// join the collection to the multi-match subquery
|
||||
r.multiMatchActiveTableAlias = "__mm" + r.activeTableAlias
|
||||
r.multiMatch.joins = append(r.multiMatch.joins, &join{
|
||||
tableName: inflector.Columnify(collection.Name),
|
||||
tableAlias: r.multiMatchActiveTableAlias,
|
||||
})
|
||||
|
||||
// leave only the collection fields
|
||||
// aka. @collection.someCollection.fieldA.fieldB -> fieldA.fieldB
|
||||
r.activeProps = r.activeProps[2:]
|
||||
|
||||
return r.processActiveProps()
|
||||
}
|
||||
|
||||
func (r *runner) processRequestAuthField() (*search.ResolverResult, error) {
|
||||
if r.resolver.requestInfo == nil || r.resolver.requestInfo.Auth == nil || r.resolver.requestInfo.Auth.Collection() == nil {
|
||||
return &search.ResolverResult{Identifier: "NULL"}, nil
|
||||
}
|
||||
|
||||
// plain auth field
|
||||
// ---
|
||||
if _, ok := plainRequestAuthFields[r.fieldName]; ok {
|
||||
return r.resolver.resolveStaticRequestField(r.activeProps[1:]...)
|
||||
}
|
||||
|
||||
// resolve the auth collection field
|
||||
// ---
|
||||
collection := r.resolver.requestInfo.Auth.Collection()
|
||||
|
||||
r.activeCollectionName = collection.Name
|
||||
r.activeTableAlias = "__auth_" + inflector.Columnify(r.activeCollectionName)
|
||||
|
||||
// join the auth collection to the main query
|
||||
r.resolver.registerJoin(
|
||||
inflector.Columnify(r.activeCollectionName),
|
||||
r.activeTableAlias,
|
||||
dbx.HashExp{
|
||||
// aka. __auth_users.id = :userId
|
||||
(r.activeTableAlias + ".id"): r.resolver.requestInfo.Auth.Id,
|
||||
},
|
||||
)
|
||||
|
||||
// join the auth collection to the multi-match subquery
|
||||
r.multiMatchActiveTableAlias = "__mm_" + r.activeTableAlias
|
||||
r.multiMatch.joins = append(
|
||||
r.multiMatch.joins,
|
||||
&join{
|
||||
tableName: inflector.Columnify(r.activeCollectionName),
|
||||
tableAlias: r.multiMatchActiveTableAlias,
|
||||
on: dbx.HashExp{
|
||||
(r.multiMatchActiveTableAlias + ".id"): r.resolver.requestInfo.Auth.Id,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
// leave only the auth relation fields
|
||||
// aka. @request.auth.fieldA.fieldB -> fieldA.fieldB
|
||||
r.activeProps = r.activeProps[2:]
|
||||
|
||||
return r.processActiveProps()
|
||||
}
|
||||
|
||||
// note: nil value is returned as empty slice
|
||||
func toSlice(value any) []any {
|
||||
if value == nil {
|
||||
return []any{}
|
||||
}
|
||||
|
||||
rv := reflect.ValueOf(value)
|
||||
|
||||
kind := rv.Kind()
|
||||
if kind != reflect.Slice && kind != reflect.Array {
|
||||
return []any{value}
|
||||
}
|
||||
|
||||
rvLen := rv.Len()
|
||||
|
||||
result := make([]interface{}, rvLen)
|
||||
|
||||
for i := 0; i < rvLen; i++ {
|
||||
result[i] = rv.Index(i).Interface()
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func (r *runner) processRequestInfoLowerModifier(bodyField Field) (*search.ResolverResult, error) {
|
||||
rawValue := cast.ToString(r.resolver.requestInfo.Body[bodyField.GetName()])
|
||||
|
||||
placeholder := "infoLower" + bodyField.GetName() + security.PseudorandomString(6)
|
||||
|
||||
result := &search.ResolverResult{
|
||||
Identifier: "LOWER({:" + placeholder + "})",
|
||||
Params: dbx.Params{placeholder: rawValue},
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r *runner) processRequestInfoLengthModifier(bodyField Field) (*search.ResolverResult, error) {
|
||||
if _, ok := bodyField.(MultiValuer); !ok {
|
||||
return nil, fmt.Errorf("field %q doesn't support multivalue operations", bodyField.GetName())
|
||||
}
|
||||
|
||||
bodyItems := toSlice(r.resolver.requestInfo.Body[bodyField.GetName()])
|
||||
|
||||
result := &search.ResolverResult{
|
||||
Identifier: strconv.Itoa(len(bodyItems)),
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r *runner) processRequestInfoEachModifier(bodyField Field) (*search.ResolverResult, error) {
|
||||
multiValuer, ok := bodyField.(MultiValuer)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("field %q doesn't support multivalue operations", bodyField.GetName())
|
||||
}
|
||||
|
||||
bodyItems := toSlice(r.resolver.requestInfo.Body[bodyField.GetName()])
|
||||
bodyItemsRaw, err := json.Marshal(bodyItems)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot serialize the data for field %q", r.activeProps[2])
|
||||
}
|
||||
|
||||
placeholder := "dataEach" + security.PseudorandomString(6)
|
||||
cleanFieldName := inflector.Columnify(bodyField.GetName())
|
||||
jeTable := fmt.Sprintf("json_each({:%s})", placeholder)
|
||||
jeAlias := "__dataEach_" + cleanFieldName + "_je"
|
||||
r.resolver.registerJoin(jeTable, jeAlias, nil)
|
||||
|
||||
result := &search.ResolverResult{
|
||||
Identifier: fmt.Sprintf("[[%s.value]]", jeAlias),
|
||||
Params: dbx.Params{placeholder: bodyItemsRaw},
|
||||
}
|
||||
|
||||
if multiValuer.IsMultiple() {
|
||||
r.withMultiMatch = true
|
||||
}
|
||||
|
||||
if r.withMultiMatch {
|
||||
placeholder2 := "mm" + placeholder
|
||||
jeTable2 := fmt.Sprintf("json_each({:%s})", placeholder2)
|
||||
jeAlias2 := "__mm" + jeAlias
|
||||
|
||||
r.multiMatch.joins = append(r.multiMatch.joins, &join{
|
||||
tableName: jeTable2,
|
||||
tableAlias: jeAlias2,
|
||||
})
|
||||
r.multiMatch.params[placeholder2] = bodyItemsRaw
|
||||
r.multiMatch.valueIdentifier = fmt.Sprintf("[[%s.value]]", jeAlias2)
|
||||
|
||||
result.MultiMatchSubQuery = r.multiMatch
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r *runner) processRequestInfoRelationField(bodyField Field) (*search.ResolverResult, error) {
|
||||
relField, ok := bodyField.(*RelationField)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to initialize data relation field %q", bodyField.GetName())
|
||||
}
|
||||
|
||||
dataRelCollection, err := r.resolver.loadCollection(relField.CollectionId)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load collection %q from data field %q", relField.CollectionId, relField.Name)
|
||||
}
|
||||
|
||||
var dataRelIds []string
|
||||
if r.resolver.requestInfo != nil && len(r.resolver.requestInfo.Body) != 0 {
|
||||
dataRelIds = list.ToUniqueStringSlice(r.resolver.requestInfo.Body[relField.Name])
|
||||
}
|
||||
if len(dataRelIds) == 0 {
|
||||
return &search.ResolverResult{Identifier: "NULL"}, nil
|
||||
}
|
||||
|
||||
r.activeCollectionName = dataRelCollection.Name
|
||||
r.activeTableAlias = inflector.Columnify("__data_" + dataRelCollection.Name + "_" + relField.Name)
|
||||
|
||||
// join the data rel collection to the main collection
|
||||
r.resolver.registerJoin(
|
||||
r.activeCollectionName,
|
||||
r.activeTableAlias,
|
||||
dbx.In(
|
||||
fmt.Sprintf("[[%s.id]]", r.activeTableAlias),
|
||||
list.ToInterfaceSlice(dataRelIds)...,
|
||||
),
|
||||
)
|
||||
|
||||
if relField.IsMultiple() {
|
||||
r.withMultiMatch = true
|
||||
}
|
||||
|
||||
// join the data rel collection to the multi-match subquery
|
||||
r.multiMatchActiveTableAlias = inflector.Columnify("__data_mm_" + dataRelCollection.Name + "_" + relField.Name)
|
||||
r.multiMatch.joins = append(
|
||||
r.multiMatch.joins,
|
||||
&join{
|
||||
tableName: r.activeCollectionName,
|
||||
tableAlias: r.multiMatchActiveTableAlias,
|
||||
on: dbx.In(
|
||||
fmt.Sprintf("[[%s.id]]", r.multiMatchActiveTableAlias),
|
||||
list.ToInterfaceSlice(dataRelIds)...,
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
// leave only the data relation fields
|
||||
// aka. @request.body.someRel.fieldA.fieldB -> fieldA.fieldB
|
||||
r.activeProps = r.activeProps[3:]
|
||||
|
||||
return r.processActiveProps()
|
||||
}
|
||||
|
||||
var viaRegex = regexp.MustCompile(`^(\w+)_via_(\w+)$`)
|
||||
|
||||
func (r *runner) processActiveProps() (*search.ResolverResult, error) {
|
||||
totalProps := len(r.activeProps)
|
||||
|
||||
for i, prop := range r.activeProps {
|
||||
collection, err := r.resolver.loadCollection(r.activeCollectionName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to resolve field %q", prop)
|
||||
}
|
||||
|
||||
// last prop
|
||||
if i == totalProps-1 {
|
||||
return r.processLastProp(collection, prop)
|
||||
}
|
||||
|
||||
field := collection.Fields.GetByName(prop)
|
||||
|
||||
if field != nil && field.GetHidden() && !r.allowHiddenFields {
|
||||
return nil, fmt.Errorf("non-filterable field %q", prop)
|
||||
}
|
||||
|
||||
// json or geoPoint field -> treat the rest of the props as json path
|
||||
// @todo consider converting to "JSONExtractable" interface with optional extra validation for the remaining props?
|
||||
if field != nil && (field.Type() == FieldTypeJSON || field.Type() == FieldTypeGeoPoint) {
|
||||
var jsonPath strings.Builder
|
||||
for j, p := range r.activeProps[i+1:] {
|
||||
if _, err := strconv.Atoi(p); err == nil {
|
||||
jsonPath.WriteString("[")
|
||||
jsonPath.WriteString(inflector.Columnify(p))
|
||||
jsonPath.WriteString("]")
|
||||
} else {
|
||||
if j > 0 {
|
||||
jsonPath.WriteString(".")
|
||||
}
|
||||
jsonPath.WriteString(inflector.Columnify(p))
|
||||
}
|
||||
}
|
||||
jsonPathStr := jsonPath.String()
|
||||
|
||||
result := &search.ResolverResult{
|
||||
NoCoalesce: true,
|
||||
Identifier: dbutils.JSONExtract(r.activeTableAlias+"."+inflector.Columnify(prop), jsonPathStr),
|
||||
}
|
||||
|
||||
if r.withMultiMatch {
|
||||
r.multiMatch.valueIdentifier = dbutils.JSONExtract(r.multiMatchActiveTableAlias+"."+inflector.Columnify(prop), jsonPathStr)
|
||||
result.MultiMatchSubQuery = r.multiMatch
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
if i >= maxNestedRels {
|
||||
return nil, fmt.Errorf("max nested relations reached for field %q", prop)
|
||||
}
|
||||
|
||||
// check for back relation (eg. yourCollection_via_yourRelField)
|
||||
// -----------------------------------------------------------
|
||||
if field == nil {
|
||||
parts := viaRegex.FindStringSubmatch(prop)
|
||||
if len(parts) != 3 {
|
||||
if r.nullifyMisingField {
|
||||
return &search.ResolverResult{Identifier: "NULL"}, nil
|
||||
}
|
||||
return nil, fmt.Errorf("failed to resolve field %q", prop)
|
||||
}
|
||||
|
||||
backCollection, err := r.resolver.loadCollection(parts[1])
|
||||
if err != nil {
|
||||
if r.nullifyMisingField {
|
||||
return &search.ResolverResult{Identifier: "NULL"}, nil
|
||||
}
|
||||
return nil, fmt.Errorf("failed to load back relation field %q collection", prop)
|
||||
}
|
||||
|
||||
backField := backCollection.Fields.GetByName(parts[2])
|
||||
if backField == nil {
|
||||
if r.nullifyMisingField {
|
||||
return &search.ResolverResult{Identifier: "NULL"}, nil
|
||||
}
|
||||
return nil, fmt.Errorf("missing back relation field %q", parts[2])
|
||||
}
|
||||
|
||||
if backField.Type() != FieldTypeRelation {
|
||||
if r.nullifyMisingField {
|
||||
return &search.ResolverResult{Identifier: "NULL"}, nil
|
||||
}
|
||||
return nil, fmt.Errorf("invalid back relation field %q", parts[2])
|
||||
}
|
||||
|
||||
if backField.GetHidden() && !r.allowHiddenFields {
|
||||
return nil, fmt.Errorf("non-filterable back relation field %q", backField.GetName())
|
||||
}
|
||||
|
||||
backRelField, ok := backField.(*RelationField)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to initialize back relation field %q", backField.GetName())
|
||||
}
|
||||
if backRelField.CollectionId != collection.Id {
|
||||
// https://github.com/pocketbase/pocketbase/discussions/6590#discussioncomment-12496581
|
||||
if r.nullifyMisingField {
|
||||
return &search.ResolverResult{Identifier: "NULL"}, nil
|
||||
}
|
||||
return nil, fmt.Errorf("invalid collection reference of a back relation field %q", backField.GetName())
|
||||
}
|
||||
|
||||
// join the back relation to the main query
|
||||
// ---
|
||||
cleanProp := inflector.Columnify(prop)
|
||||
cleanBackFieldName := inflector.Columnify(backRelField.Name)
|
||||
newTableAlias := r.activeTableAlias + "_" + cleanProp
|
||||
newCollectionName := inflector.Columnify(backCollection.Name)
|
||||
|
||||
isBackRelMultiple := backRelField.IsMultiple()
|
||||
if !isBackRelMultiple {
|
||||
// additionally check if the rel field has a single column unique index
|
||||
_, hasUniqueIndex := dbutils.FindSingleColumnUniqueIndex(backCollection.Indexes, backRelField.Name)
|
||||
isBackRelMultiple = !hasUniqueIndex
|
||||
}
|
||||
|
||||
if !isBackRelMultiple {
|
||||
r.resolver.registerJoin(
|
||||
newCollectionName,
|
||||
newTableAlias,
|
||||
dbx.NewExp(fmt.Sprintf("[[%s.%s]] = [[%s.id]]", newTableAlias, cleanBackFieldName, r.activeTableAlias)),
|
||||
)
|
||||
} else {
|
||||
jeAlias := r.activeTableAlias + "_" + cleanProp + "_je"
|
||||
r.resolver.registerJoin(
|
||||
newCollectionName,
|
||||
newTableAlias,
|
||||
dbx.NewExp(fmt.Sprintf(
|
||||
"[[%s.id]] IN (SELECT [[%s.value]] FROM %s {{%s}})",
|
||||
r.activeTableAlias,
|
||||
jeAlias,
|
||||
dbutils.JSONEach(newTableAlias+"."+cleanBackFieldName),
|
||||
jeAlias,
|
||||
)),
|
||||
)
|
||||
}
|
||||
|
||||
r.activeCollectionName = newCollectionName
|
||||
r.activeTableAlias = newTableAlias
|
||||
// ---
|
||||
|
||||
// join the back relation to the multi-match subquery
|
||||
// ---
|
||||
if isBackRelMultiple {
|
||||
r.withMultiMatch = true // enable multimatch if not already
|
||||
}
|
||||
|
||||
newTableAlias2 := r.multiMatchActiveTableAlias + "_" + cleanProp
|
||||
|
||||
if !isBackRelMultiple {
|
||||
r.multiMatch.joins = append(
|
||||
r.multiMatch.joins,
|
||||
&join{
|
||||
tableName: newCollectionName,
|
||||
tableAlias: newTableAlias2,
|
||||
on: dbx.NewExp(fmt.Sprintf("[[%s.%s]] = [[%s.id]]", newTableAlias2, cleanBackFieldName, r.multiMatchActiveTableAlias)),
|
||||
},
|
||||
)
|
||||
} else {
|
||||
jeAlias2 := r.multiMatchActiveTableAlias + "_" + cleanProp + "_je"
|
||||
r.multiMatch.joins = append(
|
||||
r.multiMatch.joins,
|
||||
&join{
|
||||
tableName: newCollectionName,
|
||||
tableAlias: newTableAlias2,
|
||||
on: dbx.NewExp(fmt.Sprintf(
|
||||
"[[%s.id]] IN (SELECT [[%s.value]] FROM %s {{%s}})",
|
||||
r.multiMatchActiveTableAlias,
|
||||
jeAlias2,
|
||||
dbutils.JSONEach(newTableAlias2+"."+cleanBackFieldName),
|
||||
jeAlias2,
|
||||
)),
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
r.multiMatchActiveTableAlias = newTableAlias2
|
||||
// ---
|
||||
|
||||
continue
|
||||
}
|
||||
// -----------------------------------------------------------
|
||||
|
||||
// check for direct relation
|
||||
if field.Type() != FieldTypeRelation {
|
||||
return nil, fmt.Errorf("field %q is not a valid relation", prop)
|
||||
}
|
||||
|
||||
// join the relation to the main query
|
||||
// ---
|
||||
relField, ok := field.(*RelationField)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to initialize relation field %q", prop)
|
||||
}
|
||||
|
||||
relCollection, relErr := r.resolver.loadCollection(relField.CollectionId)
|
||||
if relErr != nil {
|
||||
return nil, fmt.Errorf("failed to load field %q collection", prop)
|
||||
}
|
||||
|
||||
// "id" lookups optimization for single relations to avoid unnecessary joins,
|
||||
// aka. "user.id" and "user" should produce the same query identifier
|
||||
if !relField.IsMultiple() &&
|
||||
// the penultimate prop is "id"
|
||||
i == totalProps-2 && r.activeProps[i+1] == FieldNameId {
|
||||
return r.processLastProp(collection, relField.Name)
|
||||
}
|
||||
|
||||
cleanFieldName := inflector.Columnify(relField.Name)
|
||||
prefixedFieldName := r.activeTableAlias + "." + cleanFieldName
|
||||
newTableAlias := r.activeTableAlias + "_" + cleanFieldName
|
||||
newCollectionName := relCollection.Name
|
||||
|
||||
if !relField.IsMultiple() {
|
||||
r.resolver.registerJoin(
|
||||
inflector.Columnify(newCollectionName),
|
||||
newTableAlias,
|
||||
dbx.NewExp(fmt.Sprintf("[[%s.id]] = [[%s]]", newTableAlias, prefixedFieldName)),
|
||||
)
|
||||
} else {
|
||||
jeAlias := r.activeTableAlias + "_" + cleanFieldName + "_je"
|
||||
r.resolver.registerJoin(dbutils.JSONEach(prefixedFieldName), jeAlias, nil)
|
||||
r.resolver.registerJoin(
|
||||
inflector.Columnify(newCollectionName),
|
||||
newTableAlias,
|
||||
dbx.NewExp(fmt.Sprintf("[[%s.id]] = [[%s.value]]", newTableAlias, jeAlias)),
|
||||
)
|
||||
}
|
||||
|
||||
r.activeCollectionName = newCollectionName
|
||||
r.activeTableAlias = newTableAlias
|
||||
// ---
|
||||
|
||||
// join the relation to the multi-match subquery
|
||||
// ---
|
||||
if relField.IsMultiple() {
|
||||
r.withMultiMatch = true // enable multimatch if not already
|
||||
}
|
||||
|
||||
newTableAlias2 := r.multiMatchActiveTableAlias + "_" + cleanFieldName
|
||||
prefixedFieldName2 := r.multiMatchActiveTableAlias + "." + cleanFieldName
|
||||
|
||||
if !relField.IsMultiple() {
|
||||
r.multiMatch.joins = append(
|
||||
r.multiMatch.joins,
|
||||
&join{
|
||||
tableName: inflector.Columnify(newCollectionName),
|
||||
tableAlias: newTableAlias2,
|
||||
on: dbx.NewExp(fmt.Sprintf("[[%s.id]] = [[%s]]", newTableAlias2, prefixedFieldName2)),
|
||||
},
|
||||
)
|
||||
} else {
|
||||
jeAlias2 := r.multiMatchActiveTableAlias + "_" + cleanFieldName + "_je"
|
||||
r.multiMatch.joins = append(
|
||||
r.multiMatch.joins,
|
||||
&join{
|
||||
tableName: dbutils.JSONEach(prefixedFieldName2),
|
||||
tableAlias: jeAlias2,
|
||||
},
|
||||
&join{
|
||||
tableName: inflector.Columnify(newCollectionName),
|
||||
tableAlias: newTableAlias2,
|
||||
on: dbx.NewExp(fmt.Sprintf("[[%s.id]] = [[%s.value]]", newTableAlias2, jeAlias2)),
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
r.multiMatchActiveTableAlias = newTableAlias2
|
||||
// ---
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("failed to resolve field %q", r.fieldName)
|
||||
}
|
||||
|
||||
func (r *runner) processLastProp(collection *Collection, prop string) (*search.ResolverResult, error) {
|
||||
name, modifier, err := splitModifier(prop)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
field := collection.Fields.GetByName(name)
|
||||
if field == nil {
|
||||
if r.nullifyMisingField {
|
||||
return &search.ResolverResult{Identifier: "NULL"}, nil
|
||||
}
|
||||
return nil, fmt.Errorf("unknown field %q", name)
|
||||
}
|
||||
|
||||
if field.GetHidden() && !r.allowHiddenFields {
|
||||
return nil, fmt.Errorf("non-filterable field %q", name)
|
||||
}
|
||||
|
||||
multvaluer, isMultivaluer := field.(MultiValuer)
|
||||
|
||||
cleanFieldName := inflector.Columnify(field.GetName())
|
||||
|
||||
// arrayable fields with ":length" modifier
|
||||
// -------------------------------------------------------
|
||||
if modifier == lengthModifier && isMultivaluer {
|
||||
jePair := r.activeTableAlias + "." + cleanFieldName
|
||||
|
||||
result := &search.ResolverResult{
|
||||
Identifier: dbutils.JSONArrayLength(jePair),
|
||||
}
|
||||
|
||||
if r.withMultiMatch {
|
||||
jePair2 := r.multiMatchActiveTableAlias + "." + cleanFieldName
|
||||
r.multiMatch.valueIdentifier = dbutils.JSONArrayLength(jePair2)
|
||||
result.MultiMatchSubQuery = r.multiMatch
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// arrayable fields with ":each" modifier
|
||||
// -------------------------------------------------------
|
||||
if modifier == eachModifier && isMultivaluer {
|
||||
jePair := r.activeTableAlias + "." + cleanFieldName
|
||||
jeAlias := r.activeTableAlias + "_" + cleanFieldName + "_je"
|
||||
r.resolver.registerJoin(dbutils.JSONEach(jePair), jeAlias, nil)
|
||||
|
||||
result := &search.ResolverResult{
|
||||
Identifier: fmt.Sprintf("[[%s.value]]", jeAlias),
|
||||
}
|
||||
|
||||
if multvaluer.IsMultiple() {
|
||||
r.withMultiMatch = true
|
||||
}
|
||||
|
||||
if r.withMultiMatch {
|
||||
jePair2 := r.multiMatchActiveTableAlias + "." + cleanFieldName
|
||||
jeAlias2 := r.multiMatchActiveTableAlias + "_" + cleanFieldName + "_je"
|
||||
|
||||
r.multiMatch.joins = append(r.multiMatch.joins, &join{
|
||||
tableName: dbutils.JSONEach(jePair2),
|
||||
tableAlias: jeAlias2,
|
||||
})
|
||||
r.multiMatch.valueIdentifier = fmt.Sprintf("[[%s.value]]", jeAlias2)
|
||||
|
||||
result.MultiMatchSubQuery = r.multiMatch
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// default
|
||||
// -------------------------------------------------------
|
||||
result := &search.ResolverResult{
|
||||
Identifier: "[[" + r.activeTableAlias + "." + cleanFieldName + "]]",
|
||||
}
|
||||
|
||||
if r.withMultiMatch {
|
||||
r.multiMatch.valueIdentifier = "[[" + r.multiMatchActiveTableAlias + "." + cleanFieldName + "]]"
|
||||
result.MultiMatchSubQuery = r.multiMatch
|
||||
}
|
||||
|
||||
// allow querying only auth records with emails marked as public
|
||||
if field.GetName() == FieldNameEmail && !r.allowHiddenFields && collection.IsAuth() {
|
||||
result.AfterBuild = func(expr dbx.Expression) dbx.Expression {
|
||||
return dbx.Enclose(dbx.And(expr, dbx.NewExp(fmt.Sprintf(
|
||||
"[[%s.%s]] = TRUE",
|
||||
r.activeTableAlias,
|
||||
FieldNameEmailVisibility,
|
||||
))))
|
||||
}
|
||||
}
|
||||
|
||||
// wrap in json_extract to ensure that top-level primitives
|
||||
// stored as json work correctly when compared to their SQL equivalent
|
||||
// (https://github.com/pocketbase/pocketbase/issues/4068)
|
||||
if field.Type() == FieldTypeJSON {
|
||||
result.NoCoalesce = true
|
||||
result.Identifier = dbutils.JSONExtract(r.activeTableAlias+"."+cleanFieldName, "")
|
||||
if r.withMultiMatch {
|
||||
r.multiMatch.valueIdentifier = dbutils.JSONExtract(r.multiMatchActiveTableAlias+"."+cleanFieldName, "")
|
||||
}
|
||||
}
|
||||
|
||||
// account for the ":lower" modifier
|
||||
if modifier == lowerModifier {
|
||||
result.Identifier = "LOWER(" + result.Identifier + ")"
|
||||
if r.withMultiMatch {
|
||||
r.multiMatch.valueIdentifier = "LOWER(" + r.multiMatch.valueIdentifier + ")"
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
808
core/record_field_resolver_test.go
Normal file
808
core/record_field_resolver_test.go
Normal file
File diff suppressed because one or more lines are too long
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue