Adding upstream version 0.28.1.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
88f1d47ab6
commit
e28c88ef14
933 changed files with 194711 additions and 0 deletions
622
core/record_query.go
Normal file
622
core/record_query.go
Normal file
|
@ -0,0 +1,622 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/tools/dbutils"
|
||||
"github.com/pocketbase/pocketbase/tools/inflector"
|
||||
"github.com/pocketbase/pocketbase/tools/list"
|
||||
"github.com/pocketbase/pocketbase/tools/search"
|
||||
"github.com/pocketbase/pocketbase/tools/security"
|
||||
)
|
||||
|
||||
var recordProxyType = reflect.TypeOf((*RecordProxy)(nil)).Elem()
|
||||
|
||||
// RecordQuery returns a new Record select query from a collection model, id or name.
|
||||
//
|
||||
// In case a collection id or name is provided and that collection doesn't
|
||||
// actually exists, the generated query will be created with a cancelled context
|
||||
// and will fail once an executor (Row(), One(), All(), etc.) is called.
|
||||
func (app *BaseApp) RecordQuery(collectionModelOrIdentifier any) *dbx.SelectQuery {
|
||||
var tableName string
|
||||
|
||||
collection, collectionErr := getCollectionByModelOrIdentifier(app, collectionModelOrIdentifier)
|
||||
if collection != nil {
|
||||
tableName = collection.Name
|
||||
}
|
||||
if tableName == "" {
|
||||
// update with some fake table name for easier debugging
|
||||
tableName = "@@__invalidCollectionModelOrIdentifier"
|
||||
}
|
||||
|
||||
query := app.ConcurrentDB().Select(app.ConcurrentDB().QuoteSimpleColumnName(tableName) + ".*").From(tableName)
|
||||
|
||||
// in case of an error attach a new context and cancel it immediately with the error
|
||||
if collectionErr != nil {
|
||||
ctx, cancelFunc := context.WithCancelCause(context.Background())
|
||||
query.WithContext(ctx)
|
||||
cancelFunc(collectionErr)
|
||||
}
|
||||
|
||||
return query.WithBuildHook(func(q *dbx.Query) {
|
||||
q.WithExecHook(execLockRetry(app.config.QueryTimeout, defaultMaxLockRetries)).
|
||||
WithOneHook(func(q *dbx.Query, a any, op func(b any) error) error {
|
||||
if a == nil {
|
||||
return op(a)
|
||||
}
|
||||
|
||||
switch v := a.(type) {
|
||||
case *Record:
|
||||
record, err := resolveRecordOneHook(collection, op)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*v = *record
|
||||
|
||||
return nil
|
||||
case RecordProxy:
|
||||
record, err := resolveRecordOneHook(collection, op)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
v.SetProxyRecord(record)
|
||||
return nil
|
||||
default:
|
||||
return op(a)
|
||||
}
|
||||
}).
|
||||
WithAllHook(func(q *dbx.Query, sliceA any, op func(sliceB any) error) error {
|
||||
if sliceA == nil {
|
||||
return op(sliceA)
|
||||
}
|
||||
|
||||
switch v := sliceA.(type) {
|
||||
case *[]*Record:
|
||||
records, err := resolveRecordAllHook(collection, op)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*v = records
|
||||
|
||||
return nil
|
||||
case *[]Record:
|
||||
records, err := resolveRecordAllHook(collection, op)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
nonPointers := make([]Record, len(records))
|
||||
for i, r := range records {
|
||||
nonPointers[i] = *r
|
||||
}
|
||||
|
||||
*v = nonPointers
|
||||
|
||||
return nil
|
||||
default: // expects []RecordProxy slice
|
||||
rv := reflect.ValueOf(v)
|
||||
if rv.Kind() != reflect.Ptr || rv.IsNil() {
|
||||
return errors.New("must be a pointer")
|
||||
}
|
||||
|
||||
rv = dereference(rv)
|
||||
|
||||
if rv.Kind() != reflect.Slice {
|
||||
return errors.New("must be a slice of RecordSetters")
|
||||
}
|
||||
|
||||
et := rv.Type().Elem()
|
||||
|
||||
var isSliceOfPointers bool
|
||||
if et.Kind() == reflect.Ptr {
|
||||
isSliceOfPointers = true
|
||||
et = et.Elem()
|
||||
}
|
||||
|
||||
if !reflect.PointerTo(et).Implements(recordProxyType) {
|
||||
return op(sliceA)
|
||||
}
|
||||
|
||||
records, err := resolveRecordAllHook(collection, op)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// create an empty slice
|
||||
if rv.IsNil() {
|
||||
rv.Set(reflect.MakeSlice(rv.Type(), 0, len(records)))
|
||||
}
|
||||
|
||||
for _, record := range records {
|
||||
ev := reflect.New(et)
|
||||
|
||||
if !ev.CanInterface() {
|
||||
continue
|
||||
}
|
||||
|
||||
ps, ok := ev.Interface().(RecordProxy)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
ps.SetProxyRecord(record)
|
||||
|
||||
ev = ev.Elem()
|
||||
if isSliceOfPointers {
|
||||
ev = ev.Addr()
|
||||
}
|
||||
|
||||
rv.Set(reflect.Append(rv, ev))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func resolveRecordOneHook(collection *Collection, op func(dst any) error) (*Record, error) {
|
||||
data := dbx.NullStringMap{}
|
||||
if err := op(&data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return newRecordFromNullStringMap(collection, data)
|
||||
}
|
||||
|
||||
func resolveRecordAllHook(collection *Collection, op func(dst any) error) ([]*Record, error) {
|
||||
data := []dbx.NullStringMap{}
|
||||
if err := op(&data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return newRecordsFromNullStringMaps(collection, data)
|
||||
}
|
||||
|
||||
// dereference returns the underlying value v points to.
|
||||
func dereference(v reflect.Value) reflect.Value {
|
||||
for v.Kind() == reflect.Ptr {
|
||||
if v.IsNil() {
|
||||
// initialize with a new value and continue searching
|
||||
v.Set(reflect.New(v.Type().Elem()))
|
||||
}
|
||||
v = v.Elem()
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func getCollectionByModelOrIdentifier(app App, collectionModelOrIdentifier any) (*Collection, error) {
|
||||
switch c := collectionModelOrIdentifier.(type) {
|
||||
case *Collection:
|
||||
return c, nil
|
||||
case Collection:
|
||||
return &c, nil
|
||||
case string:
|
||||
return app.FindCachedCollectionByNameOrId(c)
|
||||
default:
|
||||
return nil, errors.New("unknown collection identifier - must be collection model, id or name")
|
||||
}
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
// FindRecordById finds the Record model by its id.
|
||||
func (app *BaseApp) FindRecordById(
|
||||
collectionModelOrIdentifier any,
|
||||
recordId string,
|
||||
optFilters ...func(q *dbx.SelectQuery) error,
|
||||
) (*Record, error) {
|
||||
collection, err := getCollectionByModelOrIdentifier(app, collectionModelOrIdentifier)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
record := &Record{}
|
||||
|
||||
query := app.RecordQuery(collection).
|
||||
AndWhere(dbx.HashExp{collection.Name + ".id": recordId})
|
||||
|
||||
// apply filter funcs (if any)
|
||||
for _, filter := range optFilters {
|
||||
if filter == nil {
|
||||
continue
|
||||
}
|
||||
if err = filter(query); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
err = query.Limit(1).One(record)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
// FindRecordsByIds finds all records by the specified ids.
|
||||
// If no records are found, returns an empty slice.
|
||||
func (app *BaseApp) FindRecordsByIds(
|
||||
collectionModelOrIdentifier any,
|
||||
recordIds []string,
|
||||
optFilters ...func(q *dbx.SelectQuery) error,
|
||||
) ([]*Record, error) {
|
||||
collection, err := getCollectionByModelOrIdentifier(app, collectionModelOrIdentifier)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
query := app.RecordQuery(collection).
|
||||
AndWhere(dbx.In(
|
||||
collection.Name+".id",
|
||||
list.ToInterfaceSlice(recordIds)...,
|
||||
))
|
||||
|
||||
for _, filter := range optFilters {
|
||||
if filter == nil {
|
||||
continue
|
||||
}
|
||||
if err = filter(query); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
records := make([]*Record, 0, len(recordIds))
|
||||
|
||||
err = query.All(&records)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return records, nil
|
||||
}
|
||||
|
||||
// FindAllRecords finds all records matching specified db expressions.
|
||||
//
|
||||
// Returns all collection records if no expression is provided.
|
||||
//
|
||||
// Returns an empty slice if no records are found.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// // no extra expressions
|
||||
// app.FindAllRecords("example")
|
||||
//
|
||||
// // with extra expressions
|
||||
// expr1 := dbx.HashExp{"email": "test@example.com"}
|
||||
// expr2 := dbx.NewExp("LOWER(username) = {:username}", dbx.Params{"username": "test"})
|
||||
// app.FindAllRecords("example", expr1, expr2)
|
||||
func (app *BaseApp) FindAllRecords(collectionModelOrIdentifier any, exprs ...dbx.Expression) ([]*Record, error) {
|
||||
query := app.RecordQuery(collectionModelOrIdentifier)
|
||||
|
||||
for _, expr := range exprs {
|
||||
if expr != nil { // add only the non-nil expressions
|
||||
query.AndWhere(expr)
|
||||
}
|
||||
}
|
||||
|
||||
var records []*Record
|
||||
|
||||
if err := query.All(&records); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return records, nil
|
||||
}
|
||||
|
||||
// FindFirstRecordByData returns the first found record matching
|
||||
// the provided key-value pair.
|
||||
func (app *BaseApp) FindFirstRecordByData(collectionModelOrIdentifier any, key string, value any) (*Record, error) {
|
||||
record := &Record{}
|
||||
|
||||
err := app.RecordQuery(collectionModelOrIdentifier).
|
||||
AndWhere(dbx.HashExp{inflector.Columnify(key): value}).
|
||||
Limit(1).
|
||||
One(record)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
// FindRecordsByFilter returns limit number of records matching the
|
||||
// provided string filter.
|
||||
//
|
||||
// NB! Use the last "params" argument to bind untrusted user variables!
|
||||
//
|
||||
// The filter argument is optional and can be empty string to target
|
||||
// all available records.
|
||||
//
|
||||
// The sort argument is optional and can be empty string OR the same format
|
||||
// used in the web APIs, ex. "-created,title".
|
||||
//
|
||||
// If the limit argument is <= 0, no limit is applied to the query and
|
||||
// all matching records are returned.
|
||||
//
|
||||
// Returns an empty slice if no records are found.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// app.FindRecordsByFilter(
|
||||
// "posts",
|
||||
// "title ~ {:title} && visible = {:visible}",
|
||||
// "-created",
|
||||
// 10,
|
||||
// 0,
|
||||
// dbx.Params{"title": "lorem ipsum", "visible": true}
|
||||
// )
|
||||
func (app *BaseApp) FindRecordsByFilter(
|
||||
collectionModelOrIdentifier any,
|
||||
filter string,
|
||||
sort string,
|
||||
limit int,
|
||||
offset int,
|
||||
params ...dbx.Params,
|
||||
) ([]*Record, error) {
|
||||
collection, err := getCollectionByModelOrIdentifier(app, collectionModelOrIdentifier)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
q := app.RecordQuery(collection)
|
||||
|
||||
// build a fields resolver and attach the generated conditions to the query
|
||||
// ---
|
||||
resolver := NewRecordFieldResolver(
|
||||
app,
|
||||
collection, // the base collection
|
||||
nil, // no request data
|
||||
true, // allow searching hidden/protected fields like "email"
|
||||
)
|
||||
|
||||
if filter != "" {
|
||||
expr, err := search.FilterData(filter).BuildExpr(resolver, params...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid filter expression: %w", err)
|
||||
}
|
||||
q.AndWhere(expr)
|
||||
}
|
||||
|
||||
if sort != "" {
|
||||
for _, sortField := range search.ParseSortFromString(sort) {
|
||||
expr, err := sortField.BuildExpr(resolver)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if expr != "" {
|
||||
q.AndOrderBy(expr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
resolver.UpdateQuery(q) // attaches any adhoc joins and aliases
|
||||
// ---
|
||||
|
||||
if offset > 0 {
|
||||
q.Offset(int64(offset))
|
||||
}
|
||||
|
||||
if limit > 0 {
|
||||
q.Limit(int64(limit))
|
||||
}
|
||||
|
||||
records := []*Record{}
|
||||
|
||||
if err := q.All(&records); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return records, nil
|
||||
}
|
||||
|
||||
// FindFirstRecordByFilter returns the first available record matching the provided filter (if any).
|
||||
//
|
||||
// NB! Use the last params argument to bind untrusted user variables!
|
||||
//
|
||||
// Returns sql.ErrNoRows if no record is found.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// app.FindFirstRecordByFilter("posts", "")
|
||||
// app.FindFirstRecordByFilter("posts", "slug={:slug} && status='public'", dbx.Params{"slug": "test"})
|
||||
func (app *BaseApp) FindFirstRecordByFilter(
|
||||
collectionModelOrIdentifier any,
|
||||
filter string,
|
||||
params ...dbx.Params,
|
||||
) (*Record, error) {
|
||||
result, err := app.FindRecordsByFilter(collectionModelOrIdentifier, filter, "", 1, 0, params...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(result) == 0 {
|
||||
return nil, sql.ErrNoRows
|
||||
}
|
||||
|
||||
return result[0], nil
|
||||
}
|
||||
|
||||
// CountRecords returns the total number of records in a collection.
|
||||
func (app *BaseApp) CountRecords(collectionModelOrIdentifier any, exprs ...dbx.Expression) (int64, error) {
|
||||
var total int64
|
||||
|
||||
q := app.RecordQuery(collectionModelOrIdentifier).Select("count(*)")
|
||||
|
||||
for _, expr := range exprs {
|
||||
if expr != nil { // add only the non-nil expressions
|
||||
q.AndWhere(expr)
|
||||
}
|
||||
}
|
||||
|
||||
err := q.Row(&total)
|
||||
|
||||
return total, err
|
||||
}
|
||||
|
||||
// FindAuthRecordByToken finds the auth record associated with the provided JWT
|
||||
// (auth, file, verifyEmail, changeEmail, passwordReset types).
|
||||
//
|
||||
// Optionally specify a list of validTypes to check tokens only from those types.
|
||||
//
|
||||
// Returns an error if the JWT is invalid, expired or not associated to an auth collection record.
|
||||
func (app *BaseApp) FindAuthRecordByToken(token string, validTypes ...string) (*Record, error) {
|
||||
if token == "" {
|
||||
return nil, errors.New("missing token")
|
||||
}
|
||||
|
||||
unverifiedClaims, err := security.ParseUnverifiedJWT(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// check required claims
|
||||
id, _ := unverifiedClaims[TokenClaimId].(string)
|
||||
collectionId, _ := unverifiedClaims[TokenClaimCollectionId].(string)
|
||||
tokenType, _ := unverifiedClaims[TokenClaimType].(string)
|
||||
if id == "" || collectionId == "" || tokenType == "" {
|
||||
return nil, errors.New("missing or invalid token claims")
|
||||
}
|
||||
|
||||
// check types (if explicitly set)
|
||||
if len(validTypes) > 0 && !list.ExistInSlice(tokenType, validTypes) {
|
||||
return nil, fmt.Errorf("invalid token type %q, expects %q", tokenType, strings.Join(validTypes, ","))
|
||||
}
|
||||
|
||||
record, err := app.FindRecordById(collectionId, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !record.Collection().IsAuth() {
|
||||
return nil, errors.New("the token is not associated to an auth collection record")
|
||||
}
|
||||
|
||||
var baseTokenKey string
|
||||
switch tokenType {
|
||||
case TokenTypeAuth:
|
||||
baseTokenKey = record.Collection().AuthToken.Secret
|
||||
case TokenTypeFile:
|
||||
baseTokenKey = record.Collection().FileToken.Secret
|
||||
case TokenTypeVerification:
|
||||
baseTokenKey = record.Collection().VerificationToken.Secret
|
||||
case TokenTypePasswordReset:
|
||||
baseTokenKey = record.Collection().PasswordResetToken.Secret
|
||||
case TokenTypeEmailChange:
|
||||
baseTokenKey = record.Collection().EmailChangeToken.Secret
|
||||
default:
|
||||
return nil, errors.New("unknown token type " + tokenType)
|
||||
}
|
||||
|
||||
secret := record.TokenKey() + baseTokenKey
|
||||
|
||||
// verify token signature
|
||||
_, err = security.ParseJWT(token, secret)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
// FindAuthRecordByEmail finds the auth record associated with the provided email.
|
||||
//
|
||||
// The email check would be case-insensitive if the related collection
|
||||
// email unique index has COLLATE NOCASE specified for the email column.
|
||||
//
|
||||
// Returns an error if it is not an auth collection or the record is not found.
|
||||
func (app *BaseApp) FindAuthRecordByEmail(collectionModelOrIdentifier any, email string) (*Record, error) {
|
||||
collection, err := getCollectionByModelOrIdentifier(app, collectionModelOrIdentifier)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch auth collection: %w", err)
|
||||
}
|
||||
|
||||
if !collection.IsAuth() {
|
||||
return nil, fmt.Errorf("%q is not an auth collection", collection.Name)
|
||||
}
|
||||
|
||||
record := &Record{}
|
||||
|
||||
var expr dbx.Expression
|
||||
|
||||
index, ok := dbutils.FindSingleColumnUniqueIndex(collection.Indexes, FieldNameEmail)
|
||||
if ok && strings.EqualFold(index.Columns[0].Collate, "nocase") {
|
||||
// case-insensitive search
|
||||
expr = dbx.NewExp("[["+FieldNameEmail+"]] = {:email} COLLATE NOCASE", dbx.Params{"email": email})
|
||||
} else {
|
||||
expr = dbx.HashExp{FieldNameEmail: email}
|
||||
}
|
||||
|
||||
err = app.RecordQuery(collection).
|
||||
AndWhere(expr).
|
||||
Limit(1).
|
||||
One(record)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
// CanAccessRecord checks if a record is allowed to be accessed by the
|
||||
// specified requestInfo and accessRule.
|
||||
//
|
||||
// Rule and db checks are ignored in case requestInfo.Auth is a superuser.
|
||||
//
|
||||
// The returned error indicate that something unexpected happened during
|
||||
// the check (eg. invalid rule or db query error).
|
||||
//
|
||||
// The method always return false on invalid rule or db query error.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// requestInfo, _ := e.RequestInfo()
|
||||
// record, _ := app.FindRecordById("example", "RECORD_ID")
|
||||
// rule := types.Pointer("@request.auth.id != '' || status = 'public'")
|
||||
// // ... or use one of the record collection's rule, eg. record.Collection().ViewRule
|
||||
//
|
||||
// if ok, _ := app.CanAccessRecord(record, requestInfo, rule); ok { ... }
|
||||
func (app *BaseApp) CanAccessRecord(record *Record, requestInfo *RequestInfo, accessRule *string) (bool, error) {
|
||||
// superusers can access everything
|
||||
if requestInfo.HasSuperuserAuth() {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// only superusers can access this record
|
||||
if accessRule == nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// empty public rule, aka. everyone can access
|
||||
if *accessRule == "" {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
var exists int
|
||||
|
||||
query := app.RecordQuery(record.Collection()).
|
||||
Select("(1)").
|
||||
AndWhere(dbx.HashExp{record.Collection().Name + ".id": record.Id})
|
||||
|
||||
// parse and apply the access rule filter
|
||||
resolver := NewRecordFieldResolver(app, record.Collection(), requestInfo, true)
|
||||
expr, err := search.FilterData(*accessRule).BuildExpr(resolver)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
resolver.UpdateQuery(query)
|
||||
|
||||
err = query.AndWhere(expr).Limit(1).Row(&exists)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return exists > 0, nil
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue