Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add admin Discord link endpoint #49

Merged
merged 15 commits into from
Sep 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
15 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 36 additions & 1 deletion api/account/discord.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,17 @@ import (
"errors"
"net/http"
"net/url"

"github.com/bwmarrin/discordgo"
)

var (
DiscordClientID string
DiscordClientSecret string
DiscordCallbackURL string

DiscordSession *discordgo.Session
DiscordGuildID string
)

func HandleDiscordCallback(w http.ResponseWriter, r *http.Request) (string, error) {
Expand All @@ -36,7 +41,6 @@ func HandleDiscordCallback(w http.ResponseWriter, r *http.Request) (string, erro
http.Redirect(w, r, GameURL, http.StatusSeeOther)
return "", errors.New("code is empty")
}

discordId, err := RetrieveDiscordId(code)
if err != nil {
http.Redirect(w, r, GameURL, http.StatusSeeOther)
Expand Down Expand Up @@ -106,3 +110,34 @@ func RetrieveDiscordId(code string) (string, error) {

return user.Id, nil
}

func IsUserDiscordAdmin(discordId string, discordGuildID string) (bool, error) {
// fetch all roles from discord
roles, err := DiscordSession.GuildRoles(discordGuildID)
if err != nil {
return false, err
}

// fetch all roles from user
userRoles, err := DiscordSession.GuildMember(discordGuildID, discordId)
if err != nil {
return false, err
}

// check if user has a "Dev" or a "Division Heads" role
var hasRole bool
for _, role := range userRoles.Roles {
for _, guildRole := range roles {
if role == guildRole.ID && (guildRole.Name == "Dev" || guildRole.Name == "Division Heads" || guildRole.Name == "Helper") {
hasRole = true
break
}
}
}

if !hasRole {
return false, nil
}

return true, nil
}
4 changes: 3 additions & 1 deletion api/account/info.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,18 @@ type InfoResponse struct {
DiscordId string `json:"discordId"`
GoogleId string `json:"googleId"`
LastSessionSlot int `json:"lastSessionSlot"`
HasAdminRole bool `json:"hasAdminRole"`
}

// /account/info - get account info
func Info(username string, discordId string, googleId string, uuid []byte) (InfoResponse, error) {
func Info(username string, discordId string, googleId string, uuid []byte, hasAdminRole bool) (InfoResponse, error) {
slot, _ := db.GetLatestSessionSaveDataSlot(uuid)
response := InfoResponse{
Username: username,
LastSessionSlot: slot,
DiscordId: discordId,
GoogleId: googleId,
HasAdminRole: hasAdminRole,
}
return response, nil
}
4 changes: 4 additions & 0 deletions api/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ func Init(mux *http.ServeMux) error {
// auth
mux.HandleFunc("/auth/{provider}/callback", handleProviderCallback)
mux.HandleFunc("/auth/{provider}/logout", handleProviderLogout)

// admin
mux.HandleFunc("POST /admin/account/discord-link", handleAdminDiscordLink)

return nil
}

Expand Down
45 changes: 44 additions & 1 deletion api/endpoints.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
"encoding/json"
"errors"
"fmt"
"log"
"net/http"
"strconv"
"strings"
Expand Down Expand Up @@ -68,7 +69,13 @@
return
}
}
response, err := account.Info(username, discordId, googleId, uuid)

var hasAdminRole bool
if discordId != "" {
hasAdminRole, _ = account.IsUserDiscordAdmin(discordId, account.DiscordGuildID)
}

response, err := account.Info(username, discordId, googleId, uuid, hasAdminRole)
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
return
Expand Down Expand Up @@ -158,7 +165,7 @@
}

func handleGameClassicSessionCount(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(strconv.Itoa(classicSessionCount)))

Check failure on line 168 in api/endpoints.go

View workflow job for this annotation

GitHub Actions / Build (linux)

Error return value of `w.Write` is not checked (errcheck)

Check failure on line 168 in api/endpoints.go

View workflow job for this annotation

GitHub Actions / Build (windows)

Error return value of `w.Write` is not checked (errcheck)
}

func handleSession(w http.ResponseWriter, r *http.Request) {
Expand Down Expand Up @@ -555,7 +562,7 @@
httpError(w, r, err, http.StatusInternalServerError)
}

w.Write([]byte(strconv.Itoa(count)))

Check failure on line 565 in api/endpoints.go

View workflow job for this annotation

GitHub Actions / Build (linux)

Error return value of `w.Write` is not checked (errcheck)

Check failure on line 565 in api/endpoints.go

View workflow job for this annotation

GitHub Actions / Build (windows)

Error return value of `w.Write` is not checked (errcheck)
}

