diff --git a/crosshair/diff_behavior.py b/crosshair/diff_behavior.py index 934acd93..375f3f41 100644 --- a/crosshair/diff_behavior.py +++ b/crosshair/diff_behavior.py @@ -17,6 +17,7 @@ StateSpaceContext, VerificationStatus, ) +from crosshair.test_util import flexible_equal from crosshair.tracers import ( COMPOSITE_TRACER, CoverageResult, @@ -227,7 +228,7 @@ def run_iteration( result1 = describe_behavior(fn1, args1) result2 = describe_behavior(fn2, args2) space.detach_path() - if result1 == result2 and args1 == args2: + if flexible_equal(result1, result2) and flexible_equal(args1, args2): debug("Functions equivalent") return (VerificationStatus.CONFIRMED, None) debug("Functions differ") diff --git a/crosshair/diff_behavior_test.py b/crosshair/diff_behavior_test.py index 3e965d68..a441d0fc 100644 --- a/crosshair/diff_behavior_test.py +++ b/crosshair/diff_behavior_test.py @@ -146,6 +146,18 @@ def f(a: Optional[Callable[[int], int]]): assert diffs == [] +def test_diff_behavior_nan() -> None: + def f(x: float): + return x + + diffs = diff_behavior( + FunctionInfo.from_fn(f), + FunctionInfo.from_fn(f), + DEFAULT_OPTIONS, + ) + assert diffs == [] + + if __name__ == "__main__": if ("-v" in sys.argv) or ("--verbose" in sys.argv): set_debug(True) diff --git a/crosshair/fuzz_core_test.py b/crosshair/fuzz_core_test.py index 5bf99b82..5abf4b4b 100644 --- a/crosshair/fuzz_core_test.py +++ b/crosshair/fuzz_core_test.py @@ -45,6 +45,7 @@ StateSpaceContext, ) from crosshair.stubs_parser import signature_from_stubs +from crosshair.test_util import flexible_equal from crosshair.tracers import COMPOSITE_TRACER, NoTracing, ResumedTracing from crosshair.util import CrosshairUnsupported, debug, type_args_of @@ -300,7 +301,7 @@ def symbolic_checker( postexec_symbolic_args = deep_realize(postexec_symbolic_args) symbolic_ret = deep_realize(symbolic_ret) symbolic_exc = deep_realize(symbolic_exc) - rets_differ = realize(bool(literal_ret != symbolic_ret)) + rets_differ = not realize(flexible_equal(literal_ret, symbolic_ret)) postexec_args_differ = realize( bool(postexec_literal_args != postexec_symbolic_args) ) diff --git a/crosshair/test_util.py b/crosshair/test_util.py index 6173b3d1..0b47b9ff 100644 --- a/crosshair/test_util.py +++ b/crosshair/test_util.py @@ -1,9 +1,9 @@ import pathlib import sys -import traceback +from collections.abc import Container from copy import deepcopy from dataclasses import dataclass, replace -from typing import Callable, Iterable, List, Mapping, Optional, Sequence, Set, Tuple +from typing import Callable, Iterable, List, Mapping, Optional, Sequence, Tuple from crosshair.core import ( AnalysisMessage, @@ -21,6 +21,7 @@ ch_stack, debug, in_debug, + is_iterable, is_pure_python, name_of_type, ) @@ -128,6 +129,27 @@ def flexible_equal(a, b): return True if a != a and b != b: # handle float('nan') return True + if ( + is_iterable(a) + and not isinstance(a, Container) + and is_iterable(b) + and not isinstance(b, Container) + ): # unsized iterables compare by contents + a, b = list(a), list(b) + if type(a) == type(b): + # Recursively apply flexible_equal for most containers: + if isinstance(a, Mapping): + if len(a) != len(b): + return False + for k, v in a.items(): + if not flexible_equal(v, b[k]): + return False + return True + if isinstance(a, Container) and not isinstance(a, (str, bytes)): + if len(a) != len(b): + return False + return all(flexible_equal(ai, bi) for ai, bi in zip(a, b)) + return a == b diff --git a/crosshair/test_util_test.py b/crosshair/test_util_test.py new file mode 100644 index 00000000..02bfc1d6 --- /dev/null +++ b/crosshair/test_util_test.py @@ -0,0 +1,16 @@ +from crosshair.test_util import flexible_equal + + +def test_flexible_equal(): + assert float("nan") != float("nan") + assert flexible_equal(float("nan"), float("nan")) + assert flexible_equal((42, float("nan")), (42, float("nan"))) + assert not flexible_equal([float("nan"), 11], [float("nan"), 22]) + + def gen(): + yield 11 + yield 22 + + assert flexible_equal(gen(), iter([11, 22])) + assert not flexible_equal(gen(), iter([11, 22, 33])) + assert not flexible_equal(gen(), iter([11]))