This notebook demonstrates various techniques of effective Neural Network models training using the Callbacks mechanism of FastAI library (v1).

Getting the Data

from fastai.vision import *
from fastai.callbacks import *
import fastai
%matplotlib inline
path = untar_data(URLs.CIFAR) 
np.random.seed(10)
data = (ImageList.from_folder(path)
        .split_by_rand_pct(0.2)
        .label_from_folder()
        .transform(get_transforms())
        .add_test_folder()
        .databunch(bs=128)
        .normalize(imagenet_stats))
data
ImageDataBunch;

Train: LabelList (48000 items)
x: ImageList
Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32)
y: CategoryList
cat,cat,cat,cat,cat
Path: /home/condor/.fastai/data/cifar10;

Valid: LabelList (12000 items)
x: ImageList
Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32)
y: CategoryList
cat,airplane,deer,automobile,cat
Path: /home/condor/.fastai/data/cifar10;

Test: LabelList (10000 items)
x: ImageList
Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32)
y: EmptyLabelList
,,,,
Path: /home/condor/.fastai/data/cifar10
data.classes
['airplane',
 'automobile',
 'bird',
 'cat',
 'deer',
 'dog',
 'frog',
 'horse',
 'ship',
 'truck']
data.show_batch(figsize=(6,6))
results = pd.DataFrame()

Pretrained Resnet34 Model

learn = cnn_learner(data, models.resnet34, metrics=[accuracy], pretrained=True, callback_fns=[CSVLogger,ShowGraph])
learn.lr_find()
learn.recorder.plot()
LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.
learn.fit_one_cycle(5,max_lr=1e-2)
epoch train_loss valid_loss accuracy time
0 1.299677 1.165882 0.620750 01:16
1 0.998541 0.990291 0.671167 01:16
2 0.885245 0.801127 0.721583 01:16
3 0.785631 0.689487 0.760667 01:16
4 0.733769 0.654575 0.772417 01:16
results = results.append(pd.read_csv(path/'history.csv',usecols=['train_loss','valid_loss','accuracy']))
results
train_loss valid_loss accuracy
0 1.299677 1.165882 0.620750
1 0.998541 0.990291 0.671167
2 0.885245 0.801127 0.721583
3 0.785631 0.689487 0.760667
4 0.733769 0.654575 0.772417
ax = results['accuracy'].plot(title='Model\'s Accuracy')
ax.set_xlabel("Epoch");
ax = results[['train_loss','valid_loss']].plot(title='Model\'s Loss')
ax.set_xlabel("Epoch");
learn.save('stage-1')
learn.lr_find()
learn.recorder.plot()
LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.
learn.fit_one_cycle(5,max_lr=1e-4)
epoch train_loss valid_loss accuracy time
0 0.710697 0.666801 0.768750 01:17
1 0.698880 0.655413 0.772417 01:17
2 0.695578 0.655291 0.771417 01:16
3 0.703674 0.645886 0.774250 01:15
4 0.688558 0.653952 0.771417 01:17
results = results.append(pd.read_csv(path/'history.csv',usecols=['train_loss','valid_loss','accuracy']),ignore_index=True)
ax = results['accuracy'].plot(title='Model\'s Accuracy')
ax.set_xlabel("Epoch");
ax = results[['train_loss','valid_loss']].plot(title='Model\'s Loss')
ax.set_xlabel("Epoch");
learn.save('stage-2')
learn.unfreeze()
learn.lr_find(start_lr=1e-9)
learn.recorder.plot()
LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.
learn.fit_one_cycle(5,max_lr=slice(1e-6,3e-5))
epoch train_loss valid_loss accuracy time
0 0.691009 0.653497 0.771250 01:20
1 0.704839 0.640239 0.776917 01:20
2 0.680828 0.627638 0.782250 01:20
3 0.669888 0.624195 0.783333 01:20
4 0.670113 0.625992 0.782583 01:20
results = results.append(pd.read_csv(path/'history.csv',usecols=['train_loss','valid_loss','accuracy']),ignore_index=True)
ax = results['accuracy'].plot(title='Model\'s Accuracy')
ax.set_xlabel("Epoch");
ax = results[['train_loss','valid_loss']].plot(title='Model\'s Loss')
ax.set_xlabel("Epoch");
learn.save('stage-3')
learn.lr_find(start_lr=1e-9)
learn.recorder.plot()
LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.
learn.fit_one_cycle(5,max_lr=slice(1e-6,3e-5))
epoch train_loss valid_loss accuracy time
0 0.659414 0.626180 0.781417 01:21
1 0.631002 0.610892 0.786083 01:21
2 0.655484 0.608659 0.788167 01:21
3 0.647288 0.607414 0.787167 01:21
4 0.625431 0.597066 0.791750 01:21
results = results.append(pd.read_csv(path/'history.csv',usecols=['train_loss','valid_loss','accuracy']),ignore_index=True)
ax = results['accuracy'].plot(title='Model\'s Accuracy')
ax.set_xlabel("Epoch");
ax = results[['train_loss','valid_loss']].plot(title='Model\'s Loss')
ax.set_xlabel("Epoch");
learn.save('stage-4')
learn.lr_find(start_lr=1e-9)
learn.recorder.plot()
LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.
learn.fit_one_cycle(5,max_lr=slice(1e-6,1e-5))
epoch train_loss valid_loss accuracy time
0 0.615472 0.597207 0.791667 01:20
1 0.632519 0.596387 0.793417 01:20
2 0.635458 0.599571 0.790833 01:20
3 0.616331 0.595620 0.792583 01:19
4 0.622578 0.599206 0.791833 01:18
results = results.append(pd.read_csv(path/'history.csv',usecols=['train_loss','valid_loss','accuracy']),ignore_index=True)
ax = results['accuracy'].plot(title='Model\'s Accuracy')
ax.set_xlabel("Epoch");
ax = results[['train_loss','valid_loss']].plot(title='Model\'s Loss')
ax.set_xlabel("Epoch");
learn.save('stage-5')
learn.lr_find(start_lr=1e-9)
learn.recorder.plot()
LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.
learn.fit_one_cycle(5,max_lr=slice(1e-5,1e-4))
epoch train_loss valid_loss accuracy time
0 0.638027 0.602997 0.789083 01:23
1 0.621414 0.579477 0.795417 01:23
2 0.586148 0.555526 0.806333 01:22
3 0.552351 0.532826 0.813583 01:22
4 0.526386 0.532579 0.814167 01:22
results = results.append(pd.read_csv(path/'history.csv',usecols=['train_loss','valid_loss','accuracy']),ignore_index=True)
ax = results['accuracy'].plot(title='Model\'s Accuracy')
ax.set_xlabel("Epoch");
ax = results[['train_loss','valid_loss']].plot(title='Model\'s Loss')
ax.set_xlabel("Epoch");
learn.save('stage-6')
learn.lr_find(start_lr=1e-9)
learn.recorder.plot()
LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.
learn.fit_one_cycle(5,max_lr=slice(1e-5,1e-4))
epoch train_loss valid_loss accuracy time
0 0.532331 0.545799 0.811083 01:22
1 0.540433 0.530244 0.816583 01:21
2 0.501745 0.515830 0.819250 01:20
3 0.474019 0.510703 0.822250 01:23
4 0.453968 0.507984 0.824500 01:23
results = results.append(pd.read_csv(path/'history.csv',usecols=['train_loss','valid_loss','accuracy']),ignore_index=True)
ax = results['accuracy'].plot(title='Model\'s Accuracy')
ax.set_xlabel("Epoch");
ax = results[['train_loss','valid_loss']].plot(title='Model\'s Loss')
ax.set_xlabel("Epoch");
learn.save('stage-7')
learn.lr_find(start_lr=1e-10)
learn.recorder.plot()
LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.
learn.fit_one_cycle(5,max_lr=slice(1e-6,1e-5))
epoch train_loss valid_loss accuracy time
0 0.455548 0.501867 0.825500 01:22
1 0.445887 0.502957 0.824917 01:22
2 0.459576 0.502366 0.827000 01:22
3 0.442764 0.502596 0.825417 01:22
4 0.438643 0.499660 0.826333 01:22
results = results.append(pd.read_csv(path/'history.csv',usecols=['train_loss','valid_loss','accuracy']),ignore_index=True)
ax = results['accuracy'].plot(title='Model\'s Accuracy')
ax.set_xlabel("Epoch");
learn.save('stage-8')
learn.load('stage-8');
learn.lr_find(start_lr=1e-11)
learn.recorder.plot()
LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.
learn.fit_one_cycle(1,max_lr=slice(1e-8,1e-7))
epoch train_loss valid_loss accuracy time
0 0.447769 0.506052 0.825333 01:19
results = results.append(pd.read_csv(path/'history.csv',usecols=['train_loss','valid_loss','accuracy']),ignore_index=True)
ax = results['accuracy'].plot(title='Model\'s Accuracy')
ax.set_xlabel("Epoch");
ax = results[['train_loss','valid_loss']].plot(title='Model\'s Loss')
ax.set_xlabel("Epoch");
learn.save('stage-9')
learn.load('stage-9');
learn.lr_find(start_lr=1e-10)
learn.recorder.plot()
LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.
learn.fit_one_cycle(5,max_lr=slice(1e-6,1e-7))
epoch train_loss valid_loss accuracy time
0 0.437188 0.499579 0.826833 01:18
1 0.425254 0.500412 0.827500 01:18
2 0.437297 0.501330 0.826667 01:18
3 0.436054 0.498833 0.827417 01:18
4 0.423017 0.497802 0.827750 01:19
results = results.append(pd.read_csv(path/'history.csv',usecols=['train_loss','valid_loss','accuracy']),ignore_index=True)
ax = results['accuracy'].plot(title='Model\'s Accuracy')
ax.set_xlabel("Epoch");
ax = results[['train_loss','valid_loss']].plot(title='Model\'s Loss')
ax.set_xlabel("Epoch");
learn.export('resnet34-acc.82775')

Results

interp = ClassificationInterpretation.from_learner(learn)
losses,idxs = interp.top_losses()
interp.plot_confusion_matrix(figsize=(8,8))
from fastai.widgets.class_confusion import ClassConfusion
ClassConfusion??
conf = ClassConfusion(interp,classlist=classes)
Please enter a value for `k`, or the top images you will see: 5
100%|██████████| 84/84 [02:03<00:00,  5.18s/it]
<Figure size 432x288 with 0 Axes>
interp.plot_top_losses(25, figsize=(13,13))
interp.most_confused(min_val=10)
[('dog', 'cat', 173),
 ('cat', 'dog', 147),
 ('cat', 'frog', 103),
 ('dog', 'horse', 75),
 ('deer', 'frog', 69),
 ('truck', 'automobile', 66),
 ('automobile', 'truck', 61),
 ('bird', 'frog', 60),
 ('cat', 'deer', 53),
 ('ship', 'airplane', 52),
 ('deer', 'horse', 48),
 ('dog', 'deer', 47),
 ('bird', 'airplane', 46),
 ('cat', 'bird', 44),
 ('deer', 'bird', 43),
 ('deer', 'cat', 42),
 ('dog', 'bird', 41),
 ('dog', 'frog', 40),
 ('bird', 'deer', 39),
 ('airplane', 'ship', 37),
 ('ship', 'truck', 35),
 ('horse', 'dog', 34),
 ('airplane', 'truck', 32),
 ('horse', 'cat', 32),
 ('bird', 'cat', 31),
 ('bird', 'dog', 31),
 ('airplane', 'bird', 30),
 ('cat', 'horse', 30),
 ('horse', 'deer', 28),
 ('airplane', 'automobile', 26),
 ('bird', 'horse', 26),
 ('frog', 'bird', 25),
 ('frog', 'cat', 25),
 ('truck', 'airplane', 25),
 ('deer', 'airplane', 24),
 ('deer', 'dog', 23),
 ('ship', 'automobile', 23),
 ('cat', 'truck', 15),
 ('cat', 'airplane', 14),
 ('truck', 'ship', 14),
 ('horse', 'frog', 13),
 ('airplane', 'deer', 12),
 ('automobile', 'ship', 12),
 ('cat', 'ship', 11),
 ('horse', 'bird', 11),
 ('dog', 'airplane', 10),
 ('frog', 'airplane', 10)]
interp.learn.show_results(rows=4,figsize=(10,14))

Inference

