Skip to content

Commit

Permalink
optional support for doing perfect hashing of feature strings to save…
Browse files Browse the repository at this point in the history
… lots of memory
  • Loading branch information
Chris Dyer committed Sep 13, 2011
1 parent c41704e commit e7993fb
Show file tree
Hide file tree
Showing 9 changed files with 269 additions and 53 deletions.
22 changes: 20 additions & 2 deletions decoder/decoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,7 @@ struct DecoderImpl {
bool write_gradient; // TODO Observer
bool feature_expectations; // TODO Observer
bool output_training_vector; // TODO Observer
bool remove_intersected_rule_annotations;

static void ConvertSV(const SparseVector<prob_t>& src, SparseVector<double>* trg) {
for (SparseVector<prob_t>::const_iterator it = src.begin(); it != src.end(); ++it)
Expand Down Expand Up @@ -361,6 +362,9 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
("grammar,g",po::value<vector<string> >()->composing(),"Either SCFG grammar file(s) or phrase tables file(s)")
("per_sentence_grammar_file", po::value<string>(), "Optional (and possibly not implemented) per sentence grammar file enables all per sentence grammars to be stored in a single large file and accessed by offset")
("list_feature_functions,L","List available feature functions")
#ifdef HAVE_CMPH
("cmph_perfect_feature_hash,h", po::value<string>(), "Load perfect hash function for features")
#endif

("weights,w",po::value<string>(),"Feature weights file (initial forest / pass 1)")
("feature_function,F",po::value<vector<string> >()->composing(), "Pass 1 additional feature function(s) (-L for list)")
Expand Down Expand Up @@ -433,7 +437,8 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
("feature_expectations","Write feature expectations for all features in chart (**OBJ** will be the partition)")
("vector_format",po::value<string>()->default_value("b64"), "Sparse vector serialization format for feature expectations or gradients, includes (text or b64)")
("combine_size,C",po::value<int>()->default_value(1), "When option -G is used, process this many sentence pairs before writing the gradient (1=emit after every sentence pair)")
("forest_output,O",po::value<string>(),"Directory to write forests to");
("forest_output,O",po::value<string>(),"Directory to write forests to")
("remove_intersected_rule_annotations", "After forced decoding is completed, remove nonterminal annotations (i.e., the source side spans)");

// ob.AddOptions(&opts);
#ifdef FSA_RESCORING
Expand All @@ -443,7 +448,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
po::options_description clo("Command line options");
clo.add_options()
("config,c", po::value<vector<string> >(&cfg_files), "Configuration file(s) - latest has priority")
("help,h", "Print this help message and exit")
("help,?", "Print this help message and exit")
("usage,u", po::value<string>(), "Describe a feature function type")
("compgen", "Print just option names suitable for bash command line completion builtin 'compgen'")
;
Expand Down Expand Up @@ -645,6 +650,12 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
FD::Freeze(); // this means we can't see the feature names of not-weighted features
}

if (conf.count("cmph_perfect_feature_hash")) {
cerr << "Loading perfect hash function from " << conf["cmph_perfect_feature_hash"].as<string>() << " ...\n";
FD::EnableHash(conf["cmph_perfect_feature_hash"].as<string>());
cerr << " " << FD::NumFeats() << " features in map\n";
}

// set up translation back end
if (formalism == "scfg")
translator.reset(new SCFGTranslator(conf));
Expand Down Expand Up @@ -695,6 +706,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
unique_kbest = conf.count("unique_k_best");
get_oracle_forest = conf.count("get_oracle_forest");
oracle.show_derivation=conf.count("show_derivations");
remove_intersected_rule_annotations = conf.count("remove_intersected_rule_annotations");

#ifdef FSA_RESCORING
cfg_options.Validate();
Expand Down Expand Up @@ -1010,6 +1022,12 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
// if (!SILENT) cerr << " USING UNIFORM WEIGHTS\n";
// for (int i = 0; i < forest.edges_.size(); ++i)
// forest.edges_[i].edge_prob_=prob_t::One(); }
if (remove_intersected_rule_annotations) {
for (unsigned i = 0; i < forest.edges_.size(); ++i)
if (forest.edges_[i].rule_ &&
forest.edges_[i].rule_->parent_rule_)
forest.edges_[i].rule_ = forest.edges_[i].rule_->parent_rule_;
}
forest.Reweight(last_weights);
if (!SILENT) forest_stats(forest," Constr. forest",show_tree_structure,oracle.show_derivation);
if (!SILENT) cerr << " Constr. VitTree: " << ViterbiFTree(forest) << endl;
Expand Down
9 changes: 7 additions & 2 deletions utils/Makefile.am
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
noinst_PROGRAMS = ts
TESTS = ts
noinst_PROGRAMS = ts phmt
TESTS = ts phmt

if HAVE_GTEST
noinst_PROGRAMS += \
Expand Down Expand Up @@ -27,6 +27,11 @@ libutils_a_SOURCES = \
verbose.cc \
weights.cc

if HAVE_CMPH
libutils_a_SOURCES += perfect_hash.cc
endif

phmt_SOURCES = phmt.cc
ts_SOURCES = ts.cc
dict_test_SOURCES = dict_test.cc
dict_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS)
Expand Down
4 changes: 4 additions & 0 deletions utils/fdict.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ using namespace std;
Dict FD::dict_;
bool FD::frozen_ = false;

