Skip to content

Trainer#

The Trainer class provides an API for feature-complete training in most standard use cases.

Before instantiating your Trainer, create a TrainingArguments to access all the points of customization during training.

Trainer#

Trainer is a simple class implementing the training and eval loop for PyTorch to train a super-image model.

Parameters:

Name Type Description Default
model

class:~super_image.PreTrainedModel or :obj:torch.nn.Module, optional): The model to train, evaluate or use for predictions. If not provided, a model_init must be passed. .. note:: :class:~super_image.Trainer is optimized to work with the :class:~super_image.PreTrainedModel provided by the library. You can still use your own models defined as :obj:torch.nn.Module as long as they work the same way as the super_image models.

required
args

class:~super_image.TrainingArguments, optional): The arguments to tweak for training. Will default to a basic instance of :class:~super_image.TrainingArguments with the output_dir set to a directory named tmp_trainer in the current directory if not provided.

required
train_dataset

obj:torch.utils.data.dataset.Dataset or :obj:torch.utils.data.dataset.IterableDataset): The dataset to use for training.

required
eval_dataset

obj:torch.utils.data.dataset.Dataset, optional): The dataset to use for evaluation.

required

get_eval_dataloader(self) #

Returns the evaluation :class:~torch.utils.data.DataLoader.

Source code in super_image\trainer.py
def get_eval_dataloader(self) -> DataLoader:
    """
    Returns the evaluation :class:`~torch.utils.data.DataLoader`.
    """

    eval_dataset = self.eval_dataset
    if eval_dataset is None:
        eval_dataset = self.train_dataset

    return DataLoader(
        dataset=eval_dataset,
        batch_size=1,
    )

get_train_dataloader(self) #

Returns the training :class:~torch.utils.data.DataLoader.

Source code in super_image\trainer.py
def get_train_dataloader(self) -> DataLoader:
    """
    Returns the training :class:`~torch.utils.data.DataLoader`.
    """

    if self.train_dataset is None:
        raise ValueError("Trainer: training requires a train_dataset.")

    train_dataset = self.train_dataset

    return DataLoader(
        dataset=train_dataset,
        batch_size=self.args.train_batch_size,
        shuffle=True,
        num_workers=self.args.dataloader_num_workers,
        pin_memory=self.args.dataloader_pin_memory,
    )

save_model(self, output_dir=None) #

Will save the model, so you can reload it using :obj:from_pretrained(). Will only save from the main process.

Source code in super_image\trainer.py
def save_model(self, output_dir: Optional[str] = None):
    """
    Will save the model, so you can reload it using :obj:`from_pretrained()`.
    Will only save from the main process.
    """

    output_dir = output_dir if output_dir is not None else self.args.output_dir
    os.makedirs(output_dir, exist_ok=True)

    if not isinstance(self.model, PreTrainedModel):
        # Setup scale
        scale = self.model.config.scale
        if scale is not None:
            weights_name = WEIGHTS_NAME_SCALE.format(scale=scale)
        else:
            weights_name = WEIGHTS_NAME

        weights = copy.deepcopy(self.model.state_dict())
        torch.save(weights, os.path.join(output_dir, weights_name))
    else:
        self.model.save_pretrained(output_dir)

train(self, resume_from_checkpoint=None, **kwargs) #

Main training entry point.

Parameters:

Name Type Description Default
resume_from_checkpoint Union[bool, str]

obj:str or :obj:bool, optional): If a :obj:str, local path to a saved checkpoint as saved by a previous instance of :class:~super_image.Trainer. If a :obj:bool and equals True, load the last checkpoint in args.output_dir as saved by a previous instance of :class:~super_image.Trainer. If present, training will resume from the model/optimizer/scheduler states loaded here.

None
kwargs

Additional keyword arguments used to hide deprecated arguments

