#include "lexprune.h" #include "item.h" #include "lattice.h" #include "chart.h" #include #include #include extern chart *Chart; using namespace std; int lexprune(tTrigramModel *model, double threshold) { int frozen = 0; vector > states_by_end(Chart->length()); vector > states_by_start(Chart->length()); for (unsigned int t = 0; t < Chart->length(); ++t) { //states (normalised tags) seen for this end point boost::unordered_map seen; for(chart_iter_end_passive iter(Chart, t); iter.valid(); ++iter) { tItem *it = iter.current(); if (unblocked_lex_complete(it)) { //assemble state elements string tag, word, letype; const tLexItem *lexitem = dynamic_cast(it); if (lexitem) { //cast successful, LexItem tag = print_name(lexitem->type()); letype = tag; word = word_from_lexitem(lexitem, model); } else { tag = it->printname(); tItem *daughter = *(it->daughters().begin()); lexitem = dynamic_cast(daughter); while (!lexitem) {//lex rules, not lex type tag = string(daughter->printname()) + ":" + tag; daughter = *(daughter->daughters().begin()); lexitem = dynamic_cast(daughter); } tag = string(print_name(lexitem->type())) + ":" +tag; letype = string(print_name(lexitem->type())); word = word_from_lexitem(lexitem, model); } model->normalise(&word, &tag); int start = it->start(); ostringstream key(tag); key << ":" << start; if (seen.count(key.str()) > 0) { //recorded by tag and start position seen[key.str()]->itemptr(it); } else { double emit = model->getEmit(tag, word); tState *state = new tState(it, tag, emit, letype); seen[key.str()] = state; states_by_end[t].push_back(state); states_by_start[start].push_back(state); //calculate forward probabilities if (start == 0) {//base case states_by_end[t].back()->add_prev(NULL, emit + model->getTransProb(STAG(), STAG(), tag)); } else { for (vector::iterator iit = states_by_end[start].begin(); iit != states_by_end[start].end(); ++iit) { double elems = 0; string itag = (*iit)->tag(); for (vector::iterator hit = (*iit)->prevs().begin(); hit != (*iit)->prevs().end(); ++hit) { string htag = STAG(); if (hit->backptr() != NULL) htag = hit->backptr()->tag(); elems += exp(hit->alpha() + model->getTransProb(htag, itag, tag)); } if (elems == 0) {elems = DBL_MIN;} //catch log(0) issues states_by_end[t].back()->add_prev(*iit, emit + log(elems)); } //end w-2 states } //end non-base case } //end unseen state } //end valid item } //end end point } //end all end points, end forward pass //calculate normalisation double normalise = 0; for (vector::iterator eit = states_by_end[states_by_end.size()-1].begin(); eit != states_by_end[states_by_end.size()-1].end(); ++eit) { double elems = 0; string itag = (*eit)->tag(); for (vector::iterator hit = (*eit)->prevs().begin(); hit != (*eit)->prevs().end(); ++hit) { string htag = STAG(); if (hit->backptr() != NULL) htag = hit->backptr()->tag(); elems += exp(hit->alpha() + model->getTransProb(htag, itag, ETAG())); } normalise += elems; } //if normalise is not greater than zero, there is zero probability of finding //a path through the lattice (probably because of unknown words) and hence //running a backwards path is meaningless. I guess we could freeze everything, //but it's probably best to leave it to the calling code to deal with gappy //lattices. rdridan 07/10/2014 if (normalise > 0) { double c = log(normalise); //backwards pass for (unsigned int t = states_by_end.size()-1; t > 0; --t) { //t for (vector::iterator it = states_by_end[t].begin(); it != states_by_end[t].end(); ++it) { //j double post_elems = 0; for (vector::iterator iit = (*it)->prevs().begin(); iit != (*it)->prevs().end(); ++iit) { //i string itag = STAG(); if (iit->backptr() != NULL) itag = iit->backptr()->tag(); if (t == states_by_end.size()-1) { //base case, k= iit->beta(model->getTransProb(itag, (*it)->tag(), ETAG())); post_elems += exp(iit->beta() + iit->alpha() - c); } else { double elems = 0; for (vector::iterator kit = states_by_start[t].begin(); kit != states_by_start[t].end(); ++kit) { //k for (vector::iterator jit = (*kit)->prevs().begin(); jit != (*kit)->prevs().end(); ++jit) { //jk if (jit->backptr() == *it) {//beta_jk elems += exp(jit->beta() + (*kit)->emit() + model->getTransProb(itag, (*it)->tag(), (*kit)->tag())); } } } //end of k iit->beta(log(elems)); post_elems += exp(iit->beta() + iit->alpha() - c); } } //end of i double posterior = post_elems; if (posterior < threshold && !model->on_whitelist((*it)->let())) { for (vector::iterator fi = (*it)->itemptr().begin(); fi != (*it)->itemptr().end(); ++fi) { if (model->ut_debug()) { std::cerr << "freezing " << (*it)->tag() << " at prob " << posterior << std::endl; } (*fi)->freeze(false); frozen++; } } } //end of j } //end of t, end forward-backward } //clean up for (unsigned int t = 0; t < states_by_end.size() - 1; ++t) { for (vector::iterator it = states_by_end[t].begin(); it != states_by_end[t].end(); ++it) delete *it; } return frozen; } int viterbi(tTrigramModel *model) { int frozen = 0; vector > states_by_end(Chart->length()); vector > states_by_start(Chart->length()); for (unsigned int t = 0; t < Chart->length(); ++t) { //states (normalised tags) seen for this end point boost::unordered_map seen; for(chart_iter_end_passive iter(Chart, t); iter.valid(); ++iter) { tItem *it = iter.current(); if (unblocked_lex_complete(it)) { //assemble state elements string tag, word, let; const tLexItem *lexitem = dynamic_cast(it); if (lexitem) { //cast successful, LexItem tag = print_name(lexitem->type()); let = tag; word = word_from_lexitem(lexitem, model); } else { tag = it->printname(); tItem *daughter = *(it->daughters().begin()); lexitem = dynamic_cast(daughter); while (!lexitem) {//lex rules, not lex type tag = string(daughter->printname()) + ":" + tag; daughter = *(daughter->daughters().begin()); lexitem = dynamic_cast(daughter); } tag = string(print_name(lexitem->type())) + ":" +tag; let = string(print_name(lexitem->type())); word = word_from_lexitem(lexitem, model); } model->normalise(&word, &tag); int start = it->start(); ostringstream key(tag); key << ":" << start; if (seen.count(key.str()) > 0) { //recorded by tag and start position seen[key.str()]->itemptr(it); } else { double emit = model->getEmit(tag, word); tState *state = new tState(it, tag, emit, let); seen[key.str()] = state; states_by_end[t].push_back(state); states_by_start[start].push_back(state); double maxProb = log (DBL_MIN); tState *bestNode = NULL; int best_idx = -1; //calculate forward probabilities if (start == 0) {//base case states_by_end[t].back()->add_prev(NULL, emit + model->getTransProb(STAG(), STAG(), tag), -1); } else { for (vector::iterator iit = states_by_end[start].begin(); iit != states_by_end[start].end(); ++iit) { maxProb = log(DBL_MIN); best_idx = -1; string itag = (*iit)->tag(); int idx = 0; for (vector::iterator hit = (*iit)->prevs().begin(); hit != (*iit)->prevs().end(); ++hit) { string htag = STAG(); if (hit->backptr() != NULL) htag = hit->backptr()->tag(); double logProb = hit->delta() + model->getTransProb(htag, itag, tag) + emit; if (best_idx == -1 || maxProb < logProb) { maxProb = logProb; best_idx = idx; bestNode = *iit; } idx++; } states_by_end[t].back()->add_prev(bestNode, maxProb, best_idx); } //end w-2 states } //end non-base case } //end unseen state } //end valid item } //end end point } //end all end points, end forward pass //backwards pass tState *best = NULL; double bestdelta = log(DBL_MIN); int best_uidx = -1; unsigned int t = states_by_end.size()-1; for (vector::iterator it = states_by_end[t].begin(); it != states_by_end[t].end(); ++it) { //j int idx = 0; for (vector::iterator iit = (*it)->prevs().begin(); iit != (*it)->prevs().end(); ++iit) { //i string itag = STAG(); if (iit->backptr() != NULL) itag = iit->backptr()->tag(); double endDelta = model->getTransProb(itag, (*it)->tag(), ETAG()) + iit->delta(); if (best == NULL || bestdelta < endDelta) { best = *it; bestdelta = endDelta; best_uidx = idx; } idx++; } //end of i } //end of j, found end of best path tState *curr = best; int curridx = best_uidx; for (unsigned int t = states_by_end.size()-1; t > 0; --t) { //t for (vector::iterator it = states_by_end[t].begin(); it != states_by_end[t].end(); ++it) { //j if (*it == curr) { curr = (*it)->prevs()[curridx].backptr(); curridx = (*it)->prevs()[curridx].trace(); } else { for (vector::iterator fi = (*it)->itemptr().begin(); fi != (*it)->itemptr().end(); ++fi) { (*fi)->freeze(false); frozen++; } } } } //clean up for (unsigned int t = 0; t < states_by_end.size() - 1; ++t) { for (vector::iterator it = states_by_end[t].begin(); it != states_by_end[t].end(); ++it) delete *it; } return frozen; } string word_from_lexitem(const tLexItem *lexitem, tTrigramModel *model) { string word = lexitem->orth(); string caseclass; for (item_citer dit = lexitem->daughters().begin(); dit != lexitem->daughters().end() && caseclass.empty(); ++dit) { tInputItem *dp = dynamic_cast(*dit); if (dp) { //TODO get path from settings fs casefs = dp->get_fs().get_path_value("+CLASS.+CASE"); if (casefs != FAIL) caseclass = casefs.printname(); } } if (!caseclass.empty()) { word += model->caseclass_sep(); word += caseclass; } return word; }