2021年時点 「PyTorchで始める深層学習」の修正点(6章-3)

小泉訓著の「PyTorchで始める深層学習」の6章-3、手書き文字の分類で修正点がいくつかあったので、ここにメモする。

(2018年の本なので仕方ないことである)

 

sklearnのdatasetsでmnistデータをダウンロードするコードがあるが、本にあるfetch_mldataはもう使えないらしい。

なのでコードは以下のように書き換えると上手くいく。

mnist = datasets.fetch_openml('mnist_784', version=1, data_home="./data/")

 

train_X, test_X, train_Y, test_Y = model_selection.train_test_split(
mnist_data, mnist_label, train_size = train_size, test_size = test_size
)

で得たtrain_y、test_yはobject型であるため、その後の処理

train_Y = torch.from_numpy(train_Y).long()

などでエラーとなる。

 

これの対処法は

object型→numpyのint

に変えておくことだ。

 

この様なコードを加えると良い。

train_Y = train_Y.astype(np.int)
test_Y = test_Y.astype(np.int)

 

次はエラーでは無いが、pytorchでは

from torch.autograd import Variable

はもう必要無くなっている(depricated)。

よってこれを含むコードは飛ばしてもよい。

(あっても良い)

 

学習の部分で

total_loss += loss.data[0]

とあるのだが、以下の様に書き換える必要がある。

total_loss += loss.item()

 

以上で手書き文字の認識コードが動く。