Adding upstream version 0.28.1.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
88f1d47ab6
commit
e28c88ef14
933 changed files with 194711 additions and 0 deletions
62
tools/security/crypto.go
Normal file
62
tools/security/crypto.go
Normal file
|
@ -0,0 +1,62 @@
|
|||
package security
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/md5"
|
||||
"crypto/sha256"
|
||||
"crypto/sha512"
|
||||
"crypto/subtle"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// S256Challenge creates base64 encoded sha256 challenge string derived from code.
|
||||
// The padding of the result base64 string is stripped per [RFC 7636].
|
||||
//
|
||||
// [RFC 7636]: https://datatracker.ietf.org/doc/html/rfc7636#section-4.2
|
||||
func S256Challenge(code string) string {
|
||||
h := sha256.New()
|
||||
h.Write([]byte(code))
|
||||
return strings.TrimRight(base64.URLEncoding.EncodeToString(h.Sum(nil)), "=")
|
||||
}
|
||||
|
||||
// MD5 creates md5 hash from the provided plain text.
|
||||
func MD5(text string) string {
|
||||
h := md5.New()
|
||||
h.Write([]byte(text))
|
||||
return fmt.Sprintf("%x", h.Sum(nil))
|
||||
}
|
||||
|
||||
// SHA256 creates sha256 hash as defined in FIPS 180-4 from the provided text.
|
||||
func SHA256(text string) string {
|
||||
h := sha256.New()
|
||||
h.Write([]byte(text))
|
||||
return fmt.Sprintf("%x", h.Sum(nil))
|
||||
}
|
||||
|
||||
// SHA512 creates sha512 hash as defined in FIPS 180-4 from the provided text.
|
||||
func SHA512(text string) string {
|
||||
h := sha512.New()
|
||||
h.Write([]byte(text))
|
||||
return fmt.Sprintf("%x", h.Sum(nil))
|
||||
}
|
||||
|
||||
// HS256 creates a HMAC hash with sha256 digest algorithm.
|
||||
func HS256(text string, secret string) string {
|
||||
h := hmac.New(sha256.New, []byte(secret))
|
||||
h.Write([]byte(text))
|
||||
return fmt.Sprintf("%x", h.Sum(nil))
|
||||
}
|
||||
|
||||
// HS512 creates a HMAC hash with sha512 digest algorithm.
|
||||
func HS512(text string, secret string) string {
|
||||
h := hmac.New(sha512.New, []byte(secret))
|
||||
h.Write([]byte(text))
|
||||
return fmt.Sprintf("%x", h.Sum(nil))
|
||||
}
|
||||
|
||||
// Equal compares two hash strings for equality without leaking timing information.
|
||||
func Equal(hash1 string, hash2 string) bool {
|
||||
return subtle.ConstantTimeCompare([]byte(hash1), []byte(hash2)) == 1
|
||||
}
|
156
tools/security/crypto_test.go
Normal file
156
tools/security/crypto_test.go
Normal file
|
@ -0,0 +1,156 @@
|
|||
package security_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/security"
|
||||
)
|
||||
|
||||
func TestS256Challenge(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
code string
|
||||
expected string
|
||||
}{
|
||||
{"", "47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU"},
|
||||
{"123", "pmWkWSBCL51Bfkhn79xPuKBKHz__H6B-mY6G9_eieuM"},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.code, func(t *testing.T) {
|
||||
result := security.S256Challenge(s.code)
|
||||
|
||||
if result != s.expected {
|
||||
t.Fatalf("Expected %q, got %q", s.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMD5(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
code string
|
||||
expected string
|
||||
}{
|
||||
{"", "d41d8cd98f00b204e9800998ecf8427e"},
|
||||
{"123", "202cb962ac59075b964b07152d234b70"},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.code, func(t *testing.T) {
|
||||
result := security.MD5(s.code)
|
||||
|
||||
if result != s.expected {
|
||||
t.Fatalf("Expected %v, got %v", s.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSHA256(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
code string
|
||||
expected string
|
||||
}{
|
||||
{"", "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"},
|
||||
{"123", "a665a45920422f9d417e4867efdc4fb8a04a1f3fff1fa07e998e86f7f7a27ae3"},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.code, func(t *testing.T) {
|
||||
result := security.SHA256(s.code)
|
||||
|
||||
if result != s.expected {
|
||||
t.Fatalf("Expected %v, got %v", s.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSHA512(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
code string
|
||||
expected string
|
||||
}{
|
||||
{"", "cf83e1357eefb8bdf1542850d66d8007d620e4050b5715dc83f4a921d36ce9ce47d0d13c5d85f2b0ff8318d2877eec2f63b931bd47417a81a538327af927da3e"},
|
||||
{"123", "3c9909afec25354d551dae21590bb26e38d53f2173b8d3dc3eee4c047e7ab1c1eb8b85103e3be7ba613b31bb5c9c36214dc9f14a42fd7a2fdb84856bca5c44c2"},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.code, func(t *testing.T) {
|
||||
result := security.SHA512(s.code)
|
||||
|
||||
if result != s.expected {
|
||||
t.Fatalf("Expected %v, got %v", s.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHS256(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
text string
|
||||
secret string
|
||||
expected string
|
||||
}{
|
||||
{" ", "test", "9fb4e4a12d50728683a222b4fc466a69ee977332cfcdd6b9ebb44c7121dbd99f"},
|
||||
{" ", "test2", "d792417a504716e22805d940125ec12e68e8cb18fc84674703bd96c59f1e1228"},
|
||||
{"hello", "test", "f151ea24bda91a18e89b8bb5793ef324b2a02133cce15a28a719acbd2e58a986"},
|
||||
{"hello", "test2", "16436e8dcbf3d7b5b0455573b27e6372699beb5bfe94e6a2a371b14b4ae068f4"},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d-%s", i, s.text), func(t *testing.T) {
|
||||
result := security.HS256(s.text, s.secret)
|
||||
|
||||
if result != s.expected {
|
||||
t.Fatalf("Expected \n%v, \ngot \n%v", s.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHS512(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
text string
|
||||
secret string
|
||||
expected string
|
||||
}{
|
||||
{" ", "test", "eb3bdb0352c95c38880c1f645fc7e1d1332644f938f50de0d73876e42d6f302e599bb526531ba79940e8b314369aaef3675322d8d851f9fc6ea9ed121286d196"},
|
||||
{" ", "test2", "8b69e84e9252af78ae8b1c4bed3c9f737f69a3df33064cfbefe76b36d19d1827285e543cdf066cdc8bd556cc0cd0e212d52e9c12a50cd16046181ff127f4cf7f"},
|
||||
{"hello", "test", "44f280e11103e295c26cd61dd1cdd8178b531b860466867c13b1c37a26b6389f8af110efbe0bb0717b9d9c87f6fe1c97b3b1690936578890e5669abf279fe7fd"},
|
||||
{"hello", "test2", "d7f10b1b66941b20817689b973ca9dfc971090e28cfb8becbddd6824569b323eca6a0cdf2c387aa41e15040007dca5a011dd4e4bb61cfd5011aa7354d866f6ef"},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d-%q", i, s.text), func(t *testing.T) {
|
||||
result := security.HS512(s.text, s.secret)
|
||||
|
||||
if result != s.expected {
|
||||
t.Fatalf("Expected \n%v, \ngot \n%v", s.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEqual(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
hash1 string
|
||||
hash2 string
|
||||
expected bool
|
||||
}{
|
||||
{"", "", true},
|
||||
{"abc", "abd", false},
|
||||
{"abc", "abc", true},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%qVS%q", s.hash1, s.hash2), func(t *testing.T) {
|
||||
result := security.Equal(s.hash1, s.hash2)
|
||||
|
||||
if result != s.expected {
|
||||
t.Fatalf("Expected %v, got %v", s.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
62
tools/security/encrypt.go
Normal file
62
tools/security/encrypt.go
Normal file
|
@ -0,0 +1,62 @@
|
|||
package security
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
crand "crypto/rand"
|
||||
"encoding/base64"
|
||||
"io"
|
||||
)
|
||||
|
||||
// Encrypt encrypts "data" with the specified "key" (must be valid 32 char AES key).
|
||||
//
|
||||
// This method uses AES-256-GCM block cypher mode.
|
||||
func Encrypt(data []byte, key string) (string, error) {
|
||||
block, err := aes.NewCipher([]byte(key))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
nonce := make([]byte, gcm.NonceSize())
|
||||
|
||||
// populates the nonce with a cryptographically secure random sequence
|
||||
if _, err := io.ReadFull(crand.Reader, nonce); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
cipherByte := gcm.Seal(nonce, nonce, data, nil)
|
||||
|
||||
result := base64.StdEncoding.EncodeToString(cipherByte)
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Decrypt decrypts encrypted text with key (must be valid 32 chars AES key).
|
||||
//
|
||||
// This method uses AES-256-GCM block cypher mode.
|
||||
func Decrypt(cipherText string, key string) ([]byte, error) {
|
||||
block, err := aes.NewCipher([]byte(key))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nonceSize := gcm.NonceSize()
|
||||
|
||||
cipherByte, err := base64.StdEncoding.DecodeString(cipherText)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nonce, cipherByteClean := cipherByte[:nonceSize], cipherByte[nonceSize:]
|
||||
return gcm.Open(nil, nonce, cipherByteClean, nil)
|
||||
}
|
80
tools/security/encrypt_test.go
Normal file
80
tools/security/encrypt_test.go
Normal file
|
@ -0,0 +1,80 @@
|
|||
package security_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/security"
|
||||
)
|
||||
|
||||
func TestEncrypt(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
data string
|
||||
key string
|
||||
expectError bool
|
||||
}{
|
||||
{"", "", true},
|
||||
{"123", "test", true}, // key must be valid 32 char aes string
|
||||
{"123", "abcdabcdabcdabcdabcdabcdabcdabcd", false},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%s", i, s.data), func(t *testing.T) {
|
||||
result, err := security.Encrypt([]byte(s.data), s.key)
|
||||
|
||||
hasErr := err != nil
|
||||
|
||||
if hasErr != s.expectError {
|
||||
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
|
||||
}
|
||||
|
||||
if hasErr {
|
||||
if result != "" {
|
||||
t.Fatalf("Expected empty Encrypt result on error, got %q", result)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// try to decrypt
|
||||
decrypted, err := security.Decrypt(result, s.key)
|
||||
if err != nil || string(decrypted) != s.data {
|
||||
t.Fatalf("Expected decrypted value to match with the data input, got %q (%v)", decrypted, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecrypt(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
cipher string
|
||||
key string
|
||||
expectError bool
|
||||
expectedData string
|
||||
}{
|
||||
{"", "", true, ""},
|
||||
{"123", "test", true, ""}, // key must be valid 32 char aes string
|
||||
{"8kcEqilvvYKYcfnSr0aSC54gmnQCsB02SaB8ATlnA==", "abcdabcdabcdabcdabcdabcdabcdabcd", true, ""}, // illegal base64 encoded cipherText
|
||||
{"8kcEqilvv+YKYcfnSr0aSC54gmnQCsB02SaB8ATlnA==", "abcdabcdabcdabcdabcdabcdabcdabcd", false, "123"},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%s", i, s.key), func(t *testing.T) {
|
||||
result, err := security.Decrypt(s.cipher, s.key)
|
||||
|
||||
hasErr := err != nil
|
||||
|
||||
if hasErr != s.expectError {
|
||||
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
|
||||
}
|
||||
|
||||
if hasErr {
|
||||
return
|
||||
}
|
||||
|
||||
if str := string(result); str != s.expectedData {
|
||||
t.Fatalf("Expected %q, got %q", s.expectedData, str)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
56
tools/security/jwt.go
Normal file
56
tools/security/jwt.go
Normal file
|
@ -0,0 +1,56 @@
|
|||
package security
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
// ParseUnverifiedJWT parses JWT and returns its claims
|
||||
// but DOES NOT verify the signature.
|
||||
//
|
||||
// It verifies only the exp, iat and nbf claims.
|
||||
func ParseUnverifiedJWT(token string) (jwt.MapClaims, error) {
|
||||
claims := jwt.MapClaims{}
|
||||
|
||||
parser := &jwt.Parser{}
|
||||
_, _, err := parser.ParseUnverified(token, claims)
|
||||
|
||||
if err == nil {
|
||||
err = jwt.NewValidator(jwt.WithIssuedAt()).Validate(claims)
|
||||
}
|
||||
|
||||
return claims, err
|
||||
}
|
||||
|
||||
// ParseJWT verifies and parses JWT and returns its claims.
|
||||
func ParseJWT(token string, verificationKey string) (jwt.MapClaims, error) {
|
||||
parser := jwt.NewParser(jwt.WithValidMethods([]string{"HS256"}))
|
||||
|
||||
parsedToken, err := parser.Parse(token, func(t *jwt.Token) (any, error) {
|
||||
return []byte(verificationKey), nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if claims, ok := parsedToken.Claims.(jwt.MapClaims); ok && parsedToken.Valid {
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
return nil, errors.New("unable to parse token")
|
||||
}
|
||||
|
||||
// NewJWT generates and returns new HS256 signed JWT.
|
||||
func NewJWT(payload jwt.MapClaims, signingKey string, duration time.Duration) (string, error) {
|
||||
claims := jwt.MapClaims{
|
||||
"exp": time.Now().Add(duration).Unix(),
|
||||
}
|
||||
|
||||
for k, v := range payload {
|
||||
claims[k] = v
|
||||
}
|
||||
|
||||
return jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString([]byte(signingKey))
|
||||
}
|
196
tools/security/jwt_test.go
Normal file
196
tools/security/jwt_test.go
Normal file
|
@ -0,0 +1,196 @@
|
|||
package security_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/pocketbase/pocketbase/tools/security"
|
||||
)
|
||||
|
||||
func TestParseUnverifiedJWT(t *testing.T) {
|
||||
// invalid formatted JWT
|
||||
result1, err1 := security.ParseUnverifiedJWT("eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJuYW1lIjoidGVzdCJ9")
|
||||
if err1 == nil {
|
||||
t.Error("Expected error got nil")
|
||||
}
|
||||
if len(result1) > 0 {
|
||||
t.Error("Expected no parsed claims, got", result1)
|
||||
}
|
||||
|
||||
// properly formatted JWT with INVALID claims
|
||||
// {"name": "test", "exp":1516239022}
|
||||
result2, err2 := security.ParseUnverifiedJWT("eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJuYW1lIjoidGVzdCIsImV4cCI6MTUxNjIzOTAyMn0.xYHirwESfSEW3Cq2BL47CEASvD_p_ps3QCA54XtNktU")
|
||||
if err2 == nil {
|
||||
t.Error("Expected error got nil")
|
||||
}
|
||||
if len(result2) != 2 || result2["name"] != "test" {
|
||||
t.Errorf("Expected to have 2 claims, got %v", result2)
|
||||
}
|
||||
|
||||
// properly formatted JWT with VALID claims (missing exp)
|
||||
// {"name": "test"}
|
||||
result3, err3 := security.ParseUnverifiedJWT("eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJuYW1lIjoidGVzdCJ9.ml0QsTms3K9wMygTu41ZhKlTyjmW9zHQtoS8FUsCCjU")
|
||||
if err3 != nil {
|
||||
t.Error("Expected nil, got", err3)
|
||||
}
|
||||
if len(result3) != 1 || result3["name"] != "test" {
|
||||
t.Errorf("Expected to have 1 claim, got %v", result3)
|
||||
}
|
||||
|
||||
// properly formatted JWT with VALID claims (valid exp)
|
||||
// {"name": "test", "exp": 2208985261}
|
||||
result4, err4 := security.ParseUnverifiedJWT("eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJuYW1lIjoidGVzdCIsImV4cCI6MjIwODk4NTI2MX0._0KQu60hYNx5wkBIpEaoX35shXRicb0X_0VdWKWb-3k")
|
||||
if err4 != nil {
|
||||
t.Error("Expected nil, got", err4)
|
||||
}
|
||||
if len(result4) != 2 || result4["name"] != "test" {
|
||||
t.Errorf("Expected to have 2 claims, got %v", result4)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseJWT(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
token string
|
||||
secret string
|
||||
expectError bool
|
||||
expectClaims jwt.MapClaims
|
||||
}{
|
||||
// invalid formatted JWT
|
||||
{
|
||||
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJuYW1lIjoidGVzdCJ9",
|
||||
"test",
|
||||
true,
|
||||
nil,
|
||||
},
|
||||
// properly formatted JWT with INVALID claims and INVALID secret
|
||||
// {"name": "test", "exp": 1516239022}
|
||||
{
|
||||
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJuYW1lIjoidGVzdCIsImV4cCI6MTUxNjIzOTAyMn0.xYHirwESfSEW3Cq2BL47CEASvD_p_ps3QCA54XtNktU",
|
||||
"invalid",
|
||||
true,
|
||||
nil,
|
||||
},
|
||||
// properly formatted JWT with INVALID claims and VALID secret
|
||||
// {"name": "test", "exp": 1516239022}
|
||||
{
|
||||
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJuYW1lIjoidGVzdCIsImV4cCI6MTUxNjIzOTAyMn0.xYHirwESfSEW3Cq2BL47CEASvD_p_ps3QCA54XtNktU",
|
||||
"test",
|
||||
true,
|
||||
nil,
|
||||
},
|
||||
// properly formatted JWT with VALID claims and INVALID secret
|
||||
// {"name": "test", "exp": 1898636137}
|
||||
{
|
||||
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJuYW1lIjoidGVzdCIsImV4cCI6MTg5ODYzNjEzN30.gqRkHjpK5s1PxxBn9qPaWEWxTbpc1PPSD-an83TsXRY",
|
||||
"invalid",
|
||||
true,
|
||||
nil,
|
||||
},
|
||||
// properly formatted EXPIRED JWT with VALID secret
|
||||
// {"name": "test", "exp": 1652097610}
|
||||
{
|
||||
"eyJhbGciOiJIUzI1NiJ9.eyJuYW1lIjoidGVzdCIsImV4cCI6OTU3ODczMzc0fQ.0oUUKUnsQHs4nZO1pnxQHahKtcHspHu4_AplN2sGC4A",
|
||||
"test",
|
||||
true,
|
||||
nil,
|
||||
},
|
||||
// properly formatted JWT with VALID claims and VALID secret
|
||||
// {"name": "test", "exp": 1898636137}
|
||||
{
|
||||
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJuYW1lIjoidGVzdCIsImV4cCI6MTg5ODYzNjEzN30.gqRkHjpK5s1PxxBn9qPaWEWxTbpc1PPSD-an83TsXRY",
|
||||
"test",
|
||||
false,
|
||||
jwt.MapClaims{"name": "test", "exp": 1898636137.0},
|
||||
},
|
||||
// properly formatted JWT with VALID claims (without exp) and VALID secret
|
||||
// {"name": "test"}
|
||||
{
|
||||
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJuYW1lIjoidGVzdCJ9.ml0QsTms3K9wMygTu41ZhKlTyjmW9zHQtoS8FUsCCjU",
|
||||
"test",
|
||||
false,
|
||||
jwt.MapClaims{"name": "test"},
|
||||
},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%s", i, s.token), func(t *testing.T) {
|
||||
result, err := security.ParseJWT(s.token, s.secret)
|
||||
|
||||
hasErr := err != nil
|
||||
|
||||
if hasErr != s.expectError {
|
||||
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
|
||||
}
|
||||
|
||||
if len(result) != len(s.expectClaims) {
|
||||
t.Fatalf("Expected %v claims got %v", s.expectClaims, result)
|
||||
}
|
||||
|
||||
for k, v := range s.expectClaims {
|
||||
v2, ok := result[k]
|
||||
if !ok {
|
||||
t.Fatalf("Missing expected claim %q", k)
|
||||
}
|
||||
if v != v2 {
|
||||
t.Fatalf("Expected %v for %q claim, got %v", v, k, v2)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewJWT(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
claims jwt.MapClaims
|
||||
key string
|
||||
duration time.Duration
|
||||
expectError bool
|
||||
}{
|
||||
// empty, zero duration
|
||||
{jwt.MapClaims{}, "", 0, true},
|
||||
// empty, 10 seconds duration
|
||||
{jwt.MapClaims{}, "", 10 * time.Second, false},
|
||||
// non-empty, 10 seconds duration
|
||||
{jwt.MapClaims{"name": "test"}, "test", 10 * time.Second, false},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
t.Run(strconv.Itoa(i), func(t *testing.T) {
|
||||
token, tokenErr := security.NewJWT(scenario.claims, scenario.key, scenario.duration)
|
||||
if tokenErr != nil {
|
||||
t.Fatalf("Expected NewJWT to succeed, got error %v", tokenErr)
|
||||
}
|
||||
|
||||
claims, parseErr := security.ParseJWT(token, scenario.key)
|
||||
|
||||
hasParseErr := parseErr != nil
|
||||
if hasParseErr != scenario.expectError {
|
||||
t.Fatalf("Expected hasParseErr to be %v, got %v (%v)", scenario.expectError, hasParseErr, parseErr)
|
||||
}
|
||||
|
||||
if scenario.expectError {
|
||||
return
|
||||
}
|
||||
|
||||
if _, ok := claims["exp"]; !ok {
|
||||
t.Fatalf("Missing required claim exp, got %v", claims)
|
||||
}
|
||||
|
||||
// clear exp claim to match with the scenario ones
|
||||
delete(claims, "exp")
|
||||
|
||||
if len(claims) != len(scenario.claims) {
|
||||
t.Fatalf("Expected %v claims, got %v", scenario.claims, claims)
|
||||
}
|
||||
|
||||
for j, k := range claims {
|
||||
if claims[j] != scenario.claims[j] {
|
||||
t.Fatalf("Expected %v for %q claim, got %v", claims[j], k, scenario.claims[j])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
59
tools/security/random.go
Normal file
59
tools/security/random.go
Normal file
|
@ -0,0 +1,59 @@
|
|||
package security
|
||||
|
||||
import (
|
||||
cryptoRand "crypto/rand"
|
||||
"math/big"
|
||||
mathRand "math/rand" // @todo replace with rand/v2?
|
||||
)
|
||||
|
||||
const defaultRandomAlphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
|
||||
|
||||
// RandomString generates a cryptographically random string with the specified length.
|
||||
//
|
||||
// The generated string matches [A-Za-z0-9]+ and it's transparent to URL-encoding.
|
||||
func RandomString(length int) string {
|
||||
return RandomStringWithAlphabet(length, defaultRandomAlphabet)
|
||||
}
|
||||
|
||||
// RandomStringWithAlphabet generates a cryptographically random string
|
||||
// with the specified length and characters set.
|
||||
//
|
||||
// It panics if for some reason rand.Int returns a non-nil error.
|
||||
func RandomStringWithAlphabet(length int, alphabet string) string {
|
||||
b := make([]byte, length)
|
||||
max := big.NewInt(int64(len(alphabet)))
|
||||
|
||||
for i := range b {
|
||||
n, err := cryptoRand.Int(cryptoRand.Reader, max)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
b[i] = alphabet[n.Int64()]
|
||||
}
|
||||
|
||||
return string(b)
|
||||
}
|
||||
|
||||
// PseudorandomString generates a pseudorandom string with the specified length.
|
||||
//
|
||||
// The generated string matches [A-Za-z0-9]+ and it's transparent to URL-encoding.
|
||||
//
|
||||
// For a cryptographically random string (but a little bit slower) use RandomString instead.
|
||||
func PseudorandomString(length int) string {
|
||||
return PseudorandomStringWithAlphabet(length, defaultRandomAlphabet)
|
||||
}
|
||||
|
||||
// PseudorandomStringWithAlphabet generates a pseudorandom string
|
||||
// with the specified length and characters set.
|
||||
//
|
||||
// For a cryptographically random (but a little bit slower) use RandomStringWithAlphabet instead.
|
||||
func PseudorandomStringWithAlphabet(length int, alphabet string) string {
|
||||
b := make([]byte, length)
|
||||
max := len(alphabet)
|
||||
|
||||
for i := range b {
|
||||
b[i] = alphabet[mathRand.Intn(max)]
|
||||
}
|
||||
|
||||
return string(b)
|
||||
}
|
152
tools/security/random_by_regex.go
Normal file
152
tools/security/random_by_regex.go
Normal file
|
@ -0,0 +1,152 @@
|
|||
package security
|
||||
|
||||
import (
|
||||
cryptoRand "crypto/rand"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"regexp/syntax"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const defaultMaxRepeat = 6
|
||||
|
||||
var anyCharNotNLPairs = []rune{'A', 'Z', 'a', 'z', '0', '9'}
|
||||
|
||||
// RandomStringByRegex generates a random string matching the regex pattern.
|
||||
// If optFlags is not set, fallbacks to [syntax.Perl].
|
||||
//
|
||||
// NB! While the source of the randomness comes from [crypto/rand] this method
|
||||
// is not recommended to be used on its own in critical secure contexts because
|
||||
// the generated length could vary too much on the used pattern and may not be
|
||||
// as secure as simply calling [security.RandomString].
|
||||
// If you still insist on using it for such purposes, consider at least
|
||||
// a large enough minimum length for the generated string, e.g. `[a-z0-9]{30}`.
|
||||
//
|
||||
// This function is inspired by github.com/pipe01/revregexp, github.com/lucasjones/reggen and other similar packages.
|
||||
func RandomStringByRegex(pattern string, optFlags ...syntax.Flags) (string, error) {
|
||||
var flags syntax.Flags
|
||||
if len(optFlags) == 0 {
|
||||
flags = syntax.Perl
|
||||
} else {
|
||||
for _, f := range optFlags {
|
||||
flags |= f
|
||||
}
|
||||
}
|
||||
|
||||
r, err := syntax.Parse(pattern, flags)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var sb = new(strings.Builder)
|
||||
|
||||
err = writeRandomStringByRegex(r, sb)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return sb.String(), nil
|
||||
}
|
||||
|
||||
func writeRandomStringByRegex(r *syntax.Regexp, sb *strings.Builder) error {
|
||||
// https://pkg.go.dev/regexp/syntax#Op
|
||||
switch r.Op {
|
||||
case syntax.OpCharClass:
|
||||
c, err := randomRuneFromPairs(r.Rune)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = sb.WriteRune(c)
|
||||
return err
|
||||
case syntax.OpAnyChar, syntax.OpAnyCharNotNL:
|
||||
c, err := randomRuneFromPairs(anyCharNotNLPairs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = sb.WriteRune(c)
|
||||
return err
|
||||
case syntax.OpAlternate:
|
||||
idx, err := randomNumber(len(r.Sub))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return writeRandomStringByRegex(r.Sub[idx], sb)
|
||||
case syntax.OpConcat:
|
||||
var err error
|
||||
for _, sub := range r.Sub {
|
||||
err = writeRandomStringByRegex(sub, sb)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
return err
|
||||
case syntax.OpRepeat:
|
||||
return repeatRandomStringByRegex(r.Sub[0], sb, r.Min, r.Max)
|
||||
case syntax.OpQuest:
|
||||
return repeatRandomStringByRegex(r.Sub[0], sb, 0, 1)
|
||||
case syntax.OpPlus:
|
||||
return repeatRandomStringByRegex(r.Sub[0], sb, 1, -1)
|
||||
case syntax.OpStar:
|
||||
return repeatRandomStringByRegex(r.Sub[0], sb, 0, -1)
|
||||
case syntax.OpCapture:
|
||||
return writeRandomStringByRegex(r.Sub[0], sb)
|
||||
case syntax.OpLiteral:
|
||||
_, err := sb.WriteString(string(r.Rune))
|
||||
return err
|
||||
default:
|
||||
return fmt.Errorf("unsupported pattern operator %d", r.Op)
|
||||
}
|
||||
}
|
||||
|
||||
func repeatRandomStringByRegex(r *syntax.Regexp, sb *strings.Builder, min int, max int) error {
|
||||
if max < 0 {
|
||||
max = defaultMaxRepeat
|
||||
}
|
||||
|
||||
if max < min {
|
||||
max = min
|
||||
}
|
||||
|
||||
n := min
|
||||
if max != min {
|
||||
randRange, err := randomNumber(max - min)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
n += randRange
|
||||
}
|
||||
|
||||
var err error
|
||||
for i := 0; i < n; i++ {
|
||||
err = writeRandomStringByRegex(r, sb)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func randomRuneFromPairs(pairs []rune) (rune, error) {
|
||||
idx, err := randomNumber(len(pairs) / 2)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return randomRuneFromRange(pairs[idx*2], pairs[idx*2+1])
|
||||
}
|
||||
|
||||
func randomRuneFromRange(min rune, max rune) (rune, error) {
|
||||
offset, err := randomNumber(int(max - min + 1))
|
||||
if err != nil {
|
||||
return min, err
|
||||
}
|
||||
|
||||
return min + rune(offset), nil
|
||||
}
|
||||
|
||||
func randomNumber(maxSoft int) (int, error) {
|
||||
randRange, err := cryptoRand.Int(cryptoRand.Reader, big.NewInt(int64(maxSoft)))
|
||||
|
||||
return int(randRange.Int64()), err
|
||||
}
|
66
tools/security/random_by_regex_test.go
Normal file
66
tools/security/random_by_regex_test.go
Normal file
|
@ -0,0 +1,66 @@
|
|||
package security_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"regexp/syntax"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/security"
|
||||
)
|
||||
|
||||
func TestRandomStringByRegex(t *testing.T) {
|
||||
generated := []string{}
|
||||
|
||||
scenarios := []struct {
|
||||
pattern string
|
||||
flags []syntax.Flags
|
||||
expectError bool
|
||||
}{
|
||||
{``, nil, true},
|
||||
{`test`, nil, false},
|
||||
{`\d+`, []syntax.Flags{syntax.POSIX}, true},
|
||||
{`\d+`, nil, false},
|
||||
{`\d*`, nil, false},
|
||||
{`\d{1,10}`, nil, false},
|
||||
{`\d{3}`, nil, false},
|
||||
{`\d{0,}-abc`, nil, false},
|
||||
{`[a-zA-Z]*`, nil, false},
|
||||
{`[^a-zA-Z]{5,30}`, nil, false},
|
||||
{`\w+_abc`, nil, false},
|
||||
{`[a-zA-Z_]*`, nil, false},
|
||||
{`[2-9]{5}-\w+`, nil, false},
|
||||
{`(a|b|c)`, nil, false},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
t.Run(fmt.Sprintf("%d_%q", i, s.pattern), func(t *testing.T) {
|
||||
str, err := security.RandomStringByRegex(s.pattern, s.flags...)
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
|
||||
}
|
||||
|
||||
if hasErr {
|
||||
return
|
||||
}
|
||||
|
||||
r, err := regexp.Compile(s.pattern)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !r.Match([]byte(str)) {
|
||||
t.Fatalf("Expected %q to match pattern %v", str, s.pattern)
|
||||
}
|
||||
|
||||
if slices.Contains(generated, str) {
|
||||
t.Fatalf("The generated string %q already exists in\n%v", str, generated)
|
||||
}
|
||||
|
||||
generated = append(generated, str)
|
||||
})
|
||||
}
|
||||
}
|
89
tools/security/random_test.go
Normal file
89
tools/security/random_test.go
Normal file
|
@ -0,0 +1,89 @@
|
|||
package security_test
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/security"
|
||||
)
|
||||
|
||||
func TestRandomString(t *testing.T) {
|
||||
testRandomString(t, security.RandomString)
|
||||
}
|
||||
|
||||
func TestRandomStringWithAlphabet(t *testing.T) {
|
||||
testRandomStringWithAlphabet(t, security.RandomStringWithAlphabet)
|
||||
}
|
||||
|
||||
func TestPseudorandomString(t *testing.T) {
|
||||
testRandomString(t, security.PseudorandomString)
|
||||
}
|
||||
|
||||
func TestPseudorandomStringWithAlphabet(t *testing.T) {
|
||||
testRandomStringWithAlphabet(t, security.PseudorandomStringWithAlphabet)
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
func testRandomStringWithAlphabet(t *testing.T, randomFunc func(n int, alphabet string) string) {
|
||||
scenarios := []struct {
|
||||
alphabet string
|
||||
expectPattern string
|
||||
}{
|
||||
{"0123456789_", `[0-9_]+`},
|
||||
{"abcdef123", `[abcdef123]+`},
|
||||
{"!@#$%^&*()", `[\!\@\#\$\%\^\&\*\(\)]+`},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
generated := make([]string, 0, 1000)
|
||||
length := 10
|
||||
|
||||
for j := 0; j < 1000; j++ {
|
||||
result := randomFunc(length, s.alphabet)
|
||||
|
||||
if len(result) != length {
|
||||
t.Fatalf("(%d:%d) Expected the length of the string to be %d, got %d", i, j, length, len(result))
|
||||
}
|
||||
|
||||
reg := regexp.MustCompile(s.expectPattern)
|
||||
if match := reg.MatchString(result); !match {
|
||||
t.Fatalf("(%d:%d) The generated string should have only %s characters, got %q", i, j, s.expectPattern, result)
|
||||
}
|
||||
|
||||
for _, str := range generated {
|
||||
if str == result {
|
||||
t.Fatalf("(%d:%d) Repeating random string - found %q in %q", i, j, result, generated)
|
||||
}
|
||||
}
|
||||
|
||||
generated = append(generated, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func testRandomString(t *testing.T, randomFunc func(n int) string) {
|
||||
generated := make([]string, 0, 1000)
|
||||
reg := regexp.MustCompile(`[a-zA-Z0-9]+`)
|
||||
length := 10
|
||||
|
||||
for i := 0; i < 1000; i++ {
|
||||
result := randomFunc(length)
|
||||
|
||||
if len(result) != length {
|
||||
t.Fatalf("(%d) Expected the length of the string to be %d, got %d", i, length, len(result))
|
||||
}
|
||||
|
||||
if match := reg.MatchString(result); !match {
|
||||
t.Fatalf("(%d) The generated string should have only [a-zA-Z0-9]+ characters, got %q", i, result)
|
||||
}
|
||||
|
||||
for _, str := range generated {
|
||||
if str == result {
|
||||
t.Fatalf("(%d) Repeating random string - found %q in \n%v", i, result, generated)
|
||||
}
|
||||
}
|
||||
|
||||
generated = append(generated, result)
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue