From 8274b1bf1b01b0de0e5bfda51e96569fd8dcc3ff Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Fri, 16 May 2025 22:32:21 +0200 Subject: [PATCH] Adding upstream version 3.0.0. Signed-off-by: Daniel Baumann --- .gitattributes | 1 + .github/codecov.yml | 7 + .github/dependabot.yml | 14 + .github/workflows/test.yml | 30 +++ .gitignore | 3 + LICENSE | 21 ++ README.md | 319 ++++++++++++++++++++++ authorization.go | 130 +++++++++ authorization_test.go | 108 ++++++++ cache.go | 13 + client.go | 79 ++++++ client_test.go | 76 ++++++ clientprovider.go | 102 +++++++ clientprovider_test.go | 131 +++++++++ gate.go | 245 +++++++++++++++++ gate_bench_test.go | 208 ++++++++++++++ gate_test.go | 538 +++++++++++++++++++++++++++++++++++++ go.mod | 5 + go.sum | 2 + ratelimiter.go | 42 +++ ratelimiter_test.go | 73 +++++ 21 files changed, 2147 insertions(+) create mode 100644 .gitattributes create mode 100644 .github/codecov.yml create mode 100644 .github/dependabot.yml create mode 100644 .github/workflows/test.yml create mode 100644 .gitignore create mode 100644 LICENSE create mode 100644 README.md create mode 100644 authorization.go create mode 100644 authorization_test.go create mode 100644 cache.go create mode 100644 client.go create mode 100644 client_test.go create mode 100644 clientprovider.go create mode 100644 clientprovider_test.go create mode 100644 gate.go create mode 100644 gate_bench_test.go create mode 100644 gate_test.go create mode 100644 go.mod create mode 100644 go.sum create mode 100644 ratelimiter.go create mode 100644 ratelimiter_test.go diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..7d07d70 --- /dev/null +++ b/.gitattributes @@ -0,0 +1 @@ +* text=lf diff --git a/.github/codecov.yml b/.github/codecov.yml new file mode 100644 index 0000000..61a25cc --- /dev/null +++ b/.github/codecov.yml @@ -0,0 +1,7 @@ +coverage: + status: + patch: off + project: + default: + target: 75% + threshold: null \ No newline at end of file diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000..88431dc --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,14 @@ +version: 2 +updates: + - package-ecosystem: "github-actions" + directory: "/" + labels: ["dependencies"] + schedule: + interval: "weekly" + day: "saturday" + - package-ecosystem: "gomod" + directory: "/" + labels: ["dependencies"] + schedule: + interval: "weekly" + day: "saturday" diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..e9b8f08 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,30 @@ +name: test +on: + pull_request: + paths-ignore: + - '*.md' + push: + branches: + - master + paths-ignore: + - '*.md' +jobs: + test: + name: test + runs-on: ubuntu-latest + timeout-minutes: 3 + steps: + - uses: actions/setup-go@v5 + with: + go-version: 1.23.3 + - uses: actions/checkout@v4 + - name: Test (race) + run: go test ./... -race + - name: Test (coverage) + run: go test ./... -coverprofile=coverage.txt -covermode=atomic + - name: Codecov + uses: codecov/codecov-action@v5.1.2 + with: + files: ./coverage.txt + token: ${{ secrets.CODECOV_TOKEN }} + diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..887a7c3 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +.idea +*.iml +/vendor \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..7be0409 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2021-2025 TwiN + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..2ac1bd8 --- /dev/null +++ b/README.md @@ -0,0 +1,319 @@ +# g8 + +![test](https://github.com/TwiN/g8/workflows/test/badge.svg?branch=master) +[![Go Report Card](https://goreportcard.com/badge/github.com/TwiN/g8)](https://goreportcard.com/report/github.com/TwiN/g8/v3) +[![codecov](https://codecov.io/gh/TwiN/g8/branch/master/graph/badge.svg)](https://codecov.io/gh/TwiN/g8) +[![Go version](https://img.shields.io/github/go-mod/go-version/TwiN/g8.svg)](https://github.com/TwiN/g8) +[![Go Reference](https://pkg.go.dev/badge/github.com/TwiN/g8.svg)](https://pkg.go.dev/github.com/TwiN/g8/v3) +[![Follow TwiN](https://img.shields.io/github/followers/TwiN?label=Follow&style=social)](https://github.com/TwiN) + +g8, pronounced gate, is a simple Go library for protecting HTTP handlers. + +Tired of constantly re-implementing a security layer for each application? Me too, that's why I made g8. + + +## Installation +```console +go get -u github.com/TwiN/g8/v3 +``` + + +## Usage +Because the entire purpose of g8 is to NOT waste time configuring the layer of security, the primary emphasis is to +keep it as simple as possible. + + +### Simple +Just want a simple layer of security without the need for advanced permissions? This configuration is what you're +looking for. + +```go +authorizationService := g8.NewAuthorizationService().WithToken("mytoken") +gate := g8.New().WithAuthorizationService(authorizationService) + +router := http.NewServeMux() +router.Handle("/unprotected", yourHandler) +router.Handle("/protected", gate.Protect(yourHandler)) + +http.ListenAndServe(":8080", router) +``` + +The endpoint `/protected` is now only accessible if you pass the header `Authorization: Bearer mytoken`. + +If you use `http.HandleFunc` instead of `http.Handle`, you may use `gate.ProtectFunc(yourHandler)` instead. + +If you're not using the `Authorization` header, you can specify a custom token extractor. +This enables use cases like [Protecting a handler using session cookie](#protecting-a-handler-using-session-cookie) + + +### Advanced permissions +If you have tokens with more permissions than others, g8's permission system will make managing authorization a breeze. + +Rather than registering tokens, think of it as registering clients, the only difference being that clients may be +configured with permissions while tokens cannot. + +```go +authorizationService := g8.NewAuthorizationService().WithClient(g8.NewClient("mytoken").WithPermission("admin")) +gate := g8.New().WithAuthorizationService(authorizationService) + +router := http.NewServeMux() +router.Handle("/unprotected", yourHandler) +router.Handle("/protected-with-admin", gate.ProtectWithPermissions(yourHandler, []string{"admin"})) + +http.ListenAndServe(":8080", router) +``` + +The endpoint `/protected-with-admin` is now only accessible if you pass the header `Authorization: Bearer mytoken`, +because the client with the token `mytoken` has the permission `admin`. Note that the following handler would also be +accessible with that token: +```go +router.Handle("/protected", gate.Protect(yourHandler)) +``` + +To clarify, both clients and tokens have access to handlers that aren't protected with extra permissions, and +essentially, tokens are registered as clients with no extra permissions in the background. + +Creating a token like so: +```go +authorizationService := g8.NewAuthorizationService().WithToken("mytoken") +``` +is the equivalent of creating the following client: +```go +authorizationService := g8.NewAuthorizationService().WithClient(g8.NewClient("mytoken")) +``` + + +### With client provider +A client provider's task is to retrieve a Client from an external source (e.g. a database) when provided with a token. +You should use a client provider when you have a lot of tokens and it wouldn't make sense to register all of them using +`AuthorizationService`'s `WithToken`/`WithTokens`/`WithClient`/`WithClients`. + +Note that the provider is used as a fallback source. As such, if a token is explicitly registered using one of the 4 +aforementioned functions, the client provider will not be used. + +```go +clientProvider := g8.NewClientProvider(func(token string) *g8.Client { + // We'll assume that the following function calls your database and returns a struct "User" that + // has the user's token as well as the permissions granted to said user + user := database.GetUserByToken(token) + if user != nil { + return g8.NewClient(user.Token).WithPermissions(user.Permissions) + } + return nil +}) +authorizationService := g8.NewAuthorizationService().WithClientProvider(clientProvider) +gate := g8.New().WithAuthorizationService(authorizationService) +``` + +You can also configure the client provider to cache the output of the function you provide to retrieve clients by token: +```go +clientProvider := g8.NewClientProvider(...).WithCache(ttl, maxSize) +``` + +Since g8 leverages [TwiN/gocache](https://github.com/TwiN/gocache) (unless you're using `WithCustomCache`), +you can also use gocache's constants for configuring the TTL and the maximum size: +- Setting the TTL to `gocache.NoExpiration` (-1) will disable the TTL. +- Setting the maximum size to `gocache.NoMaxSize` (0) will disable the maximum cache size + +To avoid any misunderstandings, using a client provider is not mandatory. If you only have a few tokens and you can load +them on application start, you can just leverage `AuthorizationService`'s `WithToken`/`WithTokens`/`WithClient`/`WithClients`. + + +## AuthorizationService +As the previous examples may have hinted, there are several ways to create clients. The one thing they have +in common is that they all go through AuthorizationService, which is in charge of both managing clients and determining +whether a request should be blocked or allowed through. + +| Function | Description | +|:-------------------|:---------------------------------------------------------------------------------------------------------------------------------| +| WithToken | Creates a single static client with no extra permissions | +| WithTokens | Creates a slice of static clients with no extra permissions | +| WithClient | Creates a single static client | +| WithClients | Creates a slice of static clients | +| WithClientProvider | Creates a client provider which will allow a fallback to a dynamic source (e.g. to a database) when a static client is not found | + +Except for `WithClientProvider`, every functions listed above can be called more than once. +As a result, you may safely perform actions like this: +```go +authorizationService := g8.NewAuthorizationService(). + WithToken("123"). + WithToken("456"). + WithClient(g8.NewClient("789").WithPermission("admin")) +gate := g8.New().WithAuthorizationService(authorizationService) +``` + +Be aware that g8.Client supports a list of permissions as well. You may call `WithPermission` several times, or call +`WithPermissions` with a slice of permissions instead. + + +### Permissions +Unlike client permissions, handler permissions are requirements. + +A client may have as many permissions as you want, but for said client to have access to a handler protected by +permissions, the client must have all permissions defined by said handler in order to have access to it. + +In other words, a client with the permissions `create`, `read`, `update` and `delete` would have access to all of these handlers: +```go +gate := g8.New().WithAuthorizationService(g8.NewAuthorizationService().WithClient(g8.NewClient("mytoken").WithPermissions([]string{"create", "read", "update", "delete"}))) +router := http.NewServeMux() +router.Handle("/", gate.Protect(homeHandler)) // equivalent of gate.ProtectWithPermissions(homeHandler, []string{}) +router.Handle("/create", gate.ProtectWithPermissions(createHandler, []string{"create"})) +router.Handle("/read", gate.ProtectWithPermissions(readHandler, []string{"read"})) +router.Handle("/update", gate.ProtectWithPermissions(updateHandler, []string{"update"})) +router.Handle("/delete", gate.ProtectWithPermissions(deleteHandler, []string{"delete"})) +router.Handle("/crud", gate.ProtectWithPermissions(crudHandler, []string{"create", "read", "update", "delete"})) +``` +But it would not have access to the following handler, because while `mytoken` has the `read` permission, it does not +have the `backup` permission: +```go +router.Handle("/backup", gate.ProtectWithPermissions(&testHandler{}, []string{"read", "backup"})) +``` + +If you're using an HTTP library that supports middlewares like [mux](https://github.com/gorilla/mux), you can protect +an entire group of handlers instead using `gate.Protect` or `gate.PermissionMiddleware()`: +```go +router := mux.NewRouter() + +userRouter := router.PathPrefix("/").Subrouter() +userRouter.Use(gate.Protect) +userRouter.HandleFunc("/api/v1/users/me", getUserProfile).Methods("GET") +userRouter.HandleFunc("/api/v1/users/me/friends", getUserFriends).Methods("GET") +userRouter.HandleFunc("/api/v1/users/me/email", updateUserEmail).Methods("PATCH") + +adminRouter := router.PathPrefix("/").Subrouter() +adminRouter.Use(gate.PermissionMiddleware("admin")) +adminRouter.HandleFunc("/api/v1/users/{id}/ban", banUserByID).Methods("POST") +adminRouter.HandleFunc("/api/v1/users/{id}/delete", deleteUserByID).Methods("DELETE") +``` + + +## Rate limiting +To add a rate limit of 100 requests per second: +```go +gate := g8.New().WithRateLimit(100) +``` + + +## Accessing the token from the protected handlers +If you need to access the token from the handlers you are protecting with g8, you can retrieve it from the +request context by using the key `g8.TokenContextKey`: +```go +http.Handle("/handle", gate.ProtectFunc(func(w http.ResponseWriter, r *http.Request) { + token, _ := r.Context().Value(g8.TokenContextKey).(string) + // ... +})) +``` + +## Examples +### Protecting a handler using session cookie +If you want to only allow authenticated users to access a handler, you can use a custom token extractor function +combined with a client provider. + +First, we'll create a function to extract the session ID from the session cookie. While a session ID does not +theoretically refer to a token, g8 uses the term `token` as a blanket term to refer to any string that can be used to +identify a client. +```go +customTokenExtractorFunc := func(request *http.Request) string { + sessionCookie, err := request.Cookie("session") + if err != nil { + return "" + } + return sessionCookie.Value +} +``` + +Next, we need to create a client provider that will validate our token, which refers to the session ID in this case. +```go +clientProvider := g8.NewClientProvider(func(token string) *g8.Client { + // We'll assume that the following function calls your database and validates whether the session is valid. + isSessionValid := database.CheckIfSessionIsValid(token) + if !isSessionValid { + return nil // Returning nil will cause the gate to return a 401 Unauthorized. + } + // You could also retrieve the user and their permissions if you wanted instead, but for this example, + // all we care about is confirming whether the session is valid or not. + return g8.NewClient(token) +}) +``` + +Keep in mind that you can get really creative with the client provider above. +For instance, you could refresh the session's expiration time, which will allow the user to stay logged in for +as long as they're active. + +You're also not limited to using something stateful like the example above. You could use a JWT and have your client +provider validate said JWT. + +Finally, we can create the authorization service and the gate: +```go +authorizationService := g8.NewAuthorizationService().WithClientProvider(clientProvider) +gate := g8.New().WithAuthorizationService(authorizationService).WithCustomTokenExtractor(customTokenExtractorFunc) +``` + +If you need to access the token (session ID in this case) from the protected handlers, you can retrieve it from the +request context by using the key `g8.TokenContextKey`: +```go +http.Handle("/handle", gate.ProtectFunc(func(w http.ResponseWriter, r *http.Request) { + sessionID, _ := r.Context().Value(g8.TokenContextKey).(string) + // ... +})) +``` + +### Using a custom header +The logic is the same as the example above: +```go +customTokenExtractorFunc := func(request *http.Request) string { + return request.Header.Get("X-API-Token") +} + +clientProvider := g8.NewClientProvider(func(token string) *g8.Client { + // We'll assume that the following function calls your database and returns a struct "User" that + // has the user's token as well as the permissions granted to said user + user := database.GetUserByToken(token) + if user != nil { + return g8.NewClient(user.Token).WithPermissions(user.Permissions) + } + return nil +}) +authorizationService := g8.NewAuthorizationService().WithClientProvider(clientProvider) +gate := g8.New().WithAuthorizationService(authorizationService).WithCustomTokenExtractor(customTokenExtractorFunc) +``` + +### Using a custom cache + +```go +package main + +import ( + g8 "github.com/TwiN/g8/v3" +) + +type customCache struct { + entries map[string]any + sync.Mutex +} + +func (c *customCache) Get(key string) (value any, exists bool) { + return nil, false +} + +func (c *customCache) Set(key string, value any) { + // ... +} + +// To verify the implementation +var _ g8.Cache = (*customCache)(nil) + +func main() { + getClientByTokenFunc := func(token string) *g8.Client { + // We'll assume that the following function calls your database and returns a struct "User" that + // has the user's token as well as the permissions granted to said user + user := database.GetUserByToken(token) + if user != nil { + return g8.NewClient(user.Token).WithPermissions(user.Permissions).WithData(user.Data) + } + return nil + } + // Create the provider with the custom cache + provider := g8.NewClientProvider(getClientByTokenFunc).WithCustomCache(&customCache{}) +} +``` \ No newline at end of file diff --git a/authorization.go b/authorization.go new file mode 100644 index 0000000..5b41699 --- /dev/null +++ b/authorization.go @@ -0,0 +1,130 @@ +package g8 + +import ( + "sync" +) + +// AuthorizationService is the service that manages client/token registry and client fallback as well as the service +// that determines whether a token meets the specific requirements to be authorized by a Gate or not. +type AuthorizationService struct { + clients map[string]*Client + clientProvider *ClientProvider + + mutex sync.RWMutex +} + +// NewAuthorizationService creates a new AuthorizationService +func NewAuthorizationService() *AuthorizationService { + return &AuthorizationService{ + clients: make(map[string]*Client), + } +} + +// WithToken is used to specify a single token for which authorization will be granted +// +// The client that will be created from this token will have access to all handlers that are not protected with a +// specific permission. +// +// In other words, if you were to do the following: +// +// gate := g8.New().WithAuthorizationService(g8.NewAuthorizationService().WithToken("12345")) +// +// The following handler would be accessible with the token 12345: +// +// router.Handle("/1st-handler", gate.Protect(yourHandler)) +// +// But not this one would not be accessible with the token 12345: +// +// router.Handle("/2nd-handler", gate.ProtectWithPermissions(yourOtherHandler, []string{"admin"})) +// +// Calling this function multiple times will add multiple clients, though you may want to use WithTokens instead +// if you plan to add multiple clients +// +// If you wish to configure advanced permissions, consider using WithClient instead. +func (authorizationService *AuthorizationService) WithToken(token string) *AuthorizationService { + authorizationService.mutex.Lock() + authorizationService.clients[token] = NewClient(token) + authorizationService.mutex.Unlock() + return authorizationService +} + +// WithTokens is used to specify a slice of tokens for which authorization will be granted +func (authorizationService *AuthorizationService) WithTokens(tokens []string) *AuthorizationService { + authorizationService.mutex.Lock() + for _, token := range tokens { + authorizationService.clients[token] = NewClient(token) + } + authorizationService.mutex.Unlock() + return authorizationService +} + +// WithClient is used to specify a single client for which authorization will be granted +// +// When compared to WithToken, the advantage of using this function is that you may specify the client's +// permissions and thus, be a lot more granular with what endpoint a token has access to. +// +// In other words, if you were to do the following: +// +// gate := g8.New().WithAuthorizationService(g8.NewAuthorizationService().WithClient(g8.NewClient("12345").WithPermission("mod"))) +// +// The following handlers would be accessible with the token 12345: +// +// router.Handle("/1st-handler", gate.ProtectWithPermissions(yourHandler, []string{"mod"})) +// router.Handle("/2nd-handler", gate.Protect(yourOtherHandler)) +// +// But not this one, because the user does not have the permission "admin": +// +// router.Handle("/3rd-handler", gate.ProtectWithPermissions(yetAnotherHandler, []string{"admin"})) +// +// Calling this function multiple times will add multiple clients, though you may want to use WithClients instead +// if you plan to add multiple clients +func (authorizationService *AuthorizationService) WithClient(client *Client) *AuthorizationService { + authorizationService.mutex.Lock() + authorizationService.clients[client.Token] = client + authorizationService.mutex.Unlock() + return authorizationService +} + +// WithClients is used to specify a slice of clients for which authorization will be granted +func (authorizationService *AuthorizationService) WithClients(clients []*Client) *AuthorizationService { + authorizationService.mutex.Lock() + for _, client := range clients { + authorizationService.clients[client.Token] = client + } + authorizationService.mutex.Unlock() + return authorizationService +} + +// WithClientProvider allows specifying a custom provider to fetch clients by token. +// +// For example, you can use it to fallback to making a call in your database when a request is made with a token that +// hasn't been specified via WithToken, WithTokens, WithClient or WithClients. +func (authorizationService *AuthorizationService) WithClientProvider(provider *ClientProvider) *AuthorizationService { + authorizationService.clientProvider = provider + return authorizationService +} + +// Authorize checks whether a client with a given token exists and has the permissions required. +// +// If permissionsRequired is nil or empty and a client with the given token exists, said client will have access to all +// handlers that are not protected by a given permission. +// +// Returns the client is authorized (or nil if no client was authorized), as well as whether the token is authorized +func (authorizationService *AuthorizationService) Authorize(token string, permissionsRequired []string) (client *Client, authorized bool) { + if len(token) == 0 { + return nil, false + } + authorizationService.mutex.RLock() + client, _ = authorizationService.clients[token] + authorizationService.mutex.RUnlock() + // If there's no clients with the given token directly stored in the AuthorizationService, fall back to the + // client provider, if there's one configured. + if client == nil && authorizationService.clientProvider != nil { + client = authorizationService.clientProvider.GetClientByToken(token) + } + if client != nil && client.HasPermissions(permissionsRequired) { + // If the client has the required permissions, return true and the client + return client, true + } + return nil, false +} diff --git a/authorization_test.go b/authorization_test.go new file mode 100644 index 0000000..a339bdb --- /dev/null +++ b/authorization_test.go @@ -0,0 +1,108 @@ +package g8 + +import "testing" + +func TestAuthorizationService_Authorize(t *testing.T) { + authorizationService := NewAuthorizationService().WithToken("token") + if _, authorized := authorizationService.Authorize("token", nil); !authorized { + t.Error("should've returned true") + } + if _, authorized := authorizationService.Authorize("bad-token", nil); authorized { + t.Error("should've returned false") + } + if _, authorized := authorizationService.Authorize("token", []string{"admin"}); authorized { + t.Error("should've returned false") + } + if _, authorized := authorizationService.Authorize("", nil); authorized { + t.Error("should've returned false") + } +} + +func TestAuthorizationService_AuthorizeWithPermissions(t *testing.T) { + authorizationService := NewAuthorizationService().WithClient(NewClient("token").WithPermissions([]string{"a", "b"})) + if _, authorized := authorizationService.Authorize("token", nil); !authorized { + t.Error("should've returned true") + } + if _, authorized := authorizationService.Authorize("token", []string{"a"}); !authorized { + t.Error("should've returned true") + } + if _, authorized := authorizationService.Authorize("token", []string{"b"}); !authorized { + t.Error("should've returned true") + } + if _, authorized := authorizationService.Authorize("token", []string{"a", "b"}); !authorized { + t.Error("should've returned true") + } + if _, authorized := authorizationService.Authorize("token", []string{"c"}); authorized { + t.Error("should've returned false") + } + if _, authorized := authorizationService.Authorize("token", []string{"a", "c"}); authorized { + t.Error("should've returned false") + } + if _, authorized := authorizationService.Authorize("bad-token", nil); authorized { + t.Error("should've returned false") + } + if _, authorized := authorizationService.Authorize("bad-token", []string{"a"}); authorized { + t.Error("should've returned false") + } + if _, authorized := authorizationService.Authorize("", []string{"a"}); authorized { + t.Error("should've returned false") + } +} + +func TestAuthorizationService_WithToken(t *testing.T) { + authorizationService := NewAuthorizationService().WithToken("token") + if _, authorized := authorizationService.Authorize("token", nil); !authorized { + t.Error("should've returned true") + } + if _, authorized := authorizationService.Authorize("bad-token", nil); authorized { + t.Error("should've returned false") + } + if _, authorized := authorizationService.Authorize("token", []string{"admin"}); authorized { + t.Error("should've returned false") + } +} + +func TestAuthorizationService_WithTokens(t *testing.T) { + authorizationService := NewAuthorizationService().WithTokens([]string{"1", "2"}) + if _, authorized := authorizationService.Authorize("1", nil); !authorized { + t.Error("should've returned true") + } + if _, authorized := authorizationService.Authorize("2", nil); !authorized { + t.Error("should've returned true") + } + if _, authorized := authorizationService.Authorize("3", nil); authorized { + t.Error("should've returned false") + } +} + +func TestAuthorizationService_WithClient(t *testing.T) { + authorizationService := NewAuthorizationService().WithClient(NewClient("token").WithPermissions([]string{"a", "b"})) + if _, authorized := authorizationService.Authorize("token", []string{"a", "b"}); !authorized { + t.Error("should've returned true") + } + if _, authorized := authorizationService.Authorize("token", []string{"a"}); !authorized { + t.Error("should've returned true") + } + if _, authorized := authorizationService.Authorize("token", []string{"b"}); !authorized { + t.Error("should've returned true") + } + if _, authorized := authorizationService.Authorize("token", []string{"c"}); authorized { + t.Error("should've returned false") + } +} + +func TestAuthorizationService_WithClients(t *testing.T) { + authorizationService := NewAuthorizationService().WithClients([]*Client{NewClient("1").WithPermission("a"), NewClient("2").WithPermission("b")}) + if _, authorized := authorizationService.Authorize("1", []string{"a"}); !authorized { + t.Error("should've returned true") + } + if _, authorized := authorizationService.Authorize("2", []string{"b"}); !authorized { + t.Error("should've returned true") + } + if _, authorized := authorizationService.Authorize("1", []string{"b"}); authorized { + t.Error("should've returned false") + } + if _, authorized := authorizationService.Authorize("2", []string{"a"}); authorized { + t.Error("should've returned false") + } +} diff --git a/cache.go b/cache.go new file mode 100644 index 0000000..634d4a4 --- /dev/null +++ b/cache.go @@ -0,0 +1,13 @@ +package g8 + +import ( + "github.com/TwiN/gocache/v2" +) + +type Cache interface { + Get(key string) (value any, exists bool) + Set(key string, value any) +} + +// Make sure that gocache.Cache is compatible with the interface +var _ Cache = (*gocache.Cache)(nil) diff --git a/client.go b/client.go new file mode 100644 index 0000000..89cf45c --- /dev/null +++ b/client.go @@ -0,0 +1,79 @@ +package g8 + +// Client is a struct containing both a Token and a slice of extra Permissions that said token has. +type Client struct { + // Token is the value used to authenticate with the API. + Token string + + // Permissions is a slice of extra permissions that may be used for more granular access control. + // + // If you only wish to use Gate.Protect and Gate.ProtectFunc, you do not have to worry about this, + // since they're only used by Gate.ProtectWithPermissions and Gate.ProtectFuncWithPermissions + Permissions []string + + // Data is a field that can be used to store any data you want to associate with the client. + Data any +} + +// NewClient creates a Client with a given token +func NewClient(token string) *Client { + return &Client{ + Token: token, + } +} + +// NewClientWithPermissions creates a Client with a slice of permissions +// Equivalent to using NewClient and WithPermissions +func NewClientWithPermissions(token string, permissions []string) *Client { + return NewClient(token).WithPermissions(permissions) +} + +// NewClientWithData creates a Client with some data +// Equivalent to using NewClient and WithData +func NewClientWithData(token string, data any) *Client { + return NewClient(token).WithData(data) +} + +// NewClientWithPermissionsAndData creates a Client with a slice of permissions and some data +// Equivalent to using NewClient, WithPermissions and WithData +func NewClientWithPermissionsAndData(token string, permissions []string, data any) *Client { + return NewClient(token).WithPermissions(permissions).WithData(data) +} + +// WithPermissions adds a slice of permissions to a client +func (client *Client) WithPermissions(permissions []string) *Client { + client.Permissions = append(client.Permissions, permissions...) + return client +} + +// WithPermission adds a permission to a client +func (client *Client) WithPermission(permission string) *Client { + client.Permissions = append(client.Permissions, permission) + return client +} + +// WithData attaches data to a client +func (client *Client) WithData(data any) *Client { + client.Data = data + return client +} + +// HasPermission checks whether a client has a given permission +func (client *Client) HasPermission(permissionRequired string) bool { + for _, permission := range client.Permissions { + if permissionRequired == permission { + return true + } + } + return false +} + +// HasPermissions checks whether a client has the all permissions passed +func (client *Client) HasPermissions(permissionsRequired []string) bool { + for _, permissionRequired := range permissionsRequired { + if !client.HasPermission(permissionRequired) { + return false + } + } + return true +} diff --git a/client_test.go b/client_test.go new file mode 100644 index 0000000..efa15a1 --- /dev/null +++ b/client_test.go @@ -0,0 +1,76 @@ +package g8 + +import "testing" + +func TestClient_HasPermission(t *testing.T) { + client := NewClientWithPermissions("token", []string{"a", "b"}) + if !client.HasPermission("a") { + t.Errorf("client has permissions %s, therefore HasPermission(a) should've been true", client.Permissions) + } + if !client.HasPermission("b") { + t.Errorf("client has permissions %s, therefore HasPermission(b) should've been true", client.Permissions) + } + if client.HasPermission("c") { + t.Errorf("client has permissions %s, therefore HasPermission(c) should've been false", client.Permissions) + } + if client.HasPermission("ab") { + t.Errorf("client has permissions %s, therefore HasPermission(ab) should've been false", client.Permissions) + } +} + +func TestClient_HasPermissions(t *testing.T) { + client := NewClientWithPermissions("token", []string{"a", "b"}) + if !client.HasPermissions(nil) { + t.Errorf("client has permissions %s, therefore HasPermissions(nil) should've been true", client.Permissions) + } + if !client.HasPermissions([]string{"a"}) { + t.Errorf("client has permissions %s, therefore HasPermissions([a]) should've been true", client.Permissions) + } + if !client.HasPermissions([]string{"b"}) { + t.Errorf("client has permissions %s, therefore HasPermissions([b]) should've been true", client.Permissions) + } + if !client.HasPermissions([]string{"a", "b"}) { + t.Errorf("client has permissions %s, therefore HasPermissions([a, b]) should've been true", client.Permissions) + } + if client.HasPermissions([]string{"a", "b", "c"}) { + t.Errorf("client has permissions %s, therefore HasPermissions([a, b, c]) should've been false", client.Permissions) + } +} + +func TestClient_WithData(t *testing.T) { + client := NewClient("token") + if client.Data != nil { + t.Error("expected client data to be nil") + } + client.WithData(5) + if client.Data != 5 { + t.Errorf("expected client data to be 5, got %d", client.Data) + } + client.WithData(map[string]string{"key": "value"}) + if data, ok := client.Data.(map[string]string); !ok || data["key"] != "value" { + t.Errorf("expected client data to be map[string]string{key: value}, got %v", client.Data) + } +} + +func TestNewClientWithData(t *testing.T) { + client := NewClientWithData("token", 5) + if client.Data != 5 { + t.Errorf("expected client data to be 5, got %d", client.Data) + } +} + +func TestNewClientWithPermissionsAndData(t *testing.T) { + client := NewClientWithPermissionsAndData("token", []string{"a", "b"}, 5) + if client.Data != 5 { + t.Errorf("expected client data to be 5, got %d", client.Data) + } + if !client.HasPermission("a") { + t.Errorf("client has permissions %s, therefore HasPermission(a) should've been true", client.Permissions) + } + if !client.HasPermission("b") { + t.Errorf("client has permissions %s, therefore HasPermission(b) should've been true", client.Permissions) + } + if client.HasPermission("c") { + t.Errorf("client has permissions %s, therefore HasPermission(c) should've been false", client.Permissions) + } +} diff --git a/clientprovider.go b/clientprovider.go new file mode 100644 index 0000000..5017684 --- /dev/null +++ b/clientprovider.go @@ -0,0 +1,102 @@ +package g8 + +import ( + "time" + + "github.com/TwiN/gocache/v2" +) + +// ClientProvider has the task of retrieving a Client from an external source (e.g. a database) when provided with a +// token. It should be used when you have a lot of tokens, and it wouldn't make sense to register all of them using +// AuthorizationService's WithToken, WithTokens, WithClient or WithClients. +// +// Note that the provider is used as a fallback source. As such, if a token is explicitly registered using one of the 4 +// aforementioned functions, the client provider will not be used by the AuthorizationService when a request is made +// with said token. It will, however, be called upon if a token that is not explicitly registered in +// AuthorizationService is sent alongside a request going through the Gate. +// +// clientProvider := g8.NewClientProvider(func(token string) *g8.Client { +// // We'll assume that the following function calls your database and returns a struct "User" that +// // has the user's token as well as the permissions granted to said user +// user := database.GetUserByToken(token) +// if user != nil { +// return g8.NewClient(user.Token).WithPermissions(user.Permissions) +// } +// return nil +// }) +// gate := g8.New().WithAuthorizationService(g8.NewAuthorizationService().WithClientProvider(clientProvider)) +type ClientProvider struct { + getClientByTokenFunc func(token string) *Client + + cache Cache +} + +// NewClientProvider creates a ClientProvider +// The parameter that must be passed is a function that the provider will use to retrieve a client by a given token +// +// Example: +// +// clientProvider := g8.NewClientProvider(func(token string) *g8.Client { +// // We'll assume that the following function calls your database and returns a struct "User" that +// // has the user's token as well as the permissions granted to said user +// user := database.GetUserByToken(token) +// if user == nil { +// return nil +// } +// return g8.NewClient(user.Token).WithPermissions(user.Permissions) +// }) +// gate := g8.New().WithAuthorizationService(g8.NewAuthorizationService().WithClientProvider(clientProvider)) +func NewClientProvider(getClientByTokenFunc func(token string) *Client) *ClientProvider { + return &ClientProvider{ + getClientByTokenFunc: getClientByTokenFunc, + } +} + +// WithCache enables an in-memory cache for the ClientProvider. +// +// Example: +// +// clientProvider := g8.NewClientProvider(func(token string) *g8.Client { +// // We'll assume that the following function calls your database and returns a struct "User" that +// // has the user's token as well as the permissions granted to said user +// user := database.GetUserByToken(token) +// if user != nil { +// return g8.NewClient(user.Token).WithPermissions(user.Permissions) +// } +// return nil +// }) +// gate := g8.New().WithAuthorizationService(g8.NewAuthorizationService().WithClientProvider(clientProvider.WithCache(time.Hour, 70000))) +func (provider *ClientProvider) WithCache(ttl time.Duration, maxSize int) *ClientProvider { + return provider.WithCustomCache( + gocache.NewCache().WithEvictionPolicy(gocache.LeastRecentlyUsed).WithMaxSize(maxSize).WithDefaultTTL(ttl), + ) +} + +// WithCustomCache allows you to use a custom cache implementation instead of the default one. +// By default, using WithCache will leverage gocache. +// +// Note that the custom cache must implement the Cache interface +func (provider *ClientProvider) WithCustomCache(cache Cache) *ClientProvider { + provider.cache = cache + return provider +} + +// GetClientByToken retrieves a client by its token through the provided getClientByTokenFunc. +func (provider *ClientProvider) GetClientByToken(token string) *Client { + if provider.cache == nil { + return provider.getClientByTokenFunc(token) + } + if cachedClient, exists := provider.cache.Get(token); exists { + if cachedClient == nil { + return nil + } + // Safely typecast the client. + // Regardless of whether the typecast is successful or not, we return client since it'll be either client or + // nil. Technically, it should never be nil, but it's better to be safe than sorry. + client, _ := cachedClient.(*Client) + return client + } + client := provider.getClientByTokenFunc(token) + provider.cache.Set(token, client) + return client +} diff --git a/clientprovider_test.go b/clientprovider_test.go new file mode 100644 index 0000000..0c1f44e --- /dev/null +++ b/clientprovider_test.go @@ -0,0 +1,131 @@ +package g8 + +import ( + "sync" + "testing" + "time" + + "github.com/TwiN/gocache/v2" +) + +var ( + getClientByTokenFunc = func(token string) *Client { + if token == "valid-token" { + return NewClient("valid-token").WithData("client-data") + } + return nil + } +) + +func TestClientProvider_GetClientByToken(t *testing.T) { + provider := NewClientProvider(getClientByTokenFunc) + if client := provider.GetClientByToken("valid-token"); client == nil { + t.Error("should've returned a client") + } else if client.Data != "client-data" { + t.Error("expected client data to be 'client-data', got", client.Data) + } + if client := provider.GetClientByToken("invalid-token"); client != nil { + t.Error("should've returned nil") + } +} + +func TestClientProvider_WithCache(t *testing.T) { + provider := NewClientProvider(getClientByTokenFunc).WithCache(gocache.NoExpiration, 10000) + if provider.cache.(*gocache.Cache).Count() != 0 { + t.Error("expected cache to be empty") + } + if client := provider.GetClientByToken("valid-token"); client == nil { + t.Error("expected client, got nil") + } + if provider.cache.(*gocache.Cache).Count() != 1 { + t.Error("expected cache size to be 1") + } + if client := provider.GetClientByToken("valid-token"); client == nil { + t.Error("expected client, got nil") + } + if provider.cache.(*gocache.Cache).Count() != 1 { + t.Error("expected cache size to be 1") + } + if client := provider.GetClientByToken("invalid-token"); client != nil { + t.Error("expected nil, got", client) + } + if provider.cache.(*gocache.Cache).Count() != 2 { + t.Error("expected cache size to be 2") + } + if client := provider.GetClientByToken("invalid-token"); client != nil { + t.Error("expected nil, got", client) + } + if client := provider.GetClientByToken("invalid-token"); client != nil { + t.Error("should've returned nil (cached)") + } +} + +func TestClientProvider_WithCacheAndExpiration(t *testing.T) { + provider := NewClientProvider(getClientByTokenFunc).WithCache(10*time.Millisecond, 10) + provider.GetClientByToken("token") + if provider.cache.(*gocache.Cache).Count() != 1 { + t.Error("expected cache size to be 1") + } + if provider.cache.(*gocache.Cache).Stats().ExpiredKeys != 0 { + t.Error("expected cache statistics to report 0 expired key") + } + time.Sleep(15 * time.Millisecond) + provider.GetClientByToken("token") + if provider.cache.(*gocache.Cache).Stats().ExpiredKeys != 1 { + t.Error("expected cache statistics to report 1 expired key") + } +} + +type customCache struct { + entries map[string]any + sync.Mutex +} + +func (c *customCache) Get(key string) (value any, exists bool) { + c.Lock() + v, exists := c.entries[key] + c.Unlock() + return v, exists +} + +func (c *customCache) Set(key string, value any) { + c.Lock() + if c.entries == nil { + c.entries = make(map[string]any) + } + c.entries[key] = value + c.Unlock() +} + +var _ Cache = (*customCache)(nil) + +func TestClientProvider_WithCustomCache(t *testing.T) { + provider := NewClientProvider(getClientByTokenFunc).WithCustomCache(&customCache{}) + if len(provider.cache.(*customCache).entries) != 0 { + t.Error("expected cache to be empty") + } + if client := provider.GetClientByToken("valid-token"); client == nil { + t.Error("expected client, got nil") + } + if len(provider.cache.(*customCache).entries) != 1 { + t.Error("expected cache size to be 1") + } + if client := provider.GetClientByToken("valid-token"); client == nil { + t.Error("expected client, got nil") + } + if len(provider.cache.(*customCache).entries) != 1 { + t.Error("expected cache size to be 1") + } + if client := provider.GetClientByToken("invalid-token"); client != nil { + t.Error("expected nil, got", client) + } + if len(provider.cache.(*customCache).entries) != 2 { + t.Error("expected cache size to be 2") + } + if client := provider.GetClientByToken("invalid-token"); client != nil { + t.Error("expected nil, got", client) + } + if client := provider.GetClientByToken("invalid-token"); client != nil { + t.Error("should've returned nil (cached)") + } +} diff --git a/gate.go b/gate.go new file mode 100644 index 0000000..d7db789 --- /dev/null +++ b/gate.go @@ -0,0 +1,245 @@ +package g8 + +import ( + "context" + "net/http" + "strings" +) + +const ( + // AuthorizationHeader is the header in which g8 looks for the authorization bearer token + AuthorizationHeader = "Authorization" + + // DefaultUnauthorizedResponseBody is the default response body returned if a request was sent with a missing or invalid token + DefaultUnauthorizedResponseBody = "token is missing or invalid" + + // DefaultTooManyRequestsResponseBody is the default response body returned if a request exceeded the allowed rate limit + DefaultTooManyRequestsResponseBody = "too many requests" + + // TokenContextKey is the key used to store the client's token in the context. + TokenContextKey = "g8.token" + + // DataContextKey is the key used to store the client's data in the context. + DataContextKey = "g8.data" +) + +// Gate is lock to the front door of your API, letting only those you allow through. +type Gate struct { + authorizationService *AuthorizationService + unauthorizedResponseBody []byte + + customTokenExtractorFunc func(request *http.Request) string + + rateLimiter *RateLimiter + tooManyRequestsResponseBody []byte +} + +// Deprecated: use New instead. +func NewGate(authorizationService *AuthorizationService) *Gate { + return &Gate{ + authorizationService: authorizationService, + unauthorizedResponseBody: []byte(DefaultUnauthorizedResponseBody), + tooManyRequestsResponseBody: []byte(DefaultTooManyRequestsResponseBody), + } +} + +// New creates a new Gate. +func New() *Gate { + return &Gate{ + unauthorizedResponseBody: []byte(DefaultUnauthorizedResponseBody), + tooManyRequestsResponseBody: []byte(DefaultTooManyRequestsResponseBody), + } +} + +// WithAuthorizationService sets the authorization service to use. +// +// If there is no authorization service, Gate will not enforce authorization. +func (gate *Gate) WithAuthorizationService(authorizationService *AuthorizationService) *Gate { + gate.authorizationService = authorizationService + return gate +} + +// WithCustomUnauthorizedResponseBody sets a custom response body when Gate determines that a request must be blocked +func (gate *Gate) WithCustomUnauthorizedResponseBody(unauthorizedResponseBody []byte) *Gate { + gate.unauthorizedResponseBody = unauthorizedResponseBody + return gate +} + +// WithCustomTokenExtractor allows the specification of a custom function to extract a token from a request. +// If a custom token extractor is not specified, the token will be extracted from the Authorization header. +// +// For instance, if you're using a session cookie, you can extract the token from the cookie like so: +// +// authorizationService := g8.NewAuthorizationService() +// customTokenExtractorFunc := func(request *http.Request) string { +// sessionCookie, err := request.Cookie("session") +// if err != nil { +// return "" +// } +// return sessionCookie.Value +// } +// gate := g8.New().WithAuthorizationService(authorizationService).WithCustomTokenExtractor(customTokenExtractorFunc) +// +// You would normally use this with a client provider that matches whatever need you have. +// For example, if you're using a session cookie, your client provider would retrieve the user from the session ID +// extracted by this custom token extractor. +// +// Note that for the sake of convenience, the token extracted from the request is passed the protected handlers request +// context under the key TokenContextKey. This is especially useful if the token is in fact a session ID. +func (gate *Gate) WithCustomTokenExtractor(customTokenExtractorFunc func(request *http.Request) string) *Gate { + gate.customTokenExtractorFunc = customTokenExtractorFunc + return gate +} + +// WithRateLimit adds rate limiting to the Gate +// +// If you just want to use a gate for rate limiting purposes: +// +// gate := g8.New().WithRateLimit(50) +func (gate *Gate) WithRateLimit(maximumRequestsPerSecond int) *Gate { + gate.rateLimiter = NewRateLimiter(maximumRequestsPerSecond) + return gate +} + +// Protect secures a handler, requiring requests going through to have a valid Authorization Bearer token. +// Unlike ProtectWithPermissions, Protect will allow access to any registered tokens, regardless of their permissions +// or lack thereof. +// +// Example: +// +// gate := g8.New().WithAuthorizationService(g8.NewAuthorizationService().WithToken("token")) +// router := http.NewServeMux() +// // Without protection +// router.Handle("/handle", yourHandler) +// // With protection +// router.Handle("/handle", gate.Protect(yourHandler)) +// +// The token extracted from the request is passed to the handlerFunc request context under the key TokenContextKey +func (gate *Gate) Protect(handler http.Handler) http.Handler { + return gate.ProtectWithPermissions(handler, nil) +} + +// ProtectWithPermissions secures a handler, requiring requests going through to have a valid Authorization Bearer token +// as well as a slice of permissions that must be met. +// +// Example: +// +// gate := g8.New().WithAuthorizationService(g8.NewAuthorizationService().WithClient(g8.NewClient("token").WithPermission("ADMIN"))) +// router := http.NewServeMux() +// // Without protection +// router.Handle("/handle", yourHandler) +// // With protection +// router.Handle("/handle", gate.ProtectWithPermissions(yourHandler, []string{"admin"})) +// +// The token extracted from the request is passed to the handlerFunc request context under the key TokenContextKey +func (gate *Gate) ProtectWithPermissions(handler http.Handler, permissions []string) http.Handler { + return gate.ProtectFuncWithPermissions(func(writer http.ResponseWriter, request *http.Request) { + handler.ServeHTTP(writer, request) + }, permissions) +} + +// ProtectWithPermission does the same thing as ProtectWithPermissions, but for a single permission instead of a +// slice of permissions +// +// See ProtectWithPermissions for further documentation +func (gate *Gate) ProtectWithPermission(handler http.Handler, permission string) http.Handler { + return gate.ProtectFuncWithPermissions(func(writer http.ResponseWriter, request *http.Request) { + handler.ServeHTTP(writer, request) + }, []string{permission}) +} + +// ProtectFunc secures a handlerFunc, requiring requests going through to have a valid Authorization Bearer token. +// Unlike ProtectFuncWithPermissions, ProtectFunc will allow access to any registered tokens, regardless of their +// permissions or lack thereof. +// +// Example: +// +// gate := g8.New().WithAuthorizationService(g8.NewAuthorizationService().WithToken("token")) +// router := http.NewServeMux() +// // Without protection +// router.HandleFunc("/handle", yourHandlerFunc) +// // With protection +// router.HandleFunc("/handle", gate.ProtectFunc(yourHandlerFunc)) +// +// The token extracted from the request is passed to the handlerFunc request context under the key TokenContextKey +func (gate *Gate) ProtectFunc(handlerFunc http.HandlerFunc) http.HandlerFunc { + return gate.ProtectFuncWithPermissions(handlerFunc, nil) +} + +// ProtectFuncWithPermissions secures a handler, requiring requests going through to have a valid Authorization Bearer +// token as well as a slice of permissions that must be met. +// +// Example: +// +// gate := g8.New().WithAuthorizationService(g8.NewAuthorizationService().WithClient(g8.NewClient("token").WithPermission("admin"))) +// router := http.NewServeMux() +// // Without protection +// router.HandleFunc("/handle", yourHandlerFunc) +// // With protection +// router.HandleFunc("/handle", gate.ProtectFuncWithPermissions(yourHandlerFunc, []string{"admin"})) +// +// The token extracted from the request is passed to the handlerFunc request context under the key TokenContextKey +func (gate *Gate) ProtectFuncWithPermissions(handlerFunc http.HandlerFunc, permissions []string) http.HandlerFunc { + return func(writer http.ResponseWriter, request *http.Request) { + if gate.rateLimiter != nil { + if !gate.rateLimiter.Try() { + writer.WriteHeader(http.StatusTooManyRequests) + _, _ = writer.Write(gate.tooManyRequestsResponseBody) + return + } + } + if gate.authorizationService != nil { + token := gate.ExtractTokenFromRequest(request) + if client, authorized := gate.authorizationService.Authorize(token, permissions); !authorized { + writer.WriteHeader(http.StatusUnauthorized) + _, _ = writer.Write(gate.unauthorizedResponseBody) + return + } else { + request = request.WithContext(context.WithValue(request.Context(), TokenContextKey, token)) + if client != nil && client.Data != nil { + request = request.WithContext(context.WithValue(request.Context(), DataContextKey, client.Data)) + } + } + } + handlerFunc(writer, request) + } +} + +// ProtectFuncWithPermission does the same thing as ProtectFuncWithPermissions, but for a single permission instead of a +// slice of permissions +// +// See ProtectFuncWithPermissions for further documentation +func (gate *Gate) ProtectFuncWithPermission(handlerFunc http.HandlerFunc, permission string) http.HandlerFunc { + return gate.ProtectFuncWithPermissions(handlerFunc, []string{permission}) +} + +// ExtractTokenFromRequest extracts a token from a request. +// +// By default, it extracts the bearer token from the AuthorizationHeader, but if a customTokenExtractorFunc is defined, +// it will use that instead. +// +// Note that this method is internally used by Protect, ProtectWithPermission, ProtectFunc and +// ProtectFuncWithPermissions, but it is exposed in case you need to use it directly. +func (gate *Gate) ExtractTokenFromRequest(request *http.Request) string { + if gate.customTokenExtractorFunc != nil { + // A custom token extractor function is defined, so we'll use it instead of the default token extraction logic + return gate.customTokenExtractorFunc(request) + } + return strings.TrimPrefix(request.Header.Get(AuthorizationHeader), "Bearer ") +} + +// PermissionMiddleware is a middleware that behaves like ProtectWithPermission, but it is meant to be used +// as a middleware for libraries that support such a feature. +// +// For instance, if you are using github.com/gorilla/mux, you can use PermissionMiddleware like so: +// +// router := mux.NewRouter() +// router.Use(gate.PermissionMiddleware("admin")) +// router.Handle("/admin/handle", adminHandler) +// +// If you do not want to protect a router with a specific permission, you can use Gate.Protect instead. +func (gate *Gate) PermissionMiddleware(permissions ...string) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return gate.ProtectWithPermissions(next, permissions) + } +} diff --git a/gate_bench_test.go b/gate_bench_test.go new file mode 100644 index 0000000..0be4c88 --- /dev/null +++ b/gate_bench_test.go @@ -0,0 +1,208 @@ +package g8 + +import ( + "fmt" + "net/http" + "net/http/httptest" + "testing" +) + +var handler http.Handler = &testHandler{} + +func BenchmarkTestHandler(b *testing.B) { + request, _ := http.NewRequest("GET", "/handle", nil) + + router := http.NewServeMux() + router.Handle("/handle", handler) + + for n := 0; n < b.N; n++ { + responseRecorder := httptest.NewRecorder() + router.ServeHTTP(responseRecorder, request) + if responseRecorder.Code != http.StatusOK { + b.Fatalf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code) + } + } + b.ReportAllocs() +} + +func BenchmarkGate_ProtectWhenNoAuthorizationHeader(b *testing.B) { + gate := New().WithAuthorizationService(NewAuthorizationService().WithToken("good-token")) + request, _ := http.NewRequest("GET", "/handle", nil) + + router := http.NewServeMux() + router.Handle("/handle", gate.Protect(handler)) + + for n := 0; n < b.N; n++ { + responseRecorder := httptest.NewRecorder() + router.ServeHTTP(responseRecorder, request) + if responseRecorder.Code != http.StatusUnauthorized { + b.Fatalf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusUnauthorized, responseRecorder.Code) + } + } + b.ReportAllocs() +} + +func BenchmarkGate_ProtectWithInvalidToken(b *testing.B) { + gate := New().WithAuthorizationService(NewAuthorizationService().WithToken("good-token")) + request, _ := http.NewRequest("GET", "/handle", nil) + request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "bad-token")) + + router := http.NewServeMux() + router.Handle("/handle", gate.Protect(handler)) + + for n := 0; n < b.N; n++ { + responseRecorder := httptest.NewRecorder() + router.ServeHTTP(responseRecorder, request) + if responseRecorder.Code != http.StatusUnauthorized { + b.Fatalf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusUnauthorized, responseRecorder.Code) + } + } + b.ReportAllocs() +} + +func BenchmarkGate_ProtectWithValidToken(b *testing.B) { + gate := New().WithAuthorizationService(NewAuthorizationService().WithToken("good-token")) + request, _ := http.NewRequest("GET", "/handle", http.NoBody) + request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "good-token")) + + router := http.NewServeMux() + router.Handle("/handle", gate.Protect(handler)) + + for n := 0; n < b.N; n++ { + responseRecorder := httptest.NewRecorder() + router.ServeHTTP(responseRecorder, request) + if responseRecorder.Code != http.StatusOK { + b.Fatalf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code) + } + } + b.ReportAllocs() +} + +func BenchmarkGate_ProtectWithPermissionsAndValidToken(b *testing.B) { + gate := New().WithAuthorizationService(NewAuthorizationService().WithClient(NewClient("token").WithPermission("admin"))) + request, _ := http.NewRequest("GET", "/handle", http.NoBody) + request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "token")) + + router := http.NewServeMux() + router.Handle("/handle", gate.ProtectWithPermissions(handler, []string{"admin"})) + + for n := 0; n < b.N; n++ { + responseRecorder := httptest.NewRecorder() + router.ServeHTTP(responseRecorder, request) + if responseRecorder.Code != http.StatusOK { + b.Fatalf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code) + } + } + b.ReportAllocs() +} + +func BenchmarkGate_ProtectWithPermissionsAndValidTokenButInsufficientPermissions(b *testing.B) { + gate := New().WithAuthorizationService(NewAuthorizationService().WithClient(NewClient("token").WithPermission("mod"))) + request, _ := http.NewRequest("GET", "/handle", http.NoBody) + request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "token")) + + router := http.NewServeMux() + router.Handle("/handle", gate.ProtectWithPermissions(handler, []string{"admin"})) + + for n := 0; n < b.N; n++ { + responseRecorder := httptest.NewRecorder() + router.ServeHTTP(responseRecorder, request) + if responseRecorder.Code != http.StatusUnauthorized { + b.Fatalf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusUnauthorized, responseRecorder.Code) + } + } + b.ReportAllocs() +} + +func BenchmarkGate_ProtectConcurrently(b *testing.B) { + gate := New().WithAuthorizationService(NewAuthorizationService().WithToken("good-token")) + + request, _ := http.NewRequest("GET", "/handle", http.NoBody) + request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "good-token")) + + badRequest, _ := http.NewRequest("GET", "/handle", http.NoBody) + badRequest.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "bad-token")) + + router := http.NewServeMux() + router.Handle("/handle", gate.Protect(handler)) + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + responseRecorder := httptest.NewRecorder() + router.ServeHTTP(responseRecorder, request) + if responseRecorder.Code != http.StatusOK { + b.Fatalf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code) + } + responseRecorder = httptest.NewRecorder() + router.ServeHTTP(responseRecorder, badRequest) + if responseRecorder.Code != http.StatusUnauthorized { + b.Fatalf("%s %s should have returned %d, but returned %d instead", badRequest.Method, badRequest.URL, http.StatusUnauthorized, responseRecorder.Code) + } + } + }) + b.ReportAllocs() +} + +func BenchmarkGate_ProtectWithClientProviderConcurrently(b *testing.B) { + gate := New().WithAuthorizationService(NewAuthorizationService().WithClientProvider(mockClientProvider)) + + request, _ := http.NewRequest("GET", "/handle", http.NoBody) + request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", TestProviderClientToken)) + + firstBadRequest, _ := http.NewRequest("GET", "/handle", http.NoBody) + firstBadRequest.Header.Set("Authorization", fmt.Sprintf("Bearer %s", "bad-token-1")) + + secondBadRequest, _ := http.NewRequest("GET", "/handle", http.NoBody) + secondBadRequest.Header.Set("Authorization", fmt.Sprintf("Bearer %s", "bad-token-2")) + + router := http.NewServeMux() + router.Handle("/handle", gate.Protect(handler)) + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + responseRecorder := httptest.NewRecorder() + router.ServeHTTP(responseRecorder, request) + if responseRecorder.Code != http.StatusOK { + b.Fatalf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code) + } + responseRecorder = httptest.NewRecorder() + router.ServeHTTP(responseRecorder, firstBadRequest) + if responseRecorder.Code != http.StatusUnauthorized { + b.Fatalf("%s %s should have returned %d, but returned %d instead", firstBadRequest.Method, firstBadRequest.URL, http.StatusUnauthorized, responseRecorder.Code) + } + responseRecorder = httptest.NewRecorder() + router.ServeHTTP(responseRecorder, secondBadRequest) + if responseRecorder.Code != http.StatusUnauthorized { + b.Fatalf("%s %s should have returned %d, but returned %d instead", secondBadRequest.Method, secondBadRequest.URL, http.StatusUnauthorized, responseRecorder.Code) + } + } + }) + b.ReportAllocs() +} + +func BenchmarkGate_ProtectWithValidTokenAndCustomTokenExtractorFuncConcurrently(b *testing.B) { + customTokenExtractorFunc := func(request *http.Request) string { + sessionCookie, err := request.Cookie("session") + if err != nil { + return "" + } + return sessionCookie.Value + } + gate := New().WithAuthorizationService(NewAuthorizationService().WithToken("good-token")).WithCustomTokenExtractor(customTokenExtractorFunc) + request, _ := http.NewRequest("GET", "/handle", http.NoBody) + request.AddCookie(&http.Cookie{Name: "session", Value: "good-token"}) + + router := http.NewServeMux() + router.Handle("/handle", gate.Protect(handler)) + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + responseRecorder := httptest.NewRecorder() + router.ServeHTTP(responseRecorder, request) + if responseRecorder.Code != http.StatusOK { + b.Fatalf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code) + } + } + }) + b.ReportAllocs() +} diff --git a/gate_test.go b/gate_test.go new file mode 100644 index 0000000..4c3f3ce --- /dev/null +++ b/gate_test.go @@ -0,0 +1,538 @@ +package g8 + +import ( + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" +) + +const ( + FirstTestProviderClientPermission = "permission-1" + SecondTestProviderClientPermission = "permission-2" + TestProviderClientToken = "client-token-from-provider" + TestProviderClientData = "client-data-from-provider" +) + +var ( + mockClientProvider = NewClientProvider(func(token string) *Client { + // We'll pretend that there's only one token that's valid in the client provider, every other token + // returns nil + if token == TestProviderClientToken { + return &Client{ + Token: TestProviderClientToken, + Data: TestProviderClientData, + Permissions: []string{FirstTestProviderClientPermission, SecondTestProviderClientPermission}, + } + } + return nil + }) +) + +type testHandler struct { +} + +func (handler *testHandler) ServeHTTP(writer http.ResponseWriter, _ *http.Request) { + writer.WriteHeader(http.StatusOK) +} + +func testHandlerFunc(writer http.ResponseWriter, _ *http.Request) { + writer.WriteHeader(http.StatusOK) +} + +func TestUsability(t *testing.T) { + gate := New().WithAuthorizationService(NewAuthorizationService().WithToken("good-token")) + + var handler http.Handler = &testHandler{} + handlerFunc := func(writer http.ResponseWriter, request *http.Request) { + writer.WriteHeader(http.StatusOK) + } + + router := http.NewServeMux() + router.Handle("/handle", handler) + router.Handle("/handle-protected", gate.Protect(handler)) + router.HandleFunc("/handlefunc", handlerFunc) + router.HandleFunc("/handlefunc-protected", gate.ProtectFunc(handlerFunc)) +} + +func TestNewGate(t *testing.T) { + gate := NewGate(nil) + if gate == nil { + t.Error("gate should not be nil") + } +} + +func TestUnprotectedHandler(t *testing.T) { + request, _ := http.NewRequest("GET", "/handle", http.NoBody) + request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "bad-token")) + responseRecorder := httptest.NewRecorder() + + router := http.NewServeMux() + router.Handle("/handle", &testHandler{}) + router.ServeHTTP(responseRecorder, request) + + if responseRecorder.Code != http.StatusOK { + t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code) + } +} + +func TestGate_ProtectWithInvalidToken(t *testing.T) { + gate := New().WithAuthorizationService(NewAuthorizationService().WithToken("good-token")) + request, _ := http.NewRequest("GET", "/handle", http.NoBody) + request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "bad-token")) + responseRecorder := httptest.NewRecorder() + + router := http.NewServeMux() + router.Handle("/handle", gate.Protect(&testHandler{})) + router.ServeHTTP(responseRecorder, request) + + if responseRecorder.Code != http.StatusUnauthorized { + t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusUnauthorized, responseRecorder.Code) + } +} + +func TestGate_ProtectWithValidToken(t *testing.T) { + gate := New().WithAuthorizationService(NewAuthorizationService().WithToken("good-token")) + request, _ := http.NewRequest("GET", "/handle", http.NoBody) + request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "good-token")) + responseRecorder := httptest.NewRecorder() + + router := http.NewServeMux() + router.Handle("/handle", gate.Protect(&testHandler{})) + router.ServeHTTP(responseRecorder, request) + + if responseRecorder.Code != http.StatusOK { + t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code) + } +} + +func TestGate_ProtectMultipleTimes(t *testing.T) { + gate := New().WithAuthorizationService(NewAuthorizationService().WithToken("good-token")) + request, _ := http.NewRequest("GET", "/handle", http.NoBody) + request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "good-token")) + badRequest, _ := http.NewRequest("GET", "/handle", http.NoBody) + badRequest.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "bad-token")) + + router := http.NewServeMux() + router.Handle("/handle", gate.Protect(&testHandler{})) + + for i := 0; i < 100; i++ { + responseRecorder := httptest.NewRecorder() + router.ServeHTTP(responseRecorder, request) + if responseRecorder.Code != http.StatusOK { + t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code) + } + responseRecorder = httptest.NewRecorder() + router.ServeHTTP(responseRecorder, badRequest) + if responseRecorder.Code != http.StatusUnauthorized { + t.Errorf("%s %s should have returned %d, but returned %d instead", badRequest.Method, badRequest.URL, http.StatusOK, responseRecorder.Code) + } + } +} + +func TestGate_ProtectWithValidTokenExposedThroughClientProvider(t *testing.T) { + gate := New().WithAuthorizationService(NewAuthorizationService().WithClientProvider(mockClientProvider)) + request, _ := http.NewRequest("GET", "/handle", http.NoBody) + request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", TestProviderClientToken)) + responseRecorder := httptest.NewRecorder() + + router := http.NewServeMux() + router.Handle("/handle", gate.Protect(&testHandler{})) + router.ServeHTTP(responseRecorder, request) + + if responseRecorder.Code != http.StatusOK { + t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code) + } +} + +func TestGate_ProtectWithValidTokenExposedThroughClientProviderWithCache(t *testing.T) { + gate := New().WithAuthorizationService(NewAuthorizationService().WithClientProvider(mockClientProvider.WithCache(60*time.Minute, 70000))) + request, _ := http.NewRequest("GET", "/handle", http.NoBody) + request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", TestProviderClientToken)) + responseRecorder := httptest.NewRecorder() + + router := http.NewServeMux() + router.Handle("/handle", gate.Protect(&testHandler{})) + router.ServeHTTP(responseRecorder, request) + + if responseRecorder.Code != http.StatusOK { + t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code) + } +} + +func TestGate_ProtectWithInvalidTokenWhenUsingClientProvider(t *testing.T) { + gate := New().WithAuthorizationService(NewAuthorizationService().WithClientProvider(mockClientProvider)) + request, _ := http.NewRequest("GET", "/handle", http.NoBody) + request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "bad-token")) + responseRecorder := httptest.NewRecorder() + + router := http.NewServeMux() + router.Handle("/handle", gate.Protect(&testHandler{})) + router.ServeHTTP(responseRecorder, request) + + if responseRecorder.Code != http.StatusUnauthorized { + t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusUnauthorized, responseRecorder.Code) + } +} + +func TestGate_ProtectWithPermissionsWhenValidTokenAndSufficientPermissionsWhileUsingClientProvider(t *testing.T) { + gate := New().WithAuthorizationService(NewAuthorizationService().WithClientProvider(mockClientProvider)) + request, _ := http.NewRequest("GET", "/handle", http.NoBody) + request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", TestProviderClientToken)) + responseRecorder := httptest.NewRecorder() + + router := http.NewServeMux() + router.Handle("/handle", gate.ProtectWithPermissions(&testHandler{}, []string{SecondTestProviderClientPermission})) + router.ServeHTTP(responseRecorder, request) + + // Since the client returned from the mockClientProvider has FirstTestProviderClientPermission and + // SecondTestProviderClientPermission and the testHandler is protected by SecondTestProviderClientPermission, + // the request should be authorized + if responseRecorder.Code != http.StatusOK { + t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code) + } +} + +func TestGate_ProtectWithPermissionsWhenValidTokenAndInsufficientPermissionsWhileUsingClientProvider(t *testing.T) { + gate := New().WithAuthorizationService(NewAuthorizationService().WithClientProvider(mockClientProvider)) + request, _ := http.NewRequest("GET", "/handle", http.NoBody) + request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", TestProviderClientToken)) + responseRecorder := httptest.NewRecorder() + + router := http.NewServeMux() + router.Handle("/handle", gate.ProtectWithPermissions(&testHandler{}, []string{"unrelated-permission"})) + router.ServeHTTP(responseRecorder, request) + + // Since the client returned from the mockClientProvider has FirstTestProviderClientPermission and + // SecondTestProviderClientPermission and the testHandler is protected by a permission that the client does not + // have, the request should be not be authorized + if responseRecorder.Code != http.StatusUnauthorized { + t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusUnauthorized, responseRecorder.Code) + } +} + +func TestGate_ProtectWithPermissionsWhenClientHasSufficientPermissions(t *testing.T) { + gate := New().WithAuthorizationService(NewAuthorizationService().WithClient(NewClient("token").WithPermission("admin"))) + request, _ := http.NewRequest("GET", "/handle", http.NoBody) + request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "token")) + responseRecorder := httptest.NewRecorder() + + router := http.NewServeMux() + router.Handle("/handle", gate.ProtectWithPermissions(&testHandler{}, []string{"admin"})) + router.ServeHTTP(responseRecorder, request) + + // Since the client registered directly in the AuthorizationService has the permission "admin" and the testHandler + // is protected by the permission "admin", the request should be authorized + if responseRecorder.Code != http.StatusOK { + t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code) + } +} + +func TestGate_ProtectWithPermissionsWhenClientHasInsufficientPermissions(t *testing.T) { + gate := New().WithAuthorizationService(NewAuthorizationService().WithClient(NewClientWithPermissions("token", []string{"mod"}))) + request, _ := http.NewRequest("GET", "/handle", http.NoBody) + request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "token")) + responseRecorder := httptest.NewRecorder() + + router := http.NewServeMux() + router.Handle("/handle", gate.ProtectWithPermissions(&testHandler{}, []string{"admin"})) + router.ServeHTTP(responseRecorder, request) + + // Since the client registered directly in the AuthorizationService has the permission "mod" and the + // testHandler is protected by the permission "admin", the request should be not be authorized + if responseRecorder.Code != http.StatusUnauthorized { + t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusUnauthorized, responseRecorder.Code) + } +} + +func TestGate_ProtectWithPermissions(t *testing.T) { + gate := New().WithAuthorizationService(NewAuthorizationService().WithClient(NewClient("mytoken").WithPermissions([]string{"create", "read", "update", "delete"}))) + + router := http.NewServeMux() + router.Handle("/create", gate.ProtectWithPermissions(&testHandler{}, []string{"create"})) + router.Handle("/read", gate.ProtectWithPermissions(&testHandler{}, []string{"read"})) + router.Handle("/update", gate.ProtectWithPermissions(&testHandler{}, []string{"update"})) + router.Handle("/delete", gate.ProtectWithPermissions(&testHandler{}, []string{"delete"})) + router.Handle("/crud", gate.ProtectWithPermissions(&testHandler{}, []string{"create", "read", "update", "delete"})) + router.Handle("/backup", gate.ProtectWithPermissions(&testHandler{}, []string{"read", "backup"})) + + checkRouterOutput := func(t *testing.T, router *http.ServeMux, url string, expectedResponseCode int) { + t.Run(strings.TrimPrefix(url, "/"), func(t *testing.T) { + request, _ := http.NewRequest("GET", url, http.NoBody) + request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "mytoken")) + responseRecorder := httptest.NewRecorder() + router.ServeHTTP(responseRecorder, request) + if responseRecorder.Code != expectedResponseCode { + t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, expectedResponseCode, responseRecorder.Code) + } + }) + } + + checkRouterOutput(t, router, "/create", http.StatusOK) + checkRouterOutput(t, router, "/read", http.StatusOK) + checkRouterOutput(t, router, "/update", http.StatusOK) + checkRouterOutput(t, router, "/delete", http.StatusOK) + checkRouterOutput(t, router, "/crud", http.StatusOK) + checkRouterOutput(t, router, "/backup", http.StatusUnauthorized) +} + +func TestGate_ProtectWithPermissionWhenClientHasSufficientPermissions(t *testing.T) { + gate := New().WithAuthorizationService(NewAuthorizationService().WithClient(NewClient("token").WithPermission("admin"))) + request, _ := http.NewRequest("GET", "/handle", http.NoBody) + request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "token")) + responseRecorder := httptest.NewRecorder() + + router := http.NewServeMux() + router.Handle("/handle", gate.ProtectWithPermission(&testHandler{}, "admin")) + router.ServeHTTP(responseRecorder, request) + + // Since the client registered directly in the AuthorizationService has the permission "admin" and the testHandler + // is protected by the permission "admin", the request should be authorized + if responseRecorder.Code != http.StatusOK { + t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code) + } +} + +func TestGate_ProtectWithPermissionWhenClientHasInsufficientPermissions(t *testing.T) { + gate := New().WithAuthorizationService(NewAuthorizationService().WithClient(NewClientWithPermissions("token", []string{"mod"}))) + request, _ := http.NewRequest("GET", "/handle", http.NoBody) + request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "token")) + responseRecorder := httptest.NewRecorder() + + router := http.NewServeMux() + router.Handle("/handle", gate.ProtectWithPermission(&testHandler{}, "admin")) + router.ServeHTTP(responseRecorder, request) + + // Since the client registered directly in the AuthorizationService has the permission "mod" and the + // testHandler is protected by the permission "admin", the request should be not be authorized + if responseRecorder.Code != http.StatusUnauthorized { + t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusUnauthorized, responseRecorder.Code) + } +} + +func TestGate_PermissionMiddlewareWhenClientHasSufficientPermissions(t *testing.T) { + gate := New().WithAuthorizationService(NewAuthorizationService().WithClient(NewClient("token").WithPermission("admin"))) + request, _ := http.NewRequest("GET", "/handle", http.NoBody) + request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "token")) + responseRecorder := httptest.NewRecorder() + + router := http.NewServeMux() + router.Handle("/handle", gate.PermissionMiddleware("admin")(&testHandler{})) + router.ServeHTTP(responseRecorder, request) + + // Since the client registered directly in the AuthorizationService has the permission "admin" and the testHandler + // is protected by the permission "admin", the request should be authorized + if responseRecorder.Code != http.StatusOK { + t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code) + } +} + +func TestGate_PermissionMiddlewareWhenClientHasInsufficientPermissions(t *testing.T) { + gate := New().WithAuthorizationService(NewAuthorizationService().WithClient(NewClientWithPermissions("token", []string{"mod"}))) + request, _ := http.NewRequest("GET", "/handle", http.NoBody) + request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "token")) + responseRecorder := httptest.NewRecorder() + + router := http.NewServeMux() + router.Handle("/handle", gate.PermissionMiddleware("admin")(&testHandler{})) + router.ServeHTTP(responseRecorder, request) + + // Since the client registered directly in the AuthorizationService has the permission "mod" and the + // testHandler is protected by the permission "admin", the request should be not be authorized + if responseRecorder.Code != http.StatusUnauthorized { + t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusUnauthorized, responseRecorder.Code) + } +} + +func TestGate_ProtectFuncWithInvalidToken(t *testing.T) { + gate := New().WithAuthorizationService(NewAuthorizationService().WithToken("good-token")) + request, _ := http.NewRequest("GET", "/handle", http.NoBody) + request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "bad-token")) + responseRecorder := httptest.NewRecorder() + + router := http.NewServeMux() + router.Handle("/handle", gate.ProtectFunc(testHandlerFunc)) + router.ServeHTTP(responseRecorder, request) + + if responseRecorder.Code != http.StatusUnauthorized { + t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusUnauthorized, responseRecorder.Code) + } +} + +func TestGate_ProtectFuncWithValidToken(t *testing.T) { + gate := New().WithAuthorizationService(NewAuthorizationService().WithToken("good-token")) + request, _ := http.NewRequest("GET", "/handle", http.NoBody) + request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "good-token")) + responseRecorder := httptest.NewRecorder() + + router := http.NewServeMux() + router.Handle("/handle", gate.ProtectFunc(testHandlerFunc)) + router.ServeHTTP(responseRecorder, request) + + if responseRecorder.Code != http.StatusOK { + t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code) + } +} + +func TestGate_ProtectFuncWithPermissionWhenClientHasSufficientPermissions(t *testing.T) { + gate := New().WithAuthorizationService(NewAuthorizationService().WithClient(NewClient("token").WithPermission("admin"))) + request, _ := http.NewRequest("GET", "/handle", http.NoBody) + request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "token")) + responseRecorder := httptest.NewRecorder() + + router := http.NewServeMux() + router.HandleFunc("/handle", gate.ProtectFuncWithPermission(testHandlerFunc, "admin")) + router.ServeHTTP(responseRecorder, request) + + // Since the client registered directly in the AuthorizationService has the permission "admin" and the testHandler + // is protected by the permission "admin", the request should be authorized + if responseRecorder.Code != http.StatusOK { + t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code) + } +} + +func TestGate_ProtectFuncWithPermissionWhenClientHasInsufficientPermissions(t *testing.T) { + gate := New().WithAuthorizationService(NewAuthorizationService().WithClient(NewClientWithPermissions("token", []string{"mod"}))) + request, _ := http.NewRequest("GET", "/handle", http.NoBody) + request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "token")) + responseRecorder := httptest.NewRecorder() + + router := http.NewServeMux() + router.HandleFunc("/handle", gate.ProtectFuncWithPermission(testHandlerFunc, "admin")) + router.ServeHTTP(responseRecorder, request) + + // Since the client registered directly in the AuthorizationService has the permission "mod" and the + // testHandler is protected by the permission "admin", the request should be not be authorized + if responseRecorder.Code != http.StatusUnauthorized { + t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusUnauthorized, responseRecorder.Code) + } +} + +func TestGate_WithCustomUnauthorizedResponseBody(t *testing.T) { + gate := New().WithAuthorizationService(NewAuthorizationService()).WithCustomUnauthorizedResponseBody([]byte("test")) + request, _ := http.NewRequest("GET", "/handle", http.NoBody) + request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "bad-token")) + responseRecorder := httptest.NewRecorder() + + router := http.NewServeMux() + router.Handle("/handle", gate.Protect(&testHandler{})) + router.ServeHTTP(responseRecorder, request) + + if responseRecorder.Code != http.StatusUnauthorized { + t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusUnauthorized, responseRecorder.Code) + } + if responseBody, _ := io.ReadAll(responseRecorder.Body); string(responseBody) != "test" { + t.Errorf("%s %s should have returned %s, but returned %s instead", request.Method, request.URL, "test", string(responseBody)) + } +} + +func TestGate_ProtectWithNoAuthorizationService(t *testing.T) { + gate := New() + request, _ := http.NewRequest("GET", "/handle", http.NoBody) + responseRecorder := httptest.NewRecorder() + + router := http.NewServeMux() + router.Handle("/handle", gate.Protect(&testHandler{})) + router.ServeHTTP(responseRecorder, request) + + if responseRecorder.Code != http.StatusOK { + t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code) + } +} + +func TestGate_ProtectWithRateLimit(t *testing.T) { + gate := New().WithRateLimit(2) + request, _ := http.NewRequest("GET", "/handle", http.NoBody) + router := http.NewServeMux() + router.Handle("/handle", gate.Protect(&testHandler{})) + + responseRecorder := httptest.NewRecorder() + router.ServeHTTP(responseRecorder, request) + if responseRecorder.Code != http.StatusOK { + t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code) + } + + responseRecorder = httptest.NewRecorder() + router.ServeHTTP(responseRecorder, request) + if responseRecorder.Code != http.StatusOK { + t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code) + } + + responseRecorder = httptest.NewRecorder() + router.ServeHTTP(responseRecorder, request) + if responseRecorder.Code != http.StatusTooManyRequests { + t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusTooManyRequests, responseRecorder.Code) + } + + // Wait for rate limit time window to pass + time.Sleep(time.Second) + + responseRecorder = httptest.NewRecorder() + router.ServeHTTP(responseRecorder, request) + if responseRecorder.Code != http.StatusOK { + t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code) + } +} + +func TestGate_WithCustomTokenExtractor(t *testing.T) { + authorizationService := NewAuthorizationService().WithClientProvider(mockClientProvider) + customTokenExtractorFunc := func(request *http.Request) string { + sessionCookie, err := request.Cookie("session") + if err != nil { + return "" + } + return sessionCookie.Value + } + gate := New().WithAuthorizationService(authorizationService).WithCustomTokenExtractor(customTokenExtractorFunc) + + request, _ := http.NewRequest("GET", "/handle", http.NoBody) + request.AddCookie(&http.Cookie{Name: "session", Value: TestProviderClientToken}) + responseRecorder := httptest.NewRecorder() + + router := http.NewServeMux() + router.Handle("/handle", gate.ProtectFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Context().Value(TokenContextKey) != TestProviderClientToken { + t.Errorf("token should have been passed to the request context") + } + if r.Context().Value(DataContextKey) != TestProviderClientData { + t.Errorf("data should have been passed to the request context") + } + w.WriteHeader(http.StatusOK) + })) + router.ServeHTTP(responseRecorder, request) + + if responseRecorder.Code != http.StatusOK { + t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code) + } +} + +func TestGateWithCustomHeader(t *testing.T) { + authorizationService := NewAuthorizationService().WithClientProvider(mockClientProvider) + customTokenExtractorFunc := func(request *http.Request) string { + return request.Header.Get("X-API-Token") + } + gate := New().WithAuthorizationService(authorizationService).WithCustomTokenExtractor(customTokenExtractorFunc) + + request, _ := http.NewRequest("GET", "/handle", http.NoBody) + request.Header.Add("X-API-Token", TestProviderClientToken) + responseRecorder := httptest.NewRecorder() + + router := http.NewServeMux() + router.Handle("/handle", gate.ProtectFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Context().Value(TokenContextKey) != TestProviderClientToken { + t.Errorf("token should have been passed to the request context") + } + if r.Context().Value(DataContextKey) != TestProviderClientData { + t.Errorf("data should have been passed to the request context") + } + w.WriteHeader(http.StatusOK) + })) + router.ServeHTTP(responseRecorder, request) + + if responseRecorder.Code != http.StatusOK { + t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code) + } +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..751f874 --- /dev/null +++ b/go.mod @@ -0,0 +1,5 @@ +module github.com/TwiN/g8/v3 + +go 1.23.3 + +require github.com/TwiN/gocache/v2 v2.2.2 diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..13e5145 --- /dev/null +++ b/go.sum @@ -0,0 +1,2 @@ +github.com/TwiN/gocache/v2 v2.2.2 h1:4HToPfDV8FSbaYO5kkbhLpEllUYse5rAf+hVU/mSsuI= +github.com/TwiN/gocache/v2 v2.2.2/go.mod h1:WfIuwd7GR82/7EfQqEtmLFC3a2vqaKbs4Pe6neB7Gyc= diff --git a/ratelimiter.go b/ratelimiter.go new file mode 100644 index 0000000..2645ff7 --- /dev/null +++ b/ratelimiter.go @@ -0,0 +1,42 @@ +package g8 + +import ( + "sync" + "time" +) + +// RateLimiter is a fixed rate limiter +type RateLimiter struct { + maximumExecutionsPerSecond int + executionsLeftInWindow int + windowStartTime time.Time + mutex sync.Mutex +} + +// NewRateLimiter creates a RateLimiter +func NewRateLimiter(maximumExecutionsPerSecond int) *RateLimiter { + return &RateLimiter{ + windowStartTime: time.Now(), + executionsLeftInWindow: maximumExecutionsPerSecond, + maximumExecutionsPerSecond: maximumExecutionsPerSecond, + } +} + +// Try updates the number of executions if the rate limit quota hasn't been reached and returns whether the +// attempt was successful or not. +// +// Returns false if the execution was not successful (rate limit quota has been reached) +// Returns true if the execution was successful (rate limit quota has not been reached) +func (r *RateLimiter) Try() bool { + r.mutex.Lock() + defer r.mutex.Unlock() + if time.Now().Add(-time.Second).After(r.windowStartTime) { + r.windowStartTime = time.Now() + r.executionsLeftInWindow = r.maximumExecutionsPerSecond + } + if r.executionsLeftInWindow == 0 { + return false + } + r.executionsLeftInWindow-- + return true +} diff --git a/ratelimiter_test.go b/ratelimiter_test.go new file mode 100644 index 0000000..70cdd6d --- /dev/null +++ b/ratelimiter_test.go @@ -0,0 +1,73 @@ +package g8 + +import ( + "testing" + "time" +) + +func TestNewRateLimiter(t *testing.T) { + rl := NewRateLimiter(2) + if rl.maximumExecutionsPerSecond != 2 { + t.Errorf("expected maximumExecutionsPerSecond to be %d, got %d", 2, rl.maximumExecutionsPerSecond) + } + if rl.executionsLeftInWindow != 2 { + t.Errorf("expected executionsLeftInWindow to be %d, got %d", 2, rl.executionsLeftInWindow) + } + // First execution: should not be rate limited + if notRateLimited := rl.Try(); !notRateLimited { + t.Error("expected Try to return true") + } + if rl.maximumExecutionsPerSecond != 2 { + t.Errorf("expected maximumExecutionsPerSecond to be %d, got %d", 2, rl.maximumExecutionsPerSecond) + } + if rl.executionsLeftInWindow != 1 { + t.Errorf("expected executionsLeftInWindow to be %d, got %d", 1, rl.executionsLeftInWindow) + } + // Second execution: should not be rate limited + if notRateLimited := rl.Try(); !notRateLimited { + t.Error("expected Try to return true") + } + if rl.maximumExecutionsPerSecond != 2 { + t.Errorf("expected maximumExecutionsPerSecond to be %d, got %d", 2, rl.maximumExecutionsPerSecond) + } + if rl.executionsLeftInWindow != 0 { + t.Errorf("expected executionsLeftInWindow to be %d, got %d", 0, rl.executionsLeftInWindow) + } + // Third execution: should be rate limited + if notRateLimited := rl.Try(); notRateLimited { + t.Error("expected Try to return false") + } + if rl.maximumExecutionsPerSecond != 2 { + t.Errorf("expected maximumExecutionsPerSecond to be %d, got %d", 2, rl.maximumExecutionsPerSecond) + } + if rl.executionsLeftInWindow != 0 { + t.Errorf("expected executionsLeftInWindow to be %d, got %d", 0, rl.executionsLeftInWindow) + } +} + +func TestRateLimiter_Try(t *testing.T) { + rl := NewRateLimiter(5) + for i := 0; i < 20; i++ { + notRateLimited := rl.Try() + if i < 5 { + if !notRateLimited { + t.Fatal("expected to not be rate limited") + } + } else { + if notRateLimited { + t.Fatal("expected to be rate limited") + } + } + } +} + +func TestRateLimiter_TryAlwaysUnderRateLimit(t *testing.T) { + rl := NewRateLimiter(20) + for i := 0; i < 45; i++ { + notRateLimited := rl.Try() + if !notRateLimited { + t.Fatal("expected to not be rate limited") + } + time.Sleep(51 * time.Millisecond) + } +}