diff --git a/python/coinbase-agentkit/coinbase_agentkit/__init__.py b/python/coinbase-agentkit/coinbase_agentkit/__init__.py index a9b3671fb..f377512b3 100644 --- a/python/coinbase-agentkit/coinbase_agentkit/__init__.py +++ b/python/coinbase-agentkit/coinbase_agentkit/__init__.py @@ -6,6 +6,7 @@ cdp_api_action_provider, cdp_wallet_action_provider, create_action, + erc20_action_provider, morpho_action_provider, pyth_action_provider, twitter_action_provider, @@ -36,6 +37,7 @@ "EvmWalletProvider", "EthAccountWalletProvider", "EthAccountWalletProviderConfig", + "erc20_action_provider", "cdp_api_action_provider", "cdp_wallet_action_provider", "morpho_action_provider", diff --git a/python/coinbase-agentkit/coinbase_agentkit/action_providers/__init__.py b/python/coinbase-agentkit/coinbase_agentkit/action_providers/__init__.py index f7cf4b009..309498483 100644 --- a/python/coinbase-agentkit/coinbase_agentkit/action_providers/__init__.py +++ b/python/coinbase-agentkit/coinbase_agentkit/action_providers/__init__.py @@ -2,6 +2,7 @@ from .action_provider import Action, ActionProvider from .cdp.cdp_api_action_provider import CdpApiActionProvider, cdp_api_action_provider from .cdp.cdp_wallet_action_provider import CdpWalletActionProvider, cdp_wallet_action_provider +from .erc20.erc20_action_provider import ERC20ActionProvider, erc20_action_provider from .morpho.morpho_action_provider import MorphoActionProvider, morpho_action_provider from .pyth.pyth_action_provider import PythActionProvider, pyth_action_provider from .twitter.twitter_action_provider import TwitterActionProvider, twitter_action_provider @@ -16,6 +17,8 @@ "cdp_api_action_provider", "CdpWalletActionProvider", "cdp_wallet_action_provider", + "ERC20ActionProvider", + "erc20_action_provider", "MorphoActionProvider", "morpho_action_provider", "PythActionProvider", diff --git a/python/coinbase-agentkit/coinbase_agentkit/action_providers/erc20/__init__.py b/python/coinbase-agentkit/coinbase_agentkit/action_providers/erc20/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/python/coinbase-agentkit/coinbase_agentkit/action_providers/erc20/constants.py b/python/coinbase-agentkit/coinbase_agentkit/action_providers/erc20/constants.py new file mode 100644 index 000000000..08d042ad9 --- /dev/null +++ b/python/coinbase-agentkit/coinbase_agentkit/action_providers/erc20/constants.py @@ -0,0 +1,40 @@ +"""Constants for the ERC20 action provider.""" + +ERC20_ABI = [ + { + "type": "function", + "name": "balanceOf", + "stateMutability": "view", + "inputs": [ + { + "name": "account", + "type": "address", + }, + ], + "outputs": [ + { + "type": "uint256", + }, + ], + }, + { + "type": "function", + "name": "transfer", + "stateMutability": "nonpayable", + "inputs": [ + { + "name": "recipient", + "type": "address", + }, + { + "name": "amount", + "type": "uint256", + }, + ], + "outputs": [ + { + "type": "bool", + }, + ], + }, +] diff --git a/python/coinbase-agentkit/coinbase_agentkit/action_providers/erc20/erc20_action_provider.py b/python/coinbase-agentkit/coinbase_agentkit/action_providers/erc20/erc20_action_provider.py new file mode 100644 index 000000000..48bf4ea93 --- /dev/null +++ b/python/coinbase-agentkit/coinbase_agentkit/action_providers/erc20/erc20_action_provider.py @@ -0,0 +1,124 @@ +"""ERC20 action provider implementation.""" + +from web3 import Web3 + +from ...network import Network +from ...wallet_providers import EvmWalletProvider +from ..action_decorator import create_action +from ..action_provider import ActionProvider +from .constants import ERC20_ABI +from .schemas import GetBalanceSchema, TransferSchema + + +class ERC20ActionProvider(ActionProvider[EvmWalletProvider]): + """Action provider for ERC20 tokens.""" + + def __init__(self) -> None: + """Initialize the ERC20 action provider.""" + super().__init__("erc20", []) + + @create_action( + name="get_balance", + description=""" + This tool will get the balance of an ERC20 asset in the wallet. It takes the contract address as input. + """, + schema=GetBalanceSchema, + ) + def get_balance(self, wallet_provider: EvmWalletProvider, args: GetBalanceSchema) -> str: + """Get the balance of an ERC20 token. + + Args: + wallet_provider: The wallet provider to get the balance from. + args: The input arguments for the action. + + Returns: + A message containing the balance. + + """ + try: + validated_args = GetBalanceSchema(**args) + + balance = wallet_provider.read_contract( + contract_address=validated_args.contract_address, + abi=ERC20_ABI, + function_name="balanceOf", + args=[wallet_provider.get_address()], + ) + + return f"Balance of {validated_args.contract_address} is {balance}" + except Exception as e: + return f"Error getting balance: {e!s}" + + @create_action( + name="transfer", + description=""" + This tool will transfer an ERC20 token from the wallet to another onchain address. + + It takes the following inputs: + - amount: The amount to transfer + - contract_address: The contract address of the token to transfer + - destination: Where to send the tokens + + Important notes: + - Ensure sufficient balance of the input asset before transferring + - When sending native assets (e.g. 'eth' on base-mainnet), ensure there is sufficient balance for the transfer itself AND the gas cost of this transfer + """, + schema=TransferSchema, + ) + def transfer(self, wallet_provider: EvmWalletProvider, args: TransferSchema) -> str: + """Transfer a specified amount of an ERC20 token to a destination onchain. + + Args: + wallet_provider: The wallet provider to transfer the asset from. + args: The input arguments for the action. + + Returns: + A message containing the transfer details. + + """ + try: + validated_args = TransferSchema(**args) + + contract = Web3().eth.contract(address=validated_args.contract_address, abi=ERC20_ABI) + data = contract.encode_abi( + "transfer", [validated_args.destination, int(validated_args.amount)] + ) + + tx_hash = wallet_provider.send_transaction( + { + "to": validated_args.contract_address, + "data": data, + } + ) + + wallet_provider.wait_for_transaction_receipt(tx_hash) + + return ( + f"Transferred {validated_args.amount} of {validated_args.contract_address} " + f"to {validated_args.destination}.\n" + f"Transaction hash for the transfer: {tx_hash}" + ) + except Exception as e: + return f"Error transferring the asset: {e!s}" + + def supports_network(self, network: Network) -> bool: + """Check if the ERC20 action provider supports the given network. + + Args: + network: The network to check. + + Returns: + True if the ERC20 action provider supports the network, false otherwise. + + """ + return network.protocol_family == "evm" + + +def erc20_action_provider() -> ERC20ActionProvider: + """Create a new instance of the ERC20 action provider. + + Returns: + A new ERC20 action provider instance. + + """ + return ERC20ActionProvider() diff --git a/python/coinbase-agentkit/coinbase_agentkit/action_providers/erc20/schemas.py b/python/coinbase-agentkit/coinbase_agentkit/action_providers/erc20/schemas.py new file mode 100644 index 000000000..9199f4bc4 --- /dev/null +++ b/python/coinbase-agentkit/coinbase_agentkit/action_providers/erc20/schemas.py @@ -0,0 +1,28 @@ +"""Schemas for the ERC20 action provider.""" + +from pydantic import BaseModel, Field, field_validator + +from .validators import wei_amount_validator + + +class GetBalanceSchema(BaseModel): + """Schema for getting the balance of an ERC20 token.""" + + contract_address: str = Field( + ..., + description="The contract address of the token to get the balance for", + ) + + +class TransferSchema(BaseModel): + """Schema for transferring ERC20 tokens.""" + + amount: str = Field(description="The amount of the asset to transfer in wei") + contract_address: str = Field(description="The contract address of the token to transfer") + destination: str = Field(description="The destination to transfer the funds") + + @field_validator("amount") + @classmethod + def validate_wei_amount(cls, v: str) -> str: + """Validate wei amount.""" + return wei_amount_validator(v) diff --git a/python/coinbase-agentkit/coinbase_agentkit/action_providers/erc20/utils.py b/python/coinbase-agentkit/coinbase_agentkit/action_providers/erc20/utils.py new file mode 100644 index 000000000..e69de29bb diff --git a/python/coinbase-agentkit/coinbase_agentkit/action_providers/erc20/validators.py b/python/coinbase-agentkit/coinbase_agentkit/action_providers/erc20/validators.py new file mode 100644 index 000000000..784128ee3 --- /dev/null +++ b/python/coinbase-agentkit/coinbase_agentkit/action_providers/erc20/validators.py @@ -0,0 +1,24 @@ +"""Validators for ERC20 action inputs.""" + +import re + +from pydantic_core import PydanticCustomError + + +def wei_amount_validator(value: str) -> str: + """Validate that amount is a valid wei value (positive whole number as string).""" + if not re.match(r"^[0-9]+$", value): + raise PydanticCustomError( + "wei_format", + "Amount must be a positive whole number as a string", + {"value": value}, + ) + + if int(value) <= 0: + raise PydanticCustomError( + "positive_wei", + "Amount must be greater than 0", + {"value": value}, + ) + + return value diff --git a/python/coinbase-agentkit/tests/action_providers/erc20/__init__.py b/python/coinbase-agentkit/tests/action_providers/erc20/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/python/coinbase-agentkit/tests/action_providers/erc20/conftest.py b/python/coinbase-agentkit/tests/action_providers/erc20/conftest.py new file mode 100644 index 000000000..edb9eb5a8 --- /dev/null +++ b/python/coinbase-agentkit/tests/action_providers/erc20/conftest.py @@ -0,0 +1,21 @@ +"""Test fixtures for ERC20 action provider tests.""" + +from unittest.mock import Mock + +import pytest + +from coinbase_agentkit.wallet_providers.evm_wallet_provider import EvmWalletProvider + +MOCK_AMOUNT = "1000000000000000000" +MOCK_CONTRACT_ADDRESS = "0x1234567890123456789012345678901234567890" +MOCK_DESTINATION = "0x9876543210987654321098765432109876543210" +MOCK_ADDRESS = "0x1234567890123456789012345678901234567890" + + +@pytest.fixture +def mock_wallet(): + """Create a mock wallet provider.""" + mock = Mock(spec=EvmWalletProvider) + mock.get_address.return_value = MOCK_ADDRESS + mock.read_contract.return_value = MOCK_AMOUNT + return mock diff --git a/python/coinbase-agentkit/tests/action_providers/erc20/test_erc20_action_provider.py b/python/coinbase-agentkit/tests/action_providers/erc20/test_erc20_action_provider.py new file mode 100644 index 000000000..b39679075 --- /dev/null +++ b/python/coinbase-agentkit/tests/action_providers/erc20/test_erc20_action_provider.py @@ -0,0 +1,149 @@ +"""Tests for the ERC20 action provider.""" + +import pytest +from web3 import Web3 + +from coinbase_agentkit.action_providers.erc20.constants import ERC20_ABI +from coinbase_agentkit.action_providers.erc20.erc20_action_provider import ( + erc20_action_provider, +) +from coinbase_agentkit.action_providers.erc20.schemas import GetBalanceSchema, TransferSchema +from coinbase_agentkit.network import Network + +from .conftest import ( + MOCK_AMOUNT, + MOCK_CONTRACT_ADDRESS, + MOCK_DESTINATION, +) + + +def test_get_balance_schema_valid(): + """Test that the GetBalanceSchema validates correctly.""" + valid_input = {"contract_address": MOCK_CONTRACT_ADDRESS} + schema = GetBalanceSchema(**valid_input) + assert schema.contract_address == MOCK_CONTRACT_ADDRESS + + +def test_get_balance_schema_invalid(): + """Test that the GetBalanceSchema fails on invalid input.""" + with pytest.raises(ValueError): + GetBalanceSchema() + + +def test_get_balance_success(mock_wallet): + """Test successful get_balance call.""" + args = {"contract_address": MOCK_CONTRACT_ADDRESS} + provider = erc20_action_provider() + + response = provider.get_balance(mock_wallet, args) + + mock_wallet.read_contract.assert_called_once_with( + contract_address=MOCK_CONTRACT_ADDRESS, + abi=ERC20_ABI, + function_name="balanceOf", + args=[mock_wallet.get_address()], + ) + assert f"Balance of {MOCK_CONTRACT_ADDRESS} is {MOCK_AMOUNT}" in response + + +def test_get_balance_error(mock_wallet): + """Test get_balance with error.""" + args = {"contract_address": MOCK_CONTRACT_ADDRESS} + error = Exception("Failed to get balance") + mock_wallet.read_contract.side_effect = error + provider = erc20_action_provider() + + response = provider.get_balance(mock_wallet, args) + + mock_wallet.read_contract.assert_called_once_with( + contract_address=MOCK_CONTRACT_ADDRESS, + abi=ERC20_ABI, + function_name="balanceOf", + args=[mock_wallet.get_address()], + ) + assert f"Error getting balance: {error!s}" in response + + +def test_transfer_schema_valid(): + """Test that the TransferSchema validates correctly.""" + valid_input = { + "amount": MOCK_AMOUNT, + "contract_address": MOCK_CONTRACT_ADDRESS, + "destination": MOCK_DESTINATION, + } + schema = TransferSchema(**valid_input) + assert schema.amount == MOCK_AMOUNT + assert schema.contract_address == MOCK_CONTRACT_ADDRESS + assert schema.destination == MOCK_DESTINATION + + +def test_transfer_schema_invalid(): + """Test that the TransferSchema fails on invalid input.""" + with pytest.raises(ValueError): + TransferSchema() + + +def test_transfer_success(mock_wallet): + """Test successful transfer call.""" + args = { + "amount": MOCK_AMOUNT, + "contract_address": MOCK_CONTRACT_ADDRESS, + "destination": MOCK_DESTINATION, + } + provider = erc20_action_provider() + + mock_tx_hash = "0xghijkl987654321" + mock_wallet.send_transaction.return_value = mock_tx_hash + + response = provider.transfer(mock_wallet, args) + + contract = Web3().eth.contract(address=MOCK_CONTRACT_ADDRESS, abi=ERC20_ABI) + expected_data = contract.encode_abi("transfer", [MOCK_DESTINATION, int(MOCK_AMOUNT)]) + + mock_wallet.send_transaction.assert_called_once_with( + { + "to": MOCK_CONTRACT_ADDRESS, + "data": expected_data, + } + ) + mock_wallet.wait_for_transaction_receipt.assert_called_once_with(mock_tx_hash) + assert f"Transferred {MOCK_AMOUNT} of {MOCK_CONTRACT_ADDRESS} to {MOCK_DESTINATION}" in response + assert f"Transaction hash for the transfer: {mock_tx_hash}" in response + + +def test_transfer_error(mock_wallet): + """Test transfer with error.""" + args = { + "amount": MOCK_AMOUNT, + "contract_address": MOCK_CONTRACT_ADDRESS, + "destination": MOCK_DESTINATION, + } + error = Exception("Failed to execute transfer") + mock_wallet.send_transaction.side_effect = error + provider = erc20_action_provider() + + response = provider.transfer(mock_wallet, args) + + contract = Web3().eth.contract(address=MOCK_CONTRACT_ADDRESS, abi=ERC20_ABI) + expected_data = contract.encode_abi("transfer", [MOCK_DESTINATION, int(MOCK_AMOUNT)]) + + mock_wallet.send_transaction.assert_called_once_with( + { + "to": MOCK_CONTRACT_ADDRESS, + "data": expected_data, + } + ) + assert f"Error transferring the asset: {error!s}" in response + + +def test_supports_network(): + """Test network support based on protocol family.""" + test_cases = [ + ("evm", True), + ("solana", False), + ] + + provider = erc20_action_provider() + for protocol_family, expected in test_cases: + network = Network(chain_id=1, protocol_family=protocol_family) + assert provider.supports_network(network) is expected