modelpath = Path('./cifar10/')
learn = load_learner(modelpath, file='resnet34-acc.82775', test=ImageList.from_folder(path/'test'))
learn.file = 'resnet34-acc.82775'
learn.file
'resnet34-acc.82775'
learn.file = 'resnet34-acc.82775'
print(learn.file)
resnet34-acc.82775
preds = learn.TTA(ds_type=DatasetType.Test)
preds = preds.argmax(dim=1)
classes = {
 0:'airplane',
 1:'automobile',
 2:'bird',
 3:'cat',
 4:'deer',
 5:'dog',
 6:'frog',
 7:'horse',
 8:'ship',
 9:'truck'
}
idxs = np.random.randint(0,len(preds),(10,))
xs = learn.data.test_ds.x[idxs]
zs = preds[idxs]
for i in range(10):
    xs[i].show(title=str(classes[int(zs[i])]),figsize=(1,1))
learn.show_results(rows=3,ds_type=DatasetType.Test,figsize=(6,6))
learn.show_results(ds_type=DatasetType.Test,figsize=(10,10),preds=preds)
learn.show_results(rows=3,ds_type=DatasetType.Test,figsize=(6,6))
ax[0][0].plot(learn.data.test_dl)
(3, 3)
for i,ax in enumerate(ax):
    print(i,ax)
0 [<matplotlib.axes._subplots.AxesSubplot object at 0x7fb19ddf9198>
 <matplotlib.axes._subplots.AxesSubplot object at 0x7fb19f521a90>
 <matplotlib.axes._subplots.AxesSubplot object at 0x7fb19f49eba8>
 <matplotlib.axes._subplots.AxesSubplot object at 0x7fb19f58ec50>
 <matplotlib.axes._subplots.AxesSubplot object at 0x7fb19d42f7f0>]
1 [<matplotlib.axes._subplots.AxesSubplot object at 0x7fb19d3f9d30>
 <matplotlib.axes._subplots.AxesSubplot object at 0x7fb19d659320>
 <matplotlib.axes._subplots.AxesSubplot object at 0x7fb19d8c6898>
 <matplotlib.axes._subplots.AxesSubplot object at 0x7fb19d7dae80>
 <matplotlib.axes._subplots.AxesSubplot object at 0x7fb19d5e1470>]
2 [<matplotlib.axes._subplots.AxesSubplot object at 0x7fb19d4faa20>
 <matplotlib.axes._subplots.AxesSubplot object at 0x7fb19d954fd0>
 <matplotlib.axes._subplots.AxesSubplot object at 0x7fb19d78c5c0>
 <matplotlib.axes._subplots.AxesSubplot object at 0x7fb19d9d2b70>
 <matplotlib.axes._subplots.AxesSubplot object at 0x7fb19d36d160>]
3 [<matplotlib.axes._subplots.AxesSubplot object at 0x7fb19d337710>
 <matplotlib.axes._subplots.AxesSubplot object at 0x7fb19d27fcc0>
 <matplotlib.axes._subplots.AxesSubplot object at 0x7fb19d2ac2b0>
 <matplotlib.axes._subplots.AxesSubplot object at 0x7fb19d24c860>
 <matplotlib.axes._subplots.AxesSubplot object at 0x7fb19d271e10>]
4 [<matplotlib.axes._subplots.AxesSubplot object at 0x7fb19d21c400>
 <matplotlib.axes._subplots.AxesSubplot object at 0x7fb19d23d9b0>
 <matplotlib.axes._subplots.AxesSubplot object at 0x7fb19d1e0f60>
 <matplotlib.axes._subplots.AxesSubplot object at 0x7fb19d18a550>
 <matplotlib.axes._subplots.AxesSubplot object at 0x7fb19d1adb00>]
learn.pred_batch(ds_type=DatasetType.Test)
Exception ignored in: <bound method _MultiProcessingDataLoaderIter.__del__ of <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7fb90d853b70>>
Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 926, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 906, in _shutdown_workers
    w.join()
  File "/usr/lib/python3.6/multiprocessing/process.py", line 124, in join
    res = self._popen.wait(timeout)
  File "/usr/lib/python3.6/multiprocessing/popen_fork.py", line 50, in wait
    return self.poll(os.WNOHANG if timeout == 0.0 else 0)
  File "/usr/lib/python3.6/multiprocessing/popen_fork.py", line 28, in poll
    pid, sts = os.waitpid(self.pid, flag)
