# Copyright 2021 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Iterable, Iterator, Callable
import random
import reprlib
from towhee.hparam import param_scope, dynamic_dispatch
from towhee.functional.entity import Entity
from towhee.functional.option import Option, Some, Empty
from towhee.functional.mixins import AllMixins
[docs]class DataCollection(Iterable, AllMixins):
"""
DataCollection is a pythonic computation and processing framework
for unstructured data in machine learning and data science.
It allows a data scientist or researcher to assemble a data processing pipeline,
do his model work (embedding, transforming, or classification)
and apply it to the business (search, recommendation, or shopping)
with a method-chaining style API.
Examples:
1. Create a data collection from list or iterator:
>>> dc = DataCollection([0, 1, 2, 3, 4])
>>> dc = DataCollection(iter([0, 1, 2, 3, 4]))
2. Chaining function invocations makes your code clean and fluent:
>>> (
... dc.map(lambda x: x+1)
... .map(lambda x: x*2)
... ).to_list()
[2, 4, 6, 8, 10]
3. Multi-line closures are also supported via decorator syntax
>>> dc = DataCollection([1,2,3,4])
>>> @dc.map
... def add1(x):
... return x+1
>>> @add1.map
... def mul2(x):
... return x *2
>>> mul2.to_list()
[4, 6, 8, 10]
>>> dc = DataCollection([1,2,3,4])
>>> @dc.filter
... def ge3(x):
... return x>=3
>>> ge3.to_list()
[3, 4]
`DataCollection` is designed to behave as a python list or iterator. Consider you are running
the following code:
.. code-block:: python
:linenos:
dc.map(stage1)
.map(stage2)
1. `iterator` and `stream mode`: When a `DataCollection` object is created from an iterator, it behaves as a python
iterator and performs `stream-wise` data processing:
a. `DataCollection` takes one element from the input and applies `stage1` and `stage2` sequentially ;
b. Since DataCollection holds no data, indexing or shuffle is not supported;
2. `list` and `unstream mode`: If a `DataCollection` object is created from a list, it will hold all the input values,
and perform stage-wise computations:
a. `stage2` will wait until all the calculations are done in `stage1`;
b. A new DataCollection will be created to hold all the outputs for each stage. You can perform list operations on result DataCollection;
"""
[docs] def __init__(self, iterable: Iterable) -> None:
"""Initializes a new DataCollection instance.
Args:
iterable (Iterable): input data
"""
super().__init__()
self._iterable = iterable
def __iter__(self):
if hasattr(self._iterable, 'iterrows'):
return (x[1] for x in self._iterable.iterrows())
return iter(self._iterable)
[docs] def stream(self):
"""
Create a stream data collection.
Examples:
1. Convert a data collection to streamed version
>>> dc = DataCollection([0, 1, 2, 3, 4])
>>> dc.is_stream
False
>>> dc = dc.stream()
>>> dc.is_stream
True
"""
# pylint: disable=protected-access
iterable = iter(self._iterable) if not self.is_stream else self._iterable
return self._factory(iterable, parent_stream = False)
[docs] def unstream(self):
"""
Create a unstream data collection.
Examples:
1. Create a unstream data collection
>>> dc = DataCollection(iter(range(5))).unstream()
>>> dc.is_stream
False
2. Convert a streamed data collection to unstream version
>>> dc = DataCollection(iter(range(5)))
>>> dc.is_stream
True
>>> dc = dc.unstream()
>>> dc.is_stream
False
"""
iterable = list(self._iterable) if self.is_stream else self._iterable
return self._factory(iterable, parent_stream = False)
@property
def is_stream(self):
"""
Check whether the data collection is stream or unstream.
Examples:
>>> dc = DataCollection([0,1,2,3,4])
>>> dc.is_stream
False
>>> result = dc.map(lambda x: x+1)
>>> result.is_stream
False
>>> result._iterable
[1, 2, 3, 4, 5]
>>> dc = DataCollection(iter(range(5)))
>>> dc.is_stream
True
>>> result = dc.map(lambda x: x+1)
>>> result.is_stream
True
>>> isinstance(result._iterable, Iterable)
True
"""
return isinstance(self._iterable, Iterator)
def _factory(self, iterable, parent_stream = True):
"""
Factory method for data collection.
This factory method has been wrapped into a `param_scope()` which contains parent information.
Args:
iterable: An iterable object, the data being stored in the DC
parent_stream: Whether to copy the parents format (streamed vs unstreamed)
Returns:
DataCollection: DataCollection encapsulating the iterable.
"""
if parent_stream is True:
if self.is_stream:
if not isinstance(iterable, Iterator):
iterable = iter(iterable)
else:
if isinstance(iterable, Iterator):
iterable = list(iterable)
with param_scope() as hp:
hp().data_collection.parent = self
return DataCollection(iterable)
[docs] def exception_safe(self):
"""
Making the data collection exception-safe by warp elements with `Option`.
Examples:
1. Exception breaks pipeline execution:
>>> dc = DataCollection.range(5)
>>> dc.map(lambda x: x / (0 if x == 3 else 2)).to_list()
Traceback (most recent call last):
ZeroDivisionError: division by zero
2. Exception-safe execution
>>> dc.exception_safe().map(lambda x: x / (0 if x == 3 else 2)).to_list()
[Some(0.0), Some(0.5), Some(1.0), Empty(), Some(2.0)]
>>> dc.exception_safe().map(lambda x: x / (0 if x == 3 else 2)).filter(lambda x: x < 1.5).to_list()
[Some(0.0), Some(0.5), Some(1.0), Empty()]
>>> dc.exception_safe().map(lambda x: x / (0 if x == 3 else 2)).filter(lambda x: x < 1.5, drop_empty=True).to_list()
[Some(0.0), Some(0.5), Some(1.0)]
"""
result = map(lambda x: Some(x) if not isinstance(x, Option) else x, self._iterable)
return self._factory(result)
[docs] def safe(self):
"""
Shortcut for `exception_safe`
"""
return self.exception_safe()
[docs] def select_from(self, other):
"""
Select data from dc with list(self).
Examples:
>>> dc1 = DataCollection([0.8, 0.9, 8.1, 9.2])
>>> dc2 = DataCollection([[1, 2, 0], [2, 3, 0]])
>>> dc3 = dc2.select_from(dc1)
>>> list(dc3)
[[0.9, 8.1, 0.8], [8.1, 9.2, 0.8]]
"""
def inner(x):
if isinstance(x, Iterable):
return [other[i] for i in x]
return other[x]
result = map(inner, self._iterable)
return self._factory(result)
[docs] def fill_empty(self, default: Any = None) -> 'DataCollection':
"""
Unbox `Option` values and fill `Empty` with default values.
Args:
default (Any): default value to replace empty values;
Returns:
DataCollection: data collection with empty values filled with `default`;
Examples:
>>> dc = DataCollection.range(5)
>>> dc.safe().map(lambda x: x / (0 if x == 3 else 2)).fill_empty(-1.0).to_list()
[0.0, 0.5, 1.0, -1.0, 2.0]
"""
result = map(lambda x: x.get() if isinstance(x, Some) else default, self._iterable)
return self._factory(result)
[docs] def drop_empty(self, callback: Callable = None) -> 'DataCollection':
"""
Unbox `Option` values and drop `Empty`.
Args:
callback (Callable): handler for empty values;
Returns:
DataCollection: data collection that drops empty values;
Examples:
>>> dc = DataCollection.range(5)
>>> dc.safe().map(lambda x: x / (0 if x == 3 else 2)).drop_empty().to_list()
[0.0, 0.5, 1.0, 2.0]
Get inputs that case exceptions:
>>> exception_inputs = []
>>> result = dc.safe().map(lambda x: x / (0 if x == 3 else 2)).drop_empty(lambda x: exception_inputs.append(x.get().value))
>>> exception_inputs
[3]
"""
if callback is not None:
def inner(data):
for x in data:
if isinstance(x, Empty):
callback(x)
if isinstance(x, Some):
yield x.get()
result = inner(self._iterable)
else:
def inner(data):
for x in data:
if isinstance(x, Some):
yield x.get()
result = inner(self._iterable)
return self._factory(result)
[docs] def map(self, *arg):
"""
Apply operator to data collection.
Args:
*arg (Callable): functions/operators to apply to data collection;
Returns:
DataCollection: data collections that contains computation results;
Examples:
>>> dc = DataCollection([1,2,3,4])
>>> dc.map(lambda x: x+1).map(lambda x: x*2).to_list()
[4, 6, 8, 10]
"""
# mmap
if len(arg) > 1:
return self.mmap(list(arg))
unary_op = arg[0]
# smap map for stateful operator
if hasattr(unary_op, 'is_stateful') and unary_op.is_stateful:
return self.smap(unary_op)
# pmap
if self.get_executor() is not None:
return self.pmap(unary_op)
if hasattr(self._iterable, 'map'):
return self._factory(self._iterable.map(unary_op))
if hasattr(self._iterable, 'apply') and hasattr(unary_op, '__dataframe_apply__'):
return self._factory(unary_op.__dataframe_apply__(self._iterable))
# map
def inner(x):
if isinstance(x, Option):
return x.map(unary_op)
else:
return unary_op(x)
result = map(inner, self._iterable)
return self._factory(result)
[docs] def zip(self, *others) -> 'DataCollection':
"""
Combine two data collections.
Args:
*others (DataCollection): other data collections;
Returns:
DataCollection: data collection with zipped values;
Examples:
>>> dc1 = DataCollection([1,2,3,4])
>>> dc2 = DataCollection([1,2,3,4]).map(lambda x: x+1)
>>> dc3 = dc1.zip(dc2)
>>> list(dc3)
[(1, 2), (2, 3), (3, 4), (4, 5)]
"""
return self._factory(zip(self, *others))
[docs] def filter(self, unary_op: Callable, drop_empty=False) -> 'DataCollection':
"""
Filter data collection with `unary_op`.
Args:
unary_op (`Callable`):
Callable to decide whether to filter the element;
drop_empty (`bool`):
Drop empty values. Defaults to False.
Returns:
DataCollection: filtered data collection
"""
# return filter(unary_op, self)
def inner(x):
if isinstance(x, Option):
if isinstance(x, Some):
return unary_op(x.get())
return not drop_empty
return unary_op(x)
if hasattr(self._iterable, 'filter'):
return self._factory(self._iterable.filter(unary_op))
if hasattr(self._iterable, 'apply') and hasattr(unary_op, '__dataframe_filter__'):
return DataCollection(unary_op.__dataframe_apply__(self._iterable))
return self._factory(filter(inner, self._iterable))
[docs] def sample(self, ratio=1.0) -> 'DataCollection':
"""
Sample the data collection.
Args:
ratio (float): sample ratio;
Returns:
DataCollection: sampled data collection;
Examples:
>>> dc = DataCollection(range(10000))
>>> result = dc.sample(0.1)
>>> ratio = len(result.to_list()) / 10000.
>>> 0.09 < ratio < 0.11
True
"""
return self._factory(filter(lambda _: random.random() < ratio, self))
[docs] @staticmethod
def range(*arg, **kws):
"""
Generate data collection with ranged numbers.
Examples:
>>> DataCollection.range(5).to_list()
[0, 1, 2, 3, 4]
"""
return DataCollection(range(*arg, **kws))
[docs] def batch(self, size, drop_tail=False, raw=True):
"""
Create small batches from data collections.
Args:
size (int): window size;
drop_tail (bool): drop tailing windows that not full, defaults to False;
raw (bool): whether to return raw data instead of DataCollection, defaults to True
Returns:
DataCollection of batched windows or batch raw data
Examples:
>>> dc = DataCollection(range(10))
>>> [list(batch) for batch in dc.batch(2, raw=False)]
[[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]
>>> dc = DataCollection(range(10))
>>> dc.batch(3)
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
>>> dc = DataCollection(range(10))
>>> dc.batch(3, drop_tail=True)
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]
>>> from towhee import Entity
>>> dc = DataCollection([Entity(a=a, b=b) for a,b in zip(['abc', 'vdfvcd', 'cdsc'], [1,2,3])])
>>> dc.batch(2)
[<Entity dict_keys(['a', 'b'])>, <Entity dict_keys(['a', 'b'])>]
"""
def inner():
buff = []
count = 0
for ele in self._iterable:
if isinstance(ele, Entity):
if count == 0:
buff = ele
for key in ele.__dict__.keys():
buff.__dict__[key] = [buff.__dict__[key]]
count = 1
continue
for key in ele.__dict__.keys():
buff.__dict__[key].append(ele.__dict__[key])
else:
buff.append(ele)
count += 1
if count == size:
if raw:
yield buff
else:
yield buff if isinstance(buff, list) else [buff]
buff = []
count = 0
if not drop_tail and count > 0:
if raw:
yield buff
else:
yield buff if isinstance(buff, list) else [buff]
return self._factory(inner())
[docs] def rolling(self, size: int, drop_head=True, drop_tail=True):
"""
Create rolling windows from data collections.
Args:
size (int): window size;
drop_head (bool): drop headding windows that not full;
drop_tail (bool): drop tailing windows that not full;
Returns:
DataCollection: data collection of rolling windows;
Examples:
>>> dc = DataCollection(range(5))
>>> [list(batch) for batch in dc.rolling(3)]
[[0, 1, 2], [1, 2, 3], [2, 3, 4]]
>>> dc = DataCollection(range(5))
>>> [list(batch) for batch in dc.rolling(3, drop_head=False)]
[[0], [0, 1], [0, 1, 2], [1, 2, 3], [2, 3, 4]]
>>> dc = DataCollection(range(5))
>>> [list(batch) for batch in dc.rolling(3, drop_tail=False)]
[[0, 1, 2], [1, 2, 3], [2, 3, 4], [3, 4], [4]]
"""
def inner():
buff = []
for ele in self._iterable:
buff.append(ele)
if not drop_head or len(buff) == size:
yield buff.copy()
if len(buff) == size:
buff = buff[1:]
while not drop_tail and len(buff) > 0:
yield buff
buff = buff[1:]
return self._factory(inner())
[docs] def flatten(self) -> 'DataCollection':
"""
Flatten nested data collections.
Returns:
DataCollection: flattened data collection;
Examples:
>>> dc = DataCollection(range(10))
>>> nested_dc = dc.batch(2)
>>> nested_dc.flatten().to_list()
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
"""
def inner():
for ele in self._iterable:
if isinstance(ele, Iterable):
for nested_ele in iter(ele):
yield nested_ele
else:
yield ele
return self._factory(inner())
[docs] def shuffle(self) -> 'DataCollection':
"""
Shuffle an unstreamed data collection in place.
Returns:
DataCollection: shuffled data collection;
Examples:
1. Shuffle:
>>> dc = DataCollection([0, 1, 2, 3, 4])
>>> a = dc.shuffle()
>>> tuple(a) == tuple(range(5))
False
2. streamed data collection is not supported:
>>> dc = DataCollection([0, 1, 2, 3, 4]).stream()
>>> _ = dc.shuffle()
Traceback (most recent call last):
TypeError: shuffle is not supported for streamed data collection.
"""
if self.is_stream:
raise TypeError('shuffle is not supported for streamed data collection.')
iterable = random.sample(self._iterable, len(self._iterable))
return self._factory(iterable)
[docs] def __getattr__(self, name):
"""
Unknown method dispatcher.
When a unknown method is invoked on a `DataCollection` object,
the function call will be dispatched to a method resolver.
By registering function to the resolver, you are able to extend
`DataCollection`'s API at runtime without modifying its code.
Examples:
1. Define two operators:
>>> class my_add:
... def __init__(self, val):
... self.val = val
... def __call__(self, x):
... return x+self.val
>>> class my_mul:
... def __init__(self, val):
... self.val = val
... def __call__(self, x):
... return x*self.val
2. Register the operators to `DataCollection`'s execution context with `param_scope`:
>>> from towhee import param_scope
>>> with param_scope(dispatcher={
... 'add': my_add, # register `my_add` as `dc.add`
... 'mul': my_mul # register `my_mul` as `dc.mul`
... }):
... dc = DataCollection([1,2,3,4])
... dc.add(1).mul(2).to_list() # call registered operator
[4, 6, 8, 10]
"""
# pylint: disable=protected-access
with param_scope() as hp:
dispatcher = hp().dispatcher({})
@dynamic_dispatch
def wrapper(*arg, **kws):
with param_scope() as hp:
path = hp._name
index = hp._index
if self.get_backend() == 'ray':
return self.ray_resolve(dispatcher, path, index, *arg, **kws)
op = self.resolve(dispatcher, path, index, *arg, **kws)
return self.map(op)
return getattr(wrapper, name)
[docs] def __getitem__(self, index):
"""
Indexing for data collection.
Examples:
>>> dc = DataCollection([0, 1, 2, 3, 4])
>>> dc[0]
0
>>> dc.stream()[1]
Traceback (most recent call last):
TypeError: indexing is only supported for data collection created from list or pandas DataFrame.
"""
if not hasattr(self._iterable, '__getitem__'):
raise TypeError('indexing is only supported for '
'data collection created from list or pandas DataFrame.')
if isinstance(index, int):
return self._iterable[index]
return DataCollection(self._iterable[index])
[docs] def __setitem__(self, index, value):
"""
Indexing for data collection.
Examples:
>>> dc = DataCollection([0, 1, 2, 3, 4])
>>> dc[0]
0
>>> dc[0] = 5
>>> dc._iterable[0]
5
>>> dc.stream()[0]
Traceback (most recent call last):
TypeError: indexing is only supported for data collection created from list or pandas DataFrame.
"""
if not hasattr(self._iterable, '__setitem__'):
raise TypeError('indexing is only supported for '
'data collection created from list or pandas DataFrame.')
self._iterable[index] = value
[docs] def append(self, item: Any) -> 'DataCollection':
"""
Append item to data collection
Args:
item (Any): the item to append
Returns:
DataCollection: self
Examples:
>>> dc = DataCollection([0, 1, 2])
>>> dc.append(3).append(4)
[0, 1, 2, 3, 4]
"""
if hasattr(self._iterable, 'append'):
self._iterable.append(item)
return self
raise TypeError('appending is only supported for '
'data collection created from list.')
[docs] def __rshift__(self, unary_op):
"""
Chain the operators with `>>`.
Examples:
>>> dc = DataCollection([1,2,3,4])
>>> (dc
... >> (lambda x: x+1)
... >> (lambda x: x*2)
... ).to_list()
[4, 6, 8, 10]
"""
return self.map(unary_op)
def __or__(self, unary_op):
return self.map(unary_op)
[docs] def __add__(self, other):
"""
Concat two data collections:
Examples:
>>> (DataCollection.range(5) + DataCollection.range(5)).to_list()
[0, 1, 2, 3, 4, 0, 1, 2, 3, 4]
>>> (DataCollection.range(5) + DataCollection.range(5) + DataCollection.range(5)).to_list()
[0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4]
"""
def inner():
for x in self:
yield x
for x in other:
yield x
return self._factory(inner())
[docs] def __repr__(self) -> str:
"""
Return a string representation for DataCollection.
Examples:
>>> DataCollection([1, 2, 3]).unstream()
[1, 2, 3]
>>> DataCollection([1, 2, 3]).stream() #doctest: +ELLIPSIS
<list_iterator object at...>
"""
if isinstance(self._iterable, list):
return reprlib.repr(self._iterable)
if hasattr(self._iterable, '__repr__'):
return repr(self._iterable)
return super().__repr__()
[docs] def head(self, n: int = 5):
"""
Get the first n lines of a DataCollection.
Args:
n (`int`):
The number of lines to print. Default value is 5.
Examples:
>>> DataCollection.range(10).head(3).to_list()
[0, 1, 2]
"""
def inner():
for i, x in enumerate(self._iterable):
if i >= n:
break
yield x
return self._factory(inner())
def run(self):
for _ in self._iterable:
pass
def to_list(self):
return self._iterable if isinstance(self._iterable, list) else list(self)
if __name__ == '__main__':
import doctest
doctest.testmod()