Skip to content

Commit

Permalink
l2r bugfixes
Browse files Browse the repository at this point in the history
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@634 ec762483-ff6d-05da-a07a-a48fb63a330f
  • Loading branch information
graehl committed Aug 31, 2010
1 parent 926cedc commit e417508
Show file tree
Hide file tree
Showing 11 changed files with 301 additions and 35 deletions.
86 changes: 63 additions & 23 deletions decoder/apply_fsa_models.cc
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
#include <queue>
#include "apply_fsa_models.h"
#include <stdexcept>
#include <cassert>
#include <queue>
#include <stdint.h>

#include "writer.h"
#include "hg.h"
#include "ff_fsa_dynamic.h"
#include "ff_from_fsa.h"
#include "feature_vector.h"
#include "stringlib.h"
#include "apply_models.h"
#include <stdexcept>
#include <cassert>
#include "cfg.h"
#include "hg_cfg.h"
#include "utoa.h"
Expand All @@ -16,15 +19,15 @@
#include "d_ary_heap.h"
#include "agenda.h"
#include "show.h"
#include <stdint.h>
#include "string_to.h"

#define DFSA(x) x
//fsa earley chart

#define DPFSA(x) x
//prefix trie

#define DBUILDTRIE(x) x
#define DBUILDTRIE(x)

