diff --git a/CHANGELOG.md b/CHANGELOG.md index 0108e9e4d..af4de938b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed * Fix issue with `(count nil)` throwing an exception (#759). + * Fix issue with keyword fn not testing for test membership in sets (#762). ## [v0.1.0b0] ### Added diff --git a/src/basilisp/lang/keyword.py b/src/basilisp/lang/keyword.py index b8ebaf32d..5f62ebb6b 100644 --- a/src/basilisp/lang/keyword.py +++ b/src/basilisp/lang/keyword.py @@ -1,9 +1,14 @@ import threading from functools import total_ordering -from typing import Iterable, Optional +from typing import Iterable, Optional, Union from basilisp.lang import map as lmap -from basilisp.lang.interfaces import IAssociative, ILispObject, IPersistentMap +from basilisp.lang.interfaces import ( + IAssociative, + ILispObject, + IPersistentMap, + IPersistentSet, +) _LOCK = threading.Lock() _INTERN: IPersistentMap[int, "Keyword"] = lmap.PersistentMap.empty() @@ -53,7 +58,9 @@ def __lt__(self, other): return False return self._ns < other._ns or self._name < other._name - def __call__(self, m: IAssociative, default=None): + def __call__(self, m: Union[IAssociative, IPersistentSet], default=None): + if isinstance(m, IPersistentSet): + return self if self in m else default try: return m.val_at(self, default) except AttributeError: diff --git a/tests/basilisp/keyword_test.py b/tests/basilisp/keyword_test.py index 943010a77..f8923bac2 100644 --- a/tests/basilisp/keyword_test.py +++ b/tests/basilisp/keyword_test.py @@ -3,6 +3,7 @@ import pytest from basilisp.lang import map as lmap +from basilisp.lang import set as lset from basilisp.lang.keyword import Keyword, complete, keyword @@ -46,6 +47,10 @@ def test_keyword_as_function(): assert "hi" == kw(lmap.map({kw: "hi"})) assert None is kw(lmap.map({"hi": kw})) + assert kw == kw(lset.s(kw)) + assert None is kw(lset.s(1)) + assert "hi" is kw(lset.s(1), default="hi") + @pytest.mark.parametrize( "o",