-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #239 from clulab/kwalcock/vectors
Add vectors to elasticsearch
- Loading branch information
Showing
10 changed files
with
172 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
from argparse import ArgumentParser | ||
from elasticsearch import Elasticsearch | ||
from sentence_transformers import SentenceTransformer | ||
|
||
class VectorSearcher(): | ||
def __init__(self, url, username, password, model_name): | ||
super().__init__() | ||
self.elasticsearch = Elasticsearch(url, basic_auth=(username, password)) | ||
self.sentence_transformer = SentenceTransformer(model_name) | ||
|
||
def search(self, k, text): | ||
index = "habitus2" | ||
# This vector is assumed to be normalized. | ||
vector = self.sentence_transformer.encode(text).tolist() | ||
query = { | ||
"field": "chatVector", | ||
"query_vector": vector, | ||
"k": k, | ||
"num_candidates": k | ||
} | ||
# The maximum value for size is limited by the index's index.max_result_window. | ||
# If more results are desired, they need to be paged. | ||
result = self.elasticsearch.search(index=index, knn=query, source=False, from_=0, size=k) | ||
hits = result.body["hits"]["hits"] | ||
ids_and_scores = [(hit["_id"], hit["_score"]) for hit in hits] | ||
return ids_and_scores | ||
|
||
def run(username, password, k, text): | ||
url = "https://elasticsearch.habitus.clulab.org/" | ||
model_name = "all-MiniLM-L6-v2" | ||
vector_searcher = VectorSearcher(url, username, password, model_name) | ||
ids_and_scores = vector_searcher.search(k, text) | ||
for index, (id, score) in enumerate(ids_and_scores): | ||
print(index, id, score) | ||
|
||
def get_args(): | ||
argument_parser = ArgumentParser() | ||
argument_parser.add_argument("-u", "--username", required=True, help="elasticsearch username") | ||
argument_parser.add_argument("-p", "--password", required=True, help="elasticsearch password") | ||
argument_parser.add_argument("-k", "--k", required=True, help="number of nearest neighbors") | ||
argument_parser.add_argument("-t", "--text", required=True, help="text to be matched") | ||
args = argument_parser.parse_args() | ||
return args.username, args.password, args.k, args.text | ||
|
||
if __name__ == "__main__": | ||
username, password, k, text = get_args() | ||
run(username, password, k, text) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from pandas import DataFrame | ||
from pipeline import InputStage | ||
|
||
import pandas | ||
|
||
class VectorInputStage(InputStage): | ||
def __init__(self, file_name: str) -> None: | ||
super().__init__(".") | ||
self.file_name = file_name | ||
|
||
def mk_data_frame(self, file_name: str) -> DataFrame: | ||
data_frame = pandas.read_csv(self.file_name, sep="\t", encoding="utf-8", na_values=[""], keep_default_na=False, dtype={ | ||
"url": str, | ||
"sentenceIndex": int, | ||
"sentence": str, | ||
"belief": bool, | ||
"sent_locs": str, | ||
"context_locs": str | ||
}) | ||
return data_frame | ||
|
||
def run(self) -> DataFrame: | ||
data_frame = self.mk_data_frame(self.file_name) | ||
# data_frame = data_frame[0:1000] # TODO: remove | ||
return data_frame |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
from argparse import ArgumentParser | ||
from pandas_output_stage import PandasOutputStage | ||
from pipeline import Pipeline | ||
from vector_input_stage import VectorInputStage | ||
from vector_vector_stage import VectorVectorStage | ||
from typing import Tuple | ||
|
||
|
||
def get_in_and_out() -> Tuple[str, str]: | ||
argument_parser = ArgumentParser() | ||
argument_parser.add_argument("-i", "--input", required=True, help="input file name") | ||
argument_parser.add_argument("-o", "--output", required=True, help="output file name") | ||
args = argument_parser.parse_args() | ||
return args.input, args.output | ||
|
||
if __name__ == "__main__": | ||
vector_model_name: str = "all-MiniLM-L6-v2" | ||
input_file_name: str = "../corpora/uganda-mining/uganda-2.tsv" | ||
output_file_name: str = "../corpora/uganda-mining/uganda-2-vectors.tsv" | ||
# input_file_name, output_file_name = get_in_and_out() | ||
pipeline = Pipeline( | ||
VectorInputStage(input_file_name), | ||
[ | ||
VectorVectorStage(vector_model_name) | ||
], | ||
PandasOutputStage(output_file_name) | ||
) | ||
pipeline.run() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
from datasets import Dataset, DatasetDict | ||
from pandas import DataFrame | ||
from pipeline import InnerStage | ||
from sentence_transformers import SentenceTransformer | ||
|
||
import numpy | ||
import torch | ||
|
||
class VectorVectorStage(InnerStage): | ||
def __init__(self, model_name: str) -> None: | ||
super().__init__() | ||
self.sentence_transformer = SentenceTransformer(model_name) | ||
|
||
def encode(self, index, sentence): | ||
print(index) | ||
vector = self.sentence_transformer.encode(sentence) | ||
vector_strings = [str(value) for value in vector] | ||
vector_string = ", ".join(vector_strings) | ||
return vector_string | ||
|
||
def mk_vectors(self, data_frame: DataFrame): | ||
vectors = [self.encode(index, sentence) for index, sentence in enumerate(data_frame["sentence"])] | ||
return vectors | ||
|
||
def run(self, data_frame: DataFrame) -> DataFrame: | ||
vectors = self.mk_vectors(data_frame) | ||
data_frame["vector"] = vectors | ||
return data_frame |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters