-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathquery.py
160 lines (146 loc) · 7.02 KB
/
query.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import math
import pickle
import shlex
import re
from collections import defaultdict
from itertools import chain, groupby
from operator import itemgetter
from kgram import KGramIndex
from diskindex import DiskIndex
from normalize import query_normalize, remove_special_characters
from utils import intersect_sorted_lists, union_sorted_lists
THRESHOLD = .35
class QueryProcessor(object):
"""Query processing class, creates an interface to on-disk index
and kgram index for queries. Supports boolean and ranked queries,
and handles spelling corrections using kgram index"""
def __init__(self, path='bin/', num_docs=0):
with open('{}kgram.bin'.format(path), 'rb') as f:
self.kgram_index = pickle.load(f)
self.disk_index = DiskIndex(path)
self.k_docs = 10
self.num_docs = num_docs
self.path = path
def query(self, query, ranked_flag=False):
"""Query interface, returns results of either boolean or ranked queries"""
if ranked_flag:
return self.ranked_query(query, self.k_docs)
return self.boolean_query(query)
def check_spelling(self, query, vocab, ranked_flag=False):
"""Handles spell-checking for boolean or ranked queries"""
if ranked_flag:
return self.check_spelling_ranked(query, vocab)
return self.check_spelling_boolean(query, vocab)
def check_spelling_boolean(self, query, vocab):
"""Checks each term in query for spelling correction, returns new string
if corrections are made"""
terms = re.findall(r"\w+", query)
new_terms = [term if ('*' in term or remove_special_characters(term) in vocab) else self.select_best_spelling(term) for term in terms]
if not terms == new_terms:
if all(new_terms):
for term, new in zip(terms, new_terms):
if term != new:
query = query.replace(term, new)
return query
def check_spelling_ranked(self, query, vocab):
"""Faster spell check when symbols are not important"""
terms = query.split()
new_terms = [term if ('*' in term or remove_special_characters(term) in vocab) else self.select_best_spelling(term) for term in terms]
if not terms == new_terms:
if all(new_terms):
return " ".join(new_terms)
def select_best_spelling(self, term):
"""Returns the best spelling candidate based on edit distance and document frequency"""
candidates = self.kgram_index.find_spelling_candidates(term, THRESHOLD)
if not candidates:
return None
frequencies = self.disk_index.get_doc_frequency(candidates)
return candidates[frequencies.index(max(frequencies))]
def ranked_query(self, query, k):
"""Returns the k most relevant documents from the corpus for a query,
using the "term at a time" algorithm"""
accumulator = defaultdict(int)
query = [word if '*' in word else query_normalize(word) for word in query.split()]
for term in query:
if '*' in term:
query.extend(self.wildcard_query(term))
continue
postings = self.disk_index.get_postings(term)
if postings:
wqt = math.log(1 + self.num_docs/len(postings))
for posting in postings:
wdt = 1 + math.log(posting[1])
accumulator[posting[0]] += wdt * wqt
return self.disk_index.get_k_scores(accumulator, k)
def boolean_query(self, query):
"""Returns the documents that satisfy a boolean query using the index"""
query_literals = self.process_query(query)
index = self.disk_index.retrieve_postings(query_literals)
success_doc_ids = []
for literal in query_literals:
# Guard for unmatched " symbols
try:
queries = shlex.split(literal)
except ValueError as e:
print(e)
print(literal)
queries = [literal]
docs_with_all_queries = []
for subliterals in queries:
subliterals = subliterals.split()
wildcard = False
for term in subliterals:
if '*' in term:
# Recursively call query to pull new index with wildcard terms, and append the results
gram_query = self.wildcard_query(term.lower())
gram_query = '+'.join(gram_query)
res = self.boolean_query(gram_query)
if res:
docs_with_all_queries.append(res)
wildcard = True
if wildcard:
continue
subliterals = [query_normalize(term) for term in subliterals]
combined_postings_lists = list(chain.from_iterable([index[subliteral] for subliteral in subliterals]))
combined_postings_lists = sorted(combined_postings_lists, key=lambda t: t[0])
docs_with_all_queries.append(self.get_current_postings(combined_postings_lists, subliterals))
if docs_with_all_queries:
ids_intersect = docs_with_all_queries[0]
for l in docs_with_all_queries[1:]:
ids_intersect = intersect_sorted_lists(ids_intersect, l)
success_doc_ids = union_sorted_lists(success_doc_ids, ids_intersect)
return list(success_doc_ids)
def wildcard_query(self, query):
"""Puts queries in correct form for the kgram index, splits on grams
and returns the strings that contain each gram"""
if not query.startswith('*'):
query = '$' + query
if not query.endswith('*'):
query = query + '$'
gram_list = query.split('*')
gram_list = set(filter(None, gram_list))
return self.kgram_index.get_intersection_grams(gram_list)
@staticmethod
def process_query(query):
"""Isolate query literals"""
literals = query.split('+')
literals = list(map(str.strip, literals))
return literals
@staticmethod
def get_current_postings(combined_postings_lists, subliterals):
docs_with_current_query = []
for key, doc_postings in groupby(combined_postings_lists, itemgetter(0)):
doc_postings = list(doc_postings)
if len(subliterals) > 1:
if len(doc_postings) == len(subliterals):
postings = [x[2] for x in doc_postings]
for i in range(len(postings)):
postings[i] = [posting - i for posting in postings[i]]
results = postings[0]
for p in postings[1:]:
results = list(intersect_sorted_lists(results, p))
if results:
docs_with_current_query.append(doc_postings[0][0])
else:
docs_with_current_query.append(doc_postings[0][0])
return docs_with_current_query