Python

Tensorflowのモデリング2〜RNN

投稿日:

Reccurent Neural Network(RNN)

Reccurent Neural Network(RNN)は、前のデータの出力をインプットに使うモデル.
一個前のデータとの関連を踏まえて学習する仕組み.
そのため、時系列データや文字の並び順が大事なデータなどの学習に適している.
今回は、MNISTで28*28の画像データを28*1のデータ28個に分けて学習してみた.
前の列との関係も含めて学習してみる.

こうしたRNNはある状況で、勾配消失か勾配爆発しやすいとされており、Long short-term memory(LSTM)を使うとよいらしい.
人間の記憶のように短期記憶を長くもたせる仕組みを取り入れているらしい.
LSTMではだいぶ前のステップのデータも記憶しながら、いい感じに忘却することも学ぶ.
これで幾つか前のデータとの関連もうまく学習できるという仕組み.

Tensorflowでは、RNN用のモデルも用意されているので、引数とかの理解のために仕組みがだいたいわかればよい.
次のサイトを参考にMNISTの学習をしてみた.
これを理解できれば自然言語処理もできちゃう? MNISTでRNN(LSTM)を触りながら解説
TensorFlow MNISTをRNNでやってみる

入力値を計算

RNNに入力するのは最初に述べたとおり、28*1を28個の連続データとして入力する.
tf.transposeは、行列の順番を[1,0,2]の順番に入れ替える.

RNNのモデル

cellというRNNの計算単位を設定して、RNN関数にぶち込む.
tf.contrib.rnn.BasicLSTMCellを使う.
forget_biasは最初は忘れにくくするバイアスでデフォルトの1.
rnnの計算をするtf.contrib.rnn.static_rnnにいれる.

損失関数とオプティマイザーの定義

このあたりはだいたい同じで.

学習結果

今回も比較のため1000回の学習に留める.

結果、0.94.

-Python
-, , ,

執筆者:

関連記事

Raspberry Pi3とsense HATで遊ぶ

By: Su Yin Khoo – CC BY 2.0 目次1 Raspberry Pi32 Raspbian3 Raspbianでsshを使う4 raspi-configで初期設定5 R …

Python3で統計の勉強まとめ

By: Anssi Koskinen – CC BY 2.0 目次1 Hackerrank 10 Days of Statistics2 Day0 平均値、中央値、最頻値3 Day1 四分 …

はじめてのKaggle~pandas、scikit-learn

By: Internet Archive Book Images – Flickr Commons 目次1 kaggle2 kaggleコマンドのインストール3 タイタニック:災害の機械学 …

Pythonではじめての機械学習~scikit-learn、tensorflow

By: krheesy – CC BY 2.0 目次1 機械学習2 scikit-learn3 tensorflow 機械学習 勉強しているところなので、間違っている箇所もあるかもしれませ …

Kaggleの機械学習のコースで勉強したまとめ

By: Martin Howard – CC BY 2.0 目次1 Machine Learning Course2 データの事前処理3 Model Validation 学習結果の評価4 …