towhee.models.collaborative_experts.collaborative_expertsΒΆ

Functions

create_model

Create CENet model.

drop_nans

Remove nans, which we expect to find at missing indices.

kronecker_prod

sharded_cross_view_inner_product

Compute a similarity matrix from sharded vectors.

sharded_single_view_inner_product

Compute a similarity matrix from sharded vectors.

Classes

CEModule

CE Module :param expert_dims: dimension of experts :type expert_dims: int :param text_dim: dimension of text :type text_dim: int :param use_ce: use collaborative experts :type use_ce: bool :param verbose: verbose mode :type verbose: bool :param l2renorm: l2 norm for CEModule :type l2renorm: bool :param num_classes: number of classes :type num_classes: int :param trn_config: train configs :type trn_config: dict :param trn_cat: train catogries :type trn_cat: int :param use_mish: use mish module :type use_mish: int :param include_self: include self :type include_self: int :param num_h_layers: number of layers for h_reason :type num_h_layers: int :param num_g_layers: number of layers for g_reason :type num_g_layers: int :param disable_nan_checks: disable nan checks :type disable_nan_checks: bool :param random_feats: random features :type random_feats: set :param test_caption_mode: test caption mode :type test_caption_mode: str :param mimic_ce_dims: mimic collaborative experts dimension :type mimic_ce_dims: bool :param concat_experts: concat embedding of experts :type concat_experts: bool :param concat_mix_experts: concat mix experts :type concat_mix_experts: bool :param freeze_weights: freeze weights :type freeze_weights: bool :param task: task string :type task: str :param keep_missing_modalities: assign every expert/text inner product the same weight, :type keep_missing_modalities: bool :param even if the expert is missing: :param vlad_feat_sizes: vlad feature sizes :type vlad_feat_sizes: dict :param same_dim: same dimension :type same_dim: int :param use_bn_reason: use batch normalization :type use_bn_reason: int

CENet

Collaborative Experts Module.

ContextGating

ContextGating Module :param dimension: dimension of input :type dimension: int :param add_batch_norm: add batch normalization :type add_batch_norm: int

ContextGatingReasoning

Args: dimension (int): dimension of input add_batch_norm (int): add batch normalization

G_reason

G_reason Module :param same_dim: same dimension :type same_dim: int :param num_inputs: number of inputs :type num_inputs: int :param non_lin: non-linear module :type non_lin: nn.module

GatedEmbeddingUnit

Args: input_dimension (int): dimension of input output_dimension (int): dimension of output use_bn (bool): use batch normalization

GatedEmbeddingUnitReasoning

Args: output_dimension (int): dimension of output

MimicCEGatedEmbeddingUnit

Args: input_dimension (int): dimension of input output_dimension (int): dimension of output use_bn (bool): use batch normalization

Mish

Applies the mish function element-wise: mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x))) SRC: https://github.com/digantamisra98/Mish/blob/master/Mish/Torch/mish.py

ReduceDim

ReduceDim Module :param input_dimension: dimension of input :type input_dimension: int :param output_dimension: dimension of output :type output_dimension: int

RelationModuleMultiScale

RelationModuleMultiScale Module :param img_feature_dim: image feature dimension :type img_feature_dim: int :param num_frames: number of frames :type num_frames: int :param num_class: number of classes :type num_class: int

RelationModuleMultiScale_Cat

RelationModuleMultiScale_Cat Module :param img_feature_dim: image feature dimension :type img_feature_dim: int :param num_frames: number of frames :type num_frames: int :param num_class: number of classes :type num_class: int

SpatialMLP

SpatialMLP module :param dimension: dimension of input :type dimension: int

TemporalAttention

TemporalAttention Module :param img_feature_dim: image feature dimension :type img_feature_dim: int :param num_attention: number of attention :type num_attention: int