Source code for towhee.models.action_clip.action_clip

# Copyright 2021 Zilliz. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Union

import torch.cuda
from torch import nn

from towhee.models import clip
from .action_clip_utils import map_state_dict, get_configs
from .text_prompt import text_prompt
from .visual_prompt import VisualPrompt

[docs]class ActionClip: """ ActionClip """
[docs] def __init__( self, clip_model: nn.Module, visual_prompt: nn.Module = None, text_prompt_flag: bool = True ): super().__init__() self.clip_model = clip_model self.visual_prompt = visual_prompt self.text_prompt_flag = text_prompt_flag
def encode_video(self, video): b, t, c, h, w = video.size() frames = video.view(-1, c, h, w) if hasattr(self.clip_model, "encode_image"): frame_features = self.clip_model.encode_image(frames) else: frame_features = self.clip_model(frames) image_features = frame_features.view(b, t, -1) if self.visual_prompt: visual_features = self.visual_prompt(image_features) else: visual_features = image_features return visual_features def encode_text(self, text: list): if self.text_prompt_flag: text, _, _ = text_prompt(text) if hasattr(self.clip_model, "encode_text"): text_features = self.clip_model.encode_text(text) else: text_features = self.clip_model(text) return text_features
[docs]def create_model( clip_model: Union[str, nn.Module], pretrained: bool = False, jit: bool = False, device: str = None, checkpoints: dict = None, cfg: dict = None ): if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" if isinstance(clip_model, str): clip_model = clip.create_model(model_name=clip_model, pretrained=pretrained, jit=jit, device=device) configs = get_configs(model_name=clip_model) else: configs = get_configs() if cfg is not None: configs.update(**cfg) visual_prompt_type = configs["visual_prompt_type"] num_frames = configs["num_frames"] visual_prompt = VisualPrompt(visual_prompt_type, clip_model.state_dict(), num_frames).to(device) if pretrained and checkpoints: if "clip_state_dict" in checkpoints: clip_state_dict = map_state_dict(checkpoints, "clip_state_dict") clip_model.load_state_dict(clip_state_dict) if "visual_state_dict" in checkpoints: visual_state_dict = map_state_dict(checkpoints, "visual_state_dict") visual_prompt.load_state_dict(visual_state_dict) model = ActionClip(clip_model=clip_model, visual_prompt=visual_prompt) return model