#define PRINT_PREFIX 1
#if PRINT_PREFIX
Expand Down Expand Up @@ -101,23 +104,54 @@ struct TrieBackP {

FsaFeatureFunction const* print_fsa=0;
CFG const* print_cfg=0;
inline void print_cfg_rhs(std::ostream &o,WordID w) {
if (print_cfg)
print_cfg->print_rhs_name(o,w);
inline ostream& print_cfg_rhs(std::ostream &o,WordID w,CFG const*pcfg=print_cfg) {
if (pcfg)
pcfg->print_rhs_name(o,w);
else
CFG::static_print_rhs_name(o,w);
return o;
}

inline std::string nt_name(WordID n,CFG const*pcfg=print_cfg) {
if (pcfg) return pcfg->nt_name(n);
return CFG::static_nt_name(n);
}

template <class V>
ostream& print_by_nt(std::ostream &o,V const& v,CFG const*pcfg=print_cfg,char const* header="\nNT -> X\n") {
o<<header;
for (int i=0;i<v.size();++i)
o << nt_name(i,pcfg) << " -> "<<v[i]<<"\n";
return o;
}

template <class V>
ostream& print_map_by_nt(std::ostream &o,V const& v,CFG const*pcfg=print_cfg,char const* header="\nNT -> X\n") {
o<<header;
for (typename V::const_iterator i=v.begin(),e=v.end();i!=e;++i) {
print_cfg_rhs(o,i->first,pcfg) << " -> "<<i->second<<"\n";
}
return o;
}


struct PrefixTrieEdge {
// PrefixTrieEdge() { }
PrefixTrieEdge()
// : dest(0),w(TD::max_wordid)
{}
PrefixTrieEdge(WordID w,NodeP dest)
: dest(dest),w(w)
{}
// explicit PrefixTrieEdge(best_t p) : p(p),dest(0) { }
best_t p;// viterbi additional prob, i.e. product over path incl. p_final = total rule prob

best_t p;// viterbi additional prob, i.e. product over path incl. p_final = total rule prob. note: for final edge, set this.
//DPFSA()
// we can probably just store deltas, but for debugging remember the full p
// best_t delta; //
NodeP dest;
bool is_final() const { return dest==0; }
WordID w; // for lhs, this will be nonneg NTHandle instead. // not set if is_final() // actually, set to lhs nt index
best_t p_dest() const;
WordID w; // for root and and is_final(), this will be (negated) NTHandle.

// for sorting most probable first in adj; actually >(p)
inline bool operator <(PrefixTrieEdge const& o) const {
Expand Down Expand Up @@ -218,7 +252,7 @@ struct PrefixTrieNode {
for (int i=0,e=adj.size();i!=e;++i) {
PrefixTrieEdge const& edge=adj[i];
// assert(edge.p.is_1()); // actually, after done_building, e will have telescoped dest->p/p.
NTHandle n=edge.w;
NTHandle n=-edge.w;
assert(n>=0);
SHOWM3(DPFSA,"index_lhs",i,edge,n);
v[n]=edge.dest;
Expand All @@ -228,7 +262,10 @@ struct PrefixTrieNode {
template <class PV>
void done_root(PV &v) {
assert(is_root());
SHOWM1(DBUILDTRIE,"done_root",OSTRF1(print_map_by_nt,edge_for));
done_building_r(); //sets adj
SHOWM1(DBUILDTRIE,"done_root",OSTRF1(print_by_nt,adj));
// SHOWM1(DBUILDTRIE,done_root,adj);
// index_adj(); // we want an index for the root node?. don't think so - index_lhs handles it. also we stopped clearing edge_for.
index_lhs(v); // uses adj
}
Expand All @@ -244,7 +281,7 @@ struct PrefixTrieNode {
// for done_building; compute incremental (telescoped) edge p
PrefixTrieEdge /*const&*/ operator()(PrefixTrieEdgeFor::value_type & pair) const {
PrefixTrieEdge &e=pair.second;//const_cast<PrefixTrieEdge&>(pair.second);
e.p=(e.dest->p)/p;
e.p=e.p_dest()/p;
return e;
}

Expand All @@ -265,6 +302,7 @@ struct PrefixTrieNode {
// (*this)(*i);
}
#endif
SHOWM1(DBUILDTRIE,"done building adj",prange(adj.begin(),adj.end(),true));
assert(adj.size()==edge_for.size());
// if (final) p_final/=p;
std::sort(adj.begin(),adj.end());
Expand All @@ -287,26 +325,26 @@ struct PrefixTrieNode {
inline NodeP build(W w,best_t rulep) {
return build(lhs,w,rulep);
}
inline NodeP build_lhs(NTHandle w,best_t rulep) {
return build(w,w,rulep);
inline NodeP build_lhs(NTHandle n,best_t rulep) {
return build(n,-n,rulep);
}

NodeP build(NTHandle lhs_,W w,best_t rulep) {
PrefixTrieEdgeFor::iterator i=edge_for.find(w);
if (i!=edge_for.end())
return improve_edge(i->second,rulep);
PrefixTrieEdge &e=i->second;
NodeP r=new PrefixTrieNode(lhs_,rulep);
IF_PRINT_PREFIX(r->backp=BP(w,this));
e.dest=r;
// edge_for.insert(i,PrefixTrieEdgeFor::value_type(w,PrefixTrieEdge(w,r)));
add(edge_for,w,PrefixTrieEdge(w,r));
SHOWM4(DBUILDTRIE,"built node",this,w,*r,r);
return r;
}

void set_final(NTHandle lhs_,best_t pf) {
assert(no_adj());
final=true;
PrefixTrieEdge &e=edge_for[-1];
PrefixTrieEdge &e=edge_for[null_wordid];
e.p=pf;
e.dest=0;
e.w=lhs_;
Expand Down Expand Up @@ -335,6 +373,10 @@ struct PrefixTrieNode {
PRINT_SELF(PrefixTrieNode)
};

inline best_t PrefixTrieEdge::p_dest() const {
return dest ? dest->p : p; // for final edge, p was set (no sentinel node)
}


//Trie starts with lhs (nonneg index), then continues w/ rhs (mixed >0 word, else NT)
// trie ends with final edge, which points to a per-lhs prefix node
Expand All @@ -358,7 +400,9 @@ struct PrefixTrie {
SHOWM2(DBUILDTRIE,"PrefixTrie()",rulesp->size(),lhs2.size());
cfg.VisitRuleIds(*this);
root.done_root(lhs2);
SHOWM4(DBUILDTRIE,"done w/ PrefixTrie: ",root,root.adj.size(),lhs2.size(),lhs2[0]);
SHOWM3(DBUILDTRIE,"done w/ PrefixTrie: ",root,root.adj.size(),lhs2.size());
DBUILDTRIE(print_by_nt(cerr,lhs2,cfgp));
SHOWM1(DBUILDTRIE,"lhs2",OSTRF2(print_by_nt,lhs2,cfgp));
}

void operator()(int ri) {
Expand Down Expand Up @@ -526,12 +570,8 @@ struct Chart {
} else {
break;
}

}


}

}

Chart(CFG &cfg,SentenceMetadata const& smeta,FsaFF const& fsa,unsigned reserve=FSA_AGENDA_RESERVE)
Expand Down
10 changes: 9 additions & 1 deletion decoder/cfg.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,16 @@ struct CFG {
if (w<=0) return nt_name(-w);
else return TD::Convert(w);
}
static void static_print_nt_name(std::ostream &o,NTHandle n) {
o<<'['<<n<<']';
}
static std::string static_nt_name(NTHandle w) {
std::ostringstream o;
static_print_nt_name(o,w);
return o.str();
}
static void static_print_rhs_name(std::ostream &o,WordID w) {
if (w<=0) o<<'['<<-w<<']';
if (w<=0) static_print_nt_name(o,-w);
else o<<TD::Convert(w);
}
static std::string static_rhs_name(WordID w) {
Expand Down
3 changes: 2 additions & 1 deletion training/online_optimizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ class OnlineOptimizer {
public:
virtual ~OnlineOptimizer();
OnlineOptimizer(const std::tr1::shared_ptr<LearningRateSchedule>& s,
size_t training_instances) : schedule_(s), k_(), N_(training_instances) {}
size_t training_instances)
: N_(training_instances),schedule_(s),k_() {}
void UpdateWeights(const SparseVector<double>& approx_g, SparseVector<double>* weights) {
++k_;
const double eta = schedule_->eta(k_);
Expand Down
4 changes: 2 additions & 2 deletions utils/d_ary_heap.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
#define D_ARY_UP_GRAEHL 0 // untested
#define D_ARY_APPEND_ALWAYS_PUSH 1 // heapify (0) is untested. otherwise switch between push and heapify depending on size (cache effects, existing items vs. # appended ones)

#define D_ARY_TRACK_OUT_OF_HEAP 0 // shouldn't need to track, because in contains() false positives looking up stale or random loc map values are impossible - we just check key
#define D_ARY_VERIFY_HEAP 0
#define D_ARY_TRACK_OUT_OF_HEAP 1 // shouldn't need to track, because in contains() false positives looking up stale or random loc map values are impossible - we just check key
#define D_ARY_VERIFY_HEAP 1
// This is a very expensive test so it should be disabled even when NDEBUG is not defined

/* adapted from boost/graph/detail/d_ary_heap.hpp
Expand Down
27 changes: 26 additions & 1 deletion utils/hash.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ typename H::mapped_type & get_or_call(H &ht,K const& k,F const& f) {
}
}


// the below could also return a ref to the mapped max/min. they have the advantage of not falsely claiming an improvement when an equal value already existed. otherwise you could just modify the get_default and if equal assume new.
template <class H,class K>
bool improve_mapped_max(H &ht,K const& k,typename H::mapped_type const& v) {
Expand All @@ -110,6 +109,32 @@ bool improve_mapped_max(H &ht,K const& k,typename H::mapped_type const& v) {
return false;
}


// return true if there was no old value. like ht[k]=v but lets you know whether it was a new addition
template <class H,class K>
bool put(H &ht,K const& k,typename H::mapped_type const& v) {
std::pair<typename H::iterator,bool> inew=ht.insert(typename H::value_type(k,v));
if (inew.second)
return true;
inew.first->second=v;
return false;
}

// does not update old value (returns false) if one exists, otherwise add
template <class H,class K>
bool maybe_add(H &ht,K const& k,typename H::mapped_type const& v) {
std::pair<typename H::iterator,bool> inew=ht.insert(typename H::value_type(k,v));
return inew.second;
}

// ht[k] must not exist (yet)
template <class H,class K>
void add(H &ht,K const& k,typename H::mapped_type const& v) {
bool fresh=maybe_add(ht,k,v);
assert(fresh);
}


template <class H,class K>
bool improve_mapped_min(H &ht,K const& k,typename H::mapped_type const& v) {
std::pair<typename H::iterator,bool> inew=ht.insert(typename H::value_type(k,v));
Expand Down
26 changes: 25 additions & 1 deletion utils/intern_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,30 @@ struct compose_indirect {
}


};

template <class KeyF,class F,class Arg=typename KeyF::argument_type>
struct equal_indirect {
typedef Arg *argument_type; // we also accept Arg &
KeyF kf;
F f;
typedef bool result_type;

result_type operator()(Arg const& a1,Arg const& a2) const {
return f(kf(a1),kf(a2));
}
result_type operator()(Arg & a1,Arg & a2) const {
return f(kf(a1),kf(a2));
}
result_type operator()(Arg * a1,Arg * a2) const {
return a1==a2||(a1&&a2&&f(kf(*a1),kf(*a2)));
}
template <class V,class W>
result_type operator()(V const& v,W const&w) const {
return v==w||(v&&w&&f(kf(*v),kf(*w)));
}


};

/*
Expand All @@ -79,7 +103,7 @@ struct intern_pool : Pool {
typedef typename KeyF::result_type Key;
typedef Item *Handle;
typedef compose_indirect<KeyF,HashKey,Item> HashDeep;
typedef compose_indirect<KeyF,EqKey,Item> EqDeep;
typedef equal_indirect<KeyF,EqKey,Item> EqDeep;
typedef HASH_SET<Handle,HashDeep,EqDeep> Canonical;
typedef typename Canonical::iterator CFind;
typedef std::pair<CFind,bool> CInsert;
Expand Down
12 changes: 11 additions & 1 deletion utils/show.h
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
#ifndef UTILS__SHOW_H
#define UTILS__SHOW_H


//usage: string s=OSTR(1<<" "<<c);
#define OSTR(expr) ((dynamic_cast<ostringstream &>(ostringstream()<<std::dec<<expr)).str())
#define OSTRF(f) ((dynamic_cast<ostringstream &>(f(ostringstream()<<std::dec))).str())
#define OSTRF1(f,x) ((dynamic_cast<ostringstream &>(f(ostringstream()<<std::dec,x))).str())
#define OSTRF2(f,x1,x2) ((dynamic_cast<ostringstream &>(f(ostringstream()<<std::dec,x1,x2))).str())
// std::dec (or seekp, or another manip) is needed to convert to std::ostream reference.

#ifndef SHOWS
#include <iostream>
#define SHOWS std::cerr
#endif


#define SELF_TYPE_PRINT \
template <class Char,class Traits> \
inline friend std::basic_ostream<Char,Traits> & operator <<(std::basic_ostream<Char,Traits> &o, self_type const& me) \
Expand All @@ -26,6 +33,8 @@

#define PRINT_SELF(self) typedef self self_type; SELF_TYPE_PRINT_OSTREAM



#undef SHOWALWAYS
#define SHOWALWAYS(x) x

Expand Down Expand Up @@ -62,6 +71,7 @@ careful: none of this is wrapped in a block. so you can't use one of these macr
#define SHOW7(IF,x,y0,y1,y2,y3,y4,y5) SHOW1(IF,x) SHOW6(IF,y0,y1,y2,y3,y4,y5)

#define SHOWM(IF,m,x) SHOWP(IF,m<<": ") SHOW(IF,x)
#define SHOWM1(IF,m,x) SHOWM(IF,m,x)
#define SHOWM2(IF,m,x0,x1) SHOWP(IF,m<<": ") SHOW2(IF,x0,x1)
#define SHOWM3(IF,m,x0,x1,x2) SHOWP(IF,m<<": ") SHOW3(IF,x0,x1,x2)
#define SHOWM4(IF,m,x0,x1,x2,x3) SHOWP(IF,m<<": ") SHOW4(IF,x0,x1,x2,x3)
Expand Down
4 changes: 0 additions & 4 deletions utils/stringlib.h
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
#ifndef CDEC_STRINGLIB_H_
#define CDEC_STRINGLIB_H_

//usage: string s=MAKESTRE(1<<" "<<c);
#define MAKESTR(expr) ((dynamic_cast<ostringstream &>(ostringstream()<<std::dec<<expr)).str())
// std::dec (or seekp, or another manip) is needed to convert to std::ostream reference.

#ifdef STRINGLIB_DEBUG
#include <iostream>
#define SLIBDBG(x) do { std::cerr<<"DBG(stringlib): "<<x<<std::endl; } while(0)
Expand Down
3 changes: 2 additions & 1 deletion utils/tdict.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ struct TD {
}
*/
static const WordID max_wordid=0x7fffffff;
static const WordID none=(WordID)-1; // Vocab_None
static const WordID null=max_wordid-1;
static const WordID none=(WordID)-1; // Vocab_None - this will collide with mixed node/variable id / word space, though. max_wordid will be distinct (still positive)
static char const* const ss_str; //="<s>";
static char const* const se_str; //="</s>";
static char const* const unk_str; //="<unk>";
Expand Down
6 changes: 6 additions & 0 deletions utils/wordid.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
#ifndef _WORD_ID_H_
#define _WORD_ID_H_

#include <limits>

typedef int WordID;

//namespace {
static const WordID null_wordid=std::numeric_limits<WordID>::max();
//}

#endif
Loading

0 comments on commit e417508

Please sign in to comment.