1
0
Fork 0

Adding upstream version 3.0.0.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-05-16 22:32:21 +02:00
parent 4199417ac3
commit 8274b1bf1b
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
21 changed files with 2147 additions and 0 deletions

1
.gitattributes vendored Normal file
View file

@ -0,0 +1 @@
* text=lf

7
.github/codecov.yml vendored Normal file
View file

@ -0,0 +1,7 @@
coverage:
status:
patch: off
project:
default:
target: 75%
threshold: null

14
.github/dependabot.yml vendored Normal file
View file

@ -0,0 +1,14 @@
version: 2
updates:
- package-ecosystem: "github-actions"
directory: "/"
labels: ["dependencies"]
schedule:
interval: "weekly"
day: "saturday"
- package-ecosystem: "gomod"
directory: "/"
labels: ["dependencies"]
schedule:
interval: "weekly"
day: "saturday"

30
.github/workflows/test.yml vendored Normal file
View file

@ -0,0 +1,30 @@
name: test
on:
pull_request:
paths-ignore:
- '*.md'
push:
branches:
- master
paths-ignore:
- '*.md'
jobs:
test:
name: test
runs-on: ubuntu-latest
timeout-minutes: 3
steps:
- uses: actions/setup-go@v5
with:
go-version: 1.23.3
- uses: actions/checkout@v4
- name: Test (race)
run: go test ./... -race
- name: Test (coverage)
run: go test ./... -coverprofile=coverage.txt -covermode=atomic
- name: Codecov
uses: codecov/codecov-action@v5.1.2
with:
files: ./coverage.txt
token: ${{ secrets.CODECOV_TOKEN }}

3
.gitignore vendored Normal file
View file

@ -0,0 +1,3 @@
.idea
*.iml
/vendor

21
LICENSE Normal file
View file

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2021-2025 TwiN
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

319
README.md Normal file
View file

