Source code for towhee.functional.mixins.kv_storage

# Copyright 2021 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# pylint: disable=import-outside-toplevel
# pylint: disable=protected-access
# pylint: disable=consider-using-get
from io import BytesIO
from typing import Iterable

import numpy as np
from towhee.hparam import param_scope


def _insert_leveldb_callback(self):

    def wrapper(_, index, *arg, **kws):
        from towhee.utils.thirdparty.plyvel_utils import plyvel
        path = None
        if arg is not None and len(arg) == 1:
            path = arg[0]
        if 'path' in kws:
            path = kws['path']

        db = plyvel.DB(path, create_if_missing=True)

        dc_data = self
        if self.is_stream:
            dc_data = self.unstream()

        for i in dc_data._iterable:
            key = getattr(i, index[0][0])
            val = getattr(i, index[0][1])

            if isinstance(key, str) or not isinstance(key, Iterable):
                key = [key]

            if isinstance(val, np.ndarray):
                np_bytes = BytesIO()
                np.save(np_bytes, val, allow_pickle=True)
                val = np_bytes.getvalue()
            else:
                val = str(val).encode('utf-8')

            for k in key:
                db.put(str(k).encode('utf-8'), val)

        db.close()
        return dc_data

    return wrapper


def _from_leveldb_callback(self):

    def wrapper(_, index, *arg, **kws):
        from towhee.utils.thirdparty.plyvel_utils import plyvel
        path = None
        is_ndarray = False

        if arg is not None and len(arg) == 1:
            path = arg[0]
        elif arg is not None and len(arg) == 2:
            path = arg[0]
            is_ndarray = arg[1]

        if 'path' in kws:
            path = kws['path']
        if 'is_ndarray' in kws:
            is_ndarray = kws['is_ndarray']

        db = plyvel.DB(path, create_if_missing=True)

        dc_data = self
        if self.is_stream:
            dc_data = self.unstream()

        for i in dc_data._iterable:
            key = getattr(i, index[0])
            if isinstance(key, str) or not isinstance(key, Iterable):
                val = db.get(str(key).encode('utf-8'))
                if not is_ndarray:
                    val.decode('utf-8')
                else:
                    val = BytesIO(val)
                    val = np.load(val, allow_pickle=True)

                setattr(i, index[1], val)
            else:
                vals = []
                for k in key:
                    val = db.get(str(k).encode('utf-8'))
                    if not is_ndarray:
                        vals.append(val.decode('utf-8'))
                    else:
                        val = BytesIO(val)
                        val = np.load(val, allow_pickle=True)
                        vals.append(val)
                setattr(i, index[1], vals)

        db.close()
        return dc_data

    return wrapper


[docs]class KVStorageMixin: # pragma: no cover """ Mixin for kv storage. """
[docs] def __init__(self): super().__init__() self.insert_leveldb = param_scope().dispatch(_insert_leveldb_callback(self)) self.from_leveldb = param_scope().dispatch(_from_leveldb_callback(self))