Adding upstream version 0.28.1.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
88f1d47ab6
commit
e28c88ef14
933 changed files with 194711 additions and 0 deletions
5
.github/SECURITY.md
vendored
Normal file
5
.github/SECURITY.md
vendored
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Security
|
||||
|
||||
If you discover a security vulnerability within PocketBase, please send an e-mail to **support at pocketbase.io**.
|
||||
|
||||
All reports will be promptly addressed and you'll be credited in the fix release notes.
|
56
.github/workflows/release.yaml
vendored
Normal file
56
.github/workflows/release.yaml
vendored
Normal file
|
@ -0,0 +1,56 @@
|
|||
name: basebuild
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
push:
|
||||
|
||||
jobs:
|
||||
goreleaser:
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
flags: ""
|
||||
steps:
|
||||
# re-enable auto-snapshot from goreleaser-action@v3
|
||||
# (https://github.com/goreleaser/goreleaser-action-v4-auto-snapshot-example)
|
||||
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
|
||||
run: echo "flags=--snapshot" >> $GITHUB_ENV
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: 20.17.0
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '>=1.23.9'
|
||||
|
||||
# This step usually is not needed because the /ui/dist is pregenerated locally
|
||||
# but its here to ensure that each release embeds the latest admin ui artifacts.
|
||||
# If the artificats differs, a "dirty error" is thrown - https://goreleaser.com/errors/dirty/
|
||||
- name: Build Admin dashboard UI
|
||||
run: npm --prefix=./ui ci && npm --prefix=./ui run build
|
||||
|
||||
# Temporary disable as the types can have random generated identifiers making it non-deterministic.
|
||||
#
|
||||
# # Similar to the above, the jsvm types are pregenerated locally
|
||||
# # but its here to ensure that it wasn't forgotten to be executed.
|
||||
# - name: Generate jsvm types
|
||||
# run: go run ./plugins/jsvm/internal/types/types.go
|
||||
|
||||
- name: Run tests
|
||||
run: go test ./...
|
||||
|
||||
- name: Run GoReleaser
|
||||
uses: goreleaser/goreleaser-action@v6
|
||||
with:
|
||||
distribution: goreleaser
|
||||
version: '~> v2'
|
||||
args: release --clean ${{ env.flags }}
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
20
.gitignore
vendored
Normal file
20
.gitignore
vendored
Normal file
|
@ -0,0 +1,20 @@
|
|||
/.vscode/
|
||||
.idea
|
||||
|
||||
.DS_Store
|
||||
|
||||
# goreleaser builds folder
|
||||
/.builds/
|
||||
|
||||
# tests coverage
|
||||
coverage.out
|
||||
|
||||
# plaintask todo files
|
||||
*.todo
|
||||
|
||||
# generated markdown previews
|
||||
README.html
|
||||
CHANGELOG.html
|
||||
CHANGELOG_16_22.html
|
||||
CHANGELOG_8_15.html
|
||||
LICENSE.html
|
67
.goreleaser.yaml
Normal file
67
.goreleaser.yaml
Normal file
|
@ -0,0 +1,67 @@
|
|||
version: 2
|
||||
|
||||
project_name: pocketbase
|
||||
|
||||
dist: .builds
|
||||
|
||||
before:
|
||||
hooks:
|
||||
- go mod tidy
|
||||
|
||||
builds:
|
||||
- id: build_noncgo
|
||||
main: ./examples/base
|
||||
binary: pocketbase
|
||||
ldflags:
|
||||
- -s -w -X github.com/pocketbase/pocketbase.Version={{ .Version }}
|
||||
env:
|
||||
- CGO_ENABLED=0
|
||||
goos:
|
||||
- linux
|
||||
- windows
|
||||
- darwin
|
||||
goarch:
|
||||
- amd64
|
||||
- arm64
|
||||
- arm
|
||||
- s390x
|
||||
- ppc64le
|
||||
goarm:
|
||||
- 7
|
||||
ignore:
|
||||
- goos: windows
|
||||
goarch: arm
|
||||
- goos: windows
|
||||
goarch: s390x
|
||||
- goos: windows
|
||||
goarch: ppc64le
|
||||
- goos: darwin
|
||||
goarch: arm
|
||||
- goos: darwin
|
||||
goarch: s390x
|
||||
- goos: darwin
|
||||
goarch: ppc64le
|
||||
|
||||
release:
|
||||
draft: true
|
||||
|
||||
archives:
|
||||
- id: archive_noncgo
|
||||
builds: [build_noncgo]
|
||||
format: zip
|
||||
files:
|
||||
- LICENSE.md
|
||||
- CHANGELOG.md
|
||||
|
||||
checksum:
|
||||
name_template: 'checksums.txt'
|
||||
|
||||
snapshot:
|
||||
version_template: '{{ incpatch .Version }}-next'
|
||||
|
||||
changelog:
|
||||
sort: asc
|
||||
filters:
|
||||
exclude:
|
||||
- '^examples:'
|
||||
- '^ui:'
|
567
CHANGELOG.md
Normal file
567
CHANGELOG.md
Normal file
|
@ -0,0 +1,567 @@
|
|||
## v0.28.1
|
||||
|
||||
- Fixed `json_each`/`json_array_length` normalizations to properly check for array values ([#6835](https://github.com/pocketbase/pocketbase/issues/6835)).
|
||||
|
||||
|
||||
## v0.28.0
|
||||
|
||||
- Write the default response body of `*Request` hooks that are wrapped in a transaction after the related transaction completes to allow propagating the transaction error ([#6462](https://github.com/pocketbase/pocketbase/discussions/6462#discussioncomment-12207818)).
|
||||
|
||||
- Updated `app.DB()` to automatically routes raw write SQL statements to the nonconcurrent db pool ([#6689](https://github.com/pocketbase/pocketbase/discussions/6689)).
|
||||
_For the rare cases when it is needed users still have the option to explicitly target the specific pool they want using `app.ConcurrentDB()`/`app.NonconcurrentDB()`._
|
||||
|
||||
- ⚠️ Changed the default `json` field max size to 1MB.
|
||||
_Users still have the option to adjust the default limit from the collection field options but keep in mind that storing large strings/blobs in the database is known to cause performance issues and should be avoided when possible._
|
||||
|
||||
- ⚠️ Soft-deprecated and replaced `filesystem.System.GetFile(fileKey)` with `filesystem.System.GetReader(fileKey)` to avoid the confusion with `filesystem.File`.
|
||||
_The old method will still continue to work for at least until v0.29.0 but you'll get a console warning to replace it with `GetReader`._
|
||||
|
||||
- Added new `filesystem.System.GetReuploadableFile(fileKey, preserveName)` method to return an existing blob as a `*filesystem.File` value ([#6792](https://github.com/pocketbase/pocketbase/discussions/6792)).
|
||||
_This method could be useful in case you want to clone an existing Record file and assign it to a new Record (e.g. in a Record duplicate action)._
|
||||
|
||||
- Other minor improvements (updated the GitHub release min Go version to 1.23.9, updated npm and Go deps, etc.)
|
||||
|
||||
|
||||
## v0.27.2
|
||||
|
||||
- Added workers pool when cascade deleting record files to minimize _"thread exhaustion"_ errors ([#6780](https://github.com/pocketbase/pocketbase/discussions/6780)).
|
||||
|
||||
- Updated the `:excerpt` fields modifier to properly account for multibyte characters ([#6778](https://github.com/pocketbase/pocketbase/issues/6778)).
|
||||
|
||||
- Use `rowid` as count column for non-view collections to minimize the need of having the id field in a covering index ([#6739](https://github.com/pocketbase/pocketbase/discussions/6739))
|
||||
|
||||
|
||||
## v0.27.1
|
||||
|
||||
- Updated example `geoPoint` API preview body data.
|
||||
|
||||
- Added JSVM `new GeoPointField({ ... })` constructor.
|
||||
|
||||
- Added _partial_ WebP thumbs generation (_the thumbs will be stored as PNG_; [#6744](https://github.com/pocketbase/pocketbase/pull/6744)).
|
||||
|
||||
- Updated npm dev dependencies.
|
||||
|
||||
|
||||
## v0.27.0
|
||||
|
||||
- ⚠️ Moved the Create and Manage API rule checks out of the `OnRecordCreateRequest` hook finalizer, **aka. now all CRUD API rules are checked BEFORE triggering their corresponding `*Request` hook**.
|
||||
This was done to minimize the confusion regarding the firing order of the request operations, making it more predictable and consistent with the other record List/View/Update/Delete request actions.
|
||||
It could be a minor breaking change if you are relying on the old behavior and have a Go `tests.ApiScenario` that is testing a Create API rule failure and expect `OnRecordCreateRequest` to be fired. In that case for example you may have to update your test scenario like:
|
||||
```go
|
||||
tests.ApiScenario{
|
||||
Name: "Example test that checks a Create API rule failure"
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/example/records",
|
||||
...
|
||||
// old:
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordCreateRequest": 1,
|
||||
},
|
||||
// new:
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
}
|
||||
```
|
||||
If you are having difficulties adjusting your code, feel free to open a [Q&A discussion](https://github.com/pocketbase/pocketbase/discussions) with the failing/problematic code sample.
|
||||
|
||||
- Added [new `geoPoint` field](https://pocketbase.io/docs/collections/#geopoint) for storing `{"lon":x,"lat":y}` geographic coordinates.
|
||||
In addition, a new [`geoDistance(lonA, lotA, lonB, lotB)` function](htts://pocketbase.io/docs/api-rules-and-filters/#geodistancelona-lata-lonb-latb) was also implemented that could be used to apply an API rule or filter constraint based on the distance (in km) between 2 geo points.
|
||||
|
||||
- Updated the `select` field UI to accommodate better larger lists and RTL languages ([#4674](https://github.com/pocketbase/pocketbase/issues/4674)).
|
||||
|
||||
- Updated the mail attachments auto MIME type detection to use `gabriel-vasile/mimetype` for consistency and broader sniffing signatures support.
|
||||
|
||||
- Forced `text/javascript` Content-Type when serving `.js`/`.mjs` collection uploaded files with the `/api/files/...` endpoint ([#6597](https://github.com/pocketbase/pocketbase/issues/6597)).
|
||||
|
||||
- Added second optional JSVM `DateTime` constructor argument for specifying a default timezone as TZ identifier when parsing the date string as alternative to a fixed offset in order to better handle daylight saving time nuances ([#6688](https://github.com/pocketbase/pocketbase/discussions/6688)):
|
||||
```js
|
||||
// the same as with CET offset: new DateTime("2025-10-26 03:00:00 +01:00")
|
||||
new DateTime("2025-10-26 03:00:00", "Europe/Amsterdam") // 2025-10-26 02:00:00.000Z
|
||||
|
||||
// the same as with CEST offset: new DateTime("2025-10-26 01:00:00 +02:00")
|
||||
new DateTime("2025-10-26 01:00:00", "Europe/Amsterdam") // 2025-10-25 23:00:00.000Z
|
||||
```
|
||||
|
||||
- Soft-deprecated the `$http.send`'s `result.raw` field in favor of `result.body` that contains the response body as plain bytes slice to avoid the discrepancies between Go and the JSVM when casting binary data to string.
|
||||
|
||||
- Updated `modernc.org/sqlite` to 1.37.0.
|
||||
|
||||
- Other minor improvements (_removed the superuser fields from the auth record create/update body examples, allowed programmatically updating the auth record password from the create/update hooks, fixed collections import error response, etc._).
|
||||
|
||||
|
||||
## v0.26.6
|
||||
|
||||
- Allow OIDC `email_verified` to be int or boolean string since some OIDC providers like AWS Cognito has non-standard userinfo response ([#6657](https://github.com/pocketbase/pocketbase/pull/6657)).
|
||||
|
||||
- Updated `modernc.org/sqlite` to 1.36.3.
|
||||
|
||||
|
||||
## v0.26.5
|
||||
|
||||
- Fixed canonical URI parts escaping when generating the S3 request signature ([#6654](https://github.com/pocketbase/pocketbase/issues/6654)).
|
||||
|
||||
|
||||
## v0.26.4
|
||||
|
||||
- Fixed `RecordErrorEvent.Error` and `CollectionErrorEvent.Error` sync with `ModelErrorEvent.Error` ([#6639](https://github.com/pocketbase/pocketbase/issues/6639)).
|
||||
|
||||
- Fixed logs details copy to clipboard action.
|
||||
|
||||
- Updated `modernc.org/sqlite` to 1.36.2.
|
||||
|
||||
|
||||
## v0.26.3
|
||||
|
||||
- Fixed and normalized logs error serialization across common types for more consistent logs error output ([#6631](https://github.com/pocketbase/pocketbase/issues/6631)).
|
||||
|
||||
|
||||
## v0.26.2
|
||||
|
||||
- Updated `golang-jwt/jwt` dependency because it comes with a [minor security fix](https://github.com/golang-jwt/jwt/security/advisories/GHSA-mh63-6h87-95cp).
|
||||
|
||||
|
||||
## v0.26.1
|
||||
|
||||
- Removed the wrapping of `io.EOF` error when reading files since currently `io.ReadAll` doesn't check for wrapped errors ([#6600](https://github.com/pocketbase/pocketbase/issues/6600)).
|
||||
|
||||
|
||||
## v0.26.0
|
||||
|
||||
- ⚠️ Replaced `aws-sdk-go-v2` and `gocloud.dev/blob` with custom lighter implementation ([#6562](https://github.com/pocketbase/pocketbase/discussions/6562)).
|
||||
As a side-effect of the dependency removal, the binary size has been reduced with ~10MB and builds ~30% faster.
|
||||
_Although the change is expected to be backward-compatible, I'd recommend to test first locally the new version with your S3 provider (if you use S3 for files storage and backups)._
|
||||
|
||||
- ⚠️ Prioritized the user submitted non-empty `createData.email` (_it will be unverified_) when creating the PocketBase user during the first OAuth2 auth.
|
||||
|
||||
- Load the request info context during password/OAuth2/OTP authentication ([#6402](https://github.com/pocketbase/pocketbase/issues/6402)).
|
||||
This could be useful in case you want to target the auth method as part of the MFA and Auth API rules.
|
||||
For example, to disable MFA for the OAuth2 auth could be expressed as `@request.context != "oauth2"` MFA rule.
|
||||
|
||||
- Added `store.Store.SetFunc(key, func(old T) new T)` to set/update a store value with the return result of the callback in a concurrent safe manner.
|
||||
|
||||
- Added `subscription.Message.WriteSSE(w, id)` for writing an SSE formatted message into the provided writer interface (_used mostly to assist with the unit testing_).
|
||||
|
||||
- Added `$os.stat(file)` JSVM helper ([#6407](https://github.com/pocketbase/pocketbase/discussions/6407)).
|
||||
|
||||
- Added log warning for `async` marked JSVM handlers and resolve when possible the returned `Promise` as fallback ([#6476](https://github.com/pocketbase/pocketbase/issues/6476)).
|
||||
|
||||
- Allowed calling `cronAdd`, `cronRemove` from inside other JSVM handlers ([#6481](https://github.com/pocketbase/pocketbase/discussions/6481)).
|
||||
|
||||
- Bumped the default request read and write timeouts to 5mins (_old 3mins_) to accommodate slower internet connections and larger file uploads/downloads.
|
||||
_If you want to change them you can modify the `OnServe` hook's `ServeEvent.ReadTimeout/WriteTimeout` fields as shown in [#6550](https://github.com/pocketbase/pocketbase/discussions/6550#discussioncomment-12364515)._
|
||||
|
||||
- Normalized the `@request.auth.*` and `@request.body.*` back relations resolver to always return `null` when the relation field is pointing to a different collection ([#6590](https://github.com/pocketbase/pocketbase/discussions/6590#discussioncomment-12496581)).
|
||||
|
||||
- Other minor improvements (_fixed query dev log nested parameters output, reintroduced `DynamicModel` object/array props reflect types caching, updated Go and npm deps, etc._)
|
||||
|
||||
|
||||
## v0.25.9
|
||||
|
||||
- Fixed `DynamicModel` object/array props reflect type caching ([#6563](https://github.com/pocketbase/pocketbase/discussions/6563)).
|
||||
|
||||
|
||||
## v0.25.8
|
||||
|
||||
- Added a default leeway of 5 minutes for the Apple/OIDC `id_token` timestamp claims check to account for clock-skew ([#6529](https://github.com/pocketbase/pocketbase/issues/6529)).
|
||||
It can be further customized if needed with the `PB_ID_TOKEN_LEEWAY` env variable (_the value must be in seconds, e.g. "PB_ID_TOKEN_LEEWAY=60" for 1 minute_).
|
||||
|
||||
|
||||
## v0.25.7
|
||||
|
||||
- Fixed `@request.body.jsonObjOrArr.*` values extraction ([#6493](https://github.com/pocketbase/pocketbase/discussions/6493)).
|
||||
|
||||
|
||||
## v0.25.6
|
||||
|
||||
- Restore the missing `meta.isNew` field of the OAuth2 success response ([#6490](https://github.com/pocketbase/pocketbase/issues/6490)).
|
||||
|
||||
- Updated npm dependencies.
|
||||
|
||||
|
||||
## v0.25.5
|
||||
|
||||
- Set the current working directory as a default goja script path when executing inline JS strings to allow `require(m)` traversing parent `node_modules` directories.
|
||||
|
||||
- Updated `modernc.org/sqlite` and `modernc.org/libc` dependencies.
|
||||
|
||||
|
||||
## v0.25.4
|
||||
|
||||
- Downgraded `aws-sdk-go-v2` to the version before the default data integrity checks because there have been reports for non-AWS S3 providers in addition to Backblaze (IDrive, R2) that no longer or partially work with the latest AWS SDK changes.
|
||||
|
||||
While we try to enforce `when_required` by default, it is not enough to disable the new AWS SDK integrity checks entirely and some providers will require additional manual adjustments to make them compatible with the latest AWS SDK (e.g. removing the `x-aws-checksum-*` headers, unsetting the checksums calculation or reinstantiating the old MD5 checksums for some of the required operations, etc.) which as a result leads to a configuration mess that I'm not sure it would be a good idea to introduce.
|
||||
|
||||
This unfornuatelly is not a PocketBase or Go specific issue and the official AWS SDKs for other languages are in the same situation (even the latest aws-cli).
|
||||
|
||||
For those of you that extend PocketBase with Go: if your S3 vendor doesn't support the [AWS Data integrity checks](https://docs.aws.amazon.com/sdkref/latest/guide/feature-dataintegrity.html) and you are updating with `go get -u`, then make sure that the `aws-sdk-go-v2` dependencies in your `go.mod` are the same as in the repo:
|
||||
```
|
||||
// go.mod
|
||||
github.com/aws/aws-sdk-go-v2 v1.36.1
|
||||
github.com/aws/aws-sdk-go-v2/config v1.28.10
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.17.51
|
||||
github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.48
|
||||
github.com/aws/aws-sdk-go-v2/service/s3 v1.72.2
|
||||
|
||||
// after that run
|
||||
go clean -modcache && go mod tidy
|
||||
```
|
||||
_The versions pinning is temporary until the non-AWS S3 vendors patch their implementation or until I manage to find time to remove/replace the `aws-sdk-go-v2` dependency (I'll consider prioritizing it for the v0.26 or v0.27 release)._
|
||||
|
||||
|
||||
## v0.25.3
|
||||
|
||||
- Added a temporary exception for Backblaze S3 endpoints to exclude the new `aws-sdk-go-v2` checksum headers ([#6440](https://github.com/pocketbase/pocketbase/discussions/6440)).
|
||||
|
||||
|
||||
## v0.25.2
|
||||
|
||||
- Fixed realtime delete event not being fired for `RecordProxy`-ies and added basic realtime record resolve automated tests ([#6433](https://github.com/pocketbase/pocketbase/issues/6433)).
|
||||
|
||||
|
||||
## v0.25.1
|
||||
|
||||
- Fixed the batch API Preview success sample response.
|
||||
|
||||
- Bumped GitHub action min Go version to 1.23.6 as it comes with a [minor security fix](https://github.com/golang/go/issues?q=milestone%3AGo1.23.6+label%3ACherryPickApproved) for the ppc64le build.
|
||||
|
||||
|
||||
## v0.25.0
|
||||
|
||||
- ⚠️ Upgraded Google OAuth2 auth, token and userinfo endpoints to their latest versions.
|
||||
_For users that don't do anything custom with the Google OAuth2 data or the OAuth2 auth URL, this should be a non-breaking change. The exceptions that I could find are:_
|
||||
- `/v3/userinfo` auth response changes:
|
||||
```
|
||||
meta.rawUser.id => meta.rawUser.sub
|
||||
meta.rawUser.verified_email => meta.rawUser.email_verified
|
||||
```
|
||||
- `/v2/auth` query parameters changes:
|
||||
If you are specifying custom `approval_prompt=force` query parameter for the OAuth2 auth URL, you'll have to replace it with **`prompt=consent`**.
|
||||
|
||||
- Added Trakt OAuth2 provider ([#6338](https://github.com/pocketbase/pocketbase/pull/6338); thanks @aidan-)
|
||||
|
||||
- Added support for case-insensitive password auth based on the related UNIQUE index field collation ([#6337](https://github.com/pocketbase/pocketbase/discussions/6337)).
|
||||
|
||||
- Enforced `when_required` for the new AWS SDK request and response checksum validations to allow other non-AWS vendors to catch up with new AWS SDK changes (see [#6313](https://github.com/pocketbase/pocketbase/discussions/6313) and [aws/aws-sdk-go-v2#2960](https://github.com/aws/aws-sdk-go-v2/discussions/2960)).
|
||||
_You can set the environment variables `AWS_REQUEST_CHECKSUM_CALCULATION` and `AWS_RESPONSE_CHECKSUM_VALIDATION` to `when_supported` if your S3 vendor supports the [new default integrity protections](https://docs.aws.amazon.com/sdkref/latest/guide/feature-dataintegrity.html)._
|
||||
|
||||
- Soft-deprecated `Record.GetUploadedFiles` in favor of `Record.GetUnsavedFiles` to minimize the ambiguities what the method do ([#6269](https://github.com/pocketbase/pocketbase/discussions/6269)).
|
||||
|
||||
- Replaced archived `github.com/AlecAivazis/survey` dependency with a simpler `osutils.YesNoPrompt(message, fallback)` helper.
|
||||
|
||||
- Upgraded to `golang-jwt/jwt/v5`.
|
||||
|
||||
- Added JSVM `new Timezone(name)` binding for constructing `time.Location` value ([#6219](https://github.com/pocketbase/pocketbase/discussions/6219)).
|
||||
|
||||
- Added `inflector.Camelize(str)` and `inflector.Singularize(str)` helper methods.
|
||||
|
||||
- Use the non-transactional app instance during the realtime records delete access checks to ensure that cascade deleted records with API rules relying on the parent will be resolved.
|
||||
|
||||
- Other minor improvements (_replaced all `bool` exists db scans with `int` for broader drivers compatibility, updated API Preview sample error responses, updated UI dependencies, etc._)
|
||||
|
||||
|
||||
## v0.24.4
|
||||
|
||||
- Fixed fields extraction for view query with nested comments ([#6309](https://github.com/pocketbase/pocketbase/discussions/6309)).
|
||||
|
||||
- Bumped GitHub action min Go version to 1.23.5 as it comes with some [minor security fixes](https://github.com/golang/go/issues?q=milestone%3AGo1.23.5).
|
||||
|
||||
|
||||
## v0.24.3
|
||||
|
||||
- Fixed incorrectly reported unique validator error for fields starting with name of another field ([#6281](https://github.com/pocketbase/pocketbase/pull/6281); thanks @svobol13).
|
||||
|
||||
- Reload the created/edited records data in the RecordsPicker UI.
|
||||
|
||||
- Updated Go dependencies.
|
||||
|
||||
|
||||
## v0.24.2
|
||||
|
||||
- Fixed display fields extraction when there are multiple "Presentable" `relation` fields in a single related collection ([#6229](https://github.com/pocketbase/pocketbase/issues/6229)).
|
||||
|
||||
|
||||
## v0.24.1
|
||||
|
||||
- Added missing time macros in the UI autocomplete.
|
||||
|
||||
- Fixed JSVM types for structs and functions with multiple generic parameters.
|
||||
|
||||
|
||||
## v0.24.0
|
||||
|
||||
- ⚠️ Removed the "dry submit" when executing the collections Create API rule
|
||||
(you can find more details why this change was introduced and how it could affect your app in https://github.com/pocketbase/pocketbase/discussions/6073).
|
||||
For most users it should be non-breaking change, BUT if you have Create API rules that uses self-references or view counters you may have to adjust them manually.
|
||||
With this change the "multi-match" operators are also normalized in case the targeted collection doesn't have any records
|
||||
(_or in other words, `@collection.example.someField != "test"` will result to `true` if `example` collection has no records because it satisfies the condition that all available "example" records mustn't have `someField` equal to "test"_).
|
||||
As a side-effect of all of the above minor changes, the record create API performance has been also improved ~4x times in high concurrent scenarios (500 concurrent clients inserting total of 50k records - [old (58.409064001s)](https://github.com/pocketbase/benchmarks/blob/54140be5fb0102f90034e1370c7f168fbcf0ddf0/results/hetzner_cax41_cgo.md#creating-50000-posts100k-reqs50000-conc500-rulerequestauthid----requestdatapublicisset--true) vs [new (13.580098262s)](https://github.com/pocketbase/benchmarks/blob/7df0466ac9bd62fe0a1056270d20ef82012f0234/results/hetzner_cax41_cgo.md#creating-50000-posts100k-reqs50000-conc500-rulerequestauthid----requestbodypublicisset--true)).
|
||||
|
||||
- ⚠️ Changed the type definition of `store.Store[T any]` to `store.Store[K comparable, T any]` to allow support for custom store key types.
|
||||
For most users it should be non-breaking change, BUT if you are calling `store.New[any](nil)` instances you'll have to specify the store key type, aka. `store.New[string, any](nil)`.
|
||||
|
||||
- Added `@yesterday` and `@tomorrow` datetime filter macros.
|
||||
|
||||
- Added `:lower` filter modifier (e.g. `title:lower = "lorem"`).
|
||||
|
||||
- Added `mailer.Message.InlineAttachments` field for attaching inline files to an email (_aka. `cid` links_).
|
||||
|
||||
- Added cache for the JSVM `arrayOf(m)`, `DynamicModel`, etc. dynamic `reflect` created types.
|
||||
|
||||
- Added auth collection select for the settings "Send test email" popup ([#6166](https://github.com/pocketbase/pocketbase/issues/6166)).
|
||||
|
||||
- Added `record.SetRandomPassword()` to simplify random password generation usually used in the OAuth2 or OTP record creation flows.
|
||||
_The generated ~30 chars random password is assigned directly as bcrypt hash and ignores the `password` field plain value validators like min/max length or regex pattern._
|
||||
|
||||
- Added option to list and trigger the registered app level cron jobs via the Web API and UI.
|
||||
|
||||
- Added extra validators for the collection field `int64` options (e.g. `FileField.MaxSize`) restricting them to the max safe JSON number (2^53-1).
|
||||
|
||||
- Added option to unset/overwrite the default PocketBase superuser installer using `ServeEvent.InstallerFunc`.
|
||||
|
||||
- Added `app.FindCachedCollectionReferences(collection, excludeIds)` to speedup records cascade delete almost twice for projects with many collections.
|
||||
|
||||
- Added `tests.NewTestAppWithConfig(config)` helper if you need more control over the test configurations like `IsDev`, the number of allowed connections, etc.
|
||||
|
||||
- Invalidate all record tokens when the auth record email is changed programmatically or by a superuser ([#5964](https://github.com/pocketbase/pocketbase/issues/5964)).
|
||||
|
||||
- Eagerly interrupt waiting for the email alert send in case it takes longer than 15s.
|
||||
|
||||
- Normalized the hidden fields filter checks and allow targetting hidden fields in the List API rule.
|
||||
|
||||
- Fixed "Unique identify fields" input not refreshing on unique indexes change ([#6184](https://github.com/pocketbase/pocketbase/issues/6184)).
|
||||
|
||||
|
||||
## v0.23.12
|
||||
|
||||
- Added warning logs in case of mismatched `modernc.org/sqlite` and `modernc.org/libc` versions ([#6136](https://github.com/pocketbase/pocketbase/issues/6136#issuecomment-2556336962)).
|
||||
|
||||
- Skipped the default body size limit middleware for the backup upload endpoint ([#6152](https://github.com/pocketbase/pocketbase/issues/6152)).
|
||||
|
||||
|
||||
## v0.23.11
|
||||
|
||||
- Upgraded `golang.org/x/net` to 0.33.0 to fix [CVE-2024-45338](https://www.cve.org/CVERecord?id=CVE-2024-45338).
|
||||
_PocketBase uses the vulnerable functions primarily for the auto html->text mail generation, but most applications shouldn't be affected unless you are manually embedding unrestricted user provided value in your mail templates._
|
||||
|
||||
|
||||
## v0.23.10
|
||||
|
||||
- Renew the superuser file token cache when clicking on the thumb preview or download link ([#6137](https://github.com/pocketbase/pocketbase/discussions/6137)).
|
||||
|
||||
- Upgraded `modernc.org/sqlite` to 1.34.3 to fix "disk io" error on arm64 systems.
|
||||
_If you are extending PocketBase with Go and upgrading with `go get -u` make sure to manually set in your go.mod the `modernc.org/libc` indirect dependency to v1.55.3, aka. the exact same version the driver is using._
|
||||
|
||||
|
||||
## v0.23.9
|
||||
|
||||
- Replaced `strconv.Itoa` with `strconv.FormatInt` to avoid the int64->int conversion overflow on 32-bit platforms ([#6132](https://github.com/pocketbase/pocketbase/discussions/6132)).
|
||||
|
||||
|
||||
## v0.23.8
|
||||
|
||||
- Fixed Model->Record and Model->Collection hook events sync for nested and/or inner-hook transactions ([#6122](https://github.com/pocketbase/pocketbase/discussions/6122)).
|
||||
|
||||
- Other minor improvements (updated Go and npm deps, added extra escaping for the default mail record params in case the emails are stored as html files, fixed code comment typos, etc.).
|
||||
|
||||
|
||||
## v0.23.7
|
||||
|
||||
- Fixed JSVM exception -> Go error unwrapping when throwing errors from non-request hooks ([#6102](https://github.com/pocketbase/pocketbase/discussions/6102)).
|
||||
|
||||
|
||||
## v0.23.6
|
||||
|
||||
- Fixed `$filesystem.fileFromURL` documentation and generated type ([#6058](https://github.com/pocketbase/pocketbase/issues/6058)).
|
||||
|
||||
- Fixed `X-Forwarded-For` header typo in the suggested UI "Common trusted proxy" headers ([#6063](https://github.com/pocketbase/pocketbase/pull/6063)).
|
||||
|
||||
- Updated the `text` field max length validator error message to make it more clear ([#6066](https://github.com/pocketbase/pocketbase/issues/6066)).
|
||||
|
||||
- Other minor fixes (updated Go deps, skipped unnecessary validator check when the default primary key pattern is used, updated JSVM types, etc.).
|
||||
|
||||
|
||||
## v0.23.5
|
||||
|
||||
- Fixed UI logs search not properly accounting for the "Include requests by superusers" toggle when multiple search expressions are used.
|
||||
|
||||
- Fixed `text` field max validation error message ([#6053](https://github.com/pocketbase/pocketbase/issues/6053)).
|
||||
|
||||
- Other minor fixes (comment typos, JSVM types update).
|
||||
|
||||
- Updated Go deps and the min Go release GitHub action version to 1.23.4.
|
||||
|
||||
|
||||
## v0.23.4
|
||||
|
||||
- Fixed `autodate` fields not refreshing when calling `Save` multiple times on the same `Record` instance ([#6000](https://github.com/pocketbase/pocketbase/issues/6000)).
|
||||
|
||||
- Added more descriptive test OTP id and failure log message ([#5982](https://github.com/pocketbase/pocketbase/discussions/5982)).
|
||||
|
||||
- Moved the default UI CSP from meta tag to response header ([#5995](https://github.com/pocketbase/pocketbase/discussions/5995)).
|
||||
|
||||
- Updated Go and npm dependencies.
|
||||
|
||||
|
||||
## v0.23.3
|
||||
|
||||
- Fixed Gzip middleware not applying when serving static files.
|
||||
|
||||
- Fixed `Record.Fresh()`/`Record.Clone()` methods not properly cloning `autodate` fields ([#5973](https://github.com/pocketbase/pocketbase/discussions/5973)).
|
||||
|
||||
|
||||
## v0.23.2
|
||||
|
||||
- Fixed `RecordQuery()` custom struct scanning ([#5958](https://github.com/pocketbase/pocketbase/discussions/5958)).
|
||||
|
||||
- Fixed `--dev` log query print formatting.
|
||||
|
||||
- Added support for passing more than one id in the `Hook.Unbind` method for consistency with the router.
|
||||
|
||||
- Added collection rules change list in the confirmation popup
|
||||
(_to avoid getting anoying during development, the rules confirmation currently is enabled only when using https_).
|
||||
|
||||
|
||||
## v0.23.1
|
||||
|
||||
- Added `RequestEvent.Blob(status, contentType, bytes)` response write helper ([#5940](https://github.com/pocketbase/pocketbase/discussions/5940)).
|
||||
|
||||
- Added more descriptive error messages.
|
||||
|
||||
|
||||
## v0.23.0
|
||||
|
||||
> [!NOTE]
|
||||
> You don't have to upgrade to PocketBase v0.23.0 if you are not planning further developing
|
||||
> your existing app and/or are satisfied with the v0.22.x features set. There are no identified critical issues
|
||||
> with PocketBase v0.22.x yet and in the case of critical bugs and security vulnerabilities, the fixes
|
||||
> will be backported for at least until Q1 of 2025 (_if not longer_).
|
||||
>
|
||||
> **If you don't plan upgrading make sure to pin the SDKs version to their latest PocketBase v0.22.x compatible:**
|
||||
> - JS SDK: `<0.22.0`
|
||||
> - Dart SDK: `<0.19.0`
|
||||
|
||||
> [!CAUTION]
|
||||
> This release introduces many Go/JSVM and Web APIs breaking changes!
|
||||
>
|
||||
> Existing `pb_data` will be automatically upgraded with the start of the new executable,
|
||||
> but custom Go or JSVM (`pb_hooks`, `pb_migrations`) and JS/Dart SDK code will have to be migrated manually.
|
||||
> Please refer to the below upgrade guides:
|
||||
> - Go: https://pocketbase.io/v023upgrade/go/.
|
||||
> - JSVM: https://pocketbase.io/v023upgrade/jsvm/.
|
||||
>
|
||||
> If you had already switched to some of the earlier `<v0.23.0-rc14` versions and have generated a full collections snapshot migration (aka. `./pocketbase migrate collections`), then you may have to regenerate the migration file to ensure that it includes the latest changes.
|
||||
|
||||
PocketBase v0.23.0 is a major refactor of the internals with the overall goal of making PocketBase an easier to use Go framework.
|
||||
There are a lot of changes but to highlight some of the most notable ones:
|
||||
|
||||
- New and more [detailed documentation](https://pocketbase.io/docs/).
|
||||
_The old documentation could be accessed at [pocketbase.io/old](https://pocketbase.io/old/)._
|
||||
- Replaced `echo` with a new router built on top of the Go 1.22 `net/http` mux enhancements.
|
||||
- Merged `daos` packages in `core.App` to simplify the DB operations (_the `models` package structs are also migrated in `core`_).
|
||||
- Option to specify custom `DBConnect` function as part of the app configuration to allow different `database/sql` SQLite drivers (_turso/libsql, sqlcipher, etc._) and custom builds.
|
||||
_Note that we no longer loads the `mattn/go-sqlite3` driver by default when building with `CGO_ENABLED=1` to avoid `multiple definition` linker errors in case different CGO SQLite drivers or builds are used. You can find an example how to enable it back if you want to in the [new documentation](https://pocketbase.io/docs/go-overview/#github-commattngo-sqlite3)._
|
||||
- New hooks allowing better control over the execution chain and error handling (_including wrapping an entire hook chain in a single DB transaction_).
|
||||
- Various `Record` model improvements (_support for get/set modifiers, simplfied file upload by treating the file(s) as regular field value like `record.Set("document", file)`, etc._).
|
||||
- Dedicated fields structs with safer defaults to make it easier creating/updating collections programmatically.
|
||||
- Option to mark field as "Hidden", disallowing regular users to read or modify it (_there is also a dedicated Record hook to hide/unhide Record fields programmatically from a single place_).
|
||||
- Option to customize the default system collection fields (`id`, `email`, `password`, etc.).
|
||||
- Admins are now system `_superusers` auth records.
|
||||
- Builtin rate limiter (_supports tags, wildcards and exact routes matching_).
|
||||
- Batch/transactional Web API endpoint.
|
||||
- Impersonate Web API endpoint (_it could be also used for generating fixed/non-refreshable superuser tokens, aka. "API keys"_).
|
||||
- Support for custom user request activity log attributes.
|
||||
- One-Time Password (OTP) auth method (_via email code_).
|
||||
- Multi-Factor Authentication (MFA) support (_currently requires any 2 different auth methods to be used_).
|
||||
- Support for Record "proxy/projection" in preparation for the planned autogeneration of typed Go record models.
|
||||
- Linear OAuth2 provider ([#5909](https://github.com/pocketbase/pocketbase/pull/5909); thanks @chnfyi).
|
||||
- WakaTime OAuth2 provider ([#5829](https://github.com/pocketbase/pocketbase/pull/5829); thanks @tigawanna).
|
||||
- Notion OAuth2 provider ([#4999](https://github.com/pocketbase/pocketbase/pull/4999); thanks @s-li1).
|
||||
- monday.com OAuth2 provider ([#5346](https://github.com/pocketbase/pocketbase/pull/5346); thanks @Jaytpa01).
|
||||
- New Instagram provider compatible with the new Instagram Login APIs ([#5588](https://github.com/pocketbase/pocketbase/pull/5588); thanks @pnmcosta).
|
||||
_The provider key is `instagram2` to prevent conflicts with existing linked users._
|
||||
- Option to retrieve the OIDC OAuth2 user info from the `id_token` payload for the cases when the provider doesn't have a dedicated user info endpoint.
|
||||
- Various minor UI improvements (_recursive `Presentable` view, slightly different collection options organization, zoom/pan for the logs chart, etc._)
|
||||
- and many more...
|
||||
|
||||
#### Go/JSVM APIs changes
|
||||
|
||||
> - Go: https://pocketbase.io/v023upgrade/go/.
|
||||
> - JSVM: https://pocketbase.io/v023upgrade/jsvm/.
|
||||
|
||||
#### SDKs changes
|
||||
|
||||
- [JS SDK v0.22.0](https://github.com/pocketbase/js-sdk/blob/master/CHANGELOG.md)
|
||||
- [Dart SDK v0.19.0](https://github.com/pocketbase/dart-sdk/blob/master/CHANGELOG.md)
|
||||
|
||||
#### Web APIs changes
|
||||
|
||||
- New `POST /api/batch` endpoint.
|
||||
|
||||
- New `GET /api/collections/meta/scaffolds` endpoint.
|
||||
|
||||
- New `DELETE /api/collections/{collection}/truncate` endpoint.
|
||||
|
||||
- New `POST /api/collections/{collection}/request-otp` endpoint.
|
||||
|
||||
- New `POST /api/collections/{collection}/auth-with-otp` endpoint.
|
||||
|
||||
- New `POST /api/collections/{collection}/impersonate/{id}` endpoint.
|
||||
|
||||
- ⚠️ If you are constructing requests to `/api/*` routes manually remove the trailing slash (_there is no longer trailing slash removal middleware registered by default_).
|
||||
|
||||
- ⚠️ Removed `/api/admins/*` endpoints because admins are converted to `_superusers` auth collection records.
|
||||
|
||||
- ⚠️ Previously when uploading new files to a multiple `file` field, new files were automatically appended to the existing field values.
|
||||
This behaviour has changed with v0.23+ and for consistency with the other multi-valued fields when uploading new files they will replace the old ones. If you want to prepend or append new files to an existing multiple `file` field value you can use the `+` prefix or suffix:
|
||||
```js
|
||||
"documents": [file1, file2] // => [file1_name, file2_name]
|
||||
"+documents": [file1, file2] // => [file1_name, file2_name, old1_name, old2_name]
|
||||
"documents+": [file1, file2] // => [old1_name, old2_name, file1_name, file2_name]
|
||||
```
|
||||
|
||||
- ⚠️ Removed `GET /records/{id}/external-auths` and `DELETE /records/{id}/external-auths/{provider}` endpoints because this is now handled by sending list and delete requests to the `_externalAuths` collection.
|
||||
|
||||
- ⚠️ Changes to the app settings model fields and response (+new options such as `trustedProxy`, `rateLimits`, `batch`, etc.). The app settings Web APIs are mostly used by the Dashboard UI and rarely by the end users, but if you want to check all settings changes please refer to the [Settings Go struct](https://github.com/pocketbase/pocketbase/blob/develop/core/settings_model.go#L121).
|
||||
|
||||
- ⚠️ New flatten Collection model and fields structure. The Collection model Web APIs are mostly used by the Dashboard UI and rarely by the end users, but if you want to check all changes please refer to the [Collection Go struct](https://github.com/pocketbase/pocketbase/blob/develop/core/collection_model.go#L308).
|
||||
|
||||
- ⚠️ The top level error response `code` key was renamed to `status` for consistency with the Go APIs.
|
||||
The error field key remains `code`:
|
||||
```js
|
||||
{
|
||||
"status": 400, // <-- old: "code"
|
||||
"message": "Failed to create record.",
|
||||
"data": {
|
||||
"title": {
|
||||
"code": "validation_required",
|
||||
"message": "Missing required value."
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
- ⚠️ New fields in the `GET /api/collections/{collection}/auth-methods` response.
|
||||
_The old `authProviders`, `usernamePassword`, `emailPassword` fields are still returned in the response but are considered deprecated and will be removed in the future._
|
||||
```js
|
||||
{
|
||||
"mfa": {
|
||||
"duration": 100,
|
||||
"enabled": true
|
||||
},
|
||||
"otp": {
|
||||
"duration": 0,
|
||||
"enabled": false
|
||||
},
|
||||
"password": {
|
||||
"enabled": true,
|
||||
"identityFields": ["email", "username"]
|
||||
},
|
||||
"oauth2": {
|
||||
"enabled": true,
|
||||
"providers": [{"name": "gitlab", ...}, {"name": "google", ...}]
|
||||
},
|
||||
// old fields...
|
||||
}
|
||||
```
|
||||
|
||||
- ⚠️ Soft-deprecated the OAuth2 auth success `meta.avatarUrl` field in favour of `meta.avatarURL`.
|
1205
CHANGELOG_16_22.md
Normal file
1205
CHANGELOG_16_22.md
Normal file
File diff suppressed because it is too large
Load diff
1384
CHANGELOG_8_15.md
Normal file
1384
CHANGELOG_8_15.md
Normal file
File diff suppressed because it is too large
Load diff
82
CONTRIBUTING.md
Normal file
82
CONTRIBUTING.md
Normal file
|
@ -0,0 +1,82 @@
|
|||
# Contributing to PocketBase
|
||||
|
||||
Thanks for taking the time to improve PocketBase!
|
||||
|
||||
This document describes how to prepare a PR for a change in the main repository.
|
||||
|
||||
- [Prerequisites](#prerequisites)
|
||||
- [Making changes in the Go code](#making-changes-in-the-go-code)
|
||||
- [Making changes in the Admin UI](#making-changes-in-the-admin-ui)
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- Go 1.23+ (for making changes in the Go code)
|
||||
- Node 18+ (for making changes in the Admin UI)
|
||||
|
||||
If you haven't already, you can fork the main repository and clone your fork so that you can work locally:
|
||||
|
||||
```
|
||||
git clone https://github.com/your_username/pocketbase.git
|
||||
```
|
||||
|
||||
> [!IMPORTANT]
|
||||
> It is recommended to create a new branch from master for each of your bugfixes and features.
|
||||
> This is required if you are planning to submit multiple PRs in order to keep the changes separate for review until they eventually get merged.
|
||||
|
||||
## Making changes in the Go code
|
||||
|
||||
PocketBase is distributed as a Go package, which means that in order to run the project you'll have to create a Go `main` program that imports the package.
|
||||
|
||||
The repository already includes such program, located in `examples/base`, that is also used for the prebuilt executables.
|
||||
|
||||
So, let's assume that you already done some changes in the PocketBase Go code and you want now to run them:
|
||||
|
||||
1. Navigate to `examples/base`
|
||||
2. Run `go run main.go serve`
|
||||
|
||||
This will start a web server on `http://localhost:8090` with the embedded prebuilt Admin UI from `ui/dist`. And that's it!
|
||||
|
||||
**Before making a PR to the main repository, it is a good idea to:**
|
||||
|
||||
- Add unit/integration tests for your changes (we are using the standard `testing` go package).
|
||||
To run the tests, you could execute (while in the root project directory):
|
||||
|
||||
```sh
|
||||
go test ./...
|
||||
|
||||
# or using the Makefile
|
||||
make test
|
||||
```
|
||||
|
||||
- Run the linter - **golangci** ([see how to install](https://golangci-lint.run/usage/install/#local-installation)):
|
||||
|
||||
```sh
|
||||
golangci-lint run -c ./golangci.yml ./...
|
||||
|
||||
# or using the Makefile
|
||||
make lint
|
||||
```
|
||||
|
||||
## Making changes in the Admin UI
|
||||
|
||||
PocketBase Admin UI is a single-page application (SPA) built with Svelte and Vite.
|
||||
|
||||
To start the Admin UI:
|
||||
|
||||
1. Navigate to the `ui` project directory
|
||||
2. Run `npm install` to install the node dependencies
|
||||
3. Start vite's dev server
|
||||
```sh
|
||||
npm run dev
|
||||
```
|
||||
|
||||
You could open the browser and access the running Admin UI at `http://localhost:3000`.
|
||||
|
||||
Since the Admin UI is just a client-side application, you need to have the PocketBase backend server also running in the background (either manually running the `examples/base/main.go` or download a prebuilt executable).
|
||||
|
||||
> [!NOTE]
|
||||
> By default, the Admin UI is expecting the backend server to be started at `http://localhost:8090`, but you could change that by creating a new `ui/.env.development.local` file with `PB_BACKEND_URL = YOUR_ADDRESS` variable inside it.
|
||||
|
||||
Every change you make in the Admin UI should be automatically reflected in the browser at `http://localhost:3000` without reloading the page.
|
||||
|
||||
Once you are done with your changes, you have to build the Admin UI with `npm run build`, so that it can be embedded in the go package. And that's it - you can make your PR to the main PocketBase repository.
|
17
LICENSE.md
Normal file
17
LICENSE.md
Normal file
|
@ -0,0 +1,17 @@
|
|||
The MIT License (MIT)
|
||||
Copyright (c) 2022 - present, Gani Georgiev
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy of this software
|
||||
and associated documentation files (the "Software"), to deal in the Software without restriction,
|
||||
including without limitation the rights to use, copy, modify, merge, publish, distribute,
|
||||
sublicense, and/or sell copies of the Software, and to permit persons to whom the Software
|
||||
is furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all copies or
|
||||
substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
|
||||
BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
||||
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
|
||||
DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
12
Makefile
Normal file
12
Makefile
Normal file
|
@ -0,0 +1,12 @@
|
|||
lint:
|
||||
golangci-lint run -c ./golangci.yml ./...
|
||||
|
||||
test:
|
||||
go test ./... -v --cover
|
||||
|
||||
jstypes:
|
||||
go run ./plugins/jsvm/internal/types/types.go
|
||||
|
||||
test-report:
|
||||
go test ./... -v --cover -coverprofile=coverage.out
|
||||
go tool cover -html=coverage.out
|
153
README.md
Normal file
153
README.md
Normal file
|
@ -0,0 +1,153 @@
|
|||
<p align="center">
|
||||
<a href="https://pocketbase.io" target="_blank" rel="noopener">
|
||||
<img src="https://i.imgur.com/5qimnm5.png" alt="PocketBase - open source backend in 1 file" />
|
||||
</a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://github.com/pocketbase/pocketbase/actions/workflows/release.yaml" target="_blank" rel="noopener"><img src="https://github.com/pocketbase/pocketbase/actions/workflows/release.yaml/badge.svg" alt="build" /></a>
|
||||
<a href="https://github.com/pocketbase/pocketbase/releases" target="_blank" rel="noopener"><img src="https://img.shields.io/github/release/pocketbase/pocketbase.svg" alt="Latest releases" /></a>
|
||||
<a href="https://pkg.go.dev/github.com/pocketbase/pocketbase" target="_blank" rel="noopener"><img src="https://godoc.org/github.com/pocketbase/pocketbase?status.svg" alt="Go package documentation" /></a>
|
||||
</p>
|
||||
|
||||
[PocketBase](https://pocketbase.io) is an open source Go backend that includes:
|
||||
|
||||
- embedded database (_SQLite_) with **realtime subscriptions**
|
||||
- built-in **files and users management**
|
||||
- convenient **Admin dashboard UI**
|
||||
- and simple **REST-ish API**
|
||||
|
||||
**For documentation and examples, please visit https://pocketbase.io/docs.**
|
||||
|
||||
> [!WARNING]
|
||||
> Please keep in mind that PocketBase is still under active development
|
||||
> and therefore full backward compatibility is not guaranteed before reaching v1.0.0.
|
||||
|
||||
## API SDK clients
|
||||
|
||||
The easiest way to interact with the PocketBase Web APIs is to use one of the official SDK clients:
|
||||
|
||||
- **JavaScript - [pocketbase/js-sdk](https://github.com/pocketbase/js-sdk)** (_Browser, Node.js, React Native_)
|
||||
- **Dart - [pocketbase/dart-sdk](https://github.com/pocketbase/dart-sdk)** (_Web, Mobile, Desktop, CLI_)
|
||||
|
||||
You could also check the recommendations in https://pocketbase.io/docs/how-to-use/.
|
||||
|
||||
|
||||
## Overview
|
||||
|
||||
### Use as standalone app
|
||||
|
||||
You could download the prebuilt executable for your platform from the [Releases page](https://github.com/pocketbase/pocketbase/releases).
|
||||
Once downloaded, extract the archive and run `./pocketbase serve` in the extracted directory.
|
||||
|
||||
The prebuilt executables are based on the [`examples/base/main.go` file](https://github.com/pocketbase/pocketbase/blob/master/examples/base/main.go) and comes with the JS VM plugin enabled by default which allows to extend PocketBase with JavaScript (_for more details please refer to [Extend with JavaScript](https://pocketbase.io/docs/js-overview/)_).
|
||||
|
||||
### Use as a Go framework/toolkit
|
||||
|
||||
PocketBase is distributed as a regular Go library package which allows you to build
|
||||
your own custom app specific business logic and still have a single portable executable at the end.
|
||||
|
||||
Here is a minimal example:
|
||||
|
||||
0. [Install Go 1.23+](https://go.dev/doc/install) (_if you haven't already_)
|
||||
|
||||
1. Create a new project directory with the following `main.go` file inside it:
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
|
||||
"github.com/pocketbase/pocketbase"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
)
|
||||
|
||||
func main() {
|
||||
app := pocketbase.New()
|
||||
|
||||
app.OnServe().BindFunc(func(se *core.ServeEvent) error {
|
||||
// registers new "GET /hello" route
|
||||
se.Router.GET("/hello", func(re *core.RequestEvent) error {
|
||||
return re.String(200, "Hello world!")
|
||||
})
|
||||
|
||||
return se.Next()
|
||||
})
|
||||
|
||||
if err := app.Start(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
2. To init the dependencies, run `go mod init myapp && go mod tidy`.
|
||||
|
||||
3. To start the application, run `go run main.go serve`.
|
||||
|
||||
4. To build a statically linked executable, you can run `CGO_ENABLED=0 go build` and then start the created executable with `./myapp serve`.
|
||||
|
||||
_For more details please refer to [Extend with Go](https://pocketbase.io/docs/go-overview/)._
|
||||
|
||||
### Building and running the repo main.go example
|
||||
|
||||
To build the minimal standalone executable, like the prebuilt ones in the releases page, you can simply run `go build` inside the `examples/base` directory:
|
||||
|
||||
0. [Install Go 1.23+](https://go.dev/doc/install) (_if you haven't already_)
|
||||
1. Clone/download the repo
|
||||
2. Navigate to `examples/base`
|
||||
3. Run `GOOS=linux GOARCH=amd64 CGO_ENABLED=0 go build`
|
||||
(_https://go.dev/doc/install/source#environment_)
|
||||
4. Start the created executable by running `./base serve`.
|
||||
|
||||
Note that the supported build targets by the pure Go SQLite driver at the moment are:
|
||||
|
||||
```
|
||||
darwin amd64
|
||||
darwin arm64
|
||||
freebsd amd64
|
||||
freebsd arm64
|
||||
linux 386
|
||||
linux amd64
|
||||
linux arm
|
||||
linux arm64
|
||||
linux ppc64le
|
||||
linux riscv64
|
||||
linux s390x
|
||||
windows amd64
|
||||
windows arm64
|
||||
```
|
||||
|
||||
### Testing
|
||||
|
||||
PocketBase comes with mixed bag of unit and integration tests.
|
||||
To run them, use the standard `go test` command:
|
||||
|
||||
```sh
|
||||
go test ./...
|
||||
```
|
||||
|
||||
Check also the [Testing guide](http://pocketbase.io/docs/testing) to learn how to write your own custom application tests.
|
||||
|
||||
## Security
|
||||
|
||||
If you discover a security vulnerability within PocketBase, please send an e-mail to **support at pocketbase.io**.
|
||||
|
||||
All reports will be promptly addressed and you'll be credited in the fix release notes.
|
||||
|
||||
## Contributing
|
||||
|
||||
PocketBase is free and open source project licensed under the [MIT License](LICENSE.md).
|
||||
You are free to do whatever you want with it, even offering it as a paid service.
|
||||
|
||||
You could help continuing its development by:
|
||||
|
||||
- [Contribute to the source code](CONTRIBUTING.md)
|
||||
- [Suggest new features and report issues](https://github.com/pocketbase/pocketbase/issues)
|
||||
|
||||
PRs for new OAuth2 providers, bug fixes, code optimizations and documentation improvements are more than welcome.
|
||||
|
||||
But please refrain creating PRs for _new features_ without previously discussing the implementation details.
|
||||
PocketBase has a [roadmap](https://github.com/orgs/pocketbase/projects/2) and I try to work on issues in specific order and such PRs often come in out of nowhere and skew all initial planning with tedious back-and-forth communication.
|
||||
|
||||
Don't get upset if I close your PR, even if it is well executed and tested. This doesn't mean that it will never be merged.
|
||||
Later we can always refer to it and/or take pieces of your implementation when the time comes to work on the issue (don't worry you'll be credited in the release notes).
|
47
apis/api_error_aliases.go
Normal file
47
apis/api_error_aliases.go
Normal file
|
@ -0,0 +1,47 @@
|
|||
package apis
|
||||
|
||||
import "github.com/pocketbase/pocketbase/tools/router"
|
||||
|
||||
// ApiError aliases to minimize the breaking changes with earlier versions
|
||||
// and for consistency with the JSVM binds.
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
// ToApiError wraps err into ApiError instance (if not already).
|
||||
func ToApiError(err error) *router.ApiError {
|
||||
return router.ToApiError(err)
|
||||
}
|
||||
|
||||
// NewApiError is an alias for [router.NewApiError].
|
||||
func NewApiError(status int, message string, errData any) *router.ApiError {
|
||||
return router.NewApiError(status, message, errData)
|
||||
}
|
||||
|
||||
// NewBadRequestError is an alias for [router.NewBadRequestError].
|
||||
func NewBadRequestError(message string, errData any) *router.ApiError {
|
||||
return router.NewBadRequestError(message, errData)
|
||||
}
|
||||
|
||||
// NewNotFoundError is an alias for [router.NewNotFoundError].
|
||||
func NewNotFoundError(message string, errData any) *router.ApiError {
|
||||
return router.NewNotFoundError(message, errData)
|
||||
}
|
||||
|
||||
// NewForbiddenError is an alias for [router.NewForbiddenError].
|
||||
func NewForbiddenError(message string, errData any) *router.ApiError {
|
||||
return router.NewForbiddenError(message, errData)
|
||||
}
|
||||
|
||||
// NewUnauthorizedError is an alias for [router.NewUnauthorizedError].
|
||||
func NewUnauthorizedError(message string, errData any) *router.ApiError {
|
||||
return router.NewUnauthorizedError(message, errData)
|
||||
}
|
||||
|
||||
// NewTooManyRequestsError is an alias for [router.NewTooManyRequestsError].
|
||||
func NewTooManyRequestsError(message string, errData any) *router.ApiError {
|
||||
return router.NewTooManyRequestsError(message, errData)
|
||||
}
|
||||
|
||||
// NewInternalServerError is an alias for [router.NewInternalServerError].
|
||||
func NewInternalServerError(message string, errData any) *router.ApiError {
|
||||
return router.NewInternalServerError(message, errData)
|
||||
}
|
155
apis/backup.go
Normal file
155
apis/backup.go
Normal file
|
@ -0,0 +1,155 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tools/router"
|
||||
"github.com/pocketbase/pocketbase/tools/routine"
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
"github.com/spf13/cast"
|
||||
)
|
||||
|
||||
// bindBackupApi registers the file api endpoints and the corresponding handlers.
|
||||
func bindBackupApi(app core.App, rg *router.RouterGroup[*core.RequestEvent]) {
|
||||
sub := rg.Group("/backups")
|
||||
sub.GET("", backupsList).Bind(RequireSuperuserAuth())
|
||||
sub.POST("", backupCreate).Bind(RequireSuperuserAuth())
|
||||
sub.POST("/upload", backupUpload).Bind(BodyLimit(0), RequireSuperuserAuth())
|
||||
sub.GET("/{key}", backupDownload) // relies on superuser file token
|
||||
sub.DELETE("/{key}", backupDelete).Bind(RequireSuperuserAuth())
|
||||
sub.POST("/{key}/restore", backupRestore).Bind(RequireSuperuserAuth())
|
||||
}
|
||||
|
||||
type backupFileInfo struct {
|
||||
Modified types.DateTime `json:"modified"`
|
||||
Key string `json:"key"`
|
||||
Size int64 `json:"size"`
|
||||
}
|
||||
|
||||
func backupsList(e *core.RequestEvent) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
fsys, err := e.App.NewBackupsFilesystem()
|
||||
if err != nil {
|
||||
return e.BadRequestError("Failed to load backups filesystem.", err)
|
||||
}
|
||||
defer fsys.Close()
|
||||
|
||||
fsys.SetContext(ctx)
|
||||
|
||||
backups, err := fsys.List("")
|
||||
if err != nil {
|
||||
return e.BadRequestError("Failed to retrieve backup items. Raw error: \n"+err.Error(), nil)
|
||||
}
|
||||
|
||||
result := make([]backupFileInfo, len(backups))
|
||||
|
||||
for i, obj := range backups {
|
||||
modified, _ := types.ParseDateTime(obj.ModTime)
|
||||
|
||||
result[i] = backupFileInfo{
|
||||
Key: obj.Key,
|
||||
Size: obj.Size,
|
||||
Modified: modified,
|
||||
}
|
||||
}
|
||||
|
||||
return e.JSON(http.StatusOK, result)
|
||||
}
|
||||
|
||||
func backupDownload(e *core.RequestEvent) error {
|
||||
fileToken := e.Request.URL.Query().Get("token")
|
||||
|
||||
authRecord, err := e.App.FindAuthRecordByToken(fileToken, core.TokenTypeFile)
|
||||
if err != nil || !authRecord.IsSuperuser() {
|
||||
return e.ForbiddenError("Insufficient permissions to access the resource.", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
fsys, err := e.App.NewBackupsFilesystem()
|
||||
if err != nil {
|
||||
return e.InternalServerError("Failed to load backups filesystem.", err)
|
||||
}
|
||||
defer fsys.Close()
|
||||
|
||||
fsys.SetContext(ctx)
|
||||
|
||||
key := e.Request.PathValue("key")
|
||||
|
||||
return fsys.Serve(
|
||||
e.Response,
|
||||
e.Request,
|
||||
key,
|
||||
filepath.Base(key), // without the path prefix (if any)
|
||||
)
|
||||
}
|
||||
|
||||
func backupDelete(e *core.RequestEvent) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
fsys, err := e.App.NewBackupsFilesystem()
|
||||
if err != nil {
|
||||
return e.InternalServerError("Failed to load backups filesystem.", err)
|
||||
}
|
||||
defer fsys.Close()
|
||||
|
||||
fsys.SetContext(ctx)
|
||||
|
||||
key := e.Request.PathValue("key")
|
||||
|
||||
if key != "" && cast.ToString(e.App.Store().Get(core.StoreKeyActiveBackup)) == key {
|
||||
return e.BadRequestError("The backup is currently being used and cannot be deleted.", nil)
|
||||
}
|
||||
|
||||
if err := fsys.Delete(key); err != nil {
|
||||
return e.BadRequestError("Invalid or already deleted backup file. Raw error: \n"+err.Error(), nil)
|
||||
}
|
||||
|
||||
return e.NoContent(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func backupRestore(e *core.RequestEvent) error {
|
||||
if e.App.Store().Has(core.StoreKeyActiveBackup) {
|
||||
return e.BadRequestError("Try again later - another backup/restore process has already been started.", nil)
|
||||
}
|
||||
|
||||
key := e.Request.PathValue("key")
|
||||
|
||||
existsCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
fsys, err := e.App.NewBackupsFilesystem()
|
||||
if err != nil {
|
||||
return e.InternalServerError("Failed to load backups filesystem.", err)
|
||||
}
|
||||
defer fsys.Close()
|
||||
|
||||
fsys.SetContext(existsCtx)
|
||||
|
||||
if exists, err := fsys.Exists(key); !exists {
|
||||
return e.BadRequestError("Missing or invalid backup file.", err)
|
||||
}
|
||||
|
||||
routine.FireAndForget(func() {
|
||||
// give some optimistic time to write the response before restarting the app
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
// wait max 10 minutes to fetch the backup
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
if err := e.App.RestoreBackup(ctx, key); err != nil {
|
||||
e.App.Logger().Error("Failed to restore backup", "key", key, "error", err.Error())
|
||||
}
|
||||
})
|
||||
|
||||
return e.NoContent(http.StatusNoContent)
|
||||
}
|
78
apis/backup_create.go
Normal file
78
apis/backup_create.go
Normal file
|
@ -0,0 +1,78 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"regexp"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
)
|
||||
|
||||
func backupCreate(e *core.RequestEvent) error {
|
||||
if e.App.Store().Has(core.StoreKeyActiveBackup) {
|
||||
return e.BadRequestError("Try again later - another backup/restore process has already been started", nil)
|
||||
}
|
||||
|
||||
form := new(backupCreateForm)
|
||||
form.app = e.App
|
||||
|
||||
err := e.BindBody(form)
|
||||
if err != nil {
|
||||
return e.BadRequestError("An error occurred while loading the submitted data.", err)
|
||||
}
|
||||
|
||||
err = form.validate()
|
||||
if err != nil {
|
||||
return e.BadRequestError("An error occurred while validating the submitted data.", err)
|
||||
}
|
||||
|
||||
err = e.App.CreateBackup(context.Background(), form.Name)
|
||||
if err != nil {
|
||||
return e.BadRequestError("Failed to create backup.", err)
|
||||
}
|
||||
|
||||
// we don't retrieve the generated backup file because it may not be
|
||||
// available yet due to the eventually consistent nature of some S3 providers
|
||||
return e.NoContent(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
var backupNameRegex = regexp.MustCompile(`^[a-z0-9_-]+\.zip$`)
|
||||
|
||||
type backupCreateForm struct {
|
||||
app core.App
|
||||
|
||||
Name string `form:"name" json:"name"`
|
||||
}
|
||||
|
||||
func (form *backupCreateForm) validate() error {
|
||||
return validation.ValidateStruct(form,
|
||||
validation.Field(
|
||||
&form.Name,
|
||||
validation.Length(1, 150),
|
||||
validation.Match(backupNameRegex),
|
||||
validation.By(form.checkUniqueName),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
func (form *backupCreateForm) checkUniqueName(value any) error {
|
||||
v, _ := value.(string)
|
||||
if v == "" {
|
||||
return nil // nothing to check
|
||||
}
|
||||
|
||||
fsys, err := form.app.NewBackupsFilesystem()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer fsys.Close()
|
||||
|
||||
if exists, err := fsys.Exists(v); err != nil || exists {
|
||||
return validation.NewError("validation_backup_name_exists", "The backup file name is invalid or already exists.")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
823
apis/backup_test.go
Normal file
823
apis/backup_test.go
Normal file
|
@ -0,0 +1,823 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"archive/zip"
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/apis"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
"github.com/pocketbase/pocketbase/tools/filesystem/blob"
|
||||
)
|
||||
|
||||
func TestBackupsList(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "unauthorized",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/backups",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := createTestBackups(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as regular user",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/backups",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := createTestBackups(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser (empty list)",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/backups",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{`[]`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/backups",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := createTestBackups(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"test1.zip"`,
|
||||
`"test2.zip"`,
|
||||
`"test3.zip"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackupsCreate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "unauthorized",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/backups",
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
ensureNoBackups(t, app)
|
||||
},
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as regular user",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/backups",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
ensureNoBackups(t, app)
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser (pending backup)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/backups",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Store().Set(core.StoreKeyActiveBackup, "")
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
ensureNoBackups(t, app)
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser (autogenerated name)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/backups",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
files, err := getBackupFiles(app)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if total := len(files); total != 1 {
|
||||
t.Fatalf("Expected 1 backup file, got %d", total)
|
||||
}
|
||||
|
||||
expected := "pb_backup_"
|
||||
if !strings.HasPrefix(files[0].Key, expected) {
|
||||
t.Fatalf("Expected backup file with prefix %q, got %q", expected, files[0].Key)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnBackupCreate": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser (invalid name)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/backups",
|
||||
Body: strings.NewReader(`{"name":"!test.zip"}`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
ensureNoBackups(t, app)
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"name":{"code":"validation_match_invalid"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser (valid name)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/backups",
|
||||
Body: strings.NewReader(`{"name":"test.zip"}`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
files, err := getBackupFiles(app)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if total := len(files); total != 1 {
|
||||
t.Fatalf("Expected 1 backup file, got %d", total)
|
||||
}
|
||||
|
||||
expected := "test.zip"
|
||||
if files[0].Key != expected {
|
||||
t.Fatalf("Expected backup file %q, got %q", expected, files[0].Key)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnBackupCreate": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackupUpload(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// create dummy form data bodies
|
||||
type body struct {
|
||||
buffer io.Reader
|
||||
contentType string
|
||||
}
|
||||
bodies := make([]body, 10)
|
||||
for i := 0; i < 10; i++ {
|
||||
func() {
|
||||
zb := new(bytes.Buffer)
|
||||
zw := zip.NewWriter(zb)
|
||||
if err := zw.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
b := new(bytes.Buffer)
|
||||
mw := multipart.NewWriter(b)
|
||||
|
||||
mfw, err := mw.CreateFormFile("file", "test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := io.Copy(mfw, zb); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
mw.Close()
|
||||
|
||||
bodies[i] = body{
|
||||
buffer: b,
|
||||
contentType: mw.FormDataContentType(),
|
||||
}
|
||||
}()
|
||||
}
|
||||
// ---
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "unauthorized",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/backups/upload",
|
||||
Body: bodies[0].buffer,
|
||||
Headers: map[string]string{
|
||||
"Content-Type": bodies[0].contentType,
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
ensureNoBackups(t, app)
|
||||
},
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as regular user",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/backups/upload",
|
||||
Body: bodies[1].buffer,
|
||||
Headers: map[string]string{
|
||||
"Content-Type": bodies[1].contentType,
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
ensureNoBackups(t, app)
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser (missing file)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/backups/upload",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
ensureNoBackups(t, app)
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser (existing backup name)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/backups/upload",
|
||||
Body: bodies[3].buffer,
|
||||
Headers: map[string]string{
|
||||
"Content-Type": bodies[3].contentType,
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
fsys, err := app.NewBackupsFilesystem()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer fsys.Close()
|
||||
// create a dummy backup file to simulate existing backups
|
||||
if err := fsys.Upload([]byte("123"), "test"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
files, _ := getBackupFiles(app)
|
||||
if total := len(files); total != 1 {
|
||||
t.Fatalf("Expected %d backup file, got %d", 1, total)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{"file":{`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser (valid file)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/backups/upload",
|
||||
Body: bodies[4].buffer,
|
||||
Headers: map[string]string{
|
||||
"Content-Type": bodies[4].contentType,
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
files, _ := getBackupFiles(app)
|
||||
if total := len(files); total != 1 {
|
||||
t.Fatalf("Expected %d backup file, got %d", 1, total)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "ensure that the default body limit is skipped",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/backups/upload",
|
||||
Body: bytes.NewBuffer(make([]byte, apis.DefaultMaxBodySize+100)),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
ExpectedStatus: 400, // it doesn't matter as long as it is not 413
|
||||
ExpectedContent: []string{`"data":{`},
|
||||
NotExpectedContent: []string{"entity too large"},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackupsDownload(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "unauthorized",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/backups/test1.zip",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := createTestBackups(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "with record auth header",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/backups/test1.zip",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := createTestBackups(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "with superuser auth header",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/backups/test1.zip",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := createTestBackups(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "with empty or invalid token",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/backups/test1.zip?token=",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := createTestBackups(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "with valid record auth token",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/backups/test1.zip?token=eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := createTestBackups(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "with valid record file token",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/backups/test1.zip?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6ImZpbGUiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8ifQ.nSTLuCPcGpWn2K2l-BFkC3Vlzc-ZTDPByYq8dN1oPSo",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := createTestBackups(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "with valid superuser auth token",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/backups/test1.zip?token=eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := createTestBackups(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "with expired superuser file token",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/backups/test1.zip?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsImV4cCI6MTY0MDk5MTY2MSwidHlwZSI6ImZpbGUiLCJjb2xsZWN0aW9uSWQiOiJwYmNfMzE0MjYzNTgyMyJ9.nqqtqpPhxU0045F4XP_ruAkzAidYBc5oPy9ErN3XBq0",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := createTestBackups(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "with valid superuser file token but missing backup name",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/backups/missing?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6ImZpbGUiLCJjb2xsZWN0aW9uSWQiOiJwYmNfMzE0MjYzNTgyMyJ9.Lupz541xRvrktwkrl55p5pPCF77T69ZRsohsIcb2dxc",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := createTestBackups(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "with valid superuser file token",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/backups/test1.zip?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6ImZpbGUiLCJjb2xsZWN0aW9uSWQiOiJwYmNfMzE0MjYzNTgyMyJ9.Lupz541xRvrktwkrl55p5pPCF77T69ZRsohsIcb2dxc",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := createTestBackups(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
"storage/",
|
||||
"data.db",
|
||||
"auxiliary.db",
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "with valid superuser file token and backup name with escaped char",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/backups/%40test4.zip?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6ImZpbGUiLCJjb2xsZWN0aW9uSWQiOiJwYmNfMzE0MjYzNTgyMyJ9.Lupz541xRvrktwkrl55p5pPCF77T69ZRsohsIcb2dxc",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := createTestBackups(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
"storage/",
|
||||
"data.db",
|
||||
"auxiliary.db",
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackupsDelete(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
noTestBackupFilesChanges := func(t testing.TB, app *tests.TestApp) {
|
||||
files, err := getBackupFiles(app)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expected := 4
|
||||
if total := len(files); total != expected {
|
||||
t.Fatalf("Expected %d backup(s), got %d", expected, total)
|
||||
}
|
||||
}
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "unauthorized",
|
||||
Method: http.MethodDelete,
|
||||
URL: "/api/backups/test1.zip",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := createTestBackups(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
noTestBackupFilesChanges(t, app)
|
||||
},
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as regular user",
|
||||
Method: http.MethodDelete,
|
||||
URL: "/api/backups/test1.zip",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := createTestBackups(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
noTestBackupFilesChanges(t, app)
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser (missing file)",
|
||||
Method: http.MethodDelete,
|
||||
URL: "/api/backups/missing.zip",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := createTestBackups(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
noTestBackupFilesChanges(t, app)
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser (existing file with matching active backup)",
|
||||
Method: http.MethodDelete,
|
||||
URL: "/api/backups/test1.zip",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := createTestBackups(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// mock active backup with the same name to delete
|
||||
app.Store().Set(core.StoreKeyActiveBackup, "test1.zip")
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
noTestBackupFilesChanges(t, app)
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser (existing file and no matching active backup)",
|
||||
Method: http.MethodDelete,
|
||||
URL: "/api/backups/test1.zip",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := createTestBackups(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// mock active backup with different name
|
||||
app.Store().Set(core.StoreKeyActiveBackup, "new.zip")
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
files, err := getBackupFiles(app)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if total := len(files); total != 3 {
|
||||
t.Fatalf("Expected %d backup files, got %d", 3, total)
|
||||
}
|
||||
|
||||
deletedFile := "test1.zip"
|
||||
|
||||
for _, f := range files {
|
||||
if f.Key == deletedFile {
|
||||
t.Fatalf("Expected backup %q to be deleted", deletedFile)
|
||||
}
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser (backup with escaped character)",
|
||||
Method: http.MethodDelete,
|
||||
URL: "/api/backups/%40test4.zip",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := createTestBackups(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
files, err := getBackupFiles(app)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if total := len(files); total != 3 {
|
||||
t.Fatalf("Expected %d backup files, got %d", 3, total)
|
||||
}
|
||||
|
||||
deletedFile := "@test4.zip"
|
||||
|
||||
for _, f := range files {
|
||||
if f.Key == deletedFile {
|
||||
t.Fatalf("Expected backup %q to be deleted", deletedFile)
|
||||
}
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackupsRestore(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "unauthorized",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/backups/test1.zip/restore",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := createTestBackups(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as regular user",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/backups/test1.zip/restore",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := createTestBackups(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser (missing file)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/backups/missing.zip/restore",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := createTestBackups(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser (active backup process)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/backups/test1.zip/restore",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := createTestBackups(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
app.Store().Set(core.StoreKeyActiveBackup, "")
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
func createTestBackups(app core.App) error {
|
||||
ctx := context.Background()
|
||||
|
||||
if err := app.CreateBackup(ctx, "test1.zip"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := app.CreateBackup(ctx, "test2.zip"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := app.CreateBackup(ctx, "test3.zip"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := app.CreateBackup(ctx, "@test4.zip"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getBackupFiles(app core.App) ([]*blob.ListObject, error) {
|
||||
fsys, err := app.NewBackupsFilesystem()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer fsys.Close()
|
||||
|
||||
return fsys.List("")
|
||||
}
|
||||
|
||||
func ensureNoBackups(t testing.TB, app *tests.TestApp) {
|
||||
files, err := getBackupFiles(app)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if total := len(files); total != 0 {
|
||||
t.Fatalf("Expected 0 backup files, got %d", total)
|
||||
}
|
||||
}
|
72
apis/backup_upload.go
Normal file
72
apis/backup_upload.go
Normal file
|
@ -0,0 +1,72 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/core/validators"
|
||||
"github.com/pocketbase/pocketbase/tools/filesystem"
|
||||
)
|
||||
|
||||
func backupUpload(e *core.RequestEvent) error {
|
||||
fsys, err := e.App.NewBackupsFilesystem()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer fsys.Close()
|
||||
|
||||
form := new(backupUploadForm)
|
||||
form.fsys = fsys
|
||||
files, _ := e.FindUploadedFiles("file")
|
||||
if len(files) > 0 {
|
||||
form.File = files[0]
|
||||
}
|
||||
|
||||
err = form.validate()
|
||||
if err != nil {
|
||||
return e.BadRequestError("An error occurred while validating the submitted data.", err)
|
||||
}
|
||||
|
||||
err = fsys.UploadFile(form.File, form.File.OriginalName)
|
||||
if err != nil {
|
||||
return e.BadRequestError("Failed to upload backup.", err)
|
||||
}
|
||||
|
||||
// we don't retrieve the generated backup file because it may not be
|
||||
// available yet due to the eventually consistent nature of some S3 providers
|
||||
return e.NoContent(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
type backupUploadForm struct {
|
||||
fsys *filesystem.System
|
||||
|
||||
File *filesystem.File `json:"file"`
|
||||
}
|
||||
|
||||
func (form *backupUploadForm) validate() error {
|
||||
return validation.ValidateStruct(form,
|
||||
validation.Field(
|
||||
&form.File,
|
||||
validation.Required,
|
||||
validation.By(validators.UploadedFileMimeType([]string{"application/zip"})),
|
||||
validation.By(form.checkUniqueName),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
func (form *backupUploadForm) checkUniqueName(value any) error {
|
||||
v, _ := value.(*filesystem.File)
|
||||
if v == nil {
|
||||
return nil // nothing to check
|
||||
}
|
||||
|
||||
// note: we use the original name because that is what we upload
|
||||
if exists, err := form.fsys.Exists(v.OriginalName); err != nil || exists {
|
||||
return validation.NewError("validation_backup_name_exists", "Backup file with the specified name already exists.")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
174
apis/base.go
Normal file
174
apis/base.go
Normal file
|
@ -0,0 +1,174 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tools/router"
|
||||
)
|
||||
|
||||
// StaticWildcardParam is the name of Static handler wildcard parameter.
|
||||
const StaticWildcardParam = "path"
|
||||
|
||||
// NewRouter returns a new router instance loaded with the default app middlewares and api routes.
|
||||
func NewRouter(app core.App) (*router.Router[*core.RequestEvent], error) {
|
||||
pbRouter := router.NewRouter(func(w http.ResponseWriter, r *http.Request) (*core.RequestEvent, router.EventCleanupFunc) {
|
||||
event := new(core.RequestEvent)
|
||||
event.Response = w
|
||||
event.Request = r
|
||||
event.App = app
|
||||
|
||||
return event, nil
|
||||
})
|
||||
|
||||
// register default middlewares
|
||||
pbRouter.Bind(activityLogger())
|
||||
pbRouter.Bind(panicRecover())
|
||||
pbRouter.Bind(rateLimit())
|
||||
pbRouter.Bind(loadAuthToken())
|
||||
pbRouter.Bind(securityHeaders())
|
||||
pbRouter.Bind(BodyLimit(DefaultMaxBodySize))
|
||||
|
||||
apiGroup := pbRouter.Group("/api")
|
||||
bindSettingsApi(app, apiGroup)
|
||||
bindCollectionApi(app, apiGroup)
|
||||
bindRecordCrudApi(app, apiGroup)
|
||||
bindRecordAuthApi(app, apiGroup)
|
||||
bindLogsApi(app, apiGroup)
|
||||
bindBackupApi(app, apiGroup)
|
||||
bindCronApi(app, apiGroup)
|
||||
bindFileApi(app, apiGroup)
|
||||
bindBatchApi(app, apiGroup)
|
||||
bindRealtimeApi(app, apiGroup)
|
||||
bindHealthApi(app, apiGroup)
|
||||
|
||||
return pbRouter, nil
|
||||
}
|
||||
|
||||
// WrapStdHandler wraps Go [http.Handler] into a PocketBase handler func.
|
||||
func WrapStdHandler(h http.Handler) func(*core.RequestEvent) error {
|
||||
return func(e *core.RequestEvent) error {
|
||||
h.ServeHTTP(e.Response, e.Request)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WrapStdMiddleware wraps Go [func(http.Handler) http.Handle] into a PocketBase middleware func.
|
||||
func WrapStdMiddleware(m func(http.Handler) http.Handler) func(*core.RequestEvent) error {
|
||||
return func(e *core.RequestEvent) (err error) {
|
||||
m(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
e.Response = w
|
||||
e.Request = r
|
||||
err = e.Next()
|
||||
})).ServeHTTP(e.Response, e.Request)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// MustSubFS returns an [fs.FS] corresponding to the subtree rooted at fsys's dir.
|
||||
//
|
||||
// This is similar to [fs.Sub] but panics on failure.
|
||||
func MustSubFS(fsys fs.FS, dir string) fs.FS {
|
||||
dir = filepath.ToSlash(filepath.Clean(dir)) // ToSlash in case of Windows path
|
||||
|
||||
sub, err := fs.Sub(fsys, dir)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("failed to create sub FS: %w", err))
|
||||
}
|
||||
|
||||
return sub
|
||||
}
|
||||
|
||||
// Static is a handler function to serve static directory content from fsys.
|
||||
//
|
||||
// If a file resource is missing and indexFallback is set, the request
|
||||
// will be forwarded to the base index.html (useful for SPA with pretty urls).
|
||||
//
|
||||
// NB! Expects the route to have a "{path...}" wildcard parameter.
|
||||
//
|
||||
// Special redirects:
|
||||
// - if "path" is a file that ends in index.html, it is redirected to its non-index.html version (eg. /test/index.html -> /test/)
|
||||
// - if "path" is a directory that has index.html, the index.html file is rendered,
|
||||
// otherwise if missing - returns 404 or fallback to the root index.html if indexFallback is set
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// fsys := os.DirFS("./pb_public")
|
||||
// router.GET("/files/{path...}", apis.Static(fsys, false))
|
||||
func Static(fsys fs.FS, indexFallback bool) func(*core.RequestEvent) error {
|
||||
if fsys == nil {
|
||||
panic("Static: the provided fs.FS argument is nil")
|
||||
}
|
||||
|
||||
return func(e *core.RequestEvent) error {
|
||||
// disable the activity logger to avoid flooding with messages
|
||||
//
|
||||
// note: errors are still logged
|
||||
if e.Get(requestEventKeySkipSuccessActivityLog) == nil {
|
||||
e.Set(requestEventKeySkipSuccessActivityLog, true)
|
||||
}
|
||||
|
||||
filename := e.Request.PathValue(StaticWildcardParam)
|
||||
filename = filepath.ToSlash(filepath.Clean(strings.TrimPrefix(filename, "/")))
|
||||
|
||||
// eagerly check for directory traversal
|
||||
//
|
||||
// note: this is just out of an abundance of caution because the fs.FS implementation could be non-std,
|
||||
// but usually shouldn't be necessary since os.DirFS.Open is expected to fail if the filename starts with dots
|
||||
if len(filename) > 2 && filename[0] == '.' && filename[1] == '.' && (filename[2] == '/' || filename[2] == '\\') {
|
||||
if indexFallback && filename != router.IndexPage {
|
||||
return e.FileFS(fsys, router.IndexPage)
|
||||
}
|
||||
return router.ErrFileNotFound
|
||||
}
|
||||
|
||||
fi, err := fs.Stat(fsys, filename)
|
||||
if err != nil {
|
||||
if indexFallback && filename != router.IndexPage {
|
||||
return e.FileFS(fsys, router.IndexPage)
|
||||
}
|
||||
return router.ErrFileNotFound
|
||||
}
|
||||
|
||||
if fi.IsDir() {
|
||||
// redirect to a canonical dir url, aka. with trailing slash
|
||||
if !strings.HasSuffix(e.Request.URL.Path, "/") {
|
||||
return e.Redirect(http.StatusMovedPermanently, safeRedirectPath(e.Request.URL.Path+"/"))
|
||||
}
|
||||
} else {
|
||||
urlPath := e.Request.URL.Path
|
||||
if strings.HasSuffix(urlPath, "/") {
|
||||
// redirect to a non-trailing slash file route
|
||||
urlPath = strings.TrimRight(urlPath, "/")
|
||||
if len(urlPath) > 0 {
|
||||
return e.Redirect(http.StatusMovedPermanently, safeRedirectPath(urlPath))
|
||||
}
|
||||
} else if stripped, ok := strings.CutSuffix(urlPath, router.IndexPage); ok {
|
||||
// redirect without the index.html
|
||||
return e.Redirect(http.StatusMovedPermanently, safeRedirectPath(stripped))
|
||||
}
|
||||
}
|
||||
|
||||
fileErr := e.FileFS(fsys, filename)
|
||||
|
||||
if fileErr != nil && indexFallback && filename != router.IndexPage && errors.Is(fileErr, router.ErrFileNotFound) {
|
||||
return e.FileFS(fsys, router.IndexPage)
|
||||
}
|
||||
|
||||
return fileErr
|
||||
}
|
||||
}
|
||||
|
||||
// safeRedirectPath normalizes the path string by replacing all beginning slashes
|
||||
// (`\\`, `//`, `\/`) with a single forward slash to prevent open redirect attacks
|
||||
func safeRedirectPath(path string) string {
|
||||
if len(path) > 1 && (path[0] == '\\' || path[0] == '/') && (path[1] == '\\' || path[1] == '/') {
|
||||
path = "/" + strings.TrimLeft(path, `/\`)
|
||||
}
|
||||
return path
|
||||
}
|
313
apis/base_test.go
Normal file
313
apis/base_test.go
Normal file
|
@ -0,0 +1,313 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/apis"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
"github.com/pocketbase/pocketbase/tools/router"
|
||||
)
|
||||
|
||||
func TestWrapStdHandler(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
e := new(core.RequestEvent)
|
||||
e.App = app
|
||||
e.Request = req
|
||||
e.Response = rec
|
||||
|
||||
err := apis.WrapStdHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("test"))
|
||||
}))(e)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if body := rec.Body.String(); body != "test" {
|
||||
t.Fatalf("Expected body %q, got %q", "test", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWrapStdMiddleware(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
e := new(core.RequestEvent)
|
||||
e.App = app
|
||||
e.Request = req
|
||||
e.Response = rec
|
||||
|
||||
err := apis.WrapStdMiddleware(func(h http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("test"))
|
||||
})
|
||||
})(e)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if body := rec.Body.String(); body != "test" {
|
||||
t.Fatalf("Expected body %q, got %q", "test", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStatic(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
dir := createTestDir(t)
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
fsys := os.DirFS(filepath.Join(dir, "sub"))
|
||||
|
||||
type staticScenario struct {
|
||||
path string
|
||||
indexFallback bool
|
||||
expectedStatus int
|
||||
expectBody string
|
||||
expectError bool
|
||||
}
|
||||
|
||||
scenarios := []staticScenario{
|
||||
{
|
||||
path: "",
|
||||
indexFallback: false,
|
||||
expectedStatus: 200,
|
||||
expectBody: "sub index.html",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
path: "missing/a/b/c",
|
||||
indexFallback: false,
|
||||
expectedStatus: 404,
|
||||
expectBody: "",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
path: "missing/a/b/c",
|
||||
indexFallback: true,
|
||||
expectedStatus: 200,
|
||||
expectBody: "sub index.html",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
path: "testroot", // parent directory file
|
||||
indexFallback: false,
|
||||
expectedStatus: 404,
|
||||
expectBody: "",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
path: "test",
|
||||
indexFallback: false,
|
||||
expectedStatus: 200,
|
||||
expectBody: "sub test",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
path: "sub2",
|
||||
indexFallback: false,
|
||||
expectedStatus: 301,
|
||||
expectBody: "",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
path: "sub2/",
|
||||
indexFallback: false,
|
||||
expectedStatus: 200,
|
||||
expectBody: "sub2 index.html",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
path: "sub2/test",
|
||||
indexFallback: false,
|
||||
expectedStatus: 200,
|
||||
expectBody: "sub2 test",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
path: "sub2/test/",
|
||||
indexFallback: false,
|
||||
expectedStatus: 301,
|
||||
expectBody: "",
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
// extra directory traversal checks
|
||||
dtp := []string{
|
||||
"/../",
|
||||
"\\../",
|
||||
"../",
|
||||
"../../",
|
||||
"..\\",
|
||||
"..\\..\\",
|
||||
"../..\\",
|
||||
"..\\..//",
|
||||
`%2e%2e%2f`,
|
||||
`%2e%2e%2f%2e%2e%2f`,
|
||||
`%2e%2e/`,
|
||||
`%2e%2e/%2e%2e/`,
|
||||
`..%2f`,
|
||||
`..%2f..%2f`,
|
||||
`%2e%2e%5c`,
|
||||
`%2e%2e%5c%2e%2e%5c`,
|
||||
`%2e%2e\`,
|
||||
`%2e%2e\%2e%2e\`,
|
||||
`..%5c`,
|
||||
`..%5c..%5c`,
|
||||
`%252e%252e%255c`,
|
||||
`%252e%252e%255c%252e%252e%255c`,
|
||||
`..%255c`,
|
||||
`..%255c..%255c`,
|
||||
}
|
||||
for _, p := range dtp {
|
||||
scenarios = append(scenarios,
|
||||
staticScenario{
|
||||
path: p + "testroot",
|
||||
indexFallback: false,
|
||||
expectedStatus: 404,
|
||||
expectBody: "",
|
||||
expectError: true,
|
||||
},
|
||||
staticScenario{
|
||||
path: p + "testroot",
|
||||
indexFallback: true,
|
||||
expectedStatus: 200,
|
||||
expectBody: "sub index.html",
|
||||
expectError: false,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%s_%v", i, s.path, s.indexFallback), func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/"+s.path, nil)
|
||||
req.SetPathValue(apis.StaticWildcardParam, s.path)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
e := new(core.RequestEvent)
|
||||
e.App = app
|
||||
e.Request = req
|
||||
e.Response = rec
|
||||
|
||||
err := apis.Static(fsys, s.indexFallback)(e)
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
|
||||
}
|
||||
|
||||
body := rec.Body.String()
|
||||
if body != s.expectBody {
|
||||
t.Fatalf("Expected body %q, got %q", s.expectBody, body)
|
||||
}
|
||||
|
||||
if hasErr {
|
||||
apiErr := router.ToApiError(err)
|
||||
if apiErr.Status != s.expectedStatus {
|
||||
t.Fatalf("Expected status code %d, got %d", s.expectedStatus, apiErr.Status)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMustSubFS(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dir := createTestDir(t)
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
// invalid path (no beginning and ending slashes)
|
||||
if !hasPanicked(func() {
|
||||
apis.MustSubFS(os.DirFS(dir), "/test/")
|
||||
}) {
|
||||
t.Fatalf("Expected to panic")
|
||||
}
|
||||
|
||||
// valid path
|
||||
if hasPanicked(func() {
|
||||
apis.MustSubFS(os.DirFS(dir), "./////a/b/c") // checks if ToSlash was called
|
||||
}) {
|
||||
t.Fatalf("Didn't expect to panic")
|
||||
}
|
||||
|
||||
// check sub content
|
||||
sub := apis.MustSubFS(os.DirFS(dir), "sub")
|
||||
|
||||
_, err := sub.Open("test")
|
||||
if err != nil {
|
||||
t.Fatalf("Missing expected file sub/test")
|
||||
}
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
func hasPanicked(f func()) (didPanic bool) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
didPanic = true
|
||||
}
|
||||
}()
|
||||
f()
|
||||
return
|
||||
}
|
||||
|
||||
// note: make sure to call os.RemoveAll(dir) after you are done
|
||||
// working with the created test dir.
|
||||
func createTestDir(t *testing.T) string {
|
||||
dir, err := os.MkdirTemp(os.TempDir(), "test_dir")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(filepath.Join(dir, "index.html"), []byte("root index.html"), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(dir, "testroot"), []byte("root test"), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(filepath.Join(dir, "sub"), os.ModePerm); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(dir, "sub/index.html"), []byte("sub index.html"), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(dir, "sub/test"), []byte("sub test"), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(filepath.Join(dir, "sub", "sub2"), os.ModePerm); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(dir, "sub/sub2/index.html"), []byte("sub2 index.html"), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(dir, "sub/sub2/test"), []byte("sub2 test"), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return dir
|
||||
}
|
548
apis/batch.go
Normal file
548
apis/batch.go
Normal file
|
@ -0,0 +1,548 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"regexp"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tools/filesystem"
|
||||
"github.com/pocketbase/pocketbase/tools/router"
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
"github.com/spf13/cast"
|
||||
)
|
||||
|
||||
func bindBatchApi(app core.App, rg *router.RouterGroup[*core.RequestEvent]) {
|
||||
sub := rg.Group("/batch")
|
||||
sub.POST("", batchTransaction).Unbind(DefaultBodyLimitMiddlewareId) // the body limit is inlined
|
||||
}
|
||||
|
||||
type HandleFunc func(e *core.RequestEvent) error
|
||||
|
||||
type BatchActionHandlerFunc func(app core.App, ir *core.InternalRequest, params map[string]string, next func(data any) error) HandleFunc
|
||||
|
||||
// ValidBatchActions defines a map with the supported batch InternalRequest actions.
|
||||
//
|
||||
// Note: when adding new routes make sure that their middlewares are inlined!
|
||||
var ValidBatchActions = map[*regexp.Regexp]BatchActionHandlerFunc{
|
||||
// "upsert" handler
|
||||
regexp.MustCompile(`^PUT /api/collections/(?P<collection>[^\/\?]+)/records(?P<query>\?.*)?$`): func(app core.App, ir *core.InternalRequest, params map[string]string, next func(any) error) HandleFunc {
|
||||
var id string
|
||||
if len(ir.Body) > 0 && ir.Body["id"] != "" {
|
||||
id = cast.ToString(ir.Body["id"])
|
||||
}
|
||||
if id != "" {
|
||||
_, err := app.FindRecordById(params["collection"], id)
|
||||
if err == nil {
|
||||
// update
|
||||
// ---
|
||||
params["id"] = id // required for the path value
|
||||
ir.Method = "PATCH"
|
||||
ir.URL = "/api/collections/" + params["collection"] + "/records/" + id + params["query"]
|
||||
return recordUpdate(false, next)
|
||||
}
|
||||
}
|
||||
|
||||
// create
|
||||
// ---
|
||||
ir.Method = "POST"
|
||||
ir.URL = "/api/collections/" + params["collection"] + "/records" + params["query"]
|
||||
return recordCreate(false, next)
|
||||
},
|
||||
regexp.MustCompile(`^POST /api/collections/(?P<collection>[^\/\?]+)/records(\?.*)?$`): func(app core.App, ir *core.InternalRequest, params map[string]string, next func(any) error) HandleFunc {
|
||||
return recordCreate(false, next)
|
||||
},
|
||||
regexp.MustCompile(`^PATCH /api/collections/(?P<collection>[^\/\?]+)/records/(?P<id>[^\/\?]+)(\?.*)?$`): func(app core.App, ir *core.InternalRequest, params map[string]string, next func(any) error) HandleFunc {
|
||||
return recordUpdate(false, next)
|
||||
},
|
||||
regexp.MustCompile(`^DELETE /api/collections/(?P<collection>[^\/\?]+)/records/(?P<id>[^\/\?]+)(\?.*)?$`): func(app core.App, ir *core.InternalRequest, params map[string]string, next func(any) error) HandleFunc {
|
||||
return recordDelete(false, next)
|
||||
},
|
||||
}
|
||||
|
||||
type BatchRequestResult struct {
|
||||
Body any `json:"body"`
|
||||
Status int `json:"status"`
|
||||
}
|
||||
|
||||
type batchRequestsForm struct {
|
||||
Requests []*core.InternalRequest `form:"requests" json:"requests"`
|
||||
|
||||
max int
|
||||
}
|
||||
|
||||
func (brs batchRequestsForm) validate() error {
|
||||
return validation.ValidateStruct(&brs,
|
||||
validation.Field(&brs.Requests, validation.Required, validation.Length(0, brs.max)),
|
||||
)
|
||||
}
|
||||
|
||||
// NB! When the request is submitted as multipart/form-data,
|
||||
// the regular fields data is expected to be submitted as serailized
|
||||
// json under the @jsonPayload field and file keys need to follow the
|
||||
// pattern "requests.N.fileField" or requests[N].fileField.
|
||||
func batchTransaction(e *core.RequestEvent) error {
|
||||
maxRequests := e.App.Settings().Batch.MaxRequests
|
||||
if !e.App.Settings().Batch.Enabled || maxRequests <= 0 {
|
||||
return e.ForbiddenError("Batch requests are not allowed.", nil)
|
||||
}
|
||||
|
||||
txTimeout := time.Duration(e.App.Settings().Batch.Timeout) * time.Second
|
||||
if txTimeout <= 0 {
|
||||
txTimeout = 3 * time.Second // for now always limit
|
||||
}
|
||||
|
||||
maxBodySize := e.App.Settings().Batch.MaxBodySize
|
||||
if maxBodySize <= 0 {
|
||||
maxBodySize = 128 << 20
|
||||
}
|
||||
|
||||
err := applyBodyLimit(e, maxBodySize)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
form := &batchRequestsForm{max: maxRequests}
|
||||
|
||||
// load base requests data
|
||||
err = e.BindBody(form)
|
||||
if err != nil {
|
||||
return e.BadRequestError("Failed to read the submitted batch data.", err)
|
||||
}
|
||||
|
||||
// load uploaded files into each request item
|
||||
// note: expects the files to be under "requests.N.fileField" or "requests[N].fileField" format
|
||||
// (the other regular fields must be put under `@jsonPayload` as serialized json)
|
||||
if strings.HasPrefix(e.Request.Header.Get("Content-Type"), "multipart/form-data") {
|
||||
for i, ir := range form.Requests {
|
||||
iStr := strconv.Itoa(i)
|
||||
|
||||
files, err := extractPrefixedFiles(e.Request, "requests."+iStr+".", "requests["+iStr+"].")
|
||||
if err != nil {
|
||||
return e.BadRequestError("Failed to read the submitted batch files data.", err)
|
||||
}
|
||||
|
||||
for key, files := range files {
|
||||
if ir.Body == nil {
|
||||
ir.Body = map[string]any{}
|
||||
}
|
||||
ir.Body[key] = files
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// validate batch request form
|
||||
err = form.validate()
|
||||
if err != nil {
|
||||
return e.BadRequestError("Invalid batch request data.", err)
|
||||
}
|
||||
|
||||
event := new(core.BatchRequestEvent)
|
||||
event.RequestEvent = e
|
||||
event.Batch = form.Requests
|
||||
|
||||
return e.App.OnBatchRequest().Trigger(event, func(e *core.BatchRequestEvent) error {
|
||||
bp := batchProcessor{
|
||||
app: e.App,
|
||||
baseEvent: e.RequestEvent,
|
||||
infoContext: core.RequestInfoContextBatch,
|
||||
}
|
||||
|
||||
if err := bp.Process(e.Batch, txTimeout); err != nil {
|
||||
return firstApiError(err, e.BadRequestError("Batch transaction failed.", err))
|
||||
}
|
||||
|
||||
return e.JSON(http.StatusOK, bp.results)
|
||||
})
|
||||
}
|
||||
|
||||
type batchProcessor struct {
|
||||
app core.App
|
||||
baseEvent *core.RequestEvent
|
||||
infoContext string
|
||||
results []*BatchRequestResult
|
||||
failedIndex int
|
||||
errCh chan error
|
||||
stopCh chan struct{}
|
||||
}
|
||||
|
||||
func (p *batchProcessor) Process(batch []*core.InternalRequest, timeout time.Duration) error {
|
||||
p.results = make([]*BatchRequestResult, 0, len(batch))
|
||||
|
||||
if p.stopCh != nil {
|
||||
close(p.stopCh)
|
||||
}
|
||||
p.stopCh = make(chan struct{}, 1)
|
||||
|
||||
if p.errCh != nil {
|
||||
close(p.errCh)
|
||||
}
|
||||
p.errCh = make(chan error, 1)
|
||||
|
||||
return p.app.RunInTransaction(func(txApp core.App) error {
|
||||
// used to interupts the recursive processing calls in case of a timeout or connection close
|
||||
defer func() {
|
||||
p.stopCh <- struct{}{}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
err := p.process(txApp, batch, 0)
|
||||
|
||||
if err != nil {
|
||||
err = validation.Errors{
|
||||
"requests": validation.Errors{
|
||||
strconv.Itoa(p.failedIndex): &BatchResponseError{
|
||||
code: "batch_request_failed",
|
||||
message: "Batch request failed.",
|
||||
err: router.ToApiError(err),
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// note: to avoid copying and due to the process recursion the final results order is reversed
|
||||
if err == nil {
|
||||
slices.Reverse(p.results)
|
||||
}
|
||||
|
||||
p.errCh <- err
|
||||
}()
|
||||
|
||||
select {
|
||||
case responseErr := <-p.errCh:
|
||||
return responseErr
|
||||
case <-time.After(timeout):
|
||||
// note: we don't return 408 Reques Timeout error because
|
||||
// some browsers perform automatic retry behind the scenes
|
||||
// which are hard to debug and unnecessary
|
||||
return errors.New("batch transaction timeout")
|
||||
case <-p.baseEvent.Request.Context().Done():
|
||||
return errors.New("batch request interrupted")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (p *batchProcessor) process(activeApp core.App, batch []*core.InternalRequest, i int) error {
|
||||
select {
|
||||
case <-p.stopCh:
|
||||
return nil
|
||||
default:
|
||||
if len(batch) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
result, err := processInternalRequest(
|
||||
activeApp,
|
||||
p.baseEvent,
|
||||
batch[0],
|
||||
p.infoContext,
|
||||
func(_ any) error {
|
||||
if len(batch) == 1 {
|
||||
return nil
|
||||
}
|
||||
|
||||
err := p.process(activeApp, batch[1:], i+1)
|
||||
|
||||
// update the failed batch index (if not already)
|
||||
if err != nil && p.failedIndex == 0 {
|
||||
p.failedIndex = i + 1
|
||||
}
|
||||
|
||||
return err
|
||||
},
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p.results = append(p.results, result)
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func processInternalRequest(
|
||||
activeApp core.App,
|
||||
baseEvent *core.RequestEvent,
|
||||
ir *core.InternalRequest,
|
||||
infoContext string,
|
||||
optNext func(data any) error,
|
||||
) (*BatchRequestResult, error) {
|
||||
handle, params, ok := prepareInternalAction(activeApp, ir, optNext)
|
||||
if !ok {
|
||||
return nil, errors.New("unknown batch request action")
|
||||
}
|
||||
|
||||
// construct a new http.Request
|
||||
// ---------------------------------------------------------------
|
||||
buf, mw, err := multipartDataFromInternalRequest(ir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r, err := http.NewRequest(strings.ToUpper(ir.Method), ir.URL, buf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// cleanup multipart temp files
|
||||
defer func() {
|
||||
if r.MultipartForm != nil {
|
||||
if err := r.MultipartForm.RemoveAll(); err != nil {
|
||||
activeApp.Logger().Warn("failed to cleanup temp batch files", "error", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// load batch request path params
|
||||
// ---
|
||||
for k, v := range params {
|
||||
r.SetPathValue(k, v)
|
||||
}
|
||||
|
||||
// clone original request
|
||||
// ---
|
||||
r.RequestURI = r.URL.RequestURI()
|
||||
r.Proto = baseEvent.Request.Proto
|
||||
r.ProtoMajor = baseEvent.Request.ProtoMajor
|
||||
r.ProtoMinor = baseEvent.Request.ProtoMinor
|
||||
r.Host = baseEvent.Request.Host
|
||||
r.RemoteAddr = baseEvent.Request.RemoteAddr
|
||||
r.TLS = baseEvent.Request.TLS
|
||||
|
||||
if s := baseEvent.Request.TransferEncoding; s != nil {
|
||||
s2 := make([]string, len(s))
|
||||
copy(s2, s)
|
||||
r.TransferEncoding = s2
|
||||
}
|
||||
|
||||
if baseEvent.Request.Trailer != nil {
|
||||
r.Trailer = baseEvent.Request.Trailer.Clone()
|
||||
}
|
||||
|
||||
if baseEvent.Request.Header != nil {
|
||||
r.Header = baseEvent.Request.Header.Clone()
|
||||
}
|
||||
|
||||
// apply batch request specific headers
|
||||
// ---
|
||||
for k, v := range ir.Headers {
|
||||
// individual Authorization header keys don't have affect
|
||||
// because the auth state is populated from the base event
|
||||
if strings.EqualFold(k, "authorization") {
|
||||
continue
|
||||
}
|
||||
|
||||
r.Header.Set(k, v)
|
||||
}
|
||||
r.Header.Set("Content-Type", mw.FormDataContentType())
|
||||
|
||||
// construct a new RequestEvent
|
||||
// ---------------------------------------------------------------
|
||||
event := &core.RequestEvent{}
|
||||
event.App = activeApp
|
||||
event.Auth = baseEvent.Auth
|
||||
event.SetAll(baseEvent.GetAll())
|
||||
|
||||
// load RequestInfo context
|
||||
if infoContext == "" {
|
||||
infoContext = core.RequestInfoContextDefault
|
||||
}
|
||||
event.Set(core.RequestEventKeyInfoContext, infoContext)
|
||||
|
||||
// assign request
|
||||
event.Request = r
|
||||
event.Request.Body = &router.RereadableReadCloser{ReadCloser: r.Body} // enables multiple reads
|
||||
|
||||
// assign response
|
||||
rec := httptest.NewRecorder()
|
||||
event.Response = &router.ResponseWriter{ResponseWriter: rec} // enables status and write tracking
|
||||
|
||||
// execute
|
||||
// ---------------------------------------------------------------
|
||||
if err := handle(event); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := rec.Result()
|
||||
defer result.Body.Close()
|
||||
|
||||
body, _ := types.ParseJSONRaw(rec.Body.Bytes())
|
||||
|
||||
return &BatchRequestResult{
|
||||
Status: result.StatusCode,
|
||||
Body: body,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func multipartDataFromInternalRequest(ir *core.InternalRequest) (*bytes.Buffer, *multipart.Writer, error) {
|
||||
buf := &bytes.Buffer{}
|
||||
|
||||
mw := multipart.NewWriter(buf)
|
||||
|
||||
regularFields := map[string]any{}
|
||||
fileFields := map[string][]*filesystem.File{}
|
||||
|
||||
// separate regular fields from files
|
||||
// ---
|
||||
for k, rawV := range ir.Body {
|
||||
switch v := rawV.(type) {
|
||||
case *filesystem.File:
|
||||
fileFields[k] = append(fileFields[k], v)
|
||||
case []*filesystem.File:
|
||||
fileFields[k] = append(fileFields[k], v...)
|
||||
default:
|
||||
regularFields[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
// submit regularFields as @jsonPayload
|
||||
// ---
|
||||
rawBody, err := json.Marshal(regularFields)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Join(err, mw.Close())
|
||||
}
|
||||
|
||||
jsonPayload, err := mw.CreateFormField("@jsonPayload")
|
||||
if err != nil {
|
||||
return nil, nil, errors.Join(err, mw.Close())
|
||||
}
|
||||
_, err = jsonPayload.Write(rawBody)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Join(err, mw.Close())
|
||||
}
|
||||
|
||||
// submit fileFields as multipart files
|
||||
// ---
|
||||
for key, files := range fileFields {
|
||||
for _, file := range files {
|
||||
part, err := mw.CreateFormFile(key, file.Name)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Join(err, mw.Close())
|
||||
}
|
||||
|
||||
fr, err := file.Reader.Open()
|
||||
if err != nil {
|
||||
return nil, nil, errors.Join(err, mw.Close())
|
||||
}
|
||||
|
||||
_, err = io.Copy(part, fr)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Join(err, fr.Close(), mw.Close())
|
||||
}
|
||||
|
||||
err = fr.Close()
|
||||
if err != nil {
|
||||
return nil, nil, errors.Join(err, mw.Close())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return buf, mw, mw.Close()
|
||||
}
|
||||
|
||||
func extractPrefixedFiles(request *http.Request, prefixes ...string) (map[string][]*filesystem.File, error) {
|
||||
if request.MultipartForm == nil {
|
||||
if err := request.ParseMultipartForm(router.DefaultMaxMemory); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
result := make(map[string][]*filesystem.File)
|
||||
|
||||
for k, fhs := range request.MultipartForm.File {
|
||||
for _, p := range prefixes {
|
||||
if strings.HasPrefix(k, p) {
|
||||
resultKey := strings.TrimPrefix(k, p)
|
||||
|
||||
for _, fh := range fhs {
|
||||
file, err := filesystem.NewFileFromMultipart(fh)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result[resultKey] = append(result[resultKey], file)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func prepareInternalAction(activeApp core.App, ir *core.InternalRequest, optNext func(data any) error) (HandleFunc, map[string]string, bool) {
|
||||
full := strings.ToUpper(ir.Method) + " " + ir.URL
|
||||
|
||||
for re, actionFactory := range ValidBatchActions {
|
||||
params, ok := findNamedMatches(re, full)
|
||||
if ok {
|
||||
return actionFactory(activeApp, ir, params, optNext), params, true
|
||||
}
|
||||
}
|
||||
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
func findNamedMatches(re *regexp.Regexp, str string) (map[string]string, bool) {
|
||||
match := re.FindStringSubmatch(str)
|
||||
if match == nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
result := map[string]string{}
|
||||
|
||||
names := re.SubexpNames()
|
||||
|
||||
for i, m := range match {
|
||||
if names[i] != "" {
|
||||
result[names[i]] = m
|
||||
}
|
||||
}
|
||||
|
||||
return result, true
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
var (
|
||||
_ router.SafeErrorItem = (*BatchResponseError)(nil)
|
||||
_ router.SafeErrorResolver = (*BatchResponseError)(nil)
|
||||
)
|
||||
|
||||
type BatchResponseError struct {
|
||||
err *router.ApiError
|
||||
code string
|
||||
message string
|
||||
}
|
||||
|
||||
func (e *BatchResponseError) Error() string {
|
||||
return e.message
|
||||
}
|
||||
|
||||
func (e *BatchResponseError) Code() string {
|
||||
return e.code
|
||||
}
|
||||
|
||||
func (e *BatchResponseError) Resolve(errData map[string]any) any {
|
||||
errData["response"] = e.err
|
||||
return errData
|
||||
}
|
||||
|
||||
func (e BatchResponseError) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(map[string]any{
|
||||
"message": e.message,
|
||||
"code": e.code,
|
||||
"response": e.err,
|
||||
})
|
||||
}
|
691
apis/batch_test.go
Normal file
691
apis/batch_test.go
Normal file
|
@ -0,0 +1,691 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
"github.com/pocketbase/pocketbase/tools/router"
|
||||
)
|
||||
|
||||
func TestBatchRequest(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
formData, mp, err := tests.MockMultipartData(
|
||||
map[string]string{
|
||||
router.JSONPayloadKey: `{
|
||||
"requests":[
|
||||
{"method":"POST", "url":"/api/collections/demo3/records", "body": {"title": "batch1"}},
|
||||
{"method":"POST", "url":"/api/collections/demo3/records", "body": {"title": "batch2"}},
|
||||
{"method":"POST", "url":"/api/collections/demo3/records", "body": {"title": "batch3"}},
|
||||
{"method":"PATCH", "url":"/api/collections/demo3/records/lcl9d87w22ml6jy", "body": {"files-": "test_FLurQTgrY8.txt"}}
|
||||
]
|
||||
}`,
|
||||
},
|
||||
"requests.0.files",
|
||||
"requests.0.files",
|
||||
"requests.0.files",
|
||||
"requests[2].files",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "disabled batch requets",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/batch",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().Batch.Enabled = false
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "max request limits reached",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/batch",
|
||||
Body: strings.NewReader(`{
|
||||
"requests": [
|
||||
{"method":"GET", "url":"/test1"},
|
||||
{"method":"GET", "url":"/test2"},
|
||||
{"method":"GET", "url":"/test3"}
|
||||
]
|
||||
}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().Batch.Enabled = true
|
||||
app.Settings().Batch.MaxRequests = 2
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"requests":{"code":"validation_length_too_long"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "trigger requests validations",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/batch",
|
||||
Body: strings.NewReader(`{
|
||||
"requests": [
|
||||
{},
|
||||
{"method":"GET", "url":"/valid"},
|
||||
{"method":"invalid", "url":"/valid"},
|
||||
{"method":"POST", "url":"` + strings.Repeat("a", 2001) + `"}
|
||||
]
|
||||
}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().Batch.Enabled = true
|
||||
app.Settings().Batch.MaxRequests = 100
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"requests":{`,
|
||||
`"0":{"method":{"code":"validation_required"`,
|
||||
`"2":{"method":{"code":"validation_in_invalid"`,
|
||||
`"3":{"url":{"code":"validation_length_too_long"`,
|
||||
},
|
||||
NotExpectedContent: []string{
|
||||
`"1":`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "unknown batch request action",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/batch",
|
||||
Body: strings.NewReader(`{
|
||||
"requests": [
|
||||
{"method":"GET", "url":"/api/health"}
|
||||
]
|
||||
}`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"requests":{`,
|
||||
`0":{"code":"batch_request_failed"`,
|
||||
`"response":{`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnBatchRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "base 2 successful and 1 failed (public collection)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/batch",
|
||||
Body: strings.NewReader(`{
|
||||
"requests": [
|
||||
{"method":"POST", "url":"/api/collections/demo2/records", "body": {"title": "batch1"}},
|
||||
{"method":"POST", "url":"/api/collections/demo2/records", "body": {"title": "batch2"}},
|
||||
{"method":"POST", "url":"/api/collections/demo2/records", "body": {"title": ""}}
|
||||
]
|
||||
}`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"response":{`,
|
||||
`"2":{"code":"batch_request_failed"`,
|
||||
`"response":{"data":{"title":{"code":"validation_required"`,
|
||||
},
|
||||
NotExpectedContent: []string{
|
||||
`"0":`,
|
||||
`"1":`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnBatchRequest": 1,
|
||||
"OnRecordCreateRequest": 3,
|
||||
"OnModelCreate": 3,
|
||||
"OnModelCreateExecute": 2,
|
||||
"OnModelAfterCreateError": 3,
|
||||
"OnModelValidate": 3,
|
||||
"OnRecordCreate": 3,
|
||||
"OnRecordCreateExecute": 2,
|
||||
"OnRecordAfterCreateError": 3,
|
||||
"OnRecordValidate": 3,
|
||||
"OnRecordEnrich": 2,
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
records, err := app.FindRecordsByFilter("demo2", `title~"batch"`, "", 0, 0)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(records) != 0 {
|
||||
t.Fatalf("Expected no batch records to be persisted, got %d", len(records))
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "base 4 successful (public collection)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/batch",
|
||||
Body: strings.NewReader(`{
|
||||
"requests": [
|
||||
{"method":"POST", "url":"/api/collections/demo2/records", "body": {"title": "batch1"}},
|
||||
{"method":"POST", "url":"/api/collections/demo2/records", "body": {"title": "batch2"}},
|
||||
{"method":"PUT", "url":"/api/collections/demo2/records", "body": {"title": "batch3"}},
|
||||
{"method":"PUT", "url":"/api/collections/demo2/records?fields=*,id:excerpt(4,true)", "body": {"id":"achvryl401bhse3","title": "batch4"}}
|
||||
]
|
||||
}`),
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"title":"batch1"`,
|
||||
`"title":"batch2"`,
|
||||
`"title":"batch3"`,
|
||||
`"title":"batch4"`,
|
||||
`"id":"achv..."`,
|
||||
`"active":false`,
|
||||
`"active":true`,
|
||||
`"status":200`,
|
||||
`"body":{`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnBatchRequest": 1,
|
||||
"OnModelValidate": 4,
|
||||
"OnRecordValidate": 4,
|
||||
"OnRecordEnrich": 4,
|
||||
|
||||
"OnRecordCreateRequest": 3,
|
||||
"OnModelCreate": 3,
|
||||
"OnModelCreateExecute": 3,
|
||||
"OnModelAfterCreateSuccess": 3,
|
||||
"OnRecordCreate": 3,
|
||||
"OnRecordCreateExecute": 3,
|
||||
"OnRecordAfterCreateSuccess": 3,
|
||||
|
||||
"OnRecordUpdateRequest": 1,
|
||||
"OnModelUpdate": 1,
|
||||
"OnModelUpdateExecute": 1,
|
||||
"OnModelAfterUpdateSuccess": 1,
|
||||
"OnRecordUpdate": 1,
|
||||
"OnRecordUpdateExecute": 1,
|
||||
"OnRecordAfterUpdateSuccess": 1,
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
records, err := app.FindRecordsByFilter("demo2", `title~"batch"`, "", 0, 0)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(records) != 4 {
|
||||
t.Fatalf("Expected %d batch records to be persisted, got %d", 3, len(records))
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "mixed create/update/delete (rules failure)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/batch",
|
||||
Body: strings.NewReader(`{
|
||||
"requests": [
|
||||
{"method":"POST", "url":"/api/collections/demo2/records", "body": {"title": "batch_create"}},
|
||||
{"method":"DELETE", "url":"/api/collections/demo2/records/achvryl401bhse3"},
|
||||
{"method":"PATCH", "url":"/api/collections/demo3/records/1tmknxy2868d869", "body": {"title": "batch_update"}}
|
||||
]
|
||||
}`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"requests":{`,
|
||||
`"2":{"code":"batch_request_failed"`,
|
||||
`"response":{`,
|
||||
},
|
||||
NotExpectedContent: []string{
|
||||
// only demo3 requires authentication
|
||||
`"0":`,
|
||||
`"1":`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnBatchRequest": 1,
|
||||
"OnModelCreate": 1,
|
||||
"OnModelCreateExecute": 1,
|
||||
"OnModelAfterCreateError": 1,
|
||||
"OnModelDelete": 1,
|
||||
"OnModelDeleteExecute": 1,
|
||||
"OnModelAfterDeleteError": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnRecordCreateRequest": 1,
|
||||
"OnRecordCreate": 1,
|
||||
"OnRecordCreateExecute": 1,
|
||||
"OnRecordAfterCreateError": 1,
|
||||
"OnRecordDeleteRequest": 1,
|
||||
"OnRecordDelete": 1,
|
||||
"OnRecordDeleteExecute": 1,
|
||||
"OnRecordAfterDeleteError": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
"OnRecordValidate": 1,
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
_, err := app.FindFirstRecordByFilter("demo2", `title="batch_create"`)
|
||||
if err == nil {
|
||||
t.Fatal("Expected record to not be created")
|
||||
}
|
||||
|
||||
_, err = app.FindFirstRecordByFilter("demo3", `title="batch_update"`)
|
||||
if err == nil {
|
||||
t.Fatal("Expected record to not be updated")
|
||||
}
|
||||
|
||||
_, err = app.FindRecordById("demo2", "achvryl401bhse3")
|
||||
if err != nil {
|
||||
t.Fatal("Expected record to not be deleted")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "mixed create/update/delete (rules success)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/batch",
|
||||
Headers: map[string]string{
|
||||
// test@example.com, clients
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
|
||||
},
|
||||
Body: strings.NewReader(`{
|
||||
"requests": [
|
||||
{"method":"POST", "url":"/api/collections/demo2/records", "body": {"title": "batch_create"}, "headers": {"Authorization": "ignored"}},
|
||||
{"method":"DELETE", "url":"/api/collections/demo2/records/achvryl401bhse3", "headers": {"Authorization": "ignored"}},
|
||||
{"method":"PATCH", "url":"/api/collections/demo3/records/1tmknxy2868d869", "body": {"title": "batch_update"}, "headers": {"Authorization": "ignored"}}
|
||||
]
|
||||
}`),
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"title":"batch_create"`,
|
||||
`"title":"batch_update"`,
|
||||
`"status":200`,
|
||||
`"status":204`,
|
||||
`"body":{`,
|
||||
`"body":null`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnBatchRequest": 1,
|
||||
// ---
|
||||
"OnModelCreate": 1,
|
||||
"OnModelCreateExecute": 1,
|
||||
"OnModelAfterCreateSuccess": 1,
|
||||
"OnModelDelete": 1,
|
||||
"OnModelDeleteExecute": 1,
|
||||
"OnModelAfterDeleteSuccess": 1,
|
||||
"OnModelUpdate": 1,
|
||||
"OnModelUpdateExecute": 1,
|
||||
"OnModelAfterUpdateSuccess": 1,
|
||||
"OnModelValidate": 2,
|
||||
// ---
|
||||
"OnRecordCreateRequest": 1,
|
||||
"OnRecordCreate": 1,
|
||||
"OnRecordCreateExecute": 1,
|
||||
"OnRecordAfterCreateSuccess": 1,
|
||||
"OnRecordDeleteRequest": 1,
|
||||
"OnRecordDelete": 1,
|
||||
"OnRecordDeleteExecute": 1,
|
||||
"OnRecordAfterDeleteSuccess": 1,
|
||||
"OnRecordUpdateRequest": 1,
|
||||
"OnRecordUpdate": 1,
|
||||
"OnRecordUpdateExecute": 1,
|
||||
"OnRecordAfterUpdateSuccess": 1,
|
||||
"OnRecordValidate": 2,
|
||||
"OnRecordEnrich": 2,
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
_, err := app.FindFirstRecordByFilter("demo2", `title="batch_create"`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = app.FindFirstRecordByFilter("demo3", `title="batch_update"`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = app.FindRecordById("demo2", "achvryl401bhse3")
|
||||
if err == nil {
|
||||
t.Fatal("Expected record to be deleted")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "mixed create/update/delete (superuser auth)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/batch",
|
||||
Headers: map[string]string{
|
||||
// test@example.com, _superusers
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
Body: strings.NewReader(`{
|
||||
"requests": [
|
||||
{"method":"POST", "url":"/api/collections/demo2/records", "body": {"title": "batch_create"}},
|
||||
{"method":"DELETE", "url":"/api/collections/demo2/records/achvryl401bhse3"},
|
||||
{"method":"PATCH", "url":"/api/collections/demo3/records/1tmknxy2868d869", "body": {"title": "batch_update"}}
|
||||
]
|
||||
}`),
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"title":"batch_create"`,
|
||||
`"title":"batch_update"`,
|
||||
`"status":200`,
|
||||
`"status":204`,
|
||||
`"body":{`,
|
||||
`"body":null`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnBatchRequest": 1,
|
||||
// ---
|
||||
"OnModelCreate": 1,
|
||||
"OnModelCreateExecute": 1,
|
||||
"OnModelAfterCreateSuccess": 1,
|
||||
"OnModelDelete": 1,
|
||||
"OnModelDeleteExecute": 1,
|
||||
"OnModelAfterDeleteSuccess": 1,
|
||||
"OnModelUpdate": 1,
|
||||
"OnModelUpdateExecute": 1,
|
||||
"OnModelAfterUpdateSuccess": 1,
|
||||
"OnModelValidate": 2,
|
||||
// ---
|
||||
"OnRecordCreateRequest": 1,
|
||||
"OnRecordCreate": 1,
|
||||
"OnRecordCreateExecute": 1,
|
||||
"OnRecordAfterCreateSuccess": 1,
|
||||
"OnRecordDeleteRequest": 1,
|
||||
"OnRecordDelete": 1,
|
||||
"OnRecordDeleteExecute": 1,
|
||||
"OnRecordAfterDeleteSuccess": 1,
|
||||
"OnRecordUpdateRequest": 1,
|
||||
"OnRecordUpdate": 1,
|
||||
"OnRecordUpdateExecute": 1,
|
||||
"OnRecordAfterUpdateSuccess": 1,
|
||||
"OnRecordValidate": 2,
|
||||
"OnRecordEnrich": 2,
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
_, err := app.FindFirstRecordByFilter("demo2", `title="batch_create"`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = app.FindFirstRecordByFilter("demo3", `title="batch_update"`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = app.FindRecordById("demo2", "achvryl401bhse3")
|
||||
if err == nil {
|
||||
t.Fatal("Expected record to be deleted")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "cascade delete/update",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/batch",
|
||||
Headers: map[string]string{
|
||||
// test@example.com, _superusers
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
Body: strings.NewReader(`{
|
||||
"requests": [
|
||||
{"method":"DELETE", "url":"/api/collections/demo3/records/1tmknxy2868d869"},
|
||||
{"method":"DELETE", "url":"/api/collections/demo3/records/mk5fmymtx4wsprk"}
|
||||
]
|
||||
}`),
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"status":204`,
|
||||
`"body":null`,
|
||||
},
|
||||
NotExpectedContent: []string{
|
||||
`"status":200`,
|
||||
`"body":{`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnBatchRequest": 1,
|
||||
// ---
|
||||
"OnModelDelete": 3, // 2 batch + 1 cascade delete
|
||||
"OnModelDeleteExecute": 3,
|
||||
"OnModelAfterDeleteSuccess": 3,
|
||||
"OnModelUpdate": 5, // 5 cascade update
|
||||
"OnModelUpdateExecute": 5,
|
||||
"OnModelAfterUpdateSuccess": 5,
|
||||
// ---
|
||||
"OnRecordDeleteRequest": 2,
|
||||
"OnRecordDelete": 3,
|
||||
"OnRecordDeleteExecute": 3,
|
||||
"OnRecordAfterDeleteSuccess": 3,
|
||||
"OnRecordUpdate": 5,
|
||||
"OnRecordUpdateExecute": 5,
|
||||
"OnRecordAfterUpdateSuccess": 5,
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
ids := []string{
|
||||
"1tmknxy2868d869",
|
||||
"mk5fmymtx4wsprk",
|
||||
"qzaqccwrmva4o1n",
|
||||
}
|
||||
|
||||
for _, id := range ids {
|
||||
_, err := app.FindRecordById("demo2", id)
|
||||
if err == nil {
|
||||
t.Fatalf("Expected record %q to be deleted", id)
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "transaction timeout",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/batch",
|
||||
Body: strings.NewReader(`{
|
||||
"requests": [
|
||||
{"method":"POST", "url":"/api/collections/demo2/records", "body": {"title": "batch1"}},
|
||||
{"method":"POST", "url":"/api/collections/demo2/records", "body": {"title": "batch2"}}
|
||||
]
|
||||
}`),
|
||||
Headers: map[string]string{
|
||||
// test@example.com, _superusers
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().Batch.Timeout = 1
|
||||
app.OnRecordCreateRequest("demo2").BindFunc(func(e *core.RecordRequestEvent) error {
|
||||
time.Sleep(600 * time.Millisecond) // < 1s so that the first request can succeed
|
||||
return e.Next()
|
||||
})
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{}`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnBatchRequest": 1,
|
||||
"OnRecordCreateRequest": 2,
|
||||
"OnModelCreate": 1,
|
||||
"OnModelCreateExecute": 1,
|
||||
"OnModelAfterCreateError": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnRecordCreate": 1,
|
||||
"OnRecordCreateExecute": 1,
|
||||
"OnRecordAfterCreateError": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
"OnRecordValidate": 1,
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
records, err := app.FindRecordsByFilter("demo2", `title~"batch"`, "", 0, 0)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(records) != 0 {
|
||||
t.Fatalf("Expected %d batch records to be persisted, got %d", 0, len(records))
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "multipart/form-data + file upload",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/batch",
|
||||
Body: formData,
|
||||
Headers: map[string]string{
|
||||
// test@example.com, clients
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
|
||||
"Content-Type": mp.FormDataContentType(),
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"title":"batch1"`,
|
||||
`"title":"batch2"`,
|
||||
`"title":"batch3"`,
|
||||
`"id":"lcl9d87w22ml6jy"`,
|
||||
`"files":["300_UhLKX91HVb.png"]`,
|
||||
`"tmpfile_`,
|
||||
`"status":200`,
|
||||
`"body":{`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnBatchRequest": 1,
|
||||
// ---
|
||||
"OnModelCreate": 3,
|
||||
"OnModelCreateExecute": 3,
|
||||
"OnModelAfterCreateSuccess": 3,
|
||||
"OnModelUpdate": 1,
|
||||
"OnModelUpdateExecute": 1,
|
||||
"OnModelAfterUpdateSuccess": 1,
|
||||
"OnModelValidate": 4,
|
||||
// ---
|
||||
"OnRecordCreateRequest": 3,
|
||||
"OnRecordUpdateRequest": 1,
|
||||
"OnRecordCreate": 3,
|
||||
"OnRecordCreateExecute": 3,
|
||||
"OnRecordAfterCreateSuccess": 3,
|
||||
"OnRecordUpdate": 1,
|
||||
"OnRecordUpdateExecute": 1,
|
||||
"OnRecordAfterUpdateSuccess": 1,
|
||||
"OnRecordValidate": 4,
|
||||
"OnRecordEnrich": 4,
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
batch1, err := app.FindFirstRecordByFilter("demo3", `title="batch1"`)
|
||||
if err != nil {
|
||||
t.Fatalf("missing batch1: %v", err)
|
||||
}
|
||||
batch1Files := batch1.GetStringSlice("files")
|
||||
if len(batch1Files) != 3 {
|
||||
t.Fatalf("Expected %d batch1 file(s), got %d", 3, len(batch1Files))
|
||||
}
|
||||
|
||||
batch2, err := app.FindFirstRecordByFilter("demo3", `title="batch2"`)
|
||||
if err != nil {
|
||||
t.Fatalf("missing batch2: %v", err)
|
||||
}
|
||||
batch2Files := batch2.GetStringSlice("files")
|
||||
if len(batch2Files) != 0 {
|
||||
t.Fatalf("Expected %d batch2 file(s), got %d", 0, len(batch2Files))
|
||||
}
|
||||
|
||||
batch3, err := app.FindFirstRecordByFilter("demo3", `title="batch3"`)
|
||||
if err != nil {
|
||||
t.Fatalf("missing batch3: %v", err)
|
||||
}
|
||||
batch3Files := batch3.GetStringSlice("files")
|
||||
if len(batch3Files) != 1 {
|
||||
t.Fatalf("Expected %d batch3 file(s), got %d", 1, len(batch3Files))
|
||||
}
|
||||
|
||||
batch4, err := app.FindRecordById("demo3", "lcl9d87w22ml6jy")
|
||||
if err != nil {
|
||||
t.Fatalf("missing batch4: %v", err)
|
||||
}
|
||||
batch4Files := batch4.GetStringSlice("files")
|
||||
if len(batch4Files) != 1 {
|
||||
t.Fatalf("Expected %d batch4 file(s), got %d", 1, len(batch4Files))
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "create/update with expand query params",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/batch",
|
||||
Headers: map[string]string{
|
||||
// test@example.com, _superusers
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
Body: strings.NewReader(`{
|
||||
"requests": [
|
||||
{"method":"POST", "url":"/api/collections/demo5/records?expand=rel_one", "body": {"total": 9, "rel_one":"qzaqccwrmva4o1n"}},
|
||||
{"method":"PATCH", "url":"/api/collections/demo5/records/qjeql998mtp1azp?expand=rel_many", "body": {"total": 10}}
|
||||
]
|
||||
}`),
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"body":{`,
|
||||
`"id":"qjeql998mtp1azp"`,
|
||||
`"id":"qzaqccwrmva4o1n"`,
|
||||
`"id":"i9naidtvr6qsgb4"`,
|
||||
`"expand":{"rel_one"`,
|
||||
`"expand":{"rel_many"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnBatchRequest": 1,
|
||||
// ---
|
||||
"OnModelCreate": 1,
|
||||
"OnModelCreateExecute": 1,
|
||||
"OnModelAfterCreateSuccess": 1,
|
||||
"OnModelUpdate": 1,
|
||||
"OnModelUpdateExecute": 1,
|
||||
"OnModelAfterUpdateSuccess": 1,
|
||||
"OnModelValidate": 2,
|
||||
// ---
|
||||
"OnRecordCreateRequest": 1,
|
||||
"OnRecordUpdateRequest": 1,
|
||||
"OnRecordCreate": 1,
|
||||
"OnRecordCreateExecute": 1,
|
||||
"OnRecordAfterCreateSuccess": 1,
|
||||
"OnRecordUpdate": 1,
|
||||
"OnRecordUpdateExecute": 1,
|
||||
"OnRecordAfterUpdateSuccess": 1,
|
||||
"OnRecordValidate": 2,
|
||||
"OnRecordEnrich": 5,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "check body limit middleware",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/batch",
|
||||
Headers: map[string]string{
|
||||
// test@example.com, _superusers
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
Body: strings.NewReader(`{
|
||||
"requests": [
|
||||
{"method":"POST", "url":"/api/collections/demo5/records?expand=rel_one", "body": {"total": 9, "rel_one":"qzaqccwrmva4o1n"}},
|
||||
{"method":"PATCH", "url":"/api/collections/demo5/records/qjeql998mtp1azp?expand=rel_many", "body": {"total": 10}}
|
||||
]
|
||||
}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().Batch.MaxBodySize = 10
|
||||
},
|
||||
ExpectedStatus: 413,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
206
apis/collection.go
Normal file
206
apis/collection.go
Normal file
|
@ -0,0 +1,206 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tools/router"
|
||||
"github.com/pocketbase/pocketbase/tools/search"
|
||||
)
|
||||
|
||||
// bindCollectionApi registers the collection api endpoints and the corresponding handlers.
|
||||
func bindCollectionApi(app core.App, rg *router.RouterGroup[*core.RequestEvent]) {
|
||||
subGroup := rg.Group("/collections").Bind(RequireSuperuserAuth())
|
||||
subGroup.GET("", collectionsList)
|
||||
subGroup.POST("", collectionCreate)
|
||||
subGroup.GET("/{collection}", collectionView)
|
||||
subGroup.PATCH("/{collection}", collectionUpdate)
|
||||
subGroup.DELETE("/{collection}", collectionDelete)
|
||||
subGroup.DELETE("/{collection}/truncate", collectionTruncate)
|
||||
subGroup.PUT("/import", collectionsImport)
|
||||
subGroup.GET("/meta/scaffolds", collectionScaffolds)
|
||||
}
|
||||
|
||||
func collectionsList(e *core.RequestEvent) error {
|
||||
fieldResolver := search.NewSimpleFieldResolver(
|
||||
"id", "created", "updated", "name", "system", "type",
|
||||
)
|
||||
|
||||
collections := []*core.Collection{}
|
||||
|
||||
result, err := search.NewProvider(fieldResolver).
|
||||
Query(e.App.CollectionQuery()).
|
||||
ParseAndExec(e.Request.URL.Query().Encode(), &collections)
|
||||
|
||||
if err != nil {
|
||||
return e.BadRequestError("", err)
|
||||
}
|
||||
|
||||
event := new(core.CollectionsListRequestEvent)
|
||||
event.RequestEvent = e
|
||||
event.Collections = collections
|
||||
event.Result = result
|
||||
|
||||
return event.App.OnCollectionsListRequest().Trigger(event, func(e *core.CollectionsListRequestEvent) error {
|
||||
return execAfterSuccessTx(true, e.App, func() error {
|
||||
return e.JSON(http.StatusOK, e.Result)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func collectionView(e *core.RequestEvent) error {
|
||||
collection, err := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue("collection"))
|
||||
if err != nil || collection == nil {
|
||||
return e.NotFoundError("", err)
|
||||
}
|
||||
|
||||
event := new(core.CollectionRequestEvent)
|
||||
event.RequestEvent = e
|
||||
event.Collection = collection
|
||||
|
||||
return e.App.OnCollectionViewRequest().Trigger(event, func(e *core.CollectionRequestEvent) error {
|
||||
return execAfterSuccessTx(true, e.App, func() error {
|
||||
return e.JSON(http.StatusOK, e.Collection)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func collectionCreate(e *core.RequestEvent) error {
|
||||
// populate the minimal required factory collection data (if any)
|
||||
factoryExtract := struct {
|
||||
Type string `form:"type" json:"type"`
|
||||
Name string `form:"name" json:"name"`
|
||||
}{}
|
||||
if err := e.BindBody(&factoryExtract); err != nil {
|
||||
return e.BadRequestError("Failed to load the collection type data due to invalid formatting.", err)
|
||||
}
|
||||
|
||||
// create scaffold
|
||||
collection := core.NewCollection(factoryExtract.Type, factoryExtract.Name)
|
||||
|
||||
// merge the scaffold with the submitted request data
|
||||
if err := e.BindBody(collection); err != nil {
|
||||
return e.BadRequestError("Failed to load the submitted data due to invalid formatting.", err)
|
||||
}
|
||||
|
||||
event := new(core.CollectionRequestEvent)
|
||||
event.RequestEvent = e
|
||||
event.Collection = collection
|
||||
|
||||
return e.App.OnCollectionCreateRequest().Trigger(event, func(e *core.CollectionRequestEvent) error {
|
||||
if err := e.App.Save(e.Collection); err != nil {
|
||||
// validation failure
|
||||
var validationErrors validation.Errors
|
||||
if errors.As(err, &validationErrors) {
|
||||
return e.BadRequestError("Failed to create collection.", validationErrors)
|
||||
}
|
||||
|
||||
// other generic db error
|
||||
return e.BadRequestError("Failed to create collection. Raw error: \n"+err.Error(), nil)
|
||||
}
|
||||
|
||||
return execAfterSuccessTx(true, e.App, func() error {
|
||||
return e.JSON(http.StatusOK, e.Collection)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func collectionUpdate(e *core.RequestEvent) error {
|
||||
collection, err := e.App.FindCollectionByNameOrId(e.Request.PathValue("collection"))
|
||||
if err != nil || collection == nil {
|
||||
return e.NotFoundError("", err)
|
||||
}
|
||||
|
||||
if err := e.BindBody(collection); err != nil {
|
||||
return e.BadRequestError("Failed to load the submitted data due to invalid formatting.", err)
|
||||
}
|
||||
|
||||
event := new(core.CollectionRequestEvent)
|
||||
event.RequestEvent = e
|
||||
event.Collection = collection
|
||||
|
||||
return event.App.OnCollectionUpdateRequest().Trigger(event, func(e *core.CollectionRequestEvent) error {
|
||||
if err := e.App.Save(e.Collection); err != nil {
|
||||
// validation failure
|
||||
var validationErrors validation.Errors
|
||||
if errors.As(err, &validationErrors) {
|
||||
return e.BadRequestError("Failed to update collection.", validationErrors)
|
||||
}
|
||||
|
||||
// other generic db error
|
||||
return e.BadRequestError("Failed to update collection. Raw error: \n"+err.Error(), nil)
|
||||
}
|
||||
|
||||
return execAfterSuccessTx(true, e.App, func() error {
|
||||
return e.JSON(http.StatusOK, e.Collection)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func collectionDelete(e *core.RequestEvent) error {
|
||||
collection, err := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue("collection"))
|
||||
if err != nil || collection == nil {
|
||||
return e.NotFoundError("", err)
|
||||
}
|
||||
|
||||
event := new(core.CollectionRequestEvent)
|
||||
event.RequestEvent = e
|
||||
event.Collection = collection
|
||||
|
||||
return e.App.OnCollectionDeleteRequest().Trigger(event, func(e *core.CollectionRequestEvent) error {
|
||||
if err := e.App.Delete(e.Collection); err != nil {
|
||||
msg := "Failed to delete collection"
|
||||
|
||||
// check fo references
|
||||
refs, _ := e.App.FindCollectionReferences(e.Collection, e.Collection.Id)
|
||||
if len(refs) > 0 {
|
||||
names := make([]string, 0, len(refs))
|
||||
for ref := range refs {
|
||||
names = append(names, ref.Name)
|
||||
}
|
||||
msg += " probably due to existing reference in " + strings.Join(names, ", ")
|
||||
}
|
||||
|
||||
return e.BadRequestError(msg, err)
|
||||
}
|
||||
|
||||
return execAfterSuccessTx(true, e.App, func() error {
|
||||
return e.NoContent(http.StatusNoContent)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func collectionTruncate(e *core.RequestEvent) error {
|
||||
collection, err := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue("collection"))
|
||||
if err != nil || collection == nil {
|
||||
return e.NotFoundError("", err)
|
||||
}
|
||||
|
||||
if collection.IsView() {
|
||||
return e.BadRequestError("View collections cannot be truncated since they don't store their own records.", nil)
|
||||
}
|
||||
|
||||
err = e.App.TruncateCollection(collection)
|
||||
if err != nil {
|
||||
return e.BadRequestError("Failed to truncate collection (most likely due to required cascade delete record references).", err)
|
||||
}
|
||||
|
||||
return e.NoContent(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func collectionScaffolds(e *core.RequestEvent) error {
|
||||
collections := map[string]*core.Collection{
|
||||
core.CollectionTypeBase: core.NewBaseCollection(""),
|
||||
core.CollectionTypeAuth: core.NewAuthCollection(""),
|
||||
core.CollectionTypeView: core.NewViewCollection(""),
|
||||
}
|
||||
|
||||
for _, c := range collections {
|
||||
c.Id = "" // clear autogenerated id
|
||||
}
|
||||
|
||||
return e.JSON(http.StatusOK, collections)
|
||||
}
|
62
apis/collection_import.go
Normal file
62
apis/collection_import.go
Normal file
|
@ -0,0 +1,62 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
)
|
||||
|
||||
func collectionsImport(e *core.RequestEvent) error {
|
||||
form := new(collectionsImportForm)
|
||||
|
||||
err := e.BindBody(form)
|
||||
if err != nil {
|
||||
return firstApiError(err, e.BadRequestError("An error occurred while loading the submitted data.", err))
|
||||
}
|
||||
|
||||
err = form.validate()
|
||||
if err != nil {
|
||||
return firstApiError(err, e.BadRequestError("An error occurred while validating the submitted data.", err))
|
||||
}
|
||||
|
||||
event := new(core.CollectionsImportRequestEvent)
|
||||
event.RequestEvent = e
|
||||
event.CollectionsData = form.Collections
|
||||
event.DeleteMissing = form.DeleteMissing
|
||||
|
||||
return event.App.OnCollectionsImportRequest().Trigger(event, func(e *core.CollectionsImportRequestEvent) error {
|
||||
importErr := e.App.ImportCollections(e.CollectionsData, form.DeleteMissing)
|
||||
if importErr == nil {
|
||||
return execAfterSuccessTx(true, e.App, func() error {
|
||||
return e.NoContent(http.StatusNoContent)
|
||||
})
|
||||
}
|
||||
|
||||
// validation failure
|
||||
var validationErrors validation.Errors
|
||||
if errors.As(importErr, &validationErrors) {
|
||||
return e.BadRequestError("Failed to import collections.", validationErrors)
|
||||
}
|
||||
|
||||
// generic/db failure
|
||||
return e.BadRequestError("Failed to import collections.", validation.Errors{"collections": validation.NewError(
|
||||
"validation_collections_import_failure",
|
||||
"Failed to import the collections configuration. Raw error:\n"+importErr.Error(),
|
||||
)})
|
||||
})
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
type collectionsImportForm struct {
|
||||
Collections []map[string]any `form:"collections" json:"collections"`
|
||||
DeleteMissing bool `form:"deleteMissing" json:"deleteMissing"`
|
||||
}
|
||||
|
||||
func (form *collectionsImportForm) validate() error {
|
||||
return validation.ValidateStruct(form,
|
||||
validation.Field(&form.Collections, validation.Required),
|
||||
)
|
||||
}
|
369
apis/collection_import_test.go
Normal file
369
apis/collection_import_test.go
Normal file
|
@ -0,0 +1,369 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestCollectionsImport(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
totalCollections := 16
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "unauthorized",
|
||||
Method: http.MethodPut,
|
||||
URL: "/api/collections/import",
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as regular user",
|
||||
Method: http.MethodPut,
|
||||
URL: "/api/collections/import",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser + empty collections",
|
||||
Method: http.MethodPut,
|
||||
URL: "/api/collections/import",
|
||||
Body: strings.NewReader(`{"collections":[]}`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"collections":{"code":"validation_required"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
collections := []*core.Collection{}
|
||||
if err := app.CollectionQuery().All(&collections); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
expected := totalCollections
|
||||
if len(collections) != expected {
|
||||
t.Fatalf("Expected %d collections, got %d", expected, len(collections))
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser + collections validator failure",
|
||||
Method: http.MethodPut,
|
||||
URL: "/api/collections/import",
|
||||
Body: strings.NewReader(`{
|
||||
"collections":[
|
||||
{"name": "import1"},
|
||||
{
|
||||
"name": "import2",
|
||||
"fields": [
|
||||
{
|
||||
"id": "koih1lqx",
|
||||
"name": "expand",
|
||||
"type": "text"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"collections":{"code":"validation_collections_import_failure"`,
|
||||
`import2`,
|
||||
`fields`,
|
||||
},
|
||||
NotExpectedContent: []string{"Raw error:"},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnCollectionsImportRequest": 1,
|
||||
"OnCollectionCreate": 2,
|
||||
"OnCollectionCreateExecute": 2,
|
||||
"OnCollectionAfterCreateError": 2,
|
||||
"OnModelCreate": 2,
|
||||
"OnModelCreateExecute": 2,
|
||||
"OnModelAfterCreateError": 2,
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
collections := []*core.Collection{}
|
||||
if err := app.CollectionQuery().All(&collections); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
expected := totalCollections
|
||||
if len(collections) != expected {
|
||||
t.Fatalf("Expected %d collections, got %d", expected, len(collections))
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser + non-validator failure",
|
||||
Method: http.MethodPut,
|
||||
URL: "/api/collections/import",
|
||||
Body: strings.NewReader(`{
|
||||
"collections":[
|
||||
{
|
||||
"name": "import1",
|
||||
"fields": [
|
||||
{
|
||||
"id": "koih1lqx",
|
||||
"name": "test",
|
||||
"type": "text"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "import2",
|
||||
"fields": [
|
||||
{
|
||||
"id": "koih1lqx",
|
||||
"name": "test",
|
||||
"type": "text"
|
||||
}
|
||||
],
|
||||
"indexes": [
|
||||
"create index idx_test on import2 (test)"
|
||||
]
|
||||
}
|
||||
]
|
||||
}`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"collections":{"code":"validation_collections_import_failure"`,
|
||||
`Raw error:`,
|
||||
`custom_error`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnCollectionsImportRequest": 1,
|
||||
"OnCollectionCreate": 1,
|
||||
"OnCollectionAfterCreateError": 1,
|
||||
"OnModelCreate": 1,
|
||||
"OnModelAfterCreateError": 1,
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.OnCollectionCreate().BindFunc(func(e *core.CollectionEvent) error {
|
||||
return errors.New("custom_error")
|
||||
})
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
collections := []*core.Collection{}
|
||||
if err := app.CollectionQuery().All(&collections); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
expected := totalCollections
|
||||
if len(collections) != expected {
|
||||
t.Fatalf("Expected %d collections, got %d", expected, len(collections))
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser + successful collections create",
|
||||
Method: http.MethodPut,
|
||||
URL: "/api/collections/import",
|
||||
Body: strings.NewReader(`{
|
||||
"collections":[
|
||||
{
|
||||
"name": "import1",
|
||||
"fields": [
|
||||
{
|
||||
"id": "koih1lqx",
|
||||
"name": "test",
|
||||
"type": "text"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "import2",
|
||||
"fields": [
|
||||
{
|
||||
"id": "koih1lqx",
|
||||
"name": "test",
|
||||
"type": "text"
|
||||
}
|
||||
],
|
||||
"indexes": [
|
||||
"create index idx_test on import2 (test)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "auth_without_fields",
|
||||
"type": "auth"
|
||||
}
|
||||
]
|
||||
}`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnCollectionsImportRequest": 1,
|
||||
"OnCollectionCreate": 3,
|
||||
"OnCollectionCreateExecute": 3,
|
||||
"OnCollectionAfterCreateSuccess": 3,
|
||||
"OnModelCreate": 3,
|
||||
"OnModelCreateExecute": 3,
|
||||
"OnModelAfterCreateSuccess": 3,
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
collections := []*core.Collection{}
|
||||
if err := app.CollectionQuery().All(&collections); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expected := totalCollections + 3
|
||||
if len(collections) != expected {
|
||||
t.Fatalf("Expected %d collections, got %d", expected, len(collections))
|
||||
}
|
||||
|
||||
indexes, err := app.TableIndexes("import2")
|
||||
if err != nil || indexes["idx_test"] == "" {
|
||||
t.Fatalf("Missing index %s (%v)", "idx_test", err)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser + create/update/delete",
|
||||
Method: http.MethodPut,
|
||||
URL: "/api/collections/import",
|
||||
Body: strings.NewReader(`{
|
||||
"deleteMissing": true,
|
||||
"collections":[
|
||||
{"name": "test123"},
|
||||
{
|
||||
"id":"wsmn24bux7wo113",
|
||||
"name":"demo1",
|
||||
"fields":[
|
||||
{
|
||||
"id":"_2hlxbmp",
|
||||
"name":"title",
|
||||
"type":"text",
|
||||
"required":true
|
||||
}
|
||||
],
|
||||
"indexes": []
|
||||
}
|
||||
]
|
||||
}`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnCollectionsImportRequest": 1,
|
||||
// ---
|
||||
"OnModelCreate": 1,
|
||||
"OnModelCreateExecute": 1,
|
||||
"OnModelAfterCreateSuccess": 1,
|
||||
"OnCollectionCreate": 1,
|
||||
"OnCollectionCreateExecute": 1,
|
||||
"OnCollectionAfterCreateSuccess": 1,
|
||||
// ---
|
||||
"OnModelUpdate": 1,
|
||||
"OnModelUpdateExecute": 1,
|
||||
"OnModelAfterUpdateSuccess": 1,
|
||||
"OnCollectionUpdate": 1,
|
||||
"OnCollectionUpdateExecute": 1,
|
||||
"OnCollectionAfterUpdateSuccess": 1,
|
||||
// ---
|
||||
"OnModelDelete": 14,
|
||||
"OnModelAfterDeleteSuccess": 14,
|
||||
"OnModelDeleteExecute": 14,
|
||||
"OnCollectionDelete": 9,
|
||||
"OnCollectionDeleteExecute": 9,
|
||||
"OnCollectionAfterDeleteSuccess": 9,
|
||||
"OnRecordAfterDeleteSuccess": 5,
|
||||
"OnRecordDelete": 5,
|
||||
"OnRecordDeleteExecute": 5,
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
collections := []*core.Collection{}
|
||||
if err := app.CollectionQuery().All(&collections); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
systemCollections := 0
|
||||
for _, c := range collections {
|
||||
if c.System {
|
||||
systemCollections++
|
||||
}
|
||||
}
|
||||
|
||||
expected := systemCollections + 2
|
||||
if len(collections) != expected {
|
||||
t.Fatalf("Expected %d collections, got %d", expected, len(collections))
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "OnCollectionsImportRequest tx body write check",
|
||||
Method: http.MethodPut,
|
||||
URL: "/api/collections/import",
|
||||
Body: strings.NewReader(`{
|
||||
"deleteMissing": true,
|
||||
"collections":[
|
||||
{"name": "test123"},
|
||||
{
|
||||
"id":"wsmn24bux7wo113",
|
||||
"name":"demo1",
|
||||
"fields":[
|
||||
{
|
||||
"id":"_2hlxbmp",
|
||||
"name":"title",
|
||||
"type":"text",
|
||||
"required":true
|
||||
}
|
||||
],
|
||||
"indexes": []
|
||||
}
|
||||
]
|
||||
}`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.OnCollectionsImportRequest().BindFunc(func(e *core.CollectionsImportRequestEvent) error {
|
||||
original := e.App
|
||||
return e.App.RunInTransaction(func(txApp core.App) error {
|
||||
e.App = txApp
|
||||
defer func() { e.App = original }()
|
||||
|
||||
if err := e.Next(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return e.BadRequestError("TX_ERROR", nil)
|
||||
})
|
||||
})
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedEvents: map[string]int{"OnCollectionsImportRequest": 1},
|
||||
ExpectedContent: []string{"TX_ERROR"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
1586
apis/collection_test.go
Normal file
1586
apis/collection_test.go
Normal file
File diff suppressed because it is too large
Load diff
59
apis/cron.go
Normal file
59
apis/cron.go
Normal file
|
@ -0,0 +1,59 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tools/cron"
|
||||
"github.com/pocketbase/pocketbase/tools/router"
|
||||
"github.com/pocketbase/pocketbase/tools/routine"
|
||||
)
|
||||
|
||||
// bindCronApi registers the crons api endpoint.
|
||||
func bindCronApi(app core.App, rg *router.RouterGroup[*core.RequestEvent]) {
|
||||
subGroup := rg.Group("/crons").Bind(RequireSuperuserAuth())
|
||||
subGroup.GET("", cronsList)
|
||||
subGroup.POST("/{id}", cronRun)
|
||||
}
|
||||
|
||||
func cronsList(e *core.RequestEvent) error {
|
||||
jobs := e.App.Cron().Jobs()
|
||||
|
||||
slices.SortStableFunc(jobs, func(a, b *cron.Job) int {
|
||||
if strings.HasPrefix(a.Id(), "__pb") {
|
||||
return 1
|
||||
}
|
||||
if strings.HasPrefix(b.Id(), "__pb") {
|
||||
return -1
|
||||
}
|
||||
return strings.Compare(a.Id(), b.Id())
|
||||
})
|
||||
|
||||
return e.JSON(http.StatusOK, jobs)
|
||||
}
|
||||
|
||||
func cronRun(e *core.RequestEvent) error {
|
||||
cronId := e.Request.PathValue("id")
|
||||
|
||||
var foundJob *cron.Job
|
||||
|
||||
jobs := e.App.Cron().Jobs()
|
||||
for _, j := range jobs {
|
||||
if j.Id() == cronId {
|
||||
foundJob = j
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if foundJob == nil {
|
||||
return e.NotFoundError("Missing or invalid cron job", nil)
|
||||
}
|
||||
|
||||
routine.FireAndForget(func() {
|
||||
foundJob.Run()
|
||||
})
|
||||
|
||||
return e.NoContent(http.StatusNoContent)
|
||||
}
|
149
apis/cron_test.go
Normal file
149
apis/cron_test.go
Normal file
|
@ -0,0 +1,149 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
"github.com/spf13/cast"
|
||||
)
|
||||
|
||||
func TestCronsList(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "unauthorized",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/crons",
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as regular user",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/crons",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser (empty list)",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/crons",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Cron().RemoveAll()
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{`[]`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/crons",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`{"id":"__pbLogsCleanup__","expression":"0 */6 * * *"}`,
|
||||
`{"id":"__pbDBOptimize__","expression":"0 0 * * *"}`,
|
||||
`{"id":"__pbMFACleanup__","expression":"0 * * * *"}`,
|
||||
`{"id":"__pbOTPCleanup__","expression":"0 * * * *"}`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCronsRun(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
beforeTestFunc := func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Cron().Add("test", "* * * * *", func() {
|
||||
app.Store().Set("testJobCalls", cast.ToInt(app.Store().Get("testJobCalls"))+1)
|
||||
})
|
||||
}
|
||||
|
||||
expectedCalls := func(expected int) func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
return func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
total := cast.ToInt(app.Store().Get("testJobCalls"))
|
||||
if total != expected {
|
||||
t.Fatalf("Expected total testJobCalls %d, got %d", expected, total)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "unauthorized",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/crons/test",
|
||||
Delay: 50 * time.Millisecond,
|
||||
BeforeTestFunc: beforeTestFunc,
|
||||
AfterTestFunc: expectedCalls(0),
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as regular user",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/crons/test",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
Delay: 50 * time.Millisecond,
|
||||
BeforeTestFunc: beforeTestFunc,
|
||||
AfterTestFunc: expectedCalls(0),
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser (missing job)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/crons/missing",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
Delay: 50 * time.Millisecond,
|
||||
BeforeTestFunc: beforeTestFunc,
|
||||
AfterTestFunc: expectedCalls(0),
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser (existing job)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/crons/test",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
Delay: 50 * time.Millisecond,
|
||||
BeforeTestFunc: beforeTestFunc,
|
||||
AfterTestFunc: expectedCalls(1),
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
230
apis/file.go
Normal file
230
apis/file.go
Normal file
|
@ -0,0 +1,230 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tools/filesystem"
|
||||
"github.com/pocketbase/pocketbase/tools/list"
|
||||
"github.com/pocketbase/pocketbase/tools/router"
|
||||
"github.com/spf13/cast"
|
||||
"golang.org/x/sync/semaphore"
|
||||
"golang.org/x/sync/singleflight"
|
||||
)
|
||||
|
||||
var imageContentTypes = []string{"image/png", "image/jpg", "image/jpeg", "image/gif", "image/webp"}
|
||||
var defaultThumbSizes = []string{"100x100"}
|
||||
|
||||
// bindFileApi registers the file api endpoints and the corresponding handlers.
|
||||
func bindFileApi(app core.App, rg *router.RouterGroup[*core.RequestEvent]) {
|
||||
maxWorkers := cast.ToInt64(os.Getenv("PB_THUMBS_MAX_WORKERS"))
|
||||
if maxWorkers <= 0 {
|
||||
maxWorkers = int64(runtime.NumCPU() + 2) // the value is arbitrary chosen and may change in the future
|
||||
}
|
||||
|
||||
maxWait := cast.ToInt64(os.Getenv("PB_THUMBS_MAX_WAIT"))
|
||||
if maxWait <= 0 {
|
||||
maxWait = 60
|
||||
}
|
||||
|
||||
api := fileApi{
|
||||
thumbGenPending: new(singleflight.Group),
|
||||
thumbGenSem: semaphore.NewWeighted(maxWorkers),
|
||||
thumbGenMaxWait: time.Duration(maxWait) * time.Second,
|
||||
}
|
||||
|
||||
sub := rg.Group("/files")
|
||||
sub.POST("/token", api.fileToken).Bind(RequireAuth())
|
||||
sub.GET("/{collection}/{recordId}/{filename}", api.download).Bind(collectionPathRateLimit("", "file"))
|
||||
}
|
||||
|
||||
type fileApi struct {
|
||||
// thumbGenSem is a semaphore to prevent too much concurrent
|
||||
// requests generating new thumbs at the same time.
|
||||
thumbGenSem *semaphore.Weighted
|
||||
|
||||
// thumbGenPending represents a group of currently pending
|
||||
// thumb generation processes.
|
||||
thumbGenPending *singleflight.Group
|
||||
|
||||
// thumbGenMaxWait is the maximum waiting time for starting a new
|
||||
// thumb generation process.
|
||||
thumbGenMaxWait time.Duration
|
||||
}
|
||||
|
||||
func (api *fileApi) fileToken(e *core.RequestEvent) error {
|
||||
if e.Auth == nil {
|
||||
return e.UnauthorizedError("Missing auth context.", nil)
|
||||
}
|
||||
|
||||
token, err := e.Auth.NewFileToken()
|
||||
if err != nil {
|
||||
return e.InternalServerError("Failed to generate file token", err)
|
||||
}
|
||||
|
||||
event := new(core.FileTokenRequestEvent)
|
||||
event.RequestEvent = e
|
||||
event.Token = token
|
||||
event.Record = e.Auth
|
||||
|
||||
return e.App.OnFileTokenRequest().Trigger(event, func(e *core.FileTokenRequestEvent) error {
|
||||
return execAfterSuccessTx(true, e.App, func() error {
|
||||
return e.JSON(http.StatusOK, map[string]string{"token": e.Token})
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func (api *fileApi) download(e *core.RequestEvent) error {
|
||||
collection, err := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue("collection"))
|
||||
if err != nil {
|
||||
return e.NotFoundError("", nil)
|
||||
}
|
||||
|
||||
recordId := e.Request.PathValue("recordId")
|
||||
if recordId == "" {
|
||||
return e.NotFoundError("", nil)
|
||||
}
|
||||
|
||||
record, err := e.App.FindRecordById(collection, recordId)
|
||||
if err != nil {
|
||||
return e.NotFoundError("", err)
|
||||
}
|
||||
|
||||
filename := e.Request.PathValue("filename")
|
||||
|
||||
fileField := record.FindFileFieldByFile(filename)
|
||||
if fileField == nil {
|
||||
return e.NotFoundError("", nil)
|
||||
}
|
||||
|
||||
// check whether the request is authorized to view the protected file
|
||||
if fileField.Protected {
|
||||
originalRequestInfo, err := e.RequestInfo()
|
||||
if err != nil {
|
||||
return e.InternalServerError("Failed to load request info", err)
|
||||
}
|
||||
|
||||
token := e.Request.URL.Query().Get("token")
|
||||
authRecord, _ := e.App.FindAuthRecordByToken(token, core.TokenTypeFile)
|
||||
|
||||
// create a shallow copy of the cached request data and adjust it to the current auth record (if any)
|
||||
requestInfo := *originalRequestInfo
|
||||
requestInfo.Context = core.RequestInfoContextProtectedFile
|
||||
requestInfo.Auth = authRecord
|
||||
|
||||
if ok, _ := e.App.CanAccessRecord(record, &requestInfo, record.Collection().ViewRule); !ok {
|
||||
return e.NotFoundError("", errors.New("insufficient permissions to access the file resource"))
|
||||
}
|
||||
}
|
||||
|
||||
baseFilesPath := record.BaseFilesPath()
|
||||
|
||||
// fetch the original view file field related record
|
||||
if collection.IsView() {
|
||||
fileRecord, err := e.App.FindRecordByViewFile(collection.Id, fileField.Name, filename)
|
||||
if err != nil {
|
||||
return e.NotFoundError("", fmt.Errorf("failed to fetch view file field record: %w", err))
|
||||
}
|
||||
baseFilesPath = fileRecord.BaseFilesPath()
|
||||
}
|
||||
|
||||
fsys, err := e.App.NewFilesystem()
|
||||
if err != nil {
|
||||
return e.InternalServerError("Filesystem initialization failure.", err)
|
||||
}
|
||||
defer fsys.Close()
|
||||
|
||||
originalPath := baseFilesPath + "/" + filename
|
||||
servedPath := originalPath
|
||||
servedName := filename
|
||||
|
||||
// check for valid thumb size param
|
||||
thumbSize := e.Request.URL.Query().Get("thumb")
|
||||
if thumbSize != "" && (list.ExistInSlice(thumbSize, defaultThumbSizes) || list.ExistInSlice(thumbSize, fileField.Thumbs)) {
|
||||
// extract the original file meta attributes and check it existence
|
||||
oAttrs, oAttrsErr := fsys.Attributes(originalPath)
|
||||
if oAttrsErr != nil {
|
||||
return e.NotFoundError("", err)
|
||||
}
|
||||
|
||||
// check if it is an image
|
||||
if list.ExistInSlice(oAttrs.ContentType, imageContentTypes) {
|
||||
// add thumb size as file suffix
|
||||
servedName = thumbSize + "_" + filename
|
||||
servedPath = baseFilesPath + "/thumbs_" + filename + "/" + servedName
|
||||
|
||||
// create a new thumb if it doesn't exist
|
||||
if exists, _ := fsys.Exists(servedPath); !exists {
|
||||
if err := api.createThumb(e, fsys, originalPath, servedPath, thumbSize); err != nil {
|
||||
e.App.Logger().Warn(
|
||||
"Fallback to original - failed to create thumb "+servedName,
|
||||
slog.Any("error", err),
|
||||
slog.String("original", originalPath),
|
||||
slog.String("thumb", servedPath),
|
||||
)
|
||||
|
||||
// fallback to the original
|
||||
servedName = filename
|
||||
servedPath = originalPath
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
event := new(core.FileDownloadRequestEvent)
|
||||
event.RequestEvent = e
|
||||
event.Collection = collection
|
||||
event.Record = record
|
||||
event.FileField = fileField
|
||||
event.ServedPath = servedPath
|
||||
event.ServedName = servedName
|
||||
|
||||
// clickjacking shouldn't be a concern when serving uploaded files,
|
||||
// so it safe to unset the global X-Frame-Options to allow files embedding
|
||||
// (note: it is out of the hook to allow users to customize the behavior)
|
||||
e.Response.Header().Del("X-Frame-Options")
|
||||
|
||||
return e.App.OnFileDownloadRequest().Trigger(event, func(e *core.FileDownloadRequestEvent) error {
|
||||
err = execAfterSuccessTx(true, e.App, func() error {
|
||||
return fsys.Serve(e.Response, e.Request, e.ServedPath, e.ServedName)
|
||||
})
|
||||
if err != nil {
|
||||
return e.NotFoundError("", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (api *fileApi) createThumb(
|
||||
e *core.RequestEvent,
|
||||
fsys *filesystem.System,
|
||||
originalPath string,
|
||||
thumbPath string,
|
||||
thumbSize string,
|
||||
) error {
|
||||
ch := api.thumbGenPending.DoChan(thumbPath, func() (any, error) {
|
||||
ctx, cancel := context.WithTimeout(e.Request.Context(), api.thumbGenMaxWait)
|
||||
defer cancel()
|
||||
|
||||
if err := api.thumbGenSem.Acquire(ctx, 1); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer api.thumbGenSem.Release(1)
|
||||
|
||||
return nil, fsys.CreateThumb(originalPath, thumbPath, thumbSize)
|
||||
})
|
||||
|
||||
res := <-ch
|
||||
|
||||
api.thumbGenPending.Forget(thumbPath)
|
||||
|
||||
return res.Err
|
||||
}
|
504
apis/file_test.go
Normal file
504
apis/file_test.go
Normal file
|
@ -0,0 +1,504 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/apis"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
)
|
||||
|
||||
func TestFileToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "unauthorized",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/files/token",
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "regular user",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/files/token",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"token":"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnFileTokenRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "superuser",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/files/token",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"token":"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnFileTokenRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "hook token overwrite",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/files/token",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.OnFileTokenRequest().BindFunc(func(e *core.FileTokenRequestEvent) error {
|
||||
e.Token = "test"
|
||||
return e.Next()
|
||||
})
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"token":"test"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnFileTokenRequest": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileDownload(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, currentFile, _, _ := runtime.Caller(0)
|
||||
dataDirRelPath := "../tests/data/"
|
||||
|
||||
testFilePath := filepath.Join(path.Dir(currentFile), dataDirRelPath, "storage/_pb_users_auth_/oap640cot4yru2s/test_kfd2wYLxkz.txt")
|
||||
testImgPath := filepath.Join(path.Dir(currentFile), dataDirRelPath, "storage/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png")
|
||||
testThumbCropCenterPath := filepath.Join(path.Dir(currentFile), dataDirRelPath, "storage/_pb_users_auth_/4q1xlclmfloku33/thumbs_300_1SEi6Q6U72.png/70x50_300_1SEi6Q6U72.png")
|
||||
testThumbCropTopPath := filepath.Join(path.Dir(currentFile), dataDirRelPath, "storage/_pb_users_auth_/4q1xlclmfloku33/thumbs_300_1SEi6Q6U72.png/70x50t_300_1SEi6Q6U72.png")
|
||||
testThumbCropBottomPath := filepath.Join(path.Dir(currentFile), dataDirRelPath, "storage/_pb_users_auth_/4q1xlclmfloku33/thumbs_300_1SEi6Q6U72.png/70x50b_300_1SEi6Q6U72.png")
|
||||
testThumbFitPath := filepath.Join(path.Dir(currentFile), dataDirRelPath, "storage/_pb_users_auth_/4q1xlclmfloku33/thumbs_300_1SEi6Q6U72.png/70x50f_300_1SEi6Q6U72.png")
|
||||
testThumbZeroWidthPath := filepath.Join(path.Dir(currentFile), dataDirRelPath, "storage/_pb_users_auth_/4q1xlclmfloku33/thumbs_300_1SEi6Q6U72.png/0x50_300_1SEi6Q6U72.png")
|
||||
testThumbZeroHeightPath := filepath.Join(path.Dir(currentFile), dataDirRelPath, "storage/_pb_users_auth_/4q1xlclmfloku33/thumbs_300_1SEi6Q6U72.png/70x0_300_1SEi6Q6U72.png")
|
||||
|
||||
testFile, fileErr := os.ReadFile(testFilePath)
|
||||
if fileErr != nil {
|
||||
t.Fatal(fileErr)
|
||||
}
|
||||
|
||||
testImg, imgErr := os.ReadFile(testImgPath)
|
||||
if imgErr != nil {
|
||||
t.Fatal(imgErr)
|
||||
}
|
||||
|
||||
testThumbCropCenter, thumbErr := os.ReadFile(testThumbCropCenterPath)
|
||||
if thumbErr != nil {
|
||||
t.Fatal(thumbErr)
|
||||
}
|
||||
|
||||
testThumbCropTop, thumbErr := os.ReadFile(testThumbCropTopPath)
|
||||
if thumbErr != nil {
|
||||
t.Fatal(thumbErr)
|
||||
}
|
||||
|
||||
testThumbCropBottom, thumbErr := os.ReadFile(testThumbCropBottomPath)
|
||||
if thumbErr != nil {
|
||||
t.Fatal(thumbErr)
|
||||
}
|
||||
|
||||
testThumbFit, thumbErr := os.ReadFile(testThumbFitPath)
|
||||
if thumbErr != nil {
|
||||
t.Fatal(thumbErr)
|
||||
}
|
||||
|
||||
testThumbZeroWidth, thumbErr := os.ReadFile(testThumbZeroWidthPath)
|
||||
if thumbErr != nil {
|
||||
t.Fatal(thumbErr)
|
||||
}
|
||||
|
||||
testThumbZeroHeight, thumbErr := os.ReadFile(testThumbZeroHeightPath)
|
||||
if thumbErr != nil {
|
||||
t.Fatal(thumbErr)
|
||||
}
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "missing collection",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/files/missing/4q1xlclmfloku33/300_1SEi6Q6U72.png",
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "missing record",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/files/_pb_users_auth_/missing/300_1SEi6Q6U72.png",
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "missing file",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/files/_pb_users_auth_/4q1xlclmfloku33/missing.png",
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "existing image",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png",
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{string(testImg)},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnFileDownloadRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "existing image - missing thumb (should fallback to the original)",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png?thumb=999x999",
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{string(testImg)},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnFileDownloadRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "existing image - existing thumb (crop center)",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png?thumb=70x50",
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{string(testThumbCropCenter)},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnFileDownloadRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "existing image - existing thumb (crop top)",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png?thumb=70x50t",
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{string(testThumbCropTop)},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnFileDownloadRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "existing image - existing thumb (crop bottom)",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png?thumb=70x50b",
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{string(testThumbCropBottom)},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnFileDownloadRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "existing image - existing thumb (fit)",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png?thumb=70x50f",
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{string(testThumbFit)},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnFileDownloadRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "existing image - existing thumb (zero width)",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png?thumb=0x50",
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{string(testThumbZeroWidth)},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnFileDownloadRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "existing image - existing thumb (zero height)",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png?thumb=70x0",
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{string(testThumbZeroHeight)},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnFileDownloadRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "existing non image file - thumb parameter should be ignored",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/files/_pb_users_auth_/oap640cot4yru2s/test_kfd2wYLxkz.txt?thumb=100x100",
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{string(testFile)},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnFileDownloadRequest": 1,
|
||||
},
|
||||
},
|
||||
|
||||
// protected file access checks
|
||||
{
|
||||
Name: "protected file - superuser with expired file token",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/files/demo1/al1h9ijdeojtsjy/300_Jsjq7RdBgA.png?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsImV4cCI6MTY0MDk5MTY2MSwidHlwZSI6ImZpbGUiLCJjb2xsZWN0aW9uSWQiOiJwYmNfMzE0MjYzNTgyMyJ9.nqqtqpPhxU0045F4XP_ruAkzAidYBc5oPy9ErN3XBq0",
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "protected file - superuser with valid file token",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/files/demo1/al1h9ijdeojtsjy/300_Jsjq7RdBgA.png?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6ImZpbGUiLCJjb2xsZWN0aW9uSWQiOiJwYmNfMzE0MjYzNTgyMyJ9.Lupz541xRvrktwkrl55p5pPCF77T69ZRsohsIcb2dxc",
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{"PNG"},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnFileDownloadRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "protected file - guest without view access",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/files/demo1/al1h9ijdeojtsjy/300_Jsjq7RdBgA.png",
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "protected file - guest with view access",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/files/demo1/al1h9ijdeojtsjy/300_Jsjq7RdBgA.png",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
// mock public view access
|
||||
c, err := app.FindCachedCollectionByNameOrId("demo1")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to fetch mock collection: %v", err)
|
||||
}
|
||||
c.ViewRule = types.Pointer("")
|
||||
if err := app.UnsafeWithoutHooks().Save(c); err != nil {
|
||||
t.Fatalf("Failed to update mock collection: %v", err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{"PNG"},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnFileDownloadRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "protected file - auth record without view access",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/files/demo1/al1h9ijdeojtsjy/300_Jsjq7RdBgA.png?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6ImZpbGUiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8ifQ.nSTLuCPcGpWn2K2l-BFkC3Vlzc-ZTDPByYq8dN1oPSo",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
// mock restricted user view access
|
||||
c, err := app.FindCachedCollectionByNameOrId("demo1")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to fetch mock collection: %v", err)
|
||||
}
|
||||
c.ViewRule = types.Pointer("@request.auth.verified = true")
|
||||
if err := app.UnsafeWithoutHooks().Save(c); err != nil {
|
||||
t.Fatalf("Failed to update mock collection: %v", err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "protected file - auth record with view access",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/files/demo1/al1h9ijdeojtsjy/300_Jsjq7RdBgA.png?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6ImZpbGUiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8ifQ.nSTLuCPcGpWn2K2l-BFkC3Vlzc-ZTDPByYq8dN1oPSo",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
// mock user view access
|
||||
c, err := app.FindCachedCollectionByNameOrId("demo1")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to fetch mock collection: %v", err)
|
||||
}
|
||||
c.ViewRule = types.Pointer("@request.auth.verified = false")
|
||||
if err := app.UnsafeWithoutHooks().Save(c); err != nil {
|
||||
t.Fatalf("Failed to update mock collection: %v", err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{"PNG"},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnFileDownloadRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "protected file in view (view's View API rule failure)",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/files/view1/al1h9ijdeojtsjy/300_Jsjq7RdBgA.png?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6ImZpbGUiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8ifQ.nSTLuCPcGpWn2K2l-BFkC3Vlzc-ZTDPByYq8dN1oPSo",
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "protected file in view (view's View API rule success)",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/files/view1/84nmscqy84lsi1t/test_d61b33QdDU.txt?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6ImZpbGUiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8ifQ.nSTLuCPcGpWn2K2l-BFkC3Vlzc-ZTDPByYq8dN1oPSo",
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{"test"},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnFileDownloadRequest": 1,
|
||||
},
|
||||
},
|
||||
|
||||
// rate limit checks
|
||||
// -----------------------------------------------------------
|
||||
{
|
||||
Name: "RateLimit rule - users:file",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 100, Label: "*:file"},
|
||||
{MaxRequests: 0, Label: "users:file"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "RateLimit rule - *:file",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 0, Label: "*:file"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
// clone for the HEAD test (the same as the original scenario but without body)
|
||||
head := scenario
|
||||
head.Method = http.MethodHead
|
||||
head.Name = ("(HEAD) " + scenario.Name)
|
||||
head.ExpectedContent = nil
|
||||
head.Test(t)
|
||||
|
||||
// regular request test
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConcurrentThumbsGeneration(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, err := tests.NewTestApp()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer app.Cleanup()
|
||||
|
||||
fsys, err := app.NewFilesystem()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer fsys.Close()
|
||||
|
||||
// create a dummy file field collection
|
||||
demo1, err := app.FindCollectionByNameOrId("demo1")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
fileField := demo1.Fields.GetByName("file_one").(*core.FileField)
|
||||
fileField.Protected = false
|
||||
fileField.MaxSelect = 1
|
||||
fileField.MaxSize = 999999
|
||||
// new thumbs
|
||||
fileField.Thumbs = []string{"111x111", "111x222", "111x333"}
|
||||
demo1.Fields.Add(fileField)
|
||||
if err = app.Save(demo1); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
fileKey := "wsmn24bux7wo113/al1h9ijdeojtsjy/300_Jsjq7RdBgA.png"
|
||||
|
||||
urls := []string{
|
||||
"/api/files/" + fileKey + "?thumb=111x111",
|
||||
"/api/files/" + fileKey + "?thumb=111x111", // should still result in single thumb
|
||||
"/api/files/" + fileKey + "?thumb=111x222",
|
||||
"/api/files/" + fileKey + "?thumb=111x333",
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
wg.Add(len(urls))
|
||||
|
||||
for _, url := range urls {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
req := httptest.NewRequest("GET", url, nil)
|
||||
|
||||
pbRouter, _ := apis.NewRouter(app)
|
||||
mux, _ := pbRouter.BuildMux()
|
||||
if mux != nil {
|
||||
mux.ServeHTTP(recorder, req)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// ensure that all new requested thumbs were created
|
||||
thumbKeys := []string{
|
||||
"wsmn24bux7wo113/al1h9ijdeojtsjy/thumbs_300_Jsjq7RdBgA.png/111x111_" + filepath.Base(fileKey),
|
||||
"wsmn24bux7wo113/al1h9ijdeojtsjy/thumbs_300_Jsjq7RdBgA.png/111x222_" + filepath.Base(fileKey),
|
||||
"wsmn24bux7wo113/al1h9ijdeojtsjy/thumbs_300_Jsjq7RdBgA.png/111x333_" + filepath.Base(fileKey),
|
||||
}
|
||||
for _, k := range thumbKeys {
|
||||
if exists, _ := fsys.Exists(k); !exists {
|
||||
t.Fatalf("Missing thumb %q: %v", k, err)
|
||||
}
|
||||
}
|
||||
}
|
53
apis/health.go
Normal file
53
apis/health.go
Normal file
|
@ -0,0 +1,53 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"slices"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tools/router"
|
||||
)
|
||||
|
||||
// bindHealthApi registers the health api endpoint.
|
||||
func bindHealthApi(app core.App, rg *router.RouterGroup[*core.RequestEvent]) {
|
||||
subGroup := rg.Group("/health")
|
||||
subGroup.GET("", healthCheck)
|
||||
}
|
||||
|
||||
// healthCheck returns a 200 OK response if the server is healthy.
|
||||
func healthCheck(e *core.RequestEvent) error {
|
||||
resp := struct {
|
||||
Message string `json:"message"`
|
||||
Code int `json:"code"`
|
||||
Data map[string]any `json:"data"`
|
||||
}{
|
||||
Code: http.StatusOK,
|
||||
Message: "API is healthy.",
|
||||
}
|
||||
|
||||
if e.HasSuperuserAuth() {
|
||||
resp.Data = make(map[string]any, 3)
|
||||
resp.Data["canBackup"] = !e.App.Store().Has(core.StoreKeyActiveBackup)
|
||||
resp.Data["realIP"] = e.RealIP()
|
||||
|
||||
// loosely check if behind a reverse proxy
|
||||
// (usually used in the dashboard to remind superusers in case deployed behind reverse-proxy)
|
||||
possibleProxyHeader := ""
|
||||
headersToCheck := append(
|
||||
slices.Clone(e.App.Settings().TrustedProxy.Headers),
|
||||
// common proxy headers
|
||||
"CF-Connecting-IP", "Fly-Client-IP", "X‑Forwarded-For",
|
||||
)
|
||||
for _, header := range headersToCheck {
|
||||
if e.Request.Header.Get(header) != "" {
|
||||
possibleProxyHeader = header
|
||||
break
|
||||
}
|
||||
}
|
||||
resp.Data["possibleProxyHeader"] = possibleProxyHeader
|
||||
} else {
|
||||
resp.Data = map[string]any{} // ensure that it is returned as object
|
||||
}
|
||||
|
||||
return e.JSON(http.StatusOK, resp)
|
||||
}
|
71
apis/health_test.go
Normal file
71
apis/health_test.go
Normal file
|
@ -0,0 +1,71 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestHealthAPI(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "GET health status (guest)",
|
||||
Method: http.MethodGet, // automatically matches also HEAD as a side-effect of the Go std mux
|
||||
URL: "/api/health",
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"code":200`,
|
||||
`"data":{}`,
|
||||
},
|
||||
NotExpectedContent: []string{
|
||||
"canBackup",
|
||||
"realIP",
|
||||
"possibleProxyHeader",
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "GET health status (regular user)",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/health",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"code":200`,
|
||||
`"data":{}`,
|
||||
},
|
||||
NotExpectedContent: []string{
|
||||
"canBackup",
|
||||
"realIP",
|
||||
"possibleProxyHeader",
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "GET health status (superuser)",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/health",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"code":200`,
|
||||
`"data":{`,
|
||||
`"canBackup":true`,
|
||||
`"realIP"`,
|
||||
`"possibleProxyHeader"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
88
apis/installer.go
Normal file
88
apis/installer.go
Normal file
|
@ -0,0 +1,88 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/fatih/color"
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tools/osutils"
|
||||
)
|
||||
|
||||
// DefaultInstallerFunc is the default PocketBase installer function.
|
||||
//
|
||||
// It will attempt to open a link in the browser (with a short-lived auth
|
||||
// token for the systemSuperuser) to the installer UI so that users can
|
||||
// create their own custom superuser record.
|
||||
//
|
||||
// See https://github.com/pocketbase/pocketbase/discussions/5814.
|
||||
func DefaultInstallerFunc(app core.App, systemSuperuser *core.Record, baseURL string) error {
|
||||
token, err := systemSuperuser.NewStaticAuthToken(30 * time.Minute)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// launch url (ignore errors and always print a help text as fallback)
|
||||
url := fmt.Sprintf("%s/_/#/pbinstal/%s", strings.TrimRight(baseURL, "/"), token)
|
||||
_ = osutils.LaunchURL(url)
|
||||
color.Magenta("\n(!) Launch the URL below in the browser if it hasn't been open already to create your first superuser account:")
|
||||
color.New(color.Bold).Add(color.FgCyan).Println(url)
|
||||
color.New(color.FgHiBlack, color.Italic).Printf("(you can also create your first superuser by running: %s superuser upsert EMAIL PASS)\n\n", os.Args[0])
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func loadInstaller(
|
||||
app core.App,
|
||||
baseURL string,
|
||||
installerFunc func(app core.App, systemSuperuser *core.Record, baseURL string) error,
|
||||
) error {
|
||||
if installerFunc == nil || !needInstallerSuperuser(app) {
|
||||
return nil
|
||||
}
|
||||
|
||||
superuser, err := findOrCreateInstallerSuperuser(app)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return installerFunc(app, superuser, baseURL)
|
||||
}
|
||||
|
||||
func needInstallerSuperuser(app core.App) bool {
|
||||
total, err := app.CountRecords(core.CollectionNameSuperusers, dbx.Not(dbx.HashExp{
|
||||
"email": core.DefaultInstallerEmail,
|
||||
}))
|
||||
|
||||
return err == nil && total == 0
|
||||
}
|
||||
|
||||
func findOrCreateInstallerSuperuser(app core.App) (*core.Record, error) {
|
||||
col, err := app.FindCachedCollectionByNameOrId(core.CollectionNameSuperusers)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
record, err := app.FindAuthRecordByEmail(col, core.DefaultInstallerEmail)
|
||||
if err != nil {
|
||||
if !errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
record = core.NewRecord(col)
|
||||
record.SetEmail(core.DefaultInstallerEmail)
|
||||
record.SetRandomPassword()
|
||||
|
||||
err = app.Save(record)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
73
apis/logs.go
Normal file
73
apis/logs.go
Normal file
|
@ -0,0 +1,73 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tools/router"
|
||||
"github.com/pocketbase/pocketbase/tools/search"
|
||||
)
|
||||
|
||||
// bindLogsApi registers the request logs api endpoints.
|
||||
func bindLogsApi(app core.App, rg *router.RouterGroup[*core.RequestEvent]) {
|
||||
sub := rg.Group("/logs").Bind(RequireSuperuserAuth(), SkipSuccessActivityLog())
|
||||
sub.GET("", logsList)
|
||||
sub.GET("/stats", logsStats)
|
||||
sub.GET("/{id}", logsView)
|
||||
}
|
||||
|
||||
var logFilterFields = []string{
|
||||
"id", "created", "level", "message", "data",
|
||||
`^data\.[\w\.\:]*\w+$`,
|
||||
}
|
||||
|
||||
func logsList(e *core.RequestEvent) error {
|
||||
fieldResolver := search.NewSimpleFieldResolver(logFilterFields...)
|
||||
|
||||
result, err := search.NewProvider(fieldResolver).
|
||||
Query(e.App.AuxModelQuery(&core.Log{})).
|
||||
ParseAndExec(e.Request.URL.Query().Encode(), &[]*core.Log{})
|
||||
|
||||
if err != nil {
|
||||
return e.BadRequestError("", err)
|
||||
}
|
||||
|
||||
return e.JSON(http.StatusOK, result)
|
||||
}
|
||||
|
||||
func logsStats(e *core.RequestEvent) error {
|
||||
fieldResolver := search.NewSimpleFieldResolver(logFilterFields...)
|
||||
|
||||
filter := e.Request.URL.Query().Get(search.FilterQueryParam)
|
||||
|
||||
var expr dbx.Expression
|
||||
if filter != "" {
|
||||
var err error
|
||||
expr, err = search.FilterData(filter).BuildExpr(fieldResolver)
|
||||
if err != nil {
|
||||
return e.BadRequestError("Invalid filter format.", err)
|
||||
}
|
||||
}
|
||||
|
||||
stats, err := e.App.LogsStats(expr)
|
||||
if err != nil {
|
||||
return e.BadRequestError("Failed to generate logs stats.", err)
|
||||
}
|
||||
|
||||
return e.JSON(http.StatusOK, stats)
|
||||
}
|
||||
|
||||
func logsView(e *core.RequestEvent) error {
|
||||
id := e.Request.PathValue("id")
|
||||
if id == "" {
|
||||
return e.NotFoundError("", nil)
|
||||
}
|
||||
|
||||
log, err := e.App.FindLogById(id)
|
||||
if err != nil || log == nil {
|
||||
return e.NotFoundError("", err)
|
||||
}
|
||||
|
||||
return e.JSON(http.StatusOK, log)
|
||||
}
|
212
apis/logs_test.go
Normal file
212
apis/logs_test.go
Normal file
|
@ -0,0 +1,212 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestLogsList(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "unauthorized",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/logs",
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as regular user",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/logs",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/logs",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubLogsData(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"page":1`,
|
||||
`"perPage":30`,
|
||||
`"totalItems":2`,
|
||||
`"items":[{`,
|
||||
`"id":"873f2133-9f38-44fb-bf82-c8f53b310d91"`,
|
||||
`"id":"f2133873-44fb-9f38-bf82-c918f53b310d"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser + filter",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/logs?filter=data.status>200",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubLogsData(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"page":1`,
|
||||
`"perPage":30`,
|
||||
`"totalItems":1`,
|
||||
`"items":[{`,
|
||||
`"id":"f2133873-44fb-9f38-bf82-c918f53b310d"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogView(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "unauthorized",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/logs/873f2133-9f38-44fb-bf82-c8f53b310d91",
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as regular user",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/logs/873f2133-9f38-44fb-bf82-c8f53b310d91",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser (nonexisting request log)",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/logs/missing1-9f38-44fb-bf82-c8f53b310d91",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubLogsData(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser (existing request log)",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/logs/873f2133-9f38-44fb-bf82-c8f53b310d91",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubLogsData(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"id":"873f2133-9f38-44fb-bf82-c8f53b310d91"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogsStats(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "unauthorized",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/logs/stats",
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as regular user",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/logs/stats",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/logs/stats",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubLogsData(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`[{"date":"2022-05-01 10:00:00.000Z","total":1},{"date":"2022-05-02 10:00:00.000Z","total":1}]`,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser + filter",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/logs/stats?filter=data.status>200",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubLogsData(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`[{"date":"2022-05-02 10:00:00.000Z","total":1}]`,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
444
apis/middlewares.go
Normal file
444
apis/middlewares.go
Normal file
|
@ -0,0 +1,444 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tools/hook"
|
||||
"github.com/pocketbase/pocketbase/tools/list"
|
||||
"github.com/pocketbase/pocketbase/tools/router"
|
||||
"github.com/pocketbase/pocketbase/tools/routine"
|
||||
"github.com/spf13/cast"
|
||||
)
|
||||
|
||||
// Common request event store keys used by the middlewares and api handlers.
|
||||
const (
|
||||
RequestEventKeyLogMeta = "pbLogMeta" // extra data to store with the request activity log
|
||||
|
||||
requestEventKeyExecStart = "__execStart" // the value must be time.Time
|
||||
requestEventKeySkipSuccessActivityLog = "__skipSuccessActivityLogger" // the value must be bool
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultWWWRedirectMiddlewarePriority = -99999
|
||||
DefaultWWWRedirectMiddlewareId = "pbWWWRedirect"
|
||||
|
||||
DefaultActivityLoggerMiddlewarePriority = DefaultRateLimitMiddlewarePriority - 40
|
||||
DefaultActivityLoggerMiddlewareId = "pbActivityLogger"
|
||||
DefaultSkipSuccessActivityLogMiddlewareId = "pbSkipSuccessActivityLog"
|
||||
DefaultEnableAuthIdActivityLog = "pbEnableAuthIdActivityLog"
|
||||
|
||||
DefaultPanicRecoverMiddlewarePriority = DefaultRateLimitMiddlewarePriority - 30
|
||||
DefaultPanicRecoverMiddlewareId = "pbPanicRecover"
|
||||
|
||||
DefaultLoadAuthTokenMiddlewarePriority = DefaultRateLimitMiddlewarePriority - 20
|
||||
DefaultLoadAuthTokenMiddlewareId = "pbLoadAuthToken"
|
||||
|
||||
DefaultSecurityHeadersMiddlewarePriority = DefaultRateLimitMiddlewarePriority - 10
|
||||
DefaultSecurityHeadersMiddlewareId = "pbSecurityHeaders"
|
||||
|
||||
DefaultRequireGuestOnlyMiddlewareId = "pbRequireGuestOnly"
|
||||
DefaultRequireAuthMiddlewareId = "pbRequireAuth"
|
||||
DefaultRequireSuperuserAuthMiddlewareId = "pbRequireSuperuserAuth"
|
||||
DefaultRequireSuperuserOrOwnerAuthMiddlewareId = "pbRequireSuperuserOrOwnerAuth"
|
||||
DefaultRequireSameCollectionContextAuthMiddlewareId = "pbRequireSameCollectionContextAuth"
|
||||
)
|
||||
|
||||
// RequireGuestOnly middleware requires a request to NOT have a valid
|
||||
// Authorization header.
|
||||
//
|
||||
// This middleware is the opposite of [apis.RequireAuth()].
|
||||
func RequireGuestOnly() *hook.Handler[*core.RequestEvent] {
|
||||
return &hook.Handler[*core.RequestEvent]{
|
||||
Id: DefaultRequireGuestOnlyMiddlewareId,
|
||||
Func: func(e *core.RequestEvent) error {
|
||||
if e.Auth != nil {
|
||||
return router.NewBadRequestError("The request can be accessed only by guests.", nil)
|
||||
}
|
||||
|
||||
return e.Next()
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// RequireAuth middleware requires a request to have a valid record Authorization header.
|
||||
//
|
||||
// The auth record could be from any collection.
|
||||
// You can further filter the allowed record auth collections by specifying their names.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// apis.RequireAuth() // any auth collection
|
||||
// apis.RequireAuth("_superusers", "users") // only the listed auth collections
|
||||
func RequireAuth(optCollectionNames ...string) *hook.Handler[*core.RequestEvent] {
|
||||
return &hook.Handler[*core.RequestEvent]{
|
||||
Id: DefaultRequireAuthMiddlewareId,
|
||||
Func: requireAuth(optCollectionNames...),
|
||||
}
|
||||
}
|
||||
|
||||
func requireAuth(optCollectionNames ...string) func(*core.RequestEvent) error {
|
||||
return func(e *core.RequestEvent) error {
|
||||
if e.Auth == nil {
|
||||
return e.UnauthorizedError("The request requires valid record authorization token.", nil)
|
||||
}
|
||||
|
||||
// check record collection name
|
||||
if len(optCollectionNames) > 0 && !slices.Contains(optCollectionNames, e.Auth.Collection().Name) {
|
||||
return e.ForbiddenError("The authorized record is not allowed to perform this action.", nil)
|
||||
}
|
||||
|
||||
return e.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// RequireSuperuserAuth middleware requires a request to have
|
||||
// a valid superuser Authorization header.
|
||||
func RequireSuperuserAuth() *hook.Handler[*core.RequestEvent] {
|
||||
return &hook.Handler[*core.RequestEvent]{
|
||||
Id: DefaultRequireSuperuserAuthMiddlewareId,
|
||||
Func: requireAuth(core.CollectionNameSuperusers),
|
||||
}
|
||||
}
|
||||
|
||||
// RequireSuperuserOrOwnerAuth middleware requires a request to have
|
||||
// a valid superuser or regular record owner Authorization header set.
|
||||
//
|
||||
// This middleware is similar to [apis.RequireAuth()] but
|
||||
// for the auth record token expects to have the same id as the path
|
||||
// parameter ownerIdPathParam (default to "id" if empty).
|
||||
func RequireSuperuserOrOwnerAuth(ownerIdPathParam string) *hook.Handler[*core.RequestEvent] {
|
||||
return &hook.Handler[*core.RequestEvent]{
|
||||
Id: DefaultRequireSuperuserOrOwnerAuthMiddlewareId,
|
||||
Func: func(e *core.RequestEvent) error {
|
||||
if e.Auth == nil {
|
||||
return e.UnauthorizedError("The request requires superuser or record authorization token.", nil)
|
||||
}
|
||||
|
||||
if e.Auth.IsSuperuser() {
|
||||
return e.Next()
|
||||
}
|
||||
|
||||
if ownerIdPathParam == "" {
|
||||
ownerIdPathParam = "id"
|
||||
}
|
||||
ownerId := e.Request.PathValue(ownerIdPathParam)
|
||||
|
||||
// note: it is considered "safe" to compare only the record id
|
||||
// since the auth record ids are treated as unique across all auth collections
|
||||
if e.Auth.Id != ownerId {
|
||||
return e.ForbiddenError("You are not allowed to perform this request.", nil)
|
||||
}
|
||||
|
||||
return e.Next()
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// RequireSameCollectionContextAuth middleware requires a request to have
|
||||
// a valid record Authorization header and the auth record's collection to
|
||||
// match the one from the route path parameter (default to "collection" if collectionParam is empty).
|
||||
func RequireSameCollectionContextAuth(collectionPathParam string) *hook.Handler[*core.RequestEvent] {
|
||||
return &hook.Handler[*core.RequestEvent]{
|
||||
Id: DefaultRequireSameCollectionContextAuthMiddlewareId,
|
||||
Func: func(e *core.RequestEvent) error {
|
||||
if e.Auth == nil {
|
||||
return e.UnauthorizedError("The request requires valid record authorization token.", nil)
|
||||
}
|
||||
|
||||
if collectionPathParam == "" {
|
||||
collectionPathParam = "collection"
|
||||
}
|
||||
|
||||
collection, _ := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue(collectionPathParam))
|
||||
if collection == nil || e.Auth.Collection().Id != collection.Id {
|
||||
return e.ForbiddenError(fmt.Sprintf("The request requires auth record from %s collection.", e.Auth.Collection().Name), nil)
|
||||
}
|
||||
|
||||
return e.Next()
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// loadAuthToken attempts to load the auth context based on the "Authorization: TOKEN" header value.
|
||||
//
|
||||
// This middleware does nothing in case of:
|
||||
// - missing, invalid or expired token
|
||||
// - e.Auth is already loaded by another middleware
|
||||
//
|
||||
// This middleware is registered by default for all routes.
|
||||
//
|
||||
// Note: We don't throw an error on invalid or expired token to allow
|
||||
// users to extend with their own custom handling in external middleware(s).
|
||||
func loadAuthToken() *hook.Handler[*core.RequestEvent] {
|
||||
return &hook.Handler[*core.RequestEvent]{
|
||||
Id: DefaultLoadAuthTokenMiddlewareId,
|
||||
Priority: DefaultLoadAuthTokenMiddlewarePriority,
|
||||
Func: func(e *core.RequestEvent) error {
|
||||
// already loaded by another middleware
|
||||
if e.Auth != nil {
|
||||
return e.Next()
|
||||
}
|
||||
|
||||
token := getAuthTokenFromRequest(e)
|
||||
if token == "" {
|
||||
return e.Next()
|
||||
}
|
||||
|
||||
record, err := e.App.FindAuthRecordByToken(token, core.TokenTypeAuth)
|
||||
if err != nil {
|
||||
e.App.Logger().Debug("loadAuthToken failure", "error", err)
|
||||
} else if record != nil {
|
||||
e.Auth = record
|
||||
}
|
||||
|
||||
return e.Next()
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func getAuthTokenFromRequest(e *core.RequestEvent) string {
|
||||
token := e.Request.Header.Get("Authorization")
|
||||
if token != "" {
|
||||
// the schema prefix is not required and it is only for
|
||||
// compatibility with the defaults of some HTTP clients
|
||||
token = strings.TrimPrefix(token, "Bearer ")
|
||||
}
|
||||
return token
|
||||
}
|
||||
|
||||
// wwwRedirect performs www->non-www redirect(s) if the request host
|
||||
// matches with one of the values in redirectHosts.
|
||||
//
|
||||
// This middleware is registered by default on Serve for all routes.
|
||||
func wwwRedirect(redirectHosts []string) *hook.Handler[*core.RequestEvent] {
|
||||
return &hook.Handler[*core.RequestEvent]{
|
||||
Id: DefaultWWWRedirectMiddlewareId,
|
||||
Priority: DefaultWWWRedirectMiddlewarePriority,
|
||||
Func: func(e *core.RequestEvent) error {
|
||||
host := e.Request.Host
|
||||
|
||||
if strings.HasPrefix(host, "www.") && list.ExistInSlice(host, redirectHosts) {
|
||||
// note: e.Request.URL.Scheme would be empty
|
||||
schema := "http://"
|
||||
if e.IsTLS() {
|
||||
schema = "https://"
|
||||
}
|
||||
|
||||
return e.Redirect(
|
||||
http.StatusTemporaryRedirect,
|
||||
(schema + host[4:] + e.Request.RequestURI),
|
||||
)
|
||||
}
|
||||
|
||||
return e.Next()
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// panicRecover returns a default panic-recover handler.
|
||||
func panicRecover() *hook.Handler[*core.RequestEvent] {
|
||||
return &hook.Handler[*core.RequestEvent]{
|
||||
Id: DefaultPanicRecoverMiddlewareId,
|
||||
Priority: DefaultPanicRecoverMiddlewarePriority,
|
||||
Func: func(e *core.RequestEvent) (err error) {
|
||||
// panic-recover
|
||||
defer func() {
|
||||
recoverResult := recover()
|
||||
if recoverResult == nil {
|
||||
return
|
||||
}
|
||||
|
||||
recoverErr, ok := recoverResult.(error)
|
||||
if !ok {
|
||||
recoverErr = fmt.Errorf("%v", recoverResult)
|
||||
} else if errors.Is(recoverErr, http.ErrAbortHandler) {
|
||||
// don't recover ErrAbortHandler so the response to the client can be aborted
|
||||
panic(recoverResult)
|
||||
}
|
||||
|
||||
stack := make([]byte, 2<<10) // 2 KB
|
||||
length := runtime.Stack(stack, true)
|
||||
err = e.InternalServerError("", fmt.Errorf("[PANIC RECOVER] %w %s", recoverErr, stack[:length]))
|
||||
}()
|
||||
|
||||
err = e.Next()
|
||||
|
||||
return err
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// securityHeaders middleware adds common security headers to the response.
|
||||
//
|
||||
// This middleware is registered by default for all routes.
|
||||
func securityHeaders() *hook.Handler[*core.RequestEvent] {
|
||||
return &hook.Handler[*core.RequestEvent]{
|
||||
Id: DefaultSecurityHeadersMiddlewareId,
|
||||
Priority: DefaultSecurityHeadersMiddlewarePriority,
|
||||
Func: func(e *core.RequestEvent) error {
|
||||
e.Response.Header().Set("X-XSS-Protection", "1; mode=block")
|
||||
e.Response.Header().Set("X-Content-Type-Options", "nosniff")
|
||||
e.Response.Header().Set("X-Frame-Options", "SAMEORIGIN")
|
||||
|
||||
// @todo consider a default HSTS?
|
||||
// (see also https://webkit.org/blog/8146/protecting-against-hsts-abuse/)
|
||||
|
||||
return e.Next()
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// SkipSuccessActivityLog is a helper middleware that instructs the global
|
||||
// activity logger to log only requests that have failed/returned an error.
|
||||
func SkipSuccessActivityLog() *hook.Handler[*core.RequestEvent] {
|
||||
return &hook.Handler[*core.RequestEvent]{
|
||||
Id: DefaultSkipSuccessActivityLogMiddlewareId,
|
||||
Func: func(e *core.RequestEvent) error {
|
||||
e.Set(requestEventKeySkipSuccessActivityLog, true)
|
||||
return e.Next()
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// activityLogger middleware takes care to save the request information
|
||||
// into the logs database.
|
||||
//
|
||||
// This middleware is registered by default for all routes.
|
||||
//
|
||||
// The middleware does nothing if the app logs retention period is zero
|
||||
// (aka. app.Settings().Logs.MaxDays = 0).
|
||||
//
|
||||
// Users can attach the [apis.SkipSuccessActivityLog()] middleware if
|
||||
// you want to log only the failed requests.
|
||||
func activityLogger() *hook.Handler[*core.RequestEvent] {
|
||||
return &hook.Handler[*core.RequestEvent]{
|
||||
Id: DefaultActivityLoggerMiddlewareId,
|
||||
Priority: DefaultActivityLoggerMiddlewarePriority,
|
||||
Func: func(e *core.RequestEvent) error {
|
||||
e.Set(requestEventKeyExecStart, time.Now())
|
||||
|
||||
err := e.Next()
|
||||
|
||||
logRequest(e, err)
|
||||
|
||||
return err
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func logRequest(event *core.RequestEvent, err error) {
|
||||
// no logs retention
|
||||
if event.App.Settings().Logs.MaxDays == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// the non-error route has explicitly disabled the activity logger
|
||||
if err == nil && event.Get(requestEventKeySkipSuccessActivityLog) != nil {
|
||||
return
|
||||
}
|
||||
|
||||
attrs := make([]any, 0, 15)
|
||||
|
||||
attrs = append(attrs, slog.String("type", "request"))
|
||||
|
||||
started := cast.ToTime(event.Get(requestEventKeyExecStart))
|
||||
if !started.IsZero() {
|
||||
attrs = append(attrs, slog.Float64("execTime", float64(time.Since(started))/float64(time.Millisecond)))
|
||||
}
|
||||
|
||||
if meta := event.Get(RequestEventKeyLogMeta); meta != nil {
|
||||
attrs = append(attrs, slog.Any("meta", meta))
|
||||
}
|
||||
|
||||
status := event.Status()
|
||||
method := cutStr(strings.ToUpper(event.Request.Method), 50)
|
||||
requestUri := cutStr(event.Request.URL.RequestURI(), 3000)
|
||||
|
||||
// parse the request error
|
||||
if err != nil {
|
||||
apiErr, isPlainApiError := err.(*router.ApiError)
|
||||
if isPlainApiError || errors.As(err, &apiErr) {
|
||||
// the status header wasn't written yet
|
||||
if status == 0 {
|
||||
status = apiErr.Status
|
||||
}
|
||||
|
||||
var errMsg string
|
||||
if isPlainApiError {
|
||||
errMsg = apiErr.Message
|
||||
} else {
|
||||
// wrapped ApiError -> add the full serialized version
|
||||
// of the original error since it could contain more information
|
||||
errMsg = err.Error()
|
||||
}
|
||||
|
||||
attrs = append(
|
||||
attrs,
|
||||
slog.String("error", errMsg),
|
||||
slog.Any("details", apiErr.RawData()),
|
||||
)
|
||||
} else {
|
||||
attrs = append(attrs, slog.String("error", err.Error()))
|
||||
}
|
||||
}
|
||||
|
||||
attrs = append(
|
||||
attrs,
|
||||
slog.String("url", requestUri),
|
||||
slog.String("method", method),
|
||||
slog.Int("status", status),
|
||||
slog.String("referer", cutStr(event.Request.Referer(), 2000)),
|
||||
slog.String("userAgent", cutStr(event.Request.UserAgent(), 2000)),
|
||||
)
|
||||
|
||||
if event.Auth != nil {
|
||||
attrs = append(attrs, slog.String("auth", event.Auth.Collection().Name))
|
||||
|
||||
if event.App.Settings().Logs.LogAuthId {
|
||||
attrs = append(attrs, slog.String("authId", event.Auth.Id))
|
||||
}
|
||||
} else {
|
||||
attrs = append(attrs, slog.String("auth", ""))
|
||||
}
|
||||
|
||||
if event.App.Settings().Logs.LogIP {
|
||||
attrs = append(
|
||||
attrs,
|
||||
slog.String("userIP", event.RealIP()),
|
||||
slog.String("remoteIP", event.RemoteIP()),
|
||||
)
|
||||
}
|
||||
|
||||
// don't block on logs write
|
||||
routine.FireAndForget(func() {
|
||||
message := method + " "
|
||||
|
||||
if escaped, unescapeErr := url.PathUnescape(requestUri); unescapeErr == nil {
|
||||
message += escaped
|
||||
} else {
|
||||
message += requestUri
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
event.App.Logger().Error(message, attrs...)
|
||||
} else {
|
||||
event.App.Logger().Info(message, attrs...)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func cutStr(str string, max int) string {
|
||||
if len(str) > max {
|
||||
return str[:max] + "..."
|
||||
}
|
||||
return str
|
||||
}
|
120
apis/middlewares_body_limit.go
Normal file
120
apis/middlewares_body_limit.go
Normal file
|
@ -0,0 +1,120 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tools/hook"
|
||||
"github.com/pocketbase/pocketbase/tools/router"
|
||||
)
|
||||
|
||||
var ErrRequestEntityTooLarge = router.NewApiError(http.StatusRequestEntityTooLarge, "Request entity too large", nil)
|
||||
|
||||
const DefaultMaxBodySize int64 = 32 << 20
|
||||
|
||||
const (
|
||||
DefaultBodyLimitMiddlewareId = "pbBodyLimit"
|
||||
DefaultBodyLimitMiddlewarePriority = DefaultRateLimitMiddlewarePriority + 10
|
||||
)
|
||||
|
||||
// BodyLimit returns a middleware handler that changes the default request body size limit.
|
||||
//
|
||||
// If limitBytes <= 0, no limit is applied.
|
||||
//
|
||||
// Otherwise, if the request body size exceeds the configured limitBytes,
|
||||
// it sends 413 error response.
|
||||
func BodyLimit(limitBytes int64) *hook.Handler[*core.RequestEvent] {
|
||||
return &hook.Handler[*core.RequestEvent]{
|
||||
Id: DefaultBodyLimitMiddlewareId,
|
||||
Priority: DefaultBodyLimitMiddlewarePriority,
|
||||
Func: func(e *core.RequestEvent) error {
|
||||
err := applyBodyLimit(e, limitBytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return e.Next()
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func dynamicCollectionBodyLimit(collectionPathParam string) *hook.Handler[*core.RequestEvent] {
|
||||
if collectionPathParam == "" {
|
||||
collectionPathParam = "collection"
|
||||
}
|
||||
|
||||
return &hook.Handler[*core.RequestEvent]{
|
||||
Id: DefaultBodyLimitMiddlewareId,
|
||||
Priority: DefaultBodyLimitMiddlewarePriority,
|
||||
Func: func(e *core.RequestEvent) error {
|
||||
collection, err := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue(collectionPathParam))
|
||||
if err != nil {
|
||||
return e.NotFoundError("Missing or invalid collection context.", err)
|
||||
}
|
||||
|
||||
limitBytes := DefaultMaxBodySize
|
||||
if !collection.IsView() {
|
||||
for _, f := range collection.Fields {
|
||||
if calc, ok := f.(core.MaxBodySizeCalculator); ok {
|
||||
limitBytes += calc.CalculateMaxBodySize()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
err = applyBodyLimit(e, limitBytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return e.Next()
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func applyBodyLimit(e *core.RequestEvent, limitBytes int64) error {
|
||||
// no limit
|
||||
if limitBytes <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// optimistically check the submitted request content length
|
||||
if e.Request.ContentLength > limitBytes {
|
||||
return ErrRequestEntityTooLarge
|
||||
}
|
||||
|
||||
// replace the request body
|
||||
//
|
||||
// note: we don't use sync.Pool since the size of the elements could vary too much
|
||||
// and it might not be efficient (see https://github.com/golang/go/issues/23199)
|
||||
e.Request.Body = &limitedReader{ReadCloser: e.Request.Body, limit: limitBytes}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type limitedReader struct {
|
||||
io.ReadCloser
|
||||
limit int64
|
||||
totalRead int64
|
||||
}
|
||||
|
||||
func (r *limitedReader) Read(b []byte) (int, error) {
|
||||
n, err := r.ReadCloser.Read(b)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
|
||||
r.totalRead += int64(n)
|
||||
if r.totalRead > r.limit {
|
||||
return n, ErrRequestEntityTooLarge
|
||||
}
|
||||
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (r *limitedReader) Reread() {
|
||||
rr, ok := r.ReadCloser.(router.Rereader)
|
||||
if ok {
|
||||
rr.Reread()
|
||||
}
|
||||
}
|
60
apis/middlewares_body_limit_test.go
Normal file
60
apis/middlewares_body_limit_test.go
Normal file
|
@ -0,0 +1,60 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/apis"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestBodyLimitMiddleware(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
pbRouter, err := apis.NewRouter(app)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
pbRouter.POST("/a", func(e *core.RequestEvent) error {
|
||||
return e.String(200, "a")
|
||||
}) // default global BodyLimit check
|
||||
|
||||
pbRouter.POST("/b", func(e *core.RequestEvent) error {
|
||||
return e.String(200, "b")
|
||||
}).Bind(apis.BodyLimit(20))
|
||||
|
||||
mux, err := pbRouter.BuildMux()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
scenarios := []struct {
|
||||
url string
|
||||
size int64
|
||||
expectedStatus int
|
||||
}{
|
||||
{"/a", 21, 200},
|
||||
{"/a", apis.DefaultMaxBodySize + 1, 413},
|
||||
{"/b", 20, 200},
|
||||
{"/b", 21, 413},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%s_%d", s.url, s.size), func(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("POST", s.url, bytes.NewReader(make([]byte, s.size)))
|
||||
mux.ServeHTTP(rec, req)
|
||||
|
||||
result := rec.Result()
|
||||
defer result.Body.Close()
|
||||
|
||||
if result.StatusCode != s.expectedStatus {
|
||||
t.Fatalf("Expected response status %d, got %d", s.expectedStatus, result.StatusCode)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
327
apis/middlewares_cors.go
Normal file
327
apis/middlewares_cors.go
Normal file
|
@ -0,0 +1,327 @@
|
|||
package apis
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
// This middleware is ported from echo/middleware to minimize the breaking
|
||||
// changes and differences in the API behavior from earlier PocketBase versions
|
||||
// (https://github.com/labstack/echo/blob/ec5b858dab6105ab4c3ed2627d1ebdfb6ae1ecb8/middleware/cors.go).
|
||||
//
|
||||
// I doubt that this would matter for most cases, but the only major difference
|
||||
// is that for non-supported routes this middleware doesn't return 405 and fallbacks
|
||||
// to the default catch-all PocketBase route (aka. returns 404) to avoid
|
||||
// the extra overhead of further hijacking and wrapping the Go default mux
|
||||
// (https://github.com/golang/go/issues/65648#issuecomment-1955328807).
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tools/hook"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultCorsMiddlewareId = "pbCors"
|
||||
DefaultCorsMiddlewarePriority = DefaultActivityLoggerMiddlewarePriority - 1 // before the activity logger and rate limit so that OPTIONS preflight requests are not counted
|
||||
)
|
||||
|
||||
// CORSConfig defines the config for CORS middleware.
|
||||
type CORSConfig struct {
|
||||
// AllowOrigins determines the value of the Access-Control-Allow-Origin
|
||||
// response header. This header defines a list of origins that may access the
|
||||
// resource. The wildcard characters '*' and '?' are supported and are
|
||||
// converted to regex fragments '.*' and '.' accordingly.
|
||||
//
|
||||
// Security: use extreme caution when handling the origin, and carefully
|
||||
// validate any logic. Remember that attackers may register hostile domain names.
|
||||
// See https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html
|
||||
//
|
||||
// Optional. Default value []string{"*"}.
|
||||
//
|
||||
// See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin
|
||||
AllowOrigins []string
|
||||
|
||||
// AllowOriginFunc is a custom function to validate the origin. It takes the
|
||||
// origin as an argument and returns true if allowed or false otherwise. If
|
||||
// an error is returned, it is returned by the handler. If this option is
|
||||
// set, AllowOrigins is ignored.
|
||||
//
|
||||
// Security: use extreme caution when handling the origin, and carefully
|
||||
// validate any logic. Remember that attackers may register hostile domain names.
|
||||
// See https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html
|
||||
//
|
||||
// Optional.
|
||||
AllowOriginFunc func(origin string) (bool, error)
|
||||
|
||||
// AllowMethods determines the value of the Access-Control-Allow-Methods
|
||||
// response header. This header specified the list of methods allowed when
|
||||
// accessing the resource. This is used in response to a preflight request.
|
||||
//
|
||||
// Optional. Default value DefaultCORSConfig.AllowMethods.
|
||||
//
|
||||
// See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Methods
|
||||
AllowMethods []string
|
||||
|
||||
// AllowHeaders determines the value of the Access-Control-Allow-Headers
|
||||
// response header. This header is used in response to a preflight request to
|
||||
// indicate which HTTP headers can be used when making the actual request.
|
||||
//
|
||||
// Optional. Default value []string{}.
|
||||
//
|
||||
// See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers
|
||||
AllowHeaders []string
|
||||
|
||||
// AllowCredentials determines the value of the
|
||||
// Access-Control-Allow-Credentials response header. This header indicates
|
||||
// whether or not the response to the request can be exposed when the
|
||||
// credentials mode (Request.credentials) is true. When used as part of a
|
||||
// response to a preflight request, this indicates whether or not the actual
|
||||
// request can be made using credentials. See also
|
||||
// [MDN: Access-Control-Allow-Credentials].
|
||||
//
|
||||
// Optional. Default value false, in which case the header is not set.
|
||||
//
|
||||
// Security: avoid using `AllowCredentials = true` with `AllowOrigins = *`.
|
||||
// See "Exploiting CORS misconfigurations for Bitcoins and bounties",
|
||||
// https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html
|
||||
//
|
||||
// See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Credentials
|
||||
AllowCredentials bool
|
||||
|
||||
// UnsafeWildcardOriginWithAllowCredentials UNSAFE/INSECURE: allows wildcard '*' origin to be used with AllowCredentials
|
||||
// flag. In that case we consider any origin allowed and send it back to the client with `Access-Control-Allow-Origin` header.
|
||||
//
|
||||
// This is INSECURE and potentially leads to [cross-origin](https://portswigger.net/research/exploiting-cors-misconfigurations-for-bitcoins-and-bounties)
|
||||
// attacks. See: https://github.com/labstack/echo/issues/2400 for discussion on the subject.
|
||||
//
|
||||
// Optional. Default value is false.
|
||||
UnsafeWildcardOriginWithAllowCredentials bool
|
||||
|
||||
// ExposeHeaders determines the value of Access-Control-Expose-Headers, which
|
||||
// defines a list of headers that clients are allowed to access.
|
||||
//
|
||||
// Optional. Default value []string{}, in which case the header is not set.
|
||||
//
|
||||
// See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Expose-Header
|
||||
ExposeHeaders []string
|
||||
|
||||
// MaxAge determines the value of the Access-Control-Max-Age response header.
|
||||
// This header indicates how long (in seconds) the results of a preflight
|
||||
// request can be cached.
|
||||
// The header is set only if MaxAge != 0, negative value sends "0" which instructs browsers not to cache that response.
|
||||
//
|
||||
// Optional. Default value 0 - meaning header is not sent.
|
||||
//
|
||||
// See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Max-Age
|
||||
MaxAge int
|
||||
}
|
||||
|
||||
// DefaultCORSConfig is the default CORS middleware config.
|
||||
var DefaultCORSConfig = CORSConfig{
|
||||
AllowOrigins: []string{"*"},
|
||||
AllowMethods: []string{http.MethodGet, http.MethodHead, http.MethodPut, http.MethodPatch, http.MethodPost, http.MethodDelete},
|
||||
}
|
||||
|
||||
// CORS returns a CORS middleware.
|
||||
func CORS(config CORSConfig) *hook.Handler[*core.RequestEvent] {
|
||||
// Defaults
|
||||
if len(config.AllowOrigins) == 0 {
|
||||
config.AllowOrigins = DefaultCORSConfig.AllowOrigins
|
||||
}
|
||||
if len(config.AllowMethods) == 0 {
|
||||
config.AllowMethods = DefaultCORSConfig.AllowMethods
|
||||
}
|
||||
|
||||
allowOriginPatterns := make([]*regexp.Regexp, 0, len(config.AllowOrigins))
|
||||
for _, origin := range config.AllowOrigins {
|
||||
if origin == "*" {
|
||||
continue // "*" is handled differently and does not need regexp
|
||||
}
|
||||
|
||||
pattern := regexp.QuoteMeta(origin)
|
||||
pattern = strings.ReplaceAll(pattern, "\\*", ".*")
|
||||
pattern = strings.ReplaceAll(pattern, "\\?", ".")
|
||||
pattern = "^" + pattern + "$"
|
||||
|
||||
re, err := regexp.Compile(pattern)
|
||||
if err != nil {
|
||||
// This is to preserve previous behaviour - invalid patterns were just ignored.
|
||||
// If we would turn this to panic, users with invalid patterns
|
||||
// would have applications crashing in production due unrecovered panic.
|
||||
log.Println("invalid AllowOrigins pattern", origin)
|
||||
continue
|
||||
}
|
||||
allowOriginPatterns = append(allowOriginPatterns, re)
|
||||
}
|
||||
|
||||
allowMethods := strings.Join(config.AllowMethods, ",")
|
||||
allowHeaders := strings.Join(config.AllowHeaders, ",")
|
||||
exposeHeaders := strings.Join(config.ExposeHeaders, ",")
|
||||
|
||||
maxAge := "0"
|
||||
if config.MaxAge > 0 {
|
||||
maxAge = strconv.Itoa(config.MaxAge)
|
||||
}
|
||||
|
||||
return &hook.Handler[*core.RequestEvent]{
|
||||
Id: DefaultCorsMiddlewareId,
|
||||
Priority: DefaultCorsMiddlewarePriority,
|
||||
Func: func(e *core.RequestEvent) error {
|
||||
req := e.Request
|
||||
res := e.Response
|
||||
origin := req.Header.Get("Origin")
|
||||
allowOrigin := ""
|
||||
|
||||
res.Header().Add("Vary", "Origin")
|
||||
|
||||
// Preflight request is an OPTIONS request, using three HTTP request headers: Access-Control-Request-Method,
|
||||
// Access-Control-Request-Headers, and the Origin header. See: https://developer.mozilla.org/en-US/docs/Glossary/Preflight_request
|
||||
// For simplicity we just consider method type and later `Origin` header.
|
||||
preflight := req.Method == http.MethodOptions
|
||||
|
||||
// No Origin provided. This is (probably) not request from actual browser - proceed executing middleware chain
|
||||
if origin == "" {
|
||||
if !preflight {
|
||||
return e.Next()
|
||||
}
|
||||
return e.NoContent(http.StatusNoContent)
|
||||
}
|
||||
|
||||
if config.AllowOriginFunc != nil {
|
||||
allowed, err := config.AllowOriginFunc(origin)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if allowed {
|
||||
allowOrigin = origin
|
||||
}
|
||||
} else {
|
||||
// Check allowed origins
|
||||
for _, o := range config.AllowOrigins {
|
||||
if o == "*" && config.AllowCredentials && config.UnsafeWildcardOriginWithAllowCredentials {
|
||||
allowOrigin = origin
|
||||
break
|
||||
}
|
||||
if o == "*" || o == origin {
|
||||
allowOrigin = o
|
||||
break
|
||||
}
|
||||
if matchSubdomain(origin, o) {
|
||||
allowOrigin = origin
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
checkPatterns := false
|
||||
if allowOrigin == "" {
|
||||
// to avoid regex cost by invalid (long) domains (253 is domain name max limit)
|
||||
if len(origin) <= (253+3+5) && strings.Contains(origin, "://") {
|
||||
checkPatterns = true
|
||||
}
|
||||
}
|
||||
if checkPatterns {
|
||||
for _, re := range allowOriginPatterns {
|
||||
if match := re.MatchString(origin); match {
|
||||
allowOrigin = origin
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Origin not allowed
|
||||
if allowOrigin == "" {
|
||||
if !preflight {
|
||||
return e.Next()
|
||||
}
|
||||
return e.NoContent(http.StatusNoContent)
|
||||
}
|
||||
|
||||
res.Header().Set("Access-Control-Allow-Origin", allowOrigin)
|
||||
if config.AllowCredentials {
|
||||
res.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
}
|
||||
|
||||
// Simple request
|
||||
if !preflight {
|
||||
if exposeHeaders != "" {
|
||||
res.Header().Set("Access-Control-Expose-Headers", exposeHeaders)
|
||||
}
|
||||
return e.Next()
|
||||
}
|
||||
|
||||
// Preflight request
|
||||
res.Header().Add("Vary", "Access-Control-Request-Method")
|
||||
res.Header().Add("Vary", "Access-Control-Request-Headers")
|
||||
res.Header().Set("Access-Control-Allow-Methods", allowMethods)
|
||||
|
||||
if allowHeaders != "" {
|
||||
res.Header().Set("Access-Control-Allow-Headers", allowHeaders)
|
||||
} else {
|
||||
h := req.Header.Get("Access-Control-Request-Headers")
|
||||
if h != "" {
|
||||
res.Header().Set("Access-Control-Allow-Headers", h)
|
||||
}
|
||||
}
|
||||
if config.MaxAge != 0 {
|
||||
res.Header().Set("Access-Control-Max-Age", maxAge)
|
||||
}
|
||||
|
||||
return e.NoContent(http.StatusNoContent)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func matchScheme(domain, pattern string) bool {
|
||||
didx := strings.Index(domain, ":")
|
||||
pidx := strings.Index(pattern, ":")
|
||||
return didx != -1 && pidx != -1 && domain[:didx] == pattern[:pidx]
|
||||
}
|
||||
|
||||
// matchSubdomain compares authority with wildcard
|
||||
func matchSubdomain(domain, pattern string) bool {
|
||||
if !matchScheme(domain, pattern) {
|
||||
return false
|
||||
}
|
||||
|
||||
didx := strings.Index(domain, "://")
|
||||
pidx := strings.Index(pattern, "://")
|
||||
if didx == -1 || pidx == -1 {
|
||||
return false
|
||||
}
|
||||
domAuth := domain[didx+3:]
|
||||
// to avoid long loop by invalid long domain
|
||||
if len(domAuth) > 253 {
|
||||
return false
|
||||
}
|
||||
patAuth := pattern[pidx+3:]
|
||||
|
||||
domComp := strings.Split(domAuth, ".")
|
||||
patComp := strings.Split(patAuth, ".")
|
||||
for i := len(domComp)/2 - 1; i >= 0; i-- {
|
||||
opp := len(domComp) - 1 - i
|
||||
domComp[i], domComp[opp] = domComp[opp], domComp[i]
|
||||
}
|
||||
for i := len(patComp)/2 - 1; i >= 0; i-- {
|
||||
opp := len(patComp) - 1 - i
|
||||
patComp[i], patComp[opp] = patComp[opp], patComp[i]
|
||||
}
|
||||
|
||||
for i, v := range domComp {
|
||||
if len(patComp) <= i {
|
||||
return false
|
||||
}
|
||||
p := patComp[i]
|
||||
if p == "*" {
|
||||
return true
|
||||
}
|
||||
if p != v {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
247
apis/middlewares_gzip.go
Normal file
247
apis/middlewares_gzip.go
Normal file
|
@ -0,0 +1,247 @@
|
|||
package apis
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
// This middleware is ported from echo/middleware to minimize the breaking
|
||||
// changes and differences in the API behavior from earlier PocketBase versions
|
||||
// (https://github.com/labstack/echo/blob/ec5b858dab6105ab4c3ed2627d1ebdfb6ae1ecb8/middleware/compress.go).
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tools/hook"
|
||||
"github.com/pocketbase/pocketbase/tools/router"
|
||||
)
|
||||
|
||||
const (
|
||||
gzipScheme = "gzip"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultGzipMiddlewareId = "pbGzip"
|
||||
)
|
||||
|
||||
// GzipConfig defines the config for Gzip middleware.
|
||||
type GzipConfig struct {
|
||||
// Gzip compression level.
|
||||
// Optional. Default value -1.
|
||||
Level int
|
||||
|
||||
// Length threshold before gzip compression is applied.
|
||||
// Optional. Default value 0.
|
||||
//
|
||||
// Most of the time you will not need to change the default. Compressing
|
||||
// a short response might increase the transmitted data because of the
|
||||
// gzip format overhead. Compressing the response will also consume CPU
|
||||
// and time on the server and the client (for decompressing). Depending on
|
||||
// your use case such a threshold might be useful.
|
||||
//
|
||||
// See also:
|
||||
// https://webmasters.stackexchange.com/questions/31750/what-is-recommended-minimum-object-size-for-gzip-performance-benefits
|
||||
MinLength int
|
||||
}
|
||||
|
||||
// Gzip returns a middleware which compresses HTTP response using Gzip compression scheme.
|
||||
func Gzip() *hook.Handler[*core.RequestEvent] {
|
||||
return GzipWithConfig(GzipConfig{})
|
||||
}
|
||||
|
||||
// GzipWithConfig returns a middleware which compresses HTTP response using gzip compression scheme.
|
||||
func GzipWithConfig(config GzipConfig) *hook.Handler[*core.RequestEvent] {
|
||||
if config.Level < -2 || config.Level > 9 { // these are consts: gzip.HuffmanOnly and gzip.BestCompression
|
||||
panic(errors.New("invalid gzip level"))
|
||||
}
|
||||
if config.Level == 0 {
|
||||
config.Level = -1
|
||||
}
|
||||
if config.MinLength < 0 {
|
||||
config.MinLength = 0
|
||||
}
|
||||
|
||||
pool := sync.Pool{
|
||||
New: func() interface{} {
|
||||
w, err := gzip.NewWriterLevel(io.Discard, config.Level)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return w
|
||||
},
|
||||
}
|
||||
|
||||
bpool := sync.Pool{
|
||||
New: func() interface{} {
|
||||
b := &bytes.Buffer{}
|
||||
return b
|
||||
},
|
||||
}
|
||||
|
||||
return &hook.Handler[*core.RequestEvent]{
|
||||
Id: DefaultGzipMiddlewareId,
|
||||
Func: func(e *core.RequestEvent) error {
|
||||
e.Response.Header().Add("Vary", "Accept-Encoding")
|
||||
if strings.Contains(e.Request.Header.Get("Accept-Encoding"), gzipScheme) {
|
||||
w, ok := pool.Get().(*gzip.Writer)
|
||||
if !ok {
|
||||
return e.InternalServerError("", errors.New("failed to get gzip.Writer"))
|
||||
}
|
||||
|
||||
rw := e.Response
|
||||
w.Reset(rw)
|
||||
|
||||
buf := bpool.Get().(*bytes.Buffer)
|
||||
buf.Reset()
|
||||
|
||||
grw := &gzipResponseWriter{Writer: w, ResponseWriter: rw, minLength: config.MinLength, buffer: buf}
|
||||
defer func() {
|
||||
// There are different reasons for cases when we have not yet written response to the client and now need to do so.
|
||||
// a) handler response had only response code and no response body (ala 404 or redirects etc). Response code need to be written now.
|
||||
// b) body is shorter than our minimum length threshold and being buffered currently and needs to be written
|
||||
if !grw.wroteBody {
|
||||
if rw.Header().Get("Content-Encoding") == gzipScheme {
|
||||
rw.Header().Del("Content-Encoding")
|
||||
}
|
||||
if grw.wroteHeader {
|
||||
rw.WriteHeader(grw.code)
|
||||
}
|
||||
// We have to reset response to it's pristine state when
|
||||
// nothing is written to body or error is returned.
|
||||
// See issue echo#424, echo#407.
|
||||
e.Response = rw
|
||||
w.Reset(io.Discard)
|
||||
} else if !grw.minLengthExceeded {
|
||||
// Write uncompressed response
|
||||
e.Response = rw
|
||||
if grw.wroteHeader {
|
||||
rw.WriteHeader(grw.code)
|
||||
}
|
||||
grw.buffer.WriteTo(rw)
|
||||
w.Reset(io.Discard)
|
||||
}
|
||||
w.Close()
|
||||
bpool.Put(buf)
|
||||
pool.Put(w)
|
||||
}()
|
||||
e.Response = grw
|
||||
}
|
||||
|
||||
return e.Next()
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
type gzipResponseWriter struct {
|
||||
http.ResponseWriter
|
||||
io.Writer
|
||||
buffer *bytes.Buffer
|
||||
minLength int
|
||||
code int
|
||||
wroteHeader bool
|
||||
wroteBody bool
|
||||
minLengthExceeded bool
|
||||
}
|
||||
|
||||
func (w *gzipResponseWriter) WriteHeader(code int) {
|
||||
w.Header().Del("Content-Length") // Issue echo#444
|
||||
|
||||
w.wroteHeader = true
|
||||
|
||||
// Delay writing of the header until we know if we'll actually compress the response
|
||||
w.code = code
|
||||
}
|
||||
|
||||
func (w *gzipResponseWriter) Write(b []byte) (int, error) {
|
||||
if w.Header().Get("Content-Type") == "" {
|
||||
w.Header().Set("Content-Type", http.DetectContentType(b))
|
||||
}
|
||||
|
||||
w.wroteBody = true
|
||||
|
||||
if !w.minLengthExceeded {
|
||||
n, err := w.buffer.Write(b)
|
||||
|
||||
if w.buffer.Len() >= w.minLength {
|
||||
w.minLengthExceeded = true
|
||||
|
||||
// The minimum length is exceeded, add Content-Encoding header and write the header
|
||||
w.Header().Set("Content-Encoding", gzipScheme)
|
||||
if w.wroteHeader {
|
||||
w.ResponseWriter.WriteHeader(w.code)
|
||||
}
|
||||
|
||||
return w.Writer.Write(w.buffer.Bytes())
|
||||
}
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
||||
return w.Writer.Write(b)
|
||||
}
|
||||
|
||||
func (w *gzipResponseWriter) Flush() {
|
||||
if !w.minLengthExceeded {
|
||||
// Enforce compression because we will not know how much more data will come
|
||||
w.minLengthExceeded = true
|
||||
w.Header().Set("Content-Encoding", gzipScheme)
|
||||
if w.wroteHeader {
|
||||
w.ResponseWriter.WriteHeader(w.code)
|
||||
}
|
||||
|
||||
_, _ = w.Writer.Write(w.buffer.Bytes())
|
||||
}
|
||||
|
||||
_ = w.Writer.(*gzip.Writer).Flush()
|
||||
|
||||
_ = http.NewResponseController(w.ResponseWriter).Flush()
|
||||
}
|
||||
|
||||
func (w *gzipResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
return http.NewResponseController(w.ResponseWriter).Hijack()
|
||||
}
|
||||
|
||||
func (w *gzipResponseWriter) Push(target string, opts *http.PushOptions) error {
|
||||
rw := w.ResponseWriter
|
||||
for {
|
||||
switch p := rw.(type) {
|
||||
case http.Pusher:
|
||||
return p.Push(target, opts)
|
||||
case router.RWUnwrapper:
|
||||
rw = p.Unwrap()
|
||||
default:
|
||||
return http.ErrNotSupported
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Note: Disable the implementation for now because in case the platform
|
||||
// supports the sendfile fast-path it won't run gzipResponseWriter.Write,
|
||||
// preventing compression on the fly.
|
||||
//
|
||||
// func (w *gzipResponseWriter) ReadFrom(r io.Reader) (n int64, err error) {
|
||||
// if w.wroteHeader {
|
||||
// w.ResponseWriter.WriteHeader(w.code)
|
||||
// }
|
||||
// rw := w.ResponseWriter
|
||||
// for {
|
||||
// switch rf := rw.(type) {
|
||||
// case io.ReaderFrom:
|
||||
// return rf.ReadFrom(r)
|
||||
// case router.RWUnwrapper:
|
||||
// rw = rf.Unwrap()
|
||||
// default:
|
||||
// return io.Copy(w.ResponseWriter, r)
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
func (w *gzipResponseWriter) Unwrap() http.ResponseWriter {
|
||||
return w.ResponseWriter
|
||||
}
|
356
apis/middlewares_rate_limit.go
Normal file
356
apis/middlewares_rate_limit.go
Normal file
|
@ -0,0 +1,356 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tools/hook"
|
||||
"github.com/pocketbase/pocketbase/tools/store"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultRateLimitMiddlewareId = "pbRateLimit"
|
||||
DefaultRateLimitMiddlewarePriority = -1000
|
||||
)
|
||||
|
||||
const (
|
||||
rateLimitersStoreKey = "__pbRateLimiters__"
|
||||
rateLimitersCronKey = "__pbRateLimitersCleanup__"
|
||||
rateLimitersSettingsHookId = "__pbRateLimitersSettingsHook__"
|
||||
)
|
||||
|
||||
// rateLimit defines the global rate limit middleware.
|
||||
//
|
||||
// This middleware is registered by default for all routes.
|
||||
func rateLimit() *hook.Handler[*core.RequestEvent] {
|
||||
return &hook.Handler[*core.RequestEvent]{
|
||||
Id: DefaultRateLimitMiddlewareId,
|
||||
Priority: DefaultRateLimitMiddlewarePriority,
|
||||
Func: func(e *core.RequestEvent) error {
|
||||
if skipRateLimit(e) {
|
||||
return e.Next()
|
||||
}
|
||||
|
||||
rule, ok := e.App.Settings().RateLimits.FindRateLimitRule(
|
||||
defaultRateLimitLabels(e),
|
||||
defaultRateLimitAudience(e)...,
|
||||
)
|
||||
if ok {
|
||||
err := checkRateLimit(e, rule.Label+rule.Audience, rule)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return e.Next()
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// collectionPathRateLimit defines a rate limit middleware for the internal collection handlers.
|
||||
func collectionPathRateLimit(collectionPathParam string, baseTags ...string) *hook.Handler[*core.RequestEvent] {
|
||||
if collectionPathParam == "" {
|
||||
collectionPathParam = "collection"
|
||||
}
|
||||
|
||||
return &hook.Handler[*core.RequestEvent]{
|
||||
Id: DefaultRateLimitMiddlewareId,
|
||||
Priority: DefaultRateLimitMiddlewarePriority,
|
||||
Func: func(e *core.RequestEvent) error {
|
||||
collection, err := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue(collectionPathParam))
|
||||
if err != nil {
|
||||
return e.NotFoundError("Missing or invalid collection context.", err)
|
||||
}
|
||||
|
||||
if err := checkCollectionRateLimit(e, collection, baseTags...); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return e.Next()
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// checkCollectionRateLimit checks whether the current request satisfy the
|
||||
// rate limit configuration for the specific collection.
|
||||
//
|
||||
// Each baseTags entry will be prefixed with the collection name and its wildcard variant.
|
||||
func checkCollectionRateLimit(e *core.RequestEvent, collection *core.Collection, baseTags ...string) error {
|
||||
if skipRateLimit(e) {
|
||||
return nil
|
||||
}
|
||||
|
||||
labels := make([]string, 0, 2+len(baseTags)*2)
|
||||
|
||||
rtId := collection.Id + e.Request.Pattern
|
||||
|
||||
// add first the primary labels (aka. ["collectionName:action1", "collectionName:action2"])
|
||||
for _, baseTag := range baseTags {
|
||||
rtId += baseTag
|
||||
labels = append(labels, collection.Name+":"+baseTag)
|
||||
}
|
||||
|
||||
// add the wildcard labels (aka. [..., "*:action1","*:action2", "*"])
|
||||
for _, baseTag := range baseTags {
|
||||
labels = append(labels, "*:"+baseTag)
|
||||
}
|
||||
labels = append(labels, defaultRateLimitLabels(e)...)
|
||||
|
||||
rule, ok := e.App.Settings().RateLimits.FindRateLimitRule(labels, defaultRateLimitAudience(e)...)
|
||||
if ok {
|
||||
return checkRateLimit(e, rtId+rule.Audience, rule)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
// @todo consider exporting as helper?
|
||||
//
|
||||
//nolint:unused
|
||||
func isClientRateLimited(e *core.RequestEvent, rtId string) bool {
|
||||
rateLimiters, ok := e.App.Store().Get(rateLimitersStoreKey).(*store.Store[string, *rateLimiter])
|
||||
if !ok || rateLimiters == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
rt, ok := rateLimiters.GetOk(rtId)
|
||||
if !ok || rt == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
client, ok := rt.getClient(e.RealIP())
|
||||
if !ok || client == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return client.available <= 0 && time.Now().Unix()-client.lastConsume < client.interval
|
||||
}
|
||||
|
||||
// @todo consider exporting as helper?
|
||||
func checkRateLimit(e *core.RequestEvent, rtId string, rule core.RateLimitRule) error {
|
||||
switch rule.Audience {
|
||||
case core.RateLimitRuleAudienceAll:
|
||||
// valid for both guest and regular users
|
||||
case core.RateLimitRuleAudienceGuest:
|
||||
if e.Auth != nil {
|
||||
return nil
|
||||
}
|
||||
case core.RateLimitRuleAudienceAuth:
|
||||
if e.Auth == nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
rateLimiters := e.App.Store().GetOrSet(rateLimitersStoreKey, func() any {
|
||||
return initRateLimitersStore(e.App)
|
||||
}).(*store.Store[string, *rateLimiter])
|
||||
if rateLimiters == nil {
|
||||
e.App.Logger().Warn("Failed to retrieve app rate limiters store")
|
||||
return nil
|
||||
}
|
||||
|
||||
rt := rateLimiters.GetOrSet(rtId, func() *rateLimiter {
|
||||
return newRateLimiter(rule.MaxRequests, rule.Duration, rule.Duration+1800)
|
||||
})
|
||||
if rt == nil {
|
||||
e.App.Logger().Warn("Failed to retrieve app rate limiter", "id", rtId)
|
||||
return nil
|
||||
}
|
||||
|
||||
key := e.RealIP()
|
||||
if key == "" {
|
||||
e.App.Logger().Warn("Empty rate limit client key")
|
||||
return nil
|
||||
}
|
||||
|
||||
if !rt.isAllowed(key) {
|
||||
return e.TooManyRequestsError("", nil)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func skipRateLimit(e *core.RequestEvent) bool {
|
||||
return !e.App.Settings().RateLimits.Enabled || e.HasSuperuserAuth()
|
||||
}
|
||||
|
||||
var defaultAuthAudience = []string{core.RateLimitRuleAudienceAll, core.RateLimitRuleAudienceAuth}
|
||||
var defaultGuestAudience = []string{core.RateLimitRuleAudienceAll, core.RateLimitRuleAudienceGuest}
|
||||
|
||||
func defaultRateLimitAudience(e *core.RequestEvent) []string {
|
||||
if e.Auth != nil {
|
||||
return defaultAuthAudience
|
||||
}
|
||||
|
||||
return defaultGuestAudience
|
||||
}
|
||||
|
||||
func defaultRateLimitLabels(e *core.RequestEvent) []string {
|
||||
return []string{e.Request.Method + " " + e.Request.URL.Path, e.Request.URL.Path}
|
||||
}
|
||||
|
||||
func destroyRateLimitersStore(app core.App) {
|
||||
app.OnSettingsReload().Unbind(rateLimitersSettingsHookId)
|
||||
app.Cron().Remove(rateLimitersCronKey)
|
||||
app.Store().Remove(rateLimitersStoreKey)
|
||||
}
|
||||
|
||||
func initRateLimitersStore(app core.App) *store.Store[string, *rateLimiter] {
|
||||
app.Cron().Add(rateLimitersCronKey, "2 * * * *", func() { // offset a little since too many cleanup tasks execute at 00
|
||||
limitersStore, ok := app.Store().Get(rateLimitersStoreKey).(*store.Store[string, *rateLimiter])
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
limiters := limitersStore.GetAll()
|
||||
for _, limiter := range limiters {
|
||||
limiter.clean()
|
||||
}
|
||||
})
|
||||
|
||||
app.OnSettingsReload().Bind(&hook.Handler[*core.SettingsReloadEvent]{
|
||||
Id: rateLimitersSettingsHookId,
|
||||
Func: func(e *core.SettingsReloadEvent) error {
|
||||
err := e.Next()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// reset
|
||||
destroyRateLimitersStore(e.App)
|
||||
|
||||
return nil
|
||||
},
|
||||
})
|
||||
|
||||
return store.New[string, *rateLimiter](nil)
|
||||
}
|
||||
|
||||
func newRateLimiter(maxAllowed int, intervalInSec int64, minDeleteIntervalInSec int64) *rateLimiter {
|
||||
return &rateLimiter{
|
||||
maxAllowed: maxAllowed,
|
||||
interval: intervalInSec,
|
||||
minDeleteInterval: minDeleteIntervalInSec,
|
||||
clients: map[string]*fixedWindow{},
|
||||
}
|
||||
}
|
||||
|
||||
type rateLimiter struct {
|
||||
clients map[string]*fixedWindow
|
||||
|
||||
maxAllowed int
|
||||
interval int64
|
||||
minDeleteInterval int64
|
||||
totalDeleted int64
|
||||
|
||||
sync.RWMutex
|
||||
}
|
||||
|
||||
//nolint:unused
|
||||
func (rt *rateLimiter) getClient(key string) (*fixedWindow, bool) {
|
||||
rt.RLock()
|
||||
client, ok := rt.clients[key]
|
||||
rt.RUnlock()
|
||||
|
||||
return client, ok
|
||||
}
|
||||
|
||||
func (rt *rateLimiter) isAllowed(key string) bool {
|
||||
// lock only reads to minimize locks contention
|
||||
rt.RLock()
|
||||
client, ok := rt.clients[key]
|
||||
rt.RUnlock()
|
||||
|
||||
if !ok {
|
||||
rt.Lock()
|
||||
// check again in case the client was added by another request
|
||||
client, ok = rt.clients[key]
|
||||
if !ok {
|
||||
client = newFixedWindow(rt.maxAllowed, rt.interval)
|
||||
rt.clients[key] = client
|
||||
}
|
||||
rt.Unlock()
|
||||
}
|
||||
|
||||
return client.consume()
|
||||
}
|
||||
|
||||
func (rt *rateLimiter) clean() {
|
||||
rt.Lock()
|
||||
defer rt.Unlock()
|
||||
|
||||
nowUnix := time.Now().Unix()
|
||||
|
||||
for k, client := range rt.clients {
|
||||
if client.hasExpired(nowUnix, rt.minDeleteInterval) {
|
||||
delete(rt.clients, k)
|
||||
rt.totalDeleted++
|
||||
}
|
||||
}
|
||||
|
||||
// "shrink" the map if too may items were deleted
|
||||
//
|
||||
// @todo remove after https://github.com/golang/go/issues/20135
|
||||
if rt.totalDeleted >= 300 {
|
||||
shrunk := make(map[string]*fixedWindow, len(rt.clients))
|
||||
for k, v := range rt.clients {
|
||||
shrunk[k] = v
|
||||
}
|
||||
rt.clients = shrunk
|
||||
rt.totalDeleted = 0
|
||||
}
|
||||
}
|
||||
|
||||
func newFixedWindow(maxAllowed int, intervalInSec int64) *fixedWindow {
|
||||
return &fixedWindow{
|
||||
maxAllowed: maxAllowed,
|
||||
interval: intervalInSec,
|
||||
}
|
||||
}
|
||||
|
||||
type fixedWindow struct {
|
||||
// use plain Mutex instead of RWMutex since the operations are expected
|
||||
// to be mostly writes (e.g. consume()) and it should perform better
|
||||
sync.Mutex
|
||||
|
||||
maxAllowed int // the max allowed tokens per interval
|
||||
available int // the total available tokens
|
||||
interval int64 // in seconds
|
||||
lastConsume int64 // the time of the last consume
|
||||
}
|
||||
|
||||
// hasExpired checks whether it has been at least minElapsed seconds since the lastConsume time.
|
||||
// (usually used to perform periodic cleanup of staled instances).
|
||||
func (l *fixedWindow) hasExpired(relativeNow int64, minElapsed int64) bool {
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
|
||||
return relativeNow-l.lastConsume > minElapsed
|
||||
}
|
||||
|
||||
// consume decrease the current window allowance with 1 (if not exhausted already).
|
||||
//
|
||||
// It returns false if the allowance has been already exhausted and the user
|
||||
// has to wait until it resets back to its maxAllowed value.
|
||||
func (l *fixedWindow) consume() bool {
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
|
||||
nowUnix := time.Now().Unix()
|
||||
|
||||
// reset consumed counter
|
||||
if nowUnix-l.lastConsume >= l.interval {
|
||||
l.available = l.maxAllowed
|
||||
}
|
||||
|
||||
if l.available > 0 {
|
||||
l.available--
|
||||
l.lastConsume = nowUnix
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
159
apis/middlewares_rate_limit_test.go
Normal file
159
apis/middlewares_rate_limit_test.go
Normal file
|
@ -0,0 +1,159 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pocketbase/pocketbase/apis"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestDefaultRateLimitMiddleware(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{
|
||||
Label: "/rate/",
|
||||
MaxRequests: 2,
|
||||
Duration: 1,
|
||||
},
|
||||
{
|
||||
Label: "/rate/b",
|
||||
MaxRequests: 3,
|
||||
Duration: 1,
|
||||
},
|
||||
{
|
||||
Label: "POST /rate/b",
|
||||
MaxRequests: 1,
|
||||
Duration: 1,
|
||||
},
|
||||
{
|
||||
Label: "/rate/guest",
|
||||
MaxRequests: 1,
|
||||
Duration: 1,
|
||||
Audience: core.RateLimitRuleAudienceGuest,
|
||||
},
|
||||
{
|
||||
Label: "/rate/auth",
|
||||
MaxRequests: 1,
|
||||
Duration: 1,
|
||||
Audience: core.RateLimitRuleAudienceAuth,
|
||||
},
|
||||
}
|
||||
|
||||
pbRouter, err := apis.NewRouter(app)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
pbRouter.GET("/norate", func(e *core.RequestEvent) error {
|
||||
return e.String(200, "norate")
|
||||
}).BindFunc(func(e *core.RequestEvent) error {
|
||||
return e.Next()
|
||||
})
|
||||
pbRouter.GET("/rate/a", func(e *core.RequestEvent) error {
|
||||
return e.String(200, "a")
|
||||
})
|
||||
pbRouter.GET("/rate/b", func(e *core.RequestEvent) error {
|
||||
return e.String(200, "b")
|
||||
})
|
||||
pbRouter.GET("/rate/guest", func(e *core.RequestEvent) error {
|
||||
return e.String(200, "guest")
|
||||
})
|
||||
pbRouter.GET("/rate/auth", func(e *core.RequestEvent) error {
|
||||
return e.String(200, "auth")
|
||||
})
|
||||
|
||||
mux, err := pbRouter.BuildMux()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
scenarios := []struct {
|
||||
url string
|
||||
wait float64
|
||||
authenticated bool
|
||||
expectedStatus int
|
||||
}{
|
||||
{"/norate", 0, false, 200},
|
||||
{"/norate", 0, false, 200},
|
||||
{"/norate", 0, false, 200},
|
||||
{"/norate", 0, false, 200},
|
||||
{"/norate", 0, false, 200},
|
||||
|
||||
{"/rate/a", 0, false, 200},
|
||||
{"/rate/a", 0, false, 200},
|
||||
{"/rate/a", 0, false, 429},
|
||||
{"/rate/a", 0, false, 429},
|
||||
{"/rate/a", 1.1, false, 200},
|
||||
{"/rate/a", 0, false, 200},
|
||||
{"/rate/a", 0, false, 429},
|
||||
|
||||
{"/rate/b", 0, false, 200},
|
||||
{"/rate/b", 0, false, 200},
|
||||
{"/rate/b", 0, false, 200},
|
||||
{"/rate/b", 0, false, 429},
|
||||
{"/rate/b", 1.1, false, 200},
|
||||
{"/rate/b", 0, false, 200},
|
||||
{"/rate/b", 0, false, 200},
|
||||
{"/rate/b", 0, false, 429},
|
||||
|
||||
// "auth" with guest (should fallback to the /rate/ rule)
|
||||
{"/rate/auth", 0, false, 200},
|
||||
{"/rate/auth", 0, false, 200},
|
||||
{"/rate/auth", 0, false, 429},
|
||||
{"/rate/auth", 0, false, 429},
|
||||
|
||||
// "auth" rule with regular user (should match the /rate/auth rule)
|
||||
{"/rate/auth", 0, true, 200},
|
||||
{"/rate/auth", 0, true, 429},
|
||||
{"/rate/auth", 0, true, 429},
|
||||
|
||||
// "guest" with guest (should match the /rate/guest rule)
|
||||
{"/rate/guest", 0, false, 200},
|
||||
{"/rate/guest", 0, false, 429},
|
||||
{"/rate/guest", 0, false, 429},
|
||||
|
||||
// "guest" rule with regular user (should fallback to the /rate/ rule)
|
||||
{"/rate/guest", 1, true, 200},
|
||||
{"/rate/guest", 0, true, 200},
|
||||
{"/rate/guest", 0, true, 429},
|
||||
{"/rate/guest", 0, true, 429},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.url, func(t *testing.T) {
|
||||
if s.wait > 0 {
|
||||
time.Sleep(time.Duration(s.wait) * time.Second)
|
||||
}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("GET", s.url, nil)
|
||||
|
||||
if s.authenticated {
|
||||
auth, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
token, err := auth.NewAuthToken()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
req.Header.Add("Authorization", token)
|
||||
}
|
||||
|
||||
mux.ServeHTTP(rec, req)
|
||||
|
||||
result := rec.Result()
|
||||
|
||||
if result.StatusCode != s.expectedStatus {
|
||||
t.Fatalf("Expected response status %d, got %d", s.expectedStatus, result.StatusCode)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
539
apis/middlewares_test.go
Normal file
539
apis/middlewares_test.go
Normal file
|
@ -0,0 +1,539 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/apis"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestPanicRecover(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "panic from route",
|
||||
Method: http.MethodGet,
|
||||
URL: "/my/test",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
e.Router.GET("/my/test", func(e *core.RequestEvent) error {
|
||||
panic("123")
|
||||
})
|
||||
},
|
||||
ExpectedStatus: 500,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "panic from middleware",
|
||||
Method: http.MethodGet,
|
||||
URL: "/my/test",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
e.Router.GET("/my/test", func(e *core.RequestEvent) error {
|
||||
return e.String(http.StatusOK, "test")
|
||||
}).BindFunc(func(e *core.RequestEvent) error {
|
||||
panic(123)
|
||||
})
|
||||
},
|
||||
ExpectedStatus: 500,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireGuestOnly(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
beforeTestFunc := func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
e.Router.GET("/my/test", func(e *core.RequestEvent) error {
|
||||
return e.String(200, "test123")
|
||||
}).Bind(apis.RequireGuestOnly())
|
||||
}
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "valid regular user token",
|
||||
Method: http.MethodGet,
|
||||
URL: "/my/test",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
BeforeTestFunc: beforeTestFunc,
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "valid superuser auth token",
|
||||
Method: http.MethodGet,
|
||||
URL: "/my/test",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
BeforeTestFunc: beforeTestFunc,
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "expired/invalid token",
|
||||
Method: http.MethodGet,
|
||||
URL: "/my/test",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoxNjQwOTkxNjYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.2D3tmqPn3vc5LoqqCz8V-iCDVXo9soYiH0d32G7FQT4",
|
||||
},
|
||||
BeforeTestFunc: beforeTestFunc,
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{"test123"},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodGet,
|
||||
URL: "/my/test",
|
||||
BeforeTestFunc: beforeTestFunc,
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{"test123"},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireAuth(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodGet,
|
||||
URL: "/my/test",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
e.Router.GET("/my/test", func(e *core.RequestEvent) error {
|
||||
return e.String(200, "test123")
|
||||
}).Bind(apis.RequireAuth())
|
||||
},
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "expired token",
|
||||
Method: http.MethodGet,
|
||||
URL: "/my/test",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoxNjQwOTkxNjYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.2D3tmqPn3vc5LoqqCz8V-iCDVXo9soYiH0d32G7FQT4",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
e.Router.GET("/my/test", func(e *core.RequestEvent) error {
|
||||
return e.String(200, "test123")
|
||||
}).Bind(apis.RequireAuth())
|
||||
},
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "invalid token",
|
||||
Method: http.MethodGet,
|
||||
URL: "/my/test",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6ImZpbGUiLCJjb2xsZWN0aW9uSWQiOiJwYmNfMzE0MjYzNTgyMyJ9.Lupz541xRvrktwkrl55p5pPCF77T69ZRsohsIcb2dxc",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
e.Router.GET("/my/test", func(e *core.RequestEvent) error {
|
||||
return e.String(200, "test123")
|
||||
}).Bind(apis.RequireAuth())
|
||||
},
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "valid record auth token with no collection restrictions",
|
||||
Method: http.MethodGet,
|
||||
URL: "/my/test",
|
||||
Headers: map[string]string{
|
||||
// regular user
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
e.Router.GET("/my/test", func(e *core.RequestEvent) error {
|
||||
return e.String(200, "test123")
|
||||
}).Bind(apis.RequireAuth())
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{"test123"},
|
||||
},
|
||||
{
|
||||
Name: "valid record static auth token",
|
||||
Method: http.MethodGet,
|
||||
URL: "/my/test",
|
||||
Headers: map[string]string{
|
||||
// regular user
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6ZmFsc2V9.4IsO6YMsR19crhwl_YWzvRH8pfq2Ri4Gv2dzGyneLak",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
e.Router.GET("/my/test", func(e *core.RequestEvent) error {
|
||||
return e.String(200, "test123")
|
||||
}).Bind(apis.RequireAuth())
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{"test123"},
|
||||
},
|
||||
{
|
||||
Name: "valid record auth token with collection not in the restricted list",
|
||||
Method: http.MethodGet,
|
||||
URL: "/my/test",
|
||||
Headers: map[string]string{
|
||||
// superuser
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
e.Router.GET("/my/test", func(e *core.RequestEvent) error {
|
||||
return e.String(200, "test123")
|
||||
}).Bind(apis.RequireAuth("users", "demo1"))
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "valid record auth token with collection in the restricted list",
|
||||
Method: http.MethodGet,
|
||||
URL: "/my/test",
|
||||
Headers: map[string]string{
|
||||
// superuser
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
e.Router.GET("/my/test", func(e *core.RequestEvent) error {
|
||||
return e.String(200, "test123")
|
||||
}).Bind(apis.RequireAuth("users", core.CollectionNameSuperusers))
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{"test123"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireSuperuserAuth(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodGet,
|
||||
URL: "/my/test",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
e.Router.GET("/my/test", func(e *core.RequestEvent) error {
|
||||
return e.String(200, "test123")
|
||||
}).Bind(apis.RequireSuperuserAuth())
|
||||
},
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "expired/invalid token",
|
||||
Method: http.MethodGet,
|
||||
URL: "/my/test",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjE2NDA5OTE2NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.0pDcBPGDpL2Khh76ivlRi7ugiLBSYvasct3qpHV3rfs",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
e.Router.GET("/my/test", func(e *core.RequestEvent) error {
|
||||
return e.String(200, "test123")
|
||||
}).Bind(apis.RequireSuperuserAuth())
|
||||
},
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "valid regular user auth token",
|
||||
Method: http.MethodGet,
|
||||
URL: "/my/test",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
e.Router.GET("/my/test", func(e *core.RequestEvent) error {
|
||||
return e.String(200, "test123")
|
||||
}).Bind(apis.RequireSuperuserAuth())
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "valid superuser auth token",
|
||||
Method: http.MethodGet,
|
||||
URL: "/my/test",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
e.Router.GET("/my/test", func(e *core.RequestEvent) error {
|
||||
return e.String(200, "test123")
|
||||
}).Bind(apis.RequireSuperuserAuth())
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{"test123"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireSuperuserOrOwnerAuth(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodGet,
|
||||
URL: "/my/test/4q1xlclmfloku33",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
e.Router.GET("/my/test/{id}", func(e *core.RequestEvent) error {
|
||||
return e.String(200, "test123")
|
||||
}).Bind(apis.RequireSuperuserOrOwnerAuth(""))
|
||||
},
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "expired/invalid token",
|
||||
Method: http.MethodGet,
|
||||
URL: "/my/test/4q1xlclmfloku33",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjE2NDA5OTE2NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.0pDcBPGDpL2Khh76ivlRi7ugiLBSYvasct3qpHV3rfs",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
e.Router.GET("/my/test/{id}", func(e *core.RequestEvent) error {
|
||||
return e.String(200, "test123")
|
||||
}).Bind(apis.RequireSuperuserOrOwnerAuth(""))
|
||||
},
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "valid record auth token (different user)",
|
||||
Method: http.MethodGet,
|
||||
URL: "/my/test/oap640cot4yru2s",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
e.Router.GET("/my/test/{id}", func(e *core.RequestEvent) error {
|
||||
return e.String(200, "test123")
|
||||
}).Bind(apis.RequireSuperuserOrOwnerAuth(""))
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "valid record auth token (owner)",
|
||||
Method: http.MethodGet,
|
||||
URL: "/my/test/4q1xlclmfloku33",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
e.Router.GET("/my/test/{id}", func(e *core.RequestEvent) error {
|
||||
return e.String(200, "test123")
|
||||
}).Bind(apis.RequireSuperuserOrOwnerAuth(""))
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{"test123"},
|
||||
},
|
||||
{
|
||||
Name: "valid record auth token (owner + non-matching custom owner param)",
|
||||
Method: http.MethodGet,
|
||||
URL: "/my/test/4q1xlclmfloku33",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
e.Router.GET("/my/test/{id}", func(e *core.RequestEvent) error {
|
||||
return e.String(200, "test123")
|
||||
}).Bind(apis.RequireSuperuserOrOwnerAuth("test"))
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "valid record auth token (owner + matching custom owner param)",
|
||||
Method: http.MethodGet,
|
||||
URL: "/my/test/4q1xlclmfloku33",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
e.Router.GET("/my/test/{test}", func(e *core.RequestEvent) error {
|
||||
return e.String(200, "test123")
|
||||
}).Bind(apis.RequireSuperuserOrOwnerAuth("test"))
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{"test123"},
|
||||
},
|
||||
{
|
||||
Name: "valid superuser auth token",
|
||||
Method: http.MethodGet,
|
||||
URL: "/my/test/4q1xlclmfloku33",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
e.Router.GET("/my/test/{id}", func(e *core.RequestEvent) error {
|
||||
return e.String(200, "test123")
|
||||
}).Bind(apis.RequireSuperuserOrOwnerAuth(""))
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{"test123"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireSameCollectionContextAuth(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodGet,
|
||||
URL: "/my/test/_pb_users_auth_",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
e.Router.GET("/my/test/{collection}", func(e *core.RequestEvent) error {
|
||||
return e.String(200, "test123")
|
||||
}).Bind(apis.RequireSameCollectionContextAuth(""))
|
||||
},
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "expired/invalid token",
|
||||
Method: http.MethodGet,
|
||||
URL: "/my/test/_pb_users_auth_",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoxNjQwOTkxNjYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.2D3tmqPn3vc5LoqqCz8V-iCDVXo9soYiH0d32G7FQT4",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
e.Router.GET("/my/test/{collection}", func(e *core.RequestEvent) error {
|
||||
return e.String(200, "test123")
|
||||
}).Bind(apis.RequireSameCollectionContextAuth(""))
|
||||
},
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "valid record auth token (different collection)",
|
||||
Method: http.MethodGet,
|
||||
URL: "/my/test/clients",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
e.Router.GET("/my/test/{collection}", func(e *core.RequestEvent) error {
|
||||
return e.String(200, "test123")
|
||||
}).Bind(apis.RequireSameCollectionContextAuth(""))
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "valid record auth token (same collection)",
|
||||
Method: http.MethodGet,
|
||||
URL: "/my/test/_pb_users_auth_",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
e.Router.GET("/my/test/{collection}", func(e *core.RequestEvent) error {
|
||||
return e.String(200, "test123")
|
||||
}).Bind(apis.RequireSameCollectionContextAuth(""))
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{"test123"},
|
||||
},
|
||||
{
|
||||
Name: "valid record auth token (non-matching/missing collection param)",
|
||||
Method: http.MethodGet,
|
||||
URL: "/my/test/_pb_users_auth_",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
e.Router.GET("/my/test/{id}", func(e *core.RequestEvent) error {
|
||||
return e.String(200, "test123")
|
||||
}).Bind(apis.RequireSuperuserOrOwnerAuth(""))
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "valid record auth token (matching custom collection param)",
|
||||
Method: http.MethodGet,
|
||||
URL: "/my/test/_pb_users_auth_",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
e.Router.GET("/my/test/{test}", func(e *core.RequestEvent) error {
|
||||
return e.String(200, "test123")
|
||||
}).Bind(apis.RequireSuperuserOrOwnerAuth("test"))
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "superuser no exception check",
|
||||
Method: http.MethodGet,
|
||||
URL: "/my/test/_pb_users_auth_",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
e.Router.GET("/my/test/{collection}", func(e *core.RequestEvent) error {
|
||||
return e.String(200, "test123")
|
||||
}).Bind(apis.RequireSameCollectionContextAuth(""))
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
777
apis/realtime.go
Normal file
777
apis/realtime.go
Normal file
|
@ -0,0 +1,777 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tools/hook"
|
||||
"github.com/pocketbase/pocketbase/tools/picker"
|
||||
"github.com/pocketbase/pocketbase/tools/router"
|
||||
"github.com/pocketbase/pocketbase/tools/routine"
|
||||
"github.com/pocketbase/pocketbase/tools/search"
|
||||
"github.com/pocketbase/pocketbase/tools/subscriptions"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
// note: the chunk size is arbitrary chosen and may change in the future
|
||||
const clientsChunkSize = 150
|
||||
|
||||
// RealtimeClientAuthKey is the name of the realtime client store key that holds its auth state.
|
||||
const RealtimeClientAuthKey = "auth"
|
||||
|
||||
// bindRealtimeApi registers the realtime api endpoints.
|
||||
func bindRealtimeApi(app core.App, rg *router.RouterGroup[*core.RequestEvent]) {
|
||||
sub := rg.Group("/realtime")
|
||||
sub.GET("", realtimeConnect).Bind(SkipSuccessActivityLog())
|
||||
sub.POST("", realtimeSetSubscriptions)
|
||||
|
||||
bindRealtimeEvents(app)
|
||||
}
|
||||
|
||||
func realtimeConnect(e *core.RequestEvent) error {
|
||||
// disable global write deadline for the SSE connection
|
||||
rc := http.NewResponseController(e.Response)
|
||||
writeDeadlineErr := rc.SetWriteDeadline(time.Time{})
|
||||
if writeDeadlineErr != nil {
|
||||
if !errors.Is(writeDeadlineErr, http.ErrNotSupported) {
|
||||
return e.InternalServerError("Failed to initialize SSE connection.", writeDeadlineErr)
|
||||
}
|
||||
|
||||
// only log since there are valid cases where it may not be implement (e.g. httptest.ResponseRecorder)
|
||||
e.App.Logger().Warn("SetWriteDeadline is not supported, fallback to the default server WriteTimeout")
|
||||
}
|
||||
|
||||
// create cancellable request
|
||||
cancelCtx, cancelRequest := context.WithCancel(e.Request.Context())
|
||||
defer cancelRequest()
|
||||
e.Request = e.Request.Clone(cancelCtx)
|
||||
|
||||
e.Response.Header().Set("Content-Type", "text/event-stream")
|
||||
e.Response.Header().Set("Cache-Control", "no-store")
|
||||
// https://github.com/pocketbase/pocketbase/discussions/480#discussioncomment-3657640
|
||||
// https://nginx.org/en/docs/http/ngx_http_proxy_module.html#proxy_buffering
|
||||
e.Response.Header().Set("X-Accel-Buffering", "no")
|
||||
|
||||
connectEvent := new(core.RealtimeConnectRequestEvent)
|
||||
connectEvent.RequestEvent = e
|
||||
connectEvent.Client = subscriptions.NewDefaultClient()
|
||||
connectEvent.IdleTimeout = 5 * time.Minute
|
||||
|
||||
return e.App.OnRealtimeConnectRequest().Trigger(connectEvent, func(ce *core.RealtimeConnectRequestEvent) error {
|
||||
// register new subscription client
|
||||
ce.App.SubscriptionsBroker().Register(ce.Client)
|
||||
defer func() {
|
||||
e.App.SubscriptionsBroker().Unregister(ce.Client.Id())
|
||||
}()
|
||||
|
||||
ce.App.Logger().Debug("Realtime connection established.", slog.String("clientId", ce.Client.Id()))
|
||||
|
||||
// signalize established connection (aka. fire "connect" message)
|
||||
connectMsgEvent := new(core.RealtimeMessageEvent)
|
||||
connectMsgEvent.RequestEvent = ce.RequestEvent
|
||||
connectMsgEvent.Client = ce.Client
|
||||
connectMsgEvent.Message = &subscriptions.Message{
|
||||
Name: "PB_CONNECT",
|
||||
Data: []byte(`{"clientId":"` + ce.Client.Id() + `"}`),
|
||||
}
|
||||
connectMsgErr := ce.App.OnRealtimeMessageSend().Trigger(connectMsgEvent, func(me *core.RealtimeMessageEvent) error {
|
||||
err := me.Message.WriteSSE(me.Response, me.Client.Id())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return me.Flush()
|
||||
})
|
||||
if connectMsgErr != nil {
|
||||
ce.App.Logger().Debug(
|
||||
"Realtime connection closed (failed to deliver PB_CONNECT)",
|
||||
slog.String("clientId", ce.Client.Id()),
|
||||
slog.String("error", connectMsgErr.Error()),
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
// start an idle timer to keep track of inactive/forgotten connections
|
||||
idleTimer := time.NewTimer(ce.IdleTimeout)
|
||||
defer idleTimer.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-idleTimer.C:
|
||||
cancelRequest()
|
||||
case msg, ok := <-ce.Client.Channel():
|
||||
if !ok {
|
||||
// channel is closed
|
||||
ce.App.Logger().Debug(
|
||||
"Realtime connection closed (closed channel)",
|
||||
slog.String("clientId", ce.Client.Id()),
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
msgEvent := new(core.RealtimeMessageEvent)
|
||||
msgEvent.RequestEvent = ce.RequestEvent
|
||||
msgEvent.Client = ce.Client
|
||||
msgEvent.Message = &msg
|
||||
msgErr := ce.App.OnRealtimeMessageSend().Trigger(msgEvent, func(me *core.RealtimeMessageEvent) error {
|
||||
err := me.Message.WriteSSE(me.Response, me.Client.Id())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return me.Flush()
|
||||
})
|
||||
if msgErr != nil {
|
||||
ce.App.Logger().Debug(
|
||||
"Realtime connection closed (failed to deliver message)",
|
||||
slog.String("clientId", ce.Client.Id()),
|
||||
slog.String("error", msgErr.Error()),
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
idleTimer.Stop()
|
||||
idleTimer.Reset(ce.IdleTimeout)
|
||||
case <-ce.Request.Context().Done():
|
||||
// connection is closed
|
||||
ce.App.Logger().Debug(
|
||||
"Realtime connection closed (cancelled request)",
|
||||
slog.String("clientId", ce.Client.Id()),
|
||||
)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
type realtimeSubscribeForm struct {
|
||||
ClientId string `form:"clientId" json:"clientId"`
|
||||
Subscriptions []string `form:"subscriptions" json:"subscriptions"`
|
||||
}
|
||||
|
||||
func (form *realtimeSubscribeForm) validate() error {
|
||||
return validation.ValidateStruct(form,
|
||||
validation.Field(&form.ClientId, validation.Required, validation.Length(1, 255)),
|
||||
validation.Field(&form.Subscriptions,
|
||||
validation.Length(0, 1000),
|
||||
validation.Each(validation.Length(0, 2500)),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
// note: in case of reconnect, clients will have to resubmit all subscriptions again
|
||||
func realtimeSetSubscriptions(e *core.RequestEvent) error {
|
||||
form := new(realtimeSubscribeForm)
|
||||
|
||||
err := e.BindBody(form)
|
||||
if err != nil {
|
||||
return e.BadRequestError("", err)
|
||||
}
|
||||
|
||||
err = form.validate()
|
||||
if err != nil {
|
||||
return e.BadRequestError("", err)
|
||||
}
|
||||
|
||||
// find subscription client
|
||||
client, err := e.App.SubscriptionsBroker().ClientById(form.ClientId)
|
||||
if err != nil {
|
||||
return e.NotFoundError("Missing or invalid client id.", err)
|
||||
}
|
||||
|
||||
// for now allow only guest->auth upgrades and any other auth change is forbidden
|
||||
clientAuth, _ := client.Get(RealtimeClientAuthKey).(*core.Record)
|
||||
if clientAuth != nil && !isSameAuth(clientAuth, e.Auth) {
|
||||
return e.ForbiddenError("The current and the previous request authorization don't match.", nil)
|
||||
}
|
||||
|
||||
event := new(core.RealtimeSubscribeRequestEvent)
|
||||
event.RequestEvent = e
|
||||
event.Client = client
|
||||
event.Subscriptions = form.Subscriptions
|
||||
|
||||
return e.App.OnRealtimeSubscribeRequest().Trigger(event, func(e *core.RealtimeSubscribeRequestEvent) error {
|
||||
// update auth state
|
||||
e.Client.Set(RealtimeClientAuthKey, e.Auth)
|
||||
|
||||
// unsubscribe from any previous existing subscriptions
|
||||
e.Client.Unsubscribe()
|
||||
|
||||
// subscribe to the new subscriptions
|
||||
e.Client.Subscribe(e.Subscriptions...)
|
||||
|
||||
e.App.Logger().Debug(
|
||||
"Realtime subscriptions updated.",
|
||||
slog.String("clientId", e.Client.Id()),
|
||||
slog.Any("subscriptions", e.Subscriptions),
|
||||
)
|
||||
|
||||
return execAfterSuccessTx(true, e.App, func() error {
|
||||
return e.NoContent(http.StatusNoContent)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// updateClientsAuth updates the existing clients auth record with the new one (matched by ID).
|
||||
func realtimeUpdateClientsAuth(app core.App, newAuthRecord *core.Record) error {
|
||||
chunks := app.SubscriptionsBroker().ChunkedClients(clientsChunkSize)
|
||||
|
||||
group := new(errgroup.Group)
|
||||
|
||||
for _, chunk := range chunks {
|
||||
group.Go(func() error {
|
||||
for _, client := range chunk {
|
||||
clientAuth, _ := client.Get(RealtimeClientAuthKey).(*core.Record)
|
||||
if clientAuth != nil &&
|
||||
clientAuth.Id == newAuthRecord.Id &&
|
||||
clientAuth.Collection().Name == newAuthRecord.Collection().Name {
|
||||
client.Set(RealtimeClientAuthKey, newAuthRecord)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
return group.Wait()
|
||||
}
|
||||
|
||||
// realtimeUnsetClientsAuthState unsets the auth state of all clients that have the provided auth model.
|
||||
func realtimeUnsetClientsAuthState(app core.App, authModel core.Model) error {
|
||||
chunks := app.SubscriptionsBroker().ChunkedClients(clientsChunkSize)
|
||||
|
||||
group := new(errgroup.Group)
|
||||
|
||||
for _, chunk := range chunks {
|
||||
group.Go(func() error {
|
||||
for _, client := range chunk {
|
||||
clientAuth, _ := client.Get(RealtimeClientAuthKey).(*core.Record)
|
||||
if clientAuth != nil &&
|
||||
clientAuth.Id == authModel.PK() &&
|
||||
clientAuth.Collection().Name == authModel.TableName() {
|
||||
client.Unset(RealtimeClientAuthKey)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
return group.Wait()
|
||||
}
|
||||
|
||||
func bindRealtimeEvents(app core.App) {
|
||||
// update the clients that has auth record association
|
||||
app.OnModelAfterUpdateSuccess().Bind(&hook.Handler[*core.ModelEvent]{
|
||||
Func: func(e *core.ModelEvent) error {
|
||||
authRecord := realtimeResolveRecord(e.App, e.Model, core.CollectionTypeAuth)
|
||||
if authRecord != nil {
|
||||
if err := realtimeUpdateClientsAuth(e.App, authRecord); err != nil {
|
||||
app.Logger().Warn(
|
||||
"Failed to update client(s) associated to the updated auth record",
|
||||
slog.Any("id", authRecord.Id),
|
||||
slog.String("collectionName", authRecord.Collection().Name),
|
||||
slog.String("error", err.Error()),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return e.Next()
|
||||
},
|
||||
Priority: -99,
|
||||
})
|
||||
|
||||
// remove the client(s) associated to the deleted auth model
|
||||
// (note: works also with custom model for backward compatibility)
|
||||
app.OnModelAfterDeleteSuccess().Bind(&hook.Handler[*core.ModelEvent]{
|
||||
Func: func(e *core.ModelEvent) error {
|
||||
collection := realtimeResolveRecordCollection(e.App, e.Model)
|
||||
if collection != nil && collection.IsAuth() {
|
||||
if err := realtimeUnsetClientsAuthState(e.App, e.Model); err != nil {
|
||||
app.Logger().Warn(
|
||||
"Failed to remove client(s) associated to the deleted auth model",
|
||||
slog.Any("id", e.Model.PK()),
|
||||
slog.String("collectionName", e.Model.TableName()),
|
||||
slog.String("error", err.Error()),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return e.Next()
|
||||
},
|
||||
Priority: -99,
|
||||
})
|
||||
|
||||
app.OnModelAfterCreateSuccess().Bind(&hook.Handler[*core.ModelEvent]{
|
||||
Func: func(e *core.ModelEvent) error {
|
||||
record := realtimeResolveRecord(e.App, e.Model, "")
|
||||
if record != nil {
|
||||
err := realtimeBroadcastRecord(e.App, "create", record, false)
|
||||
if err != nil {
|
||||
app.Logger().Debug(
|
||||
"Failed to broadcast record create",
|
||||
slog.String("id", record.Id),
|
||||
slog.String("collectionName", record.Collection().Name),
|
||||
slog.String("error", err.Error()),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return e.Next()
|
||||
},
|
||||
Priority: -99,
|
||||
})
|
||||
|
||||
app.OnModelAfterUpdateSuccess().Bind(&hook.Handler[*core.ModelEvent]{
|
||||
Func: func(e *core.ModelEvent) error {
|
||||
record := realtimeResolveRecord(e.App, e.Model, "")
|
||||
if record != nil {
|
||||
err := realtimeBroadcastRecord(e.App, "update", record, false)
|
||||
if err != nil {
|
||||
app.Logger().Debug(
|
||||
"Failed to broadcast record update",
|
||||
slog.String("id", record.Id),
|
||||
slog.String("collectionName", record.Collection().Name),
|
||||
slog.String("error", err.Error()),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return e.Next()
|
||||
},
|
||||
Priority: -99,
|
||||
})
|
||||
|
||||
// delete: dry cache
|
||||
app.OnModelDelete().Bind(&hook.Handler[*core.ModelEvent]{
|
||||
Func: func(e *core.ModelEvent) error {
|
||||
record := realtimeResolveRecord(e.App, e.Model, "")
|
||||
if record != nil {
|
||||
// note: use the outside scoped app instance for the access checks so that the API rules
|
||||
// are performed out of the delete transaction ensuring that they would still work even if
|
||||
// a cascade-deleted record's API rule relies on an already deleted parent record
|
||||
err := realtimeBroadcastRecord(e.App, "delete", record, true, app)
|
||||
if err != nil {
|
||||
app.Logger().Debug(
|
||||
"Failed to dry cache record delete",
|
||||
slog.String("id", record.Id),
|
||||
slog.String("collectionName", record.Collection().Name),
|
||||
slog.String("error", err.Error()),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return e.Next()
|
||||
},
|
||||
Priority: 99, // execute as later as possible
|
||||
})
|
||||
|
||||
// delete: broadcast
|
||||
app.OnModelAfterDeleteSuccess().Bind(&hook.Handler[*core.ModelEvent]{
|
||||
Func: func(e *core.ModelEvent) error {
|
||||
// note: only ensure that it is a collection record
|
||||
// and don't use realtimeResolveRecord because in case of a
|
||||
// custom model it'll fail to resolve since the record is already deleted
|
||||
collection := realtimeResolveRecordCollection(e.App, e.Model)
|
||||
if collection != nil {
|
||||
err := realtimeBroadcastDryCacheKey(e.App, getDryCacheKey("delete", e.Model))
|
||||
if err != nil {
|
||||
app.Logger().Debug(
|
||||
"Failed to broadcast record delete",
|
||||
slog.Any("id", e.Model.PK()),
|
||||
slog.String("collectionName", collection.Name),
|
||||
slog.String("error", err.Error()),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return e.Next()
|
||||
},
|
||||
Priority: -99,
|
||||
})
|
||||
|
||||
// delete: failure
|
||||
app.OnModelAfterDeleteError().Bind(&hook.Handler[*core.ModelErrorEvent]{
|
||||
Func: func(e *core.ModelErrorEvent) error {
|
||||
record := realtimeResolveRecord(e.App, e.Model, "")
|
||||
if record != nil {
|
||||
err := realtimeUnsetDryCacheKey(e.App, getDryCacheKey("delete", record))
|
||||
if err != nil {
|
||||
app.Logger().Debug(
|
||||
"Failed to cleanup after broadcast record delete failure",
|
||||
slog.String("id", record.Id),
|
||||
slog.String("collectionName", record.Collection().Name),
|
||||
slog.String("error", err.Error()),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return e.Next()
|
||||
},
|
||||
Priority: -99,
|
||||
})
|
||||
}
|
||||
|
||||
// resolveRecord converts *if possible* the provided model interface to a Record.
|
||||
// This is usually helpful if the provided model is a custom Record model struct.
|
||||
func realtimeResolveRecord(app core.App, model core.Model, optCollectionType string) *core.Record {
|
||||
var record *core.Record
|
||||
switch m := model.(type) {
|
||||
case *core.Record:
|
||||
record = m
|
||||
case core.RecordProxy:
|
||||
record = m.ProxyRecord()
|
||||
}
|
||||
|
||||
if record != nil {
|
||||
if optCollectionType == "" || record.Collection().Type == optCollectionType {
|
||||
return record
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
tblName := model.TableName()
|
||||
|
||||
// skip Log model checks
|
||||
if tblName == core.LogsTableName {
|
||||
return nil
|
||||
}
|
||||
|
||||
// check if it is custom Record model struct
|
||||
collection, _ := app.FindCachedCollectionByNameOrId(tblName)
|
||||
if collection != nil && (optCollectionType == "" || collection.Type == optCollectionType) {
|
||||
if id, ok := model.PK().(string); ok {
|
||||
record, _ = app.FindRecordById(collection, id)
|
||||
}
|
||||
}
|
||||
|
||||
return record
|
||||
}
|
||||
|
||||
// realtimeResolveRecordCollection extracts *if possible* the Collection model from the provided model interface.
|
||||
// This is usually helpful if the provided model is a custom Record model struct.
|
||||
func realtimeResolveRecordCollection(app core.App, model core.Model) (collection *core.Collection) {
|
||||
switch m := model.(type) {
|
||||
case *core.Record:
|
||||
return m.Collection()
|
||||
case core.RecordProxy:
|
||||
return m.ProxyRecord().Collection()
|
||||
default:
|
||||
// check if it is custom Record model struct
|
||||
collection, err := app.FindCachedCollectionByNameOrId(model.TableName())
|
||||
if err == nil {
|
||||
return collection
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// recordData represents the broadcasted record subscrition message data.
|
||||
type recordData struct {
|
||||
Record any `json:"record"` /* map or core.Record */
|
||||
Action string `json:"action"`
|
||||
}
|
||||
|
||||
// Note: the optAccessCheckApp is there in case you want the access check
|
||||
// to be performed against different db app context (e.g. out of a transaction).
|
||||
// If set, it is expected that optAccessCheckApp instance is used for read-only operations to avoid deadlocks.
|
||||
// If not set, it fallbacks to app.
|
||||
func realtimeBroadcastRecord(app core.App, action string, record *core.Record, dryCache bool, optAccessCheckApp ...core.App) error {
|
||||
collection := record.Collection()
|
||||
if collection == nil {
|
||||
return errors.New("[broadcastRecord] Record collection not set")
|
||||
}
|
||||
|
||||
chunks := app.SubscriptionsBroker().ChunkedClients(clientsChunkSize)
|
||||
if len(chunks) == 0 {
|
||||
return nil // no subscribers
|
||||
}
|
||||
|
||||
subscriptionRuleMap := map[string]*string{
|
||||
(collection.Name + "/" + record.Id + "?"): collection.ViewRule,
|
||||
(collection.Id + "/" + record.Id + "?"): collection.ViewRule,
|
||||
(collection.Name + "/*?"): collection.ListRule,
|
||||
(collection.Id + "/*?"): collection.ListRule,
|
||||
|
||||
// @deprecated: the same as the wildcard topic but kept for backward compatibility
|
||||
(collection.Name + "?"): collection.ListRule,
|
||||
(collection.Id + "?"): collection.ListRule,
|
||||
}
|
||||
|
||||
dryCacheKey := getDryCacheKey(action, record)
|
||||
|
||||
group := new(errgroup.Group)
|
||||
|
||||
accessCheckApp := app
|
||||
if len(optAccessCheckApp) > 0 {
|
||||
accessCheckApp = optAccessCheckApp[0]
|
||||
}
|
||||
|
||||
for _, chunk := range chunks {
|
||||
group.Go(func() error {
|
||||
var clientAuth *core.Record
|
||||
|
||||
for _, client := range chunk {
|
||||
// note: not executed concurrently to avoid races and to ensure
|
||||
// that the access checks are applied for the current record db state
|
||||
for prefix, rule := range subscriptionRuleMap {
|
||||
subs := client.Subscriptions(prefix)
|
||||
if len(subs) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
clientAuth, _ = client.Get(RealtimeClientAuthKey).(*core.Record)
|
||||
|
||||
for sub, options := range subs {
|
||||
// mock request data
|
||||
requestInfo := &core.RequestInfo{
|
||||
Context: core.RequestInfoContextRealtime,
|
||||
Method: "GET",
|
||||
Query: options.Query,
|
||||
Headers: options.Headers,
|
||||
Auth: clientAuth,
|
||||
}
|
||||
|
||||
if !realtimeCanAccessRecord(accessCheckApp, record, requestInfo, rule) {
|
||||
continue
|
||||
}
|
||||
|
||||
// create a clean record copy without expand and unknown fields because we don't know yet
|
||||
// which exact fields the client subscription requested or has permissions to access
|
||||
cleanRecord := record.Fresh()
|
||||
|
||||
// trigger the enrich hooks
|
||||
enrichErr := triggerRecordEnrichHooks(app, requestInfo, []*core.Record{cleanRecord}, func() error {
|
||||
// apply expand
|
||||
rawExpand := options.Query[expandQueryParam]
|
||||
if rawExpand != "" {
|
||||
expandErrs := app.ExpandRecord(cleanRecord, strings.Split(rawExpand, ","), expandFetch(app, requestInfo))
|
||||
if len(expandErrs) > 0 {
|
||||
app.Logger().Debug(
|
||||
"[broadcastRecord] expand errors",
|
||||
slog.String("id", cleanRecord.Id),
|
||||
slog.String("collectionName", cleanRecord.Collection().Name),
|
||||
slog.String("sub", sub),
|
||||
slog.String("expand", rawExpand),
|
||||
slog.Any("errors", expandErrs),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// ignore the auth record email visibility checks
|
||||
// for auth owner, superuser or manager
|
||||
if collection.IsAuth() {
|
||||
if isSameAuth(clientAuth, cleanRecord) ||
|
||||
realtimeCanAccessRecord(accessCheckApp, cleanRecord, requestInfo, collection.ManageRule) {
|
||||
cleanRecord.IgnoreEmailVisibility(true)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if enrichErr != nil {
|
||||
app.Logger().Debug(
|
||||
"[broadcastRecord] record enrich error",
|
||||
slog.String("id", cleanRecord.Id),
|
||||
slog.String("collectionName", cleanRecord.Collection().Name),
|
||||
slog.String("sub", sub),
|
||||
slog.Any("error", enrichErr),
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
data := &recordData{
|
||||
Action: action,
|
||||
Record: cleanRecord,
|
||||
}
|
||||
|
||||
// check fields
|
||||
rawFields := options.Query[fieldsQueryParam]
|
||||
if rawFields != "" {
|
||||
decoded, err := picker.Pick(cleanRecord, rawFields)
|
||||
if err == nil {
|
||||
data.Record = decoded
|
||||
} else {
|
||||
app.Logger().Debug(
|
||||
"[broadcastRecord] pick fields error",
|
||||
slog.String("id", cleanRecord.Id),
|
||||
slog.String("collectionName", cleanRecord.Collection().Name),
|
||||
slog.String("sub", sub),
|
||||
slog.String("fields", rawFields),
|
||||
slog.String("error", err.Error()),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
dataBytes, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
app.Logger().Debug(
|
||||
"[broadcastRecord] data marshal error",
|
||||
slog.String("id", cleanRecord.Id),
|
||||
slog.String("collectionName", cleanRecord.Collection().Name),
|
||||
slog.String("error", err.Error()),
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
msg := subscriptions.Message{
|
||||
Name: sub,
|
||||
Data: dataBytes,
|
||||
}
|
||||
|
||||
if dryCache {
|
||||
messages, ok := client.Get(dryCacheKey).([]subscriptions.Message)
|
||||
if !ok {
|
||||
messages = []subscriptions.Message{msg}
|
||||
} else {
|
||||
messages = append(messages, msg)
|
||||
}
|
||||
client.Set(dryCacheKey, messages)
|
||||
} else {
|
||||
routine.FireAndForget(func() {
|
||||
client.Send(msg)
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
return group.Wait()
|
||||
}
|
||||
|
||||
// realtimeBroadcastDryCacheKey broadcasts the dry cached key related messages.
|
||||
func realtimeBroadcastDryCacheKey(app core.App, key string) error {
|
||||
chunks := app.SubscriptionsBroker().ChunkedClients(clientsChunkSize)
|
||||
if len(chunks) == 0 {
|
||||
return nil // no subscribers
|
||||
}
|
||||
|
||||
group := new(errgroup.Group)
|
||||
|
||||
for _, chunk := range chunks {
|
||||
group.Go(func() error {
|
||||
for _, client := range chunk {
|
||||
messages, ok := client.Get(key).([]subscriptions.Message)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
client.Unset(key)
|
||||
|
||||
client := client
|
||||
|
||||
routine.FireAndForget(func() {
|
||||
for _, msg := range messages {
|
||||
client.Send(msg)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
return group.Wait()
|
||||
}
|
||||
|
||||
// realtimeUnsetDryCacheKey removes the dry cached key related messages.
|
||||
func realtimeUnsetDryCacheKey(app core.App, key string) error {
|
||||
chunks := app.SubscriptionsBroker().ChunkedClients(clientsChunkSize)
|
||||
if len(chunks) == 0 {
|
||||
return nil // no subscribers
|
||||
}
|
||||
|
||||
group := new(errgroup.Group)
|
||||
|
||||
for _, chunk := range chunks {
|
||||
group.Go(func() error {
|
||||
for _, client := range chunk {
|
||||
if client.Get(key) != nil {
|
||||
client.Unset(key)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
return group.Wait()
|
||||
}
|
||||
|
||||
func getDryCacheKey(action string, model core.Model) string {
|
||||
pkStr, ok := model.PK().(string)
|
||||
if !ok {
|
||||
pkStr = fmt.Sprintf("%v", model.PK())
|
||||
}
|
||||
|
||||
return action + "/" + model.TableName() + "/" + pkStr
|
||||
}
|
||||
|
||||
func isSameAuth(authA, authB *core.Record) bool {
|
||||
if authA == nil {
|
||||
return authB == nil
|
||||
}
|
||||
|
||||
if authB == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return authA.Id == authB.Id && authA.Collection().Id == authB.Collection().Id
|
||||
}
|
||||
|
||||
// realtimeCanAccessRecord checks if the subscription client has access to the specified record model.
|
||||
func realtimeCanAccessRecord(
|
||||
app core.App,
|
||||
record *core.Record,
|
||||
requestInfo *core.RequestInfo,
|
||||
accessRule *string,
|
||||
) bool {
|
||||
// check the access rule
|
||||
// ---
|
||||
if ok, _ := app.CanAccessRecord(record, requestInfo, accessRule); !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
// check the subscription client-side filter (if any)
|
||||
// ---
|
||||
filter := requestInfo.Query[search.FilterQueryParam]
|
||||
if filter == "" {
|
||||
return true // no further checks needed
|
||||
}
|
||||
|
||||
err := checkForSuperuserOnlyRuleFields(requestInfo)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
var exists int
|
||||
|
||||
q := app.ConcurrentDB().Select("(1)").
|
||||
From(record.Collection().Name).
|
||||
AndWhere(dbx.HashExp{record.Collection().Name + ".id": record.Id})
|
||||
|
||||
resolver := core.NewRecordFieldResolver(app, record.Collection(), requestInfo, false)
|
||||
expr, err := search.FilterData(filter).BuildExpr(resolver)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
q.AndWhere(expr)
|
||||
resolver.UpdateQuery(q)
|
||||
|
||||
err = q.Limit(1).Row(&exists)
|
||||
|
||||
return err == nil && exists > 0
|
||||
}
|
885
apis/realtime_test.go
Normal file
885
apis/realtime_test.go
Normal file
|
@ -0,0 +1,885 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/apis"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
"github.com/pocketbase/pocketbase/tools/subscriptions"
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
)
|
||||
|
||||
func TestRealtimeConnect(t *testing.T) {
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/realtime",
|
||||
Timeout: 100 * time.Millisecond,
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`id:`,
|
||||
`event:PB_CONNECT`,
|
||||
`data:{"clientId":`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRealtimeConnectRequest": 1,
|
||||
"OnRealtimeMessageSend": 1,
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
if len(app.SubscriptionsBroker().Clients()) != 0 {
|
||||
t.Errorf("Expected the subscribers to be removed after connection close, found %d", len(app.SubscriptionsBroker().Clients()))
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "PB_CONNECT interrupt",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/realtime",
|
||||
Timeout: 100 * time.Millisecond,
|
||||
ExpectedStatus: 200,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRealtimeConnectRequest": 1,
|
||||
"OnRealtimeMessageSend": 1,
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.OnRealtimeMessageSend().BindFunc(func(e *core.RealtimeMessageEvent) error {
|
||||
if e.Message.Name == "PB_CONNECT" {
|
||||
return errors.New("PB_CONNECT error")
|
||||
}
|
||||
return e.Next()
|
||||
})
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
if len(app.SubscriptionsBroker().Clients()) != 0 {
|
||||
t.Errorf("Expected the subscribers to be removed after connection close, found %d", len(app.SubscriptionsBroker().Clients()))
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "Skipping/ignoring messages",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/realtime",
|
||||
Timeout: 100 * time.Millisecond,
|
||||
ExpectedStatus: 200,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRealtimeConnectRequest": 1,
|
||||
"OnRealtimeMessageSend": 1,
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.OnRealtimeMessageSend().BindFunc(func(e *core.RealtimeMessageEvent) error {
|
||||
return nil
|
||||
})
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
if len(app.SubscriptionsBroker().Clients()) != 0 {
|
||||
t.Errorf("Expected the subscribers to be removed after connection close, found %d", len(app.SubscriptionsBroker().Clients()))
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRealtimeSubscribe(t *testing.T) {
|
||||
client := subscriptions.NewDefaultClient()
|
||||
|
||||
resetClient := func() {
|
||||
client.Unsubscribe()
|
||||
client.Set(apis.RealtimeClientAuthKey, nil)
|
||||
}
|
||||
|
||||
validSubscriptionsLimit := make([]string, 1000)
|
||||
for i := 0; i < len(validSubscriptionsLimit); i++ {
|
||||
validSubscriptionsLimit[i] = fmt.Sprintf(`"%d"`, i)
|
||||
}
|
||||
invalidSubscriptionsLimit := make([]string, 1001)
|
||||
for i := 0; i < len(invalidSubscriptionsLimit); i++ {
|
||||
invalidSubscriptionsLimit[i] = fmt.Sprintf(`"%d"`, i)
|
||||
}
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "missing client",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/realtime",
|
||||
Body: strings.NewReader(`{"clientId":"missing","subscriptions":["test1", "test2"]}`),
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "empty data",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/realtime",
|
||||
Body: strings.NewReader(`{}`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"clientId":{"code":"validation_required`,
|
||||
},
|
||||
NotExpectedContent: []string{
|
||||
`"subscriptions"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "existing client with invalid subscriptions limit",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/realtime",
|
||||
Body: strings.NewReader(`{
|
||||
"clientId": "` + client.Id() + `",
|
||||
"subscriptions": [` + strings.Join(invalidSubscriptionsLimit, ",") + `]
|
||||
}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.SubscriptionsBroker().Register(client)
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
resetClient()
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"subscriptions":{"code":"validation_length_too_long"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "existing client with valid subscriptions limit",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/realtime",
|
||||
Body: strings.NewReader(`{
|
||||
"clientId": "` + client.Id() + `",
|
||||
"subscriptions": [` + strings.Join(validSubscriptionsLimit, ",") + `]
|
||||
}`),
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRealtimeSubscribeRequest": 1,
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
client.Subscribe("test0") // should be replaced
|
||||
app.SubscriptionsBroker().Register(client)
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
if len(client.Subscriptions()) != len(validSubscriptionsLimit) {
|
||||
t.Errorf("Expected %d subscriptions, got %d", len(validSubscriptionsLimit), len(client.Subscriptions()))
|
||||
}
|
||||
if client.HasSubscription("test0") {
|
||||
t.Errorf("Expected old subscriptions to be replaced")
|
||||
}
|
||||
resetClient()
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "existing client with invalid topic length",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/realtime",
|
||||
Body: strings.NewReader(`{
|
||||
"clientId": "` + client.Id() + `",
|
||||
"subscriptions": ["abc", "` + strings.Repeat("a", 2501) + `"]
|
||||
}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.SubscriptionsBroker().Register(client)
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
resetClient()
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"subscriptions":{"1":{"code":"validation_length_too_long"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "existing client with valid topic length",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/realtime",
|
||||
Body: strings.NewReader(`{
|
||||
"clientId": "` + client.Id() + `",
|
||||
"subscriptions": ["abc", "` + strings.Repeat("a", 2500) + `"]
|
||||
}`),
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRealtimeSubscribeRequest": 1,
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
client.Subscribe("test0")
|
||||
app.SubscriptionsBroker().Register(client)
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
if len(client.Subscriptions()) != 2 {
|
||||
t.Errorf("Expected %d subscriptions, got %d", 2, len(client.Subscriptions()))
|
||||
}
|
||||
if client.HasSubscription("test0") {
|
||||
t.Errorf("Expected old subscriptions to be replaced")
|
||||
}
|
||||
resetClient()
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "existing client - empty subscriptions",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/realtime",
|
||||
Body: strings.NewReader(`{"clientId":"` + client.Id() + `","subscriptions":[]}`),
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRealtimeSubscribeRequest": 1,
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
client.Subscribe("test0")
|
||||
app.SubscriptionsBroker().Register(client)
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
if len(client.Subscriptions()) != 0 {
|
||||
t.Errorf("Expected no subscriptions, got %d", len(client.Subscriptions()))
|
||||
}
|
||||
resetClient()
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "existing client - 2 new subscriptions",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/realtime",
|
||||
Body: strings.NewReader(`{"clientId":"` + client.Id() + `","subscriptions":["test1", "test2"]}`),
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRealtimeSubscribeRequest": 1,
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
client.Subscribe("test0")
|
||||
app.SubscriptionsBroker().Register(client)
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
expectedSubs := []string{"test1", "test2"}
|
||||
if len(expectedSubs) != len(client.Subscriptions()) {
|
||||
t.Errorf("Expected subscriptions %v, got %v", expectedSubs, client.Subscriptions())
|
||||
}
|
||||
|
||||
for _, s := range expectedSubs {
|
||||
if !client.HasSubscription(s) {
|
||||
t.Errorf("Cannot find %q subscription in %v", s, client.Subscriptions())
|
||||
}
|
||||
}
|
||||
resetClient()
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "existing client - guest -> authorized superuser",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/realtime",
|
||||
Body: strings.NewReader(`{"clientId":"` + client.Id() + `","subscriptions":["test1", "test2"]}`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRealtimeSubscribeRequest": 1,
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.SubscriptionsBroker().Register(client)
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
authRecord, _ := client.Get(apis.RealtimeClientAuthKey).(*core.Record)
|
||||
if authRecord == nil || !authRecord.IsSuperuser() {
|
||||
t.Errorf("Expected superuser auth record, got %v", authRecord)
|
||||
}
|
||||
resetClient()
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "existing client - guest -> authorized regular auth record",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/realtime",
|
||||
Body: strings.NewReader(`{"clientId":"` + client.Id() + `","subscriptions":["test1", "test2"]}`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRealtimeSubscribeRequest": 1,
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.SubscriptionsBroker().Register(client)
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
authRecord, _ := client.Get(apis.RealtimeClientAuthKey).(*core.Record)
|
||||
if authRecord == nil {
|
||||
t.Errorf("Expected regular user auth record, got %v", authRecord)
|
||||
}
|
||||
resetClient()
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "existing client - same auth",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/realtime",
|
||||
Body: strings.NewReader(`{"clientId":"` + client.Id() + `","subscriptions":["test1", "test2"]}`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRealtimeSubscribeRequest": 1,
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
// the same user as the auth token
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
client.Set(apis.RealtimeClientAuthKey, user)
|
||||
|
||||
app.SubscriptionsBroker().Register(client)
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
authRecord, _ := client.Get(apis.RealtimeClientAuthKey).(*core.Record)
|
||||
if authRecord == nil {
|
||||
t.Errorf("Expected auth record model, got nil")
|
||||
}
|
||||
resetClient()
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "existing client - mismatched auth",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/realtime",
|
||||
Body: strings.NewReader(`{"clientId":"` + client.Id() + `","subscriptions":["test1", "test2"]}`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
user, err := app.FindAuthRecordByEmail("users", "test2@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
client.Set(apis.RealtimeClientAuthKey, user)
|
||||
|
||||
app.SubscriptionsBroker().Register(client)
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
authRecord, _ := client.Get(apis.RealtimeClientAuthKey).(*core.Record)
|
||||
if authRecord == nil {
|
||||
t.Errorf("Expected auth record model, got nil")
|
||||
}
|
||||
resetClient()
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "existing client - unauthorized client",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/realtime",
|
||||
Body: strings.NewReader(`{"clientId":"` + client.Id() + `","subscriptions":["test1", "test2"]}`),
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
user, err := app.FindAuthRecordByEmail("users", "test2@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
client.Set(apis.RealtimeClientAuthKey, user)
|
||||
|
||||
app.SubscriptionsBroker().Register(client)
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
authRecord, _ := client.Get(apis.RealtimeClientAuthKey).(*core.Record)
|
||||
if authRecord == nil {
|
||||
t.Errorf("Expected auth record model, got nil")
|
||||
}
|
||||
resetClient()
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRealtimeAuthRecordDeleteEvent(t *testing.T) {
|
||||
testApp, _ := tests.NewTestApp()
|
||||
defer testApp.Cleanup()
|
||||
|
||||
// init realtime handlers
|
||||
apis.NewRouter(testApp)
|
||||
|
||||
authRecord1, err := testApp.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
authRecord2, err := testApp.FindAuthRecordByEmail("users", "test2@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
client1 := subscriptions.NewDefaultClient()
|
||||
client1.Set(apis.RealtimeClientAuthKey, authRecord1)
|
||||
testApp.SubscriptionsBroker().Register(client1)
|
||||
|
||||
client2 := subscriptions.NewDefaultClient()
|
||||
client2.Set(apis.RealtimeClientAuthKey, authRecord1)
|
||||
testApp.SubscriptionsBroker().Register(client2)
|
||||
|
||||
client3 := subscriptions.NewDefaultClient()
|
||||
client3.Set(apis.RealtimeClientAuthKey, authRecord2)
|
||||
testApp.SubscriptionsBroker().Register(client3)
|
||||
|
||||
// mock delete event
|
||||
e := new(core.ModelEvent)
|
||||
e.App = testApp
|
||||
e.Type = core.ModelEventTypeDelete
|
||||
e.Context = context.Background()
|
||||
e.Model = authRecord1
|
||||
|
||||
testApp.OnModelAfterDeleteSuccess().Trigger(e)
|
||||
|
||||
if total := len(testApp.SubscriptionsBroker().Clients()); total != 3 {
|
||||
t.Fatalf("Expected %d subscription clients, found %d", 3, total)
|
||||
}
|
||||
|
||||
if auth := client1.Get(apis.RealtimeClientAuthKey); auth != nil {
|
||||
t.Fatalf("[client1] Expected the auth state to be unset, found %#v", auth)
|
||||
}
|
||||
|
||||
if auth := client2.Get(apis.RealtimeClientAuthKey); auth != nil {
|
||||
t.Fatalf("[client2] Expected the auth state to be unset, found %#v", auth)
|
||||
}
|
||||
|
||||
if auth := client3.Get(apis.RealtimeClientAuthKey); auth == nil || auth.(*core.Record).Id != authRecord2.Id {
|
||||
t.Fatalf("[client3] Expected the auth state to be left unchanged, found %#v", auth)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRealtimeAuthRecordUpdateEvent(t *testing.T) {
|
||||
testApp, _ := tests.NewTestApp()
|
||||
defer testApp.Cleanup()
|
||||
|
||||
// init realtime handlers
|
||||
apis.NewRouter(testApp)
|
||||
|
||||
authRecord1, err := testApp.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
client := subscriptions.NewDefaultClient()
|
||||
client.Set(apis.RealtimeClientAuthKey, authRecord1)
|
||||
testApp.SubscriptionsBroker().Register(client)
|
||||
|
||||
// refetch the authRecord and change its email
|
||||
authRecord2, err := testApp.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
authRecord2.SetEmail("new@example.com")
|
||||
|
||||
// mock update event
|
||||
e := new(core.ModelEvent)
|
||||
e.App = testApp
|
||||
e.Type = core.ModelEventTypeUpdate
|
||||
e.Context = context.Background()
|
||||
e.Model = authRecord2
|
||||
|
||||
testApp.OnModelAfterUpdateSuccess().Trigger(e)
|
||||
|
||||
clientAuthRecord, _ := client.Get(apis.RealtimeClientAuthKey).(*core.Record)
|
||||
if clientAuthRecord.Email() != authRecord2.Email() {
|
||||
t.Fatalf("Expected authRecord with email %q, got %q", authRecord2.Email(), clientAuthRecord.Email())
|
||||
}
|
||||
}
|
||||
|
||||
// Custom auth record model struct
|
||||
// -------------------------------------------------------------------
|
||||
var _ core.Model = (*CustomUser)(nil)
|
||||
|
||||
type CustomUser struct {
|
||||
core.BaseModel
|
||||
|
||||
Email string `db:"email" json:"email"`
|
||||
}
|
||||
|
||||
func (m *CustomUser) TableName() string {
|
||||
return "users"
|
||||
}
|
||||
|
||||
func findCustomUserByEmail(app core.App, email string) (*CustomUser, error) {
|
||||
model := &CustomUser{}
|
||||
|
||||
err := app.ModelQuery(model).
|
||||
AndWhere(dbx.HashExp{"email": email}).
|
||||
Limit(1).
|
||||
One(model)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return model, nil
|
||||
}
|
||||
|
||||
func TestRealtimeCustomAuthModelDeleteEvent(t *testing.T) {
|
||||
testApp, _ := tests.NewTestApp()
|
||||
defer testApp.Cleanup()
|
||||
|
||||
// init realtime handlers
|
||||
apis.NewRouter(testApp)
|
||||
|
||||
authRecord1, err := testApp.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
authRecord2, err := testApp.FindAuthRecordByEmail("users", "test2@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
client1 := subscriptions.NewDefaultClient()
|
||||
client1.Set(apis.RealtimeClientAuthKey, authRecord1)
|
||||
testApp.SubscriptionsBroker().Register(client1)
|
||||
|
||||
client2 := subscriptions.NewDefaultClient()
|
||||
client2.Set(apis.RealtimeClientAuthKey, authRecord1)
|
||||
testApp.SubscriptionsBroker().Register(client2)
|
||||
|
||||
client3 := subscriptions.NewDefaultClient()
|
||||
client3.Set(apis.RealtimeClientAuthKey, authRecord2)
|
||||
testApp.SubscriptionsBroker().Register(client3)
|
||||
|
||||
// refetch the authRecord as CustomUser
|
||||
customUser, err := findCustomUserByEmail(testApp, authRecord1.Email())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// delete the custom user (should unset the client auth record)
|
||||
if err := testApp.Delete(customUser); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if total := len(testApp.SubscriptionsBroker().Clients()); total != 3 {
|
||||
t.Fatalf("Expected %d subscription clients, found %d", 3, total)
|
||||
}
|
||||
|
||||
if auth := client1.Get(apis.RealtimeClientAuthKey); auth != nil {
|
||||
t.Fatalf("[client1] Expected the auth state to be unset, found %#v", auth)
|
||||
}
|
||||
|
||||
if auth := client2.Get(apis.RealtimeClientAuthKey); auth != nil {
|
||||
t.Fatalf("[client2] Expected the auth state to be unset, found %#v", auth)
|
||||
}
|
||||
|
||||
if auth := client3.Get(apis.RealtimeClientAuthKey); auth == nil || auth.(*core.Record).Id != authRecord2.Id {
|
||||
t.Fatalf("[client3] Expected the auth state to be left unchanged, found %#v", auth)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRealtimeCustomAuthModelUpdateEvent(t *testing.T) {
|
||||
testApp, _ := tests.NewTestApp()
|
||||
defer testApp.Cleanup()
|
||||
|
||||
// init realtime handlers
|
||||
apis.NewRouter(testApp)
|
||||
|
||||
authRecord, err := testApp.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
client := subscriptions.NewDefaultClient()
|
||||
client.Set(apis.RealtimeClientAuthKey, authRecord)
|
||||
testApp.SubscriptionsBroker().Register(client)
|
||||
|
||||
// refetch the authRecord as CustomUser
|
||||
customUser, err := findCustomUserByEmail(testApp, "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// change its email
|
||||
customUser.Email = "new@example.com"
|
||||
if err := testApp.Save(customUser); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
clientAuthRecord, _ := client.Get(apis.RealtimeClientAuthKey).(*core.Record)
|
||||
if clientAuthRecord.Email() != customUser.Email {
|
||||
t.Fatalf("Expected authRecord with email %q, got %q", customUser.Email, clientAuthRecord.Email())
|
||||
}
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
var _ core.Model = (*CustomModelResolve)(nil)
|
||||
|
||||
type CustomModelResolve struct {
|
||||
core.BaseModel
|
||||
tableName string
|
||||
|
||||
Created string `db:"created"`
|
||||
}
|
||||
|
||||
func (m *CustomModelResolve) TableName() string {
|
||||
return m.tableName
|
||||
}
|
||||
|
||||
func TestRealtimeRecordResolve(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const testCollectionName = "realtime_test_collection"
|
||||
|
||||
testRecordId := core.GenerateDefaultRandomId()
|
||||
|
||||
client0 := subscriptions.NewDefaultClient()
|
||||
client0.Subscribe(testCollectionName + "/*")
|
||||
client0.Discard()
|
||||
// ---
|
||||
client1 := subscriptions.NewDefaultClient()
|
||||
client1.Subscribe(testCollectionName + "/*")
|
||||
// ---
|
||||
client2 := subscriptions.NewDefaultClient()
|
||||
client2.Subscribe(testCollectionName + "/" + testRecordId)
|
||||
// ---
|
||||
client3 := subscriptions.NewDefaultClient()
|
||||
client3.Subscribe("demo1/*")
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
op func(testApp core.App) error
|
||||
expected map[string][]string // clientId -> [events]
|
||||
}{
|
||||
{
|
||||
"core.Record",
|
||||
func(testApp core.App) error {
|
||||
c, err := testApp.FindCollectionByNameOrId(testCollectionName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r := core.NewRecord(c)
|
||||
r.Id = testRecordId
|
||||
|
||||
// create
|
||||
err = testApp.Save(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// update
|
||||
err = testApp.Save(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// delete
|
||||
err = testApp.Delete(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
map[string][]string{
|
||||
client1.Id(): {"create", "update", "delete"},
|
||||
client2.Id(): {"create", "update", "delete"},
|
||||
},
|
||||
},
|
||||
{
|
||||
"core.RecordProxy",
|
||||
func(testApp core.App) error {
|
||||
c, err := testApp.FindCollectionByNameOrId(testCollectionName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r := core.NewRecord(c)
|
||||
|
||||
proxy := &struct {
|
||||
core.BaseRecordProxy
|
||||
}{}
|
||||
proxy.SetProxyRecord(r)
|
||||
proxy.Id = testRecordId
|
||||
|
||||
// create
|
||||
err = testApp.Save(proxy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// update
|
||||
err = testApp.Save(proxy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// delete
|
||||
err = testApp.Delete(proxy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
map[string][]string{
|
||||
client1.Id(): {"create", "update", "delete"},
|
||||
client2.Id(): {"create", "update", "delete"},
|
||||
},
|
||||
},
|
||||
{
|
||||
"custom model struct",
|
||||
func(testApp core.App) error {
|
||||
m := &CustomModelResolve{tableName: testCollectionName}
|
||||
m.Id = testRecordId
|
||||
|
||||
// create
|
||||
err := testApp.Save(m)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// update
|
||||
m.Created = "123"
|
||||
err = testApp.Save(m)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// delete
|
||||
err = testApp.Delete(m)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
map[string][]string{
|
||||
client1.Id(): {"create", "update", "delete"},
|
||||
client2.Id(): {"create", "update", "delete"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
testApp, _ := tests.NewTestApp()
|
||||
defer testApp.Cleanup()
|
||||
|
||||
// init realtime handlers
|
||||
apis.NewRouter(testApp)
|
||||
|
||||
// create new test collection with public read access
|
||||
testCollection := core.NewBaseCollection(testCollectionName)
|
||||
testCollection.Fields.Add(&core.AutodateField{Name: "created", OnCreate: true, OnUpdate: true})
|
||||
testCollection.ListRule = types.Pointer("")
|
||||
testCollection.ViewRule = types.Pointer("")
|
||||
err := testApp.Save(testCollection)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
testApp.SubscriptionsBroker().Register(client0)
|
||||
testApp.SubscriptionsBroker().Register(client1)
|
||||
testApp.SubscriptionsBroker().Register(client2)
|
||||
testApp.SubscriptionsBroker().Register(client3)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
var notifications = map[string][]string{}
|
||||
|
||||
var mu sync.Mutex
|
||||
notify := func(clientId string, eventData []byte) {
|
||||
data := struct{ Action string }{}
|
||||
_ = json.Unmarshal(eventData, &data)
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
if notifications[clientId] == nil {
|
||||
notifications[clientId] = []string{}
|
||||
}
|
||||
notifications[clientId] = append(notifications[clientId], data.Action)
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
timeout := time.After(250 * time.Millisecond)
|
||||
|
||||
for {
|
||||
select {
|
||||
case e, ok := <-client0.Channel():
|
||||
if ok {
|
||||
notify(client0.Id(), e.Data)
|
||||
}
|
||||
case e, ok := <-client1.Channel():
|
||||
if ok {
|
||||
notify(client1.Id(), e.Data)
|
||||
}
|
||||
case e, ok := <-client2.Channel():
|
||||
if ok {
|
||||
notify(client2.Id(), e.Data)
|
||||
}
|
||||
case e, ok := <-client3.Channel():
|
||||
if ok {
|
||||
notify(client3.Id(), e.Data)
|
||||
}
|
||||
case <-timeout:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
err = s.op(testApp)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if len(s.expected) != len(notifications) {
|
||||
t.Fatalf("Expected %d notified clients, got %d:\n%v", len(s.expected), len(notifications), notifications)
|
||||
}
|
||||
|
||||
for id, events := range s.expected {
|
||||
if len(events) != len(notifications[id]) {
|
||||
t.Fatalf("[%s] Expected %d events, got %d:\n%v\n%v", id, len(events), len(notifications[id]), s.expected, notifications)
|
||||
}
|
||||
for _, event := range events {
|
||||
if !slices.Contains(notifications[id], event) {
|
||||
t.Fatalf("[%s] Missing expected event %q in %v", id, event, notifications[id])
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
79
apis/record_auth.go
Normal file
79
apis/record_auth.go
Normal file
|
@ -0,0 +1,79 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tools/router"
|
||||
)
|
||||
|
||||
// bindRecordAuthApi registers the auth record api endpoints and
|
||||
// the corresponding handlers.
|
||||
func bindRecordAuthApi(app core.App, rg *router.RouterGroup[*core.RequestEvent]) {
|
||||
// global oauth2 subscription redirect handler
|
||||
rg.GET("/oauth2-redirect", oauth2SubscriptionRedirect).Bind(
|
||||
SkipSuccessActivityLog(), // skip success log as it could contain sensitive information in the url
|
||||
)
|
||||
// add again as POST in case of response_mode=form_post
|
||||
rg.POST("/oauth2-redirect", oauth2SubscriptionRedirect).Bind(
|
||||
SkipSuccessActivityLog(), // skip success log as it could contain sensitive information in the url
|
||||
)
|
||||
|
||||
sub := rg.Group("/collections/{collection}")
|
||||
|
||||
sub.GET("/auth-methods", recordAuthMethods).Bind(
|
||||
collectionPathRateLimit("", "listAuthMethods"),
|
||||
)
|
||||
|
||||
sub.POST("/auth-refresh", recordAuthRefresh).Bind(
|
||||
collectionPathRateLimit("", "authRefresh"),
|
||||
RequireSameCollectionContextAuth(""),
|
||||
)
|
||||
|
||||
sub.POST("/auth-with-password", recordAuthWithPassword).Bind(
|
||||
collectionPathRateLimit("", "authWithPassword", "auth"),
|
||||
)
|
||||
|
||||
sub.POST("/auth-with-oauth2", recordAuthWithOAuth2).Bind(
|
||||
collectionPathRateLimit("", "authWithOAuth2", "auth"),
|
||||
)
|
||||
|
||||
sub.POST("/request-otp", recordRequestOTP).Bind(
|
||||
collectionPathRateLimit("", "requestOTP"),
|
||||
)
|
||||
sub.POST("/auth-with-otp", recordAuthWithOTP).Bind(
|
||||
collectionPathRateLimit("", "authWithOTP", "auth"),
|
||||
)
|
||||
|
||||
sub.POST("/request-password-reset", recordRequestPasswordReset).Bind(
|
||||
collectionPathRateLimit("", "requestPasswordReset"),
|
||||
)
|
||||
sub.POST("/confirm-password-reset", recordConfirmPasswordReset).Bind(
|
||||
collectionPathRateLimit("", "confirmPasswordReset"),
|
||||
)
|
||||
|
||||
sub.POST("/request-verification", recordRequestVerification).Bind(
|
||||
collectionPathRateLimit("", "requestVerification"),
|
||||
)
|
||||
sub.POST("/confirm-verification", recordConfirmVerification).Bind(
|
||||
collectionPathRateLimit("", "confirmVerification"),
|
||||
)
|
||||
|
||||
sub.POST("/request-email-change", recordRequestEmailChange).Bind(
|
||||
collectionPathRateLimit("", "requestEmailChange"),
|
||||
RequireSameCollectionContextAuth(""),
|
||||
)
|
||||
sub.POST("/confirm-email-change", recordConfirmEmailChange).Bind(
|
||||
collectionPathRateLimit("", "confirmEmailChange"),
|
||||
)
|
||||
|
||||
sub.POST("/impersonate/{id}", recordAuthImpersonate).Bind(RequireSuperuserAuth())
|
||||
}
|
||||
|
||||
func findAuthCollection(e *core.RequestEvent) (*core.Collection, error) {
|
||||
collection, err := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue("collection"))
|
||||
|
||||
if err != nil || !collection.IsAuth() {
|
||||
return nil, e.NotFoundError("Missing or invalid auth collection context.", err)
|
||||
}
|
||||
|
||||
return collection, nil
|
||||
}
|
122
apis/record_auth_email_change_confirm.go
Normal file
122
apis/record_auth_email_change_confirm.go
Normal file
|
@ -0,0 +1,122 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tools/security"
|
||||
)
|
||||
|
||||
func recordConfirmEmailChange(e *core.RequestEvent) error {
|
||||
collection, err := findAuthCollection(e)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if collection.Name == core.CollectionNameSuperusers {
|
||||
return e.BadRequestError("All superusers can change their emails directly.", nil)
|
||||
}
|
||||
|
||||
form := newEmailChangeConfirmForm(e.App, collection)
|
||||
if err = e.BindBody(form); err != nil {
|
||||
return firstApiError(err, e.BadRequestError("An error occurred while loading the submitted data.", err))
|
||||
}
|
||||
if err = form.validate(); err != nil {
|
||||
return firstApiError(err, e.BadRequestError("An error occurred while validating the submitted data.", err))
|
||||
}
|
||||
|
||||
authRecord, newEmail, err := form.parseToken()
|
||||
if err != nil {
|
||||
return firstApiError(err, e.BadRequestError("Invalid or expired token.", err))
|
||||
}
|
||||
|
||||
event := new(core.RecordConfirmEmailChangeRequestEvent)
|
||||
event.RequestEvent = e
|
||||
event.Collection = collection
|
||||
event.Record = authRecord
|
||||
event.NewEmail = newEmail
|
||||
|
||||
return e.App.OnRecordConfirmEmailChangeRequest().Trigger(event, func(e *core.RecordConfirmEmailChangeRequestEvent) error {
|
||||
e.Record.SetEmail(e.NewEmail)
|
||||
e.Record.SetVerified(true)
|
||||
|
||||
if err := e.App.Save(e.Record); err != nil {
|
||||
return firstApiError(err, e.BadRequestError("Failed to confirm email change.", err))
|
||||
}
|
||||
|
||||
return execAfterSuccessTx(true, e.App, func() error {
|
||||
return e.NoContent(http.StatusNoContent)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
func newEmailChangeConfirmForm(app core.App, collection *core.Collection) *EmailChangeConfirmForm {
|
||||
return &EmailChangeConfirmForm{
|
||||
app: app,
|
||||
collection: collection,
|
||||
}
|
||||
}
|
||||
|
||||
type EmailChangeConfirmForm struct {
|
||||
app core.App
|
||||
collection *core.Collection
|
||||
|
||||
Token string `form:"token" json:"token"`
|
||||
Password string `form:"password" json:"password"`
|
||||
}
|
||||
|
||||
func (form *EmailChangeConfirmForm) validate() error {
|
||||
return validation.ValidateStruct(form,
|
||||
validation.Field(&form.Token, validation.Required, validation.By(form.checkToken)),
|
||||
validation.Field(&form.Password, validation.Required, validation.Length(1, 100), validation.By(form.checkPassword)),
|
||||
)
|
||||
}
|
||||
|
||||
func (form *EmailChangeConfirmForm) checkToken(value any) error {
|
||||
_, _, err := form.parseToken()
|
||||
return err
|
||||
}
|
||||
|
||||
func (form *EmailChangeConfirmForm) checkPassword(value any) error {
|
||||
v, _ := value.(string)
|
||||
if v == "" {
|
||||
return nil // nothing to check
|
||||
}
|
||||
|
||||
authRecord, _, _ := form.parseToken()
|
||||
if authRecord == nil || !authRecord.ValidatePassword(v) {
|
||||
return validation.NewError("validation_invalid_password", "Missing or invalid auth record password.")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (form *EmailChangeConfirmForm) parseToken() (*core.Record, string, error) {
|
||||
// check token payload
|
||||
claims, _ := security.ParseUnverifiedJWT(form.Token)
|
||||
newEmail, _ := claims[core.TokenClaimNewEmail].(string)
|
||||
if newEmail == "" {
|
||||
return nil, "", validation.NewError("validation_invalid_token_payload", "Invalid token payload - newEmail must be set.")
|
||||
}
|
||||
|
||||
// ensure that there aren't other users with the new email
|
||||
_, err := form.app.FindAuthRecordByEmail(form.collection, newEmail)
|
||||
if err == nil {
|
||||
return nil, "", validation.NewError("validation_existing_token_email", "The new email address is already registered: "+newEmail)
|
||||
}
|
||||
|
||||
// verify that the token is not expired and its signature is valid
|
||||
authRecord, err := form.app.FindAuthRecordByToken(form.Token, core.TokenTypeEmailChange)
|
||||
if err != nil {
|
||||
return nil, "", validation.NewError("validation_invalid_token", "Invalid or expired token.")
|
||||
}
|
||||
|
||||
if authRecord.Collection().Id != form.collection.Id {
|
||||
return nil, "", validation.NewError("validation_token_collection_mismatch", "The provided token is for different auth collection.")
|
||||
}
|
||||
|
||||
return authRecord, newEmail, nil
|
||||
}
|
211
apis/record_auth_email_change_confirm_test.go
Normal file
211
apis/record_auth_email_change_confirm_test.go
Normal file
|
@ -0,0 +1,211 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestRecordConfirmEmailChange(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "not an auth collection",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/demo1/confirm-email-change",
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "empty data",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/confirm-email-change",
|
||||
Body: strings.NewReader(``),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":`,
|
||||
`"token":{"code":"validation_required"`,
|
||||
`"password":{"code":"validation_required"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "invalid data",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/confirm-email-change",
|
||||
Body: strings.NewReader(`{"token`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "expired token and correct password",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/confirm-email-change",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsInR5cGUiOiJlbWFpbENoYW5nZSIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSIsIm5ld0VtYWlsIjoiY2hhbmdlQGV4YW1wbGUuY29tIiwiZXhwIjoxNjQwOTkxNjYxfQ.dff842MO0mgRTHY8dktp0dqG9-7LGQOgRuiAbQpYBls",
|
||||
"password":"1234567890"
|
||||
}`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"token":{`,
|
||||
`"code":"validation_invalid_token"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "non-email change token",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/confirm-email-change",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
"password":"1234567890"
|
||||
}`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"token":{`,
|
||||
`"code":"validation_invalid_token_payload"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "valid token and incorrect password",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/confirm-email-change",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsInR5cGUiOiJlbWFpbENoYW5nZSIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSIsIm5ld0VtYWlsIjoiY2hhbmdlQGV4YW1wbGUuY29tIiwiZXhwIjoyNTI0NjA0NDYxfQ.Y7mVlaEPhJiNPoIvIqbIosZU4c4lEhwysOrRR8c95iU",
|
||||
"password":"1234567891"
|
||||
}`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"password":{`,
|
||||
`"code":"validation_invalid_password"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "valid token and correct password",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/confirm-email-change",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsInR5cGUiOiJlbWFpbENoYW5nZSIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSIsIm5ld0VtYWlsIjoiY2hhbmdlQGV4YW1wbGUuY29tIiwiZXhwIjoyNTI0NjA0NDYxfQ.Y7mVlaEPhJiNPoIvIqbIosZU4c4lEhwysOrRR8c95iU",
|
||||
"password":"1234567890"
|
||||
}`),
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordConfirmEmailChangeRequest": 1,
|
||||
"OnModelUpdate": 1,
|
||||
"OnModelUpdateExecute": 1,
|
||||
"OnModelAfterUpdateSuccess": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnRecordUpdate": 1,
|
||||
"OnRecordUpdateExecute": 1,
|
||||
"OnRecordAfterUpdateSuccess": 1,
|
||||
"OnRecordValidate": 1,
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
_, err := app.FindAuthRecordByEmail("users", "change@example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Expected to find user with email %q, got error: %v", "change@example.com", err)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "valid token in different auth collection",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/clients/confirm-email-change",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsInR5cGUiOiJlbWFpbENoYW5nZSIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSIsIm5ld0VtYWlsIjoiY2hhbmdlQGV4YW1wbGUuY29tIiwiZXhwIjoyNTI0NjA0NDYxfQ.Y7mVlaEPhJiNPoIvIqbIosZU4c4lEhwysOrRR8c95iU",
|
||||
"password":"1234567890"
|
||||
}`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"token":{"code":"validation_token_collection_mismatch"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "OnRecordConfirmEmailChangeRequest tx body write check",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/confirm-email-change",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsInR5cGUiOiJlbWFpbENoYW5nZSIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSIsIm5ld0VtYWlsIjoiY2hhbmdlQGV4YW1wbGUuY29tIiwiZXhwIjoyNTI0NjA0NDYxfQ.Y7mVlaEPhJiNPoIvIqbIosZU4c4lEhwysOrRR8c95iU",
|
||||
"password":"1234567890"
|
||||
}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.OnRecordConfirmEmailChangeRequest().BindFunc(func(e *core.RecordConfirmEmailChangeRequestEvent) error {
|
||||
original := e.App
|
||||
return e.App.RunInTransaction(func(txApp core.App) error {
|
||||
e.App = txApp
|
||||
defer func() { e.App = original }()
|
||||
|
||||
if err := e.Next(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return e.BadRequestError("TX_ERROR", nil)
|
||||
})
|
||||
})
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedEvents: map[string]int{"OnRecordConfirmEmailChangeRequest": 1},
|
||||
ExpectedContent: []string{"TX_ERROR"},
|
||||
},
|
||||
|
||||
// rate limit checks
|
||||
// -----------------------------------------------------------
|
||||
{
|
||||
Name: "RateLimit rule - users:confirmEmailChange",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/confirm-email-change",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsInR5cGUiOiJlbWFpbENoYW5nZSIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSIsIm5ld0VtYWlsIjoiY2hhbmdlQGV4YW1wbGUuY29tIiwiZXhwIjoyNTI0NjA0NDYxfQ.Y7mVlaEPhJiNPoIvIqbIosZU4c4lEhwysOrRR8c95iU",
|
||||
"password":"1234567890"
|
||||
}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 100, Label: "*:confirmEmailChange"},
|
||||
{MaxRequests: 0, Label: "users:confirmEmailChange"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "RateLimit rule - *:confirmEmailChange",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/confirm-email-change",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsInR5cGUiOiJlbWFpbENoYW5nZSIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSIsIm5ld0VtYWlsIjoiY2hhbmdlQGV4YW1wbGUuY29tIiwiZXhwIjoyNTI0NjA0NDYxfQ.Y7mVlaEPhJiNPoIvIqbIosZU4c4lEhwysOrRR8c95iU",
|
||||
"password":"1234567890"
|
||||
}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 0, Label: "*:confirmEmailChange"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
92
apis/record_auth_email_change_request.go
Normal file
92
apis/record_auth_email_change_request.go
Normal file
|
@ -0,0 +1,92 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/go-ozzo/ozzo-validation/v4/is"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/mails"
|
||||
)
|
||||
|
||||
func recordRequestEmailChange(e *core.RequestEvent) error {
|
||||
collection, err := findAuthCollection(e)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if collection.Name == core.CollectionNameSuperusers {
|
||||
return e.BadRequestError("All superusers can change their emails directly.", nil)
|
||||
}
|
||||
|
||||
record := e.Auth
|
||||
if record == nil {
|
||||
return e.UnauthorizedError("The request requires valid auth record.", nil)
|
||||
}
|
||||
|
||||
form := newEmailChangeRequestForm(e.App, record)
|
||||
if err = e.BindBody(form); err != nil {
|
||||
return firstApiError(err, e.BadRequestError("An error occurred while loading the submitted data.", err))
|
||||
}
|
||||
if err = form.validate(); err != nil {
|
||||
return firstApiError(err, e.BadRequestError("An error occurred while validating the submitted data.", err))
|
||||
}
|
||||
|
||||
event := new(core.RecordRequestEmailChangeRequestEvent)
|
||||
event.RequestEvent = e
|
||||
event.Collection = collection
|
||||
event.Record = record
|
||||
event.NewEmail = form.NewEmail
|
||||
|
||||
return e.App.OnRecordRequestEmailChangeRequest().Trigger(event, func(e *core.RecordRequestEmailChangeRequestEvent) error {
|
||||
if err := mails.SendRecordChangeEmail(e.App, e.Record, e.NewEmail); err != nil {
|
||||
return firstApiError(err, e.BadRequestError("Failed to request email change.", err))
|
||||
}
|
||||
|
||||
return execAfterSuccessTx(true, e.App, func() error {
|
||||
return e.NoContent(http.StatusNoContent)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
func newEmailChangeRequestForm(app core.App, record *core.Record) *emailChangeRequestForm {
|
||||
return &emailChangeRequestForm{
|
||||
app: app,
|
||||
record: record,
|
||||
}
|
||||
}
|
||||
|
||||
type emailChangeRequestForm struct {
|
||||
app core.App
|
||||
record *core.Record
|
||||
|
||||
NewEmail string `form:"newEmail" json:"newEmail"`
|
||||
}
|
||||
|
||||
func (form *emailChangeRequestForm) validate() error {
|
||||
return validation.ValidateStruct(form,
|
||||
validation.Field(&form.NewEmail,
|
||||
validation.Required,
|
||||
validation.Length(1, 255),
|
||||
is.EmailFormat,
|
||||
validation.NotIn(form.record.Email()),
|
||||
validation.By(form.checkUniqueEmail),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
func (form *emailChangeRequestForm) checkUniqueEmail(value any) error {
|
||||
v, _ := value.(string)
|
||||
if v == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
found, _ := form.app.FindAuthRecordByEmail(form.record.Collection(), v)
|
||||
if found != nil && found.Id != form.record.Id {
|
||||
return validation.NewError("validation_invalid_new_email", "Invalid new email address.")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
195
apis/record_auth_email_change_request_test.go
Normal file
195
apis/record_auth_email_change_request_test.go
Normal file
|
@ -0,0 +1,195 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestRecordRequestEmailChange(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "unauthorized",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-email-change",
|
||||
Body: strings.NewReader(`{"newEmail":"change@example.com"}`),
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "not an auth collection",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/demo1/request-email-change",
|
||||
Body: strings.NewReader(`{"newEmail":"change@example.com"}`),
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "record authentication but from different auth collection",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/clients/request-email-change",
|
||||
Body: strings.NewReader(`{"newEmail":"change@example.com"}`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "superuser authentication",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-email-change",
|
||||
Body: strings.NewReader(`{"newEmail":"change@example.com"}`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "invalid data",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-email-change",
|
||||
Body: strings.NewReader(`{"newEmail`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "empty data",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-email-change",
|
||||
Body: strings.NewReader(`{}`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":`,
|
||||
`"newEmail":{"code":"validation_required"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "valid data (existing email)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-email-change",
|
||||
Body: strings.NewReader(`{"newEmail":"test2@example.com"}`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":`,
|
||||
`"newEmail":{"code":"validation_invalid_new_email"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "valid data (new email)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-email-change",
|
||||
Body: strings.NewReader(`{"newEmail":"change@example.com"}`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordRequestEmailChangeRequest": 1,
|
||||
"OnMailerSend": 1,
|
||||
"OnMailerRecordEmailChangeSend": 1,
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
if !strings.Contains(app.TestMailer.LastMessage().HTML, "/auth/confirm-email-change") {
|
||||
t.Fatalf("Expected email change email, got\n%v", app.TestMailer.LastMessage().HTML)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "OnRecordRequestEmailChangeRequest tx body write check",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-email-change",
|
||||
Body: strings.NewReader(`{"newEmail":"change@example.com"}`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.OnRecordRequestEmailChangeRequest().BindFunc(func(e *core.RecordRequestEmailChangeRequestEvent) error {
|
||||
original := e.App
|
||||
return e.App.RunInTransaction(func(txApp core.App) error {
|
||||
e.App = txApp
|
||||
defer func() { e.App = original }()
|
||||
|
||||
if err := e.Next(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return e.BadRequestError("TX_ERROR", nil)
|
||||
})
|
||||
})
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedEvents: map[string]int{"OnRecordRequestEmailChangeRequest": 1},
|
||||
ExpectedContent: []string{"TX_ERROR"},
|
||||
},
|
||||
|
||||
// rate limit checks
|
||||
// -----------------------------------------------------------
|
||||
{
|
||||
Name: "RateLimit rule - users:requestEmailChange",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-email-change",
|
||||
Body: strings.NewReader(`{"newEmail":"change@example.com"}`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 100, Label: "*:requestEmailChange"},
|
||||
{MaxRequests: 0, Label: "users:requestEmailChange"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "RateLimit rule - *:requestEmailChange",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-email-change",
|
||||
Body: strings.NewReader(`{"newEmail":"change@example.com"}`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 0, Label: "*:requestEmailChange"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
54
apis/record_auth_impersonate.go
Normal file
54
apis/record_auth_impersonate.go
Normal file
|
@ -0,0 +1,54 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
)
|
||||
|
||||
// note: for now allow superusers but it may change in the future to allow access
|
||||
// also to users with "Manage API" rule access depending on the use cases that will arise
|
||||
func recordAuthImpersonate(e *core.RequestEvent) error {
|
||||
if !e.HasSuperuserAuth() {
|
||||
return e.ForbiddenError("", nil)
|
||||
}
|
||||
|
||||
collection, err := findAuthCollection(e)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
record, err := e.App.FindRecordById(collection, e.Request.PathValue("id"))
|
||||
if err != nil {
|
||||
return e.NotFoundError("", err)
|
||||
}
|
||||
|
||||
form := &impersonateForm{}
|
||||
if err = e.BindBody(form); err != nil {
|
||||
return e.BadRequestError("An error occurred while loading the submitted data.", err)
|
||||
}
|
||||
if err = form.validate(); err != nil {
|
||||
return e.BadRequestError("An error occurred while validating the submitted data.", err)
|
||||
}
|
||||
|
||||
token, err := record.NewStaticAuthToken(time.Duration(form.Duration) * time.Second)
|
||||
if err != nil {
|
||||
e.InternalServerError("Failed to generate static auth token", err)
|
||||
}
|
||||
|
||||
return recordAuthResponse(e, record, token, "", nil)
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
type impersonateForm struct {
|
||||
// Duration is the optional custom token duration in seconds.
|
||||
Duration int64 `form:"duration" json:"duration"`
|
||||
}
|
||||
|
||||
func (form *impersonateForm) validate() error {
|
||||
return validation.ValidateStruct(form,
|
||||
validation.Field(&form.Duration, validation.Min(0)),
|
||||
)
|
||||
}
|
109
apis/record_auth_impersonate_test.go
Normal file
109
apis/record_auth_impersonate_test.go
Normal file
|
@ -0,0 +1,109 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestRecordAuthImpersonate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "unauthorized",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/impersonate/4q1xlclmfloku33",
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as different user",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/impersonate/4q1xlclmfloku33",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6Im9hcDY0MGNvdDR5cnUycyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.GfJo6EHIobgas_AXt-M-tj5IoQendPnrkMSe9ExuSEY",
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as the same user",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/impersonate/4q1xlclmfloku33",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/impersonate/4q1xlclmfloku33",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"token":"`,
|
||||
`"id":"4q1xlclmfloku33"`,
|
||||
`"record":{`,
|
||||
},
|
||||
NotExpectedContent: []string{
|
||||
// hidden fields should remain hidden even though we are authenticated as superuser
|
||||
`"tokenKey"`,
|
||||
`"password"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordAuthRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser with custom invalid duration",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/impersonate/4q1xlclmfloku33",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
Body: strings.NewReader(`{"duration":-1}`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"duration":{`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser with custom valid duration",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/impersonate/4q1xlclmfloku33",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
Body: strings.NewReader(`{"duration":100}`),
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"token":"`,
|
||||
`"id":"4q1xlclmfloku33"`,
|
||||
`"record":{`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordAuthRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
170
apis/record_auth_methods.go
Normal file
170
apis/record_auth_methods.go
Normal file
|
@ -0,0 +1,170 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"slices"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tools/auth"
|
||||
"github.com/pocketbase/pocketbase/tools/security"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
type otpResponse struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Duration int64 `json:"duration"` // in seconds
|
||||
}
|
||||
|
||||
type mfaResponse struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Duration int64 `json:"duration"` // in seconds
|
||||
}
|
||||
|
||||
type passwordResponse struct {
|
||||
IdentityFields []string `json:"identityFields"`
|
||||
Enabled bool `json:"enabled"`
|
||||
}
|
||||
|
||||
type oauth2Response struct {
|
||||
Providers []providerInfo `json:"providers"`
|
||||
Enabled bool `json:"enabled"`
|
||||
}
|
||||
|
||||
type providerInfo struct {
|
||||
Name string `json:"name"`
|
||||
DisplayName string `json:"displayName"`
|
||||
State string `json:"state"`
|
||||
AuthURL string `json:"authURL"`
|
||||
|
||||
// @todo
|
||||
// deprecated: use AuthURL instead
|
||||
// AuthUrl will be removed after dropping v0.22 support
|
||||
AuthUrl string `json:"authUrl"`
|
||||
|
||||
// technically could be omitted if the provider doesn't support PKCE,
|
||||
// but to avoid breaking existing typed clients we'll return them as empty string
|
||||
CodeVerifier string `json:"codeVerifier"`
|
||||
CodeChallenge string `json:"codeChallenge"`
|
||||
CodeChallengeMethod string `json:"codeChallengeMethod"`
|
||||
}
|
||||
|
||||
type authMethodsResponse struct {
|
||||
Password passwordResponse `json:"password"`
|
||||
OAuth2 oauth2Response `json:"oauth2"`
|
||||
MFA mfaResponse `json:"mfa"`
|
||||
OTP otpResponse `json:"otp"`
|
||||
|
||||
// legacy fields
|
||||
// @todo remove after dropping v0.22 support
|
||||
AuthProviders []providerInfo `json:"authProviders"`
|
||||
UsernamePassword bool `json:"usernamePassword"`
|
||||
EmailPassword bool `json:"emailPassword"`
|
||||
}
|
||||
|
||||
func (amr *authMethodsResponse) fillLegacyFields() {
|
||||
amr.EmailPassword = amr.Password.Enabled && slices.Contains(amr.Password.IdentityFields, "email")
|
||||
|
||||
amr.UsernamePassword = amr.Password.Enabled && slices.Contains(amr.Password.IdentityFields, "username")
|
||||
|
||||
if amr.OAuth2.Enabled {
|
||||
amr.AuthProviders = amr.OAuth2.Providers
|
||||
}
|
||||
}
|
||||
|
||||
func recordAuthMethods(e *core.RequestEvent) error {
|
||||
collection, err := findAuthCollection(e)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
result := authMethodsResponse{
|
||||
Password: passwordResponse{
|
||||
IdentityFields: make([]string, 0, len(collection.PasswordAuth.IdentityFields)),
|
||||
},
|
||||
OAuth2: oauth2Response{
|
||||
Providers: make([]providerInfo, 0, len(collection.OAuth2.Providers)),
|
||||
},
|
||||
OTP: otpResponse{
|
||||
Enabled: collection.OTP.Enabled,
|
||||
},
|
||||
MFA: mfaResponse{
|
||||
Enabled: collection.MFA.Enabled,
|
||||
},
|
||||
}
|
||||
|
||||
if collection.PasswordAuth.Enabled {
|
||||
result.Password.Enabled = true
|
||||
result.Password.IdentityFields = collection.PasswordAuth.IdentityFields
|
||||
}
|
||||
|
||||
if collection.OTP.Enabled {
|
||||
result.OTP.Duration = collection.OTP.Duration
|
||||
}
|
||||
|
||||
if collection.MFA.Enabled {
|
||||
result.MFA.Duration = collection.MFA.Duration
|
||||
}
|
||||
|
||||
if !collection.OAuth2.Enabled {
|
||||
result.fillLegacyFields()
|
||||
|
||||
return e.JSON(http.StatusOK, result)
|
||||
}
|
||||
|
||||
result.OAuth2.Enabled = true
|
||||
|
||||
for _, config := range collection.OAuth2.Providers {
|
||||
provider, err := config.InitProvider()
|
||||
if err != nil {
|
||||
e.App.Logger().Debug(
|
||||
"Failed to setup OAuth2 provider",
|
||||
slog.String("name", config.Name),
|
||||
slog.String("error", err.Error()),
|
||||
)
|
||||
continue // skip provider
|
||||
}
|
||||
|
||||
info := providerInfo{
|
||||
Name: config.Name,
|
||||
DisplayName: provider.DisplayName(),
|
||||
State: security.RandomString(30),
|
||||
}
|
||||
|
||||
if info.DisplayName == "" {
|
||||
info.DisplayName = config.Name
|
||||
}
|
||||
|
||||
urlOpts := []oauth2.AuthCodeOption{}
|
||||
|
||||
// custom providers url options
|
||||
switch config.Name {
|
||||
case auth.NameApple:
|
||||
// see https://developer.apple.com/documentation/sign_in_with_apple/sign_in_with_apple_js/incorporating_sign_in_with_apple_into_other_platforms#3332113
|
||||
urlOpts = append(urlOpts, oauth2.SetAuthURLParam("response_mode", "form_post"))
|
||||
}
|
||||
|
||||
if provider.PKCE() {
|
||||
info.CodeVerifier = security.RandomString(43)
|
||||
info.CodeChallenge = security.S256Challenge(info.CodeVerifier)
|
||||
info.CodeChallengeMethod = "S256"
|
||||
urlOpts = append(urlOpts,
|
||||
oauth2.SetAuthURLParam("code_challenge", info.CodeChallenge),
|
||||
oauth2.SetAuthURLParam("code_challenge_method", info.CodeChallengeMethod),
|
||||
)
|
||||
}
|
||||
|
||||
info.AuthURL = provider.BuildAuthURL(
|
||||
info.State,
|
||||
urlOpts...,
|
||||
) + "&redirect_uri=" // empty redirect_uri so that users can append their redirect url
|
||||
|
||||
info.AuthUrl = info.AuthURL
|
||||
|
||||
result.OAuth2.Providers = append(result.OAuth2.Providers, info)
|
||||
}
|
||||
|
||||
result.fillLegacyFields()
|
||||
|
||||
return e.JSON(http.StatusOK, result)
|
||||
}
|
106
apis/record_auth_methods_test.go
Normal file
106
apis/record_auth_methods_test.go
Normal file
|
@ -0,0 +1,106 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestRecordAuthMethodsList(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "missing collection",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/missing/auth-methods",
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "non auth collection",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/demo1/auth-methods",
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "auth collection with none auth methods allowed",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/nologin/auth-methods",
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"password":{"identityFields":[],"enabled":false}`,
|
||||
`"oauth2":{"providers":[],"enabled":false}`,
|
||||
`"mfa":{"enabled":false,"duration":0}`,
|
||||
`"otp":{"enabled":false,"duration":0}`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "auth collection with all auth methods allowed",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/users/auth-methods",
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"password":{"identityFields":["email","username"],"enabled":true}`,
|
||||
`"mfa":{"enabled":true,"duration":1800}`,
|
||||
`"otp":{"enabled":true,"duration":300}`,
|
||||
`"oauth2":{`,
|
||||
`"providers":[{`,
|
||||
`"name":"google"`,
|
||||
`"name":"gitlab"`,
|
||||
`"state":`,
|
||||
`"displayName":`,
|
||||
`"codeVerifier":`,
|
||||
`"codeChallenge":`,
|
||||
`"codeChallengeMethod":`,
|
||||
`"authURL":`,
|
||||
`redirect_uri="`, // ensures that the redirect_uri is the last url param
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
|
||||
// rate limit checks
|
||||
// -----------------------------------------------------------
|
||||
{
|
||||
Name: "RateLimit rule - nologin:listAuthMethods",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/nologin/auth-methods",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 100, Label: "*:listAuthMethods"},
|
||||
{MaxRequests: 0, Label: "nologin:listAuthMethods"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "RateLimit rule - *:listAuthMethods",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/nologin/auth-methods",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 0, Label: "*:listAuthMethods"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
127
apis/record_auth_otp_request.go
Normal file
127
apis/record_auth_otp_request.go
Normal file
|
@ -0,0 +1,127 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/go-ozzo/ozzo-validation/v4/is"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/mails"
|
||||
"github.com/pocketbase/pocketbase/tools/routine"
|
||||
"github.com/pocketbase/pocketbase/tools/security"
|
||||
)
|
||||
|
||||
func recordRequestOTP(e *core.RequestEvent) error {
|
||||
collection, err := findAuthCollection(e)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !collection.OTP.Enabled {
|
||||
return e.ForbiddenError("The collection is not configured to allow OTP authentication.", nil)
|
||||
}
|
||||
|
||||
form := &createOTPForm{}
|
||||
if err = e.BindBody(form); err != nil {
|
||||
return firstApiError(err, e.BadRequestError("An error occurred while loading the submitted data.", err))
|
||||
}
|
||||
if err = form.validate(); err != nil {
|
||||
return firstApiError(err, e.BadRequestError("An error occurred while validating the submitted data.", err))
|
||||
}
|
||||
|
||||
record, err := e.App.FindAuthRecordByEmail(collection, form.Email)
|
||||
|
||||
// ignore not found errors to allow custom record find implementations
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
return e.InternalServerError("", err)
|
||||
}
|
||||
|
||||
event := new(core.RecordCreateOTPRequestEvent)
|
||||
event.RequestEvent = e
|
||||
event.Password = security.RandomStringWithAlphabet(collection.OTP.Length, "1234567890")
|
||||
event.Collection = collection
|
||||
event.Record = record
|
||||
|
||||
originalApp := e.App
|
||||
|
||||
return e.App.OnRecordRequestOTPRequest().Trigger(event, func(e *core.RecordCreateOTPRequestEvent) error {
|
||||
if e.Record == nil {
|
||||
// write a dummy 200 response as a very rudimentary emails enumeration "protection"
|
||||
e.JSON(http.StatusOK, map[string]string{
|
||||
"otpId": core.GenerateDefaultRandomId(),
|
||||
})
|
||||
|
||||
return fmt.Errorf("missing or invalid %s OTP auth record with email %s", collection.Name, form.Email)
|
||||
}
|
||||
|
||||
var otp *core.OTP
|
||||
|
||||
// limit the new OTP creations for a single user
|
||||
if !e.App.IsDev() {
|
||||
otps, err := e.App.FindAllOTPsByRecord(e.Record)
|
||||
if err != nil {
|
||||
return firstApiError(err, e.InternalServerError("Failed to fetch previous record OTPs.", err))
|
||||
}
|
||||
|
||||
totalRecent := 0
|
||||
for _, existingOTP := range otps {
|
||||
if !existingOTP.HasExpired(collection.OTP.DurationTime()) {
|
||||
totalRecent++
|
||||
}
|
||||
// use the last issued one
|
||||
if totalRecent > 9 {
|
||||
otp = otps[0] // otps are DESC sorted
|
||||
e.App.Logger().Warn(
|
||||
"Too many OTP requests - reusing the last issued",
|
||||
"email", form.Email,
|
||||
"recordId", e.Record.Id,
|
||||
"otpId", existingOTP.Id,
|
||||
)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if otp == nil {
|
||||
// create new OTP
|
||||
// ---
|
||||
otp = core.NewOTP(e.App)
|
||||
otp.SetCollectionRef(e.Record.Collection().Id)
|
||||
otp.SetRecordRef(e.Record.Id)
|
||||
otp.SetPassword(e.Password)
|
||||
err = e.App.Save(otp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// send OTP email
|
||||
// (in the background as a very basic timing attacks and emails enumeration protection)
|
||||
// ---
|
||||
routine.FireAndForget(func() {
|
||||
err = mails.SendRecordOTP(originalApp, e.Record, otp.Id, e.Password)
|
||||
if err != nil {
|
||||
originalApp.Logger().Error("Failed to send OTP email", "error", errors.Join(err, originalApp.Delete(otp)))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
return execAfterSuccessTx(true, e.App, func() error {
|
||||
return e.JSON(http.StatusOK, map[string]string{"otpId": otp.Id})
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
type createOTPForm struct {
|
||||
Email string `form:"email" json:"email"`
|
||||
}
|
||||
|
||||
func (form createOTPForm) validate() error {
|
||||
return validation.ValidateStruct(&form,
|
||||
validation.Field(&form.Email, validation.Required, validation.Length(1, 255), is.EmailFormat),
|
||||
)
|
||||
}
|
316
apis/record_auth_otp_request_test.go
Normal file
316
apis/record_auth_otp_request_test.go
Normal file
|
@ -0,0 +1,316 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
)
|
||||
|
||||
func TestRecordRequestOTP(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "not an auth collection",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/demo1/request-otp",
|
||||
Body: strings.NewReader(`{"email":"test@example.com"}`),
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "auth collection with disabled otp",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-otp",
|
||||
Body: strings.NewReader(`{"email":"test@example.com"}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
usersCol, err := app.FindCollectionByNameOrId("users")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
usersCol.OTP.Enabled = false
|
||||
|
||||
if err := app.Save(usersCol); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "empty body",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-otp",
|
||||
Body: strings.NewReader(``),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{"email":{"code":"validation_required","message":"Cannot be blank."}}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "invalid body",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-otp",
|
||||
Body: strings.NewReader(`{"email`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "invalid request data",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-otp",
|
||||
Body: strings.NewReader(`{"email":"invalid"}`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"email":{"code":"validation_is_email`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "missing auth record",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-otp",
|
||||
Body: strings.NewReader(`{"email":"missing@example.com"}`),
|
||||
Delay: 100 * time.Millisecond,
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"otpId":"`, // some fake random generated string
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordRequestOTPRequest": 1,
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
if app.TestMailer.TotalSend() != 0 {
|
||||
t.Fatalf("Expected zero emails, got %d", app.TestMailer.TotalSend())
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "existing auth record (with < 9 non-expired)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-otp",
|
||||
Body: strings.NewReader(`{"email":"test@example.com"}`),
|
||||
Delay: 100 * time.Millisecond,
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// insert 8 non-expired and 2 expired
|
||||
for i := 0; i < 10; i++ {
|
||||
otp := core.NewOTP(app)
|
||||
otp.Id = "otp_" + strconv.Itoa(i)
|
||||
otp.SetCollectionRef(user.Collection().Id)
|
||||
otp.SetRecordRef(user.Id)
|
||||
otp.SetPassword("123456")
|
||||
if i >= 8 {
|
||||
expiredDate := types.NowDateTime().AddDate(-3, 0, 0)
|
||||
otp.SetRaw("created", expiredDate)
|
||||
otp.SetRaw("updated", expiredDate)
|
||||
}
|
||||
if err := app.SaveNoValidate(otp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"otpId":"`,
|
||||
},
|
||||
NotExpectedContent: []string{
|
||||
`"otpId":"otp_`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordRequestOTPRequest": 1,
|
||||
"OnMailerSend": 1,
|
||||
"OnMailerRecordOTPSend": 1,
|
||||
"OnModelCreate": 1,
|
||||
"OnModelCreateExecute": 1,
|
||||
"OnModelAfterCreateSuccess": 1,
|
||||
"OnModelValidate": 2, // + 1 for the OTP update after the email send
|
||||
"OnRecordCreate": 1,
|
||||
"OnRecordCreateExecute": 1,
|
||||
"OnRecordAfterCreateSuccess": 1,
|
||||
"OnRecordValidate": 2,
|
||||
// OTP update
|
||||
"OnModelUpdate": 1,
|
||||
"OnModelUpdateExecute": 1,
|
||||
"OnModelAfterUpdateSuccess": 1,
|
||||
"OnRecordUpdate": 1,
|
||||
"OnRecordUpdateExecute": 1,
|
||||
"OnRecordAfterUpdateSuccess": 1,
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
if app.TestMailer.TotalSend() != 1 {
|
||||
t.Fatalf("Expected 1 email, got %d", app.TestMailer.TotalSend())
|
||||
}
|
||||
|
||||
// ensure that sentTo is set
|
||||
otps, err := app.FindRecordsByFilter(core.CollectionNameOTPs, "sentTo='test@example.com'", "", 0, 0)
|
||||
if err != nil || len(otps) != 1 {
|
||||
t.Fatalf("Expected to find 1 OTP with sentTo %q, found %d", "test@example.com", len(otps))
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "existing auth record with intercepted email (with < 9 non-expired)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-otp",
|
||||
Body: strings.NewReader(`{"email":"test@example.com"}`),
|
||||
Delay: 100 * time.Millisecond,
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
// prevent email sent
|
||||
app.OnMailerRecordOTPSend("users").BindFunc(func(e *core.MailerRecordEvent) error {
|
||||
return nil
|
||||
})
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"otpId":"`,
|
||||
},
|
||||
NotExpectedContent: []string{
|
||||
`"otpId":"otp_`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordRequestOTPRequest": 1,
|
||||
"OnMailerRecordOTPSend": 1,
|
||||
"OnModelCreate": 1,
|
||||
"OnModelCreateExecute": 1,
|
||||
"OnModelAfterCreateSuccess": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnRecordCreate": 1,
|
||||
"OnRecordCreateExecute": 1,
|
||||
"OnRecordAfterCreateSuccess": 1,
|
||||
"OnRecordValidate": 1,
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
if app.TestMailer.TotalSend() != 0 {
|
||||
t.Fatalf("Expected 0 emails, got %d", app.TestMailer.TotalSend())
|
||||
}
|
||||
|
||||
// ensure that there is no OTP with user email as sentTo
|
||||
otps, err := app.FindRecordsByFilter(core.CollectionNameOTPs, "sentTo='test@example.com'", "", 0, 0)
|
||||
if err != nil || len(otps) != 0 {
|
||||
t.Fatalf("Expected to find 0 OTPs with sentTo %q, found %d", "test@example.com", len(otps))
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "existing auth record (with > 9 non-expired)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-otp",
|
||||
Body: strings.NewReader(`{"email":"test@example.com"}`),
|
||||
Delay: 100 * time.Millisecond,
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// insert 10 non-expired
|
||||
for i := 0; i < 10; i++ {
|
||||
otp := core.NewOTP(app)
|
||||
otp.Id = "otp_" + strconv.Itoa(i)
|
||||
otp.SetCollectionRef(user.Collection().Id)
|
||||
otp.SetRecordRef(user.Id)
|
||||
otp.SetPassword("123456")
|
||||
if err := app.SaveNoValidate(otp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"otpId":"otp_9"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordRequestOTPRequest": 1,
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
if app.TestMailer.TotalSend() != 0 {
|
||||
t.Fatalf("Expected 0 sent emails, got %d", app.TestMailer.TotalSend())
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "OnRecordRequestOTPRequest tx body write check",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-otp",
|
||||
Body: strings.NewReader(`{"email":"test@example.com"}`),
|
||||
Delay: 100 * time.Millisecond,
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.OnRecordRequestOTPRequest().BindFunc(func(e *core.RecordCreateOTPRequestEvent) error {
|
||||
original := e.App
|
||||
return e.App.RunInTransaction(func(txApp core.App) error {
|
||||
e.App = txApp
|
||||
defer func() { e.App = original }()
|
||||
|
||||
if err := e.Next(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return e.BadRequestError("TX_ERROR", nil)
|
||||
})
|
||||
})
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedEvents: map[string]int{"OnRecordRequestOTPRequest": 1},
|
||||
ExpectedContent: []string{"TX_ERROR"},
|
||||
},
|
||||
|
||||
// rate limit checks
|
||||
// -----------------------------------------------------------
|
||||
{
|
||||
Name: "RateLimit rule - users:requestOTP",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-otp",
|
||||
Body: strings.NewReader(`{"email":"test@example.com"}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 100, Label: "*:requestOTP"},
|
||||
{MaxRequests: 0, Label: "users:requestOTP"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "RateLimit rule - *:requestOTP",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-otp",
|
||||
Body: strings.NewReader(`{"email":"test@example.com"}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 0, Label: "*:requestOTP"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
104
apis/record_auth_password_reset_confirm.go
Normal file
104
apis/record_auth_password_reset_confirm.go
Normal file
|
@ -0,0 +1,104 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/core/validators"
|
||||
"github.com/pocketbase/pocketbase/tools/security"
|
||||
"github.com/spf13/cast"
|
||||
)
|
||||
|
||||
func recordConfirmPasswordReset(e *core.RequestEvent) error {
|
||||
collection, err := findAuthCollection(e)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
form := new(recordConfirmPasswordResetForm)
|
||||
form.app = e.App
|
||||
form.collection = collection
|
||||
if err = e.BindBody(form); err != nil {
|
||||
return e.BadRequestError("An error occurred while loading the submitted data.", err)
|
||||
}
|
||||
if err = form.validate(); err != nil {
|
||||
return firstApiError(err, e.BadRequestError("An error occurred while validating the submitted data.", err))
|
||||
}
|
||||
|
||||
authRecord, err := e.App.FindAuthRecordByToken(form.Token, core.TokenTypePasswordReset)
|
||||
if err != nil {
|
||||
return firstApiError(err, e.BadRequestError("Invalid or expired password reset token.", err))
|
||||
}
|
||||
|
||||
event := new(core.RecordConfirmPasswordResetRequestEvent)
|
||||
event.RequestEvent = e
|
||||
event.Collection = collection
|
||||
event.Record = authRecord
|
||||
|
||||
return e.App.OnRecordConfirmPasswordResetRequest().Trigger(event, func(e *core.RecordConfirmPasswordResetRequestEvent) error {
|
||||
authRecord.SetPassword(form.Password)
|
||||
|
||||
if !authRecord.Verified() {
|
||||
payload, err := security.ParseUnverifiedJWT(form.Token)
|
||||
if err == nil && authRecord.Email() == cast.ToString(payload[core.TokenClaimEmail]) {
|
||||
// mark as verified if the email hasn't changed
|
||||
authRecord.SetVerified(true)
|
||||
}
|
||||
}
|
||||
|
||||
err = e.App.Save(authRecord)
|
||||
if err != nil {
|
||||
return firstApiError(err, e.BadRequestError("Failed to set new password.", err))
|
||||
}
|
||||
|
||||
e.App.Store().Remove(getPasswordResetResendKey(authRecord))
|
||||
|
||||
return execAfterSuccessTx(true, e.App, func() error {
|
||||
return e.NoContent(http.StatusNoContent)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
type recordConfirmPasswordResetForm struct {
|
||||
app core.App
|
||||
collection *core.Collection
|
||||
|
||||
Token string `form:"token" json:"token"`
|
||||
Password string `form:"password" json:"password"`
|
||||
PasswordConfirm string `form:"passwordConfirm" json:"passwordConfirm"`
|
||||
}
|
||||
|
||||
func (form *recordConfirmPasswordResetForm) validate() error {
|
||||
min := 1
|
||||
passField, ok := form.collection.Fields.GetByName(core.FieldNamePassword).(*core.PasswordField)
|
||||
if ok && passField != nil && passField.Min > 0 {
|
||||
min = passField.Min
|
||||
}
|
||||
|
||||
return validation.ValidateStruct(form,
|
||||
validation.Field(&form.Token, validation.Required, validation.By(form.checkToken)),
|
||||
validation.Field(&form.Password, validation.Required, validation.Length(min, 255)), // the FieldPassword validator will check further the specicic length constraints
|
||||
validation.Field(&form.PasswordConfirm, validation.Required, validation.By(validators.Equal(form.Password))),
|
||||
)
|
||||
}
|
||||
|
||||
func (form *recordConfirmPasswordResetForm) checkToken(value any) error {
|
||||
v, _ := value.(string)
|
||||
if v == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
record, err := form.app.FindAuthRecordByToken(v, core.TokenTypePasswordReset)
|
||||
if err != nil || record == nil {
|
||||
return validation.NewError("validation_invalid_token", "Invalid or expired token.")
|
||||
}
|
||||
|
||||
if record.Collection().Id != form.collection.Id {
|
||||
return validation.NewError("validation_token_collection_mismatch", "The provided token is for different auth collection.")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
360
apis/record_auth_password_reset_confirm_test.go
Normal file
360
apis/record_auth_password_reset_confirm_test.go
Normal file
|
@ -0,0 +1,360 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestRecordConfirmPasswordReset(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "empty data",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/confirm-password-reset",
|
||||
Body: strings.NewReader(``),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"password":{"code":"validation_required"`,
|
||||
`"passwordConfirm":{"code":"validation_required"`,
|
||||
`"token":{"code":"validation_required"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "invalid data format",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/confirm-password-reset",
|
||||
Body: strings.NewReader(`{"password`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "expired token and invalid password",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/confirm-password-reset",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MTY0MDk5MTY2MSwidHlwZSI6InBhc3N3b3JkUmVzZXQiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8iLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.5Tm6_6amQqOlX3urAnXlEdmxwG5qQJfiTg6U0hHR1hk",
|
||||
"password":"1234567",
|
||||
"passwordConfirm":"7654321"
|
||||
}`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"token":{"code":"validation_invalid_token"`,
|
||||
`"password":{"code":"validation_length_out_of_range"`,
|
||||
`"passwordConfirm":{"code":"validation_values_mismatch"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "non-password reset token",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/confirm-password-reset",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InZlcmlmaWNhdGlvbiIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSJ9.SetHpu2H-x-q4TIUz-xiQjwi7MNwLCLvSs4O0hUSp0E",
|
||||
"password":"1234567!",
|
||||
"passwordConfirm":"1234567!"
|
||||
}`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"token":{"code":"validation_invalid_token"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "non auth collection",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/demo1/confirm-password-reset?expand=rel,missing",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InBhc3N3b3JkUmVzZXQiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8iLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.xR-xq1oHDy0D8Q4NDOAEyYKGHWd_swzoiSoL8FLFBHY",
|
||||
"password":"1234567!",
|
||||
"passwordConfirm":"1234567!"
|
||||
}`),
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "different auth collection",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/clients/confirm-password-reset?expand=rel,missing",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InBhc3N3b3JkUmVzZXQiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8iLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.xR-xq1oHDy0D8Q4NDOAEyYKGHWd_swzoiSoL8FLFBHY",
|
||||
"password":"1234567!",
|
||||
"passwordConfirm":"1234567!"
|
||||
}`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{"token":{"code":"validation_token_collection_mismatch"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "valid token and data (unverified user)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/confirm-password-reset",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InBhc3N3b3JkUmVzZXQiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8iLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.xR-xq1oHDy0D8Q4NDOAEyYKGHWd_swzoiSoL8FLFBHY",
|
||||
"password":"1234567!",
|
||||
"passwordConfirm":"1234567!"
|
||||
}`),
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordConfirmPasswordResetRequest": 1,
|
||||
"OnModelUpdate": 1,
|
||||
"OnModelUpdateExecute": 1,
|
||||
"OnModelAfterUpdateSuccess": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnRecordUpdate": 1,
|
||||
"OnRecordUpdateExecute": 1,
|
||||
"OnRecordAfterUpdateSuccess": 1,
|
||||
"OnRecordValidate": 1,
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to fetch confirm password user: %v", err)
|
||||
}
|
||||
|
||||
if user.Verified() {
|
||||
t.Fatal("Expected the user to be unverified")
|
||||
}
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
_, err := app.FindAuthRecordByToken(
|
||||
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InBhc3N3b3JkUmVzZXQiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8iLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.xR-xq1oHDy0D8Q4NDOAEyYKGHWd_swzoiSoL8FLFBHY",
|
||||
core.TokenTypePasswordReset,
|
||||
)
|
||||
if err == nil {
|
||||
t.Fatal("Expected the password reset token to be invalidated")
|
||||
}
|
||||
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to fetch confirm password user: %v", err)
|
||||
}
|
||||
|
||||
if !user.Verified() {
|
||||
t.Fatal("Expected the user to be marked as verified")
|
||||
}
|
||||
|
||||
if !user.ValidatePassword("1234567!") {
|
||||
t.Fatal("Password wasn't changed")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "valid token and data (unverified user with different email from the one in the token)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/confirm-password-reset",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InBhc3N3b3JkUmVzZXQiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8iLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.xR-xq1oHDy0D8Q4NDOAEyYKGHWd_swzoiSoL8FLFBHY",
|
||||
"password":"1234567!",
|
||||
"passwordConfirm":"1234567!"
|
||||
}`),
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordConfirmPasswordResetRequest": 1,
|
||||
"OnModelUpdate": 1,
|
||||
"OnModelUpdateExecute": 1,
|
||||
"OnModelAfterUpdateSuccess": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnRecordUpdate": 1,
|
||||
"OnRecordUpdateExecute": 1,
|
||||
"OnRecordAfterUpdateSuccess": 1,
|
||||
"OnRecordValidate": 1,
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to fetch confirm password user: %v", err)
|
||||
}
|
||||
|
||||
if user.Verified() {
|
||||
t.Fatal("Expected the user to be unverified")
|
||||
}
|
||||
|
||||
oldTokenKey := user.TokenKey()
|
||||
|
||||
// manually change the email to check whether the verified state will be updated
|
||||
user.SetEmail("test_update@example.com")
|
||||
if err = app.Save(user); err != nil {
|
||||
t.Fatalf("Failed to update user test email: %v", err)
|
||||
}
|
||||
|
||||
// resave with the old token key since the email change above
|
||||
// would change it and will make the password token invalid
|
||||
user.SetTokenKey(oldTokenKey)
|
||||
if err = app.Save(user); err != nil {
|
||||
t.Fatalf("Failed to restore original user tokenKey: %v", err)
|
||||
}
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
_, err := app.FindAuthRecordByToken(
|
||||
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InBhc3N3b3JkUmVzZXQiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8iLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.xR-xq1oHDy0D8Q4NDOAEyYKGHWd_swzoiSoL8FLFBHY",
|
||||
core.TokenTypePasswordReset,
|
||||
)
|
||||
if err == nil {
|
||||
t.Fatalf("Expected the password reset token to be invalidated")
|
||||
}
|
||||
|
||||
user, err := app.FindAuthRecordByEmail("users", "test_update@example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to fetch confirm password user: %v", err)
|
||||
}
|
||||
|
||||
if user.Verified() {
|
||||
t.Fatal("Expected the user to remain unverified")
|
||||
}
|
||||
|
||||
if !user.ValidatePassword("1234567!") {
|
||||
t.Fatal("Password wasn't changed")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "valid token and data (verified user)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/confirm-password-reset",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InBhc3N3b3JkUmVzZXQiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8iLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.xR-xq1oHDy0D8Q4NDOAEyYKGHWd_swzoiSoL8FLFBHY",
|
||||
"password":"1234567!",
|
||||
"passwordConfirm":"1234567!"
|
||||
}`),
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordConfirmPasswordResetRequest": 1,
|
||||
"OnModelUpdate": 1,
|
||||
"OnModelUpdateExecute": 1,
|
||||
"OnModelAfterUpdateSuccess": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnRecordUpdate": 1,
|
||||
"OnRecordUpdateExecute": 1,
|
||||
"OnRecordAfterUpdateSuccess": 1,
|
||||
"OnRecordValidate": 1,
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to fetch confirm password user: %v", err)
|
||||
}
|
||||
|
||||
// ensure that the user is already verified
|
||||
user.SetVerified(true)
|
||||
if err := app.Save(user); err != nil {
|
||||
t.Fatalf("Failed to update user verified state")
|
||||
}
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
_, err := app.FindAuthRecordByToken(
|
||||
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InBhc3N3b3JkUmVzZXQiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8iLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.xR-xq1oHDy0D8Q4NDOAEyYKGHWd_swzoiSoL8FLFBHY",
|
||||
core.TokenTypePasswordReset,
|
||||
)
|
||||
if err == nil {
|
||||
t.Fatal("Expected the password reset token to be invalidated")
|
||||
}
|
||||
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to fetch confirm password user: %v", err)
|
||||
}
|
||||
|
||||
if !user.Verified() {
|
||||
t.Fatal("Expected the user to remain verified")
|
||||
}
|
||||
|
||||
if !user.ValidatePassword("1234567!") {
|
||||
t.Fatal("Password wasn't changed")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "OnRecordConfirmPasswordResetRequest tx body write check",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/confirm-password-reset",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InBhc3N3b3JkUmVzZXQiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8iLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.xR-xq1oHDy0D8Q4NDOAEyYKGHWd_swzoiSoL8FLFBHY",
|
||||
"password":"1234567!",
|
||||
"passwordConfirm":"1234567!"
|
||||
}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.OnRecordConfirmPasswordResetRequest().BindFunc(func(e *core.RecordConfirmPasswordResetRequestEvent) error {
|
||||
original := e.App
|
||||
return e.App.RunInTransaction(func(txApp core.App) error {
|
||||
e.App = txApp
|
||||
defer func() { e.App = original }()
|
||||
|
||||
if err := e.Next(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return e.BadRequestError("TX_ERROR", nil)
|
||||
})
|
||||
})
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedEvents: map[string]int{"OnRecordConfirmPasswordResetRequest": 1},
|
||||
ExpectedContent: []string{"TX_ERROR"},
|
||||
},
|
||||
|
||||
// rate limit checks
|
||||
// -----------------------------------------------------------
|
||||
{
|
||||
Name: "RateLimit rule - users:confirmPasswordReset",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/confirm-password-reset",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InBhc3N3b3JkUmVzZXQiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8iLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.xR-xq1oHDy0D8Q4NDOAEyYKGHWd_swzoiSoL8FLFBHY",
|
||||
"password":"1234567!",
|
||||
"passwordConfirm":"1234567!"
|
||||
}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 100, Label: "*:confirmPasswordReset"},
|
||||
{MaxRequests: 0, Label: "users:confirmPasswordReset"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "RateLimit rule - *:confirmPasswordReset",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/confirm-password-reset",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InBhc3N3b3JkUmVzZXQiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8iLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.xR-xq1oHDy0D8Q4NDOAEyYKGHWd_swzoiSoL8FLFBHY",
|
||||
"password":"1234567!",
|
||||
"passwordConfirm":"1234567!"
|
||||
}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 0, Label: "*:confirmPasswordReset"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
88
apis/record_auth_password_reset_request.go
Normal file
88
apis/record_auth_password_reset_request.go
Normal file
|
@ -0,0 +1,88 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/go-ozzo/ozzo-validation/v4/is"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/mails"
|
||||
"github.com/pocketbase/pocketbase/tools/routine"
|
||||
)
|
||||
|
||||
func recordRequestPasswordReset(e *core.RequestEvent) error {
|
||||
collection, err := findAuthCollection(e)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !collection.PasswordAuth.Enabled {
|
||||
return e.BadRequestError("The collection is not configured to allow password authentication.", nil)
|
||||
}
|
||||
|
||||
form := new(recordRequestPasswordResetForm)
|
||||
if err = e.BindBody(form); err != nil {
|
||||
return firstApiError(err, e.BadRequestError("An error occurred while loading the submitted data.", err))
|
||||
}
|
||||
if err = form.validate(); err != nil {
|
||||
return firstApiError(err, e.BadRequestError("An error occurred while validating the submitted data.", err))
|
||||
}
|
||||
|
||||
record, err := e.App.FindAuthRecordByEmail(collection, form.Email)
|
||||
if err != nil {
|
||||
// eagerly write 204 response as a very basic measure against emails enumeration
|
||||
e.NoContent(http.StatusNoContent)
|
||||
return fmt.Errorf("failed to fetch %s record with email %s: %w", collection.Name, form.Email, err)
|
||||
}
|
||||
|
||||
resendKey := getPasswordResetResendKey(record)
|
||||
if e.App.Store().Has(resendKey) {
|
||||
// eagerly write 204 response as a very basic measure against emails enumeration
|
||||
e.NoContent(http.StatusNoContent)
|
||||
return errors.New("try again later - you've already requested a password reset email")
|
||||
}
|
||||
|
||||
event := new(core.RecordRequestPasswordResetRequestEvent)
|
||||
event.RequestEvent = e
|
||||
event.Collection = collection
|
||||
event.Record = record
|
||||
|
||||
return e.App.OnRecordRequestPasswordResetRequest().Trigger(event, func(e *core.RecordRequestPasswordResetRequestEvent) error {
|
||||
// run in background because we don't need to show the result to the client
|
||||
app := e.App
|
||||
routine.FireAndForget(func() {
|
||||
if err := mails.SendRecordPasswordReset(app, e.Record); err != nil {
|
||||
app.Logger().Error("Failed to send password reset email", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
app.Store().Set(resendKey, struct{}{})
|
||||
time.AfterFunc(2*time.Minute, func() {
|
||||
app.Store().Remove(resendKey)
|
||||
})
|
||||
})
|
||||
|
||||
return execAfterSuccessTx(true, e.App, func() error {
|
||||
return e.NoContent(http.StatusNoContent)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
type recordRequestPasswordResetForm struct {
|
||||
Email string `form:"email" json:"email"`
|
||||
}
|
||||
|
||||
func (form *recordRequestPasswordResetForm) validate() error {
|
||||
return validation.ValidateStruct(form,
|
||||
validation.Field(&form.Email, validation.Required, validation.Length(1, 255), is.EmailFormat),
|
||||
)
|
||||
}
|
||||
|
||||
func getPasswordResetResendKey(record *core.Record) string {
|
||||
return "@limitPasswordResetEmail_" + record.Collection().Id + record.Id
|
||||
}
|
169
apis/record_auth_password_reset_request_test.go
Normal file
169
apis/record_auth_password_reset_request_test.go
Normal file
|
@ -0,0 +1,169 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestRecordRequestPasswordReset(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "not an auth collection",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/demo1/request-password-reset",
|
||||
Body: strings.NewReader(``),
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "empty data",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-password-reset",
|
||||
Body: strings.NewReader(``),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{"email":{"code":"validation_required","message":"Cannot be blank."}}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "invalid data",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-password-reset",
|
||||
Body: strings.NewReader(`{"email`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "existing auth record in a collection with disabled password login",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/nologin/request-password-reset",
|
||||
Body: strings.NewReader(`{"email":"test@example.com"}`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "missing auth record",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-password-reset",
|
||||
Body: strings.NewReader(`{"email":"missing@example.com"}`),
|
||||
Delay: 100 * time.Millisecond,
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
if app.TestMailer.TotalSend() != 0 {
|
||||
t.Fatalf("Expected zero emails, got %d", app.TestMailer.TotalSend())
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "existing auth record",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-password-reset",
|
||||
Body: strings.NewReader(`{"email":"test@example.com"}`),
|
||||
Delay: 100 * time.Millisecond,
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordRequestPasswordResetRequest": 1,
|
||||
"OnMailerSend": 1,
|
||||
"OnMailerRecordPasswordResetSend": 1,
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
if !strings.Contains(app.TestMailer.LastMessage().HTML, "/auth/confirm-password-reset") {
|
||||
t.Fatalf("Expected password reset email, got\n%v", app.TestMailer.LastMessage().HTML)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "existing auth record (after already sent)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-password-reset",
|
||||
Body: strings.NewReader(`{"email":"test@example.com"}`),
|
||||
Delay: 100 * time.Millisecond,
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
// simulate recent verification sent
|
||||
authRecord, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
resendKey := "@limitPasswordResetEmail_" + authRecord.Collection().Id + authRecord.Id
|
||||
app.Store().Set(resendKey, struct{}{})
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "OnRecordRequestPasswordResetRequest tx body write check",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-password-reset",
|
||||
Body: strings.NewReader(`{"email":"test@example.com"}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.OnRecordRequestPasswordResetRequest().BindFunc(func(e *core.RecordRequestPasswordResetRequestEvent) error {
|
||||
original := e.App
|
||||
return e.App.RunInTransaction(func(txApp core.App) error {
|
||||
e.App = txApp
|
||||
defer func() { e.App = original }()
|
||||
|
||||
if err := e.Next(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return e.BadRequestError("TX_ERROR", nil)
|
||||
})
|
||||
})
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedEvents: map[string]int{"OnRecordRequestPasswordResetRequest": 1},
|
||||
ExpectedContent: []string{"TX_ERROR"},
|
||||
},
|
||||
|
||||
// rate limit checks
|
||||
// -----------------------------------------------------------
|
||||
{
|
||||
Name: "RateLimit rule - users:requestPasswordReset",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-password-reset",
|
||||
Body: strings.NewReader(`{"email":"missing@example.com"}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 100, Label: "*:requestPasswordReset"},
|
||||
{MaxRequests: 0, Label: "users:requestPasswordReset"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "RateLimit rule - *:requestPasswordReset",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-password-reset",
|
||||
Body: strings.NewReader(`{"email":"missing@example.com"}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 0, Label: "*:requestPasswordReset"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
29
apis/record_auth_refresh.go
Normal file
29
apis/record_auth_refresh.go
Normal file
|
@ -0,0 +1,29 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tools/security"
|
||||
"github.com/spf13/cast"
|
||||
)
|
||||
|
||||
func recordAuthRefresh(e *core.RequestEvent) error {
|
||||
record := e.Auth
|
||||
if record == nil {
|
||||
return e.NotFoundError("Missing auth record context.", nil)
|
||||
}
|
||||
|
||||
currentToken := getAuthTokenFromRequest(e)
|
||||
claims, _ := security.ParseUnverifiedJWT(currentToken)
|
||||
if v, ok := claims[core.TokenClaimRefreshable]; !ok || !cast.ToBool(v) {
|
||||
return e.ForbiddenError("The current auth token is not refreshable.", nil)
|
||||
}
|
||||
|
||||
event := new(core.RecordAuthRefreshRequestEvent)
|
||||
event.RequestEvent = e
|
||||
event.Collection = record.Collection()
|
||||
event.Record = record
|
||||
|
||||
return e.App.OnRecordAuthRefreshRequest().Trigger(event, func(e *core.RecordAuthRefreshRequestEvent) error {
|
||||
return RecordAuthResponse(e.RequestEvent, e.Record, "", nil)
|
||||
})
|
||||
}
|
202
apis/record_auth_refresh_test.go
Normal file
202
apis/record_auth_refresh_test.go
Normal file
|
@ -0,0 +1,202 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestRecordAuthRefresh(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "unauthorized",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-refresh",
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "superuser trying to refresh the auth of another auth collection",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-refresh",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "auth record + not an auth collection",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/demo1/auth-refresh",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "auth record + different auth collection",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/clients/auth-refresh?expand=rel,missing",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "auth record + same auth collection as the token",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-refresh?expand=rel,missing",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"token":`,
|
||||
`"record":`,
|
||||
`"id":"4q1xlclmfloku33"`,
|
||||
`"emailVisibility":false`,
|
||||
`"email":"test@example.com"`, // the owner can always view their email address
|
||||
`"expand":`,
|
||||
`"rel":`,
|
||||
`"id":"llvuca81nly1qls"`,
|
||||
},
|
||||
NotExpectedContent: []string{
|
||||
`"missing":`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordAuthRefreshRequest": 1,
|
||||
"OnRecordAuthRequest": 1,
|
||||
"OnRecordEnrich": 2,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "auth record + same auth collection as the token but static/unrefreshable",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-refresh",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6ZmFsc2V9.4IsO6YMsR19crhwl_YWzvRH8pfq2Ri4Gv2dzGyneLak",
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "unverified auth record in onlyVerified collection",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/clients/auth-refresh",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6Im8xeTBkZDBzcGQ3ODZtZCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.Zi0yXE-CNmnbTdVaQEzYZVuECqRdn3LgEM6pmB3XWBE",
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordAuthRefreshRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "verified auth record in onlyVerified collection",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/clients/auth-refresh",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"token":`,
|
||||
`"record":`,
|
||||
`"id":"gk390qegs4y47wn"`,
|
||||
`"verified":true`,
|
||||
`"email":"test@example.com"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordAuthRefreshRequest": 1,
|
||||
"OnRecordAuthRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "OnRecordAuthRefreshRequest tx body write check",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-refresh",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.OnRecordAuthRefreshRequest().BindFunc(func(e *core.RecordAuthRefreshRequestEvent) error {
|
||||
original := e.App
|
||||
return e.App.RunInTransaction(func(txApp core.App) error {
|
||||
e.App = txApp
|
||||
defer func() { e.App = original }()
|
||||
|
||||
if err := e.Next(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return e.BadRequestError("TX_ERROR", nil)
|
||||
})
|
||||
})
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedEvents: map[string]int{"OnRecordAuthRefreshRequest": 1},
|
||||
ExpectedContent: []string{"TX_ERROR"},
|
||||
},
|
||||
|
||||
// rate limit checks
|
||||
// -----------------------------------------------------------
|
||||
{
|
||||
Name: "RateLimit rule - users:authRefresh",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-refresh",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 100, Label: "*:authRefresh"},
|
||||
{MaxRequests: 0, Label: "users:authRefresh"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "RateLimit rule - *:authRefresh",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-refresh",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 0, Label: "*:authRefresh"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
102
apis/record_auth_verification_confirm.go
Normal file
102
apis/record_auth_verification_confirm.go
Normal file
|
@ -0,0 +1,102 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tools/security"
|
||||
"github.com/spf13/cast"
|
||||
)
|
||||
|
||||
func recordConfirmVerification(e *core.RequestEvent) error {
|
||||
collection, err := findAuthCollection(e)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if collection.Name == core.CollectionNameSuperusers {
|
||||
return e.BadRequestError("All superusers are verified by default.", nil)
|
||||
}
|
||||
|
||||
form := new(recordConfirmVerificationForm)
|
||||
form.app = e.App
|
||||
form.collection = collection
|
||||
if err = e.BindBody(form); err != nil {
|
||||
return firstApiError(err, e.BadRequestError("An error occurred while loading the submitted data.", err))
|
||||
}
|
||||
if err = form.validate(); err != nil {
|
||||
return firstApiError(err, e.BadRequestError("An error occurred while validating the submitted data.", err))
|
||||
}
|
||||
|
||||
record, err := form.app.FindAuthRecordByToken(form.Token, core.TokenTypeVerification)
|
||||
if err != nil {
|
||||
return e.BadRequestError("Invalid or expired verification token.", err)
|
||||
}
|
||||
|
||||
wasVerified := record.Verified()
|
||||
|
||||
event := new(core.RecordConfirmVerificationRequestEvent)
|
||||
event.RequestEvent = e
|
||||
event.Collection = collection
|
||||
event.Record = record
|
||||
|
||||
return e.App.OnRecordConfirmVerificationRequest().Trigger(event, func(e *core.RecordConfirmVerificationRequestEvent) error {
|
||||
if !wasVerified {
|
||||
e.Record.SetVerified(true)
|
||||
|
||||
if err := e.App.Save(e.Record); err != nil {
|
||||
return firstApiError(err, e.BadRequestError("An error occurred while saving the verified state.", err))
|
||||
}
|
||||
}
|
||||
|
||||
e.App.Store().Remove(getVerificationResendKey(e.Record))
|
||||
|
||||
return execAfterSuccessTx(true, e.App, func() error {
|
||||
return e.NoContent(http.StatusNoContent)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
type recordConfirmVerificationForm struct {
|
||||
app core.App
|
||||
collection *core.Collection
|
||||
|
||||
Token string `form:"token" json:"token"`
|
||||
}
|
||||
|
||||
func (form *recordConfirmVerificationForm) validate() error {
|
||||
return validation.ValidateStruct(form,
|
||||
validation.Field(&form.Token, validation.Required, validation.By(form.checkToken)),
|
||||
)
|
||||
}
|
||||
|
||||
func (form *recordConfirmVerificationForm) checkToken(value any) error {
|
||||
v, _ := value.(string)
|
||||
if v == "" {
|
||||
return nil // nothing to check
|
||||
}
|
||||
|
||||
claims, _ := security.ParseUnverifiedJWT(v)
|
||||
email := cast.ToString(claims["email"])
|
||||
if email == "" {
|
||||
return validation.NewError("validation_invalid_token_claims", "Missing email token claim.")
|
||||
}
|
||||
|
||||
record, err := form.app.FindAuthRecordByToken(v, core.TokenTypeVerification)
|
||||
if err != nil || record == nil {
|
||||
return validation.NewError("validation_invalid_token", "Invalid or expired token.")
|
||||
}
|
||||
|
||||
if record.Collection().Id != form.collection.Id {
|
||||
return validation.NewError("validation_token_collection_mismatch", "The provided token is for different auth collection.")
|
||||
}
|
||||
|
||||
if record.Email() != email {
|
||||
return validation.NewError("validation_token_email_mismatch", "The record email doesn't match with the requested token claims.")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
216
apis/record_auth_verification_confirm_test.go
Normal file
216
apis/record_auth_verification_confirm_test.go
Normal file
|
@ -0,0 +1,216 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestRecordConfirmVerification(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "empty data",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/confirm-verification",
|
||||
Body: strings.NewReader(``),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"token":{"code":"validation_required"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "invalid data format",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/confirm-verification",
|
||||
Body: strings.NewReader(`{"password`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "expired token",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/confirm-verification",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MTY0MDk5MTY2MSwidHlwZSI6InZlcmlmaWNhdGlvbiIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSJ9.qqelNNL2Udl6K_TJ282sNHYCpASgA6SIuSVKGfBHMZU"
|
||||
}`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"token":{"code":"validation_invalid_token"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "non-verification token",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/confirm-verification",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InBhc3N3b3JkUmVzZXQiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8iLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.xR-xq1oHDy0D8Q4NDOAEyYKGHWd_swzoiSoL8FLFBHY"
|
||||
}`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"token":{"code":"validation_invalid_token"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "non auth collection",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/demo1/confirm-verification?expand=rel,missing",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InZlcmlmaWNhdGlvbiIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSJ9.SetHpu2H-x-q4TIUz-xiQjwi7MNwLCLvSs4O0hUSp0E"
|
||||
}`),
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "different auth collection",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/clients/confirm-verification?expand=rel,missing",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InZlcmlmaWNhdGlvbiIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSJ9.SetHpu2H-x-q4TIUz-xiQjwi7MNwLCLvSs4O0hUSp0E"
|
||||
}`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{"token":{"code":"validation_token_collection_mismatch"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "valid token",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/confirm-verification",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InZlcmlmaWNhdGlvbiIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSJ9.SetHpu2H-x-q4TIUz-xiQjwi7MNwLCLvSs4O0hUSp0E"
|
||||
}`),
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordConfirmVerificationRequest": 1,
|
||||
"OnModelUpdate": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnModelUpdateExecute": 1,
|
||||
"OnModelAfterUpdateSuccess": 1,
|
||||
"OnRecordUpdate": 1,
|
||||
"OnRecordValidate": 1,
|
||||
"OnRecordUpdateExecute": 1,
|
||||
"OnRecordAfterUpdateSuccess": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "valid token (already verified)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/confirm-verification",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6Im9hcDY0MGNvdDR5cnUycyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InZlcmlmaWNhdGlvbiIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsImVtYWlsIjoidGVzdDJAZXhhbXBsZS5jb20ifQ.QQmM3odNFVk6u4J4-5H8IBM3dfk9YCD7mPW-8PhBAI8"
|
||||
}`),
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordConfirmVerificationRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "valid verification token from a collection without allowed login",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/nologin/confirm-verification",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6ImRjNDlrNmpnZWpuNDBoMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InZlcmlmaWNhdGlvbiIsImNvbGxlY3Rpb25JZCI6ImtwdjcwOXNrMmxxYnFrOCIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSJ9.5GmuZr4vmwk3Cb_3ZZWNxwbE75KZC-j71xxIPR9AsVw"
|
||||
}`),
|
||||
ExpectedStatus: 204,
|
||||
ExpectedContent: []string{},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordConfirmVerificationRequest": 1,
|
||||
"OnModelUpdate": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnModelUpdateExecute": 1,
|
||||
"OnModelAfterUpdateSuccess": 1,
|
||||
"OnRecordUpdate": 1,
|
||||
"OnRecordValidate": 1,
|
||||
"OnRecordUpdateExecute": 1,
|
||||
"OnRecordAfterUpdateSuccess": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "OnRecordConfirmVerificationRequest tx body write check",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/confirm-verification",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InZlcmlmaWNhdGlvbiIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSJ9.SetHpu2H-x-q4TIUz-xiQjwi7MNwLCLvSs4O0hUSp0E"
|
||||
}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.OnRecordConfirmVerificationRequest().BindFunc(func(e *core.RecordConfirmVerificationRequestEvent) error {
|
||||
original := e.App
|
||||
return e.App.RunInTransaction(func(txApp core.App) error {
|
||||
e.App = txApp
|
||||
defer func() { e.App = original }()
|
||||
|
||||
if err := e.Next(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return e.BadRequestError("TX_ERROR", nil)
|
||||
})
|
||||
})
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedEvents: map[string]int{"OnRecordConfirmVerificationRequest": 1},
|
||||
ExpectedContent: []string{"TX_ERROR"},
|
||||
},
|
||||
|
||||
// rate limit checks
|
||||
// -----------------------------------------------------------
|
||||
{
|
||||
Name: "RateLimit rule - nologin:confirmVerification",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/nologin/confirm-verification",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6ImRjNDlrNmpnZWpuNDBoMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InZlcmlmaWNhdGlvbiIsImNvbGxlY3Rpb25JZCI6ImtwdjcwOXNrMmxxYnFrOCIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSJ9.5GmuZr4vmwk3Cb_3ZZWNxwbE75KZC-j71xxIPR9AsVw"
|
||||
}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 100, Label: "*:confirmVerification"},
|
||||
{MaxRequests: 0, Label: "nologin:confirmVerification"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "RateLimit rule - *:confirmVerification",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/nologin/confirm-verification",
|
||||
Body: strings.NewReader(`{
|
||||
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6ImRjNDlrNmpnZWpuNDBoMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InZlcmlmaWNhdGlvbiIsImNvbGxlY3Rpb25JZCI6ImtwdjcwOXNrMmxxYnFrOCIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSJ9.5GmuZr4vmwk3Cb_3ZZWNxwbE75KZC-j71xxIPR9AsVw"
|
||||
}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 0, Label: "*:confirmVerification"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
91
apis/record_auth_verification_request.go
Normal file
91
apis/record_auth_verification_request.go
Normal file
|
@ -0,0 +1,91 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/go-ozzo/ozzo-validation/v4/is"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/mails"
|
||||
"github.com/pocketbase/pocketbase/tools/routine"
|
||||
)
|
||||
|
||||
func recordRequestVerification(e *core.RequestEvent) error {
|
||||
collection, err := findAuthCollection(e)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if collection.Name == core.CollectionNameSuperusers {
|
||||
return e.BadRequestError("All superusers are verified by default.", nil)
|
||||
}
|
||||
|
||||
form := new(recordRequestVerificationForm)
|
||||
if err = e.BindBody(form); err != nil {
|
||||
return firstApiError(err, e.BadRequestError("An error occurred while loading the submitted data.", err))
|
||||
}
|
||||
if err = form.validate(); err != nil {
|
||||
return firstApiError(err, e.BadRequestError("An error occurred while validating the submitted data.", err))
|
||||
}
|
||||
|
||||
record, err := e.App.FindAuthRecordByEmail(collection, form.Email)
|
||||
if err != nil {
|
||||
// eagerly write 204 response as a very basic measure against emails enumeration
|
||||
e.NoContent(http.StatusNoContent)
|
||||
return fmt.Errorf("failed to fetch %s record with email %s: %w", collection.Name, form.Email, err)
|
||||
}
|
||||
|
||||
resendKey := getVerificationResendKey(record)
|
||||
if !record.Verified() && e.App.Store().Has(resendKey) {
|
||||
// eagerly write 204 response as a very basic measure against emails enumeration
|
||||
e.NoContent(http.StatusNoContent)
|
||||
return errors.New("try again later - you've already requested a verification email")
|
||||
}
|
||||
|
||||
event := new(core.RecordRequestVerificationRequestEvent)
|
||||
event.RequestEvent = e
|
||||
event.Collection = collection
|
||||
event.Record = record
|
||||
|
||||
return e.App.OnRecordRequestVerificationRequest().Trigger(event, func(e *core.RecordRequestVerificationRequestEvent) error {
|
||||
if e.Record.Verified() {
|
||||
return e.NoContent(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// run in background because we don't need to show the result to the client
|
||||
app := e.App
|
||||
routine.FireAndForget(func() {
|
||||
if err := mails.SendRecordVerification(app, e.Record); err != nil {
|
||||
app.Logger().Error("Failed to send verification email", "error", err)
|
||||
}
|
||||
|
||||
app.Store().Set(resendKey, struct{}{})
|
||||
time.AfterFunc(2*time.Minute, func() {
|
||||
app.Store().Remove(resendKey)
|
||||
})
|
||||
})
|
||||
|
||||
return execAfterSuccessTx(true, e.App, func() error {
|
||||
return e.NoContent(http.StatusNoContent)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
type recordRequestVerificationForm struct {
|
||||
Email string `form:"email" json:"email"`
|
||||
}
|
||||
|
||||
func (form *recordRequestVerificationForm) validate() error {
|
||||
return validation.ValidateStruct(form,
|
||||
validation.Field(&form.Email, validation.Required, validation.Length(1, 255), is.EmailFormat),
|
||||
)
|
||||
}
|
||||
|
||||
func getVerificationResendKey(record *core.Record) string {
|
||||
return "@limitVerificationEmail_" + record.Collection().Id + record.Id
|
||||
}
|
186
apis/record_auth_verification_request_test.go
Normal file
186
apis/record_auth_verification_request_test.go
Normal file
|
@ -0,0 +1,186 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestRecordRequestVerification(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "not an auth collection",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/demo1/request-verification",
|
||||
Body: strings.NewReader(``),
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "empty data",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-verification",
|
||||
Body: strings.NewReader(``),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{"email":{"code":"validation_required","message":"Cannot be blank."}}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "invalid data",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-verification",
|
||||
Body: strings.NewReader(`{"email`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "missing auth record",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-verification",
|
||||
Body: strings.NewReader(`{"email":"missing@example.com"}`),
|
||||
Delay: 100 * time.Millisecond,
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
if app.TestMailer.TotalSend() != 0 {
|
||||
t.Fatalf("Expected zero emails, got %d", app.TestMailer.TotalSend())
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "already verified auth record",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-verification",
|
||||
Body: strings.NewReader(`{"email":"test2@example.com"}`),
|
||||
Delay: 100 * time.Millisecond,
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordRequestVerificationRequest": 1,
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
if app.TestMailer.TotalSend() != 0 {
|
||||
t.Fatalf("Expected zero emails, got %d", app.TestMailer.TotalSend())
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "existing auth record",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-verification",
|
||||
Body: strings.NewReader(`{"email":"test@example.com"}`),
|
||||
Delay: 100 * time.Millisecond,
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordRequestVerificationRequest": 1,
|
||||
"OnMailerSend": 1,
|
||||
"OnMailerRecordVerificationSend": 1,
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
if !strings.Contains(app.TestMailer.LastMessage().HTML, "/auth/confirm-verification") {
|
||||
t.Fatalf("Expected verification email, got\n%v", app.TestMailer.LastMessage().HTML)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "existing auth record (after already sent)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-verification",
|
||||
Body: strings.NewReader(`{"email":"test@example.com"}`),
|
||||
Delay: 100 * time.Millisecond,
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
// terminated before firing the event
|
||||
// "OnRecordRequestVerificationRequest": 1,
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
// simulate recent verification sent
|
||||
authRecord, err := app.FindFirstRecordByData("users", "email", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
resendKey := "@limitVerificationEmail_" + authRecord.Collection().Id + authRecord.Id
|
||||
app.Store().Set(resendKey, struct{}{})
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
if app.TestMailer.TotalSend() != 0 {
|
||||
t.Fatalf("Expected zero emails, got %d", app.TestMailer.TotalSend())
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "OnRecordRequestVerificationRequest tx body write check",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-verification",
|
||||
Body: strings.NewReader(`{"email":"test@example.com"}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.OnRecordRequestVerificationRequest().BindFunc(func(e *core.RecordRequestVerificationRequestEvent) error {
|
||||
original := e.App
|
||||
return e.App.RunInTransaction(func(txApp core.App) error {
|
||||
e.App = txApp
|
||||
defer func() { e.App = original }()
|
||||
|
||||
if err := e.Next(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return e.BadRequestError("TX_ERROR", nil)
|
||||
})
|
||||
})
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedEvents: map[string]int{"OnRecordRequestVerificationRequest": 1},
|
||||
ExpectedContent: []string{"TX_ERROR"},
|
||||
},
|
||||
|
||||
// rate limit checks
|
||||
// -----------------------------------------------------------
|
||||
{
|
||||
Name: "RateLimit rule - users:requestVerification",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-verification",
|
||||
Body: strings.NewReader(`{"email":"test@example.com"}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 100, Label: "*:requestVerification"},
|
||||
{MaxRequests: 0, Label: "users:requestVerification"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "RateLimit rule - *:requestVerification",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/request-verification",
|
||||
Body: strings.NewReader(`{"email":"test@example.com"}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 0, Label: "*:requestVerification"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
374
apis/record_auth_with_oauth2.go
Normal file
374
apis/record_auth_with_oauth2.go
Normal file
|
@ -0,0 +1,374 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"maps"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tools/dbutils"
|
||||
"github.com/pocketbase/pocketbase/tools/filesystem"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
func recordAuthWithOAuth2(e *core.RequestEvent) error {
|
||||
collection, err := findAuthCollection(e)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !collection.OAuth2.Enabled {
|
||||
return e.ForbiddenError("The collection is not configured to allow OAuth2 authentication.", nil)
|
||||
}
|
||||
|
||||
var fallbackAuthRecord *core.Record
|
||||
if e.Auth != nil && e.Auth.Collection().Id == collection.Id {
|
||||
fallbackAuthRecord = e.Auth
|
||||
}
|
||||
|
||||
e.Set(core.RequestEventKeyInfoContext, core.RequestInfoContextOAuth2)
|
||||
|
||||
form := new(recordOAuth2LoginForm)
|
||||
form.collection = collection
|
||||
if err = e.BindBody(form); err != nil {
|
||||
return firstApiError(err, e.BadRequestError("An error occurred while loading the submitted data.", err))
|
||||
}
|
||||
|
||||
if form.RedirectUrl != "" && form.RedirectURL == "" {
|
||||
e.App.Logger().Warn("[recordAuthWithOAuth2] redirectUrl body param is deprecated and will be removed in the future. Please replace it with redirectURL.")
|
||||
form.RedirectURL = form.RedirectUrl
|
||||
}
|
||||
|
||||
if err = form.validate(); err != nil {
|
||||
return firstApiError(err, e.BadRequestError("An error occurred while loading the submitted data.", err))
|
||||
}
|
||||
|
||||
// exchange token for OAuth2 user info and locate existing ExternalAuth rel
|
||||
// ---------------------------------------------------------------
|
||||
|
||||
// load provider configuration
|
||||
providerConfig, ok := collection.OAuth2.GetProviderConfig(form.Provider)
|
||||
if !ok {
|
||||
return e.InternalServerError("Missing or invalid provider config.", nil)
|
||||
}
|
||||
|
||||
provider, err := providerConfig.InitProvider()
|
||||
if err != nil {
|
||||
return firstApiError(err, e.InternalServerError("Failed to init provider "+form.Provider, err))
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(e.Request.Context(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
provider.SetContext(ctx)
|
||||
provider.SetRedirectURL(form.RedirectURL)
|
||||
|
||||
var opts []oauth2.AuthCodeOption
|
||||
|
||||
if provider.PKCE() {
|
||||
opts = append(opts, oauth2.SetAuthURLParam("code_verifier", form.CodeVerifier))
|
||||
}
|
||||
|
||||
// fetch token
|
||||
token, err := provider.FetchToken(form.Code, opts...)
|
||||
if err != nil {
|
||||
return firstApiError(err, e.BadRequestError("Failed to fetch OAuth2 token.", err))
|
||||
}
|
||||
|
||||
// fetch external auth user
|
||||
authUser, err := provider.FetchAuthUser(token)
|
||||
if err != nil {
|
||||
return firstApiError(err, e.BadRequestError("Failed to fetch OAuth2 user.", err))
|
||||
}
|
||||
|
||||
var authRecord *core.Record
|
||||
|
||||
// check for existing relation with the auth collection
|
||||
externalAuthRel, err := e.App.FindFirstExternalAuthByExpr(dbx.HashExp{
|
||||
"collectionRef": form.collection.Id,
|
||||
"provider": form.Provider,
|
||||
"providerId": authUser.Id,
|
||||
})
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
return e.InternalServerError("Failed OAuth2 relation check.", err)
|
||||
}
|
||||
|
||||
switch {
|
||||
case err == nil && externalAuthRel != nil:
|
||||
authRecord, err = e.App.FindRecordById(form.collection, externalAuthRel.RecordRef())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
case fallbackAuthRecord != nil && fallbackAuthRecord.Collection().Id == form.collection.Id:
|
||||
// fallback to the logged auth record (if any)
|
||||
authRecord = fallbackAuthRecord
|
||||
case authUser.Email != "":
|
||||
// look for an existing auth record by the external auth record's email
|
||||
authRecord, err = e.App.FindAuthRecordByEmail(form.collection.Id, authUser.Email)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
return e.InternalServerError("Failed OAuth2 auth record check.", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------
|
||||
|
||||
event := new(core.RecordAuthWithOAuth2RequestEvent)
|
||||
event.RequestEvent = e
|
||||
event.Collection = collection
|
||||
event.ProviderName = form.Provider
|
||||
event.ProviderClient = provider
|
||||
event.OAuth2User = authUser
|
||||
event.CreateData = form.CreateData
|
||||
event.Record = authRecord
|
||||
event.IsNewRecord = authRecord == nil
|
||||
|
||||
return e.App.OnRecordAuthWithOAuth2Request().Trigger(event, func(e *core.RecordAuthWithOAuth2RequestEvent) error {
|
||||
if err := oauth2Submit(e, externalAuthRel); err != nil {
|
||||
return firstApiError(err, e.BadRequestError("Failed to authenticate.", err))
|
||||
}
|
||||
|
||||
// @todo revert back to struct after removing the custom auth.AuthUser marshalization
|
||||
meta := map[string]any{}
|
||||
rawOAuth2User, err := json.Marshal(e.OAuth2User)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = json.Unmarshal(rawOAuth2User, &meta)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
meta["isNew"] = e.IsNewRecord
|
||||
|
||||
return RecordAuthResponse(e.RequestEvent, e.Record, core.MFAMethodOAuth2, meta)
|
||||
})
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
type recordOAuth2LoginForm struct {
|
||||
collection *core.Collection
|
||||
|
||||
// Additional data that will be used for creating a new auth record
|
||||
// if an existing OAuth2 account doesn't exist.
|
||||
CreateData map[string]any `form:"createData" json:"createData"`
|
||||
|
||||
// The name of the OAuth2 client provider (eg. "google")
|
||||
Provider string `form:"provider" json:"provider"`
|
||||
|
||||
// The authorization code returned from the initial request.
|
||||
Code string `form:"code" json:"code"`
|
||||
|
||||
// The optional PKCE code verifier as part of the code_challenge sent with the initial request.
|
||||
CodeVerifier string `form:"codeVerifier" json:"codeVerifier"`
|
||||
|
||||
// The redirect url sent with the initial request.
|
||||
RedirectURL string `form:"redirectURL" json:"redirectURL"`
|
||||
|
||||
// @todo
|
||||
// deprecated: use RedirectURL instead
|
||||
// RedirectUrl will be removed after dropping v0.22 support
|
||||
RedirectUrl string `form:"redirectUrl" json:"redirectUrl"`
|
||||
}
|
||||
|
||||
func (form *recordOAuth2LoginForm) validate() error {
|
||||
return validation.ValidateStruct(form,
|
||||
validation.Field(&form.Provider, validation.Required, validation.Length(0, 100), validation.By(form.checkProviderName)),
|
||||
validation.Field(&form.Code, validation.Required),
|
||||
validation.Field(&form.RedirectURL, validation.Required),
|
||||
)
|
||||
}
|
||||
|
||||
func (form *recordOAuth2LoginForm) checkProviderName(value any) error {
|
||||
name, _ := value.(string)
|
||||
|
||||
_, ok := form.collection.OAuth2.GetProviderConfig(name)
|
||||
if !ok {
|
||||
return validation.NewError("validation_invalid_provider", "Provider with name {{.name}} is missing or is not enabled.").
|
||||
SetParams(map[string]any{"name": name})
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func oldCanAssignUsername(txApp core.App, collection *core.Collection, username string) bool {
|
||||
// ensure that username is unique
|
||||
index, hasUniqueue := dbutils.FindSingleColumnUniqueIndex(collection.Indexes, collection.OAuth2.MappedFields.Username)
|
||||
if hasUniqueue {
|
||||
var expr dbx.Expression
|
||||
if strings.EqualFold(index.Columns[0].Collate, "nocase") {
|
||||
// case-insensitive search
|
||||
expr = dbx.NewExp("username = {:username} COLLATE NOCASE", dbx.Params{"username": username})
|
||||
} else {
|
||||
expr = dbx.HashExp{"username": username}
|
||||
}
|
||||
|
||||
var exists int
|
||||
_ = txApp.RecordQuery(collection).Select("(1)").AndWhere(expr).Limit(1).Row(&exists)
|
||||
if exists > 0 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// ensure that the value matches the pattern of the username field (if text)
|
||||
txtField, _ := collection.Fields.GetByName(collection.OAuth2.MappedFields.Username).(*core.TextField)
|
||||
|
||||
return txtField != nil && txtField.ValidatePlainValue(username) == nil
|
||||
}
|
||||
|
||||
func oauth2Submit(e *core.RecordAuthWithOAuth2RequestEvent, optExternalAuth *core.ExternalAuth) error {
|
||||
return e.App.RunInTransaction(func(txApp core.App) error {
|
||||
if e.Record == nil {
|
||||
// extra check to prevent creating a superuser record via
|
||||
// OAuth2 in case the method is used by another action
|
||||
if e.Collection.Name == core.CollectionNameSuperusers {
|
||||
return errors.New("superusers are not allowed to sign-up with OAuth2")
|
||||
}
|
||||
|
||||
payload := maps.Clone(e.CreateData)
|
||||
if payload == nil {
|
||||
payload = map[string]any{}
|
||||
}
|
||||
|
||||
// assign the OAuth2 user email only if the user hasn't submitted one
|
||||
// (ignore empty/invalid values for consistency with the OAuth2->existing user update flow)
|
||||
if v, _ := payload[core.FieldNameEmail].(string); v == "" {
|
||||
payload[core.FieldNameEmail] = e.OAuth2User.Email
|
||||
}
|
||||
|
||||
// map known fields (unless the field was explicitly submitted as part of CreateData)
|
||||
if _, ok := payload[e.Collection.OAuth2.MappedFields.Id]; !ok && e.Collection.OAuth2.MappedFields.Id != "" {
|
||||
payload[e.Collection.OAuth2.MappedFields.Id] = e.OAuth2User.Id
|
||||
}
|
||||
if _, ok := payload[e.Collection.OAuth2.MappedFields.Name]; !ok && e.Collection.OAuth2.MappedFields.Name != "" {
|
||||
payload[e.Collection.OAuth2.MappedFields.Name] = e.OAuth2User.Name
|
||||
}
|
||||
if _, ok := payload[e.Collection.OAuth2.MappedFields.Username]; !ok &&
|
||||
// no explicit username payload value and existing OAuth2 mapping
|
||||
e.Collection.OAuth2.MappedFields.Username != "" &&
|
||||
// extra checks for backward compatibility with earlier versions
|
||||
oldCanAssignUsername(txApp, e.Collection, e.OAuth2User.Username) {
|
||||
payload[e.Collection.OAuth2.MappedFields.Username] = e.OAuth2User.Username
|
||||
}
|
||||
if _, ok := payload[e.Collection.OAuth2.MappedFields.AvatarURL]; !ok &&
|
||||
// no explicit avatar payload value and existing OAuth2 mapping
|
||||
e.Collection.OAuth2.MappedFields.AvatarURL != "" &&
|
||||
// non-empty OAuth2 avatar url
|
||||
e.OAuth2User.AvatarURL != "" {
|
||||
mappedField := e.Collection.Fields.GetByName(e.Collection.OAuth2.MappedFields.AvatarURL)
|
||||
if mappedField != nil && mappedField.Type() == core.FieldTypeFile {
|
||||
// download the avatar if the mapped field is a file
|
||||
avatarFile, err := func() (*filesystem.File, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
return filesystem.NewFileFromURL(ctx, e.OAuth2User.AvatarURL)
|
||||
}()
|
||||
if err != nil {
|
||||
txApp.Logger().Warn("Failed to retrieve OAuth2 avatar", slog.String("error", err.Error()))
|
||||
} else {
|
||||
payload[e.Collection.OAuth2.MappedFields.AvatarURL] = avatarFile
|
||||
}
|
||||
} else {
|
||||
// otherwise - assign the url string
|
||||
payload[e.Collection.OAuth2.MappedFields.AvatarURL] = e.OAuth2User.AvatarURL
|
||||
}
|
||||
}
|
||||
|
||||
createdRecord, err := sendOAuth2RecordCreateRequest(txApp, e, payload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
e.Record = createdRecord
|
||||
|
||||
if e.Record.Email() == e.OAuth2User.Email && !e.Record.Verified() {
|
||||
// mark as verified as long as it matches the OAuth2 data (even if the email is empty)
|
||||
e.Record.SetVerified(true)
|
||||
if err := txApp.Save(e.Record); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
} else {
|
||||
var needUpdate bool
|
||||
|
||||
isLoggedAuthRecord := e.Auth != nil &&
|
||||
e.Auth.Id == e.Record.Id &&
|
||||
e.Auth.Collection().Id == e.Record.Collection().Id
|
||||
|
||||
// set random password for users with unverified email
|
||||
// (this is in case a malicious actor has registered previously with the user email)
|
||||
if !isLoggedAuthRecord && e.Record.Email() != "" && !e.Record.Verified() {
|
||||
e.Record.SetRandomPassword()
|
||||
needUpdate = true
|
||||
}
|
||||
|
||||
// update the existing auth record empty email if the data.OAuth2User has one
|
||||
// (this is in case previously the auth record was created
|
||||
// with an OAuth2 provider that didn't return an email address)
|
||||
if e.Record.Email() == "" && e.OAuth2User.Email != "" {
|
||||
e.Record.SetEmail(e.OAuth2User.Email)
|
||||
needUpdate = true
|
||||
}
|
||||
|
||||
// update the existing auth record verified state
|
||||
// (only if the auth record doesn't have an email or the auth record email match with the one in data.OAuth2User)
|
||||
if !e.Record.Verified() && (e.Record.Email() == "" || e.Record.Email() == e.OAuth2User.Email) {
|
||||
e.Record.SetVerified(true)
|
||||
needUpdate = true
|
||||
}
|
||||
|
||||
if needUpdate {
|
||||
if err := txApp.Save(e.Record); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// create ExternalAuth relation if missing
|
||||
if optExternalAuth == nil {
|
||||
optExternalAuth = core.NewExternalAuth(txApp)
|
||||
optExternalAuth.SetCollectionRef(e.Record.Collection().Id)
|
||||
optExternalAuth.SetRecordRef(e.Record.Id)
|
||||
optExternalAuth.SetProvider(e.ProviderName)
|
||||
optExternalAuth.SetProviderId(e.OAuth2User.Id)
|
||||
|
||||
if err := txApp.Save(optExternalAuth); err != nil {
|
||||
return fmt.Errorf("failed to save linked rel: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func sendOAuth2RecordCreateRequest(txApp core.App, e *core.RecordAuthWithOAuth2RequestEvent, payload map[string]any) (*core.Record, error) {
|
||||
ir := &core.InternalRequest{
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/" + e.Collection.Name + "/records",
|
||||
Body: payload,
|
||||
}
|
||||
|
||||
var createdRecord *core.Record
|
||||
response, err := processInternalRequest(txApp, e.RequestEvent, ir, core.RequestInfoContextOAuth2, func(data any) error {
|
||||
createdRecord, _ = data.(*core.Record)
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if response.Status != http.StatusOK || createdRecord == nil {
|
||||
return nil, errors.New("failed to create OAuth2 auth record")
|
||||
}
|
||||
|
||||
return createdRecord, nil
|
||||
}
|
74
apis/record_auth_with_oauth2_redirect.go
Normal file
74
apis/record_auth_with_oauth2_redirect.go
Normal file
|
@ -0,0 +1,74 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tools/subscriptions"
|
||||
)
|
||||
|
||||
const (
|
||||
oauth2SubscriptionTopic string = "@oauth2"
|
||||
oauth2RedirectFailurePath string = "../_/#/auth/oauth2-redirect-failure"
|
||||
oauth2RedirectSuccessPath string = "../_/#/auth/oauth2-redirect-success"
|
||||
)
|
||||
|
||||
type oauth2RedirectData struct {
|
||||
State string `form:"state" json:"state"`
|
||||
Code string `form:"code" json:"code"`
|
||||
Error string `form:"error" json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func oauth2SubscriptionRedirect(e *core.RequestEvent) error {
|
||||
redirectStatusCode := http.StatusTemporaryRedirect
|
||||
if e.Request.Method != http.MethodGet {
|
||||
redirectStatusCode = http.StatusSeeOther
|
||||
}
|
||||
|
||||
data := oauth2RedirectData{}
|
||||
|
||||
if e.Request.Method == http.MethodPost {
|
||||
if err := e.BindBody(&data); err != nil {
|
||||
e.App.Logger().Debug("Failed to read OAuth2 redirect data", "error", err)
|
||||
return e.Redirect(redirectStatusCode, oauth2RedirectFailurePath)
|
||||
}
|
||||
} else {
|
||||
query := e.Request.URL.Query()
|
||||
data.State = query.Get("state")
|
||||
data.Code = query.Get("code")
|
||||
data.Error = query.Get("error")
|
||||
}
|
||||
|
||||
if data.State == "" {
|
||||
e.App.Logger().Debug("Missing OAuth2 state parameter")
|
||||
return e.Redirect(redirectStatusCode, oauth2RedirectFailurePath)
|
||||
}
|
||||
|
||||
client, err := e.App.SubscriptionsBroker().ClientById(data.State)
|
||||
if err != nil || client.IsDiscarded() || !client.HasSubscription(oauth2SubscriptionTopic) {
|
||||
e.App.Logger().Debug("Missing or invalid OAuth2 subscription client", "error", err, "clientId", data.State)
|
||||
return e.Redirect(redirectStatusCode, oauth2RedirectFailurePath)
|
||||
}
|
||||
defer client.Unsubscribe(oauth2SubscriptionTopic)
|
||||
|
||||
encodedData, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
e.App.Logger().Debug("Failed to marshalize OAuth2 redirect data", "error", err)
|
||||
return e.Redirect(redirectStatusCode, oauth2RedirectFailurePath)
|
||||
}
|
||||
|
||||
msg := subscriptions.Message{
|
||||
Name: oauth2SubscriptionTopic,
|
||||
Data: encodedData,
|
||||
}
|
||||
|
||||
client.Send(msg)
|
||||
|
||||
if data.Error != "" || data.Code == "" {
|
||||
e.App.Logger().Debug("Failed OAuth2 redirect due to an error or missing code parameter", "error", data.Error, "clientId", data.State)
|
||||
return e.Redirect(redirectStatusCode, oauth2RedirectFailurePath)
|
||||
}
|
||||
|
||||
return e.Redirect(redirectStatusCode, oauth2RedirectSuccessPath)
|
||||
}
|
274
apis/record_auth_with_oauth2_redirect_test.go
Normal file
274
apis/record_auth_with_oauth2_redirect_test.go
Normal file
|
@ -0,0 +1,274 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
"github.com/pocketbase/pocketbase/tools/subscriptions"
|
||||
)
|
||||
|
||||
func TestRecordAuthWithOAuth2Redirect(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
clientStubs := make([]map[string]subscriptions.Client, 0, 10)
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
c1 := subscriptions.NewDefaultClient()
|
||||
|
||||
c2 := subscriptions.NewDefaultClient()
|
||||
c2.Subscribe("@oauth2")
|
||||
|
||||
c3 := subscriptions.NewDefaultClient()
|
||||
c3.Subscribe("test1", "@oauth2")
|
||||
|
||||
c4 := subscriptions.NewDefaultClient()
|
||||
c4.Subscribe("test1", "test2")
|
||||
|
||||
c5 := subscriptions.NewDefaultClient()
|
||||
c5.Subscribe("@oauth2")
|
||||
c5.Discard()
|
||||
|
||||
clientStubs = append(clientStubs, map[string]subscriptions.Client{
|
||||
"c1": c1,
|
||||
"c2": c2,
|
||||
"c3": c3,
|
||||
"c4": c4,
|
||||
"c5": c5,
|
||||
})
|
||||
}
|
||||
|
||||
checkFailureRedirect := func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
loc := res.Header.Get("Location")
|
||||
if !strings.Contains(loc, "/oauth2-redirect-failure") {
|
||||
t.Fatalf("Expected failure redirect, got %q", loc)
|
||||
}
|
||||
}
|
||||
|
||||
checkSuccessRedirect := func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
loc := res.Header.Get("Location")
|
||||
if !strings.Contains(loc, "/oauth2-redirect-success") {
|
||||
t.Fatalf("Expected success redirect, got %q", loc)
|
||||
}
|
||||
}
|
||||
|
||||
// note: don't exit because it is usually called as part of a separate goroutine
|
||||
checkClientMessages := func(t testing.TB, clientId string, msg subscriptions.Message, expectedMessages map[string][]string) {
|
||||
if len(expectedMessages[clientId]) == 0 {
|
||||
t.Errorf("Unexpected client %q message, got %q:\n%q", clientId, msg.Name, msg.Data)
|
||||
return
|
||||
}
|
||||
|
||||
if msg.Name != "@oauth2" {
|
||||
t.Errorf("Expected @oauth2 msg.Name, got %q", msg.Name)
|
||||
return
|
||||
}
|
||||
|
||||
for _, txt := range expectedMessages[clientId] {
|
||||
if !strings.Contains(string(msg.Data), txt) {
|
||||
t.Errorf("Failed to find %q in \n%s", txt, msg.Data)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
beforeTestFunc := func(
|
||||
clients map[string]subscriptions.Client,
|
||||
expectedMessages map[string][]string,
|
||||
) func(testing.TB, *tests.TestApp, *core.ServeEvent) {
|
||||
return func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
for _, client := range clients {
|
||||
app.SubscriptionsBroker().Register(client)
|
||||
}
|
||||
|
||||
ctx, cancelFunc := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
|
||||
// add to the app store so that it can be cancelled manually after test completion
|
||||
app.Store().Set("cancelFunc", cancelFunc)
|
||||
|
||||
go func() {
|
||||
defer cancelFunc()
|
||||
|
||||
for {
|
||||
select {
|
||||
case msg, ok := <-clients["c1"].Channel():
|
||||
if ok {
|
||||
checkClientMessages(t, "c1", msg, expectedMessages)
|
||||
} else {
|
||||
t.Errorf("Unexpected c1 closed channel")
|
||||
}
|
||||
case msg, ok := <-clients["c2"].Channel():
|
||||
if ok {
|
||||
checkClientMessages(t, "c2", msg, expectedMessages)
|
||||
} else {
|
||||
t.Errorf("Unexpected c2 closed channel")
|
||||
}
|
||||
case msg, ok := <-clients["c3"].Channel():
|
||||
if ok {
|
||||
checkClientMessages(t, "c3", msg, expectedMessages)
|
||||
} else {
|
||||
t.Errorf("Unexpected c3 closed channel")
|
||||
}
|
||||
case msg, ok := <-clients["c4"].Channel():
|
||||
if ok {
|
||||
checkClientMessages(t, "c4", msg, expectedMessages)
|
||||
} else {
|
||||
t.Errorf("Unexpected c4 closed channel")
|
||||
}
|
||||
case _, ok := <-clients["c5"].Channel():
|
||||
if ok {
|
||||
t.Errorf("Expected c5 channel to be closed")
|
||||
}
|
||||
case <-ctx.Done():
|
||||
for _, c := range clients {
|
||||
c.Discard()
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "no state query param",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/oauth2-redirect?code=123",
|
||||
BeforeTestFunc: beforeTestFunc(clientStubs[0], nil),
|
||||
ExpectedStatus: http.StatusTemporaryRedirect,
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
app.Store().Get("cancelFunc").(context.CancelFunc)()
|
||||
|
||||
checkFailureRedirect(t, app, res)
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "invalid or missing client",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/oauth2-redirect?code=123&state=missing",
|
||||
BeforeTestFunc: beforeTestFunc(clientStubs[1], nil),
|
||||
ExpectedStatus: http.StatusTemporaryRedirect,
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
app.Store().Get("cancelFunc").(context.CancelFunc)()
|
||||
|
||||
checkFailureRedirect(t, app, res)
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "no code query param",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/oauth2-redirect?state=" + clientStubs[2]["c3"].Id(),
|
||||
BeforeTestFunc: beforeTestFunc(clientStubs[2], map[string][]string{
|
||||
"c3": {`"state":"` + clientStubs[2]["c3"].Id(), `"code":""`},
|
||||
}),
|
||||
ExpectedStatus: http.StatusTemporaryRedirect,
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
app.Store().Get("cancelFunc").(context.CancelFunc)()
|
||||
|
||||
checkFailureRedirect(t, app, res)
|
||||
|
||||
if clientStubs[2]["c3"].HasSubscription("@oauth2") {
|
||||
t.Fatalf("Expected oauth2 subscription to be removed")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "error query param",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/oauth2-redirect?error=example&code=123&state=" + clientStubs[3]["c3"].Id(),
|
||||
BeforeTestFunc: beforeTestFunc(clientStubs[3], map[string][]string{
|
||||
"c3": {`"state":"` + clientStubs[3]["c3"].Id(), `"code":"123"`, `"error":"example"`},
|
||||
}),
|
||||
ExpectedStatus: http.StatusTemporaryRedirect,
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
app.Store().Get("cancelFunc").(context.CancelFunc)()
|
||||
|
||||
checkFailureRedirect(t, app, res)
|
||||
|
||||
if clientStubs[3]["c3"].HasSubscription("@oauth2") {
|
||||
t.Fatalf("Expected oauth2 subscription to be removed")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "discarded client with @oauth2 subscription",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/oauth2-redirect?code=123&state=" + clientStubs[4]["c5"].Id(),
|
||||
BeforeTestFunc: beforeTestFunc(clientStubs[4], nil),
|
||||
ExpectedStatus: http.StatusTemporaryRedirect,
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
app.Store().Get("cancelFunc").(context.CancelFunc)()
|
||||
|
||||
checkFailureRedirect(t, app, res)
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "client without @oauth2 subscription",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/oauth2-redirect?code=123&state=" + clientStubs[4]["c4"].Id(),
|
||||
BeforeTestFunc: beforeTestFunc(clientStubs[5], nil),
|
||||
ExpectedStatus: http.StatusTemporaryRedirect,
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
app.Store().Get("cancelFunc").(context.CancelFunc)()
|
||||
|
||||
checkFailureRedirect(t, app, res)
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "client with @oauth2 subscription",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/oauth2-redirect?code=123&state=" + clientStubs[6]["c3"].Id(),
|
||||
BeforeTestFunc: beforeTestFunc(clientStubs[6], map[string][]string{
|
||||
"c3": {`"state":"` + clientStubs[6]["c3"].Id(), `"code":"123"`},
|
||||
}),
|
||||
ExpectedStatus: http.StatusTemporaryRedirect,
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
app.Store().Get("cancelFunc").(context.CancelFunc)()
|
||||
|
||||
checkSuccessRedirect(t, app, res)
|
||||
|
||||
if clientStubs[6]["c3"].HasSubscription("@oauth2") {
|
||||
t.Fatalf("Expected oauth2 subscription to be removed")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "(POST) client with @oauth2 subscription",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/oauth2-redirect",
|
||||
Body: strings.NewReader("code=123&state=" + clientStubs[7]["c3"].Id()),
|
||||
Headers: map[string]string{
|
||||
"content-type": "application/x-www-form-urlencoded",
|
||||
},
|
||||
BeforeTestFunc: beforeTestFunc(clientStubs[7], map[string][]string{
|
||||
"c3": {`"state":"` + clientStubs[7]["c3"].Id(), `"code":"123"`},
|
||||
}),
|
||||
ExpectedStatus: http.StatusSeeOther,
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
app.Store().Get("cancelFunc").(context.CancelFunc)()
|
||||
|
||||
checkSuccessRedirect(t, app, res)
|
||||
|
||||
if clientStubs[7]["c3"].HasSubscription("@oauth2") {
|
||||
t.Fatalf("Expected oauth2 subscription to be removed")
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
1715
apis/record_auth_with_oauth2_test.go
Normal file
1715
apis/record_auth_with_oauth2_test.go
Normal file
File diff suppressed because it is too large
Load diff
106
apis/record_auth_with_otp.go
Normal file
106
apis/record_auth_with_otp.go
Normal file
|
@ -0,0 +1,106 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
)
|
||||
|
||||
func recordAuthWithOTP(e *core.RequestEvent) error {
|
||||
collection, err := findAuthCollection(e)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !collection.OTP.Enabled {
|
||||
return e.ForbiddenError("The collection is not configured to allow OTP authentication.", nil)
|
||||
}
|
||||
|
||||
form := &authWithOTPForm{}
|
||||
if err = e.BindBody(form); err != nil {
|
||||
return firstApiError(err, e.BadRequestError("An error occurred while loading the submitted data.", err))
|
||||
}
|
||||
if err = form.validate(); err != nil {
|
||||
return firstApiError(err, e.BadRequestError("An error occurred while validating the submitted data.", err))
|
||||
}
|
||||
|
||||
e.Set(core.RequestEventKeyInfoContext, core.RequestInfoContextOTP)
|
||||
|
||||
event := new(core.RecordAuthWithOTPRequestEvent)
|
||||
event.RequestEvent = e
|
||||
event.Collection = collection
|
||||
|
||||
// extra validations
|
||||
// (note: returns a generic 400 as a very basic OTPs enumeration protection)
|
||||
// ---
|
||||
event.OTP, err = e.App.FindOTPById(form.OTPId)
|
||||
if err != nil {
|
||||
return e.BadRequestError("Invalid or expired OTP", err)
|
||||
}
|
||||
|
||||
if event.OTP.CollectionRef() != collection.Id {
|
||||
return e.BadRequestError("Invalid or expired OTP", errors.New("the OTP is for a different collection"))
|
||||
}
|
||||
|
||||
if event.OTP.HasExpired(collection.OTP.DurationTime()) {
|
||||
return e.BadRequestError("Invalid or expired OTP", errors.New("the OTP is expired"))
|
||||
}
|
||||
|
||||
event.Record, err = e.App.FindRecordById(event.OTP.CollectionRef(), event.OTP.RecordRef())
|
||||
if err != nil {
|
||||
return e.BadRequestError("Invalid or expired OTP", fmt.Errorf("missing auth record: %w", err))
|
||||
}
|
||||
|
||||
// since otps are usually simple digit numbers, enforce an extra rate limit rule as basic enumaration protection
|
||||
err = checkRateLimit(e, "@pb_otp_"+event.Record.Id, core.RateLimitRule{MaxRequests: 5, Duration: 180})
|
||||
if err != nil {
|
||||
return e.TooManyRequestsError("Too many attempts, please try again later with a new OTP.", nil)
|
||||
}
|
||||
|
||||
if !event.OTP.ValidatePassword(form.Password) {
|
||||
return e.BadRequestError("Invalid or expired OTP", errors.New("incorrect password"))
|
||||
}
|
||||
// ---
|
||||
|
||||
return e.App.OnRecordAuthWithOTPRequest().Trigger(event, func(e *core.RecordAuthWithOTPRequestEvent) error {
|
||||
// update the user email verified state in case the OTP originate from an email address matching the current record one
|
||||
//
|
||||
// note: don't wait for success auth response (it could fail because of MFA) and because we already validated the OTP above
|
||||
otpSentTo := e.OTP.SentTo()
|
||||
if !e.Record.Verified() && otpSentTo != "" && e.Record.Email() == otpSentTo {
|
||||
e.Record.SetVerified(true)
|
||||
err = e.App.Save(e.Record)
|
||||
if err != nil {
|
||||
e.App.Logger().Error("Failed to update record verified state after successful OTP validation",
|
||||
"error", err,
|
||||
"otpId", e.OTP.Id,
|
||||
"recordId", e.Record.Id,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// try to delete the used otp
|
||||
err = e.App.Delete(e.OTP)
|
||||
if err != nil {
|
||||
e.App.Logger().Error("Failed to delete used OTP", "error", err, "otpId", e.OTP.Id)
|
||||
}
|
||||
|
||||
return RecordAuthResponse(e.RequestEvent, e.Record, core.MFAMethodOTP, nil)
|
||||
})
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
type authWithOTPForm struct {
|
||||
OTPId string `form:"otpId" json:"otpId"`
|
||||
Password string `form:"password" json:"password"`
|
||||
}
|
||||
|
||||
func (form *authWithOTPForm) validate() error {
|
||||
return validation.ValidateStruct(form,
|
||||
validation.Field(&form.OTPId, validation.Required, validation.Length(1, 255)),
|
||||
validation.Field(&form.Password, validation.Required, validation.Length(1, 71)),
|
||||
)
|
||||
}
|
608
apis/record_auth_with_otp_test.go
Normal file
608
apis/record_auth_with_otp_test.go
Normal file
|
@ -0,0 +1,608 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
)
|
||||
|
||||
func TestRecordAuthWithOTP(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "not an auth collection",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/demo1/auth-with-otp",
|
||||
Body: strings.NewReader(`{"otpId":"test","password":"123456"}`),
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "auth collection with disabled otp",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-with-otp",
|
||||
Body: strings.NewReader(`{"otpId":"test","password":"123456"}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
usersCol, err := app.FindCollectionByNameOrId("users")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
usersCol.OTP.Enabled = false
|
||||
|
||||
if err := app.Save(usersCol); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "invalid body",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-with-otp",
|
||||
Body: strings.NewReader(`{"email`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "empty body",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-with-otp",
|
||||
Body: strings.NewReader(``),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"otpId":{"code":"validation_required"`,
|
||||
`"password":{"code":"validation_required"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "invalid request data",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-with-otp",
|
||||
Body: strings.NewReader(`{
|
||||
"otpId":"` + strings.Repeat("a", 256) + `",
|
||||
"password":"` + strings.Repeat("a", 72) + `"
|
||||
}`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"otpId":{"code":"validation_length_out_of_range"`,
|
||||
`"password":{"code":"validation_length_out_of_range"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "missing otp",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-with-otp",
|
||||
Body: strings.NewReader(`{
|
||||
"otpId":"missing",
|
||||
"password":"123456"
|
||||
}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
otp := core.NewOTP(app)
|
||||
otp.Id = strings.Repeat("a", 15)
|
||||
otp.SetCollectionRef(user.Collection().Id)
|
||||
otp.SetRecordRef(user.Id)
|
||||
otp.SetPassword("123456")
|
||||
if err := app.Save(otp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "otp for different collection",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-with-otp",
|
||||
Body: strings.NewReader(`{
|
||||
"otpId":"` + strings.Repeat("a", 15) + `",
|
||||
"password":"123456"
|
||||
}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
client, err := app.FindAuthRecordByEmail("clients", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
otp := core.NewOTP(app)
|
||||
otp.Id = strings.Repeat("a", 15)
|
||||
otp.SetCollectionRef(client.Collection().Id)
|
||||
otp.SetRecordRef(client.Id)
|
||||
otp.SetPassword("123456")
|
||||
if err := app.Save(otp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "otp with wrong password",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-with-otp",
|
||||
Body: strings.NewReader(`{
|
||||
"otpId":"` + strings.Repeat("a", 15) + `",
|
||||
"password":"123456"
|
||||
}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
otp := core.NewOTP(app)
|
||||
otp.Id = strings.Repeat("a", 15)
|
||||
otp.SetCollectionRef(user.Collection().Id)
|
||||
otp.SetRecordRef(user.Id)
|
||||
otp.SetPassword("1234567890")
|
||||
if err := app.Save(otp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "expired otp with valid password",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-with-otp",
|
||||
Body: strings.NewReader(`{
|
||||
"otpId":"` + strings.Repeat("a", 15) + `",
|
||||
"password":"123456"
|
||||
}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
otp := core.NewOTP(app)
|
||||
otp.Id = strings.Repeat("a", 15)
|
||||
otp.SetCollectionRef(user.Collection().Id)
|
||||
otp.SetRecordRef(user.Id)
|
||||
otp.SetPassword("123456")
|
||||
expiredDate := types.NowDateTime().AddDate(-3, 0, 0)
|
||||
otp.SetRaw("created", expiredDate)
|
||||
otp.SetRaw("updated", expiredDate)
|
||||
if err := app.Save(otp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "valid otp with valid password (enabled MFA)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-with-otp",
|
||||
Body: strings.NewReader(`{
|
||||
"otpId":"` + strings.Repeat("a", 15) + `",
|
||||
"password":"123456"
|
||||
}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
otp := core.NewOTP(app)
|
||||
otp.Id = strings.Repeat("a", 15)
|
||||
otp.SetCollectionRef(user.Collection().Id)
|
||||
otp.SetRecordRef(user.Id)
|
||||
otp.SetPassword("123456")
|
||||
if err := app.Save(otp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"mfaId":"`},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordAuthWithOTPRequest": 1,
|
||||
"OnRecordAuthRequest": 1,
|
||||
// ---
|
||||
"OnModelValidate": 1,
|
||||
"OnModelCreate": 1, // mfa record
|
||||
"OnModelCreateExecute": 1,
|
||||
"OnModelAfterCreateSuccess": 1,
|
||||
"OnModelDelete": 1, // otp delete
|
||||
"OnModelDeleteExecute": 1,
|
||||
"OnModelAfterDeleteSuccess": 1,
|
||||
// ---
|
||||
"OnRecordValidate": 1,
|
||||
"OnRecordCreate": 1,
|
||||
"OnRecordCreateExecute": 1,
|
||||
"OnRecordAfterCreateSuccess": 1,
|
||||
"OnRecordDelete": 1,
|
||||
"OnRecordDeleteExecute": 1,
|
||||
"OnRecordAfterDeleteSuccess": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "valid otp with valid password and empty sentTo (disabled MFA)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-with-otp",
|
||||
Body: strings.NewReader(`{
|
||||
"otpId":"` + strings.Repeat("a", 15) + `",
|
||||
"password":"123456"
|
||||
}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// ensure that the user is unverified
|
||||
user.SetVerified(false)
|
||||
if err = app.Save(user); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// disable MFA
|
||||
user.Collection().MFA.Enabled = false
|
||||
if err = app.Save(user.Collection()); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
otp := core.NewOTP(app)
|
||||
otp.Id = strings.Repeat("a", 15)
|
||||
otp.SetCollectionRef(user.Collection().Id)
|
||||
otp.SetRecordRef(user.Id)
|
||||
otp.SetPassword("123456")
|
||||
if err := app.Save(otp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// test at least once that the correct request info context is properly loaded
|
||||
app.OnRecordAuthRequest().BindFunc(func(e *core.RecordAuthRequestEvent) error {
|
||||
info, err := e.RequestInfo()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if info.Context != core.RequestInfoContextOTP {
|
||||
t.Fatalf("Expected request context %q, got %q", core.RequestInfoContextOTP, info.Context)
|
||||
}
|
||||
|
||||
return e.Next()
|
||||
})
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"token":"`,
|
||||
`"record":{`,
|
||||
`"email":"test@example.com"`,
|
||||
},
|
||||
NotExpectedContent: []string{
|
||||
`"meta":`,
|
||||
// hidden fields
|
||||
`"tokenKey"`,
|
||||
`"password"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordAuthWithOTPRequest": 1,
|
||||
"OnRecordAuthRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
// ---
|
||||
"OnModelValidate": 1,
|
||||
"OnModelCreate": 1, // authOrigin
|
||||
"OnModelCreateExecute": 1,
|
||||
"OnModelAfterCreateSuccess": 1,
|
||||
"OnModelDelete": 1, // otp delete
|
||||
"OnModelDeleteExecute": 1,
|
||||
"OnModelAfterDeleteSuccess": 1,
|
||||
// ---
|
||||
"OnRecordValidate": 1,
|
||||
"OnRecordCreate": 1,
|
||||
"OnRecordCreateExecute": 1,
|
||||
"OnRecordAfterCreateSuccess": 1,
|
||||
"OnRecordDelete": 1,
|
||||
"OnRecordDeleteExecute": 1,
|
||||
"OnRecordAfterDeleteSuccess": 1,
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if user.Verified() {
|
||||
t.Fatal("Expected the user to remain unverified because sentTo != email")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "valid otp with valid password and nonempty sentTo=email (disabled MFA)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-with-otp",
|
||||
Body: strings.NewReader(`{
|
||||
"otpId":"` + strings.Repeat("a", 15) + `",
|
||||
"password":"123456"
|
||||
}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// ensure that the user is unverified
|
||||
user.SetVerified(false)
|
||||
if err = app.Save(user); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// disable MFA
|
||||
user.Collection().MFA.Enabled = false
|
||||
if err = app.Save(user.Collection()); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
otp := core.NewOTP(app)
|
||||
otp.Id = strings.Repeat("a", 15)
|
||||
otp.SetCollectionRef(user.Collection().Id)
|
||||
otp.SetRecordRef(user.Id)
|
||||
otp.SetPassword("123456")
|
||||
otp.SetSentTo(user.Email())
|
||||
if err := app.Save(otp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"token":"`,
|
||||
`"record":{`,
|
||||
`"email":"test@example.com"`,
|
||||
},
|
||||
NotExpectedContent: []string{
|
||||
`"meta":`,
|
||||
// hidden fields
|
||||
`"tokenKey"`,
|
||||
`"password"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordAuthWithOTPRequest": 1,
|
||||
"OnRecordAuthRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
// ---
|
||||
"OnModelValidate": 2, // +1 because of the verified user update
|
||||
// authOrigin create
|
||||
"OnModelCreate": 1,
|
||||
"OnModelCreateExecute": 1,
|
||||
"OnModelAfterCreateSuccess": 1,
|
||||
// OTP delete
|
||||
"OnModelDelete": 1,
|
||||
"OnModelDeleteExecute": 1,
|
||||
"OnModelAfterDeleteSuccess": 1,
|
||||
// user verified update
|
||||
"OnModelUpdate": 1,
|
||||
"OnModelUpdateExecute": 1,
|
||||
"OnModelAfterUpdateSuccess": 1,
|
||||
// ---
|
||||
"OnRecordValidate": 2,
|
||||
"OnRecordCreate": 1,
|
||||
"OnRecordCreateExecute": 1,
|
||||
"OnRecordAfterCreateSuccess": 1,
|
||||
"OnRecordDelete": 1,
|
||||
"OnRecordDeleteExecute": 1,
|
||||
"OnRecordAfterDeleteSuccess": 1,
|
||||
"OnRecordUpdate": 1,
|
||||
"OnRecordUpdateExecute": 1,
|
||||
"OnRecordAfterUpdateSuccess": 1,
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !user.Verified() {
|
||||
t.Fatal("Expected the user to be marked as verified")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "OnRecordAuthWithOTPRequest tx body write check",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-with-otp",
|
||||
Body: strings.NewReader(`{
|
||||
"otpId":"` + strings.Repeat("a", 15) + `",
|
||||
"password":"123456"
|
||||
}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// disable MFA
|
||||
user.Collection().MFA.Enabled = false
|
||||
if err = app.Save(user.Collection()); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
otp := core.NewOTP(app)
|
||||
otp.Id = strings.Repeat("a", 15)
|
||||
otp.SetCollectionRef(user.Collection().Id)
|
||||
otp.SetRecordRef(user.Id)
|
||||
otp.SetPassword("123456")
|
||||
if err := app.Save(otp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
app.OnRecordAuthWithOTPRequest().BindFunc(func(e *core.RecordAuthWithOTPRequestEvent) error {
|
||||
original := e.App
|
||||
return e.App.RunInTransaction(func(txApp core.App) error {
|
||||
e.App = txApp
|
||||
defer func() { e.App = original }()
|
||||
|
||||
if err := e.Next(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return e.BadRequestError("TX_ERROR", nil)
|
||||
})
|
||||
})
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedEvents: map[string]int{"OnRecordAuthWithOTPRequest": 1},
|
||||
ExpectedContent: []string{"TX_ERROR"},
|
||||
},
|
||||
|
||||
// rate limit checks
|
||||
// -----------------------------------------------------------
|
||||
{
|
||||
Name: "RateLimit rule - users:authWithOTP",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-with-otp",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 100, Label: "*:authWithOTP"},
|
||||
{MaxRequests: 100, Label: "users:auth"},
|
||||
{MaxRequests: 0, Label: "users:authWithOTP"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "RateLimit rule - *:authWithOTP",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-with-otp",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 100, Label: "*:auth"},
|
||||
{MaxRequests: 0, Label: "*:authWithOTP"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "RateLimit rule - users:auth",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-with-otp",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 100, Label: "*:authWithOTP"},
|
||||
{MaxRequests: 0, Label: "users:auth"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "RateLimit rule - *:auth",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-with-otp",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 0, Label: "*:auth"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordAuthWithOTPManualRateLimiterCheck(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var storeCache map[string]any
|
||||
|
||||
otpAId := strings.Repeat("a", 15)
|
||||
otpBId := strings.Repeat("b", 15)
|
||||
|
||||
scenarios := []struct {
|
||||
otpId string
|
||||
password string
|
||||
expectedStatus int
|
||||
}{
|
||||
{otpAId, "12345", 400},
|
||||
{otpAId, "12345", 400},
|
||||
{otpBId, "12345", 400},
|
||||
{otpBId, "12345", 400},
|
||||
{otpBId, "12345", 400},
|
||||
{otpAId, "12345", 429},
|
||||
{otpAId, "123456", 429}, // reject even if it is correct
|
||||
{otpAId, "123456", 429},
|
||||
{otpBId, "123456", 429},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
(&tests.ApiScenario{
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-with-otp",
|
||||
Body: strings.NewReader(`{
|
||||
"otpId":"` + s.otpId + `",
|
||||
"password":"` + s.password + `"
|
||||
}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
for k, v := range storeCache {
|
||||
app.Store().Set(k, v)
|
||||
}
|
||||
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
user.Collection().MFA.Enabled = false
|
||||
if err := app.Save(user.Collection()); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
for _, id := range []string{otpAId, otpBId} {
|
||||
otp := core.NewOTP(app)
|
||||
otp.Id = id
|
||||
otp.SetCollectionRef(user.Collection().Id)
|
||||
otp.SetRecordRef(user.Id)
|
||||
otp.SetPassword("123456")
|
||||
if err := app.Save(otp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
},
|
||||
ExpectedStatus: s.expectedStatus,
|
||||
ExpectedContent: []string{`"`}, // it doesn't matter anything non-empty
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
storeCache = app.Store().GetAll()
|
||||
},
|
||||
}).Test(t)
|
||||
}
|
||||
}
|
135
apis/record_auth_with_password.go
Normal file
135
apis/record_auth_with_password.go
Normal file
|
@ -0,0 +1,135 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/go-ozzo/ozzo-validation/v4/is"
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tools/dbutils"
|
||||
"github.com/pocketbase/pocketbase/tools/list"
|
||||
)
|
||||
|
||||
func recordAuthWithPassword(e *core.RequestEvent) error {
|
||||
collection, err := findAuthCollection(e)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !collection.PasswordAuth.Enabled {
|
||||
return e.ForbiddenError("The collection is not configured to allow password authentication.", nil)
|
||||
}
|
||||
|
||||
form := &authWithPasswordForm{}
|
||||
if err = e.BindBody(form); err != nil {
|
||||
return firstApiError(err, e.BadRequestError("An error occurred while loading the submitted data.", err))
|
||||
}
|
||||
if err = form.validate(collection); err != nil {
|
||||
return firstApiError(err, e.BadRequestError("An error occurred while validating the submitted data.", err))
|
||||
}
|
||||
|
||||
e.Set(core.RequestEventKeyInfoContext, core.RequestInfoContextPasswordAuth)
|
||||
|
||||
var foundRecord *core.Record
|
||||
var foundErr error
|
||||
|
||||
if form.IdentityField != "" {
|
||||
foundRecord, foundErr = findRecordByIdentityField(e.App, collection, form.IdentityField, form.Identity)
|
||||
} else {
|
||||
// prioritize email lookup
|
||||
isEmail := is.EmailFormat.Validate(form.Identity) == nil
|
||||
if isEmail && list.ExistInSlice(core.FieldNameEmail, collection.PasswordAuth.IdentityFields) {
|
||||
foundRecord, foundErr = findRecordByIdentityField(e.App, collection, core.FieldNameEmail, form.Identity)
|
||||
}
|
||||
|
||||
// search by the other identity fields
|
||||
if !isEmail || foundErr != nil {
|
||||
for _, name := range collection.PasswordAuth.IdentityFields {
|
||||
if !isEmail && name == core.FieldNameEmail {
|
||||
continue // no need to search by the email field if it is not an email
|
||||
}
|
||||
|
||||
foundRecord, foundErr = findRecordByIdentityField(e.App, collection, name, form.Identity)
|
||||
if foundErr == nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ignore not found errors to allow custom record find implementations
|
||||
if foundErr != nil && !errors.Is(foundErr, sql.ErrNoRows) {
|
||||
return e.InternalServerError("", foundErr)
|
||||
}
|
||||
|
||||
event := new(core.RecordAuthWithPasswordRequestEvent)
|
||||
event.RequestEvent = e
|
||||
event.Collection = collection
|
||||
event.Record = foundRecord
|
||||
event.Identity = form.Identity
|
||||
event.Password = form.Password
|
||||
event.IdentityField = form.IdentityField
|
||||
|
||||
return e.App.OnRecordAuthWithPasswordRequest().Trigger(event, func(e *core.RecordAuthWithPasswordRequestEvent) error {
|
||||
if e.Record == nil || !e.Record.ValidatePassword(e.Password) {
|
||||
return e.BadRequestError("Failed to authenticate.", errors.New("invalid login credentials"))
|
||||
}
|
||||
|
||||
return RecordAuthResponse(e.RequestEvent, e.Record, core.MFAMethodPassword, nil)
|
||||
})
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
type authWithPasswordForm struct {
|
||||
Identity string `form:"identity" json:"identity"`
|
||||
Password string `form:"password" json:"password"`
|
||||
|
||||
// IdentityField specifies the field to use to search for the identity
|
||||
// (leave it empty for "auto" detection).
|
||||
IdentityField string `form:"identityField" json:"identityField"`
|
||||
}
|
||||
|
||||
func (form *authWithPasswordForm) validate(collection *core.Collection) error {
|
||||
return validation.ValidateStruct(form,
|
||||
validation.Field(&form.Identity, validation.Required, validation.Length(1, 255)),
|
||||
validation.Field(&form.Password, validation.Required, validation.Length(1, 255)),
|
||||
validation.Field(
|
||||
&form.IdentityField,
|
||||
validation.Length(1, 255),
|
||||
validation.In(list.ToInterfaceSlice(collection.PasswordAuth.IdentityFields)...),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
func findRecordByIdentityField(app core.App, collection *core.Collection, field string, value any) (*core.Record, error) {
|
||||
if !slices.Contains(collection.PasswordAuth.IdentityFields, field) {
|
||||
return nil, errors.New("invalid identity field " + field)
|
||||
}
|
||||
|
||||
index, ok := dbutils.FindSingleColumnUniqueIndex(collection.Indexes, field)
|
||||
if !ok {
|
||||
return nil, errors.New("missing " + field + " unique index constraint")
|
||||
}
|
||||
|
||||
var expr dbx.Expression
|
||||
if strings.EqualFold(index.Columns[0].Collate, "nocase") {
|
||||
// case-insensitive search
|
||||
expr = dbx.NewExp("[["+field+"]] = {:identity} COLLATE NOCASE", dbx.Params{"identity": value})
|
||||
} else {
|
||||
expr = dbx.HashExp{field: value}
|
||||
}
|
||||
|
||||
record := &core.Record{}
|
||||
|
||||
err := app.RecordQuery(collection).AndWhere(expr).Limit(1).One(record)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
713
apis/record_auth_with_password_test.go
Normal file
713
apis/record_auth_with_password_test.go
Normal file
|
@ -0,0 +1,713 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
"github.com/pocketbase/pocketbase/tools/dbutils"
|
||||
)
|
||||
|
||||
func TestRecordAuthWithPassword(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
updateIdentityIndex := func(collectionIdOrName string, fieldCollateMap map[string]string) func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
return func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
collection, err := app.FindCollectionByNameOrId("clients")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
for column, collate := range fieldCollateMap {
|
||||
index, ok := dbutils.FindSingleColumnUniqueIndex(collection.Indexes, column)
|
||||
if !ok {
|
||||
t.Fatalf("Missing unique identityField index for column %q", column)
|
||||
}
|
||||
|
||||
index.Columns[0].Collate = collate
|
||||
|
||||
collection.RemoveIndex(index.IndexName)
|
||||
collection.Indexes = append(collection.Indexes, index.Build())
|
||||
}
|
||||
|
||||
err = app.Save(collection)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to update identityField index: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "disabled password auth",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/nologin/auth-with-password",
|
||||
Body: strings.NewReader(`{"identity":"test@example.com","password":"1234567890"}`),
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "non-auth collection",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/demo1/auth-with-password",
|
||||
Body: strings.NewReader(`{"identity":"test@example.com","password":"1234567890"}`),
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "invalid body format",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/clients/auth-with-password",
|
||||
Body: strings.NewReader(`{"identity`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "empty body params",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/clients/auth-with-password",
|
||||
Body: strings.NewReader(`{"identity":"","password":""}`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"identity":{`,
|
||||
`"password":{`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "OnRecordAuthWithPasswordRequest tx body write check",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/clients/auth-with-password",
|
||||
Body: strings.NewReader(`{
|
||||
"identity":"test@example.com",
|
||||
"password":"1234567890"
|
||||
}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.OnRecordAuthWithPasswordRequest().BindFunc(func(e *core.RecordAuthWithPasswordRequestEvent) error {
|
||||
original := e.App
|
||||
return e.App.RunInTransaction(func(txApp core.App) error {
|
||||
e.App = txApp
|
||||
defer func() { e.App = original }()
|
||||
|
||||
if err := e.Next(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return e.BadRequestError("TX_ERROR", nil)
|
||||
})
|
||||
})
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedEvents: map[string]int{"OnRecordAuthWithPasswordRequest": 1},
|
||||
ExpectedContent: []string{"TX_ERROR"},
|
||||
},
|
||||
{
|
||||
Name: "valid identity field and invalid password",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/clients/auth-with-password",
|
||||
Body: strings.NewReader(`{
|
||||
"identity":"test@example.com",
|
||||
"password":"invalid"
|
||||
}`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{}`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordAuthWithPasswordRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "valid identity field (email) and valid password",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/clients/auth-with-password",
|
||||
Body: strings.NewReader(`{
|
||||
"identity":"test@example.com",
|
||||
"password":"1234567890"
|
||||
}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
// test at least once that the correct request info context is properly loaded
|
||||
app.OnRecordAuthRequest().BindFunc(func(e *core.RecordAuthRequestEvent) error {
|
||||
info, err := e.RequestInfo()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if info.Context != core.RequestInfoContextPasswordAuth {
|
||||
t.Fatalf("Expected request context %q, got %q", core.RequestInfoContextPasswordAuth, info.Context)
|
||||
}
|
||||
|
||||
return e.Next()
|
||||
})
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"email":"test@example.com"`,
|
||||
`"token":`,
|
||||
},
|
||||
NotExpectedContent: []string{
|
||||
// hidden fields
|
||||
`"tokenKey"`,
|
||||
`"password"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordAuthWithPasswordRequest": 1,
|
||||
"OnRecordAuthRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
// authOrigin track
|
||||
"OnModelCreate": 1,
|
||||
"OnModelCreateExecute": 1,
|
||||
"OnModelAfterCreateSuccess": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnRecordCreate": 1,
|
||||
"OnRecordCreateExecute": 1,
|
||||
"OnRecordAfterCreateSuccess": 1,
|
||||
"OnRecordValidate": 1,
|
||||
"OnMailerSend": 1,
|
||||
"OnMailerRecordAuthAlertSend": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "valid identity field (username) and valid password",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/clients/auth-with-password",
|
||||
Body: strings.NewReader(`{
|
||||
"identity":"clients57772",
|
||||
"password":"1234567890"
|
||||
}`),
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"email":"test@example.com"`,
|
||||
`"username":"clients57772"`,
|
||||
`"token":`,
|
||||
},
|
||||
NotExpectedContent: []string{
|
||||
// hidden fields
|
||||
`"tokenKey"`,
|
||||
`"password"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordAuthWithPasswordRequest": 1,
|
||||
"OnRecordAuthRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
// authOrigin track
|
||||
"OnModelCreate": 1,
|
||||
"OnModelCreateExecute": 1,
|
||||
"OnModelAfterCreateSuccess": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnRecordCreate": 1,
|
||||
"OnRecordCreateExecute": 1,
|
||||
"OnRecordAfterCreateSuccess": 1,
|
||||
"OnRecordValidate": 1,
|
||||
"OnMailerSend": 1,
|
||||
"OnMailerRecordAuthAlertSend": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "unknown explicit identityField",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/clients/auth-with-password",
|
||||
Body: strings.NewReader(`{
|
||||
"identityField": "created",
|
||||
"identity":"test@example.com",
|
||||
"password":"1234567890"
|
||||
}`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"identityField":{"code":"validation_in_invalid"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "valid identity field and valid password with mismatched explicit identityField",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/clients/auth-with-password",
|
||||
Body: strings.NewReader(`{
|
||||
"identityField": "username",
|
||||
"identity":"test@example.com",
|
||||
"password":"1234567890"
|
||||
}`),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordAuthWithPasswordRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "valid identity field and valid password with matched explicit identityField",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/clients/auth-with-password",
|
||||
Body: strings.NewReader(`{
|
||||
"identityField": "username",
|
||||
"identity":"clients57772",
|
||||
"password":"1234567890"
|
||||
}`),
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"email":"test@example.com"`,
|
||||
`"username":"clients57772"`,
|
||||
`"token":`,
|
||||
},
|
||||
NotExpectedContent: []string{
|
||||
// hidden fields
|
||||
`"tokenKey"`,
|
||||
`"password"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordAuthWithPasswordRequest": 1,
|
||||
"OnRecordAuthRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
// authOrigin track
|
||||
"OnModelCreate": 1,
|
||||
"OnModelCreateExecute": 1,
|
||||
"OnModelAfterCreateSuccess": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnRecordCreate": 1,
|
||||
"OnRecordCreateExecute": 1,
|
||||
"OnRecordAfterCreateSuccess": 1,
|
||||
"OnRecordValidate": 1,
|
||||
"OnMailerSend": 1,
|
||||
"OnMailerRecordAuthAlertSend": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "valid identity (unverified) and valid password in onlyVerified collection",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/clients/auth-with-password",
|
||||
Body: strings.NewReader(`{
|
||||
"identity":"test2@example.com",
|
||||
"password":"1234567890"
|
||||
}`),
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordAuthWithPasswordRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "already authenticated record",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/clients/auth-with-password",
|
||||
Body: strings.NewReader(`{
|
||||
"identity":"test@example.com",
|
||||
"password":"1234567890"
|
||||
}`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"id":"gk390qegs4y47wn"`,
|
||||
`"email":"test@example.com"`,
|
||||
`"token":`,
|
||||
},
|
||||
NotExpectedContent: []string{
|
||||
// hidden fields
|
||||
`"tokenKey"`,
|
||||
`"password"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordAuthWithPasswordRequest": 1,
|
||||
"OnRecordAuthRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
// authOrigin track
|
||||
"OnModelCreate": 1,
|
||||
"OnModelCreateExecute": 1,
|
||||
"OnModelAfterCreateSuccess": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnRecordCreate": 1,
|
||||
"OnRecordCreateExecute": 1,
|
||||
"OnRecordAfterCreateSuccess": 1,
|
||||
"OnRecordValidate": 1,
|
||||
"OnMailerSend": 1,
|
||||
"OnMailerRecordAuthAlertSend": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "with mfa first auth check",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-with-password",
|
||||
Body: strings.NewReader(`{
|
||||
"identity":"test@example.com",
|
||||
"password":"1234567890"
|
||||
}`),
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{
|
||||
`"mfaId":"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordAuthWithPasswordRequest": 1,
|
||||
"OnRecordAuthRequest": 1,
|
||||
// mfa create
|
||||
"OnModelCreate": 1,
|
||||
"OnModelCreateExecute": 1,
|
||||
"OnModelAfterCreateSuccess": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnRecordCreate": 1,
|
||||
"OnRecordCreateExecute": 1,
|
||||
"OnRecordAfterCreateSuccess": 1,
|
||||
"OnRecordValidate": 1,
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
mfas, err := app.FindAllMFAsByRecord(user)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if v := len(mfas); v != 1 {
|
||||
t.Fatalf("Expected 1 mfa record to be created, got %d", v)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "with mfa second auth check",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-with-password",
|
||||
Body: strings.NewReader(`{
|
||||
"mfaId": "` + strings.Repeat("a", 15) + `",
|
||||
"identity":"test@example.com",
|
||||
"password":"1234567890"
|
||||
}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// insert a dummy mfa record
|
||||
mfa := core.NewMFA(app)
|
||||
mfa.Id = strings.Repeat("a", 15)
|
||||
mfa.SetCollectionRef(user.Collection().Id)
|
||||
mfa.SetRecordRef(user.Id)
|
||||
mfa.SetMethod("test")
|
||||
if err := app.Save(mfa); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"email":"test@example.com"`,
|
||||
`"token":`,
|
||||
},
|
||||
NotExpectedContent: []string{
|
||||
// hidden fields
|
||||
`"tokenKey"`,
|
||||
`"password"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordAuthWithPasswordRequest": 1,
|
||||
"OnRecordAuthRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
// authOrigin track
|
||||
"OnModelCreate": 1,
|
||||
"OnModelCreateExecute": 1,
|
||||
"OnModelAfterCreateSuccess": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnRecordCreate": 1,
|
||||
"OnRecordCreateExecute": 1,
|
||||
"OnRecordAfterCreateSuccess": 1,
|
||||
"OnRecordValidate": 1,
|
||||
"OnMailerSend": 0, // disabled auth email alerts
|
||||
"OnMailerRecordAuthAlertSend": 0,
|
||||
// mfa delete
|
||||
"OnModelDelete": 1,
|
||||
"OnModelDeleteExecute": 1,
|
||||
"OnModelAfterDeleteSuccess": 1,
|
||||
"OnRecordDelete": 1,
|
||||
"OnRecordDeleteExecute": 1,
|
||||
"OnRecordAfterDeleteSuccess": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "with enabled mfa but unsatisfied mfa rule (aka. skip the mfa check)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-with-password",
|
||||
Body: strings.NewReader(`{
|
||||
"identity":"test@example.com",
|
||||
"password":"1234567890"
|
||||
}`),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
users, err := app.FindCollectionByNameOrId("users")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
users.MFA.Enabled = true
|
||||
users.MFA.Rule = "1=2"
|
||||
|
||||
if err := app.Save(users); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"email":"test@example.com"`,
|
||||
`"token":`,
|
||||
},
|
||||
NotExpectedContent: []string{
|
||||
// hidden fields
|
||||
`"tokenKey"`,
|
||||
`"password"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordAuthWithPasswordRequest": 1,
|
||||
"OnRecordAuthRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
// authOrigin track
|
||||
"OnModelCreate": 1,
|
||||
"OnModelCreateExecute": 1,
|
||||
"OnModelAfterCreateSuccess": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnRecordCreate": 1,
|
||||
"OnRecordCreateExecute": 1,
|
||||
"OnRecordAfterCreateSuccess": 1,
|
||||
"OnRecordValidate": 1,
|
||||
"OnMailerSend": 0, // disabled auth email alerts
|
||||
"OnMailerRecordAuthAlertSend": 0,
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
mfas, err := app.FindAllMFAsByRecord(user)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if v := len(mfas); v != 0 {
|
||||
t.Fatalf("Expected no mfa records to be created, got %d", v)
|
||||
}
|
||||
},
|
||||
},
|
||||
|
||||
// case sensitivity checks
|
||||
// -----------------------------------------------------------
|
||||
{
|
||||
Name: "with explicit identityField (case-sensitive)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/clients/auth-with-password",
|
||||
Body: strings.NewReader(`{
|
||||
"identityField": "username",
|
||||
"identity":"Clients57772",
|
||||
"password":"1234567890"
|
||||
}`),
|
||||
BeforeTestFunc: updateIdentityIndex("clients", map[string]string{"username": ""}),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordAuthWithPasswordRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "with explicit identityField (case-insensitive)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/clients/auth-with-password",
|
||||
Body: strings.NewReader(`{
|
||||
"identityField": "username",
|
||||
"identity":"Clients57772",
|
||||
"password":"1234567890"
|
||||
}`),
|
||||
BeforeTestFunc: updateIdentityIndex("clients", map[string]string{"username": "nocase"}),
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"email":"test@example.com"`,
|
||||
`"username":"clients57772"`,
|
||||
`"token":`,
|
||||
},
|
||||
NotExpectedContent: []string{
|
||||
// hidden fields
|
||||
`"tokenKey"`,
|
||||
`"password"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordAuthWithPasswordRequest": 1,
|
||||
"OnRecordAuthRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
// authOrigin track
|
||||
"OnModelCreate": 1,
|
||||
"OnModelCreateExecute": 1,
|
||||
"OnModelAfterCreateSuccess": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnRecordCreate": 1,
|
||||
"OnRecordCreateExecute": 1,
|
||||
"OnRecordAfterCreateSuccess": 1,
|
||||
"OnRecordValidate": 1,
|
||||
"OnMailerSend": 1,
|
||||
"OnMailerRecordAuthAlertSend": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "without explicit identityField and non-email field (case-insensitive)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/clients/auth-with-password",
|
||||
Body: strings.NewReader(`{
|
||||
"identity":"Clients57772",
|
||||
"password":"1234567890"
|
||||
}`),
|
||||
BeforeTestFunc: updateIdentityIndex("clients", map[string]string{"username": "nocase"}),
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"email":"test@example.com"`,
|
||||
`"username":"clients57772"`,
|
||||
`"token":`,
|
||||
},
|
||||
NotExpectedContent: []string{
|
||||
// hidden fields
|
||||
`"tokenKey"`,
|
||||
`"password"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordAuthWithPasswordRequest": 1,
|
||||
"OnRecordAuthRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
// authOrigin track
|
||||
"OnModelCreate": 1,
|
||||
"OnModelCreateExecute": 1,
|
||||
"OnModelAfterCreateSuccess": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnRecordCreate": 1,
|
||||
"OnRecordCreateExecute": 1,
|
||||
"OnRecordAfterCreateSuccess": 1,
|
||||
"OnRecordValidate": 1,
|
||||
"OnMailerSend": 1,
|
||||
"OnMailerRecordAuthAlertSend": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "without explicit identityField and email field (case-insensitive)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/clients/auth-with-password",
|
||||
Body: strings.NewReader(`{
|
||||
"identity":"tESt@example.com",
|
||||
"password":"1234567890"
|
||||
}`),
|
||||
BeforeTestFunc: updateIdentityIndex("clients", map[string]string{"email": "nocase"}),
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"email":"test@example.com"`,
|
||||
`"username":"clients57772"`,
|
||||
`"token":`,
|
||||
},
|
||||
NotExpectedContent: []string{
|
||||
// hidden fields
|
||||
`"tokenKey"`,
|
||||
`"password"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordAuthWithPasswordRequest": 1,
|
||||
"OnRecordAuthRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
// authOrigin track
|
||||
"OnModelCreate": 1,
|
||||
"OnModelCreateExecute": 1,
|
||||
"OnModelAfterCreateSuccess": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnRecordCreate": 1,
|
||||
"OnRecordCreateExecute": 1,
|
||||
"OnRecordAfterCreateSuccess": 1,
|
||||
"OnRecordValidate": 1,
|
||||
"OnMailerSend": 1,
|
||||
"OnMailerRecordAuthAlertSend": 1,
|
||||
},
|
||||
},
|
||||
|
||||
// rate limit checks
|
||||
// -----------------------------------------------------------
|
||||
{
|
||||
Name: "RateLimit rule - users:authWithPassword",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-with-password",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 100, Label: "*:authWithPassword"},
|
||||
{MaxRequests: 100, Label: "users:auth"},
|
||||
{MaxRequests: 0, Label: "users:authWithPassword"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "RateLimit rule - *:authWithPassword",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-with-password",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 100, Label: "*:auth"},
|
||||
{MaxRequests: 0, Label: "*:authWithPassword"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "RateLimit rule - users:auth",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-with-password",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 100, Label: "*:authWithPassword"},
|
||||
{MaxRequests: 0, Label: "users:auth"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "RateLimit rule - *:auth",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/users/auth-with-password",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.Settings().RateLimits.Enabled = true
|
||||
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
||||
{MaxRequests: 100, Label: "abc"},
|
||||
{MaxRequests: 0, Label: "*:auth"},
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 429,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
742
apis/record_crud.go
Normal file
742
apis/record_crud.go
Normal file
|
@ -0,0 +1,742 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
cryptoRand "crypto/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/forms"
|
||||
"github.com/pocketbase/pocketbase/tools/filesystem"
|
||||
"github.com/pocketbase/pocketbase/tools/inflector"
|
||||
"github.com/pocketbase/pocketbase/tools/list"
|
||||
"github.com/pocketbase/pocketbase/tools/router"
|
||||
"github.com/pocketbase/pocketbase/tools/search"
|
||||
"github.com/pocketbase/pocketbase/tools/security"
|
||||
)
|
||||
|
||||
// bindRecordCrudApi registers the record crud api endpoints and
|
||||
// the corresponding handlers.
|
||||
//
|
||||
// note: the rate limiter is "inlined" because some of the crud actions are also used in the batch APIs
|
||||
func bindRecordCrudApi(app core.App, rg *router.RouterGroup[*core.RequestEvent]) {
|
||||
subGroup := rg.Group("/collections/{collection}/records").Unbind(DefaultRateLimitMiddlewareId)
|
||||
subGroup.GET("", recordsList)
|
||||
subGroup.GET("/{id}", recordView)
|
||||
subGroup.POST("", recordCreate(true, nil)).Bind(dynamicCollectionBodyLimit(""))
|
||||
subGroup.PATCH("/{id}", recordUpdate(true, nil)).Bind(dynamicCollectionBodyLimit(""))
|
||||
subGroup.DELETE("/{id}", recordDelete(true, nil))
|
||||
}
|
||||
|
||||
func recordsList(e *core.RequestEvent) error {
|
||||
collection, err := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue("collection"))
|
||||
if err != nil || collection == nil {
|
||||
return e.NotFoundError("Missing collection context.", err)
|
||||
}
|
||||
|
||||
err = checkCollectionRateLimit(e, collection, "list")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
requestInfo, err := e.RequestInfo()
|
||||
if err != nil {
|
||||
return firstApiError(err, e.BadRequestError("", err))
|
||||
}
|
||||
|
||||
if collection.ListRule == nil && !requestInfo.HasSuperuserAuth() {
|
||||
return e.ForbiddenError("Only superusers can perform this action.", nil)
|
||||
}
|
||||
|
||||
// forbid users and guests to query special filter/sort fields
|
||||
err = checkForSuperuserOnlyRuleFields(requestInfo)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
query := e.App.RecordQuery(collection)
|
||||
|
||||
fieldsResolver := core.NewRecordFieldResolver(e.App, collection, requestInfo, true)
|
||||
|
||||
if !requestInfo.HasSuperuserAuth() && collection.ListRule != nil && *collection.ListRule != "" {
|
||||
expr, err := search.FilterData(*collection.ListRule).BuildExpr(fieldsResolver)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
query.AndWhere(expr)
|
||||
|
||||
// will be applied by the search provider right before executing the query
|
||||
// fieldsResolver.UpdateQuery(query)
|
||||
}
|
||||
|
||||
// hidden fields are searchable only by superusers
|
||||
fieldsResolver.SetAllowHiddenFields(requestInfo.HasSuperuserAuth())
|
||||
|
||||
searchProvider := search.NewProvider(fieldsResolver).Query(query)
|
||||
|
||||
// use rowid when available to minimize the need of a covering index with the "id" field
|
||||
if !collection.IsView() {
|
||||
searchProvider.CountCol("_rowid_")
|
||||
}
|
||||
|
||||
records := []*core.Record{}
|
||||
result, err := searchProvider.ParseAndExec(e.Request.URL.Query().Encode(), &records)
|
||||
if err != nil {
|
||||
return firstApiError(err, e.BadRequestError("", err))
|
||||
}
|
||||
|
||||
event := new(core.RecordsListRequestEvent)
|
||||
event.RequestEvent = e
|
||||
event.Collection = collection
|
||||
event.Records = records
|
||||
event.Result = result
|
||||
|
||||
return e.App.OnRecordsListRequest().Trigger(event, func(e *core.RecordsListRequestEvent) error {
|
||||
if err := EnrichRecords(e.RequestEvent, e.Records); err != nil {
|
||||
return firstApiError(err, e.InternalServerError("Failed to enrich records", err))
|
||||
}
|
||||
|
||||
// Add a randomized throttle in case of too many empty search filter attempts.
|
||||
//
|
||||
// This is just for extra precaution since security researches raised concern regarding the possibility of eventual
|
||||
// timing attacks because the List API rule acts also as filter and executes in a single run with the client-side filters.
|
||||
// This is by design and it is an accepted trade off between performance, usability and correctness.
|
||||
//
|
||||
// While technically the below doesn't fully guarantee protection against filter timing attacks, in practice combined with the network latency it makes them even less feasible.
|
||||
// A properly configured rate limiter or individual fields Hidden checks are better suited if you are really concerned about eventual information disclosure by side-channel attacks.
|
||||
//
|
||||
// In all cases it doesn't really matter that much because it doesn't affect the builtin PocketBase security sensitive fields (e.g. password and tokenKey) since they
|
||||
// are not client-side filterable and in the few places where they need to be compared against an external value, a constant time check is used.
|
||||
if !e.HasSuperuserAuth() &&
|
||||
(collection.ListRule != nil && *collection.ListRule != "") &&
|
||||
(requestInfo.Query["filter"] != "") &&
|
||||
len(e.Records) == 0 &&
|
||||
checkRateLimit(e.RequestEvent, "@pb_list_timing_check_"+collection.Id, listTimingRateLimitRule) != nil {
|
||||
e.App.Logger().Debug("Randomized throttle because of too many failed searches", "collectionId", collection.Id)
|
||||
randomizedThrottle(150)
|
||||
}
|
||||
|
||||
return execAfterSuccessTx(true, e.App, func() error {
|
||||
return e.JSON(http.StatusOK, e.Result)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
var listTimingRateLimitRule = core.RateLimitRule{MaxRequests: 3, Duration: 3}
|
||||
|
||||
func randomizedThrottle(softMax int64) {
|
||||
var timeout int64
|
||||
randRange, err := cryptoRand.Int(cryptoRand.Reader, big.NewInt(softMax))
|
||||
if err == nil {
|
||||
timeout = randRange.Int64()
|
||||
} else {
|
||||
timeout = softMax
|
||||
}
|
||||
|
||||
time.Sleep(time.Duration(timeout) * time.Millisecond)
|
||||
}
|
||||
|
||||
func recordView(e *core.RequestEvent) error {
|
||||
collection, err := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue("collection"))
|
||||
if err != nil || collection == nil {
|
||||
return e.NotFoundError("Missing collection context.", err)
|
||||
}
|
||||
|
||||
err = checkCollectionRateLimit(e, collection, "view")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
recordId := e.Request.PathValue("id")
|
||||
if recordId == "" {
|
||||
return e.NotFoundError("", nil)
|
||||
}
|
||||
|
||||
requestInfo, err := e.RequestInfo()
|
||||
if err != nil {
|
||||
return firstApiError(err, e.BadRequestError("", err))
|
||||
}
|
||||
|
||||
if collection.ViewRule == nil && !requestInfo.HasSuperuserAuth() {
|
||||
return e.ForbiddenError("Only superusers can perform this action.", nil)
|
||||
}
|
||||
|
||||
ruleFunc := func(q *dbx.SelectQuery) error {
|
||||
if !requestInfo.HasSuperuserAuth() && collection.ViewRule != nil && *collection.ViewRule != "" {
|
||||
resolver := core.NewRecordFieldResolver(e.App, collection, requestInfo, true)
|
||||
expr, err := search.FilterData(*collection.ViewRule).BuildExpr(resolver)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resolver.UpdateQuery(q)
|
||||
q.AndWhere(expr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
record, fetchErr := e.App.FindRecordById(collection, recordId, ruleFunc)
|
||||
if fetchErr != nil || record == nil {
|
||||
return firstApiError(err, e.NotFoundError("", fetchErr))
|
||||
}
|
||||
|
||||
event := new(core.RecordRequestEvent)
|
||||
event.RequestEvent = e
|
||||
event.Collection = collection
|
||||
event.Record = record
|
||||
|
||||
return e.App.OnRecordViewRequest().Trigger(event, func(e *core.RecordRequestEvent) error {
|
||||
if err := EnrichRecord(e.RequestEvent, e.Record); err != nil {
|
||||
return firstApiError(err, e.InternalServerError("Failed to enrich record", err))
|
||||
}
|
||||
|
||||
return execAfterSuccessTx(true, e.App, func() error {
|
||||
return e.JSON(http.StatusOK, e.Record)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func recordCreate(responseWriteAfterTx bool, optFinalizer func(data any) error) func(e *core.RequestEvent) error {
|
||||
return func(e *core.RequestEvent) error {
|
||||
collection, err := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue("collection"))
|
||||
if err != nil || collection == nil {
|
||||
return e.NotFoundError("Missing collection context.", err)
|
||||
}
|
||||
|
||||
if collection.IsView() {
|
||||
return e.BadRequestError("Unsupported collection type.", nil)
|
||||
}
|
||||
|
||||
err = checkCollectionRateLimit(e, collection, "create")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
requestInfo, err := e.RequestInfo()
|
||||
if err != nil {
|
||||
return firstApiError(err, e.BadRequestError("", err))
|
||||
}
|
||||
|
||||
hasSuperuserAuth := requestInfo.HasSuperuserAuth()
|
||||
if !hasSuperuserAuth && collection.CreateRule == nil {
|
||||
return e.ForbiddenError("Only superusers can perform this action.", nil)
|
||||
}
|
||||
|
||||
record := core.NewRecord(collection)
|
||||
|
||||
data, err := recordDataFromRequest(e, record)
|
||||
if err != nil {
|
||||
return firstApiError(err, e.BadRequestError("Failed to read the submitted data.", err))
|
||||
}
|
||||
|
||||
// set a random password for the OAuth2 ignoring its plain password validators
|
||||
var skipPlainPasswordRecordValidators bool
|
||||
if requestInfo.Context == core.RequestInfoContextOAuth2 {
|
||||
if _, ok := data[core.FieldNamePassword]; !ok {
|
||||
data[core.FieldNamePassword] = security.RandomString(30)
|
||||
data[core.FieldNamePassword+"Confirm"] = data[core.FieldNamePassword]
|
||||
skipPlainPasswordRecordValidators = true
|
||||
}
|
||||
}
|
||||
|
||||
// replace modifiers fields so that the resolved value is always
|
||||
// available when accessing requestInfo.Body
|
||||
requestInfo.Body = data
|
||||
|
||||
form := forms.NewRecordUpsert(e.App, record)
|
||||
if hasSuperuserAuth {
|
||||
form.GrantSuperuserAccess()
|
||||
}
|
||||
form.Load(data)
|
||||
|
||||
if skipPlainPasswordRecordValidators {
|
||||
// unset the plain value to skip the plain password field validators
|
||||
if raw, ok := record.GetRaw(core.FieldNamePassword).(*core.PasswordFieldValue); ok {
|
||||
raw.Plain = ""
|
||||
}
|
||||
}
|
||||
|
||||
// check the request and record data against the create and manage rules
|
||||
if !hasSuperuserAuth && collection.CreateRule != nil {
|
||||
dummyRecord := record.Clone()
|
||||
|
||||
dummyRandomPart := "__pb_create__" + security.PseudorandomString(6)
|
||||
|
||||
// set an id if it doesn't have already
|
||||
// (the value doesn't matter; it is used only to minimize the breaking changes with earlier versions)
|
||||
if dummyRecord.Id == "" {
|
||||
dummyRecord.Id = "__temp_id__" + dummyRandomPart
|
||||
}
|
||||
|
||||
// unset the verified field to prevent manage API rule misuse in case the rule relies on it
|
||||
dummyRecord.SetVerified(false)
|
||||
|
||||
// export the dummy record data into db params
|
||||
dummyExport, err := dummyRecord.DBExport(e.App)
|
||||
if err != nil {
|
||||
return e.BadRequestError("Failed to create record", fmt.Errorf("dummy DBExport error: %w", err))
|
||||
}
|
||||
|
||||
dummyParams := make(dbx.Params, len(dummyExport))
|
||||
selects := make([]string, 0, len(dummyExport))
|
||||
var param string
|
||||
for k, v := range dummyExport {
|
||||
k = inflector.Columnify(k) // columnify is just as extra measure in case of custom fields
|
||||
param = "__pb_create__" + k
|
||||
dummyParams[param] = v
|
||||
selects = append(selects, "{:"+param+"} AS [["+k+"]]")
|
||||
}
|
||||
|
||||
// shallow clone the current collection
|
||||
dummyCollection := *collection
|
||||
dummyCollection.Id += dummyRandomPart
|
||||
dummyCollection.Name += inflector.Columnify(dummyRandomPart)
|
||||
|
||||
withFrom := fmt.Sprintf("WITH {{%s}} as (SELECT %s)", dummyCollection.Name, strings.Join(selects, ","))
|
||||
|
||||
// check non-empty create rule
|
||||
if *dummyCollection.CreateRule != "" {
|
||||
ruleQuery := e.App.ConcurrentDB().Select("(1)").PreFragment(withFrom).From(dummyCollection.Name).AndBind(dummyParams)
|
||||
|
||||
resolver := core.NewRecordFieldResolver(e.App, &dummyCollection, requestInfo, true)
|
||||
|
||||
expr, err := search.FilterData(*dummyCollection.CreateRule).BuildExpr(resolver)
|
||||
if err != nil {
|
||||
return e.BadRequestError("Failed to create record", fmt.Errorf("create rule build expression failure: %w", err))
|
||||
}
|
||||
ruleQuery.AndWhere(expr)
|
||||
|
||||
resolver.UpdateQuery(ruleQuery)
|
||||
|
||||
var exists int
|
||||
err = ruleQuery.Limit(1).Row(&exists)
|
||||
if err != nil || exists == 0 {
|
||||
return e.BadRequestError("Failed to create record", fmt.Errorf("create rule failure: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
// check for manage rule access
|
||||
manageRuleQuery := e.App.ConcurrentDB().Select("(1)").PreFragment(withFrom).From(dummyCollection.Name).AndBind(dummyParams)
|
||||
if !form.HasManageAccess() &&
|
||||
hasAuthManageAccess(e.App, requestInfo, &dummyCollection, manageRuleQuery) {
|
||||
form.GrantManagerAccess()
|
||||
}
|
||||
}
|
||||
|
||||
var isOptFinalizerCalled bool
|
||||
|
||||
event := new(core.RecordRequestEvent)
|
||||
event.RequestEvent = e
|
||||
event.Collection = collection
|
||||
event.Record = record
|
||||
|
||||
hookErr := e.App.OnRecordCreateRequest().Trigger(event, func(e *core.RecordRequestEvent) error {
|
||||
form.SetApp(e.App)
|
||||
form.SetRecord(e.Record)
|
||||
|
||||
err := form.Submit()
|
||||
if err != nil {
|
||||
return firstApiError(err, e.BadRequestError("Failed to create record", err))
|
||||
}
|
||||
|
||||
err = EnrichRecord(e.RequestEvent, e.Record)
|
||||
if err != nil {
|
||||
return firstApiError(err, e.InternalServerError("Failed to enrich record", err))
|
||||
}
|
||||
|
||||
err = execAfterSuccessTx(responseWriteAfterTx, e.App, func() error {
|
||||
return e.JSON(http.StatusOK, e.Record)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if optFinalizer != nil {
|
||||
isOptFinalizerCalled = true
|
||||
err = optFinalizer(e.Record)
|
||||
if err != nil {
|
||||
return firstApiError(err, e.InternalServerError("", err))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if hookErr != nil {
|
||||
return hookErr
|
||||
}
|
||||
|
||||
// e.g. in case the regular hook chain was stopped and the finalizer cannot be executed as part of the last e.Next() task
|
||||
if !isOptFinalizerCalled && optFinalizer != nil {
|
||||
if err := optFinalizer(event.Record); err != nil {
|
||||
return firstApiError(err, e.InternalServerError("", err))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func recordUpdate(responseWriteAfterTx bool, optFinalizer func(data any) error) func(e *core.RequestEvent) error {
|
||||
return func(e *core.RequestEvent) error {
|
||||
collection, err := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue("collection"))
|
||||
if err != nil || collection == nil {
|
||||
return e.NotFoundError("Missing collection context.", err)
|
||||
}
|
||||
|
||||
if collection.IsView() {
|
||||
return e.BadRequestError("Unsupported collection type.", nil)
|
||||
}
|
||||
|
||||
err = checkCollectionRateLimit(e, collection, "update")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
recordId := e.Request.PathValue("id")
|
||||
if recordId == "" {
|
||||
return e.NotFoundError("", nil)
|
||||
}
|
||||
|
||||
requestInfo, err := e.RequestInfo()
|
||||
if err != nil {
|
||||
return firstApiError(err, e.BadRequestError("", err))
|
||||
}
|
||||
|
||||
hasSuperuserAuth := requestInfo.HasSuperuserAuth()
|
||||
|
||||
if !hasSuperuserAuth && collection.UpdateRule == nil {
|
||||
return firstApiError(err, e.ForbiddenError("Only superusers can perform this action.", nil))
|
||||
}
|
||||
|
||||
// eager fetch the record so that the modifiers field values can be resolved
|
||||
record, err := e.App.FindRecordById(collection, recordId)
|
||||
if err != nil {
|
||||
return firstApiError(err, e.NotFoundError("", err))
|
||||
}
|
||||
|
||||
data, err := recordDataFromRequest(e, record)
|
||||
if err != nil {
|
||||
return firstApiError(err, e.BadRequestError("Failed to read the submitted data.", err))
|
||||
}
|
||||
|
||||
// replace modifiers fields so that the resolved value is always
|
||||
// available when accessing requestInfo.Body
|
||||
requestInfo.Body = data
|
||||
|
||||
ruleFunc := func(q *dbx.SelectQuery) error {
|
||||
if !hasSuperuserAuth && collection.UpdateRule != nil && *collection.UpdateRule != "" {
|
||||
resolver := core.NewRecordFieldResolver(e.App, collection, requestInfo, true)
|
||||
expr, err := search.FilterData(*collection.UpdateRule).BuildExpr(resolver)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resolver.UpdateQuery(q)
|
||||
q.AndWhere(expr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// refetch with access checks
|
||||
record, err = e.App.FindRecordById(collection, recordId, ruleFunc)
|
||||
if err != nil {
|
||||
return firstApiError(err, e.NotFoundError("", err))
|
||||
}
|
||||
|
||||
form := forms.NewRecordUpsert(e.App, record)
|
||||
if hasSuperuserAuth {
|
||||
form.GrantSuperuserAccess()
|
||||
}
|
||||
form.Load(data)
|
||||
|
||||
manageRuleQuery := e.App.ConcurrentDB().Select("(1)").From(collection.Name).AndWhere(dbx.HashExp{
|
||||
collection.Name + ".id": record.Id,
|
||||
})
|
||||
if !form.HasManageAccess() &&
|
||||
hasAuthManageAccess(e.App, requestInfo, collection, manageRuleQuery) {
|
||||
form.GrantManagerAccess()
|
||||
}
|
||||
|
||||
var isOptFinalizerCalled bool
|
||||
|
||||
event := new(core.RecordRequestEvent)
|
||||
event.RequestEvent = e
|
||||
event.Collection = collection
|
||||
event.Record = record
|
||||
|
||||
hookErr := e.App.OnRecordUpdateRequest().Trigger(event, func(e *core.RecordRequestEvent) error {
|
||||
form.SetApp(e.App)
|
||||
form.SetRecord(e.Record)
|
||||
|
||||
err := form.Submit()
|
||||
if err != nil {
|
||||
return firstApiError(err, e.BadRequestError("Failed to update record.", err))
|
||||
}
|
||||
|
||||
err = EnrichRecord(e.RequestEvent, e.Record)
|
||||
if err != nil {
|
||||
return firstApiError(err, e.InternalServerError("Failed to enrich record", err))
|
||||
}
|
||||
|
||||
err = execAfterSuccessTx(responseWriteAfterTx, e.App, func() error {
|
||||
return e.JSON(http.StatusOK, e.Record)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if optFinalizer != nil {
|
||||
isOptFinalizerCalled = true
|
||||
err = optFinalizer(e.Record)
|
||||
if err != nil {
|
||||
return firstApiError(err, e.InternalServerError("", fmt.Errorf("update optFinalizer error: %w", err)))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if hookErr != nil {
|
||||
return hookErr
|
||||
}
|
||||
|
||||
// e.g. in case the regular hook chain was stopped and the finalizer cannot be executed as part of the last e.Next() task
|
||||
if !isOptFinalizerCalled && optFinalizer != nil {
|
||||
if err := optFinalizer(event.Record); err != nil {
|
||||
return firstApiError(err, e.InternalServerError("", fmt.Errorf("update optFinalizer error: %w", err)))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func recordDelete(responseWriteAfterTx bool, optFinalizer func(data any) error) func(e *core.RequestEvent) error {
|
||||
return func(e *core.RequestEvent) error {
|
||||
collection, err := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue("collection"))
|
||||
if err != nil || collection == nil {
|
||||
return e.NotFoundError("Missing collection context.", err)
|
||||
}
|
||||
|
||||
if collection.IsView() {
|
||||
return e.BadRequestError("Unsupported collection type.", nil)
|
||||
}
|
||||
|
||||
err = checkCollectionRateLimit(e, collection, "delete")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
recordId := e.Request.PathValue("id")
|
||||
if recordId == "" {
|
||||
return e.NotFoundError("", nil)
|
||||
}
|
||||
|
||||
requestInfo, err := e.RequestInfo()
|
||||
if err != nil {
|
||||
return firstApiError(err, e.BadRequestError("", err))
|
||||
}
|
||||
|
||||
if !requestInfo.HasSuperuserAuth() && collection.DeleteRule == nil {
|
||||
return e.ForbiddenError("Only superusers can perform this action.", nil)
|
||||
}
|
||||
|
||||
ruleFunc := func(q *dbx.SelectQuery) error {
|
||||
if !requestInfo.HasSuperuserAuth() && collection.DeleteRule != nil && *collection.DeleteRule != "" {
|
||||
resolver := core.NewRecordFieldResolver(e.App, collection, requestInfo, true)
|
||||
expr, err := search.FilterData(*collection.DeleteRule).BuildExpr(resolver)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resolver.UpdateQuery(q)
|
||||
q.AndWhere(expr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
record, err := e.App.FindRecordById(collection, recordId, ruleFunc)
|
||||
if err != nil || record == nil {
|
||||
return e.NotFoundError("", err)
|
||||
}
|
||||
|
||||
var isOptFinalizerCalled bool
|
||||
|
||||
event := new(core.RecordRequestEvent)
|
||||
event.RequestEvent = e
|
||||
event.Collection = collection
|
||||
event.Record = record
|
||||
|
||||
hookErr := e.App.OnRecordDeleteRequest().Trigger(event, func(e *core.RecordRequestEvent) error {
|
||||
if err := e.App.Delete(e.Record); err != nil {
|
||||
return firstApiError(err, e.BadRequestError("Failed to delete record. Make sure that the record is not part of a required relation reference.", err))
|
||||
}
|
||||
|
||||
err = execAfterSuccessTx(responseWriteAfterTx, e.App, func() error {
|
||||
return e.NoContent(http.StatusNoContent)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if optFinalizer != nil {
|
||||
isOptFinalizerCalled = true
|
||||
err = optFinalizer(e.Record)
|
||||
if err != nil {
|
||||
return firstApiError(err, e.InternalServerError("", fmt.Errorf("delete optFinalizer error: %w", err)))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if hookErr != nil {
|
||||
return hookErr
|
||||
}
|
||||
|
||||
// e.g. in case the regular hook chain was stopped and the finalizer cannot be executed as part of the last e.Next() task
|
||||
if !isOptFinalizerCalled && optFinalizer != nil {
|
||||
if err := optFinalizer(event.Record); err != nil {
|
||||
return firstApiError(err, e.InternalServerError("", fmt.Errorf("delete optFinalizer error: %w", err)))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
func recordDataFromRequest(e *core.RequestEvent, record *core.Record) (map[string]any, error) {
|
||||
info, err := e.RequestInfo()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// resolve regular fields
|
||||
result := record.ReplaceModifiers(info.Body)
|
||||
|
||||
// resolve uploaded files
|
||||
uploadedFiles, err := extractUploadedFiles(e, record.Collection(), "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(uploadedFiles) > 0 {
|
||||
for k, files := range uploadedFiles {
|
||||
uploaded := make([]any, 0, len(files))
|
||||
|
||||
// if not remove/prepend/append -> merge with the submitted
|
||||
// info.Body values to prevent accidental old files deletion
|
||||
if info.Body[k] != nil &&
|
||||
!strings.HasPrefix(k, "+") &&
|
||||
!strings.HasSuffix(k, "+") &&
|
||||
!strings.HasSuffix(k, "-") {
|
||||
existing := list.ToUniqueStringSlice(info.Body[k])
|
||||
for _, name := range existing {
|
||||
uploaded = append(uploaded, name)
|
||||
}
|
||||
}
|
||||
|
||||
for _, file := range files {
|
||||
uploaded = append(uploaded, file)
|
||||
}
|
||||
|
||||
result[k] = uploaded
|
||||
}
|
||||
|
||||
result = record.ReplaceModifiers(result)
|
||||
}
|
||||
|
||||
isAuth := record.Collection().IsAuth()
|
||||
|
||||
// unset hidden fields for non-superusers
|
||||
if !info.HasSuperuserAuth() {
|
||||
for _, f := range record.Collection().Fields {
|
||||
if f.GetHidden() {
|
||||
// exception for the auth collection "password" field
|
||||
if isAuth && f.GetName() == core.FieldNamePassword {
|
||||
continue
|
||||
}
|
||||
|
||||
delete(result, f.GetName())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func extractUploadedFiles(re *core.RequestEvent, collection *core.Collection, prefix string) (map[string][]*filesystem.File, error) {
|
||||
contentType := re.Request.Header.Get("content-type")
|
||||
if !strings.HasPrefix(contentType, "multipart/form-data") {
|
||||
return nil, nil // not multipart/form-data request
|
||||
}
|
||||
|
||||
result := map[string][]*filesystem.File{}
|
||||
|
||||
for _, field := range collection.Fields {
|
||||
if field.Type() != core.FieldTypeFile {
|
||||
continue
|
||||
}
|
||||
|
||||
baseKey := field.GetName()
|
||||
|
||||
keys := []string{
|
||||
baseKey,
|
||||
// prepend and append modifiers
|
||||
"+" + baseKey,
|
||||
baseKey + "+",
|
||||
}
|
||||
|
||||
for _, k := range keys {
|
||||
if prefix != "" {
|
||||
k = prefix + "." + k
|
||||
}
|
||||
files, err := re.FindUploadedFiles(k)
|
||||
if err != nil && !errors.Is(err, http.ErrMissingFile) {
|
||||
return nil, err
|
||||
}
|
||||
if len(files) > 0 {
|
||||
result[k] = files
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// hasAuthManageAccess checks whether the client is allowed to have
|
||||
// [forms.RecordUpsert] auth management permissions
|
||||
// (e.g. allowing to change system auth fields without oldPassword).
|
||||
func hasAuthManageAccess(app core.App, requestInfo *core.RequestInfo, collection *core.Collection, query *dbx.SelectQuery) bool {
|
||||
if !collection.IsAuth() {
|
||||
return false
|
||||
}
|
||||
|
||||
manageRule := collection.ManageRule
|
||||
|
||||
if manageRule == nil || *manageRule == "" {
|
||||
return false // only for superusers (manageRule can't be empty)
|
||||
}
|
||||
|
||||
if requestInfo == nil || requestInfo.Auth == nil {
|
||||
return false // no auth record
|
||||
}
|
||||
|
||||
resolver := core.NewRecordFieldResolver(app, collection, requestInfo, true)
|
||||
|
||||
expr, err := search.FilterData(*manageRule).BuildExpr(resolver)
|
||||
if err != nil {
|
||||
app.Logger().Error("Manage rule build expression error", "error", err, "collectionId", collection.Id)
|
||||
return false
|
||||
}
|
||||
query.AndWhere(expr)
|
||||
|
||||
resolver.UpdateQuery(query)
|
||||
|
||||
var exists int
|
||||
|
||||
err = query.Limit(1).Row(&exists)
|
||||
|
||||
return err == nil && exists > 0
|
||||
}
|
314
apis/record_crud_auth_origin_test.go
Normal file
314
apis/record_crud_auth_origin_test.go
Normal file
|
@ -0,0 +1,314 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestRecordCrudAuthOriginList(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records",
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"page":1`,
|
||||
`"perPage":30`,
|
||||
`"totalItems":0`,
|
||||
`"totalPages":0`,
|
||||
`"items":[]`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordsListRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "regular auth with authOrigins",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records",
|
||||
Headers: map[string]string{
|
||||
// clients, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"page":1`,
|
||||
`"perPage":30`,
|
||||
`"totalItems":1`,
|
||||
`"totalPages":1`,
|
||||
`"id":"9r2j0m74260ur8i"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordsListRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "regular auth without authOrigins",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records",
|
||||
Headers: map[string]string{
|
||||
// users, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"page":1`,
|
||||
`"perPage":30`,
|
||||
`"totalItems":0`,
|
||||
`"totalPages":0`,
|
||||
`"items":[]`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordsListRequest": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordCrudAuthOriginView(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records/9r2j0m74260ur8i",
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "non-owner",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records/9r2j0m74260ur8i",
|
||||
Headers: map[string]string{
|
||||
// users, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "owner",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records/9r2j0m74260ur8i",
|
||||
Headers: map[string]string{
|
||||
// clients, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{`"id":"9r2j0m74260ur8i"`},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordViewRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordCrudAuthOriginDelete(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodDelete,
|
||||
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records/9r2j0m74260ur8i",
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "non-owner",
|
||||
Method: http.MethodDelete,
|
||||
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records/9r2j0m74260ur8i",
|
||||
Headers: map[string]string{
|
||||
// users, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "owner",
|
||||
Method: http.MethodDelete,
|
||||
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records/9r2j0m74260ur8i",
|
||||
Headers: map[string]string{
|
||||
// clients, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
|
||||
},
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordDeleteRequest": 1,
|
||||
"OnModelDelete": 1,
|
||||
"OnModelDeleteExecute": 1,
|
||||
"OnModelAfterDeleteSuccess": 1,
|
||||
"OnRecordDelete": 1,
|
||||
"OnRecordDeleteExecute": 1,
|
||||
"OnRecordAfterDeleteSuccess": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordCrudAuthOriginCreate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body := func() *strings.Reader {
|
||||
return strings.NewReader(`{
|
||||
"recordRef": "4q1xlclmfloku33",
|
||||
"collectionRef": "_pb_users_auth_",
|
||||
"fingerprint": "abc"
|
||||
}`)
|
||||
}
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records",
|
||||
Body: body(),
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "owner regular auth",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records",
|
||||
Headers: map[string]string{
|
||||
// users, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
Body: body(),
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "superusers auth",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records",
|
||||
Headers: map[string]string{
|
||||
// superusers, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
Body: body(),
|
||||
ExpectedContent: []string{
|
||||
`"fingerprint":"abc"`,
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordCreateRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
"OnModelCreate": 1,
|
||||
"OnModelCreateExecute": 1,
|
||||
"OnModelAfterCreateSuccess": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnRecordCreate": 1,
|
||||
"OnRecordCreateExecute": 1,
|
||||
"OnRecordAfterCreateSuccess": 1,
|
||||
"OnRecordValidate": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordCrudAuthOriginUpdate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body := func() *strings.Reader {
|
||||
return strings.NewReader(`{
|
||||
"fingerprint":"abc"
|
||||
}`)
|
||||
}
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodPatch,
|
||||
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records/9r2j0m74260ur8i",
|
||||
Body: body(),
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "owner regular auth",
|
||||
Method: http.MethodPatch,
|
||||
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records/9r2j0m74260ur8i",
|
||||
Headers: map[string]string{
|
||||
// clients, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
|
||||
},
|
||||
Body: body(),
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "superusers auth",
|
||||
Method: http.MethodPatch,
|
||||
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records/9r2j0m74260ur8i",
|
||||
Headers: map[string]string{
|
||||
// superusers, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
Body: body(),
|
||||
ExpectedContent: []string{
|
||||
`"id":"9r2j0m74260ur8i"`,
|
||||
`"fingerprint":"abc"`,
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordUpdateRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
"OnModelUpdate": 1,
|
||||
"OnModelUpdateExecute": 1,
|
||||
"OnModelAfterUpdateSuccess": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnRecordUpdate": 1,
|
||||
"OnRecordUpdateExecute": 1,
|
||||
"OnRecordAfterUpdateSuccess": 1,
|
||||
"OnRecordValidate": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
316
apis/record_crud_external_auth_test.go
Normal file
316
apis/record_crud_external_auth_test.go
Normal file
|
@ -0,0 +1,316 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestRecordCrudExternalAuthList(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records",
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"page":1`,
|
||||
`"perPage":30`,
|
||||
`"totalItems":0`,
|
||||
`"totalPages":0`,
|
||||
`"items":[]`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordsListRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "regular auth with externalAuths",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records",
|
||||
Headers: map[string]string{
|
||||
// clients, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"page":1`,
|
||||
`"perPage":30`,
|
||||
`"totalItems":1`,
|
||||
`"totalPages":1`,
|
||||
`"id":"f1z5b3843pzc964"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordsListRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "regular auth without externalAuths",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records",
|
||||
Headers: map[string]string{
|
||||
// users, test2@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6Im9hcDY0MGNvdDR5cnUycyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.GfJo6EHIobgas_AXt-M-tj5IoQendPnrkMSe9ExuSEY",
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"page":1`,
|
||||
`"perPage":30`,
|
||||
`"totalItems":0`,
|
||||
`"totalPages":0`,
|
||||
`"items":[]`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordsListRequest": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordCrudExternalAuthView(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records/dlmflokuq1xl342",
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "non-owner",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records/dlmflokuq1xl342",
|
||||
Headers: map[string]string{
|
||||
// clients, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
|
||||
},
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "owner",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records/dlmflokuq1xl342",
|
||||
Headers: map[string]string{
|
||||
// users, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{`"id":"dlmflokuq1xl342"`},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordViewRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordCrudExternalAuthDelete(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodDelete,
|
||||
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records/dlmflokuq1xl342",
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "non-owner",
|
||||
Method: http.MethodDelete,
|
||||
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records/dlmflokuq1xl342",
|
||||
Headers: map[string]string{
|
||||
// clients, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
|
||||
},
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "owner",
|
||||
Method: http.MethodDelete,
|
||||
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records/dlmflokuq1xl342",
|
||||
Headers: map[string]string{
|
||||
// users, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordDeleteRequest": 1,
|
||||
"OnModelDelete": 1,
|
||||
"OnModelDeleteExecute": 1,
|
||||
"OnModelAfterDeleteSuccess": 1,
|
||||
"OnRecordDelete": 1,
|
||||
"OnRecordDeleteExecute": 1,
|
||||
"OnRecordAfterDeleteSuccess": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordCrudExternalAuthCreate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body := func() *strings.Reader {
|
||||
return strings.NewReader(`{
|
||||
"recordRef": "4q1xlclmfloku33",
|
||||
"collectionRef": "_pb_users_auth_",
|
||||
"provider": "github",
|
||||
"providerId": "abc"
|
||||
}`)
|
||||
}
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records",
|
||||
Body: body(),
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "owner regular auth",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records",
|
||||
Headers: map[string]string{
|
||||
// users, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
Body: body(),
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "superusers auth",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records",
|
||||
Headers: map[string]string{
|
||||
// superusers, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
Body: body(),
|
||||
ExpectedContent: []string{
|
||||
`"recordRef":"4q1xlclmfloku33"`,
|
||||
`"providerId":"abc"`,
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordCreateRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
"OnModelCreate": 1,
|
||||
"OnModelCreateExecute": 1,
|
||||
"OnModelAfterCreateSuccess": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnRecordCreate": 1,
|
||||
"OnRecordCreateExecute": 1,
|
||||
"OnRecordAfterCreateSuccess": 1,
|
||||
"OnRecordValidate": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordCrudExternalAuthUpdate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body := func() *strings.Reader {
|
||||
return strings.NewReader(`{
|
||||
"providerId": "abc"
|
||||
}`)
|
||||
}
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodPatch,
|
||||
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records/dlmflokuq1xl342",
|
||||
Body: body(),
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "owner regular auth",
|
||||
Method: http.MethodPatch,
|
||||
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records/dlmflokuq1xl342",
|
||||
Headers: map[string]string{
|
||||
// clients, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
|
||||
},
|
||||
Body: body(),
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "superusers auth",
|
||||
Method: http.MethodPatch,
|
||||
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records/dlmflokuq1xl342",
|
||||
Headers: map[string]string{
|
||||
// superusers, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
Body: body(),
|
||||
ExpectedContent: []string{
|
||||
`"id":"dlmflokuq1xl342"`,
|
||||
`"providerId":"abc"`,
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordUpdateRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
"OnModelUpdate": 1,
|
||||
"OnModelUpdateExecute": 1,
|
||||
"OnModelAfterUpdateSuccess": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnRecordUpdate": 1,
|
||||
"OnRecordUpdateExecute": 1,
|
||||
"OnRecordAfterUpdateSuccess": 1,
|
||||
"OnRecordValidate": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
405
apis/record_crud_mfa_test.go
Normal file
405
apis/record_crud_mfa_test.go
Normal file
|
@ -0,0 +1,405 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestRecordCrudMFAList(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameMFAs + "/records",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubMFARecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"page":1`,
|
||||
`"perPage":30`,
|
||||
`"totalItems":0`,
|
||||
`"totalPages":0`,
|
||||
`"items":[]`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordsListRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "regular auth with mfas",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameMFAs + "/records",
|
||||
Headers: map[string]string{
|
||||
// users, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubMFARecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"page":1`,
|
||||
`"perPage":30`,
|
||||
`"totalItems":1`,
|
||||
`"totalPages":1`,
|
||||
`"id":"user1_0"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordsListRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "regular auth without mfas",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameMFAs + "/records",
|
||||
Headers: map[string]string{
|
||||
// clients, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubMFARecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"page":1`,
|
||||
`"perPage":30`,
|
||||
`"totalItems":0`,
|
||||
`"totalPages":0`,
|
||||
`"items":[]`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordsListRequest": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordCrudMFAView(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameMFAs + "/records/user1_0",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubMFARecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "non-owner",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameMFAs + "/records/user1_0",
|
||||
Headers: map[string]string{
|
||||
// clients, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubMFARecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "owner",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameMFAs + "/records/user1_0",
|
||||
Headers: map[string]string{
|
||||
// users, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubMFARecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{`"id":"user1_0"`},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordViewRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordCrudMFADelete(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodDelete,
|
||||
URL: "/api/collections/" + core.CollectionNameMFAs + "/records/user1_0",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubMFARecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "non-owner",
|
||||
Method: http.MethodDelete,
|
||||
URL: "/api/collections/" + core.CollectionNameMFAs + "/records/user1_0",
|
||||
Headers: map[string]string{
|
||||
// clients, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubMFARecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "owner",
|
||||
Method: http.MethodDelete,
|
||||
URL: "/api/collections/" + core.CollectionNameMFAs + "/records/user1_0",
|
||||
Headers: map[string]string{
|
||||
// users, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubMFARecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "superusers auth",
|
||||
Method: http.MethodDelete,
|
||||
URL: "/api/collections/" + core.CollectionNameMFAs + "/records/user1_0",
|
||||
Headers: map[string]string{
|
||||
// superusers, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubMFARecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordDeleteRequest": 1,
|
||||
"OnModelDelete": 1,
|
||||
"OnModelDeleteExecute": 1,
|
||||
"OnModelAfterDeleteSuccess": 1,
|
||||
"OnRecordDelete": 1,
|
||||
"OnRecordDeleteExecute": 1,
|
||||
"OnRecordAfterDeleteSuccess": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordCrudMFACreate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body := func() *strings.Reader {
|
||||
return strings.NewReader(`{
|
||||
"recordRef": "4q1xlclmfloku33",
|
||||
"collectionRef": "_pb_users_auth_",
|
||||
"method": "abc"
|
||||
}`)
|
||||
}
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/" + core.CollectionNameMFAs + "/records",
|
||||
Body: body(),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubMFARecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "owner regular auth",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/" + core.CollectionNameMFAs + "/records",
|
||||
Headers: map[string]string{
|
||||
// users, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
Body: body(),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubMFARecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "superusers auth",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/" + core.CollectionNameMFAs + "/records",
|
||||
Headers: map[string]string{
|
||||
// superusers, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
Body: body(),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubMFARecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedContent: []string{
|
||||
`"recordRef":"4q1xlclmfloku33"`,
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordCreateRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
"OnModelCreate": 1,
|
||||
"OnModelCreateExecute": 1,
|
||||
"OnModelAfterCreateSuccess": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnRecordCreate": 1,
|
||||
"OnRecordCreateExecute": 1,
|
||||
"OnRecordAfterCreateSuccess": 1,
|
||||
"OnRecordValidate": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordCrudMFAUpdate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body := func() *strings.Reader {
|
||||
return strings.NewReader(`{
|
||||
"method":"abc"
|
||||
}`)
|
||||
}
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodPatch,
|
||||
URL: "/api/collections/" + core.CollectionNameMFAs + "/records/user1_0",
|
||||
Body: body(),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubMFARecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "owner regular auth",
|
||||
Method: http.MethodPatch,
|
||||
URL: "/api/collections/" + core.CollectionNameMFAs + "/records/user1_0",
|
||||
Headers: map[string]string{
|
||||
// users, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
Body: body(),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubMFARecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "superusers auth",
|
||||
Method: http.MethodPatch,
|
||||
URL: "/api/collections/" + core.CollectionNameMFAs + "/records/user1_0",
|
||||
Headers: map[string]string{
|
||||
// superusers, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
Body: body(),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubMFARecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedContent: []string{
|
||||
`"id":"user1_0"`,
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordUpdateRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
"OnModelUpdate": 1,
|
||||
"OnModelUpdateExecute": 1,
|
||||
"OnModelAfterUpdateSuccess": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnRecordUpdate": 1,
|
||||
"OnRecordUpdateExecute": 1,
|
||||
"OnRecordAfterUpdateSuccess": 1,
|
||||
"OnRecordValidate": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
405
apis/record_crud_otp_test.go
Normal file
405
apis/record_crud_otp_test.go
Normal file
|
@ -0,0 +1,405 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestRecordCrudOTPList(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameOTPs + "/records",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubOTPRecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"page":1`,
|
||||
`"perPage":30`,
|
||||
`"totalItems":0`,
|
||||
`"totalPages":0`,
|
||||
`"items":[]`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordsListRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "regular auth with otps",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameOTPs + "/records",
|
||||
Headers: map[string]string{
|
||||
// users, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubOTPRecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"page":1`,
|
||||
`"perPage":30`,
|
||||
`"totalItems":1`,
|
||||
`"totalPages":1`,
|
||||
`"id":"user1_0"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordsListRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "regular auth without otps",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameOTPs + "/records",
|
||||
Headers: map[string]string{
|
||||
// clients, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubOTPRecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"page":1`,
|
||||
`"perPage":30`,
|
||||
`"totalItems":0`,
|
||||
`"totalPages":0`,
|
||||
`"items":[]`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordsListRequest": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordCrudOTPView(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameOTPs + "/records/user1_0",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubOTPRecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "non-owner",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameOTPs + "/records/user1_0",
|
||||
Headers: map[string]string{
|
||||
// clients, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubOTPRecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "owner",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameOTPs + "/records/user1_0",
|
||||
Headers: map[string]string{
|
||||
// users, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubOTPRecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{`"id":"user1_0"`},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordViewRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordCrudOTPDelete(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodDelete,
|
||||
URL: "/api/collections/" + core.CollectionNameOTPs + "/records/user1_0",
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubOTPRecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "non-owner auth",
|
||||
Method: http.MethodDelete,
|
||||
URL: "/api/collections/" + core.CollectionNameOTPs + "/records/user1_0",
|
||||
Headers: map[string]string{
|
||||
// clients, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubOTPRecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "owner regular auth",
|
||||
Method: http.MethodDelete,
|
||||
URL: "/api/collections/" + core.CollectionNameOTPs + "/records/user1_0",
|
||||
Headers: map[string]string{
|
||||
// users, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubOTPRecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "superusers auth",
|
||||
Method: http.MethodDelete,
|
||||
URL: "/api/collections/" + core.CollectionNameOTPs + "/records/user1_0",
|
||||
Headers: map[string]string{
|
||||
// superusers, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubOTPRecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordDeleteRequest": 1,
|
||||
"OnModelDelete": 1,
|
||||
"OnModelDeleteExecute": 1,
|
||||
"OnModelAfterDeleteSuccess": 1,
|
||||
"OnRecordDelete": 1,
|
||||
"OnRecordDeleteExecute": 1,
|
||||
"OnRecordAfterDeleteSuccess": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordCrudOTPCreate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body := func() *strings.Reader {
|
||||
return strings.NewReader(`{
|
||||
"recordRef": "4q1xlclmfloku33",
|
||||
"collectionRef": "_pb_users_auth_",
|
||||
"password": "abc"
|
||||
}`)
|
||||
}
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/" + core.CollectionNameOTPs + "/records",
|
||||
Body: body(),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubOTPRecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "owner regular auth",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/" + core.CollectionNameOTPs + "/records",
|
||||
Headers: map[string]string{
|
||||
// users, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
Body: body(),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubOTPRecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "superusers auth",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/" + core.CollectionNameOTPs + "/records",
|
||||
Headers: map[string]string{
|
||||
// superusers, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
Body: body(),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubOTPRecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedContent: []string{
|
||||
`"recordRef":"4q1xlclmfloku33"`,
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordCreateRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
"OnModelCreate": 1,
|
||||
"OnModelCreateExecute": 1,
|
||||
"OnModelAfterCreateSuccess": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnRecordCreate": 1,
|
||||
"OnRecordCreateExecute": 1,
|
||||
"OnRecordAfterCreateSuccess": 1,
|
||||
"OnRecordValidate": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordCrudOTPUpdate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body := func() *strings.Reader {
|
||||
return strings.NewReader(`{
|
||||
"password":"abc"
|
||||
}`)
|
||||
}
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodPatch,
|
||||
URL: "/api/collections/" + core.CollectionNameOTPs + "/records/user1_0",
|
||||
Body: body(),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubOTPRecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "owner regular auth",
|
||||
Method: http.MethodPatch,
|
||||
URL: "/api/collections/" + core.CollectionNameOTPs + "/records/user1_0",
|
||||
Headers: map[string]string{
|
||||
// users, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
Body: body(),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubOTPRecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "superusers auth",
|
||||
Method: http.MethodPatch,
|
||||
URL: "/api/collections/" + core.CollectionNameOTPs + "/records/user1_0",
|
||||
Headers: map[string]string{
|
||||
// superusers, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
Body: body(),
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
if err := tests.StubOTPRecords(app); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
ExpectedContent: []string{
|
||||
`"id":"user1_0"`,
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordUpdateRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
"OnModelUpdate": 1,
|
||||
"OnModelUpdateExecute": 1,
|
||||
"OnModelAfterUpdateSuccess": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnRecordUpdate": 1,
|
||||
"OnRecordUpdateExecute": 1,
|
||||
"OnRecordAfterUpdateSuccess": 1,
|
||||
"OnRecordValidate": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
336
apis/record_crud_superuser_test.go
Normal file
336
apis/record_crud_superuser_test.go
Normal file
|
@ -0,0 +1,336 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestRecordCrudSuperuserList(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records",
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "non-superusers auth",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records",
|
||||
Headers: map[string]string{
|
||||
// users, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "superusers auth",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records",
|
||||
Headers: map[string]string{
|
||||
// _superusers, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"page":1`,
|
||||
`"perPage":30`,
|
||||
`"totalPages":1`,
|
||||
`"totalItems":4`,
|
||||
`"items":[{`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordsListRequest": 1,
|
||||
"OnRecordEnrich": 4,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordCrudSuperuserView(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records/sywbhecnh46rhm0",
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "non-superusers auth",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records/sywbhecnh46rhm0",
|
||||
Headers: map[string]string{
|
||||
// users, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "superusers auth",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records/sywbhecnh46rhm0",
|
||||
Headers: map[string]string{
|
||||
// _superusers, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"id":"sywbhecnh46rhm0"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordViewRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordCrudSuperuserDelete(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodDelete,
|
||||
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records/sbmbsdb40jyxf7h",
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "non-superusers auth",
|
||||
Method: http.MethodDelete,
|
||||
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records/sbmbsdb40jyxf7h",
|
||||
Headers: map[string]string{
|
||||
// users, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "superusers auth",
|
||||
Method: http.MethodDelete,
|
||||
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records/sbmbsdb40jyxf7h",
|
||||
Headers: map[string]string{
|
||||
// _superusers, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
ExpectedStatus: 204,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordDeleteRequest": 1,
|
||||
"OnModelDelete": 4, // + 3 AuthOrigins
|
||||
"OnModelDeleteExecute": 4,
|
||||
"OnModelAfterDeleteSuccess": 4,
|
||||
"OnRecordDelete": 4,
|
||||
"OnRecordDeleteExecute": 4,
|
||||
"OnRecordAfterDeleteSuccess": 4,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "delete the last superuser",
|
||||
Method: http.MethodDelete,
|
||||
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records/sywbhecnh46rhm0",
|
||||
Headers: map[string]string{
|
||||
// _superusers, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
// delete all other superusers
|
||||
superusers, err := app.FindAllRecords(core.CollectionNameSuperusers, dbx.Not(dbx.HashExp{"id": "sywbhecnh46rhm0"}))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
for _, superuser := range superusers {
|
||||
if err = app.Delete(superuser); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordDeleteRequest": 1,
|
||||
"OnModelDelete": 1,
|
||||
"OnModelAfterDeleteError": 1,
|
||||
"OnRecordDelete": 1,
|
||||
"OnRecordAfterDeleteError": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordCrudSuperuserCreate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body := func() *strings.Reader {
|
||||
return strings.NewReader(`{
|
||||
"email": "test_new@example.com",
|
||||
"password": "1234567890",
|
||||
"passwordConfirm": "1234567890",
|
||||
"verified": false
|
||||
}`)
|
||||
}
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records",
|
||||
Body: body(),
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "non-superusers auth",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records",
|
||||
Headers: map[string]string{
|
||||
// users, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
Body: body(),
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "superusers auth",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records",
|
||||
Headers: map[string]string{
|
||||
// _superusers, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
Body: body(),
|
||||
ExpectedContent: []string{
|
||||
`"collectionName":"_superusers"`,
|
||||
`"email":"test_new@example.com"`,
|
||||
`"verified":true`,
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordCreateRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
"OnModelCreate": 1,
|
||||
"OnModelCreateExecute": 1,
|
||||
"OnModelAfterCreateSuccess": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnRecordCreate": 1,
|
||||
"OnRecordCreateExecute": 1,
|
||||
"OnRecordAfterCreateSuccess": 1,
|
||||
"OnRecordValidate": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordCrudSuperuserUpdate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body := func() *strings.Reader {
|
||||
return strings.NewReader(`{
|
||||
"email": "test_new@example.com",
|
||||
"verified": true
|
||||
}`)
|
||||
}
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "guest",
|
||||
Method: http.MethodPatch,
|
||||
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records/sywbhecnh46rhm0",
|
||||
Body: body(),
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "non-superusers auth",
|
||||
Method: http.MethodPatch,
|
||||
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records/sywbhecnh46rhm0",
|
||||
Headers: map[string]string{
|
||||
// users, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
Body: body(),
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "superusers auth",
|
||||
Method: http.MethodPatch,
|
||||
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records/sywbhecnh46rhm0",
|
||||
Headers: map[string]string{
|
||||
// _superusers, test@example.com
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
Body: body(),
|
||||
ExpectedContent: []string{
|
||||
`"collectionName":"_superusers"`,
|
||||
`"id":"sywbhecnh46rhm0"`,
|
||||
`"email":"test_new@example.com"`,
|
||||
`"verified":true`,
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnRecordUpdateRequest": 1,
|
||||
"OnRecordEnrich": 1,
|
||||
"OnModelUpdate": 1,
|
||||
"OnModelUpdateExecute": 1,
|
||||
"OnModelAfterUpdateSuccess": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnRecordUpdate": 1,
|
||||
"OnRecordUpdateExecute": 1,
|
||||
"OnRecordAfterUpdateSuccess": 1,
|
||||
"OnRecordValidate": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
3584
apis/record_crud_test.go
Normal file
3584
apis/record_crud_test.go
Normal file
File diff suppressed because it is too large
Load diff
636
apis/record_helpers.go
Normal file
636
apis/record_helpers.go
Normal file
|
@ -0,0 +1,636 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/mails"
|
||||
"github.com/pocketbase/pocketbase/tools/router"
|
||||
"github.com/pocketbase/pocketbase/tools/routine"
|
||||
"github.com/pocketbase/pocketbase/tools/search"
|
||||
"github.com/pocketbase/pocketbase/tools/security"
|
||||
)
|
||||
|
||||
const (
|
||||
expandQueryParam = "expand"
|
||||
fieldsQueryParam = "fields"
|
||||
)
|
||||
|
||||
var ErrMFA = errors.New("mfa required")
|
||||
|
||||
// RecordAuthResponse writes standardized json record auth response
|
||||
// into the specified request context.
|
||||
//
|
||||
// The authMethod argument specify the name of the current authentication method (eg. password, oauth2, etc.)
|
||||
// that it is used primarily as an auth identifier during MFA and for login alerts.
|
||||
//
|
||||
// Set authMethod to empty string if you want to ignore the MFA checks and the login alerts
|
||||
// (can be also adjusted additionally via the OnRecordAuthRequest hook).
|
||||
func RecordAuthResponse(e *core.RequestEvent, authRecord *core.Record, authMethod string, meta any) error {
|
||||
token, tokenErr := authRecord.NewAuthToken()
|
||||
if tokenErr != nil {
|
||||
return e.InternalServerError("Failed to create auth token.", tokenErr)
|
||||
}
|
||||
|
||||
return recordAuthResponse(e, authRecord, token, authMethod, meta)
|
||||
}
|
||||
|
||||
func recordAuthResponse(e *core.RequestEvent, authRecord *core.Record, token string, authMethod string, meta any) error {
|
||||
originalRequestInfo, err := e.RequestInfo()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ok, err := e.App.CanAccessRecord(authRecord, originalRequestInfo, authRecord.Collection().AuthRule)
|
||||
if !ok {
|
||||
return firstApiError(err, e.ForbiddenError("The request doesn't satisfy the collection requirements to authenticate.", err))
|
||||
}
|
||||
|
||||
event := new(core.RecordAuthRequestEvent)
|
||||
event.RequestEvent = e
|
||||
event.Collection = authRecord.Collection()
|
||||
event.Record = authRecord
|
||||
event.Token = token
|
||||
event.Meta = meta
|
||||
event.AuthMethod = authMethod
|
||||
|
||||
return e.App.OnRecordAuthRequest().Trigger(event, func(e *core.RecordAuthRequestEvent) error {
|
||||
if e.Written() {
|
||||
return nil
|
||||
}
|
||||
|
||||
// MFA
|
||||
// ---
|
||||
mfaId, err := checkMFA(e.RequestEvent, e.Record, e.AuthMethod)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// require additional authentication
|
||||
if mfaId != "" {
|
||||
// eagerly write the mfa response and return an err so that
|
||||
// external middlewars are aware that the auth response requires an extra step
|
||||
e.JSON(http.StatusUnauthorized, map[string]string{
|
||||
"mfaId": mfaId,
|
||||
})
|
||||
return ErrMFA
|
||||
}
|
||||
// ---
|
||||
|
||||
// create a shallow copy of the cached request data and adjust it to the current auth record
|
||||
requestInfo := *originalRequestInfo
|
||||
requestInfo.Auth = e.Record
|
||||
|
||||
err = triggerRecordEnrichHooks(e.App, &requestInfo, []*core.Record{e.Record}, func() error {
|
||||
if e.Record.IsSuperuser() {
|
||||
e.Record.Unhide(e.Record.Collection().Fields.FieldNames()...)
|
||||
}
|
||||
|
||||
// allow always returning the email address of the authenticated model
|
||||
e.Record.IgnoreEmailVisibility(true)
|
||||
|
||||
// expand record relations
|
||||
expands := strings.Split(e.Request.URL.Query().Get(expandQueryParam), ",")
|
||||
if len(expands) > 0 {
|
||||
failed := e.App.ExpandRecord(e.Record, expands, expandFetch(e.App, &requestInfo))
|
||||
if len(failed) > 0 {
|
||||
e.App.Logger().Warn("[recordAuthResponse] Failed to expand relations", "error", failed)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if e.AuthMethod != "" && authRecord.Collection().AuthAlert.Enabled {
|
||||
if err = authAlert(e.RequestEvent, e.Record); err != nil {
|
||||
e.App.Logger().Warn("[recordAuthResponse] Failed to send login alert", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
result := struct {
|
||||
Meta any `json:"meta,omitempty"`
|
||||
Record *core.Record `json:"record"`
|
||||
Token string `json:"token"`
|
||||
}{
|
||||
Token: e.Token,
|
||||
Record: e.Record,
|
||||
}
|
||||
|
||||
if e.Meta != nil {
|
||||
result.Meta = e.Meta
|
||||
}
|
||||
|
||||
return execAfterSuccessTx(true, e.App, func() error {
|
||||
return e.JSON(http.StatusOK, result)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// wantsMFA checks whether to enable MFA for the specified auth record based on its MFA rule
|
||||
// (note: returns true even in case of an error as a safer default).
|
||||
func wantsMFA(e *core.RequestEvent, record *core.Record) (bool, error) {
|
||||
rule := record.Collection().MFA.Rule
|
||||
if rule == "" {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
requestInfo, err := e.RequestInfo()
|
||||
if err != nil {
|
||||
return true, err
|
||||
}
|
||||
|
||||
var exists int
|
||||
|
||||
query := e.App.RecordQuery(record.Collection()).
|
||||
Select("(1)").
|
||||
AndWhere(dbx.HashExp{record.Collection().Name + ".id": record.Id})
|
||||
|
||||
// parse and apply the access rule filter
|
||||
resolver := core.NewRecordFieldResolver(e.App, record.Collection(), requestInfo, true)
|
||||
expr, err := search.FilterData(rule).BuildExpr(resolver)
|
||||
if err != nil {
|
||||
return true, err
|
||||
}
|
||||
resolver.UpdateQuery(query)
|
||||
|
||||
err = query.AndWhere(expr).Limit(1).Row(&exists)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
return true, err
|
||||
}
|
||||
|
||||
return exists > 0, nil
|
||||
}
|
||||
|
||||
// checkMFA handles any MFA auth checks that needs to be performed for the specified request event.
|
||||
// Returns the mfaId that needs to be written as response to the user.
|
||||
//
|
||||
// (note: all auth methods are treated as equal and there is no requirement for "pairing").
|
||||
func checkMFA(e *core.RequestEvent, authRecord *core.Record, currentAuthMethod string) (string, error) {
|
||||
if !authRecord.Collection().MFA.Enabled || currentAuthMethod == "" {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
ok, err := wantsMFA(e, authRecord)
|
||||
if err != nil {
|
||||
return "", e.BadRequestError("Failed to authenticate.", fmt.Errorf("MFA rule failure: %w", err))
|
||||
}
|
||||
if !ok {
|
||||
return "", nil // no mfa needed for this auth record
|
||||
}
|
||||
|
||||
// read the mfaId either from the qyery params or request body
|
||||
mfaId := e.Request.URL.Query().Get("mfaId")
|
||||
if mfaId == "" {
|
||||
// check the body
|
||||
data := struct {
|
||||
MfaId string `form:"mfaId" json:"mfaId" xml:"mfaId"`
|
||||
}{}
|
||||
if err := e.BindBody(&data); err != nil {
|
||||
return "", firstApiError(err, e.BadRequestError("Failed to read MFA Id", err))
|
||||
}
|
||||
mfaId = data.MfaId
|
||||
}
|
||||
|
||||
// first-time auth
|
||||
// ---
|
||||
if mfaId == "" {
|
||||
mfa := core.NewMFA(e.App)
|
||||
mfa.SetCollectionRef(authRecord.Collection().Id)
|
||||
mfa.SetRecordRef(authRecord.Id)
|
||||
mfa.SetMethod(currentAuthMethod)
|
||||
if err := e.App.Save(mfa); err != nil {
|
||||
return "", firstApiError(err, e.InternalServerError("Failed to create MFA record", err))
|
||||
}
|
||||
|
||||
return mfa.Id, nil
|
||||
}
|
||||
|
||||
// second-time auth
|
||||
// ---
|
||||
mfa, err := e.App.FindMFAById(mfaId)
|
||||
deleteMFA := func() {
|
||||
// try to delete the expired mfa
|
||||
if mfa != nil {
|
||||
if deleteErr := e.App.Delete(mfa); deleteErr != nil {
|
||||
e.App.Logger().Warn("Failed to delete expired MFA record", "error", deleteErr, "mfaId", mfa.Id)
|
||||
}
|
||||
}
|
||||
}
|
||||
if err != nil || mfa.HasExpired(authRecord.Collection().MFA.DurationTime()) {
|
||||
deleteMFA()
|
||||
return "", e.BadRequestError("Invalid or expired MFA session.", err)
|
||||
}
|
||||
|
||||
if mfa.RecordRef() != authRecord.Id || mfa.CollectionRef() != authRecord.Collection().Id {
|
||||
return "", e.BadRequestError("Invalid MFA session.", nil)
|
||||
}
|
||||
|
||||
if mfa.Method() == currentAuthMethod {
|
||||
return "", e.BadRequestError("A different authentication method is required.", nil)
|
||||
}
|
||||
|
||||
deleteMFA()
|
||||
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// EnrichRecord parses the request context and enrich the provided record:
|
||||
// - expands relations (if defaultExpands and/or ?expand query param is set)
|
||||
// - ensures that the emails of the auth record and its expanded auth relations
|
||||
// are visible only for the current logged superuser, record owner or record with manage access
|
||||
func EnrichRecord(e *core.RequestEvent, record *core.Record, defaultExpands ...string) error {
|
||||
return EnrichRecords(e, []*core.Record{record}, defaultExpands...)
|
||||
}
|
||||
|
||||
// EnrichRecords parses the request context and enriches the provided records:
|
||||
// - expands relations (if defaultExpands and/or ?expand query param is set)
|
||||
// - ensures that the emails of the auth records and their expanded auth relations
|
||||
// are visible only for the current logged superuser, record owner or record with manage access
|
||||
//
|
||||
// Note: Expects all records to be from the same collection!
|
||||
func EnrichRecords(e *core.RequestEvent, records []*core.Record, defaultExpands ...string) error {
|
||||
if len(records) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
info, err := e.RequestInfo()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return triggerRecordEnrichHooks(e.App, info, records, func() error {
|
||||
expands := defaultExpands
|
||||
if param := info.Query[expandQueryParam]; param != "" {
|
||||
expands = append(expands, strings.Split(param, ",")...)
|
||||
}
|
||||
|
||||
err := defaultEnrichRecords(e.App, info, records, expands...)
|
||||
if err != nil {
|
||||
// only log because it is not critical
|
||||
e.App.Logger().Warn("failed to apply default enriching", "error", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
type iterator[T any] struct {
|
||||
items []T
|
||||
index int
|
||||
}
|
||||
|
||||
func (ri *iterator[T]) next() T {
|
||||
var item T
|
||||
|
||||
if ri.index < len(ri.items) {
|
||||
item = ri.items[ri.index]
|
||||
ri.index++
|
||||
}
|
||||
|
||||
return item
|
||||
}
|
||||
|
||||
func triggerRecordEnrichHooks(app core.App, requestInfo *core.RequestInfo, records []*core.Record, finalizer func() error) error {
|
||||
it := iterator[*core.Record]{items: records}
|
||||
|
||||
enrichHook := app.OnRecordEnrich()
|
||||
|
||||
event := new(core.RecordEnrichEvent)
|
||||
event.App = app
|
||||
event.RequestInfo = requestInfo
|
||||
|
||||
var iterate func(record *core.Record) error
|
||||
iterate = func(record *core.Record) error {
|
||||
if record == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
event.Record = record
|
||||
|
||||
return enrichHook.Trigger(event, func(ee *core.RecordEnrichEvent) error {
|
||||
next := it.next()
|
||||
if next == nil {
|
||||
if finalizer != nil {
|
||||
return finalizer()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
event.App = ee.App // in case it was replaced with a transaction
|
||||
event.Record = next
|
||||
|
||||
err := iterate(next)
|
||||
|
||||
event.App = app
|
||||
event.Record = record
|
||||
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
return iterate(it.next())
|
||||
}
|
||||
|
||||
func defaultEnrichRecords(app core.App, requestInfo *core.RequestInfo, records []*core.Record, expands ...string) error {
|
||||
err := autoResolveRecordsFlags(app, records, requestInfo)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to resolve records flags: %w", err)
|
||||
}
|
||||
|
||||
if len(expands) > 0 {
|
||||
expandErrs := app.ExpandRecords(records, expands, expandFetch(app, requestInfo))
|
||||
if len(expandErrs) > 0 {
|
||||
errsSlice := make([]error, 0, len(expandErrs))
|
||||
for key, err := range expandErrs {
|
||||
errsSlice = append(errsSlice, fmt.Errorf("failed to expand %q: %w", key, err))
|
||||
}
|
||||
return fmt.Errorf("failed to expand records: %w", errors.Join(errsSlice...))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// expandFetch is the records fetch function that is used to expand related records.
|
||||
func expandFetch(app core.App, originalRequestInfo *core.RequestInfo) core.ExpandFetchFunc {
|
||||
// shallow clone the provided request info to set an "expand" context
|
||||
requestInfoClone := *originalRequestInfo
|
||||
requestInfoPtr := &requestInfoClone
|
||||
requestInfoPtr.Context = core.RequestInfoContextExpand
|
||||
|
||||
return func(relCollection *core.Collection, relIds []string) ([]*core.Record, error) {
|
||||
records, findErr := app.FindRecordsByIds(relCollection.Id, relIds, func(q *dbx.SelectQuery) error {
|
||||
if requestInfoPtr.Auth != nil && requestInfoPtr.Auth.IsSuperuser() {
|
||||
return nil // superusers can access everything
|
||||
}
|
||||
|
||||
if relCollection.ViewRule == nil {
|
||||
return fmt.Errorf("only superusers can view collection %q records", relCollection.Name)
|
||||
}
|
||||
|
||||
if *relCollection.ViewRule != "" {
|
||||
resolver := core.NewRecordFieldResolver(app, relCollection, requestInfoPtr, true)
|
||||
expr, err := search.FilterData(*(relCollection.ViewRule)).BuildExpr(resolver)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resolver.UpdateQuery(q)
|
||||
q.AndWhere(expr)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if findErr != nil {
|
||||
return nil, findErr
|
||||
}
|
||||
|
||||
enrichErr := triggerRecordEnrichHooks(app, requestInfoPtr, records, func() error {
|
||||
if err := autoResolveRecordsFlags(app, records, requestInfoPtr); err != nil {
|
||||
// non-critical error
|
||||
app.Logger().Warn("Failed to apply autoResolveRecordsFlags for the expanded records", "error", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if enrichErr != nil {
|
||||
return nil, enrichErr
|
||||
}
|
||||
|
||||
return records, nil
|
||||
}
|
||||
}
|
||||
|
||||
// autoResolveRecordsFlags resolves various visibility flags of the provided records.
|
||||
//
|
||||
// Currently it enables:
|
||||
// - export of hidden fields if the current auth model is a superuser
|
||||
// - email export ignoring the emailVisibity checks if the current auth model is superuser, owner or a "manager".
|
||||
//
|
||||
// Note: Expects all records to be from the same collection!
|
||||
func autoResolveRecordsFlags(app core.App, records []*core.Record, requestInfo *core.RequestInfo) error {
|
||||
if len(records) == 0 {
|
||||
return nil // nothing to resolve
|
||||
}
|
||||
|
||||
if requestInfo.HasSuperuserAuth() {
|
||||
hiddenFields := records[0].Collection().Fields.FieldNames()
|
||||
for _, rec := range records {
|
||||
rec.Unhide(hiddenFields...)
|
||||
rec.IgnoreEmailVisibility(true)
|
||||
}
|
||||
}
|
||||
|
||||
// additional emailVisibility checks
|
||||
// ---------------------------------------------------------------
|
||||
if !records[0].Collection().IsAuth() {
|
||||
return nil // not auth collection records
|
||||
}
|
||||
|
||||
collection := records[0].Collection()
|
||||
|
||||
mappedRecords := make(map[string]*core.Record, len(records))
|
||||
recordIds := make([]any, len(records))
|
||||
for i, rec := range records {
|
||||
mappedRecords[rec.Id] = rec
|
||||
recordIds[i] = rec.Id
|
||||
}
|
||||
|
||||
if requestInfo.Auth != nil && mappedRecords[requestInfo.Auth.Id] != nil {
|
||||
mappedRecords[requestInfo.Auth.Id].IgnoreEmailVisibility(true)
|
||||
}
|
||||
|
||||
if collection.ManageRule == nil || *collection.ManageRule == "" {
|
||||
return nil // no manage rule to check
|
||||
}
|
||||
|
||||
// fetch the ids of the managed records
|
||||
// ---
|
||||
managedIds := []string{}
|
||||
|
||||
query := app.RecordQuery(collection).
|
||||
Select(app.ConcurrentDB().QuoteSimpleColumnName(collection.Name) + ".id").
|
||||
AndWhere(dbx.In(app.ConcurrentDB().QuoteSimpleColumnName(collection.Name)+".id", recordIds...))
|
||||
|
||||
resolver := core.NewRecordFieldResolver(app, collection, requestInfo, true)
|
||||
expr, err := search.FilterData(*collection.ManageRule).BuildExpr(resolver)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resolver.UpdateQuery(query)
|
||||
query.AndWhere(expr)
|
||||
|
||||
if err := query.Column(&managedIds); err != nil {
|
||||
return err
|
||||
}
|
||||
// ---
|
||||
|
||||
// ignore the email visibility check for the managed records
|
||||
for _, id := range managedIds {
|
||||
if rec, ok := mappedRecords[id]; ok {
|
||||
rec.IgnoreEmailVisibility(true)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var ruleQueryParams = []string{search.FilterQueryParam, search.SortQueryParam}
|
||||
var superuserOnlyRuleFields = []string{"@collection.", "@request."}
|
||||
|
||||
// checkForSuperuserOnlyRuleFields loosely checks and returns an error if
|
||||
// the provided RequestInfo contains rule fields that only the superuser can use.
|
||||
func checkForSuperuserOnlyRuleFields(requestInfo *core.RequestInfo) error {
|
||||
if len(requestInfo.Query) == 0 || requestInfo.HasSuperuserAuth() {
|
||||
return nil // superuser or nothing to check
|
||||
}
|
||||
|
||||
for _, param := range ruleQueryParams {
|
||||
v := requestInfo.Query[param]
|
||||
if v == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, field := range superuserOnlyRuleFields {
|
||||
if strings.Contains(v, field) {
|
||||
return router.NewForbiddenError("Only superusers can filter by "+field, nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// firstApiError returns the first ApiError from the errors list
|
||||
// (this is used usually to prevent unnecessary wraping and to allow bubling ApiError from nested hooks)
|
||||
//
|
||||
// If no ApiError is found, returns a default "Internal server" error.
|
||||
func firstApiError(errs ...error) *router.ApiError {
|
||||
var apiErr *router.ApiError
|
||||
var ok bool
|
||||
|
||||
for _, err := range errs {
|
||||
if err == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// quick assert to avoid the reflection checks
|
||||
apiErr, ok = err.(*router.ApiError)
|
||||
if ok {
|
||||
return apiErr
|
||||
}
|
||||
|
||||
// nested/wrapped errors
|
||||
if errors.As(err, &apiErr) {
|
||||
return apiErr
|
||||
}
|
||||
}
|
||||
|
||||
return router.NewInternalServerError("", errors.Join(errs...))
|
||||
}
|
||||
|
||||
// execAfterSuccessTx ensures that fn is executed only after a succesul transaction.
|
||||
//
|
||||
// If the current app instance is not a transactional or checkTx is false,
|
||||
// then fn is directly executed.
|
||||
//
|
||||
// It could be usually used to allow propagating an error or writing
|
||||
// custom response from within the wrapped transaction block.
|
||||
func execAfterSuccessTx(checkTx bool, app core.App, fn func() error) error {
|
||||
if txInfo := app.TxInfo(); txInfo != nil && checkTx {
|
||||
txInfo.OnComplete(func(txErr error) error {
|
||||
if txErr == nil {
|
||||
return fn()
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
return fn()
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
const maxAuthOrigins = 5
|
||||
|
||||
func authAlert(e *core.RequestEvent, authRecord *core.Record) error {
|
||||
// generating fingerprint
|
||||
// ---
|
||||
userAgent := e.Request.UserAgent()
|
||||
if len(userAgent) > 300 {
|
||||
userAgent = userAgent[:300]
|
||||
}
|
||||
fingerprint := security.MD5(e.RealIP() + userAgent)
|
||||
// ---
|
||||
|
||||
origins, err := e.App.FindAllAuthOriginsByRecord(authRecord)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
isFirstLogin := len(origins) == 0
|
||||
|
||||
var currentOrigin *core.AuthOrigin
|
||||
for _, origin := range origins {
|
||||
if origin.Fingerprint() == fingerprint {
|
||||
currentOrigin = origin
|
||||
break
|
||||
}
|
||||
}
|
||||
if currentOrigin == nil {
|
||||
currentOrigin = core.NewAuthOrigin(e.App)
|
||||
currentOrigin.SetCollectionRef(authRecord.Collection().Id)
|
||||
currentOrigin.SetRecordRef(authRecord.Id)
|
||||
currentOrigin.SetFingerprint(fingerprint)
|
||||
}
|
||||
|
||||
// send email alert for the new origin auth (skip first login)
|
||||
//
|
||||
// Note: The "fake" timeout is a temp solution to avoid blocking
|
||||
// for too long when the SMTP server is not accessible, due
|
||||
// to the lack of context cancellation support in the underlying
|
||||
// mailer and net/smtp package.
|
||||
// The goroutine technically "leaks" but we assume that the OS will
|
||||
// terminate the connection after some time (usually after 3-4 mins).
|
||||
if !isFirstLogin && currentOrigin.IsNew() && authRecord.Email() != "" {
|
||||
mailSent := make(chan error, 1)
|
||||
|
||||
timer := time.AfterFunc(15*time.Second, func() {
|
||||
mailSent <- errors.New("auth alert mail send wait timeout reached")
|
||||
})
|
||||
|
||||
routine.FireAndForget(func() {
|
||||
err := mails.SendRecordAuthAlert(e.App, authRecord)
|
||||
timer.Stop()
|
||||
mailSent <- err
|
||||
})
|
||||
|
||||
err = <-mailSent
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// try to keep only up to maxAuthOrigins
|
||||
// (pop the last used ones; it is not executed in a transaction to avoid unnecessary locks)
|
||||
if currentOrigin.IsNew() && len(origins) >= maxAuthOrigins {
|
||||
for i := len(origins) - 1; i >= maxAuthOrigins-1; i-- {
|
||||
if err := e.App.Delete(origins[i]); err != nil {
|
||||
// treat as non-critical error, just log for now
|
||||
e.App.Logger().Warn("Failed to delete old AuthOrigin record", "error", err, "authOriginId", origins[i].Id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// create/update the origin fingerprint
|
||||
return e.App.Save(currentOrigin)
|
||||
}
|
761
apis/record_helpers_test.go
Normal file
761
apis/record_helpers_test.go
Normal file
|
@ -0,0 +1,761 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pocketbase/pocketbase/apis"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
"github.com/pocketbase/pocketbase/tools/router"
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
)
|
||||
|
||||
func TestEnrichRecords(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// mock test data
|
||||
// ---
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
freshRecords := func(records []*core.Record) []*core.Record {
|
||||
result := make([]*core.Record, len(records))
|
||||
for i, r := range records {
|
||||
result[i] = r.Fresh()
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
superuser, err := app.FindAuthRecordByEmail(core.CollectionNameSuperusers, "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
usersRecords, err := app.FindRecordsByIds("users", []string{"4q1xlclmfloku33", "bgs820n361vj1qd"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
nologinRecords, err := app.FindRecordsByIds("nologin", []string{"dc49k6jgejn40h3", "oos036e9xvqeexy"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
demo1Records, err := app.FindRecordsByIds("demo1", []string{"al1h9ijdeojtsjy", "84nmscqy84lsi1t"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
demo5Records, err := app.FindRecordsByIds("demo5", []string{"la4y2w4o98acwuj", "qjeql998mtp1azp"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// temp update the view rule to ensure that request context is set to "expand"
|
||||
demo4, err := app.FindCollectionByNameOrId("demo4")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
demo4.ViewRule = types.Pointer("@request.context = 'expand'")
|
||||
if err := app.Save(demo4); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// ---
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
auth *core.Record
|
||||
records []*core.Record
|
||||
queryExpand string
|
||||
defaultExpands []string
|
||||
expected []string
|
||||
notExpected []string
|
||||
}{
|
||||
// email visibility checks
|
||||
{
|
||||
name: "[emailVisibility] guest",
|
||||
auth: nil,
|
||||
records: freshRecords(usersRecords),
|
||||
queryExpand: "",
|
||||
defaultExpands: nil,
|
||||
expected: []string{
|
||||
`"customField":"123"`,
|
||||
`"test3@example.com"`, // emailVisibility=true
|
||||
},
|
||||
notExpected: []string{
|
||||
`"test@example.com"`,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "[emailVisibility] owner",
|
||||
auth: user,
|
||||
records: freshRecords(usersRecords),
|
||||
queryExpand: "",
|
||||
defaultExpands: nil,
|
||||
expected: []string{
|
||||
`"customField":"123"`,
|
||||
`"test3@example.com"`, // emailVisibility=true
|
||||
`"test@example.com"`, // owner
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "[emailVisibility] manager",
|
||||
auth: user,
|
||||
records: freshRecords(nologinRecords),
|
||||
queryExpand: "",
|
||||
defaultExpands: nil,
|
||||
expected: []string{
|
||||
`"customField":"123"`,
|
||||
`"test3@example.com"`,
|
||||
`"test@example.com"`,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "[emailVisibility] superuser",
|
||||
auth: superuser,
|
||||
records: freshRecords(nologinRecords),
|
||||
queryExpand: "",
|
||||
defaultExpands: nil,
|
||||
expected: []string{
|
||||
`"customField":"123"`,
|
||||
`"test3@example.com"`,
|
||||
`"test@example.com"`,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "[emailVisibility + expand] recursive auth rule checks (regular user)",
|
||||
auth: user,
|
||||
records: freshRecords(demo1Records),
|
||||
queryExpand: "",
|
||||
defaultExpands: []string{"rel_many"},
|
||||
expected: []string{
|
||||
`"customField":"123"`,
|
||||
`"expand":{"rel_many"`,
|
||||
`"expand":{}`,
|
||||
`"test@example.com"`,
|
||||
},
|
||||
notExpected: []string{
|
||||
`"id":"bgs820n361vj1qd"`,
|
||||
`"id":"oap640cot4yru2s"`,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "[emailVisibility + expand] recursive auth rule checks (superuser)",
|
||||
auth: superuser,
|
||||
records: freshRecords(demo1Records),
|
||||
queryExpand: "",
|
||||
defaultExpands: []string{"rel_many"},
|
||||
expected: []string{
|
||||
`"customField":"123"`,
|
||||
`"test@example.com"`,
|
||||
`"expand":{"rel_many"`,
|
||||
`"id":"bgs820n361vj1qd"`,
|
||||
`"id":"4q1xlclmfloku33"`,
|
||||
`"id":"oap640cot4yru2s"`,
|
||||
},
|
||||
notExpected: []string{
|
||||
`"expand":{}`,
|
||||
},
|
||||
},
|
||||
|
||||
// expand checks
|
||||
{
|
||||
name: "[expand] guest (query)",
|
||||
auth: nil,
|
||||
records: freshRecords(usersRecords),
|
||||
queryExpand: "rel",
|
||||
defaultExpands: nil,
|
||||
expected: []string{
|
||||
`"customField":"123"`,
|
||||
`"expand":{"rel"`,
|
||||
`"id":"llvuca81nly1qls"`,
|
||||
`"id":"0yxhwia2amd8gec"`,
|
||||
},
|
||||
notExpected: []string{
|
||||
`"expand":{}`,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "[expand] guest (default expands)",
|
||||
auth: nil,
|
||||
records: freshRecords(usersRecords),
|
||||
queryExpand: "",
|
||||
defaultExpands: []string{"rel"},
|
||||
expected: []string{
|
||||
`"customField":"123"`,
|
||||
`"expand":{"rel"`,
|
||||
`"id":"llvuca81nly1qls"`,
|
||||
`"id":"0yxhwia2amd8gec"`,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "[expand] @request.context=expand check",
|
||||
auth: nil,
|
||||
records: freshRecords(demo5Records),
|
||||
queryExpand: "rel_one",
|
||||
defaultExpands: []string{"rel_many"},
|
||||
expected: []string{
|
||||
`"customField":"123"`,
|
||||
`"expand":{}`,
|
||||
`"expand":{"`,
|
||||
`"rel_many":[{`,
|
||||
`"rel_one":{`,
|
||||
`"id":"i9naidtvr6qsgb4"`,
|
||||
`"id":"qzaqccwrmva4o1n"`,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
app.OnRecordEnrich().BindFunc(func(e *core.RecordEnrichEvent) error {
|
||||
e.Record.WithCustomData(true)
|
||||
e.Record.Set("customField", "123")
|
||||
return e.Next()
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/?expand="+s.queryExpand, nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
requestEvent := new(core.RequestEvent)
|
||||
requestEvent.App = app
|
||||
requestEvent.Request = req
|
||||
requestEvent.Response = rec
|
||||
requestEvent.Auth = s.auth
|
||||
|
||||
err := apis.EnrichRecords(requestEvent, s.records, s.defaultExpands...)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
raw, err := json.Marshal(s.records)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
rawStr := string(raw)
|
||||
|
||||
for _, str := range s.expected {
|
||||
if !strings.Contains(rawStr, str) {
|
||||
t.Fatalf("Expected\n%q\nin\n%v", str, rawStr)
|
||||
}
|
||||
}
|
||||
|
||||
for _, str := range s.notExpected {
|
||||
if strings.Contains(rawStr, str) {
|
||||
t.Fatalf("Didn't expected\n%q\nin\n%v", str, rawStr)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordAuthResponseAuthRuleCheck(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
event := new(core.RequestEvent)
|
||||
event.App = app
|
||||
event.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
event.Response = httptest.NewRecorder()
|
||||
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
rule *string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
"admin only rule",
|
||||
nil,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"empty rule",
|
||||
types.Pointer(""),
|
||||
false,
|
||||
},
|
||||
{
|
||||
"false rule",
|
||||
types.Pointer("1=2"),
|
||||
true,
|
||||
},
|
||||
{
|
||||
"true rule",
|
||||
types.Pointer("1=1"),
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
user.Collection().AuthRule = s.rule
|
||||
|
||||
err := apis.RecordAuthResponse(event, user, "", nil)
|
||||
|
||||
hasErr := err != nil
|
||||
if s.expectError != hasErr {
|
||||
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
|
||||
}
|
||||
|
||||
// in all cases login alert shouldn't be send because of the empty auth method
|
||||
if app.TestMailer.TotalSend() != 0 {
|
||||
t.Fatalf("Expected no emails send, got %d:\n%v", app.TestMailer.TotalSend(), app.TestMailer.LastMessage().HTML)
|
||||
}
|
||||
|
||||
if !hasErr {
|
||||
return
|
||||
}
|
||||
|
||||
apiErr, ok := err.(*router.ApiError)
|
||||
|
||||
if !ok || apiErr == nil {
|
||||
t.Fatalf("Expected ApiError, got %v", apiErr)
|
||||
}
|
||||
|
||||
if apiErr.Status != http.StatusForbidden {
|
||||
t.Fatalf("Expected ApiError.Status %d, got %d", http.StatusForbidden, apiErr.Status)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordAuthResponseAuthAlertCheck(t *testing.T) {
|
||||
const testFingerprint = "d0f88d6c87767262ba8e93d6acccd784"
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
devices []string // mock existing device fingerprints
|
||||
expectDevices []string
|
||||
enabled bool
|
||||
expectEmail bool
|
||||
}{
|
||||
{
|
||||
name: "first login",
|
||||
devices: nil,
|
||||
expectDevices: []string{testFingerprint},
|
||||
enabled: true,
|
||||
expectEmail: false,
|
||||
},
|
||||
{
|
||||
name: "existing device",
|
||||
devices: []string{"1", testFingerprint},
|
||||
expectDevices: []string{"1", testFingerprint},
|
||||
enabled: true,
|
||||
expectEmail: false,
|
||||
},
|
||||
{
|
||||
name: "new device (< 5)",
|
||||
devices: []string{"1", "2"},
|
||||
expectDevices: []string{"1", "2", testFingerprint},
|
||||
enabled: true,
|
||||
expectEmail: true,
|
||||
},
|
||||
{
|
||||
name: "new device (>= 5)",
|
||||
devices: []string{"1", "2", "3", "4", "5"},
|
||||
expectDevices: []string{"2", "3", "4", "5", testFingerprint},
|
||||
enabled: true,
|
||||
expectEmail: true,
|
||||
},
|
||||
{
|
||||
name: "with disabled auth alert collection flag",
|
||||
devices: []string{"1", "2"},
|
||||
expectDevices: []string{"1", "2"},
|
||||
enabled: false,
|
||||
expectEmail: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
event := new(core.RequestEvent)
|
||||
event.App = app
|
||||
event.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
event.Response = httptest.NewRecorder()
|
||||
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
user.Collection().MFA.Enabled = false
|
||||
user.Collection().AuthRule = types.Pointer("")
|
||||
user.Collection().AuthAlert.Enabled = s.enabled
|
||||
|
||||
// ensure that there are no other auth origins
|
||||
err = app.DeleteAllAuthOriginsByRecord(user)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
mockCreated := types.NowDateTime().Add(-time.Duration(len(s.devices)+1) * time.Second)
|
||||
// insert the mock devices
|
||||
for _, fingerprint := range s.devices {
|
||||
mockCreated = mockCreated.Add(1 * time.Second)
|
||||
d := core.NewAuthOrigin(app)
|
||||
d.SetCollectionRef(user.Collection().Id)
|
||||
d.SetRecordRef(user.Id)
|
||||
d.SetFingerprint(fingerprint)
|
||||
d.SetRaw("created", mockCreated)
|
||||
d.SetRaw("updated", mockCreated)
|
||||
if err = app.Save(d); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
err = apis.RecordAuthResponse(event, user, "example", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to resolve auth response: %v", err)
|
||||
}
|
||||
|
||||
var expectTotalSend int
|
||||
if s.expectEmail {
|
||||
expectTotalSend = 1
|
||||
}
|
||||
if total := app.TestMailer.TotalSend(); total != expectTotalSend {
|
||||
t.Fatalf("Expected %d sent emails, got %d", expectTotalSend, total)
|
||||
}
|
||||
|
||||
devices, err := app.FindAllAuthOriginsByRecord(user)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to retrieve auth origins: %v", err)
|
||||
}
|
||||
|
||||
if len(devices) != len(s.expectDevices) {
|
||||
t.Fatalf("Expected %d devices, got %d", len(s.expectDevices), len(devices))
|
||||
}
|
||||
|
||||
for _, fingerprint := range s.expectDevices {
|
||||
var exists bool
|
||||
fingerprints := make([]string, 0, len(devices))
|
||||
for _, d := range devices {
|
||||
if d.Fingerprint() == fingerprint {
|
||||
exists = true
|
||||
break
|
||||
}
|
||||
fingerprints = append(fingerprints, d.Fingerprint())
|
||||
}
|
||||
if !exists {
|
||||
t.Fatalf("Missing device with fingerprint %q:\n%v", fingerprint, fingerprints)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordAuthResponseMFACheck(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
user2, err := app.FindAuthRecordByEmail("users", "test2@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
event := new(core.RequestEvent)
|
||||
event.App = app
|
||||
event.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
event.Response = rec
|
||||
|
||||
resetMFAs := func(authRecord *core.Record) {
|
||||
// ensure that mfa is enabled
|
||||
user.Collection().MFA.Enabled = true
|
||||
user.Collection().MFA.Duration = 5
|
||||
user.Collection().MFA.Rule = ""
|
||||
|
||||
mfas, err := app.FindAllMFAsByRecord(authRecord)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to retrieve mfas: %v", err)
|
||||
}
|
||||
for _, mfa := range mfas {
|
||||
if err := app.Delete(mfa); err != nil {
|
||||
t.Fatalf("Failed to delete mfa %q: %v", mfa.Id, err)
|
||||
}
|
||||
}
|
||||
|
||||
// reset response
|
||||
rec = httptest.NewRecorder()
|
||||
event.Response = rec
|
||||
}
|
||||
|
||||
totalMFAs := func(authRecord *core.Record) int {
|
||||
mfas, err := app.FindAllMFAsByRecord(authRecord)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to retrieve mfas: %v", err)
|
||||
}
|
||||
return len(mfas)
|
||||
}
|
||||
|
||||
t.Run("no collection MFA enabled", func(t *testing.T) {
|
||||
resetMFAs(user)
|
||||
|
||||
user.Collection().MFA.Enabled = false
|
||||
|
||||
err = apis.RecordAuthResponse(event, user, "example", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected nil, got error: %v", err)
|
||||
}
|
||||
|
||||
body := rec.Body.String()
|
||||
if strings.Contains(body, "mfaId") {
|
||||
t.Fatalf("Expected no mfaId in the response body, got\n%v", body)
|
||||
}
|
||||
if !strings.Contains(body, "token") {
|
||||
t.Fatalf("Expected auth token in the response body, got\n%v", body)
|
||||
}
|
||||
|
||||
if total := totalMFAs(user); total != 0 {
|
||||
t.Fatalf("Expected no mfa records to be created, got %d", total)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("no explicit auth method", func(t *testing.T) {
|
||||
resetMFAs(user)
|
||||
|
||||
err = apis.RecordAuthResponse(event, user, "", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected nil, got error: %v", err)
|
||||
}
|
||||
|
||||
body := rec.Body.String()
|
||||
if strings.Contains(body, "mfaId") {
|
||||
t.Fatalf("Expected no mfaId in the response body, got\n%v", body)
|
||||
}
|
||||
if !strings.Contains(body, "token") {
|
||||
t.Fatalf("Expected auth token in the response body, got\n%v", body)
|
||||
}
|
||||
|
||||
if total := totalMFAs(user); total != 0 {
|
||||
t.Fatalf("Expected no mfa records to be created, got %d", total)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("no mfa wanted (mfa rule check failure)", func(t *testing.T) {
|
||||
resetMFAs(user)
|
||||
user.Collection().MFA.Rule = "1=2"
|
||||
|
||||
err = apis.RecordAuthResponse(event, user, "example", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected nil, got error: %v", err)
|
||||
}
|
||||
|
||||
body := rec.Body.String()
|
||||
if strings.Contains(body, "mfaId") {
|
||||
t.Fatalf("Expected no mfaId in the response body, got\n%v", body)
|
||||
}
|
||||
if !strings.Contains(body, "token") {
|
||||
t.Fatalf("Expected auth token in the response body, got\n%v", body)
|
||||
}
|
||||
|
||||
if total := totalMFAs(user); total != 0 {
|
||||
t.Fatalf("Expected no mfa records to be created, got %d", total)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("mfa wanted (mfa rule check success)", func(t *testing.T) {
|
||||
resetMFAs(user)
|
||||
user.Collection().MFA.Rule = "1=1"
|
||||
|
||||
err = apis.RecordAuthResponse(event, user, "example", nil)
|
||||
if !errors.Is(err, apis.ErrMFA) {
|
||||
t.Fatalf("Expected ErrMFA, got: %v", err)
|
||||
}
|
||||
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, "mfaId") {
|
||||
t.Fatalf("Expected the created mfaId to be returned in the response body, got\n%v", body)
|
||||
}
|
||||
|
||||
if total := totalMFAs(user); total != 1 {
|
||||
t.Fatalf("Expected a single mfa record to be created, got %d", total)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("mfa first-time", func(t *testing.T) {
|
||||
resetMFAs(user)
|
||||
|
||||
err = apis.RecordAuthResponse(event, user, "example", nil)
|
||||
if !errors.Is(err, apis.ErrMFA) {
|
||||
t.Fatalf("Expected ErrMFA, got: %v", err)
|
||||
}
|
||||
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, "mfaId") {
|
||||
t.Fatalf("Expected the created mfaId to be returned in the response body, got\n%v", body)
|
||||
}
|
||||
|
||||
if total := totalMFAs(user); total != 1 {
|
||||
t.Fatalf("Expected a single mfa record to be created, got %d", total)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("mfa second-time with the same auth method", func(t *testing.T) {
|
||||
resetMFAs(user)
|
||||
|
||||
// create a dummy mfa record
|
||||
mfa := core.NewMFA(app)
|
||||
mfa.SetCollectionRef(user.Collection().Id)
|
||||
mfa.SetRecordRef(user.Id)
|
||||
mfa.SetMethod("example")
|
||||
if err = app.Save(mfa); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
event.Request = httptest.NewRequest(http.MethodGet, "/?mfaId="+mfa.Id, nil)
|
||||
|
||||
err = apis.RecordAuthResponse(event, user, "example", nil)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error, got nil")
|
||||
}
|
||||
|
||||
if total := totalMFAs(user); total != 1 {
|
||||
t.Fatalf("Expected only 1 mfa record (the existing one), got %d", total)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("mfa second-time with the different auth method (query param)", func(t *testing.T) {
|
||||
resetMFAs(user)
|
||||
|
||||
// create a dummy mfa record
|
||||
mfa := core.NewMFA(app)
|
||||
mfa.SetCollectionRef(user.Collection().Id)
|
||||
mfa.SetRecordRef(user.Id)
|
||||
mfa.SetMethod("example1")
|
||||
if err = app.Save(mfa); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
event.Request = httptest.NewRequest(http.MethodGet, "/?mfaId="+mfa.Id, nil)
|
||||
|
||||
err = apis.RecordAuthResponse(event, user, "example2", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected nil, got error: %v", err)
|
||||
}
|
||||
|
||||
if total := totalMFAs(user); total != 0 {
|
||||
t.Fatalf("Expected the dummy mfa record to be deleted, found %d", total)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("mfa second-time with the different auth method (body param)", func(t *testing.T) {
|
||||
resetMFAs(user)
|
||||
|
||||
// create a dummy mfa record
|
||||
mfa := core.NewMFA(app)
|
||||
mfa.SetCollectionRef(user.Collection().Id)
|
||||
mfa.SetRecordRef(user.Id)
|
||||
mfa.SetMethod("example1")
|
||||
if err = app.Save(mfa); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
event.Request = httptest.NewRequest(http.MethodGet, "/", strings.NewReader(`{"mfaId":"`+mfa.Id+`"}`))
|
||||
event.Request.Header.Add("content-type", "application/json")
|
||||
|
||||
err = apis.RecordAuthResponse(event, user, "example2", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected nil, got error: %v", err)
|
||||
}
|
||||
|
||||
if total := totalMFAs(user); total != 0 {
|
||||
t.Fatalf("Expected the dummy mfa record to be deleted, found %d", total)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing mfa", func(t *testing.T) {
|
||||
resetMFAs(user)
|
||||
|
||||
event.Request = httptest.NewRequest(http.MethodGet, "/?mfaId=missing", nil)
|
||||
|
||||
err = apis.RecordAuthResponse(event, user, "example2", nil)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error, got nil")
|
||||
}
|
||||
|
||||
if total := totalMFAs(user); total != 0 {
|
||||
t.Fatalf("Expected 0 mfa records, got %d", total)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("expired mfa", func(t *testing.T) {
|
||||
resetMFAs(user)
|
||||
|
||||
// create a dummy expired mfa record
|
||||
mfa := core.NewMFA(app)
|
||||
mfa.SetCollectionRef(user.Collection().Id)
|
||||
mfa.SetRecordRef(user.Id)
|
||||
mfa.SetMethod("example1")
|
||||
mfa.SetRaw("created", types.NowDateTime().Add(-1*time.Hour))
|
||||
mfa.SetRaw("updated", types.NowDateTime().Add(-1*time.Hour))
|
||||
if err = app.Save(mfa); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
event.Request = httptest.NewRequest(http.MethodGet, "/?mfaId="+mfa.Id, nil)
|
||||
|
||||
err = apis.RecordAuthResponse(event, user, "example2", nil)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error, got nil")
|
||||
}
|
||||
|
||||
if totalMFAs(user) != 0 {
|
||||
t.Fatal("Expected the expired mfa record to be deleted")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("mfa for different auth record", func(t *testing.T) {
|
||||
resetMFAs(user)
|
||||
|
||||
// create a dummy expired mfa record
|
||||
mfa := core.NewMFA(app)
|
||||
mfa.SetCollectionRef(user2.Collection().Id)
|
||||
mfa.SetRecordRef(user2.Id)
|
||||
mfa.SetMethod("example1")
|
||||
if err = app.Save(mfa); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
event.Request = httptest.NewRequest(http.MethodGet, "/?mfaId="+mfa.Id, nil)
|
||||
|
||||
err = apis.RecordAuthResponse(event, user, "example2", nil)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error, got nil")
|
||||
}
|
||||
|
||||
if total := totalMFAs(user); total != 0 {
|
||||
t.Fatalf("Expected no user mfas, got %d", total)
|
||||
}
|
||||
|
||||
if total := totalMFAs(user2); total != 1 {
|
||||
t.Fatalf("Expected only 1 user2 mfa, got %d", total)
|
||||
}
|
||||
})
|
||||
}
|
318
apis/serve.go
Normal file
318
apis/serve.go
Normal file
|
@ -0,0 +1,318 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fatih/color"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tools/hook"
|
||||
"github.com/pocketbase/pocketbase/tools/list"
|
||||
"github.com/pocketbase/pocketbase/tools/routine"
|
||||
"github.com/pocketbase/pocketbase/ui"
|
||||
"golang.org/x/crypto/acme"
|
||||
"golang.org/x/crypto/acme/autocert"
|
||||
)
|
||||
|
||||
// ServeConfig defines a configuration struct for apis.Serve().
|
||||
type ServeConfig struct {
|
||||
// ShowStartBanner indicates whether to show or hide the server start console message.
|
||||
ShowStartBanner bool
|
||||
|
||||
// HttpAddr is the TCP address to listen for the HTTP server (eg. "127.0.0.1:80").
|
||||
HttpAddr string
|
||||
|
||||
// HttpsAddr is the TCP address to listen for the HTTPS server (eg. "127.0.0.1:443").
|
||||
HttpsAddr string
|
||||
|
||||
// Optional domains list to use when issuing the TLS certificate.
|
||||
//
|
||||
// If not set, the host from the bound server address will be used.
|
||||
//
|
||||
// For convenience, for each "non-www" domain a "www" entry and
|
||||
// redirect will be automatically added.
|
||||
CertificateDomains []string
|
||||
|
||||
// AllowedOrigins is an optional list of CORS origins (default to "*").
|
||||
AllowedOrigins []string
|
||||
}
|
||||
|
||||
// Serve starts a new app web server.
|
||||
//
|
||||
// NB! The app should be bootstrapped before starting the web server.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// app.Bootstrap()
|
||||
// apis.Serve(app, apis.ServeConfig{
|
||||
// HttpAddr: "127.0.0.1:8080",
|
||||
// ShowStartBanner: false,
|
||||
// })
|
||||
func Serve(app core.App, config ServeConfig) error {
|
||||
if len(config.AllowedOrigins) == 0 {
|
||||
config.AllowedOrigins = []string{"*"}
|
||||
}
|
||||
|
||||
// ensure that the latest migrations are applied before starting the server
|
||||
err := app.RunAllMigrations()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
pbRouter, err := NewRouter(app)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
pbRouter.Bind(CORS(CORSConfig{
|
||||
AllowOrigins: config.AllowedOrigins,
|
||||
AllowMethods: []string{http.MethodGet, http.MethodHead, http.MethodPut, http.MethodPatch, http.MethodPost, http.MethodDelete},
|
||||
}))
|
||||
|
||||
pbRouter.GET("/_/{path...}", Static(ui.DistDirFS, false)).
|
||||
BindFunc(func(e *core.RequestEvent) error {
|
||||
// ignore root path
|
||||
if e.Request.PathValue(StaticWildcardParam) != "" {
|
||||
e.Response.Header().Set("Cache-Control", "max-age=1209600, stale-while-revalidate=86400")
|
||||
}
|
||||
|
||||
// add a default CSP
|
||||
if e.Response.Header().Get("Content-Security-Policy") == "" {
|
||||
e.Response.Header().Set("Content-Security-Policy", "default-src 'self'; style-src 'self' 'unsafe-inline'; img-src 'self' http://127.0.0.1:* https://tile.openstreetmap.org data: blob:; connect-src 'self' http://127.0.0.1:* https://nominatim.openstreetmap.org; script-src 'self' 'sha256-GRUzBA7PzKYug7pqxv5rJaec5bwDCw1Vo6/IXwvD3Tc='")
|
||||
}
|
||||
|
||||
return e.Next()
|
||||
}).
|
||||
Bind(Gzip())
|
||||
|
||||
// start http server
|
||||
// ---
|
||||
mainAddr := config.HttpAddr
|
||||
if config.HttpsAddr != "" {
|
||||
mainAddr = config.HttpsAddr
|
||||
}
|
||||
|
||||
var wwwRedirects []string
|
||||
|
||||
// extract the host names for the certificate host policy
|
||||
hostNames := config.CertificateDomains
|
||||
if len(hostNames) == 0 {
|
||||
host, _, _ := net.SplitHostPort(mainAddr)
|
||||
hostNames = append(hostNames, host)
|
||||
}
|
||||
for _, host := range hostNames {
|
||||
if strings.HasPrefix(host, "www.") {
|
||||
continue // explicitly set www host
|
||||
}
|
||||
|
||||
wwwHost := "www." + host
|
||||
if !list.ExistInSlice(wwwHost, hostNames) {
|
||||
hostNames = append(hostNames, wwwHost)
|
||||
wwwRedirects = append(wwwRedirects, wwwHost)
|
||||
}
|
||||
}
|
||||
|
||||
// implicit www->non-www redirect(s)
|
||||
if len(wwwRedirects) > 0 {
|
||||
pbRouter.Bind(wwwRedirect(wwwRedirects))
|
||||
}
|
||||
|
||||
certManager := &autocert.Manager{
|
||||
Prompt: autocert.AcceptTOS,
|
||||
Cache: autocert.DirCache(filepath.Join(app.DataDir(), core.LocalAutocertCacheDirName)),
|
||||
HostPolicy: autocert.HostWhitelist(hostNames...),
|
||||
}
|
||||
|
||||
// base request context used for cancelling long running requests
|
||||
// like the SSE connections
|
||||
baseCtx, cancelBaseCtx := context.WithCancel(context.Background())
|
||||
defer cancelBaseCtx()
|
||||
|
||||
server := &http.Server{
|
||||
TLSConfig: &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
GetCertificate: certManager.GetCertificate,
|
||||
NextProtos: []string{acme.ALPNProto},
|
||||
},
|
||||
// higher defaults to accommodate large file uploads/downloads
|
||||
WriteTimeout: 5 * time.Minute,
|
||||
ReadTimeout: 5 * time.Minute,
|
||||
ReadHeaderTimeout: 1 * time.Minute,
|
||||
Addr: mainAddr,
|
||||
BaseContext: func(l net.Listener) context.Context {
|
||||
return baseCtx
|
||||
},
|
||||
ErrorLog: log.New(&serverErrorLogWriter{app: app}, "", 0),
|
||||
}
|
||||
|
||||
serveEvent := new(core.ServeEvent)
|
||||
serveEvent.App = app
|
||||
serveEvent.Router = pbRouter
|
||||
serveEvent.Server = server
|
||||
serveEvent.CertManager = certManager
|
||||
serveEvent.InstallerFunc = DefaultInstallerFunc
|
||||
|
||||
var listener net.Listener
|
||||
|
||||
// graceful shutdown
|
||||
// ---------------------------------------------------------------
|
||||
// WaitGroup to block until server.ShutDown() returns because Serve and similar methods exit immediately.
|
||||
// Note that the WaitGroup would do nothing if the app.OnTerminate() hook isn't triggered.
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// try to gracefully shutdown the server on app termination
|
||||
app.OnTerminate().Bind(&hook.Handler[*core.TerminateEvent]{
|
||||
Id: "pbGracefulShutdown",
|
||||
Func: func(te *core.TerminateEvent) error {
|
||||
cancelBaseCtx()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
||||
defer cancel()
|
||||
|
||||
wg.Add(1)
|
||||
|
||||
_ = server.Shutdown(ctx)
|
||||
|
||||
if te.IsRestart {
|
||||
// wait for execve and other handlers up to 3 seconds before exit
|
||||
time.AfterFunc(3*time.Second, func() {
|
||||
wg.Done()
|
||||
})
|
||||
} else {
|
||||
wg.Done()
|
||||
}
|
||||
|
||||
return te.Next()
|
||||
},
|
||||
Priority: -9999,
|
||||
})
|
||||
|
||||
// wait for the graceful shutdown to complete before exit
|
||||
defer func() {
|
||||
wg.Wait()
|
||||
|
||||
if listener != nil {
|
||||
_ = listener.Close()
|
||||
}
|
||||
}()
|
||||
// ---------------------------------------------------------------
|
||||
|
||||
var baseURL string
|
||||
|
||||
// trigger the OnServe hook and start the tcp listener
|
||||
serveHookErr := app.OnServe().Trigger(serveEvent, func(e *core.ServeEvent) error {
|
||||
handler, err := e.Router.BuildMux()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
e.Server.Handler = handler
|
||||
|
||||
if config.HttpsAddr == "" {
|
||||
baseURL = "http://" + serverAddrToHost(serveEvent.Server.Addr)
|
||||
} else {
|
||||
baseURL = "https://"
|
||||
if len(config.CertificateDomains) > 0 {
|
||||
baseURL += config.CertificateDomains[0]
|
||||
} else {
|
||||
baseURL += serverAddrToHost(serveEvent.Server.Addr)
|
||||
}
|
||||
}
|
||||
|
||||
addr := e.Server.Addr
|
||||
if addr == "" {
|
||||
// fallback similar to the std Server.ListenAndServe/ListenAndServeTLS
|
||||
if config.HttpsAddr != "" {
|
||||
addr = ":https"
|
||||
} else {
|
||||
addr = ":http"
|
||||
}
|
||||
}
|
||||
|
||||
listener, err = net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if e.InstallerFunc != nil {
|
||||
app := e.App
|
||||
installerFunc := e.InstallerFunc
|
||||
routine.FireAndForget(func() {
|
||||
if err := loadInstaller(app, baseURL, installerFunc); err != nil {
|
||||
app.Logger().Warn("Failed to initialize installer", "error", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if serveHookErr != nil {
|
||||
return serveHookErr
|
||||
}
|
||||
|
||||
if listener == nil {
|
||||
//nolint:staticcheck
|
||||
return errors.New("The OnServe finalizer wasn't invoked. Did you forget to call the ServeEvent.Next() method?")
|
||||
}
|
||||
|
||||
if config.ShowStartBanner {
|
||||
date := new(strings.Builder)
|
||||
log.New(date, "", log.LstdFlags).Print()
|
||||
|
||||
bold := color.New(color.Bold).Add(color.FgGreen)
|
||||
bold.Printf(
|
||||
"%s Server started at %s\n",
|
||||
strings.TrimSpace(date.String()),
|
||||
color.CyanString("%s", baseURL),
|
||||
)
|
||||
|
||||
regular := color.New()
|
||||
regular.Printf("├─ REST API: %s\n", color.CyanString("%s/api/", baseURL))
|
||||
regular.Printf("└─ Dashboard: %s\n", color.CyanString("%s/_/", baseURL))
|
||||
}
|
||||
|
||||
var serveErr error
|
||||
if config.HttpsAddr != "" {
|
||||
if config.HttpAddr != "" {
|
||||
// start an additional HTTP server for redirecting the traffic to the HTTPS version
|
||||
go http.ListenAndServe(config.HttpAddr, certManager.HTTPHandler(nil))
|
||||
}
|
||||
|
||||
// start HTTPS server
|
||||
serveErr = serveEvent.Server.ServeTLS(listener, "", "")
|
||||
} else {
|
||||
// OR start HTTP server
|
||||
serveErr = serveEvent.Server.Serve(listener)
|
||||
}
|
||||
if serveErr != nil && !errors.Is(serveErr, http.ErrServerClosed) {
|
||||
return serveErr
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// serverAddrToHost loosely converts http.Server.Addr string into a host to print.
|
||||
func serverAddrToHost(addr string) string {
|
||||
if addr == "" || strings.HasSuffix(addr, ":http") || strings.HasSuffix(addr, ":https") {
|
||||
return "127.0.0.1"
|
||||
}
|
||||
return addr
|
||||
}
|
||||
|
||||
type serverErrorLogWriter struct {
|
||||
app core.App
|
||||
}
|
||||
|
||||
func (s *serverErrorLogWriter) Write(p []byte) (int, error) {
|
||||
s.app.Logger().Debug(strings.TrimSpace(string(p)))
|
||||
|
||||
return len(p), nil
|
||||
}
|
143
apis/settings.go
Normal file
143
apis/settings.go
Normal file
|
@ -0,0 +1,143 @@
|
|||
package apis
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/forms"
|
||||
"github.com/pocketbase/pocketbase/tools/router"
|
||||
)
|
||||
|
||||
// bindSettingsApi registers the settings api endpoints.
|
||||
func bindSettingsApi(app core.App, rg *router.RouterGroup[*core.RequestEvent]) {
|
||||
subGroup := rg.Group("/settings").Bind(RequireSuperuserAuth())
|
||||
subGroup.GET("", settingsList)
|
||||
subGroup.PATCH("", settingsSet)
|
||||
subGroup.POST("/test/s3", settingsTestS3)
|
||||
subGroup.POST("/test/email", settingsTestEmail)
|
||||
subGroup.POST("/apple/generate-client-secret", settingsGenerateAppleClientSecret)
|
||||
}
|
||||
|
||||
func settingsList(e *core.RequestEvent) error {
|
||||
clone, err := e.App.Settings().Clone()
|
||||
if err != nil {
|
||||
return e.InternalServerError("", err)
|
||||
}
|
||||
|
||||
event := new(core.SettingsListRequestEvent)
|
||||
event.RequestEvent = e
|
||||
event.Settings = clone
|
||||
|
||||
return e.App.OnSettingsListRequest().Trigger(event, func(e *core.SettingsListRequestEvent) error {
|
||||
return execAfterSuccessTx(true, e.App, func() error {
|
||||
return e.JSON(http.StatusOK, e.Settings)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func settingsSet(e *core.RequestEvent) error {
|
||||
event := new(core.SettingsUpdateRequestEvent)
|
||||
event.RequestEvent = e
|
||||
|
||||
if clone, err := e.App.Settings().Clone(); err == nil {
|
||||
event.OldSettings = clone
|
||||
} else {
|
||||
return e.BadRequestError("", err)
|
||||
}
|
||||
|
||||
if clone, err := e.App.Settings().Clone(); err == nil {
|
||||
event.NewSettings = clone
|
||||
} else {
|
||||
return e.BadRequestError("", err)
|
||||
}
|
||||
|
||||
if err := e.BindBody(&event.NewSettings); err != nil {
|
||||
return e.BadRequestError("An error occurred while loading the submitted data.", err)
|
||||
}
|
||||
|
||||
return e.App.OnSettingsUpdateRequest().Trigger(event, func(e *core.SettingsUpdateRequestEvent) error {
|
||||
err := e.App.Save(e.NewSettings)
|
||||
if err != nil {
|
||||
return e.BadRequestError("An error occurred while saving the new settings.", err)
|
||||
}
|
||||
|
||||
appSettings, err := e.App.Settings().Clone()
|
||||
if err != nil {
|
||||
return e.InternalServerError("Failed to clone app settings.", err)
|
||||
}
|
||||
|
||||
return execAfterSuccessTx(true, e.App, func() error {
|
||||
return e.JSON(http.StatusOK, appSettings)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func settingsTestS3(e *core.RequestEvent) error {
|
||||
form := forms.NewTestS3Filesystem(e.App)
|
||||
|
||||
// load request
|
||||
if err := e.BindBody(form); err != nil {
|
||||
return e.BadRequestError("An error occurred while loading the submitted data.", err)
|
||||
}
|
||||
|
||||
// send
|
||||
if err := form.Submit(); err != nil {
|
||||
// form error
|
||||
if fErr, ok := err.(validation.Errors); ok {
|
||||
return e.BadRequestError("Failed to test the S3 filesystem.", fErr)
|
||||
}
|
||||
|
||||
// mailer error
|
||||
return e.BadRequestError("Failed to test the S3 filesystem. Raw error: \n"+err.Error(), nil)
|
||||
}
|
||||
|
||||
return e.NoContent(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func settingsTestEmail(e *core.RequestEvent) error {
|
||||
form := forms.NewTestEmailSend(e.App)
|
||||
|
||||
// load request
|
||||
if err := e.BindBody(form); err != nil {
|
||||
return e.BadRequestError("An error occurred while loading the submitted data.", err)
|
||||
}
|
||||
|
||||
// send
|
||||
if err := form.Submit(); err != nil {
|
||||
// form error
|
||||
if fErr, ok := err.(validation.Errors); ok {
|
||||
return e.BadRequestError("Failed to send the test email.", fErr)
|
||||
}
|
||||
|
||||
// mailer error
|
||||
return e.BadRequestError("Failed to send the test email. Raw error: \n"+err.Error(), nil)
|
||||
}
|
||||
|
||||
return e.NoContent(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func settingsGenerateAppleClientSecret(e *core.RequestEvent) error {
|
||||
form := forms.NewAppleClientSecretCreate(e.App)
|
||||
|
||||
// load request
|
||||
if err := e.BindBody(form); err != nil {
|
||||
return e.BadRequestError("An error occurred while loading the submitted data.", err)
|
||||
}
|
||||
|
||||
// generate
|
||||
secret, err := form.Submit()
|
||||
if err != nil {
|
||||
// form error
|
||||
if fErr, ok := err.(validation.Errors); ok {
|
||||
return e.BadRequestError("Invalid client secret data.", fErr)
|
||||
}
|
||||
|
||||
// secret generation error
|
||||
return e.BadRequestError("Failed to generate client secret. Raw error: \n"+err.Error(), nil)
|
||||
}
|
||||
|
||||
return e.JSON(http.StatusOK, map[string]string{
|
||||
"secret": secret,
|
||||
})
|
||||
}
|
641
apis/settings_test.go
Normal file
641
apis/settings_test.go
Normal file
|
@ -0,0 +1,641 @@
|
|||
package apis_test
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestSettingsList(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "unauthorized",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/settings",
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as regular user",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/settings",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/settings",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"meta":{`,
|
||||
`"logs":{`,
|
||||
`"smtp":{`,
|
||||
`"s3":{`,
|
||||
`"backups":{`,
|
||||
`"batch":{`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnSettingsListRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "OnSettingsListRequest tx body write check",
|
||||
Method: http.MethodGet,
|
||||
URL: "/api/settings",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.OnSettingsListRequest().BindFunc(func(e *core.SettingsListRequestEvent) error {
|
||||
original := e.App
|
||||
return e.App.RunInTransaction(func(txApp core.App) error {
|
||||
e.App = txApp
|
||||
defer func() { e.App = original }()
|
||||
|
||||
if err := e.Next(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return e.BadRequestError("TX_ERROR", nil)
|
||||
})
|
||||
})
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedEvents: map[string]int{"OnSettingsListRequest": 1},
|
||||
ExpectedContent: []string{"TX_ERROR"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSettingsSet(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
validData := `{
|
||||
"meta":{"appName":"update_test"},
|
||||
"s3":{"secret": "s3_secret"},
|
||||
"backups":{"s3":{"secret":"backups_s3_secret"}}
|
||||
}`
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "unauthorized",
|
||||
Method: http.MethodPatch,
|
||||
URL: "/api/settings",
|
||||
Body: strings.NewReader(validData),
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as regular user",
|
||||
Method: http.MethodPatch,
|
||||
URL: "/api/settings",
|
||||
Body: strings.NewReader(validData),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser submitting empty data",
|
||||
Method: http.MethodPatch,
|
||||
URL: "/api/settings",
|
||||
Body: strings.NewReader(``),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"meta":{`,
|
||||
`"logs":{`,
|
||||
`"smtp":{`,
|
||||
`"s3":{`,
|
||||
`"backups":{`,
|
||||
`"batch":{`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnSettingsUpdateRequest": 1,
|
||||
"OnModelUpdate": 1,
|
||||
"OnModelUpdateExecute": 1,
|
||||
"OnModelAfterUpdateSuccess": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnSettingsReload": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser submitting invalid data",
|
||||
Method: http.MethodPatch,
|
||||
URL: "/api/settings",
|
||||
Body: strings.NewReader(`{"meta":{"appName":""}}`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"meta":{"appName":{"code":"validation_required"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnModelUpdate": 1,
|
||||
"OnModelAfterUpdateError": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnSettingsUpdateRequest": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser submitting valid data",
|
||||
Method: http.MethodPatch,
|
||||
URL: "/api/settings",
|
||||
Body: strings.NewReader(validData),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"meta":{`,
|
||||
`"logs":{`,
|
||||
`"smtp":{`,
|
||||
`"s3":{`,
|
||||
`"backups":{`,
|
||||
`"batch":{`,
|
||||
`"appName":"update_test"`,
|
||||
},
|
||||
NotExpectedContent: []string{
|
||||
"secret",
|
||||
"password",
|
||||
},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnSettingsUpdateRequest": 1,
|
||||
"OnModelUpdate": 1,
|
||||
"OnModelUpdateExecute": 1,
|
||||
"OnModelAfterUpdateSuccess": 1,
|
||||
"OnModelValidate": 1,
|
||||
"OnSettingsReload": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "OnSettingsUpdateRequest tx body write check",
|
||||
Method: http.MethodPatch,
|
||||
URL: "/api/settings",
|
||||
Body: strings.NewReader(validData),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
|
||||
app.OnSettingsUpdateRequest().BindFunc(func(e *core.SettingsUpdateRequestEvent) error {
|
||||
original := e.App
|
||||
return e.App.RunInTransaction(func(txApp core.App) error {
|
||||
e.App = txApp
|
||||
defer func() { e.App = original }()
|
||||
|
||||
if err := e.Next(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return e.BadRequestError("TX_ERROR", nil)
|
||||
})
|
||||
})
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedEvents: map[string]int{"OnSettingsUpdateRequest": 1},
|
||||
ExpectedContent: []string{"TX_ERROR"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSettingsTestS3(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "unauthorized",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/settings/test/s3",
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as regular user",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/settings/test/s3",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser (missing body + no s3)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/settings/test/s3",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"filesystem":{`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser (invalid filesystem)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/settings/test/s3",
|
||||
Body: strings.NewReader(`{"filesystem":"invalid"}`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{`,
|
||||
`"filesystem":{`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser (valid filesystem and no s3)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/settings/test/s3",
|
||||
Body: strings.NewReader(`{"filesystem":"storage"}`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"data":{}`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSettingsTestEmail(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "unauthorized",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/settings/test/email",
|
||||
Body: strings.NewReader(`{
|
||||
"template": "verification",
|
||||
"email": "test@example.com"
|
||||
}`),
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as regular user",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/settings/test/email",
|
||||
Body: strings.NewReader(`{
|
||||
"template": "verification",
|
||||
"email": "test@example.com"
|
||||
}`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser (invalid body)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/settings/test/email",
|
||||
Body: strings.NewReader(`{`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser (empty json)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/settings/test/email",
|
||||
Body: strings.NewReader(`{}`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"email":{"code":"validation_required"`,
|
||||
`"template":{"code":"validation_required"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser (verifiation template)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/settings/test/email",
|
||||
Body: strings.NewReader(`{
|
||||
"template": "verification",
|
||||
"email": "test@example.com"
|
||||
}`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
if app.TestMailer.TotalSend() != 1 {
|
||||
t.Fatalf("[verification] Expected 1 sent email, got %d", app.TestMailer.TotalSend())
|
||||
}
|
||||
|
||||
if len(app.TestMailer.LastMessage().To) != 1 {
|
||||
t.Fatalf("[verification] Expected 1 recipient, got %v", app.TestMailer.LastMessage().To)
|
||||
}
|
||||
|
||||
if app.TestMailer.LastMessage().To[0].Address != "test@example.com" {
|
||||
t.Fatalf("[verification] Expected the email to be sent to %s, got %s", "test@example.com", app.TestMailer.LastMessage().To[0].Address)
|
||||
}
|
||||
|
||||
if !strings.Contains(app.TestMailer.LastMessage().HTML, "Verify") {
|
||||
t.Fatalf("[verification] Expected to sent a verification email, got \n%v\n%v", app.TestMailer.LastMessage().Subject, app.TestMailer.LastMessage().HTML)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 204,
|
||||
ExpectedContent: []string{},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnMailerSend": 1,
|
||||
"OnMailerRecordVerificationSend": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser (password reset template)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/settings/test/email",
|
||||
Body: strings.NewReader(`{
|
||||
"template": "password-reset",
|
||||
"email": "test@example.com"
|
||||
}`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
if app.TestMailer.TotalSend() != 1 {
|
||||
t.Fatalf("[password-reset] Expected 1 sent email, got %d", app.TestMailer.TotalSend())
|
||||
}
|
||||
|
||||
if len(app.TestMailer.LastMessage().To) != 1 {
|
||||
t.Fatalf("[password-reset] Expected 1 recipient, got %v", app.TestMailer.LastMessage().To)
|
||||
}
|
||||
|
||||
if app.TestMailer.LastMessage().To[0].Address != "test@example.com" {
|
||||
t.Fatalf("[password-reset] Expected the email to be sent to %s, got %s", "test@example.com", app.TestMailer.LastMessage().To[0].Address)
|
||||
}
|
||||
|
||||
if !strings.Contains(app.TestMailer.LastMessage().HTML, "Reset password") {
|
||||
t.Fatalf("[password-reset] Expected to sent a password-reset email, got \n%v\n%v", app.TestMailer.LastMessage().Subject, app.TestMailer.LastMessage().HTML)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 204,
|
||||
ExpectedContent: []string{},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnMailerSend": 1,
|
||||
"OnMailerRecordPasswordResetSend": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser (email change)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/settings/test/email",
|
||||
Body: strings.NewReader(`{
|
||||
"template": "email-change",
|
||||
"email": "test@example.com"
|
||||
}`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
if app.TestMailer.TotalSend() != 1 {
|
||||
t.Fatalf("[email-change] Expected 1 sent email, got %d", app.TestMailer.TotalSend())
|
||||
}
|
||||
|
||||
if len(app.TestMailer.LastMessage().To) != 1 {
|
||||
t.Fatalf("[email-change] Expected 1 recipient, got %v", app.TestMailer.LastMessage().To)
|
||||
}
|
||||
|
||||
if app.TestMailer.LastMessage().To[0].Address != "test@example.com" {
|
||||
t.Fatalf("[email-change] Expected the email to be sent to %s, got %s", "test@example.com", app.TestMailer.LastMessage().To[0].Address)
|
||||
}
|
||||
|
||||
if !strings.Contains(app.TestMailer.LastMessage().HTML, "Confirm new email") {
|
||||
t.Fatalf("[email-change] Expected to sent a confirm new email email, got \n%v\n%v", app.TestMailer.LastMessage().Subject, app.TestMailer.LastMessage().HTML)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 204,
|
||||
ExpectedContent: []string{},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnMailerSend": 1,
|
||||
"OnMailerRecordEmailChangeSend": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser (otp)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/settings/test/email",
|
||||
Body: strings.NewReader(`{
|
||||
"template": "otp",
|
||||
"email": "test@example.com"
|
||||
}`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
|
||||
if app.TestMailer.TotalSend() != 1 {
|
||||
t.Fatalf("[otp] Expected 1 sent email, got %d", app.TestMailer.TotalSend())
|
||||
}
|
||||
|
||||
if len(app.TestMailer.LastMessage().To) != 1 {
|
||||
t.Fatalf("[otp] Expected 1 recipient, got %v", app.TestMailer.LastMessage().To)
|
||||
}
|
||||
|
||||
if app.TestMailer.LastMessage().To[0].Address != "test@example.com" {
|
||||
t.Fatalf("[otp] Expected the email to be sent to %s, got %s", "test@example.com", app.TestMailer.LastMessage().To[0].Address)
|
||||
}
|
||||
|
||||
if !strings.Contains(app.TestMailer.LastMessage().HTML, "one-time password") {
|
||||
t.Fatalf("[otp] Expected to sent OTP email, got \n%v\n%v", app.TestMailer.LastMessage().Subject, app.TestMailer.LastMessage().HTML)
|
||||
}
|
||||
},
|
||||
ExpectedStatus: 204,
|
||||
ExpectedContent: []string{},
|
||||
ExpectedEvents: map[string]int{
|
||||
"*": 0,
|
||||
"OnMailerSend": 1,
|
||||
"OnMailerRecordOTPSend": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateAppleClientSecret(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
encodedKey, err := x509.MarshalPKCS8PrivateKey(key)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
privatePem := pem.EncodeToMemory(
|
||||
&pem.Block{
|
||||
Type: "PRIVATE KEY",
|
||||
Bytes: encodedKey,
|
||||
},
|
||||
)
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "unauthorized",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/settings/apple/generate-client-secret",
|
||||
ExpectedStatus: 401,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as regular user",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/settings/apple/generate-client-secret",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
|
||||
},
|
||||
ExpectedStatus: 403,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser (invalid body)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/settings/apple/generate-client-secret",
|
||||
Body: strings.NewReader(`{`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser (empty json)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/settings/apple/generate-client-secret",
|
||||
Body: strings.NewReader(`{}`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"clientId":{"code":"validation_required"`,
|
||||
`"teamId":{"code":"validation_required"`,
|
||||
`"keyId":{"code":"validation_required"`,
|
||||
`"privateKey":{"code":"validation_required"`,
|
||||
`"duration":{"code":"validation_required"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser (invalid data)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/settings/apple/generate-client-secret",
|
||||
Body: strings.NewReader(`{
|
||||
"clientId": "",
|
||||
"teamId": "123456789",
|
||||
"keyId": "123456789",
|
||||
"privateKey": "invalid",
|
||||
"duration": -1
|
||||
}`),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{
|
||||
`"clientId":{"code":"validation_required"`,
|
||||
`"teamId":{"code":"validation_length_invalid"`,
|
||||
`"keyId":{"code":"validation_length_invalid"`,
|
||||
`"privateKey":{"code":"validation_match_invalid"`,
|
||||
`"duration":{"code":"validation_min_greater_equal_than_required"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
{
|
||||
Name: "authorized as superuser (valid data)",
|
||||
Method: http.MethodPost,
|
||||
URL: "/api/settings/apple/generate-client-secret",
|
||||
Body: strings.NewReader(fmt.Sprintf(`{
|
||||
"clientId": "123",
|
||||
"teamId": "1234567890",
|
||||
"keyId": "1234567891",
|
||||
"privateKey": %q,
|
||||
"duration": 1
|
||||
}`, privatePem)),
|
||||
Headers: map[string]string{
|
||||
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoicGJjXzMxNDI2MzU4MjMiLCJleHAiOjI1MjQ2MDQ0NjEsInJlZnJlc2hhYmxlIjp0cnVlfQ.UXgO3j-0BumcugrFjbd7j0M4MQvbrLggLlcu_YNGjoY",
|
||||
},
|
||||
ExpectedStatus: 200,
|
||||
ExpectedContent: []string{
|
||||
`"secret":"`,
|
||||
},
|
||||
ExpectedEvents: map[string]int{"*": 0},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
77
cmd/serve.go
Normal file
77
cmd/serve.go
Normal file
|
@ -0,0 +1,77 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/pocketbase/pocketbase/apis"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
// NewServeCommand creates and returns new command responsible for
|
||||
// starting the default PocketBase web server.
|
||||
func NewServeCommand(app core.App, showStartBanner bool) *cobra.Command {
|
||||
var allowedOrigins []string
|
||||
var httpAddr string
|
||||
var httpsAddr string
|
||||
|
||||
command := &cobra.Command{
|
||||
Use: "serve [domain(s)]",
|
||||
Args: cobra.ArbitraryArgs,
|
||||
Short: "Starts the web server (default to 127.0.0.1:8090 if no domain is specified)",
|
||||
SilenceUsage: true,
|
||||
RunE: func(command *cobra.Command, args []string) error {
|
||||
// set default listener addresses if at least one domain is specified
|
||||
if len(args) > 0 {
|
||||
if httpAddr == "" {
|
||||
httpAddr = "0.0.0.0:80"
|
||||
}
|
||||
if httpsAddr == "" {
|
||||
httpsAddr = "0.0.0.0:443"
|
||||
}
|
||||
} else {
|
||||
if httpAddr == "" {
|
||||
httpAddr = "127.0.0.1:8090"
|
||||
}
|
||||
}
|
||||
|
||||
err := apis.Serve(app, apis.ServeConfig{
|
||||
HttpAddr: httpAddr,
|
||||
HttpsAddr: httpsAddr,
|
||||
ShowStartBanner: showStartBanner,
|
||||
AllowedOrigins: allowedOrigins,
|
||||
CertificateDomains: args,
|
||||
})
|
||||
|
||||
if errors.Is(err, http.ErrServerClosed) {
|
||||
return nil
|
||||
}
|
||||
|
||||
return err
|
||||
},
|
||||
}
|
||||
|
||||
command.PersistentFlags().StringSliceVar(
|
||||
&allowedOrigins,
|
||||
"origins",
|
||||
[]string{"*"},
|
||||
"CORS allowed domain origins list",
|
||||
)
|
||||
|
||||
command.PersistentFlags().StringVar(
|
||||
&httpAddr,
|
||||
"http",
|
||||
"",
|
||||
"TCP address to listen for the HTTP server\n(if domain args are specified - default to 0.0.0.0:80, otherwise - default to 127.0.0.1:8090)",
|
||||
)
|
||||
|
||||
command.PersistentFlags().StringVar(
|
||||
&httpsAddr,
|
||||
"https",
|
||||
"",
|
||||
"TCP address to listen for the HTTPS server\n(if domain args are specified - default to 0.0.0.0:443, otherwise - default to empty string, aka. no TLS)\nThe incoming HTTP traffic also will be auto redirected to the HTTPS version",
|
||||
)
|
||||
|
||||
return command
|
||||
}
|
211
cmd/superuser.go
Normal file
211
cmd/superuser.go
Normal file
|
@ -0,0 +1,211 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/fatih/color"
|
||||
"github.com/go-ozzo/ozzo-validation/v4/is"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tools/security"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
// NewSuperuserCommand creates and returns new command for managing
|
||||
// superuser accounts (create, update, upsert, delete).
|
||||
func NewSuperuserCommand(app core.App) *cobra.Command {
|
||||
command := &cobra.Command{
|
||||
Use: "superuser",
|
||||
Short: "Manage superusers",
|
||||
}
|
||||
|
||||
command.AddCommand(superuserUpsertCommand(app))
|
||||
command.AddCommand(superuserCreateCommand(app))
|
||||
command.AddCommand(superuserUpdateCommand(app))
|
||||
command.AddCommand(superuserDeleteCommand(app))
|
||||
command.AddCommand(superuserOTPCommand(app))
|
||||
|
||||
return command
|
||||
}
|
||||
|
||||
func superuserUpsertCommand(app core.App) *cobra.Command {
|
||||
command := &cobra.Command{
|
||||
Use: "upsert",
|
||||
Example: "superuser upsert test@example.com 1234567890",
|
||||
Short: "Creates, or updates if email exists, a single superuser",
|
||||
SilenceUsage: true,
|
||||
RunE: func(command *cobra.Command, args []string) error {
|
||||
if len(args) != 2 {
|
||||
return errors.New("missing email and password arguments")
|
||||
}
|
||||
|
||||
if args[0] == "" || is.EmailFormat.Validate(args[0]) != nil {
|
||||
return errors.New("missing or invalid email address")
|
||||
}
|
||||
|
||||
superusersCol, err := app.FindCachedCollectionByNameOrId(core.CollectionNameSuperusers)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to fetch %q collection: %w", core.CollectionNameSuperusers, err)
|
||||
}
|
||||
|
||||
superuser, err := app.FindAuthRecordByEmail(superusersCol, args[0])
|
||||
if err != nil {
|
||||
superuser = core.NewRecord(superusersCol)
|
||||
}
|
||||
|
||||
superuser.SetEmail(args[0])
|
||||
superuser.SetPassword(args[1])
|
||||
|
||||
if err := app.Save(superuser); err != nil {
|
||||
return fmt.Errorf("failed to upsert superuser account: %w", err)
|
||||
}
|
||||
|
||||
color.Green("Successfully saved superuser %q!", superuser.Email())
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
return command
|
||||
}
|
||||
|
||||
func superuserCreateCommand(app core.App) *cobra.Command {
|
||||
command := &cobra.Command{
|
||||
Use: "create",
|
||||
Example: "superuser create test@example.com 1234567890",
|
||||
Short: "Creates a new superuser",
|
||||
SilenceUsage: true,
|
||||
RunE: func(command *cobra.Command, args []string) error {
|
||||
if len(args) != 2 {
|
||||
return errors.New("missing email and password arguments")
|
||||
}
|
||||
|
||||
if args[0] == "" || is.EmailFormat.Validate(args[0]) != nil {
|
||||
return errors.New("missing or invalid email address")
|
||||
}
|
||||
|
||||
superusersCol, err := app.FindCachedCollectionByNameOrId(core.CollectionNameSuperusers)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to fetch %q collection: %w", core.CollectionNameSuperusers, err)
|
||||
}
|
||||
|
||||
superuser := core.NewRecord(superusersCol)
|
||||
superuser.SetEmail(args[0])
|
||||
superuser.SetPassword(args[1])
|
||||
|
||||
if err := app.Save(superuser); err != nil {
|
||||
return fmt.Errorf("failed to create new superuser account: %w", err)
|
||||
}
|
||||
|
||||
color.Green("Successfully created new superuser %q!", superuser.Email())
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
return command
|
||||
}
|
||||
|
||||
func superuserUpdateCommand(app core.App) *cobra.Command {
|
||||
command := &cobra.Command{
|
||||
Use: "update",
|
||||
Example: "superuser update test@example.com 1234567890",
|
||||
Short: "Changes the password of a single superuser",
|
||||
SilenceUsage: true,
|
||||
RunE: func(command *cobra.Command, args []string) error {
|
||||
if len(args) != 2 {
|
||||
return errors.New("missing email and password arguments")
|
||||
}
|
||||
|
||||
if args[0] == "" || is.EmailFormat.Validate(args[0]) != nil {
|
||||
return errors.New("missing or invalid email address")
|
||||
}
|
||||
|
||||
superuser, err := app.FindAuthRecordByEmail(core.CollectionNameSuperusers, args[0])
|
||||
if err != nil {
|
||||
return fmt.Errorf("superuser with email %q doesn't exist", args[0])
|
||||
}
|
||||
|
||||
superuser.SetPassword(args[1])
|
||||
|
||||
if err := app.Save(superuser); err != nil {
|
||||
return fmt.Errorf("failed to change superuser %q password: %w", superuser.Email(), err)
|
||||
}
|
||||
|
||||
color.Green("Successfully changed superuser %q password!", superuser.Email())
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
return command
|
||||
}
|
||||
|
||||
func superuserDeleteCommand(app core.App) *cobra.Command {
|
||||
command := &cobra.Command{
|
||||
Use: "delete",
|
||||
Example: "superuser delete test@example.com",
|
||||
Short: "Deletes an existing superuser",
|
||||
SilenceUsage: true,
|
||||
RunE: func(command *cobra.Command, args []string) error {
|
||||
if len(args) == 0 || args[0] == "" || is.EmailFormat.Validate(args[0]) != nil {
|
||||
return errors.New("invalid or missing email address")
|
||||
}
|
||||
|
||||
superuser, err := app.FindAuthRecordByEmail(core.CollectionNameSuperusers, args[0])
|
||||
if err != nil {
|
||||
color.Yellow("superuser %q is missing or already deleted", args[0])
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := app.Delete(superuser); err != nil {
|
||||
return fmt.Errorf("failed to delete superuser %q: %w", superuser.Email(), err)
|
||||
}
|
||||
|
||||
color.Green("Successfully deleted superuser %q!", superuser.Email())
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
return command
|
||||
}
|
||||
|
||||
func superuserOTPCommand(app core.App) *cobra.Command {
|
||||
command := &cobra.Command{
|
||||
Use: "otp",
|
||||
Example: "superuser otp test@example.com",
|
||||
Short: "Creates a new OTP for the specified superuser",
|
||||
SilenceUsage: true,
|
||||
RunE: func(command *cobra.Command, args []string) error {
|
||||
if len(args) == 0 || args[0] == "" || is.EmailFormat.Validate(args[0]) != nil {
|
||||
return errors.New("invalid or missing email address")
|
||||
}
|
||||
|
||||
superuser, err := app.FindAuthRecordByEmail(core.CollectionNameSuperusers, args[0])
|
||||
if err != nil {
|
||||
return fmt.Errorf("superuser with email %q doesn't exist", args[0])
|
||||
}
|
||||
|
||||
if !superuser.Collection().OTP.Enabled {
|
||||
return errors.New("OTP auth is not enabled for the _superusers collection")
|
||||
}
|
||||
|
||||
pass := security.RandomStringWithAlphabet(superuser.Collection().OTP.Length, "1234567890")
|
||||
|
||||
otp := core.NewOTP(app)
|
||||
otp.SetCollectionRef(superuser.Collection().Id)
|
||||
otp.SetRecordRef(superuser.Id)
|
||||
otp.SetPassword(pass)
|
||||
|
||||
err = app.Save(otp)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create OTP: %w", err)
|
||||
}
|
||||
|
||||
color.New(color.BgGreen, color.FgBlack).Printf("Successfully created OTP for superuser %q:", superuser.Email())
|
||||
color.Green("\n├─ Id: %s", otp.Id)
|
||||
color.Green("├─ Pass: %s", pass)
|
||||
color.Green("└─ Valid: %ds\n\n", superuser.Collection().OTP.Duration)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
return command
|
||||
}
|
403
cmd/superuser_test.go
Normal file
403
cmd/superuser_test.go
Normal file
|
@ -0,0 +1,403 @@
|
|||
package cmd_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/cmd"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestSuperuserUpsertCommand(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
email string
|
||||
password string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
"empty email and password",
|
||||
"",
|
||||
"",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"empty email",
|
||||
"",
|
||||
"1234567890",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"invalid email",
|
||||
"invalid",
|
||||
"1234567890",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"empty password",
|
||||
"test@example.com",
|
||||
"",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"short password",
|
||||
"test_new@example.com",
|
||||
"1234567",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"existing user",
|
||||
"test@example.com",
|
||||
"1234567890!",
|
||||
false,
|
||||
},
|
||||
{
|
||||
"new user",
|
||||
"test_new@example.com",
|
||||
"1234567890!",
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
command := cmd.NewSuperuserCommand(app)
|
||||
command.SetArgs([]string{"upsert", s.email, s.password})
|
||||
|
||||
err := command.Execute()
|
||||
|
||||
hasErr := err != nil
|
||||
if s.expectError != hasErr {
|
||||
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
|
||||
}
|
||||
|
||||
if hasErr {
|
||||
return
|
||||
}
|
||||
|
||||
// check whether the superuser account was actually upserted
|
||||
superuser, err := app.FindAuthRecordByEmail(core.CollectionNameSuperusers, s.email)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to fetch superuser %s: %v", s.email, err)
|
||||
} else if !superuser.ValidatePassword(s.password) {
|
||||
t.Fatal("Expected the superuser password to match")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSuperuserCreateCommand(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
email string
|
||||
password string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
"empty email and password",
|
||||
"",
|
||||
"",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"empty email",
|
||||
"",
|
||||
"1234567890",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"invalid email",
|
||||
"invalid",
|
||||
"1234567890",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"duplicated email",
|
||||
"test@example.com",
|
||||
"1234567890",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"empty password",
|
||||
"test@example.com",
|
||||
"",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"short password",
|
||||
"test_new@example.com",
|
||||
"1234567",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"valid email and password",
|
||||
"test_new@example.com",
|
||||
"12345678",
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
command := cmd.NewSuperuserCommand(app)
|
||||
command.SetArgs([]string{"create", s.email, s.password})
|
||||
|
||||
err := command.Execute()
|
||||
|
||||
hasErr := err != nil
|
||||
if s.expectError != hasErr {
|
||||
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
|
||||
}
|
||||
|
||||
if hasErr {
|
||||
return
|
||||
}
|
||||
|
||||
// check whether the superuser account was actually created
|
||||
superuser, err := app.FindAuthRecordByEmail(core.CollectionNameSuperusers, s.email)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to fetch created superuser %s: %v", s.email, err)
|
||||
} else if !superuser.ValidatePassword(s.password) {
|
||||
t.Fatal("Expected the superuser password to match")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSuperuserUpdateCommand(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
email string
|
||||
password string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
"empty email and password",
|
||||
"",
|
||||
"",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"empty email",
|
||||
"",
|
||||
"1234567890",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"invalid email",
|
||||
"invalid",
|
||||
"1234567890",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"nonexisting superuser",
|
||||
"test_missing@example.com",
|
||||
"1234567890",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"empty password",
|
||||
"test@example.com",
|
||||
"",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"short password",
|
||||
"test_new@example.com",
|
||||
"1234567",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"valid email and password",
|
||||
"test@example.com",
|
||||
"12345678",
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
command := cmd.NewSuperuserCommand(app)
|
||||
command.SetArgs([]string{"update", s.email, s.password})
|
||||
|
||||
err := command.Execute()
|
||||
|
||||
hasErr := err != nil
|
||||
if s.expectError != hasErr {
|
||||
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
|
||||
}
|
||||
|
||||
if hasErr {
|
||||
return
|
||||
}
|
||||
|
||||
// check whether the superuser password was actually changed
|
||||
superuser, err := app.FindAuthRecordByEmail(core.CollectionNameSuperusers, s.email)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to fetch superuser %s: %v", s.email, err)
|
||||
} else if !superuser.ValidatePassword(s.password) {
|
||||
t.Fatal("Expected the superuser password to match")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSuperuserDeleteCommand(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
email string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
"empty email",
|
||||
"",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"invalid email",
|
||||
"invalid",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"nonexisting superuser",
|
||||
"test_missing@example.com",
|
||||
false,
|
||||
},
|
||||
{
|
||||
"existing superuser",
|
||||
"test@example.com",
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
command := cmd.NewSuperuserCommand(app)
|
||||
command.SetArgs([]string{"delete", s.email})
|
||||
|
||||
err := command.Execute()
|
||||
|
||||
hasErr := err != nil
|
||||
if s.expectError != hasErr {
|
||||
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
|
||||
}
|
||||
|
||||
if hasErr {
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := app.FindAuthRecordByEmail(core.CollectionNameSuperusers, s.email); err == nil {
|
||||
t.Fatal("Expected the superuser account to be deleted")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSuperuserOTPCommand(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
superusersCollection, err := app.FindCollectionByNameOrId(core.CollectionNameSuperusers)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// remove all existing otps
|
||||
otps, err := app.FindAllOTPsByCollection(superusersCollection)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
for _, otp := range otps {
|
||||
err = app.Delete(otp)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
email string
|
||||
enabled bool
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
"empty email",
|
||||
"",
|
||||
true,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"invalid email",
|
||||
"invalid",
|
||||
true,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"nonexisting superuser",
|
||||
"test_missing@example.com",
|
||||
true,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"existing superuser",
|
||||
"test@example.com",
|
||||
true,
|
||||
false,
|
||||
},
|
||||
{
|
||||
"existing superuser with disabled OTP",
|
||||
"test@example.com",
|
||||
false,
|
||||
true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
command := cmd.NewSuperuserCommand(app)
|
||||
command.SetArgs([]string{"otp", s.email})
|
||||
|
||||
superusersCollection.OTP.Enabled = s.enabled
|
||||
if err = app.SaveNoValidate(superusersCollection); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err := command.Execute()
|
||||
|
||||
hasErr := err != nil
|
||||
if s.expectError != hasErr {
|
||||
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
|
||||
}
|
||||
|
||||
if hasErr {
|
||||
return
|
||||
}
|
||||
|
||||
superuser, err := app.FindAuthRecordByEmail(superusersCollection, s.email)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
otps, _ := app.FindAllOTPsByRecord(superuser)
|
||||
if total := len(otps); total != 1 {
|
||||
t.Fatalf("Expected 1 OTP, got %d", total)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
1538
core/app.go
Normal file
1538
core/app.go
Normal file
File diff suppressed because it is too large
Load diff
239
core/auth_origin_model.go
Normal file
239
core/auth_origin_model.go
Normal file
|
@ -0,0 +1,239 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"slices"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/tools/hook"
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
)
|
||||
|
||||
const CollectionNameAuthOrigins = "_authOrigins"
|
||||
|
||||
var (
|
||||
_ Model = (*AuthOrigin)(nil)
|
||||
_ PreValidator = (*AuthOrigin)(nil)
|
||||
_ RecordProxy = (*AuthOrigin)(nil)
|
||||
)
|
||||
|
||||
// AuthOrigin defines a Record proxy for working with the authOrigins collection.
|
||||
type AuthOrigin struct {
|
||||
*Record
|
||||
}
|
||||
|
||||
// NewAuthOrigin instantiates and returns a new blank *AuthOrigin model.
|
||||
//
|
||||
// Example usage:
|
||||
//
|
||||
// origin := core.NewOrigin(app)
|
||||
// origin.SetRecordRef(user.Id)
|
||||
// origin.SetCollectionRef(user.Collection().Id)
|
||||
// origin.SetFingerprint("...")
|
||||
// app.Save(origin)
|
||||
func NewAuthOrigin(app App) *AuthOrigin {
|
||||
m := &AuthOrigin{}
|
||||
|
||||
c, err := app.FindCachedCollectionByNameOrId(CollectionNameAuthOrigins)
|
||||
if err != nil {
|
||||
// this is just to make tests easier since authOrigins is a system collection and it is expected to be always accessible
|
||||
// (note: the loaded record is further checked on AuthOrigin.PreValidate())
|
||||
c = NewBaseCollection("@___invalid___")
|
||||
}
|
||||
|
||||
m.Record = NewRecord(c)
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
// PreValidate implements the [PreValidator] interface and checks
|
||||
// whether the proxy is properly loaded.
|
||||
func (m *AuthOrigin) PreValidate(ctx context.Context, app App) error {
|
||||
if m.Record == nil || m.Record.Collection().Name != CollectionNameAuthOrigins {
|
||||
return errors.New("missing or invalid AuthOrigin ProxyRecord")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ProxyRecord returns the proxied Record model.
|
||||
func (m *AuthOrigin) ProxyRecord() *Record {
|
||||
return m.Record
|
||||
}
|
||||
|
||||
// SetProxyRecord loads the specified record model into the current proxy.
|
||||
func (m *AuthOrigin) SetProxyRecord(record *Record) {
|
||||
m.Record = record
|
||||
}
|
||||
|
||||
// CollectionRef returns the "collectionRef" field value.
|
||||
func (m *AuthOrigin) CollectionRef() string {
|
||||
return m.GetString("collectionRef")
|
||||
}
|
||||
|
||||
// SetCollectionRef updates the "collectionRef" record field value.
|
||||
func (m *AuthOrigin) SetCollectionRef(collectionId string) {
|
||||
m.Set("collectionRef", collectionId)
|
||||
}
|
||||
|
||||
// RecordRef returns the "recordRef" record field value.
|
||||
func (m *AuthOrigin) RecordRef() string {
|
||||
return m.GetString("recordRef")
|
||||
}
|
||||
|
||||
// SetRecordRef updates the "recordRef" record field value.
|
||||
func (m *AuthOrigin) SetRecordRef(recordId string) {
|
||||
m.Set("recordRef", recordId)
|
||||
}
|
||||
|
||||
// Fingerprint returns the "fingerprint" record field value.
|
||||
func (m *AuthOrigin) Fingerprint() string {
|
||||
return m.GetString("fingerprint")
|
||||
}
|
||||
|
||||
// SetFingerprint updates the "fingerprint" record field value.
|
||||
func (m *AuthOrigin) SetFingerprint(fingerprint string) {
|
||||
m.Set("fingerprint", fingerprint)
|
||||
}
|
||||
|
||||
// Created returns the "created" record field value.
|
||||
func (m *AuthOrigin) Created() types.DateTime {
|
||||
return m.GetDateTime("created")
|
||||
}
|
||||
|
||||
// Updated returns the "updated" record field value.
|
||||
func (m *AuthOrigin) Updated() types.DateTime {
|
||||
return m.GetDateTime("updated")
|
||||
}
|
||||
|
||||
func (app *BaseApp) registerAuthOriginHooks() {
|
||||
recordRefHooks[*AuthOrigin](app, CollectionNameAuthOrigins, CollectionTypeAuth)
|
||||
|
||||
// delete existing auth origins on password change
|
||||
app.OnRecordUpdate().Bind(&hook.Handler[*RecordEvent]{
|
||||
Func: func(e *RecordEvent) error {
|
||||
err := e.Next()
|
||||
if err != nil || !e.Record.Collection().IsAuth() {
|
||||
return err
|
||||
}
|
||||
|
||||
old := e.Record.Original().GetString(FieldNamePassword + ":hash")
|
||||
new := e.Record.GetString(FieldNamePassword + ":hash")
|
||||
if old != new {
|
||||
err = e.App.DeleteAllAuthOriginsByRecord(e.Record)
|
||||
if err != nil {
|
||||
e.App.Logger().Warn(
|
||||
"Failed to delete all previous auth origin fingerprints",
|
||||
"error", err,
|
||||
"recordId", e.Record.Id,
|
||||
"collectionId", e.Record.Collection().Id,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
Priority: 99,
|
||||
})
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
// recordRefHooks registers common hooks that are usually used with record proxies
|
||||
// that have polymorphic record relations (aka. "collectionRef" and "recordRef" fields).
|
||||
func recordRefHooks[T RecordProxy](app App, collectionName string, optCollectionTypes ...string) {
|
||||
app.OnRecordValidate(collectionName).Bind(&hook.Handler[*RecordEvent]{
|
||||
Func: func(e *RecordEvent) error {
|
||||
collectionId := e.Record.GetString("collectionRef")
|
||||
err := validation.Validate(collectionId, validation.Required, validation.By(validateCollectionId(e.App, optCollectionTypes...)))
|
||||
if err != nil {
|
||||
return validation.Errors{"collectionRef": err}
|
||||
}
|
||||
|
||||
recordId := e.Record.GetString("recordRef")
|
||||
err = validation.Validate(recordId, validation.Required, validation.By(validateRecordId(e.App, collectionId)))
|
||||
if err != nil {
|
||||
return validation.Errors{"recordRef": err}
|
||||
}
|
||||
|
||||
return e.Next()
|
||||
},
|
||||
Priority: 99,
|
||||
})
|
||||
|
||||
// delete on collection ref delete
|
||||
app.OnCollectionDeleteExecute().Bind(&hook.Handler[*CollectionEvent]{
|
||||
Func: func(e *CollectionEvent) error {
|
||||
if e.Collection.Name == collectionName || (len(optCollectionTypes) > 0 && !slices.Contains(optCollectionTypes, e.Collection.Type)) {
|
||||
return e.Next()
|
||||
}
|
||||
|
||||
originalApp := e.App
|
||||
txErr := e.App.RunInTransaction(func(txApp App) error {
|
||||
e.App = txApp
|
||||
|
||||
if err := e.Next(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rels, err := txApp.FindAllRecords(collectionName, dbx.HashExp{"collectionRef": e.Collection.Id})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, mfa := range rels {
|
||||
if err := txApp.Delete(mfa); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
e.App = originalApp
|
||||
|
||||
return txErr
|
||||
},
|
||||
Priority: 99,
|
||||
})
|
||||
|
||||
// delete on record ref delete
|
||||
app.OnRecordDeleteExecute().Bind(&hook.Handler[*RecordEvent]{
|
||||
Func: func(e *RecordEvent) error {
|
||||
if e.Record.Collection().Name == collectionName ||
|
||||
(len(optCollectionTypes) > 0 && !slices.Contains(optCollectionTypes, e.Record.Collection().Type)) {
|
||||
return e.Next()
|
||||
}
|
||||
|
||||
originalApp := e.App
|
||||
txErr := e.App.RunInTransaction(func(txApp App) error {
|
||||
e.App = txApp
|
||||
|
||||
if err := e.Next(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rels, err := txApp.FindAllRecords(collectionName, dbx.HashExp{
|
||||
"collectionRef": e.Record.Collection().Id,
|
||||
"recordRef": e.Record.Id,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, rel := range rels {
|
||||
if err := txApp.Delete(rel); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
e.App = originalApp
|
||||
|
||||
return txErr
|
||||
},
|
||||
Priority: 99,
|
||||
})
|
||||
}
|
332
core/auth_origin_model_test.go
Normal file
332
core/auth_origin_model_test.go
Normal file
|
@ -0,0 +1,332 @@
|
|||
package core_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
)
|
||||
|
||||
func TestNewAuthOrigin(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
origin := core.NewAuthOrigin(app)
|
||||
|
||||
if origin.Collection().Name != core.CollectionNameAuthOrigins {
|
||||
t.Fatalf("Expected record with %q collection, got %q", core.CollectionNameAuthOrigins, origin.Collection().Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthOriginProxyRecord(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
record := core.NewRecord(core.NewBaseCollection("test"))
|
||||
record.Id = "test_id"
|
||||
|
||||
origin := core.AuthOrigin{}
|
||||
origin.SetProxyRecord(record)
|
||||
|
||||
if origin.ProxyRecord() == nil || origin.ProxyRecord().Id != record.Id {
|
||||
t.Fatalf("Expected proxy record with id %q, got %v", record.Id, origin.ProxyRecord())
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthOriginRecordRef(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
origin := core.NewAuthOrigin(app)
|
||||
|
||||
testValues := []string{"test_1", "test2", ""}
|
||||
for i, testValue := range testValues {
|
||||
t.Run(fmt.Sprintf("%d_%q", i, testValue), func(t *testing.T) {
|
||||
origin.SetRecordRef(testValue)
|
||||
|
||||
if v := origin.RecordRef(); v != testValue {
|
||||
t.Fatalf("Expected getter %q, got %q", testValue, v)
|
||||
}
|
||||
|
||||
if v := origin.GetString("recordRef"); v != testValue {
|
||||
t.Fatalf("Expected field value %q, got %q", testValue, v)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthOriginCollectionRef(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
origin := core.NewAuthOrigin(app)
|
||||
|
||||
testValues := []string{"test_1", "test2", ""}
|
||||
for i, testValue := range testValues {
|
||||
t.Run(fmt.Sprintf("%d_%q", i, testValue), func(t *testing.T) {
|
||||
origin.SetCollectionRef(testValue)
|
||||
|
||||
if v := origin.CollectionRef(); v != testValue {
|
||||
t.Fatalf("Expected getter %q, got %q", testValue, v)
|
||||
}
|
||||
|
||||
if v := origin.GetString("collectionRef"); v != testValue {
|
||||
t.Fatalf("Expected field value %q, got %q", testValue, v)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthOriginFingerprint(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
origin := core.NewAuthOrigin(app)
|
||||
|
||||
testValues := []string{"test_1", "test2", ""}
|
||||
for i, testValue := range testValues {
|
||||
t.Run(fmt.Sprintf("%d_%q", i, testValue), func(t *testing.T) {
|
||||
origin.SetFingerprint(testValue)
|
||||
|
||||
if v := origin.Fingerprint(); v != testValue {
|
||||
t.Fatalf("Expected getter %q, got %q", testValue, v)
|
||||
}
|
||||
|
||||
if v := origin.GetString("fingerprint"); v != testValue {
|
||||
t.Fatalf("Expected field value %q, got %q", testValue, v)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthOriginCreated(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
origin := core.NewAuthOrigin(app)
|
||||
|
||||
if v := origin.Created().String(); v != "" {
|
||||
t.Fatalf("Expected empty created, got %q", v)
|
||||
}
|
||||
|
||||
now := types.NowDateTime()
|
||||
origin.SetRaw("created", now)
|
||||
|
||||
if v := origin.Created().String(); v != now.String() {
|
||||
t.Fatalf("Expected %q created, got %q", now.String(), v)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthOriginUpdated(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
origin := core.NewAuthOrigin(app)
|
||||
|
||||
if v := origin.Updated().String(); v != "" {
|
||||
t.Fatalf("Expected empty updated, got %q", v)
|
||||
}
|
||||
|
||||
now := types.NowDateTime()
|
||||
origin.SetRaw("updated", now)
|
||||
|
||||
if v := origin.Updated().String(); v != now.String() {
|
||||
t.Fatalf("Expected %q updated, got %q", now.String(), v)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthOriginPreValidate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
originsCol, err := app.FindCollectionByNameOrId(core.CollectionNameAuthOrigins)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Run("no proxy record", func(t *testing.T) {
|
||||
origin := &core.AuthOrigin{}
|
||||
|
||||
if err := app.Validate(origin); err == nil {
|
||||
t.Fatal("Expected collection validation error")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("non-AuthOrigin collection", func(t *testing.T) {
|
||||
origin := &core.AuthOrigin{}
|
||||
origin.SetProxyRecord(core.NewRecord(core.NewBaseCollection("invalid")))
|
||||
origin.SetRecordRef(user.Id)
|
||||
origin.SetCollectionRef(user.Collection().Id)
|
||||
origin.SetFingerprint("abc")
|
||||
|
||||
if err := app.Validate(origin); err == nil {
|
||||
t.Fatal("Expected collection validation error")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("AuthOrigin collection", func(t *testing.T) {
|
||||
origin := &core.AuthOrigin{}
|
||||
origin.SetProxyRecord(core.NewRecord(originsCol))
|
||||
origin.SetRecordRef(user.Id)
|
||||
origin.SetCollectionRef(user.Collection().Id)
|
||||
origin.SetFingerprint("abc")
|
||||
|
||||
if err := app.Validate(origin); err != nil {
|
||||
t.Fatalf("Expected nil validation error, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestAuthOriginValidateHook(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
demo1, err := app.FindRecordById("demo1", "84nmscqy84lsi1t")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
origin func() *core.AuthOrigin
|
||||
expectErrors []string
|
||||
}{
|
||||
{
|
||||
"empty",
|
||||
func() *core.AuthOrigin {
|
||||
return core.NewAuthOrigin(app)
|
||||
},
|
||||
[]string{"collectionRef", "recordRef", "fingerprint"},
|
||||
},
|
||||
{
|
||||
"non-auth collection",
|
||||
func() *core.AuthOrigin {
|
||||
origin := core.NewAuthOrigin(app)
|
||||
origin.SetCollectionRef(demo1.Collection().Id)
|
||||
origin.SetRecordRef(demo1.Id)
|
||||
origin.SetFingerprint("abc")
|
||||
return origin
|
||||
},
|
||||
[]string{"collectionRef"},
|
||||
},
|
||||
{
|
||||
"missing record id",
|
||||
func() *core.AuthOrigin {
|
||||
origin := core.NewAuthOrigin(app)
|
||||
origin.SetCollectionRef(user.Collection().Id)
|
||||
origin.SetRecordRef("missing")
|
||||
origin.SetFingerprint("abc")
|
||||
return origin
|
||||
},
|
||||
[]string{"recordRef"},
|
||||
},
|
||||
{
|
||||
"valid ref",
|
||||
func() *core.AuthOrigin {
|
||||
origin := core.NewAuthOrigin(app)
|
||||
origin.SetCollectionRef(user.Collection().Id)
|
||||
origin.SetRecordRef(user.Id)
|
||||
origin.SetFingerprint("abc")
|
||||
return origin
|
||||
},
|
||||
[]string{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
errs := app.Validate(s.origin())
|
||||
tests.TestValidationErrors(t, errs, s.expectErrors)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthOriginPasswordChangeDeletion(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testApp, _ := tests.NewTestApp()
|
||||
defer testApp.Cleanup()
|
||||
|
||||
// no auth origin associated with it
|
||||
user1, err := testApp.FindAuthRecordByEmail("users", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
superuser2, err := testApp.FindAuthRecordByEmail(core.CollectionNameSuperusers, "test2@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
client1, err := testApp.FindAuthRecordByEmail("clients", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
scenarios := []struct {
|
||||
record *core.Record
|
||||
deletedIds []string
|
||||
}{
|
||||
{user1, nil},
|
||||
{superuser2, []string{"5798yh833k6w6w0", "ic55o70g4f8pcl4", "dmy260k6ksjr4ib"}},
|
||||
{client1, []string{"9r2j0m74260ur8i"}},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%s_%s", i, s.record.Collection().Name, s.record.Id), func(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
deletedIds := []string{}
|
||||
app.OnRecordDelete().BindFunc(func(e *core.RecordEvent) error {
|
||||
deletedIds = append(deletedIds, e.Record.Id)
|
||||
return e.Next()
|
||||
})
|
||||
|
||||
s.record.SetPassword("new_password")
|
||||
|
||||
err := app.Save(s.record)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(deletedIds) != len(s.deletedIds) {
|
||||
t.Fatalf("Expected deleted ids\n%v\ngot\n%v", s.deletedIds, deletedIds)
|
||||
}
|
||||
|
||||
for _, id := range s.deletedIds {
|
||||
if !slices.Contains(deletedIds, id) {
|
||||
t.Errorf("Expected to find deleted id %q in %v", id, deletedIds)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
101
core/auth_origin_query.go
Normal file
101
core/auth_origin_query.go
Normal file
|
@ -0,0 +1,101 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/pocketbase/dbx"
|
||||
)
|
||||
|
||||
// FindAllAuthOriginsByRecord returns all AuthOrigin models linked to the provided auth record (in DESC order).
|
||||
func (app *BaseApp) FindAllAuthOriginsByRecord(authRecord *Record) ([]*AuthOrigin, error) {
|
||||
result := []*AuthOrigin{}
|
||||
|
||||
err := app.RecordQuery(CollectionNameAuthOrigins).
|
||||
AndWhere(dbx.HashExp{
|
||||
"collectionRef": authRecord.Collection().Id,
|
||||
"recordRef": authRecord.Id,
|
||||
}).
|
||||
OrderBy("created DESC").
|
||||
All(&result)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// FindAllAuthOriginsByCollection returns all AuthOrigin models linked to the provided collection (in DESC order).
|
||||
func (app *BaseApp) FindAllAuthOriginsByCollection(collection *Collection) ([]*AuthOrigin, error) {
|
||||
result := []*AuthOrigin{}
|
||||
|
||||
err := app.RecordQuery(CollectionNameAuthOrigins).
|
||||
AndWhere(dbx.HashExp{"collectionRef": collection.Id}).
|
||||
OrderBy("created DESC").
|
||||
All(&result)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// FindAuthOriginById returns a single AuthOrigin model by its id.
|
||||
func (app *BaseApp) FindAuthOriginById(id string) (*AuthOrigin, error) {
|
||||
result := &AuthOrigin{}
|
||||
|
||||
err := app.RecordQuery(CollectionNameAuthOrigins).
|
||||
AndWhere(dbx.HashExp{"id": id}).
|
||||
Limit(1).
|
||||
One(result)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// FindAuthOriginByRecordAndFingerprint returns a single AuthOrigin model
|
||||
// by its authRecord relation and fingerprint.
|
||||
func (app *BaseApp) FindAuthOriginByRecordAndFingerprint(authRecord *Record, fingerprint string) (*AuthOrigin, error) {
|
||||
result := &AuthOrigin{}
|
||||
|
||||
err := app.RecordQuery(CollectionNameAuthOrigins).
|
||||
AndWhere(dbx.HashExp{
|
||||
"collectionRef": authRecord.Collection().Id,
|
||||
"recordRef": authRecord.Id,
|
||||
"fingerprint": fingerprint,
|
||||
}).
|
||||
Limit(1).
|
||||
One(result)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// DeleteAllAuthOriginsByRecord deletes all AuthOrigin models associated with the provided record.
|
||||
//
|
||||
// Returns a combined error with the failed deletes.
|
||||
func (app *BaseApp) DeleteAllAuthOriginsByRecord(authRecord *Record) error {
|
||||
models, err := app.FindAllAuthOriginsByRecord(authRecord)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var errs []error
|
||||
for _, m := range models {
|
||||
if err := app.Delete(m); err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
}
|
||||
if len(errs) > 0 {
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
268
core/auth_origin_query_test.go
Normal file
268
core/auth_origin_query_test.go
Normal file
|
@ -0,0 +1,268 @@
|
|||
package core_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestFindAllAuthOriginsByRecord(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
demo1, err := app.FindRecordById("demo1", "84nmscqy84lsi1t")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
superuser2, err := app.FindAuthRecordByEmail(core.CollectionNameSuperusers, "test2@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
superuser4, err := app.FindAuthRecordByEmail(core.CollectionNameSuperusers, "test4@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
client1, err := app.FindAuthRecordByEmail("clients", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
scenarios := []struct {
|
||||
record *core.Record
|
||||
expected []string
|
||||
}{
|
||||
{demo1, nil},
|
||||
{superuser2, []string{"5798yh833k6w6w0", "ic55o70g4f8pcl4", "dmy260k6ksjr4ib"}},
|
||||
{superuser4, nil},
|
||||
{client1, []string{"9r2j0m74260ur8i"}},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.record.Collection().Name+"_"+s.record.Id, func(t *testing.T) {
|
||||
result, err := app.FindAllAuthOriginsByRecord(s.record)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(result) != len(s.expected) {
|
||||
t.Fatalf("Expected total origins %d, got %d", len(s.expected), len(result))
|
||||
}
|
||||
|
||||
for i, id := range s.expected {
|
||||
if result[i].Id != id {
|
||||
t.Errorf("[%d] Expected id %q, got %q", i, id, result[i].Id)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindAllAuthOriginsByCollection(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
demo1, err := app.FindCollectionByNameOrId("demo1")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
superusers, err := app.FindCollectionByNameOrId(core.CollectionNameSuperusers)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
clients, err := app.FindCollectionByNameOrId("clients")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
scenarios := []struct {
|
||||
collection *core.Collection
|
||||
expected []string
|
||||
}{
|
||||
{demo1, nil},
|
||||
{superusers, []string{"5798yh833k6w6w0", "ic55o70g4f8pcl4", "dmy260k6ksjr4ib", "5f29jy38bf5zm3f"}},
|
||||
{clients, []string{"9r2j0m74260ur8i"}},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.collection.Name, func(t *testing.T) {
|
||||
result, err := app.FindAllAuthOriginsByCollection(s.collection)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(result) != len(s.expected) {
|
||||
t.Fatalf("Expected total origins %d, got %d", len(s.expected), len(result))
|
||||
}
|
||||
|
||||
for i, id := range s.expected {
|
||||
if result[i].Id != id {
|
||||
t.Errorf("[%d] Expected id %q, got %q", i, id, result[i].Id)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindAuthOriginById(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
scenarios := []struct {
|
||||
id string
|
||||
expectError bool
|
||||
}{
|
||||
{"", true},
|
||||
{"84nmscqy84lsi1t", true}, // non-origin id
|
||||
{"9r2j0m74260ur8i", false},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.id, func(t *testing.T) {
|
||||
result, err := app.FindAuthOriginById(s.id)
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
|
||||
}
|
||||
|
||||
if hasErr {
|
||||
return
|
||||
}
|
||||
|
||||
if result.Id != s.id {
|
||||
t.Fatalf("Expected record with id %q, got %q", s.id, result.Id)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindAuthOriginByRecordAndFingerprint(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
demo1, err := app.FindRecordById("demo1", "84nmscqy84lsi1t")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
superuser2, err := app.FindAuthRecordByEmail(core.CollectionNameSuperusers, "test2@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
scenarios := []struct {
|
||||
record *core.Record
|
||||
fingerprint string
|
||||
expectError bool
|
||||
}{
|
||||
{demo1, "6afbfe481c31c08c55a746cccb88ece0", true},
|
||||
{superuser2, "", true},
|
||||
{superuser2, "abc", true},
|
||||
{superuser2, "22bbbcbed36e25321f384ccf99f60057", false}, // fingerprint from different origin
|
||||
{superuser2, "6afbfe481c31c08c55a746cccb88ece0", false},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%s_%s", i, s.record.Id, s.fingerprint), func(t *testing.T) {
|
||||
result, err := app.FindAuthOriginByRecordAndFingerprint(s.record, s.fingerprint)
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
|
||||
}
|
||||
|
||||
if hasErr {
|
||||
return
|
||||
}
|
||||
|
||||
if result.Fingerprint() != s.fingerprint {
|
||||
t.Fatalf("Expected origin with fingerprint %q, got %q", s.fingerprint, result.Fingerprint())
|
||||
}
|
||||
|
||||
if result.RecordRef() != s.record.Id || result.CollectionRef() != s.record.Collection().Id {
|
||||
t.Fatalf("Expected record %q (%q), got %q (%q)", s.record.Id, s.record.Collection().Id, result.RecordRef(), result.CollectionRef())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteAllAuthOriginsByRecord(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testApp, _ := tests.NewTestApp()
|
||||
defer testApp.Cleanup()
|
||||
|
||||
demo1, err := testApp.FindRecordById("demo1", "84nmscqy84lsi1t")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
superuser2, err := testApp.FindAuthRecordByEmail(core.CollectionNameSuperusers, "test2@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
superuser4, err := testApp.FindAuthRecordByEmail(core.CollectionNameSuperusers, "test4@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
client1, err := testApp.FindAuthRecordByEmail("clients", "test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
scenarios := []struct {
|
||||
record *core.Record
|
||||
deletedIds []string
|
||||
}{
|
||||
{demo1, nil}, // non-auth record
|
||||
{superuser2, []string{"5798yh833k6w6w0", "ic55o70g4f8pcl4", "dmy260k6ksjr4ib"}},
|
||||
{superuser4, nil},
|
||||
{client1, []string{"9r2j0m74260ur8i"}},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%s_%s", i, s.record.Collection().Name, s.record.Id), func(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
deletedIds := []string{}
|
||||
app.OnRecordDelete().BindFunc(func(e *core.RecordEvent) error {
|
||||
deletedIds = append(deletedIds, e.Record.Id)
|
||||
return e.Next()
|
||||
})
|
||||
|
||||
err := app.DeleteAllAuthOriginsByRecord(s.record)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(deletedIds) != len(s.deletedIds) {
|
||||
t.Fatalf("Expected deleted ids\n%v\ngot\n%v", s.deletedIds, deletedIds)
|
||||
}
|
||||
|
||||
for _, id := range s.deletedIds {
|
||||
if !slices.Contains(deletedIds, id) {
|
||||
t.Errorf("Expected to find deleted id %q in %v", id, deletedIds)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
1535
core/base.go
Normal file
1535
core/base.go
Normal file
File diff suppressed because it is too large
Load diff
389
core/base_backup.go
Normal file
389
core/base_backup.go
Normal file
|
@ -0,0 +1,389 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/archive"
|
||||
"github.com/pocketbase/pocketbase/tools/filesystem"
|
||||
"github.com/pocketbase/pocketbase/tools/inflector"
|
||||
"github.com/pocketbase/pocketbase/tools/osutils"
|
||||
"github.com/pocketbase/pocketbase/tools/security"
|
||||
)
|
||||
|
||||
const (
|
||||
StoreKeyActiveBackup = "@activeBackup"
|
||||
)
|
||||
|
||||
// CreateBackup creates a new backup of the current app pb_data directory.
|
||||
//
|
||||
// If name is empty, it will be autogenerated.
|
||||
// If backup with the same name exists, the new backup file will replace it.
|
||||
//
|
||||
// The backup is executed within a transaction, meaning that new writes
|
||||
// will be temporary "blocked" until the backup file is generated.
|
||||
//
|
||||
// To safely perform the backup, it is recommended to have free disk space
|
||||
// for at least 2x the size of the pb_data directory.
|
||||
//
|
||||
// By default backups are stored in pb_data/backups
|
||||
// (the backups directory itself is excluded from the generated backup).
|
||||
//
|
||||
// When using S3 storage for the uploaded collection files, you have to
|
||||
// take care manually to backup those since they are not part of the pb_data.
|
||||
//
|
||||
// Backups can be stored on S3 if it is configured in app.Settings().Backups.
|
||||
func (app *BaseApp) CreateBackup(ctx context.Context, name string) error {
|
||||
if app.Store().Has(StoreKeyActiveBackup) {
|
||||
return errors.New("try again later - another backup/restore operation has already been started")
|
||||
}
|
||||
|
||||
app.Store().Set(StoreKeyActiveBackup, name)
|
||||
defer app.Store().Remove(StoreKeyActiveBackup)
|
||||
|
||||
event := new(BackupEvent)
|
||||
event.App = app
|
||||
event.Context = ctx
|
||||
event.Name = name
|
||||
// default root dir entries to exclude from the backup generation
|
||||
event.Exclude = []string{LocalBackupsDirName, LocalTempDirName, LocalAutocertCacheDirName}
|
||||
|
||||
return app.OnBackupCreate().Trigger(event, func(e *BackupEvent) error {
|
||||
// generate a default name if missing
|
||||
if e.Name == "" {
|
||||
e.Name = generateBackupName(e.App, "pb_backup_")
|
||||
}
|
||||
|
||||
// make sure that the special temp directory exists
|
||||
// note: it needs to be inside the current pb_data to avoid "cross-device link" errors
|
||||
localTempDir := filepath.Join(e.App.DataDir(), LocalTempDirName)
|
||||
if err := os.MkdirAll(localTempDir, os.ModePerm); err != nil {
|
||||
return fmt.Errorf("failed to create a temp dir: %w", err)
|
||||
}
|
||||
|
||||
// archive pb_data in a temp directory, exluding the "backups" and the temp dirs
|
||||
//
|
||||
// Run in transaction to temporary block other writes (transactions uses the NonconcurrentDB connection).
|
||||
// ---
|
||||
tempPath := filepath.Join(localTempDir, "pb_backup_"+security.PseudorandomString(6))
|
||||
createErr := e.App.RunInTransaction(func(txApp App) error {
|
||||
return txApp.AuxRunInTransaction(func(txApp App) error {
|
||||
// run manual checkpoint and truncate the WAL files
|
||||
// (errors are ignored because it is not that important and the PRAGMA may not be supported by the used driver)
|
||||
txApp.DB().NewQuery("PRAGMA wal_checkpoint(TRUNCATE)").Execute()
|
||||
txApp.AuxDB().NewQuery("PRAGMA wal_checkpoint(TRUNCATE)").Execute()
|
||||
|
||||
return archive.Create(txApp.DataDir(), tempPath, e.Exclude...)
|
||||
})
|
||||
})
|
||||
if createErr != nil {
|
||||
return createErr
|
||||
}
|
||||
defer os.Remove(tempPath)
|
||||
|
||||
// persist the backup in the backups filesystem
|
||||
// ---
|
||||
fsys, err := e.App.NewBackupsFilesystem()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer fsys.Close()
|
||||
|
||||
fsys.SetContext(e.Context)
|
||||
|
||||
file, err := filesystem.NewFileFromPath(tempPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
file.OriginalName = e.Name
|
||||
file.Name = file.OriginalName
|
||||
|
||||
if err := fsys.UploadFile(file, file.Name); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// RestoreBackup restores the backup with the specified name and restarts
|
||||
// the current running application process.
|
||||
//
|
||||
// NB! This feature is experimental and currently is expected to work only on UNIX based systems.
|
||||
//
|
||||
// To safely perform the restore it is recommended to have free disk space
|
||||
// for at least 2x the size of the restored pb_data backup.
|
||||
//
|
||||
// The performed steps are:
|
||||
//
|
||||
// 1. Download the backup with the specified name in a temp location
|
||||
// (this is in case of S3; otherwise it creates a temp copy of the zip)
|
||||
//
|
||||
// 2. Extract the backup in a temp directory inside the app "pb_data"
|
||||
// (eg. "pb_data/.pb_temp_to_delete/pb_restore").
|
||||
//
|
||||
// 3. Move the current app "pb_data" content (excluding the local backups and the special temp dir)
|
||||
// under another temp sub dir that will be deleted on the next app start up
|
||||
// (eg. "pb_data/.pb_temp_to_delete/old_pb_data").
|
||||
// This is because on some environments it may not be allowed
|
||||
// to delete the currently open "pb_data" files.
|
||||
//
|
||||
// 4. Move the extracted dir content to the app "pb_data".
|
||||
//
|
||||
// 5. Restart the app (on successful app bootstap it will also remove the old pb_data).
|
||||
//
|
||||
// If a failure occure during the restore process the dir changes are reverted.
|
||||
// If for whatever reason the revert is not possible, it panics.
|
||||
//
|
||||
// Note that if your pb_data has custom network mounts as subdirectories, then
|
||||
// it is possible the restore to fail during the `os.Rename` operations
|
||||
// (see https://github.com/pocketbase/pocketbase/issues/4647).
|
||||
func (app *BaseApp) RestoreBackup(ctx context.Context, name string) error {
|
||||
if app.Store().Has(StoreKeyActiveBackup) {
|
||||
return errors.New("try again later - another backup/restore operation has already been started")
|
||||
}
|
||||
|
||||
app.Store().Set(StoreKeyActiveBackup, name)
|
||||
defer app.Store().Remove(StoreKeyActiveBackup)
|
||||
|
||||
event := new(BackupEvent)
|
||||
event.App = app
|
||||
event.Context = ctx
|
||||
event.Name = name
|
||||
// default root dir entries to exclude from the backup restore
|
||||
event.Exclude = []string{LocalBackupsDirName, LocalTempDirName, LocalAutocertCacheDirName}
|
||||
|
||||
return app.OnBackupRestore().Trigger(event, func(e *BackupEvent) error {
|
||||
if runtime.GOOS == "windows" {
|
||||
return errors.New("restore is not supported on Windows")
|
||||
}
|
||||
|
||||
// make sure that the special temp directory exists
|
||||
// note: it needs to be inside the current pb_data to avoid "cross-device link" errors
|
||||
localTempDir := filepath.Join(e.App.DataDir(), LocalTempDirName)
|
||||
if err := os.MkdirAll(localTempDir, os.ModePerm); err != nil {
|
||||
return fmt.Errorf("failed to create a temp dir: %w", err)
|
||||
}
|
||||
|
||||
fsys, err := e.App.NewBackupsFilesystem()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer fsys.Close()
|
||||
|
||||
fsys.SetContext(e.Context)
|
||||
|
||||
if ok, _ := fsys.Exists(name); !ok {
|
||||
return fmt.Errorf("missing or invalid backup file %q to restore", name)
|
||||
}
|
||||
|
||||
extractedDataDir := filepath.Join(localTempDir, "pb_restore_"+security.PseudorandomString(8))
|
||||
defer os.RemoveAll(extractedDataDir)
|
||||
|
||||
// extract the zip
|
||||
if e.App.Settings().Backups.S3.Enabled {
|
||||
br, err := fsys.GetReader(name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer br.Close()
|
||||
|
||||
// create a temp zip file from the blob.Reader and try to extract it
|
||||
tempZip, err := os.CreateTemp(localTempDir, "pb_restore_zip")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer os.Remove(tempZip.Name())
|
||||
defer tempZip.Close() // note: this technically shouldn't be necessary but it is here to workaround platforms discrepancies
|
||||
|
||||
_, err = io.Copy(tempZip, br)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = archive.Extract(tempZip.Name(), extractedDataDir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// remove the temp zip file since we no longer need it
|
||||
// (this is in case the app restarts and the defer calls are not called)
|
||||
_ = tempZip.Close()
|
||||
err = os.Remove(tempZip.Name())
|
||||
if err != nil {
|
||||
e.App.Logger().Warn(
|
||||
"[RestoreBackup] Failed to remove the temp zip backup file",
|
||||
slog.String("file", tempZip.Name()),
|
||||
slog.String("error", err.Error()),
|
||||
)
|
||||
}
|
||||
} else {
|
||||
// manually construct the local path to avoid creating a copy of the zip file
|
||||
// since the blob reader currently doesn't implement ReaderAt
|
||||
zipPath := filepath.Join(app.DataDir(), LocalBackupsDirName, filepath.Base(name))
|
||||
|
||||
err = archive.Extract(zipPath, extractedDataDir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// ensure that at least a database file exists
|
||||
extractedDB := filepath.Join(extractedDataDir, "data.db")
|
||||
if _, err := os.Stat(extractedDB); err != nil {
|
||||
return fmt.Errorf("data.db file is missing or invalid: %w", err)
|
||||
}
|
||||
|
||||
// move the current pb_data content to a special temp location
|
||||
// that will hold the old data between dirs replace
|
||||
// (the temp dir will be automatically removed on the next app start)
|
||||
oldTempDataDir := filepath.Join(localTempDir, "old_pb_data_"+security.PseudorandomString(8))
|
||||
if err := osutils.MoveDirContent(e.App.DataDir(), oldTempDataDir, e.Exclude...); err != nil {
|
||||
return fmt.Errorf("failed to move the current pb_data content to a temp location: %w", err)
|
||||
}
|
||||
|
||||
// move the extracted archive content to the app's pb_data
|
||||
if err := osutils.MoveDirContent(extractedDataDir, e.App.DataDir(), e.Exclude...); err != nil {
|
||||
return fmt.Errorf("failed to move the extracted archive content to pb_data: %w", err)
|
||||
}
|
||||
|
||||
revertDataDirChanges := func() error {
|
||||
if err := osutils.MoveDirContent(e.App.DataDir(), extractedDataDir, e.Exclude...); err != nil {
|
||||
return fmt.Errorf("failed to revert the extracted dir change: %w", err)
|
||||
}
|
||||
|
||||
if err := osutils.MoveDirContent(oldTempDataDir, e.App.DataDir(), e.Exclude...); err != nil {
|
||||
return fmt.Errorf("failed to revert old pb_data dir change: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// restart the app
|
||||
if err := e.App.Restart(); err != nil {
|
||||
if revertErr := revertDataDirChanges(); revertErr != nil {
|
||||
panic(revertErr)
|
||||
}
|
||||
|
||||
return fmt.Errorf("failed to restart the app process: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// registerAutobackupHooks registers the autobackup app serve hooks.
|
||||
func (app *BaseApp) registerAutobackupHooks() {
|
||||
const jobId = "__pbAutoBackup__"
|
||||
|
||||
loadJob := func() {
|
||||
rawSchedule := app.Settings().Backups.Cron
|
||||
if rawSchedule == "" {
|
||||
app.Cron().Remove(jobId)
|
||||
return
|
||||
}
|
||||
|
||||
app.Cron().Add(jobId, rawSchedule, func() {
|
||||
const autoPrefix = "@auto_pb_backup_"
|
||||
|
||||
name := generateBackupName(app, autoPrefix)
|
||||
|
||||
if err := app.CreateBackup(context.Background(), name); err != nil {
|
||||
app.Logger().Error(
|
||||
"[Backup cron] Failed to create backup",
|
||||
slog.String("name", name),
|
||||
slog.String("error", err.Error()),
|
||||
)
|
||||
}
|
||||
|
||||
maxKeep := app.Settings().Backups.CronMaxKeep
|
||||
|
||||
if maxKeep == 0 {
|
||||
return // no explicit limit
|
||||
}
|
||||
|
||||
fsys, err := app.NewBackupsFilesystem()
|
||||
if err != nil {
|
||||
app.Logger().Error(
|
||||
"[Backup cron] Failed to initialize the backup filesystem",
|
||||
slog.String("error", err.Error()),
|
||||
)
|
||||
return
|
||||
}
|
||||
defer fsys.Close()
|
||||
|
||||
files, err := fsys.List(autoPrefix)
|
||||
if err != nil {
|
||||
app.Logger().Error(
|
||||
"[Backup cron] Failed to list autogenerated backups",
|
||||
slog.String("error", err.Error()),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
if maxKeep >= len(files) {
|
||||
return // nothing to remove
|
||||
}
|
||||
|
||||
// sort desc
|
||||
sort.Slice(files, func(i, j int) bool {
|
||||
return files[i].ModTime.After(files[j].ModTime)
|
||||
})
|
||||
|
||||
// keep only the most recent n auto backup files
|
||||
toRemove := files[maxKeep:]
|
||||
|
||||
for _, f := range toRemove {
|
||||
if err := fsys.Delete(f.Key); err != nil {
|
||||
app.Logger().Error(
|
||||
"[Backup cron] Failed to remove old autogenerated backup",
|
||||
slog.String("key", f.Key),
|
||||
slog.String("error", err.Error()),
|
||||
)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
app.OnBootstrap().BindFunc(func(e *BootstrapEvent) error {
|
||||
if err := e.Next(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
loadJob()
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
app.OnSettingsReload().BindFunc(func(e *SettingsReloadEvent) error {
|
||||
if err := e.Next(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
loadJob()
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func generateBackupName(app App, prefix string) string {
|
||||
appName := inflector.Snakecase(app.Settings().Meta.AppName)
|
||||
if len(appName) > 50 {
|
||||
appName = appName[:50]
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
"%s%s_%s.zip",
|
||||
prefix,
|
||||
appName,
|
||||
time.Now().UTC().Format("20060102150405"),
|
||||
)
|
||||
}
|
164
core/base_backup_test.go
Normal file
164
core/base_backup_test.go
Normal file
|
@ -0,0 +1,164 @@
|
|||
package core_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
"github.com/pocketbase/pocketbase/tools/archive"
|
||||
"github.com/pocketbase/pocketbase/tools/list"
|
||||
)
|
||||
|
||||
func TestCreateBackup(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
// set some long app name with spaces and special characters
|
||||
app.Settings().Meta.AppName = "test @! " + strings.Repeat("a", 100)
|
||||
|
||||
expectedAppNamePrefix := "test_" + strings.Repeat("a", 45)
|
||||
|
||||
// test pending error
|
||||
app.Store().Set(core.StoreKeyActiveBackup, "")
|
||||
if err := app.CreateBackup(context.Background(), "test.zip"); err == nil {
|
||||
t.Fatal("Expected pending error, got nil")
|
||||
}
|
||||
app.Store().Remove(core.StoreKeyActiveBackup)
|
||||
|
||||
// create with auto generated name
|
||||
if err := app.CreateBackup(context.Background(), ""); err != nil {
|
||||
t.Fatal("Failed to create a backup with autogenerated name")
|
||||
}
|
||||
|
||||
// create with custom name
|
||||
if err := app.CreateBackup(context.Background(), "custom"); err != nil {
|
||||
t.Fatal("Failed to create a backup with custom name")
|
||||
}
|
||||
|
||||
// create new with the same name (aka. replace)
|
||||
if err := app.CreateBackup(context.Background(), "custom"); err != nil {
|
||||
t.Fatal("Failed to create and replace a backup with the same name")
|
||||
}
|
||||
|
||||
backupsDir := filepath.Join(app.DataDir(), core.LocalBackupsDirName)
|
||||
|
||||
entries, err := os.ReadDir(backupsDir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expectedFiles := []string{
|
||||
`^pb_backup_` + expectedAppNamePrefix + `_\w+\.zip$`,
|
||||
`^pb_backup_` + expectedAppNamePrefix + `_\w+\.zip.attrs$`,
|
||||
"custom",
|
||||
"custom.attrs",
|
||||
}
|
||||
|
||||
if len(entries) != len(expectedFiles) {
|
||||
names := getEntryNames(entries)
|
||||
t.Fatalf("Expected %d backup files, got %d: \n%v", len(expectedFiles), len(entries), names)
|
||||
}
|
||||
|
||||
for i, entry := range entries {
|
||||
if !list.ExistInSliceWithRegex(entry.Name(), expectedFiles) {
|
||||
t.Fatalf("[%d] Missing backup file %q", i, entry.Name())
|
||||
}
|
||||
|
||||
if strings.HasSuffix(entry.Name(), ".attrs") {
|
||||
continue
|
||||
}
|
||||
|
||||
path := filepath.Join(backupsDir, entry.Name())
|
||||
|
||||
if err := verifyBackupContent(app, path); err != nil {
|
||||
t.Fatalf("[%d] Failed to verify backup content: %v", i, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRestoreBackup(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
// create a initial test backup to ensure that there are at least 1
|
||||
// backup file and that the generated zip doesn't contain the backups dir
|
||||
if err := app.CreateBackup(context.Background(), "initial"); err != nil {
|
||||
t.Fatal("Failed to create test initial backup")
|
||||
}
|
||||
|
||||
// create test backup
|
||||
if err := app.CreateBackup(context.Background(), "test"); err != nil {
|
||||
t.Fatal("Failed to create test backup")
|
||||
}
|
||||
|
||||
// test pending error
|
||||
app.Store().Set(core.StoreKeyActiveBackup, "")
|
||||
if err := app.RestoreBackup(context.Background(), "test"); err == nil {
|
||||
t.Fatal("Expected pending error, got nil")
|
||||
}
|
||||
app.Store().Remove(core.StoreKeyActiveBackup)
|
||||
|
||||
// missing backup
|
||||
if err := app.RestoreBackup(context.Background(), "missing"); err == nil {
|
||||
t.Fatal("Expected missing error, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
func verifyBackupContent(app core.App, path string) error {
|
||||
dir, err := os.MkdirTemp("", "backup_test")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
if err := archive.Extract(path, dir); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
expectedRootEntries := []string{
|
||||
"storage",
|
||||
"data.db",
|
||||
"data.db-shm",
|
||||
"data.db-wal",
|
||||
"auxiliary.db",
|
||||
"auxiliary.db-shm",
|
||||
"auxiliary.db-wal",
|
||||
".gitignore",
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(entries) != len(expectedRootEntries) {
|
||||
names := getEntryNames(entries)
|
||||
return fmt.Errorf("Expected %d backup files, got %d: \n%v", len(expectedRootEntries), len(entries), names)
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
if !list.ExistInSliceWithRegex(entry.Name(), expectedRootEntries) {
|
||||
return fmt.Errorf("Didn't expect %q entry", entry.Name())
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getEntryNames(entries []fs.DirEntry) []string {
|
||||
names := make([]string, len(entries))
|
||||
|
||||
for i, entry := range entries {
|
||||
names[i] = entry.Name()
|
||||
}
|
||||
|
||||
return names
|
||||
}
|
554
core/base_test.go
Normal file
554
core/base_test.go
Normal file
|
@ -0,0 +1,554 @@
|
|||
package core_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"log/slog"
|
||||
"os"
|
||||
"slices"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
_ "unsafe"
|
||||
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
"github.com/pocketbase/pocketbase/tools/logger"
|
||||
"github.com/pocketbase/pocketbase/tools/mailer"
|
||||
)
|
||||
|
||||
func TestNewBaseApp(t *testing.T) {
|
||||
const testDataDir = "./pb_base_app_test_data_dir/"
|
||||
defer os.RemoveAll(testDataDir)
|
||||
|
||||
app := core.NewBaseApp(core.BaseAppConfig{
|
||||
DataDir: testDataDir,
|
||||
EncryptionEnv: "test_env",
|
||||
IsDev: true,
|
||||
})
|
||||
|
||||
if app.DataDir() != testDataDir {
|
||||
t.Fatalf("expected DataDir %q, got %q", testDataDir, app.DataDir())
|
||||
}
|
||||
|
||||
if app.EncryptionEnv() != "test_env" {
|
||||
t.Fatalf("expected EncryptionEnv test_env, got %q", app.EncryptionEnv())
|
||||
}
|
||||
|
||||
if !app.IsDev() {
|
||||
t.Fatalf("expected IsDev true, got %v", app.IsDev())
|
||||
}
|
||||
|
||||
if app.Store() == nil {
|
||||
t.Fatal("expected Store to be set, got nil")
|
||||
}
|
||||
|
||||
if app.Settings() == nil {
|
||||
t.Fatal("expected Settings to be set, got nil")
|
||||
}
|
||||
|
||||
if app.SubscriptionsBroker() == nil {
|
||||
t.Fatal("expected SubscriptionsBroker to be set, got nil")
|
||||
}
|
||||
|
||||
if app.Cron() == nil {
|
||||
t.Fatal("expected Cron to be set, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseAppBootstrap(t *testing.T) {
|
||||
const testDataDir = "./pb_base_app_test_data_dir/"
|
||||
defer os.RemoveAll(testDataDir)
|
||||
|
||||
app := core.NewBaseApp(core.BaseAppConfig{
|
||||
DataDir: testDataDir,
|
||||
})
|
||||
defer app.ResetBootstrapState()
|
||||
|
||||
if app.IsBootstrapped() {
|
||||
t.Fatal("Didn't expect the application to be bootstrapped.")
|
||||
}
|
||||
|
||||
if err := app.Bootstrap(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !app.IsBootstrapped() {
|
||||
t.Fatal("Expected the application to be bootstrapped.")
|
||||
}
|
||||
|
||||
if stat, err := os.Stat(testDataDir); err != nil || !stat.IsDir() {
|
||||
t.Fatal("Expected test data directory to be created.")
|
||||
}
|
||||
|
||||
type nilCheck struct {
|
||||
name string
|
||||
value any
|
||||
expectNil bool
|
||||
}
|
||||
|
||||
runNilChecks := func(checks []nilCheck) {
|
||||
for _, check := range checks {
|
||||
t.Run(check.name, func(t *testing.T) {
|
||||
isNil := check.value == nil
|
||||
if isNil != check.expectNil {
|
||||
t.Fatalf("Expected isNil %v, got %v", check.expectNil, isNil)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
nilChecksBeforeReset := []nilCheck{
|
||||
{"[before] db", app.DB(), false},
|
||||
{"[before] concurrentDB", app.ConcurrentDB(), false},
|
||||
{"[before] nonconcurrentDB", app.NonconcurrentDB(), false},
|
||||
{"[before] auxDB", app.AuxDB(), false},
|
||||
{"[before] auxConcurrentDB", app.AuxConcurrentDB(), false},
|
||||
{"[before] auxNonconcurrentDB", app.AuxNonconcurrentDB(), false},
|
||||
{"[before] settings", app.Settings(), false},
|
||||
{"[before] logger", app.Logger(), false},
|
||||
{"[before] cached collections", app.Store().Get(core.StoreKeyCachedCollections), false},
|
||||
}
|
||||
|
||||
runNilChecks(nilChecksBeforeReset)
|
||||
|
||||
// reset
|
||||
if err := app.ResetBootstrapState(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
nilChecksAfterReset := []nilCheck{
|
||||
{"[after] db", app.DB(), true},
|
||||
{"[after] concurrentDB", app.ConcurrentDB(), true},
|
||||
{"[after] nonconcurrentDB", app.NonconcurrentDB(), true},
|
||||
{"[after] auxDB", app.AuxDB(), true},
|
||||
{"[after] auxConcurrentDB", app.AuxConcurrentDB(), true},
|
||||
{"[after] auxNonconcurrentDB", app.AuxNonconcurrentDB(), true},
|
||||
{"[after] settings", app.Settings(), false},
|
||||
{"[after] logger", app.Logger(), false},
|
||||
{"[after] cached collections", app.Store().Get(core.StoreKeyCachedCollections), false},
|
||||
}
|
||||
|
||||
runNilChecks(nilChecksAfterReset)
|
||||
}
|
||||
|
||||
func TestNewBaseAppTx(t *testing.T) {
|
||||
const testDataDir = "./pb_base_app_test_data_dir/"
|
||||
defer os.RemoveAll(testDataDir)
|
||||
|
||||
app := core.NewBaseApp(core.BaseAppConfig{
|
||||
DataDir: testDataDir,
|
||||
})
|
||||
defer app.ResetBootstrapState()
|
||||
|
||||
if err := app.Bootstrap(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
mustNotHaveTx := func(app core.App) {
|
||||
if app.IsTransactional() {
|
||||
t.Fatalf("Didn't expect the app to be transactional")
|
||||
}
|
||||
|
||||
if app.TxInfo() != nil {
|
||||
t.Fatalf("Didn't expect the app.txInfo to be loaded")
|
||||
}
|
||||
}
|
||||
|
||||
mustHaveTx := func(app core.App) {
|
||||
if !app.IsTransactional() {
|
||||
t.Fatalf("Expected the app to be transactional")
|
||||
}
|
||||
|
||||
if app.TxInfo() == nil {
|
||||
t.Fatalf("Expected the app.txInfo to be loaded")
|
||||
}
|
||||
}
|
||||
|
||||
mustNotHaveTx(app)
|
||||
|
||||
app.RunInTransaction(func(txApp core.App) error {
|
||||
mustHaveTx(txApp)
|
||||
return nil
|
||||
})
|
||||
|
||||
mustNotHaveTx(app)
|
||||
}
|
||||
|
||||
func TestBaseAppNewMailClient(t *testing.T) {
|
||||
const testDataDir = "./pb_base_app_test_data_dir/"
|
||||
defer os.RemoveAll(testDataDir)
|
||||
|
||||
app := core.NewBaseApp(core.BaseAppConfig{
|
||||
DataDir: testDataDir,
|
||||
EncryptionEnv: "pb_test_env",
|
||||
})
|
||||
defer app.ResetBootstrapState()
|
||||
|
||||
client1 := app.NewMailClient()
|
||||
m1, ok := client1.(*mailer.Sendmail)
|
||||
if !ok {
|
||||
t.Fatalf("Expected mailer.Sendmail instance, got %v", m1)
|
||||
}
|
||||
if m1.OnSend() == nil || m1.OnSend().Length() == 0 {
|
||||
t.Fatal("Expected OnSend hook to be registered")
|
||||
}
|
||||
|
||||
app.Settings().SMTP.Enabled = true
|
||||
|
||||
client2 := app.NewMailClient()
|
||||
m2, ok := client2.(*mailer.SMTPClient)
|
||||
if !ok {
|
||||
t.Fatalf("Expected mailer.SMTPClient instance, got %v", m2)
|
||||
}
|
||||
if m2.OnSend() == nil || m2.OnSend().Length() == 0 {
|
||||
t.Fatal("Expected OnSend hook to be registered")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseAppNewFilesystem(t *testing.T) {
|
||||
const testDataDir = "./pb_base_app_test_data_dir/"
|
||||
defer os.RemoveAll(testDataDir)
|
||||
|
||||
app := core.NewBaseApp(core.BaseAppConfig{
|
||||
DataDir: testDataDir,
|
||||
})
|
||||
defer app.ResetBootstrapState()
|
||||
|
||||
// local
|
||||
local, localErr := app.NewFilesystem()
|
||||
if localErr != nil {
|
||||
t.Fatal(localErr)
|
||||
}
|
||||
if local == nil {
|
||||
t.Fatal("Expected local filesystem instance, got nil")
|
||||
}
|
||||
|
||||
// misconfigured s3
|
||||
app.Settings().S3.Enabled = true
|
||||
s3, s3Err := app.NewFilesystem()
|
||||
if s3Err == nil {
|
||||
t.Fatal("Expected S3 error, got nil")
|
||||
}
|
||||
if s3 != nil {
|
||||
t.Fatalf("Expected nil s3 filesystem, got %v", s3)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseAppNewBackupsFilesystem(t *testing.T) {
|
||||
const testDataDir = "./pb_base_app_test_data_dir/"
|
||||
defer os.RemoveAll(testDataDir)
|
||||
|
||||
app := core.NewBaseApp(core.BaseAppConfig{
|
||||
DataDir: testDataDir,
|
||||
})
|
||||
defer app.ResetBootstrapState()
|
||||
|
||||
// local
|
||||
local, localErr := app.NewBackupsFilesystem()
|
||||
if localErr != nil {
|
||||
t.Fatal(localErr)
|
||||
}
|
||||
if local == nil {
|
||||
t.Fatal("Expected local backups filesystem instance, got nil")
|
||||
}
|
||||
|
||||
// misconfigured s3
|
||||
app.Settings().Backups.S3.Enabled = true
|
||||
s3, s3Err := app.NewBackupsFilesystem()
|
||||
if s3Err == nil {
|
||||
t.Fatal("Expected S3 error, got nil")
|
||||
}
|
||||
if s3 != nil {
|
||||
t.Fatalf("Expected nil s3 backups filesystem, got %v", s3)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseAppLoggerWrites(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
// reset
|
||||
if err := app.DeleteOldLogs(time.Now()); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
const logsThreshold = 200
|
||||
|
||||
totalLogs := func(app core.App, t *testing.T) int {
|
||||
var total int
|
||||
|
||||
err := app.LogQuery().Select("count(*)").Row(&total)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to fetch total logs: %v", err)
|
||||
}
|
||||
|
||||
return total
|
||||
}
|
||||
|
||||
t.Run("disabled logs retention", func(t *testing.T) {
|
||||
app.Settings().Logs.MaxDays = 0
|
||||
|
||||
for i := 0; i < logsThreshold+1; i++ {
|
||||
app.Logger().Error("test")
|
||||
}
|
||||
|
||||
if total := totalLogs(app, t); total != 0 {
|
||||
t.Fatalf("Expected no logs, got %d", total)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("test batch logs writes", func(t *testing.T) {
|
||||
app.Settings().Logs.MaxDays = 1
|
||||
|
||||
for i := 0; i < logsThreshold-1; i++ {
|
||||
app.Logger().Error("test")
|
||||
}
|
||||
|
||||
if total := totalLogs(app, t); total != 0 {
|
||||
t.Fatalf("Expected no logs, got %d", total)
|
||||
}
|
||||
|
||||
// should trigger batch write
|
||||
app.Logger().Error("test")
|
||||
|
||||
// should be added for the next batch write
|
||||
app.Logger().Error("test")
|
||||
|
||||
if total := totalLogs(app, t); total != logsThreshold {
|
||||
t.Fatalf("Expected %d logs, got %d", logsThreshold, total)
|
||||
}
|
||||
|
||||
// wait for ~3 secs to check the timer trigger
|
||||
time.Sleep(3200 * time.Millisecond)
|
||||
if total := totalLogs(app, t); total != logsThreshold+1 {
|
||||
t.Fatalf("Expected %d logs, got %d", logsThreshold+1, total)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestBaseAppRefreshSettingsLoggerMinLevelEnabled(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
name string
|
||||
isDev bool
|
||||
level int
|
||||
// level->enabled map
|
||||
expectations map[int]bool
|
||||
}{
|
||||
{
|
||||
"dev mode",
|
||||
true,
|
||||
4,
|
||||
map[int]bool{
|
||||
3: true,
|
||||
4: true,
|
||||
5: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
"nondev mode",
|
||||
false,
|
||||
4,
|
||||
map[int]bool{
|
||||
3: false,
|
||||
4: true,
|
||||
5: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
const testDataDir = "./pb_base_app_test_data_dir/"
|
||||
defer os.RemoveAll(testDataDir)
|
||||
|
||||
app := core.NewBaseApp(core.BaseAppConfig{
|
||||
DataDir: testDataDir,
|
||||
IsDev: s.isDev,
|
||||
})
|
||||
defer app.ResetBootstrapState()
|
||||
|
||||
if err := app.Bootstrap(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// silence query logs
|
||||
app.ConcurrentDB().(*dbx.DB).ExecLogFunc = func(ctx context.Context, t time.Duration, sql string, result sql.Result, err error) {}
|
||||
app.ConcurrentDB().(*dbx.DB).QueryLogFunc = func(ctx context.Context, t time.Duration, sql string, rows *sql.Rows, err error) {}
|
||||
app.NonconcurrentDB().(*dbx.DB).ExecLogFunc = func(ctx context.Context, t time.Duration, sql string, result sql.Result, err error) {}
|
||||
app.NonconcurrentDB().(*dbx.DB).QueryLogFunc = func(ctx context.Context, t time.Duration, sql string, rows *sql.Rows, err error) {}
|
||||
|
||||
handler, ok := app.Logger().Handler().(*logger.BatchHandler)
|
||||
if !ok {
|
||||
t.Fatalf("Expected BatchHandler, got %v", app.Logger().Handler())
|
||||
}
|
||||
|
||||
app.Settings().Logs.MinLevel = s.level
|
||||
|
||||
if err := app.Save(app.Settings()); err != nil {
|
||||
t.Fatalf("Failed to save settings: %v", err)
|
||||
}
|
||||
|
||||
for level, enabled := range s.expectations {
|
||||
if v := handler.Enabled(context.Background(), slog.Level(level)); v != enabled {
|
||||
t.Fatalf("Expected level %d Enabled() to be %v, got %v", level, enabled, v)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseAppDBDualBuilder(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
concurrentQueries := []string{}
|
||||
nonconcurrentQueries := []string{}
|
||||
app.ConcurrentDB().(*dbx.DB).QueryLogFunc = func(ctx context.Context, t time.Duration, sql string, rows *sql.Rows, err error) {
|
||||
concurrentQueries = append(concurrentQueries, sql)
|
||||
}
|
||||
app.ConcurrentDB().(*dbx.DB).ExecLogFunc = func(ctx context.Context, t time.Duration, sql string, result sql.Result, err error) {
|
||||
concurrentQueries = append(concurrentQueries, sql)
|
||||
}
|
||||
app.NonconcurrentDB().(*dbx.DB).QueryLogFunc = func(ctx context.Context, t time.Duration, sql string, rows *sql.Rows, err error) {
|
||||
nonconcurrentQueries = append(nonconcurrentQueries, sql)
|
||||
}
|
||||
app.NonconcurrentDB().(*dbx.DB).ExecLogFunc = func(ctx context.Context, t time.Duration, sql string, result sql.Result, err error) {
|
||||
nonconcurrentQueries = append(nonconcurrentQueries, sql)
|
||||
}
|
||||
|
||||
type testQuery struct {
|
||||
query string
|
||||
isConcurrent bool
|
||||
}
|
||||
|
||||
regularTests := []testQuery{
|
||||
{" \n sEleCt 1", true},
|
||||
{"With abc(x) AS (select 2) SELECT x FROM abc", true},
|
||||
{"create table t1(x int)", false},
|
||||
{"insert into t1(x) values(1)", false},
|
||||
{"update t1 set x = 2", false},
|
||||
{"delete from t1", false},
|
||||
}
|
||||
|
||||
txTests := []testQuery{
|
||||
{"select 3", false},
|
||||
{" \n WITH abc(x) AS (select 4) SELECT x FROM abc", false},
|
||||
{"create table t2(x int)", false},
|
||||
{"insert into t2(x) values(1)", false},
|
||||
{"update t2 set x = 2", false},
|
||||
{"delete from t2", false},
|
||||
}
|
||||
|
||||
for _, item := range regularTests {
|
||||
_, err := app.DB().NewQuery(item.query).Execute()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to execute query %q error: %v", item.query, err)
|
||||
}
|
||||
}
|
||||
|
||||
app.RunInTransaction(func(txApp core.App) error {
|
||||
for _, item := range txTests {
|
||||
_, err := txApp.DB().NewQuery(item.query).Execute()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to execute query %q error: %v", item.query, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
allTests := append(regularTests, txTests...)
|
||||
for _, item := range allTests {
|
||||
if item.isConcurrent {
|
||||
if !slices.Contains(concurrentQueries, item.query) {
|
||||
t.Fatalf("Expected concurrent query\n%q\ngot\nconcurrent:%v\nnonconcurrent:%v", item.query, concurrentQueries, nonconcurrentQueries)
|
||||
}
|
||||
} else {
|
||||
if !slices.Contains(nonconcurrentQueries, item.query) {
|
||||
t.Fatalf("Expected nonconcurrent query\n%q\ngot\nconcurrent:%v\nnonconcurrent:%v", item.query, concurrentQueries, nonconcurrentQueries)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseAppAuxDBDualBuilder(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
concurrentQueries := []string{}
|
||||
nonconcurrentQueries := []string{}
|
||||
app.AuxConcurrentDB().(*dbx.DB).QueryLogFunc = func(ctx context.Context, t time.Duration, sql string, rows *sql.Rows, err error) {
|
||||
concurrentQueries = append(concurrentQueries, sql)
|
||||
}
|
||||
app.AuxConcurrentDB().(*dbx.DB).ExecLogFunc = func(ctx context.Context, t time.Duration, sql string, result sql.Result, err error) {
|
||||
concurrentQueries = append(concurrentQueries, sql)
|
||||
}
|
||||
app.AuxNonconcurrentDB().(*dbx.DB).QueryLogFunc = func(ctx context.Context, t time.Duration, sql string, rows *sql.Rows, err error) {
|
||||
nonconcurrentQueries = append(nonconcurrentQueries, sql)
|
||||
}
|
||||
app.AuxNonconcurrentDB().(*dbx.DB).ExecLogFunc = func(ctx context.Context, t time.Duration, sql string, result sql.Result, err error) {
|
||||
nonconcurrentQueries = append(nonconcurrentQueries, sql)
|
||||
}
|
||||
|
||||
type testQuery struct {
|
||||
query string
|
||||
isConcurrent bool
|
||||
}
|
||||
|
||||
regularTests := []testQuery{
|
||||
{" \n sEleCt 1", true},
|
||||
{"With abc(x) AS (select 2) SELECT x FROM abc", true},
|
||||
{"create table t1(x int)", false},
|
||||
{"insert into t1(x) values(1)", false},
|
||||
{"update t1 set x = 2", false},
|
||||
{"delete from t1", false},
|
||||
}
|
||||
|
||||
txTests := []testQuery{
|
||||
{"select 3", false},
|
||||
{" \n WITH abc(x) AS (select 4) SELECT x FROM abc", false},
|
||||
{"create table t2(x int)", false},
|
||||
{"insert into t2(x) values(1)", false},
|
||||
{"update t2 set x = 2", false},
|
||||
{"delete from t2", false},
|
||||
}
|
||||
|
||||
for _, item := range regularTests {
|
||||
_, err := app.AuxDB().NewQuery(item.query).Execute()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to execute query %q error: %v", item.query, err)
|
||||
}
|
||||
}
|
||||
|
||||
app.AuxRunInTransaction(func(txApp core.App) error {
|
||||
for _, item := range txTests {
|
||||
_, err := txApp.AuxDB().NewQuery(item.query).Execute()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to execute query %q error: %v", item.query, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
allTests := append(regularTests, txTests...)
|
||||
for _, item := range allTests {
|
||||
if item.isConcurrent {
|
||||
if !slices.Contains(concurrentQueries, item.query) {
|
||||
t.Fatalf("Expected concurrent query\n%q\ngot\nconcurrent:%v\nnonconcurrent:%v", item.query, concurrentQueries, nonconcurrentQueries)
|
||||
}
|
||||
} else {
|
||||
if !slices.Contains(nonconcurrentQueries, item.query) {
|
||||
t.Fatalf("Expected nonconcurrent query\n%q\ngot\nconcurrent:%v\nnonconcurrent:%v", item.query, concurrentQueries, nonconcurrentQueries)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
200
core/collection_import.go
Normal file
200
core/collection_import.go
Normal file
|
@ -0,0 +1,200 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/spf13/cast"
|
||||
)
|
||||
|
||||
// ImportCollectionsByMarshaledJSON is the same as [ImportCollections]
|
||||
// but accept marshaled json array as import data (usually used for the autogenerated snapshots).
|
||||
func (app *BaseApp) ImportCollectionsByMarshaledJSON(rawSliceOfMaps []byte, deleteMissing bool) error {
|
||||
data := []map[string]any{}
|
||||
|
||||
err := json.Unmarshal(rawSliceOfMaps, &data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return app.ImportCollections(data, deleteMissing)
|
||||
}
|
||||
|
||||
// ImportCollections imports the provided collections data in a single transaction.
|
||||
//
|
||||
// For existing matching collections, the imported data is unmarshaled on top of the existing model.
|
||||
//
|
||||
// NB! If deleteMissing is true, ALL NON-SYSTEM COLLECTIONS AND SCHEMA FIELDS,
|
||||
// that are not present in the imported configuration, WILL BE DELETED
|
||||
// (this includes their related records data).
|
||||
func (app *BaseApp) ImportCollections(toImport []map[string]any, deleteMissing bool) error {
|
||||
if len(toImport) == 0 {
|
||||
// prevent accidentally deleting all collections
|
||||
return errors.New("no collections to import")
|
||||
}
|
||||
|
||||
importedCollections := make([]*Collection, len(toImport))
|
||||
mappedImported := make(map[string]*Collection, len(toImport))
|
||||
|
||||
// normalize imported collections data to ensure that all
|
||||
// collection fields are present and properly initialized
|
||||
for i, data := range toImport {
|
||||
var imported *Collection
|
||||
|
||||
identifier := cast.ToString(data["id"])
|
||||
if identifier == "" {
|
||||
identifier = cast.ToString(data["name"])
|
||||
}
|
||||
|
||||
existing, err := app.FindCollectionByNameOrId(identifier)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
return err
|
||||
}
|
||||
|
||||
if existing != nil {
|
||||
// refetch for deep copy
|
||||
imported, err = app.FindCollectionByNameOrId(existing.Id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// ensure that the fields will be cleared
|
||||
if data["fields"] == nil && deleteMissing {
|
||||
data["fields"] = []map[string]any{}
|
||||
}
|
||||
|
||||
rawData, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// load the imported data
|
||||
err = json.Unmarshal(rawData, imported)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// extend with the existing fields if necessary
|
||||
for _, f := range existing.Fields {
|
||||
if !f.GetSystem() && deleteMissing {
|
||||
continue
|
||||
}
|
||||
if imported.Fields.GetById(f.GetId()) == nil {
|
||||
// replace with the existing id to prevent accidental column deletion
|
||||
// since otherwise the imported field will be treated as a new one
|
||||
found := imported.Fields.GetByName(f.GetName())
|
||||
if found != nil && found.Type() == f.Type() {
|
||||
found.SetId(f.GetId())
|
||||
}
|
||||
imported.Fields.Add(f)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
imported = &Collection{}
|
||||
|
||||
rawData, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// load the imported data
|
||||
err = json.Unmarshal(rawData, imported)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
imported.IntegrityChecks(false)
|
||||
|
||||
importedCollections[i] = imported
|
||||
mappedImported[imported.Id] = imported
|
||||
}
|
||||
|
||||
// reorder views last since the view query could depend on some of the other collections
|
||||
slices.SortStableFunc(importedCollections, func(a, b *Collection) int {
|
||||
cmpA := -1
|
||||
if a.IsView() {
|
||||
cmpA = 1
|
||||
}
|
||||
|
||||
cmpB := -1
|
||||
if b.IsView() {
|
||||
cmpB = 1
|
||||
}
|
||||
|
||||
res := cmp.Compare(cmpA, cmpB)
|
||||
if res == 0 {
|
||||
res = a.Created.Compare(b.Created)
|
||||
if res == 0 {
|
||||
res = a.Updated.Compare(b.Updated)
|
||||
}
|
||||
}
|
||||
return res
|
||||
})
|
||||
|
||||
return app.RunInTransaction(func(txApp App) error {
|
||||
existingCollections := []*Collection{}
|
||||
if err := txApp.CollectionQuery().OrderBy("updated ASC").All(&existingCollections); err != nil {
|
||||
return err
|
||||
}
|
||||
mappedExisting := make(map[string]*Collection, len(existingCollections))
|
||||
for _, existing := range existingCollections {
|
||||
existing.IntegrityChecks(false)
|
||||
mappedExisting[existing.Id] = existing
|
||||
}
|
||||
|
||||
// delete old collections not available in the new configuration
|
||||
// (before saving the imports in case a deleted collection name is being reused)
|
||||
if deleteMissing {
|
||||
for _, existing := range existingCollections {
|
||||
if mappedImported[existing.Id] != nil || existing.System {
|
||||
continue // exist or system
|
||||
}
|
||||
|
||||
// delete collection
|
||||
if err := txApp.Delete(existing); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// upsert imported collections
|
||||
for _, imported := range importedCollections {
|
||||
if err := txApp.SaveNoValidate(imported); err != nil {
|
||||
return fmt.Errorf("failed to save collection %q: %w", imported.Name, err)
|
||||
}
|
||||
}
|
||||
|
||||
// run validations
|
||||
for _, imported := range importedCollections {
|
||||
original := mappedExisting[imported.Id]
|
||||
if original == nil {
|
||||
original = imported
|
||||
}
|
||||
|
||||
validator := newCollectionValidator(
|
||||
context.Background(),
|
||||
txApp,
|
||||
imported,
|
||||
original,
|
||||
)
|
||||
if err := validator.run(); err != nil {
|
||||
// serialize the validation error(s)
|
||||
serializedErr, _ := json.MarshalIndent(err, "", " ")
|
||||
|
||||
return validation.Errors{"collections": validation.NewError(
|
||||
"validation_collections_import_failure",
|
||||
fmt.Sprintf("Data validations failed for collection %q (%s):\n%s", imported.Name, imported.Id, serializedErr),
|
||||
)}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
476
core/collection_import_test.go
Normal file
476
core/collection_import_test.go
Normal file
|
@ -0,0 +1,476 @@
|
|||
package core_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestImportCollections(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testApp, _ := tests.NewTestApp()
|
||||
defer testApp.Cleanup()
|
||||
|
||||
var regularCollections []*core.Collection
|
||||
err := testApp.CollectionQuery().AndWhere(dbx.HashExp{"system": false}).All(®ularCollections)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var systemCollections []*core.Collection
|
||||
err = testApp.CollectionQuery().AndWhere(dbx.HashExp{"system": true}).All(&systemCollections)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
totalRegularCollections := len(regularCollections)
|
||||
totalSystemCollections := len(systemCollections)
|
||||
totalCollections := totalRegularCollections + totalSystemCollections
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
data []map[string]any
|
||||
deleteMissing bool
|
||||
expectError bool
|
||||
expectCollectionsCount int
|
||||
afterTestFunc func(testApp *tests.TestApp, resultCollections []*core.Collection)
|
||||
}{
|
||||
{
|
||||
name: "empty collections",
|
||||
data: []map[string]any{},
|
||||
expectError: true,
|
||||
expectCollectionsCount: totalCollections,
|
||||
},
|
||||
{
|
||||
name: "minimal collection import (with missing system fields)",
|
||||
data: []map[string]any{
|
||||
{"name": "import_test1", "type": "auth"},
|
||||
{
|
||||
"name": "import_test2", "fields": []map[string]any{
|
||||
{"name": "test", "type": "text"},
|
||||
},
|
||||
},
|
||||
},
|
||||
deleteMissing: false,
|
||||
expectError: false,
|
||||
expectCollectionsCount: totalCollections + 2,
|
||||
},
|
||||
{
|
||||
name: "minimal collection import (trigger collection model validations)",
|
||||
data: []map[string]any{
|
||||
{"name": ""},
|
||||
{
|
||||
"name": "import_test2", "fields": []map[string]any{
|
||||
{"name": "test", "type": "text"},
|
||||
},
|
||||
},
|
||||
},
|
||||
deleteMissing: false,
|
||||
expectError: true,
|
||||
expectCollectionsCount: totalCollections,
|
||||
},
|
||||
{
|
||||
name: "minimal collection import (trigger field settings validation)",
|
||||
data: []map[string]any{
|
||||
{"name": "import_test", "fields": []map[string]any{{"name": "test", "type": "text", "min": -1}}},
|
||||
},
|
||||
deleteMissing: false,
|
||||
expectError: true,
|
||||
expectCollectionsCount: totalCollections,
|
||||
},
|
||||
{
|
||||
name: "new + update + delete (system collections delete should be ignored)",
|
||||
data: []map[string]any{
|
||||
{
|
||||
"id": "wsmn24bux7wo113",
|
||||
"name": "demo",
|
||||
"fields": []map[string]any{
|
||||
{
|
||||
"id": "_2hlxbmp",
|
||||
"name": "title",
|
||||
"type": "text",
|
||||
"system": false,
|
||||
"required": true,
|
||||
"min": 3,
|
||||
"max": nil,
|
||||
"pattern": "",
|
||||
},
|
||||
},
|
||||
"indexes": []string{},
|
||||
},
|
||||
{
|
||||
"name": "import1",
|
||||
"fields": []map[string]any{
|
||||
{
|
||||
"name": "active",
|
||||
"type": "bool",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
deleteMissing: true,
|
||||
expectError: false,
|
||||
expectCollectionsCount: totalSystemCollections + 2,
|
||||
},
|
||||
{
|
||||
name: "test with deleteMissing: false",
|
||||
data: []map[string]any{
|
||||
{
|
||||
// "id": "wsmn24bux7wo113", // test update with only name as identifier
|
||||
"name": "demo1",
|
||||
"fields": []map[string]any{
|
||||
{
|
||||
"id": "_2hlxbmp",
|
||||
"name": "title",
|
||||
"type": "text",
|
||||
"system": false,
|
||||
"required": true,
|
||||
"min": 3,
|
||||
"max": nil,
|
||||
"pattern": "",
|
||||
},
|
||||
{
|
||||
"id": "_2hlxbmp",
|
||||
"name": "field_with_duplicate_id",
|
||||
"type": "text",
|
||||
"system": false,
|
||||
"required": true,
|
||||
"unique": false,
|
||||
"min": 4,
|
||||
"max": nil,
|
||||
"pattern": "",
|
||||
},
|
||||
{
|
||||
"id": "abcd_import",
|
||||
"name": "new_field",
|
||||
"type": "text",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "new_import",
|
||||
"fields": []map[string]any{
|
||||
{
|
||||
"id": "abcd_import",
|
||||
"name": "active",
|
||||
"type": "bool",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
deleteMissing: false,
|
||||
expectError: false,
|
||||
expectCollectionsCount: totalCollections + 1,
|
||||
afterTestFunc: func(testApp *tests.TestApp, resultCollections []*core.Collection) {
|
||||
expectedCollectionFields := map[string]int{
|
||||
core.CollectionNameAuthOrigins: 6,
|
||||
"nologin": 10,
|
||||
"demo1": 19,
|
||||
"demo2": 5,
|
||||
"demo3": 5,
|
||||
"demo4": 16,
|
||||
"demo5": 9,
|
||||
"new_import": 2,
|
||||
}
|
||||
for name, expectedCount := range expectedCollectionFields {
|
||||
collection, err := testApp.FindCollectionByNameOrId(name)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if totalFields := len(collection.Fields); totalFields != expectedCount {
|
||||
t.Errorf("Expected %d %q fields, got %d", expectedCount, collection.Name, totalFields)
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
testApp, _ := tests.NewTestApp()
|
||||
defer testApp.Cleanup()
|
||||
|
||||
err := testApp.ImportCollections(s.data, s.deleteMissing)
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Fatalf("Expected hasErr to be %v, got %v (%v)", s.expectError, hasErr, err)
|
||||
}
|
||||
|
||||
// check collections count
|
||||
collections := []*core.Collection{}
|
||||
if err := testApp.CollectionQuery().All(&collections); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(collections) != s.expectCollectionsCount {
|
||||
t.Fatalf("Expected %d collections, got %d", s.expectCollectionsCount, len(collections))
|
||||
}
|
||||
|
||||
if s.afterTestFunc != nil {
|
||||
s.afterTestFunc(testApp, collections)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestImportCollectionsByMarshaledJSON(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testApp, _ := tests.NewTestApp()
|
||||
defer testApp.Cleanup()
|
||||
|
||||
var regularCollections []*core.Collection
|
||||
err := testApp.CollectionQuery().AndWhere(dbx.HashExp{"system": false}).All(®ularCollections)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var systemCollections []*core.Collection
|
||||
err = testApp.CollectionQuery().AndWhere(dbx.HashExp{"system": true}).All(&systemCollections)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
totalRegularCollections := len(regularCollections)
|
||||
totalSystemCollections := len(systemCollections)
|
||||
totalCollections := totalRegularCollections + totalSystemCollections
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
data string
|
||||
deleteMissing bool
|
||||
expectError bool
|
||||
expectCollectionsCount int
|
||||
afterTestFunc func(testApp *tests.TestApp, resultCollections []*core.Collection)
|
||||
}{
|
||||
{
|
||||
name: "invalid json array",
|
||||
data: `{"test":123}`,
|
||||
expectError: true,
|
||||
expectCollectionsCount: totalCollections,
|
||||
},
|
||||
{
|
||||
name: "new + update + delete (system collections delete should be ignored)",
|
||||
data: `[
|
||||
{
|
||||
"id": "wsmn24bux7wo113",
|
||||
"name": "demo",
|
||||
"fields": [
|
||||
{
|
||||
"id": "_2hlxbmp",
|
||||
"name": "title",
|
||||
"type": "text",
|
||||
"system": false,
|
||||
"required": true,
|
||||
"min": 3,
|
||||
"max": null,
|
||||
"pattern": ""
|
||||
}
|
||||
],
|
||||
"indexes": []
|
||||
},
|
||||
{
|
||||
"name": "import1",
|
||||
"fields": [
|
||||
{
|
||||
"name": "active",
|
||||
"type": "bool"
|
||||
}
|
||||
]
|
||||
}
|
||||
]`,
|
||||
deleteMissing: true,
|
||||
expectError: false,
|
||||
expectCollectionsCount: totalSystemCollections + 2,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
testApp, _ := tests.NewTestApp()
|
||||
defer testApp.Cleanup()
|
||||
|
||||
err := testApp.ImportCollectionsByMarshaledJSON([]byte(s.data), s.deleteMissing)
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Fatalf("Expected hasErr to be %v, got %v (%v)", s.expectError, hasErr, err)
|
||||
}
|
||||
|
||||
// check collections count
|
||||
collections := []*core.Collection{}
|
||||
if err := testApp.CollectionQuery().All(&collections); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(collections) != s.expectCollectionsCount {
|
||||
t.Fatalf("Expected %d collections, got %d", s.expectCollectionsCount, len(collections))
|
||||
}
|
||||
|
||||
if s.afterTestFunc != nil {
|
||||
s.afterTestFunc(testApp, collections)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestImportCollectionsUpdateRules(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
data map[string]any
|
||||
deleteMissing bool
|
||||
}{
|
||||
{
|
||||
"extend existing by name (without deleteMissing)",
|
||||
map[string]any{"name": "clients", "authToken": map[string]any{"duration": 100}, "fields": []map[string]any{{"name": "test", "type": "text"}}},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"extend existing by id (without deleteMissing)",
|
||||
map[string]any{"id": "v851q4r790rhknl", "authToken": map[string]any{"duration": 100}, "fields": []map[string]any{{"name": "test", "type": "text"}}},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"extend with delete missing",
|
||||
map[string]any{
|
||||
"id": "v851q4r790rhknl",
|
||||
"authToken": map[string]any{"duration": 100},
|
||||
"fields": []map[string]any{{"name": "test", "type": "text"}},
|
||||
"passwordAuth": map[string]any{"identityFields": []string{"email"}},
|
||||
"indexes": []string{
|
||||
// min required system fields indexes
|
||||
"CREATE UNIQUE INDEX `_v851q4r790rhknl_email_idx` ON `clients` (email) WHERE email != ''",
|
||||
"CREATE UNIQUE INDEX `_v851q4r790rhknl_tokenKey_idx` ON `clients` (tokenKey)",
|
||||
},
|
||||
},
|
||||
true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
testApp, _ := tests.NewTestApp()
|
||||
defer testApp.Cleanup()
|
||||
|
||||
beforeCollection, err := testApp.FindCollectionByNameOrId("clients")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = testApp.ImportCollections([]map[string]any{s.data}, s.deleteMissing)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
afterCollection, err := testApp.FindCollectionByNameOrId("clients")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if afterCollection.AuthToken.Duration != 100 {
|
||||
t.Fatalf("Expected AuthToken duration to be %d, got %d", 100, afterCollection.AuthToken.Duration)
|
||||
}
|
||||
if beforeCollection.AuthToken.Secret != afterCollection.AuthToken.Secret {
|
||||
t.Fatalf("Expected AuthToken secrets to remain the same, got\n%q\nVS\n%q", beforeCollection.AuthToken.Secret, afterCollection.AuthToken.Secret)
|
||||
}
|
||||
if beforeCollection.Name != afterCollection.Name {
|
||||
t.Fatalf("Expected Name to remain the same, got\n%q\nVS\n%q", beforeCollection.Name, afterCollection.Name)
|
||||
}
|
||||
if beforeCollection.Id != afterCollection.Id {
|
||||
t.Fatalf("Expected Id to remain the same, got\n%q\nVS\n%q", beforeCollection.Id, afterCollection.Id)
|
||||
}
|
||||
|
||||
if !s.deleteMissing {
|
||||
totalExpectedFields := len(beforeCollection.Fields) + 1
|
||||
if v := len(afterCollection.Fields); v != totalExpectedFields {
|
||||
t.Fatalf("Expected %d total fields, got %d", totalExpectedFields, v)
|
||||
}
|
||||
|
||||
if afterCollection.Fields.GetByName("test") == nil {
|
||||
t.Fatalf("Missing new field %q", "test")
|
||||
}
|
||||
|
||||
// ensure that the old fields still exist
|
||||
oldFields := beforeCollection.Fields.FieldNames()
|
||||
for _, name := range oldFields {
|
||||
if afterCollection.Fields.GetByName(name) == nil {
|
||||
t.Fatalf("Missing expected old field %q", name)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
totalExpectedFields := 1
|
||||
for _, f := range beforeCollection.Fields {
|
||||
if f.GetSystem() {
|
||||
totalExpectedFields++
|
||||
}
|
||||
}
|
||||
|
||||
if v := len(afterCollection.Fields); v != totalExpectedFields {
|
||||
t.Fatalf("Expected %d total fields, got %d", totalExpectedFields, v)
|
||||
}
|
||||
|
||||
if afterCollection.Fields.GetByName("test") == nil {
|
||||
t.Fatalf("Missing new field %q", "test")
|
||||
}
|
||||
|
||||
// ensure that the old system fields still exist
|
||||
for _, f := range beforeCollection.Fields {
|
||||
if f.GetSystem() && afterCollection.Fields.GetByName(f.GetName()) == nil {
|
||||
t.Fatalf("Missing expected old field %q", f.GetName())
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestImportCollectionsCreateRules(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testApp, _ := tests.NewTestApp()
|
||||
defer testApp.Cleanup()
|
||||
|
||||
err := testApp.ImportCollections([]map[string]any{
|
||||
{"name": "new_test", "type": "auth", "authToken": map[string]any{"duration": 123}, "fields": []map[string]any{{"name": "test", "type": "text"}}},
|
||||
}, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
collection, err := testApp.FindCollectionByNameOrId("new_test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
raw, err := json.Marshal(collection)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
rawStr := string(raw)
|
||||
|
||||
expectedParts := []string{
|
||||
`"name":"new_test"`,
|
||||
`"fields":[`,
|
||||
`"name":"id"`,
|
||||
`"name":"email"`,
|
||||
`"name":"tokenKey"`,
|
||||
`"name":"password"`,
|
||||
`"name":"test"`,
|
||||
`"indexes":[`,
|
||||
`CREATE UNIQUE INDEX`,
|
||||
`"duration":123`,
|
||||
}
|
||||
|
||||
for _, part := range expectedParts {
|
||||
if !strings.Contains(rawStr, part) {
|
||||
t.Errorf("Missing %q in\n%s", part, rawStr)
|
||||
}
|
||||
}
|
||||
}
|
1073
core/collection_model.go
Normal file
1073
core/collection_model.go
Normal file
File diff suppressed because it is too large
Load diff
543
core/collection_model_auth_options.go
Normal file
543
core/collection_model_auth_options.go
Normal file
|
@ -0,0 +1,543 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/go-ozzo/ozzo-validation/v4/is"
|
||||
"github.com/pocketbase/pocketbase/tools/auth"
|
||||
"github.com/pocketbase/pocketbase/tools/list"
|
||||
"github.com/pocketbase/pocketbase/tools/security"
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
"github.com/spf13/cast"
|
||||
)
|
||||
|
||||
func (m *Collection) unsetMissingOAuth2MappedFields() {
|
||||
if !m.IsAuth() {
|
||||
return
|
||||
}
|
||||
|
||||
if m.OAuth2.MappedFields.Id != "" {
|
||||
if m.Fields.GetByName(m.OAuth2.MappedFields.Id) == nil {
|
||||
m.OAuth2.MappedFields.Id = ""
|
||||
}
|
||||
}
|
||||
|
||||
if m.OAuth2.MappedFields.Name != "" {
|
||||
if m.Fields.GetByName(m.OAuth2.MappedFields.Name) == nil {
|
||||
m.OAuth2.MappedFields.Name = ""
|
||||
}
|
||||
}
|
||||
|
||||
if m.OAuth2.MappedFields.Username != "" {
|
||||
if m.Fields.GetByName(m.OAuth2.MappedFields.Username) == nil {
|
||||
m.OAuth2.MappedFields.Username = ""
|
||||
}
|
||||
}
|
||||
|
||||
if m.OAuth2.MappedFields.AvatarURL != "" {
|
||||
if m.Fields.GetByName(m.OAuth2.MappedFields.AvatarURL) == nil {
|
||||
m.OAuth2.MappedFields.AvatarURL = ""
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Collection) setDefaultAuthOptions() {
|
||||
m.collectionAuthOptions = collectionAuthOptions{
|
||||
VerificationTemplate: defaultVerificationTemplate,
|
||||
ResetPasswordTemplate: defaultResetPasswordTemplate,
|
||||
ConfirmEmailChangeTemplate: defaultConfirmEmailChangeTemplate,
|
||||
AuthRule: types.Pointer(""),
|
||||
AuthAlert: AuthAlertConfig{
|
||||
Enabled: true,
|
||||
EmailTemplate: defaultAuthAlertTemplate,
|
||||
},
|
||||
PasswordAuth: PasswordAuthConfig{
|
||||
Enabled: true,
|
||||
IdentityFields: []string{FieldNameEmail},
|
||||
},
|
||||
MFA: MFAConfig{
|
||||
Enabled: false,
|
||||
Duration: 1800, // 30min
|
||||
},
|
||||
OTP: OTPConfig{
|
||||
Enabled: false,
|
||||
Duration: 180, // 3min
|
||||
Length: 8,
|
||||
EmailTemplate: defaultOTPTemplate,
|
||||
},
|
||||
AuthToken: TokenConfig{
|
||||
Secret: security.RandomString(50),
|
||||
Duration: 604800, // 7 days
|
||||
},
|
||||
PasswordResetToken: TokenConfig{
|
||||
Secret: security.RandomString(50),
|
||||
Duration: 1800, // 30min
|
||||
},
|
||||
EmailChangeToken: TokenConfig{
|
||||
Secret: security.RandomString(50),
|
||||
Duration: 1800, // 30min
|
||||
},
|
||||
VerificationToken: TokenConfig{
|
||||
Secret: security.RandomString(50),
|
||||
Duration: 259200, // 3days
|
||||
},
|
||||
FileToken: TokenConfig{
|
||||
Secret: security.RandomString(50),
|
||||
Duration: 180, // 3min
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
var _ optionsValidator = (*collectionAuthOptions)(nil)
|
||||
|
||||
// collectionAuthOptions defines the options for the "auth" type collection.
|
||||
type collectionAuthOptions struct {
|
||||
// AuthRule could be used to specify additional record constraints
|
||||
// applied after record authentication and right before returning the
|
||||
// auth token response to the client.
|
||||
//
|
||||
// For example, to allow only verified users you could set it to
|
||||
// "verified = true".
|
||||
//
|
||||
// Set it to empty string to allow any Auth collection record to authenticate.
|
||||
//
|
||||
// Set it to nil to disallow authentication altogether for the collection
|
||||
// (that includes password, OAuth2, etc.).
|
||||
AuthRule *string `form:"authRule" json:"authRule"`
|
||||
|
||||
// ManageRule gives admin-like permissions to allow fully managing
|
||||
// the auth record(s), eg. changing the password without requiring
|
||||
// to enter the old one, directly updating the verified state and email, etc.
|
||||
//
|
||||
// This rule is executed in addition to the Create and Update API rules.
|
||||
ManageRule *string `form:"manageRule" json:"manageRule"`
|
||||
|
||||
// AuthAlert defines options related to the auth alerts on new device login.
|
||||
AuthAlert AuthAlertConfig `form:"authAlert" json:"authAlert"`
|
||||
|
||||
// OAuth2 specifies whether OAuth2 auth is enabled for the collection
|
||||
// and which OAuth2 providers are allowed.
|
||||
OAuth2 OAuth2Config `form:"oauth2" json:"oauth2"`
|
||||
|
||||
// PasswordAuth defines options related to the collection password authentication.
|
||||
PasswordAuth PasswordAuthConfig `form:"passwordAuth" json:"passwordAuth"`
|
||||
|
||||
// MFA defines options related to the Multi-factor authentication (MFA).
|
||||
MFA MFAConfig `form:"mfa" json:"mfa"`
|
||||
|
||||
// OTP defines options related to the One-time password authentication (OTP).
|
||||
OTP OTPConfig `form:"otp" json:"otp"`
|
||||
|
||||
// Various token configurations
|
||||
// ---
|
||||
AuthToken TokenConfig `form:"authToken" json:"authToken"`
|
||||
PasswordResetToken TokenConfig `form:"passwordResetToken" json:"passwordResetToken"`
|
||||
EmailChangeToken TokenConfig `form:"emailChangeToken" json:"emailChangeToken"`
|
||||
VerificationToken TokenConfig `form:"verificationToken" json:"verificationToken"`
|
||||
FileToken TokenConfig `form:"fileToken" json:"fileToken"`
|
||||
|
||||
// Default email templates
|
||||
// ---
|
||||
VerificationTemplate EmailTemplate `form:"verificationTemplate" json:"verificationTemplate"`
|
||||
ResetPasswordTemplate EmailTemplate `form:"resetPasswordTemplate" json:"resetPasswordTemplate"`
|
||||
ConfirmEmailChangeTemplate EmailTemplate `form:"confirmEmailChangeTemplate" json:"confirmEmailChangeTemplate"`
|
||||
}
|
||||
|
||||
func (o *collectionAuthOptions) validate(cv *collectionValidator) error {
|
||||
err := validation.ValidateStruct(o,
|
||||
validation.Field(
|
||||
&o.AuthRule,
|
||||
validation.By(cv.checkRule),
|
||||
validation.By(cv.ensureNoSystemRuleChange(cv.original.AuthRule)),
|
||||
),
|
||||
validation.Field(
|
||||
&o.ManageRule,
|
||||
validation.NilOrNotEmpty,
|
||||
validation.By(cv.checkRule),
|
||||
validation.By(cv.ensureNoSystemRuleChange(cv.original.ManageRule)),
|
||||
),
|
||||
validation.Field(&o.AuthAlert),
|
||||
validation.Field(&o.PasswordAuth),
|
||||
validation.Field(&o.OAuth2),
|
||||
validation.Field(&o.OTP),
|
||||
validation.Field(&o.MFA),
|
||||
validation.Field(&o.AuthToken),
|
||||
validation.Field(&o.PasswordResetToken),
|
||||
validation.Field(&o.EmailChangeToken),
|
||||
validation.Field(&o.VerificationToken),
|
||||
validation.Field(&o.FileToken),
|
||||
validation.Field(&o.VerificationTemplate, validation.Required),
|
||||
validation.Field(&o.ResetPasswordTemplate, validation.Required),
|
||||
validation.Field(&o.ConfirmEmailChangeTemplate, validation.Required),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if o.MFA.Enabled {
|
||||
// if MFA is enabled require at least 2 auth methods
|
||||
//
|
||||
// @todo maybe consider disabling the check because if custom auth methods
|
||||
// are registered it may fail since we don't have mechanism to detect them at the moment
|
||||
authsEnabled := 0
|
||||
if o.PasswordAuth.Enabled {
|
||||
authsEnabled++
|
||||
}
|
||||
if o.OAuth2.Enabled {
|
||||
authsEnabled++
|
||||
}
|
||||
if o.OTP.Enabled {
|
||||
authsEnabled++
|
||||
}
|
||||
if authsEnabled < 2 {
|
||||
return validation.Errors{
|
||||
"mfa": validation.Errors{
|
||||
"enabled": validation.NewError("validation_mfa_not_enough_auths", "MFA requires at least 2 auth methods to be enabled."),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
if o.MFA.Rule != "" {
|
||||
mfaRuleValidators := []validation.RuleFunc{
|
||||
cv.checkRule,
|
||||
cv.ensureNoSystemRuleChange(&cv.original.MFA.Rule),
|
||||
}
|
||||
|
||||
for _, validator := range mfaRuleValidators {
|
||||
err := validator(&o.MFA.Rule)
|
||||
if err != nil {
|
||||
return validation.Errors{
|
||||
"mfa": validation.Errors{
|
||||
"rule": err,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// extra check to ensure that only unique identity fields are used
|
||||
if o.PasswordAuth.Enabled {
|
||||
err = validation.Validate(o.PasswordAuth.IdentityFields, validation.By(cv.checkFieldsForUniqueIndex))
|
||||
if err != nil {
|
||||
return validation.Errors{
|
||||
"passwordAuth": validation.Errors{
|
||||
"identityFields": err,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
type EmailTemplate struct {
|
||||
Subject string `form:"subject" json:"subject"`
|
||||
Body string `form:"body" json:"body"`
|
||||
}
|
||||
|
||||
// Validate makes EmailTemplate validatable by implementing [validation.Validatable] interface.
|
||||
func (t EmailTemplate) Validate() error {
|
||||
return validation.ValidateStruct(&t,
|
||||
validation.Field(&t.Subject, validation.Required),
|
||||
validation.Field(&t.Body, validation.Required),
|
||||
)
|
||||
}
|
||||
|
||||
// Resolve replaces the placeholder parameters in the current email
|
||||
// template and returns its components as ready-to-use strings.
|
||||
func (t EmailTemplate) Resolve(placeholders map[string]any) (subject, body string) {
|
||||
body = t.Body
|
||||
subject = t.Subject
|
||||
|
||||
for k, v := range placeholders {
|
||||
vStr := cast.ToString(v)
|
||||
|
||||
// replace subject placeholder params (if any)
|
||||
subject = strings.ReplaceAll(subject, k, vStr)
|
||||
|
||||
// replace body placeholder params (if any)
|
||||
body = strings.ReplaceAll(body, k, vStr)
|
||||
}
|
||||
|
||||
return subject, body
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
type AuthAlertConfig struct {
|
||||
Enabled bool `form:"enabled" json:"enabled"`
|
||||
EmailTemplate EmailTemplate `form:"emailTemplate" json:"emailTemplate"`
|
||||
}
|
||||
|
||||
// Validate makes AuthAlertConfig validatable by implementing [validation.Validatable] interface.
|
||||
func (c AuthAlertConfig) Validate() error {
|
||||
return validation.ValidateStruct(&c,
|
||||
// note: for now always run the email template validations even
|
||||
// if not enabled since it could be used separately
|
||||
validation.Field(&c.EmailTemplate),
|
||||
)
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
type TokenConfig struct {
|
||||
Secret string `form:"secret" json:"secret,omitempty"`
|
||||
|
||||
// Duration specifies how long an issued token to be valid (in seconds)
|
||||
Duration int64 `form:"duration" json:"duration"`
|
||||
}
|
||||
|
||||
// Validate makes TokenConfig validatable by implementing [validation.Validatable] interface.
|
||||
func (c TokenConfig) Validate() error {
|
||||
return validation.ValidateStruct(&c,
|
||||
validation.Field(&c.Secret, validation.Required, validation.Length(30, 255)),
|
||||
validation.Field(&c.Duration, validation.Required, validation.Min(10), validation.Max(94670856)), // ~3y max
|
||||
)
|
||||
}
|
||||
|
||||
// DurationTime returns the current Duration as [time.Duration].
|
||||
func (c TokenConfig) DurationTime() time.Duration {
|
||||
return time.Duration(c.Duration) * time.Second
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
type OTPConfig struct {
|
||||
Enabled bool `form:"enabled" json:"enabled"`
|
||||
|
||||
// Duration specifies how long the OTP to be valid (in seconds)
|
||||
Duration int64 `form:"duration" json:"duration"`
|
||||
|
||||
// Length specifies the auto generated password length.
|
||||
Length int `form:"length" json:"length"`
|
||||
|
||||
// EmailTemplate is the default OTP email template that will be send to the auth record.
|
||||
//
|
||||
// In addition to the system placeholders you can also make use of
|
||||
// [core.EmailPlaceholderOTPId] and [core.EmailPlaceholderOTP].
|
||||
EmailTemplate EmailTemplate `form:"emailTemplate" json:"emailTemplate"`
|
||||
}
|
||||
|
||||
// Validate makes OTPConfig validatable by implementing [validation.Validatable] interface.
|
||||
func (c OTPConfig) Validate() error {
|
||||
return validation.ValidateStruct(&c,
|
||||
validation.Field(&c.Duration, validation.When(c.Enabled, validation.Required, validation.Min(10), validation.Max(86400))),
|
||||
validation.Field(&c.Length, validation.When(c.Enabled, validation.Required, validation.Min(4))),
|
||||
// note: for now always run the email template validations even
|
||||
// if not enabled since it could be used separately
|
||||
validation.Field(&c.EmailTemplate),
|
||||
)
|
||||
}
|
||||
|
||||
// DurationTime returns the current Duration as [time.Duration].
|
||||
func (c OTPConfig) DurationTime() time.Duration {
|
||||
return time.Duration(c.Duration) * time.Second
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
type MFAConfig struct {
|
||||
Enabled bool `form:"enabled" json:"enabled"`
|
||||
|
||||
// Duration specifies how long an issued MFA to be valid (in seconds)
|
||||
Duration int64 `form:"duration" json:"duration"`
|
||||
|
||||
// Rule is an optional field to restrict MFA only for the records that satisfy the rule.
|
||||
//
|
||||
// Leave it empty to enable MFA for everyone.
|
||||
Rule string `form:"rule" json:"rule"`
|
||||
}
|
||||
|
||||
// Validate makes MFAConfig validatable by implementing [validation.Validatable] interface.
|
||||
func (c MFAConfig) Validate() error {
|
||||
return validation.ValidateStruct(&c,
|
||||
validation.Field(&c.Duration, validation.When(c.Enabled, validation.Required, validation.Min(10), validation.Max(86400))),
|
||||
)
|
||||
}
|
||||
|
||||
// DurationTime returns the current Duration as [time.Duration].
|
||||
func (c MFAConfig) DurationTime() time.Duration {
|
||||
return time.Duration(c.Duration) * time.Second
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
type PasswordAuthConfig struct {
|
||||
Enabled bool `form:"enabled" json:"enabled"`
|
||||
|
||||
// IdentityFields is a list of field names that could be used as
|
||||
// identity during password authentication.
|
||||
//
|
||||
// Usually only fields that has single column UNIQUE index are accepted as values.
|
||||
IdentityFields []string `form:"identityFields" json:"identityFields"`
|
||||
}
|
||||
|
||||
// Validate makes PasswordAuthConfig validatable by implementing [validation.Validatable] interface.
|
||||
func (c PasswordAuthConfig) Validate() error {
|
||||
// strip duplicated values
|
||||
c.IdentityFields = list.ToUniqueStringSlice(c.IdentityFields)
|
||||
|
||||
if !c.Enabled {
|
||||
return nil // no need to validate
|
||||
}
|
||||
|
||||
return validation.ValidateStruct(&c,
|
||||
validation.Field(&c.IdentityFields, validation.Required),
|
||||
)
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
type OAuth2KnownFields struct {
|
||||
Id string `form:"id" json:"id"`
|
||||
Name string `form:"name" json:"name"`
|
||||
Username string `form:"username" json:"username"`
|
||||
AvatarURL string `form:"avatarURL" json:"avatarURL"`
|
||||
}
|
||||
|
||||
type OAuth2Config struct {
|
||||
Providers []OAuth2ProviderConfig `form:"providers" json:"providers"`
|
||||
|
||||
MappedFields OAuth2KnownFields `form:"mappedFields" json:"mappedFields"`
|
||||
|
||||
Enabled bool `form:"enabled" json:"enabled"`
|
||||
}
|
||||
|
||||
// GetProviderConfig returns the first OAuth2ProviderConfig that matches the specified name.
|
||||
//
|
||||
// Returns false and zero config if no such provider is available in c.Providers.
|
||||
func (c OAuth2Config) GetProviderConfig(name string) (config OAuth2ProviderConfig, exists bool) {
|
||||
for _, p := range c.Providers {
|
||||
if p.Name == name {
|
||||
return p, true
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Validate makes OAuth2Config validatable by implementing [validation.Validatable] interface.
|
||||
func (c OAuth2Config) Validate() error {
|
||||
if !c.Enabled {
|
||||
return nil // no need to validate
|
||||
}
|
||||
|
||||
return validation.ValidateStruct(&c,
|
||||
// note: don't require providers for now as they could be externally registered/removed
|
||||
validation.Field(&c.Providers, validation.By(checkForDuplicatedProviders)),
|
||||
)
|
||||
}
|
||||
|
||||
func checkForDuplicatedProviders(value any) error {
|
||||
configs, _ := value.([]OAuth2ProviderConfig)
|
||||
|
||||
existing := map[string]struct{}{}
|
||||
|
||||
for i, c := range configs {
|
||||
if c.Name == "" {
|
||||
continue // the name nonempty state is validated separately
|
||||
}
|
||||
if _, ok := existing[c.Name]; ok {
|
||||
return validation.Errors{
|
||||
strconv.Itoa(i): validation.Errors{
|
||||
"name": validation.NewError("validation_duplicated_provider", "The provider {{.name}} is already registered.").
|
||||
SetParams(map[string]any{"name": c.Name}),
|
||||
},
|
||||
}
|
||||
}
|
||||
existing[c.Name] = struct{}{}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type OAuth2ProviderConfig struct {
|
||||
// PKCE overwrites the default provider PKCE config option.
|
||||
//
|
||||
// This usually shouldn't be needed but some OAuth2 vendors, like the LinkedIn OIDC,
|
||||
// may require manual adjustment due to returning error if extra parameters are added to the request
|
||||
// (https://github.com/pocketbase/pocketbase/discussions/3799#discussioncomment-7640312)
|
||||
PKCE *bool `form:"pkce" json:"pkce"`
|
||||
|
||||
Name string `form:"name" json:"name"`
|
||||
ClientId string `form:"clientId" json:"clientId"`
|
||||
ClientSecret string `form:"clientSecret" json:"clientSecret,omitempty"`
|
||||
AuthURL string `form:"authURL" json:"authURL"`
|
||||
TokenURL string `form:"tokenURL" json:"tokenURL"`
|
||||
UserInfoURL string `form:"userInfoURL" json:"userInfoURL"`
|
||||
DisplayName string `form:"displayName" json:"displayName"`
|
||||
Extra map[string]any `form:"extra" json:"extra"`
|
||||
}
|
||||
|
||||
// Validate makes OAuth2ProviderConfig validatable by implementing [validation.Validatable] interface.
|
||||
func (c OAuth2ProviderConfig) Validate() error {
|
||||
return validation.ValidateStruct(&c,
|
||||
validation.Field(&c.Name, validation.Required, validation.By(checkProviderName)),
|
||||
validation.Field(&c.ClientId, validation.Required),
|
||||
validation.Field(&c.ClientSecret, validation.Required),
|
||||
validation.Field(&c.AuthURL, is.URL),
|
||||
validation.Field(&c.TokenURL, is.URL),
|
||||
validation.Field(&c.UserInfoURL, is.URL),
|
||||
)
|
||||
}
|
||||
|
||||
func checkProviderName(value any) error {
|
||||
name, _ := value.(string)
|
||||
if name == "" {
|
||||
return nil // nothing to check
|
||||
}
|
||||
|
||||
if _, err := auth.NewProviderByName(name); err != nil {
|
||||
return validation.NewError("validation_missing_provider", "Invalid or missing provider with name {{.name}}.").
|
||||
SetParams(map[string]any{"name": name})
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// InitProvider returns a new auth.Provider instance loaded with the current OAuth2ProviderConfig options.
|
||||
func (c OAuth2ProviderConfig) InitProvider() (auth.Provider, error) {
|
||||
provider, err := auth.NewProviderByName(c.Name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if c.ClientId != "" {
|
||||
provider.SetClientId(c.ClientId)
|
||||
}
|
||||
|
||||
if c.ClientSecret != "" {
|
||||
provider.SetClientSecret(c.ClientSecret)
|
||||
}
|
||||
|
||||
if c.AuthURL != "" {
|
||||
provider.SetAuthURL(c.AuthURL)
|
||||
}
|
||||
|
||||
if c.UserInfoURL != "" {
|
||||
provider.SetUserInfoURL(c.UserInfoURL)
|
||||
}
|
||||
|
||||
if c.TokenURL != "" {
|
||||
provider.SetTokenURL(c.TokenURL)
|
||||
}
|
||||
|
||||
if c.DisplayName != "" {
|
||||
provider.SetDisplayName(c.DisplayName)
|
||||
}
|
||||
|
||||
if c.PKCE != nil {
|
||||
provider.SetPKCE(*c.PKCE)
|
||||
}
|
||||
|
||||
if c.Extra != nil {
|
||||
provider.SetExtra(c.Extra)
|
||||
}
|
||||
|
||||
return provider, nil
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue