Attention in Attention Network for Image Super-Resolution (A2N)#
Overview#
The A2N model proposes an attention in attention network (A2N) for highly accurate image SR. Specifically, the A2N consists of a non-attention branch and a coupling attention branch. Attention dropout module is proposed to generate dynamic attention weights for these two branches based on input features that can suppress unwanted attention adjustments. This allows attention modules to specialize to beneficial examples without otherwise penalties and thus greatly improve the capacity of the attention network with little parameter overhead.
More importantly the model is lightweight and fast to train (~1.5m parameters, ~4mb).
It was introduced in the paper Attention in Attention Network for Image Super-Resolution by Chen et al. (2021) and first released in this repository.
A2nConfig#
A2nModel#
forward(self, x)
#
Defines the computation performed at every call.
Should be overridden by all subclasses.
.. note::
Although the recipe for forward pass needs to be defined within
this function, one should call the :class:Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Source code in super_image\models\a2n\modeling_a2n.py
def forward(self, x):
fea = self.conv_first(x)
trunk = self.trunk_conv(self.AAB_trunk(fea))
fea = fea + trunk
if self.scale == 2 or self.scale == 3:
fea = self.upconv1(functional.interpolate(fea, scale_factor=self.scale, mode='nearest'))
fea = self.lrelu(self.att1(fea))
fea = self.lrelu(self.HRconv1(fea))
elif self.scale == 4:
fea = self.upconv1(functional.interpolate(fea, scale_factor=2, mode='nearest'))
fea = self.lrelu(self.att1(fea))
fea = self.lrelu(self.HRconv1(fea))
fea = self.upconv2(functional.interpolate(fea, scale_factor=2, mode='nearest'))
fea = self.lrelu(self.att2(fea))
fea = self.lrelu(self.HRconv2(fea))
out = self.conv_last(fea)
ilr = functional.interpolate(x, scale_factor=self.scale, mode='bilinear', align_corners=False)
out = out + ilr
return out
load_state_dict(self, state_dict, strict=True)
#
Copies parameters and buffers from :attr:state_dict
into
this module and its descendants. If :attr:strict
is True
, then
the keys of :attr:state_dict
must exactly match the keys returned
by this module's :meth:~torch.nn.Module.state_dict
function.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
state_dict |
dict |
a dict containing parameters and persistent buffers. |
required |
strict |
bool |
whether to strictly enforce that the keys
in :attr: |
True |
Returns:
Type | Description |
---|---|
``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields |
|
Source code in super_image\models\a2n\modeling_a2n.py
def load_state_dict(self, state_dict, strict=True):
own_state = self.state_dict()
for name, param in state_dict.items():
if name in own_state:
if isinstance(param, nn.Parameter):
param = param.data
try:
own_state[name].copy_(param)
except Exception:
if name.find('tail') == -1:
raise RuntimeError(f'While copying the parameter named {name}, '
f'whose dimensions in the model are {own_state[name].size()} and '
f'whose dimensions in the checkpoint are {param.size()}.')
elif strict:
if name.find('tail') == -1:
raise KeyError(f'unexpected key "{name}" in state_dict')