【定番】手書き数字を機械学習してみる PART1
こんにちは、データサイエンティストになりたい、kingsmanです。
このブログは本来、技術系のブログとして、プログラミングのエントリを書いていこうと思っていたのですが、まだ技術系の話を一回も書いていません。
(院試に落ちてから色々悩んでいたため、その余裕がありませんでした。)
このままだといつまでたっても、機械学習ができるようにはならないので、練習がてらエントリを書きます。
お題は機械学習チュートリアルのド定番ともいえる、手書き文字の識別です。
大量の手書き文字を学習しておいて、手書き文字の認識を行います。詳しくはしりませんが、郵便はがきの送り先も、機械で読み取っているらしいですね。
おそらくソフトウェアには、機械学習が使われているではないでしょうか。
とりあえずは教科書のコードを移していこうと思います。(もちろん何が起こっているか、その一行に何の意味があるかを理解しながら)
ですが、教科書通りだと芸がないし、エントリにわざわざ書く意味がないので、少し発展させます。
具体的には、
(ⅰ)テスト用の手書き文字のデータを自分で作る
学習データはScikit learnに入っているデータセットを使います。
教科書のサンプルコードはSVCクラスを使っていますが、ほかにもLinearSVC, NuSVCなどがあるようです。正直SVCが何の略かもわかっていませんが、そこも合わせて勉強します。
今回は(ⅰ)について書きます。というか、まだ(ⅱ)は出来ていません。
(1)手書き文字を書いて、png形式で保存する
ありがたいことに、「お絵描き png」で検索したらぴったりのWebアプリが出てきました。
当然ですがAdobe Illustratorのような高性能なグラフィックデザインソフトは必要ありません。これで十分です。しかもpng形式で簡単に保存できます。
これをPycharmの作業ディレクトリと保存しました。
(PyCharm上の手書き数字8のpngファイル)
これはあくまで画像ファイルなので、これをピクセルデータに変換します。
(2)画像をピクセルデータに変換
一応、コードを書きました。
def image_to_data(imagefile):
import numpy as np
from PIL import Image
image = Image.open(imagefile).convert("L")
image = image.resize((8,8),Image.ANTIALIAS)
img = np.asarray(image,dtype=float)
img = np.floor(16-16*(img/256))
print(img)
from matplotlib import pyplot as plt, cm
plt.imshow(img, cmap=cm.gray_r, interpolation="nearest")
plt.show()
return img
if __name__ == "__main__":
image_8 = image_to_data("./sample8.png")
image_to_data(image_8)
ラスト2行目の"./sample8.png"は、作業ディレクトリにある手書き数字のpngファイル名です。
サンプルコードをいじって作ったのですが、出力はあるものの、エラーっぽいものが出てきます。サンプルを少しいじってくらいで、エラーに悩まされるあたり、まだまだ勉強が必要です。解決したら、記事を修正します。くそコードなので、絶対に参考にしないでください。
(手書きの8を、8×8のグレイスケールで表現した図)
標準ライブラリしか使ってこなかったせいか、新しいライブラリをインポートして使おうとすると、分からないことだらけですね。メソッドのパラメータに何を入れるべきかがさっぱりです。教科書もすべてを解説しているわけではないので、ここは自分でパッケージの使い方を調べるしかなさそうです。
英語は苦手ではないので頑張ります。
勉強すればするほど、分からないことが増えてめげそうです。
地道にコツコツやっていきます。