Source code for towhee.models.coformer.utils

# Copyright 2022 Zilliz. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
# Original code from
# Modified by Zilliz.

import torch
from torch import Tensor
import torch.distributed as dist

from typing import Optional, List

[docs]class NestedTensor(object): ''' NestedTensor '''
[docs] def __init__(self, tensors, mask: Optional[Tensor]): self.tensors = tensors self.mask = mask
def to(self, device): # type: (Device) -> NestedTensor # noqa cast_tensor = mask = self.mask if mask is not None: assert mask is not None cast_mask = else: cast_mask = None return NestedTensor(cast_tensor, cast_mask) def decompose(self): return self.tensors, self.mask
[docs] def __repr__(self): return str(self.tensors)
[docs]def is_main_process(): return get_rank() == 0
[docs]def get_rank(): if not is_dist_avail_and_initialized(): return 0 return dist.get_rank()
[docs]def is_dist_avail_and_initialized(): if not dist.is_available(): return False if not dist.is_initialized(): return False return True
def _max_by_axis(the_list): # type: (List[List[int]]) -> List[int] maxes = the_list[0] for sublist in the_list[1:]: for index, item in enumerate(sublist): maxes[index] = max(maxes[index], item) return maxes
[docs]def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): # TODO make this more general if tensor_list[0].ndim == 3: # TODO make it support different-sized images max_size = _max_by_axis([list(img.shape) for img in tensor_list]) # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) batch_shape = [len(tensor_list)] + max_size b, _, h, w = batch_shape dtype = tensor_list[0].dtype device = tensor_list[0].device tensor = torch.zeros(batch_shape, dtype=dtype, device=device) mask = torch.ones((b, h, w), dtype=torch.bool, device=device) for img, pad_img, m in zip(tensor_list, tensor, mask): pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) m[: img.shape[1], :img.shape[2]] = False else: raise ValueError('not supported') return NestedTensor(tensor, mask)