KerasでGAN

Posted on May 31, 2017 by naca_cyan13

前回の記事 - Mnistチュートリアルをやってみた

Kerasは簡単にニューラルネットワークを扱うことができるライブラリです。それを使ってみて学ぼう、というのがこの記事の目的です。Kerasの各機能をまとめ、GANという機械学習モデルを実装しました。

今回も完全に手探りなので、間違いはDisqusで指摘しまくってください。

Kerasとは

Keras Documentation より

Kerasは,Pythonで書かれた,TensorFlowまたはTheano上で実行可能な高水準のニュー ラルネットワークライブラリです. Kerasは,迅速な実験を可能にすることに重点を置いて開発されました. 可能な限り遅れなくアイデアから結果に進められることは,良い研究をする上で重要です.

では具体的にどのように使うのかを見て行きましょう。

Sequential

Kerasを使って学習モデルを構築する上で、コアとなるものがSequentialオブジェクトです。 Sequentialはモデルのレイヤーを積み重ねたものです。

例えば入力層784次元、中間層1024次元、出力層10次元のニューラルネットワークの構築 は以下のコードで実装できます。

from keras.models import Sequential
from keras.layers import Dense, Activation

model = Sequntial()
model.add(Dense(1024, input_dim=784))
model.add(Activation('sigmoid'))
model.add(Dense(10))
model.add(Activation('softmax'))

次に、各レイヤーのオブジェクトとしてどんなものが用意されているのかまとめます。

Dense, Activate

上の例で登場したDenseは全結合ニューラルネットワークレイヤーです。 Activationは活性化レイヤーで、RuLUsigmoidtanhなど多数用意されていま す。

Convolution, UpSampling

Convolutionは畳み込みニューラルネットワークレイヤー(CNN)です。CNNについては後 述します。UpSamplingとはベクトルの次元を増やすもので、画像を引き伸ばすようなイ メージのものです。

その他

その他にもMaxPoolingLocallyConnectedなどのレイヤーオブジェクトが用意されて います。 しかし、今回のテーマには使わなかったので今後使う機会があったらまとめようと思いま す。

CNNとは

CNN(Convolutional Neural Network)とは、ニューラルネットワークに畳み込みフィルタ の概念を持ち込んだものです。画像におけるCNNを、数式を使って説明します。

\mathrm{x} を画像、MN を画像のサイズ、\omegaを畳み込みフィルタ、 mn を畳み込みフィルタのサイズ、\mathrm{b}をバイアス、 k をフィルタのインデックスとしたとき、出力\mathrm{a}^{k}

\displaystyle a_{ij}^{k} =\sum_{s=0}^{m-1} \sum_{t=0}^{n-1} \omega_{st}^{k} x_{i + s, j + t} + b^{k}

となります。ここで、CNNレイヤーを通ったあとのテンソルのランクを考えます。もとの画像が行列、つまり2階テンソル。これがフィルタの数だけ積み上がるので、3階のテンソルになります。

次に、CNNにおける学習を見てみましょう。学習すべきパラメータは\omegabです。これをバックプロパゲーションで学習させます。損失関数をEとすると、修正量\Delta \omega_{st}^k\Delta b^kはそれぞれ

\displaystyle \Delta \omega_{st}^k = - \eta_{\omega} \sum_{i=0}^{M-m} \sum_{j=0}^{N-n} \frac{\partial E}{ \partial a_{ij}^k} \frac{\partial a_{ij}^k}{\partial \omega_{st}^{k}} \displaystyle \Delta b^k = - \eta_{b} \sum_{i=0}^{M-m} \sum_{j = 0}^{N - n} \frac{\partial E}{\partial a_{ij}^k} \frac{\partial a_{ij}^k}{\partial b^k}

となります。ただし\eta_{\omega}\eta_{b}は学習係数です。もっと多層になる場合も、連鎖律を使って分解すれば求めることができます。


改めて今回やったこと

GANという学習モデルを使って、りんごの画像を読み取らせ、それに似たような画像を出力させてみました。

GANとは

GAN(Generative Adversarial Network)とは、生成モデルGeneratorと、識別モデルDiscriminatorを使って、訓練データと似たようなものを生成するための学習モデルです。

学習について説明します。まずに乱数のベクトルを入力します。それを元に、Generatorが画像をネットワークを使って生成します。その生成された画像と訓練データを交互にDiscriminatorに入力します。このとき、Discrimatorはその画像が訓練データであるなら1を、そうでない(偽物)なら0を出力させるようにします。その教師信号をGeneratorにも伝播させます。

この学習により、Discriminatorはより正確に生成された画像と訓練データを識別できるようになり、Generatorはより正確に訓練データとそっくりな画像を生成する事ができるようになります。

今回使った学習モデル

Kerasには学習モデルを画像として出力する機能があります。以下が今回使った学習モデルの画像です。実装はここのコードをかなり参考にさせていただきました。

右:Generator、左:Discriminator

訓練データ

googleさんで適当に「りんご 画像」と検索して出てきたものをダウンロードしました。

りんご 画像

結果

batch 0

batch 2500

batch 5000

まとめ

ハイパーパラメータの調整方法を知らないのと、RGBの3色の色ベクトルと畳み込みでできたチャネルをごっちゃにしてCNNに流したりしているので、色がバラバラになっている部分があります。

また、こんなにもニューラルネットワークがメモリを食うものとは思わなかったので、最初は256x256(x3)の画像を読ませようとしてpythonに怒られたりと、そのあたりで学ぶことが多かったです。近いうちにRGB画像対応のGANのもっとよい実装をやってみたいと思います。

GANの実装を通して、Kerasやニューラルネットワークについて知るという目的は概ね達成されたので満足です。

今後やりたいこと

次は画像から離れ、RNN(Recurrent Neural Network)などをつかって、チャットボットなどを作れたらいいなと思っています。

参考文献