Source code for towhee.functional.mixins.milvus

# Copyright 2021 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Iterable, Tuple

from towhee.hparam import param_scope
from towhee.utils.log import engine_log
from ..entity import Entity


def _milvus_insert(iterable: Iterable, index: Tuple[str], collection, batch_size: int): # pragma: no cover
    # pylint: disable=import-outside-toplevel
    from towhee.utils.thirdparty.milvus_utils import Collection, MutationResult
    if isinstance(collection, str):
        collection = Collection(collection)
    primary_keys = []
    insert_count = 0
    first = True

    try:
        for entity in iterable:
            if first:
                if index is None:
                    index = tuple(entity.__dict__.keys())
                data = [[] for _ in range(len(index))]
                first = False

            for i in range(len(index)):
                data[i].append(getattr(entity, index[i]))
            if len(data[0]) == batch_size:
                mr = collection.insert(data)
                data = [[] for _ in range(len(index))]
                primary_keys += mr.primary_keys
                insert_count += mr.insert_count
                engine_log.info('Successfully inserted %d row data.', mr.insert_count)

        if len(data[0]) > 0:
            mr = collection.insert(data)
            primary_keys += mr.primary_keys
            insert_count += mr.insert_count
            engine_log.info('Successfully inserted %d row data.', mr.insert_count)

        e = Entity(insert_count=insert_count,
                   primary_keys=primary_keys,
                   delete_count=mr.delete_count,
                   upsert_count=mr.upsert_count,
                   timestamp=mr.timestamp)
        milvus_mr = MutationResult(e)
    except Exception as e:  # pylint: disable=broad-except
        engine_log.error('Error when insert data to milvus with %s.', e)
        raise e
    finally:
        collection.load()
    return milvus_mr


def _to_milvus_callback(self): # pragma: no cover
    # pylint: disable=consider-using-get
    def wrapper(_: str, index, *arg, **kws):
        batch_size = 1
        if index is not None and isinstance(index, str):
            index = (index,)

        if arg is not None and len(arg) == 1:
            collection, = arg
        elif arg is not None and len(arg) == 2:
            collection, batch_size = arg

        if 'collection' in kws:
            collection = kws['collection']
        if 'batch' in kws:
            batch_size = int(kws['batch'])

        dc_data = self
        if 'stream' in kws and not kws['stream']:
            dc_data = self.unstream()

        _ = _milvus_insert(dc_data, index, collection, batch_size)
        return dc_data
    return wrapper


[docs]class MilvusMixin: # pragma: no cover """ Mixins for Milvus, such as loading data into Milvus collections. Note that the Milvus collection is created before loading the data. Refer to https://milvus.io/docs/v2.0.x/create_collection.md. Args: collection (`Union[str, Collection]`): The collection name or pymilvus.Collection in Milvus. batch (`str`): The batch size to load into Milvus, defaults to 1. stream (`bool`, optional): Whther the stream mode with `to_milvus`, defaults to True. Returns: The DataCollection. Examples: .. note:: The shape of embedding vector refer to https://towhee.io/image-embedding/timm. And the dimension of the "test" collection should be the same as it. >>> import towhee >>> mr = ( #doctest: +SKIP ... towhee.glob['path']('./*.jpg') ... .image_decode['path', 'img']() ... .image_embedding.timm['img', 'vec'](model_name='resnet50') ... .to_milvus['vec'](collection='test', batch=1000) ... ) """
[docs] def __init__(self): super().__init__() self.to_milvus = param_scope().dispatch(_to_milvus_callback(self))