Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Proof of concept: Generic OAuth #162

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions CelesteNet.Server.FrontendModule/Frontend.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using System.Linq;
using System.Net;
using System.Reflection;
using System.Security.Cryptography;
using System.Timers;
using Celeste.Mod.CelesteNet.DataTypes;
using Celeste.Mod.CelesteNet.Server.Chat;
Expand All @@ -26,6 +27,8 @@ public class Frontend : CelesteNetServerModule<FrontendSettings> {

public readonly Dictionary<string, BasicUserInfo> TaggedUsers = new();

public readonly RSA RSAKeysOAuth = RSA.Create();

private HttpServer? HTTPServer;
private WebSocketServiceHost? WSHost;

Expand Down Expand Up @@ -500,6 +503,18 @@ public void Broadcast(Action<FrontendWebSocket> callback) {
}
}

public static string SignString(RSA keys, string content) {
var signData = System.Text.Encoding.UTF8.GetBytes(content);
var signature = keys.SignData(signData, HashAlgorithmName.SHA256, RSASignaturePadding.Pss);
return Convert.ToBase64String(signature);
}

public static bool VerifyString(RSA keys, string content, string signature) {
var signData = System.Text.Encoding.UTF8.GetBytes(content);
var signatureData = Convert.FromBase64String(signature);
return keys.VerifyData(signData, signatureData, HashAlgorithmName.SHA256, RSASignaturePadding.Pss);
}

#region Read / Parse Helpers

