Models#
The base class PreTrainedModel implements the common methods for loading/saving a model either from a local file or directory, or from a pretrained model configuration provided by the library (downloaded from HuggingFace's Hub).
PreTrainedModel#
save_pretrained(self, save_directory, save_config=True, state_dict=None, save_function=<function save at 0x000001C1C08B4318>)
#
Save a model and its configuration file to a directory, so that it can be re-loaded using the
:func:
~super_image.PreTrainedModel.from_pretrained`` class method.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
save_directory |
Union[str, os.PathLike] |
obj: |
required |
save_config |
bool |
obj: |
True |
state_dict |
Optional[dict] |
obj: |
None |
save_function |
Callable |
obj: |
<function save at 0x000001C1C08B4318> |
Source code in super_image\modeling_utils.py
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
save_config: bool = True,
state_dict: Optional[dict] = None,
save_function: Callable = torch.save,
):
"""
Save a model and its configuration file to a directory, so that it can be re-loaded using the
`:func:`~super_image.PreTrainedModel.from_pretrained`` class method.
Arguments:
save_directory (:obj:`str` or :obj:`os.PathLike`):
Directory to which to save. Will be created if it doesn't exist.
save_config (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to save the config of the model. Useful when in distributed training like TPUs and need
to call this function on all processes. In this case, set :obj:`save_config=True` only on the main
process to avoid race conditions.
state_dict (nested dictionary of :obj:`torch.Tensor`):
The state dictionary of the model to save. Will default to :obj:`self.state_dict()`, but can be used to
only save parts of the model or if special precautions need to be taken when recovering the state
dictionary of a model (like when using model parallelism).
save_function (:obj:`Callable`):
The function to use to save the state dictionary. When we need to replace :obj:`torch.save` by another
method.
"""
if os.path.isfile(save_directory):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return
os.makedirs(save_directory, exist_ok=True)
model_to_save = self
# Setup scale
scale = self.config.scale
if scale is not None:
weights_name = WEIGHTS_NAME_SCALE.format(scale=scale)
else:
weights_name = WEIGHTS_NAME
# Save the config
if save_config:
model_to_save.config.save_pretrained(save_directory)
# Save the model
if state_dict is None:
state_dict = model_to_save.state_dict()
# If we save using the predefined names, we can load using `from_pretrained`
output_model_file = os.path.join(save_directory, weights_name)
save_function(state_dict, output_model_file)
logger.info(f"Model weights saved in {output_model_file}")