Training using progressive resizing
Quote from Fastbook:
jargon: progressive resizing: Gradually using larger and larger images as you train.
Progressive resizing is a very effective technique to train model from scratch or using transfer learning. IceVision now offers a good support for that technique.
For more information about the progressive resizing technique, please check out the reference, here below:
Fastai Fastbook Chapter: https://github.com/fastai/fastbook/blob/master/07_sizing_and_tta.ipynb
Check out the section: Progressive Resizing
Paper highlighting the importance of progressive resizing:
Introduction
This tutorial walk you through the different steps of training the fridge dataset using the progressive resizing technique.
The main differences with IceVision standard training are the use of:
-
The
get_dataloaders()
method# DataLoaders ds, dls = get_dataloaders(model_type, [train_records, valid_records], [train_tfms, valid_tfms], batch_size=16, num_workers=2)
-
replacing dataloaders (corresponding to different image sizes) in either a Fastai
Learner
object or a Pytorch-LightningTrainer
object as follow: -
For Fastai:
# Replace current dataloaders by the new ones (corresponding to the new size)
learn.dls = fastai_dls
# Standard training
learn.lr_find()
learn.fine_tune(10, 1e-4, freeze_epochs=1)
- Pytorch-Lightning:
# Replace current dataloaders by the new ones (corresponding to the new size)
trainer.train_dataloader = dls[0]
trainer.valid_dataloader = dls[1]
# Standard training
trainer.fit(light_model, dls[0], dls[1])
Installing IceVision and IceData
Install from pypi...
# # Torch - Torchvision - IceVision - IceData - MMDetection - YOLOv5 - EfficientDet Installation
# !wget https://raw.githubusercontent.com/airctic/icevision/master/icevision_install.sh
# # Choose your installation target: cuda11 or cuda10 or cpu
# !bash icevision_install.sh cuda11
... or from icevision master
# Torch - Torchvision - IceVision - IceData - MMDetection - YOLOv5 - EfficientDet Installation
!wget https://raw.githubusercontent.com/airctic/icevision/master/icevision_install.sh
# Choose your installation target: cuda11 or cuda10 or cpu
!bash icevision_install.sh cuda11 master
# Restart kernel after installation
import IPython
IPython.Application.instance().kernel.do_shutdown(True)
Imports
from icevision.all import *
from icevision.models.utils import get_dataloaders
from fastai.callback.tracker import SaveModelCallback
[1m[1mINFO [0m[1m[0m - [1mThe mmdet config folder already exists. No need to downloaded it. Path : /Users/fra/.icevision/mmdetection_configs/mmdetection_configs-2.16.0/configs[0m | [36micevision.models.mmdet.download_configs[0m:[36mdownload_mmdet_configs[0m:[36m17[0m
Datasets : Fridge Objects dataset
Fridge Objects dataset is tiny dataset that contains 134 images of 4 classes: - can, - carton, - milk bottle, - water bottle.
IceVision provides very handy methods such as loading a dataset, parsing annotations, and more.
# Download the dataset
url = "https://cvbp-secondary.z19.web.core.windows.net/datasets/object_detection/odFridgeObjects.zip"
dest_dir = "fridge"
data_dir = icedata.load_data(url, dest_dir)
# Parser
# Create the parser
parser = parsers.VOCBBoxParser(annotations_dir=data_dir / "odFridgeObjects/annotations", images_dir=data_dir / "odFridgeObjects/images")
# Parse annotations to create records
train_records, valid_records = parser.parse()
parser.class_map
Creating a model
# Just change the value of selection to try another model
selection = 0
extra_args = {}
if selection == 0:
model_type = models.mmdet.retinanet
backbone = model_type.backbones.resnet50_fpn_1x
elif selection == 1:
# The Retinanet model is also implemented in the torchvision library
model_type = models.torchvision.retinanet
backbone = model_type.backbones.resnet50_fpn
elif selection == 2:
model_type = models.ross.efficientdet
backbone = model_type.backbones.tf_lite0
# The efficientdet model requires an img_size parameter
extra_args['img_size'] = image_size
elif selection == 3:
model_type = models.ultralytics.yolov5
backbone = model_type.backbones.small
# The yolov5 model requires an img_size parameter
extra_args['img_size'] = image_size
model_type, backbone, extra_args
# Instantiate the model
model = model_type.model(backbone=backbone(pretrained=True), num_classes=len(parser.class_map), **extra_args)
Train and Validation Dataset Transforms
Initial size: First size (size = 384)
# Transforms
presize = 512
size = 384
train_tfms = tfms.A.Adapter([*tfms.A.aug_tfms(size=size, presize=presize), tfms.A.Normalize()])
valid_tfms = tfms.A.Adapter([*tfms.A.resize_and_pad(size), tfms.A.Normalize()])
# DataLoaders
ds, dls = get_dataloaders(model_type, [train_records, valid_records], [train_tfms, valid_tfms], batch_size=16, num_workers=2)
# dls[0].dataset[0]
ds[0][0]
BaseRecord
samples = [ds[0][0] for _ in range(3)]
show_samples(samples, ncols=3)
common:
- Image size ImgSize(width=384, height=384)
- Filepath: /root/.icevision/data/fridge/odFridgeObjects/images/112.jpg
- Img: 384x384x3 <np.ndarray> Image
- Record ID: 15
detection:
- BBoxes: [<BBox (xmin:131.17572064203858, ymin:239.0191364865004, xmax:325.2123962126736, ymax:312.68812907493447)>, <BBox (xmin:82.22909604282495, ymin:82.72545127450111, xmax:181.53168623735334, ymax:288.81880204425784)>]
- Class Map: <ClassMap: {'background': 0, 'carton': 1, 'milk_bottle': 2, 'can': 3, 'water_bottle': 4}>
- Labels: [4, 1]
DataLoader
model_type.show_batch(first(dls[0]), ncols=4)
Metrics
metrics = [COCOMetric(metric_type=COCOMetricType.bbox)]
Training
IceVision is an agnostic framework meaning it can be plugged to other DL framework such as fastai2, and pytorch-lightning.
You could also plug to oth DL framework using your own custom code.
Training using fastai
learn = model_type.fastai.learner(dls=dls, model=model, metrics=metrics, cbs=SaveModelCallback(monitor='COCOMetric'))
learn.lr_find()
SuggestedLRs(lr_min=0.00010000000474974513, lr_steep=0.00015848931798245758)
First Pass: First Training with the size = 384
learn.fine_tune(10, 1e-4, freeze_epochs=1)
epoch | train_loss | valid_loss | COCOMetric | time |
---|---|---|---|---|
0 | 1.303019 | 1.198495 | 0.018335 | 00:07 |
Better model found at epoch 0 with COCOMetric value: 0.018334512022630832.
epoch | train_loss | valid_loss | COCOMetric | time |
---|---|---|---|---|
0 | 1.146912 | 1.077046 | 0.076843 | 00:06 |
1 | 1.078911 | 0.835020 | 0.172249 | 00:06 |
2 | 0.936218 | 0.585196 | 0.317453 | 00:06 |
3 | 0.803767 | 0.500256 | 0.409220 | 00:06 |
4 | 0.701060 | 0.394725 | 0.598500 | 00:06 |
5 | 0.623153 | 0.357143 | 0.649460 | 00:06 |
6 | 0.564359 | 0.360178 | 0.691023 | 00:06 |
7 | 0.518424 | 0.343954 | 0.707619 | 00:06 |
8 | 0.478071 | 0.333479 | 0.735202 | 00:06 |
9 | 0.447511 | 0.331000 | 0.739396 | 00:06 |
Better model found at epoch 0 with COCOMetric value: 0.07684290821192277.
Better model found at epoch 1 with COCOMetric value: 0.17224854342050785.
Better model found at epoch 2 with COCOMetric value: 0.31745342072333244.
Better model found at epoch 3 with COCOMetric value: 0.4092202239902201.
Better model found at epoch 4 with COCOMetric value: 0.5985001096045487.
Better model found at epoch 5 with COCOMetric value: 0.6494598763585641.
Better model found at epoch 6 with COCOMetric value: 0.6910227038257398.
Better model found at epoch 7 with COCOMetric value: 0.7076194302452169.
Better model found at epoch 8 with COCOMetric value: 0.735202499994352.
Better model found at epoch 9 with COCOMetric value: 0.7393958039433105.
Restart resizing from here
Subsequent Pass: Subsequent Training with the size = 512
# Second Pass (size = 512)
# presize = 640
# size = 512
# Third Pass (size = 640)
presize = 768
size = 640
train_tfms = tfms.A.Adapter([*tfms.A.aug_tfms(size=size, presize=presize), tfms.A.Normalize()])
valid_tfms = tfms.A.Adapter([*tfms.A.resize_and_pad(size), tfms.A.Normalize()])
ds, dls = get_dataloaders(model_type, [train_records, valid_records], [train_tfms, valid_tfms], batch_size=16, num_workers=2)
dls[0].dataset[0]
BaseRecord
samples = [ds[0][0] for _ in range(3)]
show_samples(samples, ncols=3)
common:
- Image size ImgSize(width=640, height=640)
- Filepath: /root/.icevision/data/fridge/odFridgeObjects/images/112.jpg
- Img: 640x640x3 <np.ndarray> Image
- Record ID: 15
detection:
- BBoxes: [<BBox (xmin:98.14351388888991, ymin:376.93381905417573, xmax:576.7811863685447, ymax:620.2620189083259)>, <BBox (xmin:5.201505318565632, ymin:12.437821266274758, xmax:329.20034254651534, ymax:498.3316159707134)>]
- Class Map: <ClassMap: {'background': 0, 'carton': 1, 'milk_bottle': 2, 'can': 3, 'water_bottle': 4}>
- Labels: [4, 1]
Convert Pytorch DataLoaders to Fastai DataLoaders
from icevision.engines.fastai import *
fastai_dls = convert_dataloaders_to_fastai(dls=dls)
# Replace current dataloaders by the new ones (corresponding to the new size)
learn.dls = fastai_dls
print(fastai_dls[0])
print(learn.dls[0])
learn.lr_find()
<icevision.engines.fastai.adapters.convert_dataloader_to_fastai.convert_dataloader_to_fastai.<locals>.FastaiDataLoaderWithCollate object at 0x7fe804b4b710>
<icevision.engines.fastai.adapters.convert_dataloader_to_fastai.convert_dataloader_to_fastai.<locals>.FastaiDataLoaderWithCollate object at 0x7fe804b4b710>
SuggestedLRs(lr_min=2.2908675418875645e-07, lr_steep=9.12010818865383e-07)
learn.fine_tune(10, 1e-4, freeze_epochs=1)
# learn.fit_one_cycle(10, 2e-4)
epoch | train_loss | valid_loss | COCOMetric | time |
---|---|---|---|---|
0 | 0.200535 | 0.183777 | 0.857028 | 00:13 |
Better model found at epoch 0 with COCOMetric value: 0.8570280948418122.
epoch | train_loss | valid_loss | COCOMetric | time |
---|---|---|---|---|
0 | 0.192144 | 0.188399 | 0.846013 | 00:10 |
1 | 0.184424 | 0.167971 | 0.865162 | 00:10 |
2 | 0.192361 | 0.194635 | 0.853915 | 00:10 |
3 | 0.191424 | 0.195987 | 0.837771 | 00:10 |
4 | 0.192338 | 0.186021 | 0.828367 | 00:10 |
5 | 0.189835 | 0.152657 | 0.869115 | 00:10 |
6 | 0.184088 | 0.163762 | 0.859548 | 00:10 |
7 | 0.179976 | 0.155070 | 0.868776 | 00:10 |
8 | 0.175261 | 0.156499 | 0.867725 | 00:10 |
9 | 0.171284 | 0.156758 | 0.863302 | 00:10 |
model_type.show_results(model, ds[1], detection_threshold=.5)
Better model found at epoch 0 with COCOMetric value: 0.8460132624669883.
Better model found at epoch 1 with COCOMetric value: 0.8651620055945506.
Better model found at epoch 5 with COCOMetric value: 0.8691152429678355.
Inference
Predicting a batch of images
Instead of predicting a whole list of images at one, we can process small batch at the time: This option is more memory efficient.
infer_dl = model_type.infer_dl(ds[1], batch_size=4, shuffle=False)
preds = model_type.predict_from_dl(model, infer_dl, keep_images=True)
show_preds(preds=preds[:4])
0%| | 0/7 [00:00<?, ?it/s]
Training using Pytorch Lightning
You have to follow the same procedure as for the Fastai example. It is quite similar to the Fastai one except we don't need to convert dataloaders like in Fastai. PL dataloaders are just pytorch dataloaders.
# Create a model
class LightModel(model_type.lightning.ModelAdapter):
def configure_optimizers(self):
return SGD(self.parameters(), lr=1e-4)
light_model = LightModel(model, metrics=metrics)
# Create a trainer
trainer = pl.Trainer(max_epochs=10, gpus=1)
# First Pass (size = 384)
# Transforms
presize = 512
size = 384
train_tfms = tfms.A.Adapter([*tfms.A.aug_tfms(size=size, presize=presize), tfms.A.Normalize()])
valid_tfms = tfms.A.Adapter([*tfms.A.resize_and_pad(size), tfms.A.Normalize()])
# Dataloaders
ds, dls = get_dataloaders(model_type, [train_records, valid_records], [train_tfms, valid_tfms], batch_size=16, num_workers=2)
First training
trainer.fit(light_model, dls[0], dls[1])
# Second Pass (size = 512)
presize = 640
size = 512
# Third Pass (size = 640)
# presize = 768
# size = 640
train_tfms = tfms.A.Adapter([*tfms.A.aug_tfms(size=size, presize=presize), tfms.A.Normalize()])
valid_tfms = tfms.A.Adapter([*tfms.A.resize_and_pad(size), tfms.A.Normalize()])
# Dataloaders
ds, dls = get_dataloaders(model_type, [train_records, valid_records], [train_tfms, valid_tfms], batch_size=16, num_workers=2)
# Replace current dataloaders by the new ones (corresponding to the new size)
trainer.train_dataloader = dls[0]
trainer.valid_dataloader = dls[1]
# Subsequent training
trainer.fit(light_model, dls[0], dls[1])
Saving Model on Google Drive
from google.colab import drive
drive.mount('/content/gdrive', force_remount=True)
root_dir = Path('/content/gdrive/My Drive/')
torch.save(model.state_dict(), root_dir/'icevision/models/fridge/fridge_retinanet_prog_resizing_1.pth')
Mounted at /content/gdrive
Happy Learning!
If you need any assistance, feel free to join our forum.