Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bank] Add discord.Object support to bank.get_balance #4654

Open
wants to merge 3 commits into
base: V3/develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 38 additions & 28 deletions redbot/core/bank.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def _decode_time(time: int) -> datetime:
return datetime.utcfromtimestamp(time)


async def get_balance(member: discord.Member) -> int:
async def get_balance(member: Union[discord.Member, discord.User, discord.Object]) -> int:
"""Get the current balance of a member.

Parameters
Expand All @@ -221,12 +221,14 @@ async def get_balance(member: discord.Member) -> int:
return acc.balance


async def can_spend(member: discord.Member, amount: int) -> bool:
async def can_spend(
member: Union[discord.Member, discord.User, discord.Object], amount: int
) -> bool:
"""Determine if a member can spend the given amount.

Parameters
----------
member : discord.Member
member : Union[discord.Member, discord.User, discord.Object]
The member wanting to spend.
amount : int
The amount the member wants to spend.
Expand All @@ -250,12 +252,14 @@ async def can_spend(member: discord.Member, amount: int) -> bool:
return await get_balance(member) >= amount


async def set_balance(member: Union[discord.Member, discord.User], amount: int) -> int:
async def set_balance(
member: Union[discord.Member, discord.User, discord.Object], amount: int
) -> int:
"""Set an account balance.

Parameters
----------
member : Union[discord.Member, discord.User]
member : Union[discord.Member, discord.User, discord.Object]
The member whose balance to set.
amount : int
The amount to set the balance to.
Expand All @@ -281,15 +285,17 @@ async def set_balance(member: Union[discord.Member, discord.User], amount: int)
raise TypeError("Amount must be of type int, not {}.".format(type(amount)))
if amount < 0:
raise ValueError("Not allowed to have negative balance.")

guild = getattr(member, "guild", None)
display_name = getattr(member, "display_name", "John Doe")
max_bal = await get_max_balance(guild)

if amount > max_bal:
currency = await get_currency_name(guild)
raise errors.BalanceTooHigh(
user=member.display_name, max_balance=max_bal, currency_name=currency
)
raise errors.BalanceTooHigh(user=display_name, max_balance=max_bal, currency_name=currency)

if await is_global():
group = _config.user(member)
group = _config.user_from_id(member.id)
else:
group = _config.member(member)
await group.balance.set(amount)
Expand All @@ -299,7 +305,7 @@ async def set_balance(member: Union[discord.Member, discord.User], amount: int)
await group.created_at.set(time)

if await group.name() == "":
await group.name.set(member.display_name)
await group.name.set(display_name)

return amount

Expand All @@ -308,12 +314,14 @@ def _invalid_amount(amount: int) -> bool:
return amount < 0


async def withdraw_credits(member: discord.Member, amount: int) -> int:
async def withdraw_credits(
member: Union[discord.Member, discord.User, discord.Object], amount: int
) -> int:
"""Remove a certain amount of credits from an account.

Parameters
----------
member : discord.Member
member : Union[discord.Member, discord.User, discord.Object]
The member to withdraw credits from.
amount : int
The amount to withdraw.
Expand Down Expand Up @@ -353,12 +361,14 @@ async def withdraw_credits(member: discord.Member, amount: int) -> int:
return await set_balance(member, bal - amount)


async def deposit_credits(member: discord.Member, amount: int) -> int:
async def deposit_credits(
member: Union[discord.Member, discord.User, discord.Object], amount: int
) -> int:
"""Add a given amount of credits to an account.

Parameters
----------
member : discord.Member
member : Union[discord.Member, discord.User, discord.Object]
The member to deposit credits to.
amount : int
The amount to deposit.
Expand Down Expand Up @@ -390,17 +400,17 @@ async def deposit_credits(member: discord.Member, amount: int) -> int:


async def transfer_credits(
from_: Union[discord.Member, discord.User],
to: Union[discord.Member, discord.User],
from_: Union[discord.Member, discord.User, discord.Object],
to: Union[discord.Member, discord.User, discord.Object],
amount: int,
):
"""Transfer a given amount of credits from one account to another.

Parameters
----------
from_: Union[discord.Member, discord.User]
from_: Union[discord.Member, discord.User, discord.Object]
The member to transfer from.
to : Union[discord.Member, discord.User]
to : Union[discord.Member, discord.User, discord.Object]
The member to transfer to.
amount : int
The amount to transfer.
Expand Down Expand Up @@ -430,14 +440,14 @@ async def transfer_credits(
humanize_number(amount, override_locale="en_US")
)
)

