Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add arrays contains function #860

Merged
merged 1 commit into from
Feb 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion examples/get_started/common_sql_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def num_chars_udf(file):
return ([],)


dc = DataChain.from_storage("gs://datachain-demo/dogs-and-cats/")
dc = DataChain.from_storage("gs://datachain-demo/dogs-and-cats/", anon=True)
dc.map(num_chars_udf, params=["file"], output={"num_chars": list[str]}).select(
"file.path", "num_chars"
).show(5)
Expand All @@ -32,6 +32,12 @@ def num_chars_udf(file):
.show(5)
)

parts = string.split(path.name(C("file.path")), ".")
chain = dc.mutate(
isdog=array.contains(parts, "dog"),
iscat=array.contains(parts, "cat"),
)
chain.select("file.path", "isdog", "iscat").show(5)

chain = dc.mutate(
a=array.length(string.split("file.path", "/")),
Expand Down Expand Up @@ -79,6 +85,15 @@ def num_chars_udf(file):
3 dogs-and-cats/cat.10.json cat.10 json
4 dogs-and-cats/cat.100.jpg cat.100 jpg

[Limited by 5 rows]
file isdog iscat
path
0 dogs-and-cats/cat.1.jpg 0 1
1 dogs-and-cats/cat.1.json 0 1
2 dogs-and-cats/cat.10.jpg 0 1
3 dogs-and-cats/cat.10.json 0 1
4 dogs-and-cats/cat.100.jpg 0 1

[Limited by 5 rows]
Processed: 400 rows [00:00, 16496.93 rows/s]
a b greatest least
Expand Down
3 changes: 2 additions & 1 deletion src/datachain/func/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
row_number,
sum,
)
from .array import cosine_distance, euclidean_distance, length, sip_hash_64
from .array import contains, cosine_distance, euclidean_distance, length, sip_hash_64
from .conditional import case, greatest, ifelse, isnone, least
from .numeric import bit_and, bit_hamming_distance, bit_or, bit_xor, int_hash_64
from .random import rand
Expand All @@ -34,6 +34,7 @@
"case",
"collect",
"concat",
"contains",
"cosine_distance",
"count",
"dense_rank",
Expand Down
40 changes: 39 additions & 1 deletion src/datachain/func/array.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections.abc import Sequence
from typing import Union
from typing import Any, Union

from datachain.sql.functions import array

Expand Down Expand Up @@ -140,6 +140,44 @@ def length(arg: Union[str, Sequence, Func]) -> Func:
return Func("length", inner=array.length, cols=cols, args=args, result_type=int)


def contains(arr: Union[str, Sequence, Func], elem: Any) -> Func:
"""
Checks whether the `arr` array has the `elem` element.

Args:
arr (str | Sequence | Func): Array to check for the element.
If a string is provided, it is assumed to be the name of the array column.
If a sequence is provided, it is assumed to be an array of values.
If a Func is provided, it is assumed to be a function returning an array.
elem (Any): Element to check for in the array.

Returns:
Func: A Func object that represents the contains function. Result of the
function will be 1 if the element is present in the array, and 0 otherwise.

Example:
```py
dc.mutate(
contains1=func.array.contains("signal.values", 3),
contains2=func.array.contains([1, 2, 3, 4, 5], 7),
)
```
"""

def inner(arg):
is_json = type(elem) in [list, dict]
return array.contains(arg, elem, is_json)

if isinstance(arr, (str, Func)):
cols = [arr]
args = None
else:
cols = None
args = [arr]

return Func("contains", inner=inner, cols=cols, args=args, result_type=int)


def sip_hash_64(arg: Union[str, Sequence]) -> Func:
"""
Computes the SipHash-64 hash of the array.
Expand Down
14 changes: 13 additions & 1 deletion src/datachain/sql/functions/array.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from sqlalchemy.sql.functions import GenericFunction

from datachain.sql.types import Float, Int64
from datachain.sql.types import Boolean, Float, Int64
from datachain.sql.utils import compiler_not_implemented


Expand Down Expand Up @@ -37,6 +37,17 @@ class length(GenericFunction): # noqa: N801
inherit_cache = True


class contains(GenericFunction): # noqa: N801
"""
Checks if element is in the array.
"""

type = Boolean()
package = "array"
name = "contains"
inherit_cache = True


class sip_hash_64(GenericFunction): # noqa: N801
"""
Computes the SipHash-64 hash of the array.
Expand All @@ -51,4 +62,5 @@ class sip_hash_64(GenericFunction): # noqa: N801
compiler_not_implemented(cosine_distance)
compiler_not_implemented(euclidean_distance)
compiler_not_implemented(length)
compiler_not_implemented(contains)
compiler_not_implemented(sip_hash_64)
18 changes: 17 additions & 1 deletion src/datachain/sql/sqlite/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def setup():
compiles(sql_path.file_stem, "sqlite")(compile_path_file_stem)
compiles(sql_path.file_ext, "sqlite")(compile_path_file_ext)
compiles(array.length, "sqlite")(compile_array_length)
compiles(array.contains, "sqlite")(compile_array_contains)
compiles(string.length, "sqlite")(compile_string_length)
compiles(string.split, "sqlite")(compile_string_split)
compiles(string.regexp_replace, "sqlite")(compile_string_regexp_replace)
Expand Down Expand Up @@ -269,13 +270,16 @@ def create_string_functions(conn):

