Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Echidna printer Improve values extraction #2574

Merged
merged 5 commits into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 65 additions & 34 deletions slither/printers/guidance/echidna.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from slither.core.expressions import NewContract
from slither.core.slither_core import SlitherCore
from slither.core.solidity_types import TypeAlias
from slither.core.source_mapping.source_mapping import SourceMapping
from slither.core.variables.state_variable import StateVariable
from slither.core.variables.variable import Variable
from slither.printers.abstract_printer import AbstractPrinter
Expand Down Expand Up @@ -179,29 +180,74 @@ class ConstantValue(NamedTuple): # pylint: disable=inherit-non-class,too-few-pu
type: str


def _extract_constants_from_irs( # pylint: disable=too-many-branches,too-many-nested-blocks
def _extract_constant_from_read(
ir: Operation,
r: SourceMapping,
all_cst_used: List[ConstantValue],
all_cst_used_in_binary: Dict[str, List[ConstantValue]],
context_explored: Set[Node],
) -> None:
var_read = r.points_to_origin if isinstance(r, ReferenceVariable) else r
# Do not report struct_name in a.struct_name
if isinstance(ir, Member):
return
if isinstance(var_read, Variable) and var_read.is_constant:
# In case of type conversion we use the destination type
if isinstance(ir, TypeConversion):
if isinstance(ir.type, TypeAlias):
value_type = ir.type.type
else:
value_type = ir.type
else:
value_type = var_read.type
try:
value = ConstantFolding(var_read.expression, value_type).result()
all_cst_used.append(ConstantValue(str(value), str(value_type)))
except NotConstant:
pass
if isinstance(var_read, Constant):
all_cst_used.append(ConstantValue(str(var_read.value), str(var_read.type)))
if isinstance(var_read, StateVariable):
if var_read.node_initialization:
if var_read.node_initialization.irs:
if var_read.node_initialization in context_explored:
return
context_explored.add(var_read.node_initialization)
_extract_constants_from_irs(
var_read.node_initialization.irs,
all_cst_used,
all_cst_used_in_binary,
context_explored,
)


def _extract_constant_from_binary(
ir: Binary,
all_cst_used: List[ConstantValue],
all_cst_used_in_binary: Dict[str, List[ConstantValue]],
):
for r in ir.read:
if isinstance(r, Constant):
all_cst_used_in_binary[str(ir.type)].append(ConstantValue(str(r.value), str(r.type)))
if isinstance(ir.variable_left, Constant) or isinstance(ir.variable_right, Constant):
if ir.lvalue:
try:
type_ = ir.lvalue.type
cst = ConstantFolding(ir.expression, type_).result()
all_cst_used.append(ConstantValue(str(cst.value), str(type_)))
except NotConstant:
pass


def _extract_constants_from_irs(
irs: List[Operation],
all_cst_used: List[ConstantValue],
all_cst_used_in_binary: Dict[str, List[ConstantValue]],
context_explored: Set[Node],
) -> None:
for ir in irs:
if isinstance(ir, Binary):
for r in ir.read:
if isinstance(r, Constant):
all_cst_used_in_binary[str(ir.type)].append(
ConstantValue(str(r.value), str(r.type))
)
if isinstance(ir.variable_left, Constant) or isinstance(
ir.variable_right, Constant
):
if ir.lvalue:
try:
type_ = ir.lvalue.type
cst = ConstantFolding(ir.expression, type_).result()
all_cst_used.append(ConstantValue(str(cst.value), str(type_)))
except NotConstant:
pass
_extract_constant_from_binary(ir, all_cst_used, all_cst_used_in_binary)
if isinstance(ir, TypeConversion):
if isinstance(ir.variable, Constant):
if isinstance(ir.type, TypeAlias):
Expand All @@ -222,24 +268,9 @@ def _extract_constants_from_irs( # pylint: disable=too-many-branches,too-many-n
except ValueError: # index could fail; should never happen in working solidity code
pass
for r in ir.read:
var_read = r.points_to_origin if isinstance(r, ReferenceVariable) else r
# Do not report struct_name in a.struct_name
if isinstance(ir, Member):
continue
if isinstance(var_read, Constant):
all_cst_used.append(ConstantValue(str(var_read.value), str(var_read.type)))
if isinstance(var_read, StateVariable):
if var_read.node_initialization:
if var_read.node_initialization.irs:
if var_read.node_initialization in context_explored:
continue
context_explored.add(var_read.node_initialization)
_extract_constants_from_irs(
var_read.node_initialization.irs,
all_cst_used,
all_cst_used_in_binary,
context_explored,
)
_extract_constant_from_read(
ir, r, all_cst_used, all_cst_used_in_binary, context_explored
)


def _extract_constants(
Expand Down
172 changes: 166 additions & 6 deletions slither/visitors/expression/constants_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
TupleExpression,
TypeConversion,
CallExpression,
MemberAccess,
)
from slither.core.expressions.elementary_type_name_expression import ElementaryTypeNameExpression
from slither.core.variables import Variable
from slither.utils.integer_conversion import convert_string_to_fraction, convert_string_to_int
from slither.visitors.expression.expression import ExpressionVisitor
Expand All @@ -27,7 +29,13 @@ class NotConstant(Exception):
KEY = "ConstantFolding"

CONSTANT_TYPES_OPERATIONS = Union[
Literal, BinaryOperation, UnaryOperation, Identifier, TupleExpression, TypeConversion
Literal,
BinaryOperation,
UnaryOperation,
Identifier,
TupleExpression,
TypeConversion,
MemberAccess,
]


Expand Down Expand Up @@ -69,6 +77,9 @@ def result(self) -> "Literal":
# pylint: disable=import-outside-toplevel
def _post_identifier(self, expression: Identifier) -> None:
from slither.core.declarations.solidity_variables import SolidityFunction
from slither.core.declarations.enum import Enum
from slither.core.solidity_types.type_alias import TypeAlias
from slither.core.declarations.contract import Contract

if isinstance(expression.value, Variable):
if expression.value.is_constant:
Expand All @@ -77,7 +88,14 @@ def _post_identifier(self, expression: Identifier) -> None:
# Everything outside of literal
if isinstance(
expr,
(BinaryOperation, UnaryOperation, Identifier, TupleExpression, TypeConversion),
(
BinaryOperation,
UnaryOperation,
Identifier,
TupleExpression,
TypeConversion,
MemberAccess,
),
):
cf = ConstantFolding(expr, self._type)
expr = cf.result()
Expand All @@ -88,20 +106,41 @@ def _post_identifier(self, expression: Identifier) -> None:
elif isinstance(expression.value, SolidityFunction):
set_val(expression, expression.value)
else:
raise NotConstant
# Enum: We don't want to raise an error for a direct access to an Enum as they can be converted to a constant value
# We can't handle it here because we don't have the field accessed so we do it in _post_member_access
# TypeAlias: Support when a .wrap() is done with a constant
# Contract: Support when a constatn is use from a different contract
if not isinstance(expression.value, (Enum, TypeAlias, Contract)):
raise NotConstant

# pylint: disable=too-many-branches,too-many-statements
def _post_binary_operation(self, expression: BinaryOperation) -> None:
expression_left = expression.expression_left
expression_right = expression.expression_right
if not isinstance(
expression_left,
(Literal, BinaryOperation, UnaryOperation, Identifier, TupleExpression, TypeConversion),
(
Literal,
BinaryOperation,
UnaryOperation,
Identifier,
TupleExpression,
TypeConversion,
MemberAccess,
),
):
raise NotConstant
if not isinstance(
expression_right,
(Literal, BinaryOperation, UnaryOperation, Identifier, TupleExpression, TypeConversion),
(
Literal,
BinaryOperation,
UnaryOperation,
Identifier,
TupleExpression,
TypeConversion,
MemberAccess,
),
):
raise NotConstant
left = get_val(expression_left)
Expand Down Expand Up @@ -205,6 +244,34 @@ def _post_assignement_operation(self, expression: expressions.AssignmentOperatio
raise NotConstant

def _post_call_expression(self, expression: expressions.CallExpression) -> None:
from slither.core.declarations.solidity_variables import SolidityFunction
from slither.core.declarations.enum import Enum
from slither.core.solidity_types import TypeAlias

# pylint: disable=too-many-boolean-expressions
if (
isinstance(expression.called, Identifier)
and expression.called.value == SolidityFunction("type()")
and len(expression.arguments) == 1
and (
isinstance(expression.arguments[0], ElementaryTypeNameExpression)
or isinstance(expression.arguments[0], Identifier)
and isinstance(expression.arguments[0].value, Enum)
)
):
# Returning early to support type(ElemType).max/min or type(MyEnum).max/min
return
if (
isinstance(expression.called.expression, Identifier)
and isinstance(expression.called.expression.value, TypeAlias)
and isinstance(expression.called, MemberAccess)
and expression.called.member_name == "wrap"
and len(expression.arguments) == 1
):
# Handle constants in .wrap of user defined type
set_val(expression, get_val(expression.arguments[0]))
return

called = get_val(expression.called)
args = [get_val(arg) for arg in expression.arguments]
if called.name == "keccak256(bytes)":
Expand All @@ -220,12 +287,104 @@ def _post_conditional_expression(self, expression: expressions.ConditionalExpres
def _post_elementary_type_name_expression(
self, expression: expressions.ElementaryTypeNameExpression
) -> None:
raise NotConstant
# We don't have to raise an exception to support type(uint112).max or similar
pass

def _post_index_access(self, expression: expressions.IndexAccess) -> None:
raise NotConstant

# pylint: disable=too-many-locals
def _post_member_access(self, expression: expressions.MemberAccess) -> None:
from slither.core.declarations import (
SolidityFunction,
Contract,
EnumContract,
EnumTopLevel,
Enum,
)
from slither.core.solidity_types import UserDefinedType, TypeAlias

# pylint: disable=too-many-nested-blocks
if isinstance(expression.expression, CallExpression) and expression.member_name in [
"min",
"max",
]:
if isinstance(expression.expression.called, Identifier):
if expression.expression.called.value == SolidityFunction("type()"):
assert len(expression.expression.arguments) == 1
type_expression_found = expression.expression.arguments[0]
type_found: Union[ElementaryType, UserDefinedType]
if isinstance(type_expression_found, ElementaryTypeNameExpression):
type_expression_found_type = type_expression_found.type
assert isinstance(type_expression_found_type, ElementaryType)
type_found = type_expression_found_type
value = (
type_found.max if expression.member_name == "max" else type_found.min
)
set_val(expression, value)
return
# type(enum).max/min
# Case when enum is in another contract e.g. type(C.E).max
if isinstance(type_expression_found, MemberAccess):
contract = type_expression_found.expression.value
assert isinstance(contract, Contract)
for enum in contract.enums:
if enum.name == type_expression_found.member_name:
type_found_in_expression = enum
type_found = UserDefinedType(enum)
break
else:
assert isinstance(type_expression_found, Identifier)
type_found_in_expression = type_expression_found.value
assert isinstance(type_found_in_expression, (EnumContract, EnumTopLevel))
type_found = UserDefinedType(type_found_in_expression)
value = (
type_found_in_expression.max
if expression.member_name == "max"
else type_found_in_expression.min
)
set_val(expression, value)
return
elif isinstance(expression.expression, Identifier) and isinstance(
expression.expression.value, Enum
):
# Handle direct access to enum field
set_val(expression, expression.expression.value.values.index(expression.member_name))
return
elif isinstance(expression.expression, Identifier) and isinstance(
expression.expression.value, TypeAlias
):
# User defined type .wrap call handled in _post_call_expression
return
elif (
isinstance(expression.expression.value, Contract)
and expression.member_name in expression.expression.value.variables_as_dict
and expression.expression.value.variables_as_dict[expression.member_name].is_constant
):
# Handles when a constant is accessed on another contract
variables = expression.expression.value.variables_as_dict
if isinstance(variables[expression.member_name].expression, MemberAccess):
self._post_member_access(variables[expression.member_name].expression)
set_val(expression, get_val(variables[expression.member_name].expression))
return

# If the variable is a Literal we convert its value to int
if isinstance(variables[expression.member_name].expression, Literal):
value = convert_string_to_int(
variables[expression.member_name].expression.converted_value
)
# If the variable is a UnaryOperation we need convert its value to int
# and replacing possible spaces
elif isinstance(variables[expression.member_name].expression, UnaryOperation):
value = convert_string_to_int(
str(variables[expression.member_name].expression).replace(" ", "")
)
else:
value = variables[expression.member_name].expression

set_val(expression, value)
return

raise NotConstant

def _post_new_array(self, expression: expressions.NewArray) -> None:
Expand Down Expand Up @@ -272,6 +431,7 @@ def _post_type_conversion(self, expression: expressions.TypeConversion) -> None:
TupleExpression,
TypeConversion,
CallExpression,
MemberAccess,
),
):
raise NotConstant
Expand Down
22 changes: 22 additions & 0 deletions tests/unit/slithir/test_constantfolding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from pathlib import Path

from slither import Slither
from slither.printers.guidance.echidna import _extract_constants, ConstantValue

TEST_DATA_DIR = Path(__file__).resolve().parent / "test_data"


def test_enum_max_min(solc_binary_path) -> None:
solc_path = solc_binary_path("0.8.19")
slither = Slither(Path(TEST_DATA_DIR, "constantfolding.sol").as_posix(), solc=solc_path)

contracts = slither.get_contract_from_name("A")

constants = _extract_constants(contracts)[0]["A"]["use()"]

assert set(constants) == {
ConstantValue(value="2", type="uint256"),
ConstantValue(value="10", type="uint256"),
ConstantValue(value="100", type="uint256"),
ConstantValue(value="4294967295", type="uint32"),
}
Loading