@ -0,0 +1,319 @@
# g8
![test](https://github.com/TwiN/g8/workflows/test/badge.svg?branch=master)
[![Go Report Card](https://goreportcard.com/badge/github.com/TwiN/g8)](https://goreportcard.com/report/github.com/TwiN/g8/v3)
[![codecov](https://codecov.io/gh/TwiN/g8/branch/master/graph/badge.svg)](https://codecov.io/gh/TwiN/g8)
[![Go version](https://img.shields.io/github/go-mod/go-version/TwiN/g8.svg)](https://github.com/TwiN/g8)
[![Go Reference](https://pkg.go.dev/badge/github.com/TwiN/g8.svg)](https://pkg.go.dev/github.com/TwiN/g8/v3)
[![Follow TwiN](https://img.shields.io/github/followers/TwiN?label=Follow&style=social)](https://github.com/TwiN)
g8, pronounced gate, is a simple Go library for protecting HTTP handlers.
Tired of constantly re-implementing a security layer for each application? Me too, that's why I made g8.
## Installation
```console
go get -u github.com/TwiN/g8/v3
```
## Usage
Because the entire purpose of g8 is to NOT waste time configuring the layer of security, the primary emphasis is to
keep it as simple as possible.
### Simple
Just want a simple layer of security without the need for advanced permissions? This configuration is what you're
looking for.
```go
authorizationService := g8.NewAuthorizationService().WithToken("mytoken")
gate := g8.New().WithAuthorizationService(authorizationService)
router := http.NewServeMux()
router.Handle("/unprotected", yourHandler)
router.Handle("/protected", gate.Protect(yourHandler))
http.ListenAndServe(":8080", router)
```
The endpoint `/protected` is now only accessible if you pass the header `Authorization: Bearer mytoken`.
If you use `http.HandleFunc` instead of `http.Handle`, you may use `gate.ProtectFunc(yourHandler)` instead.
If you're not using the `Authorization` header, you can specify a custom token extractor.
This enables use cases like [Protecting a handler using session cookie](#protecting-a-handler-using-session-cookie)
### Advanced permissions
If you have tokens with more permissions than others, g8's permission system will make managing authorization a breeze.
Rather than registering tokens, think of it as registering clients, the only difference being that clients may be
configured with permissions while tokens cannot.
```go
authorizationService := g8.NewAuthorizationService().WithClient(g8.NewClient("mytoken").WithPermission("admin"))
gate := g8.New().WithAuthorizationService(authorizationService)
router := http.NewServeMux()
router.Handle("/unprotected", yourHandler)
router.Handle("/protected-with-admin", gate.ProtectWithPermissions(yourHandler, []string{"admin"}))
http.ListenAndServe(":8080", router)
```
The endpoint `/protected-with-admin` is now only accessible if you pass the header `Authorization: Bearer mytoken`,
because the client with the token `mytoken` has the permission `admin`. Note that the following handler would also be
accessible with that token:
```go
router.Handle("/protected", gate.Protect(yourHandler))
```
To clarify, both clients and tokens have access to handlers that aren't protected with extra permissions, and
essentially, tokens are registered as clients with no extra permissions in the background.
Creating a token like so:
```go
authorizationService := g8.NewAuthorizationService().WithToken("mytoken")
```
is the equivalent of creating the following client:
```go
authorizationService := g8.NewAuthorizationService().WithClient(g8.NewClient("mytoken"))
```
### With client provider
A client provider's task is to retrieve a Client from an external source (e.g. a database) when provided with a token.
You should use a client provider when you have a lot of tokens and it wouldn't make sense to register all of them using
`AuthorizationService`'s `WithToken`/`WithTokens`/`WithClient`/`WithClients`.
Note that the provider is used as a fallback source. As such, if a token is explicitly registered using one of the 4
aforementioned functions, the client provider will not be used.
```go
clientProvider := g8.NewClientProvider(func(token string) *g8.Client {
// We'll assume that the following function calls your database and returns a struct "User" that
// has the user's token as well as the permissions granted to said user
user := database.GetUserByToken(token)
if user != nil {
return g8.NewClient(user.Token).WithPermissions(user.Permissions)
}
return nil
})
authorizationService := g8.NewAuthorizationService().WithClientProvider(clientProvider)
gate := g8.New().WithAuthorizationService(authorizationService)
```
You can also configure the client provider to cache the output of the function you provide to retrieve clients by token:
```go
clientProvider := g8.NewClientProvider(...).WithCache(ttl, maxSize)
```
Since g8 leverages [TwiN/gocache](https://github.com/TwiN/gocache) (unless you're using `WithCustomCache`),
you can also use gocache's constants for configuring the TTL and the maximum size:
- Setting the TTL to `gocache.NoExpiration` (-1) will disable the TTL.
- Setting the maximum size to `gocache.NoMaxSize` (0) will disable the maximum cache size
To avoid any misunderstandings, using a client provider is not mandatory. If you only have a few tokens and you can load
them on application start, you can just leverage `AuthorizationService`'s `WithToken`/`WithTokens`/`WithClient`/`WithClients`.
## AuthorizationService
As the previous examples may have hinted, there are several ways to create clients. The one thing they have
in common is that they all go through AuthorizationService, which is in charge of both managing clients and determining
whether a request should be blocked or allowed through.
| Function | Description |
|:-------------------|:---------------------------------------------------------------------------------------------------------------------------------|
| WithToken | Creates a single static client with no extra permissions |
| WithTokens | Creates a slice of static clients with no extra permissions |
| WithClient | Creates a single static client |
| WithClients | Creates a slice of static clients |
| WithClientProvider | Creates a client provider which will allow a fallback to a dynamic source (e.g. to a database) when a static client is not found |
Except for `WithClientProvider`, every functions listed above can be called more than once.
As a result, you may safely perform actions like this:
```go
authorizationService := g8.NewAuthorizationService().
WithToken("123").
WithToken("456").
WithClient(g8.NewClient("789").WithPermission("admin"))
gate := g8.New().WithAuthorizationService(authorizationService)
```
Be aware that g8.Client supports a list of permissions as well. You may call `WithPermission` several times, or call
`WithPermissions` with a slice of permissions instead.
### Permissions
Unlike client permissions, handler permissions are requirements.
A client may have as many permissions as you want, but for said client to have access to a handler protected by
permissions, the client must have all permissions defined by said handler in order to have access to it.
In other words, a client with the permissions `create`, `read`, `update` and `delete` would have access to all of these handlers:
```go
gate := g8.New().WithAuthorizationService(g8.NewAuthorizationService().WithClient(g8.NewClient("mytoken").WithPermissions([]string{"create", "read", "update", "delete"})))
router := http.NewServeMux()
router.Handle("/", gate.Protect(homeHandler)) // equivalent of gate.ProtectWithPermissions(homeHandler, []string{})
router.Handle("/create", gate.ProtectWithPermissions(createHandler, []string{"create"}))
router.Handle("/read", gate.ProtectWithPermissions(readHandler, []string{"read"}))
router.Handle("/update", gate.ProtectWithPermissions(updateHandler, []string{"update"}))
router.Handle("/delete", gate.ProtectWithPermissions(deleteHandler, []string{"delete"}))
router.Handle("/crud", gate.ProtectWithPermissions(crudHandler, []string{"create", "read", "update", "delete"}))
```
But it would not have access to the following handler, because while `mytoken` has the `read` permission, it does not
have the `backup` permission:
```go
router.Handle("/backup", gate.ProtectWithPermissions(&testHandler{}, []string{"read", "backup"}))
```
If you're using an HTTP library that supports middlewares like [mux](https://github.com/gorilla/mux), you can protect
an entire group of handlers instead using `gate.Protect` or `gate.PermissionMiddleware()`:
```go
router := mux.NewRouter()
userRouter := router.PathPrefix("/").Subrouter()
userRouter.Use(gate.Protect)
userRouter.HandleFunc("/api/v1/users/me", getUserProfile).Methods("GET")
userRouter.HandleFunc("/api/v1/users/me/friends", getUserFriends).Methods("GET")
userRouter.HandleFunc("/api/v1/users/me/email", updateUserEmail).Methods("PATCH")
adminRouter := router.PathPrefix("/").Subrouter()
adminRouter.Use(gate.PermissionMiddleware("admin"))
adminRouter.HandleFunc("/api/v1/users/{id}/ban", banUserByID).Methods("POST")
adminRouter.HandleFunc("/api/v1/users/{id}/delete", deleteUserByID).Methods("DELETE")
```
## Rate limiting
To add a rate limit of 100 requests per second:
```go
gate := g8.New().WithRateLimit(100)
```
## Accessing the token from the protected handlers
If you need to access the token from the handlers you are protecting with g8, you can retrieve it from the
request context by using the key `g8.TokenContextKey`:
```go
http.Handle("/handle", gate.ProtectFunc(func(w http.ResponseWriter, r *http.Request) {
token, _ := r.Context().Value(g8.TokenContextKey).(string)
// ...
}))
```
## Examples
### Protecting a handler using session cookie
If you want to only allow authenticated users to access a handler, you can use a custom token extractor function
combined with a client provider.
First, we'll create a function to extract the session ID from the session cookie. While a session ID does not
theoretically refer to a token, g8 uses the term `token` as a blanket term to refer to any string that can be used to
identify a client.
```go
customTokenExtractorFunc := func(request *http.Request) string {
sessionCookie, err := request.Cookie("session")
if err != nil {
return ""
}
return sessionCookie.Value
}
```
Next, we need to create a client provider that will validate our token, which refers to the session ID in this case.
```go
clientProvider := g8.NewClientProvider(func(token string) *g8.Client {
// We'll assume that the following function calls your database and validates whether the session is valid.
isSessionValid := database.CheckIfSessionIsValid(token)
if !isSessionValid {
return nil // Returning nil will cause the gate to return a 401 Unauthorized.
}
// You could also retrieve the user and their permissions if you wanted instead, but for this example,
// all we care about is confirming whether the session is valid or not.
return g8.NewClient(token)
})
```
Keep in mind that you can get really creative with the client provider above.
For instance, you could refresh the session's expiration time, which will allow the user to stay logged in for
as long as they're active.
You're also not limited to using something stateful like the example above. You could use a JWT and have your client
provider validate said JWT.
Finally, we can create the authorization service and the gate:
```go
authorizationService := g8.NewAuthorizationService().WithClientProvider(clientProvider)
gate := g8.New().WithAuthorizationService(authorizationService).WithCustomTokenExtractor(customTokenExtractorFunc)
```
If you need to access the token (session ID in this case) from the protected handlers, you can retrieve it from the
request context by using the key `g8.TokenContextKey`:
```go
http.Handle("/handle", gate.ProtectFunc(func(w http.ResponseWriter, r *http.Request) {
sessionID, _ := r.Context().Value(g8.TokenContextKey).(string)
// ...
}))
```
### Using a custom header
The logic is the same as the example above:
```go
customTokenExtractorFunc := func(request *http.Request) string {
return request.Header.Get("X-API-Token")
}
clientProvider := g8.NewClientProvider(func(token string) *g8.Client {
// We'll assume that the following function calls your database and returns a struct "User" that
// has the user's token as well as the permissions granted to said user
user := database.GetUserByToken(token)
if user != nil {
return g8.NewClient(user.Token).WithPermissions(user.Permissions)
}
return nil
})
authorizationService := g8.NewAuthorizationService().WithClientProvider(clientProvider)
gate := g8.New().WithAuthorizationService(authorizationService).WithCustomTokenExtractor(customTokenExtractorFunc)
```
### Using a custom cache
```go
package main
import (
g8 "github.com/TwiN/g8/v3"
)
type customCache struct {
entries map[string]any
sync.Mutex
}
func (c *customCache) Get(key string) (value any, exists bool) {
return nil, false
}
func (c *customCache) Set(key string, value any) {
// ...
}
// To verify the implementation
var _ g8.Cache = (*customCache)(nil)
func main() {
getClientByTokenFunc := func(token string) *g8.Client {
// We'll assume that the following function calls your database and returns a struct "User" that
// has the user's token as well as the permissions granted to said user
user := database.GetUserByToken(token)
if user != nil {
return g8.NewClient(user.Token).WithPermissions(user.Permissions).WithData(user.Data)
}
return nil
}
// Create the provider with the custom cache
provider := g8.NewClientProvider(getClientByTokenFunc).WithCustomCache(&customCache{})
}
```

130
authorization.go Normal file
View file

@ -0,0 +1,130 @@
package g8
import (
"sync"
)
// AuthorizationService is the service that manages client/token registry and client fallback as well as the service
// that determines whether a token meets the specific requirements to be authorized by a Gate or not.
type AuthorizationService struct {
clients map[string]*Client
clientProvider *ClientProvider
mutex sync.RWMutex
}
// NewAuthorizationService creates a new AuthorizationService
func NewAuthorizationService() *AuthorizationService {
return &AuthorizationService{
clients: make(map[string]*Client),
}
}
// WithToken is used to specify a single token for which authorization will be granted
//
// The client that will be created from this token will have access to all handlers that are not protected with a
// specific permission.
//
// In other words, if you were to do the following:
//
// gate := g8.New().WithAuthorizationService(g8.NewAuthorizationService().WithToken("12345"))
//
// The following handler would be accessible with the token 12345:
//
// router.Handle("/1st-handler", gate.Protect(yourHandler))
//
// But not this one would not be accessible with the token 12345:
//
// router.Handle("/2nd-handler", gate.ProtectWithPermissions(yourOtherHandler, []string{"admin"}))
//
// Calling this function multiple times will add multiple clients, though you may want to use WithTokens instead
// if you plan to add multiple clients
//
// If you wish to configure advanced permissions, consider using WithClient instead.
func (authorizationService *AuthorizationService) WithToken(token string) *AuthorizationService {
authorizationService.mutex.Lock()
authorizationService.clients[token] = NewClient(token)
authorizationService.mutex.Unlock()
return authorizationService
}
// WithTokens is used to specify a slice of tokens for which authorization will be granted
func (authorizationService *AuthorizationService) WithTokens(tokens []string) *AuthorizationService {
authorizationService.mutex.Lock()
for _, token := range tokens {
authorizationService.clients[token] = NewClient(token)
}
authorizationService.mutex.Unlock()
return authorizationService
}
// WithClient is used to specify a single client for which authorization will be granted
//
// When compared to WithToken, the advantage of using this function is that you may specify the client's
// permissions and thus, be a lot more granular with what endpoint a token has access to.
//
// In other words, if you were to do the following:
//
// gate := g8.New().WithAuthorizationService(g8.NewAuthorizationService().WithClient(g8.NewClient("12345").WithPermission("mod")))
//
// The following handlers would be accessible with the token 12345:
//
// router.Handle("/1st-handler", gate.ProtectWithPermissions(yourHandler, []string{"mod"}))
// router.Handle("/2nd-handler", gate.Protect(yourOtherHandler))
//
// But not this one, because the user does not have the permission "admin":
//
// router.Handle("/3rd-handler", gate.ProtectWithPermissions(yetAnotherHandler, []string{"admin"}))
//
// Calling this function multiple times will add multiple clients, though you may want to use WithClients instead
// if you plan to add multiple clients
func (authorizationService *AuthorizationService) WithClient(client *Client) *AuthorizationService {
authorizationService.mutex.Lock()
authorizationService.clients[client.Token] = client
authorizationService.mutex.Unlock()
return authorizationService
}
// WithClients is used to specify a slice of clients for which authorization will be granted
func (authorizationService *AuthorizationService) WithClients(clients []*Client) *AuthorizationService {
authorizationService.mutex.Lock()
for _, client := range clients {
authorizationService.clients[client.Token] = client
}
authorizationService.mutex.Unlock()
return authorizationService
}
// WithClientProvider allows specifying a custom provider to fetch clients by token.
//
// For example, you can use it to fallback to making a call in your database when a request is made with a token that
// hasn't been specified via WithToken, WithTokens, WithClient or WithClients.
func (authorizationService *AuthorizationService) WithClientProvider(provider *ClientProvider) *AuthorizationService {
authorizationService.clientProvider = provider
return authorizationService
}
// Authorize checks whether a client with a given token exists and has the permissions required.
//
// If permissionsRequired is nil or empty and a client with the given token exists, said client will have access to all
// handlers that are not protected by a given permission.
//
// Returns the client is authorized (or nil if no client was authorized), as well as whether the token is authorized
func (authorizationService *AuthorizationService) Authorize(token string, permissionsRequired []string) (client *Client, authorized bool) {
if len(token) == 0 {
return nil, false
}
authorizationService.mutex.RLock()
client, _ = authorizationService.clients[token]
authorizationService.mutex.RUnlock()
// If there's no clients with the given token directly stored in the AuthorizationService, fall back to the
// client provider, if there's one configured.
if client == nil && authorizationService.clientProvider != nil {
client = authorizationService.clientProvider.GetClientByToken(token)
}
if client != nil && client.HasPermissions(permissionsRequired) {
// If the client has the required permissions, return true and the client
return client, true
}
return nil, false
}

108
authorization_test.go Normal file
View file

@ -0,0 +1,108 @@
package g8
import "testing"
func TestAuthorizationService_Authorize(t *testing.T) {
authorizationService := NewAuthorizationService().WithToken("token")
if _, authorized := authorizationService.Authorize("token", nil); !authorized {
t.Error("should've returned true")
}
if _, authorized := authorizationService.Authorize("bad-token", nil); authorized {
t.Error("should've returned false")
}
if _, authorized := authorizationService.Authorize("token", []string{"admin"}); authorized {
t.Error("should've returned false")
}
if _, authorized := authorizationService.Authorize("", nil); authorized {
t.Error("should've returned false")
}
}
func TestAuthorizationService_AuthorizeWithPermissions(t *testing.T) {
authorizationService := NewAuthorizationService().WithClient(NewClient("token").WithPermissions([]string{"a", "b"}))
if _, authorized := authorizationService.Authorize("token", nil); !authorized {
t.Error("should've returned true")
}
if _, authorized := authorizationService.Authorize("token", []string{"a"}); !authorized {
t.Error("should've returned true")
}
if _, authorized := authorizationService.Authorize("token", []string{"b"}); !authorized {
t.Error("should've returned true")
}
if _, authorized := authorizationService.Authorize("token", []string{"a", "b"}); !authorized {
t.Error("should've returned true")
}
if _, authorized := authorizationService.Authorize("token", []string{"c"}); authorized {
t.Error("should've returned false")
}
if _, authorized := authorizationService.Authorize("token", []string{"a", "c"}); authorized {
t.Error("should've returned false")
}
if _, authorized := authorizationService.Authorize("bad-token", nil); authorized {
t.Error("should've returned false")
}
if _, authorized := authorizationService.Authorize("bad-token", []string{"a"}); authorized {
t.Error("should've returned false")
}
if _, authorized := authorizationService.Authorize("", []string{"a"}); authorized {
t.Error("should've returned false")
}
}
func TestAuthorizationService_WithToken(t *testing.T) {
authorizationService := NewAuthorizationService().WithToken("token")
if _, authorized := authorizationService.Authorize("token", nil); !authorized {
t.Error("should've returned true")
}
if _, authorized := authorizationService.Authorize("bad-token", nil); authorized {
t.Error("should've returned false")
}
if _, authorized := authorizationService.Authorize("token", []string{"admin"}); authorized {
t.Error("should've returned false")
}
}
func TestAuthorizationService_WithTokens(t *testing.T) {
authorizationService := NewAuthorizationService().WithTokens([]string{"1", "2"})
if _, authorized := authorizationService.Authorize("1", nil); !authorized {
t.Error("should've returned true")
}
if _, authorized := authorizationService.Authorize("2", nil); !authorized {
t.Error("should've returned true")
}
if _, authorized := authorizationService.Authorize("3", nil); authorized {
t.Error("should've returned false")
}
}
func TestAuthorizationService_WithClient(t *testing.T) {
authorizationService := NewAuthorizationService().WithClient(NewClient("token").WithPermissions([]string{"a", "b"}))
if _, authorized := authorizationService.Authorize("token", []string{"a", "b"}); !authorized {
t.Error("should've returned true")
}
if _, authorized := authorizationService.Authorize("token", []string{"a"}); !authorized {
t.Error("should've returned true")
}
if _, authorized := authorizationService.Authorize("token", []string{"b"}); !authorized {
t.Error("should've returned true")
}
if _, authorized := authorizationService.Authorize("token", []string{"c"}); authorized {
t.Error("should've returned false")
}
}
func TestAuthorizationService_WithClients(t *testing.T) {
authorizationService := NewAuthorizationService().WithClients([]*Client{NewClient("1").WithPermission("a"), NewClient("2").WithPermission("b")})
if _, authorized := authorizationService.Authorize("1", []string{"a"}); !authorized {
t.Error("should've returned true")
}
if _, authorized := authorizationService.Authorize("2", []string{"b"}); !authorized {
t.Error("should've returned true")
}
if _, authorized := authorizationService.Authorize("1", []string{"b"}); authorized {
t.Error("should've returned false")
}
if _, authorized := authorizationService.Authorize("2", []string{"a"}); authorized {
t.Error("should've returned false")
}
}

13
cache.go Normal file
View file

@ -0,0 +1,13 @@
package g8
import (
"github.com/TwiN/gocache/v2"
)
type Cache interface {
Get(key string) (value any, exists bool)
Set(key string, value any)
}
// Make sure that gocache.Cache is compatible with the interface
var _ Cache = (*gocache.Cache)(nil)

79
client.go Normal file
View file

@ -0,0 +1,79 @@
package g8
// Client is a struct containing both a Token and a slice of extra Permissions that said token has.
type Client struct {
// Token is the value used to authenticate with the API.
Token string
// Permissions is a slice of extra permissions that may be used for more granular access control.
//
// If you only wish to use Gate.Protect and Gate.ProtectFunc, you do not have to worry about this,
// since they're only used by Gate.ProtectWithPermissions and Gate.ProtectFuncWithPermissions
Permissions []string
// Data is a field that can be used to store any data you want to associate with the client.
Data any
}
// NewClient creates a Client with a given token
func NewClient(token string) *Client {
return &Client{
Token: token,
}
}
// NewClientWithPermissions creates a Client with a slice of permissions
// Equivalent to using NewClient and WithPermissions
func NewClientWithPermissions(token string, permissions []string) *Client {
return NewClient(token).WithPermissions(permissions)
}
// NewClientWithData creates a Client with some data
// Equivalent to using NewClient and WithData
func NewClientWithData(token string, data any) *Client {
return NewClient(token).WithData(data)
}
// NewClientWithPermissionsAndData creates a Client with a slice of permissions and some data
// Equivalent to using NewClient, WithPermissions and WithData
func NewClientWithPermissionsAndData(token string, permissions []string, data any) *Client {
return NewClient(token).WithPermissions(permissions).WithData(data)
}
// WithPermissions adds a slice of permissions to a client
func (client *Client) WithPermissions(permissions []string) *Client {
client.Permissions = append(client.Permissions, permissions...)
return client
}
// WithPermission adds a permission to a client
func (client *Client) WithPermission(permission string) *Client {
client.Permissions = append(client.Permissions, permission)
return client
}
// WithData attaches data to a client
func (client *Client) WithData(data any) *Client {
client.Data = data
return client
}
// HasPermission checks whether a client has a given permission
func (client *Client) HasPermission(permissionRequired string) bool {
for _, permission := range client.Permissions {
if permissionRequired == permission {
return true
}
}
return false
}
// HasPermissions checks whether a client has the all permissions passed
func (client *Client) HasPermissions(permissionsRequired []string) bool {
for _, permissionRequired := range permissionsRequired {
if !client.HasPermission(permissionRequired) {
return false
}
}
return true
}

76
client_test.go Normal file
View file

@ -0,0 +1,76 @@
package g8
import "testing"
func TestClient_HasPermission(t *testing.T) {
client := NewClientWithPermissions("token", []string{"a", "b"})
if !client.HasPermission("a") {
t.Errorf("client has permissions %s, therefore HasPermission(a) should've been true", client.Permissions)
}
if !client.HasPermission("b") {
t.Errorf("client has permissions %s, therefore HasPermission(b) should've been true", client.Permissions)
}
if client.HasPermission("c") {
t.Errorf("client has permissions %s, therefore HasPermission(c) should've been false", client.Permissions)
}
if client.HasPermission("ab") {
t.Errorf("client has permissions %s, therefore HasPermission(ab) should've been false", client.Permissions)
}
}
func TestClient_HasPermissions(t *testing.T) {
client := NewClientWithPermissions("token", []string{"a", "b"})
if !client.HasPermissions(nil) {
t.Errorf("client has permissions %s, therefore HasPermissions(nil) should've been true", client.Permissions)
}
if !client.HasPermissions([]string{"a"}) {
t.Errorf("client has permissions %s, therefore HasPermissions([a]) should've been true", client.Permissions)
}
if !client.HasPermissions([]string{"b"}) {
t.Errorf("client has permissions %s, therefore HasPermissions([b]) should've been true", client.Permissions)
}
if !client.HasPermissions([]string{"a", "b"}) {
t.Errorf("client has permissions %s, therefore HasPermissions([a, b]) should've been true", client.Permissions)
}
if client.HasPermissions([]string{"a", "b", "c"}) {
t.Errorf("client has permissions %s, therefore HasPermissions([a, b, c]) should've been false", client.Permissions)
}
}
func TestClient_WithData(t *testing.T) {
client := NewClient("token")
if client.Data != nil {
t.Error("expected client data to be nil")
}
client.WithData(5)
if client.Data != 5 {
t.Errorf("expected client data to be 5, got %d", client.Data)
}
client.WithData(map[string]string{"key": "value"})
if data, ok := client.Data.(map[string]string); !ok || data["key"] != "value" {
t.Errorf("expected client data to be map[string]string{key: value}, got %v", client.Data)
}
}
func TestNewClientWithData(t *testing.T) {
client := NewClientWithData("token", 5)
if client.Data != 5 {
t.Errorf("expected client data to be 5, got %d", client.Data)
}
}
func TestNewClientWithPermissionsAndData(t *testing.T) {
client := NewClientWithPermissionsAndData("token", []string{"a", "b"}, 5)
if client.Data != 5 {
t.Errorf("expected client data to be 5, got %d", client.Data)
}
if !client.HasPermission("a") {
t.Errorf("client has permissions %s, therefore HasPermission(a) should've been true", client.Permissions)
}
if !client.HasPermission("b") {
t.Errorf("client has permissions %s, therefore HasPermission(b) should've been true", client.Permissions)
}
if client.HasPermission("c") {
t.Errorf("client has permissions %s, therefore HasPermission(c) should've been false", client.Permissions)
}
}

102
clientprovider.go Normal file
View file

@ -0,0 +1,102 @@
package g8
import (
"time"
"github.com/TwiN/gocache/v2"
)
// ClientProvider has the task of retrieving a Client from an external source (e.g. a database) when provided with a
// token. It should be used when you have a lot of tokens, and it wouldn't make sense to register all of them using
// AuthorizationService's WithToken, WithTokens, WithClient or WithClients.
//
// Note that the provider is used as a fallback source. As such, if a token is explicitly registered using one of the 4
// aforementioned functions, the client provider will not be used by the AuthorizationService when a request is made
// with said token. It will, however, be called upon if a token that is not explicitly registered in
// AuthorizationService is sent alongside a request going through the Gate.
//
// clientProvider := g8.NewClientProvider(func(token string) *g8.Client {
// // We'll assume that the following function calls your database and returns a struct "User" that
// // has the user's token as well as the permissions granted to said user
// user := database.GetUserByToken(token)
// if user != nil {
// return g8.NewClient(user.Token).WithPermissions(user.Permissions)
// }
// return nil
// })
// gate := g8.New().WithAuthorizationService(g8.NewAuthorizationService().WithClientProvider(clientProvider))
type ClientProvider struct {
getClientByTokenFunc func(token string) *Client
cache Cache
}
// NewClientProvider creates a ClientProvider
// The parameter that must be passed is a function that the provider will use to retrieve a client by a given token
//
// Example:
//
// clientProvider := g8.NewClientProvider(func(token string) *g8.Client {
// // We'll assume that the following function calls your database and returns a struct "User" that
// // has the user's token as well as the permissions granted to said user
// user := database.GetUserByToken(token)
// if user == nil {
// return nil
// }
// return g8.NewClient(user.Token).WithPermissions(user.Permissions)
// })
// gate := g8.New().WithAuthorizationService(g8.NewAuthorizationService().WithClientProvider(clientProvider))
func NewClientProvider(getClientByTokenFunc func(token string) *Client) *ClientProvider {
return &ClientProvider{
getClientByTokenFunc: getClientByTokenFunc,
}
}
// WithCache enables an in-memory cache for the ClientProvider.
//
// Example:
//
// clientProvider := g8.NewClientProvider(func(token string) *g8.Client {
// // We'll assume that the following function calls your database and returns a struct "User" that
// // has the user's token as well as the permissions granted to said user
// user := database.GetUserByToken(token)
// if user != nil {
// return g8.NewClient(user.Token).WithPermissions(user.Permissions)
// }
// return nil
// })
// gate := g8.New().WithAuthorizationService(g8.NewAuthorizationService().WithClientProvider(clientProvider.WithCache(time.Hour, 70000)))
func (provider *ClientProvider) WithCache(ttl time.Duration, maxSize int) *ClientProvider {
return provider.WithCustomCache(
gocache.NewCache().WithEvictionPolicy(gocache.LeastRecentlyUsed).WithMaxSize(maxSize).WithDefaultTTL(ttl),
)
}
// WithCustomCache allows you to use a custom cache implementation instead of the default one.
// By default, using WithCache will leverage gocache.
//
// Note that the custom cache must implement the Cache interface
func (provider *ClientProvider) WithCustomCache(cache Cache) *ClientProvider {
provider.cache = cache
return provider
}
// GetClientByToken retrieves a client by its token through the provided getClientByTokenFunc.
func (provider *ClientProvider) GetClientByToken(token string) *Client {
if provider.cache == nil {
return provider.getClientByTokenFunc(token)
}
if cachedClient, exists := provider.cache.Get(token); exists {
if cachedClient == nil {
return nil
}
// Safely typecast the client.
// Regardless of whether the typecast is successful or not, we return client since it'll be either client or
// nil. Technically, it should never be nil, but it's better to be safe than sorry.
client, _ := cachedClient.(*Client)
return client
}
client := provider.getClientByTokenFunc(token)
provider.cache.Set(token, client)
return client
}

131
clientprovider_test.go Normal file
View file

@ -0,0 +1,131 @@
package g8
import (
"sync"
"testing"
"time"
"github.com/TwiN/gocache/v2"
)
var (
getClientByTokenFunc = func(token string) *Client {
if token == "valid-token" {
return NewClient("valid-token").WithData("client-data")
}
return nil
}
)
func TestClientProvider_GetClientByToken(t *testing.T) {
provider := NewClientProvider(getClientByTokenFunc)
if client := provider.GetClientByToken("valid-token"); client == nil {
t.Error("should've returned a client")
} else if client.Data != "client-data" {
t.Error("expected client data to be 'client-data', got", client.Data)
}
if client := provider.GetClientByToken("invalid-token"); client != nil {
t.Error("should've returned nil")
}
}
func TestClientProvider_WithCache(t *testing.T) {
provider := NewClientProvider(getClientByTokenFunc).WithCache(gocache.NoExpiration, 10000)
if provider.cache.(*gocache.Cache).Count() != 0 {
t.Error("expected cache to be empty")
}
if client := provider.GetClientByToken("valid-token"); client == nil {
t.Error("expected client, got nil")
}
if provider.cache.(*gocache.Cache).Count() != 1 {
t.Error("expected cache size to be 1")
}
if client := provider.GetClientByToken("valid-token"); client == nil {
t.Error("expected client, got nil")
}
if provider.cache.(*gocache.Cache).Count() != 1 {
t.Error("expected cache size to be 1")
}
if client := provider.GetClientByToken("invalid-token"); client != nil {
t.Error("expected nil, got", client)
}
if provider.cache.(*gocache.Cache).Count() != 2 {
t.Error("expected cache size to be 2")
}
if client := provider.GetClientByToken("invalid-token"); client != nil {
t.Error("expected nil, got", client)
}
if client := provider.GetClientByToken("invalid-token"); client != nil {
t.Error("should've returned nil (cached)")
}
}
func TestClientProvider_WithCacheAndExpiration(t *testing.T) {
provider := NewClientProvider(getClientByTokenFunc).WithCache(10*time.Millisecond, 10)
provider.GetClientByToken("token")
if provider.cache.(*gocache.Cache).Count() != 1 {
t.Error("expected cache size to be 1")
}
if provider.cache.(*gocache.Cache).Stats().ExpiredKeys != 0 {
t.Error("expected cache statistics to report 0 expired key")
}
time.Sleep(15 * time.Millisecond)
provider.GetClientByToken("token")
if provider.cache.(*gocache.Cache).Stats().ExpiredKeys != 1 {
t.Error("expected cache statistics to report 1 expired key")
}
}
type customCache struct {
entries map[string]any
sync.Mutex
}
func (c *customCache) Get(key string) (value any, exists bool) {
c.Lock()
v, exists := c.entries[key]
c.Unlock()
return v, exists
}
func (c *customCache) Set(key string, value any) {
c.Lock()
if c.entries == nil {
c.entries = make(map[string]any)
}
c.entries[key] = value
c.Unlock()
}
var _ Cache = (*customCache)(nil)
func TestClientProvider_WithCustomCache(t *testing.T) {
provider := NewClientProvider(getClientByTokenFunc).WithCustomCache(&customCache{})
if len(provider.cache.(*customCache).entries) != 0 {
t.Error("expected cache to be empty")
}
if client := provider.GetClientByToken("valid-token"); client == nil {
t.Error("expected client, got nil")
}
if len(provider.cache.(*customCache).entries) != 1 {
t.Error("expected cache size to be 1")
}
if client := provider.GetClientByToken("valid-token"); client == nil {
t.Error("expected client, got nil")
}
if len(provider.cache.(*customCache).entries) != 1 {
t.Error("expected cache size to be 1")
}
if client := provider.GetClientByToken("invalid-token"); client != nil {
t.Error("expected nil, got", client)
}
if len(provider.cache.(*customCache).entries) != 2 {
t.Error("expected cache size to be 2")
}
if client := provider.GetClientByToken("invalid-token"); client != nil {
t.Error("expected nil, got", client)
}
if client := provider.GetClientByToken("invalid-token"); client != nil {
t.Error("should've returned nil (cached)")
}
}

245
gate.go Normal file
View file

@ -0,0 +1,245 @@
package g8
import (
"context"
"net/http"
"strings"
)
const (
// AuthorizationHeader is the header in which g8 looks for the authorization bearer token
AuthorizationHeader = "Authorization"
// DefaultUnauthorizedResponseBody is the default response body returned if a request was sent with a missing or invalid token
DefaultUnauthorizedResponseBody = "token is missing or invalid"
// DefaultTooManyRequestsResponseBody is the default response body returned if a request exceeded the allowed rate limit
DefaultTooManyRequestsResponseBody = "too many requests"
// TokenContextKey is the key used to store the client's token in the context.
TokenContextKey = "g8.token"
// DataContextKey is the key used to store the client's data in the context.
DataContextKey = "g8.data"
)
// Gate is lock to the front door of your API, letting only those you allow through.
type Gate struct {
authorizationService *AuthorizationService
unauthorizedResponseBody []byte
customTokenExtractorFunc func(request *http.Request) string
rateLimiter *RateLimiter
tooManyRequestsResponseBody []byte
}
// Deprecated: use New instead.
func NewGate(authorizationService *AuthorizationService) *Gate {
return &Gate{
authorizationService: authorizationService,
unauthorizedResponseBody: []byte(DefaultUnauthorizedResponseBody),
tooManyRequestsResponseBody: []byte(DefaultTooManyRequestsResponseBody),
}
}
// New creates a new Gate.
func New() *Gate {
return &Gate{
unauthorizedResponseBody: []byte(DefaultUnauthorizedResponseBody),
tooManyRequestsResponseBody: []byte(DefaultTooManyRequestsResponseBody),
}
}
// WithAuthorizationService sets the authorization service to use.
//
// If there is no authorization service, Gate will not enforce authorization.
func (gate *Gate) WithAuthorizationService(authorizationService *AuthorizationService) *Gate {
gate.authorizationService = authorizationService
return gate
}
// WithCustomUnauthorizedResponseBody sets a custom response body when Gate determines that a request must be blocked
func (gate *Gate) WithCustomUnauthorizedResponseBody(unauthorizedResponseBody []byte) *Gate {
gate.unauthorizedResponseBody = unauthorizedResponseBody
return gate
}
// WithCustomTokenExtractor allows the specification of a custom function to extract a token from a request.
// If a custom token extractor is not specified, the token will be extracted from the Authorization header.
//
// For instance, if you're using a session cookie, you can extract the token from the cookie like so:
//
// authorizationService := g8.NewAuthorizationService()
// customTokenExtractorFunc := func(request *http.Request) string {
// sessionCookie, err := request.Cookie("session")
// if err != nil {
// return ""
// }
// return sessionCookie.Value
// }
// gate := g8.New().WithAuthorizationService(authorizationService).WithCustomTokenExtractor(customTokenExtractorFunc)
//
// You would normally use this with a client provider that matches whatever need you have.
// For example, if you're using a session cookie, your client provider would retrieve the user from the session ID
// extracted by this custom token extractor.
//
// Note that for the sake of convenience, the token extracted from the request is passed the protected handlers request
// context under the key TokenContextKey. This is especially useful if the token is in fact a session ID.
func (gate *Gate) WithCustomTokenExtractor(customTokenExtractorFunc func(request *http.Request) string) *Gate {
gate.customTokenExtractorFunc = customTokenExtractorFunc
return gate
}
// WithRateLimit adds rate limiting to the Gate
//
// If you just want to use a gate for rate limiting purposes:
//
// gate := g8.New().WithRateLimit(50)
func (gate *Gate) WithRateLimit(maximumRequestsPerSecond int) *Gate {
gate.rateLimiter = NewRateLimiter(maximumRequestsPerSecond)
return gate
}
// Protect secures a handler, requiring requests going through to have a valid Authorization Bearer token.
// Unlike ProtectWithPermissions, Protect will allow access to any registered tokens, regardless of their permissions
// or lack thereof.
//
// Example:
//
// gate := g8.New().WithAuthorizationService(g8.NewAuthorizationService().WithToken("token"))
// router := http.NewServeMux()
// // Without protection
// router.Handle("/handle", yourHandler)
// // With protection
// router.Handle("/handle", gate.Protect(yourHandler))
//
// The token extracted from the request is passed to the handlerFunc request context under the key TokenContextKey
func (gate *Gate) Protect(handler http.Handler) http.Handler {
return gate.ProtectWithPermissions(handler, nil)
}
// ProtectWithPermissions secures a handler, requiring requests going through to have a valid Authorization Bearer token
// as well as a slice of permissions that must be met.
//
// Example:
//
// gate := g8.New().WithAuthorizationService(g8.NewAuthorizationService().WithClient(g8.NewClient("token").WithPermission("ADMIN")))
// router := http.NewServeMux()
// // Without protection
// router.Handle("/handle", yourHandler)
// // With protection
// router.Handle("/handle", gate.ProtectWithPermissions(yourHandler, []string{"admin"}))
//
// The token extracted from the request is passed to the handlerFunc request context under the key TokenContextKey
func (gate *Gate) ProtectWithPermissions(handler http.Handler, permissions []string) http.Handler {
return gate.ProtectFuncWithPermissions(func(writer http.ResponseWriter, request *http.Request) {
handler.ServeHTTP(writer, request)
}, permissions)
}
// ProtectWithPermission does the same thing as ProtectWithPermissions, but for a single permission instead of a
// slice of permissions
//
// See ProtectWithPermissions for further documentation
func (gate *Gate) ProtectWithPermission(handler http.Handler, permission string) http.Handler {
return gate.ProtectFuncWithPermissions(func(writer http.ResponseWriter, request *http.Request) {
handler.ServeHTTP(writer, request)
}, []string{permission})
}
// ProtectFunc secures a handlerFunc, requiring requests going through to have a valid Authorization Bearer token.
// Unlike ProtectFuncWithPermissions, ProtectFunc will allow access to any registered tokens, regardless of their
// permissions or lack thereof.
//
// Example:
//
// gate := g8.New().WithAuthorizationService(g8.NewAuthorizationService().WithToken("token"))
// router := http.NewServeMux()
// // Without protection
// router.HandleFunc("/handle", yourHandlerFunc)
// // With protection
// router.HandleFunc("/handle", gate.ProtectFunc(yourHandlerFunc))
//
// The token extracted from the request is passed to the handlerFunc request context under the key TokenContextKey
func (gate *Gate) ProtectFunc(handlerFunc http.HandlerFunc) http.HandlerFunc {
return gate.ProtectFuncWithPermissions(handlerFunc, nil)
}
// ProtectFuncWithPermissions secures a handler, requiring requests going through to have a valid Authorization Bearer
// token as well as a slice of permissions that must be met.
//
// Example:
//
// gate := g8.New().WithAuthorizationService(g8.NewAuthorizationService().WithClient(g8.NewClient("token").WithPermission("admin")))
// router := http.NewServeMux()
// // Without protection
// router.HandleFunc("/handle", yourHandlerFunc)
// // With protection
// router.HandleFunc("/handle", gate.ProtectFuncWithPermissions(yourHandlerFunc, []string{"admin"}))
//
// The token extracted from the request is passed to the handlerFunc request context under the key TokenContextKey
func (gate *Gate) ProtectFuncWithPermissions(handlerFunc http.HandlerFunc, permissions []string) http.HandlerFunc {
return func(writer http.ResponseWriter, request *http.Request) {
if gate.rateLimiter != nil {
if !gate.rateLimiter.Try() {
writer.WriteHeader(http.StatusTooManyRequests)
_, _ = writer.Write(gate.tooManyRequestsResponseBody)
return
}
}
if gate.authorizationService != nil {
token := gate.ExtractTokenFromRequest(request)
if client, authorized := gate.authorizationService.Authorize(token, permissions); !authorized {
writer.WriteHeader(http.StatusUnauthorized)
_, _ = writer.Write(gate.unauthorizedResponseBody)
return
} else {
request = request.WithContext(context.WithValue(request.Context(), TokenContextKey, token))
if client != nil && client.Data != nil {
request = request.WithContext(context.WithValue(request.Context(), DataContextKey, client.Data))
}
}
}
handlerFunc(writer, request)
}
}
// ProtectFuncWithPermission does the same thing as ProtectFuncWithPermissions, but for a single permission instead of a
// slice of permissions
//
// See ProtectFuncWithPermissions for further documentation
func (gate *Gate) ProtectFuncWithPermission(handlerFunc http.HandlerFunc, permission string) http.HandlerFunc {
return gate.ProtectFuncWithPermissions(handlerFunc, []string{permission})
}
// ExtractTokenFromRequest extracts a token from a request.
//
// By default, it extracts the bearer token from the AuthorizationHeader, but if a customTokenExtractorFunc is defined,
// it will use that instead.
//
// Note that this method is internally used by Protect, ProtectWithPermission, ProtectFunc and
// ProtectFuncWithPermissions, but it is exposed in case you need to use it directly.
func (gate *Gate) ExtractTokenFromRequest(request *http.Request) string {
if gate.customTokenExtractorFunc != nil {
// A custom token extractor function is defined, so we'll use it instead of the default token extraction logic
return gate.customTokenExtractorFunc(request)
}
return strings.TrimPrefix(request.Header.Get(AuthorizationHeader), "Bearer ")
}
// PermissionMiddleware is a middleware that behaves like ProtectWithPermission, but it is meant to be used
// as a middleware for libraries that support such a feature.
//
// For instance, if you are using github.com/gorilla/mux, you can use PermissionMiddleware like so:
//
// router := mux.NewRouter()
// router.Use(gate.PermissionMiddleware("admin"))
// router.Handle("/admin/handle", adminHandler)
//
// If you do not want to protect a router with a specific permission, you can use Gate.Protect instead.
func (gate *Gate) PermissionMiddleware(permissions ...string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return gate.ProtectWithPermissions(next, permissions)
}
}

208
gate_bench_test.go Normal file
View file

@ -0,0 +1,208 @@
package g8
import (
"fmt"
"net/http"
"net/http/httptest"
"testing"
)
var handler http.Handler = &testHandler{}
func BenchmarkTestHandler(b *testing.B) {
request, _ := http.NewRequest("GET", "/handle", nil)
router := http.NewServeMux()
router.Handle("/handle", handler)
for n := 0; n < b.N; n++ {
responseRecorder := httptest.NewRecorder()
router.ServeHTTP(responseRecorder, request)
if responseRecorder.Code != http.StatusOK {
b.Fatalf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code)
}
}
b.ReportAllocs()
}
func BenchmarkGate_ProtectWhenNoAuthorizationHeader(b *testing.B) {
gate := New().WithAuthorizationService(NewAuthorizationService().WithToken("good-token"))
request, _ := http.NewRequest("GET", "/handle", nil)
router := http.NewServeMux()
router.Handle("/handle", gate.Protect(handler))
for n := 0; n < b.N; n++ {
responseRecorder := httptest.NewRecorder()
router.ServeHTTP(responseRecorder, request)
if responseRecorder.Code != http.StatusUnauthorized {
b.Fatalf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusUnauthorized, responseRecorder.Code)
}
}
b.ReportAllocs()
}
func BenchmarkGate_ProtectWithInvalidToken(b *testing.B) {
gate := New().WithAuthorizationService(NewAuthorizationService().WithToken("good-token"))
request, _ := http.NewRequest("GET", "/handle", nil)
request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "bad-token"))
router := http.NewServeMux()
router.Handle("/handle", gate.Protect(handler))
for n := 0; n < b.N; n++ {
responseRecorder := httptest.NewRecorder()
router.ServeHTTP(responseRecorder, request)
if responseRecorder.Code != http.StatusUnauthorized {
b.Fatalf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusUnauthorized, responseRecorder.Code)
}
}
b.ReportAllocs()
}
func BenchmarkGate_ProtectWithValidToken(b *testing.B) {
gate := New().WithAuthorizationService(NewAuthorizationService().WithToken("good-token"))
request, _ := http.NewRequest("GET", "/handle", http.NoBody)
request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "good-token"))
router := http.NewServeMux()
router.Handle("/handle", gate.Protect(handler))
for n := 0; n < b.N; n++ {
responseRecorder := httptest.NewRecorder()
router.ServeHTTP(responseRecorder, request)
if responseRecorder.Code != http.StatusOK {
b.Fatalf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code)
}
}
b.ReportAllocs()
}
func BenchmarkGate_ProtectWithPermissionsAndValidToken(b *testing.B) {
gate := New().WithAuthorizationService(NewAuthorizationService().WithClient(NewClient("token").WithPermission("admin")))
request, _ := http.NewRequest("GET", "/handle", http.NoBody)
request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "token"))
router := http.NewServeMux()
router.Handle("/handle", gate.ProtectWithPermissions(handler, []string{"admin"}))
for n := 0; n < b.N; n++ {
responseRecorder := httptest.NewRecorder()
router.ServeHTTP(responseRecorder, request)
if responseRecorder.Code != http.StatusOK {
b.Fatalf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code)
}
}
b.ReportAllocs()
}
func BenchmarkGate_ProtectWithPermissionsAndValidTokenButInsufficientPermissions(b *testing.B) {
gate := New().WithAuthorizationService(NewAuthorizationService().WithClient(NewClient("token").WithPermission("mod")))
request, _ := http.NewRequest("GET", "/handle", http.NoBody)
request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "token"))
router := http.NewServeMux()
router.Handle("/handle", gate.ProtectWithPermissions(handler, []string{"admin"}))
for n := 0; n < b.N; n++ {
responseRecorder := httptest.NewRecorder()
router.ServeHTTP(responseRecorder, request)
if responseRecorder.Code != http.StatusUnauthorized {
b.Fatalf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusUnauthorized, responseRecorder.Code)
}
}
b.ReportAllocs()
}
func BenchmarkGate_ProtectConcurrently(b *testing.B) {
gate := New().WithAuthorizationService(NewAuthorizationService().WithToken("good-token"))
request, _ := http.NewRequest("GET", "/handle", http.NoBody)
request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "good-token"))
badRequest, _ := http.NewRequest("GET", "/handle", http.NoBody)
badRequest.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "bad-token"))
router := http.NewServeMux()
router.Handle("/handle", gate.Protect(handler))
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
responseRecorder := httptest.NewRecorder()
router.ServeHTTP(responseRecorder, request)
if responseRecorder.Code != http.StatusOK {
b.Fatalf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code)
}
responseRecorder = httptest.NewRecorder()
router.ServeHTTP(responseRecorder, badRequest)
if responseRecorder.Code != http.StatusUnauthorized {
b.Fatalf("%s %s should have returned %d, but returned %d instead", badRequest.Method, badRequest.URL, http.StatusUnauthorized, responseRecorder.Code)
}
}
})
b.ReportAllocs()
}
func BenchmarkGate_ProtectWithClientProviderConcurrently(b *testing.B) {
gate := New().WithAuthorizationService(NewAuthorizationService().WithClientProvider(mockClientProvider))
request, _ := http.NewRequest("GET", "/handle", http.NoBody)
request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", TestProviderClientToken))
firstBadRequest, _ := http.NewRequest("GET", "/handle", http.NoBody)
firstBadRequest.Header.Set("Authorization", fmt.Sprintf("Bearer %s", "bad-token-1"))
secondBadRequest, _ := http.NewRequest("GET", "/handle", http.NoBody)
secondBadRequest.Header.Set("Authorization", fmt.Sprintf("Bearer %s", "bad-token-2"))
router := http.NewServeMux()
router.Handle("/handle", gate.Protect(handler))
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
responseRecorder := httptest.NewRecorder()
router.ServeHTTP(responseRecorder, request)
if responseRecorder.Code != http.StatusOK {
b.Fatalf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code)
}
responseRecorder = httptest.NewRecorder()
router.ServeHTTP(responseRecorder, firstBadRequest)
if responseRecorder.Code != http.StatusUnauthorized {
b.Fatalf("%s %s should have returned %d, but returned %d instead", firstBadRequest.Method, firstBadRequest.URL, http.StatusUnauthorized, responseRecorder.Code)
}
responseRecorder = httptest.NewRecorder()
router.ServeHTTP(responseRecorder, secondBadRequest)
if responseRecorder.Code != http.StatusUnauthorized {
b.Fatalf("%s %s should have returned %d, but returned %d instead", secondBadRequest.Method, secondBadRequest.URL, http.StatusUnauthorized, responseRecorder.Code)
}
}
})
b.ReportAllocs()
}
func BenchmarkGate_ProtectWithValidTokenAndCustomTokenExtractorFuncConcurrently(b *testing.B) {
customTokenExtractorFunc := func(request *http.Request) string {
sessionCookie, err := request.Cookie("session")
if err != nil {
return ""
}
return sessionCookie.Value
}
gate := New().WithAuthorizationService(NewAuthorizationService().WithToken("good-token")).WithCustomTokenExtractor(customTokenExtractorFunc)
request, _ := http.NewRequest("GET", "/handle", http.NoBody)
request.AddCookie(&http.Cookie{Name: "session", Value: "good-token"})
router := http.NewServeMux()
router.Handle("/handle", gate.Protect(handler))
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
responseRecorder := httptest.NewRecorder()
router.ServeHTTP(responseRecorder, request)
if responseRecorder.Code != http.StatusOK {
b.Fatalf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code)
}
}
})
b.ReportAllocs()
}

