Skip to content

Commit

Permalink
Merge pull request #294 from NatLibFi/revert-additional-vw_ensemble-f…
Browse files Browse the repository at this point in the history
…eatures

Revert PR #288 additional features for vw_ensemble
  • Loading branch information
osma authored Jul 5, 2019
2 parents a9bb0d7 + 8f5fcff commit b13f02b
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 35 deletions.
41 changes: 9 additions & 32 deletions annif/backend/vw_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,6 @@ class VWEnsembleBackend(
# will make it more careful so that it will require more training data.
DEFAULT_DISCOUNT_RATE = 0.01

# score threshold for "zero features": scores lower than this will be
# considered zero and marked with a zero feature given to VW
ZERO_THRESHOLD = 0.001

def _load_subject_freq(self):
path = os.path.join(self.datadir, self.FREQ_FILE)
if not os.path.exists(path):
Expand Down Expand Up @@ -98,30 +94,17 @@ def _source_project_ids(self):
sources = annif.util.parse_sources(self.params['sources'])
return [project_id for project_id, _ in sources]

@staticmethod
def _format_value(true):
def _format_example(self, subject_id, scores, true=None):
if true is None:
return ''
val = ''
elif true:
return 1
val = 1
else:
return -1

def _format_example(self, subject_id, scores, true=None):
features = " ".join(["{}:{:.6f}".format(proj, scores[proj_idx])
for proj_idx, proj
in enumerate(self._source_project_ids)])
zero_features = " ".join(["zero^{}".format(proj)
for proj_idx, proj
in enumerate(self._source_project_ids)
if scores[proj_idx] < self.ZERO_THRESHOLD])
return "{} |raw {} {} |{} {} {}".format(
self._format_value(true),
features,
zero_features,
subject_id,
features,
zero_features)
val = -1
ex = "{} |{}".format(val, subject_id)
for proj_idx, proj in enumerate(self._source_project_ids):
ex += " {}:{:.6f}".format(proj, scores[proj_idx])
return ex

def _doc_score_vector(self, doc, source_projects):
score_vectors = []
Expand All @@ -136,8 +119,7 @@ def _doc_to_example(self, doc, project, source_projects):
true = subjects.as_vector(project.subjects)
score_vector = self._doc_score_vector(doc, source_projects)
for subj_id in range(len(true)):
if true[subj_id] \
or score_vector[:, subj_id].sum() >= self.ZERO_THRESHOLD:
if true[subj_id] or score_vector[:, subj_id].sum() > 0.0:
ex = (subj_id, self._format_example(
subj_id,
score_vector[:, subj_id],
Expand All @@ -154,11 +136,6 @@ def _create_examples(self, corpus, project):
random.shuffle(examples)
return examples

def _create_model(self, project):
# add interactions between raw (descriptor-invariant) features to
# the mix
super()._create_model(project, {'q': 'rr'})

@staticmethod
def _write_freq_file(subject_freq, filename):
with open(filename, 'w') as freqfile:
Expand Down
5 changes: 2 additions & 3 deletions tests/test_backend_vw_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def test_vw_ensemble_format_example(datadir):
datadir=str(datadir))

ex = vw_ensemble._format_example(0, [0.5])
assert ex == ' |raw dummy-en:0.500000 |0 dummy-en:0.500000 '
assert ex == ' |0 dummy-en:0.500000'


def test_vw_ensemble_format_example_avoid_sci_notation(datadir):
Expand All @@ -137,5 +137,4 @@ def test_vw_ensemble_format_example_avoid_sci_notation(datadir):
datadir=str(datadir))

ex = vw_ensemble._format_example(0, [7.24e-05])
assert ex == ' |raw dummy-en:0.000072 zero^dummy-en' + \
' |0 dummy-en:0.000072 zero^dummy-en'
assert ex == ' |0 dummy-en:0.000072'

0 comments on commit b13f02b

Please sign in to comment.