Source code for towhee.data.dataset.image_datasets

# coding=utf-8
# Copyright 2020-present the HuggingFace Inc. team and 2021 Zilliz.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from towhee.utils.thirdparty.pandas_utils import pandas as pd
import torch

from pandas import Series

# pylint: disable=ungrouped-imports
from towhee.utils.pil_utils import PILImage as Image
from torch.utils.data import Dataset
from torchvision import transforms
from typing import Tuple


[docs]class PyTorchImageDataset(Dataset): """ PyTorchImageDataset is a dataset class for training. Args: image_path (:obj:`str`): Path to the images for your dataset. label_file (:obj:`str`): Path to your label file. The label file should be a csv file. The columns in this file should be [image_name, category], 'image_name' is the path of your images, 'category' is the label of accordance image. For example: [image_name, dog] for one row. Note that the first row should be[image_name, category] data_transform (:obj:`Compose`): PyTorch transform of the input images. """
[docs] def __init__(self, image_path: str, label_file: str, data_transform: transforms.Compose = None): self.image_path = image_path self.label_file = label_file self.data_transform = data_transform df = pd.read_csv(self.label_file) image_names = Series.to_numpy(df['image_name']) for i in range(len(image_names)): if os.path.splitext(image_names[i])[1] == '': image_names[i] += '.jpg' images = image_names.tolist() self.images = [os.path.join(self.image_path, i) for i in images] categories = Series.to_numpy(df['category']) # Count the categories breed_set = set(categories) breed_list = list(breed_set) dic = {} for i in range(len(breed_list)): dic[breed_list[i]] = i self.labels = [dic[categories[i]] for i in range(len(categories))] self.num_classes = len(breed_list)
def __getitem__(self, index: int) -> Tuple[torch.Tensor, int]: label = self.labels[index] fn = self.images[index] img = Image.open(fn) if self.data_transform: img = self.data_transform(img) return (img, label) def __len__(self) -> int: return len(self.labels)