Chainer/MNIST

提供: fukudat
移動: 案内検索

MNIST とは,手書き文字認識を題材にした,Machine Learning 業界の標準ベンチマークデータセット.Chainer のソースコードには MNIST を使った手書き文字認識モデルのサンプルが付属している.

入手方法

Chainer/examples/mnist から MNIST のサンプルをゲット.

大して大きくないので git で Chainer 丸ごと clone しても良いし,サンプルファイル train_mnist.py だけをダウンロードしても良い.

丸ごと clone の場合,

$ git clone https://github.com/chainer/chainer chainer
$ cd chainer
$ git checkout v3.1.0

ここで,v3.1.0 は使用している chainer のバージョン (pip list で表示される).バージョンを合わせておかないと思わぬエラーが出ることがある.ソースが丸ごとコピーされるので,chainer/examples/mnist/train_mnist.py を使う.

個別にダウンロードするなら,https://raw.githubusercontent.com/chainer/chainer/master/examples/mnist/train_mnist.py を mnist ディレクトリにゲット.

$ mkdir mnist
$ cd mnist
$ wget https://raw.githubusercontent.com/chainer/chainer/master/examples/mnist/train_mnist.py
$ chmod a+x train_mnist.py

実行

まずは default で実行してみる.

$ ./train_mnist
GPU: -1
# unit: 1000
# Minibatch-size: 100
# epoch: 20

epoch       main/loss   validation/main/loss  main/accuracy  validation/main/accuracy  elapsed_time
1           0.191714    0.0924508             0.94215        0.9721                    30.0245       
2           0.0734554   0.10176               0.9773         0.9673                    64.1463       
...

終了すると ./result というディレクトリの下に以下のファイルが作成される.

$ ls -1 ./result
accuracy.png         # epoch vs accuracy 曲線        
cg.dot               # network 構造を記録したファイル. dotコマンドで表示できる.
log                  # 学習の進行具合を記録した JSON データ                 
loss.png             # epoch vs loss 曲線             
snapshot_iter_12000  # モデルの状態を保存したファイル.このファイルから学習を継>続できる.

train_mnist.py には以下のオプションが存在する.

--unit #unit
ユニット数を指定する.default = 1000
--epoch #epoch
エポック数 (データを読み込む回数) を指定する.default = 20
--batchsize #batchsize
ミニバッチで読み込むイメージの枚数.default = 100
--frequency #interval
snapshot をとる頻度.default = -1 (途中でスナップショットを取らない.終了時には取る)
--noplot
accuracy.png, loss.png を出力しない.default = 出力する
--out directory
結果を出力するディレクトリを指定する.default = result
--resume snapshotfile
snapshotfileにある途中結果を読み込んで学習を始める.default = 最初から学習する

例えば,result/snapshot_iter_12000 を読み込んで,epoch 30 まで学習を進めるには,

$ ./train_mnist.py --resume result/snapshot_iter_12000 --epoch 30
GPU: -1
# unit: 1000
# Minibatch-size: 100
# epoch: 30

epoch       main/loss   validation/main/loss  main/accuracy  validation/main/accuracy  elapsed_time
1           0.193788    0.0922422             0.941267       0.9697                    25.6605       
2           0.0743671   0.0806236             0.9769         0.9736                    52.6914       
...
20          0.0072788   0.100755              0.99785        0.9836                    570.553       
20          0.000123401                        1                                        571.249       
21          0.00131836  0.0921488             0.999616       0.9867                    607.077       
 

グラフ構造の表示

学習が始まると,result/cg.dot ファイルにネットワークグラフの構造が出力される. 人間が読める形式ではないので,次のコマンドで可視化する.ただし,graphviz パッケージがインストールされている必要があるので,Chainer/Ubuntu#graphvizのInstall or Chainer/Mac#graphvizのInstall を参照してインストールしておく.

$ cd result
$ dot  -Tpng cg.dot -o cg.png

すると,cg.png ファイルができるのでグラフ構造をビジュアルに表示できる.

個人用ツール
名前空間

変種
操作
案内
ツールボックス