Skip to content

Commit

Permalink
fix #31 implement all middleware functions to have a NetHTTP alternative
Browse files Browse the repository at this point in the history
  • Loading branch information
szuecs committed Feb 11, 2018
1 parent 8f27dda commit 33c395e
Show file tree
Hide file tree
Showing 3 changed files with 295 additions and 8 deletions.
104 changes: 104 additions & 0 deletions example/zalando_nethttp/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
// Zalando specific example.
package main

import (
"flag"
"fmt"
"net/http"
"time"

"github.com/golang/glog"
"github.com/zalando/gin-oauth2"
"github.com/zalando/gin-oauth2/zalando"
"goji.io"
"goji.io/pat"
)

var USERS []zalando.AccessTuple = []zalando.AccessTuple{
{"/employees", "sszuecs", "Sandor Szücs"},
{"/employees", "njuettner", "Nick Jüttner"},
}

var TEAMS []zalando.AccessTuple = []zalando.AccessTuple{
{"teams", "opensourceguild", "OpenSource"},
{"teams", "tm", "Platform Engineering / System"},
{"teams", "teapot", "Platform / Cloud API"},
}
var SERVICES []zalando.AccessTuple = []zalando.AccessTuple{
{"services", "foo", "Fooservice"},
}

func loggerMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
glog.Infof("loggerMiddleware: Got request: %s", req.URL)
next.ServeHTTP(rw, req)
})
}

func hello(w http.ResponseWriter, r *http.Request) {
name := pat.Param(r, "name")
fmt.Fprintf(w, "Hello, %s!\n", name)
}

func main() {
flag.Parse()
// start glog flusher
go func() {
for range time.Tick(1 * time.Second) {
glog.Flush()
}
}()

mux := goji.NewMux()
mux.Use(loggerMiddleware)
mux.Use(ginoauth2.RequestLoggerNetHTTP([]string{"uid"}, "data"))
ginoauth2.VarianceTimer = 3000 * time.Millisecond // defaults to 30s

public := goji.SubMux()
mux.Handle(pat.New("/api/*"), public)
public.HandleFunc(pat.Get("/:name"), hello)

private := goji.SubMux()
mux.Handle(pat.New("/private/*"), private)
privateGroup := goji.SubMux()
mux.Handle(pat.New("/privateGroup/*"), privateGroup)
privateUser := goji.SubMux()
mux.Handle(pat.New("/privateUser/*"), privateUser)
privateService := goji.SubMux()
mux.Handle(pat.New("/privateService/*"), privateService)
glog.Infof("Register allowed users: %+v and groups: %+v and services: %+v", USERS, TEAMS, SERVICES)

private.Use(ginoauth2.AuthChainNetHTTP(zalando.OAuth2Endpoint, zalando.UidCheckNetHTTP(USERS), zalando.GroupCheckNetHTTP(TEAMS), zalando.UidCheckNetHTTP(SERVICES)))
privateGroup.Use(ginoauth2.AuthNetHTTP(zalando.GroupCheckNetHTTP(TEAMS), zalando.OAuth2Endpoint))
privateUser.Use(ginoauth2.AuthNetHTTP(zalando.UidCheckNetHTTP(USERS), zalando.OAuth2Endpoint))
privateService.Use(ginoauth2.AuthNetHTTP(zalando.ScopeAndCheckNetHTTP("uidcheck", "uid", "bar"), zalando.OAuth2Endpoint))

private.HandleFunc(pat.Get("/"), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
h := w.Header()
uid := h.Get("uid")
fmt.Fprintf(w, "Hello from private for groups and users: %s\n", uid)
}))

privateGroup.HandleFunc(pat.Get("/"), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
h := w.Header()
uid := h.Get("uid")
team := h.Get("team")
fmt.Fprintf(w, "Hello from private group: uid: %s, team: %s\n", uid, team)
}))

privateUser.HandleFunc(pat.Get("/"), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
h := w.Header()
uid := h.Get("uid")
fmt.Fprintf(w, "Hello from private user: uid: %s\n", uid)
}))