public NameValueCollection ParseQueryString(string url) {
Expand Down
45 changes: 45 additions & 0 deletions CelesteNet.Server.FrontendModule/FrontendSettings.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,51 @@ public class FrontendSettings : CelesteNetServerModuleSettings {

public float NetPlusStatsUpdateRate { get; set; } = 1000;

public class OAuthProvider {
public string OAuthPathAuthorize { get; set; } = "";
public string OAuthPathToken { get; set; } = "";
public string OAuthScope { get; set; } = "identify";
public string OAuthClientID { get; set; } = "";
public string OAuthClientSecret { get; set; } = "";

public string ServiceUserAPI { get; set; } = "";

public string ServiceUserJsonPathUid { get; set; } = "$.id";

public string ServiceUserJsonPathName { get; set; } = "$.username";

public string ServiceUserJsonPathPfp { get; set; } = "$.avatar";

// will be put through string.Format with {0} = uid (ServiceUserJsonPathUid) and {1} = pfpFragment (ServiceUserJsonPathPfp)
public string ServiceUserAvatarURL { get; set; } = "";

public string ServiceUserAvatarDefaultURL { get; set; } = "";

public string OAuthURL(string redirectURL, string state) {
return $"{OAuthPathAuthorize}?client_id={OAuthClientID}&redirect_uri={Uri.EscapeDataString(redirectURL)}&response_type=code&scope={OAuthScope}&state={Uri.EscapeDataString(state)}";
}
}

public Dictionary<string,OAuthProvider> OAuthProviders { get; set; } =
new Dictionary<string, OAuthProvider>() {
{ "discord",
new OAuthProvider()
{
OAuthPathAuthorize = "https://discord.com/oauth2/authorize",
OAuthPathToken = "https://discord.com/api/oauth2/token",
ServiceUserAPI = "https://discord.com/api/users/@me",
ServiceUserJsonPathUid = "$.id",
ServiceUserJsonPathName = "$.['global_name', 'username']",
ServiceUserJsonPathPfp = "$.avatar",
ServiceUserAvatarURL = "https://cdn.discordapp.com/avatars/{0}/{1}.png?size=64",
ServiceUserAvatarDefaultURL = "https://cdn.discordapp.com/embed/avatars/0.png"
}
}
};

[YamlIgnore]
public string OAuthRedirectURL => $"{CanonicalAPIRoot}/oauth";

// TODO: Separate Discord auth module!
[YamlIgnore]
public string DiscordOAuthURL => $"https://discord.com/oauth2/authorize?client_id={DiscordOAuthClientID}&redirect_uri={Uri.EscapeDataString(DiscordOAuthRedirectURL)}&response_type=code&scope=identify";
Expand Down
98 changes: 69 additions & 29 deletions CelesteNet.Server.FrontendModule/RCEPs/RCEPPublic.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
using System.IO;
using System.Linq;
using System.Net.Http;
using System.Security.Cryptography;
using Celeste.Mod.CelesteNet.Server.Chat;
using Newtonsoft.Json;
using Newtonsoft.Json.Linq;
using SixLabors.ImageSharp;
using SixLabors.ImageSharp.Drawing;
using SixLabors.ImageSharp.Drawing.Processing;
Expand All @@ -17,16 +19,35 @@
namespace Celeste.Mod.CelesteNet.Server.Control {
public static partial class RCEndpoints {

[RCEndpoint(false, "/discordauth", "", "", "Discord OAuth2", "User auth using Discord.")]
public static void DiscordOAuth(Frontend f, HttpRequestEventArgs c) {
[RCEndpoint(false, "/oauth", "", "", "OAuth2", "User auth using provider.")]
public static void OAuth(Frontend f, HttpRequestEventArgs c) {
NameValueCollection args = f.ParseQueryString(c.Request.RawUrl);

string? provider = args["provider"];
FrontendSettings.OAuthProvider? oauthProvider = null;
if (!provider.IsNullOrEmpty())
f.Settings.OAuthProviders.TryGetValue(provider, out oauthProvider);

if (args.Count == 0) {
// c.Response.Redirect(f.Settings.OAuthURL);
c.Response.StatusCode = (int)HttpStatusCode.BadRequest;
f.RespondJSON(c, new {
Error = "No OAuth provider specified."
});
return;
} else if (args.Count == 1 && oauthProvider != null) {
string newCSRFToken = "testinglol";

string newState = $"{provider}.{newCSRFToken}.{Frontend.SignString(f.RSAKeysOAuth, newCSRFToken)}";

string redirectOAuth = oauthProvider.OAuthURL(f.Settings.OAuthRedirectURL, newState);

Logger.Log(LogLevel.CRI, "frontend-oauth", $"Generated state token: {newState}");
Logger.Log(LogLevel.CRI, "frontend-oauth", $"redirect URI: {redirectOAuth}");

c.Response.StatusCode = (int) HttpStatusCode.Redirect;
c.Response.Headers.Set("Location", f.Settings.DiscordOAuthURL);
c.Response.Headers.Set("Location", redirectOAuth);
f.RespondJSON(c, new {
Info = $"Redirecting to {f.Settings.DiscordOAuthURL}"
Info = $"Redirecting to {redirectOAuth}"
});
return;
}
Expand All @@ -43,92 +64,111 @@ public static void DiscordOAuth(Frontend f, HttpRequestEventArgs c) {
}

string? code = args["code"];
if (code.IsNullOrEmpty()) {
string? state = args["state"];
if (code.IsNullOrEmpty() || state.IsNullOrEmpty()) {
c.Response.StatusCode = (int) HttpStatusCode.BadRequest;
f.RespondJSON(c, new {
Error = "No code specified."
Error = "OAuth2 code or state parameter missing."
});
return;
}

state = System.Uri.UnescapeDataString(state);
string[] splitState = state.Split(".");

Logger.Log(LogLevel.CRI, "frontend-oauth", $"State: {state}");
Logger.Log(LogLevel.CRI, "frontend-oauth", $"State split: {splitState}");

if (splitState.Length != 3 || !f.Settings.OAuthProviders.TryGetValue(splitState[0], out oauthProvider) || !Frontend.VerifyString(f.RSAKeysOAuth, splitState[1], splitState[2])) {
c.Response.StatusCode = (int)HttpStatusCode.BadRequest;
f.RespondJSON(c, new {
Error = $"OAuth2 CSRF state {System.Uri.EscapeDataString(state)} could not be verified!"
});
return;
}

dynamic? tokenData;
dynamic? userData;
JObject? userData;

using (HttpClient client = new()) {
Logger.Log(LogLevel.CRI, "frontend-oauth", $"requesting: {oauthProvider.OAuthPathToken} with code {code}");
#pragma warning disable CS8714 // new FormUrlEncodedContent expects nullable.
using (Stream s = client.PostAsync("https://discord.com/api/oauth2/token", new FormUrlEncodedContent(new Dictionary<string?, string?>() {
using (Stream s = client.PostAsync(oauthProvider.OAuthPathToken, new FormUrlEncodedContent(new Dictionary<string?, string?>() {
#pragma warning restore CS8714
{ "client_id", f.Settings.DiscordOAuthClientID },
{ "client_secret", f.Settings.DiscordOAuthClientSecret },
{ "client_id", oauthProvider.OAuthClientID },
{ "client_secret", oauthProvider.OAuthClientSecret },
{ "grant_type", "authorization_code" },
{ "code", code },
{ "redirect_uri", f.Settings.DiscordOAuthRedirectURL },
{ "scope", "identity" }
{ "redirect_uri", f.Settings.OAuthRedirectURL },
{ "scope", oauthProvider.OAuthScope }
})).Await().Content.ReadAsStreamAsync().Await())
using (StreamReader sr = new(s))
using (JsonTextReader jtr = new(sr))
tokenData = f.Serializer.Deserialize<dynamic>(jtr);

Logger.Log(LogLevel.CRI, "frontend-oauth", $"tokenData: {tokenData}");

if (tokenData?.access_token?.ToString() is not string token ||
tokenData?.token_type?.ToString() is not string tokenType ||
token.IsNullOrEmpty() ||
tokenType.IsNullOrEmpty()) {
Logger.Log(LogLevel.CRI, "frontend-discordauth", $"Failed to obtain token: {tokenData}");
Logger.Log(LogLevel.CRI, "frontend-oauth", $"Failed to obtain token: {tokenData}");
c.Response.StatusCode = (int) HttpStatusCode.InternalServerError;
f.RespondJSON(c, new {
Error = "Couldn't obtain access token from Discord."
});
return;
}

if (tokenType == "bearer")
tokenType = "Bearer";

using (Stream s = client.SendAsync(new HttpRequestMessage {
RequestUri = new("https://discord.com/api/users/@me"),
RequestUri = new(oauthProvider.ServiceUserAPI),
Method = HttpMethod.Get,
Headers = {
{ "Authorization", $"{tokenType} {token}" }
}
}).Await().Content.ReadAsStreamAsync().Await())
using (StreamReader sr = new(s))
using (JsonTextReader jtr = new(sr))
userData = f.Serializer.Deserialize<dynamic>(jtr);
userData = JObject.Parse(sr.ReadToEnd());
}

if (!(userData?.id?.ToString() is string uid) ||
if (!((string?)userData?.SelectTokens(oauthProvider.ServiceUserJsonPathUid).FirstOrDefault() is string uid) ||
uid.IsNullOrEmpty()) {
Logger.Log(LogLevel.CRI, "frontend-discordauth", $"Failed to obtain ID: {userData}");
Logger.Log(LogLevel.CRI, "frontend-oauth", $"Failed to obtain ID: {userData}");
c.Response.StatusCode = (int) HttpStatusCode.InternalServerError;
f.RespondJSON(c, new {
Error = "Couldn't obtain user ID from Discord."
Error = $"Couldn't obtain user ID from OAuth provider {provider}."
});
return;
}

string key = f.Server.UserData.Create(uid, false);
BasicUserInfo info = f.Server.UserData.Load<BasicUserInfo>(uid);

if (userData.global_name?.ToString() is string global_name && !global_name.IsNullOrEmpty()) {
if ((string?)userData?.SelectTokens(oauthProvider.ServiceUserJsonPathName).FirstOrDefault() is string global_name && !global_name.IsNullOrEmpty()) {
info.Name = global_name;
} else {
info.Name = userData.username.ToString();
info.Name = "";
}
if (info.Name.Length > 32) {
info.Name = info.Name.Substring(0, 32);
}
info.Discrim = userData.discriminator.ToString();
info.Discrim = "";
f.Server.UserData.Save(uid, info);

string? pfpFragment = (string?)userData?.SelectTokens(oauthProvider.ServiceUserJsonPathPfp).FirstOrDefault();

Image avatarOrig;
using (HttpClient client = new()) {
string avatarURL = string.Format(oauthProvider.ServiceUserAvatarURL, new object[] { uid, pfpFragment ?? "" });
try {
using Stream s = client.GetAsync(
$"https://cdn.discordapp.com/avatars/{uid}/{userData.avatar.ToString()}.png?size=64"
).Await().Content.ReadAsStreamAsync().Await();
using Stream s = client.GetAsync(avatarURL).Await().Content.ReadAsStreamAsync().Await();
avatarOrig = Image.Load<Rgba32>(s);
} catch {
using Stream s = client.GetAsync(
$"https://cdn.discordapp.com/embed/avatars/{((int) userData.discriminator) % 6}.png"
).Await().Content.ReadAsStreamAsync().Await();
avatarURL = oauthProvider.ServiceUserAvatarDefaultURL;
using Stream s = client.GetAsync(avatarURL).Await().Content.ReadAsStreamAsync().Await();
avatarOrig = Image.Load(s);
}
}
Expand Down