Checkpoint

Checkpoint

class kospeech.checkpoint.checkpoint.Checkpoint(model: torch.nn.modules.module.Module = None, optimizer: kospeech.optim.Optimizer = None, trainset_list: list = None, validset: kospeech.data.data_loader.SpectrogramDataset = None, epoch: int = None)[source]

The Checkpoint class manages the saving and loading of a model during training. It allows training to be suspended and resumed at a later time (e.g. when running on a cluster using sequential jobs). To make a checkpoint, initialize a Checkpoint object with the following args; then call that object’s save() method to write parameters to disk.

Parameters
  • model (nn.Module) – model being trained

  • optimizer (torch.optim) – stores the state of the optimizer

  • trainset_list (list) – list of trainset

  • validset (kospeech.data.data_loader.SpectrogramDataset) – validation dataset

  • epoch (int) – current epoch (an epoch is a loop through the full training data)

Variables
  • SAVE_PATH (str) – path of file to save

  • LOAD_PATH (str) – path of file to load

  • TRAINER_STATE_NAME (str) – name of the file storing trainer states

  • MODEL_NAME (str) – name of the file storing model

get_latest_checkpoint()[source]

returns the path to the last saved checkpoint’s subdirectory. Precondition: at least one checkpoint has been made (i.e., latest checkpoint subdirectory exists).

load(path)[source]

Loads a Checkpoint object that was previously saved to disk.

Parameters

path (str) – path to the checkpoint subdirectory

Returns

checkpoint object with fields copied from those stored on disk

Return type

checkpoint (Checkpoint)

save()[source]

Saves the current model and related training parameters into a subdirectory of the checkpoint directory. The name of the subdirectory is the current local time in Y_M_D_H_M_S format.