// redirect link after authorizing application link
Expand Down Expand Up @@ -660,3 +667,39 @@
}
w.WriteHeader(http.StatusOK)
}

func handleAdminDiscordLink(w http.ResponseWriter, r *http.Request) {
err := r.ParseForm()
if err != nil {
httpError(w, r, fmt.Errorf("failed to parse request form: %s", err), http.StatusBadRequest)
return
}

uuid, err := uuidFromRequest(r)
if err != nil {
httpError(w, r, err, http.StatusUnauthorized)
return
}

userDiscordId, err := db.FetchDiscordIdByUUID(uuid)
if err != nil {
httpError(w, r, err, http.StatusUnauthorized)
return
}

hasRole, err := account.IsUserDiscordAdmin(userDiscordId, account.DiscordGuildID)
if !hasRole || err != nil {
httpError(w, r, fmt.Errorf("user does not have the required role"), http.StatusForbidden)
return
}

err = db.AddDiscordIdByUsername(r.Form.Get("discordId"), r.Form.Get("username"))
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
return
}

log.Printf("%s: %s added discord id %s to username %s", r.URL.Path, userDiscordId, r.Form.Get("discordId"), r.Form.Get("username"))

w.WriteHeader(http.StatusOK)
}
28 changes: 28 additions & 0 deletions db/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,34 @@ func FetchGoogleIdByUsername(username string) (string, error) {
return googleId.String, nil
}

func FetchDiscordIdByUUID(uuid []byte) (string, error) {
var discordId sql.NullString
err := handle.QueryRow("SELECT discordId FROM accounts WHERE uuid = ?", uuid).Scan(&discordId)
if err != nil {
return "", err
}

if !discordId.Valid {
return "", nil
}

return discordId.String, nil
}

func FetchGoogleIdByUUID(uuid []byte) (string, error) {
var googleId sql.NullString
err := handle.QueryRow("SELECT googleId FROM accounts WHERE uuid = ?", uuid).Scan(&googleId)
if err != nil {
return "", err
}

if !googleId.Valid {
return "", nil
}

return googleId.String, nil
}

