This notebook demonstrates various techniques of effective Neural Network models training using the Callbacks mechanism of FastAI library (v1).
Getting the Data¶
from fastai.vision import *
from fastai.callbacks import *
path = untar_data(URLs.CIFAR)
(path).ls()
data = (ImageList.from_folder(path)
.split_by_rand_pct(0.1)
.label_from_folder()
.transform(get_transforms())
.add_test_folder()
.databunch(bs=128)
.normalize(imagenet_stats))
data.classes
data.c
data.show_batch(figsize=(6,6))
We will decode the dataset labels to be more meaningful to a human.
Randomly Initialized CONV Model¶
results = pd.DataFrame(index=range(5))
int(torch.cuda.max_memory_allocated(0)/(1024*1024))
learn = cnn_learner(data, models.resnet50, metrics=[accuracy], pretrained=False, callback_fns=[CSVLogger,ShowGraph])
learn.lr_find()
learn.recorder.plot()
learn.fit_one_cycle(5,max_lr=2e-2)
results['resnet50'] = pd.read_csv(path/'history.csv',usecols=['accuracy'])
results.head()
results.plot()
Pretrained CONV net Model¶
learn = cnn_learner(data, models.resnet50, metrics=[accuracy], pretrained=True, callback_fns=[CSVLogger,ShowGraph])
learn.lr_find()
learn.recorder.plot()
learn.fit_one_cycle(5,max_lr=1e-2)
results['resnet50'] = pd.read_csv(path/'history.csv',usecols=['accuracy'])
results.plot()
learn.save('stage-1')
learn.unfreeze()
learn.lr_find()
learn.recorder.plot()
learn.fit_one_cycle(5,max_lr=slice(1e-5,1e-4))
learn.save('stage-2')
learn.lr_find()
learn.recorder.plot()
learn.fit_one_cycle(5,max_lr=slice(1e-6,1e-5))
results['resnet50 Stage-3'] = pd.read_csv(path/'history.csv',usecols=['accuracy'])
results.plot()
learn.save('stage-3')
learn.lr_find()
learn.recorder.plot()
learn.fit_one_cycle(5,max_lr=slice(3e-6,3e-5))
learn.save('stage-4')
results['resnet50 Stage-4'] = pd.read_csv(path/'history.csv',usecols=['accuracy'])
results.plot()
Results¶
interp = ClassificationInterpretation.from_learner(learn)
losses,idxs = interp.top_losses()
interp.plot_confusion_matrix(figsize=(8,8))
interp.plot_top_losses(25, figsize=(13,13))
interp.most_confused(min_val=3)