Source code for towhee.models.clip.clip_utils

# Built on top of the original implementation at https://github.com/openai/CLIP
#
# Modifications by Copyright 2022 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import hashlib
import urllib.request
from tqdm import tqdm
from typing import Union, List
from pkg_resources import packaging
import warnings

import torch
from torch import nn

from .simple_tokenizer import SimpleTokenizer


[docs]def base_configs(): return dict( # vision embed_dim=512, image_resolution=224, vision_layers=12, vision_width=768, vision_patch_size=16, # text context_length=77, vocab_size=49408, transformer_width=512, transformer_heads=8, transformer_layers=12 )
[docs]def get_configs(model_name): if model_name == "clip_resnet_r50": configs = base_configs() configs.update(dict( url=("https://openaipublic.azureedge.net/clip/models/" "afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt"), embed_dim=1024, vision_layers=(3, 4, 6, 3), vision_width=64, vision_patch_size=None, )) elif model_name == "clip_resnet_r101": configs = base_configs() configs.update(dict( url=("https://openaipublic.azureedge.net/clip/models/" "8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt"), embed_dim=512, vision_layers=(3, 4, 23, 3), vision_width=64, vision_patch_size=None, )) elif model_name == "clip_vit_b16": configs = base_configs() configs.update(dict( url=("https://openaipublic.azureedge.net/clip/models/" "5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt"), )) elif model_name == "clip_vit_b32": configs = base_configs() configs.update(dict( url=("https://openaipublic.azureedge.net/clip/models/" "40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"), vision_patch_size=32, multilingual_model="M-CLIP/XLM-Roberta-Large-Vit-B-32" )) else: raise ValueError(f"Invalid model name '{model_name}'.") return configs
[docs]def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) \ -> Union[torch.IntTensor, torch.LongTensor]: """ Returns the tokenized representation of given input string(s) Parameters Args: texts : Union[str, List[str]] An input string or a list of input strings to tokenize context_length : int The context length to use; all CLIP models use 77 as the context length truncate: bool Whether to truncate the text in case its encoding is longer than the context length Returns: A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]. We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long. """ if isinstance(texts, str): texts = [texts] tokenizer = SimpleTokenizer() sot_token = tokenizer.encoder["<|startoftext|>"] eot_token = tokenizer.encoder["<|endoftext|>"] all_tokens = [[sot_token] + tokenizer.encode(text) + [eot_token] for text in texts] if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"): result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) else: result = torch.zeros(len(all_tokens), context_length, dtype=torch.int) for i, tokens in enumerate(all_tokens): if len(tokens) > context_length: if truncate: tokens = tokens[:context_length] tokens[-1] = eot_token else: raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") result[i, :len(tokens)] = torch.tensor(tokens) return result
def _download(url: str, root: str): os.makedirs(root, exist_ok=True) filename = os.path.basename(url) expected_sha256 = url.split("/")[-2] download_target = os.path.join(root, filename) if os.path.exists(download_target) and not os.path.isfile(download_target): raise RuntimeError(f"{download_target} exists and is not a regular file") if os.path.isfile(download_target): with open(download_target, "rb") as f: if hashlib.sha256(f.read()).hexdigest() == expected_sha256: return download_target else: warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading file.") with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit="iB", unit_scale=True, unit_divisor=1024) as loop: while True: buffer = source.read(8192) if not buffer: break output.write(buffer) loop.update(len(buffer)) with open(download_target, "rb") as f: if hashlib.sha256(f.read()).hexdigest() != expected_sha256: raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match") return download_target
[docs]def convert_weights(model: nn.Module): """Convert applicable model parameters to fp16""" def _convert_weights_to_fp16(layer): if isinstance(layer, (nn.Conv1d, nn.Conv2d, nn.Linear)): layer.weight.data = layer.weight.data.half() if layer.bias is not None: layer.bias.data = layer.bias.data.half() if isinstance(layer, nn.MultiheadAttention): for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: tensor = getattr(layer, attr) if tensor is not None: tensor.data = tensor.data.half() for name in ["text_projection", "proj"]: if hasattr(layer, name): attr = getattr(layer, name) if attr is not None: attr.data = attr.data.half() model.apply(_convert_weights_to_fp16)
[docs]def patch_device(model, device): device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] def patch_dev(module): try: graphs = [module.graph] if hasattr(module, "graph") else [] except RuntimeError: graphs = [] if hasattr(module, "forward1"): graphs.append(module.forward1.graph) for graph in graphs: for node in graph.findAllNodes("prim::Constant"): if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): node.copyAttributes(device_node) model.apply(patch_dev) patch_dev(model.encode_image) patch_dev(model.encode_text)
[docs]def patch_float(model): float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] float_node = float_input.node() def patch_flt(module): try: graphs = [module.graph] if hasattr(module, "graph") else [] except RuntimeError: graphs = [] if hasattr(module, "forward1"): graphs.append(module.forward1.graph) for graph in graphs: for node in graph.findAllNodes("aten::to"): inputs = list(node.inputs()) for i in [1, 2]: # dtype can be the second or third argument to aten::to() if inputs[i].node()["value"] == 5: inputs[i].node().copyAttributes(float_node) model.apply(patch_flt) patch_flt(model.encode_image) patch_flt(model.encode_text) model.float()