super-image
State-of-the-art image super resolution models for PyTorch.
Installation#
With pip
:
pip install super-image
Quick Start#
Quickly utilise pre-trained models for upscaling your images 2x, 3x and 4x. See the full list of models below.
from super_image import EdsrModel, ImageLoader
from PIL import Image
import requests
url = 'https://paperswithcode.com/media/datasets/Set5-0000002728-07a9793f_zA3bDjj.jpg'
image = Image.open(requests.get(url, stream=True).raw)
model = EdsrModel.from_pretrained('eugenesiow/edsr-base', scale=2)
inputs = ImageLoader.load_image(image)
preds = model(inputs)
ImageLoader.save_image(preds, './scaled_2x.png')
ImageLoader.save_compare(inputs, preds, './scaled_2x_compare.png')
Pre-trained Models#
Pre-trained models are available at various scales and hosted at the awesome huggingface_hub
. By default the models were pretrained on DIV2K, a dataset of 800 high-quality (2K resolution) images for training, augmented to 4000 images and uses a dev set of 100 validation images (images numbered 801 to 900).
The leaderboard below shows the PSNR / SSIM metrics for each model at various scales on various test sets (Set5, Set14, BSD100, Urban100). The higher the better. All training was to 1000 epochs (some publications, like a2n, train to >1000 epochs in their experiments).
Scale x2#
Rank | Model | Params | Set5 | Set14 | BSD100 | Urban100 |
---|---|---|---|---|---|---|
1 | msrn | 5.9m | 38.08/0.9609 | 33.75/0.9183 | 33.82/0.9258 | 32.14/0.9287 |
2 | mdsr | 2.7m | 38.04/0.9608 | 33.71/0.9184 | 33.79/0.9256 | 32.14/0.9283 |
3 | msrn-bam | 5.9m | 38.02/0.9608 | 33.73/0.9186 | 33.78/0.9253 | 32.08/0.9276 |
4 | edsr-base | 1.5m | 38.02/0.9607 | 33.66/0.9180 | 33.77/0.9254 | 32.04/0.9276 |
5 | mdsr-bam | 2.7m | 38/0.9607 | 33.68/0.9182 | 33.77/0.9253 | 32.04/0.9272 |
6 | awsrn-bam | 1.4m | 37.99/0.9606 | 33.66/0.918 | 33.76/0.9253 | 31.95/0.9265 |
7 | a2n | 1.0m | 37.87/0.9602 | 33.54/0.9171 | 33.67/0.9244 | 31.71/0.9240 |
8 | carn | 1.6m | 37.89/0.9602 | 33.53/0.9173 | 33.66/0.9242 | 31.62/0.9229 |
9 | carn-bam | 1.6m | 37.83/0.96 | 33.51/0.9166 | 33.64/0.924 | 31.53/0.922 |
10 | pan | 260k | 37.77/0.9599 | 33.42/0.9162 | 33.6/0.9235 | 31.31/0.9197 |
11 | pan-bam | 260k | 37.7/0.9596 | 33.4/0.9161 | 33.6/0.9234 | 31.35/0.92 |
Scale x3#
Rank | Model | Params | Set5 | Set14 | BSD100 | Urban100 |
---|---|---|---|---|---|---|
1 | msrn | 6.1m | 35.12/0.9409 | 31.08/0.8593 | 29.67/0.8198 | 29.31/0.8743 |
2 | mdsr | 2.9m | 35.11/0.9406 | 31.06/0.8593 | 29.66/0.8196 | 29.29/0.8738 |
3 | msrn-bam | 5.9m | 35.13/0.9408 | 31.06/0.8588 | 29.65/0.8196 | 29.26/0.8736 |
4 | mdsr-bam | 2.9m | 35.07/0.9402 | 31.04/0.8582 | 29.62/0.8188 | 29.16/0.8717 |
5 | edsr-base | 1.5m | 35.01/0.9402 | 31.01/0.8583 | 29.63/0.8190 | 29.19/0.8722 |
6 | awsrn-bam | 1.5m | 35.05/0.9403 | 31.01/0.8581 | 29.63/0.8188 | 29.14/0.871 |
7 | carn | 1.6m | 34.88/0.9391 | 30.93/0.8566 | 29.56/0.8173 | 28.95/0.867 |
8 | a2n | 1.0m | 34.8/0.9387 | 30.94/0.8568 | 29.56/0.8173 | 28.95/0.8671 |
9 | carn-bam | 1.6m | 34.82/0.9385 | 30.9/0.8558 | 29.54/0.8166 | 28.84/0.8648 |
10 | pan-bam | 260k | 34.62/0.9371 | 30.83/0.8545 | 29.47/0.8153 | 28.64/0.861 |
11 | pan | 260k | 34.64/0.9376 | 30.8/0.8544 | 29.47/0.815 | 28.61/0.8603 |
Scale x4#
Rank | Model | Params | Set5 | Set14 | BSD100 | Urban100 |
---|---|---|---|---|---|---|
1 | msrn | 6.1m | 32.19/0.8951 | 28.78/0.7862 | 28.53/0.7657 | 26.12/0.7866 |
2 | msrn-bam | 5.9m | 32.26/0.8955 | 28.78/0.7859 | 28.51/0.7651 | 26.10/0.7857 |
3 | mdsr | 2.8m | 32.26/0.8953 | 28.77/0.7856 | 28.53/0.7653 | 26.07/0.7851 |
4 | mdsr-bam | 2.9m | 32.19/0.8949 | 28.73/0.7847 | 28.50/0.7645 | 26.02/0.7834 |
5 | awsrn-bam | 1.6m | 32.13/0.8947 | 28.75/0.7851 | 28.51/0.7647 | 26.03/0.7838 |
6 | edsr-base | 1.5m | 32.12/0.8947 | 28.72/0.7845 | 28.50/0.7644 | 26.02/0.7832 |
7 | a2n | 1.0m | 32.07/0.8933 | 28.68/0.7830 | 28.44/0.7624 | 25.89/0.7787 |
8 | carn | 1.6m | 32.05/0.8931 | 28.67/0.7828 | 28.44/0.7625 | 25.85/0.7768 |
9 | carn-bam | 1.6m | 32.0/0.8923 | 28.62/0.7822 | 28.41/0.7614 | 25.77/0.7741 |
10 | pan | 270k | 31.92/0.8915 | 28.57/0.7802 | 28.35/0.7595 | 25.63/0.7692 |
11 | pan-bam | 270k | 31.9/0.8911 | 28.54/0.7795 | 28.32/0.7591 | 25.6/0.7691 |
You can find a notebook to easily run evaluation on pretrained models below:
Train Models#
We need the huggingface datasets library to download the data:
pip install datasets
from datasets import load_dataset
from super_image.data import EvalDataset, TrainDataset, augment_five_crop
augmented_dataset = load_dataset('eugenesiow/Div2k', 'bicubic_x4', split='train')\
.map(augment_five_crop, batched=True, desc="Augmenting Dataset") # download and augment the data with the five_crop method
train_dataset = TrainDataset(augmented_dataset) # prepare the train dataset for loading PyTorch DataLoader
eval_dataset = EvalDataset(load_dataset('eugenesiow/Div2k', 'bicubic_x4', split='validation')) # prepare the eval dataset for the PyTorch DataLoader
The training code is provided below:
from super_image import Trainer, TrainingArguments, EdsrModel, EdsrConfig
training_args = TrainingArguments(
output_dir='./results', # output directory
num_train_epochs=1000, # total number of training epochs
)
config = EdsrConfig(
scale=4, # train a model to upscale 4x
)
model = EdsrModel(config)
trainer = Trainer(
model=model, # the instantiated model to be trained
args=training_args, # training arguments, defined above
train_dataset=train_dataset, # training dataset
eval_dataset=eval_dataset # evaluation dataset
)
trainer.train()