trainer.utils package


trainer.utils.layer_freezer module

class trainer.utils.layer_freezer.LayerFreezer(model)[source]

Bases: object

Utilities to freeze/unfreeze layers.


model – a model with weights.




>>> from towhee.trainer.utils.layer_freezer import LayerFreezer
>>> from towhee.models import vit
>>> my_model = vit.create_model()
>>> my_freezer = LayerFreezer(my_model)
>>> # Check if modules in the last layer are frozen
>>> my_freezer.status(-1)
>>> # Check if modules in the layer "head" are frozen
>>> my_freezer.status("head")
['unfrozen', 'unfrozen']
>>> # Show all frozen layers
>>> my_freezer.show_frozen_layers()
['patch_embed', 'head']
>>> # Freeze layers by a list of layer indexes
>>> my_freezer.by_idx([0, -1])
>>> # Freeze layers by a list of layer names
>>> my_freezer.by_names(['head'])
>>> # Freeze all layers
>>> my_freezer.set_all()
>>> # Unfreeze all layers
>>> my_freezer.set_all(freeze=False)
>>> # Freeze all except the last layer
>>> my_freezer.set_slice(-1)
by_idx(idx: list, freeze: bool = True)[source]

Freeze/unfreeze layers by indexes

  • idx (list) – a list of layer indexes

  • freeze (bool) – if or not freeze layers (default: True)

by_names(names: list, freeze: bool = True)[source]

Freeze/unfreeze layers by names

  • names (list) – a list of layer names

  • freeze (bool) – if or not freeze layers (default: True)

set_all(freeze: bool = True)[source]

Freeze/unfreeze all layers.


freeze (bool) – if or not freeze layers (default: True)

set_slice(slice_num: int, freeze: bool = True)[source]

Freeze/unfreeze layers by list slice.

  • slice_num (int) – number to slice the list of layers

  • freeze (bool) – if or not freeze layers (default: True)


Show all names of frozen layers




A list of names of frozen layers

status(layer: Union[str, int])[source]

Check if a layer is frozen or not by its name or index


layer (Union[str, int]) – the name or index of layer.


A list of status (‘frozen’ or ‘unfrozen’) to indicate if modules in the layer are frozen or not.

trainer.utils.trainer_utils module

Utilities for the Trainer.

class trainer.utils.trainer_utils.SchedulerType(value)[source]

Bases: enum.Enum

An enumeration.

CONSTANT = 'constant'
CONSTANT_WITH_WARMUP = 'constant_with_warmup'
COSINE = 'cosine'
COSINE_WITH_RESTARTS = 'cosine_with_restarts'
LINEAR = 'linear'
POLYNOMIAL = 'polynomial'
class trainer.utils.trainer_utils.TrainOutput(global_step, training_loss)[source]

Bases: NamedTuple


Return a nicely formatted representation string

global_step: int

Alias for field number 0

training_loss: float

Alias for field number 1

trainer.utils.trainer_utils.get_last_checkpoint(out_dir: str)[source]
trainer.utils.trainer_utils.honor_type(obj, generator: Generator)[source]

Cast a generator to the same type as obj (list, tuple or namedtuple)

trainer.utils.trainer_utils.is_torch_tensor(tensor: Any)[source]
trainer.utils.trainer_utils.recursively_apply(func: typing.Callable, data: typing.Any, *args, test_type: typing.Callable = <function is_torch_tensor>, error_on_other_type: bool = False, **kwargs)[source]

Recursively apply a function on a data structure that is a nested list/tuple/dictionary of a given base type.

  • func (callable) – The function to recursively apply.

  • data (nested list/tuple/dictionary of main_type) – The data on which to apply func

  • *args – Positional arguments that will be passed to func when applied on the unpacked data.

  • main_type (type, optional, defaults to torch.Tensor) – The base type of the objects to which apply func.

  • error_on_other_type (bool, optional, defaults to False) – Whether to return an error or not if after unpacking data, we get on an object that is not of type main_type. If False, the function will leave objects of types different than main_type unchanged.

  • **kwargs – Keyword arguments that will be passed to func when applied on the unpacked data.


The same data structure as data with func applied to every object of type main_type.

trainer.utils.trainer_utils.reduce_value(value, average=True)[source]
trainer.utils.trainer_utils.send_to_device(tensor: Any, device: torch.device)[source]

Recursively sends the elements in a nested list/tuple/dictionary of tensors to a given device. Borrowed from huggingface/accelerate. :param tensor: The data to send to a given device. :type tensor: nested list/tuple/dictionary of torch.Tensor :param device: The device to send the data to :type device: torch.device


The same data structure as tensor with all tensors sent to the proper device.

trainer.utils.trainer_utils.set_seed(seed: int)[source]

Helper function for reproducible behavior to set the seed in random, numpy, torch.


seed (int) – The seed to set.

trainer.utils.trainer_utils.unwrap_model(model: torch.nn.modules.module.Module) torch.nn.modules.module.Module[source]

Unwraps a model from potential containers (as used in distributed training).


model (torch.nn.Module) – The model to unwrap.

Module contents