#ifdef HAVE_CMPH
PerfectHashFunction* FD::hash_ = NULL;
#endif

std::string FD::Convert(std::vector<WordID> const& v) {
return Convert(&*v.begin(),&*v.end());
}
Expand Down
36 changes: 36 additions & 0 deletions utils/fdict.h
Original file line number Diff line number Diff line change
@@ -1,23 +1,56 @@
#ifndef _FDICT_H_
#define _FDICT_H_

#include "config.h"

#include <iostream>
#include <string>
#include <vector>
#include "dict.h"

#ifdef HAVE_CMPH
#include "perfect_hash.h"
#include "string_to.h"
#endif

struct FD {
// once the FD is frozen, new features not already in the
// dictionary will return 0
static void Freeze() {
frozen_ = true;
}
static bool UsingPerfectHashFunction() {
#ifdef HAVE_CMPH
return hash_;
#else
return false;
#endif
}
static void EnableHash(const std::string& cmph_file) {
#ifdef HAVE_CMPH
hash_ = new PerfectHashFunction(cmph_file);
#endif
}
static inline int NumFeats() {
#ifdef HAVE_CMPH
if (hash_) return hash_->number_of_keys();
#endif
return dict_.max() + 1;
}
static inline WordID Convert(const std::string& s) {
#ifdef HAVE_CMPH
if (hash_) return (*hash_)(s);
#endif
return dict_.Convert(s, frozen_);
}
static inline const std::string& Convert(const WordID& w) {
#ifdef HAVE_CMPH
if (hash_) {
static std::string tls;
tls = to_string(w);
return tls;
}
#endif
return dict_.Convert(w);
}
static std::string Convert(WordID const *i,WordID const* e);
Expand All @@ -29,6 +62,9 @@ struct FD {
static Dict dict_;
private:
static bool frozen_;
#ifdef HAVE_CMPH
static PerfectHashFunction* hash_;
#endif
};

#endif
37 changes: 37 additions & 0 deletions utils/perfect_hash.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#include "config.h"

#ifdef HAVE_CMPH

#include "perfect_hash.h"

#include <cstdio>
#include <iostream>

using namespace std;

PerfectHashFunction::~PerfectHashFunction() {
cmph_destroy(mphf_);
}

PerfectHashFunction::PerfectHashFunction(const string& fname) {
FILE* f = fopen(fname.c_str(), "r");
if (!f) {
cerr << "Failed to open file " << fname << " for reading: cannot load hash function.\n";
abort();
}
mphf_ = cmph_load(f);
if (!mphf_) {
cerr << "cmph_load failed on " << fname << "!\n";
abort();
}
}

size_t PerfectHashFunction::operator()(const string& key) const {
return cmph_search(mphf_, &key[0], key.size());
}

size_t PerfectHashFunction::number_of_keys() const {
return cmph_size(mphf_);
}

#endif
24 changes: 24 additions & 0 deletions utils/perfect_hash.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#ifndef _PERFECT_HASH_MAP_H_
#define _PERFECT_HASH_MAP_H_

#include "config.h"

#ifndef HAVE_CMPH
#error libcmph is required to use PerfectHashFunction
#endif

#include <vector>
#include <boost/utility.hpp>
#include "cmph.h"

class PerfectHashFunction : boost::noncopyable {
public:
explicit PerfectHashFunction(const std::string& fname);
~PerfectHashFunction();
size_t operator()(const std::string& key) const;
size_t number_of_keys() const;
private:
cmph_t *mphf_;
};

#endif
44 changes: 44 additions & 0 deletions utils/phmt.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#include "config.h"

#ifndef HAVE_CMPH
int main() {
return 0;
}
#else

#include <iostream>
#include "weights.h"
#include "fdict.h"

using namespace std;

int main(int argc, char** argv) {
if (argc != 2) { cerr << "Usage: " << argv[0] << " file.mphf\n"; return 1; }
FD::EnableHash(argv[1]);
cerr << "Number of keys: " << FD::NumFeats() << endl;
cerr << "LexFE = " << FD::Convert("LexFE") << endl;
cerr << "LexEF = " << FD::Convert("LexEF") << endl;
{
Weights w;
vector<weight_t> v(FD::NumFeats());
v[FD::Convert("LexFE")] = 1.0;
v[FD::Convert("LexEF")] = 0.5;
w.InitFromVector(v);
cerr << "Writing...\n";
w.WriteToFile("weights.bin");
cerr << "Done.\n";
}
{
Weights w;
vector<weight_t> v(FD::NumFeats());
cerr << "Reading...\n";
w.InitFromFile("weights.bin");
cerr << "Done.\n";
w.InitVector(&v);
assert(v[FD::Convert("LexFE")] == 1.0);
assert(v[FD::Convert("LexEF")] == 0.5);
}
}

#endif

Loading

0 comments on commit e7993fb

Please sign in to comment.