towhee.models.utils.init_vit_weights.init_vit_weights

towhee.models.utils.init_vit_weights.init_vit_weights(module: torch.nn.modules.module.Module, trunc_normal_std=0.02, name: str = '', head_bias: float = 0.0, jax_impl: bool = False)[source]

ViT weight initialization * When called without n, head_bias, jax_impl args it will behave exactly the same

as my original init for compatibility with prev hparam / downstream use cases (ie DeiT).

  • When called w/ valid n (module name) and jax_impl=True, will (hopefully) match JAX impl