Python

Tensorflowのモデリング1〜CNN

投稿日:2018年1月23日 更新日:

畳み込みニューラルネットワーク

今までは単純なディープラーニングのモデルを利用してきたけど、畳み込みニューラルネットワーク(CNN)のモデルを試してみた.
畳み込みニューラルネットワークは、画像などの学習に適したモデル.
MNISTのチュートリアルでは、手書きの画像を単純な1次元のベクトル(28×28=784)で入力して、学習していた.
実際の画像の特徴は、部分に現れることも多いハズ.
そこで、28*28*1の画像(*1はチャンネルという.)として入力値として、例えば、5*5*1のフィルターに合致するかどうかをチェックしていく.
フィルターにマッチすればするほど1に近く、逆に全く反転している場合は−1に近くなるようにする.
この計算を畳み込み計算という.このフィルターを学習していくことで特徴を抽出することができる.
こうした計算をする層を畳み込み(Conv)層という.

そして、一定の範囲(ここでは2*2)の最大値を抜き出していく.
そうした処理をすることで多少特徴が上下左右にずれていても、特徴を抽出することができる.
こうした処理をする層をプーリング(Pool)層といい、最大値を抽出するやり方をMAX-Pooling法という.

CNNの仕組みの解説は次の本が詳しい.

ゼロから作るDeep Learning ―Pythonで学ぶディープラーニングの理論と実装

ゼロから作るDeep Learning ―Pythonで学ぶディープラーニングの理論と実装書籍

作者斎藤 康毅

発行オライリージャパン

発売日2016年9月24日

カテゴリー単行本(ソフトカバー)

ページ数320

ISBN4873117585

Supported by amazon Product Advertising API

仕組みはややこしいが、実装は難しくなく、Tensorflowには、Conv層、Pool層用のライブラリが用意されている.
仕組みをある程度理解しておかないと引数の意味がわからないので、そのために理解しておけばいいかと思われる.

Conv層

Deep MNIST for Expertsに従って実装した.

まず、入力した値を28×28×1の形に変換する.
reshapeで-1のところはそのままという意味.ここはバッチトレーニング用のデータの数になる.

次に、学習するためのウェイトとバイアスを定義する.
ここでは、ウェイトはフィルターの形になる.5*5*1の形のフィルターを学習することになる.
最後の32は出力する形.32の行列を出力する.これで次の深層学習のセルに入力することができる.

tf.nn.conv2dを適用する.stridesは移動するベクトル.[0]と[3]の値は1でなければならない.padding=’SAME’はまわりを0で囲む処理(ゼロパディング).
TensorflowでCNNを作る際に使いそうな関数を列挙してみた
[ディープラーニング] パディングとストライド
最後に活性化関数としてRELU関数を適用している.

POOL層

tf.nn.max_poolを適用する.
ksizeは最大値を取る領域の大きさ、stridesは領域の動く範囲.
padding=’SAME’はまわりを0で囲む処理(ゼロパディング).

残りの処理

MNIST FOR EXPERTに従い、設定する.
もう1個CNNレイヤーのあとに全結合層を2つ.

途中で、ドロップアウトを設定.複雑なモデルで過学習を抑制するためにランダムにニューロンを消す仕組み.
tf.nn.dropoutを適用.

ソフトマックスはここでは適用しない.

なぜならば、クロスエントロピーと一緒に計算するから.

結果

チュートリアルでは20000回トレーニングしてるけど、前との比較のため1000回トレーニングしてみる.
めちゃおそい.

学習回数を増やせば99%までいくぽい.

ハイレベルニューラルネットレイヤーを使う

典型的なパターンはtf.layersにまとめられている.抽象度が高くその意味でハイレベル.
A Guide to TF Layers: Building a Convolutional Neural Network

Variableを定義しなくても、次のように書くことができる.引数も少しシンプルに書けるようになっている.

-Python
-, , ,

執筆者:

関連記事

市場テクニカル分析ライブラリta-libとPython用のラッパーTa-Libをインストールする

By: GotCredit – CC BY 2.0 目次1 ta-lib2 Ta-Lib ta-lib ta-libは市場のテクニカル分析用のライブラリ集. linuxの場合はソースコード …

pandasとmatplotlibで株式取引の可視化

By: sprklg – CC BY 2.0 目次1 可視化2 下準備3 pandasのplot4 matplotlibでsubplot5 売買データをプロットする6 番外編 ローソク足チ …

Raspberry Pi3とsense HATで遊ぶ

By: Su Yin Khoo – CC BY 2.0 目次1 Raspberry Pi32 Raspbian3 Raspbianでsshを使う4 raspi-configで初期設定5 R …

Pythonではじめての機械学習~scikit-learn、tensorflow

By: krheesy – CC BY 2.0 目次1 機械学習2 scikit-learn3 tensorflow 機械学習 勉強しているところなので、間違っている箇所もあるかもしれませ …

Pythonではじめての強化学習〜OpenAIGym

By: scarletgreen – CC BY 2.0 目次1 強化学習2 強化学習問題のモデル化〜マルコフ決定過程3 Open AI Gym4 Open AI Gymのインストール5 …