-
Notifications
You must be signed in to change notification settings - Fork 77
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
11 changed files
with
287 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
Quelltext hier . ||| source text here . ||| original text . ||| some source text . | ||
ein anderer Satz . ||| another sentence . ||| a different sentece . |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,15 +1,20 @@ | ||
bin_PROGRAMS = kbest_mira \ | ||
kbest_cut_mira | ||
bin_PROGRAMS = \ | ||
kbest_mira \ | ||
kbest_cut_mira \ | ||
ada_opt_sm | ||
|
||
EXTRA_DIST = mira.py | ||
|
||
ada_opt_sm_SOURCES = ada_opt_sm.cc | ||
ada_opt_sm_LDFLAGS= -rdynamic | ||
ada_opt_sm_LDADD = ../utils/libtraining_utils.a ../../decoder/libcdec.a ../../klm/search/libksearch.a ../../mteval/libmteval.a ../../utils/libutils.a ../../klm/lm/libklm.a ../../klm/util/libklm_util.a ../../klm/util/double-conversion/libklm_util_double.a | ||
|
||
kbest_mira_SOURCES = kbest_mira.cc | ||
kbest_mira_LDFLAGS= -rdynamic | ||
kbest_mira_LDADD = ../../decoder/libcdec.a ../../klm/search/libksearch.a ../../mteval/libmteval.a ../../utils/libutils.a ../../klm/lm/libklm.a ../../klm/util/libklm_util.a ../../klm/util/double-conversion/libklm_util_double.a | ||
|
||
|
||
kbest_cut_mira_SOURCES = kbest_cut_mira.cc | ||
kbest_cut_mira_LDFLAGS= -rdynamic | ||
kbest_cut_mira_LDADD = ../../decoder/libcdec.a ../../klm/search/libksearch.a ../../mteval/libmteval.a ../../utils/libutils.a ../../klm/lm/libklm.a ../../klm/util/libklm_util.a ../../klm/util/double-conversion/libklm_util_double.a | ||
|
||
AM_CPPFLAGS = -W -Wall -Wno-sign-compare -I$(top_srcdir)/utils -I$(top_srcdir)/decoder -I$(top_srcdir)/mteval | ||
AM_CPPFLAGS = -W -Wall -Wno-sign-compare -I$(top_srcdir)/utils -I$(top_srcdir)/decoder -I$(top_srcdir)/mteval -I$(top_srcdir)/training/utils |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,198 @@ | ||
#include "config.h" | ||
|
||
#include <boost/container/flat_map.hpp> | ||
#include <boost/shared_ptr.hpp> | ||
#include <boost/program_options.hpp> | ||
#include <boost/program_options/variables_map.hpp> | ||
|
||
#include "filelib.h" | ||
#include "stringlib.h" | ||
#include "weights.h" | ||
#include "sparse_vector.h" | ||
#include "candidate_set.h" | ||
#include "sentence_metadata.h" | ||
#include "ns.h" | ||
#include "ns_docscorer.h" | ||
#include "verbose.h" | ||
#include "hg.h" | ||
#include "ff_register.h" | ||
#include "decoder.h" | ||
#include "fdict.h" | ||
#include "sampler.h" | ||
|
||
using namespace std; | ||
namespace po = boost::program_options; | ||
|
||
boost::shared_ptr<MT19937> rng; | ||
vector<training::CandidateSet> kbests; | ||
SparseVector<weight_t> G, u, lambdas; | ||
double pseudo_doc_decay = 0.9; | ||
|
||
bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { | ||
po::options_description opts("Configuration options"); | ||
opts.add_options() | ||
("decoder_config,c",po::value<string>(),"[REQ] Decoder configuration file") | ||
("devset,d",po::value<string>(),"[REQ] Source/reference development set") | ||
("weights,w",po::value<string>(),"Initial feature weights file") | ||
("mt_metric,m",po::value<string>()->default_value("ibm_bleu"), "Scoring metric (ibm_bleu, nist_bleu, koehn_bleu, ter, combi)") | ||
("size",po::value<unsigned>()->default_value(0), "Process rank (for multiprocess mode)") | ||
("rank",po::value<unsigned>()->default_value(1), "Number of processes (for multiprocess mode)") | ||
("optimizer,o",po::value<unsigned>()->default_value(1), "Optimizer (Adaptive MIRA=1)") | ||
("fear,f",po::value<unsigned>()->default_value(1), "Fear selection (model-cost=1, maxcost=2, maxscore=3)") | ||
("hope,h",po::value<unsigned>()->default_value(1), "Hope selection (model+cost=1, mincost=2)") | ||
("eta0", po::value<double>()->default_value(0.1), "Initial step size") | ||
("random_seed,S", po::value<uint32_t>(), "Random seed (if not specified, /dev/random will be used)") | ||
("mt_metric_scale,s", po::value<double>()->default_value(1.0), "Scale MT loss function by this amount") | ||
("pseudo_doc,e", "Use pseudo-documents for approximate scoring") | ||
("k_best_size,k", po::value<unsigned>()->default_value(500), "Size of hypothesis list to search for oracles"); | ||
po::options_description clo("Command line options"); | ||
clo.add_options() | ||
("config", po::value<string>(), "Configuration file") | ||
("help,H", "Print this help message and exit"); | ||
po::options_description dconfig_options, dcmdline_options; | ||
dconfig_options.add(opts); | ||
dcmdline_options.add(opts).add(clo); | ||
|
||
po::store(parse_command_line(argc, argv, dcmdline_options), *conf); | ||
if (conf->count("config")) { | ||
ifstream config((*conf)["config"].as<string>().c_str()); | ||
po::store(po::parse_config_file(config, dconfig_options), *conf); | ||
} | ||
po::notify(*conf); | ||
|
||
if (conf->count("help") | ||
|| !conf->count("decoder_config") | ||
|| !conf->count("devset")) { | ||
cerr << dcmdline_options << endl; | ||
return false; | ||
} | ||
return true; | ||
} | ||
|
||
struct TrainingObserver : public DecoderObserver { | ||
explicit TrainingObserver(const EvaluationMetric& m, const int k) : metric(m), kbest_size(k), cur_eval() {} | ||
|
||
const EvaluationMetric& metric; | ||
const int kbest_size; | ||
const SegmentEvaluator* cur_eval; | ||
SufficientStats pdoc; | ||
unsigned hi, vi, fi; // hope, viterbi, fear | ||
|
||
void SetSegmentEvaluator(const SegmentEvaluator* eval) { | ||
cur_eval = eval; | ||
} | ||
|
||
virtual void NotifySourceParseFailure(const SentenceMetadata& smeta) { | ||
cerr << "Failed to translate sentence with ID = " << smeta.GetSentenceID() << endl; | ||
abort(); | ||
} | ||
|
||
unsigned CostAugmentedDecode(const training::CandidateSet& cs, | ||
const SparseVector<double>& w, | ||
double alpha = 0) { | ||
unsigned best_i = 0; | ||
double best = -numeric_limits<double>::infinity(); | ||
for (unsigned i = 0; i < cs.size(); ++i) { | ||
double s = cs[i].fmap.dot(w); | ||
if (alpha) | ||
s += alpha * metric.ComputeScore(cs[i].eval_feats + pdoc); | ||
if (s > best) { | ||
best = s; | ||
best_i = i; | ||
} | ||
} | ||
return best_i; | ||
} | ||
|
||
virtual void NotifyTranslationForest(const SentenceMetadata& smeta, Hypergraph* hg) { | ||
pdoc *= pseudo_doc_decay; | ||
const unsigned sent_id = smeta.GetSentenceID(); | ||
kbests[sent_id].AddUniqueKBestCandidates(*hg, kbest_size, cur_eval); | ||
vi = CostAugmentedDecode(kbests[sent_id], lambdas); | ||
hi = CostAugmentedDecode(kbests[sent_id], lambdas, 1.0); | ||
fi = CostAugmentedDecode(kbests[sent_id], lambdas, -1.0); | ||
cerr << sent_id << " ||| " << TD::GetString(kbests[sent_id][vi].ewords) << " ||| " << metric.ComputeScore(kbests[sent_id][vi].eval_feats + pdoc) << endl; | ||
pdoc += kbests[sent_id][vi].eval_feats; // update pseudodoc stats | ||
} | ||
}; | ||
|
||
int main(int argc, char** argv) { | ||
SetSilent(true); // turn off verbose decoder output | ||
register_feature_functions(); | ||
|
||
po::variables_map conf; | ||
if (!InitCommandLine(argc, argv, &conf)) return 1; | ||
|
||
if (conf.count("random_seed")) | ||
rng.reset(new MT19937(conf["random_seed"].as<uint32_t>())); | ||
else | ||
rng.reset(new MT19937); | ||
|
||
string metric_name = UppercaseString(conf["mt_metric"].as<string>()); | ||
if (metric_name == "COMBI") { | ||
cerr << "WARNING: 'combi' metric is no longer supported, switching to 'COMB:TER=-0.5;IBM_BLEU=0.5'\n"; | ||
metric_name = "COMB:TER=-0.5;IBM_BLEU=0.5"; | ||
} else if (metric_name == "BLEU") { | ||
cerr << "WARNING: 'BLEU' is ambiguous, assuming 'IBM_BLEU'\n"; | ||
metric_name = "IBM_BLEU"; | ||
} | ||
EvaluationMetric* metric = EvaluationMetric::Instance(metric_name); | ||
DocumentScorer ds(metric, conf["devset"].as<string>()); | ||
cerr << "Loaded " << ds.size() << " references for scoring with " << metric_name << endl; | ||
kbests.resize(ds.size()); | ||
double eta = 0.001; | ||
|
||
ReadFile ini_rf(conf["decoder_config"].as<string>()); | ||
Decoder decoder(ini_rf.stream()); | ||
|
||
vector<weight_t>& dense_weights = decoder.CurrentWeightVector(); | ||
if (conf.count("weights")) { | ||
Weights::InitFromFile(conf["weights"].as<string>(), &dense_weights); | ||
Weights::InitSparseVector(dense_weights, &lambdas); | ||
} | ||
|
||
TrainingObserver observer(*metric, conf["k_best_size"].as<unsigned>()); | ||
|
||
unsigned num = 200; | ||
for (unsigned iter = 1; iter < num; ++iter) { | ||
lambdas.init_vector(&dense_weights); | ||
unsigned sent_id = rng->next() * ds.size(); | ||
cerr << "Learning from sentence id: " << sent_id << endl; | ||
observer.SetSegmentEvaluator(ds[sent_id]); | ||
decoder.SetId(sent_id); | ||
decoder.Decode(ds[sent_id]->src, &observer); | ||
if (observer.vi != observer.hi) { // viterbi != hope | ||
SparseVector<double> grad = kbests[sent_id][observer.fi].fmap; | ||
grad -= kbests[sent_id][observer.hi].fmap; | ||
cerr << "GRAD: " << grad << endl; | ||
const SparseVector<double>& g = grad; | ||
#if HAVE_CXX11 && (__GNUC_MINOR__ > 4 || __GNUC__ > 4) | ||
for (auto& gi : g) { | ||
#else | ||
for (SparseVector<double>::const_iterator it = g.begin(); it != g.end(); ++it) { | ||
const pair<unsigned,double>& gi = *it; | ||
#endif | ||
if (gi.second) { | ||
u[gi.first] += gi.second; | ||
G[gi.first] += gi.second * gi.second; | ||
lambdas.set_value(gi.first, 1.0); // this is a dummy value to trigger recomputation | ||
} | ||
} | ||
for (SparseVector<double>::iterator it = lambdas.begin(); it != lambdas.end(); ++it) { | ||
const pair<unsigned,double>& xi = *it; | ||
double z = fabs(u[xi.first] / iter) - 0.0; | ||
double s = 1; | ||
if (u[xi.first] > 0) s = -1; | ||
if (z > 0 && G[xi.first]) { | ||
lambdas.set_value(xi.first, eta * s * z * iter / sqrt(G[xi.first])); | ||
} else { | ||
lambdas.set_value(xi.first, 0.0); | ||
} | ||
} | ||
} | ||
} | ||
cerr << "Optimization complete.\n"; | ||
Weights::WriteToFile("-", dense_weights, true); | ||
return 0; | ||
} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters