Trainer#
The Trainer class provides an API for feature-complete training in most standard use cases.
Before instantiating your Trainer, create a TrainingArguments to access all the points of customization during training.
Trainer#
Trainer is a simple class implementing the training and eval loop for PyTorch to train a super-image model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
|
class: |
required |
args |
|
class: |
required |
train_dataset |
|
obj: |
required |
eval_dataset |
|
obj: |
required |
get_eval_dataloader(self)
#
Returns the evaluation :class:~torch.utils.data.DataLoader
.
Source code in super_image\trainer.py
def get_eval_dataloader(self) -> DataLoader:
"""
Returns the evaluation :class:`~torch.utils.data.DataLoader`.
"""
eval_dataset = self.eval_dataset
if eval_dataset is None:
eval_dataset = self.train_dataset
return DataLoader(
dataset=eval_dataset,
batch_size=1,
)
get_train_dataloader(self)
#
Returns the training :class:~torch.utils.data.DataLoader
.
Source code in super_image\trainer.py
def get_train_dataloader(self) -> DataLoader:
"""
Returns the training :class:`~torch.utils.data.DataLoader`.
"""
if self.train_dataset is None:
raise ValueError("Trainer: training requires a train_dataset.")
train_dataset = self.train_dataset
return DataLoader(
dataset=train_dataset,
batch_size=self.args.train_batch_size,
shuffle=True,
num_workers=self.args.dataloader_num_workers,
pin_memory=self.args.dataloader_pin_memory,
)
save_model(self, output_dir=None)
#
Will save the model, so you can reload it using :obj:from_pretrained()
.
Will only save from the main process.
Source code in super_image\trainer.py
def save_model(self, output_dir: Optional[str] = None):
"""
Will save the model, so you can reload it using :obj:`from_pretrained()`.
Will only save from the main process.
"""
output_dir = output_dir if output_dir is not None else self.args.output_dir
os.makedirs(output_dir, exist_ok=True)
if not isinstance(self.model, PreTrainedModel):
# Setup scale
scale = self.model.config.scale
if scale is not None:
weights_name = WEIGHTS_NAME_SCALE.format(scale=scale)
else:
weights_name = WEIGHTS_NAME
weights = copy.deepcopy(self.model.state_dict())
torch.save(weights, os.path.join(output_dir, weights_name))
else:
self.model.save_pretrained(output_dir)
train(self, resume_from_checkpoint=None, **kwargs)
#
Main training entry point.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
resume_from_checkpoint |
Union[bool, str] |
obj: |
None |
kwargs |
|
Additional keyword arguments used to hide deprecated arguments |
{} |
Source code in super_image\trainer.py
def train(
self,
resume_from_checkpoint: Optional[Union[str, bool]] = None,
**kwargs,
):
"""
Main training entry point.
Args:
resume_from_checkpoint (:obj:`str` or :obj:`bool`, `optional`):
If a :obj:`str`, local path to a saved checkpoint as saved by a previous instance of
:class:`~super_image.Trainer`. If a :obj:`bool` and equals `True`, load the last checkpoint in
`args.output_dir` as saved by a previous instance of :class:`~super_image.Trainer`. If present,
training will resume from the model/optimizer/scheduler states loaded here.
kwargs:
Additional keyword arguments used to hide deprecated arguments
"""
args = self.args
epochs_trained = 0
device = args.device
num_train_epochs = args.num_train_epochs
learning_rate = args.learning_rate
train_batch_size = args.train_batch_size
train_dataset = self.train_dataset
train_dataloader = self.get_train_dataloader()
step_size = int(len(train_dataset) / train_batch_size * 200)
# # Load potential model checkpoint
# if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint:
# resume_from_checkpoint = get_last_checkpoint(args.output_dir)
# if resume_from_checkpoint is None:
# raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})")
#
# if resume_from_checkpoint is not None:
# if not os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)):
# raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}")
#
# logger.info(f"Loading model from {resume_from_checkpoint}).")
#
# if os.path.isfile(os.path.join(resume_from_checkpoint, CONFIG_NAME)):
# config = PretrainedConfig.from_json_file(os.path.join(resume_from_checkpoint, CONFIG_NAME))
#
# state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME), map_location="cpu")
# # If the model is on the GPU, it still works!
# self._load_state_dict_in_model(state_dict)
#
# # release memory
# del state_dict
optimizer = Adam(self.model.parameters(), lr=learning_rate)
scheduler = lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=self.args.gamma)
for epoch in range(epochs_trained, num_train_epochs):
for param_group in optimizer.param_groups:
param_group['lr'] = learning_rate * (0.1 ** (epoch // int(num_train_epochs * 0.8)))
self.model.train()
epoch_losses = AverageMeter()
with tqdm(total=(len(train_dataset) - len(train_dataset) % train_batch_size)) as t:
t.set_description(f'epoch: {epoch}/{num_train_epochs - 1}')
for data in train_dataloader:
inputs, labels = data
inputs = inputs.to(device)
labels = labels.to(device)
if self.model.config.model_type == 'SMSR':
# update tau for gumbel softmax
tau = max(1 - (epoch - 1) / 500, 0.4)
for m in self.model.modules():
if hasattr(m, '_set_tau'):
m._set_tau(tau)
preds = self.model(inputs)
criterion = nn.L1Loss()
loss = criterion(preds, labels)
epoch_losses.update(loss.item(), len(inputs))
optimizer.zero_grad()
loss.backward()
optimizer.step()
scheduler.step()
t.set_postfix(loss=f'{epoch_losses.avg:.6f}')
t.update(len(inputs))
self.eval(epoch)
TrainingArguments#
TrainingArguments is the data class of arguments which relate to the training loop itself.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
output_dir |
|
obj: |
required |
overwrite_output_dir |
|
obj: |
required |
learning_rate |
|
obj: |
required |
gamma |
|
obj: |
required |
num_train_epochs( |
|
obj: |
required |
save_strategy |
|
obj: |
required |
save_steps |
|
obj: |
required |
save_total_limit |
|
obj: |
required |
no_cuda |
|
obj: |
required |
seed |
|
obj: |
required |
fp16 |
|
obj: |
required |
per_device_train_batch_size |
|
obj: |
required |
local_rank |
|
obj: |
required |
dataloader_num_workers |
|
obj: |
required |
dataloader_pin_memory |
|
obj: |
required |
device: torch.device
property
readonly
#
The device used by this process.
n_gpu
property
readonly
#
The number of GPUs used by this process.
Note
This will only be greater than one when you have multiple GPUs available but are not using distributed training. For distributed training, it will always be 1.
train_batch_size: int
property
readonly
#
The actual batch size for training (may differ from :obj:per_device_train_batch_size
in distributed training).