Skip to content

Models#

The base class PreTrainedModel implements the common methods for loading/saving a model either from a local file or directory, or from a pretrained model configuration provided by the library (downloaded from HuggingFace's Hub).

PreTrainedModel#

save_pretrained(self, save_directory, save_config=True, state_dict=None, save_function=<function save at 0x000001C1C08B4318>) #

Save a model and its configuration file to a directory, so that it can be re-loaded using the :func:~super_image.PreTrainedModel.from_pretrained`` class method.

Parameters:

Name Type Description Default
save_directory Union[str, os.PathLike]

obj:str or :obj:os.PathLike): Directory to which to save. Will be created if it doesn't exist.

required
save_config bool

obj:bool, optional, defaults to :obj:True): Whether or not to save the config of the model. Useful when in distributed training like TPUs and need to call this function on all processes. In this case, set :obj:save_config=True only on the main process to avoid race conditions.

True
state_dict Optional[dict]

obj:torch.Tensor): The state dictionary of the model to save. Will default to :obj:self.state_dict(), but can be used to only save parts of the model or if special precautions need to be taken when recovering the state dictionary of a model (like when using model parallelism).

None
save_function Callable

obj:Callable): The function to use to save the state dictionary. When we need to replace :obj:torch.save by another method.

<function save at 0x000001C1C08B4318>
Source code in super_image\modeling_utils.py
def save_pretrained(
        self,
        save_directory: Union[str, os.PathLike],
        save_config: bool = True,
        state_dict: Optional[dict] = None,
        save_function: Callable = torch.save,
):
    """
    Save a model and its configuration file to a directory, so that it can be re-loaded using the
    `:func:`~super_image.PreTrainedModel.from_pretrained`` class method.
    Arguments:
        save_directory (:obj:`str` or :obj:`os.PathLike`):
            Directory to which to save. Will be created if it doesn't exist.
        save_config (:obj:`bool`, `optional`, defaults to :obj:`True`):
            Whether or not to save the config of the model. Useful when in distributed training like TPUs and need
            to call this function on all processes. In this case, set :obj:`save_config=True` only on the main
            process to avoid race conditions.
        state_dict (nested dictionary of :obj:`torch.Tensor`):
            The state dictionary of the model to save. Will default to :obj:`self.state_dict()`, but can be used to
            only save parts of the model or if special precautions need to be taken when recovering the state
            dictionary of a model (like when using model parallelism).
        save_function (:obj:`Callable`):
            The function to use to save the state dictionary. When we need to replace :obj:`torch.save` by another
            method.
    """
    if os.path.isfile(save_directory):
        logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
        return

    os.makedirs(save_directory, exist_ok=True)

    model_to_save = self

    # Setup scale
    scale = self.config.scale
    if scale is not None:
        weights_name = WEIGHTS_NAME_SCALE.format(scale=scale)
    else:
        weights_name = WEIGHTS_NAME

    # Save the config
    if save_config:
        model_to_save.config.save_pretrained(save_directory)

    # Save the model
    if state_dict is None:
        state_dict = model_to_save.state_dict()

    # If we save using the predefined names, we can load using `from_pretrained`
    output_model_file = os.path.join(save_directory, weights_name)
    save_function(state_dict, output_model_file)

    logger.info(f"Model weights saved in {output_model_file}")