_registered_function_creators["string_functions"] = create_string_functions

has_json_extension = functions_exist(["json_array_length"])
has_json_extension = functions_exist(["json_array_length", "json_array_contains"])
if not has_json_extension:

def create_json_functions(conn):
conn.create_function(
"json_array_length", 1, py_json_array_length, deterministic=True
)
conn.create_function(
"json_array_contains", 3, py_json_array_contains, deterministic=True
)

_registered_function_creators["json_functions"] = create_json_functions

Expand Down Expand Up @@ -428,10 +432,22 @@ def py_json_array_length(arr):
return len(orjson.loads(arr))


def py_json_array_contains(arr, value, is_json):
if is_json:
value = orjson.loads(value)
return value in orjson.loads(arr)


def compile_array_length(element, compiler, **kwargs):
return compiler.process(func.json_array_length(*element.clauses.clauses), **kwargs)


def compile_array_contains(element, compiler, **kwargs):
return compiler.process(
func.json_array_contains(*element.clauses.clauses), **kwargs
)


def compile_string_length(element, compiler, **kwargs):
return compiler.process(func.length(*element.clauses.clauses), **kwargs)

Expand Down
5 changes: 5 additions & 0 deletions src/datachain/sql/sqlite/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@
return orjson.dumps(arr).decode("utf-8")


def adapt_dict(dct):
return orjson.dumps(dct).decode("utf-8")

Check warning on line 35 in src/datachain/sql/sqlite/types.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/sql/sqlite/types.py#L35

Added line #L35 was not covered by tests


def convert_array(arr):
return orjson.loads(arr)

Expand All @@ -52,6 +56,7 @@

def register_type_converters():
sqlite3.register_adapter(list, adapt_array)
sqlite3.register_adapter(dict, adapt_dict)
sqlite3.register_converter("ARRAY", convert_array)
if numpy_imported:
sqlite3.register_adapter(np.ndarray, adapt_np_array)
Expand Down
18 changes: 18 additions & 0 deletions tests/unit/sql/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,24 @@ def test_length(warehouse):
assert result == ((4, 5, 2),)


def test_contains(warehouse):
query = select(
func.contains(["abc", "def", "g", "hi"], "abc").label("contains1"),
func.contains(["abc", "def", "g", "hi"], "cdf").label("contains2"),
func.contains([3.0, 5.0, 1.0, 6.0, 1.0], 1.0).label("contains3"),
func.contains([[1, None, 3], [4, 5, 6]], [1, None, 3]).label("contains4"),
# Not supported yet by CH, need to add it later + some Pydantic model as
# an input:
# func.contains(
# [{"c": 1, "a": True}, {"b": False}], {"a": True, "c": 1}
# ).label("contains5"),
func.contains([1, None, 3], None).label("contains6"),
func.contains([1, True, 3], True).label("contains7"),
)
result = tuple(warehouse.db.execute(query))
assert result == ((1, 0, 1, 1, 1, 1),)


def test_length_on_split(warehouse):
query = select(
func.array.length(func.string.split(func.literal("abc/def/g/hi"), "/")),
Expand Down
25 changes: 25 additions & 0 deletions tests/unit/test_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
isnone,
literal,
)
from datachain.func.array import contains
from datachain.func.random import rand
from datachain.func.string import length as strlen
from datachain.lib.signal_schema import SignalSchema
Expand Down Expand Up @@ -797,3 +798,27 @@ def test_isnone_with_ifelse_mutate(col):
res = dc.mutate(test=ifelse(isnone(col), "NONE", "NOT_NONE"))
assert list(res.order_by("num").collect("test")) == ["NOT_NONE"] * 3 + ["NONE"] * 2
assert res.schema["test"] is str


def test_array_contains():
dc = DataChain.from_values(
arr=[list(range(1, i)) * i for i in range(2, 7)],
val=list(range(2, 7)),
)

assert list(dc.mutate(res=contains("arr", 3)).order_by("val").collect("res")) == [
0,
0,
1,
1,
1,
]
assert list(
dc.mutate(res=contains(C("arr"), 3)).order_by("val").collect("res")
) == [0, 0, 1, 1, 1]
assert list(
dc.mutate(res=contains(C("arr"), 10)).order_by("val").collect("res")
) == [0, 0, 0, 0, 0]
assert list(
dc.mutate(res=contains(C("arr"), None)).order_by("val").collect("res")
) == [0, 0, 0, 0, 0]