538
gate_test.go Normal file
View file

@ -0,0 +1,538 @@
package g8
import (
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
)
const (
FirstTestProviderClientPermission = "permission-1"
SecondTestProviderClientPermission = "permission-2"
TestProviderClientToken = "client-token-from-provider"
TestProviderClientData = "client-data-from-provider"
)
var (
mockClientProvider = NewClientProvider(func(token string) *Client {
// We'll pretend that there's only one token that's valid in the client provider, every other token
// returns nil
if token == TestProviderClientToken {
return &Client{
Token: TestProviderClientToken,
Data: TestProviderClientData,
Permissions: []string{FirstTestProviderClientPermission, SecondTestProviderClientPermission},
}
}
return nil
})
)
type testHandler struct {
}
func (handler *testHandler) ServeHTTP(writer http.ResponseWriter, _ *http.Request) {
writer.WriteHeader(http.StatusOK)
}
func testHandlerFunc(writer http.ResponseWriter, _ *http.Request) {
writer.WriteHeader(http.StatusOK)
}
func TestUsability(t *testing.T) {
gate := New().WithAuthorizationService(NewAuthorizationService().WithToken("good-token"))
var handler http.Handler = &testHandler{}
handlerFunc := func(writer http.ResponseWriter, request *http.Request) {
writer.WriteHeader(http.StatusOK)
}
router := http.NewServeMux()
router.Handle("/handle", handler)
router.Handle("/handle-protected", gate.Protect(handler))
router.HandleFunc("/handlefunc", handlerFunc)
router.HandleFunc("/handlefunc-protected", gate.ProtectFunc(handlerFunc))
}
func TestNewGate(t *testing.T) {
gate := NewGate(nil)
if gate == nil {
t.Error("gate should not be nil")
}
}
func TestUnprotectedHandler(t *testing.T) {
request, _ := http.NewRequest("GET", "/handle", http.NoBody)
request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "bad-token"))
responseRecorder := httptest.NewRecorder()
router := http.NewServeMux()
router.Handle("/handle", &testHandler{})
router.ServeHTTP(responseRecorder, request)
if responseRecorder.Code != http.StatusOK {
t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code)
}
}
func TestGate_ProtectWithInvalidToken(t *testing.T) {
gate := New().WithAuthorizationService(NewAuthorizationService().WithToken("good-token"))
request, _ := http.NewRequest("GET", "/handle", http.NoBody)
request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "bad-token"))
responseRecorder := httptest.NewRecorder()
router := http.NewServeMux()
router.Handle("/handle", gate.Protect(&testHandler{}))
router.ServeHTTP(responseRecorder, request)
if responseRecorder.Code != http.StatusUnauthorized {
t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusUnauthorized, responseRecorder.Code)
}
}
func TestGate_ProtectWithValidToken(t *testing.T) {
gate := New().WithAuthorizationService(NewAuthorizationService().WithToken("good-token"))
request, _ := http.NewRequest("GET", "/handle", http.NoBody)
request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "good-token"))
responseRecorder := httptest.NewRecorder()
router := http.NewServeMux()
router.Handle("/handle", gate.Protect(&testHandler{}))
router.ServeHTTP(responseRecorder, request)
if responseRecorder.Code != http.StatusOK {
t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code)
}
}
func TestGate_ProtectMultipleTimes(t *testing.T) {
gate := New().WithAuthorizationService(NewAuthorizationService().WithToken("good-token"))
request, _ := http.NewRequest("GET", "/handle", http.NoBody)
request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "good-token"))
badRequest, _ := http.NewRequest("GET", "/handle", http.NoBody)
badRequest.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "bad-token"))
router := http.NewServeMux()
router.Handle("/handle", gate.Protect(&testHandler{}))
for i := 0; i < 100; i++ {
responseRecorder := httptest.NewRecorder()
router.ServeHTTP(responseRecorder, request)
if responseRecorder.Code != http.StatusOK {
t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code)
}
responseRecorder = httptest.NewRecorder()
router.ServeHTTP(responseRecorder, badRequest)
if responseRecorder.Code != http.StatusUnauthorized {
t.Errorf("%s %s should have returned %d, but returned %d instead", badRequest.Method, badRequest.URL, http.StatusOK, responseRecorder.Code)
}
}
}
func TestGate_ProtectWithValidTokenExposedThroughClientProvider(t *testing.T) {
gate := New().WithAuthorizationService(NewAuthorizationService().WithClientProvider(mockClientProvider))
request, _ := http.NewRequest("GET", "/handle", http.NoBody)
request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", TestProviderClientToken))
responseRecorder := httptest.NewRecorder()
router := http.NewServeMux()
router.Handle("/handle", gate.Protect(&testHandler{}))
router.ServeHTTP(responseRecorder, request)
if responseRecorder.Code != http.StatusOK {
t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code)
}
}
func TestGate_ProtectWithValidTokenExposedThroughClientProviderWithCache(t *testing.T) {
gate := New().WithAuthorizationService(NewAuthorizationService().WithClientProvider(mockClientProvider.WithCache(60*time.Minute, 70000)))
request, _ := http.NewRequest("GET", "/handle", http.NoBody)
request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", TestProviderClientToken))
responseRecorder := httptest.NewRecorder()
router := http.NewServeMux()
router.Handle("/handle", gate.Protect(&testHandler{}))
router.ServeHTTP(responseRecorder, request)
if responseRecorder.Code != http.StatusOK {
t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code)
}
}
func TestGate_ProtectWithInvalidTokenWhenUsingClientProvider(t *testing.T) {
gate := New().WithAuthorizationService(NewAuthorizationService().WithClientProvider(mockClientProvider))
request, _ := http.NewRequest("GET", "/handle", http.NoBody)
request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "bad-token"))
responseRecorder := httptest.NewRecorder()
router := http.NewServeMux()
router.Handle("/handle", gate.Protect(&testHandler{}))
router.ServeHTTP(responseRecorder, request)
if responseRecorder.Code != http.StatusUnauthorized {
t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusUnauthorized, responseRecorder.Code)
}
}
func TestGate_ProtectWithPermissionsWhenValidTokenAndSufficientPermissionsWhileUsingClientProvider(t *testing.T) {
gate := New().WithAuthorizationService(NewAuthorizationService().WithClientProvider(mockClientProvider))
request, _ := http.NewRequest("GET", "/handle", http.NoBody)
request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", TestProviderClientToken))
responseRecorder := httptest.NewRecorder()
router := http.NewServeMux()
router.Handle("/handle", gate.ProtectWithPermissions(&testHandler{}, []string{SecondTestProviderClientPermission}))
router.ServeHTTP(responseRecorder, request)
// Since the client returned from the mockClientProvider has FirstTestProviderClientPermission and
// SecondTestProviderClientPermission and the testHandler is protected by SecondTestProviderClientPermission,
// the request should be authorized
if responseRecorder.Code != http.StatusOK {
t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code)
}
}
func TestGate_ProtectWithPermissionsWhenValidTokenAndInsufficientPermissionsWhileUsingClientProvider(t *testing.T) {
gate := New().WithAuthorizationService(NewAuthorizationService().WithClientProvider(mockClientProvider))
request, _ := http.NewRequest("GET", "/handle", http.NoBody)
request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", TestProviderClientToken))
responseRecorder := httptest.NewRecorder()
router := http.NewServeMux()
router.Handle("/handle", gate.ProtectWithPermissions(&testHandler{}, []string{"unrelated-permission"}))
router.ServeHTTP(responseRecorder, request)
// Since the client returned from the mockClientProvider has FirstTestProviderClientPermission and
// SecondTestProviderClientPermission and the testHandler is protected by a permission that the client does not
// have, the request should be not be authorized
if responseRecorder.Code != http.StatusUnauthorized {
t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusUnauthorized, responseRecorder.Code)
}
}
func TestGate_ProtectWithPermissionsWhenClientHasSufficientPermissions(t *testing.T) {
gate := New().WithAuthorizationService(NewAuthorizationService().WithClient(NewClient("token").WithPermission("admin")))
request, _ := http.NewRequest("GET", "/handle", http.NoBody)
request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "token"))
responseRecorder := httptest.NewRecorder()
router := http.NewServeMux()
router.Handle("/handle", gate.ProtectWithPermissions(&testHandler{}, []string{"admin"}))
router.ServeHTTP(responseRecorder, request)
// Since the client registered directly in the AuthorizationService has the permission "admin" and the testHandler
// is protected by the permission "admin", the request should be authorized
if responseRecorder.Code != http.StatusOK {
t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code)
}
}
func TestGate_ProtectWithPermissionsWhenClientHasInsufficientPermissions(t *testing.T) {
gate := New().WithAuthorizationService(NewAuthorizationService().WithClient(NewClientWithPermissions("token", []string{"mod"})))
request, _ := http.NewRequest("GET", "/handle", http.NoBody)
request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "token"))
responseRecorder := httptest.NewRecorder()
router := http.NewServeMux()
router.Handle("/handle", gate.ProtectWithPermissions(&testHandler{}, []string{"admin"}))
router.ServeHTTP(responseRecorder, request)
// Since the client registered directly in the AuthorizationService has the permission "mod" and the
// testHandler is protected by the permission "admin", the request should be not be authorized
if responseRecorder.Code != http.StatusUnauthorized {
t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusUnauthorized, responseRecorder.Code)
}
}
func TestGate_ProtectWithPermissions(t *testing.T) {
gate := New().WithAuthorizationService(NewAuthorizationService().WithClient(NewClient("mytoken").WithPermissions([]string{"create", "read", "update", "delete"})))
router := http.NewServeMux()
router.Handle("/create", gate.ProtectWithPermissions(&testHandler{}, []string{"create"}))
router.Handle("/read", gate.ProtectWithPermissions(&testHandler{}, []string{"read"}))
router.Handle("/update", gate.ProtectWithPermissions(&testHandler{}, []string{"update"}))
router.Handle("/delete", gate.ProtectWithPermissions(&testHandler{}, []string{"delete"}))
router.Handle("/crud", gate.ProtectWithPermissions(&testHandler{}, []string{"create", "read", "update", "delete"}))
router.Handle("/backup", gate.ProtectWithPermissions(&testHandler{}, []string{"read", "backup"}))
checkRouterOutput := func(t *testing.T, router *http.ServeMux, url string, expectedResponseCode int) {
t.Run(strings.TrimPrefix(url, "/"), func(t *testing.T) {
request, _ := http.NewRequest("GET", url, http.NoBody)
request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "mytoken"))
responseRecorder := httptest.NewRecorder()
router.ServeHTTP(responseRecorder, request)
if responseRecorder.Code != expectedResponseCode {
t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, expectedResponseCode, responseRecorder.Code)
}
})
}
checkRouterOutput(t, router, "/create", http.StatusOK)
checkRouterOutput(t, router, "/read", http.StatusOK)
checkRouterOutput(t, router, "/update", http.StatusOK)
checkRouterOutput(t, router, "/delete", http.StatusOK)
checkRouterOutput(t, router, "/crud", http.StatusOK)
checkRouterOutput(t, router, "/backup", http.StatusUnauthorized)
}
func TestGate_ProtectWithPermissionWhenClientHasSufficientPermissions(t *testing.T) {
gate := New().WithAuthorizationService(NewAuthorizationService().WithClient(NewClient("token").WithPermission("admin")))
request, _ := http.NewRequest("GET", "/handle", http.NoBody)
request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "token"))
responseRecorder := httptest.NewRecorder()
router := http.NewServeMux()
router.Handle("/handle", gate.ProtectWithPermission(&testHandler{}, "admin"))
router.ServeHTTP(responseRecorder, request)
// Since the client registered directly in the AuthorizationService has the permission "admin" and the testHandler
// is protected by the permission "admin", the request should be authorized
if responseRecorder.Code != http.StatusOK {
t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code)
}
}
func TestGate_ProtectWithPermissionWhenClientHasInsufficientPermissions(t *testing.T) {
gate := New().WithAuthorizationService(NewAuthorizationService().WithClient(NewClientWithPermissions("token", []string{"mod"})))
request, _ := http.NewRequest("GET", "/handle", http.NoBody)
request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "token"))
responseRecorder := httptest.NewRecorder()
router := http.NewServeMux()
router.Handle("/handle", gate.ProtectWithPermission(&testHandler{}, "admin"))
router.ServeHTTP(responseRecorder, request)
// Since the client registered directly in the AuthorizationService has the permission "mod" and the
// testHandler is protected by the permission "admin", the request should be not be authorized
if responseRecorder.Code != http.StatusUnauthorized {
t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusUnauthorized, responseRecorder.Code)
}
}
func TestGate_PermissionMiddlewareWhenClientHasSufficientPermissions(t *testing.T) {
gate := New().WithAuthorizationService(NewAuthorizationService().WithClient(NewClient("token").WithPermission("admin")))
request, _ := http.NewRequest("GET", "/handle", http.NoBody)
request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "token"))
responseRecorder := httptest.NewRecorder()
router := http.NewServeMux()
router.Handle("/handle", gate.PermissionMiddleware("admin")(&testHandler{}))
router.ServeHTTP(responseRecorder, request)
// Since the client registered directly in the AuthorizationService has the permission "admin" and the testHandler
// is protected by the permission "admin", the request should be authorized
if responseRecorder.Code != http.StatusOK {
t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code)
}
}
func TestGate_PermissionMiddlewareWhenClientHasInsufficientPermissions(t *testing.T) {
gate := New().WithAuthorizationService(NewAuthorizationService().WithClient(NewClientWithPermissions("token", []string{"mod"})))
request, _ := http.NewRequest("GET", "/handle", http.NoBody)
request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "token"))
responseRecorder := httptest.NewRecorder()
router := http.NewServeMux()
router.Handle("/handle", gate.PermissionMiddleware("admin")(&testHandler{}))
router.ServeHTTP(responseRecorder, request)
// Since the client registered directly in the AuthorizationService has the permission "mod" and the
// testHandler is protected by the permission "admin", the request should be not be authorized
if responseRecorder.Code != http.StatusUnauthorized {
t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusUnauthorized, responseRecorder.Code)
}
}
func TestGate_ProtectFuncWithInvalidToken(t *testing.T) {
gate := New().WithAuthorizationService(NewAuthorizationService().WithToken("good-token"))
request, _ := http.NewRequest("GET", "/handle", http.NoBody)
request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "bad-token"))
responseRecorder := httptest.NewRecorder()
router := http.NewServeMux()
router.Handle("/handle", gate.ProtectFunc(testHandlerFunc))
router.ServeHTTP(responseRecorder, request)
if responseRecorder.Code != http.StatusUnauthorized {
t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusUnauthorized, responseRecorder.Code)
}
}
func TestGate_ProtectFuncWithValidToken(t *testing.T) {
gate := New().WithAuthorizationService(NewAuthorizationService().WithToken("good-token"))
request, _ := http.NewRequest("GET", "/handle", http.NoBody)
request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "good-token"))
responseRecorder := httptest.NewRecorder()
router := http.NewServeMux()
router.Handle("/handle", gate.ProtectFunc(testHandlerFunc))
router.ServeHTTP(responseRecorder, request)
if responseRecorder.Code != http.StatusOK {
t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code)
}
}
func TestGate_ProtectFuncWithPermissionWhenClientHasSufficientPermissions(t *testing.T) {
gate := New().WithAuthorizationService(NewAuthorizationService().WithClient(NewClient("token").WithPermission("admin")))
request, _ := http.NewRequest("GET", "/handle", http.NoBody)
request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "token"))
responseRecorder := httptest.NewRecorder()
router := http.NewServeMux()
router.HandleFunc("/handle", gate.ProtectFuncWithPermission(testHandlerFunc, "admin"))
router.ServeHTTP(responseRecorder, request)
// Since the client registered directly in the AuthorizationService has the permission "admin" and the testHandler
// is protected by the permission "admin", the request should be authorized
if responseRecorder.Code != http.StatusOK {
t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code)
}
}
func TestGate_ProtectFuncWithPermissionWhenClientHasInsufficientPermissions(t *testing.T) {
gate := New().WithAuthorizationService(NewAuthorizationService().WithClient(NewClientWithPermissions("token", []string{"mod"})))
request, _ := http.NewRequest("GET", "/handle", http.NoBody)
request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "token"))
responseRecorder := httptest.NewRecorder()
router := http.NewServeMux()
router.HandleFunc("/handle", gate.ProtectFuncWithPermission(testHandlerFunc, "admin"))
router.ServeHTTP(responseRecorder, request)
// Since the client registered directly in the AuthorizationService has the permission "mod" and the
// testHandler is protected by the permission "admin", the request should be not be authorized
if responseRecorder.Code != http.StatusUnauthorized {
t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusUnauthorized, responseRecorder.Code)
}
}
func TestGate_WithCustomUnauthorizedResponseBody(t *testing.T) {
gate := New().WithAuthorizationService(NewAuthorizationService()).WithCustomUnauthorizedResponseBody([]byte("test"))
request, _ := http.NewRequest("GET", "/handle", http.NoBody)
request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "bad-token"))
responseRecorder := httptest.NewRecorder()
router := http.NewServeMux()
router.Handle("/handle", gate.Protect(&testHandler{}))
router.ServeHTTP(responseRecorder, request)
if responseRecorder.Code != http.StatusUnauthorized {
t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusUnauthorized, responseRecorder.Code)
}
if responseBody, _ := io.ReadAll(responseRecorder.Body); string(responseBody) != "test" {
t.Errorf("%s %s should have returned %s, but returned %s instead", request.Method, request.URL, "test", string(responseBody))
}
}
func TestGate_ProtectWithNoAuthorizationService(t *testing.T) {
gate := New()
request, _ := http.NewRequest("GET", "/handle", http.NoBody)
responseRecorder := httptest.NewRecorder()
router := http.NewServeMux()
router.Handle("/handle", gate.Protect(&testHandler{}))
router.ServeHTTP(responseRecorder, request)
if responseRecorder.Code != http.StatusOK {
t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code)
}
}
func TestGate_ProtectWithRateLimit(t *testing.T) {
gate := New().WithRateLimit(2)
request, _ := http.NewRequest("GET", "/handle", http.NoBody)
router := http.NewServeMux()
router.Handle("/handle", gate.Protect(&testHandler{}))
responseRecorder := httptest.NewRecorder()
router.ServeHTTP(responseRecorder, request)
if responseRecorder.Code != http.StatusOK {
t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code)
}
responseRecorder = httptest.NewRecorder()
router.ServeHTTP(responseRecorder, request)
if responseRecorder.Code != http.StatusOK {
t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code)
}
responseRecorder = httptest.NewRecorder()
router.ServeHTTP(responseRecorder, request)
if responseRecorder.Code != http.StatusTooManyRequests {
t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusTooManyRequests, responseRecorder.Code)
}
// Wait for rate limit time window to pass
time.Sleep(time.Second)
responseRecorder = httptest.NewRecorder()
router.ServeHTTP(responseRecorder, request)
if responseRecorder.Code != http.StatusOK {
t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code)
}
}
func TestGate_WithCustomTokenExtractor(t *testing.T) {
authorizationService := NewAuthorizationService().WithClientProvider(mockClientProvider)
customTokenExtractorFunc := func(request *http.Request) string {
sessionCookie, err := request.Cookie("session")
if err != nil {
return ""
}
return sessionCookie.Value
}
gate := New().WithAuthorizationService(authorizationService).WithCustomTokenExtractor(customTokenExtractorFunc)
request, _ := http.NewRequest("GET", "/handle", http.NoBody)
request.AddCookie(&http.Cookie{Name: "session", Value: TestProviderClientToken})
responseRecorder := httptest.NewRecorder()
router := http.NewServeMux()
router.Handle("/handle", gate.ProtectFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Context().Value(TokenContextKey) != TestProviderClientToken {
t.Errorf("token should have been passed to the request context")
}
if r.Context().Value(DataContextKey) != TestProviderClientData {
t.Errorf("data should have been passed to the request context")
}
w.WriteHeader(http.StatusOK)
}))
router.ServeHTTP(responseRecorder, request)
if responseRecorder.Code != http.StatusOK {
t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code)
}
}
func TestGateWithCustomHeader(t *testing.T) {
authorizationService := NewAuthorizationService().WithClientProvider(mockClientProvider)
customTokenExtractorFunc := func(request *http.Request) string {
return request.Header.Get("X-API-Token")
}
gate := New().WithAuthorizationService(authorizationService).WithCustomTokenExtractor(customTokenExtractorFunc)
request, _ := http.NewRequest("GET", "/handle", http.NoBody)
request.Header.Add("X-API-Token", TestProviderClientToken)
responseRecorder := httptest.NewRecorder()
router := http.NewServeMux()
router.Handle("/handle", gate.ProtectFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Context().Value(TokenContextKey) != TestProviderClientToken {
t.Errorf("token should have been passed to the request context")
}
if r.Context().Value(DataContextKey) != TestProviderClientData {
t.Errorf("data should have been passed to the request context")
}
w.WriteHeader(http.StatusOK)
}))
router.ServeHTTP(responseRecorder, request)
if responseRecorder.Code != http.StatusOK {
t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code)
}
}

