Adding upstream version 0.28.1.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
88f1d47ab6
commit
e28c88ef14
933 changed files with 194711 additions and 0 deletions
91
tools/archive/create.go
Normal file
91
tools/archive/create.go
Normal file
|
@ -0,0 +1,91 @@
|
|||
package archive
|
||||
|
||||
import (
|
||||
"archive/zip"
|
||||
"compress/flate"
|
||||
"errors"
|
||||
"io"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Create creates a new zip archive from src dir content and saves it in dest path.
|
||||
//
|
||||
// You can specify skipPaths to skip/ignore certain directories and files (relative to src)
|
||||
// preventing adding them in the final archive.
|
||||
func Create(src string, dest string, skipPaths ...string) error {
|
||||
if err := os.MkdirAll(filepath.Dir(dest), os.ModePerm); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
zf, err := os.Create(dest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
zw := zip.NewWriter(zf)
|
||||
|
||||
// register a custom Deflate compressor
|
||||
zw.RegisterCompressor(zip.Deflate, func(out io.Writer) (io.WriteCloser, error) {
|
||||
return flate.NewWriter(out, flate.BestSpeed)
|
||||
})
|
||||
|
||||
err = zipAddFS(zw, os.DirFS(src), skipPaths...)
|
||||
if err != nil {
|
||||
// try to cleanup at least the created zip file
|
||||
return errors.Join(err, zw.Close(), zf.Close(), os.Remove(dest))
|
||||
}
|
||||
|
||||
return errors.Join(zw.Close(), zf.Close())
|
||||
}
|
||||
|
||||
// note remove after similar method is added in the std lib (https://github.com/golang/go/issues/54898)
|
||||
func zipAddFS(w *zip.Writer, fsys fs.FS, skipPaths ...string) error {
|
||||
return fs.WalkDir(fsys, ".", func(name string, d fs.DirEntry, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if d.IsDir() {
|
||||
return nil
|
||||
}
|
||||
|
||||
// skip
|
||||
for _, ignore := range skipPaths {
|
||||
if ignore == name ||
|
||||
strings.HasPrefix(filepath.Clean(name)+string(os.PathSeparator), filepath.Clean(ignore)+string(os.PathSeparator)) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
info, err := d.Info()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
h, err := zip.FileInfoHeader(info)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
h.Name = name
|
||||
h.Method = zip.Deflate
|
||||
|
||||
fw, err := w.CreateHeader(h)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
f, err := fsys.Open(name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
_, err = io.Copy(fw, f)
|
||||
|
||||
return err
|
||||
})
|
||||
}
|
125
tools/archive/create_test.go
Normal file
125
tools/archive/create_test.go
Normal file
|
@ -0,0 +1,125 @@
|
|||
package archive_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/archive"
|
||||
)
|
||||
|
||||
func TestCreateFailure(t *testing.T) {
|
||||
testDir := createTestDir(t)
|
||||
defer os.RemoveAll(testDir)
|
||||
|
||||
zipPath := filepath.Join(os.TempDir(), "pb_test.zip")
|
||||
defer os.RemoveAll(zipPath)
|
||||
|
||||
missingDir := filepath.Join(os.TempDir(), "missing")
|
||||
|
||||
if err := archive.Create(missingDir, zipPath); err == nil {
|
||||
t.Fatal("Expected to fail due to missing directory or file")
|
||||
}
|
||||
|
||||
if _, err := os.Stat(zipPath); err == nil {
|
||||
t.Fatalf("Expected the zip file not to be created")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateSuccess(t *testing.T) {
|
||||
testDir := createTestDir(t)
|
||||
defer os.RemoveAll(testDir)
|
||||
|
||||
zipName := "pb_test.zip"
|
||||
zipPath := filepath.Join(os.TempDir(), zipName)
|
||||
defer os.RemoveAll(zipPath)
|
||||
|
||||
// zip testDir content (excluding test and a/b/c dir)
|
||||
if err := archive.Create(testDir, zipPath, "a/b/c", "test"); err != nil {
|
||||
t.Fatalf("Failed to create archive: %v", err)
|
||||
}
|
||||
|
||||
info, err := os.Stat(zipPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to retrieve the generated zip file: %v", err)
|
||||
}
|
||||
|
||||
if name := info.Name(); name != zipName {
|
||||
t.Fatalf("Expected zip with name %q, got %q", zipName, name)
|
||||
}
|
||||
|
||||
expectedSize := int64(544)
|
||||
if size := info.Size(); size != expectedSize {
|
||||
t.Fatalf("Expected zip with size %d, got %d", expectedSize, size)
|
||||
}
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
// note: make sure to call os.RemoveAll(dir) after you are done
|
||||
// working with the created test dir.
|
||||
func createTestDir(t *testing.T) string {
|
||||
dir, err := os.MkdirTemp(os.TempDir(), "pb_zip_test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(filepath.Join(dir, "a/b/c"), os.ModePerm); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
{
|
||||
f, err := os.OpenFile(filepath.Join(dir, "test"), os.O_WRONLY|os.O_CREATE, 0644)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
f.Close()
|
||||
}
|
||||
|
||||
{
|
||||
f, err := os.OpenFile(filepath.Join(dir, "test2"), os.O_WRONLY|os.O_CREATE, 0644)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
f.Close()
|
||||
}
|
||||
|
||||
{
|
||||
f, err := os.OpenFile(filepath.Join(dir, "a/test"), os.O_WRONLY|os.O_CREATE, 0644)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
f.Close()
|
||||
}
|
||||
|
||||
{
|
||||
f, err := os.OpenFile(filepath.Join(dir, "a/b/sub1"), os.O_WRONLY|os.O_CREATE, 0644)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
f.Close()
|
||||
}
|
||||
|
||||
{
|
||||
f, err := os.OpenFile(filepath.Join(dir, "a/b/c/sub2"), os.O_WRONLY|os.O_CREATE, 0644)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
f.Close()
|
||||
}
|
||||
|
||||
{
|
||||
f, err := os.OpenFile(filepath.Join(dir, "a/b/c/sub3"), os.O_WRONLY|os.O_CREATE, 0644)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
f.Close()
|
||||
}
|
||||
|
||||
// symbolic link
|
||||
if err := os.Symlink(filepath.Join(dir, "test"), filepath.Join(dir, "test_symlink")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return dir
|
||||
}
|
77
tools/archive/extract.go
Normal file
77
tools/archive/extract.go
Normal file
|
@ -0,0 +1,77 @@
|
|||
package archive
|
||||
|
||||
import (
|
||||
"archive/zip"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Extract extracts the zip archive at "src" to "dest".
|
||||
//
|
||||
// Note that only dirs and regular files will be extracted.
|
||||
// Symbolic links, named pipes, sockets, or any other irregular files
|
||||
// are skipped because they come with too many edge cases and ambiguities.
|
||||
func Extract(src, dest string) error {
|
||||
zr, err := zip.OpenReader(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer zr.Close()
|
||||
|
||||
// normalize dest path to check later for Zip Slip
|
||||
dest = filepath.Clean(dest) + string(os.PathSeparator)
|
||||
|
||||
for _, f := range zr.File {
|
||||
err := extractFile(f, dest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// extractFile extracts the provided zipFile into "basePath/zipFileName" path,
|
||||
// creating all the necessary path directories.
|
||||
func extractFile(zipFile *zip.File, basePath string) error {
|
||||
path := filepath.Join(basePath, zipFile.Name)
|
||||
|
||||
// check for Zip Slip
|
||||
if !strings.HasPrefix(path, basePath) {
|
||||
return fmt.Errorf("invalid file path: %s", path)
|
||||
}
|
||||
|
||||
r, err := zipFile.Open()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer r.Close()
|
||||
|
||||
// allow only dirs or regular files
|
||||
if zipFile.FileInfo().IsDir() {
|
||||
if err := os.MkdirAll(path, os.ModePerm); err != nil {
|
||||
return err
|
||||
}
|
||||
} else if zipFile.FileInfo().Mode().IsRegular() {
|
||||
// ensure that the file path directories are created
|
||||
if err := os.MkdirAll(filepath.Dir(path), os.ModePerm); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
f, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, zipFile.Mode())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
_, err = io.Copy(f, r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
88
tools/archive/extract_test.go
Normal file
88
tools/archive/extract_test.go
Normal file
|
@ -0,0 +1,88 @@
|
|||
package archive_test
|
||||
|
||||
import (
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/archive"
|
||||
)
|
||||
|
||||
func TestExtractFailure(t *testing.T) {
|
||||
testDir := createTestDir(t)
|
||||
defer os.RemoveAll(testDir)
|
||||
|
||||
missingZipPath := filepath.Join(os.TempDir(), "pb_missing_test.zip")
|
||||
extractedPath := filepath.Join(os.TempDir(), "pb_zip_extract")
|
||||
defer os.RemoveAll(extractedPath)
|
||||
|
||||
if err := archive.Extract(missingZipPath, extractedPath); err == nil {
|
||||
t.Fatal("Expected Extract to fail due to missing zipPath")
|
||||
}
|
||||
|
||||
if _, err := os.Stat(extractedPath); err == nil {
|
||||
t.Fatalf("Expected %q to not be created", extractedPath)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractSuccess(t *testing.T) {
|
||||
testDir := createTestDir(t)
|
||||
defer os.RemoveAll(testDir)
|
||||
|
||||
zipPath := filepath.Join(os.TempDir(), "pb_test.zip")
|
||||
defer os.RemoveAll(zipPath)
|
||||
|
||||
extractedPath := filepath.Join(os.TempDir(), "pb_zip_extract")
|
||||
defer os.RemoveAll(extractedPath)
|
||||
|
||||
// zip testDir content (with exclude)
|
||||
if err := archive.Create(testDir, zipPath, "a/b/c", "test2", "sub2"); err != nil {
|
||||
t.Fatalf("Failed to create archive: %v", err)
|
||||
}
|
||||
|
||||
if err := archive.Extract(zipPath, extractedPath); err != nil {
|
||||
t.Fatalf("Failed to extract %q in %q", zipPath, extractedPath)
|
||||
}
|
||||
|
||||
availableFiles := []string{}
|
||||
|
||||
walkErr := filepath.WalkDir(extractedPath, func(path string, d fs.DirEntry, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if d.IsDir() {
|
||||
return nil
|
||||
}
|
||||
|
||||
availableFiles = append(availableFiles, path)
|
||||
|
||||
return nil
|
||||
})
|
||||
if walkErr != nil {
|
||||
t.Fatalf("Failed to read the extracted dir: %v", walkErr)
|
||||
}
|
||||
|
||||
// (note: symbolic links and other regular files should be missing)
|
||||
expectedFiles := []string{
|
||||
filepath.Join(extractedPath, "test"),
|
||||
filepath.Join(extractedPath, "a/test"),
|
||||
filepath.Join(extractedPath, "a/b/sub1"),
|
||||
}
|
||||
|
||||
if len(availableFiles) != len(expectedFiles) {
|
||||
t.Fatalf("Expected \n%v, \ngot \n%v", expectedFiles, availableFiles)
|
||||
}
|
||||
|
||||
ExpectedLoop:
|
||||
for _, expected := range expectedFiles {
|
||||
for _, available := range availableFiles {
|
||||
if available == expected {
|
||||
continue ExpectedLoop
|
||||
}
|
||||
}
|
||||
|
||||
t.Fatalf("Missing file %q in \n%v", expected, availableFiles)
|
||||
}
|
||||
}
|
166
tools/auth/apple.go
Normal file
166
tools/auth/apple.go
Normal file
|
@ -0,0 +1,166 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
"github.com/spf13/cast"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Providers[NameApple] = wrapFactory(NewAppleProvider)
|
||||
}
|
||||
|
||||
var _ Provider = (*Apple)(nil)
|
||||
|
||||
// NameApple is the unique name of the Apple provider.
|
||||
const NameApple string = "apple"
|
||||
|
||||
// Apple allows authentication via Apple OAuth2.
|
||||
//
|
||||
// [OIDC differences]: https://bitbucket.org/openid/connect/src/master/How-Sign-in-with-Apple-differs-from-OpenID-Connect.md
|
||||
type Apple struct {
|
||||
BaseProvider
|
||||
|
||||
jwksURL string
|
||||
}
|
||||
|
||||
// NewAppleProvider creates a new Apple provider instance with some defaults.
|
||||
func NewAppleProvider() *Apple {
|
||||
return &Apple{
|
||||
BaseProvider: BaseProvider{
|
||||
ctx: context.Background(),
|
||||
displayName: "Apple",
|
||||
pkce: true,
|
||||
scopes: []string{"name", "email"},
|
||||
authURL: "https://appleid.apple.com/auth/authorize",
|
||||
tokenURL: "https://appleid.apple.com/auth/token",
|
||||
},
|
||||
jwksURL: "https://appleid.apple.com/auth/keys",
|
||||
}
|
||||
}
|
||||
|
||||
// FetchAuthUser returns an AuthUser instance based on the provided token.
|
||||
//
|
||||
// API reference: https://developer.apple.com/documentation/sign_in_with_apple/tokenresponse.
|
||||
func (p *Apple) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
data, err := p.FetchRawUserInfo(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rawUser := map[string]any{}
|
||||
if err := json.Unmarshal(data, &rawUser); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
extracted := struct {
|
||||
Id string `json:"sub"`
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
EmailVerified any `json:"email_verified"` // could be string or bool
|
||||
User struct {
|
||||
Name struct {
|
||||
FirstName string `json:"firstName"`
|
||||
LastName string `json:"lastName"`
|
||||
} `json:"name"`
|
||||
} `json:"user"`
|
||||
}{}
|
||||
if err := json.Unmarshal(data, &extracted); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user := &AuthUser{
|
||||
Id: extracted.Id,
|
||||
Name: extracted.Name,
|
||||
RawUser: rawUser,
|
||||
AccessToken: token.AccessToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
}
|
||||
|
||||
user.Expiry, _ = types.ParseDateTime(token.Expiry)
|
||||
|
||||
if cast.ToBool(extracted.EmailVerified) {
|
||||
user.Email = extracted.Email
|
||||
}
|
||||
|
||||
if user.Name == "" {
|
||||
user.Name = strings.TrimSpace(extracted.User.Name.FirstName + " " + extracted.User.Name.LastName)
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// FetchRawUserInfo implements Provider.FetchRawUserInfo interface.
|
||||
//
|
||||
// Apple doesn't have a UserInfo endpoint and claims about users
|
||||
// are instead included in the "id_token" (https://openid.net/specs/openid-connect-core-1_0.html#id_tokenExample)
|
||||
func (p *Apple) FetchRawUserInfo(token *oauth2.Token) ([]byte, error) {
|
||||
idToken, _ := token.Extra("id_token").(string)
|
||||
|
||||
claims, err := p.parseAndVerifyIdToken(idToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Apple only returns the user object the first time the user authorizes the app
|
||||
// https://developer.apple.com/documentation/sign_in_with_apple/sign_in_with_apple_js/configuring_your_webpage_for_sign_in_with_apple#3331292
|
||||
rawUser, _ := token.Extra("user").(string)
|
||||
if rawUser != "" {
|
||||
user := map[string]any{}
|
||||
err = json.Unmarshal([]byte(rawUser), &user)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
claims["user"] = user
|
||||
}
|
||||
|
||||
return json.Marshal(claims)
|
||||
}
|
||||
|
||||
func (p *Apple) parseAndVerifyIdToken(idToken string) (jwt.MapClaims, error) {
|
||||
if idToken == "" {
|
||||
return nil, errors.New("empty id_token")
|
||||
}
|
||||
|
||||
// extract the token header params and claims
|
||||
// ---
|
||||
claims := jwt.MapClaims{}
|
||||
t, _, err := jwt.NewParser().ParseUnverified(idToken, claims)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// validate common claims per https://developer.apple.com/documentation/sign_in_with_apple/sign_in_with_apple_rest_api/verifying_a_user#3383769
|
||||
// ---
|
||||
jwtValidator := jwt.NewValidator(
|
||||
jwt.WithExpirationRequired(),
|
||||
jwt.WithIssuedAt(),
|
||||
jwt.WithLeeway(idTokenLeeway),
|
||||
jwt.WithIssuer("https://appleid.apple.com"),
|
||||
jwt.WithAudience(p.clientId),
|
||||
)
|
||||
err = jwtValidator.Validate(claims)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// validate id_token signature
|
||||
//
|
||||
// note: this step could be technically considered optional because we trust
|
||||
// the token which is a result of direct TLS communication with the provider
|
||||
// (see also https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation)
|
||||
// ---
|
||||
kid, _ := t.Header["kid"].(string)
|
||||
err = validateIdTokenSignature(p.ctx, idToken, p.jwksURL, kid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
157
tools/auth/auth.go
Normal file
157
tools/auth/auth.go
Normal file
|
@ -0,0 +1,157 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
// ProviderFactoryFunc defines a function for initializing a new OAuth2 provider.
|
||||
type ProviderFactoryFunc func() Provider
|
||||
|
||||
// Providers defines a map with all of the available OAuth2 providers.
|
||||
//
|
||||
// To register a new provider append a new entry in the map.
|
||||
var Providers = map[string]ProviderFactoryFunc{}
|
||||
|
||||
// NewProviderByName returns a new preconfigured provider instance by its name identifier.
|
||||
func NewProviderByName(name string) (Provider, error) {
|
||||
factory, ok := Providers[name]
|
||||
if !ok {
|
||||
return nil, errors.New("missing provider " + name)
|
||||
}
|
||||
|
||||
return factory(), nil
|
||||
}
|
||||
|
||||
// Provider defines a common interface for an OAuth2 client.
|
||||
type Provider interface {
|
||||
// Context returns the context associated with the provider (if any).
|
||||
Context() context.Context
|
||||
|
||||
// SetContext assigns the specified context to the current provider.
|
||||
SetContext(ctx context.Context)
|
||||
|
||||
// PKCE indicates whether the provider can use the PKCE flow.
|
||||
PKCE() bool
|
||||
|
||||
// SetPKCE toggles the state whether the provider can use the PKCE flow or not.
|
||||
SetPKCE(enable bool)
|
||||
|
||||
// DisplayName usually returns provider name as it is officially written
|
||||
// and it could be used directly in the UI.
|
||||
DisplayName() string
|
||||
|
||||
// SetDisplayName sets the provider's display name.
|
||||
SetDisplayName(displayName string)
|
||||
|
||||
// Scopes returns the provider access permissions that will be requested.
|
||||
Scopes() []string
|
||||
|
||||
// SetScopes sets the provider access permissions that will be requested later.
|
||||
SetScopes(scopes []string)
|
||||
|
||||
// ClientId returns the provider client's app ID.
|
||||
ClientId() string
|
||||
|
||||
// SetClientId sets the provider client's ID.
|
||||
SetClientId(clientId string)
|
||||
|
||||
// ClientSecret returns the provider client's app secret.
|
||||
ClientSecret() string
|
||||
|
||||
// SetClientSecret sets the provider client's app secret.
|
||||
SetClientSecret(secret string)
|
||||
|
||||
// RedirectURL returns the end address to redirect the user
|
||||
// going through the OAuth flow.
|
||||
RedirectURL() string
|
||||
|
||||
// SetRedirectURL sets the provider's RedirectURL.
|
||||
SetRedirectURL(url string)
|
||||
|
||||
// AuthURL returns the provider's authorization service url.
|
||||
AuthURL() string
|
||||
|
||||
// SetAuthURL sets the provider's AuthURL.
|
||||
SetAuthURL(url string)
|
||||
|
||||
// TokenURL returns the provider's token exchange service url.
|
||||
TokenURL() string
|
||||
|
||||
// SetTokenURL sets the provider's TokenURL.
|
||||
SetTokenURL(url string)
|
||||
|
||||
// UserInfoURL returns the provider's user info api url.
|
||||
UserInfoURL() string
|
||||
|
||||
// SetUserInfoURL sets the provider's UserInfoURL.
|
||||
SetUserInfoURL(url string)
|
||||
|
||||
// Extra returns a shallow copy of any custom config data
|
||||
// that the provider may be need.
|
||||
Extra() map[string]any
|
||||
|
||||
// SetExtra updates the provider's custom config data.
|
||||
SetExtra(data map[string]any)
|
||||
|
||||
// Client returns an http client using the provided token.
|
||||
Client(token *oauth2.Token) *http.Client
|
||||
|
||||
// BuildAuthURL returns a URL to the provider's consent page
|
||||
// that asks for permissions for the required scopes explicitly.
|
||||
BuildAuthURL(state string, opts ...oauth2.AuthCodeOption) string
|
||||
|
||||
// FetchToken converts an authorization code to token.
|
||||
FetchToken(code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error)
|
||||
|
||||
// FetchRawUserInfo requests and marshalizes into `result` the
|
||||
// the OAuth user api response.
|
||||
FetchRawUserInfo(token *oauth2.Token) ([]byte, error)
|
||||
|
||||
// FetchAuthUser is similar to FetchRawUserInfo, but normalizes and
|
||||
// marshalizes the user api response into a standardized AuthUser struct.
|
||||
FetchAuthUser(token *oauth2.Token) (user *AuthUser, err error)
|
||||
}
|
||||
|
||||
// wrapFactory is a helper that wraps a Provider specific factory
|
||||
// function and returns its result as Provider interface.
|
||||
func wrapFactory[T Provider](factory func() T) ProviderFactoryFunc {
|
||||
return func() Provider {
|
||||
return factory()
|
||||
}
|
||||
}
|
||||
|
||||
// AuthUser defines a standardized OAuth2 user data structure.
|
||||
type AuthUser struct {
|
||||
Expiry types.DateTime `json:"expiry"`
|
||||
RawUser map[string]any `json:"rawUser"`
|
||||
Id string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Username string `json:"username"`
|
||||
Email string `json:"email"`
|
||||
AvatarURL string `json:"avatarURL"`
|
||||
AccessToken string `json:"accessToken"`
|
||||
RefreshToken string `json:"refreshToken"`
|
||||
|
||||
// @todo
|
||||
// deprecated: use AvatarURL instead
|
||||
// AvatarUrl will be removed after dropping v0.22 support
|
||||
AvatarUrl string `json:"avatarUrl"`
|
||||
}
|
||||
|
||||
// MarshalJSON implements the [json.Marshaler] interface.
|
||||
//
|
||||
// @todo remove after dropping v0.22 support
|
||||
func (au AuthUser) MarshalJSON() ([]byte, error) {
|
||||
type alias AuthUser // prevent recursion
|
||||
|
||||
au2 := alias(au)
|
||||
au2.AvatarUrl = au.AvatarURL // ensure that the legacy field is populated
|
||||
|
||||
return json.Marshal(au2)
|
||||
}
|
299
tools/auth/auth_test.go
Normal file
299
tools/auth/auth_test.go
Normal file
|
@ -0,0 +1,299 @@
|
|||
package auth_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/auth"
|
||||
)
|
||||
|
||||
func TestProvidersCount(t *testing.T) {
|
||||
expected := 30
|
||||
|
||||
if total := len(auth.Providers); total != expected {
|
||||
t.Fatalf("Expected %d providers, got %d", expected, total)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewProviderByName(t *testing.T) {
|
||||
var err error
|
||||
var p auth.Provider
|
||||
|
||||
// invalid
|
||||
p, err = auth.NewProviderByName("invalid")
|
||||
if err == nil {
|
||||
t.Error("Expected error, got nil")
|
||||
}
|
||||
if p != nil {
|
||||
t.Errorf("Expected provider to be nil, got %v", p)
|
||||
}
|
||||
|
||||
// google
|
||||
p, err = auth.NewProviderByName(auth.NameGoogle)
|
||||
if err != nil {
|
||||
t.Errorf("Expected nil, got error %v", err)
|
||||
}
|
||||
if _, ok := p.(*auth.Google); !ok {
|
||||
t.Error("Expected to be instance of *auth.Google")
|
||||
}
|
||||
|
||||
// facebook
|
||||
p, err = auth.NewProviderByName(auth.NameFacebook)
|
||||
if err != nil {
|
||||
t.Errorf("Expected nil, got error %v", err)
|
||||
}
|
||||
if _, ok := p.(*auth.Facebook); !ok {
|
||||
t.Error("Expected to be instance of *auth.Facebook")
|
||||
}
|
||||
|
||||
// github
|
||||
p, err = auth.NewProviderByName(auth.NameGithub)
|
||||
if err != nil {
|
||||
t.Errorf("Expected nil, got error %v", err)
|
||||
}
|
||||
if _, ok := p.(*auth.Github); !ok {
|
||||
t.Error("Expected to be instance of *auth.Github")
|
||||
}
|
||||
|
||||
// gitlab
|
||||
p, err = auth.NewProviderByName(auth.NameGitlab)
|
||||
if err != nil {
|
||||
t.Errorf("Expected nil, got error %v", err)
|
||||
}
|
||||
if _, ok := p.(*auth.Gitlab); !ok {
|
||||
t.Error("Expected to be instance of *auth.Gitlab")
|
||||
}
|
||||
|
||||
// twitter
|
||||
p, err = auth.NewProviderByName(auth.NameTwitter)
|
||||
if err != nil {
|
||||
t.Errorf("Expected nil, got error %v", err)
|
||||
}
|
||||
if _, ok := p.(*auth.Twitter); !ok {
|
||||
t.Error("Expected to be instance of *auth.Twitter")
|
||||
}
|
||||
|
||||
// discord
|
||||
p, err = auth.NewProviderByName(auth.NameDiscord)
|
||||
if err != nil {
|
||||
t.Errorf("Expected nil, got error %v", err)
|
||||
}
|
||||
if _, ok := p.(*auth.Discord); !ok {
|
||||
t.Error("Expected to be instance of *auth.Discord")
|
||||
}
|
||||
|
||||
// microsoft
|
||||
p, err = auth.NewProviderByName(auth.NameMicrosoft)
|
||||
if err != nil {
|
||||
t.Errorf("Expected nil, got error %v", err)
|
||||
}
|
||||
if _, ok := p.(*auth.Microsoft); !ok {
|
||||
t.Error("Expected to be instance of *auth.Microsoft")
|
||||
}
|
||||
|
||||
// spotify
|
||||
p, err = auth.NewProviderByName(auth.NameSpotify)
|
||||
if err != nil {
|
||||
t.Errorf("Expected nil, got error %v", err)
|
||||
}
|
||||
if _, ok := p.(*auth.Spotify); !ok {
|
||||
t.Error("Expected to be instance of *auth.Spotify")
|
||||
}
|
||||
|
||||
// kakao
|
||||
p, err = auth.NewProviderByName(auth.NameKakao)
|
||||
if err != nil {
|
||||
t.Errorf("Expected nil, got error %v", err)
|
||||
}
|
||||
if _, ok := p.(*auth.Kakao); !ok {
|
||||
t.Error("Expected to be instance of *auth.Kakao")
|
||||
}
|
||||
|
||||
// twitch
|
||||
p, err = auth.NewProviderByName(auth.NameTwitch)
|
||||
if err != nil {
|
||||
t.Errorf("Expected nil, got error %v", err)
|
||||
}
|
||||
if _, ok := p.(*auth.Twitch); !ok {
|
||||
t.Error("Expected to be instance of *auth.Twitch")
|
||||
}
|
||||
|
||||
// strava
|
||||
p, err = auth.NewProviderByName(auth.NameStrava)
|
||||
if err != nil {
|
||||
t.Errorf("Expected nil, got error %v", err)
|
||||
}
|
||||
if _, ok := p.(*auth.Strava); !ok {
|
||||
t.Error("Expected to be instance of *auth.Strava")
|
||||
}
|
||||
|
||||
// gitee
|
||||
p, err = auth.NewProviderByName(auth.NameGitee)
|
||||
if err != nil {
|
||||
t.Errorf("Expected nil, got error %v", err)
|
||||
}
|
||||
if _, ok := p.(*auth.Gitee); !ok {
|
||||
t.Error("Expected to be instance of *auth.Gitee")
|
||||
}
|
||||
|
||||
// livechat
|
||||
p, err = auth.NewProviderByName(auth.NameLivechat)
|
||||
if err != nil {
|
||||
t.Errorf("Expected nil, got error %v", err)
|
||||
}
|
||||
if _, ok := p.(*auth.Livechat); !ok {
|
||||
t.Error("Expected to be instance of *auth.Livechat")
|
||||
}
|
||||
|
||||
// gitea
|
||||
p, err = auth.NewProviderByName(auth.NameGitea)
|
||||
if err != nil {
|
||||
t.Errorf("Expected nil, got error %v", err)
|
||||
}
|
||||
if _, ok := p.(*auth.Gitea); !ok {
|
||||
t.Error("Expected to be instance of *auth.Gitea")
|
||||
}
|
||||
|
||||
// oidc
|
||||
p, err = auth.NewProviderByName(auth.NameOIDC)
|
||||
if err != nil {
|
||||
t.Errorf("Expected nil, got error %v", err)
|
||||
}
|
||||
if _, ok := p.(*auth.OIDC); !ok {
|
||||
t.Error("Expected to be instance of *auth.OIDC")
|
||||
}
|
||||
|
||||
// oidc2
|
||||
p, err = auth.NewProviderByName(auth.NameOIDC + "2")
|
||||
if err != nil {
|
||||
t.Errorf("Expected nil, got error %v", err)
|
||||
}
|
||||
if _, ok := p.(*auth.OIDC); !ok {
|
||||
t.Error("Expected to be instance of *auth.OIDC")
|
||||
}
|
||||
|
||||
// oidc3
|
||||
p, err = auth.NewProviderByName(auth.NameOIDC + "3")
|
||||
if err != nil {
|
||||
t.Errorf("Expected nil, got error %v", err)
|
||||
}
|
||||
if _, ok := p.(*auth.OIDC); !ok {
|
||||
t.Error("Expected to be instance of *auth.OIDC")
|
||||
}
|
||||
|
||||
// apple
|
||||
p, err = auth.NewProviderByName(auth.NameApple)
|
||||
if err != nil {
|
||||
t.Errorf("Expected nil, got error %v", err)
|
||||
}
|
||||
if _, ok := p.(*auth.Apple); !ok {
|
||||
t.Error("Expected to be instance of *auth.Apple")
|
||||
}
|
||||
|
||||
// instagram
|
||||
p, err = auth.NewProviderByName(auth.NameInstagram)
|
||||
if err != nil {
|
||||
t.Errorf("Expected nil, got error %v", err)
|
||||
}
|
||||
if _, ok := p.(*auth.Instagram); !ok {
|
||||
t.Error("Expected to be instance of *auth.Instagram")
|
||||
}
|
||||
|
||||
// vk
|
||||
p, err = auth.NewProviderByName(auth.NameVK)
|
||||
if err != nil {
|
||||
t.Errorf("Expected nil, got error %v", err)
|
||||
}
|
||||
if _, ok := p.(*auth.VK); !ok {
|
||||
t.Error("Expected to be instance of *auth.VK")
|
||||
}
|
||||
|
||||
// yandex
|
||||
p, err = auth.NewProviderByName(auth.NameYandex)
|
||||
if err != nil {
|
||||
t.Errorf("Expected nil, got error %v", err)
|
||||
}
|
||||
if _, ok := p.(*auth.Yandex); !ok {
|
||||
t.Error("Expected to be instance of *auth.Yandex")
|
||||
}
|
||||
|
||||
// patreon
|
||||
p, err = auth.NewProviderByName(auth.NamePatreon)
|
||||
if err != nil {
|
||||
t.Errorf("Expected nil, got error %v", err)
|
||||
}
|
||||
if _, ok := p.(*auth.Patreon); !ok {
|
||||
t.Error("Expected to be instance of *auth.Patreon")
|
||||
}
|
||||
|
||||
// mailcow
|
||||
p, err = auth.NewProviderByName(auth.NameMailcow)
|
||||
if err != nil {
|
||||
t.Errorf("Expected nil, got error %v", err)
|
||||
}
|
||||
if _, ok := p.(*auth.Mailcow); !ok {
|
||||
t.Error("Expected to be instance of *auth.Mailcow")
|
||||
}
|
||||
|
||||
// bitbucket
|
||||
p, err = auth.NewProviderByName(auth.NameBitbucket)
|
||||
if err != nil {
|
||||
t.Errorf("Expected nil, got error %v", err)
|
||||
}
|
||||
if _, ok := p.(*auth.Bitbucket); !ok {
|
||||
t.Error("Expected to be instance of *auth.Bitbucket")
|
||||
}
|
||||
|
||||
// planningcenter
|
||||
p, err = auth.NewProviderByName(auth.NamePlanningcenter)
|
||||
if err != nil {
|
||||
t.Errorf("Expected nil, got error %v", err)
|
||||
}
|
||||
if _, ok := p.(*auth.Planningcenter); !ok {
|
||||
t.Error("Expected to be instance of *auth.Planningcenter")
|
||||
}
|
||||
|
||||
// notion
|
||||
p, err = auth.NewProviderByName(auth.NameNotion)
|
||||
if err != nil {
|
||||
t.Errorf("Expected nil, got error %v", err)
|
||||
}
|
||||
if _, ok := p.(*auth.Notion); !ok {
|
||||
t.Error("Expected to be instance of *auth.Notion")
|
||||
}
|
||||
|
||||
// monday
|
||||
p, err = auth.NewProviderByName(auth.NameMonday)
|
||||
if err != nil {
|
||||
t.Errorf("Expected nil, got error %v", err)
|
||||
}
|
||||
if _, ok := p.(*auth.Monday); !ok {
|
||||
t.Error("Expected to be instance of *auth.Monday")
|
||||
}
|
||||
|
||||
// wakatime
|
||||
p, err = auth.NewProviderByName(auth.NameWakatime)
|
||||
if err != nil {
|
||||
t.Errorf("Expected nil, got error %v", err)
|
||||
}
|
||||
if _, ok := p.(*auth.Wakatime); !ok {
|
||||
t.Error("Expected to be instance of *auth.Wakatime")
|
||||
}
|
||||
|
||||
// linear
|
||||
p, err = auth.NewProviderByName(auth.NameLinear)
|
||||
if err != nil {
|
||||
t.Errorf("Expected nil, got error %v", err)
|
||||
}
|
||||
if _, ok := p.(*auth.Linear); !ok {
|
||||
t.Error("Expected to be instance of *auth.Linear")
|
||||
}
|
||||
|
||||
// trakt
|
||||
p, err = auth.NewProviderByName(auth.NameTrakt)
|
||||
if err != nil {
|
||||
t.Errorf("Expected nil, got error %v", err)
|
||||
}
|
||||
if _, ok := p.(*auth.Trakt); !ok {
|
||||
t.Error("Expected to be instance of *auth.Trakt")
|
||||
}
|
||||
}
|
203
tools/auth/base_provider.go
Normal file
203
tools/auth/base_provider.go
Normal file
|
@ -0,0 +1,203 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"maps"
|
||||
"net/http"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
// BaseProvider defines common fields and methods used by OAuth2 client providers.
|
||||
type BaseProvider struct {
|
||||
ctx context.Context
|
||||
clientId string
|
||||
clientSecret string
|
||||
displayName string
|
||||
redirectURL string
|
||||
authURL string
|
||||
tokenURL string
|
||||
userInfoURL string
|
||||
scopes []string
|
||||
pkce bool
|
||||
extra map[string]any
|
||||
}
|
||||
|
||||
// Context implements Provider.Context() interface method.
|
||||
func (p *BaseProvider) Context() context.Context {
|
||||
return p.ctx
|
||||
}
|
||||
|
||||
// SetContext implements Provider.SetContext() interface method.
|
||||
func (p *BaseProvider) SetContext(ctx context.Context) {
|
||||
p.ctx = ctx
|
||||
}
|
||||
|
||||
// PKCE implements Provider.PKCE() interface method.
|
||||
func (p *BaseProvider) PKCE() bool {
|
||||
return p.pkce
|
||||
}
|
||||
|
||||
// SetPKCE implements Provider.SetPKCE() interface method.
|
||||
func (p *BaseProvider) SetPKCE(enable bool) {
|
||||
p.pkce = enable
|
||||
}
|
||||
|
||||
// DisplayName implements Provider.DisplayName() interface method.
|
||||
func (p *BaseProvider) DisplayName() string {
|
||||
return p.displayName
|
||||
}
|
||||
|
||||
// SetDisplayName implements Provider.SetDisplayName() interface method.
|
||||
func (p *BaseProvider) SetDisplayName(displayName string) {
|
||||
p.displayName = displayName
|
||||
}
|
||||
|
||||
// Scopes implements Provider.Scopes() interface method.
|
||||
func (p *BaseProvider) Scopes() []string {
|
||||
return p.scopes
|
||||
}
|
||||
|
||||
// SetScopes implements Provider.SetScopes() interface method.
|
||||
func (p *BaseProvider) SetScopes(scopes []string) {
|
||||
p.scopes = scopes
|
||||
}
|
||||
|
||||
// ClientId implements Provider.ClientId() interface method.
|
||||
func (p *BaseProvider) ClientId() string {
|
||||
return p.clientId
|
||||
}
|
||||
|
||||
// SetClientId implements Provider.SetClientId() interface method.
|
||||
func (p *BaseProvider) SetClientId(clientId string) {
|
||||
p.clientId = clientId
|
||||
}
|
||||
|
||||
// ClientSecret implements Provider.ClientSecret() interface method.
|
||||
func (p *BaseProvider) ClientSecret() string {
|
||||
return p.clientSecret
|
||||
}
|
||||
|
||||
// SetClientSecret implements Provider.SetClientSecret() interface method.
|
||||
func (p *BaseProvider) SetClientSecret(secret string) {
|
||||
p.clientSecret = secret
|
||||
}
|
||||
|
||||
// RedirectURL implements Provider.RedirectURL() interface method.
|
||||
func (p *BaseProvider) RedirectURL() string {
|
||||
return p.redirectURL
|
||||
}
|
||||
|
||||
// SetRedirectURL implements Provider.SetRedirectURL() interface method.
|
||||
func (p *BaseProvider) SetRedirectURL(url string) {
|
||||
p.redirectURL = url
|
||||
}
|
||||
|
||||
// AuthURL implements Provider.AuthURL() interface method.
|
||||
func (p *BaseProvider) AuthURL() string {
|
||||
return p.authURL
|
||||
}
|
||||
|
||||
// SetAuthURL implements Provider.SetAuthURL() interface method.
|
||||
func (p *BaseProvider) SetAuthURL(url string) {
|
||||
p.authURL = url
|
||||
}
|
||||
|
||||
// TokenURL implements Provider.TokenURL() interface method.
|
||||
func (p *BaseProvider) TokenURL() string {
|
||||
return p.tokenURL
|
||||
}
|
||||
|
||||
// SetTokenURL implements Provider.SetTokenURL() interface method.
|
||||
func (p *BaseProvider) SetTokenURL(url string) {
|
||||
p.tokenURL = url
|
||||
}
|
||||
|
||||
// UserInfoURL implements Provider.UserInfoURL() interface method.
|
||||
func (p *BaseProvider) UserInfoURL() string {
|
||||
return p.userInfoURL
|
||||
}
|
||||
|
||||
// SetUserInfoURL implements Provider.SetUserInfoURL() interface method.
|
||||
func (p *BaseProvider) SetUserInfoURL(url string) {
|
||||
p.userInfoURL = url
|
||||
}
|
||||
|
||||
// Extra implements Provider.Extra() interface method.
|
||||
func (p *BaseProvider) Extra() map[string]any {
|
||||
return maps.Clone(p.extra)
|
||||
}
|
||||
|
||||
// SetExtra implements Provider.SetExtra() interface method.
|
||||
func (p *BaseProvider) SetExtra(data map[string]any) {
|
||||
p.extra = data
|
||||
}
|
||||
|
||||
// BuildAuthURL implements Provider.BuildAuthURL() interface method.
|
||||
func (p *BaseProvider) BuildAuthURL(state string, opts ...oauth2.AuthCodeOption) string {
|
||||
return p.oauth2Config().AuthCodeURL(state, opts...)
|
||||
}
|
||||
|
||||
// FetchToken implements Provider.FetchToken() interface method.
|
||||
func (p *BaseProvider) FetchToken(code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
|
||||
return p.oauth2Config().Exchange(p.ctx, code, opts...)
|
||||
}
|
||||
|
||||
// Client implements Provider.Client() interface method.
|
||||
func (p *BaseProvider) Client(token *oauth2.Token) *http.Client {
|
||||
return p.oauth2Config().Client(p.ctx, token)
|
||||
}
|
||||
|
||||
// FetchRawUserInfo implements Provider.FetchRawUserInfo() interface method.
|
||||
func (p *BaseProvider) FetchRawUserInfo(token *oauth2.Token) ([]byte, error) {
|
||||
req, err := http.NewRequestWithContext(p.ctx, "GET", p.userInfoURL, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return p.sendRawUserInfoRequest(req, token)
|
||||
}
|
||||
|
||||
// sendRawUserInfoRequest sends the specified user info request and return its raw response body.
|
||||
func (p *BaseProvider) sendRawUserInfoRequest(req *http.Request, token *oauth2.Token) ([]byte, error) {
|
||||
client := p.Client(token)
|
||||
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
result, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// http.Client.Get doesn't treat non 2xx responses as error
|
||||
if res.StatusCode >= 400 {
|
||||
return nil, fmt.Errorf(
|
||||
"failed to fetch OAuth2 user profile via %s (%d):\n%s",
|
||||
p.userInfoURL,
|
||||
res.StatusCode,
|
||||
string(result),
|
||||
)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// oauth2Config constructs a oauth2.Config instance based on the provider settings.
|
||||
func (p *BaseProvider) oauth2Config() *oauth2.Config {
|
||||
return &oauth2.Config{
|
||||
RedirectURL: p.redirectURL,
|
||||
ClientID: p.clientId,
|
||||
ClientSecret: p.clientSecret,
|
||||
Scopes: p.scopes,
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: p.authURL,
|
||||
TokenURL: p.tokenURL,
|
||||
},
|
||||
}
|
||||
}
|
269
tools/auth/base_provider_test.go
Normal file
269
tools/auth/base_provider_test.go
Normal file
|
@ -0,0 +1,269 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
func TestContext(t *testing.T) {
|
||||
b := BaseProvider{}
|
||||
|
||||
before := b.Scopes()
|
||||
if before != nil {
|
||||
t.Errorf("Expected nil context, got %v", before)
|
||||
}
|
||||
|
||||
b.SetContext(context.Background())
|
||||
|
||||
after := b.Scopes()
|
||||
if after != nil {
|
||||
t.Error("Expected non-nil context")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDisplayName(t *testing.T) {
|
||||
b := BaseProvider{}
|
||||
|
||||
before := b.DisplayName()
|
||||
if before != "" {
|
||||
t.Fatalf("Expected displayName to be empty, got %v", before)
|
||||
}
|
||||
|
||||
b.SetDisplayName("test")
|
||||
|
||||
after := b.DisplayName()
|
||||
if after != "test" {
|
||||
t.Fatalf("Expected displayName to be 'test', got %v", after)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPKCE(t *testing.T) {
|
||||
b := BaseProvider{}
|
||||
|
||||
before := b.PKCE()
|
||||
if before != false {
|
||||
t.Fatalf("Expected pkce to be %v, got %v", false, before)
|
||||
}
|
||||
|
||||
b.SetPKCE(true)
|
||||
|
||||
after := b.PKCE()
|
||||
if after != true {
|
||||
t.Fatalf("Expected pkce to be %v, got %v", true, after)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScopes(t *testing.T) {
|
||||
b := BaseProvider{}
|
||||
|
||||
before := b.Scopes()
|
||||
if len(before) != 0 {
|
||||
t.Fatalf("Expected 0 scopes, got %v", before)
|
||||
}
|
||||
|
||||
b.SetScopes([]string{"test1", "test2"})
|
||||
|
||||
after := b.Scopes()
|
||||
if len(after) != 2 {
|
||||
t.Fatalf("Expected 2 scopes, got %v", after)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientId(t *testing.T) {
|
||||
b := BaseProvider{}
|
||||
|
||||
before := b.ClientId()
|
||||
if before != "" {
|
||||
t.Fatalf("Expected clientId to be empty, got %v", before)
|
||||
}
|
||||
|
||||
b.SetClientId("test")
|
||||
|
||||
after := b.ClientId()
|
||||
if after != "test" {
|
||||
t.Fatalf("Expected clientId to be 'test', got %v", after)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientSecret(t *testing.T) {
|
||||
b := BaseProvider{}
|
||||
|
||||
before := b.ClientSecret()
|
||||
if before != "" {
|
||||
t.Fatalf("Expected clientSecret to be empty, got %v", before)
|
||||
}
|
||||
|
||||
b.SetClientSecret("test")
|
||||
|
||||
after := b.ClientSecret()
|
||||
if after != "test" {
|
||||
t.Fatalf("Expected clientSecret to be 'test', got %v", after)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedirectURL(t *testing.T) {
|
||||
b := BaseProvider{}
|
||||
|
||||
before := b.RedirectURL()
|
||||
if before != "" {
|
||||
t.Fatalf("Expected RedirectURL to be empty, got %v", before)
|
||||
}
|
||||
|
||||
b.SetRedirectURL("test")
|
||||
|
||||
after := b.RedirectURL()
|
||||
if after != "test" {
|
||||
t.Fatalf("Expected RedirectURL to be 'test', got %v", after)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthURL(t *testing.T) {
|
||||
b := BaseProvider{}
|
||||
|
||||
before := b.AuthURL()
|
||||
if before != "" {
|
||||
t.Fatalf("Expected authURL to be empty, got %v", before)
|
||||
}
|
||||
|
||||
b.SetAuthURL("test")
|
||||
|
||||
after := b.AuthURL()
|
||||
if after != "test" {
|
||||
t.Fatalf("Expected authURL to be 'test', got %v", after)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenURL(t *testing.T) {
|
||||
b := BaseProvider{}
|
||||
|
||||
before := b.TokenURL()
|
||||
if before != "" {
|
||||
t.Fatalf("Expected tokenURL to be empty, got %v", before)
|
||||
}
|
||||
|
||||
b.SetTokenURL("test")
|
||||
|
||||
after := b.TokenURL()
|
||||
if after != "test" {
|
||||
t.Fatalf("Expected tokenURL to be 'test', got %v", after)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserInfoURL(t *testing.T) {
|
||||
b := BaseProvider{}
|
||||
|
||||
before := b.UserInfoURL()
|
||||
if before != "" {
|
||||
t.Fatalf("Expected userInfoURL to be empty, got %v", before)
|
||||
}
|
||||
|
||||
b.SetUserInfoURL("test")
|
||||
|
||||
after := b.UserInfoURL()
|
||||
if after != "test" {
|
||||
t.Fatalf("Expected userInfoURL to be 'test', got %v", after)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtra(t *testing.T) {
|
||||
b := BaseProvider{}
|
||||
|
||||
before := b.Extra()
|
||||
if before != nil {
|
||||
t.Fatalf("Expected extra to be empty, got %v", before)
|
||||
}
|
||||
|
||||
extra := map[string]any{"a": 1, "b": 2}
|
||||
|
||||
b.SetExtra(extra)
|
||||
|
||||
after := b.Extra()
|
||||
|
||||
rawExtra, err := json.Marshal(extra)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
rawAfter, err := json.Marshal(after)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(rawExtra, rawAfter) {
|
||||
t.Fatalf("Expected extra to be\n%s\ngot\n%s", rawExtra, rawAfter)
|
||||
}
|
||||
|
||||
// ensure that it was shallow copied
|
||||
after["b"] = 3
|
||||
if d := b.Extra(); d["b"] != 2 {
|
||||
t.Fatalf("Expected extra to remain unchanged, got\n%v", d)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAuthURL(t *testing.T) {
|
||||
b := BaseProvider{
|
||||
authURL: "authURL_test",
|
||||
tokenURL: "tokenURL_test",
|
||||
redirectURL: "redirectURL_test",
|
||||
clientId: "clientId_test",
|
||||
clientSecret: "clientSecret_test",
|
||||
scopes: []string{"test_scope"},
|
||||
}
|
||||
|
||||
expected := "authURL_test?access_type=offline&client_id=clientId_test&prompt=consent&redirect_uri=redirectURL_test&response_type=code&scope=test_scope&state=state_test"
|
||||
result := b.BuildAuthURL("state_test", oauth2.AccessTypeOffline, oauth2.ApprovalForce)
|
||||
|
||||
if result != expected {
|
||||
t.Errorf("Expected auth url %q, got %q", expected, result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient(t *testing.T) {
|
||||
b := BaseProvider{}
|
||||
|
||||
result := b.Client(&oauth2.Token{})
|
||||
if result == nil {
|
||||
t.Error("Expected *http.Client instance, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOauth2Config(t *testing.T) {
|
||||
b := BaseProvider{
|
||||
authURL: "authURL_test",
|
||||
tokenURL: "tokenURL_test",
|
||||
redirectURL: "redirectURL_test",
|
||||
clientId: "clientId_test",
|
||||
clientSecret: "clientSecret_test",
|
||||
scopes: []string{"test"},
|
||||
}
|
||||
|
||||
result := b.oauth2Config()
|
||||
|
||||
if result.RedirectURL != b.RedirectURL() {
|
||||
t.Errorf("Expected redirectURL %s, got %s", b.RedirectURL(), result.RedirectURL)
|
||||
}
|
||||
|
||||
if result.ClientID != b.ClientId() {
|
||||
t.Errorf("Expected clientId %s, got %s", b.ClientId(), result.ClientID)
|
||||
}
|
||||
|
||||
if result.ClientSecret != b.ClientSecret() {
|
||||
t.Errorf("Expected clientSecret %s, got %s", b.ClientSecret(), result.ClientSecret)
|
||||
}
|
||||
|
||||
if result.Endpoint.AuthURL != b.AuthURL() {
|
||||
t.Errorf("Expected authURL %s, got %s", b.AuthURL(), result.Endpoint.AuthURL)
|
||||
}
|
||||
|
||||
if result.Endpoint.TokenURL != b.TokenURL() {
|
||||
t.Errorf("Expected authURL %s, got %s", b.TokenURL(), result.Endpoint.TokenURL)
|
||||
}
|
||||
|
||||
if len(result.Scopes) != len(b.Scopes()) || result.Scopes[0] != b.Scopes()[0] {
|
||||
t.Errorf("Expected scopes %s, got %s", b.Scopes(), result.Scopes)
|
||||
}
|
||||
}
|
136
tools/auth/bitbucket.go
Normal file
136
tools/auth/bitbucket.go
Normal file
|
@ -0,0 +1,136 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Providers[NameBitbucket] = wrapFactory(NewBitbucketProvider)
|
||||
}
|
||||
|
||||
var _ Provider = (*Bitbucket)(nil)
|
||||
|
||||
// NameBitbucket is the unique name of the Bitbucket provider.
|
||||
const NameBitbucket = "bitbucket"
|
||||
|
||||
// Bitbucket is an auth provider for Bitbucket.
|
||||
type Bitbucket struct {
|
||||
BaseProvider
|
||||
}
|
||||
|
||||
// NewBitbucketProvider creates a new Bitbucket provider instance with some defaults.
|
||||
func NewBitbucketProvider() *Bitbucket {
|
||||
return &Bitbucket{BaseProvider{
|
||||
ctx: context.Background(),
|
||||
displayName: "Bitbucket",
|
||||
pkce: false,
|
||||
scopes: []string{"account"},
|
||||
authURL: "https://bitbucket.org/site/oauth2/authorize",
|
||||
tokenURL: "https://bitbucket.org/site/oauth2/access_token",
|
||||
userInfoURL: "https://api.bitbucket.org/2.0/user",
|
||||
}}
|
||||
}
|
||||
|
||||
// FetchAuthUser returns an AuthUser instance based on the Bitbucket's user API.
|
||||
//
|
||||
// API reference: https://developer.atlassian.com/cloud/bitbucket/rest/api-group-users/#api-user-get
|
||||
func (p *Bitbucket) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
data, err := p.FetchRawUserInfo(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rawUser := map[string]any{}
|
||||
if err := json.Unmarshal(data, &rawUser); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
extracted := struct {
|
||||
UUID string `json:"uuid"`
|
||||
Username string `json:"username"`
|
||||
DisplayName string `json:"display_name"`
|
||||
AccountStatus string `json:"account_status"`
|
||||
Links struct {
|
||||
Avatar struct {
|
||||
Href string `json:"href"`
|
||||
} `json:"avatar"`
|
||||
} `json:"links"`
|
||||
}{}
|
||||
if err := json.Unmarshal(data, &extracted); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if extracted.AccountStatus != "active" {
|
||||
return nil, errors.New("the Bitbucket user is not active")
|
||||
}
|
||||
|
||||
email, err := p.fetchPrimaryEmail(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user := &AuthUser{
|
||||
Id: extracted.UUID,
|
||||
Name: extracted.DisplayName,
|
||||
Username: extracted.Username,
|
||||
Email: email,
|
||||
AvatarURL: extracted.Links.Avatar.Href,
|
||||
RawUser: rawUser,
|
||||
AccessToken: token.AccessToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
}
|
||||
|
||||
user.Expiry, _ = types.ParseDateTime(token.Expiry)
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// fetchPrimaryEmail sends an API request to retrieve the first
|
||||
// verified primary email.
|
||||
//
|
||||
// NB! This method can succeed and still return an empty email.
|
||||
// Error responses that are result of insufficient scopes permissions are ignored.
|
||||
//
|
||||
// API reference: https://developer.atlassian.com/cloud/bitbucket/rest/api-group-users/#api-user-emails-get
|
||||
func (p *Bitbucket) fetchPrimaryEmail(token *oauth2.Token) (string, error) {
|
||||
response, err := p.Client(token).Get(p.userInfoURL + "/emails")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
// ignore common http errors caused by insufficient scope permissions
|
||||
// (the email field is optional, aka. return the auth user without it)
|
||||
if response.StatusCode >= 400 {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
data, err := io.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
expected := struct {
|
||||
Values []struct {
|
||||
Email string `json:"email"`
|
||||
IsPrimary bool `json:"is_primary"`
|
||||
} `json:"values"`
|
||||
}{}
|
||||
if err := json.Unmarshal(data, &expected); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
for _, v := range expected.Values {
|
||||
if v.IsPrimary {
|
||||
return v.Email, nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", nil
|
||||
}
|
91
tools/auth/discord.go
Normal file
91
tools/auth/discord.go
Normal file
|
@ -0,0 +1,91 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Providers[NameDiscord] = wrapFactory(NewDiscordProvider)
|
||||
}
|
||||
|
||||
var _ Provider = (*Discord)(nil)
|
||||
|
||||
// NameDiscord is the unique name of the Discord provider.
|
||||
const NameDiscord string = "discord"
|
||||
|
||||
// Discord allows authentication via Discord OAuth2.
|
||||
type Discord struct {
|
||||
BaseProvider
|
||||
}
|
||||
|
||||
// NewDiscordProvider creates a new Discord provider instance with some defaults.
|
||||
func NewDiscordProvider() *Discord {
|
||||
// https://discord.com/developers/docs/topics/oauth2
|
||||
// https://discord.com/developers/docs/resources/user#get-current-user
|
||||
return &Discord{BaseProvider{
|
||||
ctx: context.Background(),
|
||||
displayName: "Discord",
|
||||
pkce: true,
|
||||
scopes: []string{"identify", "email"},
|
||||
authURL: "https://discord.com/api/oauth2/authorize",
|
||||
tokenURL: "https://discord.com/api/oauth2/token",
|
||||
userInfoURL: "https://discord.com/api/users/@me",
|
||||
}}
|
||||
}
|
||||
|
||||
// FetchAuthUser returns an AuthUser instance from Discord's user api.
|
||||
//
|
||||
// API reference: https://discord.com/developers/docs/resources/user#user-object
|
||||
func (p *Discord) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
data, err := p.FetchRawUserInfo(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rawUser := map[string]any{}
|
||||
if err := json.Unmarshal(data, &rawUser); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
extracted := struct {
|
||||
Id string `json:"id"`
|
||||
Username string `json:"username"`
|
||||
Discriminator string `json:"discriminator"`
|
||||
Avatar string `json:"avatar"`
|
||||
Email string `json:"email"`
|
||||
Verified bool `json:"verified"`
|
||||
}{}
|
||||
if err := json.Unmarshal(data, &extracted); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Build a full avatar URL using the avatar hash provided in the API response
|
||||
// https://discord.com/developers/docs/reference#image-formatting
|
||||
avatarURL := fmt.Sprintf("https://cdn.discordapp.com/avatars/%s/%s.png", extracted.Id, extracted.Avatar)
|
||||
|
||||
// Concatenate the user's username and discriminator into a single username string
|
||||
username := fmt.Sprintf("%s#%s", extracted.Username, extracted.Discriminator)
|
||||
|
||||
user := &AuthUser{
|
||||
Id: extracted.Id,
|
||||
Name: username,
|
||||
Username: extracted.Username,
|
||||
AvatarURL: avatarURL,
|
||||
RawUser: rawUser,
|
||||
AccessToken: token.AccessToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
}
|
||||
|
||||
user.Expiry, _ = types.ParseDateTime(token.Expiry)
|
||||
|
||||
if extracted.Verified {
|
||||
user.Email = extracted.Email
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
78
tools/auth/facebook.go
Normal file
78
tools/auth/facebook.go
Normal file
|
@ -0,0 +1,78 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/facebook"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Providers[NameFacebook] = wrapFactory(NewFacebookProvider)
|
||||
}
|
||||
|
||||
var _ Provider = (*Facebook)(nil)
|
||||
|
||||
// NameFacebook is the unique name of the Facebook provider.
|
||||
const NameFacebook string = "facebook"
|
||||
|
||||
// Facebook allows authentication via Facebook OAuth2.
|
||||
type Facebook struct {
|
||||
BaseProvider
|
||||
}
|
||||
|
||||
// NewFacebookProvider creates new Facebook provider instance with some defaults.
|
||||
func NewFacebookProvider() *Facebook {
|
||||
return &Facebook{BaseProvider{
|
||||
ctx: context.Background(),
|
||||
displayName: "Facebook",
|
||||
pkce: true,
|
||||
scopes: []string{"email"},
|
||||
authURL: facebook.Endpoint.AuthURL,
|
||||
tokenURL: facebook.Endpoint.TokenURL,
|
||||
userInfoURL: "https://graph.facebook.com/me?fields=name,email,picture.type(large)",
|
||||
}}
|
||||
}
|
||||
|
||||
// FetchAuthUser returns an AuthUser instance based on the Facebook's user api.
|
||||
//
|
||||
// API reference: https://developers.facebook.com/docs/graph-api/reference/user/
|
||||
func (p *Facebook) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
data, err := p.FetchRawUserInfo(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rawUser := map[string]any{}
|
||||
if err := json.Unmarshal(data, &rawUser); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
extracted := struct {
|
||||
Id string
|
||||
Name string
|
||||
Email string
|
||||
Picture struct {
|
||||
Data struct{ Url string }
|
||||
}
|
||||
}{}
|
||||
if err := json.Unmarshal(data, &extracted); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user := &AuthUser{
|
||||
Id: extracted.Id,
|
||||
Name: extracted.Name,
|
||||
Email: extracted.Email,
|
||||
AvatarURL: extracted.Picture.Data.Url,
|
||||
RawUser: rawUser,
|
||||
AccessToken: token.AccessToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
}
|
||||
|
||||
user.Expiry, _ = types.ParseDateTime(token.Expiry)
|
||||
|
||||
return user, nil
|
||||
}
|
78
tools/auth/gitea.go
Normal file
78
tools/auth/gitea.go
Normal file
|
@ -0,0 +1,78 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"strconv"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Providers[NameGitea] = wrapFactory(NewGiteaProvider)
|
||||
}
|
||||
|
||||
var _ Provider = (*Gitea)(nil)
|
||||
|
||||
// NameGitea is the unique name of the Gitea provider.
|
||||
const NameGitea string = "gitea"
|
||||
|
||||
// Gitea allows authentication via Gitea OAuth2.
|
||||
type Gitea struct {
|
||||
BaseProvider
|
||||
}
|
||||
|
||||
// NewGiteaProvider creates new Gitea provider instance with some defaults.
|
||||
func NewGiteaProvider() *Gitea {
|
||||
return &Gitea{BaseProvider{
|
||||
ctx: context.Background(),
|
||||
displayName: "Gitea",
|
||||
pkce: true,
|
||||
scopes: []string{"read:user", "user:email"},
|
||||
authURL: "https://gitea.com/login/oauth/authorize",
|
||||
tokenURL: "https://gitea.com/login/oauth/access_token",
|
||||
userInfoURL: "https://gitea.com/api/v1/user",
|
||||
}}
|
||||
}
|
||||
|
||||
// FetchAuthUser returns an AuthUser instance based on Gitea's user api.
|
||||
//
|
||||
// API reference: https://try.gitea.io/api/swagger#/user/userGetCurrent
|
||||
func (p *Gitea) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
data, err := p.FetchRawUserInfo(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rawUser := map[string]any{}
|
||||
if err := json.Unmarshal(data, &rawUser); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
extracted := struct {
|
||||
Name string `json:"full_name"`
|
||||
Username string `json:"login"`
|
||||
Email string `json:"email"`
|
||||
AvatarURL string `json:"avatar_url"`
|
||||
Id int64 `json:"id"`
|
||||
}{}
|
||||
if err := json.Unmarshal(data, &extracted); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user := &AuthUser{
|
||||
Id: strconv.FormatInt(extracted.Id, 10),
|
||||
Name: extracted.Name,
|
||||
Username: extracted.Username,
|
||||
Email: extracted.Email,
|
||||
AvatarURL: extracted.AvatarURL,
|
||||
RawUser: rawUser,
|
||||
AccessToken: token.AccessToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
}
|
||||
|
||||
user.Expiry, _ = types.ParseDateTime(token.Expiry)
|
||||
|
||||
return user, nil
|
||||
}
|
142
tools/auth/gitee.go
Normal file
142
tools/auth/gitee.go
Normal file
|
@ -0,0 +1,142 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"strconv"
|
||||
|
||||
"github.com/go-ozzo/ozzo-validation/v4/is"
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Providers[NameGitee] = wrapFactory(NewGiteeProvider)
|
||||
}
|
||||
|
||||
var _ Provider = (*Gitee)(nil)
|
||||
|
||||
// NameGitee is the unique name of the Gitee provider.
|
||||
const NameGitee string = "gitee"
|
||||
|
||||
// Gitee allows authentication via Gitee OAuth2.
|
||||
type Gitee struct {
|
||||
BaseProvider
|
||||
}
|
||||
|
||||
// NewGiteeProvider creates new Gitee provider instance with some defaults.
|
||||
func NewGiteeProvider() *Gitee {
|
||||
return &Gitee{BaseProvider{
|
||||
ctx: context.Background(),
|
||||
displayName: "Gitee",
|
||||
pkce: true,
|
||||
scopes: []string{"user_info", "emails"},
|
||||
authURL: "https://gitee.com/oauth/authorize",
|
||||
tokenURL: "https://gitee.com/oauth/token",
|
||||
userInfoURL: "https://gitee.com/api/v5/user",
|
||||
}}
|
||||
}
|
||||
|
||||
// FetchAuthUser returns an AuthUser instance based the Gitee's user api.
|
||||
//
|
||||
// API reference: https://gitee.com/api/v5/swagger#/getV5User
|
||||
func (p *Gitee) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
data, err := p.FetchRawUserInfo(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rawUser := map[string]any{}
|
||||
if err := json.Unmarshal(data, &rawUser); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
extracted := struct {
|
||||
Login string `json:"login"`
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
AvatarURL string `json:"avatar_url"`
|
||||
Id int64 `json:"id"`
|
||||
}{}
|
||||
if err := json.Unmarshal(data, &extracted); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user := &AuthUser{
|
||||
Id: strconv.FormatInt(extracted.Id, 10),
|
||||
Name: extracted.Name,
|
||||
Username: extracted.Login,
|
||||
AvatarURL: extracted.AvatarURL,
|
||||
RawUser: rawUser,
|
||||
AccessToken: token.AccessToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
}
|
||||
|
||||
user.Expiry, _ = types.ParseDateTime(token.Expiry)
|
||||
|
||||
if extracted.Email != "" && is.EmailFormat.Validate(extracted.Email) == nil {
|
||||
// valid public primary email
|
||||
user.Email = extracted.Email
|
||||
} else {
|
||||
// send an additional optional request to retrieve the email
|
||||
email, err := p.fetchPrimaryEmail(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
user.Email = email
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// fetchPrimaryEmail sends an API request to retrieve the verified primary email,
|
||||
// in case the user hasn't set "Public email address" or has unchecked
|
||||
// the "Access your emails data" permission during authentication.
|
||||
//
|
||||
// NB! This method can succeed and still return an empty email.
|
||||
// Error responses that are result of insufficient scopes permissions are ignored.
|
||||
//
|
||||
// API reference: https://gitee.com/api/v5/swagger#/getV5Emails
|
||||
func (p *Gitee) fetchPrimaryEmail(token *oauth2.Token) (string, error) {
|
||||
client := p.Client(token)
|
||||
|
||||
response, err := client.Get("https://gitee.com/api/v5/emails")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
// ignore common http errors caused by insufficient scope permissions
|
||||
if response.StatusCode == 401 || response.StatusCode == 403 || response.StatusCode == 404 {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
content, err := io.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
emails := []struct {
|
||||
Email string
|
||||
State string
|
||||
Scope []string
|
||||
}{}
|
||||
if err := json.Unmarshal(content, &emails); err != nil {
|
||||
// ignore unmarshal error in case "Keep my email address private"
|
||||
// was set because response.Body will be something like:
|
||||
// {"email":"12285415+test@user.noreply.gitee.com"}
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// extract the first verified primary email
|
||||
for _, email := range emails {
|
||||
for _, scope := range email.Scope {
|
||||
if email.State == "confirmed" && scope == "primary" && is.EmailFormat.Validate(email.Email) == nil {
|
||||
return email.Email, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return "", nil
|
||||
}
|
136
tools/auth/github.go
Normal file
136
tools/auth/github.go
Normal file
|
@ -0,0 +1,136 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"strconv"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/github"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Providers[NameGithub] = wrapFactory(NewGithubProvider)
|
||||
}
|
||||
|
||||
var _ Provider = (*Github)(nil)
|
||||
|
||||
// NameGithub is the unique name of the Github provider.
|
||||
const NameGithub string = "github"
|
||||
|
||||
// Github allows authentication via Github OAuth2.
|
||||
type Github struct {
|
||||
BaseProvider
|
||||
}
|
||||
|
||||
// NewGithubProvider creates new Github provider instance with some defaults.
|
||||
func NewGithubProvider() *Github {
|
||||
return &Github{BaseProvider{
|
||||
ctx: context.Background(),
|
||||
displayName: "GitHub",
|
||||
pkce: true, // technically is not supported yet but it is safe as the PKCE params are just ignored
|
||||
scopes: []string{"read:user", "user:email"},
|
||||
authURL: github.Endpoint.AuthURL,
|
||||
tokenURL: github.Endpoint.TokenURL,
|
||||
userInfoURL: "https://api.github.com/user",
|
||||
}}
|
||||
}
|
||||
|
||||
// FetchAuthUser returns an AuthUser instance based the Github's user api.
|
||||
//
|
||||
// API reference: https://docs.github.com/en/rest/reference/users#get-the-authenticated-user
|
||||
func (p *Github) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
data, err := p.FetchRawUserInfo(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rawUser := map[string]any{}
|
||||
if err := json.Unmarshal(data, &rawUser); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
extracted := struct {
|
||||
Login string `json:"login"`
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
AvatarURL string `json:"avatar_url"`
|
||||
Id int64 `json:"id"`
|
||||
}{}
|
||||
if err := json.Unmarshal(data, &extracted); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user := &AuthUser{
|
||||
Id: strconv.FormatInt(extracted.Id, 10),
|
||||
Name: extracted.Name,
|
||||
Username: extracted.Login,
|
||||
Email: extracted.Email,
|
||||
AvatarURL: extracted.AvatarURL,
|
||||
RawUser: rawUser,
|
||||
AccessToken: token.AccessToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
}
|
||||
|
||||
user.Expiry, _ = types.ParseDateTime(token.Expiry)
|
||||
|
||||
// in case user has set "Keep my email address private", send an
|
||||
// **optional** API request to retrieve the verified primary email
|
||||
if user.Email == "" {
|
||||
email, err := p.fetchPrimaryEmail(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
user.Email = email
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// fetchPrimaryEmail sends an API request to retrieve the verified
|
||||
// primary email, in case "Keep my email address private" was set.
|
||||
//
|
||||
// NB! This method can succeed and still return an empty email.
|
||||
// Error responses that are result of insufficient scopes permissions are ignored.
|
||||
//
|
||||
// API reference: https://docs.github.com/en/rest/users/emails?apiVersion=2022-11-28
|
||||
func (p *Github) fetchPrimaryEmail(token *oauth2.Token) (string, error) {
|
||||
client := p.Client(token)
|
||||
|
||||
response, err := client.Get(p.userInfoURL + "/emails")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
// ignore common http errors caused by insufficient scope permissions
|
||||
// (the email field is optional, aka. return the auth user without it)
|
||||
if response.StatusCode == 401 || response.StatusCode == 403 || response.StatusCode == 404 {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
content, err := io.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
emails := []struct {
|
||||
Email string
|
||||
Verified bool
|
||||
Primary bool
|
||||
}{}
|
||||
if err := json.Unmarshal(content, &emails); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// extract the verified primary email
|
||||
for _, email := range emails {
|
||||
if email.Verified && email.Primary {
|
||||
return email.Email, nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", nil
|
||||
}
|
78
tools/auth/gitlab.go
Normal file
78
tools/auth/gitlab.go
Normal file
|
@ -0,0 +1,78 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"strconv"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Providers[NameGitlab] = wrapFactory(NewGitlabProvider)
|
||||
}
|
||||
|
||||
var _ Provider = (*Gitlab)(nil)
|
||||
|
||||
// NameGitlab is the unique name of the Gitlab provider.
|
||||
const NameGitlab string = "gitlab"
|
||||
|
||||
// Gitlab allows authentication via Gitlab OAuth2.
|
||||
type Gitlab struct {
|
||||
BaseProvider
|
||||
}
|
||||
|
||||
// NewGitlabProvider creates new Gitlab provider instance with some defaults.
|
||||
func NewGitlabProvider() *Gitlab {
|
||||
return &Gitlab{BaseProvider{
|
||||
ctx: context.Background(),
|
||||
displayName: "GitLab",
|
||||
pkce: true,
|
||||
scopes: []string{"read_user"},
|
||||
authURL: "https://gitlab.com/oauth/authorize",
|
||||
tokenURL: "https://gitlab.com/oauth/token",
|
||||
userInfoURL: "https://gitlab.com/api/v4/user",
|
||||
}}
|
||||
}
|
||||
|
||||
// FetchAuthUser returns an AuthUser instance based the Gitlab's user api.
|
||||
//
|
||||
// API reference: https://docs.gitlab.com/ee/api/users.html#for-admin
|
||||
func (p *Gitlab) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
data, err := p.FetchRawUserInfo(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rawUser := map[string]any{}
|
||||
if err := json.Unmarshal(data, &rawUser); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
extracted := struct {
|
||||
Name string `json:"name"`
|
||||
Username string `json:"username"`
|
||||
Email string `json:"email"`
|
||||
AvatarURL string `json:"avatar_url"`
|
||||
Id int64 `json:"id"`
|
||||
}{}
|
||||
if err := json.Unmarshal(data, &extracted); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user := &AuthUser{
|
||||
Id: strconv.FormatInt(extracted.Id, 10),
|
||||
Name: extracted.Name,
|
||||
Username: extracted.Username,
|
||||
Email: extracted.Email,
|
||||
AvatarURL: extracted.AvatarURL,
|
||||
RawUser: rawUser,
|
||||
AccessToken: token.AccessToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
}
|
||||
|
||||
user.Expiry, _ = types.ParseDateTime(token.Expiry)
|
||||
|
||||
return user, nil
|
||||
}
|
80
tools/auth/google.go
Normal file
80
tools/auth/google.go
Normal file
|
@ -0,0 +1,80 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Providers[NameGoogle] = wrapFactory(NewGoogleProvider)
|
||||
}
|
||||
|
||||
var _ Provider = (*Google)(nil)
|
||||
|
||||
// NameGoogle is the unique name of the Google provider.
|
||||
const NameGoogle string = "google"
|
||||
|
||||
// Google allows authentication via Google OAuth2.
|
||||
type Google struct {
|
||||
BaseProvider
|
||||
}
|
||||
|
||||
// NewGoogleProvider creates new Google provider instance with some defaults.
|
||||
func NewGoogleProvider() *Google {
|
||||
return &Google{BaseProvider{
|
||||
ctx: context.Background(),
|
||||
displayName: "Google",
|
||||
pkce: true,
|
||||
scopes: []string{
|
||||
"https://www.googleapis.com/auth/userinfo.profile",
|
||||
"https://www.googleapis.com/auth/userinfo.email",
|
||||
},
|
||||
authURL: "https://accounts.google.com/o/oauth2/v2/auth",
|
||||
tokenURL: "https://oauth2.googleapis.com/token",
|
||||
userInfoURL: "https://www.googleapis.com/oauth2/v3/userinfo",
|
||||
}}
|
||||
}
|
||||
|
||||
// FetchAuthUser returns an AuthUser instance based the Google's user api.
|
||||
func (p *Google) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
data, err := p.FetchRawUserInfo(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rawUser := map[string]any{}
|
||||
if err := json.Unmarshal(data, &rawUser); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
extracted := struct {
|
||||
Id string `json:"sub"`
|
||||
Name string `json:"name"`
|
||||
Picture string `json:"picture"`
|
||||
Email string `json:"email"`
|
||||
EmailVerified bool `json:"email_verified"`
|
||||
}{}
|
||||
if err := json.Unmarshal(data, &extracted); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user := &AuthUser{
|
||||
Id: extracted.Id,
|
||||
Name: extracted.Name,
|
||||
AvatarURL: extracted.Picture,
|
||||
RawUser: rawUser,
|
||||
AccessToken: token.AccessToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
}
|
||||
|
||||
user.Expiry, _ = types.ParseDateTime(token.Expiry)
|
||||
|
||||
if extracted.EmailVerified {
|
||||
user.Email = extracted.Email
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
82
tools/auth/instagram.go
Normal file
82
tools/auth/instagram.go
Normal file
|
@ -0,0 +1,82 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Providers[NameInstagram] = wrapFactory(NewInstagramProvider)
|
||||
}
|
||||
|
||||
var _ Provider = (*Instagram)(nil)
|
||||
|
||||
// NameInstagram is the unique name of the Instagram provider.
|
||||
const NameInstagram string = "instagram2" // "2" suffix to avoid conflicts with the old deprecated version
|
||||
|
||||
// Instagram allows authentication via Instagram Login OAuth2.
|
||||
type Instagram struct {
|
||||
BaseProvider
|
||||
}
|
||||
|
||||
// NewInstagramProvider creates new Instagram provider instance with some defaults.
|
||||
func NewInstagramProvider() *Instagram {
|
||||
return &Instagram{BaseProvider{
|
||||
ctx: context.Background(),
|
||||
displayName: "Instagram",
|
||||
pkce: true,
|
||||
scopes: []string{"instagram_business_basic"},
|
||||
authURL: "https://www.instagram.com/oauth/authorize",
|
||||
tokenURL: "https://api.instagram.com/oauth/access_token",
|
||||
userInfoURL: "https://graph.instagram.com/me?fields=id,username,account_type,user_id,name,profile_picture_url,followers_count,follows_count,media_count",
|
||||
}}
|
||||
}
|
||||
|
||||
// FetchAuthUser returns an AuthUser instance based on the Instagram Login user api response.
|
||||
//
|
||||
// API reference: https://developers.facebook.com/docs/instagram-platform/instagram-api-with-instagram-login/get-started#fields
|
||||
func (p *Instagram) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
data, err := p.FetchRawUserInfo(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rawUser := map[string]any{}
|
||||
if err := json.Unmarshal(data, &rawUser); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// include list of granted permissions to RawUser's payload
|
||||
if _, ok := rawUser["permissions"]; !ok {
|
||||
if permissions := token.Extra("permissions"); permissions != nil {
|
||||
rawUser["permissions"] = permissions
|
||||
}
|
||||
}
|
||||
|
||||
extracted := struct {
|
||||
Id string `json:"user_id"`
|
||||
Name string `json:"name"`
|
||||
Username string `json:"username"`
|
||||
AvatarURL string `json:"profile_picture_url"`
|
||||
}{}
|
||||
if err := json.Unmarshal(data, &extracted); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user := &AuthUser{
|
||||
Id: extracted.Id,
|
||||
Username: extracted.Username,
|
||||
Name: extracted.Name,
|
||||
AvatarURL: extracted.AvatarURL,
|
||||
RawUser: rawUser,
|
||||
AccessToken: token.AccessToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
}
|
||||
|
||||
user.Expiry, _ = types.ParseDateTime(token.Expiry)
|
||||
|
||||
return user, nil
|
||||
}
|
86
tools/auth/kakao.go
Normal file
86
tools/auth/kakao.go
Normal file
|
@ -0,0 +1,86 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"strconv"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/kakao"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Providers[NameKakao] = wrapFactory(NewKakaoProvider)
|
||||
}
|
||||
|
||||
var _ Provider = (*Kakao)(nil)
|
||||
|
||||
// NameKakao is the unique name of the Kakao provider.
|
||||
const NameKakao string = "kakao"
|
||||
|
||||
// Kakao allows authentication via Kakao OAuth2.
|
||||
type Kakao struct {
|
||||
BaseProvider
|
||||
}
|
||||
|
||||
// NewKakaoProvider creates a new Kakao provider instance with some defaults.
|
||||
func NewKakaoProvider() *Kakao {
|
||||
return &Kakao{BaseProvider{
|
||||
ctx: context.Background(),
|
||||
displayName: "Kakao",
|
||||
pkce: true,
|
||||
scopes: []string{"account_email", "profile_nickname", "profile_image"},
|
||||
authURL: kakao.Endpoint.AuthURL,
|
||||
tokenURL: kakao.Endpoint.TokenURL,
|
||||
userInfoURL: "https://kapi.kakao.com/v2/user/me",
|
||||
}}
|
||||
}
|
||||
|
||||
// FetchAuthUser returns an AuthUser instance based on the Kakao's user api.
|
||||
//
|
||||
// API reference: https://developers.kakao.com/docs/latest/en/kakaologin/rest-api#req-user-info-response
|
||||
func (p *Kakao) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
data, err := p.FetchRawUserInfo(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rawUser := map[string]any{}
|
||||
if err := json.Unmarshal(data, &rawUser); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
extracted := struct {
|
||||
Profile struct {
|
||||
Nickname string `json:"nickname"`
|
||||
ImageURL string `json:"profile_image"`
|
||||
} `json:"properties"`
|
||||
KakaoAccount struct {
|
||||
Email string `json:"email"`
|
||||
IsEmailVerified bool `json:"is_email_verified"`
|
||||
IsEmailValid bool `json:"is_email_valid"`
|
||||
} `json:"kakao_account"`
|
||||
Id int64 `json:"id"`
|
||||
}{}
|
||||
if err := json.Unmarshal(data, &extracted); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user := &AuthUser{
|
||||
Id: strconv.FormatInt(extracted.Id, 10),
|
||||
Username: extracted.Profile.Nickname,
|
||||
AvatarURL: extracted.Profile.ImageURL,
|
||||
RawUser: rawUser,
|
||||
AccessToken: token.AccessToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
}
|
||||
|
||||
user.Expiry, _ = types.ParseDateTime(token.Expiry)
|
||||
|
||||
if extracted.KakaoAccount.IsEmailValid && extracted.KakaoAccount.IsEmailVerified {
|
||||
user.Email = extracted.KakaoAccount.Email
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
109
tools/auth/linear.go
Normal file
109
tools/auth/linear.go
Normal file
|
@ -0,0 +1,109 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Providers[NameLinear] = wrapFactory(NewLinearProvider)
|
||||
}
|
||||
|
||||
var _ Provider = (*Linear)(nil)
|
||||
|
||||
// NameLinear is the unique name of the Linear provider.
|
||||
const NameLinear string = "linear"
|
||||
|
||||
// Linear allows authentication via Linear OAuth2.
|
||||
type Linear struct {
|
||||
BaseProvider
|
||||
}
|
||||
|
||||
// NewLinearProvider creates new Linear provider instance with some defaults.
|
||||
//
|
||||
// API reference: https://developers.linear.app/docs/oauth/authentication
|
||||
func NewLinearProvider() *Linear {
|
||||
return &Linear{BaseProvider{
|
||||
ctx: context.Background(),
|
||||
displayName: "Linear",
|
||||
pkce: false, // Linear doesn't support PKCE at the moment and returns an error if enabled
|
||||
scopes: []string{"read"},
|
||||
authURL: "https://linear.app/oauth/authorize",
|
||||
tokenURL: "https://api.linear.app/oauth/token",
|
||||
userInfoURL: "https://api.linear.app/graphql",
|
||||
}}
|
||||
}
|
||||
|
||||
// FetchAuthUser returns an AuthUser instance based on the Linear's user api.
|
||||
//
|
||||
// API reference: https://developers.linear.app/docs/graphql/working-with-the-graphql-api#authentication
|
||||
func (p *Linear) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
data, err := p.FetchRawUserInfo(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rawUser := map[string]any{}
|
||||
if err := json.Unmarshal(data, &rawUser); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
extracted := struct {
|
||||
Data struct {
|
||||
Viewer struct {
|
||||
Id string `json:"id"`
|
||||
DisplayName string `json:"displayName"`
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
AvatarURL string `json:"avatarUrl"`
|
||||
Active bool `json:"active"`
|
||||
} `json:"viewer"`
|
||||
} `json:"data"`
|
||||
}{}
|
||||
|
||||
if err := json.Unmarshal(data, &extracted); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !extracted.Data.Viewer.Active {
|
||||
return nil, errors.New("the Linear user account is not active")
|
||||
}
|
||||
|
||||
user := &AuthUser{
|
||||
Id: extracted.Data.Viewer.Id,
|
||||
Name: extracted.Data.Viewer.Name,
|
||||
Username: extracted.Data.Viewer.DisplayName,
|
||||
Email: extracted.Data.Viewer.Email,
|
||||
AvatarURL: extracted.Data.Viewer.AvatarURL,
|
||||
RawUser: rawUser,
|
||||
AccessToken: token.AccessToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
}
|
||||
|
||||
user.Expiry, _ = types.ParseDateTime(token.Expiry)
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// FetchRawUserInfo implements Provider.FetchRawUserInfo interface method.
|
||||
//
|
||||
// Linear doesn't have a UserInfo endpoint and information on the user
|
||||
// is retrieved using their GraphQL API (https://developers.linear.app/docs/graphql/working-with-the-graphql-api#queries-and-mutations)
|
||||
func (p *Linear) FetchRawUserInfo(token *oauth2.Token) ([]byte, error) {
|
||||
query := []byte(`{"query": "query Me { viewer { id displayName name email avatarUrl active } }"}`)
|
||||
bodyReader := bytes.NewReader(query)
|
||||
|
||||
req, err := http.NewRequestWithContext(p.ctx, "POST", p.userInfoURL, bodyReader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
return p.sendRawUserInfoRequest(req, token)
|
||||
}
|
79
tools/auth/livechat.go
Normal file
79
tools/auth/livechat.go
Normal file
|
@ -0,0 +1,79 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Providers[NameLivechat] = wrapFactory(NewLivechatProvider)
|
||||
}
|
||||
|
||||
var _ Provider = (*Livechat)(nil)
|
||||
|
||||
// NameLivechat is the unique name of the Livechat provider.
|
||||
const NameLivechat = "livechat"
|
||||
|
||||
// Livechat allows authentication via Livechat OAuth2.
|
||||
type Livechat struct {
|
||||
BaseProvider
|
||||
}
|
||||
|
||||
// NewLivechatProvider creates new Livechat provider instance with some defaults.
|
||||
func NewLivechatProvider() *Livechat {
|
||||
return &Livechat{BaseProvider{
|
||||
ctx: context.Background(),
|
||||
displayName: "LiveChat",
|
||||
pkce: true,
|
||||
scopes: []string{}, // default scopes are specified from the provider dashboard
|
||||
authURL: "https://accounts.livechat.com/",
|
||||
tokenURL: "https://accounts.livechat.com/token",
|
||||
userInfoURL: "https://accounts.livechat.com/v2/accounts/me",
|
||||
}}
|
||||
}
|
||||
|
||||
// FetchAuthUser returns an AuthUser based on the Livechat accounts API.
|
||||
//
|
||||
// API reference: https://developers.livechat.com/docs/authorization
|
||||
func (p *Livechat) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
data, err := p.FetchRawUserInfo(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rawUser := map[string]any{}
|
||||
if err := json.Unmarshal(data, &rawUser); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
extracted := struct {
|
||||
Id string `json:"account_id"`
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
EmailVerified bool `json:"email_verified"`
|
||||
AvatarURL string `json:"avatar_url"`
|
||||
}{}
|
||||
if err := json.Unmarshal(data, &extracted); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user := &AuthUser{
|
||||
Id: extracted.Id,
|
||||
Name: extracted.Name,
|
||||
AvatarURL: extracted.AvatarURL,
|
||||
RawUser: rawUser,
|
||||
AccessToken: token.AccessToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
}
|
||||
|
||||
user.Expiry, _ = types.ParseDateTime(token.Expiry)
|
||||
|
||||
if extracted.EmailVerified {
|
||||
user.Email = extracted.Email
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
84
tools/auth/mailcow.go
Normal file
84
tools/auth/mailcow.go
Normal file
|
@ -0,0 +1,84 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Providers[NameMailcow] = wrapFactory(NewMailcowProvider)
|
||||
}
|
||||
|
||||
var _ Provider = (*Mailcow)(nil)
|
||||
|
||||
// NameMailcow is the unique name of the mailcow provider.
|
||||
const NameMailcow string = "mailcow"
|
||||
|
||||
// Mailcow allows authentication via mailcow OAuth2.
|
||||
type Mailcow struct {
|
||||
BaseProvider
|
||||
}
|
||||
|
||||
// NewMailcowProvider creates a new mailcow provider instance with some defaults.
|
||||
func NewMailcowProvider() *Mailcow {
|
||||
return &Mailcow{BaseProvider{
|
||||
ctx: context.Background(),
|
||||
displayName: "mailcow",
|
||||
pkce: true,
|
||||
scopes: []string{"profile"},
|
||||
}}
|
||||
}
|
||||
|
||||
// FetchAuthUser returns an AuthUser instance based on mailcow's user api.
|
||||
//
|
||||
// API reference: https://github.com/mailcow/mailcow-dockerized/blob/master/data/web/oauth/profile.php
|
||||
func (p *Mailcow) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
data, err := p.FetchRawUserInfo(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rawUser := map[string]any{}
|
||||
if err := json.Unmarshal(data, &rawUser); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
extracted := struct {
|
||||
Id string `json:"id"`
|
||||
Username string `json:"username"`
|
||||
Email string `json:"email"`
|
||||
FullName string `json:"full_name"`
|
||||
Active int `json:"active"`
|
||||
}{}
|
||||
if err := json.Unmarshal(data, &extracted); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if extracted.Active != 1 {
|
||||
return nil, errors.New("the mailcow user is not active")
|
||||
}
|
||||
|
||||
user := &AuthUser{
|
||||
Id: extracted.Id,
|
||||
Name: extracted.FullName,
|
||||
Username: extracted.Username,
|
||||
Email: extracted.Email,
|
||||
RawUser: rawUser,
|
||||
AccessToken: token.AccessToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
}
|
||||
|
||||
user.Expiry, _ = types.ParseDateTime(token.Expiry)
|
||||
|
||||
// mailcow usernames are usually just the email adresses, so we just take the part in front of the @
|
||||
if strings.Contains(user.Username, "@") {
|
||||
user.Username = strings.Split(user.Username, "@")[0]
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
76
tools/auth/microsoft.go
Normal file
76
tools/auth/microsoft.go
Normal file
|
@ -0,0 +1,76 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/microsoft"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Providers[NameMicrosoft] = wrapFactory(NewMicrosoftProvider)
|
||||
}
|
||||
|
||||
var _ Provider = (*Microsoft)(nil)
|
||||
|
||||
// NameMicrosoft is the unique name of the Microsoft provider.
|
||||
const NameMicrosoft string = "microsoft"
|
||||
|
||||
// Microsoft allows authentication via AzureADEndpoint OAuth2.
|
||||
type Microsoft struct {
|
||||
BaseProvider
|
||||
}
|
||||
|
||||
// NewMicrosoftProvider creates new Microsoft AD provider instance with some defaults.
|
||||
func NewMicrosoftProvider() *Microsoft {
|
||||
endpoints := microsoft.AzureADEndpoint("")
|
||||
return &Microsoft{BaseProvider{
|
||||
ctx: context.Background(),
|
||||
displayName: "Microsoft",
|
||||
pkce: true,
|
||||
scopes: []string{"User.Read"},
|
||||
authURL: endpoints.AuthURL,
|
||||
tokenURL: endpoints.TokenURL,
|
||||
userInfoURL: "https://graph.microsoft.com/v1.0/me",
|
||||
}}
|
||||
}
|
||||
|
||||
// FetchAuthUser returns an AuthUser instance based on the Microsoft's user api.
|
||||
//
|
||||
// API reference: https://learn.microsoft.com/en-us/azure/active-directory/develop/userinfo
|
||||
// Graph explorer: https://developer.microsoft.com/en-us/graph/graph-explorer
|
||||
func (p *Microsoft) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
data, err := p.FetchRawUserInfo(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rawUser := map[string]any{}
|
||||
if err := json.Unmarshal(data, &rawUser); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
extracted := struct {
|
||||
Id string `json:"id"`
|
||||
Name string `json:"displayName"`
|
||||
Email string `json:"mail"`
|
||||
}{}
|
||||
if err := json.Unmarshal(data, &extracted); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user := &AuthUser{
|
||||
Id: extracted.Id,
|
||||
Name: extracted.Name,
|
||||
Email: extracted.Email,
|
||||
RawUser: rawUser,
|
||||
AccessToken: token.AccessToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
}
|
||||
|
||||
user.Expiry, _ = types.ParseDateTime(token.Expiry)
|
||||
|
||||
return user, nil
|
||||
}
|
108
tools/auth/monday.go
Normal file
108
tools/auth/monday.go
Normal file
|
@ -0,0 +1,108 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Providers[NameMonday] = wrapFactory(NewMondayProvider)
|
||||
}
|
||||
|
||||
var _ Provider = (*Monday)(nil)
|
||||
|
||||
// NameMonday is the unique name of the Monday provider.
|
||||
const NameMonday = "monday"
|
||||
|
||||
// Monday is an auth provider for monday.com.
|
||||
type Monday struct {
|
||||
BaseProvider
|
||||
}
|
||||
|
||||
// NewMondayProvider creates a new Monday provider instance with some defaults.
|
||||
func NewMondayProvider() *Monday {
|
||||
return &Monday{BaseProvider{
|
||||
ctx: context.Background(),
|
||||
displayName: "monday.com",
|
||||
pkce: true,
|
||||
scopes: []string{"me:read"},
|
||||
authURL: "https://auth.monday.com/oauth2/authorize",
|
||||
tokenURL: "https://auth.monday.com/oauth2/token",
|
||||
userInfoURL: "https://api.monday.com/v2",
|
||||
}}
|
||||
}
|
||||
|
||||
// FetchAuthUser returns an AuthUser instance based on the Monday's user api.
|
||||
//
|
||||
// API reference: https://developer.monday.com/api-reference/reference/me
|
||||
func (p *Monday) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
data, err := p.FetchRawUserInfo(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rawUser := map[string]any{}
|
||||
if err := json.Unmarshal(data, &rawUser); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
extracted := struct {
|
||||
Data struct {
|
||||
Me struct {
|
||||
Id string `json:"id"`
|
||||
Enabled bool `json:"enabled"`
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
IsVerified bool `json:"is_verified"`
|
||||
Avatar string `json:"photo_small"`
|
||||
} `json:"me"`
|
||||
} `json:"data"`
|
||||
}{}
|
||||
if err := json.Unmarshal(data, &extracted); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !extracted.Data.Me.Enabled {
|
||||
return nil, errors.New("the monday.com user account is not enabled")
|
||||
}
|
||||
|
||||
user := &AuthUser{
|
||||
Id: extracted.Data.Me.Id,
|
||||
Name: extracted.Data.Me.Name,
|
||||
AvatarURL: extracted.Data.Me.Avatar,
|
||||
RawUser: rawUser,
|
||||
AccessToken: token.AccessToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
}
|
||||
|
||||
if extracted.Data.Me.IsVerified {
|
||||
user.Email = extracted.Data.Me.Email
|
||||
}
|
||||
|
||||
user.Expiry, _ = types.ParseDateTime(token.Expiry)
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// FetchRawUserInfo implements Provider.FetchRawUserInfo interface.
|
||||
//
|
||||
// monday.com doesn't have a UserInfo endpoint and information on the user
|
||||
// is retrieved using their GraphQL API (https://developer.monday.com/api-reference/reference/me#queries)
|
||||
func (p *Monday) FetchRawUserInfo(token *oauth2.Token) ([]byte, error) {
|
||||
query := []byte(`{"query": "query { me { id enabled name email is_verified photo_small }}"}`)
|
||||
bodyReader := bytes.NewReader(query)
|
||||
|
||||
req, err := http.NewRequestWithContext(p.ctx, "POST", p.userInfoURL, bodyReader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
return p.sendRawUserInfoRequest(req, token)
|
||||
}
|
106
tools/auth/notion.go
Normal file
106
tools/auth/notion.go
Normal file
|
@ -0,0 +1,106 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Providers[NameNotion] = wrapFactory(NewNotionProvider)
|
||||
}
|
||||
|
||||
var _ Provider = (*Notion)(nil)
|
||||
|
||||
// NameNotion is the unique name of the Notion provider.
|
||||
const NameNotion string = "notion"
|
||||
|
||||
// Notion allows authentication via Notion OAuth2.
|
||||
type Notion struct {
|
||||
BaseProvider
|
||||
}
|
||||
|
||||
// NewNotionProvider creates new Notion provider instance with some defaults.
|
||||
func NewNotionProvider() *Notion {
|
||||
return &Notion{BaseProvider{
|
||||
ctx: context.Background(),
|
||||
displayName: "Notion",
|
||||
pkce: true,
|
||||
authURL: "https://api.notion.com/v1/oauth/authorize",
|
||||
tokenURL: "https://api.notion.com/v1/oauth/token",
|
||||
userInfoURL: "https://api.notion.com/v1/users/me",
|
||||
}}
|
||||
}
|
||||
|
||||
// FetchAuthUser returns an AuthUser instance based on the Notion's User api.
|
||||
// API reference: https://developers.notion.com/reference/get-self
|
||||
func (p *Notion) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
data, err := p.FetchRawUserInfo(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rawUser := map[string]any{}
|
||||
if err := json.Unmarshal(data, &rawUser); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
extracted := struct {
|
||||
AvatarURL string `json:"avatar_url"`
|
||||
Bot struct {
|
||||
Owner struct {
|
||||
Type string `json:"type"`
|
||||
User struct {
|
||||
AvatarURL string `json:"avatar_url"`
|
||||
Id string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Person struct {
|
||||
Email string `json:"email"`
|
||||
} `json:"person"`
|
||||
} `json:"user"`
|
||||
} `json:"owner"`
|
||||
WorkspaceName string `json:"workspace_name"`
|
||||
} `json:"bot"`
|
||||
Id string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Object string `json:"object"`
|
||||
RequestId string `json:"request_id"`
|
||||
Type string `json:"type"`
|
||||
}{}
|
||||
|
||||
if err := json.Unmarshal(data, &extracted); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user := &AuthUser{
|
||||
Id: extracted.Bot.Owner.User.Id,
|
||||
Name: extracted.Bot.Owner.User.Name,
|
||||
Email: extracted.Bot.Owner.User.Person.Email,
|
||||
AvatarURL: extracted.Bot.Owner.User.AvatarURL,
|
||||
RawUser: rawUser,
|
||||
AccessToken: token.AccessToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
}
|
||||
|
||||
user.Expiry, _ = types.ParseDateTime(token.Expiry)
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// FetchRawUserInfo implements Provider.FetchRawUserInfo interface method.
|
||||
//
|
||||
// This differ from BaseProvider because Notion requires a version header for all requests
|
||||
// (https://developers.notion.com/reference/versioning).
|
||||
func (p *Notion) FetchRawUserInfo(token *oauth2.Token) ([]byte, error) {
|
||||
req, err := http.NewRequestWithContext(p.ctx, "GET", p.userInfoURL, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Header.Set("Notion-Version", "2022-06-28")
|
||||
|
||||
return p.sendRawUserInfoRequest(req, token)
|
||||
}
|
292
tools/auth/oidc.go
Normal file
292
tools/auth/oidc.go
Normal file
|
@ -0,0 +1,292 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rsa"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/pocketbase/pocketbase/tools/security"
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
"github.com/spf13/cast"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
// idTokenLeeway is the optional leeway for the id_token timestamp fields validation.
|
||||
//
|
||||
// It can be changed externally using the PB_ID_TOKEN_LEEWAY env variable
|
||||
// (the value must be in seconds, e.g. "PB_ID_TOKEN_LEEWAY=60" for 1 minute).
|
||||
var idTokenLeeway time.Duration = 5 * time.Minute
|
||||
|
||||
func init() {
|
||||
Providers[NameOIDC] = wrapFactory(NewOIDCProvider)
|
||||
Providers[NameOIDC+"2"] = wrapFactory(NewOIDCProvider)
|
||||
Providers[NameOIDC+"3"] = wrapFactory(NewOIDCProvider)
|
||||
|
||||
if leewayStr := os.Getenv("PB_ID_TOKEN_LEEWAY"); leewayStr != "" {
|
||||
leeway, err := strconv.Atoi(leewayStr)
|
||||
if err == nil {
|
||||
idTokenLeeway = time.Duration(leeway) * time.Second
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var _ Provider = (*OIDC)(nil)
|
||||
|
||||
// NameOIDC is the unique name of the OpenID Connect (OIDC) provider.
|
||||
const NameOIDC string = "oidc"
|
||||
|
||||
// OIDC allows authentication via OpenID Connect (OIDC) OAuth2 provider.
|
||||
//
|
||||
// If specified the user data is fetched from the userInfoURL.
|
||||
// Otherwise - from the id_token payload.
|
||||
//
|
||||
// The provider support the following Extra config options:
|
||||
// - "jwksURL" - url to the keys to validate the id_token signature (optional and used only when reading the user data from the id_token)
|
||||
// - "issuers" - list of valid issuers for the iss id_token claim (optioanl and used only when reading the user data from the id_token)
|
||||
type OIDC struct {
|
||||
BaseProvider
|
||||
}
|
||||
|
||||
// NewOIDCProvider creates new OpenID Connect (OIDC) provider instance with some defaults.
|
||||
func NewOIDCProvider() *OIDC {
|
||||
return &OIDC{BaseProvider{
|
||||
ctx: context.Background(),
|
||||
displayName: "OIDC",
|
||||
pkce: true,
|
||||
scopes: []string{
|
||||
"openid", // minimal requirement to return the id
|
||||
"email",
|
||||
"profile",
|
||||
},
|
||||
}}
|
||||
}
|
||||
|
||||
// FetchAuthUser returns an AuthUser instance based the provider's user api.
|
||||
//
|
||||
// API reference: https://openid.net/specs/openid-connect-core-1_0.html#StandardClaims
|
||||
func (p *OIDC) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
data, err := p.FetchRawUserInfo(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rawUser := map[string]any{}
|
||||
if err := json.Unmarshal(data, &rawUser); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
extracted := struct {
|
||||
Id string `json:"sub"`
|
||||
Name string `json:"name"`
|
||||
Username string `json:"preferred_username"`
|
||||
Picture string `json:"picture"`
|
||||
Email string `json:"email"`
|
||||
EmailVerified any `json:"email_verified"` // see #6657
|
||||
}{}
|
||||
if err := json.Unmarshal(data, &extracted); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user := &AuthUser{
|
||||
Id: extracted.Id,
|
||||
Name: extracted.Name,
|
||||
Username: extracted.Username,
|
||||
AvatarURL: extracted.Picture,
|
||||
RawUser: rawUser,
|
||||
AccessToken: token.AccessToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
}
|
||||
|
||||
user.Expiry, _ = types.ParseDateTime(token.Expiry)
|
||||
|
||||
if cast.ToBool(extracted.EmailVerified) {
|
||||
user.Email = extracted.Email
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// FetchRawUserInfo implements Provider.FetchRawUserInfo interface method.
|
||||
//
|
||||
// It either fetch the data from p.userInfoURL, or if not set - returns the id_token claims.
|
||||
func (p *OIDC) FetchRawUserInfo(token *oauth2.Token) ([]byte, error) {
|
||||
if p.userInfoURL != "" {
|
||||
return p.BaseProvider.FetchRawUserInfo(token)
|
||||
}
|
||||
|
||||
claims, err := p.parseIdToken(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return json.Marshal(claims)
|
||||
}
|
||||
|
||||
func (p *OIDC) parseIdToken(token *oauth2.Token) (jwt.MapClaims, error) {
|
||||
idToken := token.Extra("id_token").(string)
|
||||
if idToken == "" {
|
||||
return nil, errors.New("empty id_token")
|
||||
}
|
||||
|
||||
claims := jwt.MapClaims{}
|
||||
t, _, err := jwt.NewParser().ParseUnverified(idToken, claims)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// validate common claims
|
||||
jwtValidator := jwt.NewValidator(
|
||||
jwt.WithIssuedAt(),
|
||||
jwt.WithLeeway(idTokenLeeway),
|
||||
jwt.WithAudience(p.clientId),
|
||||
)
|
||||
err = jwtValidator.Validate(claims)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// validate iss (if "issuers" extra config is set)
|
||||
issuers := cast.ToStringSlice(p.Extra()["issuers"])
|
||||
if len(issuers) > 0 {
|
||||
var isIssValid bool
|
||||
claimIssuer, _ := claims.GetIssuer()
|
||||
|
||||
for _, issuer := range issuers {
|
||||
if security.Equal(claimIssuer, issuer) {
|
||||
isIssValid = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !isIssValid {
|
||||
return nil, fmt.Errorf("iss must be one of %v, got %#v", issuers, claims["iss"])
|
||||
}
|
||||
}
|
||||
|
||||
// validate signature (if "jwksURL" extra config is set)
|
||||
//
|
||||
// note: this step could be technically considered optional because we trust
|
||||
// the token which is a result of direct TLS communication with the provider
|
||||
// (see also https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation)
|
||||
jwksURL := cast.ToString(p.Extra()["jwksURL"])
|
||||
if jwksURL != "" {
|
||||
kid, _ := t.Header["kid"].(string)
|
||||
err = validateIdTokenSignature(p.ctx, idToken, jwksURL, kid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
func validateIdTokenSignature(ctx context.Context, idToken string, jwksURL string, kid string) error {
|
||||
// fetch the public key set
|
||||
// ---
|
||||
if kid == "" {
|
||||
return errors.New("missing kid header value")
|
||||
}
|
||||
|
||||
key, err := fetchJWK(ctx, jwksURL, kid)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// decode the key params per RFC 7518 (https://tools.ietf.org/html/rfc7518#section-6.3)
|
||||
// and construct a valid publicKey from them
|
||||
// ---
|
||||
exponent, err := base64.RawURLEncoding.DecodeString(strings.TrimRight(key.E, "="))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
modulus, err := base64.RawURLEncoding.DecodeString(strings.TrimRight(key.N, "="))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
publicKey := &rsa.PublicKey{
|
||||
// https://tools.ietf.org/html/rfc7517#appendix-A.1
|
||||
E: int(big.NewInt(0).SetBytes(exponent).Uint64()),
|
||||
N: big.NewInt(0).SetBytes(modulus),
|
||||
}
|
||||
|
||||
// verify the signiture
|
||||
// ---
|
||||
parser := jwt.NewParser(jwt.WithValidMethods([]string{key.Alg}))
|
||||
|
||||
parsedToken, err := parser.Parse(idToken, func(t *jwt.Token) (any, error) {
|
||||
return publicKey, nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !parsedToken.Valid {
|
||||
return errors.New("the parsed id_token is invalid")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type jwk struct {
|
||||
Kty string
|
||||
Kid string
|
||||
Use string
|
||||
Alg string
|
||||
N string
|
||||
E string
|
||||
}
|
||||
|
||||
func fetchJWK(ctx context.Context, jwksURL string, kid string) (*jwk, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", jwksURL, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
res, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
rawBody, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// http.Client.Get doesn't treat non 2xx responses as error
|
||||
if res.StatusCode >= 400 {
|
||||
return nil, fmt.Errorf(
|
||||
"failed to verify the provided id_token (%d):\n%s",
|
||||
res.StatusCode,
|
||||
string(rawBody),
|
||||
)
|
||||
}
|
||||
|
||||
jwks := struct {
|
||||
Keys []*jwk
|
||||
}{}
|
||||
if err := json.Unmarshal(rawBody, &jwks); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, key := range jwks.Keys {
|
||||
if key.Kid == kid {
|
||||
return key, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("jwk with kid %q was not found", kid)
|
||||
}
|
88
tools/auth/patreon.go
Normal file
88
tools/auth/patreon.go
Normal file
|
@ -0,0 +1,88 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/endpoints"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Providers[NamePatreon] = wrapFactory(NewPatreonProvider)
|
||||
}
|
||||
|
||||
var _ Provider = (*Patreon)(nil)
|
||||
|
||||
// NamePatreon is the unique name of the Patreon provider.
|
||||
const NamePatreon string = "patreon"
|
||||
|
||||
// Patreon allows authentication via Patreon OAuth2.
|
||||
type Patreon struct {
|
||||
BaseProvider
|
||||
}
|
||||
|
||||
// NewPatreonProvider creates new Patreon provider instance with some defaults.
|
||||
func NewPatreonProvider() *Patreon {
|
||||
return &Patreon{BaseProvider{
|
||||
ctx: context.Background(),
|
||||
displayName: "Patreon",
|
||||
pkce: true,
|
||||
scopes: []string{"identity", "identity[email]"},
|
||||
authURL: endpoints.Patreon.AuthURL,
|
||||
tokenURL: endpoints.Patreon.TokenURL,
|
||||
userInfoURL: "https://www.patreon.com/api/oauth2/v2/identity?fields%5Buser%5D=full_name,email,vanity,image_url,is_email_verified",
|
||||
}}
|
||||
}
|
||||
|
||||
// FetchAuthUser returns an AuthUser instance based on the Patreons's identity api.
|
||||
//
|
||||
// API reference:
|
||||
// https://docs.patreon.com/#get-api-oauth2-v2-identity
|
||||
// https://docs.patreon.com/#user-v2
|
||||
func (p *Patreon) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
data, err := p.FetchRawUserInfo(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rawUser := map[string]any{}
|
||||
if err := json.Unmarshal(data, &rawUser); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
extracted := struct {
|
||||
Data struct {
|
||||
Id string `json:"id"`
|
||||
Attributes struct {
|
||||
Email string `json:"email"`
|
||||
Name string `json:"full_name"`
|
||||
Username string `json:"vanity"`
|
||||
AvatarURL string `json:"image_url"`
|
||||
IsEmailVerified bool `json:"is_email_verified"`
|
||||
} `json:"attributes"`
|
||||
} `json:"data"`
|
||||
}{}
|
||||
if err := json.Unmarshal(data, &extracted); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user := &AuthUser{
|
||||
Id: extracted.Data.Id,
|
||||
Username: extracted.Data.Attributes.Username,
|
||||
Name: extracted.Data.Attributes.Name,
|
||||
AvatarURL: extracted.Data.Attributes.AvatarURL,
|
||||
RawUser: rawUser,
|
||||
AccessToken: token.AccessToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
}
|
||||
|
||||
user.Expiry, _ = types.ParseDateTime(token.Expiry)
|
||||
|
||||
if extracted.Data.Attributes.IsEmailVerified {
|
||||
user.Email = extracted.Data.Attributes.Email
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
85
tools/auth/planningcenter.go
Normal file
85
tools/auth/planningcenter.go
Normal file
|
@ -0,0 +1,85 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Providers[NamePlanningcenter] = wrapFactory(NewPlanningcenterProvider)
|
||||
}
|
||||
|
||||
var _ Provider = (*Planningcenter)(nil)
|
||||
|
||||
// NamePlanningcenter is the unique name of the Planningcenter provider.
|
||||
const NamePlanningcenter string = "planningcenter"
|
||||
|
||||
// Planningcenter allows authentication via Planningcenter OAuth2.
|
||||
type Planningcenter struct {
|
||||
BaseProvider
|
||||
}
|
||||
|
||||
// NewPlanningcenterProvider creates a new Planningcenter provider instance with some defaults.
|
||||
func NewPlanningcenterProvider() *Planningcenter {
|
||||
return &Planningcenter{BaseProvider{
|
||||
ctx: context.Background(),
|
||||
displayName: "Planning Center",
|
||||
pkce: true,
|
||||
scopes: []string{"people"},
|
||||
authURL: "https://api.planningcenteronline.com/oauth/authorize",
|
||||
tokenURL: "https://api.planningcenteronline.com/oauth/token",
|
||||
userInfoURL: "https://api.planningcenteronline.com/people/v2/me",
|
||||
}}
|
||||
}
|
||||
|
||||
// FetchAuthUser returns an AuthUser instance based on the Planningcenter's user api.
|
||||
//
|
||||
// API reference: https://developer.planning.center/docs/#/overview/authentication
|
||||
func (p *Planningcenter) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
data, err := p.FetchRawUserInfo(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rawUser := map[string]any{}
|
||||
if err := json.Unmarshal(data, &rawUser); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
extracted := struct {
|
||||
Data struct {
|
||||
Id string `json:"id"`
|
||||
Attributes struct {
|
||||
Status string `json:"status"`
|
||||
Name string `json:"name"`
|
||||
AvatarURL string `json:"avatar"`
|
||||
// don't map the email because users can have multiple assigned
|
||||
// and it's not clear if they are verified
|
||||
}
|
||||
}
|
||||
}{}
|
||||
if err := json.Unmarshal(data, &extracted); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if extracted.Data.Attributes.Status != "active" {
|
||||
return nil, errors.New("the user is not active")
|
||||
}
|
||||
|
||||
user := &AuthUser{
|
||||
Id: extracted.Data.Id,
|
||||
Name: extracted.Data.Attributes.Name,
|
||||
AvatarURL: extracted.Data.Attributes.AvatarURL,
|
||||
RawUser: rawUser,
|
||||
AccessToken: token.AccessToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
}
|
||||
|
||||
user.Expiry, _ = types.ParseDateTime(token.Expiry)
|
||||
|
||||
return user, nil
|
||||
}
|
87
tools/auth/spotify.go
Normal file
87
tools/auth/spotify.go
Normal file
|
@ -0,0 +1,87 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/spotify"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Providers[NameSpotify] = wrapFactory(NewSpotifyProvider)
|
||||
}
|
||||
|
||||
var _ Provider = (*Spotify)(nil)
|
||||
|
||||
// NameSpotify is the unique name of the Spotify provider.
|
||||
const NameSpotify string = "spotify"
|
||||
|
||||
// Spotify allows authentication via Spotify OAuth2.
|
||||
type Spotify struct {
|
||||
BaseProvider
|
||||
}
|
||||
|
||||
// NewSpotifyProvider creates a new Spotify provider instance with some defaults.
|
||||
func NewSpotifyProvider() *Spotify {
|
||||
return &Spotify{BaseProvider{
|
||||
ctx: context.Background(),
|
||||
displayName: "Spotify",
|
||||
pkce: true,
|
||||
scopes: []string{
|
||||
"user-read-private",
|
||||
// currently Spotify doesn't return information whether the email is verified or not
|
||||
// "user-read-email",
|
||||
},
|
||||
authURL: spotify.Endpoint.AuthURL,
|
||||
tokenURL: spotify.Endpoint.TokenURL,
|
||||
userInfoURL: "https://api.spotify.com/v1/me",
|
||||
}}
|
||||
}
|
||||
|
||||
// FetchAuthUser returns an AuthUser instance based on the Spotify's user api.
|
||||
//
|
||||
// API reference: https://developer.spotify.com/documentation/web-api/reference/#/operations/get-current-users-profile
|
||||
func (p *Spotify) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
data, err := p.FetchRawUserInfo(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rawUser := map[string]any{}
|
||||
if err := json.Unmarshal(data, &rawUser); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
extracted := struct {
|
||||
Id string `json:"id"`
|
||||
Name string `json:"display_name"`
|
||||
Images []struct {
|
||||
URL string `json:"url"`
|
||||
} `json:"images"`
|
||||
// don't map the email because per the official docs
|
||||
// the email field is "unverified" and there is no proof
|
||||
// that it actually belongs to the user
|
||||
// Email string `json:"email"`
|
||||
}{}
|
||||
if err := json.Unmarshal(data, &extracted); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user := &AuthUser{
|
||||
Id: extracted.Id,
|
||||
Name: extracted.Name,
|
||||
RawUser: rawUser,
|
||||
AccessToken: token.AccessToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
}
|
||||
|
||||
user.Expiry, _ = types.ParseDateTime(token.Expiry)
|
||||
|
||||
if len(extracted.Images) > 0 {
|
||||
user.AvatarURL = extracted.Images[0].URL
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
85
tools/auth/strava.go
Normal file
85
tools/auth/strava.go
Normal file
|
@ -0,0 +1,85 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"strconv"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Providers[NameStrava] = wrapFactory(NewStravaProvider)
|
||||
}
|
||||
|
||||
var _ Provider = (*Strava)(nil)
|
||||
|
||||
// NameStrava is the unique name of the Strava provider.
|
||||
const NameStrava string = "strava"
|
||||
|
||||
// Strava allows authentication via Strava OAuth2.
|
||||
type Strava struct {
|
||||
BaseProvider
|
||||
}
|
||||
|
||||
// NewStravaProvider creates new Strava provider instance with some defaults.
|
||||
func NewStravaProvider() *Strava {
|
||||
return &Strava{BaseProvider{
|
||||
ctx: context.Background(),
|
||||
displayName: "Strava",
|
||||
pkce: true,
|
||||
scopes: []string{
|
||||
"profile:read_all",
|
||||
},
|
||||
authURL: "https://www.strava.com/oauth/authorize",
|
||||
tokenURL: "https://www.strava.com/api/v3/oauth/token",
|
||||
userInfoURL: "https://www.strava.com/api/v3/athlete",
|
||||
}}
|
||||
}
|
||||
|
||||
// FetchAuthUser returns an AuthUser instance based on the Strava's user api.
|
||||
//
|
||||
// API reference: https://developers.strava.com/docs/authentication/
|
||||
func (p *Strava) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
data, err := p.FetchRawUserInfo(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rawUser := map[string]any{}
|
||||
if err := json.Unmarshal(data, &rawUser); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
extracted := struct {
|
||||
Id int64 `json:"id"`
|
||||
FirstName string `json:"firstname"`
|
||||
LastName string `json:"lastname"`
|
||||
Username string `json:"username"`
|
||||
ProfileImageURL string `json:"profile"`
|
||||
|
||||
// At the time of writing, Strava OAuth2 doesn't support returning the user email address
|
||||
// Email string `json:"email"`
|
||||
}{}
|
||||
if err := json.Unmarshal(data, &extracted); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user := &AuthUser{
|
||||
Name: extracted.FirstName + " " + extracted.LastName,
|
||||
Username: extracted.Username,
|
||||
AvatarURL: extracted.ProfileImageURL,
|
||||
RawUser: rawUser,
|
||||
AccessToken: token.AccessToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
}
|
||||
|
||||
user.Expiry, _ = types.ParseDateTime(token.Expiry)
|
||||
|
||||
if extracted.Id != 0 {
|
||||
user.Id = strconv.FormatInt(extracted.Id, 10)
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
102
tools/auth/trakt.go
Normal file
102
tools/auth/trakt.go
Normal file
|
@ -0,0 +1,102 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Providers[NameTrakt] = wrapFactory(NewTraktProvider)
|
||||
}
|
||||
|
||||
var _ Provider = (*Trakt)(nil)
|
||||
|
||||
// NameTrakt is the unique name of the Trakt provider.
|
||||
const NameTrakt string = "trakt"
|
||||
|
||||
// Trakt allows authentication via Trakt OAuth2.
|
||||
type Trakt struct {
|
||||
BaseProvider
|
||||
}
|
||||
|
||||
// NewTraktProvider creates new Trakt provider instance with some defaults.
|
||||
func NewTraktProvider() *Trakt {
|
||||
return &Trakt{BaseProvider{
|
||||
ctx: context.Background(),
|
||||
displayName: "Trakt",
|
||||
pkce: true,
|
||||
authURL: "https://trakt.tv/oauth/authorize",
|
||||
tokenURL: "https://api.trakt.tv/oauth/token",
|
||||
userInfoURL: "https://api.trakt.tv/users/settings",
|
||||
}}
|
||||
}
|
||||
|
||||
// FetchAuthUser returns an AuthUser instance based on Trakt's user settings API.
|
||||
// API reference: https://trakt.docs.apiary.io/#reference/users/settings/retrieve-settings
|
||||
func (p *Trakt) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
data, err := p.FetchRawUserInfo(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rawUser := map[string]any{}
|
||||
if err := json.Unmarshal(data, &rawUser); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
extracted := struct {
|
||||
User struct {
|
||||
Username string `json:"username"`
|
||||
Name string `json:"name"`
|
||||
Ids struct {
|
||||
Slug string `json:"slug"`
|
||||
UUID string `json:"uuid"`
|
||||
} `json:"ids"`
|
||||
Images struct {
|
||||
Avatar struct {
|
||||
Full string `json:"full"`
|
||||
} `json:"avatar"`
|
||||
} `json:"images"`
|
||||
} `json:"user"`
|
||||
}{}
|
||||
|
||||
if err := json.Unmarshal(data, &extracted); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user := &AuthUser{
|
||||
Id: extracted.User.Ids.UUID,
|
||||
Username: extracted.User.Username,
|
||||
Name: extracted.User.Name,
|
||||
AvatarURL: extracted.User.Images.Avatar.Full,
|
||||
RawUser: rawUser,
|
||||
AccessToken: token.AccessToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
}
|
||||
|
||||
user.Expiry, _ = types.ParseDateTime(token.Expiry)
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// FetchRawUserInfo implements Provider.FetchRawUserInfo interface method.
|
||||
//
|
||||
// This differ from BaseProvider because Trakt requires a number of
|
||||
// mandatory headers for all requests
|
||||
// (https://trakt.docs.apiary.io/#introduction/required-headers).
|
||||
func (p *Trakt) FetchRawUserInfo(token *oauth2.Token) ([]byte, error) {
|
||||
req, err := http.NewRequestWithContext(p.ctx, "GET", p.userInfoURL, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Header.Set("Content-type", "application/json")
|
||||
req.Header.Set("trakt-api-key", p.clientId)
|
||||
req.Header.Set("trakt-api-version", "2")
|
||||
|
||||
return p.sendRawUserInfoRequest(req, token)
|
||||
}
|
100
tools/auth/twitch.go
Normal file
100
tools/auth/twitch.go
Normal file
|
@ -0,0 +1,100 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/twitch"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Providers[NameTwitch] = wrapFactory(NewTwitchProvider)
|
||||
}
|
||||
|
||||
var _ Provider = (*Twitch)(nil)
|
||||
|
||||
// NameTwitch is the unique name of the Twitch provider.
|
||||
const NameTwitch string = "twitch"
|
||||
|
||||
// Twitch allows authentication via Twitch OAuth2.
|
||||
type Twitch struct {
|
||||
BaseProvider
|
||||
}
|
||||
|
||||
// NewTwitchProvider creates new Twitch provider instance with some defaults.
|
||||
func NewTwitchProvider() *Twitch {
|
||||
return &Twitch{BaseProvider{
|
||||
ctx: context.Background(),
|
||||
displayName: "Twitch",
|
||||
pkce: true,
|
||||
scopes: []string{"user:read:email"},
|
||||
authURL: twitch.Endpoint.AuthURL,
|
||||
tokenURL: twitch.Endpoint.TokenURL,
|
||||
userInfoURL: "https://api.twitch.tv/helix/users",
|
||||
}}
|
||||
}
|
||||
|
||||
// FetchAuthUser returns an AuthUser instance based the Twitch's user api.
|
||||
//
|
||||
// API reference: https://dev.twitch.tv/docs/api/reference#get-users
|
||||
func (p *Twitch) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
data, err := p.FetchRawUserInfo(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rawUser := map[string]any{}
|
||||
if err := json.Unmarshal(data, &rawUser); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
extracted := struct {
|
||||
Data []struct {
|
||||
Id string `json:"id"`
|
||||
Login string `json:"login"`
|
||||
DisplayName string `json:"display_name"`
|
||||
Email string `json:"email"`
|
||||
ProfileImageURL string `json:"profile_image_url"`
|
||||
} `json:"data"`
|
||||
}{}
|
||||
if err := json.Unmarshal(data, &extracted); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(extracted.Data) == 0 {
|
||||
return nil, errors.New("failed to fetch AuthUser data")
|
||||
}
|
||||
|
||||
user := &AuthUser{
|
||||
Id: extracted.Data[0].Id,
|
||||
Name: extracted.Data[0].DisplayName,
|
||||
Username: extracted.Data[0].Login,
|
||||
Email: extracted.Data[0].Email,
|
||||
AvatarURL: extracted.Data[0].ProfileImageURL,
|
||||
RawUser: rawUser,
|
||||
AccessToken: token.AccessToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
}
|
||||
|
||||
user.Expiry, _ = types.ParseDateTime(token.Expiry)
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// FetchRawUserInfo implements Provider.FetchRawUserInfo interface method.
|
||||
//
|
||||
// This differ from BaseProvider because Twitch requires the Client-Id header.
|
||||
func (p *Twitch) FetchRawUserInfo(token *oauth2.Token) ([]byte, error) {
|
||||
req, err := http.NewRequestWithContext(p.ctx, "GET", p.userInfoURL, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Header.Set("Client-Id", p.clientId)
|
||||
|
||||
return p.sendRawUserInfoRequest(req, token)
|
||||
}
|
87
tools/auth/twitter.go
Normal file
87
tools/auth/twitter.go
Normal file
|
@ -0,0 +1,87 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Providers[NameTwitter] = wrapFactory(NewTwitterProvider)
|
||||
}
|
||||
|
||||
var _ Provider = (*Twitter)(nil)
|
||||
|
||||
// NameTwitter is the unique name of the Twitter provider.
|
||||
const NameTwitter string = "twitter"
|
||||
|
||||
// Twitter allows authentication via Twitter OAuth2.
|
||||
type Twitter struct {
|
||||
BaseProvider
|
||||
}
|
||||
|
||||
// NewTwitterProvider creates new Twitter provider instance with some defaults.
|
||||
func NewTwitterProvider() *Twitter {
|
||||
return &Twitter{BaseProvider{
|
||||
ctx: context.Background(),
|
||||
displayName: "Twitter",
|
||||
pkce: true,
|
||||
scopes: []string{
|
||||
"users.read",
|
||||
|
||||
// we don't actually use this scope, but for some reason it is required by the `/2/users/me` endpoint
|
||||
// (see https://developer.twitter.com/en/docs/twitter-api/users/lookup/api-reference/get-users-me)
|
||||
"tweet.read",
|
||||
},
|
||||
authURL: "https://twitter.com/i/oauth2/authorize",
|
||||
tokenURL: "https://api.twitter.com/2/oauth2/token",
|
||||
userInfoURL: "https://api.twitter.com/2/users/me?user.fields=id,name,username,profile_image_url",
|
||||
}}
|
||||
}
|
||||
|
||||
// FetchAuthUser returns an AuthUser instance based on the Twitter's user api.
|
||||
//
|
||||
// API reference: https://developer.twitter.com/en/docs/twitter-api/users/lookup/api-reference/get-users-me
|
||||
func (p *Twitter) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
data, err := p.FetchRawUserInfo(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rawUser := map[string]any{}
|
||||
if err := json.Unmarshal(data, &rawUser); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
extracted := struct {
|
||||
Data struct {
|
||||
Id string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Username string `json:"username"`
|
||||
ProfileImageURL string `json:"profile_image_url"`
|
||||
|
||||
// NB! At the time of writing, Twitter OAuth2 doesn't support returning the user email address
|
||||
// (see https://twittercommunity.com/t/which-api-to-get-user-after-oauth2-authorization/162417/33)
|
||||
// Email string `json:"email"`
|
||||
} `json:"data"`
|
||||
}{}
|
||||
if err := json.Unmarshal(data, &extracted); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user := &AuthUser{
|
||||
Id: extracted.Data.Id,
|
||||
Name: extracted.Data.Name,
|
||||
Username: extracted.Data.Username,
|
||||
AvatarURL: extracted.Data.ProfileImageURL,
|
||||
RawUser: rawUser,
|
||||
AccessToken: token.AccessToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
}
|
||||
|
||||
user.Expiry, _ = types.ParseDateTime(token.Expiry)
|
||||
|
||||
return user, nil
|
||||
}
|
94
tools/auth/vk.go
Normal file
94
tools/auth/vk.go
Normal file
|
@ -0,0 +1,94 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/vk"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Providers[NameVK] = wrapFactory(NewVKProvider)
|
||||
}
|
||||
|
||||
var _ Provider = (*VK)(nil)
|
||||
|
||||
// NameVK is the unique name of the VK provider.
|
||||
const NameVK string = "vk"
|
||||
|
||||
// VK allows authentication via VK OAuth2.
|
||||
type VK struct {
|
||||
BaseProvider
|
||||
}
|
||||
|
||||
// NewVKProvider creates new VK provider instance with some defaults.
|
||||
//
|
||||
// Docs: https://dev.vk.com/api/oauth-parameters
|
||||
func NewVKProvider() *VK {
|
||||
return &VK{BaseProvider{
|
||||
ctx: context.Background(),
|
||||
displayName: "ВКонтакте",
|
||||
pkce: false, // VK currently doesn't support PKCE and throws an error if PKCE params are send
|
||||
scopes: []string{"email"},
|
||||
authURL: vk.Endpoint.AuthURL,
|
||||
tokenURL: vk.Endpoint.TokenURL,
|
||||
userInfoURL: "https://api.vk.com/method/users.get?fields=photo_max,screen_name&v=5.131",
|
||||
}}
|
||||
}
|
||||
|
||||
// FetchAuthUser returns an AuthUser instance based on VK's user api.
|
||||
//
|
||||
// API reference: https://dev.vk.com/method/users.get
|
||||
func (p *VK) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
data, err := p.FetchRawUserInfo(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rawUser := map[string]any{}
|
||||
if err := json.Unmarshal(data, &rawUser); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
extracted := struct {
|
||||
Response []struct {
|
||||
Id int64 `json:"id"`
|
||||
FirstName string `json:"first_name"`
|
||||
LastName string `json:"last_name"`
|
||||
Username string `json:"screen_name"`
|
||||
AvatarURL string `json:"photo_max"`
|
||||
} `json:"response"`
|
||||
}{}
|
||||
|
||||
if err := json.Unmarshal(data, &extracted); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(extracted.Response) == 0 {
|
||||
return nil, errors.New("missing response entry")
|
||||
}
|
||||
|
||||
user := &AuthUser{
|
||||
Id: strconv.FormatInt(extracted.Response[0].Id, 10),
|
||||
Name: strings.TrimSpace(extracted.Response[0].FirstName + " " + extracted.Response[0].LastName),
|
||||
Username: extracted.Response[0].Username,
|
||||
AvatarURL: extracted.Response[0].AvatarURL,
|
||||
RawUser: rawUser,
|
||||
AccessToken: token.AccessToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
}
|
||||
|
||||
user.Expiry, _ = types.ParseDateTime(token.Expiry)
|
||||
|
||||
if email := token.Extra("email"); email != nil {
|
||||
user.Email = fmt.Sprint(email)
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
90
tools/auth/wakatime.go
Normal file
90
tools/auth/wakatime.go
Normal file
|
@ -0,0 +1,90 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Providers[NameWakatime] = wrapFactory(NewWakatimeProvider)
|
||||
}
|
||||
|
||||
var _ Provider = (*Wakatime)(nil)
|
||||
|
||||
// NameWakatime is the unique name of the Wakatime provider.
|
||||
const NameWakatime = "wakatime"
|
||||
|
||||
// Wakatime is an auth provider for Wakatime.
|
||||
type Wakatime struct {
|
||||
BaseProvider
|
||||
}
|
||||
|
||||
// NewWakatimeProvider creates a new Wakatime provider instance with some defaults.
|
||||
func NewWakatimeProvider() *Wakatime {
|
||||
return &Wakatime{BaseProvider{
|
||||
ctx: context.Background(),
|
||||
displayName: "WakaTime",
|
||||
pkce: true,
|
||||
scopes: []string{"email"},
|
||||
authURL: "https://wakatime.com/oauth/authorize",
|
||||
tokenURL: "https://wakatime.com/oauth/token",
|
||||
userInfoURL: "https://wakatime.com/api/v1/users/current",
|
||||
}}
|
||||
}
|
||||
|
||||
// FetchAuthUser returns an AuthUser instance based on the Wakatime's user API.
|
||||
//
|
||||
// API reference: https://wakatime.com/developers#users
|
||||
func (p *Wakatime) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
data, err := p.FetchRawUserInfo(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rawUser := map[string]any{}
|
||||
if err := json.Unmarshal(data, &rawUser); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
extracted := struct {
|
||||
Data struct {
|
||||
Id string `json:"id"`
|
||||
DisplayName string `json:"display_name"`
|
||||
Username string `json:"username"`
|
||||
Email string `json:"email"`
|
||||
Photo string `json:"photo"`
|
||||
IsPhotoPublic bool `json:"photo_public"`
|
||||
IsEmailConfirmed bool `json:"is_email_confirmed"`
|
||||
} `json:"data"`
|
||||
}{}
|
||||
|
||||
if err := json.Unmarshal(data, &extracted); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user := &AuthUser{
|
||||
Id: extracted.Data.Id,
|
||||
Name: extracted.Data.DisplayName,
|
||||
Username: extracted.Data.Username,
|
||||
RawUser: rawUser,
|
||||
AccessToken: token.AccessToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
}
|
||||
|
||||
// note: we don't check for is_email_public field because PocketBase
|
||||
// has its own emailVisibility flag which is false by default
|
||||
if extracted.Data.IsEmailConfirmed {
|
||||
user.Email = extracted.Data.Email
|
||||
}
|
||||
|
||||
if extracted.Data.IsPhotoPublic {
|
||||
user.AvatarURL = extracted.Data.Photo
|
||||
}
|
||||
|
||||
user.Expiry, _ = types.ParseDateTime(token.Expiry)
|
||||
|
||||
return user, nil
|
||||
}
|
84
tools/auth/yandex.go
Normal file
84
tools/auth/yandex.go
Normal file
|
@ -0,0 +1,84 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/yandex"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Providers[NameYandex] = wrapFactory(NewYandexProvider)
|
||||
}
|
||||
|
||||
var _ Provider = (*Yandex)(nil)
|
||||
|
||||
// NameYandex is the unique name of the Yandex provider.
|
||||
const NameYandex string = "yandex"
|
||||
|
||||
// Yandex allows authentication via Yandex OAuth2.
|
||||
type Yandex struct {
|
||||
BaseProvider
|
||||
}
|
||||
|
||||
// NewYandexProvider creates new Yandex provider instance with some defaults.
|
||||
//
|
||||
// Docs: https://yandex.ru/dev/id/doc/en/
|
||||
func NewYandexProvider() *Yandex {
|
||||
return &Yandex{BaseProvider{
|
||||
ctx: context.Background(),
|
||||
displayName: "Yandex",
|
||||
pkce: true,
|
||||
scopes: []string{"login:email", "login:avatar", "login:info"},
|
||||
authURL: yandex.Endpoint.AuthURL,
|
||||
tokenURL: yandex.Endpoint.TokenURL,
|
||||
userInfoURL: "https://login.yandex.ru/info",
|
||||
}}
|
||||
}
|
||||
|
||||
// FetchAuthUser returns an AuthUser instance based on Yandex's user api.
|
||||
//
|
||||
// API reference: https://yandex.ru/dev/id/doc/en/user-information#response-format
|
||||
func (p *Yandex) FetchAuthUser(token *oauth2.Token) (*AuthUser, error) {
|
||||
data, err := p.FetchRawUserInfo(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rawUser := map[string]any{}
|
||||
if err := json.Unmarshal(data, &rawUser); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
extracted := struct {
|
||||
Id string `json:"id"`
|
||||
Name string `json:"real_name"`
|
||||
Username string `json:"login"`
|
||||
Email string `json:"default_email"`
|
||||
IsAvatarEmpty bool `json:"is_avatar_empty"`
|
||||
AvatarId string `json:"default_avatar_id"`
|
||||
}{}
|
||||
if err := json.Unmarshal(data, &extracted); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user := &AuthUser{
|
||||
Id: extracted.Id,
|
||||
Name: extracted.Name,
|
||||
Username: extracted.Username,
|
||||
Email: extracted.Email,
|
||||
RawUser: rawUser,
|
||||
AccessToken: token.AccessToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
}
|
||||
|
||||
user.Expiry, _ = types.ParseDateTime(token.Expiry)
|
||||
|
||||
if !extracted.IsAvatarEmpty {
|
||||
user.AvatarURL = "https://avatars.yandex.net/get-yapic/" + extracted.AvatarId + "/islands-200"
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
228
tools/cron/cron.go
Normal file
228
tools/cron/cron.go
Normal file
|
@ -0,0 +1,228 @@
|
|||
// Package cron implements a crontab-like service to execute and schedule
|
||||
// repeative tasks/jobs.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// c := cron.New()
|
||||
// c.MustAdd("dailyReport", "0 0 * * *", func() { ... })
|
||||
// c.Start()
|
||||
package cron
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Cron is a crontab-like struct for tasks/jobs scheduling.
|
||||
type Cron struct {
|
||||
timezone *time.Location
|
||||
ticker *time.Ticker
|
||||
startTimer *time.Timer
|
||||
tickerDone chan bool
|
||||
jobs []*Job
|
||||
interval time.Duration
|
||||
mux sync.RWMutex
|
||||
}
|
||||
|
||||
// New create a new Cron struct with default tick interval of 1 minute
|
||||
// and timezone in UTC.
|
||||
//
|
||||
// You can change the default tick interval with Cron.SetInterval().
|
||||
// You can change the default timezone with Cron.SetTimezone().
|
||||
func New() *Cron {
|
||||
return &Cron{
|
||||
interval: 1 * time.Minute,
|
||||
timezone: time.UTC,
|
||||
jobs: []*Job{},
|
||||
tickerDone: make(chan bool),
|
||||
}
|
||||
}
|
||||
|
||||
// SetInterval changes the current cron tick interval
|
||||
// (it usually should be >= 1 minute).
|
||||
func (c *Cron) SetInterval(d time.Duration) {
|
||||
// update interval
|
||||
c.mux.Lock()
|
||||
wasStarted := c.ticker != nil
|
||||
c.interval = d
|
||||
c.mux.Unlock()
|
||||
|
||||
// restart the ticker
|
||||
if wasStarted {
|
||||
c.Start()
|
||||
}
|
||||
}
|
||||
|
||||
// SetTimezone changes the current cron tick timezone.
|
||||
func (c *Cron) SetTimezone(l *time.Location) {
|
||||
c.mux.Lock()
|
||||
defer c.mux.Unlock()
|
||||
|
||||
c.timezone = l
|
||||
}
|
||||
|
||||
// MustAdd is similar to Add() but panic on failure.
|
||||
func (c *Cron) MustAdd(jobId string, cronExpr string, run func()) {
|
||||
if err := c.Add(jobId, cronExpr, run); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Add registers a single cron job.
|
||||
//
|
||||
// If there is already a job with the provided id, then the old job
|
||||
// will be replaced with the new one.
|
||||
//
|
||||
// cronExpr is a regular cron expression, eg. "0 */3 * * *" (aka. at minute 0 past every 3rd hour).
|
||||
// Check cron.NewSchedule() for the supported tokens.
|
||||
func (c *Cron) Add(jobId string, cronExpr string, fn func()) error {
|
||||
if fn == nil {
|
||||
return errors.New("failed to add new cron job: fn must be non-nil function")
|
||||
}
|
||||
|
||||
schedule, err := NewSchedule(cronExpr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add new cron job: %w", err)
|
||||
}
|
||||
|
||||
c.mux.Lock()
|
||||
defer c.mux.Unlock()
|
||||
|
||||
// remove previous (if any)
|
||||
c.jobs = slices.DeleteFunc(c.jobs, func(j *Job) bool {
|
||||
return j.Id() == jobId
|
||||
})
|
||||
|
||||
// add new
|
||||
c.jobs = append(c.jobs, &Job{
|
||||
id: jobId,
|
||||
fn: fn,
|
||||
schedule: schedule,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Remove removes a single cron job by its id.
|
||||
func (c *Cron) Remove(jobId string) {
|
||||
c.mux.Lock()
|
||||
defer c.mux.Unlock()
|
||||
|
||||
if c.jobs == nil {
|
||||
return // nothing to remove
|
||||
}
|
||||
|
||||
c.jobs = slices.DeleteFunc(c.jobs, func(j *Job) bool {
|
||||
return j.Id() == jobId
|
||||
})
|
||||
}
|
||||
|
||||
// RemoveAll removes all registered cron jobs.
|
||||
func (c *Cron) RemoveAll() {
|
||||
c.mux.Lock()
|
||||
defer c.mux.Unlock()
|
||||
|
||||
c.jobs = []*Job{}
|
||||
}
|
||||
|
||||
// Total returns the current total number of registered cron jobs.
|
||||
func (c *Cron) Total() int {
|
||||
c.mux.RLock()
|
||||
defer c.mux.RUnlock()
|
||||
|
||||
return len(c.jobs)
|
||||
}
|
||||
|
||||
// Jobs returns a shallow copy of the currently registered cron jobs.
|
||||
func (c *Cron) Jobs() []*Job {
|
||||
c.mux.RLock()
|
||||
defer c.mux.RUnlock()
|
||||
|
||||
copy := make([]*Job, len(c.jobs))
|
||||
for i, j := range c.jobs {
|
||||
copy[i] = j
|
||||
}
|
||||
|
||||
return copy
|
||||
}
|
||||
|
||||
// Stop stops the current cron ticker (if not already).
|
||||
//
|
||||
// You can resume the ticker by calling Start().
|
||||
func (c *Cron) Stop() {
|
||||
c.mux.Lock()
|
||||
defer c.mux.Unlock()
|
||||
|
||||
if c.startTimer != nil {
|
||||
c.startTimer.Stop()
|
||||
c.startTimer = nil
|
||||
}
|
||||
|
||||
if c.ticker == nil {
|
||||
return // already stopped
|
||||
}
|
||||
|
||||
c.tickerDone <- true
|
||||
c.ticker.Stop()
|
||||
c.ticker = nil
|
||||
}
|
||||
|
||||
// Start starts the cron ticker.
|
||||
//
|
||||
// Calling Start() on already started cron will restart the ticker.
|
||||
func (c *Cron) Start() {
|
||||
c.Stop()
|
||||
|
||||
// delay the ticker to start at 00 of 1 c.interval duration
|
||||
now := time.Now()
|
||||
next := now.Add(c.interval).Truncate(c.interval)
|
||||
delay := next.Sub(now)
|
||||
|
||||
c.mux.Lock()
|
||||
c.startTimer = time.AfterFunc(delay, func() {
|
||||
c.mux.Lock()
|
||||
c.ticker = time.NewTicker(c.interval)
|
||||
c.mux.Unlock()
|
||||
|
||||
// run immediately at 00
|
||||
c.runDue(time.Now())
|
||||
|
||||
// run after each tick
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-c.tickerDone:
|
||||
return
|
||||
case t := <-c.ticker.C:
|
||||
c.runDue(t)
|
||||
}
|
||||
}
|
||||
}()
|
||||
})
|
||||
c.mux.Unlock()
|
||||
}
|
||||
|
||||
// HasStarted checks whether the current Cron ticker has been started.
|
||||
func (c *Cron) HasStarted() bool {
|
||||
c.mux.RLock()
|
||||
defer c.mux.RUnlock()
|
||||
|
||||
return c.ticker != nil
|
||||
}
|
||||
|
||||
// runDue runs all registered jobs that are scheduled for the provided time.
|
||||
func (c *Cron) runDue(t time.Time) {
|
||||
c.mux.RLock()
|
||||
defer c.mux.RUnlock()
|
||||
|
||||
moment := NewMoment(t.In(c.timezone))
|
||||
|
||||
for _, j := range c.jobs {
|
||||
if j.schedule.IsDue(moment) {
|
||||
go j.Run()
|
||||
}
|
||||
}
|
||||
}
|
305
tools/cron/cron_test.go
Normal file
305
tools/cron/cron_test.go
Normal file
|
@ -0,0 +1,305 @@
|
|||
package cron
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"slices"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestCronNew(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
c := New()
|
||||
|
||||
expectedInterval := 1 * time.Minute
|
||||
if c.interval != expectedInterval {
|
||||
t.Fatalf("Expected default interval %v, got %v", expectedInterval, c.interval)
|
||||
}
|
||||
|
||||
expectedTimezone := time.UTC
|
||||
if c.timezone.String() != expectedTimezone.String() {
|
||||
t.Fatalf("Expected default timezone %v, got %v", expectedTimezone, c.timezone)
|
||||
}
|
||||
|
||||
if len(c.jobs) != 0 {
|
||||
t.Fatalf("Expected no jobs by default, got \n%v", c.jobs)
|
||||
}
|
||||
|
||||
if c.ticker != nil {
|
||||
t.Fatal("Expected the ticker NOT to be initialized")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCronSetInterval(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
c := New()
|
||||
|
||||
interval := 2 * time.Minute
|
||||
|
||||
c.SetInterval(interval)
|
||||
|
||||
if c.interval != interval {
|
||||
t.Fatalf("Expected interval %v, got %v", interval, c.interval)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCronSetTimezone(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
c := New()
|
||||
|
||||
timezone, _ := time.LoadLocation("Asia/Tokyo")
|
||||
|
||||
c.SetTimezone(timezone)
|
||||
|
||||
if c.timezone.String() != timezone.String() {
|
||||
t.Fatalf("Expected timezone %v, got %v", timezone, c.timezone)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCronAddAndRemove(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
c := New()
|
||||
|
||||
if err := c.Add("test0", "* * * * *", nil); err == nil {
|
||||
t.Fatal("Expected nil function error")
|
||||
}
|
||||
|
||||
if err := c.Add("test1", "invalid", func() {}); err == nil {
|
||||
t.Fatal("Expected invalid cron expression error")
|
||||
}
|
||||
|
||||
if err := c.Add("test2", "* * * * *", func() {}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := c.Add("test3", "* * * * *", func() {}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := c.Add("test4", "* * * * *", func() {}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// overwrite test2
|
||||
if err := c.Add("test2", "1 2 3 4 5", func() {}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := c.Add("test5", "1 2 3 4 5", func() {}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// mock job deletion
|
||||
c.Remove("test4")
|
||||
|
||||
// try to remove non-existing (should be no-op)
|
||||
c.Remove("missing")
|
||||
|
||||
indexedJobs := make(map[string]*Job, len(c.jobs))
|
||||
for _, j := range c.jobs {
|
||||
indexedJobs[j.Id()] = j
|
||||
}
|
||||
|
||||
// check job keys
|
||||
{
|
||||
expectedKeys := []string{"test3", "test2", "test5"}
|
||||
|
||||
if v := len(c.jobs); v != len(expectedKeys) {
|
||||
t.Fatalf("Expected %d jobs, got %d", len(expectedKeys), v)
|
||||
}
|
||||
|
||||
for _, k := range expectedKeys {
|
||||
if indexedJobs[k] == nil {
|
||||
t.Fatalf("Expected job with key %s, got nil", k)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// check the jobs schedule
|
||||
{
|
||||
expectedSchedules := map[string]string{
|
||||
"test2": `{"minutes":{"1":{}},"hours":{"2":{}},"days":{"3":{}},"months":{"4":{}},"daysOfWeek":{"5":{}}}`,
|
||||
"test3": `{"minutes":{"0":{},"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"24":{},"25":{},"26":{},"27":{},"28":{},"29":{},"3":{},"30":{},"31":{},"32":{},"33":{},"34":{},"35":{},"36":{},"37":{},"38":{},"39":{},"4":{},"40":{},"41":{},"42":{},"43":{},"44":{},"45":{},"46":{},"47":{},"48":{},"49":{},"5":{},"50":{},"51":{},"52":{},"53":{},"54":{},"55":{},"56":{},"57":{},"58":{},"59":{},"6":{},"7":{},"8":{},"9":{}},"hours":{"0":{},"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"3":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"days":{"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"24":{},"25":{},"26":{},"27":{},"28":{},"29":{},"3":{},"30":{},"31":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"months":{"1":{},"10":{},"11":{},"12":{},"2":{},"3":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"daysOfWeek":{"0":{},"1":{},"2":{},"3":{},"4":{},"5":{},"6":{}}}`,
|
||||
"test5": `{"minutes":{"1":{}},"hours":{"2":{}},"days":{"3":{}},"months":{"4":{}},"daysOfWeek":{"5":{}}}`,
|
||||
}
|
||||
for k, v := range expectedSchedules {
|
||||
raw, err := json.Marshal(indexedJobs[k].schedule)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if string(raw) != v {
|
||||
t.Fatalf("Expected %q schedule \n%s, \ngot \n%s", k, v, raw)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCronMustAdd(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
c := New()
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Errorf("test1 didn't panic")
|
||||
}
|
||||
}()
|
||||
|
||||
c.MustAdd("test1", "* * * * *", nil)
|
||||
|
||||
c.MustAdd("test2", "* * * * *", func() {})
|
||||
|
||||
if !slices.ContainsFunc(c.jobs, func(j *Job) bool { return j.Id() == "test2" }) {
|
||||
t.Fatal("Couldn't find job test2")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCronRemoveAll(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
c := New()
|
||||
|
||||
if err := c.Add("test1", "* * * * *", func() {}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := c.Add("test2", "* * * * *", func() {}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := c.Add("test3", "* * * * *", func() {}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if v := len(c.jobs); v != 3 {
|
||||
t.Fatalf("Expected %d jobs, got %d", 3, v)
|
||||
}
|
||||
|
||||
c.RemoveAll()
|
||||
|
||||
if v := len(c.jobs); v != 0 {
|
||||
t.Fatalf("Expected %d jobs, got %d", 0, v)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCronTotal(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
c := New()
|
||||
|
||||
if v := c.Total(); v != 0 {
|
||||
t.Fatalf("Expected 0 jobs, got %v", v)
|
||||
}
|
||||
|
||||
if err := c.Add("test1", "* * * * *", func() {}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := c.Add("test2", "* * * * *", func() {}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// overwrite
|
||||
if err := c.Add("test1", "* * * * *", func() {}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if v := c.Total(); v != 2 {
|
||||
t.Fatalf("Expected 2 jobs, got %v", v)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCronJobs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
c := New()
|
||||
|
||||
calls := ""
|
||||
|
||||
if err := c.Add("a", "1 * * * *", func() { calls += "a" }); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := c.Add("b", "2 * * * *", func() { calls += "b" }); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// overwrite
|
||||
if err := c.Add("b", "3 * * * *", func() { calls += "b" }); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
jobs := c.Jobs()
|
||||
|
||||
if len(jobs) != 2 {
|
||||
t.Fatalf("Expected 2 jobs, got %v", len(jobs))
|
||||
}
|
||||
|
||||
for _, j := range jobs {
|
||||
j.Run()
|
||||
}
|
||||
|
||||
expectedCalls := "ab"
|
||||
if calls != expectedCalls {
|
||||
t.Fatalf("Expected %q calls, got %q", expectedCalls, calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCronStartStop(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
test1 := 0
|
||||
test2 := 0
|
||||
|
||||
c := New()
|
||||
|
||||
c.SetInterval(500 * time.Millisecond)
|
||||
|
||||
c.Add("test1", "* * * * *", func() {
|
||||
test1++
|
||||
})
|
||||
|
||||
c.Add("test2", "* * * * *", func() {
|
||||
test2++
|
||||
})
|
||||
|
||||
expectedCalls := 2
|
||||
|
||||
// call twice Start to check if the previous ticker will be reseted
|
||||
c.Start()
|
||||
c.Start()
|
||||
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
// call twice Stop to ensure that the second stop is no-op
|
||||
c.Stop()
|
||||
c.Stop()
|
||||
|
||||
if test1 != expectedCalls {
|
||||
t.Fatalf("Expected %d test1, got %d", expectedCalls, test1)
|
||||
}
|
||||
if test2 != expectedCalls {
|
||||
t.Fatalf("Expected %d test2, got %d", expectedCalls, test2)
|
||||
}
|
||||
|
||||
// resume for 2 seconds
|
||||
c.Start()
|
||||
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
c.Stop()
|
||||
|
||||
expectedCalls += 4
|
||||
|
||||
if test1 != expectedCalls {
|
||||
t.Fatalf("Expected %d test1, got %d", expectedCalls, test1)
|
||||
}
|
||||
if test2 != expectedCalls {
|
||||
t.Fatalf("Expected %d test2, got %d", expectedCalls, test2)
|
||||
}
|
||||
}
|
41
tools/cron/job.go
Normal file
41
tools/cron/job.go
Normal file
|
@ -0,0 +1,41 @@
|
|||
package cron
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
// Job defines a single registered cron job.
|
||||
type Job struct {
|
||||
fn func()
|
||||
schedule *Schedule
|
||||
id string
|
||||
}
|
||||
|
||||
// Id returns the cron job id.
|
||||
func (j *Job) Id() string {
|
||||
return j.id
|
||||
}
|
||||
|
||||
// Expression returns the plain cron job schedule expression.
|
||||
func (j *Job) Expression() string {
|
||||
return j.schedule.rawExpr
|
||||
}
|
||||
|
||||
// Run runs the cron job function.
|
||||
func (j *Job) Run() {
|
||||
if j.fn != nil {
|
||||
j.fn()
|
||||
}
|
||||
}
|
||||
|
||||
// MarshalJSON implements [json.Marshaler] and export the current
|
||||
// jobs data into valid JSON.
|
||||
func (j Job) MarshalJSON() ([]byte, error) {
|
||||
plain := struct {
|
||||
Id string `json:"id"`
|
||||
Expression string `json:"expression"`
|
||||
}{
|
||||
Id: j.Id(),
|
||||
Expression: j.Expression(),
|
||||
}
|
||||
|
||||
return json.Marshal(plain)
|
||||
}
|
71
tools/cron/job_test.go
Normal file
71
tools/cron/job_test.go
Normal file
|
@ -0,0 +1,71 @@
|
|||
package cron
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestJobId(t *testing.T) {
|
||||
expected := "test"
|
||||
|
||||
j := Job{id: expected}
|
||||
|
||||
if j.Id() != expected {
|
||||
t.Fatalf("Expected job with id %q, got %q", expected, j.Id())
|
||||
}
|
||||
}
|
||||
|
||||
func TestJobExpr(t *testing.T) {
|
||||
expected := "1 2 3 4 5"
|
||||
|
||||
s, err := NewSchedule(expected)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
j := Job{schedule: s}
|
||||
|
||||
if j.Expression() != expected {
|
||||
t.Fatalf("Expected job with cron expression %q, got %q", expected, j.Expression())
|
||||
}
|
||||
}
|
||||
|
||||
func TestJobRun(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Errorf("Shouldn't panic: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
calls := ""
|
||||
|
||||
j1 := Job{}
|
||||
j2 := Job{fn: func() { calls += "2" }}
|
||||
|
||||
j1.Run()
|
||||
j2.Run()
|
||||
|
||||
expected := "2"
|
||||
if calls != expected {
|
||||
t.Fatalf("Expected calls %q, got %q", expected, calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestJobMarshalJSON(t *testing.T) {
|
||||
s, err := NewSchedule("1 2 3 4 5")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
j := Job{id: "test_id", schedule: s}
|
||||
|
||||
raw, err := json.Marshal(j)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expected := `{"id":"test_id","expression":"1 2 3 4 5"}`
|
||||
if str := string(raw); str != expected {
|
||||
t.Fatalf("Expected\n%s\ngot\n%s", expected, str)
|
||||
}
|
||||
}
|
218
tools/cron/schedule.go
Normal file
218
tools/cron/schedule.go
Normal file
|
@ -0,0 +1,218 @@
|
|||
package cron
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Moment represents a parsed single time moment.
|
||||
type Moment struct {
|
||||
Minute int `json:"minute"`
|
||||
Hour int `json:"hour"`
|
||||
Day int `json:"day"`
|
||||
Month int `json:"month"`
|
||||
DayOfWeek int `json:"dayOfWeek"`
|
||||
}
|
||||
|
||||
// NewMoment creates a new Moment from the specified time.
|
||||
func NewMoment(t time.Time) *Moment {
|
||||
return &Moment{
|
||||
Minute: t.Minute(),
|
||||
Hour: t.Hour(),
|
||||
Day: t.Day(),
|
||||
Month: int(t.Month()),
|
||||
DayOfWeek: int(t.Weekday()),
|
||||
}
|
||||
}
|
||||
|
||||
// Schedule stores parsed information for each time component when a cron job should run.
|
||||
type Schedule struct {
|
||||
Minutes map[int]struct{} `json:"minutes"`
|
||||
Hours map[int]struct{} `json:"hours"`
|
||||
Days map[int]struct{} `json:"days"`
|
||||
Months map[int]struct{} `json:"months"`
|
||||
DaysOfWeek map[int]struct{} `json:"daysOfWeek"`
|
||||
|
||||
rawExpr string
|
||||
}
|
||||
|
||||
// IsDue checks whether the provided Moment satisfies the current Schedule.
|
||||
func (s *Schedule) IsDue(m *Moment) bool {
|
||||
if _, ok := s.Minutes[m.Minute]; !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
if _, ok := s.Hours[m.Hour]; !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
if _, ok := s.Days[m.Day]; !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
if _, ok := s.DaysOfWeek[m.DayOfWeek]; !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
if _, ok := s.Months[m.Month]; !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
var macros = map[string]string{
|
||||
"@yearly": "0 0 1 1 *",
|
||||
"@annually": "0 0 1 1 *",
|
||||
"@monthly": "0 0 1 * *",
|
||||
"@weekly": "0 0 * * 0",
|
||||
"@daily": "0 0 * * *",
|
||||
"@midnight": "0 0 * * *",
|
||||
"@hourly": "0 * * * *",
|
||||
}
|
||||
|
||||
// NewSchedule creates a new Schedule from a cron expression.
|
||||
//
|
||||
// A cron expression could be a macro OR 5 segments separated by space,
|
||||
// representing: minute, hour, day of the month, month and day of the week.
|
||||
//
|
||||
// The following segment formats are supported:
|
||||
// - wildcard: *
|
||||
// - range: 1-30
|
||||
// - step: */n or 1-30/n
|
||||
// - list: 1,2,3,10-20/n
|
||||
//
|
||||
// The following macros are supported:
|
||||
// - @yearly (or @annually)
|
||||
// - @monthly
|
||||
// - @weekly
|
||||
// - @daily (or @midnight)
|
||||
// - @hourly
|
||||
func NewSchedule(cronExpr string) (*Schedule, error) {
|
||||
if v, ok := macros[cronExpr]; ok {
|
||||
cronExpr = v
|
||||
}
|
||||
|
||||
segments := strings.Split(cronExpr, " ")
|
||||
if len(segments) != 5 {
|
||||
return nil, errors.New("invalid cron expression - must be a valid macro or to have exactly 5 space separated segments")
|
||||
}
|
||||
|
||||
minutes, err := parseCronSegment(segments[0], 0, 59)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
hours, err := parseCronSegment(segments[1], 0, 23)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
days, err := parseCronSegment(segments[2], 1, 31)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
months, err := parseCronSegment(segments[3], 1, 12)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
daysOfWeek, err := parseCronSegment(segments[4], 0, 6)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Schedule{
|
||||
Minutes: minutes,
|
||||
Hours: hours,
|
||||
Days: days,
|
||||
Months: months,
|
||||
DaysOfWeek: daysOfWeek,
|
||||
rawExpr: cronExpr,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// parseCronSegment parses a single cron expression segment and
|
||||
// returns its time schedule slots.
|
||||
func parseCronSegment(segment string, min int, max int) (map[int]struct{}, error) {
|
||||
slots := map[int]struct{}{}
|
||||
|
||||
list := strings.Split(segment, ",")
|
||||
for _, p := range list {
|
||||
stepParts := strings.Split(p, "/")
|
||||
|
||||
// step (*/n, 1-30/n)
|
||||
var step int
|
||||
switch len(stepParts) {
|
||||
case 1:
|
||||
step = 1
|
||||
case 2:
|
||||
parsedStep, err := strconv.Atoi(stepParts[1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if parsedStep < 1 || parsedStep > max {
|
||||
return nil, fmt.Errorf("invalid segment step boundary - the step must be between 1 and the %d", max)
|
||||
}
|
||||
step = parsedStep
|
||||
default:
|
||||
return nil, errors.New("invalid segment step format - must be in the format */n or 1-30/n")
|
||||
}
|
||||
|
||||
// find the min and max range of the segment part
|
||||
var rangeMin, rangeMax int
|
||||
if stepParts[0] == "*" {
|
||||
rangeMin = min
|
||||
rangeMax = max
|
||||
} else {
|
||||
// single digit (1) or range (1-30)
|
||||
rangeParts := strings.Split(stepParts[0], "-")
|
||||
switch len(rangeParts) {
|
||||
case 1:
|
||||
if step != 1 {
|
||||
return nil, errors.New("invalid segement step - step > 1 could be used only with the wildcard or range format")
|
||||
}
|
||||
parsed, err := strconv.Atoi(rangeParts[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if parsed < min || parsed > max {
|
||||
return nil, errors.New("invalid segment value - must be between the min and max of the segment")
|
||||
}
|
||||
rangeMin = parsed
|
||||
rangeMax = rangeMin
|
||||
case 2:
|
||||
parsedMin, err := strconv.Atoi(rangeParts[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if parsedMin < min || parsedMin > max {
|
||||
return nil, fmt.Errorf("invalid segment range minimum - must be between %d and %d", min, max)
|
||||
}
|
||||
rangeMin = parsedMin
|
||||
|
||||
parsedMax, err := strconv.Atoi(rangeParts[1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if parsedMax < parsedMin || parsedMax > max {
|
||||
return nil, fmt.Errorf("invalid segment range maximum - must be between %d and %d", rangeMin, max)
|
||||
}
|
||||
rangeMax = parsedMax
|
||||
default:
|
||||
return nil, errors.New("invalid segment range format - the range must have 1 or 2 parts")
|
||||
}
|
||||
}
|
||||
|
||||
// fill the slots
|
||||
for i := rangeMin; i <= rangeMax; i += step {
|
||||
slots[i] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
return slots, nil
|
||||
}
|
409
tools/cron/schedule_test.go
Normal file
409
tools/cron/schedule_test.go
Normal file
|
@ -0,0 +1,409 @@
|
|||
package cron_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/cron"
|
||||
)
|
||||
|
||||
func TestNewMoment(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
date, err := time.Parse("2006-01-02 15:04", "2023-05-09 15:20")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
m := cron.NewMoment(date)
|
||||
|
||||
if m.Minute != 20 {
|
||||
t.Fatalf("Expected m.Minute %d, got %d", 20, m.Minute)
|
||||
}
|
||||
|
||||
if m.Hour != 15 {
|
||||
t.Fatalf("Expected m.Hour %d, got %d", 15, m.Hour)
|
||||
}
|
||||
|
||||
if m.Day != 9 {
|
||||
t.Fatalf("Expected m.Day %d, got %d", 9, m.Day)
|
||||
}
|
||||
|
||||
if m.Month != 5 {
|
||||
t.Fatalf("Expected m.Month %d, got %d", 5, m.Month)
|
||||
}
|
||||
|
||||
if m.DayOfWeek != 2 {
|
||||
t.Fatalf("Expected m.DayOfWeek %d, got %d", 2, m.DayOfWeek)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewSchedule(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []struct {
|
||||
cronExpr string
|
||||
expectError bool
|
||||
expectSchedule string
|
||||
}{
|
||||
{
|
||||
"invalid",
|
||||
true,
|
||||
"",
|
||||
},
|
||||
{
|
||||
"* * * *",
|
||||
true,
|
||||
"",
|
||||
},
|
||||
{
|
||||
"* * * * * *",
|
||||
true,
|
||||
"",
|
||||
},
|
||||
{
|
||||
"2/3 * * * *",
|
||||
true,
|
||||
"",
|
||||
},
|
||||
{
|
||||
"* * * * *",
|
||||
false,
|
||||
`{"minutes":{"0":{},"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"24":{},"25":{},"26":{},"27":{},"28":{},"29":{},"3":{},"30":{},"31":{},"32":{},"33":{},"34":{},"35":{},"36":{},"37":{},"38":{},"39":{},"4":{},"40":{},"41":{},"42":{},"43":{},"44":{},"45":{},"46":{},"47":{},"48":{},"49":{},"5":{},"50":{},"51":{},"52":{},"53":{},"54":{},"55":{},"56":{},"57":{},"58":{},"59":{},"6":{},"7":{},"8":{},"9":{}},"hours":{"0":{},"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"3":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"days":{"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"24":{},"25":{},"26":{},"27":{},"28":{},"29":{},"3":{},"30":{},"31":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"months":{"1":{},"10":{},"11":{},"12":{},"2":{},"3":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"daysOfWeek":{"0":{},"1":{},"2":{},"3":{},"4":{},"5":{},"6":{}}}`,
|
||||
},
|
||||
{
|
||||
"*/2 */3 */5 */4 */2",
|
||||
false,
|
||||
`{"minutes":{"0":{},"10":{},"12":{},"14":{},"16":{},"18":{},"2":{},"20":{},"22":{},"24":{},"26":{},"28":{},"30":{},"32":{},"34":{},"36":{},"38":{},"4":{},"40":{},"42":{},"44":{},"46":{},"48":{},"50":{},"52":{},"54":{},"56":{},"58":{},"6":{},"8":{}},"hours":{"0":{},"12":{},"15":{},"18":{},"21":{},"3":{},"6":{},"9":{}},"days":{"1":{},"11":{},"16":{},"21":{},"26":{},"31":{},"6":{}},"months":{"1":{},"5":{},"9":{}},"daysOfWeek":{"0":{},"2":{},"4":{},"6":{}}}`,
|
||||
},
|
||||
|
||||
// minute segment
|
||||
{
|
||||
"-1 * * * *",
|
||||
true,
|
||||
"",
|
||||
},
|
||||
{
|
||||
"60 * * * *",
|
||||
true,
|
||||
"",
|
||||
},
|
||||
{
|
||||
"0 * * * *",
|
||||
false,
|
||||
`{"minutes":{"0":{}},"hours":{"0":{},"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"3":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"days":{"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"24":{},"25":{},"26":{},"27":{},"28":{},"29":{},"3":{},"30":{},"31":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"months":{"1":{},"10":{},"11":{},"12":{},"2":{},"3":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"daysOfWeek":{"0":{},"1":{},"2":{},"3":{},"4":{},"5":{},"6":{}}}`,
|
||||
},
|
||||
{
|
||||
"59 * * * *",
|
||||
false,
|
||||
`{"minutes":{"59":{}},"hours":{"0":{},"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"3":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"days":{"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"24":{},"25":{},"26":{},"27":{},"28":{},"29":{},"3":{},"30":{},"31":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"months":{"1":{},"10":{},"11":{},"12":{},"2":{},"3":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"daysOfWeek":{"0":{},"1":{},"2":{},"3":{},"4":{},"5":{},"6":{}}}`,
|
||||
},
|
||||
{
|
||||
"1,2,5,7,40-50/2 * * * *",
|
||||
false,
|
||||
`{"minutes":{"1":{},"2":{},"40":{},"42":{},"44":{},"46":{},"48":{},"5":{},"50":{},"7":{}},"hours":{"0":{},"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"3":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"days":{"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"24":{},"25":{},"26":{},"27":{},"28":{},"29":{},"3":{},"30":{},"31":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"months":{"1":{},"10":{},"11":{},"12":{},"2":{},"3":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"daysOfWeek":{"0":{},"1":{},"2":{},"3":{},"4":{},"5":{},"6":{}}}`,
|
||||
},
|
||||
|
||||
// hour segment
|
||||
{
|
||||
"* -1 * * *",
|
||||
true,
|
||||
"",
|
||||
},
|
||||
{
|
||||
"* 24 * * *",
|
||||
true,
|
||||
"",
|
||||
},
|
||||
{
|
||||
"* 0 * * *",
|
||||
false,
|
||||
`{"minutes":{"0":{},"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"24":{},"25":{},"26":{},"27":{},"28":{},"29":{},"3":{},"30":{},"31":{},"32":{},"33":{},"34":{},"35":{},"36":{},"37":{},"38":{},"39":{},"4":{},"40":{},"41":{},"42":{},"43":{},"44":{},"45":{},"46":{},"47":{},"48":{},"49":{},"5":{},"50":{},"51":{},"52":{},"53":{},"54":{},"55":{},"56":{},"57":{},"58":{},"59":{},"6":{},"7":{},"8":{},"9":{}},"hours":{"0":{}},"days":{"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"24":{},"25":{},"26":{},"27":{},"28":{},"29":{},"3":{},"30":{},"31":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"months":{"1":{},"10":{},"11":{},"12":{},"2":{},"3":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"daysOfWeek":{"0":{},"1":{},"2":{},"3":{},"4":{},"5":{},"6":{}}}`,
|
||||
},
|
||||
{
|
||||
"* 23 * * *",
|
||||
false,
|
||||
`{"minutes":{"0":{},"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"24":{},"25":{},"26":{},"27":{},"28":{},"29":{},"3":{},"30":{},"31":{},"32":{},"33":{},"34":{},"35":{},"36":{},"37":{},"38":{},"39":{},"4":{},"40":{},"41":{},"42":{},"43":{},"44":{},"45":{},"46":{},"47":{},"48":{},"49":{},"5":{},"50":{},"51":{},"52":{},"53":{},"54":{},"55":{},"56":{},"57":{},"58":{},"59":{},"6":{},"7":{},"8":{},"9":{}},"hours":{"23":{}},"days":{"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"24":{},"25":{},"26":{},"27":{},"28":{},"29":{},"3":{},"30":{},"31":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"months":{"1":{},"10":{},"11":{},"12":{},"2":{},"3":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"daysOfWeek":{"0":{},"1":{},"2":{},"3":{},"4":{},"5":{},"6":{}}}`,
|
||||
},
|
||||
{
|
||||
"* 3,4,8-16/3,7 * * *",
|
||||
false,
|
||||
`{"minutes":{"0":{},"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"24":{},"25":{},"26":{},"27":{},"28":{},"29":{},"3":{},"30":{},"31":{},"32":{},"33":{},"34":{},"35":{},"36":{},"37":{},"38":{},"39":{},"4":{},"40":{},"41":{},"42":{},"43":{},"44":{},"45":{},"46":{},"47":{},"48":{},"49":{},"5":{},"50":{},"51":{},"52":{},"53":{},"54":{},"55":{},"56":{},"57":{},"58":{},"59":{},"6":{},"7":{},"8":{},"9":{}},"hours":{"11":{},"14":{},"3":{},"4":{},"7":{},"8":{}},"days":{"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"24":{},"25":{},"26":{},"27":{},"28":{},"29":{},"3":{},"30":{},"31":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"months":{"1":{},"10":{},"11":{},"12":{},"2":{},"3":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"daysOfWeek":{"0":{},"1":{},"2":{},"3":{},"4":{},"5":{},"6":{}}}`,
|
||||
},
|
||||
|
||||
// day segment
|
||||
{
|
||||
"* * 0 * *",
|
||||
true,
|
||||
"",
|
||||
},
|
||||
{
|
||||
"* * 32 * *",
|
||||
true,
|
||||
"",
|
||||
},
|
||||
{
|
||||
"* * 1 * *",
|
||||
false,
|
||||
`{"minutes":{"0":{},"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"24":{},"25":{},"26":{},"27":{},"28":{},"29":{},"3":{},"30":{},"31":{},"32":{},"33":{},"34":{},"35":{},"36":{},"37":{},"38":{},"39":{},"4":{},"40":{},"41":{},"42":{},"43":{},"44":{},"45":{},"46":{},"47":{},"48":{},"49":{},"5":{},"50":{},"51":{},"52":{},"53":{},"54":{},"55":{},"56":{},"57":{},"58":{},"59":{},"6":{},"7":{},"8":{},"9":{}},"hours":{"0":{},"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"3":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"days":{"1":{}},"months":{"1":{},"10":{},"11":{},"12":{},"2":{},"3":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"daysOfWeek":{"0":{},"1":{},"2":{},"3":{},"4":{},"5":{},"6":{}}}`,
|
||||
},
|
||||
{
|
||||
"* * 31 * *",
|
||||
false,
|
||||
`{"minutes":{"0":{},"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"24":{},"25":{},"26":{},"27":{},"28":{},"29":{},"3":{},"30":{},"31":{},"32":{},"33":{},"34":{},"35":{},"36":{},"37":{},"38":{},"39":{},"4":{},"40":{},"41":{},"42":{},"43":{},"44":{},"45":{},"46":{},"47":{},"48":{},"49":{},"5":{},"50":{},"51":{},"52":{},"53":{},"54":{},"55":{},"56":{},"57":{},"58":{},"59":{},"6":{},"7":{},"8":{},"9":{}},"hours":{"0":{},"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"3":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"days":{"31":{}},"months":{"1":{},"10":{},"11":{},"12":{},"2":{},"3":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"daysOfWeek":{"0":{},"1":{},"2":{},"3":{},"4":{},"5":{},"6":{}}}`,
|
||||
},
|
||||
{
|
||||
"* * 5,6,20-30/3,1 * *",
|
||||
false,
|
||||
`{"minutes":{"0":{},"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"24":{},"25":{},"26":{},"27":{},"28":{},"29":{},"3":{},"30":{},"31":{},"32":{},"33":{},"34":{},"35":{},"36":{},"37":{},"38":{},"39":{},"4":{},"40":{},"41":{},"42":{},"43":{},"44":{},"45":{},"46":{},"47":{},"48":{},"49":{},"5":{},"50":{},"51":{},"52":{},"53":{},"54":{},"55":{},"56":{},"57":{},"58":{},"59":{},"6":{},"7":{},"8":{},"9":{}},"hours":{"0":{},"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"3":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"days":{"1":{},"20":{},"23":{},"26":{},"29":{},"5":{},"6":{}},"months":{"1":{},"10":{},"11":{},"12":{},"2":{},"3":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"daysOfWeek":{"0":{},"1":{},"2":{},"3":{},"4":{},"5":{},"6":{}}}`,
|
||||
},
|
||||
|
||||
// month segment
|
||||
{
|
||||
"* * * 0 *",
|
||||
true,
|
||||
"",
|
||||
},
|
||||
{
|
||||
"* * * 13 *",
|
||||
true,
|
||||
"",
|
||||
},
|
||||
{
|
||||
"* * * 1 *",
|
||||
false,
|
||||
`{"minutes":{"0":{},"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"24":{},"25":{},"26":{},"27":{},"28":{},"29":{},"3":{},"30":{},"31":{},"32":{},"33":{},"34":{},"35":{},"36":{},"37":{},"38":{},"39":{},"4":{},"40":{},"41":{},"42":{},"43":{},"44":{},"45":{},"46":{},"47":{},"48":{},"49":{},"5":{},"50":{},"51":{},"52":{},"53":{},"54":{},"55":{},"56":{},"57":{},"58":{},"59":{},"6":{},"7":{},"8":{},"9":{}},"hours":{"0":{},"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"3":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"days":{"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"24":{},"25":{},"26":{},"27":{},"28":{},"29":{},"3":{},"30":{},"31":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"months":{"1":{}},"daysOfWeek":{"0":{},"1":{},"2":{},"3":{},"4":{},"5":{},"6":{}}}`,
|
||||
},
|
||||
{
|
||||
"* * * 12 *",
|
||||
false,
|
||||
`{"minutes":{"0":{},"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"24":{},"25":{},"26":{},"27":{},"28":{},"29":{},"3":{},"30":{},"31":{},"32":{},"33":{},"34":{},"35":{},"36":{},"37":{},"38":{},"39":{},"4":{},"40":{},"41":{},"42":{},"43":{},"44":{},"45":{},"46":{},"47":{},"48":{},"49":{},"5":{},"50":{},"51":{},"52":{},"53":{},"54":{},"55":{},"56":{},"57":{},"58":{},"59":{},"6":{},"7":{},"8":{},"9":{}},"hours":{"0":{},"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"3":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"days":{"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"24":{},"25":{},"26":{},"27":{},"28":{},"29":{},"3":{},"30":{},"31":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"months":{"12":{}},"daysOfWeek":{"0":{},"1":{},"2":{},"3":{},"4":{},"5":{},"6":{}}}`,
|
||||
},
|
||||
{
|
||||
"* * * 1,4,5-10/2 *",
|
||||
false,
|
||||
`{"minutes":{"0":{},"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"24":{},"25":{},"26":{},"27":{},"28":{},"29":{},"3":{},"30":{},"31":{},"32":{},"33":{},"34":{},"35":{},"36":{},"37":{},"38":{},"39":{},"4":{},"40":{},"41":{},"42":{},"43":{},"44":{},"45":{},"46":{},"47":{},"48":{},"49":{},"5":{},"50":{},"51":{},"52":{},"53":{},"54":{},"55":{},"56":{},"57":{},"58":{},"59":{},"6":{},"7":{},"8":{},"9":{}},"hours":{"0":{},"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"3":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"days":{"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"24":{},"25":{},"26":{},"27":{},"28":{},"29":{},"3":{},"30":{},"31":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"months":{"1":{},"4":{},"5":{},"7":{},"9":{}},"daysOfWeek":{"0":{},"1":{},"2":{},"3":{},"4":{},"5":{},"6":{}}}`,
|
||||
},
|
||||
|
||||
// day of week segment
|
||||
{
|
||||
"* * * * -1",
|
||||
true,
|
||||
"",
|
||||
},
|
||||
{
|
||||
"* * * * 7",
|
||||
true,
|
||||
"",
|
||||
},
|
||||
{
|
||||
"* * * * 0",
|
||||
false,
|
||||
`{"minutes":{"0":{},"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"24":{},"25":{},"26":{},"27":{},"28":{},"29":{},"3":{},"30":{},"31":{},"32":{},"33":{},"34":{},"35":{},"36":{},"37":{},"38":{},"39":{},"4":{},"40":{},"41":{},"42":{},"43":{},"44":{},"45":{},"46":{},"47":{},"48":{},"49":{},"5":{},"50":{},"51":{},"52":{},"53":{},"54":{},"55":{},"56":{},"57":{},"58":{},"59":{},"6":{},"7":{},"8":{},"9":{}},"hours":{"0":{},"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"3":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"days":{"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"24":{},"25":{},"26":{},"27":{},"28":{},"29":{},"3":{},"30":{},"31":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"months":{"1":{},"10":{},"11":{},"12":{},"2":{},"3":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"daysOfWeek":{"0":{}}}`,
|
||||
},
|
||||
{
|
||||
"* * * * 6",
|
||||
false,
|
||||
`{"minutes":{"0":{},"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"24":{},"25":{},"26":{},"27":{},"28":{},"29":{},"3":{},"30":{},"31":{},"32":{},"33":{},"34":{},"35":{},"36":{},"37":{},"38":{},"39":{},"4":{},"40":{},"41":{},"42":{},"43":{},"44":{},"45":{},"46":{},"47":{},"48":{},"49":{},"5":{},"50":{},"51":{},"52":{},"53":{},"54":{},"55":{},"56":{},"57":{},"58":{},"59":{},"6":{},"7":{},"8":{},"9":{}},"hours":{"0":{},"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"3":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"days":{"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"24":{},"25":{},"26":{},"27":{},"28":{},"29":{},"3":{},"30":{},"31":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"months":{"1":{},"10":{},"11":{},"12":{},"2":{},"3":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"daysOfWeek":{"6":{}}}`,
|
||||
},
|
||||
{
|
||||
"* * * * 1,2-5/2",
|
||||
false,
|
||||
`{"minutes":{"0":{},"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"24":{},"25":{},"26":{},"27":{},"28":{},"29":{},"3":{},"30":{},"31":{},"32":{},"33":{},"34":{},"35":{},"36":{},"37":{},"38":{},"39":{},"4":{},"40":{},"41":{},"42":{},"43":{},"44":{},"45":{},"46":{},"47":{},"48":{},"49":{},"5":{},"50":{},"51":{},"52":{},"53":{},"54":{},"55":{},"56":{},"57":{},"58":{},"59":{},"6":{},"7":{},"8":{},"9":{}},"hours":{"0":{},"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"3":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"days":{"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"24":{},"25":{},"26":{},"27":{},"28":{},"29":{},"3":{},"30":{},"31":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"months":{"1":{},"10":{},"11":{},"12":{},"2":{},"3":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"daysOfWeek":{"1":{},"2":{},"4":{}}}`,
|
||||
},
|
||||
|
||||
// macros
|
||||
{
|
||||
"@yearly",
|
||||
false,
|
||||
`{"minutes":{"0":{}},"hours":{"0":{}},"days":{"1":{}},"months":{"1":{}},"daysOfWeek":{"0":{},"1":{},"2":{},"3":{},"4":{},"5":{},"6":{}}}`,
|
||||
},
|
||||
{
|
||||
"@annually",
|
||||
false,
|
||||
`{"minutes":{"0":{}},"hours":{"0":{}},"days":{"1":{}},"months":{"1":{}},"daysOfWeek":{"0":{},"1":{},"2":{},"3":{},"4":{},"5":{},"6":{}}}`,
|
||||
},
|
||||
{
|
||||
"@monthly",
|
||||
false,
|
||||
`{"minutes":{"0":{}},"hours":{"0":{}},"days":{"1":{}},"months":{"1":{},"10":{},"11":{},"12":{},"2":{},"3":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"daysOfWeek":{"0":{},"1":{},"2":{},"3":{},"4":{},"5":{},"6":{}}}`,
|
||||
},
|
||||
{
|
||||
"@weekly",
|
||||
false,
|
||||
`{"minutes":{"0":{}},"hours":{"0":{}},"days":{"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"24":{},"25":{},"26":{},"27":{},"28":{},"29":{},"3":{},"30":{},"31":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"months":{"1":{},"10":{},"11":{},"12":{},"2":{},"3":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"daysOfWeek":{"0":{}}}`,
|
||||
},
|
||||
{
|
||||
"@daily",
|
||||
false,
|
||||
`{"minutes":{"0":{}},"hours":{"0":{}},"days":{"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"24":{},"25":{},"26":{},"27":{},"28":{},"29":{},"3":{},"30":{},"31":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"months":{"1":{},"10":{},"11":{},"12":{},"2":{},"3":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"daysOfWeek":{"0":{},"1":{},"2":{},"3":{},"4":{},"5":{},"6":{}}}`,
|
||||
},
|
||||
{
|
||||
"@midnight",
|
||||
false,
|
||||
`{"minutes":{"0":{}},"hours":{"0":{}},"days":{"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"24":{},"25":{},"26":{},"27":{},"28":{},"29":{},"3":{},"30":{},"31":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"months":{"1":{},"10":{},"11":{},"12":{},"2":{},"3":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"daysOfWeek":{"0":{},"1":{},"2":{},"3":{},"4":{},"5":{},"6":{}}}`,
|
||||
},
|
||||
{
|
||||
"@hourly",
|
||||
false,
|
||||
`{"minutes":{"0":{}},"hours":{"0":{},"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"3":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"days":{"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"24":{},"25":{},"26":{},"27":{},"28":{},"29":{},"3":{},"30":{},"31":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"months":{"1":{},"10":{},"11":{},"12":{},"2":{},"3":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"daysOfWeek":{"0":{},"1":{},"2":{},"3":{},"4":{},"5":{},"6":{}}}`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.cronExpr, func(t *testing.T) {
|
||||
schedule, err := cron.NewSchedule(s.cronExpr)
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Fatalf("Expected hasErr to be %v, got %v (%v)", s.expectError, hasErr, err)
|
||||
}
|
||||
|
||||
if hasErr {
|
||||
return
|
||||
}
|
||||
|
||||
encoded, err := json.Marshal(schedule)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshalize the result schedule: %v", err)
|
||||
}
|
||||
encodedStr := string(encoded)
|
||||
|
||||
if encodedStr != s.expectSchedule {
|
||||
t.Fatalf("Expected \n%s, \ngot \n%s", s.expectSchedule, encodedStr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestScheduleIsDue(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []struct {
|
||||
cronExpr string
|
||||
moment *cron.Moment
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
"* * * * *",
|
||||
&cron.Moment{},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"* * * * *",
|
||||
&cron.Moment{
|
||||
Minute: 1,
|
||||
Hour: 1,
|
||||
Day: 1,
|
||||
Month: 1,
|
||||
DayOfWeek: 1,
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"5 * * * *",
|
||||
&cron.Moment{
|
||||
Minute: 1,
|
||||
Hour: 1,
|
||||
Day: 1,
|
||||
Month: 1,
|
||||
DayOfWeek: 1,
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"5 * * * *",
|
||||
&cron.Moment{
|
||||
Minute: 5,
|
||||
Hour: 1,
|
||||
Day: 1,
|
||||
Month: 1,
|
||||
DayOfWeek: 1,
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"* 2-6 * * 2,3",
|
||||
&cron.Moment{
|
||||
Minute: 1,
|
||||
Hour: 2,
|
||||
Day: 1,
|
||||
Month: 1,
|
||||
DayOfWeek: 1,
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"* 2-6 * * 2,3",
|
||||
&cron.Moment{
|
||||
Minute: 1,
|
||||
Hour: 2,
|
||||
Day: 1,
|
||||
Month: 1,
|
||||
DayOfWeek: 3,
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"* * 1,2,5,15-18 * *",
|
||||
&cron.Moment{
|
||||
Minute: 1,
|
||||
Hour: 1,
|
||||
Day: 6,
|
||||
Month: 1,
|
||||
DayOfWeek: 1,
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"* * 1,2,5,15-18/2 * *",
|
||||
&cron.Moment{
|
||||
Minute: 1,
|
||||
Hour: 1,
|
||||
Day: 2,
|
||||
Month: 1,
|
||||
DayOfWeek: 1,
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"* * 1,2,5,15-18/2 * *",
|
||||
&cron.Moment{
|
||||
Minute: 1,
|
||||
Hour: 1,
|
||||
Day: 18,
|
||||
Month: 1,
|
||||
DayOfWeek: 1,
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"* * 1,2,5,15-18/2 * *",
|
||||
&cron.Moment{
|
||||
Minute: 1,
|
||||
Hour: 1,
|
||||
Day: 17,
|
||||
Month: 1,
|
||||
DayOfWeek: 1,
|
||||
},
|
||||
true,
|
||||
},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d-%s", i, s.cronExpr), func(t *testing.T) {
|
||||
schedule, err := cron.NewSchedule(s.cronExpr)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected cron error: %v", err)
|
||||
}
|
||||
|
||||
result := schedule.IsDue(s.moment)
|
||||
|
||||
if result != s.expected {
|
||||
t.Fatalf("Expected %v, got %v", s.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
217
tools/dbutils/index.go
Normal file
217
tools/dbutils/index.go
Normal file
|
@ -0,0 +1,217 @@
|
|||
package dbutils
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/tokenizer"
|
||||
)
|
||||
|
||||
var (
|
||||
indexRegex = regexp.MustCompile(`(?im)create\s+(unique\s+)?\s*index\s*(if\s+not\s+exists\s+)?(\S*)\s+on\s+(\S*)\s*\(([\s\S]*)\)(?:\s*where\s+([\s\S]*))?`)
|
||||
indexColumnRegex = regexp.MustCompile(`(?im)^([\s\S]+?)(?:\s+collate\s+([\w]+))?(?:\s+(asc|desc))?$`)
|
||||
)
|
||||
|
||||
// IndexColumn represents a single parsed SQL index column.
|
||||
type IndexColumn struct {
|
||||
Name string `json:"name"` // identifier or expression
|
||||
Collate string `json:"collate"`
|
||||
Sort string `json:"sort"`
|
||||
}
|
||||
|
||||
// Index represents a single parsed SQL CREATE INDEX expression.
|
||||
type Index struct {
|
||||
SchemaName string `json:"schemaName"`
|
||||
IndexName string `json:"indexName"`
|
||||
TableName string `json:"tableName"`
|
||||
Where string `json:"where"`
|
||||
Columns []IndexColumn `json:"columns"`
|
||||
Unique bool `json:"unique"`
|
||||
Optional bool `json:"optional"`
|
||||
}
|
||||
|
||||
// IsValid checks if the current Index contains the minimum required fields to be considered valid.
|
||||
func (idx Index) IsValid() bool {
|
||||
return idx.IndexName != "" && idx.TableName != "" && len(idx.Columns) > 0
|
||||
}
|
||||
|
||||
// Build returns a "CREATE INDEX" SQL string from the current index parts.
|
||||
//
|
||||
// Returns empty string if idx.IsValid() is false.
|
||||
func (idx Index) Build() string {
|
||||
if !idx.IsValid() {
|
||||
return ""
|
||||
}
|
||||
|
||||
var str strings.Builder
|
||||
|
||||
str.WriteString("CREATE ")
|
||||
|
||||
if idx.Unique {
|
||||
str.WriteString("UNIQUE ")
|
||||
}
|
||||
|
||||
str.WriteString("INDEX ")
|
||||
|
||||
if idx.Optional {
|
||||
str.WriteString("IF NOT EXISTS ")
|
||||
}
|
||||
|
||||
if idx.SchemaName != "" {
|
||||
str.WriteString("`")
|
||||
str.WriteString(idx.SchemaName)
|
||||
str.WriteString("`.")
|
||||
}
|
||||
|
||||
str.WriteString("`")
|
||||
str.WriteString(idx.IndexName)
|
||||
str.WriteString("` ")
|
||||
|
||||
str.WriteString("ON `")
|
||||
str.WriteString(idx.TableName)
|
||||
str.WriteString("` (")
|
||||
|
||||
if len(idx.Columns) > 1 {
|
||||
str.WriteString("\n ")
|
||||
}
|
||||
|
||||
var hasCol bool
|
||||
for _, col := range idx.Columns {
|
||||
trimmedColName := strings.TrimSpace(col.Name)
|
||||
if trimmedColName == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if hasCol {
|
||||
str.WriteString(",\n ")
|
||||
}
|
||||
|
||||
if strings.Contains(col.Name, "(") || strings.Contains(col.Name, " ") {
|
||||
// most likely an expression
|
||||
str.WriteString(trimmedColName)
|
||||
} else {
|
||||
// regular identifier
|
||||
str.WriteString("`")
|
||||
str.WriteString(trimmedColName)
|
||||
str.WriteString("`")
|
||||
}
|
||||
|
||||
if col.Collate != "" {
|
||||
str.WriteString(" COLLATE ")
|
||||
str.WriteString(col.Collate)
|
||||
}
|
||||
|
||||
if col.Sort != "" {
|
||||
str.WriteString(" ")
|
||||
str.WriteString(strings.ToUpper(col.Sort))
|
||||
}
|
||||
|
||||
hasCol = true
|
||||
}
|
||||
|
||||
if hasCol && len(idx.Columns) > 1 {
|
||||
str.WriteString("\n")
|
||||
}
|
||||
|
||||
str.WriteString(")")
|
||||
|
||||
if idx.Where != "" {
|
||||
str.WriteString(" WHERE ")
|
||||
str.WriteString(idx.Where)
|
||||
}
|
||||
|
||||
return str.String()
|
||||
}
|
||||
|
||||
// ParseIndex parses the provided "CREATE INDEX" SQL string into Index struct.
|
||||
func ParseIndex(createIndexExpr string) Index {
|
||||
result := Index{}
|
||||
|
||||
matches := indexRegex.FindStringSubmatch(createIndexExpr)
|
||||
if len(matches) != 7 {
|
||||
return result
|
||||
}
|
||||
|
||||
trimChars := "`\"'[]\r\n\t\f\v "
|
||||
|
||||
// Unique
|
||||
// ---
|
||||
result.Unique = strings.TrimSpace(matches[1]) != ""
|
||||
|
||||
// Optional (aka. "IF NOT EXISTS")
|
||||
// ---
|
||||
result.Optional = strings.TrimSpace(matches[2]) != ""
|
||||
|
||||
// SchemaName and IndexName
|
||||
// ---
|
||||
nameTk := tokenizer.NewFromString(matches[3])
|
||||
nameTk.Separators('.')
|
||||
|
||||
nameParts, _ := nameTk.ScanAll()
|
||||
if len(nameParts) == 2 {
|
||||
result.SchemaName = strings.Trim(nameParts[0], trimChars)
|
||||
result.IndexName = strings.Trim(nameParts[1], trimChars)
|
||||
} else {
|
||||
result.IndexName = strings.Trim(nameParts[0], trimChars)
|
||||
}
|
||||
|
||||
// TableName
|
||||
// ---
|
||||
result.TableName = strings.Trim(matches[4], trimChars)
|
||||
|
||||
// Columns
|
||||
// ---
|
||||
columnsTk := tokenizer.NewFromString(matches[5])
|
||||
columnsTk.Separators(',')
|
||||
|
||||
rawColumns, _ := columnsTk.ScanAll()
|
||||
|
||||
result.Columns = make([]IndexColumn, 0, len(rawColumns))
|
||||
|
||||
for _, col := range rawColumns {
|
||||
colMatches := indexColumnRegex.FindStringSubmatch(col)
|
||||
if len(colMatches) != 4 {
|
||||
continue
|
||||
}
|
||||
|
||||
trimmedName := strings.Trim(colMatches[1], trimChars)
|
||||
if trimmedName == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
result.Columns = append(result.Columns, IndexColumn{
|
||||
Name: trimmedName,
|
||||
Collate: strings.TrimSpace(colMatches[2]),
|
||||
Sort: strings.ToUpper(colMatches[3]),
|
||||
})
|
||||
}
|
||||
|
||||
// WHERE expression
|
||||
// ---
|
||||
result.Where = strings.TrimSpace(matches[6])
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// FindSingleColumnUniqueIndex returns the first matching single column unique index.
|
||||
func FindSingleColumnUniqueIndex(indexes []string, column string) (Index, bool) {
|
||||
var index Index
|
||||
|
||||
for _, idx := range indexes {
|
||||
index := ParseIndex(idx)
|
||||
if index.Unique && len(index.Columns) == 1 && strings.EqualFold(index.Columns[0].Name, column) {
|
||||
return index, true
|
||||
}
|
||||
}
|
||||
|
||||
return index, false
|
||||
}
|
||||
|
||||
// Deprecated: Use `_, ok := FindSingleColumnUniqueIndex(indexes, column)` instead.
|
||||
//
|
||||
// HasColumnUniqueIndex loosely checks whether the specified column has
|
||||
// a single column unique index (WHERE statements are ignored).
|
||||
func HasSingleColumnUniqueIndex(column string, indexes []string) bool {
|
||||
_, ok := FindSingleColumnUniqueIndex(indexes, column)
|
||||
return ok
|
||||
}
|
405
tools/dbutils/index_test.go
Normal file
405
tools/dbutils/index_test.go
Normal file
|
@ -0,0 +1,405 @@
|
|||
package dbutils_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/dbutils"
|
||||
)
|
||||
|
||||
func TestParseIndex(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
index string
|
||||
expected dbutils.Index
|
||||
}{
|
||||
// invalid
|
||||
{
|
||||
`invalid`,
|
||||
dbutils.Index{},
|
||||
},
|
||||
// simple (multiple spaces between the table and columns list)
|
||||
{
|
||||
`create index indexname on tablename (col1)`,
|
||||
dbutils.Index{
|
||||
IndexName: "indexname",
|
||||
TableName: "tablename",
|
||||
Columns: []dbutils.IndexColumn{
|
||||
{Name: "col1"},
|
||||
},
|
||||
},
|
||||
},
|
||||
// simple (no space between the table and the columns list)
|
||||
{
|
||||
`create index indexname on tablename(col1)`,
|
||||
dbutils.Index{
|
||||
IndexName: "indexname",
|
||||
TableName: "tablename",
|
||||
Columns: []dbutils.IndexColumn{
|
||||
{Name: "col1"},
|
||||
},
|
||||
},
|
||||
},
|
||||
// all fields
|
||||
{
|
||||
`CREATE UNIQUE INDEX IF NOT EXISTS "schemaname".[indexname] on 'tablename' (
|
||||
col0,
|
||||
` + "`" + `col1` + "`" + `,
|
||||
json_extract("col2", "$.a") asc,
|
||||
"col3" collate NOCASE,
|
||||
"col4" collate RTRIM desc
|
||||
) where test = 1`,
|
||||
dbutils.Index{
|
||||
Unique: true,
|
||||
Optional: true,
|
||||
SchemaName: "schemaname",
|
||||
IndexName: "indexname",
|
||||
TableName: "tablename",
|
||||
Columns: []dbutils.IndexColumn{
|
||||
{Name: "col0"},
|
||||
{Name: "col1"},
|
||||
{Name: `json_extract("col2", "$.a")`, Sort: "ASC"},
|
||||
{Name: `col3`, Collate: "NOCASE"},
|
||||
{Name: `col4`, Collate: "RTRIM", Sort: "DESC"},
|
||||
},
|
||||
Where: "test = 1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("scenario_%d", i), func(t *testing.T) {
|
||||
result := dbutils.ParseIndex(s.index)
|
||||
|
||||
resultRaw, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
t.Fatalf("Faild to marshalize parse result: %v", err)
|
||||
}
|
||||
|
||||
expectedRaw, err := json.Marshal(s.expected)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshalize expected index: %v", err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(resultRaw, expectedRaw) {
|
||||
t.Errorf("Expected \n%s \ngot \n%s", expectedRaw, resultRaw)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIndexIsValid(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
name string
|
||||
index dbutils.Index
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
"empty",
|
||||
dbutils.Index{},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"no index name",
|
||||
dbutils.Index{
|
||||
TableName: "table",
|
||||
Columns: []dbutils.IndexColumn{{Name: "col"}},
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"no table name",
|
||||
dbutils.Index{
|
||||
IndexName: "index",
|
||||
Columns: []dbutils.IndexColumn{{Name: "col"}},
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"no columns",
|
||||
dbutils.Index{
|
||||
IndexName: "index",
|
||||
TableName: "table",
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"min valid",
|
||||
dbutils.Index{
|
||||
IndexName: "index",
|
||||
TableName: "table",
|
||||
Columns: []dbutils.IndexColumn{{Name: "col"}},
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"all fields",
|
||||
dbutils.Index{
|
||||
Optional: true,
|
||||
Unique: true,
|
||||
SchemaName: "schema",
|
||||
IndexName: "index",
|
||||
TableName: "table",
|
||||
Columns: []dbutils.IndexColumn{{Name: "col"}},
|
||||
Where: "test = 1 OR test = 2",
|
||||
},
|
||||
true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
result := s.index.IsValid()
|
||||
if result != s.expected {
|
||||
t.Fatalf("Expected %v, got %v", s.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIndexBuild(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
name string
|
||||
index dbutils.Index
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
"empty",
|
||||
dbutils.Index{},
|
||||
"",
|
||||
},
|
||||
{
|
||||
"no index name",
|
||||
dbutils.Index{
|
||||
TableName: "table",
|
||||
Columns: []dbutils.IndexColumn{{Name: "col"}},
|
||||
},
|
||||
"",
|
||||
},
|
||||
{
|
||||
"no table name",
|
||||
dbutils.Index{
|
||||
IndexName: "index",
|
||||
Columns: []dbutils.IndexColumn{{Name: "col"}},
|
||||
},
|
||||
"",
|
||||
},
|
||||
{
|
||||
"no columns",
|
||||
dbutils.Index{
|
||||
IndexName: "index",
|
||||
TableName: "table",
|
||||
},
|
||||
"",
|
||||
},
|
||||
{
|
||||
"min valid",
|
||||
dbutils.Index{
|
||||
IndexName: "index",
|
||||
TableName: "table",
|
||||
Columns: []dbutils.IndexColumn{{Name: "col"}},
|
||||
},
|
||||
"CREATE INDEX `index` ON `table` (`col`)",
|
||||
},
|
||||
{
|
||||
"all fields",
|
||||
dbutils.Index{
|
||||
Optional: true,
|
||||
Unique: true,
|
||||
SchemaName: "schema",
|
||||
IndexName: "index",
|
||||
TableName: "table",
|
||||
Columns: []dbutils.IndexColumn{
|
||||
{Name: "col1", Collate: "NOCASE", Sort: "asc"},
|
||||
{Name: "col2", Sort: "desc"},
|
||||
{Name: `json_extract("col3", "$.a")`, Collate: "NOCASE"},
|
||||
},
|
||||
Where: "test = 1 OR test = 2",
|
||||
},
|
||||
"CREATE UNIQUE INDEX IF NOT EXISTS `schema`.`index` ON `table` (\n `col1` COLLATE NOCASE ASC,\n `col2` DESC,\n " + `json_extract("col3", "$.a")` + " COLLATE NOCASE\n) WHERE test = 1 OR test = 2",
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
result := s.index.Build()
|
||||
if result != s.expected {
|
||||
t.Fatalf("Expected \n%v \ngot \n%v", s.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasSingleColumnUniqueIndex(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
name string
|
||||
column string
|
||||
indexes []string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
"empty indexes",
|
||||
"test",
|
||||
nil,
|
||||
false,
|
||||
},
|
||||
{
|
||||
"empty column",
|
||||
"",
|
||||
[]string{
|
||||
"CREATE UNIQUE INDEX `index1` ON `example` (`test`)",
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"mismatched column",
|
||||
"test",
|
||||
[]string{
|
||||
"CREATE UNIQUE INDEX `index1` ON `example` (`test2`)",
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"non unique index",
|
||||
"test",
|
||||
[]string{
|
||||
"CREATE INDEX `index1` ON `example` (`test`)",
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"matching columnd and unique index",
|
||||
"test",
|
||||
[]string{
|
||||
"CREATE UNIQUE INDEX `index1` ON `example` (`test`)",
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"multiple columns",
|
||||
"test",
|
||||
[]string{
|
||||
"CREATE UNIQUE INDEX `index1` ON `example` (`test`, `test2`)",
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"multiple indexes",
|
||||
"test",
|
||||
[]string{
|
||||
"CREATE UNIQUE INDEX `index1` ON `example` (`test`, `test2`)",
|
||||
"CREATE UNIQUE INDEX `index2` ON `example` (`test`)",
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"partial unique index",
|
||||
"test",
|
||||
[]string{
|
||||
"CREATE UNIQUE INDEX `index` ON `example` (`test`) where test != ''",
|
||||
},
|
||||
true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
result := dbutils.HasSingleColumnUniqueIndex(s.column, s.indexes)
|
||||
if result != s.expected {
|
||||
t.Fatalf("Expected %v got %v", s.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindSingleColumnUniqueIndex(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
name string
|
||||
column string
|
||||
indexes []string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
"empty indexes",
|
||||
"test",
|
||||
nil,
|
||||
false,
|
||||
},
|
||||
{
|
||||
"empty column",
|
||||
"",
|
||||
[]string{
|
||||
"CREATE UNIQUE INDEX `index1` ON `example` (`test`)",
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"mismatched column",
|
||||
"test",
|
||||
[]string{
|
||||
"CREATE UNIQUE INDEX `index1` ON `example` (`test2`)",
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"non unique index",
|
||||
"test",
|
||||
[]string{
|
||||
"CREATE INDEX `index1` ON `example` (`test`)",
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"matching columnd and unique index",
|
||||
"test",
|
||||
[]string{
|
||||
"CREATE UNIQUE INDEX `index1` ON `example` (`test`)",
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"multiple columns",
|
||||
"test",
|
||||
[]string{
|
||||
"CREATE UNIQUE INDEX `index1` ON `example` (`test`, `test2`)",
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"multiple indexes",
|
||||
"test",
|
||||
[]string{
|
||||
"CREATE UNIQUE INDEX `index1` ON `example` (`test`, `test2`)",
|
||||
"CREATE UNIQUE INDEX `index2` ON `example` (`test`)",
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"partial unique index",
|
||||
"test",
|
||||
[]string{
|
||||
"CREATE UNIQUE INDEX `index` ON `example` (`test`) where test != ''",
|
||||
},
|
||||
true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
index, exists := dbutils.FindSingleColumnUniqueIndex(s.indexes, s.column)
|
||||
if exists != s.expected {
|
||||
t.Fatalf("Expected exists %v got %v", s.expected, exists)
|
||||
}
|
||||
|
||||
if !exists && len(index.Columns) > 0 {
|
||||
t.Fatal("Expected index.Columns to be empty")
|
||||
}
|
||||
|
||||
if exists && !strings.EqualFold(index.Columns[0].Name, s.column) {
|
||||
t.Fatalf("Expected to find column %q in %v", s.column, index)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
51
tools/dbutils/json.go
Normal file
51
tools/dbutils/json.go
Normal file
|
@ -0,0 +1,51 @@
|
|||
package dbutils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// JSONEach returns JSON_EACH SQLite string expression with
|
||||
// some normalizations for non-json columns.
|
||||
func JSONEach(column string) string {
|
||||
// note: we are not using the new and shorter "if(x,y)" syntax for
|
||||
// compatability with custom drivers that use older SQLite version
|
||||
return fmt.Sprintf(
|
||||
`json_each(CASE WHEN iif(json_valid([[%s]]), json_type([[%s]])='array', FALSE) THEN [[%s]] ELSE json_array([[%s]]) END)`,
|
||||
column, column, column, column,
|
||||
)
|
||||
}
|
||||
|
||||
// JSONArrayLength returns JSON_ARRAY_LENGTH SQLite string expression
|
||||
// with some normalizations for non-json columns.
|
||||
//
|
||||
// It works with both json and non-json column values.
|
||||
//
|
||||
// Returns 0 for empty string or NULL column values.
|
||||
func JSONArrayLength(column string) string {
|
||||
// note: we are not using the new and shorter "if(x,y)" syntax for
|
||||
// compatability with custom drivers that use older SQLite version
|
||||
return fmt.Sprintf(
|
||||
`json_array_length(CASE WHEN iif(json_valid([[%s]]), json_type([[%s]])='array', FALSE) THEN [[%s]] ELSE (CASE WHEN [[%s]] = '' OR [[%s]] IS NULL THEN json_array() ELSE json_array([[%s]]) END) END)`,
|
||||
column, column, column, column, column, column,
|
||||
)
|
||||
}
|
||||
|
||||
// JSONExtract returns a JSON_EXTRACT SQLite string expression with
|
||||
// some normalizations for non-json columns.
|
||||
func JSONExtract(column string, path string) string {
|
||||
// prefix the path with dot if it is not starting with array notation
|
||||
if path != "" && !strings.HasPrefix(path, "[") {
|
||||
path = "." + path
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
// note: the extra object wrapping is needed to workaround the cases where a json_extract is used with non-json columns.
|
||||
"(CASE WHEN json_valid([[%s]]) THEN JSON_EXTRACT([[%s]], '$%s') ELSE JSON_EXTRACT(json_object('pb', [[%s]]), '$.pb%s') END)",
|
||||
column,
|
||||
column,
|
||||
path,
|
||||
column,
|
||||
path,
|
||||
)
|
||||
}
|
65
tools/dbutils/json_test.go
Normal file
65
tools/dbutils/json_test.go
Normal file
|
@ -0,0 +1,65 @@
|
|||
package dbutils_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/dbutils"
|
||||
)
|
||||
|
||||
func TestJSONEach(t *testing.T) {
|
||||
result := dbutils.JSONEach("a.b")
|
||||
|
||||
expected := "json_each(CASE WHEN iif(json_valid([[a.b]]), json_type([[a.b]])='array', FALSE) THEN [[a.b]] ELSE json_array([[a.b]]) END)"
|
||||
|
||||
if result != expected {
|
||||
t.Fatalf("Expected\n%v\ngot\n%v", expected, result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestJSONArrayLength(t *testing.T) {
|
||||
result := dbutils.JSONArrayLength("a.b")
|
||||
|
||||
expected := "json_array_length(CASE WHEN iif(json_valid([[a.b]]), json_type([[a.b]])='array', FALSE) THEN [[a.b]] ELSE (CASE WHEN [[a.b]] = '' OR [[a.b]] IS NULL THEN json_array() ELSE json_array([[a.b]]) END) END)"
|
||||
|
||||
if result != expected {
|
||||
t.Fatalf("Expected\n%v\ngot\n%v", expected, result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestJSONExtract(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
name string
|
||||
column string
|
||||
path string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
"empty path",
|
||||
"a.b",
|
||||
"",
|
||||
"(CASE WHEN json_valid([[a.b]]) THEN JSON_EXTRACT([[a.b]], '$') ELSE JSON_EXTRACT(json_object('pb', [[a.b]]), '$.pb') END)",
|
||||
},
|
||||
{
|
||||
"starting with array index",
|
||||
"a.b",
|
||||
"[1].a[2]",
|
||||
"(CASE WHEN json_valid([[a.b]]) THEN JSON_EXTRACT([[a.b]], '$[1].a[2]') ELSE JSON_EXTRACT(json_object('pb', [[a.b]]), '$.pb[1].a[2]') END)",
|
||||
},
|
||||
{
|
||||
"starting with key",
|
||||
"a.b",
|
||||
"a.b[2].c",
|
||||
"(CASE WHEN json_valid([[a.b]]) THEN JSON_EXTRACT([[a.b]], '$.a.b[2].c') ELSE JSON_EXTRACT(json_object('pb', [[a.b]]), '$.pb.a.b[2].c') END)",
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
result := dbutils.JSONExtract(s.column, s.path)
|
||||
|
||||
if result != s.expected {
|
||||
t.Fatalf("Expected\n%v\ngot\n%v", s.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
736
tools/filesystem/blob/bucket.go
Normal file
736
tools/filesystem/blob/bucket.go
Normal file
|
@ -0,0 +1,736 @@
|
|||
// Package blob defines a lightweight abstration for interacting with
|
||||
// various storage services (local filesystem, S3, etc.).
|
||||
//
|
||||
// NB!
|
||||
// For compatibility with earlier PocketBase versions and to prevent
|
||||
// unnecessary breaking changes, this package is based and implemented
|
||||
// as a minimal, stripped down version of the previously used gocloud.dev/blob.
|
||||
// While there is no promise that it won't diverge in the future to accommodate
|
||||
// better some PocketBase specific use cases, currently it copies and
|
||||
// tries to follow as close as possible the same implementations,
|
||||
// conventions and rules for the key escaping/unescaping, blob read/write
|
||||
// interfaces and struct options as gocloud.dev/blob, therefore the
|
||||
// credits goes to the original Go Cloud Development Kit Authors.
|
||||
package blob
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/md5"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"mime"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrNotFound = errors.New("resource not found")
|
||||
ErrClosed = errors.New("bucket or blob is closed")
|
||||
)
|
||||
|
||||
// Bucket provides an easy and portable way to interact with blobs
|
||||
// within a "bucket", including read, write, and list operations.
|
||||
// To create a Bucket, use constructors found in driver subpackages.
|
||||
type Bucket struct {
|
||||
drv Driver
|
||||
|
||||
// mu protects the closed variable.
|
||||
// Read locks are kept to allow holding a read lock for long-running calls,
|
||||
// and thereby prevent closing until a call finishes.
|
||||
mu sync.RWMutex
|
||||
closed bool
|
||||
}
|
||||
|
||||
// NewBucket creates a new *Bucket based on a specific driver implementation.
|
||||
func NewBucket(drv Driver) *Bucket {
|
||||
return &Bucket{drv: drv}
|
||||
}
|
||||
|
||||
// ListOptions sets options for listing blobs via Bucket.List.
|
||||
type ListOptions struct {
|
||||
// Prefix indicates that only blobs with a key starting with this prefix
|
||||
// should be returned.
|
||||
Prefix string
|
||||
|
||||
// Delimiter sets the delimiter used to define a hierarchical namespace,
|
||||
// like a filesystem with "directories". It is highly recommended that you
|
||||
// use "" or "/" as the Delimiter. Other values should work through this API,
|
||||
// but service UIs generally assume "/".
|
||||
//
|
||||
// An empty delimiter means that the bucket is treated as a single flat
|
||||
// namespace.
|
||||
//
|
||||
// A non-empty delimiter means that any result with the delimiter in its key
|
||||
// after Prefix is stripped will be returned with ListObject.IsDir = true,
|
||||
// ListObject.Key truncated after the delimiter, and zero values for other
|
||||
// ListObject fields. These results represent "directories". Multiple results
|
||||
// in a "directory" are returned as a single result.
|
||||
Delimiter string
|
||||
|
||||
// PageSize sets the maximum number of objects to be returned.
|
||||
// 0 means no maximum; driver implementations should choose a reasonable
|
||||
// max. It is guaranteed to be >= 0.
|
||||
PageSize int
|
||||
|
||||
// PageToken may be filled in with the NextPageToken from a previous
|
||||
// ListPaged call.
|
||||
PageToken []byte
|
||||
}
|
||||
|
||||
// ListPage represents a page of results return from ListPaged.
|
||||
type ListPage struct {
|
||||
// Objects is the slice of objects found. If ListOptions.PageSize > 0,
|
||||
// it should have at most ListOptions.PageSize entries.
|
||||
//
|
||||
// Objects should be returned in lexicographical order of UTF-8 encoded keys,
|
||||
// including across pages. I.e., all objects returned from a ListPage request
|
||||
// made using a PageToken from a previous ListPage request's NextPageToken
|
||||
// should have Key >= the Key for all objects from the previous request.
|
||||
Objects []*ListObject `json:"objects"`
|
||||
|
||||
// NextPageToken should be left empty unless there are more objects
|
||||
// to return. The value may be returned as ListOptions.PageToken on a
|
||||
// subsequent ListPaged call, to fetch the next page of results.
|
||||
// It can be an arbitrary []byte; it need not be a valid key.
|
||||
NextPageToken []byte `json:"nextPageToken"`
|
||||
}
|
||||
|
||||
// ListIterator iterates over List results.
|
||||
type ListIterator struct {
|
||||
b *Bucket
|
||||
opts *ListOptions
|
||||
page *ListPage
|
||||
nextIdx int
|
||||
}
|
||||
|
||||
// Next returns a *ListObject for the next blob.
|
||||
// It returns (nil, io.EOF) if there are no more.
|
||||
func (i *ListIterator) Next(ctx context.Context) (*ListObject, error) {
|
||||
if i.page != nil {
|
||||
// We've already got a page of results.
|
||||
if i.nextIdx < len(i.page.Objects) {
|
||||
// Next object is in the page; return it.
|
||||
dobj := i.page.Objects[i.nextIdx]
|
||||
i.nextIdx++
|
||||
return &ListObject{
|
||||
Key: dobj.Key,
|
||||
ModTime: dobj.ModTime,
|
||||
Size: dobj.Size,
|
||||
MD5: dobj.MD5,
|
||||
IsDir: dobj.IsDir,
|
||||
}, nil
|
||||
}
|
||||
|
||||
if len(i.page.NextPageToken) == 0 {
|
||||
// Done with current page, and there are no more; return io.EOF.
|
||||
return nil, io.EOF
|
||||
}
|
||||
|
||||
// We need to load the next page.
|
||||
i.opts.PageToken = i.page.NextPageToken
|
||||
}
|
||||
|
||||
i.b.mu.RLock()
|
||||
defer i.b.mu.RUnlock()
|
||||
|
||||
if i.b.closed {
|
||||
return nil, ErrClosed
|
||||
}
|
||||
|
||||
// Loading a new page.
|
||||
p, err := i.b.drv.ListPaged(ctx, i.opts)
|
||||
if err != nil {
|
||||
return nil, wrapError(i.b.drv, err, "")
|
||||
}
|
||||
|
||||
i.page = p
|
||||
i.nextIdx = 0
|
||||
|
||||
return i.Next(ctx)
|
||||
}
|
||||
|
||||
// ListObject represents a single blob returned from List.
|
||||
type ListObject struct {
|
||||
// Key is the key for this blob.
|
||||
Key string `json:"key"`
|
||||
|
||||
// ModTime is the time the blob was last modified.
|
||||
ModTime time.Time `json:"modTime"`
|
||||
|
||||
// Size is the size of the blob's content in bytes.
|
||||
Size int64 `json:"size"`
|
||||
|
||||
// MD5 is an MD5 hash of the blob contents or nil if not available.
|
||||
MD5 []byte `json:"md5"`
|
||||
|
||||
// IsDir indicates that this result represents a "directory" in the
|
||||
// hierarchical namespace, ending in ListOptions.Delimiter. Key can be
|
||||
// passed as ListOptions.Prefix to list items in the "directory".
|
||||
// Fields other than Key and IsDir will not be set if IsDir is true.
|
||||
IsDir bool `json:"isDir"`
|
||||
}
|
||||
|
||||
// List returns a ListIterator that can be used to iterate over blobs in a
|
||||
// bucket, in lexicographical order of UTF-8 encoded keys. The underlying
|
||||
// implementation fetches results in pages.
|
||||
//
|
||||
// A nil ListOptions is treated the same as the zero value.
|
||||
//
|
||||
// List is not guaranteed to include all recently-written blobs;
|
||||
// some services are only eventually consistent.
|
||||
func (b *Bucket) List(opts *ListOptions) *ListIterator {
|
||||
if opts == nil {
|
||||
opts = &ListOptions{}
|
||||
}
|
||||
|
||||
dopts := &ListOptions{
|
||||
Prefix: opts.Prefix,
|
||||
Delimiter: opts.Delimiter,
|
||||
}
|
||||
|
||||
return &ListIterator{b: b, opts: dopts}
|
||||
}
|
||||
|
||||
// FirstPageToken is the pageToken to pass to ListPage to retrieve the first page of results.
|
||||
var FirstPageToken = []byte("first page")
|
||||
|
||||
// ListPage returns a page of ListObject results for blobs in a bucket, in lexicographical
|
||||
// order of UTF-8 encoded keys.
|
||||
//
|
||||
// To fetch the first page, pass FirstPageToken as the pageToken. For subsequent pages, pass
|
||||
// the pageToken returned from a previous call to ListPage.
|
||||
// It is not possible to "skip ahead" pages.
|
||||
//
|
||||
// Each call will return pageSize results, unless there are not enough blobs to fill the
|
||||
// page, in which case it will return fewer results (possibly 0).
|
||||
//
|
||||
// If there are no more blobs available, ListPage will return an empty pageToken. Note that
|
||||
// this may happen regardless of the number of returned results -- the last page might have
|
||||
// 0 results (i.e., if the last item was deleted), pageSize results, or anything in between.
|
||||
//
|
||||
// Calling ListPage with an empty pageToken will immediately return io.EOF. When looping
|
||||
// over pages, callers can either check for an empty pageToken, or they can make one more
|
||||
// call and check for io.EOF.
|
||||
//
|
||||
// The underlying implementation fetches results in pages, but one call to ListPage may
|
||||
// require multiple page fetches (and therefore, multiple calls to the BeforeList callback).
|
||||
//
|
||||
// A nil ListOptions is treated the same as the zero value.
|
||||
//
|
||||
// ListPage is not guaranteed to include all recently-written blobs;
|
||||
// some services are only eventually consistent.
|
||||
func (b *Bucket) ListPage(ctx context.Context, pageToken []byte, pageSize int, opts *ListOptions) (retval []*ListObject, nextPageToken []byte, err error) {
|
||||
if opts == nil {
|
||||
opts = &ListOptions{}
|
||||
}
|
||||
if pageSize <= 0 {
|
||||
return nil, nil, fmt.Errorf("pageSize must be > 0 (%d)", pageSize)
|
||||
}
|
||||
|
||||
// Nil pageToken means no more results.
|
||||
if len(pageToken) == 0 {
|
||||
return nil, nil, io.EOF
|
||||
}
|
||||
|
||||
// FirstPageToken fetches the first page. Drivers use nil.
|
||||
// The public API doesn't use nil for the first page because it would be too easy to
|
||||
// keep fetching forever (since the last page return nil for the next pageToken).
|
||||
if bytes.Equal(pageToken, FirstPageToken) {
|
||||
pageToken = nil
|
||||
}
|
||||
|
||||
b.mu.RLock()
|
||||
defer b.mu.RUnlock()
|
||||
|
||||
if b.closed {
|
||||
return nil, nil, ErrClosed
|
||||
}
|
||||
|
||||
dopts := &ListOptions{
|
||||
Prefix: opts.Prefix,
|
||||
Delimiter: opts.Delimiter,
|
||||
PageToken: pageToken,
|
||||
PageSize: pageSize,
|
||||
}
|
||||
|
||||
retval = make([]*ListObject, 0, pageSize)
|
||||
for len(retval) < pageSize {
|
||||
p, err := b.drv.ListPaged(ctx, dopts)
|
||||
if err != nil {
|
||||
return nil, nil, wrapError(b.drv, err, "")
|
||||
}
|
||||
|
||||
for _, dobj := range p.Objects {
|
||||
retval = append(retval, &ListObject{
|
||||
Key: dobj.Key,
|
||||
ModTime: dobj.ModTime,
|
||||
Size: dobj.Size,
|
||||
MD5: dobj.MD5,
|
||||
IsDir: dobj.IsDir,
|
||||
})
|
||||
}
|
||||
|
||||
// ListPaged may return fewer results than pageSize. If there are more results
|
||||
// available, signalled by non-empty p.NextPageToken, try to fetch the remainder
|
||||
// of the page.
|
||||
// It does not work to ask for more results than we need, because then we'd have
|
||||
// a NextPageToken on a non-page boundary.
|
||||
dopts.PageSize = pageSize - len(retval)
|
||||
dopts.PageToken = p.NextPageToken
|
||||
if len(dopts.PageToken) == 0 {
|
||||
dopts.PageToken = nil
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return retval, dopts.PageToken, nil
|
||||
}
|
||||
|
||||
// Attributes contains attributes about a blob.
|
||||
type Attributes struct {
|
||||
// CacheControl specifies caching attributes that services may use
|
||||
// when serving the blob.
|
||||
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Cache-Control
|
||||
CacheControl string `json:"cacheControl"`
|
||||
|
||||
// ContentDisposition specifies whether the blob content is expected to be
|
||||
// displayed inline or as an attachment.
|
||||
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Disposition
|
||||
ContentDisposition string `json:"contentDisposition"`
|
||||
|
||||
// ContentEncoding specifies the encoding used for the blob's content, if any.
|
||||
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Encoding
|
||||
ContentEncoding string `json:"contentEncoding"`
|
||||
|
||||
// ContentLanguage specifies the language used in the blob's content, if any.
|
||||
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Language
|
||||
ContentLanguage string `json:"contentLanguage"`
|
||||
|
||||
// ContentType is the MIME type of the blob. It will not be empty.
|
||||
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Type
|
||||
ContentType string `json:"contentType"`
|
||||
|
||||
// Metadata holds key/value pairs associated with the blob.
|
||||
// Keys are guaranteed to be in lowercase, even if the backend service
|
||||
// has case-sensitive keys (although note that Metadata written via
|
||||
// this package will always be lowercased). If there are duplicate
|
||||
// case-insensitive keys (e.g., "foo" and "FOO"), only one value
|
||||
// will be kept, and it is undefined which one.
|
||||
Metadata map[string]string `json:"metadata"`
|
||||
|
||||
// CreateTime is the time the blob was created, if available. If not available,
|
||||
// CreateTime will be the zero time.
|
||||
CreateTime time.Time `json:"createTime"`
|
||||
|
||||
// ModTime is the time the blob was last modified.
|
||||
ModTime time.Time `json:"modTime"`
|
||||
|
||||
// Size is the size of the blob's content in bytes.
|
||||
Size int64 `json:"size"`
|
||||
|
||||
// MD5 is an MD5 hash of the blob contents or nil if not available.
|
||||
MD5 []byte `json:"md5"`
|
||||
|
||||
// ETag for the blob; see https://en.wikipedia.org/wiki/HTTP_ETag.
|
||||
ETag string `json:"etag"`
|
||||
}
|
||||
|
||||
// Attributes returns attributes for the blob stored at key.
|
||||
//
|
||||
// If the blob does not exist, Attributes returns an error for which
|
||||
// gcerrors.Code will return gcerrors.NotFound.
|
||||
func (b *Bucket) Attributes(ctx context.Context, key string) (_ *Attributes, err error) {
|
||||
if !utf8.ValidString(key) {
|
||||
return nil, fmt.Errorf("Attributes key must be a valid UTF-8 string: %q", key)
|
||||
}
|
||||
|
||||
b.mu.RLock()
|
||||
defer b.mu.RUnlock()
|
||||
if b.closed {
|
||||
return nil, ErrClosed
|
||||
}
|
||||
|
||||
a, err := b.drv.Attributes(ctx, key)
|
||||
if err != nil {
|
||||
return nil, wrapError(b.drv, err, key)
|
||||
}
|
||||
|
||||
var md map[string]string
|
||||
if len(a.Metadata) > 0 {
|
||||
// Services are inconsistent, but at least some treat keys
|
||||
// as case-insensitive. To make the behavior consistent, we
|
||||
// force-lowercase them when writing and reading.
|
||||
md = make(map[string]string, len(a.Metadata))
|
||||
for k, v := range a.Metadata {
|
||||
md[strings.ToLower(k)] = v
|
||||
}
|
||||
}
|
||||
|
||||
return &Attributes{
|
||||
CacheControl: a.CacheControl,
|
||||
ContentDisposition: a.ContentDisposition,
|
||||
ContentEncoding: a.ContentEncoding,
|
||||
ContentLanguage: a.ContentLanguage,
|
||||
ContentType: a.ContentType,
|
||||
Metadata: md,
|
||||
CreateTime: a.CreateTime,
|
||||
ModTime: a.ModTime,
|
||||
Size: a.Size,
|
||||
MD5: a.MD5,
|
||||
ETag: a.ETag,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Exists returns true if a blob exists at key, false if it does not exist, or
|
||||
// an error.
|
||||
//
|
||||
// It is a shortcut for calling Attributes and checking if it returns an error
|
||||
// with code ErrNotFound.
|
||||
func (b *Bucket) Exists(ctx context.Context, key string) (bool, error) {
|
||||
_, err := b.Attributes(ctx, key)
|
||||
if err == nil {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
if errors.Is(err, ErrNotFound) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return false, err
|
||||
}
|
||||
|
||||
// NewReader is a shortcut for NewRangeReader with offset=0 and length=-1.
|
||||
func (b *Bucket) NewReader(ctx context.Context, key string) (*Reader, error) {
|
||||
return b.newRangeReader(ctx, key, 0, -1)
|
||||
}
|
||||
|
||||
// NewRangeReader returns a Reader to read content from the blob stored at key.
|
||||
// It reads at most length bytes starting at offset (>= 0).
|
||||
// If length is negative, it will read till the end of the blob.
|
||||
//
|
||||
// For the purposes of Seek, the returned Reader will start at offset and
|
||||
// end at the minimum of the actual end of the blob or (if length > 0) offset + length.
|
||||
//
|
||||
// Note that ctx is used for all reads performed during the lifetime of the reader.
|
||||
//
|
||||
// If the blob does not exist, NewRangeReader returns an error for which
|
||||
// gcerrors.Code will return gcerrors.NotFound. Exists is a lighter-weight way
|
||||
// to check for existence.
|
||||
//
|
||||
// A nil ReaderOptions is treated the same as the zero value.
|
||||
//
|
||||
// The caller must call Close on the returned Reader when done reading.
|
||||
func (b *Bucket) NewRangeReader(ctx context.Context, key string, offset, length int64) (_ *Reader, err error) {
|
||||
return b.newRangeReader(ctx, key, offset, length)
|
||||
}
|
||||
|
||||
func (b *Bucket) newRangeReader(ctx context.Context, key string, offset, length int64) (_ *Reader, err error) {
|
||||
b.mu.RLock()
|
||||
defer b.mu.RUnlock()
|
||||
if b.closed {
|
||||
return nil, ErrClosed
|
||||
}
|
||||
|
||||
if offset < 0 {
|
||||
return nil, fmt.Errorf("NewRangeReader offset must be non-negative (%d)", offset)
|
||||
}
|
||||
|
||||
if !utf8.ValidString(key) {
|
||||
return nil, fmt.Errorf("NewRangeReader key must be a valid UTF-8 string: %q", key)
|
||||
}
|
||||
|
||||
var dr DriverReader
|
||||
dr, err = b.drv.NewRangeReader(ctx, key, offset, length)
|
||||
if err != nil {
|
||||
return nil, wrapError(b.drv, err, key)
|
||||
}
|
||||
|
||||
r := &Reader{
|
||||
drv: b.drv,
|
||||
r: dr,
|
||||
key: key,
|
||||
ctx: ctx,
|
||||
baseOffset: offset,
|
||||
baseLength: length,
|
||||
savedOffset: -1,
|
||||
}
|
||||
|
||||
_, file, lineno, ok := runtime.Caller(2)
|
||||
runtime.SetFinalizer(r, func(r *Reader) {
|
||||
if !r.closed {
|
||||
var caller string
|
||||
if ok {
|
||||
caller = fmt.Sprintf(" (%s:%d)", file, lineno)
|
||||
}
|
||||
log.Printf("A blob.Reader reading from %q was never closed%s", key, caller)
|
||||
}
|
||||
})
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// WriterOptions sets options for NewWriter.
|
||||
type WriterOptions struct {
|
||||
// BufferSize changes the default size in bytes of the chunks that
|
||||
// Writer will upload in a single request; larger blobs will be split into
|
||||
// multiple requests.
|
||||
//
|
||||
// This option may be ignored by some drivers.
|
||||
//
|
||||
// If 0, the driver will choose a reasonable default.
|
||||
//
|
||||
// If the Writer is used to do many small writes concurrently, using a
|
||||
// smaller BufferSize may reduce memory usage.
|
||||
BufferSize int
|
||||
|
||||
// MaxConcurrency changes the default concurrency for parts of an upload.
|
||||
//
|
||||
// This option may be ignored by some drivers.
|
||||
//
|
||||
// If 0, the driver will choose a reasonable default.
|
||||
MaxConcurrency int
|
||||
|
||||
// CacheControl specifies caching attributes that services may use
|
||||
// when serving the blob.
|
||||
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Cache-Control
|
||||
CacheControl string
|
||||
|
||||
// ContentDisposition specifies whether the blob content is expected to be
|
||||
// displayed inline or as an attachment.
|
||||
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Disposition
|
||||
ContentDisposition string
|
||||
|
||||
// ContentEncoding specifies the encoding used for the blob's content, if any.
|
||||
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Encoding
|
||||
ContentEncoding string
|
||||
|
||||
// ContentLanguage specifies the language used in the blob's content, if any.
|
||||
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Language
|
||||
ContentLanguage string
|
||||
|
||||
// ContentType specifies the MIME type of the blob being written. If not set,
|
||||
// it will be inferred from the content using the algorithm described at
|
||||
// http://mimesniff.spec.whatwg.org/.
|
||||
// Set DisableContentTypeDetection to true to disable the above and force
|
||||
// the ContentType to stay empty.
|
||||
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Type
|
||||
ContentType string
|
||||
|
||||
// When true, if ContentType is the empty string, it will stay the empty
|
||||
// string rather than being inferred from the content.
|
||||
// Note that while the blob will be written with an empty string ContentType,
|
||||
// most providers will fill one in during reads, so don't expect an empty
|
||||
// ContentType if you read the blob back.
|
||||
DisableContentTypeDetection bool
|
||||
|
||||
// ContentMD5 is used as a message integrity check.
|
||||
// If len(ContentMD5) > 0, the MD5 hash of the bytes written must match
|
||||
// ContentMD5, or Close will return an error without completing the write.
|
||||
// https://tools.ietf.org/html/rfc1864
|
||||
ContentMD5 []byte
|
||||
|
||||
// Metadata holds key/value strings to be associated with the blob, or nil.
|
||||
// Keys may not be empty, and are lowercased before being written.
|
||||
// Duplicate case-insensitive keys (e.g., "foo" and "FOO") will result in
|
||||
// an error.
|
||||
Metadata map[string]string
|
||||
}
|
||||
|
||||
// NewWriter returns a Writer that writes to the blob stored at key.
|
||||
// A nil WriterOptions is treated the same as the zero value.
|
||||
//
|
||||
// If a blob with this key already exists, it will be replaced.
|
||||
// The blob being written is not guaranteed to be readable until Close
|
||||
// has been called; until then, any previous blob will still be readable.
|
||||
// Even after Close is called, newly written blobs are not guaranteed to be
|
||||
// returned from List; some services are only eventually consistent.
|
||||
//
|
||||
// The returned Writer will store ctx for later use in Write and/or Close.
|
||||
// To abort a write, cancel ctx; otherwise, it must remain open until
|
||||
// Close is called.
|
||||
//
|
||||
// The caller must call Close on the returned Writer, even if the write is
|
||||
// aborted.
|
||||
func (b *Bucket) NewWriter(ctx context.Context, key string, opts *WriterOptions) (_ *Writer, err error) {
|
||||
if !utf8.ValidString(key) {
|
||||
return nil, fmt.Errorf("NewWriter key must be a valid UTF-8 string: %q", key)
|
||||
}
|
||||
if opts == nil {
|
||||
opts = &WriterOptions{}
|
||||
}
|
||||
dopts := &WriterOptions{
|
||||
CacheControl: opts.CacheControl,
|
||||
ContentDisposition: opts.ContentDisposition,
|
||||
ContentEncoding: opts.ContentEncoding,
|
||||
ContentLanguage: opts.ContentLanguage,
|
||||
ContentMD5: opts.ContentMD5,
|
||||
BufferSize: opts.BufferSize,
|
||||
MaxConcurrency: opts.MaxConcurrency,
|
||||
DisableContentTypeDetection: opts.DisableContentTypeDetection,
|
||||
}
|
||||
|
||||
if len(opts.Metadata) > 0 {
|
||||
// Services are inconsistent, but at least some treat keys
|
||||
// as case-insensitive. To make the behavior consistent, we
|
||||
// force-lowercase them when writing and reading.
|
||||
md := make(map[string]string, len(opts.Metadata))
|
||||
for k, v := range opts.Metadata {
|
||||
if k == "" {
|
||||
return nil, errors.New("WriterOptions.Metadata keys may not be empty strings")
|
||||
}
|
||||
if !utf8.ValidString(k) {
|
||||
return nil, fmt.Errorf("WriterOptions.Metadata keys must be valid UTF-8 strings: %q", k)
|
||||
}
|
||||
if !utf8.ValidString(v) {
|
||||
return nil, fmt.Errorf("WriterOptions.Metadata values must be valid UTF-8 strings: %q", v)
|
||||
}
|
||||
lowerK := strings.ToLower(k)
|
||||
if _, found := md[lowerK]; found {
|
||||
return nil, fmt.Errorf("WriterOptions.Metadata has a duplicate case-insensitive metadata key: %q", lowerK)
|
||||
}
|
||||
md[lowerK] = v
|
||||
}
|
||||
dopts.Metadata = md
|
||||
}
|
||||
|
||||
b.mu.RLock()
|
||||
defer b.mu.RUnlock()
|
||||
if b.closed {
|
||||
return nil, ErrClosed
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
|
||||
w := &Writer{
|
||||
drv: b.drv,
|
||||
cancel: cancel,
|
||||
key: key,
|
||||
contentMD5: opts.ContentMD5,
|
||||
md5hash: md5.New(),
|
||||
}
|
||||
|
||||
if opts.ContentType != "" || opts.DisableContentTypeDetection {
|
||||
var ct string
|
||||
if opts.ContentType != "" {
|
||||
t, p, err := mime.ParseMediaType(opts.ContentType)
|
||||
if err != nil {
|
||||
cancel()
|
||||
return nil, err
|
||||
}
|
||||
ct = mime.FormatMediaType(t, p)
|
||||
}
|
||||
dw, err := b.drv.NewTypedWriter(ctx, key, ct, dopts)
|
||||
if err != nil {
|
||||
cancel()
|
||||
return nil, wrapError(b.drv, err, key)
|
||||
}
|
||||
w.w = dw
|
||||
} else {
|
||||
// Save the fields needed to called NewTypedWriter later, once we've gotten
|
||||
// sniffLen bytes; see the comment on Writer.
|
||||
w.ctx = ctx
|
||||
w.opts = dopts
|
||||
w.buf = bytes.NewBuffer([]byte{})
|
||||
}
|
||||
|
||||
_, file, lineno, ok := runtime.Caller(1)
|
||||
runtime.SetFinalizer(w, func(w *Writer) {
|
||||
if !w.closed {
|
||||
var caller string
|
||||
if ok {
|
||||
caller = fmt.Sprintf(" (%s:%d)", file, lineno)
|
||||
}
|
||||
log.Printf("A blob.Writer writing to %q was never closed%s", key, caller)
|
||||
}
|
||||
})
|
||||
|
||||
return w, nil
|
||||
}
|
||||
|
||||
// Copy the blob stored at srcKey to dstKey.
|
||||
// A nil CopyOptions is treated the same as the zero value.
|
||||
//
|
||||
// If the source blob does not exist, Copy returns an error for which
|
||||
// gcerrors.Code will return gcerrors.NotFound.
|
||||
//
|
||||
// If the destination blob already exists, it is overwritten.
|
||||
func (b *Bucket) Copy(ctx context.Context, dstKey, srcKey string) (err error) {
|
||||
if !utf8.ValidString(srcKey) {
|
||||
return fmt.Errorf("Copy srcKey must be a valid UTF-8 string: %q", srcKey)
|
||||
}
|
||||
|
||||
if !utf8.ValidString(dstKey) {
|
||||
return fmt.Errorf("Copy dstKey must be a valid UTF-8 string: %q", dstKey)
|
||||
}
|
||||
|
||||
b.mu.RLock()
|
||||
defer b.mu.RUnlock()
|
||||
|
||||
if b.closed {
|
||||
return ErrClosed
|
||||
}
|
||||
|
||||
return wrapError(b.drv, b.drv.Copy(ctx, dstKey, srcKey), fmt.Sprintf("%s -> %s", srcKey, dstKey))
|
||||
}
|
||||
|
||||
// Delete deletes the blob stored at key.
|
||||
//
|
||||
// If the blob does not exist, Delete returns an error for which
|
||||
// gcerrors.Code will return gcerrors.NotFound.
|
||||
func (b *Bucket) Delete(ctx context.Context, key string) (err error) {
|
||||
if !utf8.ValidString(key) {
|
||||
return fmt.Errorf("Delete key must be a valid UTF-8 string: %q", key)
|
||||
}
|
||||
|
||||
b.mu.RLock()
|
||||
defer b.mu.RUnlock()
|
||||
|
||||
if b.closed {
|
||||
return ErrClosed
|
||||
}
|
||||
|
||||
return wrapError(b.drv, b.drv.Delete(ctx, key), key)
|
||||
}
|
||||
|
||||
// Close releases any resources used for the bucket.
|
||||
//
|
||||
// @todo Consider removing it.
|
||||
func (b *Bucket) Close() error {
|
||||
b.mu.Lock()
|
||||
prev := b.closed
|
||||
b.closed = true
|
||||
b.mu.Unlock()
|
||||
|
||||
if prev {
|
||||
return ErrClosed
|
||||
}
|
||||
|
||||
return wrapError(b.drv, b.drv.Close(), "")
|
||||
}
|
||||
|
||||
func wrapError(b Driver, err error, key string) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// don't wrap or normalize EOF errors since there are many places
|
||||
// in the standard library (e.g. io.ReadAll) that rely on checks
|
||||
// such as "err == io.EOF" and they will fail
|
||||
if errors.Is(err, io.EOF) {
|
||||
return err
|
||||
}
|
||||
|
||||
err = b.NormalizeError(err)
|
||||
|
||||
if key != "" {
|
||||
err = fmt.Errorf("[key: %s] %w", key, err)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
107
tools/filesystem/blob/driver.go
Normal file
107
tools/filesystem/blob/driver.go
Normal file
|
@ -0,0 +1,107 @@
|
|||
package blob
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ReaderAttributes contains a subset of attributes about a blob that are
|
||||
// accessible from Reader.
|
||||
type ReaderAttributes struct {
|
||||
// ContentType is the MIME type of the blob object. It must not be empty.
|
||||
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Type
|
||||
ContentType string `json:"contentType"`
|
||||
|
||||
// ModTime is the time the blob object was last modified.
|
||||
ModTime time.Time `json:"modTime"`
|
||||
|
||||
// Size is the size of the object in bytes.
|
||||
Size int64 `json:"size"`
|
||||
}
|
||||
|
||||
// DriverReader reads an object from the blob.
|
||||
type DriverReader interface {
|
||||
io.ReadCloser
|
||||
|
||||
// Attributes returns a subset of attributes about the blob.
|
||||
// The portable type will not modify the returned ReaderAttributes.
|
||||
Attributes() *ReaderAttributes
|
||||
}
|
||||
|
||||
// DriverWriter writes an object to the blob.
|
||||
type DriverWriter interface {
|
||||
io.WriteCloser
|
||||
}
|
||||
|
||||
// Driver provides read, write and delete operations on objects within it on the
|
||||
// blob service.
|
||||
type Driver interface {
|
||||
NormalizeError(err error) error
|
||||
|
||||
// Attributes returns attributes for the blob. If the specified object does
|
||||
// not exist, Attributes must return an error for which ErrorCode returns ErrNotFound.
|
||||
// The portable type will not modify the returned Attributes.
|
||||
Attributes(ctx context.Context, key string) (*Attributes, error)
|
||||
|
||||
// ListPaged lists objects in the bucket, in lexicographical order by
|
||||
// UTF-8-encoded key, returning pages of objects at a time.
|
||||
// Services are only required to be eventually consistent with respect
|
||||
// to recently written or deleted objects. That is to say, there is no
|
||||
// guarantee that an object that's been written will immediately be returned
|
||||
// from ListPaged.
|
||||
// opts is guaranteed to be non-nil.
|
||||
ListPaged(ctx context.Context, opts *ListOptions) (*ListPage, error)
|
||||
|
||||
// NewRangeReader returns a Reader that reads part of an object, reading at
|
||||
// most length bytes starting at the given offset. If length is negative, it
|
||||
// will read until the end of the object. If the specified object does not
|
||||
// exist, NewRangeReader must return an error for which ErrorCode returns ErrNotFound.
|
||||
// opts is guaranteed to be non-nil.
|
||||
//
|
||||
// The returned Reader *may* also implement Downloader if the underlying
|
||||
// implementation can take advantage of that. The Download call is guaranteed
|
||||
// to be the only call to the Reader. For such readers, offset will always
|
||||
// be 0 and length will always be -1.
|
||||
NewRangeReader(ctx context.Context, key string, offset, length int64) (DriverReader, error)
|
||||
|
||||
// NewTypedWriter returns Writer that writes to an object associated with key.
|
||||
//
|
||||
// A new object will be created unless an object with this key already exists.
|
||||
// Otherwise any previous object with the same key will be replaced.
|
||||
// The object may not be available (and any previous object will remain)
|
||||
// until Close has been called.
|
||||
//
|
||||
// contentType sets the MIME type of the object to be written.
|
||||
// opts is guaranteed to be non-nil.
|
||||
//
|
||||
// The caller must call Close on the returned Writer when done writing.
|
||||
//
|
||||
// Implementations should abort an ongoing write if ctx is later canceled,
|
||||
// and do any necessary cleanup in Close. Close should then return ctx.Err().
|
||||
//
|
||||
// The returned Writer *may* also implement Uploader if the underlying
|
||||
// implementation can take advantage of that. The Upload call is guaranteed
|
||||
// to be the only non-Close call to the Writer..
|
||||
NewTypedWriter(ctx context.Context, key, contentType string, opts *WriterOptions) (DriverWriter, error)
|
||||
|
||||
// Copy copies the object associated with srcKey to dstKey.
|
||||
//
|
||||
// If the source object does not exist, Copy must return an error for which
|
||||
// ErrorCode returns ErrNotFound.
|
||||
//
|
||||
// If the destination object already exists, it should be overwritten.
|
||||
//
|
||||
// opts is guaranteed to be non-nil.
|
||||
Copy(ctx context.Context, dstKey, srcKey string) error
|
||||
|
||||
// Delete deletes the object associated with key. If the specified object does
|
||||
// not exist, Delete must return an error for which ErrorCode returns ErrNotFound.
|
||||
Delete(ctx context.Context, key string) error
|
||||
|
||||
// Close cleans up any resources used by the Bucket. Once Close is called,
|
||||
// there will be no method calls to the Bucket other than As, ErrorAs, and
|
||||
// ErrorCode. There may be open readers or writers that will receive calls.
|
||||
// It is up to the driver as to how these will be handled.
|
||||
Close() error
|
||||
}
|
153
tools/filesystem/blob/hex.go
Normal file
153
tools/filesystem/blob/hex.go
Normal file
|
@ -0,0 +1,153 @@
|
|||
package blob
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// Copied from gocloud.dev/blob to avoid nuances around the specific
|
||||
// HEX escaping/unescaping rules.
|
||||
//
|
||||
// -------------------------------------------------------------------
|
||||
// Copyright 2019 The Go Cloud Development Kit Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
// HexEscape returns s, with all runes for which shouldEscape returns true
|
||||
// escaped to "__0xXXX__", where XXX is the hex representation of the rune
|
||||
// value. For example, " " would escape to "__0x20__".
|
||||
//
|
||||
// Non-UTF-8 strings will have their non-UTF-8 characters escaped to
|
||||
// unicode.ReplacementChar; the original value is lost. Please file an
|
||||
// issue if you need non-UTF8 support.
|
||||
//
|
||||
// Note: shouldEscape takes the whole string as a slice of runes and an
|
||||
// index. Passing it a single byte or a single rune doesn't provide
|
||||
// enough context for some escape decisions; for example, the caller might
|
||||
// want to escape the second "/" in "//" but not the first one.
|
||||
// We pass a slice of runes instead of the string or a slice of bytes
|
||||
// because some decisions will be made on a rune basis (e.g., encode
|
||||
// all non-ASCII runes).
|
||||
func HexEscape(s string, shouldEscape func(s []rune, i int) bool) string {
|
||||
// Do a first pass to see which runes (if any) need escaping.
|
||||
runes := []rune(s)
|
||||
var toEscape []int
|
||||
for i := range runes {
|
||||
if shouldEscape(runes, i) {
|
||||
toEscape = append(toEscape, i)
|
||||
}
|
||||
}
|
||||
if len(toEscape) == 0 {
|
||||
return s
|
||||
}
|
||||
|
||||
// Each escaped rune turns into at most 14 runes ("__0x7fffffff__"),
|
||||
// so allocate an extra 13 for each. We'll reslice at the end
|
||||
// if we didn't end up using them.
|
||||
escaped := make([]rune, len(runes)+13*len(toEscape))
|
||||
n := 0 // current index into toEscape
|
||||
j := 0 // current index into escaped
|
||||
for i, r := range runes {
|
||||
if n < len(toEscape) && i == toEscape[n] {
|
||||
// We were asked to escape this rune.
|
||||
for _, x := range fmt.Sprintf("__%#x__", r) {
|
||||
escaped[j] = x
|
||||
j++
|
||||
}
|
||||
n++
|
||||
} else {
|
||||
escaped[j] = r
|
||||
j++
|
||||
}
|
||||
}
|
||||
|
||||
return string(escaped[0:j])
|
||||
}
|
||||
|
||||
// unescape tries to unescape starting at r[i].
|
||||
// It returns a boolean indicating whether the unescaping was successful,
|
||||
// and (if true) the unescaped rune and the last index of r that was used
|
||||
// during unescaping.
|
||||
func unescape(r []rune, i int) (bool, rune, int) {
|
||||
// Look for "__0x".
|
||||
if r[i] != '_' {
|
||||
return false, 0, 0
|
||||
}
|
||||
i++
|
||||
if i >= len(r) || r[i] != '_' {
|
||||
return false, 0, 0
|
||||
}
|
||||
i++
|
||||
if i >= len(r) || r[i] != '0' {
|
||||
return false, 0, 0
|
||||
}
|
||||
i++
|
||||
if i >= len(r) || r[i] != 'x' {
|
||||
return false, 0, 0
|
||||
}
|
||||
i++
|
||||
|
||||
// Capture the digits until the next "_" (if any).
|
||||
var hexdigits []rune
|
||||
for ; i < len(r) && r[i] != '_'; i++ {
|
||||
hexdigits = append(hexdigits, r[i])
|
||||
}
|
||||
|
||||
// Look for the trailing "__".
|
||||
if i >= len(r) || r[i] != '_' {
|
||||
return false, 0, 0
|
||||
}
|
||||
i++
|
||||
if i >= len(r) || r[i] != '_' {
|
||||
return false, 0, 0
|
||||
}
|
||||
|
||||
// Parse the hex digits into an int32.
|
||||
retval, err := strconv.ParseInt(string(hexdigits), 16, 32)
|
||||
if err != nil {
|
||||
return false, 0, 0
|
||||
}
|
||||
|
||||
return true, rune(retval), i
|
||||
}
|
||||
|
||||
// HexUnescape reverses HexEscape.
|
||||
func HexUnescape(s string) string {
|
||||
var unescaped []rune
|
||||
|
||||
runes := []rune(s)
|
||||
for i := 0; i < len(runes); i++ {
|
||||
if ok, newR, newI := unescape(runes, i); ok {
|
||||
// We unescaped some runes starting at i, resulting in the
|
||||
// unescaped rune newR. The last rune used was newI.
|
||||
if unescaped == nil {
|
||||
// This is the first rune we've encountered that
|
||||
// needed unescaping. Allocate a buffer and copy any
|
||||
// previous runes.
|
||||
unescaped = make([]rune, i)
|
||||
copy(unescaped, runes)
|
||||
}
|
||||
unescaped = append(unescaped, newR)
|
||||
i = newI
|
||||
} else if unescaped != nil {
|
||||
unescaped = append(unescaped, runes[i])
|
||||
}
|
||||
}
|
||||
|
||||
if unescaped == nil {
|
||||
return s
|
||||
}
|
||||
|
||||
return string(unescaped)
|
||||
}
|
196
tools/filesystem/blob/reader.go
Normal file
196
tools/filesystem/blob/reader.go
Normal file
|
@ -0,0 +1,196 @@
|
|||
package blob
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Largely copied from gocloud.dev/blob.Reader to minimize breaking changes.
|
||||
//
|
||||
// -------------------------------------------------------------------
|
||||
// Copyright 2019 The Go Cloud Development Kit Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
var _ io.ReadSeekCloser = (*Reader)(nil)
|
||||
|
||||
// Reader reads bytes from a blob.
|
||||
// It implements io.ReadSeekCloser, and must be closed after reads are finished.
|
||||
type Reader struct {
|
||||
ctx context.Context // Used to recreate r after Seeks
|
||||
r DriverReader
|
||||
drv Driver
|
||||
key string
|
||||
baseOffset int64 // The base offset provided to NewRangeReader.
|
||||
baseLength int64 // The length provided to NewRangeReader (may be negative).
|
||||
relativeOffset int64 // Current offset (relative to baseOffset).
|
||||
savedOffset int64 // Last relativeOffset for r, saved after relativeOffset is changed in Seek, or -1 if no Seek.
|
||||
closed bool
|
||||
}
|
||||
|
||||
// Read implements io.Reader (https://golang.org/pkg/io/#Reader).
|
||||
func (r *Reader) Read(p []byte) (int, error) {
|
||||
if r.savedOffset != -1 {
|
||||
// We've done one or more Seeks since the last read. We may have
|
||||
// to recreate the Reader.
|
||||
//
|
||||
// Note that remembering the savedOffset and lazily resetting the
|
||||
// reader like this allows the caller to Seek, then Seek again back,
|
||||
// to the original offset, without having to recreate the reader.
|
||||
// We only have to recreate the reader if we actually read after a Seek.
|
||||
// This is an important optimization because it's common to Seek
|
||||
// to (SeekEnd, 0) and use the return value to determine the size
|
||||
// of the data, then Seek back to (SeekStart, 0).
|
||||
saved := r.savedOffset
|
||||
if r.relativeOffset == saved {
|
||||
// Nope! We're at the same place we left off.
|
||||
r.savedOffset = -1
|
||||
} else {
|
||||
// Yep! We've changed the offset. Recreate the reader.
|
||||
length := r.baseLength
|
||||
if length >= 0 {
|
||||
length -= r.relativeOffset
|
||||
if length < 0 {
|
||||
// Shouldn't happen based on checks in Seek.
|
||||
return 0, fmt.Errorf("invalid Seek (base length %d, relative offset %d)", r.baseLength, r.relativeOffset)
|
||||
}
|
||||
}
|
||||
newR, err := r.drv.NewRangeReader(r.ctx, r.key, r.baseOffset+r.relativeOffset, length)
|
||||
if err != nil {
|
||||
return 0, wrapError(r.drv, err, r.key)
|
||||
}
|
||||
_ = r.r.Close()
|
||||
r.savedOffset = -1
|
||||
r.r = newR
|
||||
}
|
||||
}
|
||||
n, err := r.r.Read(p)
|
||||
r.relativeOffset += int64(n)
|
||||
return n, wrapError(r.drv, err, r.key)
|
||||
}
|
||||
|
||||
// Seek implements io.Seeker (https://golang.org/pkg/io/#Seeker).
|
||||
func (r *Reader) Seek(offset int64, whence int) (int64, error) {
|
||||
if r.savedOffset == -1 {
|
||||
// Save the current offset for our reader. If the Seek changes the
|
||||
// offset, and then we try to read, we'll need to recreate the reader.
|
||||
// See comment above in Read for why we do it lazily.
|
||||
r.savedOffset = r.relativeOffset
|
||||
}
|
||||
// The maximum relative offset is the minimum of:
|
||||
// 1. The actual size of the blob, minus our initial baseOffset.
|
||||
// 2. The length provided to NewRangeReader (if it was non-negative).
|
||||
maxRelativeOffset := r.Size() - r.baseOffset
|
||||
if r.baseLength >= 0 && r.baseLength < maxRelativeOffset {
|
||||
maxRelativeOffset = r.baseLength
|
||||
}
|
||||
switch whence {
|
||||
case io.SeekStart:
|
||||
r.relativeOffset = offset
|
||||
case io.SeekCurrent:
|
||||
r.relativeOffset += offset
|
||||
case io.SeekEnd:
|
||||
r.relativeOffset = maxRelativeOffset + offset
|
||||
}
|
||||
if r.relativeOffset < 0 {
|
||||
// "Seeking to an offset before the start of the file is an error."
|
||||
invalidOffset := r.relativeOffset
|
||||
r.relativeOffset = 0
|
||||
return 0, fmt.Errorf("Seek resulted in invalid offset %d, using 0", invalidOffset)
|
||||
}
|
||||
if r.relativeOffset > maxRelativeOffset {
|
||||
// "Seeking to any positive offset is legal, but the behavior of subsequent
|
||||
// I/O operations on the underlying object is implementation-dependent."
|
||||
// We'll choose to set the offset to the EOF.
|
||||
log.Printf("blob.Reader.Seek set an offset after EOF (base offset/length from NewRangeReader %d, %d; actual blob size %d; relative offset %d -> absolute offset %d).", r.baseOffset, r.baseLength, r.Size(), r.relativeOffset, r.baseOffset+r.relativeOffset)
|
||||
r.relativeOffset = maxRelativeOffset
|
||||
}
|
||||
return r.relativeOffset, nil
|
||||
}
|
||||
|
||||
// Close implements io.Closer (https://golang.org/pkg/io/#Closer).
|
||||
func (r *Reader) Close() error {
|
||||
r.closed = true
|
||||
err := wrapError(r.drv, r.r.Close(), r.key)
|
||||
return err
|
||||
}
|
||||
|
||||
// ContentType returns the MIME type of the blob.
|
||||
func (r *Reader) ContentType() string {
|
||||
return r.r.Attributes().ContentType
|
||||
}
|
||||
|
||||
// ModTime returns the time the blob was last modified.
|
||||
func (r *Reader) ModTime() time.Time {
|
||||
return r.r.Attributes().ModTime
|
||||
}
|
||||
|
||||
// Size returns the size of the blob content in bytes.
|
||||
func (r *Reader) Size() int64 {
|
||||
return r.r.Attributes().Size
|
||||
}
|
||||
|
||||
// WriteTo reads from r and writes to w until there's no more data or
|
||||
// an error occurs.
|
||||
// The return value is the number of bytes written to w.
|
||||
//
|
||||
// It implements the io.WriterTo interface.
|
||||
func (r *Reader) WriteTo(w io.Writer) (int64, error) {
|
||||
// If the writer has a ReaderFrom method, use it to do the copy.
|
||||
// Don't do this for our own *Writer to avoid infinite recursion.
|
||||
// Avoids an allocation and a copy.
|
||||
switch w.(type) {
|
||||
case *Writer:
|
||||
default:
|
||||
if rf, ok := w.(io.ReaderFrom); ok {
|
||||
n, err := rf.ReadFrom(r)
|
||||
return n, err
|
||||
}
|
||||
}
|
||||
|
||||
_, nw, err := readFromWriteTo(r, w)
|
||||
return nw, err
|
||||
}
|
||||
|
||||
// readFromWriteTo is a helper for ReadFrom and WriteTo.
|
||||
// It reads data from r and writes to w, until EOF or a read/write error.
|
||||
// It returns the number of bytes read from r and the number of bytes
|
||||
// written to w.
|
||||
func readFromWriteTo(r io.Reader, w io.Writer) (int64, int64, error) {
|
||||
// Note: can't use io.Copy because it will try to use r.WriteTo
|
||||
// or w.WriteTo, which is recursive in this context.
|
||||
buf := make([]byte, 1024)
|
||||
var totalRead, totalWritten int64
|
||||
for {
|
||||
numRead, rerr := r.Read(buf)
|
||||
if numRead > 0 {
|
||||
totalRead += int64(numRead)
|
||||
numWritten, werr := w.Write(buf[0:numRead])
|
||||
totalWritten += int64(numWritten)
|
||||
if werr != nil {
|
||||
return totalRead, totalWritten, werr
|
||||
}
|
||||
}
|
||||
if rerr == io.EOF {
|
||||
// Done!
|
||||
return totalRead, totalWritten, nil
|
||||
}
|
||||
if rerr != nil {
|
||||
return totalRead, totalWritten, rerr
|
||||
}
|
||||
}
|
||||
}
|
184
tools/filesystem/blob/writer.go
Normal file
184
tools/filesystem/blob/writer.go
Normal file
|
@ -0,0 +1,184 @@
|
|||
package blob
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"hash"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// Largely copied from gocloud.dev/blob.Writer to minimize breaking changes.
|
||||
//
|
||||
// -------------------------------------------------------------------
|
||||
// Copyright 2019 The Go Cloud Development Kit Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
var _ io.WriteCloser = (*Writer)(nil)
|
||||
|
||||
// Writer writes bytes to a blob.
|
||||
//
|
||||
// It implements io.WriteCloser (https://golang.org/pkg/io/#Closer), and must be
|
||||
// closed after all writes are done.
|
||||
type Writer struct {
|
||||
drv Driver
|
||||
w DriverWriter
|
||||
key string
|
||||
cancel func() // cancels the ctx provided to NewTypedWriter if contentMD5 verification fails
|
||||
contentMD5 []byte
|
||||
md5hash hash.Hash
|
||||
bytesWritten int
|
||||
closed bool
|
||||
|
||||
// These fields are non-zero values only when w is nil (not yet created).
|
||||
//
|
||||
// A ctx is stored in the Writer since we need to pass it into NewTypedWriter
|
||||
// when we finish detecting the content type of the blob and create the
|
||||
// underlying driver.Writer. This step happens inside Write or Close and
|
||||
// neither of them take a context.Context as an argument.
|
||||
//
|
||||
// All 3 fields are only initialized when we create the Writer without
|
||||
// setting the w field, and are reset to zero values after w is created.
|
||||
ctx context.Context
|
||||
opts *WriterOptions
|
||||
buf *bytes.Buffer
|
||||
}
|
||||
|
||||
// sniffLen is the byte size of Writer.buf used to detect content-type.
|
||||
const sniffLen = 512
|
||||
|
||||
// Write implements the io.Writer interface (https://golang.org/pkg/io/#Writer).
|
||||
//
|
||||
// Writes may happen asynchronously, so the returned error can be nil
|
||||
// even if the actual write eventually fails. The write is only guaranteed to
|
||||
// have succeeded if Close returns no error.
|
||||
func (w *Writer) Write(p []byte) (int, error) {
|
||||
if len(w.contentMD5) > 0 {
|
||||
if _, err := w.md5hash.Write(p); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
if w.w != nil {
|
||||
return w.write(p)
|
||||
}
|
||||
|
||||
// If w is not yet created due to no content-type being passed in, try to sniff
|
||||
// the MIME type based on at most 512 bytes of the blob content of p.
|
||||
|
||||
// Detect the content-type directly if the first chunk is at least 512 bytes.
|
||||
if w.buf.Len() == 0 && len(p) >= sniffLen {
|
||||
return w.open(p)
|
||||
}
|
||||
|
||||
// Store p in w.buf and detect the content-type when the size of content in
|
||||
// w.buf is at least 512 bytes.
|
||||
n, err := w.buf.Write(p)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if w.buf.Len() >= sniffLen {
|
||||
// Note that w.open will return the full length of the buffer; we don't want
|
||||
// to return that as the length of this write since some of them were written in
|
||||
// previous writes. Instead, we return the n from this write, above.
|
||||
_, err := w.open(w.buf.Bytes())
|
||||
return n, err
|
||||
}
|
||||
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// Close closes the blob writer. The write operation is not guaranteed
|
||||
// to have succeeded until Close returns with no error.
|
||||
//
|
||||
// Close may return an error if the context provided to create the
|
||||
// Writer is canceled or reaches its deadline.
|
||||
func (w *Writer) Close() (err error) {
|
||||
w.closed = true
|
||||
|
||||
// Verify the MD5 hash of what was written matches the ContentMD5 provided by the user.
|
||||
if len(w.contentMD5) > 0 {
|
||||
md5sum := w.md5hash.Sum(nil)
|
||||
if !bytes.Equal(md5sum, w.contentMD5) {
|
||||
// No match! Return an error, but first cancel the context and call the
|
||||
// driver's Close function to ensure the write is aborted.
|
||||
w.cancel()
|
||||
if w.w != nil {
|
||||
_ = w.w.Close()
|
||||
}
|
||||
return fmt.Errorf("the WriterOptions.ContentMD5 you specified (%X) did not match what was written (%X)", w.contentMD5, md5sum)
|
||||
}
|
||||
}
|
||||
|
||||
defer w.cancel()
|
||||
|
||||
if w.w != nil {
|
||||
return wrapError(w.drv, w.w.Close(), w.key)
|
||||
}
|
||||
|
||||
if _, err := w.open(w.buf.Bytes()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return wrapError(w.drv, w.w.Close(), w.key)
|
||||
}
|
||||
|
||||
// open tries to detect the MIME type of p and write it to the blob.
|
||||
// The error it returns is wrapped.
|
||||
func (w *Writer) open(p []byte) (int, error) {
|
||||
ct := http.DetectContentType(p)
|
||||
|
||||
var err error
|
||||
w.w, err = w.drv.NewTypedWriter(w.ctx, w.key, ct, w.opts)
|
||||
if err != nil {
|
||||
return 0, wrapError(w.drv, err, w.key)
|
||||
}
|
||||
|
||||
// Set the 3 fields needed for lazy NewTypedWriter back to zero values
|
||||
// (see the comment on Writer).
|
||||
w.buf = nil
|
||||
w.ctx = nil
|
||||
w.opts = nil
|
||||
|
||||
return w.write(p)
|
||||
}
|
||||
|
||||
func (w *Writer) write(p []byte) (int, error) {
|
||||
n, err := w.w.Write(p)
|
||||
w.bytesWritten += n
|
||||
return n, wrapError(w.drv, err, w.key)
|
||||
}
|
||||
|
||||
// ReadFrom reads from r and writes to w until EOF or error.
|
||||
// The return value is the number of bytes read from r.
|
||||
//
|
||||
// It implements the io.ReaderFrom interface.
|
||||
func (w *Writer) ReadFrom(r io.Reader) (int64, error) {
|
||||
// If the reader has a WriteTo method, use it to do the copy.
|
||||
// Don't do this for our own *Reader to avoid infinite recursion.
|
||||
// Avoids an allocation and a copy.
|
||||
switch r.(type) {
|
||||
case *Reader:
|
||||
default:
|
||||
if wt, ok := r.(io.WriterTo); ok {
|
||||
return wt.WriteTo(w)
|
||||
}
|
||||
}
|
||||
|
||||
nr, _, err := readFromWriteTo(r, w)
|
||||
return nr, err
|
||||
}
|
268
tools/filesystem/file.go
Normal file
268
tools/filesystem/file.go
Normal file
|
@ -0,0 +1,268 @@
|
|||
package filesystem
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"os"
|
||||
"path"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/gabriel-vasile/mimetype"
|
||||
"github.com/pocketbase/pocketbase/tools/inflector"
|
||||
"github.com/pocketbase/pocketbase/tools/security"
|
||||
)
|
||||
|
||||
// FileReader defines an interface for a file resource reader.
|
||||
type FileReader interface {
|
||||
Open() (io.ReadSeekCloser, error)
|
||||
}
|
||||
|
||||
// File defines a single file [io.ReadSeekCloser] resource.
|
||||
//
|
||||
// The file could be from a local path, multipart/form-data header, etc.
|
||||
type File struct {
|
||||
Reader FileReader `form:"-" json:"-" xml:"-"`
|
||||
Name string `form:"name" json:"name" xml:"name"`
|
||||
OriginalName string `form:"originalName" json:"originalName" xml:"originalName"`
|
||||
Size int64 `form:"size" json:"size" xml:"size"`
|
||||
}
|
||||
|
||||
// AsMap implements [core.mapExtractor] and returns a value suitable
|
||||
// to be used in an API rule expression.
|
||||
func (f *File) AsMap() map[string]any {
|
||||
return map[string]any{
|
||||
"name": f.Name,
|
||||
"originalName": f.OriginalName,
|
||||
"size": f.Size,
|
||||
}
|
||||
}
|
||||
|
||||
// NewFileFromPath creates a new File instance from the provided local file path.
|
||||
func NewFileFromPath(path string) (*File, error) {
|
||||
f := &File{}
|
||||
|
||||
info, err := os.Stat(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f.Reader = &PathReader{Path: path}
|
||||
f.Size = info.Size()
|
||||
f.OriginalName = info.Name()
|
||||
f.Name = normalizeName(f.Reader, f.OriginalName)
|
||||
|
||||
return f, nil
|
||||
}
|
||||
|
||||
// NewFileFromBytes creates a new File instance from the provided byte slice.
|
||||
func NewFileFromBytes(b []byte, name string) (*File, error) {
|
||||
size := len(b)
|
||||
if size == 0 {
|
||||
return nil, errors.New("cannot create an empty file")
|
||||
}
|
||||
|
||||
f := &File{}
|
||||
|
||||
f.Reader = &BytesReader{b}
|
||||
f.Size = int64(size)
|
||||
f.OriginalName = name
|
||||
f.Name = normalizeName(f.Reader, f.OriginalName)
|
||||
|
||||
return f, nil
|
||||
}
|
||||
|
||||
// NewFileFromMultipart creates a new File from the provided multipart header.
|
||||
func NewFileFromMultipart(mh *multipart.FileHeader) (*File, error) {
|
||||
f := &File{}
|
||||
|
||||
f.Reader = &MultipartReader{Header: mh}
|
||||
f.Size = mh.Size
|
||||
f.OriginalName = mh.Filename
|
||||
f.Name = normalizeName(f.Reader, f.OriginalName)
|
||||
|
||||
return f, nil
|
||||
}
|
||||
|
||||
// NewFileFromURL creates a new File from the provided url by
|
||||
// downloading the resource and load it as BytesReader.
|
||||
//
|
||||
// Example
|
||||
//
|
||||
// ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
// defer cancel()
|
||||
//
|
||||
// file, err := filesystem.NewFileFromURL(ctx, "https://example.com/image.png")
|
||||
func NewFileFromURL(ctx context.Context, url string) (*File, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
res, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode < 200 || res.StatusCode > 399 {
|
||||
return nil, fmt.Errorf("failed to download url %s (%d)", url, res.StatusCode)
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
|
||||
if _, err = io.Copy(&buf, res.Body); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewFileFromBytes(buf.Bytes(), path.Base(url))
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
var _ FileReader = (*MultipartReader)(nil)
|
||||
|
||||
// MultipartReader defines a FileReader from [multipart.FileHeader].
|
||||
type MultipartReader struct {
|
||||
Header *multipart.FileHeader
|
||||
}
|
||||
|
||||
// Open implements the [filesystem.FileReader] interface.
|
||||
func (r *MultipartReader) Open() (io.ReadSeekCloser, error) {
|
||||
return r.Header.Open()
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
var _ FileReader = (*PathReader)(nil)
|
||||
|
||||
// PathReader defines a FileReader from a local file path.
|
||||
type PathReader struct {
|
||||
Path string
|
||||
}
|
||||
|
||||
// Open implements the [filesystem.FileReader] interface.
|
||||
func (r *PathReader) Open() (io.ReadSeekCloser, error) {
|
||||
return os.Open(r.Path)
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
var _ FileReader = (*BytesReader)(nil)
|
||||
|
||||
// BytesReader defines a FileReader from bytes content.
|
||||
type BytesReader struct {
|
||||
Bytes []byte
|
||||
}
|
||||
|
||||
// Open implements the [filesystem.FileReader] interface.
|
||||
func (r *BytesReader) Open() (io.ReadSeekCloser, error) {
|
||||
return &bytesReadSeekCloser{bytes.NewReader(r.Bytes)}, nil
|
||||
}
|
||||
|
||||
type bytesReadSeekCloser struct {
|
||||
*bytes.Reader
|
||||
}
|
||||
|
||||
// Close implements the [io.ReadSeekCloser] interface.
|
||||
func (r *bytesReadSeekCloser) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
var _ FileReader = (openFuncAsReader)(nil)
|
||||
|
||||
// openFuncAsReader defines a FileReader from a bare Open function.
|
||||
type openFuncAsReader func() (io.ReadSeekCloser, error)
|
||||
|
||||
// Open implements the [filesystem.FileReader] interface.
|
||||
func (r openFuncAsReader) Open() (io.ReadSeekCloser, error) {
|
||||
return r()
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
var extInvalidCharsRegex = regexp.MustCompile(`[^\w\.\*\-\+\=\#]+`)
|
||||
|
||||
const randomAlphabet = "abcdefghijklmnopqrstuvwxyz0123456789"
|
||||
|
||||
func normalizeName(fr FileReader, name string) string {
|
||||
// extension
|
||||
// ---
|
||||
originalExt := extractExtension(name)
|
||||
cleanExt := extInvalidCharsRegex.ReplaceAllString(originalExt, "")
|
||||
if cleanExt == "" {
|
||||
// try to detect the extension from the file content
|
||||
cleanExt, _ = detectExtension(fr)
|
||||
}
|
||||
if extLength := len(cleanExt); extLength > 20 {
|
||||
// keep only the last 20 characters (it is multibyte safe after the regex replace)
|
||||
cleanExt = "." + cleanExt[extLength-20:]
|
||||
}
|
||||
|
||||
// name
|
||||
// ---
|
||||
cleanName := inflector.Snakecase(strings.TrimSuffix(name, originalExt))
|
||||
if length := len(cleanName); length < 3 {
|
||||
// the name is too short so we concatenate an additional random part
|
||||
cleanName += security.RandomStringWithAlphabet(10, randomAlphabet)
|
||||
} else if length > 100 {
|
||||
// keep only the first 100 characters (it is multibyte safe after Snakecase)
|
||||
cleanName = cleanName[:100]
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
"%s_%s%s",
|
||||
cleanName,
|
||||
security.RandomStringWithAlphabet(10, randomAlphabet), // ensure that there is always a random part
|
||||
cleanExt,
|
||||
)
|
||||
}
|
||||
|
||||
// extractExtension extracts the extension (with leading dot) from name.
|
||||
//
|
||||
// This differ from filepath.Ext() by supporting double extensions (eg. ".tar.gz").
|
||||
//
|
||||
// Returns an empty string if no match is found.
|
||||
//
|
||||
// Example:
|
||||
// extractExtension("test.txt") // .txt
|
||||
// extractExtension("test.tar.gz") // .tar.gz
|
||||
// extractExtension("test.a.tar.gz") // .tar.gz
|
||||
func extractExtension(name string) string {
|
||||
primaryDot := strings.LastIndex(name, ".")
|
||||
|
||||
if primaryDot == -1 {
|
||||
return ""
|
||||
}
|
||||
|
||||
// look for secondary extension
|
||||
secondaryDot := strings.LastIndex(name[:primaryDot], ".")
|
||||
if secondaryDot >= 0 {
|
||||
return name[secondaryDot:]
|
||||
}
|
||||
|
||||
return name[primaryDot:]
|
||||
}
|
||||
|
||||
// detectExtension tries to detect the extension from file mime type.
|
||||
func detectExtension(fr FileReader) (string, error) {
|
||||
r, err := fr.Open()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer r.Close()
|
||||
|
||||
mt, err := mimetype.DetectReader(r)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return mt.Extension(), nil
|
||||
}
|
231
tools/filesystem/file_test.go
Normal file
231
tools/filesystem/file_test.go
Normal file
|
@ -0,0 +1,231 @@
|
|||
package filesystem_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
"github.com/pocketbase/pocketbase/tools/filesystem"
|
||||
)
|
||||
|
||||
func TestFileAsMap(t *testing.T) {
|
||||
file, err := filesystem.NewFileFromBytes([]byte("test"), "test123.txt")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
result := file.AsMap()
|
||||
|
||||
if len(result) != 3 {
|
||||
t.Fatalf("Expected map with %d keys, got\n%v", 3, result)
|
||||
}
|
||||
|
||||
if result["size"] != int64(4) {
|
||||
t.Fatalf("Expected size %d, got %#v", 4, result["size"])
|
||||
}
|
||||
|
||||
if str, ok := result["name"].(string); !ok || !strings.HasPrefix(str, "test123") {
|
||||
t.Fatalf("Expected name to have prefix %q, got %#v", "test123", result["name"])
|
||||
}
|
||||
|
||||
if result["originalName"] != "test123.txt" {
|
||||
t.Fatalf("Expected originalName %q, got %#v", "test123.txt", result["originalName"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewFileFromPath(t *testing.T) {
|
||||
testDir := createTestDir(t)
|
||||
defer os.RemoveAll(testDir)
|
||||
|
||||
// missing file
|
||||
_, err := filesystem.NewFileFromPath("missing")
|
||||
if err == nil {
|
||||
t.Fatal("Expected error, got nil")
|
||||
}
|
||||
|
||||
// existing file
|
||||
originalName := "image_! noext"
|
||||
normalizedNamePattern := regexp.QuoteMeta("image_noext_") + `\w{10}` + regexp.QuoteMeta(".png")
|
||||
f, err := filesystem.NewFileFromPath(filepath.Join(testDir, originalName))
|
||||
if err != nil {
|
||||
t.Fatalf("Expected nil error, got %v", err)
|
||||
}
|
||||
if f.OriginalName != originalName {
|
||||
t.Fatalf("Expected OriginalName %q, got %q", originalName, f.OriginalName)
|
||||
}
|
||||
if match, err := regexp.Match(normalizedNamePattern, []byte(f.Name)); !match {
|
||||
t.Fatalf("Expected Name to match %v, got %q (%v)", normalizedNamePattern, f.Name, err)
|
||||
}
|
||||
if f.Size != 73 {
|
||||
t.Fatalf("Expected Size %v, got %v", 73, f.Size)
|
||||
}
|
||||
if _, ok := f.Reader.(*filesystem.PathReader); !ok {
|
||||
t.Fatalf("Expected Reader to be PathReader, got %v", f.Reader)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewFileFromBytes(t *testing.T) {
|
||||
// nil bytes
|
||||
if _, err := filesystem.NewFileFromBytes(nil, "photo.jpg"); err == nil {
|
||||
t.Fatal("Expected error, got nil")
|
||||
}
|
||||
|
||||
// zero bytes
|
||||
if _, err := filesystem.NewFileFromBytes([]byte{}, "photo.jpg"); err == nil {
|
||||
t.Fatal("Expected error, got nil")
|
||||
}
|
||||
|
||||
originalName := "image_! noext"
|
||||
normalizedNamePattern := regexp.QuoteMeta("image_noext_") + `\w{10}` + regexp.QuoteMeta(".txt")
|
||||
f, err := filesystem.NewFileFromBytes([]byte("text\n"), originalName)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if f.Size != 5 {
|
||||
t.Fatalf("Expected Size %v, got %v", 5, f.Size)
|
||||
}
|
||||
if f.OriginalName != originalName {
|
||||
t.Fatalf("Expected OriginalName %q, got %q", originalName, f.OriginalName)
|
||||
}
|
||||
if match, err := regexp.Match(normalizedNamePattern, []byte(f.Name)); !match {
|
||||
t.Fatalf("Expected Name to match %v, got %q (%v)", normalizedNamePattern, f.Name, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewFileFromMultipart(t *testing.T) {
|
||||
formData, mp, err := tests.MockMultipartData(nil, "test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("", "/", formData)
|
||||
req.Header.Set("Content-Type", mp.FormDataContentType())
|
||||
req.ParseMultipartForm(32 << 20)
|
||||
|
||||
_, mh, err := req.FormFile("test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
f, err := filesystem.NewFileFromMultipart(mh)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
originalNamePattern := regexp.QuoteMeta("tmpfile-") + `\w+` + regexp.QuoteMeta(".txt")
|
||||
if match, err := regexp.Match(originalNamePattern, []byte(f.OriginalName)); !match {
|
||||
t.Fatalf("Expected OriginalName to match %v, got %q (%v)", originalNamePattern, f.OriginalName, err)
|
||||
}
|
||||
|
||||
normalizedNamePattern := regexp.QuoteMeta("tmpfile_") + `\w+\_\w{10}` + regexp.QuoteMeta(".txt")
|
||||
if match, err := regexp.Match(normalizedNamePattern, []byte(f.Name)); !match {
|
||||
t.Fatalf("Expected Name to match %v, got %q (%v)", normalizedNamePattern, f.Name, err)
|
||||
}
|
||||
|
||||
if f.Size != 4 {
|
||||
t.Fatalf("Expected Size %v, got %v", 4, f.Size)
|
||||
}
|
||||
|
||||
if _, ok := f.Reader.(*filesystem.MultipartReader); !ok {
|
||||
t.Fatalf("Expected Reader to be MultipartReader, got %v", f.Reader)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewFileFromURLTimeout(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/error" {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
fmt.Fprintf(w, "test")
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
// cancelled context
|
||||
{
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
f, err := filesystem.NewFileFromURL(ctx, srv.URL+"/cancel")
|
||||
if err == nil {
|
||||
t.Fatal("[ctx_cancel] Expected error, got nil")
|
||||
}
|
||||
if f != nil {
|
||||
t.Fatalf("[ctx_cancel] Expected file to be nil, got %v", f)
|
||||
}
|
||||
}
|
||||
|
||||
// error response
|
||||
{
|
||||
f, err := filesystem.NewFileFromURL(context.Background(), srv.URL+"/error")
|
||||
if err == nil {
|
||||
t.Fatal("[error_status] Expected error, got nil")
|
||||
}
|
||||
if f != nil {
|
||||
t.Fatalf("[error_status] Expected file to be nil, got %v", f)
|
||||
}
|
||||
}
|
||||
|
||||
// valid response
|
||||
{
|
||||
originalName := "image_! noext"
|
||||
normalizedNamePattern := regexp.QuoteMeta("image_noext_") + `\w{10}` + regexp.QuoteMeta(".txt")
|
||||
|
||||
f, err := filesystem.NewFileFromURL(context.Background(), srv.URL+"/"+originalName)
|
||||
if err != nil {
|
||||
t.Fatalf("[valid] Unexpected error %v", err)
|
||||
}
|
||||
if f == nil {
|
||||
t.Fatal("[valid] Expected non-nil file")
|
||||
}
|
||||
|
||||
// check the created file fields
|
||||
if f.OriginalName != originalName {
|
||||
t.Fatalf("Expected OriginalName %q, got %q", originalName, f.OriginalName)
|
||||
}
|
||||
if match, err := regexp.Match(normalizedNamePattern, []byte(f.Name)); !match {
|
||||
t.Fatalf("Expected Name to match %v, got %q (%v)", normalizedNamePattern, f.Name, err)
|
||||
}
|
||||
if f.Size != 4 {
|
||||
t.Fatalf("Expected Size %v, got %v", 4, f.Size)
|
||||
}
|
||||
if _, ok := f.Reader.(*filesystem.BytesReader); !ok {
|
||||
t.Fatalf("Expected Reader to be BytesReader, got %v", f.Reader)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileNameNormalizations(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
name string
|
||||
pattern string
|
||||
}{
|
||||
{"", `^\w{10}_\w{10}\.txt$`},
|
||||
{".png", `^\w{10}_\w{10}\.png$`},
|
||||
{".tar.gz", `^\w{10}_\w{10}\.tar\.gz$`},
|
||||
{"a.tar.gz", `^a\w{10}_\w{10}\.tar\.gz$`},
|
||||
{"a.b.c.d.tar.gz", `^a_b_c_d_\w{10}\.tar\.gz$`},
|
||||
{"abcd", `^abcd_\w{10}\.txt$`},
|
||||
{"a b! c d . 456", `^a_b_c_d_\w{10}\.456$`}, // normalize spaces
|
||||
{strings.Repeat("a", 101) + "." + strings.Repeat("b", 21), `^a{100}_\w{10}\.b{20}$`}, // name and extension length trim
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(strconv.Itoa(i)+"_"+s.name, func(t *testing.T) {
|
||||
f, err := filesystem.NewFileFromBytes([]byte("abc"), s.name)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if match, err := regexp.Match(s.pattern, []byte(f.Name)); !match {
|
||||
t.Fatalf("Expected Name to match %v, got %q (%v)", s.pattern, f.Name, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
564
tools/filesystem/filesystem.go
Normal file
564
tools/filesystem/filesystem.go
Normal file
|
@ -0,0 +1,564 @@
|
|||
package filesystem
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"image"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/disintegration/imaging"
|
||||
"github.com/fatih/color"
|
||||
"github.com/gabriel-vasile/mimetype"
|
||||
"github.com/pocketbase/pocketbase/tools/filesystem/blob"
|
||||
"github.com/pocketbase/pocketbase/tools/filesystem/internal/fileblob"
|
||||
"github.com/pocketbase/pocketbase/tools/filesystem/internal/s3blob"
|
||||
"github.com/pocketbase/pocketbase/tools/filesystem/internal/s3blob/s3"
|
||||
"github.com/pocketbase/pocketbase/tools/list"
|
||||
|
||||
// explicit webp decoder because disintegration/imaging does not support webp
|
||||
_ "golang.org/x/image/webp"
|
||||
)
|
||||
|
||||
// note: the same as blob.ErrNotFound for backward compatibility with earlier versions
|
||||
var ErrNotFound = blob.ErrNotFound
|
||||
|
||||
const metadataOriginalName = "original-filename"
|
||||
|
||||
type System struct {
|
||||
ctx context.Context
|
||||
bucket *blob.Bucket
|
||||
}
|
||||
|
||||
// NewS3 initializes an S3 filesystem instance.
|
||||
//
|
||||
// NB! Make sure to call `Close()` after you are done working with it.
|
||||
func NewS3(
|
||||
bucketName string,
|
||||
region string,
|
||||
endpoint string,
|
||||
accessKey string,
|
||||
secretKey string,
|
||||
s3ForcePathStyle bool,
|
||||
) (*System, error) {
|
||||
ctx := context.Background() // default context
|
||||
|
||||
client := &s3.S3{
|
||||
Bucket: bucketName,
|
||||
Region: region,
|
||||
Endpoint: endpoint,
|
||||
AccessKey: accessKey,
|
||||
SecretKey: secretKey,
|
||||
UsePathStyle: s3ForcePathStyle,
|
||||
}
|
||||
|
||||
drv, err := s3blob.New(client)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &System{ctx: ctx, bucket: blob.NewBucket(drv)}, nil
|
||||
}
|
||||
|
||||
// NewLocal initializes a new local filesystem instance.
|
||||
//
|
||||
// NB! Make sure to call `Close()` after you are done working with it.
|
||||
func NewLocal(dirPath string) (*System, error) {
|
||||
ctx := context.Background() // default context
|
||||
|
||||
// makes sure that the directory exist
|
||||
if err := os.MkdirAll(dirPath, os.ModePerm); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
drv, err := fileblob.New(dirPath, &fileblob.Options{
|
||||
NoTempDir: true,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &System{ctx: ctx, bucket: blob.NewBucket(drv)}, nil
|
||||
}
|
||||
|
||||
// SetContext assigns the specified context to the current filesystem.
|
||||
func (s *System) SetContext(ctx context.Context) {
|
||||
s.ctx = ctx
|
||||
}
|
||||
|
||||
// Close releases any resources used for the related filesystem.
|
||||
func (s *System) Close() error {
|
||||
return s.bucket.Close()
|
||||
}
|
||||
|
||||
// Exists checks if file with fileKey path exists or not.
|
||||
func (s *System) Exists(fileKey string) (bool, error) {
|
||||
return s.bucket.Exists(s.ctx, fileKey)
|
||||
}
|
||||
|
||||
// Attributes returns the attributes for the file with fileKey path.
|
||||
//
|
||||
// If the file doesn't exist it returns ErrNotFound.
|
||||
func (s *System) Attributes(fileKey string) (*blob.Attributes, error) {
|
||||
return s.bucket.Attributes(s.ctx, fileKey)
|
||||
}
|
||||
|
||||
// GetReader returns a file content reader for the given fileKey.
|
||||
//
|
||||
// NB! Make sure to call Close() on the file after you are done working with it.
|
||||
//
|
||||
// If the file doesn't exist returns ErrNotFound.
|
||||
func (s *System) GetReader(fileKey string) (*blob.Reader, error) {
|
||||
return s.bucket.NewReader(s.ctx, fileKey)
|
||||
}
|
||||
|
||||
// Deprecated: Please use GetReader(fileKey) instead.
|
||||
func (s *System) GetFile(fileKey string) (*blob.Reader, error) {
|
||||
color.Yellow("Deprecated: Please replace GetFile with GetReader.")
|
||||
return s.GetReader(fileKey)
|
||||
}
|
||||
|
||||
// GetReuploadableFile constructs a new reuploadable File value
|
||||
// from the associated fileKey blob.Reader.
|
||||
//
|
||||
// If preserveName is false then the returned File.Name will have
|
||||
// a new randomly generated suffix, otherwise it will reuse the original one.
|
||||
//
|
||||
// This method could be useful in case you want to clone an existing
|
||||
// Record file and assign it to a new Record (e.g. in a Record duplicate action).
|
||||
//
|
||||
// If you simply want to copy an existing file to a new location you
|
||||
// could check the Copy(srcKey, dstKey) method.
|
||||
func (s *System) GetReuploadableFile(fileKey string, preserveName bool) (*File, error) {
|
||||
attrs, err := s.Attributes(fileKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
name := path.Base(fileKey)
|
||||
originalName := attrs.Metadata[metadataOriginalName]
|
||||
if originalName == "" {
|
||||
originalName = name
|
||||
}
|
||||
|
||||
file := &File{}
|
||||
file.Size = attrs.Size
|
||||
file.OriginalName = originalName
|
||||
file.Reader = openFuncAsReader(func() (io.ReadSeekCloser, error) {
|
||||
return s.GetReader(fileKey)
|
||||
})
|
||||
|
||||
if preserveName {
|
||||
file.Name = name
|
||||
} else {
|
||||
file.Name = normalizeName(file.Reader, originalName)
|
||||
}
|
||||
|
||||
return file, nil
|
||||
}
|
||||
|
||||
// Copy copies the file stored at srcKey to dstKey.
|
||||
//
|
||||
// If srcKey file doesn't exist, it returns ErrNotFound.
|
||||
//
|
||||
// If dstKey file already exists, it is overwritten.
|
||||
func (s *System) Copy(srcKey, dstKey string) error {
|
||||
return s.bucket.Copy(s.ctx, dstKey, srcKey)
|
||||
}
|
||||
|
||||
// List returns a flat list with info for all files under the specified prefix.
|
||||
func (s *System) List(prefix string) ([]*blob.ListObject, error) {
|
||||
files := []*blob.ListObject{}
|
||||
|
||||
iter := s.bucket.List(&blob.ListOptions{
|
||||
Prefix: prefix,
|
||||
})
|
||||
|
||||
for {
|
||||
obj, err := iter.Next(s.ctx)
|
||||
if err != nil {
|
||||
if !errors.Is(err, io.EOF) {
|
||||
return nil, err
|
||||
}
|
||||
break
|
||||
}
|
||||
files = append(files, obj)
|
||||
}
|
||||
|
||||
return files, nil
|
||||
}
|
||||
|
||||
// Upload writes content into the fileKey location.
|
||||
func (s *System) Upload(content []byte, fileKey string) error {
|
||||
opts := &blob.WriterOptions{
|
||||
ContentType: mimetype.Detect(content).String(),
|
||||
}
|
||||
|
||||
w, writerErr := s.bucket.NewWriter(s.ctx, fileKey, opts)
|
||||
if writerErr != nil {
|
||||
return writerErr
|
||||
}
|
||||
|
||||
if _, err := w.Write(content); err != nil {
|
||||
return errors.Join(err, w.Close())
|
||||
}
|
||||
|
||||
return w.Close()
|
||||
}
|
||||
|
||||
// UploadFile uploads the provided File to the fileKey location.
|
||||
func (s *System) UploadFile(file *File, fileKey string) error {
|
||||
f, err := file.Reader.Open()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
mt, err := mimetype.DetectReader(f)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// rewind
|
||||
f.Seek(0, io.SeekStart)
|
||||
|
||||
originalName := file.OriginalName
|
||||
if len(originalName) > 255 {
|
||||
// keep only the first 255 chars as a very rudimentary measure
|
||||
// to prevent the metadata to grow too big in size
|
||||
originalName = originalName[:255]
|
||||
}
|
||||
opts := &blob.WriterOptions{
|
||||
ContentType: mt.String(),
|
||||
Metadata: map[string]string{
|
||||
metadataOriginalName: originalName,
|
||||
},
|
||||
}
|
||||
|
||||
w, err := s.bucket.NewWriter(s.ctx, fileKey, opts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := w.ReadFrom(f); err != nil {
|
||||
w.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
return w.Close()
|
||||
}
|
||||
|
||||
// UploadMultipart uploads the provided multipart file to the fileKey location.
|
||||
func (s *System) UploadMultipart(fh *multipart.FileHeader, fileKey string) error {
|
||||
f, err := fh.Open()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
mt, err := mimetype.DetectReader(f)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// rewind
|
||||
f.Seek(0, io.SeekStart)
|
||||
|
||||
originalName := fh.Filename
|
||||
if len(originalName) > 255 {
|
||||
// keep only the first 255 chars as a very rudimentary measure
|
||||
// to prevent the metadata to grow too big in size
|
||||
originalName = originalName[:255]
|
||||
}
|
||||
opts := &blob.WriterOptions{
|
||||
ContentType: mt.String(),
|
||||
Metadata: map[string]string{
|
||||
metadataOriginalName: originalName,
|
||||
},
|
||||
}
|
||||
|
||||
w, err := s.bucket.NewWriter(s.ctx, fileKey, opts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = w.ReadFrom(f)
|
||||
if err != nil {
|
||||
w.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
return w.Close()
|
||||
}
|
||||
|
||||
// Delete deletes stored file at fileKey location.
|
||||
//
|
||||
// If the file doesn't exist returns ErrNotFound.
|
||||
func (s *System) Delete(fileKey string) error {
|
||||
return s.bucket.Delete(s.ctx, fileKey)
|
||||
}
|
||||
|
||||
// DeletePrefix deletes everything starting with the specified prefix.
|
||||
//
|
||||
// The prefix could be subpath (ex. "/a/b/") or filename prefix (ex. "/a/b/file_").
|
||||
func (s *System) DeletePrefix(prefix string) []error {
|
||||
failed := []error{}
|
||||
|
||||
if prefix == "" {
|
||||
failed = append(failed, errors.New("prefix mustn't be empty"))
|
||||
return failed
|
||||
}
|
||||
|
||||
dirsMap := map[string]struct{}{}
|
||||
|
||||
var isPrefixDir bool
|
||||
|
||||
// treat the prefix as directory only if it ends with trailing slash
|
||||
if strings.HasSuffix(prefix, "/") {
|
||||
isPrefixDir = true
|
||||
dirsMap[strings.TrimRight(prefix, "/")] = struct{}{}
|
||||
}
|
||||
|
||||
// delete all files with the prefix
|
||||
// ---
|
||||
iter := s.bucket.List(&blob.ListOptions{
|
||||
Prefix: prefix,
|
||||
})
|
||||
for {
|
||||
obj, err := iter.Next(s.ctx)
|
||||
if err != nil {
|
||||
if !errors.Is(err, io.EOF) {
|
||||
failed = append(failed, err)
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
if err := s.Delete(obj.Key); err != nil {
|
||||
failed = append(failed, err)
|
||||
} else if isPrefixDir {
|
||||
slashIdx := strings.LastIndex(obj.Key, "/")
|
||||
if slashIdx > -1 {
|
||||
dirsMap[obj.Key[:slashIdx]] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
// ---
|
||||
|
||||
// try to delete the empty remaining dir objects
|
||||
// (this operation usually is optional and there is no need to strictly check the result)
|
||||
// ---
|
||||
// fill dirs slice
|
||||
dirs := make([]string, 0, len(dirsMap))
|
||||
for d := range dirsMap {
|
||||
dirs = append(dirs, d)
|
||||
}
|
||||
|
||||
// sort the child dirs first, aka. ["a/b/c", "a/b", "a"]
|
||||
sort.SliceStable(dirs, func(i, j int) bool {
|
||||
return len(strings.Split(dirs[i], "/")) > len(strings.Split(dirs[j], "/"))
|
||||
})
|
||||
|
||||
// delete dirs
|
||||
for _, d := range dirs {
|
||||
if d != "" {
|
||||
s.Delete(d)
|
||||
}
|
||||
}
|
||||
// ---
|
||||
|
||||
return failed
|
||||
}
|
||||
|
||||
// Checks if the provided dir prefix doesn't have any files.
|
||||
//
|
||||
// A trailing slash will be appended to a non-empty dir string argument
|
||||
// to ensure that the checked prefix is a "directory".
|
||||
//
|
||||
// Returns "false" in case the has at least one file, otherwise - "true".
|
||||
func (s *System) IsEmptyDir(dir string) bool {
|
||||
if dir != "" && !strings.HasSuffix(dir, "/") {
|
||||
dir += "/"
|
||||
}
|
||||
|
||||
iter := s.bucket.List(&blob.ListOptions{
|
||||
Prefix: dir,
|
||||
})
|
||||
|
||||
_, err := iter.Next(s.ctx)
|
||||
|
||||
return err != nil && errors.Is(err, io.EOF)
|
||||
}
|
||||
|
||||
var inlineServeContentTypes = []string{
|
||||
// image
|
||||
"image/png", "image/jpg", "image/jpeg", "image/gif", "image/webp", "image/x-icon", "image/bmp",
|
||||
// video
|
||||
"video/webm", "video/mp4", "video/3gpp", "video/quicktime", "video/x-ms-wmv",
|
||||
// audio
|
||||
"audio/basic", "audio/aiff", "audio/mpeg", "audio/midi", "audio/mp3", "audio/wave",
|
||||
"audio/wav", "audio/x-wav", "audio/x-mpeg", "audio/x-m4a", "audio/aac",
|
||||
// document
|
||||
"application/pdf", "application/x-pdf",
|
||||
}
|
||||
|
||||
// manualExtensionContentTypes is a map of file extensions to content types.
|
||||
var manualExtensionContentTypes = map[string]string{
|
||||
".svg": "image/svg+xml", // (see https://github.com/whatwg/mimesniff/issues/7)
|
||||
".css": "text/css", // (see https://github.com/gabriel-vasile/mimetype/pull/113)
|
||||
".js": "text/javascript", // (see https://github.com/pocketbase/pocketbase/issues/6597)
|
||||
".mjs": "text/javascript",
|
||||
}
|
||||
|
||||
// forceAttachmentParam is the name of the request query parameter to
|
||||
// force "Content-Disposition: attachment" header.
|
||||
const forceAttachmentParam = "download"
|
||||
|
||||
// Serve serves the file at fileKey location to an HTTP response.
|
||||
//
|
||||
// If the `download` query parameter is used the file will be always served for
|
||||
// download no matter of its type (aka. with "Content-Disposition: attachment").
|
||||
//
|
||||
// Internally this method uses [http.ServeContent] so Range requests,
|
||||
// If-Match, If-Unmodified-Since, etc. headers are handled transparently.
|
||||
func (s *System) Serve(res http.ResponseWriter, req *http.Request, fileKey string, name string) error {
|
||||
br, readErr := s.GetReader(fileKey)
|
||||
if readErr != nil {
|
||||
return readErr
|
||||
}
|
||||
defer br.Close()
|
||||
|
||||
var forceAttachment bool
|
||||
if raw := req.URL.Query().Get(forceAttachmentParam); raw != "" {
|
||||
forceAttachment, _ = strconv.ParseBool(raw)
|
||||
}
|
||||
|
||||
disposition := "attachment"
|
||||
realContentType := br.ContentType()
|
||||
if !forceAttachment && list.ExistInSlice(realContentType, inlineServeContentTypes) {
|
||||
disposition = "inline"
|
||||
}
|
||||
|
||||
// make an exception for specific content types and force a custom
|
||||
// content type to send in the response so that it can be loaded properly
|
||||
extContentType := realContentType
|
||||
if ct, found := manualExtensionContentTypes[filepath.Ext(name)]; found {
|
||||
extContentType = ct
|
||||
}
|
||||
|
||||
setHeaderIfMissing(res, "Content-Disposition", disposition+"; filename="+name)
|
||||
setHeaderIfMissing(res, "Content-Type", extContentType)
|
||||
setHeaderIfMissing(res, "Content-Security-Policy", "default-src 'none'; media-src 'self'; style-src 'unsafe-inline'; sandbox")
|
||||
|
||||
// set a default cache-control header
|
||||
// (valid for 30 days but the cache is allowed to reuse the file for any requests
|
||||
// that are made in the last day while revalidating the res in the background)
|
||||
setHeaderIfMissing(res, "Cache-Control", "max-age=2592000, stale-while-revalidate=86400")
|
||||
|
||||
http.ServeContent(res, req, name, br.ModTime(), br)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// note: expects key to be in a canonical form (eg. "accept-encoding" should be "Accept-Encoding").
|
||||
func setHeaderIfMissing(res http.ResponseWriter, key string, value string) {
|
||||
if _, ok := res.Header()[key]; !ok {
|
||||
res.Header().Set(key, value)
|
||||
}
|
||||
}
|
||||
|
||||
var ThumbSizeRegex = regexp.MustCompile(`^(\d+)x(\d+)(t|b|f)?$`)
|
||||
|
||||
// CreateThumb creates a new thumb image for the file at originalKey location.
|
||||
// The new thumb file is stored at thumbKey location.
|
||||
//
|
||||
// thumbSize is in the format:
|
||||
// - 0xH (eg. 0x100) - resize to H height preserving the aspect ratio
|
||||
// - Wx0 (eg. 300x0) - resize to W width preserving the aspect ratio
|
||||
// - WxH (eg. 300x100) - resize and crop to WxH viewbox (from center)
|
||||
// - WxHt (eg. 300x100t) - resize and crop to WxH viewbox (from top)
|
||||
// - WxHb (eg. 300x100b) - resize and crop to WxH viewbox (from bottom)
|
||||
// - WxHf (eg. 300x100f) - fit inside a WxH viewbox (without cropping)
|
||||
func (s *System) CreateThumb(originalKey string, thumbKey, thumbSize string) error {
|
||||
sizeParts := ThumbSizeRegex.FindStringSubmatch(thumbSize)
|
||||
if len(sizeParts) != 4 {
|
||||
return errors.New("thumb size must be in WxH, WxHt, WxHb or WxHf format")
|
||||
}
|
||||
|
||||
width, _ := strconv.Atoi(sizeParts[1])
|
||||
height, _ := strconv.Atoi(sizeParts[2])
|
||||
resizeType := sizeParts[3]
|
||||
|
||||
if width == 0 && height == 0 {
|
||||
return errors.New("thumb width and height cannot be zero at the same time")
|
||||
}
|
||||
|
||||
// fetch the original
|
||||
r, readErr := s.GetReader(originalKey)
|
||||
if readErr != nil {
|
||||
return readErr
|
||||
}
|
||||
defer r.Close()
|
||||
|
||||
// create imaging object from the original reader
|
||||
// (note: only the first frame for animated image formats)
|
||||
img, decodeErr := imaging.Decode(r, imaging.AutoOrientation(true))
|
||||
if decodeErr != nil {
|
||||
return decodeErr
|
||||
}
|
||||
|
||||
var thumbImg *image.NRGBA
|
||||
|
||||
if width == 0 || height == 0 {
|
||||
// force resize preserving aspect ratio
|
||||
thumbImg = imaging.Resize(img, width, height, imaging.Linear)
|
||||
} else {
|
||||
switch resizeType {
|
||||
case "f":
|
||||
// fit
|
||||
thumbImg = imaging.Fit(img, width, height, imaging.Linear)
|
||||
case "t":
|
||||
// fill and crop from top
|
||||
thumbImg = imaging.Fill(img, width, height, imaging.Top, imaging.Linear)
|
||||
case "b":
|
||||
// fill and crop from bottom
|
||||
thumbImg = imaging.Fill(img, width, height, imaging.Bottom, imaging.Linear)
|
||||
default:
|
||||
// fill and crop from center
|
||||
thumbImg = imaging.Fill(img, width, height, imaging.Center, imaging.Linear)
|
||||
}
|
||||
}
|
||||
|
||||
opts := &blob.WriterOptions{
|
||||
ContentType: r.ContentType(),
|
||||
}
|
||||
|
||||
// open a thumb storage writer (aka. prepare for upload)
|
||||
w, writerErr := s.bucket.NewWriter(s.ctx, thumbKey, opts)
|
||||
if writerErr != nil {
|
||||
return writerErr
|
||||
}
|
||||
|
||||
// try to detect the thumb format based on the original file name
|
||||
// (fallbacks to png on error)
|
||||
format, err := imaging.FormatFromFilename(thumbKey)
|
||||
if err != nil {
|
||||
format = imaging.PNG
|
||||
}
|
||||
|
||||
// thumb encode (aka. upload)
|
||||
if err := imaging.Encode(w, thumbImg, format); err != nil {
|
||||
w.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
// check for close errors to ensure that the thumb was really saved
|
||||
return w.Close()
|
||||
}
|
1016
tools/filesystem/filesystem_test.go
Normal file
1016
tools/filesystem/filesystem_test.go
Normal file
File diff suppressed because it is too large
Load diff
84
tools/filesystem/internal/fileblob/attrs.go
Normal file
84
tools/filesystem/internal/fileblob/attrs.go
Normal file
|
@ -0,0 +1,84 @@
|
|||
package fileblob
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
)
|
||||
|
||||
// Largely copied from gocloud.dev/blob/fileblob to apply the same
|
||||
// retrieve and write side-car .attrs rules.
|
||||
//
|
||||
// -------------------------------------------------------------------
|
||||
// Copyright 2018 The Go Cloud Development Kit Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
const attrsExt = ".attrs"
|
||||
|
||||
var errAttrsExt = fmt.Errorf("file extension %q is reserved", attrsExt)
|
||||
|
||||
// xattrs stores extended attributes for an object. The format is like
|
||||
// filesystem extended attributes, see
|
||||
// https://www.freedesktop.org/wiki/CommonExtendedAttributes.
|
||||
type xattrs struct {
|
||||
CacheControl string `json:"user.cache_control"`
|
||||
ContentDisposition string `json:"user.content_disposition"`
|
||||
ContentEncoding string `json:"user.content_encoding"`
|
||||
ContentLanguage string `json:"user.content_language"`
|
||||
ContentType string `json:"user.content_type"`
|
||||
Metadata map[string]string `json:"user.metadata"`
|
||||
MD5 []byte `json:"md5"`
|
||||
}
|
||||
|
||||
// setAttrs creates a "path.attrs" file along with blob to store the attributes,
|
||||
// it uses JSON format.
|
||||
func setAttrs(path string, xa xattrs) error {
|
||||
f, err := os.Create(path + attrsExt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := json.NewEncoder(f).Encode(xa); err != nil {
|
||||
f.Close()
|
||||
os.Remove(f.Name())
|
||||
return err
|
||||
}
|
||||
|
||||
return f.Close()
|
||||
}
|
||||
|
||||
// getAttrs looks at the "path.attrs" file to retrieve the attributes and
|
||||
// decodes them into a xattrs struct. It doesn't return error when there is no
|
||||
// such .attrs file.
|
||||
func getAttrs(path string) (xattrs, error) {
|
||||
f, err := os.Open(path + attrsExt)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
// Handle gracefully for non-existent .attr files.
|
||||
return xattrs{
|
||||
ContentType: "application/octet-stream",
|
||||
}, nil
|
||||
}
|
||||
return xattrs{}, err
|
||||
}
|
||||
|
||||
xa := new(xattrs)
|
||||
if err := json.NewDecoder(f).Decode(xa); err != nil {
|
||||
f.Close()
|
||||
return xattrs{}, err
|
||||
}
|
||||
|
||||
return *xa, f.Close()
|
||||
}
|
713
tools/filesystem/internal/fileblob/fileblob.go
Normal file
713
tools/filesystem/internal/fileblob/fileblob.go
Normal file
|
@ -0,0 +1,713 @@
|
|||
// Package fileblob provides a blob.Bucket driver implementation.
|
||||
//
|
||||
// NB! To minimize breaking changes with older PocketBase releases,
|
||||
// the driver is a stripped down and adapted version of the previously
|
||||
// used gocloud.dev/blob/fileblob, hence many of the below doc comments,
|
||||
// struct options and interface implementations are the same.
|
||||
//
|
||||
// To avoid partial writes, fileblob writes to a temporary file and then renames
|
||||
// the temporary file to the final path on Close. By default, it creates these
|
||||
// temporary files in `os.TempDir`. If `os.TempDir` is on a different mount than
|
||||
// your base bucket path, the `os.Rename` will fail with `invalid cross-device link`.
|
||||
// To avoid this, either configure the temp dir to use by setting the environment
|
||||
// variable `TMPDIR`, or set `Options.NoTempDir` to `true` (fileblob will create
|
||||
// the temporary files next to the actual files instead of in a temporary directory).
|
||||
//
|
||||
// By default fileblob stores blob metadata in "sidecar" files under the original
|
||||
// filename with an additional ".attrs" suffix.
|
||||
// This behaviour can be changed via `Options.Metadata`;
|
||||
// writing of those metadata files can be suppressed by setting it to
|
||||
// `MetadataDontWrite` or its equivalent "metadata=skip" in the URL for the opener.
|
||||
// In either case, absent any stored metadata many `blob.Attributes` fields
|
||||
// will be set to default values.
|
||||
//
|
||||
// The blob abstraction supports all UTF-8 strings; to make this work with services lacking
|
||||
// full UTF-8 support, strings must be escaped (during writes) and unescaped
|
||||
// (during reads). The following escapes are performed for fileblob:
|
||||
// - Blob keys: ASCII characters 0-31 are escaped to "__0x<hex>__".
|
||||
// If os.PathSeparator != "/", it is also escaped.
|
||||
// Additionally, the "/" in "../", the trailing "/" in "//", and a trailing
|
||||
// "/" is key names are escaped in the same way.
|
||||
// On Windows, the characters "<>:"|?*" are also escaped.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// drv, _ := fileblob.New("/path/to/dir", nil)
|
||||
// bucket := blob.NewBucket(drv)
|
||||
package fileblob
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/md5"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash"
|
||||
"io"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/filesystem/blob"
|
||||
)
|
||||
|
||||
const defaultPageSize = 1000
|
||||
|
||||
type metadataOption string // Not exported as subject to change.
|
||||
|
||||
// Settings for Options.Metadata.
|
||||
const (
|
||||
// Metadata gets written to a separate file.
|
||||
MetadataInSidecar metadataOption = ""
|
||||
|
||||
// Writes won't carry metadata, as per the package docstring.
|
||||
MetadataDontWrite metadataOption = "skip"
|
||||
)
|
||||
|
||||
// Options sets options for constructing a *blob.Bucket backed by fileblob.
|
||||
type Options struct {
|
||||
// Refers to the strategy for how to deal with metadata (such as blob.Attributes).
|
||||
// For supported values please see the Metadata* constants.
|
||||
// If left unchanged, 'MetadataInSidecar' will be used.
|
||||
Metadata metadataOption
|
||||
|
||||
// The FileMode to use when creating directories for the top-level directory
|
||||
// backing the bucket (when CreateDir is true), and for subdirectories for keys.
|
||||
// Defaults to 0777.
|
||||
DirFileMode os.FileMode
|
||||
|
||||
// If true, create the directory backing the Bucket if it does not exist
|
||||
// (using os.MkdirAll).
|
||||
CreateDir bool
|
||||
|
||||
// If true, don't use os.TempDir for temporary files, but instead place them
|
||||
// next to the actual files. This may result in "stranded" temporary files
|
||||
// (e.g., if the application is killed before the file cleanup runs).
|
||||
//
|
||||
// If your bucket directory is on a different mount than os.TempDir, you will
|
||||
// need to set this to true, as os.Rename will fail across mount points.
|
||||
NoTempDir bool
|
||||
}
|
||||
|
||||
// New creates a new instance of the fileblob driver backed by the
|
||||
// filesystem and rooted at dir, which must exist.
|
||||
func New(dir string, opts *Options) (blob.Driver, error) {
|
||||
if opts == nil {
|
||||
opts = &Options{}
|
||||
}
|
||||
if opts.DirFileMode == 0 {
|
||||
opts.DirFileMode = os.FileMode(0o777)
|
||||
}
|
||||
|
||||
absdir, err := filepath.Abs(dir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to convert %s into an absolute path: %v", dir, err)
|
||||
}
|
||||
|
||||
// Optionally, create the directory if it does not already exist.
|
||||
info, err := os.Stat(absdir)
|
||||
if err != nil && opts.CreateDir && os.IsNotExist(err) {
|
||||
err = os.MkdirAll(absdir, opts.DirFileMode)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("tried to create directory but failed: %v", err)
|
||||
}
|
||||
info, err = os.Stat(absdir)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !info.IsDir() {
|
||||
return nil, fmt.Errorf("%s is not a directory", absdir)
|
||||
}
|
||||
|
||||
return &driver{dir: absdir, opts: opts}, nil
|
||||
}
|
||||
|
||||
type driver struct {
|
||||
opts *Options
|
||||
dir string
|
||||
}
|
||||
|
||||
// Close implements [blob/Driver.Close].
|
||||
func (drv *driver) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// NormalizeError implements [blob/Driver.NormalizeError].
|
||||
func (drv *driver) NormalizeError(err error) error {
|
||||
if os.IsNotExist(err) {
|
||||
return errors.Join(err, blob.ErrNotFound)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// path returns the full path for a key.
|
||||
func (drv *driver) path(key string) (string, error) {
|
||||
path := filepath.Join(drv.dir, escapeKey(key))
|
||||
|
||||
if strings.HasSuffix(path, attrsExt) {
|
||||
return "", errAttrsExt
|
||||
}
|
||||
|
||||
return path, nil
|
||||
}
|
||||
|
||||
// forKey returns the full path, os.FileInfo, and attributes for key.
|
||||
func (drv *driver) forKey(key string) (string, os.FileInfo, *xattrs, error) {
|
||||
path, err := drv.path(key)
|
||||
if err != nil {
|
||||
return "", nil, nil, err
|
||||
}
|
||||
|
||||
info, err := os.Stat(path)
|
||||
if err != nil {
|
||||
return "", nil, nil, err
|
||||
}
|
||||
|
||||
if info.IsDir() {
|
||||
return "", nil, nil, os.ErrNotExist
|
||||
}
|
||||
|
||||
xa, err := getAttrs(path)
|
||||
if err != nil {
|
||||
return "", nil, nil, err
|
||||
}
|
||||
|
||||
return path, info, &xa, nil
|
||||
}
|
||||
|
||||
// ListPaged implements [blob/Driver.ListPaged].
|
||||
func (drv *driver) ListPaged(ctx context.Context, opts *blob.ListOptions) (*blob.ListPage, error) {
|
||||
var pageToken string
|
||||
if len(opts.PageToken) > 0 {
|
||||
pageToken = string(opts.PageToken)
|
||||
}
|
||||
|
||||
pageSize := opts.PageSize
|
||||
if pageSize == 0 {
|
||||
pageSize = defaultPageSize
|
||||
}
|
||||
|
||||
// If opts.Delimiter != "", lastPrefix contains the last "directory" key we
|
||||
// added. It is used to avoid adding it again; all files in this "directory"
|
||||
// are collapsed to the single directory entry.
|
||||
var lastPrefix string
|
||||
var lastKeyAdded string
|
||||
|
||||
// If the Prefix contains a "/", we can set the root of the Walk
|
||||
// to the path specified by the Prefix as any files below the path will not
|
||||
// match the Prefix.
|
||||
// Note that we use "/" explicitly and not os.PathSeparator, as the opts.Prefix
|
||||
// is in the unescaped form.
|
||||
root := drv.dir
|
||||
if i := strings.LastIndex(opts.Prefix, "/"); i > -1 {
|
||||
root = filepath.Join(root, opts.Prefix[:i])
|
||||
}
|
||||
|
||||
var result blob.ListPage
|
||||
|
||||
// Do a full recursive scan of the root directory.
|
||||
err := filepath.WalkDir(root, func(path string, info fs.DirEntry, err error) error {
|
||||
if err != nil {
|
||||
// Couldn't read this file/directory for some reason; just skip it.
|
||||
return nil
|
||||
}
|
||||
|
||||
// Skip the self-generated attribute files.
|
||||
if strings.HasSuffix(path, attrsExt) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// os.Walk returns the root directory; skip it.
|
||||
if path == drv.dir {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Strip the <drv.dir> prefix from path.
|
||||
prefixLen := len(drv.dir)
|
||||
// Include the separator for non-root.
|
||||
if drv.dir != "/" {
|
||||
prefixLen++
|
||||
}
|
||||
path = path[prefixLen:]
|
||||
|
||||
// Unescape the path to get the key.
|
||||
key := unescapeKey(path)
|
||||
|
||||
// Skip all directories. If opts.Delimiter is set, we'll create
|
||||
// pseudo-directories later.
|
||||
// Note that returning nil means that we'll still recurse into it;
|
||||
// we're just not adding a result for the directory itself.
|
||||
if info.IsDir() {
|
||||
key += "/"
|
||||
// Avoid recursing into subdirectories if the directory name already
|
||||
// doesn't match the prefix; any files in it are guaranteed not to match.
|
||||
if len(key) > len(opts.Prefix) && !strings.HasPrefix(key, opts.Prefix) {
|
||||
return filepath.SkipDir
|
||||
}
|
||||
// Similarly, avoid recursing into subdirectories if we're making
|
||||
// "directories" and all of the files in this subdirectory are guaranteed
|
||||
// to collapse to a "directory" that we've already added.
|
||||
if lastPrefix != "" && strings.HasPrefix(key, lastPrefix) {
|
||||
return filepath.SkipDir
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Skip files/directories that don't match the Prefix.
|
||||
if !strings.HasPrefix(key, opts.Prefix) {
|
||||
return nil
|
||||
}
|
||||
|
||||
var md5 []byte
|
||||
if xa, err := getAttrs(path); err == nil {
|
||||
// Note: we only have the MD5 hash for blobs that we wrote.
|
||||
// For other blobs, md5 will remain nil.
|
||||
md5 = xa.MD5
|
||||
}
|
||||
|
||||
fi, err := info.Info()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
obj := &blob.ListObject{
|
||||
Key: key,
|
||||
ModTime: fi.ModTime(),
|
||||
Size: fi.Size(),
|
||||
MD5: md5,
|
||||
}
|
||||
|
||||
// If using Delimiter, collapse "directories".
|
||||
if opts.Delimiter != "" {
|
||||
// Strip the prefix, which may contain Delimiter.
|
||||
keyWithoutPrefix := key[len(opts.Prefix):]
|
||||
// See if the key still contains Delimiter.
|
||||
// If no, it's a file and we just include it.
|
||||
// If yes, it's a file in a "sub-directory" and we want to collapse
|
||||
// all files in that "sub-directory" into a single "directory" result.
|
||||
if idx := strings.Index(keyWithoutPrefix, opts.Delimiter); idx != -1 {
|
||||
prefix := opts.Prefix + keyWithoutPrefix[0:idx+len(opts.Delimiter)]
|
||||
// We've already included this "directory"; don't add it.
|
||||
if prefix == lastPrefix {
|
||||
return nil
|
||||
}
|
||||
// Update the object to be a "directory".
|
||||
obj = &blob.ListObject{
|
||||
Key: prefix,
|
||||
IsDir: true,
|
||||
}
|
||||
lastPrefix = prefix
|
||||
}
|
||||
}
|
||||
|
||||
// If there's a pageToken, skip anything before it.
|
||||
if pageToken != "" && obj.Key <= pageToken {
|
||||
return nil
|
||||
}
|
||||
|
||||
// If we've already got a full page of results, set NextPageToken and stop.
|
||||
// Unless the current object is a directory, in which case there may
|
||||
// still be objects coming that are alphabetically before it (since
|
||||
// we appended the delimiter). In that case, keep going; we'll trim the
|
||||
// extra entries (if any) before returning.
|
||||
if len(result.Objects) == pageSize && !obj.IsDir {
|
||||
result.NextPageToken = []byte(result.Objects[pageSize-1].Key)
|
||||
return io.EOF
|
||||
}
|
||||
|
||||
result.Objects = append(result.Objects, obj)
|
||||
|
||||
// Normally, objects are added in the correct order (by Key).
|
||||
// However, sometimes adding the file delimiter messes that up
|
||||
// (e.g., if the file delimiter is later in the alphabet than the last character of a key).
|
||||
// Detect if this happens and swap if needed.
|
||||
if len(result.Objects) > 1 && obj.Key < lastKeyAdded {
|
||||
i := len(result.Objects) - 1
|
||||
result.Objects[i-1], result.Objects[i] = result.Objects[i], result.Objects[i-1]
|
||||
lastKeyAdded = result.Objects[i].Key
|
||||
} else {
|
||||
lastKeyAdded = obj.Key
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil && err != io.EOF {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(result.Objects) > pageSize {
|
||||
result.Objects = result.Objects[0:pageSize]
|
||||
result.NextPageToken = []byte(result.Objects[pageSize-1].Key)
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// Attributes implements [blob/Driver.Attributes].
|
||||
func (drv *driver) Attributes(ctx context.Context, key string) (*blob.Attributes, error) {
|
||||
_, info, xa, err := drv.forKey(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &blob.Attributes{
|
||||
CacheControl: xa.CacheControl,
|
||||
ContentDisposition: xa.ContentDisposition,
|
||||
ContentEncoding: xa.ContentEncoding,
|
||||
ContentLanguage: xa.ContentLanguage,
|
||||
ContentType: xa.ContentType,
|
||||
Metadata: xa.Metadata,
|
||||
// CreateTime left as the zero time.
|
||||
ModTime: info.ModTime(),
|
||||
Size: info.Size(),
|
||||
MD5: xa.MD5,
|
||||
ETag: fmt.Sprintf("\"%x-%x\"", info.ModTime().UnixNano(), info.Size()),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// NewRangeReader implements [blob/Driver.NewRangeReader].
|
||||
func (drv *driver) NewRangeReader(ctx context.Context, key string, offset, length int64) (blob.DriverReader, error) {
|
||||
path, info, xa, err := drv.forKey(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if offset > 0 {
|
||||
if _, err := f.Seek(offset, io.SeekStart); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
r := io.Reader(f)
|
||||
if length >= 0 {
|
||||
r = io.LimitReader(r, length)
|
||||
}
|
||||
|
||||
return &reader{
|
||||
r: r,
|
||||
c: f,
|
||||
attrs: &blob.ReaderAttributes{
|
||||
ContentType: xa.ContentType,
|
||||
ModTime: info.ModTime(),
|
||||
Size: info.Size(),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func createTemp(path string, noTempDir bool) (*os.File, error) {
|
||||
// Use a custom createTemp function rather than os.CreateTemp() as
|
||||
// os.CreateTemp() sets the permissions of the tempfile to 0600, rather than
|
||||
// 0666, making it inconsistent with the directories and attribute files.
|
||||
try := 0
|
||||
for {
|
||||
// Append the current time with nanosecond precision and .tmp to the
|
||||
// base path. If the file already exists try again. Nanosecond changes enough
|
||||
// between each iteration to make a conflict unlikely. Using the full
|
||||
// time lowers the chance of a collision with a file using a similar
|
||||
// pattern, but has undefined behavior after the year 2262.
|
||||
var name string
|
||||
if noTempDir {
|
||||
name = path
|
||||
} else {
|
||||
name = filepath.Join(os.TempDir(), filepath.Base(path))
|
||||
}
|
||||
name += "." + strconv.FormatInt(time.Now().UnixNano(), 16) + ".tmp"
|
||||
|
||||
f, err := os.OpenFile(name, os.O_RDWR|os.O_CREATE|os.O_EXCL, 0o666)
|
||||
if os.IsExist(err) {
|
||||
if try++; try < 10000 {
|
||||
continue
|
||||
}
|
||||
return nil, &os.PathError{Op: "createtemp", Path: path + ".*.tmp", Err: os.ErrExist}
|
||||
}
|
||||
|
||||
return f, err
|
||||
}
|
||||
}
|
||||
|
||||
// NewTypedWriter implements [blob/Driver.NewTypedWriter].
|
||||
func (drv *driver) NewTypedWriter(ctx context.Context, key, contentType string, opts *blob.WriterOptions) (blob.DriverWriter, error) {
|
||||
path, err := drv.path(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = os.MkdirAll(filepath.Dir(path), drv.opts.DirFileMode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f, err := createTemp(path, drv.opts.NoTempDir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if drv.opts.Metadata == MetadataDontWrite {
|
||||
w := &writer{
|
||||
ctx: ctx,
|
||||
File: f,
|
||||
path: path,
|
||||
}
|
||||
return w, nil
|
||||
}
|
||||
|
||||
var metadata map[string]string
|
||||
if len(opts.Metadata) > 0 {
|
||||
metadata = opts.Metadata
|
||||
}
|
||||
|
||||
return &writerWithSidecar{
|
||||
ctx: ctx,
|
||||
f: f,
|
||||
path: path,
|
||||
contentMD5: opts.ContentMD5,
|
||||
md5hash: md5.New(),
|
||||
attrs: xattrs{
|
||||
CacheControl: opts.CacheControl,
|
||||
ContentDisposition: opts.ContentDisposition,
|
||||
ContentEncoding: opts.ContentEncoding,
|
||||
ContentLanguage: opts.ContentLanguage,
|
||||
ContentType: contentType,
|
||||
Metadata: metadata,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Copy implements [blob/Driver.Copy].
|
||||
func (drv *driver) Copy(ctx context.Context, dstKey, srcKey string) error {
|
||||
// Note: we could use NewRangeReader here, but since we need to copy all of
|
||||
// the metadata (from xa), it's more efficient to do it directly.
|
||||
srcPath, _, xa, err := drv.forKey(srcKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
f, err := os.Open(srcPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
// We'll write the copy using Writer, to avoid re-implementing making of a
|
||||
// temp file, cleaning up after partial failures, etc.
|
||||
wopts := blob.WriterOptions{
|
||||
CacheControl: xa.CacheControl,
|
||||
ContentDisposition: xa.ContentDisposition,
|
||||
ContentEncoding: xa.ContentEncoding,
|
||||
ContentLanguage: xa.ContentLanguage,
|
||||
Metadata: xa.Metadata,
|
||||
}
|
||||
|
||||
// Create a cancelable context so we can cancel the write if there are problems.
|
||||
writeCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
w, err := drv.NewTypedWriter(writeCtx, dstKey, xa.ContentType, &wopts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = io.Copy(w, f)
|
||||
if err != nil {
|
||||
cancel() // cancel before Close cancels the write
|
||||
w.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
return w.Close()
|
||||
}
|
||||
|
||||
// Delete implements [blob/Driver.Delete].
|
||||
func (b *driver) Delete(ctx context.Context, key string) error {
|
||||
path, err := b.path(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = os.Remove(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = os.Remove(path + attrsExt)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
type reader struct {
|
||||
r io.Reader
|
||||
c io.Closer
|
||||
attrs *blob.ReaderAttributes
|
||||
}
|
||||
|
||||
func (r *reader) Read(p []byte) (int, error) {
|
||||
if r.r == nil {
|
||||
return 0, io.EOF
|
||||
}
|
||||
return r.r.Read(p)
|
||||
}
|
||||
|
||||
func (r *reader) Close() error {
|
||||
if r.c == nil {
|
||||
return nil
|
||||
}
|
||||
return r.c.Close()
|
||||
}
|
||||
|
||||
// Attributes implements [blob/DriverReader.Attributes].
|
||||
func (r *reader) Attributes() *blob.ReaderAttributes {
|
||||
return r.attrs
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
// writerWithSidecar implements the strategy of storing metadata in a distinct file.
|
||||
type writerWithSidecar struct {
|
||||
ctx context.Context
|
||||
md5hash hash.Hash
|
||||
f *os.File
|
||||
path string
|
||||
attrs xattrs
|
||||
contentMD5 []byte
|
||||
}
|
||||
|
||||
func (w *writerWithSidecar) Write(p []byte) (n int, err error) {
|
||||
n, err = w.f.Write(p)
|
||||
if err != nil {
|
||||
// Don't hash the unwritten tail twice when writing is resumed.
|
||||
w.md5hash.Write(p[:n])
|
||||
return n, err
|
||||
}
|
||||
|
||||
if _, err := w.md5hash.Write(p); err != nil {
|
||||
return n, err
|
||||
}
|
||||
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (w *writerWithSidecar) Close() error {
|
||||
err := w.f.Close()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Always delete the temp file. On success, it will have been
|
||||
// renamed so the Remove will fail.
|
||||
defer func() {
|
||||
_ = os.Remove(w.f.Name())
|
||||
}()
|
||||
|
||||
// Check if the write was cancelled.
|
||||
if err := w.ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
md5sum := w.md5hash.Sum(nil)
|
||||
w.attrs.MD5 = md5sum
|
||||
|
||||
// Write the attributes file.
|
||||
if err := setAttrs(w.path, w.attrs); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Rename the temp file to path.
|
||||
if err := os.Rename(w.f.Name(), w.path); err != nil {
|
||||
_ = os.Remove(w.path + attrsExt)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// writer is a file with a temporary name until closed.
|
||||
//
|
||||
// Embedding os.File allows the likes of io.Copy to use optimizations,
|
||||
// which is why it is not folded into writerWithSidecar.
|
||||
type writer struct {
|
||||
*os.File
|
||||
ctx context.Context
|
||||
path string
|
||||
}
|
||||
|
||||
func (w *writer) Close() error {
|
||||
err := w.File.Close()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Always delete the temp file. On success, it will have been renamed so
|
||||
// the Remove will fail.
|
||||
tempname := w.Name()
|
||||
defer os.Remove(tempname)
|
||||
|
||||
// Check if the write was cancelled.
|
||||
if err := w.ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Rename the temp file to path.
|
||||
return os.Rename(tempname, w.path)
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
// escapeKey does all required escaping for UTF-8 strings to work the filesystem.
|
||||
func escapeKey(s string) string {
|
||||
s = blob.HexEscape(s, func(r []rune, i int) bool {
|
||||
c := r[i]
|
||||
switch {
|
||||
case c < 32:
|
||||
return true
|
||||
// We're going to replace '/' with os.PathSeparator below. In order for this
|
||||
// to be reversible, we need to escape raw os.PathSeparators.
|
||||
case os.PathSeparator != '/' && c == os.PathSeparator:
|
||||
return true
|
||||
// For "../", escape the trailing slash.
|
||||
case i > 1 && c == '/' && r[i-1] == '.' && r[i-2] == '.':
|
||||
return true
|
||||
// For "//", escape the trailing slash.
|
||||
case i > 0 && c == '/' && r[i-1] == '/':
|
||||
return true
|
||||
// Escape the trailing slash in a key.
|
||||
case c == '/' && i == len(r)-1:
|
||||
return true
|
||||
// https://docs.microsoft.com/en-us/windows/desktop/fileio/naming-a-file
|
||||
case os.PathSeparator == '\\' && (c == '>' || c == '<' || c == ':' || c == '"' || c == '|' || c == '?' || c == '*'):
|
||||
return true
|
||||
}
|
||||
return false
|
||||
})
|
||||
|
||||
// Replace "/" with os.PathSeparator if needed, so that the local filesystem
|
||||
// can use subdirectories.
|
||||
if os.PathSeparator != '/' {
|
||||
s = strings.ReplaceAll(s, "/", string(os.PathSeparator))
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// unescapeKey reverses escapeKey.
|
||||
func unescapeKey(s string) string {
|
||||
if os.PathSeparator != '/' {
|
||||
s = strings.ReplaceAll(s, string(os.PathSeparator), "/")
|
||||
}
|
||||
|
||||
return blob.HexUnescape(s)
|
||||
}
|
59
tools/filesystem/internal/s3blob/s3/copy_object.go
Normal file
59
tools/filesystem/internal/s3blob/s3/copy_object.go
Normal file
|
@ -0,0 +1,59 @@
|
|||
package s3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/xml"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// https://docs.aws.amazon.com/AmazonS3/latest/API/API_CopyObject.html#API_CopyObject_ResponseSyntax
|
||||
type CopyObjectResponse struct {
|
||||
CopyObjectResult xml.Name `json:"copyObjectResult" xml:"CopyObjectResult"`
|
||||
ETag string `json:"etag" xml:"ETag"`
|
||||
LastModified time.Time `json:"lastModified" xml:"LastModified"`
|
||||
ChecksumType string `json:"checksumType" xml:"ChecksumType"`
|
||||
ChecksumCRC32 string `json:"checksumCRC32" xml:"ChecksumCRC32"`
|
||||
ChecksumCRC32C string `json:"checksumCRC32C" xml:"ChecksumCRC32C"`
|
||||
ChecksumCRC64NVME string `json:"checksumCRC64NVME" xml:"ChecksumCRC64NVME"`
|
||||
ChecksumSHA1 string `json:"checksumSHA1" xml:"ChecksumSHA1"`
|
||||
ChecksumSHA256 string `json:"checksumSHA256" xml:"ChecksumSHA256"`
|
||||
}
|
||||
|
||||
// CopyObject copies a single object from srcKey to dstKey destination.
|
||||
// (both keys are expected to be operating within the same bucket).
|
||||
//
|
||||
// https://docs.aws.amazon.com/AmazonS3/latest/API/API_CopyObject.html
|
||||
func (s3 *S3) CopyObject(ctx context.Context, srcKey string, dstKey string, optReqFuncs ...func(*http.Request)) (*CopyObjectResponse, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPut, s3.URL(dstKey), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// per the doc the header value must be URL-encoded
|
||||
req.Header.Set("x-amz-copy-source", url.PathEscape(s3.Bucket+"/"+strings.TrimLeft(srcKey, "/")))
|
||||
|
||||
// apply optional request funcs
|
||||
for _, fn := range optReqFuncs {
|
||||
if fn != nil {
|
||||
fn(req)
|
||||
}
|
||||
}
|
||||
|
||||
resp, err := s3.SignAndSend(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
result := &CopyObjectResponse{}
|
||||
|
||||
err = xml.NewDecoder(resp.Body).Decode(result)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
67
tools/filesystem/internal/s3blob/s3/copy_object_test.go
Normal file
67
tools/filesystem/internal/s3blob/s3/copy_object_test.go
Normal file
|
@ -0,0 +1,67 @@
|
|||
package s3_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/filesystem/internal/s3blob/s3"
|
||||
"github.com/pocketbase/pocketbase/tools/filesystem/internal/s3blob/s3/tests"
|
||||
)
|
||||
|
||||
func TestS3CopyObject(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
httpClient := tests.NewClient(
|
||||
&tests.RequestStub{
|
||||
Method: http.MethodPut,
|
||||
URL: "http://test_bucket.example.com/@dst_test",
|
||||
Match: func(req *http.Request) bool {
|
||||
return tests.ExpectHeaders(req.Header, map[string]string{
|
||||
"test_header": "test",
|
||||
"x-amz-copy-source": "test_bucket%2F@src_test",
|
||||
"Authorization": "^.+Credential=123/.+$",
|
||||
})
|
||||
},
|
||||
Response: &http.Response{
|
||||
Body: io.NopCloser(strings.NewReader(`
|
||||
<CopyObjectResult>
|
||||
<LastModified>2025-01-01T01:02:03.456Z</LastModified>
|
||||
<ETag>test_etag</ETag>
|
||||
</CopyObjectResult>
|
||||
`)),
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
s3Client := &s3.S3{
|
||||
Client: httpClient,
|
||||
Region: "test_region",
|
||||
Bucket: "test_bucket",
|
||||
Endpoint: "http://example.com",
|
||||
AccessKey: "123",
|
||||
SecretKey: "abc",
|
||||
}
|
||||
|
||||
copyResp, err := s3Client.CopyObject(context.Background(), "@src_test", "@dst_test", func(r *http.Request) {
|
||||
r.Header.Set("test_header", "test")
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = httpClient.AssertNoRemaining()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if copyResp.ETag != "test_etag" {
|
||||
t.Fatalf("Expected ETag %q, got %q", "test_etag", copyResp.ETag)
|
||||
}
|
||||
|
||||
if date := copyResp.LastModified.Format("2006-01-02T15:04:05.000Z"); date != "2025-01-01T01:02:03.456Z" {
|
||||
t.Fatalf("Expected LastModified %q, got %q", "2025-01-01T01:02:03.456Z", date)
|
||||
}
|
||||
}
|
31
tools/filesystem/internal/s3blob/s3/delete_object.go
Normal file
31
tools/filesystem/internal/s3blob/s3/delete_object.go
Normal file
|
@ -0,0 +1,31 @@
|
|||
package s3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// DeleteObject deletes a single object by its key.
|
||||
//
|
||||
// https://docs.aws.amazon.com/AmazonS3/latest/API/API_DeleteObject.html
|
||||
func (s3 *S3) DeleteObject(ctx context.Context, key string, optFuncs ...func(*http.Request)) error {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodDelete, s3.URL(key), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// apply optional request funcs
|
||||
for _, fn := range optFuncs {
|
||||
if fn != nil {
|
||||
fn(req)
|
||||
}
|
||||
}
|
||||
|
||||
resp, err := s3.SignAndSend(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
return nil
|
||||
}
|
48
tools/filesystem/internal/s3blob/s3/delete_object_test.go
Normal file
48
tools/filesystem/internal/s3blob/s3/delete_object_test.go
Normal file
|
@ -0,0 +1,48 @@
|
|||
package s3_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/filesystem/internal/s3blob/s3"
|
||||
"github.com/pocketbase/pocketbase/tools/filesystem/internal/s3blob/s3/tests"
|
||||
)
|
||||
|
||||
func TestS3DeleteObject(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
httpClient := tests.NewClient(
|
||||
&tests.RequestStub{
|
||||
Method: http.MethodDelete,
|
||||
URL: "http://test_bucket.example.com/test_key",
|
||||
Match: func(req *http.Request) bool {
|
||||
return tests.ExpectHeaders(req.Header, map[string]string{
|
||||
"test_header": "test",
|
||||
"Authorization": "^.+Credential=123/.+$",
|
||||
})
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
s3Client := &s3.S3{
|
||||
Client: httpClient,
|
||||
Region: "test_region",
|
||||
Bucket: "test_bucket",
|
||||
Endpoint: "http://example.com",
|
||||
AccessKey: "123",
|
||||
SecretKey: "abc",
|
||||
}
|
||||
|
||||
err := s3Client.DeleteObject(context.Background(), "test_key", func(r *http.Request) {
|
||||
r.Header.Set("test_header", "test")
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = httpClient.AssertNoRemaining()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
49
tools/filesystem/internal/s3blob/s3/error.go
Normal file
49
tools/filesystem/internal/s3blob/s3/error.go
Normal file
|
@ -0,0 +1,49 @@
|
|||
package s3
|
||||
|
||||
import (
|
||||
"encoding/xml"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var _ error = (*ResponseError)(nil)
|
||||
|
||||
// ResponseError defines a general S3 response error.
|
||||
//
|
||||
// https://docs.aws.amazon.com/AmazonS3/latest/API/ErrorResponses.html
|
||||
type ResponseError struct {
|
||||
XMLName xml.Name `json:"-" xml:"Error"`
|
||||
Code string `json:"code" xml:"Code"`
|
||||
Message string `json:"message" xml:"Message"`
|
||||
RequestId string `json:"requestId" xml:"RequestId"`
|
||||
Resource string `json:"resource" xml:"Resource"`
|
||||
Raw []byte `json:"-" xml:"-"`
|
||||
Status int `json:"status" xml:"Status"`
|
||||
}
|
||||
|
||||
// Error implements the std error interface.
|
||||
func (err *ResponseError) Error() string {
|
||||
var strBuilder strings.Builder
|
||||
|
||||
strBuilder.WriteString(strconv.Itoa(err.Status))
|
||||
strBuilder.WriteString(" ")
|
||||
|
||||
if err.Code != "" {
|
||||
strBuilder.WriteString(err.Code)
|
||||
} else {
|
||||
strBuilder.WriteString("S3ResponseError")
|
||||
}
|
||||
|
||||
if err.Message != "" {
|
||||
strBuilder.WriteString(": ")
|
||||
strBuilder.WriteString(err.Message)
|
||||
}
|
||||
|
||||
if len(err.Raw) > 0 {
|
||||
strBuilder.WriteString("\n(RAW: ")
|
||||
strBuilder.Write(err.Raw)
|
||||
strBuilder.WriteString(")")
|
||||
}
|
||||
|
||||
return strBuilder.String()
|
||||
}
|
86
tools/filesystem/internal/s3blob/s3/error_test.go
Normal file
86
tools/filesystem/internal/s3blob/s3/error_test.go
Normal file
|
@ -0,0 +1,86 @@
|
|||
package s3_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"encoding/xml"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/filesystem/internal/s3blob/s3"
|
||||
)
|
||||
|
||||
func TestResponseErrorSerialization(t *testing.T) {
|
||||
raw := `
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<Error>
|
||||
<Code>test_code</Code>
|
||||
<Message>test_message</Message>
|
||||
<RequestId>test_request_id</RequestId>
|
||||
<Resource>test_resource</Resource>
|
||||
</Error>
|
||||
`
|
||||
|
||||
respErr := &s3.ResponseError{
|
||||
Status: 123,
|
||||
Raw: []byte("test"),
|
||||
}
|
||||
|
||||
err := xml.Unmarshal([]byte(raw), &respErr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
jsonRaw, err := json.Marshal(respErr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
jsonStr := string(jsonRaw)
|
||||
|
||||
expected := `{"code":"test_code","message":"test_message","requestId":"test_request_id","resource":"test_resource","status":123}`
|
||||
|
||||
if expected != jsonStr {
|
||||
t.Fatalf("Expected JSON\n%s\ngot\n%s", expected, jsonStr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponseErrorErrorInterface(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
name string
|
||||
err *s3.ResponseError
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
"empty",
|
||||
&s3.ResponseError{},
|
||||
"0 S3ResponseError",
|
||||
},
|
||||
{
|
||||
"with code and message (nil raw)",
|
||||
&s3.ResponseError{
|
||||
Status: 123,
|
||||
Code: "test_code",
|
||||
Message: "test_message",
|
||||
},
|
||||
"123 test_code: test_message",
|
||||
},
|
||||
{
|
||||
"with code and message (non-nil raw)",
|
||||
&s3.ResponseError{
|
||||
Status: 123,
|
||||
Code: "test_code",
|
||||
Message: "test_message",
|
||||
Raw: []byte("test_raw"),
|
||||
},
|
||||
"123 test_code: test_message\n(RAW: test_raw)",
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
result := s.err.Error()
|
||||
|
||||
if result != s.expected {
|
||||
t.Fatalf("Expected\n%s\ngot\n%s", s.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
43
tools/filesystem/internal/s3blob/s3/get_object.go
Normal file
43
tools/filesystem/internal/s3blob/s3/get_object.go
Normal file
|
@ -0,0 +1,43 @@
|
|||
package s3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// https://docs.aws.amazon.com/AmazonS3/latest/API/API_GetObject.html#API_GetObject_ResponseElements
|
||||
type GetObjectResponse struct {
|
||||
Body io.ReadCloser `json:"-" xml:"-"`
|
||||
|
||||
HeadObjectResponse
|
||||
}
|
||||
|
||||
// GetObject retrieves a single object by its key.
|
||||
//
|
||||
// NB! Make sure to call GetObjectResponse.Body.Close() after done working with the result.
|
||||
//
|
||||
// https://docs.aws.amazon.com/AmazonS3/latest/API/API_GetObject.html
|
||||
func (s3 *S3) GetObject(ctx context.Context, key string, optFuncs ...func(*http.Request)) (*GetObjectResponse, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, s3.URL(key), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// apply optional request funcs
|
||||
for _, fn := range optFuncs {
|
||||
if fn != nil {
|
||||
fn(req)
|
||||
}
|
||||
}
|
||||
|
||||
resp, err := s3.SignAndSend(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := &GetObjectResponse{Body: resp.Body}
|
||||
result.load(resp.Header)
|
||||
|
||||
return result, nil
|
||||
}
|
92
tools/filesystem/internal/s3blob/s3/get_object_test.go
Normal file
92
tools/filesystem/internal/s3blob/s3/get_object_test.go
Normal file
|
@ -0,0 +1,92 @@
|
|||
package s3_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/filesystem/internal/s3blob/s3"
|
||||
"github.com/pocketbase/pocketbase/tools/filesystem/internal/s3blob/s3/tests"
|
||||
)
|
||||
|
||||
func TestS3GetObject(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
httpClient := tests.NewClient(
|
||||
&tests.RequestStub{
|
||||
Method: http.MethodGet,
|
||||
URL: "http://test_bucket.example.com/test_key",
|
||||
Match: func(req *http.Request) bool {
|
||||
return tests.ExpectHeaders(req.Header, map[string]string{
|
||||
"test_header": "test",
|
||||
"Authorization": "^.+Credential=123/.+$",
|
||||
})
|
||||
},
|
||||
Response: &http.Response{
|
||||
Header: http.Header{
|
||||
"Last-Modified": []string{"Mon, 01 Feb 2025 03:04:05 GMT"},
|
||||
"Cache-Control": []string{"test_cache"},
|
||||
"Content-Disposition": []string{"test_disposition"},
|
||||
"Content-Encoding": []string{"test_encoding"},
|
||||
"Content-Language": []string{"test_language"},
|
||||
"Content-Type": []string{"test_type"},
|
||||
"Content-Range": []string{"test_range"},
|
||||
"Etag": []string{"test_etag"},
|
||||
"Content-Length": []string{"100"},
|
||||
"x-amz-meta-AbC": []string{"test_meta_a"},
|
||||
"x-amz-meta-Def": []string{"test_meta_b"},
|
||||
},
|
||||
Body: io.NopCloser(strings.NewReader("test")),
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
s3Client := &s3.S3{
|
||||
Client: httpClient,
|
||||
Region: "test_region",
|
||||
Bucket: "test_bucket",
|
||||
Endpoint: "http://example.com",
|
||||
AccessKey: "123",
|
||||
SecretKey: "abc",
|
||||
}
|
||||
|
||||
resp, err := s3Client.GetObject(context.Background(), "test_key", func(r *http.Request) {
|
||||
r.Header.Set("test_header", "test")
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
err = httpClient.AssertNoRemaining()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// check body
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
bodyStr := string(body)
|
||||
|
||||
if bodyStr != "test" {
|
||||
t.Fatalf("Expected body\n%q\ngot\n%q", "test", bodyStr)
|
||||
}
|
||||
|
||||
// check serialized attributes
|
||||
raw, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
rawStr := string(raw)
|
||||
|
||||
expected := `{"metadata":{"abc":"test_meta_a","def":"test_meta_b"},"lastModified":"2025-02-01T03:04:05Z","cacheControl":"test_cache","contentDisposition":"test_disposition","contentEncoding":"test_encoding","contentLanguage":"test_language","contentType":"test_type","contentRange":"test_range","etag":"test_etag","contentLength":100}`
|
||||
|
||||
if rawStr != expected {
|
||||
t.Fatalf("Expected attributes\n%s\ngot\n%s", expected, rawStr)
|
||||
}
|
||||
}
|
89
tools/filesystem/internal/s3blob/s3/head_object.go
Normal file
89
tools/filesystem/internal/s3blob/s3/head_object.go
Normal file
|
@ -0,0 +1,89 @@
|
|||
package s3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
// https://docs.aws.amazon.com/AmazonS3/latest/API/API_HeadObject.html#API_HeadObject_ResponseElements
|
||||
type HeadObjectResponse struct {
|
||||
// Metadata is the extra data that is stored with the S3 object (aka. the "x-amz-meta-*" header values).
|
||||
//
|
||||
// The map keys are normalized to lower-case.
|
||||
Metadata map[string]string `json:"metadata"`
|
||||
|
||||
// LastModified date and time when the object was last modified.
|
||||
LastModified time.Time `json:"lastModified"`
|
||||
|
||||
// CacheControl specifies caching behavior along the request/reply chain.
|
||||
CacheControl string `json:"cacheControl"`
|
||||
|
||||
// ContentDisposition specifies presentational information for the object.
|
||||
ContentDisposition string `json:"contentDisposition"`
|
||||
|
||||
// ContentEncoding indicates what content encodings have been applied to the object
|
||||
// and thus what decoding mechanisms must be applied to obtain the
|
||||
// media-type referenced by the Content-Type header field.
|
||||
ContentEncoding string `json:"contentEncoding"`
|
||||
|
||||
// ContentLanguage indicates the language the content is in.
|
||||
ContentLanguage string `json:"contentLanguage"`
|
||||
|
||||
// ContentType is a standard MIME type describing the format of the object data.
|
||||
ContentType string `json:"contentType"`
|
||||
|
||||
// ContentRange is the portion of the object usually returned in the response for a GET request.
|
||||
ContentRange string `json:"contentRange"`
|
||||
|
||||
// ETag is an opaque identifier assigned by a web
|
||||
// server to a specific version of a resource found at a URL.
|
||||
ETag string `json:"etag"`
|
||||
|
||||
// ContentLength is size of the body in bytes.
|
||||
ContentLength int64 `json:"contentLength"`
|
||||
}
|
||||
|
||||
// load parses and load the header values into the current HeadObjectResponse fields.
|
||||
func (o *HeadObjectResponse) load(headers http.Header) {
|
||||
o.LastModified, _ = time.Parse(time.RFC1123, headers.Get("Last-Modified"))
|
||||
o.CacheControl = headers.Get("Cache-Control")
|
||||
o.ContentDisposition = headers.Get("Content-Disposition")
|
||||
o.ContentEncoding = headers.Get("Content-Encoding")
|
||||
o.ContentLanguage = headers.Get("Content-Language")
|
||||
o.ContentType = headers.Get("Content-Type")
|
||||
o.ContentRange = headers.Get("Content-Range")
|
||||
o.ETag = headers.Get("ETag")
|
||||
o.ContentLength, _ = strconv.ParseInt(headers.Get("Content-Length"), 10, 0)
|
||||
o.Metadata = extractMetadata(headers)
|
||||
}
|
||||
|
||||
// HeadObject sends a HEAD request for a single object to check its
|
||||
// existence and to retrieve its metadata.
|
||||
//
|
||||
// https://docs.aws.amazon.com/AmazonS3/latest/API/API_HeadObject.html
|
||||
func (s3 *S3) HeadObject(ctx context.Context, key string, optFuncs ...func(*http.Request)) (*HeadObjectResponse, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodHead, s3.URL(key), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// apply optional request funcs
|
||||
for _, fn := range optFuncs {
|
||||
if fn != nil {
|
||||
fn(req)
|
||||
}
|
||||
}
|
||||
|
||||
resp, err := s3.SignAndSend(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
result := &HeadObjectResponse{}
|
||||
result.load(resp.Header)
|
||||
|
||||
return result, nil
|
||||
}
|
77
tools/filesystem/internal/s3blob/s3/head_object_test.go
Normal file
77
tools/filesystem/internal/s3blob/s3/head_object_test.go
Normal file
|
@ -0,0 +1,77 @@
|
|||
package s3_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/filesystem/internal/s3blob/s3"
|
||||
"github.com/pocketbase/pocketbase/tools/filesystem/internal/s3blob/s3/tests"
|
||||
)
|
||||
|
||||
func TestS3HeadObject(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
httpClient := tests.NewClient(
|
||||
&tests.RequestStub{
|
||||
Method: http.MethodHead,
|
||||
URL: "http://test_bucket.example.com/test_key",
|
||||
Match: func(req *http.Request) bool {
|
||||
return tests.ExpectHeaders(req.Header, map[string]string{
|
||||
"test_header": "test",
|
||||
"Authorization": "^.+Credential=123/.+$",
|
||||
})
|
||||
},
|
||||
Response: &http.Response{
|
||||
Header: http.Header{
|
||||
"Last-Modified": []string{"Mon, 01 Feb 2025 03:04:05 GMT"},
|
||||
"Cache-Control": []string{"test_cache"},
|
||||
"Content-Disposition": []string{"test_disposition"},
|
||||
"Content-Encoding": []string{"test_encoding"},
|
||||
"Content-Language": []string{"test_language"},
|
||||
"Content-Type": []string{"test_type"},
|
||||
"Content-Range": []string{"test_range"},
|
||||
"Etag": []string{"test_etag"},
|
||||
"Content-Length": []string{"100"},
|
||||
"x-amz-meta-AbC": []string{"test_meta_a"},
|
||||
"x-amz-meta-Def": []string{"test_meta_b"},
|
||||
},
|
||||
Body: http.NoBody,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
s3Client := &s3.S3{
|
||||
Client: httpClient,
|
||||
Region: "test_region",
|
||||
Bucket: "test_bucket",
|
||||
Endpoint: "http://example.com",
|
||||
AccessKey: "123",
|
||||
SecretKey: "abc",
|
||||
}
|
||||
|
||||
resp, err := s3Client.HeadObject(context.Background(), "test_key", func(r *http.Request) {
|
||||
r.Header.Set("test_header", "test")
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = httpClient.AssertNoRemaining()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
raw, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
rawStr := string(raw)
|
||||
|
||||
expected := `{"metadata":{"abc":"test_meta_a","def":"test_meta_b"},"lastModified":"2025-02-01T03:04:05Z","cacheControl":"test_cache","contentDisposition":"test_disposition","contentEncoding":"test_encoding","contentLanguage":"test_language","contentType":"test_type","contentRange":"test_range","etag":"test_etag","contentLength":100}`
|
||||
|
||||
if rawStr != expected {
|
||||
t.Fatalf("Expected response\n%s\ngot\n%s", expected, rawStr)
|
||||
}
|
||||
}
|
165
tools/filesystem/internal/s3blob/s3/list_objects.go
Normal file
165
tools/filesystem/internal/s3blob/s3/list_objects.go
Normal file
|
@ -0,0 +1,165 @@
|
|||
package s3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/xml"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ListParams defines optional parameters for the ListObject request.
|
||||
type ListParams struct {
|
||||
// ContinuationToken indicates that the list is being continued on this bucket with a token.
|
||||
// ContinuationToken is obfuscated and is not a real key.
|
||||
// You can use this ContinuationToken for pagination of the list results.
|
||||
ContinuationToken string `json:"continuationToken"`
|
||||
|
||||
// Delimiter is a character that you use to group keys.
|
||||
//
|
||||
// For directory buckets, "/" is the only supported delimiter.
|
||||
Delimiter string `json:"delimiter"`
|
||||
|
||||
// Prefix limits the response to keys that begin with the specified prefix.
|
||||
Prefix string `json:"prefix"`
|
||||
|
||||
// Encoding type is used to encode the object keys in the response.
|
||||
// Responses are encoded only in UTF-8.
|
||||
// An object key can contain any Unicode character.
|
||||
// However, the XML 1.0 parser can't parse certain characters,
|
||||
// such as characters with an ASCII value from 0 to 10.
|
||||
// For characters that aren't supported in XML 1.0, you can add
|
||||
// this parameter to request that S3 encode the keys in the response.
|
||||
//
|
||||
// Valid Values: url
|
||||
EncodingType string `json:"encodingType"`
|
||||
|
||||
// StartAfter is where you want S3 to start listing from.
|
||||
// S3 starts listing after this specified key.
|
||||
// StartAfter can be any key in the bucket.
|
||||
//
|
||||
// This functionality is not supported for directory buckets.
|
||||
StartAfter string `json:"startAfter"`
|
||||
|
||||
// MaxKeys Sets the maximum number of keys returned in the response.
|
||||
// By default, the action returns up to 1,000 key names.
|
||||
// The response might contain fewer keys but will never contain more.
|
||||
MaxKeys int `json:"maxKeys"`
|
||||
|
||||
// FetchOwner returns the owner field with each key in the result.
|
||||
FetchOwner bool `json:"fetchOwner"`
|
||||
}
|
||||
|
||||
// Encode encodes the parameters in a properly formatted query string.
|
||||
func (l *ListParams) Encode() string {
|
||||
query := url.Values{}
|
||||
|
||||
query.Add("list-type", "2")
|
||||
|
||||
if l.ContinuationToken != "" {
|
||||
query.Add("continuation-token", l.ContinuationToken)
|
||||
}
|
||||
|
||||
if l.Delimiter != "" {
|
||||
query.Add("delimiter", l.Delimiter)
|
||||
}
|
||||
|
||||
if l.Prefix != "" {
|
||||
query.Add("prefix", l.Prefix)
|
||||
}
|
||||
|
||||
if l.EncodingType != "" {
|
||||
query.Add("encoding-type", l.EncodingType)
|
||||
}
|
||||
|
||||
if l.FetchOwner {
|
||||
query.Add("fetch-owner", "true")
|
||||
}
|
||||
|
||||
if l.MaxKeys > 0 {
|
||||
query.Add("max-keys", strconv.Itoa(l.MaxKeys))
|
||||
}
|
||||
|
||||
if l.StartAfter != "" {
|
||||
query.Add("start-after", l.StartAfter)
|
||||
}
|
||||
|
||||
return query.Encode()
|
||||
}
|
||||
|
||||
// ListObjects retrieves paginated objects list.
|
||||
//
|
||||
// https://docs.aws.amazon.com/AmazonS3/latest/API/API_ListObjectsV2.html
|
||||
func (s3 *S3) ListObjects(ctx context.Context, params ListParams, optReqFuncs ...func(*http.Request)) (*ListObjectsResponse, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, s3.URL("?"+params.Encode()), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// apply optional request funcs
|
||||
for _, fn := range optReqFuncs {
|
||||
if fn != nil {
|
||||
fn(req)
|
||||
}
|
||||
}
|
||||
|
||||
resp, err := s3.SignAndSend(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
result := &ListObjectsResponse{}
|
||||
|
||||
err = xml.NewDecoder(resp.Body).Decode(result)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// https://docs.aws.amazon.com/AmazonS3/latest/API/API_ListObjectsV2.html#API_ListObjectsV2_ResponseSyntax
|
||||
type ListObjectsResponse struct {
|
||||
XMLName xml.Name `json:"-" xml:"ListBucketResult"`
|
||||
EncodingType string `json:"encodingType" xml:"EncodingType"`
|
||||
Name string `json:"name" xml:"Name"`
|
||||
Prefix string `json:"prefix" xml:"Prefix"`
|
||||
Delimiter string `json:"delimiter" xml:"Delimiter"`
|
||||
ContinuationToken string `json:"continuationToken" xml:"ContinuationToken"`
|
||||
NextContinuationToken string `json:"nextContinuationToken" xml:"NextContinuationToken"`
|
||||
StartAfter string `json:"startAfter" xml:"StartAfter"`
|
||||
|
||||
CommonPrefixes []*ListObjectCommonPrefix `json:"commonPrefixes" xml:"CommonPrefixes"`
|
||||
|
||||
Contents []*ListObjectContent `json:"contents" xml:"Contents"`
|
||||
|
||||
KeyCount int `json:"keyCount" xml:"KeyCount"`
|
||||
MaxKeys int `json:"maxKeys" xml:"MaxKeys"`
|
||||
IsTruncated bool `json:"isTruncated" xml:"IsTruncated"`
|
||||
}
|
||||
|
||||
type ListObjectCommonPrefix struct {
|
||||
Prefix string `json:"prefix" xml:"Prefix"`
|
||||
}
|
||||
|
||||
type ListObjectContent struct {
|
||||
Owner struct {
|
||||
DisplayName string `json:"displayName" xml:"DisplayName"`
|
||||
ID string `json:"id" xml:"ID"`
|
||||
} `json:"owner" xml:"Owner"`
|
||||
|
||||
ChecksumAlgorithm string `json:"checksumAlgorithm" xml:"ChecksumAlgorithm"`
|
||||
ETag string `json:"etag" xml:"ETag"`
|
||||
Key string `json:"key" xml:"Key"`
|
||||
StorageClass string `json:"storageClass" xml:"StorageClass"`
|
||||
LastModified time.Time `json:"lastModified" xml:"LastModified"`
|
||||
|
||||
RestoreStatus struct {
|
||||
RestoreExpiryDate time.Time `json:"restoreExpiryDate" xml:"RestoreExpiryDate"`
|
||||
IsRestoreInProgress bool `json:"isRestoreInProgress" xml:"IsRestoreInProgress"`
|
||||
} `json:"restoreStatus" xml:"RestoreStatus"`
|
||||
|
||||
Size int64 `json:"size" xml:"Size"`
|
||||
}
|
157
tools/filesystem/internal/s3blob/s3/list_objects_test.go
Normal file
157
tools/filesystem/internal/s3blob/s3/list_objects_test.go
Normal file
|
@ -0,0 +1,157 @@
|
|||
package s3_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/filesystem/internal/s3blob/s3"
|
||||
"github.com/pocketbase/pocketbase/tools/filesystem/internal/s3blob/s3/tests"
|
||||
)
|
||||
|
||||
func TestS3ListParamsEncode(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
params s3.ListParams
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
"blank",
|
||||
s3.ListParams{},
|
||||
"list-type=2",
|
||||
},
|
||||
{
|
||||
"filled",
|
||||
s3.ListParams{
|
||||
ContinuationToken: "test_ct",
|
||||
Delimiter: "test_delimiter",
|
||||
Prefix: "test_prefix",
|
||||
EncodingType: "test_et",
|
||||
StartAfter: "test_sa",
|
||||
MaxKeys: 1,
|
||||
FetchOwner: true,
|
||||
},
|
||||
"continuation-token=test_ct&delimiter=test_delimiter&encoding-type=test_et&fetch-owner=true&list-type=2&max-keys=1&prefix=test_prefix&start-after=test_sa",
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
result := s.params.Encode()
|
||||
if result != s.expected {
|
||||
t.Fatalf("Expected\n%s\ngot\n%s", s.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestS3ListObjects(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
listParams := s3.ListParams{
|
||||
ContinuationToken: "test_ct",
|
||||
Delimiter: "test_delimiter",
|
||||
Prefix: "test_prefix",
|
||||
EncodingType: "test_et",
|
||||
StartAfter: "test_sa",
|
||||
MaxKeys: 10,
|
||||
FetchOwner: true,
|
||||
}
|
||||
|
||||
httpClient := tests.NewClient(
|
||||
&tests.RequestStub{
|
||||
Method: http.MethodGet,
|
||||
URL: "http://test_bucket.example.com/?" + listParams.Encode(),
|
||||
Match: func(req *http.Request) bool {
|
||||
return tests.ExpectHeaders(req.Header, map[string]string{
|
||||
"test_header": "test",
|
||||
"Authorization": "^.+Credential=123/.+$",
|
||||
})
|
||||
},
|
||||
Response: &http.Response{
|
||||
Body: io.NopCloser(strings.NewReader(`
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<ListBucketResult xmlns="http://s3.amazonaws.com/doc/2006-03-01/">
|
||||
<Name>example</Name>
|
||||
<EncodingType>test_encoding</EncodingType>
|
||||
<Prefix>a/</Prefix>
|
||||
<Delimiter>/</Delimiter>
|
||||
<ContinuationToken>ct</ContinuationToken>
|
||||
<NextContinuationToken>nct</NextContinuationToken>
|
||||
<StartAfter>example0.txt</StartAfter>
|
||||
<KeyCount>1</KeyCount>
|
||||
<MaxKeys>3</MaxKeys>
|
||||
<IsTruncated>true</IsTruncated>
|
||||
<Contents>
|
||||
<Key>example1.txt</Key>
|
||||
<LastModified>2025-01-01T01:02:03.123Z</LastModified>
|
||||
<ChecksumAlgorithm>test_ca</ChecksumAlgorithm>
|
||||
<ETag>test_etag1</ETag>
|
||||
<Size>123</Size>
|
||||
<StorageClass>STANDARD</StorageClass>
|
||||
<Owner>
|
||||
<DisplayName>owner_dn</DisplayName>
|
||||
<ID>owner_id</ID>
|
||||
</Owner>
|
||||
<RestoreStatus>
|
||||
<RestoreExpiryDate>2025-01-02T01:02:03.123Z</RestoreExpiryDate>
|
||||
<IsRestoreInProgress>true</IsRestoreInProgress>
|
||||
</RestoreStatus>
|
||||
</Contents>
|
||||
<Contents>
|
||||
<Key>example2.txt</Key>
|
||||
<LastModified>2025-01-02T01:02:03.123Z</LastModified>
|
||||
<ETag>test_etag2</ETag>
|
||||
<Size>456</Size>
|
||||
<StorageClass>STANDARD</StorageClass>
|
||||
</Contents>
|
||||
<CommonPrefixes>
|
||||
<Prefix>a/b/</Prefix>
|
||||
</CommonPrefixes>
|
||||
<CommonPrefixes>
|
||||
<Prefix>a/c/</Prefix>
|
||||
</CommonPrefixes>
|
||||
</ListBucketResult>
|
||||
`)),
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
s3Client := &s3.S3{
|
||||
Client: httpClient,
|
||||
Region: "test_region",
|
||||
Bucket: "test_bucket",
|
||||
Endpoint: "http://example.com",
|
||||
AccessKey: "123",
|
||||
SecretKey: "abc",
|
||||
}
|
||||
|
||||
resp, err := s3Client.ListObjects(context.Background(), listParams, func(r *http.Request) {
|
||||
r.Header.Set("test_header", "test")
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = httpClient.AssertNoRemaining()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
raw, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
rawStr := string(raw)
|
||||
|
||||
expected := `{"encodingType":"test_encoding","name":"example","prefix":"a/","delimiter":"/","continuationToken":"ct","nextContinuationToken":"nct","startAfter":"example0.txt","commonPrefixes":[{"prefix":"a/b/"},{"prefix":"a/c/"}],"contents":[{"owner":{"displayName":"owner_dn","id":"owner_id"},"checksumAlgorithm":"test_ca","etag":"test_etag1","key":"example1.txt","storageClass":"STANDARD","lastModified":"2025-01-01T01:02:03.123Z","restoreStatus":{"restoreExpiryDate":"2025-01-02T01:02:03.123Z","isRestoreInProgress":true},"size":123},{"owner":{"displayName":"","id":""},"checksumAlgorithm":"","etag":"test_etag2","key":"example2.txt","storageClass":"STANDARD","lastModified":"2025-01-02T01:02:03.123Z","restoreStatus":{"restoreExpiryDate":"0001-01-01T00:00:00Z","isRestoreInProgress":false},"size":456}],"keyCount":1,"maxKeys":3,"isTruncated":true}`
|
||||
|
||||
if rawStr != expected {
|
||||
t.Fatalf("Expected response\n%s\ngot\n%s", expected, rawStr)
|
||||
}
|
||||
}
|
370
tools/filesystem/internal/s3blob/s3/s3.go
Normal file
370
tools/filesystem/internal/s3blob/s3/s3.go
Normal file
|
@ -0,0 +1,370 @@
|
|||
// Package s3 implements a lightweight client for interacting with the
|
||||
// REST APIs of any S3 compatible service.
|
||||
//
|
||||
// It implements only the minimal functionality required by PocketBase
|
||||
// such as objects list, get, copy, delete and upload.
|
||||
//
|
||||
// For more details why we don't use the official aws-sdk-go-v2, you could check
|
||||
// https://github.com/pocketbase/pocketbase/discussions/6562.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// client := &s3.S3{
|
||||
// Endpoint: "example.com",
|
||||
// Region: "us-east-1",
|
||||
// Bucket: "test",
|
||||
// AccessKey: "...",
|
||||
// SecretKey: "...",
|
||||
// UsePathStyle: true,
|
||||
// }
|
||||
// resp, err := client.GetObject(context.Background(), "abc.txt")
|
||||
package s3
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/xml"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
awsS3ServiceCode = "s3"
|
||||
awsSignAlgorithm = "AWS4-HMAC-SHA256"
|
||||
awsTerminationString = "aws4_request"
|
||||
metadataPrefix = "x-amz-meta-"
|
||||
dateTimeFormat = "20060102T150405Z"
|
||||
)
|
||||
|
||||
type HTTPClient interface {
|
||||
Do(req *http.Request) (*http.Response, error)
|
||||
}
|
||||
|
||||
type S3 struct {
|
||||
// Client specifies a custom HTTP client to send the request with.
|
||||
//
|
||||
// If not explicitly set, fallbacks to http.DefaultClient.
|
||||
Client HTTPClient
|
||||
|
||||
Bucket string
|
||||
Region string
|
||||
Endpoint string // can be with or without the schema
|
||||
AccessKey string
|
||||
SecretKey string
|
||||
UsePathStyle bool
|
||||
}
|
||||
|
||||
// URL constructs an S3 request URL based on the current configuration.
|
||||
func (s3 *S3) URL(path string) string {
|
||||
scheme := "https"
|
||||
endpoint := strings.TrimRight(s3.Endpoint, "/")
|
||||
if after, ok := strings.CutPrefix(endpoint, "https://"); ok {
|
||||
endpoint = after
|
||||
} else if after, ok := strings.CutPrefix(endpoint, "http://"); ok {
|
||||
endpoint = after
|
||||
scheme = "http"
|
||||
}
|
||||
|
||||
path = strings.TrimLeft(path, "/")
|
||||
|
||||
if s3.UsePathStyle {
|
||||
return fmt.Sprintf("%s://%s/%s/%s", scheme, endpoint, s3.Bucket, path)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s://%s.%s/%s", scheme, s3.Bucket, endpoint, path)
|
||||
}
|
||||
|
||||
// SignAndSend signs the provided request per AWS Signature v4 and sends it.
|
||||
//
|
||||
// It automatically normalizes all 40x/50x responses to ResponseError.
|
||||
//
|
||||
// Note: Don't forget to call resp.Body.Close() after done with the result.
|
||||
func (s3 *S3) SignAndSend(req *http.Request) (*http.Response, error) {
|
||||
s3.sign(req)
|
||||
|
||||
client := s3.Client
|
||||
if client == nil {
|
||||
client = http.DefaultClient
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
defer resp.Body.Close()
|
||||
|
||||
respErr := &ResponseError{
|
||||
Status: resp.StatusCode,
|
||||
}
|
||||
|
||||
respErr.Raw, err = io.ReadAll(resp.Body)
|
||||
if err != nil && !errors.Is(err, io.EOF) {
|
||||
return nil, errors.Join(err, respErr)
|
||||
}
|
||||
|
||||
if len(respErr.Raw) > 0 {
|
||||
err = xml.Unmarshal(respErr.Raw, respErr)
|
||||
if err != nil {
|
||||
return nil, errors.Join(err, respErr)
|
||||
}
|
||||
}
|
||||
|
||||
return nil, respErr
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// https://docs.aws.amazon.com/IAM/latest/UserGuide/reference_sigv-create-signed-request.html#create-signed-request-steps
|
||||
func (s3 *S3) sign(req *http.Request) {
|
||||
// fallback to the Unsigned payload option
|
||||
// (data integrity checks could be still applied via the content-md5 or x-amz-checksum-* headers)
|
||||
if req.Header.Get("x-amz-content-sha256") == "" {
|
||||
req.Header.Set("x-amz-content-sha256", "UNSIGNED-PAYLOAD")
|
||||
}
|
||||
|
||||
reqDateTime, _ := time.Parse(dateTimeFormat, req.Header.Get("x-amz-date"))
|
||||
if reqDateTime.IsZero() {
|
||||
reqDateTime = time.Now().UTC()
|
||||
req.Header.Set("x-amz-date", reqDateTime.Format(dateTimeFormat))
|
||||
}
|
||||
|
||||
req.Header.Set("host", req.URL.Host)
|
||||
|
||||
date := reqDateTime.Format("20060102")
|
||||
|
||||
dateTime := reqDateTime.Format(dateTimeFormat)
|
||||
|
||||
// 1. Create canonical request
|
||||
// https://docs.aws.amazon.com/IAM/latest/UserGuide/reference_sigv-create-signed-request.html#create-canonical-request
|
||||
// ---------------------------------------------------------------
|
||||
canonicalHeaders, signedHeaders := canonicalAndSignedHeaders(req)
|
||||
|
||||
canonicalParts := []string{
|
||||
req.Method,
|
||||
escapePath(req.URL.Path),
|
||||
escapeQuery(req.URL.Query()),
|
||||
canonicalHeaders,
|
||||
signedHeaders,
|
||||
req.Header.Get("x-amz-content-sha256"),
|
||||
}
|
||||
|
||||
// 2. Create a hash of the canonical request
|
||||
// https://docs.aws.amazon.com/IAM/latest/UserGuide/reference_sigv-create-signed-request.html#create-canonical-request-hash
|
||||
// ---------------------------------------------------------------
|
||||
hashedCanonicalRequest := sha256Hex([]byte(strings.Join(canonicalParts, "\n")))
|
||||
|
||||
// 3. Create a string to sign
|
||||
// https://docs.aws.amazon.com/IAM/latest/UserGuide/reference_sigv-create-signed-request.html#create-string-to-sign
|
||||
// ---------------------------------------------------------------
|
||||
scope := strings.Join([]string{
|
||||
date,
|
||||
s3.Region,
|
||||
awsS3ServiceCode,
|
||||
awsTerminationString,
|
||||
}, "/")
|
||||
|
||||
stringToSign := strings.Join([]string{
|
||||
awsSignAlgorithm,
|
||||
dateTime,
|
||||
scope,
|
||||
hashedCanonicalRequest,
|
||||
}, "\n")
|
||||
|
||||
// 4. Derive a signing key for SigV4
|
||||
// https://docs.aws.amazon.com/IAM/latest/UserGuide/reference_sigv-create-signed-request.html#derive-signing-key
|
||||
// ---------------------------------------------------------------
|
||||
dateKey := hmacSHA256([]byte("AWS4"+s3.SecretKey), date)
|
||||
dateRegionKey := hmacSHA256(dateKey, s3.Region)
|
||||
dateRegionServiceKey := hmacSHA256(dateRegionKey, awsS3ServiceCode)
|
||||
signingKey := hmacSHA256(dateRegionServiceKey, awsTerminationString)
|
||||
signature := hex.EncodeToString(hmacSHA256(signingKey, stringToSign))
|
||||
|
||||
// 5. Add the signature to the request
|
||||
// https://docs.aws.amazon.com/IAM/latest/UserGuide/reference_sigv-create-signed-request.html#add-signature-to-request
|
||||
authorization := fmt.Sprintf(
|
||||
"%s Credential=%s/%s, SignedHeaders=%s, Signature=%s",
|
||||
awsSignAlgorithm,
|
||||
s3.AccessKey,
|
||||
scope,
|
||||
signedHeaders,
|
||||
signature,
|
||||
)
|
||||
|
||||
req.Header.Set("authorization", authorization)
|
||||
}
|
||||
|
||||
func sha256Hex(content []byte) string {
|
||||
h := sha256.New()
|
||||
h.Write(content)
|
||||
return hex.EncodeToString(h.Sum(nil))
|
||||
}
|
||||
|
||||
func hmacSHA256(key []byte, content string) []byte {
|
||||
mac := hmac.New(sha256.New, key)
|
||||
mac.Write([]byte(content))
|
||||
return mac.Sum(nil)
|
||||
}
|
||||
|
||||
func canonicalAndSignedHeaders(req *http.Request) (string, string) {
|
||||
signed := []string{}
|
||||
canonical := map[string]string{}
|
||||
|
||||
for key, values := range req.Header {
|
||||
normalizedKey := strings.ToLower(key)
|
||||
|
||||
if normalizedKey != "host" &&
|
||||
normalizedKey != "content-type" &&
|
||||
!strings.HasPrefix(normalizedKey, "x-amz-") {
|
||||
continue
|
||||
}
|
||||
|
||||
signed = append(signed, normalizedKey)
|
||||
|
||||
// for each value:
|
||||
// trim any leading or trailing spaces
|
||||
// convert sequential spaces to a single space
|
||||
normalizedValues := make([]string, len(values))
|
||||
for i, v := range values {
|
||||
normalizedValues[i] = strings.ReplaceAll(strings.TrimSpace(v), " ", " ")
|
||||
}
|
||||
|
||||
canonical[normalizedKey] = strings.Join(normalizedValues, ",")
|
||||
}
|
||||
|
||||
slices.Sort(signed)
|
||||
|
||||
var sortedCanonical strings.Builder
|
||||
for _, key := range signed {
|
||||
sortedCanonical.WriteString(key)
|
||||
sortedCanonical.WriteString(":")
|
||||
sortedCanonical.WriteString(canonical[key])
|
||||
sortedCanonical.WriteString("\n")
|
||||
}
|
||||
|
||||
return sortedCanonical.String(), strings.Join(signed, ";")
|
||||
}
|
||||
|
||||
// extractMetadata parses and extracts and the metadata from the specified request headers.
|
||||
//
|
||||
// The metadata keys are all lowercased and without the "x-amz-meta-" prefix.
|
||||
func extractMetadata(headers http.Header) map[string]string {
|
||||
result := map[string]string{}
|
||||
|
||||
for k, v := range headers {
|
||||
if len(v) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
metadataKey, ok := strings.CutPrefix(strings.ToLower(k), metadataPrefix)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
result[metadataKey] = v[0]
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// escapeQuery returns the URI encoded request query parameters according to the AWS S3 spec requirements
|
||||
// (it is similar to url.Values.Encode but instead of url.QueryEscape uses our own escape method).
|
||||
func escapeQuery(values url.Values) string {
|
||||
if len(values) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
var buf strings.Builder
|
||||
|
||||
keys := make([]string, 0, len(values))
|
||||
for k := range values {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
slices.Sort(keys)
|
||||
|
||||
for _, k := range keys {
|
||||
vs := values[k]
|
||||
keyEscaped := escape(k)
|
||||
for _, values := range vs {
|
||||
if buf.Len() > 0 {
|
||||
buf.WriteByte('&')
|
||||
}
|
||||
buf.WriteString(keyEscaped)
|
||||
buf.WriteByte('=')
|
||||
buf.WriteString(escape(values))
|
||||
}
|
||||
}
|
||||
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
// escapePath returns the URI encoded request path according to the AWS S3 spec requirements.
|
||||
func escapePath(path string) string {
|
||||
parts := strings.Split(path, "/")
|
||||
|
||||
for i, part := range parts {
|
||||
parts[i] = escape(part)
|
||||
}
|
||||
|
||||
return strings.Join(parts, "/")
|
||||
}
|
||||
|
||||
const upperhex = "0123456789ABCDEF"
|
||||
|
||||
// escape is similar to the std url.escape but implements the AWS [UriEncode requirements]:
|
||||
// - URI encode every byte except the unreserved characters: 'A'-'Z', 'a'-'z', '0'-'9', '-', '.', '_', and '~'.
|
||||
// - The space character is a reserved character and must be encoded as "%20" (and not as "+").
|
||||
// - Each URI encoded byte is formed by a '%' and the two-digit hexadecimal value of the byte.
|
||||
// - Letters in the hexadecimal value must be uppercase, for example "%1A".
|
||||
//
|
||||
// [UriEncode requirements]: https://docs.aws.amazon.com/IAM/latest/UserGuide/reference_sigv-create-signed-request.html
|
||||
func escape(s string) string {
|
||||
hexCount := 0
|
||||
for i := 0; i < len(s); i++ {
|
||||
c := s[i]
|
||||
if shouldEscape(c) {
|
||||
hexCount++
|
||||
}
|
||||
}
|
||||
|
||||
if hexCount == 0 {
|
||||
return s
|
||||
}
|
||||
|
||||
result := make([]byte, len(s)+2*hexCount)
|
||||
|
||||
j := 0
|
||||
for i := 0; i < len(s); i++ {
|
||||
c := s[i]
|
||||
if shouldEscape(c) {
|
||||
result[j] = '%'
|
||||
result[j+1] = upperhex[c>>4]
|
||||
result[j+2] = upperhex[c&15]
|
||||
j += 3
|
||||
} else {
|
||||
result[j] = c
|
||||
j++
|
||||
}
|
||||
}
|
||||
|
||||
return string(result)
|
||||
}
|
||||
|
||||
// > "URI encode every byte except the unreserved characters: 'A'-'Z', 'a'-'z', '0'-'9', '-', '.', '_', and '~'."
|
||||
func shouldEscape(c byte) bool {
|
||||
isUnreserved := (c >= 'A' && c <= 'Z') ||
|
||||
(c >= 'a' && c <= 'z') ||
|
||||
(c >= '0' && c <= '9') ||
|
||||
c == '-' || c == '.' || c == '_' || c == '~'
|
||||
|
||||
return !isUnreserved
|
||||
}
|
35
tools/filesystem/internal/s3blob/s3/s3_escape_test.go
Normal file
35
tools/filesystem/internal/s3blob/s3/s3_escape_test.go
Normal file
|
@ -0,0 +1,35 @@
|
|||
package s3
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestEscapePath(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
escaped := escapePath("/ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_.~ !@#$%^&*()+={}[]?><\\|,`'\"/@sub1/@sub2/a/b/c/1/2/3")
|
||||
|
||||
expected := "/ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_.~%20%21%40%23%24%25%5E%26%2A%28%29%2B%3D%7B%7D%5B%5D%3F%3E%3C%5C%7C%2C%60%27%22/%40sub1/%40sub2/a/b/c/1/2/3"
|
||||
|
||||
if escaped != expected {
|
||||
t.Fatalf("Expected\n%s\ngot\n%s", expected, escaped)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEscapeQuery(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
escaped := escapeQuery(url.Values{
|
||||
"abc": []string{"123"},
|
||||
"/ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_.~ !@#$%^&*()+={}[]?><\\|,`'\"": []string{
|
||||
"/ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_.~ !@#$%^&*()+={}[]?><\\|,`'\"",
|
||||
},
|
||||
})
|
||||
|
||||
expected := "%2FABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_.~%20%21%40%23%24%25%5E%26%2A%28%29%2B%3D%7B%7D%5B%5D%3F%3E%3C%5C%7C%2C%60%27%22=%2FABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_.~%20%21%40%23%24%25%5E%26%2A%28%29%2B%3D%7B%7D%5B%5D%3F%3E%3C%5C%7C%2C%60%27%22&abc=123"
|
||||
|
||||
if escaped != expected {
|
||||
t.Fatalf("Expected\n%s\ngot\n%s", expected, escaped)
|
||||
}
|
||||
}
|
256
tools/filesystem/internal/s3blob/s3/s3_test.go
Normal file
256
tools/filesystem/internal/s3blob/s3/s3_test.go
Normal file
|
@ -0,0 +1,256 @@
|
|||
package s3_test
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/filesystem/internal/s3blob/s3"
|
||||
"github.com/pocketbase/pocketbase/tools/filesystem/internal/s3blob/s3/tests"
|
||||
)
|
||||
|
||||
func TestS3URL(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
s3Client *s3.S3
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
"no schema",
|
||||
&s3.S3{
|
||||
Region: "test_region",
|
||||
Bucket: "test_bucket",
|
||||
Endpoint: "example.com/",
|
||||
AccessKey: "123",
|
||||
SecretKey: "abc",
|
||||
},
|
||||
"https://test_bucket.example.com/test_key/a/b/c?q=1",
|
||||
},
|
||||
{
|
||||
"with https schema",
|
||||
&s3.S3{
|
||||
Region: "test_region",
|
||||
Bucket: "test_bucket",
|
||||
Endpoint: "https://example.com/",
|
||||
AccessKey: "123",
|
||||
SecretKey: "abc",
|
||||
},
|
||||
"https://test_bucket.example.com/test_key/a/b/c?q=1",
|
||||
},
|
||||
{
|
||||
"with http schema",
|
||||
&s3.S3{
|
||||
Region: "test_region",
|
||||
Bucket: "test_bucket",
|
||||
Endpoint: "http://example.com/",
|
||||
AccessKey: "123",
|
||||
SecretKey: "abc",
|
||||
},
|
||||
"http://test_bucket.example.com/test_key/a/b/c?q=1",
|
||||
},
|
||||
{
|
||||
"path style addressing (non-explicit schema)",
|
||||
&s3.S3{
|
||||
Region: "test_region",
|
||||
Bucket: "test_bucket",
|
||||
Endpoint: "example.com/",
|
||||
AccessKey: "123",
|
||||
SecretKey: "abc",
|
||||
UsePathStyle: true,
|
||||
},
|
||||
"https://example.com/test_bucket/test_key/a/b/c?q=1",
|
||||
},
|
||||
{
|
||||
"path style addressing (explicit schema)",
|
||||
&s3.S3{
|
||||
Region: "test_region",
|
||||
Bucket: "test_bucket",
|
||||
Endpoint: "http://example.com/",
|
||||
AccessKey: "123",
|
||||
SecretKey: "abc",
|
||||
UsePathStyle: true,
|
||||
},
|
||||
"http://example.com/test_bucket/test_key/a/b/c?q=1",
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
result := s.s3Client.URL("/test_key/a/b/c?q=1")
|
||||
if result != s.expected {
|
||||
t.Fatalf("Expected URL\n%s\ngot\n%s", s.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestS3SignAndSend(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testResponse := func() *http.Response {
|
||||
return &http.Response{
|
||||
Body: io.NopCloser(strings.NewReader("test_response")),
|
||||
}
|
||||
}
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
path string
|
||||
reqFunc func(req *http.Request)
|
||||
s3Client *s3.S3
|
||||
}{
|
||||
{
|
||||
"minimal",
|
||||
"/test",
|
||||
func(req *http.Request) {
|
||||
req.Header.Set("x-amz-date", "20250102T150405Z")
|
||||
},
|
||||
&s3.S3{
|
||||
Region: "test_region",
|
||||
Bucket: "test_bucket",
|
||||
Endpoint: "https://example.com/",
|
||||
AccessKey: "123",
|
||||
SecretKey: "abc",
|
||||
Client: tests.NewClient(&tests.RequestStub{
|
||||
Method: http.MethodGet,
|
||||
URL: "https://test_bucket.example.com/test",
|
||||
Response: testResponse(),
|
||||
Match: func(req *http.Request) bool {
|
||||
return tests.ExpectHeaders(req.Header, map[string]string{
|
||||
"Authorization": "AWS4-HMAC-SHA256 Credential=123/20250102/test_region/s3/aws4_request, SignedHeaders=host;x-amz-content-sha256;x-amz-date, Signature=ea093662bc1deef08dfb4ac35453dfaad5ea89edf102e9dd3b7156c9a27e4c1f",
|
||||
"Host": "test_bucket.example.com",
|
||||
"X-Amz-Content-Sha256": "UNSIGNED-PAYLOAD",
|
||||
"X-Amz-Date": "20250102T150405Z",
|
||||
})
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
{
|
||||
"minimal with different access and secret keys",
|
||||
"/test",
|
||||
func(req *http.Request) {
|
||||
req.Header.Set("x-amz-date", "20250102T150405Z")
|
||||
},
|
||||
&s3.S3{
|
||||
Region: "test_region",
|
||||
Bucket: "test_bucket",
|
||||
Endpoint: "https://example.com/",
|
||||
AccessKey: "456",
|
||||
SecretKey: "def",
|
||||
Client: tests.NewClient(&tests.RequestStub{
|
||||
Method: http.MethodGet,
|
||||
URL: "https://test_bucket.example.com/test",
|
||||
Response: testResponse(),
|
||||
Match: func(req *http.Request) bool {
|
||||
return tests.ExpectHeaders(req.Header, map[string]string{
|
||||
"Authorization": "AWS4-HMAC-SHA256 Credential=456/20250102/test_region/s3/aws4_request, SignedHeaders=host;x-amz-content-sha256;x-amz-date, Signature=17510fa1f724403dd0a563b61c9b31d1d718f877fcbd75455620d17a8afce5fb",
|
||||
"Host": "test_bucket.example.com",
|
||||
"X-Amz-Content-Sha256": "UNSIGNED-PAYLOAD",
|
||||
"X-Amz-Date": "20250102T150405Z",
|
||||
})
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
{
|
||||
"minimal with special characters",
|
||||
"/ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_.~!@#$^&*()=/@sub?a=1&@b=@2",
|
||||
func(req *http.Request) {
|
||||
req.Header.Set("x-amz-date", "20250102T150405Z")
|
||||
},
|
||||
&s3.S3{
|
||||
Region: "test_region",
|
||||
Bucket: "test_bucket",
|
||||
Endpoint: "https://example.com/",
|
||||
AccessKey: "456",
|
||||
SecretKey: "def",
|
||||
Client: tests.NewClient(&tests.RequestStub{
|
||||
Method: http.MethodGet,
|
||||
URL: "https://test_bucket.example.com/ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_.~!@#$%5E&*()=/@sub?a=1&@b=@2",
|
||||
Response: testResponse(),
|
||||
Match: func(req *http.Request) bool {
|
||||
return tests.ExpectHeaders(req.Header, map[string]string{
|
||||
"Authorization": "AWS4-HMAC-SHA256 Credential=456/20250102/test_region/s3/aws4_request, SignedHeaders=host;x-amz-content-sha256;x-amz-date, Signature=e0001982deef1652704f74503203e77d83d4c88369421f9fca644d96f2a62a3c",
|
||||
"Host": "test_bucket.example.com",
|
||||
"X-Amz-Content-Sha256": "UNSIGNED-PAYLOAD",
|
||||
"X-Amz-Date": "20250102T150405Z",
|
||||
})
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
{
|
||||
"with extra headers",
|
||||
"/test",
|
||||
func(req *http.Request) {
|
||||
req.Header.Set("x-amz-date", "20250102T150405Z")
|
||||
req.Header.Set("x-amz-content-sha256", "test_sha256")
|
||||
req.Header.Set("x-amz-example", "123")
|
||||
req.Header.Set("x-amz-meta-a", "456")
|
||||
req.Header.Set("content-type", "image/png")
|
||||
req.Header.Set("x-test", "789") // shouldn't be included in the signing headers
|
||||
},
|
||||
&s3.S3{
|
||||
Region: "test_region",
|
||||
Bucket: "test_bucket",
|
||||
Endpoint: "https://example.com/",
|
||||
AccessKey: "123",
|
||||
SecretKey: "abc",
|
||||
Client: tests.NewClient(&tests.RequestStub{
|
||||
Method: http.MethodGet,
|
||||
URL: "https://test_bucket.example.com/test",
|
||||
Response: testResponse(),
|
||||
Match: func(req *http.Request) bool {
|
||||
return tests.ExpectHeaders(req.Header, map[string]string{
|
||||
"authorization": "AWS4-HMAC-SHA256 Credential=123/20250102/test_region/s3/aws4_request, SignedHeaders=content-type;host;x-amz-content-sha256;x-amz-date;x-amz-example;x-amz-meta-a, Signature=86dccbcd012c33073dc99e9d0a9e0b717a4d8c11c37848cfa9a4a02716bc0db3",
|
||||
"host": "test_bucket.example.com",
|
||||
"x-amz-date": "20250102T150405Z",
|
||||
"x-amz-content-sha256": "test_sha256",
|
||||
"x-amz-example": "123",
|
||||
"x-amz-meta-a": "456",
|
||||
"x-test": "789",
|
||||
})
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
req, err := http.NewRequest(http.MethodGet, s.s3Client.URL(s.path), strings.NewReader("test_request"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if s.reqFunc != nil {
|
||||
s.reqFunc(req)
|
||||
}
|
||||
|
||||
resp, err := s.s3Client.SignAndSend(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
err = s.s3Client.Client.(*tests.Client).AssertNoRemaining()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expectedBody := "test_response"
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if str := string(body); str != expectedBody {
|
||||
t.Fatalf("Expected body %q, got %q", expectedBody, str)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
111
tools/filesystem/internal/s3blob/s3/tests/client.go
Normal file
111
tools/filesystem/internal/s3blob/s3/tests/client.go
Normal file
|
@ -0,0 +1,111 @@
|
|||
// Package tests contains various tests helpers and utilities to assist
|
||||
// with the S3 client testing.
|
||||
package tests
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// NewClient creates a new test Client loaded with the specified RequestStubs.
|
||||
func NewClient(stubs ...*RequestStub) *Client {
|
||||
return &Client{stubs: stubs}
|
||||
}
|
||||
|
||||
type RequestStub struct {
|
||||
Method string
|
||||
URL string // plain string or regex pattern wrapped in "^pattern$"
|
||||
Match func(req *http.Request) bool
|
||||
Response *http.Response
|
||||
}
|
||||
|
||||
type Client struct {
|
||||
stubs []*RequestStub
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// AssertNoRemaining asserts that current client has no unprocessed requests remaining.
|
||||
func (c *Client) AssertNoRemaining() error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if len(c.stubs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
msgParts := make([]string, 0, len(c.stubs)+1)
|
||||
msgParts = append(msgParts, "not all stub requests were processed:")
|
||||
for _, stub := range c.stubs {
|
||||
msgParts = append(msgParts, "- "+stub.Method+" "+stub.URL)
|
||||
}
|
||||
|
||||
return errors.New(strings.Join(msgParts, "\n"))
|
||||
}
|
||||
|
||||
// Do implements the [s3.HTTPClient] interface.
|
||||
func (c *Client) Do(req *http.Request) (*http.Response, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
for i, stub := range c.stubs {
|
||||
if req.Method != stub.Method {
|
||||
continue
|
||||
}
|
||||
|
||||
urlPattern := stub.URL
|
||||
if !strings.HasPrefix(urlPattern, "^") && !strings.HasSuffix(urlPattern, "$") {
|
||||
urlPattern = "^" + regexp.QuoteMeta(urlPattern) + "$"
|
||||
}
|
||||
|
||||
urlRegex, err := regexp.Compile(urlPattern)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !urlRegex.MatchString(req.URL.String()) {
|
||||
continue
|
||||
}
|
||||
|
||||
if stub.Match != nil && !stub.Match(req) {
|
||||
continue
|
||||
}
|
||||
|
||||
// remove from the remaining stubs
|
||||
c.stubs = slices.Delete(c.stubs, i, i+1)
|
||||
|
||||
response := stub.Response
|
||||
if response == nil {
|
||||
response = &http.Response{}
|
||||
}
|
||||
if response.Header == nil {
|
||||
response.Header = http.Header{}
|
||||
}
|
||||
if response.Body == nil {
|
||||
response.Body = http.NoBody
|
||||
}
|
||||
|
||||
response.Request = req
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
var body []byte
|
||||
if req.Body != nil {
|
||||
defer req.Body.Close()
|
||||
body, _ = io.ReadAll(req.Body)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf(
|
||||
"the below request doesn't have a corresponding stub:\n%s %s\nHeaders: %v\nBody: %q",
|
||||
req.Method,
|
||||
req.URL.String(),
|
||||
req.Header,
|
||||
body,
|
||||
)
|
||||
}
|
33
tools/filesystem/internal/s3blob/s3/tests/headers.go
Normal file
33
tools/filesystem/internal/s3blob/s3/tests/headers.go
Normal file
|
@ -0,0 +1,33 @@
|
|||
package tests
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ExpectHeaders checks whether specified headers match the expectations.
|
||||
// The expectations map entry key is the header name.
|
||||
// The expectations map entry value is the first header value. If wrapped with `^...$`
|
||||
// it is compared as regular expression.
|
||||
func ExpectHeaders(headers http.Header, expectations map[string]string) bool {
|
||||
for h, expected := range expectations {
|
||||
v := headers.Get(h)
|
||||
|
||||
pattern := expected
|
||||
if !strings.HasPrefix(pattern, "^") && !strings.HasSuffix(pattern, "$") {
|
||||
pattern = "^" + regexp.QuoteMeta(pattern) + "$"
|
||||
}
|
||||
|
||||
expectedRegex, err := regexp.Compile(pattern)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if !expectedRegex.MatchString(v) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
414
tools/filesystem/internal/s3blob/s3/uploader.go
Normal file
414
tools/filesystem/internal/s3blob/s3/uploader.go
Normal file
|
@ -0,0 +1,414 @@
|
|||
package s3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/xml"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
var ErrUsedUploader = errors.New("the Uploader has been already used")
|
||||
|
||||
const (
|
||||
defaultMaxConcurrency int = 5
|
||||
defaultMinPartSize int = 6 << 20
|
||||
)
|
||||
|
||||
// Uploader handles the upload of a single S3 object.
|
||||
//
|
||||
// If the Payload size is less than the configured MinPartSize it sends
|
||||
// a single (PutObject) request, otherwise performs chunked/multipart upload.
|
||||
type Uploader struct {
|
||||
// S3 is the S3 client instance performing the upload object request (required).
|
||||
S3 *S3
|
||||
|
||||
// Payload is the object content to upload (required).
|
||||
Payload io.Reader
|
||||
|
||||
// Key is the destination key of the uploaded object (required).
|
||||
Key string
|
||||
|
||||
// Metadata specifies the optional metadata to write with the object upload.
|
||||
Metadata map[string]string
|
||||
|
||||
// MaxConcurrency specifies the max number of workers to use when
|
||||
// performing chunked/multipart upload.
|
||||
//
|
||||
// If zero or negative, defaults to 5.
|
||||
//
|
||||
// This option is used only when the Payload size is > MinPartSize.
|
||||
MaxConcurrency int
|
||||
|
||||
// MinPartSize specifies the min Payload size required to perform
|
||||
// chunked/multipart upload.
|
||||
//
|
||||
// If zero or negative, defaults to ~6MB.
|
||||
MinPartSize int
|
||||
|
||||
uploadId string
|
||||
uploadedParts []*mpPart
|
||||
lastPartNumber int
|
||||
mu sync.Mutex // guards lastPartNumber and the uploadedParts slice
|
||||
used bool
|
||||
}
|
||||
|
||||
// Upload processes the current Uploader instance.
|
||||
//
|
||||
// Users can specify an optional optReqFuncs that will be passed down to all Upload internal requests
|
||||
// (single upload, multipart init, multipart parts upload, multipart complete, multipart abort).
|
||||
//
|
||||
// Note that after this call the Uploader should be discarded (aka. no longer can be used).
|
||||
func (u *Uploader) Upload(ctx context.Context, optReqFuncs ...func(*http.Request)) error {
|
||||
if u.used {
|
||||
return ErrUsedUploader
|
||||
}
|
||||
|
||||
err := u.validateAndNormalize()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
initPart, _, err := u.nextPart()
|
||||
if err != nil && !errors.Is(err, io.EOF) {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(initPart) < u.MinPartSize {
|
||||
return u.singleUpload(ctx, initPart, optReqFuncs...)
|
||||
}
|
||||
|
||||
err = u.multipartInit(ctx, optReqFuncs...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("multipart init error: %w", err)
|
||||
}
|
||||
|
||||
err = u.multipartUpload(ctx, initPart, optReqFuncs...)
|
||||
if err != nil {
|
||||
return errors.Join(
|
||||
u.multipartAbort(ctx, optReqFuncs...),
|
||||
fmt.Errorf("multipart upload error: %w", err),
|
||||
)
|
||||
}
|
||||
|
||||
err = u.multipartComplete(ctx, optReqFuncs...)
|
||||
if err != nil {
|
||||
return errors.Join(
|
||||
u.multipartAbort(ctx, optReqFuncs...),
|
||||
fmt.Errorf("multipart complete error: %w", err),
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
func (u *Uploader) validateAndNormalize() error {
|
||||
if u.S3 == nil {
|
||||
return errors.New("Uploader.S3 must be a non-empty and properly initialized S3 client instance")
|
||||
}
|
||||
|
||||
if u.Key == "" {
|
||||
return errors.New("Uploader.Key is required")
|
||||
}
|
||||
|
||||
if u.Payload == nil {
|
||||
return errors.New("Uploader.Payload must be non-nill")
|
||||
}
|
||||
|
||||
if u.MaxConcurrency <= 0 {
|
||||
u.MaxConcurrency = defaultMaxConcurrency
|
||||
}
|
||||
|
||||
if u.MinPartSize <= 0 {
|
||||
u.MinPartSize = defaultMinPartSize
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *Uploader) singleUpload(ctx context.Context, part []byte, optReqFuncs ...func(*http.Request)) error {
|
||||
if u.used {
|
||||
return ErrUsedUploader
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPut, u.S3.URL(u.Key), bytes.NewReader(part))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Length", strconv.Itoa(len(part)))
|
||||
|
||||
for k, v := range u.Metadata {
|
||||
req.Header.Set(metadataPrefix+k, v)
|
||||
}
|
||||
|
||||
// apply optional request funcs
|
||||
for _, fn := range optReqFuncs {
|
||||
if fn != nil {
|
||||
fn(req)
|
||||
}
|
||||
}
|
||||
|
||||
resp, err := u.S3.SignAndSend(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
type mpPart struct {
|
||||
XMLName xml.Name `xml:"Part"`
|
||||
ETag string `xml:"ETag"`
|
||||
PartNumber int `xml:"PartNumber"`
|
||||
}
|
||||
|
||||
// https://docs.aws.amazon.com/AmazonS3/latest/API/API_CreateMultipartUpload.html
|
||||
func (u *Uploader) multipartInit(ctx context.Context, optReqFuncs ...func(*http.Request)) error {
|
||||
if u.used {
|
||||
return ErrUsedUploader
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, u.S3.URL(u.Key+"?uploads"), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for k, v := range u.Metadata {
|
||||
req.Header.Set(metadataPrefix+k, v)
|
||||
}
|
||||
|
||||
// apply optional request funcs
|
||||
for _, fn := range optReqFuncs {
|
||||
if fn != nil {
|
||||
fn(req)
|
||||
}
|
||||
}
|
||||
|
||||
resp, err := u.S3.SignAndSend(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body := &struct {
|
||||
XMLName xml.Name `xml:"InitiateMultipartUploadResult"`
|
||||
UploadId string `xml:"UploadId"`
|
||||
}{}
|
||||
|
||||
err = xml.NewDecoder(resp.Body).Decode(body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
u.uploadId = body.UploadId
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// https://docs.aws.amazon.com/AmazonS3/latest/API/API_AbortMultipartUpload.html
|
||||
func (u *Uploader) multipartAbort(ctx context.Context, optReqFuncs ...func(*http.Request)) error {
|
||||
u.mu.Lock()
|
||||
defer u.mu.Unlock()
|
||||
|
||||
u.used = true
|
||||
|
||||
// ensure that the specified abort context is always valid to allow cleanup
|
||||
var abortCtx = ctx
|
||||
if abortCtx.Err() != nil {
|
||||
abortCtx = context.Background()
|
||||
}
|
||||
|
||||
query := url.Values{"uploadId": []string{u.uploadId}}
|
||||
|
||||
req, err := http.NewRequestWithContext(abortCtx, http.MethodDelete, u.S3.URL(u.Key+"?"+query.Encode()), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// apply optional request funcs
|
||||
for _, fn := range optReqFuncs {
|
||||
if fn != nil {
|
||||
fn(req)
|
||||
}
|
||||
}
|
||||
|
||||
resp, err := u.S3.SignAndSend(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// https://docs.aws.amazon.com/AmazonS3/latest/API/API_CompleteMultipartUpload.html
|
||||
func (u *Uploader) multipartComplete(ctx context.Context, optReqFuncs ...func(*http.Request)) error {
|
||||
u.mu.Lock()
|
||||
defer u.mu.Unlock()
|
||||
|
||||
u.used = true
|
||||
|
||||
// the list of parts must be sorted in ascending order
|
||||
slices.SortFunc(u.uploadedParts, func(a, b *mpPart) int {
|
||||
if a.PartNumber < b.PartNumber {
|
||||
return -1
|
||||
}
|
||||
if a.PartNumber > b.PartNumber {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
})
|
||||
|
||||
// build a request payload with the uploaded parts
|
||||
xmlParts := &struct {
|
||||
XMLName xml.Name `xml:"CompleteMultipartUpload"`
|
||||
Parts []*mpPart
|
||||
}{
|
||||
Parts: u.uploadedParts,
|
||||
}
|
||||
rawXMLParts, err := xml.Marshal(xmlParts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
reqPayload := strings.NewReader(xml.Header + string(rawXMLParts))
|
||||
|
||||
query := url.Values{"uploadId": []string{u.uploadId}}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, u.S3.URL(u.Key+"?"+query.Encode()), reqPayload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// apply optional request funcs
|
||||
for _, fn := range optReqFuncs {
|
||||
if fn != nil {
|
||||
fn(req)
|
||||
}
|
||||
}
|
||||
|
||||
resp, err := u.S3.SignAndSend(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *Uploader) nextPart() ([]byte, int, error) {
|
||||
u.mu.Lock()
|
||||
defer u.mu.Unlock()
|
||||
|
||||
part := make([]byte, u.MinPartSize)
|
||||
n, err := io.ReadFull(u.Payload, part)
|
||||
|
||||
// normalize io.EOF errors and ensure that io.EOF is returned only when there were no read bytes
|
||||
if err != nil && (errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF)) {
|
||||
if n == 0 {
|
||||
err = io.EOF
|
||||
} else {
|
||||
err = nil
|
||||
}
|
||||
}
|
||||
|
||||
u.lastPartNumber++
|
||||
|
||||
return part[0:n], u.lastPartNumber, err
|
||||
}
|
||||
|
||||
func (u *Uploader) multipartUpload(ctx context.Context, initPart []byte, optReqFuncs ...func(*http.Request)) error {
|
||||
var g errgroup.Group
|
||||
g.SetLimit(u.MaxConcurrency)
|
||||
|
||||
totalParallel := u.MaxConcurrency
|
||||
|
||||
if len(initPart) != 0 {
|
||||
totalParallel--
|
||||
initPartNumber := u.lastPartNumber
|
||||
g.Go(func() error {
|
||||
mp, err := u.uploadPart(ctx, initPartNumber, initPart, optReqFuncs...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
u.mu.Lock()
|
||||
u.uploadedParts = append(u.uploadedParts, mp)
|
||||
u.mu.Unlock()
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
for i := 0; i < totalParallel; i++ {
|
||||
g.Go(func() error {
|
||||
for {
|
||||
part, num, err := u.nextPart()
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
mp, err := u.uploadPart(ctx, num, part, optReqFuncs...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
u.mu.Lock()
|
||||
u.uploadedParts = append(u.uploadedParts, mp)
|
||||
u.mu.Unlock()
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
return g.Wait()
|
||||
}
|
||||
|
||||
// https://docs.aws.amazon.com/AmazonS3/latest/API/API_UploadPart.html
|
||||
func (u *Uploader) uploadPart(ctx context.Context, partNumber int, partData []byte, optReqFuncs ...func(*http.Request)) (*mpPart, error) {
|
||||
query := url.Values{}
|
||||
query.Set("uploadId", u.uploadId)
|
||||
query.Set("partNumber", strconv.Itoa(partNumber))
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPut, u.S3.URL(u.Key+"?"+query.Encode()), bytes.NewReader(partData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Length", strconv.Itoa(len(partData)))
|
||||
|
||||
// apply optional request funcs
|
||||
for _, fn := range optReqFuncs {
|
||||
if fn != nil {
|
||||
fn(req)
|
||||
}
|
||||
}
|
||||
|
||||
resp, err := u.S3.SignAndSend(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
return &mpPart{
|
||||
PartNumber: partNumber,
|
||||
ETag: resp.Header.Get("ETag"),
|
||||
}, nil
|
||||
}
|
463
tools/filesystem/internal/s3blob/s3/uploader_test.go
Normal file
463
tools/filesystem/internal/s3blob/s3/uploader_test.go
Normal file
|
@ -0,0 +1,463 @@
|
|||
package s3_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/filesystem/internal/s3blob/s3"
|
||||
"github.com/pocketbase/pocketbase/tools/filesystem/internal/s3blob/s3/tests"
|
||||
)
|
||||
|
||||
func TestUploaderRequiredFields(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
s3Client := &s3.S3{
|
||||
Client: tests.NewClient(&tests.RequestStub{Method: "PUT", URL: `^.+$`}), // match every upload
|
||||
Region: "test_region",
|
||||
Bucket: "test_bucket",
|
||||
Endpoint: "http://example.com",
|
||||
AccessKey: "123",
|
||||
SecretKey: "abc",
|
||||
}
|
||||
|
||||
payload := strings.NewReader("test")
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
uploader *s3.Uploader
|
||||
expectedError bool
|
||||
}{
|
||||
{
|
||||
"blank",
|
||||
&s3.Uploader{},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"no Key",
|
||||
&s3.Uploader{S3: s3Client, Payload: payload},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"no S3",
|
||||
&s3.Uploader{Key: "abc", Payload: payload},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"no Payload",
|
||||
&s3.Uploader{S3: s3Client, Key: "abc"},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"with S3, Key and Payload",
|
||||
&s3.Uploader{S3: s3Client, Key: "abc", Payload: payload},
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
err := s.uploader.Upload(context.Background())
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectedError {
|
||||
t.Fatalf("Expected hasErr %v, got %v", s.expectedError, hasErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUploaderSingleUpload(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
httpClient := tests.NewClient(
|
||||
&tests.RequestStub{
|
||||
Method: http.MethodPut,
|
||||
URL: "http://test_bucket.example.com/test_key",
|
||||
Match: func(req *http.Request) bool {
|
||||
body, err := io.ReadAll(req.Body)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return string(body) == "abcdefg" && tests.ExpectHeaders(req.Header, map[string]string{
|
||||
"Content-Length": "7",
|
||||
"x-amz-meta-a": "123",
|
||||
"x-amz-meta-b": "456",
|
||||
"test_header": "test",
|
||||
"Authorization": "^.+Credential=123/.+$",
|
||||
})
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
uploader := &s3.Uploader{
|
||||
S3: &s3.S3{
|
||||
Client: httpClient,
|
||||
Region: "test_region",
|
||||
Bucket: "test_bucket",
|
||||
Endpoint: "http://example.com",
|
||||
AccessKey: "123",
|
||||
SecretKey: "abc",
|
||||
},
|
||||
Key: "test_key",
|
||||
Payload: strings.NewReader("abcdefg"),
|
||||
Metadata: map[string]string{"a": "123", "b": "456"},
|
||||
MinPartSize: 8,
|
||||
}
|
||||
|
||||
err := uploader.Upload(context.Background(), func(r *http.Request) {
|
||||
r.Header.Set("test_header", "test")
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = httpClient.AssertNoRemaining()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUploaderMultipartUploadSuccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
httpClient := tests.NewClient(
|
||||
&tests.RequestStub{
|
||||
Method: http.MethodPost,
|
||||
URL: "http://test_bucket.example.com/test_key?uploads",
|
||||
Match: func(req *http.Request) bool {
|
||||
return tests.ExpectHeaders(req.Header, map[string]string{
|
||||
"x-amz-meta-a": "123",
|
||||
"x-amz-meta-b": "456",
|
||||
"test_header": "test",
|
||||
"Authorization": "^.+Credential=123/.+$",
|
||||
})
|
||||
},
|
||||
Response: &http.Response{
|
||||
Body: io.NopCloser(strings.NewReader(`
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<InitiateMultipartUploadResult>
|
||||
<Bucket>test_bucket</Bucket>
|
||||
<Key>test_key</Key>
|
||||
<UploadId>test_id</UploadId>
|
||||
</InitiateMultipartUploadResult>
|
||||
`)),
|
||||
},
|
||||
},
|
||||
&tests.RequestStub{
|
||||
Method: http.MethodPut,
|
||||
URL: "http://test_bucket.example.com/test_key?partNumber=1&uploadId=test_id",
|
||||
Match: func(req *http.Request) bool {
|
||||
body, err := io.ReadAll(req.Body)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return string(body) == "abc" && tests.ExpectHeaders(req.Header, map[string]string{
|
||||
"Content-Length": "3",
|
||||
"test_header": "test",
|
||||
"Authorization": "^.+Credential=123/.+$",
|
||||
})
|
||||
},
|
||||
Response: &http.Response{
|
||||
Header: http.Header{"Etag": []string{"etag1"}},
|
||||
},
|
||||
},
|
||||
&tests.RequestStub{
|
||||
Method: http.MethodPut,
|
||||
URL: "http://test_bucket.example.com/test_key?partNumber=2&uploadId=test_id",
|
||||
Match: func(req *http.Request) bool {
|
||||
body, err := io.ReadAll(req.Body)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return string(body) == "def" && tests.ExpectHeaders(req.Header, map[string]string{
|
||||
"Content-Length": "3",
|
||||
"test_header": "test",
|
||||
"Authorization": "^.+Credential=123/.+$",
|
||||
})
|
||||
},
|
||||
Response: &http.Response{
|
||||
Header: http.Header{"Etag": []string{"etag2"}},
|
||||
},
|
||||
},
|
||||
&tests.RequestStub{
|
||||
Method: http.MethodPut,
|
||||
URL: "http://test_bucket.example.com/test_key?partNumber=3&uploadId=test_id",
|
||||
Match: func(req *http.Request) bool {
|
||||
body, err := io.ReadAll(req.Body)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return string(body) == "g" && tests.ExpectHeaders(req.Header, map[string]string{
|
||||
"Content-Length": "1",
|
||||
"test_header": "test",
|
||||
"Authorization": "^.+Credential=123/.+$",
|
||||
})
|
||||
},
|
||||
Response: &http.Response{
|
||||
Header: http.Header{"Etag": []string{"etag3"}},
|
||||
},
|
||||
},
|
||||
&tests.RequestStub{
|
||||
Method: http.MethodPost,
|
||||
URL: "http://test_bucket.example.com/test_key?uploadId=test_id",
|
||||
Match: func(req *http.Request) bool {
|
||||
body, err := io.ReadAll(req.Body)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
expected := `<CompleteMultipartUpload><Part><ETag>etag1</ETag><PartNumber>1</PartNumber></Part><Part><ETag>etag2</ETag><PartNumber>2</PartNumber></Part><Part><ETag>etag3</ETag><PartNumber>3</PartNumber></Part></CompleteMultipartUpload>`
|
||||
|
||||
return strings.Contains(string(body), expected) && tests.ExpectHeaders(req.Header, map[string]string{
|
||||
"test_header": "test",
|
||||
"Authorization": "^.+Credential=123/.+$",
|
||||
})
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
uploader := &s3.Uploader{
|
||||
S3: &s3.S3{
|
||||
Client: httpClient,
|
||||
Region: "test_region",
|
||||
Bucket: "test_bucket",
|
||||
Endpoint: "http://example.com",
|
||||
AccessKey: "123",
|
||||
SecretKey: "abc",
|
||||
},
|
||||
Key: "test_key",
|
||||
Payload: strings.NewReader("abcdefg"),
|
||||
Metadata: map[string]string{"a": "123", "b": "456"},
|
||||
MinPartSize: 3,
|
||||
}
|
||||
|
||||
err := uploader.Upload(context.Background(), func(r *http.Request) {
|
||||
r.Header.Set("test_header", "test")
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = httpClient.AssertNoRemaining()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUploaderMultipartUploadPartFailure(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
httpClient := tests.NewClient(
|
||||
&tests.RequestStub{
|
||||
Method: http.MethodPost,
|
||||
URL: "http://test_bucket.example.com/test_key?uploads",
|
||||
Match: func(req *http.Request) bool {
|
||||
return tests.ExpectHeaders(req.Header, map[string]string{
|
||||
"x-amz-meta-a": "123",
|
||||
"x-amz-meta-b": "456",
|
||||
"test_header": "test",
|
||||
"Authorization": "^.+Credential=123/.+$",
|
||||
})
|
||||
},
|
||||
Response: &http.Response{
|
||||
Body: io.NopCloser(strings.NewReader(`
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<InitiateMultipartUploadResult>
|
||||
<Bucket>test_bucket</Bucket>
|
||||
<Key>test_key</Key>
|
||||
<UploadId>test_id</UploadId>
|
||||
</InitiateMultipartUploadResult>
|
||||
`)),
|
||||
},
|
||||
},
|
||||
&tests.RequestStub{
|
||||
Method: http.MethodPut,
|
||||
URL: "http://test_bucket.example.com/test_key?partNumber=1&uploadId=test_id",
|
||||
Match: func(req *http.Request) bool {
|
||||
body, err := io.ReadAll(req.Body)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return string(body) == "abc" && tests.ExpectHeaders(req.Header, map[string]string{
|
||||
"Content-Length": "3",
|
||||
"test_header": "test",
|
||||
"Authorization": "^.+Credential=123/.+$",
|
||||
})
|
||||
},
|
||||
Response: &http.Response{
|
||||
Header: http.Header{"Etag": []string{"etag1"}},
|
||||
},
|
||||
},
|
||||
&tests.RequestStub{
|
||||
Method: http.MethodPut,
|
||||
URL: "http://test_bucket.example.com/test_key?partNumber=2&uploadId=test_id",
|
||||
Match: func(req *http.Request) bool {
|
||||
return tests.ExpectHeaders(req.Header, map[string]string{
|
||||
"test_header": "test",
|
||||
"Authorization": "^.+Credential=123/.+$",
|
||||
})
|
||||
},
|
||||
Response: &http.Response{
|
||||
StatusCode: 400,
|
||||
},
|
||||
},
|
||||
&tests.RequestStub{
|
||||
Method: http.MethodDelete,
|
||||
URL: "http://test_bucket.example.com/test_key?uploadId=test_id",
|
||||
Match: func(req *http.Request) bool {
|
||||
return tests.ExpectHeaders(req.Header, map[string]string{
|
||||
"test_header": "test",
|
||||
"Authorization": "^.+Credential=123/.+$",
|
||||
})
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
uploader := &s3.Uploader{
|
||||
S3: &s3.S3{
|
||||
Client: httpClient,
|
||||
Region: "test_region",
|
||||
Bucket: "test_bucket",
|
||||
Endpoint: "http://example.com",
|
||||
AccessKey: "123",
|
||||
SecretKey: "abc",
|
||||
},
|
||||
Key: "test_key",
|
||||
Payload: strings.NewReader("abcdefg"),
|
||||
Metadata: map[string]string{"a": "123", "b": "456"},
|
||||
MinPartSize: 3,
|
||||
}
|
||||
|
||||
err := uploader.Upload(context.Background(), func(r *http.Request) {
|
||||
r.Header.Set("test_header", "test")
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("Expected non-nil error")
|
||||
}
|
||||
|
||||
err = httpClient.AssertNoRemaining()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUploaderMultipartUploadCompleteFailure(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
httpClient := tests.NewClient(
|
||||
&tests.RequestStub{
|
||||
Method: http.MethodPost,
|
||||
URL: "http://test_bucket.example.com/test_key?uploads",
|
||||
Match: func(req *http.Request) bool {
|
||||
return tests.ExpectHeaders(req.Header, map[string]string{
|
||||
"x-amz-meta-a": "123",
|
||||
"x-amz-meta-b": "456",
|
||||
"test_header": "test",
|
||||
"Authorization": "^.+Credential=123/.+$",
|
||||
})
|
||||
},
|
||||
Response: &http.Response{
|
||||
Body: io.NopCloser(strings.NewReader(`
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<InitiateMultipartUploadResult>
|
||||
<Bucket>test_bucket</Bucket>
|
||||
<Key>test_key</Key>
|
||||
<UploadId>test_id</UploadId>
|
||||
</InitiateMultipartUploadResult>
|
||||
`)),
|
||||
},
|
||||
},
|
||||
&tests.RequestStub{
|
||||
Method: http.MethodPut,
|
||||
URL: "http://test_bucket.example.com/test_key?partNumber=1&uploadId=test_id",
|
||||
Match: func(req *http.Request) bool {
|
||||
body, err := io.ReadAll(req.Body)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return string(body) == "abc" && tests.ExpectHeaders(req.Header, map[string]string{
|
||||
"Content-Length": "3",
|
||||
"test_header": "test",
|
||||
"Authorization": "^.+Credential=123/.+$",
|
||||
})
|
||||
},
|
||||
Response: &http.Response{
|
||||
Header: http.Header{"Etag": []string{"etag1"}},
|
||||
},
|
||||
},
|
||||
&tests.RequestStub{
|
||||
Method: http.MethodPut,
|
||||
URL: "http://test_bucket.example.com/test_key?partNumber=2&uploadId=test_id",
|
||||
Match: func(req *http.Request) bool {
|
||||
body, err := io.ReadAll(req.Body)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return string(body) == "def" && tests.ExpectHeaders(req.Header, map[string]string{
|
||||
"Content-Length": "3",
|
||||
"test_header": "test",
|
||||
"Authorization": "^.+Credential=123/.+$",
|
||||
})
|
||||
},
|
||||
Response: &http.Response{
|
||||
Header: http.Header{"Etag": []string{"etag2"}},
|
||||
},
|
||||
},
|
||||
&tests.RequestStub{
|
||||
Method: http.MethodPost,
|
||||
URL: "http://test_bucket.example.com/test_key?uploadId=test_id",
|
||||
Match: func(req *http.Request) bool {
|
||||
return tests.ExpectHeaders(req.Header, map[string]string{
|
||||
"test_header": "test",
|
||||
"Authorization": "^.+Credential=123/.+$",
|
||||
})
|
||||
},
|
||||
Response: &http.Response{
|
||||
StatusCode: 400,
|
||||
},
|
||||
},
|
||||
&tests.RequestStub{
|
||||
Method: http.MethodDelete,
|
||||
URL: "http://test_bucket.example.com/test_key?uploadId=test_id",
|
||||
Match: func(req *http.Request) bool {
|
||||
return tests.ExpectHeaders(req.Header, map[string]string{
|
||||
"test_header": "test",
|
||||
"Authorization": "^.+Credential=123/.+$",
|
||||
})
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
uploader := &s3.Uploader{
|
||||
S3: &s3.S3{
|
||||
Client: httpClient,
|
||||
Region: "test_region",
|
||||
Bucket: "test_bucket",
|
||||
Endpoint: "http://example.com",
|
||||
AccessKey: "123",
|
||||
SecretKey: "abc",
|
||||
},
|
||||
Key: "test_key",
|
||||
Payload: strings.NewReader("abcdef"),
|
||||
Metadata: map[string]string{"a": "123", "b": "456"},
|
||||
MinPartSize: 3,
|
||||
}
|
||||
|
||||
err := uploader.Upload(context.Background(), func(r *http.Request) {
|
||||
r.Header.Set("test_header", "test")
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("Expected non-nil error")
|
||||
}
|
||||
|
||||
err = httpClient.AssertNoRemaining()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
485
tools/filesystem/internal/s3blob/s3blob.go
Normal file
485
tools/filesystem/internal/s3blob/s3blob.go
Normal file
|
@ -0,0 +1,485 @@
|
|||
// Package s3blob provides a blob.Bucket S3 driver implementation.
|
||||
//
|
||||
// NB! To minimize breaking changes with older PocketBase releases,
|
||||
// the driver is based of the previously used gocloud.dev/blob/s3blob,
|
||||
// hence many of the below doc comments, struct options and interface
|
||||
// implementations are the same.
|
||||
//
|
||||
// The blob abstraction supports all UTF-8 strings; to make this work with services lacking
|
||||
// full UTF-8 support, strings must be escaped (during writes) and unescaped
|
||||
// (during reads). The following escapes are performed for s3blob:
|
||||
// - Blob keys: ASCII characters 0-31 are escaped to "__0x<hex>__".
|
||||
// Additionally, the "/" in "../" is escaped in the same way.
|
||||
// - Metadata keys: Escaped using URL encoding, then additionally "@:=" are
|
||||
// escaped using "__0x<hex>__". These characters were determined by
|
||||
// experimentation.
|
||||
// - Metadata values: Escaped using URL encoding.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// drv, _ := s3blob.New(&s3.S3{
|
||||
// Bucket: "bucketName",
|
||||
// Region: "region",
|
||||
// Endpoint: "endpoint",
|
||||
// AccessKey: "accessKey",
|
||||
// SecretKey: "secretKey",
|
||||
// })
|
||||
// bucket := blob.NewBucket(drv)
|
||||
package s3blob
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/filesystem/blob"
|
||||
"github.com/pocketbase/pocketbase/tools/filesystem/internal/s3blob/s3"
|
||||
)
|
||||
|
||||
const defaultPageSize = 1000
|
||||
|
||||
// New creates a new instance of the S3 driver backed by the the internal S3 client.
|
||||
func New(s3Client *s3.S3) (blob.Driver, error) {
|
||||
if s3Client.Bucket == "" {
|
||||
return nil, errors.New("s3blob.New: missing bucket name")
|
||||
}
|
||||
|
||||
if s3Client.Endpoint == "" {
|
||||
return nil, errors.New("s3blob.New: missing endpoint")
|
||||
}
|
||||
|
||||
if s3Client.Region == "" {
|
||||
return nil, errors.New("s3blob.New: missing region")
|
||||
}
|
||||
|
||||
return &driver{s3: s3Client}, nil
|
||||
}
|
||||
|
||||
type driver struct {
|
||||
s3 *s3.S3
|
||||
}
|
||||
|
||||
// Close implements [blob/Driver.Close].
|
||||
func (drv *driver) Close() error {
|
||||
return nil // nothing to close
|
||||
}
|
||||
|
||||
// NormalizeError implements [blob/Driver.NormalizeError].
|
||||
func (drv *driver) NormalizeError(err error) error {
|
||||
// already normalized
|
||||
if errors.Is(err, blob.ErrNotFound) {
|
||||
return err
|
||||
}
|
||||
|
||||
// normalize base on its S3 error status or code
|
||||
var ae *s3.ResponseError
|
||||
if errors.As(err, &ae) {
|
||||
if ae.Status == 404 {
|
||||
return errors.Join(err, blob.ErrNotFound)
|
||||
}
|
||||
|
||||
switch ae.Code {
|
||||
case "NoSuchBucket", "NoSuchKey", "NotFound":
|
||||
return errors.Join(err, blob.ErrNotFound)
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// ListPaged implements [blob/Driver.ListPaged].
|
||||
func (drv *driver) ListPaged(ctx context.Context, opts *blob.ListOptions) (*blob.ListPage, error) {
|
||||
pageSize := opts.PageSize
|
||||
if pageSize == 0 {
|
||||
pageSize = defaultPageSize
|
||||
}
|
||||
|
||||
listParams := s3.ListParams{
|
||||
MaxKeys: pageSize,
|
||||
}
|
||||
if len(opts.PageToken) > 0 {
|
||||
listParams.ContinuationToken = string(opts.PageToken)
|
||||
}
|
||||
if opts.Prefix != "" {
|
||||
listParams.Prefix = escapeKey(opts.Prefix)
|
||||
}
|
||||
if opts.Delimiter != "" {
|
||||
listParams.Delimiter = escapeKey(opts.Delimiter)
|
||||
}
|
||||
|
||||
resp, err := drv.s3.ListObjects(ctx, listParams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
page := blob.ListPage{}
|
||||
if resp.NextContinuationToken != "" {
|
||||
page.NextPageToken = []byte(resp.NextContinuationToken)
|
||||
}
|
||||
|
||||
if n := len(resp.Contents) + len(resp.CommonPrefixes); n > 0 {
|
||||
page.Objects = make([]*blob.ListObject, n)
|
||||
for i, obj := range resp.Contents {
|
||||
page.Objects[i] = &blob.ListObject{
|
||||
Key: unescapeKey(obj.Key),
|
||||
ModTime: obj.LastModified,
|
||||
Size: obj.Size,
|
||||
MD5: eTagToMD5(obj.ETag),
|
||||
}
|
||||
}
|
||||
|
||||
for i, prefix := range resp.CommonPrefixes {
|
||||
page.Objects[i+len(resp.Contents)] = &blob.ListObject{
|
||||
Key: unescapeKey(prefix.Prefix),
|
||||
IsDir: true,
|
||||
}
|
||||
}
|
||||
|
||||
if len(resp.Contents) > 0 && len(resp.CommonPrefixes) > 0 {
|
||||
// S3 gives us blobs and "directories" in separate lists; sort them.
|
||||
sort.Slice(page.Objects, func(i, j int) bool {
|
||||
return page.Objects[i].Key < page.Objects[j].Key
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return &page, nil
|
||||
}
|
||||
|
||||
// Attributes implements [blob/Driver.Attributes].
|
||||
func (drv *driver) Attributes(ctx context.Context, key string) (*blob.Attributes, error) {
|
||||
key = escapeKey(key)
|
||||
|
||||
resp, err := drv.s3.HeadObject(ctx, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
md := make(map[string]string, len(resp.Metadata))
|
||||
for k, v := range resp.Metadata {
|
||||
// See the package comments for more details on escaping of metadata keys & values.
|
||||
md[blob.HexUnescape(urlUnescape(k))] = urlUnescape(v)
|
||||
}
|
||||
|
||||
return &blob.Attributes{
|
||||
CacheControl: resp.CacheControl,
|
||||
ContentDisposition: resp.ContentDisposition,
|
||||
ContentEncoding: resp.ContentEncoding,
|
||||
ContentLanguage: resp.ContentLanguage,
|
||||
ContentType: resp.ContentType,
|
||||
Metadata: md,
|
||||
// CreateTime not supported; left as the zero time.
|
||||
ModTime: resp.LastModified,
|
||||
Size: resp.ContentLength,
|
||||
MD5: eTagToMD5(resp.ETag),
|
||||
ETag: resp.ETag,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// NewRangeReader implements [blob/Driver.NewRangeReader].
|
||||
func (drv *driver) NewRangeReader(ctx context.Context, key string, offset, length int64) (blob.DriverReader, error) {
|
||||
key = escapeKey(key)
|
||||
|
||||
var byteRange string
|
||||
if offset > 0 && length < 0 {
|
||||
byteRange = fmt.Sprintf("bytes=%d-", offset)
|
||||
} else if length == 0 {
|
||||
// AWS doesn't support a zero-length read; we'll read 1 byte and then
|
||||
// ignore it in favor of http.NoBody below.
|
||||
byteRange = fmt.Sprintf("bytes=%d-%d", offset, offset)
|
||||
} else if length >= 0 {
|
||||
byteRange = fmt.Sprintf("bytes=%d-%d", offset, offset+length-1)
|
||||
}
|
||||
|
||||
reqOpt := func(req *http.Request) {
|
||||
req.Header.Set("Range", byteRange)
|
||||
}
|
||||
|
||||
resp, err := drv.s3.GetObject(ctx, key, reqOpt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
body := resp.Body
|
||||
if length == 0 {
|
||||
body = http.NoBody
|
||||
}
|
||||
|
||||
return &reader{
|
||||
body: body,
|
||||
attrs: &blob.ReaderAttributes{
|
||||
ContentType: resp.ContentType,
|
||||
ModTime: resp.LastModified,
|
||||
Size: getSize(resp.ContentLength, resp.ContentRange),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// NewTypedWriter implements [blob/Driver.NewTypedWriter].
|
||||
func (drv *driver) NewTypedWriter(ctx context.Context, key string, contentType string, opts *blob.WriterOptions) (blob.DriverWriter, error) {
|
||||
key = escapeKey(key)
|
||||
|
||||
u := &s3.Uploader{
|
||||
S3: drv.s3,
|
||||
Key: key,
|
||||
}
|
||||
|
||||
if opts.BufferSize != 0 {
|
||||
u.MinPartSize = opts.BufferSize
|
||||
}
|
||||
|
||||
if opts.MaxConcurrency != 0 {
|
||||
u.MaxConcurrency = opts.MaxConcurrency
|
||||
}
|
||||
|
||||
md := make(map[string]string, len(opts.Metadata))
|
||||
for k, v := range opts.Metadata {
|
||||
// See the package comments for more details on escaping of metadata keys & values.
|
||||
k = blob.HexEscape(url.PathEscape(k), func(runes []rune, i int) bool {
|
||||
c := runes[i]
|
||||
return c == '@' || c == ':' || c == '='
|
||||
})
|
||||
md[k] = url.PathEscape(v)
|
||||
}
|
||||
u.Metadata = md
|
||||
|
||||
var reqOptions []func(*http.Request)
|
||||
reqOptions = append(reqOptions, func(r *http.Request) {
|
||||
r.Header.Set("Content-Type", contentType)
|
||||
|
||||
if opts.CacheControl != "" {
|
||||
r.Header.Set("Cache-Control", opts.CacheControl)
|
||||
}
|
||||
if opts.ContentDisposition != "" {
|
||||
r.Header.Set("Content-Disposition", opts.ContentDisposition)
|
||||
}
|
||||
if opts.ContentEncoding != "" {
|
||||
r.Header.Set("Content-Encoding", opts.ContentEncoding)
|
||||
}
|
||||
if opts.ContentLanguage != "" {
|
||||
r.Header.Set("Content-Language", opts.ContentLanguage)
|
||||
}
|
||||
if len(opts.ContentMD5) > 0 {
|
||||
r.Header.Set("Content-MD5", base64.StdEncoding.EncodeToString(opts.ContentMD5))
|
||||
}
|
||||
})
|
||||
|
||||
return &writer{
|
||||
ctx: ctx,
|
||||
uploader: u,
|
||||
donec: make(chan struct{}),
|
||||
reqOptions: reqOptions,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Copy implements [blob/Driver.Copy].
|
||||
func (drv *driver) Copy(ctx context.Context, dstKey, srcKey string) error {
|
||||
dstKey = escapeKey(dstKey)
|
||||
srcKey = escapeKey(srcKey)
|
||||
_, err := drv.s3.CopyObject(ctx, srcKey, dstKey)
|
||||
return err
|
||||
}
|
||||
|
||||
// Delete implements [blob/Driver.Delete].
|
||||
func (drv *driver) Delete(ctx context.Context, key string) error {
|
||||
key = escapeKey(key)
|
||||
return drv.s3.DeleteObject(ctx, key)
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
// reader reads an S3 object. It implements io.ReadCloser.
|
||||
type reader struct {
|
||||
attrs *blob.ReaderAttributes
|
||||
body io.ReadCloser
|
||||
}
|
||||
|
||||
// Read implements [io/ReadCloser.Read].
|
||||
func (r *reader) Read(p []byte) (int, error) {
|
||||
return r.body.Read(p)
|
||||
}
|
||||
|
||||
// Close closes the reader itself. It must be called when done reading.
|
||||
func (r *reader) Close() error {
|
||||
return r.body.Close()
|
||||
}
|
||||
|
||||
// Attributes implements [blob/DriverReader.Attributes].
|
||||
func (r *reader) Attributes() *blob.ReaderAttributes {
|
||||
return r.attrs
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
// writer writes an S3 object, it implements io.WriteCloser.
|
||||
type writer struct {
|
||||
ctx context.Context
|
||||
err error // written before donec closes
|
||||
uploader *s3.Uploader
|
||||
|
||||
// Ends of an io.Pipe, created when the first byte is written.
|
||||
pw *io.PipeWriter
|
||||
pr *io.PipeReader
|
||||
|
||||
donec chan struct{} // closed when done writing
|
||||
|
||||
reqOptions []func(*http.Request)
|
||||
}
|
||||
|
||||
// Write appends p to w.pw. User must call Close to close the w after done writing.
|
||||
func (w *writer) Write(p []byte) (int, error) {
|
||||
// Avoid opening the pipe for a zero-length write;
|
||||
// the concrete can do these for empty blobs.
|
||||
if len(p) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
if w.pw == nil {
|
||||
// We'll write into pw and use pr as an io.Reader for the
|
||||
// Upload call to S3.
|
||||
w.pr, w.pw = io.Pipe()
|
||||
w.open(w.pr, true)
|
||||
}
|
||||
|
||||
return w.pw.Write(p)
|
||||
}
|
||||
|
||||
// r may be nil if we're Closing and no data was written.
|
||||
// If closePipeOnError is true, w.pr will be closed if there's an
|
||||
// error uploading to S3.
|
||||
func (w *writer) open(r io.Reader, closePipeOnError bool) {
|
||||
// This goroutine will keep running until Close, unless there's an error.
|
||||
go func() {
|
||||
defer func() {
|
||||
close(w.donec)
|
||||
}()
|
||||
|
||||
if r == nil {
|
||||
// AWS doesn't like a nil Body.
|
||||
r = http.NoBody
|
||||
}
|
||||
|
||||
w.uploader.Payload = r
|
||||
|
||||
err := w.uploader.Upload(w.ctx, w.reqOptions...)
|
||||
if err != nil {
|
||||
if closePipeOnError {
|
||||
w.pr.CloseWithError(err)
|
||||
}
|
||||
w.err = err
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Close completes the writer and closes it. Any error occurring during write
|
||||
// will be returned. If a writer is closed before any Write is called, Close
|
||||
// will create an empty file at the given key.
|
||||
func (w *writer) Close() error {
|
||||
if w.pr != nil {
|
||||
defer w.pr.Close()
|
||||
}
|
||||
|
||||
if w.pw == nil {
|
||||
// We never got any bytes written. We'll write an http.NoBody.
|
||||
w.open(nil, false)
|
||||
} else if err := w.pw.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
<-w.donec
|
||||
|
||||
return w.err
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
// etagToMD5 processes an ETag header and returns an MD5 hash if possible.
|
||||
// S3's ETag header is sometimes a quoted hexstring of the MD5. Other times,
|
||||
// notably when the object was uploaded in multiple parts, it is not.
|
||||
// We do the best we can.
|
||||
// Some links about ETag:
|
||||
// https://docs.aws.amazon.com/AmazonS3/latest/API/RESTCommonResponseHeaders.html
|
||||
// https://github.com/aws/aws-sdk-net/issues/815
|
||||
// https://teppen.io/2018/06/23/aws_s3_etags/
|
||||
func eTagToMD5(etag string) []byte {
|
||||
// No header at all.
|
||||
if etag == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Strip the expected leading and trailing quotes.
|
||||
if len(etag) < 2 || etag[0] != '"' || etag[len(etag)-1] != '"' {
|
||||
return nil
|
||||
}
|
||||
unquoted := etag[1 : len(etag)-1]
|
||||
|
||||
// Un-hex; we return nil on error. In particular, we'll get an error here
|
||||
// for multi-part uploaded blobs, whose ETag will contain a "-" and so will
|
||||
// never be a legal hex encoding.
|
||||
md5, err := hex.DecodeString(unquoted)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return md5
|
||||
}
|
||||
|
||||
func getSize(contentLength int64, contentRange string) int64 {
|
||||
// Default size to ContentLength, but that's incorrect for partial-length reads,
|
||||
// where ContentLength refers to the size of the returned Body, not the entire
|
||||
// size of the blob. ContentRange has the full size.
|
||||
size := contentLength
|
||||
if contentRange != "" {
|
||||
// Sample: bytes 10-14/27 (where 27 is the full size).
|
||||
parts := strings.Split(contentRange, "/")
|
||||
if len(parts) == 2 {
|
||||
if i, err := strconv.ParseInt(parts[1], 10, 64); err == nil {
|
||||
size = i
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return size
|
||||
}
|
||||
|
||||
// escapeKey does all required escaping for UTF-8 strings to work with S3.
|
||||
func escapeKey(key string) string {
|
||||
return blob.HexEscape(key, func(r []rune, i int) bool {
|
||||
c := r[i]
|
||||
|
||||
// S3 doesn't handle these characters (determined via experimentation).
|
||||
if c < 32 {
|
||||
return true
|
||||
}
|
||||
|
||||
// For "../", escape the trailing slash.
|
||||
if i > 1 && c == '/' && r[i-1] == '.' && r[i-2] == '.' {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
})
|
||||
}
|
||||
|
||||
// unescapeKey reverses escapeKey.
|
||||
func unescapeKey(key string) string {
|
||||
return blob.HexUnescape(key)
|
||||
}
|
||||
|
||||
// urlUnescape reverses URLEscape using url.PathUnescape. If the unescape
|
||||
// returns an error, it returns s.
|
||||
func urlUnescape(s string) string {
|
||||
if u, err := url.PathUnescape(s); err == nil {
|
||||
return u
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
605
tools/filesystem/internal/s3blob/s3blob_test.go
Normal file
605
tools/filesystem/internal/s3blob/s3blob_test.go
Normal file
|
@ -0,0 +1,605 @@
|
|||
package s3blob_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/filesystem/blob"
|
||||
"github.com/pocketbase/pocketbase/tools/filesystem/internal/s3blob"
|
||||
"github.com/pocketbase/pocketbase/tools/filesystem/internal/s3blob/s3"
|
||||
"github.com/pocketbase/pocketbase/tools/filesystem/internal/s3blob/s3/tests"
|
||||
)
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
s3Client *s3.S3
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
"blank",
|
||||
&s3.S3{},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"no bucket",
|
||||
&s3.S3{Region: "b", Endpoint: "c"},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"no endpoint",
|
||||
&s3.S3{Bucket: "a", Region: "b"},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"no region",
|
||||
&s3.S3{Bucket: "a", Endpoint: "c"},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"with bucket, endpoint and region",
|
||||
&s3.S3{Bucket: "a", Region: "b", Endpoint: "c"},
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
drv, err := s3blob.New(s.s3Client)
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
|
||||
}
|
||||
|
||||
if err == nil && drv == nil {
|
||||
t.Fatal("Expected non-nil driver instance")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDriverClose(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
drv, err := s3blob.New(&s3.S3{Bucket: "a", Region: "b", Endpoint: "c"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = drv.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("Expected nil, got error %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDriverNormilizeError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
drv, err := s3blob.New(&s3.S3{Bucket: "a", Region: "b", Endpoint: "c"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
err error
|
||||
expectErrNotFound bool
|
||||
}{
|
||||
{
|
||||
"plain error",
|
||||
errors.New("test"),
|
||||
false,
|
||||
},
|
||||
{
|
||||
"response error with only status (non-404)",
|
||||
&s3.ResponseError{Status: 123},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"response error with only status (404)",
|
||||
&s3.ResponseError{Status: 404},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"response error with custom code",
|
||||
&s3.ResponseError{Code: "test"},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"response error with NoSuchBucket code",
|
||||
&s3.ResponseError{Code: "NoSuchBucket"},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"response error with NoSuchKey code",
|
||||
&s3.ResponseError{Code: "NoSuchKey"},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"response error with NotFound code",
|
||||
&s3.ResponseError{Code: "NotFound"},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"wrapped response error with NotFound code", // ensures that the entire error's tree is checked
|
||||
fmt.Errorf("test: %w", &s3.ResponseError{Code: "NotFound"}),
|
||||
true,
|
||||
},
|
||||
{
|
||||
"already normalized error",
|
||||
fmt.Errorf("test: %w", blob.ErrNotFound),
|
||||
true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
err := drv.NormalizeError(s.err)
|
||||
if err == nil {
|
||||
t.Fatal("Expected non-nil error")
|
||||
}
|
||||
|
||||
isErrNotFound := errors.Is(err, blob.ErrNotFound)
|
||||
if isErrNotFound != s.expectErrNotFound {
|
||||
t.Fatalf("Expected isErrNotFound %v, got %v (%v)", s.expectErrNotFound, isErrNotFound, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDriverDeleteEscaping(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
httpClient := tests.NewClient(&tests.RequestStub{
|
||||
Method: http.MethodDelete,
|
||||
URL: "https://test_bucket.example.com/..__0x2f__abc/test/",
|
||||
})
|
||||
|
||||
drv, err := s3blob.New(&s3.S3{
|
||||
Bucket: "test_bucket",
|
||||
Region: "test_region",
|
||||
Endpoint: "https://example.com",
|
||||
Client: httpClient,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = drv.Delete(context.Background(), "../abc/test/")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = httpClient.AssertNoRemaining()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDriverCopyEscaping(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
httpClient := tests.NewClient(&tests.RequestStub{
|
||||
Method: http.MethodPut,
|
||||
URL: "https://test_bucket.example.com/..__0x2f__a/",
|
||||
Match: func(req *http.Request) bool {
|
||||
return tests.ExpectHeaders(req.Header, map[string]string{
|
||||
"x-amz-copy-source": "test_bucket%2F..__0x2f__b%2F",
|
||||
})
|
||||
},
|
||||
Response: &http.Response{
|
||||
Body: io.NopCloser(strings.NewReader(`<CopyObjectResult></CopyObjectResult>`)),
|
||||
},
|
||||
})
|
||||
|
||||
drv, err := s3blob.New(&s3.S3{
|
||||
Bucket: "test_bucket",
|
||||
Region: "test_region",
|
||||
Endpoint: "https://example.com",
|
||||
Client: httpClient,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = drv.Copy(context.Background(), "../a/", "../b/")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = httpClient.AssertNoRemaining()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDriverAttributes(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
httpClient := tests.NewClient(&tests.RequestStub{
|
||||
Method: http.MethodHead,
|
||||
URL: "https://test_bucket.example.com/..__0x2f__a/",
|
||||
Response: &http.Response{
|
||||
Header: http.Header{
|
||||
"Last-Modified": []string{"Mon, 01 Feb 2025 03:04:05 GMT"},
|
||||
"Cache-Control": []string{"test_cache"},
|
||||
"Content-Disposition": []string{"test_disposition"},
|
||||
"Content-Encoding": []string{"test_encoding"},
|
||||
"Content-Language": []string{"test_language"},
|
||||
"Content-Type": []string{"test_type"},
|
||||
"Content-Range": []string{"test_range"},
|
||||
"Etag": []string{`"ce5be8b6f53645c596306c4572ece521"`},
|
||||
"Content-Length": []string{"100"},
|
||||
"x-amz-meta-AbC%40": []string{"%40test_meta_a"},
|
||||
"x-amz-meta-Def": []string{"test_meta_b"},
|
||||
},
|
||||
Body: http.NoBody,
|
||||
},
|
||||
})
|
||||
|
||||
drv, err := s3blob.New(&s3.S3{
|
||||
Bucket: "test_bucket",
|
||||
Region: "test_region",
|
||||
Endpoint: "https://example.com",
|
||||
Client: httpClient,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
attrs, err := drv.Attributes(context.Background(), "../a/")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
raw, err := json.Marshal(attrs)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expected := `{"cacheControl":"test_cache","contentDisposition":"test_disposition","contentEncoding":"test_encoding","contentLanguage":"test_language","contentType":"test_type","metadata":{"abc@":"@test_meta_a","def":"test_meta_b"},"createTime":"0001-01-01T00:00:00Z","modTime":"2025-02-01T03:04:05Z","size":100,"md5":"zlvotvU2RcWWMGxFcuzlIQ==","etag":"\"ce5be8b6f53645c596306c4572ece521\""}`
|
||||
if str := string(raw); str != expected {
|
||||
t.Fatalf("Expected attributes\n%s\ngot\n%s", expected, str)
|
||||
}
|
||||
|
||||
err = httpClient.AssertNoRemaining()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDriverListPaged(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
listResponse := func() *http.Response {
|
||||
return &http.Response{
|
||||
Body: io.NopCloser(strings.NewReader(`
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<ListBucketResult xmlns="http://s3.amazonaws.com/doc/2006-03-01/">
|
||||
<Name>example</Name>
|
||||
<ContinuationToken>ct</ContinuationToken>
|
||||
<NextContinuationToken>test_next</NextContinuationToken>
|
||||
<StartAfter>example0.txt</StartAfter>
|
||||
<KeyCount>1</KeyCount>
|
||||
<MaxKeys>3</MaxKeys>
|
||||
<Contents>
|
||||
<Key>..__0x2f__prefixB/test/example.txt</Key>
|
||||
<LastModified>2025-01-01T01:02:03.123Z</LastModified>
|
||||
<ETag>"ce5be8b6f53645c596306c4572ece521"</ETag>
|
||||
<Size>123</Size>
|
||||
</Contents>
|
||||
<Contents>
|
||||
<Key>prefixA/..__0x2f__escape.txt</Key>
|
||||
<LastModified>2025-01-02T01:02:03.123Z</LastModified>
|
||||
<Size>456</Size>
|
||||
</Contents>
|
||||
<CommonPrefixes>
|
||||
<Prefix>prefixA</Prefix>
|
||||
</CommonPrefixes>
|
||||
<CommonPrefixes>
|
||||
<Prefix>..__0x2f__prefixB</Prefix>
|
||||
</CommonPrefixes>
|
||||
</ListBucketResult>
|
||||
`)),
|
||||
}
|
||||
}
|
||||
|
||||
expectedPage := `{"objects":[{"key":"../prefixB","modTime":"0001-01-01T00:00:00Z","size":0,"md5":null,"isDir":true},{"key":"../prefixB/test/example.txt","modTime":"2025-01-01T01:02:03.123Z","size":123,"md5":"zlvotvU2RcWWMGxFcuzlIQ==","isDir":false},{"key":"prefixA","modTime":"0001-01-01T00:00:00Z","size":0,"md5":null,"isDir":true},{"key":"prefixA/../escape.txt","modTime":"2025-01-02T01:02:03.123Z","size":456,"md5":null,"isDir":false}],"nextPageToken":"dGVzdF9uZXh0"}`
|
||||
|
||||
httpClient := tests.NewClient(
|
||||
&tests.RequestStub{
|
||||
Method: http.MethodGet,
|
||||
URL: "https://test_bucket.example.com/?list-type=2&max-keys=1000",
|
||||
Response: listResponse(),
|
||||
},
|
||||
&tests.RequestStub{
|
||||
Method: http.MethodGet,
|
||||
URL: "https://test_bucket.example.com/?continuation-token=test_token&delimiter=test_delimiter&list-type=2&max-keys=123&prefix=test_prefix",
|
||||
Response: listResponse(),
|
||||
},
|
||||
)
|
||||
|
||||
drv, err := s3blob.New(&s3.S3{
|
||||
Bucket: "test_bucket",
|
||||
Region: "test_region",
|
||||
Endpoint: "https://example.com",
|
||||
Client: httpClient,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
opts *blob.ListOptions
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
"empty options",
|
||||
&blob.ListOptions{},
|
||||
expectedPage,
|
||||
},
|
||||
{
|
||||
"filled options",
|
||||
&blob.ListOptions{Prefix: "test_prefix", Delimiter: "test_delimiter", PageSize: 123, PageToken: []byte("test_token")},
|
||||
expectedPage,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
page, err := drv.ListPaged(context.Background(), s.opts)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
raw, err := json.Marshal(page)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if str := string(raw); s.expected != str {
|
||||
t.Fatalf("Expected page result\n%s\ngot\n%s", s.expected, str)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
err = httpClient.AssertNoRemaining()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDriverNewRangeReader(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []struct {
|
||||
offset int64
|
||||
length int64
|
||||
httpClient *tests.Client
|
||||
expectedAttrs string
|
||||
}{
|
||||
{
|
||||
0,
|
||||
0,
|
||||
tests.NewClient(&tests.RequestStub{
|
||||
Method: http.MethodGet,
|
||||
URL: "https://test_bucket.example.com/..__0x2f__abc/test.txt",
|
||||
Match: func(req *http.Request) bool {
|
||||
return tests.ExpectHeaders(req.Header, map[string]string{
|
||||
"Range": "bytes=0-0",
|
||||
})
|
||||
},
|
||||
Response: &http.Response{
|
||||
Header: http.Header{
|
||||
"Last-Modified": []string{"Mon, 01 Feb 2025 03:04:05 GMT"},
|
||||
"Content-Type": []string{"test_ct"},
|
||||
"Content-Length": []string{"123"},
|
||||
},
|
||||
Body: io.NopCloser(strings.NewReader("test")),
|
||||
},
|
||||
}),
|
||||
`{"contentType":"test_ct","modTime":"2025-02-01T03:04:05Z","size":123}`,
|
||||
},
|
||||
{
|
||||
10,
|
||||
-1,
|
||||
tests.NewClient(&tests.RequestStub{
|
||||
Method: http.MethodGet,
|
||||
URL: "https://test_bucket.example.com/..__0x2f__abc/test.txt",
|
||||
Match: func(req *http.Request) bool {
|
||||
return tests.ExpectHeaders(req.Header, map[string]string{
|
||||
"Range": "bytes=10-",
|
||||
})
|
||||
},
|
||||
Response: &http.Response{
|
||||
Header: http.Header{
|
||||
"Last-Modified": []string{"Mon, 01 Feb 2025 03:04:05 GMT"},
|
||||
"Content-Type": []string{"test_ct"},
|
||||
"Content-Range": []string{"bytes 1-1/456"}, // should take precedence over content-length
|
||||
"Content-Length": []string{"123"},
|
||||
},
|
||||
Body: io.NopCloser(strings.NewReader("test")),
|
||||
},
|
||||
}),
|
||||
`{"contentType":"test_ct","modTime":"2025-02-01T03:04:05Z","size":456}`,
|
||||
},
|
||||
{
|
||||
10,
|
||||
0,
|
||||
tests.NewClient(&tests.RequestStub{
|
||||
Method: http.MethodGet,
|
||||
URL: "https://test_bucket.example.com/..__0x2f__abc/test.txt",
|
||||
Match: func(req *http.Request) bool {
|
||||
return tests.ExpectHeaders(req.Header, map[string]string{
|
||||
"Range": "bytes=10-10",
|
||||
})
|
||||
},
|
||||
Response: &http.Response{
|
||||
Header: http.Header{
|
||||
"Last-Modified": []string{"Mon, 01 Feb 2025 03:04:05 GMT"},
|
||||
"Content-Type": []string{"test_ct"},
|
||||
// no range and length headers
|
||||
// "Content-Range": []string{"bytes 1-1/456"},
|
||||
// "Content-Length": []string{"123"},
|
||||
},
|
||||
Body: io.NopCloser(strings.NewReader("test")),
|
||||
},
|
||||
}),
|
||||
`{"contentType":"test_ct","modTime":"2025-02-01T03:04:05Z","size":0}`,
|
||||
},
|
||||
{
|
||||
10,
|
||||
20,
|
||||
tests.NewClient(&tests.RequestStub{
|
||||
Method: http.MethodGet,
|
||||
URL: "https://test_bucket.example.com/..__0x2f__abc/test.txt",
|
||||
Match: func(req *http.Request) bool {
|
||||
return tests.ExpectHeaders(req.Header, map[string]string{
|
||||
"Range": "bytes=10-29",
|
||||
})
|
||||
},
|
||||
Response: &http.Response{
|
||||
Header: http.Header{
|
||||
"Last-Modified": []string{"Mon, 01 Feb 2025 03:04:05 GMT"},
|
||||
"Content-Type": []string{"test_ct"},
|
||||
// with range header but invalid format -> content-length takes precedence
|
||||
"Content-Range": []string{"bytes invalid-456"},
|
||||
"Content-Length": []string{"123"},
|
||||
},
|
||||
Body: io.NopCloser(strings.NewReader("test")),
|
||||
},
|
||||
}),
|
||||
`{"contentType":"test_ct","modTime":"2025-02-01T03:04:05Z","size":123}`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("offset_%d_length_%d", s.offset, s.length), func(t *testing.T) {
|
||||
drv, err := s3blob.New(&s3.S3{
|
||||
Bucket: "test_bucket",
|
||||
Region: "tesst_region",
|
||||
Endpoint: "https://example.com",
|
||||
Client: s.httpClient,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
r, err := drv.NewRangeReader(context.Background(), "../abc/test.txt", s.offset, s.length)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer r.Close()
|
||||
|
||||
// the response body should be always replaced with http.NoBody
|
||||
if s.length == 0 {
|
||||
body := make([]byte, 1)
|
||||
n, err := r.Read(body)
|
||||
if n != 0 || !errors.Is(err, io.EOF) {
|
||||
t.Fatalf("Expected body to be http.NoBody, got %v (%v)", body, err)
|
||||
}
|
||||
}
|
||||
|
||||
rawAttrs, err := json.Marshal(r.Attributes())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if str := string(rawAttrs); str != s.expectedAttrs {
|
||||
t.Fatalf("Expected attributes\n%s\ngot\n%s", s.expectedAttrs, str)
|
||||
}
|
||||
|
||||
err = s.httpClient.AssertNoRemaining()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDriverNewTypedWriter(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
httpClient := tests.NewClient(
|
||||
&tests.RequestStub{
|
||||
Method: http.MethodPut,
|
||||
URL: "https://test_bucket.example.com/..__0x2f__abc/test/",
|
||||
Match: func(req *http.Request) bool {
|
||||
body, err := io.ReadAll(req.Body)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return string(body) == "test" && tests.ExpectHeaders(req.Header, map[string]string{
|
||||
"cache-control": "test_cache_control",
|
||||
"content-disposition": "test_content_disposition",
|
||||
"content-encoding": "test_content_encoding",
|
||||
"content-language": "test_content_language",
|
||||
"content-type": "test_ct",
|
||||
"content-md5": "dGVzdA==",
|
||||
})
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
drv, err := s3blob.New(&s3.S3{
|
||||
Bucket: "test_bucket",
|
||||
Region: "test_region",
|
||||
Endpoint: "https://example.com",
|
||||
Client: httpClient,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
options := &blob.WriterOptions{
|
||||
CacheControl: "test_cache_control",
|
||||
ContentDisposition: "test_content_disposition",
|
||||
ContentEncoding: "test_content_encoding",
|
||||
ContentLanguage: "test_content_language",
|
||||
ContentType: "test_content_type", // should be ignored
|
||||
ContentMD5: []byte("test"),
|
||||
Metadata: map[string]string{"@test_meta_a": "@test"},
|
||||
}
|
||||
|
||||
w, err := drv.NewTypedWriter(context.Background(), "../abc/test/", "test_ct", options)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
n, err := w.Write(nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if n != 0 {
|
||||
t.Fatalf("Expected nil write to result in %d written bytes, got %d", 0, n)
|
||||
}
|
||||
|
||||
n, err = w.Write([]byte("test"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if n != 4 {
|
||||
t.Fatalf("Expected nil write to result in %d written bytes, got %d", 4, n)
|
||||
}
|
||||
|
||||
err = w.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = httpClient.AssertNoRemaining()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
45
tools/hook/event.go
Normal file
45
tools/hook/event.go
Normal file
|
@ -0,0 +1,45 @@
|
|||
package hook
|
||||
|
||||
// Resolver defines a common interface for a Hook event (see [Event]).
|
||||
type Resolver interface {
|
||||
// Next triggers the next handler in the hook's chain (if any).
|
||||
Next() error
|
||||
|
||||
// note: kept only for the generic interface; may get removed in the future
|
||||
nextFunc() func() error
|
||||
setNextFunc(f func() error)
|
||||
}
|
||||
|
||||
var _ Resolver = (*Event)(nil)
|
||||
|
||||
// Event implements [Resolver] and it is intended to be used as a base
|
||||
// Hook event that you can embed in your custom typed event structs.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// type CustomEvent struct {
|
||||
// hook.Event
|
||||
//
|
||||
// SomeField int
|
||||
// }
|
||||
type Event struct {
|
||||
next func() error
|
||||
}
|
||||
|
||||
// Next calls the next hook handler.
|
||||
func (e *Event) Next() error {
|
||||
if e.next != nil {
|
||||
return e.next()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// nextFunc returns the function that Next calls.
|
||||
func (e *Event) nextFunc() func() error {
|
||||
return e.next
|
||||
}
|
||||
|
||||
// setNextFunc sets the function that Next calls.
|
||||
func (e *Event) setNextFunc(f func() error) {
|
||||
e.next = f
|
||||
}
|
29
tools/hook/event_test.go
Normal file
29
tools/hook/event_test.go
Normal file
|
@ -0,0 +1,29 @@
|
|||
package hook
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestEventNext(t *testing.T) {
|
||||
calls := 0
|
||||
|
||||
e := Event{}
|
||||
|
||||
if e.nextFunc() != nil {
|
||||
t.Fatalf("Expected nextFunc to be nil")
|
||||
}
|
||||
|
||||
e.setNextFunc(func() error {
|
||||
calls++
|
||||
return nil
|
||||
})
|
||||
|
||||
if e.nextFunc() == nil {
|
||||
t.Fatalf("Expected nextFunc to be non-nil")
|
||||
}
|
||||
|
||||
e.Next()
|
||||
e.Next()
|
||||
|
||||
if calls != 2 {
|
||||
t.Fatalf("Expected %d calls, got %d", 2, calls)
|
||||
}
|
||||
}
|
178
tools/hook/hook.go
Normal file
178
tools/hook/hook.go
Normal file
|
@ -0,0 +1,178 @@
|
|||
package hook
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"sync"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/security"
|
||||
)
|
||||
|
||||
// Handler defines a single Hook handler.
|
||||
// Multiple handlers can share the same id.
|
||||
// If Id is not explicitly set it will be autogenerated by Hook.Add and Hook.AddHandler.
|
||||
type Handler[T Resolver] struct {
|
||||
// Func defines the handler function to execute.
|
||||
//
|
||||
// Note that users need to call e.Next() in order to proceed with
|
||||
// the execution of the hook chain.
|
||||
Func func(T) error
|
||||
|
||||
// Id is the unique identifier of the handler.
|
||||
//
|
||||
// It could be used later to remove the handler from a hook via [Hook.Remove].
|
||||
//
|
||||
// If missing, an autogenerated value will be assigned when adding
|
||||
// the handler to a hook.
|
||||
Id string
|
||||
|
||||
// Priority allows changing the default exec priority of the handler within a hook.
|
||||
//
|
||||
// If 0, the handler will be executed in the same order it was registered.
|
||||
Priority int
|
||||
}
|
||||
|
||||
// Hook defines a generic concurrent safe structure for managing event hooks.
|
||||
//
|
||||
// When using custom event it must embed the base [hook.Event].
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// type CustomEvent struct {
|
||||
// hook.Event
|
||||
// SomeField int
|
||||
// }
|
||||
//
|
||||
// h := Hook[*CustomEvent]{}
|
||||
//
|
||||
// h.BindFunc(func(e *CustomEvent) error {
|
||||
// println(e.SomeField)
|
||||
//
|
||||
// return e.Next()
|
||||
// })
|
||||
//
|
||||
// h.Trigger(&CustomEvent{ SomeField: 123 })
|
||||
type Hook[T Resolver] struct {
|
||||
handlers []*Handler[T]
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// Bind registers the provided handler to the current hooks queue.
|
||||
//
|
||||
// If handler.Id is empty it is updated with autogenerated value.
|
||||
//
|
||||
// If a handler from the current hook list has Id matching handler.Id
|
||||
// then the old handler is replaced with the new one.
|
||||
func (h *Hook[T]) Bind(handler *Handler[T]) string {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
var exists bool
|
||||
|
||||
if handler.Id == "" {
|
||||
handler.Id = generateHookId()
|
||||
|
||||
// ensure that it doesn't exist
|
||||
DUPLICATE_CHECK:
|
||||
for _, existing := range h.handlers {
|
||||
if existing.Id == handler.Id {
|
||||
handler.Id = generateHookId()
|
||||
goto DUPLICATE_CHECK
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// replace existing
|
||||
for i, existing := range h.handlers {
|
||||
if existing.Id == handler.Id {
|
||||
h.handlers[i] = handler
|
||||
exists = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// append new
|
||||
if !exists {
|
||||
h.handlers = append(h.handlers, handler)
|
||||
}
|
||||
|
||||
// sort handlers by Priority, preserving the original order of equal items
|
||||
sort.SliceStable(h.handlers, func(i, j int) bool {
|
||||
return h.handlers[i].Priority < h.handlers[j].Priority
|
||||
})
|
||||
|
||||
return handler.Id
|
||||
}
|
||||
|
||||
// BindFunc is similar to Bind but registers a new handler from just the provided function.
|
||||
//
|
||||
// The registered handler is added with a default 0 priority and the id will be autogenerated.
|
||||
//
|
||||
// If you want to register a handler with custom priority or id use the [Hook.Bind] method.
|
||||
func (h *Hook[T]) BindFunc(fn func(e T) error) string {
|
||||
return h.Bind(&Handler[T]{Func: fn})
|
||||
}
|
||||
|
||||
// Unbind removes one or many hook handler by their id.
|
||||
func (h *Hook[T]) Unbind(idsToRemove ...string) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
for _, id := range idsToRemove {
|
||||
for i := len(h.handlers) - 1; i >= 0; i-- {
|
||||
if h.handlers[i].Id == id {
|
||||
h.handlers = append(h.handlers[:i], h.handlers[i+1:]...)
|
||||
break // for now stop on the first occurrence since we don't allow handlers with duplicated ids
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// UnbindAll removes all registered handlers.
|
||||
func (h *Hook[T]) UnbindAll() {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
h.handlers = nil
|
||||
}
|
||||
|
||||
// Length returns to total number of registered hook handlers.
|
||||
func (h *Hook[T]) Length() int {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
return len(h.handlers)
|
||||
}
|
||||
|
||||
// Trigger executes all registered hook handlers one by one
|
||||
// with the specified event as an argument.
|
||||
//
|
||||
// Optionally, this method allows also to register additional one off
|
||||
// handler funcs that will be temporary appended to the handlers queue.
|
||||
//
|
||||
// NB! Each hook handler must call event.Next() in order the hook chain to proceed.
|
||||
func (h *Hook[T]) Trigger(event T, oneOffHandlerFuncs ...func(T) error) error {
|
||||
h.mu.RLock()
|
||||
handlers := make([]func(T) error, 0, len(h.handlers)+len(oneOffHandlerFuncs))
|
||||
for _, handler := range h.handlers {
|
||||
handlers = append(handlers, handler.Func)
|
||||
}
|
||||
handlers = append(handlers, oneOffHandlerFuncs...)
|
||||
h.mu.RUnlock()
|
||||
|
||||
event.setNextFunc(nil) // reset in case the event is being reused
|
||||
|
||||
for i := len(handlers) - 1; i >= 0; i-- {
|
||||
i := i
|
||||
old := event.nextFunc()
|
||||
event.setNextFunc(func() error {
|
||||
event.setNextFunc(old)
|
||||
return handlers[i](event)
|
||||
})
|
||||
}
|
||||
|
||||
return event.Next()
|
||||
}
|
||||
|
||||
func generateHookId() string {
|
||||
return security.PseudorandomString(20)
|
||||
}
|
162
tools/hook/hook_test.go
Normal file
162
tools/hook/hook_test.go
Normal file
|
@ -0,0 +1,162 @@
|
|||
package hook
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestHookAddHandlerAndAdd(t *testing.T) {
|
||||
calls := ""
|
||||
|
||||
h := Hook[*Event]{}
|
||||
|
||||
h.BindFunc(func(e *Event) error { calls += "1"; return e.Next() })
|
||||
h.BindFunc(func(e *Event) error { calls += "2"; return e.Next() })
|
||||
h3Id := h.BindFunc(func(e *Event) error { calls += "3"; return e.Next() })
|
||||
h.Bind(&Handler[*Event]{
|
||||
Id: h3Id, // should replace 3
|
||||
Func: func(e *Event) error { calls += "3'"; return e.Next() },
|
||||
})
|
||||
h.Bind(&Handler[*Event]{
|
||||
Func: func(e *Event) error { calls += "4"; return e.Next() },
|
||||
Priority: -2,
|
||||
})
|
||||
h.Bind(&Handler[*Event]{
|
||||
Func: func(e *Event) error { calls += "5"; return e.Next() },
|
||||
Priority: -1,
|
||||
})
|
||||
h.Bind(&Handler[*Event]{
|
||||
Func: func(e *Event) error { calls += "6"; return e.Next() },
|
||||
})
|
||||
h.Bind(&Handler[*Event]{
|
||||
Func: func(e *Event) error { calls += "7"; e.Next(); return errors.New("test") }, // error shouldn't stop the chain
|
||||
})
|
||||
|
||||
h.Trigger(
|
||||
&Event{},
|
||||
func(e *Event) error { calls += "8"; return e.Next() },
|
||||
func(e *Event) error { calls += "9"; return nil }, // skip next
|
||||
func(e *Event) error { calls += "10"; return e.Next() },
|
||||
)
|
||||
|
||||
if total := len(h.handlers); total != 7 {
|
||||
t.Fatalf("Expected %d handlers, found %d", 7, total)
|
||||
}
|
||||
|
||||
expectedCalls := "45123'6789"
|
||||
|
||||
if calls != expectedCalls {
|
||||
t.Fatalf("Expected calls sequence %q, got %q", expectedCalls, calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHookLength(t *testing.T) {
|
||||
h := Hook[*Event]{}
|
||||
|
||||
if l := h.Length(); l != 0 {
|
||||
t.Fatalf("Expected 0 hook handlers, got %d", l)
|
||||
}
|
||||
|
||||
h.BindFunc(func(e *Event) error { return e.Next() })
|
||||
h.BindFunc(func(e *Event) error { return e.Next() })
|
||||
|
||||
if l := h.Length(); l != 2 {
|
||||
t.Fatalf("Expected 2 hook handlers, got %d", l)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHookUnbind(t *testing.T) {
|
||||
h := Hook[*Event]{}
|
||||
|
||||
calls := ""
|
||||
|
||||
id0 := h.BindFunc(func(e *Event) error { calls += "0"; return e.Next() })
|
||||
id1 := h.BindFunc(func(e *Event) error { calls += "1"; return e.Next() })
|
||||
h.BindFunc(func(e *Event) error { calls += "2"; return e.Next() })
|
||||
h.Bind(&Handler[*Event]{
|
||||
Func: func(e *Event) error { calls += "3"; return e.Next() },
|
||||
})
|
||||
|
||||
h.Unbind("missing") // should do nothing and not panic
|
||||
|
||||
if total := len(h.handlers); total != 4 {
|
||||
t.Fatalf("Expected %d handlers, got %d", 4, total)
|
||||
}
|
||||
|
||||
h.Unbind(id1, id0)
|
||||
|
||||
if total := len(h.handlers); total != 2 {
|
||||
t.Fatalf("Expected %d handlers, got %d", 2, total)
|
||||
}
|
||||
|
||||
err := h.Trigger(&Event{}, func(e *Event) error { calls += "4"; return e.Next() })
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expectedCalls := "234"
|
||||
|
||||
if calls != expectedCalls {
|
||||
t.Fatalf("Expected calls sequence %q, got %q", expectedCalls, calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHookUnbindAll(t *testing.T) {
|
||||
h := Hook[*Event]{}
|
||||
|
||||
h.UnbindAll() // should do nothing and not panic
|
||||
|
||||
h.BindFunc(func(e *Event) error { return nil })
|
||||
h.BindFunc(func(e *Event) error { return nil })
|
||||
|
||||
if total := len(h.handlers); total != 2 {
|
||||
t.Fatalf("Expected %d handlers before UnbindAll, found %d", 2, total)
|
||||
}
|
||||
|
||||
h.UnbindAll()
|
||||
|
||||
if total := len(h.handlers); total != 0 {
|
||||
t.Fatalf("Expected no handlers after UnbindAll, found %d", total)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHookTriggerErrorPropagation(t *testing.T) {
|
||||
err := errors.New("test")
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
handlers []func(*Event) error
|
||||
expectedError error
|
||||
}{
|
||||
{
|
||||
"without error",
|
||||
[]func(*Event) error{
|
||||
func(e *Event) error { return e.Next() },
|
||||
func(e *Event) error { return e.Next() },
|
||||
},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"with error",
|
||||
[]func(*Event) error{
|
||||
func(e *Event) error { return e.Next() },
|
||||
func(e *Event) error { e.Next(); return err },
|
||||
func(e *Event) error { return e.Next() },
|
||||
},
|
||||
err,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
h := Hook[*Event]{}
|
||||
for _, handler := range s.handlers {
|
||||
h.BindFunc(handler)
|
||||
}
|
||||
result := h.Trigger(&Event{})
|
||||
if result != s.expectedError {
|
||||
t.Fatalf("Expected %v, got %v", s.expectedError, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
84
tools/hook/tagged.go
Normal file
84
tools/hook/tagged.go
Normal file
|
@ -0,0 +1,84 @@
|
|||
package hook
|
||||
|
||||
import (
|
||||
"github.com/pocketbase/pocketbase/tools/list"
|
||||
)
|
||||
|
||||
// Tagger defines an interface for event data structs that support tags/groups/categories/etc.
|
||||
// Usually used together with TaggedHook.
|
||||
type Tagger interface {
|
||||
Resolver
|
||||
|
||||
Tags() []string
|
||||
}
|
||||
|
||||
// wrapped local Hook embedded struct to limit the public API surface.
|
||||
type mainHook[T Tagger] struct {
|
||||
*Hook[T]
|
||||
}
|
||||
|
||||
// NewTaggedHook creates a new TaggedHook with the provided main hook and optional tags.
|
||||
func NewTaggedHook[T Tagger](hook *Hook[T], tags ...string) *TaggedHook[T] {
|
||||
return &TaggedHook[T]{
|
||||
mainHook[T]{hook},
|
||||
tags,
|
||||
}
|
||||
}
|
||||
|
||||
// TaggedHook defines a proxy hook which register handlers that are triggered only
|
||||
// if the TaggedHook.tags are empty or includes at least one of the event data tag(s).
|
||||
type TaggedHook[T Tagger] struct {
|
||||
mainHook[T]
|
||||
|
||||
tags []string
|
||||
}
|
||||
|
||||
// CanTriggerOn checks if the current TaggedHook can be triggered with
|
||||
// the provided event data tags.
|
||||
//
|
||||
// It returns always true if the hook doens't have any tags.
|
||||
func (h *TaggedHook[T]) CanTriggerOn(tagsToCheck []string) bool {
|
||||
if len(h.tags) == 0 {
|
||||
return true // match all
|
||||
}
|
||||
|
||||
for _, t := range tagsToCheck {
|
||||
if list.ExistInSlice(t, h.tags) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// Bind registers the provided handler to the current hooks queue.
|
||||
//
|
||||
// It is similar to [Hook.Bind] with the difference that the handler
|
||||
// function is invoked only if the event data tags satisfy h.CanTriggerOn.
|
||||
func (h *TaggedHook[T]) Bind(handler *Handler[T]) string {
|
||||
fn := handler.Func
|
||||
|
||||
handler.Func = func(e T) error {
|
||||
if h.CanTriggerOn(e.Tags()) {
|
||||
return fn(e)
|
||||
}
|
||||
|
||||
return e.Next()
|
||||
}
|
||||
|
||||
return h.mainHook.Bind(handler)
|
||||
}
|
||||
|
||||
// BindFunc registers a new handler with the specified function.
|
||||
//
|
||||
// It is similar to [Hook.Bind] with the difference that the handler
|
||||
// function is invoked only if the event data tags satisfy h.CanTriggerOn.
|
||||
func (h *TaggedHook[T]) BindFunc(fn func(e T) error) string {
|
||||
return h.mainHook.BindFunc(func(e T) error {
|
||||
if h.CanTriggerOn(e.Tags()) {
|
||||
return fn(e)
|
||||
}
|
||||
|
||||
return e.Next()
|
||||
})
|
||||
}
|
84
tools/hook/tagged_test.go
Normal file
84
tools/hook/tagged_test.go
Normal file
|
@ -0,0 +1,84 @@
|
|||
package hook
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type mockTagsEvent struct {
|
||||
Event
|
||||
tags []string
|
||||
}
|
||||
|
||||
func (m mockTagsEvent) Tags() []string {
|
||||
return m.tags
|
||||
}
|
||||
|
||||
func TestTaggedHook(t *testing.T) {
|
||||
calls := ""
|
||||
|
||||
base := &Hook[*mockTagsEvent]{}
|
||||
base.BindFunc(func(e *mockTagsEvent) error { calls += "f0"; return e.Next() })
|
||||
|
||||
hA := NewTaggedHook(base)
|
||||
hA.BindFunc(func(e *mockTagsEvent) error { calls += "a1"; return e.Next() })
|
||||
hA.Bind(&Handler[*mockTagsEvent]{
|
||||
Func: func(e *mockTagsEvent) error { calls += "a2"; return e.Next() },
|
||||
Priority: -1,
|
||||
})
|
||||
|
||||
hB := NewTaggedHook(base, "b1", "b2")
|
||||
hB.BindFunc(func(e *mockTagsEvent) error { calls += "b1"; return e.Next() })
|
||||
hB.Bind(&Handler[*mockTagsEvent]{
|
||||
Func: func(e *mockTagsEvent) error { calls += "b2"; return e.Next() },
|
||||
Priority: -2,
|
||||
})
|
||||
|
||||
hC := NewTaggedHook(base, "c1", "c2")
|
||||
hC.BindFunc(func(e *mockTagsEvent) error { calls += "c1"; return e.Next() })
|
||||
hC.Bind(&Handler[*mockTagsEvent]{
|
||||
Func: func(e *mockTagsEvent) error { calls += "c2"; return e.Next() },
|
||||
Priority: -3,
|
||||
})
|
||||
|
||||
scenarios := []struct {
|
||||
event *mockTagsEvent
|
||||
expectedCalls string
|
||||
}{
|
||||
{
|
||||
&mockTagsEvent{},
|
||||
"a2f0a1",
|
||||
},
|
||||
{
|
||||
&mockTagsEvent{tags: []string{"missing"}},
|
||||
"a2f0a1",
|
||||
},
|
||||
{
|
||||
&mockTagsEvent{tags: []string{"b2"}},
|
||||
"b2a2f0a1b1",
|
||||
},
|
||||
{
|
||||
&mockTagsEvent{tags: []string{"c1"}},
|
||||
"c2a2f0a1c1",
|
||||
},
|
||||
{
|
||||
&mockTagsEvent{tags: []string{"b1", "c2"}},
|
||||
"c2b2a2f0a1b1c1",
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(strings.Join(s.event.tags, "_"), func(t *testing.T) {
|
||||
calls = "" // reset
|
||||
|
||||
err := base.Trigger(s.event)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected trigger error: %v", err)
|
||||
}
|
||||
|
||||
if calls != s.expectedCalls {
|
||||
t.Fatalf("Expected calls sequence %q, got %q", s.expectedCalls, calls)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
113
tools/inflector/inflector.go
Normal file
113
tools/inflector/inflector.go
Normal file
|
@ -0,0 +1,113 @@
|
|||
package inflector
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
var columnifyRemoveRegex = regexp.MustCompile(`[^\w\.\*\-\_\@\#]+`)
|
||||
var snakecaseSplitRegex = regexp.MustCompile(`[\W_]+`)
|
||||
|
||||
// UcFirst converts the first character of a string into uppercase.
|
||||
func UcFirst(str string) string {
|
||||
if str == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
s := []rune(str)
|
||||
|
||||
return string(unicode.ToUpper(s[0])) + string(s[1:])
|
||||
}
|
||||
|
||||
// Columnify strips invalid db identifier characters.
|
||||
func Columnify(str string) string {
|
||||
return columnifyRemoveRegex.ReplaceAllString(str, "")
|
||||
}
|
||||
|
||||
// Sentenize converts and normalizes string into a sentence.
|
||||
func Sentenize(str string) string {
|
||||
str = strings.TrimSpace(str)
|
||||
if str == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
str = UcFirst(str)
|
||||
|
||||
lastChar := str[len(str)-1:]
|
||||
if lastChar != "." && lastChar != "?" && lastChar != "!" {
|
||||
return str + "."
|
||||
}
|
||||
|
||||
return str
|
||||
}
|
||||
|
||||
// Sanitize sanitizes `str` by removing all characters satisfying `removePattern`.
|
||||
// Returns an error if the pattern is not valid regex string.
|
||||
func Sanitize(str string, removePattern string) (string, error) {
|
||||
exp, err := regexp.Compile(removePattern)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return exp.ReplaceAllString(str, ""), nil
|
||||
}
|
||||
|
||||
// Snakecase removes all non word characters and converts any english text into a snakecase.
|
||||
// "ABBREVIATIONS" are preserved, eg. "myTestDB" will become "my_test_db".
|
||||
func Snakecase(str string) string {
|
||||
var result strings.Builder
|
||||
|
||||
// split at any non word character and underscore
|
||||
words := snakecaseSplitRegex.Split(str, -1)
|
||||
|
||||
for _, word := range words {
|
||||
if word == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if result.Len() > 0 {
|
||||
result.WriteString("_")
|
||||
}
|
||||
|
||||
for i, c := range word {
|
||||
if unicode.IsUpper(c) && i > 0 &&
|
||||
// is not a following uppercase character
|
||||
!unicode.IsUpper(rune(word[i-1])) {
|
||||
result.WriteString("_")
|
||||
}
|
||||
|
||||
result.WriteRune(c)
|
||||
}
|
||||
}
|
||||
|
||||
return strings.ToLower(result.String())
|
||||
}
|
||||
|
||||
// Camelize converts the provided string to its "CamelCased" version
|
||||
// (non alphanumeric characters are removed).
|
||||
//
|
||||
// For example:
|
||||
//
|
||||
// inflector.Camelize("send_email") // "SendEmail"
|
||||
func Camelize(str string) string {
|
||||
var result strings.Builder
|
||||
|
||||
var isPrevSpecial bool
|
||||
|
||||
for _, c := range str {
|
||||
if !unicode.IsLetter(c) && !unicode.IsNumber(c) {
|
||||
isPrevSpecial = true
|
||||
continue
|
||||
}
|
||||
|
||||
if isPrevSpecial || result.Len() == 0 {
|
||||
isPrevSpecial = false
|
||||
result.WriteRune(unicode.ToUpper(c))
|
||||
} else {
|
||||
result.WriteRune(c)
|
||||
}
|
||||
}
|
||||
|
||||
return result.String()
|
||||
}
|
175
tools/inflector/inflector_test.go
Normal file
175
tools/inflector/inflector_test.go
Normal file
|
@ -0,0 +1,175 @@
|
|||
package inflector_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/inflector"
|
||||
)
|
||||
|
||||
func TestUcFirst(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
val string
|
||||
expected string
|
||||
}{
|
||||
{"", ""},
|
||||
{" ", " "},
|
||||
{"Test", "Test"},
|
||||
{"test", "Test"},
|
||||
{"test test2", "Test test2"},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%#v", i, s.val), func(t *testing.T) {
|
||||
result := inflector.UcFirst(s.val)
|
||||
if result != s.expected {
|
||||
t.Fatalf("Expected %q, got %q", s.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestColumnify(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
val string
|
||||
expected string
|
||||
}{
|
||||
{"", ""},
|
||||
{" ", ""},
|
||||
{"123", "123"},
|
||||
{"Test.", "Test."},
|
||||
{" test ", "test"},
|
||||
{"test1.test2", "test1.test2"},
|
||||
{"@test!abc", "@testabc"},
|
||||
{"#test?abc", "#testabc"},
|
||||
{"123test(123)#", "123test123#"},
|
||||
{"test1--test2", "test1--test2"},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%#v", i, s.val), func(t *testing.T) {
|
||||
result := inflector.Columnify(s.val)
|
||||
if result != s.expected {
|
||||
t.Fatalf("Expected %q, got %q", s.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSentenize(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
val string
|
||||
expected string
|
||||
}{
|
||||
{"", ""},
|
||||
{" ", ""},
|
||||
{".", "."},
|
||||
{"?", "?"},
|
||||
{"!", "!"},
|
||||
{"Test", "Test."},
|
||||
{" test ", "Test."},
|
||||
{"hello world", "Hello world."},
|
||||
{"hello world.", "Hello world."},
|
||||
{"hello world!", "Hello world!"},
|
||||
{"hello world?", "Hello world?"},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%#v", i, s.val), func(t *testing.T) {
|
||||
result := inflector.Sentenize(s.val)
|
||||
if result != s.expected {
|
||||
t.Fatalf("Expected %q, got %q", s.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitize(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
val string
|
||||
pattern string
|
||||
expected string
|
||||
expectErr bool
|
||||
}{
|
||||
{"", ``, "", false},
|
||||
{" ", ``, " ", false},
|
||||
{" ", ` `, "", false},
|
||||
{"", `[A-Z]`, "", false},
|
||||
{"abcABC", `[A-Z]`, "abc", false},
|
||||
{"abcABC", `[A-Z`, "", true}, // invalid pattern
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%#v", i, s.val), func(t *testing.T) {
|
||||
result, err := inflector.Sanitize(s.val, s.pattern)
|
||||
hasErr := err != nil
|
||||
|
||||
if s.expectErr != hasErr {
|
||||
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectErr, hasErr, err)
|
||||
}
|
||||
|
||||
if result != s.expected {
|
||||
t.Fatalf("Expected %q, got %q", s.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSnakecase(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
val string
|
||||
expected string
|
||||
}{
|
||||
{"", ""},
|
||||
{" ", ""},
|
||||
{"!@#$%^", ""},
|
||||
{"...", ""},
|
||||
{"_", ""},
|
||||
{"John Doe", "john_doe"},
|
||||
{"John_Doe", "john_doe"},
|
||||
{".a!b@c#d$e%123. ", "a_b_c_d_e_123"},
|
||||
{"HelloWorld", "hello_world"},
|
||||
{"HelloWorld1HelloWorld2", "hello_world1_hello_world2"},
|
||||
{"TEST", "test"},
|
||||
{"testABR", "test_abr"},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%#v", i, s.val), func(t *testing.T) {
|
||||
result := inflector.Snakecase(s.val)
|
||||
if result != s.expected {
|
||||
t.Fatalf("Expected %q, got %q", s.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCamelize(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
val string
|
||||
expected string
|
||||
}{
|
||||
{"", ""},
|
||||
{" ", ""},
|
||||
{"Test", "Test"},
|
||||
{"test", "Test"},
|
||||
{"testTest2", "TestTest2"},
|
||||
{"TestTest2", "TestTest2"},
|
||||
{"test test2", "TestTest2"},
|
||||
{"test-test2", "TestTest2"},
|
||||
{"test'test2", "TestTest2"},
|
||||
{"test1test2", "Test1test2"},
|
||||
{"1test-test2", "1testTest2"},
|
||||
{"123", "123"},
|
||||
{"123a", "123a"},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%#v", i, s.val), func(t *testing.T) {
|
||||
result := inflector.Camelize(s.val)
|
||||
if result != s.expected {
|
||||
t.Fatalf("Expected %q, got %q", s.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
89
tools/inflector/singularize.go
Normal file
89
tools/inflector/singularize.go
Normal file
|
@ -0,0 +1,89 @@
|
|||
package inflector
|
||||
|
||||
import (
|
||||
"log"
|
||||
"regexp"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/store"
|
||||
)
|
||||
|
||||
var compiledPatterns = store.New[string, *regexp.Regexp](nil)
|
||||
|
||||
// note: the patterns are extracted from popular Ruby/PHP/Node.js inflector packages
|
||||
var singularRules = []struct {
|
||||
pattern string // lazily compiled
|
||||
replacement string
|
||||
}{
|
||||
{"(?i)([nrlm]ese|deer|fish|sheep|measles|ois|pox|media|ss)$", "${1}"},
|
||||
{"(?i)^(sea[- ]bass)$", "${1}"},
|
||||
{"(?i)(s)tatuses$", "${1}tatus"},
|
||||
{"(?i)(f)eet$", "${1}oot"},
|
||||
{"(?i)(t)eeth$", "${1}ooth"},
|
||||
{"(?i)^(.*)(menu)s$", "${1}${2}"},
|
||||
{"(?i)(quiz)zes$", "${1}"},
|
||||
{"(?i)(matr)ices$", "${1}ix"},
|
||||
{"(?i)(vert|ind)ices$", "${1}ex"},
|
||||
{"(?i)^(ox)en", "${1}"},
|
||||
{"(?i)(alias)es$", "${1}"},
|
||||
{"(?i)(alumn|bacill|cact|foc|fung|nucle|radi|stimul|syllab|termin|viri?)i$", "${1}us"},
|
||||
{"(?i)([ftw]ax)es", "${1}"},
|
||||
{"(?i)(cris|ax|test)es$", "${1}is"},
|
||||
{"(?i)(shoe)s$", "${1}"},
|
||||
{"(?i)(o)es$", "${1}"},
|
||||
{"(?i)ouses$", "ouse"},
|
||||
{"(?i)([^a])uses$", "${1}us"},
|
||||
{"(?i)([m|l])ice$", "${1}ouse"},
|
||||
{"(?i)(x|ch|ss|sh)es$", "${1}"},
|
||||
{"(?i)(m)ovies$", "${1}ovie"},
|
||||
{"(?i)(s)eries$", "${1}eries"},
|
||||
{"(?i)([^aeiouy]|qu)ies$", "${1}y"},
|
||||
{"(?i)([lr])ves$", "${1}f"},
|
||||
{"(?i)(tive)s$", "${1}"},
|
||||
{"(?i)(hive)s$", "${1}"},
|
||||
{"(?i)(drive)s$", "${1}"},
|
||||
{"(?i)([^fo])ves$", "${1}fe"},
|
||||
{"(?i)(^analy)ses$", "${1}sis"},
|
||||
{"(?i)(analy|diagno|^ba|(p)arenthe|(p)rogno|(s)ynop|(t)he)ses$", "${1}${2}sis"},
|
||||
{"(?i)([ti])a$", "${1}um"},
|
||||
{"(?i)(p)eople$", "${1}erson"},
|
||||
{"(?i)(m)en$", "${1}an"},
|
||||
{"(?i)(c)hildren$", "${1}hild"},
|
||||
{"(?i)(n)ews$", "${1}ews"},
|
||||
{"(?i)(n)etherlands$", "${1}etherlands"},
|
||||
{"(?i)eaus$", "eau"},
|
||||
{"(?i)(currenc)ies$", "${1}y"},
|
||||
{"(?i)^(.*us)$", "${1}"},
|
||||
{"(?i)s$", ""},
|
||||
}
|
||||
|
||||
// Singularize converts the specified word into its singular version.
|
||||
//
|
||||
// For example:
|
||||
//
|
||||
// inflector.Singularize("people") // "person"
|
||||
func Singularize(word string) string {
|
||||
if word == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
for _, rule := range singularRules {
|
||||
re := compiledPatterns.GetOrSet(rule.pattern, func() *regexp.Regexp {
|
||||
re, err := regexp.Compile(rule.pattern)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return re
|
||||
})
|
||||
if re == nil {
|
||||
// log only for debug purposes
|
||||
log.Println("[Singularize] failed to retrieve/compile rule pattern " + rule.pattern)
|
||||
continue
|
||||
}
|
||||
|
||||
if re.MatchString(word) {
|
||||
return re.ReplaceAllString(word, rule.replacement)
|
||||
}
|
||||
}
|
||||
|
||||
return word
|
||||
}
|
76
tools/inflector/singularize_test.go
Normal file
76
tools/inflector/singularize_test.go
Normal file
|
@ -0,0 +1,76 @@
|
|||
package inflector_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/inflector"
|
||||
)
|
||||
|
||||
func TestSingularize(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
word string
|
||||
expected string
|
||||
}{
|
||||
{"abcnese", "abcnese"},
|
||||
{"deer", "deer"},
|
||||
{"sheep", "sheep"},
|
||||
{"measles", "measles"},
|
||||
{"pox", "pox"},
|
||||
{"media", "media"},
|
||||
{"bliss", "bliss"},
|
||||
{"sea-bass", "sea-bass"},
|
||||
{"Statuses", "Status"},
|
||||
{"Feet", "Foot"},
|
||||
{"Teeth", "Tooth"},
|
||||
{"abcmenus", "abcmenu"},
|
||||
{"Quizzes", "Quiz"},
|
||||
{"Matrices", "Matrix"},
|
||||
{"Vertices", "Vertex"},
|
||||
{"Indices", "Index"},
|
||||
{"Aliases", "Alias"},
|
||||
{"Alumni", "Alumnus"},
|
||||
{"Bacilli", "Bacillus"},
|
||||
{"Cacti", "Cactus"},
|
||||
{"Fungi", "Fungus"},
|
||||
{"Nuclei", "Nucleus"},
|
||||
{"Radii", "Radius"},
|
||||
{"Stimuli", "Stimulus"},
|
||||
{"Syllabi", "Syllabus"},
|
||||
{"Termini", "Terminus"},
|
||||
{"Viri", "Virus"},
|
||||
{"Faxes", "Fax"},
|
||||
{"Crises", "Crisis"},
|
||||
{"Axes", "Axis"},
|
||||
{"Shoes", "Shoe"},
|
||||
{"abcoes", "abco"},
|
||||
{"Houses", "House"},
|
||||
{"Mice", "Mouse"},
|
||||
{"abcxes", "abcx"},
|
||||
{"Movies", "Movie"},
|
||||
{"Series", "Series"},
|
||||
{"abcquies", "abcquy"},
|
||||
{"Relatives", "Relative"},
|
||||
{"Drives", "Drive"},
|
||||
{"aardwolves", "aardwolf"},
|
||||
{"Analyses", "Analysis"},
|
||||
{"Diagnoses", "Diagnosis"},
|
||||
{"People", "Person"},
|
||||
{"Men", "Man"},
|
||||
{"Children", "Child"},
|
||||
{"News", "News"},
|
||||
{"Netherlands", "Netherlands"},
|
||||
{"Tableaus", "Tableau"},
|
||||
{"Currencies", "Currency"},
|
||||
{"abcs", "abc"},
|
||||
{"abc", "abc"},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.word, func(t *testing.T) {
|
||||
result := inflector.Singularize(s.word)
|
||||
if result != s.expected {
|
||||
t.Fatalf("Expected %q, got %q", s.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
163
tools/list/list.go
Normal file
163
tools/list/list.go
Normal file
|
@ -0,0 +1,163 @@
|
|||
package list
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/store"
|
||||
"github.com/spf13/cast"
|
||||
)
|
||||
|
||||
var cachedPatterns = store.New[string, *regexp.Regexp](nil)
|
||||
|
||||
// SubtractSlice returns a new slice with only the "base" elements
|
||||
// that don't exist in "subtract".
|
||||
func SubtractSlice[T comparable](base []T, subtract []T) []T {
|
||||
var result = make([]T, 0, len(base))
|
||||
|
||||
for _, b := range base {
|
||||
if !ExistInSlice(b, subtract) {
|
||||
result = append(result, b)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// ExistInSlice checks whether a comparable element exists in a slice of the same type.
|
||||
func ExistInSlice[T comparable](item T, list []T) bool {
|
||||
for _, v := range list {
|
||||
if v == item {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// ExistInSliceWithRegex checks whether a string exists in a slice
|
||||
// either by direct match, or by a regular expression (eg. `^\w+$`).
|
||||
//
|
||||
// Note: Only list items starting with '^' and ending with '$' are treated as regular expressions!
|
||||
func ExistInSliceWithRegex(str string, list []string) bool {
|
||||
for _, field := range list {
|
||||
isRegex := strings.HasPrefix(field, "^") && strings.HasSuffix(field, "$")
|
||||
|
||||
if !isRegex {
|
||||
// check for direct match
|
||||
if str == field {
|
||||
return true
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// check for regex match
|
||||
pattern := cachedPatterns.Get(field)
|
||||
if pattern == nil {
|
||||
var err error
|
||||
pattern, err = regexp.Compile(field)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
// "cache" the pattern to avoid compiling it every time
|
||||
// (the limit size is arbitrary and it is there to prevent the cache growing too big)
|
||||
//
|
||||
// @todo consider replacing with TTL or LRU type cache
|
||||
cachedPatterns.SetIfLessThanLimit(field, pattern, 500)
|
||||
}
|
||||
|
||||
if pattern != nil && pattern.MatchString(str) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// ToInterfaceSlice converts a generic slice to slice of interfaces.
|
||||
func ToInterfaceSlice[T any](list []T) []any {
|
||||
result := make([]any, len(list))
|
||||
|
||||
for i := range list {
|
||||
result[i] = list[i]
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// NonzeroUniques returns only the nonzero unique values from a slice.
|
||||
func NonzeroUniques[T comparable](list []T) []T {
|
||||
result := make([]T, 0, len(list))
|
||||
existMap := make(map[T]struct{}, len(list))
|
||||
|
||||
var zeroVal T
|
||||
|
||||
for _, val := range list {
|
||||
if val == zeroVal {
|
||||
continue
|
||||
}
|
||||
if _, ok := existMap[val]; ok {
|
||||
continue
|
||||
}
|
||||
existMap[val] = struct{}{}
|
||||
result = append(result, val)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// ToUniqueStringSlice casts `value` to a slice of non-zero unique strings.
|
||||
func ToUniqueStringSlice(value any) (result []string) {
|
||||
switch val := value.(type) {
|
||||
case nil:
|
||||
// nothing to cast
|
||||
case []string:
|
||||
result = val
|
||||
case string:
|
||||
if val == "" {
|
||||
break
|
||||
}
|
||||
|
||||
// check if it is a json encoded array of strings
|
||||
if strings.Contains(val, "[") {
|
||||
if err := json.Unmarshal([]byte(val), &result); err != nil {
|
||||
// not a json array, just add the string as single array element
|
||||
result = append(result, val)
|
||||
}
|
||||
} else {
|
||||
// just add the string as single array element
|
||||
result = append(result, val)
|
||||
}
|
||||
case json.Marshaler: // eg. JSONArray
|
||||
raw, _ := val.MarshalJSON()
|
||||
_ = json.Unmarshal(raw, &result)
|
||||
default:
|
||||
result = cast.ToStringSlice(value)
|
||||
}
|
||||
|
||||
return NonzeroUniques(result)
|
||||
}
|
||||
|
||||
// ToChunks splits list into chunks.
|
||||
//
|
||||
// Zero or negative chunkSize argument is normalized to 1.
|
||||
//
|
||||
// See https://go.dev/wiki/SliceTricks#batching-with-minimal-allocation.
|
||||
func ToChunks[T any](list []T, chunkSize int) [][]T {
|
||||
if chunkSize <= 0 {
|
||||
chunkSize = 1
|
||||
}
|
||||
|
||||
chunks := make([][]T, 0, (len(list)+chunkSize-1)/chunkSize)
|
||||
|
||||
if len(list) == 0 {
|
||||
return chunks
|
||||
}
|
||||
|
||||
for chunkSize < len(list) {
|
||||
list, chunks = list[chunkSize:], append(chunks, list[0:chunkSize:chunkSize])
|
||||
}
|
||||
|
||||
return append(chunks, list)
|
||||
}
|
310
tools/list/list_test.go
Normal file
310
tools/list/list_test.go
Normal file
|
@ -0,0 +1,310 @@
|
|||
package list_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/list"
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
)
|
||||
|
||||
func TestSubtractSliceString(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
base []string
|
||||
subtract []string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
[]string{},
|
||||
[]string{},
|
||||
`[]`,
|
||||
},
|
||||
{
|
||||
[]string{},
|
||||
[]string{"1", "2", "3", "4"},
|
||||
`[]`,
|
||||
},
|
||||
{
|
||||
[]string{"1", "2", "3", "4"},
|
||||
[]string{},
|
||||
`["1","2","3","4"]`,
|
||||
},
|
||||
{
|
||||
[]string{"1", "2", "3", "4"},
|
||||
[]string{"1", "2", "3", "4"},
|
||||
`[]`,
|
||||
},
|
||||
{
|
||||
[]string{"1", "2", "3", "4", "7"},
|
||||
[]string{"2", "4", "5", "6"},
|
||||
`["1","3","7"]`,
|
||||
},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%s", i, s.expected), func(t *testing.T) {
|
||||
result := list.SubtractSlice(s.base, s.subtract)
|
||||
|
||||
raw, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to serialize: %v", err)
|
||||
}
|
||||
|
||||
strResult := string(raw)
|
||||
|
||||
if strResult != s.expected {
|
||||
t.Fatalf("Expected %v, got %v", s.expected, strResult)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubtractSliceInt(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
base []int
|
||||
subtract []int
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
[]int{},
|
||||
[]int{},
|
||||
`[]`,
|
||||
},
|
||||
{
|
||||
[]int{},
|
||||
[]int{1, 2, 3, 4},
|
||||
`[]`,
|
||||
},
|
||||
{
|
||||
[]int{1, 2, 3, 4},
|
||||
[]int{},
|
||||
`[1,2,3,4]`,
|
||||
},
|
||||
{
|
||||
[]int{1, 2, 3, 4},
|
||||
[]int{1, 2, 3, 4},
|
||||
`[]`,
|
||||
},
|
||||
{
|
||||
[]int{1, 2, 3, 4, 7},
|
||||
[]int{2, 4, 5, 6},
|
||||
`[1,3,7]`,
|
||||
},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%s", i, s.expected), func(t *testing.T) {
|
||||
result := list.SubtractSlice(s.base, s.subtract)
|
||||
|
||||
raw, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to serialize: %v", err)
|
||||
}
|
||||
|
||||
strResult := string(raw)
|
||||
|
||||
if strResult != s.expected {
|
||||
t.Fatalf("Expected %v, got %v", s.expected, strResult)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExistInSliceString(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
item string
|
||||
list []string
|
||||
expected bool
|
||||
}{
|
||||
{"", []string{""}, true},
|
||||
{"", []string{"1", "2", "test 123"}, false},
|
||||
{"test", []string{}, false},
|
||||
{"test", []string{"TEST"}, false},
|
||||
{"test", []string{"1", "2", "test 123"}, false},
|
||||
{"test", []string{"1", "2", "test"}, true},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%s", i, s.item), func(t *testing.T) {
|
||||
result := list.ExistInSlice(s.item, s.list)
|
||||
if result != s.expected {
|
||||
t.Fatalf("Expected %v, got %v", s.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExistInSliceInt(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
item int
|
||||
list []int
|
||||
expected bool
|
||||
}{
|
||||
{0, []int{}, false},
|
||||
{0, []int{0}, true},
|
||||
{4, []int{1, 2, 3}, false},
|
||||
{1, []int{1, 2, 3}, true},
|
||||
{-1, []int{0, 1, 2, 3}, false},
|
||||
{-1, []int{0, -1, -2, -3, -4}, true},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%d", i, s.item), func(t *testing.T) {
|
||||
result := list.ExistInSlice(s.item, s.list)
|
||||
if result != s.expected {
|
||||
t.Fatalf("Expected %v, got %v", s.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExistInSliceWithRegex(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
item string
|
||||
list []string
|
||||
expected bool
|
||||
}{
|
||||
{"", []string{``}, true},
|
||||
{"", []string{`^\W+$`}, false},
|
||||
{" ", []string{`^\W+$`}, true},
|
||||
{"test", []string{`^\invalid[+$`}, false}, // invalid regex
|
||||
{"test", []string{`^\W+$`, "test"}, true},
|
||||
{`^\W+$`, []string{`^\W+$`, "test"}, false}, // direct match shouldn't work for this case
|
||||
{`\W+$`, []string{`\W+$`, "test"}, true}, // direct match should work for this case because it is not an actual supported pattern format
|
||||
{"!?@", []string{`\W+$`, "test"}, false}, // the method requires the pattern elems to start with '^'
|
||||
{"!?@", []string{`^\W+`, "test"}, false}, // the method requires the pattern elems to end with '$'
|
||||
{"!?@", []string{`^\W+$`, "test"}, true},
|
||||
{"!?@test", []string{`^\W+$`, "test"}, false},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%s", i, s.item), func(t *testing.T) {
|
||||
result := list.ExistInSliceWithRegex(s.item, s.list)
|
||||
if result != s.expected {
|
||||
t.Fatalf("Expected %v, got %v", s.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestToInterfaceSlice(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
items []string
|
||||
}{
|
||||
{[]string{}},
|
||||
{[]string{""}},
|
||||
{[]string{"1", "test"}},
|
||||
{[]string{"test1", "test1", "test2", "test3"}},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%#v", i, s.items), func(t *testing.T) {
|
||||
result := list.ToInterfaceSlice(s.items)
|
||||
|
||||
if len(result) != len(s.items) {
|
||||
t.Fatalf("Expected length %d, got %d", len(s.items), len(result))
|
||||
}
|
||||
|
||||
for j, v := range result {
|
||||
if v != s.items[j] {
|
||||
t.Fatalf("Result list item doesn't match with the original list item, got %v VS %v", v, s.items[j])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNonzeroUniquesString(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
items []string
|
||||
expected []string
|
||||
}{
|
||||
{[]string{}, []string{}},
|
||||
{[]string{""}, []string{}},
|
||||
{[]string{"1", "test"}, []string{"1", "test"}},
|
||||
{[]string{"test1", "", "test2", "Test2", "test1", "test3"}, []string{"test1", "test2", "Test2", "test3"}},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%#v", i, s.items), func(t *testing.T) {
|
||||
result := list.NonzeroUniques(s.items)
|
||||
|
||||
if len(result) != len(s.expected) {
|
||||
t.Fatalf("Expected length %d, got %d", len(s.expected), len(result))
|
||||
}
|
||||
|
||||
for j, v := range result {
|
||||
if v != s.expected[j] {
|
||||
t.Fatalf("Result list item doesn't match with the expected list item, got %v VS %v", v, s.expected[j])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestToUniqueStringSlice(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
value any
|
||||
expected []string
|
||||
}{
|
||||
{nil, []string{}},
|
||||
{"", []string{}},
|
||||
{[]any{}, []string{}},
|
||||
{[]int{}, []string{}},
|
||||
{"test", []string{"test"}},
|
||||
{[]int{1, 2, 3}, []string{"1", "2", "3"}},
|
||||
{[]any{0, 1, "test", ""}, []string{"0", "1", "test"}},
|
||||
{[]string{"test1", "test2", "test1"}, []string{"test1", "test2"}},
|
||||
{`["test1", "test2", "test2"]`, []string{"test1", "test2"}},
|
||||
{types.JSONArray[string]{"test1", "test2", "test1"}, []string{"test1", "test2"}},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%#v", i, s.value), func(t *testing.T) {
|
||||
result := list.ToUniqueStringSlice(s.value)
|
||||
|
||||
if len(result) != len(s.expected) {
|
||||
t.Fatalf("Expected length %d, got %d", len(s.expected), len(result))
|
||||
}
|
||||
|
||||
for j, v := range result {
|
||||
if v != s.expected[j] {
|
||||
t.Fatalf("Result list item doesn't match with the expected list item, got %v vs %v", v, s.expected[j])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestToChunks(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
items []any
|
||||
chunkSize int
|
||||
expected string
|
||||
}{
|
||||
{nil, 2, "[]"},
|
||||
{[]any{}, 2, "[]"},
|
||||
{[]any{1, 2, 3, 4}, -1, "[[1],[2],[3],[4]]"},
|
||||
{[]any{1, 2, 3, 4}, 0, "[[1],[2],[3],[4]]"},
|
||||
{[]any{1, 2, 3, 4}, 2, "[[1,2],[3,4]]"},
|
||||
{[]any{1, 2, 3, 4, 5}, 2, "[[1,2],[3,4],[5]]"},
|
||||
{[]any{1, 2, 3, 4, 5}, 10, "[[1,2,3,4,5]]"},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%#v", i, s.items), func(t *testing.T) {
|
||||
result := list.ToChunks(s.items, s.chunkSize)
|
||||
|
||||
raw, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
rawStr := string(raw)
|
||||
|
||||
if rawStr != s.expected {
|
||||
t.Fatalf("Expected %v, got %v", s.expected, rawStr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
325
tools/logger/batch_handler.go
Normal file
325
tools/logger/batch_handler.go
Normal file
|
@ -0,0 +1,325 @@
|
|||
package logger
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"sync"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
)
|
||||
|
||||
var _ slog.Handler = (*BatchHandler)(nil)
|
||||
|
||||
// BatchOptions are options for the BatchHandler.
|
||||
type BatchOptions struct {
|
||||
// WriteFunc processes the batched logs.
|
||||
WriteFunc func(ctx context.Context, logs []*Log) error
|
||||
|
||||
// BeforeAddFunc is optional function that is invoked every time
|
||||
// before a new log is added to the batch queue.
|
||||
//
|
||||
// Return false to skip adding the log into the batch queue.
|
||||
BeforeAddFunc func(ctx context.Context, log *Log) bool
|
||||
|
||||
// Level reports the minimum level to log.
|
||||
// Levels with lower levels are discarded.
|
||||
// If nil, the Handler uses [slog.LevelInfo].
|
||||
Level slog.Leveler
|
||||
|
||||
// BatchSize specifies how many logs to accumulate before calling WriteFunc.
|
||||
// If not set or 0, fallback to 100 by default.
|
||||
BatchSize int
|
||||
}
|
||||
|
||||
// NewBatchHandler creates a slog compatible handler that writes JSON
|
||||
// logs on batches (default to 100), using the given options.
|
||||
//
|
||||
// Panics if [BatchOptions.WriteFunc] is not defined.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// l := slog.New(logger.NewBatchHandler(logger.BatchOptions{
|
||||
// WriteFunc: func(ctx context.Context, logs []*Log) error {
|
||||
// for _, l := range logs {
|
||||
// fmt.Println(l.Level, l.Message, l.Data)
|
||||
// }
|
||||
// return nil
|
||||
// }
|
||||
// }))
|
||||
// l.Info("Example message", "title", "lorem ipsum")
|
||||
func NewBatchHandler(options BatchOptions) *BatchHandler {
|
||||
h := &BatchHandler{
|
||||
mux: &sync.Mutex{},
|
||||
options: &options,
|
||||
}
|
||||
|
||||
if h.options.WriteFunc == nil {
|
||||
panic("options.WriteFunc must be set")
|
||||
}
|
||||
|
||||
if h.options.Level == nil {
|
||||
h.options.Level = slog.LevelInfo
|
||||
}
|
||||
|
||||
if h.options.BatchSize == 0 {
|
||||
h.options.BatchSize = 100
|
||||
}
|
||||
|
||||
h.logs = make([]*Log, 0, h.options.BatchSize)
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
// BatchHandler is a slog handler that writes records on batches.
|
||||
//
|
||||
// The log records attributes are formatted in JSON.
|
||||
//
|
||||
// Requires the [BatchOptions.WriteFunc] option to be defined.
|
||||
type BatchHandler struct {
|
||||
mux *sync.Mutex
|
||||
parent *BatchHandler
|
||||
options *BatchOptions
|
||||
group string
|
||||
attrs []slog.Attr
|
||||
logs []*Log
|
||||
}
|
||||
|
||||
// Enabled reports whether the handler handles records at the given level.
|
||||
//
|
||||
// The handler ignores records whose level is lower.
|
||||
func (h *BatchHandler) Enabled(ctx context.Context, level slog.Level) bool {
|
||||
return level >= h.options.Level.Level()
|
||||
}
|
||||
|
||||
// WithGroup returns a new BatchHandler that starts a group.
|
||||
//
|
||||
// All logger attributes will be resolved under the specified group name.
|
||||
func (h *BatchHandler) WithGroup(name string) slog.Handler {
|
||||
if name == "" {
|
||||
return h
|
||||
}
|
||||
|
||||
return &BatchHandler{
|
||||
parent: h,
|
||||
mux: h.mux,
|
||||
options: h.options,
|
||||
group: name,
|
||||
}
|
||||
}
|
||||
|
||||
// WithAttrs returns a new BatchHandler loaded with the specified attributes.
|
||||
func (h *BatchHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
|
||||
if len(attrs) == 0 {
|
||||
return h
|
||||
}
|
||||
|
||||
return &BatchHandler{
|
||||
parent: h,
|
||||
mux: h.mux,
|
||||
options: h.options,
|
||||
attrs: attrs,
|
||||
}
|
||||
}
|
||||
|
||||
// Handle formats the slog.Record argument as JSON object and adds it
|
||||
// to the batch queue.
|
||||
//
|
||||
// If the batch queue threshold has been reached, the WriteFunc option
|
||||
// is invoked with the accumulated logs which in turn will reset the batch queue.
|
||||
func (h *BatchHandler) Handle(ctx context.Context, r slog.Record) error {
|
||||
if h.group != "" {
|
||||
h.mux.Lock()
|
||||
attrs := make([]any, 0, len(h.attrs)+r.NumAttrs())
|
||||
for _, a := range h.attrs {
|
||||
attrs = append(attrs, a)
|
||||
}
|
||||
h.mux.Unlock()
|
||||
|
||||
r.Attrs(func(a slog.Attr) bool {
|
||||
attrs = append(attrs, a)
|
||||
return true
|
||||
})
|
||||
|
||||
r = slog.NewRecord(r.Time, r.Level, r.Message, r.PC)
|
||||
r.AddAttrs(slog.Group(h.group, attrs...))
|
||||
} else if len(h.attrs) > 0 {
|
||||
r = r.Clone()
|
||||
|
||||
h.mux.Lock()
|
||||
r.AddAttrs(h.attrs...)
|
||||
h.mux.Unlock()
|
||||
}
|
||||
|
||||
if h.parent != nil {
|
||||
return h.parent.Handle(ctx, r)
|
||||
}
|
||||
|
||||
data := make(map[string]any, r.NumAttrs())
|
||||
|
||||
r.Attrs(func(a slog.Attr) bool {
|
||||
if err := h.resolveAttr(data, a); err != nil {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
log := &Log{
|
||||
Time: r.Time,
|
||||
Level: r.Level,
|
||||
Message: r.Message,
|
||||
Data: types.JSONMap[any](data),
|
||||
}
|
||||
|
||||
if h.options.BeforeAddFunc != nil && !h.options.BeforeAddFunc(ctx, log) {
|
||||
return nil
|
||||
}
|
||||
|
||||
h.mux.Lock()
|
||||
h.logs = append(h.logs, log)
|
||||
totalLogs := len(h.logs)
|
||||
h.mux.Unlock()
|
||||
|
||||
if totalLogs >= h.options.BatchSize {
|
||||
if err := h.WriteAll(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetLevel updates the handler options level to the specified one.
|
||||
func (h *BatchHandler) SetLevel(level slog.Level) {
|
||||
h.mux.Lock()
|
||||
h.options.Level = level
|
||||
h.mux.Unlock()
|
||||
}
|
||||
|
||||
// WriteAll writes all accumulated Log entries and resets the batch queue.
|
||||
func (h *BatchHandler) WriteAll(ctx context.Context) error {
|
||||
if h.parent != nil {
|
||||
// invoke recursively the parent level handler since the most
|
||||
// top level one is holding the logs queue.
|
||||
return h.parent.WriteAll(ctx)
|
||||
}
|
||||
|
||||
h.mux.Lock()
|
||||
|
||||
totalLogs := len(h.logs)
|
||||
|
||||
// no logs to write
|
||||
if totalLogs == 0 {
|
||||
h.mux.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// create a copy of the logs slice to prevent blocking during write
|
||||
logs := make([]*Log, totalLogs)
|
||||
copy(logs, h.logs)
|
||||
h.logs = h.logs[:0] // reset
|
||||
|
||||
h.mux.Unlock()
|
||||
|
||||
return h.options.WriteFunc(ctx, logs)
|
||||
}
|
||||
|
||||
// resolveAttr writes attr into data.
|
||||
func (h *BatchHandler) resolveAttr(data map[string]any, attr slog.Attr) error {
|
||||
// ensure that the attr value is resolved before doing anything else
|
||||
attr.Value = attr.Value.Resolve()
|
||||
|
||||
if attr.Equal(slog.Attr{}) {
|
||||
return nil // ignore empty attrs
|
||||
}
|
||||
|
||||
switch attr.Value.Kind() {
|
||||
case slog.KindGroup:
|
||||
attrs := attr.Value.Group()
|
||||
if len(attrs) == 0 {
|
||||
return nil // ignore empty groups
|
||||
}
|
||||
|
||||
// create a submap to wrap the resolved group attributes
|
||||
groupData := make(map[string]any, len(attrs))
|
||||
|
||||
for _, subAttr := range attrs {
|
||||
h.resolveAttr(groupData, subAttr)
|
||||
}
|
||||
|
||||
if len(groupData) > 0 {
|
||||
data[attr.Key] = groupData
|
||||
}
|
||||
default:
|
||||
data[attr.Key] = normalizeLogAttrValue(attr.Value.Any())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func normalizeLogAttrValue(rawAttrValue any) any {
|
||||
switch attrV := rawAttrValue.(type) {
|
||||
case validation.Errors:
|
||||
out := make(map[string]any, len(attrV))
|
||||
for k, v := range attrV {
|
||||
out[k] = serializeLogError(v)
|
||||
}
|
||||
return out
|
||||
case map[string]validation.Error:
|
||||
out := make(map[string]any, len(attrV))
|
||||
for k, v := range attrV {
|
||||
out[k] = serializeLogError(v)
|
||||
}
|
||||
return out
|
||||
case map[string]error:
|
||||
out := make(map[string]any, len(attrV))
|
||||
for k, v := range attrV {
|
||||
out[k] = serializeLogError(v)
|
||||
}
|
||||
return out
|
||||
case map[string]any:
|
||||
out := make(map[string]any, len(attrV))
|
||||
for k, v := range attrV {
|
||||
switch vv := v.(type) {
|
||||
case error:
|
||||
out[k] = serializeLogError(vv)
|
||||
default:
|
||||
out[k] = normalizeLogAttrValue(vv)
|
||||
}
|
||||
}
|
||||
return out
|
||||
case error:
|
||||
// check for wrapped validation.Errors
|
||||
var ve validation.Errors
|
||||
if errors.As(attrV, &ve) {
|
||||
out := make(map[string]any, len(ve))
|
||||
for k, v := range ve {
|
||||
out[k] = serializeLogError(v)
|
||||
}
|
||||
return map[string]any{
|
||||
"data": out,
|
||||
"raw": serializeLogError(attrV),
|
||||
}
|
||||
}
|
||||
return serializeLogError(attrV)
|
||||
default:
|
||||
return attrV
|
||||
}
|
||||
}
|
||||
|
||||
func serializeLogError(err error) any {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// prioritize a json structured format (e.g. validation.Errors)
|
||||
jsonErr, ok := err.(json.Marshaler)
|
||||
if ok {
|
||||
return jsonErr
|
||||
}
|
||||
|
||||
// fallback to its original string representation
|
||||
return err.Error()
|
||||
}
|
354
tools/logger/batch_handler_test.go
Normal file
354
tools/logger/batch_handler_test.go
Normal file
|
@ -0,0 +1,354 @@
|
|||
package logger
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
)
|
||||
|
||||
func TestNewBatchHandlerPanic(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Errorf("Expected to panic.")
|
||||
}
|
||||
}()
|
||||
|
||||
NewBatchHandler(BatchOptions{})
|
||||
}
|
||||
|
||||
func TestNewBatchHandlerDefaults(t *testing.T) {
|
||||
h := NewBatchHandler(BatchOptions{
|
||||
WriteFunc: func(ctx context.Context, logs []*Log) error {
|
||||
return nil
|
||||
},
|
||||
})
|
||||
|
||||
if h.options.BatchSize != 100 {
|
||||
t.Fatalf("Expected default BatchSize %d, got %d", 100, h.options.BatchSize)
|
||||
}
|
||||
|
||||
if h.options.Level != slog.LevelInfo {
|
||||
t.Fatalf("Expected default Level Info, got %v", h.options.Level)
|
||||
}
|
||||
|
||||
if h.options.BeforeAddFunc != nil {
|
||||
t.Fatal("Expected default BeforeAddFunc to be nil")
|
||||
}
|
||||
|
||||
if h.options.WriteFunc == nil {
|
||||
t.Fatal("Expected default WriteFunc to be set")
|
||||
}
|
||||
|
||||
if h.group != "" {
|
||||
t.Fatalf("Expected empty group, got %s", h.group)
|
||||
}
|
||||
|
||||
if len(h.attrs) != 0 {
|
||||
t.Fatalf("Expected empty attrs, got %v", h.attrs)
|
||||
}
|
||||
|
||||
if len(h.logs) != 0 {
|
||||
t.Fatalf("Expected empty logs queue, got %v", h.logs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBatchHandlerEnabled(t *testing.T) {
|
||||
h := NewBatchHandler(BatchOptions{
|
||||
Level: slog.LevelWarn,
|
||||
WriteFunc: func(ctx context.Context, logs []*Log) error {
|
||||
return nil
|
||||
},
|
||||
})
|
||||
|
||||
l := slog.New(h)
|
||||
|
||||
scenarios := []struct {
|
||||
level slog.Level
|
||||
expected bool
|
||||
}{
|
||||
{slog.LevelDebug, false},
|
||||
{slog.LevelInfo, false},
|
||||
{slog.LevelWarn, true},
|
||||
{slog.LevelError, true},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("Level %v", s.level), func(t *testing.T) {
|
||||
result := l.Enabled(context.Background(), s.level)
|
||||
|
||||
if result != s.expected {
|
||||
t.Fatalf("Expected %v, got %v", s.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBatchHandlerSetLevel(t *testing.T) {
|
||||
h := NewBatchHandler(BatchOptions{
|
||||
Level: slog.LevelWarn,
|
||||
WriteFunc: func(ctx context.Context, logs []*Log) error {
|
||||
return nil
|
||||
},
|
||||
})
|
||||
|
||||
if h.options.Level != slog.LevelWarn {
|
||||
t.Fatalf("Expected the initial level to be %d, got %d", slog.LevelWarn, h.options.Level)
|
||||
}
|
||||
|
||||
h.SetLevel(slog.LevelDebug)
|
||||
|
||||
if h.options.Level != slog.LevelDebug {
|
||||
t.Fatalf("Expected the updated level to be %d, got %d", slog.LevelDebug, h.options.Level)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBatchHandlerWithAttrsAndWithGroup(t *testing.T) {
|
||||
h0 := NewBatchHandler(BatchOptions{
|
||||
WriteFunc: func(ctx context.Context, logs []*Log) error {
|
||||
return nil
|
||||
},
|
||||
})
|
||||
|
||||
h1 := h0.WithAttrs([]slog.Attr{slog.Int("test1", 1)}).(*BatchHandler)
|
||||
h2 := h1.WithGroup("h2_group").(*BatchHandler)
|
||||
h3 := h2.WithAttrs([]slog.Attr{slog.Int("test2", 2)}).(*BatchHandler)
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
handler *BatchHandler
|
||||
expectedParent *BatchHandler
|
||||
expectedGroup string
|
||||
expectedAttrs int
|
||||
}{
|
||||
{
|
||||
"h0",
|
||||
h0,
|
||||
nil,
|
||||
"",
|
||||
0,
|
||||
},
|
||||
{
|
||||
"h1",
|
||||
h1,
|
||||
h0,
|
||||
"",
|
||||
1,
|
||||
},
|
||||
{
|
||||
"h2",
|
||||
h2,
|
||||
h1,
|
||||
"h2_group",
|
||||
0,
|
||||
},
|
||||
{
|
||||
"h3",
|
||||
h3,
|
||||
h2,
|
||||
"",
|
||||
1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
if s.handler.group != s.expectedGroup {
|
||||
t.Fatalf("Expected group %q, got %q", s.expectedGroup, s.handler.group)
|
||||
}
|
||||
|
||||
if s.handler.parent != s.expectedParent {
|
||||
t.Fatalf("Expected parent %v, got %v", s.expectedParent, s.handler.parent)
|
||||
}
|
||||
|
||||
if totalAttrs := len(s.handler.attrs); totalAttrs != s.expectedAttrs {
|
||||
t.Fatalf("Expected %d attrs, got %d", s.expectedAttrs, totalAttrs)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBatchHandlerHandle(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
beforeLogs := []*Log{}
|
||||
writeLogs := []*Log{}
|
||||
|
||||
h := NewBatchHandler(BatchOptions{
|
||||
BatchSize: 3,
|
||||
BeforeAddFunc: func(_ context.Context, log *Log) bool {
|
||||
beforeLogs = append(beforeLogs, log)
|
||||
|
||||
// skip test2 log
|
||||
return log.Message != "test2"
|
||||
},
|
||||
WriteFunc: func(_ context.Context, logs []*Log) error {
|
||||
writeLogs = logs
|
||||
return nil
|
||||
},
|
||||
})
|
||||
|
||||
h.Handle(ctx, slog.NewRecord(time.Now(), slog.LevelInfo, "test1", 0))
|
||||
h.Handle(ctx, slog.NewRecord(time.Now(), slog.LevelInfo, "test2", 0))
|
||||
h.Handle(ctx, slog.NewRecord(time.Now(), slog.LevelInfo, "test3", 0))
|
||||
|
||||
// no batch write
|
||||
{
|
||||
checkLogMessages([]string{"test1", "test2", "test3"}, beforeLogs, t)
|
||||
|
||||
checkLogMessages([]string{"test1", "test3"}, h.logs, t)
|
||||
|
||||
// should be empty because no batch write has happened yet
|
||||
if totalWriteLogs := len(writeLogs); totalWriteLogs != 0 {
|
||||
t.Fatalf("Expected %d writeLogs, got %d", 0, totalWriteLogs)
|
||||
}
|
||||
}
|
||||
|
||||
// add one more log to trigger the batch write
|
||||
{
|
||||
h.Handle(ctx, slog.NewRecord(time.Now(), slog.LevelInfo, "test4", 0))
|
||||
|
||||
// should be empty after the batch write
|
||||
checkLogMessages([]string{}, h.logs, t)
|
||||
|
||||
checkLogMessages([]string{"test1", "test3", "test4"}, writeLogs, t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBatchHandlerWriteAll(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
beforeLogs := []*Log{}
|
||||
writeLogs := []*Log{}
|
||||
|
||||
h := NewBatchHandler(BatchOptions{
|
||||
BatchSize: 3,
|
||||
BeforeAddFunc: func(_ context.Context, log *Log) bool {
|
||||
beforeLogs = append(beforeLogs, log)
|
||||
return true
|
||||
},
|
||||
WriteFunc: func(_ context.Context, logs []*Log) error {
|
||||
writeLogs = logs
|
||||
return nil
|
||||
},
|
||||
})
|
||||
|
||||
h.Handle(ctx, slog.NewRecord(time.Now(), slog.LevelInfo, "test1", 0))
|
||||
h.Handle(ctx, slog.NewRecord(time.Now(), slog.LevelInfo, "test2", 0))
|
||||
|
||||
checkLogMessages([]string{"test1", "test2"}, beforeLogs, t)
|
||||
checkLogMessages([]string{"test1", "test2"}, h.logs, t)
|
||||
checkLogMessages([]string{}, writeLogs, t) // empty because the batch size hasn't been reached
|
||||
|
||||
// force trigger the batch write
|
||||
h.WriteAll(ctx)
|
||||
|
||||
checkLogMessages([]string{"test1", "test2"}, beforeLogs, t)
|
||||
checkLogMessages([]string{}, h.logs, t) // reset
|
||||
checkLogMessages([]string{"test1", "test2"}, writeLogs, t)
|
||||
}
|
||||
|
||||
func TestBatchHandlerAttrsFormat(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
beforeLogs := []*Log{}
|
||||
|
||||
h0 := NewBatchHandler(BatchOptions{
|
||||
BeforeAddFunc: func(_ context.Context, log *Log) bool {
|
||||
beforeLogs = append(beforeLogs, log)
|
||||
return true
|
||||
},
|
||||
WriteFunc: func(_ context.Context, logs []*Log) error {
|
||||
return nil
|
||||
},
|
||||
})
|
||||
|
||||
h1 := h0.WithAttrs([]slog.Attr{slog.Int("a", 1), slog.String("b", "123")})
|
||||
|
||||
h2 := h1.WithGroup("sub").WithAttrs([]slog.Attr{
|
||||
slog.Int("c", 3),
|
||||
slog.Any("d", map[string]any{"d.1": 1}),
|
||||
slog.Any("e", errors.New("example error")),
|
||||
})
|
||||
|
||||
record := slog.NewRecord(time.Now(), slog.LevelInfo, "hello", 0)
|
||||
record.AddAttrs(slog.String("name", "test"))
|
||||
|
||||
h0.Handle(ctx, record)
|
||||
h1.Handle(ctx, record)
|
||||
h2.Handle(ctx, record)
|
||||
|
||||
// errors serialization checks
|
||||
errorsRecord := slog.NewRecord(time.Now(), slog.LevelError, "details", 0)
|
||||
errorsRecord.Add("validation.Errors", validation.Errors{
|
||||
"a": validation.NewError("validation_code", "validation_message"),
|
||||
"b": errors.New("plain"),
|
||||
})
|
||||
errorsRecord.Add("wrapped_validation.Errors", fmt.Errorf("wrapped: %w", validation.Errors{
|
||||
"a": validation.NewError("validation_code", "validation_message"),
|
||||
"b": errors.New("plain"),
|
||||
}))
|
||||
errorsRecord.Add("map[string]any", map[string]any{
|
||||
"a": validation.NewError("validation_code", "validation_message"),
|
||||
"b": errors.New("plain"),
|
||||
"c": "test_any",
|
||||
"d": map[string]any{
|
||||
"nestedA": validation.NewError("nested_code", "nested_message"),
|
||||
"nestedB": errors.New("nested_plain"),
|
||||
},
|
||||
})
|
||||
errorsRecord.Add("map[string]error", map[string]error{
|
||||
"a": validation.NewError("validation_code", "validation_message"),
|
||||
"b": errors.New("plain"),
|
||||
})
|
||||
errorsRecord.Add("map[string]validation.Error", map[string]validation.Error{
|
||||
"a": validation.NewError("validation_code", "validation_message"),
|
||||
"b": nil,
|
||||
})
|
||||
errorsRecord.Add("plain_error", errors.New("plain"))
|
||||
h0.Handle(ctx, errorsRecord)
|
||||
|
||||
expected := []string{
|
||||
`{"name":"test"}`,
|
||||
`{"a":1,"b":"123","name":"test"}`,
|
||||
`{"a":1,"b":"123","sub":{"c":3,"d":{"d.1":1},"e":"example error","name":"test"}}`,
|
||||
`{"map[string]any":{"a":"validation_message","b":"plain","c":"test_any","d":{"nestedA":"nested_message","nestedB":"nested_plain"}},"map[string]error":{"a":"validation_message","b":"plain"},"map[string]validation.Error":{"a":"validation_message","b":null},"plain_error":"plain","validation.Errors":{"a":"validation_message","b":"plain"},"wrapped_validation.Errors":{"data":{"a":"validation_message","b":"plain"},"raw":"wrapped: a: validation_message; b: plain."}}`,
|
||||
}
|
||||
|
||||
if len(beforeLogs) != len(expected) {
|
||||
t.Fatalf("Expected %d logs, got %d", len(expected), len(beforeLogs))
|
||||
}
|
||||
|
||||
for i, data := range expected {
|
||||
t.Run(fmt.Sprintf("log handler %d", i), func(t *testing.T) {
|
||||
log := beforeLogs[i]
|
||||
raw, _ := log.Data.MarshalJSON()
|
||||
if string(raw) != data {
|
||||
t.Fatalf("Expected \n%s \ngot \n%s", data, raw)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func checkLogMessages(expected []string, logs []*Log, t *testing.T) {
|
||||
if len(logs) != len(expected) {
|
||||
t.Fatalf("Expected %d batched logs, got %d (expected: %v)", len(expected), len(logs), expected)
|
||||
}
|
||||
|
||||
for _, message := range expected {
|
||||
exists := false
|
||||
for _, l := range logs {
|
||||
if l.Message == message {
|
||||
exists = true
|
||||
continue
|
||||
}
|
||||
}
|
||||
if !exists {
|
||||
t.Fatalf("Missing %q log message", message)
|
||||
}
|
||||
}
|
||||
}
|
17
tools/logger/log.go
Normal file
17
tools/logger/log.go
Normal file
|
@ -0,0 +1,17 @@
|
|||
package logger
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
)
|
||||
|
||||
// Log is similar to [slog.Record] bit contains the log attributes as
|
||||
// preformatted JSON map.
|
||||
type Log struct {
|
||||
Time time.Time
|
||||
Data types.JSONMap[any]
|
||||
Message string
|
||||
Level slog.Level
|
||||
}
|
118
tools/mailer/html2text.go
Normal file
118
tools/mailer/html2text.go
Normal file
|
@ -0,0 +1,118 @@
|
|||
package mailer
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/list"
|
||||
"golang.org/x/net/html"
|
||||
)
|
||||
|
||||
var whitespaceRegex = regexp.MustCompile(`\s+`)
|
||||
|
||||
var tagsToSkip = []string{
|
||||
"style", "script", "iframe", "applet", "object", "svg", "img",
|
||||
"button", "form", "textarea", "input", "select", "option", "template",
|
||||
}
|
||||
|
||||
var inlineTags = []string{
|
||||
"a", "span", "small", "strike", "strong",
|
||||
"sub", "sup", "em", "b", "u", "i",
|
||||
}
|
||||
|
||||
// Very rudimentary auto HTML to Text mail body converter.
|
||||
//
|
||||
// Caveats:
|
||||
// - This method doesn't check for correctness of the HTML document.
|
||||
// - Links will be converted to "[text](url)" format.
|
||||
// - List items (<li>) are prefixed with "- ".
|
||||
// - Indentation is stripped (both tabs and spaces).
|
||||
// - Trailing spaces are preserved.
|
||||
// - Multiple consequence newlines are collapsed as one unless multiple <br> tags are used.
|
||||
func html2Text(htmlDocument string) (string, error) {
|
||||
doc, err := html.Parse(strings.NewReader(htmlDocument))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var builder strings.Builder
|
||||
var canAddNewLine bool
|
||||
|
||||
// see https://pkg.go.dev/golang.org/x/net/html#Parse
|
||||
var f func(*html.Node, *strings.Builder)
|
||||
f = func(n *html.Node, activeBuilder *strings.Builder) {
|
||||
isLink := n.Type == html.ElementNode && n.Data == "a"
|
||||
|
||||
if isLink {
|
||||
var linkBuilder strings.Builder
|
||||
activeBuilder = &linkBuilder
|
||||
} else if activeBuilder == nil {
|
||||
activeBuilder = &builder
|
||||
}
|
||||
|
||||
switch n.Type {
|
||||
case html.TextNode:
|
||||
txt := whitespaceRegex.ReplaceAllString(n.Data, " ")
|
||||
|
||||
// the prev node has new line so it is safe to trim the indentation
|
||||
if !canAddNewLine {
|
||||
txt = strings.TrimLeft(txt, " ")
|
||||
}
|
||||
|
||||
if txt != "" {
|
||||
activeBuilder.WriteString(txt)
|
||||
canAddNewLine = true
|
||||
}
|
||||
case html.ElementNode:
|
||||
if n.Data == "br" {
|
||||
// always write new lines when <br> tag is used
|
||||
activeBuilder.WriteString("\r\n")
|
||||
canAddNewLine = false
|
||||
} else if canAddNewLine && !list.ExistInSlice(n.Data, inlineTags) {
|
||||
activeBuilder.WriteString("\r\n")
|
||||
canAddNewLine = false
|
||||
}
|
||||
|
||||
// prefix list items with dash
|
||||
if n.Data == "li" {
|
||||
activeBuilder.WriteString("- ")
|
||||
}
|
||||
}
|
||||
|
||||
for c := n.FirstChild; c != nil; c = c.NextSibling {
|
||||
if c.Type != html.ElementNode || !list.ExistInSlice(c.Data, tagsToSkip) {
|
||||
f(c, activeBuilder)
|
||||
}
|
||||
}
|
||||
|
||||
// format links as [label](href)
|
||||
if isLink {
|
||||
linkTxt := strings.TrimSpace(activeBuilder.String())
|
||||
if linkTxt == "" {
|
||||
linkTxt = "LINK"
|
||||
}
|
||||
|
||||
builder.WriteString("[")
|
||||
builder.WriteString(linkTxt)
|
||||
builder.WriteString("]")
|
||||
|
||||
// link href attr extraction
|
||||
for _, a := range n.Attr {
|
||||
if a.Key == "href" {
|
||||
if a.Val != "" {
|
||||
builder.WriteString("(")
|
||||
builder.WriteString(a.Val)
|
||||
builder.WriteString(")")
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
activeBuilder.Reset()
|
||||
}
|
||||
}
|
||||
|
||||
f(doc, &builder)
|
||||
|
||||
return strings.TrimSpace(builder.String()), nil
|
||||
}
|
131
tools/mailer/html2text_test.go
Normal file
131
tools/mailer/html2text_test.go
Normal file
|
@ -0,0 +1,131 @@
|
|||
package mailer
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestHTML2Text(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
html string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
"",
|
||||
"",
|
||||
},
|
||||
{
|
||||
"ab c",
|
||||
"ab c",
|
||||
},
|
||||
{
|
||||
"<!-- test html comment -->",
|
||||
"",
|
||||
},
|
||||
{
|
||||
"<!-- test html comment --> a ",
|
||||
"a",
|
||||
},
|
||||
{
|
||||
"<span>a</span>b<span>c</span>",
|
||||
"abc",
|
||||
},
|
||||
{
|
||||
`<a href="a/b/c">test</span>`,
|
||||
"[test](a/b/c)",
|
||||
},
|
||||
{
|
||||
`<a href="">test</span>`,
|
||||
"[test]",
|
||||
},
|
||||
{
|
||||
"<span>a</span> <span>b</span>",
|
||||
"a b",
|
||||
},
|
||||
{
|
||||
"<span>a</span> b <span>c</span>",
|
||||
"a b c",
|
||||
},
|
||||
{
|
||||
"<span>a</span> b <div>c</div>",
|
||||
"a b \r\nc",
|
||||
},
|
||||
{
|
||||
`
|
||||
<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Strict//EN" "http://www.w3.org/TR/xhtml1/DTD/xhtml1-strict.dtd">
|
||||
<html xmlns="http://www.w3.org/1999/xhtml">
|
||||
<head>
|
||||
<meta http-equiv="Content-Type" content="text/html; charset=UTF-8" />
|
||||
<meta name="viewport" content="width=device-width,initial-scale=1" />
|
||||
<style>
|
||||
body {
|
||||
padding: 0;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<!-- test html comment -->
|
||||
<style>
|
||||
body {
|
||||
padding: 0;
|
||||
}
|
||||
</style>
|
||||
<div class="wrapper">
|
||||
<div class="content">
|
||||
<p>Lorem ipsum</p>
|
||||
<p>Dolor sit amet</p>
|
||||
<p>
|
||||
<a href="a/b/c">Verify</a>
|
||||
</p>
|
||||
<br>
|
||||
<p>
|
||||
<a href="a/b/c"><strong>Verify2.1</strong> <strong>Verify2.2</strong></a>
|
||||
</p>
|
||||
<br>
|
||||
<br>
|
||||
<div>
|
||||
<div>
|
||||
<div>
|
||||
<ul>
|
||||
<li>ul.test1</li>
|
||||
<li>ul.test2</li>
|
||||
<li>ul.test3</li>
|
||||
</ul>
|
||||
<ol>
|
||||
<li>ol.test1</li>
|
||||
<li>ol.test2</li>
|
||||
<li>ol.test3</li>
|
||||
</ol>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<select>
|
||||
<option>Option 1</option>
|
||||
<option>Option 2</option>
|
||||
</select>
|
||||
<textarea>test</textarea>
|
||||
<input type="text" value="test" />
|
||||
<button>test</button>
|
||||
<p>
|
||||
Thanks,<br/>
|
||||
PocketBase team
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
`,
|
||||
"Lorem ipsum \r\nDolor sit amet \r\n[Verify](a/b/c) \r\n[Verify2.1 Verify2.2](a/b/c) \r\n\r\n- ul.test1 \r\n- ul.test2 \r\n- ul.test3 \r\n- ol.test1 \r\n- ol.test2 \r\n- ol.test3 \r\nThanks,\r\nPocketBase team",
|
||||
},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
result, err := html2Text(s.html)
|
||||
if err != nil {
|
||||
t.Errorf("(%d) Unexpected error %v", i, err)
|
||||
}
|
||||
|
||||
if result != s.expected {
|
||||
t.Errorf("(%d) Expected \n(%q)\n%v,\n\ngot:\n\n(%q)\n%v", i, s.expected, s.expected, result, result)
|
||||
}
|
||||
}
|
||||
}
|
72
tools/mailer/mailer.go
Normal file
72
tools/mailer/mailer.go
Normal file
|
@ -0,0 +1,72 @@
|
|||
package mailer
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net/mail"
|
||||
|
||||
"github.com/gabriel-vasile/mimetype"
|
||||
"github.com/pocketbase/pocketbase/tools/hook"
|
||||
)
|
||||
|
||||
// Message defines a generic email message struct.
|
||||
type Message struct {
|
||||
From mail.Address `json:"from"`
|
||||
To []mail.Address `json:"to"`
|
||||
Bcc []mail.Address `json:"bcc"`
|
||||
Cc []mail.Address `json:"cc"`
|
||||
Subject string `json:"subject"`
|
||||
HTML string `json:"html"`
|
||||
Text string `json:"text"`
|
||||
Headers map[string]string `json:"headers"`
|
||||
Attachments map[string]io.Reader `json:"attachments"`
|
||||
InlineAttachments map[string]io.Reader `json:"inlineAttachments"`
|
||||
}
|
||||
|
||||
// Mailer defines a base mail client interface.
|
||||
type Mailer interface {
|
||||
// Send sends an email with the provided Message.
|
||||
Send(message *Message) error
|
||||
}
|
||||
|
||||
// SendInterceptor is optional interface for registering mail send hooks.
|
||||
type SendInterceptor interface {
|
||||
OnSend() *hook.Hook[*SendEvent]
|
||||
}
|
||||
|
||||
type SendEvent struct {
|
||||
hook.Event
|
||||
Message *Message
|
||||
}
|
||||
|
||||
// addressesToStrings converts the provided address to a list of serialized RFC 5322 strings.
|
||||
//
|
||||
// To export only the email part of mail.Address, you can set withName to false.
|
||||
func addressesToStrings(addresses []mail.Address, withName bool) []string {
|
||||
result := make([]string, len(addresses))
|
||||
|
||||
for i, addr := range addresses {
|
||||
if withName && addr.Name != "" {
|
||||
result[i] = addr.String()
|
||||
} else {
|
||||
// keep only the email part to avoid wrapping in angle-brackets
|
||||
result[i] = addr.Address
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// detectReaderMimeType reads the first couple bytes of the reader to detect its MIME type.
|
||||
//
|
||||
// Returns a new combined reader from the partial read + the remaining of the original reader.
|
||||
func detectReaderMimeType(r io.Reader) (io.Reader, string, error) {
|
||||
readCopy := new(bytes.Buffer)
|
||||
|
||||
mime, err := mimetype.DetectReader(io.TeeReader(r, readCopy))
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
return io.MultiReader(readCopy, r), mime.String(), nil
|
||||
}
|
77
tools/mailer/mailer_test.go
Normal file
77
tools/mailer/mailer_test.go
Normal file
|
@ -0,0 +1,77 @@
|
|||
package mailer
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/mail"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestAddressesToStrings(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenarios := []struct {
|
||||
withName bool
|
||||
addresses []mail.Address
|
||||
expected []string
|
||||
}{
|
||||
{
|
||||
true,
|
||||
[]mail.Address{{Name: "John Doe", Address: "test1@example.com"}, {Name: "Jane Doe", Address: "test2@example.com"}},
|
||||
[]string{`"John Doe" <test1@example.com>`, `"Jane Doe" <test2@example.com>`},
|
||||
},
|
||||
{
|
||||
true,
|
||||
[]mail.Address{{Name: "John Doe", Address: "test1@example.com"}, {Address: "test2@example.com"}},
|
||||
[]string{`"John Doe" <test1@example.com>`, `test2@example.com`},
|
||||
},
|
||||
{
|
||||
false,
|
||||
[]mail.Address{{Name: "John Doe", Address: "test1@example.com"}, {Name: "Jane Doe", Address: "test2@example.com"}},
|
||||
[]string{`test1@example.com`, `test2@example.com`},
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%v_%v", s.withName, s.addresses), func(t *testing.T) {
|
||||
result := addressesToStrings(s.addresses, s.withName)
|
||||
|
||||
if len(s.expected) != len(result) {
|
||||
t.Fatalf("Expected\n%v\ngot\n%v", s.expected, result)
|
||||
}
|
||||
|
||||
for k, v := range s.expected {
|
||||
if v != result[k] {
|
||||
t.Fatalf("Expected %d address %q, got %q", k, v, result[k])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetectReaderMimeType(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
str := "#!/bin/node\n" + strings.Repeat("a", 10000) // ensure that it is large enough to remain after the signature sniffing
|
||||
|
||||
r, mime, err := detectReaderMimeType(strings.NewReader(str))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expectedMime := "text/javascript"
|
||||
if mime != expectedMime {
|
||||
t.Fatalf("Expected mime %q, got %q", expectedMime, mime)
|
||||
}
|
||||
|
||||
raw, err := io.ReadAll(r)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
rawStr := string(raw)
|
||||
|
||||
if rawStr != str {
|
||||
t.Fatalf("Expected content\n%s\ngot\n%s", str, rawStr)
|
||||
}
|
||||
}
|
99
tools/mailer/sendmail.go
Normal file
99
tools/mailer/sendmail.go
Normal file
|
@ -0,0 +1,99 @@
|
|||
package mailer
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"mime"
|
||||
"net/http"
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/hook"
|
||||
)
|
||||
|
||||
var _ Mailer = (*Sendmail)(nil)
|
||||
|
||||
// Sendmail implements [mailer.Mailer] interface and defines a mail
|
||||
// client that sends emails via the "sendmail" *nix command.
|
||||
//
|
||||
// This client is usually recommended only for development and testing.
|
||||
type Sendmail struct {
|
||||
onSend *hook.Hook[*SendEvent]
|
||||
}
|
||||
|
||||
// OnSend implements [mailer.SendInterceptor] interface.
|
||||
func (c *Sendmail) OnSend() *hook.Hook[*SendEvent] {
|
||||
if c.onSend == nil {
|
||||
c.onSend = &hook.Hook[*SendEvent]{}
|
||||
}
|
||||
return c.onSend
|
||||
}
|
||||
|
||||
// Send implements [mailer.Mailer] interface.
|
||||
func (c *Sendmail) Send(m *Message) error {
|
||||
if c.onSend != nil {
|
||||
return c.onSend.Trigger(&SendEvent{Message: m}, func(e *SendEvent) error {
|
||||
return c.send(e.Message)
|
||||
})
|
||||
}
|
||||
|
||||
return c.send(m)
|
||||
}
|
||||
|
||||
func (c *Sendmail) send(m *Message) error {
|
||||
toAddresses := addressesToStrings(m.To, false)
|
||||
|
||||
headers := make(http.Header)
|
||||
headers.Set("Subject", mime.QEncoding.Encode("utf-8", m.Subject))
|
||||
headers.Set("From", m.From.String())
|
||||
headers.Set("Content-Type", "text/html; charset=UTF-8")
|
||||
headers.Set("To", strings.Join(toAddresses, ","))
|
||||
|
||||
cmdPath, err := findSendmailPath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var buffer bytes.Buffer
|
||||
|
||||
// write
|
||||
// ---
|
||||
if err := headers.Write(&buffer); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := buffer.Write([]byte("\r\n")); err != nil {
|
||||
return err
|
||||
}
|
||||
if m.HTML != "" {
|
||||
if _, err := buffer.Write([]byte(m.HTML)); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if _, err := buffer.Write([]byte(m.Text)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
// ---
|
||||
|
||||
sendmail := exec.Command(cmdPath, strings.Join(toAddresses, ","))
|
||||
sendmail.Stdin = &buffer
|
||||
|
||||
return sendmail.Run()
|
||||
}
|
||||
|
||||
func findSendmailPath() (string, error) {
|
||||
options := []string{
|
||||
"/usr/sbin/sendmail",
|
||||
"/usr/bin/sendmail",
|
||||
"sendmail",
|
||||
}
|
||||
|
||||
for _, option := range options {
|
||||
path, err := exec.LookPath(option)
|
||||
if err == nil {
|
||||
return path, err
|
||||
}
|
||||
}
|
||||
|
||||
return "", errors.New("failed to locate a sendmail executable path")
|
||||
}
|
211
tools/mailer/smtp.go
Normal file
211
tools/mailer/smtp.go
Normal file
|
@ -0,0 +1,211 @@
|
|||
package mailer
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/smtp"
|
||||
"strings"
|
||||
|
||||
"github.com/domodwyer/mailyak/v3"
|
||||
"github.com/pocketbase/pocketbase/tools/hook"
|
||||
"github.com/pocketbase/pocketbase/tools/security"
|
||||
)
|
||||
|
||||
var _ Mailer = (*SMTPClient)(nil)
|
||||
|
||||
const (
|
||||
SMTPAuthPlain = "PLAIN"
|
||||
SMTPAuthLogin = "LOGIN"
|
||||
)
|
||||
|
||||
// SMTPClient defines a SMTP mail client structure that implements
|
||||
// `mailer.Mailer` interface.
|
||||
type SMTPClient struct {
|
||||
onSend *hook.Hook[*SendEvent]
|
||||
|
||||
TLS bool
|
||||
Port int
|
||||
Host string
|
||||
Username string
|
||||
Password string
|
||||
|
||||
// SMTP auth method to use
|
||||
// (if not explicitly set, defaults to "PLAIN")
|
||||
AuthMethod string
|
||||
|
||||
// LocalName is optional domain name used for the EHLO/HELO exchange
|
||||
// (if not explicitly set, defaults to "localhost").
|
||||
//
|
||||
// This is required only by some SMTP servers, such as Gmail SMTP-relay.
|
||||
LocalName string
|
||||
}
|
||||
|
||||
// OnSend implements [mailer.SendInterceptor] interface.
|
||||
func (c *SMTPClient) OnSend() *hook.Hook[*SendEvent] {
|
||||
if c.onSend == nil {
|
||||
c.onSend = &hook.Hook[*SendEvent]{}
|
||||
}
|
||||
return c.onSend
|
||||
}
|
||||
|
||||
// Send implements [mailer.Mailer] interface.
|
||||
func (c *SMTPClient) Send(m *Message) error {
|
||||
if c.onSend != nil {
|
||||
return c.onSend.Trigger(&SendEvent{Message: m}, func(e *SendEvent) error {
|
||||
return c.send(e.Message)
|
||||
})
|
||||
}
|
||||
|
||||
return c.send(m)
|
||||
}
|
||||
|
||||
func (c *SMTPClient) send(m *Message) error {
|
||||
var smtpAuth smtp.Auth
|
||||
if c.Username != "" || c.Password != "" {
|
||||
switch c.AuthMethod {
|
||||
case SMTPAuthLogin:
|
||||
smtpAuth = &smtpLoginAuth{c.Username, c.Password}
|
||||
default:
|
||||
smtpAuth = smtp.PlainAuth("", c.Username, c.Password, c.Host)
|
||||
}
|
||||
}
|
||||
|
||||
// create mail instance
|
||||
var yak *mailyak.MailYak
|
||||
if c.TLS {
|
||||
var tlsErr error
|
||||
yak, tlsErr = mailyak.NewWithTLS(fmt.Sprintf("%s:%d", c.Host, c.Port), smtpAuth, nil)
|
||||
if tlsErr != nil {
|
||||
return tlsErr
|
||||
}
|
||||
} else {
|
||||
yak = mailyak.New(fmt.Sprintf("%s:%d", c.Host, c.Port), smtpAuth)
|
||||
}
|
||||
|
||||
if c.LocalName != "" {
|
||||
yak.LocalName(c.LocalName)
|
||||
}
|
||||
|
||||
if m.From.Name != "" {
|
||||
yak.FromName(m.From.Name)
|
||||
}
|
||||
yak.From(m.From.Address)
|
||||
yak.Subject(m.Subject)
|
||||
yak.HTML().Set(m.HTML)
|
||||
|
||||
if m.Text == "" {
|
||||
// try to generate a plain text version of the HTML
|
||||
if plain, err := html2Text(m.HTML); err == nil {
|
||||
yak.Plain().Set(plain)
|
||||
}
|
||||
} else {
|
||||
yak.Plain().Set(m.Text)
|
||||
}
|
||||
|
||||
if len(m.To) > 0 {
|
||||
yak.To(addressesToStrings(m.To, true)...)
|
||||
}
|
||||
|
||||
if len(m.Bcc) > 0 {
|
||||
yak.Bcc(addressesToStrings(m.Bcc, true)...)
|
||||
}
|
||||
|
||||
if len(m.Cc) > 0 {
|
||||
yak.Cc(addressesToStrings(m.Cc, true)...)
|
||||
}
|
||||
|
||||
// add regular attachements (if any)
|
||||
for name, data := range m.Attachments {
|
||||
r, mime, err := detectReaderMimeType(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
yak.AttachWithMimeType(name, r, mime)
|
||||
}
|
||||
|
||||
// add inline attachments (if any)
|
||||
for name, data := range m.InlineAttachments {
|
||||
r, mime, err := detectReaderMimeType(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
yak.AttachInlineWithMimeType(name, r, mime)
|
||||
}
|
||||
|
||||
// add custom headers (if any)
|
||||
var hasMessageId bool
|
||||
for k, v := range m.Headers {
|
||||
if strings.EqualFold(k, "Message-ID") {
|
||||
hasMessageId = true
|
||||
}
|
||||
yak.AddHeader(k, v)
|
||||
}
|
||||
if !hasMessageId {
|
||||
// add a default message id if missing
|
||||
fromParts := strings.Split(m.From.Address, "@")
|
||||
if len(fromParts) == 2 {
|
||||
yak.AddHeader("Message-ID", fmt.Sprintf("<%s@%s>",
|
||||
security.PseudorandomString(15),
|
||||
fromParts[1],
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
return yak.Send()
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
// AUTH LOGIN
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
var _ smtp.Auth = (*smtpLoginAuth)(nil)
|
||||
|
||||
// smtpLoginAuth defines an AUTH that implements the LOGIN authentication mechanism.
|
||||
//
|
||||
// AUTH LOGIN is obsolete[1] but some mail services like outlook requires it [2].
|
||||
//
|
||||
// NB!
|
||||
// It will only send the credentials if the connection is using TLS or is connected to localhost.
|
||||
// Otherwise authentication will fail with an error, without sending the credentials.
|
||||
//
|
||||
// [1]: https://github.com/golang/go/issues/40817
|
||||
// [2]: https://support.microsoft.com/en-us/office/outlook-com-no-longer-supports-auth-plain-authentication-07f7d5e9-1697-465f-84d2-4513d4ff0145?ui=en-us&rs=en-us&ad=us
|
||||
type smtpLoginAuth struct {
|
||||
username, password string
|
||||
}
|
||||
|
||||
// Start initializes an authentication with the server.
|
||||
//
|
||||
// It is part of the [smtp.Auth] interface.
|
||||
func (a *smtpLoginAuth) Start(server *smtp.ServerInfo) (string, []byte, error) {
|
||||
// Must have TLS, or else localhost server.
|
||||
// Note: If TLS is not true, then we can't trust ANYTHING in ServerInfo.
|
||||
// In particular, it doesn't matter if the server advertises LOGIN auth.
|
||||
// That might just be the attacker saying
|
||||
// "it's ok, you can trust me with your password."
|
||||
if !server.TLS && !isLocalhost(server.Name) {
|
||||
return "", nil, errors.New("unencrypted connection")
|
||||
}
|
||||
|
||||
return "LOGIN", nil, nil
|
||||
}
|
||||
|
||||
// Next "continues" the auth process by feeding the server with the requested data.
|
||||
//
|
||||
// It is part of the [smtp.Auth] interface.
|
||||
func (a *smtpLoginAuth) Next(fromServer []byte, more bool) ([]byte, error) {
|
||||
if more {
|
||||
switch strings.ToLower(string(fromServer)) {
|
||||
case "username:":
|
||||
return []byte(a.username), nil
|
||||
case "password:":
|
||||
return []byte(a.password), nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func isLocalhost(name string) bool {
|
||||
return name == "localhost" || name == "127.0.0.1" || name == "::1"
|
||||
}
|
166
tools/mailer/smtp_test.go
Normal file
166
tools/mailer/smtp_test.go
Normal file
|
@ -0,0 +1,166 @@
|
|||
package mailer
|
||||
|
||||
import (
|
||||
"net/smtp"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLoginAuthStart(t *testing.T) {
|
||||
auth := smtpLoginAuth{username: "test", password: "123456"}
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
serverInfo *smtp.ServerInfo
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
"localhost without tls",
|
||||
&smtp.ServerInfo{TLS: false, Name: "localhost"},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"localhost with tls",
|
||||
&smtp.ServerInfo{TLS: true, Name: "localhost"},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"127.0.0.1 without tls",
|
||||
&smtp.ServerInfo{TLS: false, Name: "127.0.0.1"},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"127.0.0.1 with tls",
|
||||
&smtp.ServerInfo{TLS: false, Name: "127.0.0.1"},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"::1 without tls",
|
||||
&smtp.ServerInfo{TLS: false, Name: "::1"},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"::1 with tls",
|
||||
&smtp.ServerInfo{TLS: false, Name: "::1"},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"non-localhost without tls",
|
||||
&smtp.ServerInfo{TLS: false, Name: "example.com"},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"non-localhost with tls",
|
||||
&smtp.ServerInfo{TLS: true, Name: "example.com"},
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
method, resp, err := auth.Start(s.serverInfo)
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Fatalf("Expected hasErr %v, got %v", s.expectError, hasErr)
|
||||
}
|
||||
|
||||
if hasErr {
|
||||
return
|
||||
}
|
||||
|
||||
if len(resp) != 0 {
|
||||
t.Fatalf("Expected empty data response, got %v", resp)
|
||||
}
|
||||
|
||||
if method != "LOGIN" {
|
||||
t.Fatalf("Expected LOGIN, got %v", method)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoginAuthNext(t *testing.T) {
|
||||
auth := smtpLoginAuth{username: "test", password: "123456"}
|
||||
|
||||
{
|
||||
// example|false
|
||||
r1, err := auth.Next([]byte("example:"), false)
|
||||
if err != nil {
|
||||
t.Fatalf("[example|false] Unexpected error %v", err)
|
||||
}
|
||||
if len(r1) != 0 {
|
||||
t.Fatalf("[example|false] Expected empty part, got %v", r1)
|
||||
}
|
||||
|
||||
// example|true
|
||||
r2, err := auth.Next([]byte("example:"), true)
|
||||
if err != nil {
|
||||
t.Fatalf("[example|true] Unexpected error %v", err)
|
||||
}
|
||||
if len(r2) != 0 {
|
||||
t.Fatalf("[example|true] Expected empty part, got %v", r2)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------
|
||||
|
||||
{
|
||||
// username:|false
|
||||
r1, err := auth.Next([]byte("username:"), false)
|
||||
if err != nil {
|
||||
t.Fatalf("[username|false] Unexpected error %v", err)
|
||||
}
|
||||
if len(r1) != 0 {
|
||||
t.Fatalf("[username|false] Expected empty part, got %v", r1)
|
||||
}
|
||||
|
||||
// username:|true
|
||||
r2, err := auth.Next([]byte("username:"), true)
|
||||
if err != nil {
|
||||
t.Fatalf("[username|true] Unexpected error %v", err)
|
||||
}
|
||||
if str := string(r2); str != auth.username {
|
||||
t.Fatalf("[username|true] Expected %s, got %s", auth.username, str)
|
||||
}
|
||||
|
||||
// uSeRnAmE:|true
|
||||
r3, err := auth.Next([]byte("uSeRnAmE:"), true)
|
||||
if err != nil {
|
||||
t.Fatalf("[uSeRnAmE|true] Unexpected error %v", err)
|
||||
}
|
||||
if str := string(r3); str != auth.username {
|
||||
t.Fatalf("[uSeRnAmE|true] Expected %s, got %s", auth.username, str)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------
|
||||
|
||||
{
|
||||
// password:|false
|
||||
r1, err := auth.Next([]byte("password:"), false)
|
||||
if err != nil {
|
||||
t.Fatalf("[password|false] Unexpected error %v", err)
|
||||
}
|
||||
if len(r1) != 0 {
|
||||
t.Fatalf("[password|false] Expected empty part, got %v", r1)
|
||||
}
|
||||
|
||||
// password:|true
|
||||
r2, err := auth.Next([]byte("password:"), true)
|
||||
if err != nil {
|
||||
t.Fatalf("[password|true] Unexpected error %v", err)
|
||||
}
|
||||
if str := string(r2); str != auth.password {
|
||||
t.Fatalf("[password|true] Expected %s, got %s", auth.password, str)
|
||||
}
|
||||
|
||||
// pAsSwOrD:|true
|
||||
r3, err := auth.Next([]byte("pAsSwOrD:"), true)
|
||||
if err != nil {
|
||||
t.Fatalf("[pAsSwOrD|true] Unexpected error %v", err)
|
||||
}
|
||||
if str := string(r3); str != auth.password {
|
||||
t.Fatalf("[pAsSwOrD|true] Expected %s, got %s", auth.password, str)
|
||||
}
|
||||
}
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue