Adding upstream version 3.0.0.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
4199417ac3
commit
8274b1bf1b
21 changed files with 2147 additions and 0 deletions
1
.gitattributes
vendored
Normal file
1
.gitattributes
vendored
Normal file
|
@ -0,0 +1 @@
|
|||
* text=lf
|
7
.github/codecov.yml
vendored
Normal file
7
.github/codecov.yml
vendored
Normal file
|
@ -0,0 +1,7 @@
|
|||
coverage:
|
||||
status:
|
||||
patch: off
|
||||
project:
|
||||
default:
|
||||
target: 75%
|
||||
threshold: null
|
14
.github/dependabot.yml
vendored
Normal file
14
.github/dependabot.yml
vendored
Normal file
|
@ -0,0 +1,14 @@
|
|||
version: 2
|
||||
updates:
|
||||
- package-ecosystem: "github-actions"
|
||||
directory: "/"
|
||||
labels: ["dependencies"]
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
day: "saturday"
|
||||
- package-ecosystem: "gomod"
|
||||
directory: "/"
|
||||
labels: ["dependencies"]
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
day: "saturday"
|
30
.github/workflows/test.yml
vendored
Normal file
30
.github/workflows/test.yml
vendored
Normal file
|
@ -0,0 +1,30 @@
|
|||
name: test
|
||||
on:
|
||||
pull_request:
|
||||
paths-ignore:
|
||||
- '*.md'
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
paths-ignore:
|
||||
- '*.md'
|
||||
jobs:
|
||||
test:
|
||||
name: test
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 3
|
||||
steps:
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: 1.23.3
|
||||
- uses: actions/checkout@v4
|
||||
- name: Test (race)
|
||||
run: go test ./... -race
|
||||
- name: Test (coverage)
|
||||
run: go test ./... -coverprofile=coverage.txt -covermode=atomic
|
||||
- name: Codecov
|
||||
uses: codecov/codecov-action@v5.1.2
|
||||
with:
|
||||
files: ./coverage.txt
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
|
3
.gitignore
vendored
Normal file
3
.gitignore
vendored
Normal file
|
@ -0,0 +1,3 @@
|
|||
.idea
|
||||
*.iml
|
||||
/vendor
|
21
LICENSE
Normal file
21
LICENSE
Normal file
|
@ -0,0 +1,21 @@
|
|||
MIT License
|
||||
|
||||
Copyright (c) 2021-2025 TwiN
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
319
README.md
Normal file
319
README.md
Normal file
|
@ -0,0 +1,319 @@
|
|||
# g8
|
||||
|
||||

