Skip to content

Commit

Permalink
adaptive hope-fear learner
Browse files Browse the repository at this point in the history
  • Loading branch information
redpony committed Feb 10, 2014
1 parent 3798fb9 commit 31b5d03
Show file tree
Hide file tree
Showing 11 changed files with 287 additions and 8 deletions.
1 change: 1 addition & 0 deletions mteval/ns.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ inline const SufficientStats operator-(const SufficientStats& a, const Sufficien
struct SegmentEvaluator {
virtual ~SegmentEvaluator();
virtual void Evaluate(const std::vector<WordID>& hyp, SufficientStats* out) const = 0;
std::string src; // this may not always be available
};

// Instructions for implementing a new metric
Expand Down
34 changes: 34 additions & 0 deletions mteval/ns_docscorer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,40 @@ DocumentScorer::~DocumentScorer() {}

DocumentScorer::DocumentScorer() {}

DocumentScorer::DocumentScorer(const EvaluationMetric* metric,
const string& src_ref_file) {
const WordID kDIV = TD::Convert("|||");
assert(!src_ref_file.empty());
cerr << "Loading source and references from " << src_ref_file << "...\n";
ReadFile rf(src_ref_file);
istream& in = *rf.stream();
unsigned lc = 0;
string src_ref;
vector<WordID> tmp;
vector<vector<WordID> > refs;
while(getline(in, src_ref)) {
++lc;
size_t end_src = src_ref.find(" ||| ");
if (end_src == string::npos) {
cerr << "Expected SRC ||| REF [||| REF2 ||| REF3 ...] in line " << lc << endl;
abort();
}
refs.clear();
tmp.clear();
TD::ConvertSentence(src_ref, &tmp, end_src + 5);
unsigned last = 0;
for (unsigned j = 0; j < tmp.size(); ++j) {
if (tmp[j] == kDIV) {
refs.push_back(vector<WordID>(tmp.begin() + last, tmp.begin() + j));
last = j + 1;
}
}
refs.push_back(vector<WordID>(tmp.begin() + last, tmp.end()));
scorers_.push_back(metric->CreateSegmentEvaluator(refs));
scorers_.back()->src = src_ref.substr(0, end_src);
}
}

void DocumentScorer::Init(const EvaluationMetric* metric,
const vector<string>& ref_files,
const string& src_file,
Expand Down
9 changes: 9 additions & 0 deletions mteval/scorer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <boost/test/unit_test.hpp>
#include <boost/test/floating_point_comparison.hpp>

#include "ns_docscorer.h"
#include "ns.h"
#include "tdict.h"
#include "scorer.h"
Expand Down Expand Up @@ -223,4 +224,12 @@ BOOST_AUTO_TEST_CASE(NewScoreAPI) {
//cerr << metric->ComputeScore(statse) << endl;
}

BOOST_AUTO_TEST_CASE(HybridSourceReferenceFileFormat) {
std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA);
EvaluationMetric* metric = EvaluationMetric::Instance("IBM_BLEU");
DocumentScorer ds(metric, path + "/devset.txt");
BOOST_CHECK_EQUAL(2, ds.size());
BOOST_CHECK_EQUAL("Quelltext hier .", ds[0]->src);
}

BOOST_AUTO_TEST_SUITE_END()
2 changes: 2 additions & 0 deletions mteval/test_data/devset.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Quelltext hier . ||| source text here . ||| original text . ||| some source text .
ein anderer Satz . ||| another sentence . ||| a different sentece .
13 changes: 9 additions & 4 deletions training/mira/Makefile.am
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
bin_PROGRAMS = kbest_mira \
kbest_cut_mira
bin_PROGRAMS = \
kbest_mira \
kbest_cut_mira \
ada_opt_sm

EXTRA_DIST = mira.py

ada_opt_sm_SOURCES = ada_opt_sm.cc
ada_opt_sm_LDFLAGS= -rdynamic
ada_opt_sm_LDADD = ../utils/libtraining_utils.a ../../decoder/libcdec.a ../../klm/search/libksearch.a ../../mteval/libmteval.a ../../utils/libutils.a ../../klm/lm/libklm.a ../../klm/util/libklm_util.a ../../klm/util/double-conversion/libklm_util_double.a

kbest_mira_SOURCES = kbest_mira.cc
kbest_mira_LDFLAGS= -rdynamic
kbest_mira_LDADD = ../../decoder/libcdec.a ../../klm/search/libksearch.a ../../mteval/libmteval.a ../../utils/libutils.a ../../klm/lm/libklm.a ../../klm/util/libklm_util.a ../../klm/util/double-conversion/libklm_util_double.a


kbest_cut_mira_SOURCES = kbest_cut_mira.cc
kbest_cut_mira_LDFLAGS= -rdynamic
kbest_cut_mira_LDADD = ../../decoder/libcdec.a ../../klm/search/libksearch.a ../../mteval/libmteval.a ../../utils/libutils.a ../../klm/lm/libklm.a ../../klm/util/libklm_util.a ../../klm/util/double-conversion/libklm_util_double.a

AM_CPPFLAGS = -W -Wall -Wno-sign-compare -I$(top_srcdir)/utils -I$(top_srcdir)/decoder -I$(top_srcdir)/mteval
AM_CPPFLAGS = -W -Wall -Wno-sign-compare -I$(top_srcdir)/utils -I$(top_srcdir)/decoder -I$(top_srcdir)/mteval -I$(top_srcdir)/training/utils
198 changes: 198 additions & 0 deletions training/mira/ada_opt_sm.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
#include "config.h"

#include <boost/container/flat_map.hpp>
#include <boost/shared_ptr.hpp>
#include <boost/program_options.hpp>
#include <boost/program_options/variables_map.hpp>

#include "filelib.h"
#include "stringlib.h"
#include "weights.h"
#include "sparse_vector.h"
#include "candidate_set.h"
#include "sentence_metadata.h"
#include "ns.h"
#include "ns_docscorer.h"
#include "verbose.h"
#include "hg.h"
#include "ff_register.h"
#include "decoder.h"
#include "fdict.h"
#include "sampler.h"

using namespace std;
namespace po = boost::program_options;

boost::shared_ptr<MT19937> rng;
vector<training::CandidateSet> kbests;
SparseVector<weight_t> G, u, lambdas;
double pseudo_doc_decay = 0.9;

bool InitCommandLine(int argc, char** argv, po::variables_map* conf) {
po::options_description opts("Configuration options");
opts.add_options()
("decoder_config,c",po::value<string>(),"[REQ] Decoder configuration file")
("devset,d",po::value<string>(),"[REQ] Source/reference development set")
("weights,w",po::value<string>(),"Initial feature weights file")
("mt_metric,m",po::value<string>()->default_value("ibm_bleu"), "Scoring metric (ibm_bleu, nist_bleu, koehn_bleu, ter, combi)")
("size",po::value<unsigned>()->default_value(0), "Process rank (for multiprocess mode)")
("rank",po::value<unsigned>()->default_value(1), "Number of processes (for multiprocess mode)")
("optimizer,o",po::value<unsigned>()->default_value(1), "Optimizer (Adaptive MIRA=1)")
("fear,f",po::value<unsigned>()->default_value(1), "Fear selection (model-cost=1, maxcost=2, maxscore=3)")
("hope,h",po::value<unsigned>()->default_value(1), "Hope selection (model+cost=1, mincost=2)")
("eta0", po::value<double>()->default_value(0.1), "Initial step size")
("random_seed,S", po::value<uint32_t>(), "Random seed (if not specified, /dev/random will be used)")
("mt_metric_scale,s", po::value<double>()->default_value(1.0), "Scale MT loss function by this amount")
("pseudo_doc,e", "Use pseudo-documents for approximate scoring")
("k_best_size,k", po::value<unsigned>()->default_value(500), "Size of hypothesis list to search for oracles");
po::options_description clo("Command line options");
clo.add_options()
("config", po::value<string>(), "Configuration file")
("help,H", "Print this help message and exit");
po::options_description dconfig_options, dcmdline_options;
dconfig_options.add(opts);
dcmdline_options.add(opts).add(clo);

po::store(parse_command_line(argc, argv, dcmdline_options), *conf);
if (conf->count("config")) {
ifstream config((*conf)["config"].as<string>().c_str());
po::store(po::parse_config_file(config, dconfig_options), *conf);
}
po::notify(*conf);

if (conf->count("help")
|| !conf->count("decoder_config")
|| !conf->count("devset")) {
cerr << dcmdline_options << endl;
return false;
}
return true;
}

struct TrainingObserver : public DecoderObserver {
explicit TrainingObserver(const EvaluationMetric& m, const int k) : metric(m), kbest_size(k), cur_eval() {}

const EvaluationMetric& metric;
const int kbest_size;
const SegmentEvaluator* cur_eval;
SufficientStats pdoc;
unsigned hi, vi, fi; // hope, viterbi, fear

void SetSegmentEvaluator(const SegmentEvaluator* eval) {
cur_eval = eval;
}

virtual void NotifySourceParseFailure(const SentenceMetadata& smeta) {
cerr << "Failed to translate sentence with ID = " << smeta.GetSentenceID() << endl;
abort();
}

unsigned CostAugmentedDecode(const training::CandidateSet& cs,
const SparseVector<double>& w,
double alpha = 0) {
unsigned best_i = 0;
double best = -numeric_limits<double>::infinity();
for (unsigned i = 0; i < cs.size(); ++i) {
double s = cs[i].fmap.dot(w);
if (alpha)
s += alpha * metric.ComputeScore(cs[i].eval_feats + pdoc);
if (s > best) {
best = s;
best_i = i;
}
}
return best_i;
}

virtual void NotifyTranslationForest(const SentenceMetadata& smeta, Hypergraph* hg) {
pdoc *= pseudo_doc_decay;
const unsigned sent_id = smeta.GetSentenceID();
kbests[sent_id].AddUniqueKBestCandidates(*hg, kbest_size, cur_eval);
vi = CostAugmentedDecode(kbests[sent_id], lambdas);
hi = CostAugmentedDecode(kbests[sent_id], lambdas, 1.0);
fi = CostAugmentedDecode(kbests[sent_id], lambdas, -1.0);
cerr << sent_id << " ||| " << TD::GetString(kbests[sent_id][vi].ewords) << " ||| " << metric.ComputeScore(kbests[sent_id][vi].eval_feats + pdoc) << endl;
pdoc += kbests[sent_id][vi].eval_feats; // update pseudodoc stats
}
};

int main(int argc, char** argv) {
SetSilent(true); // turn off verbose decoder output
register_feature_functions();

po::variables_map conf;
if (!InitCommandLine(argc, argv, &conf)) return 1;

if (conf.count("random_seed"))
rng.reset(new MT19937(conf["random_seed"].as<uint32_t>()));
else
rng.reset(new MT19937);

string metric_name = UppercaseString(conf["mt_metric"].as<string>());
if (metric_name == "COMBI") {
cerr << "WARNING: 'combi' metric is no longer supported, switching to 'COMB:TER=-0.5;IBM_BLEU=0.5'\n";
metric_name = "COMB:TER=-0.5;IBM_BLEU=0.5";
} else if (metric_name == "BLEU") {
cerr << "WARNING: 'BLEU' is ambiguous, assuming 'IBM_BLEU'\n";
metric_name = "IBM_BLEU";
}
EvaluationMetric* metric = EvaluationMetric::Instance(metric_name);
DocumentScorer ds(metric, conf["devset"].as<string>());
cerr << "Loaded " << ds.size() << " references for scoring with " << metric_name << endl;
kbests.resize(ds.size());
double eta = 0.001;

ReadFile ini_rf(conf["decoder_config"].as<string>());
Decoder decoder(ini_rf.stream());

vector<weight_t>& dense_weights = decoder.CurrentWeightVector();
if (conf.count("weights")) {
Weights::InitFromFile(conf["weights"].as<string>(), &dense_weights);
Weights::InitSparseVector(dense_weights, &lambdas);
}

TrainingObserver observer(*metric, conf["k_best_size"].as<unsigned>());

unsigned num = 200;
for (unsigned iter = 1; iter < num; ++iter) {
lambdas.init_vector(&dense_weights);
unsigned sent_id = rng->next() * ds.size();
cerr << "Learning from sentence id: " << sent_id << endl;
observer.SetSegmentEvaluator(ds[sent_id]);
decoder.SetId(sent_id);
decoder.Decode(ds[sent_id]->src, &observer);
if (observer.vi != observer.hi) { // viterbi != hope
SparseVector<double> grad = kbests[sent_id][observer.fi].fmap;
grad -= kbests[sent_id][observer.hi].fmap;
cerr << "GRAD: " << grad << endl;
const SparseVector<double>& g = grad;
#if HAVE_CXX11 && (__GNUC_MINOR__ > 4 || __GNUC__ > 4)
for (auto& gi : g) {
#else
for (SparseVector<double>::const_iterator it = g.begin(); it != g.end(); ++it) {
const pair<unsigned,double>& gi = *it;
#endif
if (gi.second) {
u[gi.first] += gi.second;
G[gi.first] += gi.second * gi.second;
lambdas.set_value(gi.first, 1.0); // this is a dummy value to trigger recomputation
}
}
for (SparseVector<double>::iterator it = lambdas.begin(); it != lambdas.end(); ++it) {
const pair<unsigned,double>& xi = *it;
double z = fabs(u[xi.first] / iter) - 0.0;
double s = 1;
if (u[xi.first] > 0) s = -1;
if (z > 0 && G[xi.first]) {
lambdas.set_value(xi.first, eta * s * z * iter / sqrt(G[xi.first]));
} else {
lambdas.set_value(xi.first, 0.0);
}
}
}
}
cerr << "Optimization complete.\n";
Weights::WriteToFile("-", dense_weights, true);
return 0;
}

15 changes: 15 additions & 0 deletions training/utils/candidate_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -171,4 +171,19 @@ void CandidateSet::AddKBestCandidates(const Hypergraph& hg, size_t kbest_size, c
Dedup();
}

void CandidateSet::AddUniqueKBestCandidates(const Hypergraph& hg, size_t kbest_size, const SegmentEvaluator* scorer) {
typedef KBest::KBestDerivations<vector<WordID>, ESentenceTraversal, KBest::FilterUnique> K;
K kbest(hg, kbest_size);

for (unsigned i = 0; i < kbest_size; ++i) {
const K::Derivation* d =
kbest.LazyKthBest(hg.nodes_.size() - 1, i);
if (!d) break;
cs.push_back(Candidate(d->yield, d->feature_values));
if (scorer)
scorer->Evaluate(d->yield, &cs.back().eval_feats);
}
Dedup();
}

}
2 changes: 1 addition & 1 deletion training/utils/candidate_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class CandidateSet {
void ReadFromFile(const std::string& file);
void WriteToFile(const std::string& file) const;
void AddKBestCandidates(const Hypergraph& hg, size_t kbest_size, const SegmentEvaluator* scorer = NULL);
// TODO add code to do unique k-best
void AddUniqueKBestCandidates(const Hypergraph& hg, size_t kbest_size, const SegmentEvaluator* scorer = NULL);
// TODO add code to draw k samples

private:
Expand Down
15 changes: 15 additions & 0 deletions utils/stringlib.h
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,21 @@ void VisitTokens(std::string const& s,F f) {
VisitTokens(mp.p,mp.p+s.size(),f);
}

template <class F>
void VisitTokens(std::string const& s,F f, unsigned start) {
if (0) {
std::vector<std::string> ss=SplitOnWhitespace(s);
for (unsigned i=0;i<ss.size();++i)
f(ss[i]);
return;
}
//FIXME:
if (s.empty()) return;
mutable_c_str mp(s);
SLIBDBG("mp="<<mp.p);
VisitTokens(mp.p+start,mp.p+s.size(),f);
}

inline void SplitCommandAndParam(const std::string& in, std::string* cmd, std::string* param) {
cmd->clear();
param->clear();
Expand Down
4 changes: 2 additions & 2 deletions utils/tdict.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ struct add_wordids {

}

void TD::ConvertSentence(std::string const& s, std::vector<WordID>* ids) {
void TD::ConvertSentence(std::string const& s, std::vector<WordID>* ids, unsigned start) {
ids->clear();
VisitTokens(s,add_wordids(ids));
VisitTokens(s,add_wordids(ids),start);
}
2 changes: 1 addition & 1 deletion utils/tdict.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

struct TD {
static WordID end(); // next id to be assigned; [begin,end) give the non-reserved tokens seen so far
static void ConvertSentence(std::string const& sent, std::vector<WordID>* ids);
static void ConvertSentence(std::string const& sent, std::vector<WordID>* ids, unsigned start=0);
static void GetWordIDs(const std::vector<std::string>& strings, std::vector<WordID>* ids);
static std::string GetString(const std::vector<WordID>& str);
static std::string GetString(WordID const* i,WordID const* e);
Expand Down

0 comments on commit 31b5d03

Please sign in to comment.