Skip to content

Commit

Permalink
Parameterize oauth urls
Browse files Browse the repository at this point in the history
  • Loading branch information
tygern committed Jun 12, 2024
1 parent 8d72ba0 commit ee8ac76
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 33 deletions.
3 changes: 2 additions & 1 deletion negotiator/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ def create_app(env: Environment = Environment.from_env()) -> Flask:
message_gateway = MessageGateway(db_template)
negotiation_service = NegotiationService(db_template, negotiation_gateway, message_gateway)

oauth_client = OAuthClient(client_id=env.client_id, client_secret=env.client_secret, host_url=env.host_url)
oauth_client = OAuthClient(client_id=env.client_id, client_secret=env.client_secret, host_url=env.host_url,
oauth_url=env.oauth_url, user_info_url=env.user_info_url)
allowed_emails = AllowedEmails(domains=env.allowed_domains, addresses=env.allowed_addresses)

app.register_blueprint(index_page())
Expand Down
4 changes: 4 additions & 0 deletions negotiator/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ class Environment:
openai_api_key: str
client_id: str
client_secret: str
oauth_url: str
user_info_url: str
host_url: str
allowed_domains: str
allowed_addresses: str
Expand All @@ -24,6 +26,8 @@ def from_env(cls) -> 'Environment':
openai_api_key=cls.__require_env('OPENAI_API_KEY'),
client_id=cls.__require_env('CLIENT_ID'),
client_secret=cls.__require_env('CLIENT_SECRET'),
oauth_url=os.environ.get('OAUTH_URL', 'https://accounts.google.com/o/oauth2'),
user_info_url=os.environ.get('USER_INFO_URL', 'https://www.googleapis.com/oauth2/v3/userinfo'),
host_url=cls.__require_env('HOST_URL'),
allowed_domains=os.environ.get('ALLOWED_DOMAINS', ""),
allowed_addresses=os.environ.get('ALLOWED_ADDRESSES', ""),
Expand Down
13 changes: 7 additions & 6 deletions negotiator/oauth/oauth_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@


class OAuthClient:
def __init__(self, client_id: str, client_secret: str, host_url: str):
def __init__(self, client_id: str, client_secret: str, host_url: str, oauth_url: str, user_info_url: str):
self.__client_id = client_id
self.__client_secret = client_secret
self.__host_url = host_url
self.__oauth_url = oauth_url
self.__user_info_url = user_info_url

def auth_url(self, state) -> str:
query_string = urlencode({
Expand All @@ -18,10 +20,10 @@ def auth_url(self, state) -> str:
'scope': 'email',
'state': state,
})
return f"https://accounts.google.com/o/oauth2/auth?{query_string}"
return f"{self.__oauth_url}/auth?{query_string}"

def fetch_access_token(self, code: str) -> Union[None, str]:
response = requests.post('https://accounts.google.com/o/oauth2/token', data={
response = requests.post(f'{self.__oauth_url}/token', data={
'client_id': self.__client_id,
'client_secret': self.__client_secret,
'code': code,
Expand All @@ -34,9 +36,8 @@ def fetch_access_token(self, code: str) -> Union[None, str]:

return response.json().get('access_token')

@staticmethod
def read_email_from_token(token: str) -> Union[None, str]:
response = requests.get('https://www.googleapis.com/oauth2/v3/userinfo', headers={
def read_email_from_token(self, token: str) -> Union[None, str]:
response = requests.get(self.__user_info_url, headers={
'Authorization': f'Bearer {token}',
'Accept': 'application/json',
})
Expand Down
45 changes: 19 additions & 26 deletions tests/oauth/test_oauth_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,52 +6,48 @@


class TestOAuthClient(TestCase):
def setUp(self):
self.client = OAuthClient(client_id='some_client_id', client_secret='some_client_secret',
host_url='https://host.example.com', oauth_url='https://oauth.example.com',
user_info_url='https://userinfo.example.com')

def test_auth_url(self):
client = OAuthClient('some_client_id', 'some_client_secret', 'https://example.com')
auth_url = client.auth_url('some_state')
auth_url = self.client.auth_url('some_state')
self.assertEqual(
'https://accounts.google.com/o/oauth2/auth?client_id=some_client_id&redirect_uri=https%3A%2F%2Fexample.com%2Foauth%2Fcallback&response_type=code&scope=email&state=some_state',
'https://oauth.example.com/auth?client_id=some_client_id'
'&redirect_uri=https%3A%2F%2Fhost.example.com%2Foauth%2Fcallback&response_type=code&scope=email'
'&state=some_state',
auth_url
)

@responses.activate
def test_fetch_access_token(self):
client = OAuthClient('some_client_id', 'some_client_secret', 'https://example.com')
token_endpoint = responses.post(
'https://accounts.google.com/o/oauth2/token',
json={'access_token': 'some_access_token'}
)
token_endpoint = responses.post('https://oauth.example.com/token', json={'access_token': 'some_access_token'})

access_token = client.fetch_access_token('some_code')
access_token = self.client.fetch_access_token('some_code')

self.assertEqual('some_access_token', access_token)
self.assertEqual(1, token_endpoint.call_count)
recorded_request = responses.calls[0].request
self.assertEqual(
"client_id=some_client_id&client_secret=some_client_secret&code=some_code&grant_type=authorization_code&redirect_uri=https%3A%2F%2Fexample.com%2Foauth%2Fcallback",
"client_id=some_client_id&client_secret=some_client_secret&code=some_code&grant_type=authorization_code"
"&redirect_uri=https%3A%2F%2Fhost.example.com%2Foauth%2Fcallback",
recorded_request.body
)

@responses.activate
def test_fetch_access_token_bad_request(self):
client = OAuthClient('some_client_id', 'some_client_secret', 'https://example.com')
responses.post(
'https://accounts.google.com/o/oauth2/token',
status=400
)
responses.post('https://oauth.example.com/token', status=400)

access_token = client.fetch_access_token('some_code')
access_token = self.client.fetch_access_token('some_code')

self.assertIsNone(access_token)

@responses.activate
def test_read_email_from_token(self):
user_info_endpoint = responses.get(
'https://www.googleapis.com/oauth2/v3/userinfo',
json={'email': '[email protected]'}
)
user_info_endpoint = responses.get('https://userinfo.example.com', json={'email': '[email protected]'})

email = OAuthClient.read_email_from_token('some_token')
email = self.client.read_email_from_token('some_token')

self.assertEqual('[email protected]', email)
self.assertEqual(1, user_info_endpoint.call_count)
Expand All @@ -63,11 +59,8 @@ def test_read_email_from_token(self):

@responses.activate
def test_read_email_from_token_bad_request(self):
responses.get(
'https://www.googleapis.com/oauth2/v3/userinfo',
status=400
)
responses.get('https://userinfo.example.com', status=400)

email = OAuthClient.read_email_from_token('some_token')
email = self.client.read_email_from_token('some_token')

self.assertIsNone(email)

0 comments on commit ee8ac76

Please sign in to comment.