Multi-scale Residual Network for Image Super-Resolution (MSRN)#
Overview#
The MSRN model proposes a feature extraction structure called the multi-scale residual block. This module can "adaptively detect image features at different scales" and "exploit the potential features of the image".
This model also applies the balanced attention (BAM) method invented by Wang et al. (2021) to further improve the results.
It was introduced in the paper Multi-scale Residual Network for Image Super-Resolution by Li et al. (2018) and first released in this repository.
MsrnConfig#
This is the configuration class to store the configuration of a :class:~super_image.MsrnModel
.
It is used to instantiate the model according to the specified arguments, defining the model architecture.
Instantiating a configuration with the defaults will yield a similar
configuration to that of the MSRN BAM architecture.
Configuration objects inherit from :class:~super_image.PretrainedConfig
and can be used to control the model
outputs. Read the documentation from :class:~super_image.PretrainedConfig
for more information.
Examples:
from super_image import MsrnModel, MsrnConfig
# Initializing a configuration
config = MsrnConfig(
scale=4, # train a model to upscale 4x
bam=True, # use balanced attention (BAM)
)
# Initializing a model from the configuration
model = MsrnModel(config)
# Accessing the model configuration
configuration = model.config
__init__(self, scale=None, n_blocks=8, n_feats=64, rgb_range=255, bam=False, rgb_mean=(0.4488, 0.4371, 0.404), rgb_std=(1.0, 1.0, 1.0), data_parallel=False, **kwargs)
special
#
Parameters:
Name | Type | Description | Default |
---|---|---|---|
scale |
int |
Scale for the model to train an upscaler/super-res model. |
None |
n_blocks |
int |
Number of blocks. |
8 |
n_feats |
int |
Number of filters. |
64 |
rgb_range |
int |
Range of RGB as a multiplier to the MeanShift. |
255 |
data_parallel |
bool |
Option to use multiple GPUs for training. |
False |
bam |
bool |
Option to use balanced attention modules instead (BAM) |
False |
Source code in super_image\models\msrn\configuration_msrn.py
def __init__(self, scale=None, n_blocks=8, n_feats=64, rgb_range=255, bam=False,
rgb_mean=DIV2K_RGB_MEAN, rgb_std=DIV2K_RGB_STD,
data_parallel=False, **kwargs):
"""
Args:
scale (int): Scale for the model to train an upscaler/super-res model.
n_blocks (int): Number of blocks.
n_feats (int): Number of filters.
rgb_range (int):
Range of RGB as a multiplier to the MeanShift.
data_parallel (bool):
Option to use multiple GPUs for training.
bam (bool): Option to use balanced attention modules instead (BAM)
"""
super().__init__(**kwargs)
self.scale = scale
self.n_blocks = n_blocks
self.n_feats = n_feats
self.rgb_range = rgb_range
self.rgb_mean = rgb_mean
self.rgb_std = rgb_std
self.data_parallel = data_parallel
self.bam = bam
MsrnModel#
config_class
#
This is the configuration class to store the configuration of a :class:~super_image.MsrnModel
.
It is used to instantiate the model according to the specified arguments, defining the model architecture.
Instantiating a configuration with the defaults will yield a similar
configuration to that of the MSRN BAM architecture.
Configuration objects inherit from :class:~super_image.PretrainedConfig
and can be used to control the model
outputs. Read the documentation from :class:~super_image.PretrainedConfig
for more information.
Examples:
from super_image import MsrnModel, MsrnConfig
# Initializing a configuration
config = MsrnConfig(
scale=4, # train a model to upscale 4x
bam=True, # use balanced attention (BAM)
)
# Initializing a model from the configuration
model = MsrnModel(config)
# Accessing the model configuration
configuration = model.config
__init__(self, scale=None, n_blocks=8, n_feats=64, rgb_range=255, bam=False, rgb_mean=(0.4488, 0.4371, 0.404), rgb_std=(1.0, 1.0, 1.0), data_parallel=False, **kwargs)
special
#
Parameters:
Name | Type | Description | Default |
---|---|---|---|
scale |
int |
Scale for the model to train an upscaler/super-res model. |
None |
n_blocks |
int |
Number of blocks. |
8 |
n_feats |
int |
Number of filters. |
64 |
rgb_range |
int |
Range of RGB as a multiplier to the MeanShift. |
255 |
data_parallel |
bool |
Option to use multiple GPUs for training. |
False |
bam |
bool |
Option to use balanced attention modules instead (BAM) |
False |
Source code in super_image\models\msrn\modeling_msrn.py
def __init__(self, scale=None, n_blocks=8, n_feats=64, rgb_range=255, bam=False,
rgb_mean=DIV2K_RGB_MEAN, rgb_std=DIV2K_RGB_STD,
data_parallel=False, **kwargs):
"""
Args:
scale (int): Scale for the model to train an upscaler/super-res model.
n_blocks (int): Number of blocks.
n_feats (int): Number of filters.
rgb_range (int):
Range of RGB as a multiplier to the MeanShift.
data_parallel (bool):
Option to use multiple GPUs for training.
bam (bool): Option to use balanced attention modules instead (BAM)
"""
super().__init__(**kwargs)
self.scale = scale
self.n_blocks = n_blocks
self.n_feats = n_feats
self.rgb_range = rgb_range
self.rgb_mean = rgb_mean
self.rgb_std = rgb_std
self.data_parallel = data_parallel
self.bam = bam
forward(self, x)
#
Defines the computation performed at every call.
Should be overridden by all subclasses.
.. note::
Although the recipe for forward pass needs to be defined within
this function, one should call the :class:Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Source code in super_image\models\msrn\modeling_msrn.py
def forward(self, x):
# x = self.sub_mean(x)
x = self.head(x)
res = x
MSRB_out = []
for i in range(self.n_blocks):
x = self.body[i](x)
MSRB_out.append(x)
MSRB_out.append(res)
res = torch.cat(MSRB_out, 1)
x = self.tail(res)
# x = self.add_mean(x)
return x
load_state_dict(self, state_dict, strict=True)
#
Copies parameters and buffers from :attr:state_dict
into
this module and its descendants. If :attr:strict
is True
, then
the keys of :attr:state_dict
must exactly match the keys returned
by this module's :meth:~torch.nn.Module.state_dict
function.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
state_dict |
dict |
a dict containing parameters and persistent buffers. |
required |
strict |
bool |
whether to strictly enforce that the keys
in :attr: |
True |
Returns:
Type | Description |
---|---|
``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields |
|
Source code in super_image\models\msrn\modeling_msrn.py
def load_state_dict(self, state_dict, strict=True):
own_state = self.state_dict()
for name, param in state_dict.items():
if name in own_state:
if isinstance(param, nn.Parameter):
param = param.data
try:
own_state[name].copy_(param)
except Exception:
if name.find('tail') >= 0:
print('Replace pre-trained upsampler to new one...')
else:
raise RuntimeError(f'While copying the parameter named {name}, '
f'whose dimensions in the model are {own_state[name].size()} and '
f'whose dimensions in the checkpoint are {param.size()}.')
elif strict:
if name.find('tail') == -1:
raise KeyError(f'unexpected key "{name}" in state_dict')
if strict:
missing = set(own_state.keys()) - set(state_dict.keys())
if len(missing) > 0:
raise KeyError(f'missing keys in state_dict: "{missing}"')