|
||||
[](https://goreportcard.com/report/github.com/TwiN/g8/v3)
|
||||
[](https://codecov.io/gh/TwiN/g8)
|
||||
[](https://github.com/TwiN/g8)
|
||||
[](https://pkg.go.dev/github.com/TwiN/g8/v3)
|
||||
[](https://github.com/TwiN)
|
||||
|
||||
g8, pronounced gate, is a simple Go library for protecting HTTP handlers.
|
||||
|
||||
Tired of constantly re-implementing a security layer for each application? Me too, that's why I made g8.
|
||||
|
||||
|
||||
## Installation
|
||||
```console
|
||||
go get -u github.com/TwiN/g8/v3
|
||||
```
|
||||
|
||||
|
||||
## Usage
|
||||
Because the entire purpose of g8 is to NOT waste time configuring the layer of security, the primary emphasis is to
|
||||
keep it as simple as possible.
|
||||
|
||||
|
||||
### Simple
|
||||
Just want a simple layer of security without the need for advanced permissions? This configuration is what you're
|
||||
looking for.
|
||||
|
||||
```go
|
||||
authorizationService := g8.NewAuthorizationService().WithToken("mytoken")
|
||||
gate := g8.New().WithAuthorizationService(authorizationService)
|
||||
|
||||
router := http.NewServeMux()
|
||||
router.Handle("/unprotected", yourHandler)
|
||||
router.Handle("/protected", gate.Protect(yourHandler))
|
||||
|
||||
http.ListenAndServe(":8080", router)
|
||||
```
|
||||
|
||||
The endpoint `/protected` is now only accessible if you pass the header `Authorization: Bearer mytoken`.
|
||||
|
||||
If you use `http.HandleFunc` instead of `http.Handle`, you may use `gate.ProtectFunc(yourHandler)` instead.
|
||||
|
||||
If you're not using the `Authorization` header, you can specify a custom token extractor.
|
||||
This enables use cases like [Protecting a handler using session cookie](#protecting-a-handler-using-session-cookie)
|
||||
|
||||
|
||||
### Advanced permissions
|
||||
If you have tokens with more permissions than others, g8's permission system will make managing authorization a breeze.
|
||||
|
||||
Rather than registering tokens, think of it as registering clients, the only difference being that clients may be
|
||||
configured with permissions while tokens cannot.
|
||||
|
||||
```go
|
||||
authorizationService := g8.NewAuthorizationService().WithClient(g8.NewClient("mytoken").WithPermission("admin"))
|
||||
gate := g8.New().WithAuthorizationService(authorizationService)
|
||||
|
||||
router := http.NewServeMux()
|
||||
router.Handle("/unprotected", yourHandler)
|
||||
router.Handle("/protected-with-admin", gate.ProtectWithPermissions(yourHandler, []string{"admin"}))
|
||||
|
||||
http.ListenAndServe(":8080", router)
|
||||
```
|
||||
|
||||
The endpoint `/protected-with-admin` is now only accessible if you pass the header `Authorization: Bearer mytoken`,
|
||||
because the client with the token `mytoken` has the permission `admin`. Note that the following handler would also be
|
||||
accessible with that token:
|
||||
```go
|
||||
router.Handle("/protected", gate.Protect(yourHandler))
|
||||
```
|
||||
|
||||
To clarify, both clients and tokens have access to handlers that aren't protected with extra permissions, and
|
||||
essentially, tokens are registered as clients with no extra permissions in the background.
|
||||
|
||||
Creating a token like so:
|
||||
```go
|
||||
authorizationService := g8.NewAuthorizationService().WithToken("mytoken")
|
||||
```
|
||||
is the equivalent of creating the following client:
|
||||
```go
|
||||
authorizationService := g8.NewAuthorizationService().WithClient(g8.NewClient("mytoken"))
|
||||
```
|
||||
|
||||
|
||||
### With client provider
|
||||
A client provider's task is to retrieve a Client from an external source (e.g. a database) when provided with a token.
|
||||
You should use a client provider when you have a lot of tokens and it wouldn't make sense to register all of them using
|
||||
`AuthorizationService`'s `WithToken`/`WithTokens`/`WithClient`/`WithClients`.
|
||||
|
||||
Note that the provider is used as a fallback source. As such, if a token is explicitly registered using one of the 4
|
||||
aforementioned functions, the client provider will not be used.
|
||||
|
||||
```go
|
||||
clientProvider := g8.NewClientProvider(func(token string) *g8.Client {
|
||||
// We'll assume that the following function calls your database and returns a struct "User" that
|
||||
// has the user's token as well as the permissions granted to said user
|
||||
user := database.GetUserByToken(token)
|
||||
if user != nil {
|
||||
return g8.NewClient(user.Token).WithPermissions(user.Permissions)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
authorizationService := g8.NewAuthorizationService().WithClientProvider(clientProvider)
|
||||
gate := g8.New().WithAuthorizationService(authorizationService)
|
||||
```
|
||||
|
||||
You can also configure the client provider to cache the output of the function you provide to retrieve clients by token:
|
||||
```go
|
||||
clientProvider := g8.NewClientProvider(...).WithCache(ttl, maxSize)
|
||||
```
|
||||
|
||||
Since g8 leverages [TwiN/gocache](https://github.com/TwiN/gocache) (unless you're using `WithCustomCache`),
|
||||
you can also use gocache's constants for configuring the TTL and the maximum size:
|
||||
- Setting the TTL to `gocache.NoExpiration` (-1) will disable the TTL.
|
||||
- Setting the maximum size to `gocache.NoMaxSize` (0) will disable the maximum cache size
|
||||
|
||||
To avoid any misunderstandings, using a client provider is not mandatory. If you only have a few tokens and you can load
|
||||
them on application start, you can just leverage `AuthorizationService`'s `WithToken`/`WithTokens`/`WithClient`/`WithClients`.
|
||||
|
||||
|
||||
## AuthorizationService
|
||||
As the previous examples may have hinted, there are several ways to create clients. The one thing they have
|
||||
in common is that they all go through AuthorizationService, which is in charge of both managing clients and determining
|
||||
whether a request should be blocked or allowed through.
|
||||
|
||||
| Function | Description |
|
||||
|:-------------------|:---------------------------------------------------------------------------------------------------------------------------------|
|
||||
| WithToken | Creates a single static client with no extra permissions |
|
||||
| WithTokens | Creates a slice of static clients with no extra permissions |
|
||||
| WithClient | Creates a single static client |
|
||||
| WithClients | Creates a slice of static clients |
|
||||
| WithClientProvider | Creates a client provider which will allow a fallback to a dynamic source (e.g. to a database) when a static client is not found |
|
||||
|
||||
Except for `WithClientProvider`, every functions listed above can be called more than once.
|
||||
As a result, you may safely perform actions like this:
|
||||
```go
|
||||
authorizationService := g8.NewAuthorizationService().
|
||||
WithToken("123").
|
||||
WithToken("456").
|
||||
WithClient(g8.NewClient("789").WithPermission("admin"))
|
||||
gate := g8.New().WithAuthorizationService(authorizationService)
|
||||
```
|
||||
|
||||
Be aware that g8.Client supports a list of permissions as well. You may call `WithPermission` several times, or call
|
||||
`WithPermissions` with a slice of permissions instead.
|
||||
|
||||
|
||||
### Permissions
|
||||
Unlike client permissions, handler permissions are requirements.
|
||||
|
||||
A client may have as many permissions as you want, but for said client to have access to a handler protected by
|
||||
permissions, the client must have all permissions defined by said handler in order to have access to it.
|
||||
|
||||
In other words, a client with the permissions `create`, `read`, `update` and `delete` would have access to all of these handlers:
|
||||
```go
|
||||
gate := g8.New().WithAuthorizationService(g8.NewAuthorizationService().WithClient(g8.NewClient("mytoken").WithPermissions([]string{"create", "read", "update", "delete"})))
|
||||
router := http.NewServeMux()
|
||||
router.Handle("/", gate.Protect(homeHandler)) // equivalent of gate.ProtectWithPermissions(homeHandler, []string{})
|
||||
router.Handle("/create", gate.ProtectWithPermissions(createHandler, []string{"create"}))
|
||||
router.Handle("/read", gate.ProtectWithPermissions(readHandler, []string{"read"}))
|
||||
router.Handle("/update", gate.ProtectWithPermissions(updateHandler, []string{"update"}))
|
||||
router.Handle("/delete", gate.ProtectWithPermissions(deleteHandler, []string{"delete"}))
|
||||
router.Handle("/crud", gate.ProtectWithPermissions(crudHandler, []string{"create", "read", "update", "delete"}))
|
||||
```
|
||||
But it would not have access to the following handler, because while `mytoken` has the `read` permission, it does not
|
||||
have the `backup` permission:
|
||||
```go
|
||||
router.Handle("/backup", gate.ProtectWithPermissions(&testHandler{}, []string{"read", "backup"}))
|
||||
```
|
||||
|
||||
If you're using an HTTP library that supports middlewares like [mux](https://github.com/gorilla/mux), you can protect
|
||||
an entire group of handlers instead using `gate.Protect` or `gate.PermissionMiddleware()`:
|
||||
```go
|
||||
router := mux.NewRouter()
|
||||
|
||||
userRouter := router.PathPrefix("/").Subrouter()
|
||||
userRouter.Use(gate.Protect)
|
||||
userRouter.HandleFunc("/api/v1/users/me", getUserProfile).Methods("GET")
|
||||
userRouter.HandleFunc("/api/v1/users/me/friends", getUserFriends).Methods("GET")
|
||||
userRouter.HandleFunc("/api/v1/users/me/email", updateUserEmail).Methods("PATCH")
|
||||
|
||||
adminRouter := router.PathPrefix("/").Subrouter()
|
||||
adminRouter.Use(gate.PermissionMiddleware("admin"))
|
||||
adminRouter.HandleFunc("/api/v1/users/{id}/ban", banUserByID).Methods("POST")
|
||||
adminRouter.HandleFunc("/api/v1/users/{id}/delete", deleteUserByID).Methods("DELETE")
|
||||
```
|
||||
|
||||
|
||||
## Rate limiting
|
||||
To add a rate limit of 100 requests per second:
|
||||
```go
|
||||
gate := g8.New().WithRateLimit(100)
|
||||
```
|
||||
|
||||
|
||||
## Accessing the token from the protected handlers
|
||||
If you need to access the token from the handlers you are protecting with g8, you can retrieve it from the
|
||||
request context by using the key `g8.TokenContextKey`:
|
||||
```go
|
||||
http.Handle("/handle", gate.ProtectFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
token, _ := r.Context().Value(g8.TokenContextKey).(string)
|
||||
// ...
|
||||
}))
|
||||
```
|
||||
|
||||
## Examples
|
||||
### Protecting a handler using session cookie
|
||||
If you want to only allow authenticated users to access a handler, you can use a custom token extractor function
|
||||
combined with a client provider.
|
||||
|
||||
First, we'll create a function to extract the session ID from the session cookie. While a session ID does not
|
||||
theoretically refer to a token, g8 uses the term `token` as a blanket term to refer to any string that can be used to
|
||||
identify a client.
|
||||
```go
|
||||
customTokenExtractorFunc := func(request *http.Request) string {
|
||||
sessionCookie, err := request.Cookie("session")
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return sessionCookie.Value
|
||||
}
|
||||
```
|
||||
|
||||
Next, we need to create a client provider that will validate our token, which refers to the session ID in this case.
|
||||
```go
|
||||
clientProvider := g8.NewClientProvider(func(token string) *g8.Client {
|
||||
// We'll assume that the following function calls your database and validates whether the session is valid.
|
||||
isSessionValid := database.CheckIfSessionIsValid(token)
|
||||
if !isSessionValid {
|
||||
return nil // Returning nil will cause the gate to return a 401 Unauthorized.
|
||||
}
|
||||
// You could also retrieve the user and their permissions if you wanted instead, but for this example,
|
||||
// all we care about is confirming whether the session is valid or not.
|
||||
return g8.NewClient(token)
|
||||
})
|
||||
```
|
||||
|
||||
Keep in mind that you can get really creative with the client provider above.
|
||||
For instance, you could refresh the session's expiration time, which will allow the user to stay logged in for
|
||||
as long as they're active.
|
||||
|
||||
You're also not limited to using something stateful like the example above. You could use a JWT and have your client
|
||||
provider validate said JWT.
|
||||
|
||||
Finally, we can create the authorization service and the gate:
|
||||
```go
|
||||
authorizationService := g8.NewAuthorizationService().WithClientProvider(clientProvider)
|
||||
gate := g8.New().WithAuthorizationService(authorizationService).WithCustomTokenExtractor(customTokenExtractorFunc)
|
||||
```
|
||||
|
||||
If you need to access the token (session ID in this case) from the protected handlers, you can retrieve it from the
|
||||
request context by using the key `g8.TokenContextKey`:
|
||||
```go
|
||||
http.Handle("/handle", gate.ProtectFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
sessionID, _ := r.Context().Value(g8.TokenContextKey).(string)
|
||||
// ...
|
||||
}))
|
||||
```
|
||||
|
||||
### Using a custom header
|
||||
The logic is the same as the example above:
|
||||
```go
|
||||
customTokenExtractorFunc := func(request *http.Request) string {
|
||||
return request.Header.Get("X-API-Token")
|
||||
}
|
||||
|
||||
clientProvider := g8.NewClientProvider(func(token string) *g8.Client {
|
||||
// We'll assume that the following function calls your database and returns a struct "User" that
|
||||
// has the user's token as well as the permissions granted to said user
|
||||
user := database.GetUserByToken(token)
|
||||
if user != nil {
|
||||
return g8.NewClient(user.Token).WithPermissions(user.Permissions)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
authorizationService := g8.NewAuthorizationService().WithClientProvider(clientProvider)
|
||||
gate := g8.New().WithAuthorizationService(authorizationService).WithCustomTokenExtractor(customTokenExtractorFunc)
|
||||
```
|
||||
|
||||
### Using a custom cache
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
g8 "github.com/TwiN/g8/v3"
|
||||
)
|
||||
|
||||
type customCache struct {
|
||||
entries map[string]any
|
||||
sync.Mutex
|
||||
}
|
||||
|
||||
func (c *customCache) Get(key string) (value any, exists bool) {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (c *customCache) Set(key string, value any) {
|
||||
// ...
|
||||
}
|
||||
|
||||
// To verify the implementation
|
||||
var _ g8.Cache = (*customCache)(nil)
|
||||
|
||||
func main() {
|
||||
getClientByTokenFunc := func(token string) *g8.Client {
|
||||
// We'll assume that the following function calls your database and returns a struct "User" that
|
||||
// has the user's token as well as the permissions granted to said user
|
||||
user := database.GetUserByToken(token)
|
||||
if user != nil {
|
||||
return g8.NewClient(user.Token).WithPermissions(user.Permissions).WithData(user.Data)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
// Create the provider with the custom cache
|
||||
provider := g8.NewClientProvider(getClientByTokenFunc).WithCustomCache(&customCache{})
|
||||
}
|
||||
```
|
130
authorization.go
Normal file
130
authorization.go
Normal file
|
@ -0,0 +1,130 @@
|
|||
package g8
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
// AuthorizationService is the service that manages client/token registry and client fallback as well as the service
|
||||
// that determines whether a token meets the specific requirements to be authorized by a Gate or not.
|
||||
type AuthorizationService struct {
|
||||
clients map[string]*Client
|
||||
clientProvider *ClientProvider
|
||||
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// NewAuthorizationService creates a new AuthorizationService
|
||||
func NewAuthorizationService() *AuthorizationService {
|
||||
return &AuthorizationService{
|
||||
clients: make(map[string]*Client),
|
||||
}
|
||||
}
|
||||
|
||||
// WithToken is used to specify a single token for which authorization will be granted
|
||||
//
|
||||
// The client that will be created from this token will have access to all handlers that are not protected with a
|
||||
// specific permission.
|
||||
//
|
||||
// In other words, if you were to do the following:
|
||||
//
|
||||
// gate := g8.New().WithAuthorizationService(g8.NewAuthorizationService().WithToken("12345"))
|
||||
//
|
||||
// The following handler would be accessible with the token 12345:
|
||||
//
|
||||
// router.Handle("/1st-handler", gate.Protect(yourHandler))
|
||||
//
|
||||
// But not this one would not be accessible with the token 12345:
|
||||
//
|
||||
// router.Handle("/2nd-handler", gate.ProtectWithPermissions(yourOtherHandler, []string{"admin"}))
|
||||
//
|
||||
// Calling this function multiple times will add multiple clients, though you may want to use WithTokens instead
|
||||
// if you plan to add multiple clients
|
||||
//
|
||||
// If you wish to configure advanced permissions, consider using WithClient instead.
|
||||
func (authorizationService *AuthorizationService) WithToken(token string) *AuthorizationService {
|
||||
authorizationService.mutex.Lock()
|
||||
authorizationService.clients[token] = NewClient(token)
|
||||
authorizationService.mutex.Unlock()
|
||||
return authorizationService
|
||||
}
|
||||
|
||||
// WithTokens is used to specify a slice of tokens for which authorization will be granted
|
||||
func (authorizationService *AuthorizationService) WithTokens(tokens []string) *AuthorizationService {
|
||||
authorizationService.mutex.Lock()
|
||||
for _, token := range tokens {
|
||||
authorizationService.clients[token] = NewClient(token)
|
||||
}
|
||||
authorizationService.mutex.Unlock()
|
||||
return authorizationService
|
||||
}
|
||||
|
||||
// WithClient is used to specify a single client for which authorization will be granted
|
||||
//
|
||||
// When compared to WithToken, the advantage of using this function is that you may specify the client's
|
||||
// permissions and thus, be a lot more granular with what endpoint a token has access to.
|
||||
//
|
||||
// In other words, if you were to do the following:
|
||||
//
|
||||
// gate := g8.New().WithAuthorizationService(g8.NewAuthorizationService().WithClient(g8.NewClient("12345").WithPermission("mod")))
|
||||
//
|
||||
// The following handlers would be accessible with the token 12345:
|
||||
//
|
||||
// router.Handle("/1st-handler", gate.ProtectWithPermissions(yourHandler, []string{"mod"}))
|
||||
// router.Handle("/2nd-handler", gate.Protect(yourOtherHandler))
|
||||
//
|
||||
// But not this one, because the user does not have the permission "admin":
|
||||
//
|
||||
// router.Handle("/3rd-handler", gate.ProtectWithPermissions(yetAnotherHandler, []string{"admin"}))
|
||||
//
|
||||
// Calling this function multiple times will add multiple clients, though you may want to use WithClients instead
|
||||
// if you plan to add multiple clients
|
||||
func (authorizationService *AuthorizationService) WithClient(client *Client) *AuthorizationService {
|
||||
authorizationService.mutex.Lock()
|
||||
authorizationService.clients[client.Token] = client
|
||||
authorizationService.mutex.Unlock()
|
||||
return authorizationService
|
||||
}
|
||||
|
||||
// WithClients is used to specify a slice of clients for which authorization will be granted
|
||||
func (authorizationService *AuthorizationService) WithClients(clients []*Client) *AuthorizationService {
|
||||
authorizationService.mutex.Lock()
|
||||
for _, client := range clients {
|
||||
authorizationService.clients[client.Token] = client
|
||||
}
|
||||
authorizationService.mutex.Unlock()
|
||||
return authorizationService
|
||||
}
|
||||
|
||||
// WithClientProvider allows specifying a custom provider to fetch clients by token.
|
||||
//
|
||||
// For example, you can use it to fallback to making a call in your database when a request is made with a token that
|
||||
// hasn't been specified via WithToken, WithTokens, WithClient or WithClients.
|
||||
func (authorizationService *AuthorizationService) WithClientProvider(provider *ClientProvider) *AuthorizationService {
|
||||
authorizationService.clientProvider = provider
|
||||
return authorizationService
|
||||
}
|
||||
|
||||
// Authorize checks whether a client with a given token exists and has the permissions required.
|
||||
//
|
||||
// If permissionsRequired is nil or empty and a client with the given token exists, said client will have access to all
|
||||
// handlers that are not protected by a given permission.
|
||||
//
|
||||
// Returns the client is authorized (or nil if no client was authorized), as well as whether the token is authorized
|
||||
func (authorizationService *AuthorizationService) Authorize(token string, permissionsRequired []string) (client *Client, authorized bool) {
|
||||
if len(token) == 0 {
|
||||
return nil, false
|
||||
}
|
||||
authorizationService.mutex.RLock()
|
||||
client, _ = authorizationService.clients[token]
|
||||
authorizationService.mutex.RUnlock()
|
||||
// If there's no clients with the given token directly stored in the AuthorizationService, fall back to the
|
||||
// client provider, if there's one configured.
|
||||
if client == nil && authorizationService.clientProvider != nil {
|
||||
client = authorizationService.clientProvider.GetClientByToken(token)
|
||||
}
|
||||
if client != nil && client.HasPermissions(permissionsRequired) {
|
||||
// If the client has the required permissions, return true and the client
|
||||
return client, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
108
authorization_test.go
Normal file
108
authorization_test.go
Normal file
|
@ -0,0 +1,108 @@
|
|||
package g8
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestAuthorizationService_Authorize(t *testing.T) {
|
||||
authorizationService := NewAuthorizationService().WithToken("token")
|
||||
if _, authorized := authorizationService.Authorize("token", nil); !authorized {
|
||||
t.Error("should've returned true")
|
||||
}
|
||||
if _, authorized := authorizationService.Authorize("bad-token", nil); authorized {
|
||||
t.Error("should've returned false")
|
||||
}
|
||||
if _, authorized := authorizationService.Authorize("token", []string{"admin"}); authorized {
|
||||
t.Error("should've returned false")
|
||||
}
|
||||
if _, authorized := authorizationService.Authorize("", nil); authorized {
|
||||
t.Error("should've returned false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthorizationService_AuthorizeWithPermissions(t *testing.T) {
|
||||
authorizationService := NewAuthorizationService().WithClient(NewClient("token").WithPermissions([]string{"a", "b"}))
|
||||
if _, authorized := authorizationService.Authorize("token", nil); !authorized {
|
||||
t.Error("should've returned true")
|
||||
}
|
||||
if _, authorized := authorizationService.Authorize("token", []string{"a"}); !authorized {
|
||||
t.Error("should've returned true")
|
||||
}
|
||||
if _, authorized := authorizationService.Authorize("token", []string{"b"}); !authorized {
|
||||
t.Error("should've returned true")
|
||||
}
|
||||
if _, authorized := authorizationService.Authorize("token", []string{"a", "b"}); !authorized {
|
||||
t.Error("should've returned true")
|
||||
}
|
||||
if _, authorized := authorizationService.Authorize("token", []string{"c"}); authorized {
|
||||
t.Error("should've returned false")
|
||||
}
|
||||
if _, authorized := authorizationService.Authorize("token", []string{"a", "c"}); authorized {
|
||||
t.Error("should've returned false")
|
||||
}
|
||||
if _, authorized := authorizationService.Authorize("bad-token", nil); authorized {
|
||||
t.Error("should've returned false")
|
||||
}
|
||||
if _, authorized := authorizationService.Authorize("bad-token", []string{"a"}); authorized {
|
||||
t.Error("should've returned false")
|
||||
}
|
||||
if _, authorized := authorizationService.Authorize("", []string{"a"}); authorized {
|
||||
t.Error("should've returned false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthorizationService_WithToken(t *testing.T) {
|
||||
authorizationService := NewAuthorizationService().WithToken("token")
|
||||
if _, authorized := authorizationService.Authorize("token", nil); !authorized {
|
||||
t.Error("should've returned true")
|
||||
}
|
||||
if _, authorized := authorizationService.Authorize("bad-token", nil); authorized {
|
||||
t.Error("should've returned false")
|
||||
}
|
||||
if _, authorized := authorizationService.Authorize("token", []string{"admin"}); authorized {
|
||||
t.Error("should've returned false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthorizationService_WithTokens(t *testing.T) {
|
||||
authorizationService := NewAuthorizationService().WithTokens([]string{"1", "2"})
|
||||
if _, authorized := authorizationService.Authorize("1", nil); !authorized {
|
||||
t.Error("should've returned true")
|
||||
}
|
||||
if _, authorized := authorizationService.Authorize("2", nil); !authorized {
|
||||
t.Error("should've returned true")
|
||||
}
|
||||
if _, authorized := authorizationService.Authorize("3", nil); authorized {
|
||||
t.Error("should've returned false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthorizationService_WithClient(t *testing.T) {
|
||||
authorizationService := NewAuthorizationService().WithClient(NewClient("token").WithPermissions([]string{"a", "b"}))
|
||||
if _, authorized := authorizationService.Authorize("token", []string{"a", "b"}); !authorized {
|
||||
t.Error("should've returned true")
|
||||
}
|
||||
if _, authorized := authorizationService.Authorize("token", []string{"a"}); !authorized {
|
||||
t.Error("should've returned true")
|
||||
}
|
||||
if _, authorized := authorizationService.Authorize("token", []string{"b"}); !authorized {
|
||||
t.Error("should've returned true")
|
||||
}
|
||||
if _, authorized := authorizationService.Authorize("token", []string{"c"}); authorized {
|
||||
t.Error("should've returned false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthorizationService_WithClients(t *testing.T) {
|
||||
authorizationService := NewAuthorizationService().WithClients([]*Client{NewClient("1").WithPermission("a"), NewClient("2").WithPermission("b")})
|
||||
if _, authorized := authorizationService.Authorize("1", []string{"a"}); !authorized {
|
||||
t.Error("should've returned true")
|
||||
}
|
||||
if _, authorized := authorizationService.Authorize("2", []string{"b"}); !authorized {
|
||||
t.Error("should've returned true")
|
||||
}
|
||||
if _, authorized := authorizationService.Authorize("1", []string{"b"}); authorized {
|
||||
t.Error("should've returned false")
|
||||
}
|
||||
if _, authorized := authorizationService.Authorize("2", []string{"a"}); authorized {
|
||||
t.Error("should've returned false")
|
||||
}
|
||||
}
|
13
cache.go
Normal file
13
cache.go
Normal file
|
@ -0,0 +1,13 @@
|
|||
package g8
|
||||
|
||||
import (
|
||||
"github.com/TwiN/gocache/v2"
|
||||
)
|
||||
|
||||
type Cache interface {
|
||||
Get(key string) (value any, exists bool)
|
||||
Set(key string, value any)
|
||||
}
|
||||
|
||||
// Make sure that gocache.Cache is compatible with the interface
|
||||
var _ Cache = (*gocache.Cache)(nil)
|
79
client.go
Normal file
79
client.go
Normal file
|
@ -0,0 +1,79 @@
|
|||
package g8
|
||||
|
||||
// Client is a struct containing both a Token and a slice of extra Permissions that said token has.
|
||||
type Client struct {
|
||||
// Token is the value used to authenticate with the API.
|
||||
Token string
|
||||
|
||||
// Permissions is a slice of extra permissions that may be used for more granular access control.
|
||||
//
|
||||
// If you only wish to use Gate.Protect and Gate.ProtectFunc, you do not have to worry about this,
|
||||
// since they're only used by Gate.ProtectWithPermissions and Gate.ProtectFuncWithPermissions
|
||||
Permissions []string
|
||||
|
||||
// Data is a field that can be used to store any data you want to associate with the client.
|
||||
Data any
|
||||
}
|
||||
|
||||
// NewClient creates a Client with a given token
|
||||
func NewClient(token string) *Client {
|
||||
return &Client{
|
||||
Token: token,
|
||||
}
|
||||
}
|
||||
|
||||
// NewClientWithPermissions creates a Client with a slice of permissions
|
||||
// Equivalent to using NewClient and WithPermissions
|
||||
func NewClientWithPermissions(token string, permissions []string) *Client {
|
||||
return NewClient(token).WithPermissions(permissions)
|
||||
}
|
||||
|
||||
// NewClientWithData creates a Client with some data
|
||||
// Equivalent to using NewClient and WithData
|
||||
func NewClientWithData(token string, data any) *Client {
|
||||
return NewClient(token).WithData(data)
|
||||
}
|
||||
|
||||
// NewClientWithPermissionsAndData creates a Client with a slice of permissions and some data
|
||||
// Equivalent to using NewClient, WithPermissions and WithData
|
||||
func NewClientWithPermissionsAndData(token string, permissions []string, data any) *Client {
|
||||
return NewClient(token).WithPermissions(permissions).WithData(data)
|
||||
}
|
||||
|
||||
// WithPermissions adds a slice of permissions to a client
|
||||
func (client *Client) WithPermissions(permissions []string) *Client {
|
||||
client.Permissions = append(client.Permissions, permissions...)
|
||||
return client
|
||||
}
|
||||
|
||||
// WithPermission adds a permission to a client
|
||||
func (client *Client) WithPermission(permission string) *Client {
|
||||
client.Permissions = append(client.Permissions, permission)
|
||||
return client
|
||||
}
|
||||
|
||||
// WithData attaches data to a client
|
||||
func (client *Client) WithData(data any) *Client {
|
||||
client.Data = data
|
||||
return client
|
||||
}
|
||||
|
||||
// HasPermission checks whether a client has a given permission
|
||||
func (client *Client) HasPermission(permissionRequired string) bool {
|
||||
for _, permission := range client.Permissions {
|
||||
if permissionRequired == permission {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// HasPermissions checks whether a client has the all permissions passed
|
||||
func (client *Client) HasPermissions(permissionsRequired []string) bool {
|
||||
for _, permissionRequired := range permissionsRequired {
|
||||
if !client.HasPermission(permissionRequired) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
76
client_test.go
Normal file
76
client_test.go
Normal file
|
@ -0,0 +1,76 @@
|
|||
package g8
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestClient_HasPermission(t *testing.T) {
|
||||
client := NewClientWithPermissions("token", []string{"a", "b"})
|
||||
if !client.HasPermission("a") {
|
||||
t.Errorf("client has permissions %s, therefore HasPermission(a) should've been true", client.Permissions)
|
||||
}
|
||||
if !client.HasPermission("b") {
|
||||
t.Errorf("client has permissions %s, therefore HasPermission(b) should've been true", client.Permissions)
|
||||
}
|
||||
if client.HasPermission("c") {
|
||||
t.Errorf("client has permissions %s, therefore HasPermission(c) should've been false", client.Permissions)
|
||||
}
|
||||
if client.HasPermission("ab") {
|
||||
t.Errorf("client has permissions %s, therefore HasPermission(ab) should've been false", client.Permissions)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient_HasPermissions(t *testing.T) {
|
||||
client := NewClientWithPermissions("token", []string{"a", "b"})
|
||||
if !client.HasPermissions(nil) {
|
||||
t.Errorf("client has permissions %s, therefore HasPermissions(nil) should've been true", client.Permissions)
|
||||
}
|
||||
if !client.HasPermissions([]string{"a"}) {
|
||||
t.Errorf("client has permissions %s, therefore HasPermissions([a]) should've been true", client.Permissions)
|
||||
}
|
||||
if !client.HasPermissions([]string{"b"}) {
|
||||
t.Errorf("client has permissions %s, therefore HasPermissions([b]) should've been true", client.Permissions)
|
||||
}
|
||||
if !client.HasPermissions([]string{"a", "b"}) {
|
||||
t.Errorf("client has permissions %s, therefore HasPermissions([a, b]) should've been true", client.Permissions)
|
||||
}
|
||||
if client.HasPermissions([]string{"a", "b", "c"}) {
|
||||
t.Errorf("client has permissions %s, therefore HasPermissions([a, b, c]) should've been false", client.Permissions)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient_WithData(t *testing.T) {
|
||||
client := NewClient("token")
|
||||
if client.Data != nil {
|
||||
t.Error("expected client data to be nil")
|
||||
}
|
||||
client.WithData(5)
|
||||
if client.Data != 5 {
|
||||
t.Errorf("expected client data to be 5, got %d", client.Data)
|
||||
}
|
||||
client.WithData(map[string]string{"key": "value"})
|
||||
if data, ok := client.Data.(map[string]string); !ok || data["key"] != "value" {
|
||||
t.Errorf("expected client data to be map[string]string{key: value}, got %v", client.Data)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewClientWithData(t *testing.T) {
|
||||
client := NewClientWithData("token", 5)
|
||||
if client.Data != 5 {
|
||||
t.Errorf("expected client data to be 5, got %d", client.Data)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewClientWithPermissionsAndData(t *testing.T) {
|
||||
client := NewClientWithPermissionsAndData("token", []string{"a", "b"}, 5)
|
||||
if client.Data != 5 {
|
||||
t.Errorf("expected client data to be 5, got %d", client.Data)
|
||||
}
|
||||
if !client.HasPermission("a") {
|
||||
t.Errorf("client has permissions %s, therefore HasPermission(a) should've been true", client.Permissions)
|
||||
}
|
||||
if !client.HasPermission("b") {
|
||||
t.Errorf("client has permissions %s, therefore HasPermission(b) should've been true", client.Permissions)
|
||||
}
|
||||
if client.HasPermission("c") {
|
||||
t.Errorf("client has permissions %s, therefore HasPermission(c) should've been false", client.Permissions)
|
||||
}
|
||||
}
|
102
clientprovider.go
Normal file
102
clientprovider.go
Normal file
|
@ -0,0 +1,102 @@
|
|||
package g8
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/TwiN/gocache/v2"
|
||||
)
|
||||
|
||||
// ClientProvider has the task of retrieving a Client from an external source (e.g. a database) when provided with a
|
||||
// token. It should be used when you have a lot of tokens, and it wouldn't make sense to register all of them using
|
||||
// AuthorizationService's WithToken, WithTokens, WithClient or WithClients.
|
||||
//
|
||||
// Note that the provider is used as a fallback source. As such, if a token is explicitly registered using one of the 4
|
||||
// aforementioned functions, the client provider will not be used by the AuthorizationService when a request is made
|
||||
// with said token. It will, however, be called upon if a token that is not explicitly registered in
|
||||
// AuthorizationService is sent alongside a request going through the Gate.
|
||||
//
|
||||
// clientProvider := g8.NewClientProvider(func(token string) *g8.Client {
|
||||
// // We'll assume that the following function calls your database and returns a struct "User" that
|
||||
// // has the user's token as well as the permissions granted to said user
|
||||
// user := database.GetUserByToken(token)
|
||||
// if user != nil {
|
||||
// return g8.NewClient(user.Token).WithPermissions(user.Permissions)
|
||||
// }
|
||||
// return nil
|
||||
// })
|
||||
// gate := g8.New().WithAuthorizationService(g8.NewAuthorizationService().WithClientProvider(clientProvider))
|
||||
type ClientProvider struct {
|
||||
getClientByTokenFunc func(token string) *Client
|
||||
|
||||
cache Cache
|
||||
}
|
||||
|
||||
// NewClientProvider creates a ClientProvider
|
||||
// The parameter that must be passed is a function that the provider will use to retrieve a client by a given token
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// clientProvider := g8.NewClientProvider(func(token string) *g8.Client {
|
||||
// // We'll assume that the following function calls your database and returns a struct "User" that
|
||||
// // has the user's token as well as the permissions granted to said user
|
||||
// user := database.GetUserByToken(token)
|
||||
// if user == nil {
|
||||
// return nil
|
||||
// }
|
||||
// return g8.NewClient(user.Token).WithPermissions(user.Permissions)
|
||||
// })
|
||||
// gate := g8.New().WithAuthorizationService(g8.NewAuthorizationService().WithClientProvider(clientProvider))
|
||||
func NewClientProvider(getClientByTokenFunc func(token string) *Client) *ClientProvider {
|
||||
return &ClientProvider{
|
||||
getClientByTokenFunc: getClientByTokenFunc,
|
||||
}
|
||||
}
|
||||
|
||||
// WithCache enables an in-memory cache for the ClientProvider.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// clientProvider := g8.NewClientProvider(func(token string) *g8.Client {
|
||||
// // We'll assume that the following function calls your database and returns a struct "User" that
|
||||
// // has the user's token as well as the permissions granted to said user
|
||||
// user := database.GetUserByToken(token)
|
||||
// if user != nil {
|
||||
// return g8.NewClient(user.Token).WithPermissions(user.Permissions)
|
||||
// }
|
||||
// return nil
|
||||
// })
|
||||
// gate := g8.New().WithAuthorizationService(g8.NewAuthorizationService().WithClientProvider(clientProvider.WithCache(time.Hour, 70000)))
|
||||
func (provider *ClientProvider) WithCache(ttl time.Duration, maxSize int) *ClientProvider {
|
||||
return provider.WithCustomCache(
|
||||
gocache.NewCache().WithEvictionPolicy(gocache.LeastRecentlyUsed).WithMaxSize(maxSize).WithDefaultTTL(ttl),
|
||||
)
|
||||
}
|
||||
|
||||
// WithCustomCache allows you to use a custom cache implementation instead of the default one.
|
||||
// By default, using WithCache will leverage gocache.
|
||||
//
|
||||
// Note that the custom cache must implement the Cache interface
|
||||
func (provider *ClientProvider) WithCustomCache(cache Cache) *ClientProvider {
|
||||
provider.cache = cache
|
||||
return provider
|
||||
}
|
||||
|
||||
// GetClientByToken retrieves a client by its token through the provided getClientByTokenFunc.
|
||||
func (provider *ClientProvider) GetClientByToken(token string) *Client {
|
||||
if provider.cache == nil {
|
||||
return provider.getClientByTokenFunc(token)
|
||||
}
|
||||
if cachedClient, exists := provider.cache.Get(token); exists {
|
||||
if cachedClient == nil {
|
||||
return nil
|
||||
}
|
||||
// Safely typecast the client.
|
||||
// Regardless of whether the typecast is successful or not, we return client since it'll be either client or
|
||||
// nil. Technically, it should never be nil, but it's better to be safe than sorry.
|
||||
client, _ := cachedClient.(*Client)
|
||||
return client
|
||||
}
|
||||
client := provider.getClientByTokenFunc(token)
|
||||
provider.cache.Set(token, client)
|
||||
return client
|
||||
}
|
131
clientprovider_test.go
Normal file
131
clientprovider_test.go
Normal file
|
@ -0,0 +1,131 @@
|
|||
package g8
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/TwiN/gocache/v2"
|
||||
)
|
||||
|
||||
var (
|
||||
getClientByTokenFunc = func(token string) *Client {
|
||||
if token == "valid-token" {
|
||||
return NewClient("valid-token").WithData("client-data")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
)
|
||||
|
||||
func TestClientProvider_GetClientByToken(t *testing.T) {
|
||||
provider := NewClientProvider(getClientByTokenFunc)
|
||||
if client := provider.GetClientByToken("valid-token"); client == nil {
|
||||
t.Error("should've returned a client")
|
||||
} else if client.Data != "client-data" {
|
||||
t.Error("expected client data to be 'client-data', got", client.Data)
|
||||
}
|
||||
if client := provider.GetClientByToken("invalid-token"); client != nil {
|
||||
t.Error("should've returned nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientProvider_WithCache(t *testing.T) {
|
||||
provider := NewClientProvider(getClientByTokenFunc).WithCache(gocache.NoExpiration, 10000)
|
||||
if provider.cache.(*gocache.Cache).Count() != 0 {
|
||||
t.Error("expected cache to be empty")
|
||||
}
|
||||
if client := provider.GetClientByToken("valid-token"); client == nil {
|
||||
t.Error("expected client, got nil")
|
||||
}
|
||||
if provider.cache.(*gocache.Cache).Count() != 1 {
|
||||
t.Error("expected cache size to be 1")
|
||||
}
|
||||
if client := provider.GetClientByToken("valid-token"); client == nil {
|
||||
t.Error("expected client, got nil")
|
||||
}
|
||||
if provider.cache.(*gocache.Cache).Count() != 1 {
|
||||
t.Error("expected cache size to be 1")
|
||||
}
|
||||
if client := provider.GetClientByToken("invalid-token"); client != nil {
|
||||
t.Error("expected nil, got", client)
|
||||
}
|
||||
if provider.cache.(*gocache.Cache).Count() != 2 {
|
||||
t.Error("expected cache size to be 2")
|
||||
}
|
||||
if client := provider.GetClientByToken("invalid-token"); client != nil {
|
||||
t.Error("expected nil, got", client)
|
||||
}
|
||||
if client := provider.GetClientByToken("invalid-token"); client != nil {
|
||||
t.Error("should've returned nil (cached)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientProvider_WithCacheAndExpiration(t *testing.T) {
|
||||
provider := NewClientProvider(getClientByTokenFunc).WithCache(10*time.Millisecond, 10)
|
||||
provider.GetClientByToken("token")
|
||||
if provider.cache.(*gocache.Cache).Count() != 1 {
|
||||
t.Error("expected cache size to be 1")
|
||||
}
|
||||
if provider.cache.(*gocache.Cache).Stats().ExpiredKeys != 0 {
|
||||
t.Error("expected cache statistics to report 0 expired key")
|
||||
}
|
||||
time.Sleep(15 * time.Millisecond)
|
||||
provider.GetClientByToken("token")
|
||||
if provider.cache.(*gocache.Cache).Stats().ExpiredKeys != 1 {
|
||||
t.Error("expected cache statistics to report 1 expired key")
|
||||
}
|
||||
}
|
||||
|
||||
type customCache struct {
|
||||
entries map[string]any
|
||||
sync.Mutex
|
||||
}
|
||||
|
||||
func (c *customCache) Get(key string) (value any, exists bool) {
|
||||
c.Lock()
|
||||
v, exists := c.entries[key]
|
||||
c.Unlock()
|
||||
return v, exists
|
||||
}
|
||||
|
||||
func (c *customCache) Set(key string, value any) {
|
||||
c.Lock()
|
||||
if c.entries == nil {
|
||||
c.entries = make(map[string]any)
|
||||
}
|
||||
c.entries[key] = value
|
||||
c.Unlock()
|
||||
}
|
||||
|
||||
var _ Cache = (*customCache)(nil)
|
||||
|
||||
func TestClientProvider_WithCustomCache(t *testing.T) {
|
||||
provider := NewClientProvider(getClientByTokenFunc).WithCustomCache(&customCache{})
|
||||
if len(provider.cache.(*customCache).entries) != 0 {
|
||||
t.Error("expected cache to be empty")
|
||||
}
|
||||
if client := provider.GetClientByToken("valid-token"); client == nil {
|
||||
t.Error("expected client, got nil")
|
||||
}
|
||||
if len(provider.cache.(*customCache).entries) != 1 {
|
||||
t.Error("expected cache size to be 1")
|
||||
}
|
||||
if client := provider.GetClientByToken("valid-token"); client == nil {
|
||||
t.Error("expected client, got nil")
|
||||
}
|
||||
if len(provider.cache.(*customCache).entries) != 1 {
|
||||
t.Error("expected cache size to be 1")
|
||||
}
|
||||
if client := provider.GetClientByToken("invalid-token"); client != nil {
|
||||
t.Error("expected nil, got", client)
|
||||
}
|
||||
if len(provider.cache.(*customCache).entries) != 2 {
|
||||
t.Error("expected cache size to be 2")
|
||||
}
|
||||
if client := provider.GetClientByToken("invalid-token"); client != nil {
|
||||
t.Error("expected nil, got", client)
|
||||
}
|
||||
if client := provider.GetClientByToken("invalid-token"); client != nil {
|
||||
t.Error("should've returned nil (cached)")
|
||||
}
|
||||
}
|
245
gate.go
Normal file
245
gate.go
Normal file
|
@ -0,0 +1,245 @@
|
|||
package g8
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
// AuthorizationHeader is the header in which g8 looks for the authorization bearer token
|
||||
AuthorizationHeader = "Authorization"
|
||||
|
||||
// DefaultUnauthorizedResponseBody is the default response body returned if a request was sent with a missing or invalid token
|
||||
DefaultUnauthorizedResponseBody = "token is missing or invalid"
|
||||
|
||||
// DefaultTooManyRequestsResponseBody is the default response body returned if a request exceeded the allowed rate limit
|
||||
DefaultTooManyRequestsResponseBody = "too many requests"
|
||||
|
||||
// TokenContextKey is the key used to store the client's token in the context.
|
||||
TokenContextKey = "g8.token"
|
||||
|
||||
// DataContextKey is the key used to store the client's data in the context.
|
||||
DataContextKey = "g8.data"
|
||||
)
|
||||
|
||||
// Gate is lock to the front door of your API, letting only those you allow through.
|
||||
type Gate struct {
|
||||
authorizationService *AuthorizationService
|
||||
unauthorizedResponseBody []byte
|
||||
|
||||
customTokenExtractorFunc func(request *http.Request) string
|
||||
|
||||
rateLimiter *RateLimiter
|
||||
tooManyRequestsResponseBody []byte
|
||||
}
|
||||
|
||||
// Deprecated: use New instead.
|
||||
func NewGate(authorizationService *AuthorizationService) *Gate {
|
||||
return &Gate{
|
||||
authorizationService: authorizationService,
|
||||
unauthorizedResponseBody: []byte(DefaultUnauthorizedResponseBody),
|
||||
tooManyRequestsResponseBody: []byte(DefaultTooManyRequestsResponseBody),
|
||||
}
|
||||
}
|
||||
|
||||
// New creates a new Gate.
|
||||
func New() *Gate {
|
||||
return &Gate{
|
||||
unauthorizedResponseBody: []byte(DefaultUnauthorizedResponseBody),
|
||||
tooManyRequestsResponseBody: []byte(DefaultTooManyRequestsResponseBody),
|
||||
}
|
||||
}
|
||||
|
||||
// WithAuthorizationService sets the authorization service to use.
|
||||
//
|
||||
// If there is no authorization service, Gate will not enforce authorization.
|
||||
func (gate *Gate) WithAuthorizationService(authorizationService *AuthorizationService) *Gate {
|
||||
gate.authorizationService = authorizationService
|
||||
return gate
|
||||
}
|
||||
|
||||
// WithCustomUnauthorizedResponseBody sets a custom response body when Gate determines that a request must be blocked
|
||||
func (gate *Gate) WithCustomUnauthorizedResponseBody(unauthorizedResponseBody []byte) *Gate {
|
||||
gate.unauthorizedResponseBody = unauthorizedResponseBody
|
||||
return gate
|
||||
}
|
||||
|
||||
// WithCustomTokenExtractor allows the specification of a custom function to extract a token from a request.
|
||||
// If a custom token extractor is not specified, the token will be extracted from the Authorization header.
|
||||
//
|
||||
// For instance, if you're using a session cookie, you can extract the token from the cookie like so:
|
||||
//
|
||||
// authorizationService := g8.NewAuthorizationService()
|
||||
// customTokenExtractorFunc := func(request *http.Request) string {
|
||||
// sessionCookie, err := request.Cookie("session")
|
||||
// if err != nil {
|
||||
// return ""
|
||||
// }
|
||||
// return sessionCookie.Value
|
||||
// }
|
||||
// gate := g8.New().WithAuthorizationService(authorizationService).WithCustomTokenExtractor(customTokenExtractorFunc)
|
||||
//
|
||||
// You would normally use this with a client provider that matches whatever need you have.
|
||||
// For example, if you're using a session cookie, your client provider would retrieve the user from the session ID
|
||||
// extracted by this custom token extractor.
|
||||
//
|
||||
// Note that for the sake of convenience, the token extracted from the request is passed the protected handlers request
|
||||
// context under the key TokenContextKey. This is especially useful if the token is in fact a session ID.
|
||||
func (gate *Gate) WithCustomTokenExtractor(customTokenExtractorFunc func(request *http.Request) string) *Gate {
|
||||
gate.customTokenExtractorFunc = customTokenExtractorFunc
|
||||
return gate
|
||||
}
|
||||
|
||||
// WithRateLimit adds rate limiting to the Gate
|
||||
//
|
||||
// If you just want to use a gate for rate limiting purposes:
|
||||
//
|
||||
// gate := g8.New().WithRateLimit(50)
|
||||
func (gate *Gate) WithRateLimit(maximumRequestsPerSecond int) *Gate {
|
||||
gate.rateLimiter = NewRateLimiter(maximumRequestsPerSecond)
|
||||
return gate
|
||||
}
|
||||
|
||||
// Protect secures a handler, requiring requests going through to have a valid Authorization Bearer token.
|
||||
// Unlike ProtectWithPermissions, Protect will allow access to any registered tokens, regardless of their permissions
|
||||
// or lack thereof.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// gate := g8.New().WithAuthorizationService(g8.NewAuthorizationService().WithToken("token"))
|
||||
// router := http.NewServeMux()
|
||||
// // Without protection
|
||||
// router.Handle("/handle", yourHandler)
|
||||
// // With protection
|
||||
// router.Handle("/handle", gate.Protect(yourHandler))
|
||||
//
|
||||
// The token extracted from the request is passed to the handlerFunc request context under the key TokenContextKey
|
||||
func (gate *Gate) Protect(handler http.Handler) http.Handler {
|
||||
return gate.ProtectWithPermissions(handler, nil)
|
||||
}
|
||||
|
||||
// ProtectWithPermissions secures a handler, requiring requests going through to have a valid Authorization Bearer token
|
||||
// as well as a slice of permissions that must be met.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// gate := g8.New().WithAuthorizationService(g8.NewAuthorizationService().WithClient(g8.NewClient("token").WithPermission("ADMIN")))
|
||||
// router := http.NewServeMux()
|
||||
// // Without protection
|
||||
// router.Handle("/handle", yourHandler)
|
||||
// // With protection
|
||||
// router.Handle("/handle", gate.ProtectWithPermissions(yourHandler, []string{"admin"}))
|
||||
//
|
||||
// The token extracted from the request is passed to the handlerFunc request context under the key TokenContextKey
|
||||
func (gate *Gate) ProtectWithPermissions(handler http.Handler, permissions []string) http.Handler {
|
||||
return gate.ProtectFuncWithPermissions(func(writer http.ResponseWriter, request *http.Request) {
|
||||
handler.ServeHTTP(writer, request)
|
||||
}, permissions)
|
||||
}
|
||||
|
||||
// ProtectWithPermission does the same thing as ProtectWithPermissions, but for a single permission instead of a
|
||||
// slice of permissions
|
||||
//
|
||||
// See ProtectWithPermissions for further documentation
|
||||
func (gate *Gate) ProtectWithPermission(handler http.Handler, permission string) http.Handler {
|
||||
return gate.ProtectFuncWithPermissions(func(writer http.ResponseWriter, request *http.Request) {
|
||||
handler.ServeHTTP(writer, request)
|
||||
}, []string{permission})
|
||||
}
|
||||
|
||||
// ProtectFunc secures a handlerFunc, requiring requests going through to have a valid Authorization Bearer token.
|
||||
// Unlike ProtectFuncWithPermissions, ProtectFunc will allow access to any registered tokens, regardless of their
|
||||
// permissions or lack thereof.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// gate := g8.New().WithAuthorizationService(g8.NewAuthorizationService().WithToken("token"))
|
||||
// router := http.NewServeMux()
|
||||
// // Without protection
|
||||
// router.HandleFunc("/handle", yourHandlerFunc)
|
||||
// // With protection
|
||||
// router.HandleFunc("/handle", gate.ProtectFunc(yourHandlerFunc))
|
||||
//
|
||||
// The token extracted from the request is passed to the handlerFunc request context under the key TokenContextKey
|
||||
func (gate *Gate) ProtectFunc(handlerFunc http.HandlerFunc) http.HandlerFunc {
|
||||
return gate.ProtectFuncWithPermissions(handlerFunc, nil)
|
||||
}
|
||||
|
||||
// ProtectFuncWithPermissions secures a handler, requiring requests going through to have a valid Authorization Bearer
|
||||
// token as well as a slice of permissions that must be met.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// gate := g8.New().WithAuthorizationService(g8.NewAuthorizationService().WithClient(g8.NewClient("token").WithPermission("admin")))
|
||||
// router := http.NewServeMux()
|
||||
// // Without protection
|
||||
// router.HandleFunc("/handle", yourHandlerFunc)
|
||||
// // With protection
|
||||
// router.HandleFunc("/handle", gate.ProtectFuncWithPermissions(yourHandlerFunc, []string{"admin"}))
|
||||
//
|
||||
// The token extracted from the request is passed to the handlerFunc request context under the key TokenContextKey
|
||||
func (gate *Gate) ProtectFuncWithPermissions(handlerFunc http.HandlerFunc, permissions []string) http.HandlerFunc {
|
||||
return func(writer http.ResponseWriter, request *http.Request) {
|
||||
if gate.rateLimiter != nil {
|
||||
if !gate.rateLimiter.Try() {
|
||||
writer.WriteHeader(http.StatusTooManyRequests)
|
||||
_, _ = writer.Write(gate.tooManyRequestsResponseBody)
|
||||
return
|
||||
}
|
||||
}
|
||||
if gate.authorizationService != nil {
|
||||
token := gate.ExtractTokenFromRequest(request)
|
||||
if client, authorized := gate.authorizationService.Authorize(token, permissions); !authorized {
|
||||
writer.WriteHeader(http.StatusUnauthorized)
|
||||
_, _ = writer.Write(gate.unauthorizedResponseBody)
|
||||
return
|
||||
} else {
|
||||
request = request.WithContext(context.WithValue(request.Context(), TokenContextKey, token))
|
||||
if client != nil && client.Data != nil {
|
||||
request = request.WithContext(context.WithValue(request.Context(), DataContextKey, client.Data))
|
||||
}
|
||||
}
|
||||
}
|
||||
handlerFunc(writer, request)
|
||||
}
|
||||
}
|
||||
|
||||
// ProtectFuncWithPermission does the same thing as ProtectFuncWithPermissions, but for a single permission instead of a
|
||||
// slice of permissions
|
||||
//
|
||||
// See ProtectFuncWithPermissions for further documentation
|
||||
func (gate *Gate) ProtectFuncWithPermission(handlerFunc http.HandlerFunc, permission string) http.HandlerFunc {
|
||||
return gate.ProtectFuncWithPermissions(handlerFunc, []string{permission})
|
||||
}
|
||||
|
||||
// ExtractTokenFromRequest extracts a token from a request.
|
||||
//
|
||||
// By default, it extracts the bearer token from the AuthorizationHeader, but if a customTokenExtractorFunc is defined,
|
||||
// it will use that instead.
|
||||
//
|
||||
// Note that this method is internally used by Protect, ProtectWithPermission, ProtectFunc and
|
||||
// ProtectFuncWithPermissions, but it is exposed in case you need to use it directly.
|
||||
func (gate *Gate) ExtractTokenFromRequest(request *http.Request) string {
|
||||
if gate.customTokenExtractorFunc != nil {
|
||||
// A custom token extractor function is defined, so we'll use it instead of the default token extraction logic
|
||||
return gate.customTokenExtractorFunc(request)
|
||||
}
|
||||
return strings.TrimPrefix(request.Header.Get(AuthorizationHeader), "Bearer ")
|
||||
}
|
||||
|
||||
// PermissionMiddleware is a middleware that behaves like ProtectWithPermission, but it is meant to be used
|
||||
// as a middleware for libraries that support such a feature.
|
||||
//
|
||||
// For instance, if you are using github.com/gorilla/mux, you can use PermissionMiddleware like so:
|
||||
//
|
||||
// router := mux.NewRouter()
|
||||
// router.Use(gate.PermissionMiddleware("admin"))
|
||||
// router.Handle("/admin/handle", adminHandler)
|
||||
//
|
||||
// If you do not want to protect a router with a specific permission, you can use Gate.Protect instead.
|
||||
func (gate *Gate) PermissionMiddleware(permissions ...string) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return gate.ProtectWithPermissions(next, permissions)
|
||||
}
|
||||
}
|
208
gate_bench_test.go
Normal file
208
gate_bench_test.go
Normal file
|
@ -0,0 +1,208 @@
|
|||
package g8
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var handler http.Handler = &testHandler{}
|
||||
|
||||
func BenchmarkTestHandler(b *testing.B) {
|
||||
request, _ := http.NewRequest("GET", "/handle", nil)
|
||||
|
||||
router := http.NewServeMux()
|
||||
router.Handle("/handle", handler)
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
responseRecorder := httptest.NewRecorder()
|
||||
router.ServeHTTP(responseRecorder, request)
|
||||
if responseRecorder.Code != http.StatusOK {
|
||||
b.Fatalf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code)
|
||||
}
|
||||
}
|
||||
b.ReportAllocs()
|
||||
}
|
||||
|
||||
func BenchmarkGate_ProtectWhenNoAuthorizationHeader(b *testing.B) {
|
||||
gate := New().WithAuthorizationService(NewAuthorizationService().WithToken("good-token"))
|
||||
request, _ := http.NewRequest("GET", "/handle", nil)
|
||||
|
||||
router := http.NewServeMux()
|
||||
router.Handle("/handle", gate.Protect(handler))
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
responseRecorder := httptest.NewRecorder()
|
||||
router.ServeHTTP(responseRecorder, request)
|
||||
if responseRecorder.Code != http.StatusUnauthorized {
|
||||
b.Fatalf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusUnauthorized, responseRecorder.Code)
|
||||
}
|
||||
}
|
||||
b.ReportAllocs()
|
||||
}
|
||||
|
||||
func BenchmarkGate_ProtectWithInvalidToken(b *testing.B) {
|
||||
gate := New().WithAuthorizationService(NewAuthorizationService().WithToken("good-token"))
|
||||
request, _ := http.NewRequest("GET", "/handle", nil)
|
||||
request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "bad-token"))
|
||||
|
||||
router := http.NewServeMux()
|
||||
router.Handle("/handle", gate.Protect(handler))
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
responseRecorder := httptest.NewRecorder()
|
||||
router.ServeHTTP(responseRecorder, request)
|
||||
if responseRecorder.Code != http.StatusUnauthorized {
|
||||
b.Fatalf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusUnauthorized, responseRecorder.Code)
|
||||
}
|
||||
}
|
||||
b.ReportAllocs()
|
||||
}
|
||||
|
||||
func BenchmarkGate_ProtectWithValidToken(b *testing.B) {
|
||||
gate := New().WithAuthorizationService(NewAuthorizationService().WithToken("good-token"))
|
||||
request, _ := http.NewRequest("GET", "/handle", http.NoBody)
|
||||
request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "good-token"))
|
||||
|
||||
router := http.NewServeMux()
|
||||
router.Handle("/handle", gate.Protect(handler))
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
responseRecorder := httptest.NewRecorder()
|
||||
router.ServeHTTP(responseRecorder, request)
|
||||
if responseRecorder.Code != http.StatusOK {
|
||||
b.Fatalf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code)
|
||||
}
|
||||
}
|
||||
b.ReportAllocs()
|
||||
}
|
||||
|
||||
func BenchmarkGate_ProtectWithPermissionsAndValidToken(b *testing.B) {
|
||||
gate := New().WithAuthorizationService(NewAuthorizationService().WithClient(NewClient("token").WithPermission("admin")))
|
||||
request, _ := http.NewRequest("GET", "/handle", http.NoBody)
|
||||
request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "token"))
|
||||
|
||||
router := http.NewServeMux()
|
||||
router.Handle("/handle", gate.ProtectWithPermissions(handler, []string{"admin"}))
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
responseRecorder := httptest.NewRecorder()
|
||||
router.ServeHTTP(responseRecorder, request)
|
||||
if responseRecorder.Code != http.StatusOK {
|
||||
b.Fatalf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code)
|
||||
}
|
||||
}
|
||||
b.ReportAllocs()
|
||||
}
|
||||
|
||||
func BenchmarkGate_ProtectWithPermissionsAndValidTokenButInsufficientPermissions(b *testing.B) {
|
||||
gate := New().WithAuthorizationService(NewAuthorizationService().WithClient(NewClient("token").WithPermission("mod")))
|
||||
request, _ := http.NewRequest("GET", "/handle", http.NoBody)
|
||||
request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "token"))
|
||||
|
||||
router := http.NewServeMux()
|
||||
router.Handle("/handle", gate.ProtectWithPermissions(handler, []string{"admin"}))
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
responseRecorder := httptest.NewRecorder()
|
||||
router.ServeHTTP(responseRecorder, request)
|
||||
if responseRecorder.Code != http.StatusUnauthorized {
|
||||
b.Fatalf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusUnauthorized, responseRecorder.Code)
|
||||
}
|
||||
}
|
||||
b.ReportAllocs()
|
||||
}
|
||||
|
||||
func BenchmarkGate_ProtectConcurrently(b *testing.B) {
|
||||
gate := New().WithAuthorizationService(NewAuthorizationService().WithToken("good-token"))
|
||||
|
||||
request, _ := http.NewRequest("GET", "/handle", http.NoBody)
|
||||
request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "good-token"))
|
||||
|
||||
badRequest, _ := http.NewRequest("GET", "/handle", http.NoBody)
|
||||
badRequest.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "bad-token"))
|
||||
|
||||
router := http.NewServeMux()
|
||||
router.Handle("/handle", gate.Protect(handler))
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
responseRecorder := httptest.NewRecorder()
|
||||
router.ServeHTTP(responseRecorder, request)
|
||||
if responseRecorder.Code != http.StatusOK {
|
||||
b.Fatalf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code)
|
||||
}
|
||||
responseRecorder = httptest.NewRecorder()
|
||||
router.ServeHTTP(responseRecorder, badRequest)
|
||||
if responseRecorder.Code != http.StatusUnauthorized {
|
||||
b.Fatalf("%s %s should have returned %d, but returned %d instead", badRequest.Method, badRequest.URL, http.StatusUnauthorized, responseRecorder.Code)
|
||||
}
|
||||
}
|
||||
})
|
||||
b.ReportAllocs()
|
||||
}
|
||||
|
||||
func BenchmarkGate_ProtectWithClientProviderConcurrently(b *testing.B) {
|
||||
gate := New().WithAuthorizationService(NewAuthorizationService().WithClientProvider(mockClientProvider))
|
||||
|
||||
request, _ := http.NewRequest("GET", "/handle", http.NoBody)
|
||||
request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", TestProviderClientToken))
|
||||
|
||||
firstBadRequest, _ := http.NewRequest("GET", "/handle", http.NoBody)
|
||||
firstBadRequest.Header.Set("Authorization", fmt.Sprintf("Bearer %s", "bad-token-1"))
|
||||
|
||||
secondBadRequest, _ := http.NewRequest("GET", "/handle", http.NoBody)
|
||||
secondBadRequest.Header.Set("Authorization", fmt.Sprintf("Bearer %s", "bad-token-2"))
|
||||
|
||||
router := http.NewServeMux()
|
||||
router.Handle("/handle", gate.Protect(handler))
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
responseRecorder := httptest.NewRecorder()
|
||||
router.ServeHTTP(responseRecorder, request)
|
||||
if responseRecorder.Code != http.StatusOK {
|
||||
b.Fatalf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code)
|
||||
}
|
||||
responseRecorder = httptest.NewRecorder()
|
||||
router.ServeHTTP(responseRecorder, firstBadRequest)
|
||||
if responseRecorder.Code != http.StatusUnauthorized {
|
||||
b.Fatalf("%s %s should have returned %d, but returned %d instead", firstBadRequest.Method, firstBadRequest.URL, http.StatusUnauthorized, responseRecorder.Code)
|
||||
}
|
||||
responseRecorder = httptest.NewRecorder()
|
||||
router.ServeHTTP(responseRecorder, secondBadRequest)
|
||||
if responseRecorder.Code != http.StatusUnauthorized {
|
||||
b.Fatalf("%s %s should have returned %d, but returned %d instead", secondBadRequest.Method, secondBadRequest.URL, http.StatusUnauthorized, responseRecorder.Code)
|
||||
}
|
||||
}
|
||||
})
|
||||
b.ReportAllocs()
|
||||
}
|
||||
|
||||
func BenchmarkGate_ProtectWithValidTokenAndCustomTokenExtractorFuncConcurrently(b *testing.B) {
|
||||
customTokenExtractorFunc := func(request *http.Request) string {
|
||||
sessionCookie, err := request.Cookie("session")
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return sessionCookie.Value
|
||||
}
|
||||
gate := New().WithAuthorizationService(NewAuthorizationService().WithToken("good-token")).WithCustomTokenExtractor(customTokenExtractorFunc)
|
||||
request, _ := http.NewRequest("GET", "/handle", http.NoBody)
|
||||
request.AddCookie(&http.Cookie{Name: "session", Value: "good-token"})
|
||||
|
||||
router := http.NewServeMux()
|
||||
router.Handle("/handle", gate.Protect(handler))
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
responseRecorder := httptest.NewRecorder()
|
||||
router.ServeHTTP(responseRecorder, request)
|
||||
if responseRecorder.Code != http.StatusOK {
|
||||
b.Fatalf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code)
|
||||
}
|
||||
}
|
||||
})
|
||||
b.ReportAllocs()
|
||||
}
|
538
gate_test.go
Normal file
538
gate_test.go
Normal file
|
@ -0,0 +1,538 @@
|
|||
package g8
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
FirstTestProviderClientPermission = "permission-1"
|
||||
SecondTestProviderClientPermission = "permission-2"
|
||||
TestProviderClientToken = "client-token-from-provider"
|
||||
TestProviderClientData = "client-data-from-provider"
|
||||
)
|
||||
|
||||
var (
|
||||
mockClientProvider = NewClientProvider(func(token string) *Client {
|
||||
// We'll pretend that there's only one token that's valid in the client provider, every other token
|
||||
// returns nil
|
||||
if token == TestProviderClientToken {
|
||||
return &Client{
|
||||
Token: TestProviderClientToken,
|
||||
Data: TestProviderClientData,
|
||||
Permissions: []string{FirstTestProviderClientPermission, SecondTestProviderClientPermission},
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
)
|
||||
|
||||
type testHandler struct {
|
||||
}
|
||||
|
||||
func (handler *testHandler) ServeHTTP(writer http.ResponseWriter, _ *http.Request) {
|
||||
writer.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
func testHandlerFunc(writer http.ResponseWriter, _ *http.Request) {
|
||||
writer.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
func TestUsability(t *testing.T) {
|
||||
gate := New().WithAuthorizationService(NewAuthorizationService().WithToken("good-token"))
|
||||
|
||||
var handler http.Handler = &testHandler{}
|
||||
handlerFunc := func(writer http.ResponseWriter, request *http.Request) {
|
||||
writer.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
router := http.NewServeMux()
|
||||
router.Handle("/handle", handler)
|
||||
router.Handle("/handle-protected", gate.Protect(handler))
|
||||
router.HandleFunc("/handlefunc", handlerFunc)
|
||||
router.HandleFunc("/handlefunc-protected", gate.ProtectFunc(handlerFunc))
|
||||
}
|
||||
|
||||
func TestNewGate(t *testing.T) {
|
||||
gate := NewGate(nil)
|
||||
if gate == nil {
|
||||
t.Error("gate should not be nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnprotectedHandler(t *testing.T) {
|
||||
request, _ := http.NewRequest("GET", "/handle", http.NoBody)
|
||||
request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "bad-token"))
|
||||
responseRecorder := httptest.NewRecorder()
|
||||
|
||||
router := http.NewServeMux()
|
||||
router.Handle("/handle", &testHandler{})
|
||||
router.ServeHTTP(responseRecorder, request)
|
||||
|
||||
if responseRecorder.Code != http.StatusOK {
|
||||
t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGate_ProtectWithInvalidToken(t *testing.T) {
|
||||
gate := New().WithAuthorizationService(NewAuthorizationService().WithToken("good-token"))
|
||||
request, _ := http.NewRequest("GET", "/handle", http.NoBody)
|
||||
request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "bad-token"))
|
||||
responseRecorder := httptest.NewRecorder()
|
||||
|
||||
router := http.NewServeMux()
|
||||
router.Handle("/handle", gate.Protect(&testHandler{}))
|
||||
router.ServeHTTP(responseRecorder, request)
|
||||
|
||||
if responseRecorder.Code != http.StatusUnauthorized {
|
||||
t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusUnauthorized, responseRecorder.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGate_ProtectWithValidToken(t *testing.T) {
|
||||
gate := New().WithAuthorizationService(NewAuthorizationService().WithToken("good-token"))
|
||||
request, _ := http.NewRequest("GET", "/handle", http.NoBody)
|
||||
request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "good-token"))
|
||||
responseRecorder := httptest.NewRecorder()
|
||||
|
||||
router := http.NewServeMux()
|
||||
router.Handle("/handle", gate.Protect(&testHandler{}))
|
||||
router.ServeHTTP(responseRecorder, request)
|
||||
|
||||
if responseRecorder.Code != http.StatusOK {
|
||||
t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGate_ProtectMultipleTimes(t *testing.T) {
|
||||
gate := New().WithAuthorizationService(NewAuthorizationService().WithToken("good-token"))
|
||||
request, _ := http.NewRequest("GET", "/handle", http.NoBody)
|
||||
request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "good-token"))
|
||||
badRequest, _ := http.NewRequest("GET", "/handle", http.NoBody)
|
||||
badRequest.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "bad-token"))
|
||||
|
||||
router := http.NewServeMux()
|
||||
router.Handle("/handle", gate.Protect(&testHandler{}))
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
responseRecorder := httptest.NewRecorder()
|
||||
router.ServeHTTP(responseRecorder, request)
|
||||
if responseRecorder.Code != http.StatusOK {
|
||||
t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code)
|
||||
}
|
||||
responseRecorder = httptest.NewRecorder()
|
||||
router.ServeHTTP(responseRecorder, badRequest)
|
||||
if responseRecorder.Code != http.StatusUnauthorized {
|
||||
t.Errorf("%s %s should have returned %d, but returned %d instead", badRequest.Method, badRequest.URL, http.StatusOK, responseRecorder.Code)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGate_ProtectWithValidTokenExposedThroughClientProvider(t *testing.T) {
|
||||
gate := New().WithAuthorizationService(NewAuthorizationService().WithClientProvider(mockClientProvider))
|
||||
request, _ := http.NewRequest("GET", "/handle", http.NoBody)
|
||||
request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", TestProviderClientToken))
|
||||
responseRecorder := httptest.NewRecorder()
|
||||
|
||||
router := http.NewServeMux()
|
||||
router.Handle("/handle", gate.Protect(&testHandler{}))
|
||||
router.ServeHTTP(responseRecorder, request)
|
||||
|
||||
if responseRecorder.Code != http.StatusOK {
|
||||
t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGate_ProtectWithValidTokenExposedThroughClientProviderWithCache(t *testing.T) {
|
||||
gate := New().WithAuthorizationService(NewAuthorizationService().WithClientProvider(mockClientProvider.WithCache(60*time.Minute, 70000)))
|
||||
request, _ := http.NewRequest("GET", "/handle", http.NoBody)
|
||||
request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", TestProviderClientToken))
|
||||
responseRecorder := httptest.NewRecorder()
|
||||
|
||||
router := http.NewServeMux()
|
||||
router.Handle("/handle", gate.Protect(&testHandler{}))
|
||||
router.ServeHTTP(responseRecorder, request)
|
||||
|
||||
if responseRecorder.Code != http.StatusOK {
|
||||
t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGate_ProtectWithInvalidTokenWhenUsingClientProvider(t *testing.T) {
|
||||
gate := New().WithAuthorizationService(NewAuthorizationService().WithClientProvider(mockClientProvider))
|
||||
request, _ := http.NewRequest("GET", "/handle", http.NoBody)
|
||||
request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "bad-token"))
|
||||
responseRecorder := httptest.NewRecorder()
|
||||
|
||||
router := http.NewServeMux()
|
||||
router.Handle("/handle", gate.Protect(&testHandler{}))
|
||||
router.ServeHTTP(responseRecorder, request)
|
||||
|
||||
if responseRecorder.Code != http.StatusUnauthorized {
|
||||
t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusUnauthorized, responseRecorder.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGate_ProtectWithPermissionsWhenValidTokenAndSufficientPermissionsWhileUsingClientProvider(t *testing.T) {
|
||||
gate := New().WithAuthorizationService(NewAuthorizationService().WithClientProvider(mockClientProvider))
|
||||
request, _ := http.NewRequest("GET", "/handle", http.NoBody)
|
||||
request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", TestProviderClientToken))
|
||||
responseRecorder := httptest.NewRecorder()
|
||||
|
||||
router := http.NewServeMux()
|
||||
router.Handle("/handle", gate.ProtectWithPermissions(&testHandler{}, []string{SecondTestProviderClientPermission}))
|
||||
router.ServeHTTP(responseRecorder, request)
|
||||
|
||||
// Since the client returned from the mockClientProvider has FirstTestProviderClientPermission and
|
||||
// SecondTestProviderClientPermission and the testHandler is protected by SecondTestProviderClientPermission,
|
||||
// the request should be authorized
|
||||
if responseRecorder.Code != http.StatusOK {
|
||||
t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGate_ProtectWithPermissionsWhenValidTokenAndInsufficientPermissionsWhileUsingClientProvider(t *testing.T) {
|
||||
gate := New().WithAuthorizationService(NewAuthorizationService().WithClientProvider(mockClientProvider))
|
||||
request, _ := http.NewRequest("GET", "/handle", http.NoBody)
|
||||
request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", TestProviderClientToken))
|
||||
responseRecorder := httptest.NewRecorder()
|
||||
|
||||
router := http.NewServeMux()
|
||||
router.Handle("/handle", gate.ProtectWithPermissions(&testHandler{}, []string{"unrelated-permission"}))
|
||||
router.ServeHTTP(responseRecorder, request)
|
||||
|
||||
// Since the client returned from the mockClientProvider has FirstTestProviderClientPermission and
|
||||
// SecondTestProviderClientPermission and the testHandler is protected by a permission that the client does not
|
||||
// have, the request should be not be authorized
|
||||
if responseRecorder.Code != http.StatusUnauthorized {
|
||||
t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusUnauthorized, responseRecorder.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGate_ProtectWithPermissionsWhenClientHasSufficientPermissions(t *testing.T) {
|
||||
gate := New().WithAuthorizationService(NewAuthorizationService().WithClient(NewClient("token").WithPermission("admin")))
|
||||
request, _ := http.NewRequest("GET", "/handle", http.NoBody)
|
||||
request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "token"))
|
||||
responseRecorder := httptest.NewRecorder()
|
||||
|
||||
router := http.NewServeMux()
|
||||
router.Handle("/handle", gate.ProtectWithPermissions(&testHandler{}, []string{"admin"}))
|
||||
router.ServeHTTP(responseRecorder, request)
|
||||
|
||||
// Since the client registered directly in the AuthorizationService has the permission "admin" and the testHandler
|
||||
// is protected by the permission "admin", the request should be authorized
|
||||
if responseRecorder.Code != http.StatusOK {
|
||||
t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGate_ProtectWithPermissionsWhenClientHasInsufficientPermissions(t *testing.T) {
|
||||
gate := New().WithAuthorizationService(NewAuthorizationService().WithClient(NewClientWithPermissions("token", []string{"mod"})))
|
||||
request, _ := http.NewRequest("GET", "/handle", http.NoBody)
|
||||
request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "token"))
|
||||
responseRecorder := httptest.NewRecorder()
|
||||
|
||||
router := http.NewServeMux()
|
||||
router.Handle("/handle", gate.ProtectWithPermissions(&testHandler{}, []string{"admin"}))
|
||||
router.ServeHTTP(responseRecorder, request)
|
||||
|
||||
// Since the client registered directly in the AuthorizationService has the permission "mod" and the
|
||||
// testHandler is protected by the permission "admin", the request should be not be authorized
|
||||
if responseRecorder.Code != http.StatusUnauthorized {
|
||||
t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusUnauthorized, responseRecorder.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGate_ProtectWithPermissions(t *testing.T) {
|
||||
gate := New().WithAuthorizationService(NewAuthorizationService().WithClient(NewClient("mytoken").WithPermissions([]string{"create", "read", "update", "delete"})))
|
||||
|
||||
router := http.NewServeMux()
|
||||
router.Handle("/create", gate.ProtectWithPermissions(&testHandler{}, []string{"create"}))
|
||||
router.Handle("/read", gate.ProtectWithPermissions(&testHandler{}, []string{"read"}))
|
||||
router.Handle("/update", gate.ProtectWithPermissions(&testHandler{}, []string{"update"}))
|
||||
router.Handle("/delete", gate.ProtectWithPermissions(&testHandler{}, []string{"delete"}))
|
||||
router.Handle("/crud", gate.ProtectWithPermissions(&testHandler{}, []string{"create", "read", "update", "delete"}))
|
||||
router.Handle("/backup", gate.ProtectWithPermissions(&testHandler{}, []string{"read", "backup"}))
|
||||
|
||||
checkRouterOutput := func(t *testing.T, router *http.ServeMux, url string, expectedResponseCode int) {
|
||||
t.Run(strings.TrimPrefix(url, "/"), func(t *testing.T) {
|
||||
request, _ := http.NewRequest("GET", url, http.NoBody)
|
||||
request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "mytoken"))
|
||||
responseRecorder := httptest.NewRecorder()
|
||||
router.ServeHTTP(responseRecorder, request)
|
||||
if responseRecorder.Code != expectedResponseCode {
|
||||
t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, expectedResponseCode, responseRecorder.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
checkRouterOutput(t, router, "/create", http.StatusOK)
|
||||
checkRouterOutput(t, router, "/read", http.StatusOK)
|
||||
checkRouterOutput(t, router, "/update", http.StatusOK)
|
||||
checkRouterOutput(t, router, "/delete", http.StatusOK)
|
||||
checkRouterOutput(t, router, "/crud", http.StatusOK)
|
||||
checkRouterOutput(t, router, "/backup", http.StatusUnauthorized)
|
||||
}
|
||||
|
||||
func TestGate_ProtectWithPermissionWhenClientHasSufficientPermissions(t *testing.T) {
|
||||
gate := New().WithAuthorizationService(NewAuthorizationService().WithClient(NewClient("token").WithPermission("admin")))
|
||||
request, _ := http.NewRequest("GET", "/handle", http.NoBody)
|
||||
request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "token"))
|
||||
responseRecorder := httptest.NewRecorder()
|
||||
|
||||
router := http.NewServeMux()
|
||||
router.Handle("/handle", gate.ProtectWithPermission(&testHandler{}, "admin"))
|
||||
router.ServeHTTP(responseRecorder, request)
|
||||
|
||||
// Since the client registered directly in the AuthorizationService has the permission "admin" and the testHandler
|
||||
// is protected by the permission "admin", the request should be authorized
|
||||
if responseRecorder.Code != http.StatusOK {
|
||||
t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGate_ProtectWithPermissionWhenClientHasInsufficientPermissions(t *testing.T) {
|
||||
gate := New().WithAuthorizationService(NewAuthorizationService().WithClient(NewClientWithPermissions("token", []string{"mod"})))
|
||||
request, _ := http.NewRequest("GET", "/handle", http.NoBody)
|
||||
request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "token"))
|
||||
responseRecorder := httptest.NewRecorder()
|
||||
|
||||
router := http.NewServeMux()
|
||||
router.Handle("/handle", gate.ProtectWithPermission(&testHandler{}, "admin"))
|
||||
router.ServeHTTP(responseRecorder, request)
|
||||
|
||||
// Since the client registered directly in the AuthorizationService has the permission "mod" and the
|
||||
// testHandler is protected by the permission "admin", the request should be not be authorized
|
||||
if responseRecorder.Code != http.StatusUnauthorized {
|
||||
t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusUnauthorized, responseRecorder.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGate_PermissionMiddlewareWhenClientHasSufficientPermissions(t *testing.T) {
|
||||
gate := New().WithAuthorizationService(NewAuthorizationService().WithClient(NewClient("token").WithPermission("admin")))
|
||||
request, _ := http.NewRequest("GET", "/handle", http.NoBody)
|
||||
request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "token"))
|
||||
responseRecorder := httptest.NewRecorder()
|
||||
|
||||
router := http.NewServeMux()
|
||||
router.Handle("/handle", gate.PermissionMiddleware("admin")(&testHandler{}))
|
||||
router.ServeHTTP(responseRecorder, request)
|
||||
|
||||
// Since the client registered directly in the AuthorizationService has the permission "admin" and the testHandler
|
||||
// is protected by the permission "admin", the request should be authorized
|
||||
if responseRecorder.Code != http.StatusOK {
|
||||
t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGate_PermissionMiddlewareWhenClientHasInsufficientPermissions(t *testing.T) {
|
||||
gate := New().WithAuthorizationService(NewAuthorizationService().WithClient(NewClientWithPermissions("token", []string{"mod"})))
|
||||
request, _ := http.NewRequest("GET", "/handle", http.NoBody)
|
||||
request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "token"))
|
||||
responseRecorder := httptest.NewRecorder()
|
||||
|
||||
router := http.NewServeMux()
|
||||
router.Handle("/handle", gate.PermissionMiddleware("admin")(&testHandler{}))
|
||||
router.ServeHTTP(responseRecorder, request)
|
||||
|
||||
// Since the client registered directly in the AuthorizationService has the permission "mod" and the
|
||||
// testHandler is protected by the permission "admin", the request should be not be authorized
|
||||
if responseRecorder.Code != http.StatusUnauthorized {
|
||||
t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusUnauthorized, responseRecorder.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGate_ProtectFuncWithInvalidToken(t *testing.T) {
|
||||
gate := New().WithAuthorizationService(NewAuthorizationService().WithToken("good-token"))
|
||||
request, _ := http.NewRequest("GET", "/handle", http.NoBody)
|
||||
request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "bad-token"))
|
||||
responseRecorder := httptest.NewRecorder()
|
||||
|
||||
router := http.NewServeMux()
|
||||
router.Handle("/handle", gate.ProtectFunc(testHandlerFunc))
|
||||
router.ServeHTTP(responseRecorder, request)
|
||||
|
||||
if responseRecorder.Code != http.StatusUnauthorized {
|
||||
t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusUnauthorized, responseRecorder.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGate_ProtectFuncWithValidToken(t *testing.T) {
|
||||
gate := New().WithAuthorizationService(NewAuthorizationService().WithToken("good-token"))
|
||||
request, _ := http.NewRequest("GET", "/handle", http.NoBody)
|
||||
request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "good-token"))
|
||||
responseRecorder := httptest.NewRecorder()
|
||||
|
||||
router := http.NewServeMux()
|
||||
router.Handle("/handle", gate.ProtectFunc(testHandlerFunc))
|
||||
router.ServeHTTP(responseRecorder, request)
|
||||
|
||||
if responseRecorder.Code != http.StatusOK {
|
||||
t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGate_ProtectFuncWithPermissionWhenClientHasSufficientPermissions(t *testing.T) {
|
||||
gate := New().WithAuthorizationService(NewAuthorizationService().WithClient(NewClient("token").WithPermission("admin")))
|
||||
request, _ := http.NewRequest("GET", "/handle", http.NoBody)
|
||||
request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "token"))
|
||||
responseRecorder := httptest.NewRecorder()
|
||||
|
||||
router := http.NewServeMux()
|
||||
router.HandleFunc("/handle", gate.ProtectFuncWithPermission(testHandlerFunc, "admin"))
|
||||
router.ServeHTTP(responseRecorder, request)
|
||||
|
||||
// Since the client registered directly in the AuthorizationService has the permission "admin" and the testHandler
|
||||
// is protected by the permission "admin", the request should be authorized
|
||||
if responseRecorder.Code != http.StatusOK {
|
||||
t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGate_ProtectFuncWithPermissionWhenClientHasInsufficientPermissions(t *testing.T) {
|
||||
gate := New().WithAuthorizationService(NewAuthorizationService().WithClient(NewClientWithPermissions("token", []string{"mod"})))
|
||||
request, _ := http.NewRequest("GET", "/handle", http.NoBody)
|
||||
request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "token"))
|
||||
responseRecorder := httptest.NewRecorder()
|
||||
|
||||
router := http.NewServeMux()
|
||||
router.HandleFunc("/handle", gate.ProtectFuncWithPermission(testHandlerFunc, "admin"))
|
||||
router.ServeHTTP(responseRecorder, request)
|
||||
|
||||
// Since the client registered directly in the AuthorizationService has the permission "mod" and the
|
||||
// testHandler is protected by the permission "admin", the request should be not be authorized
|
||||
if responseRecorder.Code != http.StatusUnauthorized {
|
||||
t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusUnauthorized, responseRecorder.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGate_WithCustomUnauthorizedResponseBody(t *testing.T) {
|
||||
gate := New().WithAuthorizationService(NewAuthorizationService()).WithCustomUnauthorizedResponseBody([]byte("test"))
|
||||
request, _ := http.NewRequest("GET", "/handle", http.NoBody)
|
||||
request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "bad-token"))
|
||||
responseRecorder := httptest.NewRecorder()
|
||||
|
||||
router := http.NewServeMux()
|
||||
router.Handle("/handle", gate.Protect(&testHandler{}))
|
||||
router.ServeHTTP(responseRecorder, request)
|
||||
|
||||
if responseRecorder.Code != http.StatusUnauthorized {
|
||||
t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusUnauthorized, responseRecorder.Code)
|
||||
}
|
||||
if responseBody, _ := io.ReadAll(responseRecorder.Body); string(responseBody) != "test" {
|
||||
t.Errorf("%s %s should have returned %s, but returned %s instead", request.Method, request.URL, "test", string(responseBody))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGate_ProtectWithNoAuthorizationService(t *testing.T) {
|
||||
gate := New()
|
||||
request, _ := http.NewRequest("GET", "/handle", http.NoBody)
|
||||
responseRecorder := httptest.NewRecorder()
|
||||
|
||||
router := http.NewServeMux()
|
||||
router.Handle("/handle", gate.Protect(&testHandler{}))
|
||||
router.ServeHTTP(responseRecorder, request)
|
||||
|
||||
if responseRecorder.Code != http.StatusOK {
|
||||
t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGate_ProtectWithRateLimit(t *testing.T) {
|
||||
gate := New().WithRateLimit(2)
|
||||
request, _ := http.NewRequest("GET", "/handle", http.NoBody)
|
||||
router := http.NewServeMux()
|
||||
router.Handle("/handle", gate.Protect(&testHandler{}))
|
||||
|
||||
responseRecorder := httptest.NewRecorder()
|
||||
router.ServeHTTP(responseRecorder, request)
|
||||
if responseRecorder.Code != http.StatusOK {
|
||||
t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code)
|
||||
}
|
||||
|
||||
responseRecorder = httptest.NewRecorder()
|
||||
router.ServeHTTP(responseRecorder, request)
|
||||
if responseRecorder.Code != http.StatusOK {
|
||||
t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code)
|
||||
}
|
||||
|
||||
responseRecorder = httptest.NewRecorder()
|
||||
router.ServeHTTP(responseRecorder, request)
|
||||
if responseRecorder.Code != http.StatusTooManyRequests {
|
||||
t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusTooManyRequests, responseRecorder.Code)
|
||||
}
|
||||
|
||||
// Wait for rate limit time window to pass
|
||||
time.Sleep(time.Second)
|
||||
|
||||
responseRecorder = httptest.NewRecorder()
|
||||
router.ServeHTTP(responseRecorder, request)
|
||||
if responseRecorder.Code != http.StatusOK {
|
||||
t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGate_WithCustomTokenExtractor(t *testing.T) {
|
||||
authorizationService := NewAuthorizationService().WithClientProvider(mockClientProvider)
|
||||
customTokenExtractorFunc := func(request *http.Request) string {
|
||||
sessionCookie, err := request.Cookie("session")
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return sessionCookie.Value
|
||||
}
|
||||
gate := New().WithAuthorizationService(authorizationService).WithCustomTokenExtractor(customTokenExtractorFunc)
|
||||
|
||||
request, _ := http.NewRequest("GET", "/handle", http.NoBody)
|
||||
request.AddCookie(&http.Cookie{Name: "session", Value: TestProviderClientToken})
|
||||
responseRecorder := httptest.NewRecorder()
|
||||
|
||||
router := http.NewServeMux()
|
||||
router.Handle("/handle", gate.ProtectFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Context().Value(TokenContextKey) != TestProviderClientToken {
|
||||
t.Errorf("token should have been passed to the request context")
|
||||
}
|
||||
if r.Context().Value(DataContextKey) != TestProviderClientData {
|
||||
t.Errorf("data should have been passed to the request context")
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
router.ServeHTTP(responseRecorder, request)
|
||||
|
||||
if responseRecorder.Code != http.StatusOK {
|
||||
t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGateWithCustomHeader(t *testing.T) {
|
||||
authorizationService := NewAuthorizationService().WithClientProvider(mockClientProvider)
|
||||
customTokenExtractorFunc := func(request *http.Request) string {
|
||||
return request.Header.Get("X-API-Token")
|
||||
}
|
||||
gate := New().WithAuthorizationService(authorizationService).WithCustomTokenExtractor(customTokenExtractorFunc)
|
||||
|
||||
request, _ := http.NewRequest("GET", "/handle", http.NoBody)
|
||||
request.Header.Add("X-API-Token", TestProviderClientToken)
|
||||
responseRecorder := httptest.NewRecorder()
|
||||
|
||||
router := http.NewServeMux()
|
||||
router.Handle("/handle", gate.ProtectFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Context().Value(TokenContextKey) != TestProviderClientToken {
|
||||
t.Errorf("token should have been passed to the request context")
|
||||
}
|
||||
if r.Context().Value(DataContextKey) != TestProviderClientData {
|
||||
t.Errorf("data should have been passed to the request context")
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
router.ServeHTTP(responseRecorder, request)
|
||||
|
||||
if responseRecorder.Code != http.StatusOK {
|
||||
t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code)
|
||||
}
|
||||
}
|
5
go.mod
Normal file
5
go.mod
Normal file
|
@ -0,0 +1,5 @@
|
|||
module github.com/TwiN/g8/v3
|
||||
|
||||
go 1.23.3
|
||||
|
||||
require github.com/TwiN/gocache/v2 v2.2.2
|
2
go.sum
Normal file
2
go.sum
Normal file
|
@ -0,0 +1,2 @@
|
|||
github.com/TwiN/gocache/v2 v2.2.2 h1:4HToPfDV8FSbaYO5kkbhLpEllUYse5rAf+hVU/mSsuI=
|
||||
github.com/TwiN/gocache/v2 v2.2.2/go.mod h1:WfIuwd7GR82/7EfQqEtmLFC3a2vqaKbs4Pe6neB7Gyc=
|
42
ratelimiter.go
Normal file
42
ratelimiter.go
Normal file
|
@ -0,0 +1,42 @@
|
|||
package g8
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// RateLimiter is a fixed rate limiter
|
||||
type RateLimiter struct {
|
||||
maximumExecutionsPerSecond int
|
||||
executionsLeftInWindow int
|
||||
windowStartTime time.Time
|
||||
mutex sync.Mutex
|
||||
}
|
||||
|
||||
// NewRateLimiter creates a RateLimiter
|
||||
func NewRateLimiter(maximumExecutionsPerSecond int) *RateLimiter {
|
||||
return &RateLimiter{
|
||||
windowStartTime: time.Now(),
|
||||
executionsLeftInWindow: maximumExecutionsPerSecond,
|
||||
maximumExecutionsPerSecond: maximumExecutionsPerSecond,
|
||||
}
|
||||
}
|
||||
|
||||
// Try updates the number of executions if the rate limit quota hasn't been reached and returns whether the
|
||||
// attempt was successful or not.
|
||||
//
|
||||
// Returns false if the execution was not successful (rate limit quota has been reached)
|
||||
// Returns true if the execution was successful (rate limit quota has not been reached)
|
||||
func (r *RateLimiter) Try() bool {
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
if time.Now().Add(-time.Second).After(r.windowStartTime) {
|
||||
r.windowStartTime = time.Now()
|
||||
r.executionsLeftInWindow = r.maximumExecutionsPerSecond
|
||||
}
|
||||
if r.executionsLeftInWindow == 0 {
|
||||
return false
|
||||
}
|
||||
r.executionsLeftInWindow--
|
||||
return true
|
||||
}
|
73
ratelimiter_test.go
Normal file
73
ratelimiter_test.go
Normal file
|
@ -0,0 +1,73 @@
|
|||
package g8
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNewRateLimiter(t *testing.T) {
|
||||
rl := NewRateLimiter(2)
|
||||
if rl.maximumExecutionsPerSecond != 2 {
|
||||
t.Errorf("expected maximumExecutionsPerSecond to be %d, got %d", 2, rl.maximumExecutionsPerSecond)
|
||||
}
|
||||
if rl.executionsLeftInWindow != 2 {
|
||||
t.Errorf("expected executionsLeftInWindow to be %d, got %d", 2, rl.executionsLeftInWindow)
|
||||
}
|
||||
// First execution: should not be rate limited
|
||||
if notRateLimited := rl.Try(); !notRateLimited {
|
||||
t.Error("expected Try to return true")
|
||||
}
|
||||
if rl.maximumExecutionsPerSecond != 2 {
|
||||
t.Errorf("expected maximumExecutionsPerSecond to be %d, got %d", 2, rl.maximumExecutionsPerSecond)
|
||||
}
|
||||
if rl.executionsLeftInWindow != 1 {
|
||||
t.Errorf("expected executionsLeftInWindow to be %d, got %d", 1, rl.executionsLeftInWindow)
|
||||
}
|
||||
// Second execution: should not be rate limited
|
||||
if notRateLimited := rl.Try(); !notRateLimited {
|
||||
t.Error("expected Try to return true")
|
||||
}
|
||||
if rl.maximumExecutionsPerSecond != 2 {
|
||||
t.Errorf("expected maximumExecutionsPerSecond to be %d, got %d", 2, rl.maximumExecutionsPerSecond)
|
||||
}
|
||||
if rl.executionsLeftInWindow != 0 {
|
||||
t.Errorf("expected executionsLeftInWindow to be %d, got %d", 0, rl.executionsLeftInWindow)
|
||||
}
|
||||
// Third execution: should be rate limited
|
||||
if notRateLimited := rl.Try(); notRateLimited {
|
||||
t.Error("expected Try to return false")
|
||||
}
|
||||
if rl.maximumExecutionsPerSecond != 2 {
|
||||
t.Errorf("expected maximumExecutionsPerSecond to be %d, got %d", 2, rl.maximumExecutionsPerSecond)
|
||||
}
|
||||
if rl.executionsLeftInWindow != 0 {
|
||||
t.Errorf("expected executionsLeftInWindow to be %d, got %d", 0, rl.executionsLeftInWindow)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimiter_Try(t *testing.T) {
|
||||
rl := NewRateLimiter(5)
|
||||
for i := 0; i < 20; i++ {
|
||||
notRateLimited := rl.Try()
|
||||
if i < 5 {
|
||||
if !notRateLimited {
|
||||
t.Fatal("expected to not be rate limited")
|
||||
}
|
||||
} else {
|
||||
if notRateLimited {
|
||||
t.Fatal("expected to be rate limited")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimiter_TryAlwaysUnderRateLimit(t *testing.T) {
|
||||
rl := NewRateLimiter(20)
|
||||
for i := 0; i < 45; i++ {
|
||||
notRateLimited := rl.Try()
|
||||
if !notRateLimited {
|
||||
t.Fatal("expected to not be rate limited")
|
||||
}
|
||||
time.Sleep(51 * time.Millisecond)
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue