Unsupervised training algorithm (Baum-Welch implementation).
Definition at line 486 of file hmm.cc. References MorphoStream::get_next_word(), TaggerWord::get_string_tags(), TaggerWord::get_superficial_form(), TaggerWord::get_tags(), Collection::has_not(), and Collection::size(). { int i, j, k, t, len, nw = 0; TaggerWord *word=NULL; TTag tag; set<TTag> tags, pretags; set<TTag>::iterator itag, jtag; map <int, double> gamma; map <int, double>::iterator jt, kt; map < int, map <int, double> > alpha, beta, xsi, phi; map < int, map <int, double> >::iterator it; double prob, loli; vector < set<TTag> > pending; Collection &output = td->getOutput(); int ndesconocidas=0; // alpha => forward probabilities // beta => backward probabilities MorphoStream morpho_stream(ftxt, true, td); loli = 0; tag = eos; tags.clear(); tags.insert(tag); pending.push_back(tags); alpha[0].clear(); alpha[0][tag] = 1; word = morpho_stream.get_next_word(); while (word) { //wcerr<<L"Enter para continuar\n"; //getchar(); if (++nw%10000==0) wcerr<<L'.'<<flush; //wcerr<<*word<<L"\n"; pretags = pending.back(); tags = word->get_tags(); if (tags.size()==0) { // This is an unknown word tags = td->getOpenClass(); ndesconocidas++; } if (output.has_not(tags)) { wstring errors; errors = L"A new ambiguity class was found. I cannot continue.\n"; errors+= L"Word '"+word->get_superficial_form()+L"' not found in the dictionary.\n"; errors+= L"New ambiguity class: "+word->get_string_tags()+L"\n"; errors+= L"Take a look at the dictionary, then retrain."; fatal_error(errors); } k = output[tags]; len = pending.size(); alpha[len].clear(); //Forward probabilities for (itag=tags.begin(); itag!=tags.end(); itag++) { i=*itag; for (jtag=pretags.begin(); jtag!=pretags.end(); jtag++) { j=*jtag; //cerr<<"previous alpha["<<len<<"]["<<i<<"]="<<alpha[len][i]<<"\n"; //cerr<<"alpha["<<len-1<<"]["<<j<<"]="<<alpha[len-1][j]<<"\n"; //cerr<<"a["<<j<<"]["<<i<<"]="<<a[j][i]<<"\n"; //cerr<<"b["<<i<<"]["<<k<<"]="<<b[i][k]<<"\n"; alpha[len][i] += alpha[len-1][j]*(td->getA())[j][i]*(td->getB())[i][k]; } if (alpha[len][i]==0) alpha[len][i]=DBL_MIN; //cerr<<"alpha["<<len<<"]["<<i<<"]="<<alpha[len][i]<<"\n--------\n"; } if (tags.size()>1) { pending.push_back(tags); } else { // word is unambiguous tag = *tags.begin(); beta[0].clear(); beta[0][tag] = 1; prob = alpha[len][tag]; //cerr<<"prob="<<prob<<"\n"; //cerr<<"alpha["<<len<<"]["<<tag<<"]="<<alpha[len][tag]<<"\n"; loli -= log(prob); for (t=0; t<len; t++) { // loop from T-1 to 0 pretags = pending.back(); pending.pop_back(); k = output[tags]; beta[1-t%2].clear(); for (itag=tags.begin(); itag!=tags.end(); itag++) { i=*itag; for (jtag=pretags.begin(); jtag!=pretags.end(); jtag++) { j = *jtag; beta[1-t%2][j] += (td->getA())[j][i]*(td->getB())[i][k]*beta[t%2][i]; xsi[j][i] += alpha[len-t-1][j]*(td->getA())[j][i]*(td->getB())[i][k]*beta[t%2][i]/prob; } double previous_value = gamma[i]; gamma[i] += alpha[len-t][i]*beta[t%2][i]/prob; if (isnan(gamma[i])) { wcerr<<L"NAN(3) gamma["<<i<<L"] = "<<gamma[i]<<L" alpha["<<len-t<<L"]["<<i<<L"]= "<<alpha[len-t][i] <<L" beta["<<t%2<<L"]["<<i<<L"] = "<<beta[t%2][i]<<L" prob = "<<prob<<L" previous gamma = "<<previous_value<<L"\n"; exit(1); } if (isinf(gamma[i])) { wcerr<<L"INF(3) gamma["<<i<<L"] = "<<gamma[i]<<L" alpha["<<len-t<<L"]["<<i<<L"]= "<<alpha[len-t][i] <<L" beta["<<t%2<<L"]["<<i<<L"] = "<<beta[t%2][i]<<L" prob = "<<prob<<L" previous gamma = "<<previous_value<<L"\n"; exit(1); } if (gamma[i]==0) { //cout<<"ZERO(3) gamma["<<i<<"] = "<<gamma[i]<<" alpha["<<len-t<<"]["<<i<<"]= "<<alpha[len-t][i] // <<" beta["<<t%2<<"]["<<i<<"] = "<<beta[t%2][i]<<" prob = "<<prob<<" previous gamma = "<<previous_value<<"\n"; gamma[i]=DBL_MIN; //exit(1); } phi[i][k] += alpha[len-t][i]*beta[t%2][i]/prob; } tags=pretags; } tags.clear(); tags.insert(tag); pending.push_back(tags); alpha[0].clear(); alpha[0][tag] = 1; } delete word; word = morpho_stream.get_next_word(); } if ((pending.size()>1) || ((tag!=eos)&&(tag != (td->getTagIndex())[L"TAG_kEOF"]))) wcerr<<L"Warning: Thee las tag is not the end-of-sentence-tag\n"; int N = td->getN(); int M = td->getM(); //Clean previous values for(i=0; i<N; i++) { for(j=0; j<N; j++) (td->getA())[i][j]=ZERO; for(k=0; k<M; k++) (td->getB())[i][k]=ZERO; } // new parameters for (it=xsi.begin(); it!=xsi.end(); it++) { i = it->first; for (jt=xsi[i].begin(); jt!=xsi[i].end(); jt++) { j = jt->first; if (xsi[i][j]>0) { if (gamma[i]==0) { wcerr<<L"Warning: gamma["<<i<<L"]=0\n"; gamma[i]=DBL_MIN; } (td->getA())[i][j] = xsi[i][j]/gamma[i]; if (isnan((td->getA())[i][j])) { wcerr<<L"NAN\n"; wcerr <<L"Error: BW - NAN(1) a["<<i<<L"]["<<j<<L"]="<<(td->getA())[i][j]<<L"\txsi["<<i<<L"]["<<j<<L"]="<<xsi[i][j]<<L"\tgamma["<<i<<L"]="<<gamma[i]<<L"\n"; exit(1); } if (isinf((td->getA())[i][j])) { wcerr<<L"INF\n"; wcerr <<L"Error: BW - INF(1) a["<<i<<L"]["<<j<<L"]="<<(td->getA())[i][j]<<L"\txsi["<<i<<L"]["<<j<<L"]="<<xsi[i][j]<<L"\tgamma["<<i<<L"]="<<gamma[i]<<L"\n"; exit(1); } if ((td->getA())[i][j]==0) { //cerr <<"Error: BW - ZERO(1) a["<<i<<"]["<<j<<"]="<<(td->getA())[i][j]<<"\txsi["<<i<<"]["<<j<<"]="<<xsi[i][j]<<"\tgamma["<<i<<"]="<<gamma[i]<<"\n"; // exit(1); } } } } for (it=phi.begin(); it!=phi.end(); it++) { i = it->first; for (kt=phi[i].begin(); kt!=phi[i].end(); kt++) { k = kt->first; if (phi[i][k]>0) { (td->getB())[i][k] = phi[i][k]/gamma[i]; if (isnan((td->getB())[i][k])) { wcerr<<L"Error: BW - NAN(2) b["<<i<<L"]["<<k<<L"]="<<(td->getB())[i][k]<<L"\tphi["<<i<<L"]["<<k<<L"]="<<phi[i][k]<<L"\tgamma["<<i<<L"]="<<gamma[i]<<L"\n"; exit(1); } if (isinf((td->getB())[i][k])) { wcerr<<L"Error: BW - INF(2) b["<<i<<L"]["<<k<<L"]="<<(td->getB())[i][k]<<L"\tphi["<<i<<L"]["<<k<<L"]="<<phi[i][k]<<L"\tgamma["<<i<<L"]="<<gamma[i]<<L"\n"; exit(1); } if ((td->getB())[i][k]==0) { //cerr <<"Error: BW - ZERO(2) b["<<i<<"]["<<k<<"]="<<(td->getB())[i][k]<<"\tphi["<<i<<"]["<<k<<"]="<<phi[i][k]<<"\tgamma["<<i<<"]="<<gamma[i]<<"\n"; // exit(1); } } } } //It can be possible that a probability is not updated //We normalize the probabilitites for(i=0; i<N; i++) { double sum=0; for(j=0; j<N; j++) sum+=(td->getA())[i][j]; for(j=0; j<N; j++) (td->getA())[i][j]=(td->getA())[i][j]/sum; } for(i=0; i<N; i++) { double sum=0; for(k=0; k<M; k++) { if(output[k].find(i)!=output[k].end()) sum+=(td->getB())[i][k]; } for(k=0; k<M; k++) { if(output[k].find(i)!=output[k].end()) (td->getB())[i][k]=(td->getB())[i][k]/sum; } } wcerr<<L"Log="<<loli<<L"\n"; }
Here is the call graph for this function:
![]() |