1
0
Fork 0

Adding upstream version 0.28.1.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-05-22 10:57:38 +02:00
parent 88f1d47ab6
commit e28c88ef14
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
933 changed files with 194711 additions and 0 deletions

91
tools/archive/create.go Normal file
View file

@ -0,0 +1,91 @@
package archive
import (
"archive/zip"
"compress/flate"
"errors"
"io"
"io/fs"
"os"
"path/filepath"
"strings"
)
// Create creates a new zip archive from src dir content and saves it in dest path.
//
// You can specify skipPaths to skip/ignore certain directories and files (relative to src)
// preventing adding them in the final archive.
func Create(src string, dest string, skipPaths ...string) error {
if err := os.MkdirAll(filepath.Dir(dest), os.ModePerm); err != nil {
return err
}
zf, err := os.Create(dest)
if err != nil {
return err
}
zw := zip.NewWriter(zf)
// register a custom Deflate compressor
zw.RegisterCompressor(zip.Deflate, func(out io.Writer) (io.WriteCloser, error) {
return flate.NewWriter(out, flate.BestSpeed)
})
err = zipAddFS(zw, os.DirFS(src), skipPaths...)
if err != nil {
// try to cleanup at least the created zip file
return errors.Join(err, zw.Close(), zf.Close(), os.Remove(dest))
}
return errors.Join(zw.Close(), zf.Close())
}
// note remove after similar method is added in the std lib (https://github.com/golang/go/issues/54898)
func zipAddFS(w *zip.Writer, fsys fs.FS, skipPaths ...string) error {
return fs.WalkDir(fsys, ".", func(name string, d fs.DirEntry, err error) error {
if err != nil {
return err
}
if d.IsDir() {
return nil
}
// skip
for _, ignore := range skipPaths {
if ignore == name ||
strings.HasPrefix(filepath.Clean(name)+string(os.PathSeparator), filepath.Clean(ignore)+string(os.PathSeparator)) {
return nil
}
}
info, err := d.Info()
if err != nil {
return err
}
h, err := zip.FileInfoHeader(info)
if err != nil {
return err
}
h.Name = name
h.Method = zip.Deflate
fw, err := w.CreateHeader(h)
if err != nil {
return err
}
f, err := fsys.Open(name)
if err != nil {
return err
}
defer f.Close()
_, err = io.Copy(fw, f)
return err
})
}

View file

@ -0,0 +1,125 @@
package archive_test
import (
"os"
"path/filepath"
"testing"
"github.com/pocketbase/pocketbase/tools/archive"
)
func TestCreateFailure(t *testing.T) {
testDir := createTestDir(t)
defer os.RemoveAll(testDir)
zipPath := filepath.Join(os.TempDir(), "pb_test.zip")
defer os.RemoveAll(zipPath)
missingDir := filepath.Join(os.TempDir(), "missing")
if err := archive.Create(missingDir, zipPath); err == nil {
t.Fatal("Expected to fail due to missing directory or file")
}
if _, err := os.Stat(zipPath); err == nil {
t.Fatalf("Expected the zip file not to be created")
}
}
func TestCreateSuccess(t *testing.T) {
testDir := createTestDir(t)
defer os.RemoveAll(testDir)
zipName := "pb_test.zip"
zipPath := filepath.Join(os.TempDir(), zipName)
defer os.RemoveAll(zipPath)
// zip testDir content (excluding test and a/b/c dir)
if err := archive.Create(testDir, zipPath, "a/b/c", "test"); err != nil {
t.Fatalf("Failed to create archive: %v", err)
}
info, err := os.Stat(zipPath)
if err != nil {
t.Fatalf("Failed to retrieve the generated zip file: %v", err)
}
if name := info.Name(); name != zipName {
t.Fatalf("Expected zip with name %q, got %q", zipName, name)
}
expectedSize := int64(544)
if size := info.Size(); size != expectedSize {
t.Fatalf("Expected zip with size %d, got %d", expectedSize, size)
}
}
// -------------------------------------------------------------------
// note: make sure to call os.RemoveAll(dir) after you are done
// working with the created test dir.
func createTestDir(t *testing.T) string {
dir, err := os.MkdirTemp(os.TempDir(), "pb_zip_test")
if err != nil {
t.Fatal(err)
}
if err := os.MkdirAll(filepath.Join(dir, "a/b/c"), os.ModePerm); err != nil {
t.Fatal(err)
}
{
f, err := os.OpenFile(filepath.Join(dir, "test"), os.O_WRONLY|os.O_CREATE, 0644)
if err != nil {
t.Fatal(err)
}
f.Close()
}
{
f, err := os.OpenFile(filepath.Join(dir, "test2"), os.O_WRONLY|os.O_CREATE, 0644)
if err != nil {
t.Fatal(err)
}
f.Close()
}
{
f, err := os.OpenFile(filepath.Join(dir, "a/test"), os.O_WRONLY|os.O_CREATE, 0644)
if err != nil {
t.Fatal(err)
}
f.Close()
}
{
f, err := os.OpenFile(filepath.Join(dir, "a/b/sub1"), os.O_WRONLY|os.O_CREATE, 0644)
if err != nil {
t.Fatal(err)
}
f.Close()
}
{
f, err := os.OpenFile(filepath.Join(dir, "a/b/c/sub2"), os.O_WRONLY|os.O_CREATE, 0644)
if err != nil {
t.Fatal(err)
}
f.Close()
}
{
f, err := os.OpenFile(filepath.Join(dir, "a/b/c/sub3"), os.O_WRONLY|os.O_CREATE, 0644)
if err != nil {
t.Fatal(err)
}
f.Close()
}
// symbolic link
if err := os.Symlink(filepath.Join(dir, "test"), filepath.Join(dir, "test_symlink")); err != nil {
t.Fatal(err)
}
return dir
}

77
tools/archive/extract.go Normal file
View file

@ -0,0 +1,77 @@
package archive
import (
"archive/zip"
"fmt"
"io"
"os"
"path/filepath"
"strings"
)
// Extract extracts the zip archive at "src" to "dest".
//
// Note that only dirs and regular files will be extracted.
// Symbolic links, named pipes, sockets, or any other irregular files
// are skipped because they come with too many edge cases and ambiguities.
func Extract(src, dest string) error {
zr, err := zip.OpenReader(src)
if err != nil {
return err
}
defer zr.Close()
// normalize dest path to check later for Zip Slip
dest = filepath.Clean(dest) + string(os.PathSeparator)
for _, f := range zr.File {
err := extractFile(f, dest)
if err != nil {
return err
}
}
return nil
}
// extractFile extracts the provided zipFile into "basePath/zipFileName" path,
// creating all the necessary path directories.
func extractFile(zipFile *zip.File, basePath string) error {
path := filepath.Join(basePath, zipFile.Name)
// check for Zip Slip
if !strings.HasPrefix(path, basePath) {
return fmt.Errorf("invalid file path: %s", path)
}
r, err := zipFile.Open()
if err != nil {
return err
}
defer r.Close()
// allow only dirs or regular files
if zipFile.FileInfo().IsDir() {
if err := os.MkdirAll(path, os.ModePerm); err != nil {
return err
}
} else if zipFile.FileInfo().Mode().IsRegular() {
// ensure that the file path directories are created
if err := os.MkdirAll(filepath.Dir(path), os.ModePerm); err != nil {
return err
}
f, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, zipFile.Mode())
if err != nil {
return err
}
defer f.Close()
_, err = io.Copy(f, r)
if err != nil {
return err
}
}
return nil
}

View file

@ -0,0 +1,88 @@
package archive_test
import (
"io/fs"
"os"
"path/filepath"
"testing"
"github.com/pocketbase/pocketbase/tools/archive"
)
func TestExtractFailure(t *testing.T) {
testDir := createTestDir(t)
defer os.RemoveAll(testDir)
missingZipPath := filepath.Join(os.TempDir(), "pb_missing_test.zip")
extractedPath := filepath.Join(os.TempDir(), "pb_zip_extract")
defer os.RemoveAll(extractedPath)
if err := archive.Extract(missingZipPath, extractedPath); err == nil {
t.Fatal("Expected Extract to fail due to missing zipPath")
}
if _, err := os.Stat(extractedPath); err == nil {
t.Fatalf("Expected %q to not be created", extractedPath)
}
}
func TestExtractSuccess(t *testing.T) {
testDir := createTestDir(t)
defer os.RemoveAll(testDir)
zipPath := filepath.Join(os.TempDir(), "pb_test.zip")
defer os.RemoveAll(zipPath)
extractedPath := filepath.Join(os.TempDir(), "pb_zip_extract")
defer os.RemoveAll(extractedPath)
// zip testDir content (with exclude)
if err := archive.Create(testDir, zipPath, "a/b/c", "test2", "sub2"); err != nil {
t.Fatalf("Failed to create archive: %v", err)
}
if err := archive.Extract(zipPath, extractedPath); err != nil {
t.Fatalf("Failed to extract %q in %q", zipPath, extractedPath)
}
availableFiles := []string{}
walkErr := filepath.WalkDir(extractedPath, func(path string, d fs.DirEntry, err error) error {
if err != nil {
return err
}
if d.IsDir() {
return nil
}
availableFiles = append(availableFiles, path)
return nil
})
if walkErr != nil {
t.Fatalf("Failed to read the extracted dir: %v", walkErr)
}
// (note: symbolic links and other regular files should be missing)
expectedFiles := []string{
filepath.Join(extractedPath, "test"),
filepath.Join(extractedPath, "a/test"),
filepath.Join(extractedPath, "a/b/sub1"),
}
if len(availableFiles) != len(expectedFiles) {
t.Fatalf("Expected \n%v, \ngot \n%v", expectedFiles, availableFiles)
}
ExpectedLoop:
for _, expected := range expectedFiles {
for _, available := range availableFiles {
if available == expected {
continue ExpectedLoop
}
}
t.Fatalf("Missing file %q in \n%v", expected, availableFiles)
}
}

166
tools/auth/apple.go Normal file
View file

@ -0,0 +1,166 @@
package auth
import (
"context"
"encoding/json"
"errors"
"strings"
"github.com/golang-jwt/jwt/v5"
"github.com/pocketbase/pocketbase/tools/types"
"github.com/spf13/cast"
"golang.org/x/oauth2"
)
func init() {
Providers[NameApple] = wrapFactory(NewAppleProvider)
}
var _ Provider = (*Apple)(nil)
// NameApple is the unique name of the Apple provider.
const NameApple string = "apple"
// Apple allows authentication via Apple OAuth2.
//
// [OIDC differences]: https://bitbucket.org/openid/connect/src/master/How-Sign-in-with-Apple-differs-from-OpenID-Connect.md
type Apple struct {
BaseProvider
jwksURL string
}
// NewAppleProvider creates a new Apple provider instance with some defaults.
func NewAppleProvider() *Apple {
return &Apple{
BaseProvider: BaseProvider{
ctx: context.Background(),
displayName: "Apple",
pkce: true,
scopes: []string{"name", "email"},
authURL: "https://appleid.apple.com/auth/authorize",
tokenURL: "https://appleid.apple.com/auth/token",
},
jwksURL: "https://appleid.apple.com/auth/keys",
}
}
// FetchAuthUser returns an AuthUser instance based on the provided token.
//
// API reference: https://developer.apple.com/documentation/sign_in_with_apple/tokenresponse.
func (p *Apple) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
data, err := p.FetchRawUserInfo(token)
if err != nil {
return nil, err
}
rawUser := map[string]any{}
if err := json.Unmarshal(data, &rawUser); err != nil {
return nil, err
}
extracted := struct {
Id string `json:"sub"`
Name string `json:"name"`
Email string `json:"email"`
EmailVerified any `json:"email_verified"` // could be string or bool
User struct {
Name struct {
FirstName string `json:"firstName"`
LastName string `json:"lastName"`
} `json:"name"`
} `json:"user"`
}{}
if err := json.Unmarshal(data, &extracted); err != nil {
return nil, err
}
user := &AuthUser{
Id: extracted.Id,
Name: extracted.Name,
RawUser: rawUser,
AccessToken: token.AccessToken,
RefreshToken: token.RefreshToken,
}
user.Expiry, _ = types.ParseDateTime(token.Expiry)
if cast.ToBool(extracted.EmailVerified) {
user.Email = extracted.Email
}
if user.Name == "" {
user.Name = strings.TrimSpace(extracted.User.Name.FirstName + " " + extracted.User.Name.LastName)
}
return user, nil
}
// FetchRawUserInfo implements Provider.FetchRawUserInfo interface.
//
// Apple doesn't have a UserInfo endpoint and claims about users
// are instead included in the "id_token" (https://openid.net/specs/openid-connect-core-1_0.html#id_tokenExample)
func (p *Apple) FetchRawUserInfo(token *oauth2.Token) ([]byte, error) {
idToken, _ := token.Extra("id_token").(string)
claims, err := p.parseAndVerifyIdToken(idToken)
if err != nil {
return nil, err
}
// Apple only returns the user object the first time the user authorizes the app
// https://developer.apple.com/documentation/sign_in_with_apple/sign_in_with_apple_js/configuring_your_webpage_for_sign_in_with_apple#3331292
rawUser, _ := token.Extra("user").(string)
if rawUser != "" {
user := map[string]any{}
err = json.Unmarshal([]byte(rawUser), &user)
if err != nil {
return nil, err
}
claims["user"] = user
}
return json.Marshal(claims)
}
func (p *Apple) parseAndVerifyIdToken(idToken string) (jwt.MapClaims, error) {
if idToken == "" {
return nil, errors.New("empty id_token")
}
// extract the token header params and claims
// ---
claims := jwt.MapClaims{}
t, _, err := jwt.NewParser().ParseUnverified(idToken, claims)
if err != nil {
return nil, err
}
// validate common claims per https://developer.apple.com/documentation/sign_in_with_apple/sign_in_with_apple_rest_api/verifying_a_user#3383769
// ---
jwtValidator := jwt.NewValidator(
jwt.WithExpirationRequired(),
jwt.WithIssuedAt(),
jwt.WithLeeway(idTokenLeeway),
jwt.WithIssuer("https://appleid.apple.com"),
jwt.WithAudience(p.clientId),
)
err = jwtValidator.Validate(claims)
if err != nil {
return nil, err
}
// validate id_token signature
//
// note: this step could be technically considered optional because we trust
// the token which is a result of direct TLS communication with the provider
// (see also https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation)
// ---
kid, _ := t.Header["kid"].(string)
err = validateIdTokenSignature(p.ctx, idToken, p.jwksURL, kid)
if err != nil {
return nil, err
}
return claims, nil
}

157
tools/auth/auth.go Normal file
View file

@ -0,0 +1,157 @@
package auth
import (
"context"
"encoding/json"
"errors"
"net/http"
"github.com/pocketbase/pocketbase/tools/types"
"golang.org/x/oauth2"
)
// ProviderFactoryFunc defines a function for initializing a new OAuth2 provider.
type ProviderFactoryFunc func() Provider
// Providers defines a map with all of the available OAuth2 providers.
//
// To register a new provider append a new entry in the map.
var Providers = map[string]ProviderFactoryFunc{}
// NewProviderByName returns a new preconfigured provider instance by its name identifier.
func NewProviderByName(name string) (Provider, error) {
factory, ok := Providers[name]
if !ok {
return nil, errors.New("missing provider " + name)
}
return factory(), nil
}
// Provider defines a common interface for an OAuth2 client.
type Provider interface {
// Context returns the context associated with the provider (if any).
Context() context.Context
// SetContext assigns the specified context to the current provider.
SetContext(ctx context.Context)
// PKCE indicates whether the provider can use the PKCE flow.
PKCE() bool
// SetPKCE toggles the state whether the provider can use the PKCE flow or not.
SetPKCE(enable bool)
// DisplayName usually returns provider name as it is officially written
// and it could be used directly in the UI.
DisplayName() string
// SetDisplayName sets the provider's display name.
SetDisplayName(displayName string)
// Scopes returns the provider access permissions that will be requested.
Scopes() []string
// SetScopes sets the provider access permissions that will be requested later.
SetScopes(scopes []string)
// ClientId returns the provider client's app ID.
ClientId() string
// SetClientId sets the provider client's ID.
SetClientId(clientId string)
// ClientSecret returns the provider client's app secret.
ClientSecret() string
// SetClientSecret sets the provider client's app secret.
SetClientSecret(secret string)
// RedirectURL returns the end address to redirect the user
// going through the OAuth flow.
RedirectURL() string
// SetRedirectURL sets the provider's RedirectURL.
SetRedirectURL(url string)
// AuthURL returns the provider's authorization service url.
AuthURL() string
// SetAuthURL sets the provider's AuthURL.
SetAuthURL(url string)
// TokenURL returns the provider's token exchange service url.
TokenURL() string
// SetTokenURL sets the provider's TokenURL.
SetTokenURL(url string)
// UserInfoURL returns the provider's user info api url.
UserInfoURL() string
// SetUserInfoURL sets the provider's UserInfoURL.
SetUserInfoURL(url string)
// Extra returns a shallow copy of any custom config data
// that the provider may be need.
Extra() map[string]any
// SetExtra updates the provider's custom config data.
SetExtra(data map[string]any)
// Client returns an http client using the provided token.
Client(token *oauth2.Token) *http.Client
// BuildAuthURL returns a URL to the provider's consent page
// that asks for permissions for the required scopes explicitly.
BuildAuthURL(state string, opts ...oauth2.AuthCodeOption) string
// FetchToken converts an authorization code to token.
FetchToken(code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error)
// FetchRawUserInfo requests and marshalizes into `result` the
// the OAuth user api response.
FetchRawUserInfo(token *oauth2.Token) ([]byte, error)
// FetchAuthUser is similar to FetchRawUserInfo, but normalizes and
// marshalizes the user api response into a standardized AuthUser struct.
FetchAuthUser(token *oauth2.Token) (user *AuthUser, err error)
}
// wrapFactory is a helper that wraps a Provider specific factory
// function and returns its result as Provider interface.
func wrapFactory[T Provider](factory func() T) ProviderFactoryFunc {
return func() Provider {
return factory()
}
}
// AuthUser defines a standardized OAuth2 user data structure.
type AuthUser struct {
Expiry types.DateTime `json:"expiry"`
RawUser map[string]any `json:"rawUser"`
Id string `json:"id"`
Name string `json:"name"`
Username string `json:"username"`
Email string `json:"email"`
AvatarURL string `json:"avatarURL"`
AccessToken string `json:"accessToken"`
RefreshToken string `json:"refreshToken"`
// @todo
// deprecated: use AvatarURL instead
// AvatarUrl will be removed after dropping v0.22 support
AvatarUrl string `json:"avatarUrl"`
}
// MarshalJSON implements the [json.Marshaler] interface.
//
// @todo remove after dropping v0.22 support
func (au AuthUser) MarshalJSON() ([]byte, error) {
type alias AuthUser // prevent recursion
au2 := alias(au)
au2.AvatarUrl = au.AvatarURL // ensure that the legacy field is populated
return json.Marshal(au2)
}

299
tools/auth/auth_test.go Normal file
View file

@ -0,0 +1,299 @@
package auth_test
import (
"testing"
"github.com/pocketbase/pocketbase/tools/auth"
)
func TestProvidersCount(t *testing.T) {
expected := 30
if total := len(auth.Providers); total != expected {
t.Fatalf("Expected %d providers, got %d", expected, total)
}
}
func TestNewProviderByName(t *testing.T) {
var err error
var p auth.Provider
// invalid
p, err = auth.NewProviderByName("invalid")
if err == nil {
t.Error("Expected error, got nil")
}
if p != nil {
t.Errorf("Expected provider to be nil, got %v", p)
}
// google
p, err = auth.NewProviderByName(auth.NameGoogle)
if err != nil {
t.Errorf("Expected nil, got error %v", err)
}
if _, ok := p.(*auth.Google); !ok {
t.Error("Expected to be instance of *auth.Google")
}
// facebook
p, err = auth.NewProviderByName(auth.NameFacebook)
if err != nil {
t.Errorf("Expected nil, got error %v", err)
}
if _, ok := p.(*auth.Facebook); !ok {
t.Error("Expected to be instance of *auth.Facebook")
}
// github
p, err = auth.NewProviderByName(auth.NameGithub)
if err != nil {
t.Errorf("Expected nil, got error %v", err)
}
if _, ok := p.(*auth.Github); !ok {
t.Error("Expected to be instance of *auth.Github")
}
// gitlab
p, err = auth.NewProviderByName(auth.NameGitlab)
if err != nil {
t.Errorf("Expected nil, got error %v", err)
}
if _, ok := p.(*auth.Gitlab); !ok {
t.Error("Expected to be instance of *auth.Gitlab")
}
// twitter
p, err = auth.NewProviderByName(auth.NameTwitter)
if err != nil {
t.Errorf("Expected nil, got error %v", err)
}
if _, ok := p.(*auth.Twitter); !ok {
t.Error("Expected to be instance of *auth.Twitter")
}
// discord
p, err = auth.NewProviderByName(auth.NameDiscord)
if err != nil {
t.Errorf("Expected nil, got error %v", err)
}
if _, ok := p.(*auth.Discord); !ok {
t.Error("Expected to be instance of *auth.Discord")
}
// microsoft
p, err = auth.NewProviderByName(auth.NameMicrosoft)
if err != nil {
t.Errorf("Expected nil, got error %v", err)
}
if _, ok := p.(*auth.Microsoft); !ok {
t.Error("Expected to be instance of *auth.Microsoft")
}
// spotify
p, err = auth.NewProviderByName(auth.NameSpotify)
if err != nil {
t.Errorf("Expected nil, got error %v", err)
}
if _, ok := p.(*auth.Spotify); !ok {
t.Error("Expected to be instance of *auth.Spotify")
}
// kakao
p, err = auth.NewProviderByName(auth.NameKakao)
if err != nil {
t.Errorf("Expected nil, got error %v", err)
}
if _, ok := p.(*auth.Kakao); !ok {
t.Error("Expected to be instance of *auth.Kakao")
}
// twitch
p, err = auth.NewProviderByName(auth.NameTwitch)
if err != nil {
t.Errorf("Expected nil, got error %v", err)
}
if _, ok := p.(*auth.Twitch); !ok {
t.Error("Expected to be instance of *auth.Twitch")
}
// strava
p, err = auth.NewProviderByName(auth.NameStrava)
if err != nil {
t.Errorf("Expected nil, got error %v", err)
}
if _, ok := p.(*auth.Strava); !ok {
t.Error("Expected to be instance of *auth.Strava")
}
// gitee
p, err = auth.NewProviderByName(auth.NameGitee)
if err != nil {
t.Errorf("Expected nil, got error %v", err)
}
if _, ok := p.(*auth.Gitee); !ok {
t.Error("Expected to be instance of *auth.Gitee")
}
// livechat
p, err = auth.NewProviderByName(auth.NameLivechat)
if err != nil {
t.Errorf("Expected nil, got error %v", err)
}
if _, ok := p.(*auth.Livechat); !ok {
t.Error("Expected to be instance of *auth.Livechat")
}
// gitea
p, err = auth.NewProviderByName(auth.NameGitea)
if err != nil {
t.Errorf("Expected nil, got error %v", err)
}
if _, ok := p.(*auth.Gitea); !ok {
t.Error("Expected to be instance of *auth.Gitea")
}
// oidc
p, err = auth.NewProviderByName(auth.NameOIDC)
if err != nil {
t.Errorf("Expected nil, got error %v", err)
}
if _, ok := p.(*auth.OIDC); !ok {
t.Error("Expected to be instance of *auth.OIDC")
}
// oidc2
p, err = auth.NewProviderByName(auth.NameOIDC + "2")
if err != nil {
t.Errorf("Expected nil, got error %v", err)
}
if _, ok := p.(*auth.OIDC); !ok {
t.Error("Expected to be instance of *auth.OIDC")
}
// oidc3
p, err = auth.NewProviderByName(auth.NameOIDC + "3")
if err != nil {
t.Errorf("Expected nil, got error %v", err)
}
if _, ok := p.(*auth.OIDC); !ok {
t.Error("Expected to be instance of *auth.OIDC")
}
// apple
p, err = auth.NewProviderByName(auth.NameApple)
if err != nil {
t.Errorf("Expected nil, got error %v", err)
}
if _, ok := p.(*auth.Apple); !ok {
t.Error("Expected to be instance of *auth.Apple")
}
// instagram
p, err = auth.NewProviderByName(auth.NameInstagram)
if err != nil {
t.Errorf("Expected nil, got error %v", err)
}
if _, ok := p.(*auth.Instagram); !ok {
t.Error("Expected to be instance of *auth.Instagram")
}
// vk
p, err = auth.NewProviderByName(auth.NameVK)
if err != nil {
t.Errorf("Expected nil, got error %v", err)
}
if _, ok := p.(*auth.VK); !ok {
t.Error("Expected to be instance of *auth.VK")
}
// yandex
p, err = auth.NewProviderByName(auth.NameYandex)
if err != nil {
t.Errorf("Expected nil, got error %v", err)
}
if _, ok := p.(*auth.Yandex); !ok {
t.Error("Expected to be instance of *auth.Yandex")
}
// patreon
p, err = auth.NewProviderByName(auth.NamePatreon)
if err != nil {
t.Errorf("Expected nil, got error %v", err)
}
if _, ok := p.(*auth.Patreon); !ok {
t.Error("Expected to be instance of *auth.Patreon")
}
// mailcow
p, err = auth.NewProviderByName(auth.NameMailcow)
if err != nil {
t.Errorf("Expected nil, got error %v", err)
}
if _, ok := p.(*auth.Mailcow); !ok {
t.Error("Expected to be instance of *auth.Mailcow")
}
// bitbucket
p, err = auth.NewProviderByName(auth.NameBitbucket)
if err != nil {
t.Errorf("Expected nil, got error %v", err)
}
if _, ok := p.(*auth.Bitbucket); !ok {
t.Error("Expected to be instance of *auth.Bitbucket")
}
// planningcenter
p, err = auth.NewProviderByName(auth.NamePlanningcenter)
if err != nil {
t.Errorf("Expected nil, got error %v", err)
}
if _, ok := p.(*auth.Planningcenter); !ok {
t.Error("Expected to be instance of *auth.Planningcenter")
}
// notion
p, err = auth.NewProviderByName(auth.NameNotion)
if err != nil {
t.Errorf("Expected nil, got error %v", err)
}
if _, ok := p.(*auth.Notion); !ok {
t.Error("Expected to be instance of *auth.Notion")
}
// monday
p, err = auth.NewProviderByName(auth.NameMonday)
if err != nil {
t.Errorf("Expected nil, got error %v", err)
}
if _, ok := p.(*auth.Monday); !ok {
t.Error("Expected to be instance of *auth.Monday")
}
// wakatime
p, err = auth.NewProviderByName(auth.NameWakatime)
if err != nil {
t.Errorf("Expected nil, got error %v", err)
}
if _, ok := p.(*auth.Wakatime); !ok {
t.Error("Expected to be instance of *auth.Wakatime")
}
// linear
p, err = auth.NewProviderByName(auth.NameLinear)
if err != nil {
t.Errorf("Expected nil, got error %v", err)
}
if _, ok := p.(*auth.Linear); !ok {
t.Error("Expected to be instance of *auth.Linear")
}
// trakt
p, err = auth.NewProviderByName(auth.NameTrakt)
if err != nil {
t.Errorf("Expected nil, got error %v", err)
}
if _, ok := p.(*auth.Trakt); !ok {
t.Error("Expected to be instance of *auth.Trakt")
}
}

203
tools/auth/base_provider.go Normal file
View file

@ -0,0 +1,203 @@
package auth
import (
"context"
"fmt"
"io"
"maps"
"net/http"
"golang.org/x/oauth2"
)
// BaseProvider defines common fields and methods used by OAuth2 client providers.
type BaseProvider struct {
ctx context.Context
clientId string
clientSecret string
displayName string
redirectURL string
authURL string
tokenURL string
userInfoURL string
scopes []string
pkce bool
extra map[string]any
}
// Context implements Provider.Context() interface method.
func (p *BaseProvider) Context() context.Context {
return p.ctx
}
// SetContext implements Provider.SetContext() interface method.
func (p *BaseProvider) SetContext(ctx context.Context) {
p.ctx = ctx
}
// PKCE implements Provider.PKCE() interface method.
func (p *BaseProvider) PKCE() bool {
return p.pkce
}
// SetPKCE implements Provider.SetPKCE() interface method.
func (p *BaseProvider) SetPKCE(enable bool) {
p.pkce = enable
}
// DisplayName implements Provider.DisplayName() interface method.
func (p *BaseProvider) DisplayName() string {
return p.displayName
}
// SetDisplayName implements Provider.SetDisplayName() interface method.
func (p *BaseProvider) SetDisplayName(displayName string) {
p.displayName = displayName
}
// Scopes implements Provider.Scopes() interface method.
func (p *BaseProvider) Scopes() []string {
return p.scopes
}
// SetScopes implements Provider.SetScopes() interface method.
func (p *BaseProvider) SetScopes(scopes []string) {
p.scopes = scopes
}
// ClientId implements Provider.ClientId() interface method.
func (p *BaseProvider) ClientId() string {
return p.clientId
}
// SetClientId implements Provider.SetClientId() interface method.
func (p *BaseProvider) SetClientId(clientId string) {
p.clientId = clientId
}
// ClientSecret implements Provider.ClientSecret() interface method.
func (p *BaseProvider) ClientSecret() string {
return p.clientSecret
}
// SetClientSecret implements Provider.SetClientSecret() interface method.
func (p *BaseProvider) SetClientSecret(secret string) {
p.clientSecret = secret
}
// RedirectURL implements Provider.RedirectURL() interface method.
func (p *BaseProvider) RedirectURL() string {
return p.redirectURL
}
// SetRedirectURL implements Provider.SetRedirectURL() interface method.
func (p *BaseProvider) SetRedirectURL(url string) {
p.redirectURL = url
}
// AuthURL implements Provider.AuthURL() interface method.
func (p *BaseProvider) AuthURL() string {
return p.authURL
}
// SetAuthURL implements Provider.SetAuthURL() interface method.
func (p *BaseProvider) SetAuthURL(url string) {
p.authURL = url
}
// TokenURL implements Provider.TokenURL() interface method.
func (p *BaseProvider) TokenURL() string {
return p.tokenURL
}
// SetTokenURL implements Provider.SetTokenURL() interface method.
func (p *BaseProvider) SetTokenURL(url string) {
p.tokenURL = url
}
// UserInfoURL implements Provider.UserInfoURL() interface method.
func (p *BaseProvider) UserInfoURL() string {
return p.userInfoURL
}
// SetUserInfoURL implements Provider.SetUserInfoURL() interface method.
func (p *BaseProvider) SetUserInfoURL(url string) {
p.userInfoURL = url
}
// Extra implements Provider.Extra() interface method.
func (p *BaseProvider) Extra() map[string]any {
return maps.Clone(p.extra)
}
// SetExtra implements Provider.SetExtra() interface method.
func (p *BaseProvider) SetExtra(data map[string]any) {
p.extra = data
}
// BuildAuthURL implements Provider.BuildAuthURL() interface method.
func (p *BaseProvider) BuildAuthURL(state string, opts ...oauth2.AuthCodeOption) string {
return p.oauth2Config().AuthCodeURL(state, opts...)
}
// FetchToken implements Provider.FetchToken() interface method.
func (p *BaseProvider) FetchToken(code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
return p.oauth2Config().Exchange(p.ctx, code, opts...)
}
// Client implements Provider.Client() interface method.
func (p *BaseProvider) Client(token *oauth2.Token) *http.Client {
return p.oauth2Config().Client(p.ctx, token)
}
// FetchRawUserInfo implements Provider.FetchRawUserInfo() interface method.
func (p *BaseProvider) FetchRawUserInfo(token *oauth2.Token) ([]byte, error) {
req, err := http.NewRequestWithContext(p.ctx, "GET", p.userInfoURL, nil)
if err != nil {
return nil, err
}
return p.sendRawUserInfoRequest(req, token)
}
// sendRawUserInfoRequest sends the specified user info request and return its raw response body.
func (p *BaseProvider) sendRawUserInfoRequest(req *http.Request, token *oauth2.Token) ([]byte, error) {
client := p.Client(token)
res, err := client.Do(req)
if err != nil {
return nil, err
}
defer res.Body.Close()
result, err := io.ReadAll(res.Body)
if err != nil {
return nil, err
}
// http.Client.Get doesn't treat non 2xx responses as error
if res.StatusCode >= 400 {
return nil, fmt.Errorf(
"failed to fetch OAuth2 user profile via %s (%d):\n%s",
p.userInfoURL,
res.StatusCode,
string(result),
)
}
return result, nil
}
// oauth2Config constructs a oauth2.Config instance based on the provider settings.
func (p *BaseProvider) oauth2Config() *oauth2.Config {
return &oauth2.Config{
RedirectURL: p.redirectURL,
ClientID: p.clientId,
ClientSecret: p.clientSecret,
Scopes: p.scopes,
Endpoint: oauth2.Endpoint{
AuthURL: p.authURL,
TokenURL: p.tokenURL,
},
}
}

View file

@ -0,0 +1,269 @@
package auth
import (
"bytes"
"context"
"encoding/json"
"testing"
"golang.org/x/oauth2"
)
func TestContext(t *testing.T) {
b := BaseProvider{}
before := b.Scopes()
if before != nil {
t.Errorf("Expected nil context, got %v", before)
}
b.SetContext(context.Background())
after := b.Scopes()
if after != nil {
t.Error("Expected non-nil context")
}
}
func TestDisplayName(t *testing.T) {
b := BaseProvider{}
before := b.DisplayName()
if before != "" {
t.Fatalf("Expected displayName to be empty, got %v", before)
}
b.SetDisplayName("test")
after := b.DisplayName()
if after != "test" {
t.Fatalf("Expected displayName to be 'test', got %v", after)
}
}
func TestPKCE(t *testing.T) {
b := BaseProvider{}
before := b.PKCE()
if before != false {
t.Fatalf("Expected pkce to be %v, got %v", false, before)
}
b.SetPKCE(true)
after := b.PKCE()
if after != true {
t.Fatalf("Expected pkce to be %v, got %v", true, after)
}
}
func TestScopes(t *testing.T) {
b := BaseProvider{}
before := b.Scopes()
if len(before) != 0 {
t.Fatalf("Expected 0 scopes, got %v", before)
}
b.SetScopes([]string{"test1", "test2"})
after := b.Scopes()
if len(after) != 2 {
t.Fatalf("Expected 2 scopes, got %v", after)
}
}
func TestClientId(t *testing.T) {
b := BaseProvider{}
before := b.ClientId()
if before != "" {
t.Fatalf("Expected clientId to be empty, got %v", before)
}
b.SetClientId("test")
after := b.ClientId()
if after != "test" {
t.Fatalf("Expected clientId to be 'test', got %v", after)
}
}
func TestClientSecret(t *testing.T) {
b := BaseProvider{}
before := b.ClientSecret()
if before != "" {
t.Fatalf("Expected clientSecret to be empty, got %v", before)
}
b.SetClientSecret("test")
after := b.ClientSecret()
if after != "test" {
t.Fatalf("Expected clientSecret to be 'test', got %v", after)
}
}
func TestRedirectURL(t *testing.T) {
b := BaseProvider{}
before := b.RedirectURL()
if before != "" {
t.Fatalf("Expected RedirectURL to be empty, got %v", before)
}
b.SetRedirectURL("test")
after := b.RedirectURL()
if after != "test" {
t.Fatalf("Expected RedirectURL to be 'test', got %v", after)
}
}
func TestAuthURL(t *testing.T) {
b := BaseProvider{}
before := b.AuthURL()
if before != "" {
t.Fatalf("Expected authURL to be empty, got %v", before)
}
b.SetAuthURL("test")
after := b.AuthURL()
if after != "test" {
t.Fatalf("Expected authURL to be 'test', got %v", after)
}
}
func TestTokenURL(t *testing.T) {
b := BaseProvider{}
before := b.TokenURL()
if before != "" {
t.Fatalf("Expected tokenURL to be empty, got %v", before)
}
b.SetTokenURL("test")
after := b.TokenURL()
if after != "test" {
t.Fatalf("Expected tokenURL to be 'test', got %v", after)
}
}
func TestUserInfoURL(t *testing.T) {
b := BaseProvider{}
before := b.UserInfoURL()
if before != "" {
t.Fatalf("Expected userInfoURL to be empty, got %v", before)
}
b.SetUserInfoURL("test")
after := b.UserInfoURL()
if after != "test" {
t.Fatalf("Expected userInfoURL to be 'test', got %v", after)
}
}
func TestExtra(t *testing.T) {
b := BaseProvider{}
before := b.Extra()
if before != nil {
t.Fatalf("Expected extra to be empty, got %v", before)
}
extra := map[string]any{"a": 1, "b": 2}
b.SetExtra(extra)
after := b.Extra()
rawExtra, err := json.Marshal(extra)
if err != nil {
t.Fatal(err)
}
rawAfter, err := json.Marshal(after)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(rawExtra, rawAfter) {
t.Fatalf("Expected extra to be\n%s\ngot\n%s", rawExtra, rawAfter)
}
// ensure that it was shallow copied
after["b"] = 3
if d := b.Extra(); d["b"] != 2 {
t.Fatalf("Expected extra to remain unchanged, got\n%v", d)
}
}
func TestBuildAuthURL(t *testing.T) {
b := BaseProvider{
authURL: "authURL_test",
tokenURL: "tokenURL_test",
redirectURL: "redirectURL_test",
clientId: "clientId_test",
clientSecret: "clientSecret_test",
scopes: []string{"test_scope"},
}
expected := "authURL_test?access_type=offline&client_id=clientId_test&prompt=consent&redirect_uri=redirectURL_test&response_type=code&scope=test_scope&state=state_test"
result := b.BuildAuthURL("state_test", oauth2.AccessTypeOffline, oauth2.ApprovalForce)
if result != expected {
t.Errorf("Expected auth url %q, got %q", expected, result)
}
}
func TestClient(t *testing.T) {
b := BaseProvider{}
result := b.Client(&oauth2.Token{})
if result == nil {
t.Error("Expected *http.Client instance, got nil")
}
}
func TestOauth2Config(t *testing.T) {
b := BaseProvider{
authURL: "authURL_test",
tokenURL: "tokenURL_test",
redirectURL: "redirectURL_test",
clientId: "clientId_test",
clientSecret: "clientSecret_test",
scopes: []string{"test"},
}
result := b.oauth2Config()
if result.RedirectURL != b.RedirectURL() {
t.Errorf("Expected redirectURL %s, got %s", b.RedirectURL(), result.RedirectURL)
}
if result.ClientID != b.ClientId() {
t.Errorf("Expected clientId %s, got %s", b.ClientId(), result.ClientID)
}
if result.ClientSecret != b.ClientSecret() {
t.Errorf("Expected clientSecret %s, got %s", b.ClientSecret(), result.ClientSecret)
}
if result.Endpoint.AuthURL != b.AuthURL() {
t.Errorf("Expected authURL %s, got %s", b.AuthURL(), result.Endpoint.AuthURL)
}
if result.Endpoint.TokenURL != b.TokenURL() {
t.Errorf("Expected authURL %s, got %s", b.TokenURL(), result.Endpoint.TokenURL)
}
if len(result.Scopes) != len(b.Scopes()) || result.Scopes[0] != b.Scopes()[0] {
t.Errorf("Expected scopes %s, got %s", b.Scopes(), result.Scopes)
}
}

136
tools/auth/bitbucket.go Normal file
View file

@ -0,0 +1,136 @@
package auth
import (
"context"
"encoding/json"
"errors"
"io"
"github.com/pocketbase/pocketbase/tools/types"
"golang.org/x/oauth2"
)
func init() {
Providers[NameBitbucket] = wrapFactory(NewBitbucketProvider)
}
var _ Provider = (*Bitbucket)(nil)
// NameBitbucket is the unique name of the Bitbucket provider.
const NameBitbucket = "bitbucket"
// Bitbucket is an auth provider for Bitbucket.
type Bitbucket struct {
BaseProvider
}
// NewBitbucketProvider creates a new Bitbucket provider instance with some defaults.
func NewBitbucketProvider() *Bitbucket {
return &Bitbucket{BaseProvider{
ctx: context.Background(),
displayName: "Bitbucket",
pkce: false,
scopes: []string{"account"},
authURL: "https://bitbucket.org/site/oauth2/authorize",
tokenURL: "https://bitbucket.org/site/oauth2/access_token",
userInfoURL: "https://api.bitbucket.org/2.0/user",
}}
}
// FetchAuthUser returns an AuthUser instance based on the Bitbucket's user API.
//
// API reference: https://developer.atlassian.com/cloud/bitbucket/rest/api-group-users/#api-user-get
func (p *Bitbucket) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
data, err := p.FetchRawUserInfo(token)
if err != nil {
return nil, err
}
rawUser := map[string]any{}
if err := json.Unmarshal(data, &rawUser); err != nil {
return nil, err
}
extracted := struct {
UUID string `json:"uuid"`
Username string `json:"username"`
DisplayName string `json:"display_name"`
AccountStatus string `json:"account_status"`
Links struct {
Avatar struct {
Href string `json:"href"`
} `json:"avatar"`
} `json:"links"`
}{}
if err := json.Unmarshal(data, &extracted); err != nil {
return nil, err
}
if extracted.AccountStatus != "active" {
return nil, errors.New("the Bitbucket user is not active")
}
email, err := p.fetchPrimaryEmail(token)
if err != nil {
return nil, err
}
user := &AuthUser{
Id: extracted.UUID,
Name: extracted.DisplayName,
Username: extracted.Username,
Email: email,
AvatarURL: extracted.Links.Avatar.Href,
RawUser: rawUser,
AccessToken: token.AccessToken,
RefreshToken: token.RefreshToken,
}
user.Expiry, _ = types.ParseDateTime(token.Expiry)
return user, nil
}
// fetchPrimaryEmail sends an API request to retrieve the first
// verified primary email.
//
// NB! This method can succeed and still return an empty email.
// Error responses that are result of insufficient scopes permissions are ignored.
//
// API reference: https://developer.atlassian.com/cloud/bitbucket/rest/api-group-users/#api-user-emails-get
func (p *Bitbucket) fetchPrimaryEmail(token *oauth2.Token) (string, error) {
response, err := p.Client(token).Get(p.userInfoURL + "/emails")
if err != nil {
return "", err
}
defer response.Body.Close()
// ignore common http errors caused by insufficient scope permissions
// (the email field is optional, aka. return the auth user without it)
if response.StatusCode >= 400 {
return "", nil
}
data, err := io.ReadAll(response.Body)
if err != nil {
return "", err
}
expected := struct {
Values []struct {
Email string `json:"email"`
IsPrimary bool `json:"is_primary"`
} `json:"values"`
}{}
if err := json.Unmarshal(data, &expected); err != nil {
return "", err
}
for _, v := range expected.Values {
if v.IsPrimary {
return v.Email, nil
}
}
return "", nil
}

91
tools/auth/discord.go Normal file
View file

@ -0,0 +1,91 @@
package auth
import (
"context"
"encoding/json"
"fmt"
"github.com/pocketbase/pocketbase/tools/types"
"golang.org/x/oauth2"
)
func init() {
Providers[NameDiscord] = wrapFactory(NewDiscordProvider)
}
var _ Provider = (*Discord)(nil)
// NameDiscord is the unique name of the Discord provider.
const NameDiscord string = "discord"
// Discord allows authentication via Discord OAuth2.
type Discord struct {
BaseProvider
}
// NewDiscordProvider creates a new Discord provider instance with some defaults.
func NewDiscordProvider() *Discord {
// https://discord.com/developers/docs/topics/oauth2
// https://discord.com/developers/docs/resources/user#get-current-user
return &Discord{BaseProvider{
ctx: context.Background(),
displayName: "Discord",
pkce: true,
scopes: []string{"identify", "email"},
authURL: "https://discord.com/api/oauth2/authorize",
tokenURL: "https://discord.com/api/oauth2/token",
userInfoURL: "https://discord.com/api/users/@me",
}}
}
// FetchAuthUser returns an AuthUser instance from Discord's user api.
//
// API reference: https://discord.com/developers/docs/resources/user#user-object
func (p *Discord) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
data, err := p.FetchRawUserInfo(token)
if err != nil {
return nil, err
}
rawUser := map[string]any{}
if err := json.Unmarshal(data, &rawUser); err != nil {
return nil, err
}
extracted := struct {
Id string `json:"id"`
Username string `json:"username"`
Discriminator string `json:"discriminator"`
Avatar string `json:"avatar"`
Email string `json:"email"`
Verified bool `json:"verified"`
}{}
if err := json.Unmarshal(data, &extracted); err != nil {
return nil, err
}
// Build a full avatar URL using the avatar hash provided in the API response
// https://discord.com/developers/docs/reference#image-formatting
avatarURL := fmt.Sprintf("https://cdn.discordapp.com/avatars/%s/%s.png", extracted.Id, extracted.Avatar)
// Concatenate the user's username and discriminator into a single username string
username := fmt.Sprintf("%s#%s", extracted.Username, extracted.Discriminator)
user := &AuthUser{
Id: extracted.Id,
Name: username,
Username: extracted.Username,
AvatarURL: avatarURL,
RawUser: rawUser,
AccessToken: token.AccessToken,
RefreshToken: token.RefreshToken,
}
user.Expiry, _ = types.ParseDateTime(token.Expiry)
if extracted.Verified {
user.Email = extracted.Email
}
return user, nil
}

78
tools/auth/facebook.go Normal file
View file

@ -0,0 +1,78 @@
package auth
import (
"context"
"encoding/json"
"github.com/pocketbase/pocketbase/tools/types"
"golang.org/x/oauth2"
"golang.org/x/oauth2/facebook"
)
func init() {
Providers[NameFacebook] = wrapFactory(NewFacebookProvider)
}
var _ Provider = (*Facebook)(nil)
// NameFacebook is the unique name of the Facebook provider.
const NameFacebook string = "facebook"
// Facebook allows authentication via Facebook OAuth2.
type Facebook struct {
BaseProvider
}
// NewFacebookProvider creates new Facebook provider instance with some defaults.
func NewFacebookProvider() *Facebook {
return &Facebook{BaseProvider{
ctx: context.Background(),
displayName: "Facebook",
pkce: true,
scopes: []string{"email"},
authURL: facebook.Endpoint.AuthURL,
tokenURL: facebook.Endpoint.TokenURL,
userInfoURL: "https://graph.facebook.com/me?fields=name,email,picture.type(large)",
}}
}
// FetchAuthUser returns an AuthUser instance based on the Facebook's user api.
//
// API reference: https://developers.facebook.com/docs/graph-api/reference/user/
func (p *Facebook) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
data, err := p.FetchRawUserInfo(token)
if err != nil {
return nil, err
}
rawUser := map[string]any{}
if err := json.Unmarshal(data, &rawUser); err != nil {
return nil, err
}
extracted := struct {
Id string
Name string
Email string
Picture struct {
Data struct{ Url string }
}
}{}
if err := json.Unmarshal(data, &extracted); err != nil {
return nil, err
}
user := &AuthUser{
Id: extracted.Id,
Name: extracted.Name,
Email: extracted.Email,
AvatarURL: extracted.Picture.Data.Url,
RawUser: rawUser,
AccessToken: token.AccessToken,
RefreshToken: token.RefreshToken,
}
user.Expiry, _ = types.ParseDateTime(token.Expiry)
return user, nil
}

78
tools/auth/gitea.go Normal file
View file

@ -0,0 +1,78 @@
package auth
import (
"context"
"encoding/json"
"strconv"
"github.com/pocketbase/pocketbase/tools/types"
"golang.org/x/oauth2"
)
func init() {
Providers[NameGitea] = wrapFactory(NewGiteaProvider)
}
var _ Provider = (*Gitea)(nil)
// NameGitea is the unique name of the Gitea provider.
const NameGitea string = "gitea"
// Gitea allows authentication via Gitea OAuth2.
type Gitea struct {
BaseProvider
}
// NewGiteaProvider creates new Gitea provider instance with some defaults.
func NewGiteaProvider() *Gitea {
return &Gitea{BaseProvider{
ctx: context.Background(),
displayName: "Gitea",
pkce: true,
scopes: []string{"read:user", "user:email"},
authURL: "https://gitea.com/login/oauth/authorize",
tokenURL: "https://gitea.com/login/oauth/access_token",
userInfoURL: "https://gitea.com/api/v1/user",
}}
}
// FetchAuthUser returns an AuthUser instance based on Gitea's user api.
//
// API reference: https://try.gitea.io/api/swagger#/user/userGetCurrent
func (p *Gitea) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
data, err := p.FetchRawUserInfo(token)
if err != nil {
return nil, err
}
rawUser := map[string]any{}
if err := json.Unmarshal(data, &rawUser); err != nil {
return nil, err
}
extracted := struct {
Name string `json:"full_name"`
Username string `json:"login"`
Email string `json:"email"`
AvatarURL string `json:"avatar_url"`
Id int64 `json:"id"`
}{}
if err := json.Unmarshal(data, &extracted); err != nil {
return nil, err
}
user := &AuthUser{
Id: strconv.FormatInt(extracted.Id, 10),
Name: extracted.Name,
Username: extracted.Username,
Email: extracted.Email,
AvatarURL: extracted.AvatarURL,
RawUser: rawUser,
AccessToken: token.AccessToken,
RefreshToken: token.RefreshToken,
}
user.Expiry, _ = types.ParseDateTime(token.Expiry)
return user, nil
}

142
tools/auth/gitee.go Normal file
View file

@ -0,0 +1,142 @@
package auth
import (
"context"
"encoding/json"
"io"
"strconv"
"github.com/go-ozzo/ozzo-validation/v4/is"
"github.com/pocketbase/pocketbase/tools/types"
"golang.org/x/oauth2"
)
func init() {
Providers[NameGitee] = wrapFactory(NewGiteeProvider)
}
var _ Provider = (*Gitee)(nil)
// NameGitee is the unique name of the Gitee provider.
const NameGitee string = "gitee"
// Gitee allows authentication via Gitee OAuth2.
type Gitee struct {
BaseProvider
}
// NewGiteeProvider creates new Gitee provider instance with some defaults.
func NewGiteeProvider() *Gitee {
return &Gitee{BaseProvider{
ctx: context.Background(),
displayName: "Gitee",
pkce: true,
scopes: []string{"user_info", "emails"},
authURL: "https://gitee.com/oauth/authorize",
tokenURL: "https://gitee.com/oauth/token",
userInfoURL: "https://gitee.com/api/v5/user",
}}
}
// FetchAuthUser returns an AuthUser instance based the Gitee's user api.
//
// API reference: https://gitee.com/api/v5/swagger#/getV5User
func (p *Gitee) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
data, err := p.FetchRawUserInfo(token)
if err != nil {
return nil, err
}
rawUser := map[string]any{}
if err := json.Unmarshal(data, &rawUser); err != nil {
return nil, err
}
extracted := struct {
Login string `json:"login"`
Name string `json:"name"`
Email string `json:"email"`
AvatarURL string `json:"avatar_url"`
Id int64 `json:"id"`
}{}
if err := json.Unmarshal(data, &extracted); err != nil {
return nil, err
}
user := &AuthUser{
Id: strconv.FormatInt(extracted.Id, 10),
Name: extracted.Name,
Username: extracted.Login,
AvatarURL: extracted.AvatarURL,
RawUser: rawUser,
AccessToken: token.AccessToken,
RefreshToken: token.RefreshToken,
}
user.Expiry, _ = types.ParseDateTime(token.Expiry)
if extracted.Email != "" && is.EmailFormat.Validate(extracted.Email) == nil {
// valid public primary email
user.Email = extracted.Email
} else {
// send an additional optional request to retrieve the email
email, err := p.fetchPrimaryEmail(token)
if err != nil {
return nil, err
}
user.Email = email
}
return user, nil
}
// fetchPrimaryEmail sends an API request to retrieve the verified primary email,
// in case the user hasn't set "Public email address" or has unchecked
// the "Access your emails data" permission during authentication.
//
// NB! This method can succeed and still return an empty email.
// Error responses that are result of insufficient scopes permissions are ignored.
//
// API reference: https://gitee.com/api/v5/swagger#/getV5Emails
func (p *Gitee) fetchPrimaryEmail(token *oauth2.Token) (string, error) {
client := p.Client(token)
response, err := client.Get("https://gitee.com/api/v5/emails")
if err != nil {
return "", err
}
defer response.Body.Close()
// ignore common http errors caused by insufficient scope permissions
if response.StatusCode == 401 || response.StatusCode == 403 || response.StatusCode == 404 {
return "", nil
}
content, err := io.ReadAll(response.Body)
if err != nil {
return "", err
}
emails := []struct {
Email string
State string
Scope []string
}{}
if err := json.Unmarshal(content, &emails); err != nil {
// ignore unmarshal error in case "Keep my email address private"
// was set because response.Body will be something like:
// {"email":"12285415+test@user.noreply.gitee.com"}
return "", nil
}
// extract the first verified primary email
for _, email := range emails {
for _, scope := range email.Scope {
if email.State == "confirmed" && scope == "primary" && is.EmailFormat.Validate(email.Email) == nil {
return email.Email, nil
}
}
}
return "", nil
}

136
tools/auth/github.go Normal file
View file

@ -0,0 +1,136 @@
package auth
import (
"context"
"encoding/json"
"io"
"strconv"
"github.com/pocketbase/pocketbase/tools/types"
"golang.org/x/oauth2"
"golang.org/x/oauth2/github"
)
func init() {
Providers[NameGithub] = wrapFactory(NewGithubProvider)
}
var _ Provider = (*Github)(nil)
// NameGithub is the unique name of the Github provider.
const NameGithub string = "github"
// Github allows authentication via Github OAuth2.
type Github struct {
BaseProvider
}
// NewGithubProvider creates new Github provider instance with some defaults.
func NewGithubProvider() *Github {
return &Github{BaseProvider{
ctx: context.Background(),
displayName: "GitHub",
pkce: true, // technically is not supported yet but it is safe as the PKCE params are just ignored
scopes: []string{"read:user", "user:email"},
authURL: github.Endpoint.AuthURL,
tokenURL: github.Endpoint.TokenURL,
userInfoURL: "https://api.github.com/user",
}}
}
// FetchAuthUser returns an AuthUser instance based the Github's user api.
//
// API reference: https://docs.github.com/en/rest/reference/users#get-the-authenticated-user
func (p *Github) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
data, err := p.FetchRawUserInfo(token)
if err != nil {
return nil, err
}
rawUser := map[string]any{}
if err := json.Unmarshal(data, &rawUser); err != nil {
return nil, err
}
extracted := struct {
Login string `json:"login"`
Name string `json:"name"`
Email string `json:"email"`
AvatarURL string `json:"avatar_url"`
Id int64 `json:"id"`
}{}
if err := json.Unmarshal(data, &extracted); err != nil {
return nil, err
}
user := &AuthUser{
Id: strconv.FormatInt(extracted.Id, 10),
Name: extracted.Name,
Username: extracted.Login,
Email: extracted.Email,
AvatarURL: extracted.AvatarURL,
RawUser: rawUser,
AccessToken: token.AccessToken,
RefreshToken: token.RefreshToken,
}
user.Expiry, _ = types.ParseDateTime(token.Expiry)
// in case user has set "Keep my email address private", send an
// **optional** API request to retrieve the verified primary email
if user.Email == "" {
email, err := p.fetchPrimaryEmail(token)
if err != nil {
return nil, err
}
user.Email = email
}
return user, nil
}
// fetchPrimaryEmail sends an API request to retrieve the verified
// primary email, in case "Keep my email address private" was set.
//
// NB! This method can succeed and still return an empty email.
// Error responses that are result of insufficient scopes permissions are ignored.
//
// API reference: https://docs.github.com/en/rest/users/emails?apiVersion=2022-11-28
func (p *Github) fetchPrimaryEmail(token *oauth2.Token) (string, error) {
client := p.Client(token)
response, err := client.Get(p.userInfoURL + "/emails")
if err != nil {
return "", err
}
defer response.Body.Close()
// ignore common http errors caused by insufficient scope permissions
// (the email field is optional, aka. return the auth user without it)
if response.StatusCode == 401 || response.StatusCode == 403 || response.StatusCode == 404 {
return "", nil
}
content, err := io.ReadAll(response.Body)
if err != nil {
return "", err
}
emails := []struct {
Email string
Verified bool
Primary bool
}{}
if err := json.Unmarshal(content, &emails); err != nil {
return "", err
}
// extract the verified primary email
for _, email := range emails {
if email.Verified && email.Primary {
return email.Email, nil
}
}
return "", nil
}

78
tools/auth/gitlab.go Normal file
View file

@ -0,0 +1,78 @@
package auth
import (
"context"
"encoding/json"
"strconv"
"github.com/pocketbase/pocketbase/tools/types"
"golang.org/x/oauth2"
)
func init() {
Providers[NameGitlab] = wrapFactory(NewGitlabProvider)
}
var _ Provider = (*Gitlab)(nil)
// NameGitlab is the unique name of the Gitlab provider.
const NameGitlab string = "gitlab"
// Gitlab allows authentication via Gitlab OAuth2.
type Gitlab struct {
BaseProvider
}
// NewGitlabProvider creates new Gitlab provider instance with some defaults.
func NewGitlabProvider() *Gitlab {
return &Gitlab{BaseProvider{
ctx: context.Background(),
displayName: "GitLab",
pkce: true,
scopes: []string{"read_user"},
authURL: "https://gitlab.com/oauth/authorize",
tokenURL: "https://gitlab.com/oauth/token",
userInfoURL: "https://gitlab.com/api/v4/user",
}}
}
// FetchAuthUser returns an AuthUser instance based the Gitlab's user api.
//
// API reference: https://docs.gitlab.com/ee/api/users.html#for-admin
func (p *Gitlab) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
data, err := p.FetchRawUserInfo(token)
if err != nil {
return nil, err
}
rawUser := map[string]any{}
if err := json.Unmarshal(data, &rawUser); err != nil {
return nil, err
}
extracted := struct {
Name string `json:"name"`
Username string `json:"username"`
Email string `json:"email"`
AvatarURL string `json:"avatar_url"`
Id int64 `json:"id"`
}{}
if err := json.Unmarshal(data, &extracted); err != nil {
return nil, err
}
user := &AuthUser{
Id: strconv.FormatInt(extracted.Id, 10),
Name: extracted.Name,
Username: extracted.Username,
Email: extracted.Email,
AvatarURL: extracted.AvatarURL,
RawUser: rawUser,
AccessToken: token.AccessToken,
RefreshToken: token.RefreshToken,
}
user.Expiry, _ = types.ParseDateTime(token.Expiry)
return user, nil
}

80
tools/auth/google.go Normal file
View file

@ -0,0 +1,80 @@
package auth
import (
"context"
"encoding/json"
"github.com/pocketbase/pocketbase/tools/types"
"golang.org/x/oauth2"
)
func init() {
Providers[NameGoogle] = wrapFactory(NewGoogleProvider)
}
var _ Provider = (*Google)(nil)
// NameGoogle is the unique name of the Google provider.
const NameGoogle string = "google"
// Google allows authentication via Google OAuth2.
type Google struct {
BaseProvider
}
// NewGoogleProvider creates new Google provider instance with some defaults.
func NewGoogleProvider() *Google {
return &Google{BaseProvider{
ctx: context.Background(),
displayName: "Google",
pkce: true,
scopes: []string{
"https://www.googleapis.com/auth/userinfo.profile",
"https://www.googleapis.com/auth/userinfo.email",
},
authURL: "https://accounts.google.com/o/oauth2/v2/auth",
tokenURL: "https://oauth2.googleapis.com/token",
userInfoURL: "https://www.googleapis.com/oauth2/v3/userinfo",
}}
}
// FetchAuthUser returns an AuthUser instance based the Google's user api.
func (p *Google) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
data, err := p.FetchRawUserInfo(token)
if err != nil {
return nil, err
}
rawUser := map[string]any{}
if err := json.Unmarshal(data, &rawUser); err != nil {
return nil, err
}
extracted := struct {
Id string `json:"sub"`
Name string `json:"name"`
Picture string `json:"picture"`
Email string `json:"email"`
EmailVerified bool `json:"email_verified"`
}{}
if err := json.Unmarshal(data, &extracted); err != nil {
return nil, err
}
user := &AuthUser{
Id: extracted.Id,
Name: extracted.Name,
AvatarURL: extracted.Picture,
RawUser: rawUser,
AccessToken: token.AccessToken,
RefreshToken: token.RefreshToken,
}
user.Expiry, _ = types.ParseDateTime(token.Expiry)
if extracted.EmailVerified {
user.Email = extracted.Email
}
return user, nil
}

82
tools/auth/instagram.go Normal file
View file

@ -0,0 +1,82 @@
package auth
import (
"context"
"encoding/json"
"github.com/pocketbase/pocketbase/tools/types"
"golang.org/x/oauth2"
)
func init() {
Providers[NameInstagram] = wrapFactory(NewInstagramProvider)
}
var _ Provider = (*Instagram)(nil)
// NameInstagram is the unique name of the Instagram provider.
const NameInstagram string = "instagram2" // "2" suffix to avoid conflicts with the old deprecated version
// Instagram allows authentication via Instagram Login OAuth2.
type Instagram struct {
BaseProvider
}
// NewInstagramProvider creates new Instagram provider instance with some defaults.
func NewInstagramProvider() *Instagram {
return &Instagram{BaseProvider{
ctx: context.Background(),
displayName: "Instagram",
pkce: true,
scopes: []string{"instagram_business_basic"},
authURL: "https://www.instagram.com/oauth/authorize",
tokenURL: "https://api.instagram.com/oauth/access_token",
userInfoURL: "https://graph.instagram.com/me?fields=id,username,account_type,user_id,name,profile_picture_url,followers_count,follows_count,media_count",
}}
}
// FetchAuthUser returns an AuthUser instance based on the Instagram Login user api response.
//
// API reference: https://developers.facebook.com/docs/instagram-platform/instagram-api-with-instagram-login/get-started#fields
func (p *Instagram) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
data, err := p.FetchRawUserInfo(token)
if err != nil {
return nil, err
}
rawUser := map[string]any{}
if err := json.Unmarshal(data, &rawUser); err != nil {
return nil, err
}
// include list of granted permissions to RawUser's payload
if _, ok := rawUser["permissions"]; !ok {
if permissions := token.Extra("permissions"); permissions != nil {
rawUser["permissions"] = permissions
}
}
extracted := struct {
Id string `json:"user_id"`
Name string `json:"name"`
Username string `json:"username"`
AvatarURL string `json:"profile_picture_url"`
}{}
if err := json.Unmarshal(data, &extracted); err != nil {
return nil, err
}
user := &AuthUser{
Id: extracted.Id,
Username: extracted.Username,
Name: extracted.Name,
AvatarURL: extracted.AvatarURL,
RawUser: rawUser,
AccessToken: token.AccessToken,
RefreshToken: token.RefreshToken,
}
user.Expiry, _ = types.ParseDateTime(token.Expiry)
return user, nil
}

86
tools/auth/kakao.go Normal file
View file

@ -0,0 +1,86 @@
package auth
import (
"context"
"encoding/json"
"strconv"
"github.com/pocketbase/pocketbase/tools/types"
"golang.org/x/oauth2"
"golang.org/x/oauth2/kakao"
)
func init() {
Providers[NameKakao] = wrapFactory(NewKakaoProvider)
}
var _ Provider = (*Kakao)(nil)
// NameKakao is the unique name of the Kakao provider.
const NameKakao string = "kakao"
// Kakao allows authentication via Kakao OAuth2.
type Kakao struct {
BaseProvider
}
// NewKakaoProvider creates a new Kakao provider instance with some defaults.
func NewKakaoProvider() *Kakao {
return &Kakao{BaseProvider{
ctx: context.Background(),
displayName: "Kakao",
pkce: true,
scopes: []string{"account_email", "profile_nickname", "profile_image"},
authURL: kakao.Endpoint.AuthURL,
tokenURL: kakao.Endpoint.TokenURL,
userInfoURL: "https://kapi.kakao.com/v2/user/me",
}}
}
// FetchAuthUser returns an AuthUser instance based on the Kakao's user api.
//
// API reference: https://developers.kakao.com/docs/latest/en/kakaologin/rest-api#req-user-info-response
func (p *Kakao) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
data, err := p.FetchRawUserInfo(token)
if err != nil {
return nil, err
}
rawUser := map[string]any{}
if err := json.Unmarshal(data, &rawUser); err != nil {
return nil, err
}
extracted := struct {
Profile struct {
Nickname string `json:"nickname"`
ImageURL string `json:"profile_image"`
} `json:"properties"`
KakaoAccount struct {
Email string `json:"email"`
IsEmailVerified bool `json:"is_email_verified"`
IsEmailValid bool `json:"is_email_valid"`
} `json:"kakao_account"`
Id int64 `json:"id"`
}{}
if err := json.Unmarshal(data, &extracted); err != nil {
return nil, err
}
user := &AuthUser{
Id: strconv.FormatInt(extracted.Id, 10),
Username: extracted.Profile.Nickname,
AvatarURL: extracted.Profile.ImageURL,
RawUser: rawUser,
AccessToken: token.AccessToken,
RefreshToken: token.RefreshToken,
}
user.Expiry, _ = types.ParseDateTime(token.Expiry)
if extracted.KakaoAccount.IsEmailValid && extracted.KakaoAccount.IsEmailVerified {
user.Email = extracted.KakaoAccount.Email
}
return user, nil
}

109
tools/auth/linear.go Normal file
View file

@ -0,0 +1,109 @@
package auth
import (
"bytes"
"context"
"encoding/json"
"errors"
"net/http"
"github.com/pocketbase/pocketbase/tools/types"
"golang.org/x/oauth2"
)
func init() {
Providers[NameLinear] = wrapFactory(NewLinearProvider)
}
var _ Provider = (*Linear)(nil)
// NameLinear is the unique name of the Linear provider.
const NameLinear string = "linear"
// Linear allows authentication via Linear OAuth2.
type Linear struct {
BaseProvider
}
// NewLinearProvider creates new Linear provider instance with some defaults.
//
// API reference: https://developers.linear.app/docs/oauth/authentication
func NewLinearProvider() *Linear {
return &Linear{BaseProvider{
ctx: context.Background(),
displayName: "Linear",
pkce: false, // Linear doesn't support PKCE at the moment and returns an error if enabled
scopes: []string{"read"},
authURL: "https://linear.app/oauth/authorize",
tokenURL: "https://api.linear.app/oauth/token",
userInfoURL: "https://api.linear.app/graphql",
}}
}
// FetchAuthUser returns an AuthUser instance based on the Linear's user api.
//
// API reference: https://developers.linear.app/docs/graphql/working-with-the-graphql-api#authentication
func (p *Linear) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
data, err := p.FetchRawUserInfo(token)
if err != nil {
return nil, err
}
rawUser := map[string]any{}
if err := json.Unmarshal(data, &rawUser); err != nil {
return nil, err
}
extracted := struct {
Data struct {
Viewer struct {
Id string `json:"id"`
DisplayName string `json:"displayName"`
Name string `json:"name"`
Email string `json:"email"`
AvatarURL string `json:"avatarUrl"`
Active bool `json:"active"`
} `json:"viewer"`
} `json:"data"`
}{}
if err := json.Unmarshal(data, &extracted); err != nil {
return nil, err
}
if !extracted.Data.Viewer.Active {
return nil, errors.New("the Linear user account is not active")
}
user := &AuthUser{
Id: extracted.Data.Viewer.Id,
Name: extracted.Data.Viewer.Name,
Username: extracted.Data.Viewer.DisplayName,
Email: extracted.Data.Viewer.Email,
AvatarURL: extracted.Data.Viewer.AvatarURL,
RawUser: rawUser,
AccessToken: token.AccessToken,
RefreshToken: token.RefreshToken,
}
user.Expiry, _ = types.ParseDateTime(token.Expiry)
return user, nil
}
// FetchRawUserInfo implements Provider.FetchRawUserInfo interface method.
//
// Linear doesn't have a UserInfo endpoint and information on the user
// is retrieved using their GraphQL API (https://developers.linear.app/docs/graphql/working-with-the-graphql-api#queries-and-mutations)
func (p *Linear) FetchRawUserInfo(token *oauth2.Token) ([]byte, error) {
query := []byte(`{"query": "query Me { viewer { id displayName name email avatarUrl active } }"}`)
bodyReader := bytes.NewReader(query)
req, err := http.NewRequestWithContext(p.ctx, "POST", p.userInfoURL, bodyReader)
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
return p.sendRawUserInfoRequest(req, token)
}

79
tools/auth/livechat.go Normal file
View file

@ -0,0 +1,79 @@
package auth
import (
"context"
"encoding/json"
"github.com/pocketbase/pocketbase/tools/types"
"golang.org/x/oauth2"
)
func init() {
Providers[NameLivechat] = wrapFactory(NewLivechatProvider)
}
var _ Provider = (*Livechat)(nil)
// NameLivechat is the unique name of the Livechat provider.
const NameLivechat = "livechat"
// Livechat allows authentication via Livechat OAuth2.
type Livechat struct {
BaseProvider
}
// NewLivechatProvider creates new Livechat provider instance with some defaults.
func NewLivechatProvider() *Livechat {
return &Livechat{BaseProvider{
ctx: context.Background(),
displayName: "LiveChat",
pkce: true,
scopes: []string{}, // default scopes are specified from the provider dashboard
authURL: "https://accounts.livechat.com/",
tokenURL: "https://accounts.livechat.com/token",
userInfoURL: "https://accounts.livechat.com/v2/accounts/me",
}}
}
// FetchAuthUser returns an AuthUser based on the Livechat accounts API.
//
// API reference: https://developers.livechat.com/docs/authorization
func (p *Livechat) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
data, err := p.FetchRawUserInfo(token)
if err != nil {
return nil, err
}
rawUser := map[string]any{}
if err := json.Unmarshal(data, &rawUser); err != nil {
return nil, err
}
extracted := struct {
Id string `json:"account_id"`
Name string `json:"name"`
Email string `json:"email"`
EmailVerified bool `json:"email_verified"`
AvatarURL string `json:"avatar_url"`
}{}
if err := json.Unmarshal(data, &extracted); err != nil {
return nil, err
}
user := &AuthUser{
Id: extracted.Id,
Name: extracted.Name,
AvatarURL: extracted.AvatarURL,
RawUser: rawUser,
AccessToken: token.AccessToken,
RefreshToken: token.RefreshToken,
}
user.Expiry, _ = types.ParseDateTime(token.Expiry)
if extracted.EmailVerified {
user.Email = extracted.Email
}
return user, nil
}

84
tools/auth/mailcow.go Normal file
View file

@ -0,0 +1,84 @@
package auth
import (
"context"
"encoding/json"
"errors"
"strings"
"github.com/pocketbase/pocketbase/tools/types"
"golang.org/x/oauth2"
)
func init() {
Providers[NameMailcow] = wrapFactory(NewMailcowProvider)
}
var _ Provider = (*Mailcow)(nil)
// NameMailcow is the unique name of the mailcow provider.
const NameMailcow string = "mailcow"
// Mailcow allows authentication via mailcow OAuth2.
type Mailcow struct {
BaseProvider
}
// NewMailcowProvider creates a new mailcow provider instance with some defaults.
func NewMailcowProvider() *Mailcow {
return &Mailcow{BaseProvider{
ctx: context.Background(),
displayName: "mailcow",
pkce: true,
scopes: []string{"profile"},
}}
}
// FetchAuthUser returns an AuthUser instance based on mailcow's user api.
//
// API reference: https://github.com/mailcow/mailcow-dockerized/blob/master/data/web/oauth/profile.php
func (p *Mailcow) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
data, err := p.FetchRawUserInfo(token)
if err != nil {
return nil, err
}
rawUser := map[string]any{}
if err := json.Unmarshal(data, &rawUser); err != nil {
return nil, err
}
extracted := struct {
Id string `json:"id"`
Username string `json:"username"`
Email string `json:"email"`
FullName string `json:"full_name"`
Active int `json:"active"`
}{}
if err := json.Unmarshal(data, &extracted); err != nil {
return nil, err
}
if extracted.Active != 1 {
return nil, errors.New("the mailcow user is not active")
}
user := &AuthUser{
Id: extracted.Id,
Name: extracted.FullName,
Username: extracted.Username,
Email: extracted.Email,
RawUser: rawUser,
AccessToken: token.AccessToken,
RefreshToken: token.RefreshToken,
}
user.Expiry, _ = types.ParseDateTime(token.Expiry)
// mailcow usernames are usually just the email adresses, so we just take the part in front of the @
if strings.Contains(user.Username, "@") {
user.Username = strings.Split(user.Username, "@")[0]
}
return user, nil
}

76
tools/auth/microsoft.go Normal file
View file

@ -0,0 +1,76 @@
package auth
import (
"context"
"encoding/json"
"github.com/pocketbase/pocketbase/tools/types"
"golang.org/x/oauth2"
"golang.org/x/oauth2/microsoft"
)
func init() {
Providers[NameMicrosoft] = wrapFactory(NewMicrosoftProvider)
}
var _ Provider = (*Microsoft)(nil)
// NameMicrosoft is the unique name of the Microsoft provider.
const NameMicrosoft string = "microsoft"
// Microsoft allows authentication via AzureADEndpoint OAuth2.
type Microsoft struct {
BaseProvider
}
// NewMicrosoftProvider creates new Microsoft AD provider instance with some defaults.
func NewMicrosoftProvider() *Microsoft {
endpoints := microsoft.AzureADEndpoint("")
return &Microsoft{BaseProvider{
ctx: context.Background(),
displayName: "Microsoft",
pkce: true,
scopes: []string{"User.Read"},
authURL: endpoints.AuthURL,
tokenURL: endpoints.TokenURL,
userInfoURL: "https://graph.microsoft.com/v1.0/me",
}}
}
// FetchAuthUser returns an AuthUser instance based on the Microsoft's user api.
//
// API reference: https://learn.microsoft.com/en-us/azure/active-directory/develop/userinfo
// Graph explorer: https://developer.microsoft.com/en-us/graph/graph-explorer
func (p *Microsoft) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
data, err := p.FetchRawUserInfo(token)
if err != nil {
return nil, err
}
rawUser := map[string]any{}
if err := json.Unmarshal(data, &rawUser); err != nil {
return nil, err
}
extracted := struct {
Id string `json:"id"`
Name string `json:"displayName"`
Email string `json:"mail"`
}{}
if err := json.Unmarshal(data, &extracted); err != nil {
return nil, err
}
user := &AuthUser{
Id: extracted.Id,
Name: extracted.Name,
Email: extracted.Email,
RawUser: rawUser,
AccessToken: token.AccessToken,
RefreshToken: token.RefreshToken,
}
user.Expiry, _ = types.ParseDateTime(token.Expiry)
return user, nil
}

108
tools/auth/monday.go Normal file
View file

@ -0,0 +1,108 @@
package auth
import (
"bytes"
"context"
"encoding/json"
"errors"
"net/http"
"github.com/pocketbase/pocketbase/tools/types"
"golang.org/x/oauth2"
)
func init() {
Providers[NameMonday] = wrapFactory(NewMondayProvider)
}
var _ Provider = (*Monday)(nil)
// NameMonday is the unique name of the Monday provider.
const NameMonday = "monday"
// Monday is an auth provider for monday.com.
type Monday struct {
BaseProvider
}
// NewMondayProvider creates a new Monday provider instance with some defaults.
func NewMondayProvider() *Monday {
return &Monday{BaseProvider{
ctx: context.Background(),
displayName: "monday.com",
pkce: true,
scopes: []string{"me:read"},
authURL: "https://auth.monday.com/oauth2/authorize",
tokenURL: "https://auth.monday.com/oauth2/token",
userInfoURL: "https://api.monday.com/v2",
}}
}
// FetchAuthUser returns an AuthUser instance based on the Monday's user api.
//
// API reference: https://developer.monday.com/api-reference/reference/me
func (p *Monday) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
data, err := p.FetchRawUserInfo(token)
if err != nil {
return nil, err
}
rawUser := map[string]any{}
if err := json.Unmarshal(data, &rawUser); err != nil {
return nil, err
}
extracted := struct {
Data struct {
Me struct {
Id string `json:"id"`
Enabled bool `json:"enabled"`
Name string `json:"name"`
Email string `json:"email"`
IsVerified bool `json:"is_verified"`
Avatar string `json:"photo_small"`
} `json:"me"`
} `json:"data"`
}{}
if err := json.Unmarshal(data, &extracted); err != nil {
return nil, err
}
if !extracted.Data.Me.Enabled {
return nil, errors.New("the monday.com user account is not enabled")
}
user := &AuthUser{
Id: extracted.Data.Me.Id,
Name: extracted.Data.Me.Name,
AvatarURL: extracted.Data.Me.Avatar,
RawUser: rawUser,
AccessToken: token.AccessToken,
RefreshToken: token.RefreshToken,
}
if extracted.Data.Me.IsVerified {
user.Email = extracted.Data.Me.Email
}
user.Expiry, _ = types.ParseDateTime(token.Expiry)
return user, nil
}
// FetchRawUserInfo implements Provider.FetchRawUserInfo interface.
//
// monday.com doesn't have a UserInfo endpoint and information on the user
// is retrieved using their GraphQL API (https://developer.monday.com/api-reference/reference/me#queries)
func (p *Monday) FetchRawUserInfo(token *oauth2.Token) ([]byte, error) {
query := []byte(`{"query": "query { me { id enabled name email is_verified photo_small }}"}`)
bodyReader := bytes.NewReader(query)
req, err := http.NewRequestWithContext(p.ctx, "POST", p.userInfoURL, bodyReader)
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
return p.sendRawUserInfoRequest(req, token)
}

106
tools/auth/notion.go Normal file
View file

@ -0,0 +1,106 @@
package auth
import (
"context"
"encoding/json"
"net/http"
"github.com/pocketbase/pocketbase/tools/types"
"golang.org/x/oauth2"
)
func init() {
Providers[NameNotion] = wrapFactory(NewNotionProvider)
}
var _ Provider = (*Notion)(nil)
// NameNotion is the unique name of the Notion provider.
const NameNotion string = "notion"
// Notion allows authentication via Notion OAuth2.
type Notion struct {
BaseProvider
}
// NewNotionProvider creates new Notion provider instance with some defaults.
func NewNotionProvider() *Notion {
return &Notion{BaseProvider{
ctx: context.Background(),
displayName: "Notion",
pkce: true,
authURL: "https://api.notion.com/v1/oauth/authorize",
tokenURL: "https://api.notion.com/v1/oauth/token",
userInfoURL: "https://api.notion.com/v1/users/me",
}}
}
// FetchAuthUser returns an AuthUser instance based on the Notion's User api.
// API reference: https://developers.notion.com/reference/get-self
func (p *Notion) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
data, err := p.FetchRawUserInfo(token)
if err != nil {
return nil, err
}
rawUser := map[string]any{}
if err := json.Unmarshal(data, &rawUser); err != nil {
return nil, err
}
extracted := struct {
AvatarURL string `json:"avatar_url"`
Bot struct {
Owner struct {
Type string `json:"type"`
User struct {
AvatarURL string `json:"avatar_url"`
Id string `json:"id"`
Name string `json:"name"`
Person struct {
Email string `json:"email"`
} `json:"person"`
} `json:"user"`
} `json:"owner"`
WorkspaceName string `json:"workspace_name"`
} `json:"bot"`
Id string `json:"id"`
Name string `json:"name"`
Object string `json:"object"`
RequestId string `json:"request_id"`
Type string `json:"type"`
}{}
if err := json.Unmarshal(data, &extracted); err != nil {
return nil, err
}
user := &AuthUser{
Id: extracted.Bot.Owner.User.Id,
Name: extracted.Bot.Owner.User.Name,
Email: extracted.Bot.Owner.User.Person.Email,
AvatarURL: extracted.Bot.Owner.User.AvatarURL,
RawUser: rawUser,
AccessToken: token.AccessToken,
RefreshToken: token.RefreshToken,
}
user.Expiry, _ = types.ParseDateTime(token.Expiry)
return user, nil
}
// FetchRawUserInfo implements Provider.FetchRawUserInfo interface method.
//
// This differ from BaseProvider because Notion requires a version header for all requests
// (https://developers.notion.com/reference/versioning).
func (p *Notion) FetchRawUserInfo(token *oauth2.Token) ([]byte, error) {
req, err := http.NewRequestWithContext(p.ctx, "GET", p.userInfoURL, nil)
if err != nil {
return nil, err
}
req.Header.Set("Notion-Version", "2022-06-28")
return p.sendRawUserInfoRequest(req, token)
}

292
tools/auth/oidc.go Normal file
View file

@ -0,0 +1,292 @@
package auth
import (
"context"
"crypto/rsa"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"math/big"
"net/http"
"os"
"strconv"
"strings"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/pocketbase/pocketbase/tools/security"
"github.com/pocketbase/pocketbase/tools/types"
"github.com/spf13/cast"
"golang.org/x/oauth2"
)
// idTokenLeeway is the optional leeway for the id_token timestamp fields validation.
//
// It can be changed externally using the PB_ID_TOKEN_LEEWAY env variable
// (the value must be in seconds, e.g. "PB_ID_TOKEN_LEEWAY=60" for 1 minute).
var idTokenLeeway time.Duration = 5 * time.Minute
func init() {
Providers[NameOIDC] = wrapFactory(NewOIDCProvider)
Providers[NameOIDC+"2"] = wrapFactory(NewOIDCProvider)
Providers[NameOIDC+"3"] = wrapFactory(NewOIDCProvider)
if leewayStr := os.Getenv("PB_ID_TOKEN_LEEWAY"); leewayStr != "" {
leeway, err := strconv.Atoi(leewayStr)
if err == nil {
idTokenLeeway = time.Duration(leeway) * time.Second
}
}
}
var _ Provider = (*OIDC)(nil)
// NameOIDC is the unique name of the OpenID Connect (OIDC) provider.
const NameOIDC string = "oidc"
// OIDC allows authentication via OpenID Connect (OIDC) OAuth2 provider.
//
// If specified the user data is fetched from the userInfoURL.
// Otherwise - from the id_token payload.
//
// The provider support the following Extra config options:
// - "jwksURL" - url to the keys to validate the id_token signature (optional and used only when reading the user data from the id_token)
// - "issuers" - list of valid issuers for the iss id_token claim (optioanl and used only when reading the user data from the id_token)
type OIDC struct {
BaseProvider
}
// NewOIDCProvider creates new OpenID Connect (OIDC) provider instance with some defaults.
func NewOIDCProvider() *OIDC {
return &OIDC{BaseProvider{
ctx: context.Background(),
displayName: "OIDC",
pkce: true,
scopes: []string{
"openid", // minimal requirement to return the id
"email",
"profile",
},
}}
}
// FetchAuthUser returns an AuthUser instance based the provider's user api.
//
// API reference: https://openid.net/specs/openid-connect-core-1_0.html#StandardClaims
func (p *OIDC) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
data, err := p.FetchRawUserInfo(token)
if err != nil {
return nil, err
}
rawUser := map[string]any{}
if err := json.Unmarshal(data, &rawUser); err != nil {
return nil, err
}
extracted := struct {
Id string `json:"sub"`
Name string `json:"name"`
Username string `json:"preferred_username"`
Picture string `json:"picture"`
Email string `json:"email"`
EmailVerified any `json:"email_verified"` // see #6657
}{}
if err := json.Unmarshal(data, &extracted); err != nil {
return nil, err
}
user := &AuthUser{
Id: extracted.Id,
Name: extracted.Name,
Username: extracted.Username,
AvatarURL: extracted.Picture,
RawUser: rawUser,
AccessToken: token.AccessToken,
RefreshToken: token.RefreshToken,
}
user.Expiry, _ = types.ParseDateTime(token.Expiry)
if cast.ToBool(extracted.EmailVerified) {
user.Email = extracted.Email
}
return user, nil
}
// FetchRawUserInfo implements Provider.FetchRawUserInfo interface method.
//
// It either fetch the data from p.userInfoURL, or if not set - returns the id_token claims.
func (p *OIDC) FetchRawUserInfo(token *oauth2.Token) ([]byte, error) {
if p.userInfoURL != "" {
return p.BaseProvider.FetchRawUserInfo(token)
}
claims, err := p.parseIdToken(token)
if err != nil {
return nil, err
}
return json.Marshal(claims)
}
func (p *OIDC) parseIdToken(token *oauth2.Token) (jwt.MapClaims, error) {
idToken := token.Extra("id_token").(string)
if idToken == "" {
return nil, errors.New("empty id_token")
}
claims := jwt.MapClaims{}
t, _, err := jwt.NewParser().ParseUnverified(idToken, claims)
if err != nil {
return nil, err
}
// validate common claims
jwtValidator := jwt.NewValidator(
jwt.WithIssuedAt(),
jwt.WithLeeway(idTokenLeeway),
jwt.WithAudience(p.clientId),
)
err = jwtValidator.Validate(claims)
if err != nil {
return nil, err
}
// validate iss (if "issuers" extra config is set)
issuers := cast.ToStringSlice(p.Extra()["issuers"])
if len(issuers) > 0 {
var isIssValid bool
claimIssuer, _ := claims.GetIssuer()
for _, issuer := range issuers {
if security.Equal(claimIssuer, issuer) {
isIssValid = true
break
}
}
if !isIssValid {
return nil, fmt.Errorf("iss must be one of %v, got %#v", issuers, claims["iss"])
}
}
// validate signature (if "jwksURL" extra config is set)
//
// note: this step could be technically considered optional because we trust
// the token which is a result of direct TLS communication with the provider
// (see also https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation)
jwksURL := cast.ToString(p.Extra()["jwksURL"])
if jwksURL != "" {
kid, _ := t.Header["kid"].(string)
err = validateIdTokenSignature(p.ctx, idToken, jwksURL, kid)
if err != nil {
return nil, err
}
}
return claims, nil
}
func validateIdTokenSignature(ctx context.Context, idToken string, jwksURL string, kid string) error {
// fetch the public key set
// ---
if kid == "" {
return errors.New("missing kid header value")
}
key, err := fetchJWK(ctx, jwksURL, kid)
if err != nil {
return err
}
// decode the key params per RFC 7518 (https://tools.ietf.org/html/rfc7518#section-6.3)
// and construct a valid publicKey from them
// ---
exponent, err := base64.RawURLEncoding.DecodeString(strings.TrimRight(key.E, "="))
if err != nil {
return err
}
modulus, err := base64.RawURLEncoding.DecodeString(strings.TrimRight(key.N, "="))
if err != nil {
return err
}
publicKey := &rsa.PublicKey{
// https://tools.ietf.org/html/rfc7517#appendix-A.1
E: int(big.NewInt(0).SetBytes(exponent).Uint64()),
N: big.NewInt(0).SetBytes(modulus),
}
// verify the signiture
// ---
parser := jwt.NewParser(jwt.WithValidMethods([]string{key.Alg}))
parsedToken, err := parser.Parse(idToken, func(t *jwt.Token) (any, error) {
return publicKey, nil
})
if err != nil {
return err
}
if !parsedToken.Valid {
return errors.New("the parsed id_token is invalid")
}
return nil
}
type jwk struct {
Kty string
Kid string
Use string
Alg string
N string
E string
}
func fetchJWK(ctx context.Context, jwksURL string, kid string) (*jwk, error) {
req, err := http.NewRequestWithContext(ctx, "GET", jwksURL, nil)
if err != nil {
return nil, err
}
res, err := http.DefaultClient.Do(req)
if err != nil {
return nil, err
}
defer res.Body.Close()
rawBody, err := io.ReadAll(res.Body)
if err != nil {
return nil, err
}
// http.Client.Get doesn't treat non 2xx responses as error
if res.StatusCode >= 400 {
return nil, fmt.Errorf(
"failed to verify the provided id_token (%d):\n%s",
res.StatusCode,
string(rawBody),
)
}
jwks := struct {
Keys []*jwk
}{}
if err := json.Unmarshal(rawBody, &jwks); err != nil {
return nil, err
}
for _, key := range jwks.Keys {
if key.Kid == kid {
return key, nil
}
}
return nil, fmt.Errorf("jwk with kid %q was not found", kid)
}

88
tools/auth/patreon.go Normal file
View file

@ -0,0 +1,88 @@
package auth
import (
"context"
"encoding/json"
"github.com/pocketbase/pocketbase/tools/types"
"golang.org/x/oauth2"
"golang.org/x/oauth2/endpoints"
)
func init() {
Providers[NamePatreon] = wrapFactory(NewPatreonProvider)
}
var _ Provider = (*Patreon)(nil)
// NamePatreon is the unique name of the Patreon provider.
const NamePatreon string = "patreon"
// Patreon allows authentication via Patreon OAuth2.
type Patreon struct {
BaseProvider
}
// NewPatreonProvider creates new Patreon provider instance with some defaults.
func NewPatreonProvider() *Patreon {
return &Patreon{BaseProvider{
ctx: context.Background(),
displayName: "Patreon",
pkce: true,
scopes: []string{"identity", "identity[email]"},
authURL: endpoints.Patreon.AuthURL,
tokenURL: endpoints.Patreon.TokenURL,
userInfoURL: "https://www.patreon.com/api/oauth2/v2/identity?fields%5Buser%5D=full_name,email,vanity,image_url,is_email_verified",
}}
}
// FetchAuthUser returns an AuthUser instance based on the Patreons's identity api.
//
// API reference:
// https://docs.patreon.com/#get-api-oauth2-v2-identity
// https://docs.patreon.com/#user-v2
func (p *Patreon) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
data, err := p.FetchRawUserInfo(token)
if err != nil {
return nil, err
}
rawUser := map[string]any{}
if err := json.Unmarshal(data, &rawUser); err != nil {
return nil, err
}
extracted := struct {
Data struct {
Id string `json:"id"`
Attributes struct {
Email string `json:"email"`
Name string `json:"full_name"`
Username string `json:"vanity"`
AvatarURL string `json:"image_url"`
IsEmailVerified bool `json:"is_email_verified"`
} `json:"attributes"`
} `json:"data"`
}{}
if err := json.Unmarshal(data, &extracted); err != nil {
return nil, err
}
user := &AuthUser{
Id: extracted.Data.Id,
Username: extracted.Data.Attributes.Username,
Name: extracted.Data.Attributes.Name,
AvatarURL: extracted.Data.Attributes.AvatarURL,
RawUser: rawUser,
AccessToken: token.AccessToken,
RefreshToken: token.RefreshToken,
}
user.Expiry, _ = types.ParseDateTime(token.Expiry)
if extracted.Data.Attributes.IsEmailVerified {
user.Email = extracted.Data.Attributes.Email
}
return user, nil
}

View file

@ -0,0 +1,85 @@
package auth
import (
"context"
"encoding/json"
"errors"
"github.com/pocketbase/pocketbase/tools/types"
"golang.org/x/oauth2"
)
func init() {
Providers[NamePlanningcenter] = wrapFactory(NewPlanningcenterProvider)
}
var _ Provider = (*Planningcenter)(nil)
// NamePlanningcenter is the unique name of the Planningcenter provider.
const NamePlanningcenter string = "planningcenter"
// Planningcenter allows authentication via Planningcenter OAuth2.
type Planningcenter struct {
BaseProvider
}
// NewPlanningcenterProvider creates a new Planningcenter provider instance with some defaults.
func NewPlanningcenterProvider() *Planningcenter {
return &Planningcenter{BaseProvider{
ctx: context.Background(),
displayName: "Planning Center",
pkce: true,
scopes: []string{"people"},
authURL: "https://api.planningcenteronline.com/oauth/authorize",
tokenURL: "https://api.planningcenteronline.com/oauth/token",
userInfoURL: "https://api.planningcenteronline.com/people/v2/me",
}}
}
// FetchAuthUser returns an AuthUser instance based on the Planningcenter's user api.
//
// API reference: https://developer.planning.center/docs/#/overview/authentication
func (p *Planningcenter) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
data, err := p.FetchRawUserInfo(token)
if err != nil {
return nil, err
}
rawUser := map[string]any{}
if err := json.Unmarshal(data, &rawUser); err != nil {
return nil, err
}
extracted := struct {
Data struct {
Id string `json:"id"`
Attributes struct {
Status string `json:"status"`
Name string `json:"name"`
AvatarURL string `json:"avatar"`
// don't map the email because users can have multiple assigned
// and it's not clear if they are verified
}
}
}{}
if err := json.Unmarshal(data, &extracted); err != nil {
return nil, err
}
if extracted.Data.Attributes.Status != "active" {
return nil, errors.New("the user is not active")
}
user := &AuthUser{
Id: extracted.Data.Id,
Name: extracted.Data.Attributes.Name,
AvatarURL: extracted.Data.Attributes.AvatarURL,
RawUser: rawUser,
AccessToken: token.AccessToken,
RefreshToken: token.RefreshToken,
}
user.Expiry, _ = types.ParseDateTime(token.Expiry)
return user, nil
}

87
tools/auth/spotify.go Normal file
View file

@ -0,0 +1,87 @@
package auth
import (
"context"
"encoding/json"
"github.com/pocketbase/pocketbase/tools/types"
"golang.org/x/oauth2"
"golang.org/x/oauth2/spotify"
)
func init() {
Providers[NameSpotify] = wrapFactory(NewSpotifyProvider)
}
var _ Provider = (*Spotify)(nil)
// NameSpotify is the unique name of the Spotify provider.
const NameSpotify string = "spotify"
// Spotify allows authentication via Spotify OAuth2.
type Spotify struct {
BaseProvider
}
// NewSpotifyProvider creates a new Spotify provider instance with some defaults.
func NewSpotifyProvider() *Spotify {
return &Spotify{BaseProvider{
ctx: context.Background(),
displayName: "Spotify",
pkce: true,
scopes: []string{
"user-read-private",
// currently Spotify doesn't return information whether the email is verified or not
// "user-read-email",
},
authURL: spotify.Endpoint.AuthURL,
tokenURL: spotify.Endpoint.TokenURL,
userInfoURL: "https://api.spotify.com/v1/me",
}}
}
// FetchAuthUser returns an AuthUser instance based on the Spotify's user api.
//
// API reference: https://developer.spotify.com/documentation/web-api/reference/#/operations/get-current-users-profile
func (p *Spotify) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
data, err := p.FetchRawUserInfo(token)
if err != nil {
return nil, err
}
rawUser := map[string]any{}
if err := json.Unmarshal(data, &rawUser); err != nil {
return nil, err
}
extracted := struct {
Id string `json:"id"`
Name string `json:"display_name"`
Images []struct {
URL string `json:"url"`
} `json:"images"`
// don't map the email because per the official docs
// the email field is "unverified" and there is no proof
// that it actually belongs to the user
// Email string `json:"email"`
}{}
if err := json.Unmarshal(data, &extracted); err != nil {
return nil, err
}
user := &AuthUser{
Id: extracted.Id,
Name: extracted.Name,
RawUser: rawUser,
AccessToken: token.AccessToken,
RefreshToken: token.RefreshToken,
}
user.Expiry, _ = types.ParseDateTime(token.Expiry)
if len(extracted.Images) > 0 {
user.AvatarURL = extracted.Images[0].URL
}
return user, nil
}

85
tools/auth/strava.go Normal file
View file

@ -0,0 +1,85 @@
package auth
import (
"context"
"encoding/json"
"strconv"
"github.com/pocketbase/pocketbase/tools/types"
"golang.org/x/oauth2"
)
func init() {
Providers[NameStrava] = wrapFactory(NewStravaProvider)
}
var _ Provider = (*Strava)(nil)
// NameStrava is the unique name of the Strava provider.
const NameStrava string = "strava"
// Strava allows authentication via Strava OAuth2.
type Strava struct {
BaseProvider
}
// NewStravaProvider creates new Strava provider instance with some defaults.
func NewStravaProvider() *Strava {
return &Strava{BaseProvider{
ctx: context.Background(),
displayName: "Strava",
pkce: true,
scopes: []string{
"profile:read_all",
},
authURL: "https://www.strava.com/oauth/authorize",
tokenURL: "https://www.strava.com/api/v3/oauth/token",
userInfoURL: "https://www.strava.com/api/v3/athlete",
}}
}
// FetchAuthUser returns an AuthUser instance based on the Strava's user api.
//
// API reference: https://developers.strava.com/docs/authentication/
func (p *Strava) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
data, err := p.FetchRawUserInfo(token)
if err != nil {
return nil, err
}
rawUser := map[string]any{}
if err := json.Unmarshal(data, &rawUser); err != nil {
return nil, err
}
extracted := struct {
Id int64 `json:"id"`
FirstName string `json:"firstname"`
LastName string `json:"lastname"`
Username string `json:"username"`
ProfileImageURL string `json:"profile"`
// At the time of writing, Strava OAuth2 doesn't support returning the user email address
// Email string `json:"email"`
}{}
if err := json.Unmarshal(data, &extracted); err != nil {
return nil, err
}
user := &AuthUser{
Name: extracted.FirstName + " " + extracted.LastName,
Username: extracted.Username,
AvatarURL: extracted.ProfileImageURL,
RawUser: rawUser,
AccessToken: token.AccessToken,
RefreshToken: token.RefreshToken,
}
user.Expiry, _ = types.ParseDateTime(token.Expiry)
if extracted.Id != 0 {
user.Id = strconv.FormatInt(extracted.Id, 10)
}
return user, nil
}

102
tools/auth/trakt.go Normal file
View file

@ -0,0 +1,102 @@
package auth
import (
"context"
"encoding/json"
"net/http"
"github.com/pocketbase/pocketbase/tools/types"
"golang.org/x/oauth2"
)
func init() {
Providers[NameTrakt] = wrapFactory(NewTraktProvider)
}
var _ Provider = (*Trakt)(nil)
// NameTrakt is the unique name of the Trakt provider.
const NameTrakt string = "trakt"
// Trakt allows authentication via Trakt OAuth2.
type Trakt struct {
BaseProvider
}
// NewTraktProvider creates new Trakt provider instance with some defaults.
func NewTraktProvider() *Trakt {
return &Trakt{BaseProvider{
ctx: context.Background(),
displayName: "Trakt",
pkce: true,
authURL: "https://trakt.tv/oauth/authorize",
tokenURL: "https://api.trakt.tv/oauth/token",
userInfoURL: "https://api.trakt.tv/users/settings",
}}
}
// FetchAuthUser returns an AuthUser instance based on Trakt's user settings API.
// API reference: https://trakt.docs.apiary.io/#reference/users/settings/retrieve-settings
func (p *Trakt) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
data, err := p.FetchRawUserInfo(token)
if err != nil {
return nil, err
}
rawUser := map[string]any{}
if err := json.Unmarshal(data, &rawUser); err != nil {
return nil, err
}
extracted := struct {
User struct {
Username string `json:"username"`
Name string `json:"name"`
Ids struct {
Slug string `json:"slug"`
UUID string `json:"uuid"`
} `json:"ids"`
Images struct {
Avatar struct {
Full string `json:"full"`
} `json:"avatar"`
} `json:"images"`
} `json:"user"`
}{}
if err := json.Unmarshal(data, &extracted); err != nil {
return nil, err
}
user := &AuthUser{
Id: extracted.User.Ids.UUID,
Username: extracted.User.Username,
Name: extracted.User.Name,
AvatarURL: extracted.User.Images.Avatar.Full,
RawUser: rawUser,
AccessToken: token.AccessToken,
RefreshToken: token.RefreshToken,
}
user.Expiry, _ = types.ParseDateTime(token.Expiry)
return user, nil
}
// FetchRawUserInfo implements Provider.FetchRawUserInfo interface method.
//
// This differ from BaseProvider because Trakt requires a number of
// mandatory headers for all requests
// (https://trakt.docs.apiary.io/#introduction/required-headers).
func (p *Trakt) FetchRawUserInfo(token *oauth2.Token) ([]byte, error) {
req, err := http.NewRequestWithContext(p.ctx, "GET", p.userInfoURL, nil)
if err != nil {
return nil, err
}
req.Header.Set("Content-type", "application/json")
req.Header.Set("trakt-api-key", p.clientId)
req.Header.Set("trakt-api-version", "2")
return p.sendRawUserInfoRequest(req, token)
}

100
tools/auth/twitch.go Normal file
View file

@ -0,0 +1,100 @@
package auth
import (
"context"
"encoding/json"
"errors"
"net/http"
"github.com/pocketbase/pocketbase/tools/types"
"golang.org/x/oauth2"
"golang.org/x/oauth2/twitch"
)
func init() {
Providers[NameTwitch] = wrapFactory(NewTwitchProvider)
}
var _ Provider = (*Twitch)(nil)
// NameTwitch is the unique name of the Twitch provider.
const NameTwitch string = "twitch"
// Twitch allows authentication via Twitch OAuth2.
type Twitch struct {
BaseProvider
}
// NewTwitchProvider creates new Twitch provider instance with some defaults.
func NewTwitchProvider() *Twitch {
return &Twitch{BaseProvider{
ctx: context.Background(),
displayName: "Twitch",
pkce: true,
scopes: []string{"user:read:email"},
authURL: twitch.Endpoint.AuthURL,
tokenURL: twitch.Endpoint.TokenURL,
userInfoURL: "https://api.twitch.tv/helix/users",
}}
}
// FetchAuthUser returns an AuthUser instance based the Twitch's user api.
//
// API reference: https://dev.twitch.tv/docs/api/reference#get-users
func (p *Twitch) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
data, err := p.FetchRawUserInfo(token)
if err != nil {
return nil, err
}
rawUser := map[string]any{}
if err := json.Unmarshal(data, &rawUser); err != nil {
return nil, err
}
extracted := struct {
Data []struct {
Id string `json:"id"`
Login string `json:"login"`
DisplayName string `json:"display_name"`
Email string `json:"email"`
ProfileImageURL string `json:"profile_image_url"`
} `json:"data"`
}{}
if err := json.Unmarshal(data, &extracted); err != nil {
return nil, err
}
if len(extracted.Data) == 0 {
return nil, errors.New("failed to fetch AuthUser data")
}
user := &AuthUser{
Id: extracted.Data[0].Id,
Name: extracted.Data[0].DisplayName,
Username: extracted.Data[0].Login,
Email: extracted.Data[0].Email,
AvatarURL: extracted.Data[0].ProfileImageURL,
RawUser: rawUser,
AccessToken: token.AccessToken,
RefreshToken: token.RefreshToken,
}
user.Expiry, _ = types.ParseDateTime(token.Expiry)
return user, nil
}
// FetchRawUserInfo implements Provider.FetchRawUserInfo interface method.
//
// This differ from BaseProvider because Twitch requires the Client-Id header.
func (p *Twitch) FetchRawUserInfo(token *oauth2.Token) ([]byte, error) {
req, err := http.NewRequestWithContext(p.ctx, "GET", p.userInfoURL, nil)
if err != nil {
return nil, err
}
req.Header.Set("Client-Id", p.clientId)
return p.sendRawUserInfoRequest(req, token)
}

87
tools/auth/twitter.go Normal file
View file

@ -0,0 +1,87 @@
package auth
import (
"context"
"encoding/json"
"github.com/pocketbase/pocketbase/tools/types"
"golang.org/x/oauth2"
)
func init() {
Providers[NameTwitter] = wrapFactory(NewTwitterProvider)
}
var _ Provider = (*Twitter)(nil)
// NameTwitter is the unique name of the Twitter provider.
const NameTwitter string = "twitter"
// Twitter allows authentication via Twitter OAuth2.
type Twitter struct {
BaseProvider
}
// NewTwitterProvider creates new Twitter provider instance with some defaults.
func NewTwitterProvider() *Twitter {
return &Twitter{BaseProvider{
ctx: context.Background(),
displayName: "Twitter",
pkce: true,
scopes: []string{
"users.read",
// we don't actually use this scope, but for some reason it is required by the `/2/users/me` endpoint
// (see https://developer.twitter.com/en/docs/twitter-api/users/lookup/api-reference/get-users-me)
"tweet.read",
},
authURL: "https://twitter.com/i/oauth2/authorize",
tokenURL: "https://api.twitter.com/2/oauth2/token",
userInfoURL: "https://api.twitter.com/2/users/me?user.fields=id,name,username,profile_image_url",
}}
}
// FetchAuthUser returns an AuthUser instance based on the Twitter's user api.
//
// API reference: https://developer.twitter.com/en/docs/twitter-api/users/lookup/api-reference/get-users-me
func (p *Twitter) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
data, err := p.FetchRawUserInfo(token)
if err != nil {
return nil, err
}
rawUser := map[string]any{}
if err := json.Unmarshal(data, &rawUser); err != nil {
return nil, err
}
extracted := struct {
Data struct {
Id string `json:"id"`
Name string `json:"name"`
Username string `json:"username"`
ProfileImageURL string `json:"profile_image_url"`
// NB! At the time of writing, Twitter OAuth2 doesn't support returning the user email address
// (see https://twittercommunity.com/t/which-api-to-get-user-after-oauth2-authorization/162417/33)
// Email string `json:"email"`
} `json:"data"`
}{}
if err := json.Unmarshal(data, &extracted); err != nil {
return nil, err
}
user := &AuthUser{
Id: extracted.Data.Id,
Name: extracted.Data.Name,
Username: extracted.Data.Username,
AvatarURL: extracted.Data.ProfileImageURL,
RawUser: rawUser,
AccessToken: token.AccessToken,
RefreshToken: token.RefreshToken,
}
user.Expiry, _ = types.ParseDateTime(token.Expiry)
return user, nil
}

94
tools/auth/vk.go Normal file
View file

@ -0,0 +1,94 @@
package auth
import (
"context"
"encoding/json"
"errors"
"fmt"
"strconv"
"strings"
"github.com/pocketbase/pocketbase/tools/types"
"golang.org/x/oauth2"
"golang.org/x/oauth2/vk"
)
func init() {
Providers[NameVK] = wrapFactory(NewVKProvider)
}
var _ Provider = (*VK)(nil)
// NameVK is the unique name of the VK provider.
const NameVK string = "vk"
// VK allows authentication via VK OAuth2.
type VK struct {
BaseProvider
}
// NewVKProvider creates new VK provider instance with some defaults.
//
// Docs: https://dev.vk.com/api/oauth-parameters
func NewVKProvider() *VK {
return &VK{BaseProvider{
ctx: context.Background(),
displayName: "ВКонтакте",
pkce: false, // VK currently doesn't support PKCE and throws an error if PKCE params are send
scopes: []string{"email"},
authURL: vk.Endpoint.AuthURL,
tokenURL: vk.Endpoint.TokenURL,
userInfoURL: "https://api.vk.com/method/users.get?fields=photo_max,screen_name&v=5.131",
}}
}
// FetchAuthUser returns an AuthUser instance based on VK's user api.
//
// API reference: https://dev.vk.com/method/users.get
func (p *VK) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
data, err := p.FetchRawUserInfo(token)
if err != nil {
return nil, err
}
rawUser := map[string]any{}
if err := json.Unmarshal(data, &rawUser); err != nil {
return nil, err
}
extracted := struct {
Response []struct {
Id int64 `json:"id"`
FirstName string `json:"first_name"`
LastName string `json:"last_name"`
Username string `json:"screen_name"`
AvatarURL string `json:"photo_max"`
} `json:"response"`
}{}
if err := json.Unmarshal(data, &extracted); err != nil {
return nil, err
}
if len(extracted.Response) == 0 {
return nil, errors.New("missing response entry")
}
user := &AuthUser{
Id: strconv.FormatInt(extracted.Response[0].Id, 10),
Name: strings.TrimSpace(extracted.Response[0].FirstName + " " + extracted.Response[0].LastName),
Username: extracted.Response[0].Username,
AvatarURL: extracted.Response[0].AvatarURL,
RawUser: rawUser,
AccessToken: token.AccessToken,
RefreshToken: token.RefreshToken,
}
user.Expiry, _ = types.ParseDateTime(token.Expiry)
if email := token.Extra("email"); email != nil {
user.Email = fmt.Sprint(email)
}
return user, nil
}

90
tools/auth/wakatime.go Normal file
View file

@ -0,0 +1,90 @@
package auth
import (
"context"
"encoding/json"
"github.com/pocketbase/pocketbase/tools/types"
"golang.org/x/oauth2"
)
func init() {
Providers[NameWakatime] = wrapFactory(NewWakatimeProvider)
}
var _ Provider = (*Wakatime)(nil)
// NameWakatime is the unique name of the Wakatime provider.
const NameWakatime = "wakatime"
// Wakatime is an auth provider for Wakatime.
type Wakatime struct {
BaseProvider
}
// NewWakatimeProvider creates a new Wakatime provider instance with some defaults.
func NewWakatimeProvider() *Wakatime {
return &Wakatime{BaseProvider{
ctx: context.Background(),
displayName: "WakaTime",
pkce: true,
scopes: []string{"email"},
authURL: "https://wakatime.com/oauth/authorize",
tokenURL: "https://wakatime.com/oauth/token",
userInfoURL: "https://wakatime.com/api/v1/users/current",
}}
}
// FetchAuthUser returns an AuthUser instance based on the Wakatime's user API.
//
// API reference: https://wakatime.com/developers#users
func (p *Wakatime) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
data, err := p.FetchRawUserInfo(token)
if err != nil {
return nil, err
}
rawUser := map[string]any{}
if err := json.Unmarshal(data, &rawUser); err != nil {
return nil, err
}
extracted := struct {
Data struct {
Id string `json:"id"`
DisplayName string `json:"display_name"`
Username string `json:"username"`
Email string `json:"email"`
Photo string `json:"photo"`
IsPhotoPublic bool `json:"photo_public"`
IsEmailConfirmed bool `json:"is_email_confirmed"`
} `json:"data"`
}{}
if err := json.Unmarshal(data, &extracted); err != nil {
return nil, err
}
user := &AuthUser{
Id: extracted.Data.Id,
Name: extracted.Data.DisplayName,
Username: extracted.Data.Username,
RawUser: rawUser,
AccessToken: token.AccessToken,
RefreshToken: token.RefreshToken,
}
// note: we don't check for is_email_public field because PocketBase
// has its own emailVisibility flag which is false by default
if extracted.Data.IsEmailConfirmed {
user.Email = extracted.Data.Email
}
if extracted.Data.IsPhotoPublic {
user.AvatarURL = extracted.Data.Photo
}
user.Expiry, _ = types.ParseDateTime(token.Expiry)
return user, nil
}

84
tools/auth/yandex.go Normal file
View file

@ -0,0 +1,84 @@
package auth
import (
"context"
"encoding/json"
"github.com/pocketbase/pocketbase/tools/types"
"golang.org/x/oauth2"
"golang.org/x/oauth2/yandex"
)
func init() {
Providers[NameYandex] = wrapFactory(NewYandexProvider)
}
var _ Provider = (*Yandex)(nil)
// NameYandex is the unique name of the Yandex provider.
const NameYandex string = "yandex"
// Yandex allows authentication via Yandex OAuth2.
type Yandex struct {
BaseProvider
}
// NewYandexProvider creates new Yandex provider instance with some defaults.
//
// Docs: https://yandex.ru/dev/id/doc/en/
func NewYandexProvider() *Yandex {
return &Yandex{BaseProvider{
ctx: context.Background(),
displayName: "Yandex",
pkce: true,
scopes: []string{"login:email", "login:avatar", "login:info"},
authURL: yandex.Endpoint.AuthURL,
tokenURL: yandex.Endpoint.TokenURL,
userInfoURL: "https://login.yandex.ru/info",
}}
}
// FetchAuthUser returns an AuthUser instance based on Yandex's user api.
//
// API reference: https://yandex.ru/dev/id/doc/en/user-information#response-format
func (p *Yandex) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
data, err := p.FetchRawUserInfo(token)
if err != nil {
return nil, err
}
rawUser := map[string]any{}
if err := json.Unmarshal(data, &rawUser); err != nil {
return nil, err
}
extracted := struct {
Id string `json:"id"`
Name string `json:"real_name"`
Username string `json:"login"`
Email string `json:"default_email"`
IsAvatarEmpty bool `json:"is_avatar_empty"`
AvatarId string `json:"default_avatar_id"`
}{}
if err := json.Unmarshal(data, &extracted); err != nil {
return nil, err
}
user := &AuthUser{
Id: extracted.Id,
Name: extracted.Name,
Username: extracted.Username,
Email: extracted.Email,
RawUser: rawUser,
AccessToken: token.AccessToken,
RefreshToken: token.RefreshToken,
}
user.Expiry, _ = types.ParseDateTime(token.Expiry)
if !extracted.IsAvatarEmpty {
user.AvatarURL = "https://avatars.yandex.net/get-yapic/" + extracted.AvatarId + "/islands-200"
}
return user, nil
}

228
tools/cron/cron.go Normal file
View file

@ -0,0 +1,228 @@
// Package cron implements a crontab-like service to execute and schedule
// repeative tasks/jobs.
//
// Example:
//
// c := cron.New()
// c.MustAdd("dailyReport", "0 0 * * *", func() { ... })
// c.Start()
package cron
import (
"errors"
"fmt"
"slices"
"sync"
"time"
)
// Cron is a crontab-like struct for tasks/jobs scheduling.
type Cron struct {
timezone *time.Location
ticker *time.Ticker
startTimer *time.Timer
tickerDone chan bool
jobs []*Job
interval time.Duration
mux sync.RWMutex
}
// New create a new Cron struct with default tick interval of 1 minute
// and timezone in UTC.
//
// You can change the default tick interval with Cron.SetInterval().
// You can change the default timezone with Cron.SetTimezone().
func New() *Cron {
return &Cron{
interval: 1 * time.Minute,
timezone: time.UTC,
jobs: []*Job{},
tickerDone: make(chan bool),
}
}
// SetInterval changes the current cron tick interval
// (it usually should be >= 1 minute).
func (c *Cron) SetInterval(d time.Duration) {
// update interval
c.mux.Lock()
wasStarted := c.ticker != nil
c.interval = d
c.mux.Unlock()
// restart the ticker
if wasStarted {
c.Start()
}
}
// SetTimezone changes the current cron tick timezone.
func (c *Cron) SetTimezone(l *time.Location) {
c.mux.Lock()
defer c.mux.Unlock()
c.timezone = l
}
// MustAdd is similar to Add() but panic on failure.
func (c *Cron) MustAdd(jobId string, cronExpr string, run func()) {
if err := c.Add(jobId, cronExpr, run); err != nil {
panic(err)
}
}
// Add registers a single cron job.
//
// If there is already a job with the provided id, then the old job
// will be replaced with the new one.
//
// cronExpr is a regular cron expression, eg. "0 */3 * * *" (aka. at minute 0 past every 3rd hour).
// Check cron.NewSchedule() for the supported tokens.
func (c *Cron) Add(jobId string, cronExpr string, fn func()) error {
if fn == nil {
return errors.New("failed to add new cron job: fn must be non-nil function")
}
schedule, err := NewSchedule(cronExpr)
if err != nil {
return fmt.Errorf("failed to add new cron job: %w", err)
}
c.mux.Lock()
defer c.mux.Unlock()
// remove previous (if any)
c.jobs = slices.DeleteFunc(c.jobs, func(j *Job) bool {
return j.Id() == jobId
})
// add new
c.jobs = append(c.jobs, &Job{
id: jobId,
fn: fn,
schedule: schedule,
})
return nil
}
// Remove removes a single cron job by its id.
func (c *Cron) Remove(jobId string) {
c.mux.Lock()
defer c.mux.Unlock()
if c.jobs == nil {
return // nothing to remove
}
c.jobs = slices.DeleteFunc(c.jobs, func(j *Job) bool {
return j.Id() == jobId
})
}
// RemoveAll removes all registered cron jobs.
func (c *Cron) RemoveAll() {
c.mux.Lock()
defer c.mux.Unlock()
c.jobs = []*Job{}
}
// Total returns the current total number of registered cron jobs.
func (c *Cron) Total() int {
c.mux.RLock()
defer c.mux.RUnlock()
return len(c.jobs)
}
// Jobs returns a shallow copy of the currently registered cron jobs.
func (c *Cron) Jobs() []*Job {
c.mux.RLock()
defer c.mux.RUnlock()
copy := make([]*Job, len(c.jobs))
for i, j := range c.jobs {
copy[i] = j
}
return copy
}
// Stop stops the current cron ticker (if not already).
//
// You can resume the ticker by calling Start().
func (c *Cron) Stop() {
c.mux.Lock()
defer c.mux.Unlock()
if c.startTimer != nil {
c.startTimer.Stop()
c.startTimer = nil
}
if c.ticker == nil {
return // already stopped
}
c.tickerDone <- true
c.ticker.Stop()
c.ticker = nil
}
// Start starts the cron ticker.
//
// Calling Start() on already started cron will restart the ticker.
func (c *Cron) Start() {
c.Stop()
// delay the ticker to start at 00 of 1 c.interval duration
now := time.Now()
next := now.Add(c.interval).Truncate(c.interval)
delay := next.Sub(now)
c.mux.Lock()
c.startTimer = time.AfterFunc(delay, func() {
c.mux.Lock()
c.ticker = time.NewTicker(c.interval)
c.mux.Unlock()
// run immediately at 00
c.runDue(time.Now())
// run after each tick
go func() {
for {
select {
case <-c.tickerDone:
return
case t := <-c.ticker.C:
c.runDue(t)
}
}
}()
})
c.mux.Unlock()
}
// HasStarted checks whether the current Cron ticker has been started.
func (c *Cron) HasStarted() bool {
c.mux.RLock()
defer c.mux.RUnlock()
return c.ticker != nil
}
// runDue runs all registered jobs that are scheduled for the provided time.
func (c *Cron) runDue(t time.Time) {
c.mux.RLock()
defer c.mux.RUnlock()
moment := NewMoment(t.In(c.timezone))
for _, j := range c.jobs {
if j.schedule.IsDue(moment) {
go j.Run()
}
}
}

305
tools/cron/cron_test.go Normal file
View file

@ -0,0 +1,305 @@
package cron
import (
"encoding/json"
"slices"
"testing"
"time"
)
func TestCronNew(t *testing.T) {
t.Parallel()
c := New()
expectedInterval := 1 * time.Minute
if c.interval != expectedInterval {
t.Fatalf("Expected default interval %v, got %v", expectedInterval, c.interval)
}
expectedTimezone := time.UTC
if c.timezone.String() != expectedTimezone.String() {
t.Fatalf("Expected default timezone %v, got %v", expectedTimezone, c.timezone)
}
if len(c.jobs) != 0 {
t.Fatalf("Expected no jobs by default, got \n%v", c.jobs)
}
if c.ticker != nil {
t.Fatal("Expected the ticker NOT to be initialized")
}
}
func TestCronSetInterval(t *testing.T) {
t.Parallel()
c := New()
interval := 2 * time.Minute
c.SetInterval(interval)
if c.interval != interval {
t.Fatalf("Expected interval %v, got %v", interval, c.interval)
}
}
func TestCronSetTimezone(t *testing.T) {
t.Parallel()
c := New()
timezone, _ := time.LoadLocation("Asia/Tokyo")
c.SetTimezone(timezone)
if c.timezone.String() != timezone.String() {
t.Fatalf("Expected timezone %v, got %v", timezone, c.timezone)
}
}
func TestCronAddAndRemove(t *testing.T) {
t.Parallel()
c := New()
if err := c.Add("test0", "* * * * *", nil); err == nil {
t.Fatal("Expected nil function error")
}
if err := c.Add("test1", "invalid", func() {}); err == nil {
t.Fatal("Expected invalid cron expression error")
}
if err := c.Add("test2", "* * * * *", func() {}); err != nil {
t.Fatal(err)
}
if err := c.Add("test3", "* * * * *", func() {}); err != nil {
t.Fatal(err)
}
if err := c.Add("test4", "* * * * *", func() {}); err != nil {
t.Fatal(err)
}
// overwrite test2
if err := c.Add("test2", "1 2 3 4 5", func() {}); err != nil {
t.Fatal(err)
}
if err := c.Add("test5", "1 2 3 4 5", func() {}); err != nil {
t.Fatal(err)
}
// mock job deletion
c.Remove("test4")
// try to remove non-existing (should be no-op)
c.Remove("missing")
indexedJobs := make(map[string]*Job, len(c.jobs))
for _, j := range c.jobs {
indexedJobs[j.Id()] = j
}
// check job keys
{
expectedKeys := []string{"test3", "test2", "test5"}
if v := len(c.jobs); v != len(expectedKeys) {
t.Fatalf("Expected %d jobs, got %d", len(expectedKeys), v)
}
for _, k := range expectedKeys {
if indexedJobs[k] == nil {
t.Fatalf("Expected job with key %s, got nil", k)
}
}
}
// check the jobs schedule
{
expectedSchedules := map[string]string{
"test2": `{"minutes":{"1":{}},"hours":{"2":{}},"days":{"3":{}},"months":{"4":{}},"daysOfWeek":{"5":{}}}`,
"test3": `{"minutes":{"0":{},"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"24":{},"25":{},"26":{},"27":{},"28":{},"29":{},"3":{},"30":{},"31":{},"32":{},"33":{},"34":{},"35":{},"36":{},"37":{},"38":{},"39":{},"4":{},"40":{},"41":{},"42":{},"43":{},"44":{},"45":{},"46":{},"47":{},"48":{},"49":{},"5":{},"50":{},"51":{},"52":{},"53":{},"54":{},"55":{},"56":{},"57":{},"58":{},"59":{},"6":{},"7":{},"8":{},"9":{}},"hours":{"0":{},"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"3":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"days":{"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"24":{},"25":{},"26":{},"27":{},"28":{},"29":{},"3":{},"30":{},"31":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"months":{"1":{},"10":{},"11":{},"12":{},"2":{},"3":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"daysOfWeek":{"0":{},"1":{},"2":{},"3":{},"4":{},"5":{},"6":{}}}`,
"test5": `{"minutes":{"1":{}},"hours":{"2":{}},"days":{"3":{}},"months":{"4":{}},"daysOfWeek":{"5":{}}}`,
}
for k, v := range expectedSchedules {
raw, err := json.Marshal(indexedJobs[k].schedule)
if err != nil {
t.Fatal(err)
}
if string(raw) != v {
t.Fatalf("Expected %q schedule \n%s, \ngot \n%s", k, v, raw)
}
}
}
}
func TestCronMustAdd(t *testing.T) {
t.Parallel()
c := New()
defer func() {
if r := recover(); r == nil {
t.Errorf("test1 didn't panic")
}
}()
c.MustAdd("test1", "* * * * *", nil)
c.MustAdd("test2", "* * * * *", func() {})
if !slices.ContainsFunc(c.jobs, func(j *Job) bool { return j.Id() == "test2" }) {
t.Fatal("Couldn't find job test2")
}
}
func TestCronRemoveAll(t *testing.T) {
t.Parallel()
c := New()
if err := c.Add("test1", "* * * * *", func() {}); err != nil {
t.Fatal(err)
}
if err := c.Add("test2", "* * * * *", func() {}); err != nil {
t.Fatal(err)
}
if err := c.Add("test3", "* * * * *", func() {}); err != nil {
t.Fatal(err)
}
if v := len(c.jobs); v != 3 {
t.Fatalf("Expected %d jobs, got %d", 3, v)
}
c.RemoveAll()
if v := len(c.jobs); v != 0 {
t.Fatalf("Expected %d jobs, got %d", 0, v)
}
}
func TestCronTotal(t *testing.T) {
t.Parallel()
c := New()
if v := c.Total(); v != 0 {
t.Fatalf("Expected 0 jobs, got %v", v)
}
if err := c.Add("test1", "* * * * *", func() {}); err != nil {
t.Fatal(err)
}
if err := c.Add("test2", "* * * * *", func() {}); err != nil {
t.Fatal(err)
}
// overwrite
if err := c.Add("test1", "* * * * *", func() {}); err != nil {
t.Fatal(err)
}
if v := c.Total(); v != 2 {
t.Fatalf("Expected 2 jobs, got %v", v)
}
}
func TestCronJobs(t *testing.T) {
t.Parallel()
c := New()
calls := ""
if err := c.Add("a", "1 * * * *", func() { calls += "a" }); err != nil {
t.Fatal(err)
}
if err := c.Add("b", "2 * * * *", func() { calls += "b" }); err != nil {
t.Fatal(err)
}
// overwrite
if err := c.Add("b", "3 * * * *", func() { calls += "b" }); err != nil {
t.Fatal(err)
}
jobs := c.Jobs()
if len(jobs) != 2 {
t.Fatalf("Expected 2 jobs, got %v", len(jobs))
}
for _, j := range jobs {
j.Run()
}
expectedCalls := "ab"
if calls != expectedCalls {
t.Fatalf("Expected %q calls, got %q", expectedCalls, calls)
}
}
func TestCronStartStop(t *testing.T) {
t.Parallel()
test1 := 0
test2 := 0
c := New()
c.SetInterval(500 * time.Millisecond)
c.Add("test1", "* * * * *", func() {
test1++
})
c.Add("test2", "* * * * *", func() {
test2++
})
expectedCalls := 2
// call twice Start to check if the previous ticker will be reseted
c.Start()
c.Start()
time.Sleep(1 * time.Second)
// call twice Stop to ensure that the second stop is no-op
c.Stop()
c.Stop()
if test1 != expectedCalls {
t.Fatalf("Expected %d test1, got %d", expectedCalls, test1)
}
if test2 != expectedCalls {
t.Fatalf("Expected %d test2, got %d", expectedCalls, test2)
}
// resume for 2 seconds
c.Start()
time.Sleep(2 * time.Second)
c.Stop()
expectedCalls += 4
if test1 != expectedCalls {
t.Fatalf("Expected %d test1, got %d", expectedCalls, test1)
}
if test2 != expectedCalls {
t.Fatalf("Expected %d test2, got %d", expectedCalls, test2)
}
}

41
tools/cron/job.go Normal file
View file

@ -0,0 +1,41 @@
package cron
import "encoding/json"
// Job defines a single registered cron job.
type Job struct {
fn func()
schedule *Schedule
id string
}
// Id returns the cron job id.
func (j *Job) Id() string {
return j.id
}
// Expression returns the plain cron job schedule expression.
func (j *Job) Expression() string {
return j.schedule.rawExpr
}
// Run runs the cron job function.
func (j *Job) Run() {
if j.fn != nil {
j.fn()
}
}
// MarshalJSON implements [json.Marshaler] and export the current
// jobs data into valid JSON.
func (j Job) MarshalJSON() ([]byte, error) {
plain := struct {
Id string `json:"id"`
Expression string `json:"expression"`
}{
Id: j.Id(),
Expression: j.Expression(),
}
return json.Marshal(plain)
}

71
tools/cron/job_test.go Normal file
View file

@ -0,0 +1,71 @@
package cron
import (
"encoding/json"
"testing"
)
func TestJobId(t *testing.T) {
expected := "test"
j := Job{id: expected}
if j.Id() != expected {
t.Fatalf("Expected job with id %q, got %q", expected, j.Id())
}
}
func TestJobExpr(t *testing.T) {
expected := "1 2 3 4 5"
s, err := NewSchedule(expected)
if err != nil {
t.Fatal(err)
}
j := Job{schedule: s}
if j.Expression() != expected {
t.Fatalf("Expected job with cron expression %q, got %q", expected, j.Expression())
}
}
func TestJobRun(t *testing.T) {
defer func() {
if r := recover(); r != nil {
t.Errorf("Shouldn't panic: %v", r)
}
}()
calls := ""
j1 := Job{}
j2 := Job{fn: func() { calls += "2" }}
j1.Run()
j2.Run()
expected := "2"
if calls != expected {
t.Fatalf("Expected calls %q, got %q", expected, calls)
}
}
func TestJobMarshalJSON(t *testing.T) {
s, err := NewSchedule("1 2 3 4 5")
if err != nil {
t.Fatal(err)
}
j := Job{id: "test_id", schedule: s}
raw, err := json.Marshal(j)
if err != nil {
t.Fatal(err)
}
expected := `{"id":"test_id","expression":"1 2 3 4 5"}`
if str := string(raw); str != expected {
t.Fatalf("Expected\n%s\ngot\n%s", expected, str)
}
}

218
tools/cron/schedule.go Normal file
View file

@ -0,0 +1,218 @@
package cron
import (
"errors"
"fmt"
"strconv"
"strings"
"time"
)
// Moment represents a parsed single time moment.
type Moment struct {
Minute int `json:"minute"`
Hour int `json:"hour"`
Day int `json:"day"`
Month int `json:"month"`
DayOfWeek int `json:"dayOfWeek"`
}
// NewMoment creates a new Moment from the specified time.
func NewMoment(t time.Time) *Moment {
return &Moment{
Minute: t.Minute(),
Hour: t.Hour(),
Day: t.Day(),
Month: int(t.Month()),
DayOfWeek: int(t.Weekday()),
}
}
// Schedule stores parsed information for each time component when a cron job should run.
type Schedule struct {
Minutes map[int]struct{} `json:"minutes"`
Hours map[int]struct{} `json:"hours"`
Days map[int]struct{} `json:"days"`
Months map[int]struct{} `json:"months"`
DaysOfWeek map[int]struct{} `json:"daysOfWeek"`
rawExpr string
}
// IsDue checks whether the provided Moment satisfies the current Schedule.
func (s *Schedule) IsDue(m *Moment) bool {
if _, ok := s.Minutes[m.Minute]; !ok {
return false
}
if _, ok := s.Hours[m.Hour]; !ok {
return false
}
if _, ok := s.Days[m.Day]; !ok {
return false
}
if _, ok := s.DaysOfWeek[m.DayOfWeek]; !ok {
return false
}
if _, ok := s.Months[m.Month]; !ok {
return false
}
return true
}
var macros = map[string]string{
"@yearly": "0 0 1 1 *",
"@annually": "0 0 1 1 *",
"@monthly": "0 0 1 * *",
"@weekly": "0 0 * * 0",
"@daily": "0 0 * * *",
"@midnight": "0 0 * * *",
"@hourly": "0 * * * *",
}
// NewSchedule creates a new Schedule from a cron expression.
//
// A cron expression could be a macro OR 5 segments separated by space,
// representing: minute, hour, day of the month, month and day of the week.
//
// The following segment formats are supported:
// - wildcard: *
// - range: 1-30
// - step: */n or 1-30/n
// - list: 1,2,3,10-20/n
//
// The following macros are supported:
// - @yearly (or @annually)
// - @monthly
// - @weekly
// - @daily (or @midnight)
// - @hourly
func NewSchedule(cronExpr string) (*Schedule, error) {
if v, ok := macros[cronExpr]; ok {
cronExpr = v
}
segments := strings.Split(cronExpr, " ")
if len(segments) != 5 {
return nil, errors.New("invalid cron expression - must be a valid macro or to have exactly 5 space separated segments")
}
minutes, err := parseCronSegment(segments[0], 0, 59)
if err != nil {
return nil, err
}
hours, err := parseCronSegment(segments[1], 0, 23)
if err != nil {
return nil, err
}
days, err := parseCronSegment(segments[2], 1, 31)
if err != nil {
return nil, err
}
months, err := parseCronSegment(segments[3], 1, 12)
if err != nil {
return nil, err
}
daysOfWeek, err := parseCronSegment(segments[4], 0, 6)
if err != nil {
return nil, err
}
return &Schedule{
Minutes: minutes,
Hours: hours,
Days: days,
Months: months,
DaysOfWeek: daysOfWeek,
rawExpr: cronExpr,
}, nil
}
// parseCronSegment parses a single cron expression segment and
// returns its time schedule slots.
func parseCronSegment(segment string, min int, max int) (map[int]struct{}, error) {
slots := map[int]struct{}{}
list := strings.Split(segment, ",")
for _, p := range list {
stepParts := strings.Split(p, "/")
// step (*/n, 1-30/n)
var step int
switch len(stepParts) {
case 1:
step = 1
case 2:
parsedStep, err := strconv.Atoi(stepParts[1])
if err != nil {
return nil, err
}
if parsedStep < 1 || parsedStep > max {
return nil, fmt.Errorf("invalid segment step boundary - the step must be between 1 and the %d", max)
}
step = parsedStep
default:
return nil, errors.New("invalid segment step format - must be in the format */n or 1-30/n")
}
// find the min and max range of the segment part
var rangeMin, rangeMax int
if stepParts[0] == "*" {
rangeMin = min
rangeMax = max
} else {
// single digit (1) or range (1-30)
rangeParts := strings.Split(stepParts[0], "-")
switch len(rangeParts) {
case 1:
if step != 1 {
return nil, errors.New("invalid segement step - step > 1 could be used only with the wildcard or range format")
}
parsed, err := strconv.Atoi(rangeParts[0])
if err != nil {
return nil, err
}
if parsed < min || parsed > max {
return nil, errors.New("invalid segment value - must be between the min and max of the segment")
}
rangeMin = parsed
rangeMax = rangeMin
case 2:
parsedMin, err := strconv.Atoi(rangeParts[0])
if err != nil {
return nil, err
}
if parsedMin < min || parsedMin > max {
return nil, fmt.Errorf("invalid segment range minimum - must be between %d and %d", min, max)
}
rangeMin = parsedMin
parsedMax, err := strconv.Atoi(rangeParts[1])
if err != nil {
return nil, err
}
if parsedMax < parsedMin || parsedMax > max {
return nil, fmt.Errorf("invalid segment range maximum - must be between %d and %d", rangeMin, max)
}
rangeMax = parsedMax
default:
return nil, errors.New("invalid segment range format - the range must have 1 or 2 parts")
}
}
// fill the slots
for i := rangeMin; i <= rangeMax; i += step {
slots[i] = struct{}{}
}
}
return slots, nil
}

409
tools/cron/schedule_test.go Normal file
View file

@ -0,0 +1,409 @@
package cron_test
import (
"encoding/json"
"fmt"
"testing"
"time"
"github.com/pocketbase/pocketbase/tools/cron"
)
func TestNewMoment(t *testing.T) {
t.Parallel()
date, err := time.Parse("2006-01-02 15:04", "2023-05-09 15:20")
if err != nil {
t.Fatal(err)
}
m := cron.NewMoment(date)
if m.Minute != 20 {
t.Fatalf("Expected m.Minute %d, got %d", 20, m.Minute)
}
if m.Hour != 15 {
t.Fatalf("Expected m.Hour %d, got %d", 15, m.Hour)
}
if m.Day != 9 {
t.Fatalf("Expected m.Day %d, got %d", 9, m.Day)
}
if m.Month != 5 {
t.Fatalf("Expected m.Month %d, got %d", 5, m.Month)
}
if m.DayOfWeek != 2 {
t.Fatalf("Expected m.DayOfWeek %d, got %d", 2, m.DayOfWeek)
}
}
func TestNewSchedule(t *testing.T) {
t.Parallel()
scenarios := []struct {
cronExpr string
expectError bool
expectSchedule string
}{
{
"invalid",
true,
"",
},
{
"* * * *",
true,
"",
},
{
"* * * * * *",
true,
"",
},
{
"2/3 * * * *",
true,
"",
},
{
"* * * * *",
false,
`{"minutes":{"0":{},"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"24":{},"25":{},"26":{},"27":{},"28":{},"29":{},"3":{},"30":{},"31":{},"32":{},"33":{},"34":{},"35":{},"36":{},"37":{},"38":{},"39":{},"4":{},"40":{},"41":{},"42":{},"43":{},"44":{},"45":{},"46":{},"47":{},"48":{},"49":{},"5":{},"50":{},"51":{},"52":{},"53":{},"54":{},"55":{},"56":{},"57":{},"58":{},"59":{},"6":{},"7":{},"8":{},"9":{}},"hours":{"0":{},"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"3":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"days":{"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"24":{},"25":{},"26":{},"27":{},"28":{},"29":{},"3":{},"30":{},"31":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"months":{"1":{},"10":{},"11":{},"12":{},"2":{},"3":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"daysOfWeek":{"0":{},"1":{},"2":{},"3":{},"4":{},"5":{},"6":{}}}`,
},
{
"*/2 */3 */5 */4 */2",
false,
`{"minutes":{"0":{},"10":{},"12":{},"14":{},"16":{},"18":{},"2":{},"20":{},"22":{},"24":{},"26":{},"28":{},"30":{},"32":{},"34":{},"36":{},"38":{},"4":{},"40":{},"42":{},"44":{},"46":{},"48":{},"50":{},"52":{},"54":{},"56":{},"58":{},"6":{},"8":{}},"hours":{"0":{},"12":{},"15":{},"18":{},"21":{},"3":{},"6":{},"9":{}},"days":{"1":{},"11":{},"16":{},"21":{},"26":{},"31":{},"6":{}},"months":{"1":{},"5":{},"9":{}},"daysOfWeek":{"0":{},"2":{},"4":{},"6":{}}}`,
},
// minute segment
{
"-1 * * * *",
true,
"",
},
{
"60 * * * *",
true,
"",
},
{
"0 * * * *",
false,
`{"minutes":{"0":{}},"hours":{"0":{},"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"3":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"days":{"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"24":{},"25":{},"26":{},"27":{},"28":{},"29":{},"3":{},"30":{},"31":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"months":{"1":{},"10":{},"11":{},"12":{},"2":{},"3":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"daysOfWeek":{"0":{},"1":{},"2":{},"3":{},"4":{},"5":{},"6":{}}}`,
},
{
"59 * * * *",
false,
`{"minutes":{"59":{}},"hours":{"0":{},"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"3":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"days":{"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"24":{},"25":{},"26":{},"27":{},"28":{},"29":{},"3":{},"30":{},"31":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"months":{"1":{},"10":{},"11":{},"12":{},"2":{},"3":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"daysOfWeek":{"0":{},"1":{},"2":{},"3":{},"4":{},"5":{},"6":{}}}`,
},
{
"1,2,5,7,40-50/2 * * * *",
false,
`{"minutes":{"1":{},"2":{},"40":{},"42":{},"44":{},"46":{},"48":{},"5":{},"50":{},"7":{}},"hours":{"0":{},"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"3":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"days":{"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"24":{},"25":{},"26":{},"27":{},"28":{},"29":{},"3":{},"30":{},"31":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"months":{"1":{},"10":{},"11":{},"12":{},"2":{},"3":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"daysOfWeek":{"0":{},"1":{},"2":{},"3":{},"4":{},"5":{},"6":{}}}`,
},
// hour segment
{
"* -1 * * *",
true,
"",
},
{
"* 24 * * *",
true,
"",
},
{
"* 0 * * *",
false,
`{"minutes":{"0":{},"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"24":{},"25":{},"26":{},"27":{},"28":{},"29":{},"3":{},"30":{},"31":{},"32":{},"33":{},"34":{},"35":{},"36":{},"37":{},"38":{},"39":{},"4":{},"40":{},"41":{},"42":{},"43":{},"44":{},"45":{},"46":{},"47":{},"48":{},"49":{},"5":{},"50":{},"51":{},"52":{},"53":{},"54":{},"55":{},"56":{},"57":{},"58":{},"59":{},"6":{},"7":{},"8":{},"9":{}},"hours":{"0":{}},"days":{"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"24":{},"25":{},"26":{},"27":{},"28":{},"29":{},"3":{},"30":{},"31":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"months":{"1":{},"10":{},"11":{},"12":{},"2":{},"3":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"daysOfWeek":{"0":{},"1":{},"2":{},"3":{},"4":{},"5":{},"6":{}}}`,
},
{
"* 23 * * *",
false,
`{"minutes":{"0":{},"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"24":{},"25":{},"26":{},"27":{},"28":{},"29":{},"3":{},"30":{},"31":{},"32":{},"33":{},"34":{},"35":{},"36":{},"37":{},"38":{},"39":{},"4":{},"40":{},"41":{},"42":{},"43":{},"44":{},"45":{},"46":{},"47":{},"48":{},"49":{},"5":{},"50":{},"51":{},"52":{},"53":{},"54":{},"55":{},"56":{},"57":{},"58":{},"59":{},"6":{},"7":{},"8":{},"9":{}},"hours":{"23":{}},"days":{"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"24":{},"25":{},"26":{},"27":{},"28":{},"29":{},"3":{},"30":{},"31":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"months":{"1":{},"10":{},"11":{},"12":{},"2":{},"3":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"daysOfWeek":{"0":{},"1":{},"2":{},"3":{},"4":{},"5":{},"6":{}}}`,
},
{
"* 3,4,8-16/3,7 * * *",
false,
`{"minutes":{"0":{},"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"24":{},"25":{},"26":{},"27":{},"28":{},"29":{},"3":{},"30":{},"31":{},"32":{},"33":{},"34":{},"35":{},"36":{},"37":{},"38":{},"39":{},"4":{},"40":{},"41":{},"42":{},"43":{},"44":{},"45":{},"46":{},"47":{},"48":{},"49":{},"5":{},"50":{},"51":{},"52":{},"53":{},"54":{},"55":{},"56":{},"57":{},"58":{},"59":{},"6":{},"7":{},"8":{},"9":{}},"hours":{"11":{},"14":{},"3":{},"4":{},"7":{},"8":{}},"days":{"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"24":{},"25":{},"26":{},"27":{},"28":{},"29":{},"3":{},"30":{},"31":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"months":{"1":{},"10":{},"11":{},"12":{},"2":{},"3":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"daysOfWeek":{"0":{},"1":{},"2":{},"3":{},"4":{},"5":{},"6":{}}}`,
},
// day segment
{
"* * 0 * *",
true,
"",
},
{
"* * 32 * *",
true,
"",
},
{
"* * 1 * *",
false,
`{"minutes":{"0":{},"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"24":{},"25":{},"26":{},"27":{},"28":{},"29":{},"3":{},"30":{},"31":{},"32":{},"33":{},"34":{},"35":{},"36":{},"37":{},"38":{},"39":{},"4":{},"40":{},"41":{},"42":{},"43":{},"44":{},"45":{},"46":{},"47":{},"48":{},"49":{},"5":{},"50":{},"51":{},"52":{},"53":{},"54":{},"55":{},"56":{},"57":{},"58":{},"59":{},"6":{},"7":{},"8":{},"9":{}},"hours":{"0":{},"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"3":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"days":{"1":{}},"months":{"1":{},"10":{},"11":{},"12":{},"2":{},"3":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"daysOfWeek":{"0":{},"1":{},"2":{},"3":{},"4":{},"5":{},"6":{}}}`,
},
{
"* * 31 * *",
false,
`{"minutes":{"0":{},"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"24":{},"25":{},"26":{},"27":{},"28":{},"29":{},"3":{},"30":{},"31":{},"32":{},"33":{},"34":{},"35":{},"36":{},"37":{},"38":{},"39":{},"4":{},"40":{},"41":{},"42":{},"43":{},"44":{},"45":{},"46":{},"47":{},"48":{},"49":{},"5":{},"50":{},"51":{},"52":{},"53":{},"54":{},"55":{},"56":{},"57":{},"58":{},"59":{},"6":{},"7":{},"8":{},"9":{}},"hours":{"0":{},"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"3":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"days":{"31":{}},"months":{"1":{},"10":{},"11":{},"12":{},"2":{},"3":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"daysOfWeek":{"0":{},"1":{},"2":{},"3":{},"4":{},"5":{},"6":{}}}`,
},
{
"* * 5,6,20-30/3,1 * *",
false,
`{"minutes":{"0":{},"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"24":{},"25":{},"26":{},"27":{},"28":{},"29":{},"3":{},"30":{},"31":{},"32":{},"33":{},"34":{},"35":{},"36":{},"37":{},"38":{},"39":{},"4":{},"40":{},"41":{},"42":{},"43":{},"44":{},"45":{},"46":{},"47":{},"48":{},"49":{},"5":{},"50":{},"51":{},"52":{},"53":{},"54":{},"55":{},"56":{},"57":{},"58":{},"59":{},"6":{},"7":{},"8":{},"9":{}},"hours":{"0":{},"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"3":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"days":{"1":{},"20":{},"23":{},"26":{},"29":{},"5":{},"6":{}},"months":{"1":{},"10":{},"11":{},"12":{},"2":{},"3":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"daysOfWeek":{"0":{},"1":{},"2":{},"3":{},"4":{},"5":{},"6":{}}}`,
},
// month segment
{
"* * * 0 *",
true,
"",
},
{
"* * * 13 *",
true,
"",
},
{
"* * * 1 *",
false,
`{"minutes":{"0":{},"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"24":{},"25":{},"26":{},"27":{},"28":{},"29":{},"3":{},"30":{},"31":{},"32":{},"33":{},"34":{},"35":{},"36":{},"37":{},"38":{},"39":{},"4":{},"40":{},"41":{},"42":{},"43":{},"44":{},"45":{},"46":{},"47":{},"48":{},"49":{},"5":{},"50":{},"51":{},"52":{},"53":{},"54":{},"55":{},"56":{},"57":{},"58":{},"59":{},"6":{},"7":{},"8":{},"9":{}},"hours":{"0":{},"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"3":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"days":{"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"24":{},"25":{},"26":{},"27":{},"28":{},"29":{},"3":{},"30":{},"31":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"months":{"1":{}},"daysOfWeek":{"0":{},"1":{},"2":{},"3":{},"4":{},"5":{},"6":{}}}`,
},
{
"* * * 12 *",
false,
`{"minutes":{"0":{},"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"24":{},"25":{},"26":{},"27":{},"28":{},"29":{},"3":{},"30":{},"31":{},"32":{},"33":{},"34":{},"35":{},"36":{},"37":{},"38":{},"39":{},"4":{},"40":{},"41":{},"42":{},"43":{},"44":{},"45":{},"46":{},"47":{},"48":{},"49":{},"5":{},"50":{},"51":{},"52":{},"53":{},"54":{},"55":{},"56":{},"57":{},"58":{},"59":{},"6":{},"7":{},"8":{},"9":{}},"hours":{"0":{},"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"3":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"days":{"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"24":{},"25":{},"26":{},"27":{},"28":{},"29":{},"3":{},"30":{},"31":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"months":{"12":{}},"daysOfWeek":{"0":{},"1":{},"2":{},"3":{},"4":{},"5":{},"6":{}}}`,
},
{
"* * * 1,4,5-10/2 *",
false,
`{"minutes":{"0":{},"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"24":{},"25":{},"26":{},"27":{},"28":{},"29":{},"3":{},"30":{},"31":{},"32":{},"33":{},"34":{},"35":{},"36":{},"37":{},"38":{},"39":{},"4":{},"40":{},"41":{},"42":{},"43":{},"44":{},"45":{},"46":{},"47":{},"48":{},"49":{},"5":{},"50":{},"51":{},"52":{},"53":{},"54":{},"55":{},"56":{},"57":{},"58":{},"59":{},"6":{},"7":{},"8":{},"9":{}},"hours":{"0":{},"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"3":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"days":{"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"24":{},"25":{},"26":{},"27":{},"28":{},"29":{},"3":{},"30":{},"31":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"months":{"1":{},"4":{},"5":{},"7":{},"9":{}},"daysOfWeek":{"0":{},"1":{},"2":{},"3":{},"4":{},"5":{},"6":{}}}`,
},
// day of week segment
{
"* * * * -1",
true,
"",
},
{
"* * * * 7",
true,
"",
},
{
"* * * * 0",
false,
`{"minutes":{"0":{},"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"24":{},"25":{},"26":{},"27":{},"28":{},"29":{},"3":{},"30":{},"31":{},"32":{},"33":{},"34":{},"35":{},"36":{},"37":{},"38":{},"39":{},"4":{},"40":{},"41":{},"42":{},"43":{},"44":{},"45":{},"46":{},"47":{},"48":{},"49":{},"5":{},"50":{},"51":{},"52":{},"53":{},"54":{},"55":{},"56":{},"57":{},"58":{},"59":{},"6":{},"7":{},"8":{},"9":{}},"hours":{"0":{},"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"3":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"days":{"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"24":{},"25":{},"26":{},"27":{},"28":{},"29":{},"3":{},"30":{},"31":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"months":{"1":{},"10":{},"11":{},"12":{},"2":{},"3":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"daysOfWeek":{"0":{}}}`,
},
{
"* * * * 6",
false,
`{"minutes":{"0":{},"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"24":{},"25":{},"26":{},"27":{},"28":{},"29":{},"3":{},"30":{},"31":{},"32":{},"33":{},"34":{},"35":{},"36":{},"37":{},"38":{},"39":{},"4":{},"40":{},"41":{},"42":{},"43":{},"44":{},"45":{},"46":{},"47":{},"48":{},"49":{},"5":{},"50":{},"51":{},"52":{},"53":{},"54":{},"55":{},"56":{},"57":{},"58":{},"59":{},"6":{},"7":{},"8":{},"9":{}},"hours":{"0":{},"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"3":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"days":{"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"24":{},"25":{},"26":{},"27":{},"28":{},"29":{},"3":{},"30":{},"31":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"months":{"1":{},"10":{},"11":{},"12":{},"2":{},"3":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"daysOfWeek":{"6":{}}}`,
},
{
"* * * * 1,2-5/2",
false,
`{"minutes":{"0":{},"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"24":{},"25":{},"26":{},"27":{},"28":{},"29":{},"3":{},"30":{},"31":{},"32":{},"33":{},"34":{},"35":{},"36":{},"37":{},"38":{},"39":{},"4":{},"40":{},"41":{},"42":{},"43":{},"44":{},"45":{},"46":{},"47":{},"48":{},"49":{},"5":{},"50":{},"51":{},"52":{},"53":{},"54":{},"55":{},"56":{},"57":{},"58":{},"59":{},"6":{},"7":{},"8":{},"9":{}},"hours":{"0":{},"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"3":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"days":{"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"24":{},"25":{},"26":{},"27":{},"28":{},"29":{},"3":{},"30":{},"31":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"months":{"1":{},"10":{},"11":{},"12":{},"2":{},"3":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"daysOfWeek":{"1":{},"2":{},"4":{}}}`,
},
// macros
{
"@yearly",
false,
`{"minutes":{"0":{}},"hours":{"0":{}},"days":{"1":{}},"months":{"1":{}},"daysOfWeek":{"0":{},"1":{},"2":{},"3":{},"4":{},"5":{},"6":{}}}`,
},
{
"@annually",
false,
`{"minutes":{"0":{}},"hours":{"0":{}},"days":{"1":{}},"months":{"1":{}},"daysOfWeek":{"0":{},"1":{},"2":{},"3":{},"4":{},"5":{},"6":{}}}`,
},
{
"@monthly",
false,
`{"minutes":{"0":{}},"hours":{"0":{}},"days":{"1":{}},"months":{"1":{},"10":{},"11":{},"12":{},"2":{},"3":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"daysOfWeek":{"0":{},"1":{},"2":{},"3":{},"4":{},"5":{},"6":{}}}`,
},
{
"@weekly",
false,
`{"minutes":{"0":{}},"hours":{"0":{}},"days":{"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"24":{},"25":{},"26":{},"27":{},"28":{},"29":{},"3":{},"30":{},"31":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"months":{"1":{},"10":{},"11":{},"12":{},"2":{},"3":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"daysOfWeek":{"0":{}}}`,
},
{
"@daily",
false,
`{"minutes":{"0":{}},"hours":{"0":{}},"days":{"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"24":{},"25":{},"26":{},"27":{},"28":{},"29":{},"3":{},"30":{},"31":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"months":{"1":{},"10":{},"11":{},"12":{},"2":{},"3":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"daysOfWeek":{"0":{},"1":{},"2":{},"3":{},"4":{},"5":{},"6":{}}}`,
},
{
"@midnight",
false,
`{"minutes":{"0":{}},"hours":{"0":{}},"days":{"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"24":{},"25":{},"26":{},"27":{},"28":{},"29":{},"3":{},"30":{},"31":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"months":{"1":{},"10":{},"11":{},"12":{},"2":{},"3":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"daysOfWeek":{"0":{},"1":{},"2":{},"3":{},"4":{},"5":{},"6":{}}}`,
},
{
"@hourly",
false,
`{"minutes":{"0":{}},"hours":{"0":{},"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"3":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"days":{"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"24":{},"25":{},"26":{},"27":{},"28":{},"29":{},"3":{},"30":{},"31":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"months":{"1":{},"10":{},"11":{},"12":{},"2":{},"3":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"daysOfWeek":{"0":{},"1":{},"2":{},"3":{},"4":{},"5":{},"6":{}}}`,
},
}
for _, s := range scenarios {
t.Run(s.cronExpr, func(t *testing.T) {
schedule, err := cron.NewSchedule(s.cronExpr)
hasErr := err != nil
if hasErr != s.expectError {
t.Fatalf("Expected hasErr to be %v, got %v (%v)", s.expectError, hasErr, err)
}
if hasErr {
return
}
encoded, err := json.Marshal(schedule)
if err != nil {
t.Fatalf("Failed to marshalize the result schedule: %v", err)
}
encodedStr := string(encoded)
if encodedStr != s.expectSchedule {
t.Fatalf("Expected \n%s, \ngot \n%s", s.expectSchedule, encodedStr)
}
})
}
}
func TestScheduleIsDue(t *testing.T) {
t.Parallel()
scenarios := []struct {
cronExpr string
moment *cron.Moment
expected bool
}{
{
"* * * * *",
&cron.Moment{},
false,
},
{
"* * * * *",
&cron.Moment{
Minute: 1,
Hour: 1,
Day: 1,
Month: 1,
DayOfWeek: 1,
},
true,
},
{
"5 * * * *",
&cron.Moment{
Minute: 1,
Hour: 1,
Day: 1,
Month: 1,
DayOfWeek: 1,
},
false,
},
{
"5 * * * *",
&cron.Moment{
Minute: 5,
Hour: 1,
Day: 1,
Month: 1,
DayOfWeek: 1,
},
true,
},
{
"* 2-6 * * 2,3",
&cron.Moment{
Minute: 1,
Hour: 2,
Day: 1,
Month: 1,
DayOfWeek: 1,
},
false,
},
{
"* 2-6 * * 2,3",
&cron.Moment{
Minute: 1,
Hour: 2,
Day: 1,
Month: 1,
DayOfWeek: 3,
},
true,
},
{
"* * 1,2,5,15-18 * *",
&cron.Moment{
Minute: 1,
Hour: 1,
Day: 6,
Month: 1,
DayOfWeek: 1,
},
false,
},
{
"* * 1,2,5,15-18/2 * *",
&cron.Moment{
Minute: 1,
Hour: 1,
Day: 2,
Month: 1,
DayOfWeek: 1,
},
true,
},
{
"* * 1,2,5,15-18/2 * *",
&cron.Moment{
Minute: 1,
Hour: 1,
Day: 18,
Month: 1,
DayOfWeek: 1,
},
false,
},
{
"* * 1,2,5,15-18/2 * *",
&cron.Moment{
Minute: 1,
Hour: 1,
Day: 17,
Month: 1,
DayOfWeek: 1,
},
true,
},
}
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d-%s", i, s.cronExpr), func(t *testing.T) {
schedule, err := cron.NewSchedule(s.cronExpr)
if err != nil {
t.Fatalf("Unexpected cron error: %v", err)
}
result := schedule.IsDue(s.moment)
if result != s.expected {
t.Fatalf("Expected %v, got %v", s.expected, result)
}
})
}
}

217
tools/dbutils/index.go Normal file
View file

@ -0,0 +1,217 @@
package dbutils
import (
"regexp"
"strings"
"github.com/pocketbase/pocketbase/tools/tokenizer"
)
var (
indexRegex = regexp.MustCompile(`(?im)create\s+(unique\s+)?\s*index\s*(if\s+not\s+exists\s+)?(\S*)\s+on\s+(\S*)\s*\(([\s\S]*)\)(?:\s*where\s+([\s\S]*))?`)
indexColumnRegex = regexp.MustCompile(`(?im)^([\s\S]+?)(?:\s+collate\s+([\w]+))?(?:\s+(asc|desc))?$`)
)
// IndexColumn represents a single parsed SQL index column.
type IndexColumn struct {
Name string `json:"name"` // identifier or expression
Collate string `json:"collate"`
Sort string `json:"sort"`
}
// Index represents a single parsed SQL CREATE INDEX expression.
type Index struct {
SchemaName string `json:"schemaName"`
IndexName string `json:"indexName"`
TableName string `json:"tableName"`
Where string `json:"where"`
Columns []IndexColumn `json:"columns"`
Unique bool `json:"unique"`
Optional bool `json:"optional"`
}
// IsValid checks if the current Index contains the minimum required fields to be considered valid.
func (idx Index) IsValid() bool {
return idx.IndexName != "" && idx.TableName != "" && len(idx.Columns) > 0
}
// Build returns a "CREATE INDEX" SQL string from the current index parts.
//
// Returns empty string if idx.IsValid() is false.
func (idx Index) Build() string {
if !idx.IsValid() {
return ""
}
var str strings.Builder
str.WriteString("CREATE ")
if idx.Unique {
str.WriteString("UNIQUE ")
}
str.WriteString("INDEX ")
if idx.Optional {
str.WriteString("IF NOT EXISTS ")
}
if idx.SchemaName != "" {
str.WriteString("`")
str.WriteString(idx.SchemaName)
str.WriteString("`.")
}
str.WriteString("`")
str.WriteString(idx.IndexName)
str.WriteString("` ")
str.WriteString("ON `")
str.WriteString(idx.TableName)
str.WriteString("` (")
if len(idx.Columns) > 1 {
str.WriteString("\n ")
}
var hasCol bool
for _, col := range idx.Columns {
trimmedColName := strings.TrimSpace(col.Name)
if trimmedColName == "" {
continue
}
if hasCol {
str.WriteString(",\n ")
}
if strings.Contains(col.Name, "(") || strings.Contains(col.Name, " ") {
// most likely an expression
str.WriteString(trimmedColName)
} else {
// regular identifier
str.WriteString("`")
str.WriteString(trimmedColName)
str.WriteString("`")
}
if col.Collate != "" {
str.WriteString(" COLLATE ")
str.WriteString(col.Collate)
}
if col.Sort != "" {
str.WriteString(" ")
str.WriteString(strings.ToUpper(col.Sort))
}
hasCol = true
}
if hasCol && len(idx.Columns) > 1 {
str.WriteString("\n")
}
str.WriteString(")")
if idx.Where != "" {
str.WriteString(" WHERE ")
str.WriteString(idx.Where)
}
return str.String()
}
// ParseIndex parses the provided "CREATE INDEX" SQL string into Index struct.
func ParseIndex(createIndexExpr string) Index {
result := Index{}
matches := indexRegex.FindStringSubmatch(createIndexExpr)
if len(matches) != 7 {
return result
}
trimChars := "`\"'[]\r\n\t\f\v "
// Unique
// ---
result.Unique = strings.TrimSpace(matches[1]) != ""
// Optional (aka. "IF NOT EXISTS")
// ---
result.Optional = strings.TrimSpace(matches[2]) != ""
// SchemaName and IndexName
// ---
nameTk := tokenizer.NewFromString(matches[3])
nameTk.Separators('.')
nameParts, _ := nameTk.ScanAll()
if len(nameParts) == 2 {
result.SchemaName = strings.Trim(nameParts[0], trimChars)
result.IndexName = strings.Trim(nameParts[1], trimChars)
} else {
result.IndexName = strings.Trim(nameParts[0], trimChars)
}
// TableName
// ---
result.TableName = strings.Trim(matches[4], trimChars)
// Columns
// ---
columnsTk := tokenizer.NewFromString(matches[5])
columnsTk.Separators(',')
rawColumns, _ := columnsTk.ScanAll()
result.Columns = make([]IndexColumn, 0, len(rawColumns))
for _, col := range rawColumns {
colMatches := indexColumnRegex.FindStringSubmatch(col)
if len(colMatches) != 4 {
continue
}
trimmedName := strings.Trim(colMatches[1], trimChars)
if trimmedName == "" {
continue
}
result.Columns = append(result.Columns, IndexColumn{
Name: trimmedName,
Collate: strings.TrimSpace(colMatches[2]),
Sort: strings.ToUpper(colMatches[3]),
})
}
// WHERE expression
// ---
result.Where = strings.TrimSpace(matches[6])
return result
}
// FindSingleColumnUniqueIndex returns the first matching single column unique index.
func FindSingleColumnUniqueIndex(indexes []string, column string) (Index, bool) {
var index Index
for _, idx := range indexes {
index := ParseIndex(idx)
if index.Unique && len(index.Columns) == 1 && strings.EqualFold(index.Columns[0].Name, column) {
return index, true
}
}
return index, false
}
// Deprecated: Use `_, ok := FindSingleColumnUniqueIndex(indexes, column)` instead.
//
// HasColumnUniqueIndex loosely checks whether the specified column has
// a single column unique index (WHERE statements are ignored).
func HasSingleColumnUniqueIndex(column string, indexes []string) bool {
_, ok := FindSingleColumnUniqueIndex(indexes, column)
return ok
}

405
tools/dbutils/index_test.go Normal file
View file

@ -0,0 +1,405 @@
package dbutils_test
import (
"bytes"
"encoding/json"
"fmt"
"strings"
"testing"
"github.com/pocketbase/pocketbase/tools/dbutils"
)
func TestParseIndex(t *testing.T) {
scenarios := []struct {
index string
expected dbutils.Index
}{
// invalid
{
`invalid`,
dbutils.Index{},
},
// simple (multiple spaces between the table and columns list)
{
`create index indexname on tablename (col1)`,
dbutils.Index{
IndexName: "indexname",
TableName: "tablename",
Columns: []dbutils.IndexColumn{
{Name: "col1"},
},
},
},
// simple (no space between the table and the columns list)
{
`create index indexname on tablename(col1)`,
dbutils.Index{
IndexName: "indexname",
TableName: "tablename",
Columns: []dbutils.IndexColumn{
{Name: "col1"},
},
},
},
// all fields
{
`CREATE UNIQUE INDEX IF NOT EXISTS "schemaname".[indexname] on 'tablename' (
col0,
` + "`" + `col1` + "`" + `,
json_extract("col2", "$.a") asc,
"col3" collate NOCASE,
"col4" collate RTRIM desc
) where test = 1`,
dbutils.Index{
Unique: true,
Optional: true,
SchemaName: "schemaname",
IndexName: "indexname",
TableName: "tablename",
Columns: []dbutils.IndexColumn{
{Name: "col0"},
{Name: "col1"},
{Name: `json_extract("col2", "$.a")`, Sort: "ASC"},
{Name: `col3`, Collate: "NOCASE"},
{Name: `col4`, Collate: "RTRIM", Sort: "DESC"},
},
Where: "test = 1",
},
},
}
for i, s := range scenarios {
t.Run(fmt.Sprintf("scenario_%d", i), func(t *testing.T) {
result := dbutils.ParseIndex(s.index)
resultRaw, err := json.Marshal(result)
if err != nil {
t.Fatalf("Faild to marshalize parse result: %v", err)
}
expectedRaw, err := json.Marshal(s.expected)
if err != nil {
t.Fatalf("Failed to marshalize expected index: %v", err)
}
if !bytes.Equal(resultRaw, expectedRaw) {
t.Errorf("Expected \n%s \ngot \n%s", expectedRaw, resultRaw)
}
})
}
}
func TestIndexIsValid(t *testing.T) {
scenarios := []struct {
name string
index dbutils.Index
expected bool
}{
{
"empty",
dbutils.Index{},
false,
},
{
"no index name",
dbutils.Index{
TableName: "table",
Columns: []dbutils.IndexColumn{{Name: "col"}},
},
false,
},
{
"no table name",
dbutils.Index{
IndexName: "index",
Columns: []dbutils.IndexColumn{{Name: "col"}},
},
false,
},
{
"no columns",
dbutils.Index{
IndexName: "index",
TableName: "table",
},
false,
},
{
"min valid",
dbutils.Index{
IndexName: "index",
TableName: "table",
Columns: []dbutils.IndexColumn{{Name: "col"}},
},
true,
},
{
"all fields",
dbutils.Index{
Optional: true,
Unique: true,
SchemaName: "schema",
IndexName: "index",
TableName: "table",
Columns: []dbutils.IndexColumn{{Name: "col"}},
Where: "test = 1 OR test = 2",
},
true,
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
result := s.index.IsValid()
if result != s.expected {
t.Fatalf("Expected %v, got %v", s.expected, result)
}
})
}
}
func TestIndexBuild(t *testing.T) {
scenarios := []struct {
name string
index dbutils.Index
expected string
}{
{
"empty",
dbutils.Index{},
"",
},
{
"no index name",
dbutils.Index{
TableName: "table",
Columns: []dbutils.IndexColumn{{Name: "col"}},
},
"",
},
{
"no table name",
dbutils.Index{
IndexName: "index",
Columns: []dbutils.IndexColumn{{Name: "col"}},
},
"",
},
{
"no columns",
dbutils.Index{
IndexName: "index",
TableName: "table",
},
"",
},
{
"min valid",
dbutils.Index{
IndexName: "index",
TableName: "table",
Columns: []dbutils.IndexColumn{{Name: "col"}},
},
"CREATE INDEX `index` ON `table` (`col`)",
},
{
"all fields",
dbutils.Index{
Optional: true,
Unique: true,
SchemaName: "schema",
IndexName: "index",
TableName: "table",
Columns: []dbutils.IndexColumn{
{Name: "col1", Collate: "NOCASE", Sort: "asc"},
{Name: "col2", Sort: "desc"},
{Name: `json_extract("col3", "$.a")`, Collate: "NOCASE"},
},
Where: "test = 1 OR test = 2",
},
"CREATE UNIQUE INDEX IF NOT EXISTS `schema`.`index` ON `table` (\n `col1` COLLATE NOCASE ASC,\n `col2` DESC,\n " + `json_extract("col3", "$.a")` + " COLLATE NOCASE\n) WHERE test = 1 OR test = 2",
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
result := s.index.Build()
if result != s.expected {
t.Fatalf("Expected \n%v \ngot \n%v", s.expected, result)
}
})
}
}
func TestHasSingleColumnUniqueIndex(t *testing.T) {
scenarios := []struct {
name string
column string
indexes []string
expected bool
}{
{
"empty indexes",
"test",
nil,
false,
},
{
"empty column",
"",
[]string{
"CREATE UNIQUE INDEX `index1` ON `example` (`test`)",
},
false,
},
{
"mismatched column",
"test",
[]string{
"CREATE UNIQUE INDEX `index1` ON `example` (`test2`)",
},
false,
},
{
"non unique index",
"test",
[]string{
"CREATE INDEX `index1` ON `example` (`test`)",
},
false,
},
{
"matching columnd and unique index",
"test",
[]string{
"CREATE UNIQUE INDEX `index1` ON `example` (`test`)",
},
true,
},
{
"multiple columns",
"test",
[]string{
"CREATE UNIQUE INDEX `index1` ON `example` (`test`, `test2`)",
},
false,
},
{
"multiple indexes",
"test",
[]string{
"CREATE UNIQUE INDEX `index1` ON `example` (`test`, `test2`)",
"CREATE UNIQUE INDEX `index2` ON `example` (`test`)",
},
true,
},
{
"partial unique index",
"test",
[]string{
"CREATE UNIQUE INDEX `index` ON `example` (`test`) where test != ''",
},
true,
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
result := dbutils.HasSingleColumnUniqueIndex(s.column, s.indexes)
if result != s.expected {
t.Fatalf("Expected %v got %v", s.expected, result)
}
})
}
}
func TestFindSingleColumnUniqueIndex(t *testing.T) {
scenarios := []struct {
name string
column string
indexes []string
expected bool
}{
{
"empty indexes",
"test",
nil,
false,
},
{
"empty column",
"",
[]string{
"CREATE UNIQUE INDEX `index1` ON `example` (`test`)",
},
false,
},
{
"mismatched column",
"test",
[]string{
"CREATE UNIQUE INDEX `index1` ON `example` (`test2`)",
},
false,
},
{
"non unique index",
"test",
[]string{
"CREATE INDEX `index1` ON `example` (`test`)",
},
false,
},
{
"matching columnd and unique index",
"test",
[]string{
"CREATE UNIQUE INDEX `index1` ON `example` (`test`)",
},
true,
},
{
"multiple columns",
"test",
[]string{
"CREATE UNIQUE INDEX `index1` ON `example` (`test`, `test2`)",
},
false,
},
{
"multiple indexes",
"test",
[]string{
"CREATE UNIQUE INDEX `index1` ON `example` (`test`, `test2`)",
"CREATE UNIQUE INDEX `index2` ON `example` (`test`)",
},
true,
},
{
"partial unique index",
"test",
[]string{
"CREATE UNIQUE INDEX `index` ON `example` (`test`) where test != ''",
},
true,
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
index, exists := dbutils.FindSingleColumnUniqueIndex(s.indexes, s.column)
if exists != s.expected {
t.Fatalf("Expected exists %v got %v", s.expected, exists)
}
if !exists && len(index.Columns) > 0 {
t.Fatal("Expected index.Columns to be empty")
}
if exists && !strings.EqualFold(index.Columns[0].Name, s.column) {
t.Fatalf("Expected to find column %q in %v", s.column, index)
}
})
}
}

51
tools/dbutils/json.go Normal file
View file

@ -0,0 +1,51 @@
package dbutils
import (
"fmt"
"strings"
)
// JSONEach returns JSON_EACH SQLite string expression with
// some normalizations for non-json columns.
func JSONEach(column string) string {
// note: we are not using the new and shorter "if(x,y)" syntax for
// compatability with custom drivers that use older SQLite version
return fmt.Sprintf(
`json_each(CASE WHEN iif(json_valid([[%s]]), json_type([[%s]])='array', FALSE) THEN [[%s]] ELSE json_array([[%s]]) END)`,
column, column, column, column,
)
}
// JSONArrayLength returns JSON_ARRAY_LENGTH SQLite string expression
// with some normalizations for non-json columns.
//
// It works with both json and non-json column values.
//
// Returns 0 for empty string or NULL column values.
func JSONArrayLength(column string) string {
// note: we are not using the new and shorter "if(x,y)" syntax for
// compatability with custom drivers that use older SQLite version
return fmt.Sprintf(
`json_array_length(CASE WHEN iif(json_valid([[%s]]), json_type([[%s]])='array', FALSE) THEN [[%s]] ELSE (CASE WHEN [[%s]] = '' OR [[%s]] IS NULL THEN json_array() ELSE json_array([[%s]]) END) END)`,
column, column, column, column, column, column,
)
}
// JSONExtract returns a JSON_EXTRACT SQLite string expression with
// some normalizations for non-json columns.
func JSONExtract(column string, path string) string {
// prefix the path with dot if it is not starting with array notation
if path != "" && !strings.HasPrefix(path, "[") {
path = "." + path
}
return fmt.Sprintf(
// note: the extra object wrapping is needed to workaround the cases where a json_extract is used with non-json columns.
"(CASE WHEN json_valid([[%s]]) THEN JSON_EXTRACT([[%s]], '$%s') ELSE JSON_EXTRACT(json_object('pb', [[%s]]), '$.pb%s') END)",
column,
column,
path,
column,
path,
)
}

View file

@ -0,0 +1,65 @@
package dbutils_test
import (
"testing"
"github.com/pocketbase/pocketbase/tools/dbutils"
)
func TestJSONEach(t *testing.T) {
result := dbutils.JSONEach("a.b")
expected := "json_each(CASE WHEN iif(json_valid([[a.b]]), json_type([[a.b]])='array', FALSE) THEN [[a.b]] ELSE json_array([[a.b]]) END)"
if result != expected {
t.Fatalf("Expected\n%v\ngot\n%v", expected, result)
}
}
func TestJSONArrayLength(t *testing.T) {
result := dbutils.JSONArrayLength("a.b")
expected := "json_array_length(CASE WHEN iif(json_valid([[a.b]]), json_type([[a.b]])='array', FALSE) THEN [[a.b]] ELSE (CASE WHEN [[a.b]] = '' OR [[a.b]] IS NULL THEN json_array() ELSE json_array([[a.b]]) END) END)"
if result != expected {
t.Fatalf("Expected\n%v\ngot\n%v", expected, result)
}
}
func TestJSONExtract(t *testing.T) {
scenarios := []struct {
name string
column string
path string
expected string
}{
{
"empty path",
"a.b",
"",
"(CASE WHEN json_valid([[a.b]]) THEN JSON_EXTRACT([[a.b]], '$') ELSE JSON_EXTRACT(json_object('pb', [[a.b]]), '$.pb') END)",
},
{
"starting with array index",
"a.b",
"[1].a[2]",
"(CASE WHEN json_valid([[a.b]]) THEN JSON_EXTRACT([[a.b]], '$[1].a[2]') ELSE JSON_EXTRACT(json_object('pb', [[a.b]]), '$.pb[1].a[2]') END)",
},
{
"starting with key",
"a.b",
"a.b[2].c",
"(CASE WHEN json_valid([[a.b]]) THEN JSON_EXTRACT([[a.b]], '$.a.b[2].c') ELSE JSON_EXTRACT(json_object('pb', [[a.b]]), '$.pb.a.b[2].c') END)",
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
result := dbutils.JSONExtract(s.column, s.path)
if result != s.expected {
t.Fatalf("Expected\n%v\ngot\n%v", s.expected, result)
}
})
}
}

View file

@ -0,0 +1,736 @@
// Package blob defines a lightweight abstration for interacting with
// various storage services (local filesystem, S3, etc.).
//
// NB!
// For compatibility with earlier PocketBase versions and to prevent
// unnecessary breaking changes, this package is based and implemented
// as a minimal, stripped down version of the previously used gocloud.dev/blob.
// While there is no promise that it won't diverge in the future to accommodate
// better some PocketBase specific use cases, currently it copies and
// tries to follow as close as possible the same implementations,
// conventions and rules for the key escaping/unescaping, blob read/write
// interfaces and struct options as gocloud.dev/blob, therefore the
// credits goes to the original Go Cloud Development Kit Authors.
package blob
import (
"bytes"
"context"
"crypto/md5"
"errors"
"fmt"
"io"
"log"
"mime"
"runtime"
"strings"
"sync"
"time"
"unicode/utf8"
)
var (
ErrNotFound = errors.New("resource not found")
ErrClosed = errors.New("bucket or blob is closed")
)
// Bucket provides an easy and portable way to interact with blobs
// within a "bucket", including read, write, and list operations.
// To create a Bucket, use constructors found in driver subpackages.
type Bucket struct {
drv Driver
// mu protects the closed variable.
// Read locks are kept to allow holding a read lock for long-running calls,
// and thereby prevent closing until a call finishes.
mu sync.RWMutex
closed bool
}
// NewBucket creates a new *Bucket based on a specific driver implementation.
func NewBucket(drv Driver) *Bucket {
return &Bucket{drv: drv}
}
// ListOptions sets options for listing blobs via Bucket.List.
type ListOptions struct {
// Prefix indicates that only blobs with a key starting with this prefix
// should be returned.
Prefix string
// Delimiter sets the delimiter used to define a hierarchical namespace,
// like a filesystem with "directories". It is highly recommended that you
// use "" or "/" as the Delimiter. Other values should work through this API,
// but service UIs generally assume "/".
//
// An empty delimiter means that the bucket is treated as a single flat
// namespace.
//
// A non-empty delimiter means that any result with the delimiter in its key
// after Prefix is stripped will be returned with ListObject.IsDir = true,
// ListObject.Key truncated after the delimiter, and zero values for other
// ListObject fields. These results represent "directories". Multiple results
// in a "directory" are returned as a single result.
Delimiter string
// PageSize sets the maximum number of objects to be returned.
// 0 means no maximum; driver implementations should choose a reasonable
// max. It is guaranteed to be >= 0.
PageSize int
// PageToken may be filled in with the NextPageToken from a previous
// ListPaged call.
PageToken []byte
}
// ListPage represents a page of results return from ListPaged.
type ListPage struct {
// Objects is the slice of objects found. If ListOptions.PageSize > 0,
// it should have at most ListOptions.PageSize entries.
//
// Objects should be returned in lexicographical order of UTF-8 encoded keys,
// including across pages. I.e., all objects returned from a ListPage request
// made using a PageToken from a previous ListPage request's NextPageToken
// should have Key >= the Key for all objects from the previous request.
Objects []*ListObject `json:"objects"`
// NextPageToken should be left empty unless there are more objects
// to return. The value may be returned as ListOptions.PageToken on a
// subsequent ListPaged call, to fetch the next page of results.
// It can be an arbitrary []byte; it need not be a valid key.
NextPageToken []byte `json:"nextPageToken"`
}
// ListIterator iterates over List results.
type ListIterator struct {
b *Bucket
opts *ListOptions
page *ListPage
nextIdx int
}
// Next returns a *ListObject for the next blob.
// It returns (nil, io.EOF) if there are no more.
func (i *ListIterator) Next(ctx context.Context) (*ListObject, error) {
if i.page != nil {
// We've already got a page of results.
if i.nextIdx < len(i.page.Objects) {
// Next object is in the page; return it.
dobj := i.page.Objects[i.nextIdx]
i.nextIdx++
return &ListObject{
Key: dobj.Key,
ModTime: dobj.ModTime,
Size: dobj.Size,
MD5: dobj.MD5,
IsDir: dobj.IsDir,
}, nil
}
if len(i.page.NextPageToken) == 0 {
// Done with current page, and there are no more; return io.EOF.
return nil, io.EOF
}
// We need to load the next page.
i.opts.PageToken = i.page.NextPageToken
}
i.b.mu.RLock()
defer i.b.mu.RUnlock()
if i.b.closed {
return nil, ErrClosed
}
// Loading a new page.
p, err := i.b.drv.ListPaged(ctx, i.opts)
if err != nil {
return nil, wrapError(i.b.drv, err, "")
}
i.page = p
i.nextIdx = 0
return i.Next(ctx)
}
// ListObject represents a single blob returned from List.
type ListObject struct {
// Key is the key for this blob.
Key string `json:"key"`
// ModTime is the time the blob was last modified.
ModTime time.Time `json:"modTime"`
// Size is the size of the blob's content in bytes.
Size int64 `json:"size"`
// MD5 is an MD5 hash of the blob contents or nil if not available.
MD5 []byte `json:"md5"`
// IsDir indicates that this result represents a "directory" in the
// hierarchical namespace, ending in ListOptions.Delimiter. Key can be
// passed as ListOptions.Prefix to list items in the "directory".
// Fields other than Key and IsDir will not be set if IsDir is true.
IsDir bool `json:"isDir"`
}
// List returns a ListIterator that can be used to iterate over blobs in a
// bucket, in lexicographical order of UTF-8 encoded keys. The underlying
// implementation fetches results in pages.
//
// A nil ListOptions is treated the same as the zero value.
//
// List is not guaranteed to include all recently-written blobs;
// some services are only eventually consistent.
func (b *Bucket) List(opts *ListOptions) *ListIterator {
if opts == nil {
opts = &ListOptions{}
}
dopts := &ListOptions{
Prefix: opts.Prefix,
Delimiter: opts.Delimiter,
}
return &ListIterator{b: b, opts: dopts}
}
// FirstPageToken is the pageToken to pass to ListPage to retrieve the first page of results.
var FirstPageToken = []byte("first page")
// ListPage returns a page of ListObject results for blobs in a bucket, in lexicographical
// order of UTF-8 encoded keys.
//
// To fetch the first page, pass FirstPageToken as the pageToken. For subsequent pages, pass
// the pageToken returned from a previous call to ListPage.
// It is not possible to "skip ahead" pages.
//
// Each call will return pageSize results, unless there are not enough blobs to fill the
// page, in which case it will return fewer results (possibly 0).
//
// If there are no more blobs available, ListPage will return an empty pageToken. Note that
// this may happen regardless of the number of returned results -- the last page might have
// 0 results (i.e., if the last item was deleted), pageSize results, or anything in between.
//
// Calling ListPage with an empty pageToken will immediately return io.EOF. When looping
// over pages, callers can either check for an empty pageToken, or they can make one more
// call and check for io.EOF.
//
// The underlying implementation fetches results in pages, but one call to ListPage may
// require multiple page fetches (and therefore, multiple calls to the BeforeList callback).
//
// A nil ListOptions is treated the same as the zero value.
//
// ListPage is not guaranteed to include all recently-written blobs;
// some services are only eventually consistent.
func (b *Bucket) ListPage(ctx context.Context, pageToken []byte, pageSize int, opts *ListOptions) (retval []*ListObject, nextPageToken []byte, err error) {
if opts == nil {
opts = &ListOptions{}
}
if pageSize <= 0 {
return nil, nil, fmt.Errorf("pageSize must be > 0 (%d)", pageSize)
}
// Nil pageToken means no more results.
if len(pageToken) == 0 {
return nil, nil, io.EOF
}
// FirstPageToken fetches the first page. Drivers use nil.
// The public API doesn't use nil for the first page because it would be too easy to
// keep fetching forever (since the last page return nil for the next pageToken).
if bytes.Equal(pageToken, FirstPageToken) {
pageToken = nil
}
b.mu.RLock()
defer b.mu.RUnlock()
if b.closed {
return nil, nil, ErrClosed
}
dopts := &ListOptions{
Prefix: opts.Prefix,
Delimiter: opts.Delimiter,
PageToken: pageToken,
PageSize: pageSize,
}
retval = make([]*ListObject, 0, pageSize)
for len(retval) < pageSize {
p, err := b.drv.ListPaged(ctx, dopts)
if err != nil {
return nil, nil, wrapError(b.drv, err, "")
}
for _, dobj := range p.Objects {
retval = append(retval, &ListObject{
Key: dobj.Key,
ModTime: dobj.ModTime,
Size: dobj.Size,
MD5: dobj.MD5,
IsDir: dobj.IsDir,
})
}
// ListPaged may return fewer results than pageSize. If there are more results
// available, signalled by non-empty p.NextPageToken, try to fetch the remainder
// of the page.
// It does not work to ask for more results than we need, because then we'd have
// a NextPageToken on a non-page boundary.
dopts.PageSize = pageSize - len(retval)
dopts.PageToken = p.NextPageToken
if len(dopts.PageToken) == 0 {
dopts.PageToken = nil
break
}
}
return retval, dopts.PageToken, nil
}
// Attributes contains attributes about a blob.
type Attributes struct {
// CacheControl specifies caching attributes that services may use
// when serving the blob.
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Cache-Control
CacheControl string `json:"cacheControl"`
// ContentDisposition specifies whether the blob content is expected to be
// displayed inline or as an attachment.
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Disposition
ContentDisposition string `json:"contentDisposition"`
// ContentEncoding specifies the encoding used for the blob's content, if any.
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Encoding
ContentEncoding string `json:"contentEncoding"`
// ContentLanguage specifies the language used in the blob's content, if any.
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Language
ContentLanguage string `json:"contentLanguage"`
// ContentType is the MIME type of the blob. It will not be empty.
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Type
ContentType string `json:"contentType"`
// Metadata holds key/value pairs associated with the blob.
// Keys are guaranteed to be in lowercase, even if the backend service
// has case-sensitive keys (although note that Metadata written via
// this package will always be lowercased). If there are duplicate
// case-insensitive keys (e.g., "foo" and "FOO"), only one value
// will be kept, and it is undefined which one.
Metadata map[string]string `json:"metadata"`
// CreateTime is the time the blob was created, if available. If not available,
// CreateTime will be the zero time.
CreateTime time.Time `json:"createTime"`
// ModTime is the time the blob was last modified.
ModTime time.Time `json:"modTime"`
// Size is the size of the blob's content in bytes.
Size int64 `json:"size"`
// MD5 is an MD5 hash of the blob contents or nil if not available.
MD5 []byte `json:"md5"`
// ETag for the blob; see https://en.wikipedia.org/wiki/HTTP_ETag.
ETag string `json:"etag"`
}
// Attributes returns attributes for the blob stored at key.
//
// If the blob does not exist, Attributes returns an error for which
// gcerrors.Code will return gcerrors.NotFound.
func (b *Bucket) Attributes(ctx context.Context, key string) (_ *Attributes, err error) {
if !utf8.ValidString(key) {
return nil, fmt.Errorf("Attributes key must be a valid UTF-8 string: %q", key)
}
b.mu.RLock()
defer b.mu.RUnlock()
if b.closed {
return nil, ErrClosed
}
a, err := b.drv.Attributes(ctx, key)
if err != nil {
return nil, wrapError(b.drv, err, key)
}
var md map[string]string
if len(a.Metadata) > 0 {
// Services are inconsistent, but at least some treat keys
// as case-insensitive. To make the behavior consistent, we
// force-lowercase them when writing and reading.
md = make(map[string]string, len(a.Metadata))
for k, v := range a.Metadata {
md[strings.ToLower(k)] = v
}
}
return &Attributes{
CacheControl: a.CacheControl,
ContentDisposition: a.ContentDisposition,
ContentEncoding: a.ContentEncoding,
ContentLanguage: a.ContentLanguage,
ContentType: a.ContentType,
Metadata: md,
CreateTime: a.CreateTime,
ModTime: a.ModTime,
Size: a.Size,
MD5: a.MD5,
ETag: a.ETag,
}, nil
}
// Exists returns true if a blob exists at key, false if it does not exist, or
// an error.
//
// It is a shortcut for calling Attributes and checking if it returns an error
// with code ErrNotFound.
func (b *Bucket) Exists(ctx context.Context, key string) (bool, error) {
_, err := b.Attributes(ctx, key)
if err == nil {
return true, nil
}
if errors.Is(err, ErrNotFound) {
return false, nil
}
return false, err
}
// NewReader is a shortcut for NewRangeReader with offset=0 and length=-1.
func (b *Bucket) NewReader(ctx context.Context, key string) (*Reader, error) {
return b.newRangeReader(ctx, key, 0, -1)
}
// NewRangeReader returns a Reader to read content from the blob stored at key.
// It reads at most length bytes starting at offset (>= 0).
// If length is negative, it will read till the end of the blob.
//
// For the purposes of Seek, the returned Reader will start at offset and
// end at the minimum of the actual end of the blob or (if length > 0) offset + length.
//
// Note that ctx is used for all reads performed during the lifetime of the reader.
//
// If the blob does not exist, NewRangeReader returns an error for which
// gcerrors.Code will return gcerrors.NotFound. Exists is a lighter-weight way
// to check for existence.
//
// A nil ReaderOptions is treated the same as the zero value.
//
// The caller must call Close on the returned Reader when done reading.
func (b *Bucket) NewRangeReader(ctx context.Context, key string, offset, length int64) (_ *Reader, err error) {
return b.newRangeReader(ctx, key, offset, length)
}
func (b *Bucket) newRangeReader(ctx context.Context, key string, offset, length int64) (_ *Reader, err error) {
b.mu.RLock()
defer b.mu.RUnlock()
if b.closed {
return nil, ErrClosed
}
if offset < 0 {
return nil, fmt.Errorf("NewRangeReader offset must be non-negative (%d)", offset)
}
if !utf8.ValidString(key) {
return nil, fmt.Errorf("NewRangeReader key must be a valid UTF-8 string: %q", key)
}
var dr DriverReader
dr, err = b.drv.NewRangeReader(ctx, key, offset, length)
if err != nil {
return nil, wrapError(b.drv, err, key)
}
r := &Reader{
drv: b.drv,
r: dr,
key: key,
ctx: ctx,
baseOffset: offset,
baseLength: length,
savedOffset: -1,
}
_, file, lineno, ok := runtime.Caller(2)
runtime.SetFinalizer(r, func(r *Reader) {
if !r.closed {
var caller string
if ok {
caller = fmt.Sprintf(" (%s:%d)", file, lineno)
}
log.Printf("A blob.Reader reading from %q was never closed%s", key, caller)
}
})
return r, nil
}
// WriterOptions sets options for NewWriter.
type WriterOptions struct {
// BufferSize changes the default size in bytes of the chunks that
// Writer will upload in a single request; larger blobs will be split into
// multiple requests.
//
// This option may be ignored by some drivers.
//
// If 0, the driver will choose a reasonable default.
//
// If the Writer is used to do many small writes concurrently, using a
// smaller BufferSize may reduce memory usage.
BufferSize int
// MaxConcurrency changes the default concurrency for parts of an upload.
//
// This option may be ignored by some drivers.
//
// If 0, the driver will choose a reasonable default.
MaxConcurrency int
// CacheControl specifies caching attributes that services may use
// when serving the blob.
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Cache-Control
CacheControl string
// ContentDisposition specifies whether the blob content is expected to be
// displayed inline or as an attachment.
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Disposition
ContentDisposition string
// ContentEncoding specifies the encoding used for the blob's content, if any.
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Encoding
ContentEncoding string
// ContentLanguage specifies the language used in the blob's content, if any.
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Language
ContentLanguage string
// ContentType specifies the MIME type of the blob being written. If not set,
// it will be inferred from the content using the algorithm described at
// http://mimesniff.spec.whatwg.org/.
// Set DisableContentTypeDetection to true to disable the above and force
// the ContentType to stay empty.
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Type
ContentType string
// When true, if ContentType is the empty string, it will stay the empty
// string rather than being inferred from the content.
// Note that while the blob will be written with an empty string ContentType,
// most providers will fill one in during reads, so don't expect an empty
// ContentType if you read the blob back.
DisableContentTypeDetection bool
// ContentMD5 is used as a message integrity check.
// If len(ContentMD5) > 0, the MD5 hash of the bytes written must match
// ContentMD5, or Close will return an error without completing the write.
// https://tools.ietf.org/html/rfc1864
ContentMD5 []byte
// Metadata holds key/value strings to be associated with the blob, or nil.
// Keys may not be empty, and are lowercased before being written.
// Duplicate case-insensitive keys (e.g., "foo" and "FOO") will result in
// an error.
Metadata map[string]string
}
// NewWriter returns a Writer that writes to the blob stored at key.
// A nil WriterOptions is treated the same as the zero value.
//
// If a blob with this key already exists, it will be replaced.
// The blob being written is not guaranteed to be readable until Close
// has been called; until then, any previous blob will still be readable.
// Even after Close is called, newly written blobs are not guaranteed to be
// returned from List; some services are only eventually consistent.
//
// The returned Writer will store ctx for later use in Write and/or Close.
// To abort a write, cancel ctx; otherwise, it must remain open until
// Close is called.
//
// The caller must call Close on the returned Writer, even if the write is
// aborted.
func (b *Bucket) NewWriter(ctx context.Context, key string, opts *WriterOptions) (_ *Writer, err error) {
if !utf8.ValidString(key) {
return nil, fmt.Errorf("NewWriter key must be a valid UTF-8 string: %q", key)
}
if opts == nil {
opts = &WriterOptions{}
}
dopts := &WriterOptions{
CacheControl: opts.CacheControl,
ContentDisposition: opts.ContentDisposition,
ContentEncoding: opts.ContentEncoding,
ContentLanguage: opts.ContentLanguage,
ContentMD5: opts.ContentMD5,
BufferSize: opts.BufferSize,
MaxConcurrency: opts.MaxConcurrency,
DisableContentTypeDetection: opts.DisableContentTypeDetection,
}
if len(opts.Metadata) > 0 {
// Services are inconsistent, but at least some treat keys
// as case-insensitive. To make the behavior consistent, we
// force-lowercase them when writing and reading.
md := make(map[string]string, len(opts.Metadata))
for k, v := range opts.Metadata {
if k == "" {
return nil, errors.New("WriterOptions.Metadata keys may not be empty strings")
}
if !utf8.ValidString(k) {
return nil, fmt.Errorf("WriterOptions.Metadata keys must be valid UTF-8 strings: %q", k)
}
if !utf8.ValidString(v) {
return nil, fmt.Errorf("WriterOptions.Metadata values must be valid UTF-8 strings: %q", v)
}
lowerK := strings.ToLower(k)
if _, found := md[lowerK]; found {
return nil, fmt.Errorf("WriterOptions.Metadata has a duplicate case-insensitive metadata key: %q", lowerK)
}
md[lowerK] = v
}
dopts.Metadata = md
}
b.mu.RLock()
defer b.mu.RUnlock()
if b.closed {
return nil, ErrClosed
}
ctx, cancel := context.WithCancel(ctx)
w := &Writer{
drv: b.drv,
cancel: cancel,
key: key,
contentMD5: opts.ContentMD5,
md5hash: md5.New(),
}
if opts.ContentType != "" || opts.DisableContentTypeDetection {
var ct string
if opts.ContentType != "" {
t, p, err := mime.ParseMediaType(opts.ContentType)
if err != nil {
cancel()
return nil, err
}
ct = mime.FormatMediaType(t, p)
}
dw, err := b.drv.NewTypedWriter(ctx, key, ct, dopts)
if err != nil {
cancel()
return nil, wrapError(b.drv, err, key)
}
w.w = dw
} else {
// Save the fields needed to called NewTypedWriter later, once we've gotten
// sniffLen bytes; see the comment on Writer.
w.ctx = ctx
w.opts = dopts
w.buf = bytes.NewBuffer([]byte{})
}
_, file, lineno, ok := runtime.Caller(1)
runtime.SetFinalizer(w, func(w *Writer) {
if !w.closed {
var caller string
if ok {
caller = fmt.Sprintf(" (%s:%d)", file, lineno)
}
log.Printf("A blob.Writer writing to %q was never closed%s", key, caller)
}
})
return w, nil
}
// Copy the blob stored at srcKey to dstKey.
// A nil CopyOptions is treated the same as the zero value.
//
// If the source blob does not exist, Copy returns an error for which
// gcerrors.Code will return gcerrors.NotFound.
//
// If the destination blob already exists, it is overwritten.
func (b *Bucket) Copy(ctx context.Context, dstKey, srcKey string) (err error) {
if !utf8.ValidString(srcKey) {
return fmt.Errorf("Copy srcKey must be a valid UTF-8 string: %q", srcKey)
}
if !utf8.ValidString(dstKey) {
return fmt.Errorf("Copy dstKey must be a valid UTF-8 string: %q", dstKey)
}
b.mu.RLock()
defer b.mu.RUnlock()
if b.closed {
return ErrClosed
}
return wrapError(b.drv, b.drv.Copy(ctx, dstKey, srcKey), fmt.Sprintf("%s -> %s", srcKey, dstKey))
}
// Delete deletes the blob stored at key.
//
// If the blob does not exist, Delete returns an error for which
// gcerrors.Code will return gcerrors.NotFound.
func (b *Bucket) Delete(ctx context.Context, key string) (err error) {
if !utf8.ValidString(key) {
return fmt.Errorf("Delete key must be a valid UTF-8 string: %q", key)
}
b.mu.RLock()
defer b.mu.RUnlock()
if b.closed {
return ErrClosed
}
return wrapError(b.drv, b.drv.Delete(ctx, key), key)
}
// Close releases any resources used for the bucket.
//
// @todo Consider removing it.
func (b *Bucket) Close() error {
b.mu.Lock()
prev := b.closed
b.closed = true
b.mu.Unlock()
if prev {
return ErrClosed
}
return wrapError(b.drv, b.drv.Close(), "")
}
func wrapError(b Driver, err error, key string) error {
if err == nil {
return nil
}
// don't wrap or normalize EOF errors since there are many places
// in the standard library (e.g. io.ReadAll) that rely on checks
// such as "err == io.EOF" and they will fail
if errors.Is(err, io.EOF) {
return err
}
err = b.NormalizeError(err)
if key != "" {
err = fmt.Errorf("[key: %s] %w", key, err)
}
return err
}

View file

@ -0,0 +1,107 @@
package blob
import (
"context"
"io"
"time"
)
// ReaderAttributes contains a subset of attributes about a blob that are
// accessible from Reader.
type ReaderAttributes struct {
// ContentType is the MIME type of the blob object. It must not be empty.
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Type
ContentType string `json:"contentType"`
// ModTime is the time the blob object was last modified.
ModTime time.Time `json:"modTime"`
// Size is the size of the object in bytes.
Size int64 `json:"size"`
}
// DriverReader reads an object from the blob.
type DriverReader interface {
io.ReadCloser
// Attributes returns a subset of attributes about the blob.
// The portable type will not modify the returned ReaderAttributes.
Attributes() *ReaderAttributes
}
// DriverWriter writes an object to the blob.
type DriverWriter interface {
io.WriteCloser
}
// Driver provides read, write and delete operations on objects within it on the
// blob service.
type Driver interface {
NormalizeError(err error) error
// Attributes returns attributes for the blob. If the specified object does
// not exist, Attributes must return an error for which ErrorCode returns ErrNotFound.
// The portable type will not modify the returned Attributes.
Attributes(ctx context.Context, key string) (*Attributes, error)
// ListPaged lists objects in the bucket, in lexicographical order by
// UTF-8-encoded key, returning pages of objects at a time.
// Services are only required to be eventually consistent with respect
// to recently written or deleted objects. That is to say, there is no
// guarantee that an object that's been written will immediately be returned
// from ListPaged.
// opts is guaranteed to be non-nil.
ListPaged(ctx context.Context, opts *ListOptions) (*ListPage, error)
// NewRangeReader returns a Reader that reads part of an object, reading at
// most length bytes starting at the given offset. If length is negative, it
// will read until the end of the object. If the specified object does not
// exist, NewRangeReader must return an error for which ErrorCode returns ErrNotFound.
// opts is guaranteed to be non-nil.
//
// The returned Reader *may* also implement Downloader if the underlying
// implementation can take advantage of that. The Download call is guaranteed
// to be the only call to the Reader. For such readers, offset will always
// be 0 and length will always be -1.
NewRangeReader(ctx context.Context, key string, offset, length int64) (DriverReader, error)
// NewTypedWriter returns Writer that writes to an object associated with key.
//
// A new object will be created unless an object with this key already exists.
// Otherwise any previous object with the same key will be replaced.
// The object may not be available (and any previous object will remain)
// until Close has been called.
//
// contentType sets the MIME type of the object to be written.
// opts is guaranteed to be non-nil.
//
// The caller must call Close on the returned Writer when done writing.
//
// Implementations should abort an ongoing write if ctx is later canceled,
// and do any necessary cleanup in Close. Close should then return ctx.Err().
//
// The returned Writer *may* also implement Uploader if the underlying
// implementation can take advantage of that. The Upload call is guaranteed
// to be the only non-Close call to the Writer..
NewTypedWriter(ctx context.Context, key, contentType string, opts *WriterOptions) (DriverWriter, error)
// Copy copies the object associated with srcKey to dstKey.
//
// If the source object does not exist, Copy must return an error for which
// ErrorCode returns ErrNotFound.
//
// If the destination object already exists, it should be overwritten.
//
// opts is guaranteed to be non-nil.
Copy(ctx context.Context, dstKey, srcKey string) error
// Delete deletes the object associated with key. If the specified object does
// not exist, Delete must return an error for which ErrorCode returns ErrNotFound.
Delete(ctx context.Context, key string) error
// Close cleans up any resources used by the Bucket. Once Close is called,
// there will be no method calls to the Bucket other than As, ErrorAs, and
// ErrorCode. There may be open readers or writers that will receive calls.
// It is up to the driver as to how these will be handled.
Close() error
}

View file

@ -0,0 +1,153 @@
package blob
import (
"fmt"
"strconv"
)
// Copied from gocloud.dev/blob to avoid nuances around the specific
// HEX escaping/unescaping rules.
//
// -------------------------------------------------------------------
// Copyright 2019 The Go Cloud Development Kit Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// -------------------------------------------------------------------
// HexEscape returns s, with all runes for which shouldEscape returns true
// escaped to "__0xXXX__", where XXX is the hex representation of the rune
// value. For example, " " would escape to "__0x20__".
//
// Non-UTF-8 strings will have their non-UTF-8 characters escaped to
// unicode.ReplacementChar; the original value is lost. Please file an
// issue if you need non-UTF8 support.
//
// Note: shouldEscape takes the whole string as a slice of runes and an
// index. Passing it a single byte or a single rune doesn't provide
// enough context for some escape decisions; for example, the caller might
// want to escape the second "/" in "//" but not the first one.
// We pass a slice of runes instead of the string or a slice of bytes
// because some decisions will be made on a rune basis (e.g., encode
// all non-ASCII runes).
func HexEscape(s string, shouldEscape func(s []rune, i int) bool) string {
// Do a first pass to see which runes (if any) need escaping.
runes := []rune(s)
var toEscape []int
for i := range runes {
if shouldEscape(runes, i) {
toEscape = append(toEscape, i)
}
}
if len(toEscape) == 0 {
return s
}
// Each escaped rune turns into at most 14 runes ("__0x7fffffff__"),
// so allocate an extra 13 for each. We'll reslice at the end
// if we didn't end up using them.
escaped := make([]rune, len(runes)+13*len(toEscape))
n := 0 // current index into toEscape
j := 0 // current index into escaped
for i, r := range runes {
if n < len(toEscape) && i == toEscape[n] {
// We were asked to escape this rune.
for _, x := range fmt.Sprintf("__%#x__", r) {
escaped[j] = x
j++
}
n++
} else {
escaped[j] = r
j++
}
}
return string(escaped[0:j])
}
// unescape tries to unescape starting at r[i].
// It returns a boolean indicating whether the unescaping was successful,
// and (if true) the unescaped rune and the last index of r that was used
// during unescaping.
func unescape(r []rune, i int) (bool, rune, int) {
// Look for "__0x".
if r[i] != '_' {
return false, 0, 0
}
i++
if i >= len(r) || r[i] != '_' {
return false, 0, 0
}
i++
if i >= len(r) || r[i] != '0' {
return false, 0, 0
}
i++
if i >= len(r) || r[i] != 'x' {
return false, 0, 0
}
i++
// Capture the digits until the next "_" (if any).
var hexdigits []rune
for ; i < len(r) && r[i] != '_'; i++ {
hexdigits = append(hexdigits, r[i])
}
// Look for the trailing "__".
if i >= len(r) || r[i] != '_' {
return false, 0, 0
}
i++
if i >= len(r) || r[i] != '_' {
return false, 0, 0
}
// Parse the hex digits into an int32.
retval, err := strconv.ParseInt(string(hexdigits), 16, 32)
if err != nil {
return false, 0, 0
}
return true, rune(retval), i
}
// HexUnescape reverses HexEscape.
func HexUnescape(s string) string {
var unescaped []rune
runes := []rune(s)
for i := 0; i < len(runes); i++ {
if ok, newR, newI := unescape(runes, i); ok {
// We unescaped some runes starting at i, resulting in the
// unescaped rune newR. The last rune used was newI.
if unescaped == nil {
// This is the first rune we've encountered that
// needed unescaping. Allocate a buffer and copy any
// previous runes.
unescaped = make([]rune, i)
copy(unescaped, runes)
}
unescaped = append(unescaped, newR)
i = newI
} else if unescaped != nil {
unescaped = append(unescaped, runes[i])
}
}
if unescaped == nil {
return s
}
return string(unescaped)
}

View file

@ -0,0 +1,196 @@
package blob
import (
"context"
"fmt"
"io"
"log"
"time"
)
// Largely copied from gocloud.dev/blob.Reader to minimize breaking changes.
//
// -------------------------------------------------------------------
// Copyright 2019 The Go Cloud Development Kit Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// -------------------------------------------------------------------
var _ io.ReadSeekCloser = (*Reader)(nil)
// Reader reads bytes from a blob.
// It implements io.ReadSeekCloser, and must be closed after reads are finished.
type Reader struct {
ctx context.Context // Used to recreate r after Seeks
r DriverReader
drv Driver
key string
baseOffset int64 // The base offset provided to NewRangeReader.
baseLength int64 // The length provided to NewRangeReader (may be negative).
relativeOffset int64 // Current offset (relative to baseOffset).
savedOffset int64 // Last relativeOffset for r, saved after relativeOffset is changed in Seek, or -1 if no Seek.
closed bool
}
// Read implements io.Reader (https://golang.org/pkg/io/#Reader).
func (r *Reader) Read(p []byte) (int, error) {
if r.savedOffset != -1 {
// We've done one or more Seeks since the last read. We may have
// to recreate the Reader.
//
// Note that remembering the savedOffset and lazily resetting the
// reader like this allows the caller to Seek, then Seek again back,
// to the original offset, without having to recreate the reader.
// We only have to recreate the reader if we actually read after a Seek.
// This is an important optimization because it's common to Seek
// to (SeekEnd, 0) and use the return value to determine the size
// of the data, then Seek back to (SeekStart, 0).
saved := r.savedOffset
if r.relativeOffset == saved {
// Nope! We're at the same place we left off.
r.savedOffset = -1
} else {
// Yep! We've changed the offset. Recreate the reader.
length := r.baseLength
if length >= 0 {
length -= r.relativeOffset
if length < 0 {
// Shouldn't happen based on checks in Seek.
return 0, fmt.Errorf("invalid Seek (base length %d, relative offset %d)", r.baseLength, r.relativeOffset)
}
}
newR, err := r.drv.NewRangeReader(r.ctx, r.key, r.baseOffset+r.relativeOffset, length)
if err != nil {
return 0, wrapError(r.drv, err, r.key)
}
_ = r.r.Close()
r.savedOffset = -1
r.r = newR
}
}
n, err := r.r.Read(p)
r.relativeOffset += int64(n)
return n, wrapError(r.drv, err, r.key)
}
// Seek implements io.Seeker (https://golang.org/pkg/io/#Seeker).
func (r *Reader) Seek(offset int64, whence int) (int64, error) {
if r.savedOffset == -1 {
// Save the current offset for our reader. If the Seek changes the
// offset, and then we try to read, we'll need to recreate the reader.
// See comment above in Read for why we do it lazily.
r.savedOffset = r.relativeOffset
}
// The maximum relative offset is the minimum of:
// 1. The actual size of the blob, minus our initial baseOffset.
// 2. The length provided to NewRangeReader (if it was non-negative).
maxRelativeOffset := r.Size() - r.baseOffset
if r.baseLength >= 0 && r.baseLength < maxRelativeOffset {
maxRelativeOffset = r.baseLength
}
switch whence {
case io.SeekStart:
r.relativeOffset = offset
case io.SeekCurrent:
r.relativeOffset += offset
case io.SeekEnd:
r.relativeOffset = maxRelativeOffset + offset
}
if r.relativeOffset < 0 {
// "Seeking to an offset before the start of the file is an error."
invalidOffset := r.relativeOffset
r.relativeOffset = 0
return 0, fmt.Errorf("Seek resulted in invalid offset %d, using 0", invalidOffset)
}
if r.relativeOffset > maxRelativeOffset {
// "Seeking to any positive offset is legal, but the behavior of subsequent
// I/O operations on the underlying object is implementation-dependent."
// We'll choose to set the offset to the EOF.
log.Printf("blob.Reader.Seek set an offset after EOF (base offset/length from NewRangeReader %d, %d; actual blob size %d; relative offset %d -> absolute offset %d).", r.baseOffset, r.baseLength, r.Size(), r.relativeOffset, r.baseOffset+r.relativeOffset)
r.relativeOffset = maxRelativeOffset
}
return r.relativeOffset, nil
}
// Close implements io.Closer (https://golang.org/pkg/io/#Closer).
func (r *Reader) Close() error {
r.closed = true
err := wrapError(r.drv, r.r.Close(), r.key)
return err
}
// ContentType returns the MIME type of the blob.
func (r *Reader) ContentType() string {
return r.r.Attributes().ContentType
}
// ModTime returns the time the blob was last modified.
func (r *Reader) ModTime() time.Time {
return r.r.Attributes().ModTime
}
// Size returns the size of the blob content in bytes.
func (r *Reader) Size() int64 {
return r.r.Attributes().Size
}
// WriteTo reads from r and writes to w until there's no more data or
// an error occurs.
// The return value is the number of bytes written to w.
//
// It implements the io.WriterTo interface.
func (r *Reader) WriteTo(w io.Writer) (int64, error) {
// If the writer has a ReaderFrom method, use it to do the copy.
// Don't do this for our own *Writer to avoid infinite recursion.
// Avoids an allocation and a copy.
switch w.(type) {
case *Writer:
default:
if rf, ok := w.(io.ReaderFrom); ok {
n, err := rf.ReadFrom(r)
return n, err
}
}
_, nw, err := readFromWriteTo(r, w)
return nw, err
}
// readFromWriteTo is a helper for ReadFrom and WriteTo.
// It reads data from r and writes to w, until EOF or a read/write error.
// It returns the number of bytes read from r and the number of bytes
// written to w.
func readFromWriteTo(r io.Reader, w io.Writer) (int64, int64, error) {
// Note: can't use io.Copy because it will try to use r.WriteTo
// or w.WriteTo, which is recursive in this context.
buf := make([]byte, 1024)
var totalRead, totalWritten int64
for {
numRead, rerr := r.Read(buf)
if numRead > 0 {
totalRead += int64(numRead)
numWritten, werr := w.Write(buf[0:numRead])
totalWritten += int64(numWritten)
if werr != nil {
return totalRead, totalWritten, werr
}
}
if rerr == io.EOF {
// Done!
return totalRead, totalWritten, nil
}
if rerr != nil {
return totalRead, totalWritten, rerr
}
}
}

View file

@ -0,0 +1,184 @@
package blob
import (
"bytes"
"context"
"fmt"
"hash"
"io"
"net/http"
)
// Largely copied from gocloud.dev/blob.Writer to minimize breaking changes.
//
// -------------------------------------------------------------------
// Copyright 2019 The Go Cloud Development Kit Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// -------------------------------------------------------------------
var _ io.WriteCloser = (*Writer)(nil)
// Writer writes bytes to a blob.
//
// It implements io.WriteCloser (https://golang.org/pkg/io/#Closer), and must be
// closed after all writes are done.
type Writer struct {
drv Driver
w DriverWriter
key string
cancel func() // cancels the ctx provided to NewTypedWriter if contentMD5 verification fails
contentMD5 []byte
md5hash hash.Hash
bytesWritten int
closed bool
// These fields are non-zero values only when w is nil (not yet created).
//
// A ctx is stored in the Writer since we need to pass it into NewTypedWriter
// when we finish detecting the content type of the blob and create the
// underlying driver.Writer. This step happens inside Write or Close and
// neither of them take a context.Context as an argument.
//
// All 3 fields are only initialized when we create the Writer without
// setting the w field, and are reset to zero values after w is created.
ctx context.Context
opts *WriterOptions
buf *bytes.Buffer
}
// sniffLen is the byte size of Writer.buf used to detect content-type.
const sniffLen = 512
// Write implements the io.Writer interface (https://golang.org/pkg/io/#Writer).
//
// Writes may happen asynchronously, so the returned error can be nil
// even if the actual write eventually fails. The write is only guaranteed to
// have succeeded if Close returns no error.
func (w *Writer) Write(p []byte) (int, error) {
if len(w.contentMD5) > 0 {
if _, err := w.md5hash.Write(p); err != nil {
return 0, err
}
}
if w.w != nil {
return w.write(p)
}
// If w is not yet created due to no content-type being passed in, try to sniff
// the MIME type based on at most 512 bytes of the blob content of p.
// Detect the content-type directly if the first chunk is at least 512 bytes.
if w.buf.Len() == 0 && len(p) >= sniffLen {
return w.open(p)
}
// Store p in w.buf and detect the content-type when the size of content in
// w.buf is at least 512 bytes.
n, err := w.buf.Write(p)
if err != nil {
return 0, err
}
if w.buf.Len() >= sniffLen {
// Note that w.open will return the full length of the buffer; we don't want
// to return that as the length of this write since some of them were written in
// previous writes. Instead, we return the n from this write, above.
_, err := w.open(w.buf.Bytes())
return n, err
}
return n, nil
}
// Close closes the blob writer. The write operation is not guaranteed
// to have succeeded until Close returns with no error.
//
// Close may return an error if the context provided to create the
// Writer is canceled or reaches its deadline.
func (w *Writer) Close() (err error) {
w.closed = true
// Verify the MD5 hash of what was written matches the ContentMD5 provided by the user.
if len(w.contentMD5) > 0 {
md5sum := w.md5hash.Sum(nil)
if !bytes.Equal(md5sum, w.contentMD5) {
// No match! Return an error, but first cancel the context and call the
// driver's Close function to ensure the write is aborted.
w.cancel()
if w.w != nil {
_ = w.w.Close()
}
return fmt.Errorf("the WriterOptions.ContentMD5 you specified (%X) did not match what was written (%X)", w.contentMD5, md5sum)
}
}
defer w.cancel()
if w.w != nil {
return wrapError(w.drv, w.w.Close(), w.key)
}
if _, err := w.open(w.buf.Bytes()); err != nil {
return err
}
return wrapError(w.drv, w.w.Close(), w.key)
}
// open tries to detect the MIME type of p and write it to the blob.
// The error it returns is wrapped.
func (w *Writer) open(p []byte) (int, error) {
ct := http.DetectContentType(p)
var err error
w.w, err = w.drv.NewTypedWriter(w.ctx, w.key, ct, w.opts)
if err != nil {
return 0, wrapError(w.drv, err, w.key)
}
// Set the 3 fields needed for lazy NewTypedWriter back to zero values
// (see the comment on Writer).
w.buf = nil
w.ctx = nil
w.opts = nil
return w.write(p)
}
func (w *Writer) write(p []byte) (int, error) {
n, err := w.w.Write(p)
w.bytesWritten += n
return n, wrapError(w.drv, err, w.key)
}
// ReadFrom reads from r and writes to w until EOF or error.
// The return value is the number of bytes read from r.
//
// It implements the io.ReaderFrom interface.
func (w *Writer) ReadFrom(r io.Reader) (int64, error) {
// If the reader has a WriteTo method, use it to do the copy.
// Don't do this for our own *Reader to avoid infinite recursion.
// Avoids an allocation and a copy.
switch r.(type) {
case *Reader:
default:
if wt, ok := r.(io.WriterTo); ok {
return wt.WriteTo(w)
}
}
nr, _, err := readFromWriteTo(r, w)
return nr, err
}

268
tools/filesystem/file.go Normal file
View file

@ -0,0 +1,268 @@
package filesystem
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"mime/multipart"
"net/http"
"os"
"path"
"regexp"
"strings"
"github.com/gabriel-vasile/mimetype"
"github.com/pocketbase/pocketbase/tools/inflector"
"github.com/pocketbase/pocketbase/tools/security"
)
// FileReader defines an interface for a file resource reader.
type FileReader interface {
Open() (io.ReadSeekCloser, error)
}
// File defines a single file [io.ReadSeekCloser] resource.
//
// The file could be from a local path, multipart/form-data header, etc.
type File struct {
Reader FileReader `form:"-" json:"-" xml:"-"`
Name string `form:"name" json:"name" xml:"name"`
OriginalName string `form:"originalName" json:"originalName" xml:"originalName"`
Size int64 `form:"size" json:"size" xml:"size"`
}
// AsMap implements [core.mapExtractor] and returns a value suitable
// to be used in an API rule expression.
func (f *File) AsMap() map[string]any {
return map[string]any{
"name": f.Name,
"originalName": f.OriginalName,
"size": f.Size,
}
}
// NewFileFromPath creates a new File instance from the provided local file path.
func NewFileFromPath(path string) (*File, error) {
f := &File{}
info, err := os.Stat(path)
if err != nil {
return nil, err
}
f.Reader = &PathReader{Path: path}
f.Size = info.Size()
f.OriginalName = info.Name()
f.Name = normalizeName(f.Reader, f.OriginalName)
return f, nil
}
// NewFileFromBytes creates a new File instance from the provided byte slice.
func NewFileFromBytes(b []byte, name string) (*File, error) {
size := len(b)
if size == 0 {
return nil, errors.New("cannot create an empty file")
}
f := &File{}
f.Reader = &BytesReader{b}
f.Size = int64(size)
f.OriginalName = name
f.Name = normalizeName(f.Reader, f.OriginalName)
return f, nil
}
// NewFileFromMultipart creates a new File from the provided multipart header.
func NewFileFromMultipart(mh *multipart.FileHeader) (*File, error) {
f := &File{}
f.Reader = &MultipartReader{Header: mh}
f.Size = mh.Size
f.OriginalName = mh.Filename
f.Name = normalizeName(f.Reader, f.OriginalName)
return f, nil
}
// NewFileFromURL creates a new File from the provided url by
// downloading the resource and load it as BytesReader.
//
// Example
//
// ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
// defer cancel()
//
// file, err := filesystem.NewFileFromURL(ctx, "https://example.com/image.png")
func NewFileFromURL(ctx context.Context, url string) (*File, error) {
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return nil, err
}
res, err := http.DefaultClient.Do(req)
if err != nil {
return nil, err
}
defer res.Body.Close()
if res.StatusCode < 200 || res.StatusCode > 399 {
return nil, fmt.Errorf("failed to download url %s (%d)", url, res.StatusCode)
}
var buf bytes.Buffer
if _, err = io.Copy(&buf, res.Body); err != nil {
return nil, err
}
return NewFileFromBytes(buf.Bytes(), path.Base(url))
}
// -------------------------------------------------------------------
var _ FileReader = (*MultipartReader)(nil)
// MultipartReader defines a FileReader from [multipart.FileHeader].
type MultipartReader struct {
Header *multipart.FileHeader
}
// Open implements the [filesystem.FileReader] interface.
func (r *MultipartReader) Open() (io.ReadSeekCloser, error) {
return r.Header.Open()
}
// -------------------------------------------------------------------
var _ FileReader = (*PathReader)(nil)
// PathReader defines a FileReader from a local file path.
type PathReader struct {
Path string
}
// Open implements the [filesystem.FileReader] interface.
func (r *PathReader) Open() (io.ReadSeekCloser, error) {
return os.Open(r.Path)
}
// -------------------------------------------------------------------
var _ FileReader = (*BytesReader)(nil)
// BytesReader defines a FileReader from bytes content.
type BytesReader struct {
Bytes []byte
}
// Open implements the [filesystem.FileReader] interface.
func (r *BytesReader) Open() (io.ReadSeekCloser, error) {
return &bytesReadSeekCloser{bytes.NewReader(r.Bytes)}, nil
}
type bytesReadSeekCloser struct {
*bytes.Reader
}
// Close implements the [io.ReadSeekCloser] interface.
func (r *bytesReadSeekCloser) Close() error {
return nil
}
// -------------------------------------------------------------------
var _ FileReader = (openFuncAsReader)(nil)
// openFuncAsReader defines a FileReader from a bare Open function.
type openFuncAsReader func() (io.ReadSeekCloser, error)
// Open implements the [filesystem.FileReader] interface.
func (r openFuncAsReader) Open() (io.ReadSeekCloser, error) {
return r()
}
// -------------------------------------------------------------------
var extInvalidCharsRegex = regexp.MustCompile(`[^\w\.\*\-\+\=\#]+`)
const randomAlphabet = "abcdefghijklmnopqrstuvwxyz0123456789"
func normalizeName(fr FileReader, name string) string {
// extension
// ---
originalExt := extractExtension(name)
cleanExt := extInvalidCharsRegex.ReplaceAllString(originalExt, "")
if cleanExt == "" {
// try to detect the extension from the file content
cleanExt, _ = detectExtension(fr)
}
if extLength := len(cleanExt); extLength > 20 {
// keep only the last 20 characters (it is multibyte safe after the regex replace)
cleanExt = "." + cleanExt[extLength-20:]
}
// name
// ---
cleanName := inflector.Snakecase(strings.TrimSuffix(name, originalExt))
if length := len(cleanName); length < 3 {
// the name is too short so we concatenate an additional random part
cleanName += security.RandomStringWithAlphabet(10, randomAlphabet)
} else if length > 100 {
// keep only the first 100 characters (it is multibyte safe after Snakecase)
cleanName = cleanName[:100]
}
return fmt.Sprintf(
"%s_%s%s",
cleanName,
security.RandomStringWithAlphabet(10, randomAlphabet), // ensure that there is always a random part
cleanExt,
)
}
// extractExtension extracts the extension (with leading dot) from name.
//
// This differ from filepath.Ext() by supporting double extensions (eg. ".tar.gz").
//
// Returns an empty string if no match is found.
//
// Example:
// extractExtension("test.txt") // .txt
// extractExtension("test.tar.gz") // .tar.gz
// extractExtension("test.a.tar.gz") // .tar.gz
func extractExtension(name string) string {
primaryDot := strings.LastIndex(name, ".")
if primaryDot == -1 {
return ""
}
// look for secondary extension
secondaryDot := strings.LastIndex(name[:primaryDot], ".")
if secondaryDot >= 0 {
return name[secondaryDot:]
}
return name[primaryDot:]
}
// detectExtension tries to detect the extension from file mime type.
func detectExtension(fr FileReader) (string, error) {
r, err := fr.Open()
if err != nil {
return "", err
}
defer r.Close()
mt, err := mimetype.DetectReader(r)
if err != nil {
return "", err
}
return mt.Extension(), nil
}

View file

@ -0,0 +1,231 @@
package filesystem_test
import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"regexp"
"strconv"
"strings"
"testing"
"github.com/pocketbase/pocketbase/tests"
"github.com/pocketbase/pocketbase/tools/filesystem"
)
func TestFileAsMap(t *testing.T) {
file, err := filesystem.NewFileFromBytes([]byte("test"), "test123.txt")
if err != nil {
t.Fatal(err)
}
result := file.AsMap()
if len(result) != 3 {
t.Fatalf("Expected map with %d keys, got\n%v", 3, result)
}
if result["size"] != int64(4) {
t.Fatalf("Expected size %d, got %#v", 4, result["size"])
}
if str, ok := result["name"].(string); !ok || !strings.HasPrefix(str, "test123") {
t.Fatalf("Expected name to have prefix %q, got %#v", "test123", result["name"])
}
if result["originalName"] != "test123.txt" {
t.Fatalf("Expected originalName %q, got %#v", "test123.txt", result["originalName"])
}
}
func TestNewFileFromPath(t *testing.T) {
testDir := createTestDir(t)
defer os.RemoveAll(testDir)
// missing file
_, err := filesystem.NewFileFromPath("missing")
if err == nil {
t.Fatal("Expected error, got nil")
}
// existing file
originalName := "image_! noext"
normalizedNamePattern := regexp.QuoteMeta("image_noext_") + `\w{10}` + regexp.QuoteMeta(".png")
f, err := filesystem.NewFileFromPath(filepath.Join(testDir, originalName))
if err != nil {
t.Fatalf("Expected nil error, got %v", err)
}
if f.OriginalName != originalName {
t.Fatalf("Expected OriginalName %q, got %q", originalName, f.OriginalName)
}
if match, err := regexp.Match(normalizedNamePattern, []byte(f.Name)); !match {
t.Fatalf("Expected Name to match %v, got %q (%v)", normalizedNamePattern, f.Name, err)
}
if f.Size != 73 {
t.Fatalf("Expected Size %v, got %v", 73, f.Size)
}
if _, ok := f.Reader.(*filesystem.PathReader); !ok {
t.Fatalf("Expected Reader to be PathReader, got %v", f.Reader)
}
}
func TestNewFileFromBytes(t *testing.T) {
// nil bytes
if _, err := filesystem.NewFileFromBytes(nil, "photo.jpg"); err == nil {
t.Fatal("Expected error, got nil")
}
// zero bytes
if _, err := filesystem.NewFileFromBytes([]byte{}, "photo.jpg"); err == nil {
t.Fatal("Expected error, got nil")
}
originalName := "image_! noext"
normalizedNamePattern := regexp.QuoteMeta("image_noext_") + `\w{10}` + regexp.QuoteMeta(".txt")
f, err := filesystem.NewFileFromBytes([]byte("text\n"), originalName)
if err != nil {
t.Fatal(err)
}
if f.Size != 5 {
t.Fatalf("Expected Size %v, got %v", 5, f.Size)
}
if f.OriginalName != originalName {
t.Fatalf("Expected OriginalName %q, got %q", originalName, f.OriginalName)
}
if match, err := regexp.Match(normalizedNamePattern, []byte(f.Name)); !match {
t.Fatalf("Expected Name to match %v, got %q (%v)", normalizedNamePattern, f.Name, err)
}
}
func TestNewFileFromMultipart(t *testing.T) {
formData, mp, err := tests.MockMultipartData(nil, "test")
if err != nil {
t.Fatal(err)
}
req := httptest.NewRequest("", "/", formData)
req.Header.Set("Content-Type", mp.FormDataContentType())
req.ParseMultipartForm(32 << 20)
_, mh, err := req.FormFile("test")
if err != nil {
t.Fatal(err)
}
f, err := filesystem.NewFileFromMultipart(mh)
if err != nil {
t.Fatal(err)
}
originalNamePattern := regexp.QuoteMeta("tmpfile-") + `\w+` + regexp.QuoteMeta(".txt")
if match, err := regexp.Match(originalNamePattern, []byte(f.OriginalName)); !match {
t.Fatalf("Expected OriginalName to match %v, got %q (%v)", originalNamePattern, f.OriginalName, err)
}
normalizedNamePattern := regexp.QuoteMeta("tmpfile_") + `\w+\_\w{10}` + regexp.QuoteMeta(".txt")
if match, err := regexp.Match(normalizedNamePattern, []byte(f.Name)); !match {
t.Fatalf("Expected Name to match %v, got %q (%v)", normalizedNamePattern, f.Name, err)
}
if f.Size != 4 {
t.Fatalf("Expected Size %v, got %v", 4, f.Size)
}
if _, ok := f.Reader.(*filesystem.MultipartReader); !ok {
t.Fatalf("Expected Reader to be MultipartReader, got %v", f.Reader)
}
}
func TestNewFileFromURLTimeout(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/error" {
w.WriteHeader(http.StatusInternalServerError)
}
fmt.Fprintf(w, "test")
}))
defer srv.Close()
// cancelled context
{
ctx, cancel := context.WithCancel(context.Background())
cancel()
f, err := filesystem.NewFileFromURL(ctx, srv.URL+"/cancel")
if err == nil {
t.Fatal("[ctx_cancel] Expected error, got nil")
}
if f != nil {
t.Fatalf("[ctx_cancel] Expected file to be nil, got %v", f)
}
}
// error response
{
f, err := filesystem.NewFileFromURL(context.Background(), srv.URL+"/error")
if err == nil {
t.Fatal("[error_status] Expected error, got nil")
}
if f != nil {
t.Fatalf("[error_status] Expected file to be nil, got %v", f)
}
}
// valid response
{
originalName := "image_! noext"
normalizedNamePattern := regexp.QuoteMeta("image_noext_") + `\w{10}` + regexp.QuoteMeta(".txt")
f, err := filesystem.NewFileFromURL(context.Background(), srv.URL+"/"+originalName)
if err != nil {
t.Fatalf("[valid] Unexpected error %v", err)
}
if f == nil {
t.Fatal("[valid] Expected non-nil file")
}
// check the created file fields
if f.OriginalName != originalName {
t.Fatalf("Expected OriginalName %q, got %q", originalName, f.OriginalName)
}
if match, err := regexp.Match(normalizedNamePattern, []byte(f.Name)); !match {
t.Fatalf("Expected Name to match %v, got %q (%v)", normalizedNamePattern, f.Name, err)
}
if f.Size != 4 {
t.Fatalf("Expected Size %v, got %v", 4, f.Size)
}
if _, ok := f.Reader.(*filesystem.BytesReader); !ok {
t.Fatalf("Expected Reader to be BytesReader, got %v", f.Reader)
}
}
}
func TestFileNameNormalizations(t *testing.T) {
scenarios := []struct {
name string
pattern string
}{
{"", `^\w{10}_\w{10}\.txt$`},
{".png", `^\w{10}_\w{10}\.png$`},
{".tar.gz", `^\w{10}_\w{10}\.tar\.gz$`},
{"a.tar.gz", `^a\w{10}_\w{10}\.tar\.gz$`},
{"a.b.c.d.tar.gz", `^a_b_c_d_\w{10}\.tar\.gz$`},
{"abcd", `^abcd_\w{10}\.txt$`},
{"a b! c d . 456", `^a_b_c_d_\w{10}\.456$`}, // normalize spaces
{strings.Repeat("a", 101) + "." + strings.Repeat("b", 21), `^a{100}_\w{10}\.b{20}$`}, // name and extension length trim
}
for i, s := range scenarios {
t.Run(strconv.Itoa(i)+"_"+s.name, func(t *testing.T) {
f, err := filesystem.NewFileFromBytes([]byte("abc"), s.name)
if err != nil {
t.Fatal(err)
}
if match, err := regexp.Match(s.pattern, []byte(f.Name)); !match {
t.Fatalf("Expected Name to match %v, got %q (%v)", s.pattern, f.Name, err)
}
})
}
}

View file

@ -0,0 +1,564 @@
package filesystem
import (
"context"
"errors"
"image"
"io"
"mime/multipart"
"net/http"
"os"
"path"
"path/filepath"
"regexp"
"sort"
"strconv"
"strings"
"github.com/disintegration/imaging"
"github.com/fatih/color"
"github.com/gabriel-vasile/mimetype"
"github.com/pocketbase/pocketbase/tools/filesystem/blob"
"github.com/pocketbase/pocketbase/tools/filesystem/internal/fileblob"
"github.com/pocketbase/pocketbase/tools/filesystem/internal/s3blob"
"github.com/pocketbase/pocketbase/tools/filesystem/internal/s3blob/s3"
"github.com/pocketbase/pocketbase/tools/list"
// explicit webp decoder because disintegration/imaging does not support webp
_ "golang.org/x/image/webp"
)
// note: the same as blob.ErrNotFound for backward compatibility with earlier versions
var ErrNotFound = blob.ErrNotFound
const metadataOriginalName = "original-filename"
type System struct {
ctx context.Context
bucket *blob.Bucket
}
// NewS3 initializes an S3 filesystem instance.
//
// NB! Make sure to call `Close()` after you are done working with it.
func NewS3(
bucketName string,
region string,
endpoint string,
accessKey string,
secretKey string,
s3ForcePathStyle bool,
) (*System, error) {
ctx := context.Background() // default context
client := &s3.S3{
Bucket: bucketName,
Region: region,
Endpoint: endpoint,
AccessKey: accessKey,
SecretKey: secretKey,
UsePathStyle: s3ForcePathStyle,
}
drv, err := s3blob.New(client)
if err != nil {
return nil, err
}
return &System{ctx: ctx, bucket: blob.NewBucket(drv)}, nil
}
// NewLocal initializes a new local filesystem instance.
//
// NB! Make sure to call `Close()` after you are done working with it.
func NewLocal(dirPath string) (*System, error) {
ctx := context.Background() // default context
// makes sure that the directory exist
if err := os.MkdirAll(dirPath, os.ModePerm); err != nil {
return nil, err
}
drv, err := fileblob.New(dirPath, &fileblob.Options{
NoTempDir: true,
})
if err != nil {
return nil, err
}
return &System{ctx: ctx, bucket: blob.NewBucket(drv)}, nil
}
// SetContext assigns the specified context to the current filesystem.
func (s *System) SetContext(ctx context.Context) {
s.ctx = ctx
}
// Close releases any resources used for the related filesystem.
func (s *System) Close() error {
return s.bucket.Close()
}
// Exists checks if file with fileKey path exists or not.
func (s *System) Exists(fileKey string) (bool, error) {
return s.bucket.Exists(s.ctx, fileKey)
}
// Attributes returns the attributes for the file with fileKey path.
//
// If the file doesn't exist it returns ErrNotFound.
func (s *System) Attributes(fileKey string) (*blob.Attributes, error) {
return s.bucket.Attributes(s.ctx, fileKey)
}
// GetReader returns a file content reader for the given fileKey.
//
// NB! Make sure to call Close() on the file after you are done working with it.
//
// If the file doesn't exist returns ErrNotFound.
func (s *System) GetReader(fileKey string) (*blob.Reader, error) {
return s.bucket.NewReader(s.ctx, fileKey)
}
// Deprecated: Please use GetReader(fileKey) instead.
func (s *System) GetFile(fileKey string) (*blob.Reader, error) {
color.Yellow("Deprecated: Please replace GetFile with GetReader.")
return s.GetReader(fileKey)
}
// GetReuploadableFile constructs a new reuploadable File value
// from the associated fileKey blob.Reader.
//
// If preserveName is false then the returned File.Name will have
// a new randomly generated suffix, otherwise it will reuse the original one.
//
// This method could be useful in case you want to clone an existing
// Record file and assign it to a new Record (e.g. in a Record duplicate action).
//
// If you simply want to copy an existing file to a new location you
// could check the Copy(srcKey, dstKey) method.
func (s *System) GetReuploadableFile(fileKey string, preserveName bool) (*File, error) {
attrs, err := s.Attributes(fileKey)
if err != nil {
return nil, err
}
name := path.Base(fileKey)
originalName := attrs.Metadata[metadataOriginalName]
if originalName == "" {
originalName = name
}
file := &File{}
file.Size = attrs.Size
file.OriginalName = originalName
file.Reader = openFuncAsReader(func() (io.ReadSeekCloser, error) {
return s.GetReader(fileKey)
})
if preserveName {
file.Name = name
} else {
file.Name = normalizeName(file.Reader, originalName)
}
return file, nil
}
// Copy copies the file stored at srcKey to dstKey.
//
// If srcKey file doesn't exist, it returns ErrNotFound.
//
// If dstKey file already exists, it is overwritten.
func (s *System) Copy(srcKey, dstKey string) error {
return s.bucket.Copy(s.ctx, dstKey, srcKey)
}
// List returns a flat list with info for all files under the specified prefix.
func (s *System) List(prefix string) ([]*blob.ListObject, error) {
files := []*blob.ListObject{}
iter := s.bucket.List(&blob.ListOptions{
Prefix: prefix,
})
for {
obj, err := iter.Next(s.ctx)
if err != nil {
if !errors.Is(err, io.EOF) {
return nil, err
}
break
}
files = append(files, obj)
}
return files, nil
}
// Upload writes content into the fileKey location.
func (s *System) Upload(content []byte, fileKey string) error {
opts := &blob.WriterOptions{
ContentType: mimetype.Detect(content).String(),
}
w, writerErr := s.bucket.NewWriter(s.ctx, fileKey, opts)
if writerErr != nil {
return writerErr
}
if _, err := w.Write(content); err != nil {
return errors.Join(err, w.Close())
}
return w.Close()
}
// UploadFile uploads the provided File to the fileKey location.
func (s *System) UploadFile(file *File, fileKey string) error {
f, err := file.Reader.Open()
if err != nil {
return err
}
defer f.Close()
mt, err := mimetype.DetectReader(f)
if err != nil {
return err
}
// rewind
f.Seek(0, io.SeekStart)
originalName := file.OriginalName
if len(originalName) > 255 {
// keep only the first 255 chars as a very rudimentary measure
// to prevent the metadata to grow too big in size
originalName = originalName[:255]
}
opts := &blob.WriterOptions{
ContentType: mt.String(),
Metadata: map[string]string{
metadataOriginalName: originalName,
},
}
w, err := s.bucket.NewWriter(s.ctx, fileKey, opts)
if err != nil {
return err
}
if _, err := w.ReadFrom(f); err != nil {
w.Close()
return err
}
return w.Close()
}
// UploadMultipart uploads the provided multipart file to the fileKey location.
func (s *System) UploadMultipart(fh *multipart.FileHeader, fileKey string) error {
f, err := fh.Open()
if err != nil {
return err
}
defer f.Close()
mt, err := mimetype.DetectReader(f)
if err != nil {
return err
}
// rewind
f.Seek(0, io.SeekStart)
originalName := fh.Filename
if len(originalName) > 255 {
// keep only the first 255 chars as a very rudimentary measure
// to prevent the metadata to grow too big in size
originalName = originalName[:255]
}
opts := &blob.WriterOptions{
ContentType: mt.String(),
Metadata: map[string]string{
metadataOriginalName: originalName,
},
}
w, err := s.bucket.NewWriter(s.ctx, fileKey, opts)
if err != nil {
return err
}
_, err = w.ReadFrom(f)
if err != nil {
w.Close()
return err
}
return w.Close()
}
// Delete deletes stored file at fileKey location.
//
// If the file doesn't exist returns ErrNotFound.
func (s *System) Delete(fileKey string) error {
return s.bucket.Delete(s.ctx, fileKey)
}
// DeletePrefix deletes everything starting with the specified prefix.
//
// The prefix could be subpath (ex. "/a/b/") or filename prefix (ex. "/a/b/file_").
func (s *System) DeletePrefix(prefix string) []error {
failed := []error{}
if prefix == "" {
failed = append(failed, errors.New("prefix mustn't be empty"))
return failed
}
dirsMap := map[string]struct{}{}
var isPrefixDir bool
// treat the prefix as directory only if it ends with trailing slash
if strings.HasSuffix(prefix, "/") {
isPrefixDir = true
dirsMap[strings.TrimRight(prefix, "/")] = struct{}{}
}
// delete all files with the prefix
// ---
iter := s.bucket.List(&blob.ListOptions{
Prefix: prefix,
})
for {
obj, err := iter.Next(s.ctx)
if err != nil {
if !errors.Is(err, io.EOF) {
failed = append(failed, err)
}
break
}
if err := s.Delete(obj.Key); err != nil {
failed = append(failed, err)
} else if isPrefixDir {
slashIdx := strings.LastIndex(obj.Key, "/")
if slashIdx > -1 {
dirsMap[obj.Key[:slashIdx]] = struct{}{}
}
}
}
// ---
// try to delete the empty remaining dir objects
// (this operation usually is optional and there is no need to strictly check the result)
// ---
// fill dirs slice
dirs := make([]string, 0, len(dirsMap))
for d := range dirsMap {
dirs = append(dirs, d)
}
// sort the child dirs first, aka. ["a/b/c", "a/b", "a"]
sort.SliceStable(dirs, func(i, j int) bool {
return len(strings.Split(dirs[i], "/")) > len(strings.Split(dirs[j], "/"))
})
// delete dirs
for _, d := range dirs {
if d != "" {
s.Delete(d)
}
}
// ---
return failed
}
// Checks if the provided dir prefix doesn't have any files.
//
// A trailing slash will be appended to a non-empty dir string argument
// to ensure that the checked prefix is a "directory".
//
// Returns "false" in case the has at least one file, otherwise - "true".
func (s *System) IsEmptyDir(dir string) bool {
if dir != "" && !strings.HasSuffix(dir, "/") {
dir += "/"
}
iter := s.bucket.List(&blob.ListOptions{
Prefix: dir,
})
_, err := iter.Next(s.ctx)
return err != nil && errors.Is(err, io.EOF)
}
var inlineServeContentTypes = []string{
// image
"image/png", "image/jpg", "image/jpeg", "image/gif", "image/webp", "image/x-icon", "image/bmp",
// video
"video/webm", "video/mp4", "video/3gpp", "video/quicktime", "video/x-ms-wmv",
// audio
"audio/basic", "audio/aiff", "audio/mpeg", "audio/midi", "audio/mp3", "audio/wave",
"audio/wav", "audio/x-wav", "audio/x-mpeg", "audio/x-m4a", "audio/aac",
// document
"application/pdf", "application/x-pdf",
}
// manualExtensionContentTypes is a map of file extensions to content types.
var manualExtensionContentTypes = map[string]string{
".svg": "image/svg+xml", // (see https://github.com/whatwg/mimesniff/issues/7)
".css": "text/css", // (see https://github.com/gabriel-vasile/mimetype/pull/113)
".js": "text/javascript", // (see https://github.com/pocketbase/pocketbase/issues/6597)
".mjs": "text/javascript",
}
// forceAttachmentParam is the name of the request query parameter to
// force "Content-Disposition: attachment" header.
const forceAttachmentParam = "download"
// Serve serves the file at fileKey location to an HTTP response.
//
// If the `download` query parameter is used the file will be always served for
// download no matter of its type (aka. with "Content-Disposition: attachment").
//
// Internally this method uses [http.ServeContent] so Range requests,
// If-Match, If-Unmodified-Since, etc. headers are handled transparently.
func (s *System) Serve(res http.ResponseWriter, req *http.Request, fileKey string, name string) error {
br, readErr := s.GetReader(fileKey)
if readErr != nil {
return readErr
}
defer br.Close()
var forceAttachment bool
if raw := req.URL.Query().Get(forceAttachmentParam); raw != "" {
forceAttachment, _ = strconv.ParseBool(raw)
}
disposition := "attachment"
realContentType := br.ContentType()
if !forceAttachment && list.ExistInSlice(realContentType, inlineServeContentTypes) {
disposition = "inline"
}
// make an exception for specific content types and force a custom
// content type to send in the response so that it can be loaded properly
extContentType := realContentType
if ct, found := manualExtensionContentTypes[filepath.Ext(name)]; found {
extContentType = ct
}
setHeaderIfMissing(res, "Content-Disposition", disposition+"; filename="+name)
setHeaderIfMissing(res, "Content-Type", extContentType)
setHeaderIfMissing(res, "Content-Security-Policy", "default-src 'none'; media-src 'self'; style-src 'unsafe-inline'; sandbox")
// set a default cache-control header
// (valid for 30 days but the cache is allowed to reuse the file for any requests
// that are made in the last day while revalidating the res in the background)
setHeaderIfMissing(res, "Cache-Control", "max-age=2592000, stale-while-revalidate=86400")
http.ServeContent(res, req, name, br.ModTime(), br)
return nil
}
// note: expects key to be in a canonical form (eg. "accept-encoding" should be "Accept-Encoding").
func setHeaderIfMissing(res http.ResponseWriter, key string, value string) {
if _, ok := res.Header()[key]; !ok {
res.Header().Set(key, value)
}
}
var ThumbSizeRegex = regexp.MustCompile(`^(\d+)x(\d+)(t|b|f)?$`)
// CreateThumb creates a new thumb image for the file at originalKey location.
// The new thumb file is stored at thumbKey location.
//
// thumbSize is in the format:
// - 0xH (eg. 0x100) - resize to H height preserving the aspect ratio
// - Wx0 (eg. 300x0) - resize to W width preserving the aspect ratio
// - WxH (eg. 300x100) - resize and crop to WxH viewbox (from center)
// - WxHt (eg. 300x100t) - resize and crop to WxH viewbox (from top)
// - WxHb (eg. 300x100b) - resize and crop to WxH viewbox (from bottom)
// - WxHf (eg. 300x100f) - fit inside a WxH viewbox (without cropping)
func (s *System) CreateThumb(originalKey string, thumbKey, thumbSize string) error {
sizeParts := ThumbSizeRegex.FindStringSubmatch(thumbSize)
if len(sizeParts) != 4 {
return errors.New("thumb size must be in WxH, WxHt, WxHb or WxHf format")
}
width, _ := strconv.Atoi(sizeParts[1])
height, _ := strconv.Atoi(sizeParts[2])
resizeType := sizeParts[3]
if width == 0 && height == 0 {
return errors.New("thumb width and height cannot be zero at the same time")
}
// fetch the original
r, readErr := s.GetReader(originalKey)
if readErr != nil {
return readErr
}
defer r.Close()
// create imaging object from the original reader
// (note: only the first frame for animated image formats)
img, decodeErr := imaging.Decode(r, imaging.AutoOrientation(true))
if decodeErr != nil {
return decodeErr
}
var thumbImg *image.NRGBA
if width == 0 || height == 0 {
// force resize preserving aspect ratio
thumbImg = imaging.Resize(img, width, height, imaging.Linear)
} else {
switch resizeType {
case "f":
// fit
thumbImg = imaging.Fit(img, width, height, imaging.Linear)
case "t":
// fill and crop from top
thumbImg = imaging.Fill(img, width, height, imaging.Top, imaging.Linear)
case "b":
// fill and crop from bottom
thumbImg = imaging.Fill(img, width, height, imaging.Bottom, imaging.Linear)
default:
// fill and crop from center
thumbImg = imaging.Fill(img, width, height, imaging.Center, imaging.Linear)
}
}
opts := &blob.WriterOptions{
ContentType: r.ContentType(),
}
// open a thumb storage writer (aka. prepare for upload)
w, writerErr := s.bucket.NewWriter(s.ctx, thumbKey, opts)
if writerErr != nil {
return writerErr
}
// try to detect the thumb format based on the original file name
// (fallbacks to png on error)
format, err := imaging.FormatFromFilename(thumbKey)
if err != nil {
format = imaging.PNG
}
// thumb encode (aka. upload)
if err := imaging.Encode(w, thumbImg, format); err != nil {
w.Close()
return err
}
// check for close errors to ensure that the thumb was really saved
return w.Close()
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,84 @@
package fileblob
import (
"encoding/json"
"fmt"
"os"
)
// Largely copied from gocloud.dev/blob/fileblob to apply the same
// retrieve and write side-car .attrs rules.
//
// -------------------------------------------------------------------
// Copyright 2018 The Go Cloud Development Kit Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// -------------------------------------------------------------------
const attrsExt = ".attrs"
var errAttrsExt = fmt.Errorf("file extension %q is reserved", attrsExt)
// xattrs stores extended attributes for an object. The format is like
// filesystem extended attributes, see
// https://www.freedesktop.org/wiki/CommonExtendedAttributes.
type xattrs struct {
CacheControl string `json:"user.cache_control"`
ContentDisposition string `json:"user.content_disposition"`
ContentEncoding string `json:"user.content_encoding"`
ContentLanguage string `json:"user.content_language"`
ContentType string `json:"user.content_type"`
Metadata map[string]string `json:"user.metadata"`
MD5 []byte `json:"md5"`
}
// setAttrs creates a "path.attrs" file along with blob to store the attributes,
// it uses JSON format.
func setAttrs(path string, xa xattrs) error {
f, err := os.Create(path + attrsExt)
if err != nil {
return err
}
if err := json.NewEncoder(f).Encode(xa); err != nil {
f.Close()
os.Remove(f.Name())
return err
}
return f.Close()
}
// getAttrs looks at the "path.attrs" file to retrieve the attributes and
// decodes them into a xattrs struct. It doesn't return error when there is no
// such .attrs file.
func getAttrs(path string) (xattrs, error) {
f, err := os.Open(path + attrsExt)
if err != nil {
if os.IsNotExist(err) {
// Handle gracefully for non-existent .attr files.
return xattrs{
ContentType: "application/octet-stream",
}, nil
}
return xattrs{}, err
}
xa := new(xattrs)
if err := json.NewDecoder(f).Decode(xa); err != nil {
f.Close()
return xattrs{}, err
}
return *xa, f.Close()
}

View file

@ -0,0 +1,713 @@
// Package fileblob provides a blob.Bucket driver implementation.
//
// NB! To minimize breaking changes with older PocketBase releases,
// the driver is a stripped down and adapted version of the previously
// used gocloud.dev/blob/fileblob, hence many of the below doc comments,
// struct options and interface implementations are the same.
//
// To avoid partial writes, fileblob writes to a temporary file and then renames
// the temporary file to the final path on Close. By default, it creates these
// temporary files in `os.TempDir`. If `os.TempDir` is on a different mount than
// your base bucket path, the `os.Rename` will fail with `invalid cross-device link`.
// To avoid this, either configure the temp dir to use by setting the environment
// variable `TMPDIR`, or set `Options.NoTempDir` to `true` (fileblob will create
// the temporary files next to the actual files instead of in a temporary directory).
//
// By default fileblob stores blob metadata in "sidecar" files under the original
// filename with an additional ".attrs" suffix.
// This behaviour can be changed via `Options.Metadata`;
// writing of those metadata files can be suppressed by setting it to
// `MetadataDontWrite` or its equivalent "metadata=skip" in the URL for the opener.
// In either case, absent any stored metadata many `blob.Attributes` fields
// will be set to default values.
//
// The blob abstraction supports all UTF-8 strings; to make this work with services lacking
// full UTF-8 support, strings must be escaped (during writes) and unescaped
// (during reads). The following escapes are performed for fileblob:
// - Blob keys: ASCII characters 0-31 are escaped to "__0x<hex>__".
// If os.PathSeparator != "/", it is also escaped.
// Additionally, the "/" in "../", the trailing "/" in "//", and a trailing
// "/" is key names are escaped in the same way.
// On Windows, the characters "<>:"|?*" are also escaped.
//
// Example:
//
// drv, _ := fileblob.New("/path/to/dir", nil)
// bucket := blob.NewBucket(drv)
package fileblob
import (
"context"
"crypto/md5"
"errors"
"fmt"
"hash"
"io"
"io/fs"
"os"
"path/filepath"
"strconv"
"strings"
"time"
"github.com/pocketbase/pocketbase/tools/filesystem/blob"
)
const defaultPageSize = 1000
type metadataOption string // Not exported as subject to change.
// Settings for Options.Metadata.
const (
// Metadata gets written to a separate file.
MetadataInSidecar metadataOption = ""
// Writes won't carry metadata, as per the package docstring.
MetadataDontWrite metadataOption = "skip"
)
// Options sets options for constructing a *blob.Bucket backed by fileblob.
type Options struct {
// Refers to the strategy for how to deal with metadata (such as blob.Attributes).
// For supported values please see the Metadata* constants.
// If left unchanged, 'MetadataInSidecar' will be used.
Metadata metadataOption
// The FileMode to use when creating directories for the top-level directory
// backing the bucket (when CreateDir is true), and for subdirectories for keys.
// Defaults to 0777.
DirFileMode os.FileMode
// If true, create the directory backing the Bucket if it does not exist
// (using os.MkdirAll).
CreateDir bool
// If true, don't use os.TempDir for temporary files, but instead place them
// next to the actual files. This may result in "stranded" temporary files
// (e.g., if the application is killed before the file cleanup runs).
//
// If your bucket directory is on a different mount than os.TempDir, you will
// need to set this to true, as os.Rename will fail across mount points.
NoTempDir bool
}
// New creates a new instance of the fileblob driver backed by the
// filesystem and rooted at dir, which must exist.
func New(dir string, opts *Options) (blob.Driver, error) {
if opts == nil {
opts = &Options{}
}
if opts.DirFileMode == 0 {
opts.DirFileMode = os.FileMode(0o777)
}
absdir, err := filepath.Abs(dir)
if err != nil {
return nil, fmt.Errorf("failed to convert %s into an absolute path: %v", dir, err)
}
// Optionally, create the directory if it does not already exist.
info, err := os.Stat(absdir)
if err != nil && opts.CreateDir && os.IsNotExist(err) {
err = os.MkdirAll(absdir, opts.DirFileMode)
if err != nil {
return nil, fmt.Errorf("tried to create directory but failed: %v", err)
}
info, err = os.Stat(absdir)
}
if err != nil {
return nil, err
}
if !info.IsDir() {
return nil, fmt.Errorf("%s is not a directory", absdir)
}
return &driver{dir: absdir, opts: opts}, nil
}
type driver struct {
opts *Options
dir string
}
// Close implements [blob/Driver.Close].
func (drv *driver) Close() error {
return nil
}
// NormalizeError implements [blob/Driver.NormalizeError].
func (drv *driver) NormalizeError(err error) error {
if os.IsNotExist(err) {
return errors.Join(err, blob.ErrNotFound)
}
return err
}
// path returns the full path for a key.
func (drv *driver) path(key string) (string, error) {
path := filepath.Join(drv.dir, escapeKey(key))
if strings.HasSuffix(path, attrsExt) {
return "", errAttrsExt
}
return path, nil
}
// forKey returns the full path, os.FileInfo, and attributes for key.
func (drv *driver) forKey(key string) (string, os.FileInfo, *xattrs, error) {
path, err := drv.path(key)
if err != nil {
return "", nil, nil, err
}
info, err := os.Stat(path)
if err != nil {
return "", nil, nil, err
}
if info.IsDir() {
return "", nil, nil, os.ErrNotExist
}
xa, err := getAttrs(path)
if err != nil {
return "", nil, nil, err
}
return path, info, &xa, nil
}
// ListPaged implements [blob/Driver.ListPaged].
func (drv *driver) ListPaged(ctx context.Context, opts *blob.ListOptions) (*blob.ListPage, error) {
var pageToken string
if len(opts.PageToken) > 0 {
pageToken = string(opts.PageToken)
}
pageSize := opts.PageSize
if pageSize == 0 {
pageSize = defaultPageSize
}
// If opts.Delimiter != "", lastPrefix contains the last "directory" key we
// added. It is used to avoid adding it again; all files in this "directory"
// are collapsed to the single directory entry.
var lastPrefix string
var lastKeyAdded string
// If the Prefix contains a "/", we can set the root of the Walk
// to the path specified by the Prefix as any files below the path will not
// match the Prefix.
// Note that we use "/" explicitly and not os.PathSeparator, as the opts.Prefix
// is in the unescaped form.
root := drv.dir
if i := strings.LastIndex(opts.Prefix, "/"); i > -1 {
root = filepath.Join(root, opts.Prefix[:i])
}
var result blob.ListPage
// Do a full recursive scan of the root directory.
err := filepath.WalkDir(root, func(path string, info fs.DirEntry, err error) error {
if err != nil {
// Couldn't read this file/directory for some reason; just skip it.
return nil
}
// Skip the self-generated attribute files.
if strings.HasSuffix(path, attrsExt) {
return nil
}
// os.Walk returns the root directory; skip it.
if path == drv.dir {
return nil
}
// Strip the <drv.dir> prefix from path.
prefixLen := len(drv.dir)
// Include the separator for non-root.
if drv.dir != "/" {
prefixLen++
}
path = path[prefixLen:]
// Unescape the path to get the key.
key := unescapeKey(path)
// Skip all directories. If opts.Delimiter is set, we'll create
// pseudo-directories later.
// Note that returning nil means that we'll still recurse into it;
// we're just not adding a result for the directory itself.
if info.IsDir() {
key += "/"
// Avoid recursing into subdirectories if the directory name already
// doesn't match the prefix; any files in it are guaranteed not to match.
if len(key) > len(opts.Prefix) && !strings.HasPrefix(key, opts.Prefix) {
return filepath.SkipDir
}
// Similarly, avoid recursing into subdirectories if we're making
// "directories" and all of the files in this subdirectory are guaranteed
// to collapse to a "directory" that we've already added.
if lastPrefix != "" && strings.HasPrefix(key, lastPrefix) {
return filepath.SkipDir
}
return nil
}
// Skip files/directories that don't match the Prefix.
if !strings.HasPrefix(key, opts.Prefix) {
return nil
}
var md5 []byte
if xa, err := getAttrs(path); err == nil {
// Note: we only have the MD5 hash for blobs that we wrote.
// For other blobs, md5 will remain nil.
md5 = xa.MD5
}
fi, err := info.Info()
if err != nil {
return err
}
obj := &blob.ListObject{
Key: key,
ModTime: fi.ModTime(),
Size: fi.Size(),
MD5: md5,
}
// If using Delimiter, collapse "directories".
if opts.Delimiter != "" {
// Strip the prefix, which may contain Delimiter.
keyWithoutPrefix := key[len(opts.Prefix):]
// See if the key still contains Delimiter.
// If no, it's a file and we just include it.
// If yes, it's a file in a "sub-directory" and we want to collapse
// all files in that "sub-directory" into a single "directory" result.
if idx := strings.Index(keyWithoutPrefix, opts.Delimiter); idx != -1 {
prefix := opts.Prefix + keyWithoutPrefix[0:idx+len(opts.Delimiter)]
// We've already included this "directory"; don't add it.
if prefix == lastPrefix {
return nil
}
// Update the object to be a "directory".
obj = &blob.ListObject{
Key: prefix,
IsDir: true,
}
lastPrefix = prefix
}
}
// If there's a pageToken, skip anything before it.
if pageToken != "" && obj.Key <= pageToken {
return nil
}
// If we've already got a full page of results, set NextPageToken and stop.
// Unless the current object is a directory, in which case there may
// still be objects coming that are alphabetically before it (since
// we appended the delimiter). In that case, keep going; we'll trim the
// extra entries (if any) before returning.
if len(result.Objects) == pageSize && !obj.IsDir {
result.NextPageToken = []byte(result.Objects[pageSize-1].Key)
return io.EOF
}
result.Objects = append(result.Objects, obj)
// Normally, objects are added in the correct order (by Key).
// However, sometimes adding the file delimiter messes that up
// (e.g., if the file delimiter is later in the alphabet than the last character of a key).
// Detect if this happens and swap if needed.
if len(result.Objects) > 1 && obj.Key < lastKeyAdded {
i := len(result.Objects) - 1
result.Objects[i-1], result.Objects[i] = result.Objects[i], result.Objects[i-1]
lastKeyAdded = result.Objects[i].Key
} else {
lastKeyAdded = obj.Key
}
return nil
})
if err != nil && err != io.EOF {
return nil, err
}
if len(result.Objects) > pageSize {
result.Objects = result.Objects[0:pageSize]
result.NextPageToken = []byte(result.Objects[pageSize-1].Key)
}
return &result, nil
}
// Attributes implements [blob/Driver.Attributes].
func (drv *driver) Attributes(ctx context.Context, key string) (*blob.Attributes, error) {
_, info, xa, err := drv.forKey(key)
if err != nil {
return nil, err
}
return &blob.Attributes{
CacheControl: xa.CacheControl,
ContentDisposition: xa.ContentDisposition,
ContentEncoding: xa.ContentEncoding,
ContentLanguage: xa.ContentLanguage,
ContentType: xa.ContentType,
Metadata: xa.Metadata,
// CreateTime left as the zero time.
ModTime: info.ModTime(),
Size: info.Size(),
MD5: xa.MD5,
ETag: fmt.Sprintf("\"%x-%x\"", info.ModTime().UnixNano(), info.Size()),
}, nil
}
// NewRangeReader implements [blob/Driver.NewRangeReader].
func (drv *driver) NewRangeReader(ctx context.Context, key string, offset, length int64) (blob.DriverReader, error) {
path, info, xa, err := drv.forKey(key)
if err != nil {
return nil, err
}
f, err := os.Open(path)
if err != nil {
return nil, err
}
if offset > 0 {
if _, err := f.Seek(offset, io.SeekStart); err != nil {
return nil, err
}
}
r := io.Reader(f)
if length >= 0 {
r = io.LimitReader(r, length)
}
return &reader{
r: r,
c: f,
attrs: &blob.ReaderAttributes{
ContentType: xa.ContentType,
ModTime: info.ModTime(),
Size: info.Size(),
},
}, nil
}
func createTemp(path string, noTempDir bool) (*os.File, error) {
// Use a custom createTemp function rather than os.CreateTemp() as
// os.CreateTemp() sets the permissions of the tempfile to 0600, rather than
// 0666, making it inconsistent with the directories and attribute files.
try := 0
for {
// Append the current time with nanosecond precision and .tmp to the
// base path. If the file already exists try again. Nanosecond changes enough
// between each iteration to make a conflict unlikely. Using the full
// time lowers the chance of a collision with a file using a similar
// pattern, but has undefined behavior after the year 2262.
var name string
if noTempDir {
name = path
} else {
name = filepath.Join(os.TempDir(), filepath.Base(path))
}
name += "." + strconv.FormatInt(time.Now().UnixNano(), 16) + ".tmp"
f, err := os.OpenFile(name, os.O_RDWR|os.O_CREATE|os.O_EXCL, 0o666)
if os.IsExist(err) {
if try++; try < 10000 {
continue
}
return nil, &os.PathError{Op: "createtemp", Path: path + ".*.tmp", Err: os.ErrExist}
}
return f, err
}
}
// NewTypedWriter implements [blob/Driver.NewTypedWriter].
func (drv *driver) NewTypedWriter(ctx context.Context, key, contentType string, opts *blob.WriterOptions) (blob.DriverWriter, error) {
path, err := drv.path(key)
if err != nil {
return nil, err
}
err = os.MkdirAll(filepath.Dir(path), drv.opts.DirFileMode)
if err != nil {
return nil, err
}
f, err := createTemp(path, drv.opts.NoTempDir)
if err != nil {
return nil, err
}
if drv.opts.Metadata == MetadataDontWrite {
w := &writer{
ctx: ctx,
File: f,
path: path,
}
return w, nil
}
var metadata map[string]string
if len(opts.Metadata) > 0 {
metadata = opts.Metadata
}
return &writerWithSidecar{
ctx: ctx,
f: f,
path: path,
contentMD5: opts.ContentMD5,
md5hash: md5.New(),
attrs: xattrs{
CacheControl: opts.CacheControl,
ContentDisposition: opts.ContentDisposition,
ContentEncoding: opts.ContentEncoding,
ContentLanguage: opts.ContentLanguage,
ContentType: contentType,
Metadata: metadata,
},
}, nil
}
// Copy implements [blob/Driver.Copy].
func (drv *driver) Copy(ctx context.Context, dstKey, srcKey string) error {
// Note: we could use NewRangeReader here, but since we need to copy all of
// the metadata (from xa), it's more efficient to do it directly.
srcPath, _, xa, err := drv.forKey(srcKey)
if err != nil {
return err
}
f, err := os.Open(srcPath)
if err != nil {
return err
}
defer f.Close()
// We'll write the copy using Writer, to avoid re-implementing making of a
// temp file, cleaning up after partial failures, etc.
wopts := blob.WriterOptions{
CacheControl: xa.CacheControl,
ContentDisposition: xa.ContentDisposition,
ContentEncoding: xa.ContentEncoding,
ContentLanguage: xa.ContentLanguage,
Metadata: xa.Metadata,
}
// Create a cancelable context so we can cancel the write if there are problems.
writeCtx, cancel := context.WithCancel(ctx)
defer cancel()
w, err := drv.NewTypedWriter(writeCtx, dstKey, xa.ContentType, &wopts)
if err != nil {
return err
}
_, err = io.Copy(w, f)
if err != nil {
cancel() // cancel before Close cancels the write
w.Close()
return err
}
return w.Close()
}
// Delete implements [blob/Driver.Delete].
func (b *driver) Delete(ctx context.Context, key string) error {
path, err := b.path(key)
if err != nil {
return err
}
err = os.Remove(path)
if err != nil {
return err
}
err = os.Remove(path + attrsExt)
if err != nil && !os.IsNotExist(err) {
return err
}
return nil
}
// -------------------------------------------------------------------
type reader struct {
r io.Reader
c io.Closer
attrs *blob.ReaderAttributes
}
func (r *reader) Read(p []byte) (int, error) {
if r.r == nil {
return 0, io.EOF
}
return r.r.Read(p)
}
func (r *reader) Close() error {
if r.c == nil {
return nil
}
return r.c.Close()
}
// Attributes implements [blob/DriverReader.Attributes].
func (r *reader) Attributes() *blob.ReaderAttributes {
return r.attrs
}
// -------------------------------------------------------------------
// writerWithSidecar implements the strategy of storing metadata in a distinct file.
type writerWithSidecar struct {
ctx context.Context
md5hash hash.Hash
f *os.File
path string
attrs xattrs
contentMD5 []byte
}
func (w *writerWithSidecar) Write(p []byte) (n int, err error) {
n, err = w.f.Write(p)
if err != nil {
// Don't hash the unwritten tail twice when writing is resumed.
w.md5hash.Write(p[:n])
return n, err
}
if _, err := w.md5hash.Write(p); err != nil {
return n, err
}
return n, nil
}
func (w *writerWithSidecar) Close() error {
err := w.f.Close()
if err != nil {
return err
}
// Always delete the temp file. On success, it will have been
// renamed so the Remove will fail.
defer func() {
_ = os.Remove(w.f.Name())
}()
// Check if the write was cancelled.
if err := w.ctx.Err(); err != nil {
return err
}
md5sum := w.md5hash.Sum(nil)
w.attrs.MD5 = md5sum
// Write the attributes file.
if err := setAttrs(w.path, w.attrs); err != nil {
return err
}
// Rename the temp file to path.
if err := os.Rename(w.f.Name(), w.path); err != nil {
_ = os.Remove(w.path + attrsExt)
return err
}
return nil
}
// writer is a file with a temporary name until closed.
//
// Embedding os.File allows the likes of io.Copy to use optimizations,
// which is why it is not folded into writerWithSidecar.
type writer struct {
*os.File
ctx context.Context
path string
}
func (w *writer) Close() error {
err := w.File.Close()
if err != nil {
return err
}
// Always delete the temp file. On success, it will have been renamed so
// the Remove will fail.
tempname := w.Name()
defer os.Remove(tempname)
// Check if the write was cancelled.
if err := w.ctx.Err(); err != nil {
return err
}
// Rename the temp file to path.
return os.Rename(tempname, w.path)
}
// -------------------------------------------------------------------
// escapeKey does all required escaping for UTF-8 strings to work the filesystem.
func escapeKey(s string) string {
s = blob.HexEscape(s, func(r []rune, i int) bool {
c := r[i]
switch {
case c < 32:
return true
// We're going to replace '/' with os.PathSeparator below. In order for this
// to be reversible, we need to escape raw os.PathSeparators.
case os.PathSeparator != '/' && c == os.PathSeparator:
return true
// For "../", escape the trailing slash.
case i > 1 && c == '/' && r[i-1] == '.' && r[i-2] == '.':
return true
// For "//", escape the trailing slash.
case i > 0 && c == '/' && r[i-1] == '/':
return true
// Escape the trailing slash in a key.
case c == '/' && i == len(r)-1:
return true
// https://docs.microsoft.com/en-us/windows/desktop/fileio/naming-a-file
case os.PathSeparator == '\\' && (c == '>' || c == '<' || c == ':' || c == '"' || c == '|' || c == '?' || c == '*'):
return true
}
return false
})
// Replace "/" with os.PathSeparator if needed, so that the local filesystem
// can use subdirectories.
if os.PathSeparator != '/' {
s = strings.ReplaceAll(s, "/", string(os.PathSeparator))
}
return s
}
// unescapeKey reverses escapeKey.
func unescapeKey(s string) string {
if os.PathSeparator != '/' {
s = strings.ReplaceAll(s, string(os.PathSeparator), "/")
}
return blob.HexUnescape(s)
}

View file

@ -0,0 +1,59 @@
package s3
import (
"context"
"encoding/xml"
"net/http"
"net/url"
"strings"
"time"
)
// https://docs.aws.amazon.com/AmazonS3/latest/API/API_CopyObject.html#API_CopyObject_ResponseSyntax
type CopyObjectResponse struct {
CopyObjectResult xml.Name `json:"copyObjectResult" xml:"CopyObjectResult"`
ETag string `json:"etag" xml:"ETag"`
LastModified time.Time `json:"lastModified" xml:"LastModified"`
ChecksumType string `json:"checksumType" xml:"ChecksumType"`
ChecksumCRC32 string `json:"checksumCRC32" xml:"ChecksumCRC32"`
ChecksumCRC32C string `json:"checksumCRC32C" xml:"ChecksumCRC32C"`
ChecksumCRC64NVME string `json:"checksumCRC64NVME" xml:"ChecksumCRC64NVME"`
ChecksumSHA1 string `json:"checksumSHA1" xml:"ChecksumSHA1"`
ChecksumSHA256 string `json:"checksumSHA256" xml:"ChecksumSHA256"`
}
// CopyObject copies a single object from srcKey to dstKey destination.
// (both keys are expected to be operating within the same bucket).
//
// https://docs.aws.amazon.com/AmazonS3/latest/API/API_CopyObject.html
func (s3 *S3) CopyObject(ctx context.Context, srcKey string, dstKey string, optReqFuncs ...func(*http.Request)) (*CopyObjectResponse, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodPut, s3.URL(dstKey), nil)
if err != nil {
return nil, err
}
// per the doc the header value must be URL-encoded
req.Header.Set("x-amz-copy-source", url.PathEscape(s3.Bucket+"/"+strings.TrimLeft(srcKey, "/")))
// apply optional request funcs
for _, fn := range optReqFuncs {
if fn != nil {
fn(req)
}
}
resp, err := s3.SignAndSend(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
result := &CopyObjectResponse{}
err = xml.NewDecoder(resp.Body).Decode(result)
if err != nil {
return nil, err
}
return result, nil
}

View file

@ -0,0 +1,67 @@
package s3_test
import (
"context"
"io"
"net/http"
"strings"
"testing"
"github.com/pocketbase/pocketbase/tools/filesystem/internal/s3blob/s3"
"github.com/pocketbase/pocketbase/tools/filesystem/internal/s3blob/s3/tests"
)
func TestS3CopyObject(t *testing.T) {
t.Parallel()
httpClient := tests.NewClient(
&tests.RequestStub{
Method: http.MethodPut,
URL: "http://test_bucket.example.com/@dst_test",
Match: func(req *http.Request) bool {
return tests.ExpectHeaders(req.Header, map[string]string{
"test_header": "test",
"x-amz-copy-source": "test_bucket%2F@src_test",
"Authorization": "^.+Credential=123/.+$",
})
},
Response: &http.Response{
Body: io.NopCloser(strings.NewReader(`
<CopyObjectResult>
<LastModified>2025-01-01T01:02:03.456Z</LastModified>
<ETag>test_etag</ETag>
</CopyObjectResult>
`)),
},
},
)
s3Client := &s3.S3{
Client: httpClient,
Region: "test_region",
Bucket: "test_bucket",
Endpoint: "http://example.com",
AccessKey: "123",
SecretKey: "abc",
}
copyResp, err := s3Client.CopyObject(context.Background(), "@src_test", "@dst_test", func(r *http.Request) {
r.Header.Set("test_header", "test")
})
if err != nil {
t.Fatal(err)
}
err = httpClient.AssertNoRemaining()
if err != nil {
t.Fatal(err)
}
if copyResp.ETag != "test_etag" {
t.Fatalf("Expected ETag %q, got %q", "test_etag", copyResp.ETag)
}
if date := copyResp.LastModified.Format("2006-01-02T15:04:05.000Z"); date != "2025-01-01T01:02:03.456Z" {
t.Fatalf("Expected LastModified %q, got %q", "2025-01-01T01:02:03.456Z", date)
}
}

View file

@ -0,0 +1,31 @@
package s3
import (
"context"
"net/http"
)
// DeleteObject deletes a single object by its key.
//
// https://docs.aws.amazon.com/AmazonS3/latest/API/API_DeleteObject.html
func (s3 *S3) DeleteObject(ctx context.Context, key string, optFuncs ...func(*http.Request)) error {
req, err := http.NewRequestWithContext(ctx, http.MethodDelete, s3.URL(key), nil)
if err != nil {
return err
}
// apply optional request funcs
for _, fn := range optFuncs {
if fn != nil {
fn(req)
}
}
resp, err := s3.SignAndSend(req)
if err != nil {
return err
}
defer resp.Body.Close()
return nil
}

View file

@ -0,0 +1,48 @@
package s3_test
import (
"context"
"net/http"
"testing"
"github.com/pocketbase/pocketbase/tools/filesystem/internal/s3blob/s3"
"github.com/pocketbase/pocketbase/tools/filesystem/internal/s3blob/s3/tests"
)
func TestS3DeleteObject(t *testing.T) {
t.Parallel()
httpClient := tests.NewClient(
&tests.RequestStub{
Method: http.MethodDelete,
URL: "http://test_bucket.example.com/test_key",
Match: func(req *http.Request) bool {
return tests.ExpectHeaders(req.Header, map[string]string{
"test_header": "test",
"Authorization": "^.+Credential=123/.+$",
})
},
},
)
s3Client := &s3.S3{
Client: httpClient,
Region: "test_region",
Bucket: "test_bucket",
Endpoint: "http://example.com",
AccessKey: "123",
SecretKey: "abc",
}
err := s3Client.DeleteObject(context.Background(), "test_key", func(r *http.Request) {
r.Header.Set("test_header", "test")
})
if err != nil {
t.Fatal(err)
}
err = httpClient.AssertNoRemaining()
if err != nil {
t.Fatal(err)
}
}

View file

@ -0,0 +1,49 @@
package s3
import (
"encoding/xml"
"strconv"
"strings"
)
var _ error = (*ResponseError)(nil)
// ResponseError defines a general S3 response error.
//
// https://docs.aws.amazon.com/AmazonS3/latest/API/ErrorResponses.html
type ResponseError struct {
XMLName xml.Name `json:"-" xml:"Error"`
Code string `json:"code" xml:"Code"`
Message string `json:"message" xml:"Message"`
RequestId string `json:"requestId" xml:"RequestId"`
Resource string `json:"resource" xml:"Resource"`
Raw []byte `json:"-" xml:"-"`
Status int `json:"status" xml:"Status"`
}
// Error implements the std error interface.
func (err *ResponseError) Error() string {
var strBuilder strings.Builder
strBuilder.WriteString(strconv.Itoa(err.Status))
strBuilder.WriteString(" ")
if err.Code != "" {
strBuilder.WriteString(err.Code)
} else {
strBuilder.WriteString("S3ResponseError")
}
if err.Message != "" {
strBuilder.WriteString(": ")
strBuilder.WriteString(err.Message)
}
if len(err.Raw) > 0 {
strBuilder.WriteString("\n(RAW: ")
strBuilder.Write(err.Raw)
strBuilder.WriteString(")")
}
return strBuilder.String()
}

View file

@ -0,0 +1,86 @@
package s3_test
import (
"encoding/json"
"encoding/xml"
"testing"
"github.com/pocketbase/pocketbase/tools/filesystem/internal/s3blob/s3"
)
func TestResponseErrorSerialization(t *testing.T) {
raw := `
<?xml version="1.0" encoding="UTF-8"?>
<Error>
<Code>test_code</Code>
<Message>test_message</Message>
<RequestId>test_request_id</RequestId>
<Resource>test_resource</Resource>
</Error>
`
respErr := &s3.ResponseError{
Status: 123,
Raw: []byte("test"),
}
err := xml.Unmarshal([]byte(raw), &respErr)
if err != nil {
t.Fatal(err)
}
jsonRaw, err := json.Marshal(respErr)
if err != nil {
t.Fatal(err)
}
jsonStr := string(jsonRaw)
expected := `{"code":"test_code","message":"test_message","requestId":"test_request_id","resource":"test_resource","status":123}`
if expected != jsonStr {
t.Fatalf("Expected JSON\n%s\ngot\n%s", expected, jsonStr)
}
}
func TestResponseErrorErrorInterface(t *testing.T) {
scenarios := []struct {
name string
err *s3.ResponseError
expected string
}{
{
"empty",
&s3.ResponseError{},
"0 S3ResponseError",
},
{
"with code and message (nil raw)",
&s3.ResponseError{
Status: 123,
Code: "test_code",
Message: "test_message",
},
"123 test_code: test_message",
},
{
"with code and message (non-nil raw)",
&s3.ResponseError{
Status: 123,
Code: "test_code",
Message: "test_message",
Raw: []byte("test_raw"),
},
"123 test_code: test_message\n(RAW: test_raw)",
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
result := s.err.Error()
if result != s.expected {
t.Fatalf("Expected\n%s\ngot\n%s", s.expected, result)
}
})
}
}

View file

@ -0,0 +1,43 @@
package s3
import (
"context"
"io"
"net/http"
)
// https://docs.aws.amazon.com/AmazonS3/latest/API/API_GetObject.html#API_GetObject_ResponseElements
type GetObjectResponse struct {
Body io.ReadCloser `json:"-" xml:"-"`
HeadObjectResponse
}
// GetObject retrieves a single object by its key.
//
// NB! Make sure to call GetObjectResponse.Body.Close() after done working with the result.
//
// https://docs.aws.amazon.com/AmazonS3/latest/API/API_GetObject.html
func (s3 *S3) GetObject(ctx context.Context, key string, optFuncs ...func(*http.Request)) (*GetObjectResponse, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, s3.URL(key), nil)
if err != nil {
return nil, err
}
// apply optional request funcs
for _, fn := range optFuncs {
if fn != nil {
fn(req)
}
}
resp, err := s3.SignAndSend(req)
if err != nil {
return nil, err
}
result := &GetObjectResponse{Body: resp.Body}
result.load(resp.Header)
return result, nil
}

View file

@ -0,0 +1,92 @@
package s3_test
import (
"context"
"encoding/json"
"io"
"net/http"
"strings"
"testing"
"github.com/pocketbase/pocketbase/tools/filesystem/internal/s3blob/s3"
"github.com/pocketbase/pocketbase/tools/filesystem/internal/s3blob/s3/tests"
)
func TestS3GetObject(t *testing.T) {
t.Parallel()
httpClient := tests.NewClient(
&tests.RequestStub{
Method: http.MethodGet,
URL: "http://test_bucket.example.com/test_key",
Match: func(req *http.Request) bool {
return tests.ExpectHeaders(req.Header, map[string]string{
"test_header": "test",
"Authorization": "^.+Credential=123/.+$",
})
},
Response: &http.Response{
Header: http.Header{
"Last-Modified": []string{"Mon, 01 Feb 2025 03:04:05 GMT"},
"Cache-Control": []string{"test_cache"},
"Content-Disposition": []string{"test_disposition"},
"Content-Encoding": []string{"test_encoding"},
"Content-Language": []string{"test_language"},
"Content-Type": []string{"test_type"},
"Content-Range": []string{"test_range"},
"Etag": []string{"test_etag"},
"Content-Length": []string{"100"},
"x-amz-meta-AbC": []string{"test_meta_a"},
"x-amz-meta-Def": []string{"test_meta_b"},
},
Body: io.NopCloser(strings.NewReader("test")),
},
},
)
s3Client := &s3.S3{
Client: httpClient,
Region: "test_region",
Bucket: "test_bucket",
Endpoint: "http://example.com",
AccessKey: "123",
SecretKey: "abc",
}
resp, err := s3Client.GetObject(context.Background(), "test_key", func(r *http.Request) {
r.Header.Set("test_header", "test")
})
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
err = httpClient.AssertNoRemaining()
if err != nil {
t.Fatal(err)
}
// check body
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
bodyStr := string(body)
if bodyStr != "test" {
t.Fatalf("Expected body\n%q\ngot\n%q", "test", bodyStr)
}
// check serialized attributes
raw, err := json.Marshal(resp)
if err != nil {
t.Fatal(err)
}
rawStr := string(raw)
expected := `{"metadata":{"abc":"test_meta_a","def":"test_meta_b"},"lastModified":"2025-02-01T03:04:05Z","cacheControl":"test_cache","contentDisposition":"test_disposition","contentEncoding":"test_encoding","contentLanguage":"test_language","contentType":"test_type","contentRange":"test_range","etag":"test_etag","contentLength":100}`
if rawStr != expected {
t.Fatalf("Expected attributes\n%s\ngot\n%s", expected, rawStr)
}
}

View file

@ -0,0 +1,89 @@
package s3
import (
"context"
"net/http"
"strconv"
"time"
)
// https://docs.aws.amazon.com/AmazonS3/latest/API/API_HeadObject.html#API_HeadObject_ResponseElements
type HeadObjectResponse struct {
// Metadata is the extra data that is stored with the S3 object (aka. the "x-amz-meta-*" header values).
//
// The map keys are normalized to lower-case.
Metadata map[string]string `json:"metadata"`
// LastModified date and time when the object was last modified.
LastModified time.Time `json:"lastModified"`
// CacheControl specifies caching behavior along the request/reply chain.
CacheControl string `json:"cacheControl"`
// ContentDisposition specifies presentational information for the object.
ContentDisposition string `json:"contentDisposition"`
// ContentEncoding indicates what content encodings have been applied to the object
// and thus what decoding mechanisms must be applied to obtain the
// media-type referenced by the Content-Type header field.
ContentEncoding string `json:"contentEncoding"`
// ContentLanguage indicates the language the content is in.
ContentLanguage string `json:"contentLanguage"`
// ContentType is a standard MIME type describing the format of the object data.
ContentType string `json:"contentType"`
// ContentRange is the portion of the object usually returned in the response for a GET request.
ContentRange string `json:"contentRange"`
// ETag is an opaque identifier assigned by a web
// server to a specific version of a resource found at a URL.
ETag string `json:"etag"`
// ContentLength is size of the body in bytes.
ContentLength int64 `json:"contentLength"`
}
// load parses and load the header values into the current HeadObjectResponse fields.
func (o *HeadObjectResponse) load(headers http.Header) {
o.LastModified, _ = time.Parse(time.RFC1123, headers.Get("Last-Modified"))
o.CacheControl = headers.Get("Cache-Control")
o.ContentDisposition = headers.Get("Content-Disposition")
o.ContentEncoding = headers.Get("Content-Encoding")
o.ContentLanguage = headers.Get("Content-Language")
o.ContentType = headers.Get("Content-Type")
o.ContentRange = headers.Get("Content-Range")
o.ETag = headers.Get("ETag")
o.ContentLength, _ = strconv.ParseInt(headers.Get("Content-Length"), 10, 0)
o.Metadata = extractMetadata(headers)
}
// HeadObject sends a HEAD request for a single object to check its
// existence and to retrieve its metadata.
//
// https://docs.aws.amazon.com/AmazonS3/latest/API/API_HeadObject.html
func (s3 *S3) HeadObject(ctx context.Context, key string, optFuncs ...func(*http.Request)) (*HeadObjectResponse, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodHead, s3.URL(key), nil)
if err != nil {
return nil, err
}
// apply optional request funcs
for _, fn := range optFuncs {
if fn != nil {
fn(req)
}
}
resp, err := s3.SignAndSend(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
result := &HeadObjectResponse{}
result.load(resp.Header)
return result, nil
}

View file

@ -0,0 +1,77 @@
package s3_test
import (
"context"
"encoding/json"
"net/http"
"testing"
"github.com/pocketbase/pocketbase/tools/filesystem/internal/s3blob/s3"
"github.com/pocketbase/pocketbase/tools/filesystem/internal/s3blob/s3/tests"
)
func TestS3HeadObject(t *testing.T) {
t.Parallel()
httpClient := tests.NewClient(
&tests.RequestStub{
Method: http.MethodHead,
URL: "http://test_bucket.example.com/test_key",
Match: func(req *http.Request) bool {
return tests.ExpectHeaders(req.Header, map[string]string{
"test_header": "test",
"Authorization": "^.+Credential=123/.+$",
})
},
Response: &http.Response{
Header: http.Header{
"Last-Modified": []string{"Mon, 01 Feb 2025 03:04:05 GMT"},
"Cache-Control": []string{"test_cache"},
"Content-Disposition": []string{"test_disposition"},
"Content-Encoding": []string{"test_encoding"},
"Content-Language": []string{"test_language"},
"Content-Type": []string{"test_type"},
"Content-Range": []string{"test_range"},
"Etag": []string{"test_etag"},
"Content-Length": []string{"100"},
"x-amz-meta-AbC": []string{"test_meta_a"},
"x-amz-meta-Def": []string{"test_meta_b"},
},
Body: http.NoBody,
},
},
)
s3Client := &s3.S3{
Client: httpClient,
Region: "test_region",
Bucket: "test_bucket",
Endpoint: "http://example.com",
AccessKey: "123",
SecretKey: "abc",
}
resp, err := s3Client.HeadObject(context.Background(), "test_key", func(r *http.Request) {
r.Header.Set("test_header", "test")
})
if err != nil {
t.Fatal(err)
}
err = httpClient.AssertNoRemaining()
if err != nil {
t.Fatal(err)
}
raw, err := json.Marshal(resp)
if err != nil {
t.Fatal(err)
}
rawStr := string(raw)
expected := `{"metadata":{"abc":"test_meta_a","def":"test_meta_b"},"lastModified":"2025-02-01T03:04:05Z","cacheControl":"test_cache","contentDisposition":"test_disposition","contentEncoding":"test_encoding","contentLanguage":"test_language","contentType":"test_type","contentRange":"test_range","etag":"test_etag","contentLength":100}`
if rawStr != expected {
t.Fatalf("Expected response\n%s\ngot\n%s", expected, rawStr)
}
}

View file

@ -0,0 +1,165 @@
package s3
import (
"context"
"encoding/xml"
"net/http"
"net/url"
"strconv"
"time"
)
// ListParams defines optional parameters for the ListObject request.
type ListParams struct {
// ContinuationToken indicates that the list is being continued on this bucket with a token.
// ContinuationToken is obfuscated and is not a real key.
// You can use this ContinuationToken for pagination of the list results.
ContinuationToken string `json:"continuationToken"`
// Delimiter is a character that you use to group keys.
//
// For directory buckets, "/" is the only supported delimiter.
Delimiter string `json:"delimiter"`
// Prefix limits the response to keys that begin with the specified prefix.
Prefix string `json:"prefix"`
// Encoding type is used to encode the object keys in the response.
// Responses are encoded only in UTF-8.
// An object key can contain any Unicode character.
// However, the XML 1.0 parser can't parse certain characters,
// such as characters with an ASCII value from 0 to 10.
// For characters that aren't supported in XML 1.0, you can add
// this parameter to request that S3 encode the keys in the response.
//
// Valid Values: url
EncodingType string `json:"encodingType"`
// StartAfter is where you want S3 to start listing from.
// S3 starts listing after this specified key.
// StartAfter can be any key in the bucket.
//
// This functionality is not supported for directory buckets.
StartAfter string `json:"startAfter"`
// MaxKeys Sets the maximum number of keys returned in the response.
// By default, the action returns up to 1,000 key names.
// The response might contain fewer keys but will never contain more.
MaxKeys int `json:"maxKeys"`
// FetchOwner returns the owner field with each key in the result.
FetchOwner bool `json:"fetchOwner"`
}
// Encode encodes the parameters in a properly formatted query string.
func (l *ListParams) Encode() string {
query := url.Values{}
query.Add("list-type", "2")
if l.ContinuationToken != "" {
query.Add("continuation-token", l.ContinuationToken)
}
if l.Delimiter != "" {
query.Add("delimiter", l.Delimiter)
}
if l.Prefix != "" {
query.Add("prefix", l.Prefix)
}
if l.EncodingType != "" {
query.Add("encoding-type", l.EncodingType)
}
if l.FetchOwner {
query.Add("fetch-owner", "true")
}
if l.MaxKeys > 0 {
query.Add("max-keys", strconv.Itoa(l.MaxKeys))
}
if l.StartAfter != "" {
query.Add("start-after", l.StartAfter)
}
return query.Encode()
}
// ListObjects retrieves paginated objects list.
//
// https://docs.aws.amazon.com/AmazonS3/latest/API/API_ListObjectsV2.html
func (s3 *S3) ListObjects(ctx context.Context, params ListParams, optReqFuncs ...func(*http.Request)) (*ListObjectsResponse, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, s3.URL("?"+params.Encode()), nil)
if err != nil {
return nil, err
}
// apply optional request funcs
for _, fn := range optReqFuncs {
if fn != nil {
fn(req)
}
}
resp, err := s3.SignAndSend(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
result := &ListObjectsResponse{}
err = xml.NewDecoder(resp.Body).Decode(result)
if err != nil {
return nil, err
}
return result, nil
}
// https://docs.aws.amazon.com/AmazonS3/latest/API/API_ListObjectsV2.html#API_ListObjectsV2_ResponseSyntax
type ListObjectsResponse struct {
XMLName xml.Name `json:"-" xml:"ListBucketResult"`
EncodingType string `json:"encodingType" xml:"EncodingType"`
Name string `json:"name" xml:"Name"`
Prefix string `json:"prefix" xml:"Prefix"`
Delimiter string `json:"delimiter" xml:"Delimiter"`
ContinuationToken string `json:"continuationToken" xml:"ContinuationToken"`
NextContinuationToken string `json:"nextContinuationToken" xml:"NextContinuationToken"`
StartAfter string `json:"startAfter" xml:"StartAfter"`
CommonPrefixes []*ListObjectCommonPrefix `json:"commonPrefixes" xml:"CommonPrefixes"`
Contents []*ListObjectContent `json:"contents" xml:"Contents"`
KeyCount int `json:"keyCount" xml:"KeyCount"`
MaxKeys int `json:"maxKeys" xml:"MaxKeys"`
IsTruncated bool `json:"isTruncated" xml:"IsTruncated"`
}
type ListObjectCommonPrefix struct {
Prefix string `json:"prefix" xml:"Prefix"`
}
type ListObjectContent struct {
Owner struct {
DisplayName string `json:"displayName" xml:"DisplayName"`
ID string `json:"id" xml:"ID"`
} `json:"owner" xml:"Owner"`
ChecksumAlgorithm string `json:"checksumAlgorithm" xml:"ChecksumAlgorithm"`
ETag string `json:"etag" xml:"ETag"`
Key string `json:"key" xml:"Key"`
StorageClass string `json:"storageClass" xml:"StorageClass"`
LastModified time.Time `json:"lastModified" xml:"LastModified"`
RestoreStatus struct {
RestoreExpiryDate time.Time `json:"restoreExpiryDate" xml:"RestoreExpiryDate"`
IsRestoreInProgress bool `json:"isRestoreInProgress" xml:"IsRestoreInProgress"`
} `json:"restoreStatus" xml:"RestoreStatus"`
Size int64 `json:"size" xml:"Size"`
}

View file

@ -0,0 +1,157 @@
package s3_test
import (
"context"
"encoding/json"
"io"
"net/http"
"strings"
"testing"
"github.com/pocketbase/pocketbase/tools/filesystem/internal/s3blob/s3"
"github.com/pocketbase/pocketbase/tools/filesystem/internal/s3blob/s3/tests"
)
func TestS3ListParamsEncode(t *testing.T) {
t.Parallel()
scenarios := []struct {
name string
params s3.ListParams
expected string
}{
{
"blank",
s3.ListParams{},
"list-type=2",
},
{
"filled",
s3.ListParams{
ContinuationToken: "test_ct",
Delimiter: "test_delimiter",
Prefix: "test_prefix",
EncodingType: "test_et",
StartAfter: "test_sa",
MaxKeys: 1,
FetchOwner: true,
},
"continuation-token=test_ct&delimiter=test_delimiter&encoding-type=test_et&fetch-owner=true&list-type=2&max-keys=1&prefix=test_prefix&start-after=test_sa",
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
result := s.params.Encode()
if result != s.expected {
t.Fatalf("Expected\n%s\ngot\n%s", s.expected, result)
}
})
}
}
func TestS3ListObjects(t *testing.T) {
t.Parallel()
listParams := s3.ListParams{
ContinuationToken: "test_ct",
Delimiter: "test_delimiter",
Prefix: "test_prefix",
EncodingType: "test_et",
StartAfter: "test_sa",
MaxKeys: 10,
FetchOwner: true,
}
httpClient := tests.NewClient(
&tests.RequestStub{
Method: http.MethodGet,
URL: "http://test_bucket.example.com/?" + listParams.Encode(),
Match: func(req *http.Request) bool {
return tests.ExpectHeaders(req.Header, map[string]string{
"test_header": "test",
"Authorization": "^.+Credential=123/.+$",
})
},
Response: &http.Response{
Body: io.NopCloser(strings.NewReader(`
<?xml version="1.0" encoding="UTF-8"?>
<ListBucketResult xmlns="http://s3.amazonaws.com/doc/2006-03-01/">
<Name>example</Name>
<EncodingType>test_encoding</EncodingType>
<Prefix>a/</Prefix>
<Delimiter>/</Delimiter>
<ContinuationToken>ct</ContinuationToken>
<NextContinuationToken>nct</NextContinuationToken>
<StartAfter>example0.txt</StartAfter>
<KeyCount>1</KeyCount>
<MaxKeys>3</MaxKeys>
<IsTruncated>true</IsTruncated>
<Contents>
<Key>example1.txt</Key>
<LastModified>2025-01-01T01:02:03.123Z</LastModified>
<ChecksumAlgorithm>test_ca</ChecksumAlgorithm>
<ETag>test_etag1</ETag>
<Size>123</Size>
<StorageClass>STANDARD</StorageClass>
<Owner>
<DisplayName>owner_dn</DisplayName>
<ID>owner_id</ID>
</Owner>
<RestoreStatus>
<RestoreExpiryDate>2025-01-02T01:02:03.123Z</RestoreExpiryDate>
<IsRestoreInProgress>true</IsRestoreInProgress>
</RestoreStatus>
</Contents>
<Contents>
<Key>example2.txt</Key>
<LastModified>2025-01-02T01:02:03.123Z</LastModified>
<ETag>test_etag2</ETag>
<Size>456</Size>
<StorageClass>STANDARD</StorageClass>
</Contents>
<CommonPrefixes>
<Prefix>a/b/</Prefix>
</CommonPrefixes>
<CommonPrefixes>
<Prefix>a/c/</Prefix>
</CommonPrefixes>
</ListBucketResult>
`)),
},
},
)
s3Client := &s3.S3{
Client: httpClient,
Region: "test_region",
Bucket: "test_bucket",
Endpoint: "http://example.com",
AccessKey: "123",
SecretKey: "abc",
}
resp, err := s3Client.ListObjects(context.Background(), listParams, func(r *http.Request) {
r.Header.Set("test_header", "test")
})
if err != nil {
t.Fatal(err)
}
err = httpClient.AssertNoRemaining()
if err != nil {
t.Fatal(err)
}
raw, err := json.Marshal(resp)
if err != nil {
t.Fatal(err)
}
rawStr := string(raw)
expected := `{"encodingType":"test_encoding","name":"example","prefix":"a/","delimiter":"/","continuationToken":"ct","nextContinuationToken":"nct","startAfter":"example0.txt","commonPrefixes":[{"prefix":"a/b/"},{"prefix":"a/c/"}],"contents":[{"owner":{"displayName":"owner_dn","id":"owner_id"},"checksumAlgorithm":"test_ca","etag":"test_etag1","key":"example1.txt","storageClass":"STANDARD","lastModified":"2025-01-01T01:02:03.123Z","restoreStatus":{"restoreExpiryDate":"2025-01-02T01:02:03.123Z","isRestoreInProgress":true},"size":123},{"owner":{"displayName":"","id":""},"checksumAlgorithm":"","etag":"test_etag2","key":"example2.txt","storageClass":"STANDARD","lastModified":"2025-01-02T01:02:03.123Z","restoreStatus":{"restoreExpiryDate":"0001-01-01T00:00:00Z","isRestoreInProgress":false},"size":456}],"keyCount":1,"maxKeys":3,"isTruncated":true}`
if rawStr != expected {
t.Fatalf("Expected response\n%s\ngot\n%s", expected, rawStr)
}
}

View file

@ -0,0 +1,370 @@
// Package s3 implements a lightweight client for interacting with the
// REST APIs of any S3 compatible service.
//
// It implements only the minimal functionality required by PocketBase
// such as objects list, get, copy, delete and upload.
//
// For more details why we don't use the official aws-sdk-go-v2, you could check
// https://github.com/pocketbase/pocketbase/discussions/6562.
//
// Example:
//
// client := &s3.S3{
// Endpoint: "example.com",
// Region: "us-east-1",
// Bucket: "test",
// AccessKey: "...",
// SecretKey: "...",
// UsePathStyle: true,
// }
// resp, err := client.GetObject(context.Background(), "abc.txt")
package s3
import (
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"encoding/xml"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"slices"
"strings"
"time"
)
const (
awsS3ServiceCode = "s3"
awsSignAlgorithm = "AWS4-HMAC-SHA256"
awsTerminationString = "aws4_request"
metadataPrefix = "x-amz-meta-"
dateTimeFormat = "20060102T150405Z"
)
type HTTPClient interface {
Do(req *http.Request) (*http.Response, error)
}
type S3 struct {
// Client specifies a custom HTTP client to send the request with.
//
// If not explicitly set, fallbacks to http.DefaultClient.
Client HTTPClient
Bucket string
Region string
Endpoint string // can be with or without the schema
AccessKey string
SecretKey string
UsePathStyle bool
}
// URL constructs an S3 request URL based on the current configuration.
func (s3 *S3) URL(path string) string {
scheme := "https"
endpoint := strings.TrimRight(s3.Endpoint, "/")
if after, ok := strings.CutPrefix(endpoint, "https://"); ok {
endpoint = after
} else if after, ok := strings.CutPrefix(endpoint, "http://"); ok {
endpoint = after
scheme = "http"
}
path = strings.TrimLeft(path, "/")
if s3.UsePathStyle {
return fmt.Sprintf("%s://%s/%s/%s", scheme, endpoint, s3.Bucket, path)
}
return fmt.Sprintf("%s://%s.%s/%s", scheme, s3.Bucket, endpoint, path)
}
// SignAndSend signs the provided request per AWS Signature v4 and sends it.
//
// It automatically normalizes all 40x/50x responses to ResponseError.
//
// Note: Don't forget to call resp.Body.Close() after done with the result.
func (s3 *S3) SignAndSend(req *http.Request) (*http.Response, error) {
s3.sign(req)
client := s3.Client
if client == nil {
client = http.DefaultClient
}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
if resp.StatusCode >= 400 {
defer resp.Body.Close()
respErr := &ResponseError{
Status: resp.StatusCode,
}
respErr.Raw, err = io.ReadAll(resp.Body)
if err != nil && !errors.Is(err, io.EOF) {
return nil, errors.Join(err, respErr)
}
if len(respErr.Raw) > 0 {
err = xml.Unmarshal(respErr.Raw, respErr)
if err != nil {
return nil, errors.Join(err, respErr)
}
}
return nil, respErr
}
return resp, nil
}
// https://docs.aws.amazon.com/IAM/latest/UserGuide/reference_sigv-create-signed-request.html#create-signed-request-steps
func (s3 *S3) sign(req *http.Request) {
// fallback to the Unsigned payload option
// (data integrity checks could be still applied via the content-md5 or x-amz-checksum-* headers)
if req.Header.Get("x-amz-content-sha256") == "" {
req.Header.Set("x-amz-content-sha256", "UNSIGNED-PAYLOAD")
}
reqDateTime, _ := time.Parse(dateTimeFormat, req.Header.Get("x-amz-date"))
if reqDateTime.IsZero() {
reqDateTime = time.Now().UTC()
req.Header.Set("x-amz-date", reqDateTime.Format(dateTimeFormat))
}
req.Header.Set("host", req.URL.Host)
date := reqDateTime.Format("20060102")
dateTime := reqDateTime.Format(dateTimeFormat)
// 1. Create canonical request
// https://docs.aws.amazon.com/IAM/latest/UserGuide/reference_sigv-create-signed-request.html#create-canonical-request
// ---------------------------------------------------------------
canonicalHeaders, signedHeaders := canonicalAndSignedHeaders(req)
canonicalParts := []string{
req.Method,
escapePath(req.URL.Path),
escapeQuery(req.URL.Query()),
canonicalHeaders,
signedHeaders,
req.Header.Get("x-amz-content-sha256"),
}
// 2. Create a hash of the canonical request
// https://docs.aws.amazon.com/IAM/latest/UserGuide/reference_sigv-create-signed-request.html#create-canonical-request-hash
// ---------------------------------------------------------------
hashedCanonicalRequest := sha256Hex([]byte(strings.Join(canonicalParts, "\n")))
// 3. Create a string to sign
// https://docs.aws.amazon.com/IAM/latest/UserGuide/reference_sigv-create-signed-request.html#create-string-to-sign
// ---------------------------------------------------------------
scope := strings.Join([]string{
date,
s3.Region,
awsS3ServiceCode,
awsTerminationString,
}, "/")
stringToSign := strings.Join([]string{
awsSignAlgorithm,
dateTime,
scope,
hashedCanonicalRequest,
}, "\n")
// 4. Derive a signing key for SigV4
// https://docs.aws.amazon.com/IAM/latest/UserGuide/reference_sigv-create-signed-request.html#derive-signing-key
// ---------------------------------------------------------------
dateKey := hmacSHA256([]byte("AWS4"+s3.SecretKey), date)
dateRegionKey := hmacSHA256(dateKey, s3.Region)
dateRegionServiceKey := hmacSHA256(dateRegionKey, awsS3ServiceCode)
signingKey := hmacSHA256(dateRegionServiceKey, awsTerminationString)
signature := hex.EncodeToString(hmacSHA256(signingKey, stringToSign))
// 5. Add the signature to the request
// https://docs.aws.amazon.com/IAM/latest/UserGuide/reference_sigv-create-signed-request.html#add-signature-to-request
authorization := fmt.Sprintf(
"%s Credential=%s/%s, SignedHeaders=%s, Signature=%s",
awsSignAlgorithm,
s3.AccessKey,
scope,
signedHeaders,
signature,
)
req.Header.Set("authorization", authorization)
}
func sha256Hex(content []byte) string {
h := sha256.New()
h.Write(content)
return hex.EncodeToString(h.Sum(nil))
}
func hmacSHA256(key []byte, content string) []byte {
mac := hmac.New(sha256.New, key)
mac.Write([]byte(content))
return mac.Sum(nil)
}
func canonicalAndSignedHeaders(req *http.Request) (string, string) {
signed := []string{}
canonical := map[string]string{}
for key, values := range req.Header {
normalizedKey := strings.ToLower(key)
if normalizedKey != "host" &&
normalizedKey != "content-type" &&
!strings.HasPrefix(normalizedKey, "x-amz-") {
continue
}
signed = append(signed, normalizedKey)
// for each value:
// trim any leading or trailing spaces
// convert sequential spaces to a single space
normalizedValues := make([]string, len(values))
for i, v := range values {
normalizedValues[i] = strings.ReplaceAll(strings.TrimSpace(v), " ", " ")
}
canonical[normalizedKey] = strings.Join(normalizedValues, ",")
}
slices.Sort(signed)
var sortedCanonical strings.Builder
for _, key := range signed {
sortedCanonical.WriteString(key)
sortedCanonical.WriteString(":")
sortedCanonical.WriteString(canonical[key])
sortedCanonical.WriteString("\n")
}
return sortedCanonical.String(), strings.Join(signed, ";")
}
// extractMetadata parses and extracts and the metadata from the specified request headers.
//
// The metadata keys are all lowercased and without the "x-amz-meta-" prefix.
func extractMetadata(headers http.Header) map[string]string {
result := map[string]string{}
for k, v := range headers {
if len(v) == 0 {
continue
}
metadataKey, ok := strings.CutPrefix(strings.ToLower(k), metadataPrefix)
if !ok {
continue
}
result[metadataKey] = v[0]
}
return result
}
// escapeQuery returns the URI encoded request query parameters according to the AWS S3 spec requirements
// (it is similar to url.Values.Encode but instead of url.QueryEscape uses our own escape method).
func escapeQuery(values url.Values) string {
if len(values) == 0 {
return ""
}
var buf strings.Builder
keys := make([]string, 0, len(values))
for k := range values {
keys = append(keys, k)
}
slices.Sort(keys)
for _, k := range keys {
vs := values[k]
keyEscaped := escape(k)
for _, values := range vs {
if buf.Len() > 0 {
buf.WriteByte('&')
}
buf.WriteString(keyEscaped)
buf.WriteByte('=')
buf.WriteString(escape(values))
}
}
return buf.String()
}
// escapePath returns the URI encoded request path according to the AWS S3 spec requirements.
func escapePath(path string) string {
parts := strings.Split(path, "/")
for i, part := range parts {
parts[i] = escape(part)
}
return strings.Join(parts, "/")
}
const upperhex = "0123456789ABCDEF"
// escape is similar to the std url.escape but implements the AWS [UriEncode requirements]:
// - URI encode every byte except the unreserved characters: 'A'-'Z', 'a'-'z', '0'-'9', '-', '.', '_', and '~'.
// - The space character is a reserved character and must be encoded as "%20" (and not as "+").
// - Each URI encoded byte is formed by a '%' and the two-digit hexadecimal value of the byte.
// - Letters in the hexadecimal value must be uppercase, for example "%1A".
//
// [UriEncode requirements]: https://docs.aws.amazon.com/IAM/latest/UserGuide/reference_sigv-create-signed-request.html
func escape(s string) string {
hexCount := 0
for i := 0; i < len(s); i++ {
c := s[i]
if shouldEscape(c) {
hexCount++
}
}
if hexCount == 0 {
return s
}
result := make([]byte, len(s)+2*hexCount)
j := 0
for i := 0; i < len(s); i++ {
c := s[i]
if shouldEscape(c) {
result[j] = '%'
result[j+1] = upperhex[c>>4]
result[j+2] = upperhex[c&15]
j += 3
} else {
result[j] = c
j++
}
}
return string(result)
}
// > "URI encode every byte except the unreserved characters: 'A'-'Z', 'a'-'z', '0'-'9', '-', '.', '_', and '~'."
func shouldEscape(c byte) bool {
isUnreserved := (c >= 'A' && c <= 'Z') ||
(c >= 'a' && c <= 'z') ||
(c >= '0' && c <= '9') ||
c == '-' || c == '.' || c == '_' || c == '~'
return !isUnreserved
}

View file

@ -0,0 +1,35 @@
package s3
import (
"net/url"
"testing"
)
func TestEscapePath(t *testing.T) {
t.Parallel()
escaped := escapePath("/ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_.~ !@#$%^&*()+={}[]?><\\|,`'\"/@sub1/@sub2/a/b/c/1/2/3")
expected := "/ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_.~%20%21%40%23%24%25%5E%26%2A%28%29%2B%3D%7B%7D%5B%5D%3F%3E%3C%5C%7C%2C%60%27%22/%40sub1/%40sub2/a/b/c/1/2/3"
if escaped != expected {
t.Fatalf("Expected\n%s\ngot\n%s", expected, escaped)
}
}
func TestEscapeQuery(t *testing.T) {
t.Parallel()
escaped := escapeQuery(url.Values{
"abc": []string{"123"},
"/ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_.~ !@#$%^&*()+={}[]?><\\|,`'\"": []string{
"/ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_.~ !@#$%^&*()+={}[]?><\\|,`'\"",
},
})
expected := "%2FABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_.~%20%21%40%23%24%25%5E%26%2A%28%29%2B%3D%7B%7D%5B%5D%3F%3E%3C%5C%7C%2C%60%27%22=%2FABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_.~%20%21%40%23%24%25%5E%26%2A%28%29%2B%3D%7B%7D%5B%5D%3F%3E%3C%5C%7C%2C%60%27%22&abc=123"
if escaped != expected {
t.Fatalf("Expected\n%s\ngot\n%s", expected, escaped)
}
}

View file

@ -0,0 +1,256 @@
package s3_test
import (
"io"
"net/http"
"strings"
"testing"
"github.com/pocketbase/pocketbase/tools/filesystem/internal/s3blob/s3"
"github.com/pocketbase/pocketbase/tools/filesystem/internal/s3blob/s3/tests"
)
func TestS3URL(t *testing.T) {
t.Parallel()
scenarios := []struct {
name string
s3Client *s3.S3
expected string
}{
{
"no schema",
&s3.S3{
Region: "test_region",
Bucket: "test_bucket",
Endpoint: "example.com/",
AccessKey: "123",
SecretKey: "abc",
},
"https://test_bucket.example.com/test_key/a/b/c?q=1",
},
{
"with https schema",
&s3.S3{
Region: "test_region",
Bucket: "test_bucket",
Endpoint: "https://example.com/",
AccessKey: "123",
SecretKey: "abc",
},
"https://test_bucket.example.com/test_key/a/b/c?q=1",
},
{
"with http schema",
&s3.S3{
Region: "test_region",
Bucket: "test_bucket",
Endpoint: "http://example.com/",
AccessKey: "123",
SecretKey: "abc",
},
"http://test_bucket.example.com/test_key/a/b/c?q=1",
},
{
"path style addressing (non-explicit schema)",
&s3.S3{
Region: "test_region",
Bucket: "test_bucket",
Endpoint: "example.com/",
AccessKey: "123",
SecretKey: "abc",
UsePathStyle: true,
},
"https://example.com/test_bucket/test_key/a/b/c?q=1",
},
{
"path style addressing (explicit schema)",
&s3.S3{
Region: "test_region",
Bucket: "test_bucket",
Endpoint: "http://example.com/",
AccessKey: "123",
SecretKey: "abc",
UsePathStyle: true,
},
"http://example.com/test_bucket/test_key/a/b/c?q=1",
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
result := s.s3Client.URL("/test_key/a/b/c?q=1")
if result != s.expected {
t.Fatalf("Expected URL\n%s\ngot\n%s", s.expected, result)
}
})
}
}
func TestS3SignAndSend(t *testing.T) {
t.Parallel()
testResponse := func() *http.Response {
return &http.Response{
Body: io.NopCloser(strings.NewReader("test_response")),
}
}
scenarios := []struct {
name string
path string
reqFunc func(req *http.Request)
s3Client *s3.S3
}{
{
"minimal",
"/test",
func(req *http.Request) {
req.Header.Set("x-amz-date", "20250102T150405Z")
},
&s3.S3{
Region: "test_region",
Bucket: "test_bucket",
Endpoint: "https://example.com/",
AccessKey: "123",
SecretKey: "abc",
Client: tests.NewClient(&tests.RequestStub{
Method: http.MethodGet,
URL: "https://test_bucket.example.com/test",
Response: testResponse(),
Match: func(req *http.Request) bool {
return tests.ExpectHeaders(req.Header, map[string]string{
"Authorization": "AWS4-HMAC-SHA256 Credential=123/20250102/test_region/s3/aws4_request, SignedHeaders=host;x-amz-content-sha256;x-amz-date, Signature=ea093662bc1deef08dfb4ac35453dfaad5ea89edf102e9dd3b7156c9a27e4c1f",
"Host": "test_bucket.example.com",
"X-Amz-Content-Sha256": "UNSIGNED-PAYLOAD",
"X-Amz-Date": "20250102T150405Z",
})
},
}),
},
},
{
"minimal with different access and secret keys",
"/test",
func(req *http.Request) {
req.Header.Set("x-amz-date", "20250102T150405Z")
},
&s3.S3{
Region: "test_region",
Bucket: "test_bucket",
Endpoint: "https://example.com/",
AccessKey: "456",
SecretKey: "def",
Client: tests.NewClient(&tests.RequestStub{
Method: http.MethodGet,
URL: "https://test_bucket.example.com/test",
Response: testResponse(),
Match: func(req *http.Request) bool {
return tests.ExpectHeaders(req.Header, map[string]string{
"Authorization": "AWS4-HMAC-SHA256 Credential=456/20250102/test_region/s3/aws4_request, SignedHeaders=host;x-amz-content-sha256;x-amz-date, Signature=17510fa1f724403dd0a563b61c9b31d1d718f877fcbd75455620d17a8afce5fb",
"Host": "test_bucket.example.com",
"X-Amz-Content-Sha256": "UNSIGNED-PAYLOAD",
"X-Amz-Date": "20250102T150405Z",
})
},
}),
},
},
{
"minimal with special characters",
"/ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_.~!@#$^&*()=/@sub?a=1&@b=@2",
func(req *http.Request) {
req.Header.Set("x-amz-date", "20250102T150405Z")
},
&s3.S3{
Region: "test_region",
Bucket: "test_bucket",
Endpoint: "https://example.com/",
AccessKey: "456",
SecretKey: "def",
Client: tests.NewClient(&tests.RequestStub{
Method: http.MethodGet,
URL: "https://test_bucket.example.com/ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_.~!@#$%5E&*()=/@sub?a=1&@b=@2",
Response: testResponse(),
Match: func(req *http.Request) bool {
return tests.ExpectHeaders(req.Header, map[string]string{
"Authorization": "AWS4-HMAC-SHA256 Credential=456/20250102/test_region/s3/aws4_request, SignedHeaders=host;x-amz-content-sha256;x-amz-date, Signature=e0001982deef1652704f74503203e77d83d4c88369421f9fca644d96f2a62a3c",
"Host": "test_bucket.example.com",
"X-Amz-Content-Sha256": "UNSIGNED-PAYLOAD",
"X-Amz-Date": "20250102T150405Z",
})
},
}),
},
},
{
"with extra headers",
"/test",
func(req *http.Request) {
req.Header.Set("x-amz-date", "20250102T150405Z")
req.Header.Set("x-amz-content-sha256", "test_sha256")
req.Header.Set("x-amz-example", "123")
req.Header.Set("x-amz-meta-a", "456")
req.Header.Set("content-type", "image/png")
req.Header.Set("x-test", "789") // shouldn't be included in the signing headers
},
&s3.S3{
Region: "test_region",
Bucket: "test_bucket",
Endpoint: "https://example.com/",
AccessKey: "123",
SecretKey: "abc",
Client: tests.NewClient(&tests.RequestStub{
Method: http.MethodGet,
URL: "https://test_bucket.example.com/test",
Response: testResponse(),
Match: func(req *http.Request) bool {
return tests.ExpectHeaders(req.Header, map[string]string{
"authorization": "AWS4-HMAC-SHA256 Credential=123/20250102/test_region/s3/aws4_request, SignedHeaders=content-type;host;x-amz-content-sha256;x-amz-date;x-amz-example;x-amz-meta-a, Signature=86dccbcd012c33073dc99e9d0a9e0b717a4d8c11c37848cfa9a4a02716bc0db3",
"host": "test_bucket.example.com",
"x-amz-date": "20250102T150405Z",
"x-amz-content-sha256": "test_sha256",
"x-amz-example": "123",
"x-amz-meta-a": "456",
"x-test": "789",
})
},
}),
},
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
req, err := http.NewRequest(http.MethodGet, s.s3Client.URL(s.path), strings.NewReader("test_request"))
if err != nil {
t.Fatal(err)
}
if s.reqFunc != nil {
s.reqFunc(req)
}
resp, err := s.s3Client.SignAndSend(req)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
err = s.s3Client.Client.(*tests.Client).AssertNoRemaining()
if err != nil {
t.Fatal(err)
}
expectedBody := "test_response"
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
if str := string(body); str != expectedBody {
t.Fatalf("Expected body %q, got %q", expectedBody, str)
}
})
}
}

View file

@ -0,0 +1,111 @@
// Package tests contains various tests helpers and utilities to assist
// with the S3 client testing.
package tests
import (
"errors"
"fmt"
"io"
"net/http"
"regexp"
"slices"
"strings"
"sync"
)
// NewClient creates a new test Client loaded with the specified RequestStubs.
func NewClient(stubs ...*RequestStub) *Client {
return &Client{stubs: stubs}
}
type RequestStub struct {
Method string
URL string // plain string or regex pattern wrapped in "^pattern$"
Match func(req *http.Request) bool
Response *http.Response
}
type Client struct {
stubs []*RequestStub
mu sync.Mutex
}
// AssertNoRemaining asserts that current client has no unprocessed requests remaining.
func (c *Client) AssertNoRemaining() error {
c.mu.Lock()
defer c.mu.Unlock()
if len(c.stubs) == 0 {
return nil
}
msgParts := make([]string, 0, len(c.stubs)+1)
msgParts = append(msgParts, "not all stub requests were processed:")
for _, stub := range c.stubs {
msgParts = append(msgParts, "- "+stub.Method+" "+stub.URL)
}
return errors.New(strings.Join(msgParts, "\n"))
}
// Do implements the [s3.HTTPClient] interface.
func (c *Client) Do(req *http.Request) (*http.Response, error) {
c.mu.Lock()
defer c.mu.Unlock()
for i, stub := range c.stubs {
if req.Method != stub.Method {
continue
}
urlPattern := stub.URL
if !strings.HasPrefix(urlPattern, "^") && !strings.HasSuffix(urlPattern, "$") {
urlPattern = "^" + regexp.QuoteMeta(urlPattern) + "$"
}
urlRegex, err := regexp.Compile(urlPattern)
if err != nil {
return nil, err
}
if !urlRegex.MatchString(req.URL.String()) {
continue
}
if stub.Match != nil && !stub.Match(req) {
continue
}
// remove from the remaining stubs
c.stubs = slices.Delete(c.stubs, i, i+1)
response := stub.Response
if response == nil {
response = &http.Response{}
}
if response.Header == nil {
response.Header = http.Header{}
}
if response.Body == nil {
response.Body = http.NoBody
}
response.Request = req
return response, nil
}
var body []byte
if req.Body != nil {
defer req.Body.Close()
body, _ = io.ReadAll(req.Body)
}
return nil, fmt.Errorf(
"the below request doesn't have a corresponding stub:\n%s %s\nHeaders: %v\nBody: %q",
req.Method,
req.URL.String(),
req.Header,
body,
)
}

View file

@ -0,0 +1,33 @@
package tests
import (
"net/http"
"regexp"
"strings"
)
// ExpectHeaders checks whether specified headers match the expectations.
// The expectations map entry key is the header name.
// The expectations map entry value is the first header value. If wrapped with `^...$`
// it is compared as regular expression.
func ExpectHeaders(headers http.Header, expectations map[string]string) bool {
for h, expected := range expectations {
v := headers.Get(h)
pattern := expected
if !strings.HasPrefix(pattern, "^") && !strings.HasSuffix(pattern, "$") {
pattern = "^" + regexp.QuoteMeta(pattern) + "$"
}
expectedRegex, err := regexp.Compile(pattern)
if err != nil {
return false
}
if !expectedRegex.MatchString(v) {
return false
}
}
return true
}

View file

@ -0,0 +1,414 @@
package s3
import (
"bytes"
"context"
"encoding/xml"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"slices"
"strconv"
"strings"
"sync"
"golang.org/x/sync/errgroup"
)
var ErrUsedUploader = errors.New("the Uploader has been already used")
const (
defaultMaxConcurrency int = 5
defaultMinPartSize int = 6 << 20
)
// Uploader handles the upload of a single S3 object.
//
// If the Payload size is less than the configured MinPartSize it sends
// a single (PutObject) request, otherwise performs chunked/multipart upload.
type Uploader struct {
// S3 is the S3 client instance performing the upload object request (required).
S3 *S3
// Payload is the object content to upload (required).
Payload io.Reader
// Key is the destination key of the uploaded object (required).
Key string
// Metadata specifies the optional metadata to write with the object upload.
Metadata map[string]string
// MaxConcurrency specifies the max number of workers to use when
// performing chunked/multipart upload.
//
// If zero or negative, defaults to 5.
//
// This option is used only when the Payload size is > MinPartSize.
MaxConcurrency int
// MinPartSize specifies the min Payload size required to perform
// chunked/multipart upload.
//
// If zero or negative, defaults to ~6MB.
MinPartSize int
uploadId string
uploadedParts []*mpPart
lastPartNumber int
mu sync.Mutex // guards lastPartNumber and the uploadedParts slice
used bool
}
// Upload processes the current Uploader instance.
//
// Users can specify an optional optReqFuncs that will be passed down to all Upload internal requests
// (single upload, multipart init, multipart parts upload, multipart complete, multipart abort).
//
// Note that after this call the Uploader should be discarded (aka. no longer can be used).
func (u *Uploader) Upload(ctx context.Context, optReqFuncs ...func(*http.Request)) error {
if u.used {
return ErrUsedUploader
}
err := u.validateAndNormalize()
if err != nil {
return err
}
initPart, _, err := u.nextPart()
if err != nil && !errors.Is(err, io.EOF) {
return err
}
if len(initPart) < u.MinPartSize {
return u.singleUpload(ctx, initPart, optReqFuncs...)
}
err = u.multipartInit(ctx, optReqFuncs...)
if err != nil {
return fmt.Errorf("multipart init error: %w", err)
}
err = u.multipartUpload(ctx, initPart, optReqFuncs...)
if err != nil {
return errors.Join(
u.multipartAbort(ctx, optReqFuncs...),
fmt.Errorf("multipart upload error: %w", err),
)
}
err = u.multipartComplete(ctx, optReqFuncs...)
if err != nil {
return errors.Join(
u.multipartAbort(ctx, optReqFuncs...),
fmt.Errorf("multipart complete error: %w", err),
)
}
return nil
}
// -------------------------------------------------------------------
func (u *Uploader) validateAndNormalize() error {
if u.S3 == nil {
return errors.New("Uploader.S3 must be a non-empty and properly initialized S3 client instance")
}
if u.Key == "" {
return errors.New("Uploader.Key is required")
}
if u.Payload == nil {
return errors.New("Uploader.Payload must be non-nill")
}
if u.MaxConcurrency <= 0 {
u.MaxConcurrency = defaultMaxConcurrency
}
if u.MinPartSize <= 0 {
u.MinPartSize = defaultMinPartSize
}
return nil
}
func (u *Uploader) singleUpload(ctx context.Context, part []byte, optReqFuncs ...func(*http.Request)) error {
if u.used {
return ErrUsedUploader
}
req, err := http.NewRequestWithContext(ctx, http.MethodPut, u.S3.URL(u.Key), bytes.NewReader(part))
if err != nil {
return err
}
req.Header.Set("Content-Length", strconv.Itoa(len(part)))
for k, v := range u.Metadata {
req.Header.Set(metadataPrefix+k, v)
}
// apply optional request funcs
for _, fn := range optReqFuncs {
if fn != nil {
fn(req)
}
}
resp, err := u.S3.SignAndSend(req)
if err != nil {
return err
}
defer resp.Body.Close()
return nil
}
// -------------------------------------------------------------------
type mpPart struct {
XMLName xml.Name `xml:"Part"`
ETag string `xml:"ETag"`
PartNumber int `xml:"PartNumber"`
}
// https://docs.aws.amazon.com/AmazonS3/latest/API/API_CreateMultipartUpload.html
func (u *Uploader) multipartInit(ctx context.Context, optReqFuncs ...func(*http.Request)) error {
if u.used {
return ErrUsedUploader
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, u.S3.URL(u.Key+"?uploads"), nil)
if err != nil {
return err
}
for k, v := range u.Metadata {
req.Header.Set(metadataPrefix+k, v)
}
// apply optional request funcs
for _, fn := range optReqFuncs {
if fn != nil {
fn(req)
}
}
resp, err := u.S3.SignAndSend(req)
if err != nil {
return err
}
defer resp.Body.Close()
body := &struct {
XMLName xml.Name `xml:"InitiateMultipartUploadResult"`
UploadId string `xml:"UploadId"`
}{}
err = xml.NewDecoder(resp.Body).Decode(body)
if err != nil {
return err
}
u.uploadId = body.UploadId
return nil
}
// https://docs.aws.amazon.com/AmazonS3/latest/API/API_AbortMultipartUpload.html
func (u *Uploader) multipartAbort(ctx context.Context, optReqFuncs ...func(*http.Request)) error {
u.mu.Lock()
defer u.mu.Unlock()
u.used = true
// ensure that the specified abort context is always valid to allow cleanup
var abortCtx = ctx
if abortCtx.Err() != nil {
abortCtx = context.Background()
}
query := url.Values{"uploadId": []string{u.uploadId}}
req, err := http.NewRequestWithContext(abortCtx, http.MethodDelete, u.S3.URL(u.Key+"?"+query.Encode()), nil)
if err != nil {
return err
}
// apply optional request funcs
for _, fn := range optReqFuncs {
if fn != nil {
fn(req)
}
}
resp, err := u.S3.SignAndSend(req)
if err != nil {
return err
}
defer resp.Body.Close()
return nil
}
// https://docs.aws.amazon.com/AmazonS3/latest/API/API_CompleteMultipartUpload.html
func (u *Uploader) multipartComplete(ctx context.Context, optReqFuncs ...func(*http.Request)) error {
u.mu.Lock()
defer u.mu.Unlock()
u.used = true
// the list of parts must be sorted in ascending order
slices.SortFunc(u.uploadedParts, func(a, b *mpPart) int {
if a.PartNumber < b.PartNumber {
return -1
}
if a.PartNumber > b.PartNumber {
return 1
}
return 0
})
// build a request payload with the uploaded parts
xmlParts := &struct {
XMLName xml.Name `xml:"CompleteMultipartUpload"`
Parts []*mpPart
}{
Parts: u.uploadedParts,
}
rawXMLParts, err := xml.Marshal(xmlParts)
if err != nil {
return err
}
reqPayload := strings.NewReader(xml.Header + string(rawXMLParts))
query := url.Values{"uploadId": []string{u.uploadId}}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, u.S3.URL(u.Key+"?"+query.Encode()), reqPayload)
if err != nil {
return err
}
// apply optional request funcs
for _, fn := range optReqFuncs {
if fn != nil {
fn(req)
}
}
resp, err := u.S3.SignAndSend(req)
if err != nil {
return err
}
defer resp.Body.Close()
return nil
}
func (u *Uploader) nextPart() ([]byte, int, error) {
u.mu.Lock()
defer u.mu.Unlock()
part := make([]byte, u.MinPartSize)
n, err := io.ReadFull(u.Payload, part)
// normalize io.EOF errors and ensure that io.EOF is returned only when there were no read bytes
if err != nil && (errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF)) {
if n == 0 {
err = io.EOF
} else {
err = nil
}
}
u.lastPartNumber++
return part[0:n], u.lastPartNumber, err
}
func (u *Uploader) multipartUpload(ctx context.Context, initPart []byte, optReqFuncs ...func(*http.Request)) error {
var g errgroup.Group
g.SetLimit(u.MaxConcurrency)
totalParallel := u.MaxConcurrency
if len(initPart) != 0 {
totalParallel--
initPartNumber := u.lastPartNumber
g.Go(func() error {
mp, err := u.uploadPart(ctx, initPartNumber, initPart, optReqFuncs...)
if err != nil {
return err
}
u.mu.Lock()
u.uploadedParts = append(u.uploadedParts, mp)
u.mu.Unlock()
return nil
})
}
for i := 0; i < totalParallel; i++ {
g.Go(func() error {
for {
part, num, err := u.nextPart()
if err != nil {
if errors.Is(err, io.EOF) {
break
}
return err
}
mp, err := u.uploadPart(ctx, num, part, optReqFuncs...)
if err != nil {
return err
}
u.mu.Lock()
u.uploadedParts = append(u.uploadedParts, mp)
u.mu.Unlock()
}
return nil
})
}
return g.Wait()
}
// https://docs.aws.amazon.com/AmazonS3/latest/API/API_UploadPart.html
func (u *Uploader) uploadPart(ctx context.Context, partNumber int, partData []byte, optReqFuncs ...func(*http.Request)) (*mpPart, error) {
query := url.Values{}
query.Set("uploadId", u.uploadId)
query.Set("partNumber", strconv.Itoa(partNumber))
req, err := http.NewRequestWithContext(ctx, http.MethodPut, u.S3.URL(u.Key+"?"+query.Encode()), bytes.NewReader(partData))
if err != nil {
return nil, err
}
req.Header.Set("Content-Length", strconv.Itoa(len(partData)))
// apply optional request funcs
for _, fn := range optReqFuncs {
if fn != nil {
fn(req)
}
}
resp, err := u.S3.SignAndSend(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
return &mpPart{
PartNumber: partNumber,
ETag: resp.Header.Get("ETag"),
}, nil
}

View file

@ -0,0 +1,463 @@
package s3_test
import (
"context"
"io"
"net/http"
"strings"
"testing"
"github.com/pocketbase/pocketbase/tools/filesystem/internal/s3blob/s3"
"github.com/pocketbase/pocketbase/tools/filesystem/internal/s3blob/s3/tests"
)
func TestUploaderRequiredFields(t *testing.T) {
t.Parallel()
s3Client := &s3.S3{
Client: tests.NewClient(&tests.RequestStub{Method: "PUT", URL: `^.+$`}), // match every upload
Region: "test_region",
Bucket: "test_bucket",
Endpoint: "http://example.com",
AccessKey: "123",
SecretKey: "abc",
}
payload := strings.NewReader("test")
scenarios := []struct {
name string
uploader *s3.Uploader
expectedError bool
}{
{
"blank",
&s3.Uploader{},
true,
},
{
"no Key",
&s3.Uploader{S3: s3Client, Payload: payload},
true,
},
{
"no S3",
&s3.Uploader{Key: "abc", Payload: payload},
true,
},
{
"no Payload",
&s3.Uploader{S3: s3Client, Key: "abc"},
true,
},
{
"with S3, Key and Payload",
&s3.Uploader{S3: s3Client, Key: "abc", Payload: payload},
false,
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
err := s.uploader.Upload(context.Background())
hasErr := err != nil
if hasErr != s.expectedError {
t.Fatalf("Expected hasErr %v, got %v", s.expectedError, hasErr)
}
})
}
}
func TestUploaderSingleUpload(t *testing.T) {
t.Parallel()
httpClient := tests.NewClient(
&tests.RequestStub{
Method: http.MethodPut,
URL: "http://test_bucket.example.com/test_key",
Match: func(req *http.Request) bool {
body, err := io.ReadAll(req.Body)
if err != nil {
return false
}
return string(body) == "abcdefg" && tests.ExpectHeaders(req.Header, map[string]string{
"Content-Length": "7",
"x-amz-meta-a": "123",
"x-amz-meta-b": "456",
"test_header": "test",
"Authorization": "^.+Credential=123/.+$",
})
},
},
)
uploader := &s3.Uploader{
S3: &s3.S3{
Client: httpClient,
Region: "test_region",
Bucket: "test_bucket",
Endpoint: "http://example.com",
AccessKey: "123",
SecretKey: "abc",
},
Key: "test_key",
Payload: strings.NewReader("abcdefg"),
Metadata: map[string]string{"a": "123", "b": "456"},
MinPartSize: 8,
}
err := uploader.Upload(context.Background(), func(r *http.Request) {
r.Header.Set("test_header", "test")
})
if err != nil {
t.Fatal(err)
}
err = httpClient.AssertNoRemaining()
if err != nil {
t.Fatal(err)
}
}
func TestUploaderMultipartUploadSuccess(t *testing.T) {
t.Parallel()
httpClient := tests.NewClient(
&tests.RequestStub{
Method: http.MethodPost,
URL: "http://test_bucket.example.com/test_key?uploads",
Match: func(req *http.Request) bool {
return tests.ExpectHeaders(req.Header, map[string]string{
"x-amz-meta-a": "123",
"x-amz-meta-b": "456",
"test_header": "test",
"Authorization": "^.+Credential=123/.+$",
})
},
Response: &http.Response{
Body: io.NopCloser(strings.NewReader(`
<?xml version="1.0" encoding="UTF-8"?>
<InitiateMultipartUploadResult>
<Bucket>test_bucket</Bucket>
<Key>test_key</Key>
<UploadId>test_id</UploadId>
</InitiateMultipartUploadResult>
`)),
},
},
&tests.RequestStub{
Method: http.MethodPut,
URL: "http://test_bucket.example.com/test_key?partNumber=1&uploadId=test_id",
Match: func(req *http.Request) bool {
body, err := io.ReadAll(req.Body)
if err != nil {
return false
}
return string(body) == "abc" && tests.ExpectHeaders(req.Header, map[string]string{
"Content-Length": "3",
"test_header": "test",
"Authorization": "^.+Credential=123/.+$",
})
},
Response: &http.Response{
Header: http.Header{"Etag": []string{"etag1"}},
},
},
&tests.RequestStub{
Method: http.MethodPut,
URL: "http://test_bucket.example.com/test_key?partNumber=2&uploadId=test_id",
Match: func(req *http.Request) bool {
body, err := io.ReadAll(req.Body)
if err != nil {
return false
}
return string(body) == "def" && tests.ExpectHeaders(req.Header, map[string]string{
"Content-Length": "3",
"test_header": "test",
"Authorization": "^.+Credential=123/.+$",
})
},
Response: &http.Response{
Header: http.Header{"Etag": []string{"etag2"}},
},
},
&tests.RequestStub{
Method: http.MethodPut,
URL: "http://test_bucket.example.com/test_key?partNumber=3&uploadId=test_id",
Match: func(req *http.Request) bool {
body, err := io.ReadAll(req.Body)
if err != nil {
return false
}
return string(body) == "g" && tests.ExpectHeaders(req.Header, map[string]string{
"Content-Length": "1",
"test_header": "test",
"Authorization": "^.+Credential=123/.+$",
})
},
Response: &http.Response{
Header: http.Header{"Etag": []string{"etag3"}},
},
},
&tests.RequestStub{
Method: http.MethodPost,
URL: "http://test_bucket.example.com/test_key?uploadId=test_id",
Match: func(req *http.Request) bool {
body, err := io.ReadAll(req.Body)
if err != nil {
return false
}
expected := `<CompleteMultipartUpload><Part><ETag>etag1</ETag><PartNumber>1</PartNumber></Part><Part><ETag>etag2</ETag><PartNumber>2</PartNumber></Part><Part><ETag>etag3</ETag><PartNumber>3</PartNumber></Part></CompleteMultipartUpload>`
return strings.Contains(string(body), expected) && tests.ExpectHeaders(req.Header, map[string]string{
"test_header": "test",
"Authorization": "^.+Credential=123/.+$",
})
},
},
)
uploader := &s3.Uploader{
S3: &s3.S3{
Client: httpClient,
Region: "test_region",
Bucket: "test_bucket",
Endpoint: "http://example.com",
AccessKey: "123",
SecretKey: "abc",
},
Key: "test_key",
Payload: strings.NewReader("abcdefg"),
Metadata: map[string]string{"a": "123", "b": "456"},
MinPartSize: 3,
}
err := uploader.Upload(context.Background(), func(r *http.Request) {
r.Header.Set("test_header", "test")
})
if err != nil {
t.Fatal(err)
}
err = httpClient.AssertNoRemaining()
if err != nil {
t.Fatal(err)
}
}
func TestUploaderMultipartUploadPartFailure(t *testing.T) {
t.Parallel()
httpClient := tests.NewClient(
&tests.RequestStub{
Method: http.MethodPost,
URL: "http://test_bucket.example.com/test_key?uploads",
Match: func(req *http.Request) bool {
return tests.ExpectHeaders(req.Header, map[string]string{
"x-amz-meta-a": "123",
"x-amz-meta-b": "456",
"test_header": "test",
"Authorization": "^.+Credential=123/.+$",
})
},
Response: &http.Response{
Body: io.NopCloser(strings.NewReader(`
<?xml version="1.0" encoding="UTF-8"?>
<InitiateMultipartUploadResult>
<Bucket>test_bucket</Bucket>
<Key>test_key</Key>
<UploadId>test_id</UploadId>
</InitiateMultipartUploadResult>
`)),
},
},
&tests.RequestStub{
Method: http.MethodPut,
URL: "http://test_bucket.example.com/test_key?partNumber=1&uploadId=test_id",
Match: func(req *http.Request) bool {
body, err := io.ReadAll(req.Body)
if err != nil {
return false
}
return string(body) == "abc" && tests.ExpectHeaders(req.Header, map[string]string{
"Content-Length": "3",
"test_header": "test",
"Authorization": "^.+Credential=123/.+$",
})
},
Response: &http.Response{
Header: http.Header{"Etag": []string{"etag1"}},
},
},
&tests.RequestStub{
Method: http.MethodPut,
URL: "http://test_bucket.example.com/test_key?partNumber=2&uploadId=test_id",
Match: func(req *http.Request) bool {
return tests.ExpectHeaders(req.Header, map[string]string{
"test_header": "test",
"Authorization": "^.+Credential=123/.+$",
})
},
Response: &http.Response{
StatusCode: 400,
},
},
&tests.RequestStub{
Method: http.MethodDelete,
URL: "http://test_bucket.example.com/test_key?uploadId=test_id",
Match: func(req *http.Request) bool {
return tests.ExpectHeaders(req.Header, map[string]string{
"test_header": "test",
"Authorization": "^.+Credential=123/.+$",
})
},
},
)
uploader := &s3.Uploader{
S3: &s3.S3{
Client: httpClient,
Region: "test_region",
Bucket: "test_bucket",
Endpoint: "http://example.com",
AccessKey: "123",
SecretKey: "abc",
},
Key: "test_key",
Payload: strings.NewReader("abcdefg"),
Metadata: map[string]string{"a": "123", "b": "456"},
MinPartSize: 3,
}
err := uploader.Upload(context.Background(), func(r *http.Request) {
r.Header.Set("test_header", "test")
})
if err == nil {
t.Fatal("Expected non-nil error")
}
err = httpClient.AssertNoRemaining()
if err != nil {
t.Fatal(err)
}
}
func TestUploaderMultipartUploadCompleteFailure(t *testing.T) {
t.Parallel()
httpClient := tests.NewClient(
&tests.RequestStub{
Method: http.MethodPost,
URL: "http://test_bucket.example.com/test_key?uploads",
Match: func(req *http.Request) bool {
return tests.ExpectHeaders(req.Header, map[string]string{
"x-amz-meta-a": "123",
"x-amz-meta-b": "456",
"test_header": "test",
"Authorization": "^.+Credential=123/.+$",
})
},
Response: &http.Response{
Body: io.NopCloser(strings.NewReader(`
<?xml version="1.0" encoding="UTF-8"?>
<InitiateMultipartUploadResult>
<Bucket>test_bucket</Bucket>
<Key>test_key</Key>
<UploadId>test_id</UploadId>
</InitiateMultipartUploadResult>
`)),
},
},
&tests.RequestStub{
Method: http.MethodPut,
URL: "http://test_bucket.example.com/test_key?partNumber=1&uploadId=test_id",
Match: func(req *http.Request) bool {
body, err := io.ReadAll(req.Body)
if err != nil {
return false
}
return string(body) == "abc" && tests.ExpectHeaders(req.Header, map[string]string{
"Content-Length": "3",
"test_header": "test",
"Authorization": "^.+Credential=123/.+$",
})
},
Response: &http.Response{
Header: http.Header{"Etag": []string{"etag1"}},
},
},
&tests.RequestStub{
Method: http.MethodPut,
URL: "http://test_bucket.example.com/test_key?partNumber=2&uploadId=test_id",
Match: func(req *http.Request) bool {
body, err := io.ReadAll(req.Body)
if err != nil {
return false
}
return string(body) == "def" && tests.ExpectHeaders(req.Header, map[string]string{
"Content-Length": "3",
"test_header": "test",
"Authorization": "^.+Credential=123/.+$",
})
},
Response: &http.Response{
Header: http.Header{"Etag": []string{"etag2"}},
},
},
&tests.RequestStub{
Method: http.MethodPost,
URL: "http://test_bucket.example.com/test_key?uploadId=test_id",
Match: func(req *http.Request) bool {
return tests.ExpectHeaders(req.Header, map[string]string{
"test_header": "test",
"Authorization": "^.+Credential=123/.+$",
})
},
Response: &http.Response{
StatusCode: 400,
},
},
&tests.RequestStub{
Method: http.MethodDelete,
URL: "http://test_bucket.example.com/test_key?uploadId=test_id",
Match: func(req *http.Request) bool {
return tests.ExpectHeaders(req.Header, map[string]string{
"test_header": "test",
"Authorization": "^.+Credential=123/.+$",
})
},
},
)
uploader := &s3.Uploader{
S3: &s3.S3{
Client: httpClient,
Region: "test_region",
Bucket: "test_bucket",
Endpoint: "http://example.com",
AccessKey: "123",
SecretKey: "abc",
},
Key: "test_key",
Payload: strings.NewReader("abcdef"),
Metadata: map[string]string{"a": "123", "b": "456"},
MinPartSize: 3,
}
err := uploader.Upload(context.Background(), func(r *http.Request) {
r.Header.Set("test_header", "test")
})
if err == nil {
t.Fatal("Expected non-nil error")
}
err = httpClient.AssertNoRemaining()
if err != nil {
t.Fatal(err)
}
}

View file

@ -0,0 +1,485 @@
// Package s3blob provides a blob.Bucket S3 driver implementation.
//
// NB! To minimize breaking changes with older PocketBase releases,
// the driver is based of the previously used gocloud.dev/blob/s3blob,
// hence many of the below doc comments, struct options and interface
// implementations are the same.
//
// The blob abstraction supports all UTF-8 strings; to make this work with services lacking
// full UTF-8 support, strings must be escaped (during writes) and unescaped
// (during reads). The following escapes are performed for s3blob:
// - Blob keys: ASCII characters 0-31 are escaped to "__0x<hex>__".
// Additionally, the "/" in "../" is escaped in the same way.
// - Metadata keys: Escaped using URL encoding, then additionally "@:=" are
// escaped using "__0x<hex>__". These characters were determined by
// experimentation.
// - Metadata values: Escaped using URL encoding.
//
// Example:
//
// drv, _ := s3blob.New(&s3.S3{
// Bucket: "bucketName",
// Region: "region",
// Endpoint: "endpoint",
// AccessKey: "accessKey",
// SecretKey: "secretKey",
// })
// bucket := blob.NewBucket(drv)
package s3blob
import (
"context"
"encoding/base64"
"encoding/hex"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"sort"
"strconv"
"strings"
"github.com/pocketbase/pocketbase/tools/filesystem/blob"
"github.com/pocketbase/pocketbase/tools/filesystem/internal/s3blob/s3"
)
const defaultPageSize = 1000
// New creates a new instance of the S3 driver backed by the the internal S3 client.
func New(s3Client *s3.S3) (blob.Driver, error) {
if s3Client.Bucket == "" {
return nil, errors.New("s3blob.New: missing bucket name")
}
if s3Client.Endpoint == "" {
return nil, errors.New("s3blob.New: missing endpoint")
}
if s3Client.Region == "" {
return nil, errors.New("s3blob.New: missing region")
}
return &driver{s3: s3Client}, nil
}
type driver struct {
s3 *s3.S3
}
// Close implements [blob/Driver.Close].
func (drv *driver) Close() error {
return nil // nothing to close
}
// NormalizeError implements [blob/Driver.NormalizeError].
func (drv *driver) NormalizeError(err error) error {
// already normalized
if errors.Is(err, blob.ErrNotFound) {
return err
}
// normalize base on its S3 error status or code
var ae *s3.ResponseError
if errors.As(err, &ae) {
if ae.Status == 404 {
return errors.Join(err, blob.ErrNotFound)
}
switch ae.Code {
case "NoSuchBucket", "NoSuchKey", "NotFound":
return errors.Join(err, blob.ErrNotFound)
}
}
return err
}
// ListPaged implements [blob/Driver.ListPaged].
func (drv *driver) ListPaged(ctx context.Context, opts *blob.ListOptions) (*blob.ListPage, error) {
pageSize := opts.PageSize
if pageSize == 0 {
pageSize = defaultPageSize
}
listParams := s3.ListParams{
MaxKeys: pageSize,
}
if len(opts.PageToken) > 0 {
listParams.ContinuationToken = string(opts.PageToken)
}
if opts.Prefix != "" {
listParams.Prefix = escapeKey(opts.Prefix)
}
if opts.Delimiter != "" {
listParams.Delimiter = escapeKey(opts.Delimiter)
}
resp, err := drv.s3.ListObjects(ctx, listParams)
if err != nil {
return nil, err
}
page := blob.ListPage{}
if resp.NextContinuationToken != "" {
page.NextPageToken = []byte(resp.NextContinuationToken)
}
if n := len(resp.Contents) + len(resp.CommonPrefixes); n > 0 {
page.Objects = make([]*blob.ListObject, n)
for i, obj := range resp.Contents {
page.Objects[i] = &blob.ListObject{
Key: unescapeKey(obj.Key),
ModTime: obj.LastModified,
Size: obj.Size,
MD5: eTagToMD5(obj.ETag),
}
}
for i, prefix := range resp.CommonPrefixes {
page.Objects[i+len(resp.Contents)] = &blob.ListObject{
Key: unescapeKey(prefix.Prefix),
IsDir: true,
}
}
if len(resp.Contents) > 0 && len(resp.CommonPrefixes) > 0 {
// S3 gives us blobs and "directories" in separate lists; sort them.
sort.Slice(page.Objects, func(i, j int) bool {
return page.Objects[i].Key < page.Objects[j].Key
})
}
}
return &page, nil
}
// Attributes implements [blob/Driver.Attributes].
func (drv *driver) Attributes(ctx context.Context, key string) (*blob.Attributes, error) {
key = escapeKey(key)
resp, err := drv.s3.HeadObject(ctx, key)
if err != nil {
return nil, err
}
md := make(map[string]string, len(resp.Metadata))
for k, v := range resp.Metadata {
// See the package comments for more details on escaping of metadata keys & values.
md[blob.HexUnescape(urlUnescape(k))] = urlUnescape(v)
}
return &blob.Attributes{
CacheControl: resp.CacheControl,
ContentDisposition: resp.ContentDisposition,
ContentEncoding: resp.ContentEncoding,
ContentLanguage: resp.ContentLanguage,
ContentType: resp.ContentType,
Metadata: md,
// CreateTime not supported; left as the zero time.
ModTime: resp.LastModified,
Size: resp.ContentLength,
MD5: eTagToMD5(resp.ETag),
ETag: resp.ETag,
}, nil
}
// NewRangeReader implements [blob/Driver.NewRangeReader].
func (drv *driver) NewRangeReader(ctx context.Context, key string, offset, length int64) (blob.DriverReader, error) {
key = escapeKey(key)
var byteRange string
if offset > 0 && length < 0 {
byteRange = fmt.Sprintf("bytes=%d-", offset)
} else if length == 0 {
// AWS doesn't support a zero-length read; we'll read 1 byte and then
// ignore it in favor of http.NoBody below.
byteRange = fmt.Sprintf("bytes=%d-%d", offset, offset)
} else if length >= 0 {
byteRange = fmt.Sprintf("bytes=%d-%d", offset, offset+length-1)
}
reqOpt := func(req *http.Request) {
req.Header.Set("Range", byteRange)
}
resp, err := drv.s3.GetObject(ctx, key, reqOpt)
if err != nil {
return nil, err
}
body := resp.Body
if length == 0 {
body = http.NoBody
}
return &reader{
body: body,
attrs: &blob.ReaderAttributes{
ContentType: resp.ContentType,
ModTime: resp.LastModified,
Size: getSize(resp.ContentLength, resp.ContentRange),
},
}, nil
}
// NewTypedWriter implements [blob/Driver.NewTypedWriter].
func (drv *driver) NewTypedWriter(ctx context.Context, key string, contentType string, opts *blob.WriterOptions) (blob.DriverWriter, error) {
key = escapeKey(key)
u := &s3.Uploader{
S3: drv.s3,
Key: key,
}
if opts.BufferSize != 0 {
u.MinPartSize = opts.BufferSize
}
if opts.MaxConcurrency != 0 {
u.MaxConcurrency = opts.MaxConcurrency
}
md := make(map[string]string, len(opts.Metadata))
for k, v := range opts.Metadata {
// See the package comments for more details on escaping of metadata keys & values.
k = blob.HexEscape(url.PathEscape(k), func(runes []rune, i int) bool {
c := runes[i]
return c == '@' || c == ':' || c == '='
})
md[k] = url.PathEscape(v)
}
u.Metadata = md
var reqOptions []func(*http.Request)
reqOptions = append(reqOptions, func(r *http.Request) {
r.Header.Set("Content-Type", contentType)
if opts.CacheControl != "" {
r.Header.Set("Cache-Control", opts.CacheControl)
}
if opts.ContentDisposition != "" {
r.Header.Set("Content-Disposition", opts.ContentDisposition)
}
if opts.ContentEncoding != "" {
r.Header.Set("Content-Encoding", opts.ContentEncoding)
}
if opts.ContentLanguage != "" {
r.Header.Set("Content-Language", opts.ContentLanguage)
}
if len(opts.ContentMD5) > 0 {
r.Header.Set("Content-MD5", base64.StdEncoding.EncodeToString(opts.ContentMD5))
}
})
return &writer{
ctx: ctx,
uploader: u,
donec: make(chan struct{}),
reqOptions: reqOptions,
}, nil
}
// Copy implements [blob/Driver.Copy].
func (drv *driver) Copy(ctx context.Context, dstKey, srcKey string) error {
dstKey = escapeKey(dstKey)
srcKey = escapeKey(srcKey)
_, err := drv.s3.CopyObject(ctx, srcKey, dstKey)
return err
}
// Delete implements [blob/Driver.Delete].
func (drv *driver) Delete(ctx context.Context, key string) error {
key = escapeKey(key)
return drv.s3.DeleteObject(ctx, key)
}
// -------------------------------------------------------------------
// reader reads an S3 object. It implements io.ReadCloser.
type reader struct {
attrs *blob.ReaderAttributes
body io.ReadCloser
}
// Read implements [io/ReadCloser.Read].
func (r *reader) Read(p []byte) (int, error) {
return r.body.Read(p)
}
// Close closes the reader itself. It must be called when done reading.
func (r *reader) Close() error {
return r.body.Close()
}
// Attributes implements [blob/DriverReader.Attributes].
func (r *reader) Attributes() *blob.ReaderAttributes {
return r.attrs
}
// -------------------------------------------------------------------
// writer writes an S3 object, it implements io.WriteCloser.
type writer struct {
ctx context.Context
err error // written before donec closes
uploader *s3.Uploader
// Ends of an io.Pipe, created when the first byte is written.
pw *io.PipeWriter
pr *io.PipeReader
donec chan struct{} // closed when done writing
reqOptions []func(*http.Request)
}
// Write appends p to w.pw. User must call Close to close the w after done writing.
func (w *writer) Write(p []byte) (int, error) {
// Avoid opening the pipe for a zero-length write;
// the concrete can do these for empty blobs.
if len(p) == 0 {
return 0, nil
}
if w.pw == nil {
// We'll write into pw and use pr as an io.Reader for the
// Upload call to S3.
w.pr, w.pw = io.Pipe()
w.open(w.pr, true)
}
return w.pw.Write(p)
}
// r may be nil if we're Closing and no data was written.
// If closePipeOnError is true, w.pr will be closed if there's an
// error uploading to S3.
func (w *writer) open(r io.Reader, closePipeOnError bool) {
// This goroutine will keep running until Close, unless there's an error.
go func() {
defer func() {
close(w.donec)
}()
if r == nil {
// AWS doesn't like a nil Body.
r = http.NoBody
}
w.uploader.Payload = r
err := w.uploader.Upload(w.ctx, w.reqOptions...)
if err != nil {
if closePipeOnError {
w.pr.CloseWithError(err)
}
w.err = err
}
}()
}
// Close completes the writer and closes it. Any error occurring during write
// will be returned. If a writer is closed before any Write is called, Close
// will create an empty file at the given key.
func (w *writer) Close() error {
if w.pr != nil {
defer w.pr.Close()
}
if w.pw == nil {
// We never got any bytes written. We'll write an http.NoBody.
w.open(nil, false)
} else if err := w.pw.Close(); err != nil {
return err
}
<-w.donec
return w.err
}
// -------------------------------------------------------------------
// etagToMD5 processes an ETag header and returns an MD5 hash if possible.
// S3's ETag header is sometimes a quoted hexstring of the MD5. Other times,
// notably when the object was uploaded in multiple parts, it is not.
// We do the best we can.
// Some links about ETag:
// https://docs.aws.amazon.com/AmazonS3/latest/API/RESTCommonResponseHeaders.html
// https://github.com/aws/aws-sdk-net/issues/815
// https://teppen.io/2018/06/23/aws_s3_etags/
func eTagToMD5(etag string) []byte {
// No header at all.
if etag == "" {
return nil
}
// Strip the expected leading and trailing quotes.
if len(etag) < 2 || etag[0] != '"' || etag[len(etag)-1] != '"' {
return nil
}
unquoted := etag[1 : len(etag)-1]
// Un-hex; we return nil on error. In particular, we'll get an error here
// for multi-part uploaded blobs, whose ETag will contain a "-" and so will
// never be a legal hex encoding.
md5, err := hex.DecodeString(unquoted)
if err != nil {
return nil
}
return md5
}
func getSize(contentLength int64, contentRange string) int64 {
// Default size to ContentLength, but that's incorrect for partial-length reads,
// where ContentLength refers to the size of the returned Body, not the entire
// size of the blob. ContentRange has the full size.
size := contentLength
if contentRange != "" {
// Sample: bytes 10-14/27 (where 27 is the full size).
parts := strings.Split(contentRange, "/")
if len(parts) == 2 {
if i, err := strconv.ParseInt(parts[1], 10, 64); err == nil {
size = i
}
}
}
return size
}
// escapeKey does all required escaping for UTF-8 strings to work with S3.
func escapeKey(key string) string {
return blob.HexEscape(key, func(r []rune, i int) bool {
c := r[i]
// S3 doesn't handle these characters (determined via experimentation).
if c < 32 {
return true
}
// For "../", escape the trailing slash.
if i > 1 && c == '/' && r[i-1] == '.' && r[i-2] == '.' {
return true
}
return false
})
}
// unescapeKey reverses escapeKey.
func unescapeKey(key string) string {
return blob.HexUnescape(key)
}
// urlUnescape reverses URLEscape using url.PathUnescape. If the unescape
// returns an error, it returns s.
func urlUnescape(s string) string {
if u, err := url.PathUnescape(s); err == nil {
return u
}
return s
}

View file

@ -0,0 +1,605 @@
package s3blob_test
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"testing"
"github.com/pocketbase/pocketbase/tools/filesystem/blob"
"github.com/pocketbase/pocketbase/tools/filesystem/internal/s3blob"
"github.com/pocketbase/pocketbase/tools/filesystem/internal/s3blob/s3"
"github.com/pocketbase/pocketbase/tools/filesystem/internal/s3blob/s3/tests"
)
func TestNew(t *testing.T) {
t.Parallel()
scenarios := []struct {
name string
s3Client *s3.S3
expectError bool
}{
{
"blank",
&s3.S3{},
true,
},
{
"no bucket",
&s3.S3{Region: "b", Endpoint: "c"},
true,
},
{
"no endpoint",
&s3.S3{Bucket: "a", Region: "b"},
true,
},
{
"no region",
&s3.S3{Bucket: "a", Endpoint: "c"},
true,
},
{
"with bucket, endpoint and region",
&s3.S3{Bucket: "a", Region: "b", Endpoint: "c"},
false,
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
drv, err := s3blob.New(s.s3Client)
hasErr := err != nil
if hasErr != s.expectError {
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
}
if err == nil && drv == nil {
t.Fatal("Expected non-nil driver instance")
}
})
}
}
func TestDriverClose(t *testing.T) {
t.Parallel()
drv, err := s3blob.New(&s3.S3{Bucket: "a", Region: "b", Endpoint: "c"})
if err != nil {
t.Fatal(err)
}
err = drv.Close()
if err != nil {
t.Fatalf("Expected nil, got error %v", err)
}
}
func TestDriverNormilizeError(t *testing.T) {
t.Parallel()
drv, err := s3blob.New(&s3.S3{Bucket: "a", Region: "b", Endpoint: "c"})
if err != nil {
t.Fatal(err)
}
scenarios := []struct {
name string
err error
expectErrNotFound bool
}{
{
"plain error",
errors.New("test"),
false,
},
{
"response error with only status (non-404)",
&s3.ResponseError{Status: 123},
false,
},
{
"response error with only status (404)",
&s3.ResponseError{Status: 404},
true,
},
{
"response error with custom code",
&s3.ResponseError{Code: "test"},
false,
},
{
"response error with NoSuchBucket code",
&s3.ResponseError{Code: "NoSuchBucket"},
true,
},
{
"response error with NoSuchKey code",
&s3.ResponseError{Code: "NoSuchKey"},
true,
},
{
"response error with NotFound code",
&s3.ResponseError{Code: "NotFound"},
true,
},
{
"wrapped response error with NotFound code", // ensures that the entire error's tree is checked
fmt.Errorf("test: %w", &s3.ResponseError{Code: "NotFound"}),
true,
},
{
"already normalized error",
fmt.Errorf("test: %w", blob.ErrNotFound),
true,
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
err := drv.NormalizeError(s.err)
if err == nil {
t.Fatal("Expected non-nil error")
}
isErrNotFound := errors.Is(err, blob.ErrNotFound)
if isErrNotFound != s.expectErrNotFound {
t.Fatalf("Expected isErrNotFound %v, got %v (%v)", s.expectErrNotFound, isErrNotFound, err)
}
})
}
}
func TestDriverDeleteEscaping(t *testing.T) {
t.Parallel()
httpClient := tests.NewClient(&tests.RequestStub{
Method: http.MethodDelete,
URL: "https://test_bucket.example.com/..__0x2f__abc/test/",
})
drv, err := s3blob.New(&s3.S3{
Bucket: "test_bucket",
Region: "test_region",
Endpoint: "https://example.com",
Client: httpClient,
})
if err != nil {
t.Fatal(err)
}
err = drv.Delete(context.Background(), "../abc/test/")
if err != nil {
t.Fatal(err)
}
err = httpClient.AssertNoRemaining()
if err != nil {
t.Fatal(err)
}
}
func TestDriverCopyEscaping(t *testing.T) {
t.Parallel()
httpClient := tests.NewClient(&tests.RequestStub{
Method: http.MethodPut,
URL: "https://test_bucket.example.com/..__0x2f__a/",
Match: func(req *http.Request) bool {
return tests.ExpectHeaders(req.Header, map[string]string{
"x-amz-copy-source": "test_bucket%2F..__0x2f__b%2F",
})
},
Response: &http.Response{
Body: io.NopCloser(strings.NewReader(`<CopyObjectResult></CopyObjectResult>`)),
},
})
drv, err := s3blob.New(&s3.S3{
Bucket: "test_bucket",
Region: "test_region",
Endpoint: "https://example.com",
Client: httpClient,
})
if err != nil {
t.Fatal(err)
}
err = drv.Copy(context.Background(), "../a/", "../b/")
if err != nil {
t.Fatal(err)
}
err = httpClient.AssertNoRemaining()
if err != nil {
t.Fatal(err)
}
}
func TestDriverAttributes(t *testing.T) {
t.Parallel()
httpClient := tests.NewClient(&tests.RequestStub{
Method: http.MethodHead,
URL: "https://test_bucket.example.com/..__0x2f__a/",
Response: &http.Response{
Header: http.Header{
"Last-Modified": []string{"Mon, 01 Feb 2025 03:04:05 GMT"},
"Cache-Control": []string{"test_cache"},
"Content-Disposition": []string{"test_disposition"},
"Content-Encoding": []string{"test_encoding"},
"Content-Language": []string{"test_language"},
"Content-Type": []string{"test_type"},
"Content-Range": []string{"test_range"},
"Etag": []string{`"ce5be8b6f53645c596306c4572ece521"`},
"Content-Length": []string{"100"},
"x-amz-meta-AbC%40": []string{"%40test_meta_a"},
"x-amz-meta-Def": []string{"test_meta_b"},
},
Body: http.NoBody,
},
})
drv, err := s3blob.New(&s3.S3{
Bucket: "test_bucket",
Region: "test_region",
Endpoint: "https://example.com",
Client: httpClient,
})
if err != nil {
t.Fatal(err)
}
attrs, err := drv.Attributes(context.Background(), "../a/")
if err != nil {
t.Fatal(err)
}
raw, err := json.Marshal(attrs)
if err != nil {
t.Fatal(err)
}
expected := `{"cacheControl":"test_cache","contentDisposition":"test_disposition","contentEncoding":"test_encoding","contentLanguage":"test_language","contentType":"test_type","metadata":{"abc@":"@test_meta_a","def":"test_meta_b"},"createTime":"0001-01-01T00:00:00Z","modTime":"2025-02-01T03:04:05Z","size":100,"md5":"zlvotvU2RcWWMGxFcuzlIQ==","etag":"\"ce5be8b6f53645c596306c4572ece521\""}`
if str := string(raw); str != expected {
t.Fatalf("Expected attributes\n%s\ngot\n%s", expected, str)
}
err = httpClient.AssertNoRemaining()
if err != nil {
t.Fatal(err)
}
}
func TestDriverListPaged(t *testing.T) {
t.Parallel()
listResponse := func() *http.Response {
return &http.Response{
Body: io.NopCloser(strings.NewReader(`
<?xml version="1.0" encoding="UTF-8"?>
<ListBucketResult xmlns="http://s3.amazonaws.com/doc/2006-03-01/">
<Name>example</Name>
<ContinuationToken>ct</ContinuationToken>
<NextContinuationToken>test_next</NextContinuationToken>
<StartAfter>example0.txt</StartAfter>
<KeyCount>1</KeyCount>
<MaxKeys>3</MaxKeys>
<Contents>
<Key>..__0x2f__prefixB/test/example.txt</Key>
<LastModified>2025-01-01T01:02:03.123Z</LastModified>
<ETag>"ce5be8b6f53645c596306c4572ece521"</ETag>
<Size>123</Size>
</Contents>
<Contents>
<Key>prefixA/..__0x2f__escape.txt</Key>
<LastModified>2025-01-02T01:02:03.123Z</LastModified>
<Size>456</Size>
</Contents>
<CommonPrefixes>
<Prefix>prefixA</Prefix>
</CommonPrefixes>
<CommonPrefixes>
<Prefix>..__0x2f__prefixB</Prefix>
</CommonPrefixes>
</ListBucketResult>
`)),
}
}
expectedPage := `{"objects":[{"key":"../prefixB","modTime":"0001-01-01T00:00:00Z","size":0,"md5":null,"isDir":true},{"key":"../prefixB/test/example.txt","modTime":"2025-01-01T01:02:03.123Z","size":123,"md5":"zlvotvU2RcWWMGxFcuzlIQ==","isDir":false},{"key":"prefixA","modTime":"0001-01-01T00:00:00Z","size":0,"md5":null,"isDir":true},{"key":"prefixA/../escape.txt","modTime":"2025-01-02T01:02:03.123Z","size":456,"md5":null,"isDir":false}],"nextPageToken":"dGVzdF9uZXh0"}`
httpClient := tests.NewClient(
&tests.RequestStub{
Method: http.MethodGet,
URL: "https://test_bucket.example.com/?list-type=2&max-keys=1000",
Response: listResponse(),
},
&tests.RequestStub{
Method: http.MethodGet,
URL: "https://test_bucket.example.com/?continuation-token=test_token&delimiter=test_delimiter&list-type=2&max-keys=123&prefix=test_prefix",
Response: listResponse(),
},
)
drv, err := s3blob.New(&s3.S3{
Bucket: "test_bucket",
Region: "test_region",
Endpoint: "https://example.com",
Client: httpClient,
})
if err != nil {
t.Fatal(err)
}
scenarios := []struct {
name string
opts *blob.ListOptions
expected string
}{
{
"empty options",
&blob.ListOptions{},
expectedPage,
},
{
"filled options",
&blob.ListOptions{Prefix: "test_prefix", Delimiter: "test_delimiter", PageSize: 123, PageToken: []byte("test_token")},
expectedPage,
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
page, err := drv.ListPaged(context.Background(), s.opts)
if err != nil {
t.Fatal(err)
}
raw, err := json.Marshal(page)
if err != nil {
t.Fatal(err)
}
if str := string(raw); s.expected != str {
t.Fatalf("Expected page result\n%s\ngot\n%s", s.expected, str)
}
})
}
err = httpClient.AssertNoRemaining()
if err != nil {
t.Fatal(err)
}
}
func TestDriverNewRangeReader(t *testing.T) {
t.Parallel()
scenarios := []struct {
offset int64
length int64
httpClient *tests.Client
expectedAttrs string
}{
{
0,
0,
tests.NewClient(&tests.RequestStub{
Method: http.MethodGet,
URL: "https://test_bucket.example.com/..__0x2f__abc/test.txt",
Match: func(req *http.Request) bool {
return tests.ExpectHeaders(req.Header, map[string]string{
"Range": "bytes=0-0",
})
},
Response: &http.Response{
Header: http.Header{
"Last-Modified": []string{"Mon, 01 Feb 2025 03:04:05 GMT"},
"Content-Type": []string{"test_ct"},
"Content-Length": []string{"123"},
},
Body: io.NopCloser(strings.NewReader("test")),
},
}),
`{"contentType":"test_ct","modTime":"2025-02-01T03:04:05Z","size":123}`,
},
{
10,
-1,
tests.NewClient(&tests.RequestStub{
Method: http.MethodGet,
URL: "https://test_bucket.example.com/..__0x2f__abc/test.txt",
Match: func(req *http.Request) bool {
return tests.ExpectHeaders(req.Header, map[string]string{
"Range": "bytes=10-",
})
},
Response: &http.Response{
Header: http.Header{
"Last-Modified": []string{"Mon, 01 Feb 2025 03:04:05 GMT"},
"Content-Type": []string{"test_ct"},
"Content-Range": []string{"bytes 1-1/456"}, // should take precedence over content-length
"Content-Length": []string{"123"},
},
Body: io.NopCloser(strings.NewReader("test")),
},
}),
`{"contentType":"test_ct","modTime":"2025-02-01T03:04:05Z","size":456}`,
},
{
10,
0,
tests.NewClient(&tests.RequestStub{
Method: http.MethodGet,
URL: "https://test_bucket.example.com/..__0x2f__abc/test.txt",
Match: func(req *http.Request) bool {
return tests.ExpectHeaders(req.Header, map[string]string{
"Range": "bytes=10-10",
})
},
Response: &http.Response{
Header: http.Header{
"Last-Modified": []string{"Mon, 01 Feb 2025 03:04:05 GMT"},
"Content-Type": []string{"test_ct"},
// no range and length headers
// "Content-Range": []string{"bytes 1-1/456"},
// "Content-Length": []string{"123"},
},
Body: io.NopCloser(strings.NewReader("test")),
},
}),
`{"contentType":"test_ct","modTime":"2025-02-01T03:04:05Z","size":0}`,
},
{
10,
20,
tests.NewClient(&tests.RequestStub{
Method: http.MethodGet,
URL: "https://test_bucket.example.com/..__0x2f__abc/test.txt",
Match: func(req *http.Request) bool {
return tests.ExpectHeaders(req.Header, map[string]string{
"Range": "bytes=10-29",
})
},
Response: &http.Response{
Header: http.Header{
"Last-Modified": []string{"Mon, 01 Feb 2025 03:04:05 GMT"},
"Content-Type": []string{"test_ct"},
// with range header but invalid format -> content-length takes precedence
"Content-Range": []string{"bytes invalid-456"},
"Content-Length": []string{"123"},
},
Body: io.NopCloser(strings.NewReader("test")),
},
}),
`{"contentType":"test_ct","modTime":"2025-02-01T03:04:05Z","size":123}`,
},
}
for _, s := range scenarios {
t.Run(fmt.Sprintf("offset_%d_length_%d", s.offset, s.length), func(t *testing.T) {
drv, err := s3blob.New(&s3.S3{
Bucket: "test_bucket",
Region: "tesst_region",
Endpoint: "https://example.com",
Client: s.httpClient,
})
if err != nil {
t.Fatal(err)
}
r, err := drv.NewRangeReader(context.Background(), "../abc/test.txt", s.offset, s.length)
if err != nil {
t.Fatal(err)
}
defer r.Close()
// the response body should be always replaced with http.NoBody
if s.length == 0 {
body := make([]byte, 1)
n, err := r.Read(body)
if n != 0 || !errors.Is(err, io.EOF) {
t.Fatalf("Expected body to be http.NoBody, got %v (%v)", body, err)
}
}
rawAttrs, err := json.Marshal(r.Attributes())
if err != nil {
t.Fatal(err)
}
if str := string(rawAttrs); str != s.expectedAttrs {
t.Fatalf("Expected attributes\n%s\ngot\n%s", s.expectedAttrs, str)
}
err = s.httpClient.AssertNoRemaining()
if err != nil {
t.Fatal(err)
}
})
}
}
func TestDriverNewTypedWriter(t *testing.T) {
t.Parallel()
httpClient := tests.NewClient(
&tests.RequestStub{
Method: http.MethodPut,
URL: "https://test_bucket.example.com/..__0x2f__abc/test/",
Match: func(req *http.Request) bool {
body, err := io.ReadAll(req.Body)
if err != nil {
return false
}
return string(body) == "test" && tests.ExpectHeaders(req.Header, map[string]string{
"cache-control": "test_cache_control",
"content-disposition": "test_content_disposition",
"content-encoding": "test_content_encoding",
"content-language": "test_content_language",
"content-type": "test_ct",
"content-md5": "dGVzdA==",
})
},
},
)
drv, err := s3blob.New(&s3.S3{
Bucket: "test_bucket",
Region: "test_region",
Endpoint: "https://example.com",
Client: httpClient,
})
if err != nil {
t.Fatal(err)
}
options := &blob.WriterOptions{
CacheControl: "test_cache_control",
ContentDisposition: "test_content_disposition",
ContentEncoding: "test_content_encoding",
ContentLanguage: "test_content_language",
ContentType: "test_content_type", // should be ignored
ContentMD5: []byte("test"),
Metadata: map[string]string{"@test_meta_a": "@test"},
}
w, err := drv.NewTypedWriter(context.Background(), "../abc/test/", "test_ct", options)
if err != nil {
t.Fatal(err)
}
n, err := w.Write(nil)
if err != nil {
t.Fatal(err)
}
if n != 0 {
t.Fatalf("Expected nil write to result in %d written bytes, got %d", 0, n)
}
n, err = w.Write([]byte("test"))
if err != nil {
t.Fatal(err)
}
if n != 4 {
t.Fatalf("Expected nil write to result in %d written bytes, got %d", 4, n)
}
err = w.Close()
if err != nil {
t.Fatal(err)
}
err = httpClient.AssertNoRemaining()
if err != nil {
t.Fatal(err)
}
}

45
tools/hook/event.go Normal file
View file

@ -0,0 +1,45 @@
package hook
// Resolver defines a common interface for a Hook event (see [Event]).
type Resolver interface {
// Next triggers the next handler in the hook's chain (if any).
Next() error
// note: kept only for the generic interface; may get removed in the future
nextFunc() func() error
setNextFunc(f func() error)
}
var _ Resolver = (*Event)(nil)
// Event implements [Resolver] and it is intended to be used as a base
// Hook event that you can embed in your custom typed event structs.
//
// Example:
//
// type CustomEvent struct {
// hook.Event
//
// SomeField int
// }
type Event struct {
next func() error
}
// Next calls the next hook handler.
func (e *Event) Next() error {
if e.next != nil {
return e.next()
}
return nil
}
// nextFunc returns the function that Next calls.
func (e *Event) nextFunc() func() error {
return e.next
}
// setNextFunc sets the function that Next calls.
func (e *Event) setNextFunc(f func() error) {
e.next = f
}

29
tools/hook/event_test.go Normal file
View file

@ -0,0 +1,29 @@
package hook
import "testing"
func TestEventNext(t *testing.T) {
calls := 0
e := Event{}
if e.nextFunc() != nil {
t.Fatalf("Expected nextFunc to be nil")
}
e.setNextFunc(func() error {
calls++
return nil
})
if e.nextFunc() == nil {
t.Fatalf("Expected nextFunc to be non-nil")
}
e.Next()
e.Next()
if calls != 2 {
t.Fatalf("Expected %d calls, got %d", 2, calls)
}
}

178
tools/hook/hook.go Normal file
View file

@ -0,0 +1,178 @@
package hook
import (
"sort"
"sync"
"github.com/pocketbase/pocketbase/tools/security"
)
// Handler defines a single Hook handler.
// Multiple handlers can share the same id.
// If Id is not explicitly set it will be autogenerated by Hook.Add and Hook.AddHandler.
type Handler[T Resolver] struct {
// Func defines the handler function to execute.
//
// Note that users need to call e.Next() in order to proceed with
// the execution of the hook chain.
Func func(T) error
// Id is the unique identifier of the handler.
//
// It could be used later to remove the handler from a hook via [Hook.Remove].
//
// If missing, an autogenerated value will be assigned when adding
// the handler to a hook.
Id string
// Priority allows changing the default exec priority of the handler within a hook.
//
// If 0, the handler will be executed in the same order it was registered.
Priority int
}
// Hook defines a generic concurrent safe structure for managing event hooks.
//
// When using custom event it must embed the base [hook.Event].
//
// Example:
//
// type CustomEvent struct {
// hook.Event
// SomeField int
// }
//
// h := Hook[*CustomEvent]{}
//
// h.BindFunc(func(e *CustomEvent) error {
// println(e.SomeField)
//
// return e.Next()
// })
//
// h.Trigger(&CustomEvent{ SomeField: 123 })
type Hook[T Resolver] struct {
handlers []*Handler[T]
mu sync.RWMutex
}
// Bind registers the provided handler to the current hooks queue.
//
// If handler.Id is empty it is updated with autogenerated value.
//
// If a handler from the current hook list has Id matching handler.Id
// then the old handler is replaced with the new one.
func (h *Hook[T]) Bind(handler *Handler[T]) string {
h.mu.Lock()
defer h.mu.Unlock()
var exists bool
if handler.Id == "" {
handler.Id = generateHookId()
// ensure that it doesn't exist
DUPLICATE_CHECK:
for _, existing := range h.handlers {
if existing.Id == handler.Id {
handler.Id = generateHookId()
goto DUPLICATE_CHECK
}
}
} else {
// replace existing
for i, existing := range h.handlers {
if existing.Id == handler.Id {
h.handlers[i] = handler
exists = true
break
}
}
}
// append new
if !exists {
h.handlers = append(h.handlers, handler)
}
// sort handlers by Priority, preserving the original order of equal items
sort.SliceStable(h.handlers, func(i, j int) bool {
return h.handlers[i].Priority < h.handlers[j].Priority
})
return handler.Id
}
// BindFunc is similar to Bind but registers a new handler from just the provided function.
//
// The registered handler is added with a default 0 priority and the id will be autogenerated.
//
// If you want to register a handler with custom priority or id use the [Hook.Bind] method.
func (h *Hook[T]) BindFunc(fn func(e T) error) string {
return h.Bind(&Handler[T]{Func: fn})
}
// Unbind removes one or many hook handler by their id.
func (h *Hook[T]) Unbind(idsToRemove ...string) {
h.mu.Lock()
defer h.mu.Unlock()
for _, id := range idsToRemove {
for i := len(h.handlers) - 1; i >= 0; i-- {
if h.handlers[i].Id == id {
h.handlers = append(h.handlers[:i], h.handlers[i+1:]...)
break // for now stop on the first occurrence since we don't allow handlers with duplicated ids
}
}
}
}
// UnbindAll removes all registered handlers.
func (h *Hook[T]) UnbindAll() {
h.mu.Lock()
defer h.mu.Unlock()
h.handlers = nil
}
// Length returns to total number of registered hook handlers.
func (h *Hook[T]) Length() int {
h.mu.RLock()
defer h.mu.RUnlock()
return len(h.handlers)
}
// Trigger executes all registered hook handlers one by one
// with the specified event as an argument.
//
// Optionally, this method allows also to register additional one off
// handler funcs that will be temporary appended to the handlers queue.
//
// NB! Each hook handler must call event.Next() in order the hook chain to proceed.
func (h *Hook[T]) Trigger(event T, oneOffHandlerFuncs ...func(T) error) error {
h.mu.RLock()
handlers := make([]func(T) error, 0, len(h.handlers)+len(oneOffHandlerFuncs))
for _, handler := range h.handlers {
handlers = append(handlers, handler.Func)
}
handlers = append(handlers, oneOffHandlerFuncs...)
h.mu.RUnlock()
event.setNextFunc(nil) // reset in case the event is being reused
for i := len(handlers) - 1; i >= 0; i-- {
i := i
old := event.nextFunc()
event.setNextFunc(func() error {
event.setNextFunc(old)
return handlers[i](event)
})
}
return event.Next()
}
func generateHookId() string {
return security.PseudorandomString(20)
}

162
tools/hook/hook_test.go Normal file
View file

@ -0,0 +1,162 @@
package hook
import (
"errors"
"testing"
)
func TestHookAddHandlerAndAdd(t *testing.T) {
calls := ""
h := Hook[*Event]{}
h.BindFunc(func(e *Event) error { calls += "1"; return e.Next() })
h.BindFunc(func(e *Event) error { calls += "2"; return e.Next() })
h3Id := h.BindFunc(func(e *Event) error { calls += "3"; return e.Next() })
h.Bind(&Handler[*Event]{
Id: h3Id, // should replace 3
Func: func(e *Event) error { calls += "3'"; return e.Next() },
})
h.Bind(&Handler[*Event]{
Func: func(e *Event) error { calls += "4"; return e.Next() },
Priority: -2,
})
h.Bind(&Handler[*Event]{
Func: func(e *Event) error { calls += "5"; return e.Next() },
Priority: -1,
})
h.Bind(&Handler[*Event]{
Func: func(e *Event) error { calls += "6"; return e.Next() },
})
h.Bind(&Handler[*Event]{
Func: func(e *Event) error { calls += "7"; e.Next(); return errors.New("test") }, // error shouldn't stop the chain
})
h.Trigger(
&Event{},
func(e *Event) error { calls += "8"; return e.Next() },
func(e *Event) error { calls += "9"; return nil }, // skip next
func(e *Event) error { calls += "10"; return e.Next() },
)
if total := len(h.handlers); total != 7 {
t.Fatalf("Expected %d handlers, found %d", 7, total)
}
expectedCalls := "45123'6789"
if calls != expectedCalls {
t.Fatalf("Expected calls sequence %q, got %q", expectedCalls, calls)
}
}
func TestHookLength(t *testing.T) {
h := Hook[*Event]{}
if l := h.Length(); l != 0 {
t.Fatalf("Expected 0 hook handlers, got %d", l)
}
h.BindFunc(func(e *Event) error { return e.Next() })
h.BindFunc(func(e *Event) error { return e.Next() })
if l := h.Length(); l != 2 {
t.Fatalf("Expected 2 hook handlers, got %d", l)
}
}
func TestHookUnbind(t *testing.T) {
h := Hook[*Event]{}
calls := ""
id0 := h.BindFunc(func(e *Event) error { calls += "0"; return e.Next() })
id1 := h.BindFunc(func(e *Event) error { calls += "1"; return e.Next() })
h.BindFunc(func(e *Event) error { calls += "2"; return e.Next() })
h.Bind(&Handler[*Event]{
Func: func(e *Event) error { calls += "3"; return e.Next() },
})
h.Unbind("missing") // should do nothing and not panic
if total := len(h.handlers); total != 4 {
t.Fatalf("Expected %d handlers, got %d", 4, total)
}
h.Unbind(id1, id0)
if total := len(h.handlers); total != 2 {
t.Fatalf("Expected %d handlers, got %d", 2, total)
}
err := h.Trigger(&Event{}, func(e *Event) error { calls += "4"; return e.Next() })
if err != nil {
t.Fatal(err)
}
expectedCalls := "234"
if calls != expectedCalls {
t.Fatalf("Expected calls sequence %q, got %q", expectedCalls, calls)
}
}
func TestHookUnbindAll(t *testing.T) {
h := Hook[*Event]{}
h.UnbindAll() // should do nothing and not panic
h.BindFunc(func(e *Event) error { return nil })
h.BindFunc(func(e *Event) error { return nil })
if total := len(h.handlers); total != 2 {
t.Fatalf("Expected %d handlers before UnbindAll, found %d", 2, total)
}
h.UnbindAll()
if total := len(h.handlers); total != 0 {
t.Fatalf("Expected no handlers after UnbindAll, found %d", total)
}
}
func TestHookTriggerErrorPropagation(t *testing.T) {
err := errors.New("test")
scenarios := []struct {
name string
handlers []func(*Event) error
expectedError error
}{
{
"without error",
[]func(*Event) error{
func(e *Event) error { return e.Next() },
func(e *Event) error { return e.Next() },
},
nil,
},
{
"with error",
[]func(*Event) error{
func(e *Event) error { return e.Next() },
func(e *Event) error { e.Next(); return err },
func(e *Event) error { return e.Next() },
},
err,
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
h := Hook[*Event]{}
for _, handler := range s.handlers {
h.BindFunc(handler)
}
result := h.Trigger(&Event{})
if result != s.expectedError {
t.Fatalf("Expected %v, got %v", s.expectedError, result)
}
})
}
}

84
tools/hook/tagged.go Normal file
View file

@ -0,0 +1,84 @@
package hook
import (
"github.com/pocketbase/pocketbase/tools/list"
)
// Tagger defines an interface for event data structs that support tags/groups/categories/etc.
// Usually used together with TaggedHook.
type Tagger interface {
Resolver
Tags() []string
}
// wrapped local Hook embedded struct to limit the public API surface.
type mainHook[T Tagger] struct {
*Hook[T]
}
// NewTaggedHook creates a new TaggedHook with the provided main hook and optional tags.
func NewTaggedHook[T Tagger](hook *Hook[T], tags ...string) *TaggedHook[T] {
return &TaggedHook[T]{
mainHook[T]{hook},
tags,
}
}
// TaggedHook defines a proxy hook which register handlers that are triggered only
// if the TaggedHook.tags are empty or includes at least one of the event data tag(s).
type TaggedHook[T Tagger] struct {
mainHook[T]
tags []string
}
// CanTriggerOn checks if the current TaggedHook can be triggered with
// the provided event data tags.
//
// It returns always true if the hook doens't have any tags.
func (h *TaggedHook[T]) CanTriggerOn(tagsToCheck []string) bool {
if len(h.tags) == 0 {
return true // match all
}
for _, t := range tagsToCheck {
if list.ExistInSlice(t, h.tags) {
return true
}
}
return false
}
// Bind registers the provided handler to the current hooks queue.
//
// It is similar to [Hook.Bind] with the difference that the handler
// function is invoked only if the event data tags satisfy h.CanTriggerOn.
func (h *TaggedHook[T]) Bind(handler *Handler[T]) string {
fn := handler.Func
handler.Func = func(e T) error {
if h.CanTriggerOn(e.Tags()) {
return fn(e)
}
return e.Next()
}
return h.mainHook.Bind(handler)
}
// BindFunc registers a new handler with the specified function.
//
// It is similar to [Hook.Bind] with the difference that the handler
// function is invoked only if the event data tags satisfy h.CanTriggerOn.
func (h *TaggedHook[T]) BindFunc(fn func(e T) error) string {
return h.mainHook.BindFunc(func(e T) error {
if h.CanTriggerOn(e.Tags()) {
return fn(e)
}
return e.Next()
})
}

84
tools/hook/tagged_test.go Normal file
View file

@ -0,0 +1,84 @@
package hook
import (
"strings"
"testing"
)
type mockTagsEvent struct {
Event
tags []string
}
func (m mockTagsEvent) Tags() []string {
return m.tags
}
func TestTaggedHook(t *testing.T) {
calls := ""
base := &Hook[*mockTagsEvent]{}
base.BindFunc(func(e *mockTagsEvent) error { calls += "f0"; return e.Next() })
hA := NewTaggedHook(base)
hA.BindFunc(func(e *mockTagsEvent) error { calls += "a1"; return e.Next() })
hA.Bind(&Handler[*mockTagsEvent]{
Func: func(e *mockTagsEvent) error { calls += "a2"; return e.Next() },
Priority: -1,
})
hB := NewTaggedHook(base, "b1", "b2")
hB.BindFunc(func(e *mockTagsEvent) error { calls += "b1"; return e.Next() })
hB.Bind(&Handler[*mockTagsEvent]{
Func: func(e *mockTagsEvent) error { calls += "b2"; return e.Next() },
Priority: -2,
})
hC := NewTaggedHook(base, "c1", "c2")
hC.BindFunc(func(e *mockTagsEvent) error { calls += "c1"; return e.Next() })
hC.Bind(&Handler[*mockTagsEvent]{
Func: func(e *mockTagsEvent) error { calls += "c2"; return e.Next() },
Priority: -3,
})
scenarios := []struct {
event *mockTagsEvent
expectedCalls string
}{
{
&mockTagsEvent{},
"a2f0a1",
},
{
&mockTagsEvent{tags: []string{"missing"}},
"a2f0a1",
},
{
&mockTagsEvent{tags: []string{"b2"}},
"b2a2f0a1b1",
},
{
&mockTagsEvent{tags: []string{"c1"}},
"c2a2f0a1c1",
},
{
&mockTagsEvent{tags: []string{"b1", "c2"}},
"c2b2a2f0a1b1c1",
},
}
for _, s := range scenarios {
t.Run(strings.Join(s.event.tags, "_"), func(t *testing.T) {
calls = "" // reset
err := base.Trigger(s.event)
if err != nil {
t.Fatalf("Unexpected trigger error: %v", err)
}
if calls != s.expectedCalls {
t.Fatalf("Expected calls sequence %q, got %q", s.expectedCalls, calls)
}
})
}
}

View file

@ -0,0 +1,113 @@
package inflector
import (
"regexp"
"strings"
"unicode"
)
var columnifyRemoveRegex = regexp.MustCompile(`[^\w\.\*\-\_\@\#]+`)
var snakecaseSplitRegex = regexp.MustCompile(`[\W_]+`)
// UcFirst converts the first character of a string into uppercase.
func UcFirst(str string) string {
if str == "" {
return ""
}
s := []rune(str)
return string(unicode.ToUpper(s[0])) + string(s[1:])
}
// Columnify strips invalid db identifier characters.
func Columnify(str string) string {
return columnifyRemoveRegex.ReplaceAllString(str, "")
}
// Sentenize converts and normalizes string into a sentence.
func Sentenize(str string) string {
str = strings.TrimSpace(str)
if str == "" {
return ""
}
str = UcFirst(str)
lastChar := str[len(str)-1:]
if lastChar != "." && lastChar != "?" && lastChar != "!" {
return str + "."
}
return str
}
// Sanitize sanitizes `str` by removing all characters satisfying `removePattern`.
// Returns an error if the pattern is not valid regex string.
func Sanitize(str string, removePattern string) (string, error) {
exp, err := regexp.Compile(removePattern)
if err != nil {
return "", err
}
return exp.ReplaceAllString(str, ""), nil
}
// Snakecase removes all non word characters and converts any english text into a snakecase.
// "ABBREVIATIONS" are preserved, eg. "myTestDB" will become "my_test_db".
func Snakecase(str string) string {
var result strings.Builder
// split at any non word character and underscore
words := snakecaseSplitRegex.Split(str, -1)
for _, word := range words {
if word == "" {
continue
}
if result.Len() > 0 {
result.WriteString("_")
}
for i, c := range word {
if unicode.IsUpper(c) && i > 0 &&
// is not a following uppercase character
!unicode.IsUpper(rune(word[i-1])) {
result.WriteString("_")
}
result.WriteRune(c)
}
}
return strings.ToLower(result.String())
}
// Camelize converts the provided string to its "CamelCased" version
// (non alphanumeric characters are removed).
//
// For example:
//
// inflector.Camelize("send_email") // "SendEmail"
func Camelize(str string) string {
var result strings.Builder
var isPrevSpecial bool
for _, c := range str {
if !unicode.IsLetter(c) && !unicode.IsNumber(c) {
isPrevSpecial = true
continue
}
if isPrevSpecial || result.Len() == 0 {
isPrevSpecial = false
result.WriteRune(unicode.ToUpper(c))
} else {
result.WriteRune(c)
}
}
return result.String()
}

View file

@ -0,0 +1,175 @@
package inflector_test
import (
"fmt"
"testing"
"github.com/pocketbase/pocketbase/tools/inflector"
)
func TestUcFirst(t *testing.T) {
scenarios := []struct {
val string
expected string
}{
{"", ""},
{" ", " "},
{"Test", "Test"},
{"test", "Test"},
{"test test2", "Test test2"},
}
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%#v", i, s.val), func(t *testing.T) {
result := inflector.UcFirst(s.val)
if result != s.expected {
t.Fatalf("Expected %q, got %q", s.expected, result)
}
})
}
}
func TestColumnify(t *testing.T) {
scenarios := []struct {
val string
expected string
}{
{"", ""},
{" ", ""},
{"123", "123"},
{"Test.", "Test."},
{" test ", "test"},
{"test1.test2", "test1.test2"},
{"@test!abc", "@testabc"},
{"#test?abc", "#testabc"},
{"123test(123)#", "123test123#"},
{"test1--test2", "test1--test2"},
}
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%#v", i, s.val), func(t *testing.T) {
result := inflector.Columnify(s.val)
if result != s.expected {
t.Fatalf("Expected %q, got %q", s.expected, result)
}
})
}
}
func TestSentenize(t *testing.T) {
scenarios := []struct {
val string
expected string
}{
{"", ""},
{" ", ""},
{".", "."},
{"?", "?"},
{"!", "!"},
{"Test", "Test."},
{" test ", "Test."},
{"hello world", "Hello world."},
{"hello world.", "Hello world."},
{"hello world!", "Hello world!"},
{"hello world?", "Hello world?"},
}
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%#v", i, s.val), func(t *testing.T) {
result := inflector.Sentenize(s.val)
if result != s.expected {
t.Fatalf("Expected %q, got %q", s.expected, result)
}
})
}
}
func TestSanitize(t *testing.T) {
scenarios := []struct {
val string
pattern string
expected string
expectErr bool
}{
{"", ``, "", false},
{" ", ``, " ", false},
{" ", ` `, "", false},
{"", `[A-Z]`, "", false},
{"abcABC", `[A-Z]`, "abc", false},
{"abcABC", `[A-Z`, "", true}, // invalid pattern
}
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%#v", i, s.val), func(t *testing.T) {
result, err := inflector.Sanitize(s.val, s.pattern)
hasErr := err != nil
if s.expectErr != hasErr {
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectErr, hasErr, err)
}
if result != s.expected {
t.Fatalf("Expected %q, got %q", s.expected, result)
}
})
}
}
func TestSnakecase(t *testing.T) {
scenarios := []struct {
val string
expected string
}{
{"", ""},
{" ", ""},
{"!@#$%^", ""},
{"...", ""},
{"_", ""},
{"John Doe", "john_doe"},
{"John_Doe", "john_doe"},
{".a!b@c#d$e%123. ", "a_b_c_d_e_123"},
{"HelloWorld", "hello_world"},
{"HelloWorld1HelloWorld2", "hello_world1_hello_world2"},
{"TEST", "test"},
{"testABR", "test_abr"},
}
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%#v", i, s.val), func(t *testing.T) {
result := inflector.Snakecase(s.val)
if result != s.expected {
t.Fatalf("Expected %q, got %q", s.expected, result)
}
})
}
}
func TestCamelize(t *testing.T) {
scenarios := []struct {
val string
expected string
}{
{"", ""},
{" ", ""},
{"Test", "Test"},
{"test", "Test"},
{"testTest2", "TestTest2"},
{"TestTest2", "TestTest2"},
{"test test2", "TestTest2"},
{"test-test2", "TestTest2"},
{"test'test2", "TestTest2"},
{"test1test2", "Test1test2"},
{"1test-test2", "1testTest2"},
{"123", "123"},
{"123a", "123a"},
}
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%#v", i, s.val), func(t *testing.T) {
result := inflector.Camelize(s.val)
if result != s.expected {
t.Fatalf("Expected %q, got %q", s.expected, result)
}
})
}
}

View file

@ -0,0 +1,89 @@
package inflector
import (
"log"
"regexp"
"github.com/pocketbase/pocketbase/tools/store"
)
var compiledPatterns = store.New[string, *regexp.Regexp](nil)
// note: the patterns are extracted from popular Ruby/PHP/Node.js inflector packages
var singularRules = []struct {
pattern string // lazily compiled
replacement string
}{
{"(?i)([nrlm]ese|deer|fish|sheep|measles|ois|pox|media|ss)$", "${1}"},
{"(?i)^(sea[- ]bass)$", "${1}"},
{"(?i)(s)tatuses$", "${1}tatus"},
{"(?i)(f)eet$", "${1}oot"},
{"(?i)(t)eeth$", "${1}ooth"},
{"(?i)^(.*)(menu)s$", "${1}${2}"},
{"(?i)(quiz)zes$", "${1}"},
{"(?i)(matr)ices$", "${1}ix"},
{"(?i)(vert|ind)ices$", "${1}ex"},
{"(?i)^(ox)en", "${1}"},
{"(?i)(alias)es$", "${1}"},
{"(?i)(alumn|bacill|cact|foc|fung|nucle|radi|stimul|syllab|termin|viri?)i$", "${1}us"},
{"(?i)([ftw]ax)es", "${1}"},
{"(?i)(cris|ax|test)es$", "${1}is"},
{"(?i)(shoe)s$", "${1}"},
{"(?i)(o)es$", "${1}"},
{"(?i)ouses$", "ouse"},
{"(?i)([^a])uses$", "${1}us"},
{"(?i)([m|l])ice$", "${1}ouse"},
{"(?i)(x|ch|ss|sh)es$", "${1}"},
{"(?i)(m)ovies$", "${1}ovie"},
{"(?i)(s)eries$", "${1}eries"},
{"(?i)([^aeiouy]|qu)ies$", "${1}y"},
{"(?i)([lr])ves$", "${1}f"},
{"(?i)(tive)s$", "${1}"},
{"(?i)(hive)s$", "${1}"},
{"(?i)(drive)s$", "${1}"},
{"(?i)([^fo])ves$", "${1}fe"},
{"(?i)(^analy)ses$", "${1}sis"},
{"(?i)(analy|diagno|^ba|(p)arenthe|(p)rogno|(s)ynop|(t)he)ses$", "${1}${2}sis"},
{"(?i)([ti])a$", "${1}um"},
{"(?i)(p)eople$", "${1}erson"},
{"(?i)(m)en$", "${1}an"},
{"(?i)(c)hildren$", "${1}hild"},
{"(?i)(n)ews$", "${1}ews"},
{"(?i)(n)etherlands$", "${1}etherlands"},
{"(?i)eaus$", "eau"},
{"(?i)(currenc)ies$", "${1}y"},
{"(?i)^(.*us)$", "${1}"},
{"(?i)s$", ""},
}
// Singularize converts the specified word into its singular version.
//
// For example:
//
// inflector.Singularize("people") // "person"
func Singularize(word string) string {
if word == "" {
return ""
}
for _, rule := range singularRules {
re := compiledPatterns.GetOrSet(rule.pattern, func() *regexp.Regexp {
re, err := regexp.Compile(rule.pattern)
if err != nil {
return nil
}
return re
})
if re == nil {
// log only for debug purposes
log.Println("[Singularize] failed to retrieve/compile rule pattern " + rule.pattern)
continue
}
if re.MatchString(word) {
return re.ReplaceAllString(word, rule.replacement)
}
}
return word
}

View file

@ -0,0 +1,76 @@
package inflector_test
import (
"testing"
"github.com/pocketbase/pocketbase/tools/inflector"
)
func TestSingularize(t *testing.T) {
scenarios := []struct {
word string
expected string
}{
{"abcnese", "abcnese"},
{"deer", "deer"},
{"sheep", "sheep"},
{"measles", "measles"},
{"pox", "pox"},
{"media", "media"},
{"bliss", "bliss"},
{"sea-bass", "sea-bass"},
{"Statuses", "Status"},
{"Feet", "Foot"},
{"Teeth", "Tooth"},
{"abcmenus", "abcmenu"},
{"Quizzes", "Quiz"},
{"Matrices", "Matrix"},
{"Vertices", "Vertex"},
{"Indices", "Index"},
{"Aliases", "Alias"},
{"Alumni", "Alumnus"},
{"Bacilli", "Bacillus"},
{"Cacti", "Cactus"},
{"Fungi", "Fungus"},
{"Nuclei", "Nucleus"},
{"Radii", "Radius"},
{"Stimuli", "Stimulus"},
{"Syllabi", "Syllabus"},
{"Termini", "Terminus"},
{"Viri", "Virus"},
{"Faxes", "Fax"},
{"Crises", "Crisis"},
{"Axes", "Axis"},
{"Shoes", "Shoe"},
{"abcoes", "abco"},
{"Houses", "House"},
{"Mice", "Mouse"},
{"abcxes", "abcx"},
{"Movies", "Movie"},
{"Series", "Series"},
{"abcquies", "abcquy"},
{"Relatives", "Relative"},
{"Drives", "Drive"},
{"aardwolves", "aardwolf"},
{"Analyses", "Analysis"},
{"Diagnoses", "Diagnosis"},
{"People", "Person"},
{"Men", "Man"},
{"Children", "Child"},
{"News", "News"},
{"Netherlands", "Netherlands"},
{"Tableaus", "Tableau"},
{"Currencies", "Currency"},
{"abcs", "abc"},
{"abc", "abc"},
}
for _, s := range scenarios {
t.Run(s.word, func(t *testing.T) {
result := inflector.Singularize(s.word)
if result != s.expected {
t.Fatalf("Expected %q, got %q", s.expected, result)
}
})
}
}

163
tools/list/list.go Normal file
View file

@ -0,0 +1,163 @@
package list
import (
"encoding/json"
"regexp"
"strings"
"github.com/pocketbase/pocketbase/tools/store"
"github.com/spf13/cast"
)
var cachedPatterns = store.New[string, *regexp.Regexp](nil)
// SubtractSlice returns a new slice with only the "base" elements
// that don't exist in "subtract".
func SubtractSlice[T comparable](base []T, subtract []T) []T {
var result = make([]T, 0, len(base))
for _, b := range base {
if !ExistInSlice(b, subtract) {
result = append(result, b)
}
}
return result
}
// ExistInSlice checks whether a comparable element exists in a slice of the same type.
func ExistInSlice[T comparable](item T, list []T) bool {
for _, v := range list {
if v == item {
return true
}
}
return false
}
// ExistInSliceWithRegex checks whether a string exists in a slice
// either by direct match, or by a regular expression (eg. `^\w+$`).
//
// Note: Only list items starting with '^' and ending with '$' are treated as regular expressions!
func ExistInSliceWithRegex(str string, list []string) bool {
for _, field := range list {
isRegex := strings.HasPrefix(field, "^") && strings.HasSuffix(field, "$")
if !isRegex {
// check for direct match
if str == field {
return true
}
continue
}
// check for regex match
pattern := cachedPatterns.Get(field)
if pattern == nil {
var err error
pattern, err = regexp.Compile(field)
if err != nil {
continue
}
// "cache" the pattern to avoid compiling it every time
// (the limit size is arbitrary and it is there to prevent the cache growing too big)
//
// @todo consider replacing with TTL or LRU type cache
cachedPatterns.SetIfLessThanLimit(field, pattern, 500)
}
if pattern != nil && pattern.MatchString(str) {
return true
}
}
return false
}
// ToInterfaceSlice converts a generic slice to slice of interfaces.
func ToInterfaceSlice[T any](list []T) []any {
result := make([]any, len(list))
for i := range list {
result[i] = list[i]
}
return result
}
// NonzeroUniques returns only the nonzero unique values from a slice.
func NonzeroUniques[T comparable](list []T) []T {
result := make([]T, 0, len(list))
existMap := make(map[T]struct{}, len(list))
var zeroVal T
for _, val := range list {
if val == zeroVal {
continue
}
if _, ok := existMap[val]; ok {
continue
}
existMap[val] = struct{}{}
result = append(result, val)
}
return result
}
// ToUniqueStringSlice casts `value` to a slice of non-zero unique strings.
func ToUniqueStringSlice(value any) (result []string) {
switch val := value.(type) {
case nil:
// nothing to cast
case []string:
result = val
case string:
if val == "" {
break
}
// check if it is a json encoded array of strings
if strings.Contains(val, "[") {
if err := json.Unmarshal([]byte(val), &result); err != nil {
// not a json array, just add the string as single array element
result = append(result, val)
}
} else {
// just add the string as single array element
result = append(result, val)
}
case json.Marshaler: // eg. JSONArray
raw, _ := val.MarshalJSON()
_ = json.Unmarshal(raw, &result)
default:
result = cast.ToStringSlice(value)
}
return NonzeroUniques(result)
}
// ToChunks splits list into chunks.
//
// Zero or negative chunkSize argument is normalized to 1.
//
// See https://go.dev/wiki/SliceTricks#batching-with-minimal-allocation.
func ToChunks[T any](list []T, chunkSize int) [][]T {
if chunkSize <= 0 {
chunkSize = 1
}
chunks := make([][]T, 0, (len(list)+chunkSize-1)/chunkSize)
if len(list) == 0 {
return chunks
}
for chunkSize < len(list) {
list, chunks = list[chunkSize:], append(chunks, list[0:chunkSize:chunkSize])
}
return append(chunks, list)
}

310
tools/list/list_test.go Normal file
View file

@ -0,0 +1,310 @@
package list_test
import (
"encoding/json"
"fmt"
"testing"
"github.com/pocketbase/pocketbase/tools/list"
"github.com/pocketbase/pocketbase/tools/types"
)
func TestSubtractSliceString(t *testing.T) {
scenarios := []struct {
base []string
subtract []string
expected string
}{
{
[]string{},
[]string{},
`[]`,
},
{
[]string{},
[]string{"1", "2", "3", "4"},
`[]`,
},
{
[]string{"1", "2", "3", "4"},
[]string{},
`["1","2","3","4"]`,
},
{
[]string{"1", "2", "3", "4"},
[]string{"1", "2", "3", "4"},
`[]`,
},
{
[]string{"1", "2", "3", "4", "7"},
[]string{"2", "4", "5", "6"},
`["1","3","7"]`,
},
}
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%s", i, s.expected), func(t *testing.T) {
result := list.SubtractSlice(s.base, s.subtract)
raw, err := json.Marshal(result)
if err != nil {
t.Fatalf("Failed to serialize: %v", err)
}
strResult := string(raw)
if strResult != s.expected {
t.Fatalf("Expected %v, got %v", s.expected, strResult)
}
})
}
}
func TestSubtractSliceInt(t *testing.T) {
scenarios := []struct {
base []int
subtract []int
expected string
}{
{
[]int{},
[]int{},
`[]`,
},
{
[]int{},
[]int{1, 2, 3, 4},
`[]`,
},
{
[]int{1, 2, 3, 4},
[]int{},
`[1,2,3,4]`,
},
{
[]int{1, 2, 3, 4},
[]int{1, 2, 3, 4},
`[]`,
},
{
[]int{1, 2, 3, 4, 7},
[]int{2, 4, 5, 6},
`[1,3,7]`,
},
}
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%s", i, s.expected), func(t *testing.T) {
result := list.SubtractSlice(s.base, s.subtract)
raw, err := json.Marshal(result)
if err != nil {
t.Fatalf("Failed to serialize: %v", err)
}
strResult := string(raw)
if strResult != s.expected {
t.Fatalf("Expected %v, got %v", s.expected, strResult)
}
})
}
}
func TestExistInSliceString(t *testing.T) {
scenarios := []struct {
item string
list []string
expected bool
}{
{"", []string{""}, true},
{"", []string{"1", "2", "test 123"}, false},
{"test", []string{}, false},
{"test", []string{"TEST"}, false},
{"test", []string{"1", "2", "test 123"}, false},
{"test", []string{"1", "2", "test"}, true},
}
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%s", i, s.item), func(t *testing.T) {
result := list.ExistInSlice(s.item, s.list)
if result != s.expected {
t.Fatalf("Expected %v, got %v", s.expected, result)
}
})
}
}
func TestExistInSliceInt(t *testing.T) {
scenarios := []struct {
item int
list []int
expected bool
}{
{0, []int{}, false},
{0, []int{0}, true},
{4, []int{1, 2, 3}, false},
{1, []int{1, 2, 3}, true},
{-1, []int{0, 1, 2, 3}, false},
{-1, []int{0, -1, -2, -3, -4}, true},
}
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%d", i, s.item), func(t *testing.T) {
result := list.ExistInSlice(s.item, s.list)
if result != s.expected {
t.Fatalf("Expected %v, got %v", s.expected, result)
}
})
}
}
func TestExistInSliceWithRegex(t *testing.T) {
scenarios := []struct {
item string
list []string
expected bool
}{
{"", []string{``}, true},
{"", []string{`^\W+$`}, false},
{" ", []string{`^\W+$`}, true},
{"test", []string{`^\invalid[+$`}, false}, // invalid regex
{"test", []string{`^\W+$`, "test"}, true},
{`^\W+$`, []string{`^\W+$`, "test"}, false}, // direct match shouldn't work for this case
{`\W+$`, []string{`\W+$`, "test"}, true}, // direct match should work for this case because it is not an actual supported pattern format
{"!?@", []string{`\W+$`, "test"}, false}, // the method requires the pattern elems to start with '^'
{"!?@", []string{`^\W+`, "test"}, false}, // the method requires the pattern elems to end with '$'
{"!?@", []string{`^\W+$`, "test"}, true},
{"!?@test", []string{`^\W+$`, "test"}, false},
}
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%s", i, s.item), func(t *testing.T) {
result := list.ExistInSliceWithRegex(s.item, s.list)
if result != s.expected {
t.Fatalf("Expected %v, got %v", s.expected, result)
}
})
}
}
func TestToInterfaceSlice(t *testing.T) {
scenarios := []struct {
items []string
}{
{[]string{}},
{[]string{""}},
{[]string{"1", "test"}},
{[]string{"test1", "test1", "test2", "test3"}},
}
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%#v", i, s.items), func(t *testing.T) {
result := list.ToInterfaceSlice(s.items)
if len(result) != len(s.items) {
t.Fatalf("Expected length %d, got %d", len(s.items), len(result))
}
for j, v := range result {
if v != s.items[j] {
t.Fatalf("Result list item doesn't match with the original list item, got %v VS %v", v, s.items[j])
}
}
})
}
}
func TestNonzeroUniquesString(t *testing.T) {
scenarios := []struct {
items []string
expected []string
}{
{[]string{}, []string{}},
{[]string{""}, []string{}},
{[]string{"1", "test"}, []string{"1", "test"}},
{[]string{"test1", "", "test2", "Test2", "test1", "test3"}, []string{"test1", "test2", "Test2", "test3"}},
}
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%#v", i, s.items), func(t *testing.T) {
result := list.NonzeroUniques(s.items)
if len(result) != len(s.expected) {
t.Fatalf("Expected length %d, got %d", len(s.expected), len(result))
}
for j, v := range result {
if v != s.expected[j] {
t.Fatalf("Result list item doesn't match with the expected list item, got %v VS %v", v, s.expected[j])
}
}
})
}
}
func TestToUniqueStringSlice(t *testing.T) {
scenarios := []struct {
value any
expected []string
}{
{nil, []string{}},
{"", []string{}},
{[]any{}, []string{}},
{[]int{}, []string{}},
{"test", []string{"test"}},
{[]int{1, 2, 3}, []string{"1", "2", "3"}},
{[]any{0, 1, "test", ""}, []string{"0", "1", "test"}},
{[]string{"test1", "test2", "test1"}, []string{"test1", "test2"}},
{`["test1", "test2", "test2"]`, []string{"test1", "test2"}},
{types.JSONArray[string]{"test1", "test2", "test1"}, []string{"test1", "test2"}},
}
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%#v", i, s.value), func(t *testing.T) {
result := list.ToUniqueStringSlice(s.value)
if len(result) != len(s.expected) {
t.Fatalf("Expected length %d, got %d", len(s.expected), len(result))
}
for j, v := range result {
if v != s.expected[j] {
t.Fatalf("Result list item doesn't match with the expected list item, got %v vs %v", v, s.expected[j])
}
}
})
}
}
func TestToChunks(t *testing.T) {
scenarios := []struct {
items []any
chunkSize int
expected string
}{
{nil, 2, "[]"},
{[]any{}, 2, "[]"},
{[]any{1, 2, 3, 4}, -1, "[[1],[2],[3],[4]]"},
{[]any{1, 2, 3, 4}, 0, "[[1],[2],[3],[4]]"},
{[]any{1, 2, 3, 4}, 2, "[[1,2],[3,4]]"},
{[]any{1, 2, 3, 4, 5}, 2, "[[1,2],[3,4],[5]]"},
{[]any{1, 2, 3, 4, 5}, 10, "[[1,2,3,4,5]]"},
}
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%#v", i, s.items), func(t *testing.T) {
result := list.ToChunks(s.items, s.chunkSize)
raw, err := json.Marshal(result)
if err != nil {
t.Fatal(err)
}
rawStr := string(raw)
if rawStr != s.expected {
t.Fatalf("Expected %v, got %v", s.expected, rawStr)
}
})
}
}

View file

@ -0,0 +1,325 @@
package logger
import (
"context"
"encoding/json"
"errors"
"log/slog"
"sync"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/pocketbase/pocketbase/tools/types"
)
var _ slog.Handler = (*BatchHandler)(nil)
// BatchOptions are options for the BatchHandler.
type BatchOptions struct {
// WriteFunc processes the batched logs.
WriteFunc func(ctx context.Context, logs []*Log) error
// BeforeAddFunc is optional function that is invoked every time
// before a new log is added to the batch queue.
//
// Return false to skip adding the log into the batch queue.
BeforeAddFunc func(ctx context.Context, log *Log) bool
// Level reports the minimum level to log.
// Levels with lower levels are discarded.
// If nil, the Handler uses [slog.LevelInfo].
Level slog.Leveler
// BatchSize specifies how many logs to accumulate before calling WriteFunc.
// If not set or 0, fallback to 100 by default.
BatchSize int
}
// NewBatchHandler creates a slog compatible handler that writes JSON
// logs on batches (default to 100), using the given options.
//
// Panics if [BatchOptions.WriteFunc] is not defined.
//
// Example:
//
// l := slog.New(logger.NewBatchHandler(logger.BatchOptions{
// WriteFunc: func(ctx context.Context, logs []*Log) error {
// for _, l := range logs {
// fmt.Println(l.Level, l.Message, l.Data)
// }
// return nil
// }
// }))
// l.Info("Example message", "title", "lorem ipsum")
func NewBatchHandler(options BatchOptions) *BatchHandler {
h := &BatchHandler{
mux: &sync.Mutex{},
options: &options,
}
if h.options.WriteFunc == nil {
panic("options.WriteFunc must be set")
}
if h.options.Level == nil {
h.options.Level = slog.LevelInfo
}
if h.options.BatchSize == 0 {
h.options.BatchSize = 100
}
h.logs = make([]*Log, 0, h.options.BatchSize)
return h
}
// BatchHandler is a slog handler that writes records on batches.
//
// The log records attributes are formatted in JSON.
//
// Requires the [BatchOptions.WriteFunc] option to be defined.
type BatchHandler struct {
mux *sync.Mutex
parent *BatchHandler
options *BatchOptions
group string
attrs []slog.Attr
logs []*Log
}
// Enabled reports whether the handler handles records at the given level.
//
// The handler ignores records whose level is lower.
func (h *BatchHandler) Enabled(ctx context.Context, level slog.Level) bool {
return level >= h.options.Level.Level()
}
// WithGroup returns a new BatchHandler that starts a group.
//
// All logger attributes will be resolved under the specified group name.
func (h *BatchHandler) WithGroup(name string) slog.Handler {
if name == "" {
return h
}
return &BatchHandler{
parent: h,
mux: h.mux,
options: h.options,
group: name,
}
}
// WithAttrs returns a new BatchHandler loaded with the specified attributes.
func (h *BatchHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
if len(attrs) == 0 {
return h
}
return &BatchHandler{
parent: h,
mux: h.mux,
options: h.options,
attrs: attrs,
}
}
// Handle formats the slog.Record argument as JSON object and adds it
// to the batch queue.
//
// If the batch queue threshold has been reached, the WriteFunc option
// is invoked with the accumulated logs which in turn will reset the batch queue.
func (h *BatchHandler) Handle(ctx context.Context, r slog.Record) error {
if h.group != "" {
h.mux.Lock()
attrs := make([]any, 0, len(h.attrs)+r.NumAttrs())
for _, a := range h.attrs {
attrs = append(attrs, a)
}
h.mux.Unlock()
r.Attrs(func(a slog.Attr) bool {
attrs = append(attrs, a)
return true
})
r = slog.NewRecord(r.Time, r.Level, r.Message, r.PC)
r.AddAttrs(slog.Group(h.group, attrs...))
} else if len(h.attrs) > 0 {
r = r.Clone()
h.mux.Lock()
r.AddAttrs(h.attrs...)
h.mux.Unlock()
}
if h.parent != nil {
return h.parent.Handle(ctx, r)
}
data := make(map[string]any, r.NumAttrs())
r.Attrs(func(a slog.Attr) bool {
if err := h.resolveAttr(data, a); err != nil {
return false
}
return true
})
log := &Log{
Time: r.Time,
Level: r.Level,
Message: r.Message,
Data: types.JSONMap[any](data),
}
if h.options.BeforeAddFunc != nil && !h.options.BeforeAddFunc(ctx, log) {
return nil
}
h.mux.Lock()
h.logs = append(h.logs, log)
totalLogs := len(h.logs)
h.mux.Unlock()
if totalLogs >= h.options.BatchSize {
if err := h.WriteAll(ctx); err != nil {
return err
}
}
return nil
}
// SetLevel updates the handler options level to the specified one.
func (h *BatchHandler) SetLevel(level slog.Level) {
h.mux.Lock()
h.options.Level = level
h.mux.Unlock()
}
// WriteAll writes all accumulated Log entries and resets the batch queue.
func (h *BatchHandler) WriteAll(ctx context.Context) error {
if h.parent != nil {
// invoke recursively the parent level handler since the most
// top level one is holding the logs queue.
return h.parent.WriteAll(ctx)
}
h.mux.Lock()
totalLogs := len(h.logs)
// no logs to write
if totalLogs == 0 {
h.mux.Unlock()
return nil
}
// create a copy of the logs slice to prevent blocking during write
logs := make([]*Log, totalLogs)
copy(logs, h.logs)
h.logs = h.logs[:0] // reset
h.mux.Unlock()
return h.options.WriteFunc(ctx, logs)
}
// resolveAttr writes attr into data.
func (h *BatchHandler) resolveAttr(data map[string]any, attr slog.Attr) error {
// ensure that the attr value is resolved before doing anything else
attr.Value = attr.Value.Resolve()
if attr.Equal(slog.Attr{}) {
return nil // ignore empty attrs
}
switch attr.Value.Kind() {
case slog.KindGroup:
attrs := attr.Value.Group()
if len(attrs) == 0 {
return nil // ignore empty groups
}
// create a submap to wrap the resolved group attributes
groupData := make(map[string]any, len(attrs))
for _, subAttr := range attrs {
h.resolveAttr(groupData, subAttr)
}
if len(groupData) > 0 {
data[attr.Key] = groupData
}
default:
data[attr.Key] = normalizeLogAttrValue(attr.Value.Any())
}
return nil
}
func normalizeLogAttrValue(rawAttrValue any) any {
switch attrV := rawAttrValue.(type) {
case validation.Errors:
out := make(map[string]any, len(attrV))
for k, v := range attrV {
out[k] = serializeLogError(v)
}
return out
case map[string]validation.Error:
out := make(map[string]any, len(attrV))
for k, v := range attrV {
out[k] = serializeLogError(v)
}
return out
case map[string]error:
out := make(map[string]any, len(attrV))
for k, v := range attrV {
out[k] = serializeLogError(v)
}
return out
case map[string]any:
out := make(map[string]any, len(attrV))
for k, v := range attrV {
switch vv := v.(type) {
case error:
out[k] = serializeLogError(vv)
default:
out[k] = normalizeLogAttrValue(vv)
}
}
return out
case error:
// check for wrapped validation.Errors
var ve validation.Errors
if errors.As(attrV, &ve) {
out := make(map[string]any, len(ve))
for k, v := range ve {
out[k] = serializeLogError(v)
}
return map[string]any{
"data": out,
"raw": serializeLogError(attrV),
}
}
return serializeLogError(attrV)
default:
return attrV
}
}
func serializeLogError(err error) any {
if err == nil {
return nil
}
// prioritize a json structured format (e.g. validation.Errors)
jsonErr, ok := err.(json.Marshaler)
if ok {
return jsonErr
}
// fallback to its original string representation
return err.Error()
}

View file

@ -0,0 +1,354 @@
package logger
import (
"context"
"errors"
"fmt"
"log/slog"
"testing"
"time"
validation "github.com/go-ozzo/ozzo-validation/v4"
)
func TestNewBatchHandlerPanic(t *testing.T) {
defer func() {
if r := recover(); r == nil {
t.Errorf("Expected to panic.")
}
}()
NewBatchHandler(BatchOptions{})
}
func TestNewBatchHandlerDefaults(t *testing.T) {
h := NewBatchHandler(BatchOptions{
WriteFunc: func(ctx context.Context, logs []*Log) error {
return nil
},
})
if h.options.BatchSize != 100 {
t.Fatalf("Expected default BatchSize %d, got %d", 100, h.options.BatchSize)
}
if h.options.Level != slog.LevelInfo {
t.Fatalf("Expected default Level Info, got %v", h.options.Level)
}
if h.options.BeforeAddFunc != nil {
t.Fatal("Expected default BeforeAddFunc to be nil")
}
if h.options.WriteFunc == nil {
t.Fatal("Expected default WriteFunc to be set")
}
if h.group != "" {
t.Fatalf("Expected empty group, got %s", h.group)
}
if len(h.attrs) != 0 {
t.Fatalf("Expected empty attrs, got %v", h.attrs)
}
if len(h.logs) != 0 {
t.Fatalf("Expected empty logs queue, got %v", h.logs)
}
}
func TestBatchHandlerEnabled(t *testing.T) {
h := NewBatchHandler(BatchOptions{
Level: slog.LevelWarn,
WriteFunc: func(ctx context.Context, logs []*Log) error {
return nil
},
})
l := slog.New(h)
scenarios := []struct {
level slog.Level
expected bool
}{
{slog.LevelDebug, false},
{slog.LevelInfo, false},
{slog.LevelWarn, true},
{slog.LevelError, true},
}
for _, s := range scenarios {
t.Run(fmt.Sprintf("Level %v", s.level), func(t *testing.T) {
result := l.Enabled(context.Background(), s.level)
if result != s.expected {
t.Fatalf("Expected %v, got %v", s.expected, result)
}
})
}
}
func TestBatchHandlerSetLevel(t *testing.T) {
h := NewBatchHandler(BatchOptions{
Level: slog.LevelWarn,
WriteFunc: func(ctx context.Context, logs []*Log) error {
return nil
},
})
if h.options.Level != slog.LevelWarn {
t.Fatalf("Expected the initial level to be %d, got %d", slog.LevelWarn, h.options.Level)
}
h.SetLevel(slog.LevelDebug)
if h.options.Level != slog.LevelDebug {
t.Fatalf("Expected the updated level to be %d, got %d", slog.LevelDebug, h.options.Level)
}
}
func TestBatchHandlerWithAttrsAndWithGroup(t *testing.T) {
h0 := NewBatchHandler(BatchOptions{
WriteFunc: func(ctx context.Context, logs []*Log) error {
return nil
},
})
h1 := h0.WithAttrs([]slog.Attr{slog.Int("test1", 1)}).(*BatchHandler)
h2 := h1.WithGroup("h2_group").(*BatchHandler)
h3 := h2.WithAttrs([]slog.Attr{slog.Int("test2", 2)}).(*BatchHandler)
scenarios := []struct {
name string
handler *BatchHandler
expectedParent *BatchHandler
expectedGroup string
expectedAttrs int
}{
{
"h0",
h0,
nil,
"",
0,
},
{
"h1",
h1,
h0,
"",
1,
},
{
"h2",
h2,
h1,
"h2_group",
0,
},
{
"h3",
h3,
h2,
"",
1,
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
if s.handler.group != s.expectedGroup {
t.Fatalf("Expected group %q, got %q", s.expectedGroup, s.handler.group)
}
if s.handler.parent != s.expectedParent {
t.Fatalf("Expected parent %v, got %v", s.expectedParent, s.handler.parent)
}
if totalAttrs := len(s.handler.attrs); totalAttrs != s.expectedAttrs {
t.Fatalf("Expected %d attrs, got %d", s.expectedAttrs, totalAttrs)
}
})
}
}
func TestBatchHandlerHandle(t *testing.T) {
ctx := context.Background()
beforeLogs := []*Log{}
writeLogs := []*Log{}
h := NewBatchHandler(BatchOptions{
BatchSize: 3,
BeforeAddFunc: func(_ context.Context, log *Log) bool {
beforeLogs = append(beforeLogs, log)
// skip test2 log
return log.Message != "test2"
},
WriteFunc: func(_ context.Context, logs []*Log) error {
writeLogs = logs
return nil
},
})
h.Handle(ctx, slog.NewRecord(time.Now(), slog.LevelInfo, "test1", 0))
h.Handle(ctx, slog.NewRecord(time.Now(), slog.LevelInfo, "test2", 0))
h.Handle(ctx, slog.NewRecord(time.Now(), slog.LevelInfo, "test3", 0))
// no batch write
{
checkLogMessages([]string{"test1", "test2", "test3"}, beforeLogs, t)
checkLogMessages([]string{"test1", "test3"}, h.logs, t)
// should be empty because no batch write has happened yet
if totalWriteLogs := len(writeLogs); totalWriteLogs != 0 {
t.Fatalf("Expected %d writeLogs, got %d", 0, totalWriteLogs)
}
}
// add one more log to trigger the batch write
{
h.Handle(ctx, slog.NewRecord(time.Now(), slog.LevelInfo, "test4", 0))
// should be empty after the batch write
checkLogMessages([]string{}, h.logs, t)
checkLogMessages([]string{"test1", "test3", "test4"}, writeLogs, t)
}
}
func TestBatchHandlerWriteAll(t *testing.T) {
ctx := context.Background()
beforeLogs := []*Log{}
writeLogs := []*Log{}
h := NewBatchHandler(BatchOptions{
BatchSize: 3,
BeforeAddFunc: func(_ context.Context, log *Log) bool {
beforeLogs = append(beforeLogs, log)
return true
},
WriteFunc: func(_ context.Context, logs []*Log) error {
writeLogs = logs
return nil
},
})
h.Handle(ctx, slog.NewRecord(time.Now(), slog.LevelInfo, "test1", 0))
h.Handle(ctx, slog.NewRecord(time.Now(), slog.LevelInfo, "test2", 0))
checkLogMessages([]string{"test1", "test2"}, beforeLogs, t)
checkLogMessages([]string{"test1", "test2"}, h.logs, t)
checkLogMessages([]string{}, writeLogs, t) // empty because the batch size hasn't been reached
// force trigger the batch write
h.WriteAll(ctx)
checkLogMessages([]string{"test1", "test2"}, beforeLogs, t)
checkLogMessages([]string{}, h.logs, t) // reset
checkLogMessages([]string{"test1", "test2"}, writeLogs, t)
}
func TestBatchHandlerAttrsFormat(t *testing.T) {
ctx := context.Background()
beforeLogs := []*Log{}
h0 := NewBatchHandler(BatchOptions{
BeforeAddFunc: func(_ context.Context, log *Log) bool {
beforeLogs = append(beforeLogs, log)
return true
},
WriteFunc: func(_ context.Context, logs []*Log) error {
return nil
},
})
h1 := h0.WithAttrs([]slog.Attr{slog.Int("a", 1), slog.String("b", "123")})
h2 := h1.WithGroup("sub").WithAttrs([]slog.Attr{
slog.Int("c", 3),
slog.Any("d", map[string]any{"d.1": 1}),
slog.Any("e", errors.New("example error")),
})
record := slog.NewRecord(time.Now(), slog.LevelInfo, "hello", 0)
record.AddAttrs(slog.String("name", "test"))
h0.Handle(ctx, record)
h1.Handle(ctx, record)
h2.Handle(ctx, record)
// errors serialization checks
errorsRecord := slog.NewRecord(time.Now(), slog.LevelError, "details", 0)
errorsRecord.Add("validation.Errors", validation.Errors{
"a": validation.NewError("validation_code", "validation_message"),
"b": errors.New("plain"),
})
errorsRecord.Add("wrapped_validation.Errors", fmt.Errorf("wrapped: %w", validation.Errors{
"a": validation.NewError("validation_code", "validation_message"),
"b": errors.New("plain"),
}))
errorsRecord.Add("map[string]any", map[string]any{
"a": validation.NewError("validation_code", "validation_message"),
"b": errors.New("plain"),
"c": "test_any",
"d": map[string]any{
"nestedA": validation.NewError("nested_code", "nested_message"),
"nestedB": errors.New("nested_plain"),
},
})
errorsRecord.Add("map[string]error", map[string]error{
"a": validation.NewError("validation_code", "validation_message"),
"b": errors.New("plain"),
})
errorsRecord.Add("map[string]validation.Error", map[string]validation.Error{
"a": validation.NewError("validation_code", "validation_message"),
"b": nil,
})
errorsRecord.Add("plain_error", errors.New("plain"))
h0.Handle(ctx, errorsRecord)
expected := []string{
`{"name":"test"}`,
`{"a":1,"b":"123","name":"test"}`,
`{"a":1,"b":"123","sub":{"c":3,"d":{"d.1":1},"e":"example error","name":"test"}}`,
`{"map[string]any":{"a":"validation_message","b":"plain","c":"test_any","d":{"nestedA":"nested_message","nestedB":"nested_plain"}},"map[string]error":{"a":"validation_message","b":"plain"},"map[string]validation.Error":{"a":"validation_message","b":null},"plain_error":"plain","validation.Errors":{"a":"validation_message","b":"plain"},"wrapped_validation.Errors":{"data":{"a":"validation_message","b":"plain"},"raw":"wrapped: a: validation_message; b: plain."}}`,
}
if len(beforeLogs) != len(expected) {
t.Fatalf("Expected %d logs, got %d", len(expected), len(beforeLogs))
}
for i, data := range expected {
t.Run(fmt.Sprintf("log handler %d", i), func(t *testing.T) {
log := beforeLogs[i]
raw, _ := log.Data.MarshalJSON()
if string(raw) != data {
t.Fatalf("Expected \n%s \ngot \n%s", data, raw)
}
})
}
}
func checkLogMessages(expected []string, logs []*Log, t *testing.T) {
if len(logs) != len(expected) {
t.Fatalf("Expected %d batched logs, got %d (expected: %v)", len(expected), len(logs), expected)
}
for _, message := range expected {
exists := false
for _, l := range logs {
if l.Message == message {
exists = true
continue
}
}
if !exists {
t.Fatalf("Missing %q log message", message)
}
}
}

17
tools/logger/log.go Normal file
View file

@ -0,0 +1,17 @@
package logger
import (
"log/slog"
"time"
"github.com/pocketbase/pocketbase/tools/types"
)
// Log is similar to [slog.Record] bit contains the log attributes as
// preformatted JSON map.
type Log struct {
Time time.Time
Data types.JSONMap[any]
Message string
Level slog.Level
}

118
tools/mailer/html2text.go Normal file
View file

@ -0,0 +1,118 @@
package mailer
import (
"regexp"
"strings"
"github.com/pocketbase/pocketbase/tools/list"
"golang.org/x/net/html"
)
var whitespaceRegex = regexp.MustCompile(`\s+`)
var tagsToSkip = []string{
"style", "script", "iframe", "applet", "object", "svg", "img",
"button", "form", "textarea", "input", "select", "option", "template",
}
var inlineTags = []string{
"a", "span", "small", "strike", "strong",
"sub", "sup", "em", "b", "u", "i",
}
// Very rudimentary auto HTML to Text mail body converter.
//
// Caveats:
// - This method doesn't check for correctness of the HTML document.
// - Links will be converted to "[text](url)" format.
// - List items (<li>) are prefixed with "- ".
// - Indentation is stripped (both tabs and spaces).
// - Trailing spaces are preserved.
// - Multiple consequence newlines are collapsed as one unless multiple <br> tags are used.
func html2Text(htmlDocument string) (string, error) {
doc, err := html.Parse(strings.NewReader(htmlDocument))
if err != nil {
return "", err
}
var builder strings.Builder
var canAddNewLine bool
// see https://pkg.go.dev/golang.org/x/net/html#Parse
var f func(*html.Node, *strings.Builder)
f = func(n *html.Node, activeBuilder *strings.Builder) {
isLink := n.Type == html.ElementNode && n.Data == "a"
if isLink {
var linkBuilder strings.Builder
activeBuilder = &linkBuilder
} else if activeBuilder == nil {
activeBuilder = &builder
}
switch n.Type {
case html.TextNode:
txt := whitespaceRegex.ReplaceAllString(n.Data, " ")
// the prev node has new line so it is safe to trim the indentation
if !canAddNewLine {
txt = strings.TrimLeft(txt, " ")
}
if txt != "" {
activeBuilder.WriteString(txt)
canAddNewLine = true
}
case html.ElementNode:
if n.Data == "br" {
// always write new lines when <br> tag is used
activeBuilder.WriteString("\r\n")
canAddNewLine = false
} else if canAddNewLine && !list.ExistInSlice(n.Data, inlineTags) {
activeBuilder.WriteString("\r\n")
canAddNewLine = false
}
// prefix list items with dash
if n.Data == "li" {
activeBuilder.WriteString("- ")
}
}
for c := n.FirstChild; c != nil; c = c.NextSibling {
if c.Type != html.ElementNode || !list.ExistInSlice(c.Data, tagsToSkip) {
f(c, activeBuilder)
}
}
// format links as [label](href)
if isLink {
linkTxt := strings.TrimSpace(activeBuilder.String())
if linkTxt == "" {
linkTxt = "LINK"
}
builder.WriteString("[")
builder.WriteString(linkTxt)
builder.WriteString("]")
// link href attr extraction
for _, a := range n.Attr {
if a.Key == "href" {
if a.Val != "" {
builder.WriteString("(")
builder.WriteString(a.Val)
builder.WriteString(")")
}
break
}
}
activeBuilder.Reset()
}
}
f(doc, &builder)
return strings.TrimSpace(builder.String()), nil
}

View file

@ -0,0 +1,131 @@
package mailer
import (
"testing"
)
func TestHTML2Text(t *testing.T) {
scenarios := []struct {
html string
expected string
}{
{
"",
"",
},
{
"ab c",
"ab c",
},
{
"<!-- test html comment -->",
"",
},
{
"<!-- test html comment --> a ",
"a",
},
{
"<span>a</span>b<span>c</span>",
"abc",
},
{
`<a href="a/b/c">test</span>`,
"[test](a/b/c)",
},
{
`<a href="">test</span>`,
"[test]",
},
{
"<span>a</span> <span>b</span>",
"a b",
},
{
"<span>a</span> b <span>c</span>",
"a b c",
},
{
"<span>a</span> b <div>c</div>",
"a b \r\nc",
},
{
`
<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Strict//EN" "http://www.w3.org/TR/xhtml1/DTD/xhtml1-strict.dtd">
<html xmlns="http://www.w3.org/1999/xhtml">
<head>
<meta http-equiv="Content-Type" content="text/html; charset=UTF-8" />
<meta name="viewport" content="width=device-width,initial-scale=1" />
<style>
body {
padding: 0;
}
</style>
</head>
<body>
<!-- test html comment -->
<style>
body {
padding: 0;
}
</style>
<div class="wrapper">
<div class="content">
<p>Lorem ipsum</p>
<p>Dolor sit amet</p>
<p>
<a href="a/b/c">Verify</a>
</p>
<br>
<p>
<a href="a/b/c"><strong>Verify2.1</strong> <strong>Verify2.2</strong></a>
</p>
<br>
<br>
<div>
<div>
<div>
<ul>
<li>ul.test1</li>
<li>ul.test2</li>
<li>ul.test3</li>
</ul>
<ol>
<li>ol.test1</li>
<li>ol.test2</li>
<li>ol.test3</li>
</ol>
</div>
</div>
</div>
<select>
<option>Option 1</option>
<option>Option 2</option>
</select>
<textarea>test</textarea>
<input type="text" value="test" />
<button>test</button>
<p>
Thanks,<br/>
PocketBase team
</p>
</div>
</div>
</body>
</html>
`,
"Lorem ipsum \r\nDolor sit amet \r\n[Verify](a/b/c) \r\n[Verify2.1 Verify2.2](a/b/c) \r\n\r\n- ul.test1 \r\n- ul.test2 \r\n- ul.test3 \r\n- ol.test1 \r\n- ol.test2 \r\n- ol.test3 \r\nThanks,\r\nPocketBase team",
},
}
for i, s := range scenarios {
result, err := html2Text(s.html)
if err != nil {
t.Errorf("(%d) Unexpected error %v", i, err)
}
if result != s.expected {
t.Errorf("(%d) Expected \n(%q)\n%v,\n\ngot:\n\n(%q)\n%v", i, s.expected, s.expected, result, result)
}
}
}

72
tools/mailer/mailer.go Normal file
View file

@ -0,0 +1,72 @@
package mailer
import (
"bytes"
"io"
"net/mail"
"github.com/gabriel-vasile/mimetype"
"github.com/pocketbase/pocketbase/tools/hook"
)
// Message defines a generic email message struct.
type Message struct {
From mail.Address `json:"from"`
To []mail.Address `json:"to"`
Bcc []mail.Address `json:"bcc"`
Cc []mail.Address `json:"cc"`
Subject string `json:"subject"`
HTML string `json:"html"`
Text string `json:"text"`
Headers map[string]string `json:"headers"`
Attachments map[string]io.Reader `json:"attachments"`
InlineAttachments map[string]io.Reader `json:"inlineAttachments"`
}
// Mailer defines a base mail client interface.
type Mailer interface {
// Send sends an email with the provided Message.
Send(message *Message) error
}
// SendInterceptor is optional interface for registering mail send hooks.
type SendInterceptor interface {
OnSend() *hook.Hook[*SendEvent]
}
type SendEvent struct {
hook.Event
Message *Message
}
// addressesToStrings converts the provided address to a list of serialized RFC 5322 strings.
//
// To export only the email part of mail.Address, you can set withName to false.
func addressesToStrings(addresses []mail.Address, withName bool) []string {
result := make([]string, len(addresses))
for i, addr := range addresses {
if withName && addr.Name != "" {
result[i] = addr.String()
} else {
// keep only the email part to avoid wrapping in angle-brackets
result[i] = addr.Address
}
}
return result
}
// detectReaderMimeType reads the first couple bytes of the reader to detect its MIME type.
//
// Returns a new combined reader from the partial read + the remaining of the original reader.
func detectReaderMimeType(r io.Reader) (io.Reader, string, error) {
readCopy := new(bytes.Buffer)
mime, err := mimetype.DetectReader(io.TeeReader(r, readCopy))
if err != nil {
return nil, "", err
}
return io.MultiReader(readCopy, r), mime.String(), nil
}

View file

@ -0,0 +1,77 @@
package mailer
import (
"fmt"
"io"
"net/mail"
"strings"
"testing"
)
func TestAddressesToStrings(t *testing.T) {
t.Parallel()
scenarios := []struct {
withName bool
addresses []mail.Address
expected []string
}{
{
true,
[]mail.Address{{Name: "John Doe", Address: "test1@example.com"}, {Name: "Jane Doe", Address: "test2@example.com"}},
[]string{`"John Doe" <test1@example.com>`, `"Jane Doe" <test2@example.com>`},
},
{
true,
[]mail.Address{{Name: "John Doe", Address: "test1@example.com"}, {Address: "test2@example.com"}},
[]string{`"John Doe" <test1@example.com>`, `test2@example.com`},
},
{
false,
[]mail.Address{{Name: "John Doe", Address: "test1@example.com"}, {Name: "Jane Doe", Address: "test2@example.com"}},
[]string{`test1@example.com`, `test2@example.com`},
},
}
for _, s := range scenarios {
t.Run(fmt.Sprintf("%v_%v", s.withName, s.addresses), func(t *testing.T) {
result := addressesToStrings(s.addresses, s.withName)
if len(s.expected) != len(result) {
t.Fatalf("Expected\n%v\ngot\n%v", s.expected, result)
}
for k, v := range s.expected {
if v != result[k] {
t.Fatalf("Expected %d address %q, got %q", k, v, result[k])
}
}
})
}
}
func TestDetectReaderMimeType(t *testing.T) {
t.Parallel()
str := "#!/bin/node\n" + strings.Repeat("a", 10000) // ensure that it is large enough to remain after the signature sniffing
r, mime, err := detectReaderMimeType(strings.NewReader(str))
if err != nil {
t.Fatal(err)
}
expectedMime := "text/javascript"
if mime != expectedMime {
t.Fatalf("Expected mime %q, got %q", expectedMime, mime)
}
raw, err := io.ReadAll(r)
if err != nil {
t.Fatal(err)
}
rawStr := string(raw)
if rawStr != str {
t.Fatalf("Expected content\n%s\ngot\n%s", str, rawStr)
}
}

99
tools/mailer/sendmail.go Normal file
View file

@ -0,0 +1,99 @@
package mailer
import (
"bytes"
"errors"
"mime"
"net/http"
"os/exec"
"strings"
"github.com/pocketbase/pocketbase/tools/hook"
)
var _ Mailer = (*Sendmail)(nil)
// Sendmail implements [mailer.Mailer] interface and defines a mail
// client that sends emails via the "sendmail" *nix command.
//
// This client is usually recommended only for development and testing.
type Sendmail struct {
onSend *hook.Hook[*SendEvent]
}
// OnSend implements [mailer.SendInterceptor] interface.
func (c *Sendmail) OnSend() *hook.Hook[*SendEvent] {
if c.onSend == nil {
c.onSend = &hook.Hook[*SendEvent]{}
}
return c.onSend
}
// Send implements [mailer.Mailer] interface.
func (c *Sendmail) Send(m *Message) error {
if c.onSend != nil {
return c.onSend.Trigger(&SendEvent{Message: m}, func(e *SendEvent) error {
return c.send(e.Message)
})
}
return c.send(m)
}
func (c *Sendmail) send(m *Message) error {
toAddresses := addressesToStrings(m.To, false)
headers := make(http.Header)
headers.Set("Subject", mime.QEncoding.Encode("utf-8", m.Subject))
headers.Set("From", m.From.String())
headers.Set("Content-Type", "text/html; charset=UTF-8")
headers.Set("To", strings.Join(toAddresses, ","))
cmdPath, err := findSendmailPath()
if err != nil {
return err
}
var buffer bytes.Buffer
// write
// ---
if err := headers.Write(&buffer); err != nil {
return err
}
if _, err := buffer.Write([]byte("\r\n")); err != nil {
return err
}
if m.HTML != "" {
if _, err := buffer.Write([]byte(m.HTML)); err != nil {
return err
}
} else {
if _, err := buffer.Write([]byte(m.Text)); err != nil {
return err
}
}
// ---
sendmail := exec.Command(cmdPath, strings.Join(toAddresses, ","))
sendmail.Stdin = &buffer
return sendmail.Run()
}
func findSendmailPath() (string, error) {
options := []string{
"/usr/sbin/sendmail",
"/usr/bin/sendmail",
"sendmail",
}
for _, option := range options {
path, err := exec.LookPath(option)
if err == nil {
return path, err
}
}
return "", errors.New("failed to locate a sendmail executable path")
}

211
tools/mailer/smtp.go Normal file
View file

@ -0,0 +1,211 @@
package mailer
import (
"errors"
"fmt"
"net/smtp"
"strings"
"github.com/domodwyer/mailyak/v3"
"github.com/pocketbase/pocketbase/tools/hook"
"github.com/pocketbase/pocketbase/tools/security"
)
var _ Mailer = (*SMTPClient)(nil)
const (
SMTPAuthPlain = "PLAIN"
SMTPAuthLogin = "LOGIN"
)
// SMTPClient defines a SMTP mail client structure that implements
// `mailer.Mailer` interface.
type SMTPClient struct {
onSend *hook.Hook[*SendEvent]
TLS bool
Port int
Host string
Username string
Password string
// SMTP auth method to use
// (if not explicitly set, defaults to "PLAIN")
AuthMethod string
// LocalName is optional domain name used for the EHLO/HELO exchange
// (if not explicitly set, defaults to "localhost").
//
// This is required only by some SMTP servers, such as Gmail SMTP-relay.
LocalName string
}
// OnSend implements [mailer.SendInterceptor] interface.
func (c *SMTPClient) OnSend() *hook.Hook[*SendEvent] {
if c.onSend == nil {
c.onSend = &hook.Hook[*SendEvent]{}
}
return c.onSend
}
// Send implements [mailer.Mailer] interface.
func (c *SMTPClient) Send(m *Message) error {
if c.onSend != nil {
return c.onSend.Trigger(&SendEvent{Message: m}, func(e *SendEvent) error {
return c.send(e.Message)
})
}
return c.send(m)
}
func (c *SMTPClient) send(m *Message) error {
var smtpAuth smtp.Auth
if c.Username != "" || c.Password != "" {
switch c.AuthMethod {
case SMTPAuthLogin:
smtpAuth = &smtpLoginAuth{c.Username, c.Password}
default:
smtpAuth = smtp.PlainAuth("", c.Username, c.Password, c.Host)
}
}
// create mail instance
var yak *mailyak.MailYak
if c.TLS {
var tlsErr error
yak, tlsErr = mailyak.NewWithTLS(fmt.Sprintf("%s:%d", c.Host, c.Port), smtpAuth, nil)
if tlsErr != nil {
return tlsErr
}
} else {
yak = mailyak.New(fmt.Sprintf("%s:%d", c.Host, c.Port), smtpAuth)
}
if c.LocalName != "" {
yak.LocalName(c.LocalName)
}
if m.From.Name != "" {
yak.FromName(m.From.Name)
}
yak.From(m.From.Address)
yak.Subject(m.Subject)
yak.HTML().Set(m.HTML)
if m.Text == "" {
// try to generate a plain text version of the HTML
if plain, err := html2Text(m.HTML); err == nil {
yak.Plain().Set(plain)
}
} else {
yak.Plain().Set(m.Text)
}
if len(m.To) > 0 {
yak.To(addressesToStrings(m.To, true)...)
}
if len(m.Bcc) > 0 {
yak.Bcc(addressesToStrings(m.Bcc, true)...)
}
if len(m.Cc) > 0 {
yak.Cc(addressesToStrings(m.Cc, true)...)
}
// add regular attachements (if any)
for name, data := range m.Attachments {
r, mime, err := detectReaderMimeType(data)
if err != nil {
return err
}
yak.AttachWithMimeType(name, r, mime)
}
// add inline attachments (if any)
for name, data := range m.InlineAttachments {
r, mime, err := detectReaderMimeType(data)
if err != nil {
return err
}
yak.AttachInlineWithMimeType(name, r, mime)
}
// add custom headers (if any)
var hasMessageId bool
for k, v := range m.Headers {
if strings.EqualFold(k, "Message-ID") {
hasMessageId = true
}
yak.AddHeader(k, v)
}
if !hasMessageId {
// add a default message id if missing
fromParts := strings.Split(m.From.Address, "@")
if len(fromParts) == 2 {
yak.AddHeader("Message-ID", fmt.Sprintf("<%s@%s>",
security.PseudorandomString(15),
fromParts[1],
))
}
}
return yak.Send()
}
// -------------------------------------------------------------------
// AUTH LOGIN
// -------------------------------------------------------------------
var _ smtp.Auth = (*smtpLoginAuth)(nil)
// smtpLoginAuth defines an AUTH that implements the LOGIN authentication mechanism.
//
// AUTH LOGIN is obsolete[1] but some mail services like outlook requires it [2].
//
// NB!
// It will only send the credentials if the connection is using TLS or is connected to localhost.
// Otherwise authentication will fail with an error, without sending the credentials.
//
// [1]: https://github.com/golang/go/issues/40817
// [2]: https://support.microsoft.com/en-us/office/outlook-com-no-longer-supports-auth-plain-authentication-07f7d5e9-1697-465f-84d2-4513d4ff0145?ui=en-us&rs=en-us&ad=us
type smtpLoginAuth struct {
username, password string
}
// Start initializes an authentication with the server.
//
// It is part of the [smtp.Auth] interface.
func (a *smtpLoginAuth) Start(server *smtp.ServerInfo) (string, []byte, error) {
// Must have TLS, or else localhost server.
// Note: If TLS is not true, then we can't trust ANYTHING in ServerInfo.
// In particular, it doesn't matter if the server advertises LOGIN auth.
// That might just be the attacker saying
// "it's ok, you can trust me with your password."
if !server.TLS && !isLocalhost(server.Name) {
return "", nil, errors.New("unencrypted connection")
}
return "LOGIN", nil, nil
}
// Next "continues" the auth process by feeding the server with the requested data.
//
// It is part of the [smtp.Auth] interface.
func (a *smtpLoginAuth) Next(fromServer []byte, more bool) ([]byte, error) {
if more {
switch strings.ToLower(string(fromServer)) {
case "username:":
return []byte(a.username), nil
case "password:":
return []byte(a.password), nil
}
}
return nil, nil
}
func isLocalhost(name string) bool {
return name == "localhost" || name == "127.0.0.1" || name == "::1"
}

166
tools/mailer/smtp_test.go Normal file
View file

@ -0,0 +1,166 @@
package mailer
import (
"net/smtp"
"testing"
)
func TestLoginAuthStart(t *testing.T) {
auth := smtpLoginAuth{username: "test", password: "123456"}
scenarios := []struct {
name string
serverInfo *smtp.ServerInfo
expectError bool
}{
{
"localhost without tls",
&smtp.ServerInfo{TLS: false, Name: "localhost"},
false,
},
{
"localhost with tls",
&smtp.ServerInfo{TLS: true, Name: "localhost"},
false,
},
{
"127.0.0.1 without tls",
&smtp.ServerInfo{TLS: false, Name: "127.0.0.1"},
false,
},
{
"127.0.0.1 with tls",
&smtp.ServerInfo{TLS: false, Name: "127.0.0.1"},
false,
},
{
"::1 without tls",
&smtp.ServerInfo{TLS: false, Name: "::1"},
false,
},
{
"::1 with tls",
&smtp.ServerInfo{TLS: false, Name: "::1"},
false,
},
{
"non-localhost without tls",
&smtp.ServerInfo{TLS: false, Name: "example.com"},
true,
},
{
"non-localhost with tls",
&smtp.ServerInfo{TLS: true, Name: "example.com"},
false,
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
method, resp, err := auth.Start(s.serverInfo)
hasErr := err != nil
if hasErr != s.expectError {
t.Fatalf("Expected hasErr %v, got %v", s.expectError, hasErr)
}
if hasErr {
return
}
if len(resp) != 0 {
t.Fatalf("Expected empty data response, got %v", resp)
}
if method != "LOGIN" {
t.Fatalf("Expected LOGIN, got %v", method)
}
})
}
}
func TestLoginAuthNext(t *testing.T) {
auth := smtpLoginAuth{username: "test", password: "123456"}
{
// example|false
r1, err := auth.Next([]byte("example:"), false)
if err != nil {
t.Fatalf("[example|false] Unexpected error %v", err)
}
if len(r1) != 0 {
t.Fatalf("[example|false] Expected empty part, got %v", r1)
}
// example|true
r2, err := auth.Next([]byte("example:"), true)
if err != nil {
t.Fatalf("[example|true] Unexpected error %v", err)
}
if len(r2) != 0 {
t.Fatalf("[example|true] Expected empty part, got %v", r2)
}
}
// ---------------------------------------------------------------
{
// username:|false
r1, err := auth.Next([]byte("username:"), false)
if err != nil {
t.Fatalf("[username|false] Unexpected error %v", err)
}
if len(r1) != 0 {
t.Fatalf("[username|false] Expected empty part, got %v", r1)
}
// username:|true
r2, err := auth.Next([]byte("username:"), true)
if err != nil {
t.Fatalf("[username|true] Unexpected error %v", err)
}
if str := string(r2); str != auth.username {
t.Fatalf("[username|true] Expected %s, got %s", auth.username, str)
}
// uSeRnAmE:|true
r3, err := auth.Next([]byte("uSeRnAmE:"), true)
if err != nil {
t.Fatalf("[uSeRnAmE|true] Unexpected error %v", err)
}
if str := string(r3); str != auth.username {
t.Fatalf("[uSeRnAmE|true] Expected %s, got %s", auth.username, str)
}
}
// ---------------------------------------------------------------
{
// password:|false
r1, err := auth.Next([]byte("password:"), false)
if err != nil {
t.Fatalf("[password|false] Unexpected error %v", err)
}
if len(r1) != 0 {
t.Fatalf("[password|false] Expected empty part, got %v", r1)
}
// password:|true
r2, err := auth.Next([]byte("password:"), true)
if err != nil {
t.Fatalf("[password|true] Unexpected error %v", err)
}
if str := string(r2); str != auth.password {
t.Fatalf("[password|true] Expected %s, got %s", auth.password, str)
}
// pAsSwOrD:|true
r3, err := auth.Next([]byte("pAsSwOrD:"), true)
if err != nil {
t.Fatalf("[pAsSwOrD|true] Unexpected error %v", err)
}
if str := string(r3); str != auth.password {
t.Fatalf("[pAsSwOrD|true] Expected %s, got %s", auth.password, str)
}
}
}

Some files were not shown because too many files have changed in this diff Show more