From 1bd2ac906d59b157d565ccd44a585863d5742c58 Mon Sep 17 00:00:00 2001 From: RedFlames Date: Wed, 6 Nov 2024 04:00:17 +0100 Subject: [PATCH] Proof of concept: Generic OAuth --- .../Components/CelesteNetEmojiComponent.cs | 3 + CelesteNet.Server.FrontendModule/Frontend.cs | 15 +++ .../FrontendSettings.cs | 45 +++++++++ .../RCEPs/RCEPPublic.cs | 98 +++++++++++++------ 4 files changed, 132 insertions(+), 29 deletions(-) diff --git a/CelesteNet.Client/Components/CelesteNetEmojiComponent.cs b/CelesteNet.Client/Components/CelesteNetEmojiComponent.cs index d8532814..ac8e6533 100644 --- a/CelesteNet.Client/Components/CelesteNetEmojiComponent.cs +++ b/CelesteNet.Client/Components/CelesteNetEmojiComponent.cs @@ -98,6 +98,9 @@ public CelesteNetEmojiComponent(CelesteNetClientContext context, Game game) } public void Handle(CelesteNetConnection con, DataNetEmoji netemoji) { + if (Client?.Options?.AvatarsDisabled == true) + return; + lock (Pending) { // Get the emoji asset if (!Pending.TryGetValue(netemoji.ID, out NetEmojiAsset asset)) diff --git a/CelesteNet.Server.FrontendModule/Frontend.cs b/CelesteNet.Server.FrontendModule/Frontend.cs index f016a131..b7a1c361 100644 --- a/CelesteNet.Server.FrontendModule/Frontend.cs +++ b/CelesteNet.Server.FrontendModule/Frontend.cs @@ -6,6 +6,7 @@ using System.Linq; using System.Net; using System.Reflection; +using System.Security.Cryptography; using System.Timers; using Celeste.Mod.CelesteNet.DataTypes; using Celeste.Mod.CelesteNet.Server.Chat; @@ -26,6 +27,8 @@ public class Frontend : CelesteNetServerModule { public readonly Dictionary TaggedUsers = new(); + public readonly RSA RSAKeysOAuth = RSA.Create(); + private HttpServer? HTTPServer; private WebSocketServiceHost? WSHost; @@ -500,6 +503,18 @@ public void Broadcast(Action callback) { } } + public static string SignString(RSA keys, string content) { + var signData = System.Text.Encoding.UTF8.GetBytes(content); + var signature = keys.SignData(signData, HashAlgorithmName.SHA256, RSASignaturePadding.Pss); + return Convert.ToBase64String(signature); + } + + public static bool VerifyString(RSA keys, string content, string signature) { + var signData = System.Text.Encoding.UTF8.GetBytes(content); + var signatureData = Convert.FromBase64String(signature); + return keys.VerifyData(signData, signatureData, HashAlgorithmName.SHA256, RSASignaturePadding.Pss); + } + #region Read / Parse Helpers public NameValueCollection ParseQueryString(string url) { diff --git a/CelesteNet.Server.FrontendModule/FrontendSettings.cs b/CelesteNet.Server.FrontendModule/FrontendSettings.cs index a9ba2c55..c80a7883 100644 --- a/CelesteNet.Server.FrontendModule/FrontendSettings.cs +++ b/CelesteNet.Server.FrontendModule/FrontendSettings.cs @@ -22,6 +22,51 @@ public class FrontendSettings : CelesteNetServerModuleSettings { public float NetPlusStatsUpdateRate { get; set; } = 1000; + public class OAuthProvider { + public string OAuthPathAuthorize { get; set; } = ""; + public string OAuthPathToken { get; set; } = ""; + public string OAuthScope { get; set; } = "identify"; + public string OAuthClientID { get; set; } = ""; + public string OAuthClientSecret { get; set; } = ""; + + public string ServiceUserAPI { get; set; } = ""; + + public string ServiceUserJsonPathUid { get; set; } = "$.id"; + + public string ServiceUserJsonPathName { get; set; } = "$.username"; + + public string ServiceUserJsonPathPfp { get; set; } = "$.avatar"; + + // will be put through string.Format with {0} = uid (ServiceUserJsonPathUid) and {1} = pfpFragment (ServiceUserJsonPathPfp) + public string ServiceUserAvatarURL { get; set; } = ""; + + public string ServiceUserAvatarDefaultURL { get; set; } = ""; + + public string OAuthURL(string redirectURL, string state) { + return $"{OAuthPathAuthorize}?client_id={OAuthClientID}&redirect_uri={Uri.EscapeDataString(redirectURL)}&response_type=code&scope={OAuthScope}&state={Uri.EscapeDataString(state)}"; + } + } + + public Dictionary OAuthProviders { get; set; } = + new Dictionary() { + { "discord", + new OAuthProvider() + { + OAuthPathAuthorize = "https://discord.com/oauth2/authorize", + OAuthPathToken = "https://discord.com/api/oauth2/token", + ServiceUserAPI = "https://discord.com/api/users/@me", + ServiceUserJsonPathUid = "$.id", + ServiceUserJsonPathName = "$.['global_name', 'username']", + ServiceUserJsonPathPfp = "$.avatar", + ServiceUserAvatarURL = "https://cdn.discordapp.com/avatars/{0}/{1}.png?size=64", + ServiceUserAvatarDefaultURL = "https://cdn.discordapp.com/embed/avatars/0.png" + } + } + }; + + [YamlIgnore] + public string OAuthRedirectURL => $"{CanonicalAPIRoot}/oauth"; + // TODO: Separate Discord auth module! [YamlIgnore] public string DiscordOAuthURL => $"https://discord.com/oauth2/authorize?client_id={DiscordOAuthClientID}&redirect_uri={Uri.EscapeDataString(DiscordOAuthRedirectURL)}&response_type=code&scope=identify"; diff --git a/CelesteNet.Server.FrontendModule/RCEPs/RCEPPublic.cs b/CelesteNet.Server.FrontendModule/RCEPs/RCEPPublic.cs index 7fc8d1d7..4095ee5c 100644 --- a/CelesteNet.Server.FrontendModule/RCEPs/RCEPPublic.cs +++ b/CelesteNet.Server.FrontendModule/RCEPs/RCEPPublic.cs @@ -3,8 +3,10 @@ using System.IO; using System.Linq; using System.Net.Http; +using System.Security.Cryptography; using Celeste.Mod.CelesteNet.Server.Chat; using Newtonsoft.Json; +using Newtonsoft.Json.Linq; using SixLabors.ImageSharp; using SixLabors.ImageSharp.Drawing; using SixLabors.ImageSharp.Drawing.Processing; @@ -17,16 +19,35 @@ namespace Celeste.Mod.CelesteNet.Server.Control { public static partial class RCEndpoints { - [RCEndpoint(false, "/discordauth", "", "", "Discord OAuth2", "User auth using Discord.")] - public static void DiscordOAuth(Frontend f, HttpRequestEventArgs c) { + [RCEndpoint(false, "/oauth", "", "", "OAuth2", "User auth using provider.")] + public static void OAuth(Frontend f, HttpRequestEventArgs c) { NameValueCollection args = f.ParseQueryString(c.Request.RawUrl); + string? provider = args["provider"]; + FrontendSettings.OAuthProvider? oauthProvider = null; + if (!provider.IsNullOrEmpty()) + f.Settings.OAuthProviders.TryGetValue(provider, out oauthProvider); + if (args.Count == 0) { - // c.Response.Redirect(f.Settings.OAuthURL); + c.Response.StatusCode = (int)HttpStatusCode.BadRequest; + f.RespondJSON(c, new { + Error = "No OAuth provider specified." + }); + return; + } else if (args.Count == 1 && oauthProvider != null) { + string newCSRFToken = "testinglol"; + + string newState = $"{provider}.{newCSRFToken}.{Frontend.SignString(f.RSAKeysOAuth, newCSRFToken)}"; + + string redirectOAuth = oauthProvider.OAuthURL(f.Settings.OAuthRedirectURL, newState); + + Logger.Log(LogLevel.CRI, "frontend-oauth", $"Generated state token: {newState}"); + Logger.Log(LogLevel.CRI, "frontend-oauth", $"redirect URI: {redirectOAuth}"); + c.Response.StatusCode = (int) HttpStatusCode.Redirect; - c.Response.Headers.Set("Location", f.Settings.DiscordOAuthURL); + c.Response.Headers.Set("Location", redirectOAuth); f.RespondJSON(c, new { - Info = $"Redirecting to {f.Settings.DiscordOAuthURL}" + Info = $"Redirecting to {redirectOAuth}" }); return; } @@ -43,37 +64,55 @@ public static void DiscordOAuth(Frontend f, HttpRequestEventArgs c) { } string? code = args["code"]; - if (code.IsNullOrEmpty()) { + string? state = args["state"]; + if (code.IsNullOrEmpty() || state.IsNullOrEmpty()) { c.Response.StatusCode = (int) HttpStatusCode.BadRequest; f.RespondJSON(c, new { - Error = "No code specified." + Error = "OAuth2 code or state parameter missing." + }); + return; + } + + state = System.Uri.UnescapeDataString(state); + string[] splitState = state.Split("."); + + Logger.Log(LogLevel.CRI, "frontend-oauth", $"State: {state}"); + Logger.Log(LogLevel.CRI, "frontend-oauth", $"State split: {splitState}"); + + if (splitState.Length != 3 || !f.Settings.OAuthProviders.TryGetValue(splitState[0], out oauthProvider) || !Frontend.VerifyString(f.RSAKeysOAuth, splitState[1], splitState[2])) { + c.Response.StatusCode = (int)HttpStatusCode.BadRequest; + f.RespondJSON(c, new { + Error = $"OAuth2 CSRF state {System.Uri.EscapeDataString(state)} could not be verified!" }); return; } dynamic? tokenData; - dynamic? userData; + JObject? userData; using (HttpClient client = new()) { + Logger.Log(LogLevel.CRI, "frontend-oauth", $"requesting: {oauthProvider.OAuthPathToken} with code {code}"); #pragma warning disable CS8714 // new FormUrlEncodedContent expects nullable. - using (Stream s = client.PostAsync("https://discord.com/api/oauth2/token", new FormUrlEncodedContent(new Dictionary() { + using (Stream s = client.PostAsync(oauthProvider.OAuthPathToken, new FormUrlEncodedContent(new Dictionary() { #pragma warning restore CS8714 - { "client_id", f.Settings.DiscordOAuthClientID }, - { "client_secret", f.Settings.DiscordOAuthClientSecret }, + { "client_id", oauthProvider.OAuthClientID }, + { "client_secret", oauthProvider.OAuthClientSecret }, { "grant_type", "authorization_code" }, { "code", code }, - { "redirect_uri", f.Settings.DiscordOAuthRedirectURL }, - { "scope", "identity" } + { "redirect_uri", f.Settings.OAuthRedirectURL }, + { "scope", oauthProvider.OAuthScope } })).Await().Content.ReadAsStreamAsync().Await()) using (StreamReader sr = new(s)) using (JsonTextReader jtr = new(sr)) tokenData = f.Serializer.Deserialize(jtr); + Logger.Log(LogLevel.CRI, "frontend-oauth", $"tokenData: {tokenData}"); + if (tokenData?.access_token?.ToString() is not string token || tokenData?.token_type?.ToString() is not string tokenType || token.IsNullOrEmpty() || tokenType.IsNullOrEmpty()) { - Logger.Log(LogLevel.CRI, "frontend-discordauth", $"Failed to obtain token: {tokenData}"); + Logger.Log(LogLevel.CRI, "frontend-oauth", $"Failed to obtain token: {tokenData}"); c.Response.StatusCode = (int) HttpStatusCode.InternalServerError; f.RespondJSON(c, new { Error = "Couldn't obtain access token from Discord." @@ -81,25 +120,26 @@ public static void DiscordOAuth(Frontend f, HttpRequestEventArgs c) { return; } + if (tokenType == "bearer") + tokenType = "Bearer"; using (Stream s = client.SendAsync(new HttpRequestMessage { - RequestUri = new("https://discord.com/api/users/@me"), + RequestUri = new(oauthProvider.ServiceUserAPI), Method = HttpMethod.Get, Headers = { { "Authorization", $"{tokenType} {token}" } } }).Await().Content.ReadAsStreamAsync().Await()) using (StreamReader sr = new(s)) - using (JsonTextReader jtr = new(sr)) - userData = f.Serializer.Deserialize(jtr); + userData = JObject.Parse(sr.ReadToEnd()); } - if (!(userData?.id?.ToString() is string uid) || + if (!((string?)userData?.SelectTokens(oauthProvider.ServiceUserJsonPathUid).FirstOrDefault() is string uid) || uid.IsNullOrEmpty()) { - Logger.Log(LogLevel.CRI, "frontend-discordauth", $"Failed to obtain ID: {userData}"); + Logger.Log(LogLevel.CRI, "frontend-oauth", $"Failed to obtain ID: {userData}"); c.Response.StatusCode = (int) HttpStatusCode.InternalServerError; f.RespondJSON(c, new { - Error = "Couldn't obtain user ID from Discord." + Error = $"Couldn't obtain user ID from OAuth provider {provider}." }); return; } @@ -107,28 +147,28 @@ public static void DiscordOAuth(Frontend f, HttpRequestEventArgs c) { string key = f.Server.UserData.Create(uid, false); BasicUserInfo info = f.Server.UserData.Load(uid); - if (userData.global_name?.ToString() is string global_name && !global_name.IsNullOrEmpty()) { + if ((string?)userData?.SelectTokens(oauthProvider.ServiceUserJsonPathName).FirstOrDefault() is string global_name && !global_name.IsNullOrEmpty()) { info.Name = global_name; } else { - info.Name = userData.username.ToString(); + info.Name = ""; } if (info.Name.Length > 32) { info.Name = info.Name.Substring(0, 32); } - info.Discrim = userData.discriminator.ToString(); + info.Discrim = ""; f.Server.UserData.Save(uid, info); + string? pfpFragment = (string?)userData?.SelectTokens(oauthProvider.ServiceUserJsonPathPfp).FirstOrDefault(); + Image avatarOrig; using (HttpClient client = new()) { + string avatarURL = string.Format(oauthProvider.ServiceUserAvatarURL, new object[] { uid, pfpFragment ?? "" }); try { - using Stream s = client.GetAsync( - $"https://cdn.discordapp.com/avatars/{uid}/{userData.avatar.ToString()}.png?size=64" - ).Await().Content.ReadAsStreamAsync().Await(); + using Stream s = client.GetAsync(avatarURL).Await().Content.ReadAsStreamAsync().Await(); avatarOrig = Image.Load(s); } catch { - using Stream s = client.GetAsync( - $"https://cdn.discordapp.com/embed/avatars/{((int) userData.discriminator) % 6}.png" - ).Await().Content.ReadAsStreamAsync().Await(); + avatarURL = oauthProvider.ServiceUserAvatarDefaultURL; + using Stream s = client.GetAsync(avatarURL).Await().Content.ReadAsStreamAsync().Await(); avatarOrig = Image.Load(s); } }