func FetchUsernameBySessionToken(token []byte) (string, error) {
var username string
err := handle.QueryRow("SELECT a.username FROM accounts a JOIN sessions s ON a.uuid = s.uuid WHERE s.token = ?", token).Scan(&username)
Expand Down
6 changes: 5 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,8 @@ require (
github.com/klauspost/compress v1.17.9
)

require golang.org/x/sys v0.19.0 // indirect
require (
github.com/bwmarrin/discordgo v0.28.1 // indirect
github.com/gorilla/websocket v1.4.2 // indirect
golang.org/x/sys v0.19.0 // indirect
)
10 changes: 10 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,12 +1,22 @@
github.com/bwmarrin/discordgo v0.28.1 h1:gXsuo2GBO7NbR6uqmrrBDplPUx2T3nzu775q/Rd1aG4=
github.com/bwmarrin/discordgo v0.28.1/go.mod h1:NJZpH+1AfhIcyQsPeuBKsUtYrRnjkyu0kIVMCHkZtRY=
github.com/go-sql-driver/mysql v1.7.1 h1:lUIinVbN1DY0xBg0eMOzmmtGoHwWBbvnWubQUrtU8EI=
github.com/go-sql-driver/mysql v1.7.1/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk=
github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA=
github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw=
github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc=
github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs=
github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro=
golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
golang.org/x/crypto v0.22.0 h1:g1v0xeRhjcugydODzvb3mEM9SQ0HGp9s/nh3COQ/C30=
golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.19.0 h1:q5f1RH2jigJ1MoAWp2KTp3gm5zAGFUTarQZ5U386+4o=
golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
79 changes: 45 additions & 34 deletions rogueserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,66 +19,69 @@ package main

import (
"encoding/gob"
"flag"
"log"
"net"
"net/http"
"os"
"strconv"

"github.com/bwmarrin/discordgo"
"github.com/pagefaultgames/rogueserver/api"
"github.com/pagefaultgames/rogueserver/api/account"
"github.com/pagefaultgames/rogueserver/db"
)

func main() {
// flag stuff
debug := flag.Bool("debug", false, "use debug mode")
// env stuff
debug, _ := strconv.ParseBool(os.Getenv("debug"))

proto := flag.String("proto", "tcp", "protocol for api to use (tcp, unix)")
addr := flag.String("addr", "0.0.0.0:8001", "network address for api to listen on")
tlscert := flag.String("tlscert", "", "tls certificate path")
tlskey := flag.String("tlskey", "", "tls key path")
proto := getEnv("proto", "tcp")
addr := getEnv("addr", "0.0.0.0:8001")
tlscert := getEnv("tlscert", "")
tlskey := getEnv("tlskey", "")

dbuser := flag.String("dbuser", "pokerogue", "database username")
dbpass := flag.String("dbpass", "pokerogue", "database password")
dbproto := flag.String("dbproto", "tcp", "protocol for database connection")
dbaddr := flag.String("dbaddr", "localhost", "database address")
dbname := flag.String("dbname", "pokeroguedb", "database name")
dbuser := getEnv("dbuser", "pokerogue")
dbpass := getEnv("dbpass", "pokerogue")
dbproto := getEnv("dbproto", "tcp")
dbaddr := getEnv("dbaddr", "localhost")
dbname := getEnv("dbname", "pokeroguedb")

discordclientid := flag.String("discordclientid", "dcid", "Discord Oauth2 Client ID")
discordsecretid := flag.String("discordsecretid", "dsid", "Discord Oauth2 Secret ID")
discordclientid := getEnv("discordclientid", "")
discordsecretid := getEnv("discordsecretid", "")

googleclientid := flag.String("googleclientid", "gcid", "Google Oauth2 Client ID")
googlesecretid := flag.String("googlesecretid", "gsid", "Google Oauth2 Secret ID")
googleclientid := getEnv("googleclientid", "")
googlesecretid := getEnv("googlesecretid", "")

callbackurl := flag.String("callbackurl", "http://localhost:8001/", "Callback URL for Oauth2 Client")
callbackurl := getEnv("callbackurl", "http://localhost:8001/")

gameurl := flag.String("gameurl", "https://pokerogue.net", "URL for game server")
gameurl := getEnv("gameurl", "https://pokerogue.net")

flag.Parse()
discordbottoken := getEnv("discordbottoken", "")
discordguildid := getEnv("discordguildid", "")

account.GameURL = *gameurl
account.GameURL = gameurl

account.DiscordClientID = *discordclientid
account.DiscordClientSecret = *discordsecretid
account.DiscordCallbackURL = *callbackurl + "/auth/discord/callback"

account.GoogleClientID = *googleclientid
account.GoogleClientSecret = *googlesecretid
account.GoogleCallbackURL = *callbackurl + "/auth/google/callback"
account.DiscordClientID = discordclientid
account.DiscordClientSecret = discordsecretid
account.DiscordCallbackURL = callbackurl + "/auth/discord/callback"

account.GoogleClientID = googleclientid
account.GoogleClientSecret = googlesecretid
account.GoogleCallbackURL = callbackurl + "/auth/google/callback"
account.DiscordSession, _ = discordgo.New("Bot " + discordbottoken)
account.DiscordGuildID = discordguildid
// register gob types
gob.Register([]interface{}{})
gob.Register(map[string]interface{}{})

// get database connection
err := db.Init(*dbuser, *dbpass, *dbproto, *dbaddr, *dbname)
err := db.Init(dbuser, dbpass, dbproto, dbaddr, dbname)
if err != nil {
log.Fatalf("failed to initialize database: %s", err)
}

// create listener
listener, err := createListener(*proto, *addr)
listener, err := createListener(proto, addr)
if err != nil {
log.Fatalf("failed to create net listener: %s", err)
}
Expand All @@ -92,14 +95,14 @@ func main() {

// start web server
handler := prodHandler(mux, gameurl)
if *debug {
if debug {
handler = debugHandler(mux)
}

if *tlscert == "" {
if tlscert == "" {
err = http.Serve(listener, handler)
} else {
err = http.ServeTLS(listener, handler, *tlscert, *tlskey)
err = http.ServeTLS(listener, handler, tlscert, tlskey)
}
if err != nil {
log.Fatalf("failed to create http server or server errored: %s", err)
Expand All @@ -126,11 +129,11 @@ func createListener(proto, addr string) (net.Listener, error) {
return listener, nil
}

func prodHandler(router *http.ServeMux, clienturl *string) http.Handler {
func prodHandler(router *http.ServeMux, clienturl string) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type")
w.Header().Set("Access-Control-Allow-Methods", "OPTIONS, GET, POST")
w.Header().Set("Access-Control-Allow-Origin", *clienturl)
w.Header().Set("Access-Control-Allow-Origin", clienturl)

if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusOK)
Expand All @@ -155,3 +158,11 @@ func debugHandler(router *http.ServeMux) http.Handler {
router.ServeHTTP(w, r)
})
}

func getEnv(key string, defaultValue string) string {
if value, ok := os.LookupEnv(key); ok {
return value
}

return defaultValue
}
Loading