5
go.mod Normal file
View file

@ -0,0 +1,5 @@
module github.com/TwiN/g8/v3
go 1.23.3
require github.com/TwiN/gocache/v2 v2.2.2

2
go.sum Normal file
View file

@ -0,0 +1,2 @@
github.com/TwiN/gocache/v2 v2.2.2 h1:4HToPfDV8FSbaYO5kkbhLpEllUYse5rAf+hVU/mSsuI=
github.com/TwiN/gocache/v2 v2.2.2/go.mod h1:WfIuwd7GR82/7EfQqEtmLFC3a2vqaKbs4Pe6neB7Gyc=

42
ratelimiter.go Normal file
View file

@ -0,0 +1,42 @@
package g8
import (
"sync"
"time"
)
// RateLimiter is a fixed rate limiter
type RateLimiter struct {
maximumExecutionsPerSecond int
executionsLeftInWindow int
windowStartTime time.Time
mutex sync.Mutex
}
// NewRateLimiter creates a RateLimiter
func NewRateLimiter(maximumExecutionsPerSecond int) *RateLimiter {
return &RateLimiter{
windowStartTime: time.Now(),
executionsLeftInWindow: maximumExecutionsPerSecond,
maximumExecutionsPerSecond: maximumExecutionsPerSecond,
}
}
// Try updates the number of executions if the rate limit quota hasn't been reached and returns whether the
// attempt was successful or not.
//
// Returns false if the execution was not successful (rate limit quota has been reached)
// Returns true if the execution was successful (rate limit quota has not been reached)
func (r *RateLimiter) Try() bool {
r.mutex.Lock()
defer r.mutex.Unlock()
if time.Now().Add(-time.Second).After(r.windowStartTime) {
r.windowStartTime = time.Now()
r.executionsLeftInWindow = r.maximumExecutionsPerSecond
}
if r.executionsLeftInWindow == 0 {
return false
}
r.executionsLeftInWindow--
return true
}

