トップ 差分 一覧 ソース 検索 ヘルプ RSS ログイン

BugTrack-R備忘録/59

R備忘録 /状態空間モデリング/donlp2/その他のメモ

R備忘録 - 記事一覧

ソースを修正して、nnet(ニューラルネットワーク)の計算速度の改善をはかる

  • 投稿者: みゅ
  • カテゴリ: なし
  • 優先度: 普通
  • 状態: 完了
  • 日時: 2010年04月03日 23時14分01秒

内容

  • メモ
  • まあ、いまさらって感じですが.

リンク

方法

  • OS : LINUX
  • ATLAS化されたBLASを使って高速化.ネットワークが大きい場合(ニューロンの数が多い場合)には有効.
  • R-2.8.0で確認.それ以下でも大丈夫だと思いますが自己責任でお願いします.
  • nnetはR-2.8.0で自動的にインストールされます.ソースを見てみると【VR_7.2-44.tar.gz】の中にnnetがあります.Rのソース(nnet.R)を見ると
tmp <- .C(VR_dovm,
           as.integer(ntr), Z, weights,
    as.integer(length(wts)),
    wts=as.double(wts),
    val=double(1),
    as.integer(maxit),
    as.logical(trace),
    as.integer(mask),
           as.double(abstol), as.double(reltol),
           ifail = integer(1)
    )
  • これがメインのルーティンぽい.Cのソース(nnet.c)を見るとvmminという関数を呼んでいる.
void
VR_dovm(Sint *ntr, Sdata *train, Sdata *weights,
Sint *Nw, double *wts, double *Fmin,
Sint *maxit, Sint *trace, Sint *mask,
double *abstol, double *reltol, int *ifail)
{
   int fncount, grcount;
   NTrain = *ntr;
   TrainIn = train;
   TrainOut = train + Ninputs * NTrain;
   Weights = weights;
   vmmin((int) *Nw, wts, Fmin, fminfn, fmingr, 
  (int) *maxit, (int) *trace, mask,
  *abstol, *reltol, REPORT, NULL, &fncount, &grcount, ifail);
}
  • このvmminというのはRの関数optim()のmethod="BFGS"を指定したときに呼ばれる関数.なので今度はoptim.cのvmmin関数を見てみた.すると行列の演算なんかをBLASを使わずに、べた書きしている.例えば以下.
   //gradproj = 0.0;
   //for (i = 0; i < n; i++) {
   //  s = 0.0;
   //  for (j = 0; j <= i; j++) s -= B[i][j] * g[l[j]];
   //  for (j = i + 1; j < n; j++) s -= B[j][i] * g[l[j]];
   //  t[i] = s;
   //  gradproj += s * g[l[i]];
   //}
  • ここはgradproj := c %*% B %*% c という計算を行っている.こういう行列演算は行列が大きい場合にはATLASを使用すると計算が非常に早くなる.ここは以下のように書くとソースもすっきりして何を計算しているかわかりやすくもなる.
F77_CALL(dsymv)(UPLO, &n, &alphaMOne, B0, &n, c, &INCX, &betaZero, t, &INCY);
gradproj = F77_CALL(ddot)(&n, t, &INCX, c, &INCY);
  • こんな感じでvmmin関数を修正したソースを以下に置いたので、興味のある方は使ってみてください.
  • optim.cのvmminの定義のところを以下でそっくり置き換えて、RをビルドしなおせばOK.
  • vmmin.txt(121) (文字コードはeuc-jpです)
  • バグ報告や質問なんかは別館のほうにお願いします.
  • makeして、「make install」する前に、makeしたディレクトリで下で動作を確認できます.
bin/R
  • として、Rが起動するのでここでnnetの動作確認をして、スピードアップしてることを確認してから「make install」すれば、今のRをアンインストールしなくて済みます.
  • Rは以下でコンフィグしないとlapackを正しく認識してくれない(かも).
./configure --enable-R-shlib 
    --with-blas="-L/usr/local/atlas/lib -llapack -lptcblas -lptf77blas -latlas -L/usr/local/atlas/lib -lptf77blas -latlas -lpthread" 
    --with-lapack

ついでに

  • optim関数の「method = "L-BFGS-B"」の場合も修正しちゃいます.といってもこっちはlinpackを使っていて、その中でBLASを呼んでいるので、別に計算速度がはやくなるわけじゃありません.LINPACK使っているのが、気持ち悪いから.それだけ・・・
  • src/appl/lbfgsb.cを修正します.
//#include <R_ext/Linpack.h>
char *UPLO = "U";
char *TRANS_N = "N";
char *TRANS_T = "T";
char *DIAG = "N";
static const int NRHS = 1;

   //F77_CALL(dtrsl)(&wt[wt_offset], &m, col, &p[*col + 1], &c__11, info);
   F77_CALL(dtrtrs)(UPLO, TRANS_T, DIAG, col, &NRHS, &wt[wt_offset], &m, &p[*col + 1], &m, info);
・・・
   //F77_CALL(dtrsl)(&wt[wt_offset], &m, col, &p[*col + 1], &c__1, info);
   F77_CALL(dtrtrs)(UPLO, TRANS_N, DIAG, col, &NRHS, &wt[wt_offset], &m, &p[*col + 1], &m, info);
・・・
   for (js = *col + 1; js <= col2; ++js) {
       //F77_CALL(dtrsl)(&wn[wn_offset], &m2, col,
       //              &wn[js * wn_dim1 + 1], &c__11, info);
       F77_CALL(dtrtrs)(UPLO, TRANS_T, DIAG, col, &NRHS, &wn[wn_offset], &m2, &wn[js * wn_dim1 + 1], &m2, info);
   }
・・・
   /*Rprintf(" job : %d\n", c__11);
   for( i=0; i<m2; i++ ){
       Rprintf("%d > ", i);
       for( k=0; k<col2; k++ ){
           Rprintf("%f, ", wn[wn_offset+i+m2*k]);
       }
       Rprintf(" b : %f\n", wv[1 + i]);
   }*/
   //F77_CALL(dtrsl)(&wn[wn_offset], &m2, &col2, &wv[1], &c__11, info);
   F77_CALL(dtrtrs)(UPLO, TRANS_T, DIAG, &col2, &NRHS, &wn[wn_offset], &m2, &wv[1], &m2, info);
   if (*info != 0) {
       return;
   }
   /*for( i=0; i<m2; i++ ){
       Rprintf("%d > ", i);
       for( k=0; k<col2; k++ ){
           Rprintf("%f, ", wn[wn_offset+i+m2*k]);
       }
       Rprintf(" b : %f\n", wv[1 + i]);
   }*/
・・・
   //F77_CALL(dtrsl)(&wn[wn_offset], &m2, &col2, &wv[1], &c__1, info);
   F77_CALL(dtrtrs)(UPLO, TRANS_N, DIAG, &col2, &NRHS, &wn[wn_offset], &m2, &wv[1], &m2, info);

コメント