-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
10 changed files
with
743 additions
and
259 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
import numpy as np | ||
import re | ||
from nltk.corpus import stopwords | ||
|
||
from .utils import cosine_similarities | ||
|
||
re_ws = re.compile(r'\s+') | ||
re_num = re.compile(r'[^\w\s\']', flags=re.UNICODE) | ||
THRESHOLD = 0.1 | ||
WEIGHTING = { | ||
'C': 2, | ||
'SC': 2, | ||
'SSC': 1, | ||
'WC': 3, | ||
'WSSC': 5 | ||
} | ||
STOPWORDS_LANGUAGE = 'english' | ||
|
||
class WModel: | ||
def __init__(self, model): | ||
self.model = model | ||
self._model_dims = model.get_dims() | ||
|
||
def __getitem__(self, name): | ||
# Save Rust worrying about lifetime of a numpy array | ||
a = np.zeros((self._model_dims,), dtype=np.float32) | ||
|
||
self.model.load_embedding(name, a) | ||
|
||
return a | ||
|
||
class CategoryManager: | ||
_stop_words = None | ||
_model = None | ||
_classifier_bow = None | ||
_topic_vectors = None | ||
|
||
def __init__(self, word_model): | ||
self._categories = {} | ||
self._model = WModel(word_model) | ||
self._stop_words = stopwords.words(STOPWORDS_LANGUAGE) | ||
|
||
def add_categories_from_bow(self, name, classifier_bow): | ||
topic_vectors = [ | ||
(np.mean([WEIGHTING[code] * self._model[w] for code, w in l], axis=0), [w for _, w in l]) for k, l in classifier_bow.items() | ||
] | ||
self._categories[name] = (classifier_bow, topic_vectors) | ||
|
||
def closest(self, text, cat, classifier_bow_vec): | ||
word_list = set(sum(self.strip_document(text), [])) | ||
word_scores = [ | ||
(word, | ||
cosine_similarities(self._model[word], classifier_bow_vec[cat]).mean() | ||
|
||
# TODO: double check model.embedding_similarities( | ||
# cm._model[word], get_cat_bow(cat) | ||
#) | ||
) | ||
for word in word_list | ||
if cat in classifier_bow_vec | ||
] | ||
return [ | ||
word for word, score in sorted(word_scores, key=lambda word: word[1], reverse=True) | ||
if score > 0.5 | ||
] | ||
|
||
def strip_document(self, doc): | ||
if type(doc) is list: | ||
doc = ' '.join(doc) | ||
|
||
docs = doc.split(',') | ||
word_list = [] | ||
for doc in docs: | ||
doc = doc.replace('\n', ' ').replace('_', ' ').replace('\'', '').lower() | ||
doc = re_ws.sub(' ', re_num.sub('', doc)).strip() | ||
|
||
if doc == '': | ||
return [] | ||
|
||
word_list.append([w for w in doc.split(' ') if w not in self._stop_words]) | ||
|
||
return word_list | ||
|
||
def test(self, sentence, category_group='dtcats'): | ||
classifier_bow, topic_vectors = self._categories[category_group] | ||
|
||
clean = self.strip_document(sentence) | ||
|
||
if not clean: | ||
return [] | ||
|
||
tags = set() | ||
for words in clean: | ||
if not words: | ||
continue | ||
|
||
vec = np.mean([self._model[w] for w in words], axis=0) | ||
result = cosine_similarities(vec, [t for t, _ in topic_vectors]) | ||
|
||
top = np.nonzero(result > THRESHOLD)[0] | ||
|
||
tags.update({(result[i], classifier_bow.keys()[i]) for i in top}) | ||
|
||
return sorted(tags, reverse=True) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
import os | ||
from collections import Counter | ||
from elasticsearch2 import Elasticsearch | ||
from elasticsearch_dsl import Search, Q | ||
from nltk import download | ||
|
||
from ._ff_fasttext import FfModel | ||
from .category_manager import CategoryManager | ||
from .taxonomy import get_taxonomy, taxonomy_to_categories, categories_to_classifier_bow | ||
|
||
APPEARANCE_THRESHOLD = 5 | ||
UPPER_APPEARANCE_THRESHOLD = 10 | ||
HOST = os.getenv('ELASTICSEARCH_HOST', 'http://localhost:9200') | ||
ELASTICSEARCH_INDEX = os.getenv('ELASTICSEARCH_INDEX', 'ons1639492069322') | ||
|
||
def get_datasets(cm, classifier_bow): | ||
classifier_bow_vec = { | ||
k: [cm._model[w[1]] for w in words] | ||
for k, words in classifier_bow.items() | ||
} | ||
datasets = {} | ||
#results_df = pd.DataFrame((d.to_dict() for d in s.scan())) | ||
# /businesseconomy../business/activitiespeopel/123745 | ||
client = Elasticsearch([HOST]) | ||
|
||
s = Search(using=client, index=ELASTICSEARCH_INDEX) \ | ||
.filter('bool', must=[Q('exists', field="description.title")]) | ||
for hit in s.scan(): | ||
try: | ||
datasets[hit.description.title] = { | ||
'category': tuple(hit.uri.split('/')[1:4]), | ||
'text': f'{hit.description.title} {hit.description.metaDescription}' | ||
} | ||
datasets[hit.description.title]['bow'] = cm.closest(datasets[hit.description.title]['text'], datasets[hit.description.title]['category'], classifier_bow_vec) | ||
except AttributeError as e: | ||
pass | ||
return datasets | ||
|
||
def discover_terms(datasets, classifier_bow): | ||
discovered_terms = {} | ||
# could do with lemmatizing | ||
for ds in datasets.values(): | ||
if ds['category'][0:2] not in discovered_terms: | ||
discovered_terms[ds['category'][0:2]] = Counter() | ||
discovered_terms[ds['category'][0:2]].update(set(ds['bow'])) | ||
if ds['category'] not in discovered_terms: | ||
discovered_terms[ds['category']] = Counter() | ||
discovered_terms[ds['category']].update(set(ds['bow'])) | ||
|
||
discovered_terms = { | ||
k: [w for w, c in count.items() if c > (APPEARANCE_THRESHOLD if len(k) > 2 else UPPER_APPEARANCE_THRESHOLD)] | ||
for k, count in discovered_terms.items() | ||
} | ||
for key, terms in classifier_bow.items(): | ||
if key in discovered_terms: | ||
terms += [('WSSC', w) for w in discovered_terms[key]] | ||
if key[0:2] in discovered_terms: | ||
terms += [('WC', w) for w in discovered_terms[key[0:2]]] | ||
|
||
def append_discovered_terms_from_elasticsearch(cm, classifier_bow): | ||
datasets = get_datasets(cm, classifier_bow) | ||
discover_terms(datasets, classifier_bow) | ||
|
||
def load(model_file): | ||
model = FfModel(model_file) | ||
# Import and download stopwords from NLTK. | ||
download('stopwords') # Download stopwords list. | ||
|
||
category_manager = CategoryManager(model) | ||
|
||
taxonomy = get_taxonomy() | ||
categories = taxonomy_to_categories(taxonomy) | ||
|
||
classifier_bow = categories_to_classifier_bow(category_manager.strip_document, categories) | ||
append_discovered_terms_from_elasticsearch(category_manager, classifier_bow) | ||
category_manager.add_categories_from_bow('onyxcats', classifier_bow) | ||
|
||
return category_manager |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
import click | ||
|
||
from ff_fasttext.extract import load | ||
|
||
@click.command() | ||
def main(): | ||
category_manager = load('test_data/wiki.en.fifu') | ||
word = None | ||
while word not in ('\\quit', '\\q'): | ||
word = input("Sentence? ") | ||
categories = category_manager.test(word.strip(), 'onyxcats') | ||
categories = ['->'.join(c[1]) + f'({c[0]:.2f})' for c in categories if c[0] > 0.3][:5] | ||
print('\n'.join(categories)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
from fastapi import FastApi | ||
from .extract import load | ||
|
||
THRESHOLD = 0.4 | ||
|
||
def make_app(category_manager): | ||
app = FastApi() | ||
|
||
@app.get('/categories') | ||
def get_categories(query: str): | ||
categories = category_manager.test(query.strip(), 'onyxcats') | ||
return [ | ||
c for c in categories if c[1] > THRESHOLD | ||
] | ||
|
||
return app | ||
|
||
category_manager = load('test_data/wiki.en.fifu') | ||
app = make_app(category_manager) |
Oops, something went wrong.