KeyboardInterrupt: 
tensor([[6.8408e-01, 5.7007e-03, 5.3190e-02, 3.4258e-03, 7.4425e-03, 4.2081e-03,
         1.0031e-03, 7.4358e-02, 1.5332e-01, 1.3269e-02],
        [8.8712e-02, 1.3003e-03, 1.6897e-02, 6.2706e-03, 6.9148e-02, 3.8449e-03,
         3.3436e-04, 3.6343e-02, 7.7591e-01, 1.2452e-03],
        [7.0001e-01, 8.0732e-03, 4.3004e-02, 2.2150e-02, 2.9275e-02, 4.1997e-03,
         2.7561e-02, 1.8817e-03, 1.5634e-01, 7.5021e-03],
        [1.7121e-04, 2.6147e-05, 1.7973e-05, 2.6687e-05, 4.6262e-06, 4.4637e-06,
         8.9420e-06, 1.5052e-06, 9.9969e-01, 5.0907e-05],
        [3.7220e-02, 1.2537e-03, 1.3567e-02, 2.2394e-04, 7.6938e-05, 3.9906e-05,
         9.5503e-05, 3.1466e-05, 9.4732e-01, 1.7589e-04],
        [2.4263e-03, 5.2237e-02, 2.4334e-04, 1.2789e-04, 2.3181e-05, 5.9032e-05,
         1.4350e-03, 9.5051e-05, 9.4130e-01, 2.0526e-03],
        [1.8045e-03, 2.4069e-05, 2.0543e-04, 4.0256e-06, 2.1597e-06, 1.3024e-06,
         6.4138e-06, 1.5679e-06, 9.9794e-01, 1.2305e-05],
        [1.1888e-02, 8.7392e-02, 2.9185e-04, 9.6800e-04, 5.4154e-05, 4.2001e-05,
         2.8520e-04, 6.0223e-05, 8.2782e-01, 7.1203e-02],
        [5.5805e-04, 1.6555e-03, 1.1028e-03, 3.1171e-04, 8.2866e-04, 3.1551e-04,
         2.2857e-03, 9.8176e-05, 9.8334e-01, 9.5025e-03],
        [8.6505e-03, 3.4907e-04, 3.0368e-04, 1.0471e-04, 3.6354e-03, 7.6966e-05,
         3.0304e-04, 3.5167e-03, 9.7463e-01, 8.4280e-03],
        [3.0857e-04, 1.3648e-02, 3.1474e-04, 4.0362e-04, 1.7455e-05, 1.8675e-04,
         1.2725e-03, 6.3684e-04, 9.8270e-01, 5.1431e-04],
        [1.1699e-04, 1.8148e-05, 4.6238e-05, 5.4841e-06, 1.3519e-06, 9.8452e-06,
         4.5152e-06, 1.0416e-05, 9.9976e-01, 2.8690e-05],
        [7.2473e-06, 2.1709e-06, 3.7460e-07, 9.3380e-08, 1.1154e-08, 1.8529e-07,
         1.8999e-07, 3.4900e-08, 9.9999e-01, 1.5388e-06],
        [3.4056e-03, 2.8286e-02, 8.4777e-04, 3.5079e-03, 1.2765e-03, 2.9098e-03,
         1.9687e-02, 2.0317e-03, 9.0676e-01, 3.1286e-02],
        [1.6008e-05, 1.0401e-05, 1.3696e-06, 4.6332e-06, 4.4693e-07, 5.0362e-07,
         4.3970e-07, 4.1795e-07, 9.9969e-01, 2.7157e-04],
        [1.5209e-04, 1.4402e-04, 6.6023e-05, 3.4667e-05, 1.1050e-05, 2.0418e-05,
         1.5894e-05, 1.2215e-05, 9.9943e-01, 1.0869e-04],
        [1.7585e-02, 3.2862e-01, 1.4537e-03, 9.3943e-03, 1.1344e-03, 5.7568e-04,
         7.9109e-03, 1.1603e-03, 6.2977e-01, 2.4008e-03],
        [3.1852e-03, 1.3149e-02, 1.7228e-04, 2.1794e-04, 1.3239e-04, 6.9000e-05,
         2.5192e-04, 3.4176e-04, 9.5812e-01, 2.4358e-02],
        [4.6982e-04, 2.0873e-04, 1.2923e-03, 3.7042e-04, 2.7356e-05, 1.4631e-04,
         1.9128e-04, 1.1896e-04, 9.9229e-01, 4.8888e-03],
        [1.5854e-02, 1.3884e-02, 8.3481e-04, 5.2603e-03, 5.8387e-04, 8.3195e-04,
         1.6641e-03, 6.4984e-04, 9.1134e-01, 4.9098e-02],
        [3.4516e-04, 6.5250e-05, 6.8618e-06, 1.5830e-06, 5.2314e-07, 1.1869e-06,
         1.5422e-06, 1.3356e-06, 9.9957e-01, 1.0257e-05],
        [3.2370e-03, 1.5353e-03, 4.4592e-05, 2.4029e-05, 1.9324e-05, 2.5402e-05,
         5.9010e-05, 7.0338e-05, 9.9443e-01, 5.5255e-04],
        [3.5302e-03, 2.3571e-03, 1.1423e-04, 1.9935e-04, 6.3655e-05, 2.1379e-05,
         1.2050e-04, 1.1137e-05, 9.9323e-01, 3.5043e-04],
        [2.3451e-05, 5.2255e-06, 3.1127e-06, 2.1761e-06, 1.2159e-06, 5.0042e-07,
         5.9002e-07, 7.9745e-07, 9.9988e-01, 8.1847e-05],
        [1.1080e-03, 1.4036e-03, 4.3716e-04, 8.0268e-04, 1.3389e-04, 1.2233e-03,
         3.4560e-04, 1.1004e-03, 9.9264e-01, 8.0938e-04],
        [5.7497e-05, 1.8224e-05, 5.0266e-06, 3.7927e-06, 5.9268e-07, 3.6134e-07,
         7.9320e-07, 6.2822e-07, 9.9984e-01, 7.7309e-05],
        [5.4783e-04, 6.2582e-04, 4.1998e-04, 1.5549e-05, 1.3988e-05, 2.8148e-05,
         1.1940e-04, 3.4644e-06, 9.9759e-01, 6.3673e-04],
        [2.3260e-04, 2.0917e-02, 6.2222e-05, 2.8709e-04, 4.0610e-05, 4.8510e-05,
         1.0236e-04, 1.6688e-04, 9.7673e-01, 1.4095e-03],
        [1.2970e-04, 7.8655e-05, 1.4382e-05, 4.5519e-05, 5.5774e-06, 6.3306e-06,
         9.8852e-06, 5.1166e-06, 9.9630e-01, 3.4022e-03],
        [1.4305e-02, 6.4228e-02, 7.7209e-04, 3.9936e-03, 7.5476e-04, 6.4526e-04,
         4.7390e-04, 3.4776e-03, 9.0600e-01, 5.3523e-03],
        [4.6988e-01, 7.2976e-04, 1.5712e-03, 1.0244e-03, 2.0497e-03, 5.5418e-05,
         4.0344e-05, 1.3052e-04, 5.1679e-01, 7.7339e-03],
        [6.2238e-05, 6.2101e-04, 3.3958e-05, 1.1921e-05, 4.2233e-06, 2.3170e-05,
         2.9465e-05, 2.5280e-05, 9.9664e-01, 2.5514e-03],
        [2.3308e-01, 1.9573e-02, 5.4823e-01, 2.2979e-02, 4.8375e-02, 1.0574e-02,
         1.1212e-02, 5.2205e-02, 5.0045e-02, 3.7277e-03],
        [5.9598e-02, 1.2880e-02, 1.5692e-03, 1.0012e-03, 4.9919e-03, 7.4143e-04,
         1.1570e-03, 1.7433e-03, 6.8672e-01, 2.2960e-01],
        [7.9180e-02, 1.3316e-02, 4.3819e-03, 2.0526e-03, 6.3495e-04, 3.3066e-04,
         1.1661e-02, 5.6447e-04, 2.6667e-01, 6.2121e-01],
        [1.3946e-05, 5.0354e-06, 1.8152e-06, 3.6558e-06, 4.2602e-07, 2.1876e-07,
         8.6102e-07, 1.2151e-07, 9.9997e-01, 2.8319e-06],
        [4.8494e-05, 3.5102e-06, 6.0758e-06, 5.8964e-06, 1.4340e-06, 2.5628e-06,
         2.2669e-05, 1.5119e-06, 9.9983e-01, 7.3962e-05],
        [7.6004e-02, 2.8450e-01, 4.0106e-02, 1.5867e-02, 1.9982e-03, 5.0583e-03,
         2.4114e-02, 4.7531e-04, 4.2636e-01, 1.2552e-01],
        [4.1234e-04, 1.0982e-03, 9.3264e-05, 6.5309e-04, 1.1622e-05, 1.9614e-05,
         7.7053e-05, 5.1408e-05, 9.9452e-01, 3.0629e-03],
        [2.2695e-03, 2.0149e-01, 1.1189e-03, 3.0826e-02, 1.5222e-03, 3.0781e-03,
         1.0425e-02, 3.6759e-04, 7.3255e-01, 1.6354e-02],
        [7.5925e-04, 3.6030e-04, 5.7776e-05, 7.1331e-05, 2.5138e-05, 4.6625e-06,
         1.8102e-05, 4.6631e-06, 9.9861e-01, 8.7165e-05],
        [7.7294e-01, 8.3702e-03, 1.6192e-02, 2.3444e-03, 1.8206e-02, 1.1239e-03,
         6.0560e-03, 5.5273e-03, 1.1241e-01, 5.6827e-02],
        [1.5257e-03, 9.1004e-05, 4.4130e-05, 1.1393e-05, 6.9907e-06, 1.3442e-06,
         2.5484e-06, 3.4308e-06, 9.9818e-01, 1.3300e-04],
        [1.6020e-05, 8.2537e-05, 5.4327e-06, 9.7800e-06, 2.5828e-06, 3.3646e-06,
         2.1840e-05, 1.3028e-05, 9.9983e-01, 1.8984e-05],
        [4.9410e-04, 5.9237e-05, 3.5377e-05, 3.3948e-05, 2.4046e-06, 5.5708e-06,
         1.0124e-05, 1.6617e-06, 9.9930e-01, 5.8464e-05],
        [4.0622e-03, 1.6547e-04, 3.9655e-04, 1.5595e-03, 1.2394e-04, 3.5269e-04,
         4.3029e-05, 9.1227e-05, 9.8434e-01, 8.8661e-03],
        [1.7587e-02, 1.6534e-01, 3.2706e-04, 1.0901e-04, 1.4627e-04, 1.4403e-05,
         7.3428e-05, 8.5107e-05, 7.9402e-01, 2.2289e-02],
        [2.8088e-03, 2.2361e-04, 8.5391e-05, 2.2682e-04, 4.9494e-05, 1.5324e-05,
         2.4079e-05, 2.0527e-05, 9.9612e-01, 4.2287e-04],
        [2.7577e-04, 2.1623e-04, 2.9526e-05, 1.4715e-04, 2.4754e-05, 1.1524e-05,
         1.2466e-05, 2.0089e-05, 9.9899e-01, 2.6821e-04],
        [2.2405e-04, 1.4956e-05, 1.6650e-04, 4.5493e-05, 8.9055e-05, 1.0261e-05,
         6.3178e-06, 1.5625e-05, 9.9931e-01, 1.1609e-04],
        [2.3056e-04, 3.7985e-05, 3.3086e-05, 1.0165e-04, 1.7689e-05, 1.0451e-05,
         1.1259e-05, 5.9795e-06, 9.9947e-01, 7.8206e-05],
        [2.7728e-04, 1.6038e-03, 7.7481e-06, 8.4107e-05, 3.2046e-06, 3.2471e-06,
         6.8308e-06, 8.5400e-07, 9.9787e-01, 1.4567e-04],
        [3.1002e-04, 8.4504e-05, 1.8937e-05, 1.5231e-05, 4.1803e-06, 3.5561e-06,
         1.3106e-06, 2.1131e-06, 9.9870e-01, 8.5739e-04],
        [8.7196e-05, 2.1540e-05, 8.8427e-06, 3.0097e-06, 9.4019e-07, 8.0467e-07,
         2.5104e-06, 7.3468e-07, 9.9981e-01, 6.1766e-05],
        [1.1707e-03, 1.9382e-04, 2.6993e-05, 1.0871e-05, 4.0469e-06, 3.9220e-06,
         1.5585e-05, 2.7385e-06, 9.9822e-01, 3.5520e-04],
        [1.2498e-04, 1.2529e-06, 1.4651e-06, 3.4074e-07, 6.6630e-08, 2.6237e-07,
         2.0303e-07, 2.0325e-07, 9.9986e-01, 6.4914e-06],
        [2.8939e-05, 6.2726e-05, 1.8810e-05, 2.1980e-05, 2.0815e-06, 2.8081e-06,
         6.8333e-06, 6.2072e-06, 9.9945e-01, 4.0238e-04],
        [2.7279e-02, 1.2098e-03, 4.8523e-04, 1.0481e-04, 3.6851e-04, 6.1317e-05,
         6.0193e-05, 6.0287e-04, 9.6806e-01, 1.7645e-03],
        [3.3046e-03, 1.4413e-05, 1.8892e-05, 3.8076e-06, 3.8060e-05, 2.2704e-06,
         8.7199e-07, 7.5069e-07, 9.9659e-01, 2.4991e-05],
        [1.0898e-03, 6.6100e-03, 8.5384e-05, 1.2267e-04, 3.4761e-06, 2.2966e-05,
         4.5974e-05, 6.5390e-06, 9.9194e-01, 7.2549e-05],
        [1.7583e-03, 1.9188e-04, 1.8919e-03, 8.8702e-05, 2.1734e-05, 3.5054e-05,
         1.6764e-04, 2.0359e-05, 9.8967e-01, 6.1555e-03],
        [3.3622e-03, 2.8995e-03, 1.6684e-04, 2.5759e-04, 1.0425e-04, 3.0977e-05,
         5.6246e-05, 1.2328e-04, 9.9182e-01, 1.1783e-03],
        [7.8772e-04, 5.6673e-03, 1.1697e-04, 1.6290e-04, 8.6451e-05, 1.2833e-05,
         1.0302e-04, 3.8482e-05, 9.9192e-01, 1.1018e-03],
        [6.0275e-04, 4.1923e-05, 8.5904e-04, 1.6200e-05, 1.0102e-05, 8.8638e-06,
         9.5299e-05, 5.9083e-06, 9.9829e-01, 6.6334e-05]])
