2021年時点 「PyTorchで始める深層学習」の修正点(7章-4)
小泉訓著の「PyTorchで始める深層学習」の7章-5、自分のデータセット画像の分類で修正点がいくつかあったので、ここにメモする。
(2018年の本なので仕方ないことである)
① リスト7-10のtrain画像のパス
本には
files = os.listdir("./hymenoptera_data/" + d)
と書いてあるが、これではtrainフォルダ内の画像にはアクセスできない。
よって、
files = os.listdir("./hymenoptera_data/train/" + d)
と書き換えると良い。
② train画像の種類
trainフォルダにはgifなども混じっている。よって、リスト7-10では
という様に書き換えれば上手くいく。
③ Variableについて
次はエラーでは無いが、pytorchでは
from torch.autograd import Variable
はもう必要無くなっている(depricated)。
よってこれを含むコードは飛ばしてもよい。
(あっても良い)
④lossの書き方
学習の部分で
total_loss += loss.data[0]
とあるのだが、以下の様に書き換える必要がある。
total_loss += loss.item()
以上で画像分類のコードが動く。