towhee.operator.base.NNOperator

class towhee.operator.base.NNOperator(framework: str = 'pytorch')[source]

Bases: Operator

Neural Network related operators that involve machine learning frameworks.

Parameters:

framework (str) – The framework to apply.

Methods

load

Load the model checkpoint into an operator.

save

Save the model checkpoint into the path.

save_model

Save model to local.

setup_trainer

Set up the trainer instance in operator before training and set trainer parameters. :param training_config: The config of this trainer. :type training_config: TrainingConfig :param train_dataset: Training dataset. :type train_dataset: Union[Dataset, TowheeDataSet] :param eval_dataset: Evaluate dataset. :type eval_dataset: Union[Dataset, TowheeDataSet] :param train_dataloader: If specified, Trainer will use it to load training data. Otherwise, Trainer will build dataloader from train_dataset. :type train_dataloader: Union[DataLoader, Iterable] :param eval_dataloader: If specified, Trainer will use it to load evaluate data. Otherwise, Trainer will build dataloader from train_dataset. :type eval_dataloader: Union[DataLoader, Iterable] :param model_card: Model card contains the training informations. :type model_card: ModelCard.

supported_model_names

Return a list of supported model names.

train

Start to train an operator.

Attributes

flag

framework

key

shared_type

trainer

abstract __call__()

The framework calls __call__ function repeatedly for every input data.

Args:

Returns:

Raises:

An exception during __init__ can terminate the graph run.

__init__(framework: str = 'pytorch')[source]

Init operator, before a graph starts, the framework will call Operator __init__ function.

Args:

Raises:

An exception during __init__ can terminate the graph run.

load(path: Optional[str] = None)[source]

Load the model checkpoint into an operator.

Parameters:

path (str) – The folder path containing the model’s checkpoints.

save(path: str, overwrite: bool = True)[source]

Save the model checkpoint into the path.

Parameters:
  • path (str) – The folder path containing the model’s checkpoints.

  • overwrite (bool) – If True, it will overwrite the same name path when existing.

Raises:

(FileExistsError) – If overwrite is False, when there already exists a path, it will raise Error.

save_model()[source]

Save model to local.

setup_trainer(training_config=None, train_dataset=None, eval_dataset=None, train_dataloader=None, eval_dataloader=None, model_card=None)[source]

Set up the trainer instance in operator before training and set trainer parameters. :param training_config: The config of this trainer. :type training_config: TrainingConfig :param train_dataset: Training dataset. :type train_dataset: Union[Dataset, TowheeDataSet] :param eval_dataset: Evaluate dataset. :type eval_dataset: Union[Dataset, TowheeDataSet] :param train_dataloader: If specified, Trainer will use it to load training data.

Otherwise, Trainer will build dataloader from train_dataset.

Parameters:
  • eval_dataloader (Union[DataLoader, Iterable]) – If specified, Trainer will use it to load evaluate data. Otherwise, Trainer will build dataloader from train_dataset.

  • model_card (ModelCard) – Model card contains the training informations.

Returns:

supported_model_names()[source]

Return a list of supported model names.

train(training_config=None, train_dataset=None, eval_dataset=None, resume_checkpoint_path=None, **kwargs)[source]

Start to train an operator.

Parameters:
  • training_config (TrainingConfig) – The config of this trainer.

  • train_dataset (Union[Dataset, TowheeDataSet]) – Training dataset.

  • eval_dataset (Union[Dataset, TowheeDataSet]) – Evaluate dataset.

  • resume_checkpoint_path (str) – If resuming training, pass into the path.

  • **kwargs (Any) – Keyword Args.