Source code for towhee.functional.mixins.dataset

# Copyright 2021 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Union
from pathlib import Path

from towhee.functional.entity import Entity


[docs]class DatasetMixin: """ Mixin for dealing with dataset """ # pylint: disable=import-outside-toplevel
[docs] @classmethod def from_glob(cls, *args): # pragma: no cover """ generate a file list with `pattern` """ from glob import glob files = [] for path in args: files.extend(glob(path)) return cls(files).stream()
[docs] @classmethod def from_zip(cls, url, pattern, mode='r'): # pragma: no cover """load files from url/path. Args: zip_src (`Union[str, path]`): The path leads to the image. pattern (`str`): The filename pattern to extract. mode (str): file open mode. Returns: (File): The file handler for file in the zip file. """ from towhee.utils.repo_normalize import RepoNormalize from io import BytesIO from zipfile import ZipFile from glob import fnmatch from urllib.request import urlopen def inner(): if RepoNormalize(str(url)).url_valid(): with urlopen(url) as zip_file: zip_path = BytesIO(zip_file.read()) else: zip_path = str(Path(url).resolve()) with ZipFile(zip_path, 'r') as zfile: file_list = zfile.namelist() path_list = fnmatch.filter(file_list, pattern) for path in path_list: with zfile.open(path, mode=mode) as f: yield f.read() return cls(inner()).stream()
[docs] @classmethod def from_camera(cls, device_id=0, limit=-1): # pragma: no cover """ read images from a camera. """ from towhee.utils.cv2_utils import cv2 cnt = limit def inner(): nonlocal cnt cap = cv2.VideoCapture(device_id) while cnt != 0: retval, im = cap.read() if retval: yield im cnt -= 1 return cls(inner()).stream()
@classmethod def from_json(cls, json_path: Union[str, Path], encoding: str = 'utf-8'): import json def inner(): with open(json_path, 'r', encoding=encoding) as f: string = f.readline() while string: data = json.loads(string) string = f.readline() yield Entity(**data) return cls(inner()).stream() @classmethod def from_csv(cls, csv_path: Union[str, Path], encoding: str = 'utf-8-sig'): import csv def inner(): with open(csv_path, 'r', encoding=encoding) as f: data = csv.DictReader(f) for line in data: yield Entity(**line) return cls(inner()).stream() def random_sample(self): # core API already exists pass def filter_data(self): # core API already exists pass # pylint: disable=dangerous-default-value
[docs] def split_train_test(self, size: list = [0.9, 0.1], **kws): """ Split DataCollection to train and test data. Args: size (`list`): The size of the train and test. Examples: >>> from towhee.functional import DataCollection >>> dc = DataCollection.range(10) >>> train, test = dc.split_train_test(shuffle=False) >>> train.to_list() [0, 1, 2, 3, 4, 5, 6, 7, 8] >>> test.to_list() [9] """ from towhee.utils import sklearn_utils train_size = size[0] test_size = size[1] train, test = sklearn_utils.train_test_split(self._iterable, train_size=train_size, test_size=test_size, **kws) return self._factory(train), self._factory(test)