towhee.operator.base.NNOperator¶
- class towhee.operator.base.NNOperator(framework: str = 'pytorch')[source]¶
Bases:
Operator
Neural Network related operators that involve machine learning frameworks.
- Parameters:
framework (str) – The framework to apply.
Methods
Load the model checkpoint into an operator.
Save the model checkpoint into the path.
Save model to local.
Set up the trainer instance in operator before training and set trainer parameters. :param training_config: The config of this trainer. :type training_config: TrainingConfig :param train_dataset: Training dataset. :type train_dataset: Union[Dataset, TowheeDataSet] :param eval_dataset: Evaluate dataset. :type eval_dataset: Union[Dataset, TowheeDataSet] :param train_dataloader: If specified, Trainer will use it to load training data. Otherwise, Trainer will build dataloader from train_dataset. :type train_dataloader: Union[DataLoader, Iterable] :param eval_dataloader: If specified, Trainer will use it to load evaluate data. Otherwise, Trainer will build dataloader from train_dataset. :type eval_dataloader: Union[DataLoader, Iterable] :param model_card: Model card contains the training informations. :type model_card: ModelCard.
Return a list of supported model names.
Start to train an operator.
Attributes
flag
framework
key
shared_type
trainer
- abstract __call__()¶
The framework calls __call__ function repeatedly for every input data.
Args:
Returns:
- Raises:
An exception during __init__ can terminate the graph run. –
- __init__(framework: str = 'pytorch')[source]¶
Init operator, before a graph starts, the framework will call Operator __init__ function.
Args:
- Raises:
An exception during __init__ can terminate the graph run. –
- load(path: Optional[str] = None)[source]¶
Load the model checkpoint into an operator.
- Parameters:
path (str) – The folder path containing the model’s checkpoints.
- save(path: str, overwrite: bool = True)[source]¶
Save the model checkpoint into the path.
- Parameters:
path (str) – The folder path containing the model’s checkpoints.
overwrite (bool) – If True, it will overwrite the same name path when existing.
- Raises:
(FileExistsError) – If overwrite is False, when there already exists a path, it will raise Error.
- setup_trainer(training_config=None, train_dataset=None, eval_dataset=None, train_dataloader=None, eval_dataloader=None, model_card=None)[source]¶
Set up the trainer instance in operator before training and set trainer parameters. :param training_config: The config of this trainer. :type training_config: TrainingConfig :param train_dataset: Training dataset. :type train_dataset: Union[Dataset, TowheeDataSet] :param eval_dataset: Evaluate dataset. :type eval_dataset: Union[Dataset, TowheeDataSet] :param train_dataloader: If specified, Trainer will use it to load training data.
Otherwise, Trainer will build dataloader from train_dataset.
- Parameters:
eval_dataloader (Union[DataLoader, Iterable]) – If specified, Trainer will use it to load evaluate data. Otherwise, Trainer will build dataloader from train_dataset.
model_card (ModelCard) – Model card contains the training informations.
Returns:
- train(training_config=None, train_dataset=None, eval_dataset=None, resume_checkpoint_path=None, **kwargs)[source]¶
Start to train an operator.
- Parameters:
training_config (TrainingConfig) – The config of this trainer.
train_dataset (Union[Dataset, TowheeDataSet]) – Training dataset.
eval_dataset (Union[Dataset, TowheeDataSet]) – Evaluate dataset.
resume_checkpoint_path (str) – If resuming training, pass into the path.
**kwargs (Any) – Keyword Args.