privateService.HandleFunc(pat.Get("/"), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
h := w.Header()
cn := h.Get("cn")
fmt.Fprintf(w, "Hello from private service cn: %s\n", cn)
}))

glog.Info("bootstrapped application")
http.ListenAndServe("localhost:8081", mux)

}
108 changes: 101 additions & 7 deletions ginoauth2.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,10 @@ type TokenContainer struct {
// access.
type AccessCheckFunction func(tc *TokenContainer, ctx *gin.Context) bool

// AccessCheckFunctionNetHTTP is a function that checks if a given token grants
// access.
type AccessCheckFunctionNetHTTP func(tc *TokenContainer, w http.ResponseWriter, r *http.Request) bool

func extractToken(r *http.Request) (*oauth2.Token, error) {
hdr := r.Header.Get("Authorization")
if hdr == "" {
Expand Down Expand Up @@ -180,12 +184,12 @@ func GetTokenContainer(token *oauth2.Token) (*TokenContainer, error) {
return ParseTokenContainer(token, data)
}

func getTokenContainer(ctx *gin.Context) (*TokenContainer, bool) {
func getTokenContainer(r *http.Request) (*TokenContainer, bool) {
var oauthToken *oauth2.Token
var tc *TokenContainer
var err error

if oauthToken, err = extractToken(ctx.Request); err != nil {
if oauthToken, err = extractToken(r); err != nil {
glog.Errorf("[Gin-OAuth] Can not extract oauth2.Token, caused by: %s", err)
return nil, false
}
Expand Down Expand Up @@ -232,6 +236,11 @@ func Auth(accessCheckFunction AccessCheckFunction, endpoints oauth2.Endpoint) gi
return AuthChain(endpoints, accessCheckFunction)
}

// AuthNetHTTP is the net/http version of Auth
func AuthNetHTTP(accessCheckFunction AccessCheckFunctionNetHTTP, endpoints oauth2.Endpoint) func(http.Handler) http.Handler {
return AuthChainNetHTTP(endpoints, accessCheckFunction)
}

// AuthChain is a router middleware that can be used to get an authenticated
// and authorized service for the whole router group. Similar to Auth, but
// takes a chain of AccessCheckFunctions and only fails if all of them fails.
Expand Down Expand Up @@ -262,7 +271,7 @@ func AuthChain(endpoints oauth2.Endpoint, accessCheckFunctions ...AccessCheckFun
varianceControl := make(chan bool, 1)

go func() {
tokenContainer, ok := getTokenContainer(ctx)
tokenContainer, ok := getTokenContainer(ctx.Request)
if !ok {
// set LOCATION header to auth endpoint such that the user can easily get a new access-token
ctx.Writer.Header().Set("Location", endpoints.AuthURL)
Expand Down Expand Up @@ -309,6 +318,70 @@ func AuthChain(endpoints oauth2.Endpoint, accessCheckFunctions ...AccessCheckFun
}
}

// AuthChainNetHTTP is the net/http version of AuthChain
func AuthChainNetHTTP(endpoints oauth2.Endpoint, accessCheckFunctions ...AccessCheckFunctionNetHTTP) func(http.Handler) http.Handler {
// init
AuthInfoURL = endpoints.TokenURL
// middleware
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
t := time.Now()
varianceControl := make(chan bool, 1)

go func() {
tokenContainer, ok := getTokenContainer(request)
if !ok {
// set LOCATION header to auth endpoint such that the user can easily get a new access-token
writer.Header().Set("Location", endpoints.AuthURL)
writer.WriteHeader(http.StatusUnauthorized)
writer.Write([]byte("No token in context"))
varianceControl <- false
return
}

if !tokenContainer.Valid() {
// set LOCATION header to auth endpoint such that the user can easily get a new access-token
writer.Header().Set("Location", endpoints.AuthURL)
writer.WriteHeader(http.StatusUnauthorized)
writer.Write([]byte("Invalid Token"))
varianceControl <- false
return
}

for i, fn := range accessCheckFunctions {
if fn(tokenContainer, writer, request) {
varianceControl <- true
break
}

if len(accessCheckFunctions)-1 == i {
writer.WriteHeader(http.StatusForbidden)
writer.Write([]byte("Access to the Resource is fobidden"))
varianceControl <- false
return
}
}
}()

select {
case ok := <-varianceControl:
if !ok {
glog.V(2).Infof("[Gin-OAuth] %12v %s access not allowed", time.Since(t), request.URL.Path)
return
}
case <-time.After(VarianceTimer):
writer.WriteHeader(http.StatusGatewayTimeout)
writer.Write([]byte("Authorization check overtime"))
glog.V(2).Infof("[Gin-OAuth] %12v %s overtime", time.Since(t), request.URL.Path)
return
}

glog.V(2).Infof("[Gin-OAuth] %12v %s access allowed", time.Since(t), request.URL.Path)
next.ServeHTTP(writer, request)
})
}
}

// RequestLogger is a middleware that logs all the request and prints
// relevant information. This can be used for logging all the
// requests that contain important information and are authorized.
Expand All @@ -333,12 +406,10 @@ func RequestLogger(keys []string, contentKey string) gin.HandlerFunc {
c.Next()
err := c.Errors
if request.Method != "GET" && err == nil {
data, e := c.Get(contentKey)
if e != false { //key is non existent
if data, ok := c.Get(contentKey); ok {
values := make([]string, 0)
for _, key := range keys {
val, keyPresent := c.Get(key)
if keyPresent {
if val, ok := c.Get(key); ok {
values = append(values, val.(string))
}
}
Expand All @@ -348,4 +419,27 @@ func RequestLogger(keys []string, contentKey string) gin.HandlerFunc {
}
}

// RequestLoggerNetHTTP is the net/http version of RequestLogger.
func RequestLoggerNetHTTP(keys []string, contentKey string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
request := r
next.ServeHTTP(w, r.WithContext(ctx))
if request.Method != "GET" {
if data, ok := ctx.Value(contentKey).(string); ok {
values := make([]string, 0)
for _, key := range keys {
s, ok := ctx.Value(key).(string)
if ok {
values = append(values, s)
}
}
glog.Infof("[Gin-OAuth] Request: %+v for %s", data, strings.Join(values, "-"))
}
}
})
}
}

// vim: ts=4 sw=4 noexpandtab nolist syn=go
91 changes: 90 additions & 1 deletion zalando/zalando.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,45 @@ func GroupCheck(at []AccessTuple) func(tc *ginoauth2.TokenContainer, ctx *gin.Co
}
}

// GroupCheckNetHTTP is the net/http version of GroupCheck
func GroupCheckNetHTTP(at []AccessTuple) func(tc *ginoauth2.TokenContainer, w http.ResponseWriter, r *http.Request) bool {
ats := at
return func(tc *ginoauth2.TokenContainer, w http.ResponseWriter, r *http.Request) bool {
blob, err := RequestTeamInfo(tc, TeamAPI)
if err != nil {
glog.Errorf("[Gin-OAuth] failed to get team info, caused by: %s", err)
return false
}
var data []TeamInfo
err = json.Unmarshal(blob, &data)
if err != nil {
glog.Errorf("[Gin-OAuth] JSON.Unmarshal failed, caused by: %s", err)
return false
}
granted := false
for _, teamInfo := range data {
for idx := range ats {
at := ats[idx]
if teamInfo.Id == at.Uid {
granted = true
glog.Infof("[Gin-OAuth] Grant access to %s as team member of \"%s\"\n", tc.Scopes["uid"].(string), teamInfo.Id)
}
if teamInfo.Type == "official" {
if uid, ok := tc.Scopes["uid"].(string); ok {
w.Header().Set("uid", uid)
w.Header().Set("team", teamInfo.Id)
}
}
}
}
return granted
}
}

// UidCheck is an authorization function that checks UID scope
// TokenContainer must be Valid. As side effect it sets "uid" and
// "cn" in the gin.Context to the authorized uid and cn (Realname).
func UidCheck(at []AccessTuple) func(tc *ginoauth2.TokenContainer, ctx *gin.Context) bool {
func UidCheck(at []AccessTuple) ginoauth2.AccessCheckFunction {
ats := at
return func(tc *ginoauth2.TokenContainer, ctx *gin.Context) bool {
uid := tc.Scopes["uid"].(string)
Expand All @@ -122,6 +157,24 @@ func UidCheck(at []AccessTuple) func(tc *ginoauth2.TokenContainer, ctx *gin.Cont
}
}

// UidCheckNetHTTP is the net/http version of UidCheck
func UidCheckNetHTTP(at []AccessTuple) ginoauth2.AccessCheckFunctionNetHTTP {
ats := at
return func(tc *ginoauth2.TokenContainer, w http.ResponseWriter, r *http.Request) bool {
uid := tc.Scopes["uid"].(string)
for idx := range ats {
at := ats[idx]
if tc.Realm == at.Realm && uid == at.Uid {
w.Header().Set("uid", uid)
w.Header().Set("cn", at.Cn)
glog.Infof("[Gin-OAuth] Grant access to %s\n", uid)
return true
}
}
return false
}
}

// ScopeCheck does an OR check of scopes given from token of the
// request to all provided scopes. If one of provided scopes is in the
// Scopes of the token it grants access to the resource.
Expand All @@ -141,6 +194,23 @@ func ScopeCheck(name string, scopes ...string) func(tc *ginoauth2.TokenContainer
}
}

// ScopeCheckNetHTTP is the net/http version of ScopeCheck
func ScopeCheckNetHTTP(name string, scopes ...string) func(tc *ginoauth2.TokenContainer, w http.ResponseWriter, r *http.Request) bool {
glog.Infof("ScopeCheck %s configured to grant access for scopes: %v", name, scopes)
configuredScopes := scopes
return func(tc *ginoauth2.TokenContainer, w http.ResponseWriter, r *http.Request) bool {
scopesFromToken := make([]string, 0)
for _, s := range configuredScopes {
if cur, ok := tc.Scopes[s].(string); ok {
glog.V(2).Infof("Found configured scope %s", cur)
scopesFromToken = append(scopesFromToken, cur)
w.Header().Add(s, cur)
}
}
return len(scopesFromToken) > 0
}
}

// ScopeAndCheck does an AND check of scopes given from token of the
// request to all provided scopes. Only if all of provided scopes are found in the
// Scopes of the token it grants access to the resource.
Expand All @@ -162,6 +232,25 @@ func ScopeAndCheck(name string, scopes ...string) func(tc *ginoauth2.TokenContai
}
}

// ScopeAndCheckNetHTTP is the net/http version of ScopeAndCheck
func ScopeAndCheckNetHTTP(name string, scopes ...string) func(tc *ginoauth2.TokenContainer, w http.ResponseWriter, r *http.Request) bool {
glog.Infof("ScopeCheck %s configured to grant access only if scopes: %v are present", name, scopes)
configuredScopes := scopes
return func(tc *ginoauth2.TokenContainer, w http.ResponseWriter, r *http.Request) bool {
scopesFromToken := make([]string, 0)
for _, s := range configuredScopes {
if cur, ok := tc.Scopes[s].(string); ok {
glog.V(2).Infof("Found configured scope %s", cur)
scopesFromToken = append(scopesFromToken, cur)
w.Header().Add(s, cur)
} else {
return false
}
}
return true
}
}

// NoAuthorization sets "team" and "uid" in the context without
// checking if the user/team is authorized.
func NoAuthorization() func(tc *ginoauth2.TokenContainer, ctx *gin.Context) bool {
Expand Down

0 comments on commit 33c395e

Please sign in to comment.