2021年時点 「PyTorchで始める深層学習」の修正点(6章-3)
小泉訓著の「PyTorchで始める深層学習」の6章-3、手書き文字の分類で修正点がいくつかあったので、ここにメモする。
(2018年の本なので仕方ないことである)
①
sklearnのdatasetsでmnistデータをダウンロードするコードがあるが、本にあるfetch_mldataはもう使えないらしい。
なのでコードは以下のように書き換えると上手くいく。
mnist = datasets.fetch_openml('mnist_784', version=1, data_home="./data/")
②
train_X, test_X, train_Y, test_Y = model_selection.train_test_split(
mnist_data, mnist_label, train_size = train_size, test_size = test_size
)
で得たtrain_y、test_yはobject型であるため、その後の処理
train_Y = torch.from_numpy(train_Y).long()
などでエラーとなる。
これの対処法は
object型→numpyのint
に変えておくことだ。
この様なコードを加えると良い。
train_Y = train_Y.astype(np.int)
test_Y = test_Y.astype(np.int)
③
次はエラーでは無いが、pytorchでは
from torch.autograd import Variable
はもう必要無くなっている(depricated)。
よってこれを含むコードは飛ばしてもよい。
(あっても良い)
④
学習の部分で
total_loss += loss.data[0]
とあるのだが、以下の様に書き換える必要がある。
total_loss += loss.item()
以上で手書き文字の認識コードが動く。