{}
Source code in super_image\trainer.py
def train(
        self,
        resume_from_checkpoint: Optional[Union[str, bool]] = None,
        **kwargs,
):
    """
    Main training entry point.
    Args:
        resume_from_checkpoint (:obj:`str` or :obj:`bool`, `optional`):
            If a :obj:`str`, local path to a saved checkpoint as saved by a previous instance of
            :class:`~super_image.Trainer`. If a :obj:`bool` and equals `True`, load the last checkpoint in
            `args.output_dir` as saved by a previous instance of :class:`~super_image.Trainer`. If present,
            training will resume from the model/optimizer/scheduler states loaded here.
        kwargs:
            Additional keyword arguments used to hide deprecated arguments
    """
    args = self.args

    epochs_trained = 0
    device = args.device
    num_train_epochs = args.num_train_epochs
    learning_rate = args.learning_rate
    train_batch_size = args.train_batch_size
    train_dataset = self.train_dataset
    train_dataloader = self.get_train_dataloader()
    step_size = int(len(train_dataset) / train_batch_size * 200)

    # # Load potential model checkpoint
    # if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint:
    #     resume_from_checkpoint = get_last_checkpoint(args.output_dir)
    #     if resume_from_checkpoint is None:
    #         raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})")
    #
    # if resume_from_checkpoint is not None:
    #     if not os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)):
    #         raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}")
    #
    #     logger.info(f"Loading model from {resume_from_checkpoint}).")
    #
    #     if os.path.isfile(os.path.join(resume_from_checkpoint, CONFIG_NAME)):
    #         config = PretrainedConfig.from_json_file(os.path.join(resume_from_checkpoint, CONFIG_NAME))
    #
    #     state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME), map_location="cpu")
    #     # If the model is on the GPU, it still works!
    #     self._load_state_dict_in_model(state_dict)
    #
    #     # release memory
    #     del state_dict

    optimizer = Adam(self.model.parameters(), lr=learning_rate)
    scheduler = lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=self.args.gamma)

    for epoch in range(epochs_trained, num_train_epochs):
        for param_group in optimizer.param_groups:
            param_group['lr'] = learning_rate * (0.1 ** (epoch // int(num_train_epochs * 0.8)))

        self.model.train()
        epoch_losses = AverageMeter()

        with tqdm(total=(len(train_dataset) - len(train_dataset) % train_batch_size)) as t:
            t.set_description(f'epoch: {epoch}/{num_train_epochs - 1}')

            for data in train_dataloader:
                inputs, labels = data

                inputs = inputs.to(device)
                labels = labels.to(device)

                if self.model.config.model_type == 'SMSR':
                    # update tau for gumbel softmax
                    tau = max(1 - (epoch - 1) / 500, 0.4)
                    for m in self.model.modules():
                        if hasattr(m, '_set_tau'):
                            m._set_tau(tau)

                preds = self.model(inputs)
                criterion = nn.L1Loss()
                loss = criterion(preds, labels)

                epoch_losses.update(loss.item(), len(inputs))

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                scheduler.step()

                t.set_postfix(loss=f'{epoch_losses.avg:.6f}')
                t.update(len(inputs))

        self.eval(epoch)

TrainingArguments#

TrainingArguments is the data class of arguments which relate to the training loop itself.

Parameters:

Name Type Description Default
output_dir

obj:str): The output directory where the model predictions and checkpoints will be written.

required
overwrite_output_dir

obj:bool, optional, defaults to :obj:False): If :obj:True, overwrite the content of the output directory. Use this to continue training if :obj:output_dir points to a checkpoint directory.

required
learning_rate

obj:float, optional, defaults to 1e-4): The initial learning rate for :class:torch.optim.Adam optimizer.

required
gamma

obj:float, optional, defaults to 0.5): The weight decay gamma to apply to the :class:torch.optim.Adam optimizer.

required
num_train_epochs(

obj:int, optional, defaults to 1000): Total number of training epochs to perform.

required
save_strategy

obj:str or :class:~transformers.trainer_utils.IntervalStrategy, optional, defaults to :obj:"steps"): The checkpoint save strategy to adopt during training. Possible values are: * :obj:"no": No save is done during training. * :obj:"epoch": Save is done at the end of each epoch. * :obj:"steps": Save is done every :obj:save_steps.

required
save_steps

obj:int, optional, defaults to 500): Number of updates steps before two checkpoint saves if :obj:save_strategy="steps".

required
save_total_limit

obj:int, optional): If a value is passed, will limit the total amount of checkpoints. Deletes the older checkpoints in :obj:output_dir.

required
no_cuda

obj:bool, optional, defaults to :obj:False): Whether to not use CUDA even when it is available or not.

required
seed

obj:int, optional, defaults to 42): Random seed that will be set at the beginning of training.

required
fp16

obj:bool, optional, defaults to :obj:False): Whether to use 16-bit (mixed) precision training instead of 32-bit training.

required
per_device_train_batch_size

obj:int, optional, defaults to 16): The batch size per GPU/CPU for training.

required
local_rank

obj:int, optional, defaults to -1): Rank of the process during distributed training.

required
dataloader_num_workers

obj:int, optional, defaults to 0): Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.

required
dataloader_pin_memory

obj:bool, optional, defaults to :obj:True): Whether you want to pin memory in data loaders or not. Will default to :obj:True.

required

device: torch.device property readonly #

The device used by this process.

n_gpu property readonly #

The number of GPUs used by this process.

Note

This will only be greater than one when you have multiple GPUs available but are not using distributed training. For distributed training, it will always be 1.

train_batch_size: int property readonly #

The actual batch size for training (may differ from :obj:per_device_train_batch_size in distributed training).