Skip to content

Commit

Permalink
fix: fix offset factory (#208)
Browse files Browse the repository at this point in the history
  • Loading branch information
jemmyshin authored Dec 7, 2022
1 parent f6fdc9d commit 61e0a57
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 6 deletions.
12 changes: 8 additions & 4 deletions annlite/storage/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,9 +306,14 @@ def _offset_factory(_, record):
self._conn.row_factory = _offset_factory

cursor = self._conn.cursor()
offsets = cursor.execute(sql, params).fetchall()
self._conn.row_factory = None
return offsets if offsets else []

try:
offsets = cursor.execute(sql, params).fetchall()
self._conn.row_factory = None
return offsets if offsets else []
except Exception as e:
self._conn.row_factory = None
raise e

def delete(self, doc_ids: List[str]):
"""Delete the docs
Expand Down Expand Up @@ -361,7 +366,6 @@ def count(self, where_clause: str = '', where_params: Tuple = ()):
# # EXPLAIN SQL query
# for row in self._conn.execute('EXPLAIN QUERY PLAN ' + sql, params):
# print(row)

return self._conn.execute(sql, params).fetchone()[0]
else:
sql = f'SELECT MAX(_id) from {self.name} LIMIT 1;'
Expand Down
46 changes: 46 additions & 0 deletions tests/test_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,3 +209,49 @@ def test_filter_with_limit_offset(tmpfile, limit, offset, order_by, ascending):
assert m.tags[order_by] <= matches[i + 1].tags[order_by]
else:
assert m.tags[order_by] >= matches[i + 1].tags[order_by]


@pytest.mark.parametrize('limit', [1, 5])
def test_filter_with_wrong_columns(tmpfile, limit):
N = 100
D = 128

index = AnnLite(
D,
columns=[('price', float)],
data_path=tmpfile,
)
X = np.random.random((N, D)).astype(np.float32)

docs = DocumentArray(
[
Document(id=f'{i}', embedding=X[i], tags={'price': random.random()})
for i in range(N)
]
)

index.index(docs)

matches = index.filter(
filter={'price': {'$lte': 50}},
limit=limit,
include_metadata=True,
)

assert len(matches) == limit

import sqlite3

with pytest.raises(sqlite3.OperationalError):
matches = index.filter(
filter={'price_': {'$lte': 50}},
include_metadata=True,
)

matches = index.filter(
filter={'price': {'$lte': 50}},
limit=limit,
include_metadata=True,
)

assert len(matches) == limit
4 changes: 2 additions & 2 deletions tests/test_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,10 +306,10 @@ def test_remote_backup_restore(tmpdir):
index.index(docs)

tmpname = uuid.uuid4().hex
index.backup(target_name='test', token=token)
index.backup(target_name='test_remote_backup_restore', token=token)

index = AnnLite(n_dim=D, data_path=tmpdir / 'workspace' / '0')
index.restore(source_name='test', token=token)
index.restore(source_name='test_remote_backup_restore', token=token)

delete_artifact(tmpname)
status = index.stat
Expand Down

0 comments on commit 61e0a57

Please sign in to comment.