Skip to content

Pixel Attention Network (PAN)#

Overview#

The PAN model proposes a a lightweight convolutional neural network for image super resolution. Pixel attention (PA) is similar to channel attention and spatial attention in formulation. PA however produces 3D attention maps instead of a 1D attention vector or a 2D map. This attention scheme introduces fewer additional parameters but generates better SR results.

This model also applies the balanced attention (BAM) method invented by Wang et al. (2021) to further improve the results.

It was introduced in the paper Efficient Image Super-Resolution Using Pixel Attention by Zhao et al. (2020) and first released in this repository.

PanConfig#

PanModel#

forward(self, x) #

Defines the computation performed at every call.

Should be overridden by all subclasses.

.. note:: Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Source code in super_image\models\pan\modeling_pan.py
def forward(self, x):
    fea = self.conv_first(x)
    trunk = self.trunk_conv(self.SCPA_trunk(fea))
    fea = fea + trunk

    if self.scale == 2 or self.scale == 3:
        fea = self.upconv1(functional.interpolate(fea, scale_factor=self.scale, mode='nearest'))
        fea = self.lrelu(self.att1(fea))
        fea = self.lrelu(self.HRconv1(fea))
    elif self.scale == 4:
        fea = self.upconv1(functional.interpolate(fea, scale_factor=2, mode='nearest'))
        fea = self.lrelu(self.att1(fea))
        fea = self.lrelu(self.HRconv1(fea))
        fea = self.upconv2(functional.interpolate(fea, scale_factor=2, mode='nearest'))
        fea = self.lrelu(self.att2(fea))
        fea = self.lrelu(self.HRconv2(fea))

    out = self.conv_last(fea)

    ilr = functional.interpolate(x, scale_factor=self.scale, mode='bilinear', align_corners=False)
    out = out + ilr
    return out

load_state_dict(self, state_dict, strict=True) #

Copies parameters and buffers from :attr:state_dict into this module and its descendants. If :attr:strict is True, then the keys of :attr:state_dict must exactly match the keys returned by this module's :meth:~torch.nn.Module.state_dict function.

Parameters:

Name Type Description Default
state_dict dict

a dict containing parameters and persistent buffers.

required
strict bool

whether to strictly enforce that the keys in :attr:state_dict match the keys returned by this module's :meth:~torch.nn.Module.state_dict function. Default: True

True

Returns:

Type Description
``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields
  • missing_keys is a list of str containing the missing keys
    • unexpected_keys is a list of str containing the unexpected keys
Source code in super_image\models\pan\modeling_pan.py
def load_state_dict(self, state_dict, strict=True):
    own_state = self.state_dict()
    for name, param in state_dict.items():
        if name in own_state:
            if isinstance(param, nn.Parameter):
                param = param.data
            try:
                own_state[name].copy_(param)
            except Exception:
                if name.find('tail') >= 0:
                    print('Replace pre-trained upsampler to new one...')
                else:
                    raise RuntimeError(f'While copying the parameter named {name}, '
                                       f'whose dimensions in the model are {own_state[name].size()} and '
                                       f'whose dimensions in the checkpoint are {param.size()}.')
        elif strict:
            if name.find('tail') == -1:
                raise KeyError(f'unexpected key "{name}" in state_dict')

    if strict:
        missing = set(own_state.keys()) - set(state_dict.keys())
        if len(missing) > 0:
            raise KeyError(f'missing keys in state_dict: "{missing}"')