読者です 読者をやめる 読者になる 読者になる

それなりにうまくいったNB

前回の記事では、うまくいかなかったナイーブベイズを紹介しました.

結局何が悪いのかよくわからないので別のデータを使ってやってみることにした.

使用データ:カテゴリ付きの新聞記事データ
特徴ベクトル:単語出現頻度

データサンプル

単語行列.txt
4191 22855 4
49 12:3 20:3 62:6 97:2 ... 1 21755:2 1 
…

1行目->記事数 単語数 カテゴリ数
2行目以降->ある記事の単一の単語数 単語ID:出現数 … 最後の数値はカテゴリID 

単語.txt
1	patty
2	hough
3	sandia

ID\t単語

せっかくなので少しだけコードを改良する.
・交差検定(k-hold cross validation)の導入
・未知語に対してラプラススムージング(\alpha=1)の導入
・事前確率は学習データの全文書と各カテゴリの文書数から求める


コードはそう簡単に綺麗にはなりませんね…

#include <iostream>
#include <fstream>
#include <cstdlib>
#include <cmath>
#include <iomanip>
#include <vector>
#include <algorithm>
#include <map>
#include <sstream>
#include <ctime>
#include <math.h>

using namespace std;

//シャッフル用
vector< int > VecA;
//カテゴリ行列
vector< int > VecC;
//カテゴリごとの記事数
vector< int > VecD;
//カテゴリごとの単語出現確率行列
vector< vector<double> > MatW;
int Dm,Dn,Dc,K;

class NB{
public:
  int word;
  int freq;
  vector< vector <NB> > MatA;
  void readfile(char *fn1);
  void initValue(int s, int t);
  double calNB(int s, int t);
};

class Random
{
public:
  // コンストラクタ
  Random()
  {
    srand( static_cast<unsigned int>( time(NULL) ) );
  }

  // 関数オブジェクト
  unsigned int operator()(unsigned int max)
  {
    double tmp = static_cast<double>( rand() ) / static_cast<double>( RAND_MAX );
    return static_cast<unsigned int>( tmp * max );
  }
};

void NB::readfile(char *fn1){
  ifstream fin;
  int i,j,size,k,l;
  char c;
  NB obj;
  fin.open(fn1);
  if(!fin){
    cerr << "ERROR: Failed to open file" << endl;
    exit(1);
  }
  fin >> Dm >> Dn >> Dc;
  MatA.resize(Dm);
  VecA.resize(Dm);
  VecC.resize(Dm);
  for(i=0;i<Dm;i++){
    fin >> size;
    MatA[i].resize(size);
    for(j=0;j<size;j++){
      fin >> k >> c >> l;
      obj.word = k-1;
      obj.freq = l;
      MatA[i][j] = obj;
    }
    VecA[i] = i;
    fin >> k;
    VecC[i] = k-1; 
  }
  fin.close();
}

//学習データの作成
void NB::initValue(int S, int T){
  int i,j,k;
  VecD.resize(Dc);
  MatW.resize(Dn);
  for(i=0;i<Dn;i++){
    MatW[i].resize(Dc);
    for(j=0;j<Dc;j++){
      MatW[i][j] = 0.0;
    }
  }
  for(i=0;i<Dc;i++) VecD[i] = 0;
  for(i=0;i<Dm;i++){
    if(i < S || i >= T){
      for(j=0;j<MatA[VecA[i]].size();j++){
	MatW[MatA[VecA[i]][j].word][VecC[VecA[i]]] += MatA[VecA[i]][j].freq;
	VecD[VecC[VecA[i]]] += MatA[VecA[i]][j].freq;
      }
    }
  }
}


//カテゴリごとに確率計算.
double NB::calNB(int S, int T){
  int i,j,k,max,count;
  double v,w,P;
  for(i=S,count=0;i<T;i++){
    for(j=0,v=0.0;j<Dc;j++){
      v = log(VecD[j]/(double)(Dm-Dm/K));
      for(k=0;k<MatA[VecA[i]].size();k++) v += log(MatW[MatA[VecA[i]][k].word][j]+1.0/(double)(VecD[VecC[VecA[i]]]+Dn));
      if(j==0){
	max = j;
	P = v;
      }else{
	if(v > P){
	  max = j;
	  P = v;
	}
      }
    }
    if(max = VecC[VecA[i]]) count++;
  }
  return (double)count/(double)(Dm/K);
}

int main(int argc, char **argv){
  int i,j;
  double means, SD, v;
  NB obj;
  Random r;
  //文書ごとの単語行列
  obj.readfile(argv[1]); //readRT.lbl
  cout << "read OK\n";
  K = atoi(argv[2]); //交差検定の分割個数
  vector< double > VecK;
  VecK.resize(10);
  for(j=0;j<10;j++){
    random_shuffle( VecA.begin(),VecA.end() ,r);
    for(i=0,v=0.0;i<K;i++){
      int S = (Dm/K)*i;
      int T = (Dm/K)*(i+1);
      obj.initValue(S,T);
      v += obj.calNB(S,T);
      //cout << v << endl;
    }
    VecK[j] = v/K;
    printf("cal%dok->%e\n",j+1,v/K); 
  }
  for(i=0,means=0.0;i<10;i++){
    means += VecK[i];
  }
  means /= 10;
  for(i=0,SD=0.0;i<10;i++){
    SD += pow(VecK[i]-means,2);
  }
  printf("分類精度=%e, SD=%e",means,SD);
  return 0;
}

結果
交差検定を10回繰り返して平均と偏差を取ってみる

$ ./nb_cross ./readdata.dat 10 > logclos10.txt &

cal1ok->7.785203e-01
cal2ok->7.782816e-01
cal3ok->7.782816e-01
cal4ok->7.782816e-01
cal5ok->7.785203e-01
cal6ok->7.782816e-01
cal7ok->7.785203e-01
cal8ok->7.782816e-01
cal9ok->7.782816e-01
cal10ok->7.782816e-01
分類精度=7.783532e-01, SD=1.196165e-07


$ ./nb_cross ./readdata.dat 100 > logclos100.txt &

cal1ok->7.790244e-01
cal2ok->7.782927e-01
cal3ok->7.787805e-01
cal4ok->7.795122e-01
cal5ok->7.790244e-01
cal6ok->7.773171e-01
cal7ok->7.773171e-01
cal8ok->7.773171e-01
cal9ok->7.780488e-01
cal10ok->7.773171e-01
分類精度=7.781951e-01, SD=6.567519e-06

分類精度は約78%くらい.Kを大きくすると学習データが大きくなるし偏差は小さくなりそうなんだがなぁ…
まだ何か間違っているかもしれないがそれっぽい結果は出ている。

ひとまずよしとして次の章に進みます。