gensimでDoc2Vecと格闘する
今回は少し前に大ブームになっていたらしいDoc2Vec( Word2Vec)です。Doc2Vecでも内部ではWord2Vecが動いているので、どちらにしてもWord2Vecです。gensimを使ってPythonから呼び出そうと思いましたが、困ったことに使い方がさっぱりわかりません。ネット上に転がっているサンプルは、うまく動かなかったり、どこからかコピペしたものを焼き増ししたり、説明なく謎のパラメーター設定をしてたりしていたため、自分の環境ではうまく動かせず、年末・年始を非常に有意義に消化してしまいました。
まだ詳しいところ(各種パラメーター等)理解していませんが、結果が出たので雑にまとめておきます。
各パラメータとサンプルコード
今回も以前Qiitaに投稿した際に使用した、Yahooニュースの原稿を用いて学習をすすめます。
パラメーターについて
- 学習率の初期値alphaは、今回かなり小さめに設定しました。ニュース原稿のような長文で複雑なデータは、非常に発散しやすいためです。
- min_alphaもかなり小さくしています。
- windowは変更しても大きな変化はありませんでした。
- size(ベクトルの次元数)も100以上ではあまり変化がなく、また高次元すぎて類似度が無駄に下がるだけなので100にしています。
- min_countはとりあえず1にしておきます。
- その他のパラメータはデフォルトにします。
- 下手に弄らないほうがよかったです。
- 詳しくは公式を見るのが一番でしょう。gensim: models.doc2vec – Deep learning with paragraph2vec
以下gensimを用いたコード付近の抜粋です。抜粋なのでコピペしても動かないでしょう。
サンプルコード
def sentences(): # wakati_gen = news_tokenizer.read_wakati() wakati_gen = list(news_tokenizer.read_wakati()) random.shuffle(wakati_gen) for category, tokens in wakati_gen: # categoryはニュース記事のカテゴリー(サッカー,野球など) # tokensはニュース原稿を形態素解析して単語を並べたリスト yield doc2vec.LabeledSentence(tokens, tags=[category]) def main(): word = 'microsoft' data = sentences() # ニュース原稿は複雑なデータなので学習率は小さめにしておく model = doc2vec.Doc2Vec(data, size=100, alpha=0.0025, min_alpha=0.000001, window=15, min_count=1) # microsoft に最も似た単語を並べる(未学習) print(model.wv.most_similar(word)) training = 10 for epoch in range(training): data = sentences() model.train(data, total_examples=model.corpus_count, epochs=model.iter) # microsoft に最も似た単語を並べる(学習中) print(model.wv.most_similar(word)) model.save('./data/vector/d2v/category.model')
ネット上に転がっているソースは、学習率のmodel.alpha
やmodel.min_alpha
を訓練ループ毎に減らしているものがありました。ですが、そんなことをしなくてもWord2Vec内でデフォルトで線形に減衰していきます。(非線形に減らしたいなら別ですが)
microsoftに類似するワードの経過
未学習の状態ではmicrosoftとは関係ない「三十三観音」とか「クラフトビール」とかが出てきます。
ですが最終的には「マイクロソフト」や「azure」などが出てきます。「linux」が出てくるのは、WSLが話題になった次期だからでしょうか。
おもしろい
[('pot', 0.40417373180389404), ('クリーク・アンド・リバー', 0.3955865800380707), ('アンドリュー・クライスバーグ', 0.38725024461746216), ( 'クラフトビール', 0.3792040944099426), ('三十三観音', 0.3782668709754944), ('cafe', 0.37654101848602295), ('情報基盤センター', 0.3752627968788147), ('イオンモール岡山', 0.3726036548614502), ('カッチー', 0.3725459575653076), ('輸送用機器', 0.3713015913963318)] [('android', 0.9968464374542236), ('プラットフォーム', 0.9967312812805176), ('azure', 0.9967294931411743), ('windows', 0.9965859055519104), ('aws', 0.9965192675590515), ('ios', 0.9956364035606384), ('アップデート', 0.9955232739448547), ('提供', 0.995313286781311), ('ソリューション', 0.9951471090316772), ('セキュリティ', 0.9951217174530029)] [('connect', 0.9941026568412781), ('linux', 0.9937072992324829), ('azure', 0.9918037056922913), ('更新プログラム', 0.9912915229797363), ('マイクロソフト', 0.9907968640327454), ('プラットフォーム', 0.9906506538391113), ('android', 0.990537166595459), ('windows', 0.9905165433883667), ('platform', 0.9904587864875793), ('商用', 0.9898083209991455)] [('source', 0.9873005151748657), ('マイクロソフト', 0.9856327176094055), ('open', 0.9854219555854797), ('linux', 0.984656810760498), ('sdk', 0.9821251034736633), ('azure', 0.9807733297348022), ('platform', 0.9805046319961548), ('software', 0.9794276356697083), ('oss', 0.9788164496421814), ('aws', 0.9785921573638916)] reading wakati files [('linux', 0.9825512766838074), ('source', 0.9787257313728333), ('open', 0.9778124690055847), ('adobe', 0.9764299392700195), ('platform',0.9740426540374756), ('プラットフォーム', 0.9735702872276306), ('oss', 0.9733889102935791), ('aws', 0.9732674956321716), ('マイクロソフト', 0.9702072739601135), ('プログラミング言語', 0.9694012403488159)] .... [('source', 0.953851580619812), ('open', 0.9504380822181702), ('linux', 0.9310433864593506), ('azure', 0.9303960204124451), ('aws', 0.9150442481040955), ('vmware', 0.9144693613052368), ('プラットフォーム', 0.9121274948120117), ('platform', 0.9118857383728027), ('ソフトウェア', 0.9039658308029175), ('開発者', 0.9016284942626953)]
ソースコード
全コードはここにあります。ニュース原稿はありませんので、自分で集めましょう。
もしくはロンウィットのサイトからライブドアニュースを取ってきて形態素解析するのもよいでしょう。ダウンロード - 株式会社ロンウイット