guild = getattr(to, "guild", None)
display_name = getattr(to, "display_name", "John Doe")
max_bal = await get_max_balance(guild)

if await get_balance(to) + amount > max_bal:
currency = await get_currency_name(guild)
raise errors.BalanceTooHigh(
user=to.display_name, max_balance=max_bal, currency_name=currency
)
raise errors.BalanceTooHigh(user=display_name, max_balance=max_bal, currency_name=currency)

await withdraw_credits(from_, amount)
return await deposit_credits(to, amount)
Expand Down Expand Up @@ -599,14 +609,14 @@ async def get_leaderboard_position(
return pos[0]


async def get_account(member: Union[discord.Member, discord.User]) -> Account:
async def get_account(member: Union[discord.Member, discord.User, discord.Object]) -> Account:
"""Get the appropriate account for the given user or member.

A member is required if the bank is currently guild specific.

Parameters
----------
member : `discord.User` or `discord.Member`
member : Union[discord.Member, discord.User, discord.Object]
The user whose account to get.

Returns
Expand All @@ -620,12 +630,12 @@ async def get_account(member: Union[discord.Member, discord.User]) -> Account:
else:
all_accounts = await _config.all_members(member.guild)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line will and should fail error with discord.Object. Even if it were to succeed with Object, get_default_balance below would throw a runtime error


guild = getattr(member, "guild", None)
display_name = getattr(member, "display_name", "John Doe")

if member.id not in all_accounts:
acc_data = {"name": member.display_name, "created_at": _DEFAULT_MEMBER["created_at"]}
try:
acc_data["balance"] = await get_default_balance(member.guild)
except AttributeError:
acc_data["balance"] = await get_default_balance()
acc_data = {"name": display_name, "created_at": _DEFAULT_MEMBER["created_at"]}
acc_data["balance"] = await get_default_balance(guild)
else:
acc_data = all_accounts[member.id]

Expand Down
12 changes: 12 additions & 0 deletions redbot/pytest/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"empty_message",
"empty_role",
"empty_user",
"object_as_member_factory",
"member_factory",
"user_factory",
"ctx",
Expand Down Expand Up @@ -136,6 +137,17 @@ def empty_user(user_factory):
return user_factory.get()


@pytest.fixture()
def object_as_member_factory(guild_factory):
mock_member = namedtuple("Object", "id guild display_name")

class ObjectAsMemberFactory:
def get(self):
return mock_member(random.randint(1, 999999999), guild_factory.get(), "Testing_Name")

return ObjectAsMemberFactory()


@pytest.fixture(scope="module")
def empty_message():
mock_msg = namedtuple("Message", "content")
Expand Down
29 changes: 29 additions & 0 deletions tests/cogs/test_economy.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,27 @@ async def test_bank_transfer(bank, member_factory):
assert bal2 + 50 == newbal2


@pytest.mark.asyncio
async def test_bank_transfer_from_objects(bank, object_as_member_factory):
mbr1 = object_as_member_factory.get()
mbr2 = object_as_member_factory.get()
bal1 = (await bank.get_account(mbr1)).balance
bal2 = (await bank.get_account(mbr2)).balance
await bank.transfer_credits(mbr1, mbr2, 50)
newbal1 = (await bank.get_account(mbr1)).balance
newbal2 = (await bank.get_account(mbr2)).balance
assert bal1 - 50 == newbal1
assert bal2 + 50 == newbal2


@pytest.mark.asyncio
async def test_bank_get_from_object(bank, object_as_member_factory):
mbr = object_as_member_factory.get()
await bank.set_balance(mbr, 250)
acc = await bank.get_account(mbr)
assert acc.balance == 250


@pytest.mark.asyncio
async def test_bank_set(bank, member_factory):
mbr = member_factory.get()
Expand All @@ -36,6 +57,14 @@ async def test_bank_set(bank, member_factory):
assert acc.balance == 250


@pytest.mark.asyncio
async def test_bank_set_from_object(bank, object_as_member_factory):
mbr = object_as_member_factory.get()
await bank.set_balance(mbr, 250)
acc = await bank.get_account(mbr)
assert acc.balance == 250


@pytest.mark.asyncio
async def test_bank_can_spend(bank, member_factory):
mbr = member_factory.get()
Expand Down