This notebook demonstrates various techniques of effective Neural Network models training using the Callbacks mechanism of FastAI library (v1).

Getting the Data

from fastai.vision import *
path = untar_data('https://s3.amazonaws.com/fast-ai-imageclas/imagenette2') 
lbl_dict = dict(
    n01440764='tench',
    n02102040='English springer',
    n02979186='cassette player',
    n03000684='chain saw',
    n03028079='church',
    n03394916='French horn',
    n03417042='garbage truck',
    n03425413='gas pump',
    n03445777='golf ball',
    n03888257='parachute'
)
def lbl_trfm(o): 
  return lbl_dict[str(o).split('/')[-2]]
tfms = get_transforms()
data = ImageList.from_folder(path).split_by_folder(valid='val').label_from_func(lbl_trfm).transform(tfms, size=320).databunch(bs=32).normalize(imagenet_stats)
data
ImageDataBunch;

Train: LabelList (9469 items)
x: ImageList
Image (3, 320, 320),Image (3, 320, 320),Image (3, 320, 320),Image (3, 320, 320),Image (3, 320, 320)
y: CategoryList
parachute,parachute,parachute,parachute,parachute
Path: /root/.fastai/data/imagenette2;

Valid: LabelList (3925 items)
x: ImageList
Image (3, 320, 320),Image (3, 320, 320),Image (3, 320, 320),Image (3, 320, 320),Image (3, 320, 320)
y: CategoryList
parachute,parachute,parachute,parachute,parachute
Path: /root/.fastai/data/imagenette2;

Test: None
data.show_batch(5, figsize=(16,16))

Training Model

if torch.cuda.is_available():
  torch.cuda.init()
from fastai.callbacks import *
learn = cnn_learner(data, models.resnet50, metrics=[accuracy], callback_fns=[CSVLogger,ShowGraph])
learn.lr_find()
LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.
learn.recorder.plot()
learn.freeze()
learn.fit_one_cycle(2,max_lr=1e-2)
epoch train_loss valid_loss accuracy time
0 0.578849 0.410253 0.879490 06:13
1 0.187801 0.092428 0.968662 06:03
learn.unfreeze()
learn.fit_one_cycle(3,max_lr=slice(1e-4,1e-3,1e-2))
epoch train_loss valid_loss accuracy time
0 0.471138 0.314138 0.898344 07:54
1 0.288256 0.203786 0.935287 07:50
2 0.135715 0.108479 0.963822 07:46

Results

interp = ClassificationInterpretation.from_learner(learn)
losses,idxs = interp.top_losses()
interp.plot_confusion_matrix(figsize=(8,8))
interp.plot_top_losses(9, figsize=(16,16))
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).