Skip to content

Commit

Permalink
Add deploy() method, make constructor type safe. (#80)
Browse files Browse the repository at this point in the history
  • Loading branch information
sentilesdal authored Dec 15, 2023
1 parent e800dd1 commit f64ed38
Show file tree
Hide file tree
Showing 35 changed files with 1,903 additions and 155 deletions.
2 changes: 1 addition & 1 deletion conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def w3_init(local_chain: str) -> Web3:


@pytest.fixture(scope="function")
def w3(w3_init: Web3) -> Iterator[Web3]: # type: ignore
def w3(w3_init: Web3) -> Iterator[Web3]:
"""resets the anvil instance at the function level so each test gets a fresh chain.
Parameters
Expand Down
17 changes: 14 additions & 3 deletions example/abis/Example.json

Large diffs are not rendered by default.

8 changes: 7 additions & 1 deletion example/contracts/Example.sol
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ pragma solidity ^0.8.0;

contract Example {

string contractName;

struct SimpleStruct {
uint intVal;
string strVal;
Expand All @@ -18,6 +20,10 @@ contract Example {
InnerStruct innerStruct;
}

constructor(string memory name) {
contractName = name;
}

function flipFlop(uint flip, uint flop) public pure returns (uint _flop, uint _flip) {
return (flop,flip);
}
Expand Down Expand Up @@ -77,7 +83,7 @@ contract Example {
});
}

function mixStructsAndPrimitives() public pure returns (SimpleStruct memory simpleStruct, NestedStruct memory, uint, string memory name, bool YesOrNo) {
function mixStructsAndPrimitives() public pure returns (SimpleStruct memory simpleStruct, NestedStruct memory, uint, string memory _name, bool YesOrNo) {
simpleStruct = SimpleStruct({
intVal: 1,
strVal: "You are number 1"
Expand Down
14 changes: 10 additions & 4 deletions example/test_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def test_using_vanilla_web3py(self, w3: Web3):
ExampleContract = w3.eth.contract(abi=abi, bytecode=bytecode)

# Submit the transaction that deploys the contract
tx_hash = ExampleContract.constructor().transact()
tx_hash = ExampleContract.constructor("example").transact()
# Wait for the transaction to be mined, and get the transaction receipt
tx_receipt = w3.eth.wait_for_transaction_receipt(tx_hash)

Expand All @@ -57,7 +57,9 @@ def test_using_vanilla_web3py(self, w3: Web3):
def test_flip_flop(self, w3):
"""Tests single value"""

deployed_contract = ExampleContract.deploy(w3=w3, signer=w3.eth.accounts[0])
deployed_contract = ExampleContract.deploy(
w3=w3, account=w3.eth.accounts[0], constructorArgs=ExampleContract.ConstructorArgs("example")
)

flip = 1
flop = 2
Expand All @@ -70,7 +72,9 @@ def test_flip_flop(self, w3):
def test_simple_structs(self, w3):
"""Tests single value"""

deployed_contract = ExampleContract.deploy(w3=w3, signer=w3.eth.accounts[0])
deployed_contract = ExampleContract.deploy(
w3=w3, account=w3.eth.accounts[0], constructorArgs=ExampleContract.ConstructorArgs("example")
)

input_struct: SimpleStruct = SimpleStruct(1, "string")
output_struct = deployed_contract.functions.singleSimpleStruct(input_struct).call()
Expand All @@ -81,7 +85,9 @@ def test_simple_structs(self, w3):
def test_nested_structs(self, w3):
"""Tests single value"""

deployed_contract = ExampleContract.deploy(w3=w3, signer=w3.eth.accounts[0])
deployed_contract = ExampleContract.deploy(
w3=w3, account=w3.eth.accounts[0], constructorArgs=ExampleContract.ConstructorArgs("example")
)

input_struct = NestedStruct(1, "string", InnerStruct(True))
output_struct: NestedStruct = deployed_contract.functions.singleNestedStruct(input_struct).call()
Expand Down
83 changes: 76 additions & 7 deletions example/types/ExampleContract.py

Large diffs are not rendered by default.

90 changes: 46 additions & 44 deletions pypechain/render/contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from web3.types import ABI

from pypechain.utilities.abi import (
get_abi_constructor,
get_abi_items,
get_input_names,
get_input_names_and_types,
Expand All @@ -16,7 +17,6 @@
get_output_names_and_types,
get_output_types,
get_structs_for_abi,
is_abi_constructor,
is_abi_event,
is_abi_function,
load_abi_from_file,
Expand Down Expand Up @@ -83,6 +83,7 @@ def render_contract_file(contract_name: str, abi_file_path: Path) -> str:
has_bytecode=has_bytecode,
has_events=has_events,
contract_name=contract_name,
constructor=constructor_data,
functions=function_datas,
)

Expand Down Expand Up @@ -202,53 +203,54 @@ def get_function_datas(abi: ABI) -> GetFunctionDatasReturnValue:
A tuple where the first value is a dictionary of FunctionData's keyed by function name and
the second value is SignatureData for the constructor.
"""

# handle constructor
abi_constructor = get_abi_constructor(abi)
constructor_data: SignatureData | None = (
{
"input_names_and_types": get_input_names_and_types(abi_constructor),
"input_names": get_input_names(abi_constructor),
"input_types": get_input_types(abi_constructor),
"outputs": get_output_names(abi_constructor),
"output_types": get_output_names_and_types(abi_constructor),
}
if abi_constructor
else None
)

# handle all other functions
function_datas: dict[str, FunctionData] = {}
constructor_data: SignatureData | None = None
for abi_function in get_abi_items(abi):
if is_abi_function(abi_function):
# handle constructor
if is_abi_constructor(abi_function):
constructor_data = {
"input_names_and_types": get_input_names_and_types(abi_function),
"input_names": get_input_names(abi_function),
"input_types": get_input_types(abi_function),
"outputs": get_output_names(abi_function),
"output_types": get_output_names_and_types(abi_function),
}

# handle all other functions
name = abi_function.get("name", "")
name = re.sub(r"\W|^(?=\d)", "_", name)
signature_data: SignatureData = {
"input_names_and_types": get_input_names_and_types(abi_function),
"input_names": get_input_names(abi_function),
"input_types": get_input_types(abi_function),
"outputs": get_output_names(abi_function),
"output_types": get_output_types(abi_function),
}

function_data: FunctionData = {
"name": name,
"capitalized_name": capitalize_first_letter_only(name),
"signature_datas": [signature_data],
"has_overloading": False,
"has_multiple_return_signatures": False,
"has_multiple_return_values": False,
}
if not function_datas.get(name):
function_datas[name] = function_data
function_datas[name]["has_multiple_return_values"] = get_has_multiple_return_values([signature_data])
else:
name = abi_function.get("name", "")
name = re.sub(r"\W|^(?=\d)", "_", name)
signature_data: SignatureData = {
"input_names_and_types": get_input_names_and_types(abi_function),
"input_names": get_input_names(abi_function),
"input_types": get_input_types(abi_function),
"outputs": get_output_names(abi_function),
"output_types": get_output_types(abi_function),
}

function_data: FunctionData = {
"name": name,
"capitalized_name": capitalize_first_letter_only(name),
"signature_datas": [signature_data],
"has_overloading": False,
"has_multiple_return_signatures": False,
"has_multiple_return_values": False,
}
if not function_datas.get(name):
function_datas[name] = function_data
function_datas[name]["has_multiple_return_values"] = get_has_multiple_return_values(
[signature_data]
)
else:
signature_datas = function_datas[name]["signature_datas"]
signature_datas.append(signature_data)
function_datas[name]["has_overloading"] = len(signature_datas) > 1
function_datas[name]["has_multiple_return_signatures"] = get_has_multiple_return_signatures(
signature_datas
)
function_datas[name]["has_multiple_return_values"] = get_has_multiple_return_values(signature_datas)
signature_datas = function_datas[name]["signature_datas"]
signature_datas.append(signature_data)
function_datas[name]["has_overloading"] = len(signature_datas) > 1
function_datas[name]["has_multiple_return_signatures"] = get_has_multiple_return_signatures(
signature_datas
)
function_datas[name]["has_multiple_return_values"] = get_has_multiple_return_values(signature_datas)
return GetFunctionDatasReturnValue(function_datas, constructor_data)


Expand Down
9 changes: 6 additions & 3 deletions pypechain/render/init.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""Functions to render __init__.py from a list of filenames usng a jinja2 template."""
import subprocess
from pathlib import Path

from pypechain.utilities.file import write_string_to_file
from pypechain.utilities.format import apply_black_formatting
from pypechain.utilities.templates import get_jinja_env


Expand All @@ -9,5 +11,6 @@ def render_init_file(output_dir: str, file_names: list[str], line_length):
env = get_jinja_env()
init_template = env.get_template("init.py.jinja2")
init_code = init_template.render(file_names=file_names)
formatted_init_code = apply_black_formatting(init_code, line_length)
write_string_to_file(f"{output_dir}/__init__.py", formatted_init_code)
init_file_path = Path(f"{output_dir}/__init__.py")
write_string_to_file(init_file_path, init_code)
subprocess.run(f"black --line-length={line_length} {init_file_path}", shell=True, check=True)
26 changes: 4 additions & 22 deletions pypechain/render/main.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
"""Functions to render Python files from an abi usng a jinja2 template."""

import os
import subprocess
from pathlib import Path

from pypechain.render.contract import render_contract_file
from pypechain.render.types import render_types_file
from pypechain.utilities.file import write_string_to_file
from pypechain.utilities.format import apply_black_formatting
from pypechain.utilities.format import format_file


def render_files(abi_file_path: str, output_dir: str, line_length: int = 120) -> list[str]:
Expand All @@ -29,30 +28,13 @@ def render_files(abi_file_path: str, output_dir: str, line_length: int = 120) ->

contract_file_path = Path(f"{contract_path}Contract.py")
write_string_to_file(contract_file_path, rendered_contract_code)
format_file(contract_file_path)
format_file(contract_file_path, line_length)
file_names.append(f"{contract_name}Contract")

# TODO: write tests for this conditional write.
if rendered_types_code:
types_file_path = Path(f"{contract_path}Types.py")
formatted_types_code = apply_black_formatting(rendered_types_code, line_length)
write_string_to_file(types_file_path, formatted_types_code)
format_file(types_file_path)
write_string_to_file(types_file_path, rendered_types_code)
format_file(types_file_path, line_length)
file_names.append(f"{contract_name}Types")

return file_names


def format_file(file_path: Path, line_length: int = 120) -> None:
"""Formats a file with isort and black.
Parameters
----------
file_path : Path
The file to be formatted.
line_length : int, optional
Black's line-length config option.
"""

subprocess.run(f"isort {file_path}", shell=True, check=True)
subprocess.run(f"black --line-length={line_length} {file_path}", shell=True, check=True)
18 changes: 11 additions & 7 deletions pypechain/templates/contract.py/base.py.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,26 @@ https://github.com/delvtech/pypechain"""
# This file is bound to get very long depending on contract sizes.
# pylint: disable=too-many-lines

# methods are overriden with specific arguments instead of generic *args, **kwargs
# pylint: disable=arguments-differ

from __future__ import annotations

from dataclasses import fields, is_dataclass
from typing import Any, {% if has_multiple_return_values %}NamedTuple, {% endif %}Tuple, Type, TypeVar, cast{% if has_overloading %}, overload{% endif %}
{% if has_events %}from typing import Iterable, Sequence {% endif %}
from typing import Any, NamedTuple, Tuple, Type, TypeVar, cast, overload
from typing import Iterable, Sequence

from eth_typing import ChecksumAddress{% if has_bytecode %}, HexStr{% endif %}
from eth_typing import ChecksumAddress, HexStr
from eth_account.signers.local import LocalAccount
from hexbytes import HexBytes
from typing_extensions import Self
from web3 import Web3
from web3.contract.contract import Contract, ContractFunction, ContractFunctions
{% if has_events %}from web3.contract.contract import ContractEvent, ContractEvents{% endif %}
from web3.contract.contract import Contract, ContractFunction, ContractFunctions, ContractConstructor
from web3.contract.contract import ContractEvent, ContractEvents
from web3.exceptions import FallbackNotFound
from web3.types import ABI, BlockIdentifier, CallOverride, TxParams
{% if has_events %}from web3.types import EventData{% endif %}
{% if has_events %}from web3._utils.filters import LogFilter{% endif %}
from web3.types import EventData
from web3._utils.filters import LogFilter
{% if structs_for_abi|length > 0 %}from .{{contract_name}}Types import {{ structs_for_abi|join(', ')}}{% endif %}

T = TypeVar("T")
Expand Down
Loading

0 comments on commit f64ed38

Please sign in to comment.