# Copyright 2021 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from typing import Dict, Any, Optional, Set, Union, List
from towhee.functional.entity import Entity
from towhee.hparam import dynamic_dispatch, param_scope
# pylint: disable=protected-access
[docs]class EntityMixin:
"""
Mixin to help deal with Entity.
Examples:
1. define an operator with `register` decorator
>>> from towhee import register
>>> from towhee import DataCollection
>>> @register
... def add_1(x):
... return x+1
2. apply the operator to named field of entity and save result to another named field
>>> (
... DataCollection([dict(a=1, b=2), dict(a=2, b=3)])
... .as_entity()
... .add_1['a', 'c']() # <-- use field `a` as input and filed `c` as output
... .as_str()
... .to_list()
... )
['{"a": 1, "b": 2, "c": 2}', '{"a": 2, "b": 3, "c": 3}']
Select the entity on the specified fields.
Examples:
1. Select the entity on one specified field:
>>> from towhee import Entity
>>> from towhee import DataCollection
>>> dc = DataCollection([Entity(a=i, b=i, c=i) for i in range(2)])
>>> dc.select['a']().to_list()
[<Entity dict_keys(['a'])>, <Entity dict_keys(['a'])>]
2. Select multiple fields and unpack the entity:
>>> (
... DataCollection([Entity(a=i, b=i, c=i) for i in range(5)])
... .select['a', 'b']()
... .as_raw()
... .to_list()
... )
[(0, 0), (1, 1), (2, 2), (3, 3), (4, 4)]
3. Another field selection syntax (not suggested):
>>> (
... DataCollection([Entity(a=i, b=i, c=i) for i in range(5)])
... .select('a', 'b')
... .as_raw()
... .to_list()
... )
[(0, 0), (1, 1), (2, 2), (3, 3), (4, 4)]
"""
def __init__(self):
# pylint: disable=useless-super-delegation
super().__init__()
@property
def select(self):
@dynamic_dispatch
def selector(*arg):
index = param_scope()._index
if isinstance(index, str):
index = (index, )
if index is None and arg is not None and len(arg) > 0:
index = arg
def inner(entity: Entity):
if index is not None:
return Entity(
**{col: getattr(entity, col)
for col in index})
return entity
return self.map(inner)
return selector
# pylint: disable=invalid-name
[docs] def fill_entity(self,
_DefaultKVs: Optional[Dict[str, Any]] = None,
_ReplaceNoneValue: bool = False,
**kws):
"""
When DataCollection's iterable exists of Entities and some indexes missing, fill default value for those indexes.
Args:
_ReplaceNoneValue (`bool`):
Whether to replace None in Entity's value.
_DefaultKVs (`Dict[str, Any]`):
The key-value pairs stored in a dict.
"""
if _DefaultKVs:
kws.update(_DefaultKVs)
def fill(entity: Entity):
for k, v in kws.items():
if not hasattr(entity, k):
setattr(entity, k, v)
if _ReplaceNoneValue and v is None:
setattr(entity, k, 0)
return entity
return self._factory(map(fill, self._iterable))
[docs] def as_entity(self, schema: Optional[List[str]] = None):
"""
Convert elements into Entities.
Args:
schema (Optional[List[str]]):
schema contains field names.
Examples:
1. convert dicts into entities:
>>> from towhee import DataCollection
>>> (
... DataCollection([dict(a=1, b=2), dict(a=2, b=3)])
... .as_entity()
... .as_str()
... .to_list()
... )
['{"a": 1, "b": 2}', '{"a": 2, "b": 3}']
2. convert tuples into entities:
>>> from towhee import DataCollection
>>> (
... DataCollection([(1, 2), (2, 3)])
... .as_entity(schema=['a', 'b'])
... .as_str()
... .to_list()
... )
['{"a": 1, "b": 2}', '{"a": 2, "b": 3}']
3. convert single value into entities:
>>> from towhee import DataCollection
>>> (
... DataCollection([1, 2])
... .as_entity(schema=['a'])
... .as_str()
... .to_list()
... )
['{"a": 1}', '{"a": 2}']
"""
if schema is None:
def inner(x):
return Entity(**x)
else:
def inner(x):
if len(schema) == 1:
x = (x, )
data = dict(zip(schema, x))
return Entity(**data)
return self._factory(map(inner, self._iterable))
[docs] def parse_json(self):
"""
Parse string to entities.
Examples:
>>> from towhee import DataCollection
>>> dc = (
... DataCollection(['{"x": 1}'])
... .parse_json()
... )
>>> dc[0].x
1
"""
def inner(x):
data = json.loads(x)
return Entity(**data)
return self.map(inner)
[docs] def as_json(self):
"""
Convert entities to json
Examples:
>>> from towhee import DataCollection, Entity
>>> (
... DataCollection([Entity(x=1)])
... .as_json()
... )
['{"x": 1}']
"""
def inner(x):
return json.dumps(x.__dict__)
return self.map(inner)
[docs] def as_raw(self):
"""
Convert entitis into raw python values
Examples:
1. unpack multiple values from entities:
>>> from towhee import DataCollection
>>> (
... DataCollection([(1, 2), (2, 3)])
... .as_entity(schema=['a', 'b'])
... .as_raw()
... .to_list()
... )
[(1, 2), (2, 3)]
2. unpack single value from entities:
>>> (
... DataCollection([1, 2])
... .as_entity(schema=['a'])
... .as_raw()
... .to_list()
... )
[1, 2]
"""
def inner(x):
if len(x.__dict__) == 1:
return list(x.__dict__.values())[0]
return tuple(getattr(x, name) for name in x.__dict__)
return self.map(inner)
[docs] def replace(self, **kws):
"""
Replace specific attributes with given vlues.
"""
def inner(entity: Entity):
for index, convert_dict in kws.items():
origin_value = getattr(entity, index)
if origin_value in convert_dict:
setattr(entity, index, convert_dict[origin_value])
return entity
return self._factory(map(inner, self._iterable))
[docs] def dropna(self, na: Set[str] = {'', None}) -> Union[bool, 'DataCollection']: # pylint: disable=dangerous-default-value
"""
Drop entities that contain some specific values.
Args:
na (`Set[str]`):
Those entities contain values in na will be dropped.
"""
def inner(entity: Entity):
for val in entity.__dict__.values():
if val in na:
return False
return True
return self._factory(filter(inner, self._iterable))
[docs] def rename(self, column: Dict[str, str]):
"""
Rename an column in DataCollection.
Args:
column (`Dict[str, str]`):
The columns to rename and their corresponding new name.
"""
def inner(x):
for key in column:
x.__dict__[column[key]] = x.__dict__.pop(key)
return x
return self._factory(map(inner, self._iterable))
@property
def df(self):
# pylint: disable=import-outside-toplevel
import pandas as pd
if isinstance(self._iterable, pd.DataFrame):
return self._iterable
else:
raise TypeError(
'data collection is not created from pandas DataFrame.')