Training#
Train super-image models for image super resolution tasks.
Setting up the Environment#
Install the library#
We will install the super-image and huggingface datasets library using pip install.
pip install -qq datasets super-image
Loading and Augmenting the Dataset#
We download the Div2k dataset using the huggingface datasets library. You can explore more super resolution datasets here.
We then follow the pre-processing and augmentation method of Wang et al. (2021). This will take awhile, go grab a coffee.
- Note that you can change
bicubic_x4to any of [bicubic_x2,bicubic_x3orbicubic_x4]. - If you don't want to do augmentation to your dataset, you can just do:
train_dataset = TrainDataset(load_dataset('eugenesiow/Div2k', 'bicubic_x4', split='train')) - If you want eval to be faster you can use the much smaller Set5:
eval_dataset = EvalDataset(load_dataset('eugenesiow/Set5', 'bicubic_x4', split='validation'))
from datasets import load_dataset
from super_image.data import EvalDataset, TrainDataset, augment_five_crop
augmented_dataset = load_dataset('eugenesiow/Div2k', 'bicubic_x4', split='train')\
.map(augment_five_crop, batched=True, desc="Augmenting Dataset") # download and augment the data with the five_crop method
train_dataset = TrainDataset(augmented_dataset) # prepare the train dataset for loading PyTorch DataLoader
eval_dataset = EvalDataset(load_dataset('eugenesiow/Div2k', 'bicubic_x4', split='validation')) # prepare the eval dataset for the PyTorch DataLoader
Training the Model#
We then train the model. It's best if you have a GPU.
from super_image import Trainer, TrainingArguments, EdsrModel, EdsrConfig
training_args = TrainingArguments(
output_dir='./results', # output directory
num_train_epochs=1000, # total number of training epochs
)
config = EdsrConfig(
scale=4, # train a model to upscale 4x
)
model = EdsrModel(config)
trainer = Trainer(
model=model, # the instantiated model to be trained
args=training_args, # training arguments, defined above
train_dataset=train_dataset, # training dataset
eval_dataset=eval_dataset # evaluation dataset
)
trainer.train()
We see that after each epoch of training, the PSNR and SSIM scores of the epoch on the validation set is reported.
The best model after 1000 epochs is saved.
Try Other Architectures#
- You can try the other architectures in
super-image. - Compare the performance via the leaderboard.
- View the various pretrained models on huggingface hub.
Here is an example on another architecture, MSRN:
from super_image import Trainer, TrainingArguments, MsrnModel, MsrnConfig
training_args = TrainingArguments(
output_dir='./results_msrn', # output directory
num_train_epochs=2, # total number of training epochs
)
config = MsrnConfig(
scale=4, # train a model to upscale 4x
bam=True, # use balanced attention
)
model = MsrnModel(config)
trainer = Trainer(
model=model, # the instantiated model to be trained
args=training_args, # training arguments, defined above
train_dataset=train_dataset, # training dataset
eval_dataset=eval_dataset # evaluation dataset
)
trainer.train()