learn.data.test_ds[1000][0]
classes[int(preds[1000])]
'deer'
learn.data.test_ds[1][0]
classes[int(preds[1])],preds[1]
('cat', tensor(3))
preds = learn.pred_batch(ds_type=DatasetType.Test)
classes[preds]
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-184-f521ba2cf553> in <module>()
----> 1 classes[preds]

TypeError: only integer tensors of a single element can be converted to an index
classes[preds.argmax(dim=1)[1].item()]
'ship'
learn.data.y
tensor([6, 3, 3,  ..., 8, 8, 8])
learn.data.show_batch(49,ds_type=DatasetType.Test,figsize=(10,10))
from google.colab import files
!ls -l '/root/.fastai/data/cifar10/models/'
total 2225952
-rw-r--r-- 1 root root  91813552 Dec 16 08:07 stage-1.pth
-rw-r--r-- 1 root root  91813527 Dec 16 08:16 stage-2.pth
-rw-r--r-- 1 root root 261966435 Dec 16 08:24 stage-3.pth
-rw-r--r-- 1 root root 261966485 Dec 16 08:33 stage-4.pth
-rw-r--r-- 1 root root 261966479 Dec 16 08:43 stage-5.pth
-rw-r--r-- 1 root root 261966467 Dec 16 08:53 stage-6.pth
-rw-r--r-- 1 root root 261966479 Dec 16 09:02 stage-7.pth
-rw-r--r-- 1 root root 261966473 Dec 16 09:13 stage-8.pth
-rw-r--r-- 1 root root 261966483 Dec 16 09:26 stage-9.pth
-rw-r--r-- 1 root root 261966324 Dec 16 10:38 tmp.pth
!ls -l '/root/.fastai/data/cifar10/'
total 85380
-rw-r--r--  1 root root      217 Dec 16 10:46 history.csv
-rw-r--r--  1 1001 1001       60 Nov 18  2016 labels.txt
drwxr-xr-x  2 root root     4096 Dec 16 10:38 models
-rw-r--r--  1 root root 87407208 Dec 16 10:48 resnet34-acc.82775
drwxr-xr-x 12 1001 1001     4096 Nov 15  2017 test
drwxr-xr-x 12 1001 1001     4096 Nov 15  2017 train
files.download('/root/.fastai/data/cifar10/resnet34-acc.82775')
files.download('/root/.fastai/data/cifar10/resnet34-acc.82775')