73
ratelimiter_test.go Normal file
View file

@ -0,0 +1,73 @@
package g8
import (
"testing"
"time"
)
func TestNewRateLimiter(t *testing.T) {
rl := NewRateLimiter(2)
if rl.maximumExecutionsPerSecond != 2 {
t.Errorf("expected maximumExecutionsPerSecond to be %d, got %d", 2, rl.maximumExecutionsPerSecond)
}
if rl.executionsLeftInWindow != 2 {
t.Errorf("expected executionsLeftInWindow to be %d, got %d", 2, rl.executionsLeftInWindow)
}
// First execution: should not be rate limited
if notRateLimited := rl.Try(); !notRateLimited {
t.Error("expected Try to return true")
}
if rl.maximumExecutionsPerSecond != 2 {
t.Errorf("expected maximumExecutionsPerSecond to be %d, got %d", 2, rl.maximumExecutionsPerSecond)
}
if rl.executionsLeftInWindow != 1 {
t.Errorf("expected executionsLeftInWindow to be %d, got %d", 1, rl.executionsLeftInWindow)
}
// Second execution: should not be rate limited
if notRateLimited := rl.Try(); !notRateLimited {
t.Error("expected Try to return true")
}
if rl.maximumExecutionsPerSecond != 2 {
t.Errorf("expected maximumExecutionsPerSecond to be %d, got %d", 2, rl.maximumExecutionsPerSecond)
}
if rl.executionsLeftInWindow != 0 {
t.Errorf("expected executionsLeftInWindow to be %d, got %d", 0, rl.executionsLeftInWindow)
}
// Third execution: should be rate limited
if notRateLimited := rl.Try(); notRateLimited {
t.Error("expected Try to return false")
}
if rl.maximumExecutionsPerSecond != 2 {
t.Errorf("expected maximumExecutionsPerSecond to be %d, got %d", 2, rl.maximumExecutionsPerSecond)
}
if rl.executionsLeftInWindow != 0 {
t.Errorf("expected executionsLeftInWindow to be %d, got %d", 0, rl.executionsLeftInWindow)
}
}
func TestRateLimiter_Try(t *testing.T) {
rl := NewRateLimiter(5)
for i := 0; i < 20; i++ {
notRateLimited := rl.Try()
if i < 5 {
if !notRateLimited {
t.Fatal("expected to not be rate limited")
}
} else {
if notRateLimited {
t.Fatal("expected to be rate limited")
}
}
}
}
func TestRateLimiter_TryAlwaysUnderRateLimit(t *testing.T) {
rl := NewRateLimiter(20)
for i := 0; i < 45; i++ {
notRateLimited := rl.Try()
if !notRateLimited {
t.Fatal("expected to not be rate limited")
}
time.Sleep